diff --git a/.bazelrc b/.bazelrc index 554440cfe3d..53485cb9743 100644 --- a/.bazelrc +++ b/.bazelrc @@ -1 +1 @@ -build --cxxopt=-std=c++14 --host_cxxopt=-std=c++14 +build --cxxopt=-std=c++17 --host_cxxopt=-std=c++17 diff --git a/.gemini/config.yaml b/.gemini/config.yaml new file mode 100644 index 00000000000..71adf793964 --- /dev/null +++ b/.gemini/config.yaml @@ -0,0 +1,13 @@ +have_fun: false +memory_config: + disabled: false +code_review: + disable: false + comment_severity_threshold: MEDIUM + max_review_comments: -1 + pull_request_opened: + help: false + summary: false + code_review: false + include_drafts: false +ignore_patterns: [] diff --git a/.github/workflows/branch-testing.yml b/.github/workflows/branch-testing.yml new file mode 100644 index 00000000000..ece8ec4cd58 --- /dev/null +++ b/.github/workflows/branch-testing.yml @@ -0,0 +1,41 @@ +name: GitHub Actions Branch Testing + +on: + push: + branches: + - master + - 'v1.*' + schedule: + - cron: '54 19 * * SUN' # weekly at a "random" time + +permissions: + contents: read + +jobs: + arm64: + runs-on: ubuntu-24.04-arm + strategy: + matrix: + jre: [17] + fail-fast: false # Should swap to true if we grow a large matrix + + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-java@v4 + with: + java-version: ${{ matrix.jre }} + distribution: 'temurin' + + - name: Gradle cache + uses: actions/cache@v4 + with: + path: | + ~/.gradle/caches + ~/.gradle/wrapper + key: ${{ runner.os }}-gradle-${{ hashFiles('**/*.gradle*', '**/gradle-wrapper.properties') }} + restore-keys: | + ${{ runner.os }}-gradle- + + - name: Build + run: ./gradlew -Dorg.gradle.parallel=true -Dorg.gradle.jvmargs='-Xmx1g' -PskipAndroid=true -PskipCodegen=true -PerrorProne=false test + diff --git a/.github/workflows/gradle-wrapper-validation.yml b/.github/workflows/gradle-wrapper-validation.yml index b827468d719..da1e2fed114 100644 --- a/.github/workflows/gradle-wrapper-validation.yml +++ b/.github/workflows/gradle-wrapper-validation.yml @@ -9,5 +9,5 @@ jobs: name: "Gradle wrapper validation" runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 - - uses: gradle/wrapper-validation-action@v1 + - uses: actions/checkout@v4 + - uses: gradle/actions/wrapper-validation@v4 diff --git a/.github/workflows/lock.yml b/.github/workflows/lock.yml index 907a9dad2b5..3070a1a2f7c 100644 --- a/.github/workflows/lock.yml +++ b/.github/workflows/lock.yml @@ -13,7 +13,7 @@ jobs: lock: runs-on: ubuntu-latest steps: - - uses: dessant/lock-threads@v4 + - uses: dessant/lock-threads@v5 with: github-token: ${{ github.token }} issue-inactive-days: 90 diff --git a/.github/workflows/testing.yml b/.github/workflows/testing.yml index adc02dc3519..953edd12e04 100644 --- a/.github/workflows/testing.yml +++ b/.github/workflows/testing.yml @@ -17,18 +17,18 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - jre: [8, 11, 17] + jre: [8, 11, 17, 21] fail-fast: false # Should swap to true if we grow a large matrix steps: - - uses: actions/checkout@v3 - - uses: actions/setup-java@v3 + - uses: actions/checkout@v4 + - uses: actions/setup-java@v4 with: java-version: ${{ matrix.jre }} distribution: 'temurin' - name: Gradle cache - uses: actions/cache@v3 + uses: actions/cache@v4 with: path: | ~/.gradle/caches @@ -37,7 +37,7 @@ jobs: restore-keys: | ${{ runner.os }}-gradle- - name: Maven cache - uses: actions/cache@v3 + uses: actions/cache@v4 with: path: | ~/.m2/repository @@ -46,7 +46,7 @@ jobs: restore-keys: | ${{ runner.os }}-maven- - name: Protobuf cache - uses: actions/cache@v3 + uses: actions/cache@v4 with: path: /tmp/protobuf-cache key: ${{ runner.os }}-maven-${{ hashFiles('buildscripts/make_dependencies.sh') }} @@ -55,7 +55,7 @@ jobs: run: buildscripts/kokoro/unix.sh - name: Post Failure Upload Test Reports to Artifacts if: ${{ failure() }} - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 with: name: Test Reports (JRE ${{ matrix.jre }}) path: | @@ -71,18 +71,32 @@ jobs: COVERALLS_REPO_TOKEN: ${{ secrets.COVERALLS_REPO_TOKEN }} run: ./gradlew :grpc-all:coveralls -PskipAndroid=true -x compileJava - name: Codecov - uses: codecov/codecov-action@v3 + uses: codecov/codecov-action@v4 + with: + token: ${{ secrets.CODECOV_TOKEN }} bazel: runs-on: ubuntu-latest + strategy: + matrix: + bzlmod: [true, false] + bazel_version: [8.7.0, 9.1.0] + exclude: + - bazel_version: 9.1.0 + bzlmod: false env: - USE_BAZEL_VERSION: 6.0.0 + USE_BAZEL_VERSION: ${{ matrix.bazel_version }} steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 + + - name: Check versions match in MODULE.bazel and repositories.bzl + run: | + diff -u <(sed -n '/GRPC_DEPS_START/,/GRPC_DEPS_END/ {/GRPC_DEPS_/! p}' MODULE.bazel) \ + <(sed -n '/GRPC_DEPS_START/,/GRPC_DEPS_END/ {/GRPC_DEPS_/! p}' repositories.bzl) - name: Bazel cache - uses: actions/cache@v3 + uses: actions/cache@v4 with: path: | ~/.cache/bazel/*/cache @@ -90,4 +104,11 @@ jobs: key: ${{ runner.os }}-bazel-${{ env.USE_BAZEL_VERSION }}-${{ hashFiles('WORKSPACE', 'repositories.bzl') }} - name: Run bazel build - run: bazelisk build //... + run: bazelisk build //... --enable_bzlmod=${{ matrix.bzlmod }} --enable_workspace=${{ !matrix.bzlmod }} + + - name: Run bazel test + run: bazelisk test //... --enable_bzlmod=${{ matrix.bzlmod }} --enable_workspace=${{ !matrix.bzlmod }} + + - name: Run example bazel build + run: bazelisk build //... --enable_bzlmod=${{ matrix.bzlmod }} --enable_workspace=${{ !matrix.bzlmod }} + working-directory: ./examples diff --git a/.gitignore b/.gitignore index 9fd0d7fb574..b078d891adf 100644 --- a/.gitignore +++ b/.gitignore @@ -15,6 +15,7 @@ bazel-genfiles bazel-grpc-java bazel-out bazel-testlogs +MODULE.bazel.lock # IntelliJ IDEA .idea @@ -30,6 +31,9 @@ bazel-testlogs .gitignore bin +# VsCode +.vscode + # OS X .DS_Store diff --git a/BUILD.bazel b/BUILD.bazel index 40c04022673..27a99fb62eb 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -12,6 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +load("@com_google_protobuf//bazel:java_proto_library.bzl", "java_proto_library") +load("@rules_java//java:defs.bzl", "java_library", "java_plugin") +load("@rules_jvm_external//:defs.bzl", "artifact") load(":java_grpc_library.bzl", "java_grpc_library") java_proto_library( @@ -32,10 +35,9 @@ java_library( "//api", "//protobuf", "//stub", - "//stub:javax_annotation", - "@com_google_code_findbugs_jsr305//jar", - "@com_google_guava_guava//jar", "@com_google_protobuf//:protobuf_java", + artifact("com.google.code.findbugs:jsr305"), + artifact("com.google.guava:guava"), ], ) @@ -46,9 +48,8 @@ java_library( "//api", "//protobuf-lite", "//stub", - "//stub:javax_annotation", - "@com_google_code_findbugs_jsr305//jar", - "@com_google_guava_guava//jar", + artifact("com.google.code.findbugs:jsr305"), + artifact("com.google.guava:guava"), ], ) @@ -56,7 +57,7 @@ java_plugin( name = "auto_value", generates_api = 1, processor_class = "com.google.auto.value.processor.AutoValueProcessor", - deps = ["@com_google_auto_value_auto_value//jar"], + deps = [artifact("com.google.auto.value:auto-value")], ) java_library( @@ -65,7 +66,6 @@ java_library( neverlink = 1, visibility = ["//:__subpackages__"], exports = [ - "@com_google_auto_value_auto_value_annotations//jar", - "@org_apache_tomcat_annotations_api//jar", # @Generated for Java 9+ + artifact("com.google.auto.value:auto-value-annotations"), ], ) diff --git a/COMPILING.md b/COMPILING.md index de3cbb026c1..b7df1319beb 100644 --- a/COMPILING.md +++ b/COMPILING.md @@ -44,11 +44,11 @@ This section is only necessary if you are making changes to the code generation. Most users only need to use `skipCodegen=true` as discussed above. ### Build Protobuf -The codegen plugin is C++ code and requires protobuf 21.7 or later. +The codegen plugin is C++ code and requires protobuf 22.5 or later. For Linux, Mac and MinGW: ``` -$ PROTOBUF_VERSION=21.7 +$ PROTOBUF_VERSION=22.5 $ curl -LO https://github.com/protocolbuffers/protobuf/releases/download/v$PROTOBUF_VERSION/protobuf-all-$PROTOBUF_VERSION.tar.gz $ tar xzf protobuf-all-$PROTOBUF_VERSION.tar.gz $ cd protobuf-$PROTOBUF_VERSION diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index ce40827e748..646a7d986fd 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -30,43 +30,36 @@ style configurations are commonly useful. For IntelliJ 14, copy the style to `~/.IdeaIC14/config/codestyles/`, start IntelliJ, go to File > Settings > Code Style, and set the Scheme to `GoogleStyle`. -## Maintaining clean commit history - -We have few conventions for keeping history clean and making code reviews easier -for reviewers: - -* First line of commit messages should be in format of - - `package-name: summary of change` - - where the summary finishes the sentence: `This commit improves gRPC to ____________.` - - for example: - - `core,netty,interop-testing: add capacitive duractance to turbo encabulators` - -* Every time you receive a feedback on your pull request, push changes that - address it as a separate one or multiple commits with a descriptive commit - message (try avoid using vauge `addressed pr feedback` type of messages). - - Project maintainers are obligated to squash those commits into one when - merging. - ## Guidelines for Pull Requests How to get your contributions merged smoothly and quickly. - Create **small PRs** that are narrowly focused on **addressing a single concern**. We often times receive PRs that are trying to fix several things at a time, but only one fix is considered acceptable, nothing gets merged and both author's & review's time is wasted. Create more PRs to address different concerns and everyone will be happy. -- For speculative changes, consider opening an issue and discussing it first. If you are suggesting a behavioral or API change, consider starting with a [gRFC proposal](https://github.com/grpc/proposal). - -- Provide a good **PR description** as a record of **what** change is being made and **why** it was made. Link to a github issue if it exists. - -- Don't fix code style and formatting unless you are already changing that line to address an issue. PRs with irrelevant changes won't be merged. If you do want to fix formatting or style, do that in a separate PR. - -- Unless your PR is trivial, you should expect there will be reviewer comments that you'll need to address before merging. We expect you to be reasonably responsive to those comments, otherwise the PR will be closed after 2-3 weeks of inactivity. - -- Maintain **clean commit history** and use **meaningful commit messages**. See [maintaining clean commit history](#maintaining-clean-commit-history) for details. - +- For speculative changes, consider opening an issue and discussing it to avoid + wasting time on an inappropriate approach. If you are suggesting a behavioral + or API change, consider starting with a [gRFC + proposal](https://github.com/grpc/proposal). + +- Follow [typical Git commit message](https://cbea.ms/git-commit/#seven-rules) + structure. Have a good **commit description** as a record of **what** and + **why** the change is being made. Link to a GitHub issue if it exists. The + commit description makes a good PR description and is auto-copied by GitHub if + you have a single commit when creating the PR. + + If your change is mostly for a single module (e.g., other module changes are + trivial), prefix your commit summary with the module name changed. Instead of + "Add HTTP/2 faster-than-light support to gRPC Netty" it is more terse as + "netty: Add faster-than-light support". + +- Don't fix code style and formatting unless you are already changing that line + to address an issue. If you do want to fix formatting or style, do that in a + separate PR. + +- Unless your PR is trivial, you should expect there will be reviewer comments + that you'll need to address before merging. Address comments with additional + commits so the reviewer can review just the changes; do not squash reviewed + commits unless the reviewer agrees. PRs are squashed when merging. + - Keep your PR up to date with upstream/master (if there are merge conflicts, we can't really merge your change). - **All tests need to be passing** before your change can be merged. We recommend you **run tests locally** before creating your PR to catch breakages early on. Also, `./gradlew build` (`gradlew build` on Windows) **must not introduce any new warnings**. diff --git a/MAINTAINERS.md b/MAINTAINERS.md index f05542e1987..5048c7c5aca 100644 --- a/MAINTAINERS.md +++ b/MAINTAINERS.md @@ -10,25 +10,26 @@ for general contribution guidelines. ## Maintainers (in alphabetical order) - [ejona86](https://github.com/ejona86), Google LLC - [jdcormie](https://github.com/jdcormie), Google LLC -- [larry-safran](https://github.com/larry-safran), Google LLC -- [markb74](https://github.com/markb74), Google LLC +- [kannanjgithub](https://github.com/kannanjgithub), Google LLC - [ran-su](https://github.com/ran-su), Google LLC -- [sanjaypujare](https://github.com/sanjaypujare), Google LLC - [sergiitk](https://github.com/sergiitk), Google LLC - [temawi](https://github.com/temawi), Google LLC - [YifeiZhuang](https://github.com/YifeiZhuang), Google LLC - [zhangkun83](https://github.com/zhangkun83), Google LLC ## Emeritus Maintainers (in alphabetical order) -- [carl-mastrangelo](https://github.com/carl-mastrangelo), Google LLC -- [creamsoup](https://github.com/creamsoup), Google LLC -- [dapengzhang0](https://github.com/dapengzhang0), Google LLC -- [ericgribkoff](https://github.com/ericgribkoff), Google LLC -- [jiangtaoli2016](https://github.com/jiangtaoli2016), Google LLC -- [jtattermusch](https://github.com/jtattermusch), Google LLC -- [louiscryan](https://github.com/louiscryan), Google LLC -- [nicolasnoble](https://github.com/nicolasnoble), Google LLC -- [nmittler](https://github.com/nmittler), Google LLC -- [srini100](https://github.com/srini100), Google LLC -- [voidzcy](https://github.com/voidzcy), Google LLC -- [zpencer](https://github.com/zpencer), Google LLC +- [carl-mastrangelo](https://github.com/carl-mastrangelo) +- [creamsoup](https://github.com/creamsoup) +- [dapengzhang0](https://github.com/dapengzhang0) +- [ericgribkoff](https://github.com/ericgribkoff) +- [jiangtaoli2016](https://github.com/jiangtaoli2016) +- [jtattermusch](https://github.com/jtattermusch) +- [larry-safran](https://github.com/larry-safran) +- [louiscryan](https://github.com/louiscryan) +- [markb74](https://github.com/markb74) +- [nicolasnoble](https://github.com/nicolasnoble) +- [nmittler](https://github.com/nmittler) +- [sanjaypujare](https://github.com/sanjaypujare) +- [srini100](https://github.com/srini100) +- [voidzcy](https://github.com/voidzcy) +- [zpencer](https://github.com/zpencer) diff --git a/MODULE.bazel b/MODULE.bazel new file mode 100644 index 00000000000..0ad9acb28aa --- /dev/null +++ b/MODULE.bazel @@ -0,0 +1,160 @@ +module( + name = "grpc-java", + version = "1.82.0-SNAPSHOT", # CURRENT_GRPC_VERSION + repo_name = "io_grpc_grpc_java", +) + +# GRPC_DEPS_START +IO_GRPC_GRPC_JAVA_ARTIFACTS = [ + "com.google.android:annotations:4.1.1.4", + "com.google.api.grpc:proto-google-common-protos:2.64.1", + "com.google.auth:google-auth-library-credentials:1.42.1", + "com.google.auth:google-auth-library-oauth2-http:1.42.1", + "com.google.auto.value:auto-value-annotations:1.11.0", + "com.google.auto.value:auto-value:1.11.0", + "com.google.code.findbugs:jsr305:3.0.2", + "com.google.code.gson:gson:2.13.2", + "com.google.errorprone:error_prone_annotations:2.48.0", + "com.google.guava:failureaccess:1.0.1", + "com.google.guava:guava:33.5.0-android", + "com.google.re2j:re2j:1.8", + "com.google.s2a.proto.v2:s2a-proto:0.1.3", + "com.google.truth:truth:1.4.5", + "com.squareup.okhttp:okhttp:2.7.5", + "com.squareup.okio:okio:2.10.0", # 3.0+ needs swapping to -jvm; need work to avoid flag-day + "io.netty:netty-buffer:4.1.133.Final", + "io.netty:netty-codec-http2:4.1.133.Final", + "io.netty:netty-codec-http:4.1.133.Final", + "io.netty:netty-codec-socks:4.1.133.Final", + "io.netty:netty-codec:4.1.133.Final", + "io.netty:netty-common:4.1.133.Final", + "io.netty:netty-handler-proxy:4.1.133.Final", + "io.netty:netty-handler:4.1.133.Final", + "io.netty:netty-resolver:4.1.133.Final", + "io.netty:netty-tcnative-boringssl-static:2.0.75.Final", + "io.netty:netty-tcnative-classes:2.0.75.Final", + "io.netty:netty-transport-native-epoll:jar:linux-x86_64:4.1.133.Final", + "io.netty:netty-transport-native-unix-common:4.1.133.Final", + "io.netty:netty-transport:4.1.133.Final", + "io.opencensus:opencensus-api:0.31.0", + "io.opencensus:opencensus-contrib-grpc-metrics:0.31.0", + "io.perfmark:perfmark-api:0.27.0", + "junit:junit:4.13.2", + "org.mockito:mockito-core:4.4.0", + "org.checkerframework:checker-qual:3.49.5", + "org.codehaus.mojo:animal-sniffer-annotations:1.27", +] +# GRPC_DEPS_END + +bazel_dep(name = "abseil-cpp", version = "20250512.1") +bazel_dep(name = "bazel_jar_jar", version = "0.1.11.bcr.1") +bazel_dep(name = "bazel_skylib", version = "1.7.1") +bazel_dep(name = "googleapis", version = "0.0.0-20260514-1dbb1a14", repo_name = "com_google_googleapis") +bazel_dep(name = "grpc-proto", version = "0.0.0-20240627-ec30f58.bcr.1", repo_name = "io_grpc_grpc_proto") +bazel_dep(name = "protobuf", version = "33.4", repo_name = "com_google_protobuf") +bazel_dep(name = "rules_cc", version = "0.0.9") +bazel_dep(name = "rules_java", version = "9.1.0") +bazel_dep(name = "rules_jvm_external", version = "6.0") + +maven = use_extension("@rules_jvm_external//:extensions.bzl", "maven") +maven.install( + artifacts = IO_GRPC_GRPC_JAVA_ARTIFACTS, + repositories = [ + "https://repo.maven.apache.org/maven2/", + ], + strict_visibility = True, +) +use_repo(maven, "maven") + +maven.override( + coordinates = "com.google.protobuf:protobuf-java", + target = "@com_google_protobuf//:protobuf_java", +) +maven.override( + coordinates = "com.google.protobuf:protobuf-java-util", + target = "@com_google_protobuf//:protobuf_java_util", +) +maven.override( + coordinates = "com.google.protobuf:protobuf-javalite", + target = "@com_google_protobuf//:protobuf_javalite", +) +maven.override( + coordinates = "io.grpc:grpc-alts", + target = "@io_grpc_grpc_java//alts", +) +maven.override( + coordinates = "io.grpc:grpc-api", + target = "@io_grpc_grpc_java//api", +) +maven.override( + coordinates = "io.grpc:grpc-auth", + target = "@io_grpc_grpc_java//auth", +) +maven.override( + coordinates = "io.grpc:grpc-census", + target = "@io_grpc_grpc_java//census", +) +maven.override( + coordinates = "io.grpc:grpc-context", + target = "@io_grpc_grpc_java//context", +) +maven.override( + coordinates = "io.grpc:grpc-core", + target = "@io_grpc_grpc_java//core:core_maven", +) +maven.override( + coordinates = "io.grpc:grpc-googleapis", + target = "@io_grpc_grpc_java//googleapis", +) +maven.override( + coordinates = "io.grpc:grpc-grpclb", + target = "@io_grpc_grpc_java//grpclb", +) +maven.override( + coordinates = "io.grpc:grpc-inprocess", + target = "@io_grpc_grpc_java//inprocess", +) +maven.override( + coordinates = "io.grpc:grpc-netty", + target = "@io_grpc_grpc_java//netty", +) +maven.override( + coordinates = "io.grpc:grpc-netty-shaded", + target = "@io_grpc_grpc_java//netty:shaded_maven", +) +maven.override( + coordinates = "io.grpc:grpc-okhttp", + target = "@io_grpc_grpc_java//okhttp", +) +maven.override( + coordinates = "io.grpc:grpc-protobuf", + target = "@io_grpc_grpc_java//protobuf", +) +maven.override( + coordinates = "io.grpc:grpc-protobuf-lite", + target = "@io_grpc_grpc_java//protobuf-lite", +) +maven.override( + coordinates = "io.grpc:grpc-rls", + target = "@io_grpc_grpc_java//rls", +) +maven.override( + coordinates = "io.grpc:grpc-services", + target = "@io_grpc_grpc_java//services:services_maven", +) +maven.override( + coordinates = "io.grpc:grpc-stub", + target = "@io_grpc_grpc_java//stub", +) +maven.override( + coordinates = "io.grpc:grpc-testing", + target = "@io_grpc_grpc_java//testing", +) +maven.override( + coordinates = "io.grpc:grpc-xds", + target = "@io_grpc_grpc_java//xds:xds_maven", +) +maven.override( + coordinates = "io.grpc:grpc-util", + target = "@io_grpc_grpc_java//util", +) diff --git a/README.md b/README.md index f29f35781e8..8e6620c927e 100644 --- a/README.md +++ b/README.md @@ -20,7 +20,7 @@ gRPC-Java - An RPC library and framework Supported Platforms ------------------- -gRPC-Java supports Java 8 and later. Android minSdkVersion 21 (Lollipop) and +gRPC-Java supports Java 8 and later. Android minSdkVersion 23 (Marshmallow) and later are supported with [Java 8 language desugaring][android-java-8]. TLS usage on Android typically requires Play Services Dynamic Security Provider. @@ -44,8 +44,8 @@ For a guided tour, take a look at the [quick start guide](https://grpc.io/docs/languages/java/quickstart) or the more explanatory [gRPC basics](https://grpc.io/docs/languages/java/basics). -The [examples](https://github.com/grpc/grpc-java/tree/v1.62.2/examples) and the -[Android example](https://github.com/grpc/grpc-java/tree/v1.62.2/examples/android) +The [examples](https://github.com/grpc/grpc-java/tree/v1.81.0/examples) and the +[Android example](https://github.com/grpc/grpc-java/tree/v1.81.0/examples/android) are standalone projects that showcase the usage of gRPC. Download @@ -56,42 +56,34 @@ Download [the JARs][]. Or for Maven with non-Android, add to your `pom.xml`: io.grpc grpc-netty-shaded - 1.62.2 + 1.81.0 runtime io.grpc grpc-protobuf - 1.62.2 + 1.81.0 io.grpc grpc-stub - 1.62.2 - - - org.apache.tomcat - annotations-api - 6.0.53 - provided + 1.81.0 ``` Or for Gradle with non-Android, add to your dependencies: ```gradle -runtimeOnly 'io.grpc:grpc-netty-shaded:1.62.2' -implementation 'io.grpc:grpc-protobuf:1.62.2' -implementation 'io.grpc:grpc-stub:1.62.2' -compileOnly 'org.apache.tomcat:annotations-api:6.0.53' // necessary for Java 9+ +runtimeOnly 'io.grpc:grpc-netty-shaded:1.81.0' +implementation 'io.grpc:grpc-protobuf:1.81.0' +implementation 'io.grpc:grpc-stub:1.81.0' ``` For Android client, use `grpc-okhttp` instead of `grpc-netty-shaded` and `grpc-protobuf-lite` instead of `grpc-protobuf`: ```gradle -implementation 'io.grpc:grpc-okhttp:1.62.2' -implementation 'io.grpc:grpc-protobuf-lite:1.62.2' -implementation 'io.grpc:grpc-stub:1.62.2' -compileOnly 'org.apache.tomcat:annotations-api:6.0.53' // necessary for Java 9+ +implementation 'io.grpc:grpc-okhttp:1.81.0' +implementation 'io.grpc:grpc-protobuf-lite:1.81.0' +implementation 'io.grpc:grpc-stub:1.81.0' ``` For [Bazel](https://bazel.build), you can either @@ -99,10 +91,10 @@ For [Bazel](https://bazel.build), you can either (with the GAVs from above), or use `@io_grpc_grpc_java//api` et al (see below). [the JARs]: -https://search.maven.org/search?q=g:io.grpc%20AND%20v:1.62.2 +https://search.maven.org/search?q=g:io.grpc%20AND%20v:1.81.0 Development snapshots are available in [Sonatypes's snapshot -repository](https://oss.sonatype.org/content/repositories/snapshots/). +repository](https://central.sonatype.com/repository/maven-snapshots/). Generated Code -------------- @@ -129,9 +121,9 @@ For protobuf-based codegen integrated with the Maven build system, you can use protobuf-maven-plugin 0.6.1 - com.google.protobuf:protoc:3.25.1:exe:${os.detected.classifier} + com.google.protobuf:protoc:3.25.8:exe:${os.detected.classifier} grpc-java - io.grpc:protoc-gen-grpc-java:1.62.2:exe:${os.detected.classifier} + io.grpc:protoc-gen-grpc-java:1.81.0:exe:${os.detected.classifier} @@ -152,16 +144,16 @@ For non-Android protobuf-based codegen integrated with the Gradle build system, you can use [protobuf-gradle-plugin][]: ```gradle plugins { - id 'com.google.protobuf' version '0.9.4' + id 'com.google.protobuf' version '0.9.5' } protobuf { protoc { - artifact = "com.google.protobuf:protoc:3.25.1" + artifact = "com.google.protobuf:protoc:3.25.8" } plugins { grpc { - artifact = 'io.grpc:protoc-gen-grpc-java:1.62.2' + artifact = 'io.grpc:protoc-gen-grpc-java:1.81.0' } } generateProtoTasks { @@ -185,16 +177,16 @@ use protobuf-gradle-plugin but specify the 'lite' options: ```gradle plugins { - id 'com.google.protobuf' version '0.9.4' + id 'com.google.protobuf' version '0.9.5' } protobuf { protoc { - artifact = "com.google.protobuf:protoc:3.25.1" + artifact = "com.google.protobuf:protoc:3.25.8" } plugins { grpc { - artifact = 'io.grpc:protoc-gen-grpc-java:1.62.2' + artifact = 'io.grpc:protoc-gen-grpc-java:1.81.0' } } generateProtoTasks { @@ -269,6 +261,9 @@ gRPC comes with multiple Transport implementations: 1. The Netty-based HTTP/2 transport is the main transport implementation based on [Netty](https://netty.io). It is not officially supported on Android. + There is a "grpc-netty-shaded" version of this transport. It is generally + preferred over using the Netty-based transport directly as it requires less + dependency management and is easier to upgrade within many applications. 2. The OkHttp-based HTTP/2 transport is a lightweight transport based on [Okio](https://square.github.io/okio/) and forked low-level parts of [OkHttp](https://square.github.io/okhttp/). It is mainly for use on Android. diff --git a/RELEASING.md b/RELEASING.md index ff43823ed43..c57829b8c25 100644 --- a/RELEASING.md +++ b/RELEASING.md @@ -18,8 +18,10 @@ them before continuing, and set them again when resuming. ```bash MAJOR=1 MINOR=7 PATCH=0 # Set appropriately for new release VERSION_FILES=( + MODULE.bazel build.gradle core/src/main/java/io/grpc/internal/GrpcUtil.java + examples/MODULE.bazel examples/build.gradle examples/pom.xml examples/android/clientcache/app/build.gradle @@ -63,7 +65,7 @@ would be used to create all `v1.7` tags (e.g. `v1.7.0`, `v1.7.1`). ```bash git fetch upstream git checkout -b v$MAJOR.$MINOR.x \ - $(git log --pretty=format:%H --grep "^Start $MAJOR.$((MINOR+1)).0 development cycle$" upstream/master)^ + $(git log --pretty=format:%H --grep "^Start $MAJOR.$((MINOR+1)).0 development cycle" upstream/master)^ git push upstream v$MAJOR.$MINOR.x ``` 5. Continue with Google-internal steps at go/grpc-java/releasing, but stop @@ -130,7 +132,9 @@ Tagging the Release compiler/src/test{,Lite}/golden/Test{,Deprecated}Service.java.txt ./gradlew build git commit -a -m "Bump version to $MAJOR.$MINOR.$((PATCH+1))-SNAPSHOT" + git push -u origin release-v$MAJOR.$MINOR.$PATCH ``` + Raise a PR and set the base branch of the PR to v$MAJOR.$MINOR.x of the upstream grpc-java repo. 6. Go through PR review and push the release tag and updated release branch to GitHub (DO NOT click the merge button on the GitHub page): @@ -156,21 +160,21 @@ Tagging the Release repository can then be `released`, which will begin the process of pushing the new artifacts to Maven Central (the staging repository will be destroyed in the process). You can see the complete process for releasing to Maven - Central on the [OSSRH site](https://central.sonatype.org/pages/releasing-the-deployment.html). + Central on the [OSSRH site](https://central.sonatype.org/publish/publish-portal-ossrh-staging-api/#deploying). 10. We have containers for each release to detect compatibility regressions with old releases. Generate one for the new release by following the [GCR image generation instructions][gcr-image]. Summary: ```bash # If you haven't previously configured docker: - gcloud auth configure-docker + gcloud auth configure-docker us-docker.pkg.dev # In main grpc repo, add the new version to matrix ${EDITOR:-nano -w} tools/interop_matrix/client_matrix.py tools/interop_matrix/create_matrix_images.py --git_checkout --release=v$MAJOR.$MINOR.$PATCH \ --upload_images --language java - docker pull gcr.io/grpc-testing/grpc_interop_java:v$MAJOR.$MINOR.$PATCH - docker_image=gcr.io/grpc-testing/grpc_interop_java:v$MAJOR.$MINOR.$PATCH \ + docker pull us-docker.pkg.dev/grpc-testing/testing-images-public/grpc_interop_java:v$MAJOR.$MINOR.$PATCH + docker_image=us-docker.pkg.dev/grpc-testing/testing-images-public/grpc_interop_java:v$MAJOR.$MINOR.$PATCH \ tools/interop_matrix/testcases/java__master # Commit the changes @@ -202,31 +206,47 @@ Tagging the Release 12. Add [Release Notes](https://github.com/grpc/grpc-java/releases) for the new tag. *Make sure that any backports are reflected in the release notes.* +13. Notify the Community. Post a release announcement to + [grpc-io](https://groups.google.com/forum/#!forum/grpc-io) + (`grpc-io@googlegroups.com`) with the title `gRPC-Java v$MAJOR.$MINOR.$PATCH + Released`. The email content should link to the GitHub release notes and + include a copy of them. -Update README.md ----------------- -After waiting ~1 day and verifying that the release is indexed on [Maven -Central](https://search.maven.org/search?q=g:io.grpc), cherry-pick the commit -that updated the README into the master branch. +14. Update README.md. Cherry-pick the commit that updated the README.md into the + master branch. -```bash -git checkout -b bump-readme master -git cherry-pick v$MAJOR.$MINOR.$PATCH^ -git push --set-upstream origin bump-readme -``` + ```bash + git checkout -b bump-readme master + git cherry-pick v$MAJOR.$MINOR.$PATCH^ + git push --set-upstream origin bump-readme + ``` + + Create a PR and go through the review process -Create a PR and go through the review process +15. Update version referenced by tutorials. Update `params.grpc_vers.java` in + [config.yaml](https://github.com/grpc/grpc.io/blob/master/config.yaml) of + the grpc.io repository. Create a PR and go through the review process. -Update version referenced by tutorials --------------------------------------- +Post-release upgrades +--------------------- +Upgrade dependencies after the release so they can be well-tested before the +next release. + +Upgrade the Gradle plugins in `settings.gradle` and the Gradle version in +`gradle/wrapper/gradle-wrapper.properties`. Make sure to read the release notes +for each dependency upgraded. Test by doing a regular build. + +Upgrade the regular dependencies in `gradle/libs.versions.toml`, except for +Netty and netty-tcnative. To find available upgrades: + +```bash +./gradlew checkForUpdates +``` -Update `params.grpc_vers.java` in -[config.yaml](https://github.com/grpc/grpc.io/blob/master/config.yaml) -of the grpc.io repository. +Test by doing a regular build. For each step, if a dependency cannot be +upgraded, add a comment. Create issues in other projects for breakages, and in +gRPC for things that will need a migration effort. -Notify the Community --------------------- -Post a release announcement to [grpc-io](https://groups.google.com/forum/#!forum/grpc-io) -(`grpc-io@googlegroups.com`) with the title `gRPC-Java v$MAJOR.$MINOR.$PATCH -Released`. The email content should link to the GitHub release notes and include -a copy of them. +When happy with the dependency upgrades, update the versions in `MODULE.bazel`, +`repositories.bzl`, and the various `pom.xml` and `build.gradle` files in +`examples/`. diff --git a/SECURITY.md b/SECURITY.md index bf2206a7d1a..e710ceaabe1 100644 --- a/SECURITY.md +++ b/SECURITY.md @@ -84,7 +84,7 @@ with OpenSSL](#tls-with-netty-tcnative-on-openssl) are other valid options. [Apache Tomcat's tcnative](https://tomcat.apache.org/native-doc/) and is a JNI wrapper around OpenSSL/BoringSSL/LibreSSL. -We recommend BoringSSL for its simplicitly and low occurrence of security +We recommend BoringSSL for its simplicity and low occurrence of security vulnerabilities relative to OpenSSL. BoringSSL is used by Conscrypt as well. ### TLS with netty-tcnative on BoringSSL @@ -330,14 +330,10 @@ is an option](#tls-with-conscrypt). Otherwise you need to [build your own 32-bit version of `netty-tcnative`](https://netty.io/wiki/forked-tomcat-native.html#wiki-h2-6). -If on Alpine Linux and you see "Error loading shared library libcrypt.so.1: No -such file or directory". Run `apk update && apk add libc6-compat` to install the -necessary dependency. - -If on Alpine Linux, try to use `grpc-netty-shaded` instead of `grpc-netty` or -(if you need `grpc-netty`) `netty-tcnative-boringssl-static` instead of -`netty-tcnative`. If those are not an option, you may consider using -[netty-tcnative-alpine](https://github.com/pires/netty-tcnative-alpine). +If on Alpine Linux, depending on your specific JDK you may see a crash in +netty_tcnative. This is generally caused by a missing symbol. Run `apk install +gcompat` and use the environment variable `LD_PRELOAD=/lib/libgcompat.so.0` when +executing Java. If on Fedora 30 or later and you see "libcrypt.so.1: cannot open shared object file: No such file or directory". Run `dnf -y install libxcrypt-compat` to @@ -398,7 +394,13 @@ grpc-netty version | netty-handler version | netty-tcnative-boringssl-static ver 1.56.x | 4.1.87.Final | 2.0.61.Final 1.57.x-1.58.x | 4.1.93.Final | 2.0.61.Final 1.59.x | 4.1.97.Final | 2.0.61.Final -1.60.x- | 4.1.100.Final | 2.0.61.Final +1.60.x-1.66.x | 4.1.100.Final | 2.0.61.Final +1.67.x-1.70.x | 4.1.110.Final | 2.0.65.Final +1.71.x-1.74.x | 4.1.110.Final | 2.0.70.Final +1.75.x-1.76.x | 4.1.124.Final | 2.0.72.Final +1.77.x-1.78.x | 4.1.127.Final | 2.0.74.Final +1.79.x-1.80.x | 4.1.130.Final | 2.0.74.Final +1.81.x- | 4.1.132.Final | 2.0.75.Final _(grpc-netty-shaded avoids issues with keeping these versions in sync.)_ diff --git a/WORKSPACE b/WORKSPACE index bb38b185552..1efdf2793a8 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -1,47 +1,58 @@ workspace(name = "io_grpc_grpc_java") load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") +load("//:repositories.bzl", "IO_GRPC_GRPC_JAVA_ARTIFACTS", "IO_GRPC_GRPC_JAVA_OVERRIDE_TARGETS", "grpc_java_repositories") + +grpc_java_repositories() http_archive( - name = "rules_jvm_external", - sha256 = "d31e369b854322ca5098ea12c69d7175ded971435e55c18dd9dd5f29cc5249ac", - strip_prefix = "rules_jvm_external-5.3", - url = "https://github.com/bazelbuild/rules_jvm_external/releases/download/5.3/rules_jvm_external-5.3.tar.gz", + name = "rules_java", + sha256 = "47632cc506c858011853073449801d648e10483d4b50e080ec2549a4b2398960", + urls = [ + "https://github.com/bazelbuild/rules_java/releases/download/8.15.2/rules_java-8.15.2.tar.gz", + ], ) -load("@rules_jvm_external//:defs.bzl", "maven_install") -load("//:repositories.bzl", "IO_GRPC_GRPC_JAVA_ARTIFACTS") -load("//:repositories.bzl", "IO_GRPC_GRPC_JAVA_OVERRIDE_TARGETS") -load("//:repositories.bzl", "grpc_java_repositories") +load("@com_google_protobuf//:protobuf_deps.bzl", "PROTOBUF_MAVEN_ARTIFACTS", "protobuf_deps") -grpc_java_repositories() +protobuf_deps() -load("@com_google_protobuf//:protobuf_deps.bzl", "PROTOBUF_MAVEN_ARTIFACTS") -load("@com_google_protobuf//:protobuf_deps.bzl", "protobuf_deps") +load("@rules_java//java:rules_java_deps.bzl", "rules_java_dependencies") -protobuf_deps() +rules_java_dependencies() + +load("@bazel_features//:deps.bzl", "bazel_features_deps") + +bazel_features_deps() -load("@envoy_api//bazel:repositories.bzl", "api_dependencies") +load("@bazel_jar_jar//:jar_jar.bzl", "jar_jar_repositories") -api_dependencies() +jar_jar_repositories() + +load("@rules_python//python:repositories.bzl", "py_repositories") + +py_repositories() load("@com_google_googleapis//:repository_rules.bzl", "switched_rules_by_language") switched_rules_by_language( name = "com_google_googleapis_imports", - java = True, ) +http_archive( + name = "rules_jvm_external", + sha256 = "d31e369b854322ca5098ea12c69d7175ded971435e55c18dd9dd5f29cc5249ac", + strip_prefix = "rules_jvm_external-5.3", + url = "https://github.com/bazelbuild/rules_jvm_external/releases/download/5.3/rules_jvm_external-5.3.tar.gz", +) + +load("@rules_jvm_external//:defs.bzl", "maven_install") + maven_install( artifacts = IO_GRPC_GRPC_JAVA_ARTIFACTS + PROTOBUF_MAVEN_ARTIFACTS, - generate_compat_repositories = True, override_targets = IO_GRPC_GRPC_JAVA_OVERRIDE_TARGETS, repositories = [ "https://repo.maven.apache.org/maven2/", ], strict_visibility = True, ) - -load("@maven//:compat.bzl", "compat_repositories") - -compat_repositories() diff --git a/WORKSPACE.bzlmod b/WORKSPACE.bzlmod new file mode 100644 index 00000000000..4ecb9e5d985 --- /dev/null +++ b/WORKSPACE.bzlmod @@ -0,0 +1 @@ +# When using bzlmod this makes sure nothing from the legacy WORKSPACE is loaded diff --git a/all/build.gradle b/all/build.gradle index 42c57531f34..11eec4b7ff8 100644 --- a/all/build.gradle +++ b/all/build.gradle @@ -12,9 +12,11 @@ def subprojects = [ project(':grpc-auth'), project(':grpc-core'), project(':grpc-grpclb'), + project(':grpc-gcp-csm-observability'), project(':grpc-inprocess'), project(':grpc-netty'), project(':grpc-okhttp'), + project(':grpc-opentelemetry'), project(':grpc-protobuf'), project(':grpc-protobuf-lite'), project(':grpc-rls'), diff --git a/alts/BUILD.bazel b/alts/BUILD.bazel index c99689bac11..f29df303fbe 100644 --- a/alts/BUILD.bazel +++ b/alts/BUILD.bazel @@ -1,4 +1,7 @@ -load("@rules_proto//proto:defs.bzl", "proto_library") +load("@com_google_protobuf//bazel:java_proto_library.bzl", "java_proto_library") +load("@com_google_protobuf//bazel:proto_library.bzl", "proto_library") +load("@rules_java//java:defs.bzl", "java_library") +load("@rules_jvm_external//:defs.bzl", "artifact") load("//:java_grpc_library.bzl", "java_grpc_library") java_library( @@ -11,19 +14,18 @@ java_library( ":handshaker_java_proto", "//api", "//core:internal", - "//grpclb", "//netty", "//stub", - "@com_google_code_findbugs_jsr305//jar", - "@com_google_guava_guava//jar", - "@com_google_j2objc_j2objc_annotations//jar", "@com_google_protobuf//:protobuf_java", "@com_google_protobuf//:protobuf_java_util", - "@io_netty_netty_buffer//jar", - "@io_netty_netty_codec//jar", - "@io_netty_netty_common//jar", - "@io_netty_netty_handler//jar", - "@io_netty_netty_transport//jar", + artifact("com.google.code.findbugs:jsr305"), + artifact("com.google.errorprone:error_prone_annotations"), + artifact("com.google.guava:guava"), + artifact("io.netty:netty-buffer"), + artifact("io.netty:netty-codec"), + artifact("io.netty:netty-common"), + artifact("io.netty:netty-handler"), + artifact("io.netty:netty-transport"), ], ) @@ -35,19 +37,18 @@ java_library( visibility = ["//visibility:public"], deps = [ ":alts_internal", - ":handshaker_java_proto", ":handshaker_java_grpc", + ":handshaker_java_proto", "//api", "//auth", "//core:internal", "//netty", - "@com_google_auth_google_auth_library_oauth2_http//jar", - "@com_google_code_findbugs_jsr305//jar", - "@com_google_guava_guava//jar", - "@com_google_j2objc_j2objc_annotations//jar", - "@io_netty_netty_common//jar", - "@io_netty_netty_handler//jar", - "@io_netty_netty_transport//jar", + artifact("com.google.auth:google-auth-library-oauth2-http"), + artifact("com.google.code.findbugs:jsr305"), + artifact("com.google.guava:guava"), + artifact("io.netty:netty-common"), + artifact("io.netty:netty-handler"), + artifact("io.netty:netty-transport"), ], ) diff --git a/alts/build.gradle b/alts/build.gradle index 187100698c8..c206a37bcef 100644 --- a/alts/build.gradle +++ b/alts/build.gradle @@ -2,8 +2,8 @@ plugins { id "java-library" id "maven-publish" - id "com.github.johnrengelman.shadow" id "com.google.protobuf" + id "com.gradleup.shadow" id "ru.vyarus.animalsniffer" } @@ -13,15 +13,13 @@ dependencies { api project(':grpc-api') implementation project(':grpc-auth'), project(':grpc-core'), - project(':grpc-grpclb'), + project(":grpc-context"), // Override google-auth dependency with our newer version project(':grpc-protobuf'), project(':grpc-stub'), libraries.protobuf.java, libraries.conscrypt, - libraries.guava.jre, // JRE required by protobuf-java-util from grpclb libraries.google.auth.oauth2Http def nettyDependency = implementation project(':grpc-netty') - compileOnly libraries.javax.annotation shadow configurations.implementation.getDependencies().minus(nettyDependency) shadow project(path: ':grpc-netty-shaded', configuration: 'shadow') @@ -43,7 +41,11 @@ dependencies { classifier = "linux-x86_64" } } - signature libraries.signature.java + signature (libraries.signature.java) { + artifact { + extension = "signature" + } + } } configureProtoCompilation() diff --git a/alts/src/generated/main/grpc/io/grpc/alts/internal/HandshakerServiceGrpc.java b/alts/src/generated/main/grpc/io/grpc/alts/internal/HandshakerServiceGrpc.java index 2caba4a0544..07e4256eb75 100644 --- a/alts/src/generated/main/grpc/io/grpc/alts/internal/HandshakerServiceGrpc.java +++ b/alts/src/generated/main/grpc/io/grpc/alts/internal/HandshakerServiceGrpc.java @@ -4,9 +4,6 @@ /** */ -@javax.annotation.Generated( - value = "by gRPC proto compiler", - comments = "Source: grpc/gcp/handshaker.proto") @io.grpc.stub.annotations.GrpcGenerated public final class HandshakerServiceGrpc { @@ -60,6 +57,21 @@ public HandshakerServiceStub newStub(io.grpc.Channel channel, io.grpc.CallOption return HandshakerServiceStub.newStub(factory, channel); } + /** + * Creates a new blocking-style stub that supports all types of calls on the service + */ + public static HandshakerServiceBlockingV2Stub newBlockingV2Stub( + io.grpc.Channel channel) { + io.grpc.stub.AbstractStub.StubFactory factory = + new io.grpc.stub.AbstractStub.StubFactory() { + @java.lang.Override + public HandshakerServiceBlockingV2Stub newStub(io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new HandshakerServiceBlockingV2Stub(channel, callOptions); + } + }; + return HandshakerServiceBlockingV2Stub.newStub(factory, channel); + } + /** * Creates a new blocking-style stub that supports unary and streaming output calls on the service */ @@ -157,6 +169,40 @@ public io.grpc.stub.StreamObserver doHandsh /** * A stub to allow clients to do synchronous rpc calls to service HandshakerService. */ + public static final class HandshakerServiceBlockingV2Stub + extends io.grpc.stub.AbstractBlockingStub { + private HandshakerServiceBlockingV2Stub( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + super(channel, callOptions); + } + + @java.lang.Override + protected HandshakerServiceBlockingV2Stub build( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new HandshakerServiceBlockingV2Stub(channel, callOptions); + } + + /** + *
+     * Handshaker service accepts a stream of handshaker request, returning a
+     * stream of handshaker response. Client is expected to send exactly one
+     * message with either client_start or server_start followed by one or more
+     * messages with next. Each time client sends a request, the handshaker
+     * service expects to respond. Client does not have to wait for service's
+     * response before sending next request.
+     * 
+ */ + @io.grpc.ExperimentalApi("https://github.com/grpc/grpc-java/issues/10918") + public io.grpc.stub.BlockingClientCall + doHandshake() { + return io.grpc.stub.ClientCalls.blockingBidiStreamingCall( + getChannel(), getDoHandshakeMethod(), getCallOptions()); + } + } + + /** + * A stub to allow clients to do limited synchronous rpc calls to service HandshakerService. + */ public static final class HandshakerServiceBlockingStub extends io.grpc.stub.AbstractBlockingStub { private HandshakerServiceBlockingStub( diff --git a/alts/src/main/java/io/grpc/alts/AltsChannelBuilder.java b/alts/src/main/java/io/grpc/alts/AltsChannelBuilder.java index 81dd75eab46..ca33f8b00b9 100644 --- a/alts/src/main/java/io/grpc/alts/AltsChannelBuilder.java +++ b/alts/src/main/java/io/grpc/alts/AltsChannelBuilder.java @@ -29,7 +29,7 @@ /** * ALTS version of {@code ManagedChannelBuilder}. This class sets up a secure and authenticated - * commmunication between two cloud VMs using ALTS. + * communication between two cloud VMs using ALTS. */ @ExperimentalApi("https://github.com/grpc/grpc-java/issues/4151") public final class AltsChannelBuilder extends ForwardingChannelBuilder2 { @@ -38,7 +38,7 @@ public final class AltsChannelBuilder extends ForwardingChannelBuilder2 call) { - Object authContext = call.getAttributes().get(AltsProtocolNegotiator.AUTH_CONTEXT_KEY); + public static AltsContext createFrom(ServerCall call) { + return createFrom(call.getAttributes()); + } + + /** + * Creates an {@link AltsContext} from ALTS context information in the {@link ClientCall}. + * + * @param call the {@link ClientCall} containing the ALTS information + * @return the created {@link AltsContext} + * @throws IllegalArgumentException if the {@link ClientCall} has no ALTS information. + */ + public static AltsContext createFrom(ClientCall call) { + return createFrom(call.getAttributes()); + } + + /** + * Creates an {@link AltsContext} from ALTS context information in the {@link Attributes}. + * + * @param attributes the {@link Attributes} containing the ALTS information + * @return the created {@link AltsContext} + * @throws IllegalArgumentException if the {@link Attributes} has no ALTS information. + */ + public static AltsContext createFrom(Attributes attributes) { + Object authContext = attributes.get(AltsProtocolNegotiator.AUTH_CONTEXT_KEY); if (!(authContext instanceof AltsInternalContext)) { throw new IllegalArgumentException("No ALTS context information found"); } @@ -49,8 +72,28 @@ public static AltsContext createFrom(ServerCall call) { * @param call the {@link ServerCall} to check * @return true, if the {@link ServerCall} contains ALTS information and false otherwise. */ - public static boolean check(ServerCall call) { - Object authContext = call.getAttributes().get(AltsProtocolNegotiator.AUTH_CONTEXT_KEY); + public static boolean check(ServerCall call) { + return check(call.getAttributes()); + } + + /** + * Checks if the {@link ClientCall} contains ALTS information. + * + * @param call the {@link ClientCall} to check + * @return true, if the {@link ClientCall} contains ALTS information and false otherwise. + */ + public static boolean check(ClientCall call) { + return check(call.getAttributes()); + } + + /** + * Checks if the {@link Attributes} contains ALTS information. + * + * @param attributes the {@link Attributes} to check + * @return true, if the {@link Attributes} contains ALTS information and false otherwise. + */ + public static boolean check(Attributes attributes) { + Object authContext = attributes.get(AltsProtocolNegotiator.AUTH_CONTEXT_KEY); return authContext instanceof AltsInternalContext; } } diff --git a/alts/src/main/java/io/grpc/alts/ComputeEngineChannelBuilder.java b/alts/src/main/java/io/grpc/alts/ComputeEngineChannelBuilder.java index 8898c33cd30..b5ee6a8d362 100644 --- a/alts/src/main/java/io/grpc/alts/ComputeEngineChannelBuilder.java +++ b/alts/src/main/java/io/grpc/alts/ComputeEngineChannelBuilder.java @@ -35,7 +35,7 @@ private ComputeEngineChannelBuilder(String target) { } /** "Overrides" the static method in {@link ManagedChannelBuilder}. */ - public static final ComputeEngineChannelBuilder forTarget(String target) { + public static ComputeEngineChannelBuilder forTarget(String target) { return new ComputeEngineChannelBuilder(target); } diff --git a/alts/src/main/java/io/grpc/alts/DualCallCredentials.java b/alts/src/main/java/io/grpc/alts/DualCallCredentials.java new file mode 100644 index 00000000000..08104712e65 --- /dev/null +++ b/alts/src/main/java/io/grpc/alts/DualCallCredentials.java @@ -0,0 +1,46 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.alts; + +import io.grpc.CallCredentials; +import java.util.concurrent.Executor; + +/** + * {@code CallCredentials} that will pick the right credentials based on whether the established + * connection is ALTS or TLS. + */ +final class DualCallCredentials extends CallCredentials { + private final CallCredentials tlsCallCredentials; + private final CallCredentials altsCallCredentials; + + public DualCallCredentials(CallCredentials tlsCallCreds, CallCredentials altsCallCreds) { + tlsCallCredentials = tlsCallCreds; + altsCallCredentials = altsCallCreds; + } + + @Override + public void applyRequestMetadata( + CallCredentials.RequestInfo requestInfo, + Executor appExecutor, + CallCredentials.MetadataApplier applier) { + if (AltsContextUtil.check(requestInfo.getTransportAttrs())) { + altsCallCredentials.applyRequestMetadata(requestInfo, appExecutor, applier); + } else { + tlsCallCredentials.applyRequestMetadata(requestInfo, appExecutor, applier); + } + } +} diff --git a/alts/src/main/java/io/grpc/alts/GoogleDefaultChannelBuilder.java b/alts/src/main/java/io/grpc/alts/GoogleDefaultChannelBuilder.java index 8e628065113..c78b94417c4 100644 --- a/alts/src/main/java/io/grpc/alts/GoogleDefaultChannelBuilder.java +++ b/alts/src/main/java/io/grpc/alts/GoogleDefaultChannelBuilder.java @@ -35,7 +35,7 @@ private GoogleDefaultChannelBuilder(String target) { } /** "Overrides" the static method in {@link ManagedChannelBuilder}. */ - public static final GoogleDefaultChannelBuilder forTarget(String target) { + public static GoogleDefaultChannelBuilder forTarget(String target) { return new GoogleDefaultChannelBuilder(target); } diff --git a/alts/src/main/java/io/grpc/alts/GoogleDefaultChannelCredentials.java b/alts/src/main/java/io/grpc/alts/GoogleDefaultChannelCredentials.java index d9c2ddaaed7..1b5880120a4 100644 --- a/alts/src/main/java/io/grpc/alts/GoogleDefaultChannelCredentials.java +++ b/alts/src/main/java/io/grpc/alts/GoogleDefaultChannelCredentials.java @@ -63,6 +63,7 @@ public static Builder newBuilder() { */ public static final class Builder { private CallCredentials callCredentials; + private CallCredentials altsCallCredentials; private Builder() {} @@ -72,23 +73,32 @@ public Builder callCredentials(CallCredentials callCreds) { return this; } + /** Constructs GoogleDefaultChannelCredentials with an ALTS-specific call credential. */ + public Builder altsCallCredentials(CallCredentials callCreds) { + altsCallCredentials = callCreds; + return this; + } + /** Builds a GoogleDefaultChannelCredentials instance. */ public ChannelCredentials build() { ChannelCredentials nettyCredentials = InternalNettyChannelCredentials.create(createClientFactory()); - if (callCredentials != null) { - return CompositeChannelCredentials.create(nettyCredentials, callCredentials); - } - CallCredentials callCreds; - try { - callCreds = MoreCallCredentials.from(GoogleCredentials.getApplicationDefault()); - } catch (IOException e) { - callCreds = - new FailingCallCredentials( - Status.UNAUTHENTICATED - .withDescription("Failed to get Google default credentials") - .withCause(e)); + CallCredentials tlsCallCreds = callCredentials; + if (tlsCallCreds == null) { + try { + tlsCallCreds = MoreCallCredentials.from(GoogleCredentials.getApplicationDefault()); + } catch (IOException e) { + tlsCallCreds = + new FailingCallCredentials( + Status.UNAUTHENTICATED + .withDescription("Failed to get Google default credentials") + .withCause(e)); + } } + CallCredentials callCreds = + altsCallCredentials == null + ? tlsCallCreds + : new DualCallCredentials(tlsCallCreds, altsCallCredentials); return CompositeChannelCredentials.create(nettyCredentials, callCreds); } diff --git a/alts/src/main/java/io/grpc/alts/HandshakerServiceChannel.java b/alts/src/main/java/io/grpc/alts/HandshakerServiceChannel.java index 8e8d175b7af..5e32d22d901 100644 --- a/alts/src/main/java/io/grpc/alts/HandshakerServiceChannel.java +++ b/alts/src/main/java/io/grpc/alts/HandshakerServiceChannel.java @@ -21,6 +21,7 @@ import io.grpc.ClientCall; import io.grpc.ManagedChannel; import io.grpc.MethodDescriptor; +import io.grpc.internal.GrpcUtil; import io.grpc.internal.SharedResourceHolder.Resource; import io.grpc.netty.NettyChannelBuilder; import io.netty.channel.EventLoopGroup; @@ -36,15 +37,37 @@ * application will have at most one connection to the handshaker service. */ final class HandshakerServiceChannel { + // Port 8080 is necessary for ALTS handshake. + private static final int ALTS_PORT = 8080; + private static final String DEFAULT_TARGET = "metadata.google.internal.:8080"; static final Resource SHARED_HANDSHAKER_CHANNEL = - new ChannelResource("metadata.google.internal.:8080"); - + new ChannelResource(getHandshakerTarget(System.getenv("GCE_METADATA_HOST"))); + + /** + * Returns handshaker target. When GCE_METADATA_HOST is provided, it might contain port which we + * will discard and use ALTS_PORT instead. + */ + static String getHandshakerTarget(String envValue) { + if (envValue == null || envValue.isEmpty()) { + return DEFAULT_TARGET; + } + String host = envValue; + int portIndex = host.lastIndexOf(':'); + if (portIndex != -1) { + host = host.substring(0, portIndex); // Discard port if specified + } + return host + ":" + ALTS_PORT; // Utilize ALTS port in all cases + } + /** Returns a resource of handshaker service channel for testing only. */ static Resource getHandshakerChannelForTesting(String handshakerAddress) { return new ChannelResource(handshakerAddress); } + private static final boolean EXPERIMENTAL_ALTS_HANDSHAKER_KEEPALIVE_PARAMS = + GrpcUtil.getFlag("GRPC_EXPERIMENTAL_ALTS_HANDSHAKER_KEEPALIVE_PARAMS", false); + private static class ChannelResource implements Resource { private final String target; @@ -57,12 +80,16 @@ public Channel create() { /* Use its own event loop thread pool to avoid blocking. */ EventLoopGroup eventGroup = new NioEventLoopGroup(1, new DefaultThreadFactory("handshaker pool", true)); - ManagedChannel channel = NettyChannelBuilder.forTarget(target) + NettyChannelBuilder channelBuilder = + NettyChannelBuilder.forTarget(target) .channelType(NioSocketChannel.class, InetSocketAddress.class) .directExecutor() .eventLoopGroup(eventGroup) - .usePlaintext() - .build(); + .usePlaintext(); + if (EXPERIMENTAL_ALTS_HANDSHAKER_KEEPALIVE_PARAMS) { + channelBuilder.keepAliveTime(10, TimeUnit.MINUTES).keepAliveTimeout(10, TimeUnit.SECONDS); + } + ManagedChannel channel = channelBuilder.build(); return new EventLoopHoldingChannel(channel, eventGroup); } diff --git a/alts/src/main/java/io/grpc/alts/internal/AltsProtocolNegotiator.java b/alts/src/main/java/io/grpc/alts/internal/AltsProtocolNegotiator.java index e0343f83c51..9c51cf6a053 100644 --- a/alts/src/main/java/io/grpc/alts/internal/AltsProtocolNegotiator.java +++ b/alts/src/main/java/io/grpc/alts/internal/AltsProtocolNegotiator.java @@ -30,7 +30,6 @@ import io.grpc.SecurityLevel; import io.grpc.Status; import io.grpc.alts.internal.RpcProtocolVersionsUtil.RpcVersionsCheckResult; -import io.grpc.grpclb.GrpclbConstants; import io.grpc.internal.ObjectPool; import io.grpc.netty.GrpcHttp2ConnectionHandler; import io.grpc.netty.InternalProtocolNegotiator; @@ -299,9 +298,7 @@ public ChannelHandler newHandler(GrpcHttp2ConnectionHandler grpcHandler) { isXdsDirectPath = isDirectPathCluster( grpcHandler.getEagAttributes().get(clusterNameAttrKey)); } - if (grpcHandler.getEagAttributes().get(GrpclbConstants.ATTR_LB_ADDR_AUTHORITY) != null - || grpcHandler.getEagAttributes().get(GrpclbConstants.ATTR_LB_PROVIDED_BACKEND) != null - || isXdsDirectPath) { + if (isXdsDirectPath) { TsiHandshaker handshaker = handshakerFactory.newHandshaker(grpcHandler.getAuthority(), negotiationLogger); NettyTsiHandshaker nettyHandshaker = new NettyTsiHandshaker(handshaker); diff --git a/alts/src/main/java/io/grpc/alts/internal/AltsTsiHandshaker.java b/alts/src/main/java/io/grpc/alts/internal/AltsTsiHandshaker.java index 007db9e1eed..2d6c322c1b1 100644 --- a/alts/src/main/java/io/grpc/alts/internal/AltsTsiHandshaker.java +++ b/alts/src/main/java/io/grpc/alts/internal/AltsTsiHandshaker.java @@ -80,7 +80,7 @@ public boolean processBytesFromPeer(ByteBuffer bytes) throws GeneralSecurityExce return true; } int remaining = bytes.remaining(); - // Call handshaker service to proceess the bytes. + // Call handshaker service to process the bytes. if (outputFrame == null) { checkState(!isClient, "Client handshaker should not process any frame at the beginning."); outputFrame = handshaker.startServerHandshake(bytes); diff --git a/alts/src/main/java/io/grpc/alts/internal/AsyncSemaphore.java b/alts/src/main/java/io/grpc/alts/internal/AsyncSemaphore.java index 3ccdcfc763a..a8251c7fbd3 100644 --- a/alts/src/main/java/io/grpc/alts/internal/AsyncSemaphore.java +++ b/alts/src/main/java/io/grpc/alts/internal/AsyncSemaphore.java @@ -16,12 +16,12 @@ package io.grpc.alts.internal; +import com.google.errorprone.annotations.concurrent.GuardedBy; import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelPromise; import java.util.LinkedList; import java.util.Queue; -import javax.annotation.concurrent.GuardedBy; /** Provides a semaphore primitive, without blocking waiting on permits. */ final class AsyncSemaphore { diff --git a/alts/src/main/java/io/grpc/alts/internal/ProtectedPromise.java b/alts/src/main/java/io/grpc/alts/internal/ProtectedPromise.java index e204acdd5f9..871a51f1bea 100644 --- a/alts/src/main/java/io/grpc/alts/internal/ProtectedPromise.java +++ b/alts/src/main/java/io/grpc/alts/internal/ProtectedPromise.java @@ -78,7 +78,7 @@ public ChannelPromise doneAllocatingPromises() { if (!doneAllocating) { doneAllocating = true; if (successfulCount == expectedCount) { - trySuccessInternal(null); + trySuccessInternal(); return super.setSuccess(null); } } @@ -117,18 +117,18 @@ private boolean awaitingPromises() { } @Override - public ChannelPromise setSuccess(Void result) { - trySuccess(result); + public ChannelPromise setSuccess(Void unused) { + trySuccess(null); return this; } @Override - public boolean trySuccess(Void result) { + public boolean trySuccess(Void unused) { if (awaitingPromises()) { ++successfulCount; if (successfulCount == expectedCount && doneAllocating) { - trySuccessInternal(result); - return super.trySuccess(result); + trySuccessInternal(); + return super.trySuccess(null); } // TODO: We break the interface a bit here. // Multiple success events can be processed without issue because this is an aggregation. @@ -137,9 +137,9 @@ public boolean trySuccess(Void result) { return false; } - private void trySuccessInternal(Void result) { + private void trySuccessInternal() { for (int i = 0; i < unprotectedPromises.size(); ++i) { - unprotectedPromises.get(i).trySuccess(result); + unprotectedPromises.get(i).trySuccess(null); } } diff --git a/alts/src/test/java/io/grpc/alts/AltsContextUtilTest.java b/alts/src/test/java/io/grpc/alts/AltsContextUtilTest.java index 6fd2d840d45..675fa29fc99 100644 --- a/alts/src/test/java/io/grpc/alts/AltsContextUtilTest.java +++ b/alts/src/test/java/io/grpc/alts/AltsContextUtilTest.java @@ -24,6 +24,7 @@ import static org.mockito.Mockito.when; import io.grpc.Attributes; +import io.grpc.ClientCall; import io.grpc.ServerCall; import io.grpc.alts.AltsContext.SecurityLevel; import io.grpc.alts.internal.AltsInternalContext; @@ -37,27 +38,38 @@ /** Unit tests for {@link AltsContextUtil}. */ @RunWith(JUnit4.class) public class AltsContextUtilTest { - - private final ServerCall call = mock(ServerCall.class); - @Test public void check_noAttributeValue() { - when(call.getAttributes()).thenReturn(Attributes.newBuilder().build()); + assertFalse(AltsContextUtil.check(Attributes.newBuilder().build())); + } - assertFalse(AltsContextUtil.check(call)); + @Test + public void check_unexpectedAttributeValueType() { + assertFalse(AltsContextUtil.check(Attributes.newBuilder() + .set(AltsProtocolNegotiator.AUTH_CONTEXT_KEY, new Object()) + .build())); } @Test - public void contains_unexpectedAttributeValueType() { + public void check_altsInternalContext() { + assertTrue(AltsContextUtil.check(Attributes.newBuilder() + .set(AltsProtocolNegotiator.AUTH_CONTEXT_KEY, AltsInternalContext.getDefaultInstance()) + .build())); + } + + @Test + public void checkServer_altsInternalContext() { + ServerCall call = mock(ServerCall.class); when(call.getAttributes()).thenReturn(Attributes.newBuilder() - .set(AltsProtocolNegotiator.AUTH_CONTEXT_KEY, new Object()) + .set(AltsProtocolNegotiator.AUTH_CONTEXT_KEY, AltsInternalContext.getDefaultInstance()) .build()); - assertFalse(AltsContextUtil.check(call)); + assertTrue(AltsContextUtil.check(call)); } @Test - public void contains_altsInternalContext() { + public void checkClient_altsInternalContext() { + ClientCall call = mock(ClientCall.class); when(call.getAttributes()).thenReturn(Attributes.newBuilder() .set(AltsProtocolNegotiator.AUTH_CONTEXT_KEY, AltsInternalContext.getDefaultInstance()) .build()); @@ -66,26 +78,57 @@ public void contains_altsInternalContext() { } @Test - public void from_altsInternalContext() { + public void createFrom_altsInternalContext() { HandshakerResult handshakerResult = HandshakerResult.newBuilder() .setPeerIdentity(Identity.newBuilder().setServiceAccount("remote@peer")) .setLocalIdentity(Identity.newBuilder().setServiceAccount("local@peer")) .build(); - when(call.getAttributes()).thenReturn(Attributes.newBuilder() - .set(AltsProtocolNegotiator.AUTH_CONTEXT_KEY, new AltsInternalContext(handshakerResult)) - .build()); - AltsContext context = AltsContextUtil.createFrom(call); + AltsContext context = AltsContextUtil.createFrom(Attributes.newBuilder() + .set(AltsProtocolNegotiator.AUTH_CONTEXT_KEY, new AltsInternalContext(handshakerResult)) + .build()); assertEquals("remote@peer", context.getPeerServiceAccount()); assertEquals("local@peer", context.getLocalServiceAccount()); assertEquals(SecurityLevel.INTEGRITY_AND_PRIVACY, context.getSecurityLevel()); } @Test(expected = IllegalArgumentException.class) - public void from_noAttributeValue() { - when(call.getAttributes()).thenReturn(Attributes.newBuilder().build()); + public void createFrom_noAttributeValue() { + AltsContextUtil.createFrom(Attributes.newBuilder().build()); + } - AltsContextUtil.createFrom(call); + @Test + public void createFromServer_altsInternalContext() { + HandshakerResult handshakerResult = + HandshakerResult.newBuilder() + .setPeerIdentity(Identity.newBuilder().setServiceAccount("remote@peer")) + .setLocalIdentity(Identity.newBuilder().setServiceAccount("local@peer")) + .build(); + + ServerCall call = mock(ServerCall.class); + when(call.getAttributes()).thenReturn(Attributes.newBuilder() + .set(AltsProtocolNegotiator.AUTH_CONTEXT_KEY, new AltsInternalContext(handshakerResult)) + .build()); + + AltsContext context = AltsContextUtil.createFrom(call); + assertEquals("remote@peer", context.getPeerServiceAccount()); + } + + @Test + public void createFromClient_altsInternalContext() { + HandshakerResult handshakerResult = + HandshakerResult.newBuilder() + .setPeerIdentity(Identity.newBuilder().setServiceAccount("remote@peer")) + .setLocalIdentity(Identity.newBuilder().setServiceAccount("local@peer")) + .build(); + + ClientCall call = mock(ClientCall.class); + when(call.getAttributes()).thenReturn(Attributes.newBuilder() + .set(AltsProtocolNegotiator.AUTH_CONTEXT_KEY, new AltsInternalContext(handshakerResult)) + .build()); + + AltsContext context = AltsContextUtil.createFrom(call); + assertEquals("remote@peer", context.getPeerServiceAccount()); } } diff --git a/alts/src/test/java/io/grpc/alts/DualCallCredentialsTest.java b/alts/src/test/java/io/grpc/alts/DualCallCredentialsTest.java new file mode 100644 index 00000000000..29646191be1 --- /dev/null +++ b/alts/src/test/java/io/grpc/alts/DualCallCredentialsTest.java @@ -0,0 +1,109 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.alts; + +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +import io.grpc.Attributes; +import io.grpc.CallCredentials; +import io.grpc.CallCredentials.RequestInfo; +import io.grpc.MethodDescriptor; +import io.grpc.SecurityLevel; +import io.grpc.alts.internal.AltsInternalContext; +import io.grpc.alts.internal.AltsProtocolNegotiator; +import io.grpc.testing.TestMethodDescriptors; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnit; +import org.mockito.junit.MockitoRule; + +/** Unit tests for {@link DualCallCredentials}. */ +@RunWith(JUnit4.class) +public class DualCallCredentialsTest { + + @Rule public final MockitoRule mocks = MockitoJUnit.rule(); + + @Mock CallCredentials tlsCallCredentials; + + @Mock CallCredentials altsCallCredentials; + + private static final String AUTHORITY = "testauthority"; + private static final SecurityLevel SECURITY_LEVEL = SecurityLevel.PRIVACY_AND_INTEGRITY; + + @Test + public void invokeTlsCallCredentials() { + DualCallCredentials callCredentials = + new DualCallCredentials(tlsCallCredentials, altsCallCredentials); + RequestInfo requestInfo = new RequestInfoImpl(false); + callCredentials.applyRequestMetadata(requestInfo, null, null); + + verify(altsCallCredentials, never()).applyRequestMetadata(any(), any(), any()); + verify(tlsCallCredentials, times(1)).applyRequestMetadata(requestInfo, null, null); + } + + @Test + public void invokeAltsCallCredentials() { + DualCallCredentials callCredentials = + new DualCallCredentials(tlsCallCredentials, altsCallCredentials); + RequestInfo requestInfo = new RequestInfoImpl(true); + callCredentials.applyRequestMetadata(requestInfo, null, null); + + verify(altsCallCredentials, times(1)).applyRequestMetadata(requestInfo, null, null); + verify(tlsCallCredentials, never()).applyRequestMetadata(any(), any(), any()); + } + + private static final class RequestInfoImpl extends CallCredentials.RequestInfo { + private Attributes attrs; + + RequestInfoImpl(boolean hasAltsContext) { + attrs = + hasAltsContext + ? Attributes.newBuilder() + .set( + AltsProtocolNegotiator.AUTH_CONTEXT_KEY, + AltsInternalContext.getDefaultInstance()) + .build() + : Attributes.EMPTY; + } + + @Override + public MethodDescriptor getMethodDescriptor() { + return TestMethodDescriptors.voidMethod(); + } + + @Override + public SecurityLevel getSecurityLevel() { + return SECURITY_LEVEL; + } + + @Override + public String getAuthority() { + return AUTHORITY; + } + + @Override + public Attributes getTransportAttrs() { + return attrs; + } + } +} diff --git a/alts/src/test/java/io/grpc/alts/HandshakerServiceChannelTest.java b/alts/src/test/java/io/grpc/alts/HandshakerServiceChannelTest.java index a3937904cd7..221001157f1 100644 --- a/alts/src/test/java/io/grpc/alts/HandshakerServiceChannelTest.java +++ b/alts/src/test/java/io/grpc/alts/HandshakerServiceChannelTest.java @@ -67,6 +67,24 @@ public void sharedChannel_authority() { } } + @Test + public void getHandshakerTarget_nullEnvVar() { + assertThat(HandshakerServiceChannel.getHandshakerTarget(null)) + .isEqualTo("metadata.google.internal.:8080"); + } + + @Test + public void getHandshakerTarget_envVarWithPort() { + assertThat(HandshakerServiceChannel.getHandshakerTarget("169.254.169.254:80")) + .isEqualTo("169.254.169.254:8080"); + } + + @Test + public void getHandshakerTarget_envVarWithHostOnly() { + assertThat(HandshakerServiceChannel.getHandshakerTarget("169.254.169.254")) + .isEqualTo("169.254.169.254:8080"); + } + @Test public void resource_works() { Channel channel = resource.create(); diff --git a/alts/src/test/java/io/grpc/alts/internal/AltsHandshakerStubTest.java b/alts/src/test/java/io/grpc/alts/internal/AltsHandshakerStubTest.java index fc2b440749a..7a6018b5064 100644 --- a/alts/src/test/java/io/grpc/alts/internal/AltsHandshakerStubTest.java +++ b/alts/src/test/java/io/grpc/alts/internal/AltsHandshakerStubTest.java @@ -32,7 +32,7 @@ @RunWith(JUnit4.class) public class AltsHandshakerStubTest { /** Mock status of handshaker service. */ - private static enum Status { + private enum Status { OK, ERROR, COMPLETE diff --git a/alts/src/test/java/io/grpc/alts/internal/AltsProtocolNegotiatorTest.java b/alts/src/test/java/io/grpc/alts/internal/AltsProtocolNegotiatorTest.java index 24392af75fd..d47607ed90f 100644 --- a/alts/src/test/java/io/grpc/alts/internal/AltsProtocolNegotiatorTest.java +++ b/alts/src/test/java/io/grpc/alts/internal/AltsProtocolNegotiatorTest.java @@ -202,8 +202,11 @@ public void operationComplete(ChannelFuture future) throws Exception { channel.flush(); // Capture the protected data written to the wire. - assertEquals(1, channel.outboundMessages().size()); - ByteBuf protectedData = channel.readOutbound(); + assertThat(channel.outboundMessages()).isNotEmpty(); + ByteBuf protectedData = channel.alloc().buffer(); + while (!channel.outboundMessages().isEmpty()) { + protectedData.writeBytes((ByteBuf) channel.readOutbound()); + } assertEquals(message.length(), writeCount.get()); // Read the protected message at the server and verify it matches the original message. @@ -327,16 +330,18 @@ public void doNotFlushEmptyBuffer() throws Exception { String message = "hello"; ByteBuf in = Unpooled.copiedBuffer(message, UTF_8); - assertEquals(0, protector.flushes.get()); + int flushes = protector.flushes.get(); Future done = channel.write(in); channel.flush(); + flushes++; done.get(5, TimeUnit.SECONDS); - assertEquals(1, protector.flushes.get()); + assertEquals(flushes, protector.flushes.get()); + // Flush does not propagate done = channel.write(Unpooled.EMPTY_BUFFER); channel.flush(); done.get(5, TimeUnit.SECONDS); - assertEquals(1, protector.flushes.get()); + assertEquals(flushes, protector.flushes.get()); } @Test diff --git a/alts/src/test/java/io/grpc/alts/internal/FakeTsiHandshaker.java b/alts/src/test/java/io/grpc/alts/internal/FakeTsiHandshaker.java index a68f842a98e..7a6119dc0be 100644 --- a/alts/src/test/java/io/grpc/alts/internal/FakeTsiHandshaker.java +++ b/alts/src/test/java/io/grpc/alts/internal/FakeTsiHandshaker.java @@ -68,6 +68,7 @@ enum State { SERVER_FINISHED; // Returns the next State. In order to advance to sendState=N, receiveState must be N-1. + @SuppressWarnings("EnumOrdinal") public State next() { if (ordinal() + 1 < values().length) { return values()[ordinal() + 1]; @@ -147,7 +148,7 @@ public void getBytesToSendToPeer(ByteBuffer bytes) throws GeneralSecurityExcepti return; } - // Prepare the next message, if neeeded. + // Prepare the next message, if needed. if (sendBuffer == null) { if (sendState.next() != receiveState) { // We're still waiting for bytes from the peer, so bail. diff --git a/alts/src/test/java/io/grpc/alts/internal/GoogleDefaultProtocolNegotiatorTest.java b/alts/src/test/java/io/grpc/alts/internal/GoogleDefaultProtocolNegotiatorTest.java index 9a520720beb..14c19e554ae 100644 --- a/alts/src/test/java/io/grpc/alts/internal/GoogleDefaultProtocolNegotiatorTest.java +++ b/alts/src/test/java/io/grpc/alts/internal/GoogleDefaultProtocolNegotiatorTest.java @@ -29,7 +29,6 @@ import io.grpc.ChannelLogger; import io.grpc.ChannelLogger.ChannelLogLevel; import io.grpc.ManagedChannel; -import io.grpc.grpclb.GrpclbConstants; import io.grpc.inprocess.InProcessChannelBuilder; import io.grpc.internal.ObjectPool; import io.grpc.netty.GrpcHttp2ConnectionHandler; @@ -95,13 +94,6 @@ public void tearDown() { @Nullable abstract Attributes.Key getClusterNameAttrKey(); - @Test - public void altsHandler_lbProvidedBackend() { - Attributes attrs = - Attributes.newBuilder().set(GrpclbConstants.ATTR_LB_PROVIDED_BACKEND, true).build(); - subtest_altsHandler(attrs); - } - @Test public void tlsHandler_emptyAttributes() { subtest_tlsHandler(Attributes.EMPTY); diff --git a/android-interop-testing/build.gradle b/android-interop-testing/build.gradle index 4d96adbd0dc..b61d50a6763 100644 --- a/android-interop-testing/build.gradle +++ b/android-interop-testing/build.gradle @@ -7,11 +7,10 @@ description = 'gRPC: Android Integration Testing' repositories { google() - mavenCentral() } android { - namespace 'io.grpc.android.integrationtest' + namespace = 'io.grpc.android.integrationtest' sourceSets { main { java { @@ -34,15 +33,11 @@ android { defaultConfig { applicationId "io.grpc.android.integrationtest" - // Held back to 20 as Gradle fails to build at the 21 level. This is - // presumably a Gradle bug that can be revisited later. - // Maybe this issue: https://github.com/gradle/gradle/issues/20778 - minSdkVersion 20 + minSdkVersion 23 targetSdkVersion 33 versionCode 1 versionName "1.0" testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner" - multiDexEnabled true } buildTypes { debug { minifyEnabled false } @@ -63,37 +58,30 @@ android { dependencies { implementation 'androidx.appcompat:appcompat:1.3.0' - implementation 'androidx.multidex:multidex:2.0.0' implementation libraries.androidx.annotation implementation 'com.google.android.gms:play-services-base:18.0.1' implementation project(':grpc-android'), + project(':grpc-api'), project(':grpc-core'), - project(':grpc-auth'), project(':grpc-census'), project(':grpc-okhttp'), project(':grpc-protobuf-lite'), project(':grpc-stub'), project(':grpc-testing'), - libraries.hdrhistogram, libraries.junit, libraries.truth, libraries.androidx.test.rules, + libraries.androidx.test.core, libraries.opencensus.contrib.grpc.metrics - implementation (libraries.google.auth.oauth2Http) { - exclude group: 'org.apache.httpcomponents' - } - implementation (project(':grpc-services')) { exclude group: 'com.google.protobuf' exclude group: 'com.google.guava' } - compileOnly libraries.javax.annotation - - androidTestImplementation 'androidx.test.ext:junit:1.1.3', - 'androidx.test:runner:1.4.0' + androidTestImplementation libraries.androidx.test.ext.junit, + libraries.androidx.test.runner } // Checkstyle doesn't run automatically with android @@ -116,7 +104,6 @@ import net.ltgt.gradle.errorprone.CheckSeverity tasks.withType(JavaCompile).configureEach { options.compilerArgs += [ "-Xlint:-cast", - "-Xlint:-deprecation", // https://github.com/grpc/grpc-java/issues/10298 ] appendToProperty(it.options.errorprone.excludedPaths, ".*/R.java", "|") appendToProperty( @@ -125,6 +112,25 @@ tasks.withType(JavaCompile).configureEach { "|") } +// Workaround error seen with Gradle 8.14.3 and AGP 7.4.1 when building: +// ./gradlew clean :grpc-android-interop-testing:build -PskipAndroid=false \ +// -Pandroid.useAndroidX=true --no-build-cache +// +// Error message: +// +// Execution failed for task ':grpc-android-interop-testing:mergeExtDexDebug'. +// > Could not resolve all files for configuration ':grpc-android-interop-testing:debugRuntimeClasspath'. +// > Failed to transform opencensus-contrib-grpc-metrics-0.31.1.jar (io.opencensus:opencensus-contrib-grpc-metrics:0.31.1) to match attributes {artifactType=android-dex, asm-transformed-variant=NONE, dexing-enable-desugaring=true, dexing-enable-jacoco-instrumentation=false, dexing-is-debuggable=true, dexing-min-sdk=23, org.gradle.category=library, org.gradle.libraryelements=jar, org.gradle.status=release, org.gradle.usage=java-runtime}. +// > Could not resolve all files for configuration ':grpc-android-interop-testing:debugRuntimeClasspath'. +// > Failed to transform grpc-api-1.81.0-SNAPSHOT.jar (project :grpc-api) to match attributes {artifactType=android-classes-jar, org.gradle.category=library, org.gradle.dependency.bundling=external, org.gradle.jvm.version=8, org.gradle.libraryelements=jar, org.gradle.usage=java-runtime}. +// > Execution failed for IdentityTransform: grpc-java/api/build/libs/grpc-api-1.81.0-SNAPSHOT.jar. +// > File/directory does not exist: grpc-java/api/build/libs/grpc-api-1.81.0-SNAPSHOT.jar +tasks.configureEach { task -> + if (task.name.equals("mergeExtDexDebug")) { + dependsOn(':grpc-api:jar') + } +} + afterEvaluate { // Hack to workaround "Task ':grpc-android-interop-testing:extractIncludeDebugProto' uses this // output of task ':grpc-context:jar' without declaring an explicit or implicit dependency." The diff --git a/android-interop-testing/src/androidTest/AndroidManifest.xml b/android-interop-testing/src/androidTest/AndroidManifest.xml index b0507f10ab9..3cc0a29a85f 100644 --- a/android-interop-testing/src/androidTest/AndroidManifest.xml +++ b/android-interop-testing/src/androidTest/AndroidManifest.xml @@ -5,8 +5,7 @@ android:name="androidx.test.runner.AndroidJUnitRunner" android:targetPackage="io.grpc.android.integrationtest" /> - + diff --git a/android-interop-testing/src/androidTest/java/io/grpc/android/integrationtest/InteropInstrumentationTest.java b/android-interop-testing/src/androidTest/java/io/grpc/android/integrationtest/InteropInstrumentationTest.java index 5b06c91fe09..a5870063a87 100644 --- a/android-interop-testing/src/androidTest/java/io/grpc/android/integrationtest/InteropInstrumentationTest.java +++ b/android-interop-testing/src/androidTest/java/io/grpc/android/integrationtest/InteropInstrumentationTest.java @@ -30,6 +30,7 @@ import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @@ -115,9 +116,10 @@ private void runTest(String testCase) throws Exception { result = executor.submit(new TestCallable( TesterOkHttpChannelBuilder.build(host, port, serverHostOverride, useTls, testCa), testCase)).get(TIMEOUT_SECONDS, TimeUnit.SECONDS); - assertEquals(testCase + " failed", TestCallable.SUCCESS_MESSAGE, result); - } catch (ExecutionException | InterruptedException e) { + } catch (ExecutionException | InterruptedException | TimeoutException e) { + Log.e(LOG_TAG, "Error while executing test case " + testCase, e); result = e.getMessage(); } + assertEquals(testCase + " failed", TestCallable.SUCCESS_MESSAGE, result); } } diff --git a/android-interop-testing/src/androidTest/java/io/grpc/android/integrationtest/UdsChannelInteropTest.java b/android-interop-testing/src/androidTest/java/io/grpc/android/integrationtest/UdsChannelInteropTest.java index f5e54da5d4e..5b98665ba29 100644 --- a/android-interop-testing/src/androidTest/java/io/grpc/android/integrationtest/UdsChannelInteropTest.java +++ b/android-interop-testing/src/androidTest/java/io/grpc/android/integrationtest/UdsChannelInteropTest.java @@ -19,9 +19,9 @@ import static org.junit.Assert.assertEquals; import android.net.LocalSocketAddress.Namespace; -import androidx.test.InstrumentationRegistry; +import androidx.test.ext.junit.rules.ActivityScenarioRule; import androidx.test.ext.junit.runners.AndroidJUnit4; -import androidx.test.rule.ActivityTestRule; +import androidx.test.platform.app.InstrumentationRegistry; import io.grpc.Grpc; import io.grpc.InsecureServerCredentials; import io.grpc.Server; @@ -60,8 +60,8 @@ public class UdsChannelInteropTest { // Ensures Looper is initialized for tests running on API level 15. Otherwise instantiating an // AsyncTask throws an exception. @Rule - public ActivityTestRule activityRule = - new ActivityTestRule(TesterActivity.class); + public ActivityScenarioRule activityRule = + new ActivityScenarioRule<>(TesterActivity.class); @Before public void setUp() throws IOException { diff --git a/android-interop-testing/src/generated/debug/grpc/io/grpc/testing/integration/LoadBalancerStatsServiceGrpc.java b/android-interop-testing/src/generated/debug/grpc/io/grpc/testing/integration/LoadBalancerStatsServiceGrpc.java index e030fde13e3..33b914bb4b3 100644 --- a/android-interop-testing/src/generated/debug/grpc/io/grpc/testing/integration/LoadBalancerStatsServiceGrpc.java +++ b/android-interop-testing/src/generated/debug/grpc/io/grpc/testing/integration/LoadBalancerStatsServiceGrpc.java @@ -7,9 +7,6 @@ * A service used to obtain stats for verifying LB behavior. * */ -@javax.annotation.Generated( - value = "by gRPC proto compiler", - comments = "Source: grpc/testing/test.proto") @io.grpc.stub.annotations.GrpcGenerated public final class LoadBalancerStatsServiceGrpc { @@ -92,6 +89,21 @@ public LoadBalancerStatsServiceStub newStub(io.grpc.Channel channel, io.grpc.Cal return LoadBalancerStatsServiceStub.newStub(factory, channel); } + /** + * Creates a new blocking-style stub that supports all types of calls on the service + */ + public static LoadBalancerStatsServiceBlockingV2Stub newBlockingV2Stub( + io.grpc.Channel channel) { + io.grpc.stub.AbstractStub.StubFactory factory = + new io.grpc.stub.AbstractStub.StubFactory() { + @java.lang.Override + public LoadBalancerStatsServiceBlockingV2Stub newStub(io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new LoadBalancerStatsServiceBlockingV2Stub(channel, callOptions); + } + }; + return LoadBalancerStatsServiceBlockingV2Stub.newStub(factory, channel); + } + /** * Creates a new blocking-style stub that supports unary and streaming output calls on the service */ @@ -212,6 +224,46 @@ public void getClientAccumulatedStats(io.grpc.testing.integration.Messages.LoadB * A service used to obtain stats for verifying LB behavior. * */ + public static final class LoadBalancerStatsServiceBlockingV2Stub + extends io.grpc.stub.AbstractBlockingStub { + private LoadBalancerStatsServiceBlockingV2Stub( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + super(channel, callOptions); + } + + @java.lang.Override + protected LoadBalancerStatsServiceBlockingV2Stub build( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new LoadBalancerStatsServiceBlockingV2Stub(channel, callOptions); + } + + /** + *
+     * Gets the backend distribution for RPCs sent by a test client.
+     * 
+ */ + public io.grpc.testing.integration.Messages.LoadBalancerStatsResponse getClientStats(io.grpc.testing.integration.Messages.LoadBalancerStatsRequest request) throws io.grpc.StatusException { + return io.grpc.stub.ClientCalls.blockingV2UnaryCall( + getChannel(), getGetClientStatsMethod(), getCallOptions(), request); + } + + /** + *
+     * Gets the accumulated stats for RPCs sent by a test client.
+     * 
+ */ + public io.grpc.testing.integration.Messages.LoadBalancerAccumulatedStatsResponse getClientAccumulatedStats(io.grpc.testing.integration.Messages.LoadBalancerAccumulatedStatsRequest request) throws io.grpc.StatusException { + return io.grpc.stub.ClientCalls.blockingV2UnaryCall( + getChannel(), getGetClientAccumulatedStatsMethod(), getCallOptions(), request); + } + } + + /** + * A stub to allow clients to do limited synchronous rpc calls to service LoadBalancerStatsService. + *
+   * A service used to obtain stats for verifying LB behavior.
+   * 
+ */ public static final class LoadBalancerStatsServiceBlockingStub extends io.grpc.stub.AbstractBlockingStub { private LoadBalancerStatsServiceBlockingStub( diff --git a/android-interop-testing/src/generated/debug/grpc/io/grpc/testing/integration/MetricsServiceGrpc.java b/android-interop-testing/src/generated/debug/grpc/io/grpc/testing/integration/MetricsServiceGrpc.java index e8726d5adc4..c99abcff7cb 100644 --- a/android-interop-testing/src/generated/debug/grpc/io/grpc/testing/integration/MetricsServiceGrpc.java +++ b/android-interop-testing/src/generated/debug/grpc/io/grpc/testing/integration/MetricsServiceGrpc.java @@ -4,9 +4,6 @@ /** */ -@javax.annotation.Generated( - value = "by gRPC proto compiler", - comments = "Source: grpc/testing/metrics.proto") @io.grpc.stub.annotations.GrpcGenerated public final class MetricsServiceGrpc { @@ -89,6 +86,21 @@ public MetricsServiceStub newStub(io.grpc.Channel channel, io.grpc.CallOptions c return MetricsServiceStub.newStub(factory, channel); } + /** + * Creates a new blocking-style stub that supports all types of calls on the service + */ + public static MetricsServiceBlockingV2Stub newBlockingV2Stub( + io.grpc.Channel channel) { + io.grpc.stub.AbstractStub.StubFactory factory = + new io.grpc.stub.AbstractStub.StubFactory() { + @java.lang.Override + public MetricsServiceBlockingV2Stub newStub(io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new MetricsServiceBlockingV2Stub(channel, callOptions); + } + }; + return MetricsServiceBlockingV2Stub.newStub(factory, channel); + } + /** * Creates a new blocking-style stub that supports unary and streaming output calls on the service */ @@ -199,6 +211,46 @@ public void getGauge(io.grpc.testing.integration.Metrics.GaugeRequest request, /** * A stub to allow clients to do synchronous rpc calls to service MetricsService. */ + public static final class MetricsServiceBlockingV2Stub + extends io.grpc.stub.AbstractBlockingStub { + private MetricsServiceBlockingV2Stub( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + super(channel, callOptions); + } + + @java.lang.Override + protected MetricsServiceBlockingV2Stub build( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new MetricsServiceBlockingV2Stub(channel, callOptions); + } + + /** + *
+     * Returns the values of all the gauges that are currently being maintained by
+     * the service
+     * 
+ */ + @io.grpc.ExperimentalApi("https://github.com/grpc/grpc-java/issues/10918") + public io.grpc.stub.BlockingClientCall + getAllGauges(io.grpc.testing.integration.Metrics.EmptyMessage request) { + return io.grpc.stub.ClientCalls.blockingV2ServerStreamingCall( + getChannel(), getGetAllGaugesMethod(), getCallOptions(), request); + } + + /** + *
+     * Returns the value of one gauge
+     * 
+ */ + public io.grpc.testing.integration.Metrics.GaugeResponse getGauge(io.grpc.testing.integration.Metrics.GaugeRequest request) throws io.grpc.StatusException { + return io.grpc.stub.ClientCalls.blockingV2UnaryCall( + getChannel(), getGetGaugeMethod(), getCallOptions(), request); + } + } + + /** + * A stub to allow clients to do limited synchronous rpc calls to service MetricsService. + */ public static final class MetricsServiceBlockingStub extends io.grpc.stub.AbstractBlockingStub { private MetricsServiceBlockingStub( diff --git a/android-interop-testing/src/generated/debug/grpc/io/grpc/testing/integration/ReconnectServiceGrpc.java b/android-interop-testing/src/generated/debug/grpc/io/grpc/testing/integration/ReconnectServiceGrpc.java index 8ede6407cd0..fffcaad2df2 100644 --- a/android-interop-testing/src/generated/debug/grpc/io/grpc/testing/integration/ReconnectServiceGrpc.java +++ b/android-interop-testing/src/generated/debug/grpc/io/grpc/testing/integration/ReconnectServiceGrpc.java @@ -7,9 +7,6 @@ * A service used to control reconnect server. * */ -@javax.annotation.Generated( - value = "by gRPC proto compiler", - comments = "Source: grpc/testing/test.proto") @io.grpc.stub.annotations.GrpcGenerated public final class ReconnectServiceGrpc { @@ -92,6 +89,21 @@ public ReconnectServiceStub newStub(io.grpc.Channel channel, io.grpc.CallOptions return ReconnectServiceStub.newStub(factory, channel); } + /** + * Creates a new blocking-style stub that supports all types of calls on the service + */ + public static ReconnectServiceBlockingV2Stub newBlockingV2Stub( + io.grpc.Channel channel) { + io.grpc.stub.AbstractStub.StubFactory factory = + new io.grpc.stub.AbstractStub.StubFactory() { + @java.lang.Override + public ReconnectServiceBlockingV2Stub newStub(io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new ReconnectServiceBlockingV2Stub(channel, callOptions); + } + }; + return ReconnectServiceBlockingV2Stub.newStub(factory, channel); + } + /** * Creates a new blocking-style stub that supports unary and streaming output calls on the service */ @@ -200,6 +212,40 @@ public void stop(io.grpc.testing.integration.EmptyProtos.Empty request, * A service used to control reconnect server. * */ + public static final class ReconnectServiceBlockingV2Stub + extends io.grpc.stub.AbstractBlockingStub { + private ReconnectServiceBlockingV2Stub( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + super(channel, callOptions); + } + + @java.lang.Override + protected ReconnectServiceBlockingV2Stub build( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new ReconnectServiceBlockingV2Stub(channel, callOptions); + } + + /** + */ + public io.grpc.testing.integration.EmptyProtos.Empty start(io.grpc.testing.integration.Messages.ReconnectParams request) throws io.grpc.StatusException { + return io.grpc.stub.ClientCalls.blockingV2UnaryCall( + getChannel(), getStartMethod(), getCallOptions(), request); + } + + /** + */ + public io.grpc.testing.integration.Messages.ReconnectInfo stop(io.grpc.testing.integration.EmptyProtos.Empty request) throws io.grpc.StatusException { + return io.grpc.stub.ClientCalls.blockingV2UnaryCall( + getChannel(), getStopMethod(), getCallOptions(), request); + } + } + + /** + * A stub to allow clients to do limited synchronous rpc calls to service ReconnectService. + *
+   * A service used to control reconnect server.
+   * 
+ */ public static final class ReconnectServiceBlockingStub extends io.grpc.stub.AbstractBlockingStub { private ReconnectServiceBlockingStub( diff --git a/android-interop-testing/src/generated/debug/grpc/io/grpc/testing/integration/TestServiceGrpc.java b/android-interop-testing/src/generated/debug/grpc/io/grpc/testing/integration/TestServiceGrpc.java index 01e2678a12f..1d7805e3a3f 100644 --- a/android-interop-testing/src/generated/debug/grpc/io/grpc/testing/integration/TestServiceGrpc.java +++ b/android-interop-testing/src/generated/debug/grpc/io/grpc/testing/integration/TestServiceGrpc.java @@ -8,9 +8,6 @@ * performance with various types of payload. * */ -@javax.annotation.Generated( - value = "by gRPC proto compiler", - comments = "Source: grpc/testing/test.proto") @io.grpc.stub.annotations.GrpcGenerated public final class TestServiceGrpc { @@ -273,6 +270,21 @@ public TestServiceStub newStub(io.grpc.Channel channel, io.grpc.CallOptions call return TestServiceStub.newStub(factory, channel); } + /** + * Creates a new blocking-style stub that supports all types of calls on the service + */ + public static TestServiceBlockingV2Stub newBlockingV2Stub( + io.grpc.Channel channel) { + io.grpc.stub.AbstractStub.StubFactory factory = + new io.grpc.stub.AbstractStub.StubFactory() { + @java.lang.Override + public TestServiceBlockingV2Stub newStub(io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new TestServiceBlockingV2Stub(channel, callOptions); + } + }; + return TestServiceBlockingV2Stub.newStub(factory, channel); + } + /** * Creates a new blocking-style stub that supports unary and streaming output calls on the service */ @@ -543,6 +555,125 @@ public void unimplementedCall(io.grpc.testing.integration.EmptyProtos.Empty requ * performance with various types of payload. * */ + public static final class TestServiceBlockingV2Stub + extends io.grpc.stub.AbstractBlockingStub { + private TestServiceBlockingV2Stub( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + super(channel, callOptions); + } + + @java.lang.Override + protected TestServiceBlockingV2Stub build( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new TestServiceBlockingV2Stub(channel, callOptions); + } + + /** + *
+     * One empty request followed by one empty response.
+     * 
+ */ + public io.grpc.testing.integration.EmptyProtos.Empty emptyCall(io.grpc.testing.integration.EmptyProtos.Empty request) throws io.grpc.StatusException { + return io.grpc.stub.ClientCalls.blockingV2UnaryCall( + getChannel(), getEmptyCallMethod(), getCallOptions(), request); + } + + /** + *
+     * One request followed by one response.
+     * 
+ */ + public io.grpc.testing.integration.Messages.SimpleResponse unaryCall(io.grpc.testing.integration.Messages.SimpleRequest request) throws io.grpc.StatusException { + return io.grpc.stub.ClientCalls.blockingV2UnaryCall( + getChannel(), getUnaryCallMethod(), getCallOptions(), request); + } + + /** + *
+     * One request followed by one response. Response has cache control
+     * headers set such that a caching HTTP proxy (such as GFE) can
+     * satisfy subsequent requests.
+     * 
+ */ + public io.grpc.testing.integration.Messages.SimpleResponse cacheableUnaryCall(io.grpc.testing.integration.Messages.SimpleRequest request) throws io.grpc.StatusException { + return io.grpc.stub.ClientCalls.blockingV2UnaryCall( + getChannel(), getCacheableUnaryCallMethod(), getCallOptions(), request); + } + + /** + *
+     * One request followed by a sequence of responses (streamed download).
+     * The server returns the payload with client desired type and sizes.
+     * 
+ */ + @io.grpc.ExperimentalApi("https://github.com/grpc/grpc-java/issues/10918") + public io.grpc.stub.BlockingClientCall + streamingOutputCall(io.grpc.testing.integration.Messages.StreamingOutputCallRequest request) { + return io.grpc.stub.ClientCalls.blockingV2ServerStreamingCall( + getChannel(), getStreamingOutputCallMethod(), getCallOptions(), request); + } + + /** + *
+     * A sequence of requests followed by one response (streamed upload).
+     * The server returns the aggregated size of client payload as the result.
+     * 
+ */ + @io.grpc.ExperimentalApi("https://github.com/grpc/grpc-java/issues/10918") + public io.grpc.stub.BlockingClientCall + streamingInputCall() { + return io.grpc.stub.ClientCalls.blockingClientStreamingCall( + getChannel(), getStreamingInputCallMethod(), getCallOptions()); + } + + /** + *
+     * A sequence of requests with each request served by the server immediately.
+     * As one request could lead to multiple responses, this interface
+     * demonstrates the idea of full duplexing.
+     * 
+ */ + @io.grpc.ExperimentalApi("https://github.com/grpc/grpc-java/issues/10918") + public io.grpc.stub.BlockingClientCall + fullDuplexCall() { + return io.grpc.stub.ClientCalls.blockingBidiStreamingCall( + getChannel(), getFullDuplexCallMethod(), getCallOptions()); + } + + /** + *
+     * A sequence of requests followed by a sequence of responses.
+     * The server buffers all the client requests and then serves them in order. A
+     * stream of responses are returned to the client when the server starts with
+     * first request.
+     * 
+ */ + @io.grpc.ExperimentalApi("https://github.com/grpc/grpc-java/issues/10918") + public io.grpc.stub.BlockingClientCall + halfDuplexCall() { + return io.grpc.stub.ClientCalls.blockingBidiStreamingCall( + getChannel(), getHalfDuplexCallMethod(), getCallOptions()); + } + + /** + *
+     * The test server will not implement this method. It will be used
+     * to test the behavior when clients call unimplemented methods.
+     * 
+ */ + public io.grpc.testing.integration.EmptyProtos.Empty unimplementedCall(io.grpc.testing.integration.EmptyProtos.Empty request) throws io.grpc.StatusException { + return io.grpc.stub.ClientCalls.blockingV2UnaryCall( + getChannel(), getUnimplementedCallMethod(), getCallOptions(), request); + } + } + + /** + * A stub to allow clients to do limited synchronous rpc calls to service TestService. + *
+   * A simple service to test the various types of RPCs and experiment with
+   * performance with various types of payload.
+   * 
+ */ public static final class TestServiceBlockingStub extends io.grpc.stub.AbstractBlockingStub { private TestServiceBlockingStub( diff --git a/android-interop-testing/src/generated/debug/grpc/io/grpc/testing/integration/UnimplementedServiceGrpc.java b/android-interop-testing/src/generated/debug/grpc/io/grpc/testing/integration/UnimplementedServiceGrpc.java index 743d68c3828..bec9b5a723a 100644 --- a/android-interop-testing/src/generated/debug/grpc/io/grpc/testing/integration/UnimplementedServiceGrpc.java +++ b/android-interop-testing/src/generated/debug/grpc/io/grpc/testing/integration/UnimplementedServiceGrpc.java @@ -8,9 +8,6 @@ * that case. * */ -@javax.annotation.Generated( - value = "by gRPC proto compiler", - comments = "Source: grpc/testing/test.proto") @io.grpc.stub.annotations.GrpcGenerated public final class UnimplementedServiceGrpc { @@ -63,6 +60,21 @@ public UnimplementedServiceStub newStub(io.grpc.Channel channel, io.grpc.CallOpt return UnimplementedServiceStub.newStub(factory, channel); } + /** + * Creates a new blocking-style stub that supports all types of calls on the service + */ + public static UnimplementedServiceBlockingV2Stub newBlockingV2Stub( + io.grpc.Channel channel) { + io.grpc.stub.AbstractStub.StubFactory factory = + new io.grpc.stub.AbstractStub.StubFactory() { + @java.lang.Override + public UnimplementedServiceBlockingV2Stub newStub(io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new UnimplementedServiceBlockingV2Stub(channel, callOptions); + } + }; + return UnimplementedServiceBlockingV2Stub.newStub(factory, channel); + } + /** * Creates a new blocking-style stub that supports unary and streaming output calls on the service */ @@ -166,6 +178,37 @@ public void unimplementedCall(io.grpc.testing.integration.EmptyProtos.Empty requ * that case. * */ + public static final class UnimplementedServiceBlockingV2Stub + extends io.grpc.stub.AbstractBlockingStub { + private UnimplementedServiceBlockingV2Stub( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + super(channel, callOptions); + } + + @java.lang.Override + protected UnimplementedServiceBlockingV2Stub build( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new UnimplementedServiceBlockingV2Stub(channel, callOptions); + } + + /** + *
+     * A call that no server should implement
+     * 
+ */ + public io.grpc.testing.integration.EmptyProtos.Empty unimplementedCall(io.grpc.testing.integration.EmptyProtos.Empty request) throws io.grpc.StatusException { + return io.grpc.stub.ClientCalls.blockingV2UnaryCall( + getChannel(), getUnimplementedCallMethod(), getCallOptions(), request); + } + } + + /** + * A stub to allow clients to do limited synchronous rpc calls to service UnimplementedService. + *
+   * A simple service NOT implemented at servers so clients can test for
+   * that case.
+   * 
+ */ public static final class UnimplementedServiceBlockingStub extends io.grpc.stub.AbstractBlockingStub { private UnimplementedServiceBlockingStub( diff --git a/android-interop-testing/src/generated/debug/grpc/io/grpc/testing/integration/XdsUpdateClientConfigureServiceGrpc.java b/android-interop-testing/src/generated/debug/grpc/io/grpc/testing/integration/XdsUpdateClientConfigureServiceGrpc.java index 61cfc19d29b..3453b6c01be 100644 --- a/android-interop-testing/src/generated/debug/grpc/io/grpc/testing/integration/XdsUpdateClientConfigureServiceGrpc.java +++ b/android-interop-testing/src/generated/debug/grpc/io/grpc/testing/integration/XdsUpdateClientConfigureServiceGrpc.java @@ -7,9 +7,6 @@ * A service to dynamically update the configuration of an xDS test client. * */ -@javax.annotation.Generated( - value = "by gRPC proto compiler", - comments = "Source: grpc/testing/test.proto") @io.grpc.stub.annotations.GrpcGenerated public final class XdsUpdateClientConfigureServiceGrpc { @@ -62,6 +59,21 @@ public XdsUpdateClientConfigureServiceStub newStub(io.grpc.Channel channel, io.g return XdsUpdateClientConfigureServiceStub.newStub(factory, channel); } + /** + * Creates a new blocking-style stub that supports all types of calls on the service + */ + public static XdsUpdateClientConfigureServiceBlockingV2Stub newBlockingV2Stub( + io.grpc.Channel channel) { + io.grpc.stub.AbstractStub.StubFactory factory = + new io.grpc.stub.AbstractStub.StubFactory() { + @java.lang.Override + public XdsUpdateClientConfigureServiceBlockingV2Stub newStub(io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new XdsUpdateClientConfigureServiceBlockingV2Stub(channel, callOptions); + } + }; + return XdsUpdateClientConfigureServiceBlockingV2Stub.newStub(factory, channel); + } + /** * Creates a new blocking-style stub that supports unary and streaming output calls on the service */ @@ -161,6 +173,36 @@ public void configure(io.grpc.testing.integration.Messages.ClientConfigureReques * A service to dynamically update the configuration of an xDS test client. * */ + public static final class XdsUpdateClientConfigureServiceBlockingV2Stub + extends io.grpc.stub.AbstractBlockingStub { + private XdsUpdateClientConfigureServiceBlockingV2Stub( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + super(channel, callOptions); + } + + @java.lang.Override + protected XdsUpdateClientConfigureServiceBlockingV2Stub build( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new XdsUpdateClientConfigureServiceBlockingV2Stub(channel, callOptions); + } + + /** + *
+     * Update the tes client's configuration.
+     * 
+ */ + public io.grpc.testing.integration.Messages.ClientConfigureResponse configure(io.grpc.testing.integration.Messages.ClientConfigureRequest request) throws io.grpc.StatusException { + return io.grpc.stub.ClientCalls.blockingV2UnaryCall( + getChannel(), getConfigureMethod(), getCallOptions(), request); + } + } + + /** + * A stub to allow clients to do limited synchronous rpc calls to service XdsUpdateClientConfigureService. + *
+   * A service to dynamically update the configuration of an xDS test client.
+   * 
+ */ public static final class XdsUpdateClientConfigureServiceBlockingStub extends io.grpc.stub.AbstractBlockingStub { private XdsUpdateClientConfigureServiceBlockingStub( diff --git a/android-interop-testing/src/generated/debug/grpc/io/grpc/testing/integration/XdsUpdateHealthServiceGrpc.java b/android-interop-testing/src/generated/debug/grpc/io/grpc/testing/integration/XdsUpdateHealthServiceGrpc.java index 6ba9419dedf..fb5f2cdebc7 100644 --- a/android-interop-testing/src/generated/debug/grpc/io/grpc/testing/integration/XdsUpdateHealthServiceGrpc.java +++ b/android-interop-testing/src/generated/debug/grpc/io/grpc/testing/integration/XdsUpdateHealthServiceGrpc.java @@ -7,9 +7,6 @@ * A service to remotely control health status of an xDS test server. * */ -@javax.annotation.Generated( - value = "by gRPC proto compiler", - comments = "Source: grpc/testing/test.proto") @io.grpc.stub.annotations.GrpcGenerated public final class XdsUpdateHealthServiceGrpc { @@ -92,6 +89,21 @@ public XdsUpdateHealthServiceStub newStub(io.grpc.Channel channel, io.grpc.CallO return XdsUpdateHealthServiceStub.newStub(factory, channel); } + /** + * Creates a new blocking-style stub that supports all types of calls on the service + */ + public static XdsUpdateHealthServiceBlockingV2Stub newBlockingV2Stub( + io.grpc.Channel channel) { + io.grpc.stub.AbstractStub.StubFactory factory = + new io.grpc.stub.AbstractStub.StubFactory() { + @java.lang.Override + public XdsUpdateHealthServiceBlockingV2Stub newStub(io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new XdsUpdateHealthServiceBlockingV2Stub(channel, callOptions); + } + }; + return XdsUpdateHealthServiceBlockingV2Stub.newStub(factory, channel); + } + /** * Creates a new blocking-style stub that supports unary and streaming output calls on the service */ @@ -200,6 +212,40 @@ public void setNotServing(io.grpc.testing.integration.EmptyProtos.Empty request, * A service to remotely control health status of an xDS test server. * */ + public static final class XdsUpdateHealthServiceBlockingV2Stub + extends io.grpc.stub.AbstractBlockingStub { + private XdsUpdateHealthServiceBlockingV2Stub( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + super(channel, callOptions); + } + + @java.lang.Override + protected XdsUpdateHealthServiceBlockingV2Stub build( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new XdsUpdateHealthServiceBlockingV2Stub(channel, callOptions); + } + + /** + */ + public io.grpc.testing.integration.EmptyProtos.Empty setServing(io.grpc.testing.integration.EmptyProtos.Empty request) throws io.grpc.StatusException { + return io.grpc.stub.ClientCalls.blockingV2UnaryCall( + getChannel(), getSetServingMethod(), getCallOptions(), request); + } + + /** + */ + public io.grpc.testing.integration.EmptyProtos.Empty setNotServing(io.grpc.testing.integration.EmptyProtos.Empty request) throws io.grpc.StatusException { + return io.grpc.stub.ClientCalls.blockingV2UnaryCall( + getChannel(), getSetNotServingMethod(), getCallOptions(), request); + } + } + + /** + * A stub to allow clients to do limited synchronous rpc calls to service XdsUpdateHealthService. + *
+   * A service to remotely control health status of an xDS test server.
+   * 
+ */ public static final class XdsUpdateHealthServiceBlockingStub extends io.grpc.stub.AbstractBlockingStub { private XdsUpdateHealthServiceBlockingStub( diff --git a/android-interop-testing/src/generated/release/grpc/io/grpc/testing/integration/LoadBalancerStatsServiceGrpc.java b/android-interop-testing/src/generated/release/grpc/io/grpc/testing/integration/LoadBalancerStatsServiceGrpc.java index e030fde13e3..33b914bb4b3 100644 --- a/android-interop-testing/src/generated/release/grpc/io/grpc/testing/integration/LoadBalancerStatsServiceGrpc.java +++ b/android-interop-testing/src/generated/release/grpc/io/grpc/testing/integration/LoadBalancerStatsServiceGrpc.java @@ -7,9 +7,6 @@ * A service used to obtain stats for verifying LB behavior. * */ -@javax.annotation.Generated( - value = "by gRPC proto compiler", - comments = "Source: grpc/testing/test.proto") @io.grpc.stub.annotations.GrpcGenerated public final class LoadBalancerStatsServiceGrpc { @@ -92,6 +89,21 @@ public LoadBalancerStatsServiceStub newStub(io.grpc.Channel channel, io.grpc.Cal return LoadBalancerStatsServiceStub.newStub(factory, channel); } + /** + * Creates a new blocking-style stub that supports all types of calls on the service + */ + public static LoadBalancerStatsServiceBlockingV2Stub newBlockingV2Stub( + io.grpc.Channel channel) { + io.grpc.stub.AbstractStub.StubFactory factory = + new io.grpc.stub.AbstractStub.StubFactory() { + @java.lang.Override + public LoadBalancerStatsServiceBlockingV2Stub newStub(io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new LoadBalancerStatsServiceBlockingV2Stub(channel, callOptions); + } + }; + return LoadBalancerStatsServiceBlockingV2Stub.newStub(factory, channel); + } + /** * Creates a new blocking-style stub that supports unary and streaming output calls on the service */ @@ -212,6 +224,46 @@ public void getClientAccumulatedStats(io.grpc.testing.integration.Messages.LoadB * A service used to obtain stats for verifying LB behavior. * */ + public static final class LoadBalancerStatsServiceBlockingV2Stub + extends io.grpc.stub.AbstractBlockingStub { + private LoadBalancerStatsServiceBlockingV2Stub( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + super(channel, callOptions); + } + + @java.lang.Override + protected LoadBalancerStatsServiceBlockingV2Stub build( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new LoadBalancerStatsServiceBlockingV2Stub(channel, callOptions); + } + + /** + *
+     * Gets the backend distribution for RPCs sent by a test client.
+     * 
+ */ + public io.grpc.testing.integration.Messages.LoadBalancerStatsResponse getClientStats(io.grpc.testing.integration.Messages.LoadBalancerStatsRequest request) throws io.grpc.StatusException { + return io.grpc.stub.ClientCalls.blockingV2UnaryCall( + getChannel(), getGetClientStatsMethod(), getCallOptions(), request); + } + + /** + *
+     * Gets the accumulated stats for RPCs sent by a test client.
+     * 
+ */ + public io.grpc.testing.integration.Messages.LoadBalancerAccumulatedStatsResponse getClientAccumulatedStats(io.grpc.testing.integration.Messages.LoadBalancerAccumulatedStatsRequest request) throws io.grpc.StatusException { + return io.grpc.stub.ClientCalls.blockingV2UnaryCall( + getChannel(), getGetClientAccumulatedStatsMethod(), getCallOptions(), request); + } + } + + /** + * A stub to allow clients to do limited synchronous rpc calls to service LoadBalancerStatsService. + *
+   * A service used to obtain stats for verifying LB behavior.
+   * 
+ */ public static final class LoadBalancerStatsServiceBlockingStub extends io.grpc.stub.AbstractBlockingStub { private LoadBalancerStatsServiceBlockingStub( diff --git a/android-interop-testing/src/generated/release/grpc/io/grpc/testing/integration/MetricsServiceGrpc.java b/android-interop-testing/src/generated/release/grpc/io/grpc/testing/integration/MetricsServiceGrpc.java index e8726d5adc4..c99abcff7cb 100644 --- a/android-interop-testing/src/generated/release/grpc/io/grpc/testing/integration/MetricsServiceGrpc.java +++ b/android-interop-testing/src/generated/release/grpc/io/grpc/testing/integration/MetricsServiceGrpc.java @@ -4,9 +4,6 @@ /** */ -@javax.annotation.Generated( - value = "by gRPC proto compiler", - comments = "Source: grpc/testing/metrics.proto") @io.grpc.stub.annotations.GrpcGenerated public final class MetricsServiceGrpc { @@ -89,6 +86,21 @@ public MetricsServiceStub newStub(io.grpc.Channel channel, io.grpc.CallOptions c return MetricsServiceStub.newStub(factory, channel); } + /** + * Creates a new blocking-style stub that supports all types of calls on the service + */ + public static MetricsServiceBlockingV2Stub newBlockingV2Stub( + io.grpc.Channel channel) { + io.grpc.stub.AbstractStub.StubFactory factory = + new io.grpc.stub.AbstractStub.StubFactory() { + @java.lang.Override + public MetricsServiceBlockingV2Stub newStub(io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new MetricsServiceBlockingV2Stub(channel, callOptions); + } + }; + return MetricsServiceBlockingV2Stub.newStub(factory, channel); + } + /** * Creates a new blocking-style stub that supports unary and streaming output calls on the service */ @@ -199,6 +211,46 @@ public void getGauge(io.grpc.testing.integration.Metrics.GaugeRequest request, /** * A stub to allow clients to do synchronous rpc calls to service MetricsService. */ + public static final class MetricsServiceBlockingV2Stub + extends io.grpc.stub.AbstractBlockingStub { + private MetricsServiceBlockingV2Stub( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + super(channel, callOptions); + } + + @java.lang.Override + protected MetricsServiceBlockingV2Stub build( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new MetricsServiceBlockingV2Stub(channel, callOptions); + } + + /** + *
+     * Returns the values of all the gauges that are currently being maintained by
+     * the service
+     * 
+ */ + @io.grpc.ExperimentalApi("https://github.com/grpc/grpc-java/issues/10918") + public io.grpc.stub.BlockingClientCall + getAllGauges(io.grpc.testing.integration.Metrics.EmptyMessage request) { + return io.grpc.stub.ClientCalls.blockingV2ServerStreamingCall( + getChannel(), getGetAllGaugesMethod(), getCallOptions(), request); + } + + /** + *
+     * Returns the value of one gauge
+     * 
+ */ + public io.grpc.testing.integration.Metrics.GaugeResponse getGauge(io.grpc.testing.integration.Metrics.GaugeRequest request) throws io.grpc.StatusException { + return io.grpc.stub.ClientCalls.blockingV2UnaryCall( + getChannel(), getGetGaugeMethod(), getCallOptions(), request); + } + } + + /** + * A stub to allow clients to do limited synchronous rpc calls to service MetricsService. + */ public static final class MetricsServiceBlockingStub extends io.grpc.stub.AbstractBlockingStub { private MetricsServiceBlockingStub( diff --git a/android-interop-testing/src/generated/release/grpc/io/grpc/testing/integration/ReconnectServiceGrpc.java b/android-interop-testing/src/generated/release/grpc/io/grpc/testing/integration/ReconnectServiceGrpc.java index 8ede6407cd0..fffcaad2df2 100644 --- a/android-interop-testing/src/generated/release/grpc/io/grpc/testing/integration/ReconnectServiceGrpc.java +++ b/android-interop-testing/src/generated/release/grpc/io/grpc/testing/integration/ReconnectServiceGrpc.java @@ -7,9 +7,6 @@ * A service used to control reconnect server. * */ -@javax.annotation.Generated( - value = "by gRPC proto compiler", - comments = "Source: grpc/testing/test.proto") @io.grpc.stub.annotations.GrpcGenerated public final class ReconnectServiceGrpc { @@ -92,6 +89,21 @@ public ReconnectServiceStub newStub(io.grpc.Channel channel, io.grpc.CallOptions return ReconnectServiceStub.newStub(factory, channel); } + /** + * Creates a new blocking-style stub that supports all types of calls on the service + */ + public static ReconnectServiceBlockingV2Stub newBlockingV2Stub( + io.grpc.Channel channel) { + io.grpc.stub.AbstractStub.StubFactory factory = + new io.grpc.stub.AbstractStub.StubFactory() { + @java.lang.Override + public ReconnectServiceBlockingV2Stub newStub(io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new ReconnectServiceBlockingV2Stub(channel, callOptions); + } + }; + return ReconnectServiceBlockingV2Stub.newStub(factory, channel); + } + /** * Creates a new blocking-style stub that supports unary and streaming output calls on the service */ @@ -200,6 +212,40 @@ public void stop(io.grpc.testing.integration.EmptyProtos.Empty request, * A service used to control reconnect server. * */ + public static final class ReconnectServiceBlockingV2Stub + extends io.grpc.stub.AbstractBlockingStub { + private ReconnectServiceBlockingV2Stub( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + super(channel, callOptions); + } + + @java.lang.Override + protected ReconnectServiceBlockingV2Stub build( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new ReconnectServiceBlockingV2Stub(channel, callOptions); + } + + /** + */ + public io.grpc.testing.integration.EmptyProtos.Empty start(io.grpc.testing.integration.Messages.ReconnectParams request) throws io.grpc.StatusException { + return io.grpc.stub.ClientCalls.blockingV2UnaryCall( + getChannel(), getStartMethod(), getCallOptions(), request); + } + + /** + */ + public io.grpc.testing.integration.Messages.ReconnectInfo stop(io.grpc.testing.integration.EmptyProtos.Empty request) throws io.grpc.StatusException { + return io.grpc.stub.ClientCalls.blockingV2UnaryCall( + getChannel(), getStopMethod(), getCallOptions(), request); + } + } + + /** + * A stub to allow clients to do limited synchronous rpc calls to service ReconnectService. + *
+   * A service used to control reconnect server.
+   * 
+ */ public static final class ReconnectServiceBlockingStub extends io.grpc.stub.AbstractBlockingStub { private ReconnectServiceBlockingStub( diff --git a/android-interop-testing/src/generated/release/grpc/io/grpc/testing/integration/TestServiceGrpc.java b/android-interop-testing/src/generated/release/grpc/io/grpc/testing/integration/TestServiceGrpc.java index 01e2678a12f..1d7805e3a3f 100644 --- a/android-interop-testing/src/generated/release/grpc/io/grpc/testing/integration/TestServiceGrpc.java +++ b/android-interop-testing/src/generated/release/grpc/io/grpc/testing/integration/TestServiceGrpc.java @@ -8,9 +8,6 @@ * performance with various types of payload. * */ -@javax.annotation.Generated( - value = "by gRPC proto compiler", - comments = "Source: grpc/testing/test.proto") @io.grpc.stub.annotations.GrpcGenerated public final class TestServiceGrpc { @@ -273,6 +270,21 @@ public TestServiceStub newStub(io.grpc.Channel channel, io.grpc.CallOptions call return TestServiceStub.newStub(factory, channel); } + /** + * Creates a new blocking-style stub that supports all types of calls on the service + */ + public static TestServiceBlockingV2Stub newBlockingV2Stub( + io.grpc.Channel channel) { + io.grpc.stub.AbstractStub.StubFactory factory = + new io.grpc.stub.AbstractStub.StubFactory() { + @java.lang.Override + public TestServiceBlockingV2Stub newStub(io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new TestServiceBlockingV2Stub(channel, callOptions); + } + }; + return TestServiceBlockingV2Stub.newStub(factory, channel); + } + /** * Creates a new blocking-style stub that supports unary and streaming output calls on the service */ @@ -543,6 +555,125 @@ public void unimplementedCall(io.grpc.testing.integration.EmptyProtos.Empty requ * performance with various types of payload. * */ + public static final class TestServiceBlockingV2Stub + extends io.grpc.stub.AbstractBlockingStub { + private TestServiceBlockingV2Stub( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + super(channel, callOptions); + } + + @java.lang.Override + protected TestServiceBlockingV2Stub build( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new TestServiceBlockingV2Stub(channel, callOptions); + } + + /** + *
+     * One empty request followed by one empty response.
+     * 
+ */ + public io.grpc.testing.integration.EmptyProtos.Empty emptyCall(io.grpc.testing.integration.EmptyProtos.Empty request) throws io.grpc.StatusException { + return io.grpc.stub.ClientCalls.blockingV2UnaryCall( + getChannel(), getEmptyCallMethod(), getCallOptions(), request); + } + + /** + *
+     * One request followed by one response.
+     * 
+ */ + public io.grpc.testing.integration.Messages.SimpleResponse unaryCall(io.grpc.testing.integration.Messages.SimpleRequest request) throws io.grpc.StatusException { + return io.grpc.stub.ClientCalls.blockingV2UnaryCall( + getChannel(), getUnaryCallMethod(), getCallOptions(), request); + } + + /** + *
+     * One request followed by one response. Response has cache control
+     * headers set such that a caching HTTP proxy (such as GFE) can
+     * satisfy subsequent requests.
+     * 
+ */ + public io.grpc.testing.integration.Messages.SimpleResponse cacheableUnaryCall(io.grpc.testing.integration.Messages.SimpleRequest request) throws io.grpc.StatusException { + return io.grpc.stub.ClientCalls.blockingV2UnaryCall( + getChannel(), getCacheableUnaryCallMethod(), getCallOptions(), request); + } + + /** + *
+     * One request followed by a sequence of responses (streamed download).
+     * The server returns the payload with client desired type and sizes.
+     * 
+ */ + @io.grpc.ExperimentalApi("https://github.com/grpc/grpc-java/issues/10918") + public io.grpc.stub.BlockingClientCall + streamingOutputCall(io.grpc.testing.integration.Messages.StreamingOutputCallRequest request) { + return io.grpc.stub.ClientCalls.blockingV2ServerStreamingCall( + getChannel(), getStreamingOutputCallMethod(), getCallOptions(), request); + } + + /** + *
+     * A sequence of requests followed by one response (streamed upload).
+     * The server returns the aggregated size of client payload as the result.
+     * 
+ */ + @io.grpc.ExperimentalApi("https://github.com/grpc/grpc-java/issues/10918") + public io.grpc.stub.BlockingClientCall + streamingInputCall() { + return io.grpc.stub.ClientCalls.blockingClientStreamingCall( + getChannel(), getStreamingInputCallMethod(), getCallOptions()); + } + + /** + *
+     * A sequence of requests with each request served by the server immediately.
+     * As one request could lead to multiple responses, this interface
+     * demonstrates the idea of full duplexing.
+     * 
+ */ + @io.grpc.ExperimentalApi("https://github.com/grpc/grpc-java/issues/10918") + public io.grpc.stub.BlockingClientCall + fullDuplexCall() { + return io.grpc.stub.ClientCalls.blockingBidiStreamingCall( + getChannel(), getFullDuplexCallMethod(), getCallOptions()); + } + + /** + *
+     * A sequence of requests followed by a sequence of responses.
+     * The server buffers all the client requests and then serves them in order. A
+     * stream of responses are returned to the client when the server starts with
+     * first request.
+     * 
+ */ + @io.grpc.ExperimentalApi("https://github.com/grpc/grpc-java/issues/10918") + public io.grpc.stub.BlockingClientCall + halfDuplexCall() { + return io.grpc.stub.ClientCalls.blockingBidiStreamingCall( + getChannel(), getHalfDuplexCallMethod(), getCallOptions()); + } + + /** + *
+     * The test server will not implement this method. It will be used
+     * to test the behavior when clients call unimplemented methods.
+     * 
+ */ + public io.grpc.testing.integration.EmptyProtos.Empty unimplementedCall(io.grpc.testing.integration.EmptyProtos.Empty request) throws io.grpc.StatusException { + return io.grpc.stub.ClientCalls.blockingV2UnaryCall( + getChannel(), getUnimplementedCallMethod(), getCallOptions(), request); + } + } + + /** + * A stub to allow clients to do limited synchronous rpc calls to service TestService. + *
+   * A simple service to test the various types of RPCs and experiment with
+   * performance with various types of payload.
+   * 
+ */ public static final class TestServiceBlockingStub extends io.grpc.stub.AbstractBlockingStub { private TestServiceBlockingStub( diff --git a/android-interop-testing/src/generated/release/grpc/io/grpc/testing/integration/UnimplementedServiceGrpc.java b/android-interop-testing/src/generated/release/grpc/io/grpc/testing/integration/UnimplementedServiceGrpc.java index 743d68c3828..bec9b5a723a 100644 --- a/android-interop-testing/src/generated/release/grpc/io/grpc/testing/integration/UnimplementedServiceGrpc.java +++ b/android-interop-testing/src/generated/release/grpc/io/grpc/testing/integration/UnimplementedServiceGrpc.java @@ -8,9 +8,6 @@ * that case. * */ -@javax.annotation.Generated( - value = "by gRPC proto compiler", - comments = "Source: grpc/testing/test.proto") @io.grpc.stub.annotations.GrpcGenerated public final class UnimplementedServiceGrpc { @@ -63,6 +60,21 @@ public UnimplementedServiceStub newStub(io.grpc.Channel channel, io.grpc.CallOpt return UnimplementedServiceStub.newStub(factory, channel); } + /** + * Creates a new blocking-style stub that supports all types of calls on the service + */ + public static UnimplementedServiceBlockingV2Stub newBlockingV2Stub( + io.grpc.Channel channel) { + io.grpc.stub.AbstractStub.StubFactory factory = + new io.grpc.stub.AbstractStub.StubFactory() { + @java.lang.Override + public UnimplementedServiceBlockingV2Stub newStub(io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new UnimplementedServiceBlockingV2Stub(channel, callOptions); + } + }; + return UnimplementedServiceBlockingV2Stub.newStub(factory, channel); + } + /** * Creates a new blocking-style stub that supports unary and streaming output calls on the service */ @@ -166,6 +178,37 @@ public void unimplementedCall(io.grpc.testing.integration.EmptyProtos.Empty requ * that case. * */ + public static final class UnimplementedServiceBlockingV2Stub + extends io.grpc.stub.AbstractBlockingStub { + private UnimplementedServiceBlockingV2Stub( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + super(channel, callOptions); + } + + @java.lang.Override + protected UnimplementedServiceBlockingV2Stub build( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new UnimplementedServiceBlockingV2Stub(channel, callOptions); + } + + /** + *
+     * A call that no server should implement
+     * 
+ */ + public io.grpc.testing.integration.EmptyProtos.Empty unimplementedCall(io.grpc.testing.integration.EmptyProtos.Empty request) throws io.grpc.StatusException { + return io.grpc.stub.ClientCalls.blockingV2UnaryCall( + getChannel(), getUnimplementedCallMethod(), getCallOptions(), request); + } + } + + /** + * A stub to allow clients to do limited synchronous rpc calls to service UnimplementedService. + *
+   * A simple service NOT implemented at servers so clients can test for
+   * that case.
+   * 
+ */ public static final class UnimplementedServiceBlockingStub extends io.grpc.stub.AbstractBlockingStub { private UnimplementedServiceBlockingStub( diff --git a/android-interop-testing/src/generated/release/grpc/io/grpc/testing/integration/XdsUpdateClientConfigureServiceGrpc.java b/android-interop-testing/src/generated/release/grpc/io/grpc/testing/integration/XdsUpdateClientConfigureServiceGrpc.java index 61cfc19d29b..3453b6c01be 100644 --- a/android-interop-testing/src/generated/release/grpc/io/grpc/testing/integration/XdsUpdateClientConfigureServiceGrpc.java +++ b/android-interop-testing/src/generated/release/grpc/io/grpc/testing/integration/XdsUpdateClientConfigureServiceGrpc.java @@ -7,9 +7,6 @@ * A service to dynamically update the configuration of an xDS test client. * */ -@javax.annotation.Generated( - value = "by gRPC proto compiler", - comments = "Source: grpc/testing/test.proto") @io.grpc.stub.annotations.GrpcGenerated public final class XdsUpdateClientConfigureServiceGrpc { @@ -62,6 +59,21 @@ public XdsUpdateClientConfigureServiceStub newStub(io.grpc.Channel channel, io.g return XdsUpdateClientConfigureServiceStub.newStub(factory, channel); } + /** + * Creates a new blocking-style stub that supports all types of calls on the service + */ + public static XdsUpdateClientConfigureServiceBlockingV2Stub newBlockingV2Stub( + io.grpc.Channel channel) { + io.grpc.stub.AbstractStub.StubFactory factory = + new io.grpc.stub.AbstractStub.StubFactory() { + @java.lang.Override + public XdsUpdateClientConfigureServiceBlockingV2Stub newStub(io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new XdsUpdateClientConfigureServiceBlockingV2Stub(channel, callOptions); + } + }; + return XdsUpdateClientConfigureServiceBlockingV2Stub.newStub(factory, channel); + } + /** * Creates a new blocking-style stub that supports unary and streaming output calls on the service */ @@ -161,6 +173,36 @@ public void configure(io.grpc.testing.integration.Messages.ClientConfigureReques * A service to dynamically update the configuration of an xDS test client. * */ + public static final class XdsUpdateClientConfigureServiceBlockingV2Stub + extends io.grpc.stub.AbstractBlockingStub { + private XdsUpdateClientConfigureServiceBlockingV2Stub( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + super(channel, callOptions); + } + + @java.lang.Override + protected XdsUpdateClientConfigureServiceBlockingV2Stub build( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new XdsUpdateClientConfigureServiceBlockingV2Stub(channel, callOptions); + } + + /** + *
+     * Update the tes client's configuration.
+     * 
+ */ + public io.grpc.testing.integration.Messages.ClientConfigureResponse configure(io.grpc.testing.integration.Messages.ClientConfigureRequest request) throws io.grpc.StatusException { + return io.grpc.stub.ClientCalls.blockingV2UnaryCall( + getChannel(), getConfigureMethod(), getCallOptions(), request); + } + } + + /** + * A stub to allow clients to do limited synchronous rpc calls to service XdsUpdateClientConfigureService. + *
+   * A service to dynamically update the configuration of an xDS test client.
+   * 
+ */ public static final class XdsUpdateClientConfigureServiceBlockingStub extends io.grpc.stub.AbstractBlockingStub { private XdsUpdateClientConfigureServiceBlockingStub( diff --git a/android-interop-testing/src/generated/release/grpc/io/grpc/testing/integration/XdsUpdateHealthServiceGrpc.java b/android-interop-testing/src/generated/release/grpc/io/grpc/testing/integration/XdsUpdateHealthServiceGrpc.java index 6ba9419dedf..fb5f2cdebc7 100644 --- a/android-interop-testing/src/generated/release/grpc/io/grpc/testing/integration/XdsUpdateHealthServiceGrpc.java +++ b/android-interop-testing/src/generated/release/grpc/io/grpc/testing/integration/XdsUpdateHealthServiceGrpc.java @@ -7,9 +7,6 @@ * A service to remotely control health status of an xDS test server. * */ -@javax.annotation.Generated( - value = "by gRPC proto compiler", - comments = "Source: grpc/testing/test.proto") @io.grpc.stub.annotations.GrpcGenerated public final class XdsUpdateHealthServiceGrpc { @@ -92,6 +89,21 @@ public XdsUpdateHealthServiceStub newStub(io.grpc.Channel channel, io.grpc.CallO return XdsUpdateHealthServiceStub.newStub(factory, channel); } + /** + * Creates a new blocking-style stub that supports all types of calls on the service + */ + public static XdsUpdateHealthServiceBlockingV2Stub newBlockingV2Stub( + io.grpc.Channel channel) { + io.grpc.stub.AbstractStub.StubFactory factory = + new io.grpc.stub.AbstractStub.StubFactory() { + @java.lang.Override + public XdsUpdateHealthServiceBlockingV2Stub newStub(io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new XdsUpdateHealthServiceBlockingV2Stub(channel, callOptions); + } + }; + return XdsUpdateHealthServiceBlockingV2Stub.newStub(factory, channel); + } + /** * Creates a new blocking-style stub that supports unary and streaming output calls on the service */ @@ -200,6 +212,40 @@ public void setNotServing(io.grpc.testing.integration.EmptyProtos.Empty request, * A service to remotely control health status of an xDS test server. * */ + public static final class XdsUpdateHealthServiceBlockingV2Stub + extends io.grpc.stub.AbstractBlockingStub { + private XdsUpdateHealthServiceBlockingV2Stub( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + super(channel, callOptions); + } + + @java.lang.Override + protected XdsUpdateHealthServiceBlockingV2Stub build( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new XdsUpdateHealthServiceBlockingV2Stub(channel, callOptions); + } + + /** + */ + public io.grpc.testing.integration.EmptyProtos.Empty setServing(io.grpc.testing.integration.EmptyProtos.Empty request) throws io.grpc.StatusException { + return io.grpc.stub.ClientCalls.blockingV2UnaryCall( + getChannel(), getSetServingMethod(), getCallOptions(), request); + } + + /** + */ + public io.grpc.testing.integration.EmptyProtos.Empty setNotServing(io.grpc.testing.integration.EmptyProtos.Empty request) throws io.grpc.StatusException { + return io.grpc.stub.ClientCalls.blockingV2UnaryCall( + getChannel(), getSetNotServingMethod(), getCallOptions(), request); + } + } + + /** + * A stub to allow clients to do limited synchronous rpc calls to service XdsUpdateHealthService. + *
+   * A service to remotely control health status of an xDS test server.
+   * 
+ */ public static final class XdsUpdateHealthServiceBlockingStub extends io.grpc.stub.AbstractBlockingStub { private XdsUpdateHealthServiceBlockingStub( diff --git a/android-interop-testing/src/main/AndroidManifest.xml b/android-interop-testing/src/main/AndroidManifest.xml index 35f3ee33a2b..08c139e5880 100644 --- a/android-interop-testing/src/main/AndroidManifest.xml +++ b/android-interop-testing/src/main/AndroidManifest.xml @@ -5,19 +5,19 @@ - + + android:theme="@style/Base.V7.Theme.AppCompat.Light"> + android:exported="true"> diff --git a/android-interop-testing/src/main/java/io/grpc/android/integrationtest/TesterActivity.java b/android-interop-testing/src/main/java/io/grpc/android/integrationtest/TesterActivity.java index fb5b35c42d5..17c7e24cbfa 100644 --- a/android-interop-testing/src/main/java/io/grpc/android/integrationtest/TesterActivity.java +++ b/android-interop-testing/src/main/java/io/grpc/android/integrationtest/TesterActivity.java @@ -121,7 +121,7 @@ private void startTest(String testCase) { ((InputMethodManager) getSystemService(Context.INPUT_METHOD_SERVICE)).hideSoftInputFromWindow( hostEdit.getWindowToken(), 0); enableButtons(false); - resultText.setText("Testing..."); + resultText.setText(R.string.testing_message); String host = hostEdit.getText().toString(); String portStr = portEdit.getText().toString(); diff --git a/android-interop-testing/src/main/res/layout/activity_tester.xml b/android-interop-testing/src/main/res/layout/activity_tester.xml index e25bd1bb6f6..042da6437c0 100644 --- a/android-interop-testing/src/main/res/layout/activity_tester.xml +++ b/android-interop-testing/src/main/res/layout/activity_tester.xml @@ -16,6 +16,7 @@ android:layout_weight="2" android:layout_width="0dp" android:layout_height="wrap_content" + android:inputType="text" android:hint="Enter Host" /> gRPC Integration Test + Testing… diff --git a/android/build.gradle b/android/build.gradle index 3b3bfa59b96..e94bf03ff37 100644 --- a/android/build.gradle +++ b/android/build.gradle @@ -7,20 +7,20 @@ plugins { description = 'gRPC: Android' android { - namespace 'io.grpc.android' + namespace = 'io.grpc.android' compileOptions { sourceCompatibility JavaVersion.VERSION_1_8 targetCompatibility JavaVersion.VERSION_1_8 } compileSdkVersion 34 defaultConfig { - minSdkVersion 21 + minSdkVersion 23 targetSdkVersion 33 versionCode 1 versionName "1.0" testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner" } - lintOptions { abortOnError true } + lintOptions { abortOnError = true } publishing { singleVariant('release') { withSourcesJar() @@ -31,7 +31,6 @@ android { repositories { google() - mavenCentral() } dependencies { diff --git a/android/src/main/java/io/grpc/android/AndroidChannelBuilder.java b/android/src/main/java/io/grpc/android/AndroidChannelBuilder.java index 317b7a50b74..3a750e02795 100644 --- a/android/src/main/java/io/grpc/android/AndroidChannelBuilder.java +++ b/android/src/main/java/io/grpc/android/AndroidChannelBuilder.java @@ -28,6 +28,7 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; import com.google.errorprone.annotations.InlineMe; +import com.google.errorprone.annotations.concurrent.GuardedBy; import io.grpc.CallOptions; import io.grpc.ClientCall; import io.grpc.ConnectivityState; @@ -41,7 +42,6 @@ import io.grpc.internal.GrpcUtil; import java.util.concurrent.TimeUnit; import javax.annotation.Nullable; -import javax.annotation.concurrent.GuardedBy; /** * Builds a {@link ManagedChannel} that, when provided with a {@link Context}, will automatically @@ -217,7 +217,6 @@ private void configureNetworkMonitoring() { connectivityManager.registerDefaultNetworkCallback(defaultNetworkCallback); unregisterRunnable = new Runnable() { - @TargetApi(Build.VERSION_CODES.LOLLIPOP) @Override public void run() { connectivityManager.unregisterNetworkCallback(defaultNetworkCallback); @@ -231,7 +230,6 @@ public void run() { context.registerReceiver(networkReceiver, networkIntentFilter); unregisterRunnable = new Runnable() { - @TargetApi(Build.VERSION_CODES.LOLLIPOP) @Override public void run() { context.unregisterReceiver(networkReceiver); @@ -325,7 +323,6 @@ public void onBlockedStatusChanged(Network network, boolean blocked) { /** Respond to network changes. Only used on API levels < 24. */ private class NetworkReceiver extends BroadcastReceiver { - private boolean isConnected = false; @SuppressWarnings("deprecation") @Override @@ -333,9 +330,8 @@ public void onReceive(Context context, Intent intent) { ConnectivityManager conn = (ConnectivityManager) context.getSystemService(Context.CONNECTIVITY_SERVICE); android.net.NetworkInfo networkInfo = conn.getActiveNetworkInfo(); - boolean wasConnected = isConnected; - isConnected = networkInfo != null && networkInfo.isConnected(); - if (isConnected && !wasConnected) { + + if (networkInfo != null && networkInfo.isConnected()) { delegate.enterIdle(); } } diff --git a/android/src/main/java/io/grpc/android/UdsChannelBuilder.java b/android/src/main/java/io/grpc/android/UdsChannelBuilder.java index e2dc7232378..6f03aa0ee5e 100644 --- a/android/src/main/java/io/grpc/android/UdsChannelBuilder.java +++ b/android/src/main/java/io/grpc/android/UdsChannelBuilder.java @@ -21,6 +21,7 @@ import io.grpc.ExperimentalApi; import io.grpc.InsecureChannelCredentials; import io.grpc.ManagedChannelBuilder; +import io.grpc.internal.GrpcUtil; import java.lang.reflect.InvocationTargetException; import javax.annotation.Nullable; import javax.net.SocketFactory; @@ -68,17 +69,20 @@ public static ManagedChannelBuilder forPath(String path, Namespace namespace) throw new UnsupportedOperationException("OkHttpChannelBuilder not found on the classpath"); } try { - // Target 'dns:///localhost' is unused, but necessary as an argument for OkHttpChannelBuilder. + // Target 'dns:///127.0.0.1' is unused, but necessary as an argument for OkHttpChannelBuilder. + // An IP address is used instead of localhost to avoid a DNS lookup (see #11442). This should + // work even if IPv4 is unavailable, as the DNS resolver doesn't need working IPv4 to parse an + // IPv4 address. Unavailable IPv4 fails when we connect(), not at resolution time. // TLS is unsupported because Conscrypt assumes the platform Socket implementation to improve // performance by using the file descriptor directly. Object o = OKHTTP_CHANNEL_BUILDER_CLASS .getMethod("forTarget", String.class, ChannelCredentials.class) - .invoke(null, "dns:///localhost", InsecureChannelCredentials.create()); + .invoke(null, "dns:///127.0.0.1", InsecureChannelCredentials.create()); ManagedChannelBuilder builder = OKHTTP_CHANNEL_BUILDER_CLASS.cast(o); OKHTTP_CHANNEL_BUILDER_CLASS .getMethod("socketFactory", SocketFactory.class) .invoke(builder, new UdsSocketFactory(path, namespace)); - return builder; + return builder.proxyDetector(GrpcUtil.NOOP_PROXY_DETECTOR); } catch (IllegalAccessException e) { throw new RuntimeException("Failed to create OkHttpChannelBuilder", e); } catch (NoSuchMethodException e) { diff --git a/android/src/test/java/io/grpc/android/AndroidChannelBuilderTest.java b/android/src/test/java/io/grpc/android/AndroidChannelBuilderTest.java index 83367d93b32..c0884e4d7cf 100644 --- a/android/src/test/java/io/grpc/android/AndroidChannelBuilderTest.java +++ b/android/src/test/java/io/grpc/android/AndroidChannelBuilderTest.java @@ -152,12 +152,6 @@ public void networkChanges_api23() { .sendBroadcast(new Intent(ConnectivityManager.CONNECTIVITY_ACTION)); assertThat(delegateChannel.enterIdleCount).isEqualTo(1); - // The broadcast receiver may fire when the active network status has not actually changed - ApplicationProvider - .getApplicationContext() - .sendBroadcast(new Intent(ConnectivityManager.CONNECTIVITY_ACTION)); - assertThat(delegateChannel.enterIdleCount).isEqualTo(1); - // Drop the connection shadowOf(connectivityManager).setActiveNetworkInfo(null); ApplicationProvider diff --git a/api/BUILD.bazel b/api/BUILD.bazel index efd1e8d9660..6de00d6272d 100644 --- a/api/BUILD.bazel +++ b/api/BUILD.bazel @@ -1,16 +1,35 @@ +load("@rules_java//java:defs.bzl", "java_library") +load("@rules_jvm_external//:defs.bzl", "artifact") + java_library( name = "api", srcs = glob([ "src/main/java/**/*.java", "src/context/java/**/*.java", ]), - javacopts = ["-Xep:DoNotCall:OFF"], # Remove once requiring Bazel 3.4.0+; allows non-final visibility = ["//visibility:public"], deps = [ - "@com_google_code_findbugs_jsr305//jar", - "@com_google_errorprone_error_prone_annotations//jar", - "@com_google_guava_failureaccess//jar", # future transitive dep of Guava. See #5214 - "@com_google_guava_guava//jar", - "@com_google_j2objc_j2objc_annotations//jar", + artifact("com.google.code.findbugs:jsr305"), + artifact("com.google.errorprone:error_prone_annotations"), + artifact("com.google.guava:failureaccess"), # future transitive dep of Guava. See #5214 + artifact("com.google.guava:guava"), + ], +) + +java_library( + name = "test_fixtures", + testonly = 1, + srcs = glob([ + "src/testFixtures/java/io/grpc/**/*.java", + ]), + visibility = ["//xds:__pkg__"], + deps = [ + "//core", + artifact("com.google.code.findbugs:jsr305"), + artifact("com.google.errorprone:error_prone_annotations"), + artifact("com.google.guava:guava"), + artifact("com.google.truth:truth"), + artifact("junit:junit"), + artifact("org.mockito:mockito-core"), ], ) diff --git a/api/build.gradle b/api/build.gradle index 0a80a1e48b9..745fa00b3f1 100644 --- a/api/build.gradle +++ b/api/build.gradle @@ -47,15 +47,33 @@ dependencies { testImplementation project(':grpc-core') testImplementation project(':grpc-testing') testImplementation libraries.guava.testlib + testImplementation libraries.truth - signature libraries.signature.java - signature libraries.signature.android + signature (libraries.signature.java) { + artifact { + extension = "signature" + } + } + signature (libraries.signature.android) { + artifact { + extension = "signature" + } + } +} + +animalsniffer { + annotation = 'io.grpc.IgnoreJRERequirement' } tasks.named("javadoc").configure { source sourceSets.context.allSource // We want io.grpc.Internal, but not io.grpc.Internal* + exclude 'io/grpc/*MetricInstrument.java' + exclude 'io/grpc/*MetricInstrumentRegistry.java' exclude 'io/grpc/Internal?*.java' + exclude 'io/grpc/MetricRecorder.java' + exclude 'io/grpc/MetricSink.java' + exclude 'io/grpc/Uri.java' } tasks.named("sourcesJar").configure { diff --git a/api/src/context/java/io/grpc/Deadline.java b/api/src/context/java/io/grpc/Deadline.java index 62b803267a8..92eeba5ffce 100644 --- a/api/src/context/java/io/grpc/Deadline.java +++ b/api/src/context/java/io/grpc/Deadline.java @@ -16,8 +16,10 @@ package io.grpc; -import java.util.Arrays; +import static java.util.Objects.requireNonNull; + import java.util.Locale; +import java.util.Objects; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ScheduledFuture; import java.util.concurrent.TimeUnit; @@ -33,7 +35,7 @@ * passed to the various components unambiguously. */ public final class Deadline implements Comparable { - private static final SystemTicker SYSTEM_TICKER = new SystemTicker(); + private static final Ticker SYSTEM_TICKER = new SystemTicker(); // nanoTime has a range of just under 300 years. Only allow up to 100 years in the past or future // to prevent wraparound as long as process runs for less than ~100 years. private static final long MAX_OFFSET = TimeUnit.DAYS.toNanos(100 * 365); @@ -91,7 +93,7 @@ public static Deadline after(long duration, TimeUnit units) { * @since 1.24.0 */ public static Deadline after(long duration, TimeUnit units, Ticker ticker) { - checkNotNull(units, "units"); + requireNonNull(units, "units"); return new Deadline(ticker, units.toNanos(duration), true); } @@ -191,8 +193,8 @@ public long timeRemaining(TimeUnit unit) { * @return {@link ScheduledFuture} which can be used to cancel execution of the task */ public ScheduledFuture runOnExpiration(Runnable task, ScheduledExecutorService scheduler) { - checkNotNull(task, "task"); - checkNotNull(scheduler, "scheduler"); + requireNonNull(task, "task"); + requireNonNull(scheduler, "scheduler"); return scheduler.schedule(task, deadlineNanos - ticker.nanoTime(), TimeUnit.NANOSECONDS); } @@ -225,37 +227,27 @@ public String toString() { @Override public int compareTo(Deadline that) { checkTicker(that); - long diff = this.deadlineNanos - that.deadlineNanos; - if (diff < 0) { - return -1; - } else if (diff > 0) { - return 1; - } - return 0; + return Long.compare(this.deadlineNanos, that.deadlineNanos); } @Override public int hashCode() { - return Arrays.asList(this.ticker, this.deadlineNanos).hashCode(); + return Objects.hash(this.ticker, this.deadlineNanos); } @Override - public boolean equals(final Object o) { - if (o == this) { + public boolean equals(final Object object) { + if (object == this) { return true; } - if (!(o instanceof Deadline)) { - return false; - } - - final Deadline other = (Deadline) o; - if (this.ticker == null ? other.ticker != null : this.ticker != other.ticker) { + if (!(object instanceof Deadline)) { return false; } - if (this.deadlineNanos != other.deadlineNanos) { + final Deadline that = (Deadline) object; + if (this.ticker == null ? that.ticker != null : this.ticker != that.ticker) { return false; } - return true; + return this.deadlineNanos == that.deadlineNanos; } /** @@ -275,24 +267,17 @@ public boolean equals(final Object o) { * @since 1.24.0 */ public abstract static class Ticker { - /** Returns the number of nanoseconds since this source's epoch. */ + /** Returns the number of nanoseconds elapsed since this ticker's reference point in time. */ public abstract long nanoTime(); } - private static class SystemTicker extends Ticker { + private static final class SystemTicker extends Ticker { @Override public long nanoTime() { return System.nanoTime(); } } - private static T checkNotNull(T reference, Object errorMessage) { - if (reference == null) { - throw new NullPointerException(String.valueOf(errorMessage)); - } - return reference; - } - private void checkTicker(Deadline other) { if (ticker != other.ticker) { throw new AssertionError( diff --git a/api/src/main/java/io/grpc/Attributes.java b/api/src/main/java/io/grpc/Attributes.java index ca065765506..c8550d176b4 100644 --- a/api/src/main/java/io/grpc/Attributes.java +++ b/api/src/main/java/io/grpc/Attributes.java @@ -215,6 +215,7 @@ public int hashCode() { * The helper class to build an Attributes instance. */ public static final class Builder { + // Exactly one of base and newdata will be set private Attributes base; private IdentityHashMap, Object> newdata; @@ -225,8 +226,11 @@ private Builder(Attributes base) { private IdentityHashMap, Object> data(int size) { if (newdata == null) { - newdata = new IdentityHashMap<>(size); + newdata = new IdentityHashMap<>(base.data.size() + size); + newdata.putAll(base.data); + base = null; } + assert base == null; return newdata; } @@ -236,20 +240,18 @@ public Builder set(Key key, T value) { } /** - * Removes the key and associated value from the attribtues. + * Removes the key and associated value from the attributes. * * @since 1.22.0 * @param key The key to remove * @return this */ - @ExperimentalApi("https://github.com/grpc/grpc-java/issues/5777") public Builder discard(Key key) { - if (base.data.containsKey(key)) { - IdentityHashMap, Object> newBaseData = new IdentityHashMap<>(base.data); - newBaseData.remove(key); - base = new Attributes(newBaseData); - } - if (newdata != null) { + if (base != null) { + if (base.data.containsKey(key)) { + data(0).remove(key); + } + } else { newdata.remove(key); } return this; @@ -265,11 +267,6 @@ public Builder setAll(Attributes other) { */ public Attributes build() { if (newdata != null) { - for (Map.Entry, Object> entry : base.data.entrySet()) { - if (!newdata.containsKey(entry.getKey())) { - newdata.put(entry.getKey(), entry.getValue()); - } - } base = new Attributes(newdata); newdata = null; } diff --git a/api/src/main/java/io/grpc/CallCredentials.java b/api/src/main/java/io/grpc/CallCredentials.java index 31b68b22dae..eb92a6f15fa 100644 --- a/api/src/main/java/io/grpc/CallCredentials.java +++ b/api/src/main/java/io/grpc/CallCredentials.java @@ -43,7 +43,7 @@ public abstract class CallCredentials { *

It is called for each individual RPC, within the {@link Context} of the call, before the * stream is about to be created on a transport. Implementations should not block in this * method. If metadata is not immediately available, e.g., needs to be fetched from network, the - * implementation may give the {@code applier} to an asynchronous task which will eventually call + * implementation may give the {@code appExecutor} an asynchronous task which will eventually call * the {@code applier}. The RPC proceeds only after the {@code applier} is called. * * @param requestInfo request-related information diff --git a/api/src/main/java/io/grpc/CallOptions.java b/api/src/main/java/io/grpc/CallOptions.java index 87493d2ba0b..800bdfb6c90 100644 --- a/api/src/main/java/io/grpc/CallOptions.java +++ b/api/src/main/java/io/grpc/CallOptions.java @@ -17,16 +17,18 @@ package io.grpc; import static com.google.common.base.Preconditions.checkArgument; +import static io.grpc.TimeUtils.convertToNanos; import com.google.common.base.MoreObjects; import com.google.common.base.Preconditions; +import com.google.errorprone.annotations.CheckReturnValue; +import java.time.Duration; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.List; import java.util.concurrent.Executor; import java.util.concurrent.TimeUnit; -import javax.annotation.CheckReturnValue; import javax.annotation.Nullable; import javax.annotation.concurrent.Immutable; @@ -79,6 +81,8 @@ public final class CallOptions { private final Integer maxInboundMessageSize; @Nullable private final Integer maxOutboundMessageSize; + @Nullable + private final Integer onReadyThreshold; private CallOptions(Builder builder) { this.deadline = builder.deadline; @@ -91,6 +95,7 @@ private CallOptions(Builder builder) { this.waitForReady = builder.waitForReady; this.maxInboundMessageSize = builder.maxInboundMessageSize; this.maxOutboundMessageSize = builder.maxOutboundMessageSize; + this.onReadyThreshold = builder.onReadyThreshold; } static class Builder { @@ -105,6 +110,7 @@ static class Builder { Boolean waitForReady; Integer maxInboundMessageSize; Integer maxOutboundMessageSize; + Integer onReadyThreshold; private CallOptions build() { return new CallOptions(this); @@ -172,6 +178,11 @@ public CallOptions withDeadlineAfter(long duration, TimeUnit unit) { return withDeadline(Deadline.after(duration, unit)); } + @ExperimentalApi("https://github.com/grpc/grpc-java/issues/11657") + public CallOptions withDeadlineAfter(Duration duration) { + return withDeadlineAfter(convertToNanos(duration), TimeUnit.NANOSECONDS); + } + /** * Returns the deadline or {@code null} if the deadline is not set. */ @@ -203,6 +214,46 @@ public CallOptions withoutWaitForReady() { return builder.build(); } + /** + * Specifies how many bytes must be queued before the call is + * considered not ready to send more messages. + * + * @param numBytes The number of bytes that must be queued. Must be a + * positive integer. + */ + @ExperimentalApi("https://github.com/grpc/grpc-java/issues/11021") + public CallOptions withOnReadyThreshold(int numBytes) { + checkArgument(numBytes > 0, "numBytes must be positive: %s", numBytes); + Builder builder = toBuilder(this); + builder.onReadyThreshold = numBytes; + return builder.build(); + } + + /** + * Resets to the default number of bytes that must be queued before the + * call will leave the + * 'wait for ready' state. + */ + @ExperimentalApi("https://github.com/grpc/grpc-java/issues/11021") + public CallOptions clearOnReadyThreshold() { + Builder builder = toBuilder(this); + builder.onReadyThreshold = null; + return builder.build(); + } + + /** + * Returns to the default number of bytes that must be queued before the + * call will leave the + * 'wait for ready' state. + * + * @return null if the default threshold is used. + */ + @Nullable + @ExperimentalApi("https://github.com/grpc/grpc-java/issues/11021") + public Integer getOnReadyThreshold() { + return onReadyThreshold; + } + /** * Returns the compressor's name. */ @@ -468,6 +519,7 @@ private static Builder toBuilder(CallOptions other) { builder.waitForReady = other.waitForReady; builder.maxInboundMessageSize = other.maxInboundMessageSize; builder.maxOutboundMessageSize = other.maxOutboundMessageSize; + builder.onReadyThreshold = other.onReadyThreshold; return builder; } @@ -483,6 +535,7 @@ public String toString() { .add("waitForReady", isWaitForReady()) .add("maxInboundMessageSize", maxInboundMessageSize) .add("maxOutboundMessageSize", maxOutboundMessageSize) + .add("onReadyThreshold", onReadyThreshold) .add("streamTracerFactories", streamTracerFactories) .toString(); } diff --git a/api/src/main/java/io/grpc/CallbackMetricInstrument.java b/api/src/main/java/io/grpc/CallbackMetricInstrument.java new file mode 100644 index 00000000000..1d66d5340ed --- /dev/null +++ b/api/src/main/java/io/grpc/CallbackMetricInstrument.java @@ -0,0 +1,23 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc; + +/** + * Tagging interface for MetricInstruments that can be used with batch callbacks. + */ +@Internal +public interface CallbackMetricInstrument extends MetricInstrument {} diff --git a/api/src/main/java/io/grpc/Channel.java b/api/src/main/java/io/grpc/Channel.java index 60ff76ff082..e2787eb2f26 100644 --- a/api/src/main/java/io/grpc/Channel.java +++ b/api/src/main/java/io/grpc/Channel.java @@ -16,7 +16,6 @@ package io.grpc; -import javax.annotation.concurrent.ThreadSafe; /** * A virtual connection to a conceptual endpoint, to perform RPCs. A channel is free to have zero or @@ -29,8 +28,9 @@ * implementations using {@link ClientInterceptor}. It is expected that most application * code will not use this class directly but rather work with stubs that have been bound to a * Channel that was decorated during application initialization. + * + *

This class is thread-safe. */ -@ThreadSafe public abstract class Channel { /** * Create a {@link ClientCall} to the remote operation specified by the given diff --git a/api/src/main/java/io/grpc/ChannelLogger.java b/api/src/main/java/io/grpc/ChannelLogger.java index ce654ec9d5b..2cdf4c84724 100644 --- a/api/src/main/java/io/grpc/ChannelLogger.java +++ b/api/src/main/java/io/grpc/ChannelLogger.java @@ -16,15 +16,15 @@ package io.grpc; -import javax.annotation.concurrent.ThreadSafe; /** * A Channel-specific logger provided by GRPC library to {@link LoadBalancer} implementations. * Information logged here goes to Channelz, and to the Java logger of this class * as well. + * + *

This class is thread-safe. */ @ExperimentalApi("https://github.com/grpc/grpc-java/issues/5029") -@ThreadSafe public abstract class ChannelLogger { /** * Log levels. See the table below for the mapping from the ChannelLogger levels to Channelz diff --git a/api/src/main/java/io/grpc/ClientCall.java b/api/src/main/java/io/grpc/ClientCall.java index df9e15001e1..c915c8beaac 100644 --- a/api/src/main/java/io/grpc/ClientCall.java +++ b/api/src/main/java/io/grpc/ClientCall.java @@ -67,7 +67,7 @@ * manner, and notifies gRPC library to receive additional response after one is consumed by * a fictional processResponse(). * - *

+ * 
  *   call = channel.newCall(bidiStreamingMethod, callOptions);
  *   listener = new ClientCall.Listener<FooResponse>() {
  *     @Override
diff --git a/api/src/main/java/io/grpc/ClientInterceptor.java b/api/src/main/java/io/grpc/ClientInterceptor.java
index c27c31c8474..d6c8cd7e6fb 100644
--- a/api/src/main/java/io/grpc/ClientInterceptor.java
+++ b/api/src/main/java/io/grpc/ClientInterceptor.java
@@ -16,7 +16,6 @@
 
 package io.grpc;
 
-import javax.annotation.concurrent.ThreadSafe;
 
 /**
  * Interface for intercepting outgoing calls before they are dispatched by a {@link Channel}.
@@ -37,8 +36,10 @@
  * without completing the previous ones first. Refer to the
  * {@link io.grpc.ClientCall.Listener ClientCall.Listener} docs for more details regarding thread
  * safety of the returned listener.
+ * 
+ * 

This is thread-safe and should be considered + * for the errorprone ThreadSafe annotation in the future. */ -@ThreadSafe public interface ClientInterceptor { /** * Intercept {@link ClientCall} creation by the {@code next} {@link Channel}. diff --git a/api/src/main/java/io/grpc/ClientStreamTracer.java b/api/src/main/java/io/grpc/ClientStreamTracer.java index cb2f5538e34..8e11e781e7c 100644 --- a/api/src/main/java/io/grpc/ClientStreamTracer.java +++ b/api/src/main/java/io/grpc/ClientStreamTracer.java @@ -19,13 +19,13 @@ import static com.google.common.base.Preconditions.checkNotNull; import com.google.common.base.MoreObjects; -import javax.annotation.concurrent.ThreadSafe; /** * {@link StreamTracer} for the client-side. + * + *

This class is thread-safe. */ @ExperimentalApi("https://github.com/grpc/grpc-java/issues/2861") -@ThreadSafe public abstract class ClientStreamTracer extends StreamTracer { /** * Indicates how long the call was delayed, in nanoseconds, due to waiting for name resolution @@ -70,15 +70,35 @@ public void inboundHeaders() { } /** - * Trailing metadata has been received from the server. + * Headers has been received from the server. This method does not pass ownership to {@code + * headers}, so implementations must not access the metadata after returning. Modifications to the + * metadata within this method will be seen by interceptors and the application. + * + * @param headers the received header metadata + */ + public void inboundHeaders(Metadata headers) { + inboundHeaders(); + } + + /** + * Trailing metadata has been received from the server. This method does not pass ownership to + * {@code trailers}, so implementations must not access the metadata after returning. + * Modifications to the metadata within this method will be seen by interceptors and the + * application. * - * @param trailers the mutable trailing metadata. Modifications to it will be seen by - * interceptors and the application. + * @param trailers the received trailing metadata * @since 1.17.0 */ public void inboundTrailers(Metadata trailers) { } + /** + * Information providing context to the call became available. + */ + @Internal + public void addOptionalLabel(String key, String value) { + } + /** * Factory class for {@link ClientStreamTracer}. */ @@ -112,12 +132,15 @@ public static final class StreamInfo { private final CallOptions callOptions; private final int previousAttempts; private final boolean isTransparentRetry; + private final boolean isHedging; StreamInfo( - CallOptions callOptions, int previousAttempts, boolean isTransparentRetry) { + CallOptions callOptions, int previousAttempts, boolean isTransparentRetry, + boolean isHedging) { this.callOptions = checkNotNull(callOptions, "callOptions"); this.previousAttempts = previousAttempts; this.isTransparentRetry = isTransparentRetry; + this.isHedging = isHedging; } /** @@ -145,6 +168,15 @@ public boolean isTransparentRetry() { return isTransparentRetry; } + /** + * Whether the stream is hedging. + * + * @since 1.74.0 + */ + public boolean isHedging() { + return isHedging; + } + /** * Converts this StreamInfo into a new Builder. * @@ -154,7 +186,9 @@ public Builder toBuilder() { return new Builder() .setCallOptions(callOptions) .setPreviousAttempts(previousAttempts) - .setIsTransparentRetry(isTransparentRetry); + .setIsTransparentRetry(isTransparentRetry) + .setIsHedging(isHedging); + } /** @@ -172,6 +206,7 @@ public String toString() { .add("callOptions", callOptions) .add("previousAttempts", previousAttempts) .add("isTransparentRetry", isTransparentRetry) + .add("isHedging", isHedging) .toString(); } @@ -184,6 +219,7 @@ public static final class Builder { private CallOptions callOptions = CallOptions.DEFAULT; private int previousAttempts; private boolean isTransparentRetry; + private boolean isHedging; Builder() { } @@ -216,11 +252,21 @@ public Builder setIsTransparentRetry(boolean isTransparentRetry) { return this; } + /** + * Sets whether the stream is hedging. + * + * @since 1.74.0 + */ + public Builder setIsHedging(boolean isHedging) { + this.isHedging = isHedging; + return this; + } + /** * Builds a new StreamInfo. */ public StreamInfo build() { - return new StreamInfo(callOptions, previousAttempts, isTransparentRetry); + return new StreamInfo(callOptions, previousAttempts, isTransparentRetry, isHedging); } } } diff --git a/api/src/main/java/io/grpc/Configurator.java b/api/src/main/java/io/grpc/Configurator.java new file mode 100644 index 00000000000..90468769a8d --- /dev/null +++ b/api/src/main/java/io/grpc/Configurator.java @@ -0,0 +1,36 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc; + +/** + * Provides hooks for modifying gRPC channels and servers during their construction. + */ +interface Configurator { + /** + * Allows implementations to modify the channel builder. + * + * @param channelBuilder the channel builder being constructed + */ + default void configureChannelBuilder(ManagedChannelBuilder channelBuilder) {} + + /** + * Allows implementations to modify the server builder. + * + * @param serverBuilder the server builder being constructed + */ + default void configureServerBuilder(ServerBuilder serverBuilder) {} +} diff --git a/api/src/main/java/io/grpc/ConfiguratorRegistry.java b/api/src/main/java/io/grpc/ConfiguratorRegistry.java new file mode 100644 index 00000000000..19d6703d308 --- /dev/null +++ b/api/src/main/java/io/grpc/ConfiguratorRegistry.java @@ -0,0 +1,87 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc; + +import com.google.errorprone.annotations.concurrent.GuardedBy; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +/** + * A registry for {@link Configurator} instances. + * + *

This class is responsible for maintaining a list of configurators and providing access to + * them. The default registry can be obtained using {@link #getDefaultRegistry()}. + */ +final class ConfiguratorRegistry { + private static ConfiguratorRegistry instance; + + @GuardedBy("this") + private boolean wasConfiguratorsSet; + @GuardedBy("this") + private List configurators = Collections.emptyList(); + @GuardedBy("this") + private int configuratorsCallCountBeforeSet = 0; + + ConfiguratorRegistry() {} + + /** + * Returns the default global instance of the configurator registry. + */ + public static synchronized ConfiguratorRegistry getDefaultRegistry() { + if (instance == null) { + instance = new ConfiguratorRegistry(); + } + return instance; + } + + /** + * Sets the configurators in this registry. This method can only be called once. + * + * @param configurators the configurators to set + * @throws IllegalStateException if this method is called more than once + */ + public synchronized void setConfigurators(List configurators) { + if (wasConfiguratorsSet) { + throw new IllegalStateException("Configurators are already set"); + } + this.configurators = Collections.unmodifiableList(new ArrayList<>(configurators)); + wasConfiguratorsSet = true; + } + + /** + * Returns a list of the configurators in this registry. + */ + public synchronized List getConfigurators() { + if (!wasConfiguratorsSet) { + configuratorsCallCountBeforeSet++; + } + return configurators; + } + + /** + * Returns the number of times getConfigurators() was called before + * setConfigurators() was successfully invoked. + */ + public synchronized int getConfiguratorsCallCountBeforeSet() { + return configuratorsCallCountBeforeSet; + } + + public synchronized boolean wasSetConfiguratorsCalled() { + return wasConfiguratorsSet; + } +} diff --git a/api/src/main/java/io/grpc/ConnectivityState.java b/api/src/main/java/io/grpc/ConnectivityState.java index 677039b2517..a7407efb2e9 100644 --- a/api/src/main/java/io/grpc/ConnectivityState.java +++ b/api/src/main/java/io/grpc/ConnectivityState.java @@ -20,7 +20,7 @@ * The connectivity states. * * @see - * more information + * more information */ @ExperimentalApi("https://github.com/grpc/grpc-java/issues/4359") public enum ConnectivityState { diff --git a/api/src/main/java/io/grpc/DoubleCounterMetricInstrument.java b/api/src/main/java/io/grpc/DoubleCounterMetricInstrument.java new file mode 100644 index 00000000000..3f07d83d58f --- /dev/null +++ b/api/src/main/java/io/grpc/DoubleCounterMetricInstrument.java @@ -0,0 +1,30 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc; + +import java.util.List; + +/** + * Represents a double-valued counter metric instrument. + */ +@Internal +public final class DoubleCounterMetricInstrument extends PartialMetricInstrument { + public DoubleCounterMetricInstrument(int index, String name, String description, String unit, + List requiredLabelKeys, List optionalLabelKeys, boolean enableByDefault) { + super(index, name, description, unit, requiredLabelKeys, optionalLabelKeys, enableByDefault); + } +} diff --git a/api/src/main/java/io/grpc/DoubleHistogramMetricInstrument.java b/api/src/main/java/io/grpc/DoubleHistogramMetricInstrument.java new file mode 100644 index 00000000000..9039a8c62c1 --- /dev/null +++ b/api/src/main/java/io/grpc/DoubleHistogramMetricInstrument.java @@ -0,0 +1,38 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc; + +import java.util.List; + +/** + * Represents a double-valued histogram metric instrument. + */ +@Internal +public final class DoubleHistogramMetricInstrument extends PartialMetricInstrument { + private final List bucketBoundaries; + + public DoubleHistogramMetricInstrument(int index, String name, String description, String unit, + List bucketBoundaries, List requiredLabelKeys, List optionalLabelKeys, + boolean enableByDefault) { + super(index, name, description, unit, requiredLabelKeys, optionalLabelKeys, enableByDefault); + this.bucketBoundaries = bucketBoundaries; + } + + public List getBucketBoundaries() { + return bucketBoundaries; + } +} diff --git a/api/src/main/java/io/grpc/EquivalentAddressGroup.java b/api/src/main/java/io/grpc/EquivalentAddressGroup.java index 4b3db006684..2dd52fe7f21 100644 --- a/api/src/main/java/io/grpc/EquivalentAddressGroup.java +++ b/api/src/main/java/io/grpc/EquivalentAddressGroup.java @@ -50,6 +50,26 @@ public final class EquivalentAddressGroup { @ExperimentalApi("https://github.com/grpc/grpc-java/issues/6138") public static final Attributes.Key ATTR_AUTHORITY_OVERRIDE = Attributes.Key.create("io.grpc.EquivalentAddressGroup.ATTR_AUTHORITY_OVERRIDE"); + /** + * The name of the locality that this EquivalentAddressGroup is in. + */ + public static final Attributes.Key ATTR_LOCALITY_NAME = + Attributes.Key.create("io.grpc.EquivalentAddressGroup.LOCALITY"); + /** + * The backend service associated with this EquivalentAddressGroup. + */ + @Attr + static final Attributes.Key ATTR_BACKEND_SERVICE = + Attributes.Key.create("io.grpc.EquivalentAddressGroup.BACKEND_SERVICE"); + /** + * Endpoint weight for load balancing purposes. While the type is Long, it must be a valid uint32. + * Must not be zero. The weight is proportional to the other endpoints; if an endpoint's weight is + * twice that of another endpoint, it is intended to receive twice the load. + */ + @Attr + static final Attributes.Key ATTR_WEIGHT = + Attributes.Key.create("io.grpc.EquivalentAddressGroup.ATTR_WEIGHT"); + private final List addrs; private final Attributes attrs; @@ -108,7 +128,9 @@ public Attributes getAttributes() { @Override public String toString() { - // TODO(zpencer): Summarize return value if addr is very large + // EquivalentAddressGroup is intended to contain a small number of addresses for the same + // endpoint(e.g., IPv4/IPv6). Aggregating many groups into a single EquivalentAddressGroup + // is no longer done, so this no longer needs summarization. return "[" + addrs + "/" + attrs + "]"; } diff --git a/api/src/main/java/io/grpc/FeatureFlags.java b/api/src/main/java/io/grpc/FeatureFlags.java new file mode 100644 index 00000000000..0e414ed7b31 --- /dev/null +++ b/api/src/main/java/io/grpc/FeatureFlags.java @@ -0,0 +1,54 @@ +/* + * Copyright 2026 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Strings; + +class FeatureFlags { + private static boolean enableRfc3986Uris = getFlag("GRPC_ENABLE_RFC3986_URIS", false); + + /** Whether to parse targets as RFC 3986 URIs (true), or use {@link java.net.URI} (false). */ + @VisibleForTesting + static boolean setRfc3986UrisEnabled(boolean value) { + boolean prevValue = enableRfc3986Uris; + enableRfc3986Uris = value; + return prevValue; + } + + /** Whether to parse targets as RFC 3986 URIs (true), or use {@link java.net.URI} (false). */ + static boolean getRfc3986UrisEnabled() { + return enableRfc3986Uris; + } + + static boolean getFlag(String envVarName, boolean enableByDefault) { + String envVar = System.getenv(envVarName); + if (envVar == null) { + envVar = System.getProperty(envVarName); + } + if (envVar != null) { + envVar = envVar.trim(); + } + if (enableByDefault) { + return Strings.isNullOrEmpty(envVar) || Boolean.parseBoolean(envVar); + } else { + return !Strings.isNullOrEmpty(envVar) && Boolean.parseBoolean(envVar); + } + } + + private FeatureFlags() {} +} diff --git a/api/src/main/java/io/grpc/ForwardingChannelBuilder2.java b/api/src/main/java/io/grpc/ForwardingChannelBuilder2.java index 12ed275c06e..78fe730d91a 100644 --- a/api/src/main/java/io/grpc/ForwardingChannelBuilder2.java +++ b/api/src/main/java/io/grpc/ForwardingChannelBuilder2.java @@ -94,6 +94,12 @@ public T intercept(ClientInterceptor... interceptors) { return thisT(); } + @Override + protected T interceptWithTarget(InterceptorFactory factory) { + delegate().interceptWithTarget(factory); + return thisT(); + } + @Override public T addTransportFilter(ClientTransportFilter transportFilter) { delegate().addTransportFilter(transportFilter); @@ -251,6 +257,18 @@ public T disableServiceConfigLookUp() { return thisT(); } + @Override + protected T addMetricSink(MetricSink metricSink) { + delegate().addMetricSink(metricSink); + return thisT(); + } + + @Override + public T setNameResolverArg(NameResolver.Args.Key key, X value) { + delegate().setNameResolverArg(key, value); + return thisT(); + } + /** * Returns the {@link ManagedChannel} built by the delegate by default. Overriding method can * return different value. diff --git a/api/src/main/java/io/grpc/ForwardingServerBuilder.java b/api/src/main/java/io/grpc/ForwardingServerBuilder.java index 9cef7cfa331..d1f183dd824 100644 --- a/api/src/main/java/io/grpc/ForwardingServerBuilder.java +++ b/api/src/main/java/io/grpc/ForwardingServerBuilder.java @@ -201,6 +201,12 @@ public Server build() { return delegate().build(); } + @Override + public T addMetricSink(MetricSink metricSink) { + delegate().addMetricSink(metricSink); + return thisT(); + } + @Override public String toString() { return MoreObjects.toStringHelper(this).add("delegate", delegate()).toString(); diff --git a/api/src/main/java/io/grpc/GlobalInterceptors.java b/api/src/main/java/io/grpc/GlobalInterceptors.java deleted file mode 100644 index e5fd86170f0..00000000000 --- a/api/src/main/java/io/grpc/GlobalInterceptors.java +++ /dev/null @@ -1,93 +0,0 @@ -/* - * Copyright 2022 The gRPC Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.grpc; - -import static com.google.common.base.Preconditions.checkNotNull; - -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; - -/** The collection of global interceptors and global server stream tracers. */ -@Internal -final class GlobalInterceptors { - private static List clientInterceptors = null; - private static List serverInterceptors = null; - private static List serverStreamTracerFactories = - null; - private static boolean isGlobalInterceptorsTracersSet; - private static boolean isGlobalInterceptorsTracersGet; - - // Prevent instantiation - private GlobalInterceptors() {} - - /** - * Sets the list of global interceptors and global server stream tracers. - * - *

If {@code setInterceptorsTracers()} is called again, this method will throw {@link - * IllegalStateException}. - * - *

It is only safe to call early. This method throws {@link IllegalStateException} after any of - * the get calls [{@link #getClientInterceptors()}, {@link #getServerInterceptors()} or {@link - * #getServerStreamTracerFactories()}] has been called, in order to limit changes to the result of - * {@code setInterceptorsTracers()}. - * - * @param clientInterceptorList list of {@link ClientInterceptor} that make up global Client - * Interceptors. - * @param serverInterceptorList list of {@link ServerInterceptor} that make up global Server - * Interceptors. - * @param serverStreamTracerFactoryList list of {@link ServerStreamTracer.Factory} that make up - * global ServerStreamTracer factories. - */ - static synchronized void setInterceptorsTracers( - List clientInterceptorList, - List serverInterceptorList, - List serverStreamTracerFactoryList) { - if (isGlobalInterceptorsTracersGet) { - throw new IllegalStateException("Set cannot be called after any get call"); - } - if (isGlobalInterceptorsTracersSet) { - throw new IllegalStateException("Global interceptors and tracers are already set"); - } - checkNotNull(clientInterceptorList); - checkNotNull(serverInterceptorList); - checkNotNull(serverStreamTracerFactoryList); - clientInterceptors = Collections.unmodifiableList(new ArrayList<>(clientInterceptorList)); - serverInterceptors = Collections.unmodifiableList(new ArrayList<>(serverInterceptorList)); - serverStreamTracerFactories = - Collections.unmodifiableList(new ArrayList<>(serverStreamTracerFactoryList)); - isGlobalInterceptorsTracersSet = true; - } - - /** Returns the list of global {@link ClientInterceptor}. If not set, this returns null. */ - static synchronized List getClientInterceptors() { - isGlobalInterceptorsTracersGet = true; - return clientInterceptors; - } - - /** Returns list of global {@link ServerInterceptor}. If not set, this returns null. */ - static synchronized List getServerInterceptors() { - isGlobalInterceptorsTracersGet = true; - return serverInterceptors; - } - - /** Returns list of global {@link ServerStreamTracer.Factory}. If not set, this returns null. */ - static synchronized List getServerStreamTracerFactories() { - isGlobalInterceptorsTracersGet = true; - return serverStreamTracerFactories; - } -} diff --git a/api/src/main/java/io/grpc/Grpc.java b/api/src/main/java/io/grpc/Grpc.java index baa9f5f0ab6..a45c613fd18 100644 --- a/api/src/main/java/io/grpc/Grpc.java +++ b/api/src/main/java/io/grpc/Grpc.java @@ -56,6 +56,13 @@ private Grpc() { public static final Attributes.Key TRANSPORT_ATTR_SSL_SESSION = Attributes.Key.create("io.grpc.Grpc.TRANSPORT_ATTR_SSL_SESSION"); + /** + * The value for the custom label of per-RPC metrics. Defaults to empty string when unset. Must + * not be set to {@code null}. + */ + public static final CallOptions.Key CALL_OPTION_CUSTOM_LABEL = + CallOptions.Key.createWithDefault("io.grpc.Grpc.CALL_OPTION_CUSTOM_LABEL", ""); + /** * Annotation for transport attributes. It follows the annotation semantics defined * by {@link Attributes}. diff --git a/api/src/main/java/io/grpc/HandlerRegistry.java b/api/src/main/java/io/grpc/HandlerRegistry.java index 4aaf0114fb1..148573ada9a 100644 --- a/api/src/main/java/io/grpc/HandlerRegistry.java +++ b/api/src/main/java/io/grpc/HandlerRegistry.java @@ -19,12 +19,12 @@ import java.util.Collections; import java.util.List; import javax.annotation.Nullable; -import javax.annotation.concurrent.ThreadSafe; /** * Registry of services and their methods used by servers to dispatching incoming calls. + * + *

This class is thread-safe. */ -@ThreadSafe public abstract class HandlerRegistry { /** diff --git a/api/src/main/java/io/grpc/HttpConnectProxiedSocketAddress.java b/api/src/main/java/io/grpc/HttpConnectProxiedSocketAddress.java index d59c53db1d1..0df8dc452c1 100644 --- a/api/src/main/java/io/grpc/HttpConnectProxiedSocketAddress.java +++ b/api/src/main/java/io/grpc/HttpConnectProxiedSocketAddress.java @@ -23,6 +23,9 @@ import com.google.common.base.Objects; import java.net.InetSocketAddress; import java.net.SocketAddress; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; import javax.annotation.Nullable; /** @@ -33,6 +36,8 @@ public final class HttpConnectProxiedSocketAddress extends ProxiedSocketAddress private final SocketAddress proxyAddress; private final InetSocketAddress targetAddress; + @SuppressWarnings("serial") + private final Map headers; @Nullable private final String username; @Nullable @@ -41,6 +46,7 @@ public final class HttpConnectProxiedSocketAddress extends ProxiedSocketAddress private HttpConnectProxiedSocketAddress( SocketAddress proxyAddress, InetSocketAddress targetAddress, + Map headers, @Nullable String username, @Nullable String password) { checkNotNull(proxyAddress, "proxyAddress"); @@ -53,6 +59,7 @@ private HttpConnectProxiedSocketAddress( } this.proxyAddress = proxyAddress; this.targetAddress = targetAddress; + this.headers = headers; this.username = username; this.password = password; } @@ -87,6 +94,14 @@ public InetSocketAddress getTargetAddress() { return targetAddress; } + /** + * Returns the custom HTTP headers to be sent during the HTTP CONNECT handshake. + */ + @ExperimentalApi("https://github.com/grpc/grpc-java/issues/12479") + public Map getHeaders() { + return headers; + } + @Override public boolean equals(Object o) { if (!(o instanceof HttpConnectProxiedSocketAddress)) { @@ -95,13 +110,14 @@ public boolean equals(Object o) { HttpConnectProxiedSocketAddress that = (HttpConnectProxiedSocketAddress) o; return Objects.equal(proxyAddress, that.proxyAddress) && Objects.equal(targetAddress, that.targetAddress) + && Objects.equal(headers, that.headers) && Objects.equal(username, that.username) && Objects.equal(password, that.password); } @Override public int hashCode() { - return Objects.hashCode(proxyAddress, targetAddress, username, password); + return Objects.hashCode(proxyAddress, targetAddress, username, password, headers); } @Override @@ -109,6 +125,7 @@ public String toString() { return MoreObjects.toStringHelper(this) .add("proxyAddr", proxyAddress) .add("targetAddr", targetAddress) + .add("headers", headers) .add("username", username) // Intentionally mask out password .add("hasPassword", password != null) @@ -129,6 +146,7 @@ public static final class Builder { private SocketAddress proxyAddress; private InetSocketAddress targetAddress; + private Map headers = Collections.emptyMap(); @Nullable private String username; @Nullable @@ -153,6 +171,18 @@ public Builder setTargetAddress(InetSocketAddress targetAddress) { return this; } + /** + * Sets custom HTTP headers to be sent during the HTTP CONNECT handshake. This is an optional + * field. The headers will be sent in addition to any authentication headers (if username and + * password are set). + */ + @ExperimentalApi("https://github.com/grpc/grpc-java/issues/12479") + public Builder setHeaders(Map headers) { + this.headers = Collections.unmodifiableMap( + new HashMap<>(checkNotNull(headers, "headers"))); + return this; + } + /** * Sets the username used to connect to the proxy. This is an optional field and can be {@code * null}. @@ -175,7 +205,8 @@ public Builder setPassword(@Nullable String password) { * Creates an {@code HttpConnectProxiedSocketAddress}. */ public HttpConnectProxiedSocketAddress build() { - return new HttpConnectProxiedSocketAddress(proxyAddress, targetAddress, username, password); + return new HttpConnectProxiedSocketAddress( + proxyAddress, targetAddress, headers, username, password); } } } diff --git a/api/src/main/java/io/grpc/IgnoreJRERequirement.java b/api/src/main/java/io/grpc/IgnoreJRERequirement.java new file mode 100644 index 00000000000..2db406c5953 --- /dev/null +++ b/api/src/main/java/io/grpc/IgnoreJRERequirement.java @@ -0,0 +1,30 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc; + +import java.lang.annotation.ElementType; +import java.lang.annotation.Target; + +/** + * Disables Animal Sniffer's signature checking. This is our own package-private version to avoid + * dependening on animalsniffer-annotations. + * + *

FIELD is purposefully not supported, as Android wouldn't be able to ignore a field. Instead, + * the entire class would need to be avoided on Android. + */ +@Target({ElementType.METHOD, ElementType.CONSTRUCTOR, ElementType.TYPE}) +@interface IgnoreJRERequirement {} diff --git a/api/src/main/java/io/grpc/InternalConfigSelector.java b/api/src/main/java/io/grpc/InternalConfigSelector.java index 38856f440b4..a63009361d4 100644 --- a/api/src/main/java/io/grpc/InternalConfigSelector.java +++ b/api/src/main/java/io/grpc/InternalConfigSelector.java @@ -35,7 +35,7 @@ public abstract class InternalConfigSelector { = Attributes.Key.create("internal:io.grpc.config-selector"); // Use PickSubchannelArgs for SelectConfigArgs for now. May change over time. - /** Selects the config for an PRC. */ + /** Selects the config for an RPC. */ public abstract Result selectConfig(LoadBalancer.PickSubchannelArgs args); public static final class Result { diff --git a/api/src/main/java/io/grpc/InternalConfigurator.java b/api/src/main/java/io/grpc/InternalConfigurator.java new file mode 100644 index 00000000000..7091767a265 --- /dev/null +++ b/api/src/main/java/io/grpc/InternalConfigurator.java @@ -0,0 +1,23 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc; + +/** + * Internal access to Configurator API. + */ +@Internal +public interface InternalConfigurator extends Configurator {} diff --git a/api/src/main/java/io/grpc/InternalConfiguratorRegistry.java b/api/src/main/java/io/grpc/InternalConfiguratorRegistry.java new file mode 100644 index 00000000000..f567dab74c4 --- /dev/null +++ b/api/src/main/java/io/grpc/InternalConfiguratorRegistry.java @@ -0,0 +1,55 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc; + +import java.util.List; + +/** + * Access internal global configurators. + */ +@Internal +public final class InternalConfiguratorRegistry { + private InternalConfiguratorRegistry() {} + + public static void setConfigurators(List configurators) { + ConfiguratorRegistry.getDefaultRegistry().setConfigurators(configurators); + } + + public static List getConfigurators() { + return ConfiguratorRegistry.getDefaultRegistry().getConfigurators(); + } + + public static void configureChannelBuilder(ManagedChannelBuilder channelBuilder) { + for (Configurator configurator : ConfiguratorRegistry.getDefaultRegistry().getConfigurators()) { + configurator.configureChannelBuilder(channelBuilder); + } + } + + public static void configureServerBuilder(ServerBuilder serverBuilder) { + for (Configurator configurator : ConfiguratorRegistry.getDefaultRegistry().getConfigurators()) { + configurator.configureServerBuilder(serverBuilder); + } + } + + public static boolean wasSetConfiguratorsCalled() { + return ConfiguratorRegistry.getDefaultRegistry().wasSetConfiguratorsCalled(); + } + + public static int getConfiguratorsCallCountBeforeSet() { + return ConfiguratorRegistry.getDefaultRegistry().getConfiguratorsCallCountBeforeSet(); + } +} diff --git a/api/src/main/java/io/grpc/InternalEquivalentAddressGroup.java b/api/src/main/java/io/grpc/InternalEquivalentAddressGroup.java new file mode 100644 index 00000000000..cd171208af7 --- /dev/null +++ b/api/src/main/java/io/grpc/InternalEquivalentAddressGroup.java @@ -0,0 +1,35 @@ +/* + * Copyright 2026 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc; + +@Internal +public final class InternalEquivalentAddressGroup { + private InternalEquivalentAddressGroup() {} + + /** + * Endpoint weight for load balancing purposes. While the type is Long, it must be a valid uint32. + * Must not be zero. The weight is proportional to the other endpoints; if an endpoint's weight is + * twice that of another endpoint, it is intended to receive twice the load. + */ + public static final Attributes.Key ATTR_WEIGHT = EquivalentAddressGroup.ATTR_WEIGHT; + + /** + * The backend service associated with this EquivalentAddressGroup. + */ + public static final Attributes.Key ATTR_BACKEND_SERVICE = + EquivalentAddressGroup.ATTR_BACKEND_SERVICE; +} diff --git a/api/src/main/java/io/grpc/InternalFeatureFlags.java b/api/src/main/java/io/grpc/InternalFeatureFlags.java new file mode 100644 index 00000000000..a1e771a7571 --- /dev/null +++ b/api/src/main/java/io/grpc/InternalFeatureFlags.java @@ -0,0 +1,41 @@ +/* + * Copyright 2026 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc; + +import com.google.common.annotations.VisibleForTesting; + +/** Global variables that govern major changes to the behavior of more than one grpc module. */ +@Internal +public class InternalFeatureFlags { + + /** Whether to parse targets as RFC 3986 URIs (true), or use {@link java.net.URI} (false). */ + @VisibleForTesting + public static boolean setRfc3986UrisEnabled(boolean value) { + return FeatureFlags.setRfc3986UrisEnabled(value); + } + + /** Whether to parse targets as RFC 3986 URIs (true), or use {@link java.net.URI} (false). */ + public static boolean getRfc3986UrisEnabled() { + return FeatureFlags.getRfc3986UrisEnabled(); + } + + public static boolean getFlag(String envVarName, boolean enableByDefault) { + return FeatureFlags.getFlag(envVarName, enableByDefault); + } + + private InternalFeatureFlags() {} +} diff --git a/api/src/main/java/io/grpc/InternalGlobalInterceptors.java b/api/src/main/java/io/grpc/InternalGlobalInterceptors.java deleted file mode 100644 index db0ff6e2ce9..00000000000 --- a/api/src/main/java/io/grpc/InternalGlobalInterceptors.java +++ /dev/null @@ -1,46 +0,0 @@ -/* - * Copyright 2022 The gRPC Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.grpc; - -import java.util.List; - -/** Accessor to internal methods of {@link GlobalInterceptors}. */ -@Internal -public final class InternalGlobalInterceptors { - - public static void setInterceptorsTracers( - List clientInterceptorList, - List serverInterceptorList, - List serverStreamTracerFactoryList) { - GlobalInterceptors.setInterceptorsTracers( - clientInterceptorList, serverInterceptorList, serverStreamTracerFactoryList); - } - - public static List getClientInterceptors() { - return GlobalInterceptors.getClientInterceptors(); - } - - public static List getServerInterceptors() { - return GlobalInterceptors.getServerInterceptors(); - } - - public static List getServerStreamTracerFactories() { - return GlobalInterceptors.getServerStreamTracerFactories(); - } - - private InternalGlobalInterceptors() {} -} diff --git a/api/src/main/java/io/grpc/InternalManagedChannelBuilder.java b/api/src/main/java/io/grpc/InternalManagedChannelBuilder.java new file mode 100644 index 00000000000..083cad40098 --- /dev/null +++ b/api/src/main/java/io/grpc/InternalManagedChannelBuilder.java @@ -0,0 +1,37 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc; + +/** + * Internal accessors for {@link ManagedChannelBuilder}. + */ +@Internal +public final class InternalManagedChannelBuilder { + private InternalManagedChannelBuilder() {} + + public static > T interceptWithTarget( + ManagedChannelBuilder builder, InternalInterceptorFactory factory) { + return builder.interceptWithTarget(factory); + } + + public static > T addMetricSink( + ManagedChannelBuilder builder, MetricSink metricSink) { + return builder.addMetricSink(metricSink); + } + + public interface InternalInterceptorFactory extends ManagedChannelBuilder.InterceptorFactory {} +} diff --git a/api/src/main/java/io/grpc/InternalMethodDescriptor.java b/api/src/main/java/io/grpc/InternalMethodDescriptor.java index 23bb039e0f1..345f6065813 100644 --- a/api/src/main/java/io/grpc/InternalMethodDescriptor.java +++ b/api/src/main/java/io/grpc/InternalMethodDescriptor.java @@ -30,10 +30,12 @@ public InternalMethodDescriptor(InternalKnownTransport transport) { this.transport = checkNotNull(transport, "transport"); } + @SuppressWarnings("EnumOrdinal") public Object geRawMethodName(MethodDescriptor descriptor) { return descriptor.getRawMethodName(transport.ordinal()); } + @SuppressWarnings("EnumOrdinal") public void setRawMethodName(MethodDescriptor descriptor, Object o) { descriptor.setRawMethodName(transport.ordinal(), o); } diff --git a/api/src/main/java/io/grpc/InternalServiceProviders.java b/api/src/main/java/io/grpc/InternalServiceProviders.java index c72e01db67a..debc786a82a 100644 --- a/api/src/main/java/io/grpc/InternalServiceProviders.java +++ b/api/src/main/java/io/grpc/InternalServiceProviders.java @@ -17,7 +17,9 @@ package io.grpc; import com.google.common.annotations.VisibleForTesting; +import java.util.Iterator; import java.util.List; +import java.util.ServiceLoader; @Internal public final class InternalServiceProviders { @@ -27,12 +29,17 @@ private InternalServiceProviders() { /** * Accessor for method. */ - public static T load( + @Deprecated + public static List loadAll( Class klass, - Iterable> hardcoded, + Iterable> hardCodedClasses, ClassLoader classLoader, PriorityAccessor priorityAccessor) { - return ServiceProviders.load(klass, hardcoded, classLoader, priorityAccessor); + return loadAll( + klass, + ServiceLoader.load(klass, classLoader).iterator(), + () -> hardCodedClasses, + priorityAccessor); } /** @@ -40,10 +47,10 @@ public static T load( */ public static List loadAll( Class klass, - Iterable> hardCodedClasses, - ClassLoader classLoader, + Iterator serviceLoader, + Supplier>> hardCodedClasses, PriorityAccessor priorityAccessor) { - return ServiceProviders.loadAll(klass, hardCodedClasses, classLoader, priorityAccessor); + return ServiceProviders.loadAll(klass, serviceLoader, hardCodedClasses::get, priorityAccessor); } /** @@ -71,4 +78,8 @@ public static boolean isAndroid(ClassLoader cl) { } public interface PriorityAccessor extends ServiceProviders.PriorityAccessor {} + + public interface Supplier { + T get(); + } } diff --git a/api/src/main/java/io/grpc/InternalStatus.java b/api/src/main/java/io/grpc/InternalStatus.java index b6549bb435f..56df1decf38 100644 --- a/api/src/main/java/io/grpc/InternalStatus.java +++ b/api/src/main/java/io/grpc/InternalStatus.java @@ -38,12 +38,11 @@ private InternalStatus() {} public static final Metadata.Key CODE_KEY = Status.CODE_KEY; /** - * Create a new {@link StatusRuntimeException} with the internal option of skipping the filling - * of the stack trace. + * Create a new {@link StatusRuntimeException} skipping the filling of the stack trace. */ @Internal - public static final StatusRuntimeException asRuntimeException(Status status, - @Nullable Metadata trailers, boolean fillInStackTrace) { - return new StatusRuntimeException(status, trailers, fillInStackTrace); + public static StatusRuntimeException asRuntimeExceptionWithoutStacktrace(Status status, + @Nullable Metadata trailers) { + return new InternalStatusRuntimeException(status, trailers); } } diff --git a/api/src/main/java/io/grpc/InternalStatusRuntimeException.java b/api/src/main/java/io/grpc/InternalStatusRuntimeException.java new file mode 100644 index 00000000000..6090b701f0b --- /dev/null +++ b/api/src/main/java/io/grpc/InternalStatusRuntimeException.java @@ -0,0 +1,39 @@ +/* + * Copyright 2015 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc; + +import javax.annotation.Nullable; + +/** + * StatusRuntimeException without stack trace, implemented as a subclass, as the + * {@code String, Throwable, boolean, boolean} constructor is not available in the supported + * version of Android. + * + * @see StatusRuntimeException + */ +class InternalStatusRuntimeException extends StatusRuntimeException { + private static final long serialVersionUID = 0; + + public InternalStatusRuntimeException(Status status, @Nullable Metadata trailers) { + super(status, trailers); + } + + @Override + public synchronized Throwable fillInStackTrace() { + return this; + } +} diff --git a/api/src/main/java/io/grpc/InternalSubchannelAddressAttributes.java b/api/src/main/java/io/grpc/InternalSubchannelAddressAttributes.java new file mode 100644 index 00000000000..cfc2f7c5137 --- /dev/null +++ b/api/src/main/java/io/grpc/InternalSubchannelAddressAttributes.java @@ -0,0 +1,31 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc; + +/** + * An internal class. Do not use. + * + *

An interface to provide the attributes for address connected by subchannel. + */ +@Internal +public interface InternalSubchannelAddressAttributes { + + /** + * Return attributes of the server address connected by sub channel. + */ + public Attributes getConnectedAddressAttributes(); +} diff --git a/api/src/main/java/io/grpc/InternalTcpMetrics.java b/api/src/main/java/io/grpc/InternalTcpMetrics.java new file mode 100644 index 00000000000..3dd89b6f76c --- /dev/null +++ b/api/src/main/java/io/grpc/InternalTcpMetrics.java @@ -0,0 +1,98 @@ +/* + * Copyright 2026 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +/** + * TCP Metrics defined to be shared across transport implementations. + * These metrics and their definitions are specified in + * gRFC + * A80. + */ +@Internal +public final class InternalTcpMetrics { + + private InternalTcpMetrics() { + } + + private static final List OPTIONAL_LABELS = Arrays.asList( + "network.local.address", + "network.local.port", + "network.peer.address", + "network.peer.port"); + + public static final DoubleHistogramMetricInstrument MIN_RTT_INSTRUMENT = + MetricInstrumentRegistry.getDefaultRegistry() + .registerDoubleHistogram( + "grpc.tcp.min_rtt", + "Minimum round-trip time of a TCP connection", + "s", + Collections.emptyList(), + Collections.emptyList(), + OPTIONAL_LABELS, + false); + + public static final LongCounterMetricInstrument CONNECTIONS_CREATED_INSTRUMENT = + MetricInstrumentRegistry + .getDefaultRegistry() + .registerLongCounter( + "grpc.tcp.connections_created", + "The total number of TCP connections established.", + "{connection}", + Collections.emptyList(), + OPTIONAL_LABELS, + false); + + public static final LongUpDownCounterMetricInstrument CONNECTION_COUNT_INSTRUMENT = + MetricInstrumentRegistry + .getDefaultRegistry() + .registerLongUpDownCounter( + "grpc.tcp.connection_count", + "The current number of active TCP connections.", + "{connection}", + Collections.emptyList(), + OPTIONAL_LABELS, + false); + + public static final LongCounterMetricInstrument PACKETS_RETRANSMITTED_INSTRUMENT = + MetricInstrumentRegistry + .getDefaultRegistry() + .registerLongCounter( + "grpc.tcp.packets_retransmitted", + "The total number of packets retransmitted for all TCP connections.", + "{packet}", + Collections.emptyList(), + OPTIONAL_LABELS, + false); + + public static final LongCounterMetricInstrument RECURRING_RETRANSMITS_INSTRUMENT = + MetricInstrumentRegistry + .getDefaultRegistry() + .registerLongCounter( + "grpc.tcp.recurring_retransmits", + "The total number of times the retransmit timer " + + "popped for all TCP connections.", + "{timeout}", + Collections.emptyList(), + OPTIONAL_LABELS, + false); + +} diff --git a/api/src/main/java/io/grpc/InternalTimeUtils.java b/api/src/main/java/io/grpc/InternalTimeUtils.java new file mode 100644 index 00000000000..ef8022f53c5 --- /dev/null +++ b/api/src/main/java/io/grpc/InternalTimeUtils.java @@ -0,0 +1,26 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc; + +import java.time.Duration; + +@Internal +public final class InternalTimeUtils { + public static long convert(Duration duration) { + return TimeUtils.convertToNanos(duration); + } +} diff --git a/api/src/main/java/io/grpc/LoadBalancer.java b/api/src/main/java/io/grpc/LoadBalancer.java index 84f108c4198..ae83af2804c 100644 --- a/api/src/main/java/io/grpc/LoadBalancer.java +++ b/api/src/main/java/io/grpc/LoadBalancer.java @@ -32,7 +32,6 @@ import javax.annotation.Nullable; import javax.annotation.concurrent.Immutable; import javax.annotation.concurrent.NotThreadSafe; -import javax.annotation.concurrent.ThreadSafe; /** * A pluggable component that receives resolved addresses from {@link NameResolver} and provides the @@ -64,7 +63,7 @@ * allows implementations to schedule tasks to be run in the same Synchronization Context, with or * without a delay, thus those tasks don't need to worry about synchronizing with the balancer * methods. - * + * *

However, the actual running thread may be the network thread, thus the following rules must be * followed to prevent blocking or even dead-locking in a network: * @@ -121,6 +120,12 @@ public abstract class LoadBalancer { HEALTH_CONSUMER_LISTENER_ARG_KEY = LoadBalancer.CreateSubchannelArgs.Key.create("internal:health-check-consumer-listener"); + @Internal + public static final LoadBalancer.CreateSubchannelArgs.Key + DISABLE_SUBCHANNEL_RECONNECT_KEY = + LoadBalancer.CreateSubchannelArgs.Key.createWithDefault( + "internal:disable-subchannel-reconnect", Boolean.FALSE); + @Internal public static final Attributes.Key HAS_HEALTH_PRODUCER_LISTENER_KEY = @@ -150,15 +155,16 @@ public String toString() { private int recursionCount; /** - * Handles newly resolved server groups and metadata attributes from name resolution system. - * {@code servers} contained in {@link EquivalentAddressGroup} should be considered equivalent - * but may be flattened into a single list if needed. - * - *

Implementations should not modify the given {@code servers}. + * Handles newly resolved addresses and metadata attributes from name resolution system. + * Addresses in {@link EquivalentAddressGroup} should be considered equivalent but may be + * flattened into a single list if needed. * * @param resolvedAddresses the resolved server addresses, attributes, and config. * @since 1.21.0 + * + * @deprecated Use instead {@link #acceptResolvedAddresses(ResolvedAddresses)} */ + @Deprecated public void handleResolvedAddresses(ResolvedAddresses resolvedAddresses) { if (recursionCount++ == 0) { // Note that the information about the addresses actually being accepted will be lost @@ -173,12 +179,10 @@ public void handleResolvedAddresses(ResolvedAddresses resolvedAddresses) { * EquivalentAddressGroup} addresses should be considered equivalent but may be flattened into a * single list if needed. * - *

Implementations can choose to reject the given addresses by returning {@code false}. - * - *

Implementations should not modify the given {@code addresses}. + * @param resolvedAddresses the resolved server addresses, attributes, and config + * @return {@code Status.OK} if the resolved addresses were accepted, otherwise an error to report + * to the name resolver * - * @param resolvedAddresses the resolved server addresses, attributes, and config. - * @return {@code true} if the resolved addresses were accepted. {@code false} if rejected. * @since 1.49.0 */ public Status acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { @@ -206,7 +210,7 @@ public Status acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { * * @since 1.21.0 */ - @ExperimentalApi("https://github.com/grpc/grpc-java/issues/1771") + @ExperimentalApi("https://github.com/grpc/grpc-java/issues/11657") public static final class ResolvedAddresses { private final List addresses; @NameResolver.ResolutionResultAttr @@ -412,7 +416,16 @@ public void handleSubchannelState( * *

This method should always return a constant value. It's not specified when this will be * called. + * + *

Note that this method is only called when implementing {@code handleResolvedAddresses()} + * instead of {@code acceptResolvedAddresses()}. + * + * @deprecated Instead of overwriting this and {@code handleResolvedAddresses()}, only + * overwrite {@code acceptResolvedAddresses()} which indicates if the addresses provided + * by the name resolver are acceptable with the {@code boolean} return value. */ + @Deprecated + @SuppressWarnings("InlineMeSuggester") public boolean canHandleEmptyAddressListFromNameResolution() { return false; } @@ -436,7 +449,6 @@ public void requestConnection() {} * * @since 1.2.0 */ - @ThreadSafe @ExperimentalApi("https://github.com/grpc/grpc-java/issues/1771") public abstract static class SubchannelPicker { /** @@ -446,18 +458,6 @@ public abstract static class SubchannelPicker { * @since 1.3.0 */ public abstract PickResult pickSubchannel(PickSubchannelArgs args); - - /** - * Tries to establish connections now so that the upcoming RPC may then just pick a ready - * connection without having to connect first. - * - *

No-op if unsupported. - * - * @deprecated override {@link LoadBalancer#requestConnection} instead. - * @since 1.11.0 - */ - @Deprecated - public void requestConnection() {} } /** @@ -490,6 +490,29 @@ public abstract static class PickSubchannelArgs { * @since 1.2.0 */ public abstract MethodDescriptor getMethodDescriptor(); + + /** + * Gets an object that can be informed about what sort of pick was made. + */ + @Internal + public PickDetailsConsumer getPickDetailsConsumer() { + return new PickDetailsConsumer() {}; + } + } + + /** Receives information about the pick being chosen. */ + @Internal + public interface PickDetailsConsumer { + /** + * Optional labels that provide context of how the pick was routed. Particularly helpful for + * per-RPC metrics. + * + * @throws NullPointerException if key or value is {@code null} + */ + default void addOptionalLabel(String key, String value) { + checkNotNull(key, "key"); + checkNotNull(value, "value"); + } } /** @@ -523,6 +546,7 @@ public static final class PickResult { private final Status status; // True if the result is created by withDrop() private final boolean drop; + @Nullable private final String authorityOverride; private PickResult( @Nullable Subchannel subchannel, @Nullable ClientStreamTracer.Factory streamTracerFactory, @@ -531,6 +555,17 @@ private PickResult( this.streamTracerFactory = streamTracerFactory; this.status = checkNotNull(status, "status"); this.drop = drop; + this.authorityOverride = null; + } + + private PickResult( + @Nullable Subchannel subchannel, @Nullable ClientStreamTracer.Factory streamTracerFactory, + Status status, boolean drop, @Nullable String authorityOverride) { + this.subchannel = subchannel; + this.streamTracerFactory = streamTracerFactory; + this.status = checkNotNull(status, "status"); + this.drop = drop; + this.authorityOverride = authorityOverride; } /** @@ -603,6 +638,8 @@ private PickResult( * stream is created at all in some cases. * @since 1.3.0 */ + // TODO(shivaspeaks): Need to deprecate old APIs and create new ones, + // per https://github.com/grpc/grpc-java/issues/12662. public static PickResult withSubchannel( Subchannel subchannel, @Nullable ClientStreamTracer.Factory streamTracerFactory) { return new PickResult( @@ -610,6 +647,19 @@ public static PickResult withSubchannel( false); } + /** + * Same as {@code withSubchannel(subchannel, streamTracerFactory)} but with an authority name + * to override in the host header. + */ + @ExperimentalApi("https://github.com/grpc/grpc-java/issues/11656") + public static PickResult withSubchannel( + Subchannel subchannel, @Nullable ClientStreamTracer.Factory streamTracerFactory, + @Nullable String authorityOverride) { + return new PickResult( + checkNotNull(subchannel, "subchannel"), streamTracerFactory, Status.OK, + false, authorityOverride); + } + /** * Equivalent to {@code withSubchannel(subchannel, null)}. * @@ -619,6 +669,28 @@ public static PickResult withSubchannel(Subchannel subchannel) { return withSubchannel(subchannel, null); } + /** + * Creates a new {@code PickResult} with the given {@code subchannel}, + * but retains all other properties from this {@code PickResult}. + * + * @since 1.80.0 + */ + public PickResult copyWithSubchannel(Subchannel subchannel) { + return new PickResult(checkNotNull(subchannel, "subchannel"), streamTracerFactory, + status, drop, authorityOverride); + } + + /** + * Creates a new {@code PickResult} with the given {@code streamTracerFactory}, + * but retains all other properties from this {@code PickResult}. + * + * @since 1.80.0 + */ + public PickResult copyWithStreamTracerFactory( + @Nullable ClientStreamTracer.Factory streamTracerFactory) { + return new PickResult(subchannel, streamTracerFactory, status, drop, authorityOverride); + } + /** * A decision to report a connectivity error to the RPC. If the RPC is {@link * CallOptions#withWaitForReady wait-for-ready}, it will stay buffered. Otherwise, it will fail @@ -653,6 +725,13 @@ public static PickResult withNoResult() { return NO_RESULT; } + /** Returns the authority override if any. */ + @ExperimentalApi("https://github.com/grpc/grpc-java/issues/11656") + @Nullable + public String getAuthorityOverride() { + return authorityOverride; + } + /** * The Subchannel if this result was created by {@link #withSubchannel withSubchannel()}, or * null otherwise. @@ -693,6 +772,13 @@ public boolean isDrop() { return drop; } + /** + * Returns {@code true} if the pick was not created with {@link #withNoResult()}. + */ + public boolean hasResult() { + return !(subchannel == null && status.isOk()); + } + @Override public String toString() { return MoreObjects.toStringHelper(this) @@ -700,6 +786,7 @@ public String toString() { .add("streamTracerFactory", streamTracerFactory) .add("status", status) .add("drop", drop) + .add("authority-override", authorityOverride) .toString(); } @@ -798,9 +885,11 @@ public String toString() { @ExperimentalApi("https://github.com/grpc/grpc-java/issues/1771") public static final class Builder { + private static final Object[][] EMPTY_CUSTOM_OPTIONS = new Object[0][2]; + private List addrs; private Attributes attrs = Attributes.EMPTY; - private Object[][] customOptions = new Object[0][2]; + private Object[][] customOptions = EMPTY_CUSTOM_OPTIONS; Builder() { } @@ -939,9 +1028,10 @@ public String toString() { /** * Provides essentials for LoadBalancer implementations. * + *

This class is thread-safe. + * * @since 1.2.0 */ - @ThreadSafe @ExperimentalApi("https://github.com/grpc/grpc-java/issues/1771") public abstract static class Helper { /** @@ -964,8 +1054,8 @@ public Subchannel createSubchannel(CreateSubchannelArgs args) { } /** - * Out-of-band channel for LoadBalancer’s own RPC needs, e.g., talking to an external - * load-balancer service. + * Create an out-of-band channel for the LoadBalancer’s own RPC needs, e.g., talking to an + * external load-balancer service. * *

The LoadBalancer is responsible for closing unused OOB channels, and closing all OOB * channels within {@link #shutdown}. @@ -975,7 +1065,12 @@ public Subchannel createSubchannel(CreateSubchannelArgs args) { public abstract ManagedChannel createOobChannel(EquivalentAddressGroup eag, String authority); /** - * Accept a list of EAG for multiple authorities: https://github.com/grpc/grpc-java/issues/4618 + * Create an out-of-band channel for the LoadBalancer's own RPC needs, e.g., talking to an + * external load-balancer service. This version of the method allows multiple EAGs, so different + * addresses can have different authorities. + * + *

The LoadBalancer is responsible for closing unused OOB channels, and closing all OOB + * channels within {@link #shutdown}. * */ public ManagedChannel createOobChannel(List eag, String authority) { @@ -1127,6 +1222,10 @@ public void ignoreRefreshNameResolutionCheck() { * Returns a {@link SynchronizationContext} that runs tasks in the same Synchronization Context * as that the callback methods on the {@link LoadBalancer} interface are run in. * + *

Work added to the synchronization context might not run immediately, so LB implementations + * must be careful to ensure that any assumptions still hold when it is executed. In particular, + * the LB might have been shut down or subchannels might have changed state. + * *

Pro-tip: in order to call {@link SynchronizationContext#schedule}, you need to provide a * {@link ScheduledExecutorService}. {@link #getScheduledExecutorService} is provided for your * convenience. @@ -1162,6 +1261,13 @@ public ScheduledExecutorService getScheduledExecutorService() { */ public abstract String getAuthority(); + /** + * Returns the target string of the channel, guaranteed to include its scheme. + */ + public String getChannelTarget() { + throw new UnsupportedOperationException(); + } + /** * Returns the ChannelCredentials used to construct the channel, without bearer tokens. * @@ -1212,10 +1318,20 @@ public NameResolver.Args getNameResolverArgs() { public NameResolverRegistry getNameResolverRegistry() { throw new UnsupportedOperationException(); } + + /** + * Returns the {@link MetricRecorder} that the channel uses to record metrics. + * + * @since 1.64.0 + */ + @Internal + public MetricRecorder getMetricRecorder() { + return new MetricRecorder() {}; + } } /** - * A logical connection to a server, or a group of equivalent servers represented by an {@link + * A logical connection to a server, or a group of equivalent servers represented by an {@link * EquivalentAddressGroup}. * *

It maintains at most one physical connection (aka transport) for sending new RPCs, while @@ -1381,6 +1497,18 @@ public void updateAddresses(List addrs) { public Object getInternalSubchannel() { throw new UnsupportedOperationException(); } + + /** + * (Internal use only) returns attributes of the address subchannel is connected to. + * + *

Warning: this is INTERNAL API, is not supposed to be used by external users, and may + * change without notice. If you think you must use it, please file an issue and we can consider + * removing its "internal" status. + */ + @Internal + public Attributes getConnectedAddressAttributes() { + throw new UnsupportedOperationException(); + } } /** @@ -1422,9 +1550,10 @@ public interface SubchannelStateListener { /** * Factory to create {@link LoadBalancer} instance. * + *

This class is thread-safe. + * * @since 1.2.0 */ - @ThreadSafe @ExperimentalApi("https://github.com/grpc/grpc-java/issues/1771") public abstract static class Factory { /** @@ -1479,5 +1608,19 @@ public PickResult pickSubchannel(PickSubchannelArgs args) { public String toString() { return "FixedResultPicker(" + result + ")"; } + + @Override + public int hashCode() { + return result.hashCode(); + } + + @Override + public boolean equals(Object o) { + if (!(o instanceof FixedResultPicker)) { + return false; + } + FixedResultPicker that = (FixedResultPicker) o; + return this.result.equals(that.result); + } } } diff --git a/api/src/main/java/io/grpc/LoadBalancerProvider.java b/api/src/main/java/io/grpc/LoadBalancerProvider.java index bb4c574211e..7dc30d6baaf 100644 --- a/api/src/main/java/io/grpc/LoadBalancerProvider.java +++ b/api/src/main/java/io/grpc/LoadBalancerProvider.java @@ -81,7 +81,7 @@ public abstract class LoadBalancerProvider extends LoadBalancer.Factory { * @return a tuple of the fully parsed and validated balancer configuration, else the Status. * @since 1.20.0 * @see - * A24-lb-policy-config.md + * A24-lb-policy-config.md */ public ConfigOrError parseLoadBalancingPolicyConfig(Map rawLoadBalancingPolicyConfig) { return UNKNOWN_CONFIG; diff --git a/api/src/main/java/io/grpc/LoadBalancerRegistry.java b/api/src/main/java/io/grpc/LoadBalancerRegistry.java index f6b69f978b8..a8fbc102f5f 100644 --- a/api/src/main/java/io/grpc/LoadBalancerRegistry.java +++ b/api/src/main/java/io/grpc/LoadBalancerRegistry.java @@ -26,6 +26,7 @@ import java.util.LinkedHashSet; import java.util.List; import java.util.Map; +import java.util.ServiceLoader; import java.util.logging.Level; import java.util.logging.Logger; import javax.annotation.Nullable; @@ -42,7 +43,6 @@ public final class LoadBalancerRegistry { private static final Logger logger = Logger.getLogger(LoadBalancerRegistry.class.getName()); private static LoadBalancerRegistry instance; - private static final Iterable> HARDCODED_CLASSES = getHardCodedClasses(); private final LinkedHashSet allProviders = new LinkedHashSet<>(); @@ -101,8 +101,10 @@ public static synchronized LoadBalancerRegistry getDefaultRegistry() { if (instance == null) { List providerList = ServiceProviders.loadAll( LoadBalancerProvider.class, - HARDCODED_CLASSES, - LoadBalancerProvider.class.getClassLoader(), + ServiceLoader + .load(LoadBalancerProvider.class, LoadBalancerProvider.class.getClassLoader()) + .iterator(), + LoadBalancerRegistry::getHardCodedClasses, new LoadBalancerPriorityAccessor()); instance = new LoadBalancerRegistry(); for (LoadBalancerProvider provider : providerList) { diff --git a/api/src/main/java/io/grpc/LongCounterMetricInstrument.java b/api/src/main/java/io/grpc/LongCounterMetricInstrument.java new file mode 100644 index 00000000000..73516dfb9e4 --- /dev/null +++ b/api/src/main/java/io/grpc/LongCounterMetricInstrument.java @@ -0,0 +1,30 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc; + +import java.util.List; + +/** + * Represents a long-valued counter metric instrument. + */ +@Internal +public final class LongCounterMetricInstrument extends PartialMetricInstrument { + public LongCounterMetricInstrument(int index, String name, String description, String unit, + List requiredLabelKeys, List optionalLabelKeys, boolean enableByDefault) { + super(index, name, description, unit, requiredLabelKeys, optionalLabelKeys, enableByDefault); + } +} diff --git a/api/src/main/java/io/grpc/LongGaugeMetricInstrument.java b/api/src/main/java/io/grpc/LongGaugeMetricInstrument.java new file mode 100644 index 00000000000..393bdeb355c --- /dev/null +++ b/api/src/main/java/io/grpc/LongGaugeMetricInstrument.java @@ -0,0 +1,31 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc; + +import java.util.List; + +/** + * Represents a long-valued gauge metric instrument. + */ +@Internal +public final class LongGaugeMetricInstrument extends PartialMetricInstrument + implements CallbackMetricInstrument { + public LongGaugeMetricInstrument(int index, String name, String description, String unit, + List requiredLabelKeys, List optionalLabelKeys, boolean enableByDefault) { + super(index, name, description, unit, requiredLabelKeys, optionalLabelKeys, enableByDefault); + } +} diff --git a/api/src/main/java/io/grpc/LongHistogramMetricInstrument.java b/api/src/main/java/io/grpc/LongHistogramMetricInstrument.java new file mode 100644 index 00000000000..2a4e56ffd5a --- /dev/null +++ b/api/src/main/java/io/grpc/LongHistogramMetricInstrument.java @@ -0,0 +1,38 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc; + +import java.util.List; + +/** + * Represents a long-valued histogram metric instrument. + */ +@Internal +public final class LongHistogramMetricInstrument extends PartialMetricInstrument { + private final List bucketBoundaries; + + public LongHistogramMetricInstrument(int index, String name, String description, String unit, + List bucketBoundaries, List requiredLabelKeys, List optionalLabelKeys, + boolean enableByDefault) { + super(index, name, description, unit, requiredLabelKeys, optionalLabelKeys, enableByDefault); + this.bucketBoundaries = bucketBoundaries; + } + + public List getBucketBoundaries() { + return bucketBoundaries; + } +} diff --git a/api/src/main/java/io/grpc/LongUpDownCounterMetricInstrument.java b/api/src/main/java/io/grpc/LongUpDownCounterMetricInstrument.java new file mode 100644 index 00000000000..07e099cde5d --- /dev/null +++ b/api/src/main/java/io/grpc/LongUpDownCounterMetricInstrument.java @@ -0,0 +1,32 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc; + +import java.util.List; + +/** + * Represents a long-valued up down counter metric instrument. + */ +@Internal +public final class LongUpDownCounterMetricInstrument extends PartialMetricInstrument { + public LongUpDownCounterMetricInstrument(int index, String name, String description, String unit, + List requiredLabelKeys, + List optionalLabelKeys, + boolean enableByDefault) { + super(index, name, description, unit, requiredLabelKeys, optionalLabelKeys, enableByDefault); + } +} \ No newline at end of file diff --git a/api/src/main/java/io/grpc/ManagedChannel.java b/api/src/main/java/io/grpc/ManagedChannel.java index 7875fdb57f2..2b1d89946bf 100644 --- a/api/src/main/java/io/grpc/ManagedChannel.java +++ b/api/src/main/java/io/grpc/ManagedChannel.java @@ -17,12 +17,12 @@ package io.grpc; import java.util.concurrent.TimeUnit; -import javax.annotation.concurrent.ThreadSafe; /** * A {@link Channel} that provides lifecycle management. + * + *

This class is thread-safe. */ -@ThreadSafe public abstract class ManagedChannel extends Channel { /** * Initiates an orderly shutdown in which preexisting calls continue but new calls are immediately diff --git a/api/src/main/java/io/grpc/ManagedChannelBuilder.java b/api/src/main/java/io/grpc/ManagedChannelBuilder.java index 7fe183f2049..3f370ab3003 100644 --- a/api/src/main/java/io/grpc/ManagedChannelBuilder.java +++ b/api/src/main/java/io/grpc/ManagedChannelBuilder.java @@ -159,6 +159,21 @@ public T offloadExecutor(Executor executor) { */ public abstract T intercept(ClientInterceptor... interceptors); + /** + * Internal-only: Adds a factory that will construct an interceptor based on the channel's target. + * This can be used to work around nameResolverFactory() changing the target string. + */ + @Internal + protected T interceptWithTarget(InterceptorFactory factory) { + throw new UnsupportedOperationException(); + } + + /** Internal-only. */ + @Internal + protected interface InterceptorFactory { + ClientInterceptor newInterceptor(String target); + } + /** * Adds a {@link ClientTransportFilter}. The order of filters being added is the order they will * be executed @@ -359,9 +374,17 @@ public T maxInboundMetadataSize(int bytes) { * notice when they are causing excessive load. Clients are strongly encouraged to use only as * small of a value as necessary. * + *

When the channel implementation supports TCP_USER_TIMEOUT, enabling keepalive will also + * enable TCP_USER_TIMEOUT for the connection. This requires all sent packets to receive + * a TCP acknowledgement before the keepalive timeout. The keepalive time is not used for + * TCP_USER_TIMEOUT, except as a signal to enable the feature. grpc-netty supports + * TCP_USER_TIMEOUT on Linux platforms supported by netty-transport-native-epoll. + * * @throws UnsupportedOperationException if unsupported * @see gRFC A8 * Client-side Keepalive + * @see gRFC A18 + * TCP User Timeout * @since 1.7.0 */ public T keepAliveTime(long keepAliveTime, TimeUnit timeUnit) { @@ -378,6 +401,8 @@ public T keepAliveTime(long keepAliveTime, TimeUnit timeUnit) { * @throws UnsupportedOperationException if unsupported * @see gRFC A8 * Client-side Keepalive + * @see gRFC A18 + * TCP User Timeout * @since 1.7.0 */ public T keepAliveTimeout(long keepAliveTimeout, TimeUnit timeUnit) { @@ -607,6 +632,35 @@ public T disableServiceConfigLookUp() { throw new UnsupportedOperationException(); } + /** + * Adds a {@link MetricSink} for channel to use for configuring and recording metrics. + * + * @return this + * @since 1.64.0 + */ + @Internal + protected T addMetricSink(MetricSink metricSink) { + throw new UnsupportedOperationException(); + } + + /** + * Provides a "custom" argument for the {@link NameResolver}, if applicable, replacing any 'value' + * previously provided for 'key'. + * + *

NB: If the selected {@link NameResolver} does not understand 'key', or target URI resolution + * isn't needed at all, your custom argument will be silently ignored. + * + *

See {@link NameResolver.Args#getArg(NameResolver.Args.Key)} for more. + * + * @param key identifies the argument in a type-safe manner + * @param value the argument itself + * @return this + */ + @ExperimentalApi("https://github.com/grpc/grpc-java/issues/1770") + public T setNameResolverArg(NameResolver.Args.Key key, X value) { + throw new UnsupportedOperationException(); + } + /** * Builds a channel using the given parameters. * diff --git a/api/src/main/java/io/grpc/ManagedChannelRegistry.java b/api/src/main/java/io/grpc/ManagedChannelRegistry.java index 31f874b8094..ec47b325ffc 100644 --- a/api/src/main/java/io/grpc/ManagedChannelRegistry.java +++ b/api/src/main/java/io/grpc/ManagedChannelRegistry.java @@ -18,6 +18,7 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; +import com.google.errorprone.annotations.concurrent.GuardedBy; import java.net.SocketAddress; import java.net.URI; import java.net.URISyntaxException; @@ -28,9 +29,9 @@ import java.util.Comparator; import java.util.LinkedHashSet; import java.util.List; +import java.util.ServiceLoader; import java.util.logging.Level; import java.util.logging.Logger; -import javax.annotation.concurrent.GuardedBy; import javax.annotation.concurrent.ThreadSafe; /** @@ -100,8 +101,10 @@ public static synchronized ManagedChannelRegistry getDefaultRegistry() { if (instance == null) { List providerList = ServiceProviders.loadAll( ManagedChannelProvider.class, - getHardCodedClasses(), - ManagedChannelProvider.class.getClassLoader(), + ServiceLoader + .load(ManagedChannelProvider.class, ManagedChannelProvider.class.getClassLoader()) + .iterator(), + ManagedChannelRegistry::getHardCodedClasses, new ManagedChannelPriorityAccessor()); instance = new ManagedChannelRegistry(); for (ManagedChannelProvider provider : providerList) { @@ -160,8 +163,11 @@ ManagedChannelBuilder newChannelBuilder(NameResolverRegistry nameResolverRegi String target, ChannelCredentials creds) { NameResolverProvider nameResolverProvider = null; try { - URI uri = new URI(target); - nameResolverProvider = nameResolverRegistry.getProviderForScheme(uri.getScheme()); + String scheme = + FeatureFlags.getRfc3986UrisEnabled() + ? Uri.parse(target).getScheme() + : new URI(target).getScheme(); + nameResolverProvider = nameResolverRegistry.getProviderForScheme(scheme); } catch (URISyntaxException ignore) { // bad URI found, just ignore and continue } diff --git a/api/src/main/java/io/grpc/Metadata.java b/api/src/main/java/io/grpc/Metadata.java index 58fcefe1373..8a958d127df 100644 --- a/api/src/main/java/io/grpc/Metadata.java +++ b/api/src/main/java/io/grpc/Metadata.java @@ -16,12 +16,14 @@ package io.grpc; -import static com.google.common.base.Charsets.US_ASCII; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkNotNull; +import static java.nio.charset.StandardCharsets.US_ASCII; import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; +import com.google.common.collect.Maps; +import com.google.common.collect.Sets; import com.google.common.io.BaseEncoding; import com.google.common.io.ByteStreams; import java.io.ByteArrayInputStream; @@ -32,8 +34,6 @@ import java.util.Arrays; import java.util.BitSet; import java.util.Collections; -import java.util.HashMap; -import java.util.HashSet; import java.util.Iterator; import java.util.List; import java.util.Locale; @@ -325,7 +325,7 @@ public Set keys() { if (isEmpty()) { return Collections.emptySet(); } - Set ks = new HashSet<>(size); + Set ks = Sets.newHashSetWithExpectedSize(size); for (int i = 0; i < size; i++) { ks.add(new String(name(i), 0 /* hibyte */)); } @@ -526,7 +526,7 @@ public void merge(Metadata other) { public void merge(Metadata other, Set> keys) { Preconditions.checkNotNull(other, "other"); // Use ByteBuffer for equals and hashCode. - Map> asciiKeys = new HashMap<>(keys.size()); + Map> asciiKeys = Maps.newHashMapWithExpectedSize(keys.size()); for (Key key : keys) { asciiKeys.put(ByteBuffer.wrap(key.asciiName()), key); } diff --git a/api/src/main/java/io/grpc/MethodDescriptor.java b/api/src/main/java/io/grpc/MethodDescriptor.java index 1bfaccb4201..a02eb840deb 100644 --- a/api/src/main/java/io/grpc/MethodDescriptor.java +++ b/api/src/main/java/io/grpc/MethodDescriptor.java @@ -20,9 +20,9 @@ import com.google.common.base.MoreObjects; import com.google.common.base.Preconditions; +import com.google.errorprone.annotations.CheckReturnValue; import java.io.InputStream; import java.util.concurrent.atomic.AtomicReferenceArray; -import javax.annotation.CheckReturnValue; import javax.annotation.Nullable; import javax.annotation.concurrent.Immutable; diff --git a/api/src/main/java/io/grpc/MetricInstrument.java b/api/src/main/java/io/grpc/MetricInstrument.java new file mode 100644 index 00000000000..1930319060d --- /dev/null +++ b/api/src/main/java/io/grpc/MetricInstrument.java @@ -0,0 +1,75 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc; + +import java.util.List; + +/** + * Represents a metric instrument. Metric instrument contains information used to describe a metric. + */ +@Internal +public interface MetricInstrument { + /** + * Returns the unique index of this metric instrument. + * + * @return the index of the metric instrument. + */ + public int getIndex(); + + /** + * Returns the name of the metric. + * + * @return the name of the metric. + */ + public String getName(); + + /** + * Returns a description of the metric. + * + * @return a description of the metric. + */ + public String getDescription(); + + /** + * Returns the unit of measurement for the metric. + * + * @return the unit of measurement. + */ + public String getUnit(); + + /** + * Returns a list of required label keys for this metric instrument. + * + * @return a list of required label keys. + */ + public List getRequiredLabelKeys(); + + /** + * Returns a list of optional label keys for this metric instrument. + * + * @return a list of optional label keys. + */ + public List getOptionalLabelKeys(); + + /** + * Indicates whether this metric instrument is enabled by default. + * + * @return {@code true} if this metric instrument is enabled by default, + * {@code false} otherwise. + */ + public boolean isEnableByDefault(); +} diff --git a/api/src/main/java/io/grpc/MetricInstrumentRegistry.java b/api/src/main/java/io/grpc/MetricInstrumentRegistry.java new file mode 100644 index 00000000000..ce0f8f1b5cb --- /dev/null +++ b/api/src/main/java/io/grpc/MetricInstrumentRegistry.java @@ -0,0 +1,317 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkNotNull; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Strings; +import com.google.errorprone.annotations.concurrent.GuardedBy; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +/** + * A registry for globally registered metric instruments. + */ +@Internal +public final class MetricInstrumentRegistry { + static final int INITIAL_INSTRUMENT_CAPACITY = 5; + private static MetricInstrumentRegistry instance; + private final Object lock = new Object(); + @GuardedBy("lock") + private final Set registeredMetricNames = new HashSet<>(); + @GuardedBy("lock") + private MetricInstrument[] metricInstruments = + new MetricInstrument[INITIAL_INSTRUMENT_CAPACITY]; + @GuardedBy("lock") + private int nextAvailableMetricIndex; + + @VisibleForTesting + MetricInstrumentRegistry() {} + + /** + * Returns the default metric instrument registry. + */ + public static synchronized MetricInstrumentRegistry getDefaultRegistry() { + if (instance == null) { + instance = new MetricInstrumentRegistry(); + } + return instance; + } + + /** + * Returns a list of registered metric instruments. + */ + public List getMetricInstruments() { + synchronized (lock) { + return Collections.unmodifiableList( + Arrays.asList(Arrays.copyOfRange(metricInstruments, 0, nextAvailableMetricIndex))); + } + } + + /** + * Registers a new Double Counter metric instrument. + * + * @param name the name of the metric + * @param description a description of the metric + * @param unit the unit of measurement for the metric + * @param requiredLabelKeys a list of required label keys + * @param optionalLabelKeys a list of optional label keys + * @param enableByDefault whether the metric should be enabled by default + * @return the newly created DoubleCounterMetricInstrument + * @throws IllegalStateException if a metric with the same name already exists + */ + public DoubleCounterMetricInstrument registerDoubleCounter(String name, + String description, String unit, List requiredLabelKeys, + List optionalLabelKeys, boolean enableByDefault) { + checkArgument(!Strings.isNullOrEmpty(name), "missing metric name"); + checkNotNull(description, "description"); + checkNotNull(unit, "unit"); + checkNotNull(requiredLabelKeys, "requiredLabelKeys"); + checkNotNull(optionalLabelKeys, "optionalLabelKeys"); + synchronized (lock) { + if (registeredMetricNames.contains(name)) { + throw new IllegalStateException("Metric with name " + name + " already exists"); + } + int index = nextAvailableMetricIndex; + if (index + 1 == metricInstruments.length) { + resizeMetricInstruments(); + } + // TODO(dnvindhya): add limit for number of optional labels allowed + DoubleCounterMetricInstrument instrument = new DoubleCounterMetricInstrument( + index, name, description, unit, requiredLabelKeys, optionalLabelKeys, + enableByDefault); + metricInstruments[index] = instrument; + registeredMetricNames.add(name); + nextAvailableMetricIndex += 1; + return instrument; + } + } + + /** + * Registers a new Long Counter metric instrument. + * + * @param name the name of the metric + * @param description a description of the metric + * @param unit the unit of measurement for the metric + * @param requiredLabelKeys a list of required label keys + * @param optionalLabelKeys a list of optional label keys + * @param enableByDefault whether the metric should be enabled by default + * @return the newly created LongCounterMetricInstrument + * @throws IllegalStateException if a metric with the same name already exists + */ + public LongCounterMetricInstrument registerLongCounter(String name, + String description, String unit, List requiredLabelKeys, + List optionalLabelKeys, boolean enableByDefault) { + checkArgument(!Strings.isNullOrEmpty(name), "missing metric name"); + checkNotNull(description, "description"); + checkNotNull(unit, "unit"); + checkNotNull(requiredLabelKeys, "requiredLabelKeys"); + checkNotNull(optionalLabelKeys, "optionalLabelKeys"); + synchronized (lock) { + if (registeredMetricNames.contains(name)) { + throw new IllegalStateException("Metric with name " + name + " already exists"); + } + int index = nextAvailableMetricIndex; + if (index + 1 == metricInstruments.length) { + resizeMetricInstruments(); + } + LongCounterMetricInstrument instrument = new LongCounterMetricInstrument( + index, name, description, unit, requiredLabelKeys, optionalLabelKeys, + enableByDefault); + metricInstruments[index] = instrument; + registeredMetricNames.add(name); + nextAvailableMetricIndex += 1; + return instrument; + } + } + + /** + * Registers a new Long Up Down Counter metric instrument. + * + * @param name the name of the metric + * @param description a description of the metric + * @param unit the unit of measurement for the metric + * @param requiredLabelKeys a list of required label keys + * @param optionalLabelKeys a list of optional label keys + * @param enableByDefault whether the metric should be enabled by default + * @return the newly created LongUpDownCounterMetricInstrument + * @throws IllegalStateException if a metric with the same name already exists + */ + public LongUpDownCounterMetricInstrument registerLongUpDownCounter(String name, + String description, + String unit, + List requiredLabelKeys, + List optionalLabelKeys, + boolean enableByDefault) { + checkArgument(!Strings.isNullOrEmpty(name), "missing metric name"); + checkNotNull(description, "description"); + checkNotNull(unit, "unit"); + checkNotNull(requiredLabelKeys, "requiredLabelKeys"); + checkNotNull(optionalLabelKeys, "optionalLabelKeys"); + synchronized (lock) { + if (registeredMetricNames.contains(name)) { + throw new IllegalStateException("Metric with name " + name + " already exists"); + } + int index = nextAvailableMetricIndex; + if (index + 1 == metricInstruments.length) { + resizeMetricInstruments(); + } + LongUpDownCounterMetricInstrument instrument = new LongUpDownCounterMetricInstrument( + index, name, description, unit, requiredLabelKeys, optionalLabelKeys, + enableByDefault); + metricInstruments[index] = instrument; + registeredMetricNames.add(name); + nextAvailableMetricIndex += 1; + return instrument; + } + } + + /** + * Registers a new Double Histogram metric instrument. + * + * @param name the name of the metric + * @param description a description of the metric + * @param unit the unit of measurement for the metric + * @param bucketBoundaries recommended set of explicit bucket boundaries for the histogram + * @param requiredLabelKeys a list of required label keys + * @param optionalLabelKeys a list of optional label keys + * @param enableByDefault whether the metric should be enabled by default + * @return the newly created DoubleHistogramMetricInstrument + * @throws IllegalStateException if a metric with the same name already exists + */ + public DoubleHistogramMetricInstrument registerDoubleHistogram(String name, + String description, String unit, List bucketBoundaries, + List requiredLabelKeys, List optionalLabelKeys, boolean enableByDefault) { + checkArgument(!Strings.isNullOrEmpty(name), "missing metric name"); + checkNotNull(description, "description"); + checkNotNull(unit, "unit"); + checkNotNull(bucketBoundaries, "bucketBoundaries"); + checkNotNull(requiredLabelKeys, "requiredLabelKeys"); + checkNotNull(optionalLabelKeys, "optionalLabelKeys"); + synchronized (lock) { + if (registeredMetricNames.contains(name)) { + throw new IllegalStateException("Metric with name " + name + " already exists"); + } + int index = nextAvailableMetricIndex; + if (index + 1 == metricInstruments.length) { + resizeMetricInstruments(); + } + DoubleHistogramMetricInstrument instrument = new DoubleHistogramMetricInstrument( + index, name, description, unit, bucketBoundaries, requiredLabelKeys, + optionalLabelKeys, + enableByDefault); + metricInstruments[index] = instrument; + registeredMetricNames.add(name); + nextAvailableMetricIndex += 1; + return instrument; + } + } + + /** + * Registers a new Long Histogram metric instrument. + * + * @param name the name of the metric + * @param description a description of the metric + * @param unit the unit of measurement for the metric + * @param bucketBoundaries recommended set of explicit bucket boundaries for the histogram + * @param requiredLabelKeys a list of required label keys + * @param optionalLabelKeys a list of optional label keys + * @param enableByDefault whether the metric should be enabled by default + * @return the newly created LongHistogramMetricInstrument + * @throws IllegalStateException if a metric with the same name already exists + */ + public LongHistogramMetricInstrument registerLongHistogram(String name, + String description, String unit, List bucketBoundaries, List requiredLabelKeys, + List optionalLabelKeys, boolean enableByDefault) { + checkArgument(!Strings.isNullOrEmpty(name), "missing metric name"); + checkNotNull(description, "description"); + checkNotNull(unit, "unit"); + checkNotNull(bucketBoundaries, "bucketBoundaries"); + checkNotNull(requiredLabelKeys, "requiredLabelKeys"); + checkNotNull(optionalLabelKeys, "optionalLabelKeys"); + synchronized (lock) { + if (registeredMetricNames.contains(name)) { + throw new IllegalStateException("Metric with name " + name + " already exists"); + } + int index = nextAvailableMetricIndex; + if (index + 1 == metricInstruments.length) { + resizeMetricInstruments(); + } + LongHistogramMetricInstrument instrument = new LongHistogramMetricInstrument( + index, name, description, unit, bucketBoundaries, requiredLabelKeys, + optionalLabelKeys, + enableByDefault); + metricInstruments[index] = instrument; + registeredMetricNames.add(name); + nextAvailableMetricIndex += 1; + return instrument; + } + } + + + /** + * Registers a new Long Gauge metric instrument. + * + * @param name the name of the metric + * @param description a description of the metric + * @param unit the unit of measurement for the metric + * @param requiredLabelKeys a list of required label keys + * @param optionalLabelKeys a list of optional label keys + * @param enableByDefault whether the metric should be enabled by default + * @return the newly created LongGaugeMetricInstrument + * @throws IllegalStateException if a metric with the same name already exists + */ + public LongGaugeMetricInstrument registerLongGauge(String name, String description, + String unit, List requiredLabelKeys, List optionalLabelKeys, boolean + enableByDefault) { + checkArgument(!Strings.isNullOrEmpty(name), "missing metric name"); + checkNotNull(description, "description"); + checkNotNull(unit, "unit"); + checkNotNull(requiredLabelKeys, "requiredLabelKeys"); + checkNotNull(optionalLabelKeys, "optionalLabelKeys"); + synchronized (lock) { + if (registeredMetricNames.contains(name)) { + throw new IllegalStateException("Metric with name " + name + " already exists"); + } + int index = nextAvailableMetricIndex; + if (index + 1 == metricInstruments.length) { + resizeMetricInstruments(); + } + LongGaugeMetricInstrument instrument = new LongGaugeMetricInstrument( + index, name, description, unit, requiredLabelKeys, optionalLabelKeys, + enableByDefault); + metricInstruments[index] = instrument; + registeredMetricNames.add(name); + nextAvailableMetricIndex += 1; + return instrument; + } + } + + @GuardedBy("lock") + private void resizeMetricInstruments() { + // Increase the capacity of the metricInstruments array by INITIAL_INSTRUMENT_CAPACITY + int newInstrumentsCapacity = metricInstruments.length + INITIAL_INSTRUMENT_CAPACITY; + MetricInstrument[] resizedMetricInstruments = Arrays.copyOf(metricInstruments, + newInstrumentsCapacity); + metricInstruments = resizedMetricInstruments; + } +} diff --git a/api/src/main/java/io/grpc/MetricRecorder.java b/api/src/main/java/io/grpc/MetricRecorder.java new file mode 100644 index 00000000000..897c28011cd --- /dev/null +++ b/api/src/main/java/io/grpc/MetricRecorder.java @@ -0,0 +1,179 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc; + +import static com.google.common.base.Preconditions.checkArgument; + +import java.util.List; + +/** + * An interface used for recording gRPC metrics. Implementations of this interface are responsible + * for collecting and potentially reporting metrics from various gRPC components. + */ +@Internal +public interface MetricRecorder { + /** + * Adds a value for a double-precision counter metric instrument. + * + * @param metricInstrument The counter metric instrument to add the value against. + * @param value The value to add. + * @param requiredLabelValues A list of required label values for the metric. + * @param optionalLabelValues A list of additional, optional label values for the metric. + */ + default void addDoubleCounter(DoubleCounterMetricInstrument metricInstrument, double value, + List requiredLabelValues, List optionalLabelValues) { + checkArgument(requiredLabelValues != null + && requiredLabelValues.size() == metricInstrument.getRequiredLabelKeys().size(), + "Incorrect number of required labels provided. Expected: %s", + metricInstrument.getRequiredLabelKeys().size()); + checkArgument(optionalLabelValues != null + && optionalLabelValues.size() == metricInstrument.getOptionalLabelKeys().size(), + "Incorrect number of optional labels provided. Expected: %s", + metricInstrument.getOptionalLabelKeys().size()); + } + + /** + * Adds a value for a long valued counter metric instrument. + * + * @param metricInstrument The counter metric instrument to add the value against. + * @param value The value to add. MUST be non-negative. + * @param requiredLabelValues A list of required label values for the metric. + * @param optionalLabelValues A list of additional, optional label values for the metric. + */ + default void addLongCounter(LongCounterMetricInstrument metricInstrument, long value, + List requiredLabelValues, List optionalLabelValues) { + checkArgument(requiredLabelValues != null + && requiredLabelValues.size() == metricInstrument.getRequiredLabelKeys().size(), + "Incorrect number of required labels provided. Expected: %s", + metricInstrument.getRequiredLabelKeys().size()); + checkArgument(optionalLabelValues != null + && optionalLabelValues.size() == metricInstrument.getOptionalLabelKeys().size(), + "Incorrect number of optional labels provided. Expected: %s", + metricInstrument.getOptionalLabelKeys().size()); + } + + /** + * Adds a value for a long valued up down counter metric instrument. + * + * @param metricInstrument The counter metric instrument to add the value against. + * @param value The value to add. May be positive, negative or zero. + * @param requiredLabelValues A list of required label values for the metric. + * @param optionalLabelValues A list of additional, optional label values for the metric. + */ + default void addLongUpDownCounter(LongUpDownCounterMetricInstrument metricInstrument, + long value, + List requiredLabelValues, + List optionalLabelValues) { + checkArgument(requiredLabelValues != null + && requiredLabelValues.size() == metricInstrument.getRequiredLabelKeys().size(), + "Incorrect number of required labels provided. Expected: %s", + metricInstrument.getRequiredLabelKeys().size()); + checkArgument(optionalLabelValues != null + && optionalLabelValues.size() == metricInstrument.getOptionalLabelKeys().size(), + "Incorrect number of optional labels provided. Expected: %s", + metricInstrument.getOptionalLabelKeys().size()); + } + + + /** + * Records a value for a double-precision histogram metric instrument. + * + * @param metricInstrument The histogram metric instrument to record the value against. + * @param value The value to record. + * @param requiredLabelValues A list of required label values for the metric. + * @param optionalLabelValues A list of additional, optional label values for the metric. + */ + default void recordDoubleHistogram(DoubleHistogramMetricInstrument metricInstrument, double value, + List requiredLabelValues, List optionalLabelValues) { + checkArgument(requiredLabelValues != null + && requiredLabelValues.size() == metricInstrument.getRequiredLabelKeys().size(), + "Incorrect number of required labels provided. Expected: %s", + metricInstrument.getRequiredLabelKeys().size()); + checkArgument(optionalLabelValues != null + && optionalLabelValues.size() == metricInstrument.getOptionalLabelKeys().size(), + "Incorrect number of optional labels provided. Expected: %s", + metricInstrument.getOptionalLabelKeys().size()); + } + + /** + * Records a value for a long valued histogram metric instrument. + * + * @param metricInstrument The histogram metric instrument to record the value against. + * @param value The value to record. + * @param requiredLabelValues A list of required label values for the metric. + * @param optionalLabelValues A list of additional, optional label values for the metric. + */ + default void recordLongHistogram(LongHistogramMetricInstrument metricInstrument, long value, + List requiredLabelValues, List optionalLabelValues) { + checkArgument(requiredLabelValues != null + && requiredLabelValues.size() == metricInstrument.getRequiredLabelKeys().size(), + "Incorrect number of required labels provided. Expected: %s", + metricInstrument.getRequiredLabelKeys().size()); + checkArgument(optionalLabelValues != null + && optionalLabelValues.size() == metricInstrument.getOptionalLabelKeys().size(), + "Incorrect number of optional labels provided. Expected: %s", + metricInstrument.getOptionalLabelKeys().size()); + } + + /** + * Registers a callback to produce metric values for only the listed instruments. The returned + * registration must be closed when no longer needed, which will remove the callback. + * + * @param callback The callback to call to record. + * @param metricInstruments The metric instruments the callback will record against. + */ + default Registration registerBatchCallback(BatchCallback callback, + CallbackMetricInstrument... metricInstruments) { + return () -> { }; + } + + /** Callback to record gauge values. */ + interface BatchCallback { + /** Records instrument values into {@code recorder}. */ + void accept(BatchRecorder recorder); + } + + /** Recorder for instrument values produced by a batch callback. */ + interface BatchRecorder { + /** + * Record a long gauge value. + * + * @param value The value to record. + * @param requiredLabelValues A list of required label values for the metric. + * @param optionalLabelValues A list of additional, optional label values for the metric. + */ + default void recordLongGauge(LongGaugeMetricInstrument metricInstrument, long value, + List requiredLabelValues, List optionalLabelValues) { + checkArgument(requiredLabelValues != null + && requiredLabelValues.size() == metricInstrument.getRequiredLabelKeys().size(), + "Incorrect number of required labels provided. Expected: %s", + metricInstrument.getRequiredLabelKeys().size()); + checkArgument(optionalLabelValues != null + && optionalLabelValues.size() == metricInstrument.getOptionalLabelKeys().size(), + "Incorrect number of optional labels provided. Expected: %s", + metricInstrument.getOptionalLabelKeys().size()); + } + } + + /** A handle to a registration, that allows unregistration. */ + interface Registration extends AutoCloseable { + // Redefined to not throw an exception. + /** Unregister. */ + @Override + void close(); + } +} diff --git a/api/src/main/java/io/grpc/MetricSink.java b/api/src/main/java/io/grpc/MetricSink.java new file mode 100644 index 00000000000..ce5d3822520 --- /dev/null +++ b/api/src/main/java/io/grpc/MetricSink.java @@ -0,0 +1,142 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc; + +import java.util.List; +import java.util.Map; +import java.util.Set; + +/** + * An internal interface representing a receiver or aggregator of gRPC metrics data. + */ +@Internal +public interface MetricSink { + + /** + * Returns a set of names for the metrics that are currently enabled or disabled. + * + * @return A set of enabled metric names. + */ + Map getEnabledMetrics(); + + /** + * Returns a set of optional label names for metrics that the sink actually wants. + * + * @return A set of optional label names. + */ + Set getOptionalLabels(); + + /** + * Returns size of metric measures used to record metric values. These measures are created + * based on registered metrics (via MetricInstrumentRegistry) and are ordered according to their + * registration sequence. + * + * @return Size of metric measures. + */ + int getMeasuresSize(); + + /** + * Adds a value for a double-precision counter associated with specified metric instrument. + * + * @param metricInstrument The counter metric instrument identifies metric measure to add. + * @param value The value to record. + * @param requiredLabelValues A list of required label values for the metric. + * @param optionalLabelValues A list of additional, optional label values for the metric. + */ + default void addDoubleCounter(DoubleCounterMetricInstrument metricInstrument, double value, + List requiredLabelValues, List optionalLabelValues) { + } + + /** + * Adds a value for a long valued counter metric associated with specified metric instrument. + * + * @param metricInstrument The counter metric instrument identifies metric measure to add. + * @param value The value to record. MUST be non-negative. + * @param requiredLabelValues A list of required label values for the metric. + * @param optionalLabelValues A list of additional, optional label values for the metric. + */ + default void addLongCounter(LongCounterMetricInstrument metricInstrument, long value, + List requiredLabelValues, List optionalLabelValues) { + } + + /** + * Adds a value for a long valued up down counter metric associated with specified metric + * instrument. + * + * @param metricInstrument The counter metric instrument identifies metric measure to add. + * @param value The value to record. May be positive, negative or zero. + * @param requiredLabelValues A list of required label values for the metric. + * @param optionalLabelValues A list of additional, optional label values for the metric. + */ + default void addLongUpDownCounter(LongUpDownCounterMetricInstrument metricInstrument, long value, + List requiredLabelValues, + List optionalLabelValues) { + } + + /** + * Records a value for a double-precision histogram metric associated with specified metric + * instrument. + * + * @param metricInstrument The histogram metric instrument identifies metric measure to record. + * @param value The value to record. + * @param requiredLabelValues A list of required label values for the metric. + * @param optionalLabelValues A list of additional, optional label values for the metric. + */ + default void recordDoubleHistogram(DoubleHistogramMetricInstrument metricInstrument, double value, + List requiredLabelValues, List optionalLabelValues) { + } + + /** + * Records a value for a long valued histogram metric associated with specified metric + * instrument. + * + * @param metricInstrument The histogram metric instrument identifies metric measure to record. + * @param value The value to record. + * @param requiredLabelValues A list of required label values for the metric. + * @param optionalLabelValues A list of additional, optional label values for the metric. + */ + default void recordLongHistogram(LongHistogramMetricInstrument metricInstrument, long value, + List requiredLabelValues, List optionalLabelValues) { + } + + /** + * Record a long gauge value. + * + * @param value The value to record. + * @param requiredLabelValues A list of required label values for the metric. + * @param optionalLabelValues A list of additional, optional label values for the metric. + */ + default void recordLongGauge(LongGaugeMetricInstrument metricInstrument, long value, + List requiredLabelValues, List optionalLabelValues){ + } + + /** + * Registers a callback to produce metric values for only the listed instruments. The returned + * registration must be closed when no longer needed, which will remove the callback. + * + * @param callback The callback to call to record. + * @param metricInstruments The metric instruments the callback will record against. + */ + default Registration registerBatchCallback(Runnable callback, + CallbackMetricInstrument... metricInstruments) { + return () -> { }; + } + + interface Registration extends MetricRecorder.Registration {} + + void updateMeasures(List instruments); +} diff --git a/api/src/main/java/io/grpc/NameResolver.java b/api/src/main/java/io/grpc/NameResolver.java index a74512eb7e3..80bc338d86b 100644 --- a/api/src/main/java/io/grpc/NameResolver.java +++ b/api/src/main/java/io/grpc/NameResolver.java @@ -20,20 +20,21 @@ import static com.google.common.base.Preconditions.checkNotNull; import com.google.common.base.MoreObjects; +import com.google.common.base.MoreObjects.ToStringHelper; import com.google.common.base.Objects; import com.google.errorprone.annotations.InlineMe; import java.lang.annotation.Documented; import java.lang.annotation.Retention; import java.lang.annotation.RetentionPolicy; import java.net.URI; -import java.util.ArrayList; import java.util.Collections; +import java.util.IdentityHashMap; import java.util.List; import java.util.Map; import java.util.concurrent.Executor; import java.util.concurrent.ScheduledExecutorService; import javax.annotation.Nullable; -import javax.annotation.concurrent.ThreadSafe; +import javax.annotation.concurrent.Immutable; /** * A pluggable component that resolves a target {@link URI} and return addresses to the caller. @@ -76,7 +77,7 @@ public abstract class NameResolver { * Starts the resolution. The method is not supposed to throw any exceptions. That might cause the * Channel that the name resolver is serving to crash. Errors should be propagated * through {@link Listener#onError}. - * + * *

An instance may not be started more than once, by any overload of this method, even after * an intervening call to {@link #shutdown}. * @@ -95,7 +96,14 @@ public void onError(Status error) { @Override public void onResult(ResolutionResult resolutionResult) { - listener.onAddresses(resolutionResult.getAddresses(), resolutionResult.getAttributes()); + StatusOr> addressesOrError = + resolutionResult.getAddressesOrError(); + if (addressesOrError.hasValue()) { + listener.onAddresses(addressesOrError.getValue(), + resolutionResult.getAttributes()); + } else { + listener.onError(addressesOrError.getStatus()); + } } }); } @@ -105,7 +113,7 @@ public void onResult(ResolutionResult resolutionResult) { * Starts the resolution. The method is not supposed to throw any exceptions. That might cause the * Channel that the name resolver is serving to crash. Errors should be propagated * through {@link Listener2#onError}. - * + * *

An instance may not be started more than once, by any overload of this method, even after * an intervening call to {@link #shutdown}. * @@ -149,6 +157,10 @@ public abstract static class Factory { * cannot be resolved by this factory. The decision should be solely based on the scheme of the * URI. * + *

This method will eventually be deprecated and removed as part of a migration from {@code + * java.net.URI} to {@code io.grpc.Uri}. Implementations will override {@link + * #newNameResolver(Uri, Args)} instead. + * * @param targetUri the target URI to be resolved, whose scheme must not be {@code null} * @param args other information that may be useful * @@ -156,6 +168,37 @@ public abstract static class Factory { */ public abstract NameResolver newNameResolver(URI targetUri, final Args args); + /** + * Creates a {@link NameResolver} for the given target URI. + * + *

Implementations return {@code null} if 'targetUri' cannot be resolved by this factory. The + * decision should be solely based on the target's scheme. + * + *

All {@link NameResolver.Factory} implementations should override this method, as it will + * eventually replace {@link #newNameResolver(URI, Args)}. For backwards compatibility, this + * default implementation delegates to {@link #newNameResolver(URI, Args)} if 'targetUri' can be + * converted to a java.net.URI. + * + *

NB: Conversion is not always possible, for example {@code scheme:#frag} is a valid {@link + * Uri} but not a valid {@link URI} because its path is empty. The default implementation throws + * IllegalArgumentException in these cases. + * + * @param targetUri the target URI to be resolved + * @param args other information that may be useful + * @throws IllegalArgumentException if targetUri does not have the expected form + * @since 1.79 + */ + public NameResolver newNameResolver(Uri targetUri, final Args args) { + // Not every io.grpc.Uri can be converted but in the ordinary ManagedChannel creation flow, + // any IllegalArgumentException thrown here would happened anyway, just earlier. That's + // because parse/toString is transparent so java.net.URI#create here sees the original target + // string just like it did before the io.grpc.Uri migration. + // + // Throwing IAE shouldn't surprise non-framework callers either. After all, many existing + // Factory impls are picky about targetUri and throw IAE when it doesn't look how they expect. + return newNameResolver(URI.create(targetUri.toString()), args); + } + /** * Returns the default scheme, which will be used to construct a URI when {@link * ManagedChannelBuilder#forTarget(String)} is given an authority string instead of a compliant @@ -171,10 +214,11 @@ public abstract static class Factory { * *

All methods are expected to return quickly. * + *

This interface is thread-safe. + * * @since 1.0.0 */ @ExperimentalApi("https://github.com/grpc/grpc-java/issues/1770") - @ThreadSafe public interface Listener { /** * Handles updates on resolved addresses and attributes. @@ -218,19 +262,26 @@ public abstract static class Listener2 implements Listener { @Override @Deprecated @InlineMe( - replacement = "this.onResult(ResolutionResult.newBuilder().setAddresses(servers)" - + ".setAttributes(attributes).build())", - imports = "io.grpc.NameResolver.ResolutionResult") + replacement = "this.onResult(ResolutionResult.newBuilder().setAddressesOrError(" + + "StatusOr.fromValue(servers)).setAttributes(attributes).build())", + imports = {"io.grpc.NameResolver.ResolutionResult", "io.grpc.StatusOr"}) public final void onAddresses( List servers, @ResolutionResultAttr Attributes attributes) { // TODO(jihuncho) need to promote Listener2 if we want to use ConfigOrError + // Calling onResult and not onResult2 because onResult2 can only be called from a + // synchronization context. onResult( - ResolutionResult.newBuilder().setAddresses(servers).setAttributes(attributes).build()); + ResolutionResult.newBuilder().setAddressesOrError( + StatusOr.fromValue(servers)).setAttributes(attributes).build()); } /** * Handles updates on resolved addresses and attributes. If - * {@link ResolutionResult#getAddresses()} is empty, {@link #onError(Status)} will be called. + * {@link ResolutionResult#getAddressesOrError()} is empty, {@link #onError(Status)} will be + * called. + * + *

Newer NameResolver implementations should prefer calling onResult2. This method exists to + * facilitate older {@link Listener} implementations to migrate to {@link Listener2}. * * @param resolutionResult the resolved server addresses, attributes, and Service Config. * @since 1.21.0 @@ -241,11 +292,31 @@ public final void onAddresses( * Handles a name resolving error from the resolver. The listener is responsible for eventually * invoking {@link NameResolver#refresh()} to re-attempt resolution. * + *

New NameResolver implementations should prefer calling onResult2 which will have the + * address resolution error in {@link ResolutionResult}'s addressesOrError. This method exists + * to facilitate older implementations using {@link Listener} to migrate to {@link Listener2}. + * * @param error a non-OK status * @since 1.21.0 */ @Override public abstract void onError(Status error); + + /** + * Handles updates on resolved addresses and attributes. Must be called from the same + * {@link SynchronizationContext} available in {@link NameResolver.Args} that is passed + * from the channel. + * + * @param resolutionResult the resolved server addresses or error in address resolution, + * attributes, and Service Config or error + * @return status indicating whether the resolutionResult was accepted by the listener, + * typically the result from a load balancer. + * @since 1.66 + */ + public Status onResult2(ResolutionResult resolutionResult) { + onResult(resolutionResult); + return Status.OK; + } } /** @@ -257,10 +328,20 @@ public final void onAddresses( @Documented public @interface ResolutionResultAttr {} + @ExperimentalApi("https://github.com/grpc/grpc-java/issues/11989") + @ResolutionResultAttr + public static final Attributes.Key ATTR_BACKEND_SERVICE = + Attributes.Key.create("io.grpc.NameResolver.ATTR_BACKEND_SERVICE"); + /** * Information that a {@link Factory} uses to create a {@link NameResolver}. * - *

Note this class doesn't override neither {@code equals()} nor {@code hashCode()}. + *

Args applicable to all {@link NameResolver}s are defined here using ordinary setters and + * getters. This container can also hold externally-defined "custom" args that aren't so widely + * useful or that would be inappropriate dependencies for this low level API. See {@link + * Args#getArg} for more. + * + *

Note this class overrides neither {@code equals()} nor {@code hashCode()}. * * @since 1.21.0 */ @@ -274,24 +355,24 @@ public static final class Args { @Nullable private final ChannelLogger channelLogger; @Nullable private final Executor executor; @Nullable private final String overrideAuthority; - - private Args( - Integer defaultPort, - ProxyDetector proxyDetector, - SynchronizationContext syncContext, - ServiceConfigParser serviceConfigParser, - @Nullable ScheduledExecutorService scheduledExecutorService, - @Nullable ChannelLogger channelLogger, - @Nullable Executor executor, - @Nullable String overrideAuthority) { - this.defaultPort = checkNotNull(defaultPort, "defaultPort not set"); - this.proxyDetector = checkNotNull(proxyDetector, "proxyDetector not set"); - this.syncContext = checkNotNull(syncContext, "syncContext not set"); - this.serviceConfigParser = checkNotNull(serviceConfigParser, "serviceConfigParser not set"); - this.scheduledExecutorService = scheduledExecutorService; - this.channelLogger = channelLogger; - this.executor = executor; - this.overrideAuthority = overrideAuthority; + private final MetricRecorder metricRecorder; + @Nullable private final NameResolverRegistry nameResolverRegistry; + @Nullable private final IdentityHashMap, Object> customArgs; + + private Args(Builder builder) { + this.defaultPort = checkNotNull(builder.defaultPort, "defaultPort not set"); + this.proxyDetector = checkNotNull(builder.proxyDetector, "proxyDetector not set"); + this.syncContext = checkNotNull(builder.syncContext, "syncContext not set"); + this.serviceConfigParser = + checkNotNull(builder.serviceConfigParser, "serviceConfigParser not set"); + this.scheduledExecutorService = builder.scheduledExecutorService; + this.channelLogger = builder.channelLogger; + this.executor = builder.executor; + this.overrideAuthority = builder.overrideAuthority; + this.metricRecorder = builder.metricRecorder != null ? builder.metricRecorder + : new MetricRecorder() {}; + this.nameResolverRegistry = builder.nameResolverRegistry; + this.customArgs = cloneCustomArgs(builder.customArgs); } /** @@ -300,6 +381,7 @@ private Args( * * @since 1.21.0 */ + //

TODO: Only meaningful for InetSocketAddress producers. Make this a custom arg? public int getDefaultPort() { return defaultPort; } @@ -352,6 +434,30 @@ public ServiceConfigParser getServiceConfigParser() { return serviceConfigParser; } + /** + * Returns the value of a custom arg named 'key', or {@code null} if it's not set. + * + *

While ordinary {@link Args} should be universally useful and meaningful, custom arguments + * can apply just to resolvers of a certain URI scheme, just to resolvers producing a particular + * type of {@link java.net.SocketAddress}, or even an individual {@link NameResolver} subclass. + * Custom args are identified by an instance of {@link Args.Key} which should be a constant + * defined in a java package and class appropriate for the argument's scope. + * + *

{@link Args} are normally reserved for information in *support* of name resolution, not + * the name to be resolved itself. However, there are rare cases where all or part of the target + * name can't be represented by any standard URI scheme or can't be encoded as a String at all. + * Custom args, in contrast, can hold arbitrary Java types, making them a useful work around in + * these cases. + * + *

Custom args can also be used simply to avoid adding inappropriate deps to the low level + * io.grpc package. + */ + @SuppressWarnings("unchecked") // Cast is safe because all put()s go through the setArg() API. + @Nullable + public T getArg(Key key) { + return customArgs != null ? (T) customArgs.get(key) : null; + } + /** * Returns the {@link ChannelLogger} for the Channel served by this NameResolver. * @@ -389,6 +495,25 @@ public String getOverrideAuthority() { return overrideAuthority; } + /** + * Returns the {@link MetricRecorder} that the channel uses to record metrics. + */ + public MetricRecorder getMetricRecorder() { + return metricRecorder; + } + + /** + * Returns the {@link NameResolverRegistry} that the Channel uses to look for {@link + * NameResolver}s. + * + * @since 1.74.0 + */ + public NameResolverRegistry getNameResolverRegistry() { + if (nameResolverRegistry == null) { + throw new IllegalStateException("NameResolverRegistry is not set in Builder"); + } + return nameResolverRegistry; + } @Override public String toString() { @@ -397,10 +522,13 @@ public String toString() { .add("proxyDetector", proxyDetector) .add("syncContext", syncContext) .add("serviceConfigParser", serviceConfigParser) + .add("customArgs", customArgs) .add("scheduledExecutorService", scheduledExecutorService) .add("channelLogger", channelLogger) .add("executor", executor) .add("overrideAuthority", overrideAuthority) + .add("metricRecorder", metricRecorder) + .add("nameResolverRegistry", nameResolverRegistry) .toString(); } @@ -419,6 +547,9 @@ public Builder toBuilder() { builder.setChannelLogger(channelLogger); builder.setOffloadExecutor(executor); builder.setOverrideAuthority(overrideAuthority); + builder.setMetricRecorder(metricRecorder); + builder.setNameResolverRegistry(nameResolverRegistry); + builder.customArgs = cloneCustomArgs(customArgs); return builder; } @@ -445,6 +576,9 @@ public static final class Builder { private ChannelLogger channelLogger; private Executor executor; private String overrideAuthority; + private MetricRecorder metricRecorder; + private NameResolverRegistry nameResolverRegistry; + private IdentityHashMap, Object> customArgs; Builder() { } @@ -531,16 +665,75 @@ public Builder setOverrideAuthority(String authority) { return this; } + /** See {@link Args#getArg(Key)}. */ + public Builder setArg(Key key, T value) { + checkNotNull(key, "key"); + checkNotNull(value, "value"); + if (customArgs == null) { + customArgs = new IdentityHashMap<>(); + } + customArgs.put(key, value); + return this; + } + + /** + * See {@link Args#getMetricRecorder()}. This is an optional field. + */ + public Builder setMetricRecorder(MetricRecorder metricRecorder) { + this.metricRecorder = checkNotNull(metricRecorder, "metricRecorder"); + return this; + } + + /** + * See {@link Args#getNameResolverRegistry}. This is an optional field. + * + * @since 1.74.0 + */ + public Builder setNameResolverRegistry(NameResolverRegistry registry) { + this.nameResolverRegistry = registry; + return this; + } + /** * Builds an {@link Args}. * * @since 1.21.0 */ public Args build() { - return - new Args( - defaultPort, proxyDetector, syncContext, serviceConfigParser, - scheduledExecutorService, channelLogger, executor, overrideAuthority); + return new Args(this); + } + } + + /** + * Identifies an externally-defined custom argument that can be stored in {@link Args}. + * + *

Uses reference equality so keys should be defined as global constants. + * + * @param type of values that can be stored under this key + */ + @Immutable + @SuppressWarnings("UnusedTypeParameter") + public static final class Key { + private final String debugString; + + private Key(String debugString) { + this.debugString = debugString; + } + + @Override + public String toString() { + return debugString; + } + + /** + * Creates a new instance of {@link Key}. + * + * @param debugString a string used to describe the key, used for debugging. + * @param Key type + * @return a new instance of Key + */ + public static Key create(String debugString) { + return new Key<>(debugString); } } } @@ -573,17 +766,17 @@ public abstract static class ServiceConfigParser { */ @ExperimentalApi("https://github.com/grpc/grpc-java/issues/1770") public static final class ResolutionResult { - private final List addresses; + private final StatusOr> addressesOrError; @ResolutionResultAttr private final Attributes attributes; @Nullable private final ConfigOrError serviceConfig; ResolutionResult( - List addresses, + StatusOr> addressesOrError, @ResolutionResultAttr Attributes attributes, ConfigOrError serviceConfig) { - this.addresses = Collections.unmodifiableList(new ArrayList<>(addresses)); + this.addressesOrError = addressesOrError; this.attributes = checkNotNull(attributes, "attributes"); this.serviceConfig = serviceConfig; } @@ -604,7 +797,7 @@ public static Builder newBuilder() { */ public Builder toBuilder() { return newBuilder() - .setAddresses(addresses) + .setAddressesOrError(addressesOrError) .setAttributes(attributes) .setServiceConfig(serviceConfig); } @@ -613,9 +806,20 @@ public Builder toBuilder() { * Gets the addresses resolved by name resolution. * * @since 1.21.0 + * @deprecated Will be superseded by getAddressesOrError */ + @Deprecated public List getAddresses() { - return addresses; + return addressesOrError.getValue(); + } + + /** + * Gets the addresses resolved by name resolution or the error in doing so. + * + * @since 1.65.0 + */ + public StatusOr> getAddressesOrError() { + return addressesOrError; } /** @@ -641,11 +845,11 @@ public ConfigOrError getServiceConfig() { @Override public String toString() { - return MoreObjects.toStringHelper(this) - .add("addresses", addresses) - .add("attributes", attributes) - .add("serviceConfig", serviceConfig) - .toString(); + ToStringHelper stringHelper = MoreObjects.toStringHelper(this); + stringHelper.add("addressesOrError", addressesOrError.toString()); + stringHelper.add("attributes", attributes); + stringHelper.add("serviceConfigOrError", serviceConfig); + return stringHelper.toString(); } /** @@ -657,7 +861,7 @@ public boolean equals(Object obj) { return false; } ResolutionResult that = (ResolutionResult) obj; - return Objects.equal(this.addresses, that.addresses) + return Objects.equal(this.addressesOrError, that.addressesOrError) && Objects.equal(this.attributes, that.attributes) && Objects.equal(this.serviceConfig, that.serviceConfig); } @@ -667,7 +871,7 @@ public boolean equals(Object obj) { */ @Override public int hashCode() { - return Objects.hashCode(addresses, attributes, serviceConfig); + return Objects.hashCode(addressesOrError, attributes, serviceConfig); } /** @@ -677,7 +881,8 @@ public int hashCode() { */ @ExperimentalApi("https://github.com/grpc/grpc-java/issues/1770") public static final class Builder { - private List addresses = Collections.emptyList(); + private StatusOr> addresses = + StatusOr.fromValue(Collections.emptyList()); private Attributes attributes = Attributes.EMPTY; @Nullable private ConfigOrError serviceConfig; @@ -689,9 +894,21 @@ public static final class Builder { * Sets the addresses resolved by name resolution. This field is required. * * @since 1.21.0 + * @deprecated Will be superseded by setAddressesOrError */ + @Deprecated public Builder setAddresses(List addresses) { - this.addresses = addresses; + setAddressesOrError(StatusOr.fromValue(addresses)); + return this; + } + + /** + * Sets the addresses resolved by name resolution or the error in doing so. This field is + * required. + * @param addresses Resolved addresses or an error in resolving addresses + */ + public Builder setAddressesOrError(StatusOr> addresses) { + this.addresses = checkNotNull(addresses, "StatusOr addresses cannot be null."); return this; } @@ -814,4 +1031,10 @@ public String toString() { } } } + + @Nullable + private static IdentityHashMap, Object> cloneCustomArgs( + @Nullable IdentityHashMap, Object> customArgs) { + return customArgs != null ? new IdentityHashMap<>(customArgs) : null; + } } diff --git a/api/src/main/java/io/grpc/NameResolverRegistry.java b/api/src/main/java/io/grpc/NameResolverRegistry.java index 23eec23fd6a..c5e9f7467ab 100644 --- a/api/src/main/java/io/grpc/NameResolverRegistry.java +++ b/api/src/main/java/io/grpc/NameResolverRegistry.java @@ -20,6 +20,7 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableMap; +import com.google.errorprone.annotations.concurrent.GuardedBy; import java.net.URI; import java.util.ArrayList; import java.util.Collections; @@ -28,10 +29,10 @@ import java.util.List; import java.util.Locale; import java.util.Map; +import java.util.ServiceLoader; import java.util.logging.Level; import java.util.logging.Logger; import javax.annotation.Nullable; -import javax.annotation.concurrent.GuardedBy; import javax.annotation.concurrent.ThreadSafe; /** @@ -125,8 +126,10 @@ public static synchronized NameResolverRegistry getDefaultRegistry() { if (instance == null) { List providerList = ServiceProviders.loadAll( NameResolverProvider.class, - getHardCodedClasses(), - NameResolverProvider.class.getClassLoader(), + ServiceLoader + .load(NameResolverProvider.class, NameResolverProvider.class.getClassLoader()) + .iterator(), + NameResolverRegistry::getHardCodedClasses, new NameResolverPriorityAccessor()); if (providerList.isEmpty()) { logger.warning("No NameResolverProviders found via ServiceLoader, including for DNS. This " @@ -166,6 +169,11 @@ static List> getHardCodedClasses() { } catch (ClassNotFoundException e) { logger.log(Level.FINE, "Unable to find DNS NameResolver", e); } + try { + list.add(Class.forName("io.grpc.binder.internal.IntentNameResolverProvider")); + } catch (ClassNotFoundException e) { + logger.log(Level.FINE, "Unable to find IntentNameResolverProvider", e); + } return Collections.unmodifiableList(list); } @@ -177,6 +185,13 @@ public NameResolver newNameResolver(URI targetUri, NameResolver.Args args) { return provider == null ? null : provider.newNameResolver(targetUri, args); } + @Override + @Nullable + public NameResolver newNameResolver(io.grpc.Uri targetUri, NameResolver.Args args) { + NameResolverProvider provider = getProviderForScheme(targetUri.getScheme()); + return provider == null ? null : provider.newNameResolver(targetUri, args); + } + @Override public String getDefaultScheme() { return NameResolverRegistry.this.getDefaultScheme(); diff --git a/api/src/main/java/io/grpc/PartialForwardingServerCall.java b/api/src/main/java/io/grpc/PartialForwardingServerCall.java index a7da647308b..a313407b23e 100644 --- a/api/src/main/java/io/grpc/PartialForwardingServerCall.java +++ b/api/src/main/java/io/grpc/PartialForwardingServerCall.java @@ -58,6 +58,12 @@ public void setMessageCompression(boolean enabled) { delegate().setMessageCompression(enabled); } + @Override + @ExperimentalApi("https://github.com/grpc/grpc-java/issues/11021") + public void setOnReadyThreshold(int numBytes) { + delegate().setOnReadyThreshold(numBytes); + } + @Override @ExperimentalApi("https://github.com/grpc/grpc-java/issues/1704") public void setCompression(String compressor) { diff --git a/api/src/main/java/io/grpc/PartialMetricInstrument.java b/api/src/main/java/io/grpc/PartialMetricInstrument.java new file mode 100644 index 00000000000..7e032634f96 --- /dev/null +++ b/api/src/main/java/io/grpc/PartialMetricInstrument.java @@ -0,0 +1,97 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc; + +import com.google.common.collect.ImmutableList; +import java.util.List; + +/** + * A partial implementation of the {@link MetricInstrument} interface. This class + * provides common fields and functionality for metric instruments. + */ +@Internal +abstract class PartialMetricInstrument implements MetricInstrument { + protected final int index; + protected final String name; + protected final String description; + protected final String unit; + protected final List requiredLabelKeys; + protected final List optionalLabelKeys; + protected final boolean enableByDefault; + + /** + * Constructs a new PartialMetricInstrument with the specified attributes. + * + * @param index the unique index of this metric instrument + * @param name the name of the metric + * @param description a description of the metric + * @param unit the unit of measurement for the metric + * @param requiredLabelKeys a list of required label keys for the metric + * @param optionalLabelKeys a list of optional label keys for the metric + * @param enableByDefault whether the metric should be enabled by default + */ + protected PartialMetricInstrument(int index, String name, String description, String unit, + List requiredLabelKeys, List optionalLabelKeys, boolean enableByDefault) { + this.index = index; + this.name = name; + this.description = description; + this.unit = unit; + this.requiredLabelKeys = ImmutableList.copyOf(requiredLabelKeys); + this.optionalLabelKeys = ImmutableList.copyOf(optionalLabelKeys); + this.enableByDefault = enableByDefault; + } + + @Override + public int getIndex() { + return index; + } + + @Override + public String getName() { + return name; + } + + @Override + public String getDescription() { + return description; + } + + @Override + public String getUnit() { + return unit; + } + + @Override + public List getRequiredLabelKeys() { + return requiredLabelKeys; + } + + @Override + public List getOptionalLabelKeys() { + return optionalLabelKeys; + } + + @Override + public boolean isEnableByDefault() { + return enableByDefault; + } + + @Override + public String toString() { + return getClass().getName() + "(" + getName() + ")"; + } +} diff --git a/api/src/main/java/io/grpc/QueryParams.java b/api/src/main/java/io/grpc/QueryParams.java new file mode 100644 index 00000000000..31bc2e0e6da --- /dev/null +++ b/api/src/main/java/io/grpc/QueryParams.java @@ -0,0 +1,289 @@ +/* + * Copyright 2026 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc; + +import static com.google.common.base.Preconditions.checkNotNull; + +import com.google.common.base.Splitter; +import java.io.UnsupportedEncodingException; +import java.net.URLDecoder; +import java.net.URLEncoder; +import java.util.ArrayList; +import java.util.List; +import java.util.Objects; +import javax.annotation.Nullable; + +/** + * A parser and mutable container class for {@code application/x-www-form-urlencoded}-style URL + * parameters as conceived by + * RFC 1866 Section 8.2.1. + * + *

For example, a URI like {@code "http://who?name=John+Doe&role=admin&role=user&active"} has: + * + *

    + *
  • A key {@code name} with value {@code John Doe} + *
  • A key {@code role} with value {@code admin} + *
  • A second key named {@code role} with value {@code user} + *
  • "Lone" key {@code active} without a value. + *
+ * + *

This class is meant to be used with {@link io.grpc.Uri}. For example: + * + *

{@code
+ * Uri uri = Uri.parse("http://who?name=John+Doe&role=admin&role=user&active");
+ * QueryParams params = QueryParams.fromRawQuery(uri.getRawQuery());
+ * params.asList().removeIf(e -> "role".equals(e.getKey()) && "admin".equals(e.getValue()));
+ *
+ * Uri modifiedUri = uri.toBuilder().setRawQuery(params.toRawQuery()).build();
+ * }
+ * + *

Note that the empty collection is encoded as a null raw query string, which means "absent" to + * {@link io.grpc.Uri.Builder#setRawQuery}. An empty string query component (""), on the other hand, + * is modeled as an instance of QueryParams containing a single lone (empty) key. It must be this + * way if we are to simultaneously 1) support lone keys, 2) have parse/toRawQuery round-trip + * transparency, and 3) never fail to parse a valid RFC 3986 query component. + * + *

This container and its {@link Entry} take the same position as {@link io.grpc.Uri} on + * equality: raw keys and values must match exactly to be equal. Most callers won't care about how + * keys and values are encoded on the wire and will work with the getters for cooked keys and values + * instead. + * + *

Instances are not safe for concurrent access by multiple threads, including by way of the + * {@link #asList()} view method. + */ +@Internal +public final class QueryParams { + + private static final String UTF_8 = "UTF-8"; + private final List entries = new ArrayList<>(); + + /** Creates a new, empty {@code QueryParams} instance. */ + public QueryParams() {} + + /** + * Parses a raw query string into a {@code QueryParams} instance. + * + *

The input is split on {@code '&'} and each parameter is parsed as either a key/value pair + * (if it contains an equals sign) or a "lone" key (if it does not). + * + *

No valid RFC 3986 query component will fail to parse. For example, {@code ===} is parsed as + * a single parameter with "" as the key and "==" as the value. {@code &&&} is parsed as three + * lone keys named "". And so on. If {@code rawQuery} is not a valid RFC 3986 query component, the + * behavior is undefined. But if you are starting with a {@link io.grpc.Uri}, passing the value + * returned by {@link io.grpc.Uri#getRawQuery()} is always well-defined and will never fail. + * + *

Calling {@link #toRawQuery()} on the returned object is guaranteed to return exactly {@code + * rawQuery}. + * + * @param rawQuery the raw query component to parse, or null to return an empty container + * @return a new instance of {@code QueryParams} representing the input + */ + public static QueryParams fromRawQuery(@Nullable String rawQuery) { + QueryParams params = new QueryParams(); + if (rawQuery != null) { + for (String part : Splitter.on('&').split(rawQuery)) { + int equalsIndex = part.indexOf('='); + if (equalsIndex == -1) { + params.entries.add(Entry.forRawLoneKey(part)); + } else { + String rawKey = part.substring(0, equalsIndex); + String rawValue = part.substring(equalsIndex + 1); + params.entries.add(Entry.forRawKeyValue(rawKey, rawValue)); + } + } + } + return params; + } + + /** + * Returns a mutable list view of the query parameters. + * + * @return the mutable list of entries + */ + public List asList() { + return entries; + } + + /** + * Returns the "raw" query string representation of these parameters, suitable for passing to the + * {@link io.grpc.Uri.Builder#setRawQuery} method. + * + * @return the raw query string + */ + @Nullable + public String toRawQuery() { + if (entries.isEmpty()) { + return null; + } + StringBuilder resultBuilder = new StringBuilder(); + boolean first = true; + for (Entry entry : entries) { + if (!first) { + resultBuilder.append('&'); + } + entry.appendToRawQueryStringBuilder(resultBuilder); + first = false; + } + return resultBuilder.toString(); + } + + @Override + public String toString() { + return entries.toString(); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof QueryParams)) { + return false; + } + QueryParams other = (QueryParams) o; + return entries.equals(other.entries); + } + + @Override + public int hashCode() { + return entries.hashCode(); + } + + /** A single query parameter entry. */ + public static final class Entry { + private final String rawKey; + @Nullable private final String rawValue; + private final String key; + @Nullable private final String value; + + private Entry(String rawKey, @Nullable String rawValue, String key, @Nullable String value) { + this.rawKey = checkNotNull(rawKey, "rawKey"); + this.rawValue = rawValue; + this.key = checkNotNull(key, "key"); + this.value = value; + } + + /** + * Returns the key. + * + *

Any characters that needed URL encoding have already been decoded. + */ + public String getKey() { + return key; + } + + /** + * Returns the value, or {@code null} if this is a "lone" key. + * + *

Any characters that needed URL encoding have already been decoded. + */ + @Nullable + public String getValue() { + return value; + } + + /** Returns {@code true} if this entry has a value, {@code false} if it is a "lone" key. */ + public boolean hasValue() { + return value != null; + } + + /** + * Creates a new key/value pair entry. + * + *

Both key and value can contain any character. They will be URL encoded for you if + * necessary. + */ + public static Entry forKeyValue(String key, String value) { + checkNotNull(key, "key"); + checkNotNull(value, "value"); + return new Entry(encode(key), encode(value), key, value); + } + + /** + * Creates a new query parameter with a "lone" key. + * + *

'key' can contain any character. It will be URL encoded for you later, as necessary. + * + * @param key the decoded key, must not be null + * @return a new {@code Entry} + */ + public static Entry forLoneKey(String key) { + checkNotNull(key, "key"); + return new Entry(encode(key), null, key, null); + } + + static Entry forRawKeyValue(String rawKey, String rawValue) { + checkNotNull(rawKey, "rawKey"); + checkNotNull(rawValue, "rawValue"); + return new Entry(rawKey, rawValue, decode(rawKey), decode(rawValue)); + } + + static Entry forRawLoneKey(String rawKey) { + checkNotNull(rawKey, "rawKey"); + return new Entry(rawKey, null, decode(rawKey), null); + } + + void appendToRawQueryStringBuilder(StringBuilder sb) { + sb.append(rawKey); + if (rawValue != null) { + sb.append('=').append(rawValue); + } + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof Entry)) { + return false; + } + Entry entry = (Entry) o; + return Objects.equals(rawKey, entry.rawKey) && Objects.equals(rawValue, entry.rawValue); + } + + @Override + public int hashCode() { + return Objects.hash(rawKey, rawValue); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + appendToRawQueryStringBuilder(sb); + return sb.toString(); + } + } + + private static String decode(String s) { + try { + // TODO: Use URLDecoder.decode(String, Charset) when available + return URLDecoder.decode(s, UTF_8); + } catch (UnsupportedEncodingException impossible) { + throw new AssertionError("UTF-8 is not supported", impossible); + } + } + + private static String encode(String s) { + try { + // TODO: Use URLEncoder.encode(String, Charset) when available + return URLEncoder.encode(s, UTF_8); + } catch (UnsupportedEncodingException impossible) { + throw new AssertionError("UTF-8 is not supported", impossible); + } + } +} diff --git a/api/src/main/java/io/grpc/Server.java b/api/src/main/java/io/grpc/Server.java index 97ea06a81c2..97c4d495b8a 100644 --- a/api/src/main/java/io/grpc/Server.java +++ b/api/src/main/java/io/grpc/Server.java @@ -21,13 +21,13 @@ import java.util.Collections; import java.util.List; import java.util.concurrent.TimeUnit; -import javax.annotation.concurrent.ThreadSafe; /** * Server for listening for and dispatching incoming calls. It is not expected to be implemented by * application code or interceptors. + * + *

This class is thread-safe. */ -@ThreadSafe public abstract class Server { /** diff --git a/api/src/main/java/io/grpc/ServerBuilder.java b/api/src/main/java/io/grpc/ServerBuilder.java index c2ad566f90f..3effe593e57 100644 --- a/api/src/main/java/io/grpc/ServerBuilder.java +++ b/api/src/main/java/io/grpc/ServerBuilder.java @@ -114,13 +114,15 @@ public T callExecutor(ServerCallExecutorSupplier executorSupplier) { public abstract T addService(BindableService bindableService); /** - * Adds a list of service implementations to the handler registry together. + * Adds a list of service implementations to the handler registry together. This exists for + * convenience - equivalent to repeatedly calling addService() with different services. + * If multiple services on the list use the same name, only the last one on the list will + * be added. * * @param services the list of ServerServiceDefinition objects * @return this * @since 1.37.0 */ - @ExperimentalApi("https://github.com/grpc/grpc-java/issues/7925") public final T addServices(List services) { checkNotNull(services, "services"); for (ServerServiceDefinition service : services) { @@ -433,6 +435,17 @@ public T setBinaryLog(BinaryLog binaryLog) { */ public abstract Server build(); + /** + * Adds a metric sink to the server. + * + * @param metricSink the metric sink to add. + * @return this + */ + @ExperimentalApi("https://github.com/grpc/grpc-java/issues/12693") + public T addMetricSink(MetricSink metricSink) { + return thisT(); + } + /** * Returns the correctly typed version of the builder. */ diff --git a/api/src/main/java/io/grpc/ServerCall.java b/api/src/main/java/io/grpc/ServerCall.java index 7408479a230..3db8ac30e83 100644 --- a/api/src/main/java/io/grpc/ServerCall.java +++ b/api/src/main/java/io/grpc/ServerCall.java @@ -16,6 +16,8 @@ package io.grpc; +import static com.google.common.base.Preconditions.checkArgument; + import javax.annotation.Nullable; /** @@ -209,6 +211,19 @@ public void setCompression(String compressor) { // noop } + /** + * A hint to the call that specifies how many bytes must be queued before + * {@link #isReady()} will return false. A call may ignore this property if + * unsupported. This may only be set before any messages are sent. + * + * @param numBytes The number of bytes that must be queued. Must be a + * positive integer. + */ + @ExperimentalApi("https://github.com/grpc/grpc-java/issues/11021") + public void setOnReadyThreshold(int numBytes) { + checkArgument(numBytes > 0, "numBytes must be positive: %s", numBytes); + } + /** * Returns the level of security guarantee in communications * diff --git a/api/src/main/java/io/grpc/ServerCallHandler.java b/api/src/main/java/io/grpc/ServerCallHandler.java index fdfa9997957..7d7d8217300 100644 --- a/api/src/main/java/io/grpc/ServerCallHandler.java +++ b/api/src/main/java/io/grpc/ServerCallHandler.java @@ -16,13 +16,12 @@ package io.grpc; -import javax.annotation.concurrent.ThreadSafe; - /** * Interface to initiate processing of incoming remote calls. Advanced applications and generated * code will implement this interface to allows {@link Server}s to invoke service methods. + * + *

This interface is thread-safe. */ -@ThreadSafe public interface ServerCallHandler { /** * Starts asynchronous processing of an incoming call. diff --git a/api/src/main/java/io/grpc/ServerInterceptor.java b/api/src/main/java/io/grpc/ServerInterceptor.java index 272b10636cd..2944cf680fe 100644 --- a/api/src/main/java/io/grpc/ServerInterceptor.java +++ b/api/src/main/java/io/grpc/ServerInterceptor.java @@ -16,10 +16,9 @@ package io.grpc; -import javax.annotation.concurrent.ThreadSafe; /** - * Interface for intercepting incoming calls before that are dispatched by + * Interface for intercepting incoming calls before they are dispatched by * {@link ServerCallHandler}. * *

Implementers use this mechanism to add cross-cutting behavior to server-side calls. Common @@ -34,8 +33,9 @@ * without completing the previous ones first. Refer to the * {@link io.grpc.ServerCall.Listener ServerCall.Listener} docs for more details regarding thread * safety of the returned listener. + * + *

This interface is thread-safe. */ -@ThreadSafe public interface ServerInterceptor { /** * Intercept {@link ServerCall} dispatch by the {@code next} {@link ServerCallHandler}. General diff --git a/api/src/main/java/io/grpc/ServerInterceptors.java b/api/src/main/java/io/grpc/ServerInterceptors.java index 0bc6d07c83c..6626c8e5810 100644 --- a/api/src/main/java/io/grpc/ServerInterceptors.java +++ b/api/src/main/java/io/grpc/ServerInterceptors.java @@ -197,7 +197,7 @@ public static ServerServiceDefinition useMarshalledMessages( * to allow for interceptors to handle messages as multiple different ReqT/RespT types within * the chain if the added cost of serialization is not a concern. * - * @param serviceDef the sevice definition to add request and response marshallers to. + * @param serviceDef the service definition to add request and response marshallers to. * @param requestMarshaller request marshaller * @param responseMarshaller response marshaller * @param the request payload type diff --git a/api/src/main/java/io/grpc/ServerRegistry.java b/api/src/main/java/io/grpc/ServerRegistry.java index a083e45a000..1ec7030b82b 100644 --- a/api/src/main/java/io/grpc/ServerRegistry.java +++ b/api/src/main/java/io/grpc/ServerRegistry.java @@ -18,14 +18,15 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; +import com.google.errorprone.annotations.concurrent.GuardedBy; import java.util.ArrayList; import java.util.Collections; import java.util.Comparator; import java.util.LinkedHashSet; import java.util.List; +import java.util.ServiceLoader; import java.util.logging.Level; import java.util.logging.Logger; -import javax.annotation.concurrent.GuardedBy; import javax.annotation.concurrent.ThreadSafe; /** @@ -93,8 +94,9 @@ public static synchronized ServerRegistry getDefaultRegistry() { if (instance == null) { List providerList = ServiceProviders.loadAll( ServerProvider.class, - getHardCodedClasses(), - ServerProvider.class.getClassLoader(), + ServiceLoader.load(ServerProvider.class, ServerProvider.class.getClassLoader()) + .iterator(), + ServerRegistry::getHardCodedClasses, new ServerPriorityAccessor()); instance = new ServerRegistry(); for (ServerProvider provider : providerList) { diff --git a/api/src/main/java/io/grpc/ServerStreamTracer.java b/api/src/main/java/io/grpc/ServerStreamTracer.java index d522610ab3a..81691642131 100644 --- a/api/src/main/java/io/grpc/ServerStreamTracer.java +++ b/api/src/main/java/io/grpc/ServerStreamTracer.java @@ -17,13 +17,13 @@ package io.grpc; import javax.annotation.Nullable; -import javax.annotation.concurrent.ThreadSafe; /** * Listens to events on a stream to collect metrics. + * + *

This class is thread-safe. */ @ExperimentalApi("https://github.com/grpc/grpc-java/issues/2861") -@ThreadSafe public abstract class ServerStreamTracer extends StreamTracer { /** * Called before the interceptors and the call handlers and make changes to the Context object diff --git a/api/src/main/java/io/grpc/ServiceProviders.java b/api/src/main/java/io/grpc/ServiceProviders.java index ac4b27d8783..861688be9fb 100644 --- a/api/src/main/java/io/grpc/ServiceProviders.java +++ b/api/src/main/java/io/grpc/ServiceProviders.java @@ -17,10 +17,13 @@ package io.grpc; import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Supplier; import java.util.ArrayList; import java.util.Collections; import java.util.Comparator; +import java.util.Iterator; import java.util.List; +import java.util.ListIterator; import java.util.ServiceConfigurationError; import java.util.ServiceLoader; @@ -29,42 +32,44 @@ private ServiceProviders() { // do not instantiate } - /** - * If this is not Android, returns the highest priority implementation of the class via - * {@link ServiceLoader}. - * If this is Android, returns an instance of the highest priority class in {@code hardcoded}. - */ - public static T load( - Class klass, - Iterable> hardcoded, - ClassLoader cl, - PriorityAccessor priorityAccessor) { - List candidates = loadAll(klass, hardcoded, cl, priorityAccessor); - if (candidates.isEmpty()) { - return null; - } - return candidates.get(0); - } - /** * If this is not Android, returns all available implementations discovered via * {@link ServiceLoader}. * If this is Android, returns all available implementations in {@code hardcoded}. * The list is sorted in descending priority order. + * + *

{@code serviceLoader} should be created with {@code ServiceLoader.load(MyClass.class, + * MyClass.class.getClassLoader()).iterator()} in order to be detected by R8 so that R8 full mode + * will keep the constructors for the provider classes. */ public static List loadAll( Class klass, - Iterable> hardcoded, - ClassLoader cl, + Iterator serviceLoader, + Supplier>> hardcoded, final PriorityAccessor priorityAccessor) { - Iterable candidates; - if (isAndroid(cl)) { - candidates = getCandidatesViaHardCoded(klass, hardcoded); + Iterator candidates; + if (serviceLoader instanceof ListIterator) { + // A rewriting tool has replaced the ServiceLoader with a List of some sort (R8 uses + // ArrayList, AppReduce uses singletonList). We prefer to use such iterators on Android as + // they won't need reflection like the hard-coded list does. In addition, the provider + // instances will have already been created, so it seems we should use them. + // + // R8: https://r8.googlesource.com/r8/+/490bc53d9310d4cc2a5084c05df4aadaec8c885d/src/main/java/com/android/tools/r8/ir/optimize/ServiceLoaderRewriter.java + // AppReduce: service_loader_pass.cc + candidates = serviceLoader; + } else if (isAndroid(klass.getClassLoader())) { + // Avoid getResource() on Android, which must read from a zip which uses a lot of memory + candidates = getCandidatesViaHardCoded(klass, hardcoded.get()).iterator(); + } else if (!serviceLoader.hasNext()) { + // Attempt to load using the context class loader and ServiceLoader. + // This allows frameworks like http://aries.apache.org/modules/spi-fly.html to plug in. + candidates = ServiceLoader.load(klass).iterator(); } else { - candidates = getCandidatesViaServiceLoader(klass, cl); + candidates = serviceLoader; } List list = new ArrayList<>(); - for (T current: candidates) { + while (candidates.hasNext()) { + T current = candidates.next(); if (!priorityAccessor.isAvailable(current)) { continue; } @@ -101,15 +106,14 @@ static boolean isAndroid(ClassLoader cl) { } /** - * Loads service providers for the {@code klass} service using {@link ServiceLoader}. + * For testing only: Loads service providers for the {@code klass} service using {@link + * ServiceLoader}. Does not support spi-fly and related tricks. */ @VisibleForTesting public static Iterable getCandidatesViaServiceLoader(Class klass, ClassLoader cl) { Iterable i = ServiceLoader.load(klass, cl); - // Attempt to load using the context class loader and ServiceLoader. - // This allows frameworks like http://aries.apache.org/modules/spi-fly.html to plug in. if (!i.iterator().hasNext()) { - i = ServiceLoader.load(klass); + return null; } return i; } diff --git a/api/src/main/java/io/grpc/Status.java b/api/src/main/java/io/grpc/Status.java index 8e7f0b835c2..38cd9581f8e 100644 --- a/api/src/main/java/io/grpc/Status.java +++ b/api/src/main/java/io/grpc/Status.java @@ -16,13 +16,14 @@ package io.grpc; -import static com.google.common.base.Charsets.US_ASCII; -import static com.google.common.base.Charsets.UTF_8; import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.base.Throwables.getStackTraceAsString; +import static java.nio.charset.StandardCharsets.US_ASCII; +import static java.nio.charset.StandardCharsets.UTF_8; import com.google.common.base.MoreObjects; import com.google.common.base.Objects; +import com.google.errorprone.annotations.CheckReturnValue; import io.grpc.Metadata.TrustedAsciiMarshaller; import java.nio.ByteBuffer; import java.util.ArrayList; @@ -30,11 +31,9 @@ import java.util.Collections; import java.util.List; import java.util.TreeMap; -import javax.annotation.CheckReturnValue; import javax.annotation.Nullable; import javax.annotation.concurrent.Immutable; - /** * Defines the status of an operation by providing a standard {@link Code} in conjunction with an * optional descriptive message. Instances of {@code Status} are created by starting with the diff --git a/api/src/main/java/io/grpc/StatusException.java b/api/src/main/java/io/grpc/StatusException.java index e89ac16dc6c..c0a67a375b2 100644 --- a/api/src/main/java/io/grpc/StatusException.java +++ b/api/src/main/java/io/grpc/StatusException.java @@ -25,9 +25,10 @@ */ public class StatusException extends Exception { private static final long serialVersionUID = -660954903976144640L; + @SuppressWarnings("serial") // https://github.com/grpc/grpc-java/issues/1913 private final Status status; + @SuppressWarnings("serial") private final Metadata trailers; - private final boolean fillInStackTrace; /** * Constructs an exception with both a status. See also {@link Status#asException()}. @@ -45,25 +46,9 @@ public StatusException(Status status) { * @since 1.0.0 */ public StatusException(Status status, @Nullable Metadata trailers) { - this(status, trailers, /*fillInStackTrace=*/ true); - } - - StatusException(Status status, @Nullable Metadata trailers, boolean fillInStackTrace) { super(Status.formatThrowableMessage(status), status.getCause()); this.status = status; this.trailers = trailers; - this.fillInStackTrace = fillInStackTrace; - fillInStackTrace(); - } - - @Override - public synchronized Throwable fillInStackTrace() { - // Let's observe final variables in two states! This works because Throwable will invoke this - // method before fillInStackTrace is set, thus doing nothing. After the constructor has set - // fillInStackTrace, this method will properly fill it in. Additionally, sub classes may call - // this normally, because fillInStackTrace will either be set, or this method will be - // overriden. - return fillInStackTrace ? super.fillInStackTrace() : this; } /** @@ -80,6 +65,7 @@ public final Status getStatus() { * * @since 1.0.0 */ + @Nullable public final Metadata getTrailers() { return trailers; } diff --git a/api/src/main/java/io/grpc/StatusOr.java b/api/src/main/java/io/grpc/StatusOr.java new file mode 100644 index 00000000000..b7dd68cfd7b --- /dev/null +++ b/api/src/main/java/io/grpc/StatusOr.java @@ -0,0 +1,111 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkNotNull; + +import com.google.common.base.MoreObjects; +import com.google.common.base.MoreObjects.ToStringHelper; +import com.google.common.base.Objects; +import javax.annotation.Nullable; + +/** Either a Status or a value. */ +@ExperimentalApi("https://github.com/grpc/grpc-java/issues/11563") +public class StatusOr { + private StatusOr(Status status, T value) { + this.status = status; + this.value = value; + } + + /** Construct from a value. */ + public static StatusOr fromValue(T value) { + StatusOr result = new StatusOr(null, value); + return result; + } + + /** Construct from a non-Ok status. */ + public static StatusOr fromStatus(Status status) { + StatusOr result = new StatusOr(checkNotNull(status, "status"), null); + checkArgument(!status.isOk(), "cannot use OK status: %s", status); + return result; + } + + /** Returns whether there is a value. */ + public boolean hasValue() { + return status == null; + } + + /** + * Returns the value if set or throws exception if there is no value set. This method is meant + * to be called after checking the return value of hasValue() first. + */ + public T getValue() { + if (status != null) { + throw new IllegalStateException("No value present."); + } + return value; + } + + /** Returns the status. If there is a value (which can be null), returns OK. */ + public Status getStatus() { + return status == null ? Status.OK : status; + } + + /** + * Note that StatusOr containing statuses, the equality comparision is delegated to + * {@link Status#equals} which just does a reference equality check because equality on + * Statuses is not well defined. + * Instead, do comparison based on their Code with {@link Status#getCode}. The description and + * cause of the Status are unlikely to be stable, and additional fields may be added to Status + * in the future. + */ + @Override + public boolean equals(Object other) { + if (!(other instanceof StatusOr)) { + return false; + } + StatusOr otherStatus = (StatusOr) other; + if (hasValue() != otherStatus.hasValue()) { + return false; + } + if (hasValue()) { + return Objects.equal(value, otherStatus.value); + } + return Objects.equal(status, otherStatus.status); + } + + @Override + public int hashCode() { + return Objects.hashCode(status, value); + } + + @Override + public String toString() { + ToStringHelper stringHelper = MoreObjects.toStringHelper(this); + if (status == null) { + stringHelper.add("value", value); + } else { + stringHelper.add("error", status); + } + return stringHelper.toString(); + } + + @Nullable + private final Status status; + private final T value; +} diff --git a/api/src/main/java/io/grpc/StatusRuntimeException.java b/api/src/main/java/io/grpc/StatusRuntimeException.java index 68b816fc7fa..ebcc2f0d671 100644 --- a/api/src/main/java/io/grpc/StatusRuntimeException.java +++ b/api/src/main/java/io/grpc/StatusRuntimeException.java @@ -26,13 +26,13 @@ public class StatusRuntimeException extends RuntimeException { private static final long serialVersionUID = 1950934672280720624L; + @SuppressWarnings("serial") // https://github.com/grpc/grpc-java/issues/1913 private final Status status; + @SuppressWarnings("serial") private final Metadata trailers; - private final boolean fillInStackTrace; - /** - * Constructs the exception with both a status. See also {@link Status#asRuntimeException()}. + * Constructs the exception with a status. See also {@link Status#asRuntimeException()}. * * @since 1.0.0 */ @@ -47,25 +47,9 @@ public StatusRuntimeException(Status status) { * @since 1.0.0 */ public StatusRuntimeException(Status status, @Nullable Metadata trailers) { - this(status, trailers, /*fillInStackTrace=*/ true); - } - - StatusRuntimeException(Status status, @Nullable Metadata trailers, boolean fillInStackTrace) { super(Status.formatThrowableMessage(status), status.getCause()); this.status = status; this.trailers = trailers; - this.fillInStackTrace = fillInStackTrace; - fillInStackTrace(); - } - - @Override - public synchronized Throwable fillInStackTrace() { - // Let's observe final variables in two states! This works because Throwable will invoke this - // method before fillInStackTrace is set, thus doing nothing. After the constructor has set - // fillInStackTrace, this method will properly fill it in. Additionally, sub classes may call - // this normally, because fillInStackTrace will either be set, or this method will be - // overriden. - return fillInStackTrace ? super.fillInStackTrace() : this; } /** diff --git a/api/src/main/java/io/grpc/StreamTracer.java b/api/src/main/java/io/grpc/StreamTracer.java index 66b3de8be6b..251e6e2b49f 100644 --- a/api/src/main/java/io/grpc/StreamTracer.java +++ b/api/src/main/java/io/grpc/StreamTracer.java @@ -16,15 +16,14 @@ package io.grpc; -import javax.annotation.concurrent.ThreadSafe; - /** * Listens to events on a stream to collect metrics. * *

DO NOT MOCK: Use TestStreamTracer. Mocks are not thread-safe + * + *

This class is thread-safe. */ @ExperimentalApi("https://github.com/grpc/grpc-java/issues/2861") -@ThreadSafe public abstract class StreamTracer { /** * Stream is closed. This will be called exactly once. diff --git a/api/src/main/java/io/grpc/SynchronizationContext.java b/api/src/main/java/io/grpc/SynchronizationContext.java index 910219b4523..94916a1b473 100644 --- a/api/src/main/java/io/grpc/SynchronizationContext.java +++ b/api/src/main/java/io/grpc/SynchronizationContext.java @@ -18,8 +18,10 @@ import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.base.Preconditions.checkState; +import static io.grpc.TimeUtils.convertToNanos; import java.lang.Thread.UncaughtExceptionHandler; +import java.time.Duration; import java.util.Queue; import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.Executor; @@ -162,8 +164,14 @@ public String toString() { return new ScheduledHandle(runnable, future); } + @ExperimentalApi("https://github.com/grpc/grpc-java/issues/11657") + public final ScheduledHandle schedule( + final Runnable task, Duration delay, ScheduledExecutorService timerService) { + return schedule(task, convertToNanos(delay), TimeUnit.NANOSECONDS, timerService); + } + /** - * Schedules a task to be added and run via {@link #execute} after an inital delay and then + * Schedules a task to be added and run via {@link #execute} after an initial delay and then * repeated after the delay until cancelled. * * @param task the task being scheduled @@ -193,6 +201,14 @@ public String toString() { return new ScheduledHandle(runnable, future); } + @ExperimentalApi("https://github.com/grpc/grpc-java/issues/11657") + public final ScheduledHandle scheduleWithFixedDelay( + final Runnable task, Duration initialDelay, Duration delay, + ScheduledExecutorService timerService) { + return scheduleWithFixedDelay(task, convertToNanos(initialDelay), convertToNanos(delay), + TimeUnit.NANOSECONDS, timerService); + } + private static class ManagedRunnable implements Runnable { final Runnable task; @@ -246,4 +262,4 @@ public boolean isPending() { return !(runnable.hasStarted || runnable.isCancelled); } } -} +} \ No newline at end of file diff --git a/api/src/main/java/io/grpc/TimeUtils.java b/api/src/main/java/io/grpc/TimeUtils.java new file mode 100644 index 00000000000..01b8c158822 --- /dev/null +++ b/api/src/main/java/io/grpc/TimeUtils.java @@ -0,0 +1,32 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc; + +import java.time.Duration; + +final class TimeUtils { + private TimeUtils() {} + + @IgnoreJRERequirement + static long convertToNanos(Duration duration) { + try { + return duration.toNanos(); + } catch (ArithmeticException tooBig) { + return duration.isNegative() ? Long.MIN_VALUE : Long.MAX_VALUE; + } + } +} diff --git a/api/src/main/java/io/grpc/Uri.java b/api/src/main/java/io/grpc/Uri.java new file mode 100644 index 00000000000..a88bb6138d8 --- /dev/null +++ b/api/src/main/java/io/grpc/Uri.java @@ -0,0 +1,1184 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkNotNull; +import static com.google.common.base.Preconditions.checkState; + +import com.google.common.base.VerifyException; +import com.google.common.collect.ImmutableList; +import com.google.common.net.InetAddresses; +import com.google.errorprone.annotations.CanIgnoreReturnValue; +import java.net.InetAddress; +import java.net.URISyntaxException; +import java.nio.ByteBuffer; +import java.nio.CharBuffer; +import java.nio.charset.CharacterCodingException; +import java.nio.charset.CharsetEncoder; +import java.nio.charset.CodingErrorAction; +import java.nio.charset.MalformedInputException; +import java.nio.charset.StandardCharsets; +import java.util.BitSet; +import java.util.List; +import java.util.Locale; +import java.util.Objects; +import javax.annotation.Nullable; + +/** + * A not-quite-general-purpose representation of a Uniform Resource Identifier (URI), as defined by + * RFC 3986. + * + *

The URI

+ * + *

A URI identifies a resource by its name or location or both. The resource could be a file, + * service, or some other abstract entity. + * + *

Examples

+ * + *
    + *
  • http://admin@example.com:8080/controlpanel?filter=users#settings + *
  • ftp://[2001:db8::7]/docs/report.pdf + *
  • file:///My%20Computer/Documents/letter.doc + *
  • dns://8.8.8.8/storage.googleapis.com + *
  • mailto:John.Doe@example.com + *
  • tel:+1-206-555-1212 + *
  • urn:isbn:978-1492082798 + *
+ * + *

Limitations

+ * + *

This class aims to meet the needs of grpc-java itself and RPC related code that depend on it. + * It isn't quite general-purpose. It definitely would not be suitable for building an HTTP user + * agent or proxy server. In particular, it: + * + *

    + *
  • Can only represent a URI, not a "URI-reference" or "relative reference". In other words, a + * "scheme" is always required. + *
  • Has no knowledge of the particulars of any scheme, with respect to normalization and + * comparison. We don't know https://google.com is the same as + * https://google.com:443, that file:/// is the same as + * file://localhost, or that joe@example.com is the same as + * joe@EXAMPLE.COM. No one class can or should know everything about every scheme so + * all this is better handled at a higher layer. + *
  • Implements {@link #equals(Object)} as a char-by-char comparison. Expect false negatives. + *
  • Does not support "IPvFuture" literal addresses. + *
  • Does not reflect how web browsers parse user input or the URL Living Standard. + *
  • Does not support different character encodings. Assumes UTF-8 in several places. + *
+ * + *

Migrating from RFC 2396 and {@link java.net.URI}

+ * + *

Those migrating from {@link java.net.URI} and/or its primary specification in RFC 2396 should + * note some differences. + * + *

Uniform Hierarchical Syntax

+ * + *

RFC 3986 unifies the older ideas of "hierarchical" and "opaque" URIs into a single generic + * syntax. What RFC 2396 called an opaque "scheme-specific part" is always broken out by RFC 3986 + * into an authority and path hierarchy, followed by query and fragment components. Accordingly, + * this class has only getters for those components but no {@link + * java.net.URI#getSchemeSpecificPart()} analog. + * + *

The RFC 3986 definition of path is now more liberal to accommodate this: + * + *

    + *
  • Path doesn't have to start with a slash. For example, the path of + * urn:isbn:978-1492082798 is isbn:978-1492082798 even though it doesn't + * look much like a file system path. + *
  • The path can now be empty. So Android's + * intent:#Intent;action=MAIN;category=LAUNCHER;end is now a valid {@link Uri}. Even + * the scheme-only about: is now valid. + *
+ * + *

The uniform syntax always understands what follows a '?' to be a query string. For example, + * mailto:me@example.com?subject=foo now has a query component whereas RFC 2396 + * considered everything after the mailto: scheme to be opaque. + * + *

Same goes for fragment. data:image/png;...#xywh=0,0,10,10 now has a fragment + * whereas RFC 2396 considered everything after the scheme to be opaque. + * + *

Uniform Authority Syntax

+ * + *

RFC 2396 tried to guess if an authority was a "server" (host:port) or "registry-based" + * (arbitrary string) based on its contents. RFC 3986 expects every authority to look like + * [userinfo@]host[:port] and loosens the definition of a "host" to accommodate. Accordingly, this + * class has no equivalent to {@link java.net.URI#parseServerAuthority()} -- authority was parsed + * into its components and checked for validity when the {@link Uri} was created. + * + *

Other Specific Differences

+ * + *

RFC 2396 does not allow underscores in a host name, meaning {@link java.net.URI} switches to + * opaque mode when it sees one. {@link Uri} does allow underscores in host, to accommodate + * registries other than DNS. So http://my_site.com:8080/index.html now parses as a + * host, port and path rather than a single opaque scheme-specific part. + * + *

{@link Uri} strictly *requires* square brackets in the query string and fragment to be + * percent-encoded whereas RFC 2396 merely recommended doing so. + * + *

Other URx classes are "liberal in what they accept and strict in what they produce." {@link + * Uri#parse(String)} and {@link Uri#create(String)}, however, are strict in what they accept and + * transparent when asked to reproduce it via {@link Uri#toString()}. The former policy may be + * appropriate for parsing user input or web content, but this class is meant for gRPC clients, + * servers and plugins like name resolvers where human error at runtime is less likely and best + * detected early. {@link java.net.URI#create(String)} is similarly strict, which makes migration + * easy, except for the server/registry-based ambiguity addressed by {@link + * java.net.URI#parseServerAuthority()}. + * + *

{@link java.net.URI} and {@link Uri} both support IPv6 literals in square brackets as defined + * by RFC 2732. + * + *

{@link java.net.URI} supports IPv6 scope IDs but accepts and emits a non-standard syntax. + * {@link Uri} implements the newer RFC 6874, which percent encodes scope IDs and the % delimiter + * itself. RFC 9844 claims to obsolete RFC 6874 because web browsers would not support it. This + * class implements RFC 6874 anyway, mostly to avoid creating a barrier to migration away from + * {@link java.net.URI}. + * + *

Some URI components, e.g. scheme, are required while others may or may not be present, e.g. + * authority. {@link Uri} is careful to preserve the distinction between an absent string component + * (getter returns null) and one with an empty value (getter returns ""). {@link java.net.URI} makes + * this distinction too, *except* when it comes to the authority and host components: {@link + * java.net.URI#getAuthority()} and {@link java.net.URI#getHost()} return null when an authority is + * absent, e.g. file:/path as expected. But these methods surprisingly also return null + * when the authority is the empty string, e.g.file:///path. {@link Uri}'s getters + * correctly return null and "" in these cases, respectively, as one would expect. + */ +@Internal +public final class Uri { + // Components are stored percent-encoded, just as originally parsed for transparent parse/toString + // round-tripping. + private final String scheme; // != null since we don't support relative references. + @Nullable private final String userInfo; + @Nullable private final String host; + @Nullable private final String port; + private final String path; // In RFC 3986, path is always defined (but can be empty). + @Nullable private final String query; + @Nullable private final String fragment; + + private Uri(Builder builder) { + this.scheme = checkNotNull(builder.scheme, "scheme"); + this.userInfo = builder.userInfo; + this.host = builder.host; + this.port = builder.port; + this.path = builder.path; + this.query = builder.query; + this.fragment = builder.fragment; + + // Checks common to the parse() and Builder code paths. + if (hasAuthority()) { + if (!path.isEmpty() && !path.startsWith("/")) { + throw new IllegalArgumentException("Has authority -- Non-empty path must start with '/'"); + } + } else { + if (path.startsWith("//")) { + throw new IllegalArgumentException("No authority -- Path cannot start with '//'"); + } + } + } + + /** + * Parses a URI from its string form. + * + * @throws URISyntaxException if 's' is not a valid RFC 3986 URI. + */ + public static Uri parse(String s) throws URISyntaxException { + try { + return create(s); + } catch (IllegalArgumentException e) { + throw new URISyntaxException(s, e.getMessage()); + } + } + + /** + * Creates a URI from a string assumed to be valid. + * + *

Useful for defining URI constants in code. Not for user input. + * + * @throws IllegalArgumentException if 's' is not a valid RFC 3986 URI. + */ + public static Uri create(String s) { + Builder builder = new Builder(); + int i = 0; + final int n = s.length(); + + // 3.1. Scheme: Look for a ':' before '/', '?', or '#'. + int schemeColon = -1; + for (; i < n; ++i) { + char c = s.charAt(i); + if (c == ':') { + schemeColon = i; + break; + } else if (c == '/' || c == '?' || c == '#') { + break; + } + } + if (schemeColon < 0) { + throw new IllegalArgumentException("Missing required scheme."); + } + builder.setRawScheme(s.substring(0, schemeColon)); + + // 3.2. Authority. Look for '//' then keep scanning until '/', '?', or '#'. + i = schemeColon + 1; + if (i + 1 < n && s.charAt(i) == '/' && s.charAt(i + 1) == '/') { + // "//" just means we have an authority. Skip over it. + i += 2; + + int authorityStart = i; + for (; i < n; ++i) { + char c = s.charAt(i); + if (c == '/' || c == '?' || c == '#') { + break; + } + } + builder.setRawAuthority(s.substring(authorityStart, i)); + } + + // 3.3. Path: Whatever is left before '?' or '#'. + int pathStart = i; + for (; i < n; ++i) { + char c = s.charAt(i); + if (c == '?' || c == '#') { + break; + } + } + builder.setRawPath(s.substring(pathStart, i)); + + // 3.4. Query, if we stopped at '?'. + if (i < n && s.charAt(i) == '?') { + i++; // Skip '?' + int queryStart = i; + for (; i < n; ++i) { + char c = s.charAt(i); + if (c == '#') { + break; + } + } + builder.setRawQuery(s.substring(queryStart, i)); + } + + // 3.5. Fragment, if we stopped at '#'. + if (i < n && s.charAt(i) == '#') { + ++i; // Skip '#' + builder.setRawFragment(s.substring(i)); + } + + return builder.build(); + } + + private static int findPortStartColon(String authority, int hostStart) { + for (int i = authority.length() - 1; i >= hostStart; --i) { + char c = authority.charAt(i); + if (c == ':') { + return i; + } + if (c == ']') { + // Hit the end of IP-literal. Any further colon is inside it and couldn't indicate a port. + break; + } + if (!digitChars.get(c)) { + // Found a non-digit, non-colon, non-bracket. + // This means there is no valid port (e.g. host is "example.com") + break; + } + } + return -1; + } + + // Checks a raw path for validity and parses it into segments. Let 'out' be null to just validate. + private static void parseAssumedUtf8PathIntoSegments( + String path, ImmutableList.Builder out) { + // Skip the first slash so it doesn't count as an empty segment at the start. + // (e.g., "/a" -> ["a"], not ["", "a"]) + int start = path.startsWith("/") ? 1 : 0; + + for (int i = start; i < path.length(); ) { + int nextSlash = path.indexOf('/', i); + String segment; + if (nextSlash >= 0) { + // Typical segment case (e.g., "foo" in "/foo/bar"). + segment = path.substring(i, nextSlash); + i = nextSlash + 1; + } else { + // Final segment case (e.g., "bar" in "/foo/bar"). + segment = path.substring(i); + i = path.length(); + } + if (out != null) { + out.add(percentDecodeAssumedUtf8(segment)); + } else { + checkPercentEncodedArg(segment, "path segment", pChars); + } + } + + // RFC 3986 says a trailing slash creates a final empty segment. + // (e.g., "/foo/" -> ["foo", ""]) + if (path.endsWith("/") && out != null) { + out.add(""); + } + } + + /** Returns the scheme of this URI. */ + public String getScheme() { + return scheme; + } + + /** + * Returns the percent-decoded "Authority" component of this URI, or null if not present. + * + *

NB: This method's decoding is lossy -- It only exists for compatibility with {@link + * java.net.URI}. Prefer {@link #getRawAuthority()} or work instead with authority in terms of its + * individual components ({@link #getUserInfo()}, {@link #getHost()} and {@link #getPort()}). The + * problem with getAuthority() is that it returns the delimited concatenation of the percent- + * decoded userinfo, host and port components. But both userinfo and host can contain the '@' + * character, which becomes indistinguishable from the userinfo/host delimiter after decoding. For + * example, URIs scheme://x@y%40z and scheme://x%40y@z have different + * userinfo and host components but getAuthority() returns "x@y@z" for both of them. + * + *

NB: This method assumes the "host" component was encoded as UTF-8, as mandated by RFC 3986. + * This method also assumes the "user information" part of authority was encoded as UTF-8, + * although RFC 3986 doesn't specify an encoding. + * + *

Decoding errors are indicated by a {@code '\u005CuFFFD'} unicode replacement character in + * the output. Callers who want to detect and handle errors in some other way should call {@link + * #getRawAuthority()}, {@link #percentDecode(CharSequence)}, then decode the bytes for + * themselves. + */ + @Nullable + public String getAuthority() { + return percentDecodeAssumedUtf8(getRawAuthority()); + } + + private boolean hasAuthority() { + return host != null; + } + + /** + * Returns the "authority" component of this URI in its originally parsed, possibly + * percent-encoded form. + */ + @Nullable + public String getRawAuthority() { + if (hasAuthority()) { + StringBuilder sb = new StringBuilder(); + appendAuthority(sb); + return sb.toString(); + } + return null; + } + + private void appendAuthority(StringBuilder sb) { + if (userInfo != null) { + sb.append(userInfo).append('@'); + } + if (host != null) { + sb.append(host); + } + if (port != null) { + sb.append(':').append(port); + } + } + + /** + * Returns the percent-decoded "User Information" component of this URI, or null if not present. + * + *

NB: This method *assumes* this component was encoded as UTF-8, although RFC 3986 doesn't + * specify an encoding. + * + *

Decoding errors are indicated by a {@code '\u005CuFFFD'} unicode replacement character in + * the output. Callers who want to detect and handle errors in some other way should call {@link + * #getRawUserInfo()}, {@link #percentDecode(CharSequence)}, then decode the bytes for themselves. + */ + @Nullable + public String getUserInfo() { + return percentDecodeAssumedUtf8(userInfo); + } + + /** + * Returns the "User Information" component of this URI in its originally parsed, possibly + * percent-encoded form. + */ + @Nullable + public String getRawUserInfo() { + return userInfo; + } + + /** + * Returns the percent-decoded "host" component of this URI, or null if not present. + * + *

This method assumes the host was encoded as UTF-8, as mandated by RFC 3986. + * + *

Decoding errors are indicated by a {@code '\u005CuFFFD'} unicode replacement character in + * the output. Callers who want to detect and handle errors in some other way should call {@link + * #getRawHost()}, {@link #percentDecode(CharSequence)}, then decode the bytes for themselves. + */ + @Nullable + public String getHost() { + return percentDecodeAssumedUtf8(host); + } + + /** + * Returns the host component of this URI in its originally parsed, possibly percent-encoded form. + */ + @Nullable + public String getRawHost() { + return host; + } + + /** Returns the "port" component of this URI, or -1 if empty or not present. */ + public int getPort() { + return port != null && !port.isEmpty() ? Integer.parseInt(port) : -1; + } + + /** Returns the raw port component of this URI in its originally parsed form. */ + @Nullable + public String getRawPort() { + return port; + } + + /** + * Returns the (possibly empty) percent-decoded "path" component of this URI. + * + *

NB: This method *assumes* the path was encoded as UTF-8, although RFC 3986 doesn't specify + * an encoding. + * + *

Decoding errors are indicated by a {@code '\u005CuFFFD'} unicode replacement character in + * the output. Callers who want to detect and handle errors in some other way should call {@link + * #getRawPath()}, {@link #percentDecode(CharSequence)}, then decode the bytes for themselves. + * + *

NB: Prefer {@link #getPathSegments()} because this method's decoding is lossy. For example, + * consider these (different) URIs: + * + *

    + *
  • file:///home%2Ffolder/my%20file + *
  • file:///home/folder/my%20file + *
+ * + *

Calling getPath() on each returns the same string: /home/folder/my file. You + * can't tell whether the second '/' character is part of the first path segment or separates the + * first and second path segments. This method only exists to ease migration from {@link + * java.net.URI}. + */ + public String getPath() { + return percentDecodeAssumedUtf8(path); + } + + /** + * Returns this URI's path as a list of path segments not including the '/' segment delimiters. + * + *

Prefer this method over {@link #getPath()} because it preserves the distinction between + * segment separators and literal '/'s within a path segment. + * + *

A trailing '/' delimiter in the path results in the empty string as the last element in the + * returned list. For example, file://localhost/foo/bar/ has path segments + * ["foo", "bar", ""] + * + *

A leading '/' delimiter cannot be detected using this method. For example, both + * dns:example.com and dns:///example.com have the same list of path segments: + * ["example.com"]. Use {@link #isPathAbsolute()} or {@link #isPathRootless()} to + * distinguish these cases. + * + *

The returned list is immutable. + */ + public List getPathSegments() { + // Returned list must be immutable but we intentionally keep guava out of the public API. + ImmutableList.Builder segmentsBuilder = ImmutableList.builder(); + parseAssumedUtf8PathIntoSegments(path, segmentsBuilder); + return segmentsBuilder.build(); + } + + /** + * Returns true iff this URI's path component starts with a path segment (rather than the '/' + * segment delimiter). + * + *

The path of an RFC 3986 URI is either empty, absolute (starts with the '/' segment + * delimiter) or rootless (starts with a path segment). For example, tel:+1-206-555-1212 + * , mailto:me@example.com and urn:isbn:978-1492082798 all have + * rootless paths. mailto:%2Fdev%2Fnull@example.com is also rootless because its + * percent-encoded slashes are not segment delimiters but rather part of the first and only path + * segment. + * + *

Contrast rootless paths with absolute ones (see {@link #isPathAbsolute()}. + */ + public boolean isPathRootless() { + return !path.isEmpty() && !path.startsWith("/"); + } + + /** + * Returns true iff this URI's path component starts with the '/' segment delimiter (rather than a + * path segment). + * + *

The path of an RFC 3986 URI is either empty, absolute (starts with the '/' segment + * delimiter) or rootless (starts with a path segment). For example, file:///resume.txt + * , file:/resume.txt and file://localhost/ all have absolute + * paths while tel:+1-206-555-1212's path is not absolute. + * mailto:%2Fdev%2Fnull@example.com is also not absolute because its percent-encoded + * slashes are not segment delimiters but rather part of the first and only path segment. + * + *

Contrast absolute paths with rootless ones (see {@link #isPathRootless()}. + * + *

NB: The term "absolute" has two different meanings in RFC 3986 which are easily confused. + * This method tests for a property of this URI's path component. Contrast with {@link + * #isAbsolute()} which tests the URI itself for a different property. + */ + public boolean isPathAbsolute() { + return path.startsWith("/"); + } + + /** + * Returns the path component of this URI in its originally parsed, possibly percent-encoded form. + */ + public String getRawPath() { + return path; + } + + /** + * Returns the query component of this URI in its originally parsed, possibly percent-encoded + * form, without any leading '?' character, or null if not present. + * + *

The query component can only be read in its raw form. That’s because virtually everyone uses + * query as a container for structured data, with some additional layer of encoding not present in + * RFC-3986. Like 'application/x-www-form-urlencoded', which encodes key/value pairs like so: + * ?k1=v1&k2=v+2. The encoding of these containers always has characters that take on + * a special delimiter meaning when not percent-encoded and a literal meaning when they are (like + * '&', '=' and '+' above). Since it matters whether a character was percent encoded or not, + * offering a '#getQuery()' method that percent-decodes everything like we do for other components + * would be error-prone. + */ + @Nullable + public String getRawQuery() { + return query; + } + + /** + * Returns the percent-decoded "fragment" component of this URI, or null if not present. + * + *

NB: This method assumes the fragment was encoded as UTF-8, although RFC 3986 doesn't specify + * an encoding. + * + *

Decoding errors are indicated by a {@code '\u005CuFFFD'} unicode replacement character in + * the output. Callers who want to detect and handle errors in some other way should call {@link + * #getRawFragment()}, {@link #percentDecode(CharSequence)}, then decode the bytes for themselves. + * + *

NB: Choose carefully between this method and {@link #getRawFragment()}. Many URI schemes + * embed further structure inside the fragment that isn't part of the RFC 3986 generic syntax. For + * example, Android uses the fragment to encode the many fields of an Intent, like {@code + * intent:#Intent;S.key=val;end;}. And the URI of a JSON resource may use RFC 6901 in its fragment + * to point at a particular node, e.g. {@code + * file:/etc/config/service.json#/methodConfig/0/retryPolicy/maxBackoff}. + * + *

When percent-encoding is used to escape internal delimiters, like a literal ';' and '=' in + * an `intent:`, call {@link #getRawFragment()} to preserve that percent-encoding, or risk + * corruption. Conversely, use *this* method when percent-decoding is needed *before* any further + * interpretation, like with a JSON pointer, which must be percent-encoded in a URI fragment but + * uses a completely different method of escaping literal '/' characters. + */ + @Nullable + public String getFragment() { + return percentDecodeAssumedUtf8(fragment); + } + + /** + * Returns the fragment component of this URI in its original, possibly percent-encoded form, and + * without any leading '#' character. + * + *

NB: Choose carefully between this method and {@link #getFragment()}. See that Javadoc for + * details. + */ + @Nullable + public String getRawFragment() { + return fragment; + } + + /** + * {@inheritDoc} + * + *

If this URI was created by {@link #parse(String)} or {@link #create(String)}, then the + * returned string will match that original input exactly. + */ + @Override + public String toString() { + // https://datatracker.ietf.org/doc/html/rfc3986#section-5.3 + StringBuilder sb = new StringBuilder(); + sb.append(scheme).append(':'); + if (hasAuthority()) { + sb.append("//"); + appendAuthority(sb); + } + sb.append(path); + if (query != null) { + sb.append('?').append(query); + } + if (fragment != null) { + sb.append('#').append(fragment); + } + return sb.toString(); + } + + /** + * Returns true iff this URI has a scheme and an authority/path hierarchy, but no fragment. + * + *

All instances of {@link Uri} are RFC 3986 URIs, not "relative references", so this method is + * equivalent to {@code getFragment() == null}. It mostly exists for compatibility with {@link + * java.net.URI}. + */ + public boolean isAbsolute() { + return scheme != null && fragment == null; + } + + /** + * {@inheritDoc} + * + *

Two instances of {@link Uri} are equal if and only if they have the same string + * representation, which RFC 3986 calls "Simple String Comparison" (6.2.1). Callers with a higher + * layer expectation of equality (e.g. http://some%2Dhost:80/foo/./bar.txt ~= + * http://some-host/foo/bar.txt) will experience false negatives. + */ + @Override + public boolean equals(Object otherObj) { + if (!(otherObj instanceof Uri)) { + return false; + } + Uri other = (Uri) otherObj; + return Objects.equals(scheme, other.scheme) + && Objects.equals(userInfo, other.userInfo) + && Objects.equals(host, other.host) + && Objects.equals(port, other.port) + && Objects.equals(path, other.path) + && Objects.equals(query, other.query) + && Objects.equals(fragment, other.fragment); + } + + @Override + public int hashCode() { + return Objects.hash(scheme, userInfo, host, port, path, query, fragment); + } + + /** Returns a new Builder initialized with the fields of this URI. */ + public Builder toBuilder() { + return new Builder(this); + } + + /** Creates a new {@link Builder} with all fields uninitialized or set to their default values. */ + public static Builder newBuilder() { + return new Builder(); + } + + /** Builder for {@link Uri}. */ + public static final class Builder { + private String scheme; + private String path = ""; + private String query; + private String fragment; + private String userInfo; + private String host; + private String port; + + private Builder() {} + + Builder(Uri prototype) { + this.scheme = prototype.scheme; + this.userInfo = prototype.userInfo; + this.host = prototype.host; + this.port = prototype.port; + this.path = prototype.path; + this.query = prototype.query; + this.fragment = prototype.fragment; + } + + /** + * Sets the scheme, e.g. "https", "dns" or "xds". + * + *

This field is required. + * + * @return this, for fluent building + * @throws IllegalArgumentException if the scheme is invalid. + */ + @CanIgnoreReturnValue + public Builder setScheme(String scheme) { + return setRawScheme(scheme.toLowerCase(Locale.ROOT)); + } + + @CanIgnoreReturnValue + Builder setRawScheme(String scheme) { + if (scheme.isEmpty() || !alphaChars.get(scheme.charAt(0))) { + throw new IllegalArgumentException("Scheme must start with an alphabetic char"); + } + for (int i = 0; i < scheme.length(); i++) { + char c = scheme.charAt(i); + if (!schemeChars.get(c)) { + throw new IllegalArgumentException("Invalid character in scheme at index " + i); + } + } + this.scheme = scheme; + return this; + } + + /** + * Specifies the new URI's path component as a string of zero or more '/' delimited segments. + * + *

Path segments can consist of any string of codepoints. Codepoints that can't be encoded + * literally will be percent-encoded for you. + * + *

If a URI contains an authority component, then the path component must either be empty or + * begin with a slash ("/") character. If a URI does not contain an authority component, then + * the path cannot begin with two slash characters ("//"). + * + *

This method interprets all '/' characters in 'path' as segment delimiters. If any of your + * segments contain literal '/' characters, call {@link #setRawPath(String)} instead. + * + *

See RFC 3986 3.3 + * for more. + * + *

This field is required but can be empty (its default value). + * + * @param path the new path + * @return this, for fluent building + */ + @CanIgnoreReturnValue + public Builder setPath(String path) { + checkArgument(path != null, "Path can be empty but not null"); + this.path = percentEncode(path, pCharsAndSlash); + return this; + } + + /** + * Specifies the new URI's path component as a string of zero or more '/' delimited segments. + * + *

Path segments can consist of any string of codepoints but the caller must first percent- + * encode anything other than RFC 3986's "pchar" character class using UTF-8. + * + *

If a URI contains an authority component, then the path component must either be empty or + * begin with a slash ("/") character. If a URI does not contain an authority component, then + * the path cannot begin with two slash characters ("//"). + * + *

This method interprets all '/' characters in 'path' as segment delimiters. If any of your + * segments contain literal '/' characters, you must percent-encode them. + * + *

See RFC 3986 3.3 + * for more. + * + *

This field is required but can be empty (its default value). + * + * @param path the new path, a string consisting of characters from "pchar" + * @return this, for fluent building + */ + @CanIgnoreReturnValue + public Builder setRawPath(String path) { + checkArgument(path != null, "Path can be empty but not null"); + parseAssumedUtf8PathIntoSegments(path, null); + this.path = path; + return this; + } + + /** + * Specifies the query component of the new URI, possibly percent-encoded, exactly as it will + * appear in the string form of the built URI. + * + *

'query' must only contain codepoints from RFC 3986's "query" character class. Any other + * characters must be percent-encoded using UTF-8. Do not include the leading '?' delimiter. + * + *

The query component can only be provided in its raw form. That’s because virtually + * everyone uses query as a container for structured data, with some additional layer of + * encoding not present in RFC-3986. Like 'application/x-www-form-urlencoded', which encodes + * key/value pairs like so: ?k1=v1&k2=v+2. The encoding of these containers always + * has characters that take on a special delimiter meaning when not percent-encoded and a + * literal meaning when they are (like '&', '=' and '+' above). Since 'query' must have already + * been carefully percent-encoded externally, a '#setQuery(String)' method that percent-encodes + * an assumed-cooked string would be error-prone. + * + *

This field is optional. + * + * @param query the new query component, or null to clear this field + * @return this, for fluent building + */ + @CanIgnoreReturnValue + public Builder setRawQuery(@Nullable String query) { + if (query != null) { + checkPercentEncodedArg(query, "query", queryChars); + } + this.query = query; + return this; + } + + /** + * Specifies the fragment component of the new URI (not including the leading '#'). + * + *

The fragment can contain any string of codepoints. Codepoints that can't be encoded + * literally will be percent-encoded for you as UTF-8. + * + *

NB: Choose carefully between this method and {@link #setRawFragment(String)}. Many URI + * schemes embed further structure in the fragment that isn't part of the RFC 3986 generic + * syntax. These schemes often use internal delimiters that must be carefully percent-encoded in + * ways that this method doesn't understand. See {@link #getFragment()} for an example. In that + * case, callers should percent-encode externally then call {@link #setRawFragment(String)} + * instead. + * + *

This field is optional. + * + * @param fragment the new fragment component, or null to clear this field + * @return this, for fluent building + */ + @CanIgnoreReturnValue + public Builder setFragment(@Nullable String fragment) { + this.fragment = percentEncode(fragment, fragmentChars); + return this; + } + + /** + * Specifies the fragment component of the new URI, already percent-encoded, exactly as it will + * appear after the '#' delimiter in the string form of the built URI. + * + *

NB: Choose carefully between this method and {@link #setFragment(String)}. {@code + * fragment} must only contain codepoints from RFC 3986's "fragment" character class. Use + * percent-encoding and UTF-8 to represent anything else. In certain cases, you can use {@link + * #setFragment(String)} to have the fragment percent-encoded for you instead, but see that + * method's Javadoc for its limitations. + * + *

This field is optional. + * + * @param fragment the new fragment component, or null to clear this field + * @return this, for fluent building + * @throws IllegalArgumentException if 'fragment' contains forbidden characters + */ + @CanIgnoreReturnValue + public Builder setRawFragment(@Nullable String fragment) { + if (fragment != null) { + checkPercentEncodedArg(fragment, "fragment", fragmentChars); + } + this.fragment = fragment; + return this; + } + + /** + * Set the "user info" component of the new URI, e.g. "username:password", not including the + * trailing '@' character. + * + *

User info can contain any string of codepoints. Codepoints that can't be encoded literally + * will be percent-encoded for you as UTF-8. + * + *

This field is optional. + * + * @param userInfo the new "user info" component, or null to clear this field + * @return this, for fluent building + */ + @CanIgnoreReturnValue + public Builder setUserInfo(@Nullable String userInfo) { + this.userInfo = percentEncode(userInfo, userInfoChars); + return this; + } + + @CanIgnoreReturnValue + Builder setRawUserInfo(String userInfo) { + checkPercentEncodedArg(userInfo, "userInfo", userInfoChars); + this.userInfo = userInfo; + return this; + } + + /** + * Specifies the "host" component of the new URI in its "registered name" form (usually DNS), + * e.g. "server.com". + * + *

The registered name can contain any string of codepoints. Codepoints that can't be encoded + * literally will be percent-encoded for you as UTF-8. + * + *

This field is optional. + * + * @param regName the new host component in "registered name" form, or null to clear this field + * @return this, for fluent building + */ + @CanIgnoreReturnValue + public Builder setHost(@Nullable String regName) { + if (regName != null) { + regName = regName.toLowerCase(Locale.ROOT); + regName = percentEncode(regName, regNameChars); + } + this.host = regName; + return this; + } + + /** + * Specifies the "host" component of the new URI as an IP address. + * + *

This field is optional. + * + * @param addr the new "host" component in InetAddress form, or null to clear this field + * @return this, for fluent building + */ + @CanIgnoreReturnValue + public Builder setHost(@Nullable InetAddress addr) { + this.host = addr != null ? toUriString(addr) : null; + return this; + } + + private static String toUriString(InetAddress addr) { + // InetAddresses.toUriString(addr) is almost enough but neglects RFC 6874 percent encoding. + String inetAddrStr = InetAddresses.toUriString(addr); + int percentIndex = inetAddrStr.indexOf('%'); + if (percentIndex < 0) { + return inetAddrStr; + } + + String scope = inetAddrStr.substring(percentIndex, inetAddrStr.length() - 1); + return inetAddrStr.substring(0, percentIndex) + percentEncode(scope, unreservedChars) + "]"; + } + + @CanIgnoreReturnValue + Builder setRawHost(String host) { + if (host.startsWith("[") && host.endsWith("]")) { + // IP-literal: Guava's isUriInetAddress() is almost enough but it doesn't check the scope. + int percentIndex = host.indexOf('%'); + if (percentIndex > 0) { + String scope = host.substring(percentIndex, host.length() - 1); + checkPercentEncodedArg(scope, "scope", unreservedChars); + } + } + // IP-literal validation is complicated so we delegate it to Guava. We use this particular + // method of InetAddresses because it doesn't try to match interfaces on the local machine. + // (The validity of a URI should be the same no matter which machine does the parsing.) + // TODO(jdcormie): IPFuture + if (!InetAddresses.isUriInetAddress(host)) { + // Must be a "registered name". + checkPercentEncodedArg(host, "host", regNameChars); + } + this.host = host; + return this; + } + + /** + * Specifies the "port" component of the new URI, e.g. "8080". + * + *

The port can be any non-negative integer. A negative value represents "no port". + * + *

This field is optional. + * + * @param port the new "port" component, or -1 to clear this field + * @return this, for fluent building + */ + @CanIgnoreReturnValue + public Builder setPort(int port) { + this.port = port < 0 ? null : Integer.toString(port); + return this; + } + + @CanIgnoreReturnValue + Builder setRawPort(String port) { + if (port != null && !port.isEmpty()) { + try { + Integer.parseInt(port); // Result unused. + } catch (NumberFormatException e) { + throw new IllegalArgumentException("Invalid port", e); + } + } + this.port = port; + return this; + } + + /** + * Specifies the userinfo, host and port URI components all at once using a single string. + * + *

This setter is "raw" in the sense that special characters in userinfo and host must be + * passed in percent-encoded. See RFC 3986 3.2 for the set + * of characters allowed in each component of an authority. + * + *

There's no "cooked" method to set authority like for other URI components because + * authority is a *compound* URI component whose userinfo, host and port components are + * delimited with special characters '@' and ':'. But the first two of those components can + * themselves contain these delimiters so we need percent-encoding to parse them unambiguously. + * + * @param authority an RFC 3986 authority string that will be used to set userinfo, host and + * port, or null to clear all three of those components + */ + @CanIgnoreReturnValue + public Builder setRawAuthority(@Nullable String authority) { + if (authority == null) { + setUserInfo(null); + setHost((String) null); + setPort(-1); + } else { + // UserInfo. Easy because '@' cannot appear unencoded inside userinfo or host. + int userInfoEnd = authority.indexOf('@'); + if (userInfoEnd >= 0) { + setRawUserInfo(authority.substring(0, userInfoEnd)); + } else { + setUserInfo(null); + } + + // Host/Port. + int hostStart = userInfoEnd >= 0 ? userInfoEnd + 1 : 0; + int portStartColon = findPortStartColon(authority, hostStart); + if (portStartColon < 0) { + setRawHost(authority.substring(hostStart)); + setPort(-1); + } else { + setRawHost(authority.substring(hostStart, portStartColon)); + setRawPort(authority.substring(portStartColon + 1)); + } + } + return this; + } + + /** Builds a new instance of {@link Uri} as specified by the setters. */ + public Uri build() { + checkState(scheme != null, "Missing required scheme."); + if (host == null) { + checkState(port == null, "Cannot set port without host."); + checkState(userInfo == null, "Cannot set userInfo without host."); + } + return new Uri(this); + } + } + + /** + * Decodes a string of characters in the range [U+0000, U+007F] to bytes. + * + *

Each percent-encoded sequence (e.g. "%F0" or "%2a", as defined by RFC 3986 2.1) is decoded + * to the octet it encodes. Other characters are decoded to their code point's single byte value. + * A literal % character must be encoded as %25. + * + * @throws IllegalArgumentException if 's' contains characters out of range or invalid percent + * encoding sequences. + */ + public static ByteBuffer percentDecode(CharSequence s) { + // This is large enough because each input character needs *at most* one byte of output. + ByteBuffer outBuf = ByteBuffer.allocate(s.length()); + percentDecode(s, "input", null, outBuf); + outBuf.flip(); + return outBuf; + } + + private static void percentDecode( + CharSequence s, String what, BitSet allowedChars, ByteBuffer outBuf) { + for (int i = 0; i < s.length(); i++) { + char c = s.charAt(i); + if (c == '%') { + if (i + 2 >= s.length()) { + throw new IllegalArgumentException( + "Invalid percent-encoding at index " + i + " of " + what + ": " + s); + } + int h1 = Character.digit(s.charAt(i + 1), 16); + int h2 = Character.digit(s.charAt(i + 2), 16); + if (h1 == -1 || h2 == -1) { + throw new IllegalArgumentException( + "Invalid hex digit in " + what + " at index " + i + " of: " + s); + } + if (outBuf != null) { + outBuf.put((byte) (h1 << 4 | h2)); + } + i += 2; + } else if (allowedChars == null || allowedChars.get(c)) { + if (outBuf != null) { + outBuf.put((byte) c); + } + } else { + throw new IllegalArgumentException("Invalid character in " + what + " at index " + i); + } + } + } + + @Nullable + private static String percentDecodeAssumedUtf8(@Nullable String s) { + if (s == null || s.indexOf('%') == -1) { + return s; + } + + ByteBuffer utf8Bytes = percentDecode(s); + try { + return StandardCharsets.UTF_8 + .newDecoder() + .onMalformedInput(CodingErrorAction.REPLACE) + .onUnmappableCharacter(CodingErrorAction.REPLACE) + .decode(utf8Bytes) + .toString(); + } catch (CharacterCodingException e) { + throw new VerifyException(e); // Should not happen in REPLACE mode. + } + } + + @Nullable + private static String percentEncode(String s, BitSet allowedCodePoints) { + if (s == null) { + return null; + } + CharsetEncoder encoder = + StandardCharsets.UTF_8 + .newEncoder() + .onMalformedInput(CodingErrorAction.REPORT) + .onUnmappableCharacter(CodingErrorAction.REPORT); + ByteBuffer utf8Bytes; + try { + utf8Bytes = encoder.encode(CharBuffer.wrap(s)); + } catch (MalformedInputException e) { + throw new IllegalArgumentException("Malformed input", e); // Must be a broken surrogate pair. + } catch (CharacterCodingException e) { + throw new VerifyException(e); // Should not happen when encoding to UTF-8. + } + + StringBuilder sb = new StringBuilder(); + while (utf8Bytes.hasRemaining()) { + int b = 0xff & utf8Bytes.get(); + if (allowedCodePoints.get(b)) { + sb.append((char) b); + } else { + sb.append('%'); + sb.append(hexDigitsByVal[(b & 0xF0) >> 4]); + sb.append(hexDigitsByVal[b & 0x0F]); + } + } + return sb.toString(); + } + + private static void checkPercentEncodedArg(String s, String what, BitSet allowedChars) { + percentDecode(s, what, allowedChars, null); + } + + // See UriTest for how these were computed from the ABNF constants in RFC 3986. + static final BitSet digitChars = BitSet.valueOf(new long[] {0x3ff000000000000L}); + static final BitSet alphaChars = BitSet.valueOf(new long[] {0L, 0x7fffffe07fffffeL}); + // scheme = ALPHA *( ALPHA / DIGIT / "+" / "-" / "." ) + static final BitSet schemeChars = + BitSet.valueOf(new long[] {0x3ff680000000000L, 0x7fffffe07fffffeL}); + // unreserved = ALPHA / DIGIT / "-" / "." / "_" / "~" + static final BitSet unreservedChars = + BitSet.valueOf(new long[] {0x3ff600000000000L, 0x47fffffe87fffffeL}); + // gen-delims = ":" / "/" / "?" / "#" / "[" / "]" / "@" + static final BitSet genDelimsChars = + BitSet.valueOf(new long[] {0x8400800800000000L, 0x28000001L}); + // sub-delims = "!" / "$" / "&" / "'" / "(" / ")" / "*" / "+" / "," / ";" / "=" + static final BitSet subDelimsChars = BitSet.valueOf(new long[] {0x28001fd200000000L}); + // reserved = gen-delims / sub-delims + static final BitSet reservedChars = BitSet.valueOf(new long[] {0xac009fda00000000L, 0x28000001L}); + // reg-name = *( unreserved / pct-encoded / sub-delims ) + static final BitSet regNameChars = + BitSet.valueOf(new long[] {0x2bff7fd200000000L, 0x47fffffe87fffffeL}); + // userinfo = *( unreserved / pct-encoded / sub-delims / ":" ) + static final BitSet userInfoChars = + BitSet.valueOf(new long[] {0x2fff7fd200000000L, 0x47fffffe87fffffeL}); + // pchar = unreserved / pct-encoded / sub-delims / ":" / "@" + static final BitSet pChars = + BitSet.valueOf(new long[] {0x2fff7fd200000000L, 0x47fffffe87ffffffL}); + static final BitSet pCharsAndSlash = + BitSet.valueOf(new long[] {0x2fffffd200000000L, 0x47fffffe87ffffffL}); + // query = *( pchar / "/" / "?" ) + static final BitSet queryChars = + BitSet.valueOf(new long[] {0xafffffd200000000L, 0x47fffffe87ffffffL}); + // fragment = *( pchar / "/" / "?" ) + static final BitSet fragmentChars = queryChars; + + private static final char[] hexDigitsByVal = "0123456789ABCDEF".toCharArray(); +} diff --git a/api/src/test/java/io/grpc/CallOptionsTest.java b/api/src/test/java/io/grpc/CallOptionsTest.java index f4c98c1369e..65fb7ff3bf2 100644 --- a/api/src/test/java/io/grpc/CallOptionsTest.java +++ b/api/src/test/java/io/grpc/CallOptionsTest.java @@ -32,6 +32,7 @@ import com.google.common.base.Objects; import io.grpc.ClientStreamTracer.StreamInfo; import io.grpc.internal.SerializingExecutor; +import java.time.Duration; import java.util.concurrent.Executor; import org.junit.Test; import org.junit.runner.RunWith; @@ -81,6 +82,16 @@ public void withAndWithoutWaitForReady() { .isFalse(); } + @Test + public void withOnReadyThreshold() { + int onReadyThreshold = 1024; + CallOptions callOptions = CallOptions.DEFAULT.withOnReadyThreshold(onReadyThreshold); + callOptions = callOptions.withWaitForReady(); + assertThat(callOptions.getOnReadyThreshold()).isEqualTo(onReadyThreshold); + callOptions = callOptions.clearOnReadyThreshold(); + assertThat(callOptions.getOnReadyThreshold()).isNull(); + } + @Test public void allWiths() { assertThat(allSet.getAuthority()).isSameInstanceAs(sampleAuthority); @@ -140,6 +151,15 @@ public void withDeadlineAfter() { assertAbout(deadline()).that(actual).isWithin(10, MILLISECONDS).of(expected); } + @Test + @IgnoreJRERequirement + public void withDeadlineAfterDuration() { + Deadline actual = CallOptions.DEFAULT.withDeadlineAfter(Duration.ofMinutes(1L)).getDeadline(); + Deadline expected = Deadline.after(1, MINUTES); + + assertAbout(deadline()).that(actual).isWithin(10, MILLISECONDS).of(expected); + } + @Test public void toStringMatches_noDeadline_default() { String actual = allSet @@ -148,6 +168,7 @@ public void toStringMatches_noDeadline_default() { .withCallCredentials(null) .withMaxInboundMessageSize(44) .withMaxOutboundMessageSize(55) + .withOnReadyThreshold(1024) .toString(); assertThat(actual).contains("deadline=null"); @@ -159,6 +180,7 @@ public void toStringMatches_noDeadline_default() { assertThat(actual).contains("waitForReady=true"); assertThat(actual).contains("maxInboundMessageSize=44"); assertThat(actual).contains("maxOutboundMessageSize=55"); + assertThat(actual).contains("onReadyThreshold=1024"); assertThat(actual).contains("streamTracerFactories=[tracerFactory1, tracerFactory2]"); } diff --git a/core/src/test/java/io/grpc/ClientStreamTracerTest.java b/api/src/test/java/io/grpc/ClientStreamTracerTest.java similarity index 100% rename from core/src/test/java/io/grpc/ClientStreamTracerTest.java rename to api/src/test/java/io/grpc/ClientStreamTracerTest.java diff --git a/api/src/test/java/io/grpc/ConfiguratorRegistryTest.java b/api/src/test/java/io/grpc/ConfiguratorRegistryTest.java new file mode 100644 index 00000000000..457d5a36e77 --- /dev/null +++ b/api/src/test/java/io/grpc/ConfiguratorRegistryTest.java @@ -0,0 +1,98 @@ +/* + * Copyright 2022 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.fail; + +import java.util.Arrays; +import java.util.List; +import java.util.regex.Pattern; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class ConfiguratorRegistryTest { + + private final StaticTestingClassLoader classLoader = + new StaticTestingClassLoader( + getClass().getClassLoader(), Pattern.compile("io\\.grpc\\.[^.]+")); + + @Test + public void setConfigurators() throws Exception { + Class runnable = classLoader.loadClass(StaticTestingClassLoaderSet.class.getName()); + ((Runnable) runnable.getDeclaredConstructor().newInstance()).run(); + } + + @Test + public void setGlobalConfigurators_twice() throws Exception { + Class runnable = classLoader.loadClass(StaticTestingClassLoaderSetTwice.class.getName()); + ((Runnable) runnable.getDeclaredConstructor().newInstance()).run(); + } + + @Test + public void getBeforeSet() throws Exception { + Class runnable = + classLoader.loadClass( + StaticTestingClassLoaderGetBeforeSet.class.getName()); + ((Runnable) runnable.getDeclaredConstructor().newInstance()).run(); + } + + // UsedReflectively + public static final class StaticTestingClassLoaderSet implements Runnable { + @Override + public void run() { + List configurators = Arrays.asList(new NoopConfigurator()); + + ConfiguratorRegistry.getDefaultRegistry().setConfigurators(configurators); + + assertThat(ConfiguratorRegistry.getDefaultRegistry().getConfigurators()) + .isEqualTo(configurators); + } + } + + public static final class StaticTestingClassLoaderSetTwice implements Runnable { + @Override + public void run() { + ConfiguratorRegistry.getDefaultRegistry() + .setConfigurators(Arrays.asList(new NoopConfigurator())); + try { + ConfiguratorRegistry.getDefaultRegistry() + .setConfigurators(Arrays.asList(new NoopConfigurator())); + fail("should have failed for calling setConfigurators() again"); + } catch (IllegalStateException e) { + assertThat(e).hasMessageThat().isEqualTo("Configurators are already set"); + } + } + } + + public static final class StaticTestingClassLoaderGetBeforeSet implements Runnable { + @Override + public void run() { + assertThat(ConfiguratorRegistry.getDefaultRegistry().getConfigurators()).isEmpty(); + NoopConfigurator noopConfigurator = new NoopConfigurator(); + ConfiguratorRegistry.getDefaultRegistry() + .setConfigurators(Arrays.asList(noopConfigurator)); + assertThat(ConfiguratorRegistry.getDefaultRegistry().getConfigurators()) + .containsExactly(noopConfigurator); + assertThat(InternalConfiguratorRegistry.getConfiguratorsCallCountBeforeSet()).isEqualTo(1); + } + } + + private static class NoopConfigurator implements Configurator {} +} diff --git a/api/src/test/java/io/grpc/GlobalInterceptorsTest.java b/api/src/test/java/io/grpc/GlobalInterceptorsTest.java deleted file mode 100644 index 7315186f1ee..00000000000 --- a/api/src/test/java/io/grpc/GlobalInterceptorsTest.java +++ /dev/null @@ -1,188 +0,0 @@ -/* - * Copyright 2022 The gRPC Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.grpc; - -import static com.google.common.truth.Truth.assertThat; -import static org.junit.Assert.fail; - -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collections; -import java.util.List; -import java.util.regex.Pattern; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; - -@RunWith(JUnit4.class) -public class GlobalInterceptorsTest { - - private final StaticTestingClassLoader classLoader = - new StaticTestingClassLoader( - getClass().getClassLoader(), Pattern.compile("io\\.grpc\\.[^.]+")); - - @Test - public void setInterceptorsTracers() throws Exception { - Class runnable = classLoader.loadClass(StaticTestingClassLoaderSet.class.getName()); - ((Runnable) runnable.getDeclaredConstructor().newInstance()).run(); - } - - @Test - public void setGlobalInterceptorsTracers_twice() throws Exception { - Class runnable = classLoader.loadClass(StaticTestingClassLoaderSetTwice.class.getName()); - ((Runnable) runnable.getDeclaredConstructor().newInstance()).run(); - } - - @Test - public void getBeforeSet_clientInterceptors() throws Exception { - Class runnable = - classLoader.loadClass( - StaticTestingClassLoaderGetBeforeSetClientInterceptor.class.getName()); - ((Runnable) runnable.getDeclaredConstructor().newInstance()).run(); - } - - @Test - public void getBeforeSet_serverInterceptors() throws Exception { - Class runnable = - classLoader.loadClass( - StaticTestingClassLoaderGetBeforeSetServerInterceptor.class.getName()); - ((Runnable) runnable.getDeclaredConstructor().newInstance()).run(); - } - - @Test - public void getBeforeSet_serverStreamTracerFactories() throws Exception { - Class runnable = - classLoader.loadClass( - StaticTestingClassLoaderGetBeforeSetServerStreamTracerFactory.class.getName()); - ((Runnable) runnable.getDeclaredConstructor().newInstance()).run(); - } - - // UsedReflectively - public static final class StaticTestingClassLoaderSet implements Runnable { - @Override - public void run() { - List clientInterceptorList = - new ArrayList<>(Arrays.asList(new NoopClientInterceptor())); - List serverInterceptorList = - new ArrayList<>(Arrays.asList(new NoopServerInterceptor())); - List serverStreamTracerFactoryList = - new ArrayList<>( - Arrays.asList( - new NoopServerStreamTracerFactory(), new NoopServerStreamTracerFactory())); - - GlobalInterceptors.setInterceptorsTracers( - clientInterceptorList, serverInterceptorList, serverStreamTracerFactoryList); - - assertThat(GlobalInterceptors.getClientInterceptors()).isEqualTo(clientInterceptorList); - assertThat(GlobalInterceptors.getServerInterceptors()).isEqualTo(serverInterceptorList); - assertThat(GlobalInterceptors.getServerStreamTracerFactories()) - .isEqualTo(serverStreamTracerFactoryList); - } - } - - public static final class StaticTestingClassLoaderSetTwice implements Runnable { - @Override - public void run() { - GlobalInterceptors.setInterceptorsTracers( - new ArrayList<>(Arrays.asList(new NoopClientInterceptor())), - Collections.emptyList(), - new ArrayList<>(Arrays.asList(new NoopServerStreamTracerFactory()))); - try { - GlobalInterceptors.setInterceptorsTracers( - null, new ArrayList<>(Arrays.asList(new NoopServerInterceptor())), null); - fail("should have failed for calling setGlobalInterceptorsTracers() again"); - } catch (IllegalStateException e) { - assertThat(e).hasMessageThat().isEqualTo("Global interceptors and tracers are already set"); - } - } - } - - public static final class StaticTestingClassLoaderGetBeforeSetClientInterceptor - implements Runnable { - @Override - public void run() { - List clientInterceptors = GlobalInterceptors.getClientInterceptors(); - assertThat(clientInterceptors).isNull(); - - try { - GlobalInterceptors.setInterceptorsTracers( - new ArrayList<>(Arrays.asList(new NoopClientInterceptor())), null, null); - fail("should have failed for invoking set call after get is already called"); - } catch (IllegalStateException e) { - assertThat(e).hasMessageThat().isEqualTo("Set cannot be called after any get call"); - } - } - } - - public static final class StaticTestingClassLoaderGetBeforeSetServerInterceptor - implements Runnable { - @Override - public void run() { - List serverInterceptors = GlobalInterceptors.getServerInterceptors(); - assertThat(serverInterceptors).isNull(); - - try { - GlobalInterceptors.setInterceptorsTracers( - null, new ArrayList<>(Arrays.asList(new NoopServerInterceptor())), null); - fail("should have failed for invoking set call after get is already called"); - } catch (IllegalStateException e) { - assertThat(e).hasMessageThat().isEqualTo("Set cannot be called after any get call"); - } - } - } - - public static final class StaticTestingClassLoaderGetBeforeSetServerStreamTracerFactory - implements Runnable { - @Override - public void run() { - List serverStreamTracerFactories = - GlobalInterceptors.getServerStreamTracerFactories(); - assertThat(serverStreamTracerFactories).isNull(); - - try { - GlobalInterceptors.setInterceptorsTracers( - null, null, new ArrayList<>(Arrays.asList(new NoopServerStreamTracerFactory()))); - fail("should have failed for invoking set call after get is already called"); - } catch (IllegalStateException e) { - assertThat(e).hasMessageThat().isEqualTo("Set cannot be called after any get call"); - } - } - } - - private static class NoopClientInterceptor implements ClientInterceptor { - @Override - public ClientCall interceptCall( - MethodDescriptor method, CallOptions callOptions, Channel next) { - return next.newCall(method, callOptions); - } - } - - private static class NoopServerInterceptor implements ServerInterceptor { - @Override - public ServerCall.Listener interceptCall( - ServerCall call, Metadata headers, ServerCallHandler next) { - return next.startCall(call, headers); - } - } - - private static class NoopServerStreamTracerFactory extends ServerStreamTracer.Factory { - @Override - public ServerStreamTracer newServerStreamTracer(String fullMethodName, Metadata headers) { - throw new UnsupportedOperationException(); - } - } -} diff --git a/api/src/test/java/io/grpc/HttpConnectProxiedSocketAddressTest.java b/api/src/test/java/io/grpc/HttpConnectProxiedSocketAddressTest.java new file mode 100644 index 00000000000..6620a7d413a --- /dev/null +++ b/api/src/test/java/io/grpc/HttpConnectProxiedSocketAddressTest.java @@ -0,0 +1,248 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertNotEquals; +import static org.junit.Assert.assertThrows; + +import com.google.common.testing.EqualsTester; +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class HttpConnectProxiedSocketAddressTest { + + private final InetSocketAddress proxyAddress = + new InetSocketAddress(InetAddress.getLoopbackAddress(), 8080); + private final InetSocketAddress targetAddress = + InetSocketAddress.createUnresolved("example.com", 443); + + @Test + public void buildWithAllFields() { + Map headers = new HashMap<>(); + headers.put("X-Custom-Header", "custom-value"); + headers.put("Proxy-Authorization", "Bearer token"); + + HttpConnectProxiedSocketAddress address = HttpConnectProxiedSocketAddress.newBuilder() + .setProxyAddress(proxyAddress) + .setTargetAddress(targetAddress) + .setHeaders(headers) + .setUsername("user") + .setPassword("pass") + .build(); + + assertThat(address.getProxyAddress()).isEqualTo(proxyAddress); + assertThat(address.getTargetAddress()).isEqualTo(targetAddress); + assertThat(address.getHeaders()).hasSize(2); + assertThat(address.getHeaders()).containsEntry("X-Custom-Header", "custom-value"); + assertThat(address.getHeaders()).containsEntry("Proxy-Authorization", "Bearer token"); + assertThat(address.getUsername()).isEqualTo("user"); + assertThat(address.getPassword()).isEqualTo("pass"); + } + + @Test + public void buildWithoutOptionalFields() { + HttpConnectProxiedSocketAddress address = HttpConnectProxiedSocketAddress.newBuilder() + .setProxyAddress(proxyAddress) + .setTargetAddress(targetAddress) + .build(); + + assertThat(address.getProxyAddress()).isEqualTo(proxyAddress); + assertThat(address.getTargetAddress()).isEqualTo(targetAddress); + assertThat(address.getHeaders()).isEmpty(); + assertThat(address.getUsername()).isNull(); + assertThat(address.getPassword()).isNull(); + } + + @Test + public void buildWithEmptyHeaders() { + HttpConnectProxiedSocketAddress address = HttpConnectProxiedSocketAddress.newBuilder() + .setProxyAddress(proxyAddress) + .setTargetAddress(targetAddress) + .setHeaders(Collections.emptyMap()) + .build(); + + assertThat(address.getHeaders()).isEmpty(); + } + + @Test + public void headersAreImmutable() { + Map headers = new HashMap<>(); + headers.put("key1", "value1"); + + HttpConnectProxiedSocketAddress address = HttpConnectProxiedSocketAddress.newBuilder() + .setProxyAddress(proxyAddress) + .setTargetAddress(targetAddress) + .setHeaders(headers) + .build(); + + headers.put("key2", "value2"); + + assertThat(address.getHeaders()).hasSize(1); + assertThat(address.getHeaders()).containsEntry("key1", "value1"); + assertThat(address.getHeaders()).doesNotContainKey("key2"); + } + + @Test + public void returnedHeadersAreUnmodifiable() { + Map headers = new HashMap<>(); + headers.put("key", "value"); + + HttpConnectProxiedSocketAddress address = HttpConnectProxiedSocketAddress.newBuilder() + .setProxyAddress(proxyAddress) + .setTargetAddress(targetAddress) + .setHeaders(headers) + .build(); + + assertThrows(UnsupportedOperationException.class, + () -> address.getHeaders().put("newKey", "newValue")); + } + + @Test + public void nullHeadersThrowsException() { + assertThrows(NullPointerException.class, + () -> HttpConnectProxiedSocketAddress.newBuilder() + .setProxyAddress(proxyAddress) + .setTargetAddress(targetAddress) + .setHeaders(null) + .build()); + } + + @Test + public void equalsAndHashCode() { + Map headers1 = new HashMap<>(); + headers1.put("header", "value"); + + Map headers2 = new HashMap<>(); + headers2.put("header", "value"); + + Map differentHeaders = new HashMap<>(); + differentHeaders.put("different", "header"); + + new EqualsTester() + .addEqualityGroup( + HttpConnectProxiedSocketAddress.newBuilder() + .setProxyAddress(proxyAddress) + .setTargetAddress(targetAddress) + .setHeaders(headers1) + .setUsername("user") + .setPassword("pass") + .build(), + HttpConnectProxiedSocketAddress.newBuilder() + .setProxyAddress(proxyAddress) + .setTargetAddress(targetAddress) + .setHeaders(headers2) + .setUsername("user") + .setPassword("pass") + .build()) + .addEqualityGroup( + HttpConnectProxiedSocketAddress.newBuilder() + .setProxyAddress(proxyAddress) + .setTargetAddress(targetAddress) + .setHeaders(differentHeaders) + .setUsername("user") + .setPassword("pass") + .build()) + .addEqualityGroup( + HttpConnectProxiedSocketAddress.newBuilder() + .setProxyAddress(proxyAddress) + .setTargetAddress(targetAddress) + .build()) + .testEquals(); + } + + @Test + public void toStringContainsHeaders() { + Map headers = new HashMap<>(); + headers.put("X-Test", "test-value"); + + HttpConnectProxiedSocketAddress address = HttpConnectProxiedSocketAddress.newBuilder() + .setProxyAddress(proxyAddress) + .setTargetAddress(targetAddress) + .setHeaders(headers) + .setUsername("user") + .setPassword("secret") + .build(); + + String toString = address.toString(); + assertThat(toString).contains("headers"); + assertThat(toString).contains("X-Test"); + assertThat(toString).contains("hasPassword=true"); + assertThat(toString).doesNotContain("secret"); + } + + @Test + public void toStringWithoutPassword() { + HttpConnectProxiedSocketAddress address = HttpConnectProxiedSocketAddress.newBuilder() + .setProxyAddress(proxyAddress) + .setTargetAddress(targetAddress) + .build(); + + String toString = address.toString(); + assertThat(toString).contains("hasPassword=false"); + } + + @Test + public void hashCodeDependsOnHeaders() { + Map headers1 = new HashMap<>(); + headers1.put("header", "value1"); + + Map headers2 = new HashMap<>(); + headers2.put("header", "value2"); + + HttpConnectProxiedSocketAddress address1 = HttpConnectProxiedSocketAddress.newBuilder() + .setProxyAddress(proxyAddress) + .setTargetAddress(targetAddress) + .setHeaders(headers1) + .build(); + + HttpConnectProxiedSocketAddress address2 = HttpConnectProxiedSocketAddress.newBuilder() + .setProxyAddress(proxyAddress) + .setTargetAddress(targetAddress) + .setHeaders(headers2) + .build(); + + assertNotEquals(address1.hashCode(), address2.hashCode()); + } + + @Test + public void multipleHeadersSupported() { + Map headers = new HashMap<>(); + headers.put("X-Header-1", "value1"); + headers.put("X-Header-2", "value2"); + headers.put("X-Header-3", "value3"); + + HttpConnectProxiedSocketAddress address = HttpConnectProxiedSocketAddress.newBuilder() + .setProxyAddress(proxyAddress) + .setTargetAddress(targetAddress) + .setHeaders(headers) + .build(); + + assertThat(address.getHeaders()).hasSize(3); + assertThat(address.getHeaders()).containsEntry("X-Header-1", "value1"); + assertThat(address.getHeaders()).containsEntry("X-Header-2", "value2"); + assertThat(address.getHeaders()).containsEntry("X-Header-3", "value3"); + } +} + diff --git a/api/src/test/java/io/grpc/LoadBalancerRegistryTest.java b/api/src/test/java/io/grpc/LoadBalancerRegistryTest.java index 5b348b7adab..690db6622e0 100644 --- a/api/src/test/java/io/grpc/LoadBalancerRegistryTest.java +++ b/api/src/test/java/io/grpc/LoadBalancerRegistryTest.java @@ -40,7 +40,7 @@ public void getClassesViaHardcoded_classesPresent() throws Exception { @Test public void stockProviders() { LoadBalancerRegistry defaultRegistry = LoadBalancerRegistry.getDefaultRegistry(); - assertThat(defaultRegistry.providers()).hasSize(3); + assertThat(defaultRegistry.providers()).hasSize(4); LoadBalancerProvider pickFirst = defaultRegistry.getProvider("pick_first"); assertThat(pickFirst).isInstanceOf(PickFirstLoadBalancerProvider.class); @@ -56,6 +56,12 @@ public void stockProviders() { assertThat(outlierDetection.getClass().getName()).isEqualTo( "io.grpc.util.OutlierDetectionLoadBalancerProvider"); assertThat(roundRobin.getPriority()).isEqualTo(5); + + LoadBalancerProvider randomSubsetting = defaultRegistry.getProvider( + "random_subsetting_experimental"); + assertThat(randomSubsetting.getClass().getName()).isEqualTo( + "io.grpc.util.RandomSubsettingLoadBalancerProvider"); + assertThat(randomSubsetting.getPriority()).isEqualTo(5); } @Test diff --git a/api/src/test/java/io/grpc/LoadBalancerTest.java b/api/src/test/java/io/grpc/LoadBalancerTest.java index 5e9e5cbe816..22fdc220081 100644 --- a/api/src/test/java/io/grpc/LoadBalancerTest.java +++ b/api/src/test/java/io/grpc/LoadBalancerTest.java @@ -64,6 +64,26 @@ public void pickResult_withSubchannelAndTracer() { assertThat(result.isDrop()).isFalse(); } + @Test + public void pickResult_withSubchannelReplacement() { + PickResult result = PickResult.withSubchannel(subchannel, tracerFactory) + .copyWithSubchannel(subchannel2); + assertThat(result.getSubchannel()).isSameInstanceAs(subchannel2); + assertThat(result.getStatus()).isSameInstanceAs(Status.OK); + assertThat(result.getStreamTracerFactory()).isSameInstanceAs(tracerFactory); + assertThat(result.isDrop()).isFalse(); + } + + @Test + public void pickResult_withStreamTracerFactory() { + PickResult result = PickResult.withSubchannel(subchannel) + .copyWithStreamTracerFactory(tracerFactory); + assertThat(result.getSubchannel()).isSameInstanceAs(subchannel); + assertThat(result.getStatus()).isSameInstanceAs(Status.OK); + assertThat(result.getStreamTracerFactory()).isSameInstanceAs(tracerFactory); + assertThat(result.isDrop()).isFalse(); + } + @Test public void pickResult_withNoResult() { PickResult result = PickResult.withNoResult(); diff --git a/api/src/test/java/io/grpc/ManagedChannelRegistryTest.java b/api/src/test/java/io/grpc/ManagedChannelRegistryTest.java index 30de2477d77..2479e339791 100644 --- a/api/src/test/java/io/grpc/ManagedChannelRegistryTest.java +++ b/api/src/test/java/io/grpc/ManagedChannelRegistryTest.java @@ -20,17 +20,23 @@ import static org.junit.Assert.fail; import com.google.common.collect.ImmutableSet; +import io.grpc.FlagResetRule; import java.net.InetSocketAddress; import java.net.SocketAddress; import java.net.URI; +import java.util.Arrays; import java.util.Collection; import java.util.Collections; +import org.junit.Before; +import org.junit.Rule; import org.junit.Test; import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; +import org.junit.runners.Parameterized; +import org.junit.runners.Parameterized.Parameter; +import org.junit.runners.Parameterized.Parameters; /** Unit tests for {@link ManagedChannelRegistry}. */ -@RunWith(JUnit4.class) +@RunWith(Parameterized.class) public class ManagedChannelRegistryTest { private String target = "testing123"; private ChannelCredentials creds = new ChannelCredentials() { @@ -40,6 +46,20 @@ public ChannelCredentials withoutBearerTokens() { } }; + @Rule public final FlagResetRule flagResetRule = new FlagResetRule(); + + @Parameters(name = "enableRfc3986UrisParam={0}") + public static Iterable data() { + return Arrays.asList(new Object[][] {{true}, {false}}); + } + + @Parameter public boolean enableRfc3986UrisParam; + + @Before + public void setUp() { + flagResetRule.setFlagForTest(FeatureFlags::setRfc3986UrisEnabled, enableRfc3986UrisParam); + } + @Test public void register_unavailableProviderThrows() { ManagedChannelRegistry reg = new ManagedChannelRegistry(); diff --git a/api/src/test/java/io/grpc/MetadataTest.java b/api/src/test/java/io/grpc/MetadataTest.java index 073a505c824..a858fff5e5a 100644 --- a/api/src/test/java/io/grpc/MetadataTest.java +++ b/api/src/test/java/io/grpc/MetadataTest.java @@ -16,20 +16,22 @@ package io.grpc; -import static com.google.common.base.Charsets.US_ASCII; -import static com.google.common.base.Charsets.UTF_8; +import static com.google.common.truth.Truth.assertThat; +import static java.nio.charset.StandardCharsets.US_ASCII; +import static java.nio.charset.StandardCharsets.UTF_8; import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertNotEquals; import static org.junit.Assert.assertNotSame; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; import com.google.common.collect.Lists; import com.google.common.io.ByteStreams; +import com.google.common.testing.EqualsTester; import io.grpc.internal.GrpcUtil; import java.io.ByteArrayInputStream; import java.io.IOException; @@ -37,9 +39,7 @@ import java.util.Arrays; import java.util.Iterator; import java.util.Locale; -import org.junit.Rule; import org.junit.Test; -import org.junit.rules.ExpectedException; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -49,9 +49,6 @@ @RunWith(JUnit4.class) public class MetadataTest { - @SuppressWarnings("deprecation") // https://github.com/grpc/grpc-java/issues/7467 - @Rule public final ExpectedException thrown = ExpectedException.none(); - private static final Metadata.BinaryMarshaller FISH_MARSHALLER = new Metadata.BinaryMarshaller() { @Override @@ -65,7 +62,7 @@ public Fish parseBytes(byte[] serialized) { } }; - private static class FishStreamMarsaller implements Metadata.BinaryStreamMarshaller { + private static class FishStreamMarshaller implements Metadata.BinaryStreamMarshaller { @Override public InputStream toStream(Fish fish) { return new ByteArrayInputStream(FISH_MARSHALLER.toBytes(fish)); @@ -82,7 +79,7 @@ public Fish parseStream(InputStream stream) { } private static final Metadata.BinaryStreamMarshaller FISH_STREAM_MARSHALLER = - new FishStreamMarsaller(); + new FishStreamMarshaller(); /** A pattern commonly used to avoid unnecessary serialization of immutable objects. */ private static final class FakeFishStream extends InputStream { @@ -121,10 +118,9 @@ public Fish parseStream(InputStream stream) { @Test public void noPseudoHeaders() { - thrown.expect(IllegalArgumentException.class); - thrown.expectMessage("Invalid character"); - - Metadata.Key.of(":test-bin", FISH_MARSHALLER); + IllegalArgumentException e = assertThrows(IllegalArgumentException.class, + () -> Metadata.Key.of(":test-bin", FISH_MARSHALLER)); + assertThat(e).hasMessageThat().isEqualTo("Invalid character ':' in key name ':test-bin'"); } @Test @@ -186,8 +182,7 @@ public void testGetAllNoRemove() { Iterator i = metadata.getAll(KEY).iterator(); assertEquals(lance, i.next()); - thrown.expect(UnsupportedOperationException.class); - i.remove(); + assertThrows(UnsupportedOperationException.class, i::remove); } @Test @@ -271,17 +266,15 @@ public void mergeExpands() { @Test public void shortBinaryKeyName() { - thrown.expect(IllegalArgumentException.class); - - Metadata.Key.of("-bin", FISH_MARSHALLER); + assertThrows(IllegalArgumentException.class, () -> Metadata.Key.of("-bin", FISH_MARSHALLER)); } @Test public void invalidSuffixBinaryKeyName() { - thrown.expect(IllegalArgumentException.class); - thrown.expectMessage("Binary header is named"); - - Metadata.Key.of("nonbinary", FISH_MARSHALLER); + IllegalArgumentException e = assertThrows(IllegalArgumentException.class, + () -> Metadata.Key.of("nonbinary", FISH_MARSHALLER)); + assertThat(e).hasMessageThat() + .isEqualTo("Binary header is named nonbinary. It must end with -bin"); } @Test @@ -368,14 +361,12 @@ public void removeAllIgnoresMissingValue() { @Test public void keyEqualsHashNameWorks() { Metadata.Key k1 = Metadata.Key.of("case", Metadata.ASCII_STRING_MARSHALLER); - Metadata.Key k2 = Metadata.Key.of("CASE", Metadata.ASCII_STRING_MARSHALLER); - assertEquals(k1, k1); - assertNotEquals(k1, null); - assertNotEquals(k1, new Object(){}); - assertEquals(k1, k2); - assertEquals(k1.hashCode(), k2.hashCode()); + new EqualsTester() + .addEqualityGroup(k1, k2) + .addEqualityGroup(new Object(){}) + .testEquals(); // Check that the casing is preserved. assertEquals("CASE", k2.originalName()); assertEquals("case", k2.name()); @@ -417,7 +408,7 @@ public void streamedValueDifferentMarshaller() { h.put(KEY_STREAMED, salmon); // Get using a different marshaller instance. - Fish fish = h.get(copyKey(KEY_STREAMED, new FishStreamMarsaller())); + Fish fish = h.get(copyKey(KEY_STREAMED, new FishStreamMarshaller())); assertEquals(salmon, fish); } diff --git a/api/src/test/java/io/grpc/MethodDescriptorTest.java b/api/src/test/java/io/grpc/MethodDescriptorTest.java index 5e742fb47ed..e068e0c1108 100644 --- a/api/src/test/java/io/grpc/MethodDescriptorTest.java +++ b/api/src/test/java/io/grpc/MethodDescriptorTest.java @@ -26,9 +26,7 @@ import io.grpc.MethodDescriptor.Marshaller; import io.grpc.MethodDescriptor.MethodType; import io.grpc.testing.TestMethodDescriptors; -import org.junit.Rule; import org.junit.Test; -import org.junit.rules.ExpectedException; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -37,10 +35,6 @@ */ @RunWith(JUnit4.class) public class MethodDescriptorTest { - @SuppressWarnings("deprecation") // https://github.com/grpc/grpc-java/issues/7467 - @Rule - public final ExpectedException thrown = ExpectedException.none(); - @Test public void createMethodDescriptor() { MethodDescriptor descriptor = MethodDescriptor.newBuilder() @@ -67,7 +61,7 @@ public void idempotent() { assertFalse(descriptor.isIdempotent()); - // Create a new desriptor by setting idempotent to true + // Create a new descriptor by setting idempotent to true MethodDescriptor newDescriptor = descriptor.toBuilder().setIdempotent(true).build(); assertTrue(newDescriptor.isIdempotent()); @@ -86,7 +80,7 @@ public void safe() { .build(); assertFalse(descriptor.isSafe()); - // Create a new desriptor by setting safe to true + // Create a new descriptor by setting safe to true MethodDescriptor newDescriptor = descriptor.toBuilder().setSafe(true).build(); assertTrue(newDescriptor.isSafe()); // All other fields should staty the same diff --git a/api/src/test/java/io/grpc/MetricInstrumentRegistryTest.java b/api/src/test/java/io/grpc/MetricInstrumentRegistryTest.java new file mode 100644 index 00000000000..b378f4aaef5 --- /dev/null +++ b/api/src/test/java/io/grpc/MetricInstrumentRegistryTest.java @@ -0,0 +1,193 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc; + +import static com.google.common.truth.Truth.assertThat; +import static io.grpc.MetricInstrumentRegistry.INITIAL_INSTRUMENT_CAPACITY; + +import com.google.common.collect.ImmutableList; +import java.util.List; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** + * Unit test for {@link MetricInstrumentRegistry}. + */ +@RunWith(JUnit4.class) +public class MetricInstrumentRegistryTest { + private static final ImmutableList REQUIRED_LABEL_KEYS = ImmutableList.of("KEY1", "KEY2"); + private static final ImmutableList OPTIONAL_LABEL_KEYS = ImmutableList.of( + "OPTIONAL_KEY_1"); + private static final ImmutableList DOUBLE_HISTOGRAM_BUCKETS = ImmutableList.of(0.01, 0.1); + private static final ImmutableList LONG_HISTOGRAM_BUCKETS = ImmutableList.of(1L, 10L); + private static final String METRIC_NAME_1 = "testMetric1"; + private static final String DESCRIPTION_1 = "description1"; + private static final String DESCRIPTION_2 = "description2"; + private static final String UNIT_1 = "unit1"; + private static final String UNIT_2 = "unit2"; + private static final boolean ENABLED = true; + private static final boolean DISABLED = false; + private MetricInstrumentRegistry registry = new MetricInstrumentRegistry(); + + @Test + public void registerDoubleCounterSuccess() { + DoubleCounterMetricInstrument instrument = registry.registerDoubleCounter( + METRIC_NAME_1, DESCRIPTION_1, UNIT_1, REQUIRED_LABEL_KEYS, OPTIONAL_LABEL_KEYS, ENABLED); + assertThat(registry.getMetricInstruments().contains(instrument)).isTrue(); + assertThat(registry.getMetricInstruments().size()).isEqualTo(1); + assertThat(instrument.getName()).isEqualTo(METRIC_NAME_1); + assertThat(instrument.getDescription()).isEqualTo(DESCRIPTION_1); + assertThat(instrument.getUnit()).isEqualTo(UNIT_1); + assertThat(instrument.getRequiredLabelKeys()).isEqualTo(REQUIRED_LABEL_KEYS); + assertThat(instrument.getOptionalLabelKeys()).isEqualTo(OPTIONAL_LABEL_KEYS); + assertThat(instrument.isEnableByDefault()).isTrue(); + } + + @Test + public void registerLongCounterSuccess() { + LongCounterMetricInstrument instrument2 = registry.registerLongCounter( + METRIC_NAME_1, DESCRIPTION_1, UNIT_1, REQUIRED_LABEL_KEYS, OPTIONAL_LABEL_KEYS, ENABLED); + assertThat(registry.getMetricInstruments().contains(instrument2)).isTrue(); + assertThat(registry.getMetricInstruments().size()).isEqualTo(1); + assertThat(instrument2.getName()).isEqualTo(METRIC_NAME_1); + assertThat(instrument2.getDescription()).isEqualTo(DESCRIPTION_1); + assertThat(instrument2.getUnit()).isEqualTo(UNIT_1); + assertThat(instrument2.getRequiredLabelKeys()).isEqualTo(REQUIRED_LABEL_KEYS); + assertThat(instrument2.getOptionalLabelKeys()).isEqualTo(OPTIONAL_LABEL_KEYS); + assertThat(instrument2.isEnableByDefault()).isTrue(); + } + + @Test + public void registerDoubleHistogramSuccess() { + DoubleHistogramMetricInstrument instrument3 = registry.registerDoubleHistogram( + METRIC_NAME_1, DESCRIPTION_1, UNIT_1, DOUBLE_HISTOGRAM_BUCKETS, REQUIRED_LABEL_KEYS, + OPTIONAL_LABEL_KEYS, ENABLED); + assertThat(registry.getMetricInstruments().contains(instrument3)).isTrue(); + assertThat(registry.getMetricInstruments().size()).isEqualTo(1); + assertThat(instrument3.getName()).isEqualTo(METRIC_NAME_1); + assertThat(instrument3.getDescription()).isEqualTo(DESCRIPTION_1); + assertThat(instrument3.getUnit()).isEqualTo(UNIT_1); + assertThat(instrument3.getBucketBoundaries()).isEqualTo(DOUBLE_HISTOGRAM_BUCKETS); + assertThat(instrument3.getRequiredLabelKeys()).isEqualTo(REQUIRED_LABEL_KEYS); + assertThat(instrument3.getOptionalLabelKeys()).isEqualTo(OPTIONAL_LABEL_KEYS); + assertThat(instrument3.isEnableByDefault()).isTrue(); + } + + @Test + public void registerLongHistogramSuccess() { + LongHistogramMetricInstrument instrument4 = registry.registerLongHistogram( + METRIC_NAME_1, DESCRIPTION_1, UNIT_1, LONG_HISTOGRAM_BUCKETS, REQUIRED_LABEL_KEYS, + OPTIONAL_LABEL_KEYS, ENABLED); + assertThat(registry.getMetricInstruments().contains(instrument4)).isTrue(); + assertThat(registry.getMetricInstruments().size()).isEqualTo(1); + assertThat(instrument4.getName()).isEqualTo(METRIC_NAME_1); + assertThat(instrument4.getDescription()).isEqualTo(DESCRIPTION_1); + assertThat(instrument4.getUnit()).isEqualTo(UNIT_1); + assertThat(instrument4.getBucketBoundaries()).isEqualTo(LONG_HISTOGRAM_BUCKETS); + assertThat(instrument4.getRequiredLabelKeys()).isEqualTo(REQUIRED_LABEL_KEYS); + assertThat(instrument4.getOptionalLabelKeys()).isEqualTo(OPTIONAL_LABEL_KEYS); + assertThat(instrument4.isEnableByDefault()).isTrue(); + } + + @Test + public void registerLongGaugeSuccess() { + LongGaugeMetricInstrument instrument4 = registry.registerLongGauge( + METRIC_NAME_1, DESCRIPTION_1, UNIT_1, REQUIRED_LABEL_KEYS, + OPTIONAL_LABEL_KEYS, ENABLED); + assertThat(registry.getMetricInstruments().contains(instrument4)).isTrue(); + assertThat(registry.getMetricInstruments().size()).isEqualTo(1); + assertThat(instrument4.getName()).isEqualTo(METRIC_NAME_1); + assertThat(instrument4.getDescription()).isEqualTo(DESCRIPTION_1); + assertThat(instrument4.getUnit()).isEqualTo(UNIT_1); + assertThat(instrument4.getRequiredLabelKeys()).isEqualTo(REQUIRED_LABEL_KEYS); + assertThat(instrument4.getOptionalLabelKeys()).isEqualTo(OPTIONAL_LABEL_KEYS); + assertThat(instrument4.isEnableByDefault()).isTrue(); + } + + @Test(expected = IllegalStateException.class) + public void registerDoubleCounterDuplicateName() { + registry.registerDoubleCounter(METRIC_NAME_1, DESCRIPTION_1, UNIT_1, REQUIRED_LABEL_KEYS, + OPTIONAL_LABEL_KEYS, ENABLED); + registry.registerDoubleCounter(METRIC_NAME_1, DESCRIPTION_2, UNIT_2, REQUIRED_LABEL_KEYS, + OPTIONAL_LABEL_KEYS, DISABLED); + } + + @Test(expected = IllegalStateException.class) + public void registerLongCounterDuplicateName() { + registry.registerDoubleCounter(METRIC_NAME_1, DESCRIPTION_1, UNIT_1, REQUIRED_LABEL_KEYS, + OPTIONAL_LABEL_KEYS, ENABLED); + registry.registerLongCounter(METRIC_NAME_1, DESCRIPTION_2, UNIT_2, REQUIRED_LABEL_KEYS, + OPTIONAL_LABEL_KEYS, DISABLED); + } + + @Test(expected = IllegalStateException.class) + public void registerDoubleHistogramDuplicateName() { + registry.registerLongHistogram(METRIC_NAME_1, DESCRIPTION_1, UNIT_1, LONG_HISTOGRAM_BUCKETS, + REQUIRED_LABEL_KEYS, OPTIONAL_LABEL_KEYS, ENABLED); + registry.registerDoubleHistogram(METRIC_NAME_1, DESCRIPTION_2, UNIT_2, DOUBLE_HISTOGRAM_BUCKETS, + REQUIRED_LABEL_KEYS, OPTIONAL_LABEL_KEYS, DISABLED); + } + + @Test(expected = IllegalStateException.class) + public void registerLongHistogramDuplicateName() { + registry.registerLongCounter(METRIC_NAME_1, DESCRIPTION_1, UNIT_1, REQUIRED_LABEL_KEYS, + OPTIONAL_LABEL_KEYS, ENABLED); + registry.registerLongHistogram(METRIC_NAME_1, DESCRIPTION_2, UNIT_2, LONG_HISTOGRAM_BUCKETS, + REQUIRED_LABEL_KEYS, OPTIONAL_LABEL_KEYS, DISABLED); + } + + @Test(expected = IllegalStateException.class) + public void registerLongGaugeDuplicateName() { + registry.registerDoubleHistogram(METRIC_NAME_1, DESCRIPTION_1, UNIT_1, DOUBLE_HISTOGRAM_BUCKETS, + REQUIRED_LABEL_KEYS, OPTIONAL_LABEL_KEYS, ENABLED); + registry.registerLongGauge(METRIC_NAME_1, DESCRIPTION_2, UNIT_2, REQUIRED_LABEL_KEYS, + OPTIONAL_LABEL_KEYS, DISABLED); + } + + @Test + public void getMetricInstrumentsMultipleRegistered() { + DoubleCounterMetricInstrument instrument1 = registry.registerDoubleCounter( + "testMetric1", DESCRIPTION_1, UNIT_1, REQUIRED_LABEL_KEYS, OPTIONAL_LABEL_KEYS, ENABLED); + LongCounterMetricInstrument instrument2 = registry.registerLongCounter( + "testMetric2", DESCRIPTION_2, UNIT_2, REQUIRED_LABEL_KEYS, OPTIONAL_LABEL_KEYS, DISABLED); + DoubleHistogramMetricInstrument instrument3 = registry.registerDoubleHistogram( + "testMetric3", DESCRIPTION_2, UNIT_2, DOUBLE_HISTOGRAM_BUCKETS, REQUIRED_LABEL_KEYS, + OPTIONAL_LABEL_KEYS, DISABLED); + + List instruments = registry.getMetricInstruments(); + assertThat(instruments.size()).isEqualTo(3); + assertThat(instruments.contains(instrument1)).isTrue(); + assertThat(instruments.contains(instrument2)).isTrue(); + assertThat(instruments.contains(instrument3)).isTrue(); + } + + @Test + public void resizeMetricInstrumentsCapacityIncrease() { + int initialCapacity = INITIAL_INSTRUMENT_CAPACITY; + MetricInstrumentRegistry testRegistry = new MetricInstrumentRegistry(); + + // Registering enough instruments to trigger resize + for (int i = 0; i < initialCapacity + 1; i++) { + testRegistry.registerLongHistogram("name" + i, "desc", "unit", ImmutableList.of(), + ImmutableList.of(), ImmutableList.of(), true); + } + + assertThat(testRegistry.getMetricInstruments().size()).isGreaterThan(initialCapacity); + } + +} diff --git a/api/src/test/java/io/grpc/NameResolverRegistryTest.java b/api/src/test/java/io/grpc/NameResolverRegistryTest.java index 2fd23e3a974..76976c3b59b 100644 --- a/api/src/test/java/io/grpc/NameResolverRegistryTest.java +++ b/api/src/test/java/io/grpc/NameResolverRegistryTest.java @@ -33,7 +33,8 @@ /** Unit tests for {@link NameResolverRegistry}. */ @RunWith(JUnit4.class) public class NameResolverRegistryTest { - private final URI uri = URI.create("dns:///localhost"); + private final URI javaNetUri = URI.create("dns:///localhost"); + private final Uri ioGrpcUri = Uri.create("dns:///localhost"); private final NameResolver.Args args = NameResolver.Args.newBuilder() .setDefaultPort(8080) .setProxyDetector(mock(ProxyDetector.class)) @@ -96,43 +97,80 @@ public void getDefaultScheme_noProvider() { } @Test - public void newNameResolver_providerReturnsNull() { + public void newNameResolver_providerReturnsNull_ioGrpcUri() { NameResolverRegistry registry = new NameResolverRegistry(); registry.register( - new BaseProvider(true, 5, "noScheme") { + new BaseProvider(true, 5, ioGrpcUri.getScheme()) { @Override - public NameResolver newNameResolver(URI passedUri, NameResolver.Args passedArgs) { - assertThat(passedUri).isSameInstanceAs(uri); + public NameResolver newNameResolver(Uri passedUri, NameResolver.Args passedArgs) { + assertThat(passedUri).isSameInstanceAs(ioGrpcUri); assertThat(passedArgs).isSameInstanceAs(args); return null; } }); - assertThat(registry.asFactory().newNameResolver(uri, args)).isNull(); - assertThat(registry.asFactory().getDefaultScheme()).isEqualTo("noScheme"); + assertThat(registry.asFactory().newNameResolver(ioGrpcUri, args)).isNull(); + assertThat(registry.asFactory().getDefaultScheme()).isEqualTo(ioGrpcUri.getScheme()); } @Test - public void newNameResolver_providerReturnsNonNull() { + public void newNameResolver_providerReturnsNull_javaNetUri() { NameResolverRegistry registry = new NameResolverRegistry(); - registry.register(new BaseProvider(true, 5, uri.getScheme()) { - @Override - public NameResolver newNameResolver(URI passedUri, NameResolver.Args passedArgs) { - return null; - } - }); - final NameResolver nr = new NameResolver() { - @Override public String getServiceAuthority() { - throw new UnsupportedOperationException(); - } + registry.register( + new BaseProvider(true, 5, javaNetUri.getScheme()) { + @Override + public NameResolver newNameResolver(URI passedUri, NameResolver.Args passedArgs) { + assertThat(passedUri).isSameInstanceAs(javaNetUri); + assertThat(passedArgs).isSameInstanceAs(args); + return null; + } + }); + assertThat(registry.asFactory().newNameResolver(javaNetUri, args)).isNull(); + assertThat(registry.asFactory().getDefaultScheme()).isEqualTo(javaNetUri.getScheme()); + } - @Override public void start(Listener2 listener) { - throw new UnsupportedOperationException(); - } + @Test + public void newNameResolver_providerReturnsNonNull_ioGrpcUri() { + NameResolverRegistry registry = new NameResolverRegistry(); + Uri uri = ioGrpcUri; + registry.register( + new BaseProvider(true, 5, uri.getScheme()) { + @Override + public NameResolver newNameResolver(Uri passedUri, NameResolver.Args passedArgs) { + return null; + } + }); + final NameResolver nr = new DummyNameResolver(); + registry.register( + new BaseProvider(true, 4, uri.getScheme()) { + @Override + public NameResolver newNameResolver(Uri passedUri, NameResolver.Args passedArgs) { + return nr; + } + }); + registry.register( + new BaseProvider(true, 3, uri.getScheme()) { + @Override + public NameResolver newNameResolver(Uri passedUri, NameResolver.Args passedArgs) { + fail("Should not be called"); + throw new AssertionError(); + } + }); + assertThat(registry.asFactory().newNameResolver(uri, args)).isNull(); + assertThat(registry.asFactory().getDefaultScheme()).isEqualTo(uri.getScheme()); + } - @Override public void shutdown() { - throw new UnsupportedOperationException(); - } - }; + @Test + public void newNameResolver_providerReturnsNonNull_javaNetUri() { + NameResolverRegistry registry = new NameResolverRegistry(); + URI uri = javaNetUri; + registry.register( + new BaseProvider(true, 5, uri.getScheme()) { + @Override + public NameResolver newNameResolver(URI passedUri, NameResolver.Args passedArgs) { + return null; + } + }); + final NameResolver nr = new DummyNameResolver(); registry.register( new BaseProvider(true, 4, uri.getScheme()) { @Override @@ -153,27 +191,45 @@ public NameResolver newNameResolver(URI passedUri, NameResolver.Args passedArgs) } @Test - public void newNameResolver_multipleScheme() { + public void newNameResolver_multipleScheme_ioGrpcUri() { NameResolverRegistry registry = new NameResolverRegistry(); - registry.register(new BaseProvider(true, 5, uri.getScheme()) { - @Override - public NameResolver newNameResolver(URI passedUri, NameResolver.Args passedArgs) { - return null; - } - }); - final NameResolver nr = new NameResolver() { - @Override public String getServiceAuthority() { - throw new UnsupportedOperationException(); - } + Uri uri = ioGrpcUri; + registry.register( + new BaseProvider(true, 5, uri.getScheme()) { + @Override + public NameResolver newNameResolver(Uri passedUri, NameResolver.Args passedArgs) { + return null; + } + }); + final NameResolver nr = new DummyNameResolver(); + registry.register( + new BaseProvider(true, 4, "other") { + @Override + public NameResolver newNameResolver(Uri passedUri, NameResolver.Args passedArgs) { + return nr; + } + }); - @Override public void start(Listener2 listener) { - throw new UnsupportedOperationException(); - } + assertThat(registry.asFactory().newNameResolver(uri, args)).isNull(); + assertThat(registry.asFactory().newNameResolver(Uri.create("other:///0.0.0.0:80"), args)) + .isSameInstanceAs(nr); + assertThat(registry.asFactory().newNameResolver(Uri.create("OTHER:///0.0.0.0:80"), args)) + .isSameInstanceAs(nr); + assertThat(registry.asFactory().getDefaultScheme()).isEqualTo("dns"); + } - @Override public void shutdown() { - throw new UnsupportedOperationException(); - } - }; + @Test + public void newNameResolver_multipleScheme_javaNetUri() { + NameResolverRegistry registry = new NameResolverRegistry(); + URI uri = javaNetUri; + registry.register( + new BaseProvider(true, 5, uri.getScheme()) { + @Override + public NameResolver newNameResolver(URI passedUri, NameResolver.Args passedArgs) { + return null; + } + }); + final NameResolver nr = new DummyNameResolver(); registry.register( new BaseProvider(true, 4, "other") { @Override @@ -186,16 +242,17 @@ public NameResolver newNameResolver(URI passedUri, NameResolver.Args passedArgs) assertThat(registry.asFactory().newNameResolver(URI.create("/0.0.0.0:80"), args)).isNull(); assertThat(registry.asFactory().newNameResolver(URI.create("///0.0.0.0:80"), args)).isNull(); assertThat(registry.asFactory().newNameResolver(URI.create("other:///0.0.0.0:80"), args)) - .isSameInstanceAs(nr); + .isSameInstanceAs(nr); assertThat(registry.asFactory().newNameResolver(URI.create("OTHER:///0.0.0.0:80"), args)) - .isSameInstanceAs(nr); + .isSameInstanceAs(nr); assertThat(registry.asFactory().getDefaultScheme()).isEqualTo("dns"); } @Test public void newNameResolver_noProvider() { NameResolver.Factory factory = new NameResolverRegistry().asFactory(); - assertThat(factory.newNameResolver(uri, args)).isNull(); + assertThat(factory.newNameResolver(javaNetUri, args)).isNull(); + assertThat(factory.newNameResolver(ioGrpcUri, args)).isNull(); assertThat(factory.getDefaultScheme()).isEqualTo("unknown"); } @@ -261,9 +318,31 @@ public NameResolver newNameResolver(URI targetUri, NameResolver.Args args) { throw new UnsupportedOperationException(); } + @Override + public NameResolver newNameResolver(Uri targetUri, NameResolver.Args args) { + throw new UnsupportedOperationException(); + } + @Override public String getDefaultScheme() { return scheme == null ? "scheme" + getClass().getSimpleName() : scheme; } } + + private static class DummyNameResolver extends NameResolver { + @Override + public String getServiceAuthority() { + throw new UnsupportedOperationException(); + } + + @Override + public void start(Listener2 listener) { + throw new UnsupportedOperationException(); + } + + @Override + public void shutdown() { + throw new UnsupportedOperationException(); + } + } } diff --git a/api/src/test/java/io/grpc/NameResolverTest.java b/api/src/test/java/io/grpc/NameResolverTest.java index f825de354af..82abe5c7505 100644 --- a/api/src/test/java/io/grpc/NameResolverTest.java +++ b/api/src/test/java/io/grpc/NameResolverTest.java @@ -17,20 +17,47 @@ package io.grpc; import static com.google.common.truth.Truth.assertThat; +import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import com.google.common.base.Objects; +import io.grpc.NameResolver.ConfigOrError; +import io.grpc.NameResolver.Listener2; +import io.grpc.NameResolver.ResolutionResult; import io.grpc.NameResolver.ServiceConfigParser; import java.lang.Thread.UncaughtExceptionHandler; +import java.net.SocketAddress; +import java.util.Collections; +import java.util.List; import java.util.concurrent.Executor; import java.util.concurrent.Executors; import java.util.concurrent.ScheduledExecutorService; +import org.junit.Rule; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnit; +import org.mockito.junit.MockitoRule; /** Unit tests for the inner classes in {@link NameResolver}. */ @RunWith(JUnit4.class) public class NameResolverTest { + private static final List ADDRESSES = + Collections.singletonList( + new EquivalentAddressGroup(new FakeSocketAddress("fake-address-1"), Attributes.EMPTY)); + private static final Attributes.Key YOLO_ATTR_KEY = Attributes.Key.create("yolo"); + private static Attributes ATTRIBUTES = + Attributes.newBuilder().set(YOLO_ATTR_KEY, "To be, or not to be?").build(); + private static final NameResolver.Args.Key FOO_ARG_KEY = + NameResolver.Args.Key.create("foo"); + private static final NameResolver.Args.Key BAR_ARG_KEY = + NameResolver.Args.Key.create("bar"); + private static ConfigOrError CONFIG = ConfigOrError.fromConfig("foo"); + + @Rule + public final MockitoRule mocks = MockitoJUnit.rule(); private final int defaultPort = 293; private final ProxyDetector proxyDetector = mock(ProxyDetector.class); private final SynchronizationContext syncContext = @@ -41,6 +68,9 @@ public class NameResolverTest { private final ChannelLogger channelLogger = mock(ChannelLogger.class); private final Executor executor = Executors.newSingleThreadExecutor(); private final String overrideAuthority = "grpc.io"; + private final MetricRecorder metricRecorder = new MetricRecorder() {}; + private final int customArgValue = 42; + @Mock NameResolver.Listener mockListener; @Test public void args() { @@ -53,6 +83,9 @@ public void args() { assertThat(args.getChannelLogger()).isSameInstanceAs(channelLogger); assertThat(args.getOffloadExecutor()).isSameInstanceAs(executor); assertThat(args.getOverrideAuthority()).isSameInstanceAs(overrideAuthority); + assertThat(args.getMetricRecorder()).isSameInstanceAs(metricRecorder); + assertThat(args.getArg(FOO_ARG_KEY)).isEqualTo(customArgValue); + assertThat(args.getArg(BAR_ARG_KEY)).isNull(); NameResolver.Args args2 = args.toBuilder().build(); assertThat(args2.getDefaultPort()).isEqualTo(defaultPort); @@ -63,6 +96,9 @@ public void args() { assertThat(args2.getChannelLogger()).isSameInstanceAs(channelLogger); assertThat(args2.getOffloadExecutor()).isSameInstanceAs(executor); assertThat(args2.getOverrideAuthority()).isSameInstanceAs(overrideAuthority); + assertThat(args.getMetricRecorder()).isSameInstanceAs(metricRecorder); + assertThat(args.getArg(FOO_ARG_KEY)).isEqualTo(customArgValue); + assertThat(args.getArg(BAR_ARG_KEY)).isNull(); assertThat(args2).isNotSameInstanceAs(args); assertThat(args2).isNotEqualTo(args); @@ -78,6 +114,144 @@ private NameResolver.Args createArgs() { .setChannelLogger(channelLogger) .setOffloadExecutor(executor) .setOverrideAuthority(overrideAuthority) + .setMetricRecorder(metricRecorder) + .setArg(FOO_ARG_KEY, customArgValue) + .build(); + } + + @Test + @SuppressWarnings("deprecation") + public void startOnOldListener_wrapperListener2UsedToStart() { + final Listener2[] listener2 = new Listener2[1]; + NameResolver nameResolver = new NameResolver() { + @Override + public String getServiceAuthority() { + return null; + } + + @Override + public void shutdown() {} + + @Override + public void start(Listener2 listener2Arg) { + listener2[0] = listener2Arg; + } + }; + nameResolver.start(mockListener); + + listener2[0].onResult(ResolutionResult.newBuilder().setAddresses(ADDRESSES) + .setAttributes(ATTRIBUTES).build()); + verify(mockListener).onAddresses(eq(ADDRESSES), eq(ATTRIBUTES)); + listener2[0].onError(Status.CANCELLED); + verify(mockListener).onError(Status.CANCELLED); + } + + @Test + @SuppressWarnings({"deprecation", "InlineMeInliner"}) + public void listener2AddressesToListener2ResolutionResultConversion() { + final ResolutionResult[] resolutionResult = new ResolutionResult[1]; + NameResolver.Listener2 listener2 = new Listener2() { + @Override + public void onResult(ResolutionResult resolutionResultArg) { + resolutionResult[0] = resolutionResultArg; + } + + @Override + public void onError(Status error) {} + }; + + listener2.onAddresses(ADDRESSES, ATTRIBUTES); + + assertThat(resolutionResult[0].getAddressesOrError().getValue()).isEqualTo(ADDRESSES); + assertThat(resolutionResult[0].getAttributes()).isEqualTo(ATTRIBUTES); + } + + @Test + public void resolutionResult_toString_addressesAttributesAndConfig() { + ResolutionResult resolutionResult = ResolutionResult.newBuilder() + .setAddressesOrError(StatusOr.fromValue(ADDRESSES)) + .setAttributes(ATTRIBUTES) + .setServiceConfig(CONFIG) + .build(); + + assertThat(resolutionResult.toString()).isEqualTo( + "ResolutionResult{addressesOrError=StatusOr{value=" + + "[[[FakeSocketAddress-fake-address-1]/{}]]}, attributes={yolo=To be, or not to be?}, " + + "serviceConfigOrError=ConfigOrError{config=foo}}"); + } + + @Test + public void resolutionResult_hashCode() { + ResolutionResult resolutionResult = ResolutionResult.newBuilder() + .setAddressesOrError(StatusOr.fromValue(ADDRESSES)) + .setAttributes(ATTRIBUTES) + .setServiceConfig(CONFIG) .build(); + + assertThat(resolutionResult.hashCode()).isEqualTo( + Objects.hashCode(StatusOr.fromValue(ADDRESSES), ATTRIBUTES, CONFIG)); + } + + @Test + public void startOnOldListener_resolverReportsError() { + final boolean[] onErrorCalled = new boolean[1]; + final Status[] receivedError = new Status[1]; + + NameResolver resolver = new NameResolver() { + @Override + public String getServiceAuthority() { + return "example.com"; + } + + @Override + public void shutdown() { + } + + @Override + public void start(Listener2 listener2) { + ResolutionResult errorResult = ResolutionResult.newBuilder() + .setAddressesOrError(StatusOr.fromStatus( + Status.UNAVAILABLE + .withDescription("DNS resolution failed with UNAVAILABLE"))) + .build(); + + listener2.onResult(errorResult); + } + }; + + NameResolver.Listener listener = new NameResolver.Listener() { + @Override + public void onAddresses( + List servers, + Attributes attributes) { + throw new AssertionError("Called onAddresses on error"); + } + + @Override + public void onError(Status error) { + onErrorCalled[0] = true; + receivedError[0] = error; + } + }; + + resolver.start(listener); + + assertThat(onErrorCalled[0]).isTrue(); + assertThat(receivedError[0].getCode()).isEqualTo(Status.Code.UNAVAILABLE); + assertThat(receivedError[0].getDescription()).isEqualTo( + "DNS resolution failed with UNAVAILABLE"); + } + + private static class FakeSocketAddress extends SocketAddress { + final String name; + + FakeSocketAddress(String name) { + this.name = name; + } + + @Override + public String toString() { + return "FakeSocketAddress-" + name; + } } } diff --git a/api/src/test/java/io/grpc/QueryParamsTest.java b/api/src/test/java/io/grpc/QueryParamsTest.java new file mode 100644 index 00000000000..2def165a170 --- /dev/null +++ b/api/src/test/java/io/grpc/QueryParamsTest.java @@ -0,0 +1,274 @@ +/* + * Copyright 2026 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; + +import io.grpc.QueryParams.Entry; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Unit tests for {@link QueryParams}. */ +@RunWith(JUnit4.class) +public class QueryParamsTest { + + @Test + public void emptyInstance() { + QueryParams params = new QueryParams(); + assertThat(params.asList()).isEmpty(); + assertThat(params.toRawQuery()).isNull(); + } + + @Test + public void parseNull_yieldsEmptyInstance() { + QueryParams params = QueryParams.fromRawQuery(null); + assertThat(params.asList()).isEmpty(); + assertThat(params.toRawQuery()).isNull(); + } + + @Test + public void parseEmptyString_yieldsSingleLoneKey() { + QueryParams params = QueryParams.fromRawQuery(""); + assertThat(params.toRawQuery()).isEmpty(); + assertThat(params.asList()).isNotEmpty(); + Entry entry = params.asList().get(0); + assertThat(entry).isNotNull(); + assertThat(entry.getKey()).isEmpty(); + assertThat(entry.hasValue()).isFalse(); + assertThat(entry.getValue()).isNull(); + } + + @Test + public void parseNormalPairs() { + QueryParams params = QueryParams.fromRawQuery("a=b&c=d"); + assertThat(params.toRawQuery()).isEqualTo("a=b&c=d"); + + QueryParams.Entry a = params.asList().get(0); + assertThat(a.getKey()).isEqualTo("a"); + assertThat(a.hasValue()).isTrue(); + assertThat(a.getValue()).isEqualTo("b"); + + QueryParams.Entry c = params.asList().get(1); + assertThat(c.getKey()).isEqualTo("c"); + assertThat(c.getValue()).isEqualTo("d"); + } + + @Test + public void parseLoneKey() { + QueryParams params = QueryParams.fromRawQuery("a&b"); + assertThat(params.toRawQuery()).isEqualTo("a&b"); + + QueryParams.Entry a = params.asList().get(0); + assertThat(a.getKey()).isEqualTo("a"); + assertThat(a.hasValue()).isFalse(); + + QueryParams.Entry b = params.asList().get(1); + assertThat(b.getKey()).isEqualTo("b"); + assertThat(b.hasValue()).isFalse(); + } + + @Test + public void parseEmptyKeysAndValues() { + QueryParams params = QueryParams.fromRawQuery("=&="); + assertThat(params.toRawQuery()).isEqualTo("=&="); + + assertThat(params.asList()).hasSize(2); + assertThat(params.asList().get(0).getKey()).isEmpty(); + assertThat(params.asList().get(0).hasValue()).isTrue(); + assertThat(params.asList().get(0).getValue()).isEmpty(); + assertThat(params.asList().get(1).getKey()).isEmpty(); + assertThat(params.asList().get(1).hasValue()).isTrue(); + assertThat(params.asList().get(1).getValue()).isEmpty(); + } + + @Test + public void roundTripPreservesEncodingOfSpaces() { + // Spaces can be encoded as + or %20. + QueryParams params = QueryParams.fromRawQuery("a+b=c%20d"); + assertThat(params.asList().get(0).getKey()).isEqualTo("a b"); + assertThat(params.asList().get(0).getValue()).isEqualTo("c d"); + assertThat(params.toRawQuery()).isEqualTo("a+b=c%20d"); + } + + @Test + public void roundTripPreservesCaseOfHexDigits() { + // Percent encoding can use upper or lower case. + QueryParams params = QueryParams.fromRawQuery("%4A%4a=%4B%4b"); + assertThat(params.asList().get(0).getKey()).isEqualTo("JJ"); + assertThat(params.asList().get(0).getValue()).isEqualTo("KK"); + assertThat(params.toRawQuery()).isEqualTo("%4A%4a=%4B%4b"); + } + + @Test + public void asListMethod() { + QueryParams params = new QueryParams(); + params.asList().add(QueryParams.Entry.forKeyValue("a b", "c d")); + params.asList().add(QueryParams.Entry.forLoneKey("e f")); + + // URLEncoder encodes spaces as + + assertThat(params.toRawQuery()).isEqualTo("a+b=c+d&e+f"); + } + + @Test + public void parseInvalidPercentEncodingThrows() { + assertThrows(IllegalArgumentException.class, () -> QueryParams.fromRawQuery("a=%GH")); + } + + @Test + public void parseInvalidKeyValueEncodingSucceeds() { + QueryParams params = QueryParams.fromRawQuery("===="); + assertThat(params.asList()) + .containsExactly(Entry.forRawKeyValue("", "===")) + .inOrder(); + assertThat(params.toRawQuery()).isEqualTo("===="); + } + + @Test + public void uriIntegration_canBuild() { + QueryParams params = new QueryParams(); + params.asList().add(Entry.forKeyValue("a", "b")); + params.asList().add(Entry.forKeyValue("c", "d")); + + Uri uri = + Uri.newBuilder() + .setScheme("http") + .setHost("example.com") + .setRawQuery(params.toRawQuery()) + .build(); + + assertThat(uri.toString()).isEqualTo("http://example.com?a=b&c=d"); + assertThat(uri.getRawQuery()).isEqualTo("a=b&c=d"); + } + + @Test + public void uriIntegration_canBuildEmpty() { + QueryParams params = new QueryParams(); + Uri uri = + Uri.newBuilder() + .setScheme("http") + .setHost("example.com") + .setRawQuery(params.toRawQuery()) + .build(); + + assertThat(uri.toString()).isEqualTo("http://example.com"); + assertThat(uri.getRawQuery()).isNull(); + } + + @Test + public void uriIntegration_canParse() { + Uri uri = Uri.create("http://example.com?a=b&c=d&e"); + QueryParams params = QueryParams.fromRawQuery(uri.getRawQuery()); + + assertThat(params.asList()) + .containsExactly( + Entry.forKeyValue("a", "b"), Entry.forKeyValue("c", "d"), Entry.forLoneKey("e")) + .inOrder(); + } + + @Test + public void keysAndValuesWithCharactersNeedingUrlEncoding() { + QueryParams params = new QueryParams(); + params.asList().add(Entry.forKeyValue("a=b", "c&d")); + params.asList().add(Entry.forKeyValue("e+f", "g h")); + + assertThat(params.toRawQuery()).isEqualTo("a%3Db=c%26d&e%2Bf=g+h"); + + QueryParams roundTripped = QueryParams.fromRawQuery(params.toRawQuery()); + assertThat(roundTripped).isEqualTo(params); + } + + @Test + public void keysAndValuesWithCodePointsOutsideAsciiRange() { + QueryParams params = new QueryParams(); + params.asList().add(Entry.forKeyValue("€", "𐐷")); + + assertThat(params.toRawQuery()).isEqualTo("%E2%82%AC=%F0%90%90%B7"); + + QueryParams roundTripped = QueryParams.fromRawQuery(params.toRawQuery()); + assertThat(roundTripped).isEqualTo(params); + } + + @Test + public void toStringMethod() { + QueryParams params = new QueryParams(); + assertThat(params.toString()).isEqualTo("[]"); + + params.asList().add(Entry.forKeyValue("a", "b")); + assertThat(params.toString()).isEqualTo("[a=b]"); + + params.asList().add(Entry.forLoneKey("c")); + assertThat(params.toString()).isEqualTo("[a=b, c]"); + + params.asList().add(Entry.forKeyValue("d=e", "f&g")); + assertThat(params.toString()).isEqualTo("[a=b, c, d%3De=f%26g]"); + } + + @Test + public void entryProperties() { + Entry keyValue = Entry.forKeyValue("key", "val"); + assertThat(keyValue.getKey()).isEqualTo("key"); + assertThat(keyValue.getValue()).isEqualTo("val"); + assertThat(keyValue.hasValue()).isTrue(); + + Entry loneKey = Entry.forLoneKey("key"); + assertThat(loneKey.getKey()).isEqualTo("key"); + assertThat(loneKey.getValue()).isNull(); + assertThat(loneKey.hasValue()).isFalse(); + } + + @Test + public void equalsAndHashCode_container() { + QueryParams params1 = new QueryParams(); + QueryParams params2 = new QueryParams(); + + // Empty instances are equal + assertThat(params1).isEqualTo(params2); + assertThat(params1.hashCode()).isEqualTo(params2.hashCode()); + + params1.asList().add(Entry.forKeyValue("a", "b")); + params1.asList().add(Entry.forLoneKey("c")); + + params2.asList().add(Entry.forKeyValue("a", "b")); + params2.asList().add(Entry.forLoneKey("c")); + + // Identical parameters in identical order are equal + assertThat(params1).isEqualTo(params2); + assertThat(params1.hashCode()).isEqualTo(params2.hashCode()); + + // Order matters. + QueryParams params3 = new QueryParams(); + params3.asList().add(Entry.forLoneKey("c")); + params3.asList().add(Entry.forKeyValue("a", "b")); + assertThat(params1).isNotEqualTo(params3); + } + + @Test + public void equalsAndHashCode_entry() { + // Raw matches are equal. + assertThat(Entry.forRawKeyValue("a+b", "c")).isEqualTo(Entry.forRawKeyValue("a+b", "c")); + assertThat(Entry.forRawKeyValue("a+b", "c").hashCode()) + .isEqualTo(Entry.forRawKeyValue("a+b", "c").hashCode()); + + // Spaces encoding matters. + and %20 are not equal. + assertThat(Entry.forRawKeyValue("a+b", "c")).isNotEqualTo(Entry.forRawKeyValue("a%20b", "c")); + + // Case of hex digits matter: %4A vs %4a are not equal raw keys. + assertThat(Entry.forRawKeyValue("a", "%4A")).isNotEqualTo(Entry.forRawKeyValue("a", "%4a")); + } +} diff --git a/api/src/test/java/io/grpc/ServerInterceptorsTest.java b/api/src/test/java/io/grpc/ServerInterceptorsTest.java index abfb3540fe4..b84b3838afa 100644 --- a/api/src/test/java/io/grpc/ServerInterceptorsTest.java +++ b/api/src/test/java/io/grpc/ServerInterceptorsTest.java @@ -19,6 +19,7 @@ import static com.google.common.collect.Iterables.getOnlyElement; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; import static org.mockito.AdditionalAnswers.delegatesTo; import static org.mockito.ArgumentMatchers.same; @@ -40,7 +41,6 @@ import org.junit.Before; import org.junit.Rule; import org.junit.Test; -import org.junit.rules.ExpectedException; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; import org.mockito.ArgumentMatchers; @@ -55,10 +55,6 @@ public class ServerInterceptorsTest { @Rule public final MockitoRule mocks = MockitoJUnit.rule(); - @SuppressWarnings("deprecation") // https://github.com/grpc/grpc-java/issues/7467 - @Rule - public final ExpectedException thrown = ExpectedException.none(); - @Mock private Marshaller requestMarshaller; @@ -111,21 +107,21 @@ public void makeSureExpectedMocksUnused() { public void npeForNullServiceDefinition() { ServerServiceDefinition serviceDef = null; List interceptors = Arrays.asList(); - thrown.expect(NullPointerException.class); - ServerInterceptors.intercept(serviceDef, interceptors); + assertThrows(NullPointerException.class, + () -> ServerInterceptors.intercept(serviceDef, interceptors)); } @Test public void npeForNullInterceptorList() { - thrown.expect(NullPointerException.class); - ServerInterceptors.intercept(serviceDefinition, (List) null); + assertThrows(NullPointerException.class, + () -> ServerInterceptors.intercept(serviceDefinition, (List) null)); } @Test public void npeForNullInterceptor() { List interceptors = Arrays.asList((ServerInterceptor) null); - thrown.expect(NullPointerException.class); - ServerInterceptors.intercept(serviceDefinition, interceptors); + assertThrows(NullPointerException.class, + () -> ServerInterceptors.intercept(serviceDefinition, interceptors)); } @Test diff --git a/api/src/test/java/io/grpc/ServerServiceDefinitionTest.java b/api/src/test/java/io/grpc/ServerServiceDefinitionTest.java index 6a84d640d78..9e43302e210 100644 --- a/api/src/test/java/io/grpc/ServerServiceDefinitionTest.java +++ b/api/src/test/java/io/grpc/ServerServiceDefinitionTest.java @@ -18,14 +18,13 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertThrows; import static org.junit.Assert.fail; import java.util.ArrayList; import java.util.Collections; import java.util.HashSet; -import org.junit.Rule; import org.junit.Test; -import org.junit.rules.ExpectedException; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -52,9 +51,6 @@ public class ServerServiceDefinitionTest { = ServerMethodDefinition.create(method1, methodHandler1); private ServerMethodDefinition methodDef2 = ServerMethodDefinition.create(method2, methodHandler2); - @SuppressWarnings("deprecation") // https://github.com/grpc/grpc-java/issues/7467 - @Rule - public ExpectedException thrown = ExpectedException.none(); @Test public void noMethods() { @@ -91,9 +87,7 @@ public void addMethod_duplicateName() { ServiceDescriptor sd = new ServiceDescriptor(serviceName, method1); ServerServiceDefinition.Builder ssd = ServerServiceDefinition.builder(sd) .addMethod(method1, methodHandler1); - thrown.expect(IllegalStateException.class); - ssd.addMethod(diffMethod1, methodHandler2) - .build(); + assertThrows(IllegalStateException.class, () -> ssd.addMethod(diffMethod1, methodHandler2)); } @Test @@ -101,8 +95,7 @@ public void buildMisaligned_extraMethod() { ServiceDescriptor sd = new ServiceDescriptor(serviceName); ServerServiceDefinition.Builder ssd = ServerServiceDefinition.builder(sd) .addMethod(methodDef1); - thrown.expect(IllegalStateException.class); - ssd.build(); + assertThrows(IllegalStateException.class, ssd::build); } @Test @@ -110,16 +103,14 @@ public void buildMisaligned_diffMethodInstance() { ServiceDescriptor sd = new ServiceDescriptor(serviceName, method1); ServerServiceDefinition.Builder ssd = ServerServiceDefinition.builder(sd) .addMethod(diffMethod1, methodHandler1); - thrown.expect(IllegalStateException.class); - ssd.build(); + assertThrows(IllegalStateException.class, ssd::build); } @Test public void buildMisaligned_missingMethod() { ServiceDescriptor sd = new ServiceDescriptor(serviceName, method1); ServerServiceDefinition.Builder ssd = ServerServiceDefinition.builder(sd); - thrown.expect(IllegalStateException.class); - ssd.build(); + assertThrows(IllegalStateException.class, ssd::build); } @Test diff --git a/api/src/test/java/io/grpc/ServiceDescriptorTest.java b/api/src/test/java/io/grpc/ServiceDescriptorTest.java index a05858680d5..89bdead3632 100644 --- a/api/src/test/java/io/grpc/ServiceDescriptorTest.java +++ b/api/src/test/java/io/grpc/ServiceDescriptorTest.java @@ -16,17 +16,18 @@ package io.grpc; +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; +import com.google.common.truth.StringSubject; import io.grpc.MethodDescriptor.MethodType; import io.grpc.testing.TestMethodDescriptors; import java.util.Arrays; import java.util.Collection; import java.util.Collections; import java.util.List; -import org.junit.Rule; import org.junit.Test; -import org.junit.rules.ExpectedException; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -36,32 +37,27 @@ @RunWith(JUnit4.class) public class ServiceDescriptorTest { - @SuppressWarnings("deprecation") // https://github.com/grpc/grpc-java/issues/7467 - @Rule - public final ExpectedException thrown = ExpectedException.none(); - @Test public void failsOnNullName() { - thrown.expect(NullPointerException.class); - thrown.expectMessage("name"); - - new ServiceDescriptor(null, Collections.>emptyList()); + List> methods = Collections.emptyList(); + NullPointerException e = assertThrows(NullPointerException.class, + () -> new ServiceDescriptor(null, methods)); + assertThat(e).hasMessageThat().isEqualTo("name"); } @Test public void failsOnNullMethods() { - thrown.expect(NullPointerException.class); - thrown.expectMessage("methods"); - - new ServiceDescriptor("name", (Collection>) null); + NullPointerException e = assertThrows(NullPointerException.class, + () -> new ServiceDescriptor("name", (Collection>) null)); + assertThat(e).hasMessageThat().isEqualTo("methods"); } @Test public void failsOnNullMethod() { - thrown.expect(NullPointerException.class); - thrown.expectMessage("method"); - - new ServiceDescriptor("name", Collections.>singletonList(null)); + List> methods = Collections.singletonList(null); + NullPointerException e = assertThrows(NullPointerException.class, + () -> new ServiceDescriptor("name", methods)); + assertThat(e).hasMessageThat().isEqualTo("method"); } @Test @@ -69,15 +65,17 @@ public void failsOnNonMatchingNames() { List> descriptors = Collections.>singletonList( MethodDescriptor.newBuilder() .setType(MethodType.UNARY) - .setFullMethodName(MethodDescriptor.generateFullMethodName("wrongservice", "method")) + .setFullMethodName(MethodDescriptor.generateFullMethodName("wrongService", "method")) .setRequestMarshaller(TestMethodDescriptors.voidMarshaller()) .setResponseMarshaller(TestMethodDescriptors.voidMarshaller()) .build()); - thrown.expect(IllegalArgumentException.class); - thrown.expectMessage("service names"); - - new ServiceDescriptor("name", descriptors); + IllegalArgumentException e = assertThrows(IllegalArgumentException.class, + () -> new ServiceDescriptor("fooService", descriptors)); + StringSubject error = assertThat(e).hasMessageThat(); + error.contains("service names"); + error.contains("fooService"); + error.contains("wrongService"); } @Test @@ -96,10 +94,9 @@ public void failsOnNonDuplicateNames() { .setResponseMarshaller(TestMethodDescriptors.voidMarshaller()) .build()); - thrown.expect(IllegalArgumentException.class); - thrown.expectMessage("duplicate"); - - new ServiceDescriptor("name", descriptors); + IllegalArgumentException e = assertThrows(IllegalArgumentException.class, + () -> new ServiceDescriptor("name", descriptors)); + assertThat(e).hasMessageThat().isEqualTo("duplicate name name/method"); } @Test diff --git a/api/src/test/java/io/grpc/ServiceProvidersTest.java b/api/src/test/java/io/grpc/ServiceProvidersTest.java index 7d4388a5bb9..f971ed42646 100644 --- a/api/src/test/java/io/grpc/ServiceProvidersTest.java +++ b/api/src/test/java/io/grpc/ServiceProvidersTest.java @@ -16,6 +16,7 @@ package io.grpc; +import static com.google.common.truth.Truth.assertThat; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNull; @@ -23,12 +24,15 @@ import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; +import com.google.common.base.Supplier; import com.google.common.collect.ImmutableList; import io.grpc.InternalServiceProviders.PriorityAccessor; import java.util.Collections; import java.util.Iterator; import java.util.List; import java.util.ServiceConfigurationError; +import java.util.ServiceLoader; +import org.junit.After; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -36,7 +40,6 @@ /** Unit tests for {@link ServiceProviders}. */ @RunWith(JUnit4.class) public class ServiceProvidersTest { - private static final List> NO_HARDCODED = Collections.emptyList(); private static final PriorityAccessor ACCESSOR = new PriorityAccessor() { @Override @@ -51,6 +54,19 @@ public int getPriority(ServiceProvidersTestAbstractProvider provider) { }; private final String serviceFile = "META-INF/services/io.grpc.ServiceProvidersTestAbstractProvider"; + private boolean failingHardCodedAccessed; + private final Supplier>> failingHardCoded = new Supplier>>() { + @Override + public Iterable> get() { + failingHardCodedAccessed = true; + throw new AssertionError(); + } + }; + + @After + public void tearDown() { + assertThat(failingHardCodedAccessed).isFalse(); + } @Test public void contextClassLoaderProvider() { @@ -69,8 +85,8 @@ public void contextClassLoaderProvider() { Thread.currentThread().setContextClassLoader(rcll); assertEquals( Available7Provider.class, - ServiceProviders.load( - ServiceProvidersTestAbstractProvider.class, NO_HARDCODED, cl, ACCESSOR).getClass()); + load(ServiceProvidersTestAbstractProvider.class, failingHardCoded, cl, ACCESSOR) + .getClass()); } finally { Thread.currentThread().setContextClassLoader(ccl); } @@ -85,8 +101,7 @@ public void noProvider() { serviceFile, "io/grpc/ServiceProvidersTestAbstractProvider-doesNotExist.txt"); Thread.currentThread().setContextClassLoader(cl); - assertNull(ServiceProviders.load( - ServiceProvidersTestAbstractProvider.class, NO_HARDCODED, cl, ACCESSOR)); + assertNull(load(ServiceProvidersTestAbstractProvider.class, failingHardCoded, cl, ACCESSOR)); } finally { Thread.currentThread().setContextClassLoader(ccl); } @@ -98,11 +113,11 @@ public void multipleProvider() throws Exception { "io/grpc/ServiceProvidersTestAbstractProvider-multipleProvider.txt"); assertSame( Available7Provider.class, - ServiceProviders.load( - ServiceProvidersTestAbstractProvider.class, NO_HARDCODED, cl, ACCESSOR).getClass()); + load(ServiceProvidersTestAbstractProvider.class, failingHardCoded, cl, ACCESSOR) + .getClass()); - List providers = ServiceProviders.loadAll( - ServiceProvidersTestAbstractProvider.class, NO_HARDCODED, cl, ACCESSOR); + List providers = loadAll( + ServiceProvidersTestAbstractProvider.class, failingHardCoded, cl, ACCESSOR); assertEquals(3, providers.size()); assertEquals(Available7Provider.class, providers.get(0).getClass()); assertEquals(Available5Provider.class, providers.get(1).getClass()); @@ -116,8 +131,8 @@ public void unavailableProvider() { "io/grpc/ServiceProvidersTestAbstractProvider-unavailableProvider.txt"); assertEquals( Available7Provider.class, - ServiceProviders.load( - ServiceProvidersTestAbstractProvider.class, NO_HARDCODED, cl, ACCESSOR).getClass()); + load(ServiceProvidersTestAbstractProvider.class, failingHardCoded, cl, ACCESSOR) + .getClass()); } @Test @@ -125,8 +140,7 @@ public void unknownClassProvider() { ClassLoader cl = new ReplacingClassLoader(getClass().getClassLoader(), serviceFile, "io/grpc/ServiceProvidersTestAbstractProvider-unknownClassProvider.txt"); try { - ServiceProviders.load( - ServiceProvidersTestAbstractProvider.class, NO_HARDCODED, cl, ACCESSOR); + loadAll(ServiceProvidersTestAbstractProvider.class, failingHardCoded, cl, ACCESSOR); fail("Exception expected"); } catch (ServiceConfigurationError e) { // noop @@ -140,8 +154,7 @@ public void exceptionSurfacedToCaller_failAtInit() { try { // Even though there is a working provider, if any providers fail then we should fail // completely to avoid returning something unexpected. - ServiceProviders.load( - ServiceProvidersTestAbstractProvider.class, NO_HARDCODED, cl, ACCESSOR); + loadAll(ServiceProvidersTestAbstractProvider.class, failingHardCoded, cl, ACCESSOR); fail("Expected exception"); } catch (ServiceConfigurationError expected) { // noop @@ -154,8 +167,7 @@ public void exceptionSurfacedToCaller_failAtPriority() { "io/grpc/ServiceProvidersTestAbstractProvider-failAtPriorityProvider.txt"); try { // The exception should be surfaced to the caller - ServiceProviders.load( - ServiceProvidersTestAbstractProvider.class, NO_HARDCODED, cl, ACCESSOR); + loadAll(ServiceProvidersTestAbstractProvider.class, failingHardCoded, cl, ACCESSOR); fail("Expected exception"); } catch (FailAtPriorityProvider.PriorityException expected) { // noop @@ -168,8 +180,7 @@ public void exceptionSurfacedToCaller_failAtAvailable() { "io/grpc/ServiceProvidersTestAbstractProvider-failAtAvailableProvider.txt"); try { // The exception should be surfaced to the caller - ServiceProviders.load( - ServiceProvidersTestAbstractProvider.class, NO_HARDCODED, cl, ACCESSOR); + loadAll(ServiceProvidersTestAbstractProvider.class, failingHardCoded, cl, ACCESSOR); fail("Expected exception"); } catch (FailAtAvailableProvider.AvailableException expected) { // noop @@ -244,6 +255,30 @@ class RandomClass {} assertFalse(candidates.iterator().hasNext()); } + private static T load( + Class klass, + Supplier>> hardCoded, + ClassLoader cl, + PriorityAccessor priorityAccessor) { + List candidates = loadAll(klass, hardCoded, cl, priorityAccessor); + if (candidates.isEmpty()) { + return null; + } + return candidates.get(0); + } + + private static List loadAll( + Class klass, + Supplier>> hardCoded, + ClassLoader classLoader, + PriorityAccessor priorityAccessor) { + return ServiceProviders.loadAll( + klass, + ServiceLoader.load(klass, classLoader).iterator(), + hardCoded, + priorityAccessor); + } + private static class BaseProvider extends ServiceProvidersTestAbstractProvider { private final boolean isAvailable; private final int priority; diff --git a/api/src/test/java/io/grpc/StatusExceptionTest.java b/api/src/test/java/io/grpc/StatusExceptionTest.java index dd0d12dccda..410cfb2289a 100644 --- a/api/src/test/java/io/grpc/StatusExceptionTest.java +++ b/api/src/test/java/io/grpc/StatusExceptionTest.java @@ -28,14 +28,6 @@ @RunWith(JUnit4.class) public class StatusExceptionTest { - @Test - public void internalCtorRemovesStack() { - StackTraceElement[] trace = - new StatusException(Status.CANCELLED, null, false) {}.getStackTrace(); - - assertThat(trace).isEmpty(); - } - @Test public void normalCtorKeepsStack() { StackTraceElement[] trace = diff --git a/api/src/test/java/io/grpc/StatusOrTest.java b/api/src/test/java/io/grpc/StatusOrTest.java new file mode 100644 index 00000000000..f63a314a2bb --- /dev/null +++ b/api/src/test/java/io/grpc/StatusOrTest.java @@ -0,0 +1,81 @@ +/* + * Copyright 2015 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc; + +import static com.google.common.truth.Truth.assertThat; +import static junit.framework.TestCase.fail; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Unit tests for {@link StatusOr}. **/ +@RunWith(JUnit4.class) +public class StatusOrTest { + + @Test + public void getValue_throwsIfNoValuePresent() { + try { + StatusOr.fromStatus(Status.ABORTED).getValue(); + + fail("Expected exception."); + } catch (IllegalStateException expected) { } + } + + @Test + @SuppressWarnings("TruthIncompatibleType") + public void equals_differentValueTypes() { + assertThat(StatusOr.fromValue(1)).isNotEqualTo(StatusOr.fromValue("1")); + } + + @Test + public void equals_differentValues() { + assertThat(StatusOr.fromValue(1)).isNotEqualTo(StatusOr.fromValue(2)); + } + + @Test + public void equals_sameValues() { + assertThat(StatusOr.fromValue(1)).isEqualTo(StatusOr.fromValue(1)); + } + + @Test + public void equals_differentStatuses() { + assertThat(StatusOr.fromStatus(Status.ABORTED)).isNotEqualTo( + StatusOr.fromStatus(Status.CANCELLED)); + } + + @Test + public void equals_sameStatuses() { + assertThat(StatusOr.fromStatus(Status.ABORTED)).isEqualTo(StatusOr.fromStatus(Status.ABORTED)); + } + + @Test + public void toString_value() { + assertThat(StatusOr.fromValue(1).toString()).isEqualTo("StatusOr{value=1}"); + } + + @Test + public void toString_nullValue() { + assertThat(StatusOr.fromValue(null).toString()).isEqualTo("StatusOr{value=null}"); + } + + @Test + public void toString_errorStatus() { + assertThat(StatusOr.fromStatus(Status.ABORTED).toString()).isEqualTo( + "StatusOr{error=Status{code=ABORTED, description=null, cause=null}}"); + } +} \ No newline at end of file diff --git a/api/src/test/java/io/grpc/StatusRuntimeExceptionTest.java b/api/src/test/java/io/grpc/StatusRuntimeExceptionTest.java index ab20c111254..d965ed86253 100644 --- a/api/src/test/java/io/grpc/StatusRuntimeExceptionTest.java +++ b/api/src/test/java/io/grpc/StatusRuntimeExceptionTest.java @@ -31,7 +31,7 @@ public class StatusRuntimeExceptionTest { @Test public void internalCtorRemovesStack() { StackTraceElement[] trace = - new StatusRuntimeException(Status.CANCELLED, null, false) {}.getStackTrace(); + new InternalStatusRuntimeException(Status.CANCELLED, null) {}.getStackTrace(); assertThat(trace).isEmpty(); } diff --git a/api/src/test/java/io/grpc/SynchronizationContextTest.java b/api/src/test/java/io/grpc/SynchronizationContextTest.java index 3d5e7fa42b9..668f5ae4d6d 100644 --- a/api/src/test/java/io/grpc/SynchronizationContextTest.java +++ b/api/src/test/java/io/grpc/SynchronizationContextTest.java @@ -27,6 +27,7 @@ import com.google.common.util.concurrent.testing.TestingExecutors; import io.grpc.SynchronizationContext.ScheduledHandle; +import java.time.Duration; import java.util.concurrent.BlockingQueue; import java.util.concurrent.CountDownLatch; import java.util.concurrent.LinkedBlockingQueue; @@ -72,7 +73,7 @@ public void uncaughtException(Thread t, Throwable e) { @Mock private Runnable task3; - + @After public void tearDown() { assertThat(uncaughtErrors).isEmpty(); } @@ -246,6 +247,43 @@ public void schedule() { verify(task1).run(); } + @Test + @IgnoreJRERequirement + public void scheduleDuration() { + MockScheduledExecutorService executorService = new MockScheduledExecutorService(); + ScheduledHandle handle = + syncContext.schedule(task1, Duration.ofSeconds(10), executorService); + + assertThat(executorService.delay) + .isEqualTo(executorService.unit.convert(10, TimeUnit.SECONDS)); + assertThat(handle.isPending()).isTrue(); + verify(task1, never()).run(); + + executorService.command.run(); + + assertThat(handle.isPending()).isFalse(); + verify(task1).run(); + } + + @Test + @IgnoreJRERequirement + public void scheduleWithFixedDelayDuration() { + MockScheduledExecutorService executorService = new MockScheduledExecutorService(); + ScheduledHandle handle = + syncContext.scheduleWithFixedDelay(task1, Duration.ofSeconds(10), + Duration.ofSeconds(10), executorService); + + assertThat(executorService.delay) + .isEqualTo(executorService.unit.convert(10, TimeUnit.SECONDS)); + assertThat(handle.isPending()).isTrue(); + verify(task1, never()).run(); + + executorService.command.run(); + + assertThat(handle.isPending()).isFalse(); + verify(task1).run(); + } + @Test public void scheduleDueImmediately() { MockScheduledExecutorService executorService = new MockScheduledExecutorService(); @@ -357,5 +395,13 @@ static class MockScheduledExecutorService extends ForwardingScheduledExecutorSer this.unit = unit; return future = super.schedule(command, delay, unit); } + + @Override public ScheduledFuture scheduleWithFixedDelay(Runnable command, long intialDelay, + long delay, TimeUnit unit) { + this.command = command; + this.delay = delay; + this.unit = unit; + return future = super.scheduleWithFixedDelay(command, intialDelay, delay, unit); + } } } diff --git a/api/src/test/java/io/grpc/TimeUtilsTest.java b/api/src/test/java/io/grpc/TimeUtilsTest.java new file mode 100644 index 00000000000..728b8512cd7 --- /dev/null +++ b/api/src/test/java/io/grpc/TimeUtilsTest.java @@ -0,0 +1,60 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc; + +import static org.junit.Assert.assertEquals; + +import java.time.Duration; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Unit tests for {@link TimeUtils}. */ +@RunWith(JUnit4.class) +@IgnoreJRERequirement +public class TimeUtilsTest { + + @Test + public void testConvertNormalDuration() { + Duration duration = Duration.ofSeconds(10); + long expected = 10 * 1_000_000_000L; + + assertEquals(expected, TimeUtils.convertToNanos(duration)); + } + + @Test + public void testConvertNegativeDuration() { + Duration duration = Duration.ofSeconds(-3); + long expected = -3 * 1_000_000_000L; + + assertEquals(expected, TimeUtils.convertToNanos(duration)); + } + + @Test + public void testConvertTooLargeDuration() { + Duration duration = Duration.ofSeconds(Long.MAX_VALUE / 1_000_000_000L + 1); + + assertEquals(Long.MAX_VALUE, TimeUtils.convertToNanos(duration)); + } + + @Test + public void testConvertTooLargeNegativeDuration() { + Duration duration = Duration.ofSeconds(Long.MIN_VALUE / 1_000_000_000L - 1); + + assertEquals(Long.MIN_VALUE, TimeUtils.convertToNanos(duration)); + } +} diff --git a/api/src/test/java/io/grpc/UriTest.java b/api/src/test/java/io/grpc/UriTest.java new file mode 100644 index 00000000000..0de7ef0fc10 --- /dev/null +++ b/api/src/test/java/io/grpc/UriTest.java @@ -0,0 +1,828 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; +import static org.junit.Assume.assumeNoException; + +import com.google.common.net.InetAddresses; +import com.google.common.testing.EqualsTester; +import java.net.Inet6Address; +import java.net.URISyntaxException; +import java.net.UnknownHostException; +import java.util.BitSet; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public final class UriTest { + + @Test + public void parse_allComponents() throws URISyntaxException { + Uri uri = Uri.parse("scheme://user@host:0443/path?query#fragment"); + assertThat(uri.getScheme()).isEqualTo("scheme"); + assertThat(uri.getAuthority()).isEqualTo("user@host:0443"); + assertThat(uri.getUserInfo()).isEqualTo("user"); + assertThat(uri.getPort()).isEqualTo(443); + assertThat(uri.getRawPort()).isEqualTo("0443"); + assertThat(uri.getPath()).isEqualTo("/path"); + assertThat(uri.getRawQuery()).isEqualTo("query"); + assertThat(uri.getFragment()).isEqualTo("fragment"); + assertThat(uri.toString()).isEqualTo("scheme://user@host:0443/path?query#fragment"); + assertThat(uri.isAbsolute()).isFalse(); // Has a fragment. + assertThat(uri.isPathAbsolute()).isTrue(); + assertThat(uri.isPathRootless()).isFalse(); + } + + @Test + public void parse_noAuthority() throws URISyntaxException { + Uri uri = Uri.parse("scheme:/path?query#fragment"); + assertThat(uri.getScheme()).isEqualTo("scheme"); + assertThat(uri.getAuthority()).isNull(); + assertThat(uri.getPath()).isEqualTo("/path"); + assertThat(uri.getRawQuery()).isEqualTo("query"); + assertThat(uri.getFragment()).isEqualTo("fragment"); + assertThat(uri.toString()).isEqualTo("scheme:/path?query#fragment"); + assertThat(uri.isAbsolute()).isFalse(); // Has a fragment. + } + + @Test + public void parse_ipv6Literal_withPort() throws URISyntaxException { + Uri uri = Uri.parse("scheme://[2001:db8::7]:012345"); + assertThat(uri.getAuthority()).isEqualTo("[2001:db8::7]:012345"); + assertThat(uri.getRawHost()).isEqualTo("[2001:db8::7]"); + assertThat(uri.getHost()).isEqualTo("[2001:db8::7]"); + assertThat(uri.getRawPort()).isEqualTo("012345"); + assertThat(uri.getPort()).isEqualTo(12345); + } + + @Test + public void parse_ipv6Literal_noPort() throws URISyntaxException { + Uri uri = Uri.parse("scheme://[2001:db8::7]"); + assertThat(uri.getAuthority()).isEqualTo("[2001:db8::7]"); + assertThat(uri.getRawHost()).isEqualTo("[2001:db8::7]"); + assertThat(uri.getHost()).isEqualTo("[2001:db8::7]"); + assertThat(uri.getRawPort()).isNull(); + assertThat(uri.getPort()).isLessThan(0); + } + + @Test + public void parse_ipv6ScopedLiteral() throws URISyntaxException { + Uri uri = Uri.parse("http://[fe80::1%25eth0]"); + assertThat(uri.getRawHost()).isEqualTo("[fe80::1%25eth0]"); + assertThat(uri.getHost()).isEqualTo("[fe80::1%eth0]"); + } + + @Test + public void parse_ipv6ScopedPercentEncodedLiteral() throws URISyntaxException { + Uri uri = Uri.parse("http://[fe80::1%25foo-bar%2Fblah]"); + assertThat(uri.getRawHost()).isEqualTo("[fe80::1%25foo-bar%2Fblah]"); + assertThat(uri.getHost()).isEqualTo("[fe80::1%foo-bar/blah]"); + } + + @Test + public void parse_noQuery() throws URISyntaxException { + Uri uri = Uri.parse("scheme://authority/path#fragment"); + assertThat(uri.getScheme()).isEqualTo("scheme"); + assertThat(uri.getAuthority()).isEqualTo("authority"); + assertThat(uri.getPath()).isEqualTo("/path"); + assertThat(uri.getRawQuery()).isNull(); + assertThat(uri.getFragment()).isEqualTo("fragment"); + assertThat(uri.toString()).isEqualTo("scheme://authority/path#fragment"); + } + + @Test + public void parse_noFragment() throws URISyntaxException { + Uri uri = Uri.parse("scheme://authority/path?query"); + assertThat(uri.getScheme()).isEqualTo("scheme"); + assertThat(uri.getAuthority()).isEqualTo("authority"); + assertThat(uri.getPath()).isEqualTo("/path"); + assertThat(uri.getRawQuery()).isEqualTo("query"); + assertThat(uri.getFragment()).isNull(); + assertThat(uri.toString()).isEqualTo("scheme://authority/path?query"); + assertThat(uri.isAbsolute()).isTrue(); + } + + @Test + public void parse_emptyPathWithAuthority() throws URISyntaxException { + Uri uri = Uri.parse("scheme://authority"); + assertThat(uri.getScheme()).isEqualTo("scheme"); + assertThat(uri.getAuthority()).isEqualTo("authority"); + assertThat(uri.getPath()).isEmpty(); + assertThat(uri.getRawQuery()).isNull(); + assertThat(uri.getFragment()).isNull(); + assertThat(uri.toString()).isEqualTo("scheme://authority"); + assertThat(uri.isAbsolute()).isTrue(); + assertThat(uri.isPathAbsolute()).isFalse(); + assertThat(uri.isPathRootless()).isFalse(); + } + + @Test + public void parse_rootless() throws URISyntaxException { + Uri uri = Uri.parse("mailto:ceo@company.com?subject=raise"); + assertThat(uri.getScheme()).isEqualTo("mailto"); + assertThat(uri.getAuthority()).isNull(); + assertThat(uri.getPath()).isEqualTo("ceo@company.com"); + assertThat(uri.getRawQuery()).isEqualTo("subject=raise"); + assertThat(uri.getFragment()).isNull(); + assertThat(uri.toString()).isEqualTo("mailto:ceo@company.com?subject=raise"); + assertThat(uri.isAbsolute()).isTrue(); + assertThat(uri.isPathAbsolute()).isFalse(); + assertThat(uri.isPathRootless()).isTrue(); + } + + @Test + public void parse_emptyPath() throws URISyntaxException { + Uri uri = Uri.parse("scheme:"); + assertThat(uri.getScheme()).isEqualTo("scheme"); + assertThat(uri.getAuthority()).isNull(); + assertThat(uri.getPath()).isEmpty(); + assertThat(uri.getRawQuery()).isNull(); + assertThat(uri.getFragment()).isNull(); + assertThat(uri.toString()).isEqualTo("scheme:"); + assertThat(uri.isAbsolute()).isTrue(); + assertThat(uri.isPathAbsolute()).isFalse(); + assertThat(uri.isPathRootless()).isFalse(); + } + + @Test + public void parse_emptyQuery() { + Uri uri = Uri.create("scheme:?"); + assertThat(uri.getScheme()).isEqualTo("scheme"); + assertThat(uri.getRawQuery()).isEmpty(); + } + + @Test + public void parse_emptyFragment() { + Uri uri = Uri.create("scheme:#"); + assertThat(uri.getScheme()).isEqualTo("scheme"); + assertThat(uri.getFragment()).isEmpty(); + } + + @Test + public void parse_emptyUserInfo() { + Uri uri = Uri.create("scheme://@host"); + assertThat(uri.getScheme()).isEqualTo("scheme"); + assertThat(uri.getAuthority()).isEqualTo("@host"); + assertThat(uri.getHost()).isEqualTo("host"); + assertThat(uri.getUserInfo()).isEmpty(); + assertThat(uri.toString()).isEqualTo("scheme://@host"); + } + + @Test + public void parse_emptyPort() { + Uri uri = Uri.create("scheme://host:"); + assertThat(uri.getScheme()).isEqualTo("scheme"); + assertThat(uri.getAuthority()).isEqualTo("host:"); + assertThat(uri.getRawAuthority()).isEqualTo("host:"); + assertThat(uri.getHost()).isEqualTo("host"); + assertThat(uri.getPort()).isEqualTo(-1); + assertThat(uri.getRawPort()).isEqualTo(""); + assertThat(uri.toString()).isEqualTo("scheme://host:"); + } + + @Test + public void parse_invalidScheme_throws() { + URISyntaxException e = + assertThrows(URISyntaxException.class, () -> Uri.parse("1scheme://authority/path")); + assertThat(e).hasMessageThat().contains("Scheme must start with an alphabetic char"); + + e = assertThrows(URISyntaxException.class, () -> Uri.parse(":path")); + assertThat(e).hasMessageThat().contains("Scheme must start with an alphabetic char"); + } + + @Test + public void parse_unTerminatedScheme_throws() { + URISyntaxException e = assertThrows(URISyntaxException.class, () -> Uri.parse("scheme/")); + assertThat(e).hasMessageThat().contains("Missing required scheme"); + + e = assertThrows(URISyntaxException.class, () -> Uri.parse("scheme?")); + assertThat(e).hasMessageThat().contains("Missing required scheme"); + + e = assertThrows(URISyntaxException.class, () -> Uri.parse("scheme#")); + assertThat(e).hasMessageThat().contains("Missing required scheme"); + } + + @Test + public void parse_invalidCharactersInScheme_throws() { + URISyntaxException e = + assertThrows(URISyntaxException.class, () -> Uri.parse("schem e://authority/path")); + assertThat(e).hasMessageThat().contains("Invalid character in scheme"); + } + + @Test + public void parse_unTerminatedAuthority_throws() { + Uri uri = Uri.create("s://auth/"); + assertThat(uri.getAuthority()).isEqualTo("auth"); + uri = Uri.create("s://auth?"); + assertThat(uri.getAuthority()).isEqualTo("auth"); + uri = Uri.create("s://auth#"); + assertThat(uri.getAuthority()).isEqualTo("auth"); + } + + @Test + public void parse_invalidCharactersInUserinfo_throws() { + URISyntaxException e = + assertThrows(URISyntaxException.class, () -> Uri.parse("scheme://u ser@host/path")); + assertThat(e).hasMessageThat().contains("Invalid character in userInfo"); + } + + @Test + public void parse_invalidBackslashInUserinfo_throws() { + URISyntaxException e = + assertThrows(URISyntaxException.class, () -> Uri.parse("http://other.com\\@intended.com")); + assertThat(e).hasMessageThat().contains("Invalid character in userInfo"); + } + + @Test + public void parse_invalidCharactersInHost_throws() { + URISyntaxException e = + assertThrows(URISyntaxException.class, () -> Uri.parse("scheme://h ost/path")); + assertThat(e).hasMessageThat().contains("Invalid character in host"); + } + + @Test + public void parse_invalidBackslashInHost_throws() { + URISyntaxException e = + assertThrows(URISyntaxException.class, () -> Uri.parse("http://other.com\\.intended.com")); + assertThat(e).hasMessageThat().contains("Invalid character in host"); + } + + @Test + public void parse_invalidBackslashScope_throws() { + URISyntaxException e = + assertThrows(URISyntaxException.class, () -> Uri.parse("http://[::1%25foo\\bar]")); + assertThat(e).hasMessageThat().contains("Invalid character in scope"); + } + + @Test + public void parse_invalidCharactersInPort_throws() { + URISyntaxException e = + assertThrows(URISyntaxException.class, () -> Uri.parse("scheme://user@host:8 0/path")); + assertThat(e).hasMessageThat().contains("Invalid character"); + } + + @Test + public void parse_nonAsciiCharacterInPath_throws() throws URISyntaxException { + URISyntaxException e = assertThrows(URISyntaxException.class, () -> Uri.parse("foo:bär")); + assertThat(e).hasMessageThat().contains("Invalid character in path"); + } + + @Test + public void parse_invalidCharactersInPath_throws() { + URISyntaxException e = assertThrows(URISyntaxException.class, () -> Uri.parse("scheme:/p ath")); + assertThat(e).hasMessageThat().contains("Invalid character in path"); + } + + @Test + public void parse_invalidCharactersInQuery_throws() { + URISyntaxException e = + assertThrows(URISyntaxException.class, () -> Uri.parse("scheme://user@host/p?q[]uery")); + assertThat(e).hasMessageThat().contains("Invalid character in query"); + } + + @Test + public void parse_invalidCharactersInFragment_throws() { + URISyntaxException e = + assertThrows(URISyntaxException.class, () -> Uri.parse("scheme://user@host/path#f[]rag")); + assertThat(e).hasMessageThat().contains("Invalid character in fragment"); + } + + @Test + public void parse_nonAsciiCharacterInFragment_throws() throws URISyntaxException { + URISyntaxException e = assertThrows(URISyntaxException.class, () -> Uri.parse("foo:#bär")); + assertThat(e).hasMessageThat().contains("Invalid character in fragment"); + } + + @Test + public void parse_decoding() throws URISyntaxException { + Uri uri = Uri.parse("s://user%2Ename:pass%2Eword@a%2db:1234/p%20ath?q%20uery#f%20ragment"); + assertThat(uri.getAuthority()).isEqualTo("user.name:pass.word@a-b:1234"); + assertThat(uri.getRawAuthority()).isEqualTo("user%2Ename:pass%2Eword@a%2db:1234"); + assertThat(uri.getUserInfo()).isEqualTo("user.name:pass.word"); + assertThat(uri.getRawUserInfo()).isEqualTo("user%2Ename:pass%2Eword"); + assertThat(uri.getHost()).isEqualTo("a-b"); + assertThat(uri.getRawHost()).isEqualTo("a%2db"); + assertThat(uri.getPort()).isEqualTo(1234); + assertThat(uri.getPath()).isEqualTo("/p ath"); + assertThat(uri.getRawPath()).isEqualTo("/p%20ath"); + assertThat(uri.getRawQuery()).isEqualTo("q%20uery"); + assertThat(uri.getFragment()).isEqualTo("f ragment"); + assertThat(uri.getRawFragment()).isEqualTo("f%20ragment"); + } + + @Test + public void parse_decodingNonAscii() throws URISyntaxException { + Uri uri = Uri.parse("s://a/%E2%82%AC"); + assertThat(uri.getPath()).isEqualTo("/€"); + } + + @Test + public void parse_decodingPercent() throws URISyntaxException { + Uri uri = Uri.parse("s://a/p%2520ath#f%25ragment"); + assertThat(uri.getPath()).isEqualTo("/p%20ath"); + assertThat(uri.getFragment()).isEqualTo("f%ragment"); + } + + @Test + public void parse_invalidPercentEncoding_throws() { + URISyntaxException e = assertThrows(URISyntaxException.class, () -> Uri.parse("s://a/p%2")); + assertThat(e).hasMessageThat().contains("Invalid"); + + e = assertThrows(URISyntaxException.class, () -> Uri.parse("s://a/p%2G")); + assertThat(e).hasMessageThat().contains("Invalid"); + } + + @Test + public void parse_emptyAuthority() { + Uri uri = Uri.create("file:///foo/bar"); + assertThat(uri.getAuthority()).isEmpty(); + assertThat(uri.getHost()).isEmpty(); + assertThat(uri.getUserInfo()).isNull(); + assertThat(uri.getPort()).isEqualTo(-1); + assertThat(uri.getPath()).isEqualTo("/foo/bar"); + } + + @Test + public void parse_pathSegments_empty() throws URISyntaxException { + Uri uri = Uri.create("scheme:"); + assertThat(uri.getPathSegments()).isEmpty(); + } + + @Test + public void parse_pathSegments_root() throws URISyntaxException { + Uri uri = Uri.create("scheme:/"); + assertThat(uri.getPathSegments()).containsExactly(""); + } + + @Test + public void parse_onePathSegment() throws URISyntaxException { + Uri uri = Uri.create("file:/foo"); + assertThat(uri.getPathSegments()).containsExactly("foo"); + } + + @Test + public void parse_onePathSegment_trailingSlash() throws URISyntaxException { + Uri uri = Uri.create("file:/foo/"); + assertThat(uri.getPathSegments()).containsExactly("foo", ""); + } + + @Test + public void parse_onePathSegment_rootless() throws URISyntaxException { + Uri uri = Uri.create("dns:www.example.com"); + assertThat(uri.getPathSegments()).containsExactly("www.example.com"); + assertThat(uri.isPathAbsolute()).isFalse(); + assertThat(uri.isPathRootless()).isTrue(); + } + + @Test + public void parse_twoPathSegments() throws URISyntaxException { + Uri uri = Uri.create("file:/foo/bar"); + assertThat(uri.getPathSegments()).containsExactly("foo", "bar"); + } + + @Test + public void parse_twoPathSegments_rootless() throws URISyntaxException { + Uri uri = Uri.create("file:foo/bar"); + assertThat(uri.getPathSegments()).containsExactly("foo", "bar"); + } + + @Test + public void parse_percentEncodedPathSegment_rootless() throws URISyntaxException { + Uri uri = Uri.create("mailto:%2Fdev%2Fnull@example.com"); + assertThat(uri.getPathSegments()).containsExactly("/dev/null@example.com"); + assertThat(uri.isPathAbsolute()).isFalse(); + assertThat(uri.isPathRootless()).isTrue(); + } + + @Test + public void toString_percentEncoding() throws URISyntaxException { + Uri uri = + Uri.newBuilder() + .setScheme("s") + .setHost("a b") + .setPath("/p ath") + .setRawQuery("q%20uery") + .setFragment("f ragment") + .build(); + assertThat(uri.toString()).isEqualTo("s://a%20b/p%20ath?q%20uery#f%20ragment"); + } + + @Test + public void parse_transparentRoundTrip_ipLiteral() { + Uri uri = Uri.create("http://[2001:dB8::7]:080/%4a%4B%2f%2F?%4c%4D#%4e%4F").toBuilder().build(); + assertThat(uri.toString()).isEqualTo("http://[2001:dB8::7]:080/%4a%4B%2f%2F?%4c%4D#%4e%4F"); + + // IPv6 host has non-canonical :: zeros and mixed case hex digits. + assertThat(uri.getRawHost()).isEqualTo("[2001:dB8::7]"); + assertThat(uri.getHost()).isEqualTo("[2001:dB8::7]"); + assertThat(uri.getRawPort()).isEqualTo("080"); // Leading zeros. + assertThat(uri.getPort()).isEqualTo(80); + // Unnecessary and mixed case percent encodings. + assertThat(uri.getRawPath()).isEqualTo("/%4a%4B%2f%2F"); + assertThat(uri.getPathSegments()).containsExactly("JK//"); + assertThat(uri.getRawQuery()).isEqualTo("%4c%4D"); + assertThat(uri.getRawFragment()).isEqualTo("%4e%4F"); + assertThat(uri.getFragment()).isEqualTo("NO"); + } + + @Test + public void parse_transparentRoundTrip_regName() { + Uri uri = Uri.create("http://aB%4A%4b:080/%4a%4B%2f%2F?%4c%4D#%4e%4F").toBuilder().build(); + assertThat(uri.toString()).isEqualTo("http://aB%4A%4b:080/%4a%4B%2f%2F?%4c%4D#%4e%4F"); + + // Mixed case literal chars and hex digits. + assertThat(uri.getRawHost()).isEqualTo("aB%4A%4b"); + assertThat(uri.getHost()).isEqualTo("aBJK"); + assertThat(uri.getRawPort()).isEqualTo("080"); // Leading zeros. + assertThat(uri.getPort()).isEqualTo(80); + // Unnecessary and mixed case percent encodings. + assertThat(uri.getRawPath()).isEqualTo("/%4a%4B%2f%2F"); + assertThat(uri.getPathSegments()).containsExactly("JK//"); + assertThat(uri.getRawQuery()).isEqualTo("%4c%4D"); + assertThat(uri.getRawFragment()).isEqualTo("%4e%4F"); + assertThat(uri.getFragment()).isEqualTo("NO"); + } + + @Test + public void builder_numericPort() throws URISyntaxException { + Uri uri = Uri.newBuilder().setScheme("scheme").setHost("host").setPort(80).build(); + assertThat(uri.toString()).isEqualTo("scheme://host:80"); + } + + @Test + public void builder_ipv6Literal() throws URISyntaxException { + Uri uri = + Uri.newBuilder() + .setScheme("scheme") + .setHost(InetAddresses.forString("2001:4860:4860::8844")) + .build(); + assertThat(uri.toString()).isEqualTo("scheme://[2001:4860:4860::8844]"); + } + + @Test + public void builder_ipv6ScopedLiteral_numeric() throws UnknownHostException { + Uri uri = + Uri.newBuilder() + .setScheme("http") + // Create an address with a numeric scope_id, which should always be valid. + .setHost( + Inet6Address.getByAddress(null, InetAddresses.forString("fe80::1").getAddress(), 1)) + .build(); + + // We expect the scope ID to be percent encoded. + assertThat(uri.getRawHost()).isEqualTo("[fe80::1%251]"); + assertThat(uri.getHost()).isEqualTo("[fe80::1%1]"); + } + + @Test + public void builder_ipv6ScopedLiteral_named() throws UnknownHostException { + // Unfortunately, there's no Java API to create an Inet6Address with an arbitrary interface- + // scoped name. There's actually no way to hermetically create an Inet6Address with a scope name + // at all! The following address/interface is likely to be present on Linux test runners. + Inet6Address address; + try { + address = (Inet6Address) InetAddresses.forString("::1%lo"); + } catch (IllegalArgumentException e) { + assumeNoException(e); + return; // Not reached. + } + Uri uri = Uri.newBuilder().setScheme("http").setHost(address).build(); + + // We expect the scope ID to be percent encoded. + assertThat(uri.getRawHost()).isEqualTo("[::1%25lo]"); + assertThat(uri.getHost()).isEqualTo("[::1%lo]"); + } + + @Test + public void builder_ipv6PercentEncodedScopedLiteral() { + Uri uri = Uri.newBuilder().setScheme("http").setRawHost("[fe80::1%25foo%2Dbar%2Fblah]").build(); + assertThat(uri.getRawHost()).isEqualTo("[fe80::1%25foo%2Dbar%2Fblah]"); + assertThat(uri.getHost()).isEqualTo("[fe80::1%foo-bar/blah]"); + } + + @Test + public void builder_encodingWithAllowedReservedChars() throws URISyntaxException { + Uri uri = + Uri.newBuilder() + .setScheme("s") + .setUserInfo("u@") + .setHost("a[]") + .setPath("/p:/@") + .setRawQuery("q/?") + .setFragment("f/?") + .build(); + assertThat(uri.toString()).isEqualTo("s://u%40@a%5B%5D/p:/@?q/?#f/?"); + } + + @Test + public void builder_percentEncodingNonAscii() throws URISyntaxException { + Uri uri = Uri.newBuilder().setScheme("s").setHost("a").setPath("/€").build(); + assertThat(uri.toString()).isEqualTo("s://a/%E2%82%AC"); + } + + @Test + public void builder_percentEncodingLoneHighSurrogate_throws() { + IllegalArgumentException e = + assertThrows( + IllegalArgumentException.class, + () -> Uri.newBuilder().setPath("\uD83D")); // Lone high surrogate. + assertThat(e.getMessage()).contains("Malformed input"); + } + + @Test + public void builder_hasAuthority_pathStartsWithSlash_throws() throws URISyntaxException { + IllegalArgumentException e = + assertThrows( + IllegalArgumentException.class, + () -> Uri.newBuilder().setScheme("s").setHost("a").setPath("path").build()); + assertThat(e.getMessage()).contains("Non-empty path must start with '/'"); + } + + @Test + public void builder_noAuthority_pathStartsWithDoubleSlash_throws() throws URISyntaxException { + IllegalArgumentException e = + assertThrows( + IllegalArgumentException.class, + () -> Uri.newBuilder().setScheme("s").setPath("//path").build()); + assertThat(e.getMessage()).contains("Path cannot start with '//'"); + } + + @Test + public void builder_noScheme_throws() { + IllegalStateException e = + assertThrows(IllegalStateException.class, () -> Uri.newBuilder().build()); + assertThat(e.getMessage()).contains("Missing required scheme"); + } + + @Test + public void builder_noHost_hasUserInfo_throws() { + IllegalStateException e = + assertThrows( + IllegalStateException.class, + () -> Uri.newBuilder().setScheme("scheme").setUserInfo("user").build()); + assertThat(e.getMessage()).contains("Cannot set userInfo without host"); + } + + @Test + public void builder_noHost_hasPort_throws() { + IllegalStateException e = + assertThrows( + IllegalStateException.class, + () -> Uri.newBuilder().setScheme("scheme").setPort(1234).build()); + assertThat(e.getMessage()).contains("Cannot set port without host"); + } + + @Test + public void builder_normalizesCaseWhereAppropriate() { + Uri uri = + Uri.newBuilder() + .setScheme("hTtP") // #section-3.1 says producers (Builder) should normalize to lower. + .setHost("aBc") // #section-3.2.2 says producers (Builder) should normalize to lower. + .setPath("/CdE") // #section-6.2.2.1 says the rest are assumed to be case-sensitive + .setRawQuery("fGh") + .setFragment("IjK") + .build(); + assertThat(uri.toString()).isEqualTo("http://abc/CdE?fGh#IjK"); + } + + @Test + public void builder_normalizesIpv6Literal() { + Uri uri = + Uri.newBuilder().setScheme("scheme").setHost(InetAddresses.forString("ABCD::EFAB")).build(); + assertThat(uri.toString()).isEqualTo("scheme://[abcd::efab]"); + } + + @Test + public void builder_canClearAllOptionalFields() { + Uri uri = + Uri.create("http://user@host:80/path?query#fragment").toBuilder() + .setHost((String) null) + .setPath("") + .setUserInfo(null) + .setPort(-1) + .setRawQuery(null) + .setFragment(null) + .build(); + assertThat(uri.toString()).isEqualTo("http:"); + } + + @Test + public void builder_setRawQuery() { + Uri uri = Uri.newBuilder().setScheme("http").setHost("host").setRawQuery("%61=b&c=%64").build(); + assertThat(uri.getRawQuery()).isEqualTo("%61=b&c=%64"); + assertThat(uri.toString()).isEqualTo("http://host?%61=b&c=%64"); + } + + @Test + public void builder_setRawQuery_null() { + Uri uri = + Uri.newBuilder() + .setScheme("http") + .setHost("host") + .setRawQuery("a=b") + .setRawQuery(null) + .build(); + assertThat(uri.getRawQuery()).isNull(); + assertThat(uri.toString()).isEqualTo("http://host"); + } + + @Test + public void builder_setRawFragment() { + Uri uri = Uri.newBuilder().setScheme("http").setHost("host").setRawFragment("a%20b").build(); + assertThat(uri.getRawFragment()).isEqualTo("a%20b"); + assertThat(uri.getFragment()).isEqualTo("a b"); + assertThat(uri.toString()).isEqualTo("http://host#a%20b"); + } + + @Test + public void builder_setRawFragment_null() { + Uri uri = + Uri.newBuilder() + .setScheme("http") + .setHost("host") + .setRawFragment("a%20b") + .setRawFragment(null) + .build(); + assertThat(uri.getRawFragment()).isNull(); + assertThat(uri.getFragment()).isNull(); + assertThat(uri.toString()).isEqualTo("http://host"); + } + + @Test + public void builder_setRawFragment_invalidCharacters_throws() { + IllegalArgumentException e = + assertThrows( + IllegalArgumentException.class, + () -> Uri.newBuilder().setRawFragment("f[]rag")); + assertThat(e).hasMessageThat().contains("Invalid character in fragment"); + } + + @Test + public void builder_setRawFragment_invalidPercentEncoding_throws() { + IllegalArgumentException e = + assertThrows( + IllegalArgumentException.class, + () -> Uri.newBuilder().setRawFragment("f%XXragment")); + assertThat(e).hasMessageThat().contains("Invalid"); + } + + @Test + public void builder_canClearAuthorityComponents() { + Uri uri = Uri.create("s://user@host:80/path").toBuilder().setRawAuthority(null).build(); + assertThat(uri.toString()).isEqualTo("s:/path"); + } + + @Test + public void builder_canSetEmptyAuthority() { + Uri uri = Uri.create("s://user@host:80/path").toBuilder().setRawAuthority("").build(); + assertThat(uri.toString()).isEqualTo("s:///path"); + } + + @Test + public void builder_canSetRawAuthority() { + Uri uri = Uri.newBuilder().setScheme("http").setRawAuthority("user@host:1234").build(); + assertThat(uri.getUserInfo()).isEqualTo("user"); + assertThat(uri.getHost()).isEqualTo("host"); + assertThat(uri.getPort()).isEqualTo(1234); + } + + @Test + public void builder_setRawAuthorityPercentDecodes() { + Uri uri = + Uri.newBuilder() + .setScheme("http") + .setRawAuthority("user:user%40user@host%40host%3Ahost") + .build(); + assertThat(uri.getUserInfo()).isEqualTo("user:user@user"); + assertThat(uri.getHost()).isEqualTo("host@host:host"); + assertThat(uri.getPort()).isEqualTo(-1); + } + + @Test + public void builder_setRawAuthorityReplacesAllComponents() { + Uri uri = + Uri.newBuilder() + .setScheme("http") + .setUserInfo("user") + .setHost("host") + .setPort(1234) + .setRawAuthority("other") + .build(); + assertThat(uri.getUserInfo()).isNull(); + assertThat(uri.getHost()).isEqualTo("other"); + assertThat(uri.getPort()).isEqualTo(-1); + } + + @Test + public void toString_percentEncodingMultiChar() throws URISyntaxException { + Uri uri = + Uri.newBuilder() + .setScheme("s") + .setHost("a") + .setPath("/emojis/😊/icon.png") // Smile requires two chars to express in a java String. + .build(); + assertThat(uri.toString()).isEqualTo("s://a/emojis/%F0%9F%98%8A/icon.png"); + } + + @Test + public void toString_percentEncodingLiteralPercent() throws URISyntaxException { + Uri uri = + Uri.newBuilder() + .setScheme("s") + .setHost("a") + .setPath("/p%20ath") + .setRawQuery("q%25uery") + .setFragment("f%ragment") + .build(); + assertThat(uri.toString()).isEqualTo("s://a/p%2520ath?q%25uery#f%25ragment"); + } + + @Test + public void equalsAndHashCode() { + new EqualsTester() + .addEqualityGroup( + Uri.create("scheme://authority/path?query#fragment"), + Uri.create("scheme://authority/path?query#fragment")) + .addEqualityGroup(Uri.create("scheme://authority/path")) + .addEqualityGroup(Uri.create("scheme://authority/path?query")) + .addEqualityGroup(Uri.create("scheme:/path")) + .addEqualityGroup(Uri.create("scheme:/path?query")) + .addEqualityGroup(Uri.create("scheme:/path#fragment")) + .addEqualityGroup(Uri.create("scheme:path")) + .addEqualityGroup(Uri.create("scheme:path?query")) + .addEqualityGroup(Uri.create("scheme:path#fragment")) + .addEqualityGroup(Uri.create("scheme:")) + .testEquals(); + } + + @Test + public void isAbsolute() { + assertThat(Uri.create("scheme://authority/path").isAbsolute()).isTrue(); + assertThat(Uri.create("scheme://authority/path?query").isAbsolute()).isTrue(); + assertThat(Uri.create("scheme://authority/path#fragment").isAbsolute()).isFalse(); + assertThat(Uri.create("scheme://authority/path?query#fragment").isAbsolute()).isFalse(); + } + + @Test + public void serializedCharacterClasses_matchComputed() { + assertThat(Uri.digitChars).isEqualTo(bitSetOfRange('0', '9')); + assertThat(Uri.alphaChars).isEqualTo(or(bitSetOfRange('A', 'Z'), bitSetOfRange('a', 'z'))); + assertThat(Uri.schemeChars) + .isEqualTo(or(Uri.digitChars, Uri.alphaChars, bitSetOf('+', '-', '.'))); + assertThat(Uri.unreservedChars) + .isEqualTo(or(Uri.alphaChars, Uri.digitChars, bitSetOf('-', '.', '_', '~'))); + assertThat(Uri.genDelimsChars).isEqualTo(bitSetOf(':', '/', '?', '#', '[', ']', '@')); + assertThat(Uri.subDelimsChars) + .isEqualTo(bitSetOf('!', '$', '&', '\'', '(', ')', '*', '+', ',', ';', '=')); + assertThat(Uri.reservedChars).isEqualTo(or(Uri.genDelimsChars, Uri.subDelimsChars)); + assertThat(Uri.regNameChars).isEqualTo(or(Uri.unreservedChars, Uri.subDelimsChars)); + assertThat(Uri.userInfoChars) + .isEqualTo(or(Uri.unreservedChars, Uri.subDelimsChars, bitSetOf(':'))); + assertThat(Uri.pChars) + .isEqualTo(or(Uri.unreservedChars, Uri.subDelimsChars, bitSetOf(':', '@'))); + assertThat(Uri.pCharsAndSlash).isEqualTo(or(Uri.pChars, bitSetOf('/'))); + assertThat(Uri.queryChars).isEqualTo(or(Uri.pChars, bitSetOf('/', '?'))); + assertThat(Uri.fragmentChars).isEqualTo(or(Uri.pChars, bitSetOf('/', '?'))); + } + + private static BitSet bitSetOfRange(char from, char to) { + BitSet bitset = new BitSet(); + for (char c = from; c <= to; c++) { + bitset.set(c); + } + return bitset; + } + + private static BitSet bitSetOf(char... chars) { + BitSet bitset = new BitSet(); + for (char c : chars) { + bitset.set(c); + } + return bitset; + } + + private static BitSet or(BitSet... bitsets) { + BitSet bitset = new BitSet(); + for (BitSet bs : bitsets) { + bitset.or(bs); + } + return bitset; + } +} diff --git a/api/src/testFixtures/java/io/grpc/FlagResetRule.java b/api/src/testFixtures/java/io/grpc/FlagResetRule.java new file mode 100644 index 00000000000..08ce7ce82f2 --- /dev/null +++ b/api/src/testFixtures/java/io/grpc/FlagResetRule.java @@ -0,0 +1,96 @@ +/* + * Copyright 2026 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc; + +import java.util.ArrayDeque; +import java.util.Deque; +import javax.annotation.Nullable; +import org.junit.rules.ExternalResource; + +/** + * A {@link org.junit.rules.TestRule} that lets you set one or more feature flags just for + * the duration of the current test case. + * + *

Flags and other global variables must be reset to ensure no state leaks across tests. + */ +public final class FlagResetRule extends ExternalResource { + + /** A functional interface representing a standard gRPC feature flag setter. */ + public interface SetterMethod { + /** Sets a flag for testing and returns its previous value. */ + T set(T val); + } + + private final Deque toRunAfter = new ArrayDeque<>(); + + /** + * Sets a global feature flag to 'value' using 'setter' and arranges for its previous value to be + * unconditionally restored when the test completes. + */ + public void setFlagForTest(SetterMethod setter, T value) { + final T oldValue = setter.set(value); + toRunAfter.push(() -> setter.set(oldValue)); + } + + /** + * Sets java system property 'key' to 'value' and arranges for its previous value to be + * unconditionally restored when the test completes. + */ + public void setSystemPropertyForTest(String key, String value) { + String oldValue = System.setProperty(key, value); + restoreSystemPropertyAfterTest(key, oldValue); + } + + /** + * Clears java system property 'key' and arranges for its previous value to be unconditionally + * restored when the test completes. + */ + public void clearSystemPropertyForTest(String key) { + String oldValue = System.clearProperty(key); + restoreSystemPropertyAfterTest(key, oldValue); + } + + private void restoreSystemPropertyAfterTest(String key, @Nullable String oldValue) { + toRunAfter.push( + () -> { + if (oldValue == null) { + System.clearProperty(key); + } else { + System.setProperty(key, oldValue); + } + }); + } + + @Override + protected void after() { + RuntimeException toThrow = null; + while (!toRunAfter.isEmpty()) { + try { + toRunAfter.pop().run(); + } catch (RuntimeException e) { + if (toThrow == null) { + toThrow = e; + } else { + toThrow.addSuppressed(e); + } + } + } + if (toThrow != null) { + throw toThrow; + } + } +} diff --git a/api/src/testFixtures/java/io/grpc/MetricInstrumentRegistryAccessor.java b/api/src/testFixtures/java/io/grpc/MetricInstrumentRegistryAccessor.java new file mode 100644 index 00000000000..bd17dccad58 --- /dev/null +++ b/api/src/testFixtures/java/io/grpc/MetricInstrumentRegistryAccessor.java @@ -0,0 +1,30 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc; + +/** + * Accesses test-only methods of {@link MetricInstrumentRegistry}. + */ +public final class MetricInstrumentRegistryAccessor { + + private MetricInstrumentRegistryAccessor() { + } + + public static MetricInstrumentRegistry createMetricInstrumentRegistry() { + return new MetricInstrumentRegistry(); + } +} diff --git a/api/src/testFixtures/java/io/grpc/NoopMetricSink.java b/api/src/testFixtures/java/io/grpc/NoopMetricSink.java new file mode 100644 index 00000000000..b7717d75bcf --- /dev/null +++ b/api/src/testFixtures/java/io/grpc/NoopMetricSink.java @@ -0,0 +1,49 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc; + +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Set; + +/** + * A MetricSink that discards all records. + */ +public class NoopMetricSink implements MetricSink { + private int size; + + @Override + public Map getEnabledMetrics() { + return Collections.emptyMap(); + } + + @Override + public Set getOptionalLabels() { + return Collections.emptySet(); + } + + @Override + public synchronized int getMeasuresSize() { + return size; + } + + @Override + public synchronized void updateMeasures(List instruments) { + size = instruments.size(); + } +} diff --git a/api/src/testFixtures/java/io/grpc/PickSubchannelArgsMatcher.java b/api/src/testFixtures/java/io/grpc/PickSubchannelArgsMatcher.java new file mode 100644 index 00000000000..50140649810 --- /dev/null +++ b/api/src/testFixtures/java/io/grpc/PickSubchannelArgsMatcher.java @@ -0,0 +1,59 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc; + +import com.google.common.base.Preconditions; +import io.grpc.CallOptions; +import io.grpc.LoadBalancer.PickSubchannelArgs; +import io.grpc.Metadata; +import io.grpc.MethodDescriptor; +import org.mockito.ArgumentMatcher; +import org.mockito.ArgumentMatchers; + +/** + * Mockito Matcher for {@link PickSubchannelArgs}. + */ +public final class PickSubchannelArgsMatcher implements ArgumentMatcher { + private final MethodDescriptor method; + private final Metadata headers; + private final CallOptions callOptions; + + public PickSubchannelArgsMatcher( + MethodDescriptor method, Metadata headers, CallOptions callOptions) { + this.method = Preconditions.checkNotNull(method, "method"); + this.headers = Preconditions.checkNotNull(headers, "headers"); + this.callOptions = Preconditions.checkNotNull(callOptions, "callOptions"); + } + + @Override + public boolean matches(PickSubchannelArgs args) { + return args != null + && method.equals(args.getMethodDescriptor()) + && headers.equals(args.getHeaders()) + && callOptions.equals(args.getCallOptions()); + } + + @Override + public final String toString() { + return "[method=" + method + " headers=" + headers + " callOptions=" + callOptions + "]"; + } + + public static PickSubchannelArgs eqPickSubchannelArgs( + MethodDescriptor method, Metadata headers, CallOptions callOptions) { + return ArgumentMatchers.argThat(new PickSubchannelArgsMatcher(method, headers, callOptions)); + } +} diff --git a/api/src/testFixtures/java/io/grpc/StatusMatcher.java b/api/src/testFixtures/java/io/grpc/StatusMatcher.java new file mode 100644 index 00000000000..08e9fffb013 --- /dev/null +++ b/api/src/testFixtures/java/io/grpc/StatusMatcher.java @@ -0,0 +1,135 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc; + +import static com.google.common.base.Preconditions.checkNotNull; +import static com.google.common.base.Preconditions.checkState; + +import org.mockito.ArgumentMatcher; + +/** + * Mockito matcher for {@link Status}. + */ +public final class StatusMatcher implements ArgumentMatcher { + public static StatusMatcher statusHasCode(ArgumentMatcher codeMatcher) { + return new StatusMatcher(codeMatcher, null, null); + } + + public static StatusMatcher statusHasCode(Status.Code code) { + return statusHasCode(new EqualsMatcher<>(code)); + } + + private final ArgumentMatcher codeMatcher; + private final ArgumentMatcher descriptionMatcher; + private final ArgumentMatcher causeMatcher; + + private StatusMatcher( + ArgumentMatcher codeMatcher, + ArgumentMatcher descriptionMatcher, + ArgumentMatcher causeMatcher) { + this.codeMatcher = checkNotNull(codeMatcher, "codeMatcher"); + this.descriptionMatcher = descriptionMatcher; + this.causeMatcher = causeMatcher; + } + + public StatusMatcher andDescription(ArgumentMatcher descriptionMatcher) { + checkState(this.descriptionMatcher == null, "Already has a description matcher"); + return new StatusMatcher(codeMatcher, descriptionMatcher, causeMatcher); + } + + public StatusMatcher andDescription(String description) { + return andDescription(new EqualsMatcher<>(description)); + } + + public StatusMatcher andDescriptionContains(String substring) { + return andDescription(new StringContainsMatcher(substring)); + } + + public StatusMatcher andCause(ArgumentMatcher causeMatcher) { + checkState(this.causeMatcher == null, "Already has a cause matcher"); + return new StatusMatcher(codeMatcher, descriptionMatcher, causeMatcher); + } + + public StatusMatcher andCause(Throwable cause) { + return andCause(new EqualsMatcher<>(cause)); + } + + @Override + public boolean matches(Status status) { + return status != null + && codeMatcher.matches(status.getCode()) + && (descriptionMatcher == null || descriptionMatcher.matches(status.getDescription())) + && (causeMatcher == null || causeMatcher.matches(status.getCause())); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append("{code="); + sb.append(codeMatcher); + if (descriptionMatcher != null) { + sb.append(", description="); + sb.append(descriptionMatcher); + } + if (causeMatcher != null) { + sb.append(", cause="); + sb.append(causeMatcher); + } + sb.append("}"); + return sb.toString(); + } + + // Use instead of lambda for better error message. + static final class EqualsMatcher implements ArgumentMatcher { + private final T obj; + + EqualsMatcher(T obj) { + this.obj = checkNotNull(obj, "obj"); + } + + @Override + public boolean matches(Object other) { + return obj.equals(other); + } + + @Override + public String toString() { + return obj.toString(); + } + } + + static final class StringContainsMatcher implements ArgumentMatcher { + private final String needle; + + StringContainsMatcher(String needle) { + this.needle = checkNotNull(needle, "needle"); + } + + @Override + public boolean matches(String haystack) { + if (haystack == null) { + return false; + } + return haystack.contains(needle); + } + + @Override + public String toString() { + return "contains " + needle; + } + } +} diff --git a/api/src/testFixtures/java/io/grpc/StatusOrMatcher.java b/api/src/testFixtures/java/io/grpc/StatusOrMatcher.java new file mode 100644 index 00000000000..1e70ae97853 --- /dev/null +++ b/api/src/testFixtures/java/io/grpc/StatusOrMatcher.java @@ -0,0 +1,66 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc; + +import static com.google.common.base.Preconditions.checkNotNull; + +import org.mockito.ArgumentMatcher; + +/** + * Mockito matcher for {@link StatusOr}. + */ +public final class StatusOrMatcher implements ArgumentMatcher> { + public static StatusOrMatcher hasValue(ArgumentMatcher valueMatcher) { + return new StatusOrMatcher(checkNotNull(valueMatcher, "valueMatcher"), null); + } + + public static StatusOrMatcher hasStatus(ArgumentMatcher statusMatcher) { + return new StatusOrMatcher(null, checkNotNull(statusMatcher, "statusMatcher")); + } + + private final ArgumentMatcher valueMatcher; + private final ArgumentMatcher statusMatcher; + + private StatusOrMatcher(ArgumentMatcher valueMatcher, ArgumentMatcher statusMatcher) { + this.valueMatcher = valueMatcher; + this.statusMatcher = statusMatcher; + } + + @Override + public boolean matches(StatusOr statusOr) { + if (statusOr == null) { + return false; + } + if (statusOr.hasValue() != (valueMatcher != null)) { + return false; + } + if (valueMatcher != null) { + return valueMatcher.matches(statusOr.getValue()); + } else { + return statusMatcher.matches(statusOr.getStatus()); + } + } + + @Override + public String toString() { + if (valueMatcher != null) { + return "{value=" + valueMatcher + "}"; + } else { + return "{status=" + statusMatcher + "}"; + } + } +} diff --git a/api/src/testFixtures/java/io/grpc/StatusSubject.java b/api/src/testFixtures/java/io/grpc/StatusSubject.java new file mode 100644 index 00000000000..0b00df96140 --- /dev/null +++ b/api/src/testFixtures/java/io/grpc/StatusSubject.java @@ -0,0 +1,68 @@ +/* + * Copyright 2026 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc; + +import static com.google.common.truth.Fact.fact; + +import com.google.common.truth.FailureMetadata; +import com.google.common.truth.Subject; +import javax.annotation.Nullable; + +/** Propositions for {@link Status} subjects. */ +public final class StatusSubject extends Subject { + + private static final Subject.Factory statusFactory = new Factory(); + + public static Subject.Factory status() { + return statusFactory; + } + + private final Status actual; + + private StatusSubject(FailureMetadata metadata, @Nullable Status subject) { + super(metadata, subject); + this.actual = subject; + } + + /** Fails if the subject is not OK. */ + public void isOk() { + if (actual == null) { + failWithActual("expected to be OK but was", "null"); + } else if (!actual.isOk()) { + failWithoutActual( + fact("expected to be OK but was", actual.getCode()), + fact("description", actual.getDescription()), + fact("cause", actual.getCause())); + } + } + + /** Fails if the subject does not have the given code. */ + public void hasCode(Status.Code expectedCode) { + if (actual == null) { + failWithActual("expected to have code " + expectedCode + " but was", "null"); + } else { + check("getCode()").that(actual.getCode()).isEqualTo(expectedCode); + } + } + + private static final class Factory implements Subject.Factory { + @Override + public StatusSubject createSubject(FailureMetadata metadata, @Nullable Status that) { + return new StatusSubject(metadata, that); + } + } +} diff --git a/api/src/testFixtures/java/io/grpc/StringMarshaller.java b/api/src/testFixtures/java/io/grpc/StringMarshaller.java index af53d420e2b..e8358b76333 100644 --- a/api/src/testFixtures/java/io/grpc/StringMarshaller.java +++ b/api/src/testFixtures/java/io/grpc/StringMarshaller.java @@ -16,7 +16,7 @@ package io.grpc; -import static com.google.common.base.Charsets.UTF_8; +import static java.nio.charset.StandardCharsets.UTF_8; import com.google.common.io.ByteStreams; import java.io.ByteArrayInputStream; diff --git a/api/src/testFixtures/java/io/grpc/testing/DeadlineSubject.java b/api/src/testFixtures/java/io/grpc/testing/DeadlineSubject.java index 5d4e86fac15..c2b4d8412a7 100644 --- a/api/src/testFixtures/java/io/grpc/testing/DeadlineSubject.java +++ b/api/src/testFixtures/java/io/grpc/testing/DeadlineSubject.java @@ -24,9 +24,9 @@ import com.google.common.truth.ComparableSubject; import com.google.common.truth.FailureMetadata; import com.google.common.truth.Subject; +import com.google.errorprone.annotations.CheckReturnValue; import io.grpc.Deadline; import java.util.concurrent.TimeUnit; -import javax.annotation.CheckReturnValue; import javax.annotation.Nullable; /** Propositions for {@link Deadline} subjects. */ @@ -67,7 +67,7 @@ public void of(Deadline expected) { if (Math.abs(actualNanos - expectedNanos) > deltaNanos) { failWithoutActual( fact("expected", expectedNanos / NANOSECONDS_IN_A_SECOND), - fact("but was", expectedNanos / NANOSECONDS_IN_A_SECOND), + fact("but was", actualNanos / NANOSECONDS_IN_A_SECOND), fact("outside tolerance in seconds", deltaNanos / NANOSECONDS_IN_A_SECOND)); } } diff --git a/auth/BUILD.bazel b/auth/BUILD.bazel index cc923dce2b5..da44243e583 100644 --- a/auth/BUILD.bazel +++ b/auth/BUILD.bazel @@ -1,3 +1,6 @@ +load("@rules_java//java:defs.bzl", "java_library") +load("@rules_jvm_external//:defs.bzl", "artifact") + java_library( name = "auth", srcs = glob([ @@ -6,9 +9,8 @@ java_library( visibility = ["//visibility:public"], deps = [ "//api", - "@com_google_auth_google_auth_library_credentials//jar", - "@com_google_code_findbugs_jsr305//jar", - "@com_google_guava_guava//jar", - "@com_google_j2objc_j2objc_annotations//jar", + artifact("com.google.auth:google-auth-library-credentials"), + artifact("com.google.code.findbugs:jsr305"), + artifact("com.google.guava:guava"), ], ) diff --git a/auth/build.gradle b/auth/build.gradle index 093c798fa7f..d56802c14ca 100644 --- a/auth/build.gradle +++ b/auth/build.gradle @@ -20,7 +20,16 @@ dependencies { implementation libraries.guava testImplementation project(':grpc-testing'), project(':grpc-core'), + project(":grpc-context"), // Override google-auth dependency with our newer version libraries.google.auth.oauth2Http - signature libraries.signature.java - signature libraries.signature.android + signature (libraries.signature.java) { + artifact { + extension = "signature" + } + } + signature (libraries.signature.android) { + artifact { + extension = "signature" + } + } } diff --git a/auth/src/test/java/io/grpc/auth/GoogleAuthLibraryCallCredentialsTest.java b/auth/src/test/java/io/grpc/auth/GoogleAuthLibraryCallCredentialsTest.java index 6c350894929..75026fd7c18 100644 --- a/auth/src/test/java/io/grpc/auth/GoogleAuthLibraryCallCredentialsTest.java +++ b/auth/src/test/java/io/grpc/auth/GoogleAuthLibraryCallCredentialsTest.java @@ -16,7 +16,7 @@ package io.grpc.auth; -import static com.google.common.base.Charsets.US_ASCII; +import static java.nio.charset.StandardCharsets.US_ASCII; import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNull; @@ -50,10 +50,12 @@ import io.grpc.Status; import io.grpc.internal.JsonParser; import io.grpc.testing.TestMethodDescriptors; +import io.grpc.testing.TlsTesting; +import io.grpc.util.CertificateUtils; import java.io.IOException; +import java.io.InputStream; import java.net.URI; -import java.security.KeyPair; -import java.security.KeyPairGenerator; +import java.security.PrivateKey; import java.util.ArrayList; import java.util.Date; import java.util.List; @@ -342,7 +344,10 @@ public void serviceUri() throws Exception { @Test public void serviceAccountToJwt() throws Exception { - KeyPair pair = KeyPairGenerator.getInstance("RSA").generateKeyPair(); + PrivateKey privateKey; + try (InputStream server1Key = TlsTesting.loadCert("server1.key")) { + privateKey = CertificateUtils.getPrivateKey(server1Key); + } HttpTransportFactory factory = Mockito.mock(HttpTransportFactory.class); Mockito.when(factory.create()).thenThrow(new AssertionError()); @@ -350,7 +355,7 @@ public void serviceAccountToJwt() throws Exception { ServiceAccountCredentials credentials = ServiceAccountCredentials.newBuilder() .setClientEmail("test-email@example.com") - .setPrivateKey(pair.getPrivate()) + .setPrivateKey(privateKey) .setPrivateKeyId("test-private-key-id") .setHttpTransportFactory(factory) .build(); @@ -390,13 +395,16 @@ public void oauthClassesNotInClassPath() throws Exception { @Test public void jwtAccessCredentialsInRequestMetadata() throws Exception { - KeyPair pair = KeyPairGenerator.getInstance("RSA").generateKeyPair(); + PrivateKey privateKey; + try (InputStream server1Key = TlsTesting.loadCert("server1.key")) { + privateKey = CertificateUtils.getPrivateKey(server1Key); + } ServiceAccountCredentials credentials = ServiceAccountCredentials.newBuilder() .setClientId("test-client") .setClientEmail("test-email@example.com") - .setPrivateKey(pair.getPrivate()) + .setPrivateKey(privateKey) .setPrivateKeyId("test-private-key-id") .setQuotaProjectId("test-quota-project-id") .build(); diff --git a/authz/build.gradle b/authz/build.gradle index 491e8f32a74..4b02b01aa29 100644 --- a/authz/build.gradle +++ b/authz/build.gradle @@ -2,8 +2,8 @@ plugins { id "java-library" id "maven-publish" - id "com.github.johnrengelman.shadow" id "com.google.protobuf" + id "com.gradleup.shadow" id "ru.vyarus.animalsniffer" } @@ -15,7 +15,6 @@ dependencies { libraries.guava.jre // JRE required by transitive protobuf-java-util annotationProcessor libraries.auto.value - compileOnly libraries.javax.annotation testImplementation project(':grpc-testing'), project(':grpc-testing-proto'), @@ -26,7 +25,11 @@ dependencies { shadow configurations.implementation.getDependencies().minus([xdsDependency]) shadow project(path: ':grpc-xds', configuration: 'shadow') - signature libraries.signature.java + signature (libraries.signature.java) { + artifact { + extension = "signature" + } + } } tasks.named("jar").configure { diff --git a/authz/src/main/java/io/grpc/authz/AuthorizationPolicyTranslator.java b/authz/src/main/java/io/grpc/authz/AuthorizationPolicyTranslator.java index 1637af737ad..183ae2c3f55 100644 --- a/authz/src/main/java/io/grpc/authz/AuthorizationPolicyTranslator.java +++ b/authz/src/main/java/io/grpc/authz/AuthorizationPolicyTranslator.java @@ -33,6 +33,7 @@ import java.util.ArrayList; import java.util.LinkedHashMap; import java.util.List; +import java.util.Locale; import java.util.Map; /** @@ -77,7 +78,7 @@ private static Permission parseHeader(Map header) throws IllegalArgum } if (key.charAt(0) == ':' || key.startsWith("grpc-") - || UNSUPPORTED_HEADERS.contains(key.toLowerCase())) { + || UNSUPPORTED_HEADERS.contains(key.toLowerCase(Locale.ROOT))) { throw new IllegalArgumentException(String.format("Unsupported \"key\" %s", key)); } List valuesList = JsonUtil.getListOfStrings(header, "values"); @@ -155,19 +156,19 @@ private static Map parseRules( } /** - * Translates a gRPC authorization policy in JSON string to Envoy RBAC policies. - * On success, will return one of the following - - * 1. One allow RBAC policy or, - * 2. Two RBAC policies, deny policy followed by allow policy. - * If the policy cannot be parsed or is invalid, an exception will be thrown. - */ + * Translates a gRPC authorization policy in JSON string to Envoy RBAC policies. + * On success, will return one of the following - + * 1. One allow RBAC policy or, + * 2. Two RBAC policies, deny policy followed by allow policy. + * If the policy cannot be parsed or is invalid, an exception will be thrown. + */ public static List translate(String authorizationPolicy) throws IllegalArgumentException, IOException { Object jsonObject = JsonParser.parse(authorizationPolicy); if (!(jsonObject instanceof Map)) { throw new IllegalArgumentException( - "Authorization policy should be a JSON object. Found: " - + (jsonObject == null ? null : jsonObject.getClass())); + "Authorization policy should be a JSON object. Found: " + + (jsonObject == null ? null : jsonObject.getClass())); } @SuppressWarnings("unchecked") Map json = (Map)jsonObject; diff --git a/authz/src/test/java/io/grpc/authz/AuthorizationPolicyTranslatorTest.java b/authz/src/test/java/io/grpc/authz/AuthorizationPolicyTranslatorTest.java index 557458e97d7..17e6d4fe98b 100644 --- a/authz/src/test/java/io/grpc/authz/AuthorizationPolicyTranslatorTest.java +++ b/authz/src/test/java/io/grpc/authz/AuthorizationPolicyTranslatorTest.java @@ -45,9 +45,8 @@ public void invalidPolicy() throws Exception { AuthorizationPolicyTranslator.translate(policy); fail("exception expected"); } catch (IOException ioe) { - assertThat(ioe).hasMessageThat().isEqualTo( - "Use JsonReader.setLenient(true) to accept malformed JSON" - + " at line 1 column 18 path $.name"); + assertThat(ioe).hasMessageThat().contains("malformed JSON"); + assertThat(ioe).hasMessageThat().contains("at line 1 column 18 path $.name"); } } diff --git a/authz/src/test/java/io/grpc/authz/AuthorizationServerInterceptorTest.java b/authz/src/test/java/io/grpc/authz/AuthorizationServerInterceptorTest.java index b07a71bfb9f..65c08ef247f 100644 --- a/authz/src/test/java/io/grpc/authz/AuthorizationServerInterceptorTest.java +++ b/authz/src/test/java/io/grpc/authz/AuthorizationServerInterceptorTest.java @@ -35,9 +35,8 @@ public void invalidPolicyFailsStaticAuthzInterceptorCreation() throws Exception AuthorizationServerInterceptor.create(policy); fail("exception expected"); } catch (IOException ioe) { - assertThat(ioe).hasMessageThat().isEqualTo( - "Use JsonReader.setLenient(true) to accept malformed JSON" - + " at line 1 column 18 path $.name"); + assertThat(ioe).hasMessageThat().contains("malformed JSON"); + assertThat(ioe).hasMessageThat().contains("at line 1 column 18 path $.name"); } } diff --git a/benchmarks/build.gradle b/benchmarks/build.gradle index 5c9e1125f0b..88b26397e78 100644 --- a/benchmarks/build.gradle +++ b/benchmarks/build.gradle @@ -38,12 +38,15 @@ dependencies { classifier = "linux-x86_64" } } - compileOnly libraries.javax.annotation testImplementation libraries.junit, libraries.mockito.core - signature libraries.signature.java + signature (libraries.signature.java) { + artifact { + extension = "signature" + } + } } import net.ltgt.gradle.errorprone.CheckSeverity @@ -107,7 +110,9 @@ application { from(openloop_client) from(qps_server) from(benchmark_worker) - fileMode = 0755 + filePermissions { + unix(0755) + } } } diff --git a/benchmarks/src/generated/main/grpc/io/grpc/benchmarks/proto/BenchmarkServiceGrpc.java b/benchmarks/src/generated/main/grpc/io/grpc/benchmarks/proto/BenchmarkServiceGrpc.java index e62c2274ee9..68e911afc4a 100644 --- a/benchmarks/src/generated/main/grpc/io/grpc/benchmarks/proto/BenchmarkServiceGrpc.java +++ b/benchmarks/src/generated/main/grpc/io/grpc/benchmarks/proto/BenchmarkServiceGrpc.java @@ -4,9 +4,6 @@ /** */ -@javax.annotation.Generated( - value = "by gRPC proto compiler", - comments = "Source: grpc/testing/services.proto") @io.grpc.stub.annotations.GrpcGenerated public final class BenchmarkServiceGrpc { @@ -184,6 +181,21 @@ public BenchmarkServiceStub newStub(io.grpc.Channel channel, io.grpc.CallOptions return BenchmarkServiceStub.newStub(factory, channel); } + /** + * Creates a new blocking-style stub that supports all types of calls on the service + */ + public static BenchmarkServiceBlockingV2Stub newBlockingV2Stub( + io.grpc.Channel channel) { + io.grpc.stub.AbstractStub.StubFactory factory = + new io.grpc.stub.AbstractStub.StubFactory() { + @java.lang.Override + public BenchmarkServiceBlockingV2Stub newStub(io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new BenchmarkServiceBlockingV2Stub(channel, callOptions); + } + }; + return BenchmarkServiceBlockingV2Stub.newStub(factory, channel); + } + /** * Creates a new blocking-style stub that supports unary and streaming output calls on the service */ @@ -367,6 +379,87 @@ public io.grpc.stub.StreamObserver { + private BenchmarkServiceBlockingV2Stub( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + super(channel, callOptions); + } + + @java.lang.Override + protected BenchmarkServiceBlockingV2Stub build( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new BenchmarkServiceBlockingV2Stub(channel, callOptions); + } + + /** + *

+     * One request followed by one response.
+     * The server returns the client payload as-is.
+     * 
+ */ + public io.grpc.benchmarks.proto.Messages.SimpleResponse unaryCall(io.grpc.benchmarks.proto.Messages.SimpleRequest request) throws io.grpc.StatusException { + return io.grpc.stub.ClientCalls.blockingV2UnaryCall( + getChannel(), getUnaryCallMethod(), getCallOptions(), request); + } + + /** + *
+     * Repeated sequence of one request followed by one response.
+     * Should be called streaming ping-pong
+     * The server returns the client payload as-is on each response
+     * 
+ */ + @io.grpc.ExperimentalApi("https://github.com/grpc/grpc-java/issues/10918") + public io.grpc.stub.BlockingClientCall + streamingCall() { + return io.grpc.stub.ClientCalls.blockingBidiStreamingCall( + getChannel(), getStreamingCallMethod(), getCallOptions()); + } + + /** + *
+     * Single-sided unbounded streaming from client to server
+     * The server returns the client payload as-is once the client does WritesDone
+     * 
+ */ + @io.grpc.ExperimentalApi("https://github.com/grpc/grpc-java/issues/10918") + public io.grpc.stub.BlockingClientCall + streamingFromClient() { + return io.grpc.stub.ClientCalls.blockingClientStreamingCall( + getChannel(), getStreamingFromClientMethod(), getCallOptions()); + } + + /** + *
+     * Single-sided unbounded streaming from server to client
+     * The server repeatedly returns the client payload as-is
+     * 
+ */ + @io.grpc.ExperimentalApi("https://github.com/grpc/grpc-java/issues/10918") + public io.grpc.stub.BlockingClientCall + streamingFromServer(io.grpc.benchmarks.proto.Messages.SimpleRequest request) { + return io.grpc.stub.ClientCalls.blockingV2ServerStreamingCall( + getChannel(), getStreamingFromServerMethod(), getCallOptions(), request); + } + + /** + *
+     * Two-sided unbounded streaming between server to client
+     * Both sides send the content of their own choice to the other
+     * 
+ */ + @io.grpc.ExperimentalApi("https://github.com/grpc/grpc-java/issues/10918") + public io.grpc.stub.BlockingClientCall + streamingBothWays() { + return io.grpc.stub.ClientCalls.blockingBidiStreamingCall( + getChannel(), getStreamingBothWaysMethod(), getCallOptions()); + } + } + + /** + * A stub to allow clients to do limited synchronous rpc calls to service BenchmarkService. + */ public static final class BenchmarkServiceBlockingStub extends io.grpc.stub.AbstractBlockingStub { private BenchmarkServiceBlockingStub( diff --git a/benchmarks/src/generated/main/grpc/io/grpc/benchmarks/proto/ReportQpsScenarioServiceGrpc.java b/benchmarks/src/generated/main/grpc/io/grpc/benchmarks/proto/ReportQpsScenarioServiceGrpc.java index b24c3813c19..c5064875bb6 100644 --- a/benchmarks/src/generated/main/grpc/io/grpc/benchmarks/proto/ReportQpsScenarioServiceGrpc.java +++ b/benchmarks/src/generated/main/grpc/io/grpc/benchmarks/proto/ReportQpsScenarioServiceGrpc.java @@ -4,9 +4,6 @@ /** */ -@javax.annotation.Generated( - value = "by gRPC proto compiler", - comments = "Source: grpc/testing/services.proto") @io.grpc.stub.annotations.GrpcGenerated public final class ReportQpsScenarioServiceGrpc { @@ -60,6 +57,21 @@ public ReportQpsScenarioServiceStub newStub(io.grpc.Channel channel, io.grpc.Cal return ReportQpsScenarioServiceStub.newStub(factory, channel); } + /** + * Creates a new blocking-style stub that supports all types of calls on the service + */ + public static ReportQpsScenarioServiceBlockingV2Stub newBlockingV2Stub( + io.grpc.Channel channel) { + io.grpc.stub.AbstractStub.StubFactory factory = + new io.grpc.stub.AbstractStub.StubFactory() { + @java.lang.Override + public ReportQpsScenarioServiceBlockingV2Stub newStub(io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new ReportQpsScenarioServiceBlockingV2Stub(channel, callOptions); + } + }; + return ReportQpsScenarioServiceBlockingV2Stub.newStub(factory, channel); + } + /** * Creates a new blocking-style stub that supports unary and streaming output calls on the service */ @@ -147,6 +159,33 @@ public void reportScenario(io.grpc.benchmarks.proto.Control.ScenarioResult reque /** * A stub to allow clients to do synchronous rpc calls to service ReportQpsScenarioService. */ + public static final class ReportQpsScenarioServiceBlockingV2Stub + extends io.grpc.stub.AbstractBlockingStub { + private ReportQpsScenarioServiceBlockingV2Stub( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + super(channel, callOptions); + } + + @java.lang.Override + protected ReportQpsScenarioServiceBlockingV2Stub build( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new ReportQpsScenarioServiceBlockingV2Stub(channel, callOptions); + } + + /** + *
+     * Report results of a QPS test benchmark scenario.
+     * 
+ */ + public io.grpc.benchmarks.proto.Control.Void reportScenario(io.grpc.benchmarks.proto.Control.ScenarioResult request) throws io.grpc.StatusException { + return io.grpc.stub.ClientCalls.blockingV2UnaryCall( + getChannel(), getReportScenarioMethod(), getCallOptions(), request); + } + } + + /** + * A stub to allow clients to do limited synchronous rpc calls to service ReportQpsScenarioService. + */ public static final class ReportQpsScenarioServiceBlockingStub extends io.grpc.stub.AbstractBlockingStub { private ReportQpsScenarioServiceBlockingStub( diff --git a/benchmarks/src/generated/main/grpc/io/grpc/benchmarks/proto/WorkerServiceGrpc.java b/benchmarks/src/generated/main/grpc/io/grpc/benchmarks/proto/WorkerServiceGrpc.java index 0ee6797c8e3..721b4f9ab19 100644 --- a/benchmarks/src/generated/main/grpc/io/grpc/benchmarks/proto/WorkerServiceGrpc.java +++ b/benchmarks/src/generated/main/grpc/io/grpc/benchmarks/proto/WorkerServiceGrpc.java @@ -4,9 +4,6 @@ /** */ -@javax.annotation.Generated( - value = "by gRPC proto compiler", - comments = "Source: grpc/testing/services.proto") @io.grpc.stub.annotations.GrpcGenerated public final class WorkerServiceGrpc { @@ -153,6 +150,21 @@ public WorkerServiceStub newStub(io.grpc.Channel channel, io.grpc.CallOptions ca return WorkerServiceStub.newStub(factory, channel); } + /** + * Creates a new blocking-style stub that supports all types of calls on the service + */ + public static WorkerServiceBlockingV2Stub newBlockingV2Stub( + io.grpc.Channel channel) { + io.grpc.stub.AbstractStub.StubFactory factory = + new io.grpc.stub.AbstractStub.StubFactory() { + @java.lang.Override + public WorkerServiceBlockingV2Stub newStub(io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new WorkerServiceBlockingV2Stub(channel, callOptions); + } + }; + return WorkerServiceBlockingV2Stub.newStub(factory, channel); + } + /** * Creates a new blocking-style stub that supports unary and streaming output calls on the service */ @@ -323,6 +335,77 @@ public void quitWorker(io.grpc.benchmarks.proto.Control.Void request, /** * A stub to allow clients to do synchronous rpc calls to service WorkerService. */ + public static final class WorkerServiceBlockingV2Stub + extends io.grpc.stub.AbstractBlockingStub { + private WorkerServiceBlockingV2Stub( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + super(channel, callOptions); + } + + @java.lang.Override + protected WorkerServiceBlockingV2Stub build( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new WorkerServiceBlockingV2Stub(channel, callOptions); + } + + /** + *
+     * Start server with specified workload.
+     * First request sent specifies the ServerConfig followed by ServerStatus
+     * response. After that, a "Mark" can be sent anytime to request the latest
+     * stats. Closing the stream will initiate shutdown of the test server
+     * and once the shutdown has finished, the OK status is sent to terminate
+     * this RPC.
+     * 
+ */ + @io.grpc.ExperimentalApi("https://github.com/grpc/grpc-java/issues/10918") + public io.grpc.stub.BlockingClientCall + runServer() { + return io.grpc.stub.ClientCalls.blockingBidiStreamingCall( + getChannel(), getRunServerMethod(), getCallOptions()); + } + + /** + *
+     * Start client with specified workload.
+     * First request sent specifies the ClientConfig followed by ClientStatus
+     * response. After that, a "Mark" can be sent anytime to request the latest
+     * stats. Closing the stream will initiate shutdown of the test client
+     * and once the shutdown has finished, the OK status is sent to terminate
+     * this RPC.
+     * 
+ */ + @io.grpc.ExperimentalApi("https://github.com/grpc/grpc-java/issues/10918") + public io.grpc.stub.BlockingClientCall + runClient() { + return io.grpc.stub.ClientCalls.blockingBidiStreamingCall( + getChannel(), getRunClientMethod(), getCallOptions()); + } + + /** + *
+     * Just return the core count - unary call
+     * 
+ */ + public io.grpc.benchmarks.proto.Control.CoreResponse coreCount(io.grpc.benchmarks.proto.Control.CoreRequest request) throws io.grpc.StatusException { + return io.grpc.stub.ClientCalls.blockingV2UnaryCall( + getChannel(), getCoreCountMethod(), getCallOptions(), request); + } + + /** + *
+     * Quit this worker
+     * 
+ */ + public io.grpc.benchmarks.proto.Control.Void quitWorker(io.grpc.benchmarks.proto.Control.Void request) throws io.grpc.StatusException { + return io.grpc.stub.ClientCalls.blockingV2UnaryCall( + getChannel(), getQuitWorkerMethod(), getCallOptions(), request); + } + } + + /** + * A stub to allow clients to do limited synchronous rpc calls to service WorkerService. + */ public static final class WorkerServiceBlockingStub extends io.grpc.stub.AbstractBlockingStub { private WorkerServiceBlockingStub( diff --git a/benchmarks/src/main/java/io/grpc/benchmarks/Transport.java b/benchmarks/src/main/java/io/grpc/benchmarks/Transport.java index 820b3ac1968..fa21e03b6b8 100644 --- a/benchmarks/src/main/java/io/grpc/benchmarks/Transport.java +++ b/benchmarks/src/main/java/io/grpc/benchmarks/Transport.java @@ -16,6 +16,8 @@ package io.grpc.benchmarks; +import java.util.Locale; + /** * All of the supported transports. */ @@ -64,11 +66,16 @@ public static String getDescriptionString() { if (!first) { builder.append("\n"); } - builder.append(transport.name().toLowerCase()); + builder.append(transport); builder.append(": "); builder.append(transport.description); first = false; } return builder.toString(); } + + @Override + public String toString() { + return name().toLowerCase(Locale.ROOT); + } } diff --git a/benchmarks/src/main/java/io/grpc/benchmarks/qps/ClientConfiguration.java b/benchmarks/src/main/java/io/grpc/benchmarks/qps/ClientConfiguration.java index 3bafdb836ba..45061ffec49 100644 --- a/benchmarks/src/main/java/io/grpc/benchmarks/qps/ClientConfiguration.java +++ b/benchmarks/src/main/java/io/grpc/benchmarks/qps/ClientConfiguration.java @@ -31,6 +31,7 @@ import java.util.Collection; import java.util.Collections; import java.util.LinkedHashSet; +import java.util.Locale; import java.util.Set; /** @@ -102,7 +103,7 @@ protected ClientConfiguration build0(ClientConfiguration config) { if (config.tls) { if (!config.transport.tlsSupported) { throw new IllegalArgumentException( - "Transport " + config.transport.name().toLowerCase() + " does not support TLS."); + "Transport " + config.transport + " does not support TLS."); } } @@ -166,10 +167,10 @@ protected void setClientValue(ClientConfiguration config, String value) { config.testca = parseBoolean(value); } }, - TRANSPORT("STR", Transport.getDescriptionString(), DEFAULT.transport.name().toLowerCase()) { + TRANSPORT("STR", Transport.getDescriptionString(), DEFAULT.transport.toString()) { @Override protected void setClientValue(ClientConfiguration config, String value) { - config.transport = Transport.valueOf(value.toUpperCase()); + config.transport = Transport.valueOf(value.toUpperCase(Locale.ROOT)); } }, DURATION("SECONDS", "Duration of the benchmark.", "" + DEFAULT.duration) { @@ -236,7 +237,7 @@ protected void setClientValue(ClientConfiguration config, String value) { @Override public String getName() { - return name().toLowerCase(); + return name().toLowerCase(Locale.ROOT); } @Override diff --git a/benchmarks/src/main/java/io/grpc/benchmarks/qps/ServerConfiguration.java b/benchmarks/src/main/java/io/grpc/benchmarks/qps/ServerConfiguration.java index 915c1da75eb..eb0b45c85e0 100644 --- a/benchmarks/src/main/java/io/grpc/benchmarks/qps/ServerConfiguration.java +++ b/benchmarks/src/main/java/io/grpc/benchmarks/qps/ServerConfiguration.java @@ -29,6 +29,7 @@ import java.util.Collection; import java.util.Collections; import java.util.List; +import java.util.Locale; /** * Configuration options for benchmark servers. @@ -69,7 +70,7 @@ protected Collection getParams() { protected ServerConfiguration build0(ServerConfiguration config) { if (config.tls && !config.transport.tlsSupported) { throw new IllegalArgumentException( - "TLS unsupported with the " + config.transport.name().toLowerCase() + " transport"); + "TLS unsupported with the " + config.transport + " transport"); } // Verify that the address type is correct for the transport type. @@ -109,6 +110,11 @@ public enum Transport { this.socketAddressValidator = socketAddressValidator; } + @Override + public String toString() { + return name().toLowerCase(Locale.ROOT); + } + /** * Validates the given address for this transport. * @@ -128,7 +134,7 @@ static String getDescriptionString() { if (!first) { builder.append("\n"); } - builder.append(transport.name().toLowerCase()); + builder.append(transport); builder.append(": "); builder.append(transport.description); first = false; @@ -158,10 +164,10 @@ protected void setServerValue(ServerConfiguration config, String value) { config.tls = parseBoolean(value); } }, - TRANSPORT("STR", Transport.getDescriptionString(), DEFAULT.transport.name().toLowerCase()) { + TRANSPORT("STR", Transport.getDescriptionString(), DEFAULT.transport.toString()) { @Override protected void setServerValue(ServerConfiguration config, String value) { - config.transport = Transport.valueOf(value.toUpperCase()); + config.transport = Transport.valueOf(value.toUpperCase(Locale.ROOT)); } }, DIRECTEXECUTOR("", "Don't use a threadpool for RPC calls, instead execute calls directly " @@ -197,7 +203,7 @@ protected void setServerValue(ServerConfiguration config, String value) { @Override public String getName() { - return name().toLowerCase(); + return name().toLowerCase(Locale.ROOT); } @Override diff --git a/binder/build.gradle b/binder/build.gradle index 62613b00cb5..7e7d4810e98 100644 --- a/binder/build.gradle +++ b/binder/build.gradle @@ -6,32 +6,38 @@ plugins { description = 'gRPC BinderChannel' android { - namespace 'io.grpc.binder' + namespace = 'io.grpc.binder' compileSdkVersion 34 compileOptions { sourceCompatibility 1.8 targetCompatibility 1.8 } defaultConfig { - minSdkVersion 21 + minSdkVersion 23 targetSdkVersion 33 versionCode 1 versionName "1.0" testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner" - multiDexEnabled true } - lintOptions { abortOnError false } + lintOptions { abortOnError = false } + buildTypes { + debug { + testCoverageEnabled true // For robolectric unit tests. + enableUnitTestCoverage true // For tests that run on an emulator. + } + } + publishing { singleVariant('release') { withSourcesJar() withJavadocJar() } } + testFixtures { enable = true } } repositories { google() - mavenCentral() } dependencies { @@ -55,6 +61,7 @@ dependencies { testImplementation project(':grpc-testing') testImplementation project(':grpc-inprocess') testImplementation testFixtures(project(':grpc-core')) + testImplementation testFixtures(project(':grpc-api')) androidTestAnnotationProcessor libraries.auto.value androidTestImplementation project(':grpc-testing') @@ -70,15 +77,21 @@ dependencies { androidTestImplementation libraries.androidx.lifecycle.service androidTestImplementation libraries.guava.testlib androidTestImplementation testFixtures(project(':grpc-core')) + + testFixturesImplementation libraries.guava.testlib + testFixturesImplementation testFixtures(project(':grpc-core')) } import net.ltgt.gradle.errorprone.CheckSeverity tasks.withType(JavaCompile).configureEach { options.compilerArgs += [ - "-Xlint:-cast" + "-Xlint:-cast", + // For junit-1.15-api & org.robolectric/shadows-framework/4.11.1 + "-Xlint:-classfile", + // Unclaimed annotations. TODO(jdcormie): Fix? + "-Xlint:-processing", ] - options.compilerArgs -= ["-Werror"] // https://github.com/grpc/grpc-java/issues/10297 appendToProperty(it.options.errorprone.excludedPaths, ".*/R.java", "|") } @@ -123,3 +136,36 @@ publishing { } } } + +afterEvaluate { + components.release.withVariantsFromConfiguration(configurations.releaseTestFixturesVariantReleaseApiPublication) { skip() } + components.release.withVariantsFromConfiguration(configurations.releaseTestFixturesVariantReleaseRuntimePublication) { skip() } +} + +tasks.withType(Test) { + // Robolectric modifies classes in memory at runtime, so they lack a java.security.CodeSource + // URL to their on-disk location. By default, JaCoCo ignores classes without this property. + // Overriding this allows Robolectric tests to be instrumented. + jacoco.includeNoLocationClasses = true + // Don't instrument certain JDK internals protected from modification by JEP 403's "strong + // encapsulation." Avoids IllegalAccessError, InvalidClassException and similar at runtime. + jacoco.excludes = ["jdk.internal.**"] +} + +// Android projects don't automatically get a coverage report task. We must +// register one manually here and wire it up to AGP's test tasks. +tasks.register("jacocoTestReport", JacocoReport) { + dependsOn "testDebugUnitTest" + + reports { + // For codecov.io and coveralls. + xml.required = true + // Use the same output location as the other subprojects. + html.outputLocation = layout.buildDirectory.dir("reports/jacoco/test/html") + } + + sourceDirectories.from = android.sourceSets.main.java.srcDirs + classDirectories.from = fileTree(dir: layout.buildDirectory.dir("intermediates/javac/debug/classes"), + excludes: ['**/R.class', '**/R$*.class', '**/BuildConfig.class', '**/Manifest*.*', '**/*Test*.*', 'android/**/*.*']) + executionData.from = tasks.named("testDebugUnitTest").map { it.jacoco.destinationFile } +} diff --git a/binder/src/androidTest/AndroidManifest.xml b/binder/src/androidTest/AndroidManifest.xml index b6d71574410..44f21e104d9 100644 --- a/binder/src/androidTest/AndroidManifest.xml +++ b/binder/src/androidTest/AndroidManifest.xml @@ -11,11 +11,13 @@ + + diff --git a/binder/src/androidTest/java/io/grpc/binder/BinderChannelSmokeTest.java b/binder/src/androidTest/java/io/grpc/binder/BinderChannelSmokeTest.java index 985d5188a1c..4e3cfcf0d05 100644 --- a/binder/src/androidTest/java/io/grpc/binder/BinderChannelSmokeTest.java +++ b/binder/src/androidTest/java/io/grpc/binder/BinderChannelSmokeTest.java @@ -23,6 +23,7 @@ import android.content.Context; import android.content.Intent; +import android.net.Uri; import android.os.Parcel; import android.os.Parcelable; import androidx.test.core.app.ApplicationProvider; @@ -39,7 +40,6 @@ import io.grpc.ManagedChannel; import io.grpc.Metadata; import io.grpc.MethodDescriptor; -import io.grpc.NameResolverRegistry; import io.grpc.ServerCall; import io.grpc.ServerCall.Listener; import io.grpc.ServerCallHandler; @@ -49,7 +49,6 @@ import io.grpc.Status.Code; import io.grpc.StatusRuntimeException; import io.grpc.internal.GrpcUtil; -import io.grpc.internal.testing.FakeNameResolverProvider; import io.grpc.stub.ClientCalls; import io.grpc.stub.MetadataUtils; import io.grpc.stub.ServerCalls; @@ -77,9 +76,8 @@ public final class BinderChannelSmokeTest { private static final int SLIGHTLY_MORE_THAN_ONE_BLOCK = 16 * 1024 + 100; private static final String MSG = "Some text which will be repeated many many times"; - private static final String SERVER_TARGET_URI = "fake://server"; - private static final Metadata.Key POISON_KEY = ParcelableUtils.metadataKey( - "poison-bin", PoisonParcelable.CREATOR); + private static final Metadata.Key POISON_KEY = + ParcelableUtils.metadataKey("poison-bin", PoisonParcelable.CREATOR); final MethodDescriptor method = MethodDescriptor.newBuilder(StringMarshaller.INSTANCE, StringMarshaller.INSTANCE) @@ -99,7 +97,7 @@ public final class BinderChannelSmokeTest { .setType(MethodDescriptor.MethodType.BIDI_STREAMING) .build(); - FakeNameResolverProvider fakeNameResolverProvider; + AndroidComponentAddress serverAddress; ManagedChannel channel; AtomicReference headersCapture = new AtomicReference<>(); AtomicReference clientUidCapture = new AtomicReference<>(); @@ -137,31 +135,35 @@ public void setUp() throws Exception { TestUtils.recordRequestHeadersInterceptor(headersCapture), PeerUids.newPeerIdentifyingServerInterceptor()); - AndroidComponentAddress serverAddress = HostServices.allocateService(appContext); - fakeNameResolverProvider = new FakeNameResolverProvider(SERVER_TARGET_URI, serverAddress); - NameResolverRegistry.getDefaultRegistry().register(fakeNameResolverProvider); - HostServices.configureService(serverAddress, + serverAddress = HostServices.allocateService(appContext); + HostServices.configureService( + serverAddress, HostServices.serviceParamsBuilder() - .setServerFactory((service, receiver) -> - BinderServerBuilder.forAddress(serverAddress, receiver) - .inboundParcelablePolicy(InboundParcelablePolicy.newBuilder() - .setAcceptParcelableMetadataValues(true) + .setServerFactory( + (service, receiver) -> + BinderServerBuilder.forAddress(serverAddress, receiver) + .inboundParcelablePolicy( + InboundParcelablePolicy.newBuilder() + .setAcceptParcelableMetadataValues(true) + .build()) + .addService(serviceDef) .build()) - .addService(serviceDef) - .build()) .build()); - channel = BinderChannelBuilder.forAddress(serverAddress, appContext) - .inboundParcelablePolicy(InboundParcelablePolicy.newBuilder() - .setAcceptParcelableMetadataValues(true) - .build()) - .build(); + channel = newBinderChannelBuilder().build(); + } + + BinderChannelBuilder newBinderChannelBuilder() { + return BinderChannelBuilder.forAddress(serverAddress, appContext) + .inboundParcelablePolicy( + InboundParcelablePolicy.newBuilder() + .setAcceptParcelableMetadataValues(true) + .build()); } @After public void tearDown() throws Exception { channel.shutdownNow(); - NameResolverRegistry.getDefaultRegistry().deregister(fakeNameResolverProvider); HostServices.awaitServiceShutdown(); } @@ -186,6 +188,18 @@ public void testBasicCall() throws Exception { assertThat(doCall("Hello").get()).isEqualTo("Hello"); } + @Test + public void testBasicCallWithLegacyAuthStrategy() throws Exception { + channel = newBinderChannelBuilder().useLegacyAuthStrategy().build(); + assertThat(doCall("Hello").get()).isEqualTo("Hello"); + } + + @Test + public void testBasicCallWithV2AuthStrategy() throws Exception { + channel = newBinderChannelBuilder().useV2AuthStrategy().build(); + assertThat(doCall("Hello").get()).isEqualTo("Hello"); + } + @Test public void testPeerUidIsRecorded() throws Exception { assertThat(doCall("Hello").get()).isEqualTo("Hello"); @@ -230,7 +244,11 @@ public void testStreamingCallOptionHeaders() throws Exception { @Test public void testConnectViaTargetUri() throws Exception { - channel = BinderChannelBuilder.forTarget(SERVER_TARGET_URI, appContext).build(); + // Compare with the mapping in AndroidManifest.xml. + channel = + BinderChannelBuilder.forTarget( + "intent://authority/path#Intent;action=action1;scheme=scheme;end;", appContext) + .build(); assertThat(doCall("Hello").get()).isEqualTo("Hello"); } @@ -240,7 +258,10 @@ public void testConnectViaIntentFilter() throws Exception { channel = BinderChannelBuilder.forAddress( AndroidComponentAddress.forBindIntent( - new Intent().setAction("action1").setPackage(appContext.getPackageName())), + new Intent() + .setAction("action1") + .setData(Uri.parse("scheme://authority/path")) + .setPackage(appContext.getPackageName())), appContext) .build(); assertThat(doCall("Hello").get()).isEqualTo("Hello"); @@ -253,8 +274,8 @@ public void testUncaughtServerException() throws Exception { Metadata extraHeadersToSend = new Metadata(); extraHeadersToSend.put(POISON_KEY, bad); Channel interceptedChannel = - ClientInterceptors.intercept(channel, - MetadataUtils.newAttachHeadersInterceptor(extraHeadersToSend)); + ClientInterceptors.intercept( + channel, MetadataUtils.newAttachHeadersInterceptor(extraHeadersToSend)); CallOptions callOptions = CallOptions.DEFAULT.withDeadlineAfter(5, SECONDS); try { ClientCalls.blockingUnaryCall(interceptedChannel, method, callOptions, "hello"); @@ -361,33 +382,36 @@ public void onCompleted() { class AddParcelableServerInterceptor implements ServerInterceptor { @Override - public Listener interceptCall(ServerCall call, - Metadata headers, ServerCallHandler next) { - return next.startCall(new SimpleForwardingServerCall(call) { - @Override - public void sendHeaders(Metadata headers) { - if (parcelableForResponseHeaders != null) { - headers.put(POISON_KEY, parcelableForResponseHeaders); - } - super.sendHeaders(headers); - } - }, headers); + public Listener interceptCall( + ServerCall call, Metadata headers, ServerCallHandler next) { + return next.startCall( + new SimpleForwardingServerCall(call) { + @Override + public void sendHeaders(Metadata headers) { + if (parcelableForResponseHeaders != null) { + headers.put(POISON_KEY, parcelableForResponseHeaders); + } + super.sendHeaders(headers); + } + }, + headers); } } static class PoisonParcelable implements Parcelable { - public static final Creator CREATOR = new Parcelable.Creator() { - @Override - public PoisonParcelable createFromParcel(Parcel parcel) { - throw new RuntimeException("ouch"); - } + public static final Creator CREATOR = + new Parcelable.Creator() { + @Override + public PoisonParcelable createFromParcel(Parcel parcel) { + throw new RuntimeException("ouch"); + } - @Override - public PoisonParcelable[] newArray(int n) { - return new PoisonParcelable[n]; - } - }; + @Override + public PoisonParcelable[] newArray(int n) { + return new PoisonParcelable[n]; + } + }; @Override public int describeContents() { @@ -395,7 +419,6 @@ public int describeContents() { } @Override - public void writeToParcel(Parcel parcel, int flags) { - } + public void writeToParcel(Parcel parcel, int flags) {} } } diff --git a/binder/src/androidTest/java/io/grpc/binder/BinderSecurityTest.java b/binder/src/androidTest/java/io/grpc/binder/BinderSecurityTest.java index 1c770110fb6..35ed379556e 100644 --- a/binder/src/androidTest/java/io/grpc/binder/BinderSecurityTest.java +++ b/binder/src/androidTest/java/io/grpc/binder/BinderSecurityTest.java @@ -47,7 +47,6 @@ import java.util.List; import java.util.Map; import java.util.concurrent.atomic.AtomicReference; - import javax.annotation.Nullable; import org.junit.After; import org.junit.Before; @@ -111,15 +110,14 @@ private void createChannel() throws Exception { private void createChannel(ServerSecurityPolicy serverPolicy, SecurityPolicy channelPolicy) throws Exception { AndroidComponentAddress addr = HostServices.allocateService(appContext); - HostServices.configureService(addr, + HostServices.configureService( + addr, HostServices.serviceParamsBuilder() - .setServerFactory((service, receiver) -> buildServer(addr, receiver, serverPolicy)) - .build()); + .setServerFactory((service, receiver) -> buildServer(addr, receiver, serverPolicy)) + .build()); channel = - BinderChannelBuilder.forAddress(addr, appContext) - .securityPolicy(channelPolicy) - .build(); + BinderChannelBuilder.forAddress(addr, appContext).securityPolicy(channelPolicy).build(); } private Server buildServer( @@ -149,7 +147,7 @@ private StatusRuntimeException assertCallFailure( try { ClientCalls.blockingUnaryCall(channel, method, CallOptions.DEFAULT, null); fail("Expected call to " + method.getFullMethodName() + " to fail but it succeeded."); - throw new AssertionError(); // impossible + throw new AssertionError(); // impossible } catch (StatusRuntimeException sre) { assertThat(sre.getStatus().getCode()).isEqualTo(status.getCode()); return sre; @@ -186,12 +184,14 @@ public void testFailedFuturesPropagateOriginalException() throws Exception { IllegalStateException originalException = new IllegalStateException(errorMessage); createChannel( ServerSecurityPolicy.newBuilder() - .servicePolicy("foo", new AsyncSecurityPolicy() { - @Override - public ListenableFuture checkAuthorizationAsync(int uid) { - return Futures.immediateFailedFuture(originalException); - } - }) + .servicePolicy( + "foo", + new AsyncSecurityPolicy() { + @Override + public ListenableFuture checkAuthorizationAsync(int uid) { + return Futures.immediateFailedFuture(originalException); + } + }) .build(), SecurityPolicies.internalOnly()); MethodDescriptor method = methods.get("foo/method0"); @@ -205,15 +205,17 @@ public void testFailedFuturesAreNotCachedPermanently() throws Exception { AtomicReference firstAttempt = new AtomicReference<>(true); createChannel( ServerSecurityPolicy.newBuilder() - .servicePolicy("foo", new AsyncSecurityPolicy() { - @Override - public ListenableFuture checkAuthorizationAsync(int uid) { - if (firstAttempt.getAndSet(false)) { - return Futures.immediateFailedFuture(new IllegalStateException()); - } - return Futures.immediateFuture(Status.OK); - } - }) + .servicePolicy( + "foo", + new AsyncSecurityPolicy() { + @Override + public ListenableFuture checkAuthorizationAsync(int uid) { + if (firstAttempt.getAndSet(false)) { + return Futures.immediateFailedFuture(new IllegalStateException()); + } + return Futures.immediateFuture(Status.OK); + } + }) .build(), SecurityPolicies.internalOnly()); MethodDescriptor method = methods.get("foo/method0"); @@ -227,15 +229,17 @@ public void testCancelledFuturesAreNotCachedPermanently() throws Exception { AtomicReference firstAttempt = new AtomicReference<>(true); createChannel( ServerSecurityPolicy.newBuilder() - .servicePolicy("foo", new AsyncSecurityPolicy() { - @Override - public ListenableFuture checkAuthorizationAsync(int uid) { - if (firstAttempt.getAndSet(false)) { - return Futures.immediateCancelledFuture(); - } - return Futures.immediateFuture(Status.OK); - } - }) + .servicePolicy( + "foo", + new AsyncSecurityPolicy() { + @Override + public ListenableFuture checkAuthorizationAsync(int uid) { + if (firstAttempt.getAndSet(false)) { + return Futures.immediateCancelledFuture(); + } + return Futures.immediateFuture(Status.OK); + } + }) .build(), SecurityPolicies.internalOnly()); MethodDescriptor method = methods.get("foo/method0"); @@ -275,11 +279,11 @@ public void testPerServicePolicy() throws Exception { @Test public void testPerServicePolicyAsync() throws Exception { createChannel( - ServerSecurityPolicy.newBuilder() - .servicePolicy("foo", asyncPolicy((uid) -> Futures.immediateFuture(true))) - .servicePolicy("bar", asyncPolicy((uid) -> Futures.immediateFuture(false))) - .build(), - SecurityPolicies.internalOnly()); + ServerSecurityPolicy.newBuilder() + .servicePolicy("foo", asyncPolicy((uid) -> Futures.immediateFuture(true))) + .servicePolicy("bar", asyncPolicy((uid) -> Futures.immediateFuture(false))) + .build(), + SecurityPolicies.internalOnly()); assertThat(methods).isNotEmpty(); for (MethodDescriptor method : methods.values()) { @@ -326,11 +330,10 @@ private static AsyncSecurityPolicy asyncPolicy( return new AsyncSecurityPolicy() { @Override public ListenableFuture checkAuthorizationAsync(int uid) { - return Futures - .transform( - func.apply(uid), - allowed -> allowed ? Status.OK : Status.PERMISSION_DENIED, - MoreExecutors.directExecutor()); + return Futures.transform( + func.apply(uid), + allowed -> allowed ? Status.OK : Status.PERMISSION_DENIED, + MoreExecutors.directExecutor()); } }; } @@ -340,9 +343,7 @@ private final class CountingServerInterceptor implements ServerInterceptor { @Override public ServerCall.Listener interceptCall( - ServerCall call, - Metadata headers, - ServerCallHandler next) { + ServerCall call, Metadata headers, ServerCallHandler next) { numInterceptedCalls += 1; return next.startCall(call, headers); } diff --git a/binder/src/androidTest/java/io/grpc/binder/HostServices.java b/binder/src/androidTest/java/io/grpc/binder/HostServices.java index 92b232f1ff0..5d4a06a27fe 100644 --- a/binder/src/androidTest/java/io/grpc/binder/HostServices.java +++ b/binder/src/androidTest/java/io/grpc/binder/HostServices.java @@ -16,7 +16,6 @@ package io.grpc.binder; -import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.base.Preconditions.checkState; import static java.util.concurrent.TimeUnit.SECONDS; @@ -30,24 +29,16 @@ import androidx.lifecycle.LifecycleService; import com.google.auto.value.AutoValue; import com.google.common.base.Supplier; -import com.google.common.collect.ImmutableList; -import io.grpc.NameResolver; +import com.google.errorprone.annotations.concurrent.GuardedBy; import io.grpc.Server; -import io.grpc.ServerServiceDefinition; -import io.grpc.ServerStreamTracer; -import io.grpc.binder.AndroidComponentAddress; -import io.grpc.internal.InternalServer; import java.io.IOException; import java.util.HashMap; -import java.util.List; import java.util.Map; import java.util.concurrent.CountDownLatch; import java.util.concurrent.Executor; -import java.util.concurrent.ScheduledExecutorService; import java.util.logging.Level; import java.util.logging.Logger; import javax.annotation.Nullable; -import javax.annotation.concurrent.GuardedBy; /** * A test helper class for creating android services to host gRPC servers. @@ -64,7 +55,6 @@ public final class HostServices { HostService1.class, HostService2.class, }; - public interface ServerFactory { Server createServer(Service service, IBinderReceiver receiver); } diff --git a/binder/src/androidTest/java/io/grpc/binder/internal/BinderClientTransportTest.java b/binder/src/androidTest/java/io/grpc/binder/internal/BinderClientTransportTest.java index 7710924d8c7..aa3fb573ab5 100644 --- a/binder/src/androidTest/java/io/grpc/binder/internal/BinderClientTransportTest.java +++ b/binder/src/androidTest/java/io/grpc/binder/internal/BinderClientTransportTest.java @@ -17,16 +17,18 @@ package io.grpc.binder.internal; import static com.google.common.truth.Truth.assertThat; +import static java.util.concurrent.TimeUnit.SECONDS; import android.content.Context; import android.os.DeadObjectException; import android.os.Parcel; import android.os.RemoteException; -import androidx.core.content.ContextCompat; import androidx.test.core.app.ApplicationProvider; import androidx.test.ext.junit.runners.AndroidJUnit4; +import com.google.common.util.concurrent.SettableFuture; +import com.google.errorprone.annotations.CanIgnoreReturnValue; +import com.google.errorprone.annotations.concurrent.GuardedBy; import com.google.protobuf.Empty; -import io.grpc.Attributes; import io.grpc.CallOptions; import io.grpc.ClientStreamTracer; import io.grpc.Metadata; @@ -36,17 +38,18 @@ import io.grpc.Status; import io.grpc.Status.Code; import io.grpc.binder.AndroidComponentAddress; -import io.grpc.binder.BindServiceFlags; -import io.grpc.binder.BinderChannelCredentials; import io.grpc.binder.BinderServerBuilder; import io.grpc.binder.HostServices; -import io.grpc.binder.InboundParcelablePolicy; -import io.grpc.binder.SecurityPolicies; import io.grpc.binder.SecurityPolicy; +import io.grpc.binder.internal.FakeDeadBinder; +import io.grpc.binder.internal.OneWayBinderProxies.BlackHoleOneWayBinderProxy; import io.grpc.binder.internal.OneWayBinderProxies.BlockingBinderDecorator; import io.grpc.binder.internal.OneWayBinderProxies.ThrowingOneWayBinderProxy; +import io.grpc.binder.internal.SettableAsyncSecurityPolicy.AuthRequest; import io.grpc.internal.ClientStream; import io.grpc.internal.ClientStreamListener; +import io.grpc.internal.ClientTransportFactory.ClientTransportOptions; +import io.grpc.internal.DisconnectError; import io.grpc.internal.FixedObjectPool; import io.grpc.internal.ManagedClientTransport; import io.grpc.internal.ObjectPool; @@ -58,11 +61,11 @@ import java.util.ArrayDeque; import java.util.Deque; import java.util.concurrent.BlockingQueue; +import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.ScheduledExecutorService; import javax.annotation.Nullable; -import javax.annotation.concurrent.GuardedBy; import org.junit.After; import org.junit.Before; import org.junit.Test; @@ -76,9 +79,10 @@ */ @RunWith(AndroidJUnit4.class) public final class BinderClientTransportTest { - private static final ClientStreamTracer[] tracers = new ClientStreamTracer[] { - new ClientStreamTracer() {} - }; + private static final long TIMEOUT_SECONDS = 5; + + private static final ClientStreamTracer[] tracers = + new ClientStreamTracer[] {new ClientStreamTracer() {}}; private final Context appContext = ApplicationProvider.getApplicationContext(); @@ -98,10 +102,13 @@ public final class BinderClientTransportTest { .build(); AndroidComponentAddress serverAddress; - BinderTransport.BinderClientTransport transport; + BinderClientTransport transport; + BlockingSecurityPolicy blockingSecurityPolicy = new BlockingSecurityPolicy(); private final ObjectPool executorServicePool = new FixedObjectPool<>(Executors.newScheduledThreadPool(1)); + private final ObjectPool offloadServicePool = + new FixedObjectPool<>(Executors.newScheduledThreadPool(1)); private final TestTransportListener transportListener = new TestTransportListener(); private final TestStreamListener streamListener = new TestStreamListener(); @@ -142,50 +149,67 @@ public void setUp() throws Exception { } private class BinderClientTransportBuilder { - private SecurityPolicy securityPolicy = SecurityPolicies.internalOnly(); - private OneWayBinderProxy.Decorator binderDecorator = OneWayBinderProxy.IDENTITY_DECORATOR; + final BinderClientTransportFactory.Builder factoryBuilder = + new BinderClientTransportFactory.Builder() + .setSourceContext(appContext) + .setScheduledExecutorPool(executorServicePool) + .setOffloadExecutorPool(offloadServicePool); + @CanIgnoreReturnValue public BinderClientTransportBuilder setSecurityPolicy(SecurityPolicy securityPolicy) { - this.securityPolicy = securityPolicy; + factoryBuilder.setSecurityPolicy(securityPolicy); return this; } + @CanIgnoreReturnValue public BinderClientTransportBuilder setBinderDecorator( OneWayBinderProxy.Decorator binderDecorator) { - this.binderDecorator = binderDecorator; + factoryBuilder.setBinderDecorator(binderDecorator); + return this; + } + + @CanIgnoreReturnValue + public BinderClientTransportBuilder setReadyTimeoutMillis(int timeoutMillis) { + factoryBuilder.setReadyTimeoutMillis(timeoutMillis); return this; } - public BinderTransport.BinderClientTransport build() { - return new BinderTransport.BinderClientTransport( - appContext, - BinderChannelCredentials.forDefault(), - serverAddress, - null, - BindServiceFlags.DEFAULTS, - ContextCompat.getMainExecutor(appContext), - executorServicePool, - executorServicePool, - securityPolicy, - InboundParcelablePolicy.DEFAULT, - binderDecorator, - Attributes.EMPTY); + @CanIgnoreReturnValue + public BinderClientTransportBuilder setPreAuthorizeServer(boolean preAuthorizeServer) { + factoryBuilder.setPreAuthorizeServers(preAuthorizeServer); + return this; + } + + public BinderClientTransport build() { + return factoryBuilder + .buildClientTransportFactory() + .newClientTransport(serverAddress, new ClientTransportOptions(), null); } } @After public void tearDown() throws Exception { + blockingSecurityPolicy.provideNextCheckAuthorizationResult(Status.ABORTED); transport.shutdownNow(Status.OK); HostServices.awaitServiceShutdown(); - executorServicePool.getObject().shutdownNow(); + shutdownAndTerminate(executorServicePool.getObject()); + shutdownAndTerminate(offloadServicePool.getObject()); + } + + private static void shutdownAndTerminate(ExecutorService executorService) + throws InterruptedException { + executorService.shutdownNow(); + if (!executorService.awaitTermination(TIMEOUT_SECONDS, SECONDS)) { + throw new AssertionError("executor failed to terminate promptly"); + } } @Test public void testShutdownBeforeStreamStart_b153326034() throws Exception { transport = new BinderClientTransportBuilder().build(); startAndAwaitReady(transport, transportListener); - ClientStream stream = transport.newStream( - methodDesc, new Metadata(), CallOptions.DEFAULT, tracers); + ClientStream stream = + transport.newStream(methodDesc, new Metadata(), CallOptions.DEFAULT, tracers); transport.shutdownNow(Status.UNKNOWN.withDescription("reasons")); // This shouldn't throw an exception. @@ -269,10 +293,10 @@ public void testMessageProducerClosedAfterStream_b169313545() throws Exception { } @Test - public void testNewStreamBeforeTransportReadyFails() throws InterruptedException { + public void testNewStreamBeforeTransportReadyFails() throws Exception { // Use a special SecurityPolicy that lets us act before the transport is setup/ready. - BlockingSecurityPolicy bsp = new BlockingSecurityPolicy(); - transport = new BinderClientTransportBuilder().setSecurityPolicy(bsp).build(); + transport = + new BinderClientTransportBuilder().setSecurityPolicy(blockingSecurityPolicy).build(); transport.start(transportListener).run(); ClientStream stream = transport.newStream(streamingMethodDesc, new Metadata(), CallOptions.DEFAULT, tracers); @@ -280,19 +304,17 @@ public void testNewStreamBeforeTransportReadyFails() throws InterruptedException assertThat(streamListener.awaitClose().getCode()).isEqualTo(Code.INTERNAL); // Unblock the SETUP_TRANSPORT handshake and make sure it becomes ready in the usual way. - bsp.provideNextCheckAuthorizationResult(Status.OK); + blockingSecurityPolicy.provideNextCheckAuthorizationResult(Status.OK); transportListener.awaitReady(); } @Test - public void testTxnFailureDuringSetup() throws InterruptedException { + public void testTxnFailureDuringSetup() throws Exception { BlockingBinderDecorator decorator = new BlockingBinderDecorator<>(); - transport = new BinderClientTransportBuilder() - .setBinderDecorator(decorator) - .build(); + transport = new BinderClientTransportBuilder().setBinderDecorator(decorator).build(); transport.start(transportListener).run(); - ThrowingOneWayBinderProxy endpointBinder = new ThrowingOneWayBinderProxy( - decorator.takeNextRequest()); + ThrowingOneWayBinderProxy endpointBinder = + new ThrowingOneWayBinderProxy(decorator.takeNextRequest()); DeadObjectException doe = new DeadObjectException("ouch"); endpointBinder.setRemoteException(doe); decorator.putNextResult(endpointBinder); @@ -312,17 +334,15 @@ public void testTxnFailureDuringSetup() throws InterruptedException { } @Test - public void testTxnFailurePostSetup() throws InterruptedException { + public void testTxnFailurePostSetup() throws Exception { BlockingBinderDecorator decorator = new BlockingBinderDecorator<>(); - transport = new BinderClientTransportBuilder() - .setBinderDecorator(decorator) - .build(); + transport = new BinderClientTransportBuilder().setBinderDecorator(decorator).build(); transport.start(transportListener).run(); - ThrowingOneWayBinderProxy endpointBinder = new ThrowingOneWayBinderProxy( - decorator.takeNextRequest()); + ThrowingOneWayBinderProxy endpointBinder = + new ThrowingOneWayBinderProxy(decorator.takeNextRequest()); decorator.putNextResult(endpointBinder); - ThrowingOneWayBinderProxy serverBinder = new ThrowingOneWayBinderProxy( - decorator.takeNextRequest()); + ThrowingOneWayBinderProxy serverBinder = + new ThrowingOneWayBinderProxy(decorator.takeNextRequest()); DeadObjectException doe = new DeadObjectException("ouch"); serverBinder.setRemoteException(doe); decorator.putNextResult(serverBinder); @@ -340,59 +360,206 @@ public void testTxnFailurePostSetup() throws InterruptedException { assertThat(streamStatus.getCause()).isSameInstanceAs(doe); } + @Test + public void testServerBinderDeadOnArrival() throws Exception { + BlockingBinderDecorator decorator = new BlockingBinderDecorator<>(); + transport = new BinderClientTransportBuilder().setBinderDecorator(decorator).build(); + transport.start(transportListener).run(); + decorator.putNextResult(decorator.takeNextRequest()); // Server's "Endpoint" Binder. + OneWayBinderProxy unusedServerBinder = decorator.takeNextRequest(); + decorator.putNextResult( + OneWayBinderProxy.wrap(new FakeDeadBinder(), offloadServicePool.getObject())); + Status clientStatus = transportListener.awaitShutdown(); + assertThat(clientStatus.getCode()).isEqualTo(Code.UNAVAILABLE); + assertThat(clientStatus.getDescription()).contains("Failed to observe outgoing binder"); + } + + @Test + public void testBlackHoleEndpointConnectTimeout() throws Exception { + BlockingBinderDecorator decorator = new BlockingBinderDecorator<>(); + transport = + new BinderClientTransportBuilder() + .setBinderDecorator(decorator) + .setReadyTimeoutMillis(1_234) + .build(); + transport.start(transportListener).run(); + BlackHoleOneWayBinderProxy endpointBinder = + new BlackHoleOneWayBinderProxy(decorator.takeNextRequest()); + endpointBinder.dropAllTransactions(true); + decorator.putNextResult(endpointBinder); + Status transportStatus = transportListener.awaitShutdown(); + assertThat(transportStatus.getCode()).isEqualTo(Code.DEADLINE_EXCEEDED); + assertThat(transportStatus.getDescription()).contains("1234"); + transportListener.awaitTermination(); + } + + @Test + public void testBlackHoleSecurityPolicyAuthTimeout() throws Exception { + SettableAsyncSecurityPolicy securityPolicy = new SettableAsyncSecurityPolicy(); + transport = + new BinderClientTransportBuilder() + .setSecurityPolicy(securityPolicy) + .setPreAuthorizeServer(false) + .setReadyTimeoutMillis(1_234) + .build(); + transport.start(transportListener).run(); + // Take the next authRequest but don't respond to it, in order to trigger the ready timeout. + AuthRequest authRequest = securityPolicy.takeNextAuthRequest(TIMEOUT_SECONDS, SECONDS); + + Status transportStatus = transportListener.awaitShutdown(); + assertThat(transportStatus.getCode()).isEqualTo(Code.DEADLINE_EXCEEDED); + assertThat(transportStatus.getDescription()).contains("1234"); + transportListener.awaitTermination(); + // If the transport gave up waiting on auth, it should cancel its request. + assertThat(authRequest.isCancelled()).isTrue(); + } + + @Test + public void testBlackHoleSecurityPolicyPreAuthTimeout() throws Exception { + SettableAsyncSecurityPolicy securityPolicy = new SettableAsyncSecurityPolicy(); + transport = + new BinderClientTransportBuilder() + .setSecurityPolicy(securityPolicy) + .setPreAuthorizeServer(true) + .setReadyTimeoutMillis(1_234) + .build(); + transport.start(transportListener).run(); + // Take the next authRequest but don't respond to it, in order to trigger the ready timeout. + AuthRequest preAuthRequest = securityPolicy.takeNextAuthRequest(TIMEOUT_SECONDS, SECONDS); + + Status transportStatus = transportListener.awaitShutdown(); + assertThat(transportStatus.getCode()).isEqualTo(Code.DEADLINE_EXCEEDED); + assertThat(transportStatus.getDescription()).contains("1234"); + transportListener.awaitTermination(); + // If the transport gave up waiting on auth, it should cancel its request. + assertThat(preAuthRequest.isCancelled()).isTrue(); + } + + @Test + public void testAsyncSecurityPolicyAuthFailure() throws Exception { + SettableAsyncSecurityPolicy securityPolicy = new SettableAsyncSecurityPolicy(); + transport = + new BinderClientTransportBuilder() + .setPreAuthorizeServer(false) + .setSecurityPolicy(securityPolicy) + .build(); + RuntimeException exception = new NullPointerException(); + transport.start(transportListener).run(); + securityPolicy.takeNextAuthRequest(TIMEOUT_SECONDS, SECONDS).setResult(exception); + Status transportStatus = transportListener.awaitShutdown(); + assertThat(transportStatus.getCode()).isEqualTo(Code.INTERNAL); + assertThat(transportStatus.getCause()).isEqualTo(exception); + transportListener.awaitTermination(); + } + + @Test + public void testAsyncSecurityPolicyPreAuthFailure() throws Exception { + SettableAsyncSecurityPolicy securityPolicy = new SettableAsyncSecurityPolicy(); + transport = + new BinderClientTransportBuilder() + .setPreAuthorizeServer(true) + .setSecurityPolicy(securityPolicy) + .build(); + RuntimeException exception = new NullPointerException(); + transport.start(transportListener).run(); + securityPolicy.takeNextAuthRequest(TIMEOUT_SECONDS, SECONDS).setResult(exception); + Status transportStatus = transportListener.awaitShutdown(); + assertThat(transportStatus.getCode()).isEqualTo(Code.INTERNAL); + assertThat(transportStatus.getCause()).isEqualTo(exception); + transportListener.awaitTermination(); + } + + @Test + public void testAsyncSecurityPolicyAuthSuccess() throws Exception { + SettableAsyncSecurityPolicy securityPolicy = new SettableAsyncSecurityPolicy(); + transport = + new BinderClientTransportBuilder() + .setPreAuthorizeServer(false) + .setSecurityPolicy(securityPolicy) + .build(); + transport.start(transportListener).run(); + securityPolicy + .takeNextAuthRequest(TIMEOUT_SECONDS, SECONDS) + .setResult(Status.PERMISSION_DENIED.withDescription("xyzzy")); + Status transportStatus = transportListener.awaitShutdown(); + assertThat(transportStatus.getCode()).isEqualTo(Code.PERMISSION_DENIED); + assertThat(transportStatus.getDescription()).contains("xyzzy"); + transportListener.awaitTermination(); + } + + @Test + public void testAsyncSecurityPolicyPreAuthSuccess() throws Exception { + SettableAsyncSecurityPolicy securityPolicy = new SettableAsyncSecurityPolicy(); + transport = + new BinderClientTransportBuilder() + .setPreAuthorizeServer(true) + .setSecurityPolicy(securityPolicy) + .build(); + transport.start(transportListener).run(); + securityPolicy + .takeNextAuthRequest(TIMEOUT_SECONDS, SECONDS) + .setResult(Status.PERMISSION_DENIED.withDescription("xyzzy")); + Status transportStatus = transportListener.awaitShutdown(); + assertThat(transportStatus.getCode()).isEqualTo(Code.PERMISSION_DENIED); + assertThat(transportStatus.getDescription()).contains("xyzzy"); + transportListener.awaitTermination(); + } + + @Test + public void testAsyncSecurityPolicyCancelledUponExternalTermination() throws Exception { + SettableAsyncSecurityPolicy securityPolicy = new SettableAsyncSecurityPolicy(); + transport = new BinderClientTransportBuilder().setSecurityPolicy(securityPolicy).build(); + transport.start(transportListener).run(); + AuthRequest authRequest = securityPolicy.takeNextAuthRequest(TIMEOUT_SECONDS, SECONDS); + transport.shutdownNow(Status.UNAVAILABLE); // 'authRequest' remains unanswered! + transportListener.awaitShutdown(); + transportListener.awaitTermination(); + assertThat(authRequest.isCancelled()).isTrue(); + } + private static void startAndAwaitReady( - BinderTransport.BinderClientTransport transport, TestTransportListener transportListener) { + BinderClientTransport transport, TestTransportListener transportListener) throws Exception { transport.start(transportListener).run(); transportListener.awaitReady(); } private static final class TestTransportListener implements ManagedClientTransport.Listener { - @GuardedBy("this") - private boolean ready; - public boolean inUse; - @Nullable public Status shutdownStatus; - public boolean terminated; + private final SettableFuture isReady = SettableFuture.create(); + private final SettableFuture shutdownStatus = SettableFuture.create(); + private final SettableFuture isTerminated = SettableFuture.create(); @Override - public synchronized void transportShutdown(Status shutdownStatus) { - this.shutdownStatus = shutdownStatus; - notifyAll(); + public void transportShutdown(Status shutdownStatus, DisconnectError disconnectError) { + if (!this.shutdownStatus.set(shutdownStatus)) { + throw new IllegalStateException("transportShutdown() already called"); + } } - public synchronized Status awaitShutdown() throws InterruptedException { - while (shutdownStatus == null) { - wait(); - } - return shutdownStatus; + public Status awaitShutdown() throws Exception { + return shutdownStatus.get(TIMEOUT_SECONDS, SECONDS); } @Override - public synchronized void transportTerminated() { - terminated = true; - notifyAll(); + public void transportTerminated() { + if (!isTerminated.set(true)) { + throw new IllegalStateException("isTerminated() already called"); + } } - public synchronized void awaitTermination() throws InterruptedException { - while (!terminated) { - wait(); - } + public void awaitTermination() throws Exception { + isTerminated.get(TIMEOUT_SECONDS, SECONDS); } @Override - public synchronized void transportReady() { - ready = true; - notifyAll(); + public void transportReady() { + if (!isReady.set(true)) { + throw new IllegalStateException("isTerminated() already called"); + } } - public synchronized void awaitReady() { - while (!ready) { - try { - wait(); - } catch (InterruptedException inte) { - throw new AssertionError("Interrupted waiting for ready"); - } - } + public void awaitReady() throws Exception { + isReady.get(TIMEOUT_SECONDS, SECONDS); } @Override diff --git a/binder/src/androidTest/java/io/grpc/binder/internal/BinderTransportTest.java b/binder/src/androidTest/java/io/grpc/binder/internal/BinderTransportTest.java index 84ff74b4f8e..7932cabde89 100644 --- a/binder/src/androidTest/java/io/grpc/binder/internal/BinderTransportTest.java +++ b/binder/src/androidTest/java/io/grpc/binder/internal/BinderTransportTest.java @@ -17,18 +17,13 @@ package io.grpc.binder.internal; import android.content.Context; -import androidx.core.content.ContextCompat; import androidx.test.core.app.ApplicationProvider; import androidx.test.ext.junit.runners.AndroidJUnit4; import io.grpc.ServerStreamTracer; import io.grpc.binder.AndroidComponentAddress; -import io.grpc.binder.BindServiceFlags; -import io.grpc.binder.BinderChannelCredentials; -import io.grpc.binder.BinderInternal; import io.grpc.binder.HostServices; -import io.grpc.binder.InboundParcelablePolicy; -import io.grpc.binder.SecurityPolicies; import io.grpc.internal.AbstractTransportTest; +import io.grpc.internal.ClientTransportFactory.ClientTransportOptions; import io.grpc.internal.GrpcUtil; import io.grpc.internal.InternalServer; import io.grpc.internal.ManagedClientTransport; @@ -56,6 +51,8 @@ public final class BinderTransportTest extends AbstractTransportTest { SharedResourcePool.forResource(GrpcUtil.TIMER_SERVICE); private final ObjectPool offloadExecutorPool = SharedResourcePool.forResource(GrpcUtil.SHARED_CHANNEL_EXECUTOR); + private final ObjectPool serverExecutorPool = + SharedResourcePool.forResource(GrpcUtil.SHARED_CHANNEL_EXECUTOR); @Override @After @@ -68,14 +65,16 @@ public void tearDown() throws InterruptedException { protected InternalServer newServer(List streamTracerFactories) { AndroidComponentAddress addr = HostServices.allocateService(appContext); - BinderServer binderServer = new BinderServer(addr, - executorServicePool, - streamTracerFactories, - BinderInternal.createPolicyChecker(SecurityPolicies.serverInternalOnly()), - InboundParcelablePolicy.DEFAULT, - /* shutdownListener=*/ () -> {}); + BinderServer binderServer = + new BinderServer.Builder() + .setListenAddress(addr) + .setExecutorPool(serverExecutorPool) + .setExecutorServicePool(executorServicePool) + .setStreamTracerFactories(streamTracerFactories) + .build(); - HostServices.configureService(addr, + HostServices.configureService( + addr, HostServices.serviceParamsBuilder() .setRawBinderSupplier(() -> binderServer.getHostBinder()) .build()); @@ -97,19 +96,17 @@ protected String testAuthority(InternalServer server) { @Override protected ManagedClientTransport newClientTransport(InternalServer server) { AndroidComponentAddress addr = (AndroidComponentAddress) server.getListenSocketAddress(); - return new BinderTransport.BinderClientTransport( - appContext, - BinderChannelCredentials.forDefault(), - addr, - null, - BindServiceFlags.DEFAULTS, - ContextCompat.getMainExecutor(appContext), - executorServicePool, - offloadExecutorPool, - SecurityPolicies.internalOnly(), - InboundParcelablePolicy.DEFAULT, - OneWayBinderProxy.IDENTITY_DECORATOR, - eagAttrs()); + BinderClientTransportFactory.Builder builder = + new BinderClientTransportFactory.Builder() + .setSourceContext(appContext) + .setScheduledExecutorPool(executorServicePool) + .setOffloadExecutorPool(offloadExecutorPool); + + ClientTransportOptions options = new ClientTransportOptions(); + options.setEagAttributes(eagAttrs()); + options.setChannelLogger(transportLogger()); + + return new BinderClientTransport(builder.buildClientTransportFactory(), addr, options); } @Test @@ -122,11 +119,6 @@ public void socketStats() throws Exception {} @Override public void flowControlPushBack() throws Exception {} - @Test - @Ignore("Not yet implemented. See https://github.com/grpc/grpc-java/issues/8931") - @Override - public void serverNotListening() throws Exception {} - @Test @Ignore("This test isn't appropriate for BinderTransport.") @Override @@ -136,7 +128,7 @@ public void serverAlreadyListening() throws Exception { // refers to an Android Service class declared in an applications manifest. // // However, unlike a regular network server, which is responsible for listening on its port, a - // BinderServier is not responsible for the creation of its host Service. The opposite is + // BinderServer is not responsible for the creation of its host Service. The opposite is // the case, with the host Android Service (itself created by the Android platform in // response to a connection) building the gRPC server. // diff --git a/binder/src/androidTest/java/io/grpc/binder/internal/LeakSafeOneWayBinderTest.java b/binder/src/androidTest/java/io/grpc/binder/internal/LeakSafeOneWayBinderTest.java index f7ed5ad13cb..835c73bee50 100644 --- a/binder/src/androidTest/java/io/grpc/binder/internal/LeakSafeOneWayBinderTest.java +++ b/binder/src/androidTest/java/io/grpc/binder/internal/LeakSafeOneWayBinderTest.java @@ -21,6 +21,7 @@ import android.os.Parcel; import androidx.test.ext.junit.runners.AndroidJUnit4; +import io.grpc.binder.internal.LeakSafeOneWayBinder.TransactionHandler; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @@ -29,28 +30,34 @@ public final class LeakSafeOneWayBinderTest { private LeakSafeOneWayBinder binder; + private final FakeHandler handler = new FakeHandler(); - private int transactionsHandled; - private int lastCode; - private Parcel lastParcel; + static class FakeHandler implements TransactionHandler { + int transactionsHandled; + int lastCode; + Parcel lastParcel; - @Before - public void setUp() { - binder = new LeakSafeOneWayBinder((code, parcel) -> { + @Override + public boolean handleTransaction(int code, Parcel parcel) { transactionsHandled++; lastCode = code; lastParcel = parcel; return true; - }); + } + } + + @Before + public void setUp() { + binder = new LeakSafeOneWayBinder(handler); } @Test public void testTransaction() { Parcel p = Parcel.obtain(); assertThat(binder.onTransact(123, p, null, FLAG_ONEWAY)).isTrue(); - assertThat(transactionsHandled).isEqualTo(1); - assertThat(lastCode).isEqualTo(123); - assertThat(lastParcel).isSameInstanceAs(p); + assertThat(handler.transactionsHandled).isEqualTo(1); + assertThat(handler.lastCode).isEqualTo(123); + assertThat(handler.lastParcel).isSameInstanceAs(p); p.recycle(); } @@ -59,7 +66,7 @@ public void testDropsTwoWayTransactions() { Parcel p = Parcel.obtain(); Parcel reply = Parcel.obtain(); assertThat(binder.onTransact(123, p, reply, 0)).isFalse(); - assertThat(transactionsHandled).isEqualTo(0); + assertThat(handler.transactionsHandled).isEqualTo(0); p.recycle(); reply.recycle(); } @@ -71,7 +78,21 @@ public void testDetach() { assertThat(binder.onTransact(456, p, null, FLAG_ONEWAY)).isFalse(); // The transaction shouldn't have been processed. - assertThat(transactionsHandled).isEqualTo(0); + assertThat(handler.transactionsHandled).isEqualTo(0); + + p.recycle(); + } + + @Test + public void testReplace() { + binder = new LeakSafeOneWayBinder(handler); + Parcel p = Parcel.obtain(); + FakeHandler handler2 = new FakeHandler(); + binder.setHandler(handler2); + assertThat(binder.onTransact(456, p, null, FLAG_ONEWAY)).isTrue(); + + assertThat(handler.transactionsHandled).isEqualTo(0); + assertThat(handler2.transactionsHandled).isEqualTo(1); p.recycle(); } @@ -81,9 +102,9 @@ public void testMultipleTransactions() { Parcel p = Parcel.obtain(); assertThat(binder.onTransact(123, p, null, FLAG_ONEWAY)).isTrue(); assertThat(binder.onTransact(456, p, null, FLAG_ONEWAY)).isTrue(); - assertThat(transactionsHandled).isEqualTo(2); - assertThat(lastCode).isEqualTo(456); - assertThat(lastParcel).isSameInstanceAs(p); + assertThat(handler.transactionsHandled).isEqualTo(2); + assertThat(handler.lastCode).isEqualTo(456); + assertThat(handler.lastParcel).isSameInstanceAs(p); p.recycle(); } diff --git a/binder/src/androidTest/java/io/grpc/binder/internal/OneWayBinderProxies.java b/binder/src/androidTest/java/io/grpc/binder/internal/OneWayBinderProxies.java deleted file mode 100644 index 229c9426125..00000000000 --- a/binder/src/androidTest/java/io/grpc/binder/internal/OneWayBinderProxies.java +++ /dev/null @@ -1,100 +0,0 @@ -/* - * Copyright 2024 The gRPC Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.grpc.binder.internal; - -import android.os.RemoteException; -import java.util.concurrent.BlockingQueue; -import java.util.concurrent.LinkedBlockingQueue; -import javax.annotation.Nullable; - -/** - * A collection of {@link OneWayBinderProxy}-related test helpers. - */ -public final class OneWayBinderProxies { - /** - * A {@link OneWayBinderProxy.Decorator} that blocks calling threads while an (external) test - * provides the actual decoration. - */ - public static final class BlockingBinderDecorator implements - OneWayBinderProxy.Decorator { - private final BlockingQueue requests = new LinkedBlockingQueue<>(); - private final BlockingQueue results = new LinkedBlockingQueue<>(); - - /** - * Returns the next {@link OneWayBinderProxy} that needs decorating, blocking if it hasn't yet - * been provided to {@link #decorate}. - * - *

Follow this with a call to {@link #putNextResult(OneWayBinderProxy)} to provide - * the result of {@link #decorate} and unblock the waiting caller. - */ - public OneWayBinderProxy takeNextRequest() throws InterruptedException { - return requests.take(); - } - - /** - * Provides the next value to return from {@link #decorate}. - */ - public void putNextResult(T next) throws InterruptedException { - results.put(next); - } - - @Override - public OneWayBinderProxy decorate(OneWayBinderProxy in) { - try { - requests.put(in); - return results.take(); - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - throw new RuntimeException(e); - } - } - } - - /** - * A {@link OneWayBinderProxy} decorator whose transact method can artificially throw. - */ - public static final class ThrowingOneWayBinderProxy extends OneWayBinderProxy { - private final OneWayBinderProxy wrapped; - @Nullable - private RemoteException remoteException; - - ThrowingOneWayBinderProxy(OneWayBinderProxy wrapped) { - super(wrapped.getDelegate()); - this.wrapped = wrapped; - } - - /** - * Causes all future invocations of transact to throw `remoteException`. - * - *

Users are responsible for ensuring their calls "happen-before" the relevant calls to - * {@link #transact(int, ParcelHolder)}. - */ - public void setRemoteException(RemoteException remoteException) { - this.remoteException = remoteException; - } - - @Override - public void transact(int code, ParcelHolder data) throws RemoteException { - if (remoteException != null) { - throw remoteException; - } - wrapped.transact(code, data); - } - } - - // Cannot be instantiated. - private OneWayBinderProxies() {}; -} diff --git a/binder/src/main/AndroidManifest.xml b/binder/src/main/AndroidManifest.xml index a30cbbdd6fa..239c3b39b38 100644 --- a/binder/src/main/AndroidManifest.xml +++ b/binder/src/main/AndroidManifest.xml @@ -1,2 +1,11 @@ - - + + + + + + + + + + + \ No newline at end of file diff --git a/binder/src/main/java/io/grpc/binder/AndroidComponentAddress.java b/binder/src/main/java/io/grpc/binder/AndroidComponentAddress.java index 2cad159f2e9..b390c1f0ccd 100644 --- a/binder/src/main/java/io/grpc/binder/AndroidComponentAddress.java +++ b/binder/src/main/java/io/grpc/binder/AndroidComponentAddress.java @@ -18,10 +18,14 @@ import static android.content.Intent.URI_ANDROID_APP_SCHEME; import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkState; import android.content.ComponentName; import android.content.Context; import android.content.Intent; +import android.os.UserHandle; +import com.google.common.base.Objects; +import io.grpc.ExperimentalApi; import java.net.SocketAddress; import javax.annotation.Nullable; @@ -41,18 +45,25 @@ * fields, namely, an action of {@link ApiConstants#ACTION_BIND}, an empty category set and null * type and data URI. * - *

The semantics of {@link #equals(Object)} are the same as {@link Intent#filterEquals(Intent)}. + *

Optionally contains a {@link UserHandle} that must be considered wherever the {@link Intent} + * is evaluated. + * + *

{@link #equals(Object)} uses {@link Intent#filterEquals(Intent)} semantics to compare Intents. */ public final class AndroidComponentAddress extends SocketAddress { private static final long serialVersionUID = 0L; private final Intent bindIntent; // "Explicit", having either a component or package restriction. - protected AndroidComponentAddress(Intent bindIntent) { + @Nullable + private final UserHandle targetUser; // null means the same user that hosts this process. + + private AndroidComponentAddress(Intent bindIntent, @Nullable UserHandle targetUser) { checkArgument( bindIntent.getComponent() != null || bindIntent.getPackage() != null, "'bindIntent' must be explicit. Specify either a package or ComponentName."); this.bindIntent = bindIntent; + this.targetUser = targetUser; } /** @@ -72,8 +83,8 @@ public static AndroidComponentAddress forLocalComponent(Context context, ClassNB: The returned Intent does not specify a target Android user. If {@link #getTargetUser()} + * is non-null, {@link Context#bindServiceAsUser} should be called instead. */ public Intent asBindIntent() { return bindIntent.cloneFilter(); // Intent is mutable so return a copy. @@ -177,13 +191,92 @@ public int hashCode() { public boolean equals(Object obj) { if (obj instanceof AndroidComponentAddress) { AndroidComponentAddress that = (AndroidComponentAddress) obj; - return bindIntent.filterEquals(that.bindIntent); + return bindIntent.filterEquals(that.bindIntent) + && Objects.equal(this.targetUser, that.targetUser); } return false; } @Override public String toString() { - return "AndroidComponentAddress[" + bindIntent + "]"; + StringBuilder builder = new StringBuilder("AndroidComponentAddress["); + if (targetUser != null) { + builder.append(targetUser); + builder.append("@"); + } + builder.append(bindIntent); + builder.append("]"); + return builder.toString(); + } + + /** + * Identifies the Android user in which the bind Intent will be evaluated. + * + *

Returns the {@link UserHandle}, or null which means that the Android user hosting the + * current process will be used. + */ + @ExperimentalApi("https://github.com/grpc/grpc-java/issues/10173") + @Nullable + public UserHandle getTargetUser() { + return targetUser; + } + + public static Builder newBuilder() { + return new Builder(); + } + + /** Fluently builds instances of {@link AndroidComponentAddress}. */ + public static class Builder { + Intent bindIntent; + UserHandle targetUser; + + /** + * Sets the binding {@link Intent} to one having the "filter matching" fields of 'intent'. + * + *

'intent' must be "explicit", i.e. having either a target component ({@link + * Intent#getComponent()}) or package restriction ({@link Intent#getPackage()}). + */ + public Builder setBindIntent(Intent intent) { + this.bindIntent = intent.cloneFilter(); + return this; + } + + /** + * Sets the binding {@link Intent} to one with the specified 'component' and default values for + * all other fields, for convenience. + */ + public Builder setBindIntentFromComponent(ComponentName component) { + this.bindIntent = new Intent(ApiConstants.ACTION_BIND).setComponent(component); + return this; + } + + /** + * Specifies the Android user in which the built Address' bind Intent will be evaluated. + * + *

Connecting to a server in a different Android user is uncommon and requires the client app + * have runtime visibility of @SystemApi's and hold certain @SystemApi permissions. + * The device must also be running Android SDK version 30 or higher. + * + *

See https://developer.android.com/guide/app-compatibility/restrictions-non-sdk-interfaces + * for details on which apps can call the underlying @SystemApi's needed to make this type + * of connection. + * + *

One of the "android.permission.INTERACT_ACROSS_XXX" permissions is required. The exact one + * depends on the calling user's relationship to the target user, whether client and server are + * in the same or different apps, and the version of Android in use. See {@link + * Context#bindServiceAsUser}, the essential underlying Android API, for details. + */ + @ExperimentalApi("https://github.com/grpc/grpc-java/issues/10173") + public Builder setTargetUser(@Nullable UserHandle targetUser) { + this.targetUser = targetUser; + return this; + } + + public AndroidComponentAddress build() { + // We clone any incoming mutable intent in the setter, not here. AndroidComponentAddress + // itself is immutable so multiple instances built from here can safely share 'bindIntent'. + checkState(bindIntent != null, "Required property 'bindIntent' unset"); + return new AndroidComponentAddress(bindIntent, targetUser); + } } } diff --git a/binder/src/main/java/io/grpc/binder/ApiConstants.java b/binder/src/main/java/io/grpc/binder/ApiConstants.java index 43e94338fdc..fbf4be6b7ce 100644 --- a/binder/src/main/java/io/grpc/binder/ApiConstants.java +++ b/binder/src/main/java/io/grpc/binder/ApiConstants.java @@ -17,7 +17,11 @@ package io.grpc.binder; import android.content.Intent; +import android.os.UserHandle; +import io.grpc.Attributes; +import io.grpc.EquivalentAddressGroup; import io.grpc.ExperimentalApi; +import io.grpc.NameResolver; /** Constant parts of the gRPC binder transport public API. */ @ExperimentalApi("https://github.com/grpc/grpc-java/issues/8022") @@ -29,4 +33,43 @@ private ApiConstants() {} * themselves in a {@link android.app.Service#onBind(Intent)} call. */ public static final String ACTION_BIND = "grpc.io.action.BIND"; + + /** + * Gives a {@link NameResolver} access to its Channel's "source" {@link android.content.Context}, + * the entry point to almost every other Android API. + * + *

This argument is set automatically by {@link BinderChannelBuilder}. Any value passed to + * {@link io.grpc.ManagedChannelBuilder#setNameResolverArg} will be ignored. + * + *

See {@link BinderChannelBuilder#forTarget(String, android.content.Context)} for more. + */ + public static final NameResolver.Args.Key SOURCE_ANDROID_CONTEXT = + NameResolver.Args.Key.create("source-android-context"); + + /** + * Specifies the Android user in which target URIs should be resolved. + * + *

{@link UserHandle} can't reasonably be encoded in a target URI string. Instead, all {@link + * io.grpc.NameResolverProvider}s producing {@link AndroidComponentAddress}es should let clients + * address servers in another Android user using this argument. + * + *

Connecting to a server in a different Android user is uncommon and can only be done by a + * "system app" client with special permissions. See {@link + * AndroidComponentAddress.Builder#setTargetUser(UserHandle)} for details. + */ + @ExperimentalApi("https://github.com/grpc/grpc-java/issues/10173") + public static final NameResolver.Args.Key TARGET_ANDROID_USER = + NameResolver.Args.Key.create("target-android-user"); + + /** + * Lets you override a Channel's pre-auth configuration (see {@link + * BinderChannelBuilder#preAuthorizeServers(boolean)}) for a given {@link EquivalentAddressGroup}. + * + *

A {@link NameResolver} that discovers servers from an untrusted source like PackageManager + * can use this to force server pre-auth and prevent abuse. + */ + @EquivalentAddressGroup.Attr + @ExperimentalApi("https://github.com/grpc/grpc-java/issues/12191") + public static final Attributes.Key PRE_AUTH_SERVER_OVERRIDE = + Attributes.Key.create("pre-auth-server-override"); } diff --git a/binder/src/main/java/io/grpc/binder/AsyncSecurityPolicy.java b/binder/src/main/java/io/grpc/binder/AsyncSecurityPolicy.java index 11952a21f9b..9594c644e0c 100644 --- a/binder/src/main/java/io/grpc/binder/AsyncSecurityPolicy.java +++ b/binder/src/main/java/io/grpc/binder/AsyncSecurityPolicy.java @@ -17,12 +17,11 @@ package io.grpc.binder; import com.google.common.util.concurrent.ListenableFuture; +import com.google.errorprone.annotations.CheckReturnValue; import io.grpc.ExperimentalApi; import io.grpc.Status; - import java.util.concurrent.CancellationException; import java.util.concurrent.ExecutionException; -import javax.annotation.CheckReturnValue; /** * Decides whether a given Android UID is authorized to access some resource. @@ -37,24 +36,24 @@ @CheckReturnValue public abstract class AsyncSecurityPolicy extends SecurityPolicy { -/** - * @deprecated Prefer {@link #checkAuthorizationAsync(int)} for async or slow calls or subclass - * {@link SecurityPolicy} directly for quick, synchronous implementations. - */ -@Override -@Deprecated -public final Status checkAuthorization(int uid) { - try { - return checkAuthorizationAsync(uid).get(); - } catch (ExecutionException e) { - return Status.fromThrowable(e); - } catch (CancellationException e) { - return Status.CANCELLED.withCause(e); - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); // re-set the current thread's interruption state - return Status.CANCELLED.withCause(e); + /** + * @deprecated Prefer {@link #checkAuthorizationAsync(int)} for async or slow calls or subclass + * {@link SecurityPolicy} directly for quick, synchronous implementations. + */ + @Override + @Deprecated + public final Status checkAuthorization(int uid) { + try { + return checkAuthorizationAsync(uid).get(); + } catch (ExecutionException e) { + return Status.fromThrowable(e); + } catch (CancellationException e) { + return Status.CANCELLED.withCause(e); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); // re-set the current thread's interruption state + return Status.CANCELLED.withCause(e); + } } -} /** * Decides whether the given Android UID is authorized. (Validity is implementation dependent). @@ -68,4 +67,25 @@ public final Status checkAuthorization(int uid) { * authorized. */ public abstract ListenableFuture checkAuthorizationAsync(int uid); + + /** + * Decides whether the given Android UID is authorized, without providing its raw integer value. + * + *

Calling this is equivalent to calling {@link SecurityPolicy#checkAuthorization(int)}, except + * the caller provides a {@link PeerUid} wrapper instead of the raw integer uid (known only to the + * transport). This allows a server to check additional application-layer security policy for + * itself *after* the call itself is authorized by the transport layer. Cross cutting application- + * layer checks could be done from a {@link io.grpc.ServerInterceptor}. Checks based on the + * substance of a request message could be done by the individual RPC method implementations + * themselves. + * + *

See #checkAuthorizationAsync(int) for details on the semantics. See {@link + * PeerUids#newPeerIdentifyingServerInterceptor()} for how to get a {@link PeerUid}. + * + * @param uid The Android UID to authenticate. + * @return A gRPC {@link Status} object, with OK indicating authorized. + */ + public final ListenableFuture checkAuthorizationAsync(PeerUid uid) { + return checkAuthorizationAsync(uid.getUid()); + } } diff --git a/binder/src/main/java/io/grpc/binder/BinderChannelBuilder.java b/binder/src/main/java/io/grpc/binder/BinderChannelBuilder.java index 133c8c5dd13..a241634dd22 100644 --- a/binder/src/main/java/io/grpc/binder/BinderChannelBuilder.java +++ b/binder/src/main/java/io/grpc/binder/BinderChannelBuilder.java @@ -20,29 +20,14 @@ import static com.google.common.base.Preconditions.checkState; import android.content.Context; -import android.os.UserHandle; -import androidx.annotation.RequiresApi; -import androidx.core.content.ContextCompat; import com.google.errorprone.annotations.DoNotCall; -import io.grpc.ChannelCredentials; -import io.grpc.ChannelLogger; import io.grpc.ExperimentalApi; import io.grpc.ForwardingChannelBuilder; import io.grpc.ManagedChannel; import io.grpc.ManagedChannelBuilder; -import io.grpc.binder.internal.BinderTransport; -import io.grpc.binder.internal.OneWayBinderProxy; -import io.grpc.internal.ClientTransportFactory; -import io.grpc.internal.ConnectionClientTransport; +import io.grpc.binder.internal.BinderClientTransportFactory; import io.grpc.internal.FixedObjectPool; -import io.grpc.internal.GrpcUtil; import io.grpc.internal.ManagedChannelImplBuilder; -import io.grpc.internal.ManagedChannelImplBuilder.ClientTransportFactoryBuilder; -import io.grpc.internal.ObjectPool; -import io.grpc.internal.SharedResourcePool; -import java.net.SocketAddress; -import java.util.Collection; -import java.util.Collections; import java.util.concurrent.Executor; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; @@ -55,8 +40,7 @@ * Services */ @ExperimentalApi("https://github.com/grpc/grpc-java/issues/8022") -public final class BinderChannelBuilder - extends ForwardingChannelBuilder { +public final class BinderChannelBuilder extends ForwardingChannelBuilder { /** * Creates a channel builder that will bind to a remote Android service. @@ -110,8 +94,8 @@ public static BinderChannelBuilder forAddress( } /** - * Creates a channel builder that will bind to a remote Android service, via a string - * target name which will be resolved. + * Creates a channel builder that will bind to a remote Android service, via a string target name + * which will be resolved. * *

The underlying Android binding will be torn down when the channel becomes idle. This happens * after 30 minutes without use by default but can be configured via {@link @@ -122,16 +106,13 @@ public static BinderChannelBuilder forAddress( * resulting builder. They will not be shut down automatically. * * @param target A target uri which should resolve into an {@link AndroidComponentAddress} - * referencing the service to bind to. + * referencing the service to bind to. * @param sourceContext the context to bind from (e.g. The current Activity or Application). * @return a new builder */ public static BinderChannelBuilder forTarget(String target, Context sourceContext) { return new BinderChannelBuilder( - null, - checkNotNull(target, "target"), - sourceContext, - BinderChannelCredentials.forDefault()); + null, checkNotNull(target, "target"), sourceContext, BinderChannelCredentials.forDefault()); } /** @@ -160,18 +141,14 @@ public static BinderChannelBuilder forTarget( null, checkNotNull(target, "target"), sourceContext, channelCredentials); } - /** - * Always fails. Call {@link #forAddress(AndroidComponentAddress, Context)} instead. - */ + /** Always fails. Call {@link #forAddress(AndroidComponentAddress, Context)} instead. */ @DoNotCall("Unsupported. Use forAddress(AndroidComponentAddress, Context) instead") public static BinderChannelBuilder forAddress(String name, int port) { throw new UnsupportedOperationException( "call forAddress(AndroidComponentAddress, Context) instead"); } - /** - * Always fails. Call {@link #forAddress(AndroidComponentAddress, Context)} instead. - */ + /** Always fails. Call {@link #forAddress(AndroidComponentAddress, Context)} instead. */ @DoNotCall("Unsupported. Use forTarget(String, Context) instead") public static BinderChannelBuilder forTarget(String target) { throw new UnsupportedOperationException( @@ -179,14 +156,8 @@ public static BinderChannelBuilder forTarget(String target) { } private final ManagedChannelImplBuilder managedChannelImplBuilder; + private final BinderClientTransportFactory.Builder transportFactoryBuilder; - private Executor mainThreadExecutor; - private ObjectPool schedulerPool = - SharedResourcePool.forResource(GrpcUtil.TIMER_SERVICE); - private SecurityPolicy securityPolicy; - private InboundParcelablePolicy inboundParcelablePolicy; - private BindServiceFlags bindServiceFlags; - @Nullable private UserHandle targetUserHandle; private boolean strictLifecycleManagement; private BinderChannelBuilder( @@ -194,42 +165,18 @@ private BinderChannelBuilder( @Nullable String target, Context sourceContext, BinderChannelCredentials channelCredentials) { - mainThreadExecutor = - ContextCompat.getMainExecutor(checkNotNull(sourceContext, "sourceContext")); - securityPolicy = SecurityPolicies.internalOnly(); - inboundParcelablePolicy = InboundParcelablePolicy.DEFAULT; - bindServiceFlags = BindServiceFlags.DEFAULTS; - - final class BinderChannelTransportFactoryBuilder - implements ClientTransportFactoryBuilder { - @Override - public ClientTransportFactory buildClientTransportFactory() { - return new TransportFactory( - sourceContext, - channelCredentials, - mainThreadExecutor, - schedulerPool, - managedChannelImplBuilder.getOffloadExecutorPool(), - securityPolicy, - targetUserHandle, - bindServiceFlags, - inboundParcelablePolicy); - } - } + transportFactoryBuilder = + new BinderClientTransportFactory.Builder() + .setSourceContext(sourceContext) + .setChannelCredentials(channelCredentials); if (directAddress != null) { managedChannelImplBuilder = new ManagedChannelImplBuilder( - directAddress, - directAddress.getAuthority(), - new BinderChannelTransportFactoryBuilder(), - null); + directAddress, directAddress.getAuthority(), transportFactoryBuilder, null); } else { managedChannelImplBuilder = - new ManagedChannelImplBuilder( - target, - new BinderChannelTransportFactoryBuilder(), - null); + new ManagedChannelImplBuilder(target, transportFactoryBuilder, null); } idleTimeout(60, TimeUnit.SECONDS); } @@ -242,7 +189,7 @@ protected ManagedChannelBuilder delegate() { /** Specifies certain optional aspects of the underlying Android Service binding. */ public BinderChannelBuilder setBindServiceFlags(BindServiceFlags bindServiceFlags) { - this.bindServiceFlags = bindServiceFlags; + transportFactoryBuilder.setBindServiceFlags(bindServiceFlags); return this; } @@ -256,8 +203,8 @@ public BinderChannelBuilder setBindServiceFlags(BindServiceFlags bindServiceFlag */ public BinderChannelBuilder scheduledExecutorService( ScheduledExecutorService scheduledExecutorService) { - schedulerPool = - new FixedObjectPool<>(checkNotNull(scheduledExecutorService, "scheduledExecutorService")); + transportFactoryBuilder.setScheduledExecutorPool( + new FixedObjectPool<>(checkNotNull(scheduledExecutorService, "scheduledExecutorService"))); return this; } @@ -269,7 +216,7 @@ public BinderChannelBuilder scheduledExecutorService( * @return this */ public BinderChannelBuilder mainThreadExecutor(Executor mainThreadExecutor) { - this.mainThreadExecutor = mainThreadExecutor; + transportFactoryBuilder.setMainThreadExecutor(mainThreadExecutor); return this; } @@ -282,40 +229,21 @@ public BinderChannelBuilder mainThreadExecutor(Executor mainThreadExecutor) { * @return this */ public BinderChannelBuilder securityPolicy(SecurityPolicy securityPolicy) { - this.securityPolicy = checkNotNull(securityPolicy, "securityPolicy"); - return this; - } - -/** - * Provides the target {@UserHandle} of the remote Android service. - * - *

When targetUserHandle is set, Context.bindServiceAsUser will used and additional Android - * permissions will be required. If your usage does not require cross-user communications, please - * do not set this field. It is the caller's responsibility to make sure that it holds the - * corresponding permissions. - * - * @param targetUserHandle the target user to bind into. - * @return this - */ - @ExperimentalApi("https://github.com/grpc/grpc-java/issues/10173") - @RequiresApi(30) - public BinderChannelBuilder bindAsUser(UserHandle targetUserHandle) { - this.targetUserHandle = targetUserHandle; + transportFactoryBuilder.setSecurityPolicy(securityPolicy); return this; } /** Sets the policy for inbound parcelable objects. */ public BinderChannelBuilder inboundParcelablePolicy( InboundParcelablePolicy inboundParcelablePolicy) { - this.inboundParcelablePolicy = checkNotNull(inboundParcelablePolicy, "inboundParcelablePolicy"); + transportFactoryBuilder.setInboundParcelablePolicy(inboundParcelablePolicy); return this; } - /** - * Disables the channel idle timeout and prevents it from being enabled. This - * allows a centralized application method to configure the channel builder - * and return it, without worrying about another part of the application - * accidentally enabling the idle timeout. + /** + * Disables the channel idle timeout and prevents it from being enabled. This allows a centralized + * application method to configure the channel builder and return it, without worrying about + * another part of the application accidentally enabling the idle timeout. */ public BinderChannelBuilder strictLifecycleManagement() { strictLifecycleManagement = true; @@ -323,94 +251,111 @@ public BinderChannelBuilder strictLifecycleManagement() { return this; } - @Override - public BinderChannelBuilder idleTimeout(long value, TimeUnit unit) { - checkState(!strictLifecycleManagement, "Idle timeouts are not supported when strict lifecycle management is enabled"); - super.idleTimeout(value, unit); + /** + * Checks servers against this Channel's {@link SecurityPolicy} *before* binding. + * + *

Android users can be tricked into installing a malicious app with the same package name as a + * legitimate server. That's why we don't send calls to a server until it has been authorized by + * an appropriate {@link SecurityPolicy}. But merely binding to a malicious server can enable + * "keep-alive" and "background activity launch" abuse, even if it's ultimately unauthorized. + * Pre-authorization mitigates these threats by performing a preliminary {@link SecurityPolicy} + * check against a server app's PackageManager-registered identity without actually creating an + * instance of it. This is especially important for security when the server's direct address + * isn't known in advance but rather resolved via target URI or discovered by other means. + * + *

Note that, unlike ordinary authorization, pre-authorization is performed against the server + * app's UID, not the UID of the process hosting the bound Service. These can be different, most + * commonly due to services that set `android:isolatedProcess=true`. + * + *

Pre-authorization is strongly recommended but it remains optional for now because of this + * behavior change and the small performance cost. + * + *

The default value of this property is false but it will become true in a future release. + * Clients that require a particular behavior should configure it explicitly using this method + * rather than relying on the default. + */ + @ExperimentalApi("https://github.com/grpc/grpc-java/issues/12191") + public BinderChannelBuilder preAuthorizeServers(boolean preAuthorize) { + transportFactoryBuilder.setPreAuthorizeServers(preAuthorize); return this; } - /** Creates new binder transports. */ - private static final class TransportFactory implements ClientTransportFactory { - private final Context sourceContext; - private final BinderChannelCredentials channelCredentials; - private final Executor mainThreadExecutor; - private final ObjectPool scheduledExecutorPool; - private final ObjectPool offloadExecutorPool; - private final SecurityPolicy securityPolicy; - @Nullable private final UserHandle targetUserHandle; - private final BindServiceFlags bindServiceFlags; - private final InboundParcelablePolicy inboundParcelablePolicy; - - private ScheduledExecutorService executorService; - private Executor offloadExecutor; - private boolean closed; - - TransportFactory( - Context sourceContext, - BinderChannelCredentials channelCredentials, - Executor mainThreadExecutor, - ObjectPool scheduledExecutorPool, - ObjectPool offloadExecutorPool, - SecurityPolicy securityPolicy, - @Nullable UserHandle targetUserHandle, - BindServiceFlags bindServiceFlags, - InboundParcelablePolicy inboundParcelablePolicy) { - this.sourceContext = sourceContext; - this.channelCredentials = channelCredentials; - this.mainThreadExecutor = mainThreadExecutor; - this.scheduledExecutorPool = scheduledExecutorPool; - this.offloadExecutorPool = offloadExecutorPool; - this.securityPolicy = securityPolicy; - this.targetUserHandle = targetUserHandle; - this.bindServiceFlags = bindServiceFlags; - this.inboundParcelablePolicy = inboundParcelablePolicy; - - executorService = scheduledExecutorPool.getObject(); - offloadExecutor = offloadExecutorPool.getObject(); - } - - @Override - public ConnectionClientTransport newClientTransport( - SocketAddress addr, ClientTransportOptions options, ChannelLogger channelLogger) { - if (closed) { - throw new IllegalStateException("The transport factory is closed."); - } - return new BinderTransport.BinderClientTransport( - sourceContext, - channelCredentials, - (AndroidComponentAddress) addr, - targetUserHandle, - bindServiceFlags, - mainThreadExecutor, - scheduledExecutorPool, - offloadExecutorPool, - securityPolicy, - inboundParcelablePolicy, - OneWayBinderProxy.IDENTITY_DECORATOR, - options.getEagAttributes()); - } - - @Override - public ScheduledExecutorService getScheduledExecutorService() { - return executorService; - } + /** + * Specifies how and when to authorize a server against this Channel's {@link SecurityPolicy}. + * + *

This method selects the original "legacy" authorization strategy, which is no longer + * preferred for two reasons: First, the legacy strategy considers the UID of the server *process* + * we connect to. This is problematic for services using the `android:isolatedProcess` attribute, + * which runs them under a different "ephemeral" UID. This UID lacks all the privileges of the + * hosting app -- any non-trivial SecurityPolicy would fail to authorize it. Second, the legacy + * authorization strategy performs SecurityPolicy checks later in the connection handshake, which + * means the calling UID must be rechecked on every subsequent RPC. For these reasons, prefer + * {@link #useV2AuthStrategy} instead. + * + *

The server does not know which authorization strategy a client is using. Both strategies + * work with all versions of the grpc-binder server. + * + *

Callers need not specify an authorization strategy, but the default is unspecified and will + * eventually become {@link #useV2AuthStrategy()}. Clients that require the legacy strategy should + * configure it explicitly using this method. Eventually, however, legacy support will be + * deprecated and removed. + * + * @return this + */ + @ExperimentalApi("https://github.com/grpc/grpc-java/issues/12397") + public BinderChannelBuilder useLegacyAuthStrategy() { + transportFactoryBuilder.setUseLegacyAuthStrategy(true); + return this; + } - @Override - public SwapChannelCredentialsResult swapChannelCredentials(ChannelCredentials channelCreds) { - return null; - } + /** + * Specifies how and when to authorize a server against this Channel's {@link SecurityPolicy}. + * + *

This method selects the v2 authorization strategy. It improves on the original strategy + * ({@link #useLegacyAuthStrategy}), by considering the UID of the server *app* we connect to, + * rather than the server *process*. This allows clients to connect to services configured with + * the `android:isolatedProcess` attribute, which run with the same authority as the hosting app, + * but under a different "ephemeral" UID that any non-trivial SecurityPolicy would fail to + * authorize. + * + *

Furthermore, the v2 authorization strategy performs SecurityPolicy checks earlier in the + * connection handshake, which allows subsequent RPCs over that connection to proceed securely + * without further UID checks. For these reasons, clients should prefer the v2 strategy. + * + *

The server does not know which authorization strategy a client is using. Both strategies + * work with all versions of the grpc-binder server. + * + *

Callers need not specify an authorization strategy, but the default is unspecified and can + * change over time. Clients that require the v2 strategy should configure it explicitly using + * this method. Eventually, this strategy will become the default and legacy support will be + * removed. + * + *

If moving to the new authorization strategy causes a robolectric test to fail, ensure your + * fake Service component is registered with `ShadowPackageManager` using `addOrUpdateService()`. + * + * @return this + */ + @ExperimentalApi("https://github.com/grpc/grpc-java/issues/12397") + public BinderChannelBuilder useV2AuthStrategy() { + transportFactoryBuilder.setUseLegacyAuthStrategy(false); + return this; + } - @Override - public void close() { - closed = true; - executorService = scheduledExecutorPool.returnObject(executorService); - offloadExecutor = offloadExecutorPool.returnObject(offloadExecutor); - } + @Override + public BinderChannelBuilder idleTimeout(long value, TimeUnit unit) { + checkState( + !strictLifecycleManagement, + "Idle timeouts are not supported when strict lifecycle management is enabled"); + super.idleTimeout(value, unit); + return this; + } - @Override - public Collection> getSupportedSocketAddressTypes() { - return Collections.singleton(AndroidComponentAddress.class); - } + @Override + public ManagedChannel build() { + transportFactoryBuilder.setOffloadExecutorPool( + managedChannelImplBuilder.getOffloadExecutorPool()); + setNameResolverArg( + ApiConstants.SOURCE_ANDROID_CONTEXT, transportFactoryBuilder.getSourceContext()); + return super.build(); } } diff --git a/binder/src/main/java/io/grpc/binder/BinderChannelCredentials.java b/binder/src/main/java/io/grpc/binder/BinderChannelCredentials.java index 193b1010a16..f6e6f666494 100644 --- a/binder/src/main/java/io/grpc/binder/BinderChannelCredentials.java +++ b/binder/src/main/java/io/grpc/binder/BinderChannelCredentials.java @@ -16,8 +16,6 @@ package io.grpc.binder; -import static com.google.common.base.Preconditions.checkNotNull; - import android.content.ComponentName; import androidx.annotation.RequiresApi; import io.grpc.ChannelCredentials; @@ -61,9 +59,9 @@ public ChannelCredentials withoutBearerTokens() { return this; } - /** + /** * Returns the admin component to be specified with DevicePolicyManager - * bindDeviceAdminServiceAsUser API. + * bindDeviceAdminServiceAsUser API. */ @Nullable public ComponentName getDevicePolicyAdminComponentName() { diff --git a/binder/src/main/java/io/grpc/binder/BinderInternal.java b/binder/src/main/java/io/grpc/binder/BinderInternal.java index 18af43ce2b3..5ed24a07901 100644 --- a/binder/src/main/java/io/grpc/binder/BinderInternal.java +++ b/binder/src/main/java/io/grpc/binder/BinderInternal.java @@ -20,26 +20,22 @@ import io.grpc.Internal; import io.grpc.binder.internal.BinderTransportSecurity; -/** - * Helper class to expose IBinderReceiver methods for legacy internal builders. - */ +/** Helper class to expose IBinderReceiver methods for legacy internal builders. */ @Internal public class BinderInternal { - /** - * Sets the receiver's {@link IBinder} using {@link IBinderReceiver#set(IBinder)}. - */ + /** Sets the receiver's {@link IBinder} using {@link IBinderReceiver#set(IBinder)}. */ public static void setIBinder(IBinderReceiver receiver, IBinder binder) { receiver.set(binder); } /** - * Creates a {@link BinderTransportSecurity.ServerPolicyChecker} from a - * {@link ServerSecurityPolicy}. This exposes to callers an interface to check security policies - * without causing hard dependencies on a specific class. + * Creates a {@link BinderTransportSecurity.ServerPolicyChecker} from a {@link + * ServerSecurityPolicy}. This exposes to callers an interface to check security policies without + * causing hard dependencies on a specific class. */ public static BinderTransportSecurity.ServerPolicyChecker createPolicyChecker( - ServerSecurityPolicy securityPolicy) { + ServerSecurityPolicy securityPolicy) { return securityPolicy::checkAuthorizationForServiceAsync; } } diff --git a/binder/src/main/java/io/grpc/binder/BinderServerBuilder.java b/binder/src/main/java/io/grpc/binder/BinderServerBuilder.java index 158f7947ee8..5f0885883a5 100644 --- a/binder/src/main/java/io/grpc/binder/BinderServerBuilder.java +++ b/binder/src/main/java/io/grpc/binder/BinderServerBuilder.java @@ -29,74 +29,51 @@ import io.grpc.binder.internal.BinderServer; import io.grpc.binder.internal.BinderTransportSecurity; import io.grpc.internal.FixedObjectPool; -import io.grpc.internal.GrpcUtil; import io.grpc.internal.ServerImplBuilder; -import io.grpc.internal.ObjectPool; -import io.grpc.internal.SharedResourcePool; - -import java.io.Closeable; import java.io.File; -import java.util.concurrent.Executor; import java.util.concurrent.ScheduledExecutorService; -import javax.annotation.Nullable; - -/** - * Builder for a server that services requests from an Android Service. - */ -public final class BinderServerBuilder - extends ForwardingServerBuilder { +/** Builder for a server that services requests from an Android Service. */ +public final class BinderServerBuilder extends ForwardingServerBuilder { /** * Creates a server builder that will listen for bindings to the specified address. * - *

The listening {@link IBinder} associated with new {@link Server}s will be stored - * in {@code binderReceiver} upon {@link #build()}. Callers should return it from {@link + *

The listening {@link IBinder} associated with new {@link Server}s will be stored in {@code + * binderReceiver} upon {@link #build()}. Callers should return it from {@link * Service#onBind(Intent)} when the binding intent matches {@code listenAddress}. * * @param listenAddress an Android Service and binding Intent associated with this server. * @param receiver an "out param" for the new {@link Server}'s listening {@link IBinder} * @return a new builder */ - public static BinderServerBuilder forAddress(AndroidComponentAddress listenAddress, - IBinderReceiver receiver) { + public static BinderServerBuilder forAddress( + AndroidComponentAddress listenAddress, IBinderReceiver receiver) { return new BinderServerBuilder(listenAddress, receiver); } - /** - * Always fails. Call {@link #forAddress(AndroidComponentAddress, IBinderReceiver)} instead. - */ + /** Always fails. Call {@link #forAddress(AndroidComponentAddress, IBinderReceiver)} instead. */ @DoNotCall("Unsupported. Use forAddress() instead") public static BinderServerBuilder forPort(int port) { throw new UnsupportedOperationException("call forAddress() instead"); } private final ServerImplBuilder serverImplBuilder; - private ObjectPool schedulerPool = - SharedResourcePool.forResource(GrpcUtil.TIMER_SERVICE); - private ServerSecurityPolicy securityPolicy; - private InboundParcelablePolicy inboundParcelablePolicy; + private final BinderServer.Builder internalBuilder = new BinderServer.Builder(); private boolean isBuilt; - @Nullable private BinderTransportSecurity.ShutdownListener shutdownListener = null; private BinderServerBuilder( - AndroidComponentAddress listenAddress, - IBinderReceiver binderReceiver) { - securityPolicy = SecurityPolicies.serverInternalOnly(); - inboundParcelablePolicy = InboundParcelablePolicy.DEFAULT; - - serverImplBuilder = new ServerImplBuilder(streamTracerFactories -> { - BinderServer server = new BinderServer( - listenAddress, - schedulerPool, - streamTracerFactories, - BinderInternal.createPolicyChecker(securityPolicy), - inboundParcelablePolicy, - // 'shutdownListener' should have been set by build() - checkNotNull(shutdownListener)); - BinderInternal.setIBinder(binderReceiver, server.getHostBinder()); - return server; - }); + AndroidComponentAddress listenAddress, IBinderReceiver binderReceiver) { + internalBuilder.setListenAddress(listenAddress); + + serverImplBuilder = + new ServerImplBuilder( + (streamTracerFactories, metricRecorder) -> { + internalBuilder.setStreamTracerFactories(streamTracerFactories); + BinderServer server = internalBuilder.build(); + BinderInternal.setIBinder(binderReceiver, server.getHostBinder()); + return server; + }); // Disable stats and tracing by default. serverImplBuilder.setStatsEnabled(false); @@ -132,8 +109,8 @@ public BinderServerBuilder enableTracing() { */ public BinderServerBuilder scheduledExecutorService( ScheduledExecutorService scheduledExecutorService) { - schedulerPool = - new FixedObjectPool<>(checkNotNull(scheduledExecutorService, "scheduledExecutorService")); + internalBuilder.setExecutorServicePool( + new FixedObjectPool<>(checkNotNull(scheduledExecutorService, "scheduledExecutorService"))); return this; } @@ -146,7 +123,7 @@ public BinderServerBuilder scheduledExecutorService( * @return this */ public BinderServerBuilder securityPolicy(ServerSecurityPolicy securityPolicy) { - this.securityPolicy = checkNotNull(securityPolicy, "securityPolicy"); + internalBuilder.setServerSecurityPolicy(securityPolicy); return this; } @@ -154,13 +131,11 @@ public BinderServerBuilder securityPolicy(ServerSecurityPolicy securityPolicy) { @ExperimentalApi("https://github.com/grpc/grpc-java/issues/8022") public BinderServerBuilder inboundParcelablePolicy( InboundParcelablePolicy inboundParcelablePolicy) { - this.inboundParcelablePolicy = checkNotNull(inboundParcelablePolicy, "inboundParcelablePolicy"); + internalBuilder.setInboundParcelablePolicy(inboundParcelablePolicy); return this; } - /** - * Always fails. TLS is not supported in BinderServer. - */ + /** Always fails. TLS is not supported in BinderServer. */ @Override public BinderServerBuilder useTransportSecurity(File certChain, File privateKey) { throw new UnsupportedOperationException("TLS not supported in BinderServer"); @@ -173,16 +148,14 @@ public BinderServerBuilder useTransportSecurity(File certChain, File privateKey) * * @return the new Server */ - @Override // For javadoc refinement only. + @Override public Server build() { // Since we install a final interceptor here, we need to ensure we're only built once. checkState(!isBuilt, "BinderServerBuilder can only be used to build one server instance."); isBuilt = true; // We install the security interceptor last, so it's closest to the transport. - ObjectPool executorPool = serverImplBuilder.getExecutorPool(); - Executor executor = executorPool.getObject(); - BinderTransportSecurity.installAuthInterceptor(this, executor); - shutdownListener = () -> executorPool.returnObject(executor); + BinderTransportSecurity.installAuthInterceptor(this); + internalBuilder.setExecutorPool(serverImplBuilder.getExecutorPool()); return super.build(); } } diff --git a/binder/src/main/java/io/grpc/binder/ParcelableUtils.java b/binder/src/main/java/io/grpc/binder/ParcelableUtils.java index 969344ea68d..0082d33aff5 100644 --- a/binder/src/main/java/io/grpc/binder/ParcelableUtils.java +++ b/binder/src/main/java/io/grpc/binder/ParcelableUtils.java @@ -50,8 +50,6 @@ public static

Metadata.Key

metadataKey( */ public static

Metadata.Key

metadataKeyForImmutableType( String name, Parcelable.Creator

creator) { - return Metadata.Key.of( - name, new MetadataHelper.ParcelableMetadataMarshaller

(creator, true)); + return Metadata.Key.of(name, new MetadataHelper.ParcelableMetadataMarshaller

(creator, true)); } } - diff --git a/binder/src/main/java/io/grpc/binder/PeerUid.java b/binder/src/main/java/io/grpc/binder/PeerUid.java index 1f812bf478b..87ff6763378 100644 --- a/binder/src/main/java/io/grpc/binder/PeerUid.java +++ b/binder/src/main/java/io/grpc/binder/PeerUid.java @@ -75,4 +75,4 @@ public int hashCode() { public String toString() { return "PeerUid{" + uid + '}'; } -} \ No newline at end of file +} diff --git a/binder/src/main/java/io/grpc/binder/PeerUids.java b/binder/src/main/java/io/grpc/binder/PeerUids.java index d25d595c0cd..4c4143166a1 100644 --- a/binder/src/main/java/io/grpc/binder/PeerUids.java +++ b/binder/src/main/java/io/grpc/binder/PeerUids.java @@ -100,4 +100,4 @@ public ServerCall.Listener interceptCall( } private PeerUids() {} -} \ No newline at end of file +} diff --git a/binder/src/main/java/io/grpc/binder/SecurityPolicies.java b/binder/src/main/java/io/grpc/binder/SecurityPolicies.java index ea17b9828f0..c0f6fe81989 100644 --- a/binder/src/main/java/io/grpc/binder/SecurityPolicies.java +++ b/binder/src/main/java/io/grpc/binder/SecurityPolicies.java @@ -83,8 +83,8 @@ public Status checkAuthorization(int uid) { } /** - * Creates a {@link SecurityPolicy} which checks if the package signature - * matches {@code requiredSignature}. + * Creates a {@link SecurityPolicy} which checks if the package signature matches {@code + * requiredSignature}. * * @param packageName the package name of the allowed package. * @param requiredSignature the allowed signature of the allowed package. @@ -93,8 +93,7 @@ public Status checkAuthorization(int uid) { @ExperimentalApi("https://github.com/grpc/grpc-java/issues/8022") public static SecurityPolicy hasSignature( PackageManager packageManager, String packageName, Signature requiredSignature) { - return oneOfSignatures( - packageManager, packageName, ImmutableList.of(requiredSignature)); + return oneOfSignatures(packageManager, packageName, ImmutableList.of(requiredSignature)); } /** @@ -114,8 +113,8 @@ public static SecurityPolicy hasSignatureSha256Hash( } /** - * Creates a {@link SecurityPolicy} which checks if the package signature - * matches any of {@code requiredSignatures}. + * Creates a {@link SecurityPolicy} which checks if the package signature matches any of {@code + * requiredSignatures}. * * @param packageName the package name of the allowed package. * @param requiredSignatures the allowed signatures of the allowed package. @@ -124,14 +123,11 @@ public static SecurityPolicy hasSignatureSha256Hash( */ @ExperimentalApi("https://github.com/grpc/grpc-java/issues/8022") public static SecurityPolicy oneOfSignatures( - PackageManager packageManager, - String packageName, - Collection requiredSignatures) { + PackageManager packageManager, String packageName, Collection requiredSignatures) { Preconditions.checkNotNull(packageManager, "packageManager"); Preconditions.checkNotNull(packageName, "packageName"); Preconditions.checkNotNull(requiredSignatures, "requiredSignatures"); - Preconditions.checkArgument(!requiredSignatures.isEmpty(), - "requiredSignatures"); + Preconditions.checkArgument(!requiredSignatures.isEmpty(), "requiredSignatures"); ImmutableList requiredSignaturesImmutable = ImmutableList.copyOf(requiredSignatures); for (Signature requiredSignature : requiredSignaturesImmutable) { @@ -141,8 +137,7 @@ public static SecurityPolicy oneOfSignatures( return new SecurityPolicy() { @Override public Status checkAuthorization(int uid) { - return checkUidSignature( - packageManager, uid, packageName, requiredSignaturesImmutable); + return checkUidSignature(packageManager, uid, packageName, requiredSignaturesImmutable); } }; } @@ -186,24 +181,23 @@ public Status checkAuthorization(int uid) { } /** - * Creates {@link SecurityPolicy} which checks if the app is a device owner app. See - * {@link DevicePolicyManager}. + * Creates {@link SecurityPolicy} which checks if the app is a device owner app. See {@link + * DevicePolicyManager}. */ - @androidx.annotation.RequiresApi(18) public static io.grpc.binder.SecurityPolicy isDeviceOwner(Context applicationContext) { DevicePolicyManager devicePolicyManager = (DevicePolicyManager) applicationContext.getSystemService(Context.DEVICE_POLICY_SERVICE); return anyPackageWithUidSatisfies( - applicationContext, pkg -> devicePolicyManager.isDeviceOwnerApp(pkg), + applicationContext, + pkg -> devicePolicyManager.isDeviceOwnerApp(pkg), "Rejected by device owner policy. No packages found for UID.", "Rejected by device owner policy"); } /** - * Creates {@link SecurityPolicy} which checks if the app is a profile owner app. See - * {@link DevicePolicyManager}. + * Creates {@link SecurityPolicy} which checks if the app is a profile owner app. See {@link + * DevicePolicyManager}. */ - @androidx.annotation.RequiresApi(21) public static SecurityPolicy isProfileOwner(Context applicationContext) { DevicePolicyManager devicePolicyManager = (DevicePolicyManager) applicationContext.getSystemService(Context.DEVICE_POLICY_SERVICE); @@ -223,9 +217,10 @@ public static SecurityPolicy isProfileOwnerOnOrganizationOwnedDevice(Context app (DevicePolicyManager) applicationContext.getSystemService(Context.DEVICE_POLICY_SERVICE); return anyPackageWithUidSatisfies( applicationContext, - pkg -> VERSION.SDK_INT >= 30 - && devicePolicyManager.isProfileOwnerApp(pkg) - && devicePolicyManager.isOrganizationOwnedDeviceWithManagedProfile(), + pkg -> + VERSION.SDK_INT >= 30 + && devicePolicyManager.isProfileOwnerApp(pkg) + && devicePolicyManager.isOrganizationOwnedDeviceWithManagedProfile(), "Rejected by profile owner on organization-owned device policy. No packages found for UID.", "Rejected by profile owner on organization-owned device policy"); } @@ -237,8 +232,7 @@ private static Status checkUidSignature( ImmutableList requiredSignatures) { String[] packages = packageManager.getPackagesForUid(uid); if (packages == null) { - return Status.UNAUTHENTICATED.withDescription( - "Rejected by signature check security policy"); + return Status.UNAUTHENTICATED.withDescription("Rejected by signature check security policy"); } boolean packageNameMatched = false; for (String pkg : packages) { @@ -251,8 +245,7 @@ private static Status checkUidSignature( } } return Status.PERMISSION_DENIED.withDescription( - "Rejected by signature check security policy. Package name matched: " - + packageNameMatched); + "Rejected by signature check security policy. Package name matched: " + packageNameMatched); } private static Status checkUidSha256Signature( @@ -289,9 +282,8 @@ private static Status checkUidSha256Signature( * * @param packageName the package to be checked * @param signatureCheckFunction {@link Predicate} that takes a signature and verifies if it - * satisfies any signature constraints - * return {@code true} if {@code packageName} has a signature that satisfies {@code - * signatureCheckFunction}. + * satisfies any signature constraints return {@code true} if {@code packageName} has a + * signature that satisfies {@code signatureCheckFunction}. */ @SuppressWarnings("deprecation") // For PackageInfo.signatures @SuppressLint("PackageManagerGetSignatures") // We only allow 1 signature. @@ -321,7 +313,7 @@ private static boolean checkPackageSignature( packageInfo = packageManager.getPackageInfo(packageName, PackageManager.GET_SIGNATURES); if (packageInfo.signatures == null || packageInfo.signatures.length != 1) { // Reject multiply-signed apks because of b/13678484 - // (See PackageManagerGetSignatures supression above). + // (See PackageManagerGetSignatures suppression above). return false; } @@ -423,23 +415,23 @@ public Status checkAuthorization(int uid) { /** * Creates a {@link SecurityPolicy} which checks if the caller has all of the given permissions * from {@code permissions}. - * + * *

The gRPC framework assumes that a {@link SecurityPolicy}'s verdict for a given peer UID will * not change over the lifetime of any process with that UID. But Android runtime permissions can - * be granted or revoked by the user at any time and so using the {@link #hasPermissions} - * {@link SecurityPolicy} comes with certain special responsibilities. - * - *

In particular, callers must ensure that the *subjects* of the returned - * {@link SecurityPolicy} hold all required {@code permissions} *before* making use of it. Android - * kills an app's processes when it loses any permission but the same isn't true when a permission - * is granted. And so without special care, a {@link #hasPermissions} denial could incorrectly + * be granted or revoked by the user at any time and so using the {@link #hasPermissions} {@link + * SecurityPolicy} comes with certain special responsibilities. + * + *

In particular, callers must ensure that the *subjects* of the returned {@link + * SecurityPolicy} hold all required {@code permissions} *before* making use of it. Android kills + * an app's processes when it loses any permission but the same isn't true when a permission is + * granted. And so without special care, a {@link #hasPermissions} denial could incorrectly * persist even if the subject is later granted all required {@code permissions}. - * + * *

A server using {@link #hasPermissions} must, as part of its RPC API contract, require * clients to request and receive all {@code permissions} before making a call. This is in line * with official Android guidance to request and confirm receipt of runtime permissions before - * using them. - * + * using them. + * *

A client, on the other hand, should only use {@link #hasPermissions} policies that require * install-time permissions which cannot change. * diff --git a/binder/src/main/java/io/grpc/binder/SecurityPolicy.java b/binder/src/main/java/io/grpc/binder/SecurityPolicy.java index 6b0fb40310a..3ad8903407f 100644 --- a/binder/src/main/java/io/grpc/binder/SecurityPolicy.java +++ b/binder/src/main/java/io/grpc/binder/SecurityPolicy.java @@ -16,16 +16,16 @@ package io.grpc.binder; +import com.google.errorprone.annotations.CheckReturnValue; import io.grpc.Status; -import javax.annotation.CheckReturnValue; /** * Decides whether a given Android UID is authorized to access some resource. * - * While it's possible to extend this class to define your own policy, it's strongly - * recommended that you only use the policies provided by the {@link SecurityPolicies} or - * {@link UntrustedSecurityPolicies} classes. Implementing your own security policy requires - * significant care, and an understanding of the details and pitfalls of Android security. + *

While it's possible to extend this class to define your own policy, it's strongly recommended + * that you only use the policies provided by the {@link SecurityPolicies} or {@link + * UntrustedSecurityPolicies} classes. Implementing your own security policy requires significant + * care, and an understanding of the details and pitfalls of Android security. * *

IMPORTANT For any concrete extensions of this class, it's assumed that the * authorization status of a given UID will not change as long as a process with that UID is @@ -53,4 +53,25 @@ protected SecurityPolicy() {} * @return A gRPC {@link Status} object, with OK indicating authorized. */ public abstract Status checkAuthorization(int uid); + + /** + * Decides whether the given Android UID is authorized, without providing its raw integer value. + * + *

Calling this is equivalent to calling {@link SecurityPolicy#checkAuthorization(int)}, except + * the caller provides a {@link PeerUid} wrapper instead of the raw integer uid (known only to the + * transport). This allows a server to check additional application-layer security policy for + * itself *after* the call itself is authorized by the transport layer. Cross cutting application- + * layer checks could be done from a {@link io.grpc.ServerInterceptor}. Checks based on the + * substance of a request message could be done by the individual RPC method implementations + * themselves. + * + *

See #checkAuthorizationAsync(int) for details on the semantics. See {@link + * PeerUids#newPeerIdentifyingServerInterceptor()} for how to get a {@link PeerUid}. + * + * @param uid The Android UID to authenticate. + * @return A gRPC {@link Status} object, with OK indicating authorized. + */ + public final Status checkAuthorization(PeerUid uid) { + return checkAuthorization(uid.getUid()); + } } diff --git a/binder/src/main/java/io/grpc/binder/ServerSecurityPolicy.java b/binder/src/main/java/io/grpc/binder/ServerSecurityPolicy.java index ced973ede1c..4786a5e6cc4 100644 --- a/binder/src/main/java/io/grpc/binder/ServerSecurityPolicy.java +++ b/binder/src/main/java/io/grpc/binder/ServerSecurityPolicy.java @@ -19,15 +19,15 @@ import com.google.common.collect.ImmutableMap; import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.ListenableFuture; +import com.google.errorprone.annotations.CheckReturnValue; import io.grpc.Status; import java.util.HashMap; import java.util.Map; -import javax.annotation.CheckReturnValue; /** * A security policy for a gRPC server. * - * Contains a default policy, and optional policies for each server. + *

Contains a default policy, and optional policies for each server. */ public final class ServerSecurityPolicy { @@ -61,8 +61,8 @@ public Status checkAuthorizationForService(int uid, String serviceName) { /** * Returns whether the given Android UID is authorized to access a particular service. * - *

This method never throws an exception. If the execution of the security policy check - * fails, a failed future with such exception is returned. + *

This method never throws an exception. If the execution of the security policy check fails, + * a failed future with such exception is returned. * * @param uid The Android UID to authenticate. * @param serviceName The name of the gRPC service being called. diff --git a/binder/src/main/java/io/grpc/binder/UntrustedSecurityPolicies.java b/binder/src/main/java/io/grpc/binder/UntrustedSecurityPolicies.java index 7c842b025ac..44612a82109 100644 --- a/binder/src/main/java/io/grpc/binder/UntrustedSecurityPolicies.java +++ b/binder/src/main/java/io/grpc/binder/UntrustedSecurityPolicies.java @@ -16,13 +16,11 @@ package io.grpc.binder; +import com.google.errorprone.annotations.CheckReturnValue; import io.grpc.ExperimentalApi; import io.grpc.Status; -import javax.annotation.CheckReturnValue; -/** - * Static factory methods for creating untrusted security policies. - */ +/** Static factory methods for creating untrusted security policies. */ @CheckReturnValue @ExperimentalApi("https://github.com/grpc/grpc-java/issues/8022") public final class UntrustedSecurityPolicies { @@ -30,11 +28,9 @@ public final class UntrustedSecurityPolicies { private UntrustedSecurityPolicies() {} /** - * Return a security policy which allows any peer on device. - * Servers should only use this policy if they intend to expose - * a service to all applications on device. - * Clients should only use this policy if they don't need to trust the - * application they're connecting to. + * Return a security policy which allows any peer on device. Servers should only use this policy + * if they intend to expose a service to all applications on device. Clients should only use this + * policy if they don't need to trust the application they're connecting to. */ public static SecurityPolicy untrustedPublic() { return new SecurityPolicy() { diff --git a/binder/src/main/java/io/grpc/binder/internal/ActiveTransportTracker.java b/binder/src/main/java/io/grpc/binder/internal/ActiveTransportTracker.java new file mode 100644 index 00000000000..01505bfd509 --- /dev/null +++ b/binder/src/main/java/io/grpc/binder/internal/ActiveTransportTracker.java @@ -0,0 +1,110 @@ +package io.grpc.binder.internal; + +import static com.google.common.base.Preconditions.checkState; + +import com.google.errorprone.annotations.concurrent.GuardedBy; +import io.grpc.Attributes; +import io.grpc.Metadata; +import io.grpc.internal.ServerListener; +import io.grpc.internal.ServerStream; +import io.grpc.internal.ServerTransport; +import io.grpc.internal.ServerTransportListener; + +/** + * Tracks which {@link BinderServerTransport} are currently active and allows invoking a {@link + * Runnable} only once all transports are terminated. + */ +final class ActiveTransportTracker implements ServerListener { + private final ServerListener delegate; + private final Runnable terminationListener; + + @GuardedBy("this") + private boolean shutdown = false; + + @GuardedBy("this") + private int activeTransportCount = 0; + + /** + * @param delegate the original server listener that this object decorates. Usually passed to + * {@link BinderServer#start(ServerListener)}. + * @param terminationListener invoked only once the server has started shutdown ({@link + * #serverShutdown()} AND the last active transport is terminated. + */ + ActiveTransportTracker(ServerListener delegate, Runnable terminationListener) { + this.delegate = delegate; + this.terminationListener = terminationListener; + } + + @Override + public ServerTransportListener transportCreated(ServerTransport transport) { + synchronized (this) { + checkState(!shutdown, "Illegal transportCreated() after serverShutdown()"); + activeTransportCount++; + } + ServerTransportListener originalListener = delegate.transportCreated(transport); + return new TrackedTransportListener(originalListener); + } + + private void untrack() { + Runnable maybeTerminationListener; + synchronized (this) { + activeTransportCount--; + maybeTerminationListener = getListenerIfTerminated(); + } + // Prefer running the listener outside of the synchronization lock to release it sooner, since + // we don't know how the callback is implemented nor how long it will take. This should + // minimize the possibility of deadlocks. + if (maybeTerminationListener != null) { + maybeTerminationListener.run(); + } + } + + @Override + public void serverShutdown() { + delegate.serverShutdown(); + Runnable maybeTerminationListener; + synchronized (this) { + shutdown = true; + maybeTerminationListener = getListenerIfTerminated(); + } + // We may be able to shutdown immediately if there are no active transports. + // + // Executed outside of the lock. See "untrack()" above. + if (maybeTerminationListener != null) { + maybeTerminationListener.run(); + } + } + + @GuardedBy("this") + private Runnable getListenerIfTerminated() { + return (shutdown && activeTransportCount == 0) ? terminationListener : null; + } + + /** + * Wraps a {@link ServerTransportListener}, unregistering it from the parent tracker once the + * transport terminates. + */ + private final class TrackedTransportListener implements ServerTransportListener { + private final ServerTransportListener delegate; + + TrackedTransportListener(ServerTransportListener delegate) { + this.delegate = delegate; + } + + @Override + public void streamCreated(ServerStream stream, String method, Metadata headers) { + delegate.streamCreated(stream, method, headers); + } + + @Override + public Attributes transportReady(Attributes attributes) { + return delegate.transportReady(attributes); + } + + @Override + public void transportTerminated() { + delegate.transportTerminated(); + untrack(); + } + } +} diff --git a/binder/src/main/java/io/grpc/binder/internal/Bindable.java b/binder/src/main/java/io/grpc/binder/internal/Bindable.java index 8e1af64b63d..59a2502de2b 100644 --- a/binder/src/main/java/io/grpc/binder/internal/Bindable.java +++ b/binder/src/main/java/io/grpc/binder/internal/Bindable.java @@ -16,10 +16,12 @@ package io.grpc.binder.internal; +import android.content.pm.ServiceInfo; import android.os.IBinder; import androidx.annotation.AnyThread; import androidx.annotation.MainThread; import io.grpc.Status; +import io.grpc.StatusException; /** An interface for managing a {@code Binder} connection. */ interface Bindable { @@ -45,6 +47,22 @@ interface Observer { void onUnbound(Status reason); } + /** + * Fetches details about the remote Service from PackageManager without binding to it. + * + *

Resolving an untrusted address before binding to it lets you screen out problematic servers + * before giving them a chance to run. However, note that the identity/existence of the resolved + * Service can change between the time this method returns and the time you actually bind/connect + * to it. For example, suppose the target package gets uninstalled or upgraded right after this + * method returns. + * + *

Compare with {@link #getConnectedServiceInfo()}, which can only be called after {@link + * Observer#onBound(IBinder)} but can be used to learn about the service you actually connected + * to. + */ + @AnyThread + ServiceInfo resolve() throws StatusException; + /** * Attempt to bind with the remote service. * @@ -53,6 +71,21 @@ interface Observer { @AnyThread void bind(); + /** + * Asks PackageManager for details about the remote Service we *actually* connected to. + * + *

Can only be called after {@link Observer#onBound}. + * + *

Compare with {@link #resolve()}, which reports which service would be selected as of now but + * *without* connecting. + * + * @throws StatusException UNIMPLEMENTED if the connected service isn't found (an {@link + * Observer#onUnbound} callback has likely already happened or is on its way!) + * @throws IllegalStateException if {@link Observer#onBound} has not "happened-before" this call + */ + @AnyThread + ServiceInfo getConnectedServiceInfo() throws StatusException; + /** * Unbind from the remote service if connected. * diff --git a/binder/src/main/java/io/grpc/binder/internal/BinderClientTransport.java b/binder/src/main/java/io/grpc/binder/internal/BinderClientTransport.java new file mode 100644 index 00000000000..58e7d7e2b31 --- /dev/null +++ b/binder/src/main/java/io/grpc/binder/internal/BinderClientTransport.java @@ -0,0 +1,547 @@ +/* + * Copyright 2020 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.grpc.binder.internal; + +import static com.google.common.base.Preconditions.checkNotNull; +import static io.grpc.binder.ApiConstants.PRE_AUTH_SERVER_OVERRIDE; +import static java.util.concurrent.TimeUnit.MILLISECONDS; + +import android.content.Context; +import android.content.pm.ServiceInfo; +import android.os.Binder; +import android.os.IBinder; +import android.os.Parcel; +import android.os.Process; +import androidx.annotation.BinderThread; +import androidx.annotation.MainThread; +import com.google.common.base.Ticker; +import com.google.common.util.concurrent.FutureCallback; +import com.google.common.util.concurrent.Futures; +import com.google.common.util.concurrent.ListenableFuture; +import com.google.errorprone.annotations.CheckReturnValue; +import com.google.errorprone.annotations.concurrent.GuardedBy; +import io.grpc.Attributes; +import io.grpc.CallOptions; +import io.grpc.ClientStreamTracer; +import io.grpc.Grpc; +import io.grpc.Internal; +import io.grpc.InternalLogId; +import io.grpc.Metadata; +import io.grpc.MethodDescriptor; +import io.grpc.SecurityLevel; +import io.grpc.Status; +import io.grpc.StatusException; +import io.grpc.binder.AndroidComponentAddress; +import io.grpc.binder.AsyncSecurityPolicy; +import io.grpc.binder.InboundParcelablePolicy; +import io.grpc.binder.SecurityPolicy; +import io.grpc.internal.ClientStream; +import io.grpc.internal.ClientTransportFactory.ClientTransportOptions; +import io.grpc.internal.ConnectionClientTransport; +import io.grpc.internal.FailingClientStream; +import io.grpc.internal.GrpcAttributes; +import io.grpc.internal.GrpcUtil; +import io.grpc.internal.ManagedClientTransport; +import io.grpc.internal.ObjectPool; +import io.grpc.internal.SimpleDisconnectError; +import io.grpc.internal.StatsTraceContext; +import java.util.concurrent.Executor; +import java.util.concurrent.ScheduledFuture; +import java.util.concurrent.atomic.AtomicInteger; +import javax.annotation.Nullable; +import javax.annotation.concurrent.ThreadSafe; + +/** Concrete client-side transport implementation. */ +@ThreadSafe +@Internal +public final class BinderClientTransport extends BinderTransport + implements ConnectionClientTransport, Bindable.Observer { + + private final ObjectPool offloadExecutorPool; + private final Executor offloadExecutor; + private final SecurityPolicy securityPolicy; + private final Bindable serviceBinding; + + @GuardedBy("this") + private final ClientHandshake handshake; + + /** Number of ongoing calls which keep this transport "in-use". */ + private final AtomicInteger numInUseStreams; + + private final long readyTimeoutMillis; + private final PingTracker pingTracker; + private final boolean preAuthorizeServer; + + @Nullable private ManagedClientTransport.Listener clientTransportListener; + + @GuardedBy("this") + private int latestCallId = FIRST_CALL_ID; + + @GuardedBy("this") + private ScheduledFuture readyTimeoutFuture; // != null iff timeout scheduled. + + /** + * Constructs a new transport instance. + * + * @param factory parameters common to all a Channel's transports + * @param targetAddress the fully resolved and load-balanced server address + * @param options other parameters that can vary as transports come and go within a Channel + */ + public BinderClientTransport( + BinderClientTransportFactory factory, + AndroidComponentAddress targetAddress, + ClientTransportOptions options) { + super( + factory.scheduledExecutorPool, + buildClientAttributes( + options.getEagAttributes(), + factory.sourceContext, + targetAddress, + factory.inboundParcelablePolicy), + factory.binderDecorator, + buildLogId(factory.sourceContext, targetAddress)); + this.offloadExecutorPool = factory.offloadExecutorPool; + this.securityPolicy = factory.securityPolicy; + this.offloadExecutor = offloadExecutorPool.getObject(); + this.readyTimeoutMillis = factory.readyTimeoutMillis; + Boolean preAuthServerOverride = options.getEagAttributes().get(PRE_AUTH_SERVER_OVERRIDE); + this.preAuthorizeServer = + preAuthServerOverride != null ? preAuthServerOverride : factory.preAuthorizeServers; + this.handshake = + factory.useLegacyAuthStrategy ? new LegacyClientHandshake() : new V2ClientHandshake(); + numInUseStreams = new AtomicInteger(); + pingTracker = new PingTracker(Ticker.systemTicker(), (id) -> sendPing(id)); + serviceBinding = + new ServiceBinding( + factory.mainThreadExecutor, + factory.sourceContext, + factory.channelCredentials, + targetAddress.asBindIntent(), + targetAddress.getTargetUser(), + factory.bindServiceFlags.toInteger(), + this); + } + + @Override + void releaseExecutors() { + super.releaseExecutors(); + offloadExecutorPool.returnObject(offloadExecutor); + } + + @Override + public synchronized void onBound(IBinder binder) { + handshake.onBound(binderDecorator.decorate(OneWayBinderProxy.wrap(binder, offloadExecutor))); + } + + @Override + public synchronized void onUnbound(Status reason) { + shutdownInternal(reason, true); + } + + @CheckReturnValue + @Override + public synchronized Runnable start(Listener clientTransportListener) { + this.clientTransportListener = checkNotNull(clientTransportListener); + return this::postStartRunnable; + } + + private synchronized void postStartRunnable() { + if (!inState(TransportState.NOT_STARTED)) { + return; + } + + setState(TransportState.SETUP); + + try { + if (preAuthorizeServer) { + preAuthorize(serviceBinding.resolve()); + } else { + serviceBinding.bind(); + } + } catch (StatusException e) { + shutdownInternal(e.getStatus(), true); + return; + } + + if (readyTimeoutMillis >= 0) { + readyTimeoutFuture = + getScheduledExecutorService() + .schedule( + BinderClientTransport.this::onReadyTimeout, readyTimeoutMillis, MILLISECONDS); + } + } + + @GuardedBy("this") + private void preAuthorize(ServiceInfo serviceInfo) { + // It's unlikely, but the identity/existence of this Service could change by the time we + // actually connect. It doesn't matter though, because: + // - If pre-auth fails (but would succeed against the server's new state), the grpc-core layer + // will eventually retry using a new transport instance that will see the Service's new state. + // - If pre-auth succeeds (but would fail against the server's new state), we might give an + // unauthorized server a chance to run, but the connection will still fail by SecurityPolicy + // check later in handshake. Pre-auth remains effective at mitigating abuse because malware + // can't typically control the exact timing of its installation. + ListenableFuture preAuthResultFuture = + register(checkServerAuthorizationAsync(serviceInfo.applicationInfo.uid)); + Futures.addCallback( + preAuthResultFuture, + new FutureCallback() { + @Override + public void onSuccess(Status result) { + handlePreAuthResult(result); + } + + @Override + public void onFailure(Throwable t) { + handleAuthResult(t); + } + }, + offloadExecutor); + } + + private synchronized void handlePreAuthResult(Status authorization) { + if (!inState(TransportState.SETUP)) { + return; + } + + if (!authorization.isOk()) { + shutdownInternal(authorization, true); + return; + } + + serviceBinding.bind(); + } + + private synchronized void onReadyTimeout() { + if (inState(TransportState.SETUP)) { + readyTimeoutFuture = null; + shutdownInternal( + Status.DEADLINE_EXCEEDED.withDescription( + "Connect timeout " + readyTimeoutMillis + "ms lapsed"), + true); + } + } + + @Override + public synchronized ClientStream newStream( + final MethodDescriptor method, + final Metadata headers, + final CallOptions callOptions, + ClientStreamTracer[] tracers) { + if (!inState(TransportState.READY)) { + return newFailingClientStream( + isShutdown() + ? shutdownStatus + : Status.INTERNAL.withDescription("newStream() before transportReady()"), + attributes, + headers, + tracers); + } + + int callId = latestCallId++; + if (latestCallId == LAST_CALL_ID) { + latestCallId = FIRST_CALL_ID; + } + StatsTraceContext statsTraceContext = + StatsTraceContext.newClientContext(tracers, attributes, headers); + Inbound.ClientInbound inbound = + new Inbound.ClientInbound( + this, attributes, callId, GrpcUtil.shouldBeCountedForInUse(callOptions)); + if (ongoingCalls.putIfAbsent(callId, inbound) != null) { + Status failure = Status.INTERNAL.withDescription("Clashing call IDs"); + shutdownInternal(failure, true); + return newFailingClientStream(failure, attributes, headers, tracers); + } + + if (inbound.countsForInUse() && numInUseStreams.getAndIncrement() == 0) { + clientTransportListener.transportInUse(true); + } + Outbound.ClientOutbound outbound = + new Outbound.ClientOutbound(this, callId, method, headers, statsTraceContext); + if (method.getType().clientSendsOneMessage()) { + return new SingleMessageClientStream(inbound, outbound, attributes); + } else { + return new MultiMessageClientStream(inbound, outbound, attributes); + } + } + + @Override + protected void unregisterInbound(Inbound inbound) { + if (inbound.countsForInUse() && numInUseStreams.decrementAndGet() == 0) { + clientTransportListener.transportInUse(false); + } + super.unregisterInbound(inbound); + } + + @Override + public void ping(final PingCallback callback, Executor executor) { + pingTracker.startPing(callback, executor); + } + + @Override + public synchronized void shutdown(Status reason) { + checkNotNull(reason, "reason"); + shutdownInternal(reason, false); + } + + @Override + public synchronized void shutdownNow(Status reason) { + checkNotNull(reason, "reason"); + shutdownInternal(reason, true); + } + + @Override + @GuardedBy("this") + void notifyShutdown(Status status) { + clientTransportListener.transportShutdown(status, SimpleDisconnectError.UNKNOWN); + } + + @Override + @GuardedBy("this") + void notifyTerminated() { + if (numInUseStreams.getAndSet(0) > 0) { + clientTransportListener.transportInUse(false); + } + if (readyTimeoutFuture != null) { + readyTimeoutFuture.cancel(false); + readyTimeoutFuture = null; + } + serviceBinding.unbind(); + clientTransportListener.transportTerminated(); + } + + @Override + @GuardedBy("this") + protected void handleSetupTransport(Parcel parcel) { + if (!inState(TransportState.SETUP)) { + return; + } + + int version = parcel.readInt(); + if (version != WIRE_FORMAT_VERSION) { + shutdownInternal(Status.UNAVAILABLE.withDescription("Wire format version mismatch"), true); + return; + } + + IBinder binder = parcel.readStrongBinder(); + if (binder == null) { + shutdownInternal(Status.UNAVAILABLE.withDescription("Malformed SETUP_TRANSPORT data"), true); + return; + } + + if (!setOutgoingBinder(OneWayBinderProxy.wrap(binder, offloadExecutor))) { + shutdownInternal( + Status.UNAVAILABLE.withDescription("Failed to observe outgoing binder"), true); + return; + } + handshake.handleSetupTransport(); + } + + @GuardedBy("this") + private void checkServerAuthorization(int remoteUid) { + ListenableFuture authResultFuture = register(checkServerAuthorizationAsync(remoteUid)); + Futures.addCallback( + authResultFuture, + new FutureCallback() { + @Override + public void onSuccess(Status result) { + handleAuthResult(result); + } + + @Override + public void onFailure(Throwable t) { + handleAuthResult(t); + } + }, + offloadExecutor); + } + + private ListenableFuture checkServerAuthorizationAsync(int remoteUid) { + return (securityPolicy instanceof AsyncSecurityPolicy) + ? ((AsyncSecurityPolicy) securityPolicy).checkAuthorizationAsync(remoteUid) + : Futures.submit(() -> securityPolicy.checkAuthorization(remoteUid), offloadExecutor); + } + + private synchronized void handleAuthResult(Status authorization) { + if (!inState(TransportState.SETUP)) { + return; + } + + if (!authorization.isOk()) { + shutdownInternal(authorization, true); + return; + } + handshake.onServerAuthorizationOk(); + } + + private final class V2ClientHandshake implements ClientHandshake { + + private OneWayBinderProxy endpointBinder; + + @Override + @GuardedBy("BinderClientTransport.this") // By way of @GuardedBy("this") `handshake` member. + public void onBound(OneWayBinderProxy endpointBinder) { + this.endpointBinder = endpointBinder; + Futures.addCallback( + Futures.submit(serviceBinding::getConnectedServiceInfo, offloadExecutor), + new FutureCallback() { + @Override + public void onSuccess(ServiceInfo result) { + synchronized (BinderClientTransport.this) { + onConnectedServiceInfo(result); + } + } + + @Override + public void onFailure(Throwable t) { + synchronized (BinderClientTransport.this) { + shutdownInternal(Status.fromThrowable(t), true); + } + } + }, + offloadExecutor); + } + + @GuardedBy("BinderClientTransport.this") + private void onConnectedServiceInfo(ServiceInfo serviceInfo) { + if (!inState(TransportState.SETUP)) { + return; + } + attributes = setSecurityAttrs(attributes, serviceInfo.applicationInfo.uid); + checkServerAuthorization(serviceInfo.applicationInfo.uid); + } + + @Override + @GuardedBy("BinderClientTransport.this") + public void onServerAuthorizationOk() { + sendSetupTransaction(endpointBinder); + } + + @Override + @GuardedBy("BinderClientTransport.this") // By way of @GuardedBy("this") `handshake` member. + public void handleSetupTransport() { + onHandshakeComplete(); + } + } + + @GuardedBy("this") + private void onHandshakeComplete() { + setState(TransportState.READY); + attributes = clientTransportListener.filterTransport(attributes); + clientTransportListener.transportReady(); + if (readyTimeoutFuture != null) { + readyTimeoutFuture.cancel(false); + readyTimeoutFuture = null; + } + } + + private synchronized void handleAuthResult(Throwable t) { + shutdownInternal( + Status.INTERNAL.withDescription("Could not evaluate SecurityPolicy").withCause(t), true); + } + + @GuardedBy("this") + @Override + protected void handlePingResponse(Parcel parcel) { + pingTracker.onPingResponse(parcel.readInt()); + } + + /** + * An abstract implementation of the client's connection handshake. + * + *

Supports a clean migration away from the legacy approach, one client at a time. + */ + private interface ClientHandshake { + /** + * Notifies the implementation that the binding has succeeded and we are now connected to the + * server's "endpoint" which can be reached at 'endpointBinder'. + */ + @MainThread + void onBound(OneWayBinderProxy endpointBinder); + + /** Notifies the implementation that we've received a valid SETUP_TRANSPORT transaction. */ + @BinderThread + void handleSetupTransport(); + + /** Notifies the implementation that the SecurityPolicy check of the server succeeded. */ + void onServerAuthorizationOk(); + } + + private final class LegacyClientHandshake implements ClientHandshake { + @Override + @MainThread + @GuardedBy("BinderClientTransport.this") // By way of @GuardedBy("this") `handshake` member. + public void onBound(OneWayBinderProxy binder) { + sendSetupTransaction(binder); + } + + @Override + @BinderThread + @GuardedBy("BinderClientTransport.this") // By way of @GuardedBy("this") `handshake` member. + public void handleSetupTransport() { + int remoteUid = Binder.getCallingUid(); + restrictIncomingBinderToCallsFrom(remoteUid); + attributes = setSecurityAttrs(attributes, remoteUid); + checkServerAuthorization(remoteUid); + } + + @Override + @GuardedBy("BinderClientTransport.this") // By way of @GuardedBy("this") `handshake` member. + public void onServerAuthorizationOk() { + onHandshakeComplete(); + } + } + + private static ClientStream newFailingClientStream( + Status failure, Attributes attributes, Metadata headers, ClientStreamTracer[] tracers) { + StatsTraceContext statsTraceContext = + StatsTraceContext.newClientContext(tracers, attributes, headers); + statsTraceContext.clientOutboundHeaders(); + return new FailingClientStream(failure, tracers); + } + + private static InternalLogId buildLogId( + Context sourceContext, AndroidComponentAddress targetAddress) { + return InternalLogId.allocate( + BinderClientTransport.class, + sourceContext.getClass().getSimpleName() + "->" + targetAddress); + } + + private static Attributes buildClientAttributes( + Attributes eagAttrs, + Context sourceContext, + AndroidComponentAddress targetAddress, + InboundParcelablePolicy inboundParcelablePolicy) { + return Attributes.newBuilder() + .set(GrpcAttributes.ATTR_SECURITY_LEVEL, SecurityLevel.NONE) // Trust noone for now. + .set(GrpcAttributes.ATTR_CLIENT_EAG_ATTRS, eagAttrs) + .set(Grpc.TRANSPORT_ATTR_LOCAL_ADDR, AndroidComponentAddress.forContext(sourceContext)) + .set(Grpc.TRANSPORT_ATTR_REMOTE_ADDR, targetAddress) + .set(INBOUND_PARCELABLE_POLICY, inboundParcelablePolicy) + .build(); + } + + private static Attributes setSecurityAttrs(Attributes attributes, int uid) { + return attributes.toBuilder() + .set(REMOTE_UID, uid) + .set( + GrpcAttributes.ATTR_SECURITY_LEVEL, + uid == Process.myUid() + ? SecurityLevel.PRIVACY_AND_INTEGRITY + : SecurityLevel.INTEGRITY) // TODO: Have the SecrityPolicy decide this. + .build(); + } +} diff --git a/binder/src/main/java/io/grpc/binder/internal/BinderClientTransportFactory.java b/binder/src/main/java/io/grpc/binder/internal/BinderClientTransportFactory.java new file mode 100644 index 00000000000..459e064ad9b --- /dev/null +++ b/binder/src/main/java/io/grpc/binder/internal/BinderClientTransportFactory.java @@ -0,0 +1,232 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.grpc.binder.internal; + +import static com.google.common.base.Preconditions.checkNotNull; + +import android.content.Context; +import androidx.core.content.ContextCompat; +import io.grpc.ChannelCredentials; +import io.grpc.ChannelLogger; +import io.grpc.Internal; +import io.grpc.binder.AndroidComponentAddress; +import io.grpc.binder.BindServiceFlags; +import io.grpc.binder.BinderChannelCredentials; +import io.grpc.binder.InboundParcelablePolicy; +import io.grpc.binder.SecurityPolicies; +import io.grpc.binder.SecurityPolicy; +import io.grpc.internal.ClientTransportFactory; +import io.grpc.internal.GrpcUtil; +import io.grpc.internal.ManagedChannelImplBuilder.ClientTransportFactoryBuilder; +import io.grpc.internal.ObjectPool; +import io.grpc.internal.SharedResourcePool; +import java.net.SocketAddress; +import java.util.Collection; +import java.util.Collections; +import java.util.concurrent.Executor; +import java.util.concurrent.ScheduledExecutorService; + +/** Creates new binder transports. */ +@Internal +public final class BinderClientTransportFactory implements ClientTransportFactory { + final Context sourceContext; + final BinderChannelCredentials channelCredentials; + final Executor mainThreadExecutor; + final ObjectPool scheduledExecutorPool; + final ObjectPool offloadExecutorPool; + final SecurityPolicy securityPolicy; + final BindServiceFlags bindServiceFlags; + final InboundParcelablePolicy inboundParcelablePolicy; + final OneWayBinderProxy.Decorator binderDecorator; + final long readyTimeoutMillis; + final boolean preAuthorizeServers; // TODO(jdcormie): Default to true. + final boolean useLegacyAuthStrategy; + + ScheduledExecutorService executorService; + Executor offloadExecutor; + private boolean closed; + + private BinderClientTransportFactory(Builder builder) { + sourceContext = checkNotNull(builder.sourceContext); + channelCredentials = checkNotNull(builder.channelCredentials); + mainThreadExecutor = + builder.mainThreadExecutor != null + ? builder.mainThreadExecutor + : ContextCompat.getMainExecutor(sourceContext); + scheduledExecutorPool = checkNotNull(builder.scheduledExecutorPool); + offloadExecutorPool = checkNotNull(builder.offloadExecutorPool); + securityPolicy = checkNotNull(builder.securityPolicy); + bindServiceFlags = checkNotNull(builder.bindServiceFlags); + inboundParcelablePolicy = checkNotNull(builder.inboundParcelablePolicy); + binderDecorator = checkNotNull(builder.binderDecorator); + readyTimeoutMillis = builder.readyTimeoutMillis; + preAuthorizeServers = builder.preAuthorizeServers; + useLegacyAuthStrategy = builder.useLegacyAuthStrategy; + + executorService = scheduledExecutorPool.getObject(); + offloadExecutor = offloadExecutorPool.getObject(); + } + + @Override + public BinderClientTransport newClientTransport( + SocketAddress addr, ClientTransportOptions options, ChannelLogger channelLogger) { + if (closed) { + throw new IllegalStateException("The transport factory is closed."); + } + return new BinderClientTransport(this, (AndroidComponentAddress) addr, options); + } + + @Override + public ScheduledExecutorService getScheduledExecutorService() { + return executorService; + } + + @Override + public SwapChannelCredentialsResult swapChannelCredentials(ChannelCredentials channelCreds) { + return null; + } + + @Override + public void close() { + closed = true; + executorService = scheduledExecutorPool.returnObject(executorService); + offloadExecutor = offloadExecutorPool.returnObject(offloadExecutor); + } + + @Override + public Collection> getSupportedSocketAddressTypes() { + return Collections.singleton(AndroidComponentAddress.class); + } + + /** Allows fluent construction of ClientTransportFactory. */ + public static final class Builder implements ClientTransportFactoryBuilder { + // Required. + Context sourceContext; + ObjectPool offloadExecutorPool; + + // Optional. + BinderChannelCredentials channelCredentials = BinderChannelCredentials.forDefault(); + Executor mainThreadExecutor; // Default filled-in at build time once sourceContext is decided. + ObjectPool scheduledExecutorPool = + SharedResourcePool.forResource(GrpcUtil.TIMER_SERVICE); + SecurityPolicy securityPolicy = SecurityPolicies.internalOnly(); + BindServiceFlags bindServiceFlags = BindServiceFlags.DEFAULTS; + InboundParcelablePolicy inboundParcelablePolicy = InboundParcelablePolicy.DEFAULT; + OneWayBinderProxy.Decorator binderDecorator = OneWayBinderProxy.IDENTITY_DECORATOR; + long readyTimeoutMillis = 60_000; + boolean preAuthorizeServers; + boolean useLegacyAuthStrategy = true; // TODO(jdcormie): Default to false. + + @Override + public BinderClientTransportFactory buildClientTransportFactory() { + return new BinderClientTransportFactory(this); + } + + public Builder setSourceContext(Context sourceContext) { + this.sourceContext = checkNotNull(sourceContext); + return this; + } + + public Context getSourceContext() { + return sourceContext; + } + + public Builder setOffloadExecutorPool(ObjectPool offloadExecutorPool) { + this.offloadExecutorPool = checkNotNull(offloadExecutorPool, "offloadExecutorPool"); + return this; + } + + public Builder setChannelCredentials(BinderChannelCredentials channelCredentials) { + this.channelCredentials = checkNotNull(channelCredentials, "channelCredentials"); + return this; + } + + public Builder setMainThreadExecutor(Executor mainThreadExecutor) { + this.mainThreadExecutor = checkNotNull(mainThreadExecutor, "mainThreadExecutor"); + return this; + } + + public Builder setScheduledExecutorPool( + ObjectPool scheduledExecutorPool) { + this.scheduledExecutorPool = checkNotNull(scheduledExecutorPool, "scheduledExecutorPool"); + return this; + } + + public Builder setSecurityPolicy(SecurityPolicy securityPolicy) { + this.securityPolicy = checkNotNull(securityPolicy, "securityPolicy"); + return this; + } + + public Builder setBindServiceFlags(BindServiceFlags bindServiceFlags) { + this.bindServiceFlags = checkNotNull(bindServiceFlags, "bindServiceFlags"); + return this; + } + + public Builder setInboundParcelablePolicy(InboundParcelablePolicy inboundParcelablePolicy) { + this.inboundParcelablePolicy = + checkNotNull(inboundParcelablePolicy, "inboundParcelablePolicy"); + return this; + } + + /** + * Decorates both the "endpoint" and "server" binders, for fault injection. + * + *

Optional. If absent, these objects will go undecorated. + */ + public Builder setBinderDecorator(OneWayBinderProxy.Decorator binderDecorator) { + this.binderDecorator = checkNotNull(binderDecorator, "binderDecorator"); + return this; + } + + /** + * Limits how long it can take to for a new transport to become ready after being started. + * + *

This process currently includes: + * + *

    + *
  • Creating an Android binding. + *
  • Waiting for Android to create the server process. + *
  • Waiting for the remote Service to be created and handle onBind(). + *
  • Exchanging handshake transactions according to the wire protocol. + *
  • Evaluating a {@link SecurityPolicy} on both sides. + *
+ * + *

This setting doesn't change the need for deadlines at the call level. It merely ensures + * that gRPC features like load balancing and + * fail-fast work + * as expected despite certain edge cases that could otherwise stall the transport indefinitely. + * + *

Optional but enabled by default. Use a negative value to wait indefinitely. + */ + public Builder setReadyTimeoutMillis(long readyTimeoutMillis) { + this.readyTimeoutMillis = readyTimeoutMillis; + return this; + } + + /** Whether to check server addresses against the SecurityPolicy *before* binding to them. */ + public Builder setPreAuthorizeServers(boolean preAuthorizeServers) { + this.preAuthorizeServers = preAuthorizeServers; + return this; + } + + /** Specifies which version of the client handshake to use. */ + public Builder setUseLegacyAuthStrategy(boolean useLegacyAuthStrategy) { + this.useLegacyAuthStrategy = useLegacyAuthStrategy; + return this; + } + } +} diff --git a/binder/src/main/java/io/grpc/binder/internal/BinderServer.java b/binder/src/main/java/io/grpc/binder/internal/BinderServer.java index 6dd2f231650..f913775fcbe 100644 --- a/binder/src/main/java/io/grpc/binder/internal/BinderServer.java +++ b/binder/src/main/java/io/grpc/binder/internal/BinderServer.java @@ -16,31 +16,39 @@ package io.grpc.binder.internal; +import static android.os.IBinder.FLAG_ONEWAY; import static com.google.common.base.Preconditions.checkNotNull; +import static io.grpc.binder.internal.BinderTransport.SHUTDOWN_TRANSPORT; import android.os.Binder; import android.os.IBinder; import android.os.Parcel; +import android.os.RemoteException; import com.google.common.collect.ImmutableList; +import com.google.errorprone.annotations.concurrent.GuardedBy; import io.grpc.Attributes; import io.grpc.Grpc; import io.grpc.InternalChannelz.SocketStats; import io.grpc.InternalInstrumented; import io.grpc.ServerStreamTracer; import io.grpc.binder.AndroidComponentAddress; +import io.grpc.binder.BinderInternal; import io.grpc.binder.InboundParcelablePolicy; +import io.grpc.binder.SecurityPolicies; import io.grpc.binder.ServerSecurityPolicy; import io.grpc.internal.GrpcUtil; import io.grpc.internal.InternalServer; import io.grpc.internal.ObjectPool; import io.grpc.internal.ServerListener; -import io.grpc.internal.SharedResourceHolder; +import io.grpc.internal.SharedResourcePool; import java.io.IOException; import java.net.SocketAddress; import java.util.List; +import java.util.concurrent.Executor; import java.util.concurrent.ScheduledExecutorService; +import java.util.logging.Level; +import java.util.logging.Logger; import javax.annotation.Nullable; -import javax.annotation.concurrent.GuardedBy; import javax.annotation.concurrent.ThreadSafe; /** @@ -48,19 +56,21 @@ * *

Multiple incoming connections transports may be active at a time. * - * IMPORTANT: This implementation must comply with this published wire format. + *

IMPORTANT: This implementation must comply with this published wire format. * https://github.com/grpc/proposal/blob/master/L73-java-binderchannel/wireformat.md */ @ThreadSafe public final class BinderServer implements InternalServer, LeakSafeOneWayBinder.TransactionHandler { + private static final Logger logger = Logger.getLogger(BinderServer.class.getName()); private final ObjectPool executorServicePool; + private final ObjectPool executorPool; private final ImmutableList streamTracerFactories; private final AndroidComponentAddress listenAddress; private final LeakSafeOneWayBinder hostServiceBinder; private final BinderTransportSecurity.ServerPolicyChecker serverPolicyChecker; private final InboundParcelablePolicy inboundParcelablePolicy; - private final BinderTransportSecurity.ShutdownListener transportSecurityShutdownListener; + private final OneWayBinderProxy.Decorator clientBinderDecorator; @GuardedBy("this") private ServerListener listener; @@ -68,27 +78,22 @@ public final class BinderServer implements InternalServer, LeakSafeOneWayBinder. @GuardedBy("this") private ScheduledExecutorService executorService; + @Nullable // Before start() and after termination. + @GuardedBy("this") + private Executor executor; + @GuardedBy("this") private boolean shutdown; - /** - * @param transportSecurityShutdownListener represents resources that should be cleaned up once - * the server shuts down. - */ - public BinderServer( - AndroidComponentAddress listenAddress, - ObjectPool executorServicePool, - List streamTracerFactories, - BinderTransportSecurity.ServerPolicyChecker serverPolicyChecker, - InboundParcelablePolicy inboundParcelablePolicy, - BinderTransportSecurity.ShutdownListener transportSecurityShutdownListener) { - this.listenAddress = listenAddress; - this.executorServicePool = executorServicePool; + private BinderServer(Builder builder) { + this.listenAddress = checkNotNull(builder.listenAddress); + this.executorPool = checkNotNull(builder.executorPool); + this.executorServicePool = builder.executorServicePool; this.streamTracerFactories = - ImmutableList.copyOf(checkNotNull(streamTracerFactories, "streamTracerFactories")); - this.serverPolicyChecker = checkNotNull(serverPolicyChecker, "serverPolicyChecker"); - this.inboundParcelablePolicy = inboundParcelablePolicy; - this.transportSecurityShutdownListener = transportSecurityShutdownListener; + ImmutableList.copyOf(checkNotNull(builder.streamTracerFactories, "streamTracerFactories")); + this.serverPolicyChecker = BinderInternal.createPolicyChecker(builder.serverSecurityPolicy); + this.inboundParcelablePolicy = builder.inboundParcelablePolicy; + this.clientBinderDecorator = builder.clientBinderDecorator; hostServiceBinder = new LeakSafeOneWayBinder(this); } @@ -99,8 +104,9 @@ public IBinder getHostBinder() { @Override public synchronized void start(ServerListener serverListener) throws IOException { - this.listener = serverListener; + listener = new ActiveTransportTracker(serverListener, this::onTerminated); executorService = executorServicePool.getObject(); + executor = executorPool.getObject(); } @Override @@ -129,13 +135,17 @@ public synchronized void shutdown() { if (!shutdown) { shutdown = true; // Break the connection to the binder. We'll receive no more transactions. - hostServiceBinder.detach(); + hostServiceBinder.setHandler(GoAwayHandler.INSTANCE); listener.serverShutdown(); + // TODO(jdcormie): Shouldn't this happen in onTerminated()? Is this even used anywhere? executorService = executorServicePool.returnObject(executorService); - transportSecurityShutdownListener.onServerShutdown(); } } + private synchronized void onTerminated() { + executor = executorPool.returnObject(executor); + } + @Override public String toString() { return "BinderServer[" + listenAddress + "]"; @@ -144,6 +154,12 @@ public String toString() { @Override public synchronized boolean handleTransaction(int code, Parcel parcel) { if (code == BinderTransport.SETUP_TRANSPORT) { + if (shutdown) { + // An incoming SETUP_TRANSPORT transaction may have already been in-flight when we removed + // ourself as TransactionHandler in #shutdown(). So we must check for shutdown again here. + return GoAwayHandler.INSTANCE.handleTransaction(code, parcel); + } + int version = parcel.readInt(); // If the client-provided version is more recent, we accept the connection, // but specify the older version which we support. @@ -158,18 +174,143 @@ public synchronized boolean handleTransaction(int code, Parcel parcel) { .set(BinderTransport.REMOTE_UID, callingUid) .set(BinderTransport.SERVER_AUTHORITY, listenAddress.getAuthority()) .set(BinderTransport.INBOUND_PARCELABLE_POLICY, inboundParcelablePolicy); - BinderTransportSecurity.attachAuthAttrs(attrsBuilder, callingUid, serverPolicyChecker); + BinderTransportSecurity.attachAuthAttrs( + attrsBuilder, + callingUid, + serverPolicyChecker, + checkNotNull(executor, "Not started?")); // Create a new transport and let our listener know about it. - BinderTransport.BinderServerTransport transport = - new BinderTransport.BinderServerTransport( - executorServicePool, attrsBuilder.build(), streamTracerFactories, - OneWayBinderProxy.IDENTITY_DECORATOR, + BinderServerTransport transport = + BinderServerTransport.create( + executorServicePool, + attrsBuilder.build(), + streamTracerFactories, + clientBinderDecorator, callbackBinder); - transport.setServerTransportListener(listener.transportCreated(transport)); + transport.start(listener.transportCreated(transport)); return true; } } } return false; } + + static final class GoAwayHandler implements LeakSafeOneWayBinder.TransactionHandler { + static final GoAwayHandler INSTANCE = new GoAwayHandler(); + + @Override + public boolean handleTransaction(int code, Parcel parcel) { + if (code == BinderTransport.SETUP_TRANSPORT) { + int version = parcel.readInt(); + if (version >= BinderTransport.EARLIEST_SUPPORTED_WIRE_FORMAT_VERSION) { + IBinder callbackBinder = parcel.readStrongBinder(); + try (ParcelHolder goAwayReply = ParcelHolder.obtain()) { + // Send empty flags to avoid a memory leak linked to empty parcels (b/207778694). + goAwayReply.get().writeInt(0); + callbackBinder.transact(SHUTDOWN_TRANSPORT, goAwayReply.get(), null, FLAG_ONEWAY); + } catch (RemoteException re) { + logger.log(Level.WARNING, "Couldn't reply to post-shutdown() SETUP_TRANSPORT.", re); + } + } + } + return false; + } + } + + /** Fluent builder of {@link BinderServer} instances. */ + public static class Builder { + @Nullable AndroidComponentAddress listenAddress; + @Nullable List streamTracerFactories; + @Nullable ObjectPool executorPool; + + ObjectPool executorServicePool = + SharedResourcePool.forResource(GrpcUtil.TIMER_SERVICE); + ServerSecurityPolicy serverSecurityPolicy = SecurityPolicies.serverInternalOnly(); + InboundParcelablePolicy inboundParcelablePolicy = InboundParcelablePolicy.DEFAULT; + OneWayBinderProxy.Decorator clientBinderDecorator = OneWayBinderProxy.IDENTITY_DECORATOR; + + public BinderServer build() { + return new BinderServer(this); + } + + /** + * Sets the "listen" address for this server. + * + *

This is somewhat of a grpc-java formality. Binder servers don't really listen, rather, + * Android creates and destroys them according to client needs. + * + *

Required. + */ + public Builder setListenAddress(AndroidComponentAddress listenAddress) { + this.listenAddress = listenAddress; + return this; + } + + /** + * Sets the source for {@link ServerStreamTracer}s that will be installed on all new streams. + * + *

Required. + */ + public Builder setStreamTracerFactories( + List streamTracerFactories) { + this.streamTracerFactories = streamTracerFactories; + return this; + } + + /** + * Sets the executor to be used for calling into the application. + * + *

Required. + */ + public Builder setExecutorPool(ObjectPool executorPool) { + this.executorPool = executorPool; + return this; + } + + /** + * Sets the executor to be used for scheduling channel timers. + * + *

Optional. A process-wide default executor will be used if unset. + */ + public Builder setExecutorServicePool( + ObjectPool executorServicePool) { + this.executorServicePool = checkNotNull(executorServicePool, "executorServicePool"); + return this; + } + + /** + * Sets the {@link ServerSecurityPolicy} to be used for built servers. + * + *

Optional, {@link SecurityPolicies#serverInternalOnly()} is the default. + */ + public Builder setServerSecurityPolicy(ServerSecurityPolicy serverSecurityPolicy) { + this.serverSecurityPolicy = checkNotNull(serverSecurityPolicy, "serverSecurityPolicy"); + return this; + } + + /** + * Sets the {@link InboundParcelablePolicy} to be used for built servers. + * + *

Optional, {@link InboundParcelablePolicy#DEFAULT} is the default. + */ + public Builder setInboundParcelablePolicy(InboundParcelablePolicy inboundParcelablePolicy) { + this.inboundParcelablePolicy = + checkNotNull(inboundParcelablePolicy, "inboundParcelablePolicy"); + return this; + } + + /** + * Sets the {@link OneWayBinderProxy.Decorator} to be applied to this server's "client Binders". + * + *

Tests can use this to capture post-setup transactions from server to client. The specified + * decorator will be applied every time a client connects. The decorated result will be used for + * all subsequent transactions to this client from the new ServerTransport. + * + *

Optional, {@link OneWayBinderProxy#IDENTITY_DECORATOR} is the default. + */ + public Builder setClientBinderDecorator(OneWayBinderProxy.Decorator clientBinderDecorator) { + this.clientBinderDecorator = checkNotNull(clientBinderDecorator); + return this; + } + } } diff --git a/binder/src/main/java/io/grpc/binder/internal/BinderServerTransport.java b/binder/src/main/java/io/grpc/binder/internal/BinderServerTransport.java new file mode 100644 index 00000000000..784d833bdf5 --- /dev/null +++ b/binder/src/main/java/io/grpc/binder/internal/BinderServerTransport.java @@ -0,0 +1,157 @@ +/* + * Copyright 2020 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.grpc.binder.internal; + +import android.os.IBinder; +import com.google.errorprone.annotations.concurrent.GuardedBy; +import io.grpc.Attributes; +import io.grpc.Grpc; +import io.grpc.Internal; +import io.grpc.InternalLogId; +import io.grpc.Metadata; +import io.grpc.ServerStreamTracer; +import io.grpc.Status; +import io.grpc.internal.ObjectPool; +import io.grpc.internal.ServerStream; +import io.grpc.internal.ServerTransport; +import io.grpc.internal.ServerTransportListener; +import io.grpc.internal.StatsTraceContext; +import java.util.List; +import java.util.concurrent.ScheduledExecutorService; +import javax.annotation.Nullable; + +/** Concrete server-side transport implementation. */ +@Internal +public final class BinderServerTransport extends BinderTransport implements ServerTransport { + + private final List streamTracerFactories; + + @GuardedBy("this") + private final SimplePromise listenerPromise = new SimplePromise<>(); + + private BinderServerTransport( + ObjectPool executorServicePool, + Attributes attributes, + List streamTracerFactories, + OneWayBinderProxy.Decorator binderDecorator) { + super(executorServicePool, attributes, binderDecorator, buildLogId(attributes)); + this.streamTracerFactories = streamTracerFactories; + } + + /** + * Constructs a new transport instance. + * + * @param binderDecorator used to decorate 'callbackBinder', for fault injection. + */ + public static BinderServerTransport create( + ObjectPool executorServicePool, + Attributes attributes, + List streamTracerFactories, + OneWayBinderProxy.Decorator binderDecorator, + IBinder callbackBinder) { + BinderServerTransport transport = + new BinderServerTransport( + executorServicePool, attributes, streamTracerFactories, binderDecorator); + // TODO(jdcormie): Plumb in the Server's executor() and use it here instead. + // No need to handle failure here because if 'callbackBinder' is already dead, we'll notice it + // again in start() when we send the first transaction. + synchronized (transport) { + transport.setOutgoingBinder( + OneWayBinderProxy.wrap(callbackBinder, transport.getScheduledExecutorService())); + } + return transport; + } + + /** + * Initializes this transport instance. + * + *

Must be called exactly once, even if {@link #shutdown} or {@link #shutdownNow} was called + * first. + * + * @param serverTransportListener where this transport will report events + */ + public synchronized void start(ServerTransportListener serverTransportListener) { + this.listenerPromise.set(serverTransportListener); + if (isShutdown()) { + // It's unlikely, but we could be shutdown externally between construction and start(). One + // possible cause is an extremely short handshake timeout. + return; + } + + sendSetupTransaction(); + + // Check we're not shutdown again, since a failure inside sendSetupTransaction (or a callback + // it triggers), could have shut us down. + if (isShutdown()) { + return; + } + + setState(TransportState.READY); + attributes = serverTransportListener.transportReady(attributes); + } + + StatsTraceContext createStatsTraceContext(String methodName, Metadata headers) { + return StatsTraceContext.newServerContext(streamTracerFactories, methodName, headers); + } + + /** + * Reports a new ServerStream requested by the remote client. + * + *

Precondition: {@link #start(ServerTransportListener)} must already have been called. + */ + synchronized Status startStream(ServerStream stream, String methodName, Metadata headers) { + if (isShutdown()) { + return Status.UNAVAILABLE.withDescription("transport is shutdown"); + } + + listenerPromise.get().streamCreated(stream, methodName, headers); + return Status.OK; + } + + @Override + @GuardedBy("this") + void notifyShutdown(Status status) { + // Nothing to do. + } + + @Override + @GuardedBy("this") + void notifyTerminated() { + listenerPromise.runWhenSet(ServerTransportListener::transportTerminated); + } + + @Override + public synchronized void shutdown() { + shutdownInternal(Status.OK, false); + } + + @Override + public synchronized void shutdownNow(Status reason) { + shutdownInternal(reason, true); + } + + @Override + @Nullable + @GuardedBy("this") + protected Inbound createInbound(int callId) { + return new Inbound.ServerInbound(this, attributes, callId); + } + + private static InternalLogId buildLogId(Attributes attributes) { + return InternalLogId.allocate( + BinderServerTransport.class, "from " + attributes.get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR)); + } +} diff --git a/binder/src/main/java/io/grpc/binder/internal/BinderTransport.java b/binder/src/main/java/io/grpc/binder/internal/BinderTransport.java index 4a33adb2154..2b7aa97bfd9 100644 --- a/binder/src/main/java/io/grpc/binder/internal/BinderTransport.java +++ b/binder/src/main/java/io/grpc/binder/internal/BinderTransport.java @@ -19,49 +19,29 @@ import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.base.Preconditions.checkState; import static com.google.common.util.concurrent.Futures.immediateFuture; +import static io.grpc.binder.internal.TransactionUtils.newCallerFilteringHandler; -import android.content.Context; -import android.os.Binder; import android.os.DeadObjectException; import android.os.IBinder; import android.os.Parcel; -import android.os.Process; import android.os.RemoteException; import android.os.TransactionTooLargeException; -import android.os.UserHandle; +import androidx.annotation.BinderThread; import com.google.common.annotations.VisibleForTesting; -import com.google.common.base.Ticker; import com.google.common.base.Verify; import com.google.common.util.concurrent.ListenableFuture; +import com.google.errorprone.annotations.concurrent.GuardedBy; import io.grpc.Attributes; -import io.grpc.CallOptions; -import io.grpc.ClientStreamTracer; import io.grpc.Grpc; import io.grpc.Internal; +import io.grpc.InternalChannelz; import io.grpc.InternalChannelz.SocketStats; import io.grpc.InternalLogId; -import io.grpc.Metadata; -import io.grpc.MethodDescriptor; -import io.grpc.SecurityLevel; -import io.grpc.ServerStreamTracer; import io.grpc.Status; import io.grpc.StatusException; -import io.grpc.binder.AndroidComponentAddress; -import io.grpc.binder.BindServiceFlags; -import io.grpc.binder.BinderChannelCredentials; import io.grpc.binder.InboundParcelablePolicy; -import io.grpc.binder.SecurityPolicy; -import io.grpc.internal.ClientStream; -import io.grpc.internal.ConnectionClientTransport; -import io.grpc.internal.FailingClientStream; -import io.grpc.internal.GrpcAttributes; -import io.grpc.internal.GrpcUtil; -import io.grpc.internal.ManagedClientTransport; +import io.grpc.binder.internal.LeakSafeOneWayBinder.TransactionHandler; import io.grpc.internal.ObjectPool; -import io.grpc.internal.ServerStream; -import io.grpc.internal.ServerTransport; -import io.grpc.internal.ServerTransportListener; -import io.grpc.internal.StatsTraceContext; import java.util.ArrayList; import java.util.Iterator; import java.util.LinkedHashSet; @@ -69,16 +49,11 @@ import java.util.Map; import java.util.NoSuchElementException; import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.Executor; +import java.util.concurrent.Future; import java.util.concurrent.ScheduledExecutorService; -import java.util.concurrent.atomic.AtomicInteger; -import java.util.concurrent.atomic.AtomicLong; import java.util.logging.Level; import java.util.logging.Logger; -import javax.annotation.CheckReturnValue; import javax.annotation.Nullable; -import javax.annotation.concurrent.GuardedBy; -import javax.annotation.concurrent.ThreadSafe; /** * Base class for binder-based gRPC transport. @@ -100,10 +75,10 @@ * *

IMPORTANT: This implementation must comply with this published wire format. * https://github.com/grpc/proposal/blob/master/L73-java-binderchannel/wireformat.md + * + *

This class is thread-safe. */ -@ThreadSafe -public abstract class BinderTransport - implements LeakSafeOneWayBinder.TransactionHandler, IBinder.DeathRecipient { +public abstract class BinderTransport implements IBinder.DeathRecipient { private static final Logger logger = Logger.getLogger(BinderTransport.class.getName()); @@ -131,12 +106,10 @@ public abstract class BinderTransport *

Should this change, we should still endeavor to support earlier wire-format versions. If * that's not possible, {@link EARLIEST_SUPPORTED_WIRE_FORMAT_VERSION} should be updated below. */ - @Internal - public static final int WIRE_FORMAT_VERSION = 1; + @Internal public static final int WIRE_FORMAT_VERSION = 1; /** The version code of the earliest wire format we support. */ - @Internal - public static final int EARLIEST_SUPPORTED_WIRE_FORMAT_VERSION = 1; + @Internal public static final int EARLIEST_SUPPORTED_WIRE_FORMAT_VERSION = 1; /** The max number of "in-flight" bytes before we start buffering transactions. */ private static final int TRANSACTION_BYTES_WINDOW = 128 * 1024; @@ -149,12 +122,10 @@ public abstract class BinderTransport * the binder. and from the host s Followed by: int wire_protocol_version IBinder * client_transports_callback_binder */ - @Internal - public static final int SETUP_TRANSPORT = IBinder.FIRST_CALL_TRANSACTION; + @Internal public static final int SETUP_TRANSPORT = IBinder.FIRST_CALL_TRANSACTION; /** Send to shutdown the transport from either end. */ - @Internal - public static final int SHUTDOWN_TRANSPORT = IBinder.FIRST_CALL_TRANSACTION + 1; + @Internal public static final int SHUTDOWN_TRANSPORT = IBinder.FIRST_CALL_TRANSACTION + 1; /** Send to acknowledge receipt of rpc bytes, for flow control. */ static final int ACKNOWLEDGE_BYTES = IBinder.FIRST_CALL_TRANSACTION + 2; @@ -169,10 +140,10 @@ public abstract class BinderTransport private static final int RESERVED_TRANSACTIONS = 1000; /** The first call ID we can use. */ - private static final int FIRST_CALL_ID = IBinder.FIRST_CALL_TRANSACTION + RESERVED_TRANSACTIONS; + static final int FIRST_CALL_ID = IBinder.FIRST_CALL_TRANSACTION + RESERVED_TRANSACTIONS; /** The last call ID we can use. */ - private static final int LAST_CALL_ID = IBinder.LAST_CALL_TRANSACTION; + static final int LAST_CALL_ID = IBinder.LAST_CALL_TRANSACTION; /** The states of this transport. */ protected enum TransportState { @@ -188,14 +159,19 @@ protected enum TransportState { private final ObjectPool executorServicePool; private final ScheduledExecutorService scheduledExecutorService; private final InternalLogId logId; + + @GuardedBy("this") private final LeakSafeOneWayBinder incomingBinder; - protected final ConcurrentHashMap> ongoingCalls; + protected final ConcurrentHashMap> ongoingCalls; protected final OneWayBinderProxy.Decorator binderDecorator; @GuardedBy("this") private final LinkedHashSet callIdsToNotifyWhenReady = new LinkedHashSet<>(); + @GuardedBy("this") + private final List> ownedFutures = new ArrayList<>(); // To cancel upon terminate. + @GuardedBy("this") protected Attributes attributes; @@ -211,12 +187,14 @@ protected enum TransportState { private final FlowController flowController; /** The number of incoming bytes we've received. */ - private final AtomicLong numIncomingBytes; + // Only read/written on @BinderThread. + private long numIncomingBytes; /** The number of incoming bytes we've told our peer we've received. */ + // Only read/written on @BinderThread. private long acknowledgedIncomingBytes; - private BinderTransport( + protected BinderTransport( ObjectPool executorServicePool, Attributes attributes, OneWayBinderProxy.Decorator binderDecorator, @@ -226,10 +204,9 @@ private BinderTransport( this.attributes = attributes; this.logId = logId; scheduledExecutorService = executorServicePool.getObject(); - incomingBinder = new LeakSafeOneWayBinder(this); + incomingBinder = new LeakSafeOneWayBinder(this::handleTransaction); ongoingCalls = new ConcurrentHashMap<>(); flowController = new FlowController(TRANSACTION_BYTES_WINDOW); - numIncomingBytes = new AtomicLong(); } // Override in child class. @@ -239,7 +216,15 @@ public final ScheduledExecutorService getScheduledExecutorService() { // Override in child class. public final ListenableFuture getStats() { - return immediateFuture(null); + Attributes attributes = getAttributes(); + return immediateFuture( + new InternalChannelz.SocketStats( + /* data= */ null, // TODO: Keep track of these stats with TransportTracer or similar. + /* local= */ attributes.get(Grpc.TRANSPORT_ATTR_LOCAL_ADDR), + /* remote= */ attributes.get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR), + // TODO: SocketOptions are meaningless for binder but we're still forced to provide one. + new InternalChannelz.SocketOptions.Builder().build(), + /* security= */ null)); } // Override in child class. @@ -260,14 +245,23 @@ final boolean isReady() { return !flowController.isTransmitWindowFull(); } + @GuardedBy("this") abstract void notifyShutdown(Status shutdownStatus); + @GuardedBy("this") abstract void notifyTerminated(); void releaseExecutors() { executorServicePool.returnObject(scheduledExecutorService); } + // Registers the specified future for eventual safe cancellation upon shutdown/terminate. + @GuardedBy("this") + protected final > T register(T future) { + ownedFutures.add(future); + return future; + } + @GuardedBy("this") boolean inState(TransportState transportState) { return this.transportState == transportState; @@ -284,6 +278,14 @@ final void setState(TransportState newState) { transportState = newState; } + /** + * Sets the binder to use for sending subsequent transactions to our peer. + * + *

Subclasses should call this as early as possible but not from a constructor. + * + *

Returns true for success, false if the process hosting 'binder' is already dead. Callers are + * responsible for handling this. + */ @GuardedBy("this") protected boolean setOutgoingBinder(OneWayBinderProxy binder) { binder = binderDecorator.decorate(binder); @@ -298,7 +300,10 @@ protected boolean setOutgoingBinder(OneWayBinderProxy binder) { @Override public synchronized void binderDied() { - shutdownInternal(Status.UNAVAILABLE.withDescription("binderDied"), true); + shutdownInternal( + Status.UNAVAILABLE.withDescription( + "Peer process crashed, exited or was killed (binderDied)"), + true); } @GuardedBy("this") @@ -313,16 +318,26 @@ final void shutdownInternal(Status shutdownStatus, boolean forceTerminate) { incomingBinder.detach(); setState(TransportState.SHUTDOWN_TERMINATED); sendShutdownTransaction(); - ArrayList> calls = new ArrayList<>(ongoingCalls.values()); + ArrayList> calls = new ArrayList<>(ongoingCalls.values()); ongoingCalls.clear(); + ArrayList> futuresToCancel = new ArrayList<>(ownedFutures); + ownedFutures.clear(); scheduledExecutorService.execute( () -> { - for (Inbound inbound : calls) { + for (Inbound inbound : calls) { synchronized (inbound) { inbound.closeAbnormal(shutdownStatus); } } - notifyTerminated(); + + for (Future future : futuresToCancel) { + // Not holding any locks here just in case some listener runs on a direct Executor. + future.cancel(false); // No effect if already isDone(). + } + + synchronized (this) { + notifyTerminated(); + } releaseExecutors(); }); } @@ -377,7 +392,7 @@ protected synchronized void sendPing(int id) throws StatusException { } } - protected void unregisterInbound(Inbound inbound) { + protected void unregisterInbound(Inbound inbound) { unregisterCall(inbound.callId); } @@ -420,13 +435,14 @@ final void sendOutOfBandClose(int callId, Status status) { } } - @Override - public final boolean handleTransaction(int code, Parcel parcel) { + @BinderThread + @VisibleForTesting + final boolean handleTransaction(int code, Parcel parcel) { try { return handleTransactionInternal(code, parcel); } catch (RuntimeException e) { - logger.log(Level.SEVERE, - "Terminating transport for uncaught Exception in transaction " + code, e); + logger.log( + Level.SEVERE, "Terminating transport for uncaught Exception in transaction " + code, e); synchronized (this) { // This unhandled exception may have put us in an inconsistent state. Force terminate the // whole transport so our peer knows something is wrong and so that clients can retry with @@ -437,6 +453,7 @@ public final boolean handleTransaction(int code, Parcel parcel) { } } + @BinderThread private boolean handleTransactionInternal(int code, Parcel parcel) { if (code < FIRST_CALL_ID) { synchronized (this) { @@ -464,13 +481,13 @@ private boolean handleTransactionInternal(int code, Parcel parcel) { } } else { int size = parcel.dataSize(); - Inbound inbound = ongoingCalls.get(code); + Inbound inbound = ongoingCalls.get(code); if (inbound == null) { synchronized (this) { if (!isShutdown()) { inbound = createInbound(code); if (inbound != null) { - Inbound existing = ongoingCalls.put(code, inbound); + Inbound existing = ongoingCalls.put(code, inbound); // Can't happen as only one invocation of handleTransaction() is running at a time. Verify.verify(existing == null, "impossible appearance of %s", existing); } @@ -480,19 +497,29 @@ private boolean handleTransactionInternal(int code, Parcel parcel) { if (inbound != null) { inbound.handleTransaction(parcel); } - long nib = numIncomingBytes.addAndGet(size); - if ((nib - acknowledgedIncomingBytes) > TRANSACTION_BYTES_WINDOW_FORCE_ACK) { + numIncomingBytes += size; + if ((numIncomingBytes - acknowledgedIncomingBytes) > TRANSACTION_BYTES_WINDOW_FORCE_ACK) { synchronized (this) { - sendAcknowledgeBytes(checkNotNull(outgoingBinder)); + sendAcknowledgeBytes(checkNotNull(outgoingBinder), numIncomingBytes); } + acknowledgedIncomingBytes = numIncomingBytes; } return true; } } + @BinderThread + @GuardedBy("this") + protected void restrictIncomingBinderToCallsFrom(int allowedCallingUid) { + TransactionHandler currentHandler = incomingBinder.getHandler(); + if (currentHandler != null) { + incomingBinder.setHandler(newCallerFilteringHandler(allowedCallingUid, currentHandler)); + } + } + @Nullable @GuardedBy("this") - protected Inbound createInbound(int callId) { + protected Inbound createInbound(int callId) { return null; } @@ -516,10 +543,8 @@ private final void handlePing(Parcel requestParcel) { protected void handlePingResponse(Parcel parcel) {} @GuardedBy("this") - private void sendAcknowledgeBytes(OneWayBinderProxy iBinder) { + private void sendAcknowledgeBytes(OneWayBinderProxy iBinder, long n) { // Send a transaction to acknowledge reception of incoming data. - long n = numIncomingBytes.get(); - acknowledgedIncomingBytes = n; try (ParcelHolder parcel = ParcelHolder.obtain()) { parcel.get().writeLong(n); iBinder.transact(ACKNOWLEDGE_BYTES, parcel); @@ -541,7 +566,7 @@ final void handleAcknowledgedBytes(long numBytes) { Iterator i = callIdsToNotifyWhenReady.iterator(); while (isReady() && i.hasNext()) { - Inbound inbound = ongoingCalls.get(i.next()); + Inbound inbound = ongoingCalls.get(i.next()); i.remove(); if (inbound != null) { // Calls can be removed out from under us. inbound.onTransportReady(); @@ -550,371 +575,6 @@ final void handleAcknowledgedBytes(long numBytes) { } } - /** Concrete client-side transport implementation. */ - @ThreadSafe - @Internal - public static final class BinderClientTransport extends BinderTransport - implements ConnectionClientTransport, Bindable.Observer { - - private final ObjectPool offloadExecutorPool; - private final Executor offloadExecutor; - private final SecurityPolicy securityPolicy; - private final Bindable serviceBinding; - /** Number of ongoing calls which keep this transport "in-use". */ - private final AtomicInteger numInUseStreams; - - private final PingTracker pingTracker; - - @Nullable private ManagedClientTransport.Listener clientTransportListener; - - @GuardedBy("this") - private int latestCallId = FIRST_CALL_ID; - - /** - * Constructs a new transport instance. - * - * @param binderDecorator used to decorate both the "endpoint" and "server" binders, for fault - * injection. - */ - public BinderClientTransport( - Context sourceContext, - BinderChannelCredentials channelCredentials, - AndroidComponentAddress targetAddress, - @Nullable UserHandle targetUserHandle, - BindServiceFlags bindServiceFlags, - Executor mainThreadExecutor, - ObjectPool executorServicePool, - ObjectPool offloadExecutorPool, - SecurityPolicy securityPolicy, - InboundParcelablePolicy inboundParcelablePolicy, - OneWayBinderProxy.Decorator binderDecorator, - Attributes eagAttrs) { - super( - executorServicePool, - buildClientAttributes(eagAttrs, sourceContext, targetAddress, inboundParcelablePolicy), - binderDecorator, - buildLogId(sourceContext, targetAddress)); - this.offloadExecutorPool = offloadExecutorPool; - this.securityPolicy = securityPolicy; - this.offloadExecutor = offloadExecutorPool.getObject(); - numInUseStreams = new AtomicInteger(); - pingTracker = new PingTracker(Ticker.systemTicker(), (id) -> sendPing(id)); - - serviceBinding = - new ServiceBinding( - mainThreadExecutor, - sourceContext, - channelCredentials, - targetAddress.asBindIntent(), - targetUserHandle, - bindServiceFlags.toInteger(), - this); - } - - @Override - void releaseExecutors() { - super.releaseExecutors(); - offloadExecutorPool.returnObject(offloadExecutor); - } - - @Override - public synchronized void onBound(IBinder binder) { - sendSetupTransaction(binderDecorator.decorate(OneWayBinderProxy.wrap(binder, offloadExecutor))); - } - - @Override - public synchronized void onUnbound(Status reason) { - shutdownInternal(reason, true); - } - - @CheckReturnValue - @Override - public synchronized Runnable start(ManagedClientTransport.Listener clientTransportListener) { - this.clientTransportListener = checkNotNull(clientTransportListener); - return () -> { - synchronized (BinderClientTransport.this) { - if (inState(TransportState.NOT_STARTED)) { - setState(TransportState.SETUP); - serviceBinding.bind(); - } - } - }; - } - - @Override - public synchronized ClientStream newStream( - final MethodDescriptor method, - final Metadata headers, - final CallOptions callOptions, - ClientStreamTracer[] tracers) { - if (!inState(TransportState.READY)) { - return newFailingClientStream( - isShutdown() - ? shutdownStatus - : Status.INTERNAL.withDescription("newStream() before transportReady()"), - attributes, - headers, - tracers); - } - - int callId = latestCallId++; - if (latestCallId == LAST_CALL_ID) { - latestCallId = FIRST_CALL_ID; - } - StatsTraceContext statsTraceContext = - StatsTraceContext.newClientContext(tracers, attributes, headers); - Inbound.ClientInbound inbound = - new Inbound.ClientInbound( - this, attributes, callId, GrpcUtil.shouldBeCountedForInUse(callOptions)); - if (ongoingCalls.putIfAbsent(callId, inbound) != null) { - Status failure = Status.INTERNAL.withDescription("Clashing call IDs"); - shutdownInternal(failure, true); - return newFailingClientStream(failure, attributes, headers, tracers); - } else { - if (inbound.countsForInUse() && numInUseStreams.getAndIncrement() == 0) { - clientTransportListener.transportInUse(true); - } - Outbound.ClientOutbound outbound = - new Outbound.ClientOutbound(this, callId, method, headers, statsTraceContext); - if (method.getType().clientSendsOneMessage()) { - return new SingleMessageClientStream(inbound, outbound, attributes); - } else { - return new MultiMessageClientStream(inbound, outbound, attributes); - } - } - } - - @Override - protected void unregisterInbound(Inbound inbound) { - if (inbound.countsForInUse() && numInUseStreams.decrementAndGet() == 0) { - clientTransportListener.transportInUse(false); - } - super.unregisterInbound(inbound); - } - - @Override - public void ping(final PingCallback callback, Executor executor) { - pingTracker.startPing(callback, executor); - } - - @Override - public synchronized void shutdown(Status reason) { - checkNotNull(reason, "reason"); - shutdownInternal(reason, false); - } - - @Override - public synchronized void shutdownNow(Status reason) { - checkNotNull(reason, "reason"); - shutdownInternal(reason, true); - } - - @Override - @GuardedBy("this") - public void notifyShutdown(Status status) { - clientTransportListener.transportShutdown(status); - } - - @Override - @GuardedBy("this") - public void notifyTerminated() { - if (numInUseStreams.getAndSet(0) > 0) { - clientTransportListener.transportInUse(false); - } - serviceBinding.unbind(); - clientTransportListener.transportTerminated(); - } - - @Override - @GuardedBy("this") - protected void handleSetupTransport(Parcel parcel) { - // Add the remote uid to our attributes. - attributes = setSecurityAttrs(attributes, Binder.getCallingUid()); - if (inState(TransportState.SETUP)) { - int version = parcel.readInt(); - IBinder binder = parcel.readStrongBinder(); - if (version != WIRE_FORMAT_VERSION) { - shutdownInternal( - Status.UNAVAILABLE.withDescription("Wire format version mismatch"), true); - } else if (binder == null) { - shutdownInternal( - Status.UNAVAILABLE.withDescription("Malformed SETUP_TRANSPORT data"), true); - } else { - offloadExecutor.execute(() -> checkSecurityPolicy(binder)); - } - } - } - - private void checkSecurityPolicy(IBinder binder) { - Status authorization; - Integer remoteUid; - synchronized (this) { - remoteUid = attributes.get(REMOTE_UID); - } - if (remoteUid == null) { - authorization = Status.UNAUTHENTICATED.withDescription("No remote UID available"); - } else { - authorization = securityPolicy.checkAuthorization(remoteUid); - } - synchronized (this) { - if (inState(TransportState.SETUP)) { - if (!authorization.isOk()) { - shutdownInternal(authorization, true); - } else if (!setOutgoingBinder(OneWayBinderProxy.wrap(binder, offloadExecutor))) { - shutdownInternal( - Status.UNAVAILABLE.withDescription("Failed to observe outgoing binder"), true); - } else { - // Check state again, since a failure inside setOutgoingBinder (or a callback it - // triggers), could have shut us down. - if (!isShutdown()) { - setState(TransportState.READY); - attributes = clientTransportListener.filterTransport(attributes); - clientTransportListener.transportReady(); - } - } - } - } - } - - @GuardedBy("this") - @Override - protected void handlePingResponse(Parcel parcel) { - pingTracker.onPingResponse(parcel.readInt()); - } - - private static ClientStream newFailingClientStream( - Status failure, Attributes attributes, Metadata headers, - ClientStreamTracer[] tracers) { - StatsTraceContext statsTraceContext = - StatsTraceContext.newClientContext(tracers, attributes, headers); - statsTraceContext.clientOutboundHeaders(); - return new FailingClientStream(failure, tracers); - } - - private static InternalLogId buildLogId( - Context sourceContext, AndroidComponentAddress targetAddress) { - return InternalLogId.allocate( - BinderClientTransport.class, - sourceContext.getClass().getSimpleName() + "->" + targetAddress); - } - - private static Attributes buildClientAttributes( - Attributes eagAttrs, - Context sourceContext, - AndroidComponentAddress targetAddress, - InboundParcelablePolicy inboundParcelablePolicy) { - return Attributes.newBuilder() - .set(GrpcAttributes.ATTR_SECURITY_LEVEL, SecurityLevel.NONE) // Trust noone for now. - .set(GrpcAttributes.ATTR_CLIENT_EAG_ATTRS, eagAttrs) - .set(Grpc.TRANSPORT_ATTR_LOCAL_ADDR, AndroidComponentAddress.forContext(sourceContext)) - .set(Grpc.TRANSPORT_ATTR_REMOTE_ADDR, targetAddress) - .set(INBOUND_PARCELABLE_POLICY, inboundParcelablePolicy) - .build(); - } - - private static Attributes setSecurityAttrs(Attributes attributes, int uid) { - return attributes.toBuilder() - .set(REMOTE_UID, uid) - .set( - GrpcAttributes.ATTR_SECURITY_LEVEL, - uid == Process.myUid() - ? SecurityLevel.PRIVACY_AND_INTEGRITY - : SecurityLevel.INTEGRITY) // TODO: Have the SecrityPolicy decide this. - .build(); - } - } - - /** Concrete server-side transport implementation. */ - @Internal - public static final class BinderServerTransport extends BinderTransport implements ServerTransport { - - private final List streamTracerFactories; - @Nullable private ServerTransportListener serverTransportListener; - - /** - * Constructs a new transport instance. - * - * @param binderDecorator used to decorate 'callbackBinder', for fault injection. - */ - public BinderServerTransport( - ObjectPool executorServicePool, - Attributes attributes, - List streamTracerFactories, - OneWayBinderProxy.Decorator binderDecorator, - IBinder callbackBinder) { - super(executorServicePool, attributes, binderDecorator, buildLogId(attributes)); - this.streamTracerFactories = streamTracerFactories; - // TODO(jdcormie): Plumb in the Server's executor() and use it here instead. - setOutgoingBinder(OneWayBinderProxy.wrap(callbackBinder, getScheduledExecutorService())); - } - - public synchronized void setServerTransportListener(ServerTransportListener serverTransportListener) { - this.serverTransportListener = serverTransportListener; - if (isShutdown()) { - setState(TransportState.SHUTDOWN_TERMINATED); - notifyTerminated(); - releaseExecutors(); - } else { - sendSetupTransaction(); - // Check we're not shutdown again, since a failure inside sendSetupTransaction (or a - // callback it triggers), could have shut us down. - if (!isShutdown()) { - setState(TransportState.READY); - attributes = serverTransportListener.transportReady(attributes); - } - } - } - - StatsTraceContext createStatsTraceContext(String methodName, Metadata headers) { - return StatsTraceContext.newServerContext(streamTracerFactories, methodName, headers); - } - - synchronized Status startStream(ServerStream stream, String methodName, Metadata headers) { - if (isShutdown()) { - return Status.UNAVAILABLE.withDescription("transport is shutdown"); - } else { - serverTransportListener.streamCreated(stream, methodName, headers); - return Status.OK; - } - } - - @Override - @GuardedBy("this") - public void notifyShutdown(Status status) { - // Nothing to do. - } - - @Override - @GuardedBy("this") - public void notifyTerminated() { - if (serverTransportListener != null) { - serverTransportListener.transportTerminated(); - } - } - - @Override - public synchronized void shutdown() { - shutdownInternal(Status.OK, false); - } - - @Override - public synchronized void shutdownNow(Status reason) { - shutdownInternal(reason, true); - } - - @Override - @Nullable - @GuardedBy("this") - protected Inbound createInbound(int callId) { - return new Inbound.ServerInbound(this, attributes, callId); - } - - private static InternalLogId buildLogId(Attributes attributes) { - return InternalLogId.allocate( - BinderServerTransport.class, "from " + attributes.get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR)); - } - } - private static void checkTransition(TransportState current, TransportState next) { switch (next) { case SETUP: @@ -938,10 +598,15 @@ private static void checkTransition(TransportState current, TransportState next) } @VisibleForTesting - Map> getOngoingCalls() { + Map> getOngoingCalls() { return ongoingCalls; } + @VisibleForTesting + synchronized LeakSafeOneWayBinder getIncomingBinderForTesting() { + return this.incomingBinder; + } + private static Status statusFromRemoteException(RemoteException e) { if (e instanceof DeadObjectException || e instanceof TransactionTooLargeException) { // These are to be expected from time to time and can simply be retried. @@ -951,4 +616,3 @@ private static Status statusFromRemoteException(RemoteException e) { return Status.INTERNAL.withCause(e); } } - diff --git a/binder/src/main/java/io/grpc/binder/internal/BinderTransportSecurity.java b/binder/src/main/java/io/grpc/binder/internal/BinderTransportSecurity.java index 56464d58a4b..6f95ef8a83c 100644 --- a/binder/src/main/java/io/grpc/binder/internal/BinderTransportSecurity.java +++ b/binder/src/main/java/io/grpc/binder/internal/BinderTransportSecurity.java @@ -20,7 +20,7 @@ import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.MoreExecutors; - +import com.google.errorprone.annotations.CheckReturnValue; import io.grpc.Attributes; import io.grpc.Internal; import io.grpc.Metadata; @@ -32,19 +32,17 @@ import io.grpc.ServerInterceptor; import io.grpc.Status; import io.grpc.internal.GrpcAttributes; - import java.util.concurrent.CancellationException; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ExecutionException; import java.util.concurrent.Executor; -import javax.annotation.CheckReturnValue; import javax.annotation.Nullable; /** * Manages security for an Android Service hosted gRPC server. * - *

Attaches authorization state to a newly-created transport, and contains a - * ServerInterceptor which ensures calls are authorized before allowing them to proceed. + *

Attaches authorization state to a newly-created transport, and contains a ServerInterceptor + * which ensures calls are authorized before allowing them to proceed. */ public final class BinderTransportSecurity { @@ -57,11 +55,10 @@ private BinderTransportSecurity() {} * Install a security policy on an about-to-be created server. * * @param serverBuilder The ServerBuilder being used to create the server. - * @param executor The executor in which the authorization result will be handled. */ @Internal - public static void installAuthInterceptor(ServerBuilder serverBuilder, Executor executor) { - serverBuilder.intercept(new ServerAuthInterceptor(executor)); + public static void installAuthInterceptor(ServerBuilder serverBuilder) { + serverBuilder.intercept(new ServerAuthInterceptor()); } /** @@ -71,14 +68,18 @@ public static void installAuthInterceptor(ServerBuilder serverBuilder, Execut * @param builder The {@link Attributes.Builder} for the transport being created. * @param remoteUid The remote UID of the transport. * @param serverPolicyChecker The policy checker for this transport. + * @param executor used for calling into the application. Must outlive the transport. */ @Internal public static void attachAuthAttrs( - Attributes.Builder builder, int remoteUid, ServerPolicyChecker serverPolicyChecker) { + Attributes.Builder builder, + int remoteUid, + ServerPolicyChecker serverPolicyChecker, + Executor executor) { builder .set( TRANSPORT_AUTHORIZATION_STATE, - new TransportAuthorizationState(remoteUid, serverPolicyChecker)) + new TransportAuthorizationState(remoteUid, serverPolicyChecker, executor)) .set(GrpcAttributes.ATTR_SECURITY_LEVEL, SecurityLevel.PRIVACY_AND_INTEGRITY); } @@ -88,25 +89,20 @@ public static void attachAuthAttrs( */ private static final class ServerAuthInterceptor implements ServerInterceptor { - private final Executor executor; - - ServerAuthInterceptor(Executor executor) { - this.executor = executor; - } - @Override public ServerCall.Listener interceptCall( ServerCall call, Metadata headers, ServerCallHandler next) { + TransportAuthorizationState transportAuthState = + call.getAttributes().get(TRANSPORT_AUTHORIZATION_STATE); ListenableFuture authStatusFuture = - call.getAttributes() - .get(TRANSPORT_AUTHORIZATION_STATE) - .checkAuthorization(call.getMethodDescriptor()); + transportAuthState.checkAuthorization(call.getMethodDescriptor()); // Most SecurityPolicy will have synchronous implementations that provide an // immediately-resolved Future. In that case, short-circuit to avoid unnecessary allocations // and asynchronous code if the authorization result is already present. if (!authStatusFuture.isDone()) { - return newServerCallListenerForPendingAuthResult(authStatusFuture, call, headers, next); + return newServerCallListenerForPendingAuthResult( + authStatusFuture, transportAuthState.executor, call, headers, next); } Status authStatus; @@ -130,31 +126,33 @@ public ServerCall.Listener interceptCall( } private ServerCall.Listener newServerCallListenerForPendingAuthResult( - ListenableFuture authStatusFuture, - ServerCall call, - Metadata headers, - ServerCallHandler next) { + ListenableFuture authStatusFuture, + Executor executor, + ServerCall call, + Metadata headers, + ServerCallHandler next) { PendingAuthListener listener = new PendingAuthListener<>(); Futures.addCallback( - authStatusFuture, - new FutureCallback() { - @Override - public void onSuccess(Status authStatus) { - if (!authStatus.isOk()) { - call.close(authStatus, new Metadata()); - return; - } - - listener.startCall(call, headers, next); - } - - @Override - public void onFailure(Throwable t) { - call.close( - Status.INTERNAL.withCause(t).withDescription("Authorization future failed"), - new Metadata()); - } - }, executor); + authStatusFuture, + new FutureCallback() { + @Override + public void onSuccess(Status authStatus) { + if (!authStatus.isOk()) { + call.close(authStatus, new Metadata()); + return; + } + + listener.startCall(call, headers, next); + } + + @Override + public void onFailure(Throwable t) { + call.close( + Status.INTERNAL.withCause(t).withDescription("Authorization future failed"), + new Metadata()); + } + }, + executor); return listener; } } @@ -167,10 +165,16 @@ private static final class TransportAuthorizationState { private final int uid; private final ServerPolicyChecker serverPolicyChecker; private final ConcurrentHashMap> serviceAuthorization; + private final Executor executor; - TransportAuthorizationState(int uid, ServerPolicyChecker serverPolicyChecker) { + /** + * @param executor used for calling into the application. Must outlive the transport. + */ + TransportAuthorizationState( + int uid, ServerPolicyChecker serverPolicyChecker, Executor executor) { this.uid = uid; this.serverPolicyChecker = serverPolicyChecker; + this.executor = executor; serviceAuthorization = new ConcurrentHashMap<>(8); } @@ -201,15 +205,18 @@ ListenableFuture checkAuthorization(MethodDescriptor method) { serverPolicyChecker.checkAuthorizationForServiceAsync(uid, serviceName); if (useCache) { serviceAuthorization.putIfAbsent(serviceName, authorization); - Futures.addCallback(authorization, new FutureCallback() { - @Override - public void onSuccess(Status result) {} - - @Override - public void onFailure(Throwable t) { - serviceAuthorization.remove(serviceName, authorization); - } - }, MoreExecutors.directExecutor()); + Futures.addCallback( + authorization, + new FutureCallback() { + @Override + public void onSuccess(Status result) {} + + @Override + public void onFailure(Throwable t) { + serviceAuthorization.remove(serviceName, authorization); + } + }, + MoreExecutors.directExecutor()); } return authorization; } @@ -234,16 +241,8 @@ public interface ServerPolicyChecker { * @param uid The Android UID to authenticate. * @param serviceName The name of the gRPC service being called. * @return a future with the result of the authorization check. A failed future represents a - * failure to perform the authorization check, not that the access is denied. + * failure to perform the authorization check, not that the access is denied. */ ListenableFuture checkAuthorizationForServiceAsync(int uid, String serviceName); } - - /** - * A listener invoked when the {@link io.grpc.binder.internal.BinderServer} shuts down, allowing - * resources to be potentially cleaned up. - */ - public interface ShutdownListener { - void onServerShutdown(); - } } diff --git a/binder/src/main/java/io/grpc/binder/internal/BlockInputStream.java b/binder/src/main/java/io/grpc/binder/internal/BlockInputStream.java index 1ac1531da18..ae2a650831c 100644 --- a/binder/src/main/java/io/grpc/binder/internal/BlockInputStream.java +++ b/binder/src/main/java/io/grpc/binder/internal/BlockInputStream.java @@ -28,20 +28,17 @@ /** * A simple InputStream from a 2-dimensional byte array. * - * Used to provide message data from incoming blocks of data. It is assumed that - * all byte arrays passed in the constructor of this this class are owned by the new - * instance. + *

Used to provide message data from incoming blocks of data. It is assumed that all byte arrays + * passed in the constructor of this this class are owned by the new instance. * - * This also assumes byte arrays are created by the BlockPool class, and should - * be returned to it when this class is closed. + *

This also assumes byte arrays are created by the BlockPool class, and should be returned to it + * when this class is closed. */ @NotThreadSafe final class BlockInputStream extends InputStream implements KnownLength, Drainable { - @Nullable - private byte[][] blocks; - @Nullable - private byte[] currentBlock; + @Nullable private byte[][] blocks; + @Nullable private byte[] currentBlock; private int blockIndex; private int blockOffset; private int available; @@ -50,8 +47,7 @@ final class BlockInputStream extends InputStream implements KnownLength, Drainab /** * Creates a new stream with a single block. * - * @param block The single byte array block, ownership of which is - * passed to this instance. + * @param block The single byte array block, ownership of which is passed to this instance. */ BlockInputStream(byte[] block) { this.blocks = null; @@ -62,10 +58,10 @@ final class BlockInputStream extends InputStream implements KnownLength, Drainab /** * Creates a new stream from a sequence of blocks. * - * @param blocks A two dimensional byte array containing the data. Ownership - * of all blocks is passed to this instance. - * @param available The number of bytes available in total. This may be - * less than (but never more than) the total size of all byte arrays in blocks. + * @param blocks A two dimensional byte array containing the data. Ownership of all blocks is + * passed to this instance. + * @param available The number of bytes available in total. This may be less than (but never more + * than) the total size of all byte arrays in blocks. */ BlockInputStream(byte[][] blocks, int available) { this.blocks = blocks; diff --git a/binder/src/main/java/io/grpc/binder/internal/BlockPool.java b/binder/src/main/java/io/grpc/binder/internal/BlockPool.java index 9ca766791bf..985e465ab4b 100644 --- a/binder/src/main/java/io/grpc/binder/internal/BlockPool.java +++ b/binder/src/main/java/io/grpc/binder/internal/BlockPool.java @@ -27,9 +27,8 @@ * byte array of size N. This means we can't simply read into a large block and be done with it, we * need to allocate a new buffer specifically. Boo, Android. * - *

When writing data though, we can use a fixed-size buffer, so when large messages are - * split into standard-sized blocks, we only need a byte array allocation to read the last - * block. + *

When writing data though, we can use a fixed-size buffer, so when large messages are split + * into standard-sized blocks, we only need a byte array allocation to read the last block. * *

This class maintains a pool of blocks of standard size, but also provides smaller blocks when * requested. Currently, blocks of standard size are retained in the pool, when released, but we @@ -38,10 +37,10 @@ final class BlockPool { /** - * The size of each standard block. (Currently 16k) - * The block size must be at least as large as the maximum header list size. + * The size of each standard block. (Currently 16k) The block size must be at least as large as + * the maximum header list size. */ - private static final int BLOCK_SIZE = Math.max(16 * 1024, GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE); + static final int BLOCK_SIZE = Math.max(16 * 1024, GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE); /** * Maximum number of blocks to keep around. (Max 128k). This limit is a judgement call. 128k is diff --git a/binder/src/main/java/io/grpc/binder/internal/FlowController.java b/binder/src/main/java/io/grpc/binder/internal/FlowController.java index 1972ea00e6c..135f363a01e 100644 --- a/binder/src/main/java/io/grpc/binder/internal/FlowController.java +++ b/binder/src/main/java/io/grpc/binder/internal/FlowController.java @@ -15,7 +15,7 @@ */ package io.grpc.binder.internal; -import javax.annotation.concurrent.GuardedBy; +import com.google.errorprone.annotations.concurrent.GuardedBy; /** Keeps track of the number of bytes on the wire in a single direction. */ final class FlowController { diff --git a/binder/src/main/java/io/grpc/binder/internal/Inbound.java b/binder/src/main/java/io/grpc/binder/internal/Inbound.java index 5ab96085a41..83fc8273d6f 100644 --- a/binder/src/main/java/io/grpc/binder/internal/Inbound.java +++ b/binder/src/main/java/io/grpc/binder/internal/Inbound.java @@ -20,6 +20,7 @@ import static com.google.common.base.Preconditions.checkState; import android.os.Parcel; +import com.google.errorprone.annotations.concurrent.GuardedBy; import io.grpc.Attributes; import io.grpc.Metadata; import io.grpc.Status; @@ -34,7 +35,6 @@ import java.io.InputStream; import java.util.ArrayList; import javax.annotation.Nullable; -import javax.annotation.concurrent.GuardedBy; /** * Handles incoming binder transactions for a single stream, turning those transactions into calls @@ -42,9 +42,10 @@ * *

Out-of-order messages are reassembled into their correct order. */ -abstract class Inbound implements StreamListener.MessageProducer { +abstract class Inbound + implements StreamListener.MessageProducer { - protected final BinderTransport transport; + protected final T transport; protected final Attributes attributes; final int callId; @@ -145,7 +146,7 @@ enum State { @GuardedBy("this") private boolean producingMessages; - private Inbound(BinderTransport transport, Attributes attributes, int callId) { + private Inbound(T transport, Attributes attributes, int callId) { this.transport = transport; this.attributes = attributes; this.callId = callId; @@ -344,8 +345,7 @@ final synchronized void handleTransaction(Parcel parcel) { } int index = parcel.readInt(); boolean hasPrefix = TransactionUtils.hasFlag(flags, TransactionUtils.FLAG_PREFIX); - boolean hasMessageData = - TransactionUtils.hasFlag(flags, TransactionUtils.FLAG_MESSAGE_DATA); + boolean hasMessageData = TransactionUtils.hasFlag(flags, TransactionUtils.FLAG_MESSAGE_DATA); boolean hasSuffix = TransactionUtils.hasFlag(flags, TransactionUtils.FLAG_SUFFIX); if (hasPrefix) { handlePrefix(flags, parcel); @@ -400,6 +400,13 @@ private void handleMessageData(int flags, int index, Parcel parcel) throws Statu numBytes = parcel.dataPosition() - startPos; } else { numBytes = parcel.readInt(); + if (numBytes > parcel.dataAvail()) { + throw Status.INTERNAL + .withDescription( + "Message size is larger than remaining parcel size: " + + numBytes + " > " + parcel.dataAvail()) + .asException(); + } block = BlockPool.acquireBlock(numBytes); if (numBytes > 0) { parcel.readByteArray(block); @@ -552,7 +559,7 @@ public synchronized String toString() { // ====================================== // Client-side inbound transactions. - static final class ClientInbound extends Inbound { + static final class ClientInbound extends Inbound { private final boolean countsForInUse; @@ -565,7 +572,10 @@ static final class ClientInbound extends Inbound { private Metadata trailers; ClientInbound( - BinderTransport transport, Attributes attributes, int callId, boolean countsForInUse) { + BinderClientTransport transport, + Attributes attributes, + int callId, + boolean countsForInUse) { super(transport, attributes, callId); this.countsForInUse = countsForInUse; } @@ -579,7 +589,7 @@ boolean countsForInUse() { @GuardedBy("this") protected void handlePrefix(int flags, Parcel parcel) throws StatusException { Metadata headers = MetadataHelper.readMetadata(parcel, attributes); - statsTraceContext.clientInboundHeaders(); + statsTraceContext.clientInboundHeaders(headers); listener.headersRead(headers); } @@ -609,14 +619,9 @@ protected void deliverCloseAbnormal(Status status) { // ====================================== // Server-side inbound transactions. - static final class ServerInbound extends Inbound { - - private final BinderTransport.BinderServerTransport serverTransport; - - ServerInbound( - BinderTransport.BinderServerTransport transport, Attributes attributes, int callId) { + static final class ServerInbound extends Inbound { + ServerInbound(BinderServerTransport transport, Attributes attributes, int callId) { super(transport, attributes, callId); - this.serverTransport = transport; } @GuardedBy("this") @@ -625,17 +630,16 @@ protected void handlePrefix(int flags, Parcel parcel) throws StatusException { String methodName = parcel.readString(); Metadata headers = MetadataHelper.readMetadata(parcel, attributes); - StatsTraceContext statsTraceContext = - serverTransport.createStatsTraceContext(methodName, headers); + StatsTraceContext statsTraceContext = transport.createStatsTraceContext(methodName, headers); Outbound.ServerOutbound outbound = - new Outbound.ServerOutbound(serverTransport, callId, statsTraceContext); + new Outbound.ServerOutbound(transport, callId, statsTraceContext); ServerStream stream; if ((flags & TransactionUtils.FLAG_EXPECT_SINGLE_MESSAGE) != 0) { stream = new SingleMessageServerStream(this, outbound, attributes); } else { stream = new MultiMessageServerStream(this, outbound, attributes); } - Status status = serverTransport.startStream(stream, methodName, headers); + Status status = transport.startStream(stream, methodName, headers); if (status.isOk()) { checkNotNull(listener); // Is it ok to assume this will happen synchronously? if (transport.isReady()) { diff --git a/binder/src/main/java/io/grpc/binder/internal/IntentNameResolver.java b/binder/src/main/java/io/grpc/binder/internal/IntentNameResolver.java new file mode 100644 index 00000000000..ce3e2a96a42 --- /dev/null +++ b/binder/src/main/java/io/grpc/binder/internal/IntentNameResolver.java @@ -0,0 +1,299 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.grpc.binder.internal; + +import static com.google.common.base.Preconditions.checkNotNull; +import static com.google.common.base.Preconditions.checkState; +import static io.grpc.binder.internal.SystemApis.createContextAsUser; + +import android.annotation.SuppressLint; +import android.content.BroadcastReceiver; +import android.content.ComponentName; +import android.content.Context; +import android.content.Intent; +import android.content.IntentFilter; +import android.content.pm.PackageManager; +import android.content.pm.ResolveInfo; +import android.os.Build; +import android.os.UserHandle; +import com.google.common.collect.ImmutableMap; +import com.google.common.util.concurrent.FutureCallback; +import com.google.common.util.concurrent.Futures; +import com.google.common.util.concurrent.ListenableFuture; +import com.google.common.util.concurrent.MoreExecutors; +import io.grpc.Attributes; +import io.grpc.EquivalentAddressGroup; +import io.grpc.NameResolver; +import io.grpc.Status; +import io.grpc.StatusException; +import io.grpc.StatusOr; +import io.grpc.SynchronizationContext; +import io.grpc.binder.AndroidComponentAddress; +import io.grpc.binder.ApiConstants; +import java.util.ArrayList; +import java.util.List; +import java.util.Set; +import java.util.concurrent.Executor; +import javax.annotation.Nullable; + +/** + * A {@link NameResolver} that resolves Android-standard "intent:" target URIs to the list of {@link + * AndroidComponentAddress} that match it by manifest intent filter. + */ +final class IntentNameResolver extends NameResolver { + private final Intent targetIntent; // Never mutated. + @Nullable private final UserHandle targetUser; // null means same user that hosts this process. + private final Context targetUserContext; + private final Executor offloadExecutor; + private final Executor sequentialExecutor; + private final SynchronizationContext syncContext; + private final ServiceConfigParser serviceConfigParser; + + // Accessed only on `sequentialExecutor` + @Nullable private PackageChangeReceiver receiver; // != null when registered + + // Accessed only on 'syncContext'. + private boolean shutdown; + private boolean queryNeeded; + @Nullable private Listener2 listener; // != null after start(). + @Nullable private ListenableFuture queryResultFuture; // != null when querying. + + @EquivalentAddressGroup.Attr + private static final Attributes CONSTANT_EAG_ATTRS = + Attributes.newBuilder() + // Servers discovered in PackageManager are especially untrusted. After all, any app can + // declare any intent filter it wants! Require pre-authorization so that unauthorized apps + // don't even get a chance to run onCreate()/onBind(). + .set(ApiConstants.PRE_AUTH_SERVER_OVERRIDE, true) + .build(); + + IntentNameResolver(Intent targetIntent, Args args) { + this.targetIntent = targetIntent; + this.targetUser = args.getArg(ApiConstants.TARGET_ANDROID_USER); + Context context = + checkNotNull(args.getArg(ApiConstants.SOURCE_ANDROID_CONTEXT), "SOURCE_ANDROID_CONTEXT") + .getApplicationContext(); + this.targetUserContext = + targetUser != null ? createContextForTargetUserOrThrow(context, targetUser) : context; + // This Executor is nominally optional but all grpc-java Channels provide it since 1.25. + this.offloadExecutor = + checkNotNull(args.getOffloadExecutor(), "NameResolver.Args.getOffloadExecutor()"); + // Ensures start()'s work runs before resolve()'s' work, and both run before shutdown()'s. + this.sequentialExecutor = MoreExecutors.newSequentialExecutor(offloadExecutor); + this.syncContext = args.getSynchronizationContext(); + this.serviceConfigParser = args.getServiceConfigParser(); + } + + private static Context createContextForTargetUserOrThrow(Context context, UserHandle targetUser) { + try { + return createContextAsUser(context, targetUser, /* flags= */ 0); // @SystemApi since R. + } catch (ReflectiveOperationException e) { + throw new IllegalArgumentException( + "TARGET_ANDROID_USER NameResolver.Arg requires SDK_INT >= R and @SystemApi visibility"); + } + } + + @Override + public void start(Listener2 listener) { + checkState(this.listener == null, "Already started!"); + checkState(!shutdown, "Resolver is shutdown"); + this.listener = checkNotNull(listener); + sequentialExecutor.execute(this::registerReceiver); + resolve(); + } + + @Override + public void refresh() { + checkState(listener != null, "Not started!"); + resolve(); + } + + private void resolve() { + syncContext.throwIfNotInThisSynchronizationContext(); + + if (shutdown) { + return; + } + + // We can't block here in 'syncContext' so we offload PackageManager queries to an Executor. + // But offloading complicates things a bit because other calls can arrive while we wait for the + // results. We keep 'listener' up-to-date with the latest state in PackageManager by doing: + // 1. Only one query-and-report-to-listener operation at a time. + // 2. At least one query-and-report-to-listener AFTER every PackageManager state change. + if (queryResultFuture == null) { + queryResultFuture = Futures.submit(this::queryPackageManager, sequentialExecutor); + queryResultFuture.addListener(this::onQueryComplete, syncContext); + } else { + // There's already a query in-flight but (2) says we need at least one more. Our sequential + // Executor would be enough to ensure (1) but we also don't want a backlog of work to build up + // if things change rapidly. Just make a note to start a new query when this one finishes. + queryNeeded = true; + } + } + + private void onQueryComplete() { + syncContext.throwIfNotInThisSynchronizationContext(); + checkState(queryResultFuture != null); + checkState(queryResultFuture.isDone()); + + // Capture non-final `listener` here while we're on 'syncContext'. + Listener2 listener = checkNotNull(this.listener); + Futures.addCallback( + queryResultFuture, // Already isDone() so this execute()s immediately. + new FutureCallback() { + @Override + public void onSuccess(ResolutionResult result) { + listener.onResult2(result); + } + + @Override + public void onFailure(Throwable t) { + listener.onResult2( + ResolutionResult.newBuilder() + .setAddressesOrError(StatusOr.fromStatus(Status.fromThrowable(t))) + .build()); + } + }, + syncContext); // Already on 'syncContext' but addCallback() is faster than try/get/catch. + queryResultFuture = null; + + if (queryNeeded) { + // One or more resolve() requests arrived while we were working on the last one. Just one + // follow-on query can subsume all of them. + queryNeeded = false; + resolve(); + } + } + + @Override + public String getServiceAuthority() { + return "localhost"; + } + + @Override + public void shutdown() { + syncContext.throwIfNotInThisSynchronizationContext(); + if (!shutdown) { + shutdown = true; + sequentialExecutor.execute(this::maybeUnregisterReceiver); + } + } + + private ResolutionResult queryPackageManager() throws StatusException { + List queryResults = queryIntentServices(targetIntent); + + // Avoid a spurious UnsafeIntentLaunchViolation later. Since S, Android's StrictMode is very + // conservative, marking any Intent parsed from a string as suspicious and complaining when you + // bind to it. But all this is pointless with grpc-binder, which already goes even further by + // not trusting addresses at all! Instead, we rely on SecurityPolicy, which won't allow a + // connection to an unauthorized server UID no matter how you got there. + Intent prototypeBindIntent = sanitize(targetIntent); + + // Model each matching android.app.Service as an EAG (server) with a single address. + List addresses = new ArrayList<>(); + for (ResolveInfo resolveInfo : queryResults) { + prototypeBindIntent.setComponent( + new ComponentName(resolveInfo.serviceInfo.packageName, resolveInfo.serviceInfo.name)); + addresses.add( + new EquivalentAddressGroup( + AndroidComponentAddress.newBuilder() + .setBindIntent(prototypeBindIntent) // Makes a copy. + .setTargetUser(targetUser) + .build(), + CONSTANT_EAG_ATTRS)); + } + + return ResolutionResult.newBuilder() + .setAddressesOrError(StatusOr.fromValue(addresses)) + // Empty service config means we get the default 'pick_first' load balancing policy. + .setServiceConfig(serviceConfigParser.parseServiceConfig(ImmutableMap.of())) + .build(); + } + + private List queryIntentServices(Intent intent) throws StatusException { + int flags = 0; + if (Build.VERSION.SDK_INT >= 29) { + // Don't match direct-boot-unaware Services that can't presently be created. We'll query again + // after the user is unlocked. The MATCH_DIRECT_BOOT_AUTO behavior is actually the default but + // being explicit here avoids an android.os.strictmode.ImplicitDirectBootViolation. + flags |= PackageManager.MATCH_DIRECT_BOOT_AUTO; + } + + List intentServices = + targetUserContext.getPackageManager().queryIntentServices(intent, flags); + if (intentServices == null || intentServices.isEmpty()) { + // Must be the same as when ServiceBinding's call to bindService() returns false. + throw Status.UNIMPLEMENTED + .withDescription("Service not found for intent " + intent) + .asException(); + } + return intentServices; + } + + // Returns a new Intent with the same action, data and categories as 'input'. + private static Intent sanitize(Intent input) { + Intent output = new Intent(); + output.setAction(input.getAction()); + output.setData(input.getData()); + + Set categories = input.getCategories(); + if (categories != null) { + for (String category : categories) { + output.addCategory(category); + } + } + // Don't bother copying extras and flags since AndroidComponentAddress (rightly) ignores them. + // Don't bother copying package or ComponentName either, since we're about to set that. + return output; + } + + final class PackageChangeReceiver extends BroadcastReceiver { + @Override + public void onReceive(Context context, Intent intent) { + // Get off the main thread and into the correct SynchronizationContext. + syncContext.executeLater(IntentNameResolver.this::resolve); + offloadExecutor.execute(syncContext::drain); + } + } + + @SuppressLint("UnprotectedReceiver") // All of these are protected system broadcasts. + private void registerReceiver() { + checkState(receiver == null, "Already registered!"); + receiver = new PackageChangeReceiver(); + IntentFilter filter = new IntentFilter(); + filter.addDataScheme("package"); + filter.addAction(Intent.ACTION_PACKAGE_ADDED); + filter.addAction(Intent.ACTION_PACKAGE_CHANGED); + filter.addAction(Intent.ACTION_PACKAGE_REMOVED); + filter.addAction(Intent.ACTION_PACKAGE_REPLACED); + + targetUserContext.registerReceiver(receiver, filter); + + if (Build.VERSION.SDK_INT >= 24) { + // Clients running in direct boot mode must refresh() when the user is unlocked because + // that's when `directBootAware=false` services become visible in queryIntentServices() + // results. ACTION_BOOT_COMPLETED would work too but it's delivered with lower priority. + targetUserContext.registerReceiver(receiver, new IntentFilter(Intent.ACTION_USER_UNLOCKED)); + } + } + + private void maybeUnregisterReceiver() { + if (receiver != null) { // NameResolver API contract appears to allow shutdown without start(). + targetUserContext.unregisterReceiver(receiver); + receiver = null; + } + } +} diff --git a/binder/src/main/java/io/grpc/binder/internal/IntentNameResolverProvider.java b/binder/src/main/java/io/grpc/binder/internal/IntentNameResolverProvider.java new file mode 100644 index 00000000000..5a3c9fcc986 --- /dev/null +++ b/binder/src/main/java/io/grpc/binder/internal/IntentNameResolverProvider.java @@ -0,0 +1,88 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.grpc.binder.internal; + +import static android.content.Intent.URI_INTENT_SCHEME; + +import android.content.Intent; +import com.google.common.collect.ImmutableSet; +import io.grpc.NameResolver; +import io.grpc.Uri; +import io.grpc.NameResolver.Args; +import io.grpc.NameResolverProvider; +import io.grpc.binder.AndroidComponentAddress; +import java.net.SocketAddress; +import java.net.URI; +import java.net.URISyntaxException; +import java.util.Objects; +import javax.annotation.Nullable; + +/** + * A {@link NameResolverProvider} that handles Android-standard "intent:" target URIs, resolving + * them to the list of {@link AndroidComponentAddress} that match by manifest intent filter. + */ +public final class IntentNameResolverProvider extends NameResolverProvider { + + static final String ANDROID_INTENT_SCHEME = "intent"; + + @Override + public String getDefaultScheme() { + return ANDROID_INTENT_SCHEME; + } + + @Nullable + @Override + public NameResolver newNameResolver(URI targetUri, final Args args) { + if (Objects.equals(targetUri.getScheme(), ANDROID_INTENT_SCHEME)) { + return new IntentNameResolver(parseUriArg(targetUri.toString()), args); + } else { + return null; + } + } + + @Nullable + @Override + public NameResolver newNameResolver(Uri targetUri, final Args args) { + if (Objects.equals(targetUri.getScheme(), ANDROID_INTENT_SCHEME)) { + return new IntentNameResolver(parseUriArg(targetUri.toString()), args); + } else { + return null; + } + } + + @Override + public boolean isAvailable() { + return true; + } + + @Override + public int priority() { + return 3; // Lower than DNS so we don't accidentally become the default scheme for a registry. + } + + @Override + public ImmutableSet> getProducedSocketAddressTypes() { + return ImmutableSet.of(AndroidComponentAddress.class); + } + + private static Intent parseUriArg(String targetUri) { + try { + return Intent.parseUri(targetUri, URI_INTENT_SCHEME); + } catch (URISyntaxException e) { + throw new IllegalArgumentException(e); + } + } +} diff --git a/binder/src/main/java/io/grpc/binder/internal/LeakSafeOneWayBinder.java b/binder/src/main/java/io/grpc/binder/internal/LeakSafeOneWayBinder.java index 4b735f9f596..c36bc7d5bd3 100644 --- a/binder/src/main/java/io/grpc/binder/internal/LeakSafeOneWayBinder.java +++ b/binder/src/main/java/io/grpc/binder/internal/LeakSafeOneWayBinder.java @@ -19,6 +19,7 @@ import android.os.Binder; import android.os.IBinder; import android.os.Parcel; +import androidx.annotation.BinderThread; import io.grpc.Internal; import java.util.logging.Level; import java.util.logging.Logger; @@ -58,6 +59,7 @@ public interface TransactionHandler { * @return the value to return from {@link Binder#onTransact}. NB: "oneway" semantics mean this * result will not delivered to the caller of {@link IBinder#transact} */ + @BinderThread boolean handleTransaction(int code, Parcel data); } @@ -68,7 +70,26 @@ public LeakSafeOneWayBinder(TransactionHandler handler) { } public void detach() { - handler = null; + setHandler(null); + } + + /** Returns the current {@link TransactionHandler} or null if already detached. */ + public @Nullable TransactionHandler getHandler() { + return handler; + } + + /** + * Replaces the current {@link TransactionHandler} with `handler`. + * + *

{@link TransactionHandler} mutations race against incoming transactions except in the + * special case where the caller is already handling an incoming transaction on this same {@link + * LeakSafeOneWayBinder} instance. In that case, mutations are safe and the provided 'handler' is + * guaranteed to be used for the very next transaction. This follows from the one-at-a-time + * property of one-way Binder transactions as explained by {@link + * TransactionHandler#handleTransaction}. + */ + public void setHandler(@Nullable TransactionHandler handler) { + this.handler = handler; } @Override diff --git a/binder/src/main/java/io/grpc/binder/internal/MetadataHelper.java b/binder/src/main/java/io/grpc/binder/internal/MetadataHelper.java index bd473780788..7a8368d0b49 100644 --- a/binder/src/main/java/io/grpc/binder/internal/MetadataHelper.java +++ b/binder/src/main/java/io/grpc/binder/internal/MetadataHelper.java @@ -50,7 +50,7 @@ public final class MetadataHelper { /** The generic metadata marshaller we use for reading parcelables from the transport. */ private static final Metadata.BinaryStreamMarshaller TRANSPORT_INBOUND_MARSHALLER = - new ParcelableMetadataMarshaller<>(null, true); + new ParcelableMetadataMarshaller<>(null, true); /** Indicates the following value is a parcelable. */ private static final int PARCELABLE_SENTINEL = -1; @@ -96,14 +96,16 @@ public static void writeMetadata(Parcel parcel, @Nullable Metadata metadata) InputStream stream = (InputStream) value; int total = 0; while (total < buffer.length) { - int read = stream.read(buffer, total, buffer.length - total); + int read = stream.read(buffer, total, buffer.length - total); if (read == -1) { break; } total += read; } if (total == buffer.length) { - throw Status.RESOURCE_EXHAUSTED.withDescription("Metadata value too large").asException(); + throw Status.RESOURCE_EXHAUSTED + .withDescription("Metadata value too large") + .asException(); } parcel.writeInt(total); if (total > 0) { @@ -148,6 +150,9 @@ public static Metadata readMetadata(Parcel parcel, Attributes attributes) throws } int parcelableStartPos = parcel.dataPosition(); try { + // readParcelable(Classloader, Class<>) requires SDK 33 and at this layer we can't know + // value's type anyway. + @SuppressWarnings("deprecation") Parcelable value = parcel.readParcelable(MetadataHelper.class.getClassLoader()); if (value == null) { throw Status.INTERNAL.withDescription("Read null parcelable in metadata").asException(); @@ -179,10 +184,8 @@ public static Metadata readMetadata(Parcel parcel, Attributes attributes) throws } /** Read a byte array checking that we're not reading too much. */ - private static byte[] readBytesChecked( - Parcel parcel, - int numBytes, - int bytesRead) throws StatusException { + private static byte[] readBytesChecked(Parcel parcel, int numBytes, int bytesRead) + throws StatusException { if (bytesRead + numBytes > GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE) { throw Status.RESOURCE_EXHAUSTED.withDescription("Metadata too large").asException(); } @@ -200,7 +203,8 @@ public static final class ParcelableMetadataMarshaller

@Nullable private final Parcelable.Creator

creator; private final boolean immutableType; - public ParcelableMetadataMarshaller(@Nullable Parcelable.Creator

creator, boolean immutableType) { + public ParcelableMetadataMarshaller( + @Nullable Parcelable.Creator

creator, boolean immutableType) { this.creator = creator; this.immutableType = immutableType; } diff --git a/binder/src/main/java/io/grpc/binder/internal/MultiMessageServerStream.java b/binder/src/main/java/io/grpc/binder/internal/MultiMessageServerStream.java index cba18ba5b2f..f54769caefa 100644 --- a/binder/src/main/java/io/grpc/binder/internal/MultiMessageServerStream.java +++ b/binder/src/main/java/io/grpc/binder/internal/MultiMessageServerStream.java @@ -64,6 +64,11 @@ public void setListener(ServerStreamListener listener) { } } + @Override + public void setOnReadyThreshold(int numBytes) { + // No-op + } + @Override public boolean isReady() { return outbound.isReady(); diff --git a/binder/src/main/java/io/grpc/binder/internal/OneWayBinderProxy.java b/binder/src/main/java/io/grpc/binder/internal/OneWayBinderProxy.java index 82da825b975..fd883ca3b62 100644 --- a/binder/src/main/java/io/grpc/binder/internal/OneWayBinderProxy.java +++ b/binder/src/main/java/io/grpc/binder/internal/OneWayBinderProxy.java @@ -71,15 +71,13 @@ public static OneWayBinderProxy wrap(IBinder iBinder, Executor inProcessThreadHo */ public interface Decorator { /** - * Returns an instance of {@link OneWayBinderProxy} that decorates {@code input} with some - * new behavior. + * Returns an instance of {@link OneWayBinderProxy} that decorates {@code input} with some new + * behavior. */ OneWayBinderProxy decorate(OneWayBinderProxy input); } - /** - * A {@link Decorator} that does nothing. - */ + /** A {@link Decorator} that does nothing. */ public static final Decorator IDENTITY_DECORATOR = (x) -> x; /** diff --git a/binder/src/main/java/io/grpc/binder/internal/Outbound.java b/binder/src/main/java/io/grpc/binder/internal/Outbound.java index e2896be02a1..7db5bf0fbe4 100644 --- a/binder/src/main/java/io/grpc/binder/internal/Outbound.java +++ b/binder/src/main/java/io/grpc/binder/internal/Outbound.java @@ -19,9 +19,9 @@ import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.base.Preconditions.checkState; import static io.grpc.internal.GrpcUtil.TIMEOUT_KEY; -import static java.lang.Math.max; import android.os.Parcel; +import com.google.errorprone.annotations.concurrent.GuardedBy; import io.grpc.Deadline; import io.grpc.Metadata; import io.grpc.MethodDescriptor; @@ -34,7 +34,6 @@ import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.TimeUnit; import javax.annotation.Nullable; -import javax.annotation.concurrent.GuardedBy; /** * Sends the set of outbound transactions for a single BinderStream (rpc). @@ -397,8 +396,7 @@ protected int writeSuffix(Parcel parcel) throws IOException { @GuardedBy("this") void setDeadline(Deadline deadline) { headers.discardAll(TIMEOUT_KEY); - long effectiveTimeoutNanos = max(0, deadline.timeRemaining(TimeUnit.NANOSECONDS)); - headers.put(TIMEOUT_KEY, effectiveTimeoutNanos); + headers.put(TIMEOUT_KEY, deadline.timeRemaining(TimeUnit.NANOSECONDS)); } } diff --git a/binder/src/main/java/io/grpc/binder/internal/ParcelableInputStream.java b/binder/src/main/java/io/grpc/binder/internal/ParcelableInputStream.java index 09b8cfc43f9..5f1132de54f 100644 --- a/binder/src/main/java/io/grpc/binder/internal/ParcelableInputStream.java +++ b/binder/src/main/java/io/grpc/binder/internal/ParcelableInputStream.java @@ -79,6 +79,8 @@ final class ParcelableInputStream

extends InputStream { @SuppressWarnings("unchecked") static

ParcelableInputStream

readFromParcel( Parcel parcel, ClassLoader classLoader) { + // readParcelable(Classloader, Class

) requires SDK 33 and this class isn't typesafe anyway. + @SuppressWarnings("deprecation") P value = (P) parcel.readParcelable(classLoader); return new ParcelableInputStream<>(null, value, true); } diff --git a/binder/src/main/java/io/grpc/binder/internal/PendingAuthListener.java b/binder/src/main/java/io/grpc/binder/internal/PendingAuthListener.java index cdafc9c9191..ad993b8c93b 100644 --- a/binder/src/main/java/io/grpc/binder/internal/PendingAuthListener.java +++ b/binder/src/main/java/io/grpc/binder/internal/PendingAuthListener.java @@ -4,10 +4,8 @@ import io.grpc.ServerCall; import io.grpc.ServerCallHandler; import io.grpc.Status; - import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.atomic.AtomicReference; - import javax.annotation.Nullable; /** @@ -23,16 +21,14 @@ final class PendingAuthListener extends ServerCall.Listener { PendingAuthListener() {} - void startCall(ServerCall call, - Metadata headers, - ServerCallHandler next) { + void startCall( + ServerCall call, Metadata headers, ServerCallHandler next) { ServerCall.Listener delegate; try { delegate = next.startCall(call, headers); } catch (RuntimeException e) { call.close( - Status - .INTERNAL + Status.INTERNAL .withCause(e) .withDescription("Failed to start server call after authorization check"), new Metadata()); diff --git a/binder/src/main/java/io/grpc/binder/internal/PingTracker.java b/binder/src/main/java/io/grpc/binder/internal/PingTracker.java index 640d6006824..5a4300443ba 100644 --- a/binder/src/main/java/io/grpc/binder/internal/PingTracker.java +++ b/binder/src/main/java/io/grpc/binder/internal/PingTracker.java @@ -17,12 +17,12 @@ package io.grpc.binder.internal; import com.google.common.base.Ticker; +import com.google.errorprone.annotations.concurrent.GuardedBy; import io.grpc.Status; import io.grpc.StatusException; import io.grpc.internal.ClientTransport.PingCallback; import java.util.concurrent.Executor; import javax.annotation.Nullable; -import javax.annotation.concurrent.GuardedBy; /** * Tracks an ongoing ping request for a client-side binder transport. We only handle a single active @@ -99,15 +99,14 @@ private final class Ping { private synchronized void fail(Status status) { if (!done) { done = true; - executor.execute(() -> callback.onFailure(status.asException())); + executor.execute(() -> callback.onFailure(status)); } } private synchronized void success() { if (!done) { done = true; - executor.execute( - () -> callback.onSuccess(ticker.read() - startTimeNanos)); + executor.execute(() -> callback.onSuccess(ticker.read() - startTimeNanos)); } } } diff --git a/binder/src/main/java/io/grpc/binder/internal/ServiceBinding.java b/binder/src/main/java/io/grpc/binder/internal/ServiceBinding.java index 32d0e7a4add..4b6bf7d06fb 100644 --- a/binder/src/main/java/io/grpc/binder/internal/ServiceBinding.java +++ b/binder/src/main/java/io/grpc/binder/internal/ServiceBinding.java @@ -17,24 +17,30 @@ package io.grpc.binder.internal; import static com.google.common.base.Preconditions.checkState; +import static io.grpc.binder.internal.SystemApis.createContextAsUser; import android.app.admin.DevicePolicyManager; import android.content.ComponentName; import android.content.Context; import android.content.Intent; import android.content.ServiceConnection; +import android.content.pm.PackageManager; +import android.content.pm.ResolveInfo; +import android.content.pm.ServiceInfo; +import android.os.Build; import android.os.IBinder; import android.os.UserHandle; import androidx.annotation.AnyThread; import androidx.annotation.MainThread; import com.google.common.annotations.VisibleForTesting; +import com.google.errorprone.annotations.concurrent.GuardedBy; import io.grpc.Status; +import io.grpc.StatusException; import io.grpc.binder.BinderChannelCredentials; import java.util.concurrent.Executor; import java.util.logging.Level; import java.util.logging.Logger; import javax.annotation.Nullable; -import javax.annotation.concurrent.GuardedBy; import javax.annotation.concurrent.ThreadSafe; /** @@ -96,6 +102,9 @@ public String methodName() { private State reportedState; // Only used on the main thread. + @GuardedBy("this") + private ComponentName connectedServiceName; + @AnyThread ServiceBinding( Executor mainThreadExecutor, @@ -147,12 +156,7 @@ public synchronized void bind() { state = State.BINDING; Status bindResult = bindInternal( - sourceContext, - bindIntent, - this, - bindFlags, - channelCredentials, - targetUserHandle); + sourceContext, bindIntent, this, bindFlags, channelCredentials, targetUserHandle); if (!bindResult.isOk()) { handleBindServiceFailure(sourceContext, this); state = State.UNBOUND; @@ -184,26 +188,38 @@ private static Status bindInternal( } boolean bindResult = false; switch (bindMethodType) { - case BIND_SERVICE: + case BIND_SERVICE: bindResult = context.bindService(bindIntent, conn, flags); break; case BIND_SERVICE_AS_USER: - bindResult = context.bindServiceAsUser(bindIntent, conn, flags, targetUserHandle); + if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.R) { + // We don't need SystemApis because bindServiceAsUser() is simply public in R+. + bindResult = context.bindServiceAsUser(bindIntent, conn, flags, targetUserHandle); + } else { + // TODO(#12279): Use SystemApis to make this work pre-R. + return Status.INTERNAL.withDescription("Cross user Channel requires Android R+"); + } break; case DEVICE_POLICY_BIND_SEVICE_ADMIN: DevicePolicyManager devicePolicyManager = (DevicePolicyManager) context.getSystemService(Context.DEVICE_POLICY_SERVICE); - bindResult = devicePolicyManager.bindDeviceAdminServiceAsUser( - channelCredentials.getDevicePolicyAdminComponentName(), - bindIntent, - conn, - flags, - targetUserHandle); + if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.R) { + bindResult = + devicePolicyManager.bindDeviceAdminServiceAsUser( + channelCredentials.getDevicePolicyAdminComponentName(), + bindIntent, + conn, + flags, + targetUserHandle); + } else { + return Status.INTERNAL.withDescription( + "Device policy admin binding requires Android R+"); + } break; } if (!bindResult) { return Status.UNIMPLEMENTED.withDescription( - bindMethodType.methodName() + "(" + bindIntent + ") returned false"); + bindMethodType.methodName() + "(" + bindIntent + ") returned false"); } return Status.OK; } catch (SecurityException e) { @@ -211,8 +227,9 @@ private static Status bindInternal( .withCause(e) .withDescription("SecurityException from " + bindMethodType.methodName()); } catch (RuntimeException e) { - return Status.INTERNAL.withCause(e).withDescription( - "RuntimeException from " + bindMethodType.methodName()); + return Status.INTERNAL + .withCause(e) + .withDescription("RuntimeException from " + bindMethodType.methodName()); } } @@ -250,11 +267,67 @@ void unbindInternal(Status reason) { } } + @AnyThread + @Override + public ServiceInfo resolve() throws StatusException { + int flags = 0; + if (Build.VERSION.SDK_INT >= 29) { + // Filter out non-'directBootAware' s when 'targetUserHandle' is locked. Here's why: + // Callers want 'bindIntent' to #resolve() to the same thing a follow-up call to #bind() will. + // But bindService() *always* ignores services that can't presently be created for lack of + // 'directBootAware'-ness. This flag explicitly tells resolveService() to act the same way. + flags |= PackageManager.MATCH_DIRECT_BOOT_AUTO; + } + ResolveInfo resolveInfo = + getContextForTargetUser("Cross-user pre-auth") + .getPackageManager() + .resolveService(bindIntent, flags); + if (resolveInfo == null) { + throw Status.UNIMPLEMENTED // Same status code as when bindService() returns false. + .withDescription("resolveService(" + bindIntent + " / " + targetUserHandle + ") was null") + .asException(); + } + return resolveInfo.serviceInfo; + } + + private Context getContextForTargetUser(String purpose) throws StatusException { + checkState(sourceContext != null, "Already unbound!"); + try { + return targetUserHandle == null + ? sourceContext + : createContextAsUser(sourceContext, targetUserHandle, /* flags= */ 0); + } catch (ReflectiveOperationException e) { + throw Status.INTERNAL + .withDescription(purpose + " requires SDK_INT >= R and @SystemApi visibility") + .asException(); + } + } + @MainThread private void clearReferences() { sourceContext = null; } + @AnyThread + @Override + public ServiceInfo getConnectedServiceInfo() throws StatusException { + try { + return getContextForTargetUser("cross-user v2 handshake") + .getPackageManager() + .getServiceInfo(getConnectedServiceName(), /* flags= */ 0); + } catch (PackageManager.NameNotFoundException e) { + throw Status.UNIMPLEMENTED + .withCause(e) + .withDescription("connected remote service was uninstalled/disabled during handshake") + .asException(); + } + } + + private synchronized ComponentName getConnectedServiceName() { + checkState(connectedServiceName != null, "onBound() not yet called!"); + return connectedServiceName; + } + @Override @MainThread public void onServiceConnected(ComponentName className, IBinder binder) { @@ -262,6 +335,7 @@ public void onServiceConnected(ComponentName className, IBinder binder) { synchronized (this) { if (state == State.BINDING) { state = State.BOUND; + connectedServiceName = className; bound = true; } } @@ -275,19 +349,32 @@ public void onServiceConnected(ComponentName className, IBinder binder) { @Override @MainThread public void onServiceDisconnected(ComponentName name) { - unbindInternal(Status.UNAVAILABLE.withDescription("onServiceDisconnected: " + name)); + unbindInternal( + Status.UNAVAILABLE.withDescription( + "Server process crashed, exited or was killed (onServiceDisconnected): " + name)); } @Override @MainThread public void onNullBinding(ComponentName name) { - unbindInternal(Status.UNIMPLEMENTED.withDescription("onNullBinding: " + name)); + unbindInternal( + Status.UNIMPLEMENTED.withDescription( + "Remote Service returned null from onBind() for " + + bindIntent + + " (onNullBinding): " + + name)); } @Override @MainThread public void onBindingDied(ComponentName name) { - unbindInternal(Status.UNAVAILABLE.withDescription("onBindingDied: " + name)); + unbindInternal( + Status.UNAVAILABLE.withDescription( + "Remote Service component " + + name.getClassName() + + " was disabled, or its package " + + name.getPackageName() + + " was disabled, force-stopped, replaced or uninstalled (onBindingDied).")); } @VisibleForTesting diff --git a/binder/src/main/java/io/grpc/binder/internal/SimplePromise.java b/binder/src/main/java/io/grpc/binder/internal/SimplePromise.java new file mode 100644 index 00000000000..c7d227fbf64 --- /dev/null +++ b/binder/src/main/java/io/grpc/binder/internal/SimplePromise.java @@ -0,0 +1,97 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.binder.internal; + +import static com.google.common.base.Preconditions.checkNotNull; +import static com.google.common.base.Preconditions.checkState; + +import java.util.ArrayList; +import java.util.List; + +/** + * Placeholder for an object that will be provided later. + * + *

Similar to {@link com.google.common.util.concurrent.SettableFuture}, except it cannot fail or + * be cancelled. Most importantly, this class guarantees that {@link Listener}s run one-at-a-time + * and in the same order that they were scheduled. This conveniently matches the expectations of + * most listener interfaces in the io.grpc universe. + * + *

Not safe for concurrent use by multiple threads. Thread-compatible for callers that provide + * synchronization externally. + */ +public class SimplePromise { + private T value; + private List> pendingListeners; // Allocated lazily in the hopes it's never needed. + + /** + * Provides the promised object and runs any pending listeners. + * + * @throws IllegalStateException if this method has already been called + * @throws RuntimeException if some pending listener threw when we tried to run it + */ + public void set(T value) { + checkNotNull(value, "value"); + checkState(this.value == null, "Already set!"); + this.value = value; + if (pendingListeners != null) { + for (Listener listener : pendingListeners) { + listener.notify(value); + } + pendingListeners = null; + } + } + + /** + * Returns the promised object, under the assumption that it's already been set. + * + *

Compared to {@link #runWhenSet(Listener)}, this method may be a more efficient way to access + * the promised value in the case where you somehow know externally that {@link #set(T)} has + * "happened-before" this call. + * + * @throws IllegalStateException if {@link #set(T)} has not yet been called + */ + public T get() { + checkState(value != null, "Not yet set!"); + return value; + } + + /** + * Runs the given listener when this promise is fulfilled, or immediately if already fulfilled. + * + * @throws RuntimeException if already fulfilled and 'listener' threw when we tried to run it + */ + public void runWhenSet(Listener listener) { + if (value != null) { + listener.notify(value); + } else { + if (pendingListeners == null) { + pendingListeners = new ArrayList<>(); + } + pendingListeners.add(listener); + } + } + + /** + * An object that wants to get notified when a SimplePromise has been fulfilled. + */ + public interface Listener { + /** + * Indicates that the associated SimplePromise has been fulfilled with the given `value`. + */ + void notify(T value); + } +} diff --git a/binder/src/main/java/io/grpc/binder/internal/SingleMessageServerStream.java b/binder/src/main/java/io/grpc/binder/internal/SingleMessageServerStream.java index 92e9ff4477c..383bd7a2593 100644 --- a/binder/src/main/java/io/grpc/binder/internal/SingleMessageServerStream.java +++ b/binder/src/main/java/io/grpc/binder/internal/SingleMessageServerStream.java @@ -67,6 +67,11 @@ public void setListener(ServerStreamListener listener) { } } + @Override + public void setOnReadyThreshold(int numBytes) { + // No-op + } + @Override public boolean isReady() { return outbound.isReady(); diff --git a/binder/src/main/java/io/grpc/binder/internal/SystemApis.java b/binder/src/main/java/io/grpc/binder/internal/SystemApis.java new file mode 100644 index 00000000000..a4feec86a11 --- /dev/null +++ b/binder/src/main/java/io/grpc/binder/internal/SystemApis.java @@ -0,0 +1,60 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.grpc.binder.internal; + +import android.content.Context; +import android.os.UserHandle; +import java.lang.reflect.Method; + +/** + * A collection of static methods that wrap hidden Android "System APIs." + * + *

grpc-java can't call Android methods marked @SystemApi directly, even though many of our users + * are "system apps" entitled to do so. Being a library built outside the Android source tree, these + * "non-SDK" elements simply don't exist from our compiler's perspective. Instead we resort to + * reflection but use the static wrappers found here to keep call sites readable and type safe. + * + *

Modern Android's JRE also limits the visibility of these methods at *runtime*. Only certain + * privileged apps installed on the system image app can call them, even using reflection, and this + * wrapper doesn't change that. Callers are responsible for ensuring that the host app actually has + * the ability to call @SystemApis and all methods throw {@link ReflectiveOperationException} as a + * reminder to do that. See + * https://developer.android.com/guide/app-compatibility/restrictions-non-sdk-interfaces for more. + */ +final class SystemApis { + private static volatile Method createContextAsUserMethod; + + // Not to be instantiated. + private SystemApis() {} + + /** + * Returns a new Context object whose methods act as if they were running in the given user. + * + * @throws ReflectiveOperationException if SDK_INT < R or host app lacks @SystemApi visibility + */ + public static Context createContextAsUser(Context context, UserHandle userHandle, int flags) + throws ReflectiveOperationException { + if (createContextAsUserMethod == null) { + synchronized (SystemApis.class) { + if (createContextAsUserMethod == null) { + createContextAsUserMethod = + Context.class.getMethod("createContextAsUser", UserHandle.class, int.class); + } + } + } + return (Context) createContextAsUserMethod.invoke(context, userHandle, flags); + } +} diff --git a/binder/src/main/java/io/grpc/binder/internal/TransactionUtils.java b/binder/src/main/java/io/grpc/binder/internal/TransactionUtils.java index 91f7fb8028f..2777a78d4ac 100644 --- a/binder/src/main/java/io/grpc/binder/internal/TransactionUtils.java +++ b/binder/src/main/java/io/grpc/binder/internal/TransactionUtils.java @@ -16,19 +16,26 @@ package io.grpc.binder.internal; +import android.os.Binder; import android.os.Parcel; import io.grpc.MethodDescriptor.MethodType; import io.grpc.Status; +import java.util.logging.Level; +import java.util.logging.Logger; +import io.grpc.binder.internal.LeakSafeOneWayBinder.TransactionHandler; import javax.annotation.Nullable; /** Constants and helpers for managing inbound / outbound transactions. */ final class TransactionUtils { /** Set when the transaction contains rpc prefix data. */ static final int FLAG_PREFIX = 0x1; + /** Set when the transaction contains some message data. */ static final int FLAG_MESSAGE_DATA = 0x2; + /** Set when the transaction contains rpc suffix data. */ static final int FLAG_SUFFIX = 0x4; + /** Set when the transaction is an out-of-band close event. */ static final int FLAG_OUT_OF_BAND_CLOSE = 0x8; @@ -96,4 +103,24 @@ static void fillInFlags(Parcel parcel, int flags) { parcel.writeInt(flags); parcel.setDataPosition(pos); } + + /** + * Decorates the given {@link TransactionHandler} with a wrapper that only forwards transactions + * from the given `allowedCallingUid`. + */ + static TransactionHandler newCallerFilteringHandler( + int allowedCallingUid, TransactionHandler wrapped) { + final Logger logger = Logger.getLogger(TransactionUtils.class.getName()); + return new TransactionHandler() { + @Override + public boolean handleTransaction(int code, Parcel data) { + int callingUid = Binder.getCallingUid(); + if (callingUid != allowedCallingUid) { + logger.log(Level.WARNING, "dropped txn from " + callingUid + " !=" + allowedCallingUid); + return false; + } + return wrapped.handleTransaction(code, data); + } + }; + } } diff --git a/binder/src/test/java/io/grpc/binder/AndroidComponentAddressTest.java b/binder/src/test/java/io/grpc/binder/AndroidComponentAddressTest.java index 6d7e53e5a19..d7d77d7feb1 100644 --- a/binder/src/test/java/io/grpc/binder/AndroidComponentAddressTest.java +++ b/binder/src/test/java/io/grpc/binder/AndroidComponentAddressTest.java @@ -18,11 +18,14 @@ import static android.content.Intent.URI_ANDROID_APP_SCHEME; import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; import android.content.ComponentName; import android.content.Context; import android.content.Intent; import android.net.Uri; +import android.os.Parcel; +import android.os.UserHandle; import androidx.test.core.app.ApplicationProvider; import com.google.common.testing.EqualsTester; import java.net.URISyntaxException; @@ -83,6 +86,32 @@ public void testAsBindIntent() { assertThat(addr.asBindIntent().filterEquals(bindIntent)).isTrue(); } + @Test + public void testPostCreateIntentMutation() { + Intent bindIntent = new Intent().setAction("foo-action").setComponent(hostComponent); + AndroidComponentAddress addr = AndroidComponentAddress.forBindIntent(bindIntent); + bindIntent.setAction("bar-action"); + assertThat(addr.asBindIntent().getAction()).isEqualTo("foo-action"); + } + + @Test + public void testPostBuildIntentMutation() { + Intent bindIntent = new Intent().setAction("foo-action").setComponent(hostComponent); + AndroidComponentAddress addr = + AndroidComponentAddress.newBuilder().setBindIntent(bindIntent).build(); + bindIntent.setAction("bar-action"); + assertThat(addr.asBindIntent().getAction()).isEqualTo("foo-action"); + } + + @Test + public void testBuilderMissingRequired() { + IllegalStateException ise = + assertThrows( + IllegalStateException.class, + () -> AndroidComponentAddress.newBuilder().setTargetUser(newUserHandle(123)).build()); + assertThat(ise.getMessage()).contains("bindIntent"); + } + @Test @Config(sdk = 30) public void testAsAndroidAppUriSdk30() throws URISyntaxException { @@ -117,13 +146,21 @@ public void testEquality() { AndroidComponentAddress.forContext(appContext), AndroidComponentAddress.forLocalComponent(appContext, appContext.getClass()), AndroidComponentAddress.forRemoteComponent( - appContext.getPackageName(), appContext.getClass().getName())) + appContext.getPackageName(), appContext.getClass().getName()), + AndroidComponentAddress.newBuilder() + .setBindIntentFromComponent(hostComponent) + .setTargetUser(null) + .build()) .addEqualityGroup( AndroidComponentAddress.forRemoteComponent("appy.mcappface", ".McActivity")) .addEqualityGroup(AndroidComponentAddress.forLocalComponent(appContext, getClass())) .addEqualityGroup( AndroidComponentAddress.forBindIntent( - new Intent().setAction("custom-action").setComponent(hostComponent))) + new Intent().setAction("custom-action").setComponent(hostComponent)), + AndroidComponentAddress.newBuilder() + .setBindIntent(new Intent().setAction("custom-action").setComponent(hostComponent)) + .setTargetUser(null) + .build()) .addEqualityGroup( AndroidComponentAddress.forBindIntent( new Intent() @@ -133,6 +170,31 @@ public void testEquality() { .testEquals(); } + @Test + public void testUnequalTargetUsers() { + new EqualsTester() + .addEqualityGroup( + AndroidComponentAddress.newBuilder() + .setBindIntentFromComponent(hostComponent) + .setTargetUser(newUserHandle(10)) + .build(), + AndroidComponentAddress.newBuilder() + .setBindIntentFromComponent(hostComponent) + .setTargetUser(newUserHandle(10)) + .build()) + .addEqualityGroup( + AndroidComponentAddress.newBuilder() + .setBindIntentFromComponent(hostComponent) + .setTargetUser(newUserHandle(11)) + .build()) + .addEqualityGroup( + AndroidComponentAddress.newBuilder() + .setBindIntentFromComponent(hostComponent) + .setTargetUser(null) + .build()) + .testEquals(); + } + @Test @Config(sdk = 30) public void testPackageFilterEquality30AndUp() { @@ -163,4 +225,15 @@ public void testPackageFilterEqualityPre30() { .setComponent(new ComponentName("pkg", "cls")))) .testEquals(); } + + private static UserHandle newUserHandle(int userId) { + Parcel parcel = Parcel.obtain(); + try { + parcel.writeInt(userId); + parcel.setDataPosition(0); + return new UserHandle(parcel); + } finally { + parcel.recycle(); + } + } } diff --git a/binder/src/test/java/io/grpc/binder/BinderChannelCredentialsTest.java b/binder/src/test/java/io/grpc/binder/BinderChannelCredentialsTest.java index d31065dfe52..f4e1a57e127 100644 --- a/binder/src/test/java/io/grpc/binder/BinderChannelCredentialsTest.java +++ b/binder/src/test/java/io/grpc/binder/BinderChannelCredentialsTest.java @@ -18,7 +18,7 @@ public void defaultBinderChannelCredentials() { BinderChannelCredentials channelCredentials = BinderChannelCredentials.forDefault(); assertThat(channelCredentials.getDevicePolicyAdminComponentName()).isNull(); } - + @Test public void binderChannelCredentialsForDevicePolicyAdmin() { String deviceAdminClassName = "DevicePolicyAdmin"; diff --git a/binder/src/test/java/io/grpc/binder/PeerUidTest.java b/binder/src/test/java/io/grpc/binder/PeerUidTest.java index c5326f64673..5c7c9f2ba5a 100644 --- a/binder/src/test/java/io/grpc/binder/PeerUidTest.java +++ b/binder/src/test/java/io/grpc/binder/PeerUidTest.java @@ -31,4 +31,4 @@ public void shouldImplementEqualsAndHashCode() { .addEqualityGroup(new PeerUid(456)) .testEquals(); } -} \ No newline at end of file +} diff --git a/binder/src/test/java/io/grpc/binder/PeerUidTestHelperTest.java b/binder/src/test/java/io/grpc/binder/PeerUidTestHelperTest.java new file mode 100644 index 00000000000..c9e055b9fe0 --- /dev/null +++ b/binder/src/test/java/io/grpc/binder/PeerUidTestHelperTest.java @@ -0,0 +1,122 @@ +package io.grpc.binder; + +import static com.google.common.truth.Truth.assertThat; +import static java.nio.charset.StandardCharsets.UTF_8; + +import com.google.common.collect.ImmutableList; +import io.grpc.CallOptions; +import io.grpc.Channel; +import io.grpc.ClientInterceptors; +import io.grpc.ManagedChannel; +import io.grpc.Metadata; +import io.grpc.MethodDescriptor; +import io.grpc.ServerCallHandler; +import io.grpc.ServerInterceptor; +import io.grpc.ServerInterceptors; +import io.grpc.ServerServiceDefinition; +import io.grpc.inprocess.InProcessChannelBuilder; +import io.grpc.inprocess.InProcessServerBuilder; +import io.grpc.stub.ClientCalls; +import io.grpc.stub.MetadataUtils; +import io.grpc.stub.ServerCalls; +import io.grpc.testing.GrpcCleanupRule; +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.util.concurrent.atomic.AtomicReference; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public final class PeerUidTestHelperTest { + + @Rule public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule(); + + private static final int FAKE_UID = 12345; + + private final AtomicReference clientUidCapture = new AtomicReference<>(); + + @Test + public void keyPopulatedWithInterceptorAndHeader() throws Exception { + makeServiceCall(/* includeInterceptor= */ true, /* includeUidInHeader= */ true, FAKE_UID); + assertThat(clientUidCapture.get()).isEqualTo(new PeerUid(FAKE_UID)); + } + + @Test + public void keyNotPopulatedWithInterceptorAndNoHeader() throws Exception { + makeServiceCall(/* includeInterceptor= */ true, /* includeUidInHeader= */ false, /* uid= */ -1); + assertThat(clientUidCapture.get()).isNull(); + } + + @Test + public void keyNotPopulatedWithoutInterceptorAndWithHeader() throws Exception { + makeServiceCall( + /* includeInterceptor= */ false, /* includeUidInHeader= */ true, /* uid= */ FAKE_UID); + assertThat(clientUidCapture.get()).isNull(); + } + + private final MethodDescriptor method = + MethodDescriptor.newBuilder(StringMarshaller.INSTANCE, StringMarshaller.INSTANCE) + .setFullMethodName("test/method") + .setType(MethodDescriptor.MethodType.UNARY) + .build(); + + private void makeServiceCall(boolean includeInterceptor, boolean includeUidInHeader, int uid) + throws Exception { + ServerCallHandler callHandler = + ServerCalls.asyncUnaryCall( + (req, respObserver) -> { + clientUidCapture.set(PeerUids.REMOTE_PEER.get()); + respObserver.onNext(req); + respObserver.onCompleted(); + }); + ImmutableList interceptors; + if (includeInterceptor) { + interceptors = ImmutableList.of(PeerUidTestHelper.newTestPeerIdentifyingServerInterceptor()); + } else { + interceptors = ImmutableList.of(); + } + ServerServiceDefinition serviceDef = + ServerInterceptors.intercept( + ServerServiceDefinition.builder("test").addMethod(method, callHandler).build(), + interceptors); + + InProcessServerBuilder server = + InProcessServerBuilder.forName("test").directExecutor().addService(serviceDef); + + grpcCleanup.register(server.build().start()); + + Channel channel = InProcessChannelBuilder.forName("test").directExecutor().build(); + grpcCleanup.register((ManagedChannel) channel); + + if (includeUidInHeader) { + Metadata header = new Metadata(); + header.put(PeerUidTestHelper.UID_KEY, uid); + channel = + ClientInterceptors.intercept(channel, MetadataUtils.newAttachHeadersInterceptor(header)); + } + + ClientCalls.blockingUnaryCall(channel, method, CallOptions.DEFAULT, "hello"); + } + + private static class StringMarshaller implements MethodDescriptor.Marshaller { + + public static final StringMarshaller INSTANCE = new StringMarshaller(); + + @Override + public InputStream stream(String value) { + return new ByteArrayInputStream(value.getBytes(UTF_8)); + } + + @Override + public String parse(InputStream stream) { + try { + return new String(stream.readAllBytes(), UTF_8); + } catch (IOException ex) { + throw new RuntimeException(ex); + } + } + } +} diff --git a/binder/src/test/java/io/grpc/binder/PeerUidsTest.java b/binder/src/test/java/io/grpc/binder/PeerUidsTest.java index 41956e60ca7..f87f7bd56f1 100644 --- a/binder/src/test/java/io/grpc/binder/PeerUidsTest.java +++ b/binder/src/test/java/io/grpc/binder/PeerUidsTest.java @@ -133,4 +133,4 @@ public String parse(InputStream stream) { } } } -} \ No newline at end of file +} diff --git a/binder/src/test/java/io/grpc/binder/RobolectricBinderSecurityTest.java b/binder/src/test/java/io/grpc/binder/RobolectricBinderSecurityTest.java index 44e863780cb..ffd1d89e69c 100644 --- a/binder/src/test/java/io/grpc/binder/RobolectricBinderSecurityTest.java +++ b/binder/src/test/java/io/grpc/binder/RobolectricBinderSecurityTest.java @@ -22,13 +22,13 @@ import static org.robolectric.Shadows.shadowOf; import android.app.Application; -import android.content.ComponentName; -import android.content.Intent; -import android.os.Handler; -import android.os.IBinder; -import android.os.Looper; -import androidx.lifecycle.LifecycleService; +import android.content.pm.ApplicationInfo; +import android.content.pm.PackageInfo; +import android.content.pm.ServiceInfo; import androidx.test.core.app.ApplicationProvider; +import androidx.test.core.content.pm.ApplicationInfoBuilder; +import androidx.test.core.content.pm.PackageInfoBuilder; +import com.google.common.collect.ImmutableList; import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.SettableFuture; @@ -48,84 +48,138 @@ import io.grpc.stub.ServerCalls; import java.io.IOException; import java.util.concurrent.ArrayBlockingQueue; -import java.util.concurrent.ScheduledExecutorService; -import javax.annotation.Nullable; import org.junit.After; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; -import org.robolectric.Robolectric; -import org.robolectric.RobolectricTestRunner; -import org.robolectric.android.controller.ServiceController; - -@RunWith(RobolectricTestRunner.class) +import org.robolectric.ParameterizedRobolectricTestRunner; +import org.robolectric.ParameterizedRobolectricTestRunner.Parameter; +import org.robolectric.ParameterizedRobolectricTestRunner.Parameters; +import org.robolectric.annotation.LooperMode; +import org.robolectric.annotation.LooperMode.Mode; + +@RunWith(ParameterizedRobolectricTestRunner.class) +@LooperMode(Mode.INSTRUMENTATION_TEST) public final class RobolectricBinderSecurityTest { private static final String SERVICE_NAME = "fake_service"; private static final String FULL_METHOD_NAME = "fake_service/fake_method"; private final Application context = ApplicationProvider.getApplicationContext(); - private ServiceController controller; - private SomeService service; + private final ArrayBlockingQueue> statusesToSet = + new ArrayBlockingQueue<>(128); private ManagedChannel channel; + private Server server; + + @Parameter public boolean preAuthServersParam; + + @Parameters(name = "preAuthServersParam={0}") + public static ImmutableList data() { + return ImmutableList.of(true, false); + } @Before public void setUp() { - controller = Robolectric.buildService(SomeService.class); - service = controller.create().get(); + ApplicationInfo serverAppInfo = + ApplicationInfoBuilder.newBuilder().setPackageName(context.getPackageName()).build(); + serverAppInfo.uid = android.os.Process.myUid(); + PackageInfo serverPkgInfo = + PackageInfoBuilder.newBuilder() + .setPackageName(serverAppInfo.packageName) + .setApplicationInfo(serverAppInfo) + .build(); + shadowOf(context.getPackageManager()).installPackage(serverPkgInfo); + + ServiceInfo serviceInfo = new ServiceInfo(); + serviceInfo.name = "SomeService"; + serviceInfo.packageName = serverAppInfo.packageName; + serviceInfo.applicationInfo = serverAppInfo; + shadowOf(context.getPackageManager()).addOrUpdateService(serviceInfo); + + AndroidComponentAddress listenAddress = + AndroidComponentAddress.forRemoteComponent(serviceInfo.packageName, serviceInfo.name); + + MethodDescriptor methodDesc = getMethodDescriptor(); + ServerCallHandler callHandler = + ServerCalls.asyncUnaryCall( + (req, respObserver) -> { + respObserver.onNext(req); + respObserver.onCompleted(); + }); + ServerMethodDefinition methodDef = + ServerMethodDefinition.create(methodDesc, callHandler); + ServerServiceDefinition def = + ServerServiceDefinition.builder(SERVICE_NAME).addMethod(methodDef).build(); + + IBinderReceiver binderReceiver = new IBinderReceiver(); + server = + BinderServerBuilder.forAddress(listenAddress, binderReceiver) + .addService(def) + .securityPolicy( + ServerSecurityPolicy.newBuilder() + .servicePolicy( + SERVICE_NAME, + new AsyncSecurityPolicy() { + @Override + public ListenableFuture checkAuthorizationAsync(int uid) { + SettableFuture status = SettableFuture.create(); + statusesToSet.add(status); + return status; + } + }) + .build()) + .build(); + try { + server.start(); + } catch (IOException e) { + throw new IllegalStateException(e); + } - AndroidComponentAddress listenAddress = AndroidComponentAddress.forContext(service); - ScheduledExecutorService executor = service.getExecutor(); + shadowOf(context) + .setComponentNameAndServiceForBindServiceForIntent( + listenAddress.asBindIntent(), + listenAddress.getComponent(), + checkNotNull(binderReceiver.get())); channel = BinderChannelBuilder.forAddress(listenAddress, context) - .executor(executor) - .scheduledExecutorService(executor) - .offloadExecutor(executor) + .preAuthorizeServers(preAuthServersParam) .build(); - idleLoopers(); } @After public void tearDown() { channel.shutdownNow(); - controller.destroy(); + server.shutdownNow(); } @Test public void testAsyncServerSecurityPolicy_failed_returnsFailureStatus() throws Exception { ListenableFuture status = makeCall(); - service.setSecurityPolicyStatusWhenReady(Status.ALREADY_EXISTS); - idleLoopers(); + statusesToSet.take().set(Status.ALREADY_EXISTS); - assertThat(Futures.getDone(status).getCode()).isEqualTo(Status.Code.ALREADY_EXISTS); + assertThat(status.get().getCode()).isEqualTo(Status.Code.ALREADY_EXISTS); } @Test public void testAsyncServerSecurityPolicy_failedFuture_failsWithCodeInternal() throws Exception { ListenableFuture status = makeCall(); - service.setSecurityPolicyFailed(new IllegalStateException("oops")); - idleLoopers(); + statusesToSet.take().setException(new IllegalStateException("oops")); - assertThat(Futures.getDone(status).getCode()).isEqualTo(Status.Code.INTERNAL); + assertThat(status.get().getCode()).isEqualTo(Status.Code.INTERNAL); } @Test public void testAsyncServerSecurityPolicy_allowed_returnsOkStatus() throws Exception { ListenableFuture status = makeCall(); - service.setSecurityPolicyStatusWhenReady(Status.OK); - idleLoopers(); + statusesToSet.take().set(Status.OK); - assertThat(Futures.getDone(status).getCode()).isEqualTo(Status.Code.OK); + assertThat(status.get().getCode()).isEqualTo(Status.Code.OK); } private ListenableFuture makeCall() { - ClientCall call = - channel.newCall( - getMethodDescriptor(), CallOptions.DEFAULT.withExecutor(service.getExecutor())); + ClientCall call = channel.newCall(getMethodDescriptor(), CallOptions.DEFAULT); ListenableFuture responseFuture = ClientCalls.futureUnaryCall(call, Empty.getDefaultInstance()); - idleLoopers(); - return Futures.catching( Futures.transform(responseFuture, unused -> Status.OK, directExecutor()), StatusRuntimeException.class, @@ -133,10 +187,6 @@ private ListenableFuture makeCall() { directExecutor()); } - private static void idleLoopers() { - shadowOf(Looper.getMainLooper()).idle(); - } - private static MethodDescriptor getMethodDescriptor() { MethodDescriptor.Marshaller marshaller = ProtoLiteUtils.marshaller(Empty.getDefaultInstance()); @@ -147,109 +197,4 @@ private static MethodDescriptor getMethodDescriptor() { .setSampledToLocalTracing(true) .build(); } - - private static class SomeService extends LifecycleService { - - private final IBinderReceiver binderReceiver = new IBinderReceiver(); - private final ArrayBlockingQueue> statusesToSet = - new ArrayBlockingQueue<>(128); - private Server server; - private final ScheduledExecutorService scheduledExecutorService = - new MainThreadScheduledExecutorService(); - - @Override - public void onCreate() { - super.onCreate(); - - MethodDescriptor methodDesc = getMethodDescriptor(); - ServerCallHandler callHandler = - ServerCalls.asyncUnaryCall( - (req, respObserver) -> { - respObserver.onNext(req); - respObserver.onCompleted(); - }); - ServerMethodDefinition methodDef = - ServerMethodDefinition.create(methodDesc, callHandler); - ServerServiceDefinition def = - ServerServiceDefinition.builder(SERVICE_NAME).addMethod(methodDef).build(); - - server = - BinderServerBuilder.forAddress(AndroidComponentAddress.forContext(this), binderReceiver) - .addService(def) - .securityPolicy( - ServerSecurityPolicy.newBuilder() - .servicePolicy( - SERVICE_NAME, - new AsyncSecurityPolicy() { - @Override - public ListenableFuture checkAuthorizationAsync(int uid) { - return Futures.submitAsync( - () -> { - SettableFuture status = SettableFuture.create(); - statusesToSet.add(status); - return status; - }, - getExecutor()); - } - }) - .build()) - .executor(getExecutor()) - .scheduledExecutorService(getExecutor()) - .build(); - try { - server.start(); - } catch (IOException e) { - throw new IllegalStateException(e); - } - - Application context = ApplicationProvider.getApplicationContext(); - ComponentName componentName = new ComponentName(context, SomeService.class); - shadowOf(context) - .setComponentNameAndServiceForBindService( - componentName, checkNotNull(binderReceiver.get())); - } - - /** - * Returns an {@link ScheduledExecutorService} under which all of the gRPC computations run. The - * execution of any pending tasks on this executor can be triggered via {@link #idleLoopers()}. - */ - ScheduledExecutorService getExecutor() { - return scheduledExecutorService; - } - - void setSecurityPolicyStatusWhenReady(Status status) { - getNextEnqueuedStatus().set(status); - } - - void setSecurityPolicyFailed(Exception e) { - getNextEnqueuedStatus().setException(e); - } - - private SettableFuture getNextEnqueuedStatus() { - @Nullable SettableFuture future = statusesToSet.poll(); - while (future == null) { - // Keep idling until the future is available. - idleLoopers(); - future = statusesToSet.poll(); - } - return checkNotNull(future); - } - - @Override - public IBinder onBind(Intent intent) { - super.onBind(intent); - return checkNotNull(binderReceiver.get()); - } - - @Override - public void onDestroy() { - super.onDestroy(); - server.shutdownNow(); - } - - /** A future representing a task submitted to a {@link Handler}. */ - - - } - } diff --git a/binder/src/test/java/io/grpc/binder/SecurityPoliciesTest.java b/binder/src/test/java/io/grpc/binder/SecurityPoliciesTest.java index e0f9f987939..71180ed43c5 100644 --- a/binder/src/test/java/io/grpc/binder/SecurityPoliciesTest.java +++ b/binder/src/test/java/io/grpc/binder/SecurityPoliciesTest.java @@ -107,8 +107,7 @@ public void testPermissionDenied() throws Exception { } @Test - public void testHasSignature_succeedsIfPackageNameAndSignaturesMatch() - throws Exception { + public void testHasSignature_succeedsIfPackageNameAndSignaturesMatch() throws Exception { PackageInfo info = newBuilder().setPackageName(OTHER_UID_PACKAGE_NAME).setSignatures(SIG2).build(); @@ -152,8 +151,7 @@ public void testHasSignature_failsIfSignatureDoesNotMatch() throws Exception { } @Test - public void testOneOfSignatures_succeedsIfPackageNameAndSignaturesMatch() - throws Exception { + public void testOneOfSignatures_succeedsIfPackageNameAndSignaturesMatch() throws Exception { PackageInfo info = newBuilder().setPackageName(OTHER_UID_PACKAGE_NAME).setSignatures(SIG2).build(); @@ -189,8 +187,7 @@ public void testOneOfSignature_failsIfAllSignaturesDoNotMatch() throws Exception } @Test - public void testOneOfSignature_succeedsIfPackageNameAndOneOfSignaturesMatch() - throws Exception { + public void testOneOfSignature_succeedsIfPackageNameAndOneOfSignaturesMatch() throws Exception { PackageInfo info = newBuilder().setPackageName(OTHER_UID_PACKAGE_NAME).setSignatures(SIG2).build(); @@ -206,11 +203,7 @@ public void testOneOfSignature_succeedsIfPackageNameAndOneOfSignaturesMatch() @Test public void testHasSignature_failsIfUidUnknown() throws Exception { - policy = - SecurityPolicies.hasSignature( - packageManager, - appContext.getPackageName(), - SIG1); + policy = SecurityPolicies.hasSignature(packageManager, appContext.getPackageName(), SIG1); assertThat(policy.checkAuthorization(OTHER_UID_UNKNOWN).getCode()) .isEqualTo(Status.UNAUTHENTICATED.getCode()); @@ -335,8 +328,7 @@ public void testIsDeviceOwner_succeedsForDeviceOwner() throws Exception { newBuilder().setPackageName(OTHER_UID_PACKAGE_NAME).setSignatures(SIG2).build(); installPackages(OTHER_UID, info); - shadowOf(devicePolicyManager) - .setDeviceOwner(new ComponentName(OTHER_UID_PACKAGE_NAME, "foo")); + shadowOf(devicePolicyManager).setDeviceOwner(new ComponentName(OTHER_UID_PACKAGE_NAME, "foo")); policy = SecurityPolicies.isDeviceOwner(appContext); @@ -352,26 +344,26 @@ public void testIsDeviceOwner_failsForNotDeviceOwner() throws Exception { policy = SecurityPolicies.isDeviceOwner(appContext); - assertThat(policy.checkAuthorization(OTHER_UID).getCode()).isEqualTo(Status.PERMISSION_DENIED.getCode()); + assertThat(policy.checkAuthorization(OTHER_UID).getCode()) + .isEqualTo(Status.PERMISSION_DENIED.getCode()); } @Test public void testIsDeviceOwner_failsWhenNoPackagesForUid() throws Exception { policy = SecurityPolicies.isDeviceOwner(appContext); - assertThat(policy.checkAuthorization(OTHER_UID).getCode()).isEqualTo(Status.UNAUTHENTICATED.getCode()); + assertThat(policy.checkAuthorization(OTHER_UID).getCode()) + .isEqualTo(Status.UNAUTHENTICATED.getCode()); } - @Test - @Config(sdk = 21) + @Config(sdk = Config.OLDEST_SDK) public void testIsProfileOwner_succeedsForProfileOwner() throws Exception { PackageInfo info = newBuilder().setPackageName(OTHER_UID_PACKAGE_NAME).setSignatures(SIG2).build(); installPackages(OTHER_UID, info); - shadowOf(devicePolicyManager) - .setProfileOwner(new ComponentName(OTHER_UID_PACKAGE_NAME, "foo")); + shadowOf(devicePolicyManager).setProfileOwner(new ComponentName(OTHER_UID_PACKAGE_NAME, "foo")); policy = SecurityPolicies.isProfileOwner(appContext); @@ -379,7 +371,7 @@ public void testIsProfileOwner_succeedsForProfileOwner() throws Exception { } @Test - @Config(sdk = 21) + @Config(sdk = Config.OLDEST_SDK) public void testIsProfileOwner_failsForNotProfileOwner() throws Exception { PackageInfo info = newBuilder().setPackageName(OTHER_UID_PACKAGE_NAME).setSignatures(SIG2).build(); @@ -388,15 +380,17 @@ public void testIsProfileOwner_failsForNotProfileOwner() throws Exception { policy = SecurityPolicies.isProfileOwner(appContext); - assertThat(policy.checkAuthorization(OTHER_UID).getCode()).isEqualTo(Status.PERMISSION_DENIED.getCode()); + assertThat(policy.checkAuthorization(OTHER_UID).getCode()) + .isEqualTo(Status.PERMISSION_DENIED.getCode()); } @Test - @Config(sdk = 21) + @Config(sdk = Config.OLDEST_SDK) public void testIsProfileOwner_failsWhenNoPackagesForUid() throws Exception { policy = SecurityPolicies.isProfileOwner(appContext); - assertThat(policy.checkAuthorization(OTHER_UID).getCode()).isEqualTo(Status.UNAUTHENTICATED.getCode()); + assertThat(policy.checkAuthorization(OTHER_UID).getCode()) + .isEqualTo(Status.UNAUTHENTICATED.getCode()); } @Test @@ -406,14 +400,12 @@ public void testIsProfileOwnerOnOrgOwned_succeedsForProfileOwnerOnOrgOwned() thr newBuilder().setPackageName(OTHER_UID_PACKAGE_NAME).setSignatures(SIG2).build(); installPackages(OTHER_UID, info); - shadowOf(devicePolicyManager) - .setProfileOwner(new ComponentName(OTHER_UID_PACKAGE_NAME, "foo")); + shadowOf(devicePolicyManager).setProfileOwner(new ComponentName(OTHER_UID_PACKAGE_NAME, "foo")); shadowOf(devicePolicyManager).setOrganizationOwnedDeviceWithManagedProfile(true); policy = SecurityPolicies.isProfileOwnerOnOrganizationOwnedDevice(appContext); assertThat(policy.checkAuthorization(OTHER_UID).getCode()).isEqualTo(Status.OK.getCode()); - } @Test @@ -423,17 +415,17 @@ public void testIsProfileOwnerOnOrgOwned_failsForProfileOwnerOnNonOrgOwned() thr newBuilder().setPackageName(OTHER_UID_PACKAGE_NAME).setSignatures(SIG2).build(); installPackages(OTHER_UID, info); - shadowOf(devicePolicyManager) - .setProfileOwner(new ComponentName(OTHER_UID_PACKAGE_NAME, "foo")); + shadowOf(devicePolicyManager).setProfileOwner(new ComponentName(OTHER_UID_PACKAGE_NAME, "foo")); shadowOf(devicePolicyManager).setOrganizationOwnedDeviceWithManagedProfile(false); policy = SecurityPolicies.isProfileOwnerOnOrganizationOwnedDevice(appContext); - assertThat(policy.checkAuthorization(OTHER_UID).getCode()).isEqualTo(Status.PERMISSION_DENIED.getCode()); + assertThat(policy.checkAuthorization(OTHER_UID).getCode()) + .isEqualTo(Status.PERMISSION_DENIED.getCode()); } @Test - @Config(sdk = 21) + @Config(sdk = Config.OLDEST_SDK) public void testIsProfileOwnerOnOrgOwned_failsForNotProfileOwner() throws Exception { PackageInfo info = newBuilder().setPackageName(OTHER_UID_PACKAGE_NAME).setSignatures(SIG2).build(); @@ -442,15 +434,17 @@ public void testIsProfileOwnerOnOrgOwned_failsForNotProfileOwner() throws Except policy = SecurityPolicies.isProfileOwnerOnOrganizationOwnedDevice(appContext); - assertThat(policy.checkAuthorization(OTHER_UID).getCode()).isEqualTo(Status.PERMISSION_DENIED.getCode()); + assertThat(policy.checkAuthorization(OTHER_UID).getCode()) + .isEqualTo(Status.PERMISSION_DENIED.getCode()); } @Test - @Config(sdk = 21) + @Config(sdk = Config.OLDEST_SDK) public void testIsProfileOwnerOnOrgOwned_failsWhenNoPackagesForUid() throws Exception { policy = SecurityPolicies.isProfileOwnerOnOrganizationOwnedDevice(appContext); - assertThat(policy.checkAuthorization(OTHER_UID).getCode()).isEqualTo(Status.UNAUTHENTICATED.getCode()); + assertThat(policy.checkAuthorization(OTHER_UID).getCode()) + .isEqualTo(Status.UNAUTHENTICATED.getCode()); } @Test @@ -463,7 +457,8 @@ public void testIsProfileOwnerOnOrgOwned_failsForSdkLevelTooLow() throws Excepti policy = SecurityPolicies.isProfileOwner(appContext); - assertThat(policy.checkAuthorization(OTHER_UID).getCode()).isEqualTo(Status.PERMISSION_DENIED.getCode()); + assertThat(policy.checkAuthorization(OTHER_UID).getCode()) + .isEqualTo(Status.PERMISSION_DENIED.getCode()); } private static PackageInfoBuilder newBuilder() { @@ -490,6 +485,7 @@ public PackageInfoBuilder setSignatures(Signature... signatures) { return this; } + @SuppressWarnings("deprecation") // 'signatures': We don't yet support signing cert rotation. public PackageInfo build() { checkState(this.packageName != null, "packageName is a mandatory field"); @@ -666,8 +662,8 @@ public void testOneOfSignatureSha256Hash_succeedsIfPackageNameAndOneOfSignatureH @Test public void - testOneOfSignatureSha256Hash_failsIfPackageNameDoNotMatchAndOneOfSignatureHashesMatch() - throws Exception { + testOneOfSignatureSha256Hash_failsIfPackageNameDoNotMatchAndOneOfSignatureHashesMatch() + throws Exception { PackageInfo info = newBuilder().setPackageName(OTHER_UID_PACKAGE_NAME).setSignatures(SIG2).build(); installPackages(OTHER_UID, info); diff --git a/binder/src/test/java/io/grpc/binder/ServerSecurityPolicyTest.java b/binder/src/test/java/io/grpc/binder/ServerSecurityPolicyTest.java index fb7e8e05566..eedc3f590cd 100644 --- a/binder/src/test/java/io/grpc/binder/ServerSecurityPolicyTest.java +++ b/binder/src/test/java/io/grpc/binder/ServerSecurityPolicyTest.java @@ -18,8 +18,8 @@ import static com.google.common.truth.Truth.assertThat; import static org.junit.Assert.assertThrows; - import static org.junit.Assert.fail; + import android.os.Process; import com.google.common.base.Function; import com.google.common.util.concurrent.Futures; @@ -27,18 +27,16 @@ import com.google.common.util.concurrent.ListeningExecutorService; import com.google.common.util.concurrent.MoreExecutors; import com.google.common.util.concurrent.Uninterruptibles; - import io.grpc.Status; -import io.grpc.Status.Code; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.robolectric.RobolectricTestRunner; import java.util.concurrent.BrokenBarrierException; import java.util.concurrent.CancellationException; import java.util.concurrent.CountDownLatch; import java.util.concurrent.CyclicBarrier; import java.util.concurrent.ExecutionException; import java.util.concurrent.Executors; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.robolectric.RobolectricTestRunner; @RunWith(RobolectricTestRunner.class) public final class ServerSecurityPolicyTest { @@ -62,6 +60,7 @@ public void testDefaultInternalOnly() throws Exception { } @Test + @Deprecated public void testDefaultInternalOnly_legacyApi() { policy = new ServerSecurityPolicy(); assertThat(policy.checkAuthorizationForService(MY_UID, SERVICE1).getCode()) @@ -80,6 +79,7 @@ public void testInternalOnly_AnotherUid() throws Exception { } @Test + @Deprecated public void testInternalOnly_AnotherUid_legacyApi() { policy = new ServerSecurityPolicy(); assertThat(policy.checkAuthorizationForService(OTHER_UID, SERVICE1).getCode()) @@ -98,6 +98,7 @@ public void testBuilderDefault() throws Exception { } @Test + @Deprecated public void testBuilderDefault_legacyApi() { policy = ServerSecurityPolicy.newBuilder().build(); assertThat(policy.checkAuthorizationForService(MY_UID, SERVICE1).getCode()) @@ -123,8 +124,8 @@ public void testPerService() throws Exception { .isEqualTo(Status.OK.getCode()); } - @Test + @Deprecated public void testPerService_legacyApi() { policy = ServerSecurityPolicy.newBuilder() @@ -145,13 +146,16 @@ public void testPerService_legacyApi() { public void testPerServiceAsync() throws Exception { policy = ServerSecurityPolicy.newBuilder() - .servicePolicy(SERVICE2, asyncPolicy(uid -> { - // Add some extra future transformation to confirm that a chain - // of futures gets properly handled. - ListenableFuture dependency = Futures.immediateVoidFuture(); - return Futures - .transform(dependency, unused -> Status.OK, MoreExecutors.directExecutor()); - })) + .servicePolicy( + SERVICE2, + asyncPolicy( + uid -> { + // Add some extra future transformation to confirm that a chain + // of futures gets properly handled. + ListenableFuture dependency = Futures.immediateVoidFuture(); + return Futures.transform( + dependency, unused -> Status.OK, MoreExecutors.directExecutor()); + })) .build(); assertThat(checkAuthorizationForServiceAsync(policy, MY_UID, SERVICE1)) @@ -168,11 +172,12 @@ public void testPerServiceAsync() throws Exception { public void testPerService_failedSecurityPolicyFuture_returnsAFailedFuture() { policy = ServerSecurityPolicy.newBuilder() - .servicePolicy(SERVICE1, asyncPolicy(uid -> - Futures - .immediateFailedFuture( - new IllegalStateException("something went wrong")) - )) + .servicePolicy( + SERVICE1, + asyncPolicy( + uid -> + Futures.immediateFailedFuture( + new IllegalStateException("something went wrong")))) .build(); ListenableFuture statusFuture = @@ -199,24 +204,31 @@ public void testPerServiceAsync_interrupted_cancelledFuture() { ListeningExecutorService listeningExecutorService = MoreExecutors.listeningDecorator(Executors.newSingleThreadExecutor()); CountDownLatch unsatisfiedLatch = new CountDownLatch(1); - ListenableFuture toBeInterruptedFuture = listeningExecutorService.submit(() -> { - unsatisfiedLatch.await(); // waits forever - return null; - }); + ListenableFuture toBeInterruptedFuture = + listeningExecutorService.submit( + () -> { + unsatisfiedLatch.await(); // waits forever + return null; + }); CyclicBarrier barrier = new CyclicBarrier(2); Thread testThread = Thread.currentThread(); - new Thread(() -> { - awaitOrFail(barrier); - testThread.interrupt(); - }).start(); + new Thread( + () -> { + awaitOrFail(barrier); + testThread.interrupt(); + }) + .start(); policy = ServerSecurityPolicy.newBuilder() - .servicePolicy(SERVICE1, asyncPolicy(unused -> { - awaitOrFail(barrier); - return toBeInterruptedFuture; - })) + .servicePolicy( + SERVICE1, + asyncPolicy( + unused -> { + awaitOrFail(barrier); + return toBeInterruptedFuture; + })) .build(); ListenableFuture statusFuture = policy.checkAuthorizationForServiceAsync(MY_UID, SERVICE1); @@ -243,14 +255,18 @@ SERVICE2, policy((uid) -> uid == OTHER_UID ? Status.OK : Status.PERMISSION_DENIE // Uses the specified policy for service2. assertThat(checkAuthorizationForServiceAsync(policy, MY_UID, SERVICE2)) .isEqualTo(Status.PERMISSION_DENIED.getCode()); - assertThat(checkAuthorizationForServiceAsync(policy, OTHER_UID, SERVICE2)).isEqualTo(Status.OK.getCode()); + assertThat(checkAuthorizationForServiceAsync(policy, OTHER_UID, SERVICE2)) + .isEqualTo(Status.OK.getCode()); // Falls back to the default. - assertThat(checkAuthorizationForServiceAsync(policy, MY_UID, SERVICE3)).isEqualTo(Status.OK.getCode()); + assertThat(checkAuthorizationForServiceAsync(policy, MY_UID, SERVICE3)) + .isEqualTo(Status.OK.getCode()); assertThat(checkAuthorizationForServiceAsync(policy, OTHER_UID, SERVICE3)) .isEqualTo(Status.PERMISSION_DENIED.getCode()); } + @Test + @Deprecated public void testPerServiceNoDefault_legacyApi() { policy = ServerSecurityPolicy.newBuilder() @@ -281,44 +297,40 @@ SERVICE2, policy((uid) -> uid == OTHER_UID ? Status.OK : Status.PERMISSION_DENIE @Test public void testPerServiceNoDefaultAsync() throws Exception { policy = - ServerSecurityPolicy.newBuilder() - .servicePolicy( - SERVICE1, - asyncPolicy((uid) -> Futures.immediateFuture(Status.INTERNAL))) - .servicePolicy( - SERVICE2, asyncPolicy((uid) -> { - // Add some extra future transformation to confirm that a chain - // of futures gets properly handled. - ListenableFuture anotherUidFuture = - Futures.immediateFuture(uid == OTHER_UID); - return Futures - .transform( - anotherUidFuture, - anotherUid -> - anotherUid - ? Status.OK - : Status.PERMISSION_DENIED, - MoreExecutors.directExecutor()); - })) - .build(); + ServerSecurityPolicy.newBuilder() + .servicePolicy(SERVICE1, asyncPolicy((uid) -> Futures.immediateFuture(Status.INTERNAL))) + .servicePolicy( + SERVICE2, + asyncPolicy( + (uid) -> { + // Add some extra future transformation to confirm that a chain + // of futures gets properly handled. + ListenableFuture anotherUidFuture = + Futures.immediateFuture(uid == OTHER_UID); + return Futures.transform( + anotherUidFuture, + anotherUid -> anotherUid ? Status.OK : Status.PERMISSION_DENIED, + MoreExecutors.directExecutor()); + })) + .build(); // Uses the specified policy for service1. assertThat(checkAuthorizationForServiceAsync(policy, MY_UID, SERVICE1)) - .isEqualTo(Status.INTERNAL.getCode()); + .isEqualTo(Status.INTERNAL.getCode()); assertThat(checkAuthorizationForServiceAsync(policy, OTHER_UID, SERVICE1)) - .isEqualTo(Status.INTERNAL.getCode()); + .isEqualTo(Status.INTERNAL.getCode()); // Uses the specified policy for service2. assertThat(checkAuthorizationForServiceAsync(policy, MY_UID, SERVICE2)) - .isEqualTo(Status.PERMISSION_DENIED.getCode()); + .isEqualTo(Status.PERMISSION_DENIED.getCode()); assertThat(checkAuthorizationForServiceAsync(policy, OTHER_UID, SERVICE2)) - .isEqualTo(Status.OK.getCode()); + .isEqualTo(Status.OK.getCode()); // Falls back to the default. assertThat(checkAuthorizationForServiceAsync(policy, MY_UID, SERVICE3)) - .isEqualTo(Status.OK.getCode()); + .isEqualTo(Status.OK.getCode()); assertThat(checkAuthorizationForServiceAsync(policy, OTHER_UID, SERVICE3)) - .isEqualTo(Status.PERMISSION_DENIED.getCode()); + .isEqualTo(Status.PERMISSION_DENIED.getCode()); } /** @@ -326,9 +338,7 @@ SERVICE2, asyncPolicy((uid) -> { * dealing with concurrency details. Returns a {link @Status.Code} for convenience. */ private static Status.Code checkAuthorizationForServiceAsync( - ServerSecurityPolicy policy, - int callerUid, - String service) throws ExecutionException { + ServerSecurityPolicy policy, int callerUid, String service) throws ExecutionException { ListenableFuture statusFuture = policy.checkAuthorizationForServiceAsync(callerUid, service); return Uninterruptibles.getUninterruptibly(statusFuture).getCode(); @@ -354,12 +364,12 @@ public ListenableFuture checkAuthorizationAsync(int uid) { private static void awaitOrFail(CyclicBarrier barrier) { try { - barrier.await(); + barrier.await(); } catch (BrokenBarrierException e) { - fail(e.getMessage()); + fail(e.getMessage()); } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - fail(e.getMessage()); + Thread.currentThread().interrupt(); + fail(e.getMessage()); } } } diff --git a/binder/src/test/java/io/grpc/binder/internal/ActiveTransportTrackerTest.java b/binder/src/test/java/io/grpc/binder/internal/ActiveTransportTrackerTest.java new file mode 100644 index 00000000000..099756075f1 --- /dev/null +++ b/binder/src/test/java/io/grpc/binder/internal/ActiveTransportTrackerTest.java @@ -0,0 +1,113 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.binder.internal; + +import static org.junit.Assert.assertThrows; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; +import static org.mockito.Mockito.when; + +import io.grpc.internal.ServerListener; +import io.grpc.internal.ServerTransport; +import io.grpc.internal.ServerTransportListener; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnit; +import org.mockito.junit.MockitoRule; +import org.robolectric.RobolectricTestRunner; + +@RunWith(RobolectricTestRunner.class) +public final class ActiveTransportTrackerTest { + @Rule public final MockitoRule mocks = MockitoJUnit.rule(); + + private ActiveTransportTracker tracker; + + @Mock Runnable mockShutdownListener; + @Mock ServerListener mockServerListener; + @Mock ServerTransportListener mockServerTransportListener; + @Mock ServerTransport mockServerTransport; + + @Before + public void setUp() { + when(mockServerListener.transportCreated(any())).thenReturn(mockServerTransportListener); + tracker = new ActiveTransportTracker(mockServerListener, mockShutdownListener); + } + + @Test + public void testServerShutdown_onlyNotifiesAfterAllTransportAreTerminated() { + ServerTransportListener wrapperListener1 = registerNewTransport(); + ServerTransportListener wrapperListener2 = registerNewTransport(); + + tracker.serverShutdown(); + // 2 active transports, notification scheduled + verifyNoInteractions(mockShutdownListener); + + wrapperListener1.transportTerminated(); + // 1 active transport remaining, notification still pending + verifyNoInteractions(mockShutdownListener); + + wrapperListener2.transportTerminated(); + // No more active transports, shutdown notified + verify(mockShutdownListener).run(); + } + + @Test + public void testServerShutdown_noActiveTransport_notifiesTerminationImmediately() { + verifyNoInteractions(mockShutdownListener); + + tracker.serverShutdown(); + + verify(mockShutdownListener).run(); + } + + @Test + public void testLastTransportTerminated_serverNotShutdownYet_doesNotNotify() { + ServerTransportListener wrapperListener = registerNewTransport(); + verifyNoInteractions(mockShutdownListener); + + wrapperListener.transportTerminated(); + + verifyNoInteractions(mockShutdownListener); + } + + @Test + public void testTransportCreation_afterServerShutdown_throws() { + tracker.serverShutdown(); + + assertThrows(IllegalStateException.class, this::registerNewTransport); + } + + @Test + public void testServerListenerCallbacks_invokesDelegates() { + ServerTransportListener listener = tracker.transportCreated(mockServerTransport); + verify(mockServerListener).transportCreated(mockServerTransport); + + listener.transportTerminated(); + verify(mockServerTransportListener).transportTerminated(); + + tracker.serverShutdown(); + verify(mockServerListener).serverShutdown(); + } + + private ServerTransportListener registerNewTransport() { + return tracker.transportCreated(mockServerTransport); + } +} diff --git a/binder/src/test/java/io/grpc/binder/internal/BinderServerTransportTest.java b/binder/src/test/java/io/grpc/binder/internal/BinderServerTransportTest.java index f1e5c5a9553..d261ce43c8c 100644 --- a/binder/src/test/java/io/grpc/binder/internal/BinderServerTransportTest.java +++ b/binder/src/test/java/io/grpc/binder/internal/BinderServerTransportTest.java @@ -16,24 +16,25 @@ package io.grpc.binder.internal; -import static com.google.common.base.Preconditions.checkState; import static com.google.common.truth.Truth.assertThat; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyInt; import static org.mockito.ArgumentMatchers.isNull; -import static org.mockito.Mockito.when; -import static org.robolectric.annotation.LooperMode.Mode.PAUSED; +import static org.mockito.Mockito.doThrow; +import static org.robolectric.Shadows.shadowOf; import android.os.IBinder; +import android.os.Looper; import android.os.Parcel; +import android.os.RemoteException; import com.google.common.collect.ImmutableList; -import com.google.common.util.concurrent.testing.TestingExecutors; import io.grpc.Attributes; -import io.grpc.Metadata; +import io.grpc.ServerStreamTracer; import io.grpc.Status; import io.grpc.internal.FixedObjectPool; -import io.grpc.internal.ServerStream; -import io.grpc.internal.ServerTransportListener; +import io.grpc.internal.MockServerTransportListener; +import io.grpc.internal.ObjectPool; +import java.util.List; import java.util.concurrent.ScheduledExecutorService; import org.junit.Before; import org.junit.Rule; @@ -43,75 +44,126 @@ import org.mockito.junit.MockitoJUnit; import org.mockito.junit.MockitoRule; import org.robolectric.RobolectricTestRunner; -import org.robolectric.annotation.LooperMode; /** * Low-level server-side transport tests for binder channel. Like BinderChannelSmokeTest, this * convers edge cases not exercised by AbstractTransportTest, but it deals with the * binderTransport.BinderServerTransport directly. */ -@LooperMode(PAUSED) @RunWith(RobolectricTestRunner.class) public final class BinderServerTransportTest { @Rule public MockitoRule mocks = MockitoJUnit.rule(); - private final ScheduledExecutorService executorService = - TestingExecutors.sameThreadScheduledExecutor(); - private final TestTransportListener transportListener = new TestTransportListener(); + private final ScheduledExecutorService executorService = new MainThreadScheduledExecutorService(); + private MockServerTransportListener transportListener; @Mock IBinder mockBinder; - BinderTransport.BinderServerTransport transport; + BinderServerTransport transport; @Before public void setUp() throws Exception { - transport = - new BinderTransport.BinderServerTransport( - new FixedObjectPool<>(executorService), - Attributes.EMPTY, - ImmutableList.of(), - OneWayBinderProxy.IDENTITY_DECORATOR, - mockBinder); + transportListener = new MockServerTransportListener(transport); + } + + // Provide defaults so that we can "include only relevant details in tests." + BinderServerTransportBuilder newBinderServerTransportBuilder() { + return new BinderServerTransportBuilder() + .setExecutorServicePool(new FixedObjectPool<>(executorService)) + .setAttributes(Attributes.EMPTY) + .setStreamTracerFactories(ImmutableList.of()) + .setBinderDecorator(OneWayBinderProxy.IDENTITY_DECORATOR) + .setCallbackBinder(mockBinder); } @Test - public void testSetupTransactionFailureCausesMultipleShutdowns_b153460678() throws Exception { + public void testSetupTransactionFailureReportsMultipleTerminations_b153460678() throws Exception { // Make the binder fail the setup transaction. - when(mockBinder.transact(anyInt(), any(Parcel.class), isNull(), anyInt())).thenReturn(false); - transport.setServerTransportListener(transportListener); + doThrow(new RemoteException()) + .when(mockBinder) + .transact(anyInt(), any(Parcel.class), isNull(), anyInt()); + transport = newBinderServerTransportBuilder().setCallbackBinder(mockBinder).build(); + shadowOf(Looper.getMainLooper()).idle(); + transport.start(transportListener); + + // Now shut it down externally *before* executing Runnables scheduled on the executor. + transport.shutdownNow(Status.UNKNOWN.withDescription("reasons")); + shadowOf(Looper.getMainLooper()).idle(); + + assertThat(transportListener.isTerminated()).isTrue(); + } + + @Test + public void testClientBinderIsDeadOnArrival() throws Exception { + transport = newBinderServerTransportBuilder() + .setCallbackBinder(new FakeDeadBinder()) + .build(); + transport.start(transportListener); + shadowOf(Looper.getMainLooper()).idle(); + + assertThat(transportListener.isTerminated()).isTrue(); + } - // Now shut it down. + @Test + public void testStartAfterShutdownAndIdle() throws Exception { + transport = newBinderServerTransportBuilder().build(); transport.shutdownNow(Status.UNKNOWN.withDescription("reasons")); + shadowOf(Looper.getMainLooper()).idle(); + transport.start(transportListener); + shadowOf(Looper.getMainLooper()).idle(); - assertThat(transportListener.terminated).isTrue(); + assertThat(transportListener.isTerminated()).isTrue(); } - private static final class TestTransportListener implements ServerTransportListener { - - public boolean ready; - public boolean terminated; - - /** - * Called when a new stream was created by the remote client. - * - * @param stream the newly created stream. - * @param method the fully qualified method name being called on the server. - * @param headers containing metadata for the call. - */ - @Override - public void streamCreated(ServerStream stream, String method, Metadata headers) {} - - @Override - public Attributes transportReady(Attributes attributes) { - ready = true; - return attributes; + @Test + public void testStartAfterShutdownNoIdle() throws Exception { + transport = newBinderServerTransportBuilder().build(); + transport.shutdownNow(Status.UNKNOWN.withDescription("reasons")); + transport.start(transportListener); + shadowOf(Looper.getMainLooper()).idle(); + + assertThat(transportListener.isTerminated()).isTrue(); + } + + static class BinderServerTransportBuilder { + ObjectPool executorServicePool; + Attributes attributes; + List streamTracerFactories; + OneWayBinderProxy.Decorator binderDecorator; + IBinder callbackBinder; + + public BinderServerTransport build() { + return BinderServerTransport.create( + executorServicePool, attributes, streamTracerFactories, binderDecorator, callbackBinder); + } + + public BinderServerTransportBuilder setExecutorServicePool( + ObjectPool executorServicePool) { + this.executorServicePool = executorServicePool; + return this; + } + + public BinderServerTransportBuilder setAttributes(Attributes attributes) { + this.attributes = attributes; + return this; + } + + public BinderServerTransportBuilder setStreamTracerFactories( + List streamTracerFactories) { + this.streamTracerFactories = streamTracerFactories; + return this; + } + + public BinderServerTransportBuilder setBinderDecorator( + OneWayBinderProxy.Decorator binderDecorator) { + this.binderDecorator = binderDecorator; + return this; } - @Override - public void transportTerminated() { - checkState(!terminated, "Terminated twice"); - terminated = true; + public BinderServerTransportBuilder setCallbackBinder(IBinder callbackBinder) { + this.callbackBinder = callbackBinder; + return this; } } } diff --git a/binder/src/test/java/io/grpc/binder/internal/BlockInputStreamTest.java b/binder/src/test/java/io/grpc/binder/internal/BlockInputStreamTest.java index 5a0279fea22..d78a26cf9b9 100644 --- a/binder/src/test/java/io/grpc/binder/internal/BlockInputStreamTest.java +++ b/binder/src/test/java/io/grpc/binder/internal/BlockInputStreamTest.java @@ -45,8 +45,7 @@ public void testNoBlocks() throws Exception { @Test public void testSingleBlock() throws Exception { - BlockInputStream bis = - new BlockInputStream(new byte[][] {getBytes(10, 1)}, 10); + BlockInputStream bis = new BlockInputStream(new byte[][] {getBytes(10, 1)}, 10); assertThat(bis.read(buff, 0, 20)).isEqualTo(10); assertBytes(buff, 0, 10, 1); } @@ -95,8 +94,7 @@ public void testMultipleBlocksLessData_drain() throws Exception { @Test public void testMultipleBlocksEmptyFinalBlock() throws Exception { - BlockInputStream bis = - new BlockInputStream(new byte[][] {getBytes(10, 1), getBytes(0, 0)}, 10); + BlockInputStream bis = new BlockInputStream(new byte[][] {getBytes(10, 1), getBytes(0, 0)}, 10); assertThat(bis.read(buff, 0, 20)).isEqualTo(10); assertBytes(buff, 0, 10, 1); @@ -106,8 +104,7 @@ public void testMultipleBlocksEmptyFinalBlock() throws Exception { @Test public void testMultipleBlocksEmptyFinalBlock_drain() throws Exception { - BlockInputStream bis = - new BlockInputStream(new byte[][] {getBytes(10, 1), getBytes(0, 0)}, 10); + BlockInputStream bis = new BlockInputStream(new byte[][] {getBytes(10, 1), getBytes(0, 0)}, 10); ByteArrayOutputStream baos = new ByteArrayOutputStream(); bis.drainTo(baos); byte[] data = baos.toByteArray(); diff --git a/binder/src/test/java/io/grpc/binder/internal/IntentNameResolverProviderTest.java b/binder/src/test/java/io/grpc/binder/internal/IntentNameResolverProviderTest.java new file mode 100644 index 00000000000..2809a72fee1 --- /dev/null +++ b/binder/src/test/java/io/grpc/binder/internal/IntentNameResolverProviderTest.java @@ -0,0 +1,130 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.grpc.binder.internal; + +import static android.os.Looper.getMainLooper; +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.robolectric.Shadows.shadowOf; + +import android.app.Application; +import androidx.core.content.ContextCompat; +import androidx.test.core.app.ApplicationProvider; +import io.grpc.NameResolver; +import io.grpc.NameResolver.ResolutionResult; +import io.grpc.NameResolver.ServiceConfigParser; +import io.grpc.NameResolverProvider; +import io.grpc.SynchronizationContext; +import io.grpc.Uri; +import io.grpc.binder.ApiConstants; +import java.net.URI; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.ArgumentCaptor; +import org.mockito.Captor; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnit; +import org.mockito.junit.MockitoTestRule; +import org.robolectric.RobolectricTestRunner; + +/** A test for IntentNameResolverProvider. */ +@RunWith(RobolectricTestRunner.class) +public final class IntentNameResolverProviderTest { + + private final Application appContext = ApplicationProvider.getApplicationContext(); + private final SynchronizationContext syncContext = newSynchronizationContext(); + private final NameResolver.Args args = newNameResolverArgs(); + + private NameResolverProvider provider; + + @Rule public MockitoTestRule mockitoTestRule = MockitoJUnit.testRule(this); + @Mock public NameResolver.Listener2 mockListener; + @Captor public ArgumentCaptor resultCaptor; + + @Before + public void setUp() { + provider = new IntentNameResolverProvider(); + } + + @Test + public void testProviderScheme_returnsIntentScheme() throws Exception { + assertThat(provider.getDefaultScheme()) + .isEqualTo(IntentNameResolverProvider.ANDROID_INTENT_SCHEME); + } + + @Test + public void testNoResolverForUnknownScheme_returnsNull() throws Exception { + assertThat(provider.newNameResolver(Uri.create("random://uri"), args)).isNull(); + } + + @Test + public void testResolutionWithBadUri_throwsIllegalArg() throws Exception { + assertThrows( + IllegalArgumentException.class, + () -> provider.newNameResolver(Uri.create("intent:xxx#Intent;e.x=1;end;"), args)); + } + + @Test + public void testResolverForIntentScheme_returnsResolver() throws Exception { + Uri uri = Uri.create("intent:#Intent;action=action;end"); + NameResolver resolver = provider.newNameResolver(uri, args); + assertThat(resolver).isNotNull(); + assertThat(resolver.getServiceAuthority()).isEqualTo("localhost"); + syncContext.execute(() -> resolver.start(mockListener)); + shadowOf(getMainLooper()).idle(); + verify(mockListener).onResult2(resultCaptor.capture()); + assertThat(resultCaptor.getValue().getAddressesOrError()).isNotNull(); + syncContext.execute(resolver::shutdown); + shadowOf(getMainLooper()).idle(); + } + + @Test + public void testResolverForIntentScheme_returnsResolver_javaNetUri() throws Exception { + URI uri = new URI("intent://authority/path#Intent;action=action;scheme=scheme;end"); + NameResolver resolver = provider.newNameResolver(uri, args); + assertThat(resolver).isNotNull(); + assertThat(resolver.getServiceAuthority()).isEqualTo("localhost"); + syncContext.execute(() -> resolver.start(mockListener)); + shadowOf(getMainLooper()).idle(); + verify(mockListener).onResult2(resultCaptor.capture()); + assertThat(resultCaptor.getValue().getAddressesOrError()).isNotNull(); + syncContext.execute(resolver::shutdown); + shadowOf(getMainLooper()).idle(); + } + + /** Returns a new test-specific {@link NameResolver.Args} instance. */ + private NameResolver.Args newNameResolverArgs() { + return NameResolver.Args.newBuilder() + .setDefaultPort(-1) + .setProxyDetector((target) -> null) // No proxies here. + .setSynchronizationContext(syncContext) + .setOffloadExecutor(ContextCompat.getMainExecutor(appContext)) + .setServiceConfigParser(mock(ServiceConfigParser.class)) + .setArg(ApiConstants.SOURCE_ANDROID_CONTEXT, appContext) + .build(); + } + + private static SynchronizationContext newSynchronizationContext() { + return new SynchronizationContext( + (thread, exception) -> { + throw new AssertionError(exception); + }); + } +} diff --git a/binder/src/test/java/io/grpc/binder/internal/IntentNameResolverTest.java b/binder/src/test/java/io/grpc/binder/internal/IntentNameResolverTest.java new file mode 100644 index 00000000000..b1bfcd4fd56 --- /dev/null +++ b/binder/src/test/java/io/grpc/binder/internal/IntentNameResolverTest.java @@ -0,0 +1,531 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.grpc.binder.internal; + +import static android.content.Intent.ACTION_PACKAGE_ADDED; +import static android.content.Intent.ACTION_PACKAGE_REPLACED; +import static android.os.Looper.getMainLooper; +import static android.os.Process.myUserHandle; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.robolectric.Shadows.shadowOf; + +import android.app.Application; +import android.content.ComponentName; +import android.content.Intent; +import android.content.IntentFilter; +import android.content.pm.ServiceInfo; +import android.net.Uri; +import android.os.UserHandle; +import android.os.UserManager; +import androidx.annotation.NonNull; +import androidx.core.content.ContextCompat; +import androidx.test.core.app.ApplicationProvider; +import com.google.common.collect.ImmutableList; +import io.grpc.EquivalentAddressGroup; +import io.grpc.NameResolver; +import io.grpc.NameResolver.ResolutionResult; +import io.grpc.NameResolver.ServiceConfigParser; +import io.grpc.Status; +import io.grpc.StatusOr; +import io.grpc.SynchronizationContext; +import io.grpc.binder.AndroidComponentAddress; +import io.grpc.binder.ApiConstants; +import java.lang.Thread.UncaughtExceptionHandler; +import java.net.SocketAddress; +import java.util.ArrayList; +import java.util.List; +import java.util.Set; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.ArgumentCaptor; +import org.mockito.Captor; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnit; +import org.mockito.junit.MockitoTestRule; +import org.robolectric.RobolectricTestRunner; +import org.robolectric.annotation.Config; +import org.robolectric.shadows.ShadowPackageManager; + +/** A test for IntentNameResolverProvider. */ +@RunWith(RobolectricTestRunner.class) +public final class IntentNameResolverTest { + + private static final ComponentName SOME_COMPONENT_NAME = + new ComponentName("com.foo.bar", "SomeComponent"); + private static final ComponentName ANOTHER_COMPONENT_NAME = + new ComponentName("org.blah", "AnotherComponent"); + private final Application appContext = ApplicationProvider.getApplicationContext(); + private final SynchronizationContext syncContext = newSynchronizationContext(); + private final NameResolver.Args args = newNameResolverArgs().build(); + + private final ShadowPackageManager shadowPackageManager = + shadowOf(appContext.getPackageManager()); + + @Rule public MockitoTestRule mockitoTestRule = MockitoJUnit.testRule(this); + @Mock public NameResolver.Listener2 mockListener; + @Captor public ArgumentCaptor resultCaptor; + + @Test + public void testResolverForIntentScheme_returnsResolverWithLocalHostAuthority() throws Exception { + NameResolver resolver = newNameResolver(newIntent()); + assertThat(resolver).isNotNull(); + assertThat(resolver.getServiceAuthority()).isEqualTo("localhost"); + } + + @Test + public void testResolutionWithoutServicesAvailable_returnsUnimplemented() throws Exception { + NameResolver nameResolver = newNameResolver(newIntent()); + syncContext.execute(() -> nameResolver.start(mockListener)); + shadowOf(getMainLooper()).idle(); + verify(mockListener).onResult2(resultCaptor.capture()); + assertThat(resultCaptor.getValue().getAddressesOrError().getStatus().getCode()) + .isEqualTo(Status.UNIMPLEMENTED.getCode()); + } + + @Test + public void testResolutionWithMultipleServicesAvailable_returnsAndroidComponentAddresses() + throws Exception { + Intent intent = newIntent(); + IntentFilter serviceIntentFilter = newFilterMatching(intent); + + shadowPackageManager.addServiceIfNotPresent(SOME_COMPONENT_NAME); + shadowPackageManager.addIntentFilterForService(SOME_COMPONENT_NAME, serviceIntentFilter); + + // Adds another valid Service + shadowPackageManager.addServiceIfNotPresent(ANOTHER_COMPONENT_NAME); + shadowPackageManager.addIntentFilterForService(ANOTHER_COMPONENT_NAME, serviceIntentFilter); + + NameResolver nameResolver = newNameResolver(intent); + syncContext.execute(() -> nameResolver.start(mockListener)); + shadowOf(getMainLooper()).idle(); + + verify(mockListener, never()).onError(any()); + verify(mockListener).onResult2(resultCaptor.capture()); + assertThat(getAddressesOrThrow(resultCaptor.getValue())) + .containsExactly( + toAddressList(intent.cloneFilter().setComponent(SOME_COMPONENT_NAME)), + toAddressList(intent.cloneFilter().setComponent(ANOTHER_COMPONENT_NAME))); + + syncContext.execute(nameResolver::shutdown); + shadowOf(getMainLooper()).idle(); + } + + @Test + public void testExplicitResolutionByComponent_returnsRestrictedResults() throws Exception { + Intent intent = newIntent(); + IntentFilter serviceIntentFilter = newFilterMatching(intent); + + shadowPackageManager.addServiceIfNotPresent(SOME_COMPONENT_NAME); + shadowPackageManager.addIntentFilterForService(SOME_COMPONENT_NAME, serviceIntentFilter); + shadowPackageManager.addServiceIfNotPresent(ANOTHER_COMPONENT_NAME); + shadowPackageManager.addIntentFilterForService(ANOTHER_COMPONENT_NAME, serviceIntentFilter); + + NameResolver nameResolver = + newNameResolver(intent.cloneFilter().setComponent(ANOTHER_COMPONENT_NAME)); + syncContext.execute(() -> nameResolver.start(mockListener)); + shadowOf(getMainLooper()).idle(); + + verify(mockListener, never()).onError(any()); + verify(mockListener).onResult2(resultCaptor.capture()); + assertThat(getAddressesOrThrow(resultCaptor.getValue())) + .containsExactly(toAddressList(intent.cloneFilter().setComponent(ANOTHER_COMPONENT_NAME))); + + syncContext.execute(nameResolver::shutdown); + shadowOf(getMainLooper()).idle(); + } + + @Test + public void testExplicitResolutionByPackage_returnsRestrictedResults() throws Exception { + Intent intent = newIntent(); + IntentFilter serviceIntentFilter = newFilterMatching(intent); + + shadowPackageManager.addServiceIfNotPresent(SOME_COMPONENT_NAME); + shadowPackageManager.addIntentFilterForService(SOME_COMPONENT_NAME, serviceIntentFilter); + shadowPackageManager.addServiceIfNotPresent(ANOTHER_COMPONENT_NAME); + shadowPackageManager.addIntentFilterForService(ANOTHER_COMPONENT_NAME, serviceIntentFilter); + + NameResolver nameResolver = + newNameResolver(intent.cloneFilter().setPackage(ANOTHER_COMPONENT_NAME.getPackageName())); + syncContext.execute(() -> nameResolver.start(mockListener)); + shadowOf(getMainLooper()).idle(); + + verify(mockListener, never()).onError(any()); + verify(mockListener).onResult2(resultCaptor.capture()); + assertThat(getAddressesOrThrow(resultCaptor.getValue())) + .containsExactly(toAddressList(intent.cloneFilter().setComponent(ANOTHER_COMPONENT_NAME))); + + syncContext.execute(nameResolver::shutdown); + shadowOf(getMainLooper()).idle(); + } + + @Test + public void testResolution_setsPreAuthEagAttribute() throws Exception { + Intent intent = newIntent(); + IntentFilter serviceIntentFilter = newFilterMatching(intent); + + shadowPackageManager.addServiceIfNotPresent(SOME_COMPONENT_NAME); + shadowPackageManager.addIntentFilterForService(SOME_COMPONENT_NAME, serviceIntentFilter); + + NameResolver nameResolver = newNameResolver(intent); + syncContext.execute(() -> nameResolver.start(mockListener)); + shadowOf(getMainLooper()).idle(); + + verify(mockListener).onResult2(resultCaptor.capture()); + assertThat(getAddressesOrThrow(resultCaptor.getValue())) + .containsExactly(toAddressList(intent.cloneFilter().setComponent(SOME_COMPONENT_NAME))); + assertThat( + getEagsOrThrow(resultCaptor.getValue()).stream() + .map(EquivalentAddressGroup::getAttributes) + .collect(toImmutableList()) + .get(0) + .get(ApiConstants.PRE_AUTH_SERVER_OVERRIDE)) + .isTrue(); + + syncContext.execute(nameResolver::shutdown); + shadowOf(getMainLooper()).idle(); + } + + @Test + public void testServiceRemoved_pushesUpdatedAndroidComponentAddresses() throws Exception { + Intent intent = newIntent(); + IntentFilter serviceIntentFilter = newFilterMatching(intent); + + shadowPackageManager.addServiceIfNotPresent(SOME_COMPONENT_NAME); + shadowPackageManager.addIntentFilterForService(SOME_COMPONENT_NAME, serviceIntentFilter); + shadowPackageManager.addServiceIfNotPresent(ANOTHER_COMPONENT_NAME); + shadowPackageManager.addIntentFilterForService(ANOTHER_COMPONENT_NAME, serviceIntentFilter); + + NameResolver nameResolver = newNameResolver(intent); + syncContext.execute(() -> nameResolver.start(mockListener)); + shadowOf(getMainLooper()).idle(); + + verify(mockListener, never()).onError(any()); + verify(mockListener).onResult2(resultCaptor.capture()); + assertThat(getAddressesOrThrow(resultCaptor.getValue())) + .containsExactly( + toAddressList(intent.cloneFilter().setComponent(SOME_COMPONENT_NAME)), + toAddressList(intent.cloneFilter().setComponent(ANOTHER_COMPONENT_NAME))); + + shadowPackageManager.removeService(ANOTHER_COMPONENT_NAME); + broadcastPackageChange(ACTION_PACKAGE_REPLACED, ANOTHER_COMPONENT_NAME.getPackageName()); + shadowOf(getMainLooper()).idle(); + + verify(mockListener, never()).onError(any()); + verify(mockListener, times(2)).onResult2(resultCaptor.capture()); + assertThat(getAddressesOrThrow(resultCaptor.getValue())) + .containsExactly(toAddressList(intent.cloneFilter().setComponent(SOME_COMPONENT_NAME))); + + syncContext.execute(nameResolver::shutdown); + shadowOf(getMainLooper()).idle(); + + verifyNoMoreInteractions(mockListener); + assertThat(shadowOf(appContext).getRegisteredReceivers()).isEmpty(); + } + + @Test + @Config(sdk = 30) + public void testTargetAndroidUser_pushesUpdatedAddresses() throws Exception { + Intent intent = newIntent(); + IntentFilter serviceIntentFilter = newFilterMatching(intent); + + NameResolver nameResolver = + newNameResolver( + intent, + newNameResolverArgs().setArg(ApiConstants.TARGET_ANDROID_USER, myUserHandle()).build()); + syncContext.execute(() -> nameResolver.start(mockListener)); + shadowOf(getMainLooper()).idle(); + verify(mockListener).onResult2(resultCaptor.capture()); + assertThat(resultCaptor.getValue().getAddressesOrError().getStatus().getCode()) + .isEqualTo(Status.UNIMPLEMENTED.getCode()); + + shadowPackageManager.addServiceIfNotPresent(SOME_COMPONENT_NAME); + shadowPackageManager.addIntentFilterForService(SOME_COMPONENT_NAME, serviceIntentFilter); + broadcastPackageChange(ACTION_PACKAGE_ADDED, SOME_COMPONENT_NAME.getPackageName()); + shadowOf(getMainLooper()).idle(); + + verify(mockListener, never()).onError(any()); + verify(mockListener, times(2)).onResult2(resultCaptor.capture()); + assertThat(getAddressesOrThrow(resultCaptor.getValue())) + .containsExactly( + ImmutableList.of( + AndroidComponentAddress.newBuilder() + .setTargetUser(myUserHandle()) + .setBindIntent(intent.cloneFilter().setComponent(SOME_COMPONENT_NAME)) + .build())); + + syncContext.execute(nameResolver::shutdown); + shadowOf(getMainLooper()).idle(); + + verifyNoMoreInteractions(mockListener); + assertThat(shadowOf(appContext).getRegisteredReceivers()).isEmpty(); + } + + @Test + @Config(sdk = 29) + public void testTargetAndroidUser_notSupported_throwsWithHelpfulMessage() throws Exception { + NameResolver.Args args = + newNameResolverArgs().setArg(ApiConstants.TARGET_ANDROID_USER, myUserHandle()).build(); + IllegalArgumentException iae = + assertThrows(IllegalArgumentException.class, () -> newNameResolver(newIntent(), args)); + assertThat(iae.getMessage()).contains("TARGET_ANDROID_USER"); + assertThat(iae.getMessage()).contains("SDK_INT >= R"); + } + + @Test + @Config(sdk = 29) + public void testServiceAppearsUponBootComplete_pushesUpdatedAndroidComponentAddresses() + throws Exception { + Intent intent = newIntent(); + IntentFilter serviceIntentFilter = newFilterMatching(intent); + + // Suppose this directBootAware=true Service appears in PackageManager before a user unlock. + shadowOf(appContext.getSystemService(UserManager.class)).setUserUnlocked(false); + ServiceInfo someServiceInfo = shadowPackageManager.addServiceIfNotPresent(SOME_COMPONENT_NAME); + someServiceInfo.directBootAware = true; + shadowPackageManager.addIntentFilterForService(SOME_COMPONENT_NAME, serviceIntentFilter); + + NameResolver nameResolver = newNameResolver(intent); + syncContext.execute(() -> nameResolver.start(mockListener)); + shadowOf(getMainLooper()).idle(); + + verify(mockListener, never()).onError(any()); + verify(mockListener).onResult2(resultCaptor.capture()); + assertThat(getAddressesOrThrow(resultCaptor.getValue())) + .containsExactly(toAddressList(intent.cloneFilter().setComponent(SOME_COMPONENT_NAME))); + + // TODO(b/331618070): Robolectric doesn't yet support ServiceInfo.directBootAware filtering. + // Simulate support by waiting for a user unlock to add this !directBootAware Service. + ServiceInfo anotherServiceInfo = + shadowPackageManager.addServiceIfNotPresent(ANOTHER_COMPONENT_NAME); + anotherServiceInfo.directBootAware = false; + shadowPackageManager.addIntentFilterForService(ANOTHER_COMPONENT_NAME, serviceIntentFilter); + + shadowOf(appContext.getSystemService(UserManager.class)).setUserUnlocked(true); + broadcastUserUnlocked(myUserHandle()); + shadowOf(getMainLooper()).idle(); + + verify(mockListener, never()).onError(any()); + verify(mockListener, times(2)).onResult2(resultCaptor.capture()); + assertThat(getAddressesOrThrow(resultCaptor.getValue())) + .containsExactly( + toAddressList(intent.cloneFilter().setComponent(SOME_COMPONENT_NAME)), + toAddressList(intent.cloneFilter().setComponent(ANOTHER_COMPONENT_NAME))); + + syncContext.execute(nameResolver::shutdown); + shadowOf(getMainLooper()).idle(); + verifyNoMoreInteractions(mockListener); + } + + @Test + public void testRefresh_returnsSameAndroidComponentAddresses() throws Exception { + Intent intent = newIntent(); + IntentFilter serviceIntentFilter = newFilterMatching(intent); + + shadowPackageManager.addServiceIfNotPresent(SOME_COMPONENT_NAME); + shadowPackageManager.addIntentFilterForService(SOME_COMPONENT_NAME, serviceIntentFilter); + shadowPackageManager.addServiceIfNotPresent(ANOTHER_COMPONENT_NAME); + shadowPackageManager.addIntentFilterForService(ANOTHER_COMPONENT_NAME, serviceIntentFilter); + + NameResolver nameResolver = newNameResolver(intent); + syncContext.execute(() -> nameResolver.start(mockListener)); + shadowOf(getMainLooper()).idle(); + + verify(mockListener, never()).onError(any()); + verify(mockListener).onResult2(resultCaptor.capture()); + assertThat(getAddressesOrThrow(resultCaptor.getValue())) + .containsExactly( + toAddressList(intent.cloneFilter().setComponent(SOME_COMPONENT_NAME)), + toAddressList(intent.cloneFilter().setComponent(ANOTHER_COMPONENT_NAME))); + + syncContext.execute(nameResolver::refresh); + shadowOf(getMainLooper()).idle(); + verify(mockListener, never()).onError(any()); + verify(mockListener, times(2)).onResult2(resultCaptor.capture()); + assertThat(getAddressesOrThrow(resultCaptor.getValue())) + .containsExactly( + toAddressList(intent.cloneFilter().setComponent(SOME_COMPONENT_NAME)), + toAddressList(intent.cloneFilter().setComponent(ANOTHER_COMPONENT_NAME))); + + syncContext.execute(nameResolver::shutdown); + shadowOf(getMainLooper()).idle(); + assertThat(shadowOf(appContext).getRegisteredReceivers()).isEmpty(); + } + + @Test + public void testRefresh_collapsesMultipleRequestsIntoOneLookup() throws Exception { + Intent intent = newIntent(); + IntentFilter serviceIntentFilter = newFilterMatching(intent); + + shadowPackageManager.addServiceIfNotPresent(SOME_COMPONENT_NAME); + shadowPackageManager.addIntentFilterForService(SOME_COMPONENT_NAME, serviceIntentFilter); + + NameResolver nameResolver = newNameResolver(intent); + syncContext.execute(() -> nameResolver.start(mockListener)); // Should kick off the 1st lookup. + syncContext.execute(nameResolver::refresh); // Should queue a lookup to run when 1st finishes. + syncContext.execute(nameResolver::refresh); // Should be ignored since a lookup is already Q'd. + syncContext.execute(nameResolver::refresh); // Also ignored. + shadowOf(getMainLooper()).idle(); + + verify(mockListener, never()).onError(any()); + verify(mockListener, times(2)).onResult2(resultCaptor.capture()); + assertThat(getAddressesOrThrow(resultCaptor.getValue())) + .containsExactly(toAddressList(intent.cloneFilter().setComponent(SOME_COMPONENT_NAME))); + + syncContext.execute(nameResolver::shutdown); + shadowOf(getMainLooper()).idle(); + } + + private void broadcastPackageChange(String action, String pkgName) { + Intent broadcast = new Intent(); + broadcast.setAction(action); + broadcast.setData(Uri.parse("package:" + pkgName)); + appContext.sendBroadcast(broadcast); + } + + private void broadcastUserUnlocked(UserHandle userHandle) { + Intent unlockedBroadcast = new Intent(Intent.ACTION_USER_UNLOCKED); + unlockedBroadcast.putExtra(Intent.EXTRA_USER, userHandle); + appContext.sendBroadcast(unlockedBroadcast); + } + + @Test + public void testResolutionOnResultThrows_onErrorNotCalled() throws Exception { + RetainingUncaughtExceptionHandler exceptionHandler = new RetainingUncaughtExceptionHandler(); + SynchronizationContext syncContext = new SynchronizationContext(exceptionHandler); + Intent intent = newIntent(); + shadowPackageManager.addServiceIfNotPresent(SOME_COMPONENT_NAME); + shadowPackageManager.addIntentFilterForService(SOME_COMPONENT_NAME, newFilterMatching(intent)); + + @SuppressWarnings("serial") + class SomeRuntimeException extends RuntimeException {} + doThrow(SomeRuntimeException.class).when(mockListener).onResult2(any()); + + NameResolver nameResolver = + newNameResolver( + intent, newNameResolverArgs().setSynchronizationContext(syncContext).build()); + syncContext.execute(() -> nameResolver.start(mockListener)); + shadowOf(getMainLooper()).idle(); + + verify(mockListener).onResult2(any()); + verify(mockListener, never()).onError(any()); + assertThat(exceptionHandler.uncaught).hasSize(1); + assertThat(exceptionHandler.uncaught.get(0)).isInstanceOf(SomeRuntimeException.class); + } + + private static Intent newIntent() { + Intent intent = new Intent(); + intent.setAction("test.action"); + intent.setData(Uri.parse("grpc:ServiceName")); + return intent; + } + + private static IntentFilter newFilterMatching(Intent intent) { + IntentFilter filter = new IntentFilter(); + if (intent.getAction() != null) { + filter.addAction(intent.getAction()); + } + Uri data = intent.getData(); + if (data != null) { + if (data.getScheme() != null) { + filter.addDataScheme(data.getScheme()); + } + if (data.getSchemeSpecificPart() != null) { + filter.addDataSchemeSpecificPart(data.getSchemeSpecificPart(), 0); + } + } + Set categories = intent.getCategories(); + if (categories != null) { + for (String category : categories) { + filter.addCategory(category); + } + } + return filter; + } + + private static List getEagsOrThrow(ResolutionResult result) { + StatusOr> eags = result.getAddressesOrError(); + if (!eags.hasValue()) { + throw eags.getStatus().asRuntimeException(); + } + return eags.getValue(); + } + + // Extracts just the addresses from 'result's EquivalentAddressGroups. + private static ImmutableList> getAddressesOrThrow(ResolutionResult result) { + return getEagsOrThrow(result).stream() + .map(EquivalentAddressGroup::getAddresses) + .collect(toImmutableList()); + } + + // Converts given Intents to a list of ACAs, for convenient comparison with getAddressesOrThrow(). + private static ImmutableList toAddressList(Intent... bindIntents) { + ImmutableList.Builder builder = ImmutableList.builder(); + for (Intent bindIntent : bindIntents) { + builder.add(AndroidComponentAddress.forBindIntent(bindIntent)); + } + return builder.build(); + } + + private NameResolver newNameResolver(Intent targetIntent) { + return newNameResolver(targetIntent, args); + } + + private NameResolver newNameResolver(Intent targetIntent, NameResolver.Args args) { + return new IntentNameResolver(targetIntent, args); + } + + /** Returns a new test-specific {@link NameResolver.Args} instance. */ + private NameResolver.Args.Builder newNameResolverArgs() { + return NameResolver.Args.newBuilder() + .setDefaultPort(-1) + .setProxyDetector((target) -> null) // No proxies here. + .setSynchronizationContext(syncContext) + .setOffloadExecutor(ContextCompat.getMainExecutor(appContext)) + .setArg(ApiConstants.SOURCE_ANDROID_CONTEXT, appContext) + .setServiceConfigParser(mock(ServiceConfigParser.class)); + } + + /** + * Returns a test {@link SynchronizationContext}. + * + *

Exceptions will cause the test to fail with {@link AssertionError}. + */ + private static SynchronizationContext newSynchronizationContext() { + return new SynchronizationContext( + (thread, exception) -> { + throw new AssertionError(exception); + }); + } + + static final class RetainingUncaughtExceptionHandler implements UncaughtExceptionHandler { + final ArrayList uncaught = new ArrayList<>(); + + @Override + public void uncaughtException(@NonNull Thread t, @NonNull Throwable e) { + uncaught.add(e); + } + } +} diff --git a/binder/src/test/java/io/grpc/binder/internal/ParcelableInputStreamTest.java b/binder/src/test/java/io/grpc/binder/internal/ParcelableInputStreamTest.java index bf90e21d046..657c6d77db5 100644 --- a/binder/src/test/java/io/grpc/binder/internal/ParcelableInputStreamTest.java +++ b/binder/src/test/java/io/grpc/binder/internal/ParcelableInputStreamTest.java @@ -88,8 +88,9 @@ public void testWriteToParcel() throws Exception { stream.writeToParcel(parcel); parcel.setDataPosition(0); - assertThat((TestParcelable) parcel.readParcelable(getClass().getClassLoader())) - .isEqualTo(testParcelable); + @SuppressWarnings("deprecation") // readParcelable(ClassLoader)'s replacement is only in 33+. + TestParcelable clone = parcel.readParcelable(getClass().getClassLoader()); + assertThat(clone).isEqualTo(testParcelable); } @Test @@ -113,8 +114,9 @@ public void testAsRegularInputStream() throws Exception { parcel.unmarshall(data, 0, data.length); parcel.setDataPosition(0); - assertThat((TestParcelable) parcel.readParcelable(getClass().getClassLoader())) - .isEqualTo(testParcelable); + @SuppressWarnings("deprecation") // readParcelable(ClassLoader)'s replacement is only in 33+. + TestParcelable clone = parcel.readParcelable(getClass().getClassLoader()); + assertThat(clone).isEqualTo(testParcelable); } @Test diff --git a/binder/src/test/java/io/grpc/binder/internal/PendingAuthListenerTest.java b/binder/src/test/java/io/grpc/binder/internal/PendingAuthListenerTest.java index 29b35b309fb..9cdf123033b 100644 --- a/binder/src/test/java/io/grpc/binder/internal/PendingAuthListenerTest.java +++ b/binder/src/test/java/io/grpc/binder/internal/PendingAuthListenerTest.java @@ -6,6 +6,10 @@ import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.when; +import io.grpc.Metadata; +import io.grpc.ServerCall; +import io.grpc.ServerCallHandler; +import io.grpc.Status; import org.junit.Before; import org.junit.Rule; import org.junit.Test; @@ -19,11 +23,6 @@ import org.mockito.junit.MockitoJUnit; import org.mockito.junit.MockitoRule; -import io.grpc.Metadata; -import io.grpc.ServerCall; -import io.grpc.ServerCallHandler; -import io.grpc.Status; - @RunWith(JUnit4.class) public final class PendingAuthListenerTest { diff --git a/binder/src/test/java/io/grpc/binder/internal/PingTrackerTest.java b/binder/src/test/java/io/grpc/binder/internal/PingTrackerTest.java index 60e7c163105..c662cafe5fa 100644 --- a/binder/src/test/java/io/grpc/binder/internal/PingTrackerTest.java +++ b/binder/src/test/java/io/grpc/binder/internal/PingTrackerTest.java @@ -96,7 +96,7 @@ private static final class TestCallback implements ClientTransport.PingCallback private int numCallbacks; private boolean success; private boolean failure; - private Throwable failureException; + private Status failureStatus; private long roundtripTimeNanos; @Override @@ -107,10 +107,10 @@ public synchronized void onSuccess(long roundtripTimeNanos) { } @Override - public synchronized void onFailure(Throwable failureException) { + public synchronized void onFailure(Status failureStatus) { numCallbacks += 1; failure = true; - this.failureException = failureException; + this.failureStatus = failureStatus; } public void assertNotCalled() { @@ -130,13 +130,13 @@ public void assertSuccess(long expectRoundTripTimeNanos) { public void assertFailure(Status status) { assertThat(numCallbacks).isEqualTo(1); assertThat(failure).isTrue(); - assertThat(((StatusException) failureException).getStatus()).isSameInstanceAs(status); + assertThat(failureStatus).isSameInstanceAs(status); } public void assertFailure(Status.Code statusCode) { assertThat(numCallbacks).isEqualTo(1); assertThat(failure).isTrue(); - assertThat(((StatusException) failureException).getStatus().getCode()).isEqualTo(statusCode); + assertThat(failureStatus.getCode()).isEqualTo(statusCode); } } } diff --git a/binder/src/test/java/io/grpc/binder/internal/RobolectricBinderTransportTest.java b/binder/src/test/java/io/grpc/binder/internal/RobolectricBinderTransportTest.java new file mode 100644 index 00000000000..63c47bf4f19 --- /dev/null +++ b/binder/src/test/java/io/grpc/binder/internal/RobolectricBinderTransportTest.java @@ -0,0 +1,698 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.binder.internal; + +import static android.os.IBinder.FLAG_ONEWAY; +import static android.os.Process.myUid; +import static com.google.common.truth.Truth.assertAbout; +import static com.google.common.truth.Truth.assertThat; +import static com.google.common.util.concurrent.MoreExecutors.directExecutor; +import static io.grpc.StatusSubject.status; +import static io.grpc.binder.internal.BinderTransport.REMOTE_UID; +import static io.grpc.binder.internal.BinderTransport.SETUP_TRANSPORT; +import static io.grpc.binder.internal.BinderTransport.SHUTDOWN_TRANSPORT; +import static io.grpc.binder.internal.BinderTransport.WIRE_FORMAT_VERSION; +import static java.util.concurrent.TimeUnit.MILLISECONDS; +import static org.junit.Assume.assumeTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyLong; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.timeout; +import static org.mockito.Mockito.verify; +import static org.robolectric.Shadows.shadowOf; + +import android.app.Application; +import android.content.Intent; +import android.content.pm.ApplicationInfo; +import android.content.pm.PackageInfo; +import android.content.pm.ServiceInfo; +import android.os.Binder; +import android.os.Parcel; +import androidx.test.core.app.ApplicationProvider; +import androidx.test.core.content.pm.ApplicationInfoBuilder; +import androidx.test.core.content.pm.PackageInfoBuilder; +import com.google.common.collect.ImmutableList; +import com.google.common.truth.TruthJUnit; +import io.grpc.Attributes; +import io.grpc.CallOptions; +import io.grpc.InternalChannelz.SocketStats; +import io.grpc.Metadata; +import io.grpc.ServerStreamTracer; +import io.grpc.Status; +import io.grpc.binder.AndroidComponentAddress; +import io.grpc.binder.ApiConstants; +import io.grpc.binder.AsyncSecurityPolicy; +import io.grpc.binder.SecurityPolicies; +import io.grpc.binder.internal.OneWayBinderProxies.*; +import io.grpc.binder.internal.SettableAsyncSecurityPolicy.AuthRequest; +import io.grpc.internal.AbstractTransportTest; +import io.grpc.internal.ClientStream; +import io.grpc.internal.ClientStreamListenerBase; +import io.grpc.internal.ClientTransport; +import io.grpc.internal.ClientTransportFactory.ClientTransportOptions; +import io.grpc.internal.ConnectionClientTransport; +import io.grpc.internal.DisconnectError; +import io.grpc.internal.GrpcUtil; +import io.grpc.internal.InternalServer; +import io.grpc.internal.ManagedClientTransport; +import io.grpc.internal.MockServerTransportListener; +import io.grpc.internal.ObjectPool; +import io.grpc.internal.SharedResourcePool; +import java.io.InputStream; +import java.util.List; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.Executor; +import java.util.concurrent.ScheduledExecutorService; +import org.junit.Before; +import org.junit.Ignore; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.ArgumentCaptor; +import org.mockito.Captor; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnit; +import org.mockito.junit.MockitoRule; +import org.robolectric.ParameterizedRobolectricTestRunner; +import org.robolectric.ParameterizedRobolectricTestRunner.Parameter; +import org.robolectric.ParameterizedRobolectricTestRunner.Parameters; +import org.robolectric.annotation.LooperMode; +import org.robolectric.annotation.LooperMode.Mode; +import org.robolectric.shadows.ShadowBinder; + +/** + * All of the AbstractTransportTest cases applied to {@link BinderTransport} running in a + * Robolectric environment. + * + *

Runs much faster than BinderTransportTest and doesn't require an Android device/emulator. + * Somewhat less realistic but allows simulating behavior that would be difficult or impossible with + * real Android. + * + *

NB: Unlike most robolectric tests, we run in {@link LooperMode.Mode#INSTRUMENTATION_TEST}, + * meaning test cases don't run on the main thread. This supports the AbstractTransportTest approach + * where the test thread frequently blocks waiting for transport state changes to take effect. + */ +@RunWith(ParameterizedRobolectricTestRunner.class) +@LooperMode(Mode.INSTRUMENTATION_TEST) +public final class RobolectricBinderTransportTest extends AbstractTransportTest { + + static final int SERVER_APP_UID = 11111; + static final int EPHEMERAL_SERVER_UID = 22222; // UID of isolated server process. + + private final Application application = ApplicationProvider.getApplicationContext(); + private final ObjectPool executorServicePool = + SharedResourcePool.forResource(GrpcUtil.TIMER_SERVICE); + private final ObjectPool offloadExecutorPool = + SharedResourcePool.forResource(GrpcUtil.SHARED_CHANNEL_EXECUTOR); + private final ObjectPool serverExecutorPool = + SharedResourcePool.forResource(GrpcUtil.SHARED_CHANNEL_EXECUTOR); + + @Rule public MockitoRule mocks = MockitoJUnit.rule(); + + @Mock AsyncSecurityPolicy mockClientSecurityPolicy; + + @Captor ArgumentCaptor statusCaptor; + + ApplicationInfo serverAppInfo; + PackageInfo serverPkgInfo; + ServiceInfo serviceInfo; + + private int nextServerAddress; + private BlockingBinderDecorator blockingDecorator = + new BlockingBinderDecorator<>(); + + @Parameter(value = 0) + public boolean preAuthServersParam; + + @Parameter(value = 1) + public boolean useLegacyAuthStrategy; + + @Parameters(name = "preAuthServersParam={0};useLegacyAuthStrategy={1}") + public static ImmutableList data() { + return ImmutableList.of( + new Object[] {false, false}, + new Object[] {false, true}, + new Object[] {true, false}, + new Object[] {true, true}); + } + + @Override + public void setUp() { + serverAppInfo = + ApplicationInfoBuilder.newBuilder().setPackageName("the.server.package").build(); + serverAppInfo.uid = myUid(); + serverPkgInfo = + PackageInfoBuilder.newBuilder() + .setPackageName(serverAppInfo.packageName) + .setApplicationInfo(serverAppInfo) + .build(); + shadowOf(application.getPackageManager()).installPackage(serverPkgInfo); + + serviceInfo = new ServiceInfo(); + serviceInfo.name = "SomeService"; + serviceInfo.packageName = serverAppInfo.packageName; + serviceInfo.applicationInfo = serverAppInfo; + shadowOf(application.getPackageManager()).addOrUpdateService(serviceInfo); + + super.setUp(); + } + + @Before + public void requestRealisticBindServiceBehavior() { + shadowOf(application).setBindServiceCallsOnServiceConnectedDirectly(false); + shadowOf(application).setUnbindServiceCallsOnServiceDisconnected(false); + } + + BinderServer.Builder newServerBuilder() { + AndroidComponentAddress listenAddr = + AndroidComponentAddress.forBindIntent( + new Intent() + .setClassName(serviceInfo.packageName, serviceInfo.name) + .setAction("io.grpc.action.BIND." + nextServerAddress++)); + + return new BinderServer.Builder() + .setListenAddress(listenAddr) + .setExecutorPool(serverExecutorPool) + .setExecutorServicePool(executorServicePool) + .setStreamTracerFactories(List.of()); + } + + void registerServerWithRobolectric(BinderServer server) { + AndroidComponentAddress listenAddr = (AndroidComponentAddress) server.getListenSocketAddress(); + shadowOf(application.getPackageManager()).addServiceIfNotPresent(listenAddr.getComponent()); + shadowOf(application) + .setComponentNameAndServiceForBindServiceForIntent( + listenAddr.asBindIntent(), listenAddr.getComponent(), server.getHostBinder()); + } + + @Override + protected InternalServer newServer(List streamTracerFactories) { + BinderServer server = + newServerBuilder().setStreamTracerFactories(streamTracerFactories).build(); + registerServerWithRobolectric(server); + return server; + } + + @Override + protected InternalServer newServer( + int port, List streamTracerFactories) { + if (port > 0) { + // TODO: TCP ports have no place in an *abstract* transport test. Replace with SocketAddress. + throw new UnsupportedOperationException(); + } + return newServer(streamTracerFactories); + } + + BinderClientTransportFactory.Builder newClientTransportFactoryBuilder() { + return new BinderClientTransportFactory.Builder() + .setPreAuthorizeServers(preAuthServersParam) + .setUseLegacyAuthStrategy(useLegacyAuthStrategy) + .setSourceContext(application) + .setScheduledExecutorPool(executorServicePool) + .setOffloadExecutorPool(offloadExecutorPool); + } + + BinderClientTransportBuilder newClientTransportBuilder() { + return new BinderClientTransportBuilder() + .setFactory(newClientTransportFactoryBuilder().buildClientTransportFactory()) + .setServerAddress(server.getListenSocketAddress()); + } + + @Override + protected ManagedClientTransport newClientTransport(InternalServer server) { + ClientTransportOptions options = new ClientTransportOptions(); + options.setEagAttributes(eagAttrs()); + options.setChannelLogger(transportLogger()); + + return newClientTransportBuilder() + .setServerAddress(server.getListenSocketAddress()) + .setOptions(options) + .build(); + } + + @Override + protected String testAuthority(InternalServer server) { + return ((AndroidComponentAddress) server.getListenSocketAddress()).getAuthority(); + } + + @Test + public void clientAuthorizesServerUidsInOrder() throws Exception { + // TODO(jdcormie): In real Android, Binder#getCallingUid is thread-local but Robolectric only + // lets us fake value this *globally*. So the ShadowBinder#setCallingUid() here unrealistically + // affects the server's view of the client's uid too. For now this doesn't matter because this + // test never exercises server SecurityPolicy. + ShadowBinder.setCallingUid(EPHEMERAL_SERVER_UID); + + serverPkgInfo.applicationInfo.uid = SERVER_APP_UID; + shadowOf(application.getPackageManager()).installPackage(serverPkgInfo); + shadowOf(application.getPackageManager()).addOrUpdateService(serviceInfo); + server = newServer(ImmutableList.of()); + server.start(serverListener); + + SettableAsyncSecurityPolicy securityPolicy = new SettableAsyncSecurityPolicy(); + client = + newClientTransportBuilder() + .setFactory( + newClientTransportFactoryBuilder() + .setSecurityPolicy(securityPolicy) + .buildClientTransportFactory()) + .build(); + runIfNotNull(client.start(mockClientTransportListener)); + + if (preAuthServersParam) { + AuthRequest preAuthRequest = securityPolicy.takeNextAuthRequest(TIMEOUT_MS, MILLISECONDS); + assertThat(preAuthRequest.uid).isEqualTo(SERVER_APP_UID); + verify(mockClientTransportListener, never()).transportReady(); + preAuthRequest.setResult(Status.OK); + } + + AuthRequest authRequest = securityPolicy.takeNextAuthRequest(TIMEOUT_MS, MILLISECONDS); + if (useLegacyAuthStrategy) { + assertThat(authRequest.uid).isEqualTo(EPHEMERAL_SERVER_UID); + } else { + assertThat(authRequest.uid).isEqualTo(SERVER_APP_UID); + } + verify(mockClientTransportListener, never()).transportReady(); + authRequest.setResult(Status.OK); + + verify(mockClientTransportListener, timeout(TIMEOUT_MS)).transportReady(); + } + + @Test + public void eagAttributeCanOverrideChannelPreAuthServerSetting() throws Exception { + server.start(serverListener); + SettableAsyncSecurityPolicy securityPolicy = new SettableAsyncSecurityPolicy(); + ClientTransportOptions options = new ClientTransportOptions(); + options.setEagAttributes( + Attributes.newBuilder().set(ApiConstants.PRE_AUTH_SERVER_OVERRIDE, true).build()); + client = + newClientTransportBuilder() + .setOptions(options) + .setFactory( + newClientTransportFactoryBuilder() + .setPreAuthorizeServers(preAuthServersParam) // To be overridden. + .setSecurityPolicy(securityPolicy) + .buildClientTransportFactory()) + .build(); + runIfNotNull(client.start(mockClientTransportListener)); + + AuthRequest preAuthRequest = securityPolicy.takeNextAuthRequest(TIMEOUT_MS, MILLISECONDS); + verify(mockClientTransportListener, never()).transportReady(); + preAuthRequest.setResult(Status.OK); + + AuthRequest authRequest = securityPolicy.takeNextAuthRequest(TIMEOUT_MS, MILLISECONDS); + verify(mockClientTransportListener, never()).transportReady(); + authRequest.setResult(Status.OK); + + verify(mockClientTransportListener, timeout(TIMEOUT_MS)).transportReady(); + } + + @Test + public void clientIgnoresDuplicateSetupTransaction() throws Exception { + server.start(serverListener); + client = + newClientTransportBuilder() + .setFactory( + newClientTransportFactoryBuilder() + .setSecurityPolicy(SecurityPolicies.internalOnly()) + .buildClientTransportFactory()) + .build(); + runIfNotNull(client.start(mockClientTransportListener)); + verify(mockClientTransportListener, timeout(TIMEOUT_MS)).transportReady(); + + assertThat(((ConnectionClientTransport) client).getAttributes().get(REMOTE_UID)) + .isEqualTo(myUid()); + + Parcel setupParcel = Parcel.obtain(); + try { + setupParcel.writeInt(WIRE_FORMAT_VERSION); + setupParcel.writeStrongBinder(new Binder()); + setupParcel.setDataPosition(0); + ShadowBinder.setCallingUid(1 + myUid()); + ((BinderClientTransport) client).handleTransaction(SETUP_TRANSPORT, setupParcel); + } finally { + ShadowBinder.setCallingUid(myUid()); + setupParcel.recycle(); + } + + assertThat(((ConnectionClientTransport) client).getAttributes().get(REMOTE_UID)) + .isEqualTo(myUid()); + } + + @Test + public void clientIgnoresTransactionFromNonServerUids() throws Exception { + server.start(serverListener); + + // This test is not applicable to the new auth strategy which keeps the client Binder a secret. + assumeTrue(useLegacyAuthStrategy); + + client = newClientTransport(server); + startTransport(client, mockClientTransportListener); + + int serverUid = ((ConnectionClientTransport) client).getAttributes().get(REMOTE_UID); + int someOtherUid = 1 + serverUid; + sendShutdownTransportTransactionAsUid(client, someOtherUid); + + // Demonstrate that the transport is still working and that shutdown transaction was ignored. + ClientTransport.PingCallback mockPingCallback = mock(ClientTransport.PingCallback.class); + client.ping(mockPingCallback, directExecutor()); + verify(mockPingCallback, timeout(TIMEOUT_MS)).onSuccess(anyLong()); + + // Try again as the expected uid to demonstrate that this wasn't ignored for some other reason. + sendShutdownTransportTransactionAsUid(client, serverUid); + + verify(mockClientTransportListener, timeout(TIMEOUT_MS)) + .transportShutdown(statusCaptor.capture(), any(DisconnectError.class)); + assertThat(statusCaptor.getValue().getCode()).isEqualTo(Status.Code.UNAVAILABLE); + assertThat(statusCaptor.getValue().getDescription()).contains("shutdown"); + } + + static void sendShutdownTransportTransactionAsUid(ClientTransport client, int sendingUid) { + int originalUid = Binder.getCallingUid(); + try { + ShadowBinder.setCallingUid(sendingUid); + ((BinderClientTransport) client) + .getIncomingBinderForTesting() + .onTransact(SHUTDOWN_TRANSPORT, null, null, FLAG_ONEWAY); + } finally { + ShadowBinder.setCallingUid(originalUid); + } + } + + @Test + public void clientReportsAuthzErrorToServer() throws Exception { + server.start(serverListener); + client = + newClientTransportBuilder() + .setFactory( + newClientTransportFactoryBuilder() + .setSecurityPolicy(SecurityPolicies.permissionDenied("test")) + .buildClientTransportFactory()) + .build(); + runIfNotNull(client.start(mockClientTransportListener)); + verify(mockClientTransportListener, timeout(TIMEOUT_MS)) + .transportShutdown(statusCaptor.capture(), any(DisconnectError.class)); + assertThat(statusCaptor.getValue().getCode()).isEqualTo(Status.Code.PERMISSION_DENIED); + + // Client doesn't tell the server in this case by design -- we don't even want to start it! + TruthJUnit.assume().that(preAuthServersParam).isFalse(); + // Similar story here. The client won't send a setup transaction to an unauthorized server. + TruthJUnit.assume().that(useLegacyAuthStrategy).isTrue(); + + MockServerTransportListener serverTransportListener = + serverListener.takeListenerOrFail(TIMEOUT_MS, MILLISECONDS); + serverTransportListener.waitForTermination(TIMEOUT_MS, MILLISECONDS); + assertThat(serverTransportListener.isTerminated()).isTrue(); + } + + @Test + @Override + // We don't quite pass the official/abstract version of this test yet because + // today's binder client and server transports have different ideas of each others' address. + // TODO(#12347): Remove this @Override once this difference is resolved. + public void socketStats() throws Exception { + server.start(serverListener); + ManagedClientTransport client = newClientTransport(server); + startTransport(client, mockClientTransportListener); + + SocketStats clientSocketStats = client.getStats().get(); + assertThat(clientSocketStats.local).isInstanceOf(AndroidComponentAddress.class); + assertThat(((AndroidComponentAddress) clientSocketStats.remote).getPackage()) + .isEqualTo(((AndroidComponentAddress) server.getListenSocketAddress()).getPackage()); + + MockServerTransportListener serverTransportListener = + serverListener.takeListenerOrFail(TIMEOUT_MS, MILLISECONDS); + SocketStats serverSocketStats = serverTransportListener.transport.getStats().get(); + assertThat(serverSocketStats.local).isEqualTo(server.getListenSocketAddress()); + assertThat(serverSocketStats.remote).isEqualTo(new BoundClientAddress(myUid())); + } + + @Test + @Ignore("See BinderTransportTest#flowControlPushBack") + @Override + public void flowControlPushBack() {} + + @Test + @Ignore("See BinderTransportTest#serverAlreadyListening") + @Override + public void serverAlreadyListening() {} + + @Test + public void singleTxnMsgsDeliveredToServerOutOfOrder() throws Exception { + server.start(serverListener); + client = + newClientTransportBuilder() + .setFactory( + newClientTransportFactoryBuilder() + .setBinderDecorator(blockingDecorator) + .buildClientTransportFactory()) + .build(); + runIfNotNull(client.start(mockClientTransportListener)); + blockingDecorator.putNextResult(takeNextBinder(blockingDecorator)); // Endpoint binder. + QueueingOneWayBinderProxy queueingServerProxy = + new QueueingOneWayBinderProxy(takeNextBinder(blockingDecorator)); // Server binder. + blockingDecorator.putNextResult(queueingServerProxy); + + verify(mockClientTransportListener, timeout(TIMEOUT_MS)).transportReady(); + + ClientStream stream = + client.newStream(methodDescriptor, new Metadata(), CallOptions.DEFAULT, noopTracers); + ClientStreamListenerBase clientStreamListener = new ClientStreamListenerBase(); + stream.start(clientStreamListener); + stream.writeMessage(methodDescriptor.streamRequest("one")); + stream.writeMessage(methodDescriptor.streamRequest("two")); + stream.halfClose(); + + // Expect one transaction for headers, one for each message, and one for half-close. + QueueingOneWayBinderProxy.Transaction txHeaders = takeNextTransaction(queueingServerProxy); + QueueingOneWayBinderProxy.Transaction tx1 = takeNextTransaction(queueingServerProxy); + QueueingOneWayBinderProxy.Transaction tx2 = takeNextTransaction(queueingServerProxy); + QueueingOneWayBinderProxy.Transaction txHalfClose = takeNextTransaction(queueingServerProxy); + + // Deliver messages out of order! + queueingServerProxy.deliver(txHeaders); + queueingServerProxy.deliver(tx2); + queueingServerProxy.deliver(tx1); + queueingServerProxy.deliver(txHalfClose); + + MockServerTransportListener serverTransportListener = + serverListener.takeListenerOrFail(TIMEOUT_MS, MILLISECONDS); + MockServerTransportListener.StreamCreation serverStreamCreation = + serverTransportListener.takeStreamOrFail(TIMEOUT_MS, MILLISECONDS); + serverStreamCreation.stream.request(2); + + // Expect the server to deliver the messages in the order they were originally sent. + InputStream msg1 = takeNextMessage(serverStreamCreation.listener.messageQueue); + assertThat(methodDescriptor.parseResponse(msg1)).isEqualTo("one"); + + InputStream msg2 = takeNextMessage(serverStreamCreation.listener.messageQueue); + assertThat(methodDescriptor.parseResponse(msg2)).isEqualTo("two"); + + assertThat(serverStreamCreation.listener.awaitHalfClosed(TIMEOUT_MS, MILLISECONDS)).isTrue(); + serverStreamCreation.stream.close(Status.OK, new Metadata()); + + assertAbout(status()).that(clientStreamListener.awaitClose(TIMEOUT_MS, MILLISECONDS)).isOk(); + assertAbout(status()) + .that(serverStreamCreation.listener.awaitClose(TIMEOUT_MS, MILLISECONDS)) + .isOk(); + } + + @Test + public void msgFragmentsDeliveredToServerOutOfOrder() throws Exception { + server.start(serverListener); + client = + newClientTransportBuilder() + .setFactory( + newClientTransportFactoryBuilder() + .setBinderDecorator(blockingDecorator) + .buildClientTransportFactory()) + .build(); + runIfNotNull(client.start(mockClientTransportListener)); + blockingDecorator.putNextResult(takeNextBinder(blockingDecorator)); // Endpoint binder. + QueueingOneWayBinderProxy queueingServerProxy = + new QueueingOneWayBinderProxy(takeNextBinder(blockingDecorator)); // Server binder. + blockingDecorator.putNextResult(queueingServerProxy); + + verify(mockClientTransportListener, timeout(TIMEOUT_MS)).transportReady(); + + ClientStream stream = + client.newStream(methodDescriptor, new Metadata(), CallOptions.DEFAULT, noopTracers); + ClientStreamListenerBase clientStreamListener = new ClientStreamListenerBase(); + stream.start(clientStreamListener); + + String largeMessage = newStringOfLength(BlockPool.BLOCK_SIZE + 1); + stream.writeMessage(methodDescriptor.streamRequest(largeMessage)); + stream.halfClose(); + + // Expect the client to split largeMessage into two transactions, plus headers and half-close. + QueueingOneWayBinderProxy.Transaction txHeaders = takeNextTransaction(queueingServerProxy); + QueueingOneWayBinderProxy.Transaction tx1 = takeNextTransaction(queueingServerProxy); + QueueingOneWayBinderProxy.Transaction tx2 = takeNextTransaction(queueingServerProxy); + QueueingOneWayBinderProxy.Transaction txHalfClose = takeNextTransaction(queueingServerProxy); + + // Deliver fragments out of order! + queueingServerProxy.deliver(txHeaders); + queueingServerProxy.deliver(tx2); + queueingServerProxy.deliver(tx1); + queueingServerProxy.deliver(txHalfClose); + + // Verify that the server reassembles the transactions correctly. + MockServerTransportListener serverTransportListener = + serverListener.takeListenerOrFail(TIMEOUT_MS, MILLISECONDS); + MockServerTransportListener.StreamCreation serverStreamCreation = + serverTransportListener.takeStreamOrFail(TIMEOUT_MS, MILLISECONDS); + serverStreamCreation.stream.request(1); + InputStream msg = takeNextMessage(serverStreamCreation.listener.messageQueue); + assertThat(methodDescriptor.parseResponse(msg)).isEqualTo(largeMessage); + + assertThat(serverStreamCreation.listener.awaitHalfClosed(TIMEOUT_MS, MILLISECONDS)).isTrue(); + serverStreamCreation.stream.close(Status.OK, new Metadata()); + + assertAbout(status()).that(clientStreamListener.awaitClose(TIMEOUT_MS, MILLISECONDS)).isOk(); + assertAbout(status()) + .that(serverStreamCreation.listener.awaitClose(TIMEOUT_MS, MILLISECONDS)) + .isOk(); + } + + @Test + public void singleTxnMsgsDeliveredToClientOutOfOrder() throws Exception { + server = newServerBuilder().setClientBinderDecorator(blockingDecorator).build(); + registerServerWithRobolectric((BinderServer) server); + server.start(serverListener); + + client = newClientTransport(server); + runIfNotNull(client.start(mockClientTransportListener)); + + QueueingOneWayBinderProxy queueingClientProxy = + new QueueingOneWayBinderProxy(takeNextBinder(blockingDecorator)); + blockingDecorator.putNextResult(queueingClientProxy); + + // Deliver the setup transaction without interference. + queueingClientProxy.deliver(takeNextTransaction(queueingClientProxy)); + verify(mockClientTransportListener, timeout(TIMEOUT_MS)).transportReady(); + + ClientStreamListenerBase clientStreamListener = new ClientStreamListenerBase(); + ClientStream stream = + client.newStream(methodDescriptor, new Metadata(), CallOptions.DEFAULT, noopTracers); + stream.start(clientStreamListener); + stream.halfClose(); + stream.request(2); + + MockServerTransportListener serverTransportListener = + serverListener.takeListenerOrFail(TIMEOUT_MS, MILLISECONDS); + MockServerTransportListener.StreamCreation serverStreamCreation = + serverTransportListener.takeStreamOrFail(TIMEOUT_MS, MILLISECONDS); + + serverStreamCreation.stream.writeMessage(methodDescriptor.streamResponse("one")); + serverStreamCreation.stream.writeMessage(methodDescriptor.streamResponse("two")); + serverStreamCreation.stream.close(Status.OK, new Metadata()); + + // Expect one transaction from the server for each message. + QueueingOneWayBinderProxy.Transaction tx1 = takeNextTransaction(queueingClientProxy); + QueueingOneWayBinderProxy.Transaction tx2 = takeNextTransaction(queueingClientProxy); + QueueingOneWayBinderProxy.Transaction txClose = takeNextTransaction(queueingClientProxy); + + // Deliver messages to the client out of order! + queueingClientProxy.deliver(tx2); + queueingClientProxy.deliver(tx1); + queueingClientProxy.deliver(txClose); + + // Client should deliver messages to the application in the order sent. + InputStream msg1 = takeNextMessage(clientStreamListener.messageQueue); + assertThat(methodDescriptor.parseResponse(msg1)).isEqualTo("one"); + InputStream msg2 = takeNextMessage(clientStreamListener.messageQueue); + assertThat(methodDescriptor.parseResponse(msg2)).isEqualTo("two"); + + assertAbout(status()).that(clientStreamListener.awaitClose(TIMEOUT_MS, MILLISECONDS)).isOk(); + assertAbout(status()) + .that(serverStreamCreation.listener.awaitClose(TIMEOUT_MS, MILLISECONDS)) + .isOk(); + } + + @Test + public void msgFragmentsDeliveredToClientOutOfOrder() throws Exception { + server = newServerBuilder().setClientBinderDecorator(blockingDecorator).build(); + registerServerWithRobolectric((BinderServer) server); + server.start(serverListener); + + client = newClientTransport(server); + runIfNotNull(client.start(mockClientTransportListener)); + + QueueingOneWayBinderProxy queueingClientProxy = + new QueueingOneWayBinderProxy(takeNextBinder(blockingDecorator)); + blockingDecorator.putNextResult(queueingClientProxy); + + // Deliver the setup transaction without interference. + queueingClientProxy.deliver(takeNextTransaction(queueingClientProxy)); + verify(mockClientTransportListener, timeout(TIMEOUT_MS)).transportReady(); + + ClientStreamListenerBase clientStreamListener = new ClientStreamListenerBase(); + ClientStream stream = + client.newStream(methodDescriptor, new Metadata(), CallOptions.DEFAULT, noopTracers); + stream.start(clientStreamListener); + stream.request(1); + + MockServerTransportListener serverTransportListener = + serverListener.takeListenerOrFail(TIMEOUT_MS, MILLISECONDS); + MockServerTransportListener.StreamCreation serverStreamCreation = + serverTransportListener.takeStreamOrFail(TIMEOUT_MS, MILLISECONDS); + + String largeMessage = newStringOfLength(BlockPool.BLOCK_SIZE + 1); + serverStreamCreation.stream.writeMessage(methodDescriptor.streamResponse(largeMessage)); + serverStreamCreation.stream.flush(); + + // Expect the client to split largeMessage into two transactions. + QueueingOneWayBinderProxy.Transaction tx1 = takeNextTransaction(queueingClientProxy); + QueueingOneWayBinderProxy.Transaction tx2 = takeNextTransaction(queueingClientProxy); + + // Deliver them to the client out of order! + queueingClientProxy.deliver(tx2); + queueingClientProxy.deliver(tx1); + + // Client should reassemble the message correctly. + InputStream msg = takeNextMessage(clientStreamListener.messageQueue); + assertThat(methodDescriptor.parseResponse(msg)).isEqualTo(largeMessage); + } + + private static OneWayBinderProxy takeNextBinder( + BlockingBinderDecorator decorator) throws InterruptedException { + OneWayBinderProxy proxy = decorator.takeNextRequest(TIMEOUT_MS, MILLISECONDS); + assertThat(proxy).isNotNull(); + return proxy; + } + + private static QueueingOneWayBinderProxy.Transaction takeNextTransaction( + QueueingOneWayBinderProxy proxy) throws InterruptedException { + QueueingOneWayBinderProxy.Transaction tx = proxy.pollNextTransaction(TIMEOUT_MS, MILLISECONDS); + assertThat(tx).isNotNull(); + return tx; + } + + private static InputStream takeNextMessage(BlockingQueue messageQueue) + throws InterruptedException { + InputStream msg = messageQueue.poll(TIMEOUT_MS, MILLISECONDS); + assertThat(msg).isNotNull(); + return msg; + } + + private static String newStringOfLength(int numChars) { + char[] chars = new char[numChars]; + java.util.Arrays.fill(chars, 'x'); + return new String(chars); + } +} diff --git a/binder/src/test/java/io/grpc/binder/internal/ServiceBindingTest.java b/binder/src/test/java/io/grpc/binder/internal/ServiceBindingTest.java index 3ec65624fb8..0f57b6f8a30 100644 --- a/binder/src/test/java/io/grpc/binder/internal/ServiceBindingTest.java +++ b/binder/src/test/java/io/grpc/binder/internal/ServiceBindingTest.java @@ -19,15 +19,17 @@ import static android.content.Context.BIND_AUTO_CREATE; import static android.os.Looper.getMainLooper; import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; import static org.junit.Assert.fail; import static org.robolectric.Shadows.shadowOf; -import static org.robolectric.annotation.LooperMode.Mode.PAUSED; import android.app.Application; import android.app.admin.DevicePolicyManager; import android.content.ComponentName; import android.content.Context; import android.content.Intent; +import android.content.pm.ServiceInfo; +import android.os.Build; import android.os.IBinder; import android.os.Parcel; import android.os.UserHandle; @@ -35,6 +37,7 @@ import androidx.test.core.app.ApplicationProvider; import io.grpc.Status; import io.grpc.Status.Code; +import io.grpc.StatusException; import io.grpc.binder.BinderChannelCredentials; import io.grpc.binder.internal.Bindable.Observer; import java.util.Arrays; @@ -48,11 +51,8 @@ import org.mockito.junit.MockitoRule; import org.robolectric.RobolectricTestRunner; import org.robolectric.annotation.Config; -import org.robolectric.annotation.LooperMode; import org.robolectric.shadows.ShadowApplication; -import org.robolectric.shadows.ShadowDevicePolicyManager; -@LooperMode(PAUSED) @RunWith(RobolectricTestRunner.class) public final class ServiceBindingTest { @@ -62,6 +62,7 @@ public final class ServiceBindingTest { private Application appContext; private ComponentName serviceComponent; + private ServiceInfo serviceInfo = new ServiceInfo(); private ShadowApplication shadowApplication; private TestObserver observer; private ServiceBinding binding; @@ -70,13 +71,17 @@ public final class ServiceBindingTest { public void setUp() { appContext = ApplicationProvider.getApplicationContext(); serviceComponent = new ComponentName("DUMMY", "SERVICE"); + serviceInfo.packageName = serviceComponent.getPackageName(); + serviceInfo.name = serviceComponent.getClassName(); observer = new TestObserver(); shadowApplication = shadowOf(appContext); shadowApplication.setComponentNameAndServiceForBindService(serviceComponent, mockBinder); + shadowOf(appContext.getPackageManager()).addOrUpdateService(serviceInfo); // Don't call onServiceDisconnected() upon unbindService(), just like the real Android doesn't. shadowApplication.setUnbindServiceCallsOnServiceDisconnected(false); + shadowApplication.setBindServiceCallsOnServiceConnectedDirectly(false); binding = newBuilder().build(); shadowOf(getMainLooper()).idle(); @@ -110,6 +115,32 @@ public void testBind() throws Exception { assertThat(binding.isSourceContextCleared()).isFalse(); } + @Test + public void testGetConnectedServiceInfo() throws Exception { + binding = newBuilder().setTargetComponent(serviceComponent).build(); + binding.bind(); + shadowOf(getMainLooper()).idle(); + + assertThat(observer.gotBoundEvent).isTrue(); + + ServiceInfo serviceInfo = binding.getConnectedServiceInfo(); + assertThat(serviceInfo.name).isEqualTo(serviceComponent.getClassName()); + assertThat(serviceInfo.packageName).isEqualTo(serviceComponent.getPackageName()); + } + + @Test + public void testGetConnectedServiceInfoThrows() throws Exception { + binding = newBuilder().setTargetComponent(serviceComponent).build(); + binding.bind(); + shadowOf(getMainLooper()).idle(); + + assertThat(observer.gotBoundEvent).isTrue(); + shadowOf(appContext.getPackageManager()).removeService(serviceComponent); + + StatusException se = assertThrows(StatusException.class, binding::getConnectedServiceInfo); + assertThat(se.getStatus().getCode()).isEqualTo(Code.UNIMPLEMENTED); + } + @Test public void testBindingIntent() throws Exception { shadowApplication.setComponentNameAndServiceForBindService(null, null); @@ -266,8 +297,7 @@ public void testCallsAfterUnbindDontCrash() throws Exception { @Test @Config(sdk = 30) public void testBindWithTargetUserHandle() throws Exception { - binding = - newBuilder().setTargetUserHandle(generateUserHandle(/* userId= */ 0)).build(); + binding = newBuilder().setTargetUserHandle(generateUserHandle(/* userId= */ 0)).build(); shadowOf(getMainLooper()).idle(); binding.bind(); @@ -280,18 +310,113 @@ public void testBindWithTargetUserHandle() throws Exception { assertThat(binding.isSourceContextCleared()).isFalse(); } + @Test + public void testResolve() throws Exception { + serviceInfo.processName = "x"; // ServiceInfo has no equals() so look for one distinctive field. + shadowOf(appContext.getPackageManager()).addOrUpdateService(serviceInfo); + ServiceInfo resolvedServiceInfo = binding.resolve(); + assertThat(resolvedServiceInfo.processName).isEqualTo(serviceInfo.processName); + } + + @Test + @Config(sdk = 33) + public void testResolveWithTargetUserHandle() throws Exception { + serviceInfo.processName = "x"; // ServiceInfo has no equals() so look for one distinctive field. + // Robolectric just ignores the user arg to resolveServiceAsUser() so this is all we can do. + shadowOf(appContext.getPackageManager()).addOrUpdateService(serviceInfo); + binding = newBuilder().setTargetUserHandle(generateUserHandle(/* userId= */ 0)).build(); + ServiceInfo resolvedServiceInfo = binding.resolve(); + assertThat(resolvedServiceInfo.processName).isEqualTo(serviceInfo.processName); + } + + @Test + @Config(sdk = 29) + public void testResolveWithUnsupportedTargetUserHandle() throws Exception { + binding = newBuilder().setTargetUserHandle(generateUserHandle(/* userId= */ 0)).build(); + StatusException statusException = assertThrows(StatusException.class, binding::resolve); + assertThat(statusException.getStatus().getCode()).isEqualTo(Code.INTERNAL); + assertThat(statusException.getStatus().getDescription()).contains("SDK_INT >= R"); + } + + @Test + public void testResolveNonExistentServiceThrows() throws Exception { + ComponentName doesNotExistService = new ComponentName("does.not.exist", "NoService"); + binding = newBuilder().setTargetComponent(doesNotExistService).build(); + StatusException statusException = assertThrows(StatusException.class, binding::resolve); + assertThat(statusException.getStatus().getCode()).isEqualTo(Code.UNIMPLEMENTED); + assertThat(statusException.getStatus().getDescription()).contains("does.not.exist"); + } + + @Test + @Config(sdk = 33) + public void testResolveNonExistentServiceWithTargetUserThrows() throws Exception { + ComponentName doesNotExistService = new ComponentName("does.not.exist", "NoService"); + binding = + newBuilder() + .setTargetUserHandle(generateUserHandle(/* userId= */ 12345)) + .setTargetComponent(doesNotExistService) + .build(); + StatusException statusException = assertThrows(StatusException.class, binding::resolve); + assertThat(statusException.getStatus().getCode()).isEqualTo(Code.UNIMPLEMENTED); + assertThat(statusException.getStatus().getDescription()).contains("does.not.exist"); + assertThat(statusException.getStatus().getDescription()).contains("12345"); + } + + @Test + @Config(sdk = 30) + public void testBindService_doesNotThrowInternalErrorWhenSdkAtLeastR() { + UserHandle userHandle = generateUserHandle(/* userId= */ 12345); + binding = newBuilder().setTargetUserHandle(userHandle).build(); + binding.bind(); + shadowOf(getMainLooper()).idle(); + + assertThat(Build.VERSION.SDK_INT).isEqualTo(Build.VERSION_CODES.R); + assertThat(observer.unboundReason).isNull(); + } + + @Test + @Config(sdk = 28) + public void testBindServiceAsUser_returnsErrorWhenSdkBelowR() { + UserHandle userHandle = generateUserHandle(/* userId= */ 12345); + binding = newBuilder().setTargetUserHandle(userHandle).build(); + binding.bind(); + shadowOf(getMainLooper()).idle(); + + assertThat(observer.unboundReason.getCode()).isEqualTo(Code.INTERNAL); + assertThat(observer.unboundReason.getDescription()) + .isEqualTo("Cross user Channel requires Android R+"); + } + + @Test + @Config(sdk = 28) + public void testDevicePolicyBlind_returnsErrorWhenSdkBelowR() { + ComponentName adminComponent = new ComponentName(appContext, "DevicePolicyAdmin"); + UserHandle user10 = generateUserHandle(/* userId= */ 10); + allowBindDeviceAdminForUser(appContext, adminComponent, user10); + binding = + newBuilder() + .setTargetUserHandle(user10) + .setChannelCredentials(BinderChannelCredentials.forDevicePolicyAdmin(adminComponent)) + .build(); + binding.bind(); + shadowOf(getMainLooper()).idle(); + + assertThat(observer.unboundReason.getCode()).isEqualTo(Code.INTERNAL); + assertThat(observer.unboundReason.getDescription()) + .isEqualTo("Device policy admin binding requires Android R+"); + } + @Test @Config(sdk = 30) public void testBindWithDeviceAdmin() throws Exception { - String deviceAdminClassName = "DevicePolicyAdmin"; - ComponentName adminComponent = new ComponentName(appContext, deviceAdminClassName); - allowBindDeviceAdminForUser(appContext, adminComponent, /* userId= */ 0); + ComponentName adminComponent = new ComponentName(appContext, "DevicePolicyAdmin"); + UserHandle user0 = generateUserHandle(/* userId= */ 0); + allowBindDeviceAdminForUser(appContext, adminComponent, user0); binding = newBuilder() - .setTargetUserHandle(UserHandle.getUserHandleForUid(/* userId= */ 0)) - .setTargetUserHandle(generateUserHandle(/* userId= */ 0)) - .setChannelCredentials( - BinderChannelCredentials.forDevicePolicyAdmin(adminComponent)) + .setTargetUserHandle(user0) + .setTargetComponent(serviceComponent) + .setChannelCredentials(BinderChannelCredentials.forDevicePolicyAdmin(adminComponent)) .build(); shadowOf(getMainLooper()).idle(); @@ -303,6 +428,10 @@ public void testBindWithDeviceAdmin() throws Exception { assertThat(observer.binder).isSameInstanceAs(mockBinder); assertThat(observer.gotUnboundEvent).isFalse(); assertThat(binding.isSourceContextCleared()).isFalse(); + + ServiceInfo serviceInfo = binding.getConnectedServiceInfo(); + assertThat(serviceInfo.name).isEqualTo(serviceComponent.getClassName()); + assertThat(serviceInfo.packageName).isEqualTo(serviceComponent.getPackageName()); } private void assertNoLockHeld() { @@ -317,16 +446,11 @@ private void assertNoLockHeld() { } } - private static void allowBindDeviceAdminForUser(Context context, ComponentName admin, int userId) { - ShadowDevicePolicyManager devicePolicyManager = - shadowOf(context.getSystemService(DevicePolicyManager.class)); - devicePolicyManager.setDeviceOwner(admin); - devicePolicyManager.setBindDeviceAdminTargetUsers( - Arrays.asList(UserHandle.getUserHandleForUid(userId))); - shadowOf((DevicePolicyManager) context.getSystemService(Context.DEVICE_POLICY_SERVICE)); - devicePolicyManager.setDeviceOwner(admin); - devicePolicyManager.setBindDeviceAdminTargetUsers( - Arrays.asList(generateUserHandle(userId))); + private static void allowBindDeviceAdminForUser( + Context context, ComponentName admin, UserHandle user) { + DevicePolicyManager devicePolicyManager = context.getSystemService(DevicePolicyManager.class); + shadowOf(devicePolicyManager).setBindDeviceAdminTargetUsers(Arrays.asList(user)); + shadowOf(devicePolicyManager).setDeviceOwner(admin); } /** Generate UserHandles the hard way. */ @@ -373,7 +497,7 @@ private static class ServiceBindingBuilder { private BinderChannelCredentials channelCredentials = BinderChannelCredentials.forDefault(); public ServiceBindingBuilder setSourceContext(Context sourceContext) { - this.sourceContext = sourceContext; + this.sourceContext = sourceContext; return this; } diff --git a/binder/src/test/java/io/grpc/binder/internal/SimplePromiseTest.java b/binder/src/test/java/io/grpc/binder/internal/SimplePromiseTest.java new file mode 100644 index 00000000000..6486ff5e8a1 --- /dev/null +++ b/binder/src/test/java/io/grpc/binder/internal/SimplePromiseTest.java @@ -0,0 +1,143 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.binder.internal; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; +import static org.mockito.Mockito.inOrder; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +import io.grpc.binder.internal.SimplePromise.Listener; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.InOrder; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnit; +import org.mockito.junit.MockitoRule; + +@RunWith(JUnit4.class) +public final class SimplePromiseTest { + + private static final String FULFILLED_VALUE = "a fulfilled value"; + + @Mock private Listener mockListener1; + @Mock private Listener mockListener2; + @Rule public final MockitoRule mocks = MockitoJUnit.rule(); + + private SimplePromise promise = new SimplePromise<>(); + + @Before + public void setUp() { + } + + @Test + public void get_beforeFulfilled_throws() { + IllegalStateException e = assertThrows(IllegalStateException.class, () -> promise.get()); + assertThat(e).hasMessageThat().isEqualTo("Not yet set!"); + } + + @Test + public void get_afterFulfilled_returnsValue() { + promise.set(FULFILLED_VALUE); + assertThat(promise.get()).isEqualTo(FULFILLED_VALUE); + } + + @Test + public void set_withNull_throws() { + assertThrows(NullPointerException.class, () -> promise.set(null)); + } + + @Test + public void set_calledTwice_throws() { + promise.set(FULFILLED_VALUE); + IllegalStateException e = + assertThrows(IllegalStateException.class, () -> promise.set("another value")); + assertThat(e).hasMessageThat().isEqualTo("Already set!"); + } + + @Test + public void runWhenSet_beforeFulfill_listenerIsNotifiedUponSet() { + promise.runWhenSet(mockListener1); + + // Should not have been called yet. + verify(mockListener1, never()).notify(FULFILLED_VALUE); + + promise.set(FULFILLED_VALUE); + + // Now it should be called. + verify(mockListener1, times(1)).notify(FULFILLED_VALUE); + } + + @Test + public void runWhenSet_afterSet_listenerIsNotifiedImmediately() { + promise.set(FULFILLED_VALUE); + promise.runWhenSet(mockListener1); + + // Should have been called immediately. + verify(mockListener1, times(1)).notify(FULFILLED_VALUE); + } + + @Test + public void multipleListeners_addedBeforeSet_allNotifiedInOrder() { + promise.runWhenSet(mockListener1); + promise.runWhenSet(mockListener2); + + promise.set(FULFILLED_VALUE); + + InOrder inOrder = inOrder(mockListener1, mockListener2); + inOrder.verify(mockListener1).notify(FULFILLED_VALUE); + inOrder.verify(mockListener2).notify(FULFILLED_VALUE); + } + + @Test + public void listenerThrows_duringSet_propagatesException() { + // A listener that will throw when notified. + Listener throwingListener = + (value) -> { + throw new UnsupportedOperationException("Listener failed"); + }; + + promise.runWhenSet(throwingListener); + + // Fulfilling the promise should now throw the exception from the listener. + UnsupportedOperationException e = + assertThrows(UnsupportedOperationException.class, () -> promise.set(FULFILLED_VALUE)); + assertThat(e).hasMessageThat().isEqualTo("Listener failed"); + } + + @Test + public void listenerThrows_whenAddedAfterSet_propagatesException() { + promise.set(FULFILLED_VALUE); + + // A listener that will throw when notified. + Listener throwingListener = + (value) -> { + throw new UnsupportedOperationException("Listener failed"); + }; + + // Running the listener should throw immediately because the promise is already fulfilled. + UnsupportedOperationException e = + assertThrows( + UnsupportedOperationException.class, () -> promise.runWhenSet(throwingListener)); + assertThat(e).hasMessageThat().isEqualTo("Listener failed"); + } +} diff --git a/binder/src/test/java/io/grpc/binder/internal/TransactionUtilsTest.java b/binder/src/test/java/io/grpc/binder/internal/TransactionUtilsTest.java new file mode 100644 index 00000000000..44a3ce3ef26 --- /dev/null +++ b/binder/src/test/java/io/grpc/binder/internal/TransactionUtilsTest.java @@ -0,0 +1,70 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.binder.internal; + +import static com.google.common.truth.Truth.assertThat; +import static io.grpc.binder.internal.TransactionUtils.newCallerFilteringHandler; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.ArgumentMatchers.same; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import android.os.Binder; +import android.os.Parcel; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnit; +import org.mockito.junit.MockitoRule; +import org.robolectric.RobolectricTestRunner; +import org.robolectric.shadows.ShadowBinder; + +@RunWith(RobolectricTestRunner.class) +public final class TransactionUtilsTest { + + @Rule public MockitoRule mocks = MockitoJUnit.rule(); + + @Mock LeakSafeOneWayBinder.TransactionHandler mockHandler; + + @Test + public void shouldIgnoreTransactionFromWrongUid() { + Parcel p = Parcel.obtain(); + int originalUid = Binder.getCallingUid(); + try { + when(mockHandler.handleTransaction(eq(1234), same(p))).thenReturn(true); + LeakSafeOneWayBinder.TransactionHandler uid100OnlyHandler = + newCallerFilteringHandler(1000, mockHandler); + + ShadowBinder.setCallingUid(9999); + boolean result = uid100OnlyHandler.handleTransaction(1234, p); + assertThat(result).isFalse(); + verify(mockHandler, never()).handleTransaction(anyInt(), any()); + + ShadowBinder.setCallingUid(1000); + result = uid100OnlyHandler.handleTransaction(1234, p); + assertThat(result).isTrue(); + verify(mockHandler).handleTransaction(1234, p); + } finally { + ShadowBinder.setCallingUid(originalUid); + p.recycle(); + } + } +} diff --git a/binder/src/testFixtures/java/io/grpc/binder/PeerUidTestHelper.java b/binder/src/testFixtures/java/io/grpc/binder/PeerUidTestHelper.java new file mode 100644 index 00000000000..ecd9312b02d --- /dev/null +++ b/binder/src/testFixtures/java/io/grpc/binder/PeerUidTestHelper.java @@ -0,0 +1,57 @@ +package io.grpc.binder; + +import io.grpc.Context; +import io.grpc.Contexts; +import io.grpc.Metadata; +import io.grpc.ServerCall; +import io.grpc.ServerCallHandler; +import io.grpc.ServerInterceptor; + +/** Class which helps set up {@link PeerUids} to be used in tests. */ +public final class PeerUidTestHelper { + + /** The UID of the calling package is set with the value of this key. */ + public static final Metadata.Key UID_KEY = + Metadata.Key.of("binder-remote-uid-for-unit-testing", PeerUidTestMarshaller.INSTANCE); + + /** + * Creates an interceptor that associates the {@link PeerUids#REMOTE_PEER} key in the request + * {@link Context} with a UID provided by the client in the {@link #UID_KEY} request header, if + * present. + * + *

The returned interceptor works with any gRPC transport but is meant for in-process unit + * testing of gRPC/binder services that depend on {@link PeerUids}. + */ + public static ServerInterceptor newTestPeerIdentifyingServerInterceptor() { + return new ServerInterceptor() { + @Override + public ServerCall.Listener interceptCall( + ServerCall call, Metadata headers, ServerCallHandler next) { + if (headers.containsKey(UID_KEY)) { + Context context = + Context.current().withValue(PeerUids.REMOTE_PEER, new PeerUid(headers.get(UID_KEY))); + return Contexts.interceptCall(context, call, headers, next); + } + return next.startCall(call, headers); + } + }; + } + + private PeerUidTestHelper() {} + + private static class PeerUidTestMarshaller implements Metadata.AsciiMarshaller { + + public static final PeerUidTestMarshaller INSTANCE = new PeerUidTestMarshaller(); + + @Override + public String toAsciiString(Integer value) { + return value.toString(); + } + + @Override + public Integer parseAsciiString(String serialized) { + return Integer.parseInt(serialized); + } + } + ; +} diff --git a/binder/src/testFixtures/java/io/grpc/binder/internal/BinderClientTransportBuilder.java b/binder/src/testFixtures/java/io/grpc/binder/internal/BinderClientTransportBuilder.java new file mode 100644 index 00000000000..f732ff64663 --- /dev/null +++ b/binder/src/testFixtures/java/io/grpc/binder/internal/BinderClientTransportBuilder.java @@ -0,0 +1,61 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.binder.internal; + +import static com.google.common.base.Preconditions.checkNotNull; + +import io.grpc.ChannelLogger; +import io.grpc.internal.ClientTransportFactory.ClientTransportOptions; +import io.grpc.internal.TestUtils.NoopChannelLogger; +import java.net.SocketAddress; + +/** + * Helps unit tests create {@link BinderClientTransport} instances without having to mention + * irrelevant details (go/tott/719). + */ +public class BinderClientTransportBuilder { + private BinderClientTransportFactory factory; + private SocketAddress serverAddress; + private ChannelLogger channelLogger = new NoopChannelLogger(); + private io.grpc.internal.ClientTransportFactory.ClientTransportOptions options = + new ClientTransportOptions(); + + public BinderClientTransportBuilder setServerAddress(SocketAddress serverAddress) { + this.serverAddress = checkNotNull(serverAddress); + return this; + } + + public BinderClientTransportBuilder setChannelLogger(ChannelLogger channelLogger) { + this.channelLogger = checkNotNull(channelLogger); + return this; + } + + public BinderClientTransportBuilder setOptions(ClientTransportOptions options) { + this.options = checkNotNull(options); + return this; + } + + public BinderClientTransportBuilder setFactory(BinderClientTransportFactory factory) { + this.factory = checkNotNull(factory); + return this; + } + + public BinderClientTransport build() { + return factory.newClientTransport( + checkNotNull(serverAddress), checkNotNull(options), checkNotNull(channelLogger)); + } +} diff --git a/binder/src/testFixtures/java/io/grpc/binder/internal/FakeDeadBinder.java b/binder/src/testFixtures/java/io/grpc/binder/internal/FakeDeadBinder.java new file mode 100644 index 00000000000..5bce7498c4b --- /dev/null +++ b/binder/src/testFixtures/java/io/grpc/binder/internal/FakeDeadBinder.java @@ -0,0 +1,74 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.binder.internal; + +import android.os.DeadObjectException; +import android.os.IBinder; +import android.os.IInterface; +import android.os.Parcel; +import android.os.RemoteException; +import java.io.FileDescriptor; + +/** An {@link IBinder} that behaves as if its hosting process has died, for testing. */ +public class FakeDeadBinder implements IBinder { + @Override + public boolean isBinderAlive() { + return false; + } + + @Override + public IInterface queryLocalInterface(String descriptor) { + return null; + } + + @Override + public String getInterfaceDescriptor() throws RemoteException { + throw new DeadObjectException(); + } + + @Override + public boolean pingBinder() { + return false; + } + + @Override + public void dump(FileDescriptor fd, String[] args) throws RemoteException { + throw new DeadObjectException(); + } + + @Override + public void dumpAsync(FileDescriptor fd, String[] args) throws RemoteException { + throw new DeadObjectException(); + } + + @Override + public boolean transact(int code, Parcel data, Parcel reply, int flags) throws RemoteException { + throw new DeadObjectException(); + } + + @Override + public void linkToDeath(DeathRecipient r, int flags) throws RemoteException { + throw new DeadObjectException(); + } + + @Override + public boolean unlinkToDeath(DeathRecipient deathRecipient, int flags) { + // No need to check whether 'deathRecipient' was ever actually passed to linkToDeath(): Per our + // API contract, if "the IBinder has already died" we never throw and always return false. + return false; + } +} diff --git a/binder/src/test/java/io/grpc/binder/MainThreadScheduledExecutorService.java b/binder/src/testFixtures/java/io/grpc/binder/internal/MainThreadScheduledExecutorService.java similarity index 95% rename from binder/src/test/java/io/grpc/binder/MainThreadScheduledExecutorService.java rename to binder/src/testFixtures/java/io/grpc/binder/internal/MainThreadScheduledExecutorService.java index 9429a423ac0..23285b462e6 100644 --- a/binder/src/test/java/io/grpc/binder/MainThreadScheduledExecutorService.java +++ b/binder/src/testFixtures/java/io/grpc/binder/internal/MainThreadScheduledExecutorService.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package io.grpc.binder; +package io.grpc.binder.internal; import android.os.Handler; import android.os.Looper; @@ -37,7 +37,7 @@ * *

Use {@link org.robolectric.shadows.ShadowLooper#idle()} to run queued work. */ -class MainThreadScheduledExecutorService extends AbstractExecutorService +public class MainThreadScheduledExecutorService extends AbstractExecutorService implements ScheduledExecutorService { private final Handler handler = new Handler(Looper.getMainLooper()); @@ -110,8 +110,7 @@ public void run() { } @Override - public void shutdown() { - } + public void shutdown() {} @Override public List shutdownNow() { @@ -154,8 +153,7 @@ public long getDelay(TimeUnit timeUnit) { @Override public int compareTo(Delayed other) { - return Comparator.comparingLong( - (Delayed delayed) -> delayed.getDelay(TimeUnit.MILLISECONDS)) + return Comparator.comparingLong((Delayed delayed) -> delayed.getDelay(TimeUnit.MILLISECONDS)) .compare(this, other); } diff --git a/binder/src/testFixtures/java/io/grpc/binder/internal/OneWayBinderProxies.java b/binder/src/testFixtures/java/io/grpc/binder/internal/OneWayBinderProxies.java new file mode 100644 index 00000000000..c7eee06e73a --- /dev/null +++ b/binder/src/testFixtures/java/io/grpc/binder/internal/OneWayBinderProxies.java @@ -0,0 +1,181 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.grpc.binder.internal; + +import android.os.RemoteException; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.TimeUnit; +import javax.annotation.Nullable; + +/** A collection of {@link OneWayBinderProxy}-related test helpers. */ +public final class OneWayBinderProxies { + /** + * A {@link OneWayBinderProxy.Decorator} that blocks calling threads while an (external) test + * provides the actual decoration. + */ + public static final class BlockingBinderDecorator + implements OneWayBinderProxy.Decorator { + private final BlockingQueue requests = new LinkedBlockingQueue<>(); + private final BlockingQueue results = new LinkedBlockingQueue<>(); + + /** + * Returns the next {@link OneWayBinderProxy} that needs decorating, blocking if it hasn't yet + * been provided to {@link #decorate}. + * + *

Follow this with a call to {@link #putNextResult(OneWayBinderProxy)} to provide the result + * of {@link #decorate} and unblock the waiting caller. + */ + public OneWayBinderProxy takeNextRequest() throws InterruptedException { + return requests.take(); + } + + /** + * Returns the next {@link OneWayBinderProxy} that needs decorating, blocking for up to the + * specified timeout if it hasn't yet been provided to {@link #decorate}. + * + *

Follow this with a call to {@link #putNextResult(OneWayBinderProxy)} to provide the result + * of {@link #decorate} and unblock the waiting caller. + */ + public OneWayBinderProxy takeNextRequest(long timeout, TimeUnit unit) + throws InterruptedException { + return requests.poll(timeout, unit); + } + + /** Provides the next value to return from {@link #decorate}. */ + public void putNextResult(T next) throws InterruptedException { + results.put(next); + } + + @Override + public OneWayBinderProxy decorate(OneWayBinderProxy in) { + try { + requests.put(in); + return results.take(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RuntimeException(e); + } + } + } + + /** A {@link OneWayBinderProxy} decorator whose transact method can artificially throw. */ + public static final class ThrowingOneWayBinderProxy extends OneWayBinderProxy { + private final OneWayBinderProxy wrapped; + @Nullable private RemoteException remoteException; + + ThrowingOneWayBinderProxy(OneWayBinderProxy wrapped) { + super(wrapped.getDelegate()); + this.wrapped = wrapped; + } + + /** + * Causes all future invocations of transact to throw `remoteException`. + * + *

Users are responsible for ensuring their calls "happen-before" the relevant calls to + * {@link #transact(int, ParcelHolder)}. + */ + public void setRemoteException(RemoteException remoteException) { + this.remoteException = remoteException; + } + + @Override + public void transact(int code, ParcelHolder data) throws RemoteException { + if (remoteException != null) { + throw remoteException; + } + wrapped.transact(code, data); + } + } + + /** + * A {@link OneWayBinderProxy} decorator whose transact method can be configured to silently drop. + */ + public static final class BlackHoleOneWayBinderProxy extends OneWayBinderProxy { + + private final OneWayBinderProxy wrapped; + private boolean dropAllTransactions; + + BlackHoleOneWayBinderProxy(OneWayBinderProxy wrapped) { + super(wrapped.getDelegate()); + this.wrapped = wrapped; + } + + /** + * Causes all future invocations of transact to be silently dropped. + * + *

Users are responsible for ensuring their calls "happen-before" the relevant calls to + * {@link #transact(int, ParcelHolder)}. + */ + public void dropAllTransactions(boolean dropAllTransactions) { + this.dropAllTransactions = dropAllTransactions; + } + + @Override + public void transact(int code, ParcelHolder data) throws RemoteException { + if (!dropAllTransactions) { + wrapped.transact(code, data); + } + } + } + + /** A {@link OneWayBinderProxy} that queues transactions for a test to deliver manually later. */ + public static final class QueueingOneWayBinderProxy extends OneWayBinderProxy { + public static final class Transaction { + public final int code; + private final ParcelHolder parcel; + + public Transaction(int code, ParcelHolder parcel) { + this.code = code; + this.parcel = parcel; + } + } + + private final BlockingQueue queue = new LinkedBlockingQueue<>(); + private final OneWayBinderProxy wrapped; + + public QueueingOneWayBinderProxy(OneWayBinderProxy wrapped) { + super(wrapped.getDelegate()); + this.wrapped = wrapped; + } + + @Override + public void transact(int code, ParcelHolder data) throws RemoteException { + queue.add(new Transaction(code, new ParcelHolder(data.release()))); + } + + /** + * Returns the next transaction that was queued in order, waiting up to the specified timeout. + */ + public Transaction pollNextTransaction(long timeout, TimeUnit unit) + throws InterruptedException { + return queue.poll(timeout, unit); + } + + /** + * Delivers a previously queued transaction to its original destination. + * + * @throws IllegalStateException if transaction was already delivered once before + */ + public void deliver(Transaction transaction) throws RemoteException { + wrapped.transact(transaction.code, transaction.parcel); + } + } + + // Cannot be instantiated. + private OneWayBinderProxies() {} + ; +} diff --git a/binder/src/testFixtures/java/io/grpc/binder/internal/SettableAsyncSecurityPolicy.java b/binder/src/testFixtures/java/io/grpc/binder/internal/SettableAsyncSecurityPolicy.java new file mode 100644 index 00000000000..2cb22c2fdbf --- /dev/null +++ b/binder/src/testFixtures/java/io/grpc/binder/internal/SettableAsyncSecurityPolicy.java @@ -0,0 +1,83 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.binder.internal; + +import static com.google.common.base.Preconditions.checkState; + +import com.google.common.util.concurrent.ListenableFuture; +import com.google.common.util.concurrent.SettableFuture; +import io.grpc.Status; +import io.grpc.binder.AsyncSecurityPolicy; +import java.util.concurrent.LinkedBlockingDeque; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +/** + * An {@link AsyncSecurityPolicy} that lets unit tests verify the exact order of authorization + * requests and respond to them one at a time. + */ +public class SettableAsyncSecurityPolicy extends AsyncSecurityPolicy { + private final LinkedBlockingDeque pendingRequests = new LinkedBlockingDeque<>(); + + @Override + public ListenableFuture checkAuthorizationAsync(int uid) { + AuthRequest request = new AuthRequest(uid); + pendingRequests.add(request); + return request.resultFuture; + } + + /** + * Waits for the next "check authorization" request to be made and returns it, throwing in case no + * request arrives in time. + */ + public AuthRequest takeNextAuthRequest(long timeout, TimeUnit unit) + throws InterruptedException, TimeoutException { + AuthRequest nextAuthRequest = pendingRequests.poll(timeout, unit); + if (nextAuthRequest == null) { + throw new TimeoutException(); + } + return nextAuthRequest; + } + + /** Represents a single call to {@link AsyncSecurityPolicy#checkAuthorizationAsync(int)}. */ + public static class AuthRequest { + + /** The argument passed to {@link AsyncSecurityPolicy#checkAuthorizationAsync(int)}. */ + public final int uid; + + private final SettableFuture resultFuture = SettableFuture.create(); + + private AuthRequest(int uid) { + this.uid = uid; + } + + /** Provides this SecurityPolicy's response to this authorization request. */ + public void setResult(Status result) { + checkState(resultFuture.set(result)); + } + + /** Simulates an exceptional response to this authorization request. */ + public void setResult(Throwable t) { + checkState(resultFuture.setException(t)); + } + + /** Tests if the future returned for this authorization request was cancelled by the caller. */ + public boolean isCancelled() { + return resultFuture.isCancelled(); + } + } +} \ No newline at end of file diff --git a/bom/build.gradle b/bom/build.gradle index 1b1f98cff18..f7f3918372f 100644 --- a/bom/build.gradle +++ b/bom/build.gradle @@ -1,40 +1,32 @@ plugins { + id 'java-platform' id "maven-publish" } description = 'gRPC: BOM' +gradle.projectsEvaluated { + def projectsToInclude = rootProject.subprojects.findAll { + return it.name != 'grpc-compiler' + && it.plugins.hasPlugin('java') + && it.plugins.hasPlugin('maven-publish') + && it.tasks.findByName('publishMavenPublicationToMavenRepository')?.enabled + } + dependencies { + constraints { + projectsToInclude.each { api it } + } + } +} + publishing { publications { maven(MavenPublication) { - // remove all other artifacts since BOM doesn't generates any Jar - artifacts = [] - + from components.javaPlatform pom.withXml { - // Generate bom using subprojects - def internalProjects = [ - project.name, - 'grpc-compiler', - ] - - def dependencyManagement = asNode().appendNode('dependencyManagement') - def dependencies = dependencyManagement.appendNode('dependencies') - rootProject.subprojects.each { subproject -> - if (internalProjects.contains(subproject.name)) { - return - } - if (!subproject.hasProperty('publishMavenPublicationToMavenRepository')) { - return - } - if (!subproject.publishMavenPublicationToMavenRepository.enabled) { - return - } - def dependencyNode = dependencies.appendNode('dependency') - dependencyNode.appendNode('groupId', subproject.group) - dependencyNode.appendNode('artifactId', subproject.name) - dependencyNode.appendNode('version', subproject.version) - } + def dependencies = asNode().dependencyManagement.dependencies.last() // add protoc gen (produced by grpc-compiler with different artifact name) + // not sure how to express "pom" in gradle, kept in XML def dependencyNode = dependencies.appendNode('dependency') dependencyNode.appendNode('groupId', project.group) dependencyNode.appendNode('artifactId', 'protoc-gen-grpc-java') diff --git a/build.gradle b/build.gradle index 1d8876295a2..e65261b0cc4 100644 --- a/build.gradle +++ b/build.gradle @@ -21,13 +21,28 @@ subprojects { apply plugin: "net.ltgt.errorprone" group = "io.grpc" - version = "1.63.0-SNAPSHOT" // CURRENT_GRPC_VERSION + version = "1.82.0-SNAPSHOT" // CURRENT_GRPC_VERSION repositories { maven { // The google mirror is less flaky than mavenCentral() - url "https://maven-central.storage-download.googleapis.com/maven2/" } - mavenCentral() - mavenLocal() + url = "https://maven-central.storage-download.googleapis.com/maven2/" + metadataSources { + mavenPom() + ignoreGradleMetadataRedirection() + } + } + mavenCentral() { + metadataSources { + mavenPom() + ignoreGradleMetadataRedirection() + } + } + mavenLocal() { + metadataSources { + mavenPom() + ignoreGradleMetadataRedirection() + } + } } tasks.withType(JavaCompile).configureEach { @@ -136,7 +151,7 @@ subprojects { appendToProperty( it.options.errorprone.excludedPaths, ".*/src/generated/[^/]+/java/.*" + - "|.*/build/generated/source/proto/[^/]+/java/.*", + "|.*/build/generated/sources/proto/[^/]+/java/.*", "|") } } @@ -182,6 +197,25 @@ subprojects { } } + plugins.withId("com.android.base") { + android { + lint { + abortOnError true + if (rootProject.hasProperty('failOnWarnings') && rootProject.failOnWarnings.toBoolean()) { + warningsAsErrors true + } + } + } + tasks.withType(JavaCompile).configureEach { + it.options.compilerArgs += [ + "-Xlint:all" + ] + if (rootProject.hasProperty('failOnWarnings') && rootProject.failOnWarnings.toBoolean()) { + it.options.compilerArgs += ["-Werror"] + } + } + } + plugins.withId("java") { dependencies { testImplementation libraries.junit, @@ -223,12 +257,12 @@ subprojects { // At a test failure, log the stack trace to the console so that we don't // have to open the HTML in a browser. - tasks.named("test").configure { + tasks.withType(Test).configureEach { testLogging { exceptionFormat = 'full' - showExceptions true - showCauses true - showStackTraces true + showExceptions = true + showCauses = true + showStackTraces = true } maxHeapSize = '1500m' } @@ -254,16 +288,23 @@ subprojects { // The warning fails to provide a source location options.errorprone.check("MissingSummary", CheckSeverity.OFF) - // TODO(https://github.com/grpc/grpc-java/issues/10372): remove when fixed. - if (JavaVersion.current().isJava11Compatible()) { - options.errorprone.check("StringCaseLocaleUsage", CheckSeverity.OFF) - } + // This check is in libs.errorprone.corejava8 but has been removed + // in later versions. It isn't smart enough to realize the field is + // actually immutable. And it also doesn't complain about arrays + // that are actually mutable. + options.errorprone.check("MutableConstantField", CheckSeverity.OFF) } tasks.named("compileTestJava").configure { // LinkedList doesn't hurt much in tests and has lots of usages options.errorprone.check("JdkObsolete", CheckSeverity.OFF) options.errorprone.check("PreferJavaTimeOverload", CheckSeverity.OFF) options.errorprone.check("JavaUtilDate", CheckSeverity.OFF) + + // This check is in libs.errorprone.corejava8 but has been removed + // in later versions. It isn't smart enough to realize the field is + // actually immutable. And it also doesn't complain about arrays + // that are actually mutable. + options.errorprone.check("MutableConstantField", CheckSeverity.OFF) } plugins.withId("ru.vyarus.animalsniffer") { @@ -304,7 +345,7 @@ subprojects { } } - plugins.withId("com.github.johnrengelman.shadow") { + plugins.withId("com.gradleup.shadow") { tasks.named("shadowJar").configure { // Do a dance to remove Class-Path. This needs to run after the doFirst() from the // shadow plugin that adds Class-Path and before the core jar action. Using doFirst will @@ -367,11 +408,11 @@ subprojects { url = new File(rootProject.repositoryDir).toURI() } else { String stagingUrl + String baseUrl = "https://ossrh-staging-api.central.sonatype.com/service/local" if (rootProject.hasProperty('repositoryId')) { - stagingUrl = 'https://oss.sonatype.org/service/local/staging/deployByRepositoryId/' + - rootProject.repositoryId + stagingUrl = "${baseUrl}/staging/deployByRepositoryId/" + rootProject.repositoryId } else { - stagingUrl = 'https://oss.sonatype.org/service/local/staging/deploy/maven2/' + stagingUrl = "${baseUrl}/staging/deploy/maven2/" } credentials { if (rootProject.hasProperty('ossrhUsername') && rootProject.hasProperty('ossrhPassword')) { @@ -380,7 +421,7 @@ subprojects { } } def releaseUrl = stagingUrl - def snapshotUrl = 'https://oss.sonatype.org/content/repositories/snapshots/' + def snapshotUrl = 'https://central.sonatype.com/repository/maven-snapshots/' url = version.endsWith('SNAPSHOT') ? snapshotUrl : releaseUrl } } @@ -388,7 +429,7 @@ subprojects { } signing { - required false + required = false sign publishing.publications.maven } @@ -467,8 +508,14 @@ def isAcceptableVersion(ModuleComponentIdentifier candidate) { return true if (group == 'io.netty' && version.contains('Final')) return true + if (group == 'io.undertow' && version.contains('Final')) + return true if (module == 'android-api-level-19') return true + if (module == 'opentelemetry-exporter-prometheus') + return true + if (module == 'opentelemetry-gcp-resources') + return true return version ==~ /^[0-9]+(\.[0-9]+)+$/ } @@ -486,4 +533,5 @@ configurations { } } -tasks.register('checkForUpdates', CheckForUpdatesTask, project.configurations.checkForUpdates, "libs") +tasks.register('checkForUpdates', CheckForUpdatesTask, project.configurations.checkForUpdates, "libs", layout.projectDirectory.file("gradle/libs.versions.toml")) + diff --git a/buildSrc/src/main/java/io/grpc/gradle/CheckForUpdatesTask.java b/buildSrc/src/main/java/io/grpc/gradle/CheckForUpdatesTask.java index 01f702c8167..b7c28dbbb2d 100644 --- a/buildSrc/src/main/java/io/grpc/gradle/CheckForUpdatesTask.java +++ b/buildSrc/src/main/java/io/grpc/gradle/CheckForUpdatesTask.java @@ -16,11 +16,15 @@ package io.grpc.gradle; +import java.io.IOException; +import java.nio.file.Files; import java.util.Collections; import java.util.HashMap; import java.util.LinkedHashSet; +import java.util.List; import java.util.Map; import java.util.Set; +import java.util.stream.Collectors; import javax.inject.Inject; import org.gradle.api.DefaultTask; import org.gradle.api.artifacts.Configuration; @@ -28,8 +32,11 @@ import org.gradle.api.artifacts.ModuleVersionIdentifier; import org.gradle.api.artifacts.VersionCatalog; import org.gradle.api.artifacts.VersionCatalogsExtension; +import org.gradle.api.artifacts.result.DependencyResult; import org.gradle.api.artifacts.result.ResolvedComponentResult; import org.gradle.api.artifacts.result.ResolvedDependencyResult; +import org.gradle.api.artifacts.result.UnresolvedDependencyResult; +import org.gradle.api.file.RegularFile; import org.gradle.api.provider.Provider; import org.gradle.api.tasks.Input; import org.gradle.api.tasks.Nested; @@ -43,7 +50,23 @@ public abstract class CheckForUpdatesTask extends DefaultTask { private final Set libraries; @Inject - public CheckForUpdatesTask(Configuration updateConf, String catalog) { + public CheckForUpdatesTask(Configuration updateConf, String catalog, RegularFile commentFile) + throws IOException { + // Check for overrides to the default version selection ('+'), using comments of the form: + // # checkForUpdates: library-name:1.2.+ + List fileComments = Files.lines(commentFile.getAsFile().toPath()) + .filter(l -> l.matches("# *checkForUpdates:.*")) + .map(l -> l.replaceFirst("# *checkForUpdates:", "").trim()) + .collect(Collectors.toList()); + Map aliasToVersionSelector = new HashMap<>(2*fileComments.size()); + for (String comment : fileComments) { + String[] parts = comment.split(":", 2); + String name = parts[0].replaceAll("[_-]", "."); + if (aliasToVersionSelector.put(name, parts[1]) != null) { + throw new RuntimeException("Duplicate checkForUpdates comment for library: " + name); + } + } + updateConf.setVisible(false); updateConf.setTransitive(false); VersionCatalog versionCatalog = getProject().getExtensions().getByType(VersionCatalogsExtension.class).named(catalog); @@ -57,8 +80,12 @@ public CheckForUpdatesTask(Configuration updateConf, String catalog) { oldConf.getDependencies().add(oldDep); Configuration newConf = updateConf.copy(); + String versionSelector = aliasToVersionSelector.remove(name); + if (versionSelector == null) { + versionSelector = "+"; + } Dependency newDep = getProject().getDependencies().create( - depMap(dep.getGroup(), dep.getName(), "+", "pom")); + depMap(dep.getGroup(), dep.getName(), versionSelector, "pom")); newConf.getDependencies().add(newDep); libraries.add(new Library( @@ -66,6 +93,10 @@ public CheckForUpdatesTask(Configuration updateConf, String catalog) { oldConf.getIncoming().getResolutionResult().getRootComponent(), newConf.getIncoming().getResolutionResult().getRootComponent())); } + if (!aliasToVersionSelector.isEmpty()) { + throw new RuntimeException( + "Unused checkForUpdates comments: " + aliasToVersionSelector.keySet()); + } this.libraries = Collections.unmodifiableSet(libraries); } @@ -88,14 +119,26 @@ protected Set getLibraries() { public void checkForUpdates() { for (Library lib : libraries) { String name = lib.getName(); - ModuleVersionIdentifier oldId = ((ResolvedDependencyResult) lib.getOldResult().get() - .getDependencies().iterator().next()).getSelected().getModuleVersion(); - ModuleVersionIdentifier newId = ((ResolvedDependencyResult) lib.getNewResult().get() - .getDependencies().iterator().next()).getSelected().getModuleVersion(); + DependencyResult oldResult = lib.getOldResult().get().getDependencies().iterator().next(); + if (oldResult instanceof UnresolvedDependencyResult) { + System.out.println(String.format( + "- Current version of libs.%s not resolved", name)); + continue; + } + DependencyResult newResult = lib.getNewResult().get().getDependencies().iterator().next(); + if (newResult instanceof UnresolvedDependencyResult) { + System.out.println(String.format( + "- New version of libs.%s not resolved", name)); + continue; + } + ModuleVersionIdentifier oldId = + ((ResolvedDependencyResult) oldResult).getSelected().getModuleVersion(); + ModuleVersionIdentifier newId = + ((ResolvedDependencyResult) newResult).getSelected().getModuleVersion(); if (oldId != newId) { System.out.println(String.format( - "libs.%s = %s:%s %s -> %s", - name, newId.getGroup(), newId.getModule(), oldId.getVersion(), newId.getVersion())); + "libs.%s = %s %s -> %s", + name, newId.getModule(), oldId.getVersion(), newId.getVersion())); } } } diff --git a/buildscripts/checkstyle.xml b/buildscripts/checkstyle.xml index 960fa162ed1..0ec8ecc79ce 100644 --- a/buildscripts/checkstyle.xml +++ b/buildscripts/checkstyle.xml @@ -38,6 +38,12 @@ + + + + + + @@ -202,7 +208,7 @@ - diff --git a/buildscripts/cloudbuild-testing.yaml b/buildscripts/cloudbuild-testing.yaml new file mode 100644 index 00000000000..623b85b6882 --- /dev/null +++ b/buildscripts/cloudbuild-testing.yaml @@ -0,0 +1,64 @@ +substitutions: + _GAE_SERVICE_ACCOUNT: appengine-testing-java@grpc-testing.iam.gserviceaccount.com +options: + env: + - BUILD_ID=$BUILD_ID + - KOKORO_GAE_SERVICE=java-gae-interop-test + - DUMMY_DEFAULT_VERSION=dummy-default + - GRADLE_OPTS=-Dorg.gradle.jvmargs='-Xmx1g' + - GRADLE_FLAGS=-PskipCodegen=true -PskipAndroid=true + logging: CLOUD_LOGGING_ONLY + machineType: E2_HIGHCPU_8 + +steps: +- id: clean-stale-deploys + name: gcr.io/cloud-builders/gcloud + allowFailure: true + script: | + #!/usr/bin/env bash + set -e + echo "Cleaning out stale deploys from previous runs, it is ok if this part fails" + # If the test fails, the deployment is leaked. + # Delete all versions whose name is not 'dummy-default' and is older than 1 hour. + # This expression is an ISO8601 relative date: + # https://cloud.google.com/sdk/gcloud/reference/topic/datetimes + (gcloud app versions list --format="get(version.id)" \ + --filter="service=$KOKORO_GAE_SERVICE AND NOT version : '$DUMMY_DEFAULT_VERSION' AND version.createTime<'-p1h'" \ + | xargs -i gcloud app services delete "$KOKORO_GAE_SERVICE" --version {} --quiet) || true + +- name: gcr.io/cloud-builders/docker + args: ['build', '-t', 'gae-build', 'buildscripts/gae-build/'] + +- id: build + name: gae-build + script: | + #!/usr/bin/env bash + exec ./gradlew $GRADLE_FLAGS :grpc-gae-interop-testing-jdk8:appengineStage + +- id: deploy + name: gcr.io/cloud-builders/gcloud + args: + - app + - deploy + - gae-interop-testing/gae-jdk8/build/staged-app/app.yaml + - --service-account=$_GAE_SERVICE_ACCOUNT + - --no-promote + - --no-stop-previous-version + - --version=cb-$BUILD_ID + +- id: runInteropTestRemote + name: eclipse-temurin:17-jdk + env: + - PROJECT_ID=$PROJECT_ID + script: | + #!/usr/bin/env bash + exec ./gradlew $GRADLE_FLAGS --stacktrace -PgaeDeployVersion="cb-$BUILD_ID" \ + -PgaeProjectId="$PROJECT_ID" :grpc-gae-interop-testing-jdk8:runInteropTestRemote + +- id: cleanup + name: gcr.io/cloud-builders/gcloud + script: | + #!/usr/bin/env bash + set -e + echo "Performing cleanup now." + gcloud app services delete "$KOKORO_GAE_SERVICE" --version "cb-$BUILD_ID" --quiet diff --git a/buildscripts/gae-build/Dockerfile b/buildscripts/gae-build/Dockerfile new file mode 100644 index 00000000000..7e68b270801 --- /dev/null +++ b/buildscripts/gae-build/Dockerfile @@ -0,0 +1,10 @@ +FROM eclipse-temurin:17-jdk + +# The AppEngine Gradle plugin downloads and runs its own gcloud to get the .jar +# to link against, so we need Python even if we use gcloud deploy directly +# instead of using the plugin. +RUN export DEBIAN_FRONTEND=noninteractive && \ + apt-get update && \ + apt-get upgrade -y && \ + apt-get install -y --no-install-recommends python3 && \ + rm -rf /var/lib/apt/lists/* diff --git a/buildscripts/grpc-java-artifacts/Dockerfile b/buildscripts/grpc-java-artifacts/Dockerfile index 47be5c46145..54c595cd960 100644 --- a/buildscripts/grpc-java-artifacts/Dockerfile +++ b/buildscripts/grpc-java-artifacts/Dockerfile @@ -1,17 +1,15 @@ -FROM centos:7.9.2009 +FROM almalinux:8 RUN yum install -y \ autoconf \ automake \ + diffutils \ gcc-c++ \ - gcc-c++.i686 \ glibc-devel \ glibc-devel.i686 \ java-11-openjdk-devel \ libstdc++-devel \ libstdc++-devel.i686 \ - libstdc++-static \ - libstdc++-static.i686 \ libtool \ make \ tar \ @@ -29,7 +27,11 @@ RUN mkdir -p "$ANDROID_HOME/cmdline-tools" && \ mv "$ANDROID_HOME/cmdline-tools/cmdline-tools" "$ANDROID_HOME/cmdline-tools/latest" && \ yes | "$ANDROID_HOME/cmdline-tools/latest/bin/sdkmanager" --licenses +RUN curl -Ls https://github.com/Kitware/CMake/releases/download/v3.26.3/cmake-3.26.3-linux-x86_64.tar.gz | \ + tar xz -C /var/local + # Install Maven -RUN curl -Ls https://repo.maven.apache.org/maven2/org/apache/maven/apache-maven/3.3.9/apache-maven-3.3.9-bin.tar.gz | \ +RUN curl -Ls https://archive.apache.org/dist/maven/maven-3/3.8.8/binaries/apache-maven-3.8.8-bin.tar.gz | \ tar xz -C /var/local -ENV PATH /var/local/apache-maven-3.3.9/bin:$PATH +ENV PATH /var/local/cmake-3.26.3-linux-x86_64/bin:/var/local/apache-maven-3.8.8/bin:$PATH + diff --git a/buildscripts/grpc-java-artifacts/Dockerfile.multiarch.base b/buildscripts/grpc-java-artifacts/Dockerfile.multiarch.base index 8f7cfae2f52..6b670994677 100644 --- a/buildscripts/grpc-java-artifacts/Dockerfile.multiarch.base +++ b/buildscripts/grpc-java-artifacts/Dockerfile.multiarch.base @@ -1,4 +1,7 @@ -FROM ubuntu:18.04 +FROM ubuntu:24.04 + +# Redirect to the internal mirror to bypass the Kokoro network block +RUN sed -i 's|http://archive.ubuntu.com/ubuntu/|http://mirror.bazel.build/archive.ubuntu.com/ubuntu/|g' /etc/apt/sources.list RUN export DEBIAN_FRONTEND=noninteractive && \ apt-get update && \ @@ -9,6 +12,12 @@ RUN export DEBIAN_FRONTEND=noninteractive && \ curl \ g++-aarch64-linux-gnu \ g++-powerpc64le-linux-gnu \ - openjdk-8-jdk \ + openjdk-11-jdk \ + pkgconf \ && \ rm -rf /var/lib/apt/lists/* + +RUN curl -Ls https://github.com/Kitware/CMake/releases/download/v3.26.3/cmake-3.26.3-linux-x86_64.tar.gz | \ + tar xz -C /var/local +ENV PATH /var/local/cmake-3.26.3-linux-x86_64/bin:$PATH + diff --git a/buildscripts/grpc-java-artifacts/Dockerfile.ubuntu2004.base b/buildscripts/grpc-java-artifacts/Dockerfile.ubuntu2004.base index 2d11d76c373..e987fb3e684 100644 --- a/buildscripts/grpc-java-artifacts/Dockerfile.ubuntu2004.base +++ b/buildscripts/grpc-java-artifacts/Dockerfile.ubuntu2004.base @@ -9,5 +9,11 @@ RUN export DEBIAN_FRONTEND=noninteractive && \ curl \ g++-s390x-linux-gnu \ openjdk-8-jdk \ + pkg-config \ && \ rm -rf /var/lib/apt/lists/* + +RUN curl -Ls https://github.com/Kitware/CMake/releases/download/v3.26.3/cmake-3.26.3-linux-x86_64.tar.gz | \ + tar xz -C /var/local +ENV PATH /var/local/cmake-3.26.3-linux-x86_64/bin:$PATH + diff --git a/buildscripts/kokoro/android-interop.sh b/buildscripts/kokoro/android-interop.sh index 9e3b134dfe1..877311daca5 100755 --- a/buildscripts/kokoro/android-interop.sh +++ b/buildscripts/kokoro/android-interop.sh @@ -1,37 +1,32 @@ #!/bin/bash set -exu -o pipefail -if [[ -f /VERSION ]]; then - cat /VERSION -fi - -# Install gRPC and codegen for the Android interop app -# (a composite gradle build can't find protoc-gen-grpc-java) cd github/grpc-java -export GRADLE_OPTS=-Xmx512m -export LDFLAGS=-L/tmp/protobuf/lib -export CXXFLAGS=-I/tmp/protobuf/include -export LD_LIBRARY_PATH=/tmp/protobuf/lib -export OS_NAME=$(uname) - -(yes || true) | "${ANDROID_HOME}/tools/bin/sdkmanager" --licenses - -# Proto deps -buildscripts/make_dependencies.sh +export ANDROID_HOME=/tmp/Android/Sdk +mkdir -p "${ANDROID_HOME}/cmdline-tools" +curl -Ls -o cmdline.zip \ + "https://dl.google.com/android/repository/commandlinetools-linux-9477386_latest.zip" +unzip -qd "${ANDROID_HOME}/cmdline-tools" cmdline.zip +rm cmdline.zip +mv "${ANDROID_HOME}/cmdline-tools/cmdline-tools" "${ANDROID_HOME}/cmdline-tools/latest" +(yes || true) | "${ANDROID_HOME}/cmdline-tools/latest/bin/sdkmanager" --licenses # Build Android with Java 11, this adds it to the PATH sudo update-java-alternatives --set java-1.11.0-openjdk-amd64 # Unset any existing JAVA_HOME env var to stop Gradle from using it unset JAVA_HOME -GRADLE_FLAGS="-Pandroid.useAndroidX=true" +GRADLE_FLAGS="-Pandroid.useAndroidX=true -Dorg.gradle.jvmargs=-Xmx1024m -PskipCodegen=true" ./gradlew $GRADLE_FLAGS :grpc-android-interop-testing:assembleDebug ./gradlew $GRADLE_FLAGS :grpc-android-interop-testing:assembleDebugAndroidTest ./gradlew $GRADLE_FLAGS :grpc-binder:assembleDebugAndroidTest +# To see currently-available virtual devices: +# gcloud firebase test android models list --filter=form=virtual + # Run interop instrumentation tests on Firebase Test Lab gcloud firebase test android run \ --type instrumentation \ @@ -46,9 +41,6 @@ gcloud firebase test android run \ --device model=MediumPhone.arm,version=26,locale=en,orientation=portrait \ --device model=Nexus6P,version=25,locale=en,orientation=portrait \ --device model=Nexus6P,version=24,locale=en,orientation=portrait \ - --device model=Nexus6P,version=23,locale=en,orientation=portrait \ - --device model=Nexus6,version=22,locale=en,orientation=portrait \ - --device model=Nexus6,version=21,locale=en,orientation=portrait # Run binderchannel instrumentation tests on Firebase Test Lab gcloud firebase test android run \ @@ -62,6 +54,3 @@ gcloud firebase test android run \ --device model=MediumPhone.arm,version=26,locale=en,orientation=portrait \ --device model=Nexus6P,version=25,locale=en,orientation=portrait \ --device model=Nexus6P,version=24,locale=en,orientation=portrait \ - --device model=Nexus6P,version=23,locale=en,orientation=portrait \ - --device model=Nexus6,version=22,locale=en,orientation=portrait \ - --device model=Nexus6,version=21,locale=en,orientation=portrait diff --git a/buildscripts/kokoro/android.sh b/buildscripts/kokoro/android.sh index cdf4938b670..677825ae66b 100755 --- a/buildscripts/kokoro/android.sh +++ b/buildscripts/kokoro/android.sh @@ -1,7 +1,6 @@ #!/bin/bash set -exu -o pipefail -cat /VERSION BASE_DIR="$(pwd)" @@ -10,9 +9,6 @@ BASE_DIR="$(pwd)" cd "$BASE_DIR/github/grpc-java" -export LDFLAGS=-L/tmp/protobuf/lib -export CXXFLAGS=-I/tmp/protobuf/include -export LD_LIBRARY_PATH=/tmp/protobuf/lib export OS_NAME=$(uname) cat <> gradle.properties @@ -23,11 +19,26 @@ cat <> gradle.properties org.gradle.jvmargs=-Xmx2048m -XX:MaxMetaspaceSize=1024m EOF -(yes || true) | "${ANDROID_HOME}/tools/bin/sdkmanager" --licenses - +export ANDROID_HOME=/tmp/Android/Sdk +mkdir -p "${ANDROID_HOME}/cmdline-tools" +curl -Ls -o cmdline.zip \ + "https://dl.google.com/android/repository/commandlinetools-linux-9477386_latest.zip" +unzip -qd "${ANDROID_HOME}/cmdline-tools" cmdline.zip +rm cmdline.zip +mv "${ANDROID_HOME}/cmdline-tools/cmdline-tools" "${ANDROID_HOME}/cmdline-tools/latest" +(yes || true) | "${ANDROID_HOME}/cmdline-tools/latest/bin/sdkmanager" --licenses +curl -Ls https://github.com/Kitware/CMake/releases/download/v3.26.3/cmake-3.26.3-linux-x86_64.tar.gz | \ + tar xz -C /tmp +export PATH=/tmp/cmake-3.26.3-linux-x86_64/bin:$PATH + # Proto deps buildscripts/make_dependencies.sh +sudo apt-get update && sudo apt-get install pkg-config +export LDFLAGS="$(PKG_CONFIG_PATH=/tmp/protobuf/lib/pkgconfig pkg-config --libs protobuf)" +export CXXFLAGS="$(PKG_CONFIG_PATH=/tmp/protobuf/lib/pkgconfig pkg-config --cflags protobuf)" +export LD_LIBRARY_PATH=/tmp/protobuf/lib + # Build Android with Java 11, this adds it to the PATH sudo update-java-alternatives --set java-1.11.0-openjdk-amd64 # Unset any existing JAVA_HOME env var to stop Gradle from using it @@ -68,23 +79,18 @@ if [[ -z "${KOKORO_GITHUB_PULL_REQUEST_COMMIT:-}" ]]; then exit 0 fi -# Save a copy of set_github_status.py (it may differ from the base commit) - -SET_GITHUB_STATUS="$TMPDIR/set_github_status.py" -cp "$BASE_DIR/github/grpc-java/buildscripts/set_github_status.py" "$SET_GITHUB_STATUS" - - # Collect APK size and dex count stats for the helloworld example -sudo update-java-alternatives --set java-1.8.0-openjdk-amd64 - HELLO_WORLD_OUTPUT_DIR="$BASE_DIR/github/grpc-java/examples/android/helloworld/app/build/outputs" +# Install dependencies of apkanalyzer +"${ANDROID_HOME}/cmdline-tools/latest/bin/sdkmanager" --install "build-tools;35.0.0" + read -r ignored new_dex_count < \ - <("${ANDROID_HOME}/tools/bin/apkanalyzer" dex references \ + <("${ANDROID_HOME}/cmdline-tools/latest/bin/apkanalyzer" dex references \ "$HELLO_WORLD_OUTPUT_DIR/apk/release/app-release-unsigned.apk") set +x -all_new_methods=`"${ANDROID_HOME}/tools/bin/apkanalyzer" dex packages \ +all_new_methods=`"${ANDROID_HOME}/cmdline-tools/latest/bin/apkanalyzer" dex packages \ --proguard-mapping "$HELLO_WORLD_OUTPUT_DIR/mapping/release/mapping.txt" \ "$HELLO_WORLD_OUTPUT_DIR/apk/release/app-release-unsigned.apk" | grep ^M | cut -f4 | sort` set -x @@ -93,22 +99,20 @@ new_apk_size="$(stat --printf=%s $HELLO_WORLD_OUTPUT_DIR/apk/release/app-release # Get the APK size and dex count stats using the pull request base commit -sudo update-java-alternatives --set java-1.11.0-openjdk-amd64 - cd $BASE_DIR/github/grpc-java ./gradlew clean git checkout HEAD^ ./gradlew --stop # use a new daemon to build the previous commit +GRADLE_FLAGS="${GRADLE_FLAGS} -PskipCodegen=true" # skip codegen for build from previous commit since it wasn't built with --std=c++14 when making this change ./gradlew publishToMavenLocal $GRADLE_FLAGS cd examples/android/helloworld/ ../../gradlew build $GRADLE_FLAGS -sudo update-java-alternatives --set java-1.8.0-openjdk-amd64 read -r ignored old_dex_count < \ - <("${ANDROID_HOME}/tools/bin/apkanalyzer" dex references app/build/outputs/apk/release/app-release-unsigned.apk) + <("${ANDROID_HOME}/cmdline-tools/latest/bin/apkanalyzer" dex references app/build/outputs/apk/release/app-release-unsigned.apk) set +x -all_old_methods=`"${ANDROID_HOME}/tools/bin/apkanalyzer" dex packages --proguard-mapping app/build/outputs/mapping/release/mapping.txt app/build/outputs/apk/release/app-release-unsigned.apk | grep ^M | cut -f4 | sort` +all_old_methods=`"${ANDROID_HOME}/cmdline-tools/latest/bin/apkanalyzer" dex packages --proguard-mapping app/build/outputs/mapping/release/mapping.txt app/build/outputs/apk/release/app-release-unsigned.apk | grep ^M | cut -f4 | sort` set -x old_apk_size="$(stat --printf=%s app/build/outputs/apk/release/app-release-unsigned.apk)" @@ -128,16 +132,19 @@ fi # Update the statuses with the deltas +set +x gsutil cp gs://grpc-testing-secrets/github_credentials/oauth_token.txt ~/ -"$SET_GITHUB_STATUS" \ - --sha1 "$KOKORO_GITHUB_PULL_REQUEST_COMMIT" \ - --state success \ - --description "New DEX reference count: $(printf "%'d" "$new_dex_count") (delta: $(printf "%'d" "$dex_count_delta"))" \ - --context android/dex_diff --oauth_file ~/oauth_token.txt - -"$SET_GITHUB_STATUS" \ - --sha1 "$KOKORO_GITHUB_PULL_REQUEST_COMMIT" \ - --state success \ - --description "New APK size in bytes: $(printf "%'d" "$new_apk_size") (delta: $(printf "%'d" "$apk_size_delta"))" \ - --context android/apk_diff --oauth_file ~/oauth_token.txt +desc="New DEX reference count: $(printf "%'d" "$new_dex_count") (delta: $(printf "%'d" "$dex_count_delta"))" +echo "Setting status: $desc" +curl -f -s -X POST -H "Content-Type: application/json" \ + -H "Authorization: token $(cat ~/oauth_token.txt | tr -d '\n')" \ + -d '{"state": "success", "context": "android/dex_diff", "description": "'"${desc}"'"}' \ + "https://api.github.com/repos/grpc/grpc-java/statuses/${KOKORO_GITHUB_PULL_REQUEST_COMMIT}" + +desc="New APK size in bytes: $(printf "%'d" "$new_apk_size") (delta: $(printf "%'d" "$apk_size_delta"))" +echo "Setting status: $desc" +curl -f -s -X POST -H "Content-Type: application/json" \ + -H "Authorization: token $(cat ~/oauth_token.txt | tr -d '\n')" \ + -d '{"state": "success", "context": "android/apk_diff", "description": "'"${desc}"'"}' \ + "https://api.github.com/repos/grpc/grpc-java/statuses/${KOKORO_GITHUB_PULL_REQUEST_COMMIT}" diff --git a/buildscripts/kokoro/gae-interop.sh b/buildscripts/kokoro/gae-interop.sh deleted file mode 100755 index c4ce56cac52..00000000000 --- a/buildscripts/kokoro/gae-interop.sh +++ /dev/null @@ -1,55 +0,0 @@ -#!/bin/bash - -set -exu -o pipefail -if [[ -f /VERSION ]]; then - cat /VERSION -fi - -KOKORO_GAE_SERVICE="java-gae-interop-test" - -# We deploy as different versions of a single service, this way any stale -# lingering deploys can be easily cleaned up by purging all running versions -# of this service. -KOKORO_GAE_APP_VERSION=$(hostname) - -# A dummy version that can be the recipient of all traffic, so that the kokoro test version can be -# set to 0 traffic. This is a requirement in order to delete it. -DUMMY_DEFAULT_VERSION='dummy-default' - -function cleanup() { - echo "Performing cleanup now." - gcloud app services delete $KOKORO_GAE_SERVICE --version $KOKORO_GAE_APP_VERSION --quiet -} -trap cleanup SIGHUP SIGINT SIGTERM EXIT - -readonly GRPC_JAVA_DIR="$(cd "$(dirname "$0")"/../.. && pwd)" -cd "$GRPC_JAVA_DIR" - -## -## Deploy the dummy 'default' version of the service -## -GRADLE_FLAGS="--stacktrace -DgaeStopPreviousVersion=false -PskipCodegen=true -PskipAndroid=true" -export GRADLE_OPTS="-Dorg.gradle.jvmargs='-Xmx1g'" - -# Deploy the dummy 'default' version. We only require that it exists when cleanup() is called. -# It ok if we race with another run and fail here, because the end result is idempotent. -set +e -if ! gcloud app versions describe "$DUMMY_DEFAULT_VERSION" --service="$KOKORO_GAE_SERVICE"; then - ./gradlew $GRADLE_FLAGS -DgaeDeployVersion="$DUMMY_DEFAULT_VERSION" -DgaePromote=true :grpc-gae-interop-testing-jdk8:appengineDeploy -else - echo "default version already exists: $DUMMY_DEFAULT_VERSION" -fi -set -e - -# Deploy and test the real app (jdk8) -./gradlew $GRADLE_FLAGS -DgaeDeployVersion="$KOKORO_GAE_APP_VERSION" :grpc-gae-interop-testing-jdk8:runInteropTestRemote - -set +e -echo "Cleaning out stale deploys from previous runs, it is ok if this part fails" - -# Sometimes the trap based cleanup fails. -# Delete all versions whose name is not 'dummy-default' and is older than 1 hour. -# This expression is an ISO8601 relative date: -# https://cloud.google.com/sdk/gcloud/reference/topic/datetimes -gcloud app versions list --format="get(version.id)" --filter="service=$KOKORO_GAE_SERVICE AND NOT version : 'dummy-default' AND version.createTime<'-p1h'" | xargs -i gcloud app services delete "$KOKORO_GAE_SERVICE" --version {} --quiet -exit 0 diff --git a/buildscripts/kokoro/linux_aarch64.cfg b/buildscripts/kokoro/linux_aarch64.cfg deleted file mode 100644 index 325d910c5ea..00000000000 --- a/buildscripts/kokoro/linux_aarch64.cfg +++ /dev/null @@ -1,13 +0,0 @@ -# Config file for internal CI - -# Location of the continuous shell script in repository. -build_file: "grpc-java/buildscripts/kokoro/linux_aarch64.sh" -timeout_mins: 60 - -action { - define_artifacts { - regex: "github/grpc-java/**/build/test-results/**/sponge_log.xml" - regex: "github/grpc-java/mvn-artifacts/**" - regex: "github/grpc-java/artifacts/**" - } -} diff --git a/buildscripts/kokoro/linux_aarch64.sh b/buildscripts/kokoro/linux_aarch64.sh deleted file mode 100755 index f4a1292efb5..00000000000 --- a/buildscripts/kokoro/linux_aarch64.sh +++ /dev/null @@ -1,17 +0,0 @@ -#!/bin/bash -set -veux -o pipefail - -if [[ -f /VERSION ]]; then - cat /VERSION -fi - -readonly GRPC_JAVA_DIR="$(cd "$(dirname "$0")"/../.. && pwd)" - -. "$GRPC_JAVA_DIR"/buildscripts/kokoro/kokoro.sh -trap spongify_logs EXIT - -cd github/grpc-java - -buildscripts/qemu_helpers/prepare_qemu.sh - -buildscripts/run_arm64_tests_in_docker.sh diff --git a/buildscripts/kokoro/macos.cfg b/buildscripts/kokoro/macos.cfg index a58691a7102..4c79743692e 100644 --- a/buildscripts/kokoro/macos.cfg +++ b/buildscripts/kokoro/macos.cfg @@ -2,7 +2,7 @@ # Location of the continuous shell script in repository. build_file: "grpc-java/buildscripts/kokoro/macos.sh" -timeout_mins: 45 +timeout_mins: 60 # We always build mvn artifacts. action { diff --git a/buildscripts/kokoro/macos.sh b/buildscripts/kokoro/macos.sh index 97259231ee8..0240c0650f7 100755 --- a/buildscripts/kokoro/macos.sh +++ b/buildscripts/kokoro/macos.sh @@ -1,5 +1,6 @@ #!/bin/bash set -veux -o pipefail +CMAKE_VERSION=3.31.10 if [[ -f /VERSION ]]; then cat /VERSION @@ -7,6 +8,10 @@ fi readonly GRPC_JAVA_DIR="$(cd "$(dirname "$0")"/../.. && pwd)" +DOWNLOAD_DIR=/tmp/source +mkdir -p ${DOWNLOAD_DIR} +curl -Ls https://github.com/Kitware/CMake/releases/download/v${CMAKE_VERSION}/cmake-${CMAKE_VERSION}-macos-universal.tar.gz | tar xz -C ${DOWNLOAD_DIR} + # We had problems with random tests timing out because it took seconds to do # trivial (ns) operations. The Kokoro Mac machines have 2 cores with 4 logical # threads, so Gradle should be using 4 workers by default. @@ -15,4 +20,9 @@ export GRADLE_FLAGS="${GRADLE_FLAGS:-} --max-workers=2" . "$GRPC_JAVA_DIR"/buildscripts/kokoro/kokoro.sh trap spongify_logs EXIT +brew install --cask temurin@8 +export PATH="$(/usr/libexec/java_home -v"1.8.0")/bin:${DOWNLOAD_DIR}/cmake-${CMAKE_VERSION}-macos-universal/CMake.app/Contents/bin:${PATH}" +export JAVA_HOME="$(/usr/libexec/java_home -v"1.8.0")" +brew install maven + "$GRPC_JAVA_DIR"/buildscripts/kokoro/unix.sh diff --git a/buildscripts/kokoro/psm-cloud-run.cfg b/buildscripts/kokoro/psm-cloud-run.cfg new file mode 100644 index 00000000000..1f2d6da208f --- /dev/null +++ b/buildscripts/kokoro/psm-cloud-run.cfg @@ -0,0 +1,17 @@ +# Config file for internal CI + +# Location of the continuous shell script in repository. +build_file: "grpc-java/buildscripts/kokoro/psm-interop-test-java.sh" +timeout_mins: 240 + +action { + define_artifacts { + regex: "artifacts/**/*sponge_log.xml" + regex: "artifacts/**/*.log" + strip_prefix: "artifacts" + } +} +env_vars { + key: "PSM_TEST_SUITE" + value: "cloud_run" +} diff --git a/buildscripts/kokoro/psm-csm.cfg b/buildscripts/kokoro/psm-csm.cfg new file mode 100644 index 00000000000..6f28dec20a0 --- /dev/null +++ b/buildscripts/kokoro/psm-csm.cfg @@ -0,0 +1,17 @@ +# Config file for internal CI + +# Location of the continuous shell script in repository. +build_file: "grpc-java/buildscripts/kokoro/psm-interop-test-java.sh" +timeout_mins: 120 + +action { + define_artifacts { + regex: "artifacts/**/*sponge_log.xml" + regex: "artifacts/**/*.log" + strip_prefix: "artifacts" + } +} +env_vars { + key: "PSM_TEST_SUITE" + value: "csm" +} diff --git a/buildscripts/kokoro/psm-dualstack.cfg b/buildscripts/kokoro/psm-dualstack.cfg new file mode 100644 index 00000000000..a55d91a95b0 --- /dev/null +++ b/buildscripts/kokoro/psm-dualstack.cfg @@ -0,0 +1,17 @@ +# Config file for internal CI + +# Location of the continuous shell script in repository. +build_file: "grpc-java/buildscripts/kokoro/psm-interop-test-java.sh" +timeout_mins: 240 + +action { + define_artifacts { + regex: "artifacts/**/*sponge_log.xml" + regex: "artifacts/**/*.log" + strip_prefix: "artifacts" + } +} +env_vars { + key: "PSM_TEST_SUITE" + value: "dualstack" +} diff --git a/buildscripts/kokoro/psm-interop-build-java.sh b/buildscripts/kokoro/psm-interop-build-java.sh new file mode 100755 index 00000000000..8c7a970ef23 --- /dev/null +++ b/buildscripts/kokoro/psm-interop-build-java.sh @@ -0,0 +1,90 @@ +#!/usr/bin/env bash +# Copyright 2024 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +set -eo pipefail + +# This file defines psm::lang::build_docker_images, which is directly called +# from psm_interop_kokoro_lib.sh. + +# Used locally. +readonly BUILD_APP_PATH="interop-testing/build/install/grpc-interop-testing" + +####################################### +# Builds the test app using gradle and smoke-checks its binaries +# Globals: +# SRC_DIR Absolute path to the source repo. +# BUILD_APP_PATH +# Arguments: +# None +# Outputs: +# Writes the output of xds-test-client and xds-test-server --help to stderr +####################################### +_build_java_test_app() { + psm::tools::log "Building Java test app" + cd "${SRC_DIR}" + + set -x + GRADLE_OPTS="-Dorg.gradle.jvmargs='-Xmx1g'" \ + ./gradlew --no-daemon grpc-interop-testing:installDist -x test \ + -PskipCodegen=true -PskipAndroid=true --console=plain + set +x + + psm::tools::log "Test-run grpc-java PSM interop binaries" + psm::tools::run_ignore_exit_code "${SRC_DIR}/${BUILD_APP_PATH}/bin/xds-test-client" --help + psm::tools::run_ignore_exit_code "${SRC_DIR}/${BUILD_APP_PATH}/bin/xds-test-server" --help +} + +####################################### +# Builds test app Docker images and pushes them to GCR +# Globals: +# BUILD_APP_PATH +# SERVER_IMAGE_NAME: Test server Docker image name +# CLIENT_IMAGE_NAME: Test client Docker image name +# GIT_COMMIT: SHA-1 of git commit being built +# TESTING_VERSION: version branch under test, f.e. v1.42.x, master +# Arguments: +# None +# Outputs: +# Writes the output of `gcloud builds submit` to stdout, stderr +####################################### +psm::lang::build_docker_images() { + local java_build_log="${BUILD_LOGS_ROOT}/build-lang-java.log" + _build_java_test_app |& tee "${java_build_log}" + + psm::tools::log "Building Java xDS interop test app Docker images" + local docker_dir="${SRC_DIR}/buildscripts/xds-k8s" + local build_dir + build_dir="$(mktemp -d)" + + # Copy Docker files, log properties, and the test app to the build dir + { + cp -v "${docker_dir}/"*.Dockerfile "${build_dir}" + cp -v "${docker_dir}/"*.properties "${build_dir}" + cp -rv "${SRC_DIR}/${BUILD_APP_PATH}" "${build_dir}" + } >> "${java_build_log}" + + + # cloudbuild.yaml substitution variables + local substitutions="" + substitutions+="_SERVER_IMAGE_NAME=${SERVER_IMAGE_NAME}," + substitutions+="_CLIENT_IMAGE_NAME=${CLIENT_IMAGE_NAME}," + substitutions+="COMMIT_SHA=${GIT_COMMIT}," + substitutions+="BRANCH_NAME=${TESTING_VERSION}," + + # Run Google Cloud Build + gcloud builds submit "${build_dir}" \ + --config="${docker_dir}/cloudbuild.yaml" \ + --substitutions="${substitutions}" \ + | tee -a "${java_build_log}" +} diff --git a/buildscripts/kokoro/psm-interop-test-java.sh b/buildscripts/kokoro/psm-interop-test-java.sh new file mode 100755 index 00000000000..f249d579d82 --- /dev/null +++ b/buildscripts/kokoro/psm-interop-test-java.sh @@ -0,0 +1,27 @@ +#!/usr/bin/env bash +set -eo pipefail + +# Input parameters to psm:: methods of the install script. +readonly GRPC_LANGUAGE="java" +readonly BUILD_SCRIPT_DIR="$(dirname "$0")" + +# Used locally. +readonly TEST_DRIVER_INSTALL_SCRIPT_URL="https://raw.githubusercontent.com/${TEST_DRIVER_REPO_OWNER:-grpc}/psm-interop/${TEST_DRIVER_BRANCH:-main}/.kokoro/psm_interop_kokoro_lib.sh" + +psm::lang::source_install_lib() { + echo "Sourcing test driver install script from: ${TEST_DRIVER_INSTALL_SCRIPT_URL}" + local install_lib + # Download to a tmp file. + install_lib="$(mktemp -d)/psm_interop_kokoro_lib.sh" + curl -s --retry-connrefused --retry 5 -o "${install_lib}" "${TEST_DRIVER_INSTALL_SCRIPT_URL}" + # Checksum. + if command -v sha256sum &> /dev/null; then + echo "Install script checksum:" + sha256sum "${install_lib}" + fi + source "${install_lib}" +} + +psm::lang::source_install_lib +source "${BUILD_SCRIPT_DIR}/psm-interop-build-${GRPC_LANGUAGE}.sh" +psm::run "${PSM_TEST_SUITE}" diff --git a/buildscripts/kokoro/psm-light.cfg b/buildscripts/kokoro/psm-light.cfg new file mode 100644 index 00000000000..decd179efa3 --- /dev/null +++ b/buildscripts/kokoro/psm-light.cfg @@ -0,0 +1,17 @@ +# Config file for internal CI + +# Location of the continuous shell script in repository. +build_file: "grpc-java/buildscripts/kokoro/psm-interop-test-java.sh" +timeout_mins: 120 + +action { + define_artifacts { + regex: "artifacts/**/*sponge_log.xml" + regex: "artifacts/**/*.log" + strip_prefix: "artifacts" + } +} +env_vars { + key: "PSM_TEST_SUITE" + value: "light" +} diff --git a/buildscripts/kokoro/psm-security.cfg b/buildscripts/kokoro/psm-security.cfg index 508f4dbe9c1..76d2c578597 100644 --- a/buildscripts/kokoro/psm-security.cfg +++ b/buildscripts/kokoro/psm-security.cfg @@ -1,7 +1,7 @@ # Config file for internal CI # Location of the continuous shell script in repository. -build_file: "grpc-java/buildscripts/kokoro/psm-security.sh" +build_file: "grpc-java/buildscripts/kokoro/psm-interop-test-java.sh" timeout_mins: 240 action { @@ -11,3 +11,7 @@ action { strip_prefix: "artifacts" } } +env_vars { + key: "PSM_TEST_SUITE" + value: "security" +} diff --git a/buildscripts/kokoro/psm-security.sh b/buildscripts/kokoro/psm-security.sh deleted file mode 100755 index 651521d8fc6..00000000000 --- a/buildscripts/kokoro/psm-security.sh +++ /dev/null @@ -1,185 +0,0 @@ -#!/usr/bin/env bash -set -eo pipefail - -# Constants -readonly GITHUB_REPOSITORY_NAME="grpc-java" -readonly TEST_DRIVER_INSTALL_SCRIPT_URL="https://raw.githubusercontent.com/${TEST_DRIVER_REPO_OWNER:-grpc}/psm-interop/${TEST_DRIVER_BRANCH:-main}/.kokoro/psm_interop_kokoro_lib.sh" -## xDS test server/client Docker images -readonly SERVER_IMAGE_NAME="gcr.io/grpc-testing/xds-interop/java-server" -readonly CLIENT_IMAGE_NAME="gcr.io/grpc-testing/xds-interop/java-client" -readonly FORCE_IMAGE_BUILD="${FORCE_IMAGE_BUILD:-0}" -readonly BUILD_APP_PATH="interop-testing/build/install/grpc-interop-testing" - -####################################### -# Builds the test app using gradle and smoke-checks its binaries -# Globals: -# SRC_DIR -# BUILD_APP_PATH -# Arguments: -# None -# Outputs: -# Writes the output of xds-test-client and xds-test-server --help to stderr -####################################### -build_java_test_app() { - echo "Building Java test app" - cd "${SRC_DIR}" - GRADLE_OPTS="-Dorg.gradle.jvmargs='-Xmx1g'" \ - ./gradlew --no-daemon grpc-interop-testing:installDist -x test \ - -PskipCodegen=true -PskipAndroid=true --console=plain - - # Test-run binaries - run_ignore_exit_code "${SRC_DIR}/${BUILD_APP_PATH}/bin/xds-test-client" --help - run_ignore_exit_code "${SRC_DIR}/${BUILD_APP_PATH}/bin/xds-test-server" --help -} - -####################################### -# Builds test app Docker images and pushes them to GCR -# Globals: -# BUILD_APP_PATH -# SERVER_IMAGE_NAME: Test server Docker image name -# CLIENT_IMAGE_NAME: Test client Docker image name -# GIT_COMMIT: SHA-1 of git commit being built -# TESTING_VERSION: version branch under test, f.e. v1.42.x, master -# Arguments: -# None -# Outputs: -# Writes the output of `gcloud builds submit` to stdout, stderr -####################################### -build_test_app_docker_images() { - echo "Building Java xDS interop test app Docker images" - local docker_dir="${SRC_DIR}/buildscripts/xds-k8s" - local build_dir - build_dir="$(mktemp -d)" - # Copy Docker files, log properties, and the test app to the build dir - cp -v "${docker_dir}/"*.Dockerfile "${build_dir}" - cp -v "${docker_dir}/"*.properties "${build_dir}" - cp -rv "${SRC_DIR}/${BUILD_APP_PATH}" "${build_dir}" - # Pick a branch name for the built image - local branch_name='experimental' - if is_version_branch "${TESTING_VERSION}"; then - branch_name="${TESTING_VERSION}" - fi - # Run Google Cloud Build - gcloud builds submit "${build_dir}" \ - --config "${docker_dir}/cloudbuild.yaml" \ - --substitutions "_SERVER_IMAGE_NAME=${SERVER_IMAGE_NAME},_CLIENT_IMAGE_NAME=${CLIENT_IMAGE_NAME},COMMIT_SHA=${GIT_COMMIT},BRANCH_NAME=${branch_name}" - # TODO(sergiitk): extra "cosmetic" tags for versioned branches, e.g. v1.34.x - # TODO(sergiitk): do this when adding support for custom configs per version -} - -####################################### -# Builds test app and its docker images unless they already exist -# Globals: -# SERVER_IMAGE_NAME: Test server Docker image name -# CLIENT_IMAGE_NAME: Test client Docker image name -# GIT_COMMIT: SHA-1 of git commit being built -# FORCE_IMAGE_BUILD -# Arguments: -# None -# Outputs: -# Writes the output to stdout, stderr -####################################### -build_docker_images_if_needed() { - # Check if images already exist - server_tags="$(gcloud_gcr_list_image_tags "${SERVER_IMAGE_NAME}" "${GIT_COMMIT}")" - printf "Server image: %s:%s\n" "${SERVER_IMAGE_NAME}" "${GIT_COMMIT}" - echo "${server_tags:-Server image not found}" - - client_tags="$(gcloud_gcr_list_image_tags "${CLIENT_IMAGE_NAME}" "${GIT_COMMIT}")" - printf "Client image: %s:%s\n" "${CLIENT_IMAGE_NAME}" "${GIT_COMMIT}" - echo "${client_tags:-Client image not found}" - - # Build if any of the images are missing, or FORCE_IMAGE_BUILD=1 - if [[ "${FORCE_IMAGE_BUILD}" == "1" || -z "${server_tags}" || -z "${client_tags}" ]]; then - build_java_test_app - build_test_app_docker_images - else - echo "Skipping Java test app build" - fi -} - -####################################### -# Executes the test case -# Globals: -# TEST_DRIVER_FLAGFILE: Relative path to test driver flagfile -# KUBE_CONTEXT: The name of kubectl context with GKE cluster access -# TEST_XML_OUTPUT_DIR: Output directory for the test xUnit XML report -# SERVER_IMAGE_NAME: Test server Docker image name -# CLIENT_IMAGE_NAME: Test client Docker image name -# GIT_COMMIT: SHA-1 of git commit being built -# TESTING_VERSION: version branch under test: used by the framework to -# determine the supported PSM features. -# Arguments: -# Test case name -# Outputs: -# Writes the output of test execution to stdout, stderr -# Test xUnit report to ${TEST_XML_OUTPUT_DIR}/${test_name}/sponge_log.xml -####################################### -run_test() { - # Test driver usage: - # https://github.com/grpc/grpc/tree/master/tools/run_tests/xds_k8s_test_driver#basic-usage - local test_name="${1:?Usage: run_test test_name}" - local out_dir="${TEST_XML_OUTPUT_DIR}/${test_name}" - mkdir -pv "${out_dir}" - set -x - python -m "tests.${test_name}" \ - --flagfile="${TEST_DRIVER_FLAGFILE}" \ - --kube_context="${KUBE_CONTEXT}" \ - --server_image="${SERVER_IMAGE_NAME}:${GIT_COMMIT}" \ - --client_image="${CLIENT_IMAGE_NAME}:${GIT_COMMIT}" \ - --testing_version="${TESTING_VERSION}" \ - --force_cleanup \ - --collect_app_logs \ - --log_dir="${out_dir}" \ - --xml_output_file="${out_dir}/sponge_log.xml" \ - |& tee "${out_dir}/sponge_log.log" -} - -####################################### -# Main function: provision software necessary to execute tests, and run them -# Globals: -# KOKORO_ARTIFACTS_DIR -# GITHUB_REPOSITORY_NAME -# SRC_DIR: Populated with absolute path to the source repo -# TEST_DRIVER_REPO_DIR: Populated with the path to the repo containing -# the test driver -# TEST_DRIVER_FULL_DIR: Populated with the path to the test driver source code -# TEST_DRIVER_FLAGFILE: Populated with relative path to test driver flagfile -# TEST_XML_OUTPUT_DIR: Populated with the path to test xUnit XML report -# GIT_ORIGIN_URL: Populated with the origin URL of git repo used for the build -# GIT_COMMIT: Populated with the SHA-1 of git commit being built -# GIT_COMMIT_SHORT: Populated with the short SHA-1 of git commit being built -# KUBE_CONTEXT: Populated with name of kubectl context with GKE cluster access -# Arguments: -# None -# Outputs: -# Writes the output of test execution to stdout, stderr -####################################### -main() { - local script_dir - script_dir="$(dirname "$0")" - - # Source the test driver from the master branch. - echo "Sourcing test driver install script from: ${TEST_DRIVER_INSTALL_SCRIPT_URL}" - source /dev/stdin <<< "$(curl -s "${TEST_DRIVER_INSTALL_SCRIPT_URL}")" - - activate_gke_cluster GKE_CLUSTER_PSM_SECURITY - - set -x - if [[ -n "${KOKORO_ARTIFACTS_DIR}" ]]; then - kokoro_setup_test_driver "${GITHUB_REPOSITORY_NAME}" - else - local_setup_test_driver "${script_dir}" - fi - build_docker_images_if_needed - # Run tests - cd "${TEST_DRIVER_FULL_DIR}" - local failed_tests=0 - test_suites=("baseline_test" "security_test" "authz_test") - for test in "${test_suites[@]}"; do - run_test $test || (( ++failed_tests )) - done - echo "Failed test suites: ${failed_tests}" -} - -main "$@" diff --git a/buildscripts/kokoro/psm-spiffe.cfg b/buildscripts/kokoro/psm-spiffe.cfg new file mode 100644 index 00000000000..b04d715fca1 --- /dev/null +++ b/buildscripts/kokoro/psm-spiffe.cfg @@ -0,0 +1,17 @@ +# Config file for internal CI + +# Location of the continuous shell script in repository. +build_file: "grpc-java/buildscripts/kokoro/psm-interop-test-java.sh" +timeout_mins: 240 + +action { + define_artifacts { + regex: "artifacts/**/*sponge_log.xml" + regex: "artifacts/**/*.log" + strip_prefix: "artifacts" + } +} +env_vars { + key: "PSM_TEST_SUITE" + value: "spiffe" +} diff --git a/buildscripts/kokoro/unix.sh b/buildscripts/kokoro/unix.sh index 9b1a4054c7e..693768a0270 100755 --- a/buildscripts/kokoro/unix.sh +++ b/buildscripts/kokoro/unix.sh @@ -23,11 +23,6 @@ readonly GRPC_JAVA_DIR="$(cd "$(dirname "$0")"/../.. && pwd)" # cd to the root dir of grpc-java cd $(dirname $0)/../.. -# TODO(zpencer): always make sure we are using Oracle jdk8 -if [[ -f /usr/libexec/java_home ]]; then - JAVA_HOME=$(/usr/libexec/java_home -v"1.8.0") -fi - # ARCH is x86_64 unless otherwise specified. ARCH="${ARCH:-x86_64}" @@ -43,7 +38,13 @@ ARCH="$ARCH" buildscripts/make_dependencies.sh # Set properties via flags, do not pollute gradle.properties GRADLE_FLAGS="${GRADLE_FLAGS:-}" +GRADLE_FLAGS+=" --stacktrace" GRADLE_FLAGS+=" -PtargetArch=$ARCH" + +# For universal binaries on macOS, signal Gradle to use universal flags. +if [[ "$(uname -s)" == "Darwin" ]]; then + GRADLE_FLAGS+=" -PbuildUniversal=true" +fi GRADLE_FLAGS+=" -Pcheckstyle.ignoreFailures=false" GRADLE_FLAGS+=" -PfailOnWarnings=true" GRADLE_FLAGS+=" -PerrorProne=true" @@ -56,9 +57,9 @@ fi export GRADLE_OPTS="-Dorg.gradle.jvmargs='-Xmx1g'" # Make protobuf discoverable by :grpc-compiler -export LD_LIBRARY_PATH=/tmp/protobuf/lib -export LDFLAGS=-L/tmp/protobuf/lib -export CXXFLAGS="-I/tmp/protobuf/include" +export LDFLAGS="$(PKG_CONFIG_PATH=/tmp/protobuf/lib/pkgconfig pkg-config --libs protobuf)" +export CXXFLAGS="$(PKG_CONFIG_PATH=/tmp/protobuf/lib/pkgconfig pkg-config --cflags protobuf)" +export LIBRARY_PATH=/tmp/protobuf/lib ./gradlew grpc-compiler:clean $GRADLE_FLAGS diff --git a/buildscripts/kokoro/windows.cfg b/buildscripts/kokoro/windows.cfg index bdfaa38904f..ec0a3c9ae34 100644 --- a/buildscripts/kokoro/windows.cfg +++ b/buildscripts/kokoro/windows.cfg @@ -2,7 +2,7 @@ # Location of the continuous shell script in repository. build_file: "grpc-java/buildscripts/kokoro/windows.bat" -timeout_mins: 45 +timeout_mins: 90 # We always build mvn artifacts. action { diff --git a/buildscripts/kokoro/windows32.bat b/buildscripts/kokoro/windows32.bat index ffd4d3b99a6..d51beba82f9 100644 --- a/buildscripts/kokoro/windows32.bat +++ b/buildscripts/kokoro/windows32.bat @@ -15,19 +15,21 @@ set ESCWORKSPACE=%WORKSPACE:\=\\% @rem Clear JAVA_HOME to prevent a different Java version from being used set JAVA_HOME= -set PATH=C:\Program Files\OpenJDK\openjdk-11.0.12_7\bin;%PATH% mkdir grpc-java-helper32 cd grpc-java-helper32 -call "%VS140COMNTOOLS%\vsvars32.bat" || exit /b 1 +call "%VS170COMNTOOLS%\..\..\VC\Auxiliary\Build\vcvars32.bat" || exit /b 1 call "%WORKSPACE%\buildscripts\make_dependencies.bat" || exit /b 1 cd "%WORKSPACE%" SET TARGET_ARCH=x86_32 SET FAIL_ON_WARNINGS=true -SET VC_PROTOBUF_LIBS=%ESCWORKSPACE%\\grpc-java-helper32\\protobuf-%PROTOBUF_VER%\\build\\Release -SET VC_PROTOBUF_INCLUDE=%ESCWORKSPACE%\\grpc-java-helper32\\protobuf-%PROTOBUF_VER%\\build\\include +SET PROTOBUF_VER=33.4 +SET PKG_CONFIG_PATH=%ESCWORKSPACE%\\grpc-java-helper32\\protobuf-%PROTOBUF_VER%\\build\\protobuf-%PROTOBUF_VER%\\lib\\pkgconfig +SET VC_PROTOBUF_LIBS=/LIBPATH:%ESCWORKSPACE%\\grpc-java-helper32\\protobuf-%PROTOBUF_VER%\\build\\protobuf-%PROTOBUF_VER%\\lib +SET VC_PROTOBUF_INCLUDE=%ESCWORKSPACE%\\grpc-java-helper32\\protobuf-%PROTOBUF_VER%\\build\\protobuf-%PROTOBUF_VER%\\include +call :Get_Libs SET GRADLE_FLAGS=-PtargetArch=%TARGET_ARCH% -PfailOnWarnings=%FAIL_ON_WARNINGS% -PvcProtobufLibs=%VC_PROTOBUF_LIBS% -PvcProtobufInclude=%VC_PROTOBUF_INCLUDE% -PskipAndroid=true SET GRADLE_OPTS="-Dorg.gradle.jvmargs='-Xmx1g'" @@ -50,3 +52,34 @@ IF NOT %GRADLEEXIT% == 0 ( cmd.exe /C "%WORKSPACE%\gradlew.bat --stop" cmd.exe /C "%WORKSPACE%\gradlew.bat %GRADLE_FLAGS% -Dorg.gradle.parallel=false -PrepositoryDir=%WORKSPACE%\artifacts clean grpc-compiler:build grpc-compiler:publish" || exit /b 1 + +goto :eof +:Get_Libs +SetLocal EnableDelayedExpansion +set "libs_list=" +for /f "tokens=*" %%a in ('pkg-config --libs protobuf') do ( + for %%b in (%%a) do ( + set lib=%%b + set libfirst2char=!lib:~0,2! + if !libfirst2char!==-l ( + @rem remove the leading -l + set lib=!lib:~2! + @rem remove spaces + set lib=!lib: =! + set libprefix=!lib:~0,4! + if !libprefix!==absl ( + set lib=!lib!.lib + ) else ( + set lib=lib!lib!.lib + ) + if "!libs_list!"=="" ( + set libs_list=!lib! + ) else ( + set libs_list=!libs_list!,!lib! + ) + ) + ) +) +EndLocal & set "VC_PROTOBUF_LIBS=%VC_PROTOBUF_LIBS%,%libs_list%" +exit /b 0 + diff --git a/buildscripts/kokoro/windows64.bat b/buildscripts/kokoro/windows64.bat index 8542f1c0536..180025d5e82 100644 --- a/buildscripts/kokoro/windows64.bat +++ b/buildscripts/kokoro/windows64.bat @@ -14,19 +14,21 @@ set ESCWORKSPACE=%WORKSPACE:\=\\% @rem Clear JAVA_HOME to prevent a different Java version from being used set JAVA_HOME= -set PATH=C:\Program Files\OpenJDK\openjdk-11.0.12_7\bin;%PATH% mkdir grpc-java-helper64 cd grpc-java-helper64 -call "%VS140COMNTOOLS%\..\..\VC\bin\amd64\vcvars64.bat" || exit /b 1 +call "%VS170COMNTOOLS%\..\..\VC\Auxiliary\Build\vcvars64.bat" || exit /b 1 call "%WORKSPACE%\buildscripts\make_dependencies.bat" || exit /b 1 cd "%WORKSPACE%" SET TARGET_ARCH=x86_64 SET FAIL_ON_WARNINGS=true -SET VC_PROTOBUF_LIBS=%ESCWORKSPACE%\\grpc-java-helper64\\protobuf-%PROTOBUF_VER%\\build\\Release -SET VC_PROTOBUF_INCLUDE=%ESCWORKSPACE%\\grpc-java-helper64\\protobuf-%PROTOBUF_VER%\\build\\include +SET PROTOBUF_VER=33.4 +SET PKG_CONFIG_PATH=%ESCWORKSPACE%\\grpc-java-helper64\\protobuf-%PROTOBUF_VER%\\build\\protobuf-%PROTOBUF_VER%\\lib\\pkgconfig +SET VC_PROTOBUF_LIBS=/LIBPATH:%ESCWORKSPACE%\\grpc-java-helper64\\protobuf-%PROTOBUF_VER%\\build\\protobuf-%PROTOBUF_VER%\\lib +SET VC_PROTOBUF_INCLUDE=%ESCWORKSPACE%\\grpc-java-helper64\\protobuf-%PROTOBUF_VER%\\build\\protobuf-%PROTOBUF_VER%\\include +call :Get_Libs SET GRADLE_FLAGS=-PtargetArch=%TARGET_ARCH% -PfailOnWarnings=%FAIL_ON_WARNINGS% -PvcProtobufLibs=%VC_PROTOBUF_LIBS% -PvcProtobufInclude=%VC_PROTOBUF_INCLUDE% -PskipAndroid=true SET GRADLE_OPTS="-Dorg.gradle.jvmargs='-Xmx1g'" @@ -34,3 +36,34 @@ SET GRADLE_OPTS="-Dorg.gradle.jvmargs='-Xmx1g'" cmd.exe /C "%WORKSPACE%\gradlew.bat --stop" cmd.exe /C "%WORKSPACE%\gradlew.bat %GRADLE_FLAGS% -Dorg.gradle.parallel=false -PrepositoryDir=%WORKSPACE%\artifacts grpc-compiler:clean grpc-compiler:build grpc-compiler:publish" || exit /b 1 + +goto :eof +:Get_Libs +SetLocal EnableDelayedExpansion +set "libs_list=" +for /f "tokens=*" %%a in ('pkg-config --libs protobuf') do ( + for %%b in (%%a) do ( + set lib=%%b + set libfirst2char=!lib:~0,2! + if !libfirst2char!==-l ( + @rem remove the leading -l + set lib=!lib:~2! + @rem remove spaces + set lib=!lib: =! + set libprefix=!lib:~0,4! + if !libprefix!==absl ( + set lib=!lib!.lib + ) else ( + set lib=lib!lib!.lib + ) + if "!libs_list!"=="" ( + set libs_list=!lib! + ) else ( + set libs_list=!libs_list!,!lib! + ) + ) + ) +) +EndLocal & set "VC_PROTOBUF_LIBS=%VC_PROTOBUF_LIBS%,%libs_list%" +exit /b 0 + diff --git a/buildscripts/kokoro/xds_k8s_lb.cfg b/buildscripts/kokoro/xds_k8s_lb.cfg index 10ea2d43b5d..4dab80bf76e 100644 --- a/buildscripts/kokoro/xds_k8s_lb.cfg +++ b/buildscripts/kokoro/xds_k8s_lb.cfg @@ -1,8 +1,8 @@ # Config file for internal CI # Location of the continuous shell script in repository. -build_file: "grpc-java/buildscripts/kokoro/xds_k8s_lb.sh" -timeout_mins: 180 +build_file: "grpc-java/buildscripts/kokoro/psm-interop-test-java.sh" +timeout_mins: 300 action { define_artifacts { @@ -11,3 +11,7 @@ action { strip_prefix: "artifacts" } } +env_vars { + key: "PSM_TEST_SUITE" + value: "lb" +} diff --git a/buildscripts/kokoro/xds_k8s_lb.sh b/buildscripts/kokoro/xds_k8s_lb.sh deleted file mode 100755 index 0ff85d82e8e..00000000000 --- a/buildscripts/kokoro/xds_k8s_lb.sh +++ /dev/null @@ -1,189 +0,0 @@ -#!/usr/bin/env bash -set -eo pipefail - -# Constants -readonly GITHUB_REPOSITORY_NAME="grpc-java" -readonly TEST_DRIVER_INSTALL_SCRIPT_URL="https://raw.githubusercontent.com/${TEST_DRIVER_REPO_OWNER:-grpc}/psm-interop/${TEST_DRIVER_BRANCH:-main}/.kokoro/psm_interop_kokoro_lib.sh" -## xDS test server/client Docker images -readonly SERVER_IMAGE_NAME="gcr.io/grpc-testing/xds-interop/java-server" -readonly CLIENT_IMAGE_NAME="gcr.io/grpc-testing/xds-interop/java-client" -readonly FORCE_IMAGE_BUILD="${FORCE_IMAGE_BUILD:-0}" -readonly BUILD_APP_PATH="interop-testing/build/install/grpc-interop-testing" - -####################################### -# Builds the test app using gradle and smoke-checks its binaries -# Globals: -# SRC_DIR -# BUILD_APP_PATH -# Arguments: -# None -# Outputs: -# Writes the output of xds-test-client and xds-test-server --help to stderr -####################################### -build_java_test_app() { - echo "Building Java test app" - cd "${SRC_DIR}" - GRADLE_OPTS="-Dorg.gradle.jvmargs='-Xmx1g'" \ - ./gradlew --no-daemon grpc-interop-testing:installDist -x test \ - -PskipCodegen=true -PskipAndroid=true --console=plain - - # Test-run binaries - run_ignore_exit_code "${SRC_DIR}/${BUILD_APP_PATH}/bin/xds-test-client" --help - run_ignore_exit_code "${SRC_DIR}/${BUILD_APP_PATH}/bin/xds-test-server" --help -} - -####################################### -# Builds test app Docker images and pushes them to GCR -# Globals: -# BUILD_APP_PATH -# SERVER_IMAGE_NAME: Test server Docker image name -# CLIENT_IMAGE_NAME: Test client Docker image name -# GIT_COMMIT: SHA-1 of git commit being built -# TESTING_VERSION: version branch under test, f.e. v1.42.x, master -# Arguments: -# None -# Outputs: -# Writes the output of `gcloud builds submit` to stdout, stderr -####################################### -build_test_app_docker_images() { - echo "Building Java xDS interop test app Docker images" - local docker_dir="${SRC_DIR}/buildscripts/xds-k8s" - local build_dir - build_dir="$(mktemp -d)" - # Copy Docker files, log properties, and the test app to the build dir - cp -v "${docker_dir}/"*.Dockerfile "${build_dir}" - cp -v "${docker_dir}/"*.properties "${build_dir}" - cp -rv "${SRC_DIR}/${BUILD_APP_PATH}" "${build_dir}" - # Pick a branch name for the built image - local branch_name='experimental' - if is_version_branch "${TESTING_VERSION}"; then - branch_name="${TESTING_VERSION}" - fi - # Run Google Cloud Build - gcloud builds submit "${build_dir}" \ - --config "${docker_dir}/cloudbuild.yaml" \ - --substitutions "_SERVER_IMAGE_NAME=${SERVER_IMAGE_NAME},_CLIENT_IMAGE_NAME=${CLIENT_IMAGE_NAME},COMMIT_SHA=${GIT_COMMIT},BRANCH_NAME=${branch_name}" - # TODO(sergiitk): extra "cosmetic" tags for versioned branches, e.g. v1.34.x - # TODO(sergiitk): do this when adding support for custom configs per version -} - -####################################### -# Builds test app and its docker images unless they already exist -# Globals: -# SERVER_IMAGE_NAME: Test server Docker image name -# CLIENT_IMAGE_NAME: Test client Docker image name -# GIT_COMMIT: SHA-1 of git commit being built -# FORCE_IMAGE_BUILD -# Arguments: -# None -# Outputs: -# Writes the output to stdout, stderr -####################################### -build_docker_images_if_needed() { - # Check if images already exist - server_tags="$(gcloud_gcr_list_image_tags "${SERVER_IMAGE_NAME}" "${GIT_COMMIT}")" - printf "Server image: %s:%s\n" "${SERVER_IMAGE_NAME}" "${GIT_COMMIT}" - echo "${server_tags:-Server image not found}" - - client_tags="$(gcloud_gcr_list_image_tags "${CLIENT_IMAGE_NAME}" "${GIT_COMMIT}")" - printf "Client image: %s:%s\n" "${CLIENT_IMAGE_NAME}" "${GIT_COMMIT}" - echo "${client_tags:-Client image not found}" - - # Build if any of the images are missing, or FORCE_IMAGE_BUILD=1 - if [[ "${FORCE_IMAGE_BUILD}" == "1" || -z "${server_tags}" || -z "${client_tags}" ]]; then - build_java_test_app - build_test_app_docker_images - else - echo "Skipping Java test app build" - fi -} - -####################################### -# Executes the test case -# Globals: -# TEST_DRIVER_FLAGFILE: Relative path to test driver flagfile -# KUBE_CONTEXT: The name of kubectl context with GKE cluster access -# SECONDARY_KUBE_CONTEXT: The name of kubectl context with secondary GKE cluster access, if any -# TEST_XML_OUTPUT_DIR: Output directory for the test xUnit XML report -# SERVER_IMAGE_NAME: Test server Docker image name -# CLIENT_IMAGE_NAME: Test client Docker image name -# GIT_COMMIT: SHA-1 of git commit being built -# Arguments: -# Test case name -# Outputs: -# Writes the output of test execution to stdout, stderr -# Test xUnit report to ${TEST_XML_OUTPUT_DIR}/${test_name}/sponge_log.xml -####################################### -run_test() { - # Test driver usage: - # https://github.com/grpc/grpc/tree/master/tools/run_tests/xds_k8s_test_driver#basic-usage - local test_name="${1:?Usage: run_test test_name}" - local out_dir="${TEST_XML_OUTPUT_DIR}/${test_name}" - mkdir -pv "${out_dir}" - set -x - python -m "tests.${test_name}" \ - --flagfile="${TEST_DRIVER_FLAGFILE}" \ - --kube_context="${KUBE_CONTEXT}" \ - --secondary_kube_context="${SECONDARY_KUBE_CONTEXT}" \ - --server_image="${SERVER_IMAGE_NAME}:${GIT_COMMIT}" \ - --client_image="${CLIENT_IMAGE_NAME}:${GIT_COMMIT}" \ - --testing_version="${TESTING_VERSION}" \ - --force_cleanup \ - --collect_app_logs \ - --log_dir="${out_dir}" \ - --xml_output_file="${out_dir}/sponge_log.xml" \ - |& tee "${out_dir}/sponge_log.log" -} - -####################################### -# Main function: provision software necessary to execute tests, and run them -# Globals: -# KOKORO_ARTIFACTS_DIR -# GITHUB_REPOSITORY_NAME -# SRC_DIR: Populated with absolute path to the source repo -# TEST_DRIVER_REPO_DIR: Populated with the path to the repo containing -# the test driver -# TEST_DRIVER_FULL_DIR: Populated with the path to the test driver source code -# TEST_DRIVER_FLAGFILE: Populated with relative path to test driver flagfile -# TEST_XML_OUTPUT_DIR: Populated with the path to test xUnit XML report -# GIT_ORIGIN_URL: Populated with the origin URL of git repo used for the build -# GIT_COMMIT: Populated with the SHA-1 of git commit being built -# GIT_COMMIT_SHORT: Populated with the short SHA-1 of git commit being built -# KUBE_CONTEXT: Populated with name of kubectl context with GKE cluster access -# Arguments: -# None -# Outputs: -# Writes the output of test execution to stdout, stderr -####################################### -main() { - local script_dir - script_dir="$(dirname "$0")" - - # Source the test driver from the master branch. - echo "Sourcing test driver install script from: ${TEST_DRIVER_INSTALL_SCRIPT_URL}" - source /dev/stdin <<< "$(curl -s "${TEST_DRIVER_INSTALL_SCRIPT_URL}")" - - activate_gke_cluster GKE_CLUSTER_PSM_LB - activate_secondary_gke_cluster GKE_CLUSTER_PSM_LB - - set -x - if [[ -n "${KOKORO_ARTIFACTS_DIR}" ]]; then - kokoro_setup_test_driver "${GITHUB_REPOSITORY_NAME}" - else - local_setup_test_driver "${script_dir}" - fi - build_docker_images_if_needed - # Run tests - cd "${TEST_DRIVER_FULL_DIR}" - local failed_tests=0 - test_suites=("api_listener_test" "change_backend_service_test" "failover_test" "remove_neg_test" "round_robin_test" "affinity_test" "outlier_detection_test" "custom_lb_test") - if [[ "${TESTING_VERSION}" =~ "master" ]]; then - test_suites+=('bootstrap_generator_test') - fi - for test in "${test_suites[@]}"; do - run_test $test || (( ++failed_tests )) - done - echo "Failed test suites: ${failed_tests}" -} - -main "$@" diff --git a/buildscripts/kokoro/xds_url_map.cfg b/buildscripts/kokoro/xds_url_map.cfg index 1fa6c0141cb..3e27164fe26 100644 --- a/buildscripts/kokoro/xds_url_map.cfg +++ b/buildscripts/kokoro/xds_url_map.cfg @@ -1,7 +1,7 @@ # Config file for internal CI # Location of the continuous shell script in repository. -build_file: "grpc-java/buildscripts/kokoro/xds_url_map.sh" +build_file: "grpc-java/buildscripts/kokoro/psm-interop-test-java.sh" timeout_mins: 90 action { @@ -11,3 +11,7 @@ action { strip_prefix: "artifacts" } } +env_vars { + key: "PSM_TEST_SUITE" + value: "url_map" +} diff --git a/buildscripts/kokoro/xds_url_map.sh b/buildscripts/kokoro/xds_url_map.sh deleted file mode 100755 index 4f160728eb8..00000000000 --- a/buildscripts/kokoro/xds_url_map.sh +++ /dev/null @@ -1,178 +0,0 @@ -#!/usr/bin/env bash -set -eo pipefail - -# Constants -readonly GITHUB_REPOSITORY_NAME="grpc-java" -readonly TEST_DRIVER_INSTALL_SCRIPT_URL="https://raw.githubusercontent.com/${TEST_DRIVER_REPO_OWNER:-grpc}/psm-interop/${TEST_DRIVER_BRANCH:-main}/.kokoro/psm_interop_kokoro_lib.sh" -## xDS test client Docker images -readonly SERVER_IMAGE_NAME="gcr.io/grpc-testing/xds-interop/java-server" -readonly CLIENT_IMAGE_NAME="gcr.io/grpc-testing/xds-interop/java-client" -readonly FORCE_IMAGE_BUILD="${FORCE_IMAGE_BUILD:-0}" -readonly BUILD_APP_PATH="interop-testing/build/install/grpc-interop-testing" - -####################################### -# Builds the test app using gradle and smoke-checks its binaries -# Globals: -# SRC_DIR -# BUILD_APP_PATH -# Arguments: -# None -# Outputs: -# Writes the output of xds-test-client and xds-test-server --help to stderr -####################################### -build_java_test_app() { - echo "Building Java test app" - cd "${SRC_DIR}" - GRADLE_OPTS="-Dorg.gradle.jvmargs='-Xmx1g'" \ - ./gradlew --no-daemon grpc-interop-testing:installDist -x test \ - -PskipCodegen=true -PskipAndroid=true --console=plain - - # Test-run binaries - run_ignore_exit_code "${SRC_DIR}/${BUILD_APP_PATH}/bin/xds-test-client" --help - run_ignore_exit_code "${SRC_DIR}/${BUILD_APP_PATH}/bin/xds-test-server" --help -} - -####################################### -# Builds test app Docker images and pushes them to GCR -# Globals: -# BUILD_APP_PATH -# SERVER_IMAGE_NAME: Test server Docker image name -# CLIENT_IMAGE_NAME: Test client Docker image name -# GIT_COMMIT: SHA-1 of git commit being built -# TESTING_VERSION: version branch under test, f.e. v1.42.x, master -# Arguments: -# None -# Outputs: -# Writes the output of `gcloud builds submit` to stdout, stderr -####################################### -build_test_app_docker_images() { - echo "Building Java xDS interop test app Docker images" - local docker_dir="${SRC_DIR}/buildscripts/xds-k8s" - local build_dir - build_dir="$(mktemp -d)" - # Copy Docker files, log properties, and the test app to the build dir - cp -v "${docker_dir}/"*.Dockerfile "${build_dir}" - cp -v "${docker_dir}/"*.properties "${build_dir}" - cp -rv "${SRC_DIR}/${BUILD_APP_PATH}" "${build_dir}" - # Pick a branch name for the built image - local branch_name='experimental' - if is_version_branch "${TESTING_VERSION}"; then - branch_name="${TESTING_VERSION}" - fi - # Run Google Cloud Build - gcloud builds submit "${build_dir}" \ - --config "${docker_dir}/cloudbuild.yaml" \ - --substitutions "_SERVER_IMAGE_NAME=${SERVER_IMAGE_NAME},_CLIENT_IMAGE_NAME=${CLIENT_IMAGE_NAME},COMMIT_SHA=${GIT_COMMIT},BRANCH_NAME=${branch_name}" - # TODO(sergiitk): extra "cosmetic" tags for versioned branches, e.g. v1.34.x - # TODO(sergiitk): do this when adding support for custom configs per version -} - -####################################### -# Builds test app and its docker images unless they already exist -# Globals: -# SERVER_IMAGE_NAME: Test server Docker image name -# CLIENT_IMAGE_NAME: Test client Docker image name -# GIT_COMMIT: SHA-1 of git commit being built -# FORCE_IMAGE_BUILD -# Arguments: -# None -# Outputs: -# Writes the output to stdout, stderr -####################################### -build_docker_images_if_needed() { - # Check if images already exist - server_tags="$(gcloud_gcr_list_image_tags "${SERVER_IMAGE_NAME}" "${GIT_COMMIT}")" - printf "Server image: %s:%s\n" "${SERVER_IMAGE_NAME}" "${GIT_COMMIT}" - echo "${server_tags:-Server image not found}" - - client_tags="$(gcloud_gcr_list_image_tags "${CLIENT_IMAGE_NAME}" "${GIT_COMMIT}")" - printf "Client image: %s:%s\n" "${CLIENT_IMAGE_NAME}" "${GIT_COMMIT}" - echo "${client_tags:-Client image not found}" - - # Build if any of the images are missing, or FORCE_IMAGE_BUILD=1 - if [[ "${FORCE_IMAGE_BUILD}" == "1" || -z "${server_tags}" || -z "${client_tags}" ]]; then - build_java_test_app - build_test_app_docker_images - else - echo "Skipping Java test app build" - fi -} - -####################################### -# Executes the test case -# Globals: -# TEST_DRIVER_FLAGFILE: Relative path to test driver flagfile -# KUBE_CONTEXT: The name of kubectl context with GKE cluster access -# TEST_XML_OUTPUT_DIR: Output directory for the test xUnit XML report -# CLIENT_IMAGE_NAME: Test client Docker image name -# GIT_COMMIT: SHA-1 of git commit being built -# TESTING_VERSION: version branch under test: used by the framework to -# determine the supported PSM features. -# Arguments: -# Test case name -# Outputs: -# Writes the output of test execution to stdout, stderr -# Test xUnit report to ${TEST_XML_OUTPUT_DIR}/${test_name}/sponge_log.xml -####################################### -run_test() { - # Test driver usage: - # https://github.com/grpc/grpc/tree/master/tools/run_tests/xds_k8s_test_driver#basic-usage - local test_name="${1:?Usage: run_test test_name}" - local out_dir="${TEST_XML_OUTPUT_DIR}/${test_name}" - mkdir -pv "${out_dir}" - set -x - python -m "tests.${test_name}" \ - --flagfile="${TEST_DRIVER_FLAGFILE}" \ - --flagfile="config/url-map.cfg" \ - --kube_context="${KUBE_CONTEXT}" \ - --client_image="${CLIENT_IMAGE_NAME}:${GIT_COMMIT}" \ - --testing_version="${TESTING_VERSION}" \ - --collect_app_logs \ - --log_dir="${out_dir}" \ - --xml_output_file="${out_dir}/sponge_log.xml" \ - |& tee "${out_dir}/sponge_log.log" -} - -####################################### -# Main function: provision software necessary to execute tests, and run them -# Globals: -# KOKORO_ARTIFACTS_DIR -# GITHUB_REPOSITORY_NAME -# SRC_DIR: Populated with absolute path to the source repo -# TEST_DRIVER_REPO_DIR: Populated with the path to the repo containing -# the test driver -# TEST_DRIVER_FULL_DIR: Populated with the path to the test driver source code -# TEST_DRIVER_FLAGFILE: Populated with relative path to test driver flagfile -# TEST_XML_OUTPUT_DIR: Populated with the path to test xUnit XML report -# GIT_ORIGIN_URL: Populated with the origin URL of git repo used for the build -# GIT_COMMIT: Populated with the SHA-1 of git commit being built -# GIT_COMMIT_SHORT: Populated with the short SHA-1 of git commit being built -# KUBE_CONTEXT: Populated with name of kubectl context with GKE cluster access -# Arguments: -# None -# Outputs: -# Writes the output of test execution to stdout, stderr -####################################### -main() { - local script_dir - script_dir="$(dirname "$0")" - - # Source the test driver from the master branch. - echo "Sourcing test driver install script from: ${TEST_DRIVER_INSTALL_SCRIPT_URL}" - source /dev/stdin <<< "$(curl -s "${TEST_DRIVER_INSTALL_SCRIPT_URL}")" - - activate_gke_cluster GKE_CLUSTER_PSM_BASIC - - set -x - if [[ -n "${KOKORO_ARTIFACTS_DIR}" ]]; then - kokoro_setup_test_driver "${GITHUB_REPOSITORY_NAME}" - else - local_setup_test_driver "${script_dir}" - fi - build_docker_images_if_needed - # Run tests - cd "${TEST_DRIVER_FULL_DIR}" - run_test url_map || echo "Failed url_map test" -} - -main "$@" diff --git a/buildscripts/make_dependencies.bat b/buildscripts/make_dependencies.bat index 2bbfd394d46..a11f84d998e 100644 --- a/buildscripts/make_dependencies.bat +++ b/buildscripts/make_dependencies.bat @@ -1,12 +1,16 @@ -set PROTOBUF_VER=21.7 -set CMAKE_NAME=cmake-3.3.2-win32-x86 +choco install -y pkgconfiglite +choco install -y openjdk --version=17.0 +set PATH=%PATH%;"c:\Program Files\OpenJDK\jdk-17\bin" +set PROTOBUF_VER=33.4 +set ABSL_VERSION=20250127.1 +set CMAKE_NAME=cmake-3.26.3-windows-x86_64 if not exist "protobuf-%PROTOBUF_VER%\build\Release\" ( call :installProto || exit /b 1 ) echo Compile gRPC-Java with something like: -echo -PtargetArch=x86_32 -PvcProtobufLibs=%cd%\protobuf-%PROTOBUF_VER%\build\Release -PvcProtobufInclude=%cd%\protobuf-%PROTOBUF_VER%\build\include +echo -PtargetArch=x86_32 -PvcProtobufLibPath=%cd%\protobuf-%PROTOBUF_VER%\build\protobuf-%PROTOBUF_VER%\lib -PvcProtobufInclude=%cd%\protobuf-%PROTOBUF_VER%\build\protobuf-%PROTOBUF_VER%\include -PvcProtobufLibs=insert-list-of-libs-from-pkg-config-output-here goto :eof @@ -20,25 +24,34 @@ if not exist "%CMAKE_NAME%" ( set PATH=%PATH%;%cd%\%CMAKE_NAME%\bin :hasCmake @rem GitHub requires TLSv1.2, and for whatever reason our powershell doesn't have it enabled -powershell -command "$ErrorActionPreference = 'stop'; & { [Net.ServicePointManager]::SecurityProtocol = [Net.SecurityProtocolType]::Tls12 ; iwr https://github.com/google/protobuf/archive/v%PROTOBUF_VER%.zip -OutFile protobuf.zip }" || exit /b 1 +powershell -command "$ProgressPreference = 'SilentlyContinue'; $ErrorActionPreference = 'stop'; & { [Net.ServicePointManager]::SecurityProtocol = [Net.SecurityProtocolType]::Tls12 ; iwr https://github.com/google/protobuf/releases/download/v%PROTOBUF_VER%/protobuf-%PROTOBUF_VER%.zip -OutFile protobuf.zip }" || exit /b 1 powershell -command "$ErrorActionPreference = 'stop'; & { Add-Type -AssemblyName System.IO.Compression.FileSystem; [System.IO.Compression.ZipFile]::ExtractToDirectory('protobuf.zip', '.') }" || exit /b 1 del protobuf.zip +powershell -command "$ProgressPreference = 'SilentlyContinue'; $ErrorActionPreference = 'stop'; & { [Net.ServicePointManager]::SecurityProtocol = [Net.SecurityProtocolType]::Tls12 ; iwr https://github.com/abseil/abseil-cpp/archive/refs/tags/%ABSL_VERSION%.zip -OutFile absl.zip }" || exit /b 1 +powershell -command "$ErrorActionPreference = 'stop'; & { Add-Type -AssemblyName System.IO.Compression.FileSystem; [System.IO.Compression.ZipFile]::ExtractToDirectory('absl.zip', '.') }" || exit /b 1 +del absl.zip +move abseil-cpp-%ABSL_VERSION% protobuf-%PROTOBUF_VER%\third_party\abseil-cpp mkdir protobuf-%PROTOBUF_VER%\build pushd protobuf-%PROTOBUF_VER%\build -@rem Workaround https://github.com/protocolbuffers/protobuf/issues/10174 -powershell -command "(Get-Content ..\cmake\extract_includes.bat.in) -replace '\.\.\\', '' | Out-File -encoding ascii ..\cmake\extract_includes.bat.in" @rem cmake does not detect x86_64 from the vcvars64.bat variables. -@rem If vcvars64.bat has set PLATFORM to X64, then inform cmake to use the Win64 version of VS -if "%PLATFORM%" == "X64" ( - @rem Note the space - SET CMAKE_VSARCH= Win64 +@rem If vcvars64.bat has set PLATFORM to X64, then inform cmake to use the Win64 version of VS, likewise for x32 +if "%PLATFORM%" == "x64" ( + SET CMAKE_VSARCH=-A x64 +) else if "%PLATFORM%" == "x86" ( + @rem -A x86 doesn't work: https://github.com/microsoft/vcpkg/issues/15465 + SET CMAKE_VSARCH=-DCMAKE_GENERATOR_PLATFORM=WIN32 ) else ( SET CMAKE_VSARCH= ) -cmake -Dprotobuf_BUILD_TESTS=OFF -G "Visual Studio %VisualStudioVersion:~0,2%%CMAKE_VSARCH%" .. || exit /b 1 -msbuild /maxcpucount /p:Configuration=Release /verbosity:minimal libprotoc.vcxproj || exit /b 1 -call extract_includes.bat || exit /b 1 +for /f "tokens=4 delims=\" %%a in ("%VCINSTALLDIR%") do ( + SET VC_YEAR=%%a +) +for /f "tokens=1 delims=." %%a in ("%VisualStudioVersion%") do ( + SET visual_studio_major_version=%%a +) +cmake -DCMAKE_CXX_STANDARD=17 -DABSL_MSVC_STATIC_RUNTIME=ON -Dprotobuf_BUILD_TESTS=OFF -DCMAKE_INSTALL_PREFIX=%cd%\protobuf-%PROTOBUF_VER% -DCMAKE_PREFIX_PATH=%cd%\protobuf-%PROTOBUF_VER% -G "Visual Studio %visual_studio_major_version% %VC_YEAR%" %CMAKE_VSARCH% .. || exit /b 1 +cmake --build . --config Release --target install || exit /b 1 popd goto :eof @@ -49,3 +62,4 @@ powershell -command "$ErrorActionPreference = 'stop'; & { iwr https://cmake.org/ powershell -command "$ErrorActionPreference = 'stop'; & { Add-Type -AssemblyName System.IO.Compression.FileSystem; [System.IO.Compression.ZipFile]::ExtractToDirectory('cmake.zip', '.') }" || exit /b 1 del cmake.zip goto :eof + diff --git a/buildscripts/make_dependencies.sh b/buildscripts/make_dependencies.sh index 3d02a72f4eb..8cbefddd2eb 100755 --- a/buildscripts/make_dependencies.sh +++ b/buildscripts/make_dependencies.sh @@ -3,13 +3,63 @@ # Build protoc set -evux -o pipefail -PROTOBUF_VERSION=21.7 +PROTOBUF_VERSION=33.4 +ABSL_VERSION=20250127.1 # ARCH is x86_64 bit unless otherwise specified. ARCH="${ARCH:-x86_64}" DOWNLOAD_DIR=/tmp/source INSTALL_DIR="/tmp/protobuf-cache/$PROTOBUF_VERSION/$(uname -s)-$ARCH" +BUILDSCRIPTS_DIR="$(cd "$(dirname "$0")" && pwd)" + +function build_and_install() { + if [[ "$1" == "abseil" ]]; then + TESTS_OFF_ARG=ABSL_BUILD_TEST_HELPERS + else + TESTS_OFF_ARG=protobuf_BUILD_TESTS + fi + if [[ "$(uname -s)" == "Darwin" ]]; then + cmake .. \ + -DCMAKE_CXX_STANDARD=17 -D${TESTS_OFF_ARG}=OFF -DBUILD_SHARED_LIBS=OFF \ + -DCMAKE_INSTALL_PREFIX="$INSTALL_DIR" \ + -DCMAKE_PREFIX_PATH="$INSTALL_DIR" \ + -DCMAKE_OSX_ARCHITECTURES="arm64;x86_64" \ + -B. || exit 1 + elif [[ "$ARCH" == x86* ]]; then + CFLAGS=-m${ARCH#*_} CXXFLAGS=-m${ARCH#*_} cmake .. \ + -DCMAKE_CXX_STANDARD=17 -D${TESTS_OFF_ARG}=OFF -DBUILD_SHARED_LIBS=OFF \ + -DCMAKE_INSTALL_PREFIX="$INSTALL_DIR" \ + -DCMAKE_PREFIX_PATH="$INSTALL_DIR" \ + -B. || exit 1 + else + if [[ "$ARCH" == aarch_64 ]]; then + GCC_ARCH=aarch64-linux-gnu + elif [[ "$ARCH" == ppcle_64 ]]; then + GCC_ARCH=powerpc64le-linux-gnu + elif [[ "$ARCH" == s390_64 ]]; then + GCC_ARCH=s390x-linux-gnu + elif [[ "$ARCH" == loongarch_64 ]]; then + GCC_ARCH=loongarch64-unknown-linux-gnu + else + echo "Unknown architecture: $ARCH" + exit 1 + fi + cmake .. \ + -DCMAKE_CXX_STANDARD=17 -D${TESTS_OFF_ARG}=OFF -DBUILD_SHARED_LIBS=OFF \ + -DCMAKE_INSTALL_PREFIX="$INSTALL_DIR" \ + -DCMAKE_PREFIX_PATH="$INSTALL_DIR" \ + -Dcrosscompile_ARCH="$GCC_ARCH" \ + -DCMAKE_TOOLCHAIN_FILE=$BUILDSCRIPTS_DIR/toolchain.cmake \ + -B. || exit 1 + fi + export CMAKE_BUILD_PARALLEL_LEVEL="$NUM_CPU" + cmake --build . || exit 1 + # install here so we don't need sudo + cmake --install . || exit 1 +} + mkdir -p $DOWNLOAD_DIR +cd "$DOWNLOAD_DIR" # Start with a sane default NUM_CPU=4 @@ -19,6 +69,7 @@ fi if [[ $(uname) == 'Darwin' ]]; then NUM_CPU=$(sysctl -n hw.ncpu) fi +export CMAKE_BUILD_PARALLEL_LEVEL="$NUM_CPU" # Make protoc # Can't check for presence of directory as cache auto-creates it. @@ -26,28 +77,24 @@ if [ -f ${INSTALL_DIR}/bin/protoc ]; then echo "Not building protobuf. Already built" # TODO(ejona): swap to `brew install --devel protobuf` once it is up-to-date else - if [[ ! -d "$DOWNLOAD_DIR"/protobuf-"${PROTOBUF_VERSION}" ]]; then - curl -Ls https://github.com/google/protobuf/releases/download/v${PROTOBUF_VERSION}/protobuf-all-${PROTOBUF_VERSION}.tar.gz | tar xz -C $DOWNLOAD_DIR - fi - pushd $DOWNLOAD_DIR/protobuf-${PROTOBUF_VERSION} - # install here so we don't need sudo - if [[ "$ARCH" == x86* ]]; then - ./configure CFLAGS=-m${ARCH#*_} CXXFLAGS=-m${ARCH#*_} --disable-shared \ - --prefix="$INSTALL_DIR" - elif [[ "$ARCH" == aarch* ]]; then - ./configure --disable-shared --host=aarch64-linux-gnu --prefix="$INSTALL_DIR" - elif [[ "$ARCH" == ppc* ]]; then - ./configure --disable-shared --host=powerpc64le-linux-gnu --prefix="$INSTALL_DIR" - elif [[ "$ARCH" == s390* ]]; then - ./configure --disable-shared --host=s390x-linux-gnu --prefix="$INSTALL_DIR" - elif [[ "$ARCH" == loongarch* ]]; then - ./configure --disable-shared --host=loongarch64-unknown-linux-gnu --prefix="$INSTALL_DIR" + if [[ ! -d "protobuf-${PROTOBUF_VERSION}" ]]; then + curl -Ls "https://github.com/google/protobuf/releases/download/v${PROTOBUF_VERSION}/protobuf-${PROTOBUF_VERSION}.tar.gz" | tar xz + curl -Ls "https://github.com/abseil/abseil-cpp/archive/refs/tags/${ABSL_VERSION}.tar.gz" | tar xz fi # the same source dir is used for 32 and 64 bit builds, so we need to clean stale data first - make clean - make V=0 -j$NUM_CPU - make install + rm -rf "$DOWNLOAD_DIR/abseil-cpp-${ABSL_VERSION}/build" + mkdir "$DOWNLOAD_DIR/abseil-cpp-${ABSL_VERSION}/build" + pushd "$DOWNLOAD_DIR/abseil-cpp-${ABSL_VERSION}/build" + build_and_install "abseil" + popd + + rm -rf "$DOWNLOAD_DIR/protobuf-${PROTOBUF_VERSION}/build" + mkdir "$DOWNLOAD_DIR/protobuf-${PROTOBUF_VERSION}/build" + pushd "$DOWNLOAD_DIR/protobuf-${PROTOBUF_VERSION}/build" + build_and_install "protobuf" popd + + [ -d "$INSTALL_DIR/lib64" ] && mv "$INSTALL_DIR/lib64" "$INSTALL_DIR/lib" fi # If /tmp/protobuf exists then we just assume it's a symlink created by us. @@ -60,7 +107,9 @@ ln -s "$INSTALL_DIR" /tmp/protobuf cat <> "${grpc_java_dir}/gradle.properties" -skipAndroid=true -skipCodegen=true -org.gradle.parallel=true -org.gradle.jvmargs=-Xmx1024m -EOF - -export JAVA_OPTS="-Duser.home=/grpc-java/.current-user-home -Djava.util.prefs.userRoot=/grpc-java/.current-user-home/.java/.userPrefs" - -# build under x64 docker image to save time over building everything under -# aarch64 emulator. We've already built and tested the protoc binaries -# so for the rest of the build we will be using "-PskipCodegen=true" -# avoid further complicating the build. -docker run $DOCKER_ARGS --rm=true -v "${grpc_java_dir}":/grpc-java -w /grpc-java \ - --user "$(id -u):$(id -g)" -e JAVA_OPTS \ - openjdk:11-jdk-slim-buster \ - ./gradlew build -x test - -# Build and run java tests under aarch64 image. -# To be able to run this docker container on x64 machine, one needs to have -# qemu-user-static properly registered with binfmt_misc. -# The most important flag binfmt_misc flag we need is "F" (set by "--persistent yes"), -# which allows the qemu-aarch64-static binary to be loaded eagerly at the time of registration with binfmt_misc. -# That way, we can emulate aarch64 binaries running inside docker containers transparently, without needing the emulator -# binary to be accessible from the docker image we're emulating. -# Note that on newer distributions (such as glinux), simply "apt install qemu-user-static" is sufficient -# to install qemu-user-static with the right flags. -# A note on the "docker run" args used: -# - run docker container under current user's UID to avoid polluting the workspace -# - set the user.home property to avoid creating a "?" directory under grpc-java -docker run $DOCKER_ARGS --rm=true -v "${grpc_java_dir}":/grpc-java -w /grpc-java \ - --user "$(id -u):$(id -g)" -e JAVA_OPTS \ - arm64v8/openjdk:11-jdk-slim-buster \ - ./gradlew build diff --git a/buildscripts/set_github_status.py b/buildscripts/set_github_status.py deleted file mode 100755 index 09b2ad2ace6..00000000000 --- a/buildscripts/set_github_status.py +++ /dev/null @@ -1,64 +0,0 @@ -#!/usr/bin/env python2.7 -# -# Copyright 2018 The gRPC Authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import argparse -import json -import urllib2 - - -def run(): - argp = argparse.ArgumentParser(description='Set status on pull request') - - argp.add_argument( - '--sha1', type=str, help='SHA1 of the commit', required=True) - argp.add_argument( - '--state', - type=str, - choices=('error', 'failure', 'pending', 'success'), - help='State to set', - required=True) - argp.add_argument( - '--description', type=str, help='Status description', required=True) - argp.add_argument('--context', type=str, help='Status context', required=True) - argp.add_argument( - '--oauth_file', type=str, help='File with OAuth token', required=True) - - args = argp.parse_args() - sha1 = args.sha1 - state = args.state - description = args.description - context = args.context - oauth_file = args.oauth_file - - with open(oauth_file, 'r') as oauth_file_reader: - oauth_token = oauth_file_reader.read().replace('\n', '') - - req = urllib2.Request( - url='https://api.github.com/repos/grpc/grpc-java/statuses/%s' % sha1, - data=json.dumps({ - 'state': state, - 'description': description, - 'context': context, - }), - headers={ - 'Authorization': 'token %s' % oauth_token, - 'Content-Type': 'application/json', - }) - print urllib2.urlopen(req).read() - - -if __name__ == '__main__': - run() diff --git a/buildscripts/sonatype-upload.sh b/buildscripts/sonatype-upload.sh index 16637149126..4baa4e46ca0 100755 --- a/buildscripts/sonatype-upload.sh +++ b/buildscripts/sonatype-upload.sh @@ -59,7 +59,7 @@ if [ -z "$USERNAME" -o -z "$PASSWORD" ]; then exit 1 fi -STAGING_URL="https://oss.sonatype.org/service/local/staging" +STAGING_URL="https://ossrh-staging-api.central.sonatype.com/service/local/staging" # We go through the effort of using deloyByRepositoryId/ because it is # _substantially_ faster to upload files than deploy/maven2/. When using @@ -108,3 +108,18 @@ XML=" " curl --fail-with-body -X POST -d "$XML" -u "$USERPASS" -H "Content-Type: application/xml" \ "$STAGING_URL/profiles/$PROFILE_ID/finish" + +# TODO (okshiva): After 2-3 releases make it automatic. +# After closing the repository on the staging API, we must manually trigger +# its upload to the main Central Publisher Portal. We set publishing_type=automatic +# to have it release automatically upon passing validation. +# echo "Triggering release of repository ${REPOID} to the Central Portal" + +# MANUAL_API_URL="https://ossrh-staging-api.central.sonatype.com/service/local/manual" + +#curl --fail-with-body -X POST \ +# -H "Authorization: Bearer ${USERPASS}" \ +# -H "Content-Type: application/json" \ +# "${MANUAL_API_URL}/upload/repository/${REPOID}?publishing_type=automatic" + +# echo "Release triggered. Monitor progress at https://central.sonatype.com/publishing/deployments" diff --git a/buildscripts/toolchain.cmake b/buildscripts/toolchain.cmake new file mode 100644 index 00000000000..b71515cebda --- /dev/null +++ b/buildscripts/toolchain.cmake @@ -0,0 +1,9 @@ +set(CMAKE_SYSTEM_NAME Linux) + +set(CMAKE_C_COMPILER "${crosscompile_ARCH}-gcc") +set(CMAKE_CXX_COMPILER "${crosscompile_ARCH}-g++") +set(CMAKE_FIND_ROOT_PATH "/usr/${crosscompile_ARCH}/") + +set(CMAKE_FIND_ROOT_PATH_MODE_PROGRAM NEVER) +set(CMAKE_FIND_ROOT_PATH_MODE_LIBRARY ONLY) +set(CMAKE_FIND_ROOT_PATH_MODE_INCLUDE ONLY) diff --git a/buildscripts/xds-k8s/cloudbuild.yaml b/buildscripts/xds-k8s/cloudbuild.yaml index 577ed73ce58..ad4e8049e22 100644 --- a/buildscripts/xds-k8s/cloudbuild.yaml +++ b/buildscripts/xds-k8s/cloudbuild.yaml @@ -16,8 +16,8 @@ steps: - '.' substitutions: - _SERVER_IMAGE_NAME: gcr.io/grpc-testing/xds-interop/java-server - _CLIENT_IMAGE_NAME: gcr.io/grpc-testing/xds-interop/java-client + _SERVER_IMAGE_NAME: us-docker.pkg.dev/grpc-testing/psm-interop/java-server + _CLIENT_IMAGE_NAME: us-docker.pkg.dev/grpc-testing/psm-interop/java-client images: - '${_SERVER_IMAGE_NAME}:${COMMIT_SHA}' diff --git a/census/BUILD.bazel b/census/BUILD.bazel index c0bf29b3f37..f017eeaf8bd 100644 --- a/census/BUILD.bazel +++ b/census/BUILD.bazel @@ -1,3 +1,6 @@ +load("@rules_java//java:defs.bzl", "java_library") +load("@rules_jvm_external//:defs.bzl", "artifact") + java_library( name = "census", srcs = glob([ @@ -5,11 +8,12 @@ java_library( ]), visibility = ["//visibility:public"], deps = [ - "//api", - "//context", - "@com_google_code_findbugs_jsr305//jar", - "@com_google_guava_guava//jar", - "@io_opencensus_opencensus_api//jar", - "@io_opencensus_opencensus_contrib_grpc_metrics//jar", + "//api", + "//context", + artifact("com.google.code.findbugs:jsr305"), + artifact("com.google.errorprone:error_prone_annotations"), + artifact("com.google.guava:guava"), + artifact("io.opencensus:opencensus-api"), + artifact("io.opencensus:opencensus-contrib-grpc-metrics"), ], ) diff --git a/census/build.gradle b/census/build.gradle index 15b68acbb02..c7cb02c15a0 100644 --- a/census/build.gradle +++ b/census/build.gradle @@ -18,6 +18,7 @@ dependencies { // force dependent jars to depend on latest grpc-context runtimeOnly project(":grpc-context") implementation libraries.guava, + project(":grpc-context"), // Override opencensus dependency with our newer version libraries.opencensus.api, libraries.opencensus.contrib.grpc.metrics @@ -26,12 +27,20 @@ dependencies { project(':grpc-testing'), libraries.opencensus.impl - signature libraries.signature.java - signature libraries.signature.android + signature (libraries.signature.java) { + artifact { + extension = "signature" + } + } + signature (libraries.signature.android) { + artifact { + extension = "signature" + } + } } tasks.named("javadoc").configure { - failOnError false // no public or protected classes found to document + failOnError = false // no public or protected classes found to document exclude 'io/grpc/census/internal/**' exclude 'io/grpc/census/Internal*' } diff --git a/census/src/main/java/io/grpc/census/CensusStatsModule.java b/census/src/main/java/io/grpc/census/CensusStatsModule.java index ad16bef9604..8f571ceb627 100644 --- a/census/src/main/java/io/grpc/census/CensusStatsModule.java +++ b/census/src/main/java/io/grpc/census/CensusStatsModule.java @@ -22,6 +22,7 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Stopwatch; import com.google.common.base.Supplier; +import com.google.errorprone.annotations.concurrent.GuardedBy; import io.grpc.Attributes; import io.grpc.CallOptions; import io.grpc.Channel; @@ -62,7 +63,6 @@ import java.util.logging.Level; import java.util.logging.Logger; import javax.annotation.Nullable; -import javax.annotation.concurrent.GuardedBy; /** * Provides factories for {@link StreamTracer} that records stats to Census. diff --git a/census/src/main/java/io/grpc/census/GrpcCensus.java b/census/src/main/java/io/grpc/census/GrpcCensus.java new file mode 100644 index 00000000000..c564c349ae4 --- /dev/null +++ b/census/src/main/java/io/grpc/census/GrpcCensus.java @@ -0,0 +1,176 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.census; + +import com.google.common.base.Stopwatch; +import com.google.common.base.Supplier; +import io.grpc.ClientInterceptor; +import io.grpc.ExperimentalApi; +import io.grpc.ManagedChannelBuilder; +import io.grpc.ServerBuilder; +import io.grpc.ServerStreamTracer; +import io.opencensus.trace.Tracing; + +/** + * The entrypoint for OpenCensus instrumentation functionality in gRPC. + * + *

GrpcCensus uses {@link io.opencensus.api.OpenCensus} APIs for instrumentation. + * + */ +@ExperimentalApi("https://github.com/grpc/grpc-java/issues/12178") +public final class GrpcCensus { + + private final boolean statsEnabled; + private final boolean tracingEnabled; + + private GrpcCensus(Builder builder) { + this.statsEnabled = builder.statsEnabled; + this.tracingEnabled = builder.tracingEnabled; + } + + /** + * Creates a new builder for {@link GrpcCensus}. + */ + public static Builder newBuilder() { + return new Builder(); + } + + private static final Supplier STOPWATCH_SUPPLIER = new Supplier() { + @Override + public Stopwatch get() { + return Stopwatch.createUnstarted(); + } + }; + + /** + * Configures a {@link ServerBuilder} to enable census stats and tracing. + * + * @param serverBuilder The server builder to configure. + * @return The configured server builder. + */ + public > T configureServerBuilder(T serverBuilder) { + if (statsEnabled) { + serverBuilder.addStreamTracerFactory(newServerStatsStreamTracerFactory()); + } + if (tracingEnabled) { + serverBuilder.addStreamTracerFactory(newServerTracingStreamTracerFactory()); + } + return serverBuilder; + } + + /** + * Configures a {@link ManagedChannelBuilder} to enable census stats and tracing. + * + * @param channelBuilder The channel builder to configure. + * @return The configured channel builder. + */ + public > T configureChannelBuilder(T channelBuilder) { + if (statsEnabled) { + channelBuilder.intercept(newClientStatsInterceptor()); + } + if (tracingEnabled) { + channelBuilder.intercept(newClientTracingInterceptor()); + } + return channelBuilder; + } + + /** + * Returns a {@link ClientInterceptor} with default stats implementation. + */ + private static ClientInterceptor newClientStatsInterceptor() { + CensusStatsModule censusStats = + new CensusStatsModule( + STOPWATCH_SUPPLIER, + true, + true, + true, + false, + true); + return censusStats.getClientInterceptor(); + } + + /** + * Returns a {@link ClientInterceptor} with default tracing implementation. + */ + private static ClientInterceptor newClientTracingInterceptor() { + CensusTracingModule censusTracing = + new CensusTracingModule( + Tracing.getTracer(), + Tracing.getPropagationComponent().getBinaryFormat()); + return censusTracing.getClientInterceptor(); + } + + /** + * Returns a {@link ServerStreamTracer.Factory} with default stats implementation. + */ + private static ServerStreamTracer.Factory newServerStatsStreamTracerFactory() { + CensusStatsModule censusStats = + new CensusStatsModule( + STOPWATCH_SUPPLIER, + true, + true, + true, + false, + true); + return censusStats.getServerTracerFactory(); + } + + /** + * Returns a {@link ServerStreamTracer.Factory} with default tracing implementation. + */ + private static ServerStreamTracer.Factory newServerTracingStreamTracerFactory() { + CensusTracingModule censusTracing = + new CensusTracingModule( + Tracing.getTracer(), + Tracing.getPropagationComponent().getBinaryFormat()); + return censusTracing.getServerTracerFactory(); + } + + /** + * Builder for {@link GrpcCensus}. + */ + public static final class Builder { + private boolean statsEnabled = true; + private boolean tracingEnabled = true; + + private Builder() { + } + + /** + * Disables stats collection. + */ + public Builder disableStats() { + this.statsEnabled = false; + return this; + } + + /** + * Disables tracing. + */ + public Builder disableTracing() { + this.tracingEnabled = false; + return this; + } + + /** + * Builds a new {@link GrpcCensus}. + */ + public GrpcCensus build() { + return new GrpcCensus(this); + } + } +} diff --git a/census/src/test/java/io/grpc/census/CensusModulesTest.java b/census/src/test/java/io/grpc/census/CensusModulesTest.java index 6ccaf78314f..9e0b4d935d3 100644 --- a/census/src/test/java/io/grpc/census/CensusModulesTest.java +++ b/census/src/test/java/io/grpc/census/CensusModulesTest.java @@ -56,6 +56,7 @@ import io.grpc.ClientInterceptors; import io.grpc.ClientStreamTracer; import io.grpc.Context; +import io.grpc.KnownLength; import io.grpc.Metadata; import io.grpc.MethodDescriptor; import io.grpc.ServerCall; @@ -99,6 +100,7 @@ import io.opencensus.trace.Tracer; import io.opencensus.trace.propagation.BinaryFormat; import io.opencensus.trace.propagation.SpanContextParseException; +import java.io.IOException; import java.io.InputStream; import java.util.HashSet; import java.util.List; @@ -136,7 +138,7 @@ public class CensusModulesTest { ClientStreamTracer.StreamInfo.newBuilder() .setCallOptions(CallOptions.DEFAULT.withOption(NAME_RESOLUTION_DELAYED, 10L)).build(); - private static class StringInputStream extends InputStream { + private static class StringInputStream extends InputStream implements KnownLength { final String string; StringInputStream(String string) { @@ -149,6 +151,11 @@ public int read() { // passed to the InProcess server and consumed by MARSHALLER.parse(). throw new UnsupportedOperationException("Should not be called"); } + + @Override + public int available() throws IOException { + return string == null ? 0 : string.length(); + } } private static final MethodDescriptor.Marshaller MARSHALLER = diff --git a/compiler/BUILD.bazel b/compiler/BUILD.bazel index f88075c0e09..a9ffe77a55a 100644 --- a/compiler/BUILD.bazel +++ b/compiler/BUILD.bazel @@ -1,4 +1,6 @@ load("@rules_cc//cc:defs.bzl", "cc_binary") +load("@rules_java//java:defs.bzl", "java_library") +load("@rules_jvm_external//:defs.bzl", "artifact") load("//:java_grpc_library.bzl", "java_rpc_toolchain") # This should not generally be referenced. Users should use java_grpc_library @@ -11,19 +13,20 @@ cc_binary( ], visibility = ["//visibility:public"], deps = [ + "@abseil-cpp//absl/strings", "@com_google_protobuf//:protoc_lib", ], ) java_library( name = "java_grpc_library_deps__do_not_reference", + visibility = ["//xds:__pkg__"], exports = [ "//api", "//protobuf", "//stub", - "//stub:javax_annotation", - "@com_google_code_findbugs_jsr305//jar", - "@com_google_guava_guava//jar", + artifact("com.google.code.findbugs:jsr305"), + artifact("com.google.guava:guava"), "@com_google_protobuf//:protobuf_java", ], ) @@ -34,9 +37,8 @@ java_library( "//api", "//protobuf-lite", "//stub", - "//stub:javax_annotation", - "@com_google_code_findbugs_jsr305//jar", - "@com_google_guava_guava//jar", + artifact("com.google.code.findbugs:jsr305"), + artifact("com.google.guava:guava"), ], ) diff --git a/compiler/build.gradle b/compiler/build.gradle index 8bed90b8678..f970f629e19 100644 --- a/compiler/build.gradle +++ b/compiler/build.gradle @@ -76,6 +76,7 @@ model { aarch_64 { architecture "aarch_64" } s390_64 { architecture "s390_64" } loongarch_64 { architecture "loongarch_64" } + riscv_64 { architecture "riscv_64" } } components { @@ -84,6 +85,7 @@ model { 'x86_32', 'x86_64', 'ppcle_64', + 'riscv_64', 'aarch_64', 's390_64', 'loongarch_64' @@ -100,37 +102,45 @@ model { all { if (toolChain in Gcc || toolChain in Clang) { cppCompiler.define("GRPC_VERSION", version) - cppCompiler.args "--std=c++0x" + cppCompiler.args "--std=c++17" addEnvArgs("CXXFLAGS", cppCompiler.args) addEnvArgs("CPPFLAGS", cppCompiler.args) + if (project.hasProperty('buildUniversal') && + project.getProperty('buildUniversal').toBoolean() && + osdetector.os == "osx") { + cppCompiler.args "-arch", "arm64", "-arch", "x86_64" + linker.args "-arch", "arm64", "-arch", "x86_64" + } if (osdetector.os == "osx") { cppCompiler.args "-mmacosx-version-min=10.7", "-stdlib=libc++" + linker.args "-framework", "CoreFoundation" addLibraryIfNotLinked('protoc', linker.args) addLibraryIfNotLinked('protobuf', linker.args) } else if (osdetector.os == "windows") { linker.args "-static", "-lprotoc", "-lprotobuf", "-static-libgcc", "-static-libstdc++", "-s" - } else if (osdetector.arch == "ppcle_64") { - linker.args "-Wl,-Bstatic", "-lprotoc", "-lprotobuf", "-Wl,-Bdynamic", "-lpthread", "-s" - } else { + } else if (osdetector.arch == "ppcle_64") { + linker.args "-Wl,-Bstatic", "-lprotoc", "-lprotobuf", "-Wl,-Bdynamic", "-lpthread", "-s" + } else { // Link protoc, protobuf, libgcc and libstdc++ statically. // Link other (system) libraries dynamically. // Clang under OSX doesn't support these options. linker.args "-Wl,-Bstatic", "-lprotoc", "-lprotobuf", "-static-libgcc", - "-static-libstdc++", "-Wl,-Bdynamic", "-lpthread", "-s" } addEnvArgs("LDFLAGS", linker.args) } else if (toolChain in VisualCpp) { usingVisualCpp = true cppCompiler.define("GRPC_VERSION", version) - cppCompiler.args "/EHsc", "/MT" + cppCompiler.args "/EHsc", "/MT", "/std:c++17" if (rootProject.hasProperty('vcProtobufInclude')) { cppCompiler.args "/I${rootProject.vcProtobufInclude}" - } - linker.args "libprotobuf.lib", "libprotoc.lib" + } + linker.args.add("libprotoc.lib") + linker.args.add("libprotobuf.lib") if (rootProject.hasProperty('vcProtobufLibs')) { - linker.args "/LIBPATH:${rootProject.vcProtobufLibs}" + String libsList = rootProject.property('vcProtobufLibs') as String + libsList.split(',').each() { lib -> linker.args.add(lib) } } } } @@ -145,15 +155,13 @@ sourceSets { dependencies { testImplementation project(':grpc-protobuf'), - project(':grpc-stub'), - libraries.javax.annotation + project(':grpc-stub') testLiteImplementation project(':grpc-protobuf-lite'), - project(':grpc-stub'), - libraries.javax.annotation + project(':grpc-stub') } tasks.named("compileTestJava").configure { - options.errorprone.excludedPaths = ".*/build/generated/source/proto/.*" + options.errorprone.excludedPaths = ".*/build/generated/sources/proto/.*" } tasks.named("compileTestLiteJava").configure { @@ -161,7 +169,7 @@ tasks.named("compileTestLiteJava").configure { options.compilerArgs += [ "-Xlint:-cast" ] - options.errorprone.excludedPaths = ".*/build/generated/source/proto/.*" + options.errorprone.excludedPaths = ".*/build/generated/sources/proto/.*" } tasks.named("checkstyleTestLite").configure { @@ -185,14 +193,20 @@ protobuf { inputs.file javaPluginPath } ofSourceSet('test').configureEach { - plugins { grpc {} } + plugins { + grpc { + option '@generated=javax' + } + } } ofSourceSet('testLite').configureEach { builtins { java { option 'lite' } } plugins { - grpc { option 'lite' } + grpc { + option 'lite' + } } } } @@ -237,9 +251,10 @@ def checkArtifacts = tasks.register("checkArtifacts") { if (ret.exitValue != 0) { throw new GradleException("dumpbin exited with " + ret.exitValue) } - def dlls = os.toString() =~ /Image has the following dependencies:\s+(.*)\s+Summary/ - if (dlls[0][1] != "KERNEL32.dll") { - throw new Exception("unexpected dll deps: " + dlls[0][1]); + def dlls_match_results = os.toString() =~ /Image has the following dependencies:([\S\s]*)Summary/ + def dlls = dlls_match_results[0][1].trim().split("\\s+").sort() + if (dlls != ["KERNEL32.dll", "dbghelp.dll"]) { + throw new Exception("unexpected dll deps: " + dlls); } os.reset() ret = exec { diff --git a/compiler/check-artifact.sh b/compiler/check-artifact.sh index a80207af692..83b41f50282 100755 --- a/compiler/check-artifact.sh +++ b/compiler/check-artifact.sh @@ -86,17 +86,17 @@ checkArch () fi fi elif [[ "$OS" == osx ]]; then - format="$(file -b "$1" | grep -o "[^ ]*$")" - echo Format=$format - if [[ "$ARCH" == x86_32 ]]; then - assertEq "$format" "i386" $LINENO - elif [[ "$ARCH" == x86_64 ]]; then - assertEq "$format" "x86_64" $LINENO - elif [[ "$ARCH" == aarch_64 ]]; then - assertEq "$format" "arm64" $LINENO - else - fail "Unsupported arch: $ARCH" + # For macOS, we now build a universal binary. We check that both + # required architectures are present. + format="$(lipo -archs "$1")" + echo "Architectures found: $format" + if ! echo "$format" | grep -q "x86_64"; then + fail "Universal binary is missing x86_64 architecture." + fi + if ! echo "$format" | grep -q "arm64"; then + fail "Universal binary is missing arm64 architecture." fi + echo "Universal binary check successful." else fail "Unsupported system: $OS" fi @@ -113,23 +113,14 @@ checkDependencies () dump_cmd='objdump -x '"$1"' | fgrep "DLL Name"' white_list="KERNEL32\.dll\|msvcrt\.dll\|USER32\.dll" elif [[ "$OS" == linux ]]; then - dump_cmd='ldd '"$1" + dump_cmd='objdump -x '"$1"' | grep "NEEDED"' + white_list="libpthread\.so\.0\|libstdc++\.so\.6\|libc\.so\.6\|librt\.so\.1\|libm\.so\.6" if [[ "$ARCH" == x86_32 ]]; then - white_list="linux-gate\.so\.1\|libpthread\.so\.0\|libm\.so\.6\|libc\.so\.6\|ld-linux\.so\.2" + white_list="${white_list}\|libm\.so\.6" elif [[ "$ARCH" == x86_64 ]]; then - white_list="linux-vdso\.so\.1\|libpthread\.so\.0\|libm\.so\.6\|libc\.so\.6\|ld-linux-x86-64\.so\.2" + white_list="${white_list}\|libm\.so\.6" elif [[ "$ARCH" == aarch_64 ]]; then - dump_cmd='aarch64-linux-gnu-objdump -x '"$1"' |grep "NEEDED"' - white_list="linux-vdso\.so\.1\|libpthread\.so\.0\|libm\.so\.6\|libc\.so\.6\|ld-linux-aarch64\.so\.1" - elif [[ "$ARCH" == loongarch_64 ]]; then - dump_cmd='objdump -x '"$1"' | grep NEEDED' - white_list="linux-vdso\.so\.1\|libpthread\.so\.0\|libm\.so\.6\|libc\.so\.6\|ld\.so\.1" - elif [[ "$ARCH" == ppcle_64 ]]; then - dump_cmd='powerpc64le-linux-gnu-objdump -x '"$1"' |grep "NEEDED"' - white_list="linux-vdso64\.so\.1\|libpthread\.so\.0\|libm\.so\.6\|libc\.so\.6\|ld64\.so\.2" - elif [[ "$ARCH" == s390_64 ]]; then - dump_cmd='s390x-linux-gnu-objdump -x '"$1"' |grep "NEEDED"' - white_list="linux-vdso64\.so\.1\|libpthread\.so\.0\|libm\.so\.6\|libc\.so\.6\|ld64\.so\.1" + white_list="${white_list}\|ld-linux-aarch64\.so\.1" fi elif [[ "$OS" == osx ]]; then dump_cmd='otool -L '"$1"' | fgrep dylib' diff --git a/compiler/src/java_plugin/cpp/java_generator.cpp b/compiler/src/java_plugin/cpp/java_generator.cpp index 00855df3d04..d0f8cdd13d5 100644 --- a/compiler/src/java_plugin/cpp/java_generator.cpp +++ b/compiler/src/java_plugin/cpp/java_generator.cpp @@ -46,6 +46,7 @@ #include #include #include +#include "absl/strings/escaping.h" #include #include #include @@ -143,11 +144,24 @@ static std::set java_keywords = { "false", }; +// Methods on java.lang.Object that take no arguments. +static std::set java_object_methods = { + "clone", + "finalize", + "getClass", + "hashCode", + "notify", + "notifyAll", + "toString", + "wait", +}; + // Adjust a method name prefix identifier to follow the JavaBean spec: // - decapitalize the first letter // - remove embedded underscores & capitalize the following letter -// Finally, if the result is a reserved java keyword, append an underscore. -static std::string MixedLower(const std::string& word) { +// Finally, if the result is a reserved java keyword or an Object method, +// append an underscore. +static std::string MixedLower(std::string word, bool mangle_object_methods = false) { std::string w; w += tolower(word[0]); bool after_underscore = false; @@ -159,7 +173,9 @@ static std::string MixedLower(const std::string& word) { after_underscore = false; } } - if (java_keywords.find(w) != java_keywords.end()) { + if (java_keywords.find(w) != java_keywords.end() || + (mangle_object_methods && + java_object_methods.find(w) != java_object_methods.end())) { return w + "_"; } return w; @@ -169,7 +185,7 @@ static std::string MixedLower(const std::string& word) { // - An underscore is inserted where a lower case letter is followed by an // upper case letter. // - All letters are converted to upper case -static std::string ToAllUpperCase(const std::string& word) { +static std::string ToAllUpperCase(std::string word) { std::string w; for (size_t i = 0; i < word.length(); ++i) { w += toupper(word[i]); @@ -180,24 +196,25 @@ static std::string ToAllUpperCase(const std::string& word) { return w; } -static inline std::string LowerMethodName(const MethodDescriptor* method) { - return MixedLower(method->name()); +static inline std::string LowerMethodName(const MethodDescriptor* method, + bool mangle_object_methods = false) { + return MixedLower(std::string(method->name()), mangle_object_methods); } static inline std::string MethodPropertiesFieldName(const MethodDescriptor* method) { - return "METHOD_" + ToAllUpperCase(method->name()); + return "METHOD_" + ToAllUpperCase(std::string(method->name())); } static inline std::string MethodPropertiesGetterName(const MethodDescriptor* method) { - return MixedLower("get_" + method->name() + "_method"); + return MixedLower("get_" + std::string(method->name()) + "_method"); } static inline std::string MethodIdFieldName(const MethodDescriptor* method) { - return "METHODID_" + ToAllUpperCase(method->name()); + return "METHODID_" + ToAllUpperCase(std::string(method->name())); } static inline std::string MessageFullJavaName(const Descriptor* desc) { - return protobuf::compiler::java::ClassName(desc); + return protobuf::compiler::java::QualifiedClassName(desc); } // TODO(nmittler): Remove once protobuf includes javadoc methods in distribution. @@ -355,13 +372,15 @@ enum StubType { BLOCKING_CLIENT_IMPL = 5, FUTURE_CLIENT_IMPL = 6, ABSTRACT_CLASS = 7, - NONE = 8, + BLOCKING_V2_CLIENT_IMPL = 8, + NONE = 999, }; enum CallType { ASYNC_CALL = 0, BLOCKING_CALL = 1, - FUTURE_CALL = 2 + FUTURE_CALL = 2, + BLOCKING_V2_CALL = 3, }; // TODO(nmittler): Remove once protobuf includes javadoc methods in distribution. @@ -404,12 +423,15 @@ static void GrpcWriteServiceDocComment(Printer* printer, StubType type) { printer->Print("/**\n"); - std::map vars = {{"service", service->name()}}; + std::map vars = {{"service", std::string(service->name())}}; switch (type) { case ASYNC_CLIENT_IMPL: printer->Print(vars, " * A stub to allow clients to do asynchronous rpc calls to service $service$.\n"); break; case BLOCKING_CLIENT_IMPL: + printer->Print(vars, " * A stub to allow clients to do limited synchronous rpc calls to service $service$.\n"); + break; + case BLOCKING_V2_CLIENT_IMPL: printer->Print(vars, " * A stub to allow clients to do synchronous rpc calls to service $service$.\n"); break; case FUTURE_CLIENT_IMPL: @@ -515,7 +537,8 @@ static void PrintMethodFields( " .setResponseMarshaller($ProtoUtils$.marshaller(\n" " $output_type$.getDefaultInstance()))\n"); - (*vars)["proto_method_descriptor_supplier"] = service->name() + "MethodDescriptorSupplier"; + (*vars)["proto_method_descriptor_supplier"] + = std::string(service->name()) + "MethodDescriptorSupplier"; if (flavor == ProtoFlavor::NORMAL) { p->Print( *vars, @@ -555,6 +578,9 @@ static void PrintStubFactory( case BLOCKING_CLIENT_IMPL: stub_type_name = "Blocking"; break; + case BLOCKING_V2_CLIENT_IMPL: + stub_type_name = "BlockingV2"; + break; default: GRPC_CODEGEN_FAIL << "Cannot generate StubFactory for StubType: " << type; } @@ -575,7 +601,7 @@ static void PrintStub( const ServiceDescriptor* service, std::map* vars, Printer* p, StubType type) { - const std::string service_name = service->name(); + std::string service_name = std::string(service->name()); (*vars)["service_name"] = service_name; std::string stub_name = service_name; std::string stub_base_class_name = "AbstractStub"; @@ -597,6 +623,11 @@ static void PrintStub( stub_name += "BlockingStub"; stub_base_class_name = "AbstractBlockingStub"; break; + case BLOCKING_V2_CLIENT_IMPL: + call_type = BLOCKING_V2_CALL; + stub_name += "BlockingV2Stub"; + stub_base_class_name = "AbstractBlockingStub"; + break; case FUTURE_CLIENT_IMPL: call_type = FUTURE_CALL; stub_name += "FutureStub"; @@ -662,10 +693,12 @@ static void PrintStub( const MethodDescriptor* method = service->method(i); (*vars)["input_type"] = MessageFullJavaName(method->input_type()); (*vars)["output_type"] = MessageFullJavaName(method->output_type()); - (*vars)["lower_method_name"] = LowerMethodName(method); - (*vars)["method_method_name"] = MethodPropertiesGetterName(method); bool client_streaming = method->client_streaming(); bool server_streaming = method->server_streaming(); + bool mangle_object_methods = (call_type == BLOCKING_V2_CALL && client_streaming) + || (call_type == BLOCKING_CALL && client_streaming && server_streaming); + (*vars)["lower_method_name"] = LowerMethodName(method, mangle_object_methods); + (*vars)["method_method_name"] = MethodPropertiesGetterName(method); if (call_type == BLOCKING_CALL && client_streaming) { // Blocking client interface with client streaming is not available @@ -679,13 +712,17 @@ static void PrintStub( // Method signature p->Print("\n"); - // TODO(nmittler): Replace with WriteMethodDocComment once included by the protobuf distro. GrpcWriteMethodDocComment(p, method); if (method->options().deprecated()) { p->Print(*vars, "@$Deprecated$\n"); } + if ((call_type == BLOCKING_CALL && client_streaming && server_streaming) + || (call_type == BLOCKING_V2_CALL && (client_streaming || server_streaming))) { + p->Print(*vars, "@io.grpc.ExperimentalApi(\"https://github.com/grpc/grpc-java/issues/10918\")\n"); + } + if (!interface) { p->Print("public "); } else { @@ -695,7 +732,12 @@ static void PrintStub( case BLOCKING_CALL: GRPC_CODEGEN_CHECK(!client_streaming) << "Blocking client interface with client streaming is unavailable"; - if (server_streaming) { + if (client_streaming && server_streaming) { + p->Print( + *vars, + "$BlockingClientCall$<$input_type$, $output_type$>\n" + " $lower_method_name$()"); + } else if (server_streaming) { // Server streaming p->Print( *vars, @@ -708,6 +750,26 @@ static void PrintStub( "$output_type$ $lower_method_name$($input_type$ request)"); } break; + case BLOCKING_V2_CALL: + if (client_streaming) { // Both Bidi and Client Streaming + p->Print( + *vars, + "$BlockingClientCall$<$input_type$, $output_type$>\n" + " $lower_method_name$()"); + } else if (server_streaming) { + // Server streaming + p->Print( + *vars, + "$BlockingClientCall$\n" + " $lower_method_name$($input_type$ request)"); + } else { + // Simple RPC + (*vars)["throws_decl"] = " throws io.grpc.StatusException"; + p->Print( + *vars, + "$output_type$ $lower_method_name$($input_type$ request)$throws_decl$"); + } + break; case ASYNC_CALL: if (client_streaming) { // Bidirectional streaming or client streaming @@ -753,21 +815,47 @@ static void PrintStub( "$method_method_name$(), responseObserver);\n"); } } else if (!interface) { - switch (call_type) { + switch (call_type) { case BLOCKING_CALL: GRPC_CODEGEN_CHECK(!client_streaming) - << "Blocking client streaming interface is not available"; - if (server_streaming) { - (*vars)["calls_method"] = "io.grpc.stub.ClientCalls.blockingServerStreamingCall"; - (*vars)["params"] = "request"; - } else { - (*vars)["calls_method"] = "io.grpc.stub.ClientCalls.blockingUnaryCall"; - (*vars)["params"] = "request"; + << "Blocking client and bidi streaming interface are not available"; + if (server_streaming) { + (*vars)["calls_method"] = "io.grpc.stub.ClientCalls.blockingServerStreamingCall"; + (*vars)["params"] = "request"; + } else { + (*vars)["calls_method"] = "io.grpc.stub.ClientCalls.blockingUnaryCall"; + (*vars)["params"] = "request"; + } + p->Print( + *vars, + "return $calls_method$(\n" + " getChannel(), $method_method_name$(), getCallOptions(), $params$);\n"); + break; + case BLOCKING_V2_CALL: + if (client_streaming) { // client and bidi streaming + if (server_streaming) { + (*vars)["calls_method"] = "io.grpc.stub.ClientCalls.blockingBidiStreamingCall"; + } else { + (*vars)["calls_method"] = "io.grpc.stub.ClientCalls.blockingClientStreamingCall"; + } + p->Print( + *vars, + "return $calls_method$(\n" + " getChannel(), $method_method_name$(), getCallOptions());\n"); + } else { // server streaming and unary + (*vars)["params"] = "request"; + if (server_streaming) { + (*vars)["calls_method"] = "io.grpc.stub.ClientCalls.blockingV2ServerStreamingCall"; + } else { + (*vars)["calls_method"] = "io.grpc.stub.ClientCalls.blockingV2UnaryCall"; + (*vars)["throws_decl"] = " throws io.grpc.StatusException"; + } + + p->Print( + *vars, + "return $calls_method$(\n" + " getChannel(), $method_method_name$(), getCallOptions(), $params$);\n"); } - p->Print( - *vars, - "return $calls_method$(\n" - " getChannel(), $method_method_name$(), getCallOptions(), $params$);\n"); break; case ASYNC_CALL: if (server_streaming) { @@ -804,7 +892,7 @@ static void PrintStub( "return $calls_method$(\n" " getChannel().newCall($method_method_name$(), getCallOptions()), request);\n"); break; - } + } } else { GRPC_CODEGEN_FAIL << "Do not create Stub interfaces"; } @@ -821,8 +909,7 @@ static void PrintAbstractClassStub( const ServiceDescriptor* service, std::map* vars, Printer* p) { - const std::string service_name = service->name(); - (*vars)["service_name"] = service_name; + (*vars)["service_name"] = service->name(); GrpcWriteServiceDocComment(p, service, ABSTRACT_CLASS); if (service->options().deprecated()) { @@ -956,14 +1043,15 @@ static void PrintGetServiceDescriptorMethod(const ServiceDescriptor* service, std::map* vars, Printer* p, ProtoFlavor flavor) { - (*vars)["service_name"] = service->name(); + std::string service_name = std::string(service->name()); + (*vars)["service_name"] = service_name; if (flavor == ProtoFlavor::NORMAL) { - (*vars)["proto_base_descriptor_supplier"] = service->name() + "BaseDescriptorSupplier"; - (*vars)["proto_file_descriptor_supplier"] = service->name() + "FileDescriptorSupplier"; - (*vars)["proto_method_descriptor_supplier"] = service->name() + "MethodDescriptorSupplier"; - (*vars)["proto_class_name"] = protobuf::compiler::java::ClassName(service->file()); + (*vars)["proto_base_descriptor_supplier"] = service_name + "BaseDescriptorSupplier"; + (*vars)["proto_file_descriptor_supplier"] = service_name + "FileDescriptorSupplier"; + (*vars)["proto_method_descriptor_supplier"] = service_name + "MethodDescriptorSupplier"; + (*vars)["proto_class_name"] = protobuf::compiler::java::QualifiedClassName(service->file()); p->Print( *vars, "private static abstract class $proto_base_descriptor_supplier$\n" @@ -1116,9 +1204,10 @@ static void PrintService(const ServiceDescriptor* service, std::map* vars, Printer* p, ProtoFlavor flavor, - bool disable_version) { + bool disable_version, + GeneratedAnnotation generated_annotation) { (*vars)["service_name"] = service->name(); - (*vars)["file_name"] = service->file()->name(); + (*vars)["file_name"] = absl::Utf8SafeCEscape(service->file()->name()); (*vars)["service_class_name"] = ServiceClassName(service); (*vars)["grpc_version"] = ""; #ifdef GRPC_VERSION @@ -1129,23 +1218,16 @@ static void PrintService(const ServiceDescriptor* service, // TODO(nmittler): Replace with WriteServiceDocComment once included by protobuf distro. GrpcWriteServiceDocComment(p, service, NONE); - if ((*vars)["JakartaMode"] == "javax") { + if (generated_annotation == GeneratedAnnotation::JAVAX) { p->Print( *vars, "@javax.annotation.Generated(\n" " value = \"by gRPC proto compiler$grpc_version$\",\n" " comments = \"Source: $file_name$\")\n" "@$GrpcGenerated$\n"); - } else if ((*vars)["JakartaMode"] == "omit") { - p->Print( - *vars, - "@$GrpcGenerated$\n"); - } else { + } else { // GeneratedAnnotation::OMIT p->Print( *vars, - "@javax.annotation.Generated(\n" - " value = \"by gRPC proto compiler$grpc_version$\",\n" - " comments = \"Source: $file_name$\")\n" "@$GrpcGenerated$\n"); } @@ -1179,6 +1261,21 @@ static void PrintService(const ServiceDescriptor* service, p->Outdent(); p->Print("}\n\n"); + // TODO(nmittler): Replace with WriteDocComment once included by protobuf distro. + GrpcWriteDocComment(p, " Creates a new blocking-style stub that supports all types of calls " + "on the service"); + p->Print( + *vars, + "public static $service_name$BlockingV2Stub newBlockingV2Stub(\n" + " $Channel$ channel) {\n"); + p->Indent(); + PrintStubFactory(service, vars, p, BLOCKING_V2_CLIENT_IMPL); + p->Print( + *vars, + "return $service_name$BlockingV2Stub.newStub(factory, channel);\n"); + p->Outdent(); + p->Print("}\n\n"); + // TODO(nmittler): Replace with WriteDocComment once included by protobuf distro. GrpcWriteDocComment(p, " Creates a new blocking-style stub that supports unary and streaming " "output calls on the service"); @@ -1212,6 +1309,7 @@ static void PrintService(const ServiceDescriptor* service, PrintStub(service, vars, p, ASYNC_INTERFACE); PrintAbstractClassStub(service, vars, p); PrintStub(service, vars, p, ASYNC_CLIENT_IMPL); + PrintStub(service, vars, p, BLOCKING_V2_CLIENT_IMPL); PrintStub(service, vars, p, BLOCKING_CLIENT_IMPL); PrintStub(service, vars, p, FUTURE_CLIENT_IMPL); @@ -1232,7 +1330,7 @@ void GenerateService(const ServiceDescriptor* service, protobuf::io::ZeroCopyOutputStream* out, ProtoFlavor flavor, bool disable_version, - std::string jakarta_mode) { + GeneratedAnnotation generated_annotation) { // All non-generated classes must be referred by fully qualified names to // avoid collision with generated classes. std::map vars; @@ -1263,8 +1361,8 @@ void GenerateService(const ServiceDescriptor* service, vars["RpcMethod"] = "io.grpc.stub.annotations.RpcMethod"; vars["MethodDescriptor"] = "io.grpc.MethodDescriptor"; vars["StreamObserver"] = "io.grpc.stub.StreamObserver"; + vars["BlockingClientCall"] = "io.grpc.stub.BlockingClientCall"; vars["Iterator"] = "java.util.Iterator"; - vars["JakartaMode"] = jakarta_mode; vars["GrpcGenerated"] = "io.grpc.stub.annotations.GrpcGenerated"; vars["ListenableFuture"] = "com.google.common.util.concurrent.ListenableFuture"; @@ -1283,11 +1381,11 @@ void GenerateService(const ServiceDescriptor* service, if (!vars["Package"].empty()) { vars["Package"].append("."); } - PrintService(service, &vars, &printer, flavor, disable_version); + PrintService(service, &vars, &printer, flavor, disable_version, generated_annotation); } std::string ServiceJavaPackage(const FileDescriptor* file) { - std::string result = protobuf::compiler::java::ClassName(file); + std::string result = protobuf::compiler::java::QualifiedClassName(file); size_t last_dot_pos = result.find_last_of('.'); if (last_dot_pos != std::string::npos) { result.resize(last_dot_pos); @@ -1298,7 +1396,7 @@ std::string ServiceJavaPackage(const FileDescriptor* file) { } std::string ServiceClassName(const ServiceDescriptor* service) { - return service->name() + "Grpc"; + return std::string(service->name()) + "Grpc"; } } // namespace java_grpc_generator diff --git a/compiler/src/java_plugin/cpp/java_generator.h b/compiler/src/java_plugin/cpp/java_generator.h index d30179d334e..857fcab31d0 100644 --- a/compiler/src/java_plugin/cpp/java_generator.h +++ b/compiler/src/java_plugin/cpp/java_generator.h @@ -57,6 +57,10 @@ enum ProtoFlavor { NORMAL, LITE }; +enum GeneratedAnnotation { + OMIT, JAVAX +}; + // Returns the package name of the gRPC services defined in the given file. std::string ServiceJavaPackage(const impl::protobuf::FileDescriptor* file); @@ -69,7 +73,7 @@ void GenerateService(const impl::protobuf::ServiceDescriptor* service, impl::protobuf::io::ZeroCopyOutputStream* out, ProtoFlavor flavor, bool disable_version, - std::string jakarta_mode); + GeneratedAnnotation generated_annotation); } // namespace java_grpc_generator diff --git a/compiler/src/java_plugin/cpp/java_plugin.cpp b/compiler/src/java_plugin/cpp/java_plugin.cpp index 36f22893f63..4b02d6e9884 100644 --- a/compiler/src/java_plugin/cpp/java_plugin.cpp +++ b/compiler/src/java_plugin/cpp/java_plugin.cpp @@ -23,6 +23,9 @@ #include "java_generator.h" #include +#if GOOGLE_PROTOBUF_VERSION >= 5027000 +#include +#endif #include #include #include @@ -45,9 +48,31 @@ class JavaGrpcGenerator : public protobuf::compiler::CodeGenerator { JavaGrpcGenerator() {} virtual ~JavaGrpcGenerator() {} +// Protobuf 5.27 released edition 2023. +#if GOOGLE_PROTOBUF_VERSION >= 5027000 uint64_t GetSupportedFeatures() const override { - return FEATURE_PROTO3_OPTIONAL; + return Feature::FEATURE_PROTO3_OPTIONAL | + Feature::FEATURE_SUPPORTS_EDITIONS; } + protobuf::Edition GetMinimumEdition() const override { + return protobuf::Edition::EDITION_PROTO2; + } + protobuf::Edition GetMaximumEdition() const override { +#if GOOGLE_PROTOBUF_VERSION >= 6032000 + return protobuf::Edition::EDITION_2024; +#else + return protobuf::Edition::EDITION_2023; +#endif + } + std::vector GetFeatureExtensions() + const override { + return {GetExtensionReflection(pb::java)}; + } +#else + uint64_t GetSupportedFeatures() const override { + return Feature::FEATURE_PROTO3_OPTIONAL; + } +#endif virtual bool Generate(const protobuf::FileDescriptor* file, const std::string& parameter, @@ -58,24 +83,21 @@ class JavaGrpcGenerator : public protobuf::compiler::CodeGenerator { java_grpc_generator::ProtoFlavor flavor = java_grpc_generator::ProtoFlavor::NORMAL; + java_grpc_generator::GeneratedAnnotation generated_annotation = + java_grpc_generator::GeneratedAnnotation::OMIT; - /* - jakarta_mode has these values: - javax, the original behavior - add @javax.annotation.Generated - omit, "less controversial" = just add @io.grpc.stub.annotations.GrpcGenerated - and maybe others in the future - */ - std::string jakarta_mode; bool disable_version = false; for (size_t i = 0; i < options.size(); i++) { if (options[i].first == "lite") { flavor = java_grpc_generator::ProtoFlavor::LITE; } else if (options[i].first == "noversion") { disable_version = true; - } else if (options[i].first == "jakarta_javax") { - jakarta_mode = "javax"; - } else if (options[i].first == "jakarta_omit") { - jakarta_mode = "omit"; + } else if (options[i].first == "@generated") { + if (options[i].second == "omit") { + generated_annotation = java_grpc_generator::GeneratedAnnotation::OMIT; + } else if (options[i].second == "javax") { + generated_annotation = java_grpc_generator::GeneratedAnnotation::JAVAX; + } } } @@ -88,7 +110,7 @@ class JavaGrpcGenerator : public protobuf::compiler::CodeGenerator { std::unique_ptr output( context->Open(filename)); java_grpc_generator::GenerateService( - service, output.get(), flavor, disable_version, jakarta_mode); + service, output.get(), flavor, disable_version, generated_annotation); } return true; } diff --git a/compiler/src/test/golden/TestDeprecatedService.java.txt b/compiler/src/test/golden/TestDeprecatedService.java.txt index f6e1797b2a0..0b4924f3e6a 100644 --- a/compiler/src/test/golden/TestDeprecatedService.java.txt +++ b/compiler/src/test/golden/TestDeprecatedService.java.txt @@ -8,7 +8,7 @@ import static io.grpc.MethodDescriptor.generateFullMethodName; *

*/ @javax.annotation.Generated( - value = "by gRPC proto compiler (version 1.63.0-SNAPSHOT)", + value = "by gRPC proto compiler (version 1.82.0-SNAPSHOT)", comments = "Source: grpc/testing/compiler/test.proto") @io.grpc.stub.annotations.GrpcGenerated @java.lang.Deprecated @@ -64,6 +64,21 @@ public final class TestDeprecatedServiceGrpc { return TestDeprecatedServiceStub.newStub(factory, channel); } + /** + * Creates a new blocking-style stub that supports all types of calls on the service + */ + public static TestDeprecatedServiceBlockingV2Stub newBlockingV2Stub( + io.grpc.Channel channel) { + io.grpc.stub.AbstractStub.StubFactory factory = + new io.grpc.stub.AbstractStub.StubFactory() { + @java.lang.Override + public TestDeprecatedServiceBlockingV2Stub newStub(io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new TestDeprecatedServiceBlockingV2Stub(channel, callOptions); + } + }; + return TestDeprecatedServiceBlockingV2Stub.newStub(factory, channel); + } + /** * Creates a new blocking-style stub that supports unary and streaming output calls on the service */ @@ -169,6 +184,38 @@ public final class TestDeprecatedServiceGrpc { *
*/ @java.lang.Deprecated + public static final class TestDeprecatedServiceBlockingV2Stub + extends io.grpc.stub.AbstractBlockingStub { + private TestDeprecatedServiceBlockingV2Stub( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + super(channel, callOptions); + } + + @java.lang.Override + protected TestDeprecatedServiceBlockingV2Stub build( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new TestDeprecatedServiceBlockingV2Stub(channel, callOptions); + } + + /** + *
+     * An RPC method that has been deprecated and should generate with Java's @Deprecated annotation
+     * 
+ */ + @java.lang.Deprecated + public io.grpc.testing.compiler.Test.SimpleResponse deprecatedMethod(io.grpc.testing.compiler.Test.SimpleRequest request) throws io.grpc.StatusException { + return io.grpc.stub.ClientCalls.blockingV2UnaryCall( + getChannel(), getDeprecatedMethodMethod(), getCallOptions(), request); + } + } + + /** + * A stub to allow clients to do limited synchronous rpc calls to service TestDeprecatedService. + *
+   * Test service that has been deprecated and should generate with Java's @Deprecated annotation
+   * 
+ */ + @java.lang.Deprecated public static final class TestDeprecatedServiceBlockingStub extends io.grpc.stub.AbstractBlockingStub { private TestDeprecatedServiceBlockingStub( diff --git a/compiler/src/test/golden/TestService.java.txt b/compiler/src/test/golden/TestService.java.txt index 8e297eb6bee..5c65890273c 100644 --- a/compiler/src/test/golden/TestService.java.txt +++ b/compiler/src/test/golden/TestService.java.txt @@ -8,7 +8,7 @@ import static io.grpc.MethodDescriptor.generateFullMethodName; * */ @javax.annotation.Generated( - value = "by gRPC proto compiler (version 1.63.0-SNAPSHOT)", + value = "by gRPC proto compiler (version 1.82.0-SNAPSHOT)", comments = "Source: grpc/testing/compiler/test.proto") @io.grpc.stub.annotations.GrpcGenerated public final class TestServiceGrpc { @@ -282,6 +282,21 @@ public final class TestServiceGrpc { return TestServiceStub.newStub(factory, channel); } + /** + * Creates a new blocking-style stub that supports all types of calls on the service + */ + public static TestServiceBlockingV2Stub newBlockingV2Stub( + io.grpc.Channel channel) { + io.grpc.stub.AbstractStub.StubFactory factory = + new io.grpc.stub.AbstractStub.StubFactory() { + @java.lang.Override + public TestServiceBlockingV2Stub newStub(io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new TestServiceBlockingV2Stub(channel, callOptions); + } + }; + return TestServiceBlockingV2Stub.newStub(factory, channel); + } + /** * Creates a new blocking-style stub that supports unary and streaming output calls on the service */ @@ -546,6 +561,125 @@ public final class TestServiceGrpc { * Test service that supports all call types. * */ + public static final class TestServiceBlockingV2Stub + extends io.grpc.stub.AbstractBlockingStub { + private TestServiceBlockingV2Stub( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + super(channel, callOptions); + } + + @java.lang.Override + protected TestServiceBlockingV2Stub build( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new TestServiceBlockingV2Stub(channel, callOptions); + } + + /** + *
+     * One request followed by one response.
+     * The server returns the client payload as-is.
+     * 
+ */ + public io.grpc.testing.compiler.Test.SimpleResponse unaryCall(io.grpc.testing.compiler.Test.SimpleRequest request) throws io.grpc.StatusException { + return io.grpc.stub.ClientCalls.blockingV2UnaryCall( + getChannel(), getUnaryCallMethod(), getCallOptions(), request); + } + + /** + *
+     * One request followed by a sequence of responses (streamed download).
+     * The server returns the payload with client desired type and sizes.
+     * 
+ */ + @io.grpc.ExperimentalApi("https://github.com/grpc/grpc-java/issues/10918") + public io.grpc.stub.BlockingClientCall + streamingOutputCall(io.grpc.testing.compiler.Test.StreamingOutputCallRequest request) { + return io.grpc.stub.ClientCalls.blockingV2ServerStreamingCall( + getChannel(), getStreamingOutputCallMethod(), getCallOptions(), request); + } + + /** + *
+     * A sequence of requests followed by one response (streamed upload).
+     * The server returns the aggregated size of client payload as the result.
+     * 
+ */ + @io.grpc.ExperimentalApi("https://github.com/grpc/grpc-java/issues/10918") + public io.grpc.stub.BlockingClientCall + streamingInputCall() { + return io.grpc.stub.ClientCalls.blockingClientStreamingCall( + getChannel(), getStreamingInputCallMethod(), getCallOptions()); + } + + /** + *
+     * A sequence of requests with each request served by the server immediately.
+     * As one request could lead to multiple responses, this interface
+     * demonstrates the idea of full bidirectionality.
+     * 
+ */ + @io.grpc.ExperimentalApi("https://github.com/grpc/grpc-java/issues/10918") + public io.grpc.stub.BlockingClientCall + fullBidiCall() { + return io.grpc.stub.ClientCalls.blockingBidiStreamingCall( + getChannel(), getFullBidiCallMethod(), getCallOptions()); + } + + /** + *
+     * A sequence of requests followed by a sequence of responses.
+     * The server buffers all the client requests and then serves them in order. A
+     * stream of responses are returned to the client when the server starts with
+     * first request.
+     * 
+ */ + @io.grpc.ExperimentalApi("https://github.com/grpc/grpc-java/issues/10918") + public io.grpc.stub.BlockingClientCall + halfBidiCall() { + return io.grpc.stub.ClientCalls.blockingBidiStreamingCall( + getChannel(), getHalfBidiCallMethod(), getCallOptions()); + } + + /** + *
+     * An RPC method whose Java name collides with a keyword, and whose generated
+     * method should have a '_' appended.
+     * 
+ */ + @io.grpc.ExperimentalApi("https://github.com/grpc/grpc-java/issues/10918") + public io.grpc.stub.BlockingClientCall + import_() { + return io.grpc.stub.ClientCalls.blockingBidiStreamingCall( + getChannel(), getImportMethod(), getCallOptions()); + } + + /** + *
+     * A unary call that is Safe.
+     * 
+ */ + public io.grpc.testing.compiler.Test.SimpleResponse safeCall(io.grpc.testing.compiler.Test.SimpleRequest request) throws io.grpc.StatusException { + return io.grpc.stub.ClientCalls.blockingV2UnaryCall( + getChannel(), getSafeCallMethod(), getCallOptions(), request); + } + + /** + *
+     * A unary call that is Idempotent.
+     * 
+ */ + public io.grpc.testing.compiler.Test.SimpleResponse idempotentCall(io.grpc.testing.compiler.Test.SimpleRequest request) throws io.grpc.StatusException { + return io.grpc.stub.ClientCalls.blockingV2UnaryCall( + getChannel(), getIdempotentCallMethod(), getCallOptions(), request); + } + } + + /** + * A stub to allow clients to do limited synchronous rpc calls to service TestService. + *
+   * Test service that supports all call types.
+   * 
+ */ public static final class TestServiceBlockingStub extends io.grpc.stub.AbstractBlockingStub { private TestServiceBlockingStub( diff --git a/compiler/src/testLite/golden/TestDeprecatedService.java.txt b/compiler/src/testLite/golden/TestDeprecatedService.java.txt index 6906e339f41..89ea2e698bf 100644 --- a/compiler/src/testLite/golden/TestDeprecatedService.java.txt +++ b/compiler/src/testLite/golden/TestDeprecatedService.java.txt @@ -7,9 +7,6 @@ import static io.grpc.MethodDescriptor.generateFullMethodName; * Test service that has been deprecated and should generate with Java's @Deprecated annotation * */ -@javax.annotation.Generated( - value = "by gRPC proto compiler (version 1.63.0-SNAPSHOT)", - comments = "Source: grpc/testing/compiler/test.proto") @io.grpc.stub.annotations.GrpcGenerated @java.lang.Deprecated public final class TestDeprecatedServiceGrpc { @@ -63,6 +60,21 @@ public final class TestDeprecatedServiceGrpc { return TestDeprecatedServiceStub.newStub(factory, channel); } + /** + * Creates a new blocking-style stub that supports all types of calls on the service + */ + public static TestDeprecatedServiceBlockingV2Stub newBlockingV2Stub( + io.grpc.Channel channel) { + io.grpc.stub.AbstractStub.StubFactory factory = + new io.grpc.stub.AbstractStub.StubFactory() { + @java.lang.Override + public TestDeprecatedServiceBlockingV2Stub newStub(io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new TestDeprecatedServiceBlockingV2Stub(channel, callOptions); + } + }; + return TestDeprecatedServiceBlockingV2Stub.newStub(factory, channel); + } + /** * Creates a new blocking-style stub that supports unary and streaming output calls on the service */ @@ -168,6 +180,38 @@ public final class TestDeprecatedServiceGrpc { * */ @java.lang.Deprecated + public static final class TestDeprecatedServiceBlockingV2Stub + extends io.grpc.stub.AbstractBlockingStub { + private TestDeprecatedServiceBlockingV2Stub( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + super(channel, callOptions); + } + + @java.lang.Override + protected TestDeprecatedServiceBlockingV2Stub build( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new TestDeprecatedServiceBlockingV2Stub(channel, callOptions); + } + + /** + *
+     * An RPC method that has been deprecated and should generate with Java's @Deprecated annotation
+     * 
+ */ + @java.lang.Deprecated + public io.grpc.testing.compiler.Test.SimpleResponse deprecatedMethod(io.grpc.testing.compiler.Test.SimpleRequest request) throws io.grpc.StatusException { + return io.grpc.stub.ClientCalls.blockingV2UnaryCall( + getChannel(), getDeprecatedMethodMethod(), getCallOptions(), request); + } + } + + /** + * A stub to allow clients to do limited synchronous rpc calls to service TestDeprecatedService. + *
+   * Test service that has been deprecated and should generate with Java's @Deprecated annotation
+   * 
+ */ + @java.lang.Deprecated public static final class TestDeprecatedServiceBlockingStub extends io.grpc.stub.AbstractBlockingStub { private TestDeprecatedServiceBlockingStub( diff --git a/compiler/src/testLite/golden/TestService.java.txt b/compiler/src/testLite/golden/TestService.java.txt index bbeed66d9d0..4e9dfb8d682 100644 --- a/compiler/src/testLite/golden/TestService.java.txt +++ b/compiler/src/testLite/golden/TestService.java.txt @@ -7,9 +7,6 @@ import static io.grpc.MethodDescriptor.generateFullMethodName; * Test service that supports all call types. * */ -@javax.annotation.Generated( - value = "by gRPC proto compiler (version 1.63.0-SNAPSHOT)", - comments = "Source: grpc/testing/compiler/test.proto") @io.grpc.stub.annotations.GrpcGenerated public final class TestServiceGrpc { @@ -274,6 +271,21 @@ public final class TestServiceGrpc { return TestServiceStub.newStub(factory, channel); } + /** + * Creates a new blocking-style stub that supports all types of calls on the service + */ + public static TestServiceBlockingV2Stub newBlockingV2Stub( + io.grpc.Channel channel) { + io.grpc.stub.AbstractStub.StubFactory factory = + new io.grpc.stub.AbstractStub.StubFactory() { + @java.lang.Override + public TestServiceBlockingV2Stub newStub(io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new TestServiceBlockingV2Stub(channel, callOptions); + } + }; + return TestServiceBlockingV2Stub.newStub(factory, channel); + } + /** * Creates a new blocking-style stub that supports unary and streaming output calls on the service */ @@ -538,6 +550,125 @@ public final class TestServiceGrpc { * Test service that supports all call types. * */ + public static final class TestServiceBlockingV2Stub + extends io.grpc.stub.AbstractBlockingStub { + private TestServiceBlockingV2Stub( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + super(channel, callOptions); + } + + @java.lang.Override + protected TestServiceBlockingV2Stub build( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new TestServiceBlockingV2Stub(channel, callOptions); + } + + /** + *
+     * One request followed by one response.
+     * The server returns the client payload as-is.
+     * 
+ */ + public io.grpc.testing.compiler.Test.SimpleResponse unaryCall(io.grpc.testing.compiler.Test.SimpleRequest request) throws io.grpc.StatusException { + return io.grpc.stub.ClientCalls.blockingV2UnaryCall( + getChannel(), getUnaryCallMethod(), getCallOptions(), request); + } + + /** + *
+     * One request followed by a sequence of responses (streamed download).
+     * The server returns the payload with client desired type and sizes.
+     * 
+ */ + @io.grpc.ExperimentalApi("https://github.com/grpc/grpc-java/issues/10918") + public io.grpc.stub.BlockingClientCall + streamingOutputCall(io.grpc.testing.compiler.Test.StreamingOutputCallRequest request) { + return io.grpc.stub.ClientCalls.blockingV2ServerStreamingCall( + getChannel(), getStreamingOutputCallMethod(), getCallOptions(), request); + } + + /** + *
+     * A sequence of requests followed by one response (streamed upload).
+     * The server returns the aggregated size of client payload as the result.
+     * 
+ */ + @io.grpc.ExperimentalApi("https://github.com/grpc/grpc-java/issues/10918") + public io.grpc.stub.BlockingClientCall + streamingInputCall() { + return io.grpc.stub.ClientCalls.blockingClientStreamingCall( + getChannel(), getStreamingInputCallMethod(), getCallOptions()); + } + + /** + *
+     * A sequence of requests with each request served by the server immediately.
+     * As one request could lead to multiple responses, this interface
+     * demonstrates the idea of full bidirectionality.
+     * 
+ */ + @io.grpc.ExperimentalApi("https://github.com/grpc/grpc-java/issues/10918") + public io.grpc.stub.BlockingClientCall + fullBidiCall() { + return io.grpc.stub.ClientCalls.blockingBidiStreamingCall( + getChannel(), getFullBidiCallMethod(), getCallOptions()); + } + + /** + *
+     * A sequence of requests followed by a sequence of responses.
+     * The server buffers all the client requests and then serves them in order. A
+     * stream of responses are returned to the client when the server starts with
+     * first request.
+     * 
+ */ + @io.grpc.ExperimentalApi("https://github.com/grpc/grpc-java/issues/10918") + public io.grpc.stub.BlockingClientCall + halfBidiCall() { + return io.grpc.stub.ClientCalls.blockingBidiStreamingCall( + getChannel(), getHalfBidiCallMethod(), getCallOptions()); + } + + /** + *
+     * An RPC method whose Java name collides with a keyword, and whose generated
+     * method should have a '_' appended.
+     * 
+ */ + @io.grpc.ExperimentalApi("https://github.com/grpc/grpc-java/issues/10918") + public io.grpc.stub.BlockingClientCall + import_() { + return io.grpc.stub.ClientCalls.blockingBidiStreamingCall( + getChannel(), getImportMethod(), getCallOptions()); + } + + /** + *
+     * A unary call that is Safe.
+     * 
+ */ + public io.grpc.testing.compiler.Test.SimpleResponse safeCall(io.grpc.testing.compiler.Test.SimpleRequest request) throws io.grpc.StatusException { + return io.grpc.stub.ClientCalls.blockingV2UnaryCall( + getChannel(), getSafeCallMethod(), getCallOptions(), request); + } + + /** + *
+     * A unary call that is Idempotent.
+     * 
+ */ + public io.grpc.testing.compiler.Test.SimpleResponse idempotentCall(io.grpc.testing.compiler.Test.SimpleRequest request) throws io.grpc.StatusException { + return io.grpc.stub.ClientCalls.blockingV2UnaryCall( + getChannel(), getIdempotentCallMethod(), getCallOptions(), request); + } + } + + /** + * A stub to allow clients to do limited synchronous rpc calls to service TestService. + *
+   * Test service that supports all call types.
+   * 
+ */ public static final class TestServiceBlockingStub extends io.grpc.stub.AbstractBlockingStub { private TestServiceBlockingStub( diff --git a/context/BUILD.bazel b/context/BUILD.bazel index d0c4b04ce00..0a51dca24a9 100644 --- a/context/BUILD.bazel +++ b/context/BUILD.bazel @@ -1,3 +1,5 @@ +load("@rules_java//java:defs.bzl", "java_library") + java_library( name = "context", visibility = ["//visibility:public"], diff --git a/contextstorage/build.gradle b/contextstorage/build.gradle new file mode 100644 index 00000000000..b1e78ea0e17 --- /dev/null +++ b/contextstorage/build.gradle @@ -0,0 +1,35 @@ +plugins { + id "java-library" + id "maven-publish" + + id "ru.vyarus.animalsniffer" +} + +description = 'gRPC: ContextStorageOverride' + +dependencies { + api project(':grpc-api') + implementation libraries.opentelemetry.api + + testImplementation libraries.junit, + libraries.opentelemetry.sdk.testing, + libraries.assertj.core + testImplementation 'junit:junit:4.13.1'// opentelemetry.sdk.testing uses compileOnly for assertj + + signature (libraries.signature.java) { + artifact { + extension = "signature" + } + } + signature (libraries.signature.android) { + artifact { + extension = "signature" + } + } +} + +tasks.named("jar").configure { + manifest { + attributes('Automatic-Module-Name': 'io.grpc.override') + } +} diff --git a/contextstorage/src/main/java/io/grpc/override/ContextStorageOverride.java b/contextstorage/src/main/java/io/grpc/override/ContextStorageOverride.java new file mode 100644 index 00000000000..41b24765de0 --- /dev/null +++ b/contextstorage/src/main/java/io/grpc/override/ContextStorageOverride.java @@ -0,0 +1,46 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.override; + +import io.grpc.Context; + +/** + * Including this class in your dependencies will override the default gRPC context storage using + * reflection. It is a bridge between {@link io.grpc.Context} and + * {@link io.opentelemetry.context.Context}, i.e. propagating io.grpc.context.Context also + * propagates io.opentelemetry.context, and propagating io.opentelemetry.context will also propagate + * io.grpc.context. + */ +public final class ContextStorageOverride extends Context.Storage { + + private final Context.Storage delegate = new OpenTelemetryContextStorage(); + + @Override + public Context doAttach(Context toAttach) { + return delegate.doAttach(toAttach); + } + + @Override + public void detach(Context toDetach, Context toRestore) { + delegate.detach(toDetach, toRestore); + } + + @Override + public Context current() { + return delegate.current(); + } +} diff --git a/contextstorage/src/main/java/io/grpc/override/OpenTelemetryContextStorage.java b/contextstorage/src/main/java/io/grpc/override/OpenTelemetryContextStorage.java new file mode 100644 index 00000000000..01356e9f406 --- /dev/null +++ b/contextstorage/src/main/java/io/grpc/override/OpenTelemetryContextStorage.java @@ -0,0 +1,72 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.override; + +import io.grpc.Context; +import io.opentelemetry.context.ContextKey; +import io.opentelemetry.context.Scope; +import java.util.logging.Level; +import java.util.logging.Logger; + +/** + * A Context.Storage implementation that attaches io.grpc.context to OpenTelemetry's context and + * io.opentelemetry.context is also saved in the io.grpc.context. + * Bridge between {@link io.grpc.Context} and {@link io.opentelemetry.context.Context}. + */ +final class OpenTelemetryContextStorage extends Context.Storage { + private static final Logger logger = Logger.getLogger( + OpenTelemetryContextStorage.class.getName()); + + private static final io.grpc.Context.Key OTEL_CONTEXT_OVER_GRPC + = io.grpc.Context.key("otel-context-over-grpc"); + private static final Context.Key OTEL_SCOPE = Context.key("otel-scope"); + private static final ContextKey GRPC_CONTEXT_OVER_OTEL = + ContextKey.named("grpc-context-over-otel"); + + @Override + @SuppressWarnings("MustBeClosedChecker") + public Context doAttach(Context toAttach) { + io.grpc.Context previous = current(); + io.opentelemetry.context.Context otelContext = OTEL_CONTEXT_OVER_GRPC.get(toAttach); + if (otelContext == null) { + otelContext = io.opentelemetry.context.Context.current(); + } + Scope scope = otelContext.with(GRPC_CONTEXT_OVER_OTEL, toAttach).makeCurrent(); + return previous.withValue(OTEL_SCOPE, scope); + } + + @Override + public void detach(Context toDetach, Context toRestore) { + Scope scope = OTEL_SCOPE.get(toRestore); + if (scope == null) { + logger.log( + Level.SEVERE, "Detaching context which was not attached."); + } else { + scope.close(); + } + } + + @Override + public Context current() { + io.opentelemetry.context.Context otelCurrent = io.opentelemetry.context.Context.current(); + io.grpc.Context grpcCurrent = otelCurrent.get(GRPC_CONTEXT_OVER_OTEL); + if (grpcCurrent == null) { + grpcCurrent = Context.ROOT; + } + return grpcCurrent.withValue(OTEL_CONTEXT_OVER_GRPC, otelCurrent); + } +} diff --git a/contextstorage/src/test/java/io/grpc/override/OpenTelemetryContextStorageTest.java b/contextstorage/src/test/java/io/grpc/override/OpenTelemetryContextStorageTest.java new file mode 100644 index 00000000000..3c628964342 --- /dev/null +++ b/contextstorage/src/test/java/io/grpc/override/OpenTelemetryContextStorageTest.java @@ -0,0 +1,144 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.override; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; + +import com.google.common.util.concurrent.SettableFuture; +import io.opentelemetry.api.trace.Span; +import io.opentelemetry.api.trace.Tracer; +import io.opentelemetry.context.Context; +import io.opentelemetry.context.ContextKey; +import io.opentelemetry.context.Scope; +import io.opentelemetry.sdk.testing.junit4.OpenTelemetryRule; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; +import org.junit.Assert; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class OpenTelemetryContextStorageTest { + @Rule + public final OpenTelemetryRule openTelemetryRule = OpenTelemetryRule.create(); + private Tracer tracerRule = openTelemetryRule.getOpenTelemetry().getTracer( + "context-storage-test"); + private final io.grpc.Context.Key username = io.grpc.Context.key("username"); + private final ContextKey password = ContextKey.named("password"); + + @Test + public void grpcContextPropagation() throws Exception { + final Span parentSpan = tracerRule.spanBuilder("test-context").startSpan(); + final SettableFuture spanPropagated = SettableFuture.create(); + final SettableFuture grpcContextPropagated = SettableFuture.create(); + final SettableFuture spanDetached = SettableFuture.create(); + final SettableFuture grpcContextDetached = SettableFuture.create(); + + io.grpc.Context grpcContext; + try (Scope scope = Context.current().with(parentSpan).makeCurrent()) { + grpcContext = io.grpc.Context.current().withValue(username, "jeff"); + } + new Thread(new Runnable() { + @Override + public void run() { + io.grpc.Context previous = grpcContext.attach(); + try { + grpcContextPropagated.set(username.get(io.grpc.Context.current())); + spanPropagated.set(Span.fromContext(io.opentelemetry.context.Context.current())); + } finally { + grpcContext.detach(previous); + spanDetached.set(Span.fromContext(io.opentelemetry.context.Context.current())); + grpcContextDetached.set(username.get(io.grpc.Context.current())); + } + } + }).start(); + Assert.assertEquals(spanPropagated.get(5, TimeUnit.SECONDS), parentSpan); + Assert.assertEquals(grpcContextPropagated.get(5, TimeUnit.SECONDS), "jeff"); + Assert.assertEquals(spanDetached.get(5, TimeUnit.SECONDS), Span.getInvalid()); + Assert.assertNull(grpcContextDetached.get(5, TimeUnit.SECONDS)); + } + + @Test + public void otelContextPropagation() throws Exception { + final SettableFuture grpcPropagated = SettableFuture.create(); + final AtomicReference otelPropagation = new AtomicReference<>(); + + io.grpc.Context grpcContext = io.grpc.Context.current().withValue(username, "jeff"); + io.grpc.Context previous = grpcContext.attach(); + Context original = Context.current().with(password, "valentine"); + try { + new Thread( + () -> { + try (Scope scope = original.makeCurrent()) { + otelPropagation.set(Context.current().get(password)); + grpcPropagated.set(username.get(io.grpc.Context.current())); + } + } + ).start(); + } finally { + grpcContext.detach(previous); + } + Assert.assertEquals(grpcPropagated.get(5, TimeUnit.SECONDS), "jeff"); + Assert.assertEquals(otelPropagation.get(), "valentine"); + } + + @Test + public void grpcOtelMix() { + io.grpc.Context grpcContext = io.grpc.Context.current().withValue(username, "jeff"); + Context otelContext = Context.current().with(password, "valentine"); + Assert.assertNull(username.get(io.grpc.Context.current())); + Assert.assertNull(Context.current().get(password)); + io.grpc.Context previous = grpcContext.attach(); + try { + assertEquals(username.get(io.grpc.Context.current()), "jeff"); + try (Scope scope = otelContext.makeCurrent()) { + Assert.assertEquals(Context.current().get(password), "valentine"); + assertNull(username.get(io.grpc.Context.current())); + + io.grpc.Context grpcContext2 = io.grpc.Context.current().withValue(username, "frank"); + io.grpc.Context previous2 = grpcContext2.attach(); + try { + assertEquals(username.get(io.grpc.Context.current()), "frank"); + Assert.assertEquals(Context.current().get(password), "valentine"); + } finally { + grpcContext2.detach(previous2); + } + assertNull(username.get(io.grpc.Context.current())); + Assert.assertEquals(Context.current().get(password), "valentine"); + } + } finally { + grpcContext.detach(previous); + } + Assert.assertNull(username.get(io.grpc.Context.current())); + Assert.assertNull(Context.current().get(password)); + } + + @Test + public void grpcContextDetachError() { + io.grpc.Context grpcContext = io.grpc.Context.current().withValue(username, "jeff"); + io.grpc.Context previous = grpcContext.attach(); + try { + previous.detach(grpcContext); + assertEquals(username.get(io.grpc.Context.current()), "jeff"); + } finally { + grpcContext.detach(previous); + } + } +} diff --git a/core/BUILD.bazel b/core/BUILD.bazel index ebe5b64c277..1a743ff9eda 100644 --- a/core/BUILD.bazel +++ b/core/BUILD.bazel @@ -1,3 +1,6 @@ +load("@rules_java//java:defs.bzl", "java_library") +load("@rules_jvm_external//:defs.bzl", "artifact") + java_library( name = "core", visibility = ["//visibility:public"], @@ -15,7 +18,6 @@ java_library( srcs = glob([ "src/main/java/io/grpc/internal/*.java", ]), - javacopts = ["-Xep:DoNotCall:OFF"], # Remove once requiring Bazel 3.4.0+; allows non-final resources = glob([ "src/bazel-internal/resources/**", ]), @@ -23,14 +25,13 @@ java_library( deps = [ "//api", "//context", - "@com_google_android_annotations//jar", - "@com_google_code_findbugs_jsr305//jar", - "@com_google_code_gson_gson//jar", - "@com_google_errorprone_error_prone_annotations//jar", - "@com_google_guava_guava//jar", - "@com_google_j2objc_j2objc_annotations//jar", - "@io_perfmark_perfmark_api//jar", - "@org_codehaus_mojo_animal_sniffer_annotations//jar", + artifact("com.google.code.gson:gson"), + artifact("com.google.android:annotations"), + artifact("com.google.code.findbugs:jsr305"), + artifact("com.google.errorprone:error_prone_annotations"), + artifact("com.google.guava:guava"), + artifact("io.perfmark:perfmark-api"), + artifact("org.codehaus.mojo:animal-sniffer-annotations"), ], ) diff --git a/core/build.gradle b/core/build.gradle index 22c68b21147..b320f326b41 100644 --- a/core/build.gradle +++ b/core/build.gradle @@ -1,6 +1,6 @@ buildscript { dependencies { - classpath 'com.google.guava:guava:30.0-android' + classpath 'com.google.guava:guava:33.4.8-android' } } @@ -32,7 +32,6 @@ dependencies { libraries.truth, project(':grpc-testing') testImplementation testFixtures(project(':grpc-api')), - project(':grpc-inprocess'), project(':grpc-testing') testImplementation libraries.guava.testlib @@ -40,8 +39,16 @@ dependencies { jmh project(':grpc-testing') - signature libraries.signature.java - signature libraries.signature.android + signature (libraries.signature.java) { + artifact { + extension = "signature" + } + } + signature (libraries.signature.android) { + artifact { + extension = "signature" + } + } } tasks.named("javadoc").configure { diff --git a/core/src/main/java/io/grpc/internal/AbstractClientStream.java b/core/src/main/java/io/grpc/internal/AbstractClientStream.java index a4ebfa52d63..bce1820b482 100644 --- a/core/src/main/java/io/grpc/internal/AbstractClientStream.java +++ b/core/src/main/java/io/grpc/internal/AbstractClientStream.java @@ -21,7 +21,6 @@ import static io.grpc.internal.GrpcUtil.CONTENT_ENCODING_KEY; import static io.grpc.internal.GrpcUtil.MESSAGE_ENCODING_KEY; import static io.grpc.internal.GrpcUtil.TIMEOUT_KEY; -import static java.lang.Math.max; import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; @@ -44,9 +43,9 @@ import javax.annotation.Nullable; /** - * The abstract base class for {@link ClientStream} implementations. Extending classes only need to - * implement {@link #transportState()} and {@link #abstractClientStreamSink()}. Must only be called - * from the sending application thread. + * The abstract base class for {@link ClientStream} implementations. + * + *

Must only be called from the sending application thread. */ public abstract class AbstractClientStream extends AbstractStream implements ClientStream, MessageFramer.Sink { @@ -92,8 +91,8 @@ void writeFrame( private final TransportTracer transportTracer; private final Framer framer; - private boolean shouldBeCountedForInUse; - private boolean useGet; + private final boolean shouldBeCountedForInUse; + private final boolean useGet; private Metadata headers; /** * Whether cancel() has been called. This is not strictly necessary, but removes the delay between @@ -102,6 +101,7 @@ void writeFrame( */ private volatile boolean cancelled; + @SuppressWarnings("this-escape") protected AbstractClientStream( WritableBufferAllocator bufferAllocator, StatsTraceContext statsTraceCtx, @@ -114,7 +114,7 @@ protected AbstractClientStream( this.shouldBeCountedForInUse = GrpcUtil.shouldBeCountedForInUse(callOptions); this.useGet = useGet; if (!useGet) { - framer = new MessageFramer(this, bufferAllocator, statsTraceCtx); + this.framer = new MessageFramer(this, bufferAllocator, statsTraceCtx); this.headers = headers; } else { framer = new GetFramer(headers, statsTraceCtx); @@ -124,8 +124,7 @@ protected AbstractClientStream( @Override public void setDeadline(Deadline deadline) { headers.discardAll(TIMEOUT_KEY); - long effectiveTimeout = max(0, deadline.timeRemaining(TimeUnit.NANOSECONDS)); - headers.put(TIMEOUT_KEY, effectiveTimeout); + headers.put(TIMEOUT_KEY, deadline.timeRemaining(TimeUnit.NANOSECONDS)); } @Override @@ -243,9 +242,13 @@ protected abstract static class TransportState extends AbstractStream.TransportS protected TransportState( int maxMessageSize, StatsTraceContext statsTraceCtx, - TransportTracer transportTracer) { + TransportTracer transportTracer, + CallOptions options) { super(maxMessageSize, statsTraceCtx, transportTracer); this.statsTraceCtx = checkNotNull(statsTraceCtx, "statsTraceCtx"); + if (options.getOnReadyThreshold() != null) { + this.setOnReadyThreshold(options.getOnReadyThreshold()); + } } private void setFullStreamDecompression(boolean fullStreamDecompression) { @@ -300,7 +303,7 @@ protected final boolean isOutboundClosed() { */ protected void inboundHeadersReceived(Metadata headers) { checkState(!statusReported, "Received headers on closed stream"); - statsTraceCtx.clientInboundHeaders(); + statsTraceCtx.clientInboundHeaders(headers); boolean compressedStream = false; String streamEncoding = headers.get(CONTENT_ENCODING_KEY); diff --git a/core/src/main/java/io/grpc/internal/AbstractServerStream.java b/core/src/main/java/io/grpc/internal/AbstractServerStream.java index d781cfa9b8a..c468cba978a 100644 --- a/core/src/main/java/io/grpc/internal/AbstractServerStream.java +++ b/core/src/main/java/io/grpc/internal/AbstractServerStream.java @@ -75,10 +75,11 @@ protected interface Sink { private boolean outboundClosed; private boolean headersSent; + @SuppressWarnings("this-escape") protected AbstractServerStream( WritableBufferAllocator bufferAllocator, StatsTraceContext statsTraceCtx) { this.statsTraceCtx = Preconditions.checkNotNull(statsTraceCtx, "statsTraceCtx"); - framer = new MessageFramer(this, bufferAllocator, statsTraceCtx); + this.framer = new MessageFramer(this, bufferAllocator, statsTraceCtx); } @Override @@ -177,6 +178,19 @@ public StatsTraceContext statsTraceContext() { return statsTraceCtx; } + /** + * A hint to the stream that specifies how many bytes must be queued before + * {@link #isReady()} will return false. A stream may ignore this property + * if unsupported. This may only be set before any messages are sent. + * + * @param numBytes The number of bytes that must be queued. Must be a + * positive integer. + */ + @Override + public void setOnReadyThreshold(int numBytes) { + super.setOnReadyThreshold(numBytes); + } + /** * This should only be called from the transport thread (except for private interactions with * {@code AbstractServerStream}). @@ -243,6 +257,8 @@ public void deframerClosed(boolean hasPartialMessage) { } } + + @Override protected ServerStreamListener listener() { return listener; @@ -278,6 +294,7 @@ public void inboundDataReceived(ReadableBuffer frame, boolean endOfStream) { */ public final void transportReportStatus(final Status status) { Preconditions.checkArgument(!status.isOk(), "status must not be OK"); + onStreamDeallocated(); if (deframerClosed) { deframerClosedTask = null; closeListener(status); @@ -300,6 +317,7 @@ public void run() { * #transportReportStatus}. */ public void complete() { + onStreamDeallocated(); if (deframerClosed) { deframerClosedTask = null; closeListener(Status.OK); @@ -335,7 +353,6 @@ private void closeListener(Status newStatus) { getTransportTracer().reportStreamClosed(closedStatus.isOk()); } listenerClosed = true; - onStreamDeallocated(); listener().closed(newStatus); } } diff --git a/core/src/main/java/io/grpc/internal/AbstractStream.java b/core/src/main/java/io/grpc/internal/AbstractStream.java index cda08576eae..9f5fb035dab 100644 --- a/core/src/main/java/io/grpc/internal/AbstractStream.java +++ b/core/src/main/java/io/grpc/internal/AbstractStream.java @@ -20,6 +20,7 @@ import static com.google.common.base.Preconditions.checkState; import com.google.common.annotations.VisibleForTesting; +import com.google.errorprone.annotations.concurrent.GuardedBy; import io.grpc.Codec; import io.grpc.Compressor; import io.grpc.Decompressor; @@ -27,13 +28,16 @@ import io.perfmark.PerfMark; import io.perfmark.TaskCloseable; import java.io.InputStream; -import javax.annotation.concurrent.GuardedBy; +import java.util.logging.Level; +import java.util.logging.Logger; /** * The stream and stream state as used by the application. Must only be called from the sending * application thread. */ public abstract class AbstractStream implements Stream { + private static final Logger log = Logger.getLogger(AbstractStream.class.getName()); + /** The framer to use for sending messages. */ protected abstract Framer framer(); @@ -77,6 +81,19 @@ public final void flush() { } } + /** + * A hint to the stream that specifies how many bytes must be queued before + * {@link #isReady()} will return false. A stream may ignore this property if + * unsupported. This may only be set during stream initialization before + * any messages are set. + * + * @param numBytes The number of bytes that must be queued. Must be a + * positive integer. + */ + protected void setOnReadyThreshold(int numBytes) { + transportState().setOnReadyThreshold(numBytes); + } + /** * Closes the underlying framer. Should be called when the outgoing stream is gracefully closed * (half closure on client; closure on server). @@ -143,20 +160,25 @@ public abstract static class TransportState @GuardedBy("onReadyLock") private boolean deallocated; + @GuardedBy("onReadyLock") + private int onReadyThreshold; + + @SuppressWarnings("this-escape") protected TransportState( int maxMessageSize, StatsTraceContext statsTraceCtx, TransportTracer transportTracer) { this.statsTraceCtx = checkNotNull(statsTraceCtx, "statsTraceCtx"); this.transportTracer = checkNotNull(transportTracer, "transportTracer"); - rawDeframer = new MessageDeframer( + this.rawDeframer = new MessageDeframer( this, Codec.Identity.NONE, maxMessageSize, statsTraceCtx, transportTracer); // TODO(#7168): use MigratingThreadDeframer when enabling retry doesn't break. - deframer = rawDeframer; + deframer = this.rawDeframer; + onReadyThreshold = DEFAULT_ONREADY_THRESHOLD; } final void optimizeForDirectExecutor() { @@ -178,6 +200,20 @@ final void setMaxInboundMessageSize(int maxSize) { */ protected abstract StreamListener listener(); + /** + * A hint to the stream that specifies how many bytes must be queued before + * {@link #isReady()} will return false. A stream may ignore this property if + * unsupported. This may only be set before any messages are sent. + * + * @param numBytes The number of bytes that must be queued. Must be a + * positive integer. + */ + void setOnReadyThreshold(int numBytes) { + synchronized (onReadyLock) { + this.onReadyThreshold = numBytes; + } + } + @Override public void messagesAvailable(StreamListener.MessageProducer producer) { listener().messagesAvailable(producer); @@ -259,7 +295,7 @@ protected final void setDecompressor(Decompressor decompressor) { private boolean isReady() { synchronized (onReadyLock) { - return allocated && numSentBytesQueued < DEFAULT_ONREADY_THRESHOLD && !deallocated; + return allocated && numSentBytesQueued < onReadyThreshold && !deallocated; } } @@ -291,6 +327,12 @@ protected final void onStreamDeallocated() { } } + protected boolean isStreamDeallocated() { + synchronized (onReadyLock) { + return deallocated; + } + } + /** * Event handler to be called by the subclass when a number of bytes are being queued for * sending to the remote endpoint. @@ -316,9 +358,9 @@ public final void onSentBytes(int numBytes) { synchronized (onReadyLock) { checkState(allocated, "onStreamAllocated was not called, but it seems the stream is active"); - boolean belowThresholdBefore = numSentBytesQueued < DEFAULT_ONREADY_THRESHOLD; + boolean belowThresholdBefore = numSentBytesQueued < onReadyThreshold; numSentBytesQueued -= numBytes; - boolean belowThresholdAfter = numSentBytesQueued < DEFAULT_ONREADY_THRESHOLD; + boolean belowThresholdAfter = numSentBytesQueued < onReadyThreshold; doNotify = !belowThresholdBefore && belowThresholdAfter; } if (doNotify) { @@ -334,6 +376,12 @@ private void notifyIfReady() { boolean doNotify; synchronized (onReadyLock) { doNotify = isReady(); + if (!doNotify && log.isLoggable(Level.FINEST)) { + log.log(Level.FINEST, + "Stream not ready so skip notifying listener.\n" + + "details: allocated/deallocated:{0}/{3}, sent queued: {1}, ready thresh: {2}", + new Object[] {allocated, numSentBytesQueued, onReadyThreshold, deallocated}); + } } if (doNotify) { listener().onReady(); diff --git a/core/src/main/java/io/grpc/internal/AuthorityVerifier.java b/core/src/main/java/io/grpc/internal/AuthorityVerifier.java new file mode 100644 index 00000000000..e6164a7dc4d --- /dev/null +++ b/core/src/main/java/io/grpc/internal/AuthorityVerifier.java @@ -0,0 +1,24 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.internal; + +import io.grpc.Status; + +/** Verifier for the outgoing authority pseudo-header against peer cert. */ +public interface AuthorityVerifier { + Status verifyAuthority(String authority); +} diff --git a/core/src/main/java/io/grpc/internal/AutoConfiguredLoadBalancerFactory.java b/core/src/main/java/io/grpc/internal/AutoConfiguredLoadBalancerFactory.java index a382227fd6c..dcefa8f8351 100644 --- a/core/src/main/java/io/grpc/internal/AutoConfiguredLoadBalancerFactory.java +++ b/core/src/main/java/io/grpc/internal/AutoConfiguredLoadBalancerFactory.java @@ -19,17 +19,15 @@ import static com.google.common.base.Preconditions.checkNotNull; import com.google.common.annotations.VisibleForTesting; -import com.google.common.base.MoreObjects; import io.grpc.ChannelLogger.ChannelLogLevel; import io.grpc.ConnectivityState; import io.grpc.ConnectivityStateInfo; import io.grpc.LoadBalancer; +import io.grpc.LoadBalancer.FixedResultPicker; import io.grpc.LoadBalancer.Helper; import io.grpc.LoadBalancer.PickResult; -import io.grpc.LoadBalancer.PickSubchannelArgs; import io.grpc.LoadBalancer.ResolvedAddresses; import io.grpc.LoadBalancer.Subchannel; -import io.grpc.LoadBalancer.SubchannelPicker; import io.grpc.LoadBalancerProvider; import io.grpc.LoadBalancerRegistry; import io.grpc.NameResolver.ConfigOrError; @@ -40,12 +38,10 @@ import java.util.Map; import javax.annotation.Nullable; -// TODO(creamsoup) fully deprecate LoadBalancer.ATTR_LOAD_BALANCING_CONFIG -@SuppressWarnings("deprecation") -public final class AutoConfiguredLoadBalancerFactory { +public final class AutoConfiguredLoadBalancerFactory extends LoadBalancerProvider { private final LoadBalancerRegistry registry; - private final String defaultPolicy; + private final LoadBalancerProvider defaultProvider; public AutoConfiguredLoadBalancerFactory(String defaultPolicy) { this(LoadBalancerRegistry.getDefaultRegistry(), defaultPolicy); @@ -54,47 +50,34 @@ public AutoConfiguredLoadBalancerFactory(String defaultPolicy) { @VisibleForTesting AutoConfiguredLoadBalancerFactory(LoadBalancerRegistry registry, String defaultPolicy) { this.registry = checkNotNull(registry, "registry"); - this.defaultPolicy = checkNotNull(defaultPolicy, "defaultPolicy"); + LoadBalancerProvider provider = + registry.getProvider(checkNotNull(defaultPolicy, "defaultPolicy")); + if (provider == null) { + Status status = Status.INTERNAL.withDescription("Could not find policy '" + defaultPolicy + + "'. Make sure its implementation is either registered to LoadBalancerRegistry or" + + " included in META-INF/services/io.grpc.LoadBalancerProvider from your jar files."); + provider = new FixedPickerLoadBalancerProvider( + ConnectivityState.TRANSIENT_FAILURE, + new LoadBalancer.FixedResultPicker(PickResult.withError(status)), + status); + } + this.defaultProvider = provider; } + @Override public AutoConfiguredLoadBalancer newLoadBalancer(Helper helper) { return new AutoConfiguredLoadBalancer(helper); } - private static final class NoopLoadBalancer extends LoadBalancer { - - @Override - @Deprecated - @SuppressWarnings("InlineMeSuggester") - public void handleResolvedAddresses(ResolvedAddresses resolvedAddresses) { - } - - @Override - public Status acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { - return Status.OK; - } - - @Override - public void handleNameResolutionError(Status error) {} - - @Override - public void shutdown() {} - } - @VisibleForTesting - public final class AutoConfiguredLoadBalancer { + public final class AutoConfiguredLoadBalancer extends LoadBalancer { private final Helper helper; private LoadBalancer delegate; private LoadBalancerProvider delegateProvider; AutoConfiguredLoadBalancer(Helper helper) { this.helper = helper; - delegateProvider = registry.getProvider(defaultPolicy); - if (delegateProvider == null) { - throw new IllegalStateException("Could not find policy '" + defaultPolicy - + "'. Make sure its implementation is either registered to LoadBalancerRegistry or" - + " included in META-INF/services/io.grpc.LoadBalancerProvider from your jar files."); - } + this.delegateProvider = defaultProvider; delegate = delegateProvider.newLoadBalancer(helper); } @@ -102,29 +85,20 @@ public final class AutoConfiguredLoadBalancer { * Returns non-OK status if the delegate rejects the resolvedAddresses (e.g. if it does not * support an empty list). */ - Status tryAcceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { + @Override + public Status acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { PolicySelection policySelection = (PolicySelection) resolvedAddresses.getLoadBalancingPolicyConfig(); if (policySelection == null) { - LoadBalancerProvider defaultProvider; - try { - defaultProvider = getProviderOrThrow(defaultPolicy, "using default policy"); - } catch (PolicyException e) { - Status s = Status.INTERNAL.withDescription(e.getMessage()); - helper.updateBalancingState(ConnectivityState.TRANSIENT_FAILURE, new FailingPicker(s)); - delegate.shutdown(); - delegateProvider = null; - delegate = new NoopLoadBalancer(); - return Status.OK; - } policySelection = new PolicySelection(defaultProvider, /* config= */ null); } if (delegateProvider == null || !policySelection.provider.getPolicyName().equals(delegateProvider.getPolicyName())) { - helper.updateBalancingState(ConnectivityState.CONNECTING, new EmptyPicker()); + helper.updateBalancingState( + ConnectivityState.CONNECTING, new FixedResultPicker(PickResult.withNoResult())); delegate.shutdown(); delegateProvider = policySelection.provider; LoadBalancer old = delegate; @@ -147,20 +121,24 @@ Status tryAcceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { .build()); } - void handleNameResolutionError(Status error) { + @Override + public void handleNameResolutionError(Status error) { getDelegate().handleNameResolutionError(error); } + @Override @Deprecated - void handleSubchannelState(Subchannel subchannel, ConnectivityStateInfo stateInfo) { + public void handleSubchannelState(Subchannel subchannel, ConnectivityStateInfo stateInfo) { getDelegate().handleSubchannelState(subchannel, stateInfo); } - void requestConnection() { + @Override + public void requestConnection() { getDelegate().requestConnection(); } - void shutdown() { + @Override + public void shutdown() { delegate.shutdown(); delegate = null; } @@ -181,16 +159,6 @@ LoadBalancerProvider getDelegateProvider() { } } - private LoadBalancerProvider getProviderOrThrow(String policy, String choiceReason) - throws PolicyException { - LoadBalancerProvider provider = registry.getProvider(policy); - if (provider == null) { - throw new PolicyException( - "Trying to load '" + policy + "' because " + choiceReason + ", but it's unavailable"); - } - return provider; - } - /** * Parses first available LoadBalancer policy from service config. Available LoadBalancer should * be registered to {@link LoadBalancerRegistry}. If the first available LoadBalancer policy is @@ -211,8 +179,11 @@ private LoadBalancerProvider getProviderOrThrow(String policy, String choiceReas * * @return the parsed {@link PolicySelection}, or {@code null} if no selection could be made. */ + // TODO(ejona): The Provider API doesn't allow null, but ScParser can handle this and it will need + // tweaking to ManagedChannelImpl.defaultServiceConfig to fix. @Nullable - ConfigOrError parseLoadBalancerPolicy(Map serviceConfig) { + @Override + public ConfigOrError parseLoadBalancingPolicyConfig(Map serviceConfig) { try { List loadBalancerConfigs = null; if (serviceConfig != null) { @@ -230,38 +201,18 @@ ConfigOrError parseLoadBalancerPolicy(Map serviceConfig) { } } - @VisibleForTesting - static final class PolicyException extends Exception { - private static final long serialVersionUID = 1L; - - private PolicyException(String msg) { - super(msg); - } + @Override + public boolean isAvailable() { + return true; } - private static final class EmptyPicker extends SubchannelPicker { - - @Override - public PickResult pickSubchannel(PickSubchannelArgs args) { - return PickResult.withNoResult(); - } - - @Override - public String toString() { - return MoreObjects.toStringHelper(EmptyPicker.class).toString(); - } + @Override + public int getPriority() { + return 5; } - private static final class FailingPicker extends SubchannelPicker { - private final Status failure; - - FailingPicker(Status failure) { - this.failure = failure; - } - - @Override - public PickResult pickSubchannel(PickSubchannelArgs args) { - return PickResult.withError(failure); - } + @Override + public String getPolicyName() { + return "auto_configured_internal"; } } diff --git a/core/src/main/java/io/grpc/internal/CallCredentialsApplyingTransportFactory.java b/core/src/main/java/io/grpc/internal/CallCredentialsApplyingTransportFactory.java index 42631851974..97a74bda97e 100644 --- a/core/src/main/java/io/grpc/internal/CallCredentialsApplyingTransportFactory.java +++ b/core/src/main/java/io/grpc/internal/CallCredentialsApplyingTransportFactory.java @@ -19,6 +19,7 @@ import static com.google.common.base.MoreObjects.firstNonNull; import static com.google.common.base.Preconditions.checkNotNull; +import com.google.errorprone.annotations.concurrent.GuardedBy; import io.grpc.Attributes; import io.grpc.CallCredentials; import io.grpc.CallCredentials.RequestInfo; @@ -38,7 +39,6 @@ import java.util.concurrent.Executor; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.atomic.AtomicInteger; -import javax.annotation.concurrent.GuardedBy; final class CallCredentialsApplyingTransportFactory implements ClientTransportFactory { private final ClientTransportFactory delegate; diff --git a/core/src/main/java/io/grpc/internal/CertificateUtils.java b/core/src/main/java/io/grpc/internal/CertificateUtils.java new file mode 100644 index 00000000000..130a435bb1a --- /dev/null +++ b/core/src/main/java/io/grpc/internal/CertificateUtils.java @@ -0,0 +1,106 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.internal; + +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.security.GeneralSecurityException; +import java.security.KeyStore; +import java.security.cert.Certificate; +import java.security.cert.CertificateException; +import java.security.cert.CertificateFactory; +import java.security.cert.X509Certificate; +import java.util.Collection; +import java.util.List; +import javax.net.ssl.TrustManager; +import javax.net.ssl.TrustManagerFactory; +import javax.net.ssl.X509TrustManager; +import javax.security.auth.x500.X500Principal; + +/** + * Contains certificate/key PEM file utility method(s) for internal usage. + */ +public final class CertificateUtils { + private static final Class x509ExtendedTrustManagerClass; + + static { + Class x509ExtendedTrustManagerClass1; + try { + x509ExtendedTrustManagerClass1 = Class.forName("javax.net.ssl.X509ExtendedTrustManager"); + } catch (ClassNotFoundException e) { + x509ExtendedTrustManagerClass1 = null; + // Will disallow per-rpc authority override via call option. + } + x509ExtendedTrustManagerClass = x509ExtendedTrustManagerClass1; + } + + /** + * Creates X509TrustManagers using the provided CA certs. + */ + public static TrustManager[] createTrustManager(byte[] rootCerts) + throws GeneralSecurityException { + InputStream rootCertsStream = new ByteArrayInputStream(rootCerts); + try { + return CertificateUtils.createTrustManager(rootCertsStream); + } finally { + GrpcUtil.closeQuietly(rootCertsStream); + } + } + + /** + * Creates X509TrustManagers using the provided input stream of CA certs. + */ + public static TrustManager[] createTrustManager(InputStream rootCerts) + throws GeneralSecurityException { + KeyStore ks = KeyStore.getInstance(KeyStore.getDefaultType()); + try { + ks.load(null, null); + } catch (IOException ex) { + // Shouldn't really happen, as we're not loading any data. + throw new GeneralSecurityException(ex); + } + X509Certificate[] certs = CertificateUtils.getX509Certificates(rootCerts); + for (X509Certificate cert : certs) { + X500Principal principal = cert.getSubjectX500Principal(); + ks.setCertificateEntry(principal.getName("RFC2253"), cert); + } + + TrustManagerFactory trustManagerFactory = + TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()); + trustManagerFactory.init(ks); + return trustManagerFactory.getTrustManagers(); + } + + public static X509TrustManager getX509ExtendedTrustManager(List trustManagers) { + if (x509ExtendedTrustManagerClass != null) { + for (TrustManager trustManager : trustManagers) { + if (x509ExtendedTrustManagerClass.isInstance(trustManager)) { + return (X509TrustManager) trustManager; + } + } + } + return null; + } + + private static X509Certificate[] getX509Certificates(InputStream inputStream) + throws CertificateException { + CertificateFactory factory = CertificateFactory.getInstance("X.509"); + Collection certs = factory.generateCertificates(inputStream); + return certs.toArray(new X509Certificate[0]); + } +} diff --git a/core/src/main/java/io/grpc/internal/ChannelTracer.java b/core/src/main/java/io/grpc/internal/ChannelTracer.java index 8c8243c9021..a9730a365cc 100644 --- a/core/src/main/java/io/grpc/internal/ChannelTracer.java +++ b/core/src/main/java/io/grpc/internal/ChannelTracer.java @@ -18,6 +18,7 @@ import static com.google.common.base.Preconditions.checkNotNull; +import com.google.errorprone.annotations.concurrent.GuardedBy; import io.grpc.ChannelLogger; import io.grpc.InternalChannelz.ChannelStats; import io.grpc.InternalChannelz.ChannelTrace; @@ -31,7 +32,6 @@ import java.util.logging.LogRecord; import java.util.logging.Logger; import javax.annotation.Nullable; -import javax.annotation.concurrent.GuardedBy; /** * Tracks a collections of channel tracing events for a channel/subchannel. diff --git a/core/src/main/java/io/grpc/internal/ClientCallImpl.java b/core/src/main/java/io/grpc/internal/ClientCallImpl.java index e2176668b73..4b24b1eae3d 100644 --- a/core/src/main/java/io/grpc/internal/ClientCallImpl.java +++ b/core/src/main/java/io/grpc/internal/ClientCallImpl.java @@ -28,7 +28,6 @@ import static io.grpc.internal.GrpcUtil.CONTENT_LENGTH_KEY; import static io.grpc.internal.GrpcUtil.MESSAGE_ACCEPT_ENCODING_KEY; import static io.grpc.internal.GrpcUtil.MESSAGE_ENCODING_KEY; -import static java.lang.Math.max; import com.google.common.annotations.VisibleForTesting; import com.google.common.base.MoreObjects; @@ -62,6 +61,7 @@ import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ScheduledFuture; import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; import java.util.logging.Level; import java.util.logging.Logger; import javax.annotation.Nullable; @@ -82,16 +82,13 @@ final class ClientCallImpl extends ClientCall { private final boolean callExecutorIsDirect; private final CallTracer channelCallsTracer; private final Context context; - private volatile ScheduledFuture deadlineCancellationFuture; + private CancellationHandler cancellationHandler; private final boolean unaryRequest; private CallOptions callOptions; private ClientStream stream; - private volatile boolean cancelListenersShouldBeRemoved; private boolean cancelCalled; private boolean halfCloseCalled; private final ClientStreamProvider clientStreamProvider; - private final ContextCancellationListener cancellationListener = - new ContextCancellationListener(); private final ScheduledExecutorService deadlineCancellationExecutor; private boolean fullStreamDecompression; private DecompressorRegistry decompressorRegistry = DecompressorRegistry.getDefaultInstance(); @@ -128,13 +125,6 @@ final class ClientCallImpl extends ClientCall { PerfMark.event("ClientCall.", tag); } - private final class ContextCancellationListener implements CancellationListener { - @Override - public void cancelled(Context context) { - stream.cancel(statusFromCancelled(context)); - } - } - /** * Provider of {@link ClientStream}s. */ @@ -252,21 +242,22 @@ public void runInContext() { prepareHeaders(headers, decompressorRegistry, compressor, fullStreamDecompression); Deadline effectiveDeadline = effectiveDeadline(); - boolean deadlineExceeded = effectiveDeadline != null && effectiveDeadline.isExpired(); + boolean contextIsDeadlineSource = effectiveDeadline != null + && effectiveDeadline.equals(context.getDeadline()); + cancellationHandler = new CancellationHandler(effectiveDeadline, contextIsDeadlineSource); + boolean deadlineExceeded = effectiveDeadline != null && cancellationHandler.remainingNanos <= 0; if (!deadlineExceeded) { - logIfContextNarrowedTimeout( - effectiveDeadline, context.getDeadline(), callOptions.getDeadline()); stream = clientStreamProvider.newStream(method, callOptions, headers, context); } else { ClientStreamTracer[] tracers = - GrpcUtil.getClientStreamTracers(callOptions, headers, 0, false); - String deadlineName = - isFirstMin(callOptions.getDeadline(), context.getDeadline()) ? "CallOptions" : "Context"; + GrpcUtil.getClientStreamTracers(callOptions, headers, 0, + false, false); + String deadlineName = contextIsDeadlineSource ? "Context" : "CallOptions"; Long nameResolutionDelay = callOptions.getOption(NAME_RESOLUTION_DELAYED); String description = String.format( "ClientCall started after %s deadline was exceeded %.9f seconds ago. " + "Name resolution delay %.9f seconds.", deadlineName, - effectiveDeadline.timeRemaining(TimeUnit.NANOSECONDS) / NANO_TO_SECS, + cancellationHandler.remainingNanos / NANO_TO_SECS, nameResolutionDelay == null ? 0 : nameResolutionDelay / NANO_TO_SECS); stream = new FailingClientStream(DEADLINE_EXCEEDED.withDescription(description), tracers); } @@ -298,21 +289,7 @@ public void runInContext() { // they receive cancel before start. Issue #1343 has more details // Propagate later Context cancellation to the remote side. - context.addListener(cancellationListener, directExecutor()); - if (effectiveDeadline != null - // If the context has the effective deadline, we don't need to schedule an extra task. - && !effectiveDeadline.equals(context.getDeadline()) - // If the channel has been terminated, we don't need to schedule an extra task. - && deadlineCancellationExecutor != null) { - deadlineCancellationFuture = startDeadlineTimer(effectiveDeadline); - } - if (cancelListenersShouldBeRemoved) { - // Race detected! ClientStreamListener.closed may have been called before - // deadlineCancellationFuture was set / context listener added, thereby preventing the future - // and listener from being cancelled. Go ahead and cancel again, just to be sure it - // was cancelled. - removeContextListenerAndCancelDeadlineFuture(); - } + cancellationHandler.setUp(); } private void applyMethodConfig() { @@ -354,54 +331,77 @@ private void applyMethodConfig() { } } - private static void logIfContextNarrowedTimeout( - Deadline effectiveDeadline, @Nullable Deadline outerCallDeadline, - @Nullable Deadline callDeadline) { - if (!log.isLoggable(Level.FINE) || effectiveDeadline == null - || !effectiveDeadline.equals(outerCallDeadline)) { - return; + private final class CancellationHandler implements Runnable, CancellationListener { + private final boolean contextIsDeadlineSource; + private final boolean hasDeadline; + private final long remainingNanos; + private volatile ScheduledFuture deadlineCancellationFuture; + private volatile boolean tearDownCalled; + + CancellationHandler(Deadline deadline, boolean contextIsDeadlineSource) { + this.contextIsDeadlineSource = contextIsDeadlineSource; + if (deadline == null) { + hasDeadline = false; + remainingNanos = 0; + } else { + hasDeadline = true; + remainingNanos = deadline.timeRemaining(TimeUnit.NANOSECONDS); + } } - long effectiveTimeout = max(0, effectiveDeadline.timeRemaining(TimeUnit.NANOSECONDS)); - StringBuilder builder = new StringBuilder(String.format( - Locale.US, - "Call timeout set to '%d' ns, due to context deadline.", effectiveTimeout)); - if (callDeadline == null) { - builder.append(" Explicit call timeout was not set."); - } else { - long callTimeout = callDeadline.timeRemaining(TimeUnit.NANOSECONDS); - builder.append(String.format(Locale.US, " Explicit call timeout was '%d' ns.", callTimeout)); + void setUp() { + if (tearDownCalled) { + return; + } + if (hasDeadline + // If the context has the effective deadline, we don't need to schedule an extra task. + && !contextIsDeadlineSource + // If the channel has been terminated, we don't need to schedule an extra task. + && deadlineCancellationExecutor != null) { + deadlineCancellationFuture = deadlineCancellationExecutor.schedule( + new LogExceptionRunnable(this), remainingNanos, TimeUnit.NANOSECONDS); + } + context.addListener(this, directExecutor()); + if (tearDownCalled) { + // Race detected! Re-run to make sure the future is cancelled and context listener removed + tearDown(); + } } - log.fine(builder.toString()); - } - - private void removeContextListenerAndCancelDeadlineFuture() { - context.removeListener(cancellationListener); - ScheduledFuture f = deadlineCancellationFuture; - if (f != null) { - f.cancel(false); + // May be called multiple times, and race with setUp() + void tearDown() { + tearDownCalled = true; + ScheduledFuture deadlineCancellationFuture = this.deadlineCancellationFuture; + if (deadlineCancellationFuture != null) { + deadlineCancellationFuture.cancel(false); + } + context.removeListener(this); } - } - - private class DeadlineTimer implements Runnable { - private final long remainingNanos; - DeadlineTimer(long remainingNanos) { - this.remainingNanos = remainingNanos; + @Override + public void cancelled(Context context) { + if (hasDeadline && contextIsDeadlineSource + && context.cancellationCause() instanceof TimeoutException) { + stream.cancel(formatDeadlineExceededStatus()); + return; + } + stream.cancel(statusFromCancelled(context)); } @Override public void run() { - InsightBuilder insight = new InsightBuilder(); - stream.appendTimeoutInsight(insight); + stream.cancel(formatDeadlineExceededStatus()); + } + + Status formatDeadlineExceededStatus() { // DelayedStream.cancel() is safe to call from a thread that is different from where the // stream is created. long seconds = Math.abs(remainingNanos) / TimeUnit.SECONDS.toNanos(1); long nanos = Math.abs(remainingNanos) % TimeUnit.SECONDS.toNanos(1); StringBuilder buf = new StringBuilder(); - buf.append("deadline exceeded after "); + buf.append(contextIsDeadlineSource ? "Context" : "CallOptions"); + buf.append(" deadline exceeded after "); if (remainingNanos < 0) { buf.append('-'); } @@ -409,20 +409,18 @@ public void run() { buf.append(String.format(Locale.US, ".%09d", nanos)); buf.append("s. "); Long nsDelay = callOptions.getOption(NAME_RESOLUTION_DELAYED); - buf.append(String.format(Locale.US, "Name resolution delay %.9f seconds. ", + buf.append(String.format(Locale.US, "Name resolution delay %.9f seconds.", nsDelay == null ? 0 : nsDelay / NANO_TO_SECS)); - buf.append(insight); - stream.cancel(DEADLINE_EXCEEDED.augmentDescription(buf.toString())); + if (stream != null) { + InsightBuilder insight = new InsightBuilder(); + stream.appendTimeoutInsight(insight); + buf.append(" "); + buf.append(insight); + } + return DEADLINE_EXCEEDED.withDescription(buf.toString()); } } - private ScheduledFuture startDeadlineTimer(Deadline deadline) { - long remainingNanos = deadline.timeRemaining(TimeUnit.NANOSECONDS); - return deadlineCancellationExecutor.schedule( - new LogExceptionRunnable( - new DeadlineTimer(remainingNanos)), remainingNanos, TimeUnit.NANOSECONDS); - } - @Nullable private Deadline effectiveDeadline() { // Call options and context are immutable, so we don't need to cache the deadline. @@ -440,16 +438,6 @@ private static Deadline min(@Nullable Deadline deadline0, @Nullable Deadline dea return deadline0.minimum(deadline1); } - private static boolean isFirstMin(@Nullable Deadline deadline0, @Nullable Deadline deadline1) { - if (deadline0 == null) { - return false; - } - if (deadline1 == null) { - return true; - } - return deadline0.isBefore(deadline1); - } - @Override public void request(int numMessages) { try (TaskCloseable ignore = PerfMark.traceTask("ClientCall.request")) { @@ -493,7 +481,10 @@ private void cancelInternal(@Nullable String message, @Nullable Throwable cause) stream.cancel(status); } } finally { - removeContextListenerAndCancelDeadlineFuture(); + // start() might not have been called + if (cancellationHandler != null) { + cancellationHandler.tearDown(); + } } } @@ -571,7 +562,11 @@ public Attributes getAttributes() { } private void closeObserver(Listener observer, Status status, Metadata trailers) { - observer.onClose(status, trailers); + try { + observer.onClose(status, trailers); + } catch (RuntimeException ex) { + log.log(Level.WARNING, "Exception thrown by onClose() in ClientCall", ex); + } } @Override @@ -699,10 +694,7 @@ private void closedInternal( // description. Since our timer may be delayed in firing, we double-check the deadline and // turn the failure into the likely more helpful DEADLINE_EXCEEDED status. if (deadline.isExpired()) { - InsightBuilder insight = new InsightBuilder(); - stream.appendTimeoutInsight(insight); - status = DEADLINE_EXCEEDED.augmentDescription( - "ClientCall was cancelled at or after deadline. " + insight); + status = cancellationHandler.formatDeadlineExceededStatus(); // Replace trailers to prevent mixing sources of status and trailers. trailers = new Metadata(); } @@ -725,6 +717,7 @@ public void runInContext() { } private void runInternal() { + cancellationHandler.tearDown(); Status status = savedStatus; Metadata trailers = savedTrailers; if (exceptionStatus != null) { @@ -737,11 +730,9 @@ private void runInternal() { // Replace trailers to prevent mixing sources of status and trailers. trailers = new Metadata(); } - cancelListenersShouldBeRemoved = true; try { closeObserver(observer, status, trailers); } finally { - removeContextListenerAndCancelDeadlineFuture(); channelCallsTracer.reportCallEnded(status.isOk()); } } diff --git a/core/src/main/java/io/grpc/internal/ClientStreamListener.java b/core/src/main/java/io/grpc/internal/ClientStreamListener.java index 8db1fbe445f..85a0626d426 100644 --- a/core/src/main/java/io/grpc/internal/ClientStreamListener.java +++ b/core/src/main/java/io/grpc/internal/ClientStreamListener.java @@ -53,11 +53,12 @@ public interface ClientStreamListener extends StreamListener { */ enum RpcProgress { /** - * The RPC is processed by the server normally. + * The RPC may have been processed by the server. */ PROCESSED, /** - * The stream on the wire is created but not processed by the server's application logic. + * Some part of the RPC may have been sent, but the server has guaranteed it didn't process any + * part of the RPC. */ REFUSED, /** diff --git a/core/src/main/java/io/grpc/internal/ClientTransport.java b/core/src/main/java/io/grpc/internal/ClientTransport.java index a569a7922df..3e2c2aea247 100644 --- a/core/src/main/java/io/grpc/internal/ClientTransport.java +++ b/core/src/main/java/io/grpc/internal/ClientTransport.java @@ -22,16 +22,17 @@ import io.grpc.InternalInstrumented; import io.grpc.Metadata; import io.grpc.MethodDescriptor; +import io.grpc.Status; import java.util.concurrent.Executor; -import javax.annotation.concurrent.ThreadSafe; /** * The client-side transport typically encapsulating a single connection to a remote * server. However, streams created before the client has discovered any server address may * eventually be issued on different connections. All methods on the transport and its callbacks * are expected to execute quickly. + * + *

This interface is thread-safe. */ -@ThreadSafe public interface ClientTransport extends InternalInstrumented { /** @@ -61,7 +62,7 @@ ClientStream newStream( * Pings a remote endpoint. When an acknowledgement is received, the given callback will be * invoked using the given executor. * - *

Pings are not necessarily sent to the same endpont, thus a successful ping only means at + *

Pings are not necessarily sent to the same endpoint, thus a successful ping only means at * least one endpoint responded, but doesn't imply the availability of other endpoints (if there * is any). * @@ -90,6 +91,6 @@ interface PingCallback { * * @param cause the cause of the ping failure */ - void onFailure(Throwable cause); + void onFailure(Status cause); } } diff --git a/core/src/main/java/io/grpc/internal/ClientTransportFactory.java b/core/src/main/java/io/grpc/internal/ClientTransportFactory.java index d987f9d5068..6023fb14aa9 100644 --- a/core/src/main/java/io/grpc/internal/ClientTransportFactory.java +++ b/core/src/main/java/io/grpc/internal/ClientTransportFactory.java @@ -18,16 +18,17 @@ import com.google.common.base.Objects; import com.google.common.base.Preconditions; +import com.google.errorprone.annotations.CheckReturnValue; import io.grpc.Attributes; import io.grpc.CallCredentials; import io.grpc.ChannelCredentials; import io.grpc.ChannelLogger; import io.grpc.HttpConnectProxiedSocketAddress; +import io.grpc.MetricRecorder; import java.io.Closeable; import java.net.SocketAddress; import java.util.Collection; import java.util.concurrent.ScheduledExecutorService; -import javax.annotation.CheckReturnValue; import javax.annotation.Nullable; /** Pre-configured factory for creating {@link ConnectionClientTransport} instances. */ @@ -91,6 +92,8 @@ final class ClientTransportOptions { private Attributes eagAttributes = Attributes.EMPTY; @Nullable private String userAgent; @Nullable private HttpConnectProxiedSocketAddress connectProxiedSocketAddr; + private MetricRecorder metricRecorder = new MetricRecorder() { + }; public ChannelLogger getChannelLogger() { return channelLogger; @@ -101,6 +104,15 @@ public ClientTransportOptions setChannelLogger(ChannelLogger channelLogger) { return this; } + public MetricRecorder getMetricRecorder() { + return metricRecorder; + } + + public ClientTransportOptions setMetricRecorder(MetricRecorder metricRecorder) { + this.metricRecorder = Preconditions.checkNotNull(metricRecorder, "metricRecorder"); + return this; + } + public String getAuthority() { return authority; } diff --git a/core/src/main/java/io/grpc/internal/CompositeReadableBuffer.java b/core/src/main/java/io/grpc/internal/CompositeReadableBuffer.java index ea654c5b9ba..6cedb2caee9 100644 --- a/core/src/main/java/io/grpc/internal/CompositeReadableBuffer.java +++ b/core/src/main/java/io/grpc/internal/CompositeReadableBuffer.java @@ -18,12 +18,10 @@ import java.io.IOException; import java.io.OutputStream; -import java.nio.Buffer; import java.nio.ByteBuffer; import java.nio.InvalidMarkException; import java.util.ArrayDeque; import java.util.Deque; -import java.util.Queue; import javax.annotation.Nullable; /** @@ -39,7 +37,6 @@ public class CompositeReadableBuffer extends AbstractReadableBuffer { private final Deque readableBuffers; private Deque rewindableBuffers; private int readableBytes; - private final Queue buffers = new ArrayDeque(2); private boolean marked; public CompositeReadableBuffer(int initialCapacity) { @@ -122,30 +119,6 @@ public int read(ReadableBuffer buffer, int length, byte[] dest, int offset) { } }; - @Override - public void readBytes(byte[] dest, int destOffset, int length) { - executeNoThrow(BYTE_ARRAY_OP, length, dest, destOffset); - } - - private static final NoThrowReadOperation BYTE_BUF_OP = - new NoThrowReadOperation() { - @Override - public int read(ReadableBuffer buffer, int length, ByteBuffer dest, int unused) { - // Change the limit so that only lengthToCopy bytes are available. - int prevLimit = dest.limit(); - ((Buffer) dest).limit(dest.position() + length); - // Write the bytes and restore the original limit. - buffer.readBytes(dest); - ((Buffer) dest).limit(prevLimit); - return 0; - } - }; - - @Override - public void readBytes(ByteBuffer dest) { - executeNoThrow(BYTE_BUF_OP, dest.remaining(), dest, 0); - } - private static final ReadOperation STREAM_OP = new ReadOperation() { @Override @@ -157,33 +130,13 @@ public int read(ReadableBuffer buffer, int length, OutputStream dest, int unused }; @Override - public void readBytes(OutputStream dest, int length) throws IOException { - execute(STREAM_OP, length, dest, 0); + public void readBytes(byte[] dest, int destOffset, int length) { + executeNoThrow(BYTE_ARRAY_OP, length, dest, destOffset); } - /** - * Reads {@code length} bytes from this buffer and writes them to the destination buffer. - * Increments the read position by {@code length}. If the required bytes are not readable, throws - * {@link IndexOutOfBoundsException}. - * - * @param dest the destination buffer to receive the bytes. - * @param length the number of bytes to be copied. - * @throws IndexOutOfBoundsException if required bytes are not readable - */ - public void readBytes(CompositeReadableBuffer dest, int length) { - checkReadable(length); - readableBytes -= length; - - while (length > 0) { - ReadableBuffer buffer = buffers.peek(); - if (buffer.readableBytes() > length) { - dest.addBuffer(buffer.readBytes(length)); - length = 0; - } else { - dest.addBuffer(buffers.poll()); - length -= buffer.readableBytes(); - } - } + @Override + public void readBytes(OutputStream dest, int length) throws IOException { + execute(STREAM_OP, length, dest, 0); } @Override diff --git a/core/src/main/java/io/grpc/internal/ConcurrentTimeProvider.java b/core/src/main/java/io/grpc/internal/ConcurrentTimeProvider.java new file mode 100644 index 00000000000..c82a68222b4 --- /dev/null +++ b/core/src/main/java/io/grpc/internal/ConcurrentTimeProvider.java @@ -0,0 +1,32 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.internal; + +import java.util.concurrent.TimeUnit; + +/** + * {@link ConcurrentTimeProvider} resolves ConcurrentTimeProvider which implements + * {@link TimeProvider}. + */ + +final class ConcurrentTimeProvider implements TimeProvider { + + @Override + public long currentTimeNanos() { + return TimeUnit.MILLISECONDS.toNanos(System.currentTimeMillis()); + } +} diff --git a/core/src/main/java/io/grpc/internal/ConnectionClientTransport.java b/core/src/main/java/io/grpc/internal/ConnectionClientTransport.java index 8385316d608..6199d9dad4d 100644 --- a/core/src/main/java/io/grpc/internal/ConnectionClientTransport.java +++ b/core/src/main/java/io/grpc/internal/ConnectionClientTransport.java @@ -17,12 +17,12 @@ package io.grpc.internal; import io.grpc.Attributes; -import javax.annotation.concurrent.ThreadSafe; /** * A {@link ManagedClientTransport} that is based on a connection. + * + *

This interface is thread-safe. */ -@ThreadSafe public interface ConnectionClientTransport extends ManagedClientTransport { /** * Returns a set of attributes, which may vary depending on the state of the transport. The keys diff --git a/core/src/main/java/io/grpc/internal/DelayedClientCall.java b/core/src/main/java/io/grpc/internal/DelayedClientCall.java index 92034e83f45..e0c05ca637e 100644 --- a/core/src/main/java/io/grpc/internal/DelayedClientCall.java +++ b/core/src/main/java/io/grpc/internal/DelayedClientCall.java @@ -22,6 +22,7 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.base.MoreObjects; +import com.google.errorprone.annotations.concurrent.GuardedBy; import io.grpc.Attributes; import io.grpc.ClientCall; import io.grpc.Context; @@ -38,7 +39,6 @@ import java.util.logging.Level; import java.util.logging.Logger; import javax.annotation.Nullable; -import javax.annotation.concurrent.GuardedBy; /** * A call that queues requests before a real call is ready to be delegated to. @@ -64,6 +64,8 @@ public class DelayedClientCall extends ClientCall { * order, but also used if an error occurs before {@code realCall} is set. */ private Listener listener; + // No need to synchronize; start() synchronization provides a happens-before + private Metadata startHeaders; // Must hold {@code this} lock when setting. private ClientCall realCall; @GuardedBy("this") @@ -96,15 +98,13 @@ private boolean isAbeforeB(@Nullable Deadline a, @Nullable Deadline b) { private ScheduledFuture scheduleDeadlineIfNeeded( ScheduledExecutorService scheduler, @Nullable Deadline deadline) { Deadline contextDeadline = context.getDeadline(); - if (deadline == null && contextDeadline == null) { - return null; - } - long remainingNanos = Long.MAX_VALUE; - if (deadline != null) { + String deadlineName; + long remainingNanos; + if (deadline != null && isAbeforeB(deadline, contextDeadline)) { + deadlineName = "CallOptions"; remainingNanos = deadline.timeRemaining(NANOSECONDS); - } - - if (contextDeadline != null && contextDeadline.timeRemaining(NANOSECONDS) < remainingNanos) { + } else if (contextDeadline != null) { + deadlineName = "Context"; remainingNanos = contextDeadline.timeRemaining(NANOSECONDS); if (logger.isLoggable(Level.FINE)) { StringBuilder builder = @@ -121,29 +121,29 @@ private ScheduledFuture scheduleDeadlineIfNeeded( } logger.fine(builder.toString()); } - } - - long seconds = Math.abs(remainingNanos) / TimeUnit.SECONDS.toNanos(1); - long nanos = Math.abs(remainingNanos) % TimeUnit.SECONDS.toNanos(1); - final StringBuilder buf = new StringBuilder(); - String deadlineName = isAbeforeB(contextDeadline, deadline) ? "Context" : "CallOptions"; - if (remainingNanos < 0) { - buf.append("ClientCall started after "); - buf.append(deadlineName); - buf.append(" deadline was exceeded. Deadline has been exceeded for "); } else { - buf.append("Deadline "); - buf.append(deadlineName); - buf.append(" will be exceeded in "); + return null; } - buf.append(seconds); - buf.append(String.format(Locale.US, ".%09d", nanos)); - buf.append("s. "); /* Cancels the call if deadline exceeded prior to the real call being set. */ class DeadlineExceededRunnable implements Runnable { @Override public void run() { + long seconds = Math.abs(remainingNanos) / TimeUnit.SECONDS.toNanos(1); + long nanos = Math.abs(remainingNanos) % TimeUnit.SECONDS.toNanos(1); + StringBuilder buf = new StringBuilder(); + if (remainingNanos < 0) { + buf.append("ClientCall started after "); + buf.append(deadlineName); + buf.append(" deadline was exceeded. Deadline has been exceeded for "); + } else { + buf.append("Deadline "); + buf.append(deadlineName); + buf.append(" was exceeded after "); + } + buf.append(seconds); + buf.append(String.format(Locale.US, ".%09d", nanos)); + buf.append("s"); cancel( Status.DEADLINE_EXCEEDED.withDescription(buf.toString()), // We should not cancel the call if the realCall is set because there could be a @@ -163,13 +163,23 @@ public void run() { */ // When this method returns, passThrough is guaranteed to be true public final Runnable setCall(ClientCall call) { + Listener savedDelayedListener; synchronized (this) { // If realCall != null, then either setCall() or cancel() has been called. if (realCall != null) { return null; } setRealCall(checkNotNull(call, "call")); + // start() not yet called + if (delayedListener == null) { + assert pendingRunnables.isEmpty(); + pendingRunnables = null; + passThrough = true; + return null; + } + savedDelayedListener = this.delayedListener; } + internalStart(savedDelayedListener); return new ContextRunnable(context) { @Override public void runInContext() { @@ -178,8 +188,15 @@ public void runInContext() { }; } + private void internalStart(Listener listener) { + Metadata savedStartHeaders = this.startHeaders; + this.startHeaders = null; + context.run(() -> realCall.start(listener, savedStartHeaders)); + } + @Override public final void start(Listener listener, final Metadata headers) { + checkNotNull(headers, "headers"); checkState(this.listener == null, "already started"); Status savedError; boolean savedPassThrough; @@ -189,7 +206,8 @@ public final void start(Listener listener, final Metadata headers) { savedError = error; savedPassThrough = passThrough; if (!savedPassThrough) { - listener = delayedListener = new DelayedListener<>(listener); + listener = delayedListener = new DelayedListener<>(this, listener); + startHeaders = headers; } } if (savedError != null) { @@ -198,15 +216,7 @@ public final void start(Listener listener, final Metadata headers) { } if (savedPassThrough) { realCall.start(listener, headers); - } else { - final Listener finalListener = listener; - delayOrExecute(new Runnable() { - @Override - public void run() { - realCall.start(finalListener, headers); - } - }); - } + } // else realCall.start() will be called by setCall } // When this method returns, passThrough is guaranteed to be true @@ -255,6 +265,7 @@ public void run() { if (listenerToClose != null) { callExecutor.execute(new CloseListenerRunnable(listenerToClose, status)); } + internalStart(listenerToClose); // listener instance doesn't matter drainPendingCalls(); } callCancelled(); @@ -434,15 +445,33 @@ public void runInContext() { } private static final class DelayedListener extends Listener { + private final DelayedClientCall call; private final Listener realListener; private volatile boolean passThrough; + private volatile Status exceptionStatus; @GuardedBy("this") private List pendingCallbacks = new ArrayList<>(); - public DelayedListener(Listener listener) { + public DelayedListener(DelayedClientCall call, Listener listener) { + this.call = call; this.realListener = listener; } + /** + * Cancels call and schedules onClose() notification. May only be called from within a + * DelayedListener callback dispatch (either queued drain or passThrough). Visibility of the + * write to {@code exceptionStatus} does not rely on a single callback executor; it is a + * {@code volatile} field, and callback queuing/pass-through transitions are coordinated by + * this listener's synchronization so subsequent callbacks observe the updated status. + */ + private void exceptionThrown(Throwable t, String description) { + // onClose() must be delivered exactly once and last. Other callbacks may already be queued + // ahead of realCall's eventual onClose, so we can't call onClose() here. We set the status + // and overwrite the onClose() details when it arrives. + exceptionStatus = Status.CANCELLED.withCause(t).withDescription(description); + call.cancel(description, t); + } + private void delayOrExecute(Runnable runnable) { synchronized (this) { if (!passThrough) { @@ -456,37 +485,75 @@ private void delayOrExecute(Runnable runnable) { @Override public void onHeaders(final Metadata headers) { if (passThrough) { - realListener.onHeaders(headers); + deliverHeaders(headers); } else { delayOrExecute(new Runnable() { @Override public void run() { - realListener.onHeaders(headers); + deliverHeaders(headers); } }); } } + private void deliverHeaders(Metadata headers) { + if (exceptionStatus != null) { + return; + } + try { + realListener.onHeaders(headers); + } catch (Throwable t) { + exceptionThrown(t, "Failed to read headers"); + } + } + @Override public void onMessage(final RespT message) { if (passThrough) { - realListener.onMessage(message); + deliverMessage(message); } else { delayOrExecute(new Runnable() { @Override public void run() { - realListener.onMessage(message); + deliverMessage(message); } }); } } + private void deliverMessage(RespT message) { + if (exceptionStatus != null) { + return; + } + try { + realListener.onMessage(message); + } catch (Throwable t) { + exceptionThrown(t, "Failed to read message."); + } + } + @Override public void onClose(final Status status, final Metadata trailers) { delayOrExecute(new Runnable() { @Override public void run() { - realListener.onClose(status, trailers); + Status effectiveStatus = status; + Metadata effectiveTrailers = trailers; + if (exceptionStatus != null) { + // Ideally status matches exceptionStatus, since exceptionStatus was used to cancel + // the call. However, cancel() may reconstruct a new Status instance, and the cancel + // is racy so this onClose may have already been queued when the cancellation + // occurred. Since other callbacks throw away data if exceptionStatus != null, it is + // semantically essential that we _not_ use a status provided by the server. + effectiveStatus = exceptionStatus; + // Replace trailers to prevent mixing sources of status and trailers. + effectiveTrailers = new Metadata(); + } + try { + realListener.onClose(effectiveStatus, effectiveTrailers); + } catch (RuntimeException ex) { + logger.log(Level.WARNING, "Exception thrown by onClose() in ClientCall", ex); + } } }); } @@ -494,17 +561,28 @@ public void run() { @Override public void onReady() { if (passThrough) { - realListener.onReady(); + deliverOnReady(); } else { delayOrExecute(new Runnable() { @Override public void run() { - realListener.onReady(); + deliverOnReady(); } }); } } + private void deliverOnReady() { + if (exceptionStatus != null) { + return; + } + try { + realListener.onReady(); + } catch (Throwable t) { + exceptionThrown(t, "Failed to call onReady."); + } + } + void drainPendingCallbacks() { assert !passThrough; List toRun = new ArrayList<>(); @@ -524,7 +602,6 @@ void drainPendingCallbacks() { } for (Runnable runnable : toRun) { // Avoid calling listener while lock is held to prevent deadlocks. - // TODO(ejona): exception handling runnable.run(); } toRun.clear(); diff --git a/core/src/main/java/io/grpc/internal/DelayedClientTransport.java b/core/src/main/java/io/grpc/internal/DelayedClientTransport.java index d71de1f5d53..5569e1eecf8 100644 --- a/core/src/main/java/io/grpc/internal/DelayedClientTransport.java +++ b/core/src/main/java/io/grpc/internal/DelayedClientTransport.java @@ -19,6 +19,7 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.SettableFuture; +import com.google.errorprone.annotations.concurrent.GuardedBy; import io.grpc.CallOptions; import io.grpc.ClientStreamTracer; import io.grpc.Context; @@ -39,7 +40,6 @@ import java.util.concurrent.Executor; import javax.annotation.Nonnull; import javax.annotation.Nullable; -import javax.annotation.concurrent.GuardedBy; /** * A client transport that queues requests before a real transport is available. When {@link @@ -69,28 +69,13 @@ final class DelayedClientTransport implements ManagedClientTransport { @GuardedBy("lock") private Collection pendingStreams = new LinkedHashSet<>(); - /** - * When {@code shutdownStatus != null && !hasPendingStreams()}, then the transport is considered - * terminated. - */ - @GuardedBy("lock") - private Status shutdownStatus; - - /** - * The last picker that {@link #reprocess} has used. May be set to null when the channel has moved - * to idle. - */ - @GuardedBy("lock") - @Nullable - private SubchannelPicker lastPicker; - - @GuardedBy("lock") - private long lastPickerVersion; + /** Immutable state needed for picking. 'lock' must be held for writing. */ + private volatile PickerState pickerState = new PickerState(null, null); /** * Creates a new delayed transport. * - * @param defaultAppExecutor pending streams will create real streams and run bufferred operations + * @param defaultAppExecutor pending streams will create real streams and run buffered operations * in an application executor, which will be this executor, unless there is on provided in * {@link CallOptions}. * @param syncContext all listener callbacks of the delayed transport will be run from this @@ -137,34 +122,45 @@ public final ClientStream newStream( MethodDescriptor method, Metadata headers, CallOptions callOptions, ClientStreamTracer[] tracers) { try { - PickSubchannelArgs args = new PickSubchannelArgsImpl(method, headers, callOptions); - SubchannelPicker picker = null; - long pickerVersion = -1; + PickSubchannelArgs args = new PickSubchannelArgsImpl( + method, headers, callOptions, new PickDetailsConsumerImpl(tracers)); + PickerState state = pickerState; while (true) { - synchronized (lock) { - if (shutdownStatus != null) { - return new FailingClientStream(shutdownStatus, tracers); - } - if (lastPicker == null) { - return createPendingStream(args, tracers); + if (state.shutdownStatus != null) { + return new FailingClientStream(state.shutdownStatus, tracers); + } + PickResult pickResult = null; + if (state.lastPicker != null) { + pickResult = state.lastPicker.pickSubchannel(args); + callOptions = args.getCallOptions(); + // User code provided authority takes precedence over the LB provided one. + if (callOptions.getAuthority() == null + && pickResult.getAuthorityOverride() != null) { + callOptions = callOptions.withAuthority(pickResult.getAuthorityOverride()); } - // Check for second time through the loop, and whether anything changed - if (picker != null && pickerVersion == lastPickerVersion) { - return createPendingStream(args, tracers); + ClientTransport transport = GrpcUtil.getTransportFromPickResult(pickResult, + callOptions.isWaitForReady()); + if (transport != null) { + ClientStream stream = transport.newStream( + args.getMethodDescriptor(), args.getHeaders(), callOptions, + tracers); + // User code provided authority takes precedence over the LB provided one; this will be + // overwritten by ClientCallImpl if the application sets an authority override + if (pickResult.getAuthorityOverride() != null) { + stream.setAuthority(pickResult.getAuthorityOverride()); + } + return stream; } - picker = lastPicker; - pickerVersion = lastPickerVersion; - } - PickResult pickResult = picker.pickSubchannel(args); - ClientTransport transport = GrpcUtil.getTransportFromPickResult(pickResult, - callOptions.isWaitForReady()); - if (transport != null) { - return transport.newStream( - args.getMethodDescriptor(), args.getHeaders(), args.getCallOptions(), - tracers); } // This picker's conclusion is "buffer". If there hasn't been a newer picker set (possible - // race with reprocess()), we will buffer it. Otherwise, will try with the new picker. + // race with reprocess()), we will buffer the RPC. Otherwise, will try with the new picker. + synchronized (lock) { + PickerState newerState = pickerState; + if (state == newerState) { + return createPendingStream(args, tracers, pickResult); + } + state = newerState; + } } } finally { syncContext.drain(); @@ -176,9 +172,12 @@ public final ClientStream newStream( * schedule tasks on syncContext. */ @GuardedBy("lock") - private PendingStream createPendingStream( - PickSubchannelArgs args, ClientStreamTracer[] tracers) { + private PendingStream createPendingStream(PickSubchannelArgs args, ClientStreamTracer[] tracers, + PickResult pickResult) { PendingStream pendingStream = new PendingStream(args, tracers); + if (args.getCallOptions().isWaitForReady() && pickResult != null && pickResult.hasResult()) { + pendingStream.lastPickStatus = pickResult.getStatus(); + } pendingStreams.add(pendingStream); if (getPendingStreamsCount() == 1) { syncContext.executeLater(reportTransportInUse); @@ -202,21 +201,21 @@ public ListenableFuture getStats() { } /** - * Prevents creating any new streams. Buffered streams are not failed and may still proceed - * when {@link #reprocess} is called. The delayed transport will be terminated when there is no + * Prevents creating any new streams. Buffered streams are not failed and may still proceed + * when {@link #reprocess} is called. The delayed transport will be terminated when there is no * more buffered streams. */ @Override public final void shutdown(final Status status) { synchronized (lock) { - if (shutdownStatus != null) { + if (pickerState.shutdownStatus != null) { return; } - shutdownStatus = status; + pickerState = pickerState.withShutdownStatus(status); syncContext.executeLater(new Runnable() { @Override public void run() { - listener.transportShutdown(status); + listener.transportShutdown(status, SimpleDisconnectError.SUBCHANNEL_SHUTDOWN); } }); if (!hasPendingStreams() && reportTransportTerminated != null) { @@ -287,8 +286,7 @@ final int getPendingStreamsCount() { final void reprocess(@Nullable SubchannelPicker picker) { ArrayList toProcess; synchronized (lock) { - lastPicker = picker; - lastPickerVersion++; + pickerState = pickerState.withPicker(picker); if (picker == null || !hasPendingStreams()) { return; } @@ -299,6 +297,9 @@ final void reprocess(@Nullable SubchannelPicker picker) { for (final PendingStream stream : toProcess) { PickResult pickResult = picker.pickSubchannel(stream.args); CallOptions callOptions = stream.args.getCallOptions(); + if (callOptions.isWaitForReady() && pickResult.hasResult()) { + stream.lastPickStatus = pickResult.getStatus(); + } final ClientTransport transport = GrpcUtil.getTransportFromPickResult(pickResult, callOptions.isWaitForReady()); if (transport != null) { @@ -309,7 +310,7 @@ final void reprocess(@Nullable SubchannelPicker picker) { if (callOptions.getExecutor() != null) { executor = callOptions.getExecutor(); } - Runnable runnable = stream.createRealStream(transport); + Runnable runnable = stream.createRealStream(transport, pickResult.getAuthorityOverride()); if (runnable != null) { executor.execute(runnable); } @@ -324,7 +325,11 @@ final void reprocess(@Nullable SubchannelPicker picker) { if (!hasPendingStreams()) { return; } - pendingStreams.removeAll(toRemove); + // Avoid pendingStreams.removeAll() as it can degrade to calling toRemove.contains() for each + // element in pendingStreams. + for (PendingStream stream : toRemove) { + pendingStreams.remove(stream); + } // Because delayed transport is long-lived, we take this opportunity to down-size the // hashmap. if (pendingStreams.isEmpty()) { @@ -337,7 +342,7 @@ final void reprocess(@Nullable SubchannelPicker picker) { // (which would shutdown the transports and LoadBalancer) because the gap should be shorter // than IDLE_MODE_DEFAULT_TIMEOUT_MILLIS (1 second). syncContext.executeLater(reportTransportNotInUse); - if (shutdownStatus != null && reportTransportTerminated != null) { + if (pickerState.shutdownStatus != null && reportTransportTerminated != null) { syncContext.executeLater(reportTransportTerminated); reportTransportTerminated = null; } @@ -355,14 +360,16 @@ private class PendingStream extends DelayedStream { private final PickSubchannelArgs args; private final Context context = Context.current(); private final ClientStreamTracer[] tracers; + private volatile Status lastPickStatus; private PendingStream(PickSubchannelArgs args, ClientStreamTracer[] tracers) { + super("connecting_and_lb"); this.args = args; this.tracers = tracers; } /** Runnable may be null. */ - private Runnable createRealStream(ClientTransport transport) { + private Runnable createRealStream(ClientTransport transport, String authorityOverride) { ClientStream realStream; Context origContext = context.attach(); try { @@ -372,6 +379,13 @@ private Runnable createRealStream(ClientTransport transport) { } finally { context.detach(origContext); } + if (authorityOverride != null) { + // User code provided authority takes precedence over the LB provided one; this will be + // overwritten by an enqueud call from ClientCallImpl if the application sets an authority + // override. We must call the real stream directly because stream.start() has likely already + // been called on the delayed stream. + realStream.setAuthority(authorityOverride); + } return setStream(realStream); } @@ -383,7 +397,7 @@ public void cancel(Status reason) { boolean justRemovedAnElement = pendingStreams.remove(this); if (!hasPendingStreams() && justRemovedAnElement) { syncContext.executeLater(reportTransportNotInUse); - if (shutdownStatus != null) { + if (pickerState.shutdownStatus != null) { syncContext.executeLater(reportTransportTerminated); reportTransportTerminated = null; } @@ -404,8 +418,40 @@ protected void onEarlyCancellation(Status reason) { public void appendTimeoutInsight(InsightBuilder insight) { if (args.getCallOptions().isWaitForReady()) { insight.append("wait_for_ready"); + Status status = lastPickStatus; + if (status != null && !status.isOk()) { + insight.appendKeyValue("Last Pick Failure", status); + } } super.appendTimeoutInsight(insight); } } + + static final class PickerState { + /** + * The last picker that {@link #reprocess} has used. May be set to null when the channel has + * moved to idle. + */ + @Nullable + final SubchannelPicker lastPicker; + /** + * When {@code shutdownStatus != null && !hasPendingStreams()}, then the transport is considered + * terminated. + */ + @Nullable + final Status shutdownStatus; + + private PickerState(SubchannelPicker lastPicker, Status shutdownStatus) { + this.lastPicker = lastPicker; + this.shutdownStatus = shutdownStatus; + } + + public PickerState withPicker(SubchannelPicker newPicker) { + return new PickerState(newPicker, this.shutdownStatus); + } + + public PickerState withShutdownStatus(Status newShutdownStatus) { + return new PickerState(this.lastPicker, newShutdownStatus); + } + } } diff --git a/core/src/main/java/io/grpc/internal/DelayedStream.java b/core/src/main/java/io/grpc/internal/DelayedStream.java index 28ce2764c75..a2b1e963ac5 100644 --- a/core/src/main/java/io/grpc/internal/DelayedStream.java +++ b/core/src/main/java/io/grpc/internal/DelayedStream.java @@ -20,6 +20,8 @@ import static com.google.common.base.Preconditions.checkState; import com.google.common.annotations.VisibleForTesting; +import com.google.errorprone.annotations.CheckReturnValue; +import com.google.errorprone.annotations.concurrent.GuardedBy; import io.grpc.Attributes; import io.grpc.Compressor; import io.grpc.Deadline; @@ -30,8 +32,6 @@ import java.io.InputStream; import java.util.ArrayList; import java.util.List; -import javax.annotation.CheckReturnValue; -import javax.annotation.concurrent.GuardedBy; /** * A stream that queues requests before the transport is available, and delegates to a real stream @@ -42,11 +42,12 @@ * necessary. */ class DelayedStream implements ClientStream { + private final String bufferContext; /** {@code true} once realStream is valid and all pending calls have been drained. */ private volatile boolean passThrough; /** * Non-{@code null} iff start has been called. Used to assert methods are called in appropriate - * order, but also used if an error occurrs before {@code realStream} is set. + * order, but also used if an error occurs before {@code realStream} is set. */ private ClientStreamListener listener; /** Must hold {@code this} lock when setting. */ @@ -64,6 +65,14 @@ class DelayedStream implements ClientStream { // No need to synchronize; start() synchronization provides a happens-before private List preStartPendingCalls = new ArrayList<>(); + /** + * Create a delayed stream with debug context {@code bufferContext}. The context is what this + * stream is delayed by (e.g., "connecting", "call_credentials"). + */ + public DelayedStream(String bufferContext) { + this.bufferContext = checkNotNull(bufferContext, "bufferContext"); + } + @Override public void setMaxInboundMessageSize(final int maxSize) { checkState(listener == null, "May only be called before start"); @@ -104,11 +113,13 @@ public void appendTimeoutInsight(InsightBuilder insight) { return; } if (realStream != null) { - insight.appendKeyValue("buffered_nanos", streamSetTimeNanos - startTimeNanos); + insight.appendKeyValue( + bufferContext + "_delay", "" + (streamSetTimeNanos - startTimeNanos) + "ns"); realStream.appendTimeoutInsight(insight); } else { - insight.appendKeyValue("buffered_nanos", System.nanoTime() - startTimeNanos); - insight.append("waiting_for_connection"); + insight.appendKeyValue( + bufferContext + "_delay", "" + (System.nanoTime() - startTimeNanos) + "ns"); + insight.append("was_still_waiting"); } } } @@ -208,7 +219,6 @@ private void delayOrExecute(Runnable runnable) { @Override public void setAuthority(final String authority) { - checkState(listener == null, "May only be called before start"); checkNotNull(authority, "authority"); preStartPendingCalls.add(new Runnable() { @Override diff --git a/core/src/main/java/io/grpc/internal/DisconnectError.java b/core/src/main/java/io/grpc/internal/DisconnectError.java new file mode 100644 index 00000000000..771024f106e --- /dev/null +++ b/core/src/main/java/io/grpc/internal/DisconnectError.java @@ -0,0 +1,34 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.internal; + +import javax.annotation.concurrent.Immutable; + +/** + * Represents the reason for a subchannel disconnection. + * Implementations are either the SimpleDisconnectError enum or the GoAwayDisconnectError class for + * dynamic ones. + */ +@Immutable +public interface DisconnectError { + /** + * Returns the string representation suitable for use as an error tag. + * + * @return The formatted error tag string. + */ + String toErrorString(); +} diff --git a/core/src/main/java/io/grpc/internal/DnsNameResolver.java b/core/src/main/java/io/grpc/internal/DnsNameResolver.java index 5ef6dd863c2..1c1d95ed616 100644 --- a/core/src/main/java/io/grpc/internal/DnsNameResolver.java +++ b/core/src/main/java/io/grpc/internal/DnsNameResolver.java @@ -23,15 +23,14 @@ import com.google.common.base.Objects; import com.google.common.base.Preconditions; import com.google.common.base.Stopwatch; -import com.google.common.base.Throwables; import com.google.common.base.Verify; import com.google.common.base.VerifyException; -import io.grpc.Attributes; import io.grpc.EquivalentAddressGroup; import io.grpc.NameResolver; import io.grpc.ProxiedSocketAddress; import io.grpc.ProxyDetector; import io.grpc.Status; +import io.grpc.StatusOr; import io.grpc.SynchronizationContext; import io.grpc.internal.SharedResourceHolder.Resource; import java.io.IOException; @@ -59,7 +58,7 @@ * A DNS-based {@link NameResolver}. * *

Each {@code A} or {@code AAAA} record emits an {@link EquivalentAddressGroup} in the list - * passed to {@link NameResolver.Listener2#onResult(ResolutionResult)}. + * passed to {@link NameResolver.Listener2#onResult2(ResolutionResult)}. * * @see DnsNameResolverProvider */ @@ -100,7 +99,7 @@ public class DnsNameResolver extends NameResolver { * not installed, the ttl value is {@code null} which falls back to {@link * #DEFAULT_NETWORK_CACHE_TTL_SECONDS gRPC default value}. * - *

For android, gRPC doesn't attempt to cache; this property value will be ignored. + *

For android, gRPC uses a fixed value; this property value will be ignored. */ @VisibleForTesting static final String NETWORKADDRESS_CACHE_TTL_PROPERTY = "networkaddress.cache.ttl"; @@ -133,10 +132,10 @@ public class DnsNameResolver extends NameResolver { private final String host; private final int port; - /** Executor that will be used if an Executor is not provide via {@link NameResolver.Args}. */ - private final Resource executorResource; + private final ObjectPool executorPool; private final long cacheTtlNanos; private final SynchronizationContext syncContext; + private final ServiceConfigParser serviceConfigParser; // Following fields must be accessed from syncContext private final Stopwatch stopwatch; @@ -144,10 +143,6 @@ public class DnsNameResolver extends NameResolver { private boolean shutdown; private Executor executor; - /** True if using an executor resource that should be released after use. */ - private final boolean usingExecutorResource; - private final ServiceConfigParser serviceConfigParser; - private boolean resolving; // The field must be accessed from syncContext, although the methods on an Listener2 can be called @@ -164,7 +159,7 @@ protected DnsNameResolver( checkNotNull(args, "args"); // TODO: if a DNS server is provided as nsAuthority, use it. // https://www.captechconsulting.com/blogs/accessing-the-dusty-corners-of-dns-with-java - this.executorResource = executorResource; + // Must prepend a "//" to the name when constructing a URI, otherwise it will be treated as an // opaque URI, thus the authority and host of the resulted URI would be null. URI nameUri = URI.create("//" + checkNotNull(name, "name")); @@ -178,11 +173,15 @@ protected DnsNameResolver( port = nameUri.getPort(); } this.proxyDetector = checkNotNull(args.getProxyDetector(), "proxyDetector"); + Executor offloadExecutor = args.getOffloadExecutor(); + if (offloadExecutor != null) { + this.executorPool = new FixedObjectPool<>(offloadExecutor); + } else { + this.executorPool = SharedResourcePool.forResource(executorResource); + } this.cacheTtlNanos = getNetworkAddressCacheTtlNanos(isAndroid); this.stopwatch = checkNotNull(stopwatch, "stopwatch"); this.syncContext = checkNotNull(args.getSynchronizationContext(), "syncContext"); - this.executor = args.getOffloadExecutor(); - this.usingExecutorResource = executor == null; this.serviceConfigParser = checkNotNull(args.getServiceConfigParser(), "serviceConfigParser"); } @@ -199,9 +198,7 @@ protected String getHost() { @Override public void start(Listener2 listener) { Preconditions.checkState(this.listener == null, "already started"); - if (usingExecutorResource) { - executor = SharedResourceHolder.get(executorResource); - } + executor = executorPool.getObject(); this.listener = checkNotNull(listener, "listener"); resolve(); } @@ -212,20 +209,8 @@ public void refresh() { resolve(); } - private List resolveAddresses() { - List addresses; - Exception addressesException = null; - try { - addresses = addressResolver.resolveAddress(host); - } catch (Exception e) { - addressesException = e; - Throwables.throwIfUnchecked(e); - throw new RuntimeException(e); - } finally { - if (addressesException != null) { - logger.log(Level.FINE, "Address resolution failure", addressesException); - } - } + private List resolveAddresses() throws Exception { + List addresses = addressResolver.resolveAddress(host); // Each address forms an EAG List servers = new ArrayList<>(addresses.size()); for (InetAddress inetAddr : addresses) { @@ -276,21 +261,19 @@ private EquivalentAddressGroup detectProxy() throws IOException { /** * Main logic of name resolution. */ - protected InternalResolutionResult doResolve(boolean forceTxt) { - InternalResolutionResult result = new InternalResolutionResult(); + protected ResolutionResult doResolve() { + ResolutionResult.Builder resultBuilder = ResolutionResult.newBuilder(); try { - result.addresses = resolveAddresses(); + resultBuilder.setAddressesOrError(StatusOr.fromValue(resolveAddresses())); } catch (Exception e) { - if (!forceTxt) { - result.error = - Status.UNAVAILABLE.withDescription("Unable to resolve host " + host).withCause(e); - return result; - } + logger.log(Level.FINE, "Address resolution failure", e); + resultBuilder.setAddressesOrError(StatusOr.fromStatus( + Status.UNAVAILABLE.withDescription("Unable to resolve host " + host).withCause(e))); } if (enableTxt) { - result.config = resolveServiceConfig(); + resultBuilder.setServiceConfig(resolveServiceConfig()); } - return result; + return resultBuilder.build(); } private final class Resolve implements Runnable { @@ -305,37 +288,32 @@ public void run() { if (logger.isLoggable(Level.FINER)) { logger.finer("Attempting DNS resolution of " + host); } - InternalResolutionResult result = null; + ResolutionResult result = null; try { EquivalentAddressGroup proxiedAddr = detectProxy(); - ResolutionResult.Builder resolutionResultBuilder = ResolutionResult.newBuilder(); if (proxiedAddr != null) { if (logger.isLoggable(Level.FINER)) { logger.finer("Using proxy address " + proxiedAddr); } - resolutionResultBuilder.setAddresses(Collections.singletonList(proxiedAddr)); + result = ResolutionResult.newBuilder() + .setAddressesOrError(StatusOr.fromValue(Collections.singletonList(proxiedAddr))) + .build(); } else { - result = doResolve(false); - if (result.error != null) { - savedListener.onError(result.error); - return; - } - if (result.addresses != null) { - resolutionResultBuilder.setAddresses(result.addresses); - } - if (result.config != null) { - resolutionResultBuilder.setServiceConfig(result.config); - } - if (result.attributes != null) { - resolutionResultBuilder.setAttributes(result.attributes); - } + result = doResolve(); } - savedListener.onResult(resolutionResultBuilder.build()); + ResolutionResult savedResult = result; + syncContext.execute(() -> { + savedListener.onResult2(savedResult); + }); } catch (IOException e) { - savedListener.onError( - Status.UNAVAILABLE.withDescription("Unable to resolve host " + host).withCause(e)); + syncContext.execute(() -> + savedListener.onResult2(ResolutionResult.newBuilder() + .setAddressesOrError( + StatusOr.fromStatus( + Status.UNAVAILABLE.withDescription( + "Unable to resolve host " + host).withCause(e))).build())); } finally { - final boolean succeed = result != null && result.error == null; + final boolean succeed = result != null && result.getAddressesOrError().hasValue(); syncContext.execute(new Runnable() { @Override public void run() { @@ -401,8 +379,8 @@ public void shutdown() { return; } shutdown = true; - if (executor != null && usingExecutorResource) { - executor = SharedResourceHolder.release(executorResource, executor); + if (executor != null) { + executor = executorPool.returnObject(executor); } } @@ -453,12 +431,14 @@ private static final List getHostnamesFromChoice(Map serviceC /** * Returns value of network address cache ttl property if not Android environment. For android, - * DnsNameResolver does not cache the dns lookup result. + * DnsNameResolver uses a fixed value. */ private static long getNetworkAddressCacheTtlNanos(boolean isAndroid) { if (isAndroid) { - // on Android, ignore dns cache. - return 0; + // On Android, use fixed value. If the network used changes this value shouldn't matter, as + // channel.enterIdle() should be called and this name resolver instance will be discarded. The + // new name resolver instance will then re-request. + return TimeUnit.SECONDS.toNanos(DEFAULT_NETWORK_CACHE_TTL_SECONDS); } String cacheTtlPropertyValue = System.getProperty(NETWORKADDRESS_CACHE_TTL_PROPERTY); @@ -480,7 +460,7 @@ private static long getNetworkAddressCacheTtlNanos(boolean isAndroid) { * Determines if a given Service Config choice applies, and if so, returns it. * * @see - * Service Config in DNS + * Service Config in DNS * @param choice The service config choice. * @return The service config object or {@code null} if this choice does not apply. */ @@ -535,18 +515,6 @@ private static long getNetworkAddressCacheTtlNanos(boolean isAndroid) { return sc; } - /** - * Used as a DNS-based name resolver's internal representation of resolution result. - */ - protected static final class InternalResolutionResult { - private Status error; - private List addresses; - private ConfigOrError config; - public Attributes attributes; - - private InternalResolutionResult() {} - } - /** * Describes a parsed SRV record. */ diff --git a/core/src/main/java/io/grpc/internal/DnsNameResolverProvider.java b/core/src/main/java/io/grpc/internal/DnsNameResolverProvider.java index c977fbb0cca..14b56f1a12a 100644 --- a/core/src/main/java/io/grpc/internal/DnsNameResolverProvider.java +++ b/core/src/main/java/io/grpc/internal/DnsNameResolverProvider.java @@ -21,25 +21,31 @@ import io.grpc.InternalServiceProviders; import io.grpc.NameResolver; import io.grpc.NameResolverProvider; +import io.grpc.Uri; import java.net.InetSocketAddress; import java.net.SocketAddress; import java.net.URI; import java.util.Collection; import java.util.Collections; +import java.util.List; /** * A provider for {@link DnsNameResolver}. * *

It resolves a target URI whose scheme is {@code "dns"}. The (optional) authority of the target - * URI is reserved for the address of alternative DNS server (not implemented yet). The path of the - * target URI, excluding the leading slash {@code '/'}, is treated as the host name and the optional - * port to be resolved by DNS. Example target URIs: + * URI is reserved for the address of alternative DNS server (not implemented yet). The target URI + * must be hierarchical and have exactly one path segment which will be interpreted as an RFC 2396 + * "server-based" authority and used as the "service authority" of the resulting {@link + * NameResolver}. The "host" part of this authority is the name to be resolved by DNS. The "port" + * part of this authority (if present) will become the port number for all {@link InetSocketAddress} + * produced by this resolver. For example: * *

    *
  • {@code "dns:///foo.googleapis.com:8080"} (using default DNS)
  • *
  • {@code "dns://8.8.8.8/foo.googleapis.com:8080"} (using alternative DNS (not implemented * yet))
  • - *
  • {@code "dns:///foo.googleapis.com"} (without port)
  • + *
  • {@code "dns:///foo.googleapis.com"} (output addresses will have port {@link + * NameResolver.Args#getDefaultPort()})
  • *
*/ public final class DnsNameResolverProvider extends NameResolverProvider { @@ -51,6 +57,7 @@ public final class DnsNameResolverProvider extends NameResolverProvider { @Override public NameResolver newNameResolver(URI targetUri, NameResolver.Args args) { + // TODO(jdcormie): Remove once RFC 3986 migration is complete. if (SCHEME.equals(targetUri.getScheme())) { String targetPath = Preconditions.checkNotNull(targetUri.getPath(), "targetPath"); Preconditions.checkArgument(targetPath.startsWith("/"), @@ -68,6 +75,25 @@ public NameResolver newNameResolver(URI targetUri, NameResolver.Args args) { } } + @Override + public NameResolver newNameResolver(Uri targetUri, final NameResolver.Args args) { + if (SCHEME.equals(targetUri.getScheme())) { + List pathSegments = targetUri.getPathSegments(); + Preconditions.checkArgument(!pathSegments.isEmpty(), + "expected 1 path segment in target %s but found %s", targetUri, pathSegments); + String domainNameToResolve = pathSegments.get(0); + return new DnsNameResolver( + targetUri.getAuthority(), + domainNameToResolve, + args, + GrpcUtil.SHARED_CHANNEL_EXECUTOR, + Stopwatch.createUnstarted(), + IS_ANDROID); + } else { + return null; + } + } + @Override public String getDefaultScheme() { return SCHEME; diff --git a/core/src/main/java/io/grpc/internal/FailingClientTransport.java b/core/src/main/java/io/grpc/internal/FailingClientTransport.java index 5b31e6e5073..37194c46a29 100644 --- a/core/src/main/java/io/grpc/internal/FailingClientTransport.java +++ b/core/src/main/java/io/grpc/internal/FailingClientTransport.java @@ -55,7 +55,7 @@ public ClientStream newStream( public void ping(final PingCallback callback, Executor executor) { executor.execute(new Runnable() { @Override public void run() { - callback.onFailure(error.asException()); + callback.onFailure(error); } }); } diff --git a/core/src/main/java/io/grpc/internal/FixedPickerLoadBalancerProvider.java b/core/src/main/java/io/grpc/internal/FixedPickerLoadBalancerProvider.java new file mode 100644 index 00000000000..a632948bdb9 --- /dev/null +++ b/core/src/main/java/io/grpc/internal/FixedPickerLoadBalancerProvider.java @@ -0,0 +1,80 @@ +/* + * Copyright 2026 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.internal; + +import static java.util.Objects.requireNonNull; + +import io.grpc.ConnectivityState; +import io.grpc.LoadBalancer; +import io.grpc.LoadBalancerProvider; +import io.grpc.Status; + +/** A LB provider whose LB always uses the same picker. */ +final class FixedPickerLoadBalancerProvider extends LoadBalancerProvider { + private final ConnectivityState state; + private final LoadBalancer.SubchannelPicker picker; + private final Status acceptAddressesStatus; + + public FixedPickerLoadBalancerProvider( + ConnectivityState state, LoadBalancer.SubchannelPicker picker, Status acceptAddressesStatus) { + this.state = requireNonNull(state, "state"); + this.picker = requireNonNull(picker, "picker"); + this.acceptAddressesStatus = requireNonNull(acceptAddressesStatus, "acceptAddressesStatus"); + } + + @Override + public boolean isAvailable() { + return true; + } + + @Override + public int getPriority() { + return 5; + } + + @Override + public String getPolicyName() { + return "fixed_picker_lb_internal"; + } + + @Override + public LoadBalancer newLoadBalancer(LoadBalancer.Helper helper) { + return new FixedPickerLoadBalancer(helper); + } + + private final class FixedPickerLoadBalancer extends LoadBalancer { + private final Helper helper; + + public FixedPickerLoadBalancer(Helper helper) { + this.helper = requireNonNull(helper, "helper"); + } + + @Override + public Status acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { + helper.updateBalancingState(state, picker); + return acceptAddressesStatus; + } + + @Override + public void handleNameResolutionError(Status error) { + helper.updateBalancingState(state, picker); + } + + @Override + public void shutdown() {} + } +} diff --git a/core/src/main/java/io/grpc/internal/ForwardingClientStreamTracer.java b/core/src/main/java/io/grpc/internal/ForwardingClientStreamTracer.java index 4740a811f3a..e7679ea14cc 100644 --- a/core/src/main/java/io/grpc/internal/ForwardingClientStreamTracer.java +++ b/core/src/main/java/io/grpc/internal/ForwardingClientStreamTracer.java @@ -49,11 +49,21 @@ public void inboundHeaders() { delegate().inboundHeaders(); } + @Override + public void inboundHeaders(Metadata headers) { + delegate().inboundHeaders(headers); + } + @Override public void inboundTrailers(Metadata trailers) { delegate().inboundTrailers(trailers); } + @Override + public void addOptionalLabel(String key, String value) { + delegate().addOptionalLabel(key, value); + } + @Override public void streamClosed(Status status) { delegate().streamClosed(status); diff --git a/core/src/main/java/io/grpc/internal/ForwardingReadableBuffer.java b/core/src/main/java/io/grpc/internal/ForwardingReadableBuffer.java index 06d04b6de2d..7e690309647 100644 --- a/core/src/main/java/io/grpc/internal/ForwardingReadableBuffer.java +++ b/core/src/main/java/io/grpc/internal/ForwardingReadableBuffer.java @@ -67,11 +67,6 @@ public void readBytes(byte[] dest, int destOffset, int length) { buf.readBytes(dest, destOffset, length); } - @Override - public void readBytes(ByteBuffer dest) { - buf.readBytes(dest); - } - @Override public void readBytes(OutputStream dest, int length) throws IOException { buf.readBytes(dest, length); diff --git a/core/src/main/java/io/grpc/internal/GoAwayDisconnectError.java b/core/src/main/java/io/grpc/internal/GoAwayDisconnectError.java new file mode 100644 index 00000000000..20c8c709932 --- /dev/null +++ b/core/src/main/java/io/grpc/internal/GoAwayDisconnectError.java @@ -0,0 +1,64 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.internal; + + +import javax.annotation.concurrent.Immutable; + +/** + * Represents a dynamic disconnection due to an HTTP/2 GOAWAY frame. + * This class is immutable and holds the specific error code from the frame. + */ +@Immutable +public final class GoAwayDisconnectError implements DisconnectError { + private static final String ERROR_TAG = "GOAWAY"; + private final GrpcUtil.Http2Error errorCode; + + /** + * Creates a GoAway reason. + * + * @param errorCode The specific, non-null HTTP/2 error code (e.g., "NO_ERROR"). + */ + public GoAwayDisconnectError(GrpcUtil.Http2Error errorCode) { + if (errorCode == null) { + throw new NullPointerException("Http2Error cannot be null for GOAWAY"); + } + this.errorCode = errorCode; + } + + @Override + public String toErrorString() { + return ERROR_TAG + " " + errorCode.name(); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + GoAwayDisconnectError goAwayDisconnectError = (GoAwayDisconnectError) o; + return errorCode == goAwayDisconnectError.errorCode; + } + + @Override + public int hashCode() { + return errorCode.hashCode(); + } +} diff --git a/core/src/main/java/io/grpc/internal/GrpcAttributes.java b/core/src/main/java/io/grpc/internal/GrpcAttributes.java index da43ae14800..f95f9b9dab8 100644 --- a/core/src/main/java/io/grpc/internal/GrpcAttributes.java +++ b/core/src/main/java/io/grpc/internal/GrpcAttributes.java @@ -42,5 +42,8 @@ public final class GrpcAttributes { public static final Attributes.Key ATTR_CLIENT_EAG_ATTRS = Attributes.Key.create("io.grpc.internal.GrpcAttributes.clientEagAttrs"); + public static final Attributes.Key ATTR_AUTHORITY_VERIFIER = + Attributes.Key.create("io.grpc.internal.GrpcAttributes.authorityVerifier"); + private GrpcAttributes() {} } diff --git a/core/src/main/java/io/grpc/internal/GrpcUtil.java b/core/src/main/java/io/grpc/internal/GrpcUtil.java index aaf7cea3c75..c419f028f58 100644 --- a/core/src/main/java/io/grpc/internal/GrpcUtil.java +++ b/core/src/main/java/io/grpc/internal/GrpcUtil.java @@ -24,7 +24,6 @@ import com.google.common.base.Preconditions; import com.google.common.base.Splitter; import com.google.common.base.Stopwatch; -import com.google.common.base.Strings; import com.google.common.base.Supplier; import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.ThreadFactoryBuilder; @@ -32,6 +31,7 @@ import io.grpc.ClientStreamTracer; import io.grpc.ClientStreamTracer.StreamInfo; import io.grpc.InternalChannelz.SocketStats; +import io.grpc.InternalFeatureFlags; import io.grpc.InternalLogId; import io.grpc.InternalMetadata; import io.grpc.InternalMetadata.TrustedAsciiMarshaller; @@ -48,10 +48,8 @@ import java.io.Closeable; import java.io.IOException; import java.io.InputStream; -import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; import java.net.HttpURLConnection; -import java.net.InetSocketAddress; import java.net.SocketAddress; import java.net.URI; import java.net.URISyntaxException; @@ -221,7 +219,7 @@ public byte[] parseAsciiString(byte[] serialized) { public static final Splitter ACCEPT_ENCODING_SPLITTER = Splitter.on(',').trimResults(); - public static final String IMPLEMENTATION_VERSION = "1.63.0-SNAPSHOT"; // CURRENT_GRPC_VERSION + public static final String IMPLEMENTATION_VERSION = "1.82.0-SNAPSHOT"; // CURRENT_GRPC_VERSION /** * The default timeout in nanos for a keepalive ping request. @@ -243,6 +241,12 @@ public byte[] parseAsciiString(byte[] serialized) { */ public static final long DEFAULT_SERVER_KEEPALIVE_TIMEOUT_NANOS = TimeUnit.SECONDS.toNanos(20L); + /** + * The default minimum time between client keepalive pings permitted by server. + */ + public static final long DEFAULT_SERVER_PERMIT_KEEPALIVE_TIME_NANOS + = TimeUnit.MINUTES.toNanos(5L); + /** * The magic keepalive time value that disables keepalive. */ @@ -632,25 +636,6 @@ public Stopwatch get() { } }; - /** - * Returns the host via {@link InetSocketAddress#getHostString} if it is possible, - * i.e. in jdk >= 7. - * Otherwise, return it via {@link InetSocketAddress#getHostName} which may incur a DNS lookup. - */ - public static String getHost(InetSocketAddress addr) { - try { - Method getHostStringMethod = InetSocketAddress.class.getMethod("getHostString"); - return (String) getHostStringMethod.invoke(addr); - } catch (NoSuchMethodException e) { - // noop - } catch (IllegalAccessException e) { - // noop - } catch (InvocationTargetException e) { - // noop - } - return addr.getHostName(); - } - /** * Marshals a nanoseconds representation of the timeout to and from a string representation, * consisting of an ASCII decimal representation of a number with at most 8 digits, followed by a @@ -672,12 +657,14 @@ public static String getHost(InetSocketAddress addr) { static class TimeoutMarshaller implements Metadata.AsciiMarshaller { @Override - public String toAsciiString(Long timeoutNanos) { + public String toAsciiString(Long timeoutNanosObject) { long cutoff = 100000000; + // Timeout checking is inherently racy. RPCs with timeouts in the past ideally don't even get + // here, but if the timeout is expired assume that happened recently and adjust it to the + // smallest allowed timeout + long timeoutNanos = Math.max(1, timeoutNanosObject); TimeUnit unit = TimeUnit.NANOSECONDS; - if (timeoutNanos < 0) { - throw new IllegalArgumentException("Timeout too small"); - } else if (timeoutNanos < cutoff) { + if (timeoutNanos < cutoff) { return timeoutNanos + "n"; } else if (timeoutNanos < cutoff * 1000L) { return unit.toMicros(timeoutNanos) + "u"; @@ -778,13 +765,15 @@ public ListenableFuture getStats() { /** Gets stream tracers based on CallOptions. */ public static ClientStreamTracer[] getClientStreamTracers( - CallOptions callOptions, Metadata headers, int previousAttempts, boolean isTransparentRetry) { + CallOptions callOptions, Metadata headers, int previousAttempts, boolean isTransparentRetry, + boolean isHedging) { List factories = callOptions.getStreamTracerFactories(); ClientStreamTracer[] tracers = new ClientStreamTracer[factories.size() + 1]; StreamInfo streamInfo = StreamInfo.newBuilder() .setCallOptions(callOptions) .setPreviousAttempts(previousAttempts) .setIsTransparentRetry(isTransparentRetry) + .setIsHedging(isHedging) .build(); for (int i = 0; i < factories.size(); i++) { tracers[i] = factories.get(i).newClientStreamTracer(streamInfo, headers); @@ -838,6 +827,31 @@ public static Status replaceInappropriateControlPlaneStatus(Status status) { + status.getDescription()).withCause(status.getCause()) : status; } + /** + * Returns a "clean" representation of a status code and description (not cause) like + * "UNAVAILABLE: The description". Should be similar to Status.formatThrowableMessage(). + */ + public static String statusToPrettyString(Status status) { + if (status.getDescription() == null) { + return status.getCode().toString(); + } else { + return status.getCode() + ": " + status.getDescription(); + } + } + + /** + * Create a status with contextual information, propagating details from a non-null status that + * contributed to the failure. For example, if UNAVAILABLE, "Couldn't load bar", and status + * "FAILED_PRECONDITION: Foo missing" were passed as arguments, then this method would produce the + * status "UNAVAILABLE: Couldn't load bar: FAILED_PRECONDITION: Foo missing" with a cause if the + * passed status had a cause. + */ + public static Status statusWithDetails(Status.Code code, String description, Status causeStatus) { + return code.toStatus() + .withDescription(description + ": " + statusToPrettyString(causeStatus)) + .withCause(causeStatus.getCause()); + } + /** * Checks whether the given item exists in the iterable. This is copied from Guava Collect's * {@code Iterables.contains()} because Guava Collect is not Android-friendly thus core can't @@ -950,15 +964,7 @@ public static String encodeAuthority(String authority) { } public static boolean getFlag(String envVarName, boolean enableByDefault) { - String envVar = System.getenv(envVarName); - if (envVar == null) { - envVar = System.getProperty(envVarName); - } - if (enableByDefault) { - return Strings.isNullOrEmpty(envVar) || Boolean.parseBoolean(envVar); - } else { - return !Strings.isNullOrEmpty(envVar) && Boolean.parseBoolean(envVar); - } + return InternalFeatureFlags.getFlag(envVarName, enableByDefault); } diff --git a/core/src/main/java/io/grpc/internal/Http2ClientStreamTransportState.java b/core/src/main/java/io/grpc/internal/Http2ClientStreamTransportState.java index c2bcb350a9f..7124f2fc88a 100644 --- a/core/src/main/java/io/grpc/internal/Http2ClientStreamTransportState.java +++ b/core/src/main/java/io/grpc/internal/Http2ClientStreamTransportState.java @@ -16,13 +16,14 @@ package io.grpc.internal; -import com.google.common.base.Charsets; import com.google.common.base.Preconditions; +import io.grpc.CallOptions; import io.grpc.InternalMetadata; import io.grpc.InternalStatus; import io.grpc.Metadata; import io.grpc.Status; import java.nio.charset.Charset; +import java.nio.charset.StandardCharsets; import javax.annotation.Nullable; /** @@ -61,14 +62,15 @@ public Integer parseAsciiString(byte[] serialized) { /** When non-{@code null}, {@link #transportErrorMetadata} must also be non-{@code null}. */ private Status transportError; private Metadata transportErrorMetadata; - private Charset errorCharset = Charsets.UTF_8; + private Charset errorCharset = StandardCharsets.UTF_8; private boolean headersReceived; protected Http2ClientStreamTransportState( int maxMessageSize, StatsTraceContext statsTraceCtx, - TransportTracer transportTracer) { - super(maxMessageSize, statsTraceCtx, transportTracer); + TransportTracer transportTracer, + CallOptions options) { + super(maxMessageSize, statsTraceCtx, transportTracer, options); } /** @@ -138,6 +140,7 @@ protected void transportDataReceived(ReadableBuffer frame, boolean endOfStream) } } else { if (!headersReceived) { + frame.close(); http2ProcessingFailed( Status.INTERNAL.withDescription("headers not received before payload"), false, @@ -220,8 +223,11 @@ private Status validateInitialMetadata(Metadata headers) { } String contentType = headers.get(GrpcUtil.CONTENT_TYPE_KEY); if (!GrpcUtil.isGrpcContentType(contentType)) { - return GrpcUtil.httpStatusToGrpcStatus(httpStatus) - .augmentDescription("invalid content-type: " + contentType); + Status status = GrpcUtil.httpStatusToGrpcStatus(httpStatus); + if (contentType == null) { + return status.augmentDescription("missing content-type in response headers"); + } + return status.augmentDescription("invalid content-type: " + contentType); } return null; } @@ -239,7 +245,7 @@ private static Charset extractCharset(Metadata headers) { // Ignore and assume UTF-8 } } - return Charsets.UTF_8; + return StandardCharsets.UTF_8; } /** diff --git a/core/src/main/java/io/grpc/internal/Http2Ping.java b/core/src/main/java/io/grpc/internal/Http2Ping.java index 6104d876373..e3520295625 100644 --- a/core/src/main/java/io/grpc/internal/Http2Ping.java +++ b/core/src/main/java/io/grpc/internal/Http2Ping.java @@ -17,6 +17,8 @@ package io.grpc.internal; import com.google.common.base.Stopwatch; +import com.google.errorprone.annotations.concurrent.GuardedBy; +import io.grpc.Status; import io.grpc.internal.ClientTransport.PingCallback; import java.util.LinkedHashMap; import java.util.Map; @@ -24,7 +26,6 @@ import java.util.concurrent.TimeUnit; import java.util.logging.Level; import java.util.logging.Logger; -import javax.annotation.concurrent.GuardedBy; /** * Represents an outstanding PING operation on an HTTP/2 channel. This can be used by HTTP/2-based @@ -62,7 +63,7 @@ public class Http2Ping { /** * If non-null, indicates the ping failed. */ - @GuardedBy("this") private Throwable failureCause; + @GuardedBy("this") private Status failureCause; /** * The round-trip time for the ping, in nanoseconds. This value is only meaningful when @@ -144,7 +145,7 @@ public boolean complete() { * * @param failureCause the cause of failure */ - public void failed(Throwable failureCause) { + public void failed(Status failureCause) { Map callbacks; synchronized (this) { if (completed) { @@ -167,7 +168,7 @@ public void failed(Throwable failureCause) { * @param executor the executor used to invoke the callback * @param cause the cause of failure */ - public static void notifyFailed(PingCallback callback, Executor executor, Throwable cause) { + public static void notifyFailed(PingCallback callback, Executor executor, Status cause) { doExecute(executor, asRunnable(callback, cause)); } @@ -203,7 +204,7 @@ public void run() { * failure. */ private static Runnable asRunnable(final ClientTransport.PingCallback callback, - final Throwable failureCause) { + final Status failureCause) { return new Runnable() { @Override public void run() { diff --git a/core/src/main/java/io/grpc/internal/InstantTimeProvider.java b/core/src/main/java/io/grpc/internal/InstantTimeProvider.java new file mode 100644 index 00000000000..12996163753 --- /dev/null +++ b/core/src/main/java/io/grpc/internal/InstantTimeProvider.java @@ -0,0 +1,36 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.internal; + +import static com.google.common.math.LongMath.saturatedAdd; + +import java.time.Instant; +import java.util.concurrent.TimeUnit; +import org.codehaus.mojo.animal_sniffer.IgnoreJRERequirement; + +/** + * {@link InstantTimeProvider} resolves InstantTimeProvider which implements {@link TimeProvider}. + */ +final class InstantTimeProvider implements TimeProvider { + @Override + @IgnoreJRERequirement + public long currentTimeNanos() { + Instant now = Instant.now(); + long epochSeconds = now.getEpochSecond(); + return saturatedAdd(TimeUnit.SECONDS.toNanos(epochSeconds), now.getNano()); + } +} diff --git a/core/src/main/java/io/grpc/internal/InternalServer.java b/core/src/main/java/io/grpc/internal/InternalServer.java index a6079081233..8449f352b17 100644 --- a/core/src/main/java/io/grpc/internal/InternalServer.java +++ b/core/src/main/java/io/grpc/internal/InternalServer.java @@ -22,13 +22,13 @@ import java.net.SocketAddress; import java.util.List; import javax.annotation.Nullable; -import javax.annotation.concurrent.ThreadSafe; /** * An object that accepts new incoming connections on one or more listening socket addresses. * This would commonly encapsulate a bound socket that {@code accept()}s new connections. + * + *

This interface is thread-safe. */ -@ThreadSafe public interface InternalServer { /** * Starts transport. Implementations must not call {@code listener} until after {@code start()} diff --git a/core/src/main/java/io/grpc/internal/InternalSubchannel.java b/core/src/main/java/io/grpc/internal/InternalSubchannel.java index a986cb2deff..00a66b1c1df 100644 --- a/core/src/main/java/io/grpc/internal/InternalSubchannel.java +++ b/core/src/main/java/io/grpc/internal/InternalSubchannel.java @@ -42,11 +42,16 @@ import io.grpc.HttpConnectProxiedSocketAddress; import io.grpc.InternalChannelz; import io.grpc.InternalChannelz.ChannelStats; +import io.grpc.InternalEquivalentAddressGroup; import io.grpc.InternalInstrumented; import io.grpc.InternalLogId; import io.grpc.InternalWithLogId; +import io.grpc.LoadBalancer; import io.grpc.Metadata; import io.grpc.MethodDescriptor; +import io.grpc.MetricRecorder; +import io.grpc.NameResolver; +import io.grpc.SecurityLevel; import io.grpc.Status; import io.grpc.SynchronizationContext; import io.grpc.SynchronizationContext.ScheduledHandle; @@ -76,7 +81,9 @@ final class InternalSubchannel implements InternalInstrumented, Tr private final InternalChannelz channelz; private final CallTracer callsTracer; private final ChannelTracer channelTracer; + private final MetricRecorder metricRecorder; private final ChannelLogger channelLogger; + private final boolean reconnectDisabled; private final List transportFilters; @@ -157,13 +164,21 @@ protected void handleNotInUse() { private Status shutdownReason; - InternalSubchannel(List addressGroups, String authority, String userAgent, - BackoffPolicy.Provider backoffPolicyProvider, - ClientTransportFactory transportFactory, ScheduledExecutorService scheduledExecutor, - Supplier stopwatchSupplier, SynchronizationContext syncContext, Callback callback, - InternalChannelz channelz, CallTracer callsTracer, ChannelTracer channelTracer, - InternalLogId logId, ChannelLogger channelLogger, - List transportFilters) { + private volatile Attributes connectedAddressAttributes; + private final SubchannelMetrics subchannelMetrics; + private final String target; + + InternalSubchannel(LoadBalancer.CreateSubchannelArgs args, String authority, String userAgent, + BackoffPolicy.Provider backoffPolicyProvider, + ClientTransportFactory transportFactory, + ScheduledExecutorService scheduledExecutor, + Supplier stopwatchSupplier, SynchronizationContext syncContext, + Callback callback, InternalChannelz channelz, CallTracer callsTracer, + ChannelTracer channelTracer, InternalLogId logId, + ChannelLogger channelLogger, List transportFilters, + String target, + MetricRecorder metricRecorder) { + List addressGroups = args.getAddresses(); Preconditions.checkNotNull(addressGroups, "addressGroups"); Preconditions.checkArgument(!addressGroups.isEmpty(), "addressGroups is empty"); checkListHasNoNulls(addressGroups, "addressGroups contains null entry"); @@ -178,6 +193,7 @@ protected void handleNotInUse() { this.scheduledExecutor = scheduledExecutor; this.connectingTimer = stopwatchSupplier.get(); this.syncContext = syncContext; + this.metricRecorder = metricRecorder; this.callback = callback; this.channelz = channelz; this.callsTracer = callsTracer; @@ -185,6 +201,9 @@ protected void handleNotInUse() { this.logId = Preconditions.checkNotNull(logId, "logId"); this.channelLogger = Preconditions.checkNotNull(channelLogger, "channelLogger"); this.transportFilters = transportFilters; + this.reconnectDisabled = args.getOption(LoadBalancer.DISABLE_SUBCHANNEL_RECONNECT_KEY); + this.target = target; + this.subchannelMetrics = new SubchannelMetrics(metricRecorder); } ChannelLogger getChannelLogger() { @@ -249,6 +268,7 @@ private void startNewTransport() { .setAuthority(eagChannelAuthority != null ? eagChannelAuthority : authority) .setEagAttributes(currentEagAttributes) .setUserAgent(userAgent) + .setMetricRecorder(metricRecorder) .setHttpConnectProxiedSocketAddress(proxiedAddr); TransportLogger transportLogger = new TransportLogger(); // In case the transport logs in the constructor, use the subchannel logId @@ -287,6 +307,11 @@ public void run() { } gotoState(ConnectivityStateInfo.forTransientFailure(status)); + + if (reconnectDisabled) { + return; + } + if (reconnectPolicy == null) { reconnectPolicy = backoffPolicyProvider.get(); } @@ -305,7 +330,7 @@ public void run() { } /** - * Immediately attempt to reconnect if the current state is TRANSIENT_FAILURE. Otherwise this + * Immediately attempt to reconnect if the current state is TRANSIENT_FAILURE. Otherwise, this * method has no effect. */ void resetConnectBackoff() { @@ -334,8 +359,12 @@ private void gotoState(final ConnectivityStateInfo newState) { if (state.getState() != newState.getState()) { Preconditions.checkState(state.getState() != SHUTDOWN, - "Cannot transition out of SHUTDOWN to " + newState); - state = newState; + "Cannot transition out of SHUTDOWN to %s", newState.getState()); + if (reconnectDisabled && newState.getState() == TRANSIENT_FAILURE) { + state = ConnectivityStateInfo.forNonError(IDLE); + } else { + state = newState; + } callback.onStateChange(InternalSubchannel.this, newState); } } @@ -525,6 +554,13 @@ public void run() { return channelStatsFuture; } + /** + * Return attributes for server address connected by sub channel. + */ + public Attributes getConnectedAddressAttributes() { + return connectedAddressAttributes; + } + ConnectivityState getState() { return state.getState(); } @@ -568,7 +604,15 @@ public void run() { } else if (pendingTransport == transport) { activeTransport = transport; pendingTransport = null; + connectedAddressAttributes = addressIndex.getCurrentEagAttributes(); gotoNonErrorState(READY); + subchannelMetrics.recordConnectionAttemptSucceeded(/* target= */ target, + /* backendService= */ getBackendServiceOrDefault( + addressIndex.getCurrentEagAttributes()), + /* locality= */ getAttributeOrDefault(addressIndex.getCurrentEagAttributes(), + EquivalentAddressGroup.ATTR_LOCALITY_NAME), + /* securityLevel= */ extractSecurityLevel(addressIndex.getCurrentEagAttributes() + .get(GrpcAttributes.ATTR_SECURITY_LEVEL))); } } }); @@ -580,7 +624,7 @@ public void transportInUse(boolean inUse) { } @Override - public void transportShutdown(final Status s) { + public void transportShutdown(final Status s, final DisconnectError disconnectError) { channelLogger.log( ChannelLogLevel.INFO, "{0} SHUTDOWN with {1}", transport.getLogId(), printShortStatus(s)); shutdownInitiated = true; @@ -594,11 +638,24 @@ public void run() { activeTransport = null; addressIndex.reset(); gotoNonErrorState(IDLE); + subchannelMetrics.recordDisconnection(/* target= */ target, + /* backendService= */ getBackendServiceOrDefault( + addressIndex.getCurrentEagAttributes()), + /* locality= */ getAttributeOrDefault(addressIndex.getCurrentEagAttributes(), + EquivalentAddressGroup.ATTR_LOCALITY_NAME), + /* disconnectError= */ disconnectError.toErrorString(), + /* securityLevel= */ extractSecurityLevel(addressIndex.getCurrentEagAttributes() + .get(GrpcAttributes.ATTR_SECURITY_LEVEL))); } else if (pendingTransport == transport) { + subchannelMetrics.recordConnectionAttemptFailed(/* target= */ target, + /* backendService= */ getBackendServiceOrDefault( + addressIndex.getCurrentEagAttributes()), + /* locality= */ getAttributeOrDefault(addressIndex.getCurrentEagAttributes(), + EquivalentAddressGroup.ATTR_LOCALITY_NAME)); Preconditions.checkState(state.getState() == CONNECTING, "Expected state is CONNECTING, actual state is %s", state.getState()); addressIndex.increment(); - // Continue reconnect if there are still addresses to try. + // Continue to reconnect if there are still addresses to try. if (!addressIndex.isValid()) { pendingTransport = null; addressIndex.reset(); @@ -634,6 +691,35 @@ public void run() { } }); } + + private String extractSecurityLevel(SecurityLevel securityLevel) { + if (securityLevel == null) { + return "none"; + } + switch (securityLevel) { + case NONE: + return "none"; + case INTEGRITY: + return "integrity_only"; + case PRIVACY_AND_INTEGRITY: + return "privacy_and_integrity"; + default: + throw new IllegalArgumentException("Unknown SecurityLevel: " + securityLevel); + } + } + + private String getAttributeOrDefault(Attributes attributes, Attributes.Key key) { + String value = attributes.get(key); + return value == null ? "" : value; + } + + private String getBackendServiceOrDefault(Attributes attributes) { + String value = attributes.get(InternalEquivalentAddressGroup.ATTR_BACKEND_SERVICE); + if (value == null) { + value = attributes.get(NameResolver.ATTR_BACKEND_SERVICE); + } + return value == null ? "" : value; + } } // All methods are called in syncContext diff --git a/core/src/main/java/io/grpc/internal/JsonParser.java b/core/src/main/java/io/grpc/internal/JsonParser.java index 384d29754f0..14f78c09e72 100644 --- a/core/src/main/java/io/grpc/internal/JsonParser.java +++ b/core/src/main/java/io/grpc/internal/JsonParser.java @@ -16,6 +16,7 @@ package io.grpc.internal; +import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; import com.google.gson.stream.JsonReader; @@ -41,7 +42,8 @@ private JsonParser() {} /** * Parses a json string, returning either a {@code Map}, {@code List}, - * {@code String}, {@code Double}, {@code Boolean}, or {@code null}. + * {@code String}, {@code Double}, {@code Boolean}, or {@code null}. Fails if duplicate names + * found. */ public static Object parse(String raw) throws IOException { JsonReader jr = new JsonReader(new StringReader(raw)); @@ -81,6 +83,7 @@ private static Object parseRecursive(JsonReader jr) throws IOException { Map obj = new LinkedHashMap<>(); while (jr.hasNext()) { String name = jr.nextName(); + checkArgument(!obj.containsKey(name), "Duplicate key found: %s", name); Object value = parseRecursive(jr); obj.put(name, value); } @@ -105,4 +108,4 @@ private static Void parseJsonNull(JsonReader jr) throws IOException { jr.nextNull(); return null; } -} +} \ No newline at end of file diff --git a/core/src/main/java/io/grpc/internal/JsonUtil.java b/core/src/main/java/io/grpc/internal/JsonUtil.java index 44cb22abda5..6c9274702b6 100644 --- a/core/src/main/java/io/grpc/internal/JsonUtil.java +++ b/core/src/main/java/io/grpc/internal/JsonUtil.java @@ -356,23 +356,24 @@ private static int parseNanos(String value) throws ParseException { return result; } - private static final long NANOS_PER_SECOND = TimeUnit.SECONDS.toNanos(1); + private static final int NANOS_PER_SECOND = 1_000_000_000; /** * Copy of {@link com.google.protobuf.util.Durations#normalizedDuration}. */ - @SuppressWarnings("NarrowingCompoundAssignment") + // Math.addExact() requires Android API level 24 + @SuppressWarnings({"NarrowingCompoundAssignment", "InlineMeInliner"}) private static long normalizedDuration(long seconds, int nanos) { if (nanos <= -NANOS_PER_SECOND || nanos >= NANOS_PER_SECOND) { seconds = checkedAdd(seconds, nanos / NANOS_PER_SECOND); nanos %= NANOS_PER_SECOND; } if (seconds > 0 && nanos < 0) { - nanos += NANOS_PER_SECOND; // no overflow since nanos is negative (and we're adding) + nanos += NANOS_PER_SECOND; // no overflow— nanos is negative (and we're adding) seconds--; // no overflow since seconds is positive (and we're decrementing) } if (seconds < 0 && nanos > 0) { - nanos -= NANOS_PER_SECOND; // no overflow since nanos is positive (and we're subtracting) + nanos -= NANOS_PER_SECOND; // no overflow— nanos is positive (and we're subtracting) seconds++; // no overflow since seconds is negative (and we're incrementing) } if (!durationIsValid(seconds, nanos)) { diff --git a/core/src/main/java/io/grpc/internal/KeepAliveEnforcer.java b/core/src/main/java/io/grpc/internal/KeepAliveEnforcer.java index dd539e75a18..6480336470c 100644 --- a/core/src/main/java/io/grpc/internal/KeepAliveEnforcer.java +++ b/core/src/main/java/io/grpc/internal/KeepAliveEnforcer.java @@ -18,8 +18,8 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; +import com.google.errorprone.annotations.CheckReturnValue; import java.util.concurrent.TimeUnit; -import javax.annotation.CheckReturnValue; /** Monitors the client's PING usage to make sure the rate is permitted. */ public final class KeepAliveEnforcer { diff --git a/core/src/main/java/io/grpc/internal/KeepAliveManager.java b/core/src/main/java/io/grpc/internal/KeepAliveManager.java index 28e2a87276b..1937da6f467 100644 --- a/core/src/main/java/io/grpc/internal/KeepAliveManager.java +++ b/core/src/main/java/io/grpc/internal/KeepAliveManager.java @@ -22,11 +22,11 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Stopwatch; import com.google.common.util.concurrent.MoreExecutors; +import com.google.errorprone.annotations.concurrent.GuardedBy; import io.grpc.Status; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ScheduledFuture; import java.util.concurrent.TimeUnit; -import javax.annotation.concurrent.GuardedBy; /** * Manages keepalive pings. @@ -262,9 +262,25 @@ public interface KeepAlivePinger { * Default client side {@link KeepAlivePinger}. */ public static final class ClientKeepAlivePinger implements KeepAlivePinger { - private final ConnectionClientTransport transport; - public ClientKeepAlivePinger(ConnectionClientTransport transport) { + + /** + * A {@link ClientTransport} that has life-cycle management. + * + *

This interface is thread-safe. + */ + public interface TransportWithDisconnectReason extends ClientTransport { + + /** + * Initiates a forceful shutdown in which preexisting and new calls are closed. Existing calls + * should be closed with the provided {@code reason} and {@code disconnectError}. + */ + void shutdownNow(Status reason, DisconnectError disconnectError); + } + + private final TransportWithDisconnectReason transport; + + public ClientKeepAlivePinger(TransportWithDisconnectReason transport) { this.transport = transport; } @@ -275,9 +291,10 @@ public void ping() { public void onSuccess(long roundTripTimeNanos) {} @Override - public void onFailure(Throwable cause) { + public void onFailure(Status cause) { transport.shutdownNow(Status.UNAVAILABLE.withDescription( - "Keepalive failed. The connection is likely gone")); + "Keepalive failed. The connection is likely gone"), + SimpleDisconnectError.CONNECTION_TIMED_OUT); } }, MoreExecutors.directExecutor()); } @@ -285,8 +302,8 @@ public void onFailure(Throwable cause) { @Override public void onPingTimeout() { transport.shutdownNow(Status.UNAVAILABLE.withDescription( - "Keepalive failed. The connection is likely gone")); + "Keepalive failed. The connection is likely gone"), + SimpleDisconnectError.CONNECTION_TIMED_OUT); } } } - diff --git a/core/src/main/java/io/grpc/internal/ManagedChannelImpl.java b/core/src/main/java/io/grpc/internal/ManagedChannelImpl.java index 5d600d1ca5e..e423220e3ad 100644 --- a/core/src/main/java/io/grpc/internal/ManagedChannelImpl.java +++ b/core/src/main/java/io/grpc/internal/ManagedChannelImpl.java @@ -20,6 +20,7 @@ import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.base.Preconditions.checkState; import static io.grpc.ClientStreamTracer.NAME_RESOLUTION_DELAYED; +import static io.grpc.ConnectivityState.CONNECTING; import static io.grpc.ConnectivityState.IDLE; import static io.grpc.ConnectivityState.SHUTDOWN; import static io.grpc.ConnectivityState.TRANSIENT_FAILURE; @@ -31,6 +32,7 @@ import com.google.common.base.Supplier; import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.SettableFuture; +import com.google.errorprone.annotations.concurrent.GuardedBy; import io.grpc.Attributes; import io.grpc.CallCredentials; import io.grpc.CallOptions; @@ -67,10 +69,13 @@ import io.grpc.LoadBalancer.ResolvedAddresses; import io.grpc.LoadBalancer.SubchannelPicker; import io.grpc.LoadBalancer.SubchannelStateListener; +import io.grpc.LoadBalancerProvider; import io.grpc.ManagedChannel; import io.grpc.ManagedChannelBuilder; import io.grpc.Metadata; import io.grpc.MethodDescriptor; +import io.grpc.MetricInstrumentRegistry; +import io.grpc.MetricRecorder; import io.grpc.NameResolver; import io.grpc.NameResolver.ConfigOrError; import io.grpc.NameResolver.ResolutionResult; @@ -78,9 +83,9 @@ import io.grpc.NameResolverRegistry; import io.grpc.ProxyDetector; import io.grpc.Status; +import io.grpc.StatusOr; import io.grpc.SynchronizationContext; import io.grpc.SynchronizationContext.ScheduledHandle; -import io.grpc.internal.AutoConfiguredLoadBalancerFactory.AutoConfiguredLoadBalancer; import io.grpc.internal.ClientCallImpl.ClientStreamProvider; import io.grpc.internal.ClientTransportFactory.SwapChannelCredentialsResult; import io.grpc.internal.ManagedChannelImplBuilder.ClientTransportFactoryBuilder; @@ -89,8 +94,6 @@ import io.grpc.internal.ManagedChannelServiceConfig.ServiceConfigConvertedSelector; import io.grpc.internal.RetriableStream.ChannelBufferMeter; import io.grpc.internal.RetriableStream.Throttle; -import io.grpc.internal.RetryingNameResolver.ResolutionResultListener; -import java.net.SocketAddress; import java.net.URI; import java.net.URISyntaxException; import java.util.ArrayList; @@ -114,9 +117,7 @@ import java.util.concurrent.atomic.AtomicReference; import java.util.logging.Level; import java.util.logging.Logger; -import java.util.regex.Pattern; import javax.annotation.Nullable; -import javax.annotation.concurrent.GuardedBy; import javax.annotation.concurrent.ThreadSafe; /** A communication channel for making outgoing RPCs. */ @@ -126,12 +127,6 @@ final class ManagedChannelImpl extends ManagedChannel implements @VisibleForTesting static final Logger logger = Logger.getLogger(ManagedChannelImpl.class.getName()); - // Matching this pattern means the target string is a URI target or at least intended to be one. - // A URI target must be an absolute hierarchical URI. - // From RFC 2396: scheme = alpha *( alpha | digit | "+" | "-" | "." ) - @VisibleForTesting - static final Pattern URI_PATTERN = Pattern.compile("[a-zA-Z][a-zA-Z0-9+.-]*:/.*"); - static final long IDLE_TIMEOUT_MILLIS_DISABLE = -1; static final long SUBCHANNEL_SHUTDOWN_DELAY_SECONDS = 5; @@ -157,23 +152,25 @@ public Result selectConfig(PickSubchannelArgs args) { throw new IllegalStateException("Resolution is pending"); } }; + private static final LoadBalancer.PickDetailsConsumer NOOP_PICK_DETAILS_CONSUMER = + new LoadBalancer.PickDetailsConsumer() {}; private final InternalLogId logId; private final String target; @Nullable private final String authorityOverride; private final NameResolverRegistry nameResolverRegistry; + private final UriWrapper targetUri; + private final NameResolverProvider nameResolverProvider; private final NameResolver.Args nameResolverArgs; - private final AutoConfiguredLoadBalancerFactory loadBalancerFactory; + private final LoadBalancerProvider loadBalancerFactory; private final ClientTransportFactory originalTransportFactory; @Nullable private final ChannelCredentials originalChannelCreds; private final ClientTransportFactory transportFactory; - private final ClientTransportFactory oobTransportFactory; private final RestrictedScheduledExecutor scheduledExecutor; private final Executor executor; private final ObjectPool executorPool; - private final ObjectPool balancerRpcExecutorPool; private final ExecutorHolder balancerRpcExecutorHolder; private final ExecutorHolder offloadExecutorHolder; private final TimeProvider timeProvider; @@ -188,7 +185,12 @@ public void uncaughtException(Thread t, Throwable e) { Level.SEVERE, "[" + getLogId() + "] Uncaught exception in the SynchronizationContext. Panic!", e); - panic(e); + try { + panic(e); + } catch (Throwable anotherT) { + logger.log( + Level.SEVERE, "[" + getLogId() + "] Uncaught exception while panicking", anotherT); + } } }); @@ -198,7 +200,7 @@ public void uncaughtException(Thread t, Throwable e) { private final CompressorRegistry compressorRegistry; private final Supplier stopwatchSupplier; - /** The timout before entering idle mode. */ + /** The timeout before entering idle mode. */ private final long idleTimeoutMillis; private final ConnectivityStateManager channelStateManager = new ConnectivityStateManager(); @@ -224,11 +226,6 @@ public void uncaughtException(Thread t, Throwable e) { @Nullable private LbHelperImpl lbHelper; - // Must ONLY be assigned from updateSubchannelPicker(), which is called from syncContext. - // null if channel is in idle mode. - @Nullable - private volatile SubchannelPicker subchannelPicker; - // Must be accessed from the syncContext private boolean panicMode; @@ -242,9 +239,6 @@ public void uncaughtException(Thread t, Throwable e) { private Collection> pendingCalls; private final Object pendingCallsInUseObject = new Object(); - // Must be mutated from syncContext - private final Set oobChannels = new HashSet<>(1, .75f); - // reprocess() must be run from syncContext private final DelayedClientTransport delayedTransport; private final UncommittedRetriableStreamsRegistry uncommittedRetriableStreamsRegistry @@ -255,8 +249,7 @@ public void uncaughtException(Thread t, Throwable e) { // Channel's shutdown process: // 1. shutdown(): stop accepting new calls from applications // 1a shutdown <- true - // 1b subchannelPicker <- null - // 1c delayedTransport.shutdown() + // 1b delayedTransport.shutdown() // 2. delayedTransport terminated: stop stream-creation functionality // 2a terminating <- true // 2b loadBalancer.shutdown() @@ -315,9 +308,6 @@ private void maybeShutdownNowSubchannels() { for (InternalSubchannel subchannel : subchannels) { subchannel.shutdownNow(SHUTDOWN_NOW_STATUS); } - for (OobChannel oobChannel : oobChannels) { - oobChannel.getInternalSubchannel().shutdownNow(SHUTDOWN_NOW_STATUS); - } } } @@ -337,7 +327,6 @@ public void run() { builder.setTarget(target).setState(channelStateManager.getState()); List children = new ArrayList<>(); children.addAll(subchannels); - children.addAll(oobChannels); builder.setSubchannels(children); ret.set(builder.build()); } @@ -380,8 +369,7 @@ private void shutdownNameResolverAndLoadBalancer(boolean channelIsActive) { nameResolverStarted = false; if (channelIsActive) { nameResolver = getNameResolver( - target, authorityOverride, nameResolverRegistry, nameResolverArgs, - transportFactory.getSupportedSocketAddressTypes()); + targetUri, authorityOverride, nameResolverProvider, nameResolverArgs); } else { nameResolver = null; } @@ -390,7 +378,6 @@ private void shutdownNameResolverAndLoadBalancer(boolean channelIsActive) { lbHelper.lb.shutdown(); lbHelper = null; } - subchannelPicker = null; } /** @@ -420,9 +407,10 @@ void exitIdleMode() { LbHelperImpl lbHelper = new LbHelperImpl(); lbHelper.lb = loadBalancerFactory.newLoadBalancer(lbHelper); // Delay setting lbHelper until fully initialized, since loadBalancerFactory is user code and - // may throw. We don't want to confuse our state, even if we will enter panic mode. + // may throw. We don't want to confuse our state, even if we enter panic mode. this.lbHelper = lbHelper; + channelStateManager.gotoState(CONNECTING); NameResolverListener listener = new NameResolverListener(lbHelper, nameResolver); nameResolver.start(listener); nameResolverStarted = true; @@ -473,57 +461,21 @@ private void refreshNameResolution() { private final class ChannelStreamProvider implements ClientStreamProvider { volatile Throttle throttle; - private ClientTransport getTransport(PickSubchannelArgs args) { - SubchannelPicker pickerCopy = subchannelPicker; - if (shutdown.get()) { - // If channel is shut down, delayedTransport is also shut down which will fail the stream - // properly. - return delayedTransport; - } - if (pickerCopy == null) { - final class ExitIdleModeForTransport implements Runnable { - @Override - public void run() { - exitIdleMode(); - } - } - - syncContext.execute(new ExitIdleModeForTransport()); - return delayedTransport; - } - // There is no need to reschedule the idle timer here. - // - // pickerCopy != null, which means idle timer has not expired when this method starts. - // Even if idle timer expires right after we grab pickerCopy, and it shuts down LoadBalancer - // which calls Subchannel.shutdown(), the InternalSubchannel will be actually shutdown after - // SUBCHANNEL_SHUTDOWN_DELAY_SECONDS, which gives the caller time to start RPC on it. - // - // In most cases the idle timer is scheduled to fire after the transport has created the - // stream, which would have reported in-use state to the channel that would have cancelled - // the idle timer. - PickResult pickResult = pickerCopy.pickSubchannel(args); - ClientTransport transport = GrpcUtil.getTransportFromPickResult( - pickResult, args.getCallOptions().isWaitForReady()); - if (transport != null) { - return transport; - } - return delayedTransport; - } - @Override public ClientStream newStream( final MethodDescriptor method, final CallOptions callOptions, final Metadata headers, final Context context) { + // There is no need to reschedule the idle timer here. If the channel isn't shut down, either + // the delayed transport or a real transport will go in-use and cancel the idle timer. if (!retryEnabled) { - ClientTransport transport = - getTransport(new PickSubchannelArgsImpl(method, headers, callOptions)); - Context origContext = context.attach(); ClientStreamTracer[] tracers = GrpcUtil.getClientStreamTracers( - callOptions, headers, 0, /* isTransparentRetry= */ false); + callOptions, headers, 0, /* isTransparentRetry= */ false, + /* isHedging= */false); + Context origContext = context.attach(); try { - return transport.newStream(method, headers, callOptions, tracers); + return delayedTransport.newStream(method, headers, callOptions, tracers); } finally { context.detach(origContext); } @@ -560,15 +512,13 @@ void postCommit() { @Override ClientStream newSubstream( Metadata newHeaders, ClientStreamTracer.Factory factory, int previousAttempts, - boolean isTransparentRetry) { + boolean isTransparentRetry, boolean isHedgedStream) { CallOptions newOptions = callOptions.withStreamTracerFactory(factory); ClientStreamTracer[] tracers = GrpcUtil.getClientStreamTracers( - newOptions, newHeaders, previousAttempts, isTransparentRetry); - ClientTransport transport = - getTransport(new PickSubchannelArgsImpl(method, newHeaders, newOptions)); + newOptions, newHeaders, previousAttempts, isTransparentRetry, isHedgedStream); Context origContext = context.attach(); try { - return transport.newStream(method, newHeaders, newOptions, tracers); + return delayedTransport.newStream(method, newHeaders, newOptions, tracers); } finally { context.detach(origContext); } @@ -583,10 +533,13 @@ ClientStream newSubstream( private final ChannelStreamProvider transportProvider = new ChannelStreamProvider(); private final Rescheduler idleTimer; + private final MetricRecorder metricRecorder; ManagedChannelImpl( ManagedChannelImplBuilder builder, ClientTransportFactory clientTransportFactory, + UriWrapper targetUri, + NameResolverProvider nameResolverProvider, BackoffPolicy.Provider backoffPolicyProvider, ObjectPool balancerRpcExecutorPool, Supplier stopwatchSupplier, @@ -603,8 +556,6 @@ ClientStream newSubstream( new ExecutorHolder(checkNotNull(builder.offloadExecutorPool, "offloadExecutorPool")); this.transportFactory = new CallCredentialsApplyingTransportFactory( clientTransportFactory, builder.callCredentials, this.offloadExecutorHolder); - this.oobTransportFactory = new CallCredentialsApplyingTransportFactory( - clientTransportFactory, null, this.offloadExecutorHolder); this.scheduledExecutor = new RestrictedScheduledExecutor(transportFactory.getScheduledExecutorService()); maxTraceEvents = builder.maxTraceEvents; @@ -617,6 +568,8 @@ ClientStream newSubstream( this.retryEnabled = builder.retryEnabled; this.loadBalancerFactory = new AutoConfiguredLoadBalancerFactory(builder.defaultLbPolicy); this.nameResolverRegistry = builder.nameResolverRegistry; + this.targetUri = checkNotNull(targetUri, "targetUri"); + this.nameResolverProvider = checkNotNull(nameResolverProvider, "nameResolverProvider"); ScParser serviceConfigParser = new ScParser( retryEnabled, @@ -624,8 +577,9 @@ ClientStream newSubstream( builder.maxHedgedAttempts, loadBalancerFactory); this.authorityOverride = builder.authorityOverride; - this.nameResolverArgs = - NameResolver.Args.newBuilder() + this.metricRecorder = new MetricRecorderImpl(builder.metricSinks, + MetricInstrumentRegistry.getDefaultRegistry()); + NameResolver.Args.Builder nameResolverArgsBuilder = NameResolver.Args.newBuilder() .setDefaultPort(builder.getDefaultPort()) .setProxyDetector(proxyDetector) .setSynchronizationContext(syncContext) @@ -634,12 +588,14 @@ ClientStream newSubstream( .setChannelLogger(channelLogger) .setOffloadExecutor(this.offloadExecutorHolder) .setOverrideAuthority(this.authorityOverride) - .build(); + .setMetricRecorder(this.metricRecorder) + .setNameResolverRegistry(builder.nameResolverRegistry); + builder.copyAllNameResolverCustomArgsTo(nameResolverArgsBuilder); + this.nameResolverArgs = nameResolverArgsBuilder.build(); this.nameResolver = getNameResolver( - target, authorityOverride, nameResolverRegistry, nameResolverArgs, - transportFactory.getSupportedSocketAddressTypes()); - this.balancerRpcExecutorPool = checkNotNull(balancerRpcExecutorPool, "balancerRpcExecutorPool"); - this.balancerRpcExecutorHolder = new ExecutorHolder(balancerRpcExecutorPool); + targetUri, authorityOverride, nameResolverProvider, nameResolverArgs); + this.balancerRpcExecutorHolder = new ExecutorHolder( + checkNotNull(balancerRpcExecutorPool, "balancerRpcExecutorPool")); this.delayedTransport = new DelayedClientTransport(this.executor, this.syncContext); this.delayedTransport.start(delayedTransportListener); this.backoffPolicyProvider = backoffPolicyProvider; @@ -653,7 +609,7 @@ ClientStream newSubstream( parsedDefaultServiceConfig.getError()); this.defaultServiceConfig = (ManagedChannelServiceConfig) parsedDefaultServiceConfig.getConfig(); - this.lastServiceConfig = this.defaultServiceConfig; + this.transportProvider.throttle = this.defaultServiceConfig.getRetryThrottling(); } else { this.defaultServiceConfig = null; } @@ -709,80 +665,19 @@ public CallTracer create() { } } - private static NameResolver getNameResolver( - String target, NameResolverRegistry nameResolverRegistry, NameResolver.Args nameResolverArgs, - Collection> channelTransportSocketAddressTypes) { - // Finding a NameResolver. Try using the target string as the URI. If that fails, try prepending - // "dns:///". - NameResolverProvider provider = null; - URI targetUri = null; - StringBuilder uriSyntaxErrors = new StringBuilder(); - try { - targetUri = new URI(target); - } catch (URISyntaxException e) { - // Can happen with ip addresses like "[::1]:1234" or 127.0.0.1:1234. - uriSyntaxErrors.append(e.getMessage()); - } - if (targetUri != null) { - // For "localhost:8080" this would likely cause provider to be null, because "localhost" is - // parsed as the scheme. Will hit the next case and try "dns:///localhost:8080". - provider = nameResolverRegistry.getProviderForScheme(targetUri.getScheme()); - } - - if (provider == null && !URI_PATTERN.matcher(target).matches()) { - // It doesn't look like a URI target. Maybe it's an authority string. Try with the default - // scheme from the registry. - try { - targetUri = new URI(nameResolverRegistry.getDefaultScheme(), "", "/" + target, null); - } catch (URISyntaxException e) { - // Should not be possible. - throw new IllegalArgumentException(e); - } - provider = nameResolverRegistry.getProviderForScheme(targetUri.getScheme()); - } - - if (provider == null) { - throw new IllegalArgumentException(String.format( - "Could not find a NameResolverProvider for %s%s", - target, uriSyntaxErrors.length() > 0 ? " (" + uriSyntaxErrors + ")" : "")); - } - - if (channelTransportSocketAddressTypes != null) { - Collection> nameResolverSocketAddressTypes - = provider.getProducedSocketAddressTypes(); - if (!channelTransportSocketAddressTypes.containsAll(nameResolverSocketAddressTypes)) { - throw new IllegalArgumentException(String.format( - "Address types of NameResolver '%s' for '%s' not supported by transport", - targetUri.getScheme(), target)); - } - } - - NameResolver resolver = provider.newNameResolver(targetUri, nameResolverArgs); - if (resolver != null) { - return resolver; - } - - throw new IllegalArgumentException(String.format( - "cannot create a NameResolver for %s%s", - target, uriSyntaxErrors.length() > 0 ? " (" + uriSyntaxErrors + ")" : "")); - } - @VisibleForTesting static NameResolver getNameResolver( - String target, @Nullable final String overrideAuthority, - NameResolverRegistry nameResolverRegistry, NameResolver.Args nameResolverArgs, - Collection> channelTransportSocketAddressTypes) { - NameResolver resolver = getNameResolver(target, nameResolverRegistry, nameResolverArgs, - channelTransportSocketAddressTypes); + UriWrapper targetUri, @Nullable final String overrideAuthority, + NameResolverProvider provider, NameResolver.Args nameResolverArgs) { + NameResolver resolver = targetUri.newNameResolver(provider, nameResolverArgs); + if (resolver == null) { + throw new IllegalArgumentException("cannot create a NameResolver for " + targetUri); + } // We wrap the name resolver in a RetryingNameResolver to give it the ability to retry failures. // TODO: After a transition period, all NameResolver implementations that need retry should use // RetryingNameResolver directly and this step can be removed. - NameResolver usedNameResolver = new RetryingNameResolver(resolver, - new BackoffPolicyRetryScheduler(new ExponentialBackoffPolicy.Provider(), - nameResolverArgs.getScheduledExecutorService(), - nameResolverArgs.getSynchronizationContext()), - nameResolverArgs.getSynchronizationContext()); + NameResolver usedNameResolver = RetryingNameResolver.wrap(resolver, nameResolverArgs); if (overrideAuthority == null) { return usedNameResolver; @@ -800,6 +695,11 @@ public String getServiceAuthority() { InternalConfigSelector getConfigSelector() { return realChannel.configSelector.get(); } + + @VisibleForTesting + boolean hasThrottle() { + return this.transportProvider.throttle != null; + } /** * Initiates an orderly shutdown in which preexisting calls continue but new calls are immediately @@ -865,30 +765,16 @@ void panic(final Throwable t) { return; } panicMode = true; - cancelIdleTimer(/* permanent= */ true); - shutdownNameResolverAndLoadBalancer(false); - final class PanicSubchannelPicker extends SubchannelPicker { - private final PickResult panicPickResult = - PickResult.withDrop( - Status.INTERNAL.withDescription("Panic! This is a bug!").withCause(t)); - - @Override - public PickResult pickSubchannel(PickSubchannelArgs args) { - return panicPickResult; - } - - @Override - public String toString() { - return MoreObjects.toStringHelper(PanicSubchannelPicker.class) - .add("panicPickResult", panicPickResult) - .toString(); - } + try { + cancelIdleTimer(/* permanent= */ true); + shutdownNameResolverAndLoadBalancer(false); + } finally { + updateSubchannelPicker(new LoadBalancer.FixedResultPicker(PickResult.withDrop( + Status.INTERNAL.withDescription("Panic! This is a bug!").withCause(t)))); + realChannel.updateConfigSelector(null); + channelLogger.log(ChannelLogLevel.ERROR, "PANIC! Entering TRANSIENT_FAILURE"); + channelStateManager.gotoState(TRANSIENT_FAILURE); } - - updateSubchannelPicker(new PanicSubchannelPicker()); - realChannel.updateConfigSelector(null); - channelLogger.log(ChannelLogLevel.ERROR, "PANIC! Entering TRANSIENT_FAILURE"); - channelStateManager.gotoState(TRANSIENT_FAILURE); } @VisibleForTesting @@ -898,7 +784,6 @@ boolean isInPanicMode() { // Called from syncContext private void updateSubchannelPicker(SubchannelPicker newPicker) { - subchannelPicker = newPicker; delayedTransport.reprocess(newPicker); } @@ -1043,7 +928,15 @@ void updateConfigSelector(@Nullable InternalConfigSelector config) { // Must run in SynchronizationContext. void onConfigError() { if (configSelector.get() == INITIAL_PENDING_SELECTOR) { - updateConfigSelector(null); + // Apply Default Service Config if initial name resolution fails. + if (defaultServiceConfig != null) { + updateConfigSelector(defaultServiceConfig.getDefaultConfigSelector()); + lastServiceConfig = defaultServiceConfig; + channelLogger.log(ChannelLogLevel.ERROR, + "Initial Name Resolution error, using default service config"); + } else { + updateConfigSelector(null); + } } } @@ -1205,7 +1098,8 @@ protected ClientCall delegate() { @SuppressWarnings("unchecked") @Override public void start(Listener observer, Metadata headers) { - PickSubchannelArgs args = new PickSubchannelArgsImpl(method, headers, callOptions); + PickSubchannelArgs args = + new PickSubchannelArgsImpl(method, headers, callOptions, NOOP_PICK_DETAILS_CONSUMER); InternalConfigSelector.Result result = configSelector.selectConfig(args); Status status = result.getStatus(); if (!status.isOk()) { @@ -1283,7 +1177,7 @@ private void maybeTerminateChannel() { if (terminated) { return; } - if (shutdown.get() && subchannels.isEmpty() && oobChannels.isEmpty()) { + if (shutdown.get() && subchannels.isEmpty()) { channelLogger.log(ChannelLogLevel.INFO, "Terminated"); channelz.removeRootChannel(this); executorPool.returnObject(executor); @@ -1297,15 +1191,7 @@ private void maybeTerminateChannel() { } } - // Must be called from syncContext - private void handleInternalSubchannelState(ConnectivityStateInfo newState) { - if (newState.getState() == TRANSIENT_FAILURE || newState.getState() == IDLE) { - refreshNameResolution(); - } - } - @Override - @SuppressWarnings("deprecation") public ConnectivityState getState(boolean requestConnection) { ConnectivityState savedChannelState = channelStateManager.getState(); if (requestConnection && savedChannelState == IDLE) { @@ -1313,9 +1199,6 @@ final class RequestConnection implements Runnable { @Override public void run() { exitIdleMode(); - if (subchannelPicker != null) { - subchannelPicker.requestConnection(); - } if (lbHelper != null) { lbHelper.lb.requestConnection(); } @@ -1353,9 +1236,6 @@ public void run() { for (InternalSubchannel subchannel : subchannels) { subchannel.resetConnectBackoff(); } - for (OobChannel oobChannel : oobChannels) { - oobChannel.resetConnectBackoff(); - } } } @@ -1462,7 +1342,7 @@ void remove(RetriableStream retriableStream) { } private final class LbHelperImpl extends LoadBalancer.Helper { - AutoConfiguredLoadBalancer lb; + LoadBalancer lb; @Override public AbstractSubchannel createSubchannel(CreateSubchannelArgs args) { @@ -1478,24 +1358,18 @@ public void updateBalancingState( syncContext.throwIfNotInThisSynchronizationContext(); checkNotNull(newState, "newState"); checkNotNull(newPicker, "newPicker"); - final class UpdateBalancingState implements Runnable { - @Override - public void run() { - if (LbHelperImpl.this != lbHelper) { - return; - } - updateSubchannelPicker(newPicker); - // It's not appropriate to report SHUTDOWN state from lb. - // Ignore the case of newState == SHUTDOWN for now. - if (newState != SHUTDOWN) { - channelLogger.log( - ChannelLogLevel.INFO, "Entering {0} state with picker: {1}", newState, newPicker); - channelStateManager.gotoState(newState); - } - } - } - syncContext.execute(new UpdateBalancingState()); + if (LbHelperImpl.this != lbHelper || panicMode) { + return; + } + updateSubchannelPicker(newPicker); + // It's not appropriate to report SHUTDOWN state from lb. + // Ignore the case of newState == SHUTDOWN for now. + if (newState != SHUTDOWN) { + channelLogger.log( + ChannelLogLevel.INFO, "Entering {0} state with picker: {1}", newState, newPicker); + channelStateManager.gotoState(newState); + } } @Override @@ -1519,84 +1393,28 @@ public ManagedChannel createOobChannel(EquivalentAddressGroup addressGroup, Stri @Override public ManagedChannel createOobChannel(List addressGroup, String authority) { - // TODO(ejona): can we be even stricter? Like terminating? - checkState(!terminated, "Channel is terminated"); - long oobChannelCreationTime = timeProvider.currentTimeNanos(); - InternalLogId oobLogId = InternalLogId.allocate("OobChannel", /*details=*/ null); - InternalLogId subchannelLogId = - InternalLogId.allocate("Subchannel-OOB", /*details=*/ authority); - ChannelTracer oobChannelTracer = - new ChannelTracer( - oobLogId, maxTraceEvents, oobChannelCreationTime, - "OobChannel for " + addressGroup); - final OobChannel oobChannel = new OobChannel( - authority, balancerRpcExecutorPool, oobTransportFactory.getScheduledExecutorService(), - syncContext, callTracerFactory.create(), oobChannelTracer, channelz, timeProvider); - channelTracer.reportEvent(new ChannelTrace.Event.Builder() - .setDescription("Child OobChannel created") - .setSeverity(ChannelTrace.Event.Severity.CT_INFO) - .setTimestampNanos(oobChannelCreationTime) - .setChannelRef(oobChannel) - .build()); - ChannelTracer subchannelTracer = - new ChannelTracer(subchannelLogId, maxTraceEvents, oobChannelCreationTime, - "Subchannel for " + addressGroup); - ChannelLogger subchannelLogger = new ChannelLoggerImpl(subchannelTracer, timeProvider); - final class ManagedOobChannelCallback extends InternalSubchannel.Callback { - @Override - void onTerminated(InternalSubchannel is) { - oobChannels.remove(oobChannel); - channelz.removeSubchannel(is); - oobChannel.handleSubchannelTerminated(); - maybeTerminateChannel(); - } - - @Override - void onStateChange(InternalSubchannel is, ConnectivityStateInfo newState) { - // TODO(chengyuanzhang): change to let LB policies explicitly manage OOB channel's - // state and refresh name resolution if necessary. - handleInternalSubchannelState(newState); - oobChannel.handleSubchannelStateChange(newState); - } - } - - final InternalSubchannel internalSubchannel = new InternalSubchannel( - addressGroup, - authority, userAgent, backoffPolicyProvider, oobTransportFactory, - oobTransportFactory.getScheduledExecutorService(), stopwatchSupplier, syncContext, - // All callback methods are run from syncContext - new ManagedOobChannelCallback(), - channelz, - callTracerFactory.create(), - subchannelTracer, - subchannelLogId, - subchannelLogger, - transportFilters); - oobChannelTracer.reportEvent(new ChannelTrace.Event.Builder() - .setDescription("Child Subchannel created") - .setSeverity(ChannelTrace.Event.Severity.CT_INFO) - .setTimestampNanos(oobChannelCreationTime) - .setSubchannelRef(internalSubchannel) - .build()); - channelz.addSubchannel(oobChannel); - channelz.addSubchannel(internalSubchannel); - oobChannel.setSubchannel(internalSubchannel); - final class AddOobChannel implements Runnable { - @Override - public void run() { - if (terminating) { - oobChannel.shutdown(); - } - if (!terminated) { - // If channel has not terminated, it will track the subchannel and block termination - // for it. - oobChannels.add(oobChannel); - } - } + NameResolverRegistry nameResolverRegistry = new NameResolverRegistry(); + OobNameResolverProvider resolverProvider = + new OobNameResolverProvider(authority, addressGroup, syncContext); + nameResolverRegistry.register(resolverProvider); + // We could use a hard-coded target, as the name resolver won't actually use this string. + // However, that would make debugging less clear, as we use the target to identify the + // channel. + String target; + try { + target = new URI("oob", "", "/" + authority, null, null).toString(); + } catch (URISyntaxException ex) { + // Any special characters in the path will be percent encoded. So this should be impossible. + throw new AssertionError(ex); } - - syncContext.execute(new AddOobChannel()); - return oobChannel; + ManagedChannel delegate = createResolvingOobChannelBuilder( + target, new DefaultChannelCreds(), nameResolverRegistry) + // TODO(zdapeng): executors should not outlive the parent channel. + .executor(balancerRpcExecutorHolder.getExecutor()) + .idleTimeout(Integer.MAX_VALUE, TimeUnit.SECONDS) + .disableRetry() + .build(); + return new OobChannel(delegate, resolverProvider); } @Deprecated @@ -1608,11 +1426,17 @@ public ManagedChannelBuilder createResolvingOobChannelBuilder(String target) .overrideAuthority(getAuthority()); } - // TODO(creamsoup) prevent main channel to shutdown if oob channel is not terminated - // TODO(zdapeng) register the channel as a subchannel of the parent channel in channelz. @Override public ManagedChannelBuilder createResolvingOobChannelBuilder( final String target, final ChannelCredentials channelCreds) { + return createResolvingOobChannelBuilder(target, channelCreds, nameResolverRegistry); + } + + // TODO(creamsoup) prevent main channel to shutdown if oob channel is not terminated + // TODO(zdapeng) register the channel as a subchannel of the parent channel in channelz. + private ManagedChannelBuilder createResolvingOobChannelBuilder( + final String target, final ChannelCredentials channelCreds, + NameResolverRegistry nameResolverRegistry) { checkNotNull(channelCreds, "channelCreds"); final class ResolvingOobChannelBuilder @@ -1660,7 +1484,6 @@ protected ManagedChannelBuilder delegate() { checkState(!terminated, "Channel is terminated"); - @SuppressWarnings("deprecation") ResolvingOobChannelBuilder builder = new ResolvingOobChannelBuilder(); return builder @@ -1698,6 +1521,11 @@ public String getAuthority() { return ManagedChannelImpl.this.authority(); } + @Override + public String getChannelTarget() { + return targetUri.toString(); + } + @Override public SynchronizationContext getSynchronizationContext() { return syncContext; @@ -1723,6 +1551,11 @@ public NameResolverRegistry getNameResolverRegistry() { return nameResolverRegistry; } + @Override + public MetricRecorder getMetricRecorder() { + return metricRecorder; + } + /** * A placeholder for channel creds if user did not specify channel creds for the channel. */ @@ -1736,6 +1569,19 @@ public ChannelCredentials withoutBearerTokens() { } } + static final class OobChannel extends ForwardingManagedChannel { + private final OobNameResolverProvider resolverProvider; + + public OobChannel(ManagedChannel delegate, OobNameResolverProvider resolverProvider) { + super(delegate); + this.resolverProvider = checkNotNull(resolverProvider, "resolverProvider"); + } + + public void updateAddresses(List eags) { + resolverProvider.updateAddresses(eags); + } + } + final class NameResolverListener extends NameResolver.Listener2 { final LbHelperImpl helper; final NameResolver resolver; @@ -1747,148 +1593,145 @@ final class NameResolverListener extends NameResolver.Listener2 { @Override public void onResult(final ResolutionResult resolutionResult) { - final class NamesResolved implements Runnable { - - @SuppressWarnings("ReferenceEquality") - @Override - public void run() { - if (ManagedChannelImpl.this.nameResolver != resolver) { - return; - } + syncContext.execute(() -> onResult2(resolutionResult)); + } - List servers = resolutionResult.getAddresses(); + @SuppressWarnings("ReferenceEquality") + @Override + public Status onResult2(final ResolutionResult resolutionResult) { + syncContext.throwIfNotInThisSynchronizationContext(); + if (ManagedChannelImpl.this.nameResolver != resolver) { + return Status.OK; + } + + StatusOr> serversOrError = + resolutionResult.getAddressesOrError(); + if (!serversOrError.hasValue()) { + handleErrorInSyncContext(serversOrError.getStatus()); + return serversOrError.getStatus(); + } + List servers = serversOrError.getValue(); + channelLogger.log( + ChannelLogLevel.DEBUG, + "Resolved address: {0}, config={1}", + servers, + resolutionResult.getAttributes()); + + if (lastResolutionState != ResolutionState.SUCCESS) { + channelLogger.log(ChannelLogLevel.INFO, "Address resolved: {0}", + servers); + lastResolutionState = ResolutionState.SUCCESS; + } + ConfigOrError configOrError = resolutionResult.getServiceConfig(); + InternalConfigSelector resolvedConfigSelector = + resolutionResult.getAttributes().get(InternalConfigSelector.KEY); + ManagedChannelServiceConfig validServiceConfig = + configOrError != null && configOrError.getConfig() != null + ? (ManagedChannelServiceConfig) configOrError.getConfig() + : null; + Status serviceConfigError = configOrError != null ? configOrError.getError() : null; + + ManagedChannelServiceConfig effectiveServiceConfig; + if (!lookUpServiceConfig) { + if (validServiceConfig != null) { channelLogger.log( - ChannelLogLevel.DEBUG, - "Resolved address: {0}, config={1}", - servers, - resolutionResult.getAttributes()); - - if (lastResolutionState != ResolutionState.SUCCESS) { - channelLogger.log(ChannelLogLevel.INFO, "Address resolved: {0}", servers); - lastResolutionState = ResolutionState.SUCCESS; - } - - ConfigOrError configOrError = resolutionResult.getServiceConfig(); - ResolutionResultListener resolutionResultListener = resolutionResult.getAttributes() - .get(RetryingNameResolver.RESOLUTION_RESULT_LISTENER_KEY); - InternalConfigSelector resolvedConfigSelector = - resolutionResult.getAttributes().get(InternalConfigSelector.KEY); - ManagedChannelServiceConfig validServiceConfig = - configOrError != null && configOrError.getConfig() != null - ? (ManagedChannelServiceConfig) configOrError.getConfig() - : null; - Status serviceConfigError = configOrError != null ? configOrError.getError() : null; - - ManagedChannelServiceConfig effectiveServiceConfig; - if (!lookUpServiceConfig) { - if (validServiceConfig != null) { - channelLogger.log( - ChannelLogLevel.INFO, - "Service config from name resolver discarded by channel settings"); - } - effectiveServiceConfig = - defaultServiceConfig == null ? EMPTY_SERVICE_CONFIG : defaultServiceConfig; - if (resolvedConfigSelector != null) { + ChannelLogLevel.INFO, + "Service config from name resolver discarded by channel settings"); + } + effectiveServiceConfig = + defaultServiceConfig == null ? EMPTY_SERVICE_CONFIG : defaultServiceConfig; + if (resolvedConfigSelector != null) { + channelLogger.log( + ChannelLogLevel.INFO, + "Config selector from name resolver discarded by channel settings"); + } + realChannel.updateConfigSelector(effectiveServiceConfig.getDefaultConfigSelector()); + } else { + // Try to use config if returned from name resolver + // Otherwise, try to use the default config if available + if (validServiceConfig != null) { + effectiveServiceConfig = validServiceConfig; + if (resolvedConfigSelector != null) { + realChannel.updateConfigSelector(resolvedConfigSelector); + if (effectiveServiceConfig.getDefaultConfigSelector() != null) { channelLogger.log( - ChannelLogLevel.INFO, - "Config selector from name resolver discarded by channel settings"); + ChannelLogLevel.DEBUG, + "Method configs in service config will be discarded due to presence of" + + "config-selector"); } + } else { realChannel.updateConfigSelector(effectiveServiceConfig.getDefaultConfigSelector()); + } + } else if (defaultServiceConfig != null) { + effectiveServiceConfig = defaultServiceConfig; + realChannel.updateConfigSelector(effectiveServiceConfig.getDefaultConfigSelector()); + channelLogger.log( + ChannelLogLevel.INFO, + "Received no service config, using default service config"); + } else if (serviceConfigError != null) { + if (!serviceConfigUpdated) { + // First DNS lookup has invalid service config, and cannot fall back to default + channelLogger.log( + ChannelLogLevel.INFO, + "Fallback to error due to invalid first service config without default config"); + // This error could be an "inappropriate" control plane error that should not bleed + // through to client code using gRPC. We let them flow through here to the LB as + // we later check for these error codes when investigating pick results in + // GrpcUtil.getTransportFromPickResult(). + onError(configOrError.getError()); + return configOrError.getError(); } else { - // Try to use config if returned from name resolver - // Otherwise, try to use the default config if available - if (validServiceConfig != null) { - effectiveServiceConfig = validServiceConfig; - if (resolvedConfigSelector != null) { - realChannel.updateConfigSelector(resolvedConfigSelector); - if (effectiveServiceConfig.getDefaultConfigSelector() != null) { - channelLogger.log( - ChannelLogLevel.DEBUG, - "Method configs in service config will be discarded due to presence of" - + "config-selector"); - } - } else { - realChannel.updateConfigSelector(effectiveServiceConfig.getDefaultConfigSelector()); - } - } else if (defaultServiceConfig != null) { - effectiveServiceConfig = defaultServiceConfig; - realChannel.updateConfigSelector(effectiveServiceConfig.getDefaultConfigSelector()); - channelLogger.log( - ChannelLogLevel.INFO, - "Received no service config, using default service config"); - } else if (serviceConfigError != null) { - if (!serviceConfigUpdated) { - // First DNS lookup has invalid service config, and cannot fall back to default - channelLogger.log( - ChannelLogLevel.INFO, - "Fallback to error due to invalid first service config without default config"); - // This error could be an "inappropriate" control plane error that should not bleed - // through to client code using gRPC. We let them flow through here to the LB as - // we later check for these error codes when investigating pick results in - // GrpcUtil.getTransportFromPickResult(). - onError(configOrError.getError()); - if (resolutionResultListener != null) { - resolutionResultListener.resolutionAttempted(configOrError.getError()); - } - return; - } else { - effectiveServiceConfig = lastServiceConfig; - } - } else { - effectiveServiceConfig = EMPTY_SERVICE_CONFIG; - realChannel.updateConfigSelector(null); - } - if (!effectiveServiceConfig.equals(lastServiceConfig)) { - channelLogger.log( - ChannelLogLevel.INFO, - "Service config changed{0}", - effectiveServiceConfig == EMPTY_SERVICE_CONFIG ? " to empty" : ""); - lastServiceConfig = effectiveServiceConfig; - transportProvider.throttle = effectiveServiceConfig.getRetryThrottling(); - } - - try { - // TODO(creamsoup): when `servers` is empty and lastResolutionStateCopy == SUCCESS - // and lbNeedAddress, it shouldn't call the handleServiceConfigUpdate. But, - // lbNeedAddress is not deterministic - serviceConfigUpdated = true; - } catch (RuntimeException re) { - logger.log( - Level.WARNING, - "[" + getLogId() + "] Unexpected exception from parsing service config", - re); - } + effectiveServiceConfig = lastServiceConfig; } + } else { + effectiveServiceConfig = EMPTY_SERVICE_CONFIG; + realChannel.updateConfigSelector(null); + } + if (!effectiveServiceConfig.equals(lastServiceConfig)) { + channelLogger.log( + ChannelLogLevel.INFO, + "Service config changed{0}", + effectiveServiceConfig == EMPTY_SERVICE_CONFIG ? " to empty" : ""); + lastServiceConfig = effectiveServiceConfig; + transportProvider.throttle = effectiveServiceConfig.getRetryThrottling(); + } - Attributes effectiveAttrs = resolutionResult.getAttributes(); - // Call LB only if it's not shutdown. If LB is shutdown, lbHelper won't match. - if (NameResolverListener.this.helper == ManagedChannelImpl.this.lbHelper) { - Attributes.Builder attrBuilder = - effectiveAttrs.toBuilder().discard(InternalConfigSelector.KEY); - Map healthCheckingConfig = - effectiveServiceConfig.getHealthCheckingConfig(); - if (healthCheckingConfig != null) { - attrBuilder - .set(LoadBalancer.ATTR_HEALTH_CHECKING_CONFIG, healthCheckingConfig) - .build(); - } - Attributes attributes = attrBuilder.build(); - - Status addressAcceptanceStatus = helper.lb.tryAcceptResolvedAddresses( - ResolvedAddresses.newBuilder() - .setAddresses(servers) - .setAttributes(attributes) - .setLoadBalancingPolicyConfig(effectiveServiceConfig.getLoadBalancingConfig()) - .build()); - // If a listener is provided, let it know if the addresses were accepted. - if (resolutionResultListener != null) { - resolutionResultListener.resolutionAttempted(addressAcceptanceStatus); - } - } + try { + // TODO(creamsoup): when `serversOrError` is empty and lastResolutionStateCopy == SUCCESS + // and lbNeedAddress, it shouldn't call the handleServiceConfigUpdate. But, + // lbNeedAddress is not deterministic + serviceConfigUpdated = true; + } catch (RuntimeException re) { + logger.log( + Level.WARNING, + "[" + getLogId() + "] Unexpected exception from parsing service config", + re); } } - syncContext.execute(new NamesResolved()); + Attributes effectiveAttrs = resolutionResult.getAttributes(); + // Call LB only if it's not shutdown. If LB is shutdown, lbHelper won't match. + if (NameResolverListener.this.helper == ManagedChannelImpl.this.lbHelper) { + Attributes.Builder attrBuilder = + effectiveAttrs.toBuilder().discard(InternalConfigSelector.KEY); + Map healthCheckingConfig = + effectiveServiceConfig.getHealthCheckingConfig(); + if (healthCheckingConfig != null) { + attrBuilder + .set(LoadBalancer.ATTR_HEALTH_CHECKING_CONFIG, healthCheckingConfig) + .build(); + } + Attributes attributes = attrBuilder.build(); + + ResolvedAddresses.Builder resolvedAddresses = ResolvedAddresses.newBuilder() + .setAddresses(serversOrError.getValue()) + .setAttributes(attributes) + .setLoadBalancingPolicyConfig(effectiveServiceConfig.getLoadBalancingConfig()); + Status addressAcceptanceStatus = helper.lb.acceptResolvedAddresses( + resolvedAddresses.build()); + return addressAcceptanceStatus; + } + return Status.OK; } @Override @@ -1982,7 +1825,7 @@ void onNotInUse(InternalSubchannel is) { } final InternalSubchannel internalSubchannel = new InternalSubchannel( - args.getAddresses(), + args, authority(), userAgent, backoffPolicyProvider, @@ -1996,7 +1839,8 @@ void onNotInUse(InternalSubchannel is) { subchannelTracer, subchannelLogId, subchannelLogger, - transportFilters); + transportFilters, target, + lbHelper.getMetricRecorder()); channelTracer.reportEvent(new ChannelTrace.Event.Builder() .setDescription("Child Subchannel started") @@ -2068,6 +1912,9 @@ public void run() { public void requestConnection() { syncContext.throwIfNotInThisSynchronizationContext(); checkState(started, "not started"); + if (shutdown) { + return; + } subchannel.obtainActiveTransport(); } @@ -2119,6 +1966,11 @@ public void updateAddresses(List addrs) { subchannel.updateAddresses(addrs); } + @Override + public Attributes getConnectedAddressAttributes() { + return subchannel.getConnectedAddressAttributes(); + } + private List stripOverrideAuthorityAttributes( List eags) { List eagsWithoutOverrideAttr = new ArrayList<>(); @@ -2145,7 +1997,7 @@ public String toString() { */ private final class DelayedTransportListener implements ManagedClientTransport.Listener { @Override - public void transportShutdown(Status s) { + public void transportShutdown(Status s, DisconnectError e) { checkState(shutdown.get(), "Channel must have been shut down"); } @@ -2162,6 +2014,12 @@ public Attributes filterTransport(Attributes attributes) { @Override public void transportInUse(final boolean inUse) { inUseStateAggregator.updateObjectInUse(delayedTransport, inUse); + if (inUse) { + // It's possible to be in idle mode while inUseStateAggregator is in-use, if one of the + // subchannels is in use. But we should never be in idle mode when delayed transport is in + // use. + exitIdleMode(); + } } @Override diff --git a/core/src/main/java/io/grpc/internal/ManagedChannelImplBuilder.java b/core/src/main/java/io/grpc/internal/ManagedChannelImplBuilder.java index 1e40e547755..128c929ec0e 100644 --- a/core/src/main/java/io/grpc/internal/ManagedChannelImplBuilder.java +++ b/core/src/main/java/io/grpc/internal/ManagedChannelImplBuilder.java @@ -18,6 +18,7 @@ import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkNotNull; +import static io.grpc.internal.UriWrapper.wrap; import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; @@ -26,20 +27,28 @@ import io.grpc.Attributes; import io.grpc.BinaryLog; import io.grpc.CallCredentials; +import io.grpc.CallOptions; +import io.grpc.Channel; import io.grpc.ChannelCredentials; +import io.grpc.ClientCall; import io.grpc.ClientInterceptor; import io.grpc.ClientTransportFilter; import io.grpc.CompressorRegistry; import io.grpc.DecompressorRegistry; import io.grpc.EquivalentAddressGroup; import io.grpc.InternalChannelz; -import io.grpc.InternalGlobalInterceptors; +import io.grpc.InternalConfiguratorRegistry; +import io.grpc.InternalFeatureFlags; import io.grpc.ManagedChannel; import io.grpc.ManagedChannelBuilder; +import io.grpc.MethodDescriptor; +import io.grpc.MetricSink; import io.grpc.NameResolver; import io.grpc.NameResolverProvider; import io.grpc.NameResolverRegistry; import io.grpc.ProxyDetector; +import io.grpc.StatusOr; +import io.grpc.Uri; import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; import java.net.SocketAddress; @@ -49,6 +58,7 @@ import java.util.Arrays; import java.util.Collection; import java.util.Collections; +import java.util.IdentityHashMap; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; @@ -56,6 +66,7 @@ import java.util.concurrent.TimeUnit; import java.util.logging.Level; import java.util.logging.Logger; +import java.util.regex.Pattern; import javax.annotation.Nullable; /** @@ -108,6 +119,12 @@ public static ManagedChannelBuilder forTarget(String target) { private static final long DEFAULT_RETRY_BUFFER_SIZE_IN_BYTES = 1L << 24; // 16M private static final long DEFAULT_PER_RPC_BUFFER_LIMIT_IN_BYTES = 1L << 20; // 1M + // Matching this pattern means the target string is a URI target or at least intended to be one. + // A URI target must be an absolute hierarchical URI. + // From RFC 2396: scheme = alpha *( alpha | digit | "+" | "-" | "." ) + @VisibleForTesting + static final Pattern URI_PATTERN = Pattern.compile("[a-zA-Z][a-zA-Z0-9+.-]*:/.*"); + private static final Method GET_CLIENT_INTERCEPTOR_METHOD; static { @@ -146,6 +163,8 @@ public static ManagedChannelBuilder forTarget(String target) { final ChannelCredentials channelCredentials; @Nullable final CallCredentials callCredentials; + @Nullable + IdentityHashMap, Object> nameResolverCustomArgs; @Nullable private final SocketAddress directServerAddress; @@ -192,6 +211,7 @@ public static ManagedChannelBuilder forTarget(String target) { private boolean recordRealTimeMetrics = false; private boolean recordRetryMetrics = true; private boolean tracingEnabled = true; + List metricSinks = new ArrayList<>(); /** * An interface for Transport implementors to provide the {@link ClientTransportFactory} @@ -283,6 +303,8 @@ public ManagedChannelImplBuilder( } else { this.channelBuilderDefaultPortProvider = new ManagedChannelDefaultPortProvider(); } + // TODO(dnvindhya): Move configurator to all the individual builders + InternalConfiguratorRegistry.configureChannelBuilder(this); } /** @@ -340,6 +362,8 @@ public ManagedChannelImplBuilder(SocketAddress directServerAddress, String autho } else { this.channelBuilderDefaultPortProvider = new ManagedChannelDefaultPortProvider(); } + // TODO(dnvindhya): Move configurator to all the individual builders + InternalConfiguratorRegistry.configureChannelBuilder(this); } @Override @@ -378,6 +402,14 @@ public ManagedChannelImplBuilder intercept(ClientInterceptor... interceptors) { return intercept(Arrays.asList(interceptors)); } + @Override + protected ManagedChannelImplBuilder interceptWithTarget(InterceptorFactory factory) { + // Add a placeholder instance to the interceptor list, and replace it with a real instance + // during build(). + this.interceptors.add(new InterceptorFactoryWrapper(factory)); + return this; + } + @Override public ManagedChannelImplBuilder addTransportFilter(ClientTransportFilter hook) { transportFilters.add(checkNotNull(hook, "transport filter")); @@ -587,6 +619,24 @@ private static List checkListEntryTypes(List list) { return Collections.unmodifiableList(parsedList); } + @Override + public ManagedChannelImplBuilder setNameResolverArg(NameResolver.Args.Key key, X value) { + if (nameResolverCustomArgs == null) { + nameResolverCustomArgs = new IdentityHashMap<>(); + } + nameResolverCustomArgs.put(checkNotNull(key, "key"), checkNotNull(value, "value")); + return this; + } + + @SuppressWarnings("unchecked") // This cast is safe because of setNameResolverArg()'s signature. + void copyAllNameResolverCustomArgsTo(NameResolver.Args.Builder dest) { + if (nameResolverCustomArgs != null) { + for (Map.Entry, Object> entry : nameResolverCustomArgs.entrySet()) { + dest.setArg((NameResolver.Args.Key) entry.getKey(), entry.getValue()); + } + } + } + @Override public ManagedChannelImplBuilder disableServiceConfigLookUp() { this.lookUpServiceConfig = false; @@ -661,15 +711,30 @@ public ManagedChannelImplBuilder enableCheckAuthority() { return this; } + @Override + protected ManagedChannelImplBuilder addMetricSink(MetricSink metricSink) { + metricSinks.add(checkNotNull(metricSink, "metric sink")); + return this; + } + @Override public ManagedChannel build() { + ClientTransportFactory clientTransportFactory = + clientTransportFactoryBuilder.buildClientTransportFactory(); + ResolvedNameResolver resolvedResolver = + InternalFeatureFlags.getRfc3986UrisEnabled() + ? getNameResolverProviderRfc3986(target, nameResolverRegistry) + : getNameResolverProvider(target, nameResolverRegistry); + resolvedResolver.checkAddressTypes(clientTransportFactory.getSupportedSocketAddressTypes()); return new ManagedChannelOrphanWrapper(new ManagedChannelImpl( this, - clientTransportFactoryBuilder.buildClientTransportFactory(), + clientTransportFactory, + resolvedResolver.targetUri, + resolvedResolver.provider, new ExponentialBackoffPolicy.Provider(), SharedResourcePool.forResource(GrpcUtil.SHARED_CHANNEL_EXECUTOR), GrpcUtil.STOPWATCH_SUPPLIER, - getEffectiveInterceptors(), + getEffectiveInterceptors(resolvedResolver.targetUri.toString()), TimeProvider.SYSTEM_TIME_PROVIDER)); } @@ -677,22 +742,30 @@ public ManagedChannel build() { // what should be the desired behavior for retry + stats/tracing. // TODO(zdapeng): FIX IT @VisibleForTesting - List getEffectiveInterceptors() { - List effectiveInterceptors = new ArrayList<>(this.interceptors); - boolean isGlobalInterceptorsSet = false; - List globalClientInterceptors = - InternalGlobalInterceptors.getClientInterceptors(); - if (globalClientInterceptors != null) { - effectiveInterceptors.addAll(globalClientInterceptors); - isGlobalInterceptorsSet = true; - } - if (!isGlobalInterceptorsSet && statsEnabled) { + List getEffectiveInterceptors(String computedTarget) { + List effectiveInterceptors = new ArrayList<>(this.interceptors.size()); + for (ClientInterceptor interceptor : this.interceptors) { + if (interceptor instanceof InterceptorFactoryWrapper) { + InterceptorFactory factory = ((InterceptorFactoryWrapper) interceptor).factory; + interceptor = factory.newInterceptor(computedTarget); + if (interceptor == null) { + throw new NullPointerException("Factory returned null interceptor: " + factory); + } + } + effectiveInterceptors.add(interceptor); + } + + boolean disableImplicitCensus = InternalConfiguratorRegistry.wasSetConfiguratorsCalled(); + if (disableImplicitCensus) { + return effectiveInterceptors; + } + if (statsEnabled) { ClientInterceptor statsInterceptor = null; if (GET_CLIENT_INTERCEPTOR_METHOD != null) { try { statsInterceptor = - (ClientInterceptor) GET_CLIENT_INTERCEPTOR_METHOD + (ClientInterceptor) GET_CLIENT_INTERCEPTOR_METHOD .invoke( null, recordStartedRpcs, @@ -712,7 +785,7 @@ List getEffectiveInterceptors() { effectiveInterceptors.add(0, statsInterceptor); } } - if (!isGlobalInterceptorsSet && tracingEnabled) { + if (tracingEnabled) { ClientInterceptor tracingInterceptor = null; try { Class censusTracingAccessor = @@ -745,6 +818,114 @@ int getDefaultPort() { return channelBuilderDefaultPortProvider.getDefaultPort(); } + @VisibleForTesting + static class ResolvedNameResolver { + public final UriWrapper targetUri; + public final NameResolverProvider provider; + + public ResolvedNameResolver(UriWrapper targetUri, NameResolverProvider provider) { + this.targetUri = checkNotNull(targetUri, "targetUri"); + this.provider = checkNotNull(provider, "provider"); + } + + void checkAddressTypes( + Collection> channelTransportSocketAddressTypes) { + if (channelTransportSocketAddressTypes != null) { + Collection> nameResolverSocketAddressTypes = + provider.getProducedSocketAddressTypes(); + if (!channelTransportSocketAddressTypes.containsAll(nameResolverSocketAddressTypes)) { + throw new IllegalArgumentException( + String.format( + "Address types of NameResolver '%s' for '%s' not supported by transport", + provider.getDefaultScheme(), targetUri)); + } + } + } + } + + @VisibleForTesting + static ResolvedNameResolver getNameResolverProvider( + String target, NameResolverRegistry nameResolverRegistry) { + // Finding a NameResolver. Try using the target string as the URI. If that fails, try prepending + // "dns:///". + NameResolverProvider provider = null; + URI targetUri = null; + StringBuilder uriSyntaxErrors = new StringBuilder(); + try { + targetUri = new URI(target); + } catch (URISyntaxException e) { + // Can happen with ip addresses like "[::1]:1234" or 127.0.0.1:1234. + uriSyntaxErrors.append(e.getMessage()); + } + if (targetUri != null) { + // For "localhost:8080" this would likely cause provider to be null, because "localhost" is + // parsed as the scheme. Will hit the next case and try "dns:///localhost:8080". + provider = nameResolverRegistry.getProviderForScheme(targetUri.getScheme()); + } + + if (provider == null && !URI_PATTERN.matcher(target).matches()) { + // It doesn't look like a URI target. Maybe it's an authority string. Try with the default + // scheme from the registry. + try { + targetUri = new URI(nameResolverRegistry.getDefaultScheme(), "", "/" + target, null); + } catch (URISyntaxException e) { + // Should not be possible. + throw new IllegalArgumentException(e); + } + provider = nameResolverRegistry.getProviderForScheme(targetUri.getScheme()); + } + + if (provider == null) { + throw new IllegalArgumentException(String.format( + "Could not find a NameResolverProvider for %s%s", + target, uriSyntaxErrors.length() > 0 ? " (" + uriSyntaxErrors + ")" : "")); + } + + return new ResolvedNameResolver(wrap(targetUri), provider); + } + + @VisibleForTesting + static ResolvedNameResolver getNameResolverProviderRfc3986( + String target, NameResolverRegistry nameResolverRegistry) { + // Finding a NameResolver. Try using the target string as the URI. If that fails, try prepending + // "dns:///". + NameResolverProvider provider = null; + Uri targetUri = null; + StringBuilder uriSyntaxErrors = new StringBuilder(); + try { + targetUri = Uri.parse(target); + } catch (URISyntaxException e) { + // Can happen with ip addresses like "[::1]:1234" or 127.0.0.1:1234. + uriSyntaxErrors.append(e.getMessage()); + } + if (targetUri != null) { + // For "localhost:8080" this would likely cause provider to be null, because "localhost" is + // parsed as the scheme. Will hit the next case and try "dns:///localhost:8080". + provider = nameResolverRegistry.getProviderForScheme(targetUri.getScheme()); + } + + if (provider == null && !URI_PATTERN.matcher(target).matches()) { + // It doesn't look like a URI target. Maybe it's an authority string. Try with the default + // scheme from the registry. + targetUri = + Uri.newBuilder() + .setScheme(nameResolverRegistry.getDefaultScheme()) + .setHost("") + .setPath("/" + target) + .build(); + provider = nameResolverRegistry.getProviderForScheme(targetUri.getScheme()); + } + + if (provider == null) { + throw new IllegalArgumentException( + String.format( + "Could not find a NameResolverProvider for %s%s", + target, uriSyntaxErrors.length() > 0 ? " (" + uriSyntaxErrors + ")" : "")); + } + + return new ResolvedNameResolver(wrap(targetUri), provider); + } + private static class DirectAddressNameResolverProvider extends NameResolverProvider { final SocketAddress address; final String authority; @@ -767,9 +948,11 @@ public String getServiceAuthority() { @Override public void start(Listener2 listener) { - listener.onResult( + listener.onResult2( ResolutionResult.newBuilder() - .setAddresses(Collections.singletonList(new EquivalentAddressGroup(address))) + .setAddressesOrError( + StatusOr.fromValue( + Collections.singletonList(new EquivalentAddressGroup(address)))) .setAttributes(Attributes.EMPTY) .build()); } @@ -800,6 +983,20 @@ public Collection> getProducedSocketAddressTypes( } } + private static final class InterceptorFactoryWrapper implements ClientInterceptor { + final InterceptorFactory factory; + + public InterceptorFactoryWrapper(InterceptorFactory factory) { + this.factory = checkNotNull(factory, "factory"); + } + + @Override + public ClientCall interceptCall( + MethodDescriptor method, CallOptions callOptions, Channel next) { + throw new AssertionError("Should have been replaced with real instance"); + } + } + /** * Returns the internal offload executor pool for offloading tasks. */ diff --git a/core/src/main/java/io/grpc/internal/ManagedChannelOrphanWrapper.java b/core/src/main/java/io/grpc/internal/ManagedChannelOrphanWrapper.java index eac9b64d9db..790d5bd297f 100644 --- a/core/src/main/java/io/grpc/internal/ManagedChannelOrphanWrapper.java +++ b/core/src/main/java/io/grpc/internal/ManagedChannelOrphanWrapper.java @@ -63,12 +63,20 @@ final class ManagedChannelOrphanWrapper extends ForwardingManagedChannel { @Override public ManagedChannel shutdown() { phantom.clearSafely(); + // This dummy check prevents the JIT from collecting 'this' too early + if (this.getClass() == null) { + throw new AssertionError(); + } return super.shutdown(); } @Override public ManagedChannel shutdownNow() { phantom.clearSafely(); + // This dummy check prevents the JIT from collecting 'this' too early + if (this.getClass() == null) { + throw new AssertionError(); + } return super.shutdownNow(); } @@ -151,8 +159,9 @@ static int cleanQueue(ReferenceQueue refqueue) { int orphanedChannels = 0; while ((ref = (ManagedChannelReference) refqueue.poll()) != null) { RuntimeException maybeAllocationSite = ref.allocationSite.get(); + boolean wasShutdown = ref.shutdown.get(); ref.clearInternal(); // technically the reference is gone already. - if (!ref.shutdown.get()) { + if (!wasShutdown) { orphanedChannels++; Level level = Level.SEVERE; if (logger.isLoggable(level)) { diff --git a/core/src/main/java/io/grpc/internal/ManagedClientTransport.java b/core/src/main/java/io/grpc/internal/ManagedClientTransport.java index 1f18e317849..99a3bd1eceb 100644 --- a/core/src/main/java/io/grpc/internal/ManagedClientTransport.java +++ b/core/src/main/java/io/grpc/internal/ManagedClientTransport.java @@ -16,11 +16,10 @@ package io.grpc.internal; +import com.google.errorprone.annotations.CheckReturnValue; import io.grpc.Attributes; import io.grpc.Status; -import javax.annotation.CheckReturnValue; import javax.annotation.Nullable; -import javax.annotation.concurrent.ThreadSafe; /** * A {@link ClientTransport} that has life-cycle management. @@ -32,16 +31,15 @@ * implementations may transfer the streams to somewhere else. Either way they must conform to the * contract defined by {@link #shutdown}, {@link Listener#transportShutdown} and * {@link Listener#transportTerminated}. + * + *

This interface is thread-safe. */ -@ThreadSafe public interface ManagedClientTransport extends ClientTransport { /** * Starts transport. This method may only be called once. * - *

Implementations must not call {@code listener} from within {@link #start}; implementations - * are expected to notify listener on a separate thread or when the returned {@link Runnable} is - * run. This method and the returned {@code Runnable} should not throw any exceptions. + *

This method and the returned {@code Runnable} should not throw any exceptions. * * @param listener non-{@code null} listener of transport events * @return a {@link Runnable} that is executed after-the-fact by the original caller, typically @@ -79,8 +77,9 @@ interface Listener { *

This is called exactly once, and must be called prior to {@link #transportTerminated}. * * @param s the reason for the shutdown. + * @param e the disconnect error. */ - void transportShutdown(Status s); + void transportShutdown(Status s, DisconnectError e); /** * The transport completed shutting down. All resources have been released. All streams have diff --git a/core/src/main/java/io/grpc/internal/MessageDeframer.java b/core/src/main/java/io/grpc/internal/MessageDeframer.java index c8b250c2143..f388c006e97 100644 --- a/core/src/main/java/io/grpc/internal/MessageDeframer.java +++ b/core/src/main/java/io/grpc/internal/MessageDeframer.java @@ -314,6 +314,12 @@ private boolean readRequiredBytes() { int totalBytesRead = 0; int deflatedBytesRead = 0; try { + // Avoid allocating nextFrame when idle + if (requiredLength > 0 && fullStreamDecompressor == null + && unprocessed.readableBytes() == 0) { + return false; + } + if (nextFrame == null) { nextFrame = new CompositeReadableBuffer(); } @@ -406,7 +412,8 @@ private void processBody() { // There is no reliable way to get the uncompressed size per message when it's compressed, // because the uncompressed bytes are provided through an InputStream whose total size is // unknown until all bytes are read, and we don't know when it happens. - statsTraceCtx.inboundMessageRead(currentMessageSeqNo, inboundBodyWireSize, -1); + statsTraceCtx.inboundMessageRead(currentMessageSeqNo, inboundBodyWireSize, + (compressedFlag || fullStreamDecompressor != null) ? -1 : inboundBodyWireSize); inboundBodyWireSize = 0; InputStream stream = compressedFlag ? getCompressedBody() : getUncompressedBody(); nextFrame.touch(); diff --git a/core/src/main/java/io/grpc/internal/MessageFramer.java b/core/src/main/java/io/grpc/internal/MessageFramer.java index 5e75fa2e6fe..8b5ccb864a4 100644 --- a/core/src/main/java/io/grpc/internal/MessageFramer.java +++ b/core/src/main/java/io/grpc/internal/MessageFramer.java @@ -75,6 +75,10 @@ void deliverFrame( // effectively final. Can only be set once. private int maxOutboundMessageSize = NO_MAX_OUTBOUND_MESSAGE_SIZE; private WritableBuffer buffer; + /** + * if > 0 - the number of bytes to allocate for the current known-length message. + */ + private int knownLengthPendingAllocation; private Compressor compressor = Codec.Identity.NONE; private boolean messageCompression = true; private final OutputStreamAdapter outputStreamAdapter = new OutputStreamAdapter(); @@ -222,9 +226,7 @@ private int writeKnownLengthUncompressed(InputStream message, int messageLength) headerScratch.put(UNCOMPRESSED).putInt(messageLength); // Allocate the initial buffer chunk based on frame header + payload length. // Note that the allocator may allocate a buffer larger or smaller than this length - if (buffer == null) { - buffer = bufferAllocator.allocate(headerScratch.position() + messageLength); - } + knownLengthPendingAllocation = HEADER_LENGTH + messageLength; writeRaw(headerScratch.array(), 0, headerScratch.position()); return writeToOutputStream(message, outputStreamAdapter); } @@ -288,8 +290,9 @@ private void writeRaw(byte[] b, int off, int len) { commitToSink(false, false); } if (buffer == null) { - // Request a buffer allocation using the message length as a hint. - buffer = bufferAllocator.allocate(len); + checkState(knownLengthPendingAllocation > 0, "knownLengthPendingAllocation reached 0"); + buffer = bufferAllocator.allocate(knownLengthPendingAllocation); + knownLengthPendingAllocation -= min(knownLengthPendingAllocation, buffer.writableBytes()); } int toWrite = min(len, buffer.writableBytes()); buffer.write(b, off, toWrite); @@ -388,6 +391,8 @@ public void write(byte[] b, int off, int len) { * {@link OutputStream}. */ private final class BufferChainOutputStream extends OutputStream { + private static final int FIRST_BUFFER_SIZE = 4096; + private final List bufferList = new ArrayList<>(); private WritableBuffer current; @@ -397,7 +402,7 @@ private final class BufferChainOutputStream extends OutputStream { * {@link #write(byte[], int, int)}. */ @Override - public void write(int b) throws IOException { + public void write(int b) { if (current != null && current.writableBytes() > 0) { current.write((byte)b); return; @@ -410,7 +415,7 @@ public void write(int b) throws IOException { public void write(byte[] b, int off, int len) { if (current == null) { // Request len bytes initially from the allocator, it may give us more. - current = bufferAllocator.allocate(len); + current = bufferAllocator.allocate(Math.max(FIRST_BUFFER_SIZE, len)); bufferList.add(current); } while (len > 0) { diff --git a/core/src/main/java/io/grpc/internal/MetadataApplierImpl.java b/core/src/main/java/io/grpc/internal/MetadataApplierImpl.java index 12cab15053f..166f97b78f5 100644 --- a/core/src/main/java/io/grpc/internal/MetadataApplierImpl.java +++ b/core/src/main/java/io/grpc/internal/MetadataApplierImpl.java @@ -20,6 +20,7 @@ import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.base.Preconditions.checkState; +import com.google.errorprone.annotations.concurrent.GuardedBy; import io.grpc.CallCredentials.MetadataApplier; import io.grpc.CallOptions; import io.grpc.ClientStreamTracer; @@ -28,7 +29,6 @@ import io.grpc.MethodDescriptor; import io.grpc.Status; import javax.annotation.Nullable; -import javax.annotation.concurrent.GuardedBy; final class MetadataApplierImpl extends MetadataApplier { private final ClientTransport transport; @@ -120,7 +120,7 @@ ClientStream returnStream() { synchronized (lock) { if (returnedStream == null) { // apply() has not been called, needs to buffer the requests. - delayedStream = new DelayedStream(); + delayedStream = new DelayedStream("call_credentials"); return returnedStream = delayedStream; } else { return returnedStream; diff --git a/core/src/main/java/io/grpc/internal/MetricRecorderImpl.java b/core/src/main/java/io/grpc/internal/MetricRecorderImpl.java new file mode 100644 index 00000000000..6a12a38d677 --- /dev/null +++ b/core/src/main/java/io/grpc/internal/MetricRecorderImpl.java @@ -0,0 +1,235 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.internal; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkNotNull; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.ImmutableList; +import io.grpc.CallbackMetricInstrument; +import io.grpc.DoubleCounterMetricInstrument; +import io.grpc.DoubleHistogramMetricInstrument; +import io.grpc.LongCounterMetricInstrument; +import io.grpc.LongGaugeMetricInstrument; +import io.grpc.LongHistogramMetricInstrument; +import io.grpc.LongUpDownCounterMetricInstrument; +import io.grpc.MetricInstrument; +import io.grpc.MetricInstrumentRegistry; +import io.grpc.MetricRecorder; +import io.grpc.MetricSink; +import java.util.ArrayList; +import java.util.BitSet; +import java.util.List; + +/** + * Provides a central point for gRPC components to record metric values. Metrics can be exported to + * monitoring systems by configuring one or more {@link MetricSink}s. + * + *

This class encapsulates the interaction with metric sinks, including updating them with + * the latest set of {@link MetricInstrument}s provided by the {@link MetricInstrumentRegistry}. + */ +final class MetricRecorderImpl implements MetricRecorder { + + private final List metricSinks; + private final MetricInstrumentRegistry registry; + + @VisibleForTesting + MetricRecorderImpl(List metricSinks, MetricInstrumentRegistry registry) { + this.metricSinks = ImmutableList.copyOf(metricSinks); + this.registry = registry; + } + + /** + * Records a double counter value. + * + * @param metricInstrument the {@link DoubleCounterMetricInstrument} to record. + * @param value the value to record. + * @param requiredLabelValues the required label values for the metric. + * @param optionalLabelValues the optional label values for the metric. + */ + @Override + public void addDoubleCounter(DoubleCounterMetricInstrument metricInstrument, double value, + List requiredLabelValues, List optionalLabelValues) { + MetricRecorder.super.addDoubleCounter(metricInstrument, value, requiredLabelValues, + optionalLabelValues); + for (MetricSink sink : metricSinks) { + // TODO(dnvindhya): Move updating measures logic from sink to here + int measuresSize = sink.getMeasuresSize(); + if (measuresSize <= metricInstrument.getIndex()) { + // Measures may need updating in two cases: + // 1. When the sink is initially created with an empty list of measures. + // 2. When new metric instruments are registered, requiring the sink to accommodate them. + sink.updateMeasures(registry.getMetricInstruments()); + } + sink.addDoubleCounter(metricInstrument, value, requiredLabelValues, optionalLabelValues); + } + } + + /** + * Records a long counter value. + * + * @param metricInstrument the {@link LongCounterMetricInstrument} to record. + * @param value the value to record. Must be non-negative. + * @param requiredLabelValues the required label values for the metric. + * @param optionalLabelValues the optional label values for the metric. + */ + @Override + public void addLongCounter(LongCounterMetricInstrument metricInstrument, long value, + List requiredLabelValues, List optionalLabelValues) { + MetricRecorder.super.addLongCounter(metricInstrument, value, requiredLabelValues, + optionalLabelValues); + for (MetricSink sink : metricSinks) { + int measuresSize = sink.getMeasuresSize(); + if (measuresSize <= metricInstrument.getIndex()) { + // Measures may need updating in two cases: + // 1. When the sink is initially created with an empty list of measures. + // 2. When new metric instruments are registered, requiring the sink to accommodate them. + sink.updateMeasures(registry.getMetricInstruments()); + } + sink.addLongCounter(metricInstrument, value, requiredLabelValues, optionalLabelValues); + } + } + + /** + * Adds a long up down counter value. + * + * @param metricInstrument the {@link io.grpc.LongUpDownCounterMetricInstrument} to record. + * @param value the value to record. May be positive, negative or zero. + * @param requiredLabelValues the required label values for the metric. + * @param optionalLabelValues the optional label values for the metric. + */ + @Override + public void addLongUpDownCounter(LongUpDownCounterMetricInstrument metricInstrument, long value, + List requiredLabelValues, + List optionalLabelValues) { + MetricRecorder.super.addLongUpDownCounter(metricInstrument, value, requiredLabelValues, + optionalLabelValues); + for (MetricSink sink : metricSinks) { + int measuresSize = sink.getMeasuresSize(); + if (measuresSize <= metricInstrument.getIndex()) { + // Measures may need updating in two cases: + // 1. When the sink is initially created with an empty list of measures. + // 2. When new metric instruments are registered, requiring the sink to accommodate them. + sink.updateMeasures(registry.getMetricInstruments()); + } + sink.addLongUpDownCounter(metricInstrument, value, requiredLabelValues, optionalLabelValues); + } + } + + /** + * Records a double histogram value. + * + * @param metricInstrument the {@link DoubleHistogramMetricInstrument} to record. + * @param value the value to record. + * @param requiredLabelValues the required label values for the metric. + * @param optionalLabelValues the optional label values for the metric. + */ + @Override + public void recordDoubleHistogram(DoubleHistogramMetricInstrument metricInstrument, double value, + List requiredLabelValues, List optionalLabelValues) { + MetricRecorder.super.recordDoubleHistogram(metricInstrument, value, requiredLabelValues, + optionalLabelValues); + for (MetricSink sink : metricSinks) { + int measuresSize = sink.getMeasuresSize(); + if (measuresSize <= metricInstrument.getIndex()) { + // Measures may need updating in two cases: + // 1. When the sink is initially created with an empty list of measures. + // 2. When new metric instruments are registered, requiring the sink to accommodate them. + sink.updateMeasures(registry.getMetricInstruments()); + } + sink.recordDoubleHistogram(metricInstrument, value, requiredLabelValues, optionalLabelValues); + } + } + + /** + * Records a long histogram value. + * + * @param metricInstrument the {@link LongHistogramMetricInstrument} to record. + * @param value the value to record. + * @param requiredLabelValues the required label values for the metric. + * @param optionalLabelValues the optional label values for the metric. + */ + @Override + public void recordLongHistogram(LongHistogramMetricInstrument metricInstrument, long value, + List requiredLabelValues, List optionalLabelValues) { + MetricRecorder.super.recordLongHistogram(metricInstrument, value, requiredLabelValues, + optionalLabelValues); + for (MetricSink sink : metricSinks) { + int measuresSize = sink.getMeasuresSize(); + if (measuresSize <= metricInstrument.getIndex()) { + // Measures may need updating in two cases: + // 1. When the sink is initially created with an empty list of measures. + // 2. When new metric instruments are registered, requiring the sink to accommodate them. + sink.updateMeasures(registry.getMetricInstruments()); + } + sink.recordLongHistogram(metricInstrument, value, requiredLabelValues, optionalLabelValues); + } + } + + @Override + public Registration registerBatchCallback(BatchCallback callback, + CallbackMetricInstrument... metricInstruments) { + long largestMetricInstrumentIndex = -1; + BitSet allowedInstruments = new BitSet(); + for (CallbackMetricInstrument metricInstrument : metricInstruments) { + largestMetricInstrumentIndex = + Math.max(largestMetricInstrumentIndex, metricInstrument.getIndex()); + allowedInstruments.set(metricInstrument.getIndex()); + } + List registrations = new ArrayList<>(); + for (MetricSink sink : metricSinks) { + int measuresSize = sink.getMeasuresSize(); + if (measuresSize <= largestMetricInstrumentIndex) { + // Measures may need updating in two cases: + // 1. When the sink is initially created with an empty list of measures. + // 2. When new metric instruments are registered, requiring the sink to accommodate them. + sink.updateMeasures(registry.getMetricInstruments()); + } + BatchRecorder singleSinkRecorder = new BatchRecorderImpl(sink, allowedInstruments); + registrations.add(sink.registerBatchCallback( + () -> callback.accept(singleSinkRecorder), metricInstruments)); + } + return () -> { + for (MetricSink.Registration registration : registrations) { + registration.close(); + } + }; + } + + /** Recorder for instrument values produced by a batch callback. */ + static class BatchRecorderImpl implements BatchRecorder { + private final MetricSink sink; + private final BitSet allowedInstruments; + + BatchRecorderImpl(MetricSink sink, BitSet allowedInstruments) { + this.sink = checkNotNull(sink, "sink"); + this.allowedInstruments = checkNotNull(allowedInstruments, "allowedInstruments"); + } + + @Override + public void recordLongGauge(LongGaugeMetricInstrument metricInstrument, long value, + List requiredLabelValues, List optionalLabelValues) { + BatchRecorder.super.recordLongGauge(metricInstrument, value, requiredLabelValues, + optionalLabelValues); + checkArgument(allowedInstruments.get(metricInstrument.getIndex()), + "Instrument was not listed when registering callback: %s", metricInstrument); + // Registering the callback checked that the instruments were be present in sink. + sink.recordLongGauge(metricInstrument, value, requiredLabelValues, optionalLabelValues); + } + } +} diff --git a/core/src/main/java/io/grpc/internal/MigratingThreadDeframer.java b/core/src/main/java/io/grpc/internal/MigratingThreadDeframer.java index c3342556c9f..e4f499ab483 100644 --- a/core/src/main/java/io/grpc/internal/MigratingThreadDeframer.java +++ b/core/src/main/java/io/grpc/internal/MigratingThreadDeframer.java @@ -18,6 +18,7 @@ import static com.google.common.base.Preconditions.checkNotNull; +import com.google.errorprone.annotations.concurrent.GuardedBy; import io.grpc.Decompressor; import io.perfmark.Link; import io.perfmark.PerfMark; @@ -26,7 +27,6 @@ import java.io.InputStream; import java.util.ArrayDeque; import java.util.Queue; -import javax.annotation.concurrent.GuardedBy; /** * A deframer that moves decoding between the transport and app threads based on which is more diff --git a/core/src/main/java/io/grpc/internal/NameResolverFactoryToProviderFacade.java b/core/src/main/java/io/grpc/internal/NameResolverFactoryToProviderFacade.java index 31c20f6e499..e52eb5e38d4 100644 --- a/core/src/main/java/io/grpc/internal/NameResolverFactoryToProviderFacade.java +++ b/core/src/main/java/io/grpc/internal/NameResolverFactoryToProviderFacade.java @@ -19,6 +19,7 @@ import io.grpc.NameResolver; import io.grpc.NameResolver.Args; import io.grpc.NameResolverProvider; +import io.grpc.Uri; import java.net.URI; public class NameResolverFactoryToProviderFacade extends NameResolverProvider { @@ -34,6 +35,11 @@ public NameResolver newNameResolver(URI targetUri, Args args) { return factory.newNameResolver(targetUri, args); } + @Override + public NameResolver newNameResolver(Uri targetUri, Args args) { + return factory.newNameResolver(targetUri, args); + } + @Override public String getDefaultScheme() { return factory.getDefaultScheme(); diff --git a/core/src/main/java/io/grpc/internal/NoopClientStream.java b/core/src/main/java/io/grpc/internal/NoopClientStream.java index d44170f69fa..d77d72a5412 100644 --- a/core/src/main/java/io/grpc/internal/NoopClientStream.java +++ b/core/src/main/java/io/grpc/internal/NoopClientStream.java @@ -45,7 +45,9 @@ public Attributes getAttributes() { public void request(int numMessages) {} @Override - public void writeMessage(InputStream message) {} + public void writeMessage(InputStream message) { + GrpcUtil.closeQuietly(message); + } @Override public void flush() {} diff --git a/core/src/main/java/io/grpc/internal/NoopSslSession.java b/core/src/main/java/io/grpc/internal/NoopSslSession.java new file mode 100644 index 00000000000..9a79d281ad5 --- /dev/null +++ b/core/src/main/java/io/grpc/internal/NoopSslSession.java @@ -0,0 +1,132 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.internal; + +import java.security.Principal; +import java.security.cert.Certificate; +import javax.net.ssl.SSLPeerUnverifiedException; +import javax.net.ssl.SSLSession; +import javax.net.ssl.SSLSessionContext; + +/** A no-op ssl session, to facilitate overriding only the required methods in specific + * implementations. + */ +public class NoopSslSession implements SSLSession { + @Override + public byte[] getId() { + return new byte[0]; + } + + @Override + public SSLSessionContext getSessionContext() { + return null; + } + + @Override + @SuppressWarnings("deprecation") + public javax.security.cert.X509Certificate[] getPeerCertificateChain() { + throw new UnsupportedOperationException("This method is deprecated and marked for removal. " + + "Use the getPeerCertificates() method instead."); + } + + @Override + public long getCreationTime() { + return 0; + } + + @Override + public long getLastAccessedTime() { + return 0; + } + + @Override + public void invalidate() { + } + + @Override + public boolean isValid() { + return false; + } + + @Override + public void putValue(String s, Object o) { + } + + @Override + public Object getValue(String s) { + return null; + } + + @Override + public void removeValue(String s) { + } + + @Override + public String[] getValueNames() { + return new String[0]; + } + + @Override + public Certificate[] getPeerCertificates() throws SSLPeerUnverifiedException { + return new Certificate[0]; + } + + @Override + public Certificate[] getLocalCertificates() { + return new Certificate[0]; + } + + @Override + public Principal getPeerPrincipal() throws SSLPeerUnverifiedException { + return null; + } + + @Override + public Principal getLocalPrincipal() { + return null; + } + + @Override + public String getCipherSuite() { + return null; + } + + @Override + public String getProtocol() { + return null; + } + + @Override + public String getPeerHost() { + return null; + } + + @Override + public int getPeerPort() { + return 0; + } + + @Override + public int getPacketBufferSize() { + return 0; + } + + @Override + public int getApplicationBufferSize() { + return 0; + } +} diff --git a/core/src/main/java/io/grpc/internal/ObjectPool.java b/core/src/main/java/io/grpc/internal/ObjectPool.java index 13547bc274a..5589cbbdf3c 100644 --- a/core/src/main/java/io/grpc/internal/ObjectPool.java +++ b/core/src/main/java/io/grpc/internal/ObjectPool.java @@ -16,12 +16,11 @@ package io.grpc.internal; -import javax.annotation.concurrent.ThreadSafe; - /** * An object pool. + * + *

This interface is thread-safe. */ -@ThreadSafe public interface ObjectPool { /** * Get an object from the pool. diff --git a/core/src/main/java/io/grpc/internal/OobChannel.java b/core/src/main/java/io/grpc/internal/OobChannel.java deleted file mode 100644 index 01ef457460f..00000000000 --- a/core/src/main/java/io/grpc/internal/OobChannel.java +++ /dev/null @@ -1,344 +0,0 @@ -/* - * Copyright 2016 The gRPC Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.grpc.internal; - -import static com.google.common.base.Preconditions.checkNotNull; - -import com.google.common.annotations.VisibleForTesting; -import com.google.common.base.MoreObjects; -import com.google.common.base.Preconditions; -import com.google.common.util.concurrent.ListenableFuture; -import com.google.common.util.concurrent.SettableFuture; -import io.grpc.Attributes; -import io.grpc.CallOptions; -import io.grpc.ClientCall; -import io.grpc.ClientStreamTracer; -import io.grpc.ConnectivityState; -import io.grpc.ConnectivityStateInfo; -import io.grpc.Context; -import io.grpc.EquivalentAddressGroup; -import io.grpc.InternalChannelz; -import io.grpc.InternalChannelz.ChannelStats; -import io.grpc.InternalChannelz.ChannelTrace; -import io.grpc.InternalInstrumented; -import io.grpc.InternalLogId; -import io.grpc.InternalWithLogId; -import io.grpc.LoadBalancer; -import io.grpc.LoadBalancer.PickResult; -import io.grpc.LoadBalancer.PickSubchannelArgs; -import io.grpc.LoadBalancer.Subchannel; -import io.grpc.LoadBalancer.SubchannelPicker; -import io.grpc.ManagedChannel; -import io.grpc.Metadata; -import io.grpc.MethodDescriptor; -import io.grpc.Status; -import io.grpc.SynchronizationContext; -import io.grpc.internal.ClientCallImpl.ClientStreamProvider; -import java.util.Collections; -import java.util.List; -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.Executor; -import java.util.concurrent.ScheduledExecutorService; -import java.util.concurrent.TimeUnit; -import java.util.logging.Level; -import java.util.logging.Logger; -import javax.annotation.concurrent.ThreadSafe; - -/** - * A ManagedChannel backed by a single {@link InternalSubchannel} and used for {@link LoadBalancer} - * to its own RPC needs. - */ -@ThreadSafe -final class OobChannel extends ManagedChannel implements InternalInstrumented { - private static final Logger log = Logger.getLogger(OobChannel.class.getName()); - - private InternalSubchannel subchannel; - private AbstractSubchannel subchannelImpl; - private SubchannelPicker subchannelPicker; - - private final InternalLogId logId; - private final String authority; - private final DelayedClientTransport delayedTransport; - private final InternalChannelz channelz; - private final ObjectPool executorPool; - private final Executor executor; - private final ScheduledExecutorService deadlineCancellationExecutor; - private final CountDownLatch terminatedLatch = new CountDownLatch(1); - private volatile boolean shutdown; - private final CallTracer channelCallsTracer; - private final ChannelTracer channelTracer; - private final TimeProvider timeProvider; - - private final ClientStreamProvider transportProvider = new ClientStreamProvider() { - @Override - public ClientStream newStream(MethodDescriptor method, - CallOptions callOptions, Metadata headers, Context context) { - ClientStreamTracer[] tracers = GrpcUtil.getClientStreamTracers( - callOptions, headers, 0, /* isTransparentRetry= */ false); - Context origContext = context.attach(); - // delayed transport's newStream() always acquires a lock, but concurrent performance doesn't - // matter here because OOB communication should be sparse, and it's not on application RPC's - // critical path. - try { - return delayedTransport.newStream(method, headers, callOptions, tracers); - } finally { - context.detach(origContext); - } - } - }; - - OobChannel( - String authority, ObjectPool executorPool, - ScheduledExecutorService deadlineCancellationExecutor, SynchronizationContext syncContext, - CallTracer callsTracer, ChannelTracer channelTracer, InternalChannelz channelz, - TimeProvider timeProvider) { - this.authority = checkNotNull(authority, "authority"); - this.logId = InternalLogId.allocate(getClass(), authority); - this.executorPool = checkNotNull(executorPool, "executorPool"); - this.executor = checkNotNull(executorPool.getObject(), "executor"); - this.deadlineCancellationExecutor = checkNotNull( - deadlineCancellationExecutor, "deadlineCancellationExecutor"); - this.delayedTransport = new DelayedClientTransport(executor, syncContext); - this.channelz = Preconditions.checkNotNull(channelz); - this.delayedTransport.start(new ManagedClientTransport.Listener() { - @Override - public void transportShutdown(Status s) { - // Don't care - } - - @Override - public void transportTerminated() { - subchannelImpl.shutdown(); - } - - @Override - public void transportReady() { - // Don't care - } - - @Override - public Attributes filterTransport(Attributes attributes) { - return attributes; - } - - @Override - public void transportInUse(boolean inUse) { - // Don't care - } - }); - this.channelCallsTracer = callsTracer; - this.channelTracer = checkNotNull(channelTracer, "channelTracer"); - this.timeProvider = checkNotNull(timeProvider, "timeProvider"); - } - - // Must be called only once, right after the OobChannel is created. - void setSubchannel(final InternalSubchannel subchannel) { - log.log(Level.FINE, "[{0}] Created with [{1}]", new Object[] {this, subchannel}); - this.subchannel = subchannel; - subchannelImpl = new AbstractSubchannel() { - @Override - public void shutdown() { - subchannel.shutdown(Status.UNAVAILABLE.withDescription("OobChannel is shutdown")); - } - - @Override - InternalInstrumented getInstrumentedInternalSubchannel() { - return subchannel; - } - - @Override - public void requestConnection() { - subchannel.obtainActiveTransport(); - } - - @Override - public List getAllAddresses() { - return subchannel.getAddressGroups(); - } - - @Override - public Attributes getAttributes() { - return Attributes.EMPTY; - } - - @Override - public Object getInternalSubchannel() { - return subchannel; - } - }; - - final class OobSubchannelPicker extends SubchannelPicker { - final PickResult result = PickResult.withSubchannel(subchannelImpl); - - @Override - public PickResult pickSubchannel(PickSubchannelArgs args) { - return result; - } - - @Override - public String toString() { - return MoreObjects.toStringHelper(OobSubchannelPicker.class) - .add("result", result) - .toString(); - } - } - - subchannelPicker = new OobSubchannelPicker(); - delayedTransport.reprocess(subchannelPicker); - } - - void updateAddresses(List eag) { - subchannel.updateAddresses(eag); - } - - @Override - public ClientCall newCall( - MethodDescriptor methodDescriptor, CallOptions callOptions) { - return new ClientCallImpl<>(methodDescriptor, - callOptions.getExecutor() == null ? executor : callOptions.getExecutor(), - callOptions, transportProvider, deadlineCancellationExecutor, channelCallsTracer, null); - } - - @Override - public String authority() { - return authority; - } - - @Override - public boolean isTerminated() { - return terminatedLatch.getCount() == 0; - } - - @Override - public boolean awaitTermination(long time, TimeUnit unit) throws InterruptedException { - return terminatedLatch.await(time, unit); - } - - @Override - public ConnectivityState getState(boolean requestConnectionIgnored) { - if (subchannel == null) { - return ConnectivityState.IDLE; - } - return subchannel.getState(); - } - - @Override - public ManagedChannel shutdown() { - shutdown = true; - delayedTransport.shutdown(Status.UNAVAILABLE.withDescription("OobChannel.shutdown() called")); - return this; - } - - @Override - public boolean isShutdown() { - return shutdown; - } - - @Override - public ManagedChannel shutdownNow() { - shutdown = true; - delayedTransport.shutdownNow( - Status.UNAVAILABLE.withDescription("OobChannel.shutdownNow() called")); - return this; - } - - void handleSubchannelStateChange(final ConnectivityStateInfo newState) { - channelTracer.reportEvent( - new ChannelTrace.Event.Builder() - .setDescription("Entering " + newState.getState() + " state") - .setSeverity(ChannelTrace.Event.Severity.CT_INFO) - .setTimestampNanos(timeProvider.currentTimeNanos()) - .build()); - switch (newState.getState()) { - case READY: - case IDLE: - delayedTransport.reprocess(subchannelPicker); - break; - case TRANSIENT_FAILURE: - final class OobErrorPicker extends SubchannelPicker { - final PickResult errorResult = PickResult.withError(newState.getStatus()); - - @Override - public PickResult pickSubchannel(PickSubchannelArgs args) { - return errorResult; - } - - @Override - public String toString() { - return MoreObjects.toStringHelper(OobErrorPicker.class) - .add("errorResult", errorResult) - .toString(); - } - } - - delayedTransport.reprocess(new OobErrorPicker()); - break; - default: - // Do nothing - } - } - - // must be run from channel executor - void handleSubchannelTerminated() { - channelz.removeSubchannel(this); - // When delayedTransport is terminated, it shuts down subchannel. Therefore, at this point - // both delayedTransport and subchannel have terminated. - executorPool.returnObject(executor); - terminatedLatch.countDown(); - } - - @VisibleForTesting - Subchannel getSubchannel() { - return subchannelImpl; - } - - InternalSubchannel getInternalSubchannel() { - return subchannel; - } - - @Override - public ListenableFuture getStats() { - final SettableFuture ret = SettableFuture.create(); - final ChannelStats.Builder builder = new ChannelStats.Builder(); - channelCallsTracer.updateBuilder(builder); - channelTracer.updateBuilder(builder); - builder - .setTarget(authority) - .setState(subchannel.getState()) - .setSubchannels(Collections.singletonList(subchannel)); - ret.set(builder.build()); - return ret; - } - - @Override - public InternalLogId getLogId() { - return logId; - } - - @Override - public String toString() { - return MoreObjects.toStringHelper(this) - .add("logId", logId.getId()) - .add("authority", authority) - .toString(); - } - - @Override - public void resetConnectBackoff() { - subchannel.resetConnectBackoff(); - } -} diff --git a/core/src/main/java/io/grpc/internal/OobNameResolverProvider.java b/core/src/main/java/io/grpc/internal/OobNameResolverProvider.java new file mode 100644 index 00000000000..408b92e0c84 --- /dev/null +++ b/core/src/main/java/io/grpc/internal/OobNameResolverProvider.java @@ -0,0 +1,121 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.internal; + +import static java.util.Objects.requireNonNull; + +import io.grpc.EquivalentAddressGroup; +import io.grpc.NameResolver; +import io.grpc.NameResolverProvider; +import io.grpc.StatusOr; +import io.grpc.SynchronizationContext; +import java.net.URI; +import java.util.Collection; +import java.util.LinkedList; +import java.util.List; + +/** + * A provider that is passed addresses and relays those addresses to its created resolvers. + */ +final class OobNameResolverProvider extends NameResolverProvider { + private final String authority; + private final SynchronizationContext parentSyncContext; + // Only accessed from parentSyncContext + @SuppressWarnings("JdkObsolete") // LinkedList uses O(n) memory, including after deletions + private final Collection resolvers = new LinkedList<>(); + // Only accessed from parentSyncContext + private List lastEags; + + public OobNameResolverProvider( + String authority, List eags, SynchronizationContext syncContext) { + this.authority = requireNonNull(authority, "authority"); + this.lastEags = requireNonNull(eags, "eags"); + this.parentSyncContext = requireNonNull(syncContext, "syncContext"); + } + + @Override + public String getDefaultScheme() { + return "oob"; + } + + @Override + protected boolean isAvailable() { + return true; + } + + @Override + protected int priority() { + return 5; // Doesn't matter, as we expect only one provider in the registry + } + + public void updateAddresses(List eags) { + requireNonNull(eags, "eags"); + parentSyncContext.execute(() -> { + this.lastEags = eags; + for (OobNameResolver resolver : resolvers) { + resolver.updateAddresses(eags); + } + }); + } + + @Override + public NameResolver newNameResolver(URI targetUri, NameResolver.Args args) { + return new OobNameResolver(args.getSynchronizationContext()); + } + + final class OobNameResolver extends NameResolver { + private final SynchronizationContext syncContext; + // Null before started, and after shutdown. Only accessed from syncContext + private Listener2 listener; + + public OobNameResolver(SynchronizationContext syncContext) { + this.syncContext = requireNonNull(syncContext, "syncContext"); + } + + @Override + public String getServiceAuthority() { + return authority; + } + + @Override + public void start(Listener2 listener) { + this.listener = requireNonNull(listener, "listener"); + parentSyncContext.execute(() -> { + resolvers.add(this); + updateAddresses(lastEags); + }); + } + + void updateAddresses(List eags) { + parentSyncContext.throwIfNotInThisSynchronizationContext(); + syncContext.execute(() -> { + if (listener == null) { + return; + } + listener.onResult2(ResolutionResult.newBuilder() + .setAddressesOrError(StatusOr.fromValue(lastEags)) + .build()); + }); + } + + @Override + public void shutdown() { + this.listener = null; + parentSyncContext.execute(() -> resolvers.remove(this)); + } + } +} diff --git a/core/src/main/java/io/grpc/internal/PickDetailsConsumerImpl.java b/core/src/main/java/io/grpc/internal/PickDetailsConsumerImpl.java new file mode 100644 index 00000000000..5c69757afbf --- /dev/null +++ b/core/src/main/java/io/grpc/internal/PickDetailsConsumerImpl.java @@ -0,0 +1,42 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.internal; + +import com.google.common.base.Preconditions; +import io.grpc.ClientStreamTracer; +import io.grpc.LoadBalancer.PickDetailsConsumer; + +/** + * Adapter for tracers into details consumers. + */ +final class PickDetailsConsumerImpl implements PickDetailsConsumer { + private final ClientStreamTracer[] tracers; + + /** Construct a consumer with unchanging tracers array. */ + public PickDetailsConsumerImpl(ClientStreamTracer[] tracers) { + this.tracers = Preconditions.checkNotNull(tracers, "tracers"); + } + + @Override + public void addOptionalLabel(String key, String value) { + Preconditions.checkNotNull(key, "key"); + Preconditions.checkNotNull(value, "value"); + for (ClientStreamTracer tracer : tracers) { + tracer.addOptionalLabel(key, value); + } + } +} diff --git a/core/src/main/java/io/grpc/internal/PickFirstLeafLoadBalancer.java b/core/src/main/java/io/grpc/internal/PickFirstLeafLoadBalancer.java index 3d6fadeffd1..ab60a024e1f 100644 --- a/core/src/main/java/io/grpc/internal/PickFirstLeafLoadBalancer.java +++ b/core/src/main/java/io/grpc/internal/PickFirstLeafLoadBalancer.java @@ -24,17 +24,19 @@ import static io.grpc.ConnectivityState.TRANSIENT_FAILURE; import com.google.common.annotations.VisibleForTesting; -import com.google.common.base.MoreObjects; import com.google.common.collect.ImmutableList; import com.google.common.collect.Lists; +import com.google.errorprone.annotations.CheckReturnValue; import io.grpc.Attributes; import io.grpc.ConnectivityState; import io.grpc.ConnectivityStateInfo; import io.grpc.EquivalentAddressGroup; -import io.grpc.ExperimentalApi; +import io.grpc.InternalEquivalentAddressGroup; import io.grpc.LoadBalancer; import io.grpc.Status; import io.grpc.SynchronizationContext.ScheduledHandle; +import java.net.Inet4Address; +import java.net.InetSocketAddress; import java.net.SocketAddress; import java.util.ArrayList; import java.util.Collections; @@ -55,35 +57,48 @@ * io.grpc.NameResolver}. The channel's default behavior is used, which is walking down the address * list and sticking to the first that works. */ -@ExperimentalApi("https://github.com/grpc/grpc-java/issues/10383") final class PickFirstLeafLoadBalancer extends LoadBalancer { private static final Logger log = Logger.getLogger(PickFirstLeafLoadBalancer.class.getName()); @VisibleForTesting static final int CONNECTION_DELAY_INTERVAL_MS = 250; - public static final String GRPC_EXPERIMENTAL_XDS_DUALSTACK_ENDPOINTS = - "GRPC_EXPERIMENTAL_XDS_DUALSTACK_ENDPOINTS"; + private final boolean enableHappyEyeballs = !isSerializingRetries() + && PickFirstLoadBalancerProvider.isEnabledHappyEyeballs(); + static boolean weightedShuffling = + GrpcUtil.getFlag("GRPC_EXPERIMENTAL_PF_WEIGHTED_SHUFFLING", true); private final Helper helper; private final Map subchannels = new HashMap<>(); - private Index addressIndex; + private final Index addressIndex = new Index(ImmutableList.of(), this.enableHappyEyeballs); private int numTf = 0; private boolean firstPass = true; @Nullable - private ScheduledHandle scheduleConnectionTask; + private ScheduledHandle scheduleConnectionTask = null; private ConnectivityState rawConnectivityState = IDLE; private ConnectivityState concludedState = IDLE; - private final boolean enableHappyEyeballs = - GrpcUtil.getFlag(GRPC_EXPERIMENTAL_XDS_DUALSTACK_ENDPOINTS, false); + private boolean notAPetiolePolicy = true; // means not under a petiole policy + private final BackoffPolicy.Provider bkoffPolProvider = new ExponentialBackoffPolicy.Provider(); + private BackoffPolicy reconnectPolicy; + @Nullable + private ScheduledHandle reconnectTask = null; + private final boolean serializingRetries = isSerializingRetries(); PickFirstLeafLoadBalancer(Helper helper) { this.helper = checkNotNull(helper, "helper"); } + static boolean isSerializingRetries() { + return GrpcUtil.getFlag("GRPC_SERIALIZE_RETRIES", false); + } + @Override public Status acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { if (rawConnectivityState == SHUTDOWN) { return Status.FAILED_PRECONDITION.withDescription("Already shut down"); } + // Check whether this is a petiole policy, which is based off of an address attribute + Boolean isPetiolePolicy = resolvedAddresses.getAttributes().get(IS_PETIOLE_POLICY); + this.notAPetiolePolicy = isPetiolePolicy == null || !isPetiolePolicy; + List servers = resolvedAddresses.getAddresses(); // Validate the address list @@ -108,6 +123,8 @@ public Status acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { // Since we have a new set of addresses, we are again at first pass firstPass = true; + List cleanServers = deDupAddresses(servers); + // We can optionally be configured to shuffle the address list. This can help better distribute // the load. if (resolvedAddresses.getLoadBalancingPolicyConfig() @@ -115,33 +132,66 @@ public Status acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { PickFirstLeafLoadBalancerConfig config = (PickFirstLeafLoadBalancerConfig) resolvedAddresses.getLoadBalancingPolicyConfig(); if (config.shuffleAddressList != null && config.shuffleAddressList) { - servers = new ArrayList<>(servers); - Collections.shuffle(servers, - config.randomSeed != null ? new Random(config.randomSeed) : new Random()); + cleanServers = shuffle( + cleanServers, config.randomSeed != null ? new Random(config.randomSeed) : new Random()); } } - // Make sure we're storing our own list rather than what was passed in final ImmutableList newImmutableAddressGroups = - ImmutableList.builder().addAll(servers).build(); - - if (addressIndex == null) { - addressIndex = new Index(newImmutableAddressGroups); - } else if (rawConnectivityState == READY) { - // If the previous ready subchannel exists in new address list, - // keep this connection and don't create new subchannels + ImmutableList.copyOf(cleanServers); + + if (rawConnectivityState == READY + || (rawConnectivityState == CONNECTING + && (!enableHappyEyeballs || addressIndex.isValid()))) { + // If the previous ready (or connecting) subchannel exists in new address list, + // keep this connection and don't create new subchannels. Happy Eyeballs is excluded when + // connecting, because it allows multiple attempts simultaneously, thus is fine to start at + // the beginning. SocketAddress previousAddress = addressIndex.getCurrentAddress(); addressIndex.updateGroups(newImmutableAddressGroups); if (addressIndex.seekTo(previousAddress)) { + SubchannelData subchannelData = subchannels.get(previousAddress); + subchannelData.getSubchannel().updateAddresses(addressIndex.getCurrentEagAsList()); + shutdownRemovedAddresses(newImmutableAddressGroups); return Status.OK; - } else { - addressIndex.reset(); // Previous ready subchannel not in the new list of addresses } + // Previous ready subchannel not in the new list of addresses } else { addressIndex.updateGroups(newImmutableAddressGroups); } - // remove old subchannels that were not in new address list + // No old addresses means first time through, so we will do an explicit move to CONNECTING + // which is what we implicitly started with + boolean noOldAddrs = shutdownRemovedAddresses(newImmutableAddressGroups); + + if (noOldAddrs) { + // Make tests happy; they don't properly assume starting in CONNECTING + rawConnectivityState = CONNECTING; + updateBalancingState(CONNECTING, new FixedResultPicker(PickResult.withNoResult())); + } + + if (rawConnectivityState == READY) { + // connect from beginning when prompted + rawConnectivityState = IDLE; + updateBalancingState(IDLE, new RequestConnectionPicker(this)); + + } else if (rawConnectivityState == CONNECTING || rawConnectivityState == TRANSIENT_FAILURE) { + // start connection attempt at first address + cancelScheduleTask(); + requestConnection(); + } + + return Status.OK; + } + + /** + * Compute the difference between the flattened new addresses and the old addresses that had been + * made into subchannels and then shutdown the matching subchannels. + * @return true if there were no old addresses + */ + private boolean shutdownRemovedAddresses( + ImmutableList newImmutableAddressGroups) { + Set oldAddrs = new HashSet<>(subchannels.keySet()); // Flatten the new EAGs addresses @@ -156,54 +206,101 @@ public Status acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { subchannels.remove(oldAddr).getSubchannel().shutdown(); } } + return oldAddrs.isEmpty(); + } - if (oldAddrs.size() == 0 || rawConnectivityState == CONNECTING - || rawConnectivityState == READY) { - // start connection attempt at first address - rawConnectivityState = CONNECTING; - updateBalancingState(CONNECTING, new Picker(PickResult.withNoResult())); - cancelScheduleTask(); - requestConnection(); + private static List deDupAddresses(List groups) { + Set seenAddresses = new HashSet<>(); + List newGroups = new ArrayList<>(); + + for (EquivalentAddressGroup group : groups) { + List addrs = new ArrayList<>(); + for (SocketAddress addr : group.getAddresses()) { + if (seenAddresses.add(addr)) { + addrs.add(addr); + } + } + if (!addrs.isEmpty()) { + newGroups.add(new EquivalentAddressGroup(addrs, group.getAttributes())); + } + } - } else if (rawConnectivityState == IDLE) { - // start connection attempt at first address when requested - SubchannelPicker picker = new RequestConnectionPicker(this); - updateBalancingState(IDLE, picker); + return newGroups; + } - } else if (rawConnectivityState == TRANSIENT_FAILURE) { - // start connection attempt at first address - cancelScheduleTask(); - requestConnection(); + // Also used by PickFirstLoadBalancer + @CheckReturnValue + static List shuffle(List eags, Random random) { + if (weightedShuffling) { + List weightedEntries = new ArrayList<>(eags.size()); + for (EquivalentAddressGroup eag : eags) { + weightedEntries.add(new WeightEntry(eag, eagToWeight(eag, random))); + } + Collections.sort(weightedEntries, Collections.reverseOrder() /* descending */); + return Lists.transform(weightedEntries, entry -> entry.eag); + } else { + List eagsCopy = new ArrayList<>(eags); + Collections.shuffle(eagsCopy, random); + return eagsCopy; } + } - return Status.OK; + private static double eagToWeight(EquivalentAddressGroup eag, Random random) { + Long weight = eag.getAttributes().get(InternalEquivalentAddressGroup.ATTR_WEIGHT); + if (weight == null) { + weight = 1L; + } + return Math.pow(random.nextDouble(), 1.0 / weight); + } + + private static final class WeightEntry implements Comparable { + final EquivalentAddressGroup eag; + final double weight; + + public WeightEntry(EquivalentAddressGroup eag, double weight) { + this.eag = eag; + this.weight = weight; + } + + @Override + public int compareTo(WeightEntry entry) { + return Double.compare(this.weight, entry.weight); + } } @Override public void handleNameResolutionError(Status error) { + if (rawConnectivityState == SHUTDOWN) { + return; + } + for (SubchannelData subchannelData : subchannels.values()) { subchannelData.getSubchannel().shutdown(); } subchannels.clear(); - updateBalancingState(TRANSIENT_FAILURE, new Picker(PickResult.withError(error))); + addressIndex.updateGroups(ImmutableList.of()); + rawConnectivityState = TRANSIENT_FAILURE; + updateBalancingState(TRANSIENT_FAILURE, new FixedResultPicker(PickResult.withError(error))); } - void processSubchannelState(Subchannel subchannel, ConnectivityStateInfo stateInfo) { + void processSubchannelState(SubchannelData subchannelData, ConnectivityStateInfo stateInfo) { ConnectivityState newState = stateInfo.getState(); + // Shutdown channels/previously relevant subchannels can still callback with state updates. // To prevent pickers from returning these obsolete subchannels, this logic // is included to check if the current list of active subchannels includes this subchannel. - SubchannelData subchannelData = subchannels.get(getAddress(subchannel)); - if (subchannelData == null || subchannelData.getSubchannel() != subchannel) { + if (subchannelData != subchannels.get(getAddress(subchannelData.subchannel))) { return; } + if (newState == SHUTDOWN) { return; } - if (newState == IDLE) { + if (newState == IDLE && subchannelData.state == READY) { helper.refreshNameResolution(); } + // If we are transitioning from a TRANSIENT_FAILURE to CONNECTING or IDLE we ignore this state // transition and still keep the LB in TRANSIENT_FAILURE state. This is referred to as "sticky // transient failure". Only a subchannel state change to READY will get the LB out of @@ -236,12 +333,22 @@ void processSubchannelState(Subchannel subchannel, ConnectivityStateInfo stateIn case CONNECTING: rawConnectivityState = CONNECTING; - updateBalancingState(CONNECTING, new Picker(PickResult.withNoResult())); + // If we get a newly resolved address list via acceptResolvedAddresses, + // as we are in CONNECTING, we will try to .updateAddresses the currently + // connecting subchannel if it exists in the new list. + // As such, We need to make sure that with transitioning to CONNECTING the subchannel for + // the current address of a valid index exists. + if ((!enableHappyEyeballs && !addressIndex.isValid()) + || (addressIndex.isValid() && !subchannels.containsKey( + addressIndex.getCurrentAddress()))) { + addressIndex.seekTo(getAddress(subchannelData.subchannel)); + } + updateBalancingState(CONNECTING, new FixedResultPicker(PickResult.withNoResult())); break; case READY: shutdownRemaining(subchannelData); - addressIndex.seekTo(getAddress(subchannel)); + addressIndex.seekTo(getAddress(subchannelData.subchannel)); rawConnectivityState = READY; updateHealthCheckedState(subchannelData); break; @@ -249,17 +356,26 @@ void processSubchannelState(Subchannel subchannel, ConnectivityStateInfo stateIn case TRANSIENT_FAILURE: // If we are looking at current channel, request a connection if possible if (addressIndex.isValid() - && subchannels.get(addressIndex.getCurrentAddress()).getSubchannel() == subchannel) { + && subchannels.get(addressIndex.getCurrentAddress()) == subchannelData) { if (addressIndex.increment()) { cancelScheduleTask(); requestConnection(); // is recursive so might hit the end of the addresses + } else { + if (subchannels.size() >= addressIndex.size()) { + scheduleBackoff(); + } else { + // We must have done a seek to the middle of the list lets start over from the + // beginning + addressIndex.reset(); + requestConnection(); + } } } if (isPassComplete()) { rawConnectivityState = TRANSIENT_FAILURE; updateBalancingState(TRANSIENT_FAILURE, - new Picker(PickResult.withError(stateInfo.getStatus()))); + new FixedResultPicker(PickResult.withError(stateInfo.getStatus()))); // Refresh Name Resolution, but only when all 3 conditions are met // * We are at the end of addressIndex @@ -280,19 +396,53 @@ void processSubchannelState(Subchannel subchannel, ConnectivityStateInfo stateIn } } + /** + * Only called after all addresses attempted and failed (TRANSIENT_FAILURE). + */ + private void scheduleBackoff() { + if (!serializingRetries) { + return; + } + + class EndOfCurrentBackoff implements Runnable { + @Override + public void run() { + reconnectTask = null; + addressIndex.reset(); + requestConnection(); + } + } + + // Just allow the previous one to trigger when ready if we're already in backoff + if (reconnectTask != null) { + return; + } + + if (reconnectPolicy == null) { + reconnectPolicy = bkoffPolProvider.get(); + } + long delayNanos = reconnectPolicy.nextBackoffNanos(); + reconnectTask = helper.getSynchronizationContext().schedule( + new EndOfCurrentBackoff(), + delayNanos, + TimeUnit.NANOSECONDS, + helper.getScheduledExecutorService()); + } + private void updateHealthCheckedState(SubchannelData subchannelData) { if (subchannelData.state != READY) { return; } - if (subchannelData.getHealthState() == READY) { + + if (notAPetiolePolicy || subchannelData.getHealthState() == READY) { updateBalancingState(READY, new FixedResultPicker(PickResult.withSubchannel(subchannelData.subchannel))); } else if (subchannelData.getHealthState() == TRANSIENT_FAILURE) { - updateBalancingState(TRANSIENT_FAILURE, new Picker(PickResult.withError( - subchannelData.healthListener.healthStateInfo.getStatus()))); + updateBalancingState(TRANSIENT_FAILURE, new FixedResultPicker(PickResult.withError( + subchannelData.healthStateInfo.getStatus()))); } else if (concludedState != TRANSIENT_FAILURE) { updateBalancingState(subchannelData.getHealthState(), - new Picker(PickResult.withNoResult())); + new FixedResultPicker(PickResult.withNoResult())); } } @@ -312,6 +462,11 @@ public void shutdown() { rawConnectivityState = SHUTDOWN; concludedState = SHUTDOWN; cancelScheduleTask(); + if (reconnectTask != null) { + reconnectTask.cancel(); + reconnectTask = null; + } + reconnectPolicy = null; for (SubchannelData subchannelData : subchannels.values()) { subchannelData.getSubchannel().shutdown(); @@ -325,6 +480,12 @@ public void shutdown() { * that all other subchannels must be shutdown. */ private void shutdownRemaining(SubchannelData activeSubchannelData) { + if (reconnectTask != null) { + reconnectTask.cancel(); + reconnectTask = null; + } + reconnectPolicy = null; + cancelScheduleTask(); for (SubchannelData subchannelData : subchannels.values()) { if (!subchannelData.getSubchannel().equals(activeSubchannelData.subchannel)) { @@ -345,41 +506,41 @@ private void shutdownRemaining(SubchannelData activeSubchannelData) { */ @Override public void requestConnection() { - if (addressIndex == null || !addressIndex.isValid() || rawConnectivityState == SHUTDOWN ) { + if (!addressIndex.isValid() || rawConnectivityState == SHUTDOWN || reconnectTask != null) { return; } - Subchannel subchannel; - SocketAddress currentAddress; - currentAddress = addressIndex.getCurrentAddress(); - subchannel = subchannels.containsKey(currentAddress) - ? subchannels.get(currentAddress).getSubchannel() - : createNewSubchannel(currentAddress); + SocketAddress currentAddress = addressIndex.getCurrentAddress(); + SubchannelData subchannelData = subchannels.get(currentAddress); + if (subchannelData == null) { + subchannelData = createNewSubchannel(currentAddress, addressIndex.getCurrentEagAttributes()); + } - ConnectivityState subchannelState = subchannels.get(currentAddress).getState(); + ConnectivityState subchannelState = subchannelData.getState(); switch (subchannelState) { case IDLE: - subchannel.requestConnection(); - subchannels.get(currentAddress).updateState(CONNECTING); + subchannelData.subchannel.requestConnection(); + subchannelData.updateState(CONNECTING); scheduleNextConnection(); break; case CONNECTING: - if (enableHappyEyeballs) { - scheduleNextConnection(); - } else { - subchannel.requestConnection(); - } + scheduleNextConnection(); break; case TRANSIENT_FAILURE: - addressIndex.increment(); - requestConnection(); - break; - case READY: // Shouldn't ever happen - log.warning("Requesting a connection even though we have a READY subchannel"); + if (!serializingRetries) { + addressIndex.increment(); + requestConnection(); + } else { + if (!addressIndex.isValid()) { + scheduleBackoff(); + } else { + subchannelData.subchannel.requestConnection(); + subchannelData.updateState(CONNECTING); + } + } break; - case SHUTDOWN: default: - // Makes checkstyle happy + // Wait for current subchannel to change state } } @@ -418,32 +579,32 @@ private void cancelScheduleTask() { } } - private Subchannel createNewSubchannel(SocketAddress addr) { + private SubchannelData createNewSubchannel(SocketAddress addr, Attributes attrs) { HealthListener hcListener = new HealthListener(); final Subchannel subchannel = helper.createSubchannel( CreateSubchannelArgs.newBuilder() - .setAddresses(Lists.newArrayList( - new EquivalentAddressGroup(addr))) - .addOption(HEALTH_CONSUMER_LISTENER_ARG_KEY, hcListener) + .setAddresses(Lists.newArrayList( + new EquivalentAddressGroup(addr, attrs))) + .addOption(HEALTH_CONSUMER_LISTENER_ARG_KEY, hcListener) + .addOption(LoadBalancer.DISABLE_SUBCHANNEL_RECONNECT_KEY, serializingRetries) .build()); if (subchannel == null) { log.warning("Was not able to create subchannel for " + addr); throw new IllegalStateException("Can't create subchannel"); } - SubchannelData subchannelData = new SubchannelData(subchannel, IDLE, hcListener); + SubchannelData subchannelData = new SubchannelData(subchannel, IDLE); hcListener.subchannelData = subchannelData; subchannels.put(addr, subchannelData); - Attributes attrs = subchannel.getAttributes(); - if (attrs.get(LoadBalancer.HAS_HEALTH_PRODUCER_LISTENER_KEY) == null) { - hcListener.healthStateInfo = ConnectivityStateInfo.forNonError(READY); + Attributes scAttrs = subchannel.getAttributes(); + if (notAPetiolePolicy || scAttrs.get(LoadBalancer.HAS_HEALTH_PRODUCER_LISTENER_KEY) == null) { + subchannelData.healthStateInfo = ConnectivityStateInfo.forNonError(READY); } - subchannel.start(stateInfo -> processSubchannelState(subchannel, stateInfo)); - return subchannel; + subchannel.start(stateInfo -> processSubchannelState(subchannelData, stateInfo)); + return subchannelData; } private boolean isPassComplete() { - if (addressIndex == null || addressIndex.isValid() - || subchannels.size() < addressIndex.size()) { + if (subchannels.size() < addressIndex.size()) { return false; } for (SubchannelData sc : subchannels.values()) { @@ -455,16 +616,22 @@ private boolean isPassComplete() { } private final class HealthListener implements SubchannelStateListener { - private ConnectivityStateInfo healthStateInfo = ConnectivityStateInfo.forNonError(IDLE); private SubchannelData subchannelData; @Override public void onSubchannelState(ConnectivityStateInfo newState) { + if (notAPetiolePolicy) { + log.log(Level.WARNING, + "Ignoring health status {0} for subchannel {1} as this is not under a petiole policy", + new Object[]{newState, subchannelData.subchannel}); + return; + } + log.log(Level.FINE, "Received health status {0} for subchannel {1}", new Object[]{newState, subchannelData.subchannel}); - healthStateInfo = newState; + subchannelData.healthStateInfo = newState; if (addressIndex.isValid() - && subchannels.get(addressIndex.getCurrentAddress()).healthListener == this) { + && subchannelData == subchannels.get(addressIndex.getCurrentAddress())) { updateHealthCheckedState(subchannelData); } } @@ -479,26 +646,9 @@ ConnectivityState getConcludedConnectivityState() { return this.concludedState; } - /** - * No-op picker which doesn't add any custom picking logic. It just passes already known result - * received in constructor. - */ - private static final class Picker extends SubchannelPicker { - private final PickResult result; - - Picker(PickResult result) { - this.result = checkNotNull(result, "result"); - } - - @Override - public PickResult pickSubchannel(PickSubchannelArgs args) { - return result; - } - - @Override - public String toString() { - return MoreObjects.toStringHelper(Picker.class).add("result", result).toString(); - } + @VisibleForTesting + ConnectivityState getRawConnectivityState() { + return this.rawConnectivityState; } /** @@ -523,26 +673,26 @@ public PickResult pickSubchannel(PickSubchannelArgs args) { } /** - * Index as in 'i', the pointer to an entry. Not a "search index." + * This contains both an ordered list of addresses and a pointer(i.e. index) to the current entry. * All updates should be done in a synchronization context. */ @VisibleForTesting static final class Index { - private List addressGroups; - private int groupIndex; - private int addressIndex; + private List orderedAddresses; + private int activeElement = 0; + private boolean enableHappyEyeballs; - public Index(List groups) { - this.addressGroups = groups != null ? groups : Collections.emptyList(); + Index(List groups, boolean enableHappyEyeballs) { + this.enableHappyEyeballs = enableHappyEyeballs; + updateGroups(groups); } public boolean isValid() { - // Is invalid if empty or has incremented off the end - return groupIndex < addressGroups.size(); + return activeElement < orderedAddresses.size(); } public boolean isAtBeginning() { - return groupIndex == 0 && addressIndex == 0; + return activeElement == 0; } /** @@ -554,41 +704,48 @@ public boolean increment() { return false; } - EquivalentAddressGroup group = addressGroups.get(groupIndex); - addressIndex++; - if (addressIndex >= group.getAddresses().size()) { - groupIndex++; - addressIndex = 0; - return groupIndex < addressGroups.size(); - } + activeElement++; - return true; + return isValid(); } public void reset() { - groupIndex = 0; - addressIndex = 0; + activeElement = 0; } public SocketAddress getCurrentAddress() { if (!isValid()) { throw new IllegalStateException("Index is past the end of the address group list"); } - return addressGroups.get(groupIndex).getAddresses().get(addressIndex); + return orderedAddresses.get(activeElement).address; } public Attributes getCurrentEagAttributes() { if (!isValid()) { throw new IllegalStateException("Index is off the end of the address group list"); } - return addressGroups.get(groupIndex).getAttributes(); + return orderedAddresses.get(activeElement).attributes; + } + + public List getCurrentEagAsList() { + return Collections.singletonList(getCurrentEag()); + } + + private EquivalentAddressGroup getCurrentEag() { + if (!isValid()) { + throw new IllegalStateException("Index is past the end of the address group list"); + } + return orderedAddresses.get(activeElement).asEag(); } /** * Update to new groups, resetting the current index. */ - public void updateGroups(ImmutableList newGroups) { - addressGroups = newGroups != null ? newGroups : Collections.emptyList(); + public void updateGroups(List newGroups) { + checkNotNull(newGroups, "newGroups"); + orderedAddresses = enableHappyEyeballs + ? updateGroupsHE(newGroups) + : updateGroupsNonHE(newGroups); reset(); } @@ -596,35 +753,117 @@ public void updateGroups(ImmutableList newGroups) { * Returns false if the needle was not found and the current index was left unchanged. */ public boolean seekTo(SocketAddress needle) { - for (int i = 0; i < addressGroups.size(); i++) { - EquivalentAddressGroup group = addressGroups.get(i); - int j = group.getAddresses().indexOf(needle); - if (j == -1) { - continue; + checkNotNull(needle, "needle"); + for (int i = 0; i < orderedAddresses.size(); i++) { + if (orderedAddresses.get(i).address.equals(needle)) { + this.activeElement = i; + return true; } - this.groupIndex = i; - this.addressIndex = j; - return true; } return false; } public int size() { - return (addressGroups != null) ? addressGroups.size() : 0; + return orderedAddresses.size(); + } + + private List updateGroupsNonHE(List newGroups) { + List entries = new ArrayList<>(); + for (int g = 0; g < newGroups.size(); g++) { + EquivalentAddressGroup eag = newGroups.get(g); + for (int a = 0; a < eag.getAddresses().size(); a++) { + SocketAddress addr = eag.getAddresses().get(a); + entries.add(new UnwrappedEag(eag.getAttributes(), addr)); + } + } + + return entries; + } + + private List updateGroupsHE(List newGroups) { + Boolean firstIsV6 = null; + List v4Entries = new ArrayList<>(); + List v6Entries = new ArrayList<>(); + for (int g = 0; g < newGroups.size(); g++) { + EquivalentAddressGroup eag = newGroups.get(g); + for (int a = 0; a < eag.getAddresses().size(); a++) { + SocketAddress addr = eag.getAddresses().get(a); + boolean isIpV4 = addr instanceof InetSocketAddress + && ((InetSocketAddress) addr).getAddress() instanceof Inet4Address; + if (isIpV4) { + if (firstIsV6 == null) { + firstIsV6 = false; + } + v4Entries.add(new UnwrappedEag(eag.getAttributes(), addr)); + } else { + if (firstIsV6 == null) { + firstIsV6 = true; + } + v6Entries.add(new UnwrappedEag(eag.getAttributes(), addr)); + } + } + } + + return firstIsV6 != null && firstIsV6 + ? interleave(v6Entries, v4Entries) + : interleave(v4Entries, v6Entries); + } + + private List interleave(List firstFamily, + List secondFamily) { + if (firstFamily.isEmpty()) { + return secondFamily; + } + if (secondFamily.isEmpty()) { + return firstFamily; + } + + List result = new ArrayList<>(firstFamily.size() + secondFamily.size()); + for (int i = 0; i < Math.max(firstFamily.size(), secondFamily.size()); i++) { + if (i < firstFamily.size()) { + result.add(firstFamily.get(i)); + } + if (i < secondFamily.size()) { + result.add(secondFamily.get(i)); + } + } + return result; + } + + private static final class UnwrappedEag { + private final Attributes attributes; + private final SocketAddress address; + + public UnwrappedEag(Attributes attributes, SocketAddress address) { + this.attributes = attributes; + this.address = address; + } + + private EquivalentAddressGroup asEag() { + return new EquivalentAddressGroup(address, attributes); + } } } + @VisibleForTesting + int getIndexLocation() { + return addressIndex.activeElement; + } + + @VisibleForTesting + boolean isIndexValid() { + return addressIndex.isValid(); + } + private static final class SubchannelData { private final Subchannel subchannel; private ConnectivityState state; - private final HealthListener healthListener; private boolean completedConnectivityAttempt = false; + private ConnectivityStateInfo healthStateInfo = ConnectivityStateInfo.forNonError(IDLE); - public SubchannelData(Subchannel subchannel, ConnectivityState state, - HealthListener subchannelHealthListener) { + public SubchannelData(Subchannel subchannel, ConnectivityState state) { this.subchannel = subchannel; this.state = state; - this.healthListener = subchannelHealthListener; } public Subchannel getSubchannel() { @@ -649,7 +888,7 @@ private void updateState(ConnectivityState newState) { } private ConnectivityState getHealthState() { - return healthListener.healthStateInfo.getState(); + return healthStateInfo.getState(); } } @@ -672,4 +911,5 @@ public PickFirstLeafLoadBalancerConfig(@Nullable Boolean shuffleAddressList) { this.randomSeed = randomSeed; } } + } diff --git a/core/src/main/java/io/grpc/internal/PickFirstLoadBalancer.java b/core/src/main/java/io/grpc/internal/PickFirstLoadBalancer.java index acef79d3d9f..cf4b4c94e04 100644 --- a/core/src/main/java/io/grpc/internal/PickFirstLoadBalancer.java +++ b/core/src/main/java/io/grpc/internal/PickFirstLoadBalancer.java @@ -22,14 +22,11 @@ import static io.grpc.ConnectivityState.SHUTDOWN; import static io.grpc.ConnectivityState.TRANSIENT_FAILURE; -import com.google.common.base.MoreObjects; import io.grpc.ConnectivityState; import io.grpc.ConnectivityStateInfo; import io.grpc.EquivalentAddressGroup; import io.grpc.LoadBalancer; import io.grpc.Status; -import java.util.ArrayList; -import java.util.Collections; import java.util.List; import java.util.Random; import java.util.concurrent.atomic.AtomicBoolean; @@ -66,9 +63,8 @@ public Status acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { PickFirstLoadBalancerConfig config = (PickFirstLoadBalancerConfig) resolvedAddresses.getLoadBalancingPolicyConfig(); if (config.shuffleAddressList != null && config.shuffleAddressList) { - servers = new ArrayList(servers); - Collections.shuffle(servers, - config.randomSeed != null ? new Random(config.randomSeed) : new Random()); + servers = PickFirstLeafLoadBalancer.shuffle( + servers, config.randomSeed != null ? new Random(config.randomSeed) : new Random()); } } @@ -87,7 +83,7 @@ public void onSubchannelState(ConnectivityStateInfo stateInfo) { // The channel state does not get updated when doing name resolving today, so for the moment // let LB report CONNECTION and call subchannel.requestConnection() immediately. - updateBalancingState(CONNECTING, new Picker(PickResult.withSubchannel(subchannel))); + updateBalancingState(CONNECTING, new FixedResultPicker(PickResult.withNoResult())); subchannel.requestConnection(); } else { subchannel.updateAddresses(servers); @@ -105,7 +101,7 @@ public void handleNameResolutionError(Status error) { // NB(lukaszx0) Whether we should propagate the error unconditionally is arguable. It's fine // for time being. - updateBalancingState(TRANSIENT_FAILURE, new Picker(PickResult.withError(error))); + updateBalancingState(TRANSIENT_FAILURE, new FixedResultPicker(PickResult.withError(error))); } private void processSubchannelState(Subchannel subchannel, ConnectivityStateInfo stateInfo) { @@ -134,18 +130,18 @@ private void processSubchannelState(Subchannel subchannel, ConnectivityStateInfo SubchannelPicker picker; switch (newState) { case IDLE: - picker = new RequestConnectionPicker(subchannel); + picker = new RequestConnectionPicker(); break; case CONNECTING: // It's safe to use RequestConnectionPicker here, so when coming from IDLE we could leave // the current picker in-place. But ignoring the potential optimization is simpler. - picker = new Picker(PickResult.withNoResult()); + picker = new FixedResultPicker(PickResult.withNoResult()); break; case READY: - picker = new Picker(PickResult.withSubchannel(subchannel)); + picker = new FixedResultPicker(PickResult.withSubchannel(subchannel)); break; case TRANSIENT_FAILURE: - picker = new Picker(PickResult.withError(stateInfo.getStatus())); + picker = new FixedResultPicker(PickResult.withError(stateInfo.getStatus())); break; default: throw new IllegalArgumentException("Unsupported state:" + newState); @@ -173,46 +169,14 @@ public void requestConnection() { } } - /** - * No-op picker which doesn't add any custom picking logic. It just passes already known result - * received in constructor. - */ - private static final class Picker extends SubchannelPicker { - private final PickResult result; - - Picker(PickResult result) { - this.result = checkNotNull(result, "result"); - } - - @Override - public PickResult pickSubchannel(PickSubchannelArgs args) { - return result; - } - - @Override - public String toString() { - return MoreObjects.toStringHelper(Picker.class).add("result", result).toString(); - } - } - /** Picker that requests connection during the first pick, and returns noResult. */ private final class RequestConnectionPicker extends SubchannelPicker { - private final Subchannel subchannel; private final AtomicBoolean connectionRequested = new AtomicBoolean(false); - RequestConnectionPicker(Subchannel subchannel) { - this.subchannel = checkNotNull(subchannel, "subchannel"); - } - @Override public PickResult pickSubchannel(PickSubchannelArgs args) { if (connectionRequested.compareAndSet(false, true)) { - helper.getSynchronizationContext().execute(new Runnable() { - @Override - public void run() { - subchannel.requestConnection(); - } - }); + helper.getSynchronizationContext().execute(PickFirstLoadBalancer.this::requestConnection); } return PickResult.withNoResult(); } diff --git a/core/src/main/java/io/grpc/internal/PickFirstLoadBalancerProvider.java b/core/src/main/java/io/grpc/internal/PickFirstLoadBalancerProvider.java index 27f25e78e18..83b3fb7d8e6 100644 --- a/core/src/main/java/io/grpc/internal/PickFirstLoadBalancerProvider.java +++ b/core/src/main/java/io/grpc/internal/PickFirstLoadBalancerProvider.java @@ -16,12 +16,13 @@ package io.grpc.internal; -import com.google.common.base.Strings; +import com.google.common.annotations.VisibleForTesting; import io.grpc.LoadBalancer; import io.grpc.LoadBalancerProvider; import io.grpc.NameResolver; import io.grpc.NameResolver.ConfigOrError; import io.grpc.Status; +import io.grpc.internal.PickFirstLeafLoadBalancer.PickFirstLeafLoadBalancerConfig; import io.grpc.internal.PickFirstLoadBalancer.PickFirstLoadBalancerConfig; import java.util.Map; @@ -32,11 +33,16 @@ * down the address list and sticks to the first that works. */ public final class PickFirstLoadBalancerProvider extends LoadBalancerProvider { + public static final String GRPC_PF_USE_HAPPY_EYEBALLS = "GRPC_PF_USE_HAPPY_EYEBALLS"; private static final String SHUFFLE_ADDRESS_LIST_KEY = "shuffleAddressList"; static boolean enableNewPickFirst = - !Strings.isNullOrEmpty(System.getenv("GRPC_EXPERIMENTAL_ENABLE_NEW_PICK_FIRST")) - && Boolean.parseBoolean(System.getenv("GRPC_EXPERIMENTAL_ENABLE_NEW_PICK_FIRST")); + GrpcUtil.getFlag("GRPC_EXPERIMENTAL_ENABLE_NEW_PICK_FIRST", false); + + public static boolean isEnabledHappyEyeballs() { + + return GrpcUtil.getFlag(GRPC_PF_USE_HAPPY_EYEBALLS, false); + } @Override public boolean isAvailable() { @@ -63,16 +69,28 @@ public LoadBalancer newLoadBalancer(LoadBalancer.Helper helper) { } @Override - public ConfigOrError parseLoadBalancingPolicyConfig( - Map rawLoadBalancingPolicyConfig) { + public ConfigOrError parseLoadBalancingPolicyConfig(Map rawLbPolicyConfig) { try { - return ConfigOrError.fromConfig( - new PickFirstLoadBalancerConfig(JsonUtil.getBoolean(rawLoadBalancingPolicyConfig, - SHUFFLE_ADDRESS_LIST_KEY))); + Object config = getLbPolicyConfig(rawLbPolicyConfig); + return ConfigOrError.fromConfig(config); } catch (RuntimeException e) { return ConfigOrError.fromError( Status.UNAVAILABLE.withCause(e).withDescription( "Failed parsing configuration for " + getPolicyName())); } } + + private static Object getLbPolicyConfig(Map rawLbPolicyConfig) { + Boolean shuffleAddressList = JsonUtil.getBoolean(rawLbPolicyConfig, SHUFFLE_ADDRESS_LIST_KEY); + if (enableNewPickFirst) { + return new PickFirstLeafLoadBalancerConfig(shuffleAddressList); + } else { + return new PickFirstLoadBalancerConfig(shuffleAddressList); + } + } + + @VisibleForTesting + public static boolean isEnabledNewPickFirst() { + return enableNewPickFirst; + } } diff --git a/core/src/main/java/io/grpc/internal/PickSubchannelArgsImpl.java b/core/src/main/java/io/grpc/internal/PickSubchannelArgsImpl.java index dd938303dee..c61fcac6f69 100644 --- a/core/src/main/java/io/grpc/internal/PickSubchannelArgsImpl.java +++ b/core/src/main/java/io/grpc/internal/PickSubchannelArgsImpl.java @@ -20,6 +20,7 @@ import com.google.common.base.Objects; import io.grpc.CallOptions; +import io.grpc.LoadBalancer.PickDetailsConsumer; import io.grpc.LoadBalancer.PickSubchannelArgs; import io.grpc.Metadata; import io.grpc.MethodDescriptor; @@ -29,15 +30,18 @@ public final class PickSubchannelArgsImpl extends PickSubchannelArgs { private final CallOptions callOptions; private final Metadata headers; private final MethodDescriptor method; + private final PickDetailsConsumer pickDetailsConsumer; /** * Creates call args object for given method with its call options, metadata. */ public PickSubchannelArgsImpl( - MethodDescriptor method, Metadata headers, CallOptions callOptions) { + MethodDescriptor method, Metadata headers, CallOptions callOptions, + PickDetailsConsumer pickDetailsConsumer) { this.method = checkNotNull(method, "method"); this.headers = checkNotNull(headers, "headers"); this.callOptions = checkNotNull(callOptions, "callOptions"); + this.pickDetailsConsumer = checkNotNull(pickDetailsConsumer, "pickDetailsConsumer"); } @Override @@ -55,6 +59,11 @@ public CallOptions getCallOptions() { return method; } + @Override + public PickDetailsConsumer getPickDetailsConsumer() { + return pickDetailsConsumer; + } + @Override public boolean equals(Object o) { if (this == o) { @@ -66,12 +75,13 @@ public boolean equals(Object o) { PickSubchannelArgsImpl that = (PickSubchannelArgsImpl) o; return Objects.equal(callOptions, that.callOptions) && Objects.equal(headers, that.headers) - && Objects.equal(method, that.method); + && Objects.equal(method, that.method) + && Objects.equal(pickDetailsConsumer, that.pickDetailsConsumer); } @Override public int hashCode() { - return Objects.hashCode(callOptions, headers, method); + return Objects.hashCode(callOptions, headers, method, pickDetailsConsumer); } @Override diff --git a/core/src/main/java/io/grpc/internal/ProxyDetectorImpl.java b/core/src/main/java/io/grpc/internal/ProxyDetectorImpl.java index 7e402bcf655..2f9903eac03 100644 --- a/core/src/main/java/io/grpc/internal/ProxyDetectorImpl.java +++ b/core/src/main/java/io/grpc/internal/ProxyDetectorImpl.java @@ -147,18 +147,9 @@ public ProxySelector get() { } }; - /** - * Experimental environment variable name for enabling proxy support. - * - * @deprecated Use the standard Java proxy configuration instead with flags such as: - * -Dhttps.proxyHost=HOST -Dhttps.proxyPort=PORT - */ - @Deprecated - private static final String GRPC_PROXY_ENV_VAR = "GRPC_PROXY_EXP"; // Do not hard code a ProxySelector because the global default ProxySelector can change private final Supplier proxySelector; private final AuthenticationProvider authenticationProvider; - private final InetSocketAddress overrideProxyAddress; // We want an HTTPS proxy, which operates on the entire data stream (See IETF rfc2817). static final String PROXY_SCHEME = "https"; @@ -168,21 +159,15 @@ public ProxySelector get() { * {@link ProxyDetectorImpl.AuthenticationProvider} to detect proxy parameters. */ public ProxyDetectorImpl() { - this(DEFAULT_PROXY_SELECTOR, DEFAULT_AUTHENTICATOR, System.getenv(GRPC_PROXY_ENV_VAR)); + this(DEFAULT_PROXY_SELECTOR, DEFAULT_AUTHENTICATOR); } @VisibleForTesting ProxyDetectorImpl( Supplier proxySelector, - AuthenticationProvider authenticationProvider, - @Nullable String proxyEnvString) { + AuthenticationProvider authenticationProvider) { this.proxySelector = checkNotNull(proxySelector); this.authenticationProvider = checkNotNull(authenticationProvider); - if (proxyEnvString != null) { - overrideProxyAddress = overrideProxy(proxyEnvString); - } else { - overrideProxyAddress = null; - } } @Nullable @@ -191,25 +176,12 @@ public ProxiedSocketAddress proxyFor(SocketAddress targetServerAddress) throws I if (!(targetServerAddress instanceof InetSocketAddress)) { return null; } - if (overrideProxyAddress != null) { - return HttpConnectProxiedSocketAddress.newBuilder() - .setProxyAddress(overrideProxyAddress) - .setTargetAddress((InetSocketAddress) targetServerAddress) - .build(); - } return detectProxy((InetSocketAddress) targetServerAddress); } private ProxiedSocketAddress detectProxy(InetSocketAddress targetAddr) throws IOException { URI uri; - String host; - try { - host = GrpcUtil.getHost(targetAddr); - } catch (Throwable t) { - // Workaround for Android API levels < 19 if getHostName causes a NetworkOnMainThreadException - log.log(Level.WARNING, "Failed to get host for proxy lookup, proceeding without proxy", t); - return null; - } + String host = targetAddr.getHostString(); try { uri = new URI( @@ -235,6 +207,14 @@ private ProxiedSocketAddress detectProxy(InetSocketAddress targetAddr) throws IO } List proxies = proxySelector.select(uri); + // ProxySelector.select(URI) is contractually required to return a non-null, non-empty list. + // Surface the offending implementation's class name so a broken ProxySelector can be fixed. + if (proxies == null || proxies.isEmpty()) { + throw new IOException( + "ProxySelector " + proxySelector.getClass().getName() + + " returned " + (proxies == null ? "null" : "an empty list") + + ", which violates the java.net.ProxySelector#select(URI) contract"); + } if (proxies.size() > 1) { log.warning("More than 1 proxy detected, gRPC will select the first one"); } @@ -247,13 +227,14 @@ private ProxiedSocketAddress detectProxy(InetSocketAddress targetAddr) throws IO // The prompt string should be the realm as returned by the server. // We don't have it because we are avoiding the full handshake. String promptString = ""; - PasswordAuthentication auth = authenticationProvider.requestPasswordAuthentication( - GrpcUtil.getHost(proxyAddr), - proxyAddr.getAddress(), - proxyAddr.getPort(), - PROXY_SCHEME, - promptString, - null); + PasswordAuthentication auth = + authenticationProvider.requestPasswordAuthentication( + proxyAddr.getHostString(), + proxyAddr.getAddress(), + proxyAddr.getPort(), + PROXY_SCHEME, + promptString, + null); final InetSocketAddress resolvedProxyAddr; if (proxyAddr.isUnresolved()) { @@ -278,27 +259,6 @@ private ProxiedSocketAddress detectProxy(InetSocketAddress targetAddr) throws IO .build(); } - /** - * GRPC_PROXY_EXP is deprecated but let's maintain compatibility for now. - */ - private static InetSocketAddress overrideProxy(String proxyHostPort) { - if (proxyHostPort == null) { - return null; - } - - String[] parts = proxyHostPort.split(":", 2); - int port = 80; - if (parts.length > 1) { - port = Integer.parseInt(parts[1]); - } - log.warning( - "Detected GRPC_PROXY_EXP and will honor it, but this feature will " - + "be removed in a future release. Use the JVM flags " - + "\"-Dhttps.proxyHost=HOST -Dhttps.proxyPort=PORT\" to set the https proxy for " - + "this JVM."); - return new InetSocketAddress(parts[0], port); - } - /** * This interface makes unit testing easier by avoiding direct calls to static methods. */ diff --git a/core/src/main/java/io/grpc/internal/ReadableBuffer.java b/core/src/main/java/io/grpc/internal/ReadableBuffer.java index 6963c78203e..20f64719875 100644 --- a/core/src/main/java/io/grpc/internal/ReadableBuffer.java +++ b/core/src/main/java/io/grpc/internal/ReadableBuffer.java @@ -71,15 +71,6 @@ public interface ReadableBuffer extends Closeable { */ void readBytes(byte[] dest, int destOffset, int length); - /** - * Reads from this buffer until the destination's position reaches its limit, and increases the - * read position by the number of the transferred bytes. - * - * @param dest the destination buffer to receive the bytes. - * @throws IndexOutOfBoundsException if required bytes are not readable - */ - void readBytes(ByteBuffer dest); - /** * Reads {@code length} bytes from this buffer and writes them to the destination stream. * Increments the read position by {@code length}. If the required bytes are not readable, throws diff --git a/core/src/main/java/io/grpc/internal/ReadableBuffers.java b/core/src/main/java/io/grpc/internal/ReadableBuffers.java index c54cb0e67d0..439745e29b2 100644 --- a/core/src/main/java/io/grpc/internal/ReadableBuffers.java +++ b/core/src/main/java/io/grpc/internal/ReadableBuffers.java @@ -16,7 +16,7 @@ package io.grpc.internal; -import static com.google.common.base.Charsets.UTF_8; +import static java.nio.charset.StandardCharsets.UTF_8; import com.google.common.base.Preconditions; import io.grpc.Detachable; @@ -171,15 +171,6 @@ public void readBytes(byte[] dest, int destIndex, int length) { offset += length; } - @Override - public void readBytes(ByteBuffer dest) { - Preconditions.checkNotNull(dest, "dest"); - int length = dest.remaining(); - checkReadable(length); - dest.put(bytes, offset, length); - offset += length; - } - @Override public void readBytes(OutputStream dest, int length) throws IOException { checkReadable(length); @@ -262,21 +253,6 @@ public void readBytes(byte[] dest, int destOffset, int length) { bytes.get(dest, destOffset, length); } - @Override - public void readBytes(ByteBuffer dest) { - Preconditions.checkNotNull(dest, "dest"); - int length = dest.remaining(); - checkReadable(length); - - // Change the limit so that only length bytes are available. - int prevLimit = bytes.limit(); - ((Buffer) bytes).limit(bytes.position() + length); - - // Write the bytes and restore the original limit. - dest.put(bytes); - bytes.limit(prevLimit); - } - @Override public void readBytes(OutputStream dest, int length) throws IOException { checkReadable(length); @@ -415,6 +391,7 @@ public ByteBuffer getByteBuffer() { public InputStream detach() { ReadableBuffer detachedBuffer = buffer; buffer = buffer.readBytes(0); + detachedBuffer.touch(); return new BufferInputStream(detachedBuffer); } diff --git a/core/src/main/java/io/grpc/internal/RetriableStream.java b/core/src/main/java/io/grpc/internal/RetriableStream.java index f301eee1f98..0c37a0beaca 100644 --- a/core/src/main/java/io/grpc/internal/RetriableStream.java +++ b/core/src/main/java/io/grpc/internal/RetriableStream.java @@ -22,6 +22,8 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Objects; +import com.google.errorprone.annotations.CheckReturnValue; +import com.google.errorprone.annotations.concurrent.GuardedBy; import io.grpc.Attributes; import io.grpc.ClientStreamTracer; import io.grpc.Compressor; @@ -47,9 +49,7 @@ import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicLong; import javax.annotation.CheckForNull; -import javax.annotation.CheckReturnValue; import javax.annotation.Nullable; -import javax.annotation.concurrent.GuardedBy; /** A logical {@link ClientStream} that is retriable. */ abstract class RetriableStream implements ClientStream { @@ -149,11 +149,10 @@ public void uncaughtException(Thread t, Throwable e) { this.throttle = throttle; } - @SuppressWarnings("GuardedBy") + @SuppressWarnings("GuardedBy") // TODO(b/145386688) this.lock==ScheduledCancellor.lock so ok @Nullable // null if already committed @CheckReturnValue private Runnable commit(final Substream winningSubstream) { - synchronized (lock) { if (state.winningSubstream != null) { return null; @@ -165,10 +164,10 @@ private Runnable commit(final Substream winningSubstream) { // subtract the share of this RPC from channelBufferUsed. channelBufferUsed.addAndGet(-perRpcBufferUsed); + final boolean wasCancelled = (scheduledRetry != null) ? scheduledRetry.isCancelled() : false; final Future retryFuture; - if (scheduledRetry != null) { - // TODO(b/145386688): This access should be guarded by 'this.scheduledRetry.lock'; instead - // found: 'this.lock' + final boolean retryWasScheduled = scheduledRetry != null; + if (retryWasScheduled) { retryFuture = scheduledRetry.markCancelled(); scheduledRetry = null; } else { @@ -177,8 +176,6 @@ private Runnable commit(final Substream winningSubstream) { // cancel the scheduled hedging if it is scheduled prior to the commitment final Future hedgingFuture; if (scheduledHedging != null) { - // TODO(b/145386688): This access should be guarded by 'this.scheduledHedging.lock'; instead - // found: 'this.lock' hedgingFuture = scheduledHedging.markCancelled(); scheduledHedging = null; } else { @@ -194,9 +191,25 @@ public void run() { substream.stream.cancel(CANCELLED_BECAUSE_COMMITTED); } } - if (retryFuture != null) { - retryFuture.cancel(false); + if (retryWasScheduled) { + if (retryFuture != null) { + retryFuture.cancel(false); + } + if (!wasCancelled && inFlightSubStreams.decrementAndGet() == Integer.MIN_VALUE) { + assert savedCloseMasterListenerReason != null; + listenerSerializeExecutor.execute( + new Runnable() { + @Override + public void run() { + isClosed = true; + masterListener.closed(savedCloseMasterListenerReason.status, + savedCloseMasterListenerReason.progress, + savedCloseMasterListenerReason.metadata); + } + }); + } } + if (hedgingFuture != null) { hedgingFuture.cancel(false); } @@ -235,7 +248,8 @@ private void commitAndRun(Substream winningSubstream) { // returns null means we should not create new sub streams, e.g. cancelled or // other close condition is met for retriableStream. @Nullable - private Substream createSubstream(int previousAttemptCount, boolean isTransparentRetry) { + private Substream createSubstream(int previousAttemptCount, boolean isTransparentRetry, + boolean isHedgedStream) { int inFlight; do { inFlight = inFlightSubStreams.get(); @@ -256,7 +270,8 @@ public ClientStreamTracer newClientStreamTracer( Metadata newHeaders = updateHeaders(headers, previousAttemptCount); // NOTICE: This set _must_ be done before stream.start() and it actually is. - sub.stream = newSubstream(newHeaders, tracerFactory, previousAttemptCount, isTransparentRetry); + sub.stream = newSubstream(newHeaders, tracerFactory, previousAttemptCount, isTransparentRetry, + isHedgedStream); return sub; } @@ -266,7 +281,7 @@ public ClientStreamTracer newClientStreamTracer( */ abstract ClientStream newSubstream( Metadata headers, ClientStreamTracer.Factory tracerFactory, int previousAttempts, - boolean isTransparentRetry); + boolean isTransparentRetry, boolean isHedgedStream); /** Adds grpc-previous-rpc-attempts in the headers of a retry/hedging RPC. */ @VisibleForTesting @@ -372,7 +387,7 @@ public void runWith(Substream substream) { } } - /** Starts the first PRC attempt. */ + /** Starts the first RPC attempt. */ @Override public final void start(ClientStreamListener listener) { masterListener = listener; @@ -388,7 +403,7 @@ public final void start(ClientStreamListener listener) { state.buffer.add(new StartEntry()); } - Substream substream = createSubstream(0, false); + Substream substream = createSubstream(0, false, false); if (substream == null) { return; } @@ -415,7 +430,7 @@ public final void start(ClientStreamListener listener) { drain(substream); } - @SuppressWarnings("GuardedBy") + @SuppressWarnings("GuardedBy") // TODO(b/145386688) this.lock==ScheduledCancellor.lock so ok private void pushbackHedging(@Nullable Integer delayMillis) { if (delayMillis == null) { return; @@ -434,8 +449,6 @@ private void pushbackHedging(@Nullable Integer delayMillis) { return; } - // TODO(b/145386688): This access should be guarded by 'this.scheduledHedging.lock'; instead - // found: 'this.lock' futureToBeCancelled = scheduledHedging.markCancelled(); scheduledHedging = future = new FutureCanceller(lock); } @@ -463,22 +476,19 @@ public void run() { // If this run is not cancelled, the value of state.hedgingAttemptCount won't change // until state.addActiveHedge() is called subsequently, even the state could possibly // change. - Substream newSubstream = createSubstream(state.hedgingAttemptCount, false); + Substream newSubstream = createSubstream(state.hedgingAttemptCount, false, true); if (newSubstream == null) { return; } callExecutor.execute( new Runnable() { - @SuppressWarnings("GuardedBy") + @SuppressWarnings("GuardedBy") //TODO(b/145386688) lock==ScheduledCancellor.lock so ok @Override public void run() { boolean cancelled = false; FutureCanceller future = null; synchronized (lock) { - // TODO(b/145386688): This access should be guarded by - // 'HedgingRunnable.this.scheduledHedgingRef.lock'; instead found: - // 'RetriableStream.this.lock' if (scheduledHedgingRef.isCancelled()) { cancelled = true; } else { @@ -810,13 +820,11 @@ private boolean hasPotentialHedging(State state) { && !state.hedgingFrozen; } - @SuppressWarnings("GuardedBy") + @SuppressWarnings("GuardedBy") // TODO(b/145386688) this.lock==ScheduledCancellor.lock so ok private void freezeHedging() { Future futureToBeCancelled = null; synchronized (lock) { if (scheduledHedging != null) { - // TODO(b/145386688): This access should be guarded by 'this.scheduledHedging.lock'; instead - // found: 'this.lock' futureToBeCancelled = scheduledHedging.markCancelled(); scheduledHedging = null; } @@ -843,6 +851,15 @@ public void run() { } } + private static final boolean isExperimentalRetryJitterEnabled = GrpcUtil + .getFlag("GRPC_EXPERIMENTAL_XDS_RLS_LB", true); + + public static long intervalWithJitter(long intervalNanos) { + double inverseJitterFactor = isExperimentalRetryJitterEnabled + ? 0.4 * random.nextDouble() + 0.8 : random.nextDouble(); + return (long) (intervalNanos * inverseJitterFactor); + } + private static final class SavedCloseMasterListenerReason { private final Status status; private final RpcProgress progress; @@ -924,9 +941,8 @@ public void run() { && localOnlyTransparentRetries.incrementAndGet() > 1_000) { commitAndRun(substream); if (state.winningSubstream == substream) { - Status tooManyTransparentRetries = Status.INTERNAL - .withDescription("Too many transparent retries. Might be a bug in gRPC") - .withCause(status.asRuntimeException()); + Status tooManyTransparentRetries = GrpcUtil.statusWithDetails( + Status.Code.INTERNAL, "Too many transparent retries. Might be a bug in gRPC", status); safeCloseMasterListener(tooManyTransparentRetries, rpcProgress, trailers); } return; @@ -937,7 +953,8 @@ public void run() { || (rpcProgress == RpcProgress.REFUSED && noMoreTransparentRetry.compareAndSet(false, true))) { // transparent retry - final Substream newSubstream = createSubstream(substream.previousAttemptCount, true); + final Substream newSubstream = createSubstream(substream.previousAttemptCount, + true, false); if (newSubstream == null) { return; } @@ -989,7 +1006,8 @@ public void run() { RetryPlan retryPlan = makeRetryDecision(status, trailers); if (retryPlan.shouldRetry) { // retry - Substream newSubstream = createSubstream(substream.previousAttemptCount + 1, false); + Substream newSubstream = createSubstream(substream.previousAttemptCount + 1, + false, false); if (newSubstream == null) { return; } @@ -999,9 +1017,19 @@ public void run() { synchronized (lock) { scheduledRetry = scheduledRetryCopy = new FutureCanceller(lock); } + class RetryBackoffRunnable implements Runnable { @Override + @SuppressWarnings("FutureReturnValueIgnored") public void run() { + synchronized (scheduledRetryCopy.lock) { + if (scheduledRetryCopy.isCancelled()) { + return; + } else { + scheduledRetryCopy.markCancelled(); + } + } + callExecutor.execute( new Runnable() { @Override @@ -1053,7 +1081,7 @@ private RetryPlan makeRetryDecision(Status status, Metadata trailer) { if (pushbackMillis == null) { if (isRetryableStatusCode) { shouldRetry = true; - backoffNanos = (long) (nextBackoffIntervalNanos * random.nextDouble()); + backoffNanos = intervalWithJitter(nextBackoffIntervalNanos); nextBackoffIntervalNanos = Math.min( (long) (nextBackoffIntervalNanos * retryPolicy.backoffMultiplier), retryPolicy.maxBackoffNanos); @@ -1563,11 +1591,16 @@ private static final class FutureCanceller { } void setFuture(Future future) { + boolean wasCancelled; synchronized (lock) { - if (!cancelled) { + wasCancelled = cancelled; + if (!wasCancelled) { this.future = future; } } + if (wasCancelled) { + future.cancel(false); + } } @GuardedBy("lock") diff --git a/core/src/main/java/io/grpc/internal/RetryingNameResolver.java b/core/src/main/java/io/grpc/internal/RetryingNameResolver.java index 6d806e95944..90827fa8acb 100644 --- a/core/src/main/java/io/grpc/internal/RetryingNameResolver.java +++ b/core/src/main/java/io/grpc/internal/RetryingNameResolver.java @@ -17,7 +17,6 @@ package io.grpc.internal; import com.google.common.annotations.VisibleForTesting; -import io.grpc.Attributes; import io.grpc.NameResolver; import io.grpc.Status; import io.grpc.SynchronizationContext; @@ -28,16 +27,22 @@ * *

The {@link NameResolver} used with this */ -final class RetryingNameResolver extends ForwardingNameResolver { +public final class RetryingNameResolver extends ForwardingNameResolver { + public static NameResolver wrap(NameResolver retriedNameResolver, Args args) { + // For migration, this might become conditional + return new RetryingNameResolver( + retriedNameResolver, + new BackoffPolicyRetryScheduler( + new ExponentialBackoffPolicy.Provider(), + args.getScheduledExecutorService(), + args.getSynchronizationContext()), + args.getSynchronizationContext()); + } private final NameResolver retriedNameResolver; private final RetryScheduler retryScheduler; private final SynchronizationContext syncContext; - static final Attributes.Key RESOLUTION_RESULT_LISTENER_KEY - = Attributes.Key.create( - "io.grpc.internal.RetryingNameResolver.RESOLUTION_RESULT_LISTENER_KEY"); - /** * Creates a new {@link RetryingNameResolver}. * @@ -88,38 +93,24 @@ private class RetryingListener extends Listener2 { @Override public void onResult(ResolutionResult resolutionResult) { - // If the resolution result listener is already an attribute it indicates that a name resolver - // has already been wrapped with this class. This indicates a misconfiguration. - if (resolutionResult.getAttributes().get(RESOLUTION_RESULT_LISTENER_KEY) != null) { - throw new IllegalStateException( - "RetryingNameResolver can only be used once to wrap a NameResolver"); - } - - delegateListener.onResult(resolutionResult.toBuilder().setAttributes( - resolutionResult.getAttributes().toBuilder() - .set(RESOLUTION_RESULT_LISTENER_KEY, new ResolutionResultListener()).build()) - .build()); + syncContext.execute(() -> onResult2(resolutionResult)); } @Override - public void onError(Status error) { - delegateListener.onError(error); - syncContext.execute(() -> retryScheduler.schedule(new DelayedNameResolverRefresh())); - } - } - - /** - * Simple callback class to store in {@link ResolutionResult} attributes so that - * ManagedChannel can indicate if the resolved addresses were accepted. Temporary until - * the Listener2.onResult() API can be changed to return a boolean for this purpose. - */ - class ResolutionResultListener { - public void resolutionAttempted(Status successStatus) { - if (successStatus.isOk()) { + public Status onResult2(ResolutionResult resolutionResult) { + Status status = delegateListener.onResult2(resolutionResult); + if (status.isOk()) { retryScheduler.reset(); } else { retryScheduler.schedule(new DelayedNameResolverRefresh()); } + return status; + } + + @Override + public void onError(Status error) { + delegateListener.onError(error); + syncContext.execute(() -> retryScheduler.schedule(new DelayedNameResolverRefresh())); } } } diff --git a/core/src/main/java/io/grpc/internal/ScParser.java b/core/src/main/java/io/grpc/internal/ScParser.java index f94449f7c7b..71d6d33877f 100644 --- a/core/src/main/java/io/grpc/internal/ScParser.java +++ b/core/src/main/java/io/grpc/internal/ScParser.java @@ -19,6 +19,7 @@ import static com.google.common.base.Preconditions.checkNotNull; import com.google.common.annotations.VisibleForTesting; +import io.grpc.LoadBalancerProvider; import io.grpc.NameResolver; import io.grpc.NameResolver.ConfigOrError; import io.grpc.Status; @@ -31,18 +32,18 @@ public final class ScParser extends NameResolver.ServiceConfigParser { private final boolean retryEnabled; private final int maxRetryAttemptsLimit; private final int maxHedgedAttemptsLimit; - private final AutoConfiguredLoadBalancerFactory autoLoadBalancerFactory; + private final LoadBalancerProvider parser; /** Creates a parse with global retry settings and an auto configured lb factory. */ public ScParser( boolean retryEnabled, int maxRetryAttemptsLimit, int maxHedgedAttemptsLimit, - AutoConfiguredLoadBalancerFactory autoLoadBalancerFactory) { + LoadBalancerProvider parser) { this.retryEnabled = retryEnabled; this.maxRetryAttemptsLimit = maxRetryAttemptsLimit; this.maxHedgedAttemptsLimit = maxHedgedAttemptsLimit; - this.autoLoadBalancerFactory = checkNotNull(autoLoadBalancerFactory, "autoLoadBalancerFactory"); + this.parser = checkNotNull(parser, "parser"); } @Override @@ -50,7 +51,9 @@ public ConfigOrError parseServiceConfig(Map rawServiceConfig) { try { Object loadBalancingPolicySelection; ConfigOrError choiceFromLoadBalancer = - autoLoadBalancerFactory.parseLoadBalancerPolicy(rawServiceConfig); + parser.parseLoadBalancingPolicyConfig(rawServiceConfig); + // TODO(ejona): The Provider API doesn't allow null, but AutoConfiguredLoadBalancerFactory can + // return null and it will need tweaking to ManagedChannelImpl.defaultServiceConfig to fix. if (choiceFromLoadBalancer == null) { loadBalancingPolicySelection = null; } else if (choiceFromLoadBalancer.getError() != null) { @@ -66,8 +69,19 @@ public ConfigOrError parseServiceConfig(Map rawServiceConfig) { maxHedgedAttemptsLimit, loadBalancingPolicySelection)); } catch (RuntimeException e) { + // TODO(ejona): We really don't want parsers throwing exceptions; they should return an error. + // However, right now ManagedChannelServiceConfig itself uses exceptions like + // ClassCastException. We should handle those with a graceful return within + // ManagedChannelServiceConfig and then get rid of this case. Then all exceptions are + // "unexpected" and the INTERNAL status code makes it clear a bug needs to be fixed. return ConfigOrError.fromError( Status.UNKNOWN.withDescription("failed to parse service config").withCause(e)); + } catch (Throwable t) { + // Even catch Errors, since broken config parsing could trigger AssertionError, + // StackOverflowError, and other errors we can reasonably safely recover. Since the config + // could be untrusted, we want to error on the side of recovering. + return ConfigOrError.fromError( + Status.INTERNAL.withDescription("Unexpected error parsing service config").withCause(t)); } } } diff --git a/core/src/main/java/io/grpc/internal/SerializingExecutor.java b/core/src/main/java/io/grpc/internal/SerializingExecutor.java index 73133a339e4..7044b4e17fc 100644 --- a/core/src/main/java/io/grpc/internal/SerializingExecutor.java +++ b/core/src/main/java/io/grpc/internal/SerializingExecutor.java @@ -113,7 +113,7 @@ private void schedule(@Nullable Runnable removable) { // ConcurrentLinkedQueue claims that null elements are not allowed, but seems to not // throw if the item to remove is null. If removable is present in the queue twice, // the wrong one may be removed. It doesn't seem possible for this case to exist today. - // This is important to run in case of RejectedExectuionException, so that future calls + // This is important to run in case of RejectedExecutionException, so that future calls // to execute don't succeed and accidentally run a previous runnable. runQueue.remove(removable); } diff --git a/core/src/main/java/io/grpc/internal/ServerCallImpl.java b/core/src/main/java/io/grpc/internal/ServerCallImpl.java index 1bfee21e055..e224384ce8f 100644 --- a/core/src/main/java/io/grpc/internal/ServerCallImpl.java +++ b/core/src/main/java/io/grpc/internal/ServerCallImpl.java @@ -184,6 +184,11 @@ public void setMessageCompression(boolean enable) { stream.setMessageCompression(enable); } + @Override + public void setOnReadyThreshold(int numBytes) { + stream.setOnReadyThreshold(numBytes); + } + @Override public void setCompression(String compressorName) { // Added here to give a better error message. @@ -368,10 +373,10 @@ private void closedInternal(Status status) { } else { call.cancelled = true; listener.onCancel(); - // The status will not have a cause in all failure scenarios but we want to make sure + // The status will not have a cause in all failure scenarios, but we want to make sure // we always cancel the context with one to keep the context cancelled state consistent. - cancelCause = InternalStatus.asRuntimeException( - Status.CANCELLED.withDescription("RPC cancelled"), null, false); + cancelCause = InternalStatus.asRuntimeExceptionWithoutStacktrace( + Status.CANCELLED.withDescription("RPC cancelled"), null); } } finally { // Cancel context after delivering RPC closure notification to allow the application to diff --git a/core/src/main/java/io/grpc/internal/ServerImpl.java b/core/src/main/java/io/grpc/internal/ServerImpl.java index cec2a13a301..d9f64c2d473 100644 --- a/core/src/main/java/io/grpc/internal/ServerImpl.java +++ b/core/src/main/java/io/grpc/internal/ServerImpl.java @@ -31,6 +31,7 @@ import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.SettableFuture; +import com.google.errorprone.annotations.concurrent.GuardedBy; import io.grpc.Attributes; import io.grpc.BinaryLog; import io.grpc.CompressorRegistry; @@ -75,7 +76,6 @@ import java.util.concurrent.TimeUnit; import java.util.logging.Level; import java.util.logging.Logger; -import javax.annotation.concurrent.GuardedBy; /** * Default implementation of {@link io.grpc.Server}, for creation by transports. @@ -99,7 +99,7 @@ public final class ServerImpl extends io.grpc.Server implements InternalInstrume private final ObjectPool executorPool; /** Executor for application processing. Safe to read after {@link #start()}. */ private Executor executor; - private final HandlerRegistry registry; + private final InternalHandlerRegistry registry; private final HandlerRegistry fallbackRegistry; private final List transportFilters; // This is iterated on a per-call basis. Use an array instead of a Collection to avoid iterator @@ -151,7 +151,9 @@ public final class ServerImpl extends io.grpc.Server implements InternalInstrume InternalLogId.allocate("Server", String.valueOf(getListenSocketsIgnoringLifecycle())); // Fork from the passed in context so that it does not propagate cancellation, it only // inherits values. - this.rootContext = Preconditions.checkNotNull(rootContext, "rootContext").fork(); + this.rootContext = Preconditions.checkNotNull(rootContext, "rootContext") + .fork() + .withValue(io.grpc.InternalServer.SERVER_CONTEXT_KEY, ServerImpl.this); this.decompressorRegistry = builder.decompressorRegistry; this.compressorRegistry = builder.compressorRegistry; this.transportFilters = Collections.unmodifiableList( @@ -496,8 +498,12 @@ private void streamCreatedInternal( final StatsTraceContext statsTraceCtx = Preconditions.checkNotNull( stream.statsTraceContext(), "statsTraceCtx not present from stream"); + final ServerMethodDefinition primaryMethod = registry.lookupMethod(methodName, null); final Context.CancellableContext context = createContext(headers, statsTraceCtx); + if (primaryMethod != null) { + statsTraceCtx.serverCallMethodResolved(primaryMethod.getMethodDescriptor()); + } final Link link = PerfMark.linkOut(); @@ -534,7 +540,7 @@ private void runInternal() { ServerMethodDefinition wrapMethod; ServerCallParameters callParams; try { - ServerMethodDefinition method = registry.lookupMethod(methodName); + ServerMethodDefinition method = primaryMethod; if (method == null) { method = fallbackRegistry.lookupMethod(methodName, stream.getAuthority()); } @@ -622,19 +628,7 @@ private void runInternal() { // An extremely short deadline may expire before stream.setListener(jumpListener). // This causes NPE as in issue: https://github.com/grpc/grpc-java/issues/6300 // Delay of setting cancellationListener to context will fix the issue. - final class ServerStreamCancellationListener implements Context.CancellationListener { - @Override - public void cancelled(Context context) { - Status status = statusFromCancelled(context); - if (DEADLINE_EXCEEDED.getCode().equals(status.getCode())) { - // This should rarely get run, since the client will likely cancel the stream - // before the timeout is reached. - stream.cancel(status); - } - } - } - - context.addListener(new ServerStreamCancellationListener(), directExecutor()); + context.addListener(new ServerStreamCancellationListener(stream), directExecutor()); } } @@ -648,8 +642,7 @@ private Context.CancellableContext createContext( Context baseContext = statsTraceCtx - .serverFilterContext(rootContext) - .withValue(io.grpc.InternalServer.SERVER_CONTEXT_KEY, ServerImpl.this); + .serverFilterContext(rootContext); if (timeoutNanos == null) { return baseContext.withCancellation(); @@ -707,6 +700,31 @@ private ServerStreamListener startWrappedCall( } } + /** + * Propagates context cancellation to the ServerStream. + * + *

This is outside of HandleServerCall because that class holds Metadata and other state needed + * only when starting the RPC. The cancellation listener will live for the life of the call, so we + * avoid that useless state being retained. + */ + static final class ServerStreamCancellationListener implements Context.CancellationListener { + private final ServerStream stream; + + ServerStreamCancellationListener(ServerStream stream) { + this.stream = checkNotNull(stream, "stream"); + } + + @Override + public void cancelled(Context context) { + Status status = statusFromCancelled(context); + if (DEADLINE_EXCEEDED.getCode().equals(status.getCode())) { + // This should rarely get run, since the client will likely cancel the stream + // before the timeout is reached. + stream.cancel(status); + } + } + } + @Override public InternalLogId getLogId() { return logId; @@ -887,8 +905,8 @@ private void closedInternal(final Status status) { // failed status has an exception we will create one here if needed. Throwable cancelCause = status.getCause(); if (cancelCause == null) { - cancelCause = InternalStatus.asRuntimeException( - Status.CANCELLED.withDescription("RPC cancelled"), null, false); + cancelCause = InternalStatus.asRuntimeExceptionWithoutStacktrace( + Status.CANCELLED.withDescription("RPC cancelled"), null); } // The callExecutor might be busy doing user work. To avoid waiting, use an executor that diff --git a/core/src/main/java/io/grpc/internal/ServerImplBuilder.java b/core/src/main/java/io/grpc/internal/ServerImplBuilder.java index cd18457d51b..62a0e66f314 100644 --- a/core/src/main/java/io/grpc/internal/ServerImplBuilder.java +++ b/core/src/main/java/io/grpc/internal/ServerImplBuilder.java @@ -30,7 +30,10 @@ import io.grpc.DecompressorRegistry; import io.grpc.HandlerRegistry; import io.grpc.InternalChannelz; -import io.grpc.InternalGlobalInterceptors; +import io.grpc.InternalConfiguratorRegistry; +import io.grpc.MetricInstrumentRegistry; +import io.grpc.MetricRecorder; +import io.grpc.MetricSink; import io.grpc.Server; import io.grpc.ServerBuilder; import io.grpc.ServerCallExecutorSupplier; @@ -80,6 +83,7 @@ public static ServerBuilder forPort(int port) { final List transportFilters = new ArrayList<>(); final List interceptors = new ArrayList<>(); private final List streamTracerFactories = new ArrayList<>(); + final List metricSinks = new ArrayList<>(); private final ClientTransportServersBuilder clientTransportServersBuilder; HandlerRegistry fallbackRegistry = DEFAULT_FALLBACK_REGISTRY; ObjectPool executorPool = DEFAULT_EXECUTOR_POOL; @@ -99,12 +103,13 @@ public static ServerBuilder forPort(int port) { ServerCallExecutorSupplier executorSupplier; /** - * An interface to provide to provide transport specific information for the server. This method + * An interface to provide transport specific information for the server. This method * is meant for Transport implementors and should not be used by normal users. */ public interface ClientTransportServersBuilder { InternalServer buildClientTransportServers( - List streamTracerFactories); + List streamTracerFactories, + MetricRecorder metricRecorder); } /** @@ -113,6 +118,8 @@ InternalServer buildClientTransportServers( public ServerImplBuilder(ClientTransportServersBuilder clientTransportServersBuilder) { this.clientTransportServersBuilder = checkNotNull(clientTransportServersBuilder, "clientTransportServersBuilder"); + // TODO(dnvindhya): Move configurator to all the individual builders + InternalConfiguratorRegistry.configureServerBuilder(this); } @Override @@ -155,6 +162,15 @@ public ServerImplBuilder intercept(ServerInterceptor interceptor) { return this; } + /** + * Adds a MetricSink to the server. + */ + @Override + public ServerImplBuilder addMetricSink(MetricSink metricSink) { + metricSinks.add(checkNotNull(metricSink, "metricSink")); + return this; + } + @Override public ServerImplBuilder addStreamTracerFactory(ServerStreamTracer.Factory factory) { streamTracerFactories.add(checkNotNull(factory, "factory")); @@ -239,25 +255,22 @@ public void setDeadlineTicker(Deadline.Ticker ticker) { @Override public Server build() { + MetricRecorder metricRecorder = new MetricRecorderImpl(metricSinks, + MetricInstrumentRegistry.getDefaultRegistry()); return new ServerImpl(this, - clientTransportServersBuilder.buildClientTransportServers(getTracerFactories()), + clientTransportServersBuilder.buildClientTransportServers( + getTracerFactories(), metricRecorder), Context.ROOT); } @VisibleForTesting List getTracerFactories() { - ArrayList tracerFactories = new ArrayList<>(); - boolean isGlobalInterceptorsTracersSet = false; - List globalServerInterceptors - = InternalGlobalInterceptors.getServerInterceptors(); - List globalServerStreamTracerFactories - = InternalGlobalInterceptors.getServerStreamTracerFactories(); - if (globalServerInterceptors != null) { - tracerFactories.addAll(globalServerStreamTracerFactories); - interceptors.addAll(globalServerInterceptors); - isGlobalInterceptorsTracersSet = true; + boolean disableImplicitCensus = InternalConfiguratorRegistry.wasSetConfiguratorsCalled(); + if (disableImplicitCensus) { + return streamTracerFactories; } - if (!isGlobalInterceptorsTracersSet && statsEnabled) { + ArrayList tracerFactories = new ArrayList<>(); + if (statsEnabled) { ServerStreamTracer.Factory censusStatsTracerFactory = null; try { Class censusStatsAccessor = @@ -289,7 +302,7 @@ List getTracerFactories() { tracerFactories.add(censusStatsTracerFactory); } } - if (!isGlobalInterceptorsTracersSet && tracingEnabled) { + if (tracingEnabled) { ServerStreamTracer.Factory tracingStreamTracerFactory = null; try { Class censusTracingAccessor = diff --git a/core/src/main/java/io/grpc/internal/ServerStream.java b/core/src/main/java/io/grpc/internal/ServerStream.java index 861d5f36cc7..aa5ba10329c 100644 --- a/core/src/main/java/io/grpc/internal/ServerStream.java +++ b/core/src/main/java/io/grpc/internal/ServerStream.java @@ -96,4 +96,15 @@ public interface ServerStream extends Stream { * The HTTP/2 stream id, or {@code -1} if not supported. */ int streamId(); + + /** + * A hint to the stream that specifies how many bytes must be queued before + * {@link #isReady()} will return false. A stream may ignore this property if + * unsupported. This may only be set during stream initialization before + * any messages are set. + * + * @param numBytes The number of bytes that must be queued. Must be a + * positive integer. + */ + void setOnReadyThreshold(int numBytes); } diff --git a/core/src/main/java/io/grpc/internal/SharedResourceHolder.java b/core/src/main/java/io/grpc/internal/SharedResourceHolder.java index 67d1a98b545..1dfa1f90718 100644 --- a/core/src/main/java/io/grpc/internal/SharedResourceHolder.java +++ b/core/src/main/java/io/grpc/internal/SharedResourceHolder.java @@ -134,18 +134,16 @@ synchronized T releaseInternal(final Resource resource, final T instance) public void run() { synchronized (SharedResourceHolder.this) { // Refcount may have gone up since the task was scheduled. Re-check it. - if (cached.refcount == 0) { - try { - resource.close(instance); - } finally { - instances.remove(resource); - if (instances.isEmpty()) { - destroyer.shutdown(); - destroyer = null; - } - } + if (cached.refcount != 0) { + return; + } + instances.remove(resource); + if (instances.isEmpty()) { + destroyer.shutdown(); + destroyer = null; } } + resource.close(instance); } }), DESTROY_DELAY_SECONDS, TimeUnit.SECONDS); } diff --git a/core/src/main/java/io/grpc/internal/SimpleDisconnectError.java b/core/src/main/java/io/grpc/internal/SimpleDisconnectError.java new file mode 100644 index 00000000000..addbfbe10a3 --- /dev/null +++ b/core/src/main/java/io/grpc/internal/SimpleDisconnectError.java @@ -0,0 +1,68 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.internal; + +import javax.annotation.concurrent.Immutable; + +/** + * Represents a fixed, static reason for disconnection. + */ +@Immutable +public enum SimpleDisconnectError implements DisconnectError { + /** + * The subchannel was shut down for various reasons like parent channel shutdown, + * idleness, or load balancing policy changes. + */ + SUBCHANNEL_SHUTDOWN("subchannel shutdown"), + + /** + * Connection was reset (e.g., ECONNRESET, WSAECONNERESET). + */ + CONNECTION_RESET("connection reset"), + + /** + * Connection timed out (e.g., ETIMEDOUT, WSAETIMEDOUT), including closures + * from gRPC keepalives. + */ + CONNECTION_TIMED_OUT("connection timed out"), + + /** + * Connection was aborted (e.g., ECONNABORTED, WSAECONNABORTED). + */ + CONNECTION_ABORTED("connection aborted"), + + /** + * Any socket error not covered by other specific disconnect errors. + */ + SOCKET_ERROR("socket error"), + + /** + * A catch-all for any other unclassified reason. + */ + UNKNOWN("unknown"); + + private final String errorTag; + + SimpleDisconnectError(String errorTag) { + this.errorTag = errorTag; + } + + @Override + public String toErrorString() { + return this.errorTag; + } +} diff --git a/core/src/main/java/io/grpc/internal/SpiffeUtil.java b/core/src/main/java/io/grpc/internal/SpiffeUtil.java new file mode 100644 index 00000000000..9eafc9950e2 --- /dev/null +++ b/core/src/main/java/io/grpc/internal/SpiffeUtil.java @@ -0,0 +1,312 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.internal; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkNotNull; + +import com.google.common.base.Optional; +import com.google.common.base.Splitter; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import com.google.common.io.Files; +import java.io.ByteArrayInputStream; +import java.io.File; +import java.io.IOException; +import java.io.InputStream; +import java.nio.charset.StandardCharsets; +import java.security.cert.Certificate; +import java.security.cert.CertificateException; +import java.security.cert.CertificateFactory; +import java.security.cert.CertificateParsingException; +import java.security.cert.X509Certificate; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Locale; +import java.util.Map; + +/** + * Provides utilities to manage SPIFFE bundles, extract SPIFFE IDs from X.509 certificate chains, + * and parse SPIFFE IDs. + * @see Standard + */ +public final class SpiffeUtil { + + private static final Integer URI_SAN_TYPE = 6; + private static final String USE_PARAMETER_VALUE = "x509-svid"; + private static final ImmutableSet KTY_PARAMETER_VALUES = ImmutableSet.of("RSA", "EC"); + private static final String CERTIFICATE_PREFIX = "-----BEGIN CERTIFICATE-----\n"; + private static final String CERTIFICATE_SUFFIX = "-----END CERTIFICATE-----"; + private static final String PREFIX = "spiffe://"; + + private SpiffeUtil() {} + + /** + * Parses a URI string, applies validation rules described in SPIFFE standard, and, in case of + * success, returns parsed TrustDomain and Path. + * + * @param uri a String representing a SPIFFE ID + */ + public static SpiffeId parse(String uri) { + doInitialUriValidation(uri); + checkArgument(uri.toLowerCase(Locale.US).startsWith(PREFIX), "Spiffe Id must start with " + + PREFIX); + String domainAndPath = uri.substring(PREFIX.length()); + String trustDomain; + String path; + if (!domainAndPath.contains("/")) { + trustDomain = domainAndPath; + path = ""; + } else { + String[] parts = domainAndPath.split("/", 2); + trustDomain = parts[0]; + path = parts[1]; + checkArgument(!path.isEmpty(), "Path must not include a trailing '/'"); + } + validateTrustDomain(trustDomain); + validatePath(path); + if (!path.isEmpty()) { + path = "/" + path; + } + return new SpiffeId(trustDomain, path); + } + + private static void doInitialUriValidation(String uri) { + checkArgument(checkNotNull(uri, "uri").length() > 0, "Spiffe Id can't be empty"); + checkArgument(uri.length() <= 2048, "Spiffe Id maximum length is 2048 characters"); + checkArgument(!uri.contains("#"), "Spiffe Id must not contain query fragments"); + checkArgument(!uri.contains("?"), "Spiffe Id must not contain query parameters"); + } + + private static void validateTrustDomain(String trustDomain) { + checkArgument(!trustDomain.isEmpty(), "Trust Domain can't be empty"); + checkArgument(trustDomain.length() < 256, "Trust Domain maximum length is 255 characters"); + checkArgument(trustDomain.matches("[a-z0-9._-]+"), + "Trust Domain must contain only letters, numbers, dots, dashes, and underscores" + + " ([a-z0-9.-_])"); + } + + private static void validatePath(String path) { + if (path.isEmpty()) { + return; + } + checkArgument(!path.endsWith("/"), "Path must not include a trailing '/'"); + for (String segment : Splitter.on("/").split(path)) { + validatePathSegment(segment); + } + } + + private static void validatePathSegment(String pathSegment) { + checkArgument(!pathSegment.isEmpty(), "Individual path segments must not be empty"); + checkArgument(!(pathSegment.equals(".") || pathSegment.equals("..")), + "Individual path segments must not be relative path modifiers (i.e. ., ..)"); + checkArgument(pathSegment.matches("[a-zA-Z0-9._-]+"), + "Individual path segments must contain only letters, numbers, dots, dashes, and underscores" + + " ([a-zA-Z0-9.-_])"); + } + + /** + * Returns the SPIFFE ID from the leaf certificate, if present. + * + * @param certChain certificate chain to extract SPIFFE ID from + */ + public static Optional extractSpiffeId(X509Certificate[] certChain) + throws CertificateParsingException { + checkArgument(checkNotNull(certChain, "certChain").length > 0, "certChain can't be empty"); + Collection> subjectAltNames = certChain[0].getSubjectAlternativeNames(); + if (subjectAltNames == null) { + return Optional.absent(); + } + String uri = null; + // Search for the unique URI SAN. + for (List altName : subjectAltNames) { + if (altName.size() < 2 ) { + continue; + } + if (URI_SAN_TYPE.equals(altName.get(0))) { + if (uri != null) { + throw new IllegalArgumentException("Multiple URI SAN values found in the leaf cert."); + } + uri = (String) altName.get(1); + } + } + if (uri == null) { + return Optional.absent(); + } + return Optional.of(parse(uri)); + } + + /** + * Loads a SPIFFE trust bundle from a file, parsing it from the JSON format. + * In case of success, returns {@link SpiffeBundle}. + * If any element of the JSON content is invalid or unsupported, an + * {@link IllegalArgumentException} is thrown and the entire Bundle is considered invalid. + * + * @param trustBundleFile the file path to the JSON file containing the trust bundle + * @see JSON format + * @see JWK entry format + * @see x5c (certificate) parameter + */ + public static SpiffeBundle loadTrustBundleFromFile(String trustBundleFile) throws IOException { + Map trustDomainsNode = readTrustDomainsFromFile(trustBundleFile); + Map> trustBundleMap = new HashMap<>(); + Map sequenceNumbers = new HashMap<>(); + for (String trustDomainName : trustDomainsNode.keySet()) { + Map domainNode = JsonUtil.getObject(trustDomainsNode, trustDomainName); + if (domainNode.size() == 0) { + trustBundleMap.put(trustDomainName, Collections.emptyList()); + continue; + } + Long sequenceNumber = JsonUtil.getNumberAsLong(domainNode, "spiffe_sequence"); + sequenceNumbers.put(trustDomainName, sequenceNumber == null ? -1L : sequenceNumber); + List> keysNode = JsonUtil.getListOfObjects(domainNode, "keys"); + if (keysNode == null || keysNode.size() == 0) { + trustBundleMap.put(trustDomainName, Collections.emptyList()); + continue; + } + trustBundleMap.put(trustDomainName, extractCert(keysNode, trustDomainName)); + } + return new SpiffeBundle(sequenceNumbers, trustBundleMap); + } + + private static Map readTrustDomainsFromFile(String filePath) throws IOException { + File file = new File(checkNotNull(filePath, "trustBundleFile")); + String json = new String(Files.toByteArray(file), StandardCharsets.UTF_8); + Object jsonObject = JsonParser.parse(json); + if (!(jsonObject instanceof Map)) { + throw new IllegalArgumentException( + "SPIFFE Trust Bundle should be a JSON object. Found: " + + (jsonObject == null ? null : jsonObject.getClass())); + } + @SuppressWarnings("unchecked") + Map root = (Map)jsonObject; + Map trustDomainsNode = JsonUtil.getObject(root, "trust_domains"); + checkNotNull(trustDomainsNode, "Mandatory trust_domains element is missing"); + checkArgument(trustDomainsNode.size() > 0, "Mandatory trust_domains element is missing"); + return trustDomainsNode; + } + + private static void checkJwkEntry(Map jwkNode, String trustDomainName) { + String kty = JsonUtil.getString(jwkNode, "kty"); + if (kty == null || !KTY_PARAMETER_VALUES.contains(kty)) { + throw new IllegalArgumentException( + String.format( + "'kty' parameter must be one of %s but '%s' " + + "found. Certificate loading for trust domain '%s' failed.", + KTY_PARAMETER_VALUES, kty, trustDomainName)); + } + if (jwkNode.containsKey("kid")) { + throw new IllegalArgumentException(String.format("'kid' parameter must not be set. " + + "Certificate loading for trust domain '%s' failed.", trustDomainName)); + } + String use = JsonUtil.getString(jwkNode, "use"); + if (use == null || !use.equals(USE_PARAMETER_VALUE)) { + throw new IllegalArgumentException(String.format("'use' parameter must be '%s' but '%s' " + + "found. Certificate loading for trust domain '%s' failed.", USE_PARAMETER_VALUE, + use, trustDomainName)); + } + } + + private static List extractCert(List> keysNode, + String trustDomainName) { + List result = new ArrayList<>(); + for (Map keyNode : keysNode) { + checkJwkEntry(keyNode, trustDomainName); + List rawCerts = JsonUtil.getListOfStrings(keyNode, "x5c"); + if (rawCerts == null) { + break; + } + if (rawCerts.size() != 1) { + throw new IllegalArgumentException(String.format("Exactly 1 certificate is expected, but " + + "%s found. Certificate loading for trust domain '%s' failed.", rawCerts.size(), + trustDomainName)); + } + InputStream stream = new ByteArrayInputStream((CERTIFICATE_PREFIX + rawCerts.get(0) + "\n" + + CERTIFICATE_SUFFIX) + .getBytes(StandardCharsets.UTF_8)); + try { + Collection certs = CertificateFactory.getInstance("X509") + .generateCertificates(stream); + X509Certificate[] certsArray = certs.toArray(new X509Certificate[0]); + assert certsArray.length == 1; + result.add(certsArray[0]); + } catch (CertificateException e) { + throw new IllegalArgumentException(String.format("Certificate can't be parsed. Certificate " + + "loading for trust domain '%s' failed.", trustDomainName), e); + } + } + return result; + } + + /** + * Represents a SPIFFE ID as defined in the SPIFFE standard. + * @see Standard + */ + public static class SpiffeId { + + private final String trustDomain; + private final String path; + + private SpiffeId(String trustDomain, String path) { + this.trustDomain = trustDomain; + this.path = path; + } + + public String getTrustDomain() { + return trustDomain; + } + + public String getPath() { + return path; + } + } + + /** + * Represents a SPIFFE trust bundle; that is, a map from trust domain to set of trusted + * certificates. Only trust domain's sequence numbers and x509 certificates are supported. + * @see Standard + */ + public static final class SpiffeBundle { + + private final ImmutableMap sequenceNumbers; + + private final ImmutableMap> bundleMap; + + private SpiffeBundle(Map sequenceNumbers, + Map> trustDomainMap) { + this.sequenceNumbers = ImmutableMap.copyOf(sequenceNumbers); + ImmutableMap.Builder> builder = ImmutableMap.builder(); + for (Map.Entry> entry : trustDomainMap.entrySet()) { + builder.put(entry.getKey(), ImmutableList.copyOf(entry.getValue())); + } + this.bundleMap = builder.build(); + } + + public ImmutableMap getSequenceNumbers() { + return sequenceNumbers; + } + + public ImmutableMap> getBundleMap() { + return bundleMap; + } + } + +} diff --git a/core/src/main/java/io/grpc/internal/StatsTraceContext.java b/core/src/main/java/io/grpc/internal/StatsTraceContext.java index 889be30e712..007aefc0fb8 100644 --- a/core/src/main/java/io/grpc/internal/StatsTraceContext.java +++ b/core/src/main/java/io/grpc/internal/StatsTraceContext.java @@ -23,6 +23,7 @@ import io.grpc.ClientStreamTracer; import io.grpc.Context; import io.grpc.Metadata; +import io.grpc.MethodDescriptor; import io.grpc.ServerStreamTracer; import io.grpc.ServerStreamTracer.ServerCallInfo; import io.grpc.Status; @@ -38,6 +39,14 @@ */ @ThreadSafe public final class StatsTraceContext { + /** + * Internal hook for server tracers that can use the resolved method descriptor before + * {@link ServerStreamTracer#serverCallStarted(ServerCallInfo)} runs. + */ + public interface ServerCallMethodListener { + void serverCallMethodResolved(MethodDescriptor method); + } + public static final StatsTraceContext NOOP = new StatsTraceContext(new StreamTracer[0]); private final StreamTracer[] tracers; @@ -101,9 +110,9 @@ public void clientOutboundHeaders() { * *

Called from abstract stream implementations. */ - public void clientInboundHeaders() { + public void clientInboundHeaders(Metadata headers) { for (StreamTracer tracer : tracers) { - ((ClientStreamTracer) tracer).inboundHeaders(); + ((ClientStreamTracer) tracer).inboundHeaders(headers); } } @@ -144,6 +153,20 @@ public void serverCallStarted(ServerCallInfo callInfo) { } } + /** + * Notifies server tracers that a primary-registry method descriptor was resolved before + * {@link ServerStreamTracer#serverCallStarted(ServerCallInfo)}. + * + *

Called from {@link io.grpc.internal.ServerImpl}. + */ + public void serverCallMethodResolved(MethodDescriptor method) { + for (StreamTracer tracer : tracers) { + if (tracer instanceof ServerCallMethodListener) { + ((ServerCallMethodListener) tracer).serverCallMethodResolved(method); + } + } + } + /** * See {@link StreamTracer#streamClosed}. This may be called multiple times, and only the first * value will be taken. diff --git a/core/src/main/java/io/grpc/internal/SubchannelChannel.java b/core/src/main/java/io/grpc/internal/SubchannelChannel.java index 773dcb99dd7..ced4272afe3 100644 --- a/core/src/main/java/io/grpc/internal/SubchannelChannel.java +++ b/core/src/main/java/io/grpc/internal/SubchannelChannel.java @@ -59,7 +59,8 @@ public ClientStream newStream(MethodDescriptor method, transport = notReadyTransport; } ClientStreamTracer[] tracers = GrpcUtil.getClientStreamTracers( - callOptions, headers, 0, /* isTransparentRetry= */ false); + callOptions, headers, 0, /* isTransparentRetry= */ false, + /* isHedging= */ false); Context origContext = context.attach(); try { return transport.newStream(method, headers, callOptions, tracers); diff --git a/core/src/main/java/io/grpc/internal/SubchannelMetrics.java b/core/src/main/java/io/grpc/internal/SubchannelMetrics.java new file mode 100644 index 00000000000..4bc2cf47046 --- /dev/null +++ b/core/src/main/java/io/grpc/internal/SubchannelMetrics.java @@ -0,0 +1,108 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.internal; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Lists; +import io.grpc.LongCounterMetricInstrument; +import io.grpc.LongUpDownCounterMetricInstrument; +import io.grpc.MetricInstrumentRegistry; +import io.grpc.MetricRecorder; + +final class SubchannelMetrics { + + private static final LongCounterMetricInstrument disconnections; + private static final LongCounterMetricInstrument connectionAttemptsSucceeded; + private static final LongCounterMetricInstrument connectionAttemptsFailed; + private static final LongUpDownCounterMetricInstrument openConnections; + private final MetricRecorder metricRecorder; + + public SubchannelMetrics(MetricRecorder metricRecorder) { + this.metricRecorder = metricRecorder; + } + + static { + MetricInstrumentRegistry metricInstrumentRegistry + = MetricInstrumentRegistry.getDefaultRegistry(); + disconnections = metricInstrumentRegistry.registerLongCounter( + "grpc.subchannel.disconnections", + "EXPERIMENTAL. Number of times the selected subchannel becomes disconnected", + "{disconnection}", + Lists.newArrayList("grpc.target"), + Lists.newArrayList("grpc.lb.backend_service", "grpc.lb.locality", "grpc.disconnect_error"), + false + ); + + connectionAttemptsSucceeded = metricInstrumentRegistry.registerLongCounter( + "grpc.subchannel.connection_attempts_succeeded", + "EXPERIMENTAL. Number of successful connection attempts", + "{attempt}", + Lists.newArrayList("grpc.target"), + Lists.newArrayList("grpc.lb.backend_service", "grpc.lb.locality"), + false + ); + + connectionAttemptsFailed = metricInstrumentRegistry.registerLongCounter( + "grpc.subchannel.connection_attempts_failed", + "EXPERIMENTAL. Number of failed connection attempts", + "{attempt}", + Lists.newArrayList("grpc.target"), + Lists.newArrayList("grpc.lb.backend_service", "grpc.lb.locality"), + false + ); + + openConnections = metricInstrumentRegistry.registerLongUpDownCounter( + "grpc.subchannel.open_connections", + "EXPERIMENTAL. Number of open connections.", + "{connection}", + Lists.newArrayList("grpc.target"), + Lists.newArrayList("grpc.security_level", "grpc.lb.backend_service", "grpc.lb.locality"), + false + ); + } + + public void recordConnectionAttemptSucceeded(String target, String backendService, + String locality, String securityLevel) { + metricRecorder + .addLongCounter(connectionAttemptsSucceeded, 1, + ImmutableList.of(target), + ImmutableList.of(backendService, locality)); + metricRecorder + .addLongUpDownCounter(openConnections, 1, + ImmutableList.of(target), + ImmutableList.of(securityLevel, backendService, locality)); + } + + public void recordConnectionAttemptFailed(String target, String backendService, String locality) { + metricRecorder + .addLongCounter(connectionAttemptsFailed, 1, + ImmutableList.of(target), + ImmutableList.of(backendService, locality)); + } + + public void recordDisconnection(String target, String backendService, String locality, + String disconnectError, String securityLevel) { + metricRecorder + .addLongCounter(disconnections, 1, + ImmutableList.of(target), + ImmutableList.of(backendService, locality, disconnectError)); + metricRecorder + .addLongUpDownCounter(openConnections, -1, + ImmutableList.of(target), + ImmutableList.of(securityLevel, backendService, locality)); + } +} diff --git a/core/src/main/java/io/grpc/internal/TimeProvider.java b/core/src/main/java/io/grpc/internal/TimeProvider.java index b0ea147ada1..3bd052ab3e0 100644 --- a/core/src/main/java/io/grpc/internal/TimeProvider.java +++ b/core/src/main/java/io/grpc/internal/TimeProvider.java @@ -16,8 +16,6 @@ package io.grpc.internal; -import java.util.concurrent.TimeUnit; - /** * Time source representing the current system time in nanos. Used to inject a fake clock * into unit tests. @@ -26,10 +24,5 @@ public interface TimeProvider { /** Returns the current nano time. */ long currentTimeNanos(); - TimeProvider SYSTEM_TIME_PROVIDER = new TimeProvider() { - @Override - public long currentTimeNanos() { - return TimeUnit.MILLISECONDS.toNanos(System.currentTimeMillis()); - } - }; + TimeProvider SYSTEM_TIME_PROVIDER = TimeProviderResolverFactory.resolveTimeProvider(); } diff --git a/core/src/main/java/io/grpc/internal/TimeProviderResolverFactory.java b/core/src/main/java/io/grpc/internal/TimeProviderResolverFactory.java new file mode 100644 index 00000000000..04272034ce9 --- /dev/null +++ b/core/src/main/java/io/grpc/internal/TimeProviderResolverFactory.java @@ -0,0 +1,32 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.internal; + +/** + * {@link TimeProviderResolverFactory} resolves Time providers. + */ + +final class TimeProviderResolverFactory { + static TimeProvider resolveTimeProvider() { + try { + Class.forName("java.time.Instant"); + return new InstantTimeProvider(); + } catch (ClassNotFoundException ex) { + return new ConcurrentTimeProvider(); + } + } +} diff --git a/core/src/main/java/io/grpc/internal/TransportFrameUtil.java b/core/src/main/java/io/grpc/internal/TransportFrameUtil.java index 51854410843..3bd7ee72239 100644 --- a/core/src/main/java/io/grpc/internal/TransportFrameUtil.java +++ b/core/src/main/java/io/grpc/internal/TransportFrameUtil.java @@ -16,16 +16,16 @@ package io.grpc.internal; -import static com.google.common.base.Charsets.US_ASCII; +import static java.nio.charset.StandardCharsets.US_ASCII; import com.google.common.io.BaseEncoding; +import com.google.errorprone.annotations.CheckReturnValue; import io.grpc.InternalMetadata; import io.grpc.Metadata; import java.util.ArrayList; import java.util.Arrays; import java.util.List; import java.util.logging.Logger; -import javax.annotation.CheckReturnValue; /** * Utility functions for transport layer framing. diff --git a/core/src/main/java/io/grpc/internal/UriWrapper.java b/core/src/main/java/io/grpc/internal/UriWrapper.java new file mode 100644 index 00000000000..ca5835cabd8 --- /dev/null +++ b/core/src/main/java/io/grpc/internal/UriWrapper.java @@ -0,0 +1,139 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.internal; + +import static com.google.common.base.Preconditions.checkNotNull; + +import io.grpc.NameResolver; +import io.grpc.Uri; +import java.net.URI; +import javax.annotation.Nullable; + +/** Temporary wrapper for a URI-like object to ease the migration to io.grpc.Uri. */ +interface UriWrapper { + + static UriWrapper wrap(URI uri) { + return new JavaNetUriWrapper(uri); + } + + static UriWrapper wrap(Uri uri) { + return new IoGrpcUriWrapper(uri); + } + + /** Uses the given factory and args to create a {@link NameResolver} for this URI. */ + NameResolver newNameResolver(NameResolver.Factory factory, NameResolver.Args args); + + /** Returns the scheme component of this URI, e.g. "http", "mailto" or "dns". */ + String getScheme(); + + /** + * Returns the authority component of this URI, e.g. "google.com", "127.0.0.1:8080", or null if + * not present. + */ + @Nullable + String getAuthority(); + + /** Wraps an instance of java.net.URI. */ + final class JavaNetUriWrapper implements UriWrapper { + private final URI uri; + + private JavaNetUriWrapper(URI uri) { + this.uri = checkNotNull(uri); + } + + @Override + public NameResolver newNameResolver(NameResolver.Factory factory, NameResolver.Args args) { + return factory.newNameResolver(uri, args); + } + + @Override + public String getScheme() { + return uri.getScheme(); + } + + @Override + public String getAuthority() { + return uri.getAuthority(); + } + + @Override + public String toString() { + return uri.toString(); + } + + @Override + public boolean equals(Object other) { + if (other == this) { + return true; + } + if (!(other instanceof JavaNetUriWrapper)) { + return false; + } + return uri.equals(((JavaNetUriWrapper) other).uri); + } + + @Override + public int hashCode() { + return uri.hashCode(); + } + } + + /** Wraps an instance of io.grpc.Uri. */ + final class IoGrpcUriWrapper implements UriWrapper { + private final Uri uri; + + private IoGrpcUriWrapper(Uri uri) { + this.uri = checkNotNull(uri); + } + + @Override + public NameResolver newNameResolver(NameResolver.Factory factory, NameResolver.Args args) { + return factory.newNameResolver(uri, args); + } + + @Override + public String getScheme() { + return uri.getScheme(); + } + + @Override + public String getAuthority() { + return uri.getAuthority(); + } + + @Override + public String toString() { + return uri.toString(); + } + + @Override + public boolean equals(Object other) { + if (other == this) { + return true; + } + if (!(other instanceof IoGrpcUriWrapper)) { + return false; + } + return uri.equals(((IoGrpcUriWrapper) other).uri); + } + + @Override + public int hashCode() { + return uri.hashCode(); + } + } +} diff --git a/core/src/test/java/io/grpc/internal/AbstractClientStreamTest.java b/core/src/test/java/io/grpc/internal/AbstractClientStreamTest.java index 4ce8a467d9f..8f14b74035c 100644 --- a/core/src/test/java/io/grpc/internal/AbstractClientStreamTest.java +++ b/core/src/test/java/io/grpc/internal/AbstractClientStreamTest.java @@ -20,8 +20,10 @@ import static io.grpc.internal.ClientStreamListener.RpcProgress.PROCESSED; import static io.grpc.internal.GrpcUtil.DEFAULT_MAX_MESSAGE_SIZE; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; import static org.mockito.AdditionalAnswers.delegatesTo; @@ -56,7 +58,6 @@ import org.junit.Before; import org.junit.Rule; import org.junit.Test; -import org.junit.rules.ExpectedException; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; import org.mockito.ArgumentCaptor; @@ -75,8 +76,6 @@ public class AbstractClientStreamTest { @Rule public final MockitoRule mocks = MockitoJUnit.rule(); - @SuppressWarnings("deprecation") // https://github.com/grpc/grpc-java/issues/7467 - @Rule public final ExpectedException thrown = ExpectedException.none(); private final StatsTraceContext statsTraceCtx = StatsTraceContext.NOOP; private final TransportTracer transportTracer = new TransportTracer(); @@ -135,9 +134,7 @@ public void cancel_failsOnNull() { AbstractClientStream stream = new BaseAbstractClientStream(allocator, statsTraceCtx, transportTracer); stream.start(listener); - thrown.expect(NullPointerException.class); - - stream.cancel(null); + assertThrows(NullPointerException.class, () -> stream.cancel(null)); } @Test @@ -163,9 +160,7 @@ public void startFailsOnNullListener() { AbstractClientStream stream = new BaseAbstractClientStream(allocator, statsTraceCtx, transportTracer); - thrown.expect(NullPointerException.class); - - stream.start(null); + assertThrows(NullPointerException.class, () -> stream.start(null)); } @Test @@ -173,9 +168,7 @@ public void cantCallStartTwice() { AbstractClientStream stream = new BaseAbstractClientStream(allocator, statsTraceCtx, transportTracer); stream.start(mockListener); - thrown.expect(IllegalStateException.class); - - stream.start(mockListener); + assertThrows(IllegalStateException.class, () -> stream.start(mockListener)); } @Test @@ -187,8 +180,7 @@ public void inboundDataReceived_failsOnNullFrame() { TransportState state = stream.transportState(); - thrown.expect(NullPointerException.class); - state.inboundDataReceived(null); + assertThrows(NullPointerException.class, () -> state.inboundDataReceived(null)); } @Test @@ -211,8 +203,8 @@ public void inboundHeadersReceived_failsIfStatusReported() { TransportState state = stream.transportState(); - thrown.expect(IllegalStateException.class); - state.inboundHeadersReceived(new Metadata()); + Metadata headers = new Metadata(); + assertThrows(IllegalStateException.class, () -> state.inboundHeadersReceived(headers)); } @Test @@ -473,6 +465,24 @@ allocator, new BaseTransportState(statsTraceCtx, transportTracer), sink, statsTr .isGreaterThan(TimeUnit.MILLISECONDS.toNanos(600)); } + @Test + public void setDeadline_thePastBecomesPositive() { + AbstractClientStream.Sink sink = mock(AbstractClientStream.Sink.class); + ClientStream stream = new BaseAbstractClientStream( + allocator, new BaseTransportState(statsTraceCtx, transportTracer), sink, statsTraceCtx, + transportTracer); + + stream.setDeadline(Deadline.after(-1, TimeUnit.NANOSECONDS)); + stream.start(mockListener); + + ArgumentCaptor headersCaptor = ArgumentCaptor.forClass(Metadata.class); + verify(sink).writeHeaders(headersCaptor.capture(), ArgumentMatchers.any()); + + Metadata headers = headersCaptor.getValue(); + assertThat(headers.get(Metadata.Key.of("grpc-timeout", Metadata.ASCII_STRING_MARSHALLER))) + .isEqualTo("1n"); + } + @Test public void appendTimeoutInsight() { InsightBuilder insight = new InsightBuilder(); @@ -482,6 +492,41 @@ public void appendTimeoutInsight() { assertThat(insight.toString()).isEqualTo("[remote_addr=fake_server_addr]"); } + @Test + public void overrideOnReadyThreshold() { + AbstractClientStream.Sink sink = mock(AbstractClientStream.Sink.class); + BaseTransportState state = new BaseTransportState(statsTraceCtx, transportTracer); + AbstractClientStream stream = new BaseAbstractClientStream( + allocator, + state, + sink, + statsTraceCtx, + transportTracer, + CallOptions.DEFAULT.withOnReadyThreshold(10), + true); + ClientStreamListener listener = new NoopClientStreamListener(); + stream.start(listener); + state.onStreamAllocated(); + + // Stream should be ready. 0 bytes are queued. + assertTrue(stream.isReady()); + + // Queue some bytes above the custom threshold and check that the stream is not ready. + stream.onSendingBytes(100); + assertFalse(stream.isReady()); + + // Simulate a flush and verify ready now. + stream.transportState().onSentBytes(91); + assertTrue(stream.isReady()); + } + + @Test + public void resetOnReadyThreshold() { + CallOptions options = CallOptions.DEFAULT.withOnReadyThreshold(10); + assertEquals(Integer.valueOf(10), options.getOnReadyThreshold()); + assertNull(options.clearOnReadyThreshold().getOnReadyThreshold()); + } + /** * No-op base class for testing. */ @@ -517,9 +562,23 @@ public BaseAbstractClientStream( StatsTraceContext statsTraceCtx, TransportTracer transportTracer, boolean useGet) { - super(allocator, statsTraceCtx, transportTracer, new Metadata(), CallOptions.DEFAULT, useGet); + this(allocator, state, sink, statsTraceCtx, transportTracer, CallOptions.DEFAULT, useGet); + } + + public BaseAbstractClientStream( + WritableBufferAllocator allocator, + TransportState state, + Sink sink, + StatsTraceContext statsTraceCtx, + TransportTracer transportTracer, + CallOptions callOptions, + boolean useGet) { + super(allocator, statsTraceCtx, transportTracer, new Metadata(), callOptions, useGet); this.state = state; this.sink = sink; + if (callOptions.getOnReadyThreshold() != null) { + this.transportState().setOnReadyThreshold(callOptions.getOnReadyThreshold()); + } } @Override @@ -567,7 +626,7 @@ private Throwable getDeframeFailedCause() { } public BaseTransportState(StatsTraceContext statsTraceCtx, TransportTracer transportTracer) { - super(DEFAULT_MAX_MESSAGE_SIZE, statsTraceCtx, transportTracer); + super(DEFAULT_MAX_MESSAGE_SIZE, statsTraceCtx, transportTracer, CallOptions.DEFAULT); } @Override diff --git a/core/src/test/java/io/grpc/internal/AbstractServerStreamTest.java b/core/src/test/java/io/grpc/internal/AbstractServerStreamTest.java index 618af766c08..137ba19bfea 100644 --- a/core/src/test/java/io/grpc/internal/AbstractServerStreamTest.java +++ b/core/src/test/java/io/grpc/internal/AbstractServerStreamTest.java @@ -18,6 +18,8 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; import static org.mockito.AdditionalAnswers.delegatesTo; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyBoolean; @@ -44,9 +46,7 @@ import java.util.Queue; import java.util.concurrent.TimeUnit; import org.junit.Before; -import org.junit.Rule; import org.junit.Test; -import org.junit.rules.ExpectedException; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; import org.mockito.ArgumentCaptor; @@ -59,9 +59,6 @@ public class AbstractServerStreamTest { private static final int TIMEOUT_MS = 1000; private static final int MAX_MESSAGE_SIZE = 100; - @SuppressWarnings("deprecation") // https://github.com/grpc/grpc-java/issues/7467 - @Rule public final ExpectedException thrown = ExpectedException.none(); - private final WritableBufferAllocator allocator = new WritableBufferAllocator() { @Override public WritableBuffer allocate(int capacityHint) { @@ -84,7 +81,7 @@ public void setUp() { } /** - * Test for issue https://github.com/grpc/grpc-java/issues/1795 + * Test for issue https://github.com/grpc/grpc-java/issues/1795 . */ @Test public void frameShouldBeIgnoredAfterDeframerClosed() { @@ -211,7 +208,7 @@ public void closed(Status status) { } /** - * Test for issue https://github.com/grpc/grpc-java/issues/615 + * Test for issue https://github.com/grpc/grpc-java/issues/615 . */ @Test public void completeWithoutClose() { @@ -225,9 +222,9 @@ public void completeWithoutClose() { public void setListener_setOnlyOnce() { TransportState state = stream.transportState(); state.setListener(new ServerStreamListenerBase()); - thrown.expect(IllegalStateException.class); - state.setListener(new ServerStreamListenerBase()); + ServerStreamListenerBase listener2 = new ServerStreamListenerBase(); + assertThrows(IllegalStateException.class, () -> state.setListener(listener2)); } @Test @@ -237,8 +234,7 @@ public void listenerReady_onlyOnce() { TransportState state = stream.transportState(); - thrown.expect(IllegalStateException.class); - state.onStreamAllocated(); + assertThrows(IllegalStateException.class, state::onStreamAllocated); } @Test @@ -254,8 +250,7 @@ public void listenerReady_readyCalled() { public void setListener_failsOnNull() { TransportState state = stream.transportState(); - thrown.expect(NullPointerException.class); - state.setListener(null); + assertThrows(NullPointerException.class, () -> state.setListener(null)); } // TODO(ericgribkoff) This test is only valid if deframeInTransportThread=true, as otherwise the @@ -283,9 +278,7 @@ public void messagesAvailable(MessageProducer producer) { @Test public void writeHeaders_failsOnNullHeaders() { - thrown.expect(NullPointerException.class); - - stream.writeHeaders(null, true); + assertThrows(NullPointerException.class, () -> stream.writeHeaders(null, true)); } @Test @@ -335,16 +328,13 @@ public void writeMessage_closesStream() throws Exception { @Test public void close_failsOnNullStatus() { - thrown.expect(NullPointerException.class); - - stream.close(null, new Metadata()); + Metadata trailers = new Metadata(); + assertThrows(NullPointerException.class, () -> stream.close(null, trailers)); } @Test public void close_failsOnNullMetadata() { - thrown.expect(NullPointerException.class); - - stream.close(Status.INTERNAL, null); + assertThrows(NullPointerException.class, () -> stream.close(Status.INTERNAL, null)); } @Test @@ -371,6 +361,15 @@ public void close_sendTrailersClearsReservedFields() { assertEquals("bad", metadataCaptor.getValue().get(InternalStatus.MESSAGE_KEY)); } + @Test + public void changeOnReadyThreshold() { + stream.setListener(new ServerStreamListenerBase()); + stream.transportState().onStreamAllocated(); + stream.setOnReadyThreshold(Integer.MAX_VALUE); + stream.onSendingBytes(Integer.MAX_VALUE - 1); + assertTrue(stream.isReady()); + } + private static class ServerStreamListenerBase implements ServerStreamListener { @Override public void messagesAvailable(MessageProducer producer) { @@ -441,4 +440,3 @@ public int streamId() { } } } - diff --git a/core/src/test/java/io/grpc/internal/AutoConfiguredLoadBalancerFactoryTest.java b/core/src/test/java/io/grpc/internal/AutoConfiguredLoadBalancerFactoryTest.java index 3bd98d39d31..07d19d41a86 100644 --- a/core/src/test/java/io/grpc/internal/AutoConfiguredLoadBalancerFactoryTest.java +++ b/core/src/test/java/io/grpc/internal/AutoConfiguredLoadBalancerFactoryTest.java @@ -48,7 +48,9 @@ import io.grpc.LoadBalancerRegistry; import io.grpc.NameResolver.ConfigOrError; import io.grpc.Status; +import io.grpc.SynchronizationContext; import io.grpc.internal.AutoConfiguredLoadBalancerFactory.AutoConfiguredLoadBalancer; +import io.grpc.internal.PickFirstLeafLoadBalancer.PickFirstLeafLoadBalancerConfig; import io.grpc.internal.PickFirstLoadBalancer.PickFirstLoadBalancerConfig; import io.grpc.internal.ServiceConfigUtil.PolicySelection; import io.grpc.util.ForwardingLoadBalancerHelper; @@ -57,6 +59,7 @@ import java.util.Collections; import java.util.List; import java.util.Map; +import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; @@ -95,6 +98,11 @@ public class AutoConfiguredLoadBalancerFactoryTest { delegatesTo( new FakeLoadBalancerProvider("test_lb2", testLbBalancer2, nextParsedConfigOrError2))); + private final Class pfLbClass = + PickFirstLoadBalancerProvider.isEnabledNewPickFirst() + ? PickFirstLeafLoadBalancer.class + : PickFirstLoadBalancer.class; + @Before public void setUp() { when(testLbBalancer.acceptResolvedAddresses(isA(ResolvedAddresses.class))).thenReturn( @@ -185,7 +193,7 @@ public Subchannel createSubchannel(CreateSubchannelArgs args) { AutoConfiguredLoadBalancer lb = lbf.newLoadBalancer(helper); LoadBalancer oldDelegate = lb.getDelegate(); - Status addressAcceptanceStatus = lb.tryAcceptResolvedAddresses( + Status addressAcceptanceStatus = lb.acceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(servers) .setAttributes(Attributes.EMPTY) @@ -200,7 +208,7 @@ public Subchannel createSubchannel(CreateSubchannelArgs args) { public void acceptResolvedAddresses_shutsDownOldBalancer() throws Exception { Map serviceConfig = parseConfig("{\"loadBalancingConfig\": [ {\"round_robin\": { } } ] }"); - ConfigOrError lbConfigs = lbf.parseLoadBalancerPolicy(serviceConfig); + ConfigOrError lbConfigs = lbf.parseLoadBalancingPolicyConfig(serviceConfig); final List servers = Collections.singletonList(new EquivalentAddressGroup(new SocketAddress(){})); @@ -227,7 +235,7 @@ public void shutdown() { }; lb.setDelegate(testlb); - Status addressesAcceptanceStatus = lb.tryAcceptResolvedAddresses( + Status addressesAcceptanceStatus = lb.acceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(servers) .setLoadBalancingPolicyConfig(lbConfigs.getConfig()) @@ -244,7 +252,7 @@ public void shutdown() { public void acceptResolvedAddresses_propagateLbConfigToDelegate() throws Exception { Map rawServiceConfig = parseConfig("{\"loadBalancingConfig\": [ {\"test_lb\": { \"setting1\": \"high\" } } ] }"); - ConfigOrError lbConfigs = lbf.parseLoadBalancerPolicy(rawServiceConfig); + ConfigOrError lbConfigs = lbf.parseLoadBalancingPolicyConfig(rawServiceConfig); assertThat(lbConfigs.getConfig()).isNotNull(); final List servers = @@ -252,7 +260,7 @@ public void acceptResolvedAddresses_propagateLbConfigToDelegate() throws Excepti Helper helper = new TestHelper(); AutoConfiguredLoadBalancer lb = lbf.newLoadBalancer(helper); - Status addressesAcceptanceStatus = lb.tryAcceptResolvedAddresses( + Status addressesAcceptanceStatus = lb.acceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(servers) .setLoadBalancingPolicyConfig(lbConfigs.getConfig()) @@ -272,9 +280,9 @@ public void acceptResolvedAddresses_propagateLbConfigToDelegate() throws Excepti rawServiceConfig = parseConfig("{\"loadBalancingConfig\": [ {\"test_lb\": { \"setting1\": \"low\" } } ] }"); - lbConfigs = lbf.parseLoadBalancerPolicy(rawServiceConfig); + lbConfigs = lbf.parseLoadBalancingPolicyConfig(rawServiceConfig); - addressesAcceptanceStatus = lb.tryAcceptResolvedAddresses( + addressesAcceptanceStatus = lb.acceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(servers) .setLoadBalancingPolicyConfig(lbConfigs.getConfig()) @@ -297,7 +305,7 @@ public void acceptResolvedAddresses_propagateLbConfigToDelegate() throws Excepti public void acceptResolvedAddresses_propagateAddrsToDelegate() throws Exception { Map rawServiceConfig = parseConfig("{\"loadBalancingConfig\": [ {\"test_lb\": { \"setting1\": \"high\" } } ] }"); - ConfigOrError lbConfigs = lbf.parseLoadBalancerPolicy(rawServiceConfig); + ConfigOrError lbConfigs = lbf.parseLoadBalancingPolicyConfig(rawServiceConfig); assertThat(lbConfigs.getConfig()).isNotNull(); Helper helper = new TestHelper(); @@ -305,7 +313,7 @@ public void acceptResolvedAddresses_propagateAddrsToDelegate() throws Exception List servers = Collections.singletonList(new EquivalentAddressGroup(new InetSocketAddress(8080){})); - Status addressesAcceptanceStatus = lb.tryAcceptResolvedAddresses( + Status addressesAcceptanceStatus = lb.acceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(servers) .setLoadBalancingPolicyConfig(lbConfigs.getConfig()) @@ -321,7 +329,7 @@ public void acceptResolvedAddresses_propagateAddrsToDelegate() throws Exception servers = Collections.singletonList(new EquivalentAddressGroup(new InetSocketAddress(9090){})); - addressesAcceptanceStatus = lb.tryAcceptResolvedAddresses( + addressesAcceptanceStatus = lb.acceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(servers) .setLoadBalancingPolicyConfig(lbConfigs.getConfig()) @@ -345,8 +353,8 @@ public void acceptResolvedAddresses_delegateDoNotAcceptEmptyAddressList_nothing( Map serviceConfig = parseConfig("{\"loadBalancingConfig\": [ {\"test_lb\": { \"setting1\": \"high\" } } ] }"); - ConfigOrError lbConfig = lbf.parseLoadBalancerPolicy(serviceConfig); - Status addressesAcceptanceStatus = lb.tryAcceptResolvedAddresses( + ConfigOrError lbConfig = lbf.parseLoadBalancingPolicyConfig(serviceConfig); + Status addressesAcceptanceStatus = lb.acceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(Collections.emptyList()) .setLoadBalancingPolicyConfig(lbConfig.getConfig()) @@ -365,8 +373,8 @@ public void acceptResolvedAddresses_delegateAcceptsEmptyAddressList() Map rawServiceConfig = parseConfig("{\"loadBalancingConfig\": [ {\"test_lb2\": { \"setting1\": \"high\" } } ] }"); ConfigOrError lbConfigs = - lbf.parseLoadBalancerPolicy(rawServiceConfig); - Status addressesAcceptanceStatus = lb.tryAcceptResolvedAddresses( + lbf.parseLoadBalancingPolicyConfig(rawServiceConfig); + Status addressesAcceptanceStatus = lb.acceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(Collections.emptyList()) .setLoadBalancingPolicyConfig(lbConfigs.getConfig()) @@ -386,7 +394,7 @@ public void acceptResolvedAddresses_delegateAcceptsEmptyAddressList() public void acceptResolvedAddresses_useSelectedLbPolicy() throws Exception { Map rawServiceConfig = parseConfig("{\"loadBalancingConfig\": [{\"round_robin\": {}}]}"); - ConfigOrError lbConfigs = lbf.parseLoadBalancerPolicy(rawServiceConfig); + ConfigOrError lbConfigs = lbf.parseLoadBalancingPolicyConfig(rawServiceConfig); assertThat(lbConfigs.getConfig()).isNotNull(); assertThat(((PolicySelection) lbConfigs.getConfig()).provider.getClass().getName()) .isEqualTo("io.grpc.util.SecretRoundRobinLoadBalancerProvider$Provider"); @@ -401,7 +409,7 @@ public Subchannel createSubchannel(CreateSubchannelArgs args) { } }; AutoConfiguredLoadBalancer lb = lbf.newLoadBalancer(helper); - Status addressesAcceptanceStatus = lb.tryAcceptResolvedAddresses( + Status addressesAcceptanceStatus = lb.acceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(servers) .setLoadBalancingPolicyConfig(lbConfigs.getConfig()) @@ -423,13 +431,13 @@ public Subchannel createSubchannel(CreateSubchannelArgs args) { } }; AutoConfiguredLoadBalancer lb = lbf.newLoadBalancer(helper); - Status addressesAcceptanceStatus = lb.tryAcceptResolvedAddresses( + Status addressesAcceptanceStatus = lb.acceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(servers) .setLoadBalancingPolicyConfig(null) .build()); assertThat(addressesAcceptanceStatus.isOk()).isTrue(); - assertThat(lb.getDelegate()).isInstanceOf(PickFirstLoadBalancer.class); + assertThat(lb.getDelegate()).isInstanceOf(pfLbClass); } @Test @@ -438,7 +446,7 @@ public void acceptResolvedAddresses_noLbPolicySelected_defaultToCustomDefault() .newLoadBalancer(new TestHelper()); List servers = Collections.singletonList(new EquivalentAddressGroup(new SocketAddress(){})); - Status addressesAcceptanceStatus = lb.tryAcceptResolvedAddresses( + Status addressesAcceptanceStatus = lb.acceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(servers) .setLoadBalancingPolicyConfig(null) @@ -460,7 +468,7 @@ public Subchannel createSubchannel(CreateSubchannelArgs args) { AutoConfiguredLoadBalancer lb = new AutoConfiguredLoadBalancerFactory(GrpcUtil.DEFAULT_LB_POLICY).newLoadBalancer(helper); - Status addressesAcceptanceStatus = lb.tryAcceptResolvedAddresses( + Status addressesAcceptanceStatus = lb.acceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(servers) .setAttributes(Attributes.EMPTY) @@ -473,8 +481,8 @@ public Subchannel createSubchannel(CreateSubchannelArgs args) { nextParsedConfigOrError.set(testLbParsedConfig); Map serviceConfig = parseConfig("{\"loadBalancingConfig\": [ {\"test_lb\": { } } ] }"); - ConfigOrError lbConfigs = lbf.parseLoadBalancerPolicy(serviceConfig); - addressesAcceptanceStatus = lb.tryAcceptResolvedAddresses( + ConfigOrError lbConfigs = lbf.parseLoadBalancingPolicyConfig(serviceConfig); + addressesAcceptanceStatus = lb.acceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(servers) .setLoadBalancingPolicyConfig(lbConfigs.getConfig()) @@ -484,7 +492,7 @@ public Subchannel createSubchannel(CreateSubchannelArgs args) { verify(channelLogger).log( eq(ChannelLogLevel.INFO), eq("Load balancer changed from {0} to {1}"), - eq("PickFirstLoadBalancer"), + eq(pfLbClass.getSimpleName()), eq(testLbBalancer.getClass().getSimpleName())); verify(channelLogger).log( @@ -496,8 +504,8 @@ public Subchannel createSubchannel(CreateSubchannelArgs args) { testLbParsedConfig = ConfigOrError.fromConfig("bar"); nextParsedConfigOrError.set(testLbParsedConfig); serviceConfig = parseConfig("{\"loadBalancingConfig\": [ {\"test_lb\": { } } ] }"); - lbConfigs = lbf.parseLoadBalancerPolicy(serviceConfig); - addressesAcceptanceStatus = lb.tryAcceptResolvedAddresses( + lbConfigs = lbf.parseLoadBalancingPolicyConfig(serviceConfig); + addressesAcceptanceStatus = lb.acceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(servers) .setLoadBalancingPolicyConfig(lbConfigs.getConfig()) @@ -511,33 +519,33 @@ public Subchannel createSubchannel(CreateSubchannelArgs args) { } @Test - public void parseLoadBalancerConfig_failedOnUnknown() throws Exception { + public void parseLoadBalancingConfig_failedOnUnknown() throws Exception { Map serviceConfig = parseConfig("{\"loadBalancingConfig\": [ {\"magic_balancer\": {} } ] }"); - ConfigOrError parsed = lbf.parseLoadBalancerPolicy(serviceConfig); + ConfigOrError parsed = lbf.parseLoadBalancingPolicyConfig(serviceConfig); assertThat(parsed.getError()).isNotNull(); assertThat(parsed.getError().getDescription()) .isEqualTo("None of [magic_balancer] specified by Service Config are available."); } @Test - public void parseLoadBalancerPolicy_failedOnUnknown() throws Exception { + public void parseLoadBalancingPolicy_failedOnUnknown() throws Exception { Map serviceConfig = parseConfig("{\"loadBalancingPolicy\": \"magic_balancer\"}"); - ConfigOrError parsed = lbf.parseLoadBalancerPolicy(serviceConfig); + ConfigOrError parsed = lbf.parseLoadBalancingPolicyConfig(serviceConfig); assertThat(parsed.getError()).isNotNull(); assertThat(parsed.getError().getDescription()) .isEqualTo("None of [magic_balancer] specified by Service Config are available."); } @Test - public void parseLoadBalancerConfig_multipleValidPolicies() throws Exception { + public void parseLoadBalancingConfig_multipleValidPolicies() throws Exception { Map serviceConfig = parseConfig( "{\"loadBalancingConfig\": [" + "{\"round_robin\": {}}," + "{\"test_lb\": {} } ] }"); - ConfigOrError parsed = lbf.parseLoadBalancerPolicy(serviceConfig); + ConfigOrError parsed = lbf.parseLoadBalancingPolicyConfig(serviceConfig); assertThat(parsed).isNotNull(); assertThat(parsed.getError()).isNull(); assertThat(parsed.getConfig()).isInstanceOf(PolicySelection.class); @@ -546,12 +554,12 @@ public void parseLoadBalancerConfig_multipleValidPolicies() throws Exception { } @Test - public void parseLoadBalancerConfig_policyShouldBeIgnoredIfConfigExists() throws Exception { + public void parseLoadBalancingConfig_policyShouldBeIgnoredIfConfigExists() throws Exception { Map serviceConfig = parseConfig( "{\"loadBalancingConfig\": [{\"round_robin\": {} } ]," + "\"loadBalancingPolicy\": \"pick_first\" }"); - ConfigOrError parsed = lbf.parseLoadBalancerPolicy(serviceConfig); + ConfigOrError parsed = lbf.parseLoadBalancingPolicyConfig(serviceConfig); assertThat(parsed).isNotNull(); assertThat(parsed.getError()).isNull(); assertThat(parsed.getConfig()).isInstanceOf(PolicySelection.class); @@ -560,13 +568,13 @@ public void parseLoadBalancerConfig_policyShouldBeIgnoredIfConfigExists() throws } @Test - public void parseLoadBalancerConfig_policyShouldBeIgnoredEvenIfUnknownPolicyExists() + public void parseLoadBalancingConfig_policyShouldBeIgnoredEvenIfUnknownPolicyExists() throws Exception { Map serviceConfig = parseConfig( "{\"loadBalancingConfig\": [{\"magic_balancer\": {} } ]," + "\"loadBalancingPolicy\": \"round_robin\" }"); - ConfigOrError parsed = lbf.parseLoadBalancerPolicy(serviceConfig); + ConfigOrError parsed = lbf.parseLoadBalancingPolicyConfig(serviceConfig); assertThat(parsed.getError()).isNotNull(); assertThat(parsed.getError().getDescription()) .isEqualTo("None of [magic_balancer] specified by Service Config are available."); @@ -574,7 +582,7 @@ public void parseLoadBalancerConfig_policyShouldBeIgnoredEvenIfUnknownPolicyExis @Test @SuppressWarnings("unchecked") - public void parseLoadBalancerConfig_firstInvalidPolicy() throws Exception { + public void parseLoadBalancingConfig_firstInvalidPolicy() throws Exception { when(testLbBalancerProvider.parseLoadBalancingPolicyConfig(any(Map.class))) .thenReturn(ConfigOrError.fromError(Status.UNKNOWN)); Map serviceConfig = @@ -582,7 +590,7 @@ public void parseLoadBalancerConfig_firstInvalidPolicy() throws Exception { "{\"loadBalancingConfig\": [" + "{\"test_lb\": {}}," + "{\"round_robin\": {} } ] }"); - ConfigOrError parsed = lbf.parseLoadBalancerPolicy(serviceConfig); + ConfigOrError parsed = lbf.parseLoadBalancingPolicyConfig(serviceConfig); assertThat(parsed).isNotNull(); assertThat(parsed.getConfig()).isNull(); assertThat(parsed.getError()).isEqualTo(Status.UNKNOWN); @@ -590,7 +598,7 @@ public void parseLoadBalancerConfig_firstInvalidPolicy() throws Exception { @Test @SuppressWarnings("unchecked") - public void parseLoadBalancerConfig_firstValidSecondInvalidPolicy() throws Exception { + public void parseLoadBalancingConfig_firstValidSecondInvalidPolicy() throws Exception { when(testLbBalancerProvider.parseLoadBalancingPolicyConfig(any(Map.class))) .thenReturn(ConfigOrError.fromError(Status.UNKNOWN)); Map serviceConfig = @@ -598,38 +606,45 @@ public void parseLoadBalancerConfig_firstValidSecondInvalidPolicy() throws Excep "{\"loadBalancingConfig\": [" + "{\"round_robin\": {}}," + "{\"test_lb\": {} } ] }"); - ConfigOrError parsed = lbf.parseLoadBalancerPolicy(serviceConfig); + ConfigOrError parsed = lbf.parseLoadBalancingPolicyConfig(serviceConfig); assertThat(parsed).isNotNull(); assertThat(parsed.getConfig()).isNotNull(); assertThat(((PolicySelection) parsed.getConfig()).config).isNotNull(); } @Test - public void parseLoadBalancerConfig_someProvidesAreNotAvailable() throws Exception { + public void parseLoadBalancingConfig_someProvidesAreNotAvailable() throws Exception { Map serviceConfig = parseConfig("{\"loadBalancingConfig\": [ " + "{\"magic_balancer\": {} }," + "{\"round_robin\": {}} ] }"); - ConfigOrError parsed = lbf.parseLoadBalancerPolicy(serviceConfig); + ConfigOrError parsed = lbf.parseLoadBalancingPolicyConfig(serviceConfig); assertThat(parsed).isNotNull(); assertThat(parsed.getConfig()).isNotNull(); assertThat(((PolicySelection) parsed.getConfig()).config).isNotNull(); } @Test - public void parseLoadBalancerConfig_lbConfigPropagated() throws Exception { + public void parseLoadBalancingConfig_lbConfigPropagated() throws Exception { Map rawServiceConfig = parseConfig( "{\"loadBalancingConfig\": [" + "{\"pick_first\": {\"shuffleAddressList\": true } }" + "] }"); - ConfigOrError parsed = lbf.parseLoadBalancerPolicy(rawServiceConfig); + ConfigOrError parsed = lbf.parseLoadBalancingPolicyConfig(rawServiceConfig); assertThat(parsed).isNotNull(); assertThat(parsed.getConfig()).isNotNull(); PolicySelection policySelection = (PolicySelection) parsed.getConfig(); assertThat(policySelection.provider).isInstanceOf(PickFirstLoadBalancerProvider.class); - assertThat(policySelection.config).isInstanceOf(PickFirstLoadBalancerConfig.class); - assertThat(((PickFirstLoadBalancerConfig) policySelection.config).shuffleAddressList).isTrue(); + if (PickFirstLoadBalancerProvider.isEnabledNewPickFirst()) { + assertThat(policySelection.config).isInstanceOf(PickFirstLeafLoadBalancerConfig.class); + assertThat(((PickFirstLeafLoadBalancerConfig) policySelection.config).shuffleAddressList) + .isTrue(); + } else { + assertThat(policySelection.config).isInstanceOf(PickFirstLoadBalancerConfig.class); + assertThat(((PickFirstLoadBalancerConfig) policySelection.config).shuffleAddressList) + .isTrue(); + } verifyNoInteractions(channelLogger); } @@ -678,6 +693,16 @@ private static class TestLoadBalancer extends ForwardingLoadBalancer { } private class TestHelper extends ForwardingLoadBalancerHelper { + final SynchronizationContext syncContext = new SynchronizationContext( + new Thread.UncaughtExceptionHandler() { + @Override + public void uncaughtException(Thread t, Throwable e) { + throw new AssertionError(e); + } + }); + + final FakeClock fakeClock = new FakeClock(); + @Override protected Helper delegate() { return null; @@ -692,6 +717,16 @@ public ChannelLogger getChannelLogger() { public void updateBalancingState(ConnectivityState newState, SubchannelPicker newPicker) { // noop } + + @Override + public SynchronizationContext getSynchronizationContext() { + return syncContext; + } + + @Override + public ScheduledExecutorService getScheduledExecutorService() { + return fakeClock.getScheduledExecutorService(); + } } private static class TestSubchannel extends Subchannel { diff --git a/core/src/test/java/io/grpc/internal/ClientCallImplTest.java b/core/src/test/java/io/grpc/internal/ClientCallImplTest.java index 34011cd844d..03e613e13d9 100644 --- a/core/src/test/java/io/grpc/internal/ClientCallImplTest.java +++ b/core/src/test/java/io/grpc/internal/ClientCallImplTest.java @@ -926,7 +926,7 @@ public void expiredDeadlineCancelsStream_CallOptions() { verify(stream, times(1)).cancel(statusCaptor.capture()); assertEquals(Status.Code.DEADLINE_EXCEEDED, statusCaptor.getValue().getCode()); assertThat(statusCaptor.getValue().getDescription()) - .matches("deadline exceeded after [0-9]+\\.[0-9]+s. " + .matches("CallOptions deadline exceeded after [0-9]+\\.[0-9]+s. " + "Name resolution delay 0.000000000 seconds. \\[remote_addr=127\\.0\\.0\\.1:443\\]"); } @@ -954,7 +954,24 @@ public void expiredDeadlineCancelsStream_Context() { verify(stream, times(1)).cancel(statusCaptor.capture()); assertEquals(Status.Code.DEADLINE_EXCEEDED, statusCaptor.getValue().getCode()); - assertThat(statusCaptor.getValue().getDescription()).isEqualTo("context timed out"); + assertThat(statusCaptor.getValue().getDescription()) + .matches("Context deadline exceeded after [0-9]+\\.[0-9]+s. " + + "Name resolution delay 0.000000000 seconds. \\[remote_addr=127\\.0\\.0\\.1:443\\]"); + } + + @Test + public void cancelWithoutStart() { + fakeClock.forwardTime(System.nanoTime(), TimeUnit.NANOSECONDS); + + ClientCallImpl call = new ClientCallImpl<>( + method, + MoreExecutors.directExecutor(), + baseCallOptions.withDeadline(Deadline.after(1, TimeUnit.SECONDS)), + clientStreamProvider, + deadlineCancellationExecutor, + channelCallTracer, configSelector); + // Nothing happens as a result, but it shouldn't throw + call.cancel("canceled", null); } @Test @@ -1088,6 +1105,32 @@ public void getAttributes() { assertEquals(attrs, call.getAttributes()); } + @Test + public void onCloseExceptionCaughtAndLogged() { + DelayedExecutor executor = new DelayedExecutor(); + ClientCallImpl call = new ClientCallImpl<>( + method, + executor, + baseCallOptions, + clientStreamProvider, + deadlineCancellationExecutor, + channelCallTracer, configSelector); + + call.start(callListener, new Metadata()); + verify(stream).start(listenerArgumentCaptor.capture()); + final ClientStreamListener streamListener = listenerArgumentCaptor.getValue(); + streamListener.headersRead(new Metadata()); + + doThrow(new RuntimeException("Exception thrown by onClose() in ClientCall")).when(callListener) + .onClose(any(Status.class), any(Metadata.class)); + + Status status = Status.RESOURCE_EXHAUSTED.withDescription("simulated"); + streamListener.closed(status, PROCESSED, new Metadata()); + executor.release(); + + verify(callListener).onClose(same(status), any(Metadata.class)); + } + private static final class DelayedExecutor implements Executor { private final BlockingQueue commands = new LinkedBlockingQueue<>(); diff --git a/core/src/test/java/io/grpc/internal/CompositeReadableBufferTest.java b/core/src/test/java/io/grpc/internal/CompositeReadableBufferTest.java index 011d83b548a..749b71d681e 100644 --- a/core/src/test/java/io/grpc/internal/CompositeReadableBufferTest.java +++ b/core/src/test/java/io/grpc/internal/CompositeReadableBufferTest.java @@ -16,7 +16,7 @@ package io.grpc.internal; -import static com.google.common.base.Charsets.UTF_8; +import static java.nio.charset.StandardCharsets.UTF_8; import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; @@ -28,8 +28,6 @@ import java.io.ByteArrayOutputStream; import java.io.IOException; -import java.nio.Buffer; -import java.nio.ByteBuffer; import java.nio.InvalidMarkException; import org.junit.After; import org.junit.Before; @@ -121,27 +119,6 @@ public void readByteArrayShouldSucceed() { assertEquals(EXPECTED_VALUE, new String(bytes, UTF_8)); } - @Test - public void readByteBufferShouldSucceed() { - ByteBuffer byteBuffer = ByteBuffer.allocate(EXPECTED_VALUE.length()); - int remaining = EXPECTED_VALUE.length(); - - ((Buffer) byteBuffer).limit(1); - composite.readBytes(byteBuffer); - remaining--; - assertEquals(remaining, composite.readableBytes()); - - ((Buffer) byteBuffer).limit(byteBuffer.limit() + 5); - composite.readBytes(byteBuffer); - remaining -= 5; - assertEquals(remaining, composite.readableBytes()); - - ((Buffer) byteBuffer).limit(byteBuffer.limit() + remaining); - composite.readBytes(byteBuffer); - assertEquals(0, composite.readableBytes()); - assertEquals(EXPECTED_VALUE, new String(byteBuffer.array(), UTF_8)); - } - @Test public void readStreamShouldSucceed() throws IOException { ByteArrayOutputStream bos = new ByteArrayOutputStream(); @@ -216,18 +193,6 @@ public void markAndResetWithReadByteArrayShouldSucceed() { assertArrayEquals(first, second); } - @Test - public void markAndResetWithReadByteBufferShouldSucceed() { - byte[] first = new byte[EXPECTED_VALUE.length()]; - composite.mark(); - composite.readBytes(ByteBuffer.wrap(first)); - composite.reset(); - byte[] second = new byte[EXPECTED_VALUE.length()]; - assertEquals(EXPECTED_VALUE.length(), composite.readableBytes()); - composite.readBytes(ByteBuffer.wrap(second)); - assertArrayEquals(first, second); - } - @Test public void markAndResetWithReadStreamShouldSucceed() throws IOException { ByteArrayOutputStream first = new ByteArrayOutputStream(); diff --git a/core/src/test/java/io/grpc/internal/ConcurrentTimeProviderTest.java b/core/src/test/java/io/grpc/internal/ConcurrentTimeProviderTest.java new file mode 100644 index 00000000000..7983530456c --- /dev/null +++ b/core/src/test/java/io/grpc/internal/ConcurrentTimeProviderTest.java @@ -0,0 +1,45 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.internal; + +import static com.google.common.truth.Truth.assertThat; + +import java.util.concurrent.TimeUnit; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** + * Unit tests for {@link ConcurrentTimeProvider}. + */ +@RunWith(JUnit4.class) +public class ConcurrentTimeProviderTest { + @Test + public void testConcurrentCurrentTimeNanos() { + + ConcurrentTimeProvider concurrentTimeProvider = new ConcurrentTimeProvider(); + // Get the current time from the ConcurrentTimeProvider + long actualTimeNanos = concurrentTimeProvider.currentTimeNanos(); + + // Get the current time from System for comparison + long expectedTimeNanos = TimeUnit.MILLISECONDS.toNanos(System.currentTimeMillis()); + + // Validate the time returned is close to the expected value within a tolerance + // (i,e 10 millisecond tolerance in nanoseconds). + assertThat(actualTimeNanos).isWithin(10_000_000L).of(expectedTimeNanos); + } +} diff --git a/core/src/test/java/io/grpc/internal/ConnectivityStateManagerTest.java b/core/src/test/java/io/grpc/internal/ConnectivityStateManagerTest.java index 2a759a4f386..dfd6ed56a1e 100644 --- a/core/src/test/java/io/grpc/internal/ConnectivityStateManagerTest.java +++ b/core/src/test/java/io/grpc/internal/ConnectivityStateManagerTest.java @@ -27,9 +27,7 @@ import io.grpc.ConnectivityState; import java.util.LinkedList; import java.util.concurrent.Executor; -import org.junit.Rule; import org.junit.Test; -import org.junit.rules.ExpectedException; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -38,10 +36,6 @@ */ @RunWith(JUnit4.class) public class ConnectivityStateManagerTest { - @SuppressWarnings("deprecation") // https://github.com/grpc/grpc-java/issues/7467 - @Rule - public final ExpectedException thrown = ExpectedException.none(); - private final FakeClock executor = new FakeClock(); private final ConnectivityStateManager state = new ConnectivityStateManager(); private final LinkedList sink = new LinkedList<>(); @@ -75,7 +69,7 @@ public void run() { assertEquals(1, sink.size()); assertEquals(TRANSIENT_FAILURE, sink.poll()); } - + @Test public void registerCallbackAfterStateChanged() { state.gotoState(CONNECTING); diff --git a/core/src/test/java/io/grpc/internal/DelayedClientCallTest.java b/core/src/test/java/io/grpc/internal/DelayedClientCallTest.java index 45682b3a385..0d30e947b0c 100644 --- a/core/src/test/java/io/grpc/internal/DelayedClientCallTest.java +++ b/core/src/test/java/io/grpc/internal/DelayedClientCallTest.java @@ -151,10 +151,12 @@ public void startThenSetCall() { delayedClientCall.request(1); Runnable r = delayedClientCall.setCall(mockRealCall); assertThat(r).isNotNull(); - r.run(); @SuppressWarnings("unchecked") ArgumentCaptor> listenerCaptor = ArgumentCaptor.forClass(Listener.class); + // start() must be called before setCall() returns (not in runnable), to ensure the in-use + // counts keeping the channel alive after shutdown() don't momentarily decrease to zero. verify(mockRealCall).start(listenerCaptor.capture(), any(Metadata.class)); + r.run(); Listener realCallListener = listenerCaptor.getValue(); verify(mockRealCall).request(1); realCallListener.onMessage(1); @@ -204,7 +206,7 @@ public void delayedCallsRunUnderContext() throws Exception { Object goldenValue = new Object(); DelayedClientCall delayedClientCall = Context.current().withValue(contextKey, goldenValue).call(() -> - new DelayedClientCall<>(callExecutor, fakeClock.getScheduledExecutorService(), null)); + new DelayedClientCall<>(callExecutor, fakeClock.getScheduledExecutorService(), null)); AtomicReference readyContext = new AtomicReference<>(); delayedClientCall.start(new ClientCall.Listener() { @Override public void onReady() { @@ -227,6 +229,232 @@ public void delayedCallsRunUnderContext() throws Exception { assertThat(contextKey.get(readyContext.get())).isEqualTo(goldenValue); } + @Test + public void listenerThrowsInPendingCallback_cancelsRealCall() { + DelayedClientCall delayedClientCall = new DelayedClientCall<>( + callExecutor, fakeClock.getScheduledExecutorService(), null); + final RuntimeException boom = new RuntimeException("boom"); + ClientCall.Listener throwingListener = new ClientCall.Listener() { + @Override + public void onMessage(Integer msg) { + throw boom; + } + }; + delayedClientCall.start(throwingListener, new Metadata()); + // Deliver onMessage while the wrapping DelayedListener is still buffering, by firing + // it from within realCall.start() — drainPendingCalls has not yet flipped the listener + // to pass-through. The queued onMessage is then drained and throws; the fix must catch + // the throwable and cancel the real call rather than let it escape. + Runnable r = delayedClientCall.setCall(new SimpleForwardingClientCall( + mockRealCall) { + @Override + public void start(Listener listener, Metadata metadata) { + super.start(listener, metadata); + listener.onMessage(42); + } + }); + assertThat(r).isNotNull(); + r.run(); // Must not propagate `boom`. + verify(mockRealCall).cancel(eq("Failed to read message."), eq(boom)); + } + + @Test + public void listenerThrowsInPendingOnHeaders_cancelsRealCall() { + DelayedClientCall delayedClientCall = new DelayedClientCall<>( + callExecutor, fakeClock.getScheduledExecutorService(), null); + final RuntimeException boom = new RuntimeException("boom"); + ClientCall.Listener throwingListener = new ClientCall.Listener() { + @Override + public void onHeaders(Metadata headers) { + throw boom; + } + }; + delayedClientCall.start(throwingListener, new Metadata()); + Runnable r = delayedClientCall.setCall(new SimpleForwardingClientCall( + mockRealCall) { + @Override + public void start(Listener listener, Metadata metadata) { + super.start(listener, metadata); + listener.onHeaders(new Metadata()); + } + }); + assertThat(r).isNotNull(); + r.run(); + verify(mockRealCall).cancel(eq("Failed to read headers"), eq(boom)); + } + + @Test + public void listenerThrowsInPendingOnReady_cancelsRealCall() { + DelayedClientCall delayedClientCall = new DelayedClientCall<>( + callExecutor, fakeClock.getScheduledExecutorService(), null); + final RuntimeException boom = new RuntimeException("boom"); + ClientCall.Listener throwingListener = new ClientCall.Listener() { + @Override + public void onReady() { + throw boom; + } + }; + delayedClientCall.start(throwingListener, new Metadata()); + Runnable r = delayedClientCall.setCall(new SimpleForwardingClientCall( + mockRealCall) { + @Override + public void start(Listener listener, Metadata metadata) { + super.start(listener, metadata); + listener.onReady(); + } + }); + assertThat(r).isNotNull(); + r.run(); + verify(mockRealCall).cancel(eq("Failed to call onReady."), eq(boom)); + } + + @Test + public void onCloseExceptionCaughtAndLogged() { + DelayedClientCall delayedClientCall = new DelayedClientCall<>( + callExecutor, fakeClock.getScheduledExecutorService(), null); + final RuntimeException boom = new RuntimeException("boom"); + final AtomicReference observed = new AtomicReference<>(); + ClientCall.Listener throwingListener = new ClientCall.Listener() { + @Override + public void onClose(Status status, Metadata trailers) { + observed.set(status); + throw boom; + } + }; + delayedClientCall.start(throwingListener, new Metadata()); + Runnable r = delayedClientCall.setCall(new SimpleForwardingClientCall( + mockRealCall) { + @Override + public void start(Listener listener, Metadata metadata) { + super.start(listener, metadata); + listener.onClose(Status.DATA_LOSS, new Metadata()); + } + }); + assertThat(r).isNotNull(); + r.run(); // Must not propagate `boom`. + assertThat(observed.get().getCode()).isEqualTo(Status.Code.DATA_LOSS); + verify(mockRealCall, never()).cancel(any(), any()); + } + + @Test + public void listenerThrowsInPassThroughOnMessage_cancelsRealCall() { + DelayedClientCall delayedClientCall = new DelayedClientCall<>( + callExecutor, fakeClock.getScheduledExecutorService(), null); + final RuntimeException boom = new RuntimeException("boom"); + ClientCall.Listener throwingListener = new ClientCall.Listener() { + @Override + public void onMessage(Integer msg) { + throw boom; + } + }; + delayedClientCall.start(throwingListener, new Metadata()); + Runnable r = delayedClientCall.setCall(mockRealCall); + assertThat(r).isNotNull(); + r.run(); // drain completes, listener transitions to passThrough + @SuppressWarnings("unchecked") + ArgumentCaptor> listenerCaptor = ArgumentCaptor.forClass(Listener.class); + verify(mockRealCall).start(listenerCaptor.capture(), any(Metadata.class)); + Listener realCallListener = listenerCaptor.getValue(); + realCallListener.onMessage(42); // dispatched on passThrough fast path + verify(mockRealCall).cancel(eq("Failed to read message."), eq(boom)); + } + + @Test + public void listenerThrowsInPassThroughOnHeaders_cancelsRealCall() { + DelayedClientCall delayedClientCall = new DelayedClientCall<>( + callExecutor, fakeClock.getScheduledExecutorService(), null); + final RuntimeException boom = new RuntimeException("boom"); + ClientCall.Listener throwingListener = new ClientCall.Listener() { + @Override + public void onHeaders(Metadata headers) { + throw boom; + } + }; + delayedClientCall.start(throwingListener, new Metadata()); + Runnable r = delayedClientCall.setCall(mockRealCall); + assertThat(r).isNotNull(); + r.run(); + @SuppressWarnings("unchecked") + ArgumentCaptor> listenerCaptor = ArgumentCaptor.forClass(Listener.class); + verify(mockRealCall).start(listenerCaptor.capture(), any(Metadata.class)); + Listener realCallListener = listenerCaptor.getValue(); + realCallListener.onHeaders(new Metadata()); + verify(mockRealCall).cancel(eq("Failed to read headers"), eq(boom)); + } + + @Test + public void listenerThrowsInPassThroughOnReady_cancelsRealCall() { + DelayedClientCall delayedClientCall = new DelayedClientCall<>( + callExecutor, fakeClock.getScheduledExecutorService(), null); + final RuntimeException boom = new RuntimeException("boom"); + ClientCall.Listener throwingListener = new ClientCall.Listener() { + @Override + public void onReady() { + throw boom; + } + }; + delayedClientCall.start(throwingListener, new Metadata()); + Runnable r = delayedClientCall.setCall(mockRealCall); + assertThat(r).isNotNull(); + r.run(); + @SuppressWarnings("unchecked") + ArgumentCaptor> listenerCaptor = ArgumentCaptor.forClass(Listener.class); + verify(mockRealCall).start(listenerCaptor.capture(), any(Metadata.class)); + Listener realCallListener = listenerCaptor.getValue(); + realCallListener.onReady(); + verify(mockRealCall).cancel(eq("Failed to call onReady."), eq(boom)); + } + + @Test + public void listenerThrowsInPassThrough_subsequentCallbacksSwallowedAndOnCloseOverridden() { + DelayedClientCall delayedClientCall = new DelayedClientCall<>( + callExecutor, fakeClock.getScheduledExecutorService(), null); + final RuntimeException boom = new RuntimeException("boom"); + final AtomicReference lastMessage = new AtomicReference<>(); + final AtomicReference closeStatus = new AtomicReference<>(); + final AtomicReference closeTrailers = new AtomicReference<>(); + ClientCall.Listener throwingListener = new ClientCall.Listener() { + @Override + public void onMessage(Integer msg) { + lastMessage.set(msg); + if (msg == 1) { + throw boom; + } + } + + @Override + public void onClose(Status status, Metadata trailers) { + closeStatus.set(status); + closeTrailers.set(trailers); + } + }; + delayedClientCall.start(throwingListener, new Metadata()); + Runnable r = delayedClientCall.setCall(mockRealCall); + assertThat(r).isNotNull(); + r.run(); + @SuppressWarnings("unchecked") + ArgumentCaptor> listenerCaptor = ArgumentCaptor.forClass(Listener.class); + verify(mockRealCall).start(listenerCaptor.capture(), any(Metadata.class)); + Listener realCallListener = listenerCaptor.getValue(); + + realCallListener.onMessage(1); // throws -> exceptionStatus captured + assertThat(lastMessage.get()).isEqualTo(1); + verify(mockRealCall).cancel(eq("Failed to read message."), eq(boom)); + + // Later callbacks are swallowed — the listener must not see message 2. + realCallListener.onMessage(2); + assertThat(lastMessage.get()).isEqualTo(1); + + // Transport onClose with OK must be overridden by the captured CANCELLED status. + Metadata serverTrailers = new Metadata(); + serverTrailers.put(Metadata.Key.of("k", Metadata.ASCII_STRING_MARSHALLER), "v"); + realCallListener.onClose(Status.OK, serverTrailers); + assertThat(closeStatus.get().getCode()).isEqualTo(Status.Code.CANCELLED); + assertThat(closeStatus.get().getCause()).isEqualTo(boom); + // Trailers replaced to avoid mixing sources. + assertThat(closeTrailers.get()).isNotSameInstanceAs(serverTrailers); + } + private void callMeMaybe(Runnable r) { if (r != null) { r.run(); diff --git a/core/src/test/java/io/grpc/internal/DelayedClientTransportTest.java b/core/src/test/java/io/grpc/internal/DelayedClientTransportTest.java index 4cae565a19e..d7e1d4ca4f6 100644 --- a/core/src/test/java/io/grpc/internal/DelayedClientTransportTest.java +++ b/core/src/test/java/io/grpc/internal/DelayedClientTransportTest.java @@ -17,12 +17,14 @@ package io.grpc.internal; import static com.google.common.truth.Truth.assertThat; +import static io.grpc.PickSubchannelArgsMatcher.eqPickSubchannelArgs; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertSame; import static org.junit.Assert.assertTrue; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.argThat; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.same; import static org.mockito.Mockito.doAnswer; @@ -44,6 +46,7 @@ import io.grpc.Metadata; import io.grpc.MethodDescriptor; import io.grpc.MethodDescriptor.MethodType; +import io.grpc.PickSubchannelArgsMatcher; import io.grpc.Status; import io.grpc.StringMarshaller; import io.grpc.SynchronizationContext; @@ -172,7 +175,8 @@ public void uncaughtException(Thread t, Throwable e) { delayedTransport.reprocess(mockPicker); assertEquals(0, delayedTransport.getPendingStreamsCount()); delayedTransport.shutdown(SHUTDOWN_STATUS); - verify(transportListener).transportShutdown(same(SHUTDOWN_STATUS)); + verify(transportListener).transportShutdown(same(SHUTDOWN_STATUS), + eq(SimpleDisconnectError.SUBCHANNEL_SHUTDOWN)); verify(transportListener).transportTerminated(); assertEquals(0, fakeExecutor.runDueTasks()); verify(mockRealTransport).newStream( @@ -184,7 +188,8 @@ public void uncaughtException(Thread t, Throwable e) { @Test public void transportTerminatedThenAssignTransport() { delayedTransport.shutdown(SHUTDOWN_STATUS); - verify(transportListener).transportShutdown(same(SHUTDOWN_STATUS)); + verify(transportListener).transportShutdown(same(SHUTDOWN_STATUS), + eq(SimpleDisconnectError.SUBCHANNEL_SHUTDOWN)); verify(transportListener).transportTerminated(); delayedTransport.reprocess(mockPicker); verifyNoMoreInteractions(transportListener); @@ -193,7 +198,8 @@ public void uncaughtException(Thread t, Throwable e) { @Test public void assignTransportThenShutdownThenNewStream() { delayedTransport.reprocess(mockPicker); delayedTransport.shutdown(SHUTDOWN_STATUS); - verify(transportListener).transportShutdown(same(SHUTDOWN_STATUS)); + verify(transportListener).transportShutdown(same(SHUTDOWN_STATUS), + eq(SimpleDisconnectError.SUBCHANNEL_SHUTDOWN)); verify(transportListener).transportTerminated(); ClientStream stream = delayedTransport.newStream( method, headers, callOptions, tracers); @@ -207,7 +213,8 @@ public void uncaughtException(Thread t, Throwable e) { @Test public void assignTransportThenShutdownNowThenNewStream() { delayedTransport.reprocess(mockPicker); delayedTransport.shutdownNow(Status.UNAVAILABLE); - verify(transportListener).transportShutdown(any(Status.class)); + verify(transportListener).transportShutdown(any(Status.class), + eq(SimpleDisconnectError.SUBCHANNEL_SHUTDOWN)); verify(transportListener).transportTerminated(); ClientStream stream = delayedTransport.newStream( method, headers, callOptions, tracers); @@ -238,7 +245,8 @@ public void uncaughtException(Thread t, Throwable e) { delayedTransport.shutdown(SHUTDOWN_STATUS); // Stream is still buffered - verify(transportListener).transportShutdown(same(SHUTDOWN_STATUS)); + verify(transportListener).transportShutdown(same(SHUTDOWN_STATUS), + eq(SimpleDisconnectError.SUBCHANNEL_SHUTDOWN)); verify(transportListener, times(0)).transportTerminated(); assertEquals(1, delayedTransport.getPendingStreamsCount()); @@ -272,7 +280,8 @@ public void uncaughtException(Thread t, Throwable e) { ClientStream stream = delayedTransport.newStream( method, new Metadata(), CallOptions.DEFAULT, tracers); delayedTransport.shutdown(SHUTDOWN_STATUS); - verify(transportListener).transportShutdown(same(SHUTDOWN_STATUS)); + verify(transportListener).transportShutdown(same(SHUTDOWN_STATUS), + eq(SimpleDisconnectError.SUBCHANNEL_SHUTDOWN)); verify(transportListener, times(0)).transportTerminated(); assertEquals(1, delayedTransport.getPendingStreamsCount()); stream.start(streamListener); @@ -285,7 +294,8 @@ public void uncaughtException(Thread t, Throwable e) { @Test public void shutdownThenNewStream() { delayedTransport.shutdown(SHUTDOWN_STATUS); - verify(transportListener).transportShutdown(same(SHUTDOWN_STATUS)); + verify(transportListener).transportShutdown(same(SHUTDOWN_STATUS), + eq(SimpleDisconnectError.SUBCHANNEL_SHUTDOWN)); verify(transportListener).transportTerminated(); ClientStream stream = delayedTransport.newStream( method, new Metadata(), CallOptions.DEFAULT, tracers); @@ -300,7 +310,8 @@ public void uncaughtException(Thread t, Throwable e) { method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(streamListener); delayedTransport.shutdownNow(Status.UNAVAILABLE); - verify(transportListener).transportShutdown(any(Status.class)); + verify(transportListener).transportShutdown(any(Status.class), + eq(SimpleDisconnectError.SUBCHANNEL_SHUTDOWN)); verify(transportListener).transportTerminated(); verify(streamListener) .closed(statusCaptor.capture(), eq(RpcProgress.REFUSED), any(Metadata.class)); @@ -309,7 +320,8 @@ public void uncaughtException(Thread t, Throwable e) { @Test public void shutdownNowThenNewStream() { delayedTransport.shutdownNow(Status.UNAVAILABLE); - verify(transportListener).transportShutdown(any(Status.class)); + verify(transportListener).transportShutdown(any(Status.class), + eq(SimpleDisconnectError.SUBCHANNEL_SHUTDOWN)); verify(transportListener).transportTerminated(); ClientStream stream = delayedTransport.newStream( method, new Metadata(), CallOptions.DEFAULT, tracers); @@ -344,31 +356,31 @@ public void uncaughtException(Thread t, Throwable e) { method, headers, failFastCallOptions, tracers); ff1.start(mock(ClientStreamListener.class)); ff1.halfClose(); - PickSubchannelArgsImpl ff1args = new PickSubchannelArgsImpl(method, headers, + PickSubchannelArgsMatcher ff1args = new PickSubchannelArgsMatcher(method, headers, failFastCallOptions); verify(transportListener).transportInUse(true); DelayedStream ff2 = (DelayedStream) delayedTransport.newStream( method2, headers2, failFastCallOptions, tracers); - PickSubchannelArgsImpl ff2args = new PickSubchannelArgsImpl(method2, headers2, + PickSubchannelArgsMatcher ff2args = new PickSubchannelArgsMatcher(method2, headers2, failFastCallOptions); DelayedStream ff3 = (DelayedStream) delayedTransport.newStream( method, headers, failFastCallOptions, tracers); - PickSubchannelArgsImpl ff3args = new PickSubchannelArgsImpl(method, headers, + PickSubchannelArgsMatcher ff3args = new PickSubchannelArgsMatcher(method, headers, failFastCallOptions); DelayedStream ff4 = (DelayedStream) delayedTransport.newStream( method2, headers2, failFastCallOptions, tracers); - PickSubchannelArgsImpl ff4args = new PickSubchannelArgsImpl(method2, headers2, + PickSubchannelArgsMatcher ff4args = new PickSubchannelArgsMatcher(method2, headers2, failFastCallOptions); // Wait-for-ready streams FakeClock wfr3Executor = new FakeClock(); DelayedStream wfr1 = (DelayedStream) delayedTransport.newStream( method, headers, waitForReadyCallOptions, tracers); - PickSubchannelArgsImpl wfr1args = new PickSubchannelArgsImpl(method, headers, + PickSubchannelArgsMatcher wfr1args = new PickSubchannelArgsMatcher(method, headers, waitForReadyCallOptions); DelayedStream wfr2 = (DelayedStream) delayedTransport.newStream( method2, headers2, waitForReadyCallOptions, tracers); - PickSubchannelArgsImpl wfr2args = new PickSubchannelArgsImpl(method2, headers2, + PickSubchannelArgsMatcher wfr2args = new PickSubchannelArgsMatcher(method2, headers2, waitForReadyCallOptions); CallOptions wfr3callOptions = waitForReadyCallOptions.withExecutor( wfr3Executor.getScheduledExecutorService()); @@ -376,11 +388,11 @@ public void uncaughtException(Thread t, Throwable e) { method, headers, wfr3callOptions, tracers); wfr3.start(mock(ClientStreamListener.class)); wfr3.halfClose(); - PickSubchannelArgsImpl wfr3args = new PickSubchannelArgsImpl(method, headers, + PickSubchannelArgsMatcher wfr3args = new PickSubchannelArgsMatcher(method, headers, wfr3callOptions); DelayedStream wfr4 = (DelayedStream) delayedTransport.newStream( method2, headers2, waitForReadyCallOptions, tracers); - PickSubchannelArgsImpl wfr4args = new PickSubchannelArgsImpl(method2, headers2, + PickSubchannelArgsMatcher wfr4args = new PickSubchannelArgsMatcher(method2, headers2, waitForReadyCallOptions); assertEquals(8, delayedTransport.getPendingStreamsCount()); @@ -401,14 +413,14 @@ public void uncaughtException(Thread t, Throwable e) { delayedTransport.reprocess(picker); assertEquals(5, delayedTransport.getPendingStreamsCount()); - inOrder.verify(picker).pickSubchannel(ff1args); - inOrder.verify(picker).pickSubchannel(ff2args); - inOrder.verify(picker).pickSubchannel(ff3args); - inOrder.verify(picker).pickSubchannel(ff4args); - inOrder.verify(picker).pickSubchannel(wfr1args); - inOrder.verify(picker).pickSubchannel(wfr2args); - inOrder.verify(picker).pickSubchannel(wfr3args); - inOrder.verify(picker).pickSubchannel(wfr4args); + inOrder.verify(picker).pickSubchannel(argThat(ff1args)); + inOrder.verify(picker).pickSubchannel(argThat(ff2args)); + inOrder.verify(picker).pickSubchannel(argThat(ff3args)); + inOrder.verify(picker).pickSubchannel(argThat(ff4args)); + inOrder.verify(picker).pickSubchannel(argThat(wfr1args)); + inOrder.verify(picker).pickSubchannel(argThat(wfr2args)); + inOrder.verify(picker).pickSubchannel(argThat(wfr3args)); + inOrder.verify(picker).pickSubchannel(argThat(wfr4args)); inOrder.verifyNoMoreInteractions(); // Make sure that streams are created and started immediately, not in any executor. This is @@ -454,11 +466,11 @@ public void uncaughtException(Thread t, Throwable e) { delayedTransport.reprocess(picker); assertEquals(0, delayedTransport.getPendingStreamsCount()); verify(transportListener).transportInUse(false); - inOrder.verify(picker).pickSubchannel(ff3args); // ff3 - inOrder.verify(picker).pickSubchannel(ff4args); // ff4 - inOrder.verify(picker).pickSubchannel(wfr2args); // wfr2 - inOrder.verify(picker).pickSubchannel(wfr3args); // wfr3 - inOrder.verify(picker).pickSubchannel(wfr4args); // wfr4 + inOrder.verify(picker).pickSubchannel(argThat(ff3args)); // ff3 + inOrder.verify(picker).pickSubchannel(argThat(ff4args)); // ff4 + inOrder.verify(picker).pickSubchannel(argThat(wfr2args)); // wfr2 + inOrder.verify(picker).pickSubchannel(argThat(wfr3args)); // wfr3 + inOrder.verify(picker).pickSubchannel(argThat(wfr4args)); // wfr4 inOrder.verifyNoMoreInteractions(); fakeExecutor.runDueTasks(); assertEquals(0, fakeExecutor.numPendingTasks()); @@ -478,13 +490,14 @@ public void uncaughtException(Thread t, Throwable e) { method, headers, waitForReadyCallOptions, tracers); assertNull(wfr5.getRealStream()); inOrder.verify(picker).pickSubchannel( - new PickSubchannelArgsImpl(method, headers, waitForReadyCallOptions)); + eqPickSubchannelArgs(method, headers, waitForReadyCallOptions)); inOrder.verifyNoMoreInteractions(); assertEquals(1, delayedTransport.getPendingStreamsCount()); // wfr5 will stop delayed transport from terminating delayedTransport.shutdown(SHUTDOWN_STATUS); - verify(transportListener).transportShutdown(same(SHUTDOWN_STATUS)); + verify(transportListener).transportShutdown(same(SHUTDOWN_STATUS), + eq(SimpleDisconnectError.SUBCHANNEL_SHUTDOWN)); verify(transportListener, never()).transportTerminated(); // ... until it's gone picker = mock(SubchannelPicker.class); @@ -492,13 +505,38 @@ public void uncaughtException(Thread t, Throwable e) { PickResult.withSubchannel(subchannel1)); delayedTransport.reprocess(picker); verify(picker).pickSubchannel( - new PickSubchannelArgsImpl(method, headers, waitForReadyCallOptions)); + eqPickSubchannelArgs(method, headers, waitForReadyCallOptions)); fakeExecutor.runDueTasks(); assertSame(mockRealStream, wfr5.getRealStream()); assertEquals(0, delayedTransport.getPendingStreamsCount()); verify(transportListener).transportTerminated(); } + @Test + public void reprocess_authorityOverrideFromLb() { + InOrder inOrder = inOrder(mockRealStream); + DelayedStream delayedStream = (DelayedStream) delayedTransport.newStream( + method, headers, callOptions.withAuthority(null), tracers); + delayedStream.setAuthority("authority-override-from-calloptions"); + delayedStream.start(mock(ClientStreamListener.class)); + SubchannelPicker picker = mock(SubchannelPicker.class); + PickResult pickResult = PickResult.withSubchannel( + mockSubchannel, null, "authority-override-hostname-from-lb"); + when(picker.pickSubchannel(any(PickSubchannelArgs.class))).thenReturn(pickResult); + when(mockRealTransport.newStream( + same(method), same(headers), any(CallOptions.class), + ArgumentMatchers.any())) + .thenReturn(mockRealStream); + + delayedTransport.reprocess(picker); + fakeExecutor.runDueTasks(); + + // Must be set before start(), and may be overwritten + inOrder.verify(mockRealStream).setAuthority("authority-override-hostname-from-lb"); + inOrder.verify(mockRealStream).setAuthority("authority-override-from-calloptions"); + inOrder.verify(mockRealStream).start(any(ClientStreamListener.class)); + } + @Test public void reprocess_NoPendingStream() { SubchannelPicker picker = mock(SubchannelPicker.class); @@ -517,11 +555,58 @@ public void reprocess_NoPendingStream() { // Though picker was not originally used, it will be saved and serve future streams. ClientStream stream = delayedTransport.newStream( method, headers, CallOptions.DEFAULT, tracers); - verify(picker).pickSubchannel(new PickSubchannelArgsImpl(method, headers, CallOptions.DEFAULT)); + verify(picker).pickSubchannel(eqPickSubchannelArgs(method, headers, CallOptions.DEFAULT)); verify(mockInternalSubchannel).obtainActiveTransport(); assertSame(mockRealStream, stream); } + @Test + public void newStream_authorityOverrideFromLb() { + InOrder inOrder = inOrder(mockRealStream); + SubchannelPicker picker = mock(SubchannelPicker.class); + PickResult pickResult = PickResult.withSubchannel( + mockSubchannel, null, "authority-override-hostname-from-lb"); + when(picker.pickSubchannel(any(PickSubchannelArgs.class))).thenReturn(pickResult); + when(mockRealTransport.newStream( + any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class), any())) + .thenReturn(mockRealStream); + delayedTransport.reprocess(picker); + + ClientStream stream = delayedTransport.newStream(method, headers, callOptions, tracers); + assertThat(stream).isSameInstanceAs(mockRealStream); + stream.setAuthority("authority-override-from-calloptions"); + stream.start(mock(ClientStreamListener.class)); + + // Must be set before start(), and may be overwritten + inOrder.verify(mockRealStream).setAuthority("authority-override-hostname-from-lb"); + inOrder.verify(mockRealStream).setAuthority("authority-override-from-calloptions"); + inOrder.verify(mockRealStream).start(any(ClientStreamListener.class)); + } + + @Test + public void newStream_assignsTransport_authorityFromLB() { + SubchannelPicker picker = mock(SubchannelPicker.class); + AbstractSubchannel subchannel = mock(AbstractSubchannel.class); + when(subchannel.getInternalSubchannel()).thenReturn(mockInternalSubchannel); + PickResult pickResult = PickResult.withSubchannel( + subchannel, null, "authority-override-hostname-from-lb"); + when(picker.pickSubchannel(any(PickSubchannelArgs.class))).thenReturn(pickResult); + ArgumentCaptor callOptionsArgumentCaptor = + ArgumentCaptor.forClass(CallOptions.class); + when(mockRealTransport.newStream( + any(MethodDescriptor.class), any(Metadata.class), callOptionsArgumentCaptor.capture(), + ArgumentMatchers.any())) + .thenReturn(mockRealStream); + delayedTransport.reprocess(picker); + verifyNoMoreInteractions(picker); + verifyNoMoreInteractions(transportListener); + + CallOptions callOptions = CallOptions.DEFAULT; + delayedTransport.newStream(method, headers, callOptions, tracers); + assertThat(callOptionsArgumentCaptor.getValue().getAuthority()).isEqualTo( + "authority-override-hostname-from-lb"); + } + @Test public void reprocess_newStreamRacesWithReprocess() throws Exception { final CyclicBarrier barrier = new CyclicBarrier(2); @@ -559,16 +644,16 @@ public void run() { }; sideThread.start(); - PickSubchannelArgsImpl args = new PickSubchannelArgsImpl(method, headers, callOptions); - PickSubchannelArgsImpl args2 = new PickSubchannelArgsImpl(method, headers2, callOptions); + PickSubchannelArgsMatcher args = new PickSubchannelArgsMatcher(method, headers, callOptions); + PickSubchannelArgsMatcher args2 = new PickSubchannelArgsMatcher(method, headers2, callOptions); // Is called from sideThread - verify(picker, timeout(5000)).pickSubchannel(args); + verify(picker, timeout(5000)).pickSubchannel(argThat(args)); // Because stream has not been buffered (it's still stuck in newStream()), this will do nothing, // but incrementing the picker version. delayedTransport.reprocess(picker); - verify(picker).pickSubchannel(args); + verify(picker).pickSubchannel(argThat(args)); // Now let the stuck newStream() through barrier.await(5, TimeUnit.SECONDS); @@ -576,7 +661,7 @@ public void run() { sideThread.join(5000); assertFalse("sideThread should've exited", sideThread.isAlive()); // newStream() detects that there has been a new picker while it's stuck, thus will pick again. - verify(picker, times(2)).pickSubchannel(args); + verify(picker, times(2)).pickSubchannel(argThat(args)); barrier.reset(); nextPickShouldWait.set(true); @@ -592,9 +677,9 @@ public void run() { }; sideThread2.start(); // The second stream will see the first picker - verify(picker, timeout(5000)).pickSubchannel(args2); + verify(picker, timeout(5000)).pickSubchannel(argThat(args2)); // While the first stream won't use the first picker any more. - verify(picker, times(2)).pickSubchannel(args); + verify(picker, times(2)).pickSubchannel(argThat(args)); // Now use a different picker SubchannelPicker picker2 = mock(SubchannelPicker.class); @@ -602,9 +687,9 @@ public void run() { .thenReturn(PickResult.withNoResult()); delayedTransport.reprocess(picker2); // The pending first stream uses the new picker - verify(picker2).pickSubchannel(args); + verify(picker2).pickSubchannel(argThat(args)); // The second stream is still pending in creation, doesn't use the new picker. - verify(picker2, never()).pickSubchannel(args2); + verify(picker2, never()).pickSubchannel(argThat(args2)); // Now let the second stream finish creation barrier.await(5, TimeUnit.SECONDS); @@ -612,13 +697,30 @@ public void run() { sideThread2.join(5000); assertFalse("sideThread2 should've exited", sideThread2.isAlive()); // The second stream should see the new picker - verify(picker2, timeout(5000)).pickSubchannel(args2); + verify(picker2, timeout(5000)).pickSubchannel(argThat(args2)); // Wrapping up - verify(picker, times(2)).pickSubchannel(args); - verify(picker).pickSubchannel(args2); - verify(picker2).pickSubchannel(args); - verify(picker2).pickSubchannel(args); + verify(picker, times(2)).pickSubchannel(argThat(args)); + verify(picker).pickSubchannel(argThat(args2)); + verify(picker2).pickSubchannel(argThat(args)); + verify(picker2).pickSubchannel(argThat(args)); + } + + @Test + public void reprocess_addOptionalLabelCallsTracer() throws Exception { + delayedTransport.reprocess(new SubchannelPicker() { + @Override public PickResult pickSubchannel(PickSubchannelArgs args) { + args.getPickDetailsConsumer().addOptionalLabel("routed", "perfectly"); + return PickResult.withError(Status.UNAVAILABLE.withDescription("expected")); + } + }); + + ClientStreamTracer tracer = mock(ClientStreamTracer.class); + ClientStream stream = delayedTransport.newStream( + method, headers, callOptions, new ClientStreamTracer[] {tracer}); + stream.start(streamListener); + + verify(tracer).addOptionalLabel("routed", "perfectly"); } @Test @@ -650,7 +752,24 @@ public void pendingStream_appendTimeoutInsight_waitForReady() { InsightBuilder insight = new InsightBuilder(); stream.appendTimeoutInsight(insight); assertThat(insight.toString()) - .matches("\\[wait_for_ready, buffered_nanos=[0-9]+\\, waiting_for_connection]"); + .matches("\\[wait_for_ready, connecting_and_lb_delay=[0-9]+ns\\, was_still_waiting]"); + } + + @Test + public void pendingStream_appendTimeoutInsight_waitForReady_withLastPickFailure() { + ClientStream stream = delayedTransport.newStream( + method, headers, callOptions.withWaitForReady(), tracers); + stream.start(streamListener); + SubchannelPicker picker = mock(SubchannelPicker.class); + when(picker.pickSubchannel(any(PickSubchannelArgs.class))) + .thenReturn(PickResult.withError(Status.PERMISSION_DENIED)); + delayedTransport.reprocess(picker); + InsightBuilder insight = new InsightBuilder(); + stream.appendTimeoutInsight(insight); + assertThat(insight.toString()) + .matches("\\[wait_for_ready, " + + "Last Pick Failure=Status\\{code=PERMISSION_DENIED, description=null, cause=null\\}," + + " connecting_and_lb_delay=[0-9]+ns, was_still_waiting]"); } private static TransportProvider newTransportProvider(final ClientTransport transport) { diff --git a/core/src/test/java/io/grpc/internal/DelayedStreamTest.java b/core/src/test/java/io/grpc/internal/DelayedStreamTest.java index e39e8d420a2..12c32fcf126 100644 --- a/core/src/test/java/io/grpc/internal/DelayedStreamTest.java +++ b/core/src/test/java/io/grpc/internal/DelayedStreamTest.java @@ -71,7 +71,7 @@ public class DelayedStreamTest { @Mock private ClientStreamListener listener; @Mock private ClientStream realStream; @Captor private ArgumentCaptor listenerCaptor; - private DelayedStream stream = new DelayedStream(); + private DelayedStream stream = new DelayedStream("test_op"); @Test public void setStream_setAuthority() { @@ -84,12 +84,6 @@ public void setStream_setAuthority() { inOrder.verify(realStream).start(any(ClientStreamListener.class)); } - @Test(expected = IllegalStateException.class) - public void setAuthority_afterStart() { - stream.start(listener); - stream.setAuthority("notgonnawork"); - } - @Test(expected = IllegalStateException.class) public void start_afterStart() { stream.start(listener); @@ -456,7 +450,7 @@ public void appendTimeoutInsight_realStreamNotSet() { InsightBuilder insight = new InsightBuilder(); stream.start(listener); stream.appendTimeoutInsight(insight); - assertThat(insight.toString()).matches("\\[buffered_nanos=[0-9]+\\, waiting_for_connection]"); + assertThat(insight.toString()).matches("\\[test_op_delay=[0-9]+ns\\, was_still_waiting]"); } @Test @@ -475,7 +469,7 @@ public Void answer(InvocationOnMock in) { InsightBuilder insight = new InsightBuilder(); stream.appendTimeoutInsight(insight); assertThat(insight.toString()) - .matches("\\[buffered_nanos=[0-9]+, remote_addr=127\\.0\\.0\\.1:443\\]"); + .matches("\\[test_op_delay=[0-9]+ns, remote_addr=127\\.0\\.0\\.1:443\\]"); } private void callMeMaybe(Runnable r) { diff --git a/core/src/test/java/io/grpc/internal/DnsNameResolverProviderTest.java b/core/src/test/java/io/grpc/internal/DnsNameResolverProviderTest.java index aff10ce9337..75b82df544f 100644 --- a/core/src/test/java/io/grpc/internal/DnsNameResolverProviderTest.java +++ b/core/src/test/java/io/grpc/internal/DnsNameResolverProviderTest.java @@ -16,8 +16,9 @@ package io.grpc.internal; -import static org.junit.Assert.assertNull; -import static org.junit.Assert.assertSame; +import static com.google.common.truth.Truth.assertThat; +import static com.google.common.truth.TruthJUnit.assume; +import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; import static org.mockito.Mockito.mock; @@ -25,16 +26,27 @@ import io.grpc.NameResolver; import io.grpc.NameResolver.ServiceConfigParser; import io.grpc.SynchronizationContext; +import io.grpc.Uri; import java.net.URI; +import java.util.Arrays; import org.junit.Test; import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; +import org.junit.runners.Parameterized; +import org.junit.runners.Parameterized.Parameter; +import org.junit.runners.Parameterized.Parameters; /** Unit tests for {@link DnsNameResolverProvider}. */ -@RunWith(JUnit4.class) +@RunWith(Parameterized.class) public class DnsNameResolverProviderTest { private final FakeClock fakeClock = new FakeClock(); + @Parameters(name = "enableRfc3986UrisParam={0}") + public static Iterable data() { + return Arrays.asList(new Object[][] {{true}, {false}}); + } + + @Parameter public boolean enableRfc3986UrisParam; + private final SynchronizationContext syncContext = new SynchronizationContext( new Thread.UncaughtExceptionHandler() { @Override @@ -59,10 +71,75 @@ public void isAvailable() { } @Test - public void newNameResolver() { - assertSame(DnsNameResolver.class, - provider.newNameResolver(URI.create("dns:///localhost:443"), args).getClass()); - assertNull( - provider.newNameResolver(URI.create("notdns:///localhost:443"), args)); + public void newNameResolver_acceptsHostAndPort() { + NameResolver nameResolver = newNameResolver("dns:///localhost:443", args); + assertThat(nameResolver).isNotNull(); + assertThat(nameResolver.getClass()).isSameInstanceAs(DnsNameResolver.class); + assertThat(nameResolver.getServiceAuthority()).isEqualTo("localhost:443"); + assertThat(((DnsNameResolver) nameResolver).getPort()).isEqualTo(443); + } + + @Test + public void newNameResolver_acceptsRootless() { + assume().that(enableRfc3986UrisParam).isTrue(); + NameResolver nameResolver = newNameResolver("dns:localhost:443", args); + assertThat(nameResolver).isNotNull(); + assertThat(nameResolver.getClass()).isSameInstanceAs(DnsNameResolver.class); + assertThat(nameResolver.getServiceAuthority()).isEqualTo("localhost:443"); + } + + @Test + public void newNameResolver_rejectsNonDnsScheme() { + NameResolver nameResolver = newNameResolver("notdns:///localhost:443", args); + assertThat(nameResolver).isNull(); + } + + @Test + public void newNameResolver_validDnsNameWithoutPort_usesDefaultPort() { + DnsNameResolver nameResolver = + (DnsNameResolver) newNameResolver("dns:/foo.googleapis.com", args); + assertThat(nameResolver).isNotNull(); + assertThat(nameResolver.getServiceAuthority()).isEqualTo("foo.googleapis.com"); + assertThat(nameResolver.getPort()).isEqualTo(args.getDefaultPort()); + } + + // TODO(jdcormie): Trailing path segments *should* be forbidden. This test just demonstrates that + // both newNameResolver() overloads behave the same with respect to this bug. + @Test + public void newNameResolver_toleratesTrailingPathSegments() { + NameResolver nameResolver = newNameResolver("dns:///foo.googleapis.com/ig/nor/ed", args); + assertThat(nameResolver).isNotNull(); + assertThat(nameResolver.getClass()).isSameInstanceAs(DnsNameResolver.class); + assertThat(nameResolver.getServiceAuthority()).isEqualTo("foo.googleapis.com"); + } + + @Test + public void newNameResolver_toleratesAuthority() { + NameResolver nameResolver = newNameResolver("dns://8.8.8.8/foo.googleapis.com", args); + assertThat(nameResolver).isNotNull(); + assertThat(nameResolver.getClass()).isSameInstanceAs(DnsNameResolver.class); + assertThat(nameResolver.getServiceAuthority()).isEqualTo("foo.googleapis.com"); + } + + @Test + public void newNameResolver_validIpv6Host() { + NameResolver nameResolver = newNameResolver("dns:/%5B::1%5D", args); + assertThat(nameResolver).isNotNull(); + assertThat(nameResolver.getClass()).isSameInstanceAs(DnsNameResolver.class); + assertThat(nameResolver.getServiceAuthority()).isEqualTo("[::1]"); + } + + @Test + public void newNameResolver_invalidIpv6Host_throws() { + IllegalArgumentException e = + assertThrows( + IllegalArgumentException.class, () -> newNameResolver("dns:/%5Binvalid%5D", args)); + assertThat(e).hasMessageThat().contains("invalid"); + } + + private NameResolver newNameResolver(String uriString, NameResolver.Args args) { + return enableRfc3986UrisParam + ? provider.newNameResolver(Uri.create(uriString), args) + : provider.newNameResolver(URI.create(uriString), args); } } diff --git a/core/src/test/java/io/grpc/internal/DnsNameResolverTest.java b/core/src/test/java/io/grpc/internal/DnsNameResolverTest.java index 9c245f615da..c53863dcf5d 100644 --- a/core/src/test/java/io/grpc/internal/DnsNameResolverTest.java +++ b/core/src/test/java/io/grpc/internal/DnsNameResolverTest.java @@ -17,16 +17,18 @@ package io.grpc.internal; import static com.google.common.truth.Truth.assertThat; +import static io.grpc.internal.DnsNameResolver.NETWORKADDRESS_CACHE_TTL_PROPERTY; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.isA; -import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.times; @@ -35,12 +37,14 @@ import static org.mockito.Mockito.when; import com.google.common.base.Stopwatch; +import com.google.common.base.VerifyException; import com.google.common.collect.ImmutableMap; import com.google.common.collect.Iterables; import com.google.common.net.InetAddresses; import com.google.common.testing.FakeTicker; import io.grpc.ChannelLogger; import io.grpc.EquivalentAddressGroup; +import io.grpc.FlagResetRule; import io.grpc.HttpConnectProxiedSocketAddress; import io.grpc.NameResolver; import io.grpc.NameResolver.ConfigOrError; @@ -61,7 +65,6 @@ import java.net.InetAddress; import java.net.InetSocketAddress; import java.net.SocketAddress; -import java.net.URI; import java.net.UnknownHostException; import java.util.ArrayList; import java.util.Arrays; @@ -76,13 +79,10 @@ import java.util.logging.Level; import java.util.logging.Logger; import java.util.regex.Pattern; -import javax.annotation.Nullable; -import org.junit.After; import org.junit.Before; import org.junit.Rule; import org.junit.Test; import org.junit.rules.DisableOnDebug; -import org.junit.rules.ExpectedException; import org.junit.rules.TestRule; import org.junit.rules.Timeout; import org.junit.runner.RunWith; @@ -99,8 +99,7 @@ public class DnsNameResolverTest { @Rule public final TestRule globalTimeout = new DisableOnDebug(Timeout.seconds(10)); @Rule public final MockitoRule mocks = MockitoJUnit.rule(); - @SuppressWarnings("deprecation") // https://github.com/grpc/grpc-java/issues/7467 - @Rule public final ExpectedException thrown = ExpectedException.none(); + @Rule public final FlagResetRule flagResetRule = new FlagResetRule(); private final Map serviceConfig = new LinkedHashMap<>(); @@ -113,7 +112,6 @@ public void uncaughtException(Thread t, Throwable e) { } }); - private final DnsNameResolverProvider provider = new DnsNameResolverProvider(); private final FakeClock fakeClock = new FakeClock(); private final FakeClock fakeExecutor = new FakeClock(); private static final FakeClock.TaskFilter NAME_RESOLVER_REFRESH_TASK_FILTER = @@ -140,37 +138,19 @@ public Executor create() { public void close(Executor instance) {} } - private final NameResolver.Args args = NameResolver.Args.newBuilder() - .setDefaultPort(DEFAULT_PORT) - .setProxyDetector(GrpcUtil.DEFAULT_PROXY_DETECTOR) - .setSynchronizationContext(syncContext) - .setServiceConfigParser(mock(ServiceConfigParser.class)) - .setChannelLogger(mock(ChannelLogger.class)) - .setScheduledExecutorService(fakeExecutor.getScheduledExecutorService()) - .build(); - @Mock private NameResolver.Listener2 mockListener; @Captor private ArgumentCaptor resultCaptor; - @Captor - private ArgumentCaptor errorCaptor; - @Nullable - private String networkaddressCacheTtlPropertyValue; @Mock private RecordFetcher recordFetcher; + @Mock private ProxyDetector mockProxyDetector; private RetryingNameResolver newResolver(String name, int defaultPort) { return newResolver( name, defaultPort, GrpcUtil.NOOP_PROXY_DETECTOR, Stopwatch.createUnstarted()); } - private RetryingNameResolver newResolver(String name, int defaultPort, boolean isAndroid) { - return newResolver( - name, defaultPort, GrpcUtil.NOOP_PROXY_DETECTOR, Stopwatch.createUnstarted(), - isAndroid); - } - private RetryingNameResolver newResolver( String name, int defaultPort, @@ -209,63 +189,15 @@ private RetryingNameResolver newResolver( // In practice the DNS name resolver provider always wraps the resolver in a // RetryingNameResolver which adds retry capabilities to it. We use the same setup here. - return new RetryingNameResolver( - dnsResolver, - new BackoffPolicyRetryScheduler( - new ExponentialBackoffPolicy.Provider(), - fakeExecutor.getScheduledExecutorService(), - syncContext - ), - syncContext); + return (RetryingNameResolver) RetryingNameResolver.wrap(dnsResolver, args); } @Before public void setUp() { DnsNameResolver.enableJndi = true; - networkaddressCacheTtlPropertyValue = - System.getProperty(DnsNameResolver.NETWORKADDRESS_CACHE_TTL_PROPERTY); // By default the mock listener processes the result successfully. - doAnswer(invocation -> { - ResolutionResult result = invocation.getArgument(0); - syncContext.execute( - () -> result.getAttributes().get(RetryingNameResolver.RESOLUTION_RESULT_LISTENER_KEY) - .resolutionAttempted(Status.OK)); - return null; - }).when(mockListener).onResult(isA(ResolutionResult.class)); - } - - @After - public void restoreSystemProperty() { - if (networkaddressCacheTtlPropertyValue == null) { - System.clearProperty(DnsNameResolver.NETWORKADDRESS_CACHE_TTL_PROPERTY); - } else { - System.setProperty( - DnsNameResolver.NETWORKADDRESS_CACHE_TTL_PROPERTY, - networkaddressCacheTtlPropertyValue); - } - } - - @Test - public void invalidDnsName() throws Exception { - testInvalidUri(new URI("dns", null, "/[invalid]", null)); - } - - @Test - public void validIpv6() throws Exception { - testValidUri(new URI("dns", null, "/[::1]", null), "[::1]", DEFAULT_PORT); - } - - @Test - public void validDnsNameWithoutPort() throws Exception { - testValidUri(new URI("dns", null, "/foo.googleapis.com", null), - "foo.googleapis.com", DEFAULT_PORT); - } - - @Test - public void validDnsNameWithPort() throws Exception { - testValidUri(new URI("dns", null, "/foo.googleapis.com:456", null), - "foo.googleapis.com:456", 456); + when(mockListener.onResult2(isA(ResolutionResult.class))).thenReturn(Status.OK); } @Test @@ -288,30 +220,14 @@ public void invalidDnsName_containsUnderscore() { } } - @Test - public void resolve_androidIgnoresPropertyValue() throws Exception { - System.setProperty(DnsNameResolver.NETWORKADDRESS_CACHE_TTL_PROPERTY, Long.toString(2)); - resolveNeverCache(true); - } - - @Test - public void resolve_androidIgnoresPropertyValueCacheForever() throws Exception { - System.setProperty(DnsNameResolver.NETWORKADDRESS_CACHE_TTL_PROPERTY, Long.toString(-1)); - resolveNeverCache(true); - } - @Test public void resolve_neverCache() throws Exception { - System.setProperty(DnsNameResolver.NETWORKADDRESS_CACHE_TTL_PROPERTY, "0"); - resolveNeverCache(false); - } - - private void resolveNeverCache(boolean isAndroid) throws Exception { + flagResetRule.setSystemPropertyForTest(NETWORKADDRESS_CACHE_TTL_PROPERTY, "0"); final List answer1 = createAddressList(2); final List answer2 = createAddressList(1); String name = "foo.googleapis.com"; - RetryingNameResolver resolver = newResolver(name, 81, isAndroid); + RetryingNameResolver resolver = newResolver(name, 81); DnsNameResolver dnsResolver = (DnsNameResolver) resolver.getRetriedNameResolver(); AddressResolver mockResolver = mock(AddressResolver.class); when(mockResolver.resolveAddress(anyString())).thenReturn(answer1).thenReturn(answer2); @@ -319,13 +235,13 @@ private void resolveNeverCache(boolean isAndroid) throws Exception { resolver.start(mockListener); assertEquals(1, fakeExecutor.runDueTasks()); - verify(mockListener).onResult(resultCaptor.capture()); + verify(mockListener).onResult2(resultCaptor.capture()); assertAnswerMatches(answer1, 81, resultCaptor.getValue()); assertEquals(0, fakeClock.numPendingTasks()); resolver.refresh(); assertEquals(1, fakeExecutor.runDueTasks()); - verify(mockListener, times(2)).onResult(resultCaptor.capture()); + verify(mockListener, times(2)).onResult2(resultCaptor.capture()); assertAnswerMatches(answer2, 81, resultCaptor.getValue()); assertEquals(0, fakeClock.numPendingTasks()); assertEquals(0, fakeExecutor.numPendingTasks()); @@ -347,7 +263,7 @@ public void testExecutor_default() throws Exception { resolver.start(mockListener); assertEquals(1, fakeExecutor.runDueTasks()); - verify(mockListener).onResult(resultCaptor.capture()); + verify(mockListener).onResult2(resultCaptor.capture()); assertAnswerMatches(answer, 81, resultCaptor.getValue()); assertEquals(0, fakeClock.numPendingTasks()); assertEquals(0, fakeExecutor.numPendingTasks()); @@ -389,7 +305,7 @@ public void execute(Runnable command) { resolver.start(mockListener); assertEquals(0, fakeExecutor.runDueTasks()); - verify(mockListener).onResult(resultCaptor.capture()); + verify(mockListener).onResult2(resultCaptor.capture()); assertAnswerMatches(answer, 81, resultCaptor.getValue()); assertEquals(0, fakeClock.numPendingTasks()); assertEquals(0, fakeExecutor.numPendingTasks()); @@ -402,7 +318,7 @@ public void execute(Runnable command) { @Test public void resolve_cacheForever() throws Exception { - System.setProperty(DnsNameResolver.NETWORKADDRESS_CACHE_TTL_PROPERTY, "-1"); + flagResetRule.setSystemPropertyForTest(NETWORKADDRESS_CACHE_TTL_PROPERTY, "-1"); final List answer1 = createAddressList(2); String name = "foo.googleapis.com"; FakeTicker fakeTicker = new FakeTicker(); @@ -418,7 +334,7 @@ public void resolve_cacheForever() throws Exception { resolver.start(mockListener); assertEquals(1, fakeExecutor.runDueTasks()); - verify(mockListener).onResult(resultCaptor.capture()); + verify(mockListener).onResult2(resultCaptor.capture()); assertAnswerMatches(answer1, 81, resultCaptor.getValue()); assertEquals(0, fakeClock.numPendingTasks()); @@ -436,7 +352,7 @@ public void resolve_cacheForever() throws Exception { @Test public void resolve_usingCache() throws Exception { long ttl = 60; - System.setProperty(DnsNameResolver.NETWORKADDRESS_CACHE_TTL_PROPERTY, Long.toString(ttl)); + flagResetRule.setSystemPropertyForTest(NETWORKADDRESS_CACHE_TTL_PROPERTY, Long.toString(ttl)); final List answer = createAddressList(2); String name = "foo.googleapis.com"; FakeTicker fakeTicker = new FakeTicker(); @@ -452,7 +368,7 @@ public void resolve_usingCache() throws Exception { resolver.start(mockListener); assertEquals(1, fakeExecutor.runDueTasks()); - verify(mockListener).onResult(resultCaptor.capture()); + verify(mockListener).onResult2(resultCaptor.capture()); assertAnswerMatches(answer, 81, resultCaptor.getValue()); assertEquals(0, fakeClock.numPendingTasks()); @@ -471,7 +387,7 @@ public void resolve_usingCache() throws Exception { @Test public void resolve_cacheExpired() throws Exception { long ttl = 60; - System.setProperty(DnsNameResolver.NETWORKADDRESS_CACHE_TTL_PROPERTY, Long.toString(ttl)); + flagResetRule.setSystemPropertyForTest(NETWORKADDRESS_CACHE_TTL_PROPERTY, Long.toString(ttl)); final List answer1 = createAddressList(2); final List answer2 = createAddressList(1); String name = "foo.googleapis.com"; @@ -487,14 +403,14 @@ public void resolve_cacheExpired() throws Exception { resolver.start(mockListener); assertEquals(1, fakeExecutor.runDueTasks()); - verify(mockListener).onResult(resultCaptor.capture()); + verify(mockListener).onResult2(resultCaptor.capture()); assertAnswerMatches(answer1, 81, resultCaptor.getValue()); assertEquals(0, fakeClock.numPendingTasks()); fakeTicker.advance(ttl + 1, TimeUnit.SECONDS); resolver.refresh(); assertEquals(1, fakeExecutor.runDueTasks()); - verify(mockListener, times(2)).onResult(resultCaptor.capture()); + verify(mockListener, times(2)).onResult2(resultCaptor.capture()); assertAnswerMatches(answer2, 81, resultCaptor.getValue()); assertEquals(0, fakeClock.numPendingTasks()); assertEquals(0, fakeExecutor.numPendingTasks()); @@ -504,26 +420,38 @@ public void resolve_cacheExpired() throws Exception { verify(mockResolver, times(2)).resolveAddress(anyString()); } + @Test + public void resolve_androidIgnoresPropertyValue() throws Exception { + flagResetRule.setSystemPropertyForTest(NETWORKADDRESS_CACHE_TTL_PROPERTY, "2"); + resolveDefaultValue(true); + } + + @Test + public void resolve_androidIgnoresPropertyValueCacheForever() throws Exception { + flagResetRule.setSystemPropertyForTest(NETWORKADDRESS_CACHE_TTL_PROPERTY, "-1"); + resolveDefaultValue(true); + } + @Test public void resolve_invalidTtlPropertyValue() throws Exception { - System.setProperty(DnsNameResolver.NETWORKADDRESS_CACHE_TTL_PROPERTY, "not_a_number"); - resolveDefaultValue(); + flagResetRule.setSystemPropertyForTest(NETWORKADDRESS_CACHE_TTL_PROPERTY, "not_a_number"); + resolveDefaultValue(false); } @Test public void resolve_noPropertyValue() throws Exception { - System.clearProperty(DnsNameResolver.NETWORKADDRESS_CACHE_TTL_PROPERTY); - resolveDefaultValue(); + flagResetRule.clearSystemPropertyForTest(NETWORKADDRESS_CACHE_TTL_PROPERTY); + resolveDefaultValue(false); } - private void resolveDefaultValue() throws Exception { + private void resolveDefaultValue(boolean isAndroid) throws Exception { final List answer1 = createAddressList(2); final List answer2 = createAddressList(1); String name = "foo.googleapis.com"; FakeTicker fakeTicker = new FakeTicker(); RetryingNameResolver resolver = newResolver( - name, 81, GrpcUtil.NOOP_PROXY_DETECTOR, Stopwatch.createUnstarted(fakeTicker)); + name, 81, GrpcUtil.NOOP_PROXY_DETECTOR, Stopwatch.createUnstarted(fakeTicker), isAndroid); DnsNameResolver dnsResolver = (DnsNameResolver) resolver.getRetriedNameResolver(); AddressResolver mockResolver = mock(AddressResolver.class); when(mockResolver.resolveAddress(anyString())).thenReturn(answer1).thenReturn(answer2); @@ -531,7 +459,7 @@ private void resolveDefaultValue() throws Exception { resolver.start(mockListener); assertEquals(1, fakeExecutor.runDueTasks()); - verify(mockListener).onResult(resultCaptor.capture()); + verify(mockListener).onResult2(resultCaptor.capture()); assertAnswerMatches(answer1, 81, resultCaptor.getValue()); assertEquals(0, fakeClock.numPendingTasks()); @@ -544,7 +472,7 @@ private void resolveDefaultValue() throws Exception { fakeTicker.advance(1, TimeUnit.SECONDS); resolver.refresh(); assertEquals(1, fakeExecutor.runDueTasks()); - verify(mockListener, times(2)).onResult(resultCaptor.capture()); + verify(mockListener, times(2)).onResult2(resultCaptor.capture()); assertAnswerMatches(answer2, 81, resultCaptor.getValue()); assertEquals(0, fakeClock.numPendingTasks()); assertEquals(0, fakeExecutor.numPendingTasks()); @@ -575,9 +503,9 @@ public List resolveAddress(String host) throws Exception { assertThat(fakeExecutor.runDueTasks()).isEqualTo(1); ArgumentCaptor ac = ArgumentCaptor.forClass(ResolutionResult.class); - verify(mockListener).onResult(ac.capture()); + verify(mockListener).onResult2(ac.capture()); verifyNoMoreInteractions(mockListener); - assertThat(ac.getValue().getAddresses()).isEmpty(); + assertThat(ac.getValue().getAddressesOrError().getValue()).isEmpty(); assertThat(ac.getValue().getServiceConfig()).isNull(); verify(mockResourceResolver, never()).resolveSrv(anyString()); @@ -585,15 +513,43 @@ public List resolveAddress(String host) throws Exception { assertEquals(0, fakeExecutor.numPendingTasks()); } + @Test + public void resolve_addressResolutionError() throws Exception { + DnsNameResolver.enableTxt = true; + when(mockProxyDetector.proxyFor(any(SocketAddress.class))).thenThrow(new IOException()); + RetryingNameResolver resolver = newResolver( + "addr.fake:1234", 443, mockProxyDetector, Stopwatch.createUnstarted()); + DnsNameResolver dnsResolver = (DnsNameResolver) resolver.getRetriedNameResolver(); + dnsResolver.setAddressResolver(new AddressResolver() { + @Override + public List resolveAddress(String host) throws Exception { + return Collections.emptyList(); + } + }); + ResourceResolver mockResourceResolver = mock(ResourceResolver.class); + when(mockResourceResolver.resolveTxt(anyString())) + .thenReturn(Collections.emptyList()); + + dnsResolver.setResourceResolver(mockResourceResolver); + + resolver.start(mockListener); + assertThat(fakeExecutor.runDueTasks()).isEqualTo(1); + + ArgumentCaptor ac = ArgumentCaptor.forClass(ResolutionResult.class); + verify(mockListener).onResult2(ac.capture()); + verifyNoMoreInteractions(mockListener); + assertThat(ac.getValue().getAddressesOrError().getStatus().getCode()).isEqualTo( + Status.UNAVAILABLE.getCode()); + assertThat(ac.getValue().getAddressesOrError().getStatus().getDescription()).isEqualTo( + "Unable to resolve host addr.fake"); + assertThat(ac.getValue().getAddressesOrError().getStatus().getCause()) + .isInstanceOf(IOException.class); + } + // Load balancer rejects the empty addresses. @Test public void resolve_emptyResult_notAccepted() throws Exception { - doAnswer(invocation -> { - ResolutionResult result = invocation.getArgument(0); - result.getAttributes().get(RetryingNameResolver.RESOLUTION_RESULT_LISTENER_KEY) - .resolutionAttempted(Status.UNAVAILABLE); - return null; - }).when(mockListener).onResult(isA(ResolutionResult.class)); + when(mockListener.onResult2(isA(ResolutionResult.class))).thenReturn(Status.UNAVAILABLE); DnsNameResolver.enableTxt = true; RetryingNameResolver resolver = newResolver("dns:///addr.fake:1234", 443); @@ -614,9 +570,9 @@ public List resolveAddress(String host) throws Exception { syncContext.execute(() -> assertThat(fakeExecutor.runDueTasks()).isEqualTo(1)); ArgumentCaptor ac = ArgumentCaptor.forClass(ResolutionResult.class); - verify(mockListener).onResult(ac.capture()); + verify(mockListener).onResult2(ac.capture()); verifyNoMoreInteractions(mockListener); - assertThat(ac.getValue().getAddresses()).isEmpty(); + assertThat(ac.getValue().getAddressesOrError().getValue()).isEmpty(); assertThat(ac.getValue().getServiceConfig()).isNull(); verify(mockResourceResolver, never()).resolveSrv(anyString()); @@ -640,11 +596,11 @@ public void resolve_nullResourceResolver() throws Exception { dnsResolver.setResourceResolver(null); resolver.start(mockListener); assertEquals(1, fakeExecutor.runDueTasks()); - verify(mockListener).onResult(resultCaptor.capture()); + verify(mockListener).onResult2(resultCaptor.capture()); ResolutionResult result = resultCaptor.getValue(); InetSocketAddress resolvedBackendAddr = (InetSocketAddress) Iterables.getOnlyElement( - Iterables.getOnlyElement(result.getAddresses()).getAddresses()); + Iterables.getOnlyElement(result.getAddressesOrError().getValue()).getAddresses()); assertThat(resolvedBackendAddr.getAddress()).isEqualTo(backendAddr); verify(mockAddressResolver).resolveAddress(name); assertThat(result.getServiceConfig()).isNull(); @@ -659,6 +615,7 @@ public void resolve_nullResourceResolver_addressFailure() throws Exception { AddressResolver mockAddressResolver = mock(AddressResolver.class); when(mockAddressResolver.resolveAddress(anyString())) .thenThrow(new IOException("no addr")); + when(mockListener.onResult2(isA(ResolutionResult.class))).thenReturn(Status.UNAVAILABLE); String name = "foo.googleapis.com"; RetryingNameResolver resolver = newResolver(name, 81); @@ -667,8 +624,8 @@ public void resolve_nullResourceResolver_addressFailure() throws Exception { dnsResolver.setResourceResolver(null); resolver.start(mockListener); assertEquals(1, fakeExecutor.runDueTasks()); - verify(mockListener).onError(errorCaptor.capture()); - Status errorStatus = errorCaptor.getValue(); + verify(mockListener).onResult2(resultCaptor.capture()); + Status errorStatus = resultCaptor.getValue().getAddressesOrError().getStatus(); assertThat(errorStatus.getCode()).isEqualTo(Code.UNAVAILABLE); assertThat(errorStatus.getCause()).hasMessageThat().contains("no addr"); @@ -712,11 +669,11 @@ public ConfigOrError parseServiceConfig(Map rawServiceConfig) { resolver.start(mockListener); assertEquals(1, fakeExecutor.runDueTasks()); - verify(mockListener).onResult(resultCaptor.capture()); + verify(mockListener).onResult2(resultCaptor.capture()); ResolutionResult result = resultCaptor.getValue(); InetSocketAddress resolvedBackendAddr = (InetSocketAddress) Iterables.getOnlyElement( - Iterables.getOnlyElement(result.getAddresses()).getAddresses()); + Iterables.getOnlyElement(result.getAddressesOrError().getValue()).getAddresses()); assertThat(resolvedBackendAddr.getAddress()).isEqualTo(backendAddr); assertThat(result.getServiceConfig().getConfig()).isNotNull(); verify(mockAddressResolver).resolveAddress(name); @@ -727,11 +684,12 @@ public ConfigOrError parseServiceConfig(Map rawServiceConfig) { } @Test - public void resolve_addressFailure_neverLookUpServiceConfig() throws Exception { + public void resolve_addressFailure_stillLookUpServiceConfig() throws Exception { DnsNameResolver.enableTxt = true; AddressResolver mockAddressResolver = mock(AddressResolver.class); when(mockAddressResolver.resolveAddress(anyString())) .thenThrow(new IOException("no addr")); + when(mockListener.onResult2(isA(ResolutionResult.class))).thenReturn(Status.UNAVAILABLE); String name = "foo.googleapis.com"; ResourceResolver mockResourceResolver = mock(ResourceResolver.class); @@ -741,11 +699,11 @@ public void resolve_addressFailure_neverLookUpServiceConfig() throws Exception { dnsResolver.setResourceResolver(mockResourceResolver); resolver.start(mockListener); assertEquals(1, fakeExecutor.runDueTasks()); - verify(mockListener).onError(errorCaptor.capture()); - Status errorStatus = errorCaptor.getValue(); + verify(mockListener).onResult2(resultCaptor.capture()); + Status errorStatus = resultCaptor.getValue().getAddressesOrError().getStatus(); assertThat(errorStatus.getCode()).isEqualTo(Code.UNAVAILABLE); assertThat(errorStatus.getCause()).hasMessageThat().contains("no addr"); - verify(mockResourceResolver, never()).resolveTxt(anyString()); + verify(mockResourceResolver).resolveTxt("_grpc_config." + name); assertEquals(0, fakeClock.numPendingTasks()); // A retry should be scheduled @@ -770,11 +728,11 @@ public void resolve_serviceConfigLookupFails_nullServiceConfig() throws Exceptio dnsResolver.setResourceResolver(mockResourceResolver); resolver.start(mockListener); assertEquals(1, fakeExecutor.runDueTasks()); - verify(mockListener).onResult(resultCaptor.capture()); + verify(mockListener).onResult2(resultCaptor.capture()); ResolutionResult result = resultCaptor.getValue(); InetSocketAddress resolvedBackendAddr = (InetSocketAddress) Iterables.getOnlyElement( - Iterables.getOnlyElement(result.getAddresses()).getAddresses()); + Iterables.getOnlyElement(result.getAddressesOrError().getValue()).getAddresses()); assertThat(resolvedBackendAddr.getAddress()).isEqualTo(backendAddr); verify(mockAddressResolver).resolveAddress(name); assertThat(result.getServiceConfig()).isNull(); @@ -802,11 +760,11 @@ public void resolve_serviceConfigMalformed_serviceConfigError() throws Exception dnsResolver.setResourceResolver(mockResourceResolver); resolver.start(mockListener); assertEquals(1, fakeExecutor.runDueTasks()); - verify(mockListener).onResult(resultCaptor.capture()); + verify(mockListener).onResult2(resultCaptor.capture()); ResolutionResult result = resultCaptor.getValue(); InetSocketAddress resolvedBackendAddr = (InetSocketAddress) Iterables.getOnlyElement( - Iterables.getOnlyElement(result.getAddresses()).getAddresses()); + Iterables.getOnlyElement(result.getAddressesOrError().getValue()).getAddresses()); assertThat(resolvedBackendAddr.getAddress()).isEqualTo(backendAddr); verify(mockAddressResolver).resolveAddress(name); assertThat(result.getServiceConfig()).isNotNull(); @@ -870,8 +828,8 @@ public HttpConnectProxiedSocketAddress proxyFor(SocketAddress targetAddress) { resolver.start(mockListener); assertEquals(1, fakeExecutor.runDueTasks()); - verify(mockListener).onResult(resultCaptor.capture()); - List result = resultCaptor.getValue().getAddresses(); + verify(mockListener).onResult2(resultCaptor.capture()); + List result = resultCaptor.getValue().getAddressesOrError().getValue(); assertThat(result).hasSize(1); EquivalentAddressGroup eag = result.get(0); assertThat(eag.getAddresses()).hasSize(1); @@ -891,9 +849,10 @@ public HttpConnectProxiedSocketAddress proxyFor(SocketAddress targetAddress) { public void maybeChooseServiceConfig_failsOnMisspelling() { Map bad = new LinkedHashMap<>(); bad.put("parcentage", 1.0); - thrown.expectMessage("Bad key"); - - DnsNameResolver.maybeChooseServiceConfig(bad, new Random(), "host"); + Random random = new Random(); + VerifyException e = assertThrows(VerifyException.class, + () -> DnsNameResolver.maybeChooseServiceConfig(bad, random, "host")); + assertThat(e).hasMessageThat().isEqualTo("Bad key: parcentage=1.0"); } @Test @@ -930,7 +889,7 @@ public void maybeChooseServiceConfig_clientLanguageCaseInsensitive() { } @Test - public void maybeChooseServiceConfig_clientLanguageMatchesEmtpy() { + public void maybeChooseServiceConfig_clientLanguageMatchesEmpty() { Map choice = new LinkedHashMap<>(); List langs = new ArrayList<>(); choice.put("clientLanguage", langs); @@ -1099,7 +1058,7 @@ public void maybeChooseServiceConfig_clientLanguageCaseSensitive() { } @Test - public void maybeChooseServiceConfig_hostnameMatchesEmtpy() { + public void maybeChooseServiceConfig_hostnameMatchesEmpty() { Map choice = new LinkedHashMap<>(); List hosts = new ArrayList<>(); choice.put("clientHostname", hosts); @@ -1132,25 +1091,25 @@ public void parseTxtResults_misspelledName() throws Exception { } @Test - public void parseTxtResults_badTypeFails() throws Exception { + public void parseTxtResults_badTypeFails() { List txtRecords = new ArrayList<>(); txtRecords.add("some_record"); txtRecords.add("grpc_config={}"); - thrown.expect(ClassCastException.class); - thrown.expectMessage("wrong type"); - DnsNameResolver.parseTxtResults(txtRecords); + ClassCastException e = assertThrows(ClassCastException.class, + () -> DnsNameResolver.parseTxtResults(txtRecords)); + assertThat(e).hasMessageThat().isEqualTo("wrong type {}"); } @Test - public void parseTxtResults_badInnerTypeFails() throws Exception { + public void parseTxtResults_badInnerTypeFails() { List txtRecords = new ArrayList<>(); txtRecords.add("some_record"); txtRecords.add("grpc_config=[\"bogus\"]"); - thrown.expect(ClassCastException.class); - thrown.expectMessage("not object"); - DnsNameResolver.parseTxtResults(txtRecords); + ClassCastException e = assertThrows(ClassCastException.class, + () -> DnsNameResolver.parseTxtResults(txtRecords)); + assertThat(e).hasMessageThat().isEqualTo("value bogus for idx 0 in [bogus] is not object"); } @Test @@ -1191,7 +1150,7 @@ public void shouldUseJndi_falseIfDisabledForLocalhost() { } @Test - public void shouldUseJndi_trueIfLocalhostOverriden() { + public void shouldUseJndi_trueIfLocalhostOverridden() { boolean enableJndi = true; boolean enableJndiLocalhost = true; String host = "localhost"; @@ -1283,22 +1242,6 @@ public void parseServiceConfig_matches() { assertThat(result.getConfig()).isEqualTo(ImmutableMap.of()); } - private void testInvalidUri(URI uri) { - try { - provider.newNameResolver(uri, args); - fail("Should have failed"); - } catch (IllegalArgumentException e) { - // expected - } - } - - private void testValidUri(URI uri, String exportedAuthority, int expectedPort) { - DnsNameResolver resolver = (DnsNameResolver) provider.newNameResolver(uri, args); - assertNotNull(resolver); - assertEquals(expectedPort, resolver.getPort()); - assertEquals(exportedAuthority, resolver.getServiceAuthority()); - } - private byte lastByte = 0; private List createAddressList(int n) throws UnknownHostException { @@ -1311,9 +1254,9 @@ private List createAddressList(int n) throws UnknownHostException { private static void assertAnswerMatches( List addrs, int port, ResolutionResult resolutionResult) { - assertThat(resolutionResult.getAddresses()).hasSize(addrs.size()); + assertThat(resolutionResult.getAddressesOrError().getValue()).hasSize(addrs.size()); for (int i = 0; i < addrs.size(); i++) { - EquivalentAddressGroup addrGroup = resolutionResult.getAddresses().get(i); + EquivalentAddressGroup addrGroup = resolutionResult.getAddressesOrError().getValue().get(i); InetSocketAddress socketAddr = (InetSocketAddress) Iterables.getOnlyElement(addrGroup.getAddresses()); assertEquals("Addr " + i, port, socketAddr.getPort()); diff --git a/core/src/test/java/io/grpc/internal/ForwardingReadableBufferTest.java b/core/src/test/java/io/grpc/internal/ForwardingReadableBufferTest.java index 8ce45bc77cf..696fb35e379 100644 --- a/core/src/test/java/io/grpc/internal/ForwardingReadableBufferTest.java +++ b/core/src/test/java/io/grpc/internal/ForwardingReadableBufferTest.java @@ -25,7 +25,6 @@ import java.io.IOException; import java.io.OutputStream; import java.lang.reflect.Method; -import java.nio.ByteBuffer; import java.util.Collections; import org.junit.Before; import org.junit.Rule; @@ -91,14 +90,6 @@ public void readBytes() { verify(delegate).readBytes(dest, 1, 2); } - @Test - public void readBytes_overload1() { - ByteBuffer dest = ByteBuffer.allocate(0); - buffer.readBytes(dest); - - verify(delegate).readBytes(dest); - } - @Test public void readBytes_overload2() throws IOException { OutputStream dest = mock(OutputStream.class); diff --git a/core/src/test/java/io/grpc/internal/GrpcUtilTest.java b/core/src/test/java/io/grpc/internal/GrpcUtilTest.java index 39acb582d28..c243790028c 100644 --- a/core/src/test/java/io/grpc/internal/GrpcUtilTest.java +++ b/core/src/test/java/io/grpc/internal/GrpcUtilTest.java @@ -22,6 +22,7 @@ import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; @@ -41,7 +42,6 @@ import java.util.ArrayList; import org.junit.Rule; import org.junit.Test; -import org.junit.rules.ExpectedException; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; import org.mockito.ArgumentCaptor; @@ -57,8 +57,6 @@ public class GrpcUtilTest { new ClientStreamTracer() {} }; - @SuppressWarnings("deprecation") // https://github.com/grpc/grpc-java/issues/7467 - @Rule public final ExpectedException thrown = ExpectedException.none(); @Rule public final MockitoRule mocks = MockitoJUnit.rule(); @Captor @@ -100,8 +98,8 @@ public void timeoutTest() { GrpcUtil.TimeoutMarshaller marshaller = new GrpcUtil.TimeoutMarshaller(); // nanos - assertEquals("0n", marshaller.toAsciiString(0L)); - assertEquals(0L, (long) marshaller.parseAsciiString("0n")); + assertEquals("1n", marshaller.toAsciiString(1L)); + assertEquals(1L, (long) marshaller.parseAsciiString("1n")); assertEquals("99999999n", marshaller.toAsciiString(99999999L)); assertEquals(99999999L, (long) marshaller.parseAsciiString("99999999n")); @@ -201,9 +199,7 @@ public void urlAuthorityEscape_unicodeAreNotEncoded() { @Test public void checkAuthority_failsOnNull() { - thrown.expect(NullPointerException.class); - - GrpcUtil.checkAuthority(null); + assertThrows(NullPointerException.class, () -> GrpcUtil.checkAuthority(null)); } @Test @@ -229,19 +225,18 @@ public void checkAuthority_succeedsOnIpV6() { @Test public void checkAuthority_failsOnInvalidAuthority() { - thrown.expect(IllegalArgumentException.class); - thrown.expectMessage("Invalid authority"); - - GrpcUtil.checkAuthority("[ : : 1]"); + IllegalArgumentException e = assertThrows(IllegalArgumentException.class, + () -> GrpcUtil.checkAuthority("[ : : 1]")); + assertThat(e).hasMessageThat().isEqualTo("Invalid authority: [ : : 1]"); } @Test public void checkAuthority_userInfoNotAllowed() { - thrown.expect(IllegalArgumentException.class); - thrown.expectMessage("Userinfo"); - - GrpcUtil.checkAuthority("foo@valid"); + IllegalArgumentException e = assertThrows(IllegalArgumentException.class, + () -> GrpcUtil.checkAuthority("foo@valid")); + assertThat(e).hasMessageThat() + .isEqualTo("Userinfo must not be present on authority: 'foo@valid'"); } @Test diff --git a/core/src/test/java/io/grpc/internal/Http2ClientStreamTransportStateTest.java b/core/src/test/java/io/grpc/internal/Http2ClientStreamTransportStateTest.java index d49e41a4f4a..66df062a3e0 100644 --- a/core/src/test/java/io/grpc/internal/Http2ClientStreamTransportStateTest.java +++ b/core/src/test/java/io/grpc/internal/Http2ClientStreamTransportStateTest.java @@ -16,9 +16,9 @@ package io.grpc.internal; -import static com.google.common.base.Charsets.US_ASCII; import static io.grpc.internal.ClientStreamListener.RpcProgress.PROCESSED; import static io.grpc.internal.GrpcUtil.DEFAULT_MAX_MESSAGE_SIZE; +import static java.nio.charset.StandardCharsets.US_ASCII; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; import static org.mockito.ArgumentMatchers.any; @@ -27,6 +27,7 @@ import static org.mockito.Mockito.never; import static org.mockito.Mockito.verify; +import io.grpc.CallOptions; import io.grpc.InternalMetadata; import io.grpc.Metadata; import io.grpc.Status; @@ -300,6 +301,24 @@ public void transportTrailersReceived_missingStatusUsesHttpStatus() { assertTrue(statusCaptor.getValue().getDescription().contains("401")); } + @Test + public void transportTrailersReceived_missingContentTypeUsesHttpStatus() { + BaseTransportState state = new BaseTransportState(transportTracer); + state.setListener(mockListener); + Metadata trailers = new Metadata(); + trailers.put(testStatusMashaller, "431"); + + state.transportTrailersReceived(trailers); + + verify(mockListener, never()).headersRead(any(Metadata.class)); + verify(mockListener).closed(statusCaptor.capture(), same(PROCESSED), same(trailers)); + assertEquals(Code.INTERNAL, statusCaptor.getValue().getCode()); + assertTrue(statusCaptor.getValue().getDescription().contains("HTTP status code 431")); + assertTrue( + statusCaptor.getValue().getDescription().contains( + "missing content-type in response headers")); + } + @Test public void transportTrailersReceived_missingHttpStatus() { BaseTransportState state = new BaseTransportState(transportTracer); @@ -347,9 +366,22 @@ public void transportTrailersReceived_missingStatusAfterHeadersIgnoresHttpStatus assertEquals(Code.UNKNOWN, statusCaptor.getValue().getCode()); } + @Test + public void transportStateWithOnReadyThreshold() { + BaseTransportState state = new BaseTransportState(transportTracer, + CallOptions.DEFAULT.withOnReadyThreshold(Integer.MAX_VALUE)); + assertEquals(Integer.MAX_VALUE, state.onReadyThreshold); + } + private static class BaseTransportState extends Http2ClientStreamTransportState { + private int onReadyThreshold; + + public BaseTransportState(TransportTracer transportTracer, CallOptions options) { + super(DEFAULT_MAX_MESSAGE_SIZE, StatsTraceContext.NOOP, transportTracer, options); + } + public BaseTransportState(TransportTracer transportTracer) { - super(DEFAULT_MAX_MESSAGE_SIZE, StatsTraceContext.NOOP, transportTracer); + this(transportTracer, CallOptions.DEFAULT); } @Override @@ -367,5 +399,11 @@ public void bytesRead(int processedBytes) {} public void runOnTransportThread(Runnable r) { r.run(); } + + @Override + void setOnReadyThreshold(int numBytes) { + onReadyThreshold = numBytes; + super.setOnReadyThreshold(numBytes); + } } } diff --git a/core/src/test/java/io/grpc/internal/InstantTimeProviderTest.java b/core/src/test/java/io/grpc/internal/InstantTimeProviderTest.java new file mode 100644 index 00000000000..6702bc421a5 --- /dev/null +++ b/core/src/test/java/io/grpc/internal/InstantTimeProviderTest.java @@ -0,0 +1,51 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.internal; + +import static com.google.common.truth.Truth.assertThat; + +import java.time.Instant; +import java.util.concurrent.TimeUnit; +import org.codehaus.mojo.animal_sniffer.IgnoreJRERequirement; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** + * Unit tests for {@link InstantTimeProvider}. + */ +@RunWith(JUnit4.class) +@IgnoreJRERequirement +public class InstantTimeProviderTest { + @Test + public void testInstantCurrentTimeNanos() throws Exception { + + InstantTimeProvider instantTimeProvider = new InstantTimeProvider(); + + // Get the current time from the InstantTimeProvider + long actualTimeNanos = instantTimeProvider.currentTimeNanos(); + + // Get the current time from Instant for comparison + Instant instantNow = Instant.now(); + long expectedTimeNanos = TimeUnit.SECONDS.toNanos(instantNow.getEpochSecond()) + + instantNow.getNano(); + + // Validate the time returned is close to the expected value within a tolerance + // (i,e 1000 millisecond (1 second) tolerance in nanoseconds). + assertThat(actualTimeNanos).isWithin(1000_000_000L).of(expectedTimeNanos); + } +} diff --git a/core/src/test/java/io/grpc/internal/InternalSubchannelTest.java b/core/src/test/java/io/grpc/internal/InternalSubchannelTest.java index 05dc4549a0e..4236c091d9c 100644 --- a/core/src/test/java/io/grpc/internal/InternalSubchannelTest.java +++ b/core/src/test/java/io/grpc/internal/InternalSubchannelTest.java @@ -27,11 +27,15 @@ import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; +import static org.mockito.AdditionalAnswers.delegatesTo; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.argThat; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.isA; import static org.mockito.ArgumentMatchers.same; +import static org.mockito.Mockito.inOrder; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.times; @@ -44,8 +48,14 @@ import io.grpc.ConnectivityStateInfo; import io.grpc.EquivalentAddressGroup; import io.grpc.InternalChannelz; +import io.grpc.InternalEquivalentAddressGroup; import io.grpc.InternalLogId; import io.grpc.InternalWithLogId; +import io.grpc.LoadBalancer; +import io.grpc.MetricInstrument; +import io.grpc.MetricRecorder; +import io.grpc.NameResolver; +import io.grpc.SecurityLevel; import io.grpc.Status; import io.grpc.SynchronizationContext; import io.grpc.internal.InternalSubchannel.CallTracingTransport; @@ -64,9 +74,9 @@ import org.junit.Before; import org.junit.Rule; import org.junit.Test; -import org.junit.rules.ExpectedException; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; +import org.mockito.InOrder; import org.mockito.Mock; import org.mockito.junit.MockitoJUnit; import org.mockito.junit.MockitoRule; @@ -78,11 +88,11 @@ public class InternalSubchannelTest { @Rule public final MockitoRule mocks = MockitoJUnit.rule(); - @SuppressWarnings("deprecation") // https://github.com/grpc/grpc-java/issues/7467 - @Rule - public final ExpectedException thrown = ExpectedException.none(); private static final String AUTHORITY = "fakeauthority"; + private static final String BACKEND_SERVICE = "ice-cream-factory-service"; + private static final String LOCALITY = "mars-olympus-mons-datacenter"; + private static final SecurityLevel SECURITY_LEVEL = SecurityLevel.PRIVACY_AND_INTEGRITY; private static final String USER_AGENT = "mosaic"; private static final ConnectivityStateInfo UNAVAILABLE_STATE = ConnectivityStateInfo.forTransientFailure(Status.UNAVAILABLE); @@ -110,6 +120,10 @@ public void uncaughtException(Thread t, Throwable e) { @Mock private BackoffPolicy.Provider mockBackoffPolicyProvider; @Mock private ClientTransportFactory mockTransportFactory; + @Mock private BackoffPolicy mockBackoffPolicy; + private MetricRecorder mockMetricRecorder = mock(MetricRecorder.class, + delegatesTo(new MetricRecorderImpl())); + private final LinkedList callbackInvokes = new LinkedList<>(); private final InternalSubchannel.Callback mockInternalSubchannelCallback = new InternalSubchannel.Callback() { @@ -220,7 +234,8 @@ public void constructor_eagListWithNull_throws() { // Fail this one. Because there is only one address to try, enter TRANSIENT_FAILURE. assertNoCallbackInvoke(); - transports.poll().listener.transportShutdown(Status.UNAVAILABLE); + transports.poll().listener.transportShutdown(Status.UNAVAILABLE, + SimpleDisconnectError.SUBCHANNEL_SHUTDOWN); assertEquals(TRANSIENT_FAILURE, internalSubchannel.getState()); assertExactCallbackInvokes("onStateChange:" + UNAVAILABLE_STATE); // Backoff reset and using first back-off value interval @@ -251,7 +266,8 @@ public void constructor_eagListWithNull_throws() { assertNoCallbackInvoke(); // Here we use a different status from the first failure, and verify that it's passed to // the callback. - transports.poll().listener.transportShutdown(Status.RESOURCE_EXHAUSTED); + transports.poll().listener.transportShutdown(Status.RESOURCE_EXHAUSTED, + SimpleDisconnectError.SUBCHANNEL_SHUTDOWN); assertEquals(TRANSIENT_FAILURE, internalSubchannel.getState()); assertExactCallbackInvokes("onStateChange:" + RESOURCE_EXHAUSTED_STATE); // Second back-off interval @@ -289,7 +305,8 @@ public void constructor_eagListWithNull_throws() { // Close the READY transport, will enter IDLE state. assertNoCallbackInvoke(); - transports.poll().listener.transportShutdown(Status.UNAVAILABLE); + transports.poll().listener.transportShutdown(Status.UNAVAILABLE, + SimpleDisconnectError.SUBCHANNEL_SHUTDOWN); assertEquals(IDLE, internalSubchannel.getState()); assertExactCallbackInvokes("onStateChange:IDLE"); @@ -309,10 +326,59 @@ public void constructor_eagListWithNull_throws() { verify(mockBackoffPolicy2, times(backoff2Consulted)).nextBackoffNanos(); } + @Test public void twoAddressesReconnectDisabled() { + SocketAddress addr1 = mock(SocketAddress.class); + SocketAddress addr2 = mock(SocketAddress.class); + createInternalSubchannel(true, + new EquivalentAddressGroup(Arrays.asList(addr1, addr2))); + assertEquals(IDLE, internalSubchannel.getState()); + + assertNull(internalSubchannel.obtainActiveTransport()); + assertExactCallbackInvokes("onStateChange:CONNECTING"); + assertEquals(CONNECTING, internalSubchannel.getState()); + verify(mockTransportFactory).newClientTransport(eq(addr1), any(), any()); + // Let this one fail without success + transports.poll().listener.transportShutdown(Status.UNAVAILABLE, + SimpleDisconnectError.SUBCHANNEL_SHUTDOWN); + // Still in CONNECTING + assertNull(internalSubchannel.obtainActiveTransport()); + assertNoCallbackInvoke(); + assertEquals(CONNECTING, internalSubchannel.getState()); + + // Second attempt will start immediately. Still no back-off policy. + verify(mockBackoffPolicyProvider, times(0)).get(); + verify(mockTransportFactory, times(1)) + .newClientTransport( + eq(addr2), + eq(createClientTransportOptions()), + isA(TransportLogger.class)); + assertNull(internalSubchannel.obtainActiveTransport()); + // Fail this one too + assertNoCallbackInvoke(); + transports.poll().listener.transportShutdown(Status.UNAVAILABLE, + SimpleDisconnectError.SUBCHANNEL_SHUTDOWN); + // All addresses have failed, but we aren't controlling retries. + assertEquals(IDLE, internalSubchannel.getState()); + assertExactCallbackInvokes("onStateChange:" + UNAVAILABLE_STATE); + // Backoff reset and first back-off interval begins + verify(mockBackoffPolicy1, never()).nextBackoffNanos(); + verify(mockBackoffPolicyProvider, never()).get(); + assertTrue("Nothing should have been scheduled", fakeClock.getPendingTasks().isEmpty()); + + // Should follow orders and create an active transport. + internalSubchannel.obtainActiveTransport(); + assertExactCallbackInvokes("onStateChange:CONNECTING"); + assertEquals(CONNECTING, internalSubchannel.getState()); + + // Shouldn't have anything scheduled, so shouldn't do anything + assertTrue("Nothing should have been scheduled 2", fakeClock.getPendingTasks().isEmpty()); + } + @Test public void twoAddressesReconnect() { SocketAddress addr1 = mock(SocketAddress.class); SocketAddress addr2 = mock(SocketAddress.class); - createInternalSubchannel(addr1, addr2); + createInternalSubchannel(false, + new EquivalentAddressGroup(Arrays.asList(addr1, addr2))); assertEquals(IDLE, internalSubchannel.getState()); // Invocation counters int transportsAddr1 = 0; @@ -334,7 +400,8 @@ public void constructor_eagListWithNull_throws() { isA(TransportLogger.class)); // Let this one fail without success - transports.poll().listener.transportShutdown(Status.UNAVAILABLE); + transports.poll().listener.transportShutdown(Status.UNAVAILABLE, + SimpleDisconnectError.SUBCHANNEL_SHUTDOWN); // Still in CONNECTING assertNull(internalSubchannel.obtainActiveTransport()); assertNoCallbackInvoke(); @@ -350,7 +417,8 @@ public void constructor_eagListWithNull_throws() { assertNull(internalSubchannel.obtainActiveTransport()); // Fail this one too assertNoCallbackInvoke(); - transports.poll().listener.transportShutdown(Status.UNAVAILABLE); + transports.poll().listener.transportShutdown(Status.UNAVAILABLE, + SimpleDisconnectError.SUBCHANNEL_SHUTDOWN); // All addresses have failed. Delayed transport will be in back-off interval. assertEquals(TRANSIENT_FAILURE, internalSubchannel.getState()); assertExactCallbackInvokes("onStateChange:" + UNAVAILABLE_STATE); @@ -381,7 +449,8 @@ public void constructor_eagListWithNull_throws() { eq(createClientTransportOptions()), isA(TransportLogger.class)); // Fail this one too - transports.poll().listener.transportShutdown(Status.UNAVAILABLE); + transports.poll().listener.transportShutdown(Status.UNAVAILABLE, + SimpleDisconnectError.SUBCHANNEL_SHUTDOWN); assertEquals(CONNECTING, internalSubchannel.getState()); // Forth attempt will start immediately. Keep back-off policy. @@ -395,7 +464,8 @@ public void constructor_eagListWithNull_throws() { isA(TransportLogger.class)); // Fail this one too assertNoCallbackInvoke(); - transports.poll().listener.transportShutdown(Status.RESOURCE_EXHAUSTED); + transports.poll().listener.transportShutdown(Status.RESOURCE_EXHAUSTED, + SimpleDisconnectError.SUBCHANNEL_SHUTDOWN); // All addresses have failed again. Delayed transport will be in back-off interval. assertExactCallbackInvokes("onStateChange:" + RESOURCE_EXHAUSTED_STATE); assertEquals(TRANSIENT_FAILURE, internalSubchannel.getState()); @@ -432,7 +502,8 @@ public void constructor_eagListWithNull_throws() { ((CallTracingTransport) internalSubchannel.obtainActiveTransport()).delegate()); // Then close it. assertNoCallbackInvoke(); - transports.poll().listener.transportShutdown(Status.UNAVAILABLE); + transports.poll().listener.transportShutdown(Status.UNAVAILABLE, + SimpleDisconnectError.SUBCHANNEL_SHUTDOWN); assertExactCallbackInvokes("onStateChange:IDLE"); assertEquals(IDLE, internalSubchannel.getState()); @@ -448,7 +519,8 @@ public void constructor_eagListWithNull_throws() { eq(createClientTransportOptions()), isA(TransportLogger.class)); // Fail the transport - transports.poll().listener.transportShutdown(Status.UNAVAILABLE); + transports.poll().listener.transportShutdown(Status.UNAVAILABLE, + SimpleDisconnectError.SUBCHANNEL_SHUTDOWN); assertEquals(CONNECTING, internalSubchannel.getState()); // Second attempt will start immediately. Still no new back-off policy. @@ -460,7 +532,8 @@ public void constructor_eagListWithNull_throws() { isA(TransportLogger.class)); // Fail this one too assertEquals(CONNECTING, internalSubchannel.getState()); - transports.poll().listener.transportShutdown(Status.UNAVAILABLE); + transports.poll().listener.transportShutdown(Status.UNAVAILABLE, + SimpleDisconnectError.SUBCHANNEL_SHUTDOWN); // All addresses have failed. Enter TRANSIENT_FAILURE. Back-off in effect. assertExactCallbackInvokes("onStateChange:" + UNAVAILABLE_STATE); assertEquals(TRANSIENT_FAILURE, internalSubchannel.getState()); @@ -496,8 +569,9 @@ public void constructor_eagListWithNull_throws() { public void updateAddresses_emptyEagList_throws() { SocketAddress addr = new FakeSocketAddress(); createInternalSubchannel(addr); - thrown.expect(IllegalArgumentException.class); - internalSubchannel.updateAddresses(Arrays.asList()); + List newAddressGroups = Collections.emptyList(); + assertThrows(IllegalArgumentException.class, + () -> internalSubchannel.updateAddresses(newAddressGroups)); } @Test @@ -505,8 +579,7 @@ public void updateAddresses_eagListWithNull_throws() { SocketAddress addr = new FakeSocketAddress(); createInternalSubchannel(addr); List eags = Arrays.asList((EquivalentAddressGroup) null); - thrown.expect(NullPointerException.class); - internalSubchannel.updateAddresses(eags); + assertThrows(NullPointerException.class, () -> internalSubchannel.updateAddresses(eags)); } @Test public void updateAddresses_intersecting_ready() { @@ -524,7 +597,8 @@ public void updateAddresses_eagListWithNull_throws() { eq(addr1), eq(createClientTransportOptions()), isA(TransportLogger.class)); - transports.poll().listener.transportShutdown(Status.UNAVAILABLE); + transports.poll().listener.transportShutdown(Status.UNAVAILABLE, + SimpleDisconnectError.SUBCHANNEL_SHUTDOWN); assertEquals(CONNECTING, internalSubchannel.getState()); // Second address connects @@ -546,7 +620,8 @@ public void updateAddresses_eagListWithNull_throws() { verify(transports.peek().transport, never()).shutdownNow(any(Status.class)); // And new addresses chosen when re-connecting - transports.poll().listener.transportShutdown(Status.UNAVAILABLE); + transports.poll().listener.transportShutdown(Status.UNAVAILABLE, + SimpleDisconnectError.SUBCHANNEL_SHUTDOWN); assertExactCallbackInvokes("onStateChange:IDLE"); assertNull(internalSubchannel.obtainActiveTransport()); @@ -556,13 +631,15 @@ public void updateAddresses_eagListWithNull_throws() { eq(addr2), eq(createClientTransportOptions()), isA(TransportLogger.class)); - transports.poll().listener.transportShutdown(Status.UNAVAILABLE); + transports.poll().listener.transportShutdown(Status.UNAVAILABLE, + SimpleDisconnectError.SUBCHANNEL_SHUTDOWN); verify(mockTransportFactory) .newClientTransport( eq(addr3), eq(createClientTransportOptions()), isA(TransportLogger.class)); - transports.poll().listener.transportShutdown(Status.UNAVAILABLE); + transports.poll().listener.transportShutdown(Status.UNAVAILABLE, + SimpleDisconnectError.SUBCHANNEL_SHUTDOWN); verifyNoMoreInteractions(mockTransportFactory); fakeClock.forwardNanos(10); // Drain retry, but don't care about result @@ -583,7 +660,8 @@ public void updateAddresses_eagListWithNull_throws() { eq(addr1), eq(createClientTransportOptions()), isA(TransportLogger.class)); - transports.poll().listener.transportShutdown(Status.UNAVAILABLE); + transports.poll().listener.transportShutdown(Status.UNAVAILABLE, + SimpleDisconnectError.SUBCHANNEL_SHUTDOWN); assertEquals(CONNECTING, internalSubchannel.getState()); // Second address connecting @@ -606,7 +684,8 @@ public void updateAddresses_eagListWithNull_throws() { // And new addresses chosen when re-connecting transports.peek().listener.transportReady(); assertExactCallbackInvokes("onStateChange:READY"); - transports.poll().listener.transportShutdown(Status.UNAVAILABLE); + transports.poll().listener.transportShutdown(Status.UNAVAILABLE, + SimpleDisconnectError.SUBCHANNEL_SHUTDOWN); assertExactCallbackInvokes("onStateChange:IDLE"); assertNull(internalSubchannel.obtainActiveTransport()); @@ -616,13 +695,15 @@ public void updateAddresses_eagListWithNull_throws() { eq(addr2), eq(createClientTransportOptions()), isA(TransportLogger.class)); - transports.poll().listener.transportShutdown(Status.UNAVAILABLE); + transports.poll().listener.transportShutdown(Status.UNAVAILABLE, + SimpleDisconnectError.SUBCHANNEL_SHUTDOWN); verify(mockTransportFactory) .newClientTransport( eq(addr3), eq(createClientTransportOptions()), isA(TransportLogger.class)); - transports.poll().listener.transportShutdown(Status.UNAVAILABLE); + transports.poll().listener.transportShutdown(Status.UNAVAILABLE, + SimpleDisconnectError.SUBCHANNEL_SHUTDOWN); verifyNoMoreInteractions(mockTransportFactory); fakeClock.forwardNanos(10); // Drain retry, but don't care about result @@ -661,7 +742,8 @@ public void updateAddresses_eagListWithNull_throws() { // And no other addresses attempted assertEquals(0, fakeClock.numPendingTasks()); - transports.poll().listener.transportShutdown(Status.UNAVAILABLE); + transports.poll().listener.transportShutdown(Status.UNAVAILABLE, + SimpleDisconnectError.SUBCHANNEL_SHUTDOWN); assertExactCallbackInvokes("onStateChange:" + UNAVAILABLE_STATE); assertEquals(TRANSIENT_FAILURE, internalSubchannel.getState()); verifyNoMoreInteractions(mockTransportFactory); @@ -685,7 +767,8 @@ public void updateAddresses_eagListWithNull_throws() { eq(addr1), eq(createClientTransportOptions()), isA(TransportLogger.class)); - transports.poll().listener.transportShutdown(Status.UNAVAILABLE); + transports.poll().listener.transportShutdown(Status.UNAVAILABLE, + SimpleDisconnectError.SUBCHANNEL_SHUTDOWN); assertEquals(CONNECTING, internalSubchannel.getState()); // Second address connects @@ -709,7 +792,8 @@ public void updateAddresses_eagListWithNull_throws() { verify(transports.peek().transport).shutdown(any(Status.class)); // And new addresses chosen when re-connecting - transports.poll().listener.transportShutdown(Status.UNAVAILABLE); + transports.poll().listener.transportShutdown(Status.UNAVAILABLE, + SimpleDisconnectError.SUBCHANNEL_SHUTDOWN); assertNoCallbackInvoke(); assertEquals(IDLE, internalSubchannel.getState()); @@ -720,13 +804,15 @@ public void updateAddresses_eagListWithNull_throws() { eq(addr3), eq(createClientTransportOptions()), isA(TransportLogger.class)); - transports.poll().listener.transportShutdown(Status.UNAVAILABLE); + transports.poll().listener.transportShutdown(Status.UNAVAILABLE, + SimpleDisconnectError.SUBCHANNEL_SHUTDOWN); verify(mockTransportFactory) .newClientTransport( eq(addr4), eq(createClientTransportOptions()), isA(TransportLogger.class)); - transports.poll().listener.transportShutdown(Status.UNAVAILABLE); + transports.poll().listener.transportShutdown(Status.UNAVAILABLE, + SimpleDisconnectError.SUBCHANNEL_SHUTDOWN); verifyNoMoreInteractions(mockTransportFactory); fakeClock.forwardNanos(10); // Drain retry, but don't care about result @@ -748,7 +834,8 @@ public void updateAddresses_eagListWithNull_throws() { eq(addr1), eq(createClientTransportOptions()), isA(TransportLogger.class)); - transports.poll().listener.transportShutdown(Status.UNAVAILABLE); + transports.poll().listener.transportShutdown(Status.UNAVAILABLE, + SimpleDisconnectError.SUBCHANNEL_SHUTDOWN); assertEquals(CONNECTING, internalSubchannel.getState()); // Second address connecting @@ -778,13 +865,15 @@ public void updateAddresses_eagListWithNull_throws() { eq(addr3), eq(createClientTransportOptions()), isA(TransportLogger.class)); - transports.poll().listener.transportShutdown(Status.UNAVAILABLE); + transports.poll().listener.transportShutdown(Status.UNAVAILABLE, + SimpleDisconnectError.SUBCHANNEL_SHUTDOWN); verify(mockTransportFactory) .newClientTransport( eq(addr4), eq(createClientTransportOptions()), isA(TransportLogger.class)); - transports.poll().listener.transportShutdown(Status.UNAVAILABLE); + transports.poll().listener.transportShutdown(Status.UNAVAILABLE, + SimpleDisconnectError.SUBCHANNEL_SHUTDOWN); verifyNoMoreInteractions(mockTransportFactory); fakeClock.forwardNanos(10); // Drain retry, but don't care about result @@ -868,7 +957,8 @@ public void connectIsLazy() { isA(TransportLogger.class)); // Fail this one - transports.poll().listener.transportShutdown(Status.UNAVAILABLE); + transports.poll().listener.transportShutdown(Status.UNAVAILABLE, + SimpleDisconnectError.SUBCHANNEL_SHUTDOWN); assertExactCallbackInvokes("onStateChange:" + UNAVAILABLE_STATE); // Will always reconnect after back-off @@ -884,7 +974,8 @@ public void connectIsLazy() { transports.peek().listener.transportReady(); assertExactCallbackInvokes("onStateChange:READY"); // Then go-away - transports.poll().listener.transportShutdown(Status.UNAVAILABLE); + transports.poll().listener.transportShutdown(Status.UNAVAILABLE, + SimpleDisconnectError.SUBCHANNEL_SHUTDOWN); assertExactCallbackInvokes("onStateChange:IDLE"); // No scheduled tasks that would ever try to reconnect ... @@ -914,7 +1005,8 @@ public void shutdownWhenReady() throws Exception { internalSubchannel.shutdown(SHUTDOWN_REASON); verify(transportInfo.transport).shutdown(same(SHUTDOWN_REASON)); assertExactCallbackInvokes("onStateChange:SHUTDOWN"); - transportInfo.listener.transportShutdown(SHUTDOWN_REASON); + transportInfo.listener.transportShutdown(SHUTDOWN_REASON, + SimpleDisconnectError.SUBCHANNEL_SHUTDOWN); transportInfo.listener.transportTerminated(); assertExactCallbackInvokes("onTerminated"); @@ -937,7 +1029,8 @@ public void shutdownBeforeTransportCreated() throws Exception { // Fail this one MockClientTransportInfo transportInfo = transports.poll(); - transportInfo.listener.transportShutdown(Status.UNAVAILABLE); + transportInfo.listener.transportShutdown(Status.UNAVAILABLE, + SimpleDisconnectError.SUBCHANNEL_SHUTDOWN); transportInfo.listener.transportTerminated(); // Entering TRANSIENT_FAILURE, waiting for back-off @@ -964,7 +1057,7 @@ public void shutdownBeforeTransportCreated() throws Exception { // This should not lead to the creation of a new transport. reconnectTask.command.run(); - // Futher call to obtainActiveTransport() is no-op. + // Further call to obtainActiveTransport() is no-op. assertNull(internalSubchannel.obtainActiveTransport()); assertEquals(SHUTDOWN, internalSubchannel.getState()); assertNoCallbackInvoke(); @@ -993,7 +1086,8 @@ public void shutdownBeforeTransportReady() throws Exception { // The transport should've been shut down even though it's not the active transport yet. verify(transportInfo.transport).shutdown(same(SHUTDOWN_REASON)); - transportInfo.listener.transportShutdown(Status.UNAVAILABLE); + transportInfo.listener.transportShutdown(Status.UNAVAILABLE, + SimpleDisconnectError.SUBCHANNEL_SHUTDOWN); assertNoCallbackInvoke(); transportInfo.listener.transportTerminated(); assertExactCallbackInvokes("onTerminated"); @@ -1009,7 +1103,7 @@ public void shutdownNow() throws Exception { MockClientTransportInfo t1 = transports.poll(); t1.listener.transportReady(); assertExactCallbackInvokes("onStateChange:CONNECTING", "onStateChange:READY"); - t1.listener.transportShutdown(Status.UNAVAILABLE); + t1.listener.transportShutdown(Status.UNAVAILABLE, SimpleDisconnectError.SUBCHANNEL_SHUTDOWN); assertExactCallbackInvokes("onStateChange:IDLE"); internalSubchannel.obtainActiveTransport(); @@ -1066,7 +1160,7 @@ public void inUseState() { t0.listener.transportInUse(true); assertExactCallbackInvokes("onInUse"); - t0.listener.transportShutdown(Status.UNAVAILABLE); + t0.listener.transportShutdown(Status.UNAVAILABLE, SimpleDisconnectError.SUBCHANNEL_SHUTDOWN); assertExactCallbackInvokes("onStateChange:IDLE"); assertNull(internalSubchannel.obtainActiveTransport()); @@ -1099,7 +1193,7 @@ public void transportTerminateWithoutExitingInUse() { t0.listener.transportInUse(true); assertExactCallbackInvokes("onInUse"); - t0.listener.transportShutdown(Status.UNAVAILABLE); + t0.listener.transportShutdown(Status.UNAVAILABLE, SimpleDisconnectError.SUBCHANNEL_SHUTDOWN); assertExactCallbackInvokes("onStateChange:IDLE"); t0.listener.transportTerminated(); assertExactCallbackInvokes("onNotInUse"); @@ -1126,12 +1220,12 @@ public void run() { assertEquals(1, runnableInvokes.get()); MockClientTransportInfo t0 = transports.poll(); - t0.listener.transportShutdown(Status.UNAVAILABLE); + t0.listener.transportShutdown(Status.UNAVAILABLE, SimpleDisconnectError.SUBCHANNEL_SHUTDOWN); assertEquals(2, runnableInvokes.get()); // 2nd address: reconnect immediatly MockClientTransportInfo t1 = transports.poll(); - t1.listener.transportShutdown(Status.UNAVAILABLE); + t1.listener.transportShutdown(Status.UNAVAILABLE, SimpleDisconnectError.SUBCHANNEL_SHUTDOWN); // Addresses exhausted, waiting for back-off. assertEquals(2, runnableInvokes.get()); @@ -1158,7 +1252,8 @@ public void resetConnectBackoff() throws Exception { eq(addr), eq(createClientTransportOptions()), isA(TransportLogger.class)); - transports.poll().listener.transportShutdown(Status.UNAVAILABLE); + transports.poll().listener.transportShutdown(Status.UNAVAILABLE, + SimpleDisconnectError.SUBCHANNEL_SHUTDOWN); assertExactCallbackInvokes("onStateChange:" + UNAVAILABLE_STATE); // Save the reconnectTask @@ -1194,7 +1289,8 @@ public void resetConnectBackoff() throws Exception { // Fail the reconnect attempt to verify that a fresh reconnect policy is generated after // invoking resetConnectBackoff() - transports.poll().listener.transportShutdown(Status.UNAVAILABLE); + transports.poll().listener.transportShutdown(Status.UNAVAILABLE, + SimpleDisconnectError.SUBCHANNEL_SHUTDOWN); assertExactCallbackInvokes("onStateChange:" + UNAVAILABLE_STATE); verify(mockBackoffPolicyProvider, times(2)).get(); fakeClock.forwardNanos(10); @@ -1222,7 +1318,8 @@ public void channelzMembership() throws Exception { MockClientTransportInfo t0 = transports.poll(); t0.listener.transportReady(); assertTrue(channelz.containsClientSocket(t0.transport.getLogId())); - t0.listener.transportShutdown(Status.RESOURCE_EXHAUSTED); + t0.listener.transportShutdown(Status.RESOURCE_EXHAUSTED, + SimpleDisconnectError.SUBCHANNEL_SHUTDOWN); t0.listener.transportTerminated(); assertFalse(channelz.containsClientSocket(t0.transport.getLogId())); } @@ -1339,6 +1436,32 @@ public void channelzStatContainsTransport() throws Exception { assertThat(index.getCurrentAddress()).isSameInstanceAs(addr2); } + @Test + public void connectedAddressAttributes_ready() { + SocketAddress addr = new SocketAddress() {}; + Attributes attr = Attributes.newBuilder().set(Attributes.Key.create("some-key"), "1").build(); + createInternalSubchannel(new EquivalentAddressGroup(Arrays.asList(addr), attr)); + + assertEquals(IDLE, internalSubchannel.getState()); + assertNoCallbackInvoke(); + assertNull(internalSubchannel.obtainActiveTransport()); + assertNull(internalSubchannel.getConnectedAddressAttributes()); + + assertExactCallbackInvokes("onStateChange:CONNECTING"); + assertEquals(CONNECTING, internalSubchannel.getState()); + verify(mockTransportFactory).newClientTransport( + eq(addr), + eq(createClientTransportOptions().setEagAttributes(attr)), + isA(TransportLogger.class)); + assertNull(internalSubchannel.getConnectedAddressAttributes()); + + internalSubchannel.obtainActiveTransport(); + transports.peek().listener.transportReady(); + assertExactCallbackInvokes("onStateChange:READY"); + assertEquals(READY, internalSubchannel.getState()); + assertEquals(attr, internalSubchannel.getConnectedAddressAttributes()); + } + /** Create ClientTransportOptions. Should not be reused if it may be mutated. */ private ClientTransportFactory.ClientTransportOptions createClientTransportOptions() { return new ClientTransportFactory.ClientTransportOptions() @@ -1351,18 +1474,201 @@ private void createInternalSubchannel(SocketAddress ... addrs) { } private void createInternalSubchannel(EquivalentAddressGroup ... addrs) { + createInternalSubchannel(false, addrs); + } + + private void createInternalSubchannel(boolean reconnectDisabled, + EquivalentAddressGroup ... addrs) { List addressGroups = Arrays.asList(addrs); InternalLogId logId = InternalLogId.allocate("Subchannel", /*details=*/ AUTHORITY); ChannelTracer subchannelTracer = new ChannelTracer(logId, 10, fakeClock.getTimeProvider().currentTimeNanos(), "Subchannel"); - internalSubchannel = new InternalSubchannel(addressGroups, AUTHORITY, USER_AGENT, + LoadBalancer.CreateSubchannelArgs.Builder argBuilder = + LoadBalancer.CreateSubchannelArgs.newBuilder().setAddresses(addressGroups); + if (reconnectDisabled) { + argBuilder.addOption(LoadBalancer.DISABLE_SUBCHANNEL_RECONNECT_KEY, reconnectDisabled); + } + LoadBalancer.CreateSubchannelArgs createSubchannelArgs = argBuilder.build(); + internalSubchannel = new InternalSubchannel( + createSubchannelArgs, + AUTHORITY, USER_AGENT, mockBackoffPolicyProvider, mockTransportFactory, fakeClock.getScheduledExecutorService(), fakeClock.getStopwatchSupplier(), syncContext, mockInternalSubchannelCallback, channelz, CallTracer.getDefaultFactory().create(), subchannelTracer, logId, new ChannelLoggerImpl(subchannelTracer, fakeClock.getTimeProvider()), - Collections.emptyList()); + Collections.emptyList(), + "", + new MetricRecorder() { + } + ); + } + + @Test + public void subchannelStateChanges_triggersAttemptFailedMetric() { + // 1. Setup: Standard subchannel initialization + when(mockBackoffPolicyProvider.get()).thenReturn(mockBackoffPolicy); + SocketAddress addr = mock(SocketAddress.class); + Attributes eagAttributes = Attributes.newBuilder() + .set(InternalEquivalentAddressGroup.ATTR_BACKEND_SERVICE, BACKEND_SERVICE) + .set(EquivalentAddressGroup.ATTR_LOCALITY_NAME, LOCALITY) + .set(GrpcAttributes.ATTR_SECURITY_LEVEL, SECURITY_LEVEL) + .build(); + List addressGroups = + Arrays.asList(new EquivalentAddressGroup(Arrays.asList(addr), eagAttributes)); + InternalLogId logId = InternalLogId.allocate("Subchannel", /*details=*/ AUTHORITY); + ChannelTracer subchannelTracer = new ChannelTracer(logId, 10, + fakeClock.getTimeProvider().currentTimeNanos(), "Subchannel"); + LoadBalancer.CreateSubchannelArgs createSubchannelArgs = + LoadBalancer.CreateSubchannelArgs.newBuilder().setAddresses(addressGroups).build(); + internalSubchannel = new InternalSubchannel( + createSubchannelArgs, AUTHORITY, USER_AGENT, mockBackoffPolicyProvider, + mockTransportFactory, fakeClock.getScheduledExecutorService(), + fakeClock.getStopwatchSupplier(), syncContext, mockInternalSubchannelCallback, channelz, + CallTracer.getDefaultFactory().create(), subchannelTracer, logId, + new ChannelLoggerImpl(subchannelTracer, fakeClock.getTimeProvider()), + Collections.emptyList(), AUTHORITY, mockMetricRecorder + ); + + // --- Action: Simulate the "connecting to failed" transition --- + // a. Initiate the connection attempt. The subchannel is now CONNECTING. + internalSubchannel.obtainActiveTransport(); + MockClientTransportInfo transportInfo = transports.poll(); + assertNotNull("A connection attempt should have been made", transportInfo); + + // b. Fail the transport before it can signal `transportReady()`. + transportInfo.listener.transportShutdown( + Status.INTERNAL.withDescription("Simulated connect failure"), + SimpleDisconnectError.SUBCHANNEL_SHUTDOWN); + fakeClock.runDueTasks(); // Process the failure event + + // --- Verification --- + // a. Verify that the "connection_attempts_failed" metric was recorded exactly once. + verify(mockMetricRecorder).addLongCounter( + eqMetricInstrumentName("grpc.subchannel.connection_attempts_failed"), + eq(1L), + eq(Arrays.asList(AUTHORITY)), + eq(Arrays.asList(BACKEND_SERVICE, LOCALITY)) + ); + + // b. Verify no other metrics were recorded. This confirms it wasn't incorrectly + // logged as a success, disconnection, or open connection. + verifyNoMoreInteractions(mockMetricRecorder); + } + + @Test + public void subchannelStateChanges_triggersSuccessAndDisconnectMetrics() { + // 1. Mock the backoff policy (needed for subchannel creation) + when(mockBackoffPolicyProvider.get()).thenReturn(mockBackoffPolicy); + + // 2. Setup Subchannel with attributes + SocketAddress addr = mock(SocketAddress.class); + Attributes eagAttributes = Attributes.newBuilder() + .set(InternalEquivalentAddressGroup.ATTR_BACKEND_SERVICE, BACKEND_SERVICE) + .set(EquivalentAddressGroup.ATTR_LOCALITY_NAME, LOCALITY) + .set(GrpcAttributes.ATTR_SECURITY_LEVEL, SECURITY_LEVEL) + .build(); + List addressGroups = + Arrays.asList(new EquivalentAddressGroup(Arrays.asList(addr), eagAttributes)); + createInternalSubchannel(new EquivalentAddressGroup(addr)); + InternalLogId logId = InternalLogId.allocate("Subchannel", /*details=*/ AUTHORITY); + ChannelTracer subchannelTracer = new ChannelTracer(logId, 10, + fakeClock.getTimeProvider().currentTimeNanos(), "Subchannel"); + LoadBalancer.CreateSubchannelArgs createSubchannelArgs = + LoadBalancer.CreateSubchannelArgs.newBuilder().setAddresses(addressGroups).build(); + internalSubchannel = new InternalSubchannel( + createSubchannelArgs, AUTHORITY, USER_AGENT, mockBackoffPolicyProvider, + mockTransportFactory, fakeClock.getScheduledExecutorService(), + fakeClock.getStopwatchSupplier(), syncContext, mockInternalSubchannelCallback, channelz, + CallTracer.getDefaultFactory().create(), subchannelTracer, logId, + new ChannelLoggerImpl(subchannelTracer, fakeClock.getTimeProvider()), + Collections.emptyList(), AUTHORITY, mockMetricRecorder + ); + + // --- Action: Successful connection --- + internalSubchannel.obtainActiveTransport(); + MockClientTransportInfo transportInfo = transports.poll(); + assertNotNull(transportInfo); + transportInfo.listener.transportReady(); + fakeClock.runDueTasks(); // Process the successful connection + + // --- Action: Transport is shut down --- + transportInfo.listener.transportShutdown(Status.UNAVAILABLE.withDescription("unknown"), + SimpleDisconnectError.SUBCHANNEL_SHUTDOWN); + fakeClock.runDueTasks(); // Process the shutdown + + // --- Verification --- + InOrder inOrder = inOrder(mockMetricRecorder); + + // Verify successful connection metrics + inOrder.verify(mockMetricRecorder).addLongCounter( + eqMetricInstrumentName("grpc.subchannel.connection_attempts_succeeded"), + eq(1L), + eq(Arrays.asList(AUTHORITY)), + eq(Arrays.asList(BACKEND_SERVICE, LOCALITY)) + ); + inOrder.verify(mockMetricRecorder).addLongUpDownCounter( + eqMetricInstrumentName("grpc.subchannel.open_connections"), + eq(1L), + eq(Arrays.asList(AUTHORITY)), + eq(Arrays.asList("privacy_and_integrity", BACKEND_SERVICE, LOCALITY)) + ); + + // Verify disconnection metrics + inOrder.verify(mockMetricRecorder).addLongCounter( + eqMetricInstrumentName("grpc.subchannel.disconnections"), + eq(1L), + eq(Arrays.asList(AUTHORITY)), + eq(Arrays.asList(BACKEND_SERVICE, LOCALITY, "subchannel shutdown")) + ); + inOrder.verify(mockMetricRecorder).addLongUpDownCounter( + eqMetricInstrumentName("grpc.subchannel.open_connections"), + eq(-1L), + eq(Arrays.asList(AUTHORITY)), + eq(Arrays.asList("privacy_and_integrity", BACKEND_SERVICE, LOCALITY)) + ); + + inOrder.verifyNoMoreInteractions(); + } + + @Test + public void subchannelStateChanges_backendServiceFallsBackToResolutionResultAttr() { + when(mockBackoffPolicyProvider.get()).thenReturn(mockBackoffPolicy); + SocketAddress addr = mock(SocketAddress.class); + Attributes eagAttributes = Attributes.newBuilder() + .set(NameResolver.ATTR_BACKEND_SERVICE, BACKEND_SERVICE) + .set(EquivalentAddressGroup.ATTR_LOCALITY_NAME, LOCALITY) + .set(GrpcAttributes.ATTR_SECURITY_LEVEL, SECURITY_LEVEL) + .build(); + List addressGroups = + Arrays.asList(new EquivalentAddressGroup(Arrays.asList(addr), eagAttributes)); + InternalLogId logId = InternalLogId.allocate("Subchannel", /*details=*/ AUTHORITY); + ChannelTracer subchannelTracer = new ChannelTracer(logId, 10, + fakeClock.getTimeProvider().currentTimeNanos(), "Subchannel"); + LoadBalancer.CreateSubchannelArgs createSubchannelArgs = + LoadBalancer.CreateSubchannelArgs.newBuilder().setAddresses(addressGroups).build(); + internalSubchannel = new InternalSubchannel( + createSubchannelArgs, AUTHORITY, USER_AGENT, mockBackoffPolicyProvider, + mockTransportFactory, fakeClock.getScheduledExecutorService(), + fakeClock.getStopwatchSupplier(), syncContext, mockInternalSubchannelCallback, channelz, + CallTracer.getDefaultFactory().create(), subchannelTracer, logId, + new ChannelLoggerImpl(subchannelTracer, fakeClock.getTimeProvider()), + Collections.emptyList(), AUTHORITY, mockMetricRecorder + ); + + internalSubchannel.obtainActiveTransport(); + MockClientTransportInfo transportInfo = transports.poll(); + assertNotNull(transportInfo); + transportInfo.listener.transportReady(); + fakeClock.runDueTasks(); + + verify(mockMetricRecorder).addLongCounter( + eqMetricInstrumentName("grpc.subchannel.connection_attempts_succeeded"), + eq(1L), + eq(Arrays.asList(AUTHORITY)), + eq(Arrays.asList(BACKEND_SERVICE, LOCALITY)) + ); } private void assertNoCallbackInvoke() { @@ -1375,5 +1681,13 @@ private void assertExactCallbackInvokes(String ... expectedInvokes) { callbackInvokes.clear(); } + static class MetricRecorderImpl implements MetricRecorder { + } + + @SuppressWarnings("TypeParameterUnusedInFormals") + private T eqMetricInstrumentName(String name) { + return argThat(instrument -> instrument.getName().equals(name)); + } + private static class FakeSocketAddress extends SocketAddress {} } diff --git a/core/src/test/java/io/grpc/internal/JsonParserTest.java b/core/src/test/java/io/grpc/internal/JsonParserTest.java index 1e74c753d4d..a0dd81c20ce 100644 --- a/core/src/test/java/io/grpc/internal/JsonParserTest.java +++ b/core/src/test/java/io/grpc/internal/JsonParserTest.java @@ -17,15 +17,14 @@ package io.grpc.internal; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; import com.google.gson.stream.MalformedJsonException; import java.io.EOFException; import java.io.IOException; import java.util.ArrayList; import java.util.LinkedHashMap; -import org.junit.Rule; import org.junit.Test; -import org.junit.rules.ExpectedException; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -35,10 +34,6 @@ @RunWith(JUnit4.class) public class JsonParserTest { - @SuppressWarnings("deprecation") // https://github.com/grpc/grpc-java/issues/7467 - @Rule - public final ExpectedException thrown = ExpectedException.none(); - @Test public void emptyObject() throws IOException { assertEquals(new LinkedHashMap(), JsonParser.parse("{}")); @@ -75,45 +70,33 @@ public void nullValue() throws IOException { } @Test - public void nanFails() throws IOException { - thrown.expect(MalformedJsonException.class); - - JsonParser.parse("NaN"); + public void nanFails() { + assertThrows(MalformedJsonException.class, () -> JsonParser.parse("NaN")); } @Test - public void objectEarlyEnd() throws IOException { - thrown.expect(MalformedJsonException.class); - - JsonParser.parse("{foo:}"); + public void objectEarlyEnd() { + assertThrows(MalformedJsonException.class, () -> JsonParser.parse("{foo:}")); } @Test - public void earlyEndArray() throws IOException { - thrown.expect(EOFException.class); - - JsonParser.parse("[1, 2, "); + public void earlyEndArray() { + assertThrows(EOFException.class, () -> JsonParser.parse("[1, 2, ")); } @Test - public void arrayMissingElement() throws IOException { - thrown.expect(MalformedJsonException.class); - - JsonParser.parse("[1, 2, ]"); + public void arrayMissingElement() { + assertThrows(MalformedJsonException.class, () -> JsonParser.parse("[1, 2, ]")); } @Test - public void objectMissingElement() throws IOException { - thrown.expect(MalformedJsonException.class); - - JsonParser.parse("{1: "); + public void objectMissingElement() { + assertThrows(MalformedJsonException.class, () -> JsonParser.parse("{1: ")); } @Test - public void objectNoName() throws IOException { - thrown.expect(MalformedJsonException.class); - - JsonParser.parse("{: 1"); + public void objectNoName() { + assertThrows(MalformedJsonException.class, () -> JsonParser.parse("{: 1")); } @Test @@ -123,4 +106,9 @@ public void objectStringName() throws IOException { assertEquals(expected, JsonParser.parse("{\"hi\": 2}")); } + + @Test + public void duplicate() { + assertThrows(IllegalArgumentException.class, () -> JsonParser.parse("{\"hi\": 2, \"hi\": 3}")); + } } diff --git a/core/src/test/java/io/grpc/internal/KeepAliveManagerTest.java b/core/src/test/java/io/grpc/internal/KeepAliveManagerTest.java index 411a9fbe9fc..81e3d1b2638 100644 --- a/core/src/test/java/io/grpc/internal/KeepAliveManagerTest.java +++ b/core/src/test/java/io/grpc/internal/KeepAliveManagerTest.java @@ -19,6 +19,7 @@ import static com.google.common.truth.Truth.assertThat; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; +import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.isA; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; @@ -104,13 +105,15 @@ public void keepAlivePingDelayedByIncomingData() { @Test public void clientKeepAlivePinger_pingTimeout() { - ConnectionClientTransport transport = mock(ConnectionClientTransport.class); + ClientKeepAlivePinger.TransportWithDisconnectReason transport = + mock(ClientKeepAlivePinger.TransportWithDisconnectReason.class); ClientKeepAlivePinger pinger = new ClientKeepAlivePinger(transport); pinger.onPingTimeout(); ArgumentCaptor statusCaptor = ArgumentCaptor.forClass(Status.class); - verify(transport).shutdownNow(statusCaptor.capture()); + verify(transport).shutdownNow(statusCaptor.capture(), + eq(SimpleDisconnectError.CONNECTION_TIMED_OUT)); Status status = statusCaptor.getValue(); assertThat(status.getCode()).isEqualTo(Status.Code.UNAVAILABLE); assertThat(status.getDescription()).isEqualTo( @@ -119,7 +122,8 @@ public void clientKeepAlivePinger_pingTimeout() { @Test public void clientKeepAlivePinger_pingFailure() { - ConnectionClientTransport transport = mock(ConnectionClientTransport.class); + ClientKeepAlivePinger.TransportWithDisconnectReason transport = + mock(ClientKeepAlivePinger.TransportWithDisconnectReason.class); ClientKeepAlivePinger pinger = new ClientKeepAlivePinger(transport); pinger.ping(); ArgumentCaptor pingCallbackCaptor = @@ -127,10 +131,11 @@ public void clientKeepAlivePinger_pingFailure() { verify(transport).ping(pingCallbackCaptor.capture(), isA(Executor.class)); ClientTransport.PingCallback pingCallback = pingCallbackCaptor.getValue(); - pingCallback.onFailure(new Throwable()); + pingCallback.onFailure(Status.UNAVAILABLE.withDescription("I must write descriptions")); ArgumentCaptor statusCaptor = ArgumentCaptor.forClass(Status.class); - verify(transport).shutdownNow(statusCaptor.capture()); + verify(transport).shutdownNow(statusCaptor.capture(), + eq(SimpleDisconnectError.CONNECTION_TIMED_OUT)); Status status = statusCaptor.getValue(); assertThat(status.getCode()).isEqualTo(Status.Code.UNAVAILABLE); assertThat(status.getDescription()).isEqualTo( diff --git a/core/src/test/java/io/grpc/internal/ManagedChannelImplBuilderTest.java b/core/src/test/java/io/grpc/internal/ManagedChannelImplBuilderTest.java index faa03b20319..b0939239477 100644 --- a/core/src/test/java/io/grpc/internal/ManagedChannelImplBuilderTest.java +++ b/core/src/test/java/io/grpc/internal/ManagedChannelImplBuilderTest.java @@ -22,6 +22,8 @@ import static org.junit.Assert.assertNotEquals; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; import static org.mockito.Mockito.doReturn; @@ -36,13 +38,18 @@ import io.grpc.ClientInterceptor; import io.grpc.CompressorRegistry; import io.grpc.DecompressorRegistry; -import io.grpc.InternalGlobalInterceptors; +import io.grpc.FlagResetRule; +import io.grpc.InternalConfigurator; +import io.grpc.InternalConfiguratorRegistry; +import io.grpc.InternalFeatureFlags; +import io.grpc.InternalManagedChannelBuilder.InternalInterceptorFactory; import io.grpc.ManagedChannel; +import io.grpc.ManagedChannelBuilder; import io.grpc.MethodDescriptor; +import io.grpc.MetricSink; import io.grpc.NameResolver; import io.grpc.NameResolverRegistry; import io.grpc.StaticTestingClassLoader; -import io.grpc.inprocess.InProcessSocketAddress; import io.grpc.internal.ManagedChannelImplBuilder.ChannelBuilderDefaultPortProvider; import io.grpc.internal.ManagedChannelImplBuilder.ClientTransportFactoryBuilder; import io.grpc.internal.ManagedChannelImplBuilder.FixedPortProvider; @@ -63,15 +70,16 @@ import org.junit.Before; import org.junit.Rule; import org.junit.Test; -import org.junit.rules.ExpectedException; import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; +import org.junit.runners.Parameterized; +import org.junit.runners.Parameterized.Parameter; +import org.junit.runners.Parameterized.Parameters; import org.mockito.Mock; import org.mockito.junit.MockitoJUnit; import org.mockito.junit.MockitoRule; /** Unit tests for {@link ManagedChannelImplBuilder}. */ -@RunWith(JUnit4.class) +@RunWith(Parameterized.class) public class ManagedChannelImplBuilderTest { private static final int DUMMY_PORT = 42; private static final String DUMMY_TARGET = "fake-target"; @@ -94,10 +102,16 @@ public ClientCall interceptCall( } }; + @Parameters(name = "enableRfc3986UrisParam={0}") + public static Iterable data() { + return Arrays.asList(new Object[][] {{true}, {false}}); + } + + @Parameter public boolean enableRfc3986UrisParam; + @Rule public final MockitoRule mocks = MockitoJUnit.rule(); - @SuppressWarnings("deprecation") // https://github.com/grpc/grpc-java/issues/7467 - @Rule public final ExpectedException thrown = ExpectedException.none(); @Rule public final GrpcCleanupRule grpcCleanupRule = new GrpcCleanupRule(); + @Rule public final FlagResetRule flagResetRule = new FlagResetRule(); @Mock private ClientTransportFactory mockClientTransportFactory; @Mock private ClientTransportFactoryBuilder mockClientTransportFactoryBuilder; @@ -109,11 +123,15 @@ public ClientCall interceptCall( new StaticTestingClassLoader( getClass().getClassLoader(), Pattern.compile( - "io\\.grpc\\.InternalGlobalInterceptors|io\\.grpc\\.GlobalInterceptors|" + "io\\.grpc\\.InternalConfigurator|io\\.grpc\\.Configurator|" + + "io\\.grpc\\.InternalConfiguratorRegistry|io\\.grpc\\.ConfiguratorRegistry|" + "io\\.grpc\\.internal\\.[^.]+")); @Before public void setUp() throws Exception { + flagResetRule.setFlagForTest( + InternalFeatureFlags::setRfc3986UrisEnabled, enableRfc3986UrisParam); + builder = new ManagedChannelImplBuilder( DUMMY_TARGET, new UnsupportedClientTransportFactoryBuilder(), @@ -165,37 +183,37 @@ public void executor_default() { @Test public void executor_normal() { Executor executor = mock(Executor.class); - assertEquals(builder, builder.executor(executor)); - assertEquals(executor, builder.executorPool.getObject()); + assertSame(builder, builder.executor(executor)); + assertSame(executor, builder.executorPool.getObject()); } @Test public void executor_null() { ObjectPool defaultValue = builder.executorPool; builder.executor(mock(Executor.class)); - assertEquals(builder, builder.executor(null)); - assertEquals(defaultValue, builder.executorPool); + assertSame(builder, builder.executor(null)); + assertSame(defaultValue, builder.executorPool); } @Test public void directExecutor() { - assertEquals(builder, builder.directExecutor()); + assertSame(builder, builder.directExecutor()); assertEquals(MoreExecutors.directExecutor(), builder.executorPool.getObject()); } @Test public void offloadExecutor_normal() { Executor executor = mock(Executor.class); - assertEquals(builder, builder.offloadExecutor(executor)); - assertEquals(executor, builder.offloadExecutorPool.getObject()); + assertSame(builder, builder.offloadExecutor(executor)); + assertSame(executor, builder.offloadExecutorPool.getObject()); } @Test public void offloadExecutor_null() { ObjectPool defaultValue = builder.offloadExecutorPool; builder.offloadExecutor(mock(Executor.class)); - assertEquals(builder, builder.offloadExecutor(null)); - assertEquals(defaultValue, builder.offloadExecutorPool); + assertSame(builder, builder.offloadExecutor(null)); + assertSame(defaultValue, builder.offloadExecutorPool); } @Test @@ -208,7 +226,7 @@ public void nameResolverRegistry_default() { public void nameResolverFactory_normal() { NameResolver.Factory nameResolverFactory = mock(NameResolver.Factory.class); doReturn("testscheme").when(nameResolverFactory).getDefaultScheme(); - assertEquals(builder, builder.nameResolverFactory(nameResolverFactory)); + assertSame(builder, builder.nameResolverFactory(nameResolverFactory)); assertNotNull(builder.nameResolverRegistry); assertEquals("testscheme", builder.nameResolverRegistry.asFactory().getDefaultScheme()); } @@ -238,7 +256,7 @@ public void defaultLoadBalancingPolicy_default() { @Test public void defaultLoadBalancingPolicy_normal() { - assertEquals(builder, builder.defaultLoadBalancingPolicy("magic_balancer")); + assertSame(builder, builder.defaultLoadBalancingPolicy("magic_balancer")); assertEquals("magic_balancer", builder.defaultLbPolicy); } @@ -266,14 +284,14 @@ public void decompressorRegistry_default() { public void decompressorRegistry_normal() { DecompressorRegistry decompressorRegistry = DecompressorRegistry.emptyInstance(); assertNotEquals(decompressorRegistry, builder.decompressorRegistry); - assertEquals(builder, builder.decompressorRegistry(decompressorRegistry)); + assertSame(builder, builder.decompressorRegistry(decompressorRegistry)); assertEquals(decompressorRegistry, builder.decompressorRegistry); } @Test public void decompressorRegistry_null() { DecompressorRegistry defaultValue = builder.decompressorRegistry; - assertEquals(builder, builder.decompressorRegistry(DecompressorRegistry.emptyInstance())); + assertSame(builder, builder.decompressorRegistry(DecompressorRegistry.emptyInstance())); assertNotEquals(defaultValue, builder.decompressorRegistry); builder.decompressorRegistry(null); assertEquals(defaultValue, builder.decompressorRegistry); @@ -288,8 +306,8 @@ public void compressorRegistry_default() { public void compressorRegistry_normal() { CompressorRegistry compressorRegistry = CompressorRegistry.newEmptyInstance(); assertNotEquals(compressorRegistry, builder.compressorRegistry); - assertEquals(builder, builder.compressorRegistry(compressorRegistry)); - assertEquals(compressorRegistry, builder.compressorRegistry); + assertSame(builder, builder.compressorRegistry(compressorRegistry)); + assertSame(compressorRegistry, builder.compressorRegistry); } @Test @@ -297,8 +315,8 @@ public void compressorRegistry_null() { CompressorRegistry defaultValue = builder.compressorRegistry; builder.compressorRegistry(CompressorRegistry.newEmptyInstance()); assertNotEquals(defaultValue, builder.compressorRegistry); - assertEquals(builder, builder.compressorRegistry(null)); - assertEquals(defaultValue, builder.compressorRegistry); + assertSame(builder, builder.compressorRegistry(null)); + assertSame(defaultValue, builder.compressorRegistry); } @Test @@ -309,13 +327,13 @@ public void userAgent_default() { @Test public void userAgent_normal() { String userAgent = "user-agent/1"; - assertEquals(builder, builder.userAgent(userAgent)); - assertEquals(userAgent, builder.userAgent); + assertSame(builder, builder.userAgent(userAgent)); + assertSame(userAgent, builder.userAgent); } @Test public void userAgent_null() { - assertEquals(builder, builder.userAgent(null)); + assertSame(builder, builder.userAgent(null)); assertNull(builder.userAgent); builder.userAgent("user-agent/1"); @@ -362,7 +380,7 @@ public void transportDoesNotSupportAddressTypes() { when(mockClientTransportFactoryBuilder.buildClientTransportFactory()) .thenReturn(mockClientTransportFactory); when(mockClientTransportFactory.getSupportedSocketAddressTypes()) - .thenReturn(Collections.singleton(InProcessSocketAddress.class)); + .thenReturn(Collections.singleton(CustomSocketAddress.class)); builder = new ManagedChannelImplBuilder(DUMMY_AUTHORITY_VALID, mockClientTransportFactoryBuilder, new FixedPortProvider(DUMMY_PORT)); @@ -370,8 +388,11 @@ public void transportDoesNotSupportAddressTypes() { ManagedChannel unused = grpcCleanupRule.register(builder.build()); fail("Should fail"); } catch (IllegalArgumentException e) { - assertThat(e).hasMessageThat().isEqualTo( - "Address types of NameResolver 'dns' for 'valid:1234' not supported by transport"); + assertThat(e) + .hasMessageThat() + .isEqualTo( + "Address types of NameResolver 'dns' for 'dns:///valid:1234' not supported by" + + " transport"); } } @@ -398,8 +419,8 @@ public void overrideAuthority_default() { @Test public void overrideAuthority_normal() { String overrideAuthority = "best-authority"; - assertEquals(builder, builder.overrideAuthority(overrideAuthority)); - assertEquals(overrideAuthority, builder.authorityOverride); + assertSame(builder, builder.overrideAuthority(overrideAuthority)); + assertSame(overrideAuthority, builder.authorityOverride); } @Test(expected = NullPointerException.class) @@ -419,10 +440,9 @@ public void checkAuthority_validAuthorityAllowed() { @Test public void checkAuthority_invalidAuthorityFailed() { - thrown.expect(IllegalArgumentException.class); - thrown.expectMessage("Invalid authority"); - - builder.checkAuthority(DUMMY_AUTHORITY_INVALID); + IllegalArgumentException e = assertThrows(IllegalArgumentException.class, + () -> builder.checkAuthority(DUMMY_AUTHORITY_INVALID)); + assertThat(e).hasMessageThat().isEqualTo("Invalid authority: [ : : 1]"); } @Test @@ -445,11 +465,10 @@ public void enableCheckAuthority_validAuthorityAllowed() { @Test public void disableCheckAuthority_invalidAuthorityFailed() { - thrown.expect(IllegalArgumentException.class); - thrown.expectMessage("Invalid authority"); - builder.disableCheckAuthority().enableCheckAuthority(); - builder.checkAuthority(DUMMY_AUTHORITY_INVALID); + IllegalArgumentException e = assertThrows(IllegalArgumentException.class, + () -> builder.checkAuthority(DUMMY_AUTHORITY_INVALID)); + assertThat(e).hasMessageThat().isEqualTo("Invalid authority: [ : : 1]"); } @Test @@ -469,7 +488,7 @@ public void makeTargetStringForDirectAddress_scopedIpv6() throws Exception { @Test public void getEffectiveInterceptors_default() { builder.intercept(DUMMY_USER_INTERCEPTOR); - List effectiveInterceptors = builder.getEffectiveInterceptors(); + List effectiveInterceptors = builder.getEffectiveInterceptors("unused:///"); assertEquals(3, effectiveInterceptors.size()); assertThat(effectiveInterceptors.get(0).getClass().getName()) .isEqualTo("io.grpc.census.CensusTracingModule$TracingClientInterceptor"); @@ -482,7 +501,7 @@ public void getEffectiveInterceptors_default() { public void getEffectiveInterceptors_disableStats() { builder.intercept(DUMMY_USER_INTERCEPTOR); builder.setStatsEnabled(false); - List effectiveInterceptors = builder.getEffectiveInterceptors(); + List effectiveInterceptors = builder.getEffectiveInterceptors("unused:///"); assertEquals(2, effectiveInterceptors.size()); assertThat(effectiveInterceptors.get(0).getClass().getName()) .isEqualTo("io.grpc.census.CensusTracingModule$TracingClientInterceptor"); @@ -493,7 +512,7 @@ public void getEffectiveInterceptors_disableStats() { public void getEffectiveInterceptors_disableTracing() { builder.intercept(DUMMY_USER_INTERCEPTOR); builder.setTracingEnabled(false); - List effectiveInterceptors = builder.getEffectiveInterceptors(); + List effectiveInterceptors = builder.getEffectiveInterceptors("unused:///"); assertEquals(2, effectiveInterceptors.size()); assertThat(effectiveInterceptors.get(0).getClass().getName()) .isEqualTo("io.grpc.census.CensusStatsModule$StatsClientInterceptor"); @@ -505,12 +524,12 @@ public void getEffectiveInterceptors_disableBoth() { builder.intercept(DUMMY_USER_INTERCEPTOR); builder.setStatsEnabled(false); builder.setTracingEnabled(false); - List effectiveInterceptors = builder.getEffectiveInterceptors(); + List effectiveInterceptors = builder.getEffectiveInterceptors("unused:///"); assertThat(effectiveInterceptors).containsExactly(DUMMY_USER_INTERCEPTOR); } @Test - public void getEffectiveInterceptors_callsGetGlobalInterceptors() throws Exception { + public void getEffectiveInterceptors_callsGetConfiguratorRegistry() throws Exception { Class runnable = classLoader.loadClass(StaticTestingClassLoaderCallsGet.class.getName()); ((Runnable) runnable.getDeclaredConstructor().newInstance()).run(); } @@ -525,22 +544,17 @@ public void run() { DUMMY_TARGET, new UnsupportedClientTransportFactoryBuilder(), new FixedPortProvider(DUMMY_PORT)); - List effectiveInterceptors = builder.getEffectiveInterceptors(); + List effectiveInterceptors = + builder.getEffectiveInterceptors("unused:///"); assertThat(effectiveInterceptors).hasSize(2); - try { - InternalGlobalInterceptors.setInterceptorsTracers( - Arrays.asList(DUMMY_USER_INTERCEPTOR), - Collections.emptyList(), - Collections.emptyList()); - fail("exception expected"); - } catch (IllegalStateException e) { - assertThat(e).hasMessageThat().contains("Set cannot be called after any get call"); - } + InternalConfiguratorRegistry.setConfigurators(Collections.emptyList()); + assertThat(InternalConfiguratorRegistry.getConfigurators()).isEmpty(); + assertThat(InternalConfiguratorRegistry.getConfiguratorsCallCountBeforeSet()).isEqualTo(1); } } @Test - public void getEffectiveInterceptors_callsSetGlobalInterceptors() throws Exception { + public void getEffectiveInterceptors_callsSetConfiguratorRegistry() throws Exception { Class runnable = classLoader.loadClass(StaticTestingClassLoaderCallsSet.class.getName()); ((Runnable) runnable.getDeclaredConstructor().newInstance()).run(); } @@ -550,23 +564,27 @@ public static final class StaticTestingClassLoaderCallsSet implements Runnable { @Override public void run() { - InternalGlobalInterceptors.setInterceptorsTracers( - Arrays.asList(DUMMY_USER_INTERCEPTOR, DUMMY_USER_INTERCEPTOR1), - Collections.emptyList(), - Collections.emptyList()); + InternalConfiguratorRegistry.setConfigurators( + Arrays.asList(new InternalConfigurator() { + @Override + public void configureChannelBuilder(ManagedChannelBuilder builder) { + builder.intercept(DUMMY_USER_INTERCEPTOR, DUMMY_USER_INTERCEPTOR1); + } + })); ManagedChannelImplBuilder builder = new ManagedChannelImplBuilder( DUMMY_TARGET, new UnsupportedClientTransportFactoryBuilder(), new FixedPortProvider(DUMMY_PORT)); - List effectiveInterceptors = builder.getEffectiveInterceptors(); + List effectiveInterceptors = + builder.getEffectiveInterceptors("unused:///"); assertThat(effectiveInterceptors) .containsExactly(DUMMY_USER_INTERCEPTOR, DUMMY_USER_INTERCEPTOR1); } } @Test - public void getEffectiveInterceptors_setEmptyGlobalInterceptors() throws Exception { + public void getEffectiveInterceptors_setEmptyConfiguratorRegistry() throws Exception { Class runnable = classLoader.loadClass(StaticTestingClassLoaderCallsSetEmpty.class.getName()); ((Runnable) runnable.getDeclaredConstructor().newInstance()).run(); @@ -577,18 +595,41 @@ public static final class StaticTestingClassLoaderCallsSetEmpty implements Runna @Override public void run() { - InternalGlobalInterceptors.setInterceptorsTracers( - Collections.emptyList(), Collections.emptyList(), Collections.emptyList()); + InternalConfiguratorRegistry.setConfigurators(Collections.emptyList()); ManagedChannelImplBuilder builder = new ManagedChannelImplBuilder( DUMMY_TARGET, new UnsupportedClientTransportFactoryBuilder(), new FixedPortProvider(DUMMY_PORT)); - List effectiveInterceptors = builder.getEffectiveInterceptors(); + List effectiveInterceptors = + builder.getEffectiveInterceptors("unused:///"); assertThat(effectiveInterceptors).isEmpty(); } } + @Test + public void getEffectiveInterceptors_createsFromInterceptorFactories() throws Exception { + String target = "dns:///the-host"; + builder.setStatsEnabled(false); + builder.setTracingEnabled(false); + + builder.intercept(DUMMY_USER_INTERCEPTOR) + .interceptWithTarget(new InternalInterceptorFactory() { + @Override + public ClientInterceptor newInterceptor(String passedTarget) { + assertThat(passedTarget).isEqualTo(target); + return DUMMY_USER_INTERCEPTOR1; + } + }) + .intercept(DUMMY_USER_INTERCEPTOR); + + assertThat(builder.getEffectiveInterceptors(target)) + .isEqualTo(Arrays.asList( + DUMMY_USER_INTERCEPTOR, + DUMMY_USER_INTERCEPTOR1, + DUMMY_USER_INTERCEPTOR)); + } + @Test public void idleTimeout() { assertEquals(ManagedChannelImplBuilder.IDLE_MODE_DEFAULT_TIMEOUT_MILLIS, @@ -650,14 +691,12 @@ public void perRpcBufferLimit() { @Test public void retryBufferSizeInvalidArg() { - thrown.expect(IllegalArgumentException.class); - builder.retryBufferSize(0L); + assertThrows(IllegalArgumentException.class, () -> builder.retryBufferSize(0L)); } @Test public void perRpcBufferLimitInvalidArg() { - thrown.expect(IllegalArgumentException.class); - builder.perRpcBufferLimit(0L); + assertThrows(IllegalArgumentException.class, () -> builder.perRpcBufferLimit(0L)); } @Test @@ -680,8 +719,7 @@ public void defaultServiceConfig_nullKey() { Map config = new HashMap<>(); config.put(null, "val"); - thrown.expect(IllegalArgumentException.class); - builder.defaultServiceConfig(config); + assertThrows(IllegalArgumentException.class, () -> builder.defaultServiceConfig(config)); } @Test @@ -691,8 +729,7 @@ public void defaultServiceConfig_intKey() { Map config = new HashMap<>(); config.put("key", subConfig); - thrown.expect(IllegalArgumentException.class); - builder.defaultServiceConfig(config); + assertThrows(IllegalArgumentException.class, () -> builder.defaultServiceConfig(config)); } @Test @@ -700,8 +737,7 @@ public void defaultServiceConfig_intValue() { Map config = new HashMap<>(); config.put("key", 3); - thrown.expect(IllegalArgumentException.class); - builder.defaultServiceConfig(config); + assertThrows(IllegalArgumentException.class, () -> builder.defaultServiceConfig(config)); } @Test @@ -733,4 +769,35 @@ public void disableNameResolverServiceConfig() { builder.disableServiceConfigLookUp(); assertThat(builder.lookUpServiceConfig).isFalse(); } + + @Test + public void setNameResolverExtArgs() { + assertThat(builder.nameResolverCustomArgs) + .isNull(); + + NameResolver.Args.Key testKey = NameResolver.Args.Key.create("test-key"); + builder.setNameResolverArg(testKey, 42); + assertThat(builder.nameResolverCustomArgs.get(testKey)).isEqualTo(42); + } + + @Test + public void metricSinks() { + MetricSink mocksink = mock(MetricSink.class); + builder.addMetricSink(mocksink); + + assertThat(builder.metricSinks).contains(mocksink); + } + + @Test + public void uriPattern() { + Pattern uriPattern = ManagedChannelImplBuilder.URI_PATTERN; + assertTrue(uriPattern.matcher("a:/").matches()); + assertTrue(uriPattern.matcher("Z019+-.:/!@ #~ ").matches()); + assertFalse(uriPattern.matcher("a/:").matches()); // "/:" not matched + assertFalse(uriPattern.matcher("0a:/").matches()); // '0' not matched + assertFalse(uriPattern.matcher("a,:/").matches()); // ',' not matched + assertFalse(uriPattern.matcher(" a:/").matches()); // space not matched + } + + private static class CustomSocketAddress extends SocketAddress {} } diff --git a/core/src/test/java/io/grpc/internal/ManagedChannelImplGetNameResolverRfc3986Test.java b/core/src/test/java/io/grpc/internal/ManagedChannelImplGetNameResolverRfc3986Test.java new file mode 100644 index 00000000000..5bcf24a30e2 --- /dev/null +++ b/core/src/test/java/io/grpc/internal/ManagedChannelImplGetNameResolverRfc3986Test.java @@ -0,0 +1,243 @@ +/* + * Copyright 2015 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.internal; + +import static com.google.common.truth.Truth.assertThat; +import static io.grpc.internal.UriWrapper.wrap; +import static org.junit.Assert.fail; + +import io.grpc.NameResolver; +import io.grpc.NameResolverProvider; +import io.grpc.NameResolverRegistry; +import io.grpc.Uri; +import java.net.SocketAddress; +import java.net.URI; +import java.util.Collections; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Unit tests for ManagedChannelImplBuilder#getNameResolverProviderNew(). */ +@RunWith(JUnit4.class) +public class ManagedChannelImplGetNameResolverRfc3986Test { + @Test + public void invalidUriTarget() { + testInvalidTarget("defaultscheme:///[invalid]"); + } + + @Test + public void invalidUnescapedSquareBracketsInRfc3986UriFragment() { + testInvalidTarget("defaultscheme://8.8.8.8/host#section[1]"); + } + + @Test + public void invalidUnescapedSquareBracketsInRfc3986UriQuery() { + testInvalidTarget("dns://8.8.8.8/path?section=[1]"); + } + + @Test + public void validTargetWithInvalidDnsName() throws Exception { + testValidTarget( + "[valid]", + "defaultscheme:///%5Bvalid%5D", + Uri.newBuilder().setScheme("defaultscheme").setHost("").setPath("/[valid]").build()); + } + + @Test + public void validAuthorityTarget() throws Exception { + testValidTarget( + "foo.googleapis.com:8080", + "defaultscheme:///foo.googleapis.com:8080", + Uri.newBuilder() + .setScheme("defaultscheme") + .setHost("") + .setPath("/foo.googleapis.com:8080") + .build()); + } + + @Test + public void validUriTarget() throws Exception { + testValidTarget( + "scheme:///foo.googleapis.com:8080", + "scheme:///foo.googleapis.com:8080", + Uri.newBuilder() + .setScheme("scheme") + .setHost("") + .setPath("/foo.googleapis.com:8080") + .build()); + } + + @Test + public void validIpv4AuthorityTarget() throws Exception { + testValidTarget( + "127.0.0.1:1234", + "defaultscheme:///127.0.0.1:1234", + Uri.newBuilder().setScheme("defaultscheme").setHost("").setPath("/127.0.0.1:1234").build()); + } + + @Test + public void validIpv4UriTarget() throws Exception { + testValidTarget( + "dns:///127.0.0.1:1234", + "dns:///127.0.0.1:1234", + Uri.newBuilder().setScheme("dns").setHost("").setPath("/127.0.0.1:1234").build()); + } + + @Test + public void validIpv6AuthorityTarget() throws Exception { + testValidTarget( + "[::1]:1234", + "defaultscheme:///%5B::1%5D:1234", + Uri.newBuilder().setScheme("defaultscheme").setHost("").setPath("/[::1]:1234").build()); + } + + @Test + public void invalidIpv6UriTarget() throws Exception { + testInvalidTarget("dns:///[::1]:1234"); + } + + @Test + public void invalidIpv6UriWithUnescapedScope() { + testInvalidTarget("dns://[::1%eth0]:53/host"); + } + + @Test + public void validIpv6UriTarget() throws Exception { + testValidTarget( + "dns:///%5B::1%5D:1234", + "dns:///%5B::1%5D:1234", + Uri.newBuilder().setScheme("dns").setHost("").setPath("/[::1]:1234").build()); + } + + @Test + public void validTargetStartingWithSlash() throws Exception { + testValidTarget( + "/target", + "defaultscheme:////target", + Uri.newBuilder().setScheme("defaultscheme").setHost("").setPath("//target").build()); + } + + @Test + public void validTargetNoProvider() { + NameResolverRegistry nameResolverRegistry = new NameResolverRegistry(); + try { + ManagedChannelImplBuilder.getNameResolverProviderRfc3986( + "foo.googleapis.com:8080", nameResolverRegistry); + fail("Should fail"); + } catch (IllegalArgumentException e) { + // expected + } + } + + @Test + public void validTargetProviderAddrTypesNotSupported() { + NameResolverRegistry nameResolverRegistry = getTestRegistry("testscheme"); + try { + ManagedChannelImplBuilder.getNameResolverProviderRfc3986( + "testscheme:///foo.googleapis.com:8080", nameResolverRegistry) + .checkAddressTypes(Collections.singleton(CustomSocketAddress.class)); + fail("Should fail"); + } catch (IllegalArgumentException e) { + assertThat(e) + .hasMessageThat() + .isEqualTo( + "Address types of NameResolver 'testscheme' for " + + "'testscheme:///foo.googleapis.com:8080' not supported by transport"); + } + } + + private void testValidTarget(String target, String expectedUriString, Uri expectedUri) { + NameResolverRegistry nameResolverRegistry = getTestRegistry(expectedUri.getScheme()); + ManagedChannelImplBuilder.ResolvedNameResolver resolved = + ManagedChannelImplBuilder.getNameResolverProviderRfc3986(target, nameResolverRegistry); + assertThat(resolved.provider).isInstanceOf(FakeNameResolverProvider.class); + assertThat(resolved.targetUri).isEqualTo(wrap(expectedUri)); + assertThat(resolved.targetUri.toString()).isEqualTo(expectedUriString); + } + + private void testInvalidTarget(String target) { + NameResolverRegistry nameResolverRegistry = getTestRegistry("dns"); + + try { + ManagedChannelImplBuilder.ResolvedNameResolver resolved = + ManagedChannelImplBuilder.getNameResolverProviderRfc3986(target, nameResolverRegistry); + FakeNameResolverProvider nameResolverProvider = (FakeNameResolverProvider) resolved.provider; + fail("Should have failed, but got resolver provider " + nameResolverProvider); + } catch (IllegalArgumentException e) { + // expected + } + } + + private static NameResolverRegistry getTestRegistry(String expectedScheme) { + NameResolverRegistry nameResolverRegistry = new NameResolverRegistry(); + FakeNameResolverProvider nameResolverProvider = new FakeNameResolverProvider(expectedScheme); + nameResolverRegistry.register(nameResolverProvider); + return nameResolverRegistry; + } + + private static class FakeNameResolverProvider extends NameResolverProvider { + final String expectedScheme; + + FakeNameResolverProvider(String expectedScheme) { + this.expectedScheme = expectedScheme; + } + + @Override + public NameResolver newNameResolver(URI targetUri, NameResolver.Args args) { + if (expectedScheme.equals(targetUri.getScheme())) { + return new FakeNameResolver(targetUri); + } + return null; + } + + @Override + public String getDefaultScheme() { + return expectedScheme; + } + + @Override + protected boolean isAvailable() { + return true; + } + + @Override + protected int priority() { + return 5; + } + } + + private static class FakeNameResolver extends NameResolver { + final URI uri; + + FakeNameResolver(URI uri) { + this.uri = uri; + } + + @Override + public String getServiceAuthority() { + return uri.getAuthority(); + } + + @Override + public void start(final Listener2 listener) {} + + @Override + public void shutdown() {} + } + + private static class CustomSocketAddress extends SocketAddress {} +} diff --git a/core/src/test/java/io/grpc/internal/ManagedChannelImplGetNameResolverTest.java b/core/src/test/java/io/grpc/internal/ManagedChannelImplGetNameResolverTest.java index 452e071912c..792f4daca4e 100644 --- a/core/src/test/java/io/grpc/internal/ManagedChannelImplGetNameResolverTest.java +++ b/core/src/test/java/io/grpc/internal/ManagedChannelImplGetNameResolverTest.java @@ -17,45 +17,42 @@ package io.grpc.internal; import static com.google.common.truth.Truth.assertThat; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; +import static io.grpc.internal.UriWrapper.wrap; import static org.junit.Assert.fail; -import static org.mockito.Mockito.mock; -import io.grpc.ChannelLogger; import io.grpc.NameResolver; -import io.grpc.NameResolver.Args; -import io.grpc.NameResolver.ServiceConfigParser; import io.grpc.NameResolverProvider; import io.grpc.NameResolverRegistry; -import io.grpc.ProxyDetector; -import io.grpc.SynchronizationContext; -import io.grpc.inprocess.InProcessSocketAddress; -import java.lang.Thread.UncaughtExceptionHandler; -import java.net.InetSocketAddress; +import java.net.SocketAddress; import java.net.URI; import java.util.Collections; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; -/** Unit tests for ManagedChannelImpl#getNameResolver(). */ +/** Unit tests for ManagedChannelImplBuilder#getNameResolverProvider(). */ @RunWith(JUnit4.class) public class ManagedChannelImplGetNameResolverTest { - private static final NameResolver.Args NAMERESOLVER_ARGS = NameResolver.Args.newBuilder() - .setDefaultPort(447) - .setProxyDetector(mock(ProxyDetector.class)) - .setSynchronizationContext(new SynchronizationContext(mock(UncaughtExceptionHandler.class))) - .setServiceConfigParser(mock(ServiceConfigParser.class)) - .setChannelLogger(mock(ChannelLogger.class)) - .setScheduledExecutorService(new FakeClock().getScheduledExecutorService()) - .build(); - @Test public void invalidUriTarget() { testInvalidTarget("defaultscheme:///[invalid]"); } + @Test + public void validSquareBracketsInRfc2396UriFragment() throws Exception { + testValidTarget("dns://8.8.8.8/host#section[1]", + "dns://8.8.8.8/host#section[1]", + new URI("dns", "8.8.8.8", "/host", null, "section[1]")); + } + + + @Test + public void validSquareBracketsInRfc2396UriQuery() throws Exception { + testValidTarget("dns://8.8.8.8/host?section=[1]", + "dns://8.8.8.8/host?section=[1]", + new URI("dns", "8.8.8.8", "/host", "section=[1]", null)); + } + @Test public void validTargetWithInvalidDnsName() throws Exception { testValidTarget("[valid]", "defaultscheme:///%5Bvalid%5D", @@ -68,18 +65,6 @@ public void validAuthorityTarget() throws Exception { new URI("defaultscheme", "", "/foo.googleapis.com:8080", null)); } - @Test - public void validAuthorityTarget_overrideAuthority() throws Exception { - String target = "foo.googleapis.com:8080"; - String overrideAuthority = "override.authority"; - URI expectedUri = new URI("defaultscheme", "", "/foo.googleapis.com:8080", null); - NameResolverRegistry nameResolverRegistry = getTestRegistry(expectedUri.getScheme()); - NameResolver nameResolver = ManagedChannelImpl.getNameResolver( - target, overrideAuthority, nameResolverRegistry, NAMERESOLVER_ARGS, - Collections.singleton(InetSocketAddress.class)); - assertThat(nameResolver.getServiceAuthority()).isEqualTo(overrideAuthority); - } - @Test public void validUriTarget() throws Exception { testValidTarget("scheme:///foo.googleapis.com:8080", "scheme:///foo.googleapis.com:8080", @@ -104,6 +89,13 @@ public void validIpv6AuthorityTarget() throws Exception { new URI("defaultscheme", "", "/[::1]:1234", null)); } + @Test + public void validIpv6UriWithJavaNetUriScopeName() throws Exception { + testValidTarget("dns://[::1%eth0]:53/host", + "dns://[::1%eth0]:53/host", + new URI("dns", "[::1%eth0]:53", "/host", null, null)); + } + @Test public void invalidIpv6UriTarget() throws Exception { testInvalidTarget("dns:///[::1]:1234"); @@ -121,48 +113,12 @@ public void validTargetStartingWithSlash() throws Exception { new URI("defaultscheme", "", "//target", null)); } - @Test - public void validTargetNoResolver() { - NameResolverRegistry nameResolverRegistry = new NameResolverRegistry(); - NameResolverProvider nameResolverProvider = new NameResolverProvider() { - @Override - protected boolean isAvailable() { - return true; - } - - @Override - protected int priority() { - return 5; - } - - @Override - public NameResolver newNameResolver(URI targetUri, Args args) { - return null; - } - - @Override - public String getDefaultScheme() { - return "defaultscheme"; - } - }; - nameResolverRegistry.register(nameResolverProvider); - try { - ManagedChannelImpl.getNameResolver( - "foo.googleapis.com:8080", null, nameResolverRegistry, NAMERESOLVER_ARGS, - Collections.singleton(InetSocketAddress.class)); - fail("Should fail"); - } catch (IllegalArgumentException e) { - // expected - } - } - @Test public void validTargetNoProvider() { NameResolverRegistry nameResolverRegistry = new NameResolverRegistry(); try { - ManagedChannelImpl.getNameResolver( - "foo.googleapis.com:8080", null, nameResolverRegistry, NAMERESOLVER_ARGS, - Collections.singleton(InetSocketAddress.class)); + ManagedChannelImplBuilder.getNameResolverProvider( + "foo.googleapis.com:8080", nameResolverRegistry); fail("Should fail"); } catch (IllegalArgumentException e) { // expected @@ -173,9 +129,9 @@ public void validTargetNoProvider() { public void validTargetProviderAddrTypesNotSupported() { NameResolverRegistry nameResolverRegistry = getTestRegistry("testscheme"); try { - ManagedChannelImpl.getNameResolver( - "testscheme:///foo.googleapis.com:8080", null, nameResolverRegistry, NAMERESOLVER_ARGS, - Collections.singleton(InProcessSocketAddress.class)); + ManagedChannelImplBuilder.getNameResolverProvider( + "testscheme:///foo.googleapis.com:8080", nameResolverRegistry) + .checkAddressTypes(Collections.singleton(CustomSocketAddress.class)); fail("Should fail"); } catch (IllegalArgumentException e) { assertThat(e).hasMessageThat().isEqualTo( @@ -184,26 +140,23 @@ public void validTargetProviderAddrTypesNotSupported() { } } - private void testValidTarget(String target, String expectedUriString, URI expectedUri) { NameResolverRegistry nameResolverRegistry = getTestRegistry(expectedUri.getScheme()); - FakeNameResolver nameResolver - = (FakeNameResolver) ((RetryingNameResolver) ManagedChannelImpl.getNameResolver( - target, null, nameResolverRegistry, NAMERESOLVER_ARGS, - Collections.singleton(InetSocketAddress.class))).getRetriedNameResolver(); - assertNotNull(nameResolver); - assertEquals(expectedUri, nameResolver.uri); - assertEquals(expectedUriString, nameResolver.uri.toString()); + ManagedChannelImplBuilder.ResolvedNameResolver resolved = + ManagedChannelImplBuilder.getNameResolverProvider(target, nameResolverRegistry); + assertThat(resolved.provider).isInstanceOf(FakeNameResolverProvider.class); + assertThat(resolved.targetUri).isEqualTo(wrap(expectedUri)); + assertThat(resolved.targetUri.toString()).isEqualTo(expectedUriString); } private void testInvalidTarget(String target) { NameResolverRegistry nameResolverRegistry = getTestRegistry("dns"); try { - FakeNameResolver nameResolver = (FakeNameResolver) ManagedChannelImpl.getNameResolver( - target, null, nameResolverRegistry, NAMERESOLVER_ARGS, - Collections.singleton(InetSocketAddress.class)); - fail("Should have failed, but got resolver with " + nameResolver.uri); + ManagedChannelImplBuilder.ResolvedNameResolver resolved = + ManagedChannelImplBuilder.getNameResolverProvider(target, nameResolverRegistry); + FakeNameResolverProvider nameResolverProvider = (FakeNameResolverProvider) resolved.provider; + fail("Should have failed, but got resolver provider " + nameResolverProvider); } catch (IllegalArgumentException e) { // expected } @@ -262,4 +215,6 @@ private static class FakeNameResolver extends NameResolver { @Override public void shutdown() {} } + + private static class CustomSocketAddress extends SocketAddress {} } diff --git a/core/src/test/java/io/grpc/internal/ManagedChannelImplIdlenessTest.java b/core/src/test/java/io/grpc/internal/ManagedChannelImplIdlenessTest.java index e50eeaf7686..97e92be7fdd 100644 --- a/core/src/test/java/io/grpc/internal/ManagedChannelImplIdlenessTest.java +++ b/core/src/test/java/io/grpc/internal/ManagedChannelImplIdlenessTest.java @@ -19,6 +19,7 @@ import static com.google.common.truth.Truth.assertThat; import static io.grpc.ConnectivityState.READY; import static io.grpc.ConnectivityState.TRANSIENT_FAILURE; +import static io.grpc.internal.UriWrapper.wrap; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNull; @@ -61,7 +62,9 @@ import io.grpc.MethodDescriptor.MethodType; import io.grpc.NameResolver; import io.grpc.NameResolver.ResolutionResult; +import io.grpc.NameResolverProvider; import io.grpc.Status; +import io.grpc.StatusOr; import io.grpc.StringMarshaller; import io.grpc.internal.FakeClock.ScheduledTask; import io.grpc.internal.ManagedChannelImplBuilder.UnsupportedClientTransportFactoryBuilder; @@ -169,7 +172,9 @@ public void setUp() { when(mockTransportFactory.getSupportedSocketAddressTypes()) .thenReturn(Collections.singleton(InetSocketAddress.class)); - ManagedChannelImplBuilder builder = new ManagedChannelImplBuilder("mockscheme:///target", + String target = "mockscheme:///target"; + URI targetUri = URI.create(target); + ManagedChannelImplBuilder builder = new ManagedChannelImplBuilder(target, new UnsupportedClientTransportFactoryBuilder(), null); builder @@ -178,8 +183,11 @@ public void setUp() { .idleTimeout(IDLE_TIMEOUT_SECONDS, TimeUnit.SECONDS) .userAgent(USER_AGENT); builder.executorPool = executorPool; + NameResolverProvider nameResolverProvider = + builder.nameResolverRegistry.getProviderForScheme(targetUri.getScheme()); channel = new ManagedChannelImpl( - builder, mockTransportFactory, new FakeBackoffPolicyProvider(), + builder, mockTransportFactory, wrap(targetUri), nameResolverProvider, + new FakeBackoffPolicyProvider(), oobExecutorPool, timer.getStopwatchSupplier(), Collections.emptyList(), TimeProvider.SYSTEM_TIME_PROVIDER); @@ -609,7 +617,7 @@ private void deliverResolutionResult() { // the NameResolver. ResolutionResult resolutionResult = ResolutionResult.newBuilder() - .setAddresses(servers) + .setAddressesOrError(StatusOr.fromValue(servers)) .setAttributes(Attributes.EMPTY) .build(); nameResolverListenerCaptor.getValue().onResult(resolutionResult); diff --git a/core/src/test/java/io/grpc/internal/ManagedChannelImplTest.java b/core/src/test/java/io/grpc/internal/ManagedChannelImplTest.java index 04926cc25a5..ae224af27e1 100644 --- a/core/src/test/java/io/grpc/internal/ManagedChannelImplTest.java +++ b/core/src/test/java/io/grpc/internal/ManagedChannelImplTest.java @@ -26,7 +26,9 @@ import static io.grpc.ConnectivityState.SHUTDOWN; import static io.grpc.ConnectivityState.TRANSIENT_FAILURE; import static io.grpc.EquivalentAddressGroup.ATTR_AUTHORITY_OVERRIDE; +import static io.grpc.PickSubchannelArgsMatcher.eqPickSubchannelArgs; import static io.grpc.internal.ClientStreamListener.RpcProgress.PROCESSED; +import static io.grpc.internal.UriWrapper.wrap; import static junit.framework.TestCase.assertNotSame; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; @@ -56,6 +58,7 @@ import com.google.common.base.Throwables; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Iterables; import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.MoreExecutors; import com.google.common.util.concurrent.SettableFuture; @@ -83,6 +86,7 @@ import io.grpc.InternalChannelz; import io.grpc.InternalChannelz.ChannelStats; import io.grpc.InternalChannelz.ChannelTrace; +import io.grpc.InternalChannelz.ChannelTrace.Event.Severity; import io.grpc.InternalConfigSelector; import io.grpc.InternalInstrumented; import io.grpc.LoadBalancer; @@ -96,13 +100,17 @@ import io.grpc.LoadBalancer.SubchannelStateListener; import io.grpc.LoadBalancerProvider; import io.grpc.LoadBalancerRegistry; +import io.grpc.LongCounterMetricInstrument; import io.grpc.ManagedChannel; import io.grpc.Metadata; import io.grpc.MethodDescriptor; import io.grpc.MethodDescriptor.MethodType; +import io.grpc.MetricInstrumentRegistry; +import io.grpc.MetricSink; import io.grpc.NameResolver; import io.grpc.NameResolver.ConfigOrError; import io.grpc.NameResolver.ResolutionResult; +import io.grpc.NameResolverProvider; import io.grpc.NameResolverRegistry; import io.grpc.ProxiedSocketAddress; import io.grpc.ProxyDetector; @@ -110,13 +118,16 @@ import io.grpc.ServerMethodDefinition; import io.grpc.Status; import io.grpc.Status.Code; +import io.grpc.StatusOr; import io.grpc.StringMarshaller; +import io.grpc.SynchronizationContext; import io.grpc.internal.ClientTransportFactory.ClientTransportOptions; import io.grpc.internal.ClientTransportFactory.SwapChannelCredentialsResult; import io.grpc.internal.InternalSubchannel.TransportLogger; import io.grpc.internal.ManagedChannelImplBuilder.ClientTransportFactoryBuilder; import io.grpc.internal.ManagedChannelImplBuilder.FixedPortProvider; import io.grpc.internal.ManagedChannelImplBuilder.UnsupportedClientTransportFactoryBuilder; +import io.grpc.internal.ManagedChannelServiceConfig.MethodInfo; import io.grpc.internal.ServiceConfigUtil.PolicySelection; import io.grpc.internal.TestUtils.MockClientTransportInfo; import io.grpc.stub.ClientCalls; @@ -187,6 +198,17 @@ public class ManagedChannelImplTest { .setUserAgent(USER_AGENT); private static final String TARGET = "fake://" + SERVICE_NAME; private static final String MOCK_POLICY_NAME = "mock_lb"; + private static final NameResolver.Args NAMERESOLVER_ARGS = NameResolver.Args.newBuilder() + .setDefaultPort(447) + .setProxyDetector(mock(ProxyDetector.class)) + .setSynchronizationContext( + new SynchronizationContext(mock(Thread.UncaughtExceptionHandler.class))) + .setServiceConfigParser(mock(NameResolver.ServiceConfigParser.class)) + .setScheduledExecutorService(new FakeClock().getScheduledExecutorService()) + .build(); + private static final NameResolver.Args.Key TEST_RESOLVER_CUSTOM_ARG_KEY = + NameResolver.Args.Key.create("test-key"); + private URI expectedUri; private final SocketAddress socketAddress = new SocketAddress() { @@ -211,6 +233,9 @@ public String toString() { private final InternalChannelz channelz = new InternalChannelz(); + private final MetricInstrumentRegistry metricInstrumentRegistry = + MetricInstrumentRegistry.getDefaultRegistry(); + @Rule public final MockitoRule mocks = MockitoJUnit.rule(); private ManagedChannelImpl channel; @@ -261,10 +286,6 @@ public String getPolicyName() { @Mock private ClientCall.Listener mockCallListener3; @Mock - private ClientCall.Listener mockCallListener4; - @Mock - private ClientCall.Listener mockCallListener5; - @Mock private ObjectPool executorPool; @Mock private ObjectPool balancerRpcExecutorPool; @@ -292,8 +313,11 @@ private void createChannel(boolean nameResolutionExpectedToFail, when(mockTransportFactory.getSupportedSocketAddressTypes()).thenReturn(Collections.singleton( InetSocketAddress.class)); + NameResolverProvider nameResolverProvider = + channelBuilder.nameResolverRegistry.getProviderForScheme(expectedUri.getScheme()); channel = new ManagedChannelImpl( - channelBuilder, mockTransportFactory, new FakeBackoffPolicyProvider(), + channelBuilder, mockTransportFactory, wrap(expectedUri), nameResolverProvider, + new FakeBackoffPolicyProvider(), balancerRpcExecutorPool, timer.getStopwatchSupplier(), Arrays.asList(interceptors), timer.getTimeProvider()); @@ -481,7 +505,8 @@ public void startCallBeforeNameResolution() throws Exception { when(mockTransportFactory.getSupportedSocketAddressTypes()).thenReturn(Collections.singleton( InetSocketAddress.class)); channel = new ManagedChannelImpl( - channelBuilder, mockTransportFactory, new FakeBackoffPolicyProvider(), + channelBuilder, mockTransportFactory, wrap(expectedUri), nameResolverFactory, + new FakeBackoffPolicyProvider(), balancerRpcExecutorPool, timer.getStopwatchSupplier(), Collections.emptyList(), timer.getTimeProvider()); Map rawServiceConfig = @@ -545,7 +570,8 @@ public void newCallWithConfigSelector() { when(mockTransportFactory.getSupportedSocketAddressTypes()).thenReturn(Collections.singleton( InetSocketAddress.class)); channel = new ManagedChannelImpl( - channelBuilder, mockTransportFactory, new FakeBackoffPolicyProvider(), + channelBuilder, mockTransportFactory, wrap(expectedUri), nameResolverFactory, + new FakeBackoffPolicyProvider(), balancerRpcExecutorPool, timer.getStopwatchSupplier(), Collections.emptyList(), timer.getTimeProvider()); nameResolverFactory.nextConfigOrError.set( @@ -617,6 +643,74 @@ public ClientCall interceptCall( TimeUnit.SECONDS.toNanos(ManagedChannelImpl.SUBCHANNEL_SHUTDOWN_DELAY_SECONDS)); } + @Test + public void pickSubchannelAddOptionalLabel_callsTracer() { + channelBuilder.directExecutor(); + createChannel(); + + updateBalancingStateSafely(helper, TRANSIENT_FAILURE, new SubchannelPicker() { + @Override + public PickResult pickSubchannel(PickSubchannelArgs args) { + args.getPickDetailsConsumer().addOptionalLabel("routed", "perfectly"); + return PickResult.withError(Status.UNAVAILABLE.withDescription("expected")); + } + }); + ClientStreamTracer tracer = mock(ClientStreamTracer.class); + ClientStreamTracer.Factory tracerFactory = new ClientStreamTracer.Factory() { + @Override + public ClientStreamTracer newClientStreamTracer(StreamInfo info, Metadata headers) { + return tracer; + } + }; + ClientCall call = channel.newCall( + method, CallOptions.DEFAULT.withStreamTracerFactory(tracerFactory)); + call.start(mockCallListener, new Metadata()); + + verify(tracer).addOptionalLabel("routed", "perfectly"); + } + + @Test + public void metricRecorder_recordsToMetricSink() { + MetricSink mockSink = mock(MetricSink.class); + channelBuilder.addMetricSink(mockSink); + createChannel(); + + LongCounterMetricInstrument counter = metricInstrumentRegistry.registerLongCounter( + "recorder_duration", "Time taken by metric recorder", "s", + ImmutableList.of("grpc.method"), Collections.emptyList(), false); + List requiredLabelValues = ImmutableList.of("testMethod"); + List optionalLabelValues = Collections.emptyList(); + + helper.getMetricRecorder() + .addLongCounter(counter, 32, requiredLabelValues, optionalLabelValues); + verify(mockSink).addLongCounter(eq(counter), eq(32L), eq(requiredLabelValues), + eq(optionalLabelValues)); + } + + @Test + public void metricRecorder_fromNameResolverArgs_recordsToMetricSink() { + MetricSink mockSink1 = mock(MetricSink.class); + MetricSink mockSink2 = mock(MetricSink.class); + channelBuilder.addMetricSink(mockSink1); + channelBuilder.addMetricSink(mockSink2); + createChannel(); + + LongCounterMetricInstrument counter = metricInstrumentRegistry.registerLongCounter( + "test_counter", "Time taken by metric recorder", "s", + ImmutableList.of("grpc.method"), Collections.emptyList(), false); + List requiredLabelValues = ImmutableList.of("testMethod"); + List optionalLabelValues = Collections.emptyList(); + + NameResolver.Args args = helper.getNameResolverArgs(); + assertThat(args.getMetricRecorder()).isNotNull(); + args.getMetricRecorder() + .addLongCounter(counter, 10, requiredLabelValues, optionalLabelValues); + verify(mockSink1).addLongCounter(eq(counter), eq(10L), eq(requiredLabelValues), + eq(optionalLabelValues)); + verify(mockSink2).addLongCounter(eq(counter), eq(10L), eq(requiredLabelValues), + eq(optionalLabelValues)); + } + @Test public void shutdownWithNoTransportsEverCreated() { channelBuilder.nameResolverFactory( @@ -701,7 +795,8 @@ public void channelzMembership_subchannel() throws Exception { transportInfo.listener.transportReady(); // terminate transport - transportInfo.listener.transportShutdown(Status.CANCELLED); + transportInfo.listener.transportShutdown(Status.CANCELLED, + SimpleDisconnectError.SUBCHANNEL_SHUTDOWN); transportInfo.listener.transportTerminated(); assertFalse(channelz.containsClientSocket(transportInfo.transport.getLogId())); @@ -719,46 +814,6 @@ public void channelzMembership_subchannel() throws Exception { assertNotNull(channelz.getRootChannel(channel.getLogId().getId())); } - @Test - public void channelzMembership_oob() throws Exception { - createChannel(); - OobChannel oob = (OobChannel) helper.createOobChannel( - Collections.singletonList(addressGroup), AUTHORITY); - // oob channels are not root channels - assertNull(channelz.getRootChannel(oob.getLogId().getId())); - assertTrue(channelz.containsSubchannel(oob.getLogId())); - assertThat(getStats(channel).subchannels).containsExactly(oob); - assertTrue(channelz.containsSubchannel(oob.getLogId())); - - AbstractSubchannel subchannel = (AbstractSubchannel) oob.getSubchannel(); - assertTrue( - channelz.containsSubchannel(subchannel.getInstrumentedInternalSubchannel().getLogId())); - assertThat(getStats(oob).subchannels) - .containsExactly(subchannel.getInstrumentedInternalSubchannel()); - assertTrue( - channelz.containsSubchannel(subchannel.getInstrumentedInternalSubchannel().getLogId())); - - oob.getSubchannel().requestConnection(); - MockClientTransportInfo transportInfo = transports.poll(); - assertNotNull(transportInfo); - assertTrue(channelz.containsClientSocket(transportInfo.transport.getLogId())); - - // terminate transport - transportInfo.listener.transportShutdown(Status.INTERNAL); - transportInfo.listener.transportTerminated(); - assertFalse(channelz.containsClientSocket(transportInfo.transport.getLogId())); - - // terminate oobchannel - oob.shutdown(); - assertFalse(channelz.containsSubchannel(oob.getLogId())); - assertThat(getStats(channel).subchannels).isEmpty(); - assertFalse( - channelz.containsSubchannel(subchannel.getInstrumentedInternalSubchannel().getLogId())); - - // channel still appears - assertNotNull(channelz.getRootChannel(channel.getLogId().getId())); - } - @Test public void callsAndShutdown() { subtestCallsAndShutdown(false, false); @@ -808,10 +863,10 @@ private void subtestCallsAndShutdown(boolean shutdownNow, boolean shutdownNowAft .thenReturn(mockStream2); transportListener.transportReady(); when(mockPicker.pickSubchannel( - new PickSubchannelArgsImpl(method, headers, CallOptions.DEFAULT))).thenReturn( + eqPickSubchannelArgs(method, headers, CallOptions.DEFAULT))).thenReturn( PickResult.withNoResult()); when(mockPicker.pickSubchannel( - new PickSubchannelArgsImpl(method, headers2, CallOptions.DEFAULT))).thenReturn( + eqPickSubchannelArgs(method, headers2, CallOptions.DEFAULT))).thenReturn( PickResult.withSubchannel(subchannel)); updateBalancingStateSafely(helper, READY, mockPicker); @@ -875,7 +930,7 @@ private void subtestCallsAndShutdown(boolean shutdownNow, boolean shutdownNowAft assertFalse(nameResolverFactory.resolvers.get(0).shutdown); // call and call2 are still alive, and can still be assigned to a real transport SubchannelPicker picker2 = mock(SubchannelPicker.class); - when(picker2.pickSubchannel(new PickSubchannelArgsImpl(method, headers, CallOptions.DEFAULT))) + when(picker2.pickSubchannel(eqPickSubchannelArgs(method, headers, CallOptions.DEFAULT))) .thenReturn(PickResult.withSubchannel(subchannel)); updateBalancingStateSafely(helper, READY, picker2); executor.runDueTasks(); @@ -905,7 +960,8 @@ private void subtestCallsAndShutdown(boolean shutdownNow, boolean shutdownNowAft } // Killing the remaining real transport will terminate the channel - transportListener.transportShutdown(Status.UNAVAILABLE); + transportListener.transportShutdown(Status.UNAVAILABLE, + SimpleDisconnectError.SUBCHANNEL_SHUTDOWN); assertFalse(channel.isTerminated()); verify(executorPool, never()).returnObject(any()); transportListener.transportTerminated(); @@ -975,7 +1031,8 @@ public void noMoreCallbackAfterLoadBalancerShutdown() { // Since subchannels are shutdown, SubchannelStateListeners will only get SHUTDOWN regardless of // the transport states. - transportInfo1.listener.transportShutdown(Status.UNAVAILABLE); + transportInfo1.listener.transportShutdown(Status.UNAVAILABLE, + SimpleDisconnectError.SUBCHANNEL_SHUTDOWN); transportInfo2.listener.transportReady(); verify(stateListener1).onSubchannelState(ConnectivityStateInfo.forNonError(SHUTDOWN)); verify(stateListener2).onSubchannelState(ConnectivityStateInfo.forNonError(SHUTDOWN)); @@ -987,6 +1044,131 @@ public void noMoreCallbackAfterLoadBalancerShutdown() { verifyNoMoreInteractions(mockLoadBalancer); } + @Test + public void noMoreCallbackAfterLoadBalancerShutdown_configError() throws InterruptedException { + FakeNameResolverFactory nameResolverFactory = + new FakeNameResolverFactory.Builder(expectedUri) + .setServers(Collections.singletonList(new EquivalentAddressGroup(socketAddress))) + .build(); + channelBuilder.nameResolverFactory(nameResolverFactory); + Status resolutionError = Status.UNAVAILABLE.withDescription("Resolution failed"); + createChannel(); + + FakeNameResolverFactory.FakeNameResolver resolver = nameResolverFactory.resolvers.get(0); + verify(mockLoadBalancerProvider).newLoadBalancer(any(Helper.class)); + verify(mockLoadBalancer).acceptResolvedAddresses(resolvedAddressCaptor.capture()); + assertThat(resolvedAddressCaptor.getValue().getAddresses()).containsExactly(addressGroup); + + SubchannelStateListener stateListener1 = mock(SubchannelStateListener.class); + SubchannelStateListener stateListener2 = mock(SubchannelStateListener.class); + Subchannel subchannel1 = + createSubchannelSafely(helper, addressGroup, Attributes.EMPTY, stateListener1); + Subchannel subchannel2 = + createSubchannelSafely(helper, addressGroup, Attributes.EMPTY, stateListener2); + requestConnectionSafely(helper, subchannel1); + requestConnectionSafely(helper, subchannel2); + verify(mockTransportFactory, times(2)) + .newClientTransport( + any(SocketAddress.class), any(ClientTransportOptions.class), any(ChannelLogger.class)); + MockClientTransportInfo transportInfo1 = transports.poll(); + MockClientTransportInfo transportInfo2 = transports.poll(); + + // LoadBalancer receives all sorts of callbacks + transportInfo1.listener.transportReady(); + + verify(stateListener1, times(2)).onSubchannelState(stateInfoCaptor.capture()); + assertSame(CONNECTING, stateInfoCaptor.getAllValues().get(0).getState()); + assertSame(READY, stateInfoCaptor.getAllValues().get(1).getState()); + + verify(stateListener2).onSubchannelState(stateInfoCaptor.capture()); + assertSame(CONNECTING, stateInfoCaptor.getValue().getState()); + + channel.syncContext.execute(() -> + resolver.listener.onResult2( + ResolutionResult.newBuilder() + .setAddressesOrError(StatusOr.fromStatus(resolutionError)).build())); + verify(mockLoadBalancer).handleNameResolutionError(resolutionError); + + verifyNoMoreInteractions(mockLoadBalancer); + + channel.shutdown(); + verify(mockLoadBalancer).shutdown(); + verifyNoMoreInteractions(stateListener1, stateListener2); + + // LoadBalancer will normally shutdown all subchannels + shutdownSafely(helper, subchannel1); + shutdownSafely(helper, subchannel2); + + // Since subchannels are shutdown, SubchannelStateListeners will only get SHUTDOWN regardless of + // the transport states. + transportInfo1.listener.transportShutdown(Status.UNAVAILABLE, + SimpleDisconnectError.SUBCHANNEL_SHUTDOWN); + transportInfo2.listener.transportReady(); + verify(stateListener1).onSubchannelState(ConnectivityStateInfo.forNonError(SHUTDOWN)); + verify(stateListener2).onSubchannelState(ConnectivityStateInfo.forNonError(SHUTDOWN)); + verifyNoMoreInteractions(stateListener1, stateListener2); + + // No more callback should be delivered to LoadBalancer after it's shut down + channel.syncContext.execute(() -> + resolver.listener.onResult2( + ResolutionResult.newBuilder() + .setAddressesOrError(StatusOr.fromStatus(resolutionError)).build())); + assertThat(timer.getPendingTasks()).isEmpty(); + resolver.resolved(); + verifyNoMoreInteractions(mockLoadBalancer); + } + + @Test + public void addressResolutionError_noPriorNameResolution_usesDefaultServiceConfig() + throws Exception { + Map rawServiceConfig = + parseConfig("{\"methodConfig\":[{" + + "\"name\":[{\"service\":\"service\"}]," + + "\"waitForReady\":true}]}"); + ManagedChannelServiceConfig managedChannelServiceConfig = + createManagedChannelServiceConfig(rawServiceConfig, null); + FakeNameResolverFactory nameResolverFactory = + new FakeNameResolverFactory.Builder(expectedUri) + .setServers(Collections.singletonList(new EquivalentAddressGroup(socketAddress))) + .setResolvedAtStart(false) + .build(); + nameResolverFactory.nextConfigOrError.set( + ConfigOrError.fromConfig(managedChannelServiceConfig)); + channelBuilder.nameResolverFactory(nameResolverFactory); + Map defaultServiceConfig = + parseConfig("{\"methodConfig\":[{" + + "\"name\":[{\"service\":\"service\"}]," + + "\"waitForReady\":true}]}"); + channelBuilder.defaultServiceConfig(defaultServiceConfig); + Status resolutionError = Status.UNAVAILABLE.withDescription("Resolution failed"); + channelBuilder.maxTraceEvents(10); + createChannel(); + FakeNameResolverFactory.FakeNameResolver resolver = nameResolverFactory.resolvers.get(0); + + resolver.listener.onError(resolutionError); + + InternalConfigSelector configSelector = channel.getConfigSelector(); + ManagedChannelServiceConfig config = + (ManagedChannelServiceConfig) configSelector.selectConfig(null).getConfig(); + MethodInfo methodConfig = config.getMethodConfig(method); + assertThat(methodConfig.waitForReady).isTrue(); + timer.forwardNanos(1234); + assertThat(getStats(channel).channelTrace.events).contains(new ChannelTrace.Event.Builder() + .setDescription("Initial Name Resolution error, using default service config") + .setSeverity(Severity.CT_ERROR) + .setTimestampNanos(0) + .build()); + + // Check that "lastServiceConfig" variable has been set above: a config resolution with the same + // config simply gets ignored and not gets reassigned. + resolver.resolved(); + timer.forwardNanos(1234); + assertThat(Iterables.filter( + getStats(channel).channelTrace.events, + event -> event.description.equals("Service config changed"))) + .isEmpty(); + } + @Test public void interceptor() throws Exception { final AtomicLong atomic = new AtomicLong(); @@ -1056,7 +1238,8 @@ public void callOptionsExecutor() { verify(mockCallListener).onClose(same(Status.CANCELLED), same(trailers)); - transportListener.transportShutdown(Status.UNAVAILABLE); + transportListener.transportShutdown(Status.UNAVAILABLE, + SimpleDisconnectError.SUBCHANNEL_SHUTDOWN); transportListener.transportTerminated(); // Clean up as much as possible to allow the channel to terminate. @@ -1122,7 +1305,7 @@ public ClientStreamTracer newClientStreamTracer(StreamInfo info, Metadata header PickResult.withSubchannel(subchannel)); updateBalancingStateSafely(helper, READY, mockPicker); - assertEquals(2, executor.runDueTasks()); + assertEquals(3, executor.runDueTasks()); verify(mockPicker).pickSubchannel(any(PickSubchannelArgs.class)); verify(mockTransport).newStream( @@ -1208,7 +1391,8 @@ public void firstResolvedServerFailedToConnect() throws Exception { MockClientTransportInfo badTransportInfo = transports.poll(); // Which failed to connect - badTransportInfo.listener.transportShutdown(Status.UNAVAILABLE); + badTransportInfo.listener.transportShutdown(Status.UNAVAILABLE, + SimpleDisconnectError.SUBCHANNEL_SHUTDOWN); inOrder.verifyNoMoreInteractions(); // The channel then try the second address (goodAddress) @@ -1358,7 +1542,8 @@ public void allServersFailedToConnect() throws Exception { .newClientTransport( same(addr2), any(ClientTransportOptions.class), any(ChannelLogger.class)); MockClientTransportInfo transportInfo1 = transports.poll(); - transportInfo1.listener.transportShutdown(Status.UNAVAILABLE); + transportInfo1.listener.transportShutdown(Status.UNAVAILABLE, + SimpleDisconnectError.SUBCHANNEL_SHUTDOWN); // Connecting to server2, which will fail too verify(mockTransportFactory) @@ -1366,7 +1551,8 @@ public void allServersFailedToConnect() throws Exception { same(addr2), any(ClientTransportOptions.class), any(ChannelLogger.class)); MockClientTransportInfo transportInfo2 = transports.poll(); Status server2Error = Status.UNAVAILABLE.withDescription("Server2 failed to connect"); - transportInfo2.listener.transportShutdown(server2Error); + transportInfo2.listener.transportShutdown(server2Error, + SimpleDisconnectError.SUBCHANNEL_SHUTDOWN); // ... which makes the subchannel enter TRANSIENT_FAILURE. The last error Status is propagated // to LoadBalancer. @@ -1476,9 +1662,11 @@ public void run() { verify(transportInfo2.transport).shutdown(same(ManagedChannelImpl.SHUTDOWN_STATUS)); // Cleanup - transportInfo1.listener.transportShutdown(Status.UNAVAILABLE); + transportInfo1.listener.transportShutdown(Status.UNAVAILABLE, + SimpleDisconnectError.SUBCHANNEL_SHUTDOWN); transportInfo1.listener.transportTerminated(); - transportInfo2.listener.transportShutdown(Status.UNAVAILABLE); + transportInfo2.listener.transportShutdown(Status.UNAVAILABLE, + SimpleDisconnectError.SUBCHANNEL_SHUTDOWN); transportInfo2.listener.transportTerminated(); timer.forwardTime(ManagedChannelImpl.SUBCHANNEL_SHUTDOWN_DELAY_SECONDS, TimeUnit.SECONDS); } @@ -1518,8 +1706,10 @@ public void subchannelsWhenChannelShutdownNow() { verify(ti1.transport).shutdownNow(any(Status.class)); verify(ti2.transport).shutdownNow(any(Status.class)); - ti1.listener.transportShutdown(Status.UNAVAILABLE.withDescription("shutdown now")); - ti2.listener.transportShutdown(Status.UNAVAILABLE.withDescription("shutdown now")); + ti1.listener.transportShutdown(Status.UNAVAILABLE.withDescription("shutdown now"), + SimpleDisconnectError.SUBCHANNEL_SHUTDOWN); + ti2.listener.transportShutdown(Status.UNAVAILABLE.withDescription("shutdown now"), + SimpleDisconnectError.SUBCHANNEL_SHUTDOWN); ti1.listener.transportTerminated(); assertFalse(channel.isTerminated()); @@ -1546,6 +1736,19 @@ public void subchannelsNoConnectionShutdown() { any(SocketAddress.class), any(ClientTransportOptions.class), any(ChannelLogger.class)); } + @Test + public void subchannelsRequestConnectionNoopAfterShutdown() { + createChannel(); + Subchannel sub1 = + createSubchannelSafely(helper, addressGroup, Attributes.EMPTY, subchannelStateListener); + + shutdownSafely(helper, sub1); + requestConnectionSafely(helper, sub1); + verify(mockTransportFactory, never()) + .newClientTransport( + any(SocketAddress.class), any(ClientTransportOptions.class), any(ChannelLogger.class)); + } + @Test public void subchannelsNoConnectionShutdownNow() { createChannel(); @@ -1554,7 +1757,7 @@ public void subchannelsNoConnectionShutdownNow() { channel.shutdownNow(); verify(mockLoadBalancer).shutdown(); - // Channel's shutdownNow() will call shutdownNow() on all subchannels and oobchannels. + // Channel's shutdownNow() will call shutdownNow() on all subchannels. // Therefore, channel is terminated without relying on LoadBalancer to shutdown subchannels. assertTrue(channel.isTerminated()); verify(mockTransportFactory, never()) @@ -1562,112 +1765,6 @@ public void subchannelsNoConnectionShutdownNow() { any(SocketAddress.class), any(ClientTransportOptions.class), any(ChannelLogger.class)); } - @Test - public void oobchannels() { - createChannel(); - - ManagedChannel oob1 = helper.createOobChannel( - Collections.singletonList(addressGroup), "oob1authority"); - ManagedChannel oob2 = helper.createOobChannel( - Collections.singletonList(addressGroup), "oob2authority"); - verify(balancerRpcExecutorPool, times(2)).getObject(); - - assertEquals("oob1authority", oob1.authority()); - assertEquals("oob2authority", oob2.authority()); - - // OOB channels create connections lazily. A new call will initiate the connection. - Metadata headers = new Metadata(); - ClientCall call = oob1.newCall(method, CallOptions.DEFAULT); - call.start(mockCallListener, headers); - verify(mockTransportFactory) - .newClientTransport( - eq(socketAddress), - eq(new ClientTransportOptions().setAuthority("oob1authority").setUserAgent(USER_AGENT)), - isA(ChannelLogger.class)); - MockClientTransportInfo transportInfo = transports.poll(); - assertNotNull(transportInfo); - - assertEquals(0, balancerRpcExecutor.numPendingTasks()); - transportInfo.listener.transportReady(); - assertEquals(1, balancerRpcExecutor.runDueTasks()); - verify(transportInfo.transport).newStream( - same(method), same(headers), same(CallOptions.DEFAULT), - ArgumentMatchers.any()); - - // The transport goes away - transportInfo.listener.transportShutdown(Status.UNAVAILABLE); - transportInfo.listener.transportTerminated(); - - // A new call will trigger a new transport - ClientCall call2 = oob1.newCall(method, CallOptions.DEFAULT); - call2.start(mockCallListener2, headers); - ClientCall call3 = - oob1.newCall(method, CallOptions.DEFAULT.withWaitForReady()); - call3.start(mockCallListener3, headers); - verify(mockTransportFactory, times(2)).newClientTransport( - eq(socketAddress), - eq(new ClientTransportOptions().setAuthority("oob1authority").setUserAgent(USER_AGENT)), - isA(ChannelLogger.class)); - transportInfo = transports.poll(); - assertNotNull(transportInfo); - - // This transport fails - Status transportError = Status.UNAVAILABLE.withDescription("Connection refused"); - assertEquals(0, balancerRpcExecutor.numPendingTasks()); - transportInfo.listener.transportShutdown(transportError); - assertTrue(balancerRpcExecutor.runDueTasks() > 0); - - // Fail-fast RPC will fail, while wait-for-ready RPC will still be pending - verify(mockCallListener2).onClose(same(transportError), any(Metadata.class)); - verify(mockCallListener3, never()).onClose(any(Status.class), any(Metadata.class)); - - // Shutdown - assertFalse(oob1.isShutdown()); - assertFalse(oob2.isShutdown()); - oob1.shutdown(); - oob2.shutdownNow(); - assertTrue(oob1.isShutdown()); - assertTrue(oob2.isShutdown()); - assertTrue(oob2.isTerminated()); - verify(balancerRpcExecutorPool).returnObject(balancerRpcExecutor.getScheduledExecutorService()); - - // New RPCs will be rejected. - assertEquals(0, balancerRpcExecutor.numPendingTasks()); - ClientCall call4 = oob1.newCall(method, CallOptions.DEFAULT); - ClientCall call5 = oob2.newCall(method, CallOptions.DEFAULT); - call4.start(mockCallListener4, headers); - call5.start(mockCallListener5, headers); - assertTrue(balancerRpcExecutor.runDueTasks() > 0); - verify(mockCallListener4).onClose(statusCaptor.capture(), any(Metadata.class)); - Status status4 = statusCaptor.getValue(); - assertEquals(Status.Code.UNAVAILABLE, status4.getCode()); - verify(mockCallListener5).onClose(statusCaptor.capture(), any(Metadata.class)); - Status status5 = statusCaptor.getValue(); - assertEquals(Status.Code.UNAVAILABLE, status5.getCode()); - - // The pending RPC will still be pending - verify(mockCallListener3, never()).onClose(any(Status.class), any(Metadata.class)); - - // This will shutdownNow() the delayed transport, terminating the pending RPC - assertEquals(0, balancerRpcExecutor.numPendingTasks()); - oob1.shutdownNow(); - assertTrue(balancerRpcExecutor.runDueTasks() > 0); - verify(mockCallListener3).onClose(any(Status.class), any(Metadata.class)); - - // Shut down the channel, and it will not terminated because OOB channel has not. - channel.shutdown(); - assertFalse(channel.isTerminated()); - // Delayed transport has already terminated. Terminating the transport terminates the - // subchannel, which in turn terimates the OOB channel, which terminates the channel. - assertFalse(oob1.isTerminated()); - verify(balancerRpcExecutorPool).returnObject(balancerRpcExecutor.getScheduledExecutorService()); - transportInfo.listener.transportTerminated(); - assertTrue(oob1.isTerminated()); - assertTrue(channel.isTerminated()); - verify(balancerRpcExecutorPool, times(2)) - .returnObject(balancerRpcExecutor.getScheduledExecutorService()); - } - @Test public void oobChannelHasNoChannelCallCredentials() { Metadata.Key metadataKey = @@ -1719,7 +1816,7 @@ public void oobChannelHasNoChannelCallCredentials() { balancerRpcExecutor.runDueTasks(); verify(transportInfo.transport).newStream( - same(method), same(headers), same(callOptions), + same(method), same(headers), ArgumentMatchers.any(), ArgumentMatchers.any()); assertThat(headers.getAll(metadataKey)).containsExactly(callCredValue); oob.shutdownNow(); @@ -1846,74 +1943,6 @@ public SwapChannelCredentialsResult answer(InvocationOnMock invocation) { oob.shutdownNow(); } - @Test - public void oobChannelsWhenChannelShutdownNow() { - createChannel(); - ManagedChannel oob1 = helper.createOobChannel( - Collections.singletonList(addressGroup), "oob1Authority"); - ManagedChannel oob2 = helper.createOobChannel( - Collections.singletonList(addressGroup), "oob2Authority"); - - oob1.newCall(method, CallOptions.DEFAULT).start(mockCallListener, new Metadata()); - oob2.newCall(method, CallOptions.DEFAULT).start(mockCallListener2, new Metadata()); - - assertThat(transports).hasSize(2); - MockClientTransportInfo ti1 = transports.poll(); - MockClientTransportInfo ti2 = transports.poll(); - - ti1.listener.transportReady(); - ti2.listener.transportReady(); - - channel.shutdownNow(); - verify(ti1.transport).shutdownNow(any(Status.class)); - verify(ti2.transport).shutdownNow(any(Status.class)); - - ti1.listener.transportShutdown(Status.UNAVAILABLE.withDescription("shutdown now")); - ti2.listener.transportShutdown(Status.UNAVAILABLE.withDescription("shutdown now")); - ti1.listener.transportTerminated(); - - assertFalse(channel.isTerminated()); - ti2.listener.transportTerminated(); - assertTrue(channel.isTerminated()); - } - - @Test - public void oobChannelsNoConnectionShutdown() { - createChannel(); - ManagedChannel oob1 = helper.createOobChannel( - Collections.singletonList(addressGroup), "oob1Authority"); - ManagedChannel oob2 = helper.createOobChannel( - Collections.singletonList(addressGroup), "oob2Authority"); - channel.shutdown(); - - verify(mockLoadBalancer).shutdown(); - oob1.shutdown(); - assertTrue(oob1.isTerminated()); - assertFalse(channel.isTerminated()); - oob2.shutdown(); - assertTrue(oob2.isTerminated()); - assertTrue(channel.isTerminated()); - verify(mockTransportFactory, never()) - .newClientTransport( - any(SocketAddress.class), any(ClientTransportOptions.class), any(ChannelLogger.class)); - } - - @Test - public void oobChannelsNoConnectionShutdownNow() { - createChannel(); - helper.createOobChannel(Collections.singletonList(addressGroup), "oob1Authority"); - helper.createOobChannel(Collections.singletonList(addressGroup), "oob2Authority"); - channel.shutdownNow(); - - verify(mockLoadBalancer).shutdown(); - assertTrue(channel.isTerminated()); - // Channel's shutdownNow() will call shutdownNow() on all subchannels and oobchannels. - // Therefore, channel is terminated without relying on LoadBalancer to shutdown oobchannels. - verify(mockTransportFactory, never()) - .newClientTransport( - any(SocketAddress.class), any(ClientTransportOptions.class), any(ChannelLogger.class)); - } - @Test public void subchannelChannel_normalUsage() { createChannel(); @@ -2048,6 +2077,7 @@ public void lbHelper_getNameResolverArgs() { assertThat(args.getSynchronizationContext()) .isSameInstanceAs(helper.getSynchronizationContext()); assertThat(args.getServiceConfigParser()).isNotNull(); + assertThat(args.getMetricRecorder()).isNotNull(); } @Test @@ -2058,77 +2088,6 @@ public void lbHelper_getNonDefaultNameResolverRegistry() { .isNotSameInstanceAs(NameResolverRegistry.getDefaultRegistry()); } - @Test - public void refreshNameResolution_whenOobChannelConnectionFailed_notIdle() { - subtestNameResolutionRefreshWhenConnectionFailed(false); - } - - @Test - public void notRefreshNameResolution_whenOobChannelConnectionFailed_idle() { - subtestNameResolutionRefreshWhenConnectionFailed(true); - } - - private void subtestNameResolutionRefreshWhenConnectionFailed(boolean isIdle) { - FakeNameResolverFactory nameResolverFactory = - new FakeNameResolverFactory.Builder(expectedUri) - .setServers(Collections.singletonList(new EquivalentAddressGroup(socketAddress))) - .build(); - channelBuilder.nameResolverFactory(nameResolverFactory); - createChannel(); - OobChannel oobChannel = (OobChannel) helper.createOobChannel( - Collections.singletonList(addressGroup), "oobAuthority"); - oobChannel.getSubchannel().requestConnection(); - - MockClientTransportInfo transportInfo = transports.poll(); - assertNotNull(transportInfo); - - FakeNameResolverFactory.FakeNameResolver resolver = nameResolverFactory.resolvers.remove(0); - - if (isIdle) { - channel.enterIdle(); - // Entering idle mode will result in a new resolver - resolver = nameResolverFactory.resolvers.remove(0); - } - - assertEquals(0, nameResolverFactory.resolvers.size()); - - int expectedRefreshCount = 0; - - // Transport closed when connecting - assertEquals(expectedRefreshCount, resolver.refreshCalled); - transportInfo.listener.transportShutdown(Status.UNAVAILABLE); - // When channel enters idle, new resolver is created but not started. - if (!isIdle) { - expectedRefreshCount++; - } - assertEquals(expectedRefreshCount, resolver.refreshCalled); - - timer.forwardNanos(RECONNECT_BACKOFF_INTERVAL_NANOS); - transportInfo = transports.poll(); - assertNotNull(transportInfo); - - transportInfo.listener.transportReady(); - - // Transport closed when ready - assertEquals(expectedRefreshCount, resolver.refreshCalled); - transportInfo.listener.transportShutdown(Status.UNAVAILABLE); - // When channel enters idle, new resolver is created but not started. - if (!isIdle) { - expectedRefreshCount++; - } - assertEquals(expectedRefreshCount, resolver.refreshCalled); - } - - @Test - public void uriPattern() { - assertTrue(ManagedChannelImpl.URI_PATTERN.matcher("a:/").matches()); - assertTrue(ManagedChannelImpl.URI_PATTERN.matcher("Z019+-.:/!@ #~ ").matches()); - assertFalse(ManagedChannelImpl.URI_PATTERN.matcher("a/:").matches()); // "/:" not matched - assertFalse(ManagedChannelImpl.URI_PATTERN.matcher("0a:/").matches()); // '0' not matched - assertFalse(ManagedChannelImpl.URI_PATTERN.matcher("a,:/").matches()); // ',' not matched - assertFalse(ManagedChannelImpl.URI_PATTERN.matcher(" a:/").matches()); // space not matched - } - /** * Test that information such as the Call's context, MethodDescriptor, authority, executor are * propagated to newStream() and applyRequestMetadata(). @@ -2338,7 +2297,7 @@ public void getState_loadBalancerSupportsChannelState() { channelBuilder.nameResolverFactory( new FakeNameResolverFactory.Builder(expectedUri).setResolvedAtStart(false).build()); createChannel(); - assertEquals(IDLE, channel.getState(false)); + assertEquals(CONNECTING, channel.getState(false)); updateBalancingStateSafely(helper, TRANSIENT_FAILURE, mockPicker); assertEquals(TRANSIENT_FAILURE, channel.getState(false)); @@ -2378,7 +2337,6 @@ public void getState_withRequestConnect_IdleWithLbRunning() { assertEquals(IDLE, channel.getState(true)); verify(mockLoadBalancerProvider).newLoadBalancer(any(Helper.class)); - verify(mockPicker).requestConnection(); verify(mockLoadBalancer).requestConnection(); } @@ -2395,21 +2353,21 @@ public void run() { channelBuilder.nameResolverFactory( new FakeNameResolverFactory.Builder(expectedUri).setResolvedAtStart(false).build()); createChannel(); - assertEquals(IDLE, channel.getState(false)); + assertEquals(CONNECTING, channel.getState(false)); - channel.notifyWhenStateChanged(IDLE, onStateChanged); + channel.notifyWhenStateChanged(CONNECTING, onStateChanged); executor.runDueTasks(); assertFalse(stateChanged.get()); - // state change from IDLE to CONNECTING - updateBalancingStateSafely(helper, CONNECTING, mockPicker); + // state change from CONNECTING to IDLE + updateBalancingStateSafely(helper, IDLE, mockPicker); // onStateChanged callback should run executor.runDueTasks(); assertTrue(stateChanged.get()); - // clear and test form CONNECTING + // clear and test form IDLE stateChanged.set(false); - channel.notifyWhenStateChanged(IDLE, onStateChanged); + channel.notifyWhenStateChanged(CONNECTING, onStateChanged); // onStateChanged callback should run immediately executor.runDueTasks(); assertTrue(stateChanged.get()); @@ -2428,8 +2386,8 @@ public void run() { channelBuilder.nameResolverFactory( new FakeNameResolverFactory.Builder(expectedUri).setResolvedAtStart(false).build()); createChannel(); - assertEquals(IDLE, channel.getState(false)); - channel.notifyWhenStateChanged(IDLE, onStateChanged); + assertEquals(CONNECTING, channel.getState(false)); + channel.notifyWhenStateChanged(CONNECTING, onStateChanged); executor.runDueTasks(); assertFalse(stateChanged.get()); @@ -2452,9 +2410,6 @@ public void stateIsIdleOnIdleTimeout() { long idleTimeoutMillis = 2000L; channelBuilder.idleTimeout(idleTimeoutMillis, TimeUnit.MILLISECONDS); createChannel(); - assertEquals(IDLE, channel.getState(false)); - - updateBalancingStateSafely(helper, CONNECTING, mockPicker); assertEquals(CONNECTING, channel.getState(false)); timer.forwardNanos(TimeUnit.MILLISECONDS.toNanos(idleTimeoutMillis)); @@ -2677,11 +2632,11 @@ public void idleTimeoutAndReconnect() { // Updating on the old helper (whose balancer has been shutdown) does not change the channel // state. - updateBalancingStateSafely(helper, CONNECTING, mockPicker); - assertEquals(IDLE, channel.getState(false)); - - updateBalancingStateSafely(helper2, CONNECTING, mockPicker); + updateBalancingStateSafely(helper, IDLE, mockPicker); assertEquals(CONNECTING, channel.getState(false)); + + updateBalancingStateSafely(helper2, IDLE, mockPicker); + assertEquals(IDLE, channel.getState(false)); } @Test @@ -2695,7 +2650,7 @@ public void idleMode_resetsDelayedTransportPicker() { .setServers(Collections.singletonList(new EquivalentAddressGroup(socketAddress))) .build()); createChannel(); - assertEquals(IDLE, channel.getState(false)); + assertEquals(CONNECTING, channel.getState(false)); // This call will be buffered in delayedTransport ClientCall call = channel.newCall(method, CallOptions.DEFAULT); @@ -2790,7 +2745,7 @@ public void enterIdle_exitsIdleIfDelayedStreamPending() { // enterIdle() will shut down the name resolver and lb policy used to get a pick for the delayed // call channel.enterIdle(); - assertEquals(IDLE, channel.getState(false)); + assertEquals(CONNECTING, channel.getState(false)); // enterIdle() will restart the delayed call by exiting idle. This creates a new helper. ArgumentCaptor helperCaptor = ArgumentCaptor.forClass(Helper.class); @@ -2912,14 +2867,14 @@ public void updateBalancingStateWithShutdownShouldBeIgnored() { channelBuilder.nameResolverFactory( new FakeNameResolverFactory.Builder(expectedUri).setResolvedAtStart(false).build()); createChannel(); - assertEquals(IDLE, channel.getState(false)); + assertEquals(CONNECTING, channel.getState(false)); Runnable onStateChanged = mock(Runnable.class); - channel.notifyWhenStateChanged(IDLE, onStateChanged); + channel.notifyWhenStateChanged(CONNECTING, onStateChanged); updateBalancingStateSafely(helper, SHUTDOWN, mockPicker); - assertEquals(IDLE, channel.getState(false)); + assertEquals(CONNECTING, channel.getState(false)); executor.runDueTasks(); verify(onStateChanged, never()).run(); } @@ -3084,6 +3039,56 @@ public void channelTracing_nameResolvedEvent_zeorAndNonzeroBackends() throws Exc assertThat(getStats(channel).channelTrace.events).hasSize(prevSize + 1); } + @Test + public void channelTracing_nameResolvedEvent_zeorAndNonzeroBackends_usesListener2onResult2() + throws Exception { + timer.forwardNanos(1234); + channelBuilder.maxTraceEvents(10); + List servers = new ArrayList<>(); + servers.add(new EquivalentAddressGroup(socketAddress)); + FakeNameResolverFactory nameResolverFactory = + new FakeNameResolverFactory.Builder(expectedUri).setServers(servers).build(); + channelBuilder.nameResolverFactory(nameResolverFactory); + createChannel(); + + int prevSize = getStats(channel).channelTrace.events.size(); + ResolutionResult resolutionResult1 = ResolutionResult.newBuilder() + .setAddresses(Collections.singletonList( + new EquivalentAddressGroup( + Arrays.asList(new SocketAddress() {}, new SocketAddress() {})))) + .build(); + + channel.syncContext.execute( + () -> nameResolverFactory.resolvers.get(0).listener.onResult2(resolutionResult1)); + assertThat(getStats(channel).channelTrace.events).hasSize(prevSize); + + prevSize = getStats(channel).channelTrace.events.size(); + channel.syncContext.execute(() -> + nameResolverFactory.resolvers.get(0).listener.onResult2( + ResolutionResult.newBuilder() + .setAddressesOrError( + StatusOr.fromStatus(Status.INTERNAL)).build())); + assertThat(getStats(channel).channelTrace.events).hasSize(prevSize + 1); + + prevSize = getStats(channel).channelTrace.events.size(); + channel.syncContext.execute(() -> + nameResolverFactory.resolvers.get(0).listener.onResult2( + ResolutionResult.newBuilder() + .setAddressesOrError( + StatusOr.fromStatus(Status.INTERNAL)).build())); + assertThat(getStats(channel).channelTrace.events).hasSize(prevSize); + + prevSize = getStats(channel).channelTrace.events.size(); + ResolutionResult resolutionResult2 = ResolutionResult.newBuilder() + .setAddresses(Collections.singletonList( + new EquivalentAddressGroup( + Arrays.asList(new SocketAddress() {}, new SocketAddress() {})))) + .build(); + channel.syncContext.execute( + () -> nameResolverFactory.resolvers.get(0).listener.onResult2(resolutionResult2)); + assertThat(getStats(channel).channelTrace.events).hasSize(prevSize + 1); + } + @Test public void channelTracing_serviceConfigChange() throws Exception { timer.forwardNanos(1234); @@ -3143,6 +3148,69 @@ public void channelTracing_serviceConfigChange() throws Exception { .build()); } + @Test + public void channelTracing_serviceConfigChange_usesListener2OnResult2() throws Exception { + timer.forwardNanos(1234); + channelBuilder.maxTraceEvents(10); + List servers = new ArrayList<>(); + servers.add(new EquivalentAddressGroup(socketAddress)); + FakeNameResolverFactory nameResolverFactory = + new FakeNameResolverFactory.Builder(expectedUri).setServers(servers).build(); + channelBuilder.nameResolverFactory(nameResolverFactory); + createChannel(); + + int prevSize = getStats(channel).channelTrace.events.size(); + ManagedChannelServiceConfig mcsc1 = createManagedChannelServiceConfig( + ImmutableMap.of(), + new PolicySelection( + mockLoadBalancerProvider, null)); + ResolutionResult resolutionResult1 = ResolutionResult.newBuilder() + .setAddresses(Collections.singletonList( + new EquivalentAddressGroup( + Arrays.asList(new SocketAddress() {}, new SocketAddress() {})))) + .setServiceConfig(ConfigOrError.fromConfig(mcsc1)) + .build(); + + channel.syncContext.execute(() -> + nameResolverFactory.resolvers.get(0).listener.onResult2(resolutionResult1)); + assertThat(getStats(channel).channelTrace.events).hasSize(prevSize + 1); + assertThat(getStats(channel).channelTrace.events.get(prevSize)) + .isEqualTo(new ChannelTrace.Event.Builder() + .setDescription("Service config changed") + .setSeverity(ChannelTrace.Event.Severity.CT_INFO) + .setTimestampNanos(timer.getTicker().read()) + .build()); + + prevSize = getStats(channel).channelTrace.events.size(); + ResolutionResult resolutionResult2 = ResolutionResult.newBuilder().setAddresses( + Collections.singletonList( + new EquivalentAddressGroup( + Arrays.asList(new SocketAddress() {}, new SocketAddress() {})))) + .setServiceConfig(ConfigOrError.fromConfig(mcsc1)) + .build(); + channel.syncContext.execute(() -> + nameResolverFactory.resolvers.get(0).listener.onResult(resolutionResult2)); + assertThat(getStats(channel).channelTrace.events).hasSize(prevSize); + + prevSize = getStats(channel).channelTrace.events.size(); + timer.forwardNanos(1234); + ResolutionResult resolutionResult3 = ResolutionResult.newBuilder() + .setAddresses(Collections.singletonList( + new EquivalentAddressGroup( + Arrays.asList(new SocketAddress() {}, new SocketAddress() {})))) + .setServiceConfig(ConfigOrError.fromConfig(ManagedChannelServiceConfig.empty())) + .build(); + channel.syncContext.execute(() -> + nameResolverFactory.resolvers.get(0).listener.onResult(resolutionResult3)); + assertThat(getStats(channel).channelTrace.events).hasSize(prevSize + 1); + assertThat(getStats(channel).channelTrace.events.get(prevSize)) + .isEqualTo(new ChannelTrace.Event.Builder() + .setDescription("Service config changed") + .setSeverity(ChannelTrace.Event.Severity.CT_INFO) + .setTimestampNanos(timer.getTicker().read()) + .build()); + } + @Test public void channelTracing_stateChangeEvent() throws Exception { channelBuilder.maxTraceEvents(10); @@ -3172,48 +3240,6 @@ public void channelTracing_subchannelStateChangeEvent() throws Exception { .build()); } - @Test - public void channelTracing_oobChannelStateChangeEvent() throws Exception { - channelBuilder.maxTraceEvents(10); - createChannel(); - OobChannel oobChannel = (OobChannel) helper.createOobChannel( - Collections.singletonList(addressGroup), "authority"); - timer.forwardNanos(1234); - oobChannel.handleSubchannelStateChange( - ConnectivityStateInfo.forNonError(ConnectivityState.CONNECTING)); - assertThat(getStats(oobChannel).channelTrace.events).contains(new ChannelTrace.Event.Builder() - .setDescription("Entering CONNECTING state") - .setSeverity(ChannelTrace.Event.Severity.CT_INFO) - .setTimestampNanos(timer.getTicker().read()) - .build()); - } - - @Test - public void channelTracing_oobChannelCreationEvents() throws Exception { - channelBuilder.maxTraceEvents(10); - createChannel(); - timer.forwardNanos(1234); - OobChannel oobChannel = (OobChannel) helper.createOobChannel( - Collections.singletonList(addressGroup), "authority"); - assertThat(getStats(channel).channelTrace.events).contains(new ChannelTrace.Event.Builder() - .setDescription("Child OobChannel created") - .setSeverity(ChannelTrace.Event.Severity.CT_INFO) - .setTimestampNanos(timer.getTicker().read()) - .setChannelRef(oobChannel) - .build()); - assertThat(getStats(oobChannel).channelTrace.events).contains(new ChannelTrace.Event.Builder() - .setDescription("OobChannel for [[[test-addr]/{}]] created") - .setSeverity(ChannelTrace.Event.Severity.CT_INFO) - .setTimestampNanos(timer.getTicker().read()) - .build()); - assertThat(getStats(oobChannel.getInternalSubchannel()).channelTrace.events).contains( - new ChannelTrace.Event.Builder() - .setDescription("Subchannel for [[[test-addr]/{}]] created") - .setSeverity(ChannelTrace.Event.Severity.CT_INFO) - .setTimestampNanos(timer.getTicker().read()) - .build()); - } - @Test public void channelsAndSubchannels_instrumented_state() throws Exception { createChannel(); @@ -3222,8 +3248,6 @@ public void channelsAndSubchannels_instrumented_state() throws Exception { verify(mockLoadBalancerProvider).newLoadBalancer(helperCaptor.capture()); helper = helperCaptor.getValue(); - assertEquals(IDLE, getStats(channel).state); - updateBalancingStateSafely(helper, CONNECTING, mockPicker); assertEquals(CONNECTING, getStats(channel).state); AbstractSubchannel subchannel = @@ -3331,115 +3355,6 @@ private void channelsAndSubchannels_instrumented0(boolean success) throws Except } } - @Test - public void channelsAndSubchannels_oob_instrumented_success() throws Exception { - channelsAndSubchannels_oob_instrumented0(true); - } - - @Test - public void channelsAndSubchannels_oob_instrumented_fail() throws Exception { - channelsAndSubchannels_oob_instrumented0(false); - } - - private void channelsAndSubchannels_oob_instrumented0(boolean success) throws Exception { - // set up - ClientStream mockStream = mock(ClientStream.class); - createChannel(); - - OobChannel oobChannel = (OobChannel) helper.createOobChannel( - Collections.singletonList(addressGroup), "oobauthority"); - AbstractSubchannel oobSubchannel = (AbstractSubchannel) oobChannel.getSubchannel(); - FakeClock callExecutor = new FakeClock(); - CallOptions options = - CallOptions.DEFAULT.withExecutor(callExecutor.getScheduledExecutorService()); - ClientCall call = oobChannel.newCall(method, options); - Metadata headers = new Metadata(); - - // Channel stat bumped when ClientCall.start() called - assertEquals(0, getStats(oobChannel).callsStarted); - call.start(mockCallListener, headers); - assertEquals(1, getStats(oobChannel).callsStarted); - - MockClientTransportInfo transportInfo = transports.poll(); - ConnectionClientTransport mockTransport = transportInfo.transport; - ManagedClientTransport.Listener transportListener = transportInfo.listener; - when(mockTransport.newStream( - same(method), same(headers), any(CallOptions.class), - ArgumentMatchers.any())) - .thenReturn(mockStream); - - // subchannel stat bumped when call gets assigned to it - assertEquals(0, getStats(oobSubchannel).callsStarted); - transportListener.transportReady(); - callExecutor.runDueTasks(); - verify(mockStream).start(streamListenerCaptor.capture()); - assertEquals(1, getStats(oobSubchannel).callsStarted); - - ClientStreamListener streamListener = streamListenerCaptor.getValue(); - call.halfClose(); - - // closing stream listener affects subchannel stats immediately - assertEquals(0, getStats(oobSubchannel).callsSucceeded); - assertEquals(0, getStats(oobSubchannel).callsFailed); - streamListener.closed(success ? Status.OK : Status.UNKNOWN, PROCESSED, new Metadata()); - if (success) { - assertEquals(1, getStats(oobSubchannel).callsSucceeded); - assertEquals(0, getStats(oobSubchannel).callsFailed); - } else { - assertEquals(0, getStats(oobSubchannel).callsSucceeded); - assertEquals(1, getStats(oobSubchannel).callsFailed); - } - - // channel stats bumped when the ClientCall.Listener is notified - assertEquals(0, getStats(oobChannel).callsSucceeded); - assertEquals(0, getStats(oobChannel).callsFailed); - callExecutor.runDueTasks(); - if (success) { - assertEquals(1, getStats(oobChannel).callsSucceeded); - assertEquals(0, getStats(oobChannel).callsFailed); - } else { - assertEquals(0, getStats(oobChannel).callsSucceeded); - assertEquals(1, getStats(oobChannel).callsFailed); - } - // oob channel is separate from the original channel - assertEquals(0, getStats(channel).callsSucceeded); - assertEquals(0, getStats(channel).callsFailed); - } - - @Test - public void channelsAndSubchannels_oob_instrumented_name() throws Exception { - createChannel(); - - String authority = "oobauthority"; - OobChannel oobChannel = (OobChannel) helper.createOobChannel( - Collections.singletonList(addressGroup), authority); - assertEquals(authority, getStats(oobChannel).target); - } - - @Test - public void channelsAndSubchannels_oob_instrumented_state() throws Exception { - createChannel(); - - OobChannel oobChannel = (OobChannel) helper.createOobChannel( - Collections.singletonList(addressGroup), "oobauthority"); - assertEquals(IDLE, getStats(oobChannel).state); - - oobChannel.getSubchannel().requestConnection(); - assertEquals(CONNECTING, getStats(oobChannel).state); - - MockClientTransportInfo transportInfo = transports.poll(); - ManagedClientTransport.Listener transportListener = transportInfo.listener; - - transportListener.transportReady(); - assertEquals(READY, getStats(oobChannel).state); - - // oobchannel state is separate from the ManagedChannel - assertEquals(IDLE, getStats(channel).state); - channel.shutdownNow(); - assertEquals(SHUTDOWN, getStats(channel).state); - assertEquals(SHUTDOWN, getStats(oobChannel).state); - } - @Test public void binaryLogInstalled() throws Exception { final SettableFuture intercepted = SettableFuture.create(); @@ -3525,8 +3440,6 @@ public double nextDouble() { verify(mockLoadBalancer).acceptResolvedAddresses(resolvedAddressCaptor.capture()); ResolvedAddresses resolvedAddresses = resolvedAddressCaptor.getValue(); assertThat(resolvedAddresses.getAddresses()).isEqualTo(nameResolverFactory.servers); - assertThat(resolvedAddresses.getAttributes() - .get(RetryingNameResolver.RESOLUTION_RESULT_LISTENER_KEY)).isNotNull(); // simulating request connection and then transport ready after resolved address Subchannel subchannel = @@ -3555,7 +3468,7 @@ public double nextDouble() { Status.UNAVAILABLE, PROCESSED, new Metadata()); // in backoff - timer.forwardTime(5, TimeUnit.SECONDS); + timer.forwardTime(6, TimeUnit.SECONDS); assertThat(timer.getPendingTasks()).hasSize(1); verify(mockStream2, never()).start(any(ClientStreamListener.class)); @@ -3574,7 +3487,7 @@ public double nextDouble() { assertEquals("Channel shutdown invoked", statusCaptor.getValue().getDescription()); // backoff ends - timer.forwardTime(5, TimeUnit.SECONDS); + timer.forwardTime(6, TimeUnit.SECONDS); assertThat(timer.getPendingTasks()).isEmpty(); verify(mockStream2).start(streamListenerCaptor.capture()); verify(mockLoadBalancer, never()).shutdown(); @@ -3587,7 +3500,8 @@ public double nextDouble() { verify(mockLoadBalancer).shutdown(); // simulating the shutdown of load balancer triggers the shutdown of subchannel shutdownSafely(helper, subchannel); - transportInfo.listener.transportShutdown(Status.INTERNAL); + transportInfo.listener.transportShutdown(Status.INTERNAL, + SimpleDisconnectError.SUBCHANNEL_SHUTDOWN); transportInfo.listener.transportTerminated(); // simulating transport terminated assertTrue( "channel.isTerminated() is expected to be true but was false", @@ -3632,8 +3546,6 @@ public void hedgingScheduledThenChannelShutdown_hedgeShouldStillHappen_newCallSh verify(mockLoadBalancer).acceptResolvedAddresses(resolvedAddressCaptor.capture()); ResolvedAddresses resolvedAddresses = resolvedAddressCaptor.getValue(); assertThat(resolvedAddresses.getAddresses()).isEqualTo(nameResolverFactory.servers); - assertThat(resolvedAddresses.getAttributes() - .get(RetryingNameResolver.RESOLUTION_RESULT_LISTENER_KEY)).isNotNull(); // simulating request connection and then transport ready after resolved address Subchannel subchannel = @@ -3694,7 +3606,8 @@ public void hedgingScheduledThenChannelShutdown_hedgeShouldStillHappen_newCallSh // simulating the shutdown of load balancer triggers the shutdown of subchannel shutdownSafely(helper, subchannel); // simulating transport shutdown & terminated - transportInfo.listener.transportShutdown(Status.INTERNAL); + transportInfo.listener.transportShutdown(Status.INTERNAL, + SimpleDisconnectError.SUBCHANNEL_SHUTDOWN); transportInfo.listener.transportTerminated(); assertTrue( "channel.isTerminated() is expected to be true but was false", @@ -3805,6 +3718,120 @@ public ClientTransportFactory buildClientTransportFactory() { mychannel.shutdownNow(); } + @Test + public void badServiceConfigIsRecoverable_usesListener2OnResult2() throws Exception { + final List addresses = + ImmutableList.of(new EquivalentAddressGroup(new SocketAddress() {})); + final class FakeNameResolver extends NameResolver { + Listener2 listener; + private final SynchronizationContext syncContext; + + FakeNameResolver(Args args) { + this.syncContext = args.getSynchronizationContext(); + } + + @Override + public String getServiceAuthority() { + return "also fake"; + } + + @Override + public void start(Listener2 listener) { + this.listener = listener; + syncContext.execute(() -> + listener.onResult2( + ResolutionResult.newBuilder() + .setAddresses(addresses) + .setServiceConfig( + ConfigOrError.fromError( + Status.INTERNAL.withDescription("kaboom is invalid"))) + .build())); + } + + @Override + public void shutdown() {} + } + + final class FakeNameResolverFactory2 extends NameResolver.Factory { + FakeNameResolver resolver; + ManagedChannelImpl managedChannel; + SynchronizationContext syncContext; + + @Nullable + @Override + public NameResolver newNameResolver(URI targetUri, NameResolver.Args args) { + syncContext = args.getSynchronizationContext(); + return (resolver = new FakeNameResolver(args)); + } + + @Override + public String getDefaultScheme() { + return "fake"; + } + } + + FakeNameResolverFactory2 factory = new FakeNameResolverFactory2(); + + ManagedChannelImplBuilder customBuilder = new ManagedChannelImplBuilder(TARGET, + new ClientTransportFactoryBuilder() { + @Override + public ClientTransportFactory buildClientTransportFactory() { + return mockTransportFactory; + } + }, + null); + when(mockTransportFactory.getSupportedSocketAddressTypes()).thenReturn(Collections.singleton( + InetSocketAddress.class)); + customBuilder.executorPool = executorPool; + customBuilder.channelz = channelz; + ManagedChannel mychannel = customBuilder.nameResolverFactory(factory).build(); + + ClientCall call1 = + mychannel.newCall(TestMethodDescriptors.voidMethod(), CallOptions.DEFAULT); + ListenableFuture future1 = ClientCalls.futureUnaryCall(call1, null); + executor.runDueTasks(); + try { + future1.get(1, TimeUnit.SECONDS); + Assert.fail(); + } catch (ExecutionException e) { + assertThat(Throwables.getStackTraceAsString(e.getCause())).contains("kaboom"); + } + + // ok the service config is bad, let's fix it. + Map rawServiceConfig = + parseConfig("{\"loadBalancingConfig\": [{\"round_robin\": {}}]}"); + Object fakeLbConfig = new Object(); + PolicySelection lbConfigs = + new PolicySelection( + mockLoadBalancerProvider, fakeLbConfig); + mockLoadBalancerProvider.parseLoadBalancingPolicyConfig(rawServiceConfig); + ManagedChannelServiceConfig managedChannelServiceConfig = + createManagedChannelServiceConfig(rawServiceConfig, lbConfigs); + factory.syncContext.execute(() -> + factory.resolver.listener.onResult2( + ResolutionResult.newBuilder() + .setAddresses(addresses) + .setServiceConfig(ConfigOrError.fromConfig(managedChannelServiceConfig)) + .build())); + + ClientCall call2 = mychannel.newCall( + TestMethodDescriptors.voidMethod(), + CallOptions.DEFAULT.withDeadlineAfter(5, TimeUnit.SECONDS)); + ListenableFuture future2 = ClientCalls.futureUnaryCall(call2, null); + + timer.forwardTime(1234, TimeUnit.SECONDS); + + executor.runDueTasks(); + try { + future2.get(); + Assert.fail(); + } catch (ExecutionException e) { + assertThat(Throwables.getStackTraceAsString(e.getCause())).contains("deadline"); + } + + mychannel.shutdownNow(); + } + @Test public void nameResolverArgsPropagation() { final AtomicReference capturedArgs = new AtomicReference<>(); @@ -3839,13 +3866,18 @@ public String getDefaultScheme() { return "fake"; } }; - channelBuilder.nameResolverFactory(factory).proxyDetector(neverProxy); + channelBuilder + .nameResolverFactory(factory) + .proxyDetector(neverProxy) + .setNameResolverArg(TEST_RESOLVER_CUSTOM_ARG_KEY, "test-value"); + createChannel(); NameResolver.Args args = capturedArgs.get(); assertThat(args).isNotNull(); assertThat(args.getDefaultPort()).isEqualTo(DEFAULT_PORT); assertThat(args.getProxyDetector()).isSameInstanceAs(neverProxy); + assertThat(args.getArg(TEST_RESOLVER_CUSTOM_ARG_KEY)).isEqualTo("test-value"); verify(offloadExecutor, never()).execute(any(Runnable.class)); args.getOffloadExecutor() @@ -3910,13 +3942,37 @@ public void nameResolverHelper_badConfigFails() { assertThat(coe.getError().getCause()).isInstanceOf(ClassCastException.class); } + @Test + public void nameResolverHelper_badParser_failsGracefully() { + boolean retryEnabled = false; + int maxRetryAttemptsLimit = 2; + int maxHedgedAttemptsLimit = 3; + + Throwable t = new Error("really poor config parser"); + when(mockLoadBalancerProvider.parseLoadBalancingPolicyConfig(any())).thenThrow(t); + ScParser parser = new ScParser( + retryEnabled, + maxRetryAttemptsLimit, + maxHedgedAttemptsLimit, + mockLoadBalancerProvider); + + ConfigOrError coe = parser.parseServiceConfig(ImmutableMap.of()); + + assertThat(coe.getError()).isNotNull(); + assertThat(coe.getError().getCode()).isEqualTo(Code.INTERNAL); + assertThat(coe.getError().getDescription()).contains("Unexpected error parsing service config"); + assertThat(coe.getError().getCause()).isSameInstanceAs(t); + } + @Test public void nameResolverHelper_noConfigChosen() { boolean retryEnabled = false; int maxRetryAttemptsLimit = 2; int maxHedgedAttemptsLimit = 3; + LoadBalancerRegistry registry = new LoadBalancerRegistry(); + registry.register(mockLoadBalancerProvider); AutoConfiguredLoadBalancerFactory autoConfiguredLoadBalancerFactory = - new AutoConfiguredLoadBalancerFactory("pick_first"); + new AutoConfiguredLoadBalancerFactory(registry, MOCK_POLICY_NAME); ScParser parser = new ScParser( retryEnabled, @@ -4006,6 +4062,40 @@ public void disableServiceConfigLookUp_withDefaultConfig() throws Exception { } } + @Test + public void disableServiceConfigLookUp_withDefaultConfig_withRetryThrottle() throws Exception { + LoadBalancerRegistry.getDefaultRegistry().register(mockLoadBalancerProvider); + try { + FakeNameResolverFactory nameResolverFactory = + new FakeNameResolverFactory.Builder(expectedUri) + .setServers(ImmutableList.of(addressGroup)).build(); + channelBuilder.nameResolverFactory(nameResolverFactory); + channelBuilder.disableServiceConfigLookUp(); + channelBuilder.enableRetry(); + Map defaultServiceConfig = + parseConfig("{" + + "\"retryThrottling\":{\"maxTokens\": 1, \"tokenRatio\": 1}," + + "\"methodConfig\":[{" + + "\"name\":[{\"service\":\"SimpleService1\"}]," + + "\"waitForReady\":true" + + "}]}"); + channelBuilder.defaultServiceConfig(defaultServiceConfig); + + createChannel(); + + ArgumentCaptor resultCaptor = + ArgumentCaptor.forClass(ResolvedAddresses.class); + verify(mockLoadBalancer).acceptResolvedAddresses(resultCaptor.capture()); + assertThat(resultCaptor.getValue().getAddresses()).containsExactly(addressGroup); + assertThat(resultCaptor.getValue().getAttributes().get(InternalConfigSelector.KEY)).isNull(); + verify(mockLoadBalancer, never()).handleNameResolutionError(any(Status.class)); + assertThat(channel.hasThrottle()).isTrue(); + + } finally { + LoadBalancerRegistry.getDefaultRegistry().deregister(mockLoadBalancerProvider); + } + } + @Test public void enableServiceConfigLookUp_noDefaultConfig() throws Exception { LoadBalancerRegistry.getDefaultRegistry().register(mockLoadBalancerProvider); @@ -4113,6 +4203,39 @@ public void enableServiceConfigLookUp_resolverReturnsNoConfig_withDefaultConfig( } } + + @Test + public void enableServiceConfigLookUp_usingDefaultConfig_withRetryThrottling() throws Exception { + LoadBalancerRegistry.getDefaultRegistry().register(mockLoadBalancerProvider); + try { + FakeNameResolverFactory nameResolverFactory = + new FakeNameResolverFactory.Builder(expectedUri) + .setServers(ImmutableList.of(addressGroup)).build(); + channelBuilder.nameResolverFactory(nameResolverFactory); + channelBuilder.enableRetry(); + Map defaultServiceConfig = + parseConfig("{" + + "\"retryThrottling\":{\"maxTokens\": 1, \"tokenRatio\": 1}," + + "\"methodConfig\":[{" + + "\"name\":[{\"service\":\"SimpleService1\"}]," + + "\"waitForReady\":true" + + "}]}"); + channelBuilder.defaultServiceConfig(defaultServiceConfig); + + createChannel(); + + ArgumentCaptor resultCaptor = + ArgumentCaptor.forClass(ResolvedAddresses.class); + verify(mockLoadBalancer).acceptResolvedAddresses(resultCaptor.capture()); + assertThat(resultCaptor.getValue().getAddresses()).containsExactly(addressGroup); + assertThat(resultCaptor.getValue().getAttributes().get(InternalConfigSelector.KEY)).isNull(); + verify(mockLoadBalancer, never()).handleNameResolutionError(any(Status.class)); + assertThat(channel.hasThrottle()).isTrue(); + } finally { + LoadBalancerRegistry.getDefaultRegistry().deregister(mockLoadBalancerProvider); + } + } + @Test public void enableServiceConfigLookUp_resolverReturnsNoConfig_noDefaultConfig() { LoadBalancerRegistry.getDefaultRegistry().register(mockLoadBalancerProvider); @@ -4184,7 +4307,7 @@ public void notUseDefaultImmediatelyIfEnableLookUp() throws Exception { int size = getStats(channel).channelTrace.events.size(); assertThat(getStats(channel).channelTrace.events.get(size - 1)) .isNotEqualTo(new ChannelTrace.Event.Builder() - .setDescription("Using default service config") + .setDescription("timer.forwardNanos(1234);") .setSeverity(ChannelTrace.Event.Severity.CT_INFO) .setTimestampNanos(timer.getTicker().read()) .build()); @@ -4278,12 +4401,86 @@ public void transportTerminated(Attributes transportAttrs) { assertEquals(1, readyCallbackCalled.get()); assertEquals(0, terminationCallbackCalled.get()); - transportListener.transportShutdown(Status.OK); + transportListener.transportShutdown(Status.OK, SimpleDisconnectError.SUBCHANNEL_SHUTDOWN); transportListener.transportTerminated(); assertEquals(1, terminationCallbackCalled.get()); } + @Test + public void validAuthorityTarget_overrideAuthority() throws Exception { + String overrideAuthority = "override.authority"; + String serviceAuthority = "fakeauthority"; + NameResolverProvider nameResolverProvider = new NameResolverProvider() { + @Override protected boolean isAvailable() { + return true; + } + + @Override protected int priority() { + return 5; + } + + @Override public NameResolver newNameResolver(URI targetUri, NameResolver.Args args) { + return new NameResolver() { + @Override public String getServiceAuthority() { + return serviceAuthority; + } + + @Override public void start(final Listener2 listener) {} + + @Override public void shutdown() {} + }; + } + + @Override public String getDefaultScheme() { + return "defaultscheme"; + } + }; + + URI targetUri = new URI("defaultscheme", "", "/foo.googleapis.com:8080", null); + NameResolver nameResolver = ManagedChannelImpl.getNameResolver( + wrap(targetUri), null, nameResolverProvider, NAMERESOLVER_ARGS); + assertThat(nameResolver.getServiceAuthority()).isEqualTo(serviceAuthority); + + nameResolver = ManagedChannelImpl.getNameResolver( + wrap(targetUri), overrideAuthority, nameResolverProvider, NAMERESOLVER_ARGS); + assertThat(nameResolver.getServiceAuthority()).isEqualTo(overrideAuthority); + } + + @Test + public void validTargetNoResolver_throws() { + NameResolverProvider nameResolverProvider = new NameResolverProvider() { + @Override + protected boolean isAvailable() { + return true; + } + + @Override + protected int priority() { + return 5; + } + + @Override + public NameResolver newNameResolver(URI targetUri, NameResolver.Args args) { + return null; + } + + @Override + public String getDefaultScheme() { + return "defaultscheme"; + } + }; + try { + ManagedChannelImpl.getNameResolver( + wrap(URI.create("defaultscheme:///foo.gogoleapis.com:8080")), + null, nameResolverProvider, NAMERESOLVER_ARGS); + fail("Should fail"); + } catch (IllegalArgumentException e) { + // expected + } + } + + private static final class FakeBackoffPolicyProvider implements BackoffPolicy.Provider { @Override public BackoffPolicy get() { @@ -4298,7 +4495,7 @@ public long nextBackoffNanos() { } } - private static final class FakeNameResolverFactory extends NameResolver.Factory { + private static final class FakeNameResolverFactory extends NameResolverProvider { final List expectedUris; final List servers; final boolean resolvedAtStart; @@ -4325,7 +4522,7 @@ public NameResolver newNameResolver(final URI targetUri, NameResolver.Args args) } assertEquals(DEFAULT_PORT, args.getDefaultPort()); FakeNameResolverFactory.FakeNameResolver resolver = - new FakeNameResolverFactory.FakeNameResolver(targetUri, error); + new FakeNameResolverFactory.FakeNameResolver(targetUri, error, args); resolvers.add(resolver); return resolver; } @@ -4335,6 +4532,16 @@ public String getDefaultScheme() { return "fake"; } + @Override + public int priority() { + return 9; + } + + @Override + public boolean isAvailable() { + return true; + } + void allResolved() { for (FakeNameResolverFactory.FakeNameResolver resolver : resolvers) { resolver.resolved(); @@ -4343,14 +4550,16 @@ void allResolved() { final class FakeNameResolver extends NameResolver { final URI targetUri; + final SynchronizationContext syncContext; Listener2 listener; boolean shutdown; int refreshCalled; Status error; - FakeNameResolver(URI targetUri, Status error) { + FakeNameResolver(URI targetUri, Status error, Args args) { this.targetUri = targetUri; this.error = error; + syncContext = args.getSynchronizationContext(); } @Override public String getServiceAuthority() { @@ -4371,7 +4580,10 @@ final class FakeNameResolver extends NameResolver { void resolved() { if (error != null) { - listener.onError(error); + syncContext.execute(() -> + listener.onResult2( + ResolutionResult.newBuilder() + .setAddressesOrError(StatusOr.fromStatus(error)).build())); return; } ResolutionResult.Builder builder = @@ -4382,7 +4594,7 @@ void resolved() { if (configOrError != null) { builder.setServiceConfig(configOrError); } - listener.onResult(builder.build()); + syncContext.execute(() -> listener.onResult(builder.build())); } @Override public void shutdown() { diff --git a/core/src/test/java/io/grpc/internal/ManagedChannelOrphanWrapperTest.java b/core/src/test/java/io/grpc/internal/ManagedChannelOrphanWrapperTest.java index 5ae97c69211..45fb3881722 100644 --- a/core/src/test/java/io/grpc/internal/ManagedChannelOrphanWrapperTest.java +++ b/core/src/test/java/io/grpc/internal/ManagedChannelOrphanWrapperTest.java @@ -101,6 +101,45 @@ public boolean isDone() { } } + @Test + public void shutdown_withDelegateStillReferenced_doesNotLogWarning() { + ManagedChannel mc = new TestManagedChannel(); + final ReferenceQueue refqueue = new ReferenceQueue<>(); + ConcurrentMap refs = + new ConcurrentHashMap<>(); + + ManagedChannelOrphanWrapper wrapper = new ManagedChannelOrphanWrapper(mc, refqueue, refs); + WeakReference wrapperWeakRef = new WeakReference<>(wrapper); + + final List records = new ArrayList<>(); + Logger orphanLogger = Logger.getLogger(ManagedChannelOrphanWrapper.class.getName()); + Filter oldFilter = orphanLogger.getFilter(); + orphanLogger.setFilter(new Filter() { + @Override + public boolean isLoggable(LogRecord record) { + synchronized (records) { + records.add(record); + } + return false; + } + }); + + try { + wrapper.shutdown(); + wrapper = null; + + // Wait for the WRAPPER itself to be garbage collected + GcFinalization.awaitClear(wrapperWeakRef); + ManagedChannelReference.cleanQueue(refqueue); + + synchronized (records) { + assertEquals("Warning was logged even though shutdownNow() was called!", 0, records.size()); + } + } finally { + orphanLogger.setFilter(oldFilter); + } + } + @Test public void refCycleIsGCed() { ReferenceQueue refqueue = diff --git a/core/src/test/java/io/grpc/internal/ManagedChannelServiceConfigTest.java b/core/src/test/java/io/grpc/internal/ManagedChannelServiceConfigTest.java index c25a0808584..fefc37e4fdc 100644 --- a/core/src/test/java/io/grpc/internal/ManagedChannelServiceConfigTest.java +++ b/core/src/test/java/io/grpc/internal/ManagedChannelServiceConfigTest.java @@ -20,6 +20,7 @@ import static io.grpc.MethodDescriptor.MethodType.UNARY; import static io.grpc.Status.Code.UNAVAILABLE; import static java.util.concurrent.TimeUnit.MILLISECONDS; +import static org.junit.Assert.assertThrows; import static org.junit.Assert.fail; import com.google.common.collect.ImmutableList; @@ -27,25 +28,20 @@ import io.grpc.CallOptions; import io.grpc.InternalConfigSelector; import io.grpc.InternalConfigSelector.Result; +import io.grpc.LoadBalancer.PickDetailsConsumer; import io.grpc.Metadata; import io.grpc.MethodDescriptor; import io.grpc.internal.ManagedChannelServiceConfig.MethodInfo; import io.grpc.testing.TestMethodDescriptors; import java.util.Collections; import java.util.Map; -import org.junit.Rule; import org.junit.Test; -import org.junit.rules.ExpectedException; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @RunWith(JUnit4.class) public class ManagedChannelServiceConfigTest { - @SuppressWarnings("deprecation") // https://github.com/grpc/grpc-java/issues/7467 - @Rule - public final ExpectedException thrown = ExpectedException.none(); - @Test public void managedChannelServiceConfig_shouldParseHealthCheckingConfig() throws Exception { Map rawServiceConfig = @@ -78,10 +74,9 @@ public void createManagedChannelServiceConfig_failsOnDuplicateMethod() { Map methodConfig = ImmutableMap.of("name", ImmutableList.of(name1, name2)); Map serviceConfig = ImmutableMap.of("methodConfig", ImmutableList.of(methodConfig)); - thrown.expect(IllegalArgumentException.class); - thrown.expectMessage("Duplicate method"); - - ManagedChannelServiceConfig.fromServiceConfig(serviceConfig, true, 3, 4, null); + IllegalArgumentException e = assertThrows(IllegalArgumentException.class, + () -> ManagedChannelServiceConfig.fromServiceConfig(serviceConfig, true, 3, 4, null)); + assertThat(e).hasMessageThat().isEqualTo("Duplicate method name service/method"); } @Test @@ -91,10 +86,9 @@ public void createManagedChannelServiceConfig_failsOnDuplicateService() { Map methodConfig = ImmutableMap.of("name", ImmutableList.of(name1, name2)); Map serviceConfig = ImmutableMap.of("methodConfig", ImmutableList.of(methodConfig)); - thrown.expect(IllegalArgumentException.class); - thrown.expectMessage("Duplicate service"); - - ManagedChannelServiceConfig.fromServiceConfig(serviceConfig, true, 3, 4, null); + IllegalArgumentException e = assertThrows(IllegalArgumentException.class, + () -> ManagedChannelServiceConfig.fromServiceConfig(serviceConfig, true, 3, 4, null)); + assertThat(e).hasMessageThat().isEqualTo("Duplicate service service"); } @Test @@ -106,10 +100,9 @@ public void createManagedChannelServiceConfig_failsOnDuplicateServiceMultipleCon Map serviceConfig = ImmutableMap.of("methodConfig", ImmutableList.of(methodConfig1, methodConfig2)); - thrown.expect(IllegalArgumentException.class); - thrown.expectMessage("Duplicate service"); - - ManagedChannelServiceConfig.fromServiceConfig(serviceConfig, true, 3, 4, null); + IllegalArgumentException e = assertThrows(IllegalArgumentException.class, + () -> ManagedChannelServiceConfig.fromServiceConfig(serviceConfig, true, 3, 4, null)); + assertThat(e).hasMessageThat().isEqualTo("Duplicate service service"); } @Test @@ -118,10 +111,9 @@ public void createManagedChannelServiceConfig_failsOnMethodNameWithEmptyServiceN Map methodConfig = ImmutableMap.of("name", ImmutableList.of(name)); Map serviceConfig = ImmutableMap.of("methodConfig", ImmutableList.of(methodConfig)); - thrown.expect(IllegalArgumentException.class); - thrown.expectMessage("missing service name for method method1"); - - ManagedChannelServiceConfig.fromServiceConfig(serviceConfig, true, 3, 4, null); + IllegalArgumentException e = assertThrows(IllegalArgumentException.class, + () -> ManagedChannelServiceConfig.fromServiceConfig(serviceConfig, true, 3, 4, null)); + assertThat(e).hasMessageThat().isEqualTo("missing service name for method method1"); } @Test @@ -130,10 +122,9 @@ public void createManagedChannelServiceConfig_failsOnMethodNameWithoutServiceNam Map methodConfig = ImmutableMap.of("name", ImmutableList.of(name)); Map serviceConfig = ImmutableMap.of("methodConfig", ImmutableList.of(methodConfig)); - thrown.expect(IllegalArgumentException.class); - thrown.expectMessage("missing service name for method method1"); - - ManagedChannelServiceConfig.fromServiceConfig(serviceConfig, true, 3, 4, null); + IllegalArgumentException e = assertThrows(IllegalArgumentException.class, + () -> ManagedChannelServiceConfig.fromServiceConfig(serviceConfig, true, 3, 4, null)); + assertThat(e).hasMessageThat().isEqualTo("missing service name for method method1"); } @Test @@ -142,10 +133,9 @@ public void createManagedChannelServiceConfig_failsOnMissingServiceName() { Map methodConfig = ImmutableMap.of("name", ImmutableList.of(name)); Map serviceConfig = ImmutableMap.of("methodConfig", ImmutableList.of(methodConfig)); - thrown.expect(IllegalArgumentException.class); - thrown.expectMessage("missing service"); - - ManagedChannelServiceConfig.fromServiceConfig(serviceConfig, true, 3, 4, null); + IllegalArgumentException e = assertThrows(IllegalArgumentException.class, + () -> ManagedChannelServiceConfig.fromServiceConfig(serviceConfig, true, 3, 4, null)); + assertThat(e).hasMessageThat().isEqualTo("missing service name for method method"); } @Test @@ -209,7 +199,8 @@ public void getDefaultConfigSelectorFromConfig() { InternalConfigSelector configSelector = serviceConfig.getDefaultConfigSelector(); MethodDescriptor method = methodForName("service1", "method1"); Result result = configSelector.selectConfig( - new PickSubchannelArgsImpl(method, new Metadata(), CallOptions.DEFAULT)); + new PickSubchannelArgsImpl( + method, new Metadata(), CallOptions.DEFAULT, new PickDetailsConsumer() {})); MethodInfo methodInfoFromDefaultConfigSelector = ((ManagedChannelServiceConfig) result.getConfig()).getMethodConfig(method); assertThat(methodInfoFromDefaultConfigSelector) diff --git a/core/src/test/java/io/grpc/internal/ManagedClientTransportTest.java b/core/src/test/java/io/grpc/internal/ManagedClientTransportTest.java index 0af88a62728..5ddea08131b 100644 --- a/core/src/test/java/io/grpc/internal/ManagedClientTransportTest.java +++ b/core/src/test/java/io/grpc/internal/ManagedClientTransportTest.java @@ -32,7 +32,7 @@ public class ManagedClientTransportTest { public void testListener() { ManagedClientTransport.Listener listener = new ManagedClientTransport.Listener() { @Override - public void transportShutdown(Status s) {} + public void transportShutdown(Status s, DisconnectError e) {} @Override public void transportTerminated() {} @@ -45,7 +45,7 @@ public void transportInUse(boolean inUse) {} }; // Test that the listener methods do not throw. - listener.transportShutdown(Status.OK); + listener.transportShutdown(Status.OK, SimpleDisconnectError.SUBCHANNEL_SHUTDOWN); listener.transportTerminated(); listener.transportReady(); listener.transportInUse(true); diff --git a/core/src/test/java/io/grpc/internal/MessageDeframerTest.java b/core/src/test/java/io/grpc/internal/MessageDeframerTest.java index 98ed0691458..54758bc096f 100644 --- a/core/src/test/java/io/grpc/internal/MessageDeframerTest.java +++ b/core/src/test/java/io/grpc/internal/MessageDeframerTest.java @@ -20,6 +20,7 @@ import static io.grpc.internal.GrpcUtil.DEFAULT_MAX_MESSAGE_SIZE; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; import static org.junit.Assume.assumeTrue; import static org.mockito.ArgumentMatchers.anyInt; @@ -31,7 +32,6 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoMoreInteractions; -import com.google.common.base.Charsets; import com.google.common.io.ByteStreams; import com.google.common.primitives.Bytes; import io.grpc.Codec; @@ -46,6 +46,7 @@ import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; +import java.nio.charset.StandardCharsets; import java.util.Arrays; import java.util.Collection; import java.util.List; @@ -53,10 +54,8 @@ import java.util.concurrent.TimeUnit; import java.util.zip.GZIPOutputStream; import org.junit.Before; -import org.junit.Rule; import org.junit.Test; import org.junit.experimental.runners.Enclosed; -import org.junit.rules.ExpectedException; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; import org.junit.runners.Parameterized; @@ -133,7 +132,7 @@ public void simplePayload() { assertEquals(Bytes.asList(new byte[]{3, 14}), bytes(producer.getValue().next())); verify(listener, atLeastOnce()).bytesRead(anyInt()); verifyNoMoreInteractions(listener); - checkStats(tracer, transportTracer.getStats(), fakeClock, 2, 2); + checkStats(tracer, transportTracer.getStats(), fakeClock, useGzipInflatingBuffer, 2, 2); } @Test @@ -148,7 +147,7 @@ public void smallCombinedPayloads() { verify(listener, atLeastOnce()).bytesRead(anyInt()); assertEquals(Bytes.asList(new byte[]{14, 15}), bytes(streams.get(1).next())); verifyNoMoreInteractions(listener); - checkStats(tracer, transportTracer.getStats(), fakeClock, 1, 1, 2, 2); + checkStats(tracer, transportTracer.getStats(), fakeClock, useGzipInflatingBuffer, 1, 1, 2, 2); } @Test @@ -162,7 +161,7 @@ public void endOfStreamWithPayloadShouldNotifyEndOfStream() { verify(listener).deframerClosed(false); verify(listener, atLeastOnce()).bytesRead(anyInt()); verifyNoMoreInteractions(listener); - checkStats(tracer, transportTracer.getStats(), fakeClock, 1, 1); + checkStats(tracer, transportTracer.getStats(), fakeClock, useGzipInflatingBuffer, 1, 1); } @Test @@ -177,7 +176,7 @@ public void endOfStreamShouldNotifyEndOfStream() { } verify(listener).deframerClosed(false); verifyNoMoreInteractions(listener); - checkStats(tracer, transportTracer.getStats(), fakeClock); + checkStats(tracer, transportTracer.getStats(), fakeClock, false); } @Test @@ -189,7 +188,7 @@ public void endOfStreamWithPartialMessageShouldNotifyDeframerClosedWithPartialMe verify(listener, atLeastOnce()).bytesRead(anyInt()); verify(listener).deframerClosed(true); verifyNoMoreInteractions(listener); - checkStats(tracer, transportTracer.getStats(), fakeClock); + checkStats(tracer, transportTracer.getStats(), fakeClock, false); } @Test @@ -206,7 +205,7 @@ public void endOfStreamWithInvalidGzipBlockShouldNotifyDeframerClosedWithPartial deframer.closeWhenComplete(); verify(listener).deframerClosed(true); verifyNoMoreInteractions(listener); - checkStats(tracer, transportTracer.getStats(), fakeClock); + checkStats(tracer, transportTracer.getStats(), fakeClock, false); } @Test @@ -228,10 +227,11 @@ public void payloadSplitBetweenBuffers() { tracer, transportTracer.getStats(), fakeClock, + true, 7 /* msg size */ + 2 /* second buffer adds two bytes of overhead in deflate block */, 7); } else { - checkStats(tracer, transportTracer.getStats(), fakeClock, 7, 7); + checkStats(tracer, transportTracer.getStats(), fakeClock, false, 7, 7); } } @@ -248,7 +248,7 @@ public void frameHeaderSplitBetweenBuffers() { assertEquals(Bytes.asList(new byte[]{3}), bytes(producer.getValue().next())); verify(listener, atLeastOnce()).bytesRead(anyInt()); verifyNoMoreInteractions(listener); - checkStats(tracer, transportTracer.getStats(), fakeClock, 1, 1); + checkStats(tracer, transportTracer.getStats(), fakeClock, useGzipInflatingBuffer, 1, 1); } @Test @@ -259,7 +259,7 @@ public void emptyPayload() { assertEquals(Bytes.asList(), bytes(producer.getValue().next())); verify(listener, atLeastOnce()).bytesRead(anyInt()); verifyNoMoreInteractions(listener); - checkStats(tracer, transportTracer.getStats(), fakeClock, 0, 0); + checkStats(tracer, transportTracer.getStats(), fakeClock, useGzipInflatingBuffer, 0, 0); } @Test @@ -273,9 +273,10 @@ public void largerFrameSize() { verify(listener, atLeastOnce()).bytesRead(anyInt()); verifyNoMoreInteractions(listener); if (useGzipInflatingBuffer) { - checkStats(tracer, transportTracer.getStats(), fakeClock, 8 /* compressed size */, 1000); + checkStats(tracer, transportTracer.getStats(), fakeClock,true, + 8 /* compressed size */, 1000); } else { - checkStats(tracer, transportTracer.getStats(), fakeClock, 1000, 1000); + checkStats(tracer, transportTracer.getStats(), fakeClock, false, 1000, 1000); } } @@ -292,7 +293,7 @@ public void endOfStreamCallbackShouldWaitForMessageDelivery() { verify(listener).deframerClosed(false); verify(listener, atLeastOnce()).bytesRead(anyInt()); verifyNoMoreInteractions(listener); - checkStats(tracer, transportTracer.getStats(), fakeClock, 1, 1); + checkStats(tracer, transportTracer.getStats(), fakeClock, useGzipInflatingBuffer, 1, 1); } @Test @@ -308,6 +309,7 @@ public void compressed() { verify(listener).messagesAvailable(producer.capture()); assertEquals(Bytes.asList(new byte[1000]), bytes(producer.getValue().next())); verify(listener, atLeastOnce()).bytesRead(anyInt()); + checkStats(tracer, transportTracer.getStats(), fakeClock, true, 29, 1000); verifyNoMoreInteractions(listener); } @@ -338,16 +340,13 @@ public Void answer(InvocationOnMock invocation) throws Throwable { @RunWith(JUnit4.class) public static class SizeEnforcingInputStreamTests { - @SuppressWarnings("deprecation") // https://github.com/grpc/grpc-java/issues/7467 - @Rule - public final ExpectedException thrown = ExpectedException.none(); private TestBaseStreamTracer tracer = new TestBaseStreamTracer(); private StatsTraceContext statsTraceCtx = new StatsTraceContext(new StreamTracer[]{tracer}); @Test public void sizeEnforcingInputStream_readByteBelowLimit() throws IOException { - ByteArrayInputStream in = new ByteArrayInputStream("foo".getBytes(Charsets.UTF_8)); + ByteArrayInputStream in = new ByteArrayInputStream("foo".getBytes(StandardCharsets.UTF_8)); SizeEnforcingInputStream stream = new MessageDeframer.SizeEnforcingInputStream(in, 4, statsTraceCtx); @@ -360,7 +359,7 @@ public void sizeEnforcingInputStream_readByteBelowLimit() throws IOException { @Test public void sizeEnforcingInputStream_readByteAtLimit() throws IOException { - ByteArrayInputStream in = new ByteArrayInputStream("foo".getBytes(Charsets.UTF_8)); + ByteArrayInputStream in = new ByteArrayInputStream("foo".getBytes(StandardCharsets.UTF_8)); SizeEnforcingInputStream stream = new MessageDeframer.SizeEnforcingInputStream(in, 3, statsTraceCtx); @@ -373,16 +372,17 @@ public void sizeEnforcingInputStream_readByteAtLimit() throws IOException { @Test public void sizeEnforcingInputStream_readByteAboveLimit() throws IOException { - ByteArrayInputStream in = new ByteArrayInputStream("foo".getBytes(Charsets.UTF_8)); + ByteArrayInputStream in = new ByteArrayInputStream("foo".getBytes(StandardCharsets.UTF_8)); SizeEnforcingInputStream stream = new MessageDeframer.SizeEnforcingInputStream(in, 2, statsTraceCtx); try { - thrown.expect(StatusRuntimeException.class); - thrown.expectMessage("RESOURCE_EXHAUSTED: Decompressed gRPC message exceeds"); - - while (stream.read() != -1) { - } + StatusRuntimeException e = assertThrows(StatusRuntimeException.class, () -> { + while (stream.read() != -1) { + } + }); + assertThat(e).hasMessageThat() + .isEqualTo("RESOURCE_EXHAUSTED: Decompressed gRPC message exceeds maximum size 2"); } finally { stream.close(); } @@ -390,7 +390,7 @@ public void sizeEnforcingInputStream_readByteAboveLimit() throws IOException { @Test public void sizeEnforcingInputStream_readBelowLimit() throws IOException { - ByteArrayInputStream in = new ByteArrayInputStream("foo".getBytes(Charsets.UTF_8)); + ByteArrayInputStream in = new ByteArrayInputStream("foo".getBytes(StandardCharsets.UTF_8)); SizeEnforcingInputStream stream = new MessageDeframer.SizeEnforcingInputStream(in, 4, statsTraceCtx); byte[] buf = new byte[10]; @@ -404,7 +404,7 @@ public void sizeEnforcingInputStream_readBelowLimit() throws IOException { @Test public void sizeEnforcingInputStream_readAtLimit() throws IOException { - ByteArrayInputStream in = new ByteArrayInputStream("foo".getBytes(Charsets.UTF_8)); + ByteArrayInputStream in = new ByteArrayInputStream("foo".getBytes(StandardCharsets.UTF_8)); SizeEnforcingInputStream stream = new MessageDeframer.SizeEnforcingInputStream(in, 3, statsTraceCtx); byte[] buf = new byte[10]; @@ -418,16 +418,16 @@ public void sizeEnforcingInputStream_readAtLimit() throws IOException { @Test public void sizeEnforcingInputStream_readAboveLimit() throws IOException { - ByteArrayInputStream in = new ByteArrayInputStream("foo".getBytes(Charsets.UTF_8)); + ByteArrayInputStream in = new ByteArrayInputStream("foo".getBytes(StandardCharsets.UTF_8)); SizeEnforcingInputStream stream = new MessageDeframer.SizeEnforcingInputStream(in, 2, statsTraceCtx); byte[] buf = new byte[10]; try { - thrown.expect(StatusRuntimeException.class); - thrown.expectMessage("RESOURCE_EXHAUSTED: Decompressed gRPC message exceeds"); - - stream.read(buf, 0, buf.length); + StatusRuntimeException e = assertThrows(StatusRuntimeException.class, + () -> stream.read(buf, 0, buf.length)); + assertThat(e).hasMessageThat() + .isEqualTo("RESOURCE_EXHAUSTED: Decompressed gRPC message exceeds maximum size 2"); } finally { stream.close(); } @@ -435,7 +435,7 @@ public void sizeEnforcingInputStream_readAboveLimit() throws IOException { @Test public void sizeEnforcingInputStream_skipBelowLimit() throws IOException { - ByteArrayInputStream in = new ByteArrayInputStream("foo".getBytes(Charsets.UTF_8)); + ByteArrayInputStream in = new ByteArrayInputStream("foo".getBytes(StandardCharsets.UTF_8)); SizeEnforcingInputStream stream = new MessageDeframer.SizeEnforcingInputStream(in, 4, statsTraceCtx); @@ -449,7 +449,7 @@ public void sizeEnforcingInputStream_skipBelowLimit() throws IOException { @Test public void sizeEnforcingInputStream_skipAtLimit() throws IOException { - ByteArrayInputStream in = new ByteArrayInputStream("foo".getBytes(Charsets.UTF_8)); + ByteArrayInputStream in = new ByteArrayInputStream("foo".getBytes(StandardCharsets.UTF_8)); SizeEnforcingInputStream stream = new MessageDeframer.SizeEnforcingInputStream(in, 3, statsTraceCtx); @@ -462,15 +462,14 @@ public void sizeEnforcingInputStream_skipAtLimit() throws IOException { @Test public void sizeEnforcingInputStream_skipAboveLimit() throws IOException { - ByteArrayInputStream in = new ByteArrayInputStream("foo".getBytes(Charsets.UTF_8)); + ByteArrayInputStream in = new ByteArrayInputStream("foo".getBytes(StandardCharsets.UTF_8)); SizeEnforcingInputStream stream = new MessageDeframer.SizeEnforcingInputStream(in, 2, statsTraceCtx); try { - thrown.expect(StatusRuntimeException.class); - thrown.expectMessage("RESOURCE_EXHAUSTED: Decompressed gRPC message exceeds"); - - stream.skip(4); + StatusRuntimeException e = assertThrows(StatusRuntimeException.class, () -> stream.skip(4)); + assertThat(e).hasMessageThat() + .isEqualTo("RESOURCE_EXHAUSTED: Decompressed gRPC message exceeds maximum size 2"); } finally { stream.close(); } @@ -478,7 +477,7 @@ public void sizeEnforcingInputStream_skipAboveLimit() throws IOException { @Test public void sizeEnforcingInputStream_markReset() throws IOException { - ByteArrayInputStream in = new ByteArrayInputStream("foo".getBytes(Charsets.UTF_8)); + ByteArrayInputStream in = new ByteArrayInputStream("foo".getBytes(StandardCharsets.UTF_8)); SizeEnforcingInputStream stream = new MessageDeframer.SizeEnforcingInputStream(in, 3, statsTraceCtx); // stream currently looks like: |foo @@ -502,7 +501,8 @@ public void sizeEnforcingInputStream_markReset() throws IOException { * @param sizes in the format {wire0, uncompressed0, wire1, uncompressed1, ...} */ private static void checkStats( - TestBaseStreamTracer tracer, TransportStats transportStats, FakeClock clock, long... sizes) { + TestBaseStreamTracer tracer, TransportStats transportStats, FakeClock clock, + boolean compressed, long... sizes) { assertEquals(0, sizes.length % 2); int count = sizes.length / 2; long expectedWireSize = 0; @@ -510,7 +510,8 @@ private static void checkStats( for (int i = 0; i < count; i++) { assertEquals("inboundMessage(" + i + ")", tracer.nextInboundEvent()); assertEquals( - String.format(Locale.US, "inboundMessageRead(%d, %d, -1)", i, sizes[i * 2]), + String.format(Locale.US, "inboundMessageRead(%d, %d, %d)", i, sizes[i * 2], + compressed ? -1 : sizes[i * 2 + 1]), tracer.nextInboundEvent()); expectedWireSize += sizes[i * 2]; expectedUncompressedSize += sizes[i * 2 + 1]; diff --git a/core/src/test/java/io/grpc/internal/MetricRecorderImplTest.java b/core/src/test/java/io/grpc/internal/MetricRecorderImplTest.java new file mode 100644 index 00000000000..33bf9bb41e2 --- /dev/null +++ b/core/src/test/java/io/grpc/internal/MetricRecorderImplTest.java @@ -0,0 +1,329 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.internal; + +import static org.junit.Assert.assertThrows; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyList; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import com.google.common.collect.ImmutableList; +import io.grpc.DoubleCounterMetricInstrument; +import io.grpc.DoubleHistogramMetricInstrument; +import io.grpc.LongCounterMetricInstrument; +import io.grpc.LongGaugeMetricInstrument; +import io.grpc.LongHistogramMetricInstrument; +import io.grpc.LongUpDownCounterMetricInstrument; +import io.grpc.MetricInstrumentRegistry; +import io.grpc.MetricInstrumentRegistryAccessor; +import io.grpc.MetricRecorder; +import io.grpc.MetricSink; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.ArgumentCaptor; + +/** + * Unit test for {@link MetricRecorderImpl}. + */ +@RunWith(JUnit4.class) +public class MetricRecorderImplTest { + private static final String DESCRIPTION = "description"; + private static final String UNIT = "unit"; + private static final boolean ENABLED = true; + private static final ImmutableList REQUIRED_LABEL_KEYS = ImmutableList.of("KEY1", "KEY2"); + private static final ImmutableList OPTIONAL_LABEL_KEYS = ImmutableList.of( + "OPTIONAL_KEY_1"); + private static final ImmutableList REQUIRED_LABEL_VALUES = ImmutableList.of("VALUE1", + "VALUE2"); + private static final ImmutableList OPTIONAL_LABEL_VALUES = ImmutableList.of( + "OPTIONAL_VALUE_1"); + private MetricSink mockSink = mock(MetricSink.class); + private List sinks = Arrays.asList(mockSink, mockSink); + private MetricInstrumentRegistry registry = + MetricInstrumentRegistryAccessor.createMetricInstrumentRegistry(); + private final DoubleCounterMetricInstrument doubleCounterInstrument = + registry.registerDoubleCounter("counter0", DESCRIPTION, UNIT, REQUIRED_LABEL_KEYS, + OPTIONAL_LABEL_KEYS, ENABLED); + private final LongCounterMetricInstrument longCounterInstrument = + registry.registerLongCounter("counter1", DESCRIPTION, UNIT, REQUIRED_LABEL_KEYS, + OPTIONAL_LABEL_KEYS, ENABLED); + private final DoubleHistogramMetricInstrument doubleHistogramInstrument = + registry.registerDoubleHistogram("histogram1", DESCRIPTION, UNIT, + Collections.emptyList(), REQUIRED_LABEL_KEYS, OPTIONAL_LABEL_KEYS, ENABLED); + private final LongHistogramMetricInstrument longHistogramInstrument = + registry.registerLongHistogram("histogram2", DESCRIPTION, UNIT, + Collections.emptyList(), REQUIRED_LABEL_KEYS, OPTIONAL_LABEL_KEYS, ENABLED); + private final LongGaugeMetricInstrument longGaugeInstrument = + registry.registerLongGauge("gauge0", DESCRIPTION, UNIT, REQUIRED_LABEL_KEYS, + OPTIONAL_LABEL_KEYS, ENABLED); + private final LongUpDownCounterMetricInstrument longUpDownCounterInstrument = + registry.registerLongUpDownCounter("upDownCounter0", DESCRIPTION, UNIT, + REQUIRED_LABEL_KEYS, OPTIONAL_LABEL_KEYS, ENABLED); + private MetricRecorder recorder; + + @Before + public void setUp() { + recorder = new MetricRecorderImpl(sinks, registry); + } + + @Test + public void addCounter() { + when(mockSink.getMeasuresSize()).thenReturn(6); + + recorder.addDoubleCounter(doubleCounterInstrument, 1.0, REQUIRED_LABEL_VALUES, + OPTIONAL_LABEL_VALUES); + verify(mockSink, times(2)).addDoubleCounter(eq(doubleCounterInstrument), eq(1D), + eq(REQUIRED_LABEL_VALUES), eq(OPTIONAL_LABEL_VALUES)); + + recorder.addLongCounter(longCounterInstrument, 1, REQUIRED_LABEL_VALUES, + OPTIONAL_LABEL_VALUES); + verify(mockSink, times(2)).addLongCounter(eq(longCounterInstrument), eq(1L), + eq(REQUIRED_LABEL_VALUES), eq(OPTIONAL_LABEL_VALUES)); + + recorder.addLongUpDownCounter(longUpDownCounterInstrument, -10, REQUIRED_LABEL_VALUES, + OPTIONAL_LABEL_VALUES); + verify(mockSink, times(2)) + .addLongUpDownCounter(eq(longUpDownCounterInstrument), eq(-10L), + eq(REQUIRED_LABEL_VALUES), eq(OPTIONAL_LABEL_VALUES)); + + verify(mockSink, never()).updateMeasures(registry.getMetricInstruments()); + } + + @Test + public void recordHistogram() { + when(mockSink.getMeasuresSize()).thenReturn(4); + + recorder.recordDoubleHistogram(doubleHistogramInstrument, 99.0, REQUIRED_LABEL_VALUES, + OPTIONAL_LABEL_VALUES); + verify(mockSink, times(2)).recordDoubleHistogram(eq(doubleHistogramInstrument), + eq(99D), eq(REQUIRED_LABEL_VALUES), eq(OPTIONAL_LABEL_VALUES)); + + recorder.recordLongHistogram(longHistogramInstrument, 99, REQUIRED_LABEL_VALUES, + OPTIONAL_LABEL_VALUES); + verify(mockSink, times(2)).recordLongHistogram(eq(longHistogramInstrument), eq(99L), + eq(REQUIRED_LABEL_VALUES), eq(OPTIONAL_LABEL_VALUES)); + + verify(mockSink, never()).updateMeasures(registry.getMetricInstruments()); + } + + @Test + public void recordCallback() { + MetricSink.Registration mockRegistration = mock(MetricSink.Registration.class); + when(mockSink.getMeasuresSize()).thenReturn(5); + when(mockSink.registerBatchCallback(any(Runnable.class), eq(longGaugeInstrument))) + .thenReturn(mockRegistration); + + MetricRecorder.Registration registration = recorder.registerBatchCallback((recorder) -> { + recorder.recordLongGauge( + longGaugeInstrument, 99, REQUIRED_LABEL_VALUES, OPTIONAL_LABEL_VALUES); + }, longGaugeInstrument); + + ArgumentCaptor callbackCaptor = ArgumentCaptor.forClass(Runnable.class); + verify(mockSink, times(2)) + .registerBatchCallback(callbackCaptor.capture(), eq(longGaugeInstrument)); + + callbackCaptor.getValue().run(); + // Only once, for the one sink that called the callback. + verify(mockSink).recordLongGauge( + longGaugeInstrument, 99, REQUIRED_LABEL_VALUES, OPTIONAL_LABEL_VALUES); + + verify(mockRegistration, never()).close(); + registration.close(); + verify(mockRegistration, times(2)).close(); + + verify(mockSink, never()).updateMeasures(registry.getMetricInstruments()); + } + + @Test + public void newRegisteredMetricUpdateMeasures() { + // Sink is initialized with zero measures, should trigger updateMeasures() on sinks + when(mockSink.getMeasuresSize()).thenReturn(0); + + // Double Counter + recorder.addDoubleCounter(doubleCounterInstrument, 1.0, REQUIRED_LABEL_VALUES, + OPTIONAL_LABEL_VALUES); + verify(mockSink, times(2)).updateMeasures(anyList()); + verify(mockSink, times(2)).addDoubleCounter(eq(doubleCounterInstrument), eq(1D), + eq(REQUIRED_LABEL_VALUES), eq(OPTIONAL_LABEL_VALUES)); + + // Long Counter + recorder.addLongCounter(longCounterInstrument, 1, REQUIRED_LABEL_VALUES, + OPTIONAL_LABEL_VALUES); + verify(mockSink, times(4)).updateMeasures(anyList()); + verify(mockSink, times(2)).addLongCounter(eq(longCounterInstrument), eq(1L), + eq(REQUIRED_LABEL_VALUES), eq(OPTIONAL_LABEL_VALUES)); + + // Double Histogram + recorder.recordDoubleHistogram(doubleHistogramInstrument, 99.0, REQUIRED_LABEL_VALUES, + OPTIONAL_LABEL_VALUES); + verify(mockSink, times(6)).updateMeasures(anyList()); + verify(mockSink, times(2)).recordDoubleHistogram(eq(doubleHistogramInstrument), + eq(99D), eq(REQUIRED_LABEL_VALUES), eq(OPTIONAL_LABEL_VALUES)); + + // Long Histogram + recorder.recordLongHistogram(longHistogramInstrument, 99, REQUIRED_LABEL_VALUES, + OPTIONAL_LABEL_VALUES); + verify(mockSink, times(8)).updateMeasures(registry.getMetricInstruments()); + verify(mockSink, times(2)).recordLongHistogram(eq(longHistogramInstrument), eq(99L), + eq(REQUIRED_LABEL_VALUES), eq(OPTIONAL_LABEL_VALUES)); + + // Callback + when(mockSink.registerBatchCallback(any(Runnable.class), eq(longGaugeInstrument))) + .thenReturn(mock(MetricSink.Registration.class)); + MetricRecorder.Registration registration = recorder.registerBatchCallback( + (recorder) -> { }, longGaugeInstrument); + verify(mockSink, times(10)).updateMeasures(registry.getMetricInstruments()); + verify(mockSink, times(2)) + .registerBatchCallback(any(Runnable.class), eq(longGaugeInstrument)); + registration.close(); + + // Long UpDown Counter + recorder.addLongUpDownCounter(longUpDownCounterInstrument, -10, REQUIRED_LABEL_VALUES, + OPTIONAL_LABEL_VALUES); + verify(mockSink, times(12)).updateMeasures(anyList()); + verify(mockSink, times(2)).addLongUpDownCounter(eq(longUpDownCounterInstrument), eq(-10L), + eq(REQUIRED_LABEL_VALUES), eq(OPTIONAL_LABEL_VALUES)); + } + + @Test(expected = IllegalArgumentException.class) + public void addDoubleCounterMismatchedRequiredLabelValues() { + when(mockSink.getMeasuresSize()).thenReturn(4); + + recorder.addDoubleCounter(doubleCounterInstrument, 1.0, ImmutableList.of(), + OPTIONAL_LABEL_VALUES); + } + + @Test(expected = IllegalArgumentException.class) + public void addLongCounterMismatchedRequiredLabelValues() { + when(mockSink.getMeasuresSize()).thenReturn(4); + + recorder.addLongCounter(longCounterInstrument, 1, ImmutableList.of(), + OPTIONAL_LABEL_VALUES); + } + + @Test(expected = IllegalArgumentException.class) + public void addLongUpDownCounterMismatchedRequiredLabelValues() { + when(mockSink.getMeasuresSize()).thenReturn(6); + recorder.addLongUpDownCounter(longUpDownCounterInstrument, 1, ImmutableList.of(), + OPTIONAL_LABEL_VALUES); + } + + @Test(expected = IllegalArgumentException.class) + public void recordDoubleHistogramMismatchedRequiredLabelValues() { + when(mockSink.getMeasuresSize()).thenReturn(4); + + recorder.recordDoubleHistogram(doubleHistogramInstrument, 99.0, ImmutableList.of(), + OPTIONAL_LABEL_VALUES); + } + + @Test(expected = IllegalArgumentException.class) + public void recordLongHistogramMismatchedRequiredLabelValues() { + when(mockSink.getMeasuresSize()).thenReturn(4); + + recorder.recordLongHistogram(longHistogramInstrument, 99, ImmutableList.of(), + OPTIONAL_LABEL_VALUES); + } + + @Test + public void recordLongGaugeMismatchedRequiredLabelValues() { + when(mockSink.getMeasuresSize()).thenReturn(4); + when(mockSink.registerBatchCallback(any(Runnable.class), eq(longGaugeInstrument))) + .thenReturn(mock(MetricSink.Registration.class)); + + MetricRecorder.Registration registration = recorder.registerBatchCallback((recorder) -> { + assertThrows( + IllegalArgumentException.class, + () -> recorder.recordLongGauge( + longGaugeInstrument, 99, ImmutableList.of(), OPTIONAL_LABEL_VALUES)); + }, longGaugeInstrument); + + ArgumentCaptor callbackCaptor = ArgumentCaptor.forClass(Runnable.class); + verify(mockSink, times(2)) + .registerBatchCallback(callbackCaptor.capture(), eq(longGaugeInstrument)); + callbackCaptor.getValue().run(); + registration.close(); + } + + @Test(expected = IllegalArgumentException.class) + public void addDoubleCounterMismatchedOptionalLabelValues() { + when(mockSink.getMeasuresSize()).thenReturn(4); + + recorder.addDoubleCounter(doubleCounterInstrument, 1.0, REQUIRED_LABEL_VALUES, + ImmutableList.of()); + } + + @Test(expected = IllegalArgumentException.class) + public void addLongCounterMismatchedOptionalLabelValues() { + when(mockSink.getMeasuresSize()).thenReturn(4); + + recorder.addLongCounter(longCounterInstrument, 1, REQUIRED_LABEL_VALUES, + ImmutableList.of()); + } + + @Test(expected = IllegalArgumentException.class) + public void addLongUpDownCounterMismatchedOptionalLabelValues() { + when(mockSink.getMeasuresSize()).thenReturn(6); + recorder.addLongUpDownCounter(longUpDownCounterInstrument, 1, REQUIRED_LABEL_VALUES, + ImmutableList.of()); + } + + @Test(expected = IllegalArgumentException.class) + public void recordDoubleHistogramMismatchedOptionalLabelValues() { + when(mockSink.getMeasuresSize()).thenReturn(4); + + recorder.recordDoubleHistogram(doubleHistogramInstrument, 99.0, REQUIRED_LABEL_VALUES, + ImmutableList.of()); + } + + @Test(expected = IllegalArgumentException.class) + public void recordLongHistogramMismatchedOptionalLabelValues() { + when(mockSink.getMeasuresSize()).thenReturn(4); + + recorder.recordLongHistogram(longHistogramInstrument, 99, REQUIRED_LABEL_VALUES, + ImmutableList.of()); + } + + @Test + public void recordLongGaugeMismatchedOptionalLabelValues() { + when(mockSink.getMeasuresSize()).thenReturn(4); + when(mockSink.registerBatchCallback(any(Runnable.class), eq(longGaugeInstrument))) + .thenReturn(mock(MetricSink.Registration.class)); + + MetricRecorder.Registration registration = recorder.registerBatchCallback((recorder) -> { + assertThrows( + IllegalArgumentException.class, + () -> recorder.recordLongGauge( + longGaugeInstrument, 99, REQUIRED_LABEL_VALUES, ImmutableList.of())); + }, longGaugeInstrument); + + ArgumentCaptor callbackCaptor = ArgumentCaptor.forClass(Runnable.class); + verify(mockSink, times(2)) + .registerBatchCallback(callbackCaptor.capture(), eq(longGaugeInstrument)); + callbackCaptor.getValue().run(); + registration.close(); + } +} diff --git a/core/src/test/java/io/grpc/internal/NoopClientStreamTest.java b/core/src/test/java/io/grpc/internal/NoopClientStreamTest.java new file mode 100644 index 00000000000..d68642dad85 --- /dev/null +++ b/core/src/test/java/io/grpc/internal/NoopClientStreamTest.java @@ -0,0 +1,44 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.internal; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; + +import java.io.InputStream; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** + * Unit tests for {@link NoopClientStream}. + */ +@RunWith(JUnit4.class) +public class NoopClientStreamTest { + + @Test + public void writeMessageShouldCloseInputStream() throws Exception { + // NoopClientStream.writeMessage() is called when a stream is cancelled or failed + // before the real transport stream is established (e.g. via DelayedStream draining + // buffered messages to NoopClientStream on cancellation, or FailingClientStream + // which extends NoopClientStream). The InputStream must be closed to avoid leaking + // resources such as ref-counted ByteBufs. + InputStream message = mock(InputStream.class); + NoopClientStream.INSTANCE.writeMessage(message); + verify(message).close(); + } +} diff --git a/core/src/test/java/io/grpc/internal/PickFirstLeafLoadBalancerTest.java b/core/src/test/java/io/grpc/internal/PickFirstLeafLoadBalancerTest.java index 92222ac9af6..0467e57223d 100644 --- a/core/src/test/java/io/grpc/internal/PickFirstLeafLoadBalancerTest.java +++ b/core/src/test/java/io/grpc/internal/PickFirstLeafLoadBalancerTest.java @@ -23,13 +23,18 @@ import static io.grpc.ConnectivityState.READY; import static io.grpc.ConnectivityState.SHUTDOWN; import static io.grpc.ConnectivityState.TRANSIENT_FAILURE; +import static io.grpc.InternalEquivalentAddressGroup.ATTR_WEIGHT; import static io.grpc.LoadBalancer.HAS_HEALTH_PRODUCER_LISTENER_KEY; import static io.grpc.LoadBalancer.HEALTH_CONSUMER_LISTENER_ARG_KEY; +import static io.grpc.LoadBalancer.IS_PETIOLE_POLICY; import static io.grpc.internal.PickFirstLeafLoadBalancer.CONNECTION_DELAY_INTERVAL_MS; +import static io.grpc.internal.PickFirstLeafLoadBalancer.isSerializingRetries; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThrows; +import static org.junit.Assume.assumeTrue; import static org.mockito.AdditionalAnswers.delegatesTo; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; @@ -64,15 +69,18 @@ import io.grpc.Status.Code; import io.grpc.SynchronizationContext; import io.grpc.internal.PickFirstLeafLoadBalancer.PickFirstLeafLoadBalancerConfig; +import java.net.InetSocketAddress; import java.net.SocketAddress; +import java.util.ArrayDeque; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.List; +import java.util.Queue; +import java.util.Random; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; import org.junit.After; -import org.junit.Assume; import org.junit.Before; import org.junit.Rule; import org.junit.Test; @@ -91,14 +99,22 @@ public class PickFirstLeafLoadBalancerTest { public static final Status CONNECTION_ERROR = Status.UNAVAILABLE.withDescription("Simulated connection error"); - - @Parameterized.Parameters(name = "{0}") - public static List enableHappyEyeballs() { - return Arrays.asList(true, false); + public static final String GRPC_SERIALIZE_RETRIES = "GRPC_SERIALIZE_RETRIES"; + + @Parameterized.Parameters(name = "{0}-{1}") + public static List data() { + return Arrays.asList(new Object[][] { + {false, false}, + {false, true}, + {true, false}}); } - @Parameterized.Parameter + @Parameterized.Parameter(value = 0) + public boolean serializeRetries; + + @Parameterized.Parameter(value = 1) public boolean enableHappyEyeballs; + private PickFirstLeafLoadBalancer loadBalancer; private final List servers = Lists.newArrayList(); private static final Attributes.Key FOO = Attributes.Key.create("foo"); @@ -123,61 +139,82 @@ public void uncaughtException(Thread t, Throwable e) { private ArgumentCaptor createArgsCaptor; @Captor private ArgumentCaptor stateListenerCaptor; - private final Helper mockHelper = mock(Helper.class, delegatesTo(new MockHelperImpl())); - @Mock + private Helper mockHelper; private FakeSubchannel mockSubchannel1; - @Mock + private FakeSubchannel mockSubchannel1n2; private FakeSubchannel mockSubchannel2; - @Mock + private FakeSubchannel mockSubchannel2n2; private FakeSubchannel mockSubchannel3; - @Mock + private FakeSubchannel mockSubchannel3n2; private FakeSubchannel mockSubchannel4; - @Mock private FakeSubchannel mockSubchannel5; @Mock // This LoadBalancer doesn't use any of the arg fields, as verified in tearDown(). private PickSubchannelArgs mockArgs; private String originalHappyEyeballsEnabledValue; + private String originalSerializeRetriesValue; + private boolean originalWeightedShuffling; + + private long backoffMillis; @Before public void setUp() { + assumeTrue(!serializeRetries || !enableHappyEyeballs); // they are not compatible + + backoffMillis = TimeUnit.SECONDS.toMillis(1); + originalSerializeRetriesValue = System.getProperty(GRPC_SERIALIZE_RETRIES); + System.setProperty(GRPC_SERIALIZE_RETRIES, Boolean.toString(serializeRetries)); + originalHappyEyeballsEnabledValue = - System.getProperty(PickFirstLeafLoadBalancer.GRPC_EXPERIMENTAL_XDS_DUALSTACK_ENDPOINTS); - System.setProperty(PickFirstLeafLoadBalancer.GRPC_EXPERIMENTAL_XDS_DUALSTACK_ENDPOINTS, - enableHappyEyeballs ? "true" : "false"); + System.getProperty(PickFirstLoadBalancerProvider.GRPC_PF_USE_HAPPY_EYEBALLS); + System.setProperty(PickFirstLoadBalancerProvider.GRPC_PF_USE_HAPPY_EYEBALLS, + Boolean.toString(enableHappyEyeballs)); + + originalWeightedShuffling = PickFirstLeafLoadBalancer.weightedShuffling; for (int i = 1; i <= 5; i++) { SocketAddress addr = new FakeSocketAddress("server" + i); servers.add(new EquivalentAddressGroup(addr)); } - mockSubchannel1 = mock(FakeSubchannel.class); - mockSubchannel2 = mock(FakeSubchannel.class); - mockSubchannel3 = mock(FakeSubchannel.class); - mockSubchannel4 = mock(FakeSubchannel.class); - mockSubchannel5 = mock(FakeSubchannel.class); - when(mockSubchannel1.getAttributes()).thenReturn(Attributes.EMPTY); - when(mockSubchannel2.getAttributes()).thenReturn(Attributes.EMPTY); - when(mockSubchannel3.getAttributes()).thenReturn(Attributes.EMPTY); - when(mockSubchannel4.getAttributes()).thenReturn(Attributes.EMPTY); - when(mockSubchannel5.getAttributes()).thenReturn(Attributes.EMPTY); - - when(mockSubchannel1.getAllAddresses()).thenReturn(Lists.newArrayList(servers.get(0))); - when(mockSubchannel2.getAllAddresses()).thenReturn(Lists.newArrayList(servers.get(1))); - when(mockSubchannel3.getAllAddresses()).thenReturn(Lists.newArrayList(servers.get(2))); - when(mockSubchannel4.getAllAddresses()).thenReturn(Lists.newArrayList(servers.get(3))); - when(mockSubchannel5.getAllAddresses()).thenReturn(Lists.newArrayList(servers.get(4))); - + mockSubchannel1 = mock(FakeSubchannel.class, delegatesTo( + new FakeSubchannel(Arrays.asList(servers.get(0)), Attributes.EMPTY))); + mockSubchannel1n2 = mock(FakeSubchannel.class, delegatesTo( + new FakeSubchannel(Arrays.asList(servers.get(0)), Attributes.EMPTY))); + mockSubchannel2 = mock(FakeSubchannel.class, delegatesTo( + new FakeSubchannel(Arrays.asList(servers.get(1)), Attributes.EMPTY))); + mockSubchannel2n2 = mock(FakeSubchannel.class, delegatesTo( + new FakeSubchannel(Arrays.asList(servers.get(1)), Attributes.EMPTY))); + mockSubchannel3 = mock(FakeSubchannel.class, delegatesTo( + new FakeSubchannel(Arrays.asList(servers.get(2)), Attributes.EMPTY))); + mockSubchannel3n2 = mock(FakeSubchannel.class, delegatesTo( + new FakeSubchannel(Arrays.asList(servers.get(2)), Attributes.EMPTY))); + mockSubchannel4 = mock(FakeSubchannel.class, delegatesTo( + new FakeSubchannel(Arrays.asList(servers.get(3)), Attributes.EMPTY))); + mockSubchannel5 = mock(FakeSubchannel.class, delegatesTo( + new FakeSubchannel(Arrays.asList(servers.get(4)), Attributes.EMPTY))); + + mockHelper = mock(Helper.class, delegatesTo(new MockHelperImpl(Arrays.asList( + mockSubchannel1, mockSubchannel1n2, + mockSubchannel2, mockSubchannel2n2, + mockSubchannel3, mockSubchannel3n2, + mockSubchannel4, mockSubchannel5)))); loadBalancer = new PickFirstLeafLoadBalancer(mockHelper); } @After public void tearDown() { + if (originalSerializeRetriesValue == null) { + System.clearProperty(GRPC_SERIALIZE_RETRIES); + } else { + System.setProperty(GRPC_SERIALIZE_RETRIES, originalSerializeRetriesValue); + } if (originalHappyEyeballsEnabledValue == null) { - System.clearProperty(PickFirstLeafLoadBalancer.GRPC_EXPERIMENTAL_XDS_DUALSTACK_ENDPOINTS); + System.clearProperty(PickFirstLoadBalancerProvider.GRPC_PF_USE_HAPPY_EYEBALLS); } else { - System.setProperty(PickFirstLeafLoadBalancer.GRPC_EXPERIMENTAL_XDS_DUALSTACK_ENDPOINTS, + System.setProperty(PickFirstLoadBalancerProvider.GRPC_PF_USE_HAPPY_EYEBALLS, originalHappyEyeballsEnabledValue); } + PickFirstLeafLoadBalancer.weightedShuffling = originalWeightedShuffling; loadBalancer.shutdown(); verifyNoMoreInteractions(mockArgs); @@ -213,6 +250,12 @@ public void pickAfterResolved() { verifyNoMoreInteractions(mockHelper); } + @Test + public void pickAfterResolved_shuffle_oppositeWeightedShuffling() { + PickFirstLeafLoadBalancer.weightedShuffling = !PickFirstLeafLoadBalancer.weightedShuffling; + pickAfterResolved_shuffle(); + } + @Test public void pickAfterResolved_shuffle() { servers.remove(4); @@ -251,14 +294,14 @@ public void pickAfterResolved_shuffle() { PickResult pick2 = pickerCaptor.getValue().pickSubchannel(mockArgs); assertEquals(pick1, pick2); verifyNoMoreInteractions(mockHelper); - assertThat(pick1.toString()).contains("subchannel=null"); + assertThat(pick1.getSubchannel()).isNull(); stateListener2.onSubchannelState(ConnectivityStateInfo.forNonError(READY)); verify(mockHelper).updateBalancingState(eq(READY), pickerCaptor.capture()); PickResult pick3 = pickerCaptor.getValue().pickSubchannel(mockArgs); PickResult pick4 = pickerCaptor.getValue().pickSubchannel(mockArgs); assertEquals(pick3, pick4); - assertThat(pick3.toString()).contains("subchannel=Mock"); + assertThat(pick3.getSubchannel()).isEqualTo(mockSubchannel2); } @Test @@ -276,6 +319,103 @@ public void pickAfterResolved_noShuffle() { assertNotNull(pickerCaptor.getValue().pickSubchannel(mockArgs)); } + @Test + public void pickAfterResolved_shuffleImplicitUniform_oppositeWeightedShuffling() { + PickFirstLeafLoadBalancer.weightedShuffling = !PickFirstLeafLoadBalancer.weightedShuffling; + pickAfterResolved_shuffleImplicitUniform(); + } + + @Test + public void pickAfterResolved_shuffleImplicitUniform() { + EquivalentAddressGroup eag1 = new EquivalentAddressGroup(new FakeSocketAddress("server1")); + EquivalentAddressGroup eag2 = new EquivalentAddressGroup(new FakeSocketAddress("server2")); + EquivalentAddressGroup eag3 = new EquivalentAddressGroup(new FakeSocketAddress("server3")); + + int[] counts = countAddressSelections(99, Arrays.asList(eag1, eag2, eag3)); + assertThat(counts[0]).isWithin(7).of(33); + assertThat(counts[1]).isWithin(7).of(33); + assertThat(counts[2]).isWithin(7).of(33); + } + + @Test + public void pickAfterResolved_shuffleExplicitUniform_oppositeWeightedShuffling() { + PickFirstLeafLoadBalancer.weightedShuffling = !PickFirstLeafLoadBalancer.weightedShuffling; + pickAfterResolved_shuffleExplicitUniform(); + } + + @Test + public void pickAfterResolved_shuffleExplicitUniform() { + EquivalentAddressGroup eag1 = new EquivalentAddressGroup( + new FakeSocketAddress("server1"), Attributes.newBuilder().set(ATTR_WEIGHT, 111L).build()); + EquivalentAddressGroup eag2 = new EquivalentAddressGroup( + new FakeSocketAddress("server2"), Attributes.newBuilder().set(ATTR_WEIGHT, 111L).build()); + EquivalentAddressGroup eag3 = new EquivalentAddressGroup( + new FakeSocketAddress("server3"), Attributes.newBuilder().set(ATTR_WEIGHT, 111L).build()); + + int[] counts = countAddressSelections(99, Arrays.asList(eag1, eag2, eag3)); + assertThat(counts[0]).isWithin(7).of(33); + assertThat(counts[1]).isWithin(7).of(33); + assertThat(counts[2]).isWithin(7).of(33); + } + + @Test + public void pickAfterResolved_shuffleWeighted_noWeightedShuffling() { + PickFirstLeafLoadBalancer.weightedShuffling = false; + EquivalentAddressGroup eag1 = new EquivalentAddressGroup( + new FakeSocketAddress("server1"), Attributes.newBuilder().set(ATTR_WEIGHT, 12L).build()); + EquivalentAddressGroup eag2 = new EquivalentAddressGroup( + new FakeSocketAddress("server2"), Attributes.newBuilder().set(ATTR_WEIGHT, 3L).build()); + EquivalentAddressGroup eag3 = new EquivalentAddressGroup( + new FakeSocketAddress("server3"), Attributes.newBuilder().set(ATTR_WEIGHT, 1L).build()); + + int[] counts = countAddressSelections(100, Arrays.asList(eag1, eag2, eag3)); + assertThat(counts[0]).isWithin(7).of(33); + assertThat(counts[1]).isWithin(7).of(33); + assertThat(counts[2]).isWithin(7).of(33); + } + + @Test + public void pickAfterResolved_shuffleWeighted_weightedShuffling() { + PickFirstLeafLoadBalancer.weightedShuffling = true; + EquivalentAddressGroup eag1 = new EquivalentAddressGroup( + new FakeSocketAddress("server1"), Attributes.newBuilder().set(ATTR_WEIGHT, 12L).build()); + EquivalentAddressGroup eag2 = new EquivalentAddressGroup( + new FakeSocketAddress("server2"), Attributes.newBuilder().set(ATTR_WEIGHT, 3L).build()); + EquivalentAddressGroup eag3 = new EquivalentAddressGroup( + new FakeSocketAddress("server3"), Attributes.newBuilder().set(ATTR_WEIGHT, 1L).build()); + + int[] counts = countAddressSelections(100, Arrays.asList(eag1, eag2, eag3)); + assertThat(counts[0]).isWithin(7).of(75); // 100*12/16 + assertThat(counts[1]).isWithin(7).of(19); // 100*3/16 + assertThat(counts[2]).isWithin(7).of(6); // 100*1/16 + } + + /** Returns int[index_of_eag] array with number of times each eag was selected. */ + private int[] countAddressSelections(int trials, List eags) { + int[] counts = new int[eags.size()]; + Random random = new Random(1); + for (int i = 0; i < trials; i++) { + RecordingHelper helper = new RecordingHelper(); + LoadBalancer lb = new PickFirstLeafLoadBalancer(helper); + assertThat(lb.acceptResolvedAddresses(ResolvedAddresses.newBuilder() + .setAddresses(eags) + .setAttributes(affinity) + .setLoadBalancingPolicyConfig( + new PickFirstLeafLoadBalancerConfig(true, random.nextLong())) + .build())) + .isSameInstanceAs(Status.OK); + helper.subchannels.remove().listener.onSubchannelState( + ConnectivityStateInfo.forNonError(READY)); + + assertThat(helper.state).isEqualTo(READY); + Subchannel subchannel = helper.picker.pickSubchannel(mockArgs).getSubchannel(); + counts[eags.indexOf(subchannel.getAddresses())]++; + + lb.shutdown(); + } + return counts; + } + @Test public void requestConnectionPicker() { // Set up @@ -358,11 +498,7 @@ public void pickAfterResolvedAndUnchanged() { // Second acceptResolvedAddresses shouldn't do anything loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder().setAddresses(servers).setAttributes(affinity).build()); - if (enableHappyEyeballs) { - inOrder.verify(mockSubchannel1, never()).requestConnection(); - } else { - inOrder.verify(mockSubchannel1, times(1)).requestConnection(); - } + inOrder.verify(mockSubchannel1, never()).requestConnection(); inOrder.verify(mockHelper, never()).updateBalancingState(any(), any()); } @@ -390,15 +526,44 @@ public void pickAfterResolvedAndChanged() { verify(mockSubchannel2).requestConnection(); } + @Test + public void healthCheck_nonPetiolePolicy() { + when(mockSubchannel1.getAttributes()).thenReturn( + Attributes.newBuilder().set(HAS_HEALTH_PRODUCER_LISTENER_KEY, true).build()); + + // Initialize with one server loadbalancer and both health and state listeners + List oneServer = Lists.newArrayList(servers.get(0)); + loadBalancer.acceptResolvedAddresses(ResolvedAddresses.newBuilder().setAddresses(oneServer) + .setAttributes(Attributes.EMPTY).build()); + InOrder inOrder = inOrder(mockHelper, mockSubchannel1); + inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), any(SubchannelPicker.class)); + inOrder.verify(mockHelper).createSubchannel(createArgsCaptor.capture()); + SubchannelStateListener healthListener = createArgsCaptor.getValue() + .getOption(HEALTH_CONSUMER_LISTENER_ARG_KEY); + inOrder.verify(mockSubchannel1).start(stateListenerCaptor.capture()); + SubchannelStateListener stateListener = stateListenerCaptor.getValue(); + + stateListener.onSubchannelState(ConnectivityStateInfo.forNonError(CONNECTING)); + healthListener.onSubchannelState(ConnectivityStateInfo.forNonError(CONNECTING)); + inOrder.verify(mockHelper, never()).updateBalancingState(any(), any()); + + stateListener.onSubchannelState(ConnectivityStateInfo.forNonError(READY)); + inOrder.verify(mockHelper).updateBalancingState(eq(READY), any()); // health listener ignored + + healthListener.onSubchannelState(ConnectivityStateInfo.forTransientFailure(Status.INTERNAL)); + inOrder.verify(mockHelper, never()).updateBalancingState(any(), any(SubchannelPicker.class)); + } + @Test public void healthCheckFlow() { when(mockSubchannel1.getAttributes()).thenReturn( Attributes.newBuilder().set(HAS_HEALTH_PRODUCER_LISTENER_KEY, true).build()); when(mockSubchannel2.getAttributes()).thenReturn( Attributes.newBuilder().set(HAS_HEALTH_PRODUCER_LISTENER_KEY, true).build()); + List oneServer = Lists.newArrayList(servers.get(0), servers.get(1)); loadBalancer.acceptResolvedAddresses(ResolvedAddresses.newBuilder().setAddresses(oneServer) - .setAttributes(Attributes.EMPTY).build()); + .setAttributes(Attributes.newBuilder().set(IS_PETIOLE_POLICY, true).build()).build()); InOrder inOrder = inOrder(mockHelper, mockSubchannel1, mockSubchannel2); inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), any(SubchannelPicker.class)); @@ -414,13 +579,13 @@ public void healthCheckFlow() { // subchannel2 | IDLE | IDLE stateListener.onSubchannelState(ConnectivityStateInfo.forNonError(CONNECTING)); healthListener.onSubchannelState(ConnectivityStateInfo.forNonError(CONNECTING)); - inOrder.verify(mockHelper, times(0)).updateBalancingState(any(), any()); + inOrder.verify(mockHelper, never()).updateBalancingState(any(), any()); // subchannel | state | health // subchannel1 | READY | CONNECTING // subchannel2 | IDLE | IDLE stateListener.onSubchannelState(ConnectivityStateInfo.forNonError(READY)); - inOrder.verify(mockHelper, times(0)).updateBalancingState(any(), any()); + inOrder.verify(mockHelper, never()).updateBalancingState(any(), any()); // subchannel | state | health // subchannel1 | READY | READY @@ -469,11 +634,147 @@ public void healthCheckFlow() { inOrder.verify(mockHelper).updateBalancingState(eq(READY), pickerCaptor.capture()); assertThat(pickerCaptor.getValue().pickSubchannel(mockArgs) .getSubchannel()).isSameInstanceAs(mockSubchannel1); + verify(mockHelper, atLeast(0)).getSynchronizationContext(); + verify(mockHelper, atLeast(0)).getScheduledExecutorService(); + verifyNoMoreInteractions(mockHelper); healthListener2.onSubchannelState(ConnectivityStateInfo.forNonError(READY)); verifyNoMoreInteractions(mockHelper); } + // reproduces #12796 + @Test + public void healthCheckWithTF_AllowsStateInconsistency() { + assumeTrue(!serializeRetries); + + when(mockSubchannel1.getAttributes()).thenReturn( + Attributes.newBuilder().set(HAS_HEALTH_PRODUCER_LISTENER_KEY, true).build()); + + Attributes petioleAttributes = + Attributes.newBuilder().set(IS_PETIOLE_POLICY, true).build(); + + loadBalancer.acceptResolvedAddresses( + ResolvedAddresses.newBuilder() + .setAddresses( + Lists.newArrayList( + /* server 1 */servers.get(0), + /* server 3 */servers.get(2) + )) + .setAttributes(petioleAttributes) + .build()); + + // Get the state and health listener for subchannel 1 + verify(mockHelper).createSubchannel(createArgsCaptor.capture()); + SubchannelStateListener healthListener1 = + createArgsCaptor.getValue().getOption(HEALTH_CONSUMER_LISTENER_ARG_KEY); + verify(mockSubchannel1).start(stateListenerCaptor.capture()); + SubchannelStateListener stateListener1 = stateListenerCaptor.getValue(); + + // As start() was called, we transition subchannel 1 to CONNECTING... + stateListener1.onSubchannelState(ConnectivityStateInfo.forNonError(CONNECTING)); + healthListener1.onSubchannelState(ConnectivityStateInfo.forNonError(CONNECTING)); + + // ...which eventually ends up READY. + stateListener1.onSubchannelState(ConnectivityStateInfo.forNonError(READY)); + healthListener1.onSubchannelState(ConnectivityStateInfo.forNonError(READY)); + + // Let the fun begin: subchannel 1's health turns into TRANSIENT_FAILURE + healthListener1.onSubchannelState( + ConnectivityStateInfo.forTransientFailure( + Status.UNAVAILABLE.withDescription("health failure"))); + // HealthListener.onSubchannelState gets called. It updates the LBs balancing + // state/concludedState. + assertEquals(TRANSIENT_FAILURE, loadBalancer.getConcludedConnectivityState()); + assertEquals(READY, loadBalancer.getRawConnectivityState()); + + // Subchannel 1's transport goes idle + stateListener1.onSubchannelState(ConnectivityStateInfo.forNonError(IDLE)); + + // LB's raw connectivity stays ready as the TRANSIENT_FAILURE health state + assertEquals(TRANSIENT_FAILURE, loadBalancer.getConcludedConnectivityState()); + assertEquals(READY, loadBalancer.getRawConnectivityState()); + assertEquals(0, loadBalancer.getIndexLocation()); + + // LB tries to reconnect subchannel 1. + verify(mockSubchannel1, times(2)).requestConnection(); + + stateListener1.onSubchannelState(ConnectivityStateInfo.forNonError(CONNECTING)); + + // LB is waiting for subchannel 1 to report status. + assertEquals(TRANSIENT_FAILURE, loadBalancer.getConcludedConnectivityState()); + assertEquals(READY, loadBalancer.getRawConnectivityState()); + assertEquals(0, loadBalancer.getIndexLocation()); + + // Subchannel 1's new connection attempt fails and reports TRANSIENT_FAILURE. + stateListener1.onSubchannelState(ConnectivityStateInfo.forTransientFailure(CONNECTION_ERROR)); + + // LB increments the index and tries to connect to server 3. + assertEquals(TRANSIENT_FAILURE, loadBalancer.getConcludedConnectivityState()); + assertEquals(READY, loadBalancer.getRawConnectivityState()); + assertEquals(1, loadBalancer.getIndexLocation()); + verify(mockSubchannel3).start(stateListenerCaptor.capture()); + SubchannelStateListener stateListener3 = stateListenerCaptor.getValue(); + verify(mockSubchannel3).requestConnection(); + stateListener3.onSubchannelState(ConnectivityStateInfo.forNonError(CONNECTING)); + + // Subchannel 3 connection did not change the state as we are + // still in TRANSIENT_FAILURE health state. + assertEquals(TRANSIENT_FAILURE, loadBalancer.getConcludedConnectivityState()); + assertEquals(READY, loadBalancer.getRawConnectivityState()); + assertEquals(1, loadBalancer.getIndexLocation()); + + List newServers = + Lists.newArrayList( + /* server 2 */ + servers.get(1), + /* server 1 */ + servers.get(0) + ); + + // The resolver update removes the (current) subchannel 3, keeps server 1, and + // resets addressIndex to server2, which has no subchannel. + loadBalancer.acceptResolvedAddresses( + ResolvedAddresses.newBuilder() + .setAddresses(newServers) + .setAttributes(petioleAttributes) + .build()); + + verify(mockSubchannel3, times(1)).shutdown(); + + // LB thinks that there are no subchannels that are trying to connect. + assertEquals(IDLE, loadBalancer.getRawConnectivityState()); + assertEquals(IDLE, loadBalancer.getConcludedConnectivityState()); + // As mentioned, the LB resets the index to 0 by calling addressIndex.updateGroups. + // Given the new list, it is now pointing to server 2 which does not have a subchannel. + assertEquals(0, loadBalancer.getIndexLocation()); + + // Subchannel 1 is still in TRANSIENT_FAILURE state. Is backoff expires, + // and now it is retrying to connect. This state listener transitions the LB to CONNECTING. + stateListener1.onSubchannelState(ConnectivityStateInfo.forNonError(CONNECTING)); + + // As our health state is IDLE now the LB handles the CONNECTING subchannel state change + // by transitioning into CONNECTING itself. + assertEquals(CONNECTING, loadBalancer.getRawConnectivityState()); + assertEquals(CONNECTING, loadBalancer.getConcludedConnectivityState()); + + // Before the fix: + // The index is now pointing to server 2 for which the LB did not create a subchannel yet. + // assertEquals(0, loadBalancer.getIndexLocation()); + + // The index is now pointing to server 1 + assertEquals(1, loadBalancer.getIndexLocation()); + + // The resolver refreshes and provides the same addresses. + // As the LB is in CONNECTING, acceptResolvedAddresses tries + // to get the subchannel represented from the current index (server 2) and + // update its addresses. As the subchannel still does not exist an NPE is thrown. + assertEquals(Status.OK, loadBalancer.acceptResolvedAddresses( + ResolvedAddresses.newBuilder() + .setAddresses(newServers) + .setAttributes(petioleAttributes) + .build())); + } + @Test public void pickAfterStateChangeAfterResolution() { InOrder inOrder = @@ -491,20 +792,7 @@ public void pickAfterStateChangeAfterResolution() { inOrder.verify(mockSubchannel1).start(stateListenerCaptor.capture()); stateListeners[0] = stateListenerCaptor.getValue(); - if (enableHappyEyeballs) { - forwardTimeByConnectionDelay(); - inOrder.verify(mockSubchannel2).start(stateListenerCaptor.capture()); - stateListeners[1] = stateListenerCaptor.getValue(); - forwardTimeByConnectionDelay(); - inOrder.verify(mockSubchannel3).start(stateListenerCaptor.capture()); - stateListeners[2] = stateListenerCaptor.getValue(); - forwardTimeByConnectionDelay(); - inOrder.verify(mockSubchannel4).start(stateListenerCaptor.capture()); - stateListeners[3] = stateListenerCaptor.getValue(); - } - - reset(mockHelper); - + stateListeners[0].onSubchannelState(ConnectivityStateInfo.forNonError(READY)); stateListeners[0].onSubchannelState(ConnectivityStateInfo.forNonError(IDLE)); inOrder.verify(mockHelper).refreshNameResolution(); inOrder.verify(mockHelper).updateBalancingState(eq(IDLE), pickerCaptor.capture()); @@ -514,11 +802,23 @@ public void pickAfterStateChangeAfterResolution() { stateListeners[0].onSubchannelState(ConnectivityStateInfo.forNonError(CONNECTING)); Status error = Status.UNAVAILABLE.withDescription("boom!"); + reset(mockHelper); if (enableHappyEyeballs) { - for (SubchannelStateListener listener : stateListeners) { - listener.onSubchannelState(ConnectivityStateInfo.forTransientFailure(error)); - } + stateListeners[0].onSubchannelState(ConnectivityStateInfo.forTransientFailure(error)); + forwardTimeByConnectionDelay(); + inOrder.verify(mockSubchannel2).start(stateListenerCaptor.capture()); + stateListeners[1] = stateListenerCaptor.getValue(); + stateListeners[1].onSubchannelState(ConnectivityStateInfo.forTransientFailure(error)); + forwardTimeByConnectionDelay(); + inOrder.verify(mockSubchannel3).start(stateListenerCaptor.capture()); + stateListeners[2] = stateListenerCaptor.getValue(); + stateListeners[2].onSubchannelState(ConnectivityStateInfo.forTransientFailure(error)); + forwardTimeByConnectionDelay(); + inOrder.verify(mockSubchannel4).start(stateListenerCaptor.capture()); + stateListeners[3] = stateListenerCaptor.getValue(); + stateListeners[3].onSubchannelState(ConnectivityStateInfo.forTransientFailure(error)); + forwardTimeByConnectionDelay(); } else { stateListeners[0].onSubchannelState(ConnectivityStateInfo.forTransientFailure(error)); for (int i = 1; i < stateListeners.length; i++) { @@ -560,8 +860,81 @@ public void pickAfterResolutionAfterTransientValue() { // Transition from TRANSIENT_ERROR to CONNECTING should also be ignored. stateListener.onSubchannelState(ConnectivityStateInfo.forNonError(CONNECTING)); + verify(mockHelper, atLeast(0)).getSynchronizationContext(); + verify(mockHelper, atLeast(0)).getScheduledExecutorService(); + verifyNoMoreInteractions(mockHelper); + assertEquals(error, pickerCaptor.getValue().pickSubchannel(mockArgs).getStatus()); + } + + @Test + public void pickWithDupAddressesUpDownUp() { + InOrder inOrder = inOrder(mockHelper); + SocketAddress socketAddress = servers.get(0).getAddresses().get(0); + EquivalentAddressGroup badEag = new EquivalentAddressGroup( + Lists.newArrayList(socketAddress, socketAddress)); + List newServers = Lists.newArrayList(badEag); + + loadBalancer.acceptResolvedAddresses( + ResolvedAddresses.newBuilder().setAddresses(newServers).setAttributes(affinity).build()); + verify(mockHelper).updateBalancingState(eq(CONNECTING), pickerCaptor.capture()); + verify(mockHelper).createSubchannel(createArgsCaptor.capture()); + verify(mockSubchannel1).start(stateListenerCaptor.capture()); + SubchannelStateListener stateListener = stateListenerCaptor.getValue(); + + reset(mockHelper); + + // An error has happened. + Status error = Status.UNAVAILABLE.withDescription("boom!"); + stateListener.onSubchannelState(ConnectivityStateInfo.forTransientFailure(error)); + inOrder.verify(mockHelper).updateBalancingState(eq(TRANSIENT_FAILURE), pickerCaptor.capture()); + inOrder.verify(mockHelper).refreshNameResolution(); + assertEquals(error, pickerCaptor.getValue().pickSubchannel(mockArgs).getStatus()); + + // Transition from TRANSIENT_ERROR to CONNECTING should also be ignored. + stateListener.onSubchannelState(ConnectivityStateInfo.forNonError(CONNECTING)); + verify(mockHelper, atLeast(0)).getSynchronizationContext(); + verify(mockHelper, atLeast(0)).getScheduledExecutorService(); + verifyNoMoreInteractions(mockHelper); + assertEquals(error, pickerCaptor.getValue().pickSubchannel(mockArgs).getStatus()); + + // Transition from CONNECTING to READY . + stateListener.onSubchannelState(ConnectivityStateInfo.forNonError(READY)); + inOrder.verify(mockHelper).updateBalancingState(eq(READY), pickerCaptor.capture()); + assertEquals(Status.OK, pickerCaptor.getValue().pickSubchannel(mockArgs).getStatus()); + } + + @Test + public void pickWithDupEagsUpDownUp() { + InOrder inOrder = inOrder(mockHelper); + List newServers = Lists.newArrayList(servers.get(0), servers.get(0)); + + loadBalancer.acceptResolvedAddresses( + ResolvedAddresses.newBuilder().setAddresses(newServers).setAttributes(affinity).build()); + verify(mockHelper).updateBalancingState(eq(CONNECTING), pickerCaptor.capture()); + verify(mockHelper).createSubchannel(createArgsCaptor.capture()); + verify(mockSubchannel1).start(stateListenerCaptor.capture()); + SubchannelStateListener stateListener = stateListenerCaptor.getValue(); + + reset(mockHelper); + + // An error has happened. + Status error = Status.UNAVAILABLE.withDescription("boom!"); + stateListener.onSubchannelState(ConnectivityStateInfo.forTransientFailure(error)); + inOrder.verify(mockHelper).updateBalancingState(eq(TRANSIENT_FAILURE), pickerCaptor.capture()); + inOrder.verify(mockHelper).refreshNameResolution(); + assertEquals(error, pickerCaptor.getValue().pickSubchannel(mockArgs).getStatus()); + + // Transition from TRANSIENT_ERROR to CONNECTING should also be ignored. + stateListener.onSubchannelState(ConnectivityStateInfo.forNonError(CONNECTING)); + verify(mockHelper, atLeast(0)).getSynchronizationContext(); + verify(mockHelper, atLeast(0)).getScheduledExecutorService(); verifyNoMoreInteractions(mockHelper); assertEquals(error, pickerCaptor.getValue().pickSubchannel(mockArgs).getStatus()); + + // Transition from CONNECTING to READY . + stateListener.onSubchannelState(ConnectivityStateInfo.forNonError(READY)); + inOrder.verify(mockHelper).updateBalancingState(eq(READY), pickerCaptor.capture()); + assertEquals(Status.OK, pickerCaptor.getValue().pickSubchannel(mockArgs).getStatus()); } @Test @@ -592,7 +965,7 @@ public void nameResolutionError_emptyAddressList() { } @Test - public void nameResolutionAfterSufficientTFs() { + public void nameResolutionAfterSufficientTFs_multipleEags() { InOrder inOrder = inOrder(mockHelper); acceptXSubchannels(3); Status error = Status.UNAVAILABLE.withDescription("boom!"); @@ -637,6 +1010,57 @@ public void nameResolutionAfterSufficientTFs() { inOrder.verify(mockHelper).refreshNameResolution(); } + @Test + public void nameResolutionAfterSufficientTFs_singleEag() { + InOrder inOrder = inOrder(mockHelper); + EquivalentAddressGroup eag = new EquivalentAddressGroup(Arrays.asList( + new FakeSocketAddress("server1"), + new FakeSocketAddress("server2"), + new FakeSocketAddress("server3"))); + loadBalancer.acceptResolvedAddresses( + ResolvedAddresses.newBuilder().setAddresses(Arrays.asList(eag)).build()); + Status error = Status.UNAVAILABLE.withDescription("boom!"); + + // Initial subchannel gets TF, LB is still in CONNECTING + verify(mockSubchannel1).start(stateListenerCaptor.capture()); + SubchannelStateListener stateListener1 = stateListenerCaptor.getValue(); + stateListener1.onSubchannelState(ConnectivityStateInfo.forTransientFailure(error)); + inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), pickerCaptor.capture()); + assertEquals(Status.OK, pickerCaptor.getValue().pickSubchannel(mockArgs).getStatus()); + + // Second subchannel gets TF, no UpdateBalancingState called + verify(mockSubchannel2).start(stateListenerCaptor.capture()); + SubchannelStateListener stateListener2 = stateListenerCaptor.getValue(); + stateListener2.onSubchannelState(ConnectivityStateInfo.forTransientFailure(error)); + inOrder.verify(mockHelper, never()).refreshNameResolution(); + inOrder.verify(mockHelper, never()).updateBalancingState(any(), any()); + + // Third subchannel gets TF, LB goes into TRANSIENT_FAILURE and does a refreshNameResolution + verify(mockSubchannel3).start(stateListenerCaptor.capture()); + SubchannelStateListener stateListener3 = stateListenerCaptor.getValue(); + stateListener3.onSubchannelState(ConnectivityStateInfo.forTransientFailure(error)); + inOrder.verify(mockHelper).updateBalancingState(eq(TRANSIENT_FAILURE), pickerCaptor.capture()); + inOrder.verify(mockHelper).refreshNameResolution(); + assertEquals(error, pickerCaptor.getValue().pickSubchannel(mockArgs).getStatus()); + + // Only after we have TFs reported for # of subchannels do we call refreshNameResolution + stateListener2.onSubchannelState(ConnectivityStateInfo.forTransientFailure(error)); + inOrder.verify(mockHelper, never()).refreshNameResolution(); + stateListener2.onSubchannelState(ConnectivityStateInfo.forTransientFailure(error)); + inOrder.verify(mockHelper, never()).refreshNameResolution(); + stateListener2.onSubchannelState(ConnectivityStateInfo.forTransientFailure(error)); + inOrder.verify(mockHelper).refreshNameResolution(); + + // Now that we have refreshed, the count should have been reset + // Only after we have TFs reported for # of subchannels do we call refreshNameResolution + stateListener1.onSubchannelState(ConnectivityStateInfo.forTransientFailure(error)); + inOrder.verify(mockHelper, never()).refreshNameResolution(); + stateListener2.onSubchannelState(ConnectivityStateInfo.forTransientFailure(error)); + inOrder.verify(mockHelper, never()).refreshNameResolution(); + stateListener3.onSubchannelState(ConnectivityStateInfo.forTransientFailure(error)); + inOrder.verify(mockHelper).refreshNameResolution(); + } + @Test public void nameResolutionSuccessAfterError() { loadBalancer.handleNameResolutionError(Status.NOT_FOUND.withDescription("nameResolutionError")); @@ -657,6 +1081,38 @@ public void nameResolutionSuccessAfterError() { assertEquals(mockSubchannel1, pickerCaptor.getValue().pickSubchannel(mockArgs).getSubchannel()); } + @Test + public void nameResolutionTemporaryError() { + List newServers = Lists.newArrayList(servers.get(0)); + InOrder inOrder = inOrder(mockHelper, mockSubchannel1, mockSubchannel1n2); + loadBalancer.acceptResolvedAddresses( + ResolvedAddresses.newBuilder().setAddresses(newServers).setAttributes(affinity).build()); + inOrder.verify(mockSubchannel1).start(stateListenerCaptor.capture()); + SubchannelStateListener stateListener1 = stateListenerCaptor.getValue(); + stateListener1.onSubchannelState(ConnectivityStateInfo.forNonError(READY)); + inOrder.verify(mockHelper).updateBalancingState(eq(READY), pickerCaptor.capture()); + assertEquals(mockSubchannel1, pickerCaptor.getValue().pickSubchannel(mockArgs).getSubchannel()); + + loadBalancer.handleNameResolutionError( + Status.UNAVAILABLE.withDescription("nameResolutionError")); + inOrder.verify(mockHelper).updateBalancingState( + eq(TRANSIENT_FAILURE), any(SubchannelPicker.class)); + + loadBalancer.acceptResolvedAddresses( + ResolvedAddresses.newBuilder().setAddresses(servers).setAttributes(affinity).build()); + inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), pickerCaptor.capture()); + inOrder.verify(mockSubchannel1n2).start(stateListenerCaptor.capture()); + SubchannelStateListener stateListener2 = stateListenerCaptor.getValue(); + + assertNull(pickerCaptor.getValue().pickSubchannel(mockArgs).getSubchannel()); + + stateListener2.onSubchannelState(ConnectivityStateInfo.forNonError(READY)); + inOrder.verify(mockHelper).updateBalancingState(eq(READY), pickerCaptor.capture()); + assertEquals(mockSubchannel1n2, + pickerCaptor.getValue().pickSubchannel(mockArgs).getSubchannel()); + } + + @Test public void nameResolutionErrorWithStateChanges() { List newServers = Lists.newArrayList(servers.get(0)); @@ -709,8 +1165,7 @@ public void requestConnection() { loadBalancer.requestConnection(); inOrder.verify(mockSubchannel2).start(stateListenerCaptor.capture()); SubchannelStateListener stateListener2 = stateListenerCaptor.getValue(); - int expectedRequests = enableHappyEyeballs ? 1 : 2; - inOrder.verify(mockSubchannel2, times(expectedRequests)).requestConnection(); + inOrder.verify(mockSubchannel2).requestConnection(); stateListener2.onSubchannelState(ConnectivityStateInfo.forNonError(CONNECTING)); @@ -718,11 +1173,34 @@ public void requestConnection() { loadBalancer.requestConnection(); inOrder.verify(mockHelper, never()).updateBalancingState(any(), any()); inOrder.verify(mockSubchannel1, never()).requestConnection(); - if (enableHappyEyeballs) { - inOrder.verify(mockSubchannel2, never()).requestConnection(); - } else { - inOrder.verify(mockSubchannel2).requestConnection(); - } + inOrder.verify(mockSubchannel2, never()).requestConnection(); + } + + @Test + public void failChannelWhenSubchannelsFail() { + List newServers = Lists.newArrayList(servers.get(0), servers.get(1)); + when(mockSubchannel1.getAllAddresses()).thenReturn(Lists.newArrayList(servers.get(0))); + when(mockSubchannel2.getAllAddresses()).thenReturn(Lists.newArrayList(servers.get(1))); + + // accept resolved addresses + loadBalancer.acceptResolvedAddresses( + ResolvedAddresses.newBuilder().setAddresses(newServers).setAttributes(affinity).build()); + InOrder inOrder = inOrder(mockHelper, mockSubchannel1, mockSubchannel2); + verify(mockHelper).updateBalancingState(eq(CONNECTING), pickerCaptor.capture()); + verify(mockHelper).createSubchannel(createArgsCaptor.capture()); + inOrder.verify(mockSubchannel1).start(stateListenerCaptor.capture()); + SubchannelStateListener stateListener = stateListenerCaptor.getValue(); + assertNull(pickerCaptor.getValue().pickSubchannel(mockArgs).getSubchannel()); + + inOrder.verify(mockSubchannel1).requestConnection(); + stateListener.onSubchannelState(ConnectivityStateInfo.forTransientFailure(CONNECTION_ERROR)); + + inOrder.verify(mockSubchannel2).start(stateListenerCaptor.capture()); + SubchannelStateListener stateListener2 = stateListenerCaptor.getValue(); + stateListener2.onSubchannelState(ConnectivityStateInfo.forTransientFailure(CONNECTION_ERROR)); + + inOrder.verify(mockHelper).updateBalancingState(eq(TRANSIENT_FAILURE), pickerCaptor.capture()); + assertEquals(CONNECTION_ERROR, pickerCaptor.getValue().pickSubchannel(mockArgs).getStatus()); } @Test @@ -902,7 +1380,7 @@ public void updateAddresses_disjoint_connecting() { @Test public void updateAddresses_disjoint_ready_twice() { InOrder inOrder = inOrder(mockHelper, mockSubchannel1, mockSubchannel2, - mockSubchannel3, mockSubchannel4); + mockSubchannel3, mockSubchannel4, mockSubchannel1n2, mockSubchannel2n2); // Creating first set of endpoints/addresses List oldServers = Lists.newArrayList(servers.get(0), servers.get(1)); SubchannelStateListener stateListener2 = null; @@ -950,10 +1428,17 @@ public void updateAddresses_disjoint_ready_twice() { loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder().setAddresses(newServers).setAttributes(affinity).build()); inOrder.verify(mockSubchannel1).shutdown(); - inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), pickerCaptor.capture()); + inOrder.verify(mockHelper).updateBalancingState(eq(IDLE), pickerCaptor.capture()); + inOrder.verify(mockSubchannel3, never()).start(stateListenerCaptor.capture()); + + // Trigger connection creation + picker = pickerCaptor.getValue(); + assertEquals(PickResult.withNoResult(), picker.pickSubchannel(mockArgs)); inOrder.verify(mockSubchannel3).start(stateListenerCaptor.capture()); SubchannelStateListener stateListener3 = stateListenerCaptor.getValue(); inOrder.verify(mockSubchannel3).requestConnection(); + stateListener3.onSubchannelState(ConnectivityStateInfo.forNonError(CONNECTING)); + inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), pickerCaptor.capture()); if (enableHappyEyeballs) { forwardTimeByConnectionDelay(); @@ -1000,17 +1485,19 @@ public void updateAddresses_disjoint_ready_twice() { loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder().setAddresses(newestServers).setAttributes(affinity).build()); inOrder.verify(mockSubchannel3).shutdown(); - inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), pickerCaptor.capture()); - inOrder.verify(mockSubchannel1).start(stateListenerCaptor.capture()); - stateListener = stateListenerCaptor.getValue(); - assertEquals(CONNECTING, loadBalancer.getConcludedConnectivityState()); + inOrder.verify(mockHelper).updateBalancingState(eq(IDLE), pickerCaptor.capture()); + assertEquals(IDLE, loadBalancer.getConcludedConnectivityState()); picker = pickerCaptor.getValue(); // Calling pickSubchannel() twice gave the same result assertEquals(picker.pickSubchannel(mockArgs), picker.pickSubchannel(mockArgs)); // But the picker calls requestConnection() only once - inOrder.verify(mockSubchannel1).requestConnection(); + inOrder.verify(mockSubchannel1n2).start(stateListenerCaptor.capture()); + stateListener = stateListenerCaptor.getValue(); + inOrder.verify(mockSubchannel1n2).requestConnection(); + stateListener.onSubchannelState(ConnectivityStateInfo.forNonError(CONNECTING)); + inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), pickerCaptor.capture()); assertEquals(PickResult.withNoResult(), pickerCaptor.getValue().pickSubchannel(mockArgs)); assertEquals(CONNECTING, loadBalancer.getConcludedConnectivityState()); @@ -1025,23 +1512,24 @@ public void updateAddresses_disjoint_ready_twice() { stateListener.onSubchannelState(ConnectivityStateInfo.forTransientFailure(CONNECTION_ERROR)); // Starting connection attempt to address 2 - if (!enableHappyEyeballs) { - inOrder.verify(mockSubchannel2).start(stateListenerCaptor.capture()); - stateListener2 = stateListenerCaptor.getValue(); - } - inOrder.verify(mockSubchannel2).requestConnection(); + FakeSubchannel mockSubchannel2Attempt = + enableHappyEyeballs ? mockSubchannel2n2 : mockSubchannel2; + inOrder.verify(mockSubchannel2Attempt).start(stateListenerCaptor.capture()); + stateListener2 = stateListenerCaptor.getValue(); + inOrder.verify(mockSubchannel2Attempt).requestConnection(); // Connection attempt to address 2 is successful stateListener2.onSubchannelState(ConnectivityStateInfo.forNonError(READY)); assertEquals(READY, loadBalancer.getConcludedConnectivityState()); - inOrder.verify(mockSubchannel1).shutdown(); + inOrder.verify(mockSubchannel1n2).shutdown(); // Successful connection shuts down other subchannel inOrder.verify(mockHelper).updateBalancingState(eq(READY), pickerCaptor.capture()); picker = pickerCaptor.getValue(); // Verify that picker still returns correct subchannel - assertEquals(PickResult.withSubchannel(mockSubchannel2), picker.pickSubchannel(mockArgs)); + assertEquals( + PickResult.withSubchannel(mockSubchannel2Attempt), picker.pickSubchannel(mockArgs)); } @Test @@ -1091,6 +1579,11 @@ public void updateAddresses_disjoint_transient_failure() { loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder().setAddresses(newServers).setAttributes(affinity).build()); + if (serializeRetries) { + inOrder.verify(mockSubchannel3, never()).start(stateListenerCaptor.capture()); + forwardTimeByBackoffDelay(); + } + // subchannel 3 still attempts a connection even though we stay in transient failure assertEquals(TRANSIENT_FAILURE, loadBalancer.getConcludedConnectivityState()); inOrder.verify(mockSubchannel3).start(stateListenerCaptor.capture()); @@ -1307,6 +1800,8 @@ public void updateAddresses_intersecting_ready() { @Test public void updateAddresses_intersecting_transient_failure() { + assumeTrue(!isSerializingRetries()); + // Starting first connection attempt InOrder inOrder = inOrder(mockHelper, mockSubchannel1, mockSubchannel2, mockSubchannel3, mockSubchannel4); // captor: captures @@ -1571,6 +2066,8 @@ public void updateAddresses_identical_ready() { @Test public void updateAddresses_identical_transient_failure() { + assumeTrue(!isSerializingRetries()); + InOrder inOrder = inOrder(mockHelper, mockSubchannel1, mockSubchannel2, mockSubchannel3, mockSubchannel4); // Creating first set of endpoints/addresses @@ -1624,6 +2121,45 @@ public void updateAddresses_identical_transient_failure() { assertEquals(PickResult.withSubchannel(mockSubchannel1), picker.pickSubchannel(mockArgs)); } + @Test + public void updateAddresses_identicalSingleAddress_connecting() { + // Creating first set of endpoints/addresses + List oldServers = Lists.newArrayList(servers.get(0)); + + // Accept Addresses and verify proper connection flow + assertEquals(IDLE, loadBalancer.getConcludedConnectivityState()); + loadBalancer.acceptResolvedAddresses( + ResolvedAddresses.newBuilder().setAddresses(oldServers).setAttributes(affinity).build()); + verify(mockSubchannel1).start(stateListenerCaptor.capture()); + SubchannelStateListener stateListener = stateListenerCaptor.getValue(); + assertEquals(CONNECTING, loadBalancer.getConcludedConnectivityState()); + + // First connection attempt is successful + stateListener.onSubchannelState(ConnectivityStateInfo.forNonError(CONNECTING)); + assertEquals(CONNECTING, loadBalancer.getConcludedConnectivityState()); + fakeClock.forwardTime(CONNECTION_DELAY_INTERVAL_MS, TimeUnit.MILLISECONDS); + + // verify that picker returns no subchannel + verify(mockHelper).updateBalancingState(eq(CONNECTING), pickerCaptor.capture()); + SubchannelPicker picker = pickerCaptor.getValue(); + assertEquals(PickResult.withNoResult(), picker.pickSubchannel(mockArgs)); + + // Accept same resolved addresses to update + reset(mockHelper); + loadBalancer.acceptResolvedAddresses( + ResolvedAddresses.newBuilder().setAddresses(oldServers).setAttributes(affinity).build()); + fakeClock.forwardTime(CONNECTION_DELAY_INTERVAL_MS, TimeUnit.MILLISECONDS); + + // Verify that no new subchannels were created or started + verify(mockSubchannel2, never()).start(any()); + assertEquals(CONNECTING, loadBalancer.getConcludedConnectivityState()); + + // verify that picker hasn't changed via checking mock helper's interactions + verify(mockHelper, atLeast(0)).getSynchronizationContext(); // Don't care + verify(mockHelper, atLeast(0)).getScheduledExecutorService(); + verifyNoMoreInteractions(mockHelper); + } + @Test public void twoAddressesSeriallyConnect() { // Starting first connection attempt @@ -1885,18 +2421,20 @@ public void lastAddressFailingNotTransientFailure() { loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder().setAddresses(newServers).setAttributes(affinity).build()); - // Verify that no new subchannels were created or started + // Subchannel 2 should be reused since it was trying to connect and is present. inOrder.verify(mockSubchannel1).shutdown(); - inOrder.verify(mockSubchannel3).start(stateListenerCaptor.capture()); - SubchannelStateListener stateListener3 = stateListenerCaptor.getValue(); - inOrder.verify(mockSubchannel3).requestConnection(); + inOrder.verify(mockSubchannel3, never()).start(stateListenerCaptor.capture()); assertEquals(CONNECTING, loadBalancer.getConcludedConnectivityState()); - // Second address connection attempt is unsuccessful, but should not go into transient failure + // Second address connection attempt is unsuccessful, so since at end, but don't have all + // subchannels, schedule a backoff for the first address stateListener2.onSubchannelState(ConnectivityStateInfo.forTransientFailure(CONNECTION_ERROR)); + fakeClock.forwardTime(1, TimeUnit.SECONDS); + inOrder.verify(mockSubchannel3).start(stateListenerCaptor.capture()); + SubchannelStateListener stateListener3 = stateListenerCaptor.getValue(); assertEquals(CONNECTING, loadBalancer.getConcludedConnectivityState()); - // Third address connection attempt is unsuccessful, now we enter transient failure + // Third address connection attempt is unsuccessful, now we enter TF, do name resolution stateListener3.onSubchannelState(ConnectivityStateInfo.forTransientFailure(CONNECTION_ERROR)); assertEquals(TRANSIENT_FAILURE, loadBalancer.getConcludedConnectivityState()); @@ -1923,7 +2461,7 @@ public void recreate_shutdown_subchannel() { // Starting first connection attempt InOrder inOrder = inOrder(mockHelper, mockSubchannel1, mockSubchannel2, - mockSubchannel3, mockSubchannel4); // captor: captures + mockSubchannel3, mockSubchannel4, mockSubchannel1n2); // captor: captures // Creating first set of endpoints/addresses List addrs = @@ -1959,9 +2497,9 @@ public void recreate_shutdown_subchannel() { // Calling pickSubchannel() requests a connection. assertEquals(picker.pickSubchannel(mockArgs), picker.pickSubchannel(mockArgs)); - inOrder.verify(mockSubchannel1).start(stateListenerCaptor.capture()); + inOrder.verify(mockSubchannel1n2).start(stateListenerCaptor.capture()); SubchannelStateListener stateListener3 = stateListenerCaptor.getValue(); - inOrder.verify(mockSubchannel1).requestConnection(); + inOrder.verify(mockSubchannel1n2).requestConnection(); when(mockSubchannel1.getAllAddresses()).thenReturn(Lists.newArrayList(servers.get(0))); // gives the same result when called twice @@ -1976,7 +2514,7 @@ public void recreate_shutdown_subchannel() { // second subchannel connection attempt succeeds inOrder.verify(mockSubchannel2).requestConnection(); stateListener2.onSubchannelState(ConnectivityStateInfo.forNonError(READY)); - inOrder.verify(mockSubchannel1).shutdown(); + inOrder.verify(mockSubchannel1n2).shutdown(); inOrder.verify(mockHelper).updateBalancingState(eq(READY), pickerCaptor.capture()); assertEquals(READY, loadBalancer.getConcludedConnectivityState()); @@ -2021,7 +2559,7 @@ public void shutdown() { public void ready_then_transient_failure_again() { // Starting first connection attempt InOrder inOrder = inOrder(mockHelper, mockSubchannel1, mockSubchannel2, - mockSubchannel3, mockSubchannel4); // captor: captures + mockSubchannel3, mockSubchannel4, mockSubchannel1n2); // captor: captures // Creating first set of endpoints/addresses List addrs = @@ -2058,9 +2596,9 @@ public void ready_then_transient_failure_again() { // Calling pickSubchannel() requests a connection, gives the same result when called twice. assertEquals(picker.pickSubchannel(mockArgs), picker.pickSubchannel(mockArgs)); - inOrder.verify(mockSubchannel1).start(stateListenerCaptor.capture()); + inOrder.verify(mockSubchannel1n2).start(stateListenerCaptor.capture()); SubchannelStateListener stateListener3 = stateListenerCaptor.getValue(); - inOrder.verify(mockSubchannel1).requestConnection(); + inOrder.verify(mockSubchannel1n2).requestConnection(); when(mockSubchannel3.getAllAddresses()).thenReturn(Lists.newArrayList(servers.get(0))); stateListener3.onSubchannelState(ConnectivityStateInfo.forNonError(CONNECTING)); inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), pickerCaptor.capture()); @@ -2076,7 +2614,7 @@ public void ready_then_transient_failure_again() { assertEquals(READY, loadBalancer.getConcludedConnectivityState()); // verify that picker returns correct subchannel - inOrder.verify(mockSubchannel1).shutdown(); + inOrder.verify(mockSubchannel1n2).shutdown(); inOrder.verify(mockHelper).updateBalancingState(eq(READY), pickerCaptor.capture()); picker = pickerCaptor.getValue(); assertEquals(PickResult.withSubchannel(mockSubchannel2), picker.pickSubchannel(mockArgs)); @@ -2084,7 +2622,7 @@ public void ready_then_transient_failure_again() { @Test public void happy_eyeballs_trigger_connection_delay() { - Assume.assumeTrue(enableHappyEyeballs); // This test is only for happy eyeballs + assumeTrue(enableHappyEyeballs); // This test is only for happy eyeballs // Starting first connection attempt InOrder inOrder = inOrder(mockHelper, mockSubchannel1, mockSubchannel2, mockSubchannel3, mockSubchannel4); @@ -2129,7 +2667,7 @@ public void happy_eyeballs_trigger_connection_delay() { @Test public void happy_eyeballs_connection_results_happen_after_get_to_end() { - Assume.assumeTrue(enableHappyEyeballs); // This test is only for happy eyeballs + assumeTrue(enableHappyEyeballs); // This test is only for happy eyeballs InOrder inOrder = inOrder(mockHelper, mockSubchannel1, mockSubchannel2, mockSubchannel3); Status error = Status.UNAUTHENTICATED.withDescription("simulated failure"); @@ -2182,9 +2720,10 @@ public void happy_eyeballs_connection_results_happen_after_get_to_end() { @Test public void happy_eyeballs_pick_pushes_index_over_end() { - Assume.assumeTrue(enableHappyEyeballs); // This test is only for happy eyeballs + assumeTrue(enableHappyEyeballs); // This test is only for happy eyeballs - InOrder inOrder = inOrder(mockHelper, mockSubchannel1, mockSubchannel2, mockSubchannel3); + InOrder inOrder = inOrder(mockHelper, mockSubchannel1, mockSubchannel2, mockSubchannel3, + mockSubchannel2n2, mockSubchannel3n2); Status error = Status.UNAUTHENTICATED.withDescription("simulated failure"); List addrs = @@ -2234,9 +2773,9 @@ public void happy_eyeballs_pick_pushes_index_over_end() { // Try pushing after end with just picks listeners[0].onSubchannelState(ConnectivityStateInfo.forNonError(READY)); - for (SubchannelStateListener listener : listeners) { - listener.onSubchannelState(ConnectivityStateInfo.forNonError(IDLE)); - } + verify(mockSubchannel2).shutdown(); + verify(mockSubchannel3).shutdown(); + listeners[0].onSubchannelState(ConnectivityStateInfo.forNonError(IDLE)); loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder().setAddresses(addrs).setAttributes(affinity).build()); inOrder.verify(mockHelper).updateBalancingState(eq(IDLE), pickerCaptor.capture()); @@ -2247,16 +2786,19 @@ public void happy_eyeballs_pick_pushes_index_over_end() { } assertEquals(IDLE, loadBalancer.getConcludedConnectivityState()); - for (SubchannelStateListener listener : listeners) { - listener.onSubchannelState(ConnectivityStateInfo.forTransientFailure(error)); - } + listeners[0].onSubchannelState(ConnectivityStateInfo.forTransientFailure(error)); + inOrder.verify(mockSubchannel2n2).start(stateListenerCaptor.capture()); + stateListenerCaptor.getValue().onSubchannelState( + ConnectivityStateInfo.forTransientFailure(error)); + inOrder.verify(mockSubchannel3n2).start(stateListenerCaptor.capture()); + stateListenerCaptor.getValue().onSubchannelState( + ConnectivityStateInfo.forTransientFailure(error)); assertEquals(TRANSIENT_FAILURE, loadBalancer.getConcludedConnectivityState()); - } @Test public void happy_eyeballs_fail_then_trigger_connection_delay() { - Assume.assumeTrue(enableHappyEyeballs); // This test is only for happy eyeballs + assumeTrue(enableHappyEyeballs); // This test is only for happy eyeballs // Starting first connection attempt InOrder inOrder = inOrder(mockHelper, mockSubchannel1, mockSubchannel2, mockSubchannel3); assertEquals(IDLE, loadBalancer.getConcludedConnectivityState()); @@ -2335,6 +2877,44 @@ public void advance_index_then_request_connection() { loadBalancer.requestConnection(); // should be handled without throwing exception } + @Test + public void serialized_retries_two_passes() { + assumeTrue(serializeRetries); // This test is only for serialized retries + + InOrder inOrder = inOrder(mockHelper, mockSubchannel1, mockSubchannel2, mockSubchannel3); + Status error = Status.UNAUTHENTICATED.withDescription("simulated failure"); + + List addrs = + Lists.newArrayList(servers.get(0), servers.get(1), servers.get(2)); + Subchannel[] subchannels = new Subchannel[]{mockSubchannel1, mockSubchannel2, mockSubchannel3}; + SubchannelStateListener[] listeners = new SubchannelStateListener[subchannels.length]; + loadBalancer.acceptResolvedAddresses( + ResolvedAddresses.newBuilder().setAddresses(addrs).build()); + forwardTimeByConnectionDelay(2); + for (int i = 0; i < subchannels.length; i++) { + inOrder.verify(subchannels[i]).start(stateListenerCaptor.capture()); + inOrder.verify(subchannels[i]).requestConnection(); + listeners[i] = stateListenerCaptor.getValue(); + listeners[i].onSubchannelState(ConnectivityStateInfo.forTransientFailure(error)); + } + assertEquals(TRANSIENT_FAILURE, loadBalancer.getConcludedConnectivityState()); + assertFalse("Index should be at end", loadBalancer.isIndexValid()); + + forwardTimeByBackoffDelay(); // should trigger retry + for (int i = 0; i < subchannels.length; i++) { + inOrder.verify(subchannels[i]).requestConnection(); + listeners[i].onSubchannelState(ConnectivityStateInfo.forTransientFailure(error)); // cascade + } + inOrder.verify(subchannels[0], never()).requestConnection(); // should wait for backoff delay + + forwardTimeByBackoffDelay(); // should trigger retry again + for (int i = 0; i < subchannels.length; i++) { + inOrder.verify(subchannels[i]).requestConnection(); + assertEquals(i, loadBalancer.getIndexLocation()); + listeners[i].onSubchannelState(ConnectivityStateInfo.forTransientFailure(error)); // cascade + } + } + @Test public void index_looping() { Attributes.Key key = Attributes.Key.create("some-key"); @@ -2349,7 +2929,7 @@ public void index_looping() { PickFirstLeafLoadBalancer.Index index = new PickFirstLeafLoadBalancer.Index(Arrays.asList( new EquivalentAddressGroup(Arrays.asList(addr1, addr2), attr1), new EquivalentAddressGroup(Arrays.asList(addr3), attr2), - new EquivalentAddressGroup(Arrays.asList(addr4, addr5), attr3))); + new EquivalentAddressGroup(Arrays.asList(addr4, addr5), attr3)), enableHappyEyeballs); assertThat(index.getCurrentAddress()).isSameInstanceAs(addr1); assertThat(index.getCurrentEagAttributes()).isSameInstanceAs(attr1); assertThat(index.isAtBeginning()).isTrue(); @@ -2408,7 +2988,7 @@ public void index_updateGroups_resets() { SocketAddress addr3 = new FakeSocketAddress("addr3"); PickFirstLeafLoadBalancer.Index index = new PickFirstLeafLoadBalancer.Index(Arrays.asList( new EquivalentAddressGroup(Arrays.asList(addr1)), - new EquivalentAddressGroup(Arrays.asList(addr2, addr3)))); + new EquivalentAddressGroup(Arrays.asList(addr2, addr3))), enableHappyEyeballs); index.increment(); index.increment(); // We want to make sure both groupIndex and addressIndex are reset @@ -2425,7 +3005,7 @@ public void index_seekTo() { SocketAddress addr3 = new FakeSocketAddress("addr3"); PickFirstLeafLoadBalancer.Index index = new PickFirstLeafLoadBalancer.Index(Arrays.asList( new EquivalentAddressGroup(Arrays.asList(addr1, addr2)), - new EquivalentAddressGroup(Arrays.asList(addr3)))); + new EquivalentAddressGroup(Arrays.asList(addr3))), enableHappyEyeballs); assertThat(index.seekTo(addr3)).isTrue(); assertThat(index.getCurrentAddress()).isSameInstanceAs(addr3); assertThat(index.seekTo(addr1)).isTrue(); @@ -2437,6 +3017,83 @@ public void index_seekTo() { assertThat(index.getCurrentAddress()).isSameInstanceAs(addr2); } + @Test + public void index_interleaving() { + InetSocketAddress addr1_6 = new InetSocketAddress("f38:1:1", 1234); + InetSocketAddress addr1_4 = new InetSocketAddress("10.1.1.1", 1234); + InetSocketAddress addr2_4 = new InetSocketAddress("10.1.1.2", 1234); + InetSocketAddress addr3_4 = new InetSocketAddress("10.1.1.3", 1234); + InetSocketAddress addr4_4 = new InetSocketAddress("10.1.1.4", 1234); + InetSocketAddress addr4_6 = new InetSocketAddress("f38:1:4", 1234); + + Attributes attrs1 = Attributes.newBuilder().build(); + Attributes attrs2 = Attributes.newBuilder().build(); + Attributes attrs3 = Attributes.newBuilder().build(); + Attributes attrs4 = Attributes.newBuilder().build(); + + PickFirstLeafLoadBalancer.Index index = new PickFirstLeafLoadBalancer.Index(Arrays.asList( + new EquivalentAddressGroup(Arrays.asList(addr1_4, addr1_6), attrs1), + new EquivalentAddressGroup(Arrays.asList(addr2_4), attrs2), + new EquivalentAddressGroup(Arrays.asList(addr3_4), attrs3), + new EquivalentAddressGroup(Arrays.asList(addr4_4, addr4_6), attrs4)), enableHappyEyeballs); + + assertThat(index.getCurrentAddress()).isSameInstanceAs(addr1_4); + assertThat(index.getCurrentEagAttributes()).isSameInstanceAs(attrs1); + assertThat(index.isAtBeginning()).isTrue(); + + index.increment(); + assertThat(index.isValid()).isTrue(); + assertThat(index.getCurrentAddress()).isSameInstanceAs(addr1_6); + assertThat(index.getCurrentEagAttributes()).isSameInstanceAs(attrs1); + assertThat(index.isAtBeginning()).isFalse(); + + index.increment(); + assertThat(index.getCurrentAddress()).isSameInstanceAs(addr2_4); + assertThat(index.getCurrentEagAttributes()).isSameInstanceAs(attrs2); + + index.increment(); + if (enableHappyEyeballs) { + assertThat(index.getCurrentAddress()).isSameInstanceAs(addr4_6); + assertThat(index.getCurrentEagAttributes()).isSameInstanceAs(attrs4); + } else { + assertThat(index.getCurrentAddress()).isSameInstanceAs(addr3_4); + assertThat(index.getCurrentEagAttributes()).isSameInstanceAs(attrs3); + } + + index.increment(); + if (enableHappyEyeballs) { + assertThat(index.getCurrentAddress()).isSameInstanceAs(addr3_4); + assertThat(index.getCurrentEagAttributes()).isSameInstanceAs(attrs3); + } else { + assertThat(index.getCurrentAddress()).isSameInstanceAs(addr4_4); + assertThat(index.getCurrentEagAttributes()).isSameInstanceAs(attrs4); + } + + // Move to last entry + assertThat(index.increment()).isTrue(); + assertThat(index.isValid()).isTrue(); + if (enableHappyEyeballs) { + assertThat(index.getCurrentAddress()).isSameInstanceAs(addr4_4); + } else { + assertThat(index.getCurrentAddress()).isSameInstanceAs(addr4_6); + } + + // Move off of the end + assertThat(index.increment()).isFalse(); + assertThat(index.isValid()).isFalse(); + assertThrows(IllegalStateException.class, index::getCurrentAddress); + + // Reset + index.reset(); + assertThat(index.getCurrentAddress()).isSameInstanceAs(addr1_4); + assertThat(index.isAtBeginning()).isTrue(); + assertThat(index.isValid()).isTrue(); + + // Seek to an address + assertThat(index.seekTo(addr4_4)).isTrue(); + assertThat(index.getCurrentAddress()).isSameInstanceAs(addr4_4); + } + private static class FakeSocketAddress extends SocketAddress { final String name; @@ -2446,9 +3103,22 @@ private static class FakeSocketAddress extends SocketAddress { @Override public String toString() { - return "FakeSocketAddress-" + name; + return "FakeSocketAddress(" + name + ")"; + } + + @Override + public boolean equals(Object o) { + if (!(o instanceof FakeSocketAddress)) { + return false; + } + FakeSocketAddress that = (FakeSocketAddress) o; + return this.name.equals(that.name); } + @Override + public int hashCode() { + return name.hashCode(); + } } private void forwardTimeByConnectionDelay() { @@ -2461,6 +3131,11 @@ private void forwardTimeByConnectionDelay(int times) { } } + private void forwardTimeByBackoffDelay() { + backoffMillis = (long) (backoffMillis * 1.8); // backoff factor default is 1.6 with Jitter .2 + fakeClock.forwardTime(backoffMillis, TimeUnit.MILLISECONDS); + } + private void acceptXSubchannels(int num) { List newServers = new ArrayList<>(); for (int i = 0; i < num; i++) { @@ -2506,15 +3181,20 @@ public void updateAddresses(List addrs) { @Override public void shutdown() { + listener.onSubchannelState(ConnectivityStateInfo.forNonError(SHUTDOWN)); } @Override public void requestConnection() { - listener.onSubchannelState(ConnectivityStateInfo.forNonError(CONNECTING)); + } + + @Override + public String toString() { + return "FakeSubchannel@" + hashCode() + "(" + eags + ")"; } } - private class MockHelperImpl extends LoadBalancer.Helper { + private class BaseHelper extends LoadBalancer.Helper { @Override public ManagedChannel createOobChannel(EquivalentAddressGroup eag, String authority) { return null; @@ -2544,19 +3224,47 @@ public ScheduledExecutorService getScheduledExecutorService() { public void refreshNameResolution() { // noop } + } + + private class MockHelperImpl extends BaseHelper { + private final List subchannels; + + public MockHelperImpl(List subchannels) { + this.subchannels = new ArrayList(subchannels); + } @Override public Subchannel createSubchannel(CreateSubchannelArgs args) { - SocketAddress addr = args.getAddresses().get(0).getAddresses().get(0); - List fakeSubchannels = - Arrays.asList(mockSubchannel1, mockSubchannel2, mockSubchannel3, mockSubchannel4, - mockSubchannel5); - for (int i = 1; i <= 5; i++) { - if (addr.toString().equals(new FakeSocketAddress("server" + i).toString())) { - return fakeSubchannels.get(i - 1); + for (int i = 0; i < subchannels.size(); i++) { + Subchannel subchannel = subchannels.get(i); + List addrs = subchannel.getAllAddresses(); + verify(subchannel, atLeast(1)).getAllAddresses(); // ignore the interaction + if (!args.getAddresses().equals(addrs)) { + continue; } + subchannels.remove(i); + return subchannel; } - throw new IllegalArgumentException("Unexpected address: " + addr); + throw new IllegalArgumentException("Unexpected addresses: " + args.getAddresses()); + } + } + + class RecordingHelper extends BaseHelper { + ConnectivityState state; + SubchannelPicker picker; + final Queue subchannels = new ArrayDeque<>(); + + @Override + public void updateBalancingState(ConnectivityState newState, SubchannelPicker newPicker) { + this.state = newState; + this.picker = newPicker; + } + + @Override + public Subchannel createSubchannel(CreateSubchannelArgs args) { + FakeSubchannel subchannel = new FakeSubchannel(args.getAddresses(), args.getAttributes()); + subchannels.add(subchannel); + return subchannel; } } -} \ No newline at end of file +} diff --git a/core/src/test/java/io/grpc/internal/PickFirstLoadBalancerProviderTest.java b/core/src/test/java/io/grpc/internal/PickFirstLoadBalancerProviderTest.java index 3aa9b1872c3..7844aebd3fd 100644 --- a/core/src/test/java/io/grpc/internal/PickFirstLoadBalancerProviderTest.java +++ b/core/src/test/java/io/grpc/internal/PickFirstLoadBalancerProviderTest.java @@ -19,6 +19,7 @@ import static com.google.common.truth.Truth.assertThat; import io.grpc.NameResolver.ConfigOrError; +import io.grpc.internal.PickFirstLeafLoadBalancer.PickFirstLeafLoadBalancerConfig; import io.grpc.internal.PickFirstLoadBalancer.PickFirstLoadBalancerConfig; import java.util.HashMap; import java.util.Map; @@ -35,10 +36,23 @@ public void parseWithConfig() { rawConfig.put("shuffleAddressList", true); ConfigOrError parsedConfig = new PickFirstLoadBalancerProvider().parseLoadBalancingPolicyConfig( rawConfig); - PickFirstLoadBalancerConfig config = (PickFirstLoadBalancerConfig) parsedConfig.getConfig(); - assertThat(config.shuffleAddressList).isTrue(); - assertThat(config.randomSeed).isNull(); + Boolean shuffleAddressList; + Long randomSeed; + + if (PickFirstLoadBalancerProvider.isEnabledNewPickFirst()) { + PickFirstLeafLoadBalancerConfig config = + (PickFirstLeafLoadBalancerConfig) parsedConfig.getConfig(); + shuffleAddressList = config.shuffleAddressList; + randomSeed = config.randomSeed; + } else { + PickFirstLoadBalancerConfig config = (PickFirstLoadBalancerConfig) parsedConfig.getConfig(); + shuffleAddressList = config.shuffleAddressList; + randomSeed = config.randomSeed; + } + + assertThat(shuffleAddressList).isTrue(); + assertThat(randomSeed).isNull(); } @Test @@ -46,9 +60,22 @@ public void parseWithoutConfig() { Map rawConfig = new HashMap<>(); ConfigOrError parsedConfig = new PickFirstLoadBalancerProvider().parseLoadBalancingPolicyConfig( rawConfig); - PickFirstLoadBalancerConfig config = (PickFirstLoadBalancerConfig) parsedConfig.getConfig(); - assertThat(config.shuffleAddressList).isNull(); - assertThat(config.randomSeed).isNull(); + Boolean shuffleAddressList; + Long randomSeed; + + if (PickFirstLoadBalancerProvider.isEnabledNewPickFirst()) { + PickFirstLeafLoadBalancerConfig config = + (PickFirstLeafLoadBalancerConfig) parsedConfig.getConfig(); + shuffleAddressList = config.shuffleAddressList; + randomSeed = config.randomSeed; + } else { + PickFirstLoadBalancerConfig config = (PickFirstLoadBalancerConfig) parsedConfig.getConfig(); + shuffleAddressList = config.shuffleAddressList; + randomSeed = config.randomSeed; + } + + assertThat(shuffleAddressList).isNull(); + assertThat(randomSeed).isNull(); } } diff --git a/core/src/test/java/io/grpc/internal/PickFirstLoadBalancerTest.java b/core/src/test/java/io/grpc/internal/PickFirstLoadBalancerTest.java index 3e0258f2e40..1e130423a45 100644 --- a/core/src/test/java/io/grpc/internal/PickFirstLoadBalancerTest.java +++ b/core/src/test/java/io/grpc/internal/PickFirstLoadBalancerTest.java @@ -21,6 +21,7 @@ import static io.grpc.ConnectivityState.IDLE; import static io.grpc.ConnectivityState.READY; import static io.grpc.ConnectivityState.TRANSIENT_FAILURE; +import static io.grpc.InternalEquivalentAddressGroup.ATTR_WEIGHT; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertSame; @@ -49,12 +50,18 @@ import io.grpc.LoadBalancer.Subchannel; import io.grpc.LoadBalancer.SubchannelPicker; import io.grpc.LoadBalancer.SubchannelStateListener; +import io.grpc.ManagedChannel; import io.grpc.Status; import io.grpc.Status.Code; import io.grpc.SynchronizationContext; import io.grpc.internal.PickFirstLoadBalancer.PickFirstLoadBalancerConfig; import java.net.SocketAddress; +import java.util.ArrayDeque; +import java.util.Arrays; +import java.util.Collections; import java.util.List; +import java.util.Queue; +import java.util.Random; import org.junit.After; import org.junit.Before; import org.junit.Rule; @@ -103,8 +110,12 @@ public void uncaughtException(Thread t, Throwable e) { @Mock // This LoadBalancer doesn't use any of the arg fields, as verified in tearDown(). private PickSubchannelArgs mockArgs; + private boolean originalWeightedShuffling; + @Before public void setUp() { + originalWeightedShuffling = PickFirstLeafLoadBalancer.weightedShuffling; + for (int i = 0; i < 3; i++) { SocketAddress addr = new FakeSocketAddress("server" + i); servers.add(new EquivalentAddressGroup(addr)); @@ -120,6 +131,7 @@ public void setUp() { @After public void tearDown() throws Exception { + PickFirstLeafLoadBalancer.weightedShuffling = originalWeightedShuffling; verifyNoMoreInteractions(mockArgs); } @@ -141,6 +153,12 @@ public void pickAfterResolved() throws Exception { verifyNoMoreInteractions(mockHelper); } + @Test + public void pickAfterResolved_shuffle_oppositeWeightedShuffling() throws Exception { + PickFirstLeafLoadBalancer.weightedShuffling = !PickFirstLeafLoadBalancer.weightedShuffling; + pickAfterResolved_shuffle(); + } + @Test public void pickAfterResolved_shuffle() throws Exception { loadBalancer.acceptResolvedAddresses( @@ -184,6 +202,103 @@ public void pickAfterResolved_noShuffle() throws Exception { verifyNoMoreInteractions(mockHelper); } + @Test + public void pickAfterResolved_shuffleImplicitUniform_oppositeWeightedShuffling() { + PickFirstLeafLoadBalancer.weightedShuffling = !PickFirstLeafLoadBalancer.weightedShuffling; + pickAfterResolved_shuffleImplicitUniform(); + } + + @Test + public void pickAfterResolved_shuffleImplicitUniform() { + EquivalentAddressGroup eag1 = new EquivalentAddressGroup(new FakeSocketAddress("server1")); + EquivalentAddressGroup eag2 = new EquivalentAddressGroup(new FakeSocketAddress("server2")); + EquivalentAddressGroup eag3 = new EquivalentAddressGroup(new FakeSocketAddress("server3")); + + int[] counts = countAddressSelections(99, Arrays.asList(eag1, eag2, eag3)); + assertThat(counts[0]).isWithin(7).of(33); + assertThat(counts[1]).isWithin(7).of(33); + assertThat(counts[2]).isWithin(7).of(33); + } + + @Test + public void pickAfterResolved_shuffleExplicitUniform_oppositeWeightedShuffling() { + PickFirstLeafLoadBalancer.weightedShuffling = !PickFirstLeafLoadBalancer.weightedShuffling; + pickAfterResolved_shuffleExplicitUniform(); + } + + @Test + public void pickAfterResolved_shuffleExplicitUniform() { + EquivalentAddressGroup eag1 = new EquivalentAddressGroup( + new FakeSocketAddress("server1"), Attributes.newBuilder().set(ATTR_WEIGHT, 111L).build()); + EquivalentAddressGroup eag2 = new EquivalentAddressGroup( + new FakeSocketAddress("server2"), Attributes.newBuilder().set(ATTR_WEIGHT, 111L).build()); + EquivalentAddressGroup eag3 = new EquivalentAddressGroup( + new FakeSocketAddress("server3"), Attributes.newBuilder().set(ATTR_WEIGHT, 111L).build()); + + int[] counts = countAddressSelections(99, Arrays.asList(eag1, eag2, eag3)); + assertThat(counts[0]).isWithin(7).of(33); + assertThat(counts[1]).isWithin(7).of(33); + assertThat(counts[2]).isWithin(7).of(33); + } + + @Test + public void pickAfterResolved_shuffleWeighted_noWeightedShuffling() { + PickFirstLeafLoadBalancer.weightedShuffling = false; + EquivalentAddressGroup eag1 = new EquivalentAddressGroup( + new FakeSocketAddress("server1"), Attributes.newBuilder().set(ATTR_WEIGHT, 12L).build()); + EquivalentAddressGroup eag2 = new EquivalentAddressGroup( + new FakeSocketAddress("server2"), Attributes.newBuilder().set(ATTR_WEIGHT, 3L).build()); + EquivalentAddressGroup eag3 = new EquivalentAddressGroup( + new FakeSocketAddress("server3"), Attributes.newBuilder().set(ATTR_WEIGHT, 1L).build()); + + int[] counts = countAddressSelections(100, Arrays.asList(eag1, eag2, eag3)); + assertThat(counts[0]).isWithin(7).of(33); + assertThat(counts[1]).isWithin(7).of(33); + assertThat(counts[2]).isWithin(7).of(33); + } + + @Test + public void pickAfterResolved_shuffleWeighted_weightedShuffling() { + PickFirstLeafLoadBalancer.weightedShuffling = true; + EquivalentAddressGroup eag1 = new EquivalentAddressGroup( + new FakeSocketAddress("server1"), Attributes.newBuilder().set(ATTR_WEIGHT, 12L).build()); + EquivalentAddressGroup eag2 = new EquivalentAddressGroup( + new FakeSocketAddress("server2"), Attributes.newBuilder().set(ATTR_WEIGHT, 3L).build()); + EquivalentAddressGroup eag3 = new EquivalentAddressGroup( + new FakeSocketAddress("server3"), Attributes.newBuilder().set(ATTR_WEIGHT, 1L).build()); + + int[] counts = countAddressSelections(100, Arrays.asList(eag1, eag2, eag3)); + assertThat(counts[0]).isWithin(7).of(75); // 100*12/16 + assertThat(counts[1]).isWithin(7).of(19); // 100*3/16 + assertThat(counts[2]).isWithin(7).of(6); // 100*1/16 + } + + /** Returns int[index_of_eag] array with number of times each eag was selected. */ + private int[] countAddressSelections(int trials, List eags) { + int[] counts = new int[eags.size()]; + Random random = new Random(1); + for (int i = 0; i < trials; i++) { + RecordingHelper helper = new RecordingHelper(); + PickFirstLoadBalancer lb = new PickFirstLoadBalancer(helper); + assertThat(lb.acceptResolvedAddresses(ResolvedAddresses.newBuilder() + .setAddresses(eags) + .setAttributes(affinity) + .setLoadBalancingPolicyConfig( + new PickFirstLoadBalancerConfig(true, random.nextLong())) + .build())) + .isSameInstanceAs(Status.OK); + helper.subchannels.remove().listener.onSubchannelState( + ConnectivityStateInfo.forNonError(READY)); + + assertThat(helper.state).isEqualTo(READY); + Subchannel subchannel = helper.picker.pickSubchannel(mockArgs).getSubchannel(); + counts[eags.indexOf(subchannel.getAllAddresses().get(0))]++; + + lb.shutdown(); + } + return counts; + } + @Test public void requestConnectionPicker() throws Exception { loadBalancer.acceptResolvedAddresses( @@ -219,7 +334,7 @@ public void refreshNameResolutionAfterSubchannelConnectionBroken() { inOrder.verify(mockSubchannel).start(stateListenerCaptor.capture()); SubchannelStateListener stateListener = stateListenerCaptor.getValue(); inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), pickerCaptor.capture()); - assertSame(mockSubchannel, pickerCaptor.getValue().pickSubchannel(mockArgs).getSubchannel()); + assertThat(pickerCaptor.getValue().pickSubchannel(mockArgs).hasResult()).isFalse(); inOrder.verify(mockSubchannel).requestConnection(); stateListener.onSubchannelState(ConnectivityStateInfo.forNonError(CONNECTING)); @@ -278,7 +393,7 @@ public void pickAfterResolvedAndChanged() throws Exception { assertThat(args.getAddresses()).isEqualTo(servers); inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), pickerCaptor.capture()); verify(mockSubchannel).requestConnection(); - assertEquals(mockSubchannel, pickerCaptor.getValue().pickSubchannel(mockArgs).getSubchannel()); + assertThat(pickerCaptor.getValue().pickSubchannel(mockArgs).hasResult()).isFalse(); loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder().setAddresses(newServers).setAttributes(affinity).build()); @@ -300,7 +415,7 @@ public void pickAfterStateChangeAfterResolution() throws Exception { verify(mockSubchannel).start(stateListenerCaptor.capture()); SubchannelStateListener stateListener = stateListenerCaptor.getValue(); verify(mockHelper).updateBalancingState(eq(CONNECTING), pickerCaptor.capture()); - Subchannel subchannel = pickerCaptor.getValue().pickSubchannel(mockArgs).getSubchannel(); + assertThat(pickerCaptor.getValue().pickSubchannel(mockArgs).hasResult()).isFalse(); reset(mockHelper); when(mockHelper.getSynchronizationContext()).thenReturn(syncContext); @@ -317,7 +432,7 @@ public void pickAfterStateChangeAfterResolution() throws Exception { stateListener.onSubchannelState(ConnectivityStateInfo.forNonError(READY)); inOrder.verify(mockHelper).updateBalancingState(eq(READY), pickerCaptor.capture()); - assertEquals(subchannel, pickerCaptor.getValue().pickSubchannel(mockArgs).getSubchannel()); + assertEquals(mockSubchannel, pickerCaptor.getValue().pickSubchannel(mockArgs).getSubchannel()); verify(mockHelper, atLeast(0)).getSynchronizationContext(); // Don't care verifyNoMoreInteractions(mockHelper); @@ -405,8 +520,7 @@ public void nameResolutionSuccessAfterError() throws Exception { inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), pickerCaptor.capture()); verify(mockSubchannel).requestConnection(); - assertEquals(mockSubchannel, pickerCaptor.getValue().pickSubchannel(mockArgs) - .getSubchannel()); + assertThat(pickerCaptor.getValue().pickSubchannel(mockArgs).hasResult()).isFalse(); assertEquals(pickerCaptor.getValue().pickSubchannel(mockArgs), pickerCaptor.getValue().pickSubchannel(mockArgs)); @@ -487,4 +601,96 @@ public String toString() { return "FakeSocketAddress-" + name; } } + + private static class FakeSubchannel extends Subchannel { + private final Attributes attributes; + private List eags; + private SubchannelStateListener listener; + + public FakeSubchannel(List eags, Attributes attributes) { + this.eags = Collections.unmodifiableList(eags); + this.attributes = attributes; + } + + @Override + public List getAllAddresses() { + return eags; + } + + @Override + public Attributes getAttributes() { + return attributes; + } + + @Override + public void start(SubchannelStateListener listener) { + this.listener = listener; + } + + @Override + public void updateAddresses(List addrs) { + this.eags = Collections.unmodifiableList(addrs); + } + + @Override + public void shutdown() { + listener.onSubchannelState(ConnectivityStateInfo.forNonError(ConnectivityState.SHUTDOWN)); + } + + @Override + public void requestConnection() { + } + + @Override + public String toString() { + return "FakeSubchannel@" + hashCode() + "(" + eags + ")"; + } + } + + private class BaseHelper extends Helper { + @Override + public ManagedChannel createOobChannel(EquivalentAddressGroup eag, String authority) { + return null; + } + + @Override + public String getAuthority() { + return null; + } + + @Override + public void updateBalancingState(ConnectivityState newState, SubchannelPicker newPicker) { + // ignore + } + + @Override + public SynchronizationContext getSynchronizationContext() { + return syncContext; + } + + @Override + public void refreshNameResolution() { + // noop + } + } + + class RecordingHelper extends BaseHelper { + ConnectivityState state; + SubchannelPicker picker; + final Queue subchannels = new ArrayDeque<>(); + + @Override + public void updateBalancingState(ConnectivityState newState, SubchannelPicker newPicker) { + this.state = newState; + this.picker = newPicker; + } + + @Override + public Subchannel createSubchannel(CreateSubchannelArgs args) { + FakeSubchannel subchannel = new FakeSubchannel(args.getAddresses(), args.getAttributes()); + subchannels.add(subchannel); + return subchannel; + } + } + } diff --git a/core/src/test/java/io/grpc/internal/ProxyDetectorImplTest.java b/core/src/test/java/io/grpc/internal/ProxyDetectorImplTest.java index 0432a474ac5..af0ed1f35d3 100644 --- a/core/src/test/java/io/grpc/internal/ProxyDetectorImplTest.java +++ b/core/src/test/java/io/grpc/internal/ProxyDetectorImplTest.java @@ -21,6 +21,7 @@ import static org.junit.Assert.assertNotEquals; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyInt; @@ -33,6 +34,7 @@ import io.grpc.HttpConnectProxiedSocketAddress; import io.grpc.ProxiedSocketAddress; import io.grpc.ProxyDetector; +import java.io.IOException; import java.net.InetAddress; import java.net.InetSocketAddress; import java.net.PasswordAuthentication; @@ -40,6 +42,7 @@ import java.net.ProxySelector; import java.net.SocketAddress; import java.net.URI; +import java.util.Collections; import org.junit.Before; import org.junit.Rule; import org.junit.Test; @@ -73,7 +76,7 @@ public ProxySelector get() { return proxySelector; } }; - proxyDetector = new ProxyDetectorImpl(proxySelectorSupplier, authenticator, null); + proxyDetector = new ProxyDetectorImpl(proxySelectorSupplier, authenticator); unresolvedProxy = InetSocketAddress.createUnresolved("10.0.0.1", proxyPort); proxySocketAddress = HttpConnectProxiedSocketAddress.newBuilder() .setTargetAddress(destination) @@ -82,45 +85,6 @@ public ProxySelector get() { .build(); } - @Test - public void override_hostPort() throws Exception { - final String overrideHost = "10.99.99.99"; - final int overridePort = 1234; - final String overrideHostWithPort = overrideHost + ":" + overridePort; - ProxyDetectorImpl proxyDetector = new ProxyDetectorImpl( - proxySelectorSupplier, - authenticator, - overrideHostWithPort); - ProxiedSocketAddress detected = proxyDetector.proxyFor(destination); - assertNotNull(detected); - assertEquals( - HttpConnectProxiedSocketAddress.newBuilder() - .setTargetAddress(destination) - .setProxyAddress( - new InetSocketAddress(InetAddress.getByName(overrideHost), overridePort)) - .build(), - detected); - } - - @Test - public void override_hostOnly() throws Exception { - final String overrideHostWithoutPort = "10.99.99.99"; - final int defaultPort = 80; - ProxyDetectorImpl proxyDetector = new ProxyDetectorImpl( - proxySelectorSupplier, - authenticator, - overrideHostWithoutPort); - ProxiedSocketAddress detected = proxyDetector.proxyFor(destination); - assertNotNull(detected); - assertEquals( - HttpConnectProxiedSocketAddress.newBuilder() - .setTargetAddress(destination) - .setProxyAddress( - new InetSocketAddress(InetAddress.getByName(overrideHostWithoutPort), defaultPort)) - .build(), - detected); - } - @Test public void returnNullWhenNoProxy() throws Exception { when(proxySelector.select(any(URI.class))) @@ -227,8 +191,27 @@ public ProxySelector get() { return null; } }, - authenticator, - null); + authenticator); assertNull(proxyDetector.proxyFor(destination)); } + + @Test + public void throwsWhenProxySelectorReturnsEmptyList() throws Exception { + when(proxySelector.select(any(URI.class))).thenReturn(Collections.emptyList()); + + IOException e = + assertThrows(IOException.class, () -> proxyDetector.proxyFor(destination)); + assertTrue(e.getMessage(), e.getMessage().contains("empty list")); + assertTrue(e.getMessage(), e.getMessage().contains(proxySelector.getClass().getName())); + } + + @Test + public void throwsWhenProxySelectorReturnsNullList() throws Exception { + when(proxySelector.select(any(URI.class))).thenReturn(null); + + IOException e = + assertThrows(IOException.class, () -> proxyDetector.proxyFor(destination)); + assertTrue(e.getMessage(), e.getMessage().contains("null")); + assertTrue(e.getMessage(), e.getMessage().contains(proxySelector.getClass().getName())); + } } diff --git a/core/src/test/java/io/grpc/internal/ReadableBuffersArrayTest.java b/core/src/test/java/io/grpc/internal/ReadableBuffersArrayTest.java index d5c4fa77fd8..5b0fb02c611 100644 --- a/core/src/test/java/io/grpc/internal/ReadableBuffersArrayTest.java +++ b/core/src/test/java/io/grpc/internal/ReadableBuffersArrayTest.java @@ -16,8 +16,8 @@ package io.grpc.internal; -import static com.google.common.base.Charsets.UTF_8; import static io.grpc.internal.ReadableBuffers.wrap; +import static java.nio.charset.StandardCharsets.UTF_8; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertSame; import static org.junit.Assert.assertTrue; diff --git a/core/src/test/java/io/grpc/internal/ReadableBuffersByteBufferTest.java b/core/src/test/java/io/grpc/internal/ReadableBuffersByteBufferTest.java index a040182c259..67e7aaf9132 100644 --- a/core/src/test/java/io/grpc/internal/ReadableBuffersByteBufferTest.java +++ b/core/src/test/java/io/grpc/internal/ReadableBuffersByteBufferTest.java @@ -16,7 +16,7 @@ package io.grpc.internal; -import static com.google.common.base.Charsets.UTF_8; +import static java.nio.charset.StandardCharsets.UTF_8; import java.nio.ByteBuffer; diff --git a/core/src/test/java/io/grpc/internal/ReadableBuffersTest.java b/core/src/test/java/io/grpc/internal/ReadableBuffersTest.java index 2bc5a8a3760..b9135b49503 100644 --- a/core/src/test/java/io/grpc/internal/ReadableBuffersTest.java +++ b/core/src/test/java/io/grpc/internal/ReadableBuffersTest.java @@ -16,7 +16,7 @@ package io.grpc.internal; -import static com.google.common.base.Charsets.UTF_8; +import static java.nio.charset.StandardCharsets.UTF_8; import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertThrows; diff --git a/core/src/test/java/io/grpc/internal/RetriableStreamTest.java b/core/src/test/java/io/grpc/internal/RetriableStreamTest.java index 3487ef02b46..afbdaa395b0 100644 --- a/core/src/test/java/io/grpc/internal/RetriableStreamTest.java +++ b/core/src/test/java/io/grpc/internal/RetriableStreamTest.java @@ -147,6 +147,17 @@ public double nextDouble() { private final ChannelBufferMeter channelBufferUsed = new ChannelBufferMeter(); private final FakeClock fakeClock = new FakeClock(); + private static long calculateBackoffWithRetries(int retryCount) { + // Calculate the exponential backoff delay with jitter + double exponent = retryCount > 0 ? Math.pow(BACKOFF_MULTIPLIER, retryCount) : 1; + long delay = (long) (INITIAL_BACKOFF_IN_SECONDS * exponent); + return RetriableStream.intervalWithJitter(delay); + } + + private static long calculateMaxBackoff() { + return RetriableStream.intervalWithJitter(MAX_BACKOFF_IN_SECONDS); + } + private final class RecordedRetriableStream extends RetriableStream { RecordedRetriableStream(MethodDescriptor method, Metadata headers, ChannelBufferMeter channelBufferUsed, long perRpcBufferLimit, long channelBufferLimit, @@ -175,7 +186,8 @@ ClientStream newSubstream( Metadata metadata, ClientStreamTracer.Factory tracerFactory, int previousAttempts, - boolean isTransparentRetry) { + boolean isTransparentRetry, + boolean isHedgedStream) { bufferSizeTracer = tracerFactory.newClientStreamTracer(STREAM_INFO, metadata); int actualPreviousRpcAttemptsInHeader = metadata.get(GRPC_PREVIOUS_RPC_ATTEMPTS) == null @@ -307,7 +319,7 @@ public Void answer(InvocationOnMock in) { retriableStream.sendMessage("msg1 during backoff1"); retriableStream.sendMessage("msg2 during backoff1"); - fakeClock.forwardTime((long) (INITIAL_BACKOFF_IN_SECONDS * FAKE_RANDOM) - 1L, TimeUnit.SECONDS); + fakeClock.forwardTime(calculateBackoffWithRetries(0) - 1L, TimeUnit.SECONDS); inOrder.verifyNoMoreInteractions(); assertEquals(1, fakeClock.numPendingTasks()); fakeClock.forwardTime(1L, TimeUnit.SECONDS); @@ -364,9 +376,7 @@ public Void answer(InvocationOnMock in) { retriableStream.sendMessage("msg2 during backoff2"); retriableStream.sendMessage("msg3 during backoff2"); - fakeClock.forwardTime( - (long) (INITIAL_BACKOFF_IN_SECONDS * BACKOFF_MULTIPLIER * FAKE_RANDOM) - 1L, - TimeUnit.SECONDS); + fakeClock.forwardTime(calculateBackoffWithRetries(1) - 1L, TimeUnit.SECONDS); inOrder.verifyNoMoreInteractions(); assertEquals(1, fakeClock.numPendingTasks()); fakeClock.forwardTime(1L, TimeUnit.SECONDS); @@ -459,7 +469,7 @@ public void retry_headersRead_cancel() { sublistenerCaptor1.getValue().closed( Status.fromCode(RETRIABLE_STATUS_CODE_1), PROCESSED, new Metadata()); assertEquals(1, fakeClock.numPendingTasks()); - fakeClock.forwardTime((long) (INITIAL_BACKOFF_IN_SECONDS * FAKE_RANDOM), TimeUnit.SECONDS); + fakeClock.forwardTime(calculateBackoffWithRetries(0), TimeUnit.SECONDS); ArgumentCaptor sublistenerCaptor2 = ArgumentCaptor.forClass(ClientStreamListener.class); @@ -518,7 +528,7 @@ public void retry_headersRead_closed() { doReturn(mockStream2).when(retriableStreamRecorder).newSubstream(1); sublistenerCaptor1.getValue().closed( Status.fromCode(RETRIABLE_STATUS_CODE_1), PROCESSED, new Metadata()); - fakeClock.forwardTime((long) (INITIAL_BACKOFF_IN_SECONDS * FAKE_RANDOM), TimeUnit.SECONDS); + fakeClock.forwardTime(calculateBackoffWithRetries(0), TimeUnit.SECONDS); ArgumentCaptor sublistenerCaptor2 = ArgumentCaptor.forClass(ClientStreamListener.class); @@ -584,7 +594,7 @@ public void retry_cancel_closed() { doReturn(mockStream2).when(retriableStreamRecorder).newSubstream(1); sublistenerCaptor1.getValue().closed( Status.fromCode(RETRIABLE_STATUS_CODE_1), PROCESSED, new Metadata()); - fakeClock.forwardTime((long) (INITIAL_BACKOFF_IN_SECONDS * FAKE_RANDOM), TimeUnit.SECONDS); + fakeClock.forwardTime(calculateBackoffWithRetries(0), TimeUnit.SECONDS); ArgumentCaptor sublistenerCaptor2 = ArgumentCaptor.forClass(ClientStreamListener.class); @@ -687,7 +697,7 @@ public void retry_unretriableClosed_cancel() { doReturn(mockStream2).when(retriableStreamRecorder).newSubstream(1); sublistenerCaptor1.getValue().closed( Status.fromCode(RETRIABLE_STATUS_CODE_1), PROCESSED, new Metadata()); - fakeClock.forwardTime((long) (INITIAL_BACKOFF_IN_SECONDS * FAKE_RANDOM), TimeUnit.SECONDS); + fakeClock.forwardTime(calculateBackoffWithRetries(0), TimeUnit.SECONDS); ArgumentCaptor sublistenerCaptor2 = ArgumentCaptor.forClass(ClientStreamListener.class); @@ -705,6 +715,7 @@ public void retry_unretriableClosed_cancel() { // cancel retriableStream.cancel(Status.CANCELLED); inOrder.verify(retriableStreamRecorder, never()).postCommit(); + verify(masterListener, times(1)).closed(any(), any(), any()); } @Test @@ -733,6 +744,7 @@ public void retry_cancelWhileBackoff() { verifyNoMoreInteractions(mockStream1); verifyNoMoreInteractions(mockStream2); + verify(masterListener, times(1)).closed(any(), any(), any()); } @Test @@ -819,7 +831,7 @@ public boolean isReady() { // send more requests during backoff retriableStream.request(789); - fakeClock.forwardTime((long) (INITIAL_BACKOFF_IN_SECONDS * FAKE_RANDOM), TimeUnit.SECONDS); + fakeClock.forwardTime(calculateBackoffWithRetries(0), TimeUnit.SECONDS); inOrder.verify(mockStream2).start(sublistenerCaptor2.get()); inOrder.verify(mockStream2).request(3); @@ -873,7 +885,7 @@ public void request(int numMessages) { doReturn(mockStream2).when(retriableStreamRecorder).newSubstream(1); sublistenerCaptor1.getValue().closed( Status.fromCode(RETRIABLE_STATUS_CODE_1), PROCESSED, new Metadata()); - fakeClock.forwardTime((long) (INITIAL_BACKOFF_IN_SECONDS * FAKE_RANDOM), TimeUnit.SECONDS); + fakeClock.forwardTime(calculateBackoffWithRetries(0), TimeUnit.SECONDS); inOrder.verify(mockStream2).start(sublistenerCaptor2.capture()); inOrder.verify(mockStream2).request(3); @@ -918,7 +930,7 @@ public void start(ClientStreamListener listener) { doReturn(mockStream2).when(retriableStreamRecorder).newSubstream(1); sublistenerCaptor1.getValue().closed( Status.fromCode(RETRIABLE_STATUS_CODE_1), PROCESSED, new Metadata()); - fakeClock.forwardTime((long) (INITIAL_BACKOFF_IN_SECONDS * FAKE_RANDOM), TimeUnit.SECONDS); + fakeClock.forwardTime(calculateBackoffWithRetries(0), TimeUnit.SECONDS); inOrder.verify(mockStream2).start(sublistenerCaptor2.capture()); inOrder.verify(retriableStreamRecorder).postCommit(); @@ -1026,7 +1038,7 @@ public boolean isReady() { retriableStream.request(789); readiness.add(retriableStream.isReady()); // expected false b/c in backoff - fakeClock.forwardTime((long) (INITIAL_BACKOFF_IN_SECONDS * FAKE_RANDOM), TimeUnit.SECONDS); + fakeClock.forwardTime(calculateBackoffWithRetries(0), TimeUnit.SECONDS); verify(mockStream2).start(any(ClientStreamListener.class)); readiness.add(retriableStream.isReady()); // expected true @@ -1108,7 +1120,7 @@ public void addPrevRetryAttemptsToRespHeaders() { doReturn(mockStream2).when(retriableStreamRecorder).newSubstream(1); sublistenerCaptor1.getValue().closed( Status.fromCode(RETRIABLE_STATUS_CODE_1), PROCESSED, new Metadata()); - fakeClock.forwardTime((long) (INITIAL_BACKOFF_IN_SECONDS * FAKE_RANDOM), TimeUnit.SECONDS); + fakeClock.forwardTime(calculateBackoffWithRetries(0), TimeUnit.SECONDS); ArgumentCaptor sublistenerCaptor2 = ArgumentCaptor.forClass(ClientStreamListener.class); @@ -1158,13 +1170,12 @@ public void start(ClientStreamListener listener) { listener1.closed( Status.fromCode(RETRIABLE_STATUS_CODE_1), PROCESSED, new Metadata()); assertEquals(1, fakeClock.numPendingTasks()); - fakeClock.forwardTime((long) (INITIAL_BACKOFF_IN_SECONDS * FAKE_RANDOM), TimeUnit.SECONDS); + fakeClock.forwardTime(calculateBackoffWithRetries(0), TimeUnit.SECONDS); assertEquals(1, fakeClock.numPendingTasks()); // send requests during backoff retriableStream.request(3); - fakeClock.forwardTime( - (long) (INITIAL_BACKOFF_IN_SECONDS * BACKOFF_MULTIPLIER * FAKE_RANDOM), TimeUnit.SECONDS); + fakeClock.forwardTime(calculateBackoffWithRetries(1), TimeUnit.SECONDS); retriableStream.request(1); verify(mockStream1, never()).request(anyInt()); @@ -1205,7 +1216,7 @@ public void start(ClientStreamListener listener) { // retry listener1.closed( Status.fromCode(RETRIABLE_STATUS_CODE_1), PROCESSED, new Metadata()); - fakeClock.forwardTime((long) (INITIAL_BACKOFF_IN_SECONDS * FAKE_RANDOM), TimeUnit.SECONDS); + fakeClock.forwardTime(calculateBackoffWithRetries(0), TimeUnit.SECONDS); verify(mockStream2).start(any(ClientStreamListener.class)); verify(retriableStreamRecorder).postCommit(); @@ -1258,7 +1269,7 @@ public void perRpcBufferLimitExceededDuringBackoff() { bufferSizeTracer.outboundWireSize(2); verify(retriableStreamRecorder, never()).postCommit(); - fakeClock.forwardTime((long) (INITIAL_BACKOFF_IN_SECONDS * FAKE_RANDOM), TimeUnit.SECONDS); + fakeClock.forwardTime(calculateBackoffWithRetries(0), TimeUnit.SECONDS); verify(mockStream2).start(any(ClientStreamListener.class)); verify(mockStream2).isReady(); @@ -1330,7 +1341,7 @@ public void expBackoff_maxBackoff_maxRetryAttempts() { sublistenerCaptor1.getValue().closed( Status.fromCode(RETRIABLE_STATUS_CODE_1), PROCESSED, new Metadata()); assertEquals(1, fakeClock.numPendingTasks()); - fakeClock.forwardTime((long) (INITIAL_BACKOFF_IN_SECONDS * FAKE_RANDOM) - 1L, TimeUnit.SECONDS); + fakeClock.forwardTime(calculateBackoffWithRetries(0) - 1L, TimeUnit.SECONDS); assertEquals(1, fakeClock.numPendingTasks()); fakeClock.forwardTime(1L, TimeUnit.SECONDS); assertEquals(0, fakeClock.numPendingTasks()); @@ -1345,9 +1356,7 @@ public void expBackoff_maxBackoff_maxRetryAttempts() { sublistenerCaptor2.getValue().closed( Status.fromCode(RETRIABLE_STATUS_CODE_2), PROCESSED, new Metadata()); assertEquals(1, fakeClock.numPendingTasks()); - fakeClock.forwardTime( - (long) (INITIAL_BACKOFF_IN_SECONDS * BACKOFF_MULTIPLIER * FAKE_RANDOM) - 1L, - TimeUnit.SECONDS); + fakeClock.forwardTime(calculateBackoffWithRetries(1) - 1L, TimeUnit.SECONDS); assertEquals(1, fakeClock.numPendingTasks()); fakeClock.forwardTime(1L, TimeUnit.SECONDS); assertEquals(0, fakeClock.numPendingTasks()); @@ -1362,10 +1371,7 @@ public void expBackoff_maxBackoff_maxRetryAttempts() { sublistenerCaptor3.getValue().closed( Status.fromCode(RETRIABLE_STATUS_CODE_1), PROCESSED, new Metadata()); assertEquals(1, fakeClock.numPendingTasks()); - fakeClock.forwardTime( - (long) (INITIAL_BACKOFF_IN_SECONDS * BACKOFF_MULTIPLIER * BACKOFF_MULTIPLIER * FAKE_RANDOM) - - 1L, - TimeUnit.SECONDS); + fakeClock.forwardTime(calculateBackoffWithRetries(2) - 1L, TimeUnit.SECONDS); assertEquals(1, fakeClock.numPendingTasks()); fakeClock.forwardTime(1L, TimeUnit.SECONDS); assertEquals(0, fakeClock.numPendingTasks()); @@ -1380,7 +1386,7 @@ public void expBackoff_maxBackoff_maxRetryAttempts() { sublistenerCaptor4.getValue().closed( Status.fromCode(RETRIABLE_STATUS_CODE_2), PROCESSED, new Metadata()); assertEquals(1, fakeClock.numPendingTasks()); - fakeClock.forwardTime((long) (MAX_BACKOFF_IN_SECONDS * FAKE_RANDOM) - 1L, TimeUnit.SECONDS); + fakeClock.forwardTime(calculateMaxBackoff() - 1L, TimeUnit.SECONDS); assertEquals(1, fakeClock.numPendingTasks()); fakeClock.forwardTime(1L, TimeUnit.SECONDS); assertEquals(0, fakeClock.numPendingTasks()); @@ -1395,7 +1401,7 @@ public void expBackoff_maxBackoff_maxRetryAttempts() { sublistenerCaptor5.getValue().closed( Status.fromCode(RETRIABLE_STATUS_CODE_2), PROCESSED, new Metadata()); assertEquals(1, fakeClock.numPendingTasks()); - fakeClock.forwardTime((long) (MAX_BACKOFF_IN_SECONDS * FAKE_RANDOM) - 1L, TimeUnit.SECONDS); + fakeClock.forwardTime(calculateMaxBackoff() - 1L, TimeUnit.SECONDS); assertEquals(1, fakeClock.numPendingTasks()); fakeClock.forwardTime(1L, TimeUnit.SECONDS); assertEquals(0, fakeClock.numPendingTasks()); @@ -1478,7 +1484,7 @@ public void pushback() { sublistenerCaptor3.getValue().closed( Status.fromCode(RETRIABLE_STATUS_CODE_1), PROCESSED, new Metadata()); assertEquals(1, fakeClock.numPendingTasks()); - fakeClock.forwardTime((long) (INITIAL_BACKOFF_IN_SECONDS * FAKE_RANDOM) - 1L, TimeUnit.SECONDS); + fakeClock.forwardTime(calculateBackoffWithRetries(0) - 1L, TimeUnit.SECONDS); assertEquals(1, fakeClock.numPendingTasks()); fakeClock.forwardTime(1L, TimeUnit.SECONDS); assertEquals(0, fakeClock.numPendingTasks()); @@ -1493,9 +1499,7 @@ public void pushback() { sublistenerCaptor4.getValue().closed( Status.fromCode(RETRIABLE_STATUS_CODE_2), PROCESSED, new Metadata()); assertEquals(1, fakeClock.numPendingTasks()); - fakeClock.forwardTime( - (long) (INITIAL_BACKOFF_IN_SECONDS * BACKOFF_MULTIPLIER * FAKE_RANDOM) - 1L, - TimeUnit.SECONDS); + fakeClock.forwardTime(calculateBackoffWithRetries(1) - 1L, TimeUnit.SECONDS); assertEquals(1, fakeClock.numPendingTasks()); fakeClock.forwardTime(1L, TimeUnit.SECONDS); assertEquals(0, fakeClock.numPendingTasks()); @@ -1510,10 +1514,7 @@ public void pushback() { sublistenerCaptor5.getValue().closed( Status.fromCode(RETRIABLE_STATUS_CODE_2), PROCESSED, new Metadata()); assertEquals(1, fakeClock.numPendingTasks()); - fakeClock.forwardTime( - (long) (INITIAL_BACKOFF_IN_SECONDS * BACKOFF_MULTIPLIER * BACKOFF_MULTIPLIER * FAKE_RANDOM) - - 1L, - TimeUnit.SECONDS); + fakeClock.forwardTime(calculateBackoffWithRetries(2) - 1L, TimeUnit.SECONDS); assertEquals(1, fakeClock.numPendingTasks()); fakeClock.forwardTime(1L, TimeUnit.SECONDS); assertEquals(0, fakeClock.numPendingTasks()); @@ -1802,7 +1803,7 @@ public void transparentRetry_onlyOnceOnRefused() { .closed(Status.fromCode(RETRIABLE_STATUS_CODE_1), REFUSED, new Metadata()); assertEquals(1, fakeClock.numPendingTasks()); - fakeClock.forwardTime((long) (INITIAL_BACKOFF_IN_SECONDS * FAKE_RANDOM), TimeUnit.SECONDS); + fakeClock.forwardTime(calculateBackoffWithRetries(0), TimeUnit.SECONDS); inOrder.verify(retriableStreamRecorder).newSubstream(1); ArgumentCaptor sublistenerCaptor3 = ArgumentCaptor.forClass(ClientStreamListener.class); @@ -1905,7 +1906,7 @@ public void normalRetry_thenNoTransparentRetry_butNormalRetry() { .closed(Status.fromCode(RETRIABLE_STATUS_CODE_1), PROCESSED, new Metadata()); assertEquals(1, fakeClock.numPendingTasks()); - fakeClock.forwardTime((long) (INITIAL_BACKOFF_IN_SECONDS * FAKE_RANDOM), TimeUnit.SECONDS); + fakeClock.forwardTime(calculateBackoffWithRetries(0), TimeUnit.SECONDS); inOrder.verify(retriableStreamRecorder).newSubstream(1); ArgumentCaptor sublistenerCaptor2 = ArgumentCaptor.forClass(ClientStreamListener.class); @@ -1921,8 +1922,7 @@ public void normalRetry_thenNoTransparentRetry_butNormalRetry() { .closed(Status.fromCode(RETRIABLE_STATUS_CODE_1), REFUSED, new Metadata()); assertEquals(1, fakeClock.numPendingTasks()); - fakeClock.forwardTime( - (long) (INITIAL_BACKOFF_IN_SECONDS * BACKOFF_MULTIPLIER * FAKE_RANDOM), TimeUnit.SECONDS); + fakeClock.forwardTime(calculateBackoffWithRetries(1), TimeUnit.SECONDS); inOrder.verify(retriableStreamRecorder).newSubstream(2); ArgumentCaptor sublistenerCaptor3 = ArgumentCaptor.forClass(ClientStreamListener.class); @@ -1958,7 +1958,7 @@ public void normalRetry_thenNoTransparentRetry_andNoMoreRetry() { .closed(Status.fromCode(RETRIABLE_STATUS_CODE_1), PROCESSED, new Metadata()); assertEquals(1, fakeClock.numPendingTasks()); - fakeClock.forwardTime((long) (INITIAL_BACKOFF_IN_SECONDS * FAKE_RANDOM), TimeUnit.SECONDS); + fakeClock.forwardTime(calculateBackoffWithRetries(0), TimeUnit.SECONDS); inOrder.verify(retriableStreamRecorder).newSubstream(1); ArgumentCaptor sublistenerCaptor2 = ArgumentCaptor.forClass(ClientStreamListener.class); @@ -2590,9 +2590,7 @@ public void run() { .closed(Status.fromCode(NON_FATAL_STATUS_CODE_1), REFUSED, new Metadata()); } finally { transport2Lock.unlock(); - if (transport1Lock.tryLock()) { - transport1Lock.unlock(); - } + transport1Lock.unlock(); } } }, "Thread-transport2"); diff --git a/core/src/test/java/io/grpc/internal/RetryingNameResolverTest.java b/core/src/test/java/io/grpc/internal/RetryingNameResolverTest.java index 8801f540394..1da93f05fe2 100644 --- a/core/src/test/java/io/grpc/internal/RetryingNameResolverTest.java +++ b/core/src/test/java/io/grpc/internal/RetryingNameResolverTest.java @@ -17,17 +17,16 @@ package io.grpc.internal; import static com.google.common.truth.Truth.assertThat; -import static org.junit.Assert.fail; import static org.mockito.ArgumentMatchers.isA; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; import io.grpc.NameResolver; import io.grpc.NameResolver.Listener2; import io.grpc.NameResolver.ResolutionResult; import io.grpc.Status; import io.grpc.SynchronizationContext; -import io.grpc.internal.RetryingNameResolver.ResolutionResultListener; import java.lang.Thread.UncaughtExceptionHandler; import org.junit.Before; import org.junit.Rule; @@ -57,8 +56,6 @@ public class RetryingNameResolverTest { private RetryScheduler mockRetryScheduler; @Captor private ArgumentCaptor listenerCaptor; - @Captor - private ArgumentCaptor onResultCaptor; private final SynchronizationContext syncContext = new SynchronizationContext( mock(UncaughtExceptionHandler.class)); @@ -76,58 +73,52 @@ public void startAndShutdown() { retryingNameResolver.shutdown(); } - // Make sure the ResolutionResultListener callback is added to the ResolutionResult attributes, - // and the retry scheduler is reset since the name resolution was successful. @Test - public void onResult_sucess() { + public void onResult_success() { + when(mockListener.onResult2(isA(ResolutionResult.class))).thenReturn(Status.OK); retryingNameResolver.start(mockListener); verify(mockNameResolver).start(listenerCaptor.capture()); listenerCaptor.getValue().onResult(ResolutionResult.newBuilder().build()); - verify(mockListener).onResult(onResultCaptor.capture()); - ResolutionResultListener resolutionResultListener = onResultCaptor.getValue() - .getAttributes() - .get(RetryingNameResolver.RESOLUTION_RESULT_LISTENER_KEY); - assertThat(resolutionResultListener).isNotNull(); - resolutionResultListener.resolutionAttempted(Status.OK); verify(mockRetryScheduler).reset(); } - // Make sure the ResolutionResultListener callback is added to the ResolutionResult attributes, - // and that a retry gets scheduled when the resolution results are rejected. + @Test + public void onResult2_sucesss() { + when(mockListener.onResult2(isA(ResolutionResult.class))).thenReturn(Status.OK); + retryingNameResolver.start(mockListener); + verify(mockNameResolver).start(listenerCaptor.capture()); + + assertThat(listenerCaptor.getValue().onResult2(ResolutionResult.newBuilder().build())) + .isEqualTo(Status.OK); + + verify(mockRetryScheduler).reset(); + } + + // Make sure that a retry gets scheduled when the resolution results are rejected. @Test public void onResult_failure() { + when(mockListener.onResult2(isA(ResolutionResult.class))).thenReturn(Status.UNAVAILABLE); retryingNameResolver.start(mockListener); verify(mockNameResolver).start(listenerCaptor.capture()); listenerCaptor.getValue().onResult(ResolutionResult.newBuilder().build()); - verify(mockListener).onResult(onResultCaptor.capture()); - ResolutionResultListener resolutionResultListener = onResultCaptor.getValue() - .getAttributes() - .get(RetryingNameResolver.RESOLUTION_RESULT_LISTENER_KEY); - assertThat(resolutionResultListener).isNotNull(); - resolutionResultListener.resolutionAttempted(Status.UNAVAILABLE); verify(mockRetryScheduler).schedule(isA(Runnable.class)); } - // Wrapping a NameResolver more than once is a misconfiguration. + // Make sure that a retry gets scheduled when the resolution results are rejected. @Test - public void onResult_failure_doubleWrapped() { - NameResolver doubleWrappedResolver = new RetryingNameResolver(retryingNameResolver, - mockRetryScheduler, syncContext); - - doubleWrappedResolver.start(mockListener); + public void onResult2_failure() { + when(mockListener.onResult2(isA(ResolutionResult.class))).thenReturn(Status.UNAVAILABLE); + retryingNameResolver.start(mockListener); verify(mockNameResolver).start(listenerCaptor.capture()); - try { - listenerCaptor.getValue().onResult(ResolutionResult.newBuilder().build()); - } catch (IllegalStateException e) { - assertThat(e).hasMessageThat().contains("can only be used once"); - return; - } - fail("An exception should have been thrown for a double wrapped NAmeResolver"); + assertThat(listenerCaptor.getValue().onResult2(ResolutionResult.newBuilder().build())) + .isEqualTo(Status.UNAVAILABLE); + + verify(mockRetryScheduler).schedule(isA(Runnable.class)); } // A retry should get scheduled when name resolution fails. @@ -139,4 +130,4 @@ public void onError() { verify(mockListener).onError(Status.DEADLINE_EXCEEDED); verify(mockRetryScheduler).schedule(isA(Runnable.class)); } -} \ No newline at end of file +} diff --git a/core/src/test/java/io/grpc/internal/ServerCallImplTest.java b/core/src/test/java/io/grpc/internal/ServerCallImplTest.java index 833f5109e34..7394c83eab2 100644 --- a/core/src/test/java/io/grpc/internal/ServerCallImplTest.java +++ b/core/src/test/java/io/grpc/internal/ServerCallImplTest.java @@ -16,12 +16,14 @@ package io.grpc.internal; -import static com.google.common.base.Charsets.UTF_8; +import static com.google.common.truth.Truth.assertThat; import static io.grpc.internal.GrpcUtil.CONTENT_LENGTH_KEY; +import static java.nio.charset.StandardCharsets.UTF_8; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; import static org.mockito.ArgumentMatchers.any; @@ -54,7 +56,6 @@ import org.junit.Before; import org.junit.Rule; import org.junit.Test; -import org.junit.rules.ExpectedException; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; import org.mockito.ArgumentCaptor; @@ -64,8 +65,6 @@ @RunWith(JUnit4.class) public class ServerCallImplTest { - @SuppressWarnings("deprecation") // https://github.com/grpc/grpc-java/issues/7467 - @Rule public final ExpectedException thrown = ExpectedException.none(); @Rule public final MockitoRule mocks = MockitoJUnit.rule(); @Mock private ServerStream stream; @@ -147,6 +146,12 @@ public void request() { verify(stream).request(10); } + @Test + public void setOnReadyThreshold() { + call.setOnReadyThreshold(10); + verify(stream).setOnReadyThreshold(10); + } + @Test public void sendHeader_firstCall() { Metadata headers = new Metadata(); @@ -169,20 +174,20 @@ public void sendHeader_contentLengthDiscarded() { @Test public void sendHeader_failsOnSecondCall() { call.sendHeaders(new Metadata()); - thrown.expect(IllegalStateException.class); - thrown.expectMessage("sendHeaders has already been called"); - - call.sendHeaders(new Metadata()); + Metadata headers = new Metadata(); + IllegalStateException e = assertThrows(IllegalStateException.class, + () -> call.sendHeaders(headers)); + assertThat(e).hasMessageThat().isEqualTo("sendHeaders has already been called"); } @Test public void sendHeader_failsOnClosed() { call.close(Status.CANCELLED, new Metadata()); - thrown.expect(IllegalStateException.class); - thrown.expectMessage("call is closed"); - - call.sendHeaders(new Metadata()); + Metadata headers = new Metadata(); + IllegalStateException e = assertThrows(IllegalStateException.class, + () -> call.sendHeaders(headers)); + assertThat(e).hasMessageThat().isEqualTo("call is closed"); } @Test @@ -198,18 +203,16 @@ public void sendMessage_failsOnClosed() { call.sendHeaders(new Metadata()); call.close(Status.CANCELLED, new Metadata()); - thrown.expect(IllegalStateException.class); - thrown.expectMessage("call is closed"); - - call.sendMessage(1234L); + IllegalStateException e = assertThrows(IllegalStateException.class, + () -> call.sendMessage(1234L)); + assertThat(e).hasMessageThat().isEqualTo("call is closed"); } @Test public void sendMessage_failsIfheadersUnsent() { - thrown.expect(IllegalStateException.class); - thrown.expectMessage("sendHeaders has not been called"); - - call.sendMessage(1234L); + IllegalStateException e = assertThrows(IllegalStateException.class, + () -> call.sendMessage(1234L)); + assertThat(e).hasMessageThat().isEqualTo("sendHeaders has not been called"); } @Test @@ -484,9 +487,10 @@ public void streamListener_unexpectedRuntimeException() { InputStream inputStream = UNARY_METHOD.streamRequest(1234L); - thrown.expect(RuntimeException.class); - thrown.expectMessage("unexpected exception"); - streamListener.messagesAvailable(new SingleMessageProducer(inputStream)); + SingleMessageProducer producer = new SingleMessageProducer(inputStream); + RuntimeException e = assertThrows(RuntimeException.class, + () -> streamListener.messagesAvailable(producer)); + assertThat(e).hasMessageThat().isEqualTo("unexpected exception"); } private static class LongMarshaller implements Marshaller { diff --git a/core/src/test/java/io/grpc/internal/ServerImplBuilderTest.java b/core/src/test/java/io/grpc/internal/ServerImplBuilderTest.java index ce601c5f837..c2cb281a19e 100644 --- a/core/src/test/java/io/grpc/internal/ServerImplBuilderTest.java +++ b/core/src/test/java/io/grpc/internal/ServerImplBuilderTest.java @@ -18,10 +18,14 @@ import static com.google.common.truth.Truth.assertThat; import static org.junit.Assert.assertEquals; -import static org.junit.Assert.fail; -import io.grpc.InternalGlobalInterceptors; +import io.grpc.InternalConfigurator; +import io.grpc.InternalConfiguratorRegistry; import io.grpc.Metadata; +import io.grpc.MetricRecorder; +import io.grpc.MetricSink; +import io.grpc.NoopMetricSink; +import io.grpc.ServerBuilder; import io.grpc.ServerCall; import io.grpc.ServerCallHandler; import io.grpc.ServerInterceptor; @@ -60,7 +64,8 @@ public ServerCall.Listener interceptCall( new StaticTestingClassLoader( getClass().getClassLoader(), Pattern.compile( - "io\\.grpc\\.InternalGlobalInterceptors|io\\.grpc\\.GlobalInterceptors|" + "io\\.grpc\\.InternalConfigurator|io\\.grpc\\.Configurator|" + + "io\\.grpc\\.InternalConfiguratorRegistry|io\\.grpc\\.ConfiguratorRegistry|" + "io\\.grpc\\.internal\\.[^.]+")); private ServerImplBuilder builder; @@ -71,7 +76,8 @@ public void setUp() throws Exception { new ClientTransportServersBuilder() { @Override public InternalServer buildClientTransportServers( - List streamTracerFactories) { + List streamTracerFactories, + MetricRecorder metricRecorder) { throw new UnsupportedOperationException(); } }); @@ -126,6 +132,13 @@ public void getTracerFactories_disableBoth() { assertThat(factories).containsExactly(DUMMY_USER_TRACER); } + @Test + public void addMetricSink_addsToSinks() { + MetricSink noopMetricSink = new NoopMetricSink(); + builder.addMetricSink(noopMetricSink); + assertThat(builder.metricSinks).containsExactly(noopMetricSink); + } + @Test public void getTracerFactories_callsGet() throws Exception { Class runnable = classLoader.loadClass(StaticTestingClassLoaderCallsGet.class.getName()); @@ -137,18 +150,14 @@ public static final class StaticTestingClassLoaderCallsGet implements Runnable { public void run() { ServerImplBuilder builder = new ServerImplBuilder( - streamTracerFactories -> { + (streamTracerFactories, metricRecorder) -> { throw new UnsupportedOperationException(); }); assertThat(builder.getTracerFactories()).hasSize(2); assertThat(builder.interceptors).hasSize(0); - try { - InternalGlobalInterceptors.setInterceptorsTracers( - Collections.emptyList(), Collections.emptyList(), Collections.emptyList()); - fail("exception expected"); - } catch (IllegalStateException e) { - assertThat(e).hasMessageThat().contains("Set cannot be called after any get call"); - } + InternalConfiguratorRegistry.setConfigurators(Collections.emptyList()); + assertThat(InternalConfiguratorRegistry.getConfigurators()).isEmpty(); + assertThat(InternalConfiguratorRegistry.getConfiguratorsCallCountBeforeSet()).isEqualTo(1); } } @@ -161,13 +170,17 @@ public void getTracerFactories_callsSet() throws Exception { public static final class StaticTestingClassLoaderCallsSet implements Runnable { @Override public void run() { - InternalGlobalInterceptors.setInterceptorsTracers( - Collections.emptyList(), - Arrays.asList(DUMMY_TEST_INTERCEPTOR), - Arrays.asList(DUMMY_USER_TRACER)); + InternalConfiguratorRegistry.setConfigurators( + Arrays.asList(new InternalConfigurator() { + @Override + public void configureServerBuilder(ServerBuilder builder) { + builder.intercept(DUMMY_TEST_INTERCEPTOR); + builder.addStreamTracerFactory(DUMMY_USER_TRACER); + } + })); ServerImplBuilder builder = new ServerImplBuilder( - streamTracerFactories -> { + (streamTracerFactories, metricRecorder) -> { throw new UnsupportedOperationException(); }); assertThat(builder.getTracerFactories()).containsExactly(DUMMY_USER_TRACER); @@ -187,11 +200,10 @@ public static final class StaticTestingClassLoaderCallsSetEmpty implements Runna @Override public void run() { - InternalGlobalInterceptors.setInterceptorsTracers( - Collections.emptyList(), Collections.emptyList(), Collections.emptyList()); + InternalConfiguratorRegistry.setConfigurators(Collections.emptyList()); ServerImplBuilder builder = new ServerImplBuilder( - streamTracerFactories -> { + (streamTracerFactories, metricRecorder) -> { throw new UnsupportedOperationException(); }); assertThat(builder.getTracerFactories()).isEmpty(); diff --git a/core/src/test/java/io/grpc/internal/ServerImplTest.java b/core/src/test/java/io/grpc/internal/ServerImplTest.java index dd93e296208..91969dd6910 100644 --- a/core/src/test/java/io/grpc/internal/ServerImplTest.java +++ b/core/src/test/java/io/grpc/internal/ServerImplTest.java @@ -26,6 +26,7 @@ import static org.junit.Assert.assertNotSame; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; import static org.mockito.AdditionalAnswers.delegatesTo; @@ -52,6 +53,7 @@ import io.grpc.Channel; import io.grpc.Compressor; import io.grpc.Context; +import io.grpc.Deadline; import io.grpc.Grpc; import io.grpc.HandlerRegistry; import io.grpc.IntegerMarshaller; @@ -63,6 +65,7 @@ import io.grpc.InternalServerInterceptors; import io.grpc.Metadata; import io.grpc.MethodDescriptor; +import io.grpc.MetricRecorder; import io.grpc.ServerCall; import io.grpc.ServerCall.Listener; import io.grpc.ServerCallExecutorSupplier; @@ -104,7 +107,6 @@ import org.junit.BeforeClass; import org.junit.Rule; import org.junit.Test; -import org.junit.rules.ExpectedException; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; import org.mockito.ArgumentCaptor; @@ -127,6 +129,10 @@ public class ServerImplTest { .setRequestMarshaller(STRING_MARSHALLER) .setResponseMarshaller(INTEGER_MARSHALLER) .build(); + private static final MethodDescriptor GENERATED_METHOD = + METHOD.toBuilder() + .setSampledToLocalTracing(true) + .build(); private static final Context.Key SERVER_ONLY = Context.key("serverOnly"); private static final Context.Key SERVER_TRACER_ADDED_KEY = Context.key("tracer-added"); private static final Context.CancellableContext SERVER_CONTEXT = @@ -140,8 +146,60 @@ public boolean shouldAccept(Runnable runnable) { }; private static final String AUTHORITY = "some_authority"; - @SuppressWarnings("deprecation") // https://github.com/grpc/grpc-java/issues/7467 - @Rule public final ExpectedException thrown = ExpectedException.none(); + private static final class MethodNameCapturingTracer extends ServerStreamTracer + implements StatsTraceContext.ServerCallMethodListener { + @Nullable private ServerCallInfo serverCallInfo; + @Nullable private String recordedMethodName; + @Nullable private String resolvedMethodName; + private boolean streamClosed; + + @Override + public synchronized void serverCallMethodResolved(MethodDescriptor method) { + resolvedMethodName = + recordMethodName(method.isSampledToLocalTracing(), method.getFullMethodName()); + } + + @Override + public synchronized void streamClosed(Status status) { + streamClosed = true; + if (serverCallInfo != null) { + recordedMethodName = + recordMethodName( + serverCallInfo.getMethodDescriptor().isSampledToLocalTracing(), + serverCallInfo.getMethodDescriptor().getFullMethodName()); + } else if (resolvedMethodName != null) { + recordedMethodName = resolvedMethodName; + } else { + recordedMethodName = "other"; + } + } + + @Override + public synchronized void serverCallStarted(ServerCallInfo callInfo) { + serverCallInfo = callInfo; + if (streamClosed) { + recordedMethodName = + recordMethodName( + callInfo.getMethodDescriptor().isSampledToLocalTracing(), + callInfo.getMethodDescriptor().getFullMethodName()); + } + } + + @Nullable + synchronized ServerCallInfo getServerCallInfo() { + return serverCallInfo; + } + + @Nullable + synchronized String getRecordedMethodName() { + return recordedMethodName; + } + + private static String recordMethodName(boolean generatedMethod, String fullMethodName) { + return generatedMethod ? fullMethodName : "other"; + } + } + @Rule public final MockitoRule mocks = MockitoJUnit.rule(); @BeforeClass @@ -207,7 +265,8 @@ public void startUp() throws IOException { new ClientTransportServersBuilder() { @Override public InternalServer buildClientTransportServers( - List streamTracerFactories) { + List streamTracerFactories, + MetricRecorder metricRecorder) { throw new UnsupportedOperationException(); } }); @@ -461,6 +520,172 @@ public void methodNotFound() throws Exception { assertEquals(Status.Code.UNIMPLEMENTED, statusCaptor.getValue().getCode()); } + @Test + public void primaryRegistryGeneratedMethod_streamClosedBeforeStart_preservesMethodName() + throws Exception { + MethodNameCapturingTracer methodNameTracer = new MethodNameCapturingTracer(); + streamTracerFactories = + Collections.singletonList( + new ServerStreamTracer.Factory() { + @Override + public ServerStreamTracer newServerStreamTracer( + String fullMethodName, Metadata headers) { + return methodNameTracer; + } + }); + builder.addService( + ServerServiceDefinition.builder(new ServiceDescriptor("Waiter", GENERATED_METHOD)) + .addMethod( + GENERATED_METHOD, + new ServerCallHandler() { + @Override + public ServerCall.Listener startCall( + ServerCall call, Metadata headers) { + return callListener; + } + }) + .build()); + + createAndStartServer(); + ServerTransportListener transportListener + = transportServer.registerNewServerTransport(new SimpleServerTransport()); + transportListener.transportReady(Attributes.EMPTY); + Metadata requestHeaders = new Metadata(); + StatsTraceContext statsTraceCtx = + StatsTraceContext.newServerContext( + streamTracerFactories, GENERATED_METHOD.getFullMethodName(), requestHeaders); + when(stream.getAttributes()).thenReturn(Attributes.EMPTY); + when(stream.statsTraceContext()).thenReturn(statsTraceCtx); + + transportListener.streamCreated(stream, GENERATED_METHOD.getFullMethodName(), requestHeaders); + verify(stream).setListener(isA(ServerStreamListener.class)); + verify(stream, atLeast(1)).statsTraceContext(); + + statsTraceCtx.streamClosed(Status.CANCELLED); + assertNull(methodNameTracer.getServerCallInfo()); + assertEquals( + GENERATED_METHOD.getFullMethodName(), + methodNameTracer.getRecordedMethodName()); + + assertEquals(1, executor.runDueTasks()); + + assertNotNull(methodNameTracer.getServerCallInfo()); + assertSame(GENERATED_METHOD, methodNameTracer.getServerCallInfo().getMethodDescriptor()); + assertEquals( + GENERATED_METHOD.getFullMethodName(), + methodNameTracer.getRecordedMethodName()); + verify(fallbackRegistry, never()).lookupMethod(anyString(), any()); + } + + @Test + public void primaryRegistryNonGeneratedMethod_streamClosedBeforeStart_recordsOther() + throws Exception { + MethodNameCapturingTracer methodNameTracer = new MethodNameCapturingTracer(); + streamTracerFactories = + Collections.singletonList( + new ServerStreamTracer.Factory() { + @Override + public ServerStreamTracer newServerStreamTracer( + String fullMethodName, Metadata headers) { + return methodNameTracer; + } + }); + builder.addService( + ServerServiceDefinition.builder(new ServiceDescriptor("Waiter", METHOD)) + .addMethod( + METHOD, + new ServerCallHandler() { + @Override + public ServerCall.Listener startCall( + ServerCall call, Metadata headers) { + return callListener; + } + }) + .build()); + + createAndStartServer(); + ServerTransportListener transportListener + = transportServer.registerNewServerTransport(new SimpleServerTransport()); + transportListener.transportReady(Attributes.EMPTY); + Metadata requestHeaders = new Metadata(); + StatsTraceContext statsTraceCtx = + StatsTraceContext.newServerContext( + streamTracerFactories, METHOD.getFullMethodName(), requestHeaders); + when(stream.getAttributes()).thenReturn(Attributes.EMPTY); + when(stream.statsTraceContext()).thenReturn(statsTraceCtx); + + transportListener.streamCreated(stream, METHOD.getFullMethodName(), requestHeaders); + verify(stream).setListener(isA(ServerStreamListener.class)); + verify(stream, atLeast(1)).statsTraceContext(); + + statsTraceCtx.streamClosed(Status.CANCELLED); + assertNull(methodNameTracer.getServerCallInfo()); + assertEquals("other", methodNameTracer.getRecordedMethodName()); + + assertEquals(1, executor.runDueTasks()); + + assertNotNull(methodNameTracer.getServerCallInfo()); + assertSame(METHOD, methodNameTracer.getServerCallInfo().getMethodDescriptor()); + assertEquals("other", methodNameTracer.getRecordedMethodName()); + verify(fallbackRegistry, never()).lookupMethod(anyString(), any()); + } + + @Test + public void fallbackRegistryGeneratedMethod_streamClosedBeforeStart_resolvesOnAsyncLookup() + throws Exception { + MethodNameCapturingTracer methodNameTracer = new MethodNameCapturingTracer(); + streamTracerFactories = + Collections.singletonList( + new ServerStreamTracer.Factory() { + @Override + public ServerStreamTracer newServerStreamTracer( + String fullMethodName, Metadata headers) { + return methodNameTracer; + } + }); + mutableFallbackRegistry.addService( + ServerServiceDefinition.builder(new ServiceDescriptor("Waiter", GENERATED_METHOD)) + .addMethod( + GENERATED_METHOD, + new ServerCallHandler() { + @Override + public ServerCall.Listener startCall( + ServerCall call, Metadata headers) { + return callListener; + } + }) + .build()); + + createAndStartServer(); + ServerTransportListener transportListener + = transportServer.registerNewServerTransport(new SimpleServerTransport()); + transportListener.transportReady(Attributes.EMPTY); + Metadata requestHeaders = new Metadata(); + StatsTraceContext statsTraceCtx = + StatsTraceContext.newServerContext( + streamTracerFactories, GENERATED_METHOD.getFullMethodName(), requestHeaders); + when(stream.getAttributes()).thenReturn(Attributes.EMPTY); + when(stream.statsTraceContext()).thenReturn(statsTraceCtx); + + transportListener.streamCreated(stream, GENERATED_METHOD.getFullMethodName(), requestHeaders); + verify(stream).setListener(isA(ServerStreamListener.class)); + verify(stream, atLeast(1)).statsTraceContext(); + + statsTraceCtx.streamClosed(Status.CANCELLED); + assertNull(methodNameTracer.getServerCallInfo()); + assertEquals("other", methodNameTracer.getRecordedMethodName()); + verify(fallbackRegistry, never()).lookupMethod(anyString(), any()); + + assertEquals(1, executor.runDueTasks()); + + assertNotNull(methodNameTracer.getServerCallInfo()); + assertSame(GENERATED_METHOD, methodNameTracer.getServerCallInfo().getMethodDescriptor()); + assertEquals( + GENERATED_METHOD.getFullMethodName(), + methodNameTracer.getRecordedMethodName()); + verify(fallbackRegistry).lookupMethod(GENERATED_METHOD.getFullMethodName(), AUTHORITY); + } + @Test public void executorSupplierSameExecutorBasic() throws Exception { @@ -932,7 +1157,7 @@ public void shutdown() { } catch (Exception ex) { throw new AssertionError(ex); } - // If deadlock is possible with this setup, this sychronization completes the loop because + // If deadlock is possible with this setup, this synchronization completes the loop because // the serverShutdown needs a lock that Server is holding while calling this method. synchronized (lock) { } @@ -972,7 +1197,7 @@ public void shutdown() { } catch (Exception ex) { throw new AssertionError(ex); } - // If deadlock is possible with this setup, this sychronization completes the loop + // If deadlock is possible with this setup, this synchronization completes the loop // because the transportTerminated needs a lock that Server is holding while calling this // method. synchronized (lock) { @@ -1148,11 +1373,21 @@ public ServerCall.Listener startCall( @Test public void testContextExpiredBeforeStreamCreate_StreamCancelNotCalledBeforeSetListener() throws Exception { + builder.ticker = new Deadline.Ticker() { + private long time; + + @Override + public long nanoTime() { + time += 1000; + return time; + } + }; + AtomicBoolean contextCancelled = new AtomicBoolean(false); AtomicReference context = new AtomicReference<>(); AtomicReference> callReference = new AtomicReference<>(); - testStreamClose_setup(callReference, context, contextCancelled, 0L); + testStreamClose_setup(callReference, context, contextCancelled, 1L); // This assert that stream.setListener(jumpListener) is called before stream.cancel(), which // prevents extremely short deadlines causing NPEs. @@ -1228,7 +1463,7 @@ public void testStreamClose_deadlineExceededTriggersImmediateCancellation() thro assertFalse(context.get().isCancelled()); assertEquals(1, timer.forwardNanos(1)); - + assertTrue(callReference.get().isCancelled()); assertTrue(context.get().isCancelled()); assertThat(context.get().cancellationCause()).isNotNull(); @@ -1260,9 +1495,8 @@ public List getListenSocketAddresses() { public void getPortBeforeStartedFails() { transportServer = new SimpleServer(); createServer(); - thrown.expect(IllegalStateException.class); - thrown.expectMessage("started"); - server.getPort(); + IllegalStateException e = assertThrows(IllegalStateException.class, () -> server.getPort()); + assertThat(e).hasMessageThat().isEqualTo("Not started"); } @Test @@ -1271,9 +1505,8 @@ public void getPortAfterTerminationFails() throws Exception { createAndStartServer(); server.shutdown(); server.awaitTermination(); - thrown.expect(IllegalStateException.class); - thrown.expectMessage("terminated"); - server.getPort(); + IllegalStateException e = assertThrows(IllegalStateException.class, () -> server.getPort()); + assertThat(e).hasMessageThat().isEqualTo("Already terminated"); } @Test @@ -1298,7 +1531,7 @@ public void handlerRegistryPriorities() throws Exception { assertEquals(1, executor.runDueTasks()); verify(callHandler).startCall(ArgumentMatchers.>any(), ArgumentMatchers.any()); - // This call will be handled by the fallbackRegistry because it's not registred in the internal + // This call will be handled by the fallbackRegistry because it's not registered in the internal // registry. transportListener.streamCreated(stream, "Service1/Method2", requestHeaders); assertEquals(1, executor.runDueTasks()); diff --git a/core/src/test/java/io/grpc/internal/ServiceConfigErrorHandlingTest.java b/core/src/test/java/io/grpc/internal/ServiceConfigErrorHandlingTest.java index 0d050a09a9a..0daee676b82 100644 --- a/core/src/test/java/io/grpc/internal/ServiceConfigErrorHandlingTest.java +++ b/core/src/test/java/io/grpc/internal/ServiceConfigErrorHandlingTest.java @@ -19,6 +19,7 @@ import static com.google.common.base.Preconditions.checkState; import static com.google.common.truth.Truth.assertThat; import static com.google.common.truth.Truth.assertWithMessage; +import static io.grpc.internal.UriWrapper.wrap; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; import static org.mockito.AdditionalAnswers.delegatesTo; @@ -45,6 +46,7 @@ import io.grpc.LoadBalancerRegistry; import io.grpc.NameResolver; import io.grpc.NameResolver.ConfigOrError; +import io.grpc.NameResolverProvider; import io.grpc.Status; import io.grpc.internal.ManagedChannelImplBuilder.FixedPortProvider; import io.grpc.internal.ManagedChannelImplBuilder.UnsupportedClientTransportFactoryBuilder; @@ -161,10 +163,14 @@ private void createChannel(ClientInterceptor... interceptors) { when(mockTransportFactory.getSupportedSocketAddressTypes()).thenReturn(Collections.singleton( InetSocketAddress.class)); + NameResolverProvider nameResolverProvider = + channelBuilder.nameResolverRegistry.getProviderForScheme(expectedUri.getScheme()); channel = new ManagedChannelImpl( channelBuilder, mockTransportFactory, + wrap(expectedUri), + nameResolverProvider, new FakeBackoffPolicyProvider(), balancerRpcExecutorPool, timer.getStopwatchSupplier(), @@ -277,7 +283,7 @@ public void emptyAddresses_validConfig_2ndResolution_lbNeedsAddress() throws Exc assertThat(resolvedAddresses.getLoadBalancingPolicyConfig()).isEqualTo("12"); verify(mockLoadBalancer, never()).handleNameResolutionError(any(Status.class)); - assertThat(channel.getState(true)).isEqualTo(ConnectivityState.IDLE); + assertThat(channel.getState(true)).isEqualTo(ConnectivityState.CONNECTING); reset(mockLoadBalancer); nameResolverFactory.servers.clear(); @@ -480,7 +486,7 @@ public void invalidConfig_2ndResolution() throws Exception { assertThat(newResolvedAddress.getLoadBalancingPolicyConfig()).isEqualTo("1st raw config"); assertThat(channel.getConfigSelector()).isSameInstanceAs(configSelector); verify(mockLoadBalancer, never()).handleNameResolutionError(any(Status.class)); - assertThat(channel.getState(false)).isEqualTo(ConnectivityState.IDLE); + assertThat(channel.getState(false)).isEqualTo(ConnectivityState.CONNECTING); } @Test diff --git a/core/src/test/java/io/grpc/internal/SharedResourceHolderTest.java b/core/src/test/java/io/grpc/internal/SharedResourceHolderTest.java index 531632ca79c..692b22a0a68 100644 --- a/core/src/test/java/io/grpc/internal/SharedResourceHolderTest.java +++ b/core/src/test/java/io/grpc/internal/SharedResourceHolderTest.java @@ -30,7 +30,9 @@ import io.grpc.internal.SharedResourceHolder.Resource; import java.util.LinkedList; +import java.util.concurrent.CyclicBarrier; import java.util.concurrent.Delayed; +import java.util.concurrent.FutureTask; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ScheduledFuture; import java.util.concurrent.TimeUnit; @@ -98,7 +100,7 @@ public void close(ResourceInstance instance) { assertEquals(SharedResourceHolder.DESTROY_DELAY_SECONDS, scheduledDestroyTask.getDelay(TimeUnit.SECONDS)); - // Simluate that the destroyer executes the foo destroying task + // Simulate that the destroyer executes the foo destroying task scheduledDestroyTask.runTask(); assertTrue(sharedFoo.closed); @@ -201,6 +203,46 @@ public void close(ResourceInstance instance) { assertNotSame(instance, holder.getInternal(resource)); } + @Test(timeout = 5000) + public void closeRunsConcurrently() throws Exception { + CyclicBarrier barrier = new CyclicBarrier(2); + class SlowResource implements Resource { + @Override + public ResourceInstance create() { + return new ResourceInstance(); + } + + @Override + public void close(ResourceInstance instance) { + instance.closed = true; + try { + barrier.await(); + barrier.await(); + } catch (Exception ex) { + throw new AssertionError(ex); + } + } + } + + Resource resource = new SlowResource(); + ResourceInstance instance = holder.getInternal(resource); + holder.releaseInternal(resource, instance); + MockScheduledFuture scheduledDestroyTask = scheduledDestroyTasks.poll(); + FutureTask runTask = new FutureTask<>(scheduledDestroyTask::runTask, null); + Thread t = new Thread(runTask); + t.start(); + + barrier.await(); // Ensure the other thread has blocked + assertTrue(instance.closed); + instance = holder.getInternal(resource); + assertFalse(instance.closed); + holder.releaseInternal(resource, instance); + + barrier.await(); // Resume the other thread + t.join(); + runTask.get(); // Check for exception + } + private class MockExecutorFactory implements SharedResourceHolder.ScheduledExecutorFactory { @Override diff --git a/core/src/test/java/io/grpc/internal/SpiffeUtilTest.java b/core/src/test/java/io/grpc/internal/SpiffeUtilTest.java new file mode 100644 index 00000000000..57824cf207f --- /dev/null +++ b/core/src/test/java/io/grpc/internal/SpiffeUtilTest.java @@ -0,0 +1,388 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.internal; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; + +import com.google.common.base.Optional; +import com.google.common.io.ByteStreams; +import io.grpc.internal.SpiffeUtil.SpiffeBundle; +import io.grpc.internal.SpiffeUtil.SpiffeId; +import io.grpc.testing.TlsTesting; +import io.grpc.util.CertificateUtils; +import java.io.File; +import java.io.FileNotFoundException; +import java.io.FileOutputStream; +import java.io.InputStream; +import java.io.OutputStream; +import java.security.cert.X509Certificate; +import java.util.Arrays; +import java.util.Collection; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.experimental.runners.Enclosed; +import org.junit.rules.TemporaryFolder; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.junit.runners.Parameterized.Parameter; +import org.junit.runners.Parameterized.Parameters; + + +@RunWith(Enclosed.class) +public class SpiffeUtilTest { + + @RunWith(Parameterized.class) + public static class ParseSuccessTest { + @Parameter + public String uri; + + @Parameter(1) + public String trustDomain; + + @Parameter(2) + public String path; + + @Test + public void parseSuccessTest() { + SpiffeUtil.SpiffeId spiffeId = SpiffeUtil.parse(uri); + assertEquals(trustDomain, spiffeId.getTrustDomain()); + assertEquals(path, spiffeId.getPath()); + } + + @Parameters(name = "spiffeId={0}") + public static Collection data() { + return Arrays.asList(new String[][] { + {"spiffe://example.com", "example.com", ""}, + {"spiffe://example.com/us", "example.com", "/us"}, + {"spIFfe://qa-staging.final_check.example.com/us", "qa-staging.final_check.example.com", + "/us"}, + {"spiffe://example.com/country/us/state/FL/city/Miami", "example.com", + "/country/us/state/FL/city/Miami"}, + {"SPIFFE://example.com/Czech.Republic/region0.1/city_of-Prague", "example.com", + "/Czech.Republic/region0.1/city_of-Prague"}, + {"spiffe://trust-domain-name/path", "trust-domain-name", "/path"}, + {"spiffe://staging.example.com/payments/mysql", "staging.example.com", "/payments/mysql"}, + {"spiffe://staging.example.com/payments/web-fe", "staging.example.com", + "/payments/web-fe"}, + {"spiffe://k8s-west.example.com/ns/staging/sa/default", "k8s-west.example.com", + "/ns/staging/sa/default"}, + {"spiffe://example.com/9eebccd2-12bf-40a6-b262-65fe0487d453", "example.com", + "/9eebccd2-12bf-40a6-b262-65fe0487d453"}, + {"spiffe://trustdomain/.a..", "trustdomain", "/.a.."}, + {"spiffe://trustdomain/...", "trustdomain", "/..."}, + {"spiffe://trustdomain/abcdefghijklmnopqrstuvwxyz", "trustdomain", + "/abcdefghijklmnopqrstuvwxyz"}, + {"spiffe://trustdomain/abc0123.-_", "trustdomain", "/abc0123.-_"}, + {"spiffe://trustdomain/0123456789", "trustdomain", "/0123456789"}, + {"spiffe://trustdomain0123456789/path", "trustdomain0123456789", "/path"}, + }); + } + } + + @RunWith(Parameterized.class) + public static class ParseFailureTest { + @Parameter + public String uri; + + @Test + public void parseFailureTest() { + assertThrows(IllegalArgumentException.class, () -> SpiffeUtil.parse(uri)); + } + + @Parameters(name = "spiffeId={0}") + public static Collection data() { + return Arrays.asList( + "spiffe:///", + "spiffe://example!com", + "spiffe://exampleя.com/workload-1", + "spiffe://example.com/us/florida/miamiя", + "spiffe:/trustdomain/path", + "spiffe:///path", + "spiffe://trust%20domain/path", + "spiffe://user@trustdomain/path", + "spiffe:// /", + "", + "http://trustdomain/path", + "//trustdomain/path", + "://trustdomain/path", + "piffe://trustdomain/path", + "://", + "://trustdomain", + "spiff", + "spiffe", + "spiffe:////", + "spiffe://trust.domain/../path" + ); + } + } + + public static class ExceptionMessageTest { + + @Test + public void spiffeUriFormatTest() { + NullPointerException npe = assertThrows(NullPointerException.class, () -> + SpiffeUtil.parse(null)); + assertEquals("uri", npe.getMessage()); + + IllegalArgumentException iae = assertThrows(IllegalArgumentException.class, () -> + SpiffeUtil.parse("https://example.com")); + assertEquals("Spiffe Id must start with spiffe://", iae.getMessage()); + + iae = assertThrows(IllegalArgumentException.class, () -> + SpiffeUtil.parse("spiffe://example.com/workload#1")); + assertEquals("Spiffe Id must not contain query fragments", iae.getMessage()); + + iae = assertThrows(IllegalArgumentException.class, () -> + SpiffeUtil.parse("spiffe://example.com/workload-1?t=1")); + assertEquals("Spiffe Id must not contain query parameters", iae.getMessage()); + } + + @Test + public void spiffeTrustDomainFormatTest() { + IllegalArgumentException iae = assertThrows(IllegalArgumentException.class, () -> + SpiffeUtil.parse("spiffe://")); + assertEquals("Trust Domain can't be empty", iae.getMessage()); + + iae = assertThrows(IllegalArgumentException.class, () -> + SpiffeUtil.parse("spiffe://eXample.com")); + assertEquals( + "Trust Domain must contain only letters, numbers, dots, dashes, and underscores " + + "([a-z0-9.-_])", + iae.getMessage()); + + StringBuilder longTrustDomain = new StringBuilder("spiffe://pi.eu."); + for (int i = 0; i < 50; i++) { + longTrustDomain.append("pi.eu"); + } + iae = assertThrows(IllegalArgumentException.class, () -> + SpiffeUtil.parse(longTrustDomain.toString())); + assertEquals("Trust Domain maximum length is 255 characters", iae.getMessage()); + + @SuppressWarnings("OrphanedFormatString") + StringBuilder longSpiffe = new StringBuilder("spiffe://mydomain%21com/"); + for (int i = 0; i < 405; i++) { + longSpiffe.append("qwert"); + } + iae = assertThrows(IllegalArgumentException.class, () -> + SpiffeUtil.parse(longSpiffe.toString())); + assertEquals("Spiffe Id maximum length is 2048 characters", iae.getMessage()); + } + + @Test + public void spiffePathFormatTest() { + IllegalArgumentException iae = assertThrows(IllegalArgumentException.class, () -> + SpiffeUtil.parse("spiffe://example.com//")); + assertEquals("Path must not include a trailing '/'", iae.getMessage()); + + iae = assertThrows(IllegalArgumentException.class, () -> + SpiffeUtil.parse("spiffe://example.com/")); + assertEquals("Path must not include a trailing '/'", iae.getMessage()); + + iae = assertThrows(IllegalArgumentException.class, () -> + SpiffeUtil.parse("spiffe://example.com/us//miami")); + assertEquals("Individual path segments must not be empty", iae.getMessage()); + + iae = assertThrows(IllegalArgumentException.class, () -> + SpiffeUtil.parse("spiffe://example.com/us/.")); + assertEquals("Individual path segments must not be relative path modifiers (i.e. ., ..)", + iae.getMessage()); + + iae = assertThrows(IllegalArgumentException.class, () -> + SpiffeUtil.parse("spiffe://example.com/us!")); + assertEquals("Individual path segments must contain only letters, numbers, dots, dashes, and " + + "underscores ([a-zA-Z0-9.-_])", iae.getMessage()); + } + } + + public static class CertificateApiTest { + private static final String SPIFFE_PEM_FILE = "spiffe_cert.pem"; + private static final String MULTI_URI_SAN_PEM_FILE = "spiffe_multi_uri_san_cert.pem"; + private static final String SERVER_0_PEM_FILE = "server0.pem"; + private static final String TEST_DIRECTORY_PREFIX = "io/grpc/internal/"; + private static final String SPIFFE_TRUST_BUNDLE = "spiffebundle.json"; + private static final String SPIFFE_TRUST_BUNDLE_WITH_EC_KTY = "spiffebundle_ec.json"; + private static final String SPIFFE_TRUST_BUNDLE_MALFORMED = "spiffebundle_malformed.json"; + private static final String SPIFFE_TRUST_BUNDLE_CORRUPTED_CERT = + "spiffebundle_corrupted_cert.json"; + private static final String SPIFFE_TRUST_BUNDLE_WRONG_KTY = "spiffebundle_wrong_kty.json"; + private static final String SPIFFE_TRUST_BUNDLE_WRONG_KID = "spiffebundle_wrong_kid.json"; + private static final String SPIFFE_TRUST_BUNDLE_WRONG_USE = "spiffebundle_wrong_use.json"; + private static final String SPIFFE_TRUST_BUNDLE_WRONG_MULTI_CERTS = + "spiffebundle_wrong_multi_certs.json"; + private static final String SPIFFE_TRUST_BUNDLE_DUPLICATES = "spiffebundle_duplicates.json"; + private static final String SPIFFE_TRUST_BUNDLE_WRONG_ROOT = "spiffebundle_wrong_root.json"; + private static final String SPIFFE_TRUST_BUNDLE_WRONG_SEQ = "spiffebundle_wrong_seq_type.json"; + private static final String DOMAIN_ERROR_MESSAGE = + " Certificate loading for trust domain 'google.com' failed."; + + + @Rule public TemporaryFolder tempFolder = new TemporaryFolder(); + + private X509Certificate[] spiffeCert; + private X509Certificate[] multipleUriSanCert; + private X509Certificate[] serverCert0; + + @Before + public void setUp() throws Exception { + spiffeCert = CertificateUtils.getX509Certificates(TlsTesting.loadCert(SPIFFE_PEM_FILE)); + multipleUriSanCert = CertificateUtils.getX509Certificates(TlsTesting + .loadCert(MULTI_URI_SAN_PEM_FILE)); + serverCert0 = CertificateUtils.getX509Certificates(TlsTesting.loadCert(SERVER_0_PEM_FILE)); + } + + private String copyFileToTmp(String fileName) throws Exception { + File tempFile = tempFolder.newFile(fileName); + try (InputStream resourceStream = SpiffeUtilTest.class.getClassLoader() + .getResourceAsStream(TEST_DIRECTORY_PREFIX + fileName); + OutputStream fileStream = new FileOutputStream(tempFile)) { + ByteStreams.copy(resourceStream, fileStream); + fileStream.flush(); + } + return tempFile.toString(); + } + + @Test + public void extractSpiffeIdSuccessTest() throws Exception { + Optional spiffeId = SpiffeUtil.extractSpiffeId(spiffeCert); + assertTrue(spiffeId.isPresent()); + assertEquals("foo.bar.com", spiffeId.get().getTrustDomain()); + assertEquals("/client/workload/1", spiffeId.get().getPath()); + } + + @Test + public void extractSpiffeIdFailureTest() throws Exception { + Optional spiffeId = SpiffeUtil.extractSpiffeId(serverCert0); + assertFalse(spiffeId.isPresent()); + IllegalArgumentException iae = assertThrows(IllegalArgumentException.class, () -> SpiffeUtil + .extractSpiffeId(multipleUriSanCert)); + assertEquals("Multiple URI SAN values found in the leaf cert.", iae.getMessage()); + + } + + @Test + public void extractSpiffeIdFromChainTest() throws Exception { + // Check that the SPIFFE ID is extracted only from the leaf cert in the chain (spiffeCert + // contains it, but serverCert0 does not). + X509Certificate[] leafWithSpiffeChain = new X509Certificate[]{spiffeCert[0], serverCert0[0]}; + assertTrue(SpiffeUtil.extractSpiffeId(leafWithSpiffeChain).isPresent()); + X509Certificate[] leafWithoutSpiffeChain = + new X509Certificate[]{serverCert0[0], spiffeCert[0]}; + assertFalse(SpiffeUtil.extractSpiffeId(leafWithoutSpiffeChain).isPresent()); + } + + @Test + public void extractSpiffeIdParameterValidityTest() { + NullPointerException npe = assertThrows(NullPointerException.class, () -> SpiffeUtil + .extractSpiffeId(null)); + assertEquals("certChain", npe.getMessage()); + IllegalArgumentException iae = assertThrows(IllegalArgumentException.class, () -> SpiffeUtil + .extractSpiffeId(new X509Certificate[]{})); + assertEquals("certChain can't be empty", iae.getMessage()); + } + + @Test + public void loadTrustBundleFromFileSuccessTest() throws Exception { + SpiffeBundle tb = SpiffeUtil.loadTrustBundleFromFile(copyFileToTmp(SPIFFE_TRUST_BUNDLE)); + assertEquals(2, tb.getSequenceNumbers().size()); + assertEquals(12035488L, (long) tb.getSequenceNumbers().get("example.com")); + assertEquals(-1L, (long) tb.getSequenceNumbers().get("test.example.com")); + assertEquals(3, tb.getBundleMap().size()); + assertEquals(0, tb.getBundleMap().get("test.google.com.au").size()); + assertEquals(1, tb.getBundleMap().get("example.com").size()); + assertEquals(2, tb.getBundleMap().get("test.example.com").size()); + Optional spiffeId = SpiffeUtil.extractSpiffeId(tb.getBundleMap().get("example.com") + .toArray(new X509Certificate[0])); + assertTrue(spiffeId.isPresent()); + assertEquals("foo.bar.com", spiffeId.get().getTrustDomain()); + + SpiffeBundle tb_ec = SpiffeUtil.loadTrustBundleFromFile( + copyFileToTmp(SPIFFE_TRUST_BUNDLE_WITH_EC_KTY)); + assertEquals(2, tb_ec.getSequenceNumbers().size()); + assertEquals(12035488L, (long) tb_ec.getSequenceNumbers().get("example.com")); + assertEquals(-1L, (long) tb_ec.getSequenceNumbers().get("test.example.com")); + assertEquals(3, tb_ec.getBundleMap().size()); + assertEquals(0, tb_ec.getBundleMap().get("test.google.com.au").size()); + assertEquals(1, tb_ec.getBundleMap().get("example.com").size()); + assertEquals(2, tb_ec.getBundleMap().get("test.example.com").size()); + Optional spiffeId_ec = + SpiffeUtil.extractSpiffeId(tb_ec.getBundleMap().get("example.com") + .toArray(new X509Certificate[0])); + assertTrue(spiffeId_ec.isPresent()); + assertEquals("foo.bar.com", spiffeId_ec.get().getTrustDomain()); + } + + @Test + public void loadTrustBundleFromFileFailureTest() { + // Check the exception if JSON root element is different from 'trust_domains' + NullPointerException npe = assertThrows(NullPointerException.class, () -> SpiffeUtil + .loadTrustBundleFromFile(copyFileToTmp(SPIFFE_TRUST_BUNDLE_WRONG_ROOT))); + assertEquals("Mandatory trust_domains element is missing", npe.getMessage()); + // Check the exception if JSON root element is different from 'trust_domains' + ClassCastException cce = assertThrows(ClassCastException.class, () -> SpiffeUtil + .loadTrustBundleFromFile(copyFileToTmp(SPIFFE_TRUST_BUNDLE_WRONG_SEQ))); + assertTrue(cce.getMessage().contains("Number expected to be long")); + // Check the exception if JSON file doesn't contain an object + IllegalArgumentException iae = assertThrows(IllegalArgumentException.class, () -> SpiffeUtil + .loadTrustBundleFromFile(copyFileToTmp(SPIFFE_TRUST_BUNDLE_MALFORMED))); + assertTrue(iae.getMessage().contains("SPIFFE Trust Bundle should be a JSON object.")); + // Check the exception if JSON contains duplicates + iae = assertThrows(IllegalArgumentException.class, () -> SpiffeUtil + .loadTrustBundleFromFile(copyFileToTmp(SPIFFE_TRUST_BUNDLE_DUPLICATES))); + assertEquals("Duplicate key found: google.com", iae.getMessage()); + // Check the exception if 'x5c' value cannot be parsed + iae = assertThrows(IllegalArgumentException.class, () -> SpiffeUtil + .loadTrustBundleFromFile(copyFileToTmp(SPIFFE_TRUST_BUNDLE_CORRUPTED_CERT))); + assertEquals("Certificate can't be parsed." + DOMAIN_ERROR_MESSAGE, iae.getMessage()); + // Check the exception if 'kty' value differs from 'RSA' + iae = assertThrows(IllegalArgumentException.class, () -> SpiffeUtil + .loadTrustBundleFromFile(copyFileToTmp(SPIFFE_TRUST_BUNDLE_WRONG_KTY))); + assertEquals( + "'kty' parameter must be one of [RSA, EC] but 'null' found." + DOMAIN_ERROR_MESSAGE, + iae.getMessage()); + // Check the exception if 'kid' has a value + iae = assertThrows(IllegalArgumentException.class, () -> SpiffeUtil + .loadTrustBundleFromFile(copyFileToTmp(SPIFFE_TRUST_BUNDLE_WRONG_KID))); + assertEquals("'kid' parameter must not be set." + DOMAIN_ERROR_MESSAGE, iae.getMessage()); + // Check the exception if 'use' value differs from 'x509-svid' + iae = assertThrows(IllegalArgumentException.class, () -> SpiffeUtil + .loadTrustBundleFromFile(copyFileToTmp(SPIFFE_TRUST_BUNDLE_WRONG_USE))); + assertEquals("'use' parameter must be 'x509-svid' but 'i_am_not_x509-svid' found." + + DOMAIN_ERROR_MESSAGE, iae.getMessage()); + // Check the exception if multiple certs are provided for 'x5c' + iae = assertThrows(IllegalArgumentException.class, () -> SpiffeUtil + .loadTrustBundleFromFile(copyFileToTmp(SPIFFE_TRUST_BUNDLE_WRONG_MULTI_CERTS))); + assertEquals("Exactly 1 certificate is expected, but 2 found." + DOMAIN_ERROR_MESSAGE, + iae.getMessage()); + } + + @Test + public void loadTrustBundleFromFileParameterValidityTest() { + NullPointerException npe = assertThrows(NullPointerException.class, () -> SpiffeUtil + .loadTrustBundleFromFile(null)); + assertEquals("trustBundleFile", npe.getMessage()); + FileNotFoundException nsfe = assertThrows(FileNotFoundException.class, () -> SpiffeUtil + .loadTrustBundleFromFile("i_do_not_exist")); + assertTrue( + "Did not contain expected substring: " + nsfe.getMessage(), + nsfe.getMessage().contains("i_do_not_exist")); + } + } +} diff --git a/core/src/test/java/io/grpc/internal/TransportFrameUtilTest.java b/core/src/test/java/io/grpc/internal/TransportFrameUtilTest.java index 5fa789df4f3..8b4bc170d52 100644 --- a/core/src/test/java/io/grpc/internal/TransportFrameUtilTest.java +++ b/core/src/test/java/io/grpc/internal/TransportFrameUtilTest.java @@ -16,10 +16,10 @@ package io.grpc.internal; -import static com.google.common.base.Charsets.US_ASCII; -import static com.google.common.base.Charsets.UTF_8; import static io.grpc.Metadata.ASCII_STRING_MARSHALLER; import static io.grpc.Metadata.BINARY_BYTE_MARSHALLER; +import static java.nio.charset.StandardCharsets.US_ASCII; +import static java.nio.charset.StandardCharsets.UTF_8; import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNull; diff --git a/core/src/test/resources/io/grpc/internal/spiffebundle.json b/core/src/test/resources/io/grpc/internal/spiffebundle.json new file mode 100644 index 00000000000..f968f730d94 --- /dev/null +++ b/core/src/test/resources/io/grpc/internal/spiffebundle.json @@ -0,0 +1,115 @@ +{ + "trust_domains": { + "test.google.com.au": {}, + "example.com": { + "spiffe_sequence": 12035488, + "keys": [ + { + "kty": "RSA", + "use": "x509-svid", + "x5c": ["MIIFsjCCA5qgAwIBAgIURygVMMzdr+Q7rsUaz189JozyHMwwDQYJKoZIhvcNAQEL + BQAwTjELMAkGA1UEBhMCVVMxCzAJBgNVBAgMAkNBMQwwCgYDVQQHDANTVkwxDTAL + BgNVBAoMBGdSUEMxFTATBgNVBAMMDHRlc3QtY2xpZW50MTAeFw0yMTEyMjMxODQy + NTJaFw0zMTEyMjExODQyNTJaME4xCzAJBgNVBAYTAlVTMQswCQYDVQQIDAJDQTEM + MAoGA1UEBwwDU1ZMMQ0wCwYDVQQKDARnUlBDMRUwEwYDVQQDDAx0ZXN0LWNsaWVu + dDEwggIiMA0GCSqGSIb3DQEBAQUAA4ICDwAwggIKAoICAQDJ4AqpGetyVSqGUuBJ + LVFla+7bEfca7UYzfVSSZLZ/X+JDmWIVN8UIPuFib5jhMEc3XaUnFXUmM7zEtz/Z + G5hapwLwOb2C3ZxOP6PQjYCJxbkLie+b43UQrFu1xxd3vMhVJgcj/AIxEpmszuqO + a6kUrkYifjJADQ+64kZgl66bsTdXMCzpxyFl9xUfff59L8OX+HUfAcoZz3emjg3Z + JPYURQEmjdZTOau1EjFilwHgd989Jt7NKgx30NXoHmw7nusVBIY94fL2VKN3f1XV + m0dHu5NI279Q6zr0ZBU7k5T3IeHnzsUesQS4NGlklDWoVTKk73Uv9Pna8yQsSW75 + 7PEbHOGp9Knu4bnoGPOlsG81yIPipO6hTgGFK24pF97M9kpGbWqYX4+2vLlrCAfc + msHqaUPmQlYeRVTT6vw7ctYo2kyUYGtnODXk76LqewRBVvkzx75QUhfjAyb740Yc + DmIenc56Tq6gebJHjhEmVSehR6xIpXP7SVeurTyhPsEQnpJHtgs4dcwWOZp7BvPN + zHXmJqfr7vsshie3vS5kQ0u1e1yqAqXgyDjqKXOkx+dpgUTehSJHhPNHvTc5LXRs + vvXKYz6FrwR/DZ8t7BNEvPeLjFgxpH7QVJFLCvCbXs5K6yYbsnLfxFIBPRnrbJkI + sK+sQwnRdnsiUdPsTkG5B2lQfQIDAQABo4GHMIGEMB0GA1UdDgQWBBQ2lBp0PiRH + HvQ5IRURm8aHsj4RETAfBgNVHSMEGDAWgBQ2lBp0PiRHHvQ5IRURm8aHsj4RETAP + BgNVHRMBAf8EBTADAQH/MDEGA1UdEQQqMCiGJnNwaWZmZTovL2Zvby5iYXIuY29t + L2NsaWVudC93b3JrbG9hZC8xMA0GCSqGSIb3DQEBCwUAA4ICAQA1mSkgRclAl+E/ + aS9zJ7t8+Y4n3T24nOKKveSIjxXm/zjhWqVsLYBI6kglWtih2+PELvU8JdPqNZK3 + 4Kl0Q6FWpVSGDdWN1i6NyORt2ocggL3ke3iXxRk3UpUKJmqwz81VhA2KUHnMlyE0 + IufFfZNwNWWHBv13uJfRbjeQpKPhU+yf4DeXrsWcvrZlGvAET+mcplafUzCp7Iv+ + PcISJtUerbxbVtuHVeZCLlgDXWkLAWJN8rf0dIG4x060LJ+j6j9uRVhb9sZn1HJV + +j4XdIYm1VKilluhOtNwP2d3Ox/JuTBxf7hFHXZPfMagQE5k5PzmxRaCAEMJ1l2D + vUbZw+shJfSNoWcBo2qadnUaWT3BmmJRBDh7ZReib/RQ1Rd4ygOyzP3E0vkV4/gq + yjLdApXh5PZP8KLQZ+1JN/sdWt7VfIt9wYOpkIqujdll51ESHzwQeAK9WVCB4UvV + z6zdhItB9CRbXPreWC+wCB1xDovIzFKOVsLs5+Gqs1m7VinG2LxbDqaKyo/FB0Hx + x0acBNzezLWoDwXYQrN0T0S4pnqhKD1CYPpdArBkNezUYAjS725FkApuK+mnBX3U + 0msBffEaUEOkcyar1EW2m/33vpetD/k3eQQkmvQf4Hbiu9AF+9cNDm/hMuXEw5EX + GA91fn0891b5eEW8BJHXX0jri0aN8g=="], + "n": "", + "e": "AQAB" + } + ] + }, + "test.example.com": { + "keys": [ + { + "kty": "RSA", + "use": "x509-svid", + "x5c": ["MIIFsjCCA5qgAwIBAgIURygVMMzdr+Q7rsUaz189JozyHMwwDQYJKoZIhvcNAQEL + BQAwTjELMAkGA1UEBhMCVVMxCzAJBgNVBAgMAkNBMQwwCgYDVQQHDANTVkwxDTAL + BgNVBAoMBGdSUEMxFTATBgNVBAMMDHRlc3QtY2xpZW50MTAeFw0yMTEyMjMxODQy + NTJaFw0zMTEyMjExODQyNTJaME4xCzAJBgNVBAYTAlVTMQswCQYDVQQIDAJDQTEM + MAoGA1UEBwwDU1ZMMQ0wCwYDVQQKDARnUlBDMRUwEwYDVQQDDAx0ZXN0LWNsaWVu + dDEwggIiMA0GCSqGSIb3DQEBAQUAA4ICDwAwggIKAoICAQDJ4AqpGetyVSqGUuBJ + LVFla+7bEfca7UYzfVSSZLZ/X+JDmWIVN8UIPuFib5jhMEc3XaUnFXUmM7zEtz/Z + G5hapwLwOb2C3ZxOP6PQjYCJxbkLie+b43UQrFu1xxd3vMhVJgcj/AIxEpmszuqO + a6kUrkYifjJADQ+64kZgl66bsTdXMCzpxyFl9xUfff59L8OX+HUfAcoZz3emjg3Z + JPYURQEmjdZTOau1EjFilwHgd989Jt7NKgx30NXoHmw7nusVBIY94fL2VKN3f1XV + m0dHu5NI279Q6zr0ZBU7k5T3IeHnzsUesQS4NGlklDWoVTKk73Uv9Pna8yQsSW75 + 7PEbHOGp9Knu4bnoGPOlsG81yIPipO6hTgGFK24pF97M9kpGbWqYX4+2vLlrCAfc + msHqaUPmQlYeRVTT6vw7ctYo2kyUYGtnODXk76LqewRBVvkzx75QUhfjAyb740Yc + DmIenc56Tq6gebJHjhEmVSehR6xIpXP7SVeurTyhPsEQnpJHtgs4dcwWOZp7BvPN + zHXmJqfr7vsshie3vS5kQ0u1e1yqAqXgyDjqKXOkx+dpgUTehSJHhPNHvTc5LXRs + vvXKYz6FrwR/DZ8t7BNEvPeLjFgxpH7QVJFLCvCbXs5K6yYbsnLfxFIBPRnrbJkI + sK+sQwnRdnsiUdPsTkG5B2lQfQIDAQABo4GHMIGEMB0GA1UdDgQWBBQ2lBp0PiRH + HvQ5IRURm8aHsj4RETAfBgNVHSMEGDAWgBQ2lBp0PiRHHvQ5IRURm8aHsj4RETAP + BgNVHRMBAf8EBTADAQH/MDEGA1UdEQQqMCiGJnNwaWZmZTovL2Zvby5iYXIuY29t + L2NsaWVudC93b3JrbG9hZC8xMA0GCSqGSIb3DQEBCwUAA4ICAQA1mSkgRclAl+E/ + aS9zJ7t8+Y4n3T24nOKKveSIjxXm/zjhWqVsLYBI6kglWtih2+PELvU8JdPqNZK3 + 4Kl0Q6FWpVSGDdWN1i6NyORt2ocggL3ke3iXxRk3UpUKJmqwz81VhA2KUHnMlyE0 + IufFfZNwNWWHBv13uJfRbjeQpKPhU+yf4DeXrsWcvrZlGvAET+mcplafUzCp7Iv+ + PcISJtUerbxbVtuHVeZCLlgDXWkLAWJN8rf0dIG4x060LJ+j6j9uRVhb9sZn1HJV + +j4XdIYm1VKilluhOtNwP2d3Ox/JuTBxf7hFHXZPfMagQE5k5PzmxRaCAEMJ1l2D + vUbZw+shJfSNoWcBo2qadnUaWT3BmmJRBDh7ZReib/RQ1Rd4ygOyzP3E0vkV4/gq + yjLdApXh5PZP8KLQZ+1JN/sdWt7VfIt9wYOpkIqujdll51ESHzwQeAK9WVCB4UvV + z6zdhItB9CRbXPreWC+wCB1xDovIzFKOVsLs5+Gqs1m7VinG2LxbDqaKyo/FB0Hx + x0acBNzezLWoDwXYQrN0T0S4pnqhKD1CYPpdArBkNezUYAjS725FkApuK+mnBX3U + 0msBffEaUEOkcyar1EW2m/33vpetD/k3eQQkmvQf4Hbiu9AF+9cNDm/hMuXEw5EX + GA91fn0891b5eEW8BJHXX0jri0aN8g=="], + "n": "", + "e": "AQAB" + }, + { + "kty": "RSA", + "use": "x509-svid", + "x5c": ["MIIELTCCAxWgAwIBAgIUVXGlXjNENtOZbI12epjgIhMaShEwDQYJKoZIhvcNAQEL + BQAwVjELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoM + GEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDEPMA0GA1UEAwwGdGVzdGNhMB4XDTI0 + MDkxNzE2MTk0NFoXDTM0MDkxNTE2MTk0NFowTjELMAkGA1UEBhMCVVMxCzAJBgNV + BAgMAkNBMQwwCgYDVQQHDANTVkwxDTALBgNVBAoMBGdSUEMxFTATBgNVBAMMDHRl + c3QtY2xpZW50MTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBAOcTjjcS + SfG/EGrr6G+f+3T2GXyHHfroQFi9mZUz80L7uKBdECOImID+YhoK8vcxLQjPmEEv + FIYgJT5amugDcYIgUhMjBx/8RPJaP/nGmBngAqsuuNCaZfyaHBRqN8XdS/AwmsI5 + Wo+nru0+0/7aQFdqqtd2+e9dHjUWwgHxXvMgC4hkHpsdCGIZWVzWyBliwTYQYb1Y + yYe1LzqqQA5OMbZfKOY9MYDCEYOliRiunOn30iIOHj9V5qLzWGfSyxCRuvLRdEP8 + iDeNweHbdaKuI80nQmxuBdRIspE9k5sD1WA4vLZpeg3zggxp4rfLL5zBJgb/33D3 + d9Rkm14xfDPihhkCAwEAAaOB+jCB9zBZBgNVHREEUjBQhiZzcGlmZmU6Ly9mb28u + YmFyLmNvbS9jbGllbnQvd29ya2xvYWQvMYYmc3BpZmZlOi8vZm9vLmJhci5jb20v + Y2xpZW50L3dvcmtsb2FkLzIwHQYDVR0OBBYEFG9GkBgdBg/p0U9/lXv8zIJ+2c2N + MHsGA1UdIwR0MHKhWqRYMFYxCzAJBgNVBAYTAkFVMRMwEQYDVQQIDApTb21lLVN0 + YXRlMSEwHwYDVQQKDBhJbnRlcm5ldCBXaWRnaXRzIFB0eSBMdGQxDzANBgNVBAMM + BnRlc3RjYYIUWrP0VvHcy+LP6UuYNtiL9gBhD5owDQYJKoZIhvcNAQELBQADggEB + AJ4Cbxv+02SpUgkEu4hP/1+8DtSBXUxNxI0VG4e3Ap2+Rhjm3YiFeS/UeaZhNrrw + UEjkSTPFODyXR7wI7UO9OO1StyD6CMkp3SEvevU5JsZtGL6mTiTLTi3Qkywa91Bt + GlyZdVMghA1bBJLBMwiD5VT5noqoJBD7hDy6v9yNmt1Sw2iYBJPqI3Gnf5bMjR3s + UICaxmFyqaMCZsPkfJh0DmZpInGJys3m4QqGz6ZE2DWgcSr1r/ML7/5bSPjjr8j4 + WFFSqFR3dMu8CbGnfZTCTXa4GTX/rARXbAO67Z/oJbJBK7VKayskL+PzKuohb9ox + jGL772hQMbwtFCOFXu5VP0s="] + } + ] + } + } +} \ No newline at end of file diff --git a/core/src/test/resources/io/grpc/internal/spiffebundle_corrupted_cert.json b/core/src/test/resources/io/grpc/internal/spiffebundle_corrupted_cert.json new file mode 100644 index 00000000000..9ca51733ff3 --- /dev/null +++ b/core/src/test/resources/io/grpc/internal/spiffebundle_corrupted_cert.json @@ -0,0 +1,14 @@ +{ + "trust_domains": { + "google.com": { + "spiffe_sequence": 123, + "keys": [ + { + "kty": "RSA", + "use": "x509-svid", + "x5c": ["UNPARSABLE_CERTIFICATE"] + } + ] + } + } +} \ No newline at end of file diff --git a/core/src/test/resources/io/grpc/internal/spiffebundle_duplicates.json b/core/src/test/resources/io/grpc/internal/spiffebundle_duplicates.json new file mode 100644 index 00000000000..3f015bd1568 --- /dev/null +++ b/core/src/test/resources/io/grpc/internal/spiffebundle_duplicates.json @@ -0,0 +1,23 @@ +{ + "trust_domains": { + "google.com": { + "spiffe_sequence": 123, + "keys": [ + { + "x5c": "VALUE_DOESN'T_MATTER" + } + ] + }, + "google.com": { + "spiffe_sequence": 123, + "keys": [ + { + "use": "x509-svid", + "kid": "some_value", + "x5c": "VALUE_DOESN'T_MATTER" + } + ] + }, + "test.google.com.au": {} + } +} \ No newline at end of file diff --git a/core/src/test/resources/io/grpc/internal/spiffebundle_ec.json b/core/src/test/resources/io/grpc/internal/spiffebundle_ec.json new file mode 100644 index 00000000000..1732310f8cf --- /dev/null +++ b/core/src/test/resources/io/grpc/internal/spiffebundle_ec.json @@ -0,0 +1,116 @@ +{ + "trust_domains": { + "test.google.com.au": {}, + "example.com": { + "spiffe_sequence": 12035488, + "keys": [ + { + + "kty": "EC", + "use": "x509-svid", + "x5c": ["MIIFsjCCA5qgAwIBAgIURygVMMzdr+Q7rsUaz189JozyHMwwDQYJKoZIhvcNAQEL + BQAwTjELMAkGA1UEBhMCVVMxCzAJBgNVBAgMAkNBMQwwCgYDVQQHDANTVkwxDTAL + BgNVBAoMBGdSUEMxFTATBgNVBAMMDHRlc3QtY2xpZW50MTAeFw0yMTEyMjMxODQy + NTJaFw0zMTEyMjExODQyNTJaME4xCzAJBgNVBAYTAlVTMQswCQYDVQQIDAJDQTEM + MAoGA1UEBwwDU1ZMMQ0wCwYDVQQKDARnUlBDMRUwEwYDVQQDDAx0ZXN0LWNsaWVu + dDEwggIiMA0GCSqGSIb3DQEBAQUAA4ICDwAwggIKAoICAQDJ4AqpGetyVSqGUuBJ + LVFla+7bEfca7UYzfVSSZLZ/X+JDmWIVN8UIPuFib5jhMEc3XaUnFXUmM7zEtz/Z + G5hapwLwOb2C3ZxOP6PQjYCJxbkLie+b43UQrFu1xxd3vMhVJgcj/AIxEpmszuqO + a6kUrkYifjJADQ+64kZgl66bsTdXMCzpxyFl9xUfff59L8OX+HUfAcoZz3emjg3Z + JPYURQEmjdZTOau1EjFilwHgd989Jt7NKgx30NXoHmw7nusVBIY94fL2VKN3f1XV + m0dHu5NI279Q6zr0ZBU7k5T3IeHnzsUesQS4NGlklDWoVTKk73Uv9Pna8yQsSW75 + 7PEbHOGp9Knu4bnoGPOlsG81yIPipO6hTgGFK24pF97M9kpGbWqYX4+2vLlrCAfc + msHqaUPmQlYeRVTT6vw7ctYo2kyUYGtnODXk76LqewRBVvkzx75QUhfjAyb740Yc + DmIenc56Tq6gebJHjhEmVSehR6xIpXP7SVeurTyhPsEQnpJHtgs4dcwWOZp7BvPN + zHXmJqfr7vsshie3vS5kQ0u1e1yqAqXgyDjqKXOkx+dpgUTehSJHhPNHvTc5LXRs + vvXKYz6FrwR/DZ8t7BNEvPeLjFgxpH7QVJFLCvCbXs5K6yYbsnLfxFIBPRnrbJkI + sK+sQwnRdnsiUdPsTkG5B2lQfQIDAQABo4GHMIGEMB0GA1UdDgQWBBQ2lBp0PiRH + HvQ5IRURm8aHsj4RETAfBgNVHSMEGDAWgBQ2lBp0PiRHHvQ5IRURm8aHsj4RETAP + BgNVHRMBAf8EBTADAQH/MDEGA1UdEQQqMCiGJnNwaWZmZTovL2Zvby5iYXIuY29t + L2NsaWVudC93b3JrbG9hZC8xMA0GCSqGSIb3DQEBCwUAA4ICAQA1mSkgRclAl+E/ + aS9zJ7t8+Y4n3T24nOKKveSIjxXm/zjhWqVsLYBI6kglWtih2+PELvU8JdPqNZK3 + 4Kl0Q6FWpVSGDdWN1i6NyORt2ocggL3ke3iXxRk3UpUKJmqwz81VhA2KUHnMlyE0 + IufFfZNwNWWHBv13uJfRbjeQpKPhU+yf4DeXrsWcvrZlGvAET+mcplafUzCp7Iv+ + PcISJtUerbxbVtuHVeZCLlgDXWkLAWJN8rf0dIG4x060LJ+j6j9uRVhb9sZn1HJV + +j4XdIYm1VKilluhOtNwP2d3Ox/JuTBxf7hFHXZPfMagQE5k5PzmxRaCAEMJ1l2D + vUbZw+shJfSNoWcBo2qadnUaWT3BmmJRBDh7ZReib/RQ1Rd4ygOyzP3E0vkV4/gq + yjLdApXh5PZP8KLQZ+1JN/sdWt7VfIt9wYOpkIqujdll51ESHzwQeAK9WVCB4UvV + z6zdhItB9CRbXPreWC+wCB1xDovIzFKOVsLs5+Gqs1m7VinG2LxbDqaKyo/FB0Hx + x0acBNzezLWoDwXYQrN0T0S4pnqhKD1CYPpdArBkNezUYAjS725FkApuK+mnBX3U + 0msBffEaUEOkcyar1EW2m/33vpetD/k3eQQkmvQf4Hbiu9AF+9cNDm/hMuXEw5EX + GA91fn0891b5eEW8BJHXX0jri0aN8g=="], + "n": "", + "e": "AQAB" + } + ] + }, + "test.example.com": { + "keys": [ + { + "kty": "RSA", + "use": "x509-svid", + "x5c": ["MIIFsjCCA5qgAwIBAgIURygVMMzdr+Q7rsUaz189JozyHMwwDQYJKoZIhvcNAQEL + BQAwTjELMAkGA1UEBhMCVVMxCzAJBgNVBAgMAkNBMQwwCgYDVQQHDANTVkwxDTAL + BgNVBAoMBGdSUEMxFTATBgNVBAMMDHRlc3QtY2xpZW50MTAeFw0yMTEyMjMxODQy + NTJaFw0zMTEyMjExODQyNTJaME4xCzAJBgNVBAYTAlVTMQswCQYDVQQIDAJDQTEM + MAoGA1UEBwwDU1ZMMQ0wCwYDVQQKDARnUlBDMRUwEwYDVQQDDAx0ZXN0LWNsaWVu + dDEwggIiMA0GCSqGSIb3DQEBAQUAA4ICDwAwggIKAoICAQDJ4AqpGetyVSqGUuBJ + LVFla+7bEfca7UYzfVSSZLZ/X+JDmWIVN8UIPuFib5jhMEc3XaUnFXUmM7zEtz/Z + G5hapwLwOb2C3ZxOP6PQjYCJxbkLie+b43UQrFu1xxd3vMhVJgcj/AIxEpmszuqO + a6kUrkYifjJADQ+64kZgl66bsTdXMCzpxyFl9xUfff59L8OX+HUfAcoZz3emjg3Z + JPYURQEmjdZTOau1EjFilwHgd989Jt7NKgx30NXoHmw7nusVBIY94fL2VKN3f1XV + m0dHu5NI279Q6zr0ZBU7k5T3IeHnzsUesQS4NGlklDWoVTKk73Uv9Pna8yQsSW75 + 7PEbHOGp9Knu4bnoGPOlsG81yIPipO6hTgGFK24pF97M9kpGbWqYX4+2vLlrCAfc + msHqaUPmQlYeRVTT6vw7ctYo2kyUYGtnODXk76LqewRBVvkzx75QUhfjAyb740Yc + DmIenc56Tq6gebJHjhEmVSehR6xIpXP7SVeurTyhPsEQnpJHtgs4dcwWOZp7BvPN + zHXmJqfr7vsshie3vS5kQ0u1e1yqAqXgyDjqKXOkx+dpgUTehSJHhPNHvTc5LXRs + vvXKYz6FrwR/DZ8t7BNEvPeLjFgxpH7QVJFLCvCbXs5K6yYbsnLfxFIBPRnrbJkI + sK+sQwnRdnsiUdPsTkG5B2lQfQIDAQABo4GHMIGEMB0GA1UdDgQWBBQ2lBp0PiRH + HvQ5IRURm8aHsj4RETAfBgNVHSMEGDAWgBQ2lBp0PiRHHvQ5IRURm8aHsj4RETAP + BgNVHRMBAf8EBTADAQH/MDEGA1UdEQQqMCiGJnNwaWZmZTovL2Zvby5iYXIuY29t + L2NsaWVudC93b3JrbG9hZC8xMA0GCSqGSIb3DQEBCwUAA4ICAQA1mSkgRclAl+E/ + aS9zJ7t8+Y4n3T24nOKKveSIjxXm/zjhWqVsLYBI6kglWtih2+PELvU8JdPqNZK3 + 4Kl0Q6FWpVSGDdWN1i6NyORt2ocggL3ke3iXxRk3UpUKJmqwz81VhA2KUHnMlyE0 + IufFfZNwNWWHBv13uJfRbjeQpKPhU+yf4DeXrsWcvrZlGvAET+mcplafUzCp7Iv+ + PcISJtUerbxbVtuHVeZCLlgDXWkLAWJN8rf0dIG4x060LJ+j6j9uRVhb9sZn1HJV + +j4XdIYm1VKilluhOtNwP2d3Ox/JuTBxf7hFHXZPfMagQE5k5PzmxRaCAEMJ1l2D + vUbZw+shJfSNoWcBo2qadnUaWT3BmmJRBDh7ZReib/RQ1Rd4ygOyzP3E0vkV4/gq + yjLdApXh5PZP8KLQZ+1JN/sdWt7VfIt9wYOpkIqujdll51ESHzwQeAK9WVCB4UvV + z6zdhItB9CRbXPreWC+wCB1xDovIzFKOVsLs5+Gqs1m7VinG2LxbDqaKyo/FB0Hx + x0acBNzezLWoDwXYQrN0T0S4pnqhKD1CYPpdArBkNezUYAjS725FkApuK+mnBX3U + 0msBffEaUEOkcyar1EW2m/33vpetD/k3eQQkmvQf4Hbiu9AF+9cNDm/hMuXEw5EX + GA91fn0891b5eEW8BJHXX0jri0aN8g=="], + "n": "", + "e": "AQAB" + }, + { + "kty": "RSA", + "use": "x509-svid", + "x5c": ["MIIELTCCAxWgAwIBAgIUVXGlXjNENtOZbI12epjgIhMaShEwDQYJKoZIhvcNAQEL + BQAwVjELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoM + GEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDEPMA0GA1UEAwwGdGVzdGNhMB4XDTI0 + MDkxNzE2MTk0NFoXDTM0MDkxNTE2MTk0NFowTjELMAkGA1UEBhMCVVMxCzAJBgNV + BAgMAkNBMQwwCgYDVQQHDANTVkwxDTALBgNVBAoMBGdSUEMxFTATBgNVBAMMDHRl + c3QtY2xpZW50MTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBAOcTjjcS + SfG/EGrr6G+f+3T2GXyHHfroQFi9mZUz80L7uKBdECOImID+YhoK8vcxLQjPmEEv + FIYgJT5amugDcYIgUhMjBx/8RPJaP/nGmBngAqsuuNCaZfyaHBRqN8XdS/AwmsI5 + Wo+nru0+0/7aQFdqqtd2+e9dHjUWwgHxXvMgC4hkHpsdCGIZWVzWyBliwTYQYb1Y + yYe1LzqqQA5OMbZfKOY9MYDCEYOliRiunOn30iIOHj9V5qLzWGfSyxCRuvLRdEP8 + iDeNweHbdaKuI80nQmxuBdRIspE9k5sD1WA4vLZpeg3zggxp4rfLL5zBJgb/33D3 + d9Rkm14xfDPihhkCAwEAAaOB+jCB9zBZBgNVHREEUjBQhiZzcGlmZmU6Ly9mb28u + YmFyLmNvbS9jbGllbnQvd29ya2xvYWQvMYYmc3BpZmZlOi8vZm9vLmJhci5jb20v + Y2xpZW50L3dvcmtsb2FkLzIwHQYDVR0OBBYEFG9GkBgdBg/p0U9/lXv8zIJ+2c2N + MHsGA1UdIwR0MHKhWqRYMFYxCzAJBgNVBAYTAkFVMRMwEQYDVQQIDApTb21lLVN0 + YXRlMSEwHwYDVQQKDBhJbnRlcm5ldCBXaWRnaXRzIFB0eSBMdGQxDzANBgNVBAMM + BnRlc3RjYYIUWrP0VvHcy+LP6UuYNtiL9gBhD5owDQYJKoZIhvcNAQELBQADggEB + AJ4Cbxv+02SpUgkEu4hP/1+8DtSBXUxNxI0VG4e3Ap2+Rhjm3YiFeS/UeaZhNrrw + UEjkSTPFODyXR7wI7UO9OO1StyD6CMkp3SEvevU5JsZtGL6mTiTLTi3Qkywa91Bt + GlyZdVMghA1bBJLBMwiD5VT5noqoJBD7hDy6v9yNmt1Sw2iYBJPqI3Gnf5bMjR3s + UICaxmFyqaMCZsPkfJh0DmZpInGJys3m4QqGz6ZE2DWgcSr1r/ML7/5bSPjjr8j4 + WFFSqFR3dMu8CbGnfZTCTXa4GTX/rARXbAO67Z/oJbJBK7VKayskL+PzKuohb9ox + jGL772hQMbwtFCOFXu5VP0s="] + } + ] + } + } +} \ No newline at end of file diff --git a/core/src/test/resources/io/grpc/internal/spiffebundle_malformed.json b/core/src/test/resources/io/grpc/internal/spiffebundle_malformed.json new file mode 100644 index 00000000000..a2488eeb3cd --- /dev/null +++ b/core/src/test/resources/io/grpc/internal/spiffebundle_malformed.json @@ -0,0 +1,4 @@ +[ + "test.google.com", + "test.google.com.au" +] \ No newline at end of file diff --git a/core/src/test/resources/io/grpc/internal/spiffebundle_wrong_kid.json b/core/src/test/resources/io/grpc/internal/spiffebundle_wrong_kid.json new file mode 100644 index 00000000000..f93af634a54 --- /dev/null +++ b/core/src/test/resources/io/grpc/internal/spiffebundle_wrong_kid.json @@ -0,0 +1,15 @@ +{ + "trust_domains": { + "google.com": { + "spiffe_sequence": 123, + "keys": [ + { + "kty": "RSA", + "use": "x509-svid", + "kid": "some_value", + "x5c": "VALUE_DOESN'T_MATTER" + } + ] + } + } +} \ No newline at end of file diff --git a/core/src/test/resources/io/grpc/internal/spiffebundle_wrong_kty.json b/core/src/test/resources/io/grpc/internal/spiffebundle_wrong_kty.json new file mode 100644 index 00000000000..384da03fd6f --- /dev/null +++ b/core/src/test/resources/io/grpc/internal/spiffebundle_wrong_kty.json @@ -0,0 +1,12 @@ +{ + "trust_domains": { + "google.com": { + "spiffe_sequence": 123, + "keys": [ + { + "x5c": "VALUE_DOESN'T_MATTER" + } + ] + } + } +} \ No newline at end of file diff --git a/core/src/test/resources/io/grpc/internal/spiffebundle_wrong_multi_certs.json b/core/src/test/resources/io/grpc/internal/spiffebundle_wrong_multi_certs.json new file mode 100644 index 00000000000..5e85635bb02 --- /dev/null +++ b/core/src/test/resources/io/grpc/internal/spiffebundle_wrong_multi_certs.json @@ -0,0 +1,67 @@ +{ + "trust_domains": { + "google.com": { + "spiffe_sequence": 123, + "keys": [ + { + "kty": "RSA", + "use": "x509-svid", + "x5c": ["MIIFsjCCA5qgAwIBAgIURygVMMzdr+Q7rsUaz189JozyHMwwDQYJKoZIhvcNAQEL + BQAwTjELMAkGA1UEBhMCVVMxCzAJBgNVBAgMAkNBMQwwCgYDVQQHDANTVkwxDTAL + BgNVBAoMBGdSUEMxFTATBgNVBAMMDHRlc3QtY2xpZW50MTAeFw0yMTEyMjMxODQy + NTJaFw0zMTEyMjExODQyNTJaME4xCzAJBgNVBAYTAlVTMQswCQYDVQQIDAJDQTEM + MAoGA1UEBwwDU1ZMMQ0wCwYDVQQKDARnUlBDMRUwEwYDVQQDDAx0ZXN0LWNsaWVu + dDEwggIiMA0GCSqGSIb3DQEBAQUAA4ICDwAwggIKAoICAQDJ4AqpGetyVSqGUuBJ + LVFla+7bEfca7UYzfVSSZLZ/X+JDmWIVN8UIPuFib5jhMEc3XaUnFXUmM7zEtz/Z + G5hapwLwOb2C3ZxOP6PQjYCJxbkLie+b43UQrFu1xxd3vMhVJgcj/AIxEpmszuqO + a6kUrkYifjJADQ+64kZgl66bsTdXMCzpxyFl9xUfff59L8OX+HUfAcoZz3emjg3Z + JPYURQEmjdZTOau1EjFilwHgd989Jt7NKgx30NXoHmw7nusVBIY94fL2VKN3f1XV + m0dHu5NI279Q6zr0ZBU7k5T3IeHnzsUesQS4NGlklDWoVTKk73Uv9Pna8yQsSW75 + 7PEbHOGp9Knu4bnoGPOlsG81yIPipO6hTgGFK24pF97M9kpGbWqYX4+2vLlrCAfc + msHqaUPmQlYeRVTT6vw7ctYo2kyUYGtnODXk76LqewRBVvkzx75QUhfjAyb740Yc + DmIenc56Tq6gebJHjhEmVSehR6xIpXP7SVeurTyhPsEQnpJHtgs4dcwWOZp7BvPN + zHXmJqfr7vsshie3vS5kQ0u1e1yqAqXgyDjqKXOkx+dpgUTehSJHhPNHvTc5LXRs + vvXKYz6FrwR/DZ8t7BNEvPeLjFgxpH7QVJFLCvCbXs5K6yYbsnLfxFIBPRnrbJkI + sK+sQwnRdnsiUdPsTkG5B2lQfQIDAQABo4GHMIGEMB0GA1UdDgQWBBQ2lBp0PiRH + HvQ5IRURm8aHsj4RETAfBgNVHSMEGDAWgBQ2lBp0PiRHHvQ5IRURm8aHsj4RETAP + BgNVHRMBAf8EBTADAQH/MDEGA1UdEQQqMCiGJnNwaWZmZTovL2Zvby5iYXIuY29t + L2NsaWVudC93b3JrbG9hZC8xMA0GCSqGSIb3DQEBCwUAA4ICAQA1mSkgRclAl+E/ + aS9zJ7t8+Y4n3T24nOKKveSIjxXm/zjhWqVsLYBI6kglWtih2+PELvU8JdPqNZK3 + 4Kl0Q6FWpVSGDdWN1i6NyORt2ocggL3ke3iXxRk3UpUKJmqwz81VhA2KUHnMlyE0 + IufFfZNwNWWHBv13uJfRbjeQpKPhU+yf4DeXrsWcvrZlGvAET+mcplafUzCp7Iv+ + PcISJtUerbxbVtuHVeZCLlgDXWkLAWJN8rf0dIG4x060LJ+j6j9uRVhb9sZn1HJV + +j4XdIYm1VKilluhOtNwP2d3Ox/JuTBxf7hFHXZPfMagQE5k5PzmxRaCAEMJ1l2D + vUbZw+shJfSNoWcBo2qadnUaWT3BmmJRBDh7ZReib/RQ1Rd4ygOyzP3E0vkV4/gq + yjLdApXh5PZP8KLQZ+1JN/sdWt7VfIt9wYOpkIqujdll51ESHzwQeAK9WVCB4UvV + z6zdhItB9CRbXPreWC+wCB1xDovIzFKOVsLs5+Gqs1m7VinG2LxbDqaKyo/FB0Hx + x0acBNzezLWoDwXYQrN0T0S4pnqhKD1CYPpdArBkNezUYAjS725FkApuK+mnBX3U + 0msBffEaUEOkcyar1EW2m/33vpetD/k3eQQkmvQf4Hbiu9AF+9cNDm/hMuXEw5EX + GA91fn0891b5eEW8BJHXX0jri0aN8g==", + "MIIELTCCAxWgAwIBAgIUVXGlXjNENtOZbI12epjgIhMaShEwDQYJKoZIhvcNAQEL + BQAwVjELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoM + GEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDEPMA0GA1UEAwwGdGVzdGNhMB4XDTI0 + MDkxNzE2MTk0NFoXDTM0MDkxNTE2MTk0NFowTjELMAkGA1UEBhMCVVMxCzAJBgNV + BAgMAkNBMQwwCgYDVQQHDANTVkwxDTALBgNVBAoMBGdSUEMxFTATBgNVBAMMDHRl + c3QtY2xpZW50MTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBAOcTjjcS + SfG/EGrr6G+f+3T2GXyHHfroQFi9mZUz80L7uKBdECOImID+YhoK8vcxLQjPmEEv + FIYgJT5amugDcYIgUhMjBx/8RPJaP/nGmBngAqsuuNCaZfyaHBRqN8XdS/AwmsI5 + Wo+nru0+0/7aQFdqqtd2+e9dHjUWwgHxXvMgC4hkHpsdCGIZWVzWyBliwTYQYb1Y + yYe1LzqqQA5OMbZfKOY9MYDCEYOliRiunOn30iIOHj9V5qLzWGfSyxCRuvLRdEP8 + iDeNweHbdaKuI80nQmxuBdRIspE9k5sD1WA4vLZpeg3zggxp4rfLL5zBJgb/33D3 + d9Rkm14xfDPihhkCAwEAAaOB+jCB9zBZBgNVHREEUjBQhiZzcGlmZmU6Ly9mb28u + YmFyLmNvbS9jbGllbnQvd29ya2xvYWQvMYYmc3BpZmZlOi8vZm9vLmJhci5jb20v + Y2xpZW50L3dvcmtsb2FkLzIwHQYDVR0OBBYEFG9GkBgdBg/p0U9/lXv8zIJ+2c2N + MHsGA1UdIwR0MHKhWqRYMFYxCzAJBgNVBAYTAkFVMRMwEQYDVQQIDApTb21lLVN0 + YXRlMSEwHwYDVQQKDBhJbnRlcm5ldCBXaWRnaXRzIFB0eSBMdGQxDzANBgNVBAMM + BnRlc3RjYYIUWrP0VvHcy+LP6UuYNtiL9gBhD5owDQYJKoZIhvcNAQELBQADggEB + AJ4Cbxv+02SpUgkEu4hP/1+8DtSBXUxNxI0VG4e3Ap2+Rhjm3YiFeS/UeaZhNrrw + UEjkSTPFODyXR7wI7UO9OO1StyD6CMkp3SEvevU5JsZtGL6mTiTLTi3Qkywa91Bt + GlyZdVMghA1bBJLBMwiD5VT5noqoJBD7hDy6v9yNmt1Sw2iYBJPqI3Gnf5bMjR3s + UICaxmFyqaMCZsPkfJh0DmZpInGJys3m4QqGz6ZE2DWgcSr1r/ML7/5bSPjjr8j4 + WFFSqFR3dMu8CbGnfZTCTXa4GTX/rARXbAO67Z/oJbJBK7VKayskL+PzKuohb9ox + jGL772hQMbwtFCOFXu5VP0s="] + } + ] + } + } +} \ No newline at end of file diff --git a/core/src/test/resources/io/grpc/internal/spiffebundle_wrong_root.json b/core/src/test/resources/io/grpc/internal/spiffebundle_wrong_root.json new file mode 100644 index 00000000000..90d2847dc05 --- /dev/null +++ b/core/src/test/resources/io/grpc/internal/spiffebundle_wrong_root.json @@ -0,0 +1,6 @@ +{ + "trustDomains": { + "test.google.com": {}, + "test.google.com.au": {} + } +} \ No newline at end of file diff --git a/core/src/test/resources/io/grpc/internal/spiffebundle_wrong_seq_type.json b/core/src/test/resources/io/grpc/internal/spiffebundle_wrong_seq_type.json new file mode 100644 index 00000000000..4e0aeacc89f --- /dev/null +++ b/core/src/test/resources/io/grpc/internal/spiffebundle_wrong_seq_type.json @@ -0,0 +1,12 @@ +{ + "trust_domains": { + "google.com": { + "spiffe_sequence": 123.5, + "keys": [ + { + "x5c": "VALUE_DOESN'T_MATTER" + } + ] + } + } +} \ No newline at end of file diff --git a/core/src/test/resources/io/grpc/internal/spiffebundle_wrong_use.json b/core/src/test/resources/io/grpc/internal/spiffebundle_wrong_use.json new file mode 100644 index 00000000000..166be04846c --- /dev/null +++ b/core/src/test/resources/io/grpc/internal/spiffebundle_wrong_use.json @@ -0,0 +1,13 @@ +{ + "trust_domains": { + "google.com": { + "keys": [ + { + "kty": "RSA", + "use": "i_am_not_x509-svid", + "x5c": "VALUE_DOESN'T_MATTER" + } + ] + } + } +} \ No newline at end of file diff --git a/core/src/testFixtures/java/io/grpc/internal/AbstractTransportTest.java b/core/src/testFixtures/java/io/grpc/internal/AbstractTransportTest.java index 57d870575d7..5d07de32df9 100644 --- a/core/src/testFixtures/java/io/grpc/internal/AbstractTransportTest.java +++ b/core/src/testFixtures/java/io/grpc/internal/AbstractTransportTest.java @@ -16,14 +16,15 @@ package io.grpc.internal; -import static com.google.common.base.Charsets.UTF_8; import static com.google.common.truth.Truth.assertThat; +import static java.nio.charset.StandardCharsets.UTF_8; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotEquals; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; import static org.junit.Assume.assumeTrue; @@ -57,6 +58,7 @@ import io.grpc.MethodDescriptor; import io.grpc.ServerStreamTracer; import io.grpc.Status; +import io.grpc.internal.MockServerTransportListener.StreamCreation; import io.grpc.internal.testing.TestClientStreamTracer; import io.grpc.internal.testing.TestServerStreamTracer; import java.io.ByteArrayInputStream; @@ -68,17 +70,13 @@ import java.util.Arrays; import java.util.List; import java.util.concurrent.BlockingQueue; -import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutionException; import java.util.concurrent.Future; -import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; import org.junit.After; import org.junit.Before; -import org.junit.Rule; import org.junit.Test; -import org.junit.rules.ExpectedException; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; import org.mockito.ArgumentCaptor; @@ -94,7 +92,10 @@ public abstract class AbstractTransportTest { */ public static final int TEST_FLOW_CONTROL_WINDOW = 65 * 1024; - private static final int TIMEOUT_MS = 5000; + protected static final int TIMEOUT_MS = 5000; + + protected static final String GRPC_EXPERIMENTAL_SUPPORT_TRACING_MESSAGE_SIZES = + "GRPC_EXPERIMENTAL_SUPPORT_TRACING_MESSAGE_SIZES"; private static final Attributes.Key ADDITIONAL_TRANSPORT_ATTR_KEY = Attributes.Key.create("additional-attr"); @@ -136,13 +137,6 @@ protected abstract InternalServer newServer( */ protected abstract String testAuthority(InternalServer server); - /** - * Returns true (which is default) if the transport reports message sizes to StreamTracers. - */ - protected boolean sizesReported() { - return true; - } - protected final Attributes eagAttrs() { return EAG_ATTRS; } @@ -163,9 +157,9 @@ public void log(ChannelLogLevel level, String messageFormat, Object... args) {} * tests in an indeterminate state. */ protected InternalServer server; - private ServerTransport serverTransport; - private ManagedClientTransport client; - private MethodDescriptor methodDescriptor = + protected ServerTransport serverTransport; + protected ManagedClientTransport client; + protected MethodDescriptor methodDescriptor = MethodDescriptor.newBuilder() .setType(MethodDescriptor.MethodType.UNKNOWN) .setFullMethodName("service/method") @@ -182,22 +176,22 @@ public void log(ChannelLogLevel level, String messageFormat, Object... args) {} "tracer-key", Metadata.ASCII_STRING_MARSHALLER); private final String tracerKeyValue = "tracer-key-value"; - private ManagedClientTransport.Listener mockClientTransportListener + protected ManagedClientTransport.Listener mockClientTransportListener = mock(ManagedClientTransport.Listener.class); - private MockServerListener serverListener = new MockServerListener(); - private ArgumentCaptor throwableCaptor = ArgumentCaptor.forClass(Throwable.class); - private final TestClientStreamTracer clientStreamTracer1 = new TestHeaderClientStreamTracer(); + protected MockServerListener serverListener = new MockServerListener(); + private ArgumentCaptor statusCaptor = ArgumentCaptor.forClass(Status.class); + protected final TestClientStreamTracer clientStreamTracer1 = new TestHeaderClientStreamTracer(); private final TestClientStreamTracer clientStreamTracer2 = new TestHeaderClientStreamTracer(); - private final ClientStreamTracer[] tracers = new ClientStreamTracer[] { + protected final ClientStreamTracer[] tracers = new ClientStreamTracer[] { clientStreamTracer1, clientStreamTracer2 }; - private final ClientStreamTracer[] noopTracers = new ClientStreamTracer[] { + protected final ClientStreamTracer[] noopTracers = new ClientStreamTracer[] { new ClientStreamTracer() {} }; - private final TestServerStreamTracer serverStreamTracer1 = new TestServerStreamTracer(); + protected final TestServerStreamTracer serverStreamTracer1 = new TestServerStreamTracer(); private final TestServerStreamTracer serverStreamTracer2 = new TestServerStreamTracer(); - private final ServerStreamTracer.Factory serverStreamTracerFactory = mock( + protected final ServerStreamTracer.Factory serverStreamTracerFactory = mock( ServerStreamTracer.Factory.class, delegatesTo(new ServerStreamTracer.Factory() { final ArrayDeque tracers = @@ -213,10 +207,6 @@ public ServerStreamTracer newServerStreamTracer(String fullMethodName, Metadata } })); - @SuppressWarnings("deprecation") // https://github.com/grpc/grpc-java/issues/7467 - @Rule - public ExpectedException thrown = ExpectedException.none(); - @Before public void setUp() { server = newServer(Arrays.asList(serverStreamTracerFactory)); @@ -245,6 +235,13 @@ protected void advanceClock(long offset, TimeUnit unit) { throw new UnsupportedOperationException(); } + /** + * Returns true if env var is set. + */ + protected static boolean isEnabledSupportTracingMessageSizes() { + return GrpcUtil.getFlag(GRPC_EXPERIMENTAL_SUPPORT_TRACING_MESSAGE_SIZES, false); + } + /** * Returns the current time, for tests that rely on the clock. */ @@ -266,7 +263,7 @@ protected long fakeCurrentTimeNanos() { // (and maybe exceptions handled) /** - * Test for issue https://github.com/grpc/grpc-java/issues/1682 + * Test for issue https://github.com/grpc/grpc-java/issues/1682 . */ @Test public void frameAfterRstStreamShouldNotBreakClientChannel() throws Exception { @@ -298,8 +295,8 @@ public void frameAfterRstStreamShouldNotBreakClientChannel() throws Exception { serverStreamCreation.stream.flush(); assertEquals( - Status.CANCELLED, clientStreamListener.status.get(TIMEOUT_MS, TimeUnit.MILLISECONDS)); - assertNotNull(clientStreamListener.trailers.get(TIMEOUT_MS, TimeUnit.MILLISECONDS)); + Status.CANCELLED, clientStreamListener.awaitClose(TIMEOUT_MS, TimeUnit.MILLISECONDS)); + assertNotNull(clientStreamListener.awaitTrailers(TIMEOUT_MS, TimeUnit.MILLISECONDS)); ClientStreamListener mockClientStreamListener2 = mock(ClientStreamListener.class); @@ -329,7 +326,8 @@ public void serverNotListening() throws Exception { runIfNotNull(client.start(mockClientTransportListener)); verify(mockClientTransportListener, timeout(TIMEOUT_MS)).transportTerminated(); ArgumentCaptor statusCaptor = ArgumentCaptor.forClass(Status.class); - inOrder.verify(mockClientTransportListener).transportShutdown(statusCaptor.capture()); + inOrder.verify(mockClientTransportListener).transportShutdown(statusCaptor.capture(), + any(DisconnectError.class)); assertCodeEquals(Status.UNAVAILABLE, statusCaptor.getValue()); inOrder.verify(mockClientTransportListener).transportTerminated(); verify(mockClientTransportListener, never()).transportReady(); @@ -345,7 +343,8 @@ public void clientStartStop() throws Exception { Status shutdownReason = Status.UNAVAILABLE.withDescription("shutdown called"); client.shutdown(shutdownReason); verify(mockClientTransportListener, timeout(TIMEOUT_MS)).transportTerminated(); - inOrder.verify(mockClientTransportListener).transportShutdown(same(shutdownReason)); + inOrder.verify(mockClientTransportListener).transportShutdown(same(shutdownReason), + any(DisconnectError.class)); inOrder.verify(mockClientTransportListener).transportTerminated(); verify(mockClientTransportListener, never()).transportInUse(anyBoolean()); } @@ -361,7 +360,8 @@ public void clientStartAndStopOnceConnected() throws Exception { = serverListener.takeListenerOrFail(TIMEOUT_MS, TimeUnit.MILLISECONDS); client.shutdown(Status.UNAVAILABLE); verify(mockClientTransportListener, timeout(TIMEOUT_MS)).transportTerminated(); - inOrder.verify(mockClientTransportListener).transportShutdown(any(Status.class)); + inOrder.verify(mockClientTransportListener).transportShutdown(any(Status.class), + any(DisconnectError.class)); inOrder.verify(mockClientTransportListener).transportTerminated(); assertTrue(serverTransportListener.waitForTermination(TIMEOUT_MS, TimeUnit.MILLISECONDS)); server.shutdown(); @@ -393,8 +393,7 @@ public void serverAlreadyListening() throws Exception { port = ((InetSocketAddress) addr).getPort(); } InternalServer server2 = newServer(port, Arrays.asList(serverStreamTracerFactory)); - thrown.expect(IOException.class); - server2.start(new MockServerListener()); + assertThrows(IOException.class, () -> server2.start(new MockServerListener())); } @Test @@ -458,7 +457,8 @@ public void openStreamPreventsTermination() throws Exception { serverTransport.shutdown(); serverTransport = null; - verify(mockClientTransportListener, timeout(TIMEOUT_MS)).transportShutdown(any(Status.class)); + verify(mockClientTransportListener, timeout(TIMEOUT_MS)).transportShutdown(any(Status.class), + any(DisconnectError.class)); assertTrue(serverListener.waitForShutdown(TIMEOUT_MS, TimeUnit.MILLISECONDS)); // A new server should be able to start listening, since the current server has given up @@ -472,7 +472,7 @@ public void openStreamPreventsTermination() throws Exception { // the stream still functions. serverStream.writeHeaders(new Metadata(), true); clientStream.halfClose(); - assertNotNull(clientStreamListener.headers.get(TIMEOUT_MS, TimeUnit.MILLISECONDS)); + assertNotNull(clientStreamListener.awaitHeaders(TIMEOUT_MS, TimeUnit.MILLISECONDS)); assertTrue(serverStreamListener.awaitHalfClosed(TIMEOUT_MS, TimeUnit.MILLISECONDS)); verify(mockClientTransportListener, never()).transportTerminated(); @@ -508,15 +508,16 @@ public void shutdownNowKillsClientStream() throws Exception { client.shutdownNow(status); client = null; - verify(mockClientTransportListener, timeout(TIMEOUT_MS)).transportShutdown(any(Status.class)); + verify(mockClientTransportListener, timeout(TIMEOUT_MS)).transportShutdown(any(Status.class), + any(DisconnectError.class)); verify(mockClientTransportListener, timeout(TIMEOUT_MS)).transportTerminated(); verify(mockClientTransportListener, timeout(TIMEOUT_MS)).transportInUse(false); assertTrue(serverTransportListener.waitForTermination(TIMEOUT_MS, TimeUnit.MILLISECONDS)); assertTrue(serverTransportListener.isTerminated()); - assertEquals(status, clientStreamListener.status.get(TIMEOUT_MS, TimeUnit.MILLISECONDS)); - assertNotNull(clientStreamListener.trailers.get(TIMEOUT_MS, TimeUnit.MILLISECONDS)); - Status serverStatus = serverStreamListener.status.get(TIMEOUT_MS, TimeUnit.MILLISECONDS); + assertEquals(status, clientStreamListener.awaitClose(TIMEOUT_MS, TimeUnit.MILLISECONDS)); + assertNotNull(clientStreamListener.awaitTrailers(TIMEOUT_MS, TimeUnit.MILLISECONDS)); + Status serverStatus = serverStreamListener.awaitClose(TIMEOUT_MS, TimeUnit.MILLISECONDS); assertFalse(serverStatus.isOk()); assertTrue(clientStreamTracer1.await(TIMEOUT_MS, TimeUnit.MILLISECONDS)); assertNull(clientStreamTracer1.getInboundTrailers()); @@ -547,15 +548,16 @@ public void shutdownNowKillsServerStream() throws Exception { serverTransport.shutdownNow(shutdownStatus); serverTransport = null; - verify(mockClientTransportListener, timeout(TIMEOUT_MS)).transportShutdown(any(Status.class)); + verify(mockClientTransportListener, timeout(TIMEOUT_MS)).transportShutdown(any(Status.class), + any(DisconnectError.class)); verify(mockClientTransportListener, timeout(TIMEOUT_MS)).transportTerminated(); verify(mockClientTransportListener, timeout(TIMEOUT_MS)).transportInUse(false); assertTrue(serverTransportListener.waitForTermination(TIMEOUT_MS, TimeUnit.MILLISECONDS)); assertTrue(serverTransportListener.isTerminated()); - Status clientStreamStatus = clientStreamListener.status.get(TIMEOUT_MS, TimeUnit.MILLISECONDS); + Status clientStreamStatus = clientStreamListener.awaitClose(TIMEOUT_MS, TimeUnit.MILLISECONDS); assertFalse(clientStreamStatus.isOk()); - assertNotNull(clientStreamListener.trailers.get(TIMEOUT_MS, TimeUnit.MILLISECONDS)); + assertNotNull(clientStreamListener.awaitTrailers(TIMEOUT_MS, TimeUnit.MILLISECONDS)); assertTrue(clientStreamTracer1.await(TIMEOUT_MS, TimeUnit.MILLISECONDS)); assertNull(clientStreamTracer1.getInboundTrailers()); assertStatusEquals(clientStreamStatus, clientStreamTracer1.getStatus()); @@ -565,7 +567,7 @@ public void shutdownNowKillsServerStream() throws Exception { // Generally will be same status provided to shutdownNow, but InProcessTransport can't // differentiate between client and server shutdownNow. The status is not really used on // server-side, so we don't care much. - assertNotNull(serverStreamListener.status.get(TIMEOUT_MS, TimeUnit.MILLISECONDS)); + assertNotNull(serverStreamListener.awaitClose(TIMEOUT_MS, TimeUnit.MILLISECONDS)); } @Test @@ -595,7 +597,8 @@ public void ping_duringShutdown() throws Exception { ClientStreamListenerBase clientStreamListener = new ClientStreamListenerBase(); stream.start(clientStreamListener); client.shutdown(Status.UNAVAILABLE); - verify(mockClientTransportListener, timeout(TIMEOUT_MS)).transportShutdown(any(Status.class)); + verify(mockClientTransportListener, timeout(TIMEOUT_MS)).transportShutdown(any(Status.class), + any(DisconnectError.class)); ClientTransport.PingCallback mockPingCallback = mock(ClientTransport.PingCallback.class); try { client.ping(mockPingCallback, MoreExecutors.directExecutor()); @@ -623,8 +626,8 @@ public void ping_afterTermination() throws Exception { // Transport doesn't support ping, so this neither passes nor fails. assumeTrue(false); } - verify(mockPingCallback, timeout(TIMEOUT_MS)).onFailure(throwableCaptor.capture()); - Status status = Status.fromThrowable(throwableCaptor.getValue()); + verify(mockPingCallback, timeout(TIMEOUT_MS)).onFailure(statusCaptor.capture()); + Status status = statusCaptor.getValue(); assertSame(shutdownReason, status); } @@ -639,15 +642,16 @@ public void newStream_duringShutdown() throws Exception { ClientStreamListenerBase clientStreamListener = new ClientStreamListenerBase(); stream.start(clientStreamListener); client.shutdown(Status.UNAVAILABLE); - verify(mockClientTransportListener, timeout(TIMEOUT_MS)).transportShutdown(any(Status.class)); + verify(mockClientTransportListener, timeout(TIMEOUT_MS)).transportShutdown(any(Status.class), + any(DisconnectError.class)); ClientStream stream2 = client.newStream( methodDescriptor, new Metadata(), callOptions, tracers); ClientStreamListenerBase clientStreamListener2 = new ClientStreamListenerBase(); stream2.start(clientStreamListener2); Status clientStreamStatus2 = - clientStreamListener2.status.get(TIMEOUT_MS, TimeUnit.MILLISECONDS); - assertNotNull(clientStreamListener2.trailers.get(TIMEOUT_MS, TimeUnit.MILLISECONDS)); + clientStreamListener2.awaitClose(TIMEOUT_MS, TimeUnit.MILLISECONDS); + assertNotNull(clientStreamListener2.awaitTrailers(TIMEOUT_MS, TimeUnit.MILLISECONDS)); assertCodeEquals(Status.UNAVAILABLE, clientStreamStatus2); assertNull(clientStreamTracer2.getInboundTrailers()); assertSame(clientStreamStatus2, clientStreamTracer2.getStatus()); @@ -661,8 +665,8 @@ public void newStream_duringShutdown() throws Exception { StreamCreation serverStreamCreation = serverTransportListener.takeStreamOrFail(20 * TIMEOUT_MS, TimeUnit.MILLISECONDS); serverStreamCreation.stream.close(Status.OK, new Metadata()); - assertCodeEquals(Status.OK, clientStreamListener.status.get(TIMEOUT_MS, TimeUnit.MILLISECONDS)); - assertNotNull(clientStreamListener.trailers.get(TIMEOUT_MS, TimeUnit.MILLISECONDS)); + assertCodeEquals(Status.OK, clientStreamListener.awaitClose(TIMEOUT_MS, TimeUnit.MILLISECONDS)); + assertNotNull(clientStreamListener.awaitTrailers(TIMEOUT_MS, TimeUnit.MILLISECONDS)); } @Test @@ -682,8 +686,8 @@ public void newStream_afterTermination() throws Exception { ClientStreamListenerBase clientStreamListener = new ClientStreamListenerBase(); stream.start(clientStreamListener); assertEquals( - shutdownReason, clientStreamListener.status.get(TIMEOUT_MS, TimeUnit.MILLISECONDS)); - assertNotNull(clientStreamListener.trailers.get(TIMEOUT_MS, TimeUnit.MILLISECONDS)); + shutdownReason, clientStreamListener.awaitClose(TIMEOUT_MS, TimeUnit.MILLISECONDS)); + assertNotNull(clientStreamListener.awaitTrailers(TIMEOUT_MS, TimeUnit.MILLISECONDS)); verify(mockClientTransportListener, never()).transportInUse(anyBoolean()); assertNull(clientStreamTracer1.getInboundTrailers()); assertSame(shutdownReason, clientStreamTracer1.getStatus()); @@ -791,6 +795,17 @@ public void transportInUse_clientCancel() throws Exception { @Test public void basicStream() throws Exception { + serverListener = + new MockServerListener( + transport -> + new MockServerTransportListener(transport) { + @Override + public Attributes transportReady(Attributes attributes) { + return super.transportReady(attributes).toBuilder() + .set(ADDITIONAL_TRANSPORT_ATTR_KEY, "additional attribute value") + .build(); + } + }); InOrder serverInOrder = inOrder(serverStreamTracerFactory); server.start(serverListener); client = newClientTransport(server); @@ -857,25 +872,20 @@ public void basicStream() throws Exception { message.close(); assertThat(clientStreamTracer1.nextOutboundEvent()) .matches("outboundMessageSent\\(0, -?[0-9]+, -?[0-9]+\\)"); - if (sizesReported()) { + if (isEnabledSupportTracingMessageSizes()) { assertThat(clientStreamTracer1.getOutboundWireSize()).isGreaterThan(0L); assertThat(clientStreamTracer1.getOutboundUncompressedSize()).isGreaterThan(0L); - } else { - assertThat(clientStreamTracer1.getOutboundWireSize()).isEqualTo(0L); - assertThat(clientStreamTracer1.getOutboundUncompressedSize()).isEqualTo(0L); } + assertThat(serverStreamTracer1.nextInboundEvent()).isEqualTo("inboundMessage(0)"); assertNull("no additional message expected", serverStreamListener.messageQueue.poll()); clientStream.halfClose(); assertTrue(serverStreamListener.awaitHalfClosed(TIMEOUT_MS, TimeUnit.MILLISECONDS)); - if (sizesReported()) { + if (isEnabledSupportTracingMessageSizes()) { assertThat(serverStreamTracer1.getInboundWireSize()).isGreaterThan(0L); assertThat(serverStreamTracer1.getInboundUncompressedSize()).isGreaterThan(0L); - } else { - assertThat(serverStreamTracer1.getInboundWireSize()).isEqualTo(0L); - assertThat(serverStreamTracer1.getInboundUncompressedSize()).isEqualTo(0L); } assertThat(serverStreamTracer1.nextInboundEvent()) .matches("inboundMessageRead\\(0, -?[0-9]+, -?[0-9]+\\)"); @@ -889,7 +899,7 @@ public void basicStream() throws Exception { Metadata serverHeadersCopy = new Metadata(); serverHeadersCopy.merge(serverHeaders); serverStream.writeHeaders(serverHeaders, true); - Metadata headers = clientStreamListener.headers.get(TIMEOUT_MS, TimeUnit.MILLISECONDS); + Metadata headers = clientStreamListener.awaitHeaders(TIMEOUT_MS, TimeUnit.MILLISECONDS); assertNotNull(headers); assertAsciiMetadataValuesEqual(serverHeadersCopy.getAll(asciiKey), headers.getAll(asciiKey)); assertEquals( @@ -907,24 +917,18 @@ public void basicStream() throws Exception { assertNotNull("message expected", message); assertThat(serverStreamTracer1.nextOutboundEvent()) .matches("outboundMessageSent\\(0, -?[0-9]+, -?[0-9]+\\)"); - if (sizesReported()) { + if (isEnabledSupportTracingMessageSizes()) { assertThat(serverStreamTracer1.getOutboundWireSize()).isGreaterThan(0L); assertThat(serverStreamTracer1.getOutboundUncompressedSize()).isGreaterThan(0L); - } else { - assertThat(serverStreamTracer1.getOutboundWireSize()).isEqualTo(0L); - assertThat(serverStreamTracer1.getOutboundUncompressedSize()).isEqualTo(0L); } assertTrue(clientStreamTracer1.getInboundHeaders()); assertThat(clientStreamTracer1.nextInboundEvent()).isEqualTo("inboundMessage(0)"); assertEquals("Hi. Who are you?", methodDescriptor.parseResponse(message)); assertThat(clientStreamTracer1.nextInboundEvent()) .matches("inboundMessageRead\\(0, -?[0-9]+, -?[0-9]+\\)"); - if (sizesReported()) { + if (isEnabledSupportTracingMessageSizes()) { assertThat(clientStreamTracer1.getInboundWireSize()).isGreaterThan(0L); assertThat(clientStreamTracer1.getInboundUncompressedSize()).isGreaterThan(0L); - } else { - assertThat(clientStreamTracer1.getInboundWireSize()).isEqualTo(0L); - assertThat(clientStreamTracer1.getInboundUncompressedSize()).isEqualTo(0L); } message.close(); @@ -940,11 +944,11 @@ public void basicStream() throws Exception { serverStream.close(status, trailers); assertNull(serverStreamTracer1.nextInboundEvent()); assertNull(serverStreamTracer1.nextOutboundEvent()); - assertCodeEquals(Status.OK, serverStreamListener.status.get(TIMEOUT_MS, TimeUnit.MILLISECONDS)); + assertCodeEquals(Status.OK, serverStreamListener.awaitClose(TIMEOUT_MS, TimeUnit.MILLISECONDS)); assertSame(status, serverStreamTracer1.getStatus()); - Status clientStreamStatus = clientStreamListener.status.get(TIMEOUT_MS, TimeUnit.MILLISECONDS); + Status clientStreamStatus = clientStreamListener.awaitClose(TIMEOUT_MS, TimeUnit.MILLISECONDS); Metadata clientStreamTrailers = - clientStreamListener.trailers.get(TIMEOUT_MS, TimeUnit.MILLISECONDS); + clientStreamListener.awaitTrailers(TIMEOUT_MS, TimeUnit.MILLISECONDS); assertSame(clientStreamTrailers, clientStreamTracer1.getInboundTrailers()); assertSame(clientStreamStatus, clientStreamTracer1.getStatus()); assertNull(clientStreamTracer1.nextInboundEvent()); @@ -1013,14 +1017,14 @@ public void zeroMessageStream() throws Exception { assertTrue(serverStreamListener.awaitHalfClosed(TIMEOUT_MS, TimeUnit.MILLISECONDS)); serverStream.writeHeaders(new Metadata(), true); - assertNotNull(clientStreamListener.headers.get(TIMEOUT_MS, TimeUnit.MILLISECONDS)); + assertNotNull(clientStreamListener.awaitHeaders(TIMEOUT_MS, TimeUnit.MILLISECONDS)); Status status = Status.OK.withDescription("Nice talking to you"); serverStream.close(status, new Metadata()); - assertCodeEquals(Status.OK, serverStreamListener.status.get(TIMEOUT_MS, TimeUnit.MILLISECONDS)); - Status clientStreamStatus = clientStreamListener.status.get(TIMEOUT_MS, TimeUnit.MILLISECONDS); + assertCodeEquals(Status.OK, serverStreamListener.awaitClose(TIMEOUT_MS, TimeUnit.MILLISECONDS)); + Status clientStreamStatus = clientStreamListener.awaitClose(TIMEOUT_MS, TimeUnit.MILLISECONDS); Metadata clientStreamTrailers = - clientStreamListener.trailers.get(TIMEOUT_MS, TimeUnit.MILLISECONDS); + clientStreamListener.awaitTrailers(TIMEOUT_MS, TimeUnit.MILLISECONDS); assertNotNull(clientStreamTrailers); assertEquals(status.getCode(), clientStreamStatus.getCode()); assertEquals(status.getDescription(), clientStreamStatus.getDescription()); @@ -1050,15 +1054,15 @@ public void earlyServerClose_withServerHeaders() throws Exception { ServerStreamListenerBase serverStreamListener = serverStreamCreation.listener; serverStream.writeHeaders(new Metadata(), true); - assertNotNull(clientStreamListener.headers.get(TIMEOUT_MS, TimeUnit.MILLISECONDS)); + assertNotNull(clientStreamListener.awaitHeaders(TIMEOUT_MS, TimeUnit.MILLISECONDS)); Status strippedStatus = Status.OK.withDescription("Hello. Goodbye."); Status status = strippedStatus.withCause(new Exception()); serverStream.close(status, new Metadata()); - assertCodeEquals(Status.OK, serverStreamListener.status.get(TIMEOUT_MS, TimeUnit.MILLISECONDS)); - Status clientStreamStatus = clientStreamListener.status.get(TIMEOUT_MS, TimeUnit.MILLISECONDS); + assertCodeEquals(Status.OK, serverStreamListener.awaitClose(TIMEOUT_MS, TimeUnit.MILLISECONDS)); + Status clientStreamStatus = clientStreamListener.awaitClose(TIMEOUT_MS, TimeUnit.MILLISECONDS); Metadata clientStreamTrailers = - clientStreamListener.trailers.get(TIMEOUT_MS, TimeUnit.MILLISECONDS); + clientStreamListener.awaitTrailers(TIMEOUT_MS, TimeUnit.MILLISECONDS); assertNotNull(clientStreamTrailers); checkClientStatus(status, clientStreamStatus); assertTrue(clientStreamTracer1.getOutboundHeaders()); @@ -1094,10 +1098,10 @@ public void earlyServerClose_noServerHeaders() throws Exception { trailers.put(asciiKey, "dupvalue"); trailers.put(binaryKey, "äbinarytrailers"); serverStream.close(status, trailers); - assertCodeEquals(Status.OK, serverStreamListener.status.get(TIMEOUT_MS, TimeUnit.MILLISECONDS)); - Status clientStreamStatus = clientStreamListener.status.get(TIMEOUT_MS, TimeUnit.MILLISECONDS); + assertCodeEquals(Status.OK, serverStreamListener.awaitClose(TIMEOUT_MS, TimeUnit.MILLISECONDS)); + Status clientStreamStatus = clientStreamListener.awaitClose(TIMEOUT_MS, TimeUnit.MILLISECONDS); Metadata clientStreamTrailers = - clientStreamListener.trailers.get(TIMEOUT_MS, TimeUnit.MILLISECONDS); + clientStreamListener.awaitTrailers(TIMEOUT_MS, TimeUnit.MILLISECONDS); checkClientStatus(status, clientStreamStatus); assertEquals( Lists.newArrayList(trailers.getAll(asciiKey)), @@ -1132,10 +1136,10 @@ public void earlyServerClose_serverFailure() throws Exception { Status strippedStatus = Status.INTERNAL.withDescription("I'm not listening"); Status status = strippedStatus.withCause(new Exception()); serverStream.close(status, new Metadata()); - assertCodeEquals(Status.OK, serverStreamListener.status.get(TIMEOUT_MS, TimeUnit.MILLISECONDS)); - Status clientStreamStatus = clientStreamListener.status.get(TIMEOUT_MS, TimeUnit.MILLISECONDS); + assertCodeEquals(Status.OK, serverStreamListener.awaitClose(TIMEOUT_MS, TimeUnit.MILLISECONDS)); + Status clientStreamStatus = clientStreamListener.awaitClose(TIMEOUT_MS, TimeUnit.MILLISECONDS); Metadata clientStreamTrailers = - clientStreamListener.trailers.get(TIMEOUT_MS, TimeUnit.MILLISECONDS); + clientStreamListener.awaitTrailers(TIMEOUT_MS, TimeUnit.MILLISECONDS); assertNotNull(clientStreamTrailers); checkClientStatus(status, clientStreamStatus); assertTrue(clientStreamTracer1.getOutboundHeaders()); @@ -1175,10 +1179,10 @@ public void closed( Status strippedStatus = Status.INTERNAL.withDescription("I'm not listening"); Status status = strippedStatus.withCause(new Exception()); serverStream.close(status, new Metadata()); - assertCodeEquals(Status.OK, serverStreamListener.status.get(TIMEOUT_MS, TimeUnit.MILLISECONDS)); - Status clientStreamStatus = clientStreamListener.status.get(TIMEOUT_MS, TimeUnit.MILLISECONDS); + assertCodeEquals(Status.OK, serverStreamListener.awaitClose(TIMEOUT_MS, TimeUnit.MILLISECONDS)); + Status clientStreamStatus = clientStreamListener.awaitClose(TIMEOUT_MS, TimeUnit.MILLISECONDS); Metadata clientStreamTrailers = - clientStreamListener.trailers.get(TIMEOUT_MS, TimeUnit.MILLISECONDS); + clientStreamListener.awaitTrailers(TIMEOUT_MS, TimeUnit.MILLISECONDS); assertNotNull(clientStreamTrailers); checkClientStatus(status, clientStreamStatus); assertTrue(clientStreamTracer1.getOutboundHeaders()); @@ -1206,9 +1210,9 @@ public void clientCancel() throws Exception { Status status = Status.CANCELLED.withDescription("Nevermind").withCause(new Exception()); clientStream.cancel(status); - assertEquals(status, clientStreamListener.status.get(TIMEOUT_MS, TimeUnit.MILLISECONDS)); - assertNotNull(clientStreamListener.trailers.get(TIMEOUT_MS, TimeUnit.MILLISECONDS)); - Status serverStatus = serverStreamListener.status.get(TIMEOUT_MS, TimeUnit.MILLISECONDS); + assertEquals(status, clientStreamListener.awaitClose(TIMEOUT_MS, TimeUnit.MILLISECONDS)); + assertNotNull(clientStreamListener.awaitTrailers(TIMEOUT_MS, TimeUnit.MILLISECONDS)); + Status serverStatus = serverStreamListener.awaitClose(TIMEOUT_MS, TimeUnit.MILLISECONDS); assertNotEquals(Status.Code.OK, serverStatus.getCode()); // Cause should not be transmitted between client and server by default assertNull(serverStatus.getCause()); @@ -1285,16 +1289,11 @@ public void onReady() { serverStream.close(Status.OK, new Metadata()); assertTrue(clientStreamTracer1.getOutboundHeaders()); assertTrue(clientStreamTracer1.getInboundHeaders()); - if (sizesReported()) { + if (isEnabledSupportTracingMessageSizes()) { assertThat(clientStreamTracer1.getInboundWireSize()).isGreaterThan(0L); assertThat(clientStreamTracer1.getInboundUncompressedSize()).isGreaterThan(0L); assertThat(serverStreamTracer1.getOutboundWireSize()).isGreaterThan(0L); assertThat(serverStreamTracer1.getOutboundUncompressedSize()).isGreaterThan(0L); - } else { - assertThat(clientStreamTracer1.getInboundWireSize()).isEqualTo(0L); - assertThat(clientStreamTracer1.getInboundUncompressedSize()).isEqualTo(0L); - assertThat(serverStreamTracer1.getOutboundWireSize()).isEqualTo(0L); - assertThat(serverStreamTracer1.getOutboundUncompressedSize()).isEqualTo(0L); } assertNull(clientStreamTracer1.getInboundTrailers()); assertSame(status, clientStreamTracer1.getStatus()); @@ -1325,9 +1324,9 @@ public void serverCancel() throws Exception { Status status = Status.DEADLINE_EXCEEDED.withDescription("It was bound to happen") .withCause(new Exception()); serverStream.cancel(status); - assertEquals(status, serverStreamListener.status.get(TIMEOUT_MS, TimeUnit.MILLISECONDS)); - Status clientStreamStatus = clientStreamListener.status.get(TIMEOUT_MS, TimeUnit.MILLISECONDS); - assertNotNull(clientStreamListener.trailers.get(TIMEOUT_MS, TimeUnit.MILLISECONDS)); + assertEquals(status, serverStreamListener.awaitClose(TIMEOUT_MS, TimeUnit.MILLISECONDS)); + Status clientStreamStatus = clientStreamListener.awaitClose(TIMEOUT_MS, TimeUnit.MILLISECONDS); + assertNotNull(clientStreamListener.awaitTrailers(TIMEOUT_MS, TimeUnit.MILLISECONDS)); // Presently we can't sent much back to the client in this case. Verify that is the current // behavior for consistency between transports. assertCodeEquals(Status.CANCELLED, clientStreamStatus); @@ -1458,7 +1457,7 @@ public void flowControlPushBack() throws Exception { clientStream.flush(); clientStream.halfClose(); doPingPong(serverListener); - assertFalse(serverStreamListener.awaitHalfClosed(TIMEOUT_MS, TimeUnit.MILLISECONDS)); + assertFalse(serverStreamListener.isHalfClosed()); serverStream.request(1); serverReceived += verifyMessageCountAndClose(serverStreamListener.messageQueue, 1); @@ -1470,18 +1469,14 @@ public void flowControlPushBack() throws Exception { Status status = Status.OK.withDescription("... quite a lengthy discussion"); serverStream.close(status, new Metadata()); doPingPong(serverListener); - try { - clientStreamListener.status.get(TIMEOUT_MS, TimeUnit.MILLISECONDS); - fail("Expected TimeoutException"); - } catch (TimeoutException expectedException) { - } + assertFalse(clientStreamListener.isClosed()); clientStream.request(1); clientReceived += verifyMessageCountAndClose(clientStreamListener.messageQueue, 1); assertEquals(serverSent + 6, clientReceived); - assertCodeEquals(Status.OK, serverStreamListener.status.get(TIMEOUT_MS, TimeUnit.MILLISECONDS)); - Status clientStreamStatus = clientStreamListener.status.get(TIMEOUT_MS, TimeUnit.MILLISECONDS); - assertNotNull(clientStreamListener.trailers.get(TIMEOUT_MS, TimeUnit.MILLISECONDS)); + assertCodeEquals(Status.OK, serverStreamListener.awaitClose(TIMEOUT_MS, TimeUnit.MILLISECONDS)); + Status clientStreamStatus = clientStreamListener.awaitClose(TIMEOUT_MS, TimeUnit.MILLISECONDS); + assertNotNull(clientStreamListener.awaitTrailers(TIMEOUT_MS, TimeUnit.MILLISECONDS)); assertEquals(status.getCode(), clientStreamStatus.getCode()); assertEquals(status.getDescription(), clientStreamStatus.getDescription()); } @@ -1537,9 +1532,9 @@ public void flowControlDoesNotDeadlockLargeMessage() throws Exception { serverStream.close(status, new Metadata()); doPingPong(serverListener); clientStream.request(1); - assertCodeEquals(Status.OK, serverStreamListener.status.get(TIMEOUT_MS, TimeUnit.MILLISECONDS)); - Status clientStreamStatus = clientStreamListener.status.get(TIMEOUT_MS, TimeUnit.MILLISECONDS); - assertNotNull(clientStreamListener.trailers.get(TIMEOUT_MS, TimeUnit.MILLISECONDS)); + assertCodeEquals(Status.OK, serverStreamListener.awaitClose(TIMEOUT_MS, TimeUnit.MILLISECONDS)); + Status clientStreamStatus = clientStreamListener.awaitClose(TIMEOUT_MS, TimeUnit.MILLISECONDS); + assertNotNull(clientStreamListener.awaitTrailers(TIMEOUT_MS, TimeUnit.MILLISECONDS)); assertEquals(status.getCode(), clientStreamStatus.getCode()); assertEquals(status.getDescription(), clientStreamStatus.getDescription()); } @@ -1607,8 +1602,8 @@ public void interactionsAfterServerStreamCloseAreNoops() throws Exception { // setup clientStream.request(1); server.stream.close(Status.INTERNAL, new Metadata()); - assertNotNull(clientStreamListener.status.get(TIMEOUT_MS, TimeUnit.MILLISECONDS)); - assertNotNull(clientStreamListener.trailers.get(TIMEOUT_MS, TimeUnit.MILLISECONDS)); + assertNotNull(clientStreamListener.awaitClose(TIMEOUT_MS, TimeUnit.MILLISECONDS)); + assertNotNull(clientStreamListener.awaitTrailers(TIMEOUT_MS, TimeUnit.MILLISECONDS)); // Ensure that for a closed ServerStream, interactions are noops server.stream.writeHeaders(new Metadata(), true); @@ -1640,7 +1635,7 @@ public void interactionsAfterClientStreamCancelAreNoops() throws Exception { // setup server.stream.request(1); clientStream.cancel(Status.UNKNOWN); - assertNotNull(server.listener.status.get(TIMEOUT_MS, TimeUnit.MILLISECONDS)); + assertNotNull(server.listener.awaitClose(TIMEOUT_MS, TimeUnit.MILLISECONDS)); // Ensure that for a cancelled ClientStream, interactions are noops clientStream.writeMessage(methodDescriptor.streamRequest("request")); @@ -1763,9 +1758,8 @@ public void transportTracer_server_streamEnded_ok() throws Exception { clientStream.halfClose(); serverStream.close(Status.OK, new Metadata()); // do not validate stats until close() has been called on client - assertNotNull(clientStreamListener.status.get(TIMEOUT_MS, TimeUnit.MILLISECONDS)); - assertNotNull(clientStreamListener.trailers.get(TIMEOUT_MS, TimeUnit.MILLISECONDS)); - + assertNotNull(clientStreamListener.awaitClose(TIMEOUT_MS, TimeUnit.MILLISECONDS)); + assertNotNull(clientStreamListener.awaitTrailers(TIMEOUT_MS, TimeUnit.MILLISECONDS)); TransportStats serverAfter = getTransportStats(serverTransportListener.transport); assertEquals(1, serverAfter.streamsSucceeded); @@ -1802,9 +1796,8 @@ public void transportTracer_server_streamEnded_nonOk() throws Exception { serverStream.close(Status.UNKNOWN, new Metadata()); // do not validate stats until close() has been called on client - assertNotNull(clientStreamListener.status.get(TIMEOUT_MS, TimeUnit.MILLISECONDS)); - assertNotNull(clientStreamListener.trailers.get(TIMEOUT_MS, TimeUnit.MILLISECONDS)); - + assertNotNull(clientStreamListener.awaitClose(TIMEOUT_MS, TimeUnit.MILLISECONDS)); + assertNotNull(clientStreamListener.awaitTrailers(TIMEOUT_MS, TimeUnit.MILLISECONDS)); TransportStats serverAfter = getTransportStats(serverTransportListener.transport); assertEquals(1, serverAfter.streamsFailed); @@ -1842,7 +1835,7 @@ public void transportTracer_client_streamEnded_nonOk() throws Exception { clientStream.cancel(Status.UNKNOWN); // do not validate stats until close() has been called on server - assertNotNull(serverStreamCreation.listener.status.get(TIMEOUT_MS, TimeUnit.MILLISECONDS)); + assertNotNull(serverStreamCreation.listener.awaitClose(TIMEOUT_MS, TimeUnit.MILLISECONDS)); TransportStats serverAfter = getTransportStats(serverTransportListener.transport); assertEquals(1, serverAfter.streamsFailed); @@ -1999,7 +1992,7 @@ public void serverChecksInboundMetadataSize() throws Exception { // Server shouldn't have created a stream, so nothing to clean up on server-side // If this times out, the server probably isn't noticing the metadata size - Status status = clientStreamListener.status.get(TIMEOUT_MS, TimeUnit.MILLISECONDS); + Status status = clientStreamListener.awaitClose(TIMEOUT_MS, TimeUnit.MILLISECONDS); List codeOptions = Arrays.asList( Status.Code.UNKNOWN, Status.Code.RESOURCE_EXHAUSTED, Status.Code.INTERNAL); if (!codeOptions.contains(status.getCode())) { @@ -2040,13 +2033,13 @@ public void clientChecksInboundMetadataSize_header() throws Exception { serverStreamCreation.stream.writeMessage(methodDescriptor.streamResponse("response")); serverStreamCreation.stream.close(Status.OK, new Metadata()); - Status status = clientStreamListener.status.get(TIMEOUT_MS, TimeUnit.MILLISECONDS); + Status status = clientStreamListener.awaitClose(TIMEOUT_MS, TimeUnit.MILLISECONDS); List codeOptions = Arrays.asList( Status.Code.UNKNOWN, Status.Code.RESOURCE_EXHAUSTED, Status.Code.INTERNAL); if (!codeOptions.contains(status.getCode())) { fail("Status code was not expected: " + status); } - assertFalse(clientStreamListener.headers.isDone()); + assertFalse(clientStreamListener.hasHeaders()); } /** This assumes the client limits metadata size to GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE. */ @@ -2085,13 +2078,13 @@ public void clientChecksInboundMetadataSize_trailer() throws Exception { serverStreamCreation.stream.writeMessage(methodDescriptor.streamResponse("response")); serverStreamCreation.stream.close(Status.OK, tooLargeMetadata); - Status status = clientStreamListener.status.get(TIMEOUT_MS, TimeUnit.MILLISECONDS); + Status status = clientStreamListener.awaitClose(TIMEOUT_MS, TimeUnit.MILLISECONDS); List codeOptions = Arrays.asList( Status.Code.UNKNOWN, Status.Code.RESOURCE_EXHAUSTED, Status.Code.INTERNAL); if (!codeOptions.contains(status.getCode())) { fail("Status code was not expected: " + status); } - Metadata metadata = clientStreamListener.trailers.get(TIMEOUT_MS, TimeUnit.MILLISECONDS); + Metadata metadata = clientStreamListener.awaitTrailers(TIMEOUT_MS, TimeUnit.MILLISECONDS); assertNull(metadata.get(tellTaleKey)); } @@ -2119,9 +2112,9 @@ methodDescriptor, new Metadata(), callOptions, ServerStreamListenerBase serverStreamListener = serverStreamCreation.listener; serverStream.close(Status.OK, new Metadata()); - assertNotNull(clientStreamListener.status.get(TIMEOUT_MS, TimeUnit.MILLISECONDS)); - assertNotNull(clientStreamListener.trailers.get(TIMEOUT_MS, TimeUnit.MILLISECONDS)); - assertNotNull(serverStreamListener.status.get(TIMEOUT_MS, TimeUnit.MILLISECONDS)); + assertNotNull(clientStreamListener.awaitClose(TIMEOUT_MS, TimeUnit.MILLISECONDS)); + assertNotNull(clientStreamListener.awaitTrailers(TIMEOUT_MS, TimeUnit.MILLISECONDS)); + assertNotNull(serverStreamListener.awaitClose(TIMEOUT_MS, TimeUnit.MILLISECONDS)); client.shutdown(Status.UNAVAILABLE); } @@ -2166,7 +2159,7 @@ private static void checkClientStatus(Status expectedStatus, Status clientStream assertNull(clientStreamStatus.getCause()); } - private static boolean waitForFuture(Future future, long timeout, TimeUnit unit) + static boolean waitForFuture(Future future, long timeout, TimeUnit unit) throws InterruptedException { try { future.get(timeout, unit); @@ -2178,13 +2171,13 @@ private static boolean waitForFuture(Future future, long timeout, TimeUnit un return true; } - private static void runIfNotNull(Runnable runnable) { + protected static void runIfNotNull(Runnable runnable) { if (runnable != null) { runnable.run(); } } - private static void startTransport( + protected static void startTransport( ManagedClientTransport clientTransport, ManagedClientTransport.Listener listener) { runIfNotNull(clientTransport.start(listener)); @@ -2202,217 +2195,6 @@ public void streamCreated(Attributes transportAttrs, Metadata metadata) { } } - private static class MockServerListener implements ServerListener { - public final BlockingQueue listeners - = new LinkedBlockingQueue<>(); - private final SettableFuture shutdown = SettableFuture.create(); - - @Override - public ServerTransportListener transportCreated(ServerTransport transport) { - MockServerTransportListener listener = new MockServerTransportListener(transport); - listeners.add(listener); - return listener; - } - - @Override - public void serverShutdown() { - assertTrue(shutdown.set(null)); - } - - public boolean waitForShutdown(long timeout, TimeUnit unit) throws InterruptedException { - return waitForFuture(shutdown, timeout, unit); - } - - public MockServerTransportListener takeListenerOrFail(long timeout, TimeUnit unit) - throws InterruptedException { - MockServerTransportListener listener = listeners.poll(timeout, unit); - if (listener == null) { - fail("Timed out waiting for server transport"); - } - return listener; - } - } - - private static class MockServerTransportListener implements ServerTransportListener { - public final ServerTransport transport; - public final BlockingQueue streams = new LinkedBlockingQueue<>(); - private final SettableFuture terminated = SettableFuture.create(); - - public MockServerTransportListener(ServerTransport transport) { - this.transport = transport; - } - - @Override - public void streamCreated(ServerStream stream, String method, Metadata headers) { - ServerStreamListenerBase listener = new ServerStreamListenerBase(); - streams.add(new StreamCreation(stream, method, headers, listener)); - stream.setListener(listener); - } - - @Override - public Attributes transportReady(Attributes attributes) { - return Attributes.newBuilder() - .setAll(attributes) - .set(ADDITIONAL_TRANSPORT_ATTR_KEY, "additional attribute value") - .build(); - } - - @Override - public void transportTerminated() { - assertTrue(terminated.set(null)); - } - - public boolean waitForTermination(long timeout, TimeUnit unit) throws InterruptedException { - return waitForFuture(terminated, timeout, unit); - } - - public boolean isTerminated() { - return terminated.isDone(); - } - - public StreamCreation takeStreamOrFail(long timeout, TimeUnit unit) - throws InterruptedException { - StreamCreation stream = streams.poll(timeout, unit); - if (stream == null) { - fail("Timed out waiting for server stream"); - } - return stream; - } - } - - private static class ServerStreamListenerBase implements ServerStreamListener { - private final BlockingQueue messageQueue = new LinkedBlockingQueue<>(); - // Would have used Void instead of Object, but null elements are not allowed - private final BlockingQueue readyQueue = new LinkedBlockingQueue<>(); - private final CountDownLatch halfClosedLatch = new CountDownLatch(1); - private final SettableFuture status = SettableFuture.create(); - - private boolean awaitOnReady(int timeout, TimeUnit unit) throws Exception { - return readyQueue.poll(timeout, unit) != null; - } - - private boolean awaitOnReadyAndDrain(int timeout, TimeUnit unit) throws Exception { - if (!awaitOnReady(timeout, unit)) { - return false; - } - // Throw the rest away - readyQueue.drainTo(Lists.newArrayList()); - return true; - } - - private boolean awaitHalfClosed(int timeout, TimeUnit unit) throws Exception { - return halfClosedLatch.await(timeout, unit); - } - - @Override - public void messagesAvailable(MessageProducer producer) { - if (status.isDone()) { - fail("messagesAvailable invoked after closed"); - } - InputStream message; - while ((message = producer.next()) != null) { - messageQueue.add(message); - } - } - - @Override - public void onReady() { - if (status.isDone()) { - fail("onReady invoked after closed"); - } - readyQueue.add(new Object()); - } - - @Override - public void halfClosed() { - if (status.isDone()) { - fail("halfClosed invoked after closed"); - } - halfClosedLatch.countDown(); - } - - @Override - public void closed(Status status) { - if (this.status.isDone()) { - fail("closed invoked more than once"); - } - this.status.set(status); - } - } - - private static class ClientStreamListenerBase implements ClientStreamListener { - private final BlockingQueue messageQueue = new LinkedBlockingQueue<>(); - // Would have used Void instead of Object, but null elements are not allowed - private final BlockingQueue readyQueue = new LinkedBlockingQueue<>(); - private final SettableFuture headers = SettableFuture.create(); - private final SettableFuture trailers = SettableFuture.create(); - private final SettableFuture status = SettableFuture.create(); - - private boolean awaitOnReady(int timeout, TimeUnit unit) throws Exception { - return readyQueue.poll(timeout, unit) != null; - } - - private boolean awaitOnReadyAndDrain(int timeout, TimeUnit unit) throws Exception { - if (!awaitOnReady(timeout, unit)) { - return false; - } - // Throw the rest away - readyQueue.drainTo(Lists.newArrayList()); - return true; - } - - @Override - public void messagesAvailable(MessageProducer producer) { - if (status.isDone()) { - fail("messagesAvailable invoked after closed"); - } - InputStream message; - while ((message = producer.next()) != null) { - messageQueue.add(message); - } - } - - @Override - public void onReady() { - if (status.isDone()) { - fail("onReady invoked after closed"); - } - readyQueue.add(new Object()); - } - - @Override - public void headersRead(Metadata headers) { - if (status.isDone()) { - fail("headersRead invoked after closed"); - } - this.headers.set(headers); - } - - @Override - public void closed(Status status, RpcProgress rpcProgress, Metadata trailers) { - if (this.status.isDone()) { - fail("headersRead invoked after closed"); - } - this.status.set(status); - this.trailers.set(trailers); - } - } - - private static class StreamCreation { - public final ServerStream stream; - public final String method; - public final Metadata headers; - public final ServerStreamListenerBase listener; - - public StreamCreation( - ServerStream stream, String method, Metadata headers, ServerStreamListenerBase listener) { - this.stream = stream; - this.method = method; - this.headers = headers; - this.listener = listener; - } - } - private static class StringMarshaller implements MethodDescriptor.Marshaller { public static final StringMarshaller INSTANCE = new StringMarshaller(); diff --git a/core/src/testFixtures/java/io/grpc/internal/ClientStreamListenerBase.java b/core/src/testFixtures/java/io/grpc/internal/ClientStreamListenerBase.java new file mode 100644 index 00000000000..3c35cf59225 --- /dev/null +++ b/core/src/testFixtures/java/io/grpc/internal/ClientStreamListenerBase.java @@ -0,0 +1,126 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.internal; + +import static org.junit.Assert.fail; + +import com.google.common.collect.Lists; +import com.google.common.util.concurrent.SettableFuture; +import io.grpc.Metadata; +import io.grpc.Status; +import java.io.InputStream; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.TimeUnit; + +public class ClientStreamListenerBase implements ClientStreamListener { + public final BlockingQueue messageQueue = new LinkedBlockingQueue<>(); + // Would have used Void instead of Object, but null elements are not allowed + private final BlockingQueue readyQueue = new LinkedBlockingQueue<>(); + private final SettableFuture headers = SettableFuture.create(); + private final SettableFuture trailers = SettableFuture.create(); + private final SettableFuture status = SettableFuture.create(); + + /** + * Returns the stream's status or throws {@link java.util.concurrent.TimeoutException} if it isn't + * closed before the timeout. + */ + public Status awaitClose(int timeout, TimeUnit unit) throws Exception { + return status.get(timeout, unit); + } + + /** + * Return {@code true} if {@code #awaitClose} would return immediately with a status. + */ + public boolean isClosed() { + return status.isDone(); + } + + /** + * Returns response headers from the server or throws {@link + * java.util.concurrent.TimeoutException} if they aren't delivered before the timeout. + * + *

Callers must not modify the returned object. + */ + public Metadata awaitHeaders(int timeout, TimeUnit unit) throws Exception { + return headers.get(timeout, unit); + } + + /** + * Returns response trailers from the server or throws {@link + * java.util.concurrent.TimeoutException} if they aren't delivered before the timeout. + * + *

Callers must not modify the returned object. + */ + public Metadata awaitTrailers(int timeout, TimeUnit unit) throws Exception { + return trailers.get(timeout, unit); + } + + public boolean awaitOnReady(int timeout, TimeUnit unit) throws Exception { + return readyQueue.poll(timeout, unit) != null; + } + + public boolean awaitOnReadyAndDrain(int timeout, TimeUnit unit) throws Exception { + if (!awaitOnReady(timeout, unit)) { + return false; + } + // Throw the rest away + readyQueue.drainTo(Lists.newArrayList()); + return true; + } + + @Override + public void messagesAvailable(MessageProducer producer) { + if (status.isDone()) { + fail("messagesAvailable invoked after closed"); + } + InputStream message; + while ((message = producer.next()) != null) { + messageQueue.add(message); + } + } + + @Override + public void onReady() { + if (status.isDone()) { + fail("onReady invoked after closed"); + } + readyQueue.add(new Object()); + } + + @Override + public void headersRead(Metadata headers) { + if (status.isDone()) { + fail("headersRead invoked after closed"); + } + this.headers.set(headers); + } + + @Override + public void closed(Status status, RpcProgress rpcProgress, Metadata trailers) { + if (this.status.isDone()) { + fail("headersRead invoked after closed"); + } + this.status.set(status); + this.trailers.set(trailers); + } + + /** Returns true iff response headers have been received from the server. */ + public boolean hasHeaders() { + return headers.isDone(); + } +} diff --git a/core/src/testFixtures/java/io/grpc/internal/FakeClock.java b/core/src/testFixtures/java/io/grpc/internal/FakeClock.java index 9cc9178f1ff..1a3584f4e2c 100644 --- a/core/src/testFixtures/java/io/grpc/internal/FakeClock.java +++ b/core/src/testFixtures/java/io/grpc/internal/FakeClock.java @@ -188,7 +188,8 @@ private void schedule(ScheduledTask task, long delay, TimeUnit unit) { } @Override public boolean isShutdown() { - throw new UnsupportedOperationException(); + // If shutdown is not implemented, then it is never shutdown. + return false; } @Override public boolean isTerminated() { diff --git a/core/src/testFixtures/java/io/grpc/internal/MockServerListener.java b/core/src/testFixtures/java/io/grpc/internal/MockServerListener.java new file mode 100644 index 00000000000..0c33b98cf1c --- /dev/null +++ b/core/src/testFixtures/java/io/grpc/internal/MockServerListener.java @@ -0,0 +1,78 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.internal; + +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +import com.google.common.util.concurrent.SettableFuture; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.TimeUnit; + +/** + * A {@link ServerListener} that helps you write blocking unit tests. + * + *

TODO: Rename, since this is not actually a mock: + * https://testing.googleblog.com/2013/07/testing-on-toilet-know-your-test-doubles.html + */ +public class MockServerListener implements ServerListener { + private final BlockingQueue listeners = new LinkedBlockingQueue<>(); + private final SettableFuture shutdown = SettableFuture.create(); + private final ServerTransportListenerFactory serverTransportListenerFactory; + + /** + * Lets you customize the {@link MockServerTransportListener} installed on newly created + * {@link ServerTransport}s. + */ + public interface ServerTransportListenerFactory { + MockServerTransportListener create(ServerTransport transport); + } + + public MockServerListener(ServerTransportListenerFactory serverTransportListenerFactory) { + this.serverTransportListenerFactory = serverTransportListenerFactory; + } + + public MockServerListener() { + this(MockServerTransportListener::new); + } + + @Override + public ServerTransportListener transportCreated(ServerTransport transport) { + MockServerTransportListener listener = serverTransportListenerFactory.create(transport); + listeners.add(listener); + return listener; + } + + @Override + public void serverShutdown() { + assertTrue(shutdown.set(null)); + } + + public boolean waitForShutdown(long timeout, TimeUnit unit) throws InterruptedException { + return AbstractTransportTest.waitForFuture(shutdown, timeout, unit); + } + + public MockServerTransportListener takeListenerOrFail(long timeout, TimeUnit unit) + throws InterruptedException { + MockServerTransportListener listener = listeners.poll(timeout, unit); + if (listener == null) { + fail("Timed out waiting for server transport"); + } + return listener; + } +} diff --git a/core/src/testFixtures/java/io/grpc/internal/MockServerTransportListener.java b/core/src/testFixtures/java/io/grpc/internal/MockServerTransportListener.java new file mode 100644 index 00000000000..e6c4e2f578e --- /dev/null +++ b/core/src/testFixtures/java/io/grpc/internal/MockServerTransportListener.java @@ -0,0 +1,93 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.internal; + +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +import com.google.common.util.concurrent.SettableFuture; +import io.grpc.Attributes; +import io.grpc.Metadata; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.TimeUnit; + +/** + * A {@link ServerTransportListener} that helps you write blocking unit tests. + * + *

TODO: Rename, since this is not actually a mock: + * https://testing.googleblog.com/2013/07/testing-on-toilet-know-your-test-doubles.html + */ +public class MockServerTransportListener implements ServerTransportListener { + public final ServerTransport transport; + private final BlockingQueue streams = new LinkedBlockingQueue<>(); + private final SettableFuture terminated = SettableFuture.create(); + + public MockServerTransportListener(ServerTransport transport) { + this.transport = transport; + } + + @Override + public void streamCreated(ServerStream stream, String method, Metadata headers) { + ServerStreamListenerBase listener = new ServerStreamListenerBase(); + streams.add(new StreamCreation(stream, method, headers, listener)); + stream.setListener(listener); + } + + @Override + public Attributes transportReady(Attributes attributes) { + assertFalse(terminated.isDone()); + return attributes; + } + + @Override + public void transportTerminated() { + assertTrue(terminated.set(null)); + } + + public boolean waitForTermination(long timeout, TimeUnit unit) throws InterruptedException { + return AbstractTransportTest.waitForFuture(terminated, timeout, unit); + } + + public boolean isTerminated() { + return terminated.isDone(); + } + + public StreamCreation takeStreamOrFail(long timeout, TimeUnit unit) throws InterruptedException { + StreamCreation stream = streams.poll(timeout, unit); + if (stream == null) { + fail("Timed out waiting for server stream"); + } + return stream; + } + + public static class StreamCreation { + public final ServerStream stream; + public final String method; + public final Metadata headers; + public final ServerStreamListenerBase listener; + + public StreamCreation( + ServerStream stream, String method, Metadata headers, ServerStreamListenerBase listener) { + this.stream = stream; + this.method = method; + this.headers = headers; + this.listener = listener; + } + } +} diff --git a/core/src/testFixtures/java/io/grpc/internal/PickFirstLoadBalancerProviderAccessor.java b/core/src/testFixtures/java/io/grpc/internal/PickFirstLoadBalancerProviderAccessor.java new file mode 100644 index 00000000000..a6e94df03c2 --- /dev/null +++ b/core/src/testFixtures/java/io/grpc/internal/PickFirstLoadBalancerProviderAccessor.java @@ -0,0 +1,28 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.internal; + +/** + * Accessor for PickFirstLoadBalancerProvider, allowing access only during tests. + */ +public final class PickFirstLoadBalancerProviderAccessor { + private PickFirstLoadBalancerProviderAccessor() {} + + public static void setEnableNewPickFirst(boolean enableNewPickFirst) { + PickFirstLoadBalancerProvider.enableNewPickFirst = enableNewPickFirst; + } +} diff --git a/core/src/testFixtures/java/io/grpc/internal/ReadableBufferTestBase.java b/core/src/testFixtures/java/io/grpc/internal/ReadableBufferTestBase.java index 97e0df38ae7..2262f0466f7 100644 --- a/core/src/testFixtures/java/io/grpc/internal/ReadableBufferTestBase.java +++ b/core/src/testFixtures/java/io/grpc/internal/ReadableBufferTestBase.java @@ -16,12 +16,11 @@ package io.grpc.internal; -import static com.google.common.base.Charsets.UTF_8; +import static java.nio.charset.StandardCharsets.UTF_8; import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; import java.io.ByteArrayOutputStream; -import java.nio.Buffer; import java.nio.ByteBuffer; import java.util.Arrays; import org.junit.Assume; @@ -83,30 +82,6 @@ public void partialReadToStreamShouldSucceed() throws Exception { assertEquals(msg.length() - 2, buffer.readableBytes()); } - @Test - public void readToByteBufferShouldSucceed() { - ReadableBuffer buffer = buffer(); - ByteBuffer byteBuffer = ByteBuffer.allocate(msg.length()); - buffer.readBytes(byteBuffer); - ((Buffer) byteBuffer).flip(); - byte[] array = new byte[msg.length()]; - byteBuffer.get(array); - assertArrayEquals(msg.getBytes(UTF_8), array); - assertEquals(0, buffer.readableBytes()); - } - - @Test - public void partialReadToByteBufferShouldSucceed() { - ReadableBuffer buffer = buffer(); - ByteBuffer byteBuffer = ByteBuffer.allocate(2); - buffer.readBytes(byteBuffer); - ((Buffer) byteBuffer).flip(); - byte[] array = new byte[2]; - byteBuffer.get(array); - assertArrayEquals(new byte[]{'h', 'e'}, array); - assertEquals(msg.length() - 2, buffer.readableBytes()); - } - @Test public void partialReadToReadableBufferShouldSucceed() { ReadableBuffer buffer = buffer(); diff --git a/core/src/testFixtures/java/io/grpc/internal/ServerStreamListenerBase.java b/core/src/testFixtures/java/io/grpc/internal/ServerStreamListenerBase.java new file mode 100644 index 00000000000..aaa70600542 --- /dev/null +++ b/core/src/testFixtures/java/io/grpc/internal/ServerStreamListenerBase.java @@ -0,0 +1,99 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.internal; + +import static org.junit.Assert.fail; + +import com.google.common.collect.Lists; +import com.google.common.util.concurrent.SettableFuture; +import io.grpc.Status; +import java.io.InputStream; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.TimeUnit; + +/** + * A {@link ServerStreamListener} that helps you write blocking unit tests. + */ +public class ServerStreamListenerBase implements ServerStreamListener { + public final BlockingQueue messageQueue = new LinkedBlockingQueue<>(); + // Would have used Void instead of Object, but null elements are not allowed + private final BlockingQueue readyQueue = new LinkedBlockingQueue<>(); + private final CountDownLatch halfClosedLatch = new CountDownLatch(1); + private final SettableFuture status = SettableFuture.create(); + + public boolean awaitOnReady(int timeout, TimeUnit unit) throws Exception { + return readyQueue.poll(timeout, unit) != null; + } + + public boolean awaitOnReadyAndDrain(int timeout, TimeUnit unit) throws Exception { + if (!awaitOnReady(timeout, unit)) { + return false; + } + // Throw the rest away + readyQueue.drainTo(Lists.newArrayList()); + return true; + } + + public boolean awaitHalfClosed(int timeout, TimeUnit unit) throws Exception { + return halfClosedLatch.await(timeout, unit); + } + + public boolean isHalfClosed() { + return halfClosedLatch.getCount() == 0; + } + + public Status awaitClose(int timeout, TimeUnit unit) throws Exception { + return status.get(timeout, unit); + } + + @Override + public void messagesAvailable(MessageProducer producer) { + if (status.isDone()) { + fail("messagesAvailable invoked after closed"); + } + InputStream message; + while ((message = producer.next()) != null) { + messageQueue.add(message); + } + } + + @Override + public void onReady() { + if (status.isDone()) { + fail("onReady invoked after closed"); + } + readyQueue.add(new Object()); + } + + @Override + public void halfClosed() { + if (status.isDone()) { + fail("halfClosed invoked after closed"); + } + halfClosedLatch.countDown(); + } + + @Override + public void closed(Status status) { + if (this.status.isDone()) { + fail("closed invoked more than once"); + } + this.status.set(status); + } +} diff --git a/core/src/testFixtures/java/io/grpc/internal/TestUtils.java b/core/src/testFixtures/java/io/grpc/internal/TestUtils.java index 055a7f80283..e5aab81c1bb 100644 --- a/core/src/testFixtures/java/io/grpc/internal/TestUtils.java +++ b/core/src/testFixtures/java/io/grpc/internal/TestUtils.java @@ -146,7 +146,7 @@ public Runnable answer(InvocationOnMock invocation) throws Throwable { } @SuppressWarnings("ReferenceEquality") - public static final EquivalentAddressGroup stripAttrs(EquivalentAddressGroup eag) { + public static EquivalentAddressGroup stripAttrs(EquivalentAddressGroup eag) { if (eag.getAttributes() == Attributes.EMPTY) { return eag; } diff --git a/cronet/build.gradle b/cronet/build.gradle index 3252a9d249b..e096761ddd2 100644 --- a/cronet/build.gradle +++ b/cronet/build.gradle @@ -8,14 +8,13 @@ description = "gRPC: Cronet Android" repositories { google() - mavenCentral() } android { - namespace 'io.grpc.cronet' + namespace = 'io.grpc.cronet' compileSdkVersion 33 defaultConfig { - minSdkVersion 21 + minSdkVersion 23 targetSdkVersion 33 versionCode 1 versionName "1.0" @@ -47,6 +46,7 @@ dependencies { libraries.cronet.api implementation project(':grpc-core') implementation libraries.guava + implementation 'org.checkerframework:checker-qual:3.49.5' testImplementation project(':grpc-testing') testImplementation libraries.cronet.embedded diff --git a/cronet/src/main/java/io/grpc/cronet/CronetChannelBuilder.java b/cronet/src/main/java/io/grpc/cronet/CronetChannelBuilder.java index 93413aa22a3..7ea1bc891c2 100644 --- a/cronet/src/main/java/io/grpc/cronet/CronetChannelBuilder.java +++ b/cronet/src/main/java/io/grpc/cronet/CronetChannelBuilder.java @@ -20,10 +20,11 @@ import static com.google.common.base.Preconditions.checkNotNull; import static io.grpc.internal.GrpcUtil.DEFAULT_MAX_MESSAGE_SIZE; -import android.util.Log; +import android.net.Network; import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; import com.google.common.util.concurrent.MoreExecutors; +import com.google.errorprone.annotations.CanIgnoreReturnValue; import com.google.errorprone.annotations.DoNotCall; import io.grpc.ChannelCredentials; import io.grpc.ChannelLogger; @@ -38,8 +39,6 @@ import io.grpc.internal.ManagedChannelImplBuilder.ClientTransportFactoryBuilder; import io.grpc.internal.SharedResourceHolder; import io.grpc.internal.TransportTracer; -import java.lang.reflect.InvocationTargetException; -import java.lang.reflect.Method; import java.net.InetSocketAddress; import java.net.SocketAddress; import java.util.Collection; @@ -49,15 +48,11 @@ import javax.annotation.Nullable; import org.chromium.net.BidirectionalStream; import org.chromium.net.CronetEngine; -import org.chromium.net.ExperimentalBidirectionalStream; -import org.chromium.net.ExperimentalCronetEngine; /** Convenience class for building channels with the cronet transport. */ @ExperimentalApi("There is no plan to make this API stable, given transport API instability") public final class CronetChannelBuilder extends ForwardingChannelBuilder2 { - private static final String LOG_TAG = "CronetChannelBuilder"; - /** BidirectionalStream.Builder factory used for getting the gRPC BidirectionalStream. */ public static abstract class StreamBuilderFactory { public abstract BidirectionalStream.Builder newBidirectionalStreamBuilder( @@ -91,7 +86,7 @@ public static CronetChannelBuilder forAddress(String name, int port) { private final CronetEngine cronetEngine; private final ManagedChannelImplBuilder managedChannelImplBuilder; - private TransportTracer.Factory transportTracerFactory = TransportTracer.getDefaultFactory(); + private final TransportTracer.Factory transportTracerFactory = TransportTracer.getDefaultFactory(); private boolean alwaysUsePut = false; @@ -112,6 +107,7 @@ public static CronetChannelBuilder forAddress(String name, int port) { private int trafficStatsTag; private boolean trafficStatsUidSet; private int trafficStatsUid; + private Network network; private CronetChannelBuilder(String host, int port, CronetEngine cronetEngine) { final class CronetChannelTransportFactoryBuilder implements ClientTransportFactoryBuilder { @@ -139,7 +135,7 @@ protected ManagedChannelBuilder delegate() { * Sets the maximum message size allowed to be received on the channel. If not called, * defaults to {@link io.grpc.internal.GrpcUtil#DEFAULT_MAX_MESSAGE_SIZE}. */ - public final CronetChannelBuilder maxMessageSize(int maxMessageSize) { + public CronetChannelBuilder maxMessageSize(int maxMessageSize) { checkArgument(maxMessageSize >= 0, "maxMessageSize must be >= 0"); this.maxMessageSize = maxMessageSize; return this; @@ -148,7 +144,7 @@ public final CronetChannelBuilder maxMessageSize(int maxMessageSize) { /** * Sets the Cronet channel to always use PUT instead of POST. Defaults to false. */ - public final CronetChannelBuilder alwaysUsePut(boolean enable) { + public CronetChannelBuilder alwaysUsePut(boolean enable) { this.alwaysUsePut = enable; return this; } @@ -170,7 +166,7 @@ public final CronetChannelBuilder alwaysUsePut(boolean enable) { * application. * @return the builder to facilitate chaining. */ - final CronetChannelBuilder setTrafficStatsTag(int tag) { + CronetChannelBuilder setTrafficStatsTag(int tag) { trafficStatsTagSet = true; trafficStatsTag = tag; return this; @@ -180,7 +176,7 @@ final CronetChannelBuilder setTrafficStatsTag(int tag) { * Sets specific UID to use when accounting socket traffic caused by this channel. See {@link * android.net.TrafficStats} for more information. Designed for use when performing an operation * on behalf of another application. Caller must hold {@link - * android.Manifest.permission#MODIFY_NETWORK_ACCOUNTING} permission. By default traffic is + * android.Manifest.permission#UPDATE_DEVICE_STATS} permission. By default traffic is * attributed to UID of caller. * *

NOTE:Setting a UID disallows sharing of sockets with channels with other UIDs, which @@ -191,12 +187,19 @@ final CronetChannelBuilder setTrafficStatsTag(int tag) { * @param uid the UID to attribute socket traffic caused by this channel. * @return the builder to facilitate chaining. */ - final CronetChannelBuilder setTrafficStatsUid(int uid) { + CronetChannelBuilder setTrafficStatsUid(int uid) { trafficStatsUidSet = true; trafficStatsUid = uid; return this; } + /** Sets the network ID to use for this channel traffic. */ + @CanIgnoreReturnValue + CronetChannelBuilder bindToNetwork(@Nullable Network network) { + this.network = network; + return this; + } + /** * Provides a custom scheduled executor service. * @@ -207,7 +210,7 @@ final CronetChannelBuilder setTrafficStatsUid(int uid) { * * @since 1.12.0 */ - public final CronetChannelBuilder scheduledExecutorService( + public CronetChannelBuilder scheduledExecutorService( ScheduledExecutorService scheduledExecutorService) { this.scheduledExecutorService = checkNotNull(scheduledExecutorService, "scheduledExecutorService"); @@ -217,7 +220,12 @@ public final CronetChannelBuilder scheduledExecutorService( ClientTransportFactory buildTransportFactory() { return new CronetTransportFactory( new TaggingStreamFactory( - cronetEngine, trafficStatsTagSet, trafficStatsTag, trafficStatsUidSet, trafficStatsUid), + cronetEngine, + trafficStatsTagSet, + trafficStatsTag, + trafficStatsUidSet, + trafficStatsUid, + network), MoreExecutors.directExecutor(), scheduledExecutorService, maxMessageSize, @@ -296,101 +304,44 @@ public Collection> getSupportedSocketAddressTypes * StreamBuilderFactory impl that applies TrafficStats tags to stream builders that are produced. */ private static class TaggingStreamFactory extends StreamBuilderFactory { - private static volatile boolean loadSetTrafficStatsTagAttempted; - private static volatile boolean loadSetTrafficStatsUidAttempted; - private static volatile Method setTrafficStatsTagMethod; - private static volatile Method setTrafficStatsUidMethod; - private final CronetEngine cronetEngine; private final boolean trafficStatsTagSet; private final int trafficStatsTag; private final boolean trafficStatsUidSet; private final int trafficStatsUid; + private final Network network; TaggingStreamFactory( CronetEngine cronetEngine, boolean trafficStatsTagSet, int trafficStatsTag, boolean trafficStatsUidSet, - int trafficStatsUid) { + int trafficStatsUid, + Network network) { this.cronetEngine = cronetEngine; this.trafficStatsTagSet = trafficStatsTagSet; this.trafficStatsTag = trafficStatsTag; this.trafficStatsUidSet = trafficStatsUidSet; this.trafficStatsUid = trafficStatsUid; + this.network = network; } @Override public BidirectionalStream.Builder newBidirectionalStreamBuilder( String url, BidirectionalStream.Callback callback, Executor executor) { - ExperimentalBidirectionalStream.Builder builder = - ((ExperimentalCronetEngine) cronetEngine) + BidirectionalStream.Builder builder = + cronetEngine .newBidirectionalStreamBuilder(url, callback, executor); if (trafficStatsTagSet) { - setTrafficStatsTag(builder, trafficStatsTag); + builder.setTrafficStatsTag(trafficStatsTag); } if (trafficStatsUidSet) { - setTrafficStatsUid(builder, trafficStatsUid); - } - return builder; - } - - private static void setTrafficStatsTag(ExperimentalBidirectionalStream.Builder builder, - int tag) { - if (!loadSetTrafficStatsTagAttempted) { - synchronized (TaggingStreamFactory.class) { - if (!loadSetTrafficStatsTagAttempted) { - try { - setTrafficStatsTagMethod = ExperimentalBidirectionalStream.Builder.class - .getMethod("setTrafficStatsTag", int.class); - } catch (NoSuchMethodException e) { - Log.w(LOG_TAG, - "Failed to load method ExperimentalBidirectionalStream.Builder.setTrafficStatsTag", - e); - } finally { - loadSetTrafficStatsTagAttempted = true; - } - } - } + builder.setTrafficStatsUid(trafficStatsUid); } - if (setTrafficStatsTagMethod != null) { - try { - setTrafficStatsTagMethod.invoke(builder, tag); - } catch (InvocationTargetException e) { - throw new RuntimeException(e.getCause() == null ? e.getTargetException() : e.getCause()); - } catch (IllegalAccessException e) { - Log.w(LOG_TAG, "Failed to set traffic stats tag: " + tag, e); - } - } - } - - private static void setTrafficStatsUid(ExperimentalBidirectionalStream.Builder builder, - int uid) { - if (!loadSetTrafficStatsUidAttempted) { - synchronized (TaggingStreamFactory.class) { - if (!loadSetTrafficStatsUidAttempted) { - try { - setTrafficStatsUidMethod = ExperimentalBidirectionalStream.Builder.class - .getMethod("setTrafficStatsUid", int.class); - } catch (NoSuchMethodException e) { - Log.w(LOG_TAG, - "Failed to load method ExperimentalBidirectionalStream.Builder.setTrafficStatsUid", - e); - } finally { - loadSetTrafficStatsUidAttempted = true; - } - } - } - } - if (setTrafficStatsUidMethod != null) { - try { - setTrafficStatsUidMethod.invoke(builder, uid); - } catch (InvocationTargetException e) { - throw new RuntimeException(e.getCause() == null ? e.getTargetException() : e.getCause()); - } catch (IllegalAccessException e) { - Log.w(LOG_TAG, "Failed to set traffic stats uid: " + uid, e); - } + if (network != null) { + builder.bindToNetwork(network.getNetworkHandle()); } + return builder; } } } diff --git a/cronet/src/main/java/io/grpc/cronet/CronetClientStream.java b/cronet/src/main/java/io/grpc/cronet/CronetClientStream.java index d44b716146e..07bbb953489 100644 --- a/cronet/src/main/java/io/grpc/cronet/CronetClientStream.java +++ b/cronet/src/main/java/io/grpc/cronet/CronetClientStream.java @@ -25,6 +25,7 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; import com.google.common.io.BaseEncoding; +import com.google.errorprone.annotations.concurrent.GuardedBy; import io.grpc.Attributes; import io.grpc.CallOptions; import io.grpc.InternalMetadata; @@ -40,11 +41,9 @@ import io.grpc.internal.TransportFrameUtil; import io.grpc.internal.TransportTracer; import io.grpc.internal.WritableBuffer; -import java.lang.reflect.InvocationTargetException; -import java.lang.reflect.Method; import java.nio.Buffer; import java.nio.ByteBuffer; -import java.nio.charset.Charset; +import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.Collection; import java.util.Collections; @@ -52,29 +51,29 @@ import java.util.Map; import java.util.concurrent.Executor; import javax.annotation.Nullable; -import javax.annotation.concurrent.GuardedBy; import org.chromium.net.BidirectionalStream; import org.chromium.net.CronetException; -import org.chromium.net.ExperimentalBidirectionalStream; import org.chromium.net.UrlResponseInfo; /** * Client stream for the cronet transport. */ class CronetClientStream extends AbstractClientStream { - private static final int READ_BUFFER_CAPACITY = 4 * 1024; private static final ByteBuffer EMPTY_BUFFER = ByteBuffer.allocateDirect(0); private static final String LOG_TAG = "grpc-java-cronet"; - private static volatile boolean loadAddRequestAnnotationAttempted; - private static volatile Method addRequestAnnotationMethod; - @Deprecated static final CallOptions.Key CRONET_ANNOTATION_KEY = CallOptions.Key.create("cronet-annotation"); static final CallOptions.Key> CRONET_ANNOTATIONS_KEY = CallOptions.Key.create("cronet-annotations"); + /** + * Sets the read buffer size which the GRPC layer will use to read data from Cronet. Higher buffer + * size leads to less overhead but more memory consumption. The current default value is 4KB. + */ + static final CallOptions.Key CRONET_READ_BUFFER_SIZE_KEY = + CallOptions.Key.createWithDefault("cronet-read-buffer-size", 4 * 1024); private final String url; private final String userAgent; @@ -91,6 +90,8 @@ class CronetClientStream extends AbstractClientStream { private final Collection annotations; private final TransportState state; private final Sink sink = new Sink(); + @VisibleForTesting + final int readBufferSize; private StreamBuilderFactory streamFactory; CronetClientStream( @@ -124,7 +125,9 @@ class CronetClientStream extends AbstractClientStream { this.delayRequestHeader = (method.getType() == MethodDescriptor.MethodType.UNARY); this.annotation = callOptions.getOption(CRONET_ANNOTATION_KEY); this.annotations = callOptions.getOption(CRONET_ANNOTATIONS_KEY); - this.state = new TransportState(maxMessageSize, statsTraceCtx, lock, transportTracer); + this.state = new TransportState(maxMessageSize, statsTraceCtx, lock, transportTracer, + callOptions); + this.readBufferSize = callOptions.getOption(CRONET_READ_BUFFER_SIZE_KEY); // Tests expect the "plain" deframer behavior, not MigratingDeframer // https://github.com/grpc/grpc-java/issues/7140 @@ -193,14 +196,12 @@ public void writeHeaders(Metadata metadata, byte[] payload) { builder.delayRequestHeadersUntilFirstFlush(true); } if (annotation != null || annotations != null) { - ExperimentalBidirectionalStream.Builder expBidiStreamBuilder = - (ExperimentalBidirectionalStream.Builder) builder; if (annotation != null) { - addRequestAnnotation(expBidiStreamBuilder, annotation); + builder.addRequestAnnotation(annotation); } if (annotations != null) { for (Object o : annotations) { - addRequestAnnotation(expBidiStreamBuilder, o); + builder.addRequestAnnotation(o); } } } @@ -254,7 +255,7 @@ public void cancel(Status reason) { class TransportState extends Http2ClientStreamTransportState { private final Object lock; @GuardedBy("lock") - private Collection pendingData = new ArrayList(); + private final Collection pendingData = new ArrayList<>(); @GuardedBy("lock") private boolean streamReady; @GuardedBy("lock") @@ -270,8 +271,8 @@ class TransportState extends Http2ClientStreamTransportState { public TransportState( int maxMessageSize, StatsTraceContext statsTraceCtx, Object lock, - TransportTracer transportTracer) { - super(maxMessageSize, statsTraceCtx, transportTracer); + TransportTracer transportTracer, CallOptions options) { + super(maxMessageSize, statsTraceCtx, transportTracer, options); this.lock = Preconditions.checkNotNull(lock, "lock"); } @@ -316,7 +317,7 @@ public void bytesRead(int processedBytes) { if (Log.isLoggable(LOG_TAG, Log.VERBOSE)) { Log.v(LOG_TAG, "BidirectionalStream.read"); } - stream.read(ByteBuffer.allocateDirect(READ_BUFFER_CAPACITY)); + stream.read(ByteBuffer.allocateDirect(readBufferSize)); } } @@ -366,39 +367,9 @@ private static boolean isApplicationHeader(String key) { && !TE_HEADER.name().equalsIgnoreCase(key); } - private static void addRequestAnnotation(ExperimentalBidirectionalStream.Builder builder, - Object annotation) { - if (!loadAddRequestAnnotationAttempted) { - synchronized (CronetClientStream.class) { - if (!loadAddRequestAnnotationAttempted) { - try { - addRequestAnnotationMethod = ExperimentalBidirectionalStream.Builder.class - .getMethod("addRequestAnnotation", Object.class); - } catch (NoSuchMethodException e) { - Log.w(LOG_TAG, - "Failed to load method ExperimentalBidirectionalStream.Builder.addRequestAnnotation", - e); - } finally { - loadAddRequestAnnotationAttempted = true; - } - } - } - } - if (addRequestAnnotationMethod != null) { - try { - addRequestAnnotationMethod.invoke(builder, annotation); - } catch (InvocationTargetException e) { - throw new RuntimeException(e.getCause() == null ? e.getTargetException() : e.getCause()); - } catch (IllegalAccessException e) { - Log.w(LOG_TAG, "Failed to add request annotation: " + annotation, e); - } - } - } - private void setGrpcHeaders(BidirectionalStream.Builder builder) { // Psuedo-headers are set by cronet. // All non-pseudo headers must come after pseudo headers. - // TODO(ericgribkoff): remove this and set it on CronetEngine after crbug.com/588204 gets fixed. builder.addHeader(USER_AGENT_KEY.name(), userAgent); builder.addHeader(CONTENT_TYPE_KEY.name(), GrpcUtil.CONTENT_TYPE_GRPC); builder.addHeader("te", GrpcUtil.TE_TRAILERS); @@ -408,10 +379,10 @@ private void setGrpcHeaders(BidirectionalStream.Builder builder) { // String and byte array. byte[][] serializedHeaders = TransportFrameUtil.toHttp2Headers(headers); for (int i = 0; i < serializedHeaders.length; i += 2) { - String key = new String(serializedHeaders[i], Charset.forName("UTF-8")); + String key = new String(serializedHeaders[i], StandardCharsets.UTF_8); // TODO(ericgribkoff): log an error or throw an exception if (isApplicationHeader(key)) { - String value = new String(serializedHeaders[i + 1], Charset.forName("UTF-8")); + String value = new String(serializedHeaders[i + 1], StandardCharsets.UTF_8); builder.addHeader(key, value); } } @@ -466,7 +437,7 @@ public void onResponseHeadersReceived(BidirectionalStream stream, UrlResponseInf Log.v(LOG_TAG, "BidirectionalStream.read"); } reportHeaders(info.getAllHeadersAsList(), false); - stream.read(ByteBuffer.allocateDirect(READ_BUFFER_CAPACITY)); + stream.read(ByteBuffer.allocateDirect(readBufferSize)); } @Override @@ -588,8 +559,8 @@ private void reportHeaders(List> headers, boolean endO byte[][] headerValues = new byte[headerList.size()][]; for (int i = 0; i < headerList.size(); i += 2) { - headerValues[i] = headerList.get(i).getBytes(Charset.forName("UTF-8")); - headerValues[i + 1] = headerList.get(i + 1).getBytes(Charset.forName("UTF-8")); + headerValues[i] = headerList.get(i).getBytes(StandardCharsets.UTF_8); + headerValues[i + 1] = headerList.get(i + 1).getBytes(StandardCharsets.UTF_8); } Metadata metadata = InternalMetadata.newMetadata(TransportFrameUtil.toRawSerializedHeaders(headerValues)); diff --git a/cronet/src/main/java/io/grpc/cronet/CronetClientTransport.java b/cronet/src/main/java/io/grpc/cronet/CronetClientTransport.java index 800d9155854..99eb88737aa 100644 --- a/cronet/src/main/java/io/grpc/cronet/CronetClientTransport.java +++ b/cronet/src/main/java/io/grpc/cronet/CronetClientTransport.java @@ -19,6 +19,7 @@ import com.google.common.base.Preconditions; import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.SettableFuture; +import com.google.errorprone.annotations.concurrent.GuardedBy; import io.grpc.Attributes; import io.grpc.CallOptions; import io.grpc.ClientStreamTracer; @@ -33,6 +34,7 @@ import io.grpc.internal.ConnectionClientTransport; import io.grpc.internal.GrpcAttributes; import io.grpc.internal.GrpcUtil; +import io.grpc.internal.SimpleDisconnectError; import io.grpc.internal.StatsTraceContext; import io.grpc.internal.TransportTracer; import java.net.InetSocketAddress; @@ -42,7 +44,6 @@ import java.util.Set; import java.util.concurrent.Executor; import javax.annotation.Nullable; -import javax.annotation.concurrent.GuardedBy; /** * A cronet-based {@link ConnectionClientTransport} implementation. @@ -56,7 +57,7 @@ class CronetClientTransport implements ConnectionClientTransport { private final Object lock = new Object(); @GuardedBy("lock") private final Set streams = Collections.newSetFromMap( - new IdentityHashMap()); + new IdentityHashMap<>()); private final Executor executor; private final int maxMessageSize; private final boolean alwaysUsePut; @@ -64,6 +65,7 @@ class CronetClientTransport implements ConnectionClientTransport { private Attributes attrs; private final boolean useGetForSafeMethods; private final boolean usePutForIdempotentMethods; + private final StreamBuilderFactory streamFactory; // Indicates the transport is in go-away state: no new streams will be processed, // but existing streams may continue. @GuardedBy("lock") @@ -79,7 +81,6 @@ class CronetClientTransport implements ConnectionClientTransport { @GuardedBy("lock") // Whether this transport has started. private boolean started; - private StreamBuilderFactory streamFactory; CronetClientTransport( StreamBuilderFactory streamFactory, @@ -205,9 +206,9 @@ public void shutdownNow(Status status) { // streams.remove() streamsCopy = new ArrayList<>(streams); } - for (int i = 0; i < streamsCopy.size(); i++) { + for (CronetClientStream cronetClientStream : streamsCopy) { // Avoid deadlock by calling into stream without lock held - streamsCopy.get(i).cancel(status); + cronetClientStream.cancel(status); } stopIfNecessary(); } @@ -229,7 +230,7 @@ private void startGoAway(Status status) { startedGoAway = true; } - listener.transportShutdown(status); + listener.transportShutdown(status, SimpleDisconnectError.SUBCHANNEL_SHUTDOWN); synchronized (lock) { goAway = true; @@ -255,7 +256,7 @@ public InternalLogId getLogId() { */ void stopIfNecessary() { synchronized (lock) { - if (goAway && !stopped && streams.size() == 0) { + if (goAway && !stopped && streams.isEmpty()) { stopped = true; } else { return; diff --git a/cronet/src/main/java/io/grpc/cronet/InternalCronetCallOptions.java b/cronet/src/main/java/io/grpc/cronet/InternalCronetCallOptions.java index e7c4144e63a..9261a0a8f4b 100644 --- a/cronet/src/main/java/io/grpc/cronet/InternalCronetCallOptions.java +++ b/cronet/src/main/java/io/grpc/cronet/InternalCronetCallOptions.java @@ -36,6 +36,18 @@ public static CallOptions withAnnotation(CallOptions callOptions, Object annotat return CronetClientStream.withAnnotation(callOptions, annotation); } + public static CallOptions withReadBufferSize(CallOptions callOptions, int size) { + return callOptions.withOption(CronetClientStream.CRONET_READ_BUFFER_SIZE_KEY, size); + } + + /** + * Returns Cronet read buffer size for gRPC included in the given {@code callOptions}. Read + * buffer can be customized via {@link #withReadBufferSize(CallOptions, int)}. + */ + public static int getReadBufferSize(CallOptions callOptions) { + return callOptions.getOption(CronetClientStream.CRONET_READ_BUFFER_SIZE_KEY); + } + /** * Returns Cronet annotations for gRPC included in the given {@code callOptions}. Annotations * are attached via {@link #withAnnotation(CallOptions, Object)}. diff --git a/cronet/src/main/java/io/grpc/cronet/InternalCronetChannelBuilder.java b/cronet/src/main/java/io/grpc/cronet/InternalCronetChannelBuilder.java index f61685937a8..7e5e610ca67 100644 --- a/cronet/src/main/java/io/grpc/cronet/InternalCronetChannelBuilder.java +++ b/cronet/src/main/java/io/grpc/cronet/InternalCronetChannelBuilder.java @@ -16,7 +16,9 @@ package io.grpc.cronet; +import android.net.Network; import io.grpc.Internal; +import org.checkerframework.checker.nullness.qual.Nullable; /** * Internal {@link CronetChannelBuilder} accessor. This is intended for usage internal to the gRPC @@ -47,7 +49,7 @@ public static void setTrafficStatsTag(CronetChannelBuilder builder, int tag) { * Sets specific UID to use when accounting socket traffic caused by this channel. See {@link * android.net.TrafficStats} for more information. Designed for use when performing an operation * on behalf of another application. Caller must hold {@link - * android.Manifest.permission#MODIFY_NETWORK_ACCOUNTING} permission. By default traffic is + * android.Manifest.permission#UPDATE_DEVICE_STATS} permission. By default traffic is * attributed to UID of caller. * *

NOTE:Setting a UID disallows sharing of sockets with channels with other UIDs, which @@ -58,4 +60,9 @@ public static void setTrafficStatsTag(CronetChannelBuilder builder, int tag) { public static void setTrafficStatsUid(CronetChannelBuilder builder, int uid) { builder.setTrafficStatsUid(uid); } + + /** Sets the network {@link android.net.Network} to use when relying traffic by this channel. */ + public static void bindToNetwork(CronetChannelBuilder builder, @Nullable Network network) { + builder.bindToNetwork(network); + } } diff --git a/cronet/src/test/java/io/grpc/cronet/CronetChannelBuilderTest.java b/cronet/src/test/java/io/grpc/cronet/CronetChannelBuilderTest.java index b31b742577d..be437b3c80b 100644 --- a/cronet/src/test/java/io/grpc/cronet/CronetChannelBuilderTest.java +++ b/cronet/src/test/java/io/grpc/cronet/CronetChannelBuilderTest.java @@ -16,7 +16,9 @@ package io.grpc.cronet; +import static io.grpc.cronet.CronetClientStream.CRONET_READ_BUFFER_SIZE_KEY; import static io.grpc.internal.GrpcUtil.TIMER_SERVICE; +import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertSame; import static org.junit.Assert.assertTrue; @@ -35,7 +37,7 @@ import io.grpc.testing.TestMethodDescriptors; import java.net.InetSocketAddress; import java.util.concurrent.ScheduledExecutorService; -import org.chromium.net.ExperimentalCronetEngine; +import org.chromium.net.CronetEngine; import org.junit.Rule; import org.junit.Test; import org.junit.runner.RunWith; @@ -50,7 +52,7 @@ public final class CronetChannelBuilderTest { @Rule public final MockitoRule mocks = MockitoJUnit.rule(); - @Mock private ExperimentalCronetEngine mockEngine; + @Mock private CronetEngine mockEngine; @Mock private ChannelLogger channelLogger; private final ClientStreamTracer[] tracers = @@ -92,6 +94,41 @@ public void alwaysUsePut_defaultsToFalse() throws Exception { assertFalse(stream.idempotent); } + @Test + public void channelBuilderReadBufferSize_defaultsTo4Kb() throws Exception { + CronetChannelBuilder builder = CronetChannelBuilder.forAddress("address", 1234, mockEngine); + CronetTransportFactory transportFactory = + (CronetTransportFactory) builder.buildTransportFactory(); + CronetClientTransport transport = + (CronetClientTransport) + transportFactory.newClientTransport( + new InetSocketAddress("localhost", 443), + new ClientTransportOptions(), + channelLogger); + CronetClientStream stream = transport.newStream( + method, new Metadata(), CallOptions.DEFAULT, tracers); + + assertEquals(4 * 1024, stream.readBufferSize); + } + + @Test + public void channelBuilderReadBufferSize_changeReflected() throws Exception { + CronetChannelBuilder builder = CronetChannelBuilder.forAddress("address", 1234, mockEngine); + CronetTransportFactory transportFactory = + (CronetTransportFactory) builder.buildTransportFactory(); + CronetClientTransport transport = + (CronetClientTransport) + transportFactory.newClientTransport( + new InetSocketAddress("localhost", 443), + new ClientTransportOptions(), + channelLogger); + CronetClientStream stream = transport.newStream( + method, new Metadata(), + CallOptions.DEFAULT.withOption(CRONET_READ_BUFFER_SIZE_KEY, 32 * 1024), tracers); + + assertEquals(32 * 1024, stream.readBufferSize); + } + @Test public void scheduledExecutorService_default() { CronetChannelBuilder builder = CronetChannelBuilder.forAddress("address", 1234, mockEngine); diff --git a/cronet/src/test/java/io/grpc/cronet/CronetClientStreamTest.java b/cronet/src/test/java/io/grpc/cronet/CronetClientStreamTest.java index cfbe27a6257..e2b0e0b26ca 100644 --- a/cronet/src/test/java/io/grpc/cronet/CronetClientStreamTest.java +++ b/cronet/src/test/java/io/grpc/cronet/CronetClientStreamTest.java @@ -18,6 +18,7 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertTrue; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; @@ -46,7 +47,7 @@ import java.io.ByteArrayInputStream; import java.nio.Buffer; import java.nio.ByteBuffer; -import java.nio.charset.Charset; +import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.HashMap; import java.util.List; @@ -54,7 +55,6 @@ import java.util.concurrent.Executor; import org.chromium.net.BidirectionalStream; import org.chromium.net.CronetException; -import org.chromium.net.ExperimentalBidirectionalStream; import org.chromium.net.UrlResponseInfo; import org.chromium.net.impl.UrlResponseInfoImpl; import org.junit.Before; @@ -76,17 +76,12 @@ public final class CronetClientStreamTest { @Mock private CronetClientTransport transport; private Metadata metadata = new Metadata(); @Mock private StreamBuilderFactory factory; - @Mock private ExperimentalBidirectionalStream cronetStream; + @Mock private BidirectionalStream cronetStream; @Mock private ClientStreamListener clientListener; - @Mock private ExperimentalBidirectionalStream.Builder builder; + @Mock private BidirectionalStream.Builder builder; private final Object lock = new Object(); private final TransportTracer transportTracer = TransportTracer.getDefaultFactory().create(); - private final Executor executor = new Executor() { - @Override - public void execute(Runnable r) { - r.run(); - } - }; + private final Executor executor = Runnable::run; CronetClientStream clientStream; private MethodDescriptor.Marshaller marshaller = TestMethodDescriptors.voidMarshaller(); @@ -108,7 +103,7 @@ void setStream(CronetClientStream stream) { @Override @SuppressWarnings("GuardedBy") public void run() { - assertTrue(stream != null); + assertNotNull(stream); stream.transportState().start(factory); } } @@ -172,9 +167,9 @@ public void write() { String[] requests = new String[5]; WritableBuffer[] buffers = new WritableBuffer[5]; for (int i = 0; i < 5; ++i) { - requests[i] = new String("request" + String.valueOf(i)); + requests[i] = "request" + i; buffers[i] = allocator.allocate(requests[i].length()); - buffers[i].write(requests[i].getBytes(Charset.forName("UTF-8")), 0, requests[i].length()); + buffers[i].write(requests[i].getBytes(StandardCharsets.UTF_8), 0, requests[i].length()); // The 3rd and 5th writeFrame calls have flush=true. clientStream.abstractClientStreamSink().writeFrame(buffers[i], false, i == 2 || i == 4, 1); } @@ -207,27 +202,19 @@ public void write() { } private static List> responseHeader(String status) { - Map headers = new HashMap(); + Map headers = new HashMap<>(); headers.put(":status", status); headers.put("content-type", "application/grpc"); headers.put("test-key", "test-value"); - List> headerList = new ArrayList>(3); - for (Map.Entry entry : headers.entrySet()) { - headerList.add(entry); - } - return headerList; + return new ArrayList<>(headers.entrySet()); } private static List> trailers(int status) { - Map trailers = new HashMap(); + Map trailers = new HashMap<>(); trailers.put("grpc-status", String.valueOf(status)); trailers.put("content-type", "application/grpc"); trailers.put("test-trailer-key", "test-trailer-value"); - List> trailerList = new ArrayList>(3); - for (Map.Entry entry : trailers.entrySet()) { - trailerList.add(entry); - } - return trailerList; + return new ArrayList<>(trailers.entrySet()); } private static ByteBuffer createMessageFrame(byte[] bytes) { @@ -267,7 +254,7 @@ public void read() { callback.onReadCompleted( cronetStream, info, - createMessageFrame(new String("response1").getBytes(Charset.forName("UTF-8"))), + createMessageFrame("response1".getBytes(StandardCharsets.UTF_8)), false); // Haven't request any message, so no callback is called here. verify(clientListener, times(0)).messagesAvailable(isA(MessageProducer.class)); @@ -297,9 +284,9 @@ public void streamSucceeded() { verify(cronetStream, times(0)).write(isA(ByteBuffer.class), isA(Boolean.class)); // Send the first data frame. CronetWritableBufferAllocator allocator = new CronetWritableBufferAllocator(); - String request = new String("request"); + String request = "request"; WritableBuffer writableBuffer = allocator.allocate(request.length()); - writableBuffer.write(request.getBytes(Charset.forName("UTF-8")), 0, request.length()); + writableBuffer.write(request.getBytes(StandardCharsets.UTF_8), 0, request.length()); clientStream.abstractClientStreamSink().writeFrame(writableBuffer, false, true, 1); ArgumentCaptor bufferCaptor = ArgumentCaptor.forClass(ByteBuffer.class); verify(cronetStream, times(1)).write(bufferCaptor.capture(), isA(Boolean.class)); @@ -318,7 +305,7 @@ public void streamSucceeded() { callback.onReadCompleted( cronetStream, info, - createMessageFrame(new String("response").getBytes(Charset.forName("UTF-8"))), + createMessageFrame("response".getBytes(StandardCharsets.UTF_8)), false); verify(clientListener, times(1)).messagesAvailable(isA(MessageProducer.class)); verify(cronetStream, times(2)).read(isA(ByteBuffer.class)); @@ -681,8 +668,8 @@ public void getUnaryRequest() { true, false); callback.setStream(stream); - ExperimentalBidirectionalStream.Builder getBuilder = - mock(ExperimentalBidirectionalStream.Builder.class); + BidirectionalStream.Builder getBuilder = + mock(BidirectionalStream.Builder.class); when(getFactory.newBidirectionalStreamBuilder( any(String.class), any(BidirectionalStream.Callback.class), any(Executor.class))) .thenReturn(getBuilder); @@ -694,7 +681,7 @@ public void getUnaryRequest() { .newBidirectionalStreamBuilder( isA(String.class), isA(BidirectionalStream.Callback.class), isA(Executor.class)); - byte[] msg = "request".getBytes(Charset.forName("UTF-8")); + byte[] msg = "request".getBytes(StandardCharsets.UTF_8); stream.writeMessage(new ByteArrayInputStream(msg)); // We still haven't built the stream or sent anything. verify(cronetStream, times(0)).write(isA(ByteBuffer.class), isA(Boolean.class)); @@ -738,8 +725,8 @@ public void idempotentMethod_usesHttpPut() { true, true); callback.setStream(stream); - ExperimentalBidirectionalStream.Builder builder = - mock(ExperimentalBidirectionalStream.Builder.class); + BidirectionalStream.Builder builder = + mock(BidirectionalStream.Builder.class); when(factory.newBidirectionalStreamBuilder( any(String.class), any(BidirectionalStream.Callback.class), any(Executor.class))) .thenReturn(builder); @@ -770,8 +757,8 @@ public void alwaysUsePutOption_usesHttpPut() { true, true); callback.setStream(stream); - ExperimentalBidirectionalStream.Builder builder = - mock(ExperimentalBidirectionalStream.Builder.class); + BidirectionalStream.Builder builder = + mock(BidirectionalStream.Builder.class); when(factory.newBidirectionalStreamBuilder( any(String.class), any(BidirectionalStream.Callback.class), any(Executor.class))) .thenReturn(builder); @@ -810,8 +797,8 @@ public void reservedHeadersStripped() { false, false); callback.setStream(stream); - ExperimentalBidirectionalStream.Builder builder = - mock(ExperimentalBidirectionalStream.Builder.class); + BidirectionalStream.Builder builder = + mock(BidirectionalStream.Builder.class); when(factory.newBidirectionalStreamBuilder( any(String.class), any(BidirectionalStream.Callback.class), any(Executor.class))) .thenReturn(builder); diff --git a/cronet/src/test/java/io/grpc/cronet/CronetClientTransportTest.java b/cronet/src/test/java/io/grpc/cronet/CronetClientTransportTest.java index cc18f33aaea..3a79cc0b6a8 100644 --- a/cronet/src/test/java/io/grpc/cronet/CronetClientTransportTest.java +++ b/cronet/src/test/java/io/grpc/cronet/CronetClientTransportTest.java @@ -17,7 +17,7 @@ package io.grpc.cronet; import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; +import static org.junit.Assert.assertNotNull; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; @@ -34,6 +34,7 @@ import io.grpc.Status; import io.grpc.cronet.CronetChannelBuilder.StreamBuilderFactory; import io.grpc.internal.ClientStreamListener; +import io.grpc.internal.DisconnectError; import io.grpc.internal.GrpcAttributes; import io.grpc.internal.ManagedClientTransport; import io.grpc.internal.TransportTracer; @@ -71,12 +72,7 @@ public final class CronetClientTransportTest { private MethodDescriptor descriptor = TestMethodDescriptors.voidMethod(); @Mock private ManagedClientTransport.Listener clientTransportListener; @Mock private BidirectionalStream.Builder builder; - private final Executor executor = new Executor() { - @Override - public void execute(Runnable r) { - r.run(); - } - }; + private final Executor executor = Runnable::run; @Before public void setUp() { @@ -96,7 +92,7 @@ public void setUp() { false, false); Runnable callback = transport.start(clientTransportListener); - assertTrue(callback != null); + assertNotNull(callback); callback.run(); verify(clientTransportListener).transportReady(); } @@ -133,7 +129,8 @@ public void shutdownTransport() throws Exception { BidirectionalStream.Callback callback2 = callbackCaptor.getValue(); // Shut down the transport. transportShutdown should be called immediately. transport.shutdown(); - verify(clientTransportListener).transportShutdown(any(Status.class)); + verify(clientTransportListener).transportShutdown(any(Status.class), + any(DisconnectError.class)); // Have two live streams. Transport has not been terminated. verify(clientTransportListener, times(0)).transportTerminated(); diff --git a/documentation/android-binderchannel-status-codes.md b/documentation/android-binderchannel-status-codes.md index dda0220bf8a..fae4ef406af 100644 --- a/documentation/android-binderchannel-status-codes.md +++ b/documentation/android-binderchannel-status-codes.md @@ -23,51 +23,66 @@ Consider the table that follows as an BinderChannel-specific addendum to the “ - 1 + 0 - Server app not installed + Server app not visible. + + bindService() returns false - bindService() returns false +

UNIMPLEMENTED

“The operation is not implemented or is not supported / enabled in this service.” + + Give up - This is an error in the client manifest. + + + + 1 -

UNIMPLEMENTED

“The operation is not implemented or is not supported / enabled in this service.” + + Safer Intents violation. - Direct the user to install/reinstall the server app. + Direct the user to install/reinstall the server app. 2 - Old version of the server app doesn’t declare the target android.app.Service in its manifest. + Server app not installed 3 - Target android.app.Service is disabled + Old version of the server app doesn’t declare the target android.app.Service in its manifest. 4 - The whole server app is disabled + Target android.app.Service is disabled 5 - Server app predates the Android M permissions model and the user must review and approve some newly requested permissions before it can run. + The whole server app is disabled 6 + Server app predates the Android M permissions model and the user must review and approve some newly requested permissions before it can run. + + + + 7 + Target android.app.Service doesn’t recognize grpc binding Intent (old version of server app?) onNullBinding() ServiceConnection callback - 7 + 8 Method not found on the io.grpc.Server (old version of server app?) @@ -75,13 +90,13 @@ Consider the table that follows as an BinderChannel-specific addendum to the “ - 8 + 9 Request cardinality violation (old version of server app expects unary rather than streaming, say) - 9 + 10 Old version of the server app exposes target android.app.Service but doesn’t android:export it. @@ -90,9 +105,11 @@ Consider the table that follows as an BinderChannel-specific addendum to the “

PERMISSION_DENIED

“The caller does not have permission to execute the specified operation …” + Direct the user to update the server app in the hopes that a newer version fixes this error in its manifest. + - 10 + 11 Target android.app.Service requires an <android:permission> that client doesn’t hold. @@ -100,7 +117,7 @@ Consider the table that follows as an BinderChannel-specific addendum to the “ - 11 + 12 Violations of the security policy for miscellaneous Android features like android:isolatedProcess, android:externalService, android:singleUser, instant apps, BIND_TREAT_LIKE_ACTIVITY, etc, @@ -108,7 +125,7 @@ Consider the table that follows as an BinderChannel-specific addendum to the “ - 12 + 13 Calling Android UID not allowed by ServerSecurityPolicy @@ -116,13 +133,13 @@ Consider the table that follows as an BinderChannel-specific addendum to the “ - 13 + 14 Server Android UID not allowed by client’s SecurityPolicy - 14 + 15 Server process crashed or killed with request in flight. @@ -144,7 +161,7 @@ Consider the table that follows as an BinderChannel-specific addendum to the “ - 15 + 16 Server app is currently being upgraded to a new version @@ -152,13 +169,13 @@ Consider the table that follows as an BinderChannel-specific addendum to the “ - 16 + 17 The whole server app or the target android.app.Service was disabled - 17 + 18 Binder transaction buffer overflow @@ -166,7 +183,7 @@ Consider the table that follows as an BinderChannel-specific addendum to the “ - 18 + 19 Source Context for bindService() is destroyed with a request in flight @@ -178,11 +195,11 @@ Consider the table that follows as an BinderChannel-specific addendum to the “ Give up for now.

-(Re. 18: The caller can try again later when the user opens the source Activity or restarts the source Service) +(Re. 19: The caller can try again later when the user opens the source Activity or restarts the source Service) - 19 + 20 Client application cancelled the request @@ -190,7 +207,7 @@ Consider the table that follows as an BinderChannel-specific addendum to the “ - 19 + 21 Bug in Android itself or the way the io.grpc.binder transport uses it. @@ -208,7 +225,7 @@ Consider the table that follows as an BinderChannel-specific addendum to the “ - 20 + 22 Flow-control protocol violation @@ -216,7 +233,7 @@ Consider the table that follows as an BinderChannel-specific addendum to the “ - 21 + 23 Can’t parse request/response proto @@ -226,27 +243,27 @@ Consider the table that follows as an BinderChannel-specific addendum to the “ ### Ambiguity -We say a status code is ambiguous if it maps to two error cases that reasonable clients want to handle differently. For instance, a client may have good reasons to handle error cases 9 and 10 above differently. But they can’t do so based on status code alone because those error cases map to the same one. +We say a status code is ambiguous if it maps to two error cases that reasonable clients want to handle differently. For instance, a client may have good reasons to handle error cases 10 and 11 above differently. But they can’t do so based on status code alone because those error cases map to the same one. -In contrast, for example, even though error case 18 and 19 both map to the status code (`CANCELLED`), they are not ambiguous because we see no reason that clients would want to distinguish them. In both cases, clients will simply give up on the request. +In contrast, for example, even though error case 19 and 20 both map to the status code (`CANCELLED`), they are not ambiguous because we see no reason that clients would want to distinguish them. In both cases, clients will simply give up on the request. #### Ambiguity of PERMISSION_DENIED and Mitigations The mapping above has only one apparently ambiguous status code: `PERMISSION_DENIED`. However, this isn’t so bad because of the following: -The use of ``s for inter-app IPC access control (error case 10) is uncommon. Instead, we recommend that server apps only allow IPC from a limited set of client apps known in advance and identified by signature. +The use of ``s for inter-app IPC access control (error case 11) is uncommon. Instead, we recommend that server apps only allow IPC from a limited set of client apps known in advance and identified by signature. -However, there may be gRPC server apps that want to use custom <android:permission>’s to let the end user decide which arbitrary other apps can make use of its gRPC services. In that case, clients should preempt error case 10 simply by [checking whether they hold the required permissions](https://developer.android.com/training/permissions/requesting) before sending a request. +However, there may be gRPC server apps that want to use custom <android:permission>’s to let the end user decide which arbitrary other apps can make use of its gRPC services. In that case, clients should preempt error case 11 simply by [checking whether they hold the required permissions](https://developer.android.com/training/permissions/requesting) before sending a request. -Server apps can avoid error case 9 by never reusing an android.app.Service as a gRPC host if it has ever been android:exported=false in some previous app version. Instead they should simply create a new android.app.Service for this purpose. +Server apps can avoid error case 10 by never reusing an android.app.Service as a gRPC host if it has ever been android:exported=false in some previous app version. Instead they should simply create a new android.app.Service for this purpose. -Only error cases 11 - 13 remain, making `PERMISSION_DENIED` unambiguous for the purpose of error handling. Reasonable client apps can handle it in a generic way by displaying an error message and/or proceeding with degraded functionality. +Only error cases 12 - 14 remain, making `PERMISSION_DENIED` unambiguous for the purpose of error handling. Reasonable client apps can handle it in a generic way by displaying an error message and/or proceeding with degraded functionality. #### Non-Ambiguity of UNIMPLEMENTED -The `UNIMPLEMENTED` status code corresponds to quite a few different problems with the server app: It’s either not installed, too old, or disabled in whole or in part. Despite the diversity of underlying error cases, we believe most client apps will and should handle `UNIMPLEMENTED` in the same way: by sending the user to the app store to (re)install the server app. Reinstalling might be overkill for the disabled cases but most end users don't know what it means to enable/disable an app and there’s neither enough space in a UI dialog nor enough reader attention to explain it. Reinstalling is something users likely already understand and very likely to cure problems 1-8. +The `UNIMPLEMENTED` status code corresponds to quite a few different problems with the server app: It’s either not installed, too old, misconfigured, or disabled in whole or in part. Despite the diversity of underlying error cases, we believe most client apps will and should handle `UNIMPLEMENTED` in the same way: by sending the user to the app store to (re)install the server app. Reinstalling might be overkill for the disabled cases but most end users don't know what it means to enable/disable an app and there’s neither enough space in a UI dialog nor enough reader attention to explain it. Reinstalling is something users likely already understand and likely to cure problems 0-9 (once a fixed version of the server is available). ## Detailed Discussion of Binder Failure Modes @@ -315,6 +332,8 @@ According to a review of the AOSP source code, there are in fact several cases: 1. The target package is not installed 2. The target package is installed but does not declare the target Service in its manifest. 3. The target package requests dangerous permissions but targets sdk <= M and therefore requires a permissions review, but the caller is not running in the foreground and so it would be inappropriate to launch the review UI. +4. The target package is not visible to the client due to [Android 11 package visibility rules](https://developer.android.com/training/package-visibility). +5. One of the new [Safer Intents](https://developer.android.com/about/versions/15/behavior-changes-15#safer-intents) rules is violated. Most commonly, the bind `Intent` specifies a `ComponentName` explicitly but doesn't match any of its <intent-filter>s. Status code mapping: **UNIMPLEMENTED** @@ -322,6 +341,7 @@ Status code mapping: **UNIMPLEMENTED** Unfortunately `UNIMPLEMENTED` doesn’t capture (3) but none of the other canonical status codes do either and we expect this case to be extremely rare. +(4) and (5) are intentially indistinguishable from (1) by Android design so we can't handle them differently. However, as an error in its own manifest, (4) isn't something a reasonable client would handle at runtime anyway. (5) is an error in the server manifest and so, just like the other cases, the best practice for handling it is to send the user to the app store in the hope that the server can be updated with a fix. ### bindService() throws SecurityException @@ -382,4 +402,4 @@ Android’s Parcel class exposes a mechanism for marshalling certain types of `R The calling Activity or Service Context might be destroyed with a gRPC request in flight. Apps should cease operations when the Context hosting it goes away and this includes cancelling any outstanding RPCs. -Status code mapping: **CANCELLED** \ No newline at end of file +Status code mapping: **CANCELLED** diff --git a/documentation/server-reflection-tutorial.md b/documentation/server-reflection-tutorial.md index 5fad5a22333..f452174738a 100644 --- a/documentation/server-reflection-tutorial.md +++ b/documentation/server-reflection-tutorial.md @@ -10,9 +10,9 @@ proto-based services. ## Enable Server Reflection gRPC-Java Server Reflection is implemented by -`io.grpc.protobuf.services.ProtoReflectionService` in the `grpc-services` +`io.grpc.protobuf.services.ProtoReflectionServiceV1` in the `grpc-services` package. To enable server reflection, you need to add the -`ProtoReflectionService` to your gRPC server. +`ProtoReflectionServiceV1` to your gRPC server. For example, to enable server reflection in `examples/src/main/java/io/grpc/examples/helloworld/HelloWorldServer.java`, we @@ -28,14 +28,14 @@ need to make the following changes: + compile "io.grpc:grpc-services:${grpcVersion}" compile "io.grpc:grpc-stub:${grpcVersion}" - testCompile "junit:junit:4.12" + testCompile "junit:junit:4.13.2" --- a/examples/src/main/java/io/grpc/examples/helloworld/HelloWorldServer.java +++ b/examples/src/main/java/io/grpc/examples/helloworld/HelloWorldServer.java @@ -33,6 +33,7 @@ package io.grpc.examples.helloworld; import io.grpc.Server; import io.grpc.ServerBuilder; -+import io.grpc.protobuf.services.ProtoReflectionService; ++import io.grpc.protobuf.services.ProtoReflectionServiceV1; import io.grpc.stub.StreamObserver; import java.io.IOException; import java.util.logging.Logger; @@ -43,7 +43,7 @@ need to make the following changes: int port = 50051; server = ServerBuilder.forPort(port) .addService(new GreeterImpl()) -+ .addService(ProtoReflectionService.newInstance()) ++ .addService(ProtoReflectionServiceV1.newInstance()) .build() .start(); logger.info("Server started, listening on " + port); diff --git a/examples/.bazelrc b/examples/.bazelrc index 554440cfe3d..53485cb9743 100644 --- a/examples/.bazelrc +++ b/examples/.bazelrc @@ -1 +1 @@ -build --cxxopt=-std=c++14 --host_cxxopt=-std=c++14 +build --cxxopt=-std=c++17 --host_cxxopt=-std=c++17 diff --git a/examples/BUILD.bazel b/examples/BUILD.bazel index 563fa07ce84..e3ef8c5ac5d 100644 --- a/examples/BUILD.bazel +++ b/examples/BUILD.bazel @@ -1,5 +1,8 @@ -load("@rules_proto//proto:defs.bzl", "proto_library") +load("@com_google_protobuf//bazel:java_proto_library.bzl", "java_proto_library") +load("@com_google_protobuf//bazel:proto_library.bzl", "proto_library") load("@io_grpc_grpc_java//:java_grpc_library.bzl", "java_grpc_library") +load("@rules_java//java:java_binary.bzl", "java_binary") +load("@rules_java//java:java_library.bzl", "java_library") proto_library( name = "helloworld_proto", @@ -97,8 +100,8 @@ java_library( "@io_grpc_grpc_java//services:health", "@io_grpc_grpc_java//services:healthlb", "@io_grpc_grpc_java//stub", - "@io_grpc_grpc_proto//:health_proto", "@io_grpc_grpc_proto//:health_java_proto", + "@io_grpc_grpc_proto//:health_proto", "@maven//:com_google_api_grpc_proto_google_common_protos", "@maven//:com_google_code_findbugs_jsr305", "@maven//:com_google_code_gson_gson", @@ -248,3 +251,30 @@ java_grpc_library( deps = ["@io_grpc_grpc_proto//:health_java_proto"], ) +java_binary( + name = "retrying-hello-world-client", + testonly = 1, + main_class = "io.grpc.examples.retrying.RetryingHelloWorldClient", + runtime_deps = [ + ":examples", + ], +) + +java_binary( + name = "retrying-hello-world-server", + testonly = 1, + main_class = "io.grpc.examples.retrying.RetryingHelloWorldServer", + runtime_deps = [ + ":examples", + ], +) + +# grpc-xds requires some WORKSPACE/MODULE deps that aren't needed by the other +# targets. This just makes sure the example WORKSPACE/MODULE works with +# grpc-xds. +java_library( + name = "test_grpc_xds_compiles", + runtime_deps = [ + "@io_grpc_grpc_java//xds", + ], +) diff --git a/examples/MODULE.bazel b/examples/MODULE.bazel new file mode 100644 index 00000000000..105fcecaafe --- /dev/null +++ b/examples/MODULE.bazel @@ -0,0 +1,32 @@ +bazel_dep(name = "grpc-java", version = "1.82.0-SNAPSHOT", repo_name = "io_grpc_grpc_java") # CURRENT_GRPC_VERSION +bazel_dep(name = "rules_java", version = "9.3.0") +bazel_dep(name = "grpc-proto", version = "0.0.0-20240627-ec30f58", repo_name = "io_grpc_grpc_proto") +bazel_dep(name = "protobuf", version = "33.1", repo_name = "com_google_protobuf") +bazel_dep(name = "rules_jvm_external", version = "6.0") + +# Do not use this override in your own MODULE.bazel. It is unnecessary when +# using a version from BCR. Be aware the gRPC Java team does not update the +# BCR for new releases, so you may need to create a PR for the BCR to add the +# version. To not use the BCR, you could use: +# +# git_override( +# module_name = "grpc-java", +# remote = "https://github.com/grpc/grpc-java.git", +# tag = "v", +# ) +local_path_override( + module_name = "grpc-java", + path = "..", +) + +maven = use_extension("@rules_jvm_external//:extensions.bzl", "maven") +use_repo(maven, "maven") + +maven.install( + artifacts = [ + "com.google.api.grpc:grpc-google-cloud-pubsub-v1:0.1.24", + "com.google.api.grpc:proto-google-cloud-pubsub-v1:0.1.24", + ], + repositories = ["https://repo.maven.apache.org/maven2/"], + strict_visibility = True, +) diff --git a/examples/README.md b/examples/README.md index b51d560d7bb..91fde2c045c 100644 --- a/examples/README.md +++ b/examples/README.md @@ -27,114 +27,32 @@ before trying out the examples. - [Json serialization](src/main/java/io/grpc/examples/advanced) --

- Hedging - - The [hedging example](src/main/java/io/grpc/examples/hedging) demonstrates that enabling hedging - can reduce tail latency. (Users should note that enabling hedging may introduce other overhead; - and in some scenarios, such as when some server resource gets exhausted for a period of time and - almost every RPC during that time has high latency or fails, hedging may make things worse. - Setting a throttle in the service config is recommended to protect the server from too many - inappropriate retry or hedging requests.) - - The server and the client in the example are basically the same as those in the - [hello world](src/main/java/io/grpc/examples/helloworld) example, except that the server mimics a - long tail of latency, and the client sends 2000 requests and can turn on and off hedging. - - To mimic the latency, the server randomly delays the RPC handling by 2 seconds at 10% chance, 5 - seconds at 5% chance, and 10 seconds at 1% chance. - - When running the client enabling the following hedging policy - - ```json - "hedgingPolicy": { - "maxAttempts": 3, - "hedgingDelay": "1s" - } - ``` - Then the latency summary in the client log is like the following - - ```text - Total RPCs sent: 2,000. Total RPCs failed: 0 - [Hedging enabled] - ======================== - 50% latency: 0ms - 90% latency: 6ms - 95% latency: 1,003ms - 99% latency: 2,002ms - 99.9% latency: 2,011ms - Max latency: 5,272ms - ======================== - ``` - - See [the section below](#to-build-the-examples) for how to build and run the example. The - executables for the server and the client are `hedging-hello-world-server` and - `hedging-hello-world-client`. - - To disable hedging, set environment variable `DISABLE_HEDGING_IN_HEDGING_EXAMPLE=true` before - running the client. That produces a latency summary in the client log like the following - - ```text - Total RPCs sent: 2,000. Total RPCs failed: 0 - [Hedging disabled] - ======================== - 50% latency: 0ms - 90% latency: 2,002ms - 95% latency: 5,002ms - 99% latency: 10,004ms - 99.9% latency: 10,007ms - Max latency: 10,007ms - ======================== - ``` - -
- --
- Retrying - - The [retrying example](src/main/java/io/grpc/examples/retrying) provides a HelloWorld gRPC client & - server which demos the effect of client retry policy configured on the [ManagedChannel]( - ../api/src/main/java/io/grpc/ManagedChannel.java) via [gRPC ServiceConfig]( - https://github.com/grpc/grpc/blob/master/doc/service_config.md). Retry policy implementation & - configuration details are outlined in the [proposal](https://github.com/grpc/proposal/blob/master/A6-client-retries.md). - - This retrying example is very similar to the [hedging example](src/main/java/io/grpc/examples/hedging) in its setup. - The [RetryingHelloWorldServer](src/main/java/io/grpc/examples/retrying/RetryingHelloWorldServer.java) responds with - a status UNAVAILABLE error response to a specified percentage of requests to simulate server resource exhaustion and - general flakiness. The [RetryingHelloWorldClient](src/main/java/io/grpc/examples/retrying/RetryingHelloWorldClient.java) makes - a number of sequential requests to the server, several of which will be retried depending on the configured policy in - [retrying_service_config.json](src/main/resources/io/grpc/examples/retrying/retrying_service_config.json). Although - the requests are blocking unary calls for simplicity, these could easily be changed to future unary calls in order to - test the result of request concurrency with retry policy enabled. - - One can experiment with the [RetryingHelloWorldServer](src/main/java/io/grpc/examples/retrying/RetryingHelloWorldServer.java) - failure conditions to simulate server throttling, as well as alter policy values in the [retrying_service_config.json]( - src/main/resources/io/grpc/examples/retrying/retrying_service_config.json) to see their effects. To disable retrying - entirely, set environment variable `DISABLE_RETRYING_IN_RETRYING_EXAMPLE=true` before running the client. - Disabling the retry policy should produce many more failed gRPC calls as seen in the output log. - - See [the section below](#to-build-the-examples) for how to build and run the example. The - executables for the server and the client are `retrying-hello-world-server` and - `retrying-hello-world-client`. - -
- --
- Health Service - - The [health service example](src/main/java/io/grpc/examples/healthservice) - provides a HelloWorld gRPC server that doesn't like short names along with a - health service. It also provides a client application which makes HelloWorld - calls and checks the health status. - - The client application also shows how the round robin load balancer can - utilize the health status to avoid making calls to a service that is - not actively serving. -
+- [Hedging example](src/main/java/io/grpc/examples/hedging) +- [Retrying example](src/main/java/io/grpc/examples/retrying) + +- [Health Service example](src/main/java/io/grpc/examples/healthservice) - [Keep Alive](src/main/java/io/grpc/examples/keepalive) +- [Cancellation](src/main/java/io/grpc/examples/cancellation) + +- [Custom Load Balance](src/main/java/io/grpc/examples/customloadbalance) + +- [Deadline](src/main/java/io/grpc/examples/deadline) + +- [Error Details](src/main/java/io/grpc/examples/errordetails) + +- [GRPC Proxy](src/main/java/io/grpc/examples/grpcproxy) + +- [Load Balance](src/main/java/io/grpc/examples/loadbalance) + +- [Multiplex](src/main/java/io/grpc/examples/multiplex) + +- [Name Resolve](src/main/java/io/grpc/examples/nameresolve) + +- [Pre-Serialized Messages](src/main/java/io/grpc/examples/preserialized) + ### To build the examples 1. **[Install gRPC Java library SNAPSHOT locally, including code generation plugin](../COMPILING.md) (Only need this step for non-released versions, e.g. master HEAD).** @@ -235,9 +153,9 @@ Example bugs not caught by mocked stub tests include: For testing a gRPC client, create the client with a real stub using an -[InProcessChannel](../core/src/main/java/io/grpc/inprocess/InProcessChannelBuilder.java), +[InProcessChannel](../inprocess/src/main/java/io/grpc/inprocess/InProcessChannelBuilder.java), and test it against an -[InProcessServer](../core/src/main/java/io/grpc/inprocess/InProcessServerBuilder.java) +[InProcessServer](../inprocess/src/main/java/io/grpc/inprocess/InProcessServerBuilder.java) with a mock/fake service implementation. For testing a gRPC server, create the server as an InProcessServer, diff --git a/examples/WORKSPACE b/examples/WORKSPACE index 7291584b3fc..1387cc4cf12 100644 --- a/examples/WORKSPACE +++ b/examples/WORKSPACE @@ -14,6 +14,48 @@ local_repository( path = "..", ) +load("@io_grpc_grpc_java//:repositories.bzl", "IO_GRPC_GRPC_JAVA_ARTIFACTS", "IO_GRPC_GRPC_JAVA_OVERRIDE_TARGETS", "grpc_java_repositories") + +grpc_java_repositories() + +http_archive( + name = "rules_java", + sha256 = "47632cc506c858011853073449801d648e10483d4b50e080ec2549a4b2398960", + urls = [ + "https://github.com/bazelbuild/rules_java/releases/download/8.15.2/rules_java-8.15.2.tar.gz", + ], +) + +# Protobuf now requires C++14 or higher, which requires Bazel configuration +# outside the WORKSPACE. See .bazelrc in this directory. +load("@com_google_protobuf//:protobuf_deps.bzl", "PROTOBUF_MAVEN_ARTIFACTS", "protobuf_deps") + +protobuf_deps() + +load("@rules_java//java:rules_java_deps.bzl", "compatibility_proxy_repo", "rules_java_dependencies") + +rules_java_dependencies() + +load("@bazel_features//:deps.bzl", "bazel_features_deps") + +bazel_features_deps() + +compatibility_proxy_repo() + +load("@bazel_jar_jar//:jar_jar.bzl", "jar_jar_repositories") + +jar_jar_repositories() + +load("@rules_python//python:repositories.bzl", "py_repositories") + +py_repositories() + +load("@com_google_googleapis//:repository_rules.bzl", "switched_rules_by_language") + +switched_rules_by_language( + name = "com_google_googleapis_imports", +) + http_archive( name = "rules_jvm_external", sha256 = "d31e369b854322ca5098ea12c69d7175ded971435e55c18dd9dd5f29cc5249ac", @@ -22,31 +64,14 @@ http_archive( ) load("@rules_jvm_external//:defs.bzl", "maven_install") -load("@io_grpc_grpc_java//:repositories.bzl", "IO_GRPC_GRPC_JAVA_ARTIFACTS") -load("@io_grpc_grpc_java//:repositories.bzl", "IO_GRPC_GRPC_JAVA_OVERRIDE_TARGETS") -load("@io_grpc_grpc_java//:repositories.bzl", "grpc_java_repositories") - -grpc_java_repositories() - -# Protobuf now requires C++14 or higher, which requires Bazel configuration -# outside the WORKSPACE. See .bazelrc in this directory. -load("@com_google_protobuf//:protobuf_deps.bzl", "PROTOBUF_MAVEN_ARTIFACTS") -load("@com_google_protobuf//:protobuf_deps.bzl", "protobuf_deps") - -protobuf_deps() maven_install( artifacts = [ "com.google.api.grpc:grpc-google-cloud-pubsub-v1:0.1.24", "com.google.api.grpc:proto-google-cloud-pubsub-v1:0.1.24", ] + IO_GRPC_GRPC_JAVA_ARTIFACTS + PROTOBUF_MAVEN_ARTIFACTS, - generate_compat_repositories = True, override_targets = IO_GRPC_GRPC_JAVA_OVERRIDE_TARGETS, repositories = [ "https://repo.maven.apache.org/maven2/", ], ) - -load("@maven//:compat.bzl", "compat_repositories") - -compat_repositories() diff --git a/examples/WORKSPACE.bzlmod b/examples/WORKSPACE.bzlmod new file mode 100644 index 00000000000..4ecb9e5d985 --- /dev/null +++ b/examples/WORKSPACE.bzlmod @@ -0,0 +1 @@ +# When using bzlmod this makes sure nothing from the legacy WORKSPACE is loaded diff --git a/examples/android/clientcache/app/build.gradle b/examples/android/clientcache/app/build.gradle index 5bab4c97fbc..0219e73ff89 100644 --- a/examples/android/clientcache/app/build.gradle +++ b/examples/android/clientcache/app/build.gradle @@ -10,9 +10,8 @@ android { defaultConfig { applicationId "io.grpc.clientcacheexample" - minSdkVersion 21 + minSdkVersion 23 targetSdkVersion 33 - multiDexEnabled true versionCode 1 versionName "1.0" testInstrumentationRunner "android.support.test.runner.AndroidJUnitRunner" @@ -34,7 +33,7 @@ android { protobuf { protoc { artifact = 'com.google.protobuf:protoc:3.25.1' } plugins { - grpc { artifact = 'io.grpc:protoc-gen-grpc-java:1.63.0-SNAPSHOT' // CURRENT_GRPC_VERSION + grpc { artifact = 'io.grpc:protoc-gen-grpc-java:1.82.0-SNAPSHOT' // CURRENT_GRPC_VERSION } } generateProtoTasks { @@ -54,12 +53,11 @@ dependencies { implementation 'androidx.appcompat:appcompat:1.0.0' // You need to build grpc-java to obtain these libraries below. - implementation 'io.grpc:grpc-okhttp:1.63.0-SNAPSHOT' // CURRENT_GRPC_VERSION - implementation 'io.grpc:grpc-protobuf-lite:1.63.0-SNAPSHOT' // CURRENT_GRPC_VERSION - implementation 'io.grpc:grpc-stub:1.63.0-SNAPSHOT' // CURRENT_GRPC_VERSION - implementation 'org.apache.tomcat:annotations-api:6.0.53' + implementation 'io.grpc:grpc-okhttp:1.82.0-SNAPSHOT' // CURRENT_GRPC_VERSION + implementation 'io.grpc:grpc-protobuf-lite:1.82.0-SNAPSHOT' // CURRENT_GRPC_VERSION + implementation 'io.grpc:grpc-stub:1.82.0-SNAPSHOT' // CURRENT_GRPC_VERSION testImplementation 'junit:junit:4.13.2' - testImplementation 'com.google.truth:truth:1.1.5' - testImplementation 'io.grpc:grpc-testing:1.63.0-SNAPSHOT' // CURRENT_GRPC_VERSION + testImplementation 'com.google.truth:truth:1.4.5' + testImplementation 'io.grpc:grpc-testing:1.82.0-SNAPSHOT' // CURRENT_GRPC_VERSION } diff --git a/examples/android/clientcache/build.gradle b/examples/android/clientcache/build.gradle index 67d25905bbc..6db6a9bced1 100644 --- a/examples/android/clientcache/build.gradle +++ b/examples/android/clientcache/build.gradle @@ -7,7 +7,7 @@ buildscript { } dependencies { classpath 'com.android.tools.build:gradle:7.4.0' - classpath "com.google.protobuf:protobuf-gradle-plugin:0.9.4" + classpath "com.google.protobuf:protobuf-gradle-plugin:0.9.5" // NOTE: Do not place your application dependencies here; they belong // in the individual module build.gradle files diff --git a/examples/android/clientcache/settings.gradle b/examples/android/clientcache/settings.gradle index e7b4def49cb..6208d70e838 100644 --- a/examples/android/clientcache/settings.gradle +++ b/examples/android/clientcache/settings.gradle @@ -1 +1,17 @@ +pluginManagement { + // https://issuetracker.google.com/issues/342522142#comment8 + // use D8/R8 8.0.44 or 8.1.44 with AGP 7.4 if needed. + buildscript { + repositories { + mavenCentral() + maven { + url = uri("https://storage.googleapis.com/r8-releases/raw") + } + } + dependencies { + classpath("com.android.tools:r8:8.1.44") + } + } +} + include ':app' diff --git a/examples/android/helloworld/app/build.gradle b/examples/android/helloworld/app/build.gradle index 294aa19590d..1e81415e483 100644 --- a/examples/android/helloworld/app/build.gradle +++ b/examples/android/helloworld/app/build.gradle @@ -10,7 +10,7 @@ android { defaultConfig { applicationId "io.grpc.helloworldexample" - minSdkVersion 21 + minSdkVersion 23 targetSdkVersion 33 versionCode 1 versionName "1.0" @@ -32,7 +32,7 @@ android { protobuf { protoc { artifact = 'com.google.protobuf:protoc:3.25.1' } plugins { - grpc { artifact = 'io.grpc:protoc-gen-grpc-java:1.63.0-SNAPSHOT' // CURRENT_GRPC_VERSION + grpc { artifact = 'io.grpc:protoc-gen-grpc-java:1.82.0-SNAPSHOT' // CURRENT_GRPC_VERSION } } generateProtoTasks { @@ -52,8 +52,7 @@ dependencies { implementation 'androidx.appcompat:appcompat:1.0.0' // You need to build grpc-java to obtain these libraries below. - implementation 'io.grpc:grpc-okhttp:1.63.0-SNAPSHOT' // CURRENT_GRPC_VERSION - implementation 'io.grpc:grpc-protobuf-lite:1.63.0-SNAPSHOT' // CURRENT_GRPC_VERSION - implementation 'io.grpc:grpc-stub:1.63.0-SNAPSHOT' // CURRENT_GRPC_VERSION - implementation 'org.apache.tomcat:annotations-api:6.0.53' + implementation 'io.grpc:grpc-okhttp:1.82.0-SNAPSHOT' // CURRENT_GRPC_VERSION + implementation 'io.grpc:grpc-protobuf-lite:1.82.0-SNAPSHOT' // CURRENT_GRPC_VERSION + implementation 'io.grpc:grpc-stub:1.82.0-SNAPSHOT' // CURRENT_GRPC_VERSION } diff --git a/examples/android/helloworld/build.gradle b/examples/android/helloworld/build.gradle index 67d25905bbc..6db6a9bced1 100644 --- a/examples/android/helloworld/build.gradle +++ b/examples/android/helloworld/build.gradle @@ -7,7 +7,7 @@ buildscript { } dependencies { classpath 'com.android.tools.build:gradle:7.4.0' - classpath "com.google.protobuf:protobuf-gradle-plugin:0.9.4" + classpath "com.google.protobuf:protobuf-gradle-plugin:0.9.5" // NOTE: Do not place your application dependencies here; they belong // in the individual module build.gradle files diff --git a/examples/android/helloworld/settings.gradle b/examples/android/helloworld/settings.gradle index e7b4def49cb..6208d70e838 100644 --- a/examples/android/helloworld/settings.gradle +++ b/examples/android/helloworld/settings.gradle @@ -1 +1,17 @@ +pluginManagement { + // https://issuetracker.google.com/issues/342522142#comment8 + // use D8/R8 8.0.44 or 8.1.44 with AGP 7.4 if needed. + buildscript { + repositories { + mavenCentral() + maven { + url = uri("https://storage.googleapis.com/r8-releases/raw") + } + } + dependencies { + classpath("com.android.tools:r8:8.1.44") + } + } +} + include ':app' diff --git a/examples/android/routeguide/app/build.gradle b/examples/android/routeguide/app/build.gradle index a7b9375f13c..7152add7858 100644 --- a/examples/android/routeguide/app/build.gradle +++ b/examples/android/routeguide/app/build.gradle @@ -10,7 +10,7 @@ android { defaultConfig { applicationId "io.grpc.routeguideexample" - minSdkVersion 21 + minSdkVersion 23 targetSdkVersion 33 versionCode 1 versionName "1.0" @@ -32,7 +32,7 @@ android { protobuf { protoc { artifact = 'com.google.protobuf:protoc:3.25.1' } plugins { - grpc { artifact = 'io.grpc:protoc-gen-grpc-java:1.63.0-SNAPSHOT' // CURRENT_GRPC_VERSION + grpc { artifact = 'io.grpc:protoc-gen-grpc-java:1.82.0-SNAPSHOT' // CURRENT_GRPC_VERSION } } generateProtoTasks { @@ -52,8 +52,7 @@ dependencies { implementation 'androidx.appcompat:appcompat:1.0.0' // You need to build grpc-java to obtain these libraries below. - implementation 'io.grpc:grpc-okhttp:1.63.0-SNAPSHOT' // CURRENT_GRPC_VERSION - implementation 'io.grpc:grpc-protobuf-lite:1.63.0-SNAPSHOT' // CURRENT_GRPC_VERSION - implementation 'io.grpc:grpc-stub:1.63.0-SNAPSHOT' // CURRENT_GRPC_VERSION - implementation 'org.apache.tomcat:annotations-api:6.0.53' + implementation 'io.grpc:grpc-okhttp:1.82.0-SNAPSHOT' // CURRENT_GRPC_VERSION + implementation 'io.grpc:grpc-protobuf-lite:1.82.0-SNAPSHOT' // CURRENT_GRPC_VERSION + implementation 'io.grpc:grpc-stub:1.82.0-SNAPSHOT' // CURRENT_GRPC_VERSION } diff --git a/examples/android/routeguide/build.gradle b/examples/android/routeguide/build.gradle index fd058a5d68e..8fc1d293228 100644 --- a/examples/android/routeguide/build.gradle +++ b/examples/android/routeguide/build.gradle @@ -7,7 +7,7 @@ buildscript { } dependencies { classpath 'com.android.tools.build:gradle:7.4.0' - classpath "com.google.protobuf:protobuf-gradle-plugin:0.9.4" + classpath "com.google.protobuf:protobuf-gradle-plugin:0.9.5" // NOTE: Do not place your application dependencies here; they belong // in the individual module build.gradle files diff --git a/examples/android/routeguide/settings.gradle b/examples/android/routeguide/settings.gradle index e7b4def49cb..6208d70e838 100644 --- a/examples/android/routeguide/settings.gradle +++ b/examples/android/routeguide/settings.gradle @@ -1 +1,17 @@ +pluginManagement { + // https://issuetracker.google.com/issues/342522142#comment8 + // use D8/R8 8.0.44 or 8.1.44 with AGP 7.4 if needed. + buildscript { + repositories { + mavenCentral() + maven { + url = uri("https://storage.googleapis.com/r8-releases/raw") + } + } + dependencies { + classpath("com.android.tools:r8:8.1.44") + } + } +} + include ':app' diff --git a/examples/android/strictmode/app/build.gradle b/examples/android/strictmode/app/build.gradle index 23b8919a05c..cc54d274a29 100644 --- a/examples/android/strictmode/app/build.gradle +++ b/examples/android/strictmode/app/build.gradle @@ -33,7 +33,7 @@ android { protobuf { protoc { artifact = 'com.google.protobuf:protoc:3.25.1' } plugins { - grpc { artifact = 'io.grpc:protoc-gen-grpc-java:1.63.0-SNAPSHOT' // CURRENT_GRPC_VERSION + grpc { artifact = 'io.grpc:protoc-gen-grpc-java:1.82.0-SNAPSHOT' // CURRENT_GRPC_VERSION } } generateProtoTasks { @@ -53,8 +53,7 @@ dependencies { implementation 'androidx.appcompat:appcompat:1.0.0' // You need to build grpc-java to obtain these libraries below. - implementation 'io.grpc:grpc-okhttp:1.63.0-SNAPSHOT' // CURRENT_GRPC_VERSION - implementation 'io.grpc:grpc-protobuf-lite:1.63.0-SNAPSHOT' // CURRENT_GRPC_VERSION - implementation 'io.grpc:grpc-stub:1.63.0-SNAPSHOT' // CURRENT_GRPC_VERSION - implementation 'org.apache.tomcat:annotations-api:6.0.53' + implementation 'io.grpc:grpc-okhttp:1.82.0-SNAPSHOT' // CURRENT_GRPC_VERSION + implementation 'io.grpc:grpc-protobuf-lite:1.82.0-SNAPSHOT' // CURRENT_GRPC_VERSION + implementation 'io.grpc:grpc-stub:1.82.0-SNAPSHOT' // CURRENT_GRPC_VERSION } diff --git a/examples/android/strictmode/build.gradle b/examples/android/strictmode/build.gradle index 67d25905bbc..6db6a9bced1 100644 --- a/examples/android/strictmode/build.gradle +++ b/examples/android/strictmode/build.gradle @@ -7,7 +7,7 @@ buildscript { } dependencies { classpath 'com.android.tools.build:gradle:7.4.0' - classpath "com.google.protobuf:protobuf-gradle-plugin:0.9.4" + classpath "com.google.protobuf:protobuf-gradle-plugin:0.9.5" // NOTE: Do not place your application dependencies here; they belong // in the individual module build.gradle files diff --git a/examples/android/strictmode/settings.gradle b/examples/android/strictmode/settings.gradle index e7b4def49cb..6208d70e838 100644 --- a/examples/android/strictmode/settings.gradle +++ b/examples/android/strictmode/settings.gradle @@ -1 +1,17 @@ +pluginManagement { + // https://issuetracker.google.com/issues/342522142#comment8 + // use D8/R8 8.0.44 or 8.1.44 with AGP 7.4 if needed. + buildscript { + repositories { + mavenCentral() + maven { + url = uri("https://storage.googleapis.com/r8-releases/raw") + } + } + dependencies { + classpath("com.android.tools:r8:8.1.44") + } + } +} + include ':app' diff --git a/examples/build.gradle b/examples/build.gradle index b0d815c3fe9..0ad62bb9ef0 100644 --- a/examples/build.gradle +++ b/examples/build.gradle @@ -1,14 +1,12 @@ plugins { // Provide convenience executables for trying out the examples. id 'application' - id 'com.google.protobuf' version '0.9.4' + id 'com.google.protobuf' version '0.9.5' // Generate IntelliJ IDEA's .idea & .iml project files id 'idea' } repositories { - maven { // The google mirror is less flaky than mavenCentral() - url "https://maven-central.storage-download.googleapis.com/maven2/" } mavenCentral() mavenLocal() } @@ -23,15 +21,16 @@ java { // Feel free to delete the comment at the next line. It is just for safely // updating the version in our release process. -def grpcVersion = '1.63.0-SNAPSHOT' // CURRENT_GRPC_VERSION -def protobufVersion = '3.25.1' +def grpcVersion = '1.82.0-SNAPSHOT' // CURRENT_GRPC_VERSION +def protobufVersion = '3.25.8' def protocVersion = protobufVersion dependencies { implementation "io.grpc:grpc-protobuf:${grpcVersion}" + // Even though client-side won't call grpc-services directly, it needs the + // dependency to enable the health-aware round_robin implementation implementation "io.grpc:grpc-services:${grpcVersion}" implementation "io.grpc:grpc-stub:${grpcVersion}" - compileOnly "org.apache.tomcat:annotations-api:6.0.53" // examples/advanced need this for JsonFormat implementation "com.google.protobuf:protobuf-java-util:${protobufVersion}" @@ -54,15 +53,14 @@ protobuf { } } -// Inform IDEs like IntelliJ IDEA, Eclipse or NetBeans about the generated code. -sourceSets { - main { - java { - srcDirs 'build/generated/source/proto/main/grpc' - srcDirs 'build/generated/source/proto/main/java' - } - } -} +// gRPC uses java.util.ServiceLoader, which reads class names from +// META-INF/services in jars. If you package your application as a "fat" jar +// that includes dependencies, you need to make sure the packaging tool +// concatenates duplicate files in META-INF/services. +// +// For the Shadow Gradle Plugin, use call mergeServiceFiles() within the +// shadowJar task. +// https://gradleup.com/shadow/configuration/merging/#merging-service-descriptor-files startScripts.enabled = false @@ -82,7 +80,9 @@ def createStartScripts(String mainClassName) { application { applicationDistribution.into('bin') { from(newTask) - fileMode = 0755 + filePermissions { + unix(0755) + } } } } @@ -107,6 +107,7 @@ createStartScripts('io.grpc.examples.keepalive.KeepAliveClient') createStartScripts('io.grpc.examples.keepalive.KeepAliveServer') createStartScripts('io.grpc.examples.loadbalance.LoadBalanceClient') createStartScripts('io.grpc.examples.loadbalance.LoadBalanceServer') +createStartScripts('io.grpc.examples.manualflowcontrol.BidiBlockingClient') createStartScripts('io.grpc.examples.manualflowcontrol.ManualFlowControlClient') createStartScripts('io.grpc.examples.manualflowcontrol.ManualFlowControlServer') createStartScripts('io.grpc.examples.multiplex.MultiplexingServer') diff --git a/examples/example-alts/BUILD.bazel b/examples/example-alts/BUILD.bazel index 0404dcccf81..4d66accfc19 100644 --- a/examples/example-alts/BUILD.bazel +++ b/examples/example-alts/BUILD.bazel @@ -1,5 +1,8 @@ -load("@rules_proto//proto:defs.bzl", "proto_library") +load("@com_google_protobuf//bazel:java_proto_library.bzl", "java_proto_library") +load("@com_google_protobuf//bazel:proto_library.bzl", "proto_library") load("@io_grpc_grpc_java//:java_grpc_library.bzl", "java_grpc_library") +load("@rules_java//java:java_binary.bzl", "java_binary") +load("@rules_java//java:java_library.bzl", "java_library") proto_library( name = "helloworld_proto", diff --git a/examples/example-alts/example-alts/README.md b/examples/example-alts/README.md similarity index 100% rename from examples/example-alts/example-alts/README.md rename to examples/example-alts/README.md diff --git a/examples/example-alts/build.gradle b/examples/example-alts/build.gradle index 2844ec9e5c4..3fea622b923 100644 --- a/examples/example-alts/build.gradle +++ b/examples/example-alts/build.gradle @@ -1,15 +1,12 @@ plugins { // Provide convenience executables for trying out the examples. id 'application' - id 'com.google.protobuf' version '0.9.4' + id 'com.google.protobuf' version '0.9.5' // Generate IntelliJ IDEA's .idea & .iml project files id 'idea' } repositories { - maven { // The google mirror is less flaky than mavenCentral() - url "https://maven-central.storage-download.googleapis.com/maven2/" - } mavenCentral() mavenLocal() } @@ -24,13 +21,12 @@ java { // Feel free to delete the comment at the next line. It is just for safely // updating the version in our release process. -def grpcVersion = '1.63.0-SNAPSHOT' // CURRENT_GRPC_VERSION -def protocVersion = '3.25.1' +def grpcVersion = '1.82.0-SNAPSHOT' // CURRENT_GRPC_VERSION +def protocVersion = '3.25.8' dependencies { // grpc-alts transitively depends on grpc-netty-shaded, grpc-protobuf, and grpc-stub implementation "io.grpc:grpc-alts:${grpcVersion}" - compileOnly "org.apache.tomcat:annotations-api:6.0.53" } protobuf { @@ -43,16 +39,6 @@ protobuf { } } -// Inform IDEs like IntelliJ IDEA, Eclipse or NetBeans about the generated code. -sourceSets { - main { - java { - srcDirs 'build/generated/source/proto/main/grpc' - srcDirs 'build/generated/source/proto/main/java' - } - } -} - startScripts.enabled = false @@ -74,6 +60,8 @@ application { applicationDistribution.into('bin') { from(helloWorldAltsServer) from(helloWorldAltsClient) - fileMode = 0755 + filePermissions { + unix(0755) + } } } diff --git a/examples/example-alts/settings.gradle b/examples/example-alts/settings.gradle index 273558dd9cf..6bd0f0cdc2d 100644 --- a/examples/example-alts/settings.gradle +++ b/examples/example-alts/settings.gradle @@ -1,8 +1,19 @@ pluginManagement { - repositories { - maven { // The google mirror is less flaky than mavenCentral() - url "https://maven-central.storage-download.googleapis.com/maven2/" + // https://issuetracker.google.com/issues/342522142#comment8 + // use D8/R8 8.0.44 or 8.1.44 with AGP 7.4 if needed. + buildscript { + repositories { + mavenCentral() + maven { + url = uri("https://storage.googleapis.com/r8-releases/raw") + } + } + dependencies { + classpath("com.android.tools:r8:8.1.44") } + } + + repositories { gradlePluginPortal() } } diff --git a/examples/example-debug/build.gradle b/examples/example-debug/build.gradle index ffb54d4e64b..e4edc0704d0 100644 --- a/examples/example-debug/build.gradle +++ b/examples/example-debug/build.gradle @@ -2,15 +2,13 @@ plugins { id 'application' // Provide convenience executables for trying out the examples. id 'java' - id "com.google.protobuf" version "0.9.4" + id "com.google.protobuf" version "0.9.5" // Generate IntelliJ IDEA's .idea & .iml project files id 'idea' } repositories { - maven { // The google mirror is less flaky than mavenCentral() - url "https://maven-central.storage-download.googleapis.com/maven2/" } mavenCentral() mavenLocal() } @@ -25,14 +23,13 @@ java { // Feel free to delete the comment at the next line. It is just for safely // updating the version in our release process. -def grpcVersion = '1.63.0-SNAPSHOT' // CURRENT_GRPC_VERSION -def protobufVersion = '3.25.1' +def grpcVersion = '1.82.0-SNAPSHOT' // CURRENT_GRPC_VERSION +def protobufVersion = '3.25.8' dependencies { implementation "io.grpc:grpc-protobuf:${grpcVersion}" implementation "io.grpc:grpc-stub:${grpcVersion}" implementation "io.grpc:grpc-services:${grpcVersion}" - compileOnly "org.apache.tomcat:annotations-api:6.0.53" runtimeOnly "io.grpc:grpc-netty-shaded:${grpcVersion}" testImplementation 'junit:junit:4.13.2' @@ -75,6 +72,8 @@ application { applicationDistribution.into('bin') { from(HelloWorldDebuggableClient) from(HostnameDebuggableServer) - fileMode = 0755 + filePermissions { + unix(0755) + } } } diff --git a/examples/example-debug/pom.xml b/examples/example-debug/pom.xml index d644e984013..ccb9977f679 100644 --- a/examples/example-debug/pom.xml +++ b/examples/example-debug/pom.xml @@ -6,14 +6,14 @@ jar - 1.63.0-SNAPSHOT + 1.82.0-SNAPSHOT example-debug https://github.com/grpc/grpc-java UTF-8 - 1.63.0-SNAPSHOT - 3.25.1 + 1.82.0-SNAPSHOT + 3.25.8 1.8 1.8 @@ -34,32 +34,21 @@ io.grpc - grpc-protobuf + grpc-services io.grpc - grpc-stub + grpc-protobuf io.grpc - grpc-services - - - org.apache.tomcat - annotations-api - 6.0.53 - provided + grpc-stub io.grpc grpc-netty-shaded runtime - - com.google.guava - guava - 32.1.3-jre - junit junit @@ -103,7 +92,7 @@ org.apache.maven.plugins maven-enforcer-plugin - 1.4.1 + 3.5.0 enforce diff --git a/examples/example-debug/settings.gradle b/examples/example-debug/settings.gradle index 3700c983b6c..48c08629ca9 100644 --- a/examples/example-debug/settings.gradle +++ b/examples/example-debug/settings.gradle @@ -1 +1,17 @@ +pluginManagement { + // https://issuetracker.google.com/issues/342522142#comment8 + // use D8/R8 8.0.44 or 8.1.44 with AGP 7.4 if needed. + buildscript { + repositories { + mavenCentral() + maven { + url = uri("https://storage.googleapis.com/r8-releases/raw") + } + } + dependencies { + classpath("com.android.tools:r8:8.1.44") + } + } +} + rootProject.name = 'example-debug' diff --git a/examples/example-debug/src/main/java/io/grpc/examples/debug/HelloWorldDebuggableClient.java b/examples/example-debug/src/main/java/io/grpc/examples/debug/HelloWorldDebuggableClient.java index 61391b60415..ef1340cf259 100644 --- a/examples/example-debug/src/main/java/io/grpc/examples/debug/HelloWorldDebuggableClient.java +++ b/examples/example-debug/src/main/java/io/grpc/examples/debug/HelloWorldDebuggableClient.java @@ -27,7 +27,7 @@ import io.grpc.examples.helloworld.GreeterGrpc; import io.grpc.examples.helloworld.HelloReply; import io.grpc.examples.helloworld.HelloRequest; -import io.grpc.protobuf.services.ProtoReflectionService; +import io.grpc.protobuf.services.ProtoReflectionServiceV1; import io.grpc.services.AdminInterface; import java.util.concurrent.TimeUnit; import java.util.logging.Level; diff --git a/examples/example-debug/src/main/java/io/grpc/examples/debug/HostnameDebuggableServer.java b/examples/example-debug/src/main/java/io/grpc/examples/debug/HostnameDebuggableServer.java index 89ffc39b599..5525ba91d9c 100644 --- a/examples/example-debug/src/main/java/io/grpc/examples/debug/HostnameDebuggableServer.java +++ b/examples/example-debug/src/main/java/io/grpc/examples/debug/HostnameDebuggableServer.java @@ -21,7 +21,7 @@ import io.grpc.Server; import io.grpc.ServerBuilder; import io.grpc.health.v1.HealthCheckResponse.ServingStatus; -import io.grpc.protobuf.services.ProtoReflectionService; +import io.grpc.protobuf.services.ProtoReflectionServiceV1; import io.grpc.services.AdminInterface; import io.grpc.services.HealthStatusManager; import java.io.IOException; diff --git a/examples/example-dualstack/README.md b/examples/example-dualstack/README.md new file mode 100644 index 00000000000..5a26886e259 --- /dev/null +++ b/examples/example-dualstack/README.md @@ -0,0 +1,54 @@ +# gRPC Dualstack Example + +The dualstack example uses a custom name resolver that provides both IPv4 and IPv6 localhost +endpoints for each of 3 server instances. The client will first use the default name resolver and +load balancers which will only connect to the first server. It will then use the +custom name resolver with round robin to connect to each of the servers in turn. The 3 instances +of the server will bind respectively to: both IPv4 and IPv6, IPv4 only, and IPv6 only. + +The example requires grpc-java to already be built. You are strongly encouraged +to check out a git release tag, since there will already be a build of grpc +available. Otherwise, you must follow [COMPILING](../../COMPILING.md). + +### Build the example + +To build the dualstack example server and client. From the + `grpc-java/examples/example-dualstack` directory run: + +```bash +$ ../gradlew installDist +``` + +This creates the scripts +`build/install/example-dualstack/bin/dual-stack-server` + and `build/install/example-dualstack/bin/dual-stack-client`. + +To run the dualstack example, run the server with: + +```bash +$ ./build/install/example-dualstack/bin/dual-stack-server +``` + +And in a different terminal window run the client. + +```bash +$ ./build/install/example-dualstack/bin/dual-stack-client +``` + +### Maven + +If you prefer to use Maven: + +Run in the example-debug directory: + +```bash +$ mvn verify +$ # Run the server in one terminal +$ mvn exec:java -Dexec.mainClass=io.grpc.examples.dualstack.DualStackServer +``` + +```bash +$ # In another terminal run the client +$ mvn exec:java -Dexec.mainClass=io.grpc.examples.dualstack.DualStackClient +``` + diff --git a/examples/example-dualstack/build.gradle b/examples/example-dualstack/build.gradle new file mode 100644 index 00000000000..f79888831dc --- /dev/null +++ b/examples/example-dualstack/build.gradle @@ -0,0 +1,76 @@ +plugins { + id 'application' // Provide convenience executables for trying out the examples. + id 'java' + + id "com.google.protobuf" version "0.9.5" + + // Generate IntelliJ IDEA's .idea & .iml project files + id 'idea' +} + +repositories { + mavenCentral() + mavenLocal() +} + +java { + sourceCompatibility = JavaVersion.VERSION_1_8 + targetCompatibility = JavaVersion.VERSION_1_8 +} + +// IMPORTANT: You probably want the non-SNAPSHOT version of gRPC. Make sure you +// are looking at a tagged version of the example and not "master"! + +// Feel free to delete the comment at the next line. It is just for safely +// updating the version in our release process. +def grpcVersion = '1.82.0-SNAPSHOT' // CURRENT_GRPC_VERSION +def protobufVersion = '3.25.8' + +dependencies { + implementation "io.grpc:grpc-protobuf:${grpcVersion}" + implementation "io.grpc:grpc-netty:${grpcVersion}" + implementation "io.grpc:grpc-stub:${grpcVersion}" + implementation "io.grpc:grpc-services:${grpcVersion}" +} + +protobuf { + protoc { + artifact = "com.google.protobuf:protoc:${protobufVersion}" + } + plugins { + grpc { + artifact = "io.grpc:protoc-gen-grpc-java:${grpcVersion}" + } + } + generateProtoTasks { + all()*.plugins { + grpc {} + } + } +} + +startScripts.enabled = false + +task DualStackClient(type: CreateStartScripts) { + mainClass = 'io.grpc.examples.dualstack.DualStackClient' + applicationName = 'dual-stack-client' + outputDir = new File(project.buildDir, 'tmp/scripts/' + name) + classpath = startScripts.classpath +} + +task DualStackServer(type: CreateStartScripts) { + mainClass = 'io.grpc.examples.dualstack.DualStackServer' + applicationName = 'dual-stack-server' + outputDir = new File(project.buildDir, 'tmp/scripts/' + name) + classpath = startScripts.classpath +} + +application { + applicationDistribution.into('bin') { + from(DualStackClient) + from(DualStackServer) + filePermissions { + unix(0755) + } + } +} diff --git a/examples/example-dualstack/pom.xml b/examples/example-dualstack/pom.xml new file mode 100644 index 00000000000..99c0da77a22 --- /dev/null +++ b/examples/example-dualstack/pom.xml @@ -0,0 +1,116 @@ + + 4.0.0 + io.grpc + example-dualstack + jar + + 1.82.0-SNAPSHOT + example-dualstack + https://github.com/grpc/grpc-java + + + UTF-8 + 1.82.0-SNAPSHOT + 3.25.8 + + 1.8 + 1.8 + + + + + + io.grpc + grpc-bom + ${grpc.version} + pom + import + + + + + + + io.grpc + grpc-services + + + io.grpc + grpc-protobuf + + + io.grpc + grpc-stub + + + io.grpc + grpc-netty + + + io.grpc + grpc-netty-shaded + runtime + + + junit + junit + 4.13.2 + test + + + io.grpc + grpc-testing + test + + + + + + + kr.motd.maven + os-maven-plugin + 1.7.1 + + + + + org.xolstice.maven.plugins + protobuf-maven-plugin + 0.6.1 + + com.google.protobuf:protoc:${protoc.version}:exe:${os.detected.classifier} + grpc-java + io.grpc:protoc-gen-grpc-java:${grpc.version}:exe:${os.detected.classifier} + + + + + compile + compile-custom + + + + + + org.apache.maven.plugins + maven-enforcer-plugin + 3.5.0 + + + enforce + + enforce + + + + + + + + + + + + diff --git a/examples/example-dualstack/settings.gradle b/examples/example-dualstack/settings.gradle new file mode 100644 index 00000000000..160d5134334 --- /dev/null +++ b/examples/example-dualstack/settings.gradle @@ -0,0 +1,21 @@ +pluginManagement { + // https://issuetracker.google.com/issues/342522142#comment8 + // use D8/R8 8.0.44 or 8.1.44 with AGP 7.4 if needed. + buildscript { + repositories { + mavenCentral() + maven { + url = uri("https://storage.googleapis.com/r8-releases/raw") + } + } + dependencies { + classpath("com.android.tools:r8:8.1.44") + } + } + + repositories { + gradlePluginPortal() + } +} + +rootProject.name = 'example-dualstack' diff --git a/examples/example-dualstack/src/main/java/io/grpc/examples/dualstack/DualStackClient.java b/examples/example-dualstack/src/main/java/io/grpc/examples/dualstack/DualStackClient.java new file mode 100644 index 00000000000..b9993a524d6 --- /dev/null +++ b/examples/example-dualstack/src/main/java/io/grpc/examples/dualstack/DualStackClient.java @@ -0,0 +1,95 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.examples.dualstack; + +import io.grpc.Channel; +import io.grpc.ManagedChannel; +import io.grpc.ManagedChannelBuilder; +import io.grpc.NameResolverRegistry; +import io.grpc.StatusRuntimeException; +import io.grpc.examples.helloworld.GreeterGrpc; +import io.grpc.examples.helloworld.HelloReply; +import io.grpc.examples.helloworld.HelloRequest; +import java.util.concurrent.TimeUnit; +import java.util.logging.Level; +import java.util.logging.Logger; + +/** + * A client that requests greetings from the {@link DualStackServer}. + * First it sends 5 requests using the default nameresolver and load balancer. + * Then it sends 10 requests using the example nameresolver and round robin load balancer. These + * requests are evenly distributed among the 3 servers rather than favoring the server listening + * on both addresses because the ExampleDualStackNameResolver groups the 3 servers as 3 endpoints + * each with 2 addresses. + */ +public class DualStackClient { + public static final String channelTarget = "example:///lb.example.grpc.io"; + private static final Logger logger = Logger.getLogger(DualStackClient.class.getName()); + private final GreeterGrpc.GreeterBlockingStub blockingStub; + + public DualStackClient(Channel channel) { + blockingStub = GreeterGrpc.newBlockingStub(channel); + } + + public static void main(String[] args) throws Exception { + NameResolverRegistry.getDefaultRegistry() + .register(new ExampleDualStackNameResolverProvider()); + + logger.info("\n **** Use default DNS resolver ****"); + ManagedChannel channel = ManagedChannelBuilder.forTarget("localhost:50051") + .usePlaintext() + .build(); + try { + DualStackClient client = new DualStackClient(channel); + for (int i = 0; i < 5; i++) { + client.greet("request:" + i); + } + } finally { + channel.shutdownNow().awaitTermination(5, TimeUnit.SECONDS); + } + + logger.info("\n **** Change to use example name resolver ****"); + /* + Dial to "example:///resolver.example.grpc.io", use {@link ExampleNameResolver} to create connection + "resolver.example.grpc.io" is converted to {@link java.net.URI.path} + */ + channel = ManagedChannelBuilder.forTarget(channelTarget) + .defaultLoadBalancingPolicy("round_robin") + .usePlaintext() + .build(); + try { + DualStackClient client = new DualStackClient(channel); + for (int i = 0; i < 10; i++) { + client.greet("request:" + i); + } + } finally { + channel.shutdownNow().awaitTermination(5, TimeUnit.SECONDS); + } + } + + public void greet(String name) { + HelloRequest request = HelloRequest.newBuilder().setName(name).build(); + HelloReply response; + try { + response = blockingStub.sayHello(request); + } catch (StatusRuntimeException e) { + logger.log(Level.WARNING, "RPC failed: {0}", e.getStatus()); + return; + } + logger.info("Greeting: " + response.getMessage()); + } +} diff --git a/examples/example-dualstack/src/main/java/io/grpc/examples/dualstack/DualStackServer.java b/examples/example-dualstack/src/main/java/io/grpc/examples/dualstack/DualStackServer.java new file mode 100644 index 00000000000..43b21e963f8 --- /dev/null +++ b/examples/example-dualstack/src/main/java/io/grpc/examples/dualstack/DualStackServer.java @@ -0,0 +1,126 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.examples.dualstack; + +import io.grpc.Server; +import io.grpc.ServerBuilder; +import io.grpc.examples.helloworld.GreeterGrpc; +import io.grpc.examples.helloworld.HelloReply; +import io.grpc.examples.helloworld.HelloRequest; +import io.grpc.netty.NettyServerBuilder; +import io.grpc.stub.StreamObserver; +import java.io.IOException; +import java.net.InetSocketAddress; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.TimeUnit; +import java.util.logging.Logger; + +/** + * Starts 3 different greeter services each on its own port, but all for localhost. + * The first service listens on both IPv4 and IPv6, + * the second on just IPv4, and the third on just IPv6. + */ +public class DualStackServer { + private static final Logger logger = Logger.getLogger(DualStackServer.class.getName()); + private List servers; + + public static void main(String[] args) throws IOException, InterruptedException { + final DualStackServer server = new DualStackServer(); + server.start(); + server.blockUntilShutdown(); + } + + private void start() throws IOException { + InetSocketAddress inetSocketAddress; + + servers = new ArrayList<>(); + int[] serverPorts = ExampleDualStackNameResolver.SERVER_PORTS; + for (int i = 0; i < serverPorts.length; i++ ) { + String addressType; + int port = serverPorts[i]; + ServerBuilder serverBuilder; + switch (i) { + case 0: + serverBuilder = ServerBuilder.forPort(port); // bind to both IPv4 and IPv6 + addressType = "both IPv4 and IPv6"; + break; + case 1: + // bind to IPv4 only + inetSocketAddress = new InetSocketAddress("127.0.0.1", port); + serverBuilder = NettyServerBuilder.forAddress(inetSocketAddress); + addressType = "IPv4 only"; + break; + case 2: + // bind to IPv6 only + inetSocketAddress = new InetSocketAddress("::1", port); + serverBuilder = NettyServerBuilder.forAddress(inetSocketAddress); + addressType = "IPv6 only"; + break; + default: + throw new IllegalStateException("Unexpected value: " + i); + } + + servers.add(serverBuilder + .addService(new GreeterImpl(port, addressType)) + .build() + .start()); + logger.info("Server started, listening on " + port); + } + Runtime.getRuntime().addShutdownHook(new Thread(() -> { + System.err.println("*** shutting down gRPC server since JVM is shutting down"); + try { + DualStackServer.this.stop(); + } catch (InterruptedException e) { + e.printStackTrace(System.err); + } + System.err.println("*** server shut down"); + })); + } + + private void stop() throws InterruptedException { + for (Server server : servers) { + server.shutdown().awaitTermination(30, TimeUnit.SECONDS); + } + } + + private void blockUntilShutdown() throws InterruptedException { + for (Server server : servers) { + server.awaitTermination(); + } + } + + static class GreeterImpl extends GreeterGrpc.GreeterImplBase { + + int port; + String addressType; + + public GreeterImpl(int port, String addressType) { + this.port = port; + this.addressType = addressType; + } + + @Override + public void sayHello(HelloRequest req, StreamObserver responseObserver) { + String msg = String.format("Hello %s from server<%d> type: %s", + req.getName(), this.port, addressType); + HelloReply reply = HelloReply.newBuilder().setMessage(msg).build(); + responseObserver.onNext(reply); + responseObserver.onCompleted(); + } + } +} diff --git a/examples/example-dualstack/src/main/java/io/grpc/examples/dualstack/ExampleDualStackNameResolver.java b/examples/example-dualstack/src/main/java/io/grpc/examples/dualstack/ExampleDualStackNameResolver.java new file mode 100644 index 00000000000..70675b3de3d --- /dev/null +++ b/examples/example-dualstack/src/main/java/io/grpc/examples/dualstack/ExampleDualStackNameResolver.java @@ -0,0 +1,98 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.examples.dualstack; + +import com.google.common.collect.ImmutableMap; +import io.grpc.EquivalentAddressGroup; +import io.grpc.NameResolver; +import io.grpc.Status; +import java.net.InetSocketAddress; +import java.net.SocketAddress; +import java.net.URI; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.stream.Collectors; + +/** + * A fake name resolver that resolves to a hard-coded list of 3 endpoints (EquivalentAddressGropu) + * each with 2 addresses (one IPv4 and one IPv6). + */ +public class ExampleDualStackNameResolver extends NameResolver { + static public final int[] SERVER_PORTS = {50051, 50052, 50053}; + + // This is a fake name resolver, so we just hard code the address here. + private static final ImmutableMap>> addrStore = + ImmutableMap.>>builder() + .put("lb.example.grpc.io", + Arrays.stream(SERVER_PORTS) + .mapToObj(port -> getLocalAddrs(port)) + .collect(Collectors.toList()) + ) + .build(); + + private Listener2 listener; + + private final URI uri; + + public ExampleDualStackNameResolver(URI targetUri) { + this.uri = targetUri; + } + + private static List getLocalAddrs(int port) { + return Arrays.asList( + new InetSocketAddress("127.0.0.1", port), + new InetSocketAddress("::1", port)); + } + + @Override + public String getServiceAuthority() { + return uri.getPath().substring(1); + } + + @Override + public void shutdown() { + } + + @Override + public void start(Listener2 listener) { + this.listener = listener; + this.resolve(); + } + + @Override + public void refresh() { + this.resolve(); + } + + private void resolve() { + List> addresses = addrStore.get(uri.getPath().substring(1)); + try { + List eagList = new ArrayList<>(); + for (List endpoint : addresses) { + // every server is an EquivalentAddressGroup, so they can be accessed randomly + eagList.add(new EquivalentAddressGroup(endpoint)); + } + + this.listener.onResult(ResolutionResult.newBuilder().setAddresses(eagList).build()); + } catch (Exception e){ + // when error occurs, notify listener + this.listener.onError(Status.UNAVAILABLE.withDescription("Unable to resolve host ").withCause(e)); + } + } + +} diff --git a/examples/example-dualstack/src/main/java/io/grpc/examples/dualstack/ExampleDualStackNameResolverProvider.java b/examples/example-dualstack/src/main/java/io/grpc/examples/dualstack/ExampleDualStackNameResolverProvider.java new file mode 100644 index 00000000000..a01d68aca3e --- /dev/null +++ b/examples/example-dualstack/src/main/java/io/grpc/examples/dualstack/ExampleDualStackNameResolverProvider.java @@ -0,0 +1,47 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.examples.dualstack; + +import io.grpc.NameResolver; +import io.grpc.NameResolverProvider; + +import java.net.URI; + +public class ExampleDualStackNameResolverProvider extends NameResolverProvider { + public static final String exampleScheme = "example"; + + @Override + public NameResolver newNameResolver(URI targetUri, NameResolver.Args args) { + return new ExampleDualStackNameResolver(targetUri); + } + + @Override + protected boolean isAvailable() { + return true; + } + + @Override + protected int priority() { + return 5; + } + + @Override + // gRPC choose the first NameResolverProvider that supports the target URI scheme. + public String getDefaultScheme() { + return exampleScheme; + } +} diff --git a/examples/example-dualstack/src/main/proto/helloworld/helloworld.proto b/examples/example-dualstack/src/main/proto/helloworld/helloworld.proto new file mode 100644 index 00000000000..c60d9416f1f --- /dev/null +++ b/examples/example-dualstack/src/main/proto/helloworld/helloworld.proto @@ -0,0 +1,37 @@ +// Copyright 2015 The gRPC Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +syntax = "proto3"; + +option java_multiple_files = true; +option java_package = "io.grpc.examples.helloworld"; +option java_outer_classname = "HelloWorldProto"; +option objc_class_prefix = "HLW"; + +package helloworld; + +// The greeting service definition. +service Greeter { + // Sends a greeting + rpc SayHello (HelloRequest) returns (HelloReply) {} +} + +// The request message containing the user's name. +message HelloRequest { + string name = 1; +} + +// The response message containing the greetings +message HelloReply { + string message = 1; +} diff --git a/examples/example-gauth/BUILD.bazel b/examples/example-gauth/BUILD.bazel index edc4a291e27..033c51f8856 100644 --- a/examples/example-gauth/BUILD.bazel +++ b/examples/example-gauth/BUILD.bazel @@ -1,4 +1,5 @@ -load("@io_grpc_grpc_java//:java_grpc_library.bzl", "java_grpc_library") +load("@rules_java//java:java_binary.bzl", "java_binary") +load("@rules_java//java:java_library.bzl", "java_library") java_library( name = "example-gauth", diff --git a/examples/example-gauth/README.md b/examples/example-gauth/README.md index 622c14cb57b..b49d346a9be 100644 --- a/examples/example-gauth/README.md +++ b/examples/example-gauth/README.md @@ -43,13 +43,13 @@ gcloud pubsub topics create Topic1 5. You will now need to set up [authentication](https://cloud.google.com/docs/authentication/) and a [service account](https://cloud.google.com/docs/authentication/#service_accounts) in order to access Pub/Sub via gRPC APIs as described [here](https://cloud.google.com/iam/docs/creating-managing-service-accounts). -Assign the [role](https://cloud.google.com/iam/docs/granting-roles-to-service-accounts) `Project -> Owner` +(**Note:** This step is unnecessary on Google platforms (Google App Engine / Google Cloud Shell / Google Compute Engine) as it will +automatically use the in-built Google credentials). Assign the [role](https://cloud.google.com/iam/docs/granting-roles-to-service-accounts) `Project -> Owner` and for Key type select JSON. Once you click `Create`, a JSON file containing your key is downloaded to your computer. Note down the path of this file or copy this file to the computer and file system where you will be running the example application as described later. Assume this JSON file is available at -`/path/to/JSON/file`. You can also use the `gcloud` shell commands to -[create the service account](https://cloud.google.com/iam/docs/creating-managing-service-accounts#iam-service-accounts-create-gcloud) -and [the JSON file](https://cloud.google.com/iam/docs/creating-managing-service-account-keys#iam-service-account-keys-create-gcloud). +`/path/to/JSON/file` Set the value of the environment variable GOOGLE_APPLICATION_CREDENTIALS to this file path. You can also use the `gcloud` shell commands to +[create the service account](https://cloud.google.com/iam/docs/creating-managing-service-accounts#iam-service-accounts-create-gcloud). #### To build the examples @@ -62,19 +62,18 @@ $ ../gradlew installDist #### How to run the example: -`google-auth-client` requires two command line arguments for the location of the JSON file and the project ID: +`google-auth-client` requires one command line argument for the project ID: ```text -USAGE: GoogleAuthClient +USAGE: GoogleAuthClient ``` -The first argument is the location of the JSON file you created in step 5 above. -The second argument is the project ID in the form "projects/xyz123" where "xyz123" is +The first argument is the project ID in the form "projects/xyz123" where "xyz123" is the project ID of the project you created (or used) in step 2 above. ```bash # Run the client -./build/install/example-gauth/bin/google-auth-client /path/to/JSON/file projects/xyz123 +./build/install/example-gauth/bin/google-auth-client projects/xyz123 ``` That's it! The client will show the list of Pub/Sub topics for the project as follows: @@ -93,7 +92,7 @@ the project ID of the project you created (or used) in step 2 above. ``` $ mvn verify $ # Run the client - $ mvn exec:java -Dexec.mainClass=io.grpc.examples.googleAuth.GoogleAuthClient -Dexec.args="/path/to/JSON/file projects/xyz123" + $ mvn exec:java -Dexec.mainClass=io.grpc.examples.googleAuth.GoogleAuthClient -Dexec.args="projects/xyz123" ``` ## Bazel @@ -101,5 +100,5 @@ the project ID of the project you created (or used) in step 2 above. ``` $ bazel build :google-auth-client $ # Run the client - $ ../bazel-bin/google-auth-client /path/to/JSON/file projects/xyz123 + $ ../bazel-bin/google-auth-client projects/xyz123 ``` \ No newline at end of file diff --git a/examples/example-gauth/build.gradle b/examples/example-gauth/build.gradle index 65cba558068..5ab563479a4 100644 --- a/examples/example-gauth/build.gradle +++ b/examples/example-gauth/build.gradle @@ -1,15 +1,12 @@ plugins { // Provide convenience executables for trying out the examples. id 'application' - id 'com.google.protobuf' version '0.9.4' + id 'com.google.protobuf' version '0.9.5' // Generate IntelliJ IDEA's .idea & .iml project files id 'idea' } repositories { - maven { // The google mirror is less flaky than mavenCentral() - url "https://maven-central.storage-download.googleapis.com/maven2/" - } mavenCentral() mavenLocal() } @@ -24,8 +21,8 @@ java { // Feel free to delete the comment at the next line. It is just for safely // updating the version in our release process. -def grpcVersion = '1.63.0-SNAPSHOT' // CURRENT_GRPC_VERSION -def protobufVersion = '3.25.1' +def grpcVersion = '1.82.0-SNAPSHOT' // CURRENT_GRPC_VERSION +def protobufVersion = '3.25.8' def protocVersion = protobufVersion @@ -33,8 +30,7 @@ dependencies { implementation "io.grpc:grpc-protobuf:${grpcVersion}" implementation "io.grpc:grpc-stub:${grpcVersion}" implementation "io.grpc:grpc-auth:${grpcVersion}" - compileOnly "org.apache.tomcat:annotations-api:6.0.53" - implementation "com.google.auth:google-auth-library-oauth2-http:1.4.0" + implementation "com.google.auth:google-auth-library-oauth2-http:1.42.1" implementation "com.google.api.grpc:grpc-google-cloud-pubsub-v1:0.1.24" runtimeOnly "io.grpc:grpc-netty-shaded:${grpcVersion}" } @@ -49,16 +45,6 @@ protobuf { } } -// Inform IDEs like IntelliJ IDEA, Eclipse or NetBeans about the generated code. -sourceSets { - main { - java { - srcDirs 'build/generated/source/proto/main/grpc' - srcDirs 'build/generated/source/proto/main/java' - } - } -} - startScripts.enabled = false task googleAuthClient(type: CreateStartScripts) { @@ -71,6 +57,8 @@ task googleAuthClient(type: CreateStartScripts) { application { applicationDistribution.into('bin') { from(googleAuthClient) - fileMode = 0755 + filePermissions { + unix(0755) + } } } diff --git a/examples/example-gauth/pom.xml b/examples/example-gauth/pom.xml index 65a8976f317..66e0f3be563 100644 --- a/examples/example-gauth/pom.xml +++ b/examples/example-gauth/pom.xml @@ -6,14 +6,14 @@ jar - 1.63.0-SNAPSHOT + 1.82.0-SNAPSHOT example-gauth https://github.com/grpc/grpc-java UTF-8 - 1.63.0-SNAPSHOT - 3.25.1 + 1.82.0-SNAPSHOT + 3.25.8 1.8 1.8 @@ -28,6 +28,11 @@ pom import + + com.google.code.gson + gson + 2.13.2 + @@ -49,12 +54,6 @@ io.grpc grpc-auth - - org.apache.tomcat - annotations-api - 6.0.53 - provided - io.grpc grpc-testing @@ -63,7 +62,7 @@ com.google.auth google-auth-library-oauth2-http - 1.4.0 + 1.40.0 com.google.api.grpc @@ -96,7 +95,7 @@ org.apache.maven.plugins maven-enforcer-plugin - 1.4.1 + 3.5.0 enforce diff --git a/examples/example-gauth/settings.gradle b/examples/example-gauth/settings.gradle index 273558dd9cf..6bd0f0cdc2d 100644 --- a/examples/example-gauth/settings.gradle +++ b/examples/example-gauth/settings.gradle @@ -1,8 +1,19 @@ pluginManagement { - repositories { - maven { // The google mirror is less flaky than mavenCentral() - url "https://maven-central.storage-download.googleapis.com/maven2/" + // https://issuetracker.google.com/issues/342522142#comment8 + // use D8/R8 8.0.44 or 8.1.44 with AGP 7.4 if needed. + buildscript { + repositories { + mavenCentral() + maven { + url = uri("https://storage.googleapis.com/r8-releases/raw") + } + } + dependencies { + classpath("com.android.tools:r8:8.1.44") } + } + + repositories { gradlePluginPortal() } } diff --git a/examples/example-gauth/src/main/java/io/grpc/examples/googleAuth/GoogleAuthClient.java b/examples/example-gauth/src/main/java/io/grpc/examples/googleAuth/GoogleAuthClient.java index 4d3dd044376..eb0d9feedfc 100644 --- a/examples/example-gauth/src/main/java/io/grpc/examples/googleAuth/GoogleAuthClient.java +++ b/examples/example-gauth/src/main/java/io/grpc/examples/googleAuth/GoogleAuthClient.java @@ -33,7 +33,7 @@ /** * Example to illustrate use of Google credentials as described in - * @see Google Auth Example README + * @see Google Auth Example README * * Also @see Google Cloud Pubsub via gRPC */ @@ -52,7 +52,7 @@ public class GoogleAuthClient { * * @param host host to connect to - typically "pubsub.googleapis.com" * @param port port to connect to - typically 443 - the TLS port - * @param callCredentials the Google call credentials created from a JSON file + * @param callCredentials the Google call credentials */ public GoogleAuthClient(String host, int port, CallCredentials callCredentials) { // Google API invocation requires a secure channel. Channels are secure by default (SSL/TLS) @@ -63,7 +63,7 @@ public GoogleAuthClient(String host, int port, CallCredentials callCredentials) * Construct our gRPC client that connects to the pubsub server using an existing channel. * * @param channel channel that has been built already - * @param callCredentials the Google call credentials created from a JSON file + * @param callCredentials the Google call credentials */ GoogleAuthClient(ManagedChannel channel, CallCredentials callCredentials) { this.channel = channel; @@ -101,32 +101,30 @@ public void getTopics(String projectID) { /** * The app requires 2 arguments as described in - * @see Google Auth Example README + * @see Google Auth Example README * - * arg0 = location of the JSON file for the service account you created in the GCP console - * arg1 = project name in the form "projects/balmy-cirrus-225307" where "balmy-cirrus-225307" is + * arg0 = project name in the form "projects/balmy-cirrus-225307" where "balmy-cirrus-225307" is * the project ID for the project you created. * + * On non-Google platforms, the GOOGLE_APPLICATION_CREDENTIALS env variable should be set to the + * location of the JSON file for the service account you created in the GCP console. */ public static void main(String[] args) throws Exception { - if (args.length < 2) { - logger.severe("Usage: please pass 2 arguments:\n" + - "arg0 = location of the JSON file for the service account you created in the GCP console\n" + - "arg1 = project name in the form \"projects/xyz\" where \"xyz\" is the project ID of the project you created.\n"); + if (args.length < 1) { + logger.severe("Usage: please pass 1 argument:\n" + + "arg0 = project name in the form \"projects/xyz\" where \"xyz\" is the project ID of the project you created.\n"); System.exit(1); } - GoogleCredentials credentials = GoogleCredentials.fromStream(new FileInputStream(args[0])); + GoogleCredentials credentials = GoogleCredentials.getApplicationDefault(); // We need to create appropriate scope as per https://cloud.google.com/storage/docs/authentication#oauth-scopes credentials = credentials.createScoped(Arrays.asList("https://www.googleapis.com/auth/cloud-platform")); - // credentials must be refreshed before the access token is available - credentials.refreshAccessToken(); GoogleAuthClient client = new GoogleAuthClient("pubsub.googleapis.com", 443, MoreCallCredentials.from(credentials)); try { - client.getTopics(args[1]); + client.getTopics(args[0]); } finally { client.shutdown(); } diff --git a/examples/example-gcp-csm-observability/README.md b/examples/example-gcp-csm-observability/README.md new file mode 100644 index 00000000000..cbb206fbb46 --- /dev/null +++ b/examples/example-gcp-csm-observability/README.md @@ -0,0 +1,43 @@ +gRPC GCP CSM Observability Example +================ + +The GCP CSM Observability example consists of a Hello World client and a Hello World server and shows how to configure CSM Observability +for gRPC client and gRPC server. + +## Configuration + +`CsmObservabilityClient` takes the following command-line arguments - +* user - Name to be greeted. +* target - Server address. Default value is `xds:///helloworld:50051`. + * When client tries to connect to target, gRPC would use xDS to resolve this target and connect to the server backend. +* prometheusPort - Port used for exposing prometheus metrics. Default value is `9464`. + + +`CsmObservabilityServer` takes the following command-line arguments - +* port - Port used for running Hello World server. Default value is `50051`. +* prometheusPort - Port used for exposing prometheus metrics. Default value is `9464`. + +## Build the example + +From the `grpc-java/examples/`directory i.e, +``` +cd grpc-java/examples +``` +Run the following to generate client and server images respectively. + +Client: +``` +docker build -f example-gcp-csm-observability/csm-client.Dockerfile . +``` +Server: +``` +docker build -f example-gcp-csm-observability/csm-server.Dockerfile . +``` + +To push to a registry, add a tag to the image either by adding a `-t` flag to `docker build` command above or run: + +``` +docker image tag ${sha from build command above} ${tag} +``` + +And then push the tagged image using `docker push`. diff --git a/examples/example-gcp-csm-observability/build.gradle b/examples/example-gcp-csm-observability/build.gradle new file mode 100644 index 00000000000..2ddfd995cd3 --- /dev/null +++ b/examples/example-gcp-csm-observability/build.gradle @@ -0,0 +1,76 @@ +plugins { + // Provide convenience executables for trying out the examples. + id 'application' + id 'com.google.protobuf' version '0.9.5' + // Generate IntelliJ IDEA's .idea & .iml project files + id 'idea' + id 'java' +} + +repositories { + mavenCentral() + mavenLocal() +} + +java { + sourceCompatibility = JavaVersion.VERSION_1_8 + targetCompatibility = JavaVersion.VERSION_1_8 +} + +// IMPORTANT: You probably want the non-SNAPSHOT version of gRPC. Make sure you +// are looking at a tagged version of the example and not "master"! + +// Feel free to delete the comment at the next line. It is just for safely +// updating the version in our release process. +def grpcVersion = '1.82.0-SNAPSHOT' // CURRENT_GRPC_VERSION +def protocVersion = '3.25.8' +def openTelemetryVersion = '1.56.0' +def openTelemetryPrometheusVersion = '1.56.0-alpha' + +dependencies { + implementation "io.grpc:grpc-protobuf:${grpcVersion}" + implementation "io.grpc:grpc-stub:${grpcVersion}" + implementation "io.grpc:grpc-gcp-csm-observability:${grpcVersion}" + implementation "io.grpc:grpc-xds:${grpcVersion}" + implementation "io.opentelemetry:opentelemetry-sdk:${openTelemetryVersion}" + implementation "io.opentelemetry:opentelemetry-sdk-metrics:${openTelemetryVersion}" + implementation "io.opentelemetry:opentelemetry-exporter-prometheus:${openTelemetryPrometheusVersion}" + runtimeOnly "io.grpc:grpc-xds:${grpcVersion}" + runtimeOnly "io.grpc:grpc-netty-shaded:${grpcVersion}" +} + +protobuf { + protoc { artifact = "com.google.protobuf:protoc:${protocVersion}" } + plugins { + grpc { artifact = "io.grpc:protoc-gen-grpc-java:${grpcVersion}" } + } + generateProtoTasks { + all()*.plugins { grpc {} } + } +} + +startScripts.enabled = false + +task CsmObservabilityHelloWorldServer(type: CreateStartScripts) { + mainClass = 'io.grpc.examples.csmobservability.CsmObservabilityServer' + applicationName = 'csm-observability-server' + outputDir = new File(project.buildDir, 'tmp/scripts/' + name) + classpath = startScripts.classpath +} + +task CsmObservabilityHelloWorldClient(type: CreateStartScripts) { + mainClass = 'io.grpc.examples.csmobservability.CsmObservabilityClient' + applicationName = 'csm-observability-client' + outputDir = new File(project.buildDir, 'tmp/scripts/' + name) + classpath = startScripts.classpath +} + +application { + applicationDistribution.into('bin') { + from(CsmObservabilityHelloWorldServer) + from(CsmObservabilityHelloWorldClient) + filePermissions { + unix(0755) + } + } +} diff --git a/examples/example-gcp-csm-observability/csm-client.Dockerfile b/examples/example-gcp-csm-observability/csm-client.Dockerfile new file mode 100644 index 00000000000..31a3262e863 --- /dev/null +++ b/examples/example-gcp-csm-observability/csm-client.Dockerfile @@ -0,0 +1,47 @@ +# Copyright 2024 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# +# Stage 1: Build CSM client +# + +FROM eclipse-temurin:11-jdk AS build + +WORKDIR /grpc-java/examples +COPY . . + +RUN cd example-gcp-csm-observability && ../gradlew installDist -PskipCodegen=true -PskipAndroid=true + +# +# Stage 2: +# +# - Copy only the necessary files to reduce Docker image size. +# - Have an ENTRYPOINT script which will launch the CSM client +# with the given parameters. +# + +FROM eclipse-temurin:11-jre + +WORKDIR /grpc-java/ +COPY --from=build /grpc-java/examples/example-gcp-csm-observability/build/install/example-gcp-csm-observability/. . + +# Intentionally after the COPY to force the update on each build. +# Update Ubuntu system packages: +RUN apt-get update \ + && apt-get -y upgrade \ + && apt-get -y autoremove \ + && rm -rf /var/lib/apt/lists/* + +# Client +ENTRYPOINT ["bin/csm-observability-client"] diff --git a/examples/example-gcp-csm-observability/csm-server.Dockerfile b/examples/example-gcp-csm-observability/csm-server.Dockerfile new file mode 100644 index 00000000000..675b450143f --- /dev/null +++ b/examples/example-gcp-csm-observability/csm-server.Dockerfile @@ -0,0 +1,47 @@ +# Copyright 2024 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# +# Stage 1: Build CSM server +# + +FROM eclipse-temurin:11-jdk AS build + +WORKDIR /grpc-java/examples +COPY . . + +RUN cd example-gcp-csm-observability && ../gradlew installDist -PskipCodegen=true -PskipAndroid=true + +# +# Stage 2: +# +# - Copy only the necessary files to reduce Docker image size. +# - Have an ENTRYPOINT script which will launch the CSM server +# with the given parameters. +# + +FROM eclipse-temurin:11-jre + +WORKDIR /grpc-java/ +COPY --from=build /grpc-java/examples/example-gcp-csm-observability/build/install/example-gcp-csm-observability/. . + +# Intentionally after the COPY to force the update on each build. +# Update Ubuntu system packages: +RUN apt-get update \ + && apt-get -y upgrade \ + && apt-get -y autoremove \ + && rm -rf /var/lib/apt/lists/* + +# Server +ENTRYPOINT ["bin/csm-observability-server"] diff --git a/examples/example-gcp-csm-observability/settings.gradle b/examples/example-gcp-csm-observability/settings.gradle new file mode 100644 index 00000000000..44e6f340ede --- /dev/null +++ b/examples/example-gcp-csm-observability/settings.gradle @@ -0,0 +1,17 @@ +pluginManagement { + // https://issuetracker.google.com/issues/342522142#comment8 + // use D8/R8 8.0.44 or 8.1.44 with AGP 7.4 if needed. + buildscript { + repositories { + mavenCentral() + maven { + url = uri("https://storage.googleapis.com/r8-releases/raw") + } + } + dependencies { + classpath("com.android.tools:r8:8.1.44") + } + } +} + +rootProject.name = 'example-gcp-csm-observability' diff --git a/examples/example-gcp-csm-observability/src/main/java/io/grpc/examples/csmobservability/CsmObservabilityClient.java b/examples/example-gcp-csm-observability/src/main/java/io/grpc/examples/csmobservability/CsmObservabilityClient.java new file mode 100644 index 00000000000..dd0ab7eb546 --- /dev/null +++ b/examples/example-gcp-csm-observability/src/main/java/io/grpc/examples/csmobservability/CsmObservabilityClient.java @@ -0,0 +1,152 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.examples.csmobservability; + +import io.grpc.Channel; +import io.grpc.Grpc; +import io.grpc.InsecureChannelCredentials; +import io.grpc.ManagedChannel; +import io.grpc.StatusRuntimeException; +import io.grpc.examples.helloworld.GreeterGrpc; +import io.grpc.examples.helloworld.HelloReply; +import io.grpc.examples.helloworld.HelloRequest; +import io.grpc.gcp.csm.observability.CsmObservability; +import io.grpc.xds.XdsChannelCredentials; +import io.opentelemetry.exporter.prometheus.PrometheusHttpServer; +import io.opentelemetry.sdk.OpenTelemetrySdk; +import io.opentelemetry.sdk.metrics.SdkMeterProvider; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.logging.Level; +import java.util.logging.Logger; + +/** + * A simple CSM observability client that requests a greeting from the {@link HelloWorldServer} and + * generates CSM telemetry data based on the configuration. + */ +public class CsmObservabilityClient { + private static final Logger logger = Logger.getLogger(CsmObservabilityClient.class.getName()); + + private final GreeterGrpc.GreeterBlockingStub blockingStub; + + /** Construct client for accessing HelloWorld server using the existing channel. */ + public CsmObservabilityClient(Channel channel) { + blockingStub = GreeterGrpc.newBlockingStub(channel); + } + + /** Say hello to server. */ + public void greet(String name) { + logger.info("Will try to greet " + name + " ..."); + HelloRequest request = HelloRequest.newBuilder().setName(name).build(); + HelloReply response; + try { + response = blockingStub.sayHello(request); + } catch (StatusRuntimeException e) { + logger.log(Level.WARNING, "RPC failed: {0}", e.getStatus()); + return; + } + logger.info("Greeting: " + response.getMessage()); + } + + /** + * Greet server. If provided, the first element of {@code args} is the name to use in the + * greeting. The second argument is the target server. + */ + public static void main(String[] args) throws Exception { + String user = "world"; + // Use xDS to establish contact with the server "helloworld:50051". + String target = "xds:///helloworld:50051"; + // The port on which prometheus metrics will be exposed. + int prometheusPort = 9464; + AtomicBoolean sendRpcs = new AtomicBoolean(true); + if (args.length > 0) { + if ("--help".equals(args[0])) { + System.err.println("Usage: [name [target [prometheusPort]]]"); + System.err.println(""); + System.err.println(" name The name you wish to be greeted by. Defaults to " + user); + System.err.println(" target The server to connect to. Defaults to " + target); + System.err.println(" prometheusPort The port to expose prometheus metrics. Defaults to " + prometheusPort); + System.exit(1); + } + user = args[0]; + } + if (args.length > 1) { + target = args[1]; + } + if (args.length > 2) { + prometheusPort = Integer.parseInt(args[2]); + } + + Thread mainThread = Thread.currentThread(); + + Runtime.getRuntime().addShutdownHook(new Thread() { + @Override + public void run() { + // Use stderr here since the logger may have been reset by its JVM shutdown hook. + System.err.println("*** shutting down gRPC client since JVM is shutting down"); + + sendRpcs.set(false); + try { + mainThread.join(); + } catch (InterruptedException e) { + e.printStackTrace(System.err); + } + System.err.println("*** client shut down"); + } + }); + + // Adds a PrometheusHttpServer to convert OpenTelemetry metrics to Prometheus format and + // expose these via a HttpServer exporter to the SdkMeterProvider. + SdkMeterProvider sdkMeterProvider = SdkMeterProvider.builder() + .registerMetricReader( + PrometheusHttpServer.builder().setPort(prometheusPort).build()) + .build(); + + // Initialize OpenTelemetry SDK with MeterProvider configured with Prometeheus. + OpenTelemetrySdk openTelemetrySdk = + OpenTelemetrySdk.builder().setMeterProvider(sdkMeterProvider).build(); + + // Initialize CSM Observability. + CsmObservability observability = CsmObservability.newBuilder() + .sdk(openTelemetrySdk) + .build(); + // Registers CSM observabiity globally. + observability.registerGlobal(); + + // Create a communication channel to the server, known as a Channel. + ManagedChannel channel = + Grpc.newChannelBuilder( + target, XdsChannelCredentials.create(InsecureChannelCredentials.create())) + .build(); + CsmObservabilityClient client = new CsmObservabilityClient(channel); + + try { + // Run RPCs every second. + while (sendRpcs.get()) { + client.greet(user); + // Sleep for a bit before sending the next RPC. + Thread.sleep(1000); + } + } finally { + channel.shutdownNow().awaitTermination(5, TimeUnit.SECONDS); + // Shut down CSM Observability. + observability.close(); + // Shut down OpenTelemetry SDK. + openTelemetrySdk.close(); + } + } +} diff --git a/examples/example-gcp-csm-observability/src/main/java/io/grpc/examples/csmobservability/CsmObservabilityServer.java b/examples/example-gcp-csm-observability/src/main/java/io/grpc/examples/csmobservability/CsmObservabilityServer.java new file mode 100644 index 00000000000..589753b1a4c --- /dev/null +++ b/examples/example-gcp-csm-observability/src/main/java/io/grpc/examples/csmobservability/CsmObservabilityServer.java @@ -0,0 +1,143 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.examples.csmobservability; + +import io.grpc.Grpc; +import io.grpc.InsecureServerCredentials; +import io.grpc.Server; +import io.grpc.examples.helloworld.GreeterGrpc; +import io.grpc.examples.helloworld.HelloReply; +import io.grpc.examples.helloworld.HelloRequest; +import io.grpc.gcp.csm.observability.CsmObservability; +import io.grpc.stub.StreamObserver; +import io.grpc.xds.XdsServerBuilder; +import io.grpc.xds.XdsServerCredentials; +import io.opentelemetry.exporter.prometheus.PrometheusHttpServer; +import io.opentelemetry.sdk.OpenTelemetrySdk; +import io.opentelemetry.sdk.metrics.SdkMeterProvider; +import java.io.IOException; +import java.util.concurrent.TimeUnit; +import java.util.logging.Logger; + +/** + * CSM Observability server that manages startup/shutdown of a {@code Greeter} server and generates + * CSM telemetry based on the configuration. + */ +public class CsmObservabilityServer { + private static final Logger logger = Logger.getLogger(CsmObservabilityServer.class.getName()); + + private Server server; + private void start(int port) throws IOException { + server = + XdsServerBuilder.forPort( + port, XdsServerCredentials.create(InsecureServerCredentials.create())) + .addService(new GreeterImpl()) + .build() + .start(); + logger.info("Server started, listening on " + port); + } + + private void stop() throws InterruptedException { + if (server != null) { + server.shutdown().awaitTermination(30, TimeUnit.SECONDS); + } + } + + /** + * Await termination on the main thread since the grpc library uses daemon threads. + */ + private void blockUntilShutdown() throws InterruptedException { + if (server != null) { + server.awaitTermination(); + } + } + + /** + * Main launches the server from the command line. + */ + public static void main(String[] args) throws IOException, InterruptedException { + // The port on which the server should run. + int port = 50051; + // The port on which prometheus metrics will be exposed. + int prometheusPort = 9464; + + if (args.length > 0) { + if ("--help".equals(args[0])) { + System.err.println("Usage: [port [prometheus_port]]"); + System.err.println(""); + System.err.println(" port The port on which server will run. Defaults to " + port); + System.err.println(" prometheusPort The port to expose prometheus metrics. Defaults to " + prometheusPort); + System.exit(1); + } + port = Integer.parseInt(args[0]); + } + if (args.length > 1) { + prometheusPort = Integer.parseInt(args[1]); + } + + // Adds a PrometheusHttpServer to convert OpenTelemetry metrics to Prometheus format and + // expose these via a HttpServer exporter to the SdkMeterProvider. + SdkMeterProvider sdkMeterProvider = SdkMeterProvider.builder() + .registerMetricReader( + PrometheusHttpServer.builder().setPort(prometheusPort).build()) + .build(); + + // Initialize OpenTelemetry SDK with MeterProvider configured with Prometheus metrics exporter + OpenTelemetrySdk openTelemetrySdk = + OpenTelemetrySdk.builder().setMeterProvider(sdkMeterProvider).build(); + + // Initialize CSM Observability + CsmObservability observability = CsmObservability.newBuilder() + .sdk(openTelemetrySdk) + .build(); + // Registers CSM observabiity globally + observability.registerGlobal(); + + final CsmObservabilityServer server = new CsmObservabilityServer(); + server.start(port); + + Runtime.getRuntime().addShutdownHook(new Thread() { + @Override + public void run() { + System.err.println("*** shutting down gRPC server since JVM is shutting down"); + try { + server.stop(); + } catch (InterruptedException e) { + e.printStackTrace(System.err); + } + // Shut down CSM observability. + observability.close(); + // Shut down OpenTelemetry SDK. + openTelemetrySdk.close(); + + System.err.println("*** server shut down"); + } + }); + + server.blockUntilShutdown(); + } + + static class GreeterImpl extends GreeterGrpc.GreeterImplBase { + + @Override + public void sayHello(HelloRequest req, StreamObserver responseObserver) { + HelloReply reply = HelloReply.newBuilder().setMessage("Hello " + req.getName()).build(); + responseObserver.onNext(reply); + responseObserver.onCompleted(); + } + } +} diff --git a/examples/example-gcp-csm-observability/src/main/proto/helloworld/helloworld.proto b/examples/example-gcp-csm-observability/src/main/proto/helloworld/helloworld.proto new file mode 100644 index 00000000000..64a8c09ee16 --- /dev/null +++ b/examples/example-gcp-csm-observability/src/main/proto/helloworld/helloworld.proto @@ -0,0 +1,39 @@ +/* + * Copyright 2023 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +syntax = "proto3"; + +option java_multiple_files = true; +option java_package = "io.grpc.examples.helloworld"; +option java_outer_classname = "HelloWorldProto"; +option objc_class_prefix = "HLW"; + +package helloworld; + +// The greeting service definition. +service Greeter { + // Sends a greeting + rpc SayHello (HelloRequest) returns (HelloReply) {} +} + +// The request message containing the user's name. +message HelloRequest { + string name = 1; +} + +// The response message containing the greetings +message HelloReply { + string message = 1; +} diff --git a/examples/example-gcp-observability/build.gradle b/examples/example-gcp-observability/build.gradle index a666681d31c..531a5c2f9de 100644 --- a/examples/example-gcp-observability/build.gradle +++ b/examples/example-gcp-observability/build.gradle @@ -1,16 +1,13 @@ plugins { // Provide convenience executables for trying out the examples. id 'application' - id 'com.google.protobuf' version '0.9.4' + id 'com.google.protobuf' version '0.9.5' // Generate IntelliJ IDEA's .idea & .iml project files id 'idea' id 'java' } repositories { - maven { // The google mirror is less flaky than mavenCentral() - url "https://maven-central.storage-download.googleapis.com/maven2/" - } mavenCentral() mavenLocal() } @@ -25,14 +22,13 @@ java { // Feel free to delete the comment at the next line. It is just for safely // updating the version in our release process. -def grpcVersion = '1.63.0-SNAPSHOT' // CURRENT_GRPC_VERSION -def protocVersion = '3.25.1' +def grpcVersion = '1.82.0-SNAPSHOT' // CURRENT_GRPC_VERSION +def protocVersion = '3.25.8' dependencies { implementation "io.grpc:grpc-protobuf:${grpcVersion}" implementation "io.grpc:grpc-stub:${grpcVersion}" implementation "io.grpc:grpc-gcp-observability:${grpcVersion}" - compileOnly "org.apache.tomcat:annotations-api:6.0.53" runtimeOnly "io.grpc:grpc-netty-shaded:${grpcVersion}" } @@ -66,6 +62,8 @@ application { applicationDistribution.into('bin') { from(ObservabilityHelloWorldServer) from(ObservabilityHelloWorldClient) - fileMode = 0755 + filePermissions { + unix(0755) + } } } diff --git a/examples/example-gcp-observability/settings.gradle b/examples/example-gcp-observability/settings.gradle index 1e4ba3812eb..39efc20a459 100644 --- a/examples/example-gcp-observability/settings.gradle +++ b/examples/example-gcp-observability/settings.gradle @@ -1 +1,17 @@ +pluginManagement { + // https://issuetracker.google.com/issues/342522142#comment8 + // use D8/R8 8.0.44 or 8.1.44 with AGP 7.4 if needed. + buildscript { + repositories { + mavenCentral() + maven { + url = uri("https://storage.googleapis.com/r8-releases/raw") + } + } + dependencies { + classpath("com.android.tools:r8:8.1.44") + } + } +} + rootProject.name = 'example-gcp-observability' diff --git a/examples/example-hostname/BUILD.bazel b/examples/example-hostname/BUILD.bazel index 8b76f790983..d5bd3aba94c 100644 --- a/examples/example-hostname/BUILD.bazel +++ b/examples/example-hostname/BUILD.bazel @@ -1,5 +1,8 @@ -load("@rules_proto//proto:defs.bzl", "proto_library") +load("@com_google_protobuf//bazel:java_proto_library.bzl", "java_proto_library") +load("@com_google_protobuf//bazel:proto_library.bzl", "proto_library") load("@io_grpc_grpc_java//:java_grpc_library.bzl", "java_grpc_library") +load("@rules_java//java:java_binary.bzl", "java_binary") +load("@rules_java//java:java_library.bzl", "java_library") proto_library( name = "helloworld_proto", diff --git a/examples/example-hostname/build.gradle b/examples/example-hostname/build.gradle index 113a6aaae98..f776de41511 100644 --- a/examples/example-hostname/build.gradle +++ b/examples/example-hostname/build.gradle @@ -2,13 +2,11 @@ plugins { id 'application' // Provide convenience executables for trying out the examples. id 'java' - id "com.google.protobuf" version "0.9.4" - id 'com.google.cloud.tools.jib' version '3.3.2' // For releasing to Docker Hub + id "com.google.protobuf" version "0.9.5" + id 'com.google.cloud.tools.jib' version '3.4.4' // For releasing to Docker Hub } repositories { - maven { // The google mirror is less flaky than mavenCentral() - url "https://maven-central.storage-download.googleapis.com/maven2/" } mavenCentral() mavenLocal() } @@ -23,14 +21,13 @@ java { // Feel free to delete the comment at the next line. It is just for safely // updating the version in our release process. -def grpcVersion = '1.63.0-SNAPSHOT' // CURRENT_GRPC_VERSION -def protobufVersion = '3.25.1' +def grpcVersion = '1.82.0-SNAPSHOT' // CURRENT_GRPC_VERSION +def protobufVersion = '3.25.8' dependencies { implementation "io.grpc:grpc-protobuf:${grpcVersion}" implementation "io.grpc:grpc-stub:${grpcVersion}" implementation "io.grpc:grpc-services:${grpcVersion}" - compileOnly "org.apache.tomcat:annotations-api:6.0.53" runtimeOnly "io.grpc:grpc-netty-shaded:${grpcVersion}" testImplementation 'junit:junit:4.13.2' diff --git a/examples/example-hostname/pom.xml b/examples/example-hostname/pom.xml index d2bf40e1f8e..8a3c231e3eb 100644 --- a/examples/example-hostname/pom.xml +++ b/examples/example-hostname/pom.xml @@ -6,14 +6,14 @@ jar - 1.63.0-SNAPSHOT + 1.82.0-SNAPSHOT example-hostname https://github.com/grpc/grpc-java UTF-8 - 1.63.0-SNAPSHOT - 3.25.1 + 1.82.0-SNAPSHOT + 3.25.8 1.8 1.8 @@ -34,32 +34,21 @@ io.grpc - grpc-protobuf + grpc-services io.grpc - grpc-stub + grpc-protobuf io.grpc - grpc-services - - - org.apache.tomcat - annotations-api - 6.0.53 - provided + grpc-stub io.grpc grpc-netty-shaded runtime - - com.google.guava - guava - 32.1.3-jre - junit junit @@ -103,7 +92,7 @@ org.apache.maven.plugins maven-enforcer-plugin - 1.4.1 + 3.5.0 enforce diff --git a/examples/example-hostname/settings.gradle b/examples/example-hostname/settings.gradle index aa159eb0946..5bd641b3fc1 100644 --- a/examples/example-hostname/settings.gradle +++ b/examples/example-hostname/settings.gradle @@ -1 +1,17 @@ +pluginManagement { + // https://issuetracker.google.com/issues/342522142#comment8 + // use D8/R8 8.0.44 or 8.1.44 with AGP 7.4 if needed. + buildscript { + repositories { + mavenCentral() + maven { + url = uri("https://storage.googleapis.com/r8-releases/raw") + } + } + dependencies { + classpath("com.android.tools:r8:8.1.44") + } + } +} + rootProject.name = 'hostname' diff --git a/examples/example-hostname/src/main/java/io/grpc/examples/hostname/HostnameServer.java b/examples/example-hostname/src/main/java/io/grpc/examples/hostname/HostnameServer.java index 3c63296d7fa..7baa2d4733d 100644 --- a/examples/example-hostname/src/main/java/io/grpc/examples/hostname/HostnameServer.java +++ b/examples/example-hostname/src/main/java/io/grpc/examples/hostname/HostnameServer.java @@ -21,7 +21,7 @@ import io.grpc.Server; import io.grpc.ServerBuilder; import io.grpc.health.v1.HealthCheckResponse.ServingStatus; -import io.grpc.protobuf.services.ProtoReflectionService; +import io.grpc.protobuf.services.ProtoReflectionServiceV1; import io.grpc.services.HealthStatusManager; import java.io.IOException; import java.util.concurrent.TimeUnit; @@ -53,7 +53,7 @@ public static void main(String[] args) throws IOException, InterruptedException HealthStatusManager health = new HealthStatusManager(); final Server server = Grpc.newServerBuilderForPort(port, InsecureServerCredentials.create()) .addService(new HostnameGreeter(hostname)) - .addService(ProtoReflectionService.newInstance()) + .addService(ProtoReflectionServiceV1.newInstance()) .addService(health.getHealthService()) .build() .start(); @@ -64,17 +64,17 @@ public void run() { // Start graceful shutdown server.shutdown(); try { - // Wait for RPCs to complete processing - if (!server.awaitTermination(30, TimeUnit.SECONDS)) { - // That was plenty of time. Let's cancel the remaining RPCs - server.shutdownNow(); - // shutdownNow isn't instantaneous, so give a bit of time to clean resources up - // gracefully. Normally this will be well under a second. - server.awaitTermination(5, TimeUnit.SECONDS); - } + // Wait up to 30 seconds for RPCs to complete processing. + server.awaitTermination(30, TimeUnit.SECONDS); } catch (InterruptedException ex) { - server.shutdownNow(); + Thread.currentThread().interrupt(); } + // Cancel any remaining RPCs. If awaitTermination() returned true above, then there are no + // RPCs and the server is already terminated. But it is safe to call even when terminated. + server.shutdownNow(); + // shutdownNow isn't instantaneous, so you want an additional awaitTermination() to give + // time to clean resources up gracefully. Normally it will return in well under a second. In + // this example, the server.awaitTermination() in main() provides that delay. } }); // This would normally be tied to the service's dependencies. For example, if HostnameGreeter diff --git a/examples/example-hostname/src/test/java/io/grpc/examples/hostname/HostnameGreeterTest.java b/examples/example-hostname/src/test/java/io/grpc/examples/hostname/HostnameGreeterTest.java index 5420678d036..4165064b4b1 100644 --- a/examples/example-hostname/src/test/java/io/grpc/examples/hostname/HostnameGreeterTest.java +++ b/examples/example-hostname/src/test/java/io/grpc/examples/hostname/HostnameGreeterTest.java @@ -62,7 +62,7 @@ public void sayHello_dynamicHostname() throws Exception { InProcessServerBuilder.forName("hostname") .directExecutor().addService(new HostnameGreeter(null)).build().start()); - // Just verifing the service doesn't crash + // Just verifying the service doesn't crash HelloReply reply = blockingStub.sayHello(HelloRequest.newBuilder().setName("anonymous").build()); assertTrue(reply.getMessage(), reply.getMessage().startsWith("Hello anonymous, from ")); diff --git a/examples/example-jwt-auth/build.gradle b/examples/example-jwt-auth/build.gradle index df64207e865..36e6f08b3cc 100644 --- a/examples/example-jwt-auth/build.gradle +++ b/examples/example-jwt-auth/build.gradle @@ -1,15 +1,13 @@ plugins { // Provide convenience executables for trying out the examples. id 'application' - id 'com.google.protobuf' version '0.9.4' + id 'com.google.protobuf' version '0.9.5' // Generate IntelliJ IDEA's .idea & .iml project files id 'idea' } repositories { - maven { // The google mirror is less flaky than mavenCentral() - url "https://maven-central.storage-download.googleapis.com/maven2/" - } + mavenCentral() mavenLocal() } @@ -23,8 +21,8 @@ java { // Feel free to delete the comment at the next line. It is just for safely // updating the version in our release process. -def grpcVersion = '1.63.0-SNAPSHOT' // CURRENT_GRPC_VERSION -def protobufVersion = '3.25.1' +def grpcVersion = '1.82.0-SNAPSHOT' // CURRENT_GRPC_VERSION +def protobufVersion = '3.25.8' def protocVersion = protobufVersion dependencies { @@ -33,8 +31,6 @@ dependencies { implementation "io.jsonwebtoken:jjwt:0.9.1" implementation "javax.xml.bind:jaxb-api:2.3.1" - compileOnly "org.apache.tomcat:annotations-api:6.0.53" - runtimeOnly "io.grpc:grpc-netty-shaded:${grpcVersion}" testImplementation "io.grpc:grpc-testing:${grpcVersion}" @@ -53,16 +49,6 @@ protobuf { } } -// Inform IDEs like IntelliJ IDEA, Eclipse or NetBeans about the generated code. -sourceSets { - main { - java { - srcDirs 'build/generated/source/proto/main/grpc' - srcDirs 'build/generated/source/proto/main/java' - } - } -} - startScripts.enabled = false task hellowWorldJwtAuthServer(type: CreateStartScripts) { @@ -83,6 +69,8 @@ application { applicationDistribution.into('bin') { from(hellowWorldJwtAuthServer) from(hellowWorldJwtAuthClient) - fileMode = 0755 + filePermissions { + unix(0755) + } } } diff --git a/examples/example-jwt-auth/pom.xml b/examples/example-jwt-auth/pom.xml index 9b6e084aaf4..2989f61d4a0 100644 --- a/examples/example-jwt-auth/pom.xml +++ b/examples/example-jwt-auth/pom.xml @@ -7,15 +7,15 @@ jar - 1.63.0-SNAPSHOT + 1.82.0-SNAPSHOT example-jwt-auth https://github.com/grpc/grpc-java UTF-8 - 1.63.0-SNAPSHOT - 3.25.1 - 3.25.1 + 1.82.0-SNAPSHOT + 3.25.8 + 3.25.8 1.8 1.8 @@ -57,12 +57,6 @@ jaxb-api 2.3.1 - - org.apache.tomcat - annotations-api - 6.0.53 - provided - io.grpc grpc-testing @@ -94,7 +88,7 @@ org.xolstice.maven.plugins protobuf-maven-plugin - 0.5.1 + 0.6.1 com.google.protobuf:protoc:${protoc.version}:exe:${os.detected.classifier} @@ -116,7 +110,7 @@ org.apache.maven.plugins maven-enforcer-plugin - 1.4.1 + 3.5.0 enforce diff --git a/examples/example-jwt-auth/settings.gradle b/examples/example-jwt-auth/settings.gradle index 273558dd9cf..6bd0f0cdc2d 100644 --- a/examples/example-jwt-auth/settings.gradle +++ b/examples/example-jwt-auth/settings.gradle @@ -1,8 +1,19 @@ pluginManagement { - repositories { - maven { // The google mirror is less flaky than mavenCentral() - url "https://maven-central.storage-download.googleapis.com/maven2/" + // https://issuetracker.google.com/issues/342522142#comment8 + // use D8/R8 8.0.44 or 8.1.44 with AGP 7.4 if needed. + buildscript { + repositories { + mavenCentral() + maven { + url = uri("https://storage.googleapis.com/r8-releases/raw") + } + } + dependencies { + classpath("com.android.tools:r8:8.1.44") } + } + + repositories { gradlePluginPortal() } } diff --git a/examples/example-oauth/build.gradle b/examples/example-oauth/build.gradle index a5f2dac1395..3ad99a51d5d 100644 --- a/examples/example-oauth/build.gradle +++ b/examples/example-oauth/build.gradle @@ -1,15 +1,13 @@ plugins { // Provide convenience executables for trying out the examples. id 'application' - id 'com.google.protobuf' version '0.9.4' + id 'com.google.protobuf' version '0.9.5' // Generate IntelliJ IDEA's .idea & .iml project files id 'idea' } repositories { - maven { // The google mirror is less flaky than mavenCentral() - url "https://maven-central.storage-download.googleapis.com/maven2/" - } + mavenCentral() mavenLocal() } @@ -23,17 +21,15 @@ java { // Feel free to delete the comment at the next line. It is just for safely // updating the version in our release process. -def grpcVersion = '1.63.0-SNAPSHOT' // CURRENT_GRPC_VERSION -def protobufVersion = '3.25.1' +def grpcVersion = '1.82.0-SNAPSHOT' // CURRENT_GRPC_VERSION +def protobufVersion = '3.25.8' def protocVersion = protobufVersion dependencies { implementation "io.grpc:grpc-protobuf:${grpcVersion}" implementation "io.grpc:grpc-stub:${grpcVersion}" implementation "io.grpc:grpc-auth:${grpcVersion}" - implementation "com.google.auth:google-auth-library-oauth2-http:1.18.0" - - compileOnly "org.apache.tomcat:annotations-api:6.0.53" + implementation "com.google.auth:google-auth-library-oauth2-http:1.42.1" runtimeOnly "io.grpc:grpc-netty-shaded:${grpcVersion}" @@ -53,16 +49,6 @@ protobuf { } } -// Inform IDEs like IntelliJ IDEA, Eclipse or NetBeans about the generated code. -sourceSets { - main { - java { - srcDirs 'build/generated/source/proto/main/grpc' - srcDirs 'build/generated/source/proto/main/java' - } - } -} - startScripts.enabled = false task hellowWorldOauthServer(type: CreateStartScripts) { @@ -83,6 +69,8 @@ application { applicationDistribution.into('bin') { from(hellowWorldOauthServer) from(hellowWorldOauthClient) - fileMode = 0755 + filePermissions { + unix(0755) + } } } diff --git a/examples/example-oauth/gradle/wrapper/gradle-wrapper.jar b/examples/example-oauth/gradle/wrapper/gradle-wrapper.jar deleted file mode 100644 index 249e5832f09..00000000000 Binary files a/examples/example-oauth/gradle/wrapper/gradle-wrapper.jar and /dev/null differ diff --git a/examples/example-oauth/gradle/wrapper/gradle-wrapper.properties b/examples/example-oauth/gradle/wrapper/gradle-wrapper.properties deleted file mode 100644 index ae04661ee73..00000000000 --- a/examples/example-oauth/gradle/wrapper/gradle-wrapper.properties +++ /dev/null @@ -1,5 +0,0 @@ -distributionBase=GRADLE_USER_HOME -distributionPath=wrapper/dists -distributionUrl=https\://services.gradle.org/distributions/gradle-7.5.1-bin.zip -zipStoreBase=GRADLE_USER_HOME -zipStorePath=wrapper/dists diff --git a/examples/example-oauth/gradlew b/examples/example-oauth/gradlew deleted file mode 100755 index a69d9cb6c20..00000000000 --- a/examples/example-oauth/gradlew +++ /dev/null @@ -1,240 +0,0 @@ -#!/bin/sh - -# -# Copyright © 2015-2021 the original authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -############################################################################## -# -# Gradle start up script for POSIX generated by Gradle. -# -# Important for running: -# -# (1) You need a POSIX-compliant shell to run this script. If your /bin/sh is -# noncompliant, but you have some other compliant shell such as ksh or -# bash, then to run this script, type that shell name before the whole -# command line, like: -# -# ksh Gradle -# -# Busybox and similar reduced shells will NOT work, because this script -# requires all of these POSIX shell features: -# * functions; -# * expansions «$var», «${var}», «${var:-default}», «${var+SET}», -# «${var#prefix}», «${var%suffix}», and «$( cmd )»; -# * compound commands having a testable exit status, especially «case»; -# * various built-in commands including «command», «set», and «ulimit». -# -# Important for patching: -# -# (2) This script targets any POSIX shell, so it avoids extensions provided -# by Bash, Ksh, etc; in particular arrays are avoided. -# -# The "traditional" practice of packing multiple parameters into a -# space-separated string is a well documented source of bugs and security -# problems, so this is (mostly) avoided, by progressively accumulating -# options in "$@", and eventually passing that to Java. -# -# Where the inherited environment variables (DEFAULT_JVM_OPTS, JAVA_OPTS, -# and GRADLE_OPTS) rely on word-splitting, this is performed explicitly; -# see the in-line comments for details. -# -# There are tweaks for specific operating systems such as AIX, CygWin, -# Darwin, MinGW, and NonStop. -# -# (3) This script is generated from the Groovy template -# https://github.com/gradle/gradle/blob/master/subprojects/plugins/src/main/resources/org/gradle/api/internal/plugins/unixStartScript.txt -# within the Gradle project. -# -# You can find Gradle at https://github.com/gradle/gradle/. -# -############################################################################## - -# Attempt to set APP_HOME - -# Resolve links: $0 may be a link -app_path=$0 - -# Need this for daisy-chained symlinks. -while - APP_HOME=${app_path%"${app_path##*/}"} # leaves a trailing /; empty if no leading path - [ -h "$app_path" ] -do - ls=$( ls -ld "$app_path" ) - link=${ls#*' -> '} - case $link in #( - /*) app_path=$link ;; #( - *) app_path=$APP_HOME$link ;; - esac -done - -APP_HOME=$( cd "${APP_HOME:-./}" && pwd -P ) || exit - -APP_NAME="Gradle" -APP_BASE_NAME=${0##*/} - -# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. -DEFAULT_JVM_OPTS='"-Xmx64m" "-Xms64m"' - -# Use the maximum available, or set MAX_FD != -1 to use that value. -MAX_FD=maximum - -warn () { - echo "$*" -} >&2 - -die () { - echo - echo "$*" - echo - exit 1 -} >&2 - -# OS specific support (must be 'true' or 'false'). -cygwin=false -msys=false -darwin=false -nonstop=false -case "$( uname )" in #( - CYGWIN* ) cygwin=true ;; #( - Darwin* ) darwin=true ;; #( - MSYS* | MINGW* ) msys=true ;; #( - NONSTOP* ) nonstop=true ;; -esac - -CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar - - -# Determine the Java command to use to start the JVM. -if [ -n "$JAVA_HOME" ] ; then - if [ -x "$JAVA_HOME/jre/sh/java" ] ; then - # IBM's JDK on AIX uses strange locations for the executables - JAVACMD=$JAVA_HOME/jre/sh/java - else - JAVACMD=$JAVA_HOME/bin/java - fi - if [ ! -x "$JAVACMD" ] ; then - die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME - -Please set the JAVA_HOME variable in your environment to match the -location of your Java installation." - fi -else - JAVACMD=java - which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. - -Please set the JAVA_HOME variable in your environment to match the -location of your Java installation." -fi - -# Increase the maximum file descriptors if we can. -if ! "$cygwin" && ! "$darwin" && ! "$nonstop" ; then - case $MAX_FD in #( - max*) - MAX_FD=$( ulimit -H -n ) || - warn "Could not query maximum file descriptor limit" - esac - case $MAX_FD in #( - '' | soft) :;; #( - *) - ulimit -n "$MAX_FD" || - warn "Could not set maximum file descriptor limit to $MAX_FD" - esac -fi - -# Collect all arguments for the java command, stacking in reverse order: -# * args from the command line -# * the main class name -# * -classpath -# * -D...appname settings -# * --module-path (only if needed) -# * DEFAULT_JVM_OPTS, JAVA_OPTS, and GRADLE_OPTS environment variables. - -# For Cygwin or MSYS, switch paths to Windows format before running java -if "$cygwin" || "$msys" ; then - APP_HOME=$( cygpath --path --mixed "$APP_HOME" ) - CLASSPATH=$( cygpath --path --mixed "$CLASSPATH" ) - - JAVACMD=$( cygpath --unix "$JAVACMD" ) - - # Now convert the arguments - kludge to limit ourselves to /bin/sh - for arg do - if - case $arg in #( - -*) false ;; # don't mess with options #( - /?*) t=${arg#/} t=/${t%%/*} # looks like a POSIX filepath - [ -e "$t" ] ;; #( - *) false ;; - esac - then - arg=$( cygpath --path --ignore --mixed "$arg" ) - fi - # Roll the args list around exactly as many times as the number of - # args, so each arg winds up back in the position where it started, but - # possibly modified. - # - # NB: a `for` loop captures its iteration list before it begins, so - # changing the positional parameters here affects neither the number of - # iterations, nor the values presented in `arg`. - shift # remove old arg - set -- "$@" "$arg" # push replacement arg - done -fi - -# Collect all arguments for the java command; -# * $DEFAULT_JVM_OPTS, $JAVA_OPTS, and $GRADLE_OPTS can contain fragments of -# shell script including quotes and variable substitutions, so put them in -# double quotes to make sure that they get re-expanded; and -# * put everything else in single quotes, so that it's not re-expanded. - -set -- \ - "-Dorg.gradle.appname=$APP_BASE_NAME" \ - -classpath "$CLASSPATH" \ - org.gradle.wrapper.GradleWrapperMain \ - "$@" - -# Stop when "xargs" is not available. -if ! command -v xargs >/dev/null 2>&1 -then - die "xargs is not available" -fi - -# Use "xargs" to parse quoted args. -# -# With -n1 it outputs one arg per line, with the quotes and backslashes removed. -# -# In Bash we could simply go: -# -# readarray ARGS < <( xargs -n1 <<<"$var" ) && -# set -- "${ARGS[@]}" "$@" -# -# but POSIX shell has neither arrays nor command substitution, so instead we -# post-process each arg (as a line of input to sed) to backslash-escape any -# character that might be a shell metacharacter, then use eval to reverse -# that process (while maintaining the separation between arguments), and wrap -# the whole thing up as a single "set" statement. -# -# This will of course break if any of these variables contains a newline or -# an unmatched quote. -# - -eval "set -- $( - printf '%s\n' "$DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS" | - xargs -n1 | - sed ' s~[^-[:alnum:]+,./:=@_]~\\&~g; ' | - tr '\n' ' ' - )" '"$@"' - -exec "$JAVACMD" "$@" diff --git a/examples/example-oauth/gradlew.bat b/examples/example-oauth/gradlew.bat deleted file mode 100644 index 53a6b238d41..00000000000 --- a/examples/example-oauth/gradlew.bat +++ /dev/null @@ -1,91 +0,0 @@ -@rem -@rem Copyright 2015 the original author or authors. -@rem -@rem Licensed under the Apache License, Version 2.0 (the "License"); -@rem you may not use this file except in compliance with the License. -@rem You may obtain a copy of the License at -@rem -@rem https://www.apache.org/licenses/LICENSE-2.0 -@rem -@rem Unless required by applicable law or agreed to in writing, software -@rem distributed under the License is distributed on an "AS IS" BASIS, -@rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -@rem See the License for the specific language governing permissions and -@rem limitations under the License. -@rem - -@if "%DEBUG%"=="" @echo off -@rem ########################################################################## -@rem -@rem Gradle startup script for Windows -@rem -@rem ########################################################################## - -@rem Set local scope for the variables with windows NT shell -if "%OS%"=="Windows_NT" setlocal - -set DIRNAME=%~dp0 -if "%DIRNAME%"=="" set DIRNAME=. -set APP_BASE_NAME=%~n0 -set APP_HOME=%DIRNAME% - -@rem Resolve any "." and ".." in APP_HOME to make it shorter. -for %%i in ("%APP_HOME%") do set APP_HOME=%%~fi - -@rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. -set DEFAULT_JVM_OPTS="-Xmx64m" "-Xms64m" - -@rem Find java.exe -if defined JAVA_HOME goto findJavaFromJavaHome - -set JAVA_EXE=java.exe -%JAVA_EXE% -version >NUL 2>&1 -if %ERRORLEVEL% equ 0 goto execute - -echo. -echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. -echo. -echo Please set the JAVA_HOME variable in your environment to match the -echo location of your Java installation. - -goto fail - -:findJavaFromJavaHome -set JAVA_HOME=%JAVA_HOME:"=% -set JAVA_EXE=%JAVA_HOME%/bin/java.exe - -if exist "%JAVA_EXE%" goto execute - -echo. -echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME% -echo. -echo Please set the JAVA_HOME variable in your environment to match the -echo location of your Java installation. - -goto fail - -:execute -@rem Setup the command line - -set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar - - -@rem Execute Gradle -"%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %* - -:end -@rem End local scope for the variables with windows NT shell -if %ERRORLEVEL% equ 0 goto mainEnd - -:fail -rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of -rem the _cmd.exe /c_ return code! -set EXIT_CODE=%ERRORLEVEL% -if %EXIT_CODE% equ 0 set EXIT_CODE=1 -if not ""=="%GRADLE_EXIT_CONSOLE%" exit %EXIT_CODE% -exit /b %EXIT_CODE% - -:mainEnd -if "%OS%"=="Windows_NT" endlocal - -:omega diff --git a/examples/example-oauth/pom.xml b/examples/example-oauth/pom.xml index 7fc1b963702..3d88e732829 100644 --- a/examples/example-oauth/pom.xml +++ b/examples/example-oauth/pom.xml @@ -7,15 +7,15 @@ jar - 1.63.0-SNAPSHOT + 1.82.0-SNAPSHOT example-oauth https://github.com/grpc/grpc-java UTF-8 - 1.63.0-SNAPSHOT - 3.25.1 - 3.25.1 + 1.82.0-SNAPSHOT + 3.25.8 + 3.25.8 1.8 1.8 @@ -30,6 +30,11 @@ pom import + + com.google.code.gson + gson + 2.13.2 + @@ -50,23 +55,11 @@ io.grpc grpc-auth - - - com.google.auth - google-auth-library-credentials - - com.google.auth google-auth-library-oauth2-http - 1.18.0 - - - org.apache.tomcat - annotations-api - 6.0.53 - provided + 1.40.0 io.grpc @@ -99,7 +92,7 @@ org.xolstice.maven.plugins protobuf-maven-plugin - 0.5.1 + 0.6.1 com.google.protobuf:protoc:${protoc.version}:exe:${os.detected.classifier} @@ -121,7 +114,7 @@ org.apache.maven.plugins maven-enforcer-plugin - 1.4.1 + 3.5.0 enforce diff --git a/examples/example-oauth/settings.gradle b/examples/example-oauth/settings.gradle index 273558dd9cf..6bd0f0cdc2d 100644 --- a/examples/example-oauth/settings.gradle +++ b/examples/example-oauth/settings.gradle @@ -1,8 +1,19 @@ pluginManagement { - repositories { - maven { // The google mirror is less flaky than mavenCentral() - url "https://maven-central.storage-download.googleapis.com/maven2/" + // https://issuetracker.google.com/issues/342522142#comment8 + // use D8/R8 8.0.44 or 8.1.44 with AGP 7.4 if needed. + buildscript { + repositories { + mavenCentral() + maven { + url = uri("https://storage.googleapis.com/r8-releases/raw") + } + } + dependencies { + classpath("com.android.tools:r8:8.1.44") } + } + + repositories { gradlePluginPortal() } } diff --git a/examples/example-opentelemetry/README.md b/examples/example-opentelemetry/README.md new file mode 100644 index 00000000000..4515d5423ff --- /dev/null +++ b/examples/example-opentelemetry/README.md @@ -0,0 +1,54 @@ +gRPC OpenTelemetry Example +================ + +The example extends the gRPC "hello world" example by modifying the client and server to +showcase a sample configuration for gRPC OpenTelemetry with a Prometheus exporter. + +The example requires grpc-java to be pre-built. Using a release tag will download the relevant binaries +from a maven repository. But if you need the latest SNAPSHOT binaries you will need to follow +[COMPILING](../../COMPILING.md) to build these. + +### Build the example + +The source code is [here](src/main/java/io/grpc/examples/opentelemetry). +To build the example, run in this directory: +``` +$ ../gradlew installDist +``` +The build creates scripts `opentelemetry-server` and `opentelemetry-client` in the `build/install/example-opentelemetry/bin/` directory +which can be used to run this example. The example requires the server to be running before starting the +client. + +### Run the example + +**opentelemetry-server**: + +The opentelemetry-server accepts optional arguments for server-port and prometheus-port: + +```text +USAGE: opentelemetry-server [server-port [prometheus-port]] +``` + +**opentelemetry-client**: + +The opentelemetry-client accepts optional arguments for user-name, target and prometheus-port: + +```text +USAGE: opentelemetry-client-client [user-name [target [prometheus-port]]] +``` + +The opentelemetry-client continuously sends an RPC to the server every second. + +To make sure that the server and client metrics are being exported properly, in +a separate terminal, run the following: + +``` +$ curl localhost:9464/metrics +``` + +``` +$ curl localhost:9465/metrics +``` + +> ***NOTE:*** If the prometheus endpoint configured is overridden, please update the target in the +> above curl command. diff --git a/examples/example-opentelemetry/build.gradle b/examples/example-opentelemetry/build.gradle new file mode 100644 index 00000000000..8515f015c92 --- /dev/null +++ b/examples/example-opentelemetry/build.gradle @@ -0,0 +1,90 @@ +plugins { + // Provide convenience executables for trying out the examples. + id 'application' + id 'com.google.protobuf' version '0.9.5' + // Generate IntelliJ IDEA's .idea & .iml project files + id 'idea' +} + +repositories { + mavenCentral() + mavenLocal() +} + +java { + sourceCompatibility = JavaVersion.VERSION_1_8 + targetCompatibility = JavaVersion.VERSION_1_8 +} + +// IMPORTANT: You probably want the non-SNAPSHOT version of gRPC. Make sure you +// are looking at a tagged version of the example and not "master"! + +// Feel free to delete the comment at the next line. It is just for safely +// updating the version in our release process. +def grpcVersion = '1.82.0-SNAPSHOT' // CURRENT_GRPC_VERSION +def protocVersion = '3.25.8' +def openTelemetryVersion = '1.56.0' +def openTelemetryPrometheusVersion = '1.56.0-alpha' + +dependencies { + implementation "io.grpc:grpc-protobuf:${grpcVersion}" + implementation "io.grpc:grpc-stub:${grpcVersion}" + implementation "io.grpc:grpc-opentelemetry:${grpcVersion}" + implementation "io.opentelemetry:opentelemetry-sdk:${openTelemetryVersion}" + implementation "io.opentelemetry:opentelemetry-sdk-metrics:${openTelemetryVersion}" + implementation "io.opentelemetry:opentelemetry-exporter-logging:${openTelemetryVersion}" + implementation "io.opentelemetry:opentelemetry-exporter-prometheus:${openTelemetryPrometheusVersion}" + runtimeOnly "io.grpc:grpc-netty-shaded:${grpcVersion}" +} + +protobuf { + protoc { artifact = "com.google.protobuf:protoc:${protocVersion}" } + plugins { + grpc { artifact = "io.grpc:protoc-gen-grpc-java:${grpcVersion}" } + } + generateProtoTasks { + all()*.plugins { grpc {} } + } +} + +startScripts.enabled = false + +task OpenTelemetryHelloWorldServer(type: CreateStartScripts) { + mainClass = 'io.grpc.examples.opentelemetry.OpenTelemetryServer' + applicationName = 'opentelemetry-server' + outputDir = new File(project.buildDir, 'tmp/scripts/' + name) + classpath = startScripts.classpath +} + +task OpenTelemetryHelloWorldClient(type: CreateStartScripts) { + mainClass = 'io.grpc.examples.opentelemetry.OpenTelemetryClient' + applicationName = 'opentelemetry-client' + outputDir = new File(project.buildDir, 'tmp/scripts/' + name) + classpath = startScripts.classpath +} + +task LoggingOpenTelemetryHelloWorldServer(type: CreateStartScripts) { + mainClass = 'io.grpc.examples.opentelemetry.logging.LoggingOpenTelemetryServer' + applicationName = 'logging-opentelemetry-server' + outputDir = new File(project.buildDir, 'tmp/scripts/' + name) + classpath = startScripts.classpath +} + +task LoggingOpenTelemetryHelloWorldClient(type: CreateStartScripts) { + mainClass = 'io.grpc.examples.opentelemetry.logging.LoggingOpenTelemetryClient' + applicationName = 'logging-opentelemetry-client' + outputDir = new File(project.buildDir, 'tmp/scripts/' + name) + classpath = startScripts.classpath +} + +application { + applicationDistribution.into('bin') { + from(OpenTelemetryHelloWorldServer) + from(OpenTelemetryHelloWorldClient) + from(LoggingOpenTelemetryHelloWorldServer) + from(LoggingOpenTelemetryHelloWorldClient) + filePermissions { + unix(0755) + } + } +} diff --git a/examples/example-opentelemetry/settings.gradle b/examples/example-opentelemetry/settings.gradle new file mode 100644 index 00000000000..26e3bea044b --- /dev/null +++ b/examples/example-opentelemetry/settings.gradle @@ -0,0 +1,17 @@ +pluginManagement { + // https://issuetracker.google.com/issues/342522142#comment8 + // use D8/R8 8.0.44 or 8.1.44 with AGP 7.4 if needed. + buildscript { + repositories { + mavenCentral() + maven { + url = uri("https://storage.googleapis.com/r8-releases/raw") + } + } + dependencies { + classpath("com.android.tools:r8:8.1.44") + } + } +} + +rootProject.name = 'example-opentelemetry' diff --git a/examples/example-opentelemetry/src/main/java/io/grpc/example/opentelemetry/OpenTelemetryClient.java b/examples/example-opentelemetry/src/main/java/io/grpc/example/opentelemetry/OpenTelemetryClient.java new file mode 100644 index 00000000000..a21d711750f --- /dev/null +++ b/examples/example-opentelemetry/src/main/java/io/grpc/example/opentelemetry/OpenTelemetryClient.java @@ -0,0 +1,154 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.examples.opentelemetry; + +import io.grpc.Channel; +import io.grpc.Grpc; +import io.grpc.InsecureChannelCredentials; +import io.grpc.ManagedChannel; +import io.grpc.ManagedChannelBuilder; +import io.grpc.StatusRuntimeException; +import io.grpc.examples.helloworld.GreeterGrpc; +import io.grpc.examples.helloworld.HelloReply; +import io.grpc.examples.helloworld.HelloRequest; +import io.grpc.opentelemetry.GrpcOpenTelemetry; +import io.opentelemetry.exporter.prometheus.PrometheusHttpServer; +import io.opentelemetry.sdk.OpenTelemetrySdk; +import io.opentelemetry.sdk.metrics.SdkMeterProvider; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.logging.Level; +import java.util.logging.Logger; + +/** + * A simple gRPC client that requests a greeting from the {@link HelloWorldServer} and + * generates gRPC OpenTelmetry metrics data based on the configuration. + */ +public class OpenTelemetryClient { + private static final Logger logger = Logger.getLogger(OpenTelemetryClient.class.getName()); + + private final GreeterGrpc.GreeterBlockingStub blockingStub; + + /** Construct client for accessing HelloWorld server using the existing channel. */ + public OpenTelemetryClient(Channel channel) { + blockingStub = GreeterGrpc.newBlockingStub(channel); + } + + /** Say hello to server. */ + public void greet(String name) { + logger.info("Will try to greet " + name + " ..."); + HelloRequest request = HelloRequest.newBuilder().setName(name).build(); + HelloReply response; + try { + response = blockingStub.sayHello(request); + } catch (StatusRuntimeException e) { + logger.log(Level.WARNING, "RPC failed: {0}", e.getStatus()); + return; + } + logger.info("Greeting: " + response.getMessage()); + } + + /** + * Greet server. If provided, the first element of {@code args} is the name to use in the + * greeting. The second argument is the target server. + */ + public static void main(String[] args) throws Exception { + String user = "world"; + // Access a service running on the local machine on port 50051 + String target = "localhost:50051"; + // The port on which prometheus metrics are exposed. + int prometheusPort = 9465; + AtomicBoolean sendRpcs = new AtomicBoolean(true); + if (args.length > 0) { + if ("--help".equals(args[0])) { + System.err.println("Usage: [name [target [prometheusPort]]]"); + System.err.println(""); + System.err.println(" name The name you wish to be greeted by. Defaults to " + user); + System.err.println(" target The server to connect to. Defaults to " + target); + System.err.println(" prometheusPort The port to expose prometheus metrics. Defaults to " + prometheusPort); + System.exit(1); + } + user = args[0]; + } + if (args.length > 1) { + target = args[1]; + } + if (args.length > 2) { + prometheusPort = Integer.parseInt(args[2]); + } + + Thread mainThread = Thread.currentThread(); + + Runtime.getRuntime().addShutdownHook(new Thread() { + @Override + public void run() { + // Use stderr here since the logger may have been reset by its JVM shutdown hook. + System.err.println("*** shutting down gRPC client since JVM is shutting down"); + + sendRpcs.set(false); + try { + mainThread.join(); + } catch (InterruptedException e) { + e.printStackTrace(System.err); + } + System.err.println("*** client shut down"); + } + }); + + // Adds a PrometheusHttpServer to convert OpenTelemetry metrics to Prometheus format and + // expose these via a HttpServer exporter to the SdkMeterProvider. + SdkMeterProvider sdkMeterProvider = SdkMeterProvider.builder() + .registerMetricReader( + PrometheusHttpServer.builder().setPort(prometheusPort).build()) + .build(); + + // Initialize OpenTelemetry SDK with MeterProvider configured with Prometeheus. + OpenTelemetrySdk openTelemetrySdk = + OpenTelemetrySdk.builder().setMeterProvider(sdkMeterProvider).build(); + + // Initialize gRPC OpenTelemetry. + // Following client metrics are enabled by default : + // 1. grpc.client.attempt.started + // 2. grpc.client.attempt.sent_total_compressed_message_size + // 3. grpc.client.attempt.rcvd_total_compressed_message_size + // 4. grpc.client.attempt.duration + // 5. grpc.client.call.duration + GrpcOpenTelemetry grpcOpenTelmetry = GrpcOpenTelemetry.newBuilder() + .sdk(openTelemetrySdk) + .build(); + // Registers gRPC OpenTelemetry globally. + grpcOpenTelmetry.registerGlobal(); + + // Create a communication channel to the server, known as a Channel. + ManagedChannel channel = Grpc.newChannelBuilder(target, InsecureChannelCredentials.create()) + .build(); + OpenTelemetryClient client = new OpenTelemetryClient(channel); + + try { + // Run RPCs every second. + while (sendRpcs.get()) { + client.greet(user); + // Sleep for a bit before sending the next RPC. + Thread.sleep(1000); + } + } finally { + channel.shutdownNow().awaitTermination(5, TimeUnit.SECONDS); + // Shut down OpenTelemetry SDK. + openTelemetrySdk.close(); + } + } +} diff --git a/examples/example-opentelemetry/src/main/java/io/grpc/example/opentelemetry/OpenTelemetryServer.java b/examples/example-opentelemetry/src/main/java/io/grpc/example/opentelemetry/OpenTelemetryServer.java new file mode 100644 index 00000000000..3601572ab4b --- /dev/null +++ b/examples/example-opentelemetry/src/main/java/io/grpc/example/opentelemetry/OpenTelemetryServer.java @@ -0,0 +1,142 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.examples.opentelemetry; + +import io.grpc.Grpc; +import io.grpc.InsecureServerCredentials; +import io.grpc.Server; +import io.grpc.examples.helloworld.GreeterGrpc; +import io.grpc.examples.helloworld.HelloReply; +import io.grpc.examples.helloworld.HelloRequest; +import io.grpc.opentelemetry.GrpcOpenTelemetry; +import io.grpc.stub.StreamObserver; +import io.opentelemetry.exporter.prometheus.PrometheusHttpServer; +import io.opentelemetry.sdk.OpenTelemetrySdk; +import io.opentelemetry.sdk.metrics.SdkMeterProvider; +import java.io.IOException; +import java.util.concurrent.TimeUnit; +import java.util.logging.Logger; + +/** + * gRPC server that manages startup/shutdown of a {@code Greeter} server and generates + * gRPC OpenTelemetry metrics data based on the configuration. + */ +public class OpenTelemetryServer { + private static final Logger logger = Logger.getLogger(OpenTelemetryServer.class.getName()); + + private Server server; + private void start(int port) throws IOException { + server = Grpc.newServerBuilderForPort(port, InsecureServerCredentials.create()) + .addService(new GreeterImpl()) + .build() + .start(); + logger.info("Server started, listening on " + port); + } + + private void stop() throws InterruptedException { + if (server != null) { + server.shutdown().awaitTermination(30, TimeUnit.SECONDS); + } + } + + /** + * Await termination on the main thread since the grpc library uses daemon threads. + */ + private void blockUntilShutdown() throws InterruptedException { + if (server != null) { + server.awaitTermination(); + } + } + + /** + * Main launches the server from the command line. + */ + public static void main(String[] args) throws IOException, InterruptedException { + // The port on which the server should run. + int port = 50051; + // The port on which prometheus metrics are exposed. + int prometheusPort = 9464; + + if (args.length > 0) { + if ("--help".equals(args[0])) { + System.err.println("Usage: [port [prometheus_port]]"); + System.err.println(""); + System.err.println(" port The port on which server will run. Defaults to " + port); + System.err.println(" prometheusPort The port to expose prometheus metrics. Defaults to " + prometheusPort); + System.exit(1); + } + port = Integer.parseInt(args[0]); + } + if (args.length > 1) { + prometheusPort = Integer.parseInt(args[1]); + } + + // Adds a PrometheusHttpServer to convert OpenTelemetry metrics to Prometheus format and + // expose these via a HttpServer exporter to the SdkMeterProvider. + SdkMeterProvider sdkMeterProvider = SdkMeterProvider.builder() + .registerMetricReader( + PrometheusHttpServer.builder().setPort(prometheusPort).build()) + .build(); + + // Initialize OpenTelemetry SDK with MeterProvider configured with Prometheus metrics exporter + OpenTelemetrySdk openTelemetrySdk = + OpenTelemetrySdk.builder().setMeterProvider(sdkMeterProvider).build(); + + // Initialize gRPC OpenTelemetry. + // Following client metrics are enabled by default : + // 1. grpc.server.call.started + // 2. grpc.server.call.sent_total_compressed_message_size + // 3. grpc.server.call.rcvd_total_compressed_message_size + // 4. grpc.server.call.duration + GrpcOpenTelemetry grpcOpenTelmetry = GrpcOpenTelemetry.newBuilder() + .sdk(openTelemetrySdk) + .build(); + // Registers gRPC OpenTelemetry globally. + grpcOpenTelmetry.registerGlobal(); + + final OpenTelemetryServer server = new OpenTelemetryServer(); + server.start(port); + + Runtime.getRuntime().addShutdownHook(new Thread() { + @Override + public void run() { + System.err.println("*** shutting down gRPC server since JVM is shutting down"); + try { + server.stop(); + } catch (InterruptedException e) { + e.printStackTrace(System.err); + } + // Shut down OpenTelemetry SDK. + openTelemetrySdk.close(); + + System.err.println("*** server shut down"); + } + }); + + server.blockUntilShutdown(); + } + + static class GreeterImpl extends GreeterGrpc.GreeterImplBase { + + @Override + public void sayHello(HelloRequest req, StreamObserver responseObserver) { + HelloReply reply = HelloReply.newBuilder().setMessage("Hello " + req.getName()).build(); + responseObserver.onNext(reply); + responseObserver.onCompleted(); + } + } +} diff --git a/examples/example-opentelemetry/src/main/java/io/grpc/example/opentelemetry/logging/LoggingOpenTelemetryClient.java b/examples/example-opentelemetry/src/main/java/io/grpc/example/opentelemetry/logging/LoggingOpenTelemetryClient.java new file mode 100644 index 00000000000..1a6d4966145 --- /dev/null +++ b/examples/example-opentelemetry/src/main/java/io/grpc/example/opentelemetry/logging/LoggingOpenTelemetryClient.java @@ -0,0 +1,154 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.examples.opentelemetry.logging; + +import io.grpc.Channel; +import io.grpc.Grpc; +import io.grpc.InsecureChannelCredentials; +import io.grpc.ManagedChannel; +import io.grpc.ManagedChannelBuilder; +import io.grpc.StatusRuntimeException; +import io.grpc.examples.helloworld.GreeterGrpc; +import io.grpc.examples.helloworld.HelloReply; +import io.grpc.examples.helloworld.HelloRequest; +import io.grpc.opentelemetry.GrpcOpenTelemetry; +import io.opentelemetry.exporter.logging.LoggingMetricExporter; +import io.opentelemetry.sdk.OpenTelemetrySdk; +import io.opentelemetry.sdk.metrics.SdkMeterProvider; +import io.opentelemetry.sdk.metrics.export.PeriodicMetricReader; +import java.time.Duration; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.logging.Level; +import java.util.logging.Logger; + +/** + * A simple gRPC client that requests a greeting from the {@link HelloWorldServer} and + * exports gRPC OpenTelmetry metrics data using {@code java.util.logging}. + */ +public class LoggingOpenTelemetryClient { + private static final Logger logger = Logger.getLogger(LoggingOpenTelemetryClient.class.getName()); + + private final GreeterGrpc.GreeterBlockingStub blockingStub; + + /** Construct client for accessing HelloWorld server using the existing channel. */ + public LoggingOpenTelemetryClient(Channel channel) { + blockingStub = GreeterGrpc.newBlockingStub(channel); + } + + /** Say hello to server. */ + public void greet(String name) { + logger.info("Will try to greet " + name + " ..."); + HelloRequest request = HelloRequest.newBuilder().setName(name).build(); + HelloReply response; + try { + response = blockingStub.sayHello(request); + } catch (StatusRuntimeException e) { + logger.log(Level.WARNING, "RPC failed: {0}", e.getStatus()); + return; + } + logger.info("Greeting: " + response.getMessage()); + } + + /** + * Greet server. If provided, the first element of {@code args} is the name to use in the + * greeting. The second argument is the target server. + */ + public static void main(String[] args) throws Exception { + String user = "world"; + // Access a service running on the local machine on port 50051 + String target = "localhost:50051"; + // The number of milliseconds between metric exports. + long metricExportInterval = 800L; + AtomicBoolean sendRpcs = new AtomicBoolean(true); + if (args.length > 0) { + if ("--help".equals(args[0])) { + System.err.println("Usage: [name [target]]"); + System.err.println(""); + System.err.println(" name The name you wish to be greeted by. Defaults to " + user); + System.err.println(" target The server to connect to. Defaults to " + target); + System.exit(1); + } + user = args[0]; + } + if (args.length > 1) { + target = args[1]; + } + + Thread mainThread = Thread.currentThread(); + + Runtime.getRuntime().addShutdownHook(new Thread() { + @Override + public void run() { + // Use stderr here since the logger may have been reset by its JVM shutdown hook. + System.err.println("*** shutting down gRPC client since JVM is shutting down"); + + sendRpcs.set(false); + try { + mainThread.join(); + } catch (InterruptedException e) { + e.printStackTrace(System.err); + } + System.err.println("*** client shut down"); + } + }); + + // Create an instance of PeriodicMetricReader and configure it to export + // via a logging exporter to the SdkMeterProvider. + SdkMeterProvider sdkMeterProvider = SdkMeterProvider.builder() + .registerMetricReader( + PeriodicMetricReader.builder(LoggingMetricExporter.create()) + .setInterval(Duration.ofMillis(metricExportInterval)) + .build()) + .build(); + + // Initialize OpenTelemetry SDK with MeterProvider configured with Prometeheus. + OpenTelemetrySdk openTelemetrySdk = + OpenTelemetrySdk.builder().setMeterProvider(sdkMeterProvider).build(); + + // Initialize gRPC OpenTelemetry. + // Following client metrics are enabled by default : + // 1. grpc.client.attempt.started + // 2. grpc.client.attempt.sent_total_compressed_message_size + // 3. grpc.client.attempt.rcvd_total_compressed_message_size + // 4. grpc.client.attempt.duration + // 5. grpc.client.call.duration + GrpcOpenTelemetry grpcOpenTelmetry = GrpcOpenTelemetry.newBuilder() + .sdk(openTelemetrySdk) + .build(); + // Registers gRPC OpenTelemetry globally. + grpcOpenTelmetry.registerGlobal(); + + // Create a communication channel to the server, known as a Channel. + ManagedChannel channel = Grpc.newChannelBuilder(target, InsecureChannelCredentials.create()) + .build(); + LoggingOpenTelemetryClient client = new LoggingOpenTelemetryClient(channel); + + try { + // Run RPCs every second. + while (sendRpcs.get()) { + client.greet(user); + // Sleep for a bit before sending the next RPC. + Thread.sleep(1000); + } + } finally { + channel.shutdownNow().awaitTermination(5, TimeUnit.SECONDS); + // Shut down OpenTelemetry SDK. + openTelemetrySdk.close(); + } + } +} diff --git a/examples/example-opentelemetry/src/main/java/io/grpc/example/opentelemetry/logging/LoggingOpenTelemetryServer.java b/examples/example-opentelemetry/src/main/java/io/grpc/example/opentelemetry/logging/LoggingOpenTelemetryServer.java new file mode 100644 index 00000000000..121898c3ab5 --- /dev/null +++ b/examples/example-opentelemetry/src/main/java/io/grpc/example/opentelemetry/logging/LoggingOpenTelemetryServer.java @@ -0,0 +1,144 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.examples.opentelemetry.logging; + +import io.grpc.Grpc; +import io.grpc.InsecureServerCredentials; +import io.grpc.Server; +import io.grpc.examples.helloworld.GreeterGrpc; +import io.grpc.examples.helloworld.HelloReply; +import io.grpc.examples.helloworld.HelloRequest; +import io.grpc.opentelemetry.GrpcOpenTelemetry; +import io.grpc.stub.StreamObserver; +import io.opentelemetry.exporter.logging.LoggingMetricExporter; +import io.opentelemetry.sdk.OpenTelemetrySdk; +import io.opentelemetry.sdk.metrics.SdkMeterProvider; +import io.opentelemetry.sdk.metrics.export.PeriodicMetricReader; +import java.time.Duration; +import java.io.IOException; +import java.util.concurrent.TimeUnit; +import java.util.logging.Logger; + +/** + * gRPC server that manages startup/shutdown of a {@code Greeter} server and exports + * gRPC OpenTelemetry metrics data using {@code java.util.logging}. + */ +public class LoggingOpenTelemetryServer { + private static final Logger logger = Logger.getLogger(LoggingOpenTelemetryServer.class.getName()); + + private Server gRPCServer; + private void start(int port) throws IOException { + gRPCServer = Grpc.newServerBuilderForPort(port, InsecureServerCredentials.create()) + .addService(new GreeterImpl()) + .build() + .start(); + logger.info("Server started, listening on " + port); + } + + private void stop() throws InterruptedException { + if (gRPCServer != null) { + gRPCServer.shutdown().awaitTermination(30, TimeUnit.SECONDS); + } + } + + /** + * Await termination on the main thread since the grpc library uses daemon threads. + */ + private void blockUntilShutdown() throws InterruptedException { + if (gRPCServer != null) { + gRPCServer.awaitTermination(); + } + } + + /** + * Main launches the server from the command line. + */ + public static void main(String[] args) throws IOException, InterruptedException { + // The port on which the server should run. + int port = 50051; + // The port on which prometheus metrics are exposed. + int prometheusPort = 9464; + // The number of milliseconds between metric exports. + long metricExportInterval = 800L; + + if (args.length > 0) { + if ("--help".equals(args[0])) { + System.err.println("Usage: [port]"); + System.err.println(""); + System.err.println(" port The port on which server will run. Defaults to " + port); + System.exit(1); + } + port = Integer.parseInt(args[0]); + } + + // Create an instance of PeriodicMetricReader and configure it to export + // via a logging exporter to the SdkMeterProvider. + SdkMeterProvider sdkMeterProvider = SdkMeterProvider.builder() + .registerMetricReader( + PeriodicMetricReader.builder(LoggingMetricExporter.create()) + .setInterval(Duration.ofMillis(metricExportInterval)) + .build()) + .build(); + + // Initialize OpenTelemetry SDK with MeterProvider configured with Logging metrics exporter + OpenTelemetrySdk openTelemetrySdk = + OpenTelemetrySdk.builder().setMeterProvider(sdkMeterProvider).build(); + + // Initialize gRPC OpenTelemetry. + // Following client metrics are enabled by default : + // 1. grpc.server.call.started + // 2. grpc.server.call.sent_total_compressed_message_size + // 3. grpc.server.call.rcvd_total_compressed_message_size + // 4. grpc.server.call.duration + GrpcOpenTelemetry grpcOpenTelmetry = GrpcOpenTelemetry.newBuilder() + .sdk(openTelemetrySdk) + .build(); + // Registers gRPC OpenTelemetry globally. + grpcOpenTelmetry.registerGlobal(); + + final LoggingOpenTelemetryServer server = new LoggingOpenTelemetryServer(); + server.start(port); + + Runtime.getRuntime().addShutdownHook(new Thread() { + @Override + public void run() { + System.err.println("*** shutting down gRPC server since JVM is shutting down"); + try { + server.stop(); + } catch (InterruptedException e) { + e.printStackTrace(System.err); + } + // Shut down OpenTelemetry SDK. + openTelemetrySdk.close(); + + System.err.println("*** server shut down"); + } + }); + + server.blockUntilShutdown(); + } + + static class GreeterImpl extends GreeterGrpc.GreeterImplBase { + + @Override + public void sayHello(HelloRequest req, StreamObserver responseObserver) { + HelloReply reply = HelloReply.newBuilder().setMessage("Hello " + req.getName()).build(); + responseObserver.onNext(reply); + responseObserver.onCompleted(); + } + } +} diff --git a/examples/example-opentelemetry/src/main/proto/helloworld/helloworld.proto b/examples/example-opentelemetry/src/main/proto/helloworld/helloworld.proto new file mode 100644 index 00000000000..64a8c09ee16 --- /dev/null +++ b/examples/example-opentelemetry/src/main/proto/helloworld/helloworld.proto @@ -0,0 +1,39 @@ +/* + * Copyright 2023 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +syntax = "proto3"; + +option java_multiple_files = true; +option java_package = "io.grpc.examples.helloworld"; +option java_outer_classname = "HelloWorldProto"; +option objc_class_prefix = "HLW"; + +package helloworld; + +// The greeting service definition. +service Greeter { + // Sends a greeting + rpc SayHello (HelloRequest) returns (HelloReply) {} +} + +// The request message containing the user's name. +message HelloRequest { + string name = 1; +} + +// The response message containing the greetings +message HelloReply { + string message = 1; +} diff --git a/examples/example-orca/build.gradle b/examples/example-orca/build.gradle index a18f7e363fa..65627159c9c 100644 --- a/examples/example-orca/build.gradle +++ b/examples/example-orca/build.gradle @@ -1,14 +1,12 @@ plugins { id 'application' // Provide convenience executables for trying out the examples. - id 'com.google.protobuf' version '0.9.4' + id 'com.google.protobuf' version '0.9.5' // Generate IntelliJ IDEA's .idea & .iml project files id 'idea' id 'java' } repositories { - maven { // The google mirror is less flaky than mavenCentral() - url "https://maven-central.storage-download.googleapis.com/maven2/" } mavenCentral() mavenLocal() } @@ -18,16 +16,14 @@ java { targetCompatibility = JavaVersion.VERSION_1_8 } -def grpcVersion = '1.63.0-SNAPSHOT' // CURRENT_GRPC_VERSION -def protocVersion = '3.25.1' +def grpcVersion = '1.82.0-SNAPSHOT' // CURRENT_GRPC_VERSION +def protocVersion = '3.25.8' dependencies { implementation "io.grpc:grpc-protobuf:${grpcVersion}" implementation "io.grpc:grpc-services:${grpcVersion}" implementation "io.grpc:grpc-stub:${grpcVersion}" implementation "io.grpc:grpc-xds:${grpcVersion}" - compileOnly "org.apache.tomcat:annotations-api:6.0.53" - } protobuf { @@ -60,6 +56,8 @@ application { applicationDistribution.into('bin') { from(CustomBackendMetricsClient) from(CustomBackendMetricsServer) - fileMode = 0755 + filePermissions { + unix(0755) + } } } diff --git a/examples/example-orca/settings.gradle b/examples/example-orca/settings.gradle index 3c62dc663ce..12536c0ca8d 100644 --- a/examples/example-orca/settings.gradle +++ b/examples/example-orca/settings.gradle @@ -1 +1,17 @@ +pluginManagement { + // https://issuetracker.google.com/issues/342522142#comment8 + // use D8/R8 8.0.44 or 8.1.44 with AGP 7.4 if needed. + buildscript { + repositories { + mavenCentral() + maven { + url = uri("https://storage.googleapis.com/r8-releases/raw") + } + } + dependencies { + classpath("com.android.tools:r8:8.1.44") + } + } +} + rootProject.name = 'example-orca' diff --git a/examples/example-reflection/README.md b/examples/example-reflection/README.md index 9bd91f3edb0..4bc30e84b3b 100644 --- a/examples/example-reflection/README.md +++ b/examples/example-reflection/README.md @@ -1,7 +1,7 @@ gRPC Reflection Example ================ -The reflection example has a Hello World server with `ProtoReflectionService` registered. +The reflection example has a Hello World server with `ProtoReflectionServiceV1` registered. ### Build the example @@ -45,7 +45,7 @@ Output ### List all the methods of a service ``` - $ grpcurl -plaintext localhost:50051 helloworld.Greeter + $ grpcurl -plaintext localhost:50051 list helloworld.Greeter ``` Output ``` diff --git a/examples/example-reflection/build.gradle b/examples/example-reflection/build.gradle index 8fc8c4e8cc8..7c54ea281d5 100644 --- a/examples/example-reflection/build.gradle +++ b/examples/example-reflection/build.gradle @@ -1,14 +1,12 @@ plugins { id 'application' // Provide convenience executables for trying out the examples. - id 'com.google.protobuf' version '0.9.4' + id 'com.google.protobuf' version '0.9.5' // Generate IntelliJ IDEA's .idea & .iml project files id 'idea' id 'java' } repositories { - maven { // The google mirror is less flaky than mavenCentral() - url "https://maven-central.storage-download.googleapis.com/maven2/" } mavenCentral() mavenLocal() } @@ -18,16 +16,14 @@ java { targetCompatibility = JavaVersion.VERSION_1_8 } -def grpcVersion = '1.63.0-SNAPSHOT' // CURRENT_GRPC_VERSION -def protocVersion = '3.25.1' +def grpcVersion = '1.82.0-SNAPSHOT' // CURRENT_GRPC_VERSION +def protocVersion = '3.25.8' dependencies { implementation "io.grpc:grpc-protobuf:${grpcVersion}" implementation "io.grpc:grpc-services:${grpcVersion}" implementation "io.grpc:grpc-stub:${grpcVersion}" implementation "io.grpc:grpc-netty-shaded:${grpcVersion}" - compileOnly "org.apache.tomcat:annotations-api:6.0.53" - } protobuf { @@ -52,6 +48,8 @@ task ReflectionServer(type: CreateStartScripts) { application { applicationDistribution.into('bin') { from(ReflectionServer) - fileMode = 0755 + filePermissions { + unix(0755) + } } } diff --git a/examples/example-reflection/settings.gradle b/examples/example-reflection/settings.gradle index dccb973085e..28e44b77905 100644 --- a/examples/example-reflection/settings.gradle +++ b/examples/example-reflection/settings.gradle @@ -1 +1,17 @@ +pluginManagement { + // https://issuetracker.google.com/issues/342522142#comment8 + // use D8/R8 8.0.44 or 8.1.44 with AGP 7.4 if needed. + buildscript { + repositories { + mavenCentral() + maven { + url = uri("https://storage.googleapis.com/r8-releases/raw") + } + } + dependencies { + classpath("com.android.tools:r8:8.1.44") + } + } +} + rootProject.name = 'example-reflection' diff --git a/examples/example-reflection/src/main/java/io/grpc/examples/reflection/ReflectionServer.java b/examples/example-reflection/src/main/java/io/grpc/examples/reflection/ReflectionServer.java index ad702247ba7..8406317aad6 100644 --- a/examples/example-reflection/src/main/java/io/grpc/examples/reflection/ReflectionServer.java +++ b/examples/example-reflection/src/main/java/io/grpc/examples/reflection/ReflectionServer.java @@ -7,7 +7,7 @@ import io.grpc.examples.helloworld.GreeterGrpc; import io.grpc.examples.helloworld.HelloReply; import io.grpc.examples.helloworld.HelloRequest; -import io.grpc.protobuf.services.ProtoReflectionService; +import io.grpc.protobuf.services.ProtoReflectionServiceV1; import io.grpc.stub.StreamObserver; import java.io.IOException; import java.util.concurrent.TimeUnit; @@ -26,7 +26,7 @@ private void start() throws IOException { int port = 50051; server = Grpc.newServerBuilderForPort(port, InsecureServerCredentials.create()) .addService(new GreeterImpl()) - .addService(ProtoReflectionService.newInstance()) // add reflection service + .addService(ProtoReflectionServiceV1.newInstance()) // add reflection service .build() .start(); logger.info("Server started, listening on " + port); diff --git a/examples/example-servlet/build.gradle b/examples/example-servlet/build.gradle index ec59cdde6df..b83d38be5b5 100644 --- a/examples/example-servlet/build.gradle +++ b/examples/example-servlet/build.gradle @@ -1,13 +1,12 @@ plugins { - id 'com.google.protobuf' version '0.9.4' + id 'com.google.protobuf' version '0.9.5' // Generate IntelliJ IDEA's .idea & .iml project files id 'idea' id 'war' } repositories { - maven { // The google mirror is less flaky than mavenCentral() - url "https://maven-central.storage-download.googleapis.com/maven2/" } + mavenCentral() mavenLocal() } @@ -16,16 +15,15 @@ java { targetCompatibility = JavaVersion.VERSION_1_8 } -def grpcVersion = '1.63.0-SNAPSHOT' // CURRENT_GRPC_VERSION -def protocVersion = '3.25.1' +def grpcVersion = '1.82.0-SNAPSHOT' // CURRENT_GRPC_VERSION +def protocVersion = '3.25.8' dependencies { implementation "io.grpc:grpc-protobuf:${grpcVersion}", "io.grpc:grpc-servlet:${grpcVersion}", "io.grpc:grpc-stub:${grpcVersion}" - compileOnly "javax.servlet:javax.servlet-api:4.0.1", - "org.apache.tomcat:annotations-api:6.0.53" + compileOnly "javax.servlet:javax.servlet-api:4.0.1" } protobuf { @@ -35,13 +33,3 @@ protobuf { all()*.plugins { grpc {} } } } - -// Inform IDEs like IntelliJ IDEA, Eclipse or NetBeans about the generated code. -sourceSets { - main { - java { - srcDirs 'build/generated/source/proto/main/grpc' - srcDirs 'build/generated/source/proto/main/java' - } - } -} diff --git a/examples/example-servlet/settings.gradle b/examples/example-servlet/settings.gradle index 273558dd9cf..6bd0f0cdc2d 100644 --- a/examples/example-servlet/settings.gradle +++ b/examples/example-servlet/settings.gradle @@ -1,8 +1,19 @@ pluginManagement { - repositories { - maven { // The google mirror is less flaky than mavenCentral() - url "https://maven-central.storage-download.googleapis.com/maven2/" + // https://issuetracker.google.com/issues/342522142#comment8 + // use D8/R8 8.0.44 or 8.1.44 with AGP 7.4 if needed. + buildscript { + repositories { + mavenCentral() + maven { + url = uri("https://storage.googleapis.com/r8-releases/raw") + } + } + dependencies { + classpath("com.android.tools:r8:8.1.44") } + } + + repositories { gradlePluginPortal() } } diff --git a/examples/example-tls/BUILD.bazel b/examples/example-tls/BUILD.bazel index 81913836766..cb46ef5bb30 100644 --- a/examples/example-tls/BUILD.bazel +++ b/examples/example-tls/BUILD.bazel @@ -1,5 +1,8 @@ -load("@rules_proto//proto:defs.bzl", "proto_library") +load("@com_google_protobuf//bazel:java_proto_library.bzl", "java_proto_library") +load("@com_google_protobuf//bazel:proto_library.bzl", "proto_library") load("@io_grpc_grpc_java//:java_grpc_library.bzl", "java_grpc_library") +load("@rules_java//java:java_binary.bzl", "java_binary") +load("@rules_java//java:java_library.bzl", "java_library") proto_library( name = "helloworld_proto", diff --git a/examples/example-tls/build.gradle b/examples/example-tls/build.gradle index 29246921b65..4fe0794d62b 100644 --- a/examples/example-tls/build.gradle +++ b/examples/example-tls/build.gradle @@ -1,15 +1,12 @@ plugins { // Provide convenience executables for trying out the examples. id 'application' - id 'com.google.protobuf' version '0.9.4' + id 'com.google.protobuf' version '0.9.5' // Generate IntelliJ IDEA's .idea & .iml project files id 'idea' } repositories { - maven { // The google mirror is less flaky than mavenCentral() - url "https://maven-central.storage-download.googleapis.com/maven2/" - } mavenCentral() mavenLocal() } @@ -24,13 +21,12 @@ java { // Feel free to delete the comment at the next line. It is just for safely // updating the version in our release process. -def grpcVersion = '1.63.0-SNAPSHOT' // CURRENT_GRPC_VERSION -def protocVersion = '3.25.1' +def grpcVersion = '1.82.0-SNAPSHOT' // CURRENT_GRPC_VERSION +def protocVersion = '3.25.8' dependencies { implementation "io.grpc:grpc-protobuf:${grpcVersion}" implementation "io.grpc:grpc-stub:${grpcVersion}" - compileOnly "org.apache.tomcat:annotations-api:6.0.53" runtimeOnly "io.grpc:grpc-netty-shaded:${grpcVersion}" } @@ -44,16 +40,6 @@ protobuf { } } -// Inform IDEs like IntelliJ IDEA, Eclipse or NetBeans about the generated code. -sourceSets { - main { - java { - srcDirs 'build/generated/source/proto/main/grpc' - srcDirs 'build/generated/source/proto/main/java' - } - } -} - startScripts.enabled = false task helloWorldTlsServer(type: CreateStartScripts) { @@ -74,6 +60,8 @@ application { applicationDistribution.into('bin') { from(helloWorldTlsServer) from(helloWorldTlsClient) - fileMode = 0755 + filePermissions { + unix(0755) + } } } diff --git a/examples/example-tls/pom.xml b/examples/example-tls/pom.xml index 53b9a2e54c0..dfe611e4fe7 100644 --- a/examples/example-tls/pom.xml +++ b/examples/example-tls/pom.xml @@ -6,14 +6,14 @@ jar - 1.63.0-SNAPSHOT + 1.82.0-SNAPSHOT example-tls https://github.com/grpc/grpc-java UTF-8 - 1.63.0-SNAPSHOT - 3.25.1 + 1.82.0-SNAPSHOT + 3.25.8 1.8 1.8 @@ -40,12 +40,6 @@ io.grpc grpc-stub - - org.apache.tomcat - annotations-api - 6.0.53 - provided - io.grpc grpc-netty-shaded @@ -82,7 +76,7 @@ org.apache.maven.plugins maven-enforcer-plugin - 1.4.1 + 3.5.0 enforce diff --git a/examples/example-tls/settings.gradle b/examples/example-tls/settings.gradle index 273558dd9cf..6bd0f0cdc2d 100644 --- a/examples/example-tls/settings.gradle +++ b/examples/example-tls/settings.gradle @@ -1,8 +1,19 @@ pluginManagement { - repositories { - maven { // The google mirror is less flaky than mavenCentral() - url "https://maven-central.storage-download.googleapis.com/maven2/" + // https://issuetracker.google.com/issues/342522142#comment8 + // use D8/R8 8.0.44 or 8.1.44 with AGP 7.4 if needed. + buildscript { + repositories { + mavenCentral() + maven { + url = uri("https://storage.googleapis.com/r8-releases/raw") + } + } + dependencies { + classpath("com.android.tools:r8:8.1.44") } + } + + repositories { gradlePluginPortal() } } diff --git a/examples/example-xds/build.gradle b/examples/example-xds/build.gradle index 4d25077a1e2..1974c86798e 100644 --- a/examples/example-xds/build.gradle +++ b/examples/example-xds/build.gradle @@ -1,14 +1,12 @@ plugins { id 'application' // Provide convenience executables for trying out the examples. - id 'com.google.protobuf' version '0.9.4' + id 'com.google.protobuf' version '0.9.5' // Generate IntelliJ IDEA's .idea & .iml project files id 'idea' id 'java' } repositories { - maven { // The google mirror is less flaky than mavenCentral() - url "https://maven-central.storage-download.googleapis.com/maven2/" } mavenCentral() mavenLocal() } @@ -23,15 +21,14 @@ java { // Feel free to delete the comment at the next line. It is just for safely // updating the version in our release process. -def grpcVersion = '1.63.0-SNAPSHOT' // CURRENT_GRPC_VERSION -def protocVersion = '3.25.1' +def grpcVersion = '1.82.0-SNAPSHOT' // CURRENT_GRPC_VERSION +def protocVersion = '3.25.8' dependencies { implementation "io.grpc:grpc-protobuf:${grpcVersion}" implementation "io.grpc:grpc-services:${grpcVersion}" implementation "io.grpc:grpc-stub:${grpcVersion}" implementation "io.grpc:grpc-xds:${grpcVersion}" - compileOnly "org.apache.tomcat:annotations-api:6.0.53" runtimeOnly "io.grpc:grpc-netty-shaded:${grpcVersion}" } @@ -66,6 +63,8 @@ application { applicationDistribution.into('bin') { from(xdsHelloWorldClient) from(xdsHelloWorldServer) - fileMode = 0755 + filePermissions { + unix(0755) + } } } diff --git a/examples/example-xds/settings.gradle b/examples/example-xds/settings.gradle index 878f1f23ae3..4197fa6760d 100644 --- a/examples/example-xds/settings.gradle +++ b/examples/example-xds/settings.gradle @@ -1 +1,17 @@ +pluginManagement { + // https://issuetracker.google.com/issues/342522142#comment8 + // use D8/R8 8.0.44 or 8.1.44 with AGP 7.4 if needed. + buildscript { + repositories { + mavenCentral() + maven { + url = uri("https://storage.googleapis.com/r8-releases/raw") + } + } + dependencies { + classpath("com.android.tools:r8:8.1.44") + } + } +} + rootProject.name = 'example-xds' diff --git a/examples/example-xds/src/main/java/io/grpc/examples/helloworldxds/XdsHelloWorldServer.java b/examples/example-xds/src/main/java/io/grpc/examples/helloworldxds/XdsHelloWorldServer.java index 93317dda23e..c7c67f8d681 100644 --- a/examples/example-xds/src/main/java/io/grpc/examples/helloworldxds/XdsHelloWorldServer.java +++ b/examples/example-xds/src/main/java/io/grpc/examples/helloworldxds/XdsHelloWorldServer.java @@ -20,7 +20,7 @@ import io.grpc.Server; import io.grpc.ServerCredentials; import io.grpc.health.v1.HealthCheckResponse.ServingStatus; -import io.grpc.protobuf.services.ProtoReflectionService; +import io.grpc.protobuf.services.ProtoReflectionServiceV1; import io.grpc.services.HealthStatusManager; import io.grpc.xds.XdsServerBuilder; import io.grpc.xds.XdsServerCredentials; @@ -66,7 +66,7 @@ public static void main(String[] args) throws IOException, InterruptedException final HealthStatusManager health = new HealthStatusManager(); final Server server = XdsServerBuilder.forPort(port, credentials) .addService(new HostnameGreeter(hostname)) - .addService(ProtoReflectionService.newInstance()) // convenient for command line tools + .addService(ProtoReflectionServiceV1.newInstance()) // convenient for command line tools .addService(health.getHealthService()) // allow management servers to monitor health .build() .start(); diff --git a/examples/gradle/wrapper/gradle-wrapper.properties b/examples/gradle/wrapper/gradle-wrapper.properties index db9a6b825d7..1e2fbf0d458 100644 --- a/examples/gradle/wrapper/gradle-wrapper.properties +++ b/examples/gradle/wrapper/gradle-wrapper.properties @@ -1,5 +1,5 @@ distributionBase=GRADLE_USER_HOME distributionPath=wrapper/dists -distributionUrl=https\://services.gradle.org/distributions/gradle-8.3-bin.zip +distributionUrl=https\://services.gradle.org/distributions/gradle-8.10.2-bin.zip zipStoreBase=GRADLE_USER_HOME zipStorePath=wrapper/dists diff --git a/examples/maven-assembly-jar-with-dependencies.xml b/examples/maven-assembly-jar-with-dependencies.xml new file mode 100644 index 00000000000..6c8abbfe7e8 --- /dev/null +++ b/examples/maven-assembly-jar-with-dependencies.xml @@ -0,0 +1,27 @@ + + + jar-with-dependencies + + jar + + false + + + / + true + true + runtime + + + + + metaInf-services + + + diff --git a/examples/pom.xml b/examples/pom.xml index 80a2d830fd4..943182b60fe 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -6,15 +6,15 @@ jar - 1.63.0-SNAPSHOT + 1.82.0-SNAPSHOT examples https://github.com/grpc/grpc-java UTF-8 - 1.63.0-SNAPSHOT - 3.25.1 - 3.25.1 + 1.82.0-SNAPSHOT + 3.25.8 + 3.25.8 1.8 1.8 @@ -35,16 +35,16 @@ io.grpc - grpc-netty-shaded - runtime + grpc-services io.grpc - grpc-protobuf + grpc-netty-shaded + runtime io.grpc - grpc-services + grpc-protobuf io.grpc @@ -55,16 +55,10 @@ protobuf-java-util ${protobuf.version} - - com.google.guava - guava - 32.1.3-jre - - - org.apache.tomcat - annotations-api - 6.0.53 - provided + + com.google.j2objc + j2objc-annotations + 3.1 io.grpc @@ -115,7 +109,7 @@ org.apache.maven.plugins maven-enforcer-plugin - 1.4.1 + 3.5.0 enforce @@ -130,6 +124,35 @@ + + + + + + + + maven-assembly-plugin + 3.7.1 + + ${project.basedir}/maven-assembly-jar-with-dependencies.xml + + + + make-assembly + package + + single + + + + diff --git a/examples/settings.gradle b/examples/settings.gradle index 0473750a54f..4d39e8b45ba 100644 --- a/examples/settings.gradle +++ b/examples/settings.gradle @@ -1,8 +1,19 @@ pluginManagement { - repositories { - maven { // The google mirror is less flaky than mavenCentral() - url "https://maven-central.storage-download.googleapis.com/maven2/" + // https://issuetracker.google.com/issues/342522142#comment8 + // use D8/R8 8.0.44 or 8.1.44 with AGP 7.4 if needed. + buildscript { + repositories { + mavenCentral() + maven { + url = uri("https://storage.googleapis.com/r8-releases/raw") + } + } + dependencies { + classpath("com.android.tools:r8:8.1.44") } + } + + repositories { gradlePluginPortal() } } diff --git a/examples/src/main/java/io/grpc/examples/advanced/README.md b/examples/src/main/java/io/grpc/examples/advanced/README.md new file mode 100644 index 00000000000..f5b5c6cc7fc --- /dev/null +++ b/examples/src/main/java/io/grpc/examples/advanced/README.md @@ -0,0 +1,16 @@ +gRPC JSON Serialization Example +===================== + +gRPC is a modern high-performance framework for building Remote Procedure Call (RPC) systems. +It commonly uses Protocol Buffers (Protobuf) as its serialization format, which is compact and efficient. +However, gRPC can also support JSON serialization when needed, typically for interoperability with +systems or clients that do not use Protobuf. +This is an advanced example of how to swap out the serialization logic, Normal users do not need to do this. +This code is not intended to be a production-ready implementation, since JSON encoding is slow. +Additionally, JSON serialization as implemented may be not resilient to malicious input. + +This advanced example uses Marshaller for JSON which marshals in the Protobuf 3 format described here +https://developers.google.com/protocol-buffers/docs/proto3#json + +If you are considering implementing your own serialization logic, contact the grpc team at +https://groups.google.com/forum/#!forum/grpc-io diff --git a/examples/src/main/java/io/grpc/examples/cancellation/README.md b/examples/src/main/java/io/grpc/examples/cancellation/README.md new file mode 100644 index 00000000000..6b11a17c517 --- /dev/null +++ b/examples/src/main/java/io/grpc/examples/cancellation/README.md @@ -0,0 +1,18 @@ +gRPC Cancellation Example +===================== + +When a gRPC client is no longer interested in the result of an RPC call, +it may cancel to signal this discontinuation of interest to the server. + +Any abort of an ongoing RPC is considered "cancellation" of that RPC. +The common causes of cancellation are the client explicitly cancelling, the deadline expires, and I/O failures. +The service is not informed the reason for the cancellation. + +There are two APIs for services to be notified of RPC cancellation: io.grpc.Context and ServerCallStreamObserver + +Context listeners are called on a different thread, so need to be thread-safe. +The ServerCallStreamObserver cancellation callback is called like other StreamObserver callbacks, +so the application may not need thread-safe handling. +Both APIs have thread-safe isCancelled() polling methods. + +Refer the gRPC documentation for details on Cancellation of RPCs https://grpc.io/docs/guides/cancellation/ diff --git a/examples/src/main/java/io/grpc/examples/customloadbalance/README.md b/examples/src/main/java/io/grpc/examples/customloadbalance/README.md new file mode 100644 index 00000000000..20dbccb81ac --- /dev/null +++ b/examples/src/main/java/io/grpc/examples/customloadbalance/README.md @@ -0,0 +1,19 @@ +gRPC Custom Load Balance Example +===================== + +One of the key features of gRPC is load balancing, which allows requests from clients to be distributed across multiple servers. +This helps prevent any one server from becoming overloaded and allows the system to scale up by adding more servers. + +A gRPC load balancing policy is given a list of server IP addresses by the name resolver. +The policy is responsible for maintaining connections (subchannels) to the servers and picking a connection to use when an RPC is sent. + +This example gives the details about how we can implement our own custom load balance policy, If the built-in policies does not meet your requirements +and follow below steps for the same. + + - Register your implementation in the load balancer registry so that it can be referred to from the service config + - Parse the JSON configuration object of your implementation. This allows your load balancer to be configured in the service config with any arbitrary JSON you choose to support + - Manage what backends to maintain a connection with + - Implement a picker that will choose which backend to connect to when an RPC is made. Note that this needs to be a fast operation as it is on the RPC call path + - To enable your load balancer, configure it in your service config + +Refer the gRPC documentation for more details https://grpc.io/docs/guides/custom-load-balancing/ diff --git a/examples/src/main/java/io/grpc/examples/customloadbalance/ShufflingPickFirstLoadBalancer.java b/examples/src/main/java/io/grpc/examples/customloadbalance/ShufflingPickFirstLoadBalancer.java index 4cf09170c8d..4715b551524 100644 --- a/examples/src/main/java/io/grpc/examples/customloadbalance/ShufflingPickFirstLoadBalancer.java +++ b/examples/src/main/java/io/grpc/examples/customloadbalance/ShufflingPickFirstLoadBalancer.java @@ -92,7 +92,7 @@ public void onSubchannelState(ConnectivityStateInfo stateInfo) { }); this.subchannel = subchannel; - helper.updateBalancingState(CONNECTING, new Picker(PickResult.withNoResult())); + helper.updateBalancingState(CONNECTING, new FixedResultPicker(PickResult.withNoResult())); subchannel.requestConnection(); } else { subchannel.updateAddresses(servers); @@ -107,7 +107,8 @@ public void handleNameResolutionError(Status error) { subchannel.shutdown(); subchannel = null; } - helper.updateBalancingState(TRANSIENT_FAILURE, new Picker(PickResult.withError(error))); + helper.updateBalancingState( + TRANSIENT_FAILURE, new FixedResultPicker(PickResult.withError(error))); } private void processSubchannelState(Subchannel subchannel, ConnectivityStateInfo stateInfo) { @@ -122,16 +123,16 @@ private void processSubchannelState(Subchannel subchannel, ConnectivityStateInfo SubchannelPicker picker; switch (currentState) { case IDLE: - picker = new RequestConnectionPicker(subchannel); + picker = new RequestConnectionPicker(); break; case CONNECTING: - picker = new Picker(PickResult.withNoResult()); + picker = new FixedResultPicker(PickResult.withNoResult()); break; case READY: - picker = new Picker(PickResult.withSubchannel(subchannel)); + picker = new FixedResultPicker(PickResult.withSubchannel(subchannel)); break; case TRANSIENT_FAILURE: - picker = new Picker(PickResult.withError(stateInfo.getStatus())); + picker = new FixedResultPicker(PickResult.withError(stateInfo.getStatus())); break; default: throw new IllegalArgumentException("Unsupported state:" + currentState); @@ -154,52 +155,20 @@ public void requestConnection() { } } - /** - * No-op picker which doesn't add any custom picking logic. It just passes already known result - * received in constructor. - */ - private static final class Picker extends SubchannelPicker { - - private final PickResult result; - - Picker(PickResult result) { - this.result = checkNotNull(result, "result"); - } - - @Override - public PickResult pickSubchannel(PickSubchannelArgs args) { - return result; - } - - @Override - public String toString() { - return MoreObjects.toStringHelper(Picker.class).add("result", result).toString(); - } - } - /** * Picker that requests connection during the first pick, and returns noResult. */ private final class RequestConnectionPicker extends SubchannelPicker { - private final Subchannel subchannel; private final AtomicBoolean connectionRequested = new AtomicBoolean(false); - RequestConnectionPicker(Subchannel subchannel) { - this.subchannel = checkNotNull(subchannel, "subchannel"); - } - @Override public PickResult pickSubchannel(PickSubchannelArgs args) { if (connectionRequested.compareAndSet(false, true)) { - helper.getSynchronizationContext().execute(new Runnable() { - @Override - public void run() { - subchannel.requestConnection(); - } - }); + helper.getSynchronizationContext().execute( + ShufflingPickFirstLoadBalancer.this::requestConnection); } return PickResult.withNoResult(); } } -} \ No newline at end of file +} diff --git a/examples/src/main/java/io/grpc/examples/deadline/README.md b/examples/src/main/java/io/grpc/examples/deadline/README.md new file mode 100644 index 00000000000..3c7646f1e5f --- /dev/null +++ b/examples/src/main/java/io/grpc/examples/deadline/README.md @@ -0,0 +1,15 @@ +gRPC Deadline Example +===================== + +A Deadline is used to specify a point in time past which a client is unwilling to wait for a response from a server. +This simple idea is very important in building robust distributed systems. +Clients that do not wait around unnecessarily and servers that know when to give up processing requests will improve the resource utilization and latency of your system. + +Note that while some language APIs have the concept of a deadline, others use the idea of a timeout. +When an API asks for a deadline, you provide a point in time which the call should not go past. +A timeout is the max duration of time that the call can take. +A timeout can be converted to a deadline by adding the timeout to the current time when the application starts a call. + +This Example gives usage and implementation of Deadline on Server, Client and Propagation. + +Refer the gRPC documentation for more details on Deadlines https://grpc.io/docs/guides/deadlines/ \ No newline at end of file diff --git a/examples/src/main/java/io/grpc/examples/errordetails/README.md b/examples/src/main/java/io/grpc/examples/errordetails/README.md new file mode 100644 index 00000000000..8f241ba37a7 --- /dev/null +++ b/examples/src/main/java/io/grpc/examples/errordetails/README.md @@ -0,0 +1,16 @@ +gRPC Error Details Example +===================== + +If a gRPC call completes successfully the server returns an OK status to the client (depending on the language the OK status may or may not be directly used in your code). +But what happens if the call isn’t successful? + +This Example gives the usage and implementation of how return the error details if gRPC call not successful or fails +and how to set and read com.google.rpc.Status objects as google.rpc.Status error details. + +gRPC allows detailed error information to be encapsulated in protobuf messages, which are sent alongside the status codes. + +If an error occurs, gRPC returns one of its error status codes with error message that provides further error details about what happened. + +Refer the below links for more details on error details and status codes +- https://grpc.io/docs/guides/error/ +- https://github.com/grpc/grpc-java/blob/master/api/src/main/java/io/grpc/Status.java \ No newline at end of file diff --git a/examples/src/main/java/io/grpc/examples/errorhandling/ErrorHandlingClient.java b/examples/src/main/java/io/grpc/examples/errorhandling/ErrorHandlingClient.java index 7e310433a90..9f86e9ceac8 100644 --- a/examples/src/main/java/io/grpc/examples/errorhandling/ErrorHandlingClient.java +++ b/examples/src/main/java/io/grpc/examples/errorhandling/ErrorHandlingClient.java @@ -45,7 +45,7 @@ import javax.annotation.Nullable; /** - * Shows how to extract error information from a server response. + * Shows how to extract error information from a failed RPC. */ public class ErrorHandlingClient { public static void main(String [] args) throws Exception { @@ -60,6 +60,8 @@ void run() throws Exception { .addService(new GreeterGrpc.GreeterImplBase() { @Override public void sayHello(HelloRequest request, StreamObserver responseObserver) { + // The server will always fail, and we'll see this failure on client-side. The exception is + // not sent to the client, only the status code (i.e., INTERNAL) and description. responseObserver.onError(Status.INTERNAL .withDescription("Eggplant Xerxes Crybaby Overbite Narwhal").asRuntimeException()); } diff --git a/examples/src/main/java/io/grpc/examples/errorhandling/README.md b/examples/src/main/java/io/grpc/examples/errorhandling/README.md new file mode 100644 index 00000000000..a920e939c86 --- /dev/null +++ b/examples/src/main/java/io/grpc/examples/errorhandling/README.md @@ -0,0 +1,27 @@ +gRPC Error Handling Example +===================== + +Error handling in gRPC is a critical aspect of designing reliable and robust distributed systems. +gRPC provides a standardized mechanism for handling errors using status codes, error details, and optional metadata. + +This Example gives the usage and implementation of how to handle the Errors/Exceptions in gRPC, +shows how to extract error information from a failed RPC and setting and reading RPC error details. + +If a gRPC call completes successfully the server returns an OK status to the client (depending on the language the OK status may or may not be directly used in your code). + +If an error occurs gRPC returns one of its error status codes with error message that provides further error details about what happened. + +Error Propagation: +- When an error occurs on the server, gRPC stops processing the RPC and sends the error (status code, description, and optional details) to the client. +- On the client side, the error can be handled based on the status code. + +Client Side Error Handling: + - The gRPC client typically throws an exception or returns an error object when an RPC fails. + +Server Side Error Handling: +- Servers use the gRPC API to return errors explicitly using the grpc library's status functions. + +gRPC uses predefined status codes to represent the outcome of an RPC call. These status codes are part of the Status object that is sent from the server to the client. +Each status code is accompanied by a human-readable description(Please refer https://github.com/grpc/grpc-java/blob/master/api/src/main/java/io/grpc/Status.java) + +Refer the gRPC documentation for more details on Error Handling https://grpc.io/docs/guides/error/ \ No newline at end of file diff --git a/examples/src/main/java/io/grpc/examples/experimental/README.md b/examples/src/main/java/io/grpc/examples/experimental/README.md new file mode 100644 index 00000000000..295b0801538 --- /dev/null +++ b/examples/src/main/java/io/grpc/examples/experimental/README.md @@ -0,0 +1,13 @@ +gRPC Compression Example +===================== + +This example shows how clients can specify compression options when performing RPCs, +and how to enable compressed(i,e gzip) requests/responses for only particular method and in case of all methods by using the interceptors. + +Compression is used to reduce the amount of bandwidth used when communicating between client/server or peers and +can be enabled or disabled based on call or message level for all languages. + +gRPC allows asymmetrically compressed communication, whereby a response may be compressed differently with the request, +or not compressed at all. + +Refer the gRPC documentation for more details on Compression https://grpc.io/docs/guides/compression/ \ No newline at end of file diff --git a/examples/src/main/java/io/grpc/examples/grpcproxy/README.md b/examples/src/main/java/io/grpc/examples/grpcproxy/README.md new file mode 100644 index 00000000000..cc13dc3d9d0 --- /dev/null +++ b/examples/src/main/java/io/grpc/examples/grpcproxy/README.md @@ -0,0 +1,22 @@ +gRPC Proxy Example +===================== + +A gRPC proxy is a component or tool that acts as an intermediary between gRPC clients and servers, +facilitating communication while offering additional capabilities. +Proxies are used in scenarios where you need to handle tasks like load balancing, routing, monitoring, +or providing a bridge between gRPC and other protocols. + +GrpcProxy itself can be used unmodified to proxy any service for both unary and streaming. +It doesn't care what type of messages are being used. +The Registry class causes it to be called for any inbound RPC, and uses plain bytes for messages which avoids marshalling +messages and the need for Protobuf schema information. + +We can run the Grpc Proxy with Route guide example to see how it works by running the below + +Route guide has unary and streaming RPCs which makes it a nice showcase, and we can run each in a separate terminal window. + +./build/install/examples/bin/route-guide-server +./build/install/examples/bin/grpc-proxy +./build/install/examples/bin/route-guide-client localhost:8981 + +you can verify the proxy is being used by shutting down the proxy and seeing the client fail. \ No newline at end of file diff --git a/examples/src/main/java/io/grpc/examples/header/HeaderClientInterceptor.java b/examples/src/main/java/io/grpc/examples/header/HeaderClientInterceptor.java index b9a73931299..2a60eeda6c4 100644 --- a/examples/src/main/java/io/grpc/examples/header/HeaderClientInterceptor.java +++ b/examples/src/main/java/io/grpc/examples/header/HeaderClientInterceptor.java @@ -52,7 +52,7 @@ public void start(Listener responseListener, Metadata headers) { public void onHeaders(Metadata headers) { /** * if you don't need receive header from server, - * you can use {@link io.grpc.stub.MetadataUtils#attachHeaders} + * you can use {@link io.grpc.stub.MetadataUtils#newAttachHeadersInterceptor} * directly to send header */ logger.info("header received from server:" + headers); diff --git a/examples/src/main/java/io/grpc/examples/header/README.md b/examples/src/main/java/io/grpc/examples/header/README.md new file mode 100644 index 00000000000..1563a2799cc --- /dev/null +++ b/examples/src/main/java/io/grpc/examples/header/README.md @@ -0,0 +1,16 @@ +gRPC Custom Header Example +===================== + +This example gives the usage and implementation of how to create and process(send/receive) the custom headers between Client and Server +using the interceptors (HeaderServerInterceptor, ClientServerInterceptor) along with Metadata. + +Metadata is a side channel that allows clients and servers to provide information to each other that is associated with an RPC. +gRPC metadata is a key-value pair of data that is sent with initial or final gRPC requests or responses. +It is used to provide additional information about the call, such as authentication credentials, +tracing information, or custom headers. + +gRPC metadata can be used to send custom headers to the server or from the server to the client. +This can be used to implement application-specific features, such as load balancing, +rate limiting or providing detailed error messages from the server to the client. + +Refer the gRPC documentation for more on Metadata/Headers https://grpc.io/docs/guides/metadata/ \ No newline at end of file diff --git a/examples/src/main/java/io/grpc/examples/healthservice/HealthServiceClient.java b/examples/src/main/java/io/grpc/examples/healthservice/HealthServiceClient.java index 471084feab6..a7963630965 100644 --- a/examples/src/main/java/io/grpc/examples/healthservice/HealthServiceClient.java +++ b/examples/src/main/java/io/grpc/examples/healthservice/HealthServiceClient.java @@ -32,6 +32,7 @@ import io.grpc.health.v1.HealthCheckResponse.ServingStatus; import io.grpc.health.v1.HealthGrpc; import java.util.Arrays; +import java.util.Collections; import java.util.HashMap; import java.util.Map; import java.util.concurrent.TimeUnit; @@ -45,25 +46,17 @@ public class HealthServiceClient { private static final Logger logger = Logger.getLogger(HealthServiceClient.class.getName()); private final GreeterGrpc.GreeterBlockingStub greeterBlockingStub; - private final HealthGrpc.HealthStub healthStub; private final HealthGrpc.HealthBlockingStub healthBlockingStub; - private final HealthCheckRequest healthRequest; - /** Construct client for accessing HelloWorld server using the existing channel. */ public HealthServiceClient(Channel channel) { greeterBlockingStub = GreeterGrpc.newBlockingStub(channel); - healthStub = HealthGrpc.newStub(channel); healthBlockingStub = HealthGrpc.newBlockingStub(channel); - healthRequest = HealthCheckRequest.getDefaultInstance(); - LoadBalancerProvider roundRobin = LoadBalancerRegistry.getDefaultRegistry() - .getProvider("round_robin"); - } private ServingStatus checkHealth(String prefix) { HealthCheckResponse response = - healthBlockingStub.check(healthRequest); + healthBlockingStub.check(HealthCheckRequest.getDefaultInstance()); logger.info(prefix + ", current health is: " + response.getStatus()); return response.getStatus(); } @@ -86,34 +79,35 @@ public void greet(String name) { } - private static void runTest(String target, String[] users, boolean useRoundRobin) + private static void runTest(String target, String[] users, boolean enableHealthChecking) throws InterruptedException { - ManagedChannelBuilder builder = - Grpc.newChannelBuilder(target, InsecureChannelCredentials.create()); - - // Round Robin, when a healthCheckConfig is present in the default service configuration, runs - // a watch on the health service and when picking an endpoint will - // consider a transport to a server whose service is not in SERVING state to be unavailable. - // Since we only have a single server we are connecting to, then the load balancer will - // return an error without sending the RPC. - if (useRoundRobin) { - builder = builder - .defaultLoadBalancingPolicy("round_robin") - .defaultServiceConfig(generateHealthConfig("")); + String healthServiceName; + if (enableHealthChecking) { + healthServiceName = ""; // requests the backend's "overall health status" + } else { + healthServiceName = null; // disables health checking in generateServiceConfig() } + ManagedChannel channel = + Grpc.newChannelBuilder(target, InsecureChannelCredentials.create()) + // Enable the round_robin load balancer, with or without health checking + .defaultServiceConfig(generateServiceConfig(healthServiceName)) + .build(); - ManagedChannel channel = builder.build(); + // Round Robin, when a healthCheckConfig is present in the service configuration, runs a watch + // on the health service and when picking an endpoint will consider a transport to a server + // whose service is not in SERVING state to be unavailable. Since we only have a single server + // we are connecting to, then the load balancer will return an error without sending the RPC. - System.out.println("\nDoing test with" + (useRoundRobin ? "" : "out") - + " the Round Robin load balancer\n"); + System.out.println("\nDoing test with" + (enableHealthChecking ? "" : "out") + + " health checking\n"); try { HealthServiceClient client = new HealthServiceClient(channel); - if (!useRoundRobin) { + if (!enableHealthChecking) { client.checkHealth("Before call"); } client.greet(users[0]); - if (!useRoundRobin) { + if (!enableHealthChecking) { client.checkHealth("After user " + users[0]); } @@ -122,7 +116,7 @@ private static void runTest(String target, String[] users, boolean useRoundRobin Thread.sleep(100); // Since the health update is asynchronous give it time to propagate } - if (!useRoundRobin) { + if (!enableHealthChecking) { client.checkHealth("After all users"); Thread.sleep(10000); client.checkHealth("After 10 second wait"); @@ -137,12 +131,17 @@ private static void runTest(String target, String[] users, boolean useRoundRobin channel.shutdownNow().awaitTermination(5, TimeUnit.SECONDS); } } - private static Map generateHealthConfig(String serviceName) { + private static Map generateServiceConfig(String healthServiceName) { Map config = new HashMap<>(); - Map serviceMap = new HashMap<>(); - - config.put("healthCheckConfig", serviceMap); - serviceMap.put("serviceName", serviceName); + if (healthServiceName != null) { + config.put("healthCheckConfig", Collections.singletonMap("serviceName", healthServiceName)); + } + // There is more than one round_robin implementation. If the client doesn't depend on + // io.grpc:grpc-services, then the round_robin implementation does not support health watching + // (to avoid a Protobuf dependency). When the client depends on grpc-services the + // health-supporting round_robin implementation is used instead. + config.put("loadBalancingConfig", Arrays.asList( + Collections.singletonMap("round_robin", Collections.emptyMap()))); return config; } diff --git a/examples/src/main/java/io/grpc/examples/healthservice/HealthServiceServer.java b/examples/src/main/java/io/grpc/examples/healthservice/HealthServiceServer.java index f6547c11103..2170c9d3e08 100644 --- a/examples/src/main/java/io/grpc/examples/healthservice/HealthServiceServer.java +++ b/examples/src/main/java/io/grpc/examples/healthservice/HealthServiceServer.java @@ -94,7 +94,7 @@ public static void main(String[] args) throws IOException, InterruptedException } private class GreeterImpl extends GreeterGrpc.GreeterImplBase { - boolean isServing = true; + private volatile boolean isServing = true; @Override public void sayHello(HelloRequest req, StreamObserver responseObserver) { @@ -134,7 +134,7 @@ public void run() { } private boolean isNameLongEnough(HelloRequest req) { - return isServing && req.getName().length() >= 5; + return req.getName().length() >= 5; } } } diff --git a/examples/src/main/java/io/grpc/examples/healthservice/README.md b/examples/src/main/java/io/grpc/examples/healthservice/README.md new file mode 100644 index 00000000000..9b17f96a624 --- /dev/null +++ b/examples/src/main/java/io/grpc/examples/healthservice/README.md @@ -0,0 +1,13 @@ +gRPC Health Service Example +===================== + +The Health Service example provides a HelloWorld gRPC server that doesn't like +short names along with a health service. It also provides a client application +which makes HelloWorld calls and checks the health status. + +The client application also shows how the round robin load balancer can +utilize the health status to avoid making calls to a service that is +not actively serving. + +Note that clients must depend on `io.grpc:grpc-services` for the health-aware +round_robin implementation to be used. diff --git a/examples/src/main/java/io/grpc/examples/hedging/README.md b/examples/src/main/java/io/grpc/examples/hedging/README.md new file mode 100644 index 00000000000..0154e5c2cee --- /dev/null +++ b/examples/src/main/java/io/grpc/examples/hedging/README.md @@ -0,0 +1,59 @@ +gRPC Hedging Example +===================== + +The Hedging example demonstrates that enabling hedging +can reduce tail latency. (Users should note that enabling hedging may introduce other overhead; +and in some scenarios, such as when some server resource gets exhausted for a period of time and +almost every RPC during that time has high latency or fails, hedging may make things worse. +Setting a throttle in the service config is recommended to protect the server from too many +inappropriate retry or hedging requests.) + +The server and the client in the example are basically the same as those in the +[hello world](src/main/java/io/grpc/examples/helloworld) example, except that the server mimics a +long tail of latency, and the client sends 2000 requests and can turn on and off hedging. + +To mimic the latency, the server randomly delays the RPC handling by 2 seconds at 10% chance, 5 +seconds at 5% chance, and 10 seconds at 1% chance. + +When running the client enabling the following hedging policy + + ```json + "hedgingPolicy": { + "maxAttempts": 3, + "hedgingDelay": "1s" + } + ``` +Then the latency summary in the client log is like the following + + ```text + Total RPCs sent: 2,000. Total RPCs failed: 0 + [Hedging enabled] + ======================== + 50% latency: 0ms + 90% latency: 6ms + 95% latency: 1,003ms + 99% latency: 2,002ms + 99.9% latency: 2,011ms + Max latency: 5,272ms + ======================== + ``` + +See [the section below](#to-build-the-examples) for how to build and run the example. The +executables for the server and the client are `hedging-hello-world-server` and +`hedging-hello-world-client`. + +To disable hedging, set environment variable `DISABLE_HEDGING_IN_HEDGING_EXAMPLE=true` before +running the client. That produces a latency summary in the client log like the following + + ```text + Total RPCs sent: 2,000. Total RPCs failed: 0 + [Hedging disabled] + ======================== + 50% latency: 0ms + 90% latency: 2,002ms + 95% latency: 5,002ms + 99% latency: 10,004ms + 99.9% latency: 10,007ms + Max latency: 10,007ms + ======================== + ``` diff --git a/examples/src/main/java/io/grpc/examples/helloworld/HelloWorldServer.java b/examples/src/main/java/io/grpc/examples/helloworld/HelloWorldServer.java index 81027587031..0e39581c98f 100644 --- a/examples/src/main/java/io/grpc/examples/helloworld/HelloWorldServer.java +++ b/examples/src/main/java/io/grpc/examples/helloworld/HelloWorldServer.java @@ -23,6 +23,8 @@ import java.io.IOException; import java.util.concurrent.TimeUnit; import java.util.logging.Logger; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; /** * Server that manages startup/shutdown of a {@code Greeter} server. @@ -31,11 +33,20 @@ public class HelloWorldServer { private static final Logger logger = Logger.getLogger(HelloWorldServer.class.getName()); private Server server; - private void start() throws IOException { /* The port on which the server should run */ int port = 50051; + /* + * By default gRPC uses a global, shared Executor.newCachedThreadPool() for gRPC callbacks into + * your application. This is convenient, but can cause an excessive number of threads to be + * created if there are many RPCs. It is often better to limit the number of threads your + * application uses for processing and let RPCs queue when the CPU is saturated. + * The appropriate number of threads varies heavily between applications. + * Async application code generally does not need more threads than CPU cores. + */ + ExecutorService executor = Executors.newFixedThreadPool(2); server = Grpc.newServerBuilderForPort(port, InsecureServerCredentials.create()) + .executor(executor) .addService(new GreeterImpl()) .build() .start(); @@ -48,7 +59,12 @@ public void run() { try { HelloWorldServer.this.stop(); } catch (InterruptedException e) { + if (server != null) { + server.shutdownNow(); + } e.printStackTrace(System.err); + } finally { + executor.shutdown(); } System.err.println("*** server shut down"); } diff --git a/examples/src/main/java/io/grpc/examples/helloworld/README.md b/examples/src/main/java/io/grpc/examples/helloworld/README.md new file mode 100644 index 00000000000..5b11d4945c2 --- /dev/null +++ b/examples/src/main/java/io/grpc/examples/helloworld/README.md @@ -0,0 +1,7 @@ +gRPC Hello World Example +===================== +This Example gives the details about basic implementation of gRPC Client and Server along with +how the communication happens between them by sending a greeting message. + +Refer the gRPC documentation for more details on helloworld.proto specification, creation of gRPC services and +methods along with Execution process https://grpc.io/docs/languages/java/quickstart/ \ No newline at end of file diff --git a/examples/src/main/java/io/grpc/examples/keepalive/KeepAliveClient.java b/examples/src/main/java/io/grpc/examples/keepalive/KeepAliveClient.java index a7c59c3952f..414d92dea4c 100644 --- a/examples/src/main/java/io/grpc/examples/keepalive/KeepAliveClient.java +++ b/examples/src/main/java/io/grpc/examples/keepalive/KeepAliveClient.java @@ -78,7 +78,6 @@ public static void main(String[] args) throws Exception { // frames. // More details see: https://github.com/grpc/proposal/blob/master/A8-client-side-keepalive.md ManagedChannel channel = Grpc.newChannelBuilder(target, InsecureChannelCredentials.create()) - .keepAliveTime(5, TimeUnit.MINUTES) .keepAliveTime(10, TimeUnit.SECONDS) // Change to a larger value, e.g. 5min. .keepAliveTimeout(1, TimeUnit.SECONDS) // Change to a larger value, e.g. 10s. .keepAliveWithoutCalls(true)// You should normally avoid enabling this. diff --git a/examples/src/main/java/io/grpc/examples/keepalive/README.md b/examples/src/main/java/io/grpc/examples/keepalive/README.md new file mode 100644 index 00000000000..7b5b72665e7 --- /dev/null +++ b/examples/src/main/java/io/grpc/examples/keepalive/README.md @@ -0,0 +1,16 @@ +gRPC Keepalive Example +===================== + +This example gives the usage and implementation of the Keepalives methods, configurations in gRPC Client and +Server and how the communication happens between them. + +HTTP/2 PING-based keepalives are a way to keep an HTTP/2 connection alive even when there is no data being transferred. +This is done by periodically sending a PING Frames to the other end of the connection. +HTTP/2 keepalives can improve performance and reliability of HTTP/2 connections, +but it is important to configure the keepalive interval carefully. + +gRPC sends http2 pings on the transport to detect if the connection is down. +If the ping is not acknowledged by the other side within a certain period, the connection will be closed. +Note that pings are only necessary when there's no activity on the connection. + +Refer the gRPC documentation for more on Keepalive details and configurations https://grpc.io/docs/guides/keepalive/ \ No newline at end of file diff --git a/examples/src/main/java/io/grpc/examples/loadbalance/ExampleNameResolver.java b/examples/src/main/java/io/grpc/examples/loadbalance/ExampleNameResolver.java index f562f0ac107..6ef327ade84 100644 --- a/examples/src/main/java/io/grpc/examples/loadbalance/ExampleNameResolver.java +++ b/examples/src/main/java/io/grpc/examples/loadbalance/ExampleNameResolver.java @@ -28,12 +28,12 @@ import java.util.List; import java.util.Map; import java.util.stream.Collectors; -import java.util.stream.Stream; import static io.grpc.examples.loadbalance.LoadBalanceClient.exampleServiceName; public class ExampleNameResolver extends NameResolver { + static private final int[] SERVER_PORTS = {50051, 50052, 50053}; private Listener2 listener; private final URI uri; @@ -44,12 +44,11 @@ public ExampleNameResolver(URI targetUri) { this.uri = targetUri; // This is a fake name resolver, so we just hard code the address here. addrStore = ImmutableMap.>builder() - .put(exampleServiceName, - Stream.iterate(LoadBalanceServer.startPort,p->p+1) - .limit(LoadBalanceServer.serverCount) - .map(port->new InetSocketAddress("localhost",port)) - .collect(Collectors.toList()) - ) + .put(exampleServiceName, + Arrays.stream(SERVER_PORTS) + .mapToObj(port->new InetSocketAddress("localhost",port)) + .collect(Collectors.toList()) + ) .build(); } diff --git a/examples/src/main/java/io/grpc/examples/loadbalance/LoadBalanceServer.java b/examples/src/main/java/io/grpc/examples/loadbalance/LoadBalanceServer.java index c97d209497a..85ae92a537a 100644 --- a/examples/src/main/java/io/grpc/examples/loadbalance/LoadBalanceServer.java +++ b/examples/src/main/java/io/grpc/examples/loadbalance/LoadBalanceServer.java @@ -24,23 +24,24 @@ import io.grpc.stub.StreamObserver; import java.io.IOException; +import java.util.ArrayList; +import java.util.List; import java.util.concurrent.TimeUnit; import java.util.logging.Logger; public class LoadBalanceServer { private static final Logger logger = Logger.getLogger(LoadBalanceServer.class.getName()); - static public final int serverCount = 3; - static public final int startPort = 50051; - private Server[] servers; + static final int[] SERVER_PORTS = {50051, 50052, 50053}; + private List servers; private void start() throws IOException { - servers = new Server[serverCount]; - for (int i = 0; i < serverCount; i++) { - int port = startPort + i; - servers[i] = ServerBuilder.forPort(port) + servers = new ArrayList<>(); + for (int port : SERVER_PORTS) { + servers.add( + ServerBuilder.forPort(port) .addService(new GreeterImpl(port)) .build() - .start(); + .start()); logger.info("Server started, listening on " + port); } Runtime.getRuntime().addShutdownHook(new Thread(() -> { @@ -55,18 +56,14 @@ private void start() throws IOException { } private void stop() throws InterruptedException { - for (int i = 0; i < serverCount; i++) { - if (servers[i] != null) { - servers[i].shutdown().awaitTermination(30, TimeUnit.SECONDS); - } + for (Server server : servers) { + server.shutdown().awaitTermination(30, TimeUnit.SECONDS); } } private void blockUntilShutdown() throws InterruptedException { - for (int i = 0; i < serverCount; i++) { - if (servers[i] != null) { - servers[i].awaitTermination(); - } + for (Server server : servers) { + server.awaitTermination(); } } @@ -86,7 +83,8 @@ public GreeterImpl(int port) { @Override public void sayHello(HelloRequest req, StreamObserver responseObserver) { - HelloReply reply = HelloReply.newBuilder().setMessage("Hello " + req.getName() + " from server<" + this.port + ">").build(); + HelloReply reply = HelloReply.newBuilder() + .setMessage("Hello " + req.getName() + " from server<" + this.port + ">").build(); responseObserver.onNext(reply); responseObserver.onCompleted(); } diff --git a/examples/src/main/java/io/grpc/examples/loadbalance/README.md b/examples/src/main/java/io/grpc/examples/loadbalance/README.md new file mode 100644 index 00000000000..0d19d2f3335 --- /dev/null +++ b/examples/src/main/java/io/grpc/examples/loadbalance/README.md @@ -0,0 +1,20 @@ +gRPC Load Balance Example +===================== + +One of the key features of gRPC is load balancing, which allows requests from clients to be distributed across multiple servers. +This helps prevent any one server from becoming overloaded and allows the system to scale up by adding more servers. + +A gRPC load balancing policy is given a list of server IP addresses by the name resolver. +The policy is responsible for maintaining connections (subchannels) to the servers and picking a connection to use when an RPC is sent. + +By default, the pick_first policy will be used. +This policy actually does no load balancing but just tries each address it gets from the name resolver and uses the first one it can connect to. +By updating the gRPC service config you can also switch to using round_robin that connects to every address it gets and rotates through the connected backends for each RPC. +There are also some other load balancing policies available, but the exact set varies by language. + +This example gives the details about how to implement Load Balance in gRPC, If the built-in policies does not meet your requirements +you can implement your own custom load balance [Custom Load Balance](src/main/java/io/grpc/examples/customloadbalance) + +gRPC supports both client side and server side load balancing but by default gRPC uses client side load balancing. + +Refer the gRPC documentation for more details on Load Balancing https://grpc.io/blog/grpc-load-balancing/ \ No newline at end of file diff --git a/examples/src/main/java/io/grpc/examples/manualflowcontrol/BidiBlockingClient.java b/examples/src/main/java/io/grpc/examples/manualflowcontrol/BidiBlockingClient.java new file mode 100644 index 00000000000..902d46c8cc6 --- /dev/null +++ b/examples/src/main/java/io/grpc/examples/manualflowcontrol/BidiBlockingClient.java @@ -0,0 +1,286 @@ +/* + * Copyright 2023 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.examples.manualflowcontrol; + +import com.google.protobuf.ByteString; +import io.grpc.Grpc; +import io.grpc.InsecureChannelCredentials; +import io.grpc.ManagedChannel; +import io.grpc.StatusException; +import io.grpc.examples.manualflowcontrol.StreamingGreeterGrpc.StreamingGreeterBlockingV2Stub; +import io.grpc.stub.BlockingClientCall; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Iterator; +import java.util.List; +import java.util.concurrent.TimeUnit; +import java.util.logging.Logger; + + +/** + * A class that tries multiple ways to do blocking bidi streaming + * communication with an echo server + */ +public class BidiBlockingClient { + + private static final Logger logger = Logger.getLogger(BidiBlockingClient.class.getName()); + + /** + * Greet server. If provided, the first element of {@code args} is the name to use in the + * greeting. The second argument is the target server. You can see the multiplexing in the server + * logs. + */ + public static void main(String[] args) throws Exception { + System.setProperty("java.util.logging.SimpleFormatter.format", "%1$tH:%1$tM:%1$tS %5$s%6$s%n"); + + // Access a service running on the local machine on port 50051 + String target = "localhost:50051"; + // Allow passing in the user and target strings as command line arguments + if (args.length > 0) { + if ("--help".equals(args[0])) { + System.err.println("Usage: [target]\n"); + System.err.println(" target The server to connect to. Defaults to " + target); + System.exit(1); + } + target = args[0]; + } + + // Create a communication channel to the server, known as a Channel. Channels are thread-safe + // and reusable. It is common to create channels at the beginning of your application and reuse + // them until the application shuts down. + // + // For the example we use plaintext insecure credentials to avoid needing TLS certificates. To + // use TLS, use TlsChannelCredentials instead. + ManagedChannel channel = Grpc.newChannelBuilder(target, InsecureChannelCredentials.create()) + .build(); + StreamingGreeterBlockingV2Stub blockingStub = StreamingGreeterGrpc.newBlockingV2Stub(channel); + List echoInput = names(); + try { + long start = System.currentTimeMillis(); + List twoThreadResult = useTwoThreads(blockingStub, echoInput); + long finish = System.currentTimeMillis(); + + System.out.println("The echo requests and results were:"); + printResultMessage("Input", echoInput, 0L); + printResultMessage("2 threads", twoThreadResult, finish - start); + } finally { + // ManagedChannels use resources like threads and TCP connections. To prevent leaking these + // resources the channel should be shut down when it will no longer be used. If it may be used + // again leave it running. + channel.shutdownNow().awaitTermination(5, TimeUnit.SECONDS); + } + } + + private static void printResultMessage(String type, List result, long millis) { + String msg = String.format("%-32s: %2d, %.3f sec", type, result.size(), millis/1000.0); + logger.info(msg); + } + + private static void logMethodStart(String method) { + logger.info("--------------------- Starting to process using method: " + method); + } + + /** + * Create 2 threads, one that writes all values, and one that reads until the stream closes. + */ + private static List useTwoThreads(StreamingGreeterBlockingV2Stub blockingStub, + List valuesToWrite) throws InterruptedException { + logMethodStart("Two Threads"); + + List readValues = new ArrayList<>(); + final BlockingClientCall stream = blockingStub.sayHelloStreaming(); + + Thread reader = new Thread(null, + new Runnable() { + @Override + public void run() { + int count = 0; + try { + while (stream.hasNext()) { + readValues.add(stream.read().getMessage()); + if (++count % 10 == 0) { + logger.info("Finished " + count + " reads"); + } + } + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + stream.cancel("Interrupted", e); + } catch (StatusException e) { + logger.warning("Encountered error while reading: " + e); + } + } + },"reader"); + + Thread writer = new Thread(null, + new Runnable() { + @Override + public void run() { + ByteString padding = createPadding(); + int count = 0; + Iterator iterator = valuesToWrite.iterator(); + boolean hadProblem = false; + try { + while (iterator.hasNext()) { + if (!stream.write(HelloRequest.newBuilder().setName(iterator.next()).setPadding(padding) + .build())) { + logger.warning("Stream closed before writes completed"); + hadProblem = true; + break; + } + if (++count % 10 == 0) { + logger.info("Finished " + count + " writes"); + } + } + if (!hadProblem) { + logger.info("Completed writes"); + stream.halfClose(); + } + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + stream.cancel("Interrupted", e); + } catch (StatusException e) { + logger.warning("Encountered error while writing: " + e); + } + } + }, "writer"); + + writer.start(); + reader.start(); + writer.join(); + reader.join(); + + return readValues; + } + + private static ByteString createPadding() { + int multiple = 50; + ByteBuffer data = ByteBuffer.allocate(1024 * multiple); + + for (int i = 0; i < multiple * 1024 / 4; i++) { + data.putInt(4 * i, 1111); + } + + return ByteString.copyFrom(data); + } + + + private static List names() { + return Arrays.asList( + "Sophia", + "Jackson", + "Emma", + "Aiden", + "Olivia", + "Lucas", + "Ava", + "Liam", + "Mia", + "Noah", + "Isabella", + "Ethan", + "Riley", + "Mason", + "Aria", + "Caden", + "Zoe", + "Oliver", + "Charlotte", + "Elijah", + "Lily", + "Grayson", + "Layla", + "Jacob", + "Amelia", + "Michael", + "Emily", + "Benjamin", + "Madelyn", + "Carter", + "Aubrey", + "James", + "Adalyn", + "Jayden", + "Madison", + "Logan", + "Chloe", + "Alexander", + "Harper", + "Caleb", + "Abigail", + "Ryan", + "Aaliyah", + "Luke", + "Avery", + "Daniel", + "Evelyn", + "Jack", + "Kaylee", + "William", + "Ella", + "Owen", + "Ellie", + "Gabriel", + "Scarlett", + "Matthew", + "Arianna", + "Connor", + "Hailey", + "Jayce", + "Nora", + "Isaac", + "Addison", + "Sebastian", + "Brooklyn", + "Henry", + "Hannah", + "Muhammad", + "Mila", + "Cameron", + "Leah", + "Wyatt", + "Elizabeth", + "Dylan", + "Sarah", + "Nathan", + "Eliana", + "Nicholas", + "Mackenzie", + "Julian", + "Peyton", + "Eli", + "Maria", + "Levi", + "Grace", + "Isaiah", + "Adeline", + "Landon", + "Elena", + "David", + "Anna", + "Christian", + "Victoria", + "Andrew", + "Camilla", + "Brayden", + "Lillian", + "John", + "Natalie", + "Lincoln" + ); + } +} diff --git a/examples/src/main/java/io/grpc/examples/manualflowcontrol/ManualFlowControlServer.java b/examples/src/main/java/io/grpc/examples/manualflowcontrol/ManualFlowControlServer.java index de8142596ea..3b7f980e08c 100644 --- a/examples/src/main/java/io/grpc/examples/manualflowcontrol/ManualFlowControlServer.java +++ b/examples/src/main/java/io/grpc/examples/manualflowcontrol/ManualFlowControlServer.java @@ -72,6 +72,7 @@ public void run() { // Give gRPC a StreamObserver that can observe and process incoming requests. return new StreamObserver() { + int cnt = 0; @Override public void onNext(HelloRequest request) { // Process the request and send a response or an error. @@ -81,7 +82,8 @@ public void onNext(HelloRequest request) { logger.info("--> " + name); // Simulate server "work" - Thread.sleep(100); + int sleepMillis = ++cnt % 20 == 0 ? 2000 : 100; + Thread.sleep(sleepMillis); // Send a response. String message = "Hello " + name; diff --git a/examples/src/main/java/io/grpc/examples/manualflowcontrol/README.md b/examples/src/main/java/io/grpc/examples/manualflowcontrol/README.md index a30688cea15..f700d428aca 100644 --- a/examples/src/main/java/io/grpc/examples/manualflowcontrol/README.md +++ b/examples/src/main/java/io/grpc/examples/manualflowcontrol/README.md @@ -1,5 +1,5 @@ -gRPC Manual Flow Control Example -===================== +# gRPC Manual Flow Control Example + Flow control is relevant for streaming RPC calls. By default, gRPC will handle dealing with flow control. However, for specific @@ -25,14 +25,13 @@ value. ### Outgoing Flow Control -The underlying layer (such as Netty) will make the write wait when there is no -space to write the next message. This causes the request stream to go into -a not ready state and the outgoing onNext method invocation waits. You can -explicitly check that the stream is ready for writing before calling onNext to -avoid blocking. This is done with `CallStreamObserver.isReady()`. You can -utilize this to start doing reads, which may allow -the other side of the channel to complete a write and then to do its own reads, -thereby avoiding deadlock. +The underlying layer (such as Netty) manages a buffer for outgoing messages. If +you write messages faster than they can be sent over the network, this buffer +will grow, which can eventually lead to an OutOfMemoryError. The outgoing onNext +method invocation does not block when this happens. Therefore, you should +explicitly check that the stream is ready for writing via +`CallStreamObserver.isReady()` before generating messages to avoid buffering +excessive amounts of data in memory. ### Incoming Manual Flow Control @@ -71,6 +70,7 @@ When you are ready to begin processing the next value from the stream call `serverCallStreamObserver.request(1)` ### Related documents + Also see [gRPC Flow Control Users Guide][user guide] - [user guide]: https://grpc.io/docs/guides/flow-control \ No newline at end of file +[user guide]: https://grpc.io/docs/guides/flow-control diff --git a/examples/src/main/java/io/grpc/examples/multiplex/README.md b/examples/src/main/java/io/grpc/examples/multiplex/README.md new file mode 100644 index 00000000000..fb24642a41b --- /dev/null +++ b/examples/src/main/java/io/grpc/examples/multiplex/README.md @@ -0,0 +1,20 @@ +gRPC Multiplex Example +===================== + +gRPC multiplexing refers to the ability of a single gRPC connection to handle multiple independent streams of communication simultaneously. +This is part of the HTTP/2 protocol on which gRPC is built. +Each gRPC connection supports multiple streams that can carry different RPCs, making it highly efficient for high-throughput, low-latency communication. + +In gRPC, sharing resources like channels and servers can improve efficiency and resource utilization. + +- Sharing gRPC Channels and Servers + + 1. Shared gRPC Channel: + - A single gRPC channel can be used by multiple stubs, enabling different service clients to communicate over the same connection. + - This minimizes the overhead of establishing and managing multiple connections + + 2. Shared gRPC Server: + - A single gRPC channel can be used by multiple stubs, enabling different service clients to communicate over the same connection. + - This minimizes the overhead of establishing and managing multiple connections + +This example demonstrates how to implement a gRPC server that serves both a GreetingService and an EchoService, and a client that shares a single channel across multiple stubs for both services. \ No newline at end of file diff --git a/examples/src/main/java/io/grpc/examples/nameresolve/NameResolveClient.java b/examples/src/main/java/io/grpc/examples/nameresolve/NameResolveClient.java index ac6fdd32549..9aaccbe1096 100644 --- a/examples/src/main/java/io/grpc/examples/nameresolve/NameResolveClient.java +++ b/examples/src/main/java/io/grpc/examples/nameresolve/NameResolveClient.java @@ -26,8 +26,7 @@ import java.util.logging.Logger; public class NameResolveClient { - public static final String exampleScheme = "example"; - public static final String exampleServiceName = "lb.example.grpc.io"; + public static final String channelTarget = "example:///lb.example.grpc.io"; private static final Logger logger = Logger.getLogger(NameResolveClient.class.getName()); private final GreeterGrpc.GreeterBlockingStub blockingStub; @@ -56,11 +55,10 @@ public static void main(String[] args) throws Exception { Dial to "example:///resolver.example.grpc.io", use {@link ExampleNameResolver} to create connection "resolver.example.grpc.io" is converted to {@link java.net.URI.path} */ - channel = ManagedChannelBuilder.forTarget( - String.format("%s:///%s", exampleScheme, exampleServiceName)) - .defaultLoadBalancingPolicy("round_robin") - .usePlaintext() - .build(); + channel = ManagedChannelBuilder.forTarget(channelTarget) + .defaultLoadBalancingPolicy("round_robin") + .usePlaintext() + .build(); try { NameResolveClient client = new NameResolveClient(channel); for (int i = 0; i < 5; i++) { diff --git a/examples/src/main/java/io/grpc/examples/nameresolve/README.md b/examples/src/main/java/io/grpc/examples/nameresolve/README.md new file mode 100644 index 00000000000..36c8d7e2a6b --- /dev/null +++ b/examples/src/main/java/io/grpc/examples/nameresolve/README.md @@ -0,0 +1,22 @@ +gRPC Name Resolve Example +===================== + +This example explains standard name resolution process and how to implement it using the Name Resolver component. + +Name Resolution is fundamentally about Service Discovery. +Name Resolution refers to the process of converting a name into an address and +Name Resolver is the component that implements the Name Resolution process. + +When sending gRPC Request, Client must determine the IP address of the Service Name, +By Default DNS Name Resolution will be used when request received from the gRPC client. + +The Name Resolver in gRPC is necessary because clients often don’t know the exact IP address or port of the server +they need to connect to. + +The client registers an implementation of a **name resolver provider** to a process-global **registry** close to the start of the process. +The name resolver provider will be called by the **gRPC library** with a **target strings** intended for the custom name resolver. +Given that target string, the name resolver provider will return an instance of a **name resolver**, +which will interact with the client connection to direct the request according to the target string. + +Refer the gRPC documentation for more on Name Resolution and Custom Name Resolution +https://grpc.io/docs/guides/custom-name-resolution/ \ No newline at end of file diff --git a/examples/src/main/java/io/grpc/examples/preserialized/README.md b/examples/src/main/java/io/grpc/examples/preserialized/README.md new file mode 100644 index 00000000000..d49b3507d03 --- /dev/null +++ b/examples/src/main/java/io/grpc/examples/preserialized/README.md @@ -0,0 +1,18 @@ +gRPC Pre-Serialized Messages Example +===================== + +This example gives the usage and implementation of pre-serialized request and response messages +communication/exchange between grpc client and server by using ByteArrayMarshaller which produces +a byte[] instead of decoding into typical POJOs. + +This is a performance optimization that can be useful if you read the request/response from on-disk or a database +where it is already serialized, or if you need to send the same complicated message to many clients and servers. +The same approach can avoid deserializing requests/responses, to be stored in a database. + +It shows how to modify MethodDescriptor to use bytes as the response instead of HelloReply. By +adjusting toBuilder() you can choose which of the request and response are bytes. +The generated bindService() uses ServerCalls to make RPC handlers, Since the generated +bindService() won't expect byte[] in the AsyncService, this uses ServerCalls directly. + +Stubs use ClientCalls to send RPCs, Since the generated stub won't have byte[] in its +method signature, this uses ClientCalls directly. \ No newline at end of file diff --git a/examples/src/main/java/io/grpc/examples/retrying/README.md b/examples/src/main/java/io/grpc/examples/retrying/README.md new file mode 100644 index 00000000000..bb29ce75e43 --- /dev/null +++ b/examples/src/main/java/io/grpc/examples/retrying/README.md @@ -0,0 +1,27 @@ +gRPC Retrying Example +===================== + +The Retrying example provides a HelloWorld gRPC client & +server which demos the effect of client retry policy configured on the [ManagedChannel]( +https://github.com/grpc/grpc-java/blob/master/api/src/main/java/io/grpc/ManagedChannel.java) via [gRPC ServiceConfig]( +https://github.com/grpc/grpc/blob/master/doc/service_config.md). Retry policy implementation & +configuration details are outlined in the [proposal](https://github.com/grpc/proposal/blob/master/A6-client-retries.md). + +This retrying example is very similar to the [hedging example](https://github.com/grpc/grpc-java/tree/master/examples/src/main/java/io/grpc/examples/hedging) in its setup. +The [RetryingHelloWorldServer](src/main/java/io/grpc/examples/retrying/RetryingHelloWorldServer.java) responds with +a status UNAVAILABLE error response to a specified percentage of requests to simulate server resource exhaustion and +general flakiness. The [RetryingHelloWorldClient](src/main/java/io/grpc/examples/retrying/RetryingHelloWorldClient.java) makes +a number of sequential requests to the server, several of which will be retried depending on the configured policy in +[retrying_service_config.json](https://github.com/grpc/grpc-java/blob/master/examples/src/main/resources/io/grpc/examples/retrying/retrying_service_config.json). Although +the requests are blocking unary calls for simplicity, these could easily be changed to future unary calls in order to +test the result of request concurrency with retry policy enabled. + +One can experiment with the [RetryingHelloWorldServer](src/main/java/io/grpc/examples/retrying/RetryingHelloWorldServer.java) +failure conditions to simulate server throttling, as well as alter policy values in the [retrying_service_config.json]( +https://github.com/grpc/grpc-java/blob/master/examples/src/main/resources/io/grpc/examples/retrying/retrying_service_config.json) to see their effects. To disable retrying +entirely, set environment variable `DISABLE_RETRYING_IN_RETRYING_EXAMPLE=true` before running the client. +Disabling the retry policy should produce many more failed gRPC calls as seen in the output log. + +See [the section](https://github.com/grpc/grpc-java/tree/master/examples#-to-build-the-examples) for how to build and run the example. The +executables for the server and the client are `retrying-hello-world-server` and +`retrying-hello-world-client`. diff --git a/examples/src/main/java/io/grpc/examples/routeguide/README.md b/examples/src/main/java/io/grpc/examples/routeguide/README.md new file mode 100644 index 00000000000..2528b26410c --- /dev/null +++ b/examples/src/main/java/io/grpc/examples/routeguide/README.md @@ -0,0 +1,24 @@ +gRPC Route Guide Example +===================== + +This example illustrates how to implement and use a gRPC server and client for a RouteGuide service, +which demonstrates all 4 types of gRPC methods (unary, client streaming, server streaming, and bidirectional streaming). +Additionally, the service loads geographic features from a JSON file [route_guide_db.json](https://github.com/grpc/grpc-java/blob/master/examples/src/main/resources/io/grpc/examples/routeguide/route_guide_db.json) and retrieves features based on latitude and longitude. + +The route_guide.proto file defines a gRPC service with 4 types of RPC methods, showcasing different communication patterns between client and server. +1. Unary RPC + - rpc GetFeature(Point) returns (Feature) {} +2. Server-Side Streaming RPC + - rpc ListFeatures(Rectangle) returns (stream Feature) {} +3. Client-Side Streaming RPC + - rpc RecordRoute(stream Point) returns (RouteSummary) {} +4. Bidirectional Streaming RPC + - rpc RouteChat(stream RouteNote) returns (stream RouteNote) {} + +These RPC methods illustrate the versatility of gRPC in handling various communication patterns, +from simple request-response interactions to complex bidirectional streaming scenarios. + +For more details, refer to the full route_guide.proto file on GitHub: https://github.com/grpc/grpc-java/blob/master/examples/src/main/proto/route_guide.proto + +Refer the gRPC documentation for more details on creation, build and execution of route guide example with explanation +https://grpc.io/docs/languages/java/basics/ \ No newline at end of file diff --git a/examples/src/main/java/io/grpc/examples/routeguide/RouteGuideServer.java b/examples/src/main/java/io/grpc/examples/routeguide/RouteGuideServer.java index b39b06a6f92..1a3ecb0f882 100644 --- a/examples/src/main/java/io/grpc/examples/routeguide/RouteGuideServer.java +++ b/examples/src/main/java/io/grpc/examples/routeguide/RouteGuideServer.java @@ -251,7 +251,7 @@ public void onCompleted() { * Get the notes list for the given location. If missing, create it. */ private List getOrCreateNotes(Point location) { - List notes = Collections.synchronizedList(new ArrayList()); + List notes = Collections.synchronizedList(new ArrayList<>()); List prevNotes = routeNotes.putIfAbsent(location, notes); return prevNotes != null ? prevNotes : notes; } diff --git a/examples/src/main/java/io/grpc/examples/waitforready/README.md b/examples/src/main/java/io/grpc/examples/waitforready/README.md new file mode 100644 index 00000000000..1e294b453b6 --- /dev/null +++ b/examples/src/main/java/io/grpc/examples/waitforready/README.md @@ -0,0 +1,29 @@ +gRPC Wait-For-Ready Example +===================== + +This example gives the usage and implementation of the Wait-For-Ready feature. + +This feature can be activated on a client stub, ensuring that Remote Procedure Calls (RPCs) are held until the server is ready to receive them. +By waiting for the server to become available before sending requests, this mechanism enhances reliability, +particularly in situations where server availability may be delayed or unpredictable. + +When an RPC is initiated and the channel fails to connect to the server, its behavior depends on the Wait-for-Ready option: + +- Without Wait-for-Ready (Default Behavior): + + - The RPC will immediately fail if the channel cannot establish a connection, providing prompt feedback about the connectivity issue. + +- With Wait-for-Ready: + + - The RPC will not fail immediately. Instead, it will be queued and will wait until the connection is successfully established. + This approach is beneficial for handling temporary network disruptions more gracefully, ensuring the RPC is eventually executed once the connection is ready. + + +Example gives the Simple client that requests a greeting from the HelloWorldServer and defines waitForReady on the stub. + +To test this flow need to follow below steps: +- run this client without a server running(client rpc should hang) +- start the server (client rpc should complete) +- run this client again (client rpc should complete nearly immediately) + +Refer the gRPC documentation for more on Wait-For-Ready https://grpc.io/docs/guides/wait-for-ready/ \ No newline at end of file diff --git a/examples/src/main/proto/hello_streaming.proto b/examples/src/main/proto/hello_streaming.proto index 325b9093b0c..b4f0f5287dd 100644 --- a/examples/src/main/proto/hello_streaming.proto +++ b/examples/src/main/proto/hello_streaming.proto @@ -29,6 +29,7 @@ service StreamingGreeter { // The request message containing the user's name. message HelloRequest { string name = 1; + bytes padding = 2; } // The response message containing the greetings diff --git a/gae-interop-testing/gae-jdk8/build.gradle b/gae-interop-testing/gae-jdk8/build.gradle index f3ff765ddfb..07033f403de 100644 --- a/gae-interop-testing/gae-jdk8/build.gradle +++ b/gae-interop-testing/gae-jdk8/build.gradle @@ -14,10 +14,6 @@ buildscript { // Configuration for building - repositories { - maven { // The google mirror is less flaky than mavenCentral() - url "https://maven-central.storage-download.googleapis.com/maven2/" } - } dependencies { classpath 'com.squareup.okhttp:okhttp:2.7.4' } @@ -28,20 +24,11 @@ plugins { id "war" id "ru.vyarus.animalsniffer" - id 'com.google.cloud.tools.appengine' version '2.3.0' + id 'com.google.cloud.tools.appengine' } description = 'gRPC: gae interop testing (jdk8)' -repositories { - // repositories for Jar's you access in your code - mavenLocal() - maven { // The google mirror is less flaky than mavenCentral() - url "https://maven-central.storage-download.googleapis.com/maven2/" } -} - -apply plugin: 'com.google.cloud.tools.appengine' // App Engine tasks - dependencies { providedCompile group: 'javax.servlet', name: 'servlet-api', version:'2.5' runtimeOnly 'com.google.appengine:appengine-api-1.0-sdk:1.9.59' @@ -55,7 +42,11 @@ dependencies { implementation libraries.junit implementation libraries.protobuf.java runtimeOnly libraries.netty.tcnative, libraries.netty.tcnative.classes - signature libraries.signature.java + signature (libraries.signature.java) { + artifact { + extension = "signature" + } + } } tasks.named("compileJava").configure { @@ -67,6 +58,7 @@ def createDefaultVersion() { return new java.text.SimpleDateFormat("yyyyMMdd't'HHmmss").format(new Date()) } +def nonShadowedProject = project // [START model] appengine { // App Engine tasks configuration @@ -76,13 +68,13 @@ appengine { deploy { // deploy configuration - projectId = 'GCLOUD_CONFIG' + projectId = nonShadowedProject.findProperty('gaeProjectId') ?: 'GCLOUD_CONFIG' // default - stop the current version - stopPreviousVersion = System.getProperty('gaeStopPreviousVersion') ?: true + stopPreviousVersion = nonShadowedProject.findProperty('gaeStopPreviousVersion') ?: true // default - do not make this the promoted version - promote = System.getProperty('gaePromote') ?: false - // Use -DgaeDeployVersion if set, otherwise the version is null and the plugin will generate it - version = System.getProperty('gaeDeployVersion', createDefaultVersion()) + promote = nonShadowedProject.findProperty('gaePromote') ?: false + // Use -PgaeDeployVersion if set, otherwise the version is null and the plugin will generate it + version = nonShadowedProject.findProperty('gaeDeployVersion') ?: createDefaultVersion() } } // [END model] @@ -92,6 +84,10 @@ version = '1.0-SNAPSHOT' // Version in generated output /** Returns the service name. */ String getGaeProject() { + def configuredProjectId = appengine.deploy.projectId + if (!"GCLOUD_CONFIG".equals(configuredProjectId)) { + return configuredProjectId + } def stream = new ByteArrayOutputStream() exec { executable 'gcloud' @@ -119,11 +115,8 @@ String getAppUrl(String project, String service, String version) { } tasks.register("runInteropTestRemote") { - dependsOn appengineDeploy + mustRunAfter appengineDeploy doLast { - // give remote app some time to settle down - sleep(20000) - def appUrl = getAppUrl( getGaeProject(), getService(project.getProjectDir().toPath()), diff --git a/gae-interop-testing/gae-jdk8/src/main/webapp/WEB-INF/appengine-web.xml b/gae-interop-testing/gae-jdk8/src/main/webapp/WEB-INF/appengine-web.xml index 2fcbe5d8221..715906ada47 100644 --- a/gae-interop-testing/gae-jdk8/src/main/webapp/WEB-INF/appengine-web.xml +++ b/gae-interop-testing/gae-jdk8/src/main/webapp/WEB-INF/appengine-web.xml @@ -14,6 +14,6 @@ java-gae-interop-test - java11 + java17 diff --git a/gcp-csm-observability/build.gradle b/gcp-csm-observability/build.gradle new file mode 100644 index 00000000000..bda54ca8146 --- /dev/null +++ b/gcp-csm-observability/build.gradle @@ -0,0 +1,36 @@ +plugins { + id "java-library" + id "maven-publish" + + id "ru.vyarus.animalsniffer" +} + +description = "gRPC: GCP CSM Observability" + +tasks.named("jar").configure { + manifest { + attributes('Automatic-Module-Name': 'io.grpc.gcp.csm.observability') + } +} + +dependencies { + implementation project(':grpc-api'), + project(':grpc-core'), + project(':grpc-opentelemetry'), + project(':grpc-protobuf'), + project(path: ':grpc-xds', configuration: 'shadow'), + libraries.guava.jre, // jre version pulled in via xds + libraries.protobuf.java, + libraries.opentelemetry.gcp.resources, + libraries.opentelemetry.sdk.extension.autoconfigure // opentelemetry.gcp.resources uses compileOnly for this dep + testImplementation project(":grpc-testing"), + project(":grpc-inprocess"), + libraries.opentelemetry.sdk.testing, + libraries.assertj.core // opentelemetry.sdk.testing uses compileOnly for this dep + + signature (libraries.signature.java) { + artifact { + extension = "signature" + } + } +} diff --git a/gcp-csm-observability/src/main/java/io/grpc/gcp/csm/observability/CsmObservability.java b/gcp-csm-observability/src/main/java/io/grpc/gcp/csm/observability/CsmObservability.java new file mode 100644 index 00000000000..c345fb35d0a --- /dev/null +++ b/gcp-csm-observability/src/main/java/io/grpc/gcp/csm/observability/CsmObservability.java @@ -0,0 +1,160 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.gcp.csm.observability; + +import com.google.common.annotations.VisibleForTesting; +import io.grpc.ExperimentalApi; +import io.grpc.InternalConfigurator; +import io.grpc.InternalConfiguratorRegistry; +import io.grpc.ManagedChannelBuilder; +import io.grpc.ServerBuilder; +import io.grpc.opentelemetry.GrpcOpenTelemetry; +import io.grpc.opentelemetry.InternalGrpcOpenTelemetry; +import io.opentelemetry.api.OpenTelemetry; +import java.io.Closeable; +import java.util.Collection; +import java.util.Collections; + +/** + * The entrypoint for GCP's CSM OpenTelemetry metrics functionality in gRPC. + * + *

CsmObservability uses {@link io.opentelemetry.api.OpenTelemetry} APIs for instrumentation. + * When no SDK is explicitly added no telemetry data will be collected. See + * {@code io.opentelemetry.sdk.OpenTelemetrySdk} for information on how to construct the SDK. + */ +@ExperimentalApi("https://github.com/grpc/grpc-java/issues/11249") +public final class CsmObservability implements Closeable { + private final GrpcOpenTelemetry delegate; + private final MetadataExchanger exchanger; + + public static Builder newBuilder() { + return new Builder(); + } + + private CsmObservability(Builder builder) { + this.delegate = builder.delegate.build(); + this.exchanger = builder.exchanger; + } + + /** + * Registers CsmObservability globally, applying its configuration to all subsequently created + * gRPC channels and servers. + * + *

Note: Only one of CsmObservability and GrpcOpenTelemetry instance can be registered + * globally. Any subsequent call to {@code registerGlobal()} will throw an {@code + * IllegalStateException}. + */ + public void registerGlobal() { + InternalConfiguratorRegistry.setConfigurators(Collections.singletonList( + new InternalConfigurator() { + @Override + public void configureChannelBuilder(ManagedChannelBuilder channelBuilder) { + CsmObservability.this.configureChannelBuilder(channelBuilder); + } + + @Override + public void configureServerBuilder(ServerBuilder serverBuilder) { + CsmObservability.this.configureServerBuilder(serverBuilder); + } + })); + } + + @VisibleForTesting + void configureChannelBuilder(ManagedChannelBuilder builder) { + delegate.configureChannelBuilder(builder); + } + + @VisibleForTesting + void configureServerBuilder(ServerBuilder serverBuilder) { + delegate.configureServerBuilder(serverBuilder); + exchanger.configureServerBuilder(serverBuilder); + } + + @Override + public void close() {} + + /** + * Builder for configuring {@link CsmObservability}. + */ + @ExperimentalApi("https://github.com/grpc/grpc-java/issues/11249") + public static final class Builder { + private final GrpcOpenTelemetry.Builder delegate = GrpcOpenTelemetry.newBuilder(); + private final MetadataExchanger exchanger; + + private Builder() { + this(new MetadataExchanger()); + } + + @VisibleForTesting + Builder(MetadataExchanger exchanger) { + this.exchanger = exchanger; + InternalGrpcOpenTelemetry.builderPlugin(delegate, exchanger); + } + + /** + * Sets the {@link io.opentelemetry.api.OpenTelemetry} entrypoint to use. This can be used to + * configure OpenTelemetry by returning the instance created by a + * {@code io.opentelemetry.sdk.OpenTelemetrySdkBuilder}. + */ + public Builder sdk(OpenTelemetry sdk) { + delegate.sdk(sdk); + return this; + } + + /** + * Adds optionalLabelKey to all the metrics that can provide value for the + * optionalLabelKey. + */ + public Builder addOptionalLabel(String optionalLabelKey) { + delegate.addOptionalLabel(optionalLabelKey); + return this; + } + + /** + * Enables the specified metrics for collection and export. By default, only a subset of + * metrics are enabled. + */ + public Builder enableMetrics(Collection enableMetrics) { + delegate.enableMetrics(enableMetrics); + return this; + } + + /** + * Disables the specified metrics from being collected and exported. + */ + public Builder disableMetrics(Collection disableMetrics) { + delegate.disableMetrics(disableMetrics); + return this; + } + + /** + * Disable all metrics. If set to true all metrics must be explicitly enabled. + */ + public Builder disableAllMetrics() { + delegate.disableAllMetrics(); + return this; + } + + /** + * Returns a new {@link CsmObservability} built with the configuration of this {@link + * Builder}. + */ + public CsmObservability build() { + return new CsmObservability(this); + } + } +} diff --git a/gcp-csm-observability/src/main/java/io/grpc/gcp/csm/observability/MetadataExchanger.java b/gcp-csm-observability/src/main/java/io/grpc/gcp/csm/observability/MetadataExchanger.java new file mode 100644 index 00000000000..5f05d52c7e7 --- /dev/null +++ b/gcp-csm-observability/src/main/java/io/grpc/gcp/csm/observability/MetadataExchanger.java @@ -0,0 +1,347 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.gcp.csm.observability; + +import com.google.common.base.Preconditions; +import com.google.common.io.BaseEncoding; +import com.google.protobuf.Struct; +import com.google.protobuf.Value; +import io.grpc.CallOptions; +import io.grpc.ForwardingServerCall.SimpleForwardingServerCall; +import io.grpc.Metadata; +import io.grpc.ServerBuilder; +import io.grpc.ServerCall; +import io.grpc.ServerCallHandler; +import io.grpc.ServerInterceptor; +import io.grpc.Status; +import io.grpc.opentelemetry.InternalOpenTelemetryPlugin; +import io.grpc.protobuf.ProtoUtils; +import io.grpc.xds.ClusterImplLoadBalancerProvider; +import io.opentelemetry.api.common.AttributeKey; +import io.opentelemetry.api.common.Attributes; +import io.opentelemetry.api.common.AttributesBuilder; +import io.opentelemetry.contrib.gcp.resource.GCPResourceProvider; +import io.opentelemetry.sdk.autoconfigure.ResourceConfiguration; +import java.net.URI; +import java.util.Map; +import java.util.function.Consumer; + +/** + * OpenTelemetryPlugin implementing metadata-based workload property exchange for both client and + * server. Is responsible for determining the metadata, communicating the metadata, and adding local + * and remote details to metrics. + */ +final class MetadataExchanger implements InternalOpenTelemetryPlugin { + + private static final AttributeKey CLOUD_PLATFORM = + AttributeKey.stringKey("cloud.platform"); + private static final AttributeKey K8S_NAMESPACE_NAME = + AttributeKey.stringKey("k8s.namespace.name"); + private static final AttributeKey K8S_CLUSTER_NAME = + AttributeKey.stringKey("k8s.cluster.name"); + private static final AttributeKey CLOUD_AVAILABILITY_ZONE = + AttributeKey.stringKey("cloud.availability_zone"); + private static final AttributeKey CLOUD_REGION = + AttributeKey.stringKey("cloud.region"); + private static final AttributeKey CLOUD_ACCOUNT_ID = + AttributeKey.stringKey("cloud.account.id"); + + private static final Metadata.Key SEND_KEY = + Metadata.Key.of("x-envoy-peer-metadata", Metadata.ASCII_STRING_MARSHALLER); + private static final Metadata.Key RECV_KEY = + Metadata.Key.of("x-envoy-peer-metadata", new BinaryToAsciiMarshaller<>( + ProtoUtils.metadataMarshaller(Struct.getDefaultInstance()))); + + private static final String EXCHANGE_TYPE = "type"; + private static final String EXCHANGE_CANONICAL_SERVICE = "canonical_service"; + private static final String EXCHANGE_PROJECT_ID = "project_id"; + private static final String EXCHANGE_LOCATION = "location"; + private static final String EXCHANGE_CLUSTER_NAME = "cluster_name"; + private static final String EXCHANGE_NAMESPACE_NAME = "namespace_name"; + private static final String EXCHANGE_WORKLOAD_NAME = "workload_name"; + private static final String TYPE_GKE = "gcp_kubernetes_engine"; + private static final String TYPE_GCE = "gcp_compute_engine"; + + private final String localMetadata; + private final Attributes localAttributes; + + public MetadataExchanger() { + this( + addOtelResourceAttributes(new GCPResourceProvider().getAttributes()), + System::getenv); + } + + MetadataExchanger(Attributes platformAttributes, Lookup env) { + String type = platformAttributes.get(CLOUD_PLATFORM); + String canonicalService = env.get("CSM_CANONICAL_SERVICE_NAME"); + Struct.Builder struct = Struct.newBuilder(); + put(struct, EXCHANGE_TYPE, type); + put(struct, EXCHANGE_CANONICAL_SERVICE, canonicalService); + if (TYPE_GKE.equals(type)) { + String location = platformAttributes.get(CLOUD_AVAILABILITY_ZONE); + if (location == null) { + location = platformAttributes.get(CLOUD_REGION); + } + put(struct, EXCHANGE_WORKLOAD_NAME, env.get("CSM_WORKLOAD_NAME")); + put(struct, EXCHANGE_NAMESPACE_NAME, platformAttributes.get(K8S_NAMESPACE_NAME)); + put(struct, EXCHANGE_CLUSTER_NAME, platformAttributes.get(K8S_CLUSTER_NAME)); + put(struct, EXCHANGE_LOCATION, location); + put(struct, EXCHANGE_PROJECT_ID, platformAttributes.get(CLOUD_ACCOUNT_ID)); + } else if (TYPE_GCE.equals(type)) { + String location = platformAttributes.get(CLOUD_AVAILABILITY_ZONE); + if (location == null) { + location = platformAttributes.get(CLOUD_REGION); + } + put(struct, EXCHANGE_WORKLOAD_NAME, env.get("CSM_WORKLOAD_NAME")); + put(struct, EXCHANGE_LOCATION, location); + put(struct, EXCHANGE_PROJECT_ID, platformAttributes.get(CLOUD_ACCOUNT_ID)); + } + localMetadata = BaseEncoding.base64().encode(struct.build().toByteArray()); + + localAttributes = Attributes.builder() + .put("csm.mesh_id", nullIsUnknown(env.get("CSM_MESH_ID"))) + .put("csm.workload_canonical_service", nullIsUnknown(canonicalService)) + .build(); + } + + private static String nullIsUnknown(String value) { + return value == null ? "unknown" : value; + } + + private static void put(Struct.Builder struct, String key, String value) { + value = nullIsUnknown(value); + struct.putFields(key, Value.newBuilder().setStringValue(value).build()); + } + + private static void put(AttributesBuilder attributes, String key, Value value) { + attributes.put(key, nullIsUnknown(fromValue(value))); + } + + private static String fromValue(Value value) { + if (value == null) { + return null; + } + if (value.getKindCase() != Value.KindCase.STRING_VALUE) { + return null; + } + return value.getStringValue(); + } + + private static Attributes addOtelResourceAttributes(Attributes platformAttributes) { + // Can't inject env variables as ResourceConfiguration requires the large ConfigProperties API + // to inject our own values and a default implementation isn't provided. So this reads directly + // from System.getenv(). + Attributes envAttributes = ResourceConfiguration + .createEnvironmentResource() + .getAttributes(); + + AttributesBuilder builder = platformAttributes.toBuilder(); + builder.putAll(envAttributes); + return builder.build(); + } + + private void addLabels(AttributesBuilder to, Struct struct) { + to.putAll(localAttributes); + Map remote = struct.getFieldsMap(); + Value typeValue = remote.get(EXCHANGE_TYPE); + String type = fromValue(typeValue); + put(to, "csm.remote_workload_type", typeValue); + put(to, "csm.remote_workload_canonical_service", remote.get(EXCHANGE_CANONICAL_SERVICE)); + if (TYPE_GKE.equals(type)) { + put(to, "csm.remote_workload_project_id", remote.get(EXCHANGE_PROJECT_ID)); + put(to, "csm.remote_workload_location", remote.get(EXCHANGE_LOCATION)); + put(to, "csm.remote_workload_cluster_name", remote.get(EXCHANGE_CLUSTER_NAME)); + put(to, "csm.remote_workload_namespace_name", remote.get(EXCHANGE_NAMESPACE_NAME)); + put(to, "csm.remote_workload_name", remote.get(EXCHANGE_WORKLOAD_NAME)); + } else if (TYPE_GCE.equals(type)) { + put(to, "csm.remote_workload_project_id", remote.get(EXCHANGE_PROJECT_ID)); + put(to, "csm.remote_workload_location", remote.get(EXCHANGE_LOCATION)); + put(to, "csm.remote_workload_name", remote.get(EXCHANGE_WORKLOAD_NAME)); + } + } + + @Override + public boolean enablePluginForChannel(String target) { + URI uri; + try { + uri = new URI(target); + } catch (Exception ex) { + return false; + } + String authority = uri.getAuthority(); + return "xds".equals(uri.getScheme()) + && (authority == null || "traffic-director-global.xds.googleapis.com".equals(authority)); + } + + @Override + public ClientCallPlugin newClientCallPlugin() { + return new ClientCallState(); + } + + public void configureServerBuilder(ServerBuilder serverBuilder) { + serverBuilder.intercept(new ServerCallInterceptor()); + } + + @Override + public ServerStreamPlugin newServerStreamPlugin(Metadata inboundMetadata) { + return new ServerStreamState(inboundMetadata.get(RECV_KEY)); + } + + final class ClientCallState implements ClientCallPlugin { + private volatile Value serviceName; + private volatile Value serviceNamespace; + + @Override + public ClientStreamPlugin newClientStreamPlugin() { + return new ClientStreamState(); + } + + @Override + public CallOptions filterCallOptions(CallOptions options) { + Consumer> existingConsumer = + options.getOption(ClusterImplLoadBalancerProvider.FILTER_METADATA_CONSUMER); + return options.withOption( + ClusterImplLoadBalancerProvider.FILTER_METADATA_CONSUMER, + (Map clusterMetadata) -> { + metadataConsumer(clusterMetadata); + existingConsumer.accept(clusterMetadata); + }); + } + + private void metadataConsumer(Map clusterMetadata) { + Struct struct = clusterMetadata.get("com.google.csm.telemetry_labels"); + if (struct == null) { + struct = Struct.getDefaultInstance(); + } + serviceName = struct.getFieldsMap().get("service_name"); + serviceNamespace = struct.getFieldsMap().get("service_namespace"); + } + + @Override + public void addMetadata(Metadata toMetadata) { + toMetadata.put(SEND_KEY, localMetadata); + } + + class ClientStreamState implements ClientStreamPlugin { + private Struct receivedExchange; + + @Override + public void inboundHeaders(Metadata headers) { + setExchange(headers); + } + + @Override + public void inboundTrailers(Metadata trailers) { + if (receivedExchange != null) { + return; // Received headers + } + setExchange(trailers); + } + + private void setExchange(Metadata metadata) { + Struct received = metadata.get(RECV_KEY); + if (received == null) { + receivedExchange = Struct.getDefaultInstance(); + } else { + receivedExchange = received; + } + } + + @Override + public void addLabels(AttributesBuilder to) { + put(to, "csm.service_name", serviceName); + put(to, "csm.service_namespace_name", serviceNamespace); + Struct exchange = receivedExchange; + if (exchange == null) { + exchange = Struct.getDefaultInstance(); + } + MetadataExchanger.this.addLabels(to, exchange); + } + } + } + + final class ServerCallInterceptor implements ServerInterceptor { + @Override + public ServerCall.Listener interceptCall( + ServerCall call, Metadata headers, ServerCallHandler next) { + if (!headers.containsKey(RECV_KEY)) { + return next.startCall(call, headers); + } else { + return next.startCall(new SimpleForwardingServerCall(call) { + private boolean headersSent; + + @Override + public void sendHeaders(Metadata headers) { + headersSent = true; + headers.put(SEND_KEY, localMetadata); + super.sendHeaders(headers); + } + + @Override + public void close(Status status, Metadata trailers) { + if (!headersSent) { + trailers.put(SEND_KEY, localMetadata); + } + super.close(status, trailers); + } + }, headers); + } + } + } + + final class ServerStreamState implements ServerStreamPlugin { + private final Struct receivedExchange; + + ServerStreamState(Struct exchange) { + if (exchange == null) { + exchange = Struct.getDefaultInstance(); + } + receivedExchange = exchange; + } + + @Override + public void addLabels(AttributesBuilder to) { + MetadataExchanger.this.addLabels(to, receivedExchange); + } + } + + interface Lookup { + String get(String name); + } + + interface Supplier { + T get() throws Exception; + } + + static final class BinaryToAsciiMarshaller implements Metadata.AsciiMarshaller { + private final Metadata.BinaryMarshaller delegate; + + public BinaryToAsciiMarshaller(Metadata.BinaryMarshaller delegate) { + this.delegate = Preconditions.checkNotNull(delegate, "delegate"); + } + + @Override + public T parseAsciiString(String serialized) { + return delegate.parseBytes(BaseEncoding.base64().decode(serialized)); + } + + @Override + public String toAsciiString(T value) { + return BaseEncoding.base64().encode(delegate.toBytes(value)); + } + } +} diff --git a/gcp-csm-observability/src/test/java/io/grpc/gcp/csm/observability/CsmObservabilityTest.java b/gcp-csm-observability/src/test/java/io/grpc/gcp/csm/observability/CsmObservabilityTest.java new file mode 100644 index 00000000000..aba2c43c44f --- /dev/null +++ b/gcp-csm-observability/src/test/java/io/grpc/gcp/csm/observability/CsmObservabilityTest.java @@ -0,0 +1,607 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.gcp.csm.observability; + +import static com.google.common.truth.Truth.assertThat; +import static io.opentelemetry.api.common.AttributeKey.stringKey; +import static org.junit.Assert.assertThrows; + +import com.google.common.collect.ImmutableMap; +import com.google.protobuf.Struct; +import com.google.protobuf.Value; +import io.grpc.CallOptions; +import io.grpc.Channel; +import io.grpc.ClientCall; +import io.grpc.ClientInterceptor; +import io.grpc.ManagedChannelBuilder; +import io.grpc.Metadata; +import io.grpc.MethodDescriptor; +import io.grpc.NameResolverProvider; +import io.grpc.NameResolverRegistry; +import io.grpc.ServerBuilder; +import io.grpc.ServerCall; +import io.grpc.ServerServiceDefinition; +import io.grpc.Status; +import io.grpc.StatusRuntimeException; +import io.grpc.inprocess.InProcessChannelBuilder; +import io.grpc.inprocess.InProcessServerBuilder; +import io.grpc.inprocess.InProcessSocketAddress; +import io.grpc.internal.testing.FakeNameResolverProvider; +import io.grpc.stub.ClientCalls; +import io.grpc.testing.GrpcCleanupRule; +import io.grpc.testing.TestMethodDescriptors; +import io.grpc.xds.ClusterImplLoadBalancerProvider; +import io.opentelemetry.api.common.Attributes; +import io.opentelemetry.sdk.testing.assertj.OpenTelemetryAssertions; +import io.opentelemetry.sdk.testing.junit4.OpenTelemetryRule; +import org.junit.After; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Tests for {@link CsmObservability}. */ +@RunWith(JUnit4.class) +public final class CsmObservabilityTest { + @Rule + public final OpenTelemetryRule openTelemetryTesting = OpenTelemetryRule.create(); + @Rule + public final GrpcCleanupRule grpcCleanupRule = new GrpcCleanupRule(); + + private NameResolverProvider fakeNameResolverProvider; + private InProcessSocketAddress socketAddress = new InProcessSocketAddress("csm-test-server"); + private ServerBuilder serverBuilder = InProcessServerBuilder.forAddress(socketAddress) + .addService(voidService(Status.OK)) + .directExecutor(); + + @After + public void tearDown() { + if (fakeNameResolverProvider != null) { + NameResolverRegistry.getDefaultRegistry().deregister(fakeNameResolverProvider); + } + } + + @Test + public void unknownDataExchange() throws Exception { + MetadataExchanger clientExchanger = new MetadataExchanger( + Attributes.builder().build(), + ImmutableMap.of()::get); + CsmObservability.Builder clientCsmBuilder = new CsmObservability.Builder(clientExchanger) + .sdk(openTelemetryTesting.getOpenTelemetry()); + MetadataExchanger serverExchanger = new MetadataExchanger( + Attributes.builder().build(), + ImmutableMap.of()::get); + CsmObservability.Builder serverCsmBuilder = new CsmObservability.Builder(serverExchanger) + .sdk(openTelemetryTesting.getOpenTelemetry()); + + String target = "xds:///csm-test"; + register(new FakeNameResolverProvider(target, socketAddress)); + serverCsmBuilder.build().configureServerBuilder(serverBuilder); + grpcCleanupRule.register(serverBuilder.build().start()); + + ManagedChannelBuilder channelBuilder = InProcessChannelBuilder.forTarget(target) + .directExecutor(); + clientCsmBuilder.build().configureChannelBuilder(channelBuilder); + Channel channel = grpcCleanupRule.register(channelBuilder.build()); + + ClientCalls.blockingUnaryCall( + channel, TestMethodDescriptors.voidMethod(), CallOptions.DEFAULT, null); + Attributes preexistingClientAttributes = Attributes.builder() + .put(stringKey("grpc.method"), "other") + .put(stringKey("grpc.target"), target) + .build(); + Attributes preexistingClientEndAttributes = preexistingClientAttributes.toBuilder() + .put(stringKey("grpc.status"), "OK") + .build(); + Attributes newClientAttributes = preexistingClientEndAttributes.toBuilder() + .put(stringKey("csm.remote_workload_canonical_service"), "unknown") + .put(stringKey("csm.remote_workload_type"), "unknown") + .put(stringKey("csm.service_name"), "unknown") + .put(stringKey("csm.service_namespace_name"), "unknown") + .put(stringKey("csm.workload_canonical_service"), "unknown") + .put(stringKey("csm.mesh_id"), "unknown") + .build(); + Attributes preexistingServerAttributes = Attributes.builder() + .put(stringKey("grpc.method"), "other") + .build(); + Attributes preexistingServerEndAttributes = preexistingServerAttributes.toBuilder() + .put(stringKey("grpc.status"), "OK") + .build(); + Attributes newServerAttributes = preexistingServerEndAttributes.toBuilder() + .put(stringKey("csm.remote_workload_canonical_service"), "unknown") + .put(stringKey("csm.remote_workload_type"), "unknown") + .put(stringKey("csm.workload_canonical_service"), "unknown") + .put(stringKey("csm.mesh_id"), "unknown") + .build(); + assertMetrics( + preexistingClientAttributes, + preexistingClientEndAttributes, + newClientAttributes, + preexistingServerAttributes, + newServerAttributes); + } + + @Test + public void nonCsmServer() throws Exception { + MetadataExchanger clientExchanger = new MetadataExchanger( + Attributes.builder().build(), + ImmutableMap.of()::get); + CsmObservability.Builder clientCsmBuilder = new CsmObservability.Builder(clientExchanger) + .sdk(openTelemetryTesting.getOpenTelemetry()); + + String target = "xds:///csm-test"; + register(new FakeNameResolverProvider(target, socketAddress)); + grpcCleanupRule.register(serverBuilder.build().start()); + + ManagedChannelBuilder channelBuilder = InProcessChannelBuilder.forTarget(target) + .directExecutor(); + clientCsmBuilder.build().configureChannelBuilder(channelBuilder); + Channel channel = grpcCleanupRule.register(channelBuilder.build()); + + ClientCalls.blockingUnaryCall( + channel, TestMethodDescriptors.voidMethod(), CallOptions.DEFAULT, null); + Attributes preexistingClientAttributes = Attributes.builder() + .put(stringKey("grpc.method"), "other") + .put(stringKey("grpc.target"), target) + .build(); + Attributes preexistingClientEndAttributes = preexistingClientAttributes.toBuilder() + .put(stringKey("grpc.status"), "OK") + .build(); + Attributes newClientAttributes = preexistingClientEndAttributes.toBuilder() + .put(stringKey("csm.remote_workload_canonical_service"), "unknown") + .put(stringKey("csm.remote_workload_type"), "unknown") + .put(stringKey("csm.service_name"), "unknown") + .put(stringKey("csm.service_namespace_name"), "unknown") + .put(stringKey("csm.workload_canonical_service"), "unknown") + .put(stringKey("csm.mesh_id"), "unknown") + .build(); + OpenTelemetryAssertions.assertThat(openTelemetryTesting.getMetrics()) + .satisfiesExactlyInAnyOrder( + metric -> OpenTelemetryAssertions.assertThat(metric) + .hasName("grpc.client.attempt.started") + .hasLongSumSatisfying( + longSum -> longSum.hasPointsSatisfying( + point -> point.hasAttributes(preexistingClientAttributes))), + metric -> OpenTelemetryAssertions.assertThat(metric) + .hasName("grpc.client.attempt.duration") + .hasHistogramSatisfying( + histogram -> histogram.hasPointsSatisfying( + point -> point.hasAttributes(newClientAttributes))), + metric -> OpenTelemetryAssertions.assertThat(metric) + .hasName("grpc.client.attempt.sent_total_compressed_message_size") + .hasHistogramSatisfying( + histogram -> histogram.hasPointsSatisfying( + point -> point.hasAttributes(newClientAttributes))), + metric -> OpenTelemetryAssertions.assertThat(metric) + .hasName("grpc.client.attempt.rcvd_total_compressed_message_size") + .hasHistogramSatisfying( + histogram -> histogram.hasPointsSatisfying( + point -> point.hasAttributes(newClientAttributes))), + metric -> OpenTelemetryAssertions.assertThat(metric) + .hasName("grpc.client.call.duration") + .hasHistogramSatisfying( + histogram -> histogram.hasPointsSatisfying( + point -> point.hasAttributes(preexistingClientEndAttributes)))); + } + + @Test + public void nonCsmClient() throws Exception { + MetadataExchanger clientExchanger = new MetadataExchanger( + Attributes.builder() + .put(stringKey("cloud.platform"), "gcp_kubernetes_engine") + .build(), + ImmutableMap.of()::get); + CsmObservability.Builder clientCsmBuilder = new CsmObservability.Builder(clientExchanger) + .sdk(openTelemetryTesting.getOpenTelemetry()); + MetadataExchanger serverExchanger = new MetadataExchanger( + Attributes.builder().build(), + ImmutableMap.of()::get); + CsmObservability.Builder serverCsmBuilder = new CsmObservability.Builder(serverExchanger) + .sdk(openTelemetryTesting.getOpenTelemetry()); + + String target = "xds://not-a-csm-authority/csm-test"; + register(new FakeNameResolverProvider(target, socketAddress)); + serverCsmBuilder.build().configureServerBuilder(serverBuilder); + grpcCleanupRule.register(serverBuilder.build().start()); + + ManagedChannelBuilder channelBuilder = InProcessChannelBuilder.forTarget(target) + .directExecutor(); + clientCsmBuilder.build().configureChannelBuilder(channelBuilder); + Channel channel = grpcCleanupRule.register(channelBuilder.build()); + + ClientCalls.blockingUnaryCall( + channel, TestMethodDescriptors.voidMethod(), CallOptions.DEFAULT, null); + Attributes preexistingClientAttributes = Attributes.builder() + .put(stringKey("grpc.method"), "other") + .put(stringKey("grpc.target"), target) + .build(); + Attributes preexistingClientEndAttributes = preexistingClientAttributes.toBuilder() + .put(stringKey("grpc.status"), "OK") + .build(); + Attributes preexistingServerAttributes = Attributes.builder() + .put(stringKey("grpc.method"), "other") + .build(); + Attributes preexistingServerEndAttributes = preexistingServerAttributes.toBuilder() + .put(stringKey("grpc.status"), "OK") + .build(); + Attributes newServerAttributes = preexistingServerEndAttributes.toBuilder() + .put(stringKey("csm.remote_workload_canonical_service"), "unknown") + .put(stringKey("csm.remote_workload_type"), "unknown") + .put(stringKey("csm.workload_canonical_service"), "unknown") + .put(stringKey("csm.mesh_id"), "unknown") + .build(); + assertMetrics( + preexistingClientAttributes, + preexistingClientEndAttributes, + preexistingClientEndAttributes, + preexistingServerAttributes, + newServerAttributes); + } + + @Test + public void k8sExchange() throws Exception { + MetadataExchanger clientExchanger = new MetadataExchanger( + Attributes.builder() + .put(stringKey("cloud.platform"), "gcp_kubernetes_engine") + .put(stringKey("k8s.namespace.name"), "namespace-aeiou") + .put(stringKey("k8s.cluster.name"), "mycluster1") + .put(stringKey("cloud.region"), "us-central1") + .put(stringKey("cloud.account.id"), "31415926") + .build(), + ImmutableMap.of( + "CSM_CANONICAL_SERVICE_NAME", "canon-service-is-a-client", + "CSM_WORKLOAD_NAME", "best-client", + "CSM_MESH_ID", "mymesh")::get); + CsmObservability.Builder clientCsmBuilder = new CsmObservability.Builder(clientExchanger) + .sdk(openTelemetryTesting.getOpenTelemetry()); + MetadataExchanger serverExchanger = new MetadataExchanger( + Attributes.builder() + .put(stringKey("cloud.platform"), "gcp_kubernetes_engine") + .put(stringKey("k8s.namespace.name"), "namespace-1e43c") + .put(stringKey("k8s.cluster.name"), "mycluster2") + .put(stringKey("cloud.availability_zone"), "us-east2-c") + .put(stringKey("cloud.region"), "us-east2") + .put(stringKey("cloud.account.id"), "11235813") + .build(), + ImmutableMap.of( + "CSM_CANONICAL_SERVICE_NAME", "server-has-a-single-name", + "CSM_WORKLOAD_NAME", "fast-server", + "CSM_MESH_ID", "meshhh")::get); + CsmObservability.Builder serverCsmBuilder = new CsmObservability.Builder(serverExchanger) + .sdk(openTelemetryTesting.getOpenTelemetry()); + + String target = "xds:///csm-test-k8s"; + register(new FakeNameResolverProvider(target, socketAddress)); + serverCsmBuilder.build().configureServerBuilder(serverBuilder); + grpcCleanupRule.register(serverBuilder.build().start()); + + ManagedChannelBuilder channelBuilder = InProcessChannelBuilder.forTarget(target) + .directExecutor() + .intercept(new ProvideFilterMetadataInterceptor( + ImmutableMap.of("com.google.csm.telemetry_labels", Struct.newBuilder() + .putFields("service_name", + Value.newBuilder().setStringValue("second-server-name").build()) + .putFields("service_namespace", + Value.newBuilder().setStringValue("namespace-0001").build()) + .build()))); + clientCsmBuilder.build().configureChannelBuilder(channelBuilder); + Channel channel = grpcCleanupRule.register(channelBuilder.build()); + + ClientCalls.blockingUnaryCall( + channel, TestMethodDescriptors.voidMethod(), CallOptions.DEFAULT, null); + Attributes preexistingClientAttributes = Attributes.builder() + .put(stringKey("grpc.method"), "other") + .put(stringKey("grpc.target"), target) + .build(); + Attributes preexistingClientEndAttributes = preexistingClientAttributes.toBuilder() + .put(stringKey("grpc.status"), "OK") + .build(); + Attributes newClientAttributes = preexistingClientEndAttributes.toBuilder() + .put(stringKey("csm.remote_workload_canonical_service"), "server-has-a-single-name") + .put(stringKey("csm.remote_workload_type"), "gcp_kubernetes_engine") + .put(stringKey("csm.remote_workload_project_id"), "11235813") + .put(stringKey("csm.remote_workload_location"), "us-east2-c") + .put(stringKey("csm.remote_workload_cluster_name"), "mycluster2") + .put(stringKey("csm.remote_workload_namespace_name"), "namespace-1e43c") + .put(stringKey("csm.remote_workload_name"), "fast-server") + .put(stringKey("csm.service_name"), "second-server-name") + .put(stringKey("csm.service_namespace_name"), "namespace-0001") + .put(stringKey("csm.workload_canonical_service"), "canon-service-is-a-client") + .put(stringKey("csm.mesh_id"), "mymesh") + .build(); + Attributes preexistingServerAttributes = Attributes.builder() + .put(stringKey("grpc.method"), "other") + .build(); + Attributes preexistingServerEndAttributes = preexistingServerAttributes.toBuilder() + .put(stringKey("grpc.status"), "OK") + .build(); + Attributes newServerAttributes = preexistingServerEndAttributes.toBuilder() + .put(stringKey("csm.remote_workload_canonical_service"), "canon-service-is-a-client") + .put(stringKey("csm.remote_workload_type"), "gcp_kubernetes_engine") + .put(stringKey("csm.remote_workload_project_id"), "31415926") + .put(stringKey("csm.remote_workload_location"), "us-central1") + .put(stringKey("csm.remote_workload_cluster_name"), "mycluster1") + .put(stringKey("csm.remote_workload_namespace_name"), "namespace-aeiou") + .put(stringKey("csm.remote_workload_name"), "best-client") + .put(stringKey("csm.workload_canonical_service"), "server-has-a-single-name") + .put(stringKey("csm.mesh_id"), "meshhh") + .build(); + assertMetrics( + preexistingClientAttributes, + preexistingClientEndAttributes, + newClientAttributes, + preexistingServerAttributes, + newServerAttributes); + } + + @Test + public void gceExchange() throws Exception { + MetadataExchanger clientExchanger = new MetadataExchanger( + Attributes.builder() + .put(stringKey("cloud.platform"), "gcp_compute_engine") + .put(stringKey("cloud.region"), "us-central1") + .put(stringKey("cloud.account.id"), "31415926") + .build(), + ImmutableMap.of( + "CSM_CANONICAL_SERVICE_NAME", "canon-service-is-a-client", + "CSM_WORKLOAD_NAME", "best-client", + "CSM_MESH_ID", "mymesh")::get); + CsmObservability.Builder clientCsmBuilder = new CsmObservability.Builder(clientExchanger) + .sdk(openTelemetryTesting.getOpenTelemetry()); + MetadataExchanger serverExchanger = new MetadataExchanger( + Attributes.builder() + .put(stringKey("cloud.platform"), "gcp_compute_engine") + .put(stringKey("cloud.availability_zone"), "us-east2-c") + .put(stringKey("cloud.region"), "us-east2") + .put(stringKey("cloud.account.id"), "11235813") + .build(), + ImmutableMap.of( + "CSM_CANONICAL_SERVICE_NAME", "server-has-a-single-name", + "CSM_WORKLOAD_NAME", "fast-server", + "CSM_MESH_ID", "meshhh")::get); + CsmObservability.Builder serverCsmBuilder = new CsmObservability.Builder(serverExchanger) + .sdk(openTelemetryTesting.getOpenTelemetry()); + + String target = "xds:///csm-test"; + register(new FakeNameResolverProvider(target, socketAddress)); + serverCsmBuilder.build().configureServerBuilder(serverBuilder); + grpcCleanupRule.register(serverBuilder.build().start()); + + ManagedChannelBuilder channelBuilder = InProcessChannelBuilder.forTarget(target) + .directExecutor() + .intercept(new ProvideFilterMetadataInterceptor(ImmutableMap.of())); + clientCsmBuilder.build().configureChannelBuilder(channelBuilder); + Channel channel = grpcCleanupRule.register(channelBuilder.build()); + + ClientCalls.blockingUnaryCall( + channel, TestMethodDescriptors.voidMethod(), CallOptions.DEFAULT, null); + Attributes preexistingClientAttributes = Attributes.builder() + .put(stringKey("grpc.method"), "other") + .put(stringKey("grpc.target"), target) + .build(); + Attributes preexistingClientEndAttributes = preexistingClientAttributes.toBuilder() + .put(stringKey("grpc.status"), "OK") + .build(); + Attributes newClientAttributes = preexistingClientEndAttributes.toBuilder() + .put(stringKey("csm.remote_workload_canonical_service"), "server-has-a-single-name") + .put(stringKey("csm.remote_workload_type"), "gcp_compute_engine") + .put(stringKey("csm.remote_workload_project_id"), "11235813") + .put(stringKey("csm.remote_workload_location"), "us-east2-c") + .put(stringKey("csm.remote_workload_name"), "fast-server") + .put(stringKey("csm.service_name"), "unknown") + .put(stringKey("csm.service_namespace_name"), "unknown") + .put(stringKey("csm.workload_canonical_service"), "canon-service-is-a-client") + .put(stringKey("csm.mesh_id"), "mymesh") + .build(); + Attributes preexistingServerAttributes = Attributes.builder() + .put(stringKey("grpc.method"), "other") + .build(); + Attributes preexistingServerEndAttributes = preexistingServerAttributes.toBuilder() + .put(stringKey("grpc.status"), "OK") + .build(); + Attributes newServerAttributes = preexistingServerEndAttributes.toBuilder() + .put(stringKey("csm.remote_workload_canonical_service"), "canon-service-is-a-client") + .put(stringKey("csm.remote_workload_type"), "gcp_compute_engine") + .put(stringKey("csm.remote_workload_project_id"), "31415926") + .put(stringKey("csm.remote_workload_location"), "us-central1") + .put(stringKey("csm.remote_workload_name"), "best-client") + .put(stringKey("csm.workload_canonical_service"), "server-has-a-single-name") + .put(stringKey("csm.mesh_id"), "meshhh") + .build(); + assertMetrics( + preexistingClientAttributes, + preexistingClientEndAttributes, + newClientAttributes, + preexistingServerAttributes, + newServerAttributes); + } + + @Test + public void trailersOnly() throws Exception { + MetadataExchanger clientExchanger = new MetadataExchanger( + Attributes.builder() + .put(stringKey("cloud.platform"), "gcp_compute_engine") + .put(stringKey("cloud.region"), "us-central1") + .put(stringKey("cloud.account.id"), "31415926") + .build(), + ImmutableMap.of( + "CSM_CANONICAL_SERVICE_NAME", "canon-service-is-a-client", + "CSM_WORKLOAD_NAME", "best-client", + "CSM_MESH_ID", "mymesh")::get); + CsmObservability.Builder clientCsmBuilder = new CsmObservability.Builder(clientExchanger) + .sdk(openTelemetryTesting.getOpenTelemetry()); + + MetadataExchanger serverExchanger = new MetadataExchanger( + Attributes.builder() + .put(stringKey("cloud.platform"), "gcp_compute_engine") + .put(stringKey("cloud.availability_zone"), "us-east2-c") + .put(stringKey("cloud.region"), "us-east2") + .put(stringKey("cloud.account.id"), "11235813") + .build(), + ImmutableMap.of( + "CSM_CANONICAL_SERVICE_NAME", "server-has-a-single-name", + "CSM_WORKLOAD_NAME", "fast-server", + "CSM_MESH_ID", "meshhh")::get); + CsmObservability.Builder serverCsmBuilder = new CsmObservability.Builder(serverExchanger) + .sdk(openTelemetryTesting.getOpenTelemetry()); + + String target = "xds:///csm-test"; + register(new FakeNameResolverProvider(target, socketAddress)); + // Trailers-only + serverBuilder.addService(voidService(Status.INVALID_ARGUMENT)); + serverCsmBuilder.build().configureServerBuilder(serverBuilder); + grpcCleanupRule.register(serverBuilder.build().start()); + + ManagedChannelBuilder channelBuilder = InProcessChannelBuilder.forTarget(target) + .directExecutor(); + clientCsmBuilder.build().configureChannelBuilder(channelBuilder); + Channel channel = grpcCleanupRule.register(channelBuilder.build()); + + assertThrows(StatusRuntimeException.class, () -> + ClientCalls.blockingUnaryCall( + channel, TestMethodDescriptors.voidMethod(), CallOptions.DEFAULT, null)); + Attributes preexistingClientAttributes = Attributes.builder() + .put(stringKey("grpc.method"), "other") + .put(stringKey("grpc.target"), target) + .build(); + Attributes preexistingClientEndAttributes = preexistingClientAttributes.toBuilder() + .put(stringKey("grpc.status"), "INVALID_ARGUMENT") + .build(); + Attributes newClientAttributes = preexistingClientEndAttributes.toBuilder() + .put(stringKey("csm.remote_workload_canonical_service"), "server-has-a-single-name") + .put(stringKey("csm.remote_workload_type"), "gcp_compute_engine") + .put(stringKey("csm.remote_workload_project_id"), "11235813") + .put(stringKey("csm.remote_workload_location"), "us-east2-c") + .put(stringKey("csm.remote_workload_name"), "fast-server") + .put(stringKey("csm.service_name"), "unknown") + .put(stringKey("csm.service_namespace_name"), "unknown") + .put(stringKey("csm.workload_canonical_service"), "canon-service-is-a-client") + .put(stringKey("csm.mesh_id"), "mymesh") + .build(); + Attributes preexistingServerAttributes = Attributes.builder() + .put(stringKey("grpc.method"), "other") + .build(); + Attributes preexistingServerEndAttributes = preexistingServerAttributes.toBuilder() + .put(stringKey("grpc.status"), "INVALID_ARGUMENT") + .build(); + Attributes newServerAttributes = preexistingServerEndAttributes.toBuilder() + .put(stringKey("csm.remote_workload_canonical_service"), "canon-service-is-a-client") + .put(stringKey("csm.remote_workload_type"), "gcp_compute_engine") + .put(stringKey("csm.remote_workload_project_id"), "31415926") + .put(stringKey("csm.remote_workload_location"), "us-central1") + .put(stringKey("csm.remote_workload_name"), "best-client") + .put(stringKey("csm.workload_canonical_service"), "server-has-a-single-name") + .put(stringKey("csm.mesh_id"), "meshhh") + .build(); + assertMetrics( + preexistingClientAttributes, + preexistingClientEndAttributes, + newClientAttributes, + preexistingServerAttributes, + newServerAttributes); + } + + private void register(NameResolverProvider provider) { + assertThat(fakeNameResolverProvider).isNull(); + fakeNameResolverProvider = provider; + NameResolverRegistry.getDefaultRegistry().register(provider); + } + + private static ServerServiceDefinition voidService(Status status) { + return ServerServiceDefinition.builder(TestMethodDescriptors.voidMethod().getServiceName()) + .addMethod(TestMethodDescriptors.voidMethod(), (call, headers) -> { + if (status.isOk()) { + call.sendHeaders(new Metadata()); + call.sendMessage(null); + } + call.close(status, new Metadata()); + return new ServerCall.Listener() {}; + }) + .build(); + } + + private void assertMetrics( + Attributes preexistingClientAttributes, + Attributes preexistingClientEndAttributes, + Attributes newClientAttributes, + Attributes preexistingServerAttributes, + Attributes newServerAttributes) { + OpenTelemetryAssertions.assertThat(openTelemetryTesting.getMetrics()) + .satisfiesExactlyInAnyOrder( + metric -> OpenTelemetryAssertions.assertThat(metric) + .hasName("grpc.client.attempt.started") + .hasLongSumSatisfying( + longSum -> longSum.hasPointsSatisfying( + point -> point.hasAttributes(preexistingClientAttributes))), + metric -> OpenTelemetryAssertions.assertThat(metric) + .hasName("grpc.client.attempt.duration") + .hasHistogramSatisfying( + histogram -> histogram.hasPointsSatisfying( + point -> point.hasAttributes(newClientAttributes))), + metric -> OpenTelemetryAssertions.assertThat(metric) + .hasName("grpc.client.attempt.sent_total_compressed_message_size") + .hasHistogramSatisfying( + histogram -> histogram.hasPointsSatisfying( + point -> point.hasAttributes(newClientAttributes))), + metric -> OpenTelemetryAssertions.assertThat(metric) + .hasName("grpc.client.attempt.rcvd_total_compressed_message_size") + .hasHistogramSatisfying( + histogram -> histogram.hasPointsSatisfying( + point -> point.hasAttributes(newClientAttributes))), + metric -> OpenTelemetryAssertions.assertThat(metric) + .hasName("grpc.client.call.duration") + .hasHistogramSatisfying( + histogram -> histogram.hasPointsSatisfying( + point -> point.hasAttributes(preexistingClientEndAttributes))), + metric -> OpenTelemetryAssertions.assertThat(metric) + .hasName("grpc.server.call.started") + .hasLongSumSatisfying( + longSum -> longSum.hasPointsSatisfying( + point -> point.hasAttributes(preexistingServerAttributes))), + metric -> OpenTelemetryAssertions.assertThat(metric) + .hasName("grpc.server.call.duration") + .hasHistogramSatisfying( + histogram -> histogram.hasPointsSatisfying( + point -> point.hasAttributes(newServerAttributes))), + metric -> OpenTelemetryAssertions.assertThat(metric) + .hasName("grpc.server.call.sent_total_compressed_message_size") + .hasHistogramSatisfying( + histogram -> histogram.hasPointsSatisfying( + point -> point.hasAttributes(newServerAttributes))), + metric -> OpenTelemetryAssertions.assertThat(metric) + .hasName("grpc.server.call.rcvd_total_compressed_message_size") + .hasHistogramSatisfying( + histogram -> histogram.hasPointsSatisfying( + point -> point.hasAttributes(newServerAttributes)))); + } + + private static class ProvideFilterMetadataInterceptor implements ClientInterceptor { + private final ImmutableMap filterMetadata; + + public ProvideFilterMetadataInterceptor(ImmutableMap filterMetadata) { + this.filterMetadata = filterMetadata; + } + + @Override + public ClientCall interceptCall( + MethodDescriptor method, CallOptions callOptions, Channel next) { + callOptions.getOption(ClusterImplLoadBalancerProvider.FILTER_METADATA_CONSUMER) + .accept(filterMetadata); + return next.newCall(method, callOptions); + } + } +} diff --git a/gcp-csm-observability/src/test/java/io/grpc/gcp/csm/observability/MetadataExchangerTest.java b/gcp-csm-observability/src/test/java/io/grpc/gcp/csm/observability/MetadataExchangerTest.java new file mode 100644 index 00000000000..cc3472be182 --- /dev/null +++ b/gcp-csm-observability/src/test/java/io/grpc/gcp/csm/observability/MetadataExchangerTest.java @@ -0,0 +1,154 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.gcp.csm.observability; + +import static com.google.common.truth.Truth.assertThat; +import static io.opentelemetry.api.common.AttributeKey.stringKey; + +import com.google.common.collect.ImmutableMap; +import com.google.common.io.BaseEncoding; +import com.google.protobuf.Struct; +import com.google.protobuf.Value; +import io.grpc.Metadata; +import io.opentelemetry.api.common.Attributes; +import io.opentelemetry.api.common.AttributesBuilder; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Tests for {@link MetadataExchanger}. */ +@RunWith(JUnit4.class) +public final class MetadataExchangerTest { + + @Test + public void enablePluginForChannel_matches() { + MetadataExchanger exchanger = + new MetadataExchanger(Attributes.builder().build(), (name) -> null); + assertThat(exchanger.enablePluginForChannel("xds:///testing")).isTrue(); + assertThat(exchanger.enablePluginForChannel("xds:/testing")).isTrue(); + assertThat(exchanger.enablePluginForChannel( + "xds://traffic-director-global.xds.googleapis.com/testing:123")).isTrue(); + } + + @Test + public void enablePluginForChannel_doesNotMatch() { + MetadataExchanger exchanger = + new MetadataExchanger(Attributes.builder().build(), (name) -> null); + assertThat(exchanger.enablePluginForChannel("dns:///localhost")).isFalse(); + assertThat(exchanger.enablePluginForChannel("xds:///[]")).isFalse(); + assertThat(exchanger.enablePluginForChannel("xds://my-xds-server/testing")).isFalse(); + } + + @Test + public void addLabels_receivedWrongType() { + MetadataExchanger exchanger = + new MetadataExchanger(Attributes.builder().build(), (name) -> null); + Metadata metadata = new Metadata(); + metadata.put(Metadata.Key.of("x-envoy-peer-metadata", Metadata.ASCII_STRING_MARSHALLER), + BaseEncoding.base64().encode(Struct.newBuilder() + .putFields("type", Value.newBuilder().setNumberValue(1).build()) + .build() + .toByteArray())); + AttributesBuilder builder = Attributes.builder(); + exchanger.newServerStreamPlugin(metadata).addLabels(builder); + + assertThat(builder.build()).isEqualTo(Attributes.builder() + .put(stringKey("csm.mesh_id"), "unknown") + .put(stringKey("csm.workload_canonical_service"), "unknown") + .put(stringKey("csm.remote_workload_type"), "unknown") + .put(stringKey("csm.remote_workload_canonical_service"), "unknown") + .build()); + } + + @Test + public void addLabelsFromExchange_unknownGcpType() { + MetadataExchanger exchanger = + new MetadataExchanger(Attributes.builder().build(), (name) -> null); + Metadata metadata = new Metadata(); + metadata.put(Metadata.Key.of("x-envoy-peer-metadata", Metadata.ASCII_STRING_MARSHALLER), + BaseEncoding.base64().encode(Struct.newBuilder() + .putFields("type", Value.newBuilder().setStringValue("gcp_surprise").build()) + .putFields("canonical_service", Value.newBuilder().setStringValue("myservice1").build()) + .build() + .toByteArray())); + AttributesBuilder builder = Attributes.builder(); + exchanger.newServerStreamPlugin(metadata).addLabels(builder); + + assertThat(builder.build()).isEqualTo(Attributes.builder() + .put(stringKey("csm.mesh_id"), "unknown") + .put(stringKey("csm.workload_canonical_service"), "unknown") + .put(stringKey("csm.remote_workload_type"), "gcp_surprise") + .put(stringKey("csm.remote_workload_canonical_service"), "myservice1") + .build()); + } + + @Test + public void addMetadata_k8s() throws Exception { + MetadataExchanger exchanger = new MetadataExchanger( + Attributes.builder() + .put(stringKey("cloud.platform"), "gcp_kubernetes_engine") + .put(stringKey("k8s.namespace.name"), "mynamespace1") + .put(stringKey("k8s.cluster.name"), "mycluster1") + .put(stringKey("cloud.availability_zone"), "myzone1") + .put(stringKey("cloud.account.id"), "0001") + .build(), + ImmutableMap.of( + "CSM_CANONICAL_SERVICE_NAME", "myservice1", + "CSM_WORKLOAD_NAME", "myworkload1")::get); + Metadata metadata = new Metadata(); + exchanger.newClientCallPlugin().addMetadata(metadata); + + Struct peer = Struct.parseFrom(BaseEncoding.base64().decode(metadata.get( + Metadata.Key.of("x-envoy-peer-metadata", Metadata.ASCII_STRING_MARSHALLER)))); + assertThat(peer).isEqualTo( + Struct.newBuilder() + .putFields("type", Value.newBuilder().setStringValue("gcp_kubernetes_engine").build()) + .putFields("canonical_service", Value.newBuilder().setStringValue("myservice1").build()) + .putFields("workload_name", Value.newBuilder().setStringValue("myworkload1").build()) + .putFields("namespace_name", Value.newBuilder().setStringValue("mynamespace1").build()) + .putFields("cluster_name", Value.newBuilder().setStringValue("mycluster1").build()) + .putFields("location", Value.newBuilder().setStringValue("myzone1").build()) + .putFields("project_id", Value.newBuilder().setStringValue("0001").build()) + .build()); + } + + @Test + public void addMetadata_gce() throws Exception { + MetadataExchanger exchanger = new MetadataExchanger( + Attributes.builder() + .put(stringKey("cloud.platform"), "gcp_compute_engine") + .put(stringKey("cloud.availability_zone"), "myzone1") + .put(stringKey("cloud.account.id"), "0001") + .build(), + ImmutableMap.of( + "CSM_CANONICAL_SERVICE_NAME", "myservice1", + "CSM_WORKLOAD_NAME", "myworkload1")::get); + Metadata metadata = new Metadata(); + exchanger.newClientCallPlugin().addMetadata(metadata); + + Struct peer = Struct.parseFrom(BaseEncoding.base64().decode(metadata.get( + Metadata.Key.of("x-envoy-peer-metadata", Metadata.ASCII_STRING_MARSHALLER)))); + assertThat(peer).isEqualTo( + Struct.newBuilder() + .putFields("type", Value.newBuilder().setStringValue("gcp_compute_engine").build()) + .putFields("canonical_service", Value.newBuilder().setStringValue("myservice1").build()) + .putFields("workload_name", Value.newBuilder().setStringValue("myworkload1").build()) + .putFields("location", Value.newBuilder().setStringValue("myzone1").build()) + .putFields("project_id", Value.newBuilder().setStringValue("0001").build()) + .build()); + } +} diff --git a/gcp-observability/build.gradle b/gcp-observability/build.gradle index 69bff88bf0f..1d8c7a9f961 100644 --- a/gcp-observability/build.gradle +++ b/gcp-observability/build.gradle @@ -34,6 +34,7 @@ dependencies { implementation project(':grpc-protobuf'), project(':grpc-stub'), project(':grpc-census'), + project(":grpc-context"), // Override opencensus dependency with our newer version libraries.opencensus.contrib.grpc.metrics // Avoid gradle using project dependencies without configuration: shadow implementation (libraries.google.cloud.logging) { @@ -58,14 +59,12 @@ dependencies { project(path: ':grpc-alts', configuration: 'shadow'), project(':grpc-auth'), // Align grpc versions project(':grpc-core'), // Align grpc versions - project(':grpc-grpclb'), // Align grpc versions project(':grpc-services'), // Align grpc versions libraries.animalsniffer.annotations, // Use our newer version libraries.auto.value.annotations, // Use our newer version libraries.guava.jre, // Use our newer version libraries.protobuf.java.util, // Use our newer version - libraries.re2j, // Use our newer version - libraries.j2objc.annotations // Explicit dependency to keep in step with version used by guava + libraries.re2j // Use our newer version testImplementation testFixtures(project(':grpc-api')), project(':grpc-testing'), @@ -74,7 +73,11 @@ dependencies { exclude group: 'junit', module: 'junit' } - signature libraries.signature.java + signature (libraries.signature.java) { + artifact { + extension = "signature" + } + } } configureProtoCompilation() diff --git a/gcp-observability/interop/build.gradle b/gcp-observability/interop/build.gradle index 4a78c056eac..7e17624995a 100644 --- a/gcp-observability/interop/build.gradle +++ b/gcp-observability/interop/build.gradle @@ -10,7 +10,11 @@ dependencies { implementation project(':grpc-interop-testing'), project(':grpc-gcp-observability') - signature libraries.signature.java + signature (libraries.signature.java) { + artifact { + extension = "signature" + } + } } application { diff --git a/gcp-observability/src/main/java/io/grpc/gcp/observability/GcpObservability.java b/gcp-observability/src/main/java/io/grpc/gcp/observability/GcpObservability.java index d7eaf43f94c..7fe4e3a8a3c 100644 --- a/gcp-observability/src/main/java/io/grpc/gcp/observability/GcpObservability.java +++ b/gcp-observability/src/main/java/io/grpc/gcp/observability/GcpObservability.java @@ -19,10 +19,14 @@ import static com.google.common.base.Preconditions.checkNotNull; import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; import io.grpc.ClientInterceptor; -import io.grpc.InternalGlobalInterceptors; +import io.grpc.InternalConfigurator; +import io.grpc.InternalConfiguratorRegistry; +import io.grpc.ManagedChannelBuilder; import io.grpc.ManagedChannelProvider.ProviderNotFoundException; +import io.grpc.ServerBuilder; import io.grpc.ServerInterceptor; import io.grpc.ServerStreamTracer; import io.grpc.census.InternalCensusStatsAccessor; @@ -55,7 +59,9 @@ import java.net.UnknownHostException; import java.security.SecureRandom; import java.util.ArrayList; +import java.util.Arrays; import java.util.HashMap; +import java.util.List; import java.util.Map; import java.util.concurrent.TimeUnit; import java.util.logging.Level; @@ -121,6 +127,15 @@ static GcpObservability grpcInit( /** Un-initialize/shutdown grpc-observability. */ @Override public void close() { + closeWithSleepTime(2 * METRICS_EXPORT_INTERVAL, TimeUnit.SECONDS); + } + + /** + * Method to close along with sleep time explicitly. + * + * @param sleepTime sleepTime + */ + void closeWithSleepTime(long sleepTime, TimeUnit timeUnit) { synchronized (GcpObservability.class) { if (instance == null) { throw new IllegalStateException("GcpObservability already closed!"); @@ -129,8 +144,7 @@ public void close() { if (config.isEnableCloudMonitoring() || config.isEnableCloudTracing()) { try { // Sleeping before shutdown to ensure all metrics and traces are flushed - Thread.sleep( - TimeUnit.MILLISECONDS.convert(2 * METRICS_EXPORT_INTERVAL, TimeUnit.SECONDS)); + timeUnit.sleep(sleepTime); } catch (InterruptedException e) { Thread.currentThread().interrupt(); logger.log(Level.SEVERE, "Caught exception during sleep", e); @@ -161,8 +175,42 @@ private void setProducer( tracerFactories.add(InternalCensusTracingAccessor.getServerStreamTracerFactory()); } - InternalGlobalInterceptors.setInterceptorsTracers( - clientInterceptors, serverInterceptors, tracerFactories); + InternalConfiguratorRegistry.setConfigurators(Arrays.asList( + new ObservabilityConfigurator(clientInterceptors, serverInterceptors, tracerFactories))); + } + + @VisibleForTesting + static final class ObservabilityConfigurator implements InternalConfigurator { + final List clientInterceptors; + final List serverInterceptors; + final List tracerFactories; + + ObservabilityConfigurator( + List clientInterceptors, + List serverInterceptors, + List tracerFactories) { + this.clientInterceptors = ImmutableList.copyOf( + checkNotNull(clientInterceptors, "clientInterceptors")); + this.serverInterceptors = ImmutableList.copyOf( + checkNotNull(serverInterceptors, "serverInterceptors")); + this.tracerFactories = ImmutableList.copyOf( + checkNotNull(tracerFactories, "tracerFactories")); + } + + @Override + public void configureChannelBuilder(ManagedChannelBuilder builder) { + builder.intercept(clientInterceptors); + } + + @Override + public void configureServerBuilder(ServerBuilder builder) { + for (ServerInterceptor interceptor : serverInterceptors) { + builder.intercept(interceptor); + } + for (ServerStreamTracer.Factory factory : tracerFactories) { + builder.addStreamTracerFactory(factory); + } + } } static ConditionalClientInterceptor getConditionalInterceptor(ClientInterceptor interceptor) { diff --git a/gcp-observability/src/main/java/io/grpc/gcp/observability/ObservabilityConfigImpl.java b/gcp-observability/src/main/java/io/grpc/gcp/observability/ObservabilityConfigImpl.java index 2b0a44473d0..ae74bf10c43 100644 --- a/gcp-observability/src/main/java/io/grpc/gcp/observability/ObservabilityConfigImpl.java +++ b/gcp-observability/src/main/java/io/grpc/gcp/observability/ObservabilityConfigImpl.java @@ -19,7 +19,6 @@ import static com.google.common.base.Preconditions.checkArgument; import com.google.cloud.ServiceOptions; -import com.google.common.base.Charsets; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; @@ -28,6 +27,7 @@ import io.opencensus.trace.Sampler; import io.opencensus.trace.samplers.Samplers; import java.io.IOException; +import java.nio.charset.StandardCharsets; import java.nio.file.Files; import java.nio.file.Paths; import java.util.Collections; @@ -75,7 +75,7 @@ static ObservabilityConfigImpl getInstance() throws IOException { void parseFile(String configFile) throws IOException { String configFileContent = - new String(Files.readAllBytes(Paths.get(configFile)), Charsets.UTF_8); + new String(Files.readAllBytes(Paths.get(configFile)), StandardCharsets.UTF_8); checkArgument(!configFileContent.isEmpty(), CONFIG_FILE_ENV_VAR_NAME + " is empty!"); parse(configFileContent); } diff --git a/gcp-observability/src/test/java/io/grpc/gcp/observability/GcpObservabilityTest.java b/gcp-observability/src/test/java/io/grpc/gcp/observability/GcpObservabilityTest.java index c42d7b65c08..25467839dd6 100644 --- a/gcp-observability/src/test/java/io/grpc/gcp/observability/GcpObservabilityTest.java +++ b/gcp-observability/src/test/java/io/grpc/gcp/observability/GcpObservabilityTest.java @@ -30,13 +30,14 @@ import io.grpc.Channel; import io.grpc.ClientCall; import io.grpc.ClientInterceptor; -import io.grpc.InternalGlobalInterceptors; +import io.grpc.InternalConfiguratorRegistry; import io.grpc.Metadata; import io.grpc.MethodDescriptor; import io.grpc.ServerCall; import io.grpc.ServerCallHandler; import io.grpc.ServerInterceptor; import io.grpc.StaticTestingClassLoader; +import io.grpc.gcp.observability.GcpObservability.ObservabilityConfigurator; import io.grpc.gcp.observability.interceptors.ConditionalClientInterceptor; import io.grpc.gcp.observability.interceptors.InternalLoggingChannelInterceptor; import io.grpc.gcp.observability.interceptors.InternalLoggingServerInterceptor; @@ -44,6 +45,7 @@ import io.opencensus.trace.samplers.Samplers; import java.io.IOException; import java.util.List; +import java.util.concurrent.TimeUnit; import java.util.regex.Pattern; import org.junit.Test; import org.junit.runner.RunWith; @@ -56,7 +58,8 @@ public class GcpObservabilityTest { new StaticTestingClassLoader( getClass().getClassLoader(), Pattern.compile( - "io\\.grpc\\.InternalGlobalInterceptors|io\\.grpc\\.GlobalInterceptors|" + "io\\.grpc\\.InternalConfigurator|io\\.grpc\\.Configurator|" + + "io\\.grpc\\.InternalConfiguratorRegistry|io\\.grpc\\.ConfiguratorRegistry|" + "io\\.grpc\\.gcp\\.observability\\.[^.]+|" + "io\\.grpc\\.gcp\\.observability\\.interceptors\\.[^.]+|" + "io\\.grpc\\.gcp\\.observability\\.GcpObservabilityTest\\$.*")); @@ -194,18 +197,23 @@ public void run() { mock(InternalLoggingServerInterceptor.Factory.class); when(serverInterceptorFactory.create()).thenReturn(serverInterceptor); - try (GcpObservability unused = - GcpObservability.grpcInit( - sink, config, channelInterceptorFactory, serverInterceptorFactory)) { - List list = InternalGlobalInterceptors.getClientInterceptors(); + try { + GcpObservability gcpObservability = GcpObservability.grpcInit( + sink, config, channelInterceptorFactory, serverInterceptorFactory); + List configurators = InternalConfiguratorRegistry.getConfigurators(); + assertThat(configurators).hasSize(1); + ObservabilityConfigurator configurator = (ObservabilityConfigurator) configurators.get(0); + List list = configurator.clientInterceptors; assertThat(list).hasSize(3); assertThat(list.get(1)).isInstanceOf(ConditionalClientInterceptor.class); assertThat(list.get(2)).isInstanceOf(ConditionalClientInterceptor.class); - assertThat(InternalGlobalInterceptors.getServerInterceptors()).hasSize(1); - assertThat(InternalGlobalInterceptors.getServerStreamTracerFactories()).hasSize(2); + assertThat(configurator.serverInterceptors).hasSize(1); + assertThat(configurator.tracerFactories).hasSize(2); + gcpObservability.closeWithSleepTime(3000, TimeUnit.MILLISECONDS); } catch (Exception e) { fail("Encountered exception: " + e); } + verify(sink).close(); } } @@ -228,9 +236,12 @@ public void run() { try (GcpObservability unused = GcpObservability.grpcInit( sink, config, channelInterceptorFactory, serverInterceptorFactory)) { - assertThat(InternalGlobalInterceptors.getClientInterceptors()).isEmpty(); - assertThat(InternalGlobalInterceptors.getServerInterceptors()).isEmpty(); - assertThat(InternalGlobalInterceptors.getServerStreamTracerFactories()).isEmpty(); + List configurators = InternalConfiguratorRegistry.getConfigurators(); + assertThat(configurators).hasSize(1); + ObservabilityConfigurator configurator = (ObservabilityConfigurator) configurators.get(0); + assertThat(configurator.clientInterceptors).isEmpty(); + assertThat(configurator.serverInterceptors).isEmpty(); + assertThat(configurator.tracerFactories).isEmpty(); } catch (Exception e) { fail("Encountered exception: " + e); } diff --git a/gcp-observability/src/test/java/io/grpc/gcp/observability/LoggingTest.java b/gcp-observability/src/test/java/io/grpc/gcp/observability/LoggingTest.java index 6d575e58eed..92e67b01e01 100644 --- a/gcp-observability/src/test/java/io/grpc/gcp/observability/LoggingTest.java +++ b/gcp-observability/src/test/java/io/grpc/gcp/observability/LoggingTest.java @@ -71,9 +71,9 @@ public class LoggingTest { new StaticTestingClassLoader(getClass().getClassLoader(), Pattern.compile("io\\.grpc\\..*")); /** - * Cloud logging test using GlobalInterceptors. + * Cloud logging test using global interceptors. * - *

Ignoring test, because it calls external Cloud Logging APIs. + *

Ignoring test, because it calls external Cloud Logging APIs. * To test cloud logging setup locally, * 1. Set up Cloud auth credentials * 2. Assign permissions to service account to write logs to project specified by diff --git a/gcp-observability/src/test/java/io/grpc/gcp/observability/ObservabilityConfigImplTest.java b/gcp-observability/src/test/java/io/grpc/gcp/observability/ObservabilityConfigImplTest.java index d6f23fbcc9a..f409a149bf1 100644 --- a/gcp-observability/src/test/java/io/grpc/gcp/observability/ObservabilityConfigImplTest.java +++ b/gcp-observability/src/test/java/io/grpc/gcp/observability/ObservabilityConfigImplTest.java @@ -21,12 +21,12 @@ import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; -import com.google.common.base.Charsets; import io.grpc.gcp.observability.ObservabilityConfig.LogFilter; import io.opencensus.trace.Sampler; import io.opencensus.trace.samplers.Samplers; import java.io.File; import java.io.IOException; +import java.nio.charset.StandardCharsets; import java.nio.file.Files; import java.nio.file.Paths; import java.util.Collections; @@ -108,8 +108,7 @@ public class ObservabilityConfigImplTest { private static final String PROJECT_ID = "{\n" + " \"project_id\": \"grpc-testing\",\n" - + " \"cloud_logging\": {},\n" - + " \"project_id\": \"grpc-testing\"\n" + + " \"cloud_logging\": {}\n" + "}"; private static final String EMPTY_CONFIG = "{}"; @@ -401,7 +400,8 @@ public void badProbabilisticSampler_error() throws IOException { public void configFileLogFilters() throws Exception { File configFile = tempFolder.newFile(); Files.write( - Paths.get(configFile.getAbsolutePath()), CLIENT_LOG_FILTERS.getBytes(Charsets.US_ASCII)); + Paths.get(configFile.getAbsolutePath()), + CLIENT_LOG_FILTERS.getBytes(StandardCharsets.US_ASCII)); observabilityConfig.parseFile(configFile.getAbsolutePath()); assertTrue(observabilityConfig.isEnableCloudLogging()); assertThat(observabilityConfig.getProjectId()).isEqualTo("grpc-testing"); diff --git a/googleapis/BUILD.bazel b/googleapis/BUILD.bazel index 77b0bcd93b9..5b62b21cb3a 100644 --- a/googleapis/BUILD.bazel +++ b/googleapis/BUILD.bazel @@ -1,3 +1,6 @@ +load("@rules_java//java:defs.bzl", "java_library") +load("@rules_jvm_external//:defs.bzl", "artifact") + java_library( name = "googleapis", srcs = glob([ @@ -9,6 +12,7 @@ java_library( "//api", "//core:internal", "//xds", - "@com_google_guava_guava//jar", + artifact("com.google.guava:guava"), + artifact("com.google.errorprone:error_prone_annotations"), ], ) diff --git a/googleapis/build.gradle b/googleapis/build.gradle index 435e552d47d..3a7a3a2766a 100644 --- a/googleapis/build.gradle +++ b/googleapis/build.gradle @@ -21,5 +21,9 @@ dependencies { libraries.guava.jre // JRE required by transitive protobuf-java-util testImplementation testFixtures(project(':grpc-core')) - signature libraries.signature.java + signature (libraries.signature.java) { + artifact { + extension = "signature" + } + } } diff --git a/googleapis/src/main/java/io/grpc/googleapis/GoogleCloudToProdExperimentalNameResolverProvider.java b/googleapis/src/main/java/io/grpc/googleapis/GoogleCloudToProdExperimentalNameResolverProvider.java index 349e1c94380..db674aeb2ee 100644 --- a/googleapis/src/main/java/io/grpc/googleapis/GoogleCloudToProdExperimentalNameResolverProvider.java +++ b/googleapis/src/main/java/io/grpc/googleapis/GoogleCloudToProdExperimentalNameResolverProvider.java @@ -20,6 +20,7 @@ import io.grpc.NameResolver; import io.grpc.NameResolver.Args; import io.grpc.NameResolverProvider; +import io.grpc.Uri; import java.net.URI; /** @@ -35,6 +36,11 @@ public NameResolver newNameResolver(URI targetUri, Args args) { return delegate.newNameResolver(targetUri, args); } + @Override + public NameResolver newNameResolver(Uri targetUri, Args args) { + return delegate.newNameResolver(targetUri, args); + } + @Override public String getDefaultScheme() { return delegate.getDefaultScheme(); diff --git a/googleapis/src/main/java/io/grpc/googleapis/GoogleCloudToProdNameResolver.java b/googleapis/src/main/java/io/grpc/googleapis/GoogleCloudToProdNameResolver.java index 64c2e0f9c86..10ba586ab47 100644 --- a/googleapis/src/main/java/io/grpc/googleapis/GoogleCloudToProdNameResolver.java +++ b/googleapis/src/main/java/io/grpc/googleapis/GoogleCloudToProdNameResolver.java @@ -19,29 +19,38 @@ import static com.google.common.base.Preconditions.checkNotNull; import com.google.common.annotations.VisibleForTesting; -import com.google.common.base.Charsets; import com.google.common.base.Preconditions; -import com.google.common.base.Strings; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.io.CharStreams; +import com.google.errorprone.annotations.concurrent.GuardedBy; +import io.grpc.MetricRecorder; import io.grpc.NameResolver; import io.grpc.NameResolverRegistry; +import io.grpc.QueryParams; import io.grpc.Status; import io.grpc.SynchronizationContext; +import io.grpc.Uri; import io.grpc.alts.InternalCheckGcpEnvironment; import io.grpc.internal.GrpcUtil; import io.grpc.internal.SharedResourceHolder; import io.grpc.internal.SharedResourceHolder.Resource; +import io.grpc.xds.InternalGrpcBootstrapperImpl; +import io.grpc.xds.InternalSharedXdsClientPoolProvider; +import io.grpc.xds.InternalSharedXdsClientPoolProvider.XdsClientResult; +import io.grpc.xds.XdsNameResolverProvider; +import io.grpc.xds.client.Bootstrapper.BootstrapInfo; +import io.grpc.xds.client.XdsClient; +import io.grpc.xds.client.XdsInitializationException; import java.io.IOException; import java.io.InputStream; import java.io.InputStreamReader; import java.io.Reader; import java.net.HttpURLConnection; import java.net.URI; -import java.net.URISyntaxException; import java.net.URL; -import java.util.Map; +import java.nio.charset.StandardCharsets; +import java.util.List; import java.util.Random; import java.util.concurrent.Executor; import java.util.logging.Level; @@ -63,53 +72,72 @@ final class GoogleCloudToProdNameResolver extends NameResolver { static final String C2P_AUTHORITY = "traffic-director-c2p.xds.googleapis.com"; @VisibleForTesting static boolean isOnGcp = InternalCheckGcpEnvironment.isOnGcp(); - @VisibleForTesting - static boolean xdsBootstrapProvided = - System.getenv("GRPC_XDS_BOOTSTRAP") != null - || System.getProperty("io.grpc.xds.bootstrap") != null - || System.getenv("GRPC_XDS_BOOTSTRAP_CONFIG") != null - || System.getProperty("io.grpc.xds.bootstrapConfig") != null; - @VisibleForTesting - static boolean enableFederation = - Strings.isNullOrEmpty(System.getenv("GRPC_EXPERIMENTAL_XDS_FEDERATION")) - || Boolean.parseBoolean(System.getenv("GRPC_EXPERIMENTAL_XDS_FEDERATION")); private static final String serverUriOverride = System.getenv("GRPC_TEST_ONLY_GOOGLE_C2P_RESOLVER_TRAFFIC_DIRECTOR_URI"); - private HttpConnectionProvider httpConnectionProvider = HttpConnectionFactory.INSTANCE; + @GuardedBy("GoogleCloudToProdNameResolver.class") + private static BootstrapInfo bootstrapInfo; + private static HttpConnectionProvider httpConnectionProvider = HttpConnectionFactory.INSTANCE; + private static int c2pId = new Random().nextInt(); + + private static synchronized BootstrapInfo getBootstrapInfo(boolean isForcedXds) + throws XdsInitializationException, IOException { + if (bootstrapInfo != null) { + return bootstrapInfo; + } + BootstrapInfo newInfo; + if (isForcedXds) { + newInfo = InternalGrpcBootstrapperImpl.parseBootstrap( + generateBootstrap("", true)); + } else { + newInfo = InternalGrpcBootstrapperImpl.parseBootstrap( + generateBootstrap( + queryZoneMetadata(METADATA_URL_ZONE), + queryIpv6SupportMetadata(METADATA_URL_SUPPORT_IPV6))); + } + // Avoid setting global when testing + if (httpConnectionProvider == HttpConnectionFactory.INSTANCE) { + bootstrapInfo = newInfo; + } + return newInfo; + } + private final String authority; private final SynchronizationContext syncContext; private final Resource executorResource; - private final BootstrapSetter bootstrapSetter; + private final String target; + private final MetricRecorder metricRecorder; private final NameResolver delegate; - private final Random rand; private final boolean usingExecutorResource; - // It's not possible to use both PSM and DirectPath C2P in the same application. - // Delegate to DNS if user-provided bootstrap is found. - private final String schemeOverride = - !isOnGcp - || (xdsBootstrapProvided && !enableFederation) - ? "dns" : "xds"; + private final boolean forceXds; + private final String schemeOverride; + private XdsClientResult xdsClientPool; + private XdsClient xdsClient; private Executor executor; private Listener2 listener; private boolean succeeded; private boolean resolving; private boolean shutdown; - GoogleCloudToProdNameResolver(URI targetUri, Args args, Resource executorResource, - BootstrapSetter bootstrapSetter) { - this(targetUri, args, executorResource, new Random(), bootstrapSetter, + GoogleCloudToProdNameResolver(URI targetUri, Args args, Resource executorResource) { + this(targetUri, args, executorResource, NameResolverRegistry.getDefaultRegistry().asFactory()); } + // TODO(jdcormie): Remove after io.grpc.Uri migration. @VisibleForTesting GoogleCloudToProdNameResolver(URI targetUri, Args args, Resource executorResource, - Random rand, BootstrapSetter bootstrapSetter, NameResolver.Factory nameResolverFactory) { + NameResolver.Factory nameResolverFactory) { this.executorResource = checkNotNull(executorResource, "executorResource"); - this.bootstrapSetter = checkNotNull(bootstrapSetter, "bootstrapSetter"); - this.rand = checkNotNull(rand, "rand"); String targetPath = checkNotNull(checkNotNull(targetUri, "targetUri").getPath(), "targetPath"); + Uri grpcUri = Uri.create(targetUri.toString()); + QueryParams queryParams = QueryParams.fromRawQuery(grpcUri.getRawQuery()); + this.forceXds = checkForceXds(queryParams); + this.schemeOverride = (forceXds || isOnGcp) ? "xds" : "dns"; + stripForceXds(queryParams); + String newQuery = queryParams.toRawQuery(); + Preconditions.checkArgument( targetPath.startsWith("/"), "the path component (%s) of the target (%s) must start with '/'", @@ -117,16 +145,78 @@ final class GoogleCloudToProdNameResolver extends NameResolver { targetUri); authority = GrpcUtil.checkAuthority(targetPath.substring(1)); syncContext = checkNotNull(args, "args").getSynchronizationContext(); - targetUri = overrideUriScheme(targetUri, schemeOverride); - if (schemeOverride.equals("xds") && enableFederation) { - targetUri = overrideUriAuthority(targetUri, C2P_AUTHORITY); + + Uri.Builder modifiedTargetBuilder = grpcUri.toBuilder().setScheme(schemeOverride); + modifiedTargetBuilder.setRawQuery(newQuery); + if (schemeOverride.equals("xds")) { + modifiedTargetBuilder.setRawAuthority(C2P_AUTHORITY); } + targetUri = URI.create(modifiedTargetBuilder.build().toString()); + + if (schemeOverride.equals("xds")) { + args = args.toBuilder() + .setArg(XdsNameResolverProvider.XDS_CLIENT_SUPPLIER, () -> xdsClient) + .build(); + } + target = targetUri.toString(); + metricRecorder = args.getMetricRecorder(); delegate = checkNotNull(nameResolverFactory, "nameResolverFactory").newNameResolver( targetUri, args); executor = args.getOffloadExecutor(); usingExecutorResource = executor == null; } + GoogleCloudToProdNameResolver(Uri targetUri, Args args, Resource executorResource) { + this(targetUri, args, executorResource, NameResolverRegistry.getDefaultRegistry().asFactory()); + } + + @VisibleForTesting + GoogleCloudToProdNameResolver( + Uri targetUri, + Args args, + Resource executorResource, + NameResolver.Factory nameResolverFactory) { + this.executorResource = checkNotNull(executorResource, "executorResource"); + QueryParams queryParams = QueryParams.fromRawQuery(targetUri.getRawQuery()); + this.forceXds = checkForceXds(queryParams); + this.schemeOverride = (forceXds || isOnGcp) ? "xds" : "dns"; + stripForceXds(queryParams); + String newQuery = queryParams.toRawQuery(); + + Preconditions.checkArgument( + targetUri.isPathAbsolute(), + "the path component of the target (%s) must start with '/'", + targetUri); + List pathSegments = targetUri.getPathSegments(); + Preconditions.checkArgument( + pathSegments.size() == 1, + "the path component of the target (%s) must have exactly one segment", + targetUri); + authority = GrpcUtil.checkAuthority(pathSegments.get(0)); + syncContext = checkNotNull(args, "args").getSynchronizationContext(); + Uri.Builder modifiedTargetBuilder = targetUri.toBuilder().setScheme(schemeOverride); + if (newQuery != null) { + modifiedTargetBuilder.setRawQuery(newQuery); + } else { + modifiedTargetBuilder.setRawQuery(null); + } + + if (schemeOverride.equals("xds")) { + modifiedTargetBuilder.setRawAuthority(C2P_AUTHORITY); + args = + args.toBuilder() + .setArg(XdsNameResolverProvider.XDS_CLIENT_SUPPLIER, () -> xdsClient) + .build(); + } + targetUri = modifiedTargetBuilder.build(); + target = targetUri.toString(); + metricRecorder = args.getMetricRecorder(); + delegate = + checkNotNull(nameResolverFactory, "nameResolverFactory").newNameResolver(targetUri, args); + executor = args.getOffloadExecutor(); + usingExecutorResource = executor == null; + } + @Override public String getServiceAuthority() { return authority; @@ -150,7 +240,7 @@ private void resolve() { resolving = true; if (logger.isLoggable(Level.FINE)) { - logger.fine("resolve with schemaOverride = " + schemeOverride); + logger.log(Level.FINE, "start with schemaOverride = {0}", schemeOverride); } if (schemeOverride.equals("dns")) { @@ -168,28 +258,28 @@ private void resolve() { class Resolve implements Runnable { @Override public void run() { - ImmutableMap rawBootstrap = null; + BootstrapInfo bootstrapInfo = null; try { - // User provided bootstrap configs are only supported with federation. If federation is - // not enabled or there is no user provided config, we set a custom bootstrap override. - // Otherwise, we don't set the override, which will allow a user provided bootstrap config - // to take effect. - if (!enableFederation || !xdsBootstrapProvided) { - rawBootstrap = generateBootstrap(queryZoneMetadata(METADATA_URL_ZONE), - queryIpv6SupportMetadata(METADATA_URL_SUPPORT_IPV6)); - } + bootstrapInfo = getBootstrapInfo(forceXds); } catch (IOException e) { listener.onError( Status.INTERNAL.withDescription("Unable to get metadata").withCause(e)); + } catch (XdsInitializationException e) { + listener.onError( + Status.INTERNAL.withDescription("Unable to create c2p bootstrap").withCause(e)); + } catch (Throwable t) { + listener.onError( + Status.INTERNAL.withDescription("Unexpected error creating c2p bootstrap") + .withCause(t)); } finally { - final ImmutableMap finalRawBootstrap = rawBootstrap; + final BootstrapInfo finalBootstrapInfo = bootstrapInfo; syncContext.execute(new Runnable() { @Override public void run() { - if (!shutdown) { - if (finalRawBootstrap != null) { - bootstrapSetter.setBootstrap(finalRawBootstrap); - } + if (!shutdown && finalBootstrapInfo != null) { + xdsClientPool = InternalSharedXdsClientPoolProvider.getOrCreate( + target, finalBootstrapInfo, metricRecorder, null); + xdsClient = xdsClientPool.getObject(); delegate.start(listener); succeeded = true; } @@ -203,9 +293,11 @@ public void run() { executor.execute(new Resolve()); } - private ImmutableMap generateBootstrap(String zone, boolean supportIpv6) { + private static ImmutableMap generateBootstrap( + String zone, boolean supportIpv6) { ImmutableMap.Builder nodeBuilder = ImmutableMap.builder(); - nodeBuilder.put("id", "C2P-" + (rand.nextInt() & Integer.MAX_VALUE)); + String nodeIdPrefix = isOnGcp ? "C2P-" : "C2P-non-gcp-"; + nodeBuilder.put("id", nodeIdPrefix + (c2pId & Integer.MAX_VALUE)); if (!zone.isEmpty()) { nodeBuilder.put("locality", ImmutableMap.of("zone", zone)); } @@ -250,12 +342,15 @@ public void shutdown() { if (delegate != null) { delegate.shutdown(); } + if (xdsClient != null) { + xdsClient = xdsClientPool.returnObject(xdsClient); + } if (executor != null && usingExecutorResource) { executor = SharedResourceHolder.release(executorResource, executor); } } - private String queryZoneMetadata(String url) throws IOException { + private static String queryZoneMetadata(String url) throws IOException { HttpURLConnection con = null; String respBody; try { @@ -263,7 +358,7 @@ private String queryZoneMetadata(String url) throws IOException { if (con.getResponseCode() != 200) { return ""; } - try (Reader reader = new InputStreamReader(con.getInputStream(), Charsets.UTF_8)) { + try (Reader reader = new InputStreamReader(con.getInputStream(), StandardCharsets.UTF_8)) { respBody = CharStreams.toString(reader); } } finally { @@ -275,7 +370,7 @@ private String queryZoneMetadata(String url) throws IOException { return index == -1 ? "" : respBody.substring(index + 1); } - private boolean queryIpv6SupportMetadata(String url) throws IOException { + private static boolean queryIpv6SupportMetadata(String url) throws IOException { HttpURLConnection con = null; try { con = httpConnectionProvider.createConnection(url); @@ -294,28 +389,30 @@ private boolean queryIpv6SupportMetadata(String url) throws IOException { } @VisibleForTesting - void setHttpConnectionProvider(HttpConnectionProvider httpConnectionProvider) { - this.httpConnectionProvider = httpConnectionProvider; + static void setHttpConnectionProvider(HttpConnectionProvider httpConnectionProvider) { + if (httpConnectionProvider == null) { + GoogleCloudToProdNameResolver.httpConnectionProvider = HttpConnectionFactory.INSTANCE; + } else { + GoogleCloudToProdNameResolver.httpConnectionProvider = httpConnectionProvider; + } } - private static URI overrideUriScheme(URI uri, String scheme) { - URI res; - try { - res = new URI(scheme, uri.getAuthority(), uri.getPath(), uri.getQuery(), uri.getFragment()); - } catch (URISyntaxException ex) { - throw new IllegalArgumentException("Invalid scheme: " + scheme, ex); - } - return res; + @VisibleForTesting + static void setC2pId(int c2pId) { + GoogleCloudToProdNameResolver.c2pId = c2pId; } - private static URI overrideUriAuthority(URI uri, String authority) { - URI res; - try { - res = new URI(uri.getScheme(), authority, uri.getPath(), uri.getQuery(), uri.getFragment()); - } catch (URISyntaxException ex) { - throw new IllegalArgumentException("Invalid authority: " + authority, ex); + private static boolean checkForceXds(QueryParams params) { + for (QueryParams.Entry entry : params.asList()) { + if ("force-xds".equals(entry.getKey())) { + return true; + } } - return res; + return false; + } + + private static void stripForceXds(QueryParams params) { + params.asList().removeIf(entry -> "force-xds".equals(entry.getKey())); } private enum HttpConnectionFactory implements HttpConnectionProvider { @@ -335,8 +432,4 @@ public HttpURLConnection createConnection(String url) throws IOException { interface HttpConnectionProvider { HttpURLConnection createConnection(String url) throws IOException; } - - public interface BootstrapSetter { - void setBootstrap(Map bootstrap); - } } diff --git a/googleapis/src/main/java/io/grpc/googleapis/GoogleCloudToProdNameResolverProvider.java b/googleapis/src/main/java/io/grpc/googleapis/GoogleCloudToProdNameResolverProvider.java index 8ad292a3d98..f936de086e9 100644 --- a/googleapis/src/main/java/io/grpc/googleapis/GoogleCloudToProdNameResolverProvider.java +++ b/googleapis/src/main/java/io/grpc/googleapis/GoogleCloudToProdNameResolverProvider.java @@ -21,14 +21,13 @@ import io.grpc.NameResolver; import io.grpc.NameResolver.Args; import io.grpc.NameResolverProvider; +import io.grpc.Uri; import io.grpc.internal.GrpcUtil; -import io.grpc.xds.InternalSharedXdsClientPoolProvider; import java.net.InetSocketAddress; import java.net.SocketAddress; import java.net.URI; import java.util.Collection; import java.util.Collections; -import java.util.Map; /** * A provider for {@link GoogleCloudToProdNameResolver}. @@ -48,12 +47,21 @@ public GoogleCloudToProdNameResolverProvider() { this.scheme = Preconditions.checkNotNull(scheme, "scheme"); } + // TODO(jdcormie): Remove after io.grpc.Uri migration is complete. @Override public NameResolver newNameResolver(URI targetUri, Args args) { if (scheme.equals(targetUri.getScheme())) { return new GoogleCloudToProdNameResolver( - targetUri, args, GrpcUtil.SHARED_CHANNEL_EXECUTOR, - new SharedXdsClientPoolProviderBootstrapSetter()); + targetUri, args, GrpcUtil.SHARED_CHANNEL_EXECUTOR); + } + return null; + } + + @Override + public NameResolver newNameResolver(Uri targetUri, Args args) { + if (scheme.equals(targetUri.getScheme())) { + return new GoogleCloudToProdNameResolver( + targetUri, args, GrpcUtil.SHARED_CHANNEL_EXECUTOR); } return null; } @@ -77,12 +85,4 @@ protected int priority() { public Collection> getProducedSocketAddressTypes() { return Collections.singleton(InetSocketAddress.class); } - - private static final class SharedXdsClientPoolProviderBootstrapSetter - implements GoogleCloudToProdNameResolver.BootstrapSetter { - @Override - public void setBootstrap(Map bootstrap) { - InternalSharedXdsClientPoolProvider.setDefaultProviderBootstrapOverride(bootstrap); - } - } } diff --git a/googleapis/src/test/java/io/grpc/googleapis/GoogleCloudToProdNameResolverProviderTest.java b/googleapis/src/test/java/io/grpc/googleapis/GoogleCloudToProdNameResolverProviderTest.java index 447b102c8c7..39468472985 100644 --- a/googleapis/src/test/java/io/grpc/googleapis/GoogleCloudToProdNameResolverProviderTest.java +++ b/googleapis/src/test/java/io/grpc/googleapis/GoogleCloudToProdNameResolverProviderTest.java @@ -23,20 +23,23 @@ import io.grpc.ChannelLogger; import io.grpc.InternalServiceProviders; import io.grpc.NameResolver; +import io.grpc.NameResolver.Args; import io.grpc.NameResolver.ServiceConfigParser; import io.grpc.NameResolverProvider; import io.grpc.SynchronizationContext; +import io.grpc.Uri; import io.grpc.internal.FakeClock; import io.grpc.internal.GrpcUtil; import java.net.URI; +import java.util.Arrays; import org.junit.Test; import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; +import org.junit.runners.Parameterized; +import org.junit.runners.Parameterized.Parameter; +import org.junit.runners.Parameterized.Parameters; -/** - * Unit tests for {@link GoogleCloudToProdNameResolverProvider}. - */ -@RunWith(JUnit4.class) +/** Unit tests for {@link GoogleCloudToProdNameResolverProvider}. */ +@RunWith(Parameterized.class) public class GoogleCloudToProdNameResolverProviderTest { private final SynchronizationContext syncContext = new SynchronizationContext( new Thread.UncaughtExceptionHandler() { @@ -59,6 +62,13 @@ public void uncaughtException(Thread t, Throwable e) { private GoogleCloudToProdNameResolverProvider provider = new GoogleCloudToProdNameResolverProvider(); + @Parameters(name = "enableRfc3986UrisParam={0}") + public static Iterable data() { + return Arrays.asList(new Object[][] {{true}, {false}}); + } + + @Parameter public boolean enableRfc3986UrisParam; + @Test public void provided() { for (NameResolverProvider current @@ -84,16 +94,24 @@ NameResolverProvider.class, getClass().getClassLoader())) { } @Test - public void newNameResolver() { - assertThat(provider - .newNameResolver(URI.create("google-c2p:///foo.googleapis.com"), args)) + public void shouldProvideNameResolverOfExpectedType() { + assertThat(newNameResolver(provider, "google-c2p:///foo.googleapis.com", args)) .isInstanceOf(GoogleCloudToProdNameResolver.class); } @Test - public void experimentalNewNameResolver() { - assertThat(new GoogleCloudToProdExperimentalNameResolverProvider() - .newNameResolver(URI.create("google-c2p-experimental:///foo.googleapis.com"), args)) + public void shouldProvideExperimentalNameResolverOfExpectedType() { + assertThat( + newNameResolver( + new GoogleCloudToProdExperimentalNameResolverProvider(), + "google-c2p-experimental:///foo.googleapis.com", + args)) .isInstanceOf(GoogleCloudToProdNameResolver.class); } + + private NameResolver newNameResolver(NameResolverProvider provider, String uri, Args args) { + return enableRfc3986UrisParam + ? provider.newNameResolver(Uri.create(uri), args) + : provider.newNameResolver(URI.create(uri), args); + } } diff --git a/googleapis/src/test/java/io/grpc/googleapis/GoogleCloudToProdNameResolverTest.java b/googleapis/src/test/java/io/grpc/googleapis/GoogleCloudToProdNameResolverTest.java index edb3126d1e3..bbd3ba3ef05 100644 --- a/googleapis/src/test/java/io/grpc/googleapis/GoogleCloudToProdNameResolverTest.java +++ b/googleapis/src/test/java/io/grpc/googleapis/GoogleCloudToProdNameResolverTest.java @@ -21,10 +21,9 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; -import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableMap; import com.google.common.collect.Iterables; import io.grpc.ChannelLogger; +import io.grpc.MetricRecorder; import io.grpc.NameResolver; import io.grpc.NameResolver.Args; import io.grpc.NameResolver.ServiceConfigParser; @@ -33,6 +32,7 @@ import io.grpc.Status; import io.grpc.Status.Code; import io.grpc.SynchronizationContext; +import io.grpc.Uri; import io.grpc.googleapis.GoogleCloudToProdNameResolver.HttpConnectionProvider; import io.grpc.internal.FakeClock; import io.grpc.internal.GrpcUtil; @@ -42,31 +42,32 @@ import java.net.HttpURLConnection; import java.net.URI; import java.nio.charset.StandardCharsets; +import java.util.Arrays; import java.util.HashMap; -import java.util.List; import java.util.Map; import java.util.Random; import java.util.concurrent.Executor; -import java.util.concurrent.atomic.AtomicReference; import org.junit.After; import org.junit.Before; import org.junit.Rule; import org.junit.Test; import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; +import org.junit.runners.Parameterized; +import org.junit.runners.Parameterized.Parameter; +import org.junit.runners.Parameterized.Parameters; import org.mockito.ArgumentCaptor; import org.mockito.Captor; import org.mockito.Mock; import org.mockito.junit.MockitoJUnit; import org.mockito.junit.MockitoRule; -@RunWith(JUnit4.class) +@RunWith(Parameterized.class) public class GoogleCloudToProdNameResolverTest { @Rule public final MockitoRule mocks = MockitoJUnit.rule(); - private static final URI TARGET_URI = URI.create("google-c2p:///googleapis.com"); + private static final String TARGET_URI = "google-c2p:///googleapis.com"; private static final String ZONE = "us-central1-a"; private static final int DEFAULT_PORT = 887; @@ -77,15 +78,16 @@ public void uncaughtException(Thread t, Throwable e) { throw new AssertionError(e); } }); + private final FakeClock fakeExecutor = new FakeClock(); private final NameResolver.Args args = NameResolver.Args.newBuilder() .setDefaultPort(DEFAULT_PORT) .setProxyDetector(GrpcUtil.DEFAULT_PROXY_DETECTOR) .setSynchronizationContext(syncContext) + .setScheduledExecutorService(fakeExecutor.getScheduledExecutorService()) .setServiceConfigParser(mock(ServiceConfigParser.class)) .setChannelLogger(mock(ChannelLogger.class)) + .setMetricRecorder(new MetricRecorder() {}) .build(); - private final FakeClock fakeExecutor = new FakeClock(); - private final FakeBootstrapSetter fakeBootstrapSetter = new FakeBootstrapSetter(); private final Resource fakeExecutorResource = new Resource() { @Override public Executor create() { @@ -98,37 +100,30 @@ public void close(Executor instance) {} private final NameResolverRegistry nsRegistry = new NameResolverRegistry(); private final Map delegatedResolver = new HashMap<>(); + private final Map delegatedUri = new HashMap<>(); + private final Map delegatedRfcUri = new HashMap<>(); @Mock private NameResolver.Listener2 mockListener; - private Random random = new Random(1); @Captor private ArgumentCaptor errorCaptor; private boolean originalIsOnGcp; - private boolean originalXdsBootstrapProvided; private GoogleCloudToProdNameResolver resolver; + private String responseToIpV6 = "1:1:1"; + + @Parameters(name = "enableRfc3986UrisParam={0}") + public static Iterable data() { + return Arrays.asList(new Object[][] {{true}, {false}}); + } + + @Parameter public boolean enableRfc3986UrisParam; @Before public void setUp() { nsRegistry.register(new FakeNsProvider("dns")); nsRegistry.register(new FakeNsProvider("xds")); originalIsOnGcp = GoogleCloudToProdNameResolver.isOnGcp; - originalXdsBootstrapProvided = GoogleCloudToProdNameResolver.xdsBootstrapProvided; - } - @After - public void tearDown() { - GoogleCloudToProdNameResolver.isOnGcp = originalIsOnGcp; - GoogleCloudToProdNameResolver.xdsBootstrapProvided = originalXdsBootstrapProvided; - resolver.shutdown(); - verify(Iterables.getOnlyElement(delegatedResolver.values())).shutdown(); - } - - private void createResolver() { - createResolver("1:1:1"); - } - - private void createResolver(String responseToIpV6) { HttpConnectionProvider httpConnections = new HttpConnectionProvider() { @Override public HttpURLConnection createConnection(String url) throws IOException { @@ -148,10 +143,28 @@ public HttpURLConnection createConnection(String url) throws IOException { throw new AssertionError("Unknown http query"); } }; - resolver = new GoogleCloudToProdNameResolver( - TARGET_URI, args, fakeExecutorResource, random, fakeBootstrapSetter, - nsRegistry.asFactory()); - resolver.setHttpConnectionProvider(httpConnections); + GoogleCloudToProdNameResolver.setHttpConnectionProvider(httpConnections); + + GoogleCloudToProdNameResolver.setC2pId(new Random(1).nextInt()); + } + + @After + public void tearDown() { + GoogleCloudToProdNameResolver.isOnGcp = originalIsOnGcp; + GoogleCloudToProdNameResolver.setHttpConnectionProvider(null); + if (resolver != null) { + resolver.shutdown(); + verify(Iterables.getOnlyElement(delegatedResolver.values())).shutdown(); + } + } + + private void createResolver() { + resolver = + enableRfc3986UrisParam + ? new GoogleCloudToProdNameResolver( + Uri.create(TARGET_URI), args, fakeExecutorResource, nsRegistry.asFactory()) + : new GoogleCloudToProdNameResolver( + URI.create(TARGET_URI), args, fakeExecutorResource, nsRegistry.asFactory()); } @Test @@ -164,137 +177,139 @@ public void notOnGcp_DelegateToDns() { } @Test - public void hasProvidedBootstrap_DelegateToDns() { + public void onGcpAndNoProvidedBootstrap_DelegateToXds() { GoogleCloudToProdNameResolver.isOnGcp = true; - GoogleCloudToProdNameResolver.xdsBootstrapProvided = true; - GoogleCloudToProdNameResolver.enableFederation = false; createResolver(); resolver.start(mockListener); - assertThat(delegatedResolver.keySet()).containsExactly("dns"); + fakeExecutor.runDueTasks(); + assertThat(delegatedResolver.keySet()).containsExactly("xds"); verify(Iterables.getOnlyElement(delegatedResolver.values())).start(mockListener); } - @SuppressWarnings("unchecked") @Test - public void onGcpAndNoProvidedBootstrap_DelegateToXds() { - GoogleCloudToProdNameResolver.isOnGcp = true; - GoogleCloudToProdNameResolver.xdsBootstrapProvided = false; - createResolver(); + public void notOnGcpButForceXds_DelegateToXds() { + GoogleCloudToProdNameResolver.isOnGcp = false; + String target = TARGET_URI + "?force-xds"; + resolver = + enableRfc3986UrisParam + ? new GoogleCloudToProdNameResolver( + Uri.create(target), args, fakeExecutorResource, nsRegistry.asFactory()) + : new GoogleCloudToProdNameResolver( + URI.create(target), args, fakeExecutorResource, nsRegistry.asFactory()); resolver.start(mockListener); fakeExecutor.runDueTasks(); assertThat(delegatedResolver.keySet()).containsExactly("xds"); - verify(Iterables.getOnlyElement(delegatedResolver.values())).start(mockListener); - Map bootstrap = fakeBootstrapSetter.bootstrapRef.get(); - Map node = (Map) bootstrap.get("node"); - assertThat(node).containsExactly( - "id", "C2P-991614323", - "locality", ImmutableMap.of("zone", ZONE), - "metadata", ImmutableMap.of("TRAFFICDIRECTOR_DIRECTPATH_C2P_IPV6_CAPABLE", true)); - Map server = Iterables.getOnlyElement( - (List>) bootstrap.get("xds_servers")); - assertThat(server).containsExactly( - "server_uri", "directpath-pa.googleapis.com", - "channel_creds", ImmutableList.of(ImmutableMap.of("type", "google_default")), - "server_features", ImmutableList.of("xds_v3", "ignore_resource_deletion")); - Map authorities = (Map) bootstrap.get("authorities"); - assertThat(authorities).containsExactly( - "traffic-director-c2p.xds.googleapis.com", - ImmutableMap.of("xds_servers", ImmutableList.of(server))); + + if (enableRfc3986UrisParam) { + Uri delegatedRfcUriValue = delegatedRfcUri.get("xds"); + assertThat(delegatedRfcUriValue).isNotNull(); + assertThat(delegatedRfcUriValue.getRawQuery()).isNull(); + } else { + URI delegatedUriValue = delegatedUri.get("xds"); + assertThat(delegatedUriValue).isNotNull(); + assertThat(delegatedUriValue.getQuery()).isNull(); + } } - @SuppressWarnings("unchecked") @Test - public void onGcpAndNoProvidedBootstrap_DelegateToXds_noIpV6() { - GoogleCloudToProdNameResolver.isOnGcp = true; - GoogleCloudToProdNameResolver.xdsBootstrapProvided = false; - createResolver(null); + public void notOnGcpButForceXds_KeyValueTrue_DelegateToXds() { + GoogleCloudToProdNameResolver.isOnGcp = false; + String target = TARGET_URI + "?force-xds=true"; + resolver = enableRfc3986UrisParam + ? new GoogleCloudToProdNameResolver( + Uri.create(target), args, fakeExecutorResource, nsRegistry.asFactory()) + : new GoogleCloudToProdNameResolver( + URI.create(target), args, fakeExecutorResource, nsRegistry.asFactory()); resolver.start(mockListener); fakeExecutor.runDueTasks(); assertThat(delegatedResolver.keySet()).containsExactly("xds"); - verify(Iterables.getOnlyElement(delegatedResolver.values())).start(mockListener); - Map bootstrap = fakeBootstrapSetter.bootstrapRef.get(); - Map node = (Map) bootstrap.get("node"); - assertThat(node).containsExactly( - "id", "C2P-991614323", - "locality", ImmutableMap.of("zone", ZONE)); - Map server = Iterables.getOnlyElement( - (List>) bootstrap.get("xds_servers")); - assertThat(server).containsExactly( - "server_uri", "directpath-pa.googleapis.com", - "channel_creds", ImmutableList.of(ImmutableMap.of("type", "google_default")), - "server_features", ImmutableList.of("xds_v3", "ignore_resource_deletion")); - Map authorities = (Map) bootstrap.get("authorities"); - assertThat(authorities).containsExactly( - "traffic-director-c2p.xds.googleapis.com", - ImmutableMap.of("xds_servers", ImmutableList.of(server))); + + if (enableRfc3986UrisParam) { + Uri delegatedRfcUriValue = delegatedRfcUri.get("xds"); + assertThat(delegatedRfcUriValue).isNotNull(); + assertThat(delegatedRfcUriValue.getRawQuery()).isNull(); + } else { + URI delegatedUriValue = delegatedUri.get("xds"); + assertThat(delegatedUriValue).isNotNull(); + assertThat(delegatedUriValue.getQuery()).isNull(); + } } - @SuppressWarnings("unchecked") + @Test - public void emptyResolverMeetadataValue() { - GoogleCloudToProdNameResolver.isOnGcp = true; - GoogleCloudToProdNameResolver.xdsBootstrapProvided = false; - createResolver(""); + public void notOnGcpButForceXds_WithMultipleParams_DelegateToXds() { + GoogleCloudToProdNameResolver.isOnGcp = false; + String target = TARGET_URI + "?foo=bar&force-xds&baz=qux"; + resolver = enableRfc3986UrisParam + ? new GoogleCloudToProdNameResolver( + Uri.create(target), args, fakeExecutorResource, nsRegistry.asFactory()) + : new GoogleCloudToProdNameResolver( + URI.create(target), args, fakeExecutorResource, nsRegistry.asFactory()); resolver.start(mockListener); fakeExecutor.runDueTasks(); assertThat(delegatedResolver.keySet()).containsExactly("xds"); - verify(Iterables.getOnlyElement(delegatedResolver.values())).start(mockListener); - Map bootstrap = fakeBootstrapSetter.bootstrapRef.get(); - Map node = (Map) bootstrap.get("node"); - assertThat(node).containsExactly( - "id", "C2P-991614323", - "locality", ImmutableMap.of("zone", ZONE)); + + if (enableRfc3986UrisParam) { + Uri delegatedRfcUriValue = delegatedRfcUri.get("xds"); + assertThat(delegatedRfcUriValue).isNotNull(); + assertThat(delegatedRfcUriValue.getRawQuery()).isEqualTo("foo=bar&baz=qux"); + } else { + URI delegatedUriValue = delegatedUri.get("xds"); + assertThat(delegatedUriValue).isNotNull(); + assertThat(delegatedUriValue.getQuery()).isEqualTo("foo=bar&baz=qux"); + } } - @SuppressWarnings("unchecked") @Test - public void onGcpAndNoProvidedBootstrapAndFederationEnabled_DelegateToXds() { - GoogleCloudToProdNameResolver.isOnGcp = true; - GoogleCloudToProdNameResolver.xdsBootstrapProvided = false; - GoogleCloudToProdNameResolver.enableFederation = true; - createResolver(); + public void notOnGcpButForceXds_WithEncodedAmpersand_DelegateToXds() { + GoogleCloudToProdNameResolver.isOnGcp = false; + String target = TARGET_URI + "?force-xds&foo=bar%26baz"; + resolver = enableRfc3986UrisParam + ? new GoogleCloudToProdNameResolver( + Uri.create(target), args, fakeExecutorResource, nsRegistry.asFactory()) + : new GoogleCloudToProdNameResolver( + URI.create(target), args, fakeExecutorResource, nsRegistry.asFactory()); resolver.start(mockListener); fakeExecutor.runDueTasks(); assertThat(delegatedResolver.keySet()).containsExactly("xds"); - verify(Iterables.getOnlyElement(delegatedResolver.values())).start(mockListener); - // check bootstrap - Map bootstrap = fakeBootstrapSetter.bootstrapRef.get(); - Map node = (Map) bootstrap.get("node"); - assertThat(node).containsExactly( - "id", "C2P-991614323", - "locality", ImmutableMap.of("zone", ZONE), - "metadata", ImmutableMap.of("TRAFFICDIRECTOR_DIRECTPATH_C2P_IPV6_CAPABLE", true)); - Map server = Iterables.getOnlyElement( - (List>) bootstrap.get("xds_servers")); - assertThat(server).containsExactly( - "server_uri", "directpath-pa.googleapis.com", - "channel_creds", ImmutableList.of(ImmutableMap.of("type", "google_default")), - "server_features", ImmutableList.of("xds_v3", "ignore_resource_deletion")); - Map authorities = (Map) bootstrap.get("authorities"); - assertThat(authorities).containsExactly( - "traffic-director-c2p.xds.googleapis.com", - ImmutableMap.of("xds_servers", ImmutableList.of(server))); + + if (enableRfc3986UrisParam) { + Uri delegatedRfcUriValue = delegatedRfcUri.get("xds"); + assertThat(delegatedRfcUriValue).isNotNull(); + assertThat(delegatedRfcUriValue.getRawQuery()).isEqualTo("foo=bar%26baz"); + } else { + URI delegatedUriValue = delegatedUri.get("xds"); + assertThat(delegatedUriValue).isNotNull(); + assertThat(delegatedUriValue.getRawQuery()).isEqualTo("foo=bar%26baz"); + } } - @SuppressWarnings("unchecked") @Test - public void onGcpAndProvidedBootstrapAndFederationEnabled_DontDelegateToXds() { - GoogleCloudToProdNameResolver.isOnGcp = true; - GoogleCloudToProdNameResolver.xdsBootstrapProvided = true; - GoogleCloudToProdNameResolver.enableFederation = true; - createResolver(); + public void notOnGcpButForceXds_CaseSensitive_DelegateToDns() { + GoogleCloudToProdNameResolver.isOnGcp = false; + String target = TARGET_URI + "?FORCE-XDS"; + resolver = enableRfc3986UrisParam + ? new GoogleCloudToProdNameResolver( + Uri.create(target), args, fakeExecutorResource, nsRegistry.asFactory()) + : new GoogleCloudToProdNameResolver( + URI.create(target), args, fakeExecutorResource, nsRegistry.asFactory()); resolver.start(mockListener); - fakeExecutor.runDueTasks(); - assertThat(delegatedResolver.keySet()).containsExactly("xds"); - verify(Iterables.getOnlyElement(delegatedResolver.values())).start(mockListener); - // Bootstrapper should not have been set, since there was no user provided config. - assertThat(fakeBootstrapSetter.bootstrapRef.get()).isNull(); + assertThat(delegatedResolver.keySet()).containsExactly("dns"); + + if (enableRfc3986UrisParam) { + Uri delegatedRfcUriValue = delegatedRfcUri.get("dns"); + assertThat(delegatedRfcUriValue).isNotNull(); + assertThat(delegatedRfcUriValue.getRawQuery()).isEqualTo("FORCE-XDS"); + } else { + URI delegatedUriValue = delegatedUri.get("dns"); + assertThat(delegatedUriValue).isNotNull(); + assertThat(delegatedUriValue.getQuery()).isEqualTo("FORCE-XDS"); + } } @Test public void failToQueryMetadata() { GoogleCloudToProdNameResolver.isOnGcp = true; - GoogleCloudToProdNameResolver.xdsBootstrapProvided = false; createResolver(); HttpConnectionProvider httpConnections = new HttpConnectionProvider() { @Override @@ -304,7 +319,7 @@ public HttpURLConnection createConnection(String url) throws IOException { return con; } }; - resolver.setHttpConnectionProvider(httpConnections); + GoogleCloudToProdNameResolver.setHttpConnectionProvider(httpConnections); resolver.start(mockListener); fakeExecutor.runDueTasks(); verify(mockListener).onError(errorCaptor.capture()); @@ -322,6 +337,18 @@ private FakeNsProvider(String scheme) { @Override public NameResolver newNameResolver(URI targetUri, Args args) { if (scheme.equals(targetUri.getScheme())) { + delegatedUri.put(scheme, targetUri); + NameResolver resolver = mock(NameResolver.class); + delegatedResolver.put(scheme, resolver); + return resolver; + } + return null; + } + + @Override + public NameResolver newNameResolver(Uri targetUri, Args args) { + if (scheme.equals(targetUri.getScheme())) { + delegatedRfcUri.put(scheme, targetUri); NameResolver resolver = mock(NameResolver.class); delegatedResolver.put(scheme, resolver); return resolver; @@ -344,14 +371,4 @@ public String getDefaultScheme() { return scheme; } } - - private static final class FakeBootstrapSetter - implements GoogleCloudToProdNameResolver.BootstrapSetter { - private final AtomicReference> bootstrapRef = new AtomicReference<>(); - - @Override - public void setBootstrap(Map bootstrap) { - bootstrapRef.set(bootstrap); - } - } } diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index 282f012e77a..a951303d209 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -1,99 +1,153 @@ [versions] -googleauth = "1.22.0" -netty = '4.1.100.Final' -# Keep the following references of tcnative version in sync whenever it's updated: -# SECURITY.md -nettytcnative = '2.0.61.Final' opencensus = "0.31.1" -protobuf = "3.25.1" [libraries] android-annotations = "com.google.android:annotations:4.1.1.4" -androidx-annotation = "androidx.annotation:annotation:1.7.0" -androidx-core = "androidx.core:core:1.12.0" -androidx-lifecycle-common = "androidx.lifecycle:lifecycle-common:2.6.2" -androidx-lifecycle-service = "androidx.lifecycle:lifecycle-service:2.6.2" -androidx-test-core = "androidx.test:core:1.5.0" -androidx-test-ext-junit = "androidx.test.ext:junit:1.1.5" -androidx-test-rules = "androidx.test:rules:1.5.0" -animalsniffer = "org.codehaus.mojo:animal-sniffer:1.23" -animalsniffer-annotations = "org.codehaus.mojo:animal-sniffer-annotations:1.23" -auto-value = "com.google.auto.value:auto-value:1.10.4" -auto-value-annotations = "com.google.auto.value:auto-value-annotations:1.10.4" -checkstyle = "com.puppycrawl.tools:checkstyle:10.12.5" +# 1.9.1+ uses Kotlin and requires Android Gradle Plugin 9+ +# checkForUpdates: androidx-annotation:1.9.0 +androidx-annotation = "androidx.annotation:annotation:1.9.0" +# 1.14.x doesn't exist. +# 1.15.0+ requires compileSdkVersion 35 which officially requires AGP 8.6.0+. +# It might work before then, but AGP 7.4.1 fails with: +# RES_TABLE_TYPE_TYPE entry offsets overlap actual entry data. +# 1.16.0+ requires AGP 8.6.0+ +# checkForUpdates: androidx-core:1.13.+ +androidx-core = "androidx.core:core:1.13.1" +# 2.9+ requires AGP 8.1.1+ +# checkForUpdates: androidx-lifecycle-common:2.8.+ +androidx-lifecycle-common = "androidx.lifecycle:lifecycle-common:2.8.7" +# checkForUpdates: androidx-lifecycle-service:2.8.+ +androidx-lifecycle-service = "androidx.lifecycle:lifecycle-service:2.8.7" +androidx-test-core = "androidx.test:core:1.7.0" +androidx-test-ext-junit = "androidx.test.ext:junit:1.3.0" +androidx-test-rules = "androidx.test:rules:1.7.0" +androidx-test-runner = "androidx.test:runner:1.7.0" +animalsniffer = "org.codehaus.mojo:animal-sniffer:1.27" +animalsniffer-annotations = "org.codehaus.mojo:animal-sniffer-annotations:1.27" +assertj-core = "org.assertj:assertj-core:3.27.7" +# 1.11.1 started converting jsr305 @Nullable to jspecify +# checkForUpdates: auto-value:1.11.0 +auto-value = "com.google.auto.value:auto-value:1.11.0" +# checkForUpdates: auto-value-annotations:1.11.0 +auto-value-annotations = "com.google.auto.value:auto-value-annotations:1.11.0" +# 11.0+ requires Java 17+ +# https://checkstyle.sourceforge.io/releasenotes.html +# checkForUpdates: checkstyle:10.+ +checkstyle = "com.puppycrawl.tools:checkstyle:10.26.1" +# checkstyle 10.0+ requires Java 11+ +# See https://checkstyle.sourceforge.io/releasenotes_old_8-35_10-26.html#Release_10.0 +# checkForUpdates: checkstylejava8:9.+ +checkstylejava8 = "com.puppycrawl.tools:checkstyle:9.3" commons-math3 = "org.apache.commons:commons-math3:3.6.1" conscrypt = "org.conscrypt:conscrypt-openjdk-uber:2.5.2" -cronet-api = "org.chromium.net:cronet-api:108.5359.79" -cronet-embedded = "org.chromium.net:cronet-embedded:108.5359.79" -errorprone-annotations = "com.google.errorprone:error_prone_annotations:2.23.0" -errorprone-core = "com.google.errorprone:error_prone_core:2.23.0" -google-api-protos = "com.google.api.grpc:proto-google-common-protos:2.29.0" -google-auth-credentials = { module = "com.google.auth:google-auth-library-credentials", version.ref = "googleauth" } -google-auth-oauth2Http = { module = "com.google.auth:google-auth-library-oauth2-http", version.ref = "googleauth" } +# 141.7340.3+ requires Java 17+ +# checkForUpdates: cronet-api:119.6045.31 +cronet-api = "org.chromium.net:cronet-api:119.6045.31" +# checkForUpdates: cronet-embedded:119.6045.31 +cronet-embedded = "org.chromium.net:cronet-embedded:119.6045.31" +errorprone-annotations = "com.google.errorprone:error_prone_annotations:2.48.0" +# 2.32.0+ requires Java 17+ +# checkForUpdates: errorprone-core:2.31.+ +errorprone-core = "com.google.errorprone:error_prone_core:2.31.0" +# 2.11.0+ requires JDK 11+ (See https://github.com/google/error-prone/releases/tag/v2.11.0) +# checkForUpdates: errorprone-corejava8:2.10.+ +errorprone-corejava8 = "com.google.errorprone:error_prone_core:2.10.0" +# 2.65.0+ requires protobuf 4.x +# checkForUpdates: google-api-protos:2.64.+ +google-api-protos = "com.google.api.grpc:proto-google-common-protos:2.64.1" +# 1.43.0+ versions of google-auth-library requires protobuf 4.x +# checkForUpdates: google-auth-credentials:1.42.+ +google-auth-credentials = "com.google.auth:google-auth-library-credentials:1.42.1" +# checkForUpdates: google-auth-oauth2Http:1.42.+ +google-auth-oauth2Http = "com.google.auth:google-auth-library-oauth2-http:1.42.1" # Release notes: https://cloud.google.com/logging/docs/release-notes -google-cloud-logging = "com.google.cloud:google-cloud-logging:3.15.14" -gson = "com.google.code.gson:gson:2.10.1" -guava = "com.google.guava:guava:32.1.3-android" +# 3.23.11+ require protobuf 4.x +# checkForUpdates: google-cloud-logging:3.23.10 +google-cloud-logging = "com.google.cloud:google-cloud-logging:3.23.10" +gson = "com.google.code.gson:gson:2.13.2" +guava = "com.google.guava:guava:33.5.0-android" guava-betaChecker = "com.google.guava:guava-beta-checker:1.0" -guava-testlib = "com.google.guava:guava-testlib:32.1.3-android" +guava-testlib = "com.google.guava:guava-testlib:33.5.0-android" # JRE version is needed for projects where its a transitive dependency, f.e. gcp-observability. # May be different from the -android version. -guava-jre = "com.google.guava:guava:32.1.3-jre" -hdrhistogram = "org.hdrhistogram:HdrHistogram:2.1.12" -j2objc-annotations = " com.google.j2objc:j2objc-annotations:2.8" +guava-jre = "com.google.guava:guava:33.5.0-jre" +hdrhistogram = "org.hdrhistogram:HdrHistogram:2.2.2" +# 6.0.0+ use java.lang.Deprecated forRemoval and since from Java 9 +# checkForUpdates: jakarta-servlet-api:5.+ jakarta-servlet-api = "jakarta.servlet:jakarta.servlet-api:5.0.0" -javax-annotation = "org.apache.tomcat:annotations-api:6.0.53" javax-servlet-api = "javax.servlet:javax.servlet-api:4.0.1" -jetty-client = "org.eclipse.jetty:jetty-client:10.0.7" -jetty-http2-server = "org.eclipse.jetty.http2:http2-server:11.0.7" -jetty-http2-server10 = "org.eclipse.jetty.http2:http2-server:10.0.7" -jetty-servlet = "org.eclipse.jetty:jetty-servlet:11.0.7" -jetty-servlet10 = "org.eclipse.jetty:jetty-servlet:10.0.7" +# 12.0.0+ require Java 17+ +# checkForUpdates: jetty-client:11.+ +jetty-client = "org.eclipse.jetty:jetty-client:11.0.26" +jetty-http2-server = "org.eclipse.jetty.http2:jetty-http2-server:12.1.7" +# 10.0.25+ uses uses @Deprecated(since=/forRemoval=) from Java 9 +# checkForUpdates: jetty-http2-server10:10.0.24 +jetty-http2-server10 = "org.eclipse.jetty.http2:http2-server:10.0.24" +jetty-servlet = "org.eclipse.jetty.ee10:jetty-ee10-servlet:12.1.7" +# checkForUpdates: jetty-servlet10:10.0.24 +jetty-servlet10 = "org.eclipse.jetty:jetty-servlet:10.0.24" jsr305 = "com.google.code.findbugs:jsr305:3.0.2" junit = "junit:junit:4.13.2" -lincheck = "org.jetbrains.kotlinx:lincheck:2.14.1" +lincheck = "org.jetbrains.lincheck:lincheck:3.4" # Update notes / 2023-07-19 sergiitk: # Couldn't update to 5.4.0, updated to the last in 4.x line. Version 5.x breaks some tests. # Error log: https://github.com/grpc/grpc-java/pull/10359#issuecomment-1632834435 # Update notes / 2023-10-09 temawi: -# 4.11.0 Has been breaking the android integration tests as mockito now uses streams +# 4.5.0 Has been breaking the android integration tests as mockito now uses streams # (not available in API levels < 24). https://github.com/grpc/grpc-java/issues/10457 +# checkForUpdates: mockito-android:4.4.+ mockito-android = "org.mockito:mockito-android:4.4.0" +# checkForUpdates: mockito-core:4.4.+ mockito-core = "org.mockito:mockito-core:4.4.0" -netty-codec-http2 = { module = "io.netty:netty-codec-http2", version.ref = "netty" } -netty-handler-proxy = { module = "io.netty:netty-handler-proxy", version.ref = "netty" } -netty-tcnative = { module = "io.netty:netty-tcnative-boringssl-static", version.ref = "nettytcnative" } -netty-tcnative-classes = { module = "io.netty:netty-tcnative-classes", version.ref = "nettytcnative" } -netty-transport-epoll = { module = "io.netty:netty-transport-native-epoll", version.ref = "netty" } -netty-unix-common = { module = "io.netty:netty-transport-native-unix-common", version.ref = "netty" } +# Need to decide when we require users to absorb the breaking changes in 4.2 +# checkForUpdates: netty-codec-http2:4.1.+ +netty-codec-http2 = "io.netty:netty-codec-http2:4.1.133.Final" +# checkForUpdates: netty-handler-proxy:4.1.+ +netty-handler-proxy = "io.netty:netty-handler-proxy:4.1.133.Final" +# Keep the following references of tcnative version in sync whenever it's updated: +# SECURITY.md +netty-tcnative = "io.netty:netty-tcnative-boringssl-static:2.0.75.Final" +netty-tcnative-classes = "io.netty:netty-tcnative-classes:2.0.75.Final" +# checkForUpdates: netty-transport-epoll:4.1.+ +netty-transport-epoll = "io.netty:netty-transport-native-epoll:4.1.133.Final" +# checkForUpdates: netty-unix-common:4.1.+ +netty-unix-common = "io.netty:netty-transport-native-unix-common:4.1.133.Final" okhttp = "com.squareup.okhttp:okhttp:2.7.5" # okio 3.5+ uses Kotlin 1.9+ which requires Android Gradle Plugin 9+ +# checkForUpdates: okio:3.4.+ okio = "com.squareup.okio:okio:3.4.0" opencensus-api = { module = "io.opencensus:opencensus-api", version.ref = "opencensus" } opencensus-contrib-grpc-metrics = { module = "io.opencensus:opencensus-contrib-grpc-metrics", version.ref = "opencensus" } opencensus-exporter-stats-stackdriver = { module = "io.opencensus:opencensus-exporter-stats-stackdriver", version.ref = "opencensus" } opencensus-exporter-trace-stackdriver = { module = "io.opencensus:opencensus-exporter-trace-stackdriver", version.ref = "opencensus" } opencensus-impl = { module = "io.opencensus:opencensus-impl", version.ref = "opencensus" } -opencensus-proto = "io.opencensus:opencensus-proto:0.2.0" -opentelemetry-api = "io.opentelemetry:opentelemetry-api:1.32.0" -opentelemetry-sdk-testing = "io.opentelemetry:opentelemetry-sdk-testing:1.32.0" -perfmark-api = "io.perfmark:perfmark-api:0.26.0" -protobuf-java = { module = "com.google.protobuf:protobuf-java", version.ref = "protobuf" } -protobuf-java-util = { module = "com.google.protobuf:protobuf-java-util", version.ref = "protobuf" } -protobuf-javalite = { module = "com.google.protobuf:protobuf-javalite", version.ref = "protobuf" } -protobuf-protoc = { module = "com.google.protobuf:protoc", version.ref = "protobuf" } -re2j = "com.google.re2j:re2j:1.7" -robolectric = "org.robolectric:robolectric:4.11.1" -signature-android = "net.sf.androidscents.signature:android-api-level-19:4.4.2_r4" +opentelemetry-api = "io.opentelemetry:opentelemetry-api:1.60.1" +opentelemetry-exporter-prometheus = "io.opentelemetry:opentelemetry-exporter-prometheus:1.60.1-alpha" +opentelemetry-gcp-resources = "io.opentelemetry.contrib:opentelemetry-gcp-resources:1.54.0-alpha" +opentelemetry-sdk-extension-autoconfigure = "io.opentelemetry:opentelemetry-sdk-extension-autoconfigure:1.60.1" +opentelemetry-sdk-testing = "io.opentelemetry:opentelemetry-sdk-testing:1.60.1" +perfmark-api = "io.perfmark:perfmark-api:0.27.0" +# Not upgrading to 4.x as it is not yet ABI compatible. +# https://github.com/protocolbuffers/protobuf/issues/17247 +# checkForUpdates: protobuf-java:3.+ +protobuf-java = "com.google.protobuf:protobuf-java:3.25.8" +# checkForUpdates: protobuf-java-util:3.+ +protobuf-java-util = "com.google.protobuf:protobuf-java-util:3.25.8" +# checkForUpdates: protobuf-javalite:3.+ +protobuf-javalite = "com.google.protobuf:protobuf-javalite:3.25.8" +# checkForUpdates: protobuf-protoc:3.+ +protobuf-protoc = "com.google.protobuf:protoc:3.25.8" +re2j = "com.google.re2j:re2j:1.8" +robolectric = "org.robolectric:robolectric:4.16.1" +s2a-proto = "com.google.s2a.proto.v2:s2a-proto:0.1.3" +signature-android = "net.sf.androidscents.signature:android-api-level-21:5.0.1_r2" signature-java = "org.codehaus.mojo.signature:java18:1.0" -tomcat-embed-core = "org.apache.tomcat.embed:tomcat-embed-core:10.0.14" -tomcat-embed-core9 = "org.apache.tomcat.embed:tomcat-embed-core:9.0.56" -truth = "com.google.truth:truth:1.1.5" -undertow-servlet = "io.undertow:undertow-servlet:2.2.14.Final" -undertow-servlet-jakartaee9 = "io.undertow:undertow-servlet-jakartaee9:2.2.13.Final" - -# Do not update: Pinned to the last version supporting Java 8. -# See https://checkstyle.sourceforge.io/releasenotes.html#Release_10.1 -checkstylejava8 = "com.puppycrawl.tools:checkstyle:9.3" -# See https://github.com/google/error-prone/releases/tag/v2.11.0 -errorprone-corejava8 = "com.google.errorprone:error_prone_core:2.10.0" +# 11.0.0+ require Java 17+ +# checkForUpdates: tomcat-embed-core:10.+ +tomcat-embed-core = "org.apache.tomcat.embed:tomcat-embed-core:10.1.52" +# checkForUpdates: tomcat-embed-core9:9.+ +tomcat-embed-core9 = "org.apache.tomcat.embed:tomcat-embed-core:9.0.115" +truth = "com.google.truth:truth:1.4.5" +# checkForUpdates: undertow-servlet22:2.2.+ +undertow-servlet22 = "io.undertow:undertow-servlet:2.2.38.Final" +undertow-servlet = "io.undertow:undertow-servlet:2.3.20.Final" diff --git a/gradle/wrapper/gradle-wrapper.properties b/gradle/wrapper/gradle-wrapper.properties index 1af9e0930b8..d4081da476b 100644 --- a/gradle/wrapper/gradle-wrapper.properties +++ b/gradle/wrapper/gradle-wrapper.properties @@ -1,6 +1,6 @@ distributionBase=GRADLE_USER_HOME distributionPath=wrapper/dists -distributionUrl=https\://services.gradle.org/distributions/gradle-8.5-bin.zip +distributionUrl=https\://services.gradle.org/distributions/gradle-8.14.3-bin.zip networkTimeout=10000 validateDistributionUrl=true zipStoreBase=GRADLE_USER_HOME diff --git a/grpclb/BUILD.bazel b/grpclb/BUILD.bazel index e82d8022bd2..ca9975b7ce6 100644 --- a/grpclb/BUILD.bazel +++ b/grpclb/BUILD.bazel @@ -1,3 +1,5 @@ +load("@rules_java//java:defs.bzl", "java_library") +load("@rules_jvm_external//:defs.bzl", "artifact") load("//:java_grpc_library.bzl", "java_grpc_library") java_library( @@ -14,13 +16,13 @@ java_library( "//api", "//context", "//core:internal", - "//util", "//stub", - "@com_google_code_findbugs_jsr305//jar", - "@com_google_guava_guava//jar", - "@com_google_j2objc_j2objc_annotations//jar", + "//util", "@com_google_protobuf//:protobuf_java_util", "@io_grpc_grpc_proto//:grpclb_load_balancer_java_proto", + artifact("com.google.code.findbugs:jsr305"), + artifact("com.google.errorprone:error_prone_annotations"), + artifact("com.google.guava:guava"), ], ) diff --git a/grpclb/build.gradle b/grpclb/build.gradle index cea599828f5..e8896604f03 100644 --- a/grpclb/build.gradle +++ b/grpclb/build.gradle @@ -19,16 +19,20 @@ dependencies { implementation project(':grpc-core'), project(':grpc-protobuf'), project(':grpc-stub'), + project(':grpc-util'), + libraries.guava, libraries.protobuf.java, - libraries.protobuf.java.util, - libraries.guava + libraries.protobuf.java.util runtimeOnly libraries.errorprone.annotations - compileOnly libraries.javax.annotation testImplementation libraries.truth, project(':grpc-inprocess'), testFixtures(project(':grpc-core')) - signature libraries.signature.java + signature (libraries.signature.java) { + artifact { + extension = "signature" + } + } } configureProtoCompilation() diff --git a/grpclb/src/generated/main/grpc/io/grpc/lb/v1/LoadBalancerGrpc.java b/grpclb/src/generated/main/grpc/io/grpc/lb/v1/LoadBalancerGrpc.java index c96c5400aac..b730eff7b37 100644 --- a/grpclb/src/generated/main/grpc/io/grpc/lb/v1/LoadBalancerGrpc.java +++ b/grpclb/src/generated/main/grpc/io/grpc/lb/v1/LoadBalancerGrpc.java @@ -4,9 +4,6 @@ /** */ -@javax.annotation.Generated( - value = "by gRPC proto compiler", - comments = "Source: grpc/lb/v1/load_balancer.proto") @io.grpc.stub.annotations.GrpcGenerated public final class LoadBalancerGrpc { @@ -60,6 +57,21 @@ public LoadBalancerStub newStub(io.grpc.Channel channel, io.grpc.CallOptions cal return LoadBalancerStub.newStub(factory, channel); } + /** + * Creates a new blocking-style stub that supports all types of calls on the service + */ + public static LoadBalancerBlockingV2Stub newBlockingV2Stub( + io.grpc.Channel channel) { + io.grpc.stub.AbstractStub.StubFactory factory = + new io.grpc.stub.AbstractStub.StubFactory() { + @java.lang.Override + public LoadBalancerBlockingV2Stub newStub(io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new LoadBalancerBlockingV2Stub(channel, callOptions); + } + }; + return LoadBalancerBlockingV2Stub.newStub(factory, channel); + } + /** * Creates a new blocking-style stub that supports unary and streaming output calls on the service */ @@ -147,6 +159,35 @@ public io.grpc.stub.StreamObserver balanceLoad /** * A stub to allow clients to do synchronous rpc calls to service LoadBalancer. */ + public static final class LoadBalancerBlockingV2Stub + extends io.grpc.stub.AbstractBlockingStub { + private LoadBalancerBlockingV2Stub( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + super(channel, callOptions); + } + + @java.lang.Override + protected LoadBalancerBlockingV2Stub build( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new LoadBalancerBlockingV2Stub(channel, callOptions); + } + + /** + *

+     * Bidirectional rpc to get a list of servers.
+     * 
+ */ + @io.grpc.ExperimentalApi("https://github.com/grpc/grpc-java/issues/10918") + public io.grpc.stub.BlockingClientCall + balanceLoad() { + return io.grpc.stub.ClientCalls.blockingBidiStreamingCall( + getChannel(), getBalanceLoadMethod(), getCallOptions()); + } + } + + /** + * A stub to allow clients to do limited synchronous rpc calls to service LoadBalancer. + */ public static final class LoadBalancerBlockingStub extends io.grpc.stub.AbstractBlockingStub { private LoadBalancerBlockingStub( diff --git a/grpclb/src/main/java/io/grpc/grpclb/GrpclbClientLoadRecorder.java b/grpclb/src/main/java/io/grpc/grpclb/GrpclbClientLoadRecorder.java index d27c485dc13..fe928263ef9 100644 --- a/grpclb/src/main/java/io/grpc/grpclb/GrpclbClientLoadRecorder.java +++ b/grpclb/src/main/java/io/grpc/grpclb/GrpclbClientLoadRecorder.java @@ -18,6 +18,7 @@ import static com.google.common.base.Preconditions.checkNotNull; +import com.google.errorprone.annotations.concurrent.GuardedBy; import com.google.protobuf.util.Timestamps; import io.grpc.ClientStreamTracer; import io.grpc.Metadata; @@ -29,7 +30,6 @@ import java.util.HashMap; import java.util.Map; import java.util.concurrent.atomic.AtomicLongFieldUpdater; -import javax.annotation.concurrent.GuardedBy; import javax.annotation.concurrent.ThreadSafe; /** diff --git a/grpclb/src/main/java/io/grpc/grpclb/GrpclbConfig.java b/grpclb/src/main/java/io/grpc/grpclb/GrpclbConfig.java index 4395c8415dc..1476e3e2f83 100644 --- a/grpclb/src/main/java/io/grpc/grpclb/GrpclbConfig.java +++ b/grpclb/src/main/java/io/grpc/grpclb/GrpclbConfig.java @@ -55,7 +55,7 @@ long getFallbackTimeoutMs() { } /** - * If specified, it overrides the name of the sevice name to be sent to the balancer. if not, the + * If specified, it overrides the name of the service name to be sent to the balancer. if not, the * target to be sent to the balancer will continue to be obtained from the target URI passed * to the gRPC client channel. */ diff --git a/grpclb/src/main/java/io/grpc/grpclb/GrpclbLoadBalancer.java b/grpclb/src/main/java/io/grpc/grpclb/GrpclbLoadBalancer.java index 872937b03c1..bf9eea69af0 100644 --- a/grpclb/src/main/java/io/grpc/grpclb/GrpclbLoadBalancer.java +++ b/grpclb/src/main/java/io/grpc/grpclb/GrpclbLoadBalancer.java @@ -154,6 +154,8 @@ public void handleNameResolutionError(Status error) { } @Override + @Deprecated + @SuppressWarnings("InlineMeSuggester") public boolean canHandleEmptyAddressListFromNameResolution() { return true; } diff --git a/grpclb/src/main/java/io/grpc/grpclb/GrpclbNameResolver.java b/grpclb/src/main/java/io/grpc/grpclb/GrpclbNameResolver.java index d17587fb14d..60d02220e64 100644 --- a/grpclb/src/main/java/io/grpc/grpclb/GrpclbNameResolver.java +++ b/grpclb/src/main/java/io/grpc/grpclb/GrpclbNameResolver.java @@ -21,6 +21,7 @@ import io.grpc.Attributes; import io.grpc.EquivalentAddressGroup; import io.grpc.NameResolver; +import io.grpc.StatusOr; import io.grpc.internal.DnsNameResolver; import io.grpc.internal.SharedResourceHolder.Resource; import java.net.InetAddress; @@ -58,14 +59,22 @@ final class GrpclbNameResolver extends DnsNameResolver { } @Override - protected InternalResolutionResult doResolve(boolean forceTxt) { + protected ResolutionResult doResolve() { + ResolutionResult result = super.doResolve(); List balancerAddrs = resolveBalancerAddresses(); - InternalResolutionResult result = super.doResolve(!balancerAddrs.isEmpty()); if (!balancerAddrs.isEmpty()) { - result.attributes = - Attributes.newBuilder() + ResolutionResult.Builder resultBuilder = result.toBuilder() + .setAttributes(result.getAttributes().toBuilder() .set(GrpclbConstants.ATTR_LB_ADDRS, balancerAddrs) - .build(); + .build()); + if (!result.getAddressesOrError().hasValue()) { + // While ResolutionResult is powerful enough to communicate attributes simultaneously with + // an address resolution failure, LoadBalancer.ResolvedAddresses isn't yet and so the + // attributes are lost if addresses fail. GrpclbLB will be able to handle the lack of + // addresses since there are LB addresses, so discard the failure for now. + resultBuilder.setAddressesOrError(StatusOr.fromValue(Collections.emptyList())); + } + result = resultBuilder.build(); } return result; } diff --git a/grpclb/src/main/java/io/grpc/grpclb/GrpclbState.java b/grpclb/src/main/java/io/grpc/grpclb/GrpclbState.java index 49b74645ec8..5ed84ade2f8 100644 --- a/grpclb/src/main/java/io/grpc/grpclb/GrpclbState.java +++ b/grpclb/src/main/java/io/grpc/grpclb/GrpclbState.java @@ -37,13 +37,16 @@ import io.grpc.ConnectivityStateInfo; import io.grpc.Context; import io.grpc.EquivalentAddressGroup; -import io.grpc.LoadBalancer.CreateSubchannelArgs; +import io.grpc.LoadBalancer; +import io.grpc.LoadBalancer.FixedResultPicker; import io.grpc.LoadBalancer.Helper; import io.grpc.LoadBalancer.PickResult; import io.grpc.LoadBalancer.PickSubchannelArgs; +import io.grpc.LoadBalancer.ResolvedAddresses; import io.grpc.LoadBalancer.Subchannel; import io.grpc.LoadBalancer.SubchannelPicker; -import io.grpc.LoadBalancer.SubchannelStateListener; +import io.grpc.LoadBalancerProvider; +import io.grpc.LoadBalancerRegistry; import io.grpc.ManagedChannel; import io.grpc.Metadata; import io.grpc.Status; @@ -62,6 +65,7 @@ import io.grpc.lb.v1.Server; import io.grpc.lb.v1.ServerList; import io.grpc.stub.StreamObserver; +import io.grpc.util.ForwardingLoadBalancerHelper; import java.net.InetAddress; import java.net.InetSocketAddress; import java.net.UnknownHostException; @@ -119,7 +123,7 @@ final class GrpclbState { @VisibleForTesting static final RoundRobinEntry BUFFER_ENTRY = new RoundRobinEntry() { @Override - public PickResult picked(Metadata headers) { + public PickResult picked(PickSubchannelArgs args) { return PickResult.withNoResult(); } @@ -183,10 +187,20 @@ enum Mode { private List dropList = Collections.emptyList(); // Contains only non-drop, i.e., backends from the round-robin list from the balancer. private List backendList = Collections.emptyList(); + private ConnectivityState currentState = ConnectivityState.CONNECTING; private RoundRobinPicker currentPicker = new RoundRobinPicker(Collections.emptyList(), Arrays.asList(BUFFER_ENTRY)); private boolean requestConnectionPending; + // Child LoadBalancer and state for PICK_FIRST mode delegation. + private final LoadBalancerProvider pickFirstLbProvider; + @Nullable + private LoadBalancer pickFirstLb; + private ConnectivityState pickFirstLbState = CONNECTING; + private SubchannelPicker pickFirstLbPicker = new FixedResultPicker(PickResult.withNoResult()); + @Nullable + private GrpclbClientLoadRecorder currentPickFirstLoadRecorder; + GrpclbState( GrpclbConfig config, Helper helper, @@ -212,6 +226,9 @@ public void onSubchannelState( } else { this.subchannelPool = null; } + this.pickFirstLbProvider = checkNotNull( + LoadBalancerRegistry.getDefaultRegistry().getProvider("pick_first"), + "pick_first balancer not available"); this.time = checkNotNull(time, "time provider"); this.stopwatch = checkNotNull(stopwatch, "stopwatch"); this.timerService = checkNotNull(helper.getScheduledExecutorService(), "timerService"); @@ -309,6 +326,12 @@ void handleAddresses( void requestConnection() { requestConnectionPending = true; + // For PICK_FIRST mode with delegation, forward to the child LB. + if (config.getMode() == Mode.PICK_FIRST && pickFirstLb != null) { + pickFirstLb.requestConnection(); + requestConnectionPending = false; + return; + } for (RoundRobinEntry entry : currentPicker.pickList) { if (entry instanceof IdleSubchannelEntry) { ((IdleSubchannelEntry) entry).subchannel.requestConnection(); @@ -323,15 +346,23 @@ private void maybeUseFallbackBackends() { } // Balancer RPC should have either been broken or timed out. checkState(fallbackReason != null, "no reason to fallback"); - for (Subchannel subchannel : subchannels.values()) { - ConnectivityStateInfo stateInfo = subchannel.getAttributes().get(STATE_INFO).get(); - if (stateInfo.getState() == READY) { + // For PICK_FIRST mode with delegation, check the child LB's state. + if (config.getMode() == Mode.PICK_FIRST) { + if (pickFirstLb != null && pickFirstLbState == READY) { return; } - // If we do have balancer-provided backends, use one of its error in the error message if - // fail to fallback. - if (stateInfo.getState() == TRANSIENT_FAILURE) { - fallbackReason = stateInfo.getStatus(); + // For PICK_FIRST, we don't have individual subchannel states to use as fallback reason. + } else { + for (Subchannel subchannel : subchannels.values()) { + ConnectivityStateInfo stateInfo = subchannel.getAttributes().get(STATE_INFO).get(); + if (stateInfo.getState() == READY) { + return; + } + // If we do have balancer-provided backends, use one of its error in the error message if + // fail to fallback. + if (stateInfo.getState() == TRANSIENT_FAILURE) { + fallbackReason = stateInfo.getStatus(); + } } } // Fallback conditions met @@ -355,11 +386,12 @@ private void useFallbackBackends() { } private void shutdownLbComm() { + shutdownLbRpc(); if (lbCommChannel != null) { - lbCommChannel.shutdown(); + // The channel should have no RPCs at this point + lbCommChannel.shutdownNow(); lbCommChannel = null; } - shutdownLbRpc(); } private void shutdownLbRpc() { @@ -438,9 +470,10 @@ void shutdown() { subchannelPool.clear(); break; case PICK_FIRST: - if (!subchannels.isEmpty()) { - checkState(subchannels.size() == 1, "Excessive Subchannels: %s", subchannels); - subchannels.values().iterator().next().shutdown(); + // Shutdown the child pick_first LB which manages its own subchannels. + if (pickFirstLb != null) { + pickFirstLb.shutdown(); + pickFirstLb = null; } break; default: @@ -517,22 +550,17 @@ private void updateServerList( subchannels = Collections.unmodifiableMap(newSubchannelMap); break; case PICK_FIRST: - checkState(subchannels.size() <= 1, "Unexpected Subchannel count: %s", subchannels); - final Subchannel subchannel; + // Delegate to child pick_first LB for address management. + // Shutdown existing child LB if addresses become empty. if (newBackendAddrList.isEmpty()) { - if (subchannels.size() == 1) { - subchannel = subchannels.values().iterator().next(); - subchannel.shutdown(); - subchannels = Collections.emptyMap(); + if (pickFirstLb != null) { + pickFirstLb.shutdown(); + pickFirstLb = null; } break; } List eagList = new ArrayList<>(); - // Because for PICK_FIRST, we create a single Subchannel for all addresses, we have to - // attach the tokens to the EAG attributes and use TokenAttachingLoadRecorder to put them on - // headers. - // - // The PICK_FIRST code path doesn't cache Subchannels. + // Attach tokens to EAG attributes for TokenAttachingTracerFactory to retrieve. for (BackendAddressGroup bag : newBackendAddrList) { EquivalentAddressGroup origEag = bag.getAddresses(); Attributes eagAttrs = origEag.getAttributes(); @@ -542,30 +570,22 @@ private void updateServerList( } eagList.add(new EquivalentAddressGroup(origEag.getAddresses(), eagAttrs)); } - if (subchannels.isEmpty()) { - subchannel = - helper.createSubchannel( - CreateSubchannelArgs.newBuilder() - .setAddresses(eagList) - .setAttributes(createSubchannelAttrs()) - .build()); - subchannel.start(new SubchannelStateListener() { - @Override - public void onSubchannelState(ConnectivityStateInfo newState) { - handleSubchannelState(subchannel, newState); - } - }); - if (requestConnectionPending) { - subchannel.requestConnection(); - requestConnectionPending = false; - } - } else { - subchannel = subchannels.values().iterator().next(); - subchannel.updateAddresses(eagList); + + if (pickFirstLb == null) { + pickFirstLb = pickFirstLbProvider.newLoadBalancer(new PickFirstLbHelper()); + } + + // Pass addresses to child LB. + pickFirstLb.acceptResolvedAddresses( + ResolvedAddresses.newBuilder() + .setAddresses(eagList) + .build()); + if (requestConnectionPending) { + pickFirstLb.requestConnection(); + requestConnectionPending = false; } - subchannels = Collections.singletonMap(eagList, subchannel); - newBackendList.add( - new BackendEntry(subchannel, new TokenAttachingTracerFactory(loadRecorder))); + // Store the load recorder for token attachment. + currentPickFirstLoadRecorder = loadRecorder; break; default: throw new AssertionError("Missing case for " + config.getMode()); @@ -842,7 +862,11 @@ private void cleanUp() { private void maybeUpdatePicker() { List pickList; ConnectivityState state; - if (backendList.isEmpty()) { + // For PICK_FIRST mode with delegation, check if child LB exists instead of backendList. + boolean hasBackends = config.getMode() == Mode.PICK_FIRST + ? pickFirstLb != null + : !backendList.isEmpty(); + if (!hasBackends) { // Note balancer (is working) may enforce using fallback backends, and that fallback may // fail. So we should check if currently in fallback first. if (usingFallbackBackends) { @@ -894,26 +918,12 @@ private void maybeUpdatePicker() { } break; case PICK_FIRST: { - checkState(backendList.size() == 1, "Excessive backend entries: %s", backendList); - BackendEntry onlyEntry = backendList.get(0); - ConnectivityStateInfo stateInfo = - onlyEntry.subchannel.getAttributes().get(STATE_INFO).get(); - state = stateInfo.getState(); - switch (state) { - case READY: - pickList = Collections.singletonList(onlyEntry); - break; - case TRANSIENT_FAILURE: - pickList = - Collections.singletonList(new ErrorEntry(stateInfo.getStatus())); - break; - case CONNECTING: - pickList = Collections.singletonList(BUFFER_ENTRY); - break; - default: - pickList = Collections.singletonList( - new IdleSubchannelEntry(onlyEntry.subchannel, syncContext)); - } + // Use child LB's state and picker. Wrap the picker for token attachment. + state = pickFirstLbState; + TokenAttachingTracerFactory tracerFactory = + new TokenAttachingTracerFactory(currentPickFirstLoadRecorder); + pickList = Collections.singletonList( + new ChildLbPickerEntry(pickFirstLbPicker, tracerFactory)); break; } default: @@ -929,10 +939,12 @@ private void maybeUpdatePicker(ConnectivityState state, RoundRobinPicker picker) // Discard the new picker if we are sure it won't make any difference, in order to save // re-processing pending streams, and avoid unnecessary resetting of the pointer in // RoundRobinPicker. - if (picker.dropList.equals(currentPicker.dropList) + if (state.equals(currentState) + && picker.dropList.equals(currentPicker.dropList) && picker.pickList.equals(currentPicker.pickList)) { return; } + currentState = state; currentPicker = picker; helper.updateBalancingState(state, picker); } @@ -983,7 +995,7 @@ public boolean equals(Object other) { @VisibleForTesting interface RoundRobinEntry { - PickResult picked(Metadata headers); + PickResult picked(PickSubchannelArgs args); } @VisibleForTesting @@ -1024,7 +1036,8 @@ static final class BackendEntry implements RoundRobinEntry { } @Override - public PickResult picked(Metadata headers) { + public PickResult picked(PickSubchannelArgs args) { + Metadata headers = args.getHeaders(); headers.discardAll(GrpclbConstants.TOKEN_METADATA_KEY); if (token != null) { headers.put(GrpclbConstants.TOKEN_METADATA_KEY, token); @@ -1065,7 +1078,7 @@ static final class IdleSubchannelEntry implements RoundRobinEntry { } @Override - public PickResult picked(Metadata headers) { + public PickResult picked(PickSubchannelArgs args) { if (connectionRequested.compareAndSet(false, true)) { syncContext.execute(new Runnable() { @Override @@ -1108,7 +1121,7 @@ static final class ErrorEntry implements RoundRobinEntry { } @Override - public PickResult picked(Metadata headers) { + public PickResult picked(PickSubchannelArgs args) { return result; } @@ -1132,6 +1145,58 @@ public String toString() { } } + /** + * Entry that wraps a child LB's picker for PICK_FIRST mode delegation. + * Attaches TokenAttachingTracerFactory to the pick result for token propagation. + */ + @VisibleForTesting + static final class ChildLbPickerEntry implements RoundRobinEntry { + private final SubchannelPicker childPicker; + private final TokenAttachingTracerFactory tracerFactory; + + ChildLbPickerEntry(SubchannelPicker childPicker, TokenAttachingTracerFactory tracerFactory) { + this.childPicker = checkNotNull(childPicker, "childPicker"); + this.tracerFactory = checkNotNull(tracerFactory, "tracerFactory"); + } + + @Override + public PickResult picked(PickSubchannelArgs args) { + PickResult childResult = childPicker.pickSubchannel(args); + if (childResult.getSubchannel() == null) { + // No subchannel (e.g., buffer, error), return as-is. + return childResult; + } + // Wrap the pick result to attach tokens via the tracer factory. + return PickResult.withSubchannel( + childResult.getSubchannel(), tracerFactory, childResult.getAuthorityOverride()); + } + + @Override + public int hashCode() { + return Objects.hashCode(childPicker, tracerFactory); + } + + @Override + public boolean equals(Object other) { + if (!(other instanceof ChildLbPickerEntry)) { + return false; + } + ChildLbPickerEntry that = (ChildLbPickerEntry) other; + return Objects.equal(childPicker, that.childPicker) + && Objects.equal(tracerFactory, that.tracerFactory); + } + + @Override + public String toString() { + return "ChildLbPickerEntry(" + childPicker + ")"; + } + + @VisibleForTesting + SubchannelPicker getChildPicker() { + return childPicker; + } + } + @VisibleForTesting static final class RoundRobinPicker extends SubchannelPicker { @VisibleForTesting @@ -1174,7 +1239,7 @@ public PickResult pickSubchannel(PickSubchannelArgs args) { if (pickIndex == pickList.size()) { pickIndex = 0; } - return pick.picked(args.getHeaders()); + return pick.picked(args); } } @@ -1189,4 +1254,28 @@ public String toString() { return MoreObjects.toStringHelper(RoundRobinPicker.class).toString(); } } + + /** + * Helper for the child pick_first LB in PICK_FIRST mode. Intercepts updateBalancingState() + * to store state and trigger the grpclb picker update with drops and token attachment. + */ + private final class PickFirstLbHelper extends ForwardingLoadBalancerHelper { + + @Override + protected Helper delegate() { + return helper; + } + + @Override + public void updateBalancingState(ConnectivityState newState, SubchannelPicker newPicker) { + pickFirstLbState = newState; + pickFirstLbPicker = newPicker; + // Trigger name resolution refresh on TRANSIENT_FAILURE or IDLE, similar to ROUND_ROBIN. + if (newState == TRANSIENT_FAILURE || newState == IDLE) { + helper.refreshNameResolution(); + } + maybeUseFallbackBackends(); + maybeUpdatePicker(); + } + } } diff --git a/grpclb/src/main/java/io/grpc/grpclb/SecretGrpclbNameResolverProvider.java b/grpclb/src/main/java/io/grpc/grpclb/SecretGrpclbNameResolverProvider.java index 8952ea1d8fb..f394c812b28 100644 --- a/grpclb/src/main/java/io/grpc/grpclb/SecretGrpclbNameResolverProvider.java +++ b/grpclb/src/main/java/io/grpc/grpclb/SecretGrpclbNameResolverProvider.java @@ -19,14 +19,17 @@ import com.google.common.base.Preconditions; import com.google.common.base.Stopwatch; import io.grpc.InternalServiceProviders; +import io.grpc.NameResolver; import io.grpc.NameResolver.Args; import io.grpc.NameResolverProvider; +import io.grpc.Uri; import io.grpc.internal.GrpcUtil; import java.net.InetSocketAddress; import java.net.SocketAddress; import java.net.URI; import java.util.Collection; import java.util.Collections; +import java.util.List; /** * A provider for {@code io.grpc.grpclb.GrpclbNameResolver}. @@ -56,27 +59,47 @@ public static final class Provider extends NameResolverProvider { private static final boolean IS_ANDROID = InternalServiceProviders .isAndroid(SecretGrpclbNameResolverProvider.class.getClassLoader()); + @Override + public NameResolver newNameResolver(Uri targetUri, final NameResolver.Args args) { + if (SCHEME.equals(targetUri.getScheme())) { + List pathSegments = targetUri.getPathSegments(); + Preconditions.checkArgument( + pathSegments.size() == 1, + "expected 1 path segment in target %s but found %s", + targetUri, + pathSegments); + return newNameResolver(targetUri.getAuthority(), pathSegments.get(0), args); + } else { + return null; + } + } + @Override public GrpclbNameResolver newNameResolver(URI targetUri, Args args) { + // TODO(jdcormie): Remove once RFC 3986 migration is complete. if (SCHEME.equals(targetUri.getScheme())) { String targetPath = Preconditions.checkNotNull(targetUri.getPath(), "targetPath"); Preconditions.checkArgument( targetPath.startsWith("/"), "the path component (%s) of the target (%s) must start with '/'", targetPath, targetUri); - String name = targetPath.substring(1); - return new GrpclbNameResolver( - targetUri.getAuthority(), - name, - args, - GrpcUtil.SHARED_CHANNEL_EXECUTOR, - Stopwatch.createUnstarted(), - IS_ANDROID); + return newNameResolver(targetUri.getAuthority(), targetPath.substring(1), args); } else { return null; } } + private GrpclbNameResolver newNameResolver( + String authority, String domainNameToResolve, final NameResolver.Args args) { + return new GrpclbNameResolver( + authority, + domainNameToResolve, + args, + GrpcUtil.SHARED_CHANNEL_EXECUTOR, + Stopwatch.createUnstarted(), + IS_ANDROID); + } + @Override public String getDefaultScheme() { return SCHEME; diff --git a/grpclb/src/test/java/io/grpc/grpclb/GrpclbLoadBalancerTest.java b/grpclb/src/test/java/io/grpc/grpclb/GrpclbLoadBalancerTest.java index e489129676a..ef31b318cb5 100644 --- a/grpclb/src/test/java/io/grpc/grpclb/GrpclbLoadBalancerTest.java +++ b/grpclb/src/test/java/io/grpc/grpclb/GrpclbLoadBalancerTest.java @@ -72,6 +72,7 @@ import io.grpc.Status.Code; import io.grpc.SynchronizationContext; import io.grpc.grpclb.GrpclbState.BackendEntry; +import io.grpc.grpclb.GrpclbState.ChildLbPickerEntry; import io.grpc.grpclb.GrpclbState.DropEntry; import io.grpc.grpclb.GrpclbState.ErrorEntry; import io.grpc.grpclb.GrpclbState.IdleSubchannelEntry; @@ -779,7 +780,9 @@ public void receiveNoBackendAndBalancerAddress() { verify(helper).updateBalancingState(eq(TRANSIENT_FAILURE), pickerCaptor.capture()); RoundRobinPicker picker = (RoundRobinPicker) pickerCaptor.getValue(); assertThat(picker.dropList).isEmpty(); - Status error = Iterables.getOnlyElement(picker.pickList).picked(new Metadata()).getStatus(); + PickSubchannelArgs args = mock(PickSubchannelArgs.class); + when(args.getHeaders()).thenReturn(new Metadata()); + Status error = Iterables.getOnlyElement(picker.pickList).picked(args).getStatus(); assertThat(error.getCode()).isEqualTo(Code.UNAVAILABLE); assertThat(error.getDescription()).isEqualTo("No backend or balancer addresses found"); } @@ -1915,6 +1918,7 @@ public void grpclbWorking_pickFirstMode() throws Exception { lbResponseObserver.onNext(buildInitialResponse()); lbResponseObserver.onNext(buildLbResponse(backends1)); + // With delegation, the child pick_first creates the subchannel inOrder.verify(helper).createSubchannel(createSubchannelArgsCaptor.capture()); CreateSubchannelArgs createSubchannelArgs = createSubchannelArgsCaptor.getValue(); assertThat(createSubchannelArgs.getAddresses()) @@ -1922,42 +1926,41 @@ public void grpclbWorking_pickFirstMode() throws Exception { new EquivalentAddressGroup(backends1.get(0).addr, eagAttrsWithToken("token0001")), new EquivalentAddressGroup(backends1.get(1).addr, eagAttrsWithToken("token0002"))); - // Initially IDLE - inOrder.verify(helper).updateBalancingState(eq(IDLE), pickerCaptor.capture()); + // Child pick_first eagerly connects, so we start in CONNECTING + inOrder.verify(helper, atLeast(1)).updateBalancingState(eq(CONNECTING), pickerCaptor.capture()); RoundRobinPicker picker0 = (RoundRobinPicker) pickerCaptor.getValue(); - - // Only one subchannel is created + // Only one subchannel is created by the child LB assertThat(mockSubchannels).hasSize(1); Subchannel subchannel = mockSubchannels.poll(); assertThat(picker0.dropList).containsExactly(null, null); - assertThat(picker0.pickList).containsExactly(new IdleSubchannelEntry(subchannel, syncContext)); + assertThat(picker0.pickList).hasSize(1); + assertThat(picker0.pickList.get(0)).isInstanceOf(ChildLbPickerEntry.class); - // PICK_FIRST doesn't eagerly connect - verify(subchannel, never()).requestConnection(); - - // CONNECTING - deliverSubchannelState(subchannel, ConnectivityStateInfo.forNonError(CONNECTING)); - inOrder.verify(helper).updateBalancingState(eq(CONNECTING), pickerCaptor.capture()); - RoundRobinPicker picker1 = (RoundRobinPicker) pickerCaptor.getValue(); - assertThat(picker1.dropList).containsExactly(null, null); - assertThat(picker1.pickList).containsExactly(BUFFER_ENTRY); + // Child pick_first eagerly calls requestConnection() + verify(subchannel).requestConnection(); // TRANSIENT_FAILURE Status error = Status.UNAVAILABLE.withDescription("Simulated connection error"); deliverSubchannelState(subchannel, ConnectivityStateInfo.forTransientFailure(error)); - inOrder.verify(helper).updateBalancingState(eq(TRANSIENT_FAILURE), pickerCaptor.capture()); - RoundRobinPicker picker2 = (RoundRobinPicker) pickerCaptor.getValue(); - assertThat(picker2.dropList).containsExactly(null, null); - assertThat(picker2.pickList).containsExactly(new ErrorEntry(error)); + // The child LB will notify our helper, which updates grpclb state + inOrder.verify(helper, atLeast(1)) + .updateBalancingState(eq(TRANSIENT_FAILURE), pickerCaptor.capture()); + RoundRobinPicker picker1 = (RoundRobinPicker) pickerCaptor.getValue(); + assertThat(picker1.dropList).containsExactly(null, null); + ChildLbPickerEntry failureEntry = (ChildLbPickerEntry) picker1.pickList.get(0); + PickResult failureResult = + failureEntry.getChildPicker().pickSubchannel(mock(PickSubchannelArgs.class)); + assertThat(failureResult.getStatus()).isEqualTo(error); // READY deliverSubchannelState(subchannel, ConnectivityStateInfo.forNonError(READY)); - inOrder.verify(helper).updateBalancingState(eq(READY), pickerCaptor.capture()); - RoundRobinPicker picker3 = (RoundRobinPicker) pickerCaptor.getValue(); - assertThat(picker3.dropList).containsExactly(null, null); - assertThat(picker3.pickList).containsExactly( - new BackendEntry(subchannel, new TokenAttachingTracerFactory(getLoadRecorder()))); - + inOrder.verify(helper, atLeast(1)).updateBalancingState(eq(READY), pickerCaptor.capture()); + RoundRobinPicker picker2 = (RoundRobinPicker) pickerCaptor.getValue(); + assertThat(picker2.dropList).containsExactly(null, null); + ChildLbPickerEntry readyEntry = (ChildLbPickerEntry) picker2.pickList.get(0); + PickResult readyResult = + readyEntry.getChildPicker().pickSubchannel(mock(PickSubchannelArgs.class)); + assertThat(readyResult.getSubchannel()).isEqualTo(subchannel); // New server list with drops List backends2 = Arrays.asList( @@ -1968,37 +1971,40 @@ public void grpclbWorking_pickFirstMode() throws Exception { .updateBalancingState(any(ConnectivityState.class), any(SubchannelPicker.class)); lbResponseObserver.onNext(buildLbResponse(backends2)); - // new addresses will be updated to the existing subchannel - // createSubchannel() has ever been called only once + // Verify child LB is updated with new addresses, NOT recreated + inOrder.verify(helper, never()).createSubchannel(any(CreateSubchannelArgs.class)); verify(helper, times(1)).createSubchannel(any(CreateSubchannelArgs.class)); assertThat(mockSubchannels).isEmpty(); + + // The child LB policy internally calls updateAddresses on the subchannel verify(subchannel).updateAddresses( eq(Arrays.asList( new EquivalentAddressGroup(backends2.get(0).addr, eagAttrsWithToken("token0001")), new EquivalentAddressGroup(backends2.get(2).addr, eagAttrsWithToken("token0004"))))); - inOrder.verify(helper).updateBalancingState(eq(READY), pickerCaptor.capture()); - RoundRobinPicker picker4 = (RoundRobinPicker) pickerCaptor.getValue(); - assertThat(picker4.dropList).containsExactly( + inOrder.verify(helper, atLeast(1)).updateBalancingState(eq(READY), pickerCaptor.capture()); + RoundRobinPicker picker3 = (RoundRobinPicker) pickerCaptor.getValue(); + assertThat(picker3.dropList).containsExactly( null, new DropEntry(getLoadRecorder(), "token0003"), null); - assertThat(picker4.pickList).containsExactly( - new BackendEntry(subchannel, new TokenAttachingTracerFactory(getLoadRecorder()))); + ChildLbPickerEntry updatedEntry = (ChildLbPickerEntry) picker3.pickList.get(0); + PickResult updatedResult = + updatedEntry.getChildPicker().pickSubchannel(mock(PickSubchannelArgs.class)); + assertThat(updatedResult.getSubchannel()).isEqualTo(subchannel); - // Subchannel goes IDLE, but PICK_FIRST will not try to reconnect + // Subchannel goes IDLE, grpclb state should follow deliverSubchannelState(subchannel, ConnectivityStateInfo.forNonError(IDLE)); inOrder.verify(helper).updateBalancingState(eq(IDLE), pickerCaptor.capture()); - RoundRobinPicker picker5 = (RoundRobinPicker) pickerCaptor.getValue(); - verify(subchannel, never()).requestConnection(); + RoundRobinPicker picker4 = (RoundRobinPicker) pickerCaptor.getValue(); - // ... until it's selected + // No new connection request should have happened yet (beyond the first eager one) + verify(subchannel, times(1)).requestConnection(); PickSubchannelArgs args = mock(PickSubchannelArgs.class); - PickResult pick = picker5.pickSubchannel(args); - assertThat(pick).isSameInstanceAs(PickResult.withNoResult()); - verify(subchannel).requestConnection(); - - // ... or requested by application - balancer.requestConnection(); + PickResult pick = picker4.pickSubchannel(args); + // Child pick_first picker returns withNoResult() when IDLE and requests connection + assertThat(pick.getSubchannel()).isNull(); verify(subchannel, times(2)).requestConnection(); + balancer.requestConnection(); + verify(subchannel, times(3)).requestConnection(); // PICK_FIRST doesn't use subchannelPool verify(subchannelPool, never()) @@ -2036,6 +2042,7 @@ public void grpclbWorking_pickFirstMode_lbSendsEmptyAddress() throws Exception { lbResponseObserver.onNext(buildInitialResponse()); lbResponseObserver.onNext(buildLbResponse(backends1)); + // The child pick_first creates the first subchannel inOrder.verify(helper).createSubchannel(createSubchannelArgsCaptor.capture()); CreateSubchannelArgs createSubchannelArgs = createSubchannelArgsCaptor.getValue(); assertThat(createSubchannelArgs.getAddresses()) @@ -2043,56 +2050,43 @@ public void grpclbWorking_pickFirstMode_lbSendsEmptyAddress() throws Exception { new EquivalentAddressGroup(backends1.get(0).addr, eagAttrsWithToken("token0001")), new EquivalentAddressGroup(backends1.get(1).addr, eagAttrsWithToken("token0002"))); - // Initially IDLE - inOrder.verify(helper).updateBalancingState(eq(IDLE), pickerCaptor.capture()); + // Child pick_first eagerly connects, so initial state is CONNECTING + inOrder.verify(helper, atLeast(1)).updateBalancingState(eq(CONNECTING), pickerCaptor.capture()); RoundRobinPicker picker0 = (RoundRobinPicker) pickerCaptor.getValue(); - - // Only one subchannel is created + // Verify subchannel creation by child LB assertThat(mockSubchannels).hasSize(1); Subchannel subchannel = mockSubchannels.poll(); assertThat(picker0.dropList).containsExactly(null, null); - assertThat(picker0.pickList).containsExactly(new IdleSubchannelEntry(subchannel, syncContext)); - - // PICK_FIRST doesn't eagerly connect - verify(subchannel, never()).requestConnection(); + assertThat(picker0.pickList).hasSize(1); + assertThat(picker0.pickList.get(0)).isInstanceOf(ChildLbPickerEntry.class); - // CONNECTING - deliverSubchannelState(subchannel, ConnectivityStateInfo.forNonError(CONNECTING)); - inOrder.verify(helper).updateBalancingState(eq(CONNECTING), pickerCaptor.capture()); - RoundRobinPicker picker1 = (RoundRobinPicker) pickerCaptor.getValue(); - assertThat(picker1.dropList).containsExactly(null, null); - assertThat(picker1.pickList).containsExactly(BUFFER_ENTRY); - - // TRANSIENT_FAILURE - Status error = Status.UNAVAILABLE.withDescription("Simulated connection error"); - deliverSubchannelState(subchannel, ConnectivityStateInfo.forTransientFailure(error)); - inOrder.verify(helper).updateBalancingState(eq(TRANSIENT_FAILURE), pickerCaptor.capture()); - RoundRobinPicker picker2 = (RoundRobinPicker) pickerCaptor.getValue(); - assertThat(picker2.dropList).containsExactly(null, null); - assertThat(picker2.pickList).containsExactly(new ErrorEntry(error)); + // Child pick_first eagerly calls requestConnection() + verify(subchannel).requestConnection(); // READY deliverSubchannelState(subchannel, ConnectivityStateInfo.forNonError(READY)); - inOrder.verify(helper).updateBalancingState(eq(READY), pickerCaptor.capture()); - RoundRobinPicker picker3 = (RoundRobinPicker) pickerCaptor.getValue(); - assertThat(picker3.dropList).containsExactly(null, null); - assertThat(picker3.pickList).containsExactly( - new BackendEntry(subchannel, new TokenAttachingTracerFactory(getLoadRecorder()))); - + inOrder.verify(helper, atLeast(1)).updateBalancingState(eq(READY), pickerCaptor.capture()); + RoundRobinPicker pickerReady = (RoundRobinPicker) pickerCaptor.getValue(); + // Verify the subchannel in the delegated picker + ChildLbPickerEntry readyEntry = (ChildLbPickerEntry) pickerReady.pickList.get(0); + assertThat( + readyEntry.getChildPicker().pickSubchannel(mock(PickSubchannelArgs.class)).getSubchannel()) + .isEqualTo(subchannel); inOrder.verify(helper, never()) .updateBalancingState(any(ConnectivityState.class), any(SubchannelPicker.class)); - // Empty addresses from LB + // Empty addresses from LB - child LB is shutdown lbResponseObserver.onNext(buildLbResponse(Collections.emptyList())); - // new addresses will be updated to the existing subchannel + // Child LB is shutdown (which shuts down its subchannel) // createSubchannel() has ever been called only once inOrder.verify(helper, never()).createSubchannel(any(CreateSubchannelArgs.class)); assertThat(mockSubchannels).isEmpty(); verify(subchannel).shutdown(); // RPC error status includes message of no backends provided by balancer - inOrder.verify(helper).updateBalancingState(eq(TRANSIENT_FAILURE), pickerCaptor.capture()); + inOrder.verify(helper, atLeast(1)) + .updateBalancingState(eq(TRANSIENT_FAILURE), pickerCaptor.capture()); RoundRobinPicker errorPicker = (RoundRobinPicker) pickerCaptor.getValue(); assertThat(errorPicker.pickList) .containsExactly(new ErrorEntry(GrpclbState.NO_AVAILABLE_BACKENDS_STATUS)); @@ -2109,18 +2103,22 @@ public void grpclbWorking_pickFirstMode_lbSendsEmptyAddress() throws Exception { .updateBalancingState(any(ConnectivityState.class), any(SubchannelPicker.class)); lbResponseObserver.onNext(buildLbResponse(backends2)); - // new addresses will be updated to the existing subchannel - inOrder.verify(helper, times(1)).createSubchannel(any(CreateSubchannelArgs.class)); - inOrder.verify(helper).updateBalancingState(eq(IDLE), pickerCaptor.capture()); - subchannel = mockSubchannels.poll(); + // A NEW child LB and NEW subchannel are created upon recovery + inOrder.verify(helper).createSubchannel(any(CreateSubchannelArgs.class)); + assertThat(mockSubchannels).hasSize(1); + Subchannel subchannel2 = mockSubchannels.poll(); + inOrder.verify(helper, atLeast(1)).updateBalancingState(eq(CONNECTING), pickerCaptor.capture()); // Subchannel became READY - deliverSubchannelState(subchannel, ConnectivityStateInfo.forNonError(CONNECTING)); - deliverSubchannelState(subchannel, ConnectivityStateInfo.forNonError(READY)); - inOrder.verify(helper).updateBalancingState(eq(READY), pickerCaptor.capture()); - RoundRobinPicker picker4 = (RoundRobinPicker) pickerCaptor.getValue(); - assertThat(picker4.pickList).containsExactly( - new BackendEntry(subchannel, new TokenAttachingTracerFactory(getLoadRecorder()))); + deliverSubchannelState(subchannel2, ConnectivityStateInfo.forNonError(READY)); + inOrder.verify(helper, atLeast(1)).updateBalancingState(eq(READY), pickerCaptor.capture()); + RoundRobinPicker pickerFinal = (RoundRobinPicker) pickerCaptor.getValue(); + assertThat(pickerFinal.dropList).containsExactly( + null, new DropEntry(getLoadRecorder(), "token0003"), null); + ChildLbPickerEntry finalEntry = (ChildLbPickerEntry) pickerFinal.pickList.get(0); + assertThat( + finalEntry.getChildPicker().pickSubchannel(mock(PickSubchannelArgs.class)).getSubchannel()) + .isEqualTo(subchannel2); } @Test @@ -2179,7 +2177,7 @@ private void pickFirstModeFallback(long timeout) throws Exception { // Fallback timer expires with no response fakeClock.forwardTime(timeout, TimeUnit.MILLISECONDS); - // Entering fallback mode + // Entering fallback mode - child LB is created for fallback backends inOrder.verify(helper).createSubchannel(createSubchannelArgsCaptor.capture()); CreateSubchannelArgs createSubchannelArgs = createSubchannelArgsCaptor.getValue(); assertThat(createSubchannelArgs.getAddresses()) @@ -2188,23 +2186,24 @@ private void pickFirstModeFallback(long timeout) throws Exception { assertThat(mockSubchannels).hasSize(1); Subchannel subchannel = mockSubchannels.poll(); - // Initially IDLE - inOrder.verify(helper).updateBalancingState(eq(IDLE), pickerCaptor.capture()); + // child pick_first eagerly connects, so initial state is CONNECTING + inOrder.verify(helper, atLeast(1)).updateBalancingState(eq(CONNECTING), pickerCaptor.capture()); RoundRobinPicker picker0 = (RoundRobinPicker) pickerCaptor.getValue(); + assertThat(picker0.pickList.get(0)).isInstanceOf(ChildLbPickerEntry.class); - // READY + // Initial eager connection request + verify(subchannel).requestConnection(); + // READY transition in fallback deliverSubchannelState(subchannel, ConnectivityStateInfo.forNonError(READY)); - inOrder.verify(helper).updateBalancingState(eq(READY), pickerCaptor.capture()); + inOrder.verify(helper, atLeast(1)).updateBalancingState(eq(READY), pickerCaptor.capture()); RoundRobinPicker picker1 = (RoundRobinPicker) pickerCaptor.getValue(); assertThat(picker1.dropList).containsExactly(null, null); - assertThat(picker1.pickList).containsExactly( - new BackendEntry(subchannel, new TokenAttachingTracerFactory(null))); + ChildLbPickerEntry readyEntry = (ChildLbPickerEntry) picker1.pickList.get(0); + assertThat( + readyEntry.getChildPicker().pickSubchannel(mock(PickSubchannelArgs.class)).getSubchannel()) + .isEqualTo(subchannel); - assertThat(picker0.dropList).containsExactly(null, null); - assertThat(picker0.pickList).containsExactly(new IdleSubchannelEntry(subchannel, syncContext)); - - - // Finally, an LB response, which brings us out of fallback + // Finally, an LB response arrives, which brings us out of fallback List backends1 = Arrays.asList( new ServerEntry("127.0.0.1", 2000, "token0001"), new ServerEntry("127.0.0.1", 2010, "token0002")); @@ -2213,20 +2212,42 @@ private void pickFirstModeFallback(long timeout) throws Exception { lbResponseObserver.onNext(buildInitialResponse()); lbResponseObserver.onNext(buildLbResponse(backends1)); - // new addresses will be updated to the existing subchannel - // createSubchannel() has ever been called only once + // subchannel should be updated, NOT recreated inOrder.verify(helper, never()).createSubchannel(any(CreateSubchannelArgs.class)); assertThat(mockSubchannels).isEmpty(); + // The child LB internally calls updateAddresses on the existing subchannel verify(subchannel).updateAddresses( eq(Arrays.asList( new EquivalentAddressGroup(backends1.get(0).addr, eagAttrsWithToken("token0001")), new EquivalentAddressGroup(backends1.get(1).addr, eagAttrsWithToken("token0002"))))); - inOrder.verify(helper).updateBalancingState(eq(READY), pickerCaptor.capture()); + inOrder.verify(helper, atLeast(1)).updateBalancingState(eq(READY), pickerCaptor.capture()); RoundRobinPicker picker2 = (RoundRobinPicker) pickerCaptor.getValue(); assertThat(picker2.dropList).containsExactly(null, null); - assertThat(picker2.pickList).containsExactly( - new BackendEntry(subchannel, new TokenAttachingTracerFactory(getLoadRecorder()))); + + // Verify subchannel is still the same via delegated picker + ChildLbPickerEntry updatedEntry = (ChildLbPickerEntry) picker2.pickList.get(0); + assertThat( + updatedEntry.getChildPicker().pickSubchannel(mock(PickSubchannelArgs.class)) + .getSubchannel()) + .isEqualTo(subchannel); + + // Subchannel goes IDLE, grpclb follows + deliverSubchannelState(subchannel, ConnectivityStateInfo.forNonError(IDLE)); + inOrder.verify(helper).updateBalancingState(eq(IDLE), pickerCaptor.capture()); + RoundRobinPicker pickerIdle = (RoundRobinPicker) pickerCaptor.getValue(); + + // Verify connection is NOT eagerly requested again yet (still only the 1st request from start) + verify(subchannel, times(1)).requestConnection(); + + // Picking while IDLE triggers a new connection request + PickSubchannelArgs args = mock(PickSubchannelArgs.class); + PickResult pick = pickerIdle.pickSubchannel(args); + assertThat(pick.getSubchannel()).isNull(); // BUFFERing while IDLE + verify(subchannel, times(2)).requestConnection(); + + balancer.requestConnection(); + verify(subchannel, times(3)).requestConnection(); // PICK_FIRST doesn't use subchannelPool verify(subchannelPool, never()) @@ -2260,6 +2281,8 @@ public void switchMode() throws Exception { List backends1 = Arrays.asList( new ServerEntry("127.0.0.1", 2000, "token0001"), new ServerEntry("127.0.0.1", 2010, "token0002")); + + // RR Mode: Ensure no updates before initial response inOrder.verify(helper, never()) .updateBalancingState(any(ConnectivityState.class), any(SubchannelPicker.class)); lbResponseObserver.onNext(buildInitialResponse()); @@ -2284,7 +2307,6 @@ public void switchMode() throws Exception { Collections.emptyList(), grpclbBalancerList, GrpclbConfig.create(Mode.PICK_FIRST)); - // GrpclbState will be shutdown, and a new one will be created assertThat(oobChannel.isShutdown()).isTrue(); verify(subchannelPool) @@ -2303,13 +2325,13 @@ public void switchMode() throws Exception { InitialLoadBalanceRequest.newBuilder().setName(SERVICE_AUTHORITY).build()) .build())); - // Simulate receiving LB response + // Simulate receiving LB response for PICK_FIRST inOrder.verify(helper, never()) .updateBalancingState(any(ConnectivityState.class), any(SubchannelPicker.class)); lbResponseObserver.onNext(buildInitialResponse()); lbResponseObserver.onNext(buildLbResponse(backends1)); - // PICK_FIRST Subchannel + // PICK_FIRST Subchannel: child LB creates it inOrder.verify(helper).createSubchannel(createSubchannelArgsCaptor.capture()); CreateSubchannelArgs createSubchannelArgs = createSubchannelArgsCaptor.getValue(); assertThat(createSubchannelArgs.getAddresses()) @@ -2317,7 +2339,9 @@ public void switchMode() throws Exception { new EquivalentAddressGroup(backends1.get(0).addr, eagAttrsWithToken("token0001")), new EquivalentAddressGroup(backends1.get(1).addr, eagAttrsWithToken("token0002"))); - inOrder.verify(helper).updateBalancingState(eq(IDLE), any(SubchannelPicker.class)); + // Child pick_first eagerly connects, so initial state is CONNECTING (not IDLE) + inOrder.verify(helper, atLeast(1)) + .updateBalancingState(eq(CONNECTING), any(SubchannelPicker.class)); } private static Attributes eagAttrsWithToken(String token) { @@ -2344,7 +2368,7 @@ public void switchMode_nullLbPolicy() throws Exception { InitialLoadBalanceRequest.newBuilder().setName(SERVICE_AUTHORITY).build()) .build())); - // Simulate receiving LB response + // Simulate receiving LB response (Initial default mode: ROUND_ROBIN) List backends1 = Arrays.asList( new ServerEntry("127.0.0.1", 2000, "token0001"), new ServerEntry("127.0.0.1", 2010, "token0002")); @@ -2391,13 +2415,13 @@ public void switchMode_nullLbPolicy() throws Exception { InitialLoadBalanceRequest.newBuilder().setName(SERVICE_AUTHORITY).build()) .build())); - // Simulate receiving LB response + // Simulate receiving LB response for PICK_FIRST inOrder.verify(helper, never()) .updateBalancingState(any(ConnectivityState.class), any(SubchannelPicker.class)); lbResponseObserver.onNext(buildInitialResponse()); lbResponseObserver.onNext(buildLbResponse(backends1)); - // PICK_FIRST Subchannel + // PICK_FIRST Subchannel: with delegation, child LB creates the subchannel inOrder.verify(helper).createSubchannel(createSubchannelArgsCaptor.capture()); CreateSubchannelArgs createSubchannelArgs = createSubchannelArgsCaptor.getValue(); assertThat(createSubchannelArgs.getAddresses()) @@ -2405,7 +2429,9 @@ public void switchMode_nullLbPolicy() throws Exception { new EquivalentAddressGroup(backends1.get(0).addr, eagAttrsWithToken("token0001")), new EquivalentAddressGroup(backends1.get(1).addr, eagAttrsWithToken("token0002"))); - inOrder.verify(helper).updateBalancingState(eq(IDLE), any(SubchannelPicker.class)); + // Child pick_first eagerly connects, so state is CONNECTING (not IDLE) + inOrder.verify(helper, atLeast(1)) + .updateBalancingState(eq(CONNECTING), any(SubchannelPicker.class)); } @Test diff --git a/grpclb/src/test/java/io/grpc/grpclb/GrpclbNameResolverTest.java b/grpclb/src/test/java/io/grpc/grpclb/GrpclbNameResolverTest.java index 3e2cf22605f..a90556a01b0 100644 --- a/grpclb/src/test/java/io/grpc/grpclb/GrpclbNameResolverTest.java +++ b/grpclb/src/test/java/io/grpc/grpclb/GrpclbNameResolverTest.java @@ -20,7 +20,6 @@ import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.lenient; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.never; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -96,7 +95,6 @@ public void close(Executor instance) {} } @Captor private ArgumentCaptor resultCaptor; - @Captor private ArgumentCaptor errorCaptor; @Mock private ServiceConfigParser serviceConfigParser; @Mock private NameResolver.Listener2 mockListener; @@ -152,9 +150,9 @@ public List resolveSrv(String host) throws Exception { resolver.start(mockListener); assertThat(fakeClock.runDueTasks()).isEqualTo(1); - verify(mockListener).onResult(resultCaptor.capture()); + verify(mockListener).onResult2(resultCaptor.capture()); ResolutionResult result = resultCaptor.getValue(); - assertThat(result.getAddresses()).isEmpty(); + assertThat(result.getAddressesOrError().getValue()).isEmpty(); assertThat(result.getAttributes()).isEqualTo(Attributes.EMPTY); assertThat(result.getServiceConfig()).isNull(); } @@ -192,11 +190,11 @@ public ConfigOrError answer(InvocationOnMock invocation) { resolver.start(mockListener); assertThat(fakeClock.runDueTasks()).isEqualTo(1); - verify(mockListener).onResult(resultCaptor.capture()); + verify(mockListener).onResult2(resultCaptor.capture()); ResolutionResult result = resultCaptor.getValue(); InetSocketAddress resolvedBackendAddr = (InetSocketAddress) Iterables.getOnlyElement( - Iterables.getOnlyElement(result.getAddresses()).getAddresses()); + Iterables.getOnlyElement(result.getAddressesOrError().getValue()).getAddresses()); assertThat(resolvedBackendAddr.getAddress()).isEqualTo(backendAddr); EquivalentAddressGroup resolvedBalancerAddr = Iterables.getOnlyElement(result.getAttributes().get(GrpclbConstants.ATTR_LB_ADDRS)); @@ -225,9 +223,9 @@ public void resolve_nullResourceResolver() throws Exception { resolver.start(mockListener); assertThat(fakeClock.runDueTasks()).isEqualTo(1); - verify(mockListener).onResult(resultCaptor.capture()); + verify(mockListener).onResult2(resultCaptor.capture()); ResolutionResult result = resultCaptor.getValue(); - assertThat(result.getAddresses()) + assertThat(result.getAddressesOrError().getValue()) .containsExactly( new EquivalentAddressGroup(new InetSocketAddress(backendAddr, DEFAULT_PORT))); assertThat(result.getAttributes()).isEqualTo(Attributes.EMPTY); @@ -245,8 +243,8 @@ public void resolve_nullResourceResolver_addressFailure() throws Exception { resolver.start(mockListener); assertThat(fakeClock.runDueTasks()).isEqualTo(1); - verify(mockListener).onError(errorCaptor.capture()); - Status errorStatus = errorCaptor.getValue(); + verify(mockListener).onResult2(resultCaptor.capture()); + Status errorStatus = resultCaptor.getValue().getAddressesOrError().getStatus(); assertThat(errorStatus.getCode()).isEqualTo(Code.UNAVAILABLE); assertThat(errorStatus.getCause()).hasMessageThat().contains("no addr"); } @@ -272,9 +270,9 @@ public void resolve_addressFailure_stillLookUpBalancersAndServiceConfig() throws resolver.start(mockListener); assertThat(fakeClock.runDueTasks()).isEqualTo(1); - verify(mockListener).onResult(resultCaptor.capture()); + verify(mockListener).onResult2(resultCaptor.capture()); ResolutionResult result = resultCaptor.getValue(); - assertThat(result.getAddresses()).isEmpty(); + assertThat(result.getAddressesOrError().getValue()).isEmpty(); EquivalentAddressGroup resolvedBalancerAddr = Iterables.getOnlyElement(result.getAttributes().get(GrpclbConstants.ATTR_LB_ADDRS)); assertThat(resolvedBalancerAddr.getAttributes().get(GrpclbConstants.ATTR_LB_ADDR_AUTHORITY)) @@ -306,12 +304,12 @@ public void resolveAll_balancerLookupFails_stillLookUpServiceConfig() throws Exc resolver.start(mockListener); assertThat(fakeClock.runDueTasks()).isEqualTo(1); - verify(mockListener).onResult(resultCaptor.capture()); + verify(mockListener).onResult2(resultCaptor.capture()); ResolutionResult result = resultCaptor.getValue(); InetSocketAddress resolvedBackendAddr = (InetSocketAddress) Iterables.getOnlyElement( - Iterables.getOnlyElement(result.getAddresses()).getAddresses()); + Iterables.getOnlyElement(result.getAddressesOrError().getValue()).getAddresses()); assertThat(resolvedBackendAddr.getAddress()).isEqualTo(backendAddr); assertThat(result.getAttributes().get(GrpclbConstants.ATTR_LB_ADDRS)).isNull(); verify(mockAddressResolver).resolveAddress(hostName); @@ -320,7 +318,7 @@ public void resolveAll_balancerLookupFails_stillLookUpServiceConfig() throws Exc } @Test - public void resolve_addressAndBalancersLookupFail_neverLookupServiceConfig() throws Exception { + public void resolve_addressAndBalancersLookupFail_stillLookupServiceConfig() throws Exception { AddressResolver mockAddressResolver = mock(AddressResolver.class); when(mockAddressResolver.resolveAddress(anyString())) .thenThrow(new UnknownHostException("I really tried")); @@ -335,11 +333,11 @@ public void resolve_addressAndBalancersLookupFail_neverLookupServiceConfig() thr resolver.start(mockListener); assertThat(fakeClock.runDueTasks()).isEqualTo(1); - verify(mockListener).onError(errorCaptor.capture()); - Status errorStatus = errorCaptor.getValue(); + verify(mockListener).onResult2(resultCaptor.capture()); + Status errorStatus = resultCaptor.getValue().getAddressesOrError().getStatus(); assertThat(errorStatus.getCode()).isEqualTo(Code.UNAVAILABLE); verify(mockAddressResolver).resolveAddress(hostName); - verify(mockResourceResolver, never()).resolveTxt("_grpc_config." + hostName); + verify(mockResourceResolver).resolveTxt("_grpc_config." + hostName); verify(mockResourceResolver).resolveSrv("_grpclb._tcp." + hostName); } } diff --git a/grpclb/src/test/java/io/grpc/grpclb/SecretGrpclbNameResolverProviderTest.java b/grpclb/src/test/java/io/grpc/grpclb/SecretGrpclbNameResolverProviderTest.java index 24b1c781f58..e9ed92a54d0 100644 --- a/grpclb/src/test/java/io/grpc/grpclb/SecretGrpclbNameResolverProviderTest.java +++ b/grpclb/src/test/java/io/grpc/grpclb/SecretGrpclbNameResolverProviderTest.java @@ -17,6 +17,8 @@ package io.grpc.grpclb; import static com.google.common.truth.Truth.assertThat; +import static com.google.common.truth.TruthJUnit.assume; +import static org.junit.Assert.assertThrows; import static org.junit.Assert.fail; import static org.mockito.Mockito.mock; @@ -24,15 +26,19 @@ import io.grpc.NameResolver; import io.grpc.NameResolver.ServiceConfigParser; import io.grpc.SynchronizationContext; +import io.grpc.Uri; import io.grpc.internal.DnsNameResolverProvider; import io.grpc.internal.GrpcUtil; import java.net.URI; +import java.util.Arrays; import org.junit.Test; import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; +import org.junit.runners.Parameterized; +import org.junit.runners.Parameterized.Parameter; +import org.junit.runners.Parameterized.Parameters; /** Unit tests for {@link SecretGrpclbNameResolverProvider}. */ -@RunWith(JUnit4.class) +@RunWith(Parameterized.class) public class SecretGrpclbNameResolverProviderTest { private final SynchronizationContext syncContext = new SynchronizationContext( @@ -53,6 +59,13 @@ public void uncaughtException(Thread t, Throwable e) { private SecretGrpclbNameResolverProvider.Provider provider = new SecretGrpclbNameResolverProvider.Provider(); + @Parameters(name = "enableRfc3986UrisParam={0}") + public static Iterable data() { + return Arrays.asList(new Object[][] {{true}, {false}}); + } + + @Parameter public boolean enableRfc3986UrisParam; + @Test public void isAvailable() { assertThat(provider.isAvailable()).isTrue(); @@ -66,43 +79,65 @@ public void priority_shouldBeHigherThanDefaultDnsNameResolver() { } @Test - public void newNameResolver() { - assertThat(provider.newNameResolver(URI.create("dns:///localhost:443"), args)) + public void newNameResolverReturnsCorrectType() { + assertThat(newNameResolver("dns:///localhost:443", args)) .isInstanceOf(GrpclbNameResolver.class); - assertThat(provider.newNameResolver(URI.create("notdns:///localhost:443"), args)).isNull(); + assertThat(newNameResolver("notdns:///localhost:443", args)).isNull(); } @Test public void invalidDnsName() throws Exception { - testInvalidUri(new URI("dns", null, "/[invalid]", null)); + testInvalidUri("dns:/%5Binvalid%5D"); } @Test public void validIpv6() throws Exception { - testValidUri(new URI("dns", null, "/[::1]", null)); + testValidUri("dns:/%5B::1%5D"); } @Test public void validDnsNameWithoutPort() throws Exception { - testValidUri(new URI("dns", null, "/foo.googleapis.com", null)); + testValidUri("dns:/foo.googleapis.com"); } @Test public void validDnsNameWithPort() throws Exception { - testValidUri(new URI("dns", null, "/foo.googleapis.com:456", null)); + testValidUri("dns:/foo.googleapis.com:456"); + } + + @Test + public void newNameResolver_rejectsExtraPathSegments() { + assume().that(enableRfc3986UrisParam).isTrue(); + IllegalArgumentException iae = + assertThrows( + IllegalArgumentException.class, + () -> newNameResolver("dns:///localhost:443/extras", args)); + assertThat(iae).hasMessageThat().contains("expected 1 path segment in target"); } - private void testInvalidUri(URI uri) { + @Test + public void newNameResolver_toleratesExtraPathSegments() { + assume().that(enableRfc3986UrisParam).isFalse(); + newNameResolver("dns:///localhost:443/extras", args); + } + + private void testInvalidUri(String uri) { try { - provider.newNameResolver(uri, args); + newNameResolver(uri, args); fail("Should have failed"); } catch (IllegalArgumentException e) { // expected } } - private void testValidUri(URI uri) { - GrpclbNameResolver resolver = provider.newNameResolver(uri, args); + private void testValidUri(String uri) { + NameResolver resolver = newNameResolver(uri, args); assertThat(resolver).isNotNull(); } + + private NameResolver newNameResolver(String uriString, NameResolver.Args args) { + return enableRfc3986UrisParam + ? provider.newNameResolver(Uri.create(uriString), args) + : provider.newNameResolver(URI.create(uriString), args); + } } diff --git a/inprocess/BUILD.bazel b/inprocess/BUILD.bazel index 65f2adceda1..e9c5001c5ec 100644 --- a/inprocess/BUILD.bazel +++ b/inprocess/BUILD.bazel @@ -1,3 +1,6 @@ +load("@rules_java//java:defs.bzl", "java_library") +load("@rules_jvm_external//:defs.bzl", "artifact") + java_library( name = "inprocess", srcs = glob([ @@ -5,12 +8,11 @@ java_library( ]), visibility = ["//visibility:public"], deps = [ - "//core:internal", "//api", "//context", - "@com_google_code_findbugs_jsr305//jar", - "@com_google_errorprone_error_prone_annotations//jar", - "@com_google_guava_guava//jar", - "@com_google_j2objc_j2objc_annotations//jar", + "//core:internal", + artifact("com.google.code.findbugs:jsr305"), + artifact("com.google.errorprone:error_prone_annotations"), + artifact("com.google.guava:guava"), ], ) diff --git a/inprocess/build.gradle b/inprocess/build.gradle index edc97883b50..075968ccb9a 100644 --- a/inprocess/build.gradle +++ b/inprocess/build.gradle @@ -22,8 +22,16 @@ dependencies { testFixtures(project(':grpc-core')) testImplementation libraries.guava.testlib - signature libraries.signature.java - signature libraries.signature.android + signature (libraries.signature.java) { + artifact { + extension = "signature" + } + } + signature (libraries.signature.android) { + artifact { + extension = "signature" + } + } } tasks.named("javadoc").configure { diff --git a/inprocess/src/main/java/io/grpc/inprocess/AnonymousInProcessSocketAddress.java b/inprocess/src/main/java/io/grpc/inprocess/AnonymousInProcessSocketAddress.java index 5f6486e335d..c458857d70b 100644 --- a/inprocess/src/main/java/io/grpc/inprocess/AnonymousInProcessSocketAddress.java +++ b/inprocess/src/main/java/io/grpc/inprocess/AnonymousInProcessSocketAddress.java @@ -18,11 +18,13 @@ import static com.google.common.base.Preconditions.checkState; +import com.google.errorprone.annotations.concurrent.GuardedBy; import io.grpc.ExperimentalApi; import java.io.IOException; +import java.io.NotSerializableException; +import java.io.ObjectOutputStream; import java.net.SocketAddress; import javax.annotation.Nullable; -import javax.annotation.concurrent.GuardedBy; /** * Custom SocketAddress class for {@link InProcessTransport}, for @@ -34,8 +36,13 @@ public final class AnonymousInProcessSocketAddress extends SocketAddress { @Nullable @GuardedBy("this") + @SuppressWarnings("serial") private InProcessServer server; + private void writeObject(ObjectOutputStream out) throws IOException { + throw new NotSerializableException("AnonymousInProcessSocketAddress is not serializable"); + } + /** Creates a new AnonymousInProcessSocketAddress. */ public AnonymousInProcessSocketAddress() { } diff --git a/inprocess/src/main/java/io/grpc/inprocess/InProcessChannelBuilder.java b/inprocess/src/main/java/io/grpc/inprocess/InProcessChannelBuilder.java index c000b66b2a2..9b33b3d3618 100644 --- a/inprocess/src/main/java/io/grpc/inprocess/InProcessChannelBuilder.java +++ b/inprocess/src/main/java/io/grpc/inprocess/InProcessChannelBuilder.java @@ -18,6 +18,7 @@ import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkNotNull; +import static io.grpc.inprocess.InProcessTransport.isEnabledSupportTracingMessageSizes; import com.google.errorprone.annotations.DoNotCall; import io.grpc.ChannelCredentials; @@ -94,6 +95,7 @@ public static InProcessChannelBuilder forAddress(String name, int port) { private ScheduledExecutorService scheduledExecutorService; private int maxInboundMetadataSize = Integer.MAX_VALUE; private boolean transportIncludeStatusCause = false; + private long assumedMessageSize = -1; private InProcessChannelBuilder(@Nullable SocketAddress directAddress, @Nullable String target) { @@ -117,10 +119,9 @@ public ClientTransportFactory buildClientTransportFactory() { managedChannelImplBuilder.setStatsRecordStartedRpcs(false); managedChannelImplBuilder.setStatsRecordFinishedRpcs(false); managedChannelImplBuilder.setStatsRecordRetryMetrics(false); - - // By default, In-process transport should not be retriable as that leaks memory. Since - // there is no wire, bytes aren't calculated so buffer limit isn't respected - managedChannelImplBuilder.disableRetry(); + if (!isEnabledSupportTracingMessageSizes) { + managedChannelImplBuilder.disableRetry(); + } } @Internal @@ -225,9 +226,24 @@ public InProcessChannelBuilder propagateCauseWithStatus(boolean enable) { return this; } + /** + * Assumes RPC messages are the specified size. This avoids serializing + * messages for metrics and retry memory tracking. This can dramatically + * improve performance when accurate message sizes are not needed and if + * nothing else needs the serialized message. + * @param assumedMessageSize length of InProcess transport's messageSize. + * @return this + * @throws IllegalArgumentException if assumedMessageSize is negative. + */ + public InProcessChannelBuilder assumedMessageSize(long assumedMessageSize) { + checkArgument(assumedMessageSize >= 0, "assumedMessageSize must be >= 0"); + this.assumedMessageSize = assumedMessageSize; + return this; + } + ClientTransportFactory buildTransportFactory() { - return new InProcessClientTransportFactory( - scheduledExecutorService, maxInboundMetadataSize, transportIncludeStatusCause); + return new InProcessClientTransportFactory(scheduledExecutorService, + maxInboundMetadataSize, transportIncludeStatusCause, assumedMessageSize); } void setStatsEnabled(boolean value) { @@ -243,15 +259,17 @@ static final class InProcessClientTransportFactory implements ClientTransportFac private final int maxInboundMetadataSize; private boolean closed; private final boolean includeCauseWithStatus; + private long assumedMessageSize; private InProcessClientTransportFactory( @Nullable ScheduledExecutorService scheduledExecutorService, - int maxInboundMetadataSize, boolean includeCauseWithStatus) { + int maxInboundMetadataSize, boolean includeCauseWithStatus, long assumedMessageSize) { useSharedTimer = scheduledExecutorService == null; timerService = useSharedTimer ? SharedResourceHolder.get(GrpcUtil.TIMER_SERVICE) : scheduledExecutorService; this.maxInboundMetadataSize = maxInboundMetadataSize; this.includeCauseWithStatus = includeCauseWithStatus; + this.assumedMessageSize = assumedMessageSize; } @Override @@ -263,7 +281,7 @@ public ConnectionClientTransport newClientTransport( // TODO(carl-mastrangelo): Pass channelLogger in. return new InProcessTransport( addr, maxInboundMetadataSize, options.getAuthority(), options.getUserAgent(), - options.getEagAttributes(), includeCauseWithStatus); + options.getEagAttributes(), includeCauseWithStatus, assumedMessageSize); } @Override diff --git a/inprocess/src/main/java/io/grpc/inprocess/InProcessServerBuilder.java b/inprocess/src/main/java/io/grpc/inprocess/InProcessServerBuilder.java index 190f67603c3..b2004426aae 100644 --- a/inprocess/src/main/java/io/grpc/inprocess/InProcessServerBuilder.java +++ b/inprocess/src/main/java/io/grpc/inprocess/InProcessServerBuilder.java @@ -24,6 +24,7 @@ import io.grpc.ExperimentalApi; import io.grpc.ForwardingServerBuilder; import io.grpc.Internal; +import io.grpc.MetricRecorder; import io.grpc.ServerBuilder; import io.grpc.ServerStreamTracer; import io.grpc.internal.FixedObjectPool; @@ -120,7 +121,8 @@ private InProcessServerBuilder(SocketAddress listenAddress) { final class InProcessClientTransportServersBuilder implements ClientTransportServersBuilder { @Override public InternalServer buildClientTransportServers( - List streamTracerFactories) { + List streamTracerFactories, + MetricRecorder metricRecorder) { return buildTransportServers(streamTracerFactories); } } diff --git a/inprocess/src/main/java/io/grpc/inprocess/InProcessTransport.java b/inprocess/src/main/java/io/grpc/inprocess/InProcessTransport.java index 91e519f9efc..a92f10fd5c5 100644 --- a/inprocess/src/main/java/io/grpc/inprocess/InProcessTransport.java +++ b/inprocess/src/main/java/io/grpc/inprocess/InProcessTransport.java @@ -18,12 +18,13 @@ import static com.google.common.base.Preconditions.checkNotNull; import static io.grpc.internal.GrpcUtil.TIMEOUT_KEY; -import static java.lang.Math.max; import com.google.common.base.MoreObjects; -import com.google.common.base.Optional; +import com.google.common.io.ByteStreams; import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.SettableFuture; +import com.google.errorprone.annotations.CheckReturnValue; +import com.google.errorprone.annotations.concurrent.GuardedBy; import io.grpc.Attributes; import io.grpc.CallOptions; import io.grpc.ClientStreamTracer; @@ -35,6 +36,7 @@ import io.grpc.InternalChannelz.SocketStats; import io.grpc.InternalLogId; import io.grpc.InternalMetadata; +import io.grpc.KnownLength; import io.grpc.Metadata; import io.grpc.MethodDescriptor; import io.grpc.SecurityLevel; @@ -52,13 +54,14 @@ import io.grpc.internal.ManagedClientTransport; import io.grpc.internal.NoopClientStream; import io.grpc.internal.ObjectPool; -import io.grpc.internal.ServerListener; import io.grpc.internal.ServerStream; import io.grpc.internal.ServerStreamListener; import io.grpc.internal.ServerTransport; import io.grpc.internal.ServerTransportListener; +import io.grpc.internal.SimpleDisconnectError; import io.grpc.internal.StatsTraceContext; import io.grpc.internal.StreamListener; +import java.io.ByteArrayInputStream; import java.io.InputStream; import java.net.SocketAddress; import java.util.ArrayDeque; @@ -73,21 +76,20 @@ import java.util.concurrent.TimeUnit; import java.util.logging.Level; import java.util.logging.Logger; -import javax.annotation.CheckReturnValue; import javax.annotation.Nullable; -import javax.annotation.concurrent.GuardedBy; import javax.annotation.concurrent.ThreadSafe; @ThreadSafe final class InProcessTransport implements ServerTransport, ConnectionClientTransport { private static final Logger log = Logger.getLogger(InProcessTransport.class.getName()); + static boolean isEnabledSupportTracingMessageSizes = + GrpcUtil.getFlag("GRPC_EXPERIMENTAL_SUPPORT_TRACING_MESSAGE_SIZES", false); private final InternalLogId logId; private final SocketAddress address; private final int clientMaxInboundMetadataSize; private final String authority; private final String userAgent; - private final Optional optionalServerListener; private int serverMaxInboundMetadataSize; private final boolean includeCauseWithStatus; private ObjectPool serverSchedulerPool; @@ -95,6 +97,8 @@ final class InProcessTransport implements ServerTransport, ConnectionClientTrans private ServerTransportListener serverTransportListener; private Attributes serverStreamAttributes; private ManagedClientTransport.Listener clientTransportListener; + // The size is assumed from the sender's side. + private final long assumedMessageSize; @GuardedBy("this") private boolean shutdown; @GuardedBy("this") @@ -134,9 +138,9 @@ protected void handleNotInUse() { } }; - private InProcessTransport(SocketAddress address, int maxInboundMetadataSize, String authority, + public InProcessTransport(SocketAddress address, int maxInboundMetadataSize, String authority, String userAgent, Attributes eagAttrs, - Optional optionalServerListener, boolean includeCauseWithStatus) { + boolean includeCauseWithStatus, long assumedMessageSize) { this.address = address; this.clientMaxInboundMetadataSize = maxInboundMetadataSize; this.authority = authority; @@ -148,47 +152,23 @@ private InProcessTransport(SocketAddress address, int maxInboundMetadataSize, St .set(Grpc.TRANSPORT_ATTR_REMOTE_ADDR, address) .set(Grpc.TRANSPORT_ATTR_LOCAL_ADDR, address) .build(); - this.optionalServerListener = optionalServerListener; logId = InternalLogId.allocate(getClass(), address.toString()); this.includeCauseWithStatus = includeCauseWithStatus; - } - - public InProcessTransport( - SocketAddress address, int maxInboundMetadataSize, String authority, String userAgent, - Attributes eagAttrs, boolean includeCauseWithStatus) { - this(address, maxInboundMetadataSize, authority, userAgent, eagAttrs, - Optional.absent(), includeCauseWithStatus); - } - - InProcessTransport( - String name, int maxInboundMetadataSize, String authority, String userAgent, - Attributes eagAttrs, ObjectPool serverSchedulerPool, - List serverStreamTracerFactories, - ServerListener serverListener, boolean includeCauseWithStatus) { - this(new InProcessSocketAddress(name), maxInboundMetadataSize, authority, userAgent, eagAttrs, - Optional.of(serverListener), includeCauseWithStatus); - this.serverMaxInboundMetadataSize = maxInboundMetadataSize; - this.serverSchedulerPool = serverSchedulerPool; - this.serverStreamTracerFactories = serverStreamTracerFactories; + this.assumedMessageSize = assumedMessageSize; } @CheckReturnValue @Override public synchronized Runnable start(ManagedClientTransport.Listener listener) { this.clientTransportListener = listener; - if (optionalServerListener.isPresent()) { + InProcessServer server = InProcessServer.findServer(address); + if (server != null) { + serverMaxInboundMetadataSize = server.getMaxInboundMetadataSize(); + serverSchedulerPool = server.getScheduledExecutorServicePool(); serverScheduler = serverSchedulerPool.getObject(); - serverTransportListener = optionalServerListener.get().transportCreated(this); - } else { - InProcessServer server = InProcessServer.findServer(address); - if (server != null) { - serverMaxInboundMetadataSize = server.getMaxInboundMetadataSize(); - serverSchedulerPool = server.getScheduledExecutorServicePool(); - serverScheduler = serverSchedulerPool.getObject(); - serverStreamTracerFactories = server.getStreamTracerFactories(); - // Must be semi-initialized; past this point, can begin receiving requests - serverTransportListener = server.register(this); - } + serverStreamTracerFactories = server.getStreamTracerFactories(); + // Must be semi-initialized; past this point, can begin receiving requests + serverTransportListener = server.register(this); } if (serverTransportListener == null) { shutdownStatus = Status.UNAVAILABLE.withDescription("Could not find server: " + address); @@ -203,21 +183,14 @@ public void run() { } }; } - return new Runnable() { - @Override - @SuppressWarnings("deprecation") - public void run() { - synchronized (InProcessTransport.this) { - Attributes serverTransportAttrs = Attributes.newBuilder() - .set(Grpc.TRANSPORT_ATTR_REMOTE_ADDR, address) - .set(Grpc.TRANSPORT_ATTR_LOCAL_ADDR, address) - .build(); - serverStreamAttributes = serverTransportListener.transportReady(serverTransportAttrs); - attributes = clientTransportListener.filterTransport(attributes); - clientTransportListener.transportReady(); - } - } - }; + Attributes serverTransportAttrs = Attributes.newBuilder() + .set(Grpc.TRANSPORT_ATTR_REMOTE_ADDR, address) + .set(Grpc.TRANSPORT_ATTR_LOCAL_ADDR, address) + .build(); + serverStreamAttributes = serverTransportListener.transportReady(serverTransportAttrs); + attributes = clientTransportListener.filterTransport(attributes); + clientTransportListener.transportReady(); + return null; } @Override @@ -273,7 +246,7 @@ public synchronized void ping(final PingCallback callback, Executor executor) { executor.execute(new Runnable() { @Override public void run() { - callback.onFailure(shutdownStatus.asRuntimeException()); + callback.onFailure(shutdownStatus); } }); } else { @@ -355,7 +328,7 @@ private synchronized void notifyShutdown(Status s) { return; } shutdown = true; - clientTransportListener.transportShutdown(s); + clientTransportListener.transportShutdown(s, SimpleDisconnectError.SUBCHANNEL_SHUTDOWN); } private synchronized void notifyTerminated() { @@ -514,6 +487,25 @@ private void clientCancelled(Status status) { @Override public void writeMessage(InputStream message) { + long messageLength = 0; + if (isEnabledSupportTracingMessageSizes) { + try { + if (assumedMessageSize != -1) { + messageLength = assumedMessageSize; + } else if (message instanceof KnownLength || message instanceof ByteArrayInputStream) { + messageLength = message.available(); + } else { + InputStream oldMessage = message; + byte[] payload = ByteStreams.toByteArray(message); + messageLength = payload.length; + message = new ByteArrayInputStream(payload); + oldMessage.close(); + } + } catch (Exception e) { + throw new RuntimeException("Error processing the message length", e); + } + } + synchronized (this) { if (closed) { return; @@ -522,6 +514,13 @@ public void writeMessage(InputStream message) { statsTraceCtx.outboundMessageSent(outboundSeqNo, -1, -1); clientStream.statsTraceCtx.inboundMessage(outboundSeqNo); clientStream.statsTraceCtx.inboundMessageRead(outboundSeqNo, -1, -1); + if (isEnabledSupportTracingMessageSizes) { + statsTraceCtx.outboundUncompressedSize(messageLength); + statsTraceCtx.outboundWireSize(messageLength); + // messageLength should be same at receiver's end as no actual wire is involved. + clientStream.statsTraceCtx.inboundUncompressedSize(messageLength); + clientStream.statsTraceCtx.inboundWireSize(messageLength); + } outboundSeqNo++; StreamListener.MessageProducer producer = new SingleMessageProducer(message); if (clientRequested > 0) { @@ -531,7 +530,6 @@ public void writeMessage(InputStream message) { clientReceiveQueue.add(producer); } } - syncContext.drain(); } @@ -571,7 +569,7 @@ public void writeHeaders(Metadata headers, boolean flush) { return; } - clientStream.statsTraceCtx.clientInboundHeaders(); + clientStream.statsTraceCtx.clientInboundHeaders(headers); syncContext.executeLater(() -> clientStreamListener.headersRead(headers)); } syncContext.drain(); @@ -608,7 +606,7 @@ public void close(Status status, Metadata trailers) { notifyClientClose(status, trailers); } - /** clientStream.serverClosed() must be called before this method */ + /** clientStream.serverClosed() must be called before this method. */ private void notifyClientClose(Status status, Metadata trailers) { Status clientStatus = cleanStatus(status, includeCauseWithStatus); synchronized (this) { @@ -697,6 +695,11 @@ public StatsTraceContext statsTraceContext() { public int streamId() { return -1; } + + @Override + public void setOnReadyThreshold(int numBytes) { + // noop + } } private class InProcessClientStream implements ClientStream { @@ -780,6 +783,24 @@ private void serverClosed(Status serverListenerStatus, Status serverTracerStatus @Override public void writeMessage(InputStream message) { + long messageLength = 0; + if (isEnabledSupportTracingMessageSizes) { + try { + if (assumedMessageSize != -1) { + messageLength = assumedMessageSize; + } else if (message instanceof KnownLength || message instanceof ByteArrayInputStream) { + messageLength = message.available(); + } else { + InputStream oldMessage = message; + byte[] payload = ByteStreams.toByteArray(message); + messageLength = payload.length; + message = new ByteArrayInputStream(payload); + oldMessage.close(); + } + } catch (Exception e) { + throw new RuntimeException("Error processing the message length", e); + } + } synchronized (this) { if (closed) { return; @@ -788,6 +809,13 @@ public void writeMessage(InputStream message) { statsTraceCtx.outboundMessageSent(outboundSeqNo, -1, -1); serverStream.statsTraceCtx.inboundMessage(outboundSeqNo); serverStream.statsTraceCtx.inboundMessageRead(outboundSeqNo, -1, -1); + if (isEnabledSupportTracingMessageSizes) { + statsTraceCtx.outboundUncompressedSize(messageLength); + statsTraceCtx.outboundWireSize(messageLength); + // messageLength should be same at receiver's end as no actual wire is involved. + serverStream.statsTraceCtx.inboundUncompressedSize(messageLength); + serverStream.statsTraceCtx.inboundWireSize(messageLength); + } outboundSeqNo++; StreamListener.MessageProducer producer = new SingleMessageProducer(message); if (serverRequested > 0) { @@ -911,8 +939,7 @@ public void setMaxOutboundMessageSize(int maxSize) {} @Override public void setDeadline(Deadline deadline) { headers.discardAll(TIMEOUT_KEY); - long effectiveTimeout = max(0, deadline.timeRemaining(TimeUnit.NANOSECONDS)); - headers.put(TIMEOUT_KEY, effectiveTimeout); + headers.put(TIMEOUT_KEY, deadline.timeRemaining(TimeUnit.NANOSECONDS)); } @Override diff --git a/inprocess/src/main/java/io/grpc/inprocess/InternalInProcess.java b/inprocess/src/main/java/io/grpc/inprocess/InternalInProcess.java deleted file mode 100644 index 680373533c8..00000000000 --- a/inprocess/src/main/java/io/grpc/inprocess/InternalInProcess.java +++ /dev/null @@ -1,67 +0,0 @@ -/* - * Copyright 2020 The gRPC Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.grpc.inprocess; - -import io.grpc.Attributes; -import io.grpc.Internal; -import io.grpc.ServerStreamTracer; -import io.grpc.internal.ConnectionClientTransport; -import io.grpc.internal.ObjectPool; -import io.grpc.internal.ServerListener; -import java.util.List; -import java.util.concurrent.ScheduledExecutorService; - -/** - * Internal {@link InProcessTransport} accessor. - * - *

This is intended for use by io.grpc.internal, and the specifically - * supported transport packages. - */ -@Internal -public final class InternalInProcess { - - private InternalInProcess() {} - - /** - * Creates a new InProcessTransport. - * - *

When started, the transport will be registered with the given - * {@link ServerListener}. - */ - @Internal - public static ConnectionClientTransport createInProcessTransport( - String name, - int maxInboundMetadataSize, - String authority, - String userAgent, - Attributes eagAttrs, - ObjectPool serverSchedulerPool, - List serverStreamTracerFactories, - ServerListener serverListener, - boolean includeCauseWithStatus) { - return new InProcessTransport( - name, - maxInboundMetadataSize, - authority, - userAgent, - eagAttrs, - serverSchedulerPool, - serverStreamTracerFactories, - serverListener, - includeCauseWithStatus); - } -} diff --git a/inprocess/src/test/java/io/grpc/inprocess/AnonymousInProcessTransportTest.java b/inprocess/src/test/java/io/grpc/inprocess/AnonymousInProcessTransportTest.java index a78a604eac3..7bf884c9ff9 100644 --- a/inprocess/src/test/java/io/grpc/inprocess/AnonymousInProcessTransportTest.java +++ b/inprocess/src/test/java/io/grpc/inprocess/AnonymousInProcessTransportTest.java @@ -52,6 +52,6 @@ protected InternalServer newServer( protected ManagedClientTransport newClientTransport(InternalServer server) { return new InProcessTransport( address, GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE, - testAuthority(server), USER_AGENT, eagAttrs(), false); + testAuthority(server), USER_AGENT, eagAttrs(), false, -1); } } diff --git a/inprocess/src/test/java/io/grpc/inprocess/InProcessTransportTest.java b/inprocess/src/test/java/io/grpc/inprocess/InProcessTransportTest.java index 420a9c4a8e7..d2220e05114 100644 --- a/inprocess/src/test/java/io/grpc/inprocess/InProcessTransportTest.java +++ b/inprocess/src/test/java/io/grpc/inprocess/InProcessTransportTest.java @@ -17,6 +17,7 @@ package io.grpc.inprocess; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; import io.grpc.CallOptions; @@ -34,15 +35,25 @@ import io.grpc.Status.Code; import io.grpc.StatusRuntimeException; import io.grpc.internal.AbstractTransportTest; +import io.grpc.internal.ClientStream; +import io.grpc.internal.ClientStreamListenerBase; import io.grpc.internal.GrpcUtil; import io.grpc.internal.InternalServer; import io.grpc.internal.ManagedClientTransport; +import io.grpc.internal.MockServerTransportListener; +import io.grpc.internal.MockServerTransportListener.StreamCreation; +import io.grpc.internal.ServerStream; +import io.grpc.internal.ServerStreamListenerBase; +import io.grpc.internal.testing.TestStreamTracer; import io.grpc.stub.ClientCalls; import io.grpc.testing.GrpcCleanupRule; import io.grpc.testing.TestMethodDescriptors; +import java.io.InputStream; +import java.util.Arrays; import java.util.List; import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; +import org.junit.Assert; import org.junit.Ignore; import org.junit.Rule; import org.junit.Test; @@ -55,10 +66,18 @@ public class InProcessTransportTest extends AbstractTransportTest { private static final String TRANSPORT_NAME = "perfect-for-testing"; private static final String AUTHORITY = "a-testing-authority"; protected static final String USER_AGENT = "a-testing-user-agent"; + private static final int TIMEOUT_MS = 5000; + private static final long TEST_MESSAGE_LENGTH = 100; @Rule public final GrpcCleanupRule grpcCleanupRule = new GrpcCleanupRule(); + @Override + protected InternalServer newServer( + int port, List streamTracerFactories) { + return newServer(streamTracerFactories); + } + @Override protected InternalServer newServer( List streamTracerFactories) { @@ -68,12 +87,6 @@ protected InternalServer newServer( return new InProcessServer(builder, streamTracerFactories); } - @Override - protected InternalServer newServer( - int port, List streamTracerFactories) { - return newServer(streamTracerFactories); - } - @Override protected String testAuthority(InternalServer server) { return AUTHORITY; @@ -83,14 +96,13 @@ protected String testAuthority(InternalServer server) { protected ManagedClientTransport newClientTransport(InternalServer server) { return new InProcessTransport( new InProcessSocketAddress(TRANSPORT_NAME), GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE, - testAuthority(server), USER_AGENT, eagAttrs(), false); + testAuthority(server), USER_AGENT, eagAttrs(), false, -1); } - @Override - protected boolean sizesReported() { - // TODO(zhangkun83): InProcessTransport doesn't record metrics for now - // (https://github.com/grpc/grpc-java/issues/2284) - return false; + private ManagedClientTransport newClientTransportWithAssumedMessageSize(InternalServer server) { + return new InProcessTransport( + new InProcessSocketAddress(TRANSPORT_NAME), GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE, + testAuthority(server), USER_AGENT, eagAttrs(), false, TEST_MESSAGE_LENGTH); } @Test @@ -170,11 +182,67 @@ public Listener startCall(ServerCall call, Metadata headers) { .build(); ClientCall call = channel.newCall(nonMatchMethod, CallOptions.DEFAULT); try { - ClientCalls.futureUnaryCall(call, null).get(5, TimeUnit.SECONDS); + ClientCalls.futureUnaryCall(call, null).get(TIMEOUT_MS, TimeUnit.MILLISECONDS); fail("Call should fail."); } catch (ExecutionException ex) { StatusRuntimeException s = (StatusRuntimeException)ex.getCause(); assertEquals(Code.UNIMPLEMENTED, s.getStatus().getCode()); } } + + @Test + public void basicStreamInProcess() throws Exception { + InProcessServerBuilder builder = InProcessServerBuilder + .forName(TRANSPORT_NAME) + .maxInboundMetadataSize(GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE); + server = new InProcessServer(builder, Arrays.asList(serverStreamTracerFactory)); + server.start(serverListener); + client = newClientTransportWithAssumedMessageSize(server); + startTransport(client, mockClientTransportListener); + MockServerTransportListener serverTransportListener + = serverListener.takeListenerOrFail(TIMEOUT_MS, TimeUnit.MILLISECONDS); + serverTransport = serverTransportListener.transport; + // Set up client stream + ClientStream clientStream = client.newStream( + methodDescriptor, new Metadata(), CallOptions.DEFAULT, tracers); + ClientStreamListenerBase clientStreamListener = new ClientStreamListenerBase(); + clientStream.start(clientStreamListener); + StreamCreation serverStreamCreation + = serverTransportListener.takeStreamOrFail(TIMEOUT_MS, TimeUnit.MILLISECONDS); + ServerStream serverStream = serverStreamCreation.stream; + ServerStreamListenerBase serverStreamListener = serverStreamCreation.listener; + serverStream.request(1); + assertTrue(clientStream.isReady()); + // Send message from client to server + clientStream.writeMessage(methodDescriptor.streamRequest("Hello from client")); + clientStream.flush(); + // Verify server received the message and check its size + InputStream message = + serverStreamListener.messageQueue.poll(TIMEOUT_MS, TimeUnit.MILLISECONDS); + assertEquals("Hello from client", methodDescriptor.parseRequest(message)); + message.close(); + clientStream.halfClose(); + assertAssumedMessageSize(clientStreamTracer1, serverStreamTracer1); + + clientStream.request(1); + assertTrue(serverStream.isReady()); + serverStream.writeMessage(methodDescriptor.streamResponse("Hi from server")); + serverStream.flush(); + message = clientStreamListener.messageQueue.poll(TIMEOUT_MS, TimeUnit.MILLISECONDS); + assertEquals("Hi from server", methodDescriptor.parseResponse(message)); + assertAssumedMessageSize(serverStreamTracer1, clientStreamTracer1); + message.close(); + Status status = Status.OK.withDescription("That was normal"); + serverStream.close(status, new Metadata()); + } + + private void assertAssumedMessageSize( + TestStreamTracer streamTracerSender, TestStreamTracer streamTracerReceiver) { + if (isEnabledSupportTracingMessageSizes()) { + Assert.assertEquals(TEST_MESSAGE_LENGTH, streamTracerSender.getOutboundWireSize()); + Assert.assertEquals(TEST_MESSAGE_LENGTH, streamTracerSender.getOutboundUncompressedSize()); + Assert.assertEquals(TEST_MESSAGE_LENGTH, streamTracerReceiver.getInboundWireSize()); + Assert.assertEquals(TEST_MESSAGE_LENGTH, streamTracerReceiver.getInboundUncompressedSize()); + } + } } diff --git a/inprocess/src/test/java/io/grpc/inprocess/StandaloneInProcessTransportTest.java b/inprocess/src/test/java/io/grpc/inprocess/StandaloneInProcessTransportTest.java deleted file mode 100644 index b1d80d53b8b..00000000000 --- a/inprocess/src/test/java/io/grpc/inprocess/StandaloneInProcessTransportTest.java +++ /dev/null @@ -1,171 +0,0 @@ -/* - * Copyright 2020 The gRPC Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.grpc.inprocess; - -import io.grpc.InternalChannelz.SocketStats; -import io.grpc.InternalInstrumented; -import io.grpc.ServerStreamTracer; -import io.grpc.internal.AbstractTransportTest; -import io.grpc.internal.GrpcUtil; -import io.grpc.internal.InternalServer; -import io.grpc.internal.ManagedClientTransport; -import io.grpc.internal.ObjectPool; -import io.grpc.internal.ServerListener; -import io.grpc.internal.ServerTransport; -import io.grpc.internal.ServerTransportListener; -import io.grpc.internal.SharedResourcePool; -import java.io.IOException; -import java.net.SocketAddress; -import java.util.Collections; -import java.util.List; -import java.util.concurrent.ScheduledExecutorService; -import javax.annotation.Nullable; -import org.junit.Ignore; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; - -/** Unit tests for {@link InProcessTransport} when used with a separate {@link InternalServer}. */ -@RunWith(JUnit4.class) -public final class StandaloneInProcessTransportTest extends AbstractTransportTest { - private static final String TRANSPORT_NAME = "perfect-for-testing"; - private static final String AUTHORITY = "a-testing-authority"; - private static final String USER_AGENT = "a-testing-user-agent"; - - private final ObjectPool schedulerPool = - SharedResourcePool.forResource(GrpcUtil.TIMER_SERVICE); - - private TestServer currentServer; - - @Override - protected InternalServer newServer( - List streamTracerFactories) { - return new TestServer(streamTracerFactories); - } - - @Override - protected InternalServer newServer( - int port, List streamTracerFactories) { - return newServer(streamTracerFactories); - } - - @Override - protected String testAuthority(InternalServer server) { - return AUTHORITY; - } - - @Override - protected ManagedClientTransport newClientTransport(InternalServer server) { - TestServer testServer = (TestServer) server; - return InternalInProcess.createInProcessTransport( - TRANSPORT_NAME, - GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE, - testAuthority(server), - USER_AGENT, - eagAttrs(), - schedulerPool, - testServer.streamTracerFactories, - testServer.serverListener, - false); - } - - @Override - protected boolean sizesReported() { - // TODO(zhangkun83): InProcessTransport doesn't record metrics for now - // (https://github.com/grpc/grpc-java/issues/2284) - return false; - } - - @Test - @Ignore - @Override - public void socketStats() throws Exception { - // test does not apply to in-process - } - - /** An internalserver just for this test. */ - private final class TestServer implements InternalServer { - - final List streamTracerFactories; - ServerListener serverListener; - - TestServer(List streamTracerFactories) { - this.streamTracerFactories = streamTracerFactories; - } - - @Override - public void start(ServerListener serverListener) throws IOException { - if (currentServer != null) { - throw new IOException("Server already present"); - } - currentServer = this; - this.serverListener = new ServerListenerWrapper(serverListener); - } - - @Override - public void shutdown() { - currentServer = null; - serverListener.serverShutdown(); - } - - @Override - public SocketAddress getListenSocketAddress() { - return new SocketAddress() {}; - } - - @Override - public List getListenSocketAddresses() { - return Collections.singletonList(getListenSocketAddress()); - } - - @Override - @Nullable - public InternalInstrumented getListenSocketStats() { - return null; - } - - @Override - @Nullable - public List> getListenSocketStatsList() { - return null; - } - } - - /** Wraps the server listener to ensure we don't accept new transports after shutdown. */ - private static final class ServerListenerWrapper implements ServerListener { - private final ServerListener delegateListener; - private boolean shutdown; - - ServerListenerWrapper(ServerListener delegateListener) { - this.delegateListener = delegateListener; - } - - @Override - public ServerTransportListener transportCreated(ServerTransport transport) { - if (shutdown) { - return null; - } - return delegateListener.transportCreated(transport); - } - - @Override - public void serverShutdown() { - shutdown = true; - delegateListener.serverShutdown(); - } - } -} diff --git a/interop-testing/build.gradle b/interop-testing/build.gradle index 88606ea08e5..5160759460c 100644 --- a/interop-testing/build.gradle +++ b/interop-testing/build.gradle @@ -13,11 +13,10 @@ dependencies { implementation project(path: ':grpc-alts', configuration: 'shadow'), project(':grpc-auth'), project(':grpc-census'), - project(':grpc-core'), - project(':grpc-googleapis'), + project(':grpc-opentelemetry'), + project(':grpc-gcp-csm-observability'), project(':grpc-netty'), project(':grpc-okhttp'), - project(':grpc-rls'), project(':grpc-services'), project(':grpc-testing'), project(':grpc-protobuf-lite'), @@ -26,12 +25,12 @@ dependencies { libraries.truth, libraries.opencensus.contrib.grpc.metrics, libraries.google.auth.oauth2Http, + libraries.opentelemetry.sdk.extension.autoconfigure, libraries.guava.jre // Fix checkUpperBoundDeps using -android api project(':grpc-api'), project(':grpc-stub'), project(':grpc-protobuf'), libraries.junit - compileOnly libraries.javax.annotation // TODO(sergiitk): replace with com.google.cloud:google-cloud-logging // Used instead of google-cloud-logging because it's failing // due to a circular dependency on grpc. @@ -42,6 +41,8 @@ dependencies { runtimeOnly libraries.opencensus.impl, libraries.netty.tcnative, libraries.netty.tcnative.classes, + libraries.opentelemetry.exporter.prometheus, // For xds interop client + project(':grpc-googleapis'), project(':grpc-grpclb'), project(':grpc-rls') testImplementation testFixtures(project(':grpc-api')), @@ -51,8 +52,16 @@ dependencies { libraries.mockito.core, libraries.okhttp - signature libraries.signature.java - signature libraries.signature.android + signature (libraries.signature.java) { + artifact { + extension = "signature" + } + } + signature (libraries.signature.android) { + artifact { + extension = "signature" + } + } } configureProtoCompilation() @@ -159,7 +168,9 @@ application { from(xds_test_client) from(xds_test_server) from(xds_federation_test_client) - fileMode = 0755 + filePermissions { + unix(0755) + } } } diff --git a/interop-testing/src/generated/main/grpc/io/grpc/testing/integration/LoadBalancerStatsServiceGrpc.java b/interop-testing/src/generated/main/grpc/io/grpc/testing/integration/LoadBalancerStatsServiceGrpc.java index 2f4dc69c0c6..22c64d12f33 100644 --- a/interop-testing/src/generated/main/grpc/io/grpc/testing/integration/LoadBalancerStatsServiceGrpc.java +++ b/interop-testing/src/generated/main/grpc/io/grpc/testing/integration/LoadBalancerStatsServiceGrpc.java @@ -7,9 +7,6 @@ * A service used to obtain stats for verifying LB behavior. * */ -@javax.annotation.Generated( - value = "by gRPC proto compiler", - comments = "Source: grpc/testing/test.proto") @io.grpc.stub.annotations.GrpcGenerated public final class LoadBalancerStatsServiceGrpc { @@ -94,6 +91,21 @@ public LoadBalancerStatsServiceStub newStub(io.grpc.Channel channel, io.grpc.Cal return LoadBalancerStatsServiceStub.newStub(factory, channel); } + /** + * Creates a new blocking-style stub that supports all types of calls on the service + */ + public static LoadBalancerStatsServiceBlockingV2Stub newBlockingV2Stub( + io.grpc.Channel channel) { + io.grpc.stub.AbstractStub.StubFactory factory = + new io.grpc.stub.AbstractStub.StubFactory() { + @java.lang.Override + public LoadBalancerStatsServiceBlockingV2Stub newStub(io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new LoadBalancerStatsServiceBlockingV2Stub(channel, callOptions); + } + }; + return LoadBalancerStatsServiceBlockingV2Stub.newStub(factory, channel); + } + /** * Creates a new blocking-style stub that supports unary and streaming output calls on the service */ @@ -214,6 +226,46 @@ public void getClientAccumulatedStats(io.grpc.testing.integration.Messages.LoadB * A service used to obtain stats for verifying LB behavior. * */ + public static final class LoadBalancerStatsServiceBlockingV2Stub + extends io.grpc.stub.AbstractBlockingStub { + private LoadBalancerStatsServiceBlockingV2Stub( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + super(channel, callOptions); + } + + @java.lang.Override + protected LoadBalancerStatsServiceBlockingV2Stub build( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new LoadBalancerStatsServiceBlockingV2Stub(channel, callOptions); + } + + /** + *

+     * Gets the backend distribution for RPCs sent by a test client.
+     * 
+ */ + public io.grpc.testing.integration.Messages.LoadBalancerStatsResponse getClientStats(io.grpc.testing.integration.Messages.LoadBalancerStatsRequest request) throws io.grpc.StatusException { + return io.grpc.stub.ClientCalls.blockingV2UnaryCall( + getChannel(), getGetClientStatsMethod(), getCallOptions(), request); + } + + /** + *
+     * Gets the accumulated stats for RPCs sent by a test client.
+     * 
+ */ + public io.grpc.testing.integration.Messages.LoadBalancerAccumulatedStatsResponse getClientAccumulatedStats(io.grpc.testing.integration.Messages.LoadBalancerAccumulatedStatsRequest request) throws io.grpc.StatusException { + return io.grpc.stub.ClientCalls.blockingV2UnaryCall( + getChannel(), getGetClientAccumulatedStatsMethod(), getCallOptions(), request); + } + } + + /** + * A stub to allow clients to do limited synchronous rpc calls to service LoadBalancerStatsService. + *
+   * A service used to obtain stats for verifying LB behavior.
+   * 
+ */ public static final class LoadBalancerStatsServiceBlockingStub extends io.grpc.stub.AbstractBlockingStub { private LoadBalancerStatsServiceBlockingStub( diff --git a/interop-testing/src/generated/main/grpc/io/grpc/testing/integration/MetricsServiceGrpc.java b/interop-testing/src/generated/main/grpc/io/grpc/testing/integration/MetricsServiceGrpc.java index 1650365bd52..980dee010f1 100644 --- a/interop-testing/src/generated/main/grpc/io/grpc/testing/integration/MetricsServiceGrpc.java +++ b/interop-testing/src/generated/main/grpc/io/grpc/testing/integration/MetricsServiceGrpc.java @@ -4,9 +4,6 @@ /** */ -@javax.annotation.Generated( - value = "by gRPC proto compiler", - comments = "Source: grpc/testing/metrics.proto") @io.grpc.stub.annotations.GrpcGenerated public final class MetricsServiceGrpc { @@ -91,6 +88,21 @@ public MetricsServiceStub newStub(io.grpc.Channel channel, io.grpc.CallOptions c return MetricsServiceStub.newStub(factory, channel); } + /** + * Creates a new blocking-style stub that supports all types of calls on the service + */ + public static MetricsServiceBlockingV2Stub newBlockingV2Stub( + io.grpc.Channel channel) { + io.grpc.stub.AbstractStub.StubFactory factory = + new io.grpc.stub.AbstractStub.StubFactory() { + @java.lang.Override + public MetricsServiceBlockingV2Stub newStub(io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new MetricsServiceBlockingV2Stub(channel, callOptions); + } + }; + return MetricsServiceBlockingV2Stub.newStub(factory, channel); + } + /** * Creates a new blocking-style stub that supports unary and streaming output calls on the service */ @@ -201,6 +213,46 @@ public void getGauge(io.grpc.testing.integration.Metrics.GaugeRequest request, /** * A stub to allow clients to do synchronous rpc calls to service MetricsService. */ + public static final class MetricsServiceBlockingV2Stub + extends io.grpc.stub.AbstractBlockingStub { + private MetricsServiceBlockingV2Stub( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + super(channel, callOptions); + } + + @java.lang.Override + protected MetricsServiceBlockingV2Stub build( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new MetricsServiceBlockingV2Stub(channel, callOptions); + } + + /** + *
+     * Returns the values of all the gauges that are currently being maintained by
+     * the service
+     * 
+ */ + @io.grpc.ExperimentalApi("https://github.com/grpc/grpc-java/issues/10918") + public io.grpc.stub.BlockingClientCall + getAllGauges(io.grpc.testing.integration.Metrics.EmptyMessage request) { + return io.grpc.stub.ClientCalls.blockingV2ServerStreamingCall( + getChannel(), getGetAllGaugesMethod(), getCallOptions(), request); + } + + /** + *
+     * Returns the value of one gauge
+     * 
+ */ + public io.grpc.testing.integration.Metrics.GaugeResponse getGauge(io.grpc.testing.integration.Metrics.GaugeRequest request) throws io.grpc.StatusException { + return io.grpc.stub.ClientCalls.blockingV2UnaryCall( + getChannel(), getGetGaugeMethod(), getCallOptions(), request); + } + } + + /** + * A stub to allow clients to do limited synchronous rpc calls to service MetricsService. + */ public static final class MetricsServiceBlockingStub extends io.grpc.stub.AbstractBlockingStub { private MetricsServiceBlockingStub( diff --git a/interop-testing/src/generated/main/grpc/io/grpc/testing/integration/ReconnectServiceGrpc.java b/interop-testing/src/generated/main/grpc/io/grpc/testing/integration/ReconnectServiceGrpc.java index d1887ee83c4..05d46ce8e95 100644 --- a/interop-testing/src/generated/main/grpc/io/grpc/testing/integration/ReconnectServiceGrpc.java +++ b/interop-testing/src/generated/main/grpc/io/grpc/testing/integration/ReconnectServiceGrpc.java @@ -7,9 +7,6 @@ * A service used to control reconnect server. * */ -@javax.annotation.Generated( - value = "by gRPC proto compiler", - comments = "Source: grpc/testing/test.proto") @io.grpc.stub.annotations.GrpcGenerated public final class ReconnectServiceGrpc { @@ -94,6 +91,21 @@ public ReconnectServiceStub newStub(io.grpc.Channel channel, io.grpc.CallOptions return ReconnectServiceStub.newStub(factory, channel); } + /** + * Creates a new blocking-style stub that supports all types of calls on the service + */ + public static ReconnectServiceBlockingV2Stub newBlockingV2Stub( + io.grpc.Channel channel) { + io.grpc.stub.AbstractStub.StubFactory factory = + new io.grpc.stub.AbstractStub.StubFactory() { + @java.lang.Override + public ReconnectServiceBlockingV2Stub newStub(io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new ReconnectServiceBlockingV2Stub(channel, callOptions); + } + }; + return ReconnectServiceBlockingV2Stub.newStub(factory, channel); + } + /** * Creates a new blocking-style stub that supports unary and streaming output calls on the service */ @@ -202,6 +214,40 @@ public void stop(io.grpc.testing.integration.EmptyProtos.Empty request, * A service used to control reconnect server. * */ + public static final class ReconnectServiceBlockingV2Stub + extends io.grpc.stub.AbstractBlockingStub { + private ReconnectServiceBlockingV2Stub( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + super(channel, callOptions); + } + + @java.lang.Override + protected ReconnectServiceBlockingV2Stub build( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new ReconnectServiceBlockingV2Stub(channel, callOptions); + } + + /** + */ + public io.grpc.testing.integration.EmptyProtos.Empty start(io.grpc.testing.integration.Messages.ReconnectParams request) throws io.grpc.StatusException { + return io.grpc.stub.ClientCalls.blockingV2UnaryCall( + getChannel(), getStartMethod(), getCallOptions(), request); + } + + /** + */ + public io.grpc.testing.integration.Messages.ReconnectInfo stop(io.grpc.testing.integration.EmptyProtos.Empty request) throws io.grpc.StatusException { + return io.grpc.stub.ClientCalls.blockingV2UnaryCall( + getChannel(), getStopMethod(), getCallOptions(), request); + } + } + + /** + * A stub to allow clients to do limited synchronous rpc calls to service ReconnectService. + *
+   * A service used to control reconnect server.
+   * 
+ */ public static final class ReconnectServiceBlockingStub extends io.grpc.stub.AbstractBlockingStub { private ReconnectServiceBlockingStub( diff --git a/interop-testing/src/generated/main/grpc/io/grpc/testing/integration/TestServiceGrpc.java b/interop-testing/src/generated/main/grpc/io/grpc/testing/integration/TestServiceGrpc.java index 08071a3b653..a881c85c150 100644 --- a/interop-testing/src/generated/main/grpc/io/grpc/testing/integration/TestServiceGrpc.java +++ b/interop-testing/src/generated/main/grpc/io/grpc/testing/integration/TestServiceGrpc.java @@ -8,9 +8,6 @@ * performance with various types of payload. * */ -@javax.annotation.Generated( - value = "by gRPC proto compiler", - comments = "Source: grpc/testing/test.proto") @io.grpc.stub.annotations.GrpcGenerated public final class TestServiceGrpc { @@ -281,6 +278,21 @@ public TestServiceStub newStub(io.grpc.Channel channel, io.grpc.CallOptions call return TestServiceStub.newStub(factory, channel); } + /** + * Creates a new blocking-style stub that supports all types of calls on the service + */ + public static TestServiceBlockingV2Stub newBlockingV2Stub( + io.grpc.Channel channel) { + io.grpc.stub.AbstractStub.StubFactory factory = + new io.grpc.stub.AbstractStub.StubFactory() { + @java.lang.Override + public TestServiceBlockingV2Stub newStub(io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new TestServiceBlockingV2Stub(channel, callOptions); + } + }; + return TestServiceBlockingV2Stub.newStub(factory, channel); + } + /** * Creates a new blocking-style stub that supports unary and streaming output calls on the service */ @@ -551,6 +563,125 @@ public void unimplementedCall(io.grpc.testing.integration.EmptyProtos.Empty requ * performance with various types of payload. * */ + public static final class TestServiceBlockingV2Stub + extends io.grpc.stub.AbstractBlockingStub { + private TestServiceBlockingV2Stub( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + super(channel, callOptions); + } + + @java.lang.Override + protected TestServiceBlockingV2Stub build( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new TestServiceBlockingV2Stub(channel, callOptions); + } + + /** + *
+     * One empty request followed by one empty response.
+     * 
+ */ + public io.grpc.testing.integration.EmptyProtos.Empty emptyCall(io.grpc.testing.integration.EmptyProtos.Empty request) throws io.grpc.StatusException { + return io.grpc.stub.ClientCalls.blockingV2UnaryCall( + getChannel(), getEmptyCallMethod(), getCallOptions(), request); + } + + /** + *
+     * One request followed by one response.
+     * 
+ */ + public io.grpc.testing.integration.Messages.SimpleResponse unaryCall(io.grpc.testing.integration.Messages.SimpleRequest request) throws io.grpc.StatusException { + return io.grpc.stub.ClientCalls.blockingV2UnaryCall( + getChannel(), getUnaryCallMethod(), getCallOptions(), request); + } + + /** + *
+     * One request followed by one response. Response has cache control
+     * headers set such that a caching HTTP proxy (such as GFE) can
+     * satisfy subsequent requests.
+     * 
+ */ + public io.grpc.testing.integration.Messages.SimpleResponse cacheableUnaryCall(io.grpc.testing.integration.Messages.SimpleRequest request) throws io.grpc.StatusException { + return io.grpc.stub.ClientCalls.blockingV2UnaryCall( + getChannel(), getCacheableUnaryCallMethod(), getCallOptions(), request); + } + + /** + *
+     * One request followed by a sequence of responses (streamed download).
+     * The server returns the payload with client desired type and sizes.
+     * 
+ */ + @io.grpc.ExperimentalApi("https://github.com/grpc/grpc-java/issues/10918") + public io.grpc.stub.BlockingClientCall + streamingOutputCall(io.grpc.testing.integration.Messages.StreamingOutputCallRequest request) { + return io.grpc.stub.ClientCalls.blockingV2ServerStreamingCall( + getChannel(), getStreamingOutputCallMethod(), getCallOptions(), request); + } + + /** + *
+     * A sequence of requests followed by one response (streamed upload).
+     * The server returns the aggregated size of client payload as the result.
+     * 
+ */ + @io.grpc.ExperimentalApi("https://github.com/grpc/grpc-java/issues/10918") + public io.grpc.stub.BlockingClientCall + streamingInputCall() { + return io.grpc.stub.ClientCalls.blockingClientStreamingCall( + getChannel(), getStreamingInputCallMethod(), getCallOptions()); + } + + /** + *
+     * A sequence of requests with each request served by the server immediately.
+     * As one request could lead to multiple responses, this interface
+     * demonstrates the idea of full duplexing.
+     * 
+ */ + @io.grpc.ExperimentalApi("https://github.com/grpc/grpc-java/issues/10918") + public io.grpc.stub.BlockingClientCall + fullDuplexCall() { + return io.grpc.stub.ClientCalls.blockingBidiStreamingCall( + getChannel(), getFullDuplexCallMethod(), getCallOptions()); + } + + /** + *
+     * A sequence of requests followed by a sequence of responses.
+     * The server buffers all the client requests and then serves them in order. A
+     * stream of responses are returned to the client when the server starts with
+     * first request.
+     * 
+ */ + @io.grpc.ExperimentalApi("https://github.com/grpc/grpc-java/issues/10918") + public io.grpc.stub.BlockingClientCall + halfDuplexCall() { + return io.grpc.stub.ClientCalls.blockingBidiStreamingCall( + getChannel(), getHalfDuplexCallMethod(), getCallOptions()); + } + + /** + *
+     * The test server will not implement this method. It will be used
+     * to test the behavior when clients call unimplemented methods.
+     * 
+ */ + public io.grpc.testing.integration.EmptyProtos.Empty unimplementedCall(io.grpc.testing.integration.EmptyProtos.Empty request) throws io.grpc.StatusException { + return io.grpc.stub.ClientCalls.blockingV2UnaryCall( + getChannel(), getUnimplementedCallMethod(), getCallOptions(), request); + } + } + + /** + * A stub to allow clients to do limited synchronous rpc calls to service TestService. + *
+   * A simple service to test the various types of RPCs and experiment with
+   * performance with various types of payload.
+   * 
+ */ public static final class TestServiceBlockingStub extends io.grpc.stub.AbstractBlockingStub { private TestServiceBlockingStub( diff --git a/interop-testing/src/generated/main/grpc/io/grpc/testing/integration/UnimplementedServiceGrpc.java b/interop-testing/src/generated/main/grpc/io/grpc/testing/integration/UnimplementedServiceGrpc.java index 9711386185e..fdd8d5650ed 100644 --- a/interop-testing/src/generated/main/grpc/io/grpc/testing/integration/UnimplementedServiceGrpc.java +++ b/interop-testing/src/generated/main/grpc/io/grpc/testing/integration/UnimplementedServiceGrpc.java @@ -8,9 +8,6 @@ * that case. * */ -@javax.annotation.Generated( - value = "by gRPC proto compiler", - comments = "Source: grpc/testing/test.proto") @io.grpc.stub.annotations.GrpcGenerated public final class UnimplementedServiceGrpc { @@ -64,6 +61,21 @@ public UnimplementedServiceStub newStub(io.grpc.Channel channel, io.grpc.CallOpt return UnimplementedServiceStub.newStub(factory, channel); } + /** + * Creates a new blocking-style stub that supports all types of calls on the service + */ + public static UnimplementedServiceBlockingV2Stub newBlockingV2Stub( + io.grpc.Channel channel) { + io.grpc.stub.AbstractStub.StubFactory factory = + new io.grpc.stub.AbstractStub.StubFactory() { + @java.lang.Override + public UnimplementedServiceBlockingV2Stub newStub(io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new UnimplementedServiceBlockingV2Stub(channel, callOptions); + } + }; + return UnimplementedServiceBlockingV2Stub.newStub(factory, channel); + } + /** * Creates a new blocking-style stub that supports unary and streaming output calls on the service */ @@ -167,6 +179,37 @@ public void unimplementedCall(io.grpc.testing.integration.EmptyProtos.Empty requ * that case. * */ + public static final class UnimplementedServiceBlockingV2Stub + extends io.grpc.stub.AbstractBlockingStub { + private UnimplementedServiceBlockingV2Stub( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + super(channel, callOptions); + } + + @java.lang.Override + protected UnimplementedServiceBlockingV2Stub build( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new UnimplementedServiceBlockingV2Stub(channel, callOptions); + } + + /** + *
+     * A call that no server should implement
+     * 
+ */ + public io.grpc.testing.integration.EmptyProtos.Empty unimplementedCall(io.grpc.testing.integration.EmptyProtos.Empty request) throws io.grpc.StatusException { + return io.grpc.stub.ClientCalls.blockingV2UnaryCall( + getChannel(), getUnimplementedCallMethod(), getCallOptions(), request); + } + } + + /** + * A stub to allow clients to do limited synchronous rpc calls to service UnimplementedService. + *
+   * A simple service NOT implemented at servers so clients can test for
+   * that case.
+   * 
+ */ public static final class UnimplementedServiceBlockingStub extends io.grpc.stub.AbstractBlockingStub { private UnimplementedServiceBlockingStub( diff --git a/interop-testing/src/generated/main/grpc/io/grpc/testing/integration/XdsUpdateClientConfigureServiceGrpc.java b/interop-testing/src/generated/main/grpc/io/grpc/testing/integration/XdsUpdateClientConfigureServiceGrpc.java index 164119a29e7..6c019efefea 100644 --- a/interop-testing/src/generated/main/grpc/io/grpc/testing/integration/XdsUpdateClientConfigureServiceGrpc.java +++ b/interop-testing/src/generated/main/grpc/io/grpc/testing/integration/XdsUpdateClientConfigureServiceGrpc.java @@ -7,9 +7,6 @@ * A service to dynamically update the configuration of an xDS test client. * */ -@javax.annotation.Generated( - value = "by gRPC proto compiler", - comments = "Source: grpc/testing/test.proto") @io.grpc.stub.annotations.GrpcGenerated public final class XdsUpdateClientConfigureServiceGrpc { @@ -63,6 +60,21 @@ public XdsUpdateClientConfigureServiceStub newStub(io.grpc.Channel channel, io.g return XdsUpdateClientConfigureServiceStub.newStub(factory, channel); } + /** + * Creates a new blocking-style stub that supports all types of calls on the service + */ + public static XdsUpdateClientConfigureServiceBlockingV2Stub newBlockingV2Stub( + io.grpc.Channel channel) { + io.grpc.stub.AbstractStub.StubFactory factory = + new io.grpc.stub.AbstractStub.StubFactory() { + @java.lang.Override + public XdsUpdateClientConfigureServiceBlockingV2Stub newStub(io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new XdsUpdateClientConfigureServiceBlockingV2Stub(channel, callOptions); + } + }; + return XdsUpdateClientConfigureServiceBlockingV2Stub.newStub(factory, channel); + } + /** * Creates a new blocking-style stub that supports unary and streaming output calls on the service */ @@ -162,6 +174,36 @@ public void configure(io.grpc.testing.integration.Messages.ClientConfigureReques * A service to dynamically update the configuration of an xDS test client. * */ + public static final class XdsUpdateClientConfigureServiceBlockingV2Stub + extends io.grpc.stub.AbstractBlockingStub { + private XdsUpdateClientConfigureServiceBlockingV2Stub( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + super(channel, callOptions); + } + + @java.lang.Override + protected XdsUpdateClientConfigureServiceBlockingV2Stub build( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new XdsUpdateClientConfigureServiceBlockingV2Stub(channel, callOptions); + } + + /** + *
+     * Update the tes client's configuration.
+     * 
+ */ + public io.grpc.testing.integration.Messages.ClientConfigureResponse configure(io.grpc.testing.integration.Messages.ClientConfigureRequest request) throws io.grpc.StatusException { + return io.grpc.stub.ClientCalls.blockingV2UnaryCall( + getChannel(), getConfigureMethod(), getCallOptions(), request); + } + } + + /** + * A stub to allow clients to do limited synchronous rpc calls to service XdsUpdateClientConfigureService. + *
+   * A service to dynamically update the configuration of an xDS test client.
+   * 
+ */ public static final class XdsUpdateClientConfigureServiceBlockingStub extends io.grpc.stub.AbstractBlockingStub { private XdsUpdateClientConfigureServiceBlockingStub( diff --git a/interop-testing/src/generated/main/grpc/io/grpc/testing/integration/XdsUpdateHealthServiceGrpc.java b/interop-testing/src/generated/main/grpc/io/grpc/testing/integration/XdsUpdateHealthServiceGrpc.java index dccd23ccbee..5531033ae5c 100644 --- a/interop-testing/src/generated/main/grpc/io/grpc/testing/integration/XdsUpdateHealthServiceGrpc.java +++ b/interop-testing/src/generated/main/grpc/io/grpc/testing/integration/XdsUpdateHealthServiceGrpc.java @@ -7,9 +7,6 @@ * A service to remotely control health status of an xDS test server. * */ -@javax.annotation.Generated( - value = "by gRPC proto compiler", - comments = "Source: grpc/testing/test.proto") @io.grpc.stub.annotations.GrpcGenerated public final class XdsUpdateHealthServiceGrpc { @@ -94,6 +91,21 @@ public XdsUpdateHealthServiceStub newStub(io.grpc.Channel channel, io.grpc.CallO return XdsUpdateHealthServiceStub.newStub(factory, channel); } + /** + * Creates a new blocking-style stub that supports all types of calls on the service + */ + public static XdsUpdateHealthServiceBlockingV2Stub newBlockingV2Stub( + io.grpc.Channel channel) { + io.grpc.stub.AbstractStub.StubFactory factory = + new io.grpc.stub.AbstractStub.StubFactory() { + @java.lang.Override + public XdsUpdateHealthServiceBlockingV2Stub newStub(io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new XdsUpdateHealthServiceBlockingV2Stub(channel, callOptions); + } + }; + return XdsUpdateHealthServiceBlockingV2Stub.newStub(factory, channel); + } + /** * Creates a new blocking-style stub that supports unary and streaming output calls on the service */ @@ -202,6 +214,40 @@ public void setNotServing(io.grpc.testing.integration.EmptyProtos.Empty request, * A service to remotely control health status of an xDS test server. * */ + public static final class XdsUpdateHealthServiceBlockingV2Stub + extends io.grpc.stub.AbstractBlockingStub { + private XdsUpdateHealthServiceBlockingV2Stub( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + super(channel, callOptions); + } + + @java.lang.Override + protected XdsUpdateHealthServiceBlockingV2Stub build( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new XdsUpdateHealthServiceBlockingV2Stub(channel, callOptions); + } + + /** + */ + public io.grpc.testing.integration.EmptyProtos.Empty setServing(io.grpc.testing.integration.EmptyProtos.Empty request) throws io.grpc.StatusException { + return io.grpc.stub.ClientCalls.blockingV2UnaryCall( + getChannel(), getSetServingMethod(), getCallOptions(), request); + } + + /** + */ + public io.grpc.testing.integration.EmptyProtos.Empty setNotServing(io.grpc.testing.integration.EmptyProtos.Empty request) throws io.grpc.StatusException { + return io.grpc.stub.ClientCalls.blockingV2UnaryCall( + getChannel(), getSetNotServingMethod(), getCallOptions(), request); + } + } + + /** + * A stub to allow clients to do limited synchronous rpc calls to service XdsUpdateHealthService. + *
+   * A service to remotely control health status of an xDS test server.
+   * 
+ */ public static final class XdsUpdateHealthServiceBlockingStub extends io.grpc.stub.AbstractBlockingStub { private XdsUpdateHealthServiceBlockingStub( diff --git a/interop-testing/src/main/java/io/grpc/testing/integration/AbstractInteropTest.java b/interop-testing/src/main/java/io/grpc/testing/integration/AbstractInteropTest.java index d450ece7bcf..51295281a90 100644 --- a/interop-testing/src/main/java/io/grpc/testing/integration/AbstractInteropTest.java +++ b/interop-testing/src/main/java/io/grpc/testing/integration/AbstractInteropTest.java @@ -21,18 +21,12 @@ import static io.grpc.stub.ClientCalls.blockingServerStreamingCall; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertNotEquals; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertSame; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; -import com.google.auth.oauth2.AccessToken; -import com.google.auth.oauth2.ComputeEngineCredentials; -import com.google.auth.oauth2.GoogleCredentials; -import com.google.auth.oauth2.OAuth2Credentials; -import com.google.auth.oauth2.ServiceAccountCredentials; import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Throwables; import com.google.common.collect.ImmutableList; @@ -45,7 +39,6 @@ import io.grpc.Channel; import io.grpc.ClientCall; import io.grpc.ClientInterceptor; -import io.grpc.ClientInterceptors; import io.grpc.ClientStreamTracer; import io.grpc.Context; import io.grpc.Grpc; @@ -62,7 +55,6 @@ import io.grpc.ServerStreamTracer; import io.grpc.Status; import io.grpc.StatusRuntimeException; -import io.grpc.auth.MoreCallCredentials; import io.grpc.census.InternalCensusStatsAccessor; import io.grpc.census.internal.DeprecatedCensusConstants; import io.grpc.internal.GrpcUtil; @@ -77,7 +69,6 @@ import io.grpc.internal.testing.TestServerStreamTracer; import io.grpc.internal.testing.TestStreamTracer; import io.grpc.stub.ClientCallStreamObserver; -import io.grpc.stub.ClientCalls; import io.grpc.stub.MetadataUtils; import io.grpc.stub.StreamObserver; import io.grpc.testing.TestUtils; @@ -92,7 +83,6 @@ import io.grpc.testing.integration.Messages.StreamingInputCallResponse; import io.grpc.testing.integration.Messages.StreamingOutputCallRequest; import io.grpc.testing.integration.Messages.StreamingOutputCallResponse; -import io.grpc.testing.integration.Messages.TestOrcaReport; import io.opencensus.contrib.grpc.metrics.RpcMeasureConstants; import io.opencensus.stats.Measure; import io.opencensus.stats.Measure.MeasureDouble; @@ -118,7 +108,6 @@ import java.util.Locale; import java.util.Map; import java.util.concurrent.ArrayBlockingQueue; -import java.util.concurrent.BlockingQueue; import java.util.concurrent.Executors; import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.ScheduledExecutorService; @@ -130,7 +119,6 @@ import javax.annotation.Nullable; import javax.net.ssl.SSLPeerUnverifiedException; import javax.net.ssl.SSLSession; -import org.HdrHistogram.Histogram; import org.junit.After; import org.junit.Assert; import org.junit.Assume; @@ -144,7 +132,7 @@ /** * Abstract base class for all GRPC transport tests. * - *

New tests should avoid using Mockito to support running on AppEngine.

+ *

New tests should avoid using Mockito to support running on AppEngine. */ public abstract class AbstractInteropTest { private static Logger logger = Logger.getLogger(AbstractInteropTest.class.getName()); @@ -191,11 +179,6 @@ public abstract class AbstractInteropTest { private final LinkedBlockingQueue serverStreamTracers = new LinkedBlockingQueue<>(); - static final CallOptions.Key> - ORCA_RPC_REPORT_KEY = CallOptions.Key.create("orca-rpc-report"); - static final CallOptions.Key> - ORCA_OOB_REPORT_KEY = CallOptions.Key.create("orca-oob-report"); - private static final class ServerStreamTracerInfo { final String fullMethodName; final InteropServerStreamTracer tracer; @@ -451,47 +434,6 @@ public void emptyUnaryWithRetriableStream() throws Exception { assertEquals(EMPTY, TestServiceGrpc.newBlockingStub(channel).emptyCall(EMPTY)); } - /** Sends a cacheable unary rpc using GET. Requires that the server is behind a caching proxy. */ - public void cacheableUnary() { - // THIS TEST IS BROKEN. Enabling safe just on the MethodDescriptor does nothing by itself. This - // test would need to enable GET on the channel. - // Set safe to true. - MethodDescriptor safeCacheableUnaryCallMethod = - TestServiceGrpc.getCacheableUnaryCallMethod().toBuilder().setSafe(true).build(); - // Set fake user IP since some proxies (GFE) won't cache requests from localhost. - Metadata.Key userIpKey = Metadata.Key.of("x-user-ip", Metadata.ASCII_STRING_MARSHALLER); - Metadata metadata = new Metadata(); - metadata.put(userIpKey, "1.2.3.4"); - Channel channelWithUserIpKey = - ClientInterceptors.intercept(channel, MetadataUtils.newAttachHeadersInterceptor(metadata)); - SimpleRequest requests1And2 = - SimpleRequest.newBuilder() - .setPayload( - Payload.newBuilder() - .setBody(ByteString.copyFromUtf8(String.valueOf(System.nanoTime())))) - .build(); - SimpleRequest request3 = - SimpleRequest.newBuilder() - .setPayload( - Payload.newBuilder() - .setBody(ByteString.copyFromUtf8(String.valueOf(System.nanoTime())))) - .build(); - - SimpleResponse response1 = - ClientCalls.blockingUnaryCall( - channelWithUserIpKey, safeCacheableUnaryCallMethod, CallOptions.DEFAULT, requests1And2); - SimpleResponse response2 = - ClientCalls.blockingUnaryCall( - channelWithUserIpKey, safeCacheableUnaryCallMethod, CallOptions.DEFAULT, requests1And2); - SimpleResponse response3 = - ClientCalls.blockingUnaryCall( - channelWithUserIpKey, safeCacheableUnaryCallMethod, CallOptions.DEFAULT, request3); - - assertEquals(response1, response2); - assertNotEquals(response1, response3); - // THIS TEST IS BROKEN. See comment at start of method. - } - @Test public void largeUnary() throws Exception { assumeEnoughMemory(); @@ -541,7 +483,7 @@ public void clientCompressedUnary(boolean probe) throws Exception { blockingStub.unaryCall(expectCompressedRequest); fail("expected INVALID_ARGUMENT"); } catch (StatusRuntimeException e) { - assertEquals(Status.INVALID_ARGUMENT.getCode(), e.getStatus().getCode()); + assertCodeEquals(Status.Code.INVALID_ARGUMENT, e.getStatus()); } assertStatsTrace("grpc.testing.TestService/UnaryCall", Status.Code.INVALID_ARGUMENT); } @@ -603,26 +545,6 @@ public void serverCompressedUnary() throws Exception { Collections.singleton(goldenResponse)); } - /** - * Assuming "pick_first" policy is used, tests that all requests are sent to the same server. - */ - public void pickFirstUnary() throws Exception { - SimpleRequest request = SimpleRequest.newBuilder() - .setResponseSize(1) - .setFillServerId(true) - .setPayload(Payload.newBuilder().setBody(ByteString.copyFrom(new byte[1]))) - .build(); - - SimpleResponse firstResponse = blockingStub.unaryCall(request); - // Increase the chance of all servers are connected, in case the channel should be doing - // round_robin instead. - Thread.sleep(5000); - for (int i = 0; i < 100; i++) { - SimpleResponse response = blockingStub.unaryCall(request); - assertThat(response.getServerId()).isEqualTo(firstResponse.getServerId()); - } - } - @Test public void serverStreaming() throws Exception { final StreamingOutputCallRequest request = StreamingOutputCallRequest.newBuilder() @@ -730,7 +652,7 @@ public void clientCompressedStreaming(boolean probe) throws Exception { responseObserver.awaitCompletion(operationTimeoutMillis(), TimeUnit.MILLISECONDS); Throwable e = responseObserver.getError(); assertNotNull("expected INVALID_ARGUMENT", e); - assertEquals(Status.INVALID_ARGUMENT.getCode(), Status.fromThrowable(e).getCode()); + assertCodeEquals(Status.Code.INVALID_ARGUMENT, Status.fromThrowable(e)); } // Start a new stream @@ -879,8 +801,7 @@ public void cancelAfterBegin() throws Exception { requestObserver.onError(new RuntimeException()); responseObserver.awaitCompletion(); assertEquals(Arrays.asList(), responseObserver.getValues()); - assertEquals(Status.Code.CANCELLED, - Status.fromThrowable(responseObserver.getError()).getCode()); + assertCodeEquals(Status.Code.CANCELLED, Status.fromThrowable(responseObserver.getError())); if (metricsExpected()) { MetricsRecord clientStartRecord = clientStatsRecorder.pollRecord(5, TimeUnit.SECONDS); @@ -917,8 +838,7 @@ public void cancelAfterFirstResponse() throws Exception { requestObserver.onError(new RuntimeException()); responseObserver.awaitCompletion(operationTimeoutMillis(), TimeUnit.MILLISECONDS); assertEquals(1, responseObserver.getValues().size()); - assertEquals(Status.Code.CANCELLED, - Status.fromThrowable(responseObserver.getError()).getCode()); + assertCodeEquals(Status.Code.CANCELLED, Status.fromThrowable(responseObserver.getError())); assertStatsTrace("grpc.testing.TestService/FullDuplexCall", Status.Code.CANCELLED); } @@ -1185,12 +1105,12 @@ public void deadlineExceeded() throws Exception { stub.streamingOutputCall(request).next(); fail("Expected deadline to be exceeded"); } catch (StatusRuntimeException ex) { - assertEquals(Status.DEADLINE_EXCEEDED.getCode(), ex.getStatus().getCode()); + assertCodeEquals(Status.Code.DEADLINE_EXCEEDED, ex.getStatus()); String desc = ex.getStatus().getDescription(); assertTrue(desc, // There is a race between client and server-side deadline expiration. // If client expires first, it'd generate this message - Pattern.matches("deadline exceeded after .*s. \\[.*\\]", desc) + Pattern.matches("CallOptions deadline exceeded after .*s. \\[.*\\]", desc) // If server expires first, it'd reset the stream and client would generate a different // message || desc.startsWith("ClientCall was cancelled at or after deadline.")); @@ -1231,8 +1151,7 @@ public void deadlineExceededServerStreaming() throws Exception { .withDeadlineAfter(30, TimeUnit.MILLISECONDS) .streamingOutputCall(request, recorder); recorder.awaitCompletion(); - assertEquals(Status.DEADLINE_EXCEEDED.getCode(), - Status.fromThrowable(recorder.getError()).getCode()); + assertCodeEquals(Status.Code.DEADLINE_EXCEEDED, Status.fromThrowable(recorder.getError())); if (metricsExpected()) { // Stream may not have been created when deadline is exceeded, thus we don't check tracer // stats. @@ -1257,7 +1176,7 @@ public void deadlineInPast() throws Exception { .emptyCall(Empty.getDefaultInstance()); fail("Should have thrown"); } catch (StatusRuntimeException ex) { - assertEquals(Status.Code.DEADLINE_EXCEEDED, ex.getStatus().getCode()); + assertCodeEquals(Status.Code.DEADLINE_EXCEEDED, ex.getStatus()); assertThat(ex.getStatus().getDescription()) .startsWith("ClientCall started after CallOptions deadline was exceeded"); } @@ -1290,7 +1209,7 @@ public void deadlineInPast() throws Exception { .emptyCall(Empty.getDefaultInstance()); fail("Should have thrown"); } catch (StatusRuntimeException ex) { - assertEquals(Status.Code.DEADLINE_EXCEEDED, ex.getStatus().getCode()); + assertCodeEquals(Status.Code.DEADLINE_EXCEEDED, ex.getStatus()); assertThat(ex.getStatus().getDescription()) .startsWith("ClientCall started after CallOptions deadline was exceeded"); } @@ -1356,8 +1275,7 @@ public void maxInboundSize_tooBig() { stub.streamingOutputCall(request).next(); fail(); } catch (StatusRuntimeException ex) { - Status s = ex.getStatus(); - assertWithMessage(s.toString()).that(s.getCode()).isEqualTo(Status.Code.RESOURCE_EXHAUSTED); + assertCodeEquals(Status.Code.RESOURCE_EXHAUSTED, ex.getStatus()); assertThat(Throwables.getStackTraceAsString(ex)).contains("exceeds maximum"); } } @@ -1412,8 +1330,7 @@ public void maxOutboundSize_tooBig() { stub.streamingOutputCall(request).next(); fail(); } catch (StatusRuntimeException ex) { - Status s = ex.getStatus(); - assertWithMessage(s.toString()).that(s.getCode()).isEqualTo(Status.Code.CANCELLED); + assertCodeEquals(Status.Code.CANCELLED, ex.getStatus()); assertThat(Throwables.getStackTraceAsString(ex)).contains("message too large"); } } @@ -1635,7 +1552,7 @@ public void statusCodeAndMessage() throws Exception { blockingStub.unaryCall(simpleRequest); fail(); } catch (StatusRuntimeException e) { - assertEquals(Status.UNKNOWN.getCode(), e.getStatus().getCode()); + assertCodeEquals(Status.Code.UNKNOWN, e.getStatus()); assertEquals(errorMessage, e.getStatus().getDescription()); } assertStatsTrace("grpc.testing.TestService/UnaryCall", Status.Code.UNKNOWN); @@ -1651,7 +1568,7 @@ public void statusCodeAndMessage() throws Exception { .isTrue(); assertThat(responseObserver.getError()).isNotNull(); Status status = Status.fromThrowable(responseObserver.getError()); - assertEquals(Status.UNKNOWN.getCode(), status.getCode()); + assertCodeEquals(Status.Code.UNKNOWN, status); assertEquals(errorMessage, status.getDescription()); assertStatsTrace("grpc.testing.TestService/FullDuplexCall", Status.Code.UNKNOWN); } @@ -1671,7 +1588,7 @@ public void specialStatusMessage() throws Exception { blockingStub.unaryCall(simpleRequest); fail(); } catch (StatusRuntimeException e) { - assertEquals(Status.UNKNOWN.getCode(), e.getStatus().getCode()); + assertCodeEquals(Status.Code.UNKNOWN, e.getStatus()); assertEquals(errorMessage, e.getStatus().getDescription()); } assertStatsTrace("grpc.testing.TestService/UnaryCall", Status.Code.UNKNOWN); @@ -1684,7 +1601,7 @@ public void unimplementedMethod() { blockingStub.unimplementedCall(Empty.getDefaultInstance()); fail(); } catch (StatusRuntimeException e) { - assertEquals(Status.UNIMPLEMENTED.getCode(), e.getStatus().getCode()); + assertCodeEquals(Status.Code.UNIMPLEMENTED, e.getStatus()); } assertClientStatsTrace("grpc.testing.TestService/UnimplementedCall", @@ -1700,7 +1617,7 @@ public void unimplementedService() { stub.unimplementedCall(Empty.getDefaultInstance()); fail(); } catch (StatusRuntimeException e) { - assertEquals(Status.UNIMPLEMENTED.getCode(), e.getStatus().getCode()); + assertCodeEquals(Status.Code.UNIMPLEMENTED, e.getStatus()); } assertStatsTrace("grpc.testing.UnimplementedService/UnimplementedCall", @@ -1708,7 +1625,6 @@ public void unimplementedService() { } /** Start a fullDuplexCall which the server will not respond, and verify the deadline expires. */ - @SuppressWarnings("MissingFail") @Test public void timeoutOnSleepingServer() throws Exception { TestServiceGrpc.TestServiceStub stub = @@ -1718,20 +1634,15 @@ public void timeoutOnSleepingServer() throws Exception { StreamObserver requestObserver = stub.fullDuplexCall(responseObserver); - StreamingOutputCallRequest request = StreamingOutputCallRequest.newBuilder() + requestObserver.onNext(StreamingOutputCallRequest.newBuilder() .setPayload(Payload.newBuilder() .setBody(ByteString.copyFrom(new byte[27182]))) - .build(); - try { - requestObserver.onNext(request); - } catch (IllegalStateException expected) { - // This can happen if the stream has already been terminated due to deadline exceeded. - } + .build()); assertTrue(responseObserver.awaitCompletion(operationTimeoutMillis(), TimeUnit.MILLISECONDS)); assertEquals(0, responseObserver.getValues().size()); - assertEquals(Status.DEADLINE_EXCEEDED.getCode(), - Status.fromThrowable(responseObserver.getError()).getCode()); + assertCodeEquals( + Status.Code.DEADLINE_EXCEEDED, Status.fromThrowable(responseObserver.getError())); if (metricsExpected()) { // CensusStreamTracerModule record final status in the interceptor, thus is guaranteed to be @@ -1757,389 +1668,6 @@ public void getServerAddressAndLocalAddressFromClient() { assertNotNull(obtainLocalClientAddr()); } - /** - * Test backend metrics per query reporting: expect the test client LB policy to receive load - * reports. - */ - public void testOrcaPerRpc() throws Exception { - AtomicReference reportHolder = new AtomicReference<>(); - TestOrcaReport answer = TestOrcaReport.newBuilder() - .setCpuUtilization(0.8210) - .setMemoryUtilization(0.5847) - .putRequestCost("cost", 3456.32) - .putUtilization("util", 0.30499) - .build(); - blockingStub.withOption(ORCA_RPC_REPORT_KEY, reportHolder).unaryCall( - SimpleRequest.newBuilder().setOrcaPerQueryReport(answer).build()); - assertThat(reportHolder.get()).isEqualTo(answer); - } - - /** - * Test backend metrics OOB reporting: expect the test client LB policy to receive load reports. - */ - public void testOrcaOob() throws Exception { - AtomicReference reportHolder = new AtomicReference<>(); - final TestOrcaReport answer = TestOrcaReport.newBuilder() - .setCpuUtilization(0.8210) - .setMemoryUtilization(0.5847) - .putUtilization("util", 0.30499) - .build(); - final TestOrcaReport answer2 = TestOrcaReport.newBuilder() - .setCpuUtilization(0.29309) - .setMemoryUtilization(0.2) - .putUtilization("util", 0.2039) - .build(); - - final int retryLimit = 5; - BlockingQueue queue = new LinkedBlockingQueue<>(); - final Object lastItem = new Object(); - StreamObserver streamObserver = - asyncStub.fullDuplexCall(new StreamObserver() { - - @Override - public void onNext(StreamingOutputCallResponse value) { - queue.add(value); - } - - @Override - public void onError(Throwable t) { - queue.add(t); - } - - @Override - public void onCompleted() { - queue.add(lastItem); - } - }); - - streamObserver.onNext(StreamingOutputCallRequest.newBuilder() - .setOrcaOobReport(answer) - .addResponseParameters(ResponseParameters.newBuilder().setSize(1).build()).build()); - assertThat(queue.take()).isInstanceOf(StreamingOutputCallResponse.class); - int i = 0; - for (; i < retryLimit; i++) { - Thread.sleep(1000); - blockingStub.withOption(ORCA_OOB_REPORT_KEY, reportHolder).emptyCall(EMPTY); - if (answer.equals(reportHolder.get())) { - break; - } - } - assertThat(i).isLessThan(retryLimit); - streamObserver.onNext(StreamingOutputCallRequest.newBuilder() - .setOrcaOobReport(answer2) - .addResponseParameters(ResponseParameters.newBuilder().setSize(1).build()).build()); - assertThat(queue.take()).isInstanceOf(StreamingOutputCallResponse.class); - - for (i = 0; i < retryLimit; i++) { - Thread.sleep(1000); - blockingStub.withOption(ORCA_OOB_REPORT_KEY, reportHolder).emptyCall(EMPTY); - if (reportHolder.get().equals(answer2)) { - break; - } - } - assertThat(i).isLessThan(retryLimit); - streamObserver.onCompleted(); - assertThat(queue.take()).isSameInstanceAs(lastItem); - } - - /** Sends a large unary rpc with service account credentials. */ - public void serviceAccountCreds(String jsonKey, InputStream credentialsStream, String authScope) - throws Exception { - // cast to ServiceAccountCredentials to double-check the right type of object was created. - GoogleCredentials credentials = - ServiceAccountCredentials.class.cast(GoogleCredentials.fromStream(credentialsStream)); - credentials = credentials.createScoped(Arrays.asList(authScope)); - TestServiceGrpc.TestServiceBlockingStub stub = blockingStub - .withCallCredentials(MoreCallCredentials.from(credentials)); - final SimpleRequest request = SimpleRequest.newBuilder() - .setFillUsername(true) - .setFillOauthScope(true) - .setResponseSize(314159) - .setPayload(Payload.newBuilder() - .setBody(ByteString.copyFrom(new byte[271828]))) - .build(); - - final SimpleResponse response = stub.unaryCall(request); - assertFalse(response.getUsername().isEmpty()); - assertTrue("Received username: " + response.getUsername(), - jsonKey.contains(response.getUsername())); - assertFalse(response.getOauthScope().isEmpty()); - assertTrue("Received oauth scope: " + response.getOauthScope(), - authScope.contains(response.getOauthScope())); - - final SimpleResponse goldenResponse = SimpleResponse.newBuilder() - .setOauthScope(response.getOauthScope()) - .setUsername(response.getUsername()) - .setPayload(Payload.newBuilder() - .setBody(ByteString.copyFrom(new byte[314159]))) - .build(); - assertResponse(goldenResponse, response); - } - - /** Sends a large unary rpc with compute engine credentials. */ - public void computeEngineCreds(String serviceAccount, String oauthScope) throws Exception { - ComputeEngineCredentials credentials = ComputeEngineCredentials.create(); - TestServiceGrpc.TestServiceBlockingStub stub = blockingStub - .withCallCredentials(MoreCallCredentials.from(credentials)); - final SimpleRequest request = SimpleRequest.newBuilder() - .setFillUsername(true) - .setFillOauthScope(true) - .setResponseSize(314159) - .setPayload(Payload.newBuilder() - .setBody(ByteString.copyFrom(new byte[271828]))) - .build(); - - final SimpleResponse response = stub.unaryCall(request); - assertEquals(serviceAccount, response.getUsername()); - assertFalse(response.getOauthScope().isEmpty()); - assertTrue("Received oauth scope: " + response.getOauthScope(), - oauthScope.contains(response.getOauthScope())); - - final SimpleResponse goldenResponse = SimpleResponse.newBuilder() - .setOauthScope(response.getOauthScope()) - .setUsername(response.getUsername()) - .setPayload(Payload.newBuilder() - .setBody(ByteString.copyFrom(new byte[314159]))) - .build(); - assertResponse(goldenResponse, response); - } - - /** Sends an unary rpc with ComputeEngineChannelBuilder. */ - public void computeEngineChannelCredentials( - String defaultServiceAccount, - TestServiceGrpc.TestServiceBlockingStub computeEngineStub) throws Exception { - final SimpleRequest request = SimpleRequest.newBuilder() - .setFillUsername(true) - .setResponseSize(314159) - .setPayload(Payload.newBuilder() - .setBody(ByteString.copyFrom(new byte[271828]))) - .build(); - final SimpleResponse response = computeEngineStub.unaryCall(request); - assertEquals(defaultServiceAccount, response.getUsername()); - final SimpleResponse goldenResponse = SimpleResponse.newBuilder() - .setUsername(defaultServiceAccount) - .setPayload(Payload.newBuilder() - .setBody(ByteString.copyFrom(new byte[314159]))) - .build(); - assertResponse(goldenResponse, response); - } - - /** Test JWT-based auth. */ - public void jwtTokenCreds(InputStream serviceAccountJson) throws Exception { - final SimpleRequest request = SimpleRequest.newBuilder() - .setResponseSize(314159) - .setPayload(Payload.newBuilder() - .setBody(ByteString.copyFrom(new byte[271828]))) - .setFillUsername(true) - .build(); - - ServiceAccountCredentials credentials = (ServiceAccountCredentials) - GoogleCredentials.fromStream(serviceAccountJson); - TestServiceGrpc.TestServiceBlockingStub stub = blockingStub - .withCallCredentials(MoreCallCredentials.from(credentials)); - SimpleResponse response = stub.unaryCall(request); - assertEquals(credentials.getClientEmail(), response.getUsername()); - assertEquals(314159, response.getPayload().getBody().size()); - } - - /** Sends a unary rpc with raw oauth2 access token credentials. */ - public void oauth2AuthToken(String jsonKey, InputStream credentialsStream, String authScope) - throws Exception { - GoogleCredentials utilCredentials = - GoogleCredentials.fromStream(credentialsStream); - utilCredentials = utilCredentials.createScoped(Arrays.asList(authScope)); - AccessToken accessToken = utilCredentials.refreshAccessToken(); - - OAuth2Credentials credentials = OAuth2Credentials.create(accessToken); - - TestServiceGrpc.TestServiceBlockingStub stub = blockingStub - .withCallCredentials(MoreCallCredentials.from(credentials)); - final SimpleRequest request = SimpleRequest.newBuilder() - .setFillUsername(true) - .setFillOauthScope(true) - .build(); - - final SimpleResponse response = stub.unaryCall(request); - assertFalse(response.getUsername().isEmpty()); - assertTrue("Received username: " + response.getUsername(), - jsonKey.contains(response.getUsername())); - assertFalse(response.getOauthScope().isEmpty()); - assertTrue("Received oauth scope: " + response.getOauthScope(), - authScope.contains(response.getOauthScope())); - } - - /** Sends a unary rpc with "per rpc" raw oauth2 access token credentials. */ - public void perRpcCreds(String jsonKey, InputStream credentialsStream, String oauthScope) - throws Exception { - // In gRpc Java, we don't have per Rpc credentials, user can use an intercepted stub only once - // for that purpose. - // So, this test is identical to oauth2_auth_token test. - oauth2AuthToken(jsonKey, credentialsStream, oauthScope); - } - - /** Sends an unary rpc with "google default credentials". */ - public void googleDefaultCredentials( - String defaultServiceAccount, - TestServiceGrpc.TestServiceBlockingStub googleDefaultStub) throws Exception { - final SimpleRequest request = SimpleRequest.newBuilder() - .setFillUsername(true) - .setResponseSize(314159) - .setPayload(Payload.newBuilder() - .setBody(ByteString.copyFrom(new byte[271828]))) - .build(); - final SimpleResponse response = googleDefaultStub.unaryCall(request); - assertEquals(defaultServiceAccount, response.getUsername()); - - final SimpleResponse goldenResponse = SimpleResponse.newBuilder() - .setUsername(defaultServiceAccount) - .setPayload(Payload.newBuilder() - .setBody(ByteString.copyFrom(new byte[314159]))) - .build(); - assertResponse(goldenResponse, response); - } - - private static class SoakIterationResult { - public SoakIterationResult(long latencyMs, Status status) { - this.latencyMs = latencyMs; - this.status = status; - } - - public long getLatencyMs() { - return latencyMs; - } - - public Status getStatus() { - return status; - } - - private long latencyMs = -1; - private Status status = Status.OK; - } - - private SoakIterationResult performOneSoakIteration( - TestServiceGrpc.TestServiceBlockingStub soakStub, int soakRequestSize, int soakResponseSize) - throws Exception { - long startNs = System.nanoTime(); - Status status = Status.OK; - try { - final SimpleRequest request = - SimpleRequest.newBuilder() - .setResponseSize(soakResponseSize) - .setPayload( - Payload.newBuilder().setBody(ByteString.copyFrom(new byte[soakRequestSize]))) - .build(); - final SimpleResponse goldenResponse = - SimpleResponse.newBuilder() - .setPayload( - Payload.newBuilder().setBody(ByteString.copyFrom(new byte[soakResponseSize]))) - .build(); - assertResponse(goldenResponse, soakStub.unaryCall(request)); - } catch (StatusRuntimeException e) { - status = e.getStatus(); - } - long elapsedNs = System.nanoTime() - startNs; - return new SoakIterationResult(TimeUnit.NANOSECONDS.toMillis(elapsedNs), status); - } - - /** - * Runs large unary RPCs in a loop with configurable failure thresholds - * and channel creation behavior. - */ - public void performSoakTest( - String serverUri, - boolean resetChannelPerIteration, - int soakIterations, - int maxFailures, - int maxAcceptablePerIterationLatencyMs, - int minTimeMsBetweenRpcs, - int overallTimeoutSeconds, - int soakRequestSize, - int soakResponseSize) - throws Exception { - int iterationsDone = 0; - int totalFailures = 0; - Histogram latencies = new Histogram(4 /* number of significant value digits */); - long startNs = System.nanoTime(); - ManagedChannel soakChannel = createChannel(); - TestServiceGrpc.TestServiceBlockingStub soakStub = TestServiceGrpc - .newBlockingStub(soakChannel) - .withInterceptors(recordClientCallInterceptor(clientCallCapture)); - for (int i = 0; i < soakIterations; i++) { - if (System.nanoTime() - startNs >= TimeUnit.SECONDS.toNanos(overallTimeoutSeconds)) { - break; - } - long earliestNextStartNs = System.nanoTime() - + TimeUnit.MILLISECONDS.toNanos(minTimeMsBetweenRpcs); - if (resetChannelPerIteration) { - soakChannel.shutdownNow(); - soakChannel.awaitTermination(10, TimeUnit.SECONDS); - soakChannel = createChannel(); - soakStub = TestServiceGrpc - .newBlockingStub(soakChannel) - .withInterceptors(recordClientCallInterceptor(clientCallCapture)); - } - SoakIterationResult result = - performOneSoakIteration(soakStub, soakRequestSize, soakResponseSize); - SocketAddress peer = clientCallCapture - .get().getAttributes().get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR); - StringBuilder logStr = new StringBuilder( - String.format( - Locale.US, - "soak iteration: %d elapsed_ms: %d peer: %s server_uri: %s", - i, result.getLatencyMs(), peer != null ? peer.toString() : "null", serverUri)); - if (!result.getStatus().equals(Status.OK)) { - totalFailures++; - logStr.append(String.format(" failed: %s", result.getStatus())); - } else if (result.getLatencyMs() > maxAcceptablePerIterationLatencyMs) { - totalFailures++; - logStr.append( - " exceeds max acceptable latency: " + maxAcceptablePerIterationLatencyMs); - } else { - logStr.append(" succeeded"); - } - System.err.println(logStr.toString()); - iterationsDone++; - latencies.recordValue(result.getLatencyMs()); - long remainingNs = earliestNextStartNs - System.nanoTime(); - if (remainingNs > 0) { - TimeUnit.NANOSECONDS.sleep(remainingNs); - } - } - soakChannel.shutdownNow(); - soakChannel.awaitTermination(10, TimeUnit.SECONDS); - System.err.println( - String.format( - Locale.US, - "(server_uri: %s) soak test ran: %d / %d iterations. total failures: %d. " - + "p50: %d ms, p90: %d ms, p100: %d ms", - serverUri, - iterationsDone, - soakIterations, - totalFailures, - latencies.getValueAtPercentile(50), - latencies.getValueAtPercentile(90), - latencies.getValueAtPercentile(100))); - // check if we timed out - String timeoutErrorMessage = - String.format( - Locale.US, - "(server_uri: %s) soak test consumed all %d seconds of time and quit early, " - + "only having ran %d out of desired %d iterations.", - serverUri, - overallTimeoutSeconds, - iterationsDone, - soakIterations); - assertEquals(timeoutErrorMessage, iterationsDone, soakIterations); - // check if we had too many failures - String tooManyFailuresErrorMessage = - String.format( - Locale.US, - "(server_uri: %s) soak test total failures: %d exceeds max failures " - + "threshold: %d.", - serverUri, totalFailures, maxFailures); - assertTrue(tooManyFailuresErrorMessage, totalFailures <= maxFailures); - } - private static void assertSuccess(StreamRecorder recorder) { if (recorder.getError() != null) { throw new AssertionError(recorder.getError()); @@ -2481,7 +2009,7 @@ private void assertResponse( } } - private void assertResponse(SimpleResponse expected, SimpleResponse actual) { + public void assertResponse(SimpleResponse expected, SimpleResponse actual) { assertPayload(expected.getPayload(), actual.getPayload()); assertEquals(expected.getUsername(), actual.getUsername()); assertEquals(expected.getOauthScope(), actual.getOauthScope()); @@ -2496,6 +2024,10 @@ private void assertPayload(Payload expected, Payload actual) { } } + private static void assertCodeEquals(Status.Code expected, Status actual) { + assertWithMessage("Unexpected status: %s", actual).that(actual.getCode()).isEqualTo(expected); + } + /** * Captures the request attributes. Useful for testing ServerCalls. * {@link ServerCall#getAttributes()} diff --git a/interop-testing/src/main/java/io/grpc/testing/integration/CustomBackendMetricsLoadBalancerProvider.java b/interop-testing/src/main/java/io/grpc/testing/integration/CustomBackendMetricsLoadBalancerProvider.java index 87ecf308674..b9a89a01e3a 100644 --- a/interop-testing/src/main/java/io/grpc/testing/integration/CustomBackendMetricsLoadBalancerProvider.java +++ b/interop-testing/src/main/java/io/grpc/testing/integration/CustomBackendMetricsLoadBalancerProvider.java @@ -16,8 +16,8 @@ package io.grpc.testing.integration; -import static io.grpc.testing.integration.AbstractInteropTest.ORCA_OOB_REPORT_KEY; -import static io.grpc.testing.integration.AbstractInteropTest.ORCA_RPC_REPORT_KEY; +import static io.grpc.testing.integration.TestServiceClient.ORCA_OOB_REPORT_KEY; +import static io.grpc.testing.integration.TestServiceClient.ORCA_RPC_REPORT_KEY; import io.grpc.ConnectivityState; import io.grpc.LoadBalancer; diff --git a/interop-testing/src/main/java/io/grpc/testing/integration/GrpclbFallbackTestClient.java b/interop-testing/src/main/java/io/grpc/testing/integration/GrpclbFallbackTestClient.java index 9fc017c0e35..8ce83f73e6d 100644 --- a/interop-testing/src/main/java/io/grpc/testing/integration/GrpclbFallbackTestClient.java +++ b/interop-testing/src/main/java/io/grpc/testing/integration/GrpclbFallbackTestClient.java @@ -16,7 +16,7 @@ package io.grpc.testing.integration; -import static com.google.common.base.Charsets.UTF_8; +import static java.nio.charset.StandardCharsets.UTF_8; import static org.junit.Assert.assertEquals; import com.google.common.io.CharStreams; diff --git a/interop-testing/src/main/java/io/grpc/testing/integration/Http2TestCases.java b/interop-testing/src/main/java/io/grpc/testing/integration/Http2TestCases.java index b064ee74243..d79c6798cc2 100644 --- a/interop-testing/src/main/java/io/grpc/testing/integration/Http2TestCases.java +++ b/interop-testing/src/main/java/io/grpc/testing/integration/Http2TestCases.java @@ -17,6 +17,7 @@ package io.grpc.testing.integration; import com.google.common.base.Preconditions; +import java.util.Locale; /** * Enum of HTTP/2 interop test cases. @@ -49,7 +50,7 @@ public String description() { public static Http2TestCases fromString(String s) { Preconditions.checkNotNull(s, "s"); try { - return Http2TestCases.valueOf(s.toUpperCase()); + return Http2TestCases.valueOf(s.toUpperCase(Locale.ROOT)); } catch (IllegalArgumentException ex) { throw new IllegalArgumentException("Invalid test case: " + s); } diff --git a/interop-testing/src/main/java/io/grpc/testing/integration/RpcBehaviorLoadBalancerProvider.java b/interop-testing/src/main/java/io/grpc/testing/integration/RpcBehaviorLoadBalancerProvider.java index 83c416765ec..f1410142bff 100644 --- a/interop-testing/src/main/java/io/grpc/testing/integration/RpcBehaviorLoadBalancerProvider.java +++ b/interop-testing/src/main/java/io/grpc/testing/integration/RpcBehaviorLoadBalancerProvider.java @@ -110,12 +110,20 @@ protected LoadBalancer delegate() { return delegateLb; } + @Deprecated @Override public void handleResolvedAddresses(ResolvedAddresses resolvedAddresses) { helper.setRpcBehavior( ((RpcBehaviorConfig) resolvedAddresses.getLoadBalancingPolicyConfig()).rpcBehavior); delegateLb.handleResolvedAddresses(resolvedAddresses); } + + @Override + public Status acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { + helper.setRpcBehavior( + ((RpcBehaviorConfig) resolvedAddresses.getLoadBalancingPolicyConfig()).rpcBehavior); + return delegateLb.acceptResolvedAddresses(resolvedAddresses); + } } /** diff --git a/interop-testing/src/main/java/io/grpc/testing/integration/SoakClient.java b/interop-testing/src/main/java/io/grpc/testing/integration/SoakClient.java new file mode 100644 index 00000000000..e119c826f09 --- /dev/null +++ b/interop-testing/src/main/java/io/grpc/testing/integration/SoakClient.java @@ -0,0 +1,300 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.testing.integration; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +import com.google.common.base.Function; +import com.google.protobuf.ByteString; +import io.grpc.CallOptions; +import io.grpc.Channel; +import io.grpc.ClientCall; +import io.grpc.ClientInterceptor; +import io.grpc.Grpc; +import io.grpc.ManagedChannel; +import io.grpc.MethodDescriptor; +import io.grpc.Status; +import io.grpc.StatusRuntimeException; +import io.grpc.testing.integration.Messages.Payload; +import io.grpc.testing.integration.Messages.SimpleRequest; +import io.grpc.testing.integration.Messages.SimpleResponse; +import java.net.SocketAddress; +import java.util.ArrayList; +import java.util.List; +import java.util.Locale; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; +import java.util.logging.Logger; +import org.HdrHistogram.Histogram; + +/** + * Shared implementation for rpc_soak and channel_soak. Unlike the tests in AbstractInteropTest, + * these "test cases" are only intended to be run from the command line. They don't fit the regular + * test patterns of AbstractInteropTest. + * https://github.com/grpc/grpc/blob/master/doc/interop-test-descriptions.md#rpc_soak + */ +final class SoakClient { + private static final Logger logger = Logger.getLogger(SoakClient.class.getName()); + + private static class SoakIterationResult { + public SoakIterationResult(long latencyMs, Status status) { + this.latencyMs = latencyMs; + this.status = status; + } + + public long getLatencyMs() { + return latencyMs; + } + + public Status getStatus() { + return status; + } + + private long latencyMs = -1; + private Status status = Status.OK; + } + + private static class ThreadResults { + private int threadFailures = 0; + private int iterationsDone = 0; + private Histogram latencies = new Histogram(4); + + public int getThreadFailures() { + return threadFailures; + } + + public int getIterationsDone() { + return iterationsDone; + } + + public Histogram getLatencies() { + return latencies; + } + } + + private static SoakIterationResult performOneSoakIteration( + TestServiceGrpc.TestServiceBlockingStub soakStub, int soakRequestSize, int soakResponseSize) + throws InterruptedException { + long startNs = System.nanoTime(); + Status status = Status.OK; + try { + final SimpleRequest request = + SimpleRequest.newBuilder() + .setResponseSize(soakResponseSize) + .setPayload( + Payload.newBuilder().setBody(ByteString.copyFrom(new byte[soakRequestSize]))) + .build(); + final SimpleResponse goldenResponse = + SimpleResponse.newBuilder() + .setPayload( + Payload.newBuilder().setBody(ByteString.copyFrom(new byte[soakResponseSize]))) + .build(); + assertResponse(goldenResponse, soakStub.unaryCall(request)); + } catch (StatusRuntimeException e) { + status = e.getStatus(); + } + long elapsedNs = System.nanoTime() - startNs; + return new SoakIterationResult(TimeUnit.NANOSECONDS.toMillis(elapsedNs), status); + } + + /** + * Runs large unary RPCs in a loop with configurable failure thresholds + * and channel creation behavior. + */ + public static void performSoakTest( + String serverUri, + int soakIterations, + int maxFailures, + int maxAcceptablePerIterationLatencyMs, + int minTimeMsBetweenRpcs, + int overallTimeoutSeconds, + int soakRequestSize, + int soakResponseSize, + int numThreads, + ManagedChannel sharedChannel, + Function maybeCreateChannel) + throws InterruptedException { + if (soakIterations % numThreads != 0) { + throw new IllegalArgumentException("soakIterations must be evenly divisible by numThreads."); + } + long startNs = System.nanoTime(); + Thread[] threads = new Thread[numThreads]; + int soakIterationsPerThread = soakIterations / numThreads; + List threadResultsList = new ArrayList<>(numThreads); + for (int i = 0; i < numThreads; i++) { + threadResultsList.add(new ThreadResults()); + } + for (int threadInd = 0; threadInd < numThreads; threadInd++) { + final int currentThreadInd = threadInd; + threads[threadInd] = new Thread(() -> { + try { + executeSoakTestInThread( + soakIterationsPerThread, + startNs, + minTimeMsBetweenRpcs, + soakRequestSize, + soakResponseSize, + maxAcceptablePerIterationLatencyMs, + overallTimeoutSeconds, + serverUri, + threadResultsList.get(currentThreadInd), + sharedChannel, + maybeCreateChannel); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RuntimeException("Thread interrupted: " + e.getMessage(), e); + } + }); + threads[threadInd].start(); + } + for (Thread thread : threads) { + thread.join(); + } + + int totalFailures = 0; + int iterationsDone = 0; + Histogram latencies = new Histogram(4); + for (ThreadResults threadResult :threadResultsList) { + totalFailures += threadResult.getThreadFailures(); + iterationsDone += threadResult.getIterationsDone(); + latencies.add(threadResult.getLatencies()); + } + logger.info( + String.format( + Locale.US, + "(server_uri: %s) soak test ran: %d / %d iterations. total failures: %d. " + + "p50: %d ms, p90: %d ms, p100: %d ms", + serverUri, + iterationsDone, + soakIterations, + totalFailures, + latencies.getValueAtPercentile(50), + latencies.getValueAtPercentile(90), + latencies.getValueAtPercentile(100))); + // check if we timed out + String timeoutErrorMessage = + String.format( + Locale.US, + "(server_uri: %s) soak test consumed all %d seconds of time and quit early, " + + "only having ran %d out of desired %d iterations.", + serverUri, + overallTimeoutSeconds, + iterationsDone, + soakIterations); + assertEquals(timeoutErrorMessage, iterationsDone, soakIterations); + // check if we had too many failures + String tooManyFailuresErrorMessage = + String.format( + Locale.US, + "(server_uri: %s) soak test total failures: %d exceeds max failures " + + "threshold: %d.", + serverUri, totalFailures, maxFailures); + assertTrue(tooManyFailuresErrorMessage, totalFailures <= maxFailures); + sharedChannel.shutdownNow(); + sharedChannel.awaitTermination(10, TimeUnit.SECONDS); + } + + private static void executeSoakTestInThread( + int soakIterationsPerThread, + long startNs, + int minTimeMsBetweenRpcs, + int soakRequestSize, + int soakResponseSize, + int maxAcceptablePerIterationLatencyMs, + int overallTimeoutSeconds, + String serverUri, + ThreadResults threadResults, + ManagedChannel sharedChannel, + Function maybeCreateChannel) throws InterruptedException { + ManagedChannel currentChannel = sharedChannel; + for (int i = 0; i < soakIterationsPerThread; i++) { + if (System.nanoTime() - startNs >= TimeUnit.SECONDS.toNanos(overallTimeoutSeconds)) { + break; + } + long earliestNextStartNs = System.nanoTime() + + TimeUnit.MILLISECONDS.toNanos(minTimeMsBetweenRpcs); + // recordClientCallInterceptor takes an AtomicReference. + AtomicReference> soakThreadClientCallCapture = new AtomicReference<>(); + currentChannel = maybeCreateChannel.apply(currentChannel); + TestServiceGrpc.TestServiceBlockingStub currentStub = TestServiceGrpc + .newBlockingStub(currentChannel) + .withInterceptors(recordClientCallInterceptor(soakThreadClientCallCapture)); + SoakIterationResult result = performOneSoakIteration(currentStub, + soakRequestSize, soakResponseSize); + SocketAddress peer = soakThreadClientCallCapture + .get().getAttributes().get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR); + StringBuilder logStr = new StringBuilder( + String.format( + Locale.US, + "thread id: %d soak iteration: %d elapsed_ms: %d peer: %s server_uri: %s", + Thread.currentThread().getId(), + i, result.getLatencyMs(), peer != null ? peer.toString() : "null", serverUri)); + if (!result.getStatus().equals(Status.OK)) { + threadResults.threadFailures++; + logStr.append(String.format(" failed: %s", result.getStatus())); + logger.warning(logStr.toString()); + } else if (result.getLatencyMs() > maxAcceptablePerIterationLatencyMs) { + threadResults.threadFailures++; + logStr.append( + " exceeds max acceptable latency: " + maxAcceptablePerIterationLatencyMs); + logger.warning(logStr.toString()); + } else { + logStr.append(" succeeded"); + logger.info(logStr.toString()); + } + threadResults.iterationsDone++; + threadResults.getLatencies().recordValue(result.getLatencyMs()); + long remainingNs = earliestNextStartNs - System.nanoTime(); + if (remainingNs > 0) { + TimeUnit.NANOSECONDS.sleep(remainingNs); + } + } + } + + private static void assertResponse(SimpleResponse expected, SimpleResponse actual) { + assertPayload(expected.getPayload(), actual.getPayload()); + assertEquals(expected.getUsername(), actual.getUsername()); + assertEquals(expected.getOauthScope(), actual.getOauthScope()); + } + + private static void assertPayload(Payload expected, Payload actual) { + // Compare non deprecated fields in Payload, to make this test forward compatible. + if (expected == null || actual == null) { + assertEquals(expected, actual); + } else { + assertEquals(expected.getBody(), actual.getBody()); + } + } + + /** + * Captures the ClientCall. Useful for testing {@link ClientCall#getAttributes()} + */ + private static ClientInterceptor recordClientCallInterceptor( + final AtomicReference> clientCallCapture) { + return new ClientInterceptor() { + @Override + public ClientCall interceptCall( + MethodDescriptor method, CallOptions callOptions, Channel next) { + ClientCall clientCall = next.newCall(method,callOptions); + clientCallCapture.set(clientCall); + return clientCall; + } + }; + } + +} diff --git a/interop-testing/src/main/java/io/grpc/testing/integration/StressTestClient.java b/interop-testing/src/main/java/io/grpc/testing/integration/StressTestClient.java index 5739f7e7469..7fafe43be05 100644 --- a/interop-testing/src/main/java/io/grpc/testing/integration/StressTestClient.java +++ b/interop-testing/src/main/java/io/grpc/testing/integration/StressTestClient.java @@ -308,7 +308,9 @@ void shutdown() { } try { - metricsServer.shutdownNow(); + if (metricsServer != null) { + metricsServer.shutdownNow(); + } } catch (Throwable t) { log.log(Level.WARNING, "Error shutting down metrics service!", t); } @@ -434,7 +436,7 @@ private static String serverAddressesToString(List addresses) private static String validTestCasesHelpText() { StringBuilder builder = new StringBuilder(); for (TestCases testCase : TestCases.values()) { - String strTestcase = testCase.name().toLowerCase(); + String strTestcase = testCase.toString(); builder.append("\n ") .append(strTestcase) .append(": ") diff --git a/interop-testing/src/main/java/io/grpc/testing/integration/TestCases.java b/interop-testing/src/main/java/io/grpc/testing/integration/TestCases.java index 85e5c31a4cb..2d16065254a 100644 --- a/interop-testing/src/main/java/io/grpc/testing/integration/TestCases.java +++ b/interop-testing/src/main/java/io/grpc/testing/integration/TestCases.java @@ -17,6 +17,7 @@ package io.grpc.testing.integration; import com.google.common.base.Preconditions; +import java.util.Locale; /** * Enum of interop test cases. @@ -79,6 +80,11 @@ public String description() { */ public static TestCases fromString(String s) { Preconditions.checkNotNull(s, "s"); - return TestCases.valueOf(s.toUpperCase()); + return TestCases.valueOf(s.toUpperCase(Locale.ROOT)); + } + + @Override + public String toString() { + return name().toLowerCase(Locale.ROOT); } } diff --git a/interop-testing/src/main/java/io/grpc/testing/integration/TestServiceClient.java b/interop-testing/src/main/java/io/grpc/testing/integration/TestServiceClient.java index e059e81ed54..125d876b705 100644 --- a/interop-testing/src/main/java/io/grpc/testing/integration/TestServiceClient.java +++ b/interop-testing/src/main/java/io/grpc/testing/integration/TestServiceClient.java @@ -16,10 +16,25 @@ package io.grpc.testing.integration; +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotEquals; +import static org.junit.Assert.assertTrue; + +import com.google.auth.oauth2.AccessToken; +import com.google.auth.oauth2.ComputeEngineCredentials; +import com.google.auth.oauth2.GoogleCredentials; +import com.google.auth.oauth2.OAuth2Credentials; +import com.google.auth.oauth2.ServiceAccountCredentials; import com.google.common.annotations.VisibleForTesting; import com.google.common.io.Files; +import com.google.protobuf.ByteString; +import io.grpc.CallOptions; +import io.grpc.Channel; import io.grpc.ChannelCredentials; import io.grpc.ClientInterceptor; +import io.grpc.ClientInterceptors; import io.grpc.Grpc; import io.grpc.InsecureChannelCredentials; import io.grpc.InsecureServerCredentials; @@ -28,11 +43,13 @@ import io.grpc.ManagedChannel; import io.grpc.ManagedChannelBuilder; import io.grpc.Metadata; +import io.grpc.MethodDescriptor; import io.grpc.ServerBuilder; import io.grpc.TlsChannelCredentials; import io.grpc.alts.AltsChannelCredentials; import io.grpc.alts.ComputeEngineChannelCredentials; import io.grpc.alts.GoogleDefaultChannelCredentials; +import io.grpc.auth.MoreCallCredentials; import io.grpc.internal.GrpcUtil; import io.grpc.internal.JsonParser; import io.grpc.netty.InsecureFromHttp1ChannelCredentials; @@ -40,13 +57,27 @@ import io.grpc.netty.NettyChannelBuilder; import io.grpc.okhttp.InternalOkHttpChannelBuilder; import io.grpc.okhttp.OkHttpChannelBuilder; +import io.grpc.stub.ClientCalls; import io.grpc.stub.MetadataUtils; +import io.grpc.stub.StreamObserver; import io.grpc.testing.TlsTesting; +import io.grpc.testing.integration.Messages.Payload; +import io.grpc.testing.integration.Messages.ResponseParameters; +import io.grpc.testing.integration.Messages.SimpleRequest; +import io.grpc.testing.integration.Messages.SimpleResponse; +import io.grpc.testing.integration.Messages.StreamingOutputCallRequest; +import io.grpc.testing.integration.Messages.StreamingOutputCallResponse; +import io.grpc.testing.integration.Messages.TestOrcaReport; import java.io.File; import java.io.FileInputStream; +import java.io.InputStream; import java.nio.charset.Charset; +import java.util.Arrays; import java.util.Map; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; import javax.annotation.Nullable; /** @@ -57,6 +88,11 @@ public class TestServiceClient { private static final Charset UTF_8 = Charset.forName("UTF-8"); + static final CallOptions.Key> + ORCA_RPC_REPORT_KEY = CallOptions.Key.create("orca-rpc-report"); + static final CallOptions.Key> + ORCA_OOB_REPORT_KEY = CallOptions.Key.create("orca-oob-report"); + /** * The main application allowing this client to be launched from the command line. */ @@ -98,6 +134,7 @@ public static void main(String[] args) throws Exception { soakIterations * soakPerIterationMaxAcceptableLatencyMs / 1000; private int soakRequestSize = 271828; private int soakResponseSize = 314159; + private int numThreads = 1; private String additionalMetadata = ""; private static LoadBalancerProvider customBackendMetricsLoadBalancerProvider; @@ -178,6 +215,8 @@ void parseArgs(String[] args) throws Exception { soakRequestSize = Integer.parseInt(value); } else if ("soak_response_size".equals(key)) { soakResponseSize = Integer.parseInt(value); + } else if ("soak_num_threads".equals(key)) { + numThreads = Integer.parseInt(value); } else if ("additional_metadata".equals(key)) { additionalMetadata = value; } else { @@ -254,6 +293,9 @@ void parseArgs(String[] args) throws Exception { + "\n --soak_response_size " + "\n The response size in a soak RPC. Default " + c.soakResponseSize + + "\n --soak_num_threads The number of threads for concurrent execution of the " + + "\n soak tests (rpc_soak or channel_soak). Default " + + c.numThreads + "\n --additional_metadata " + "\n Additional metadata to send in each request, as a " + "\n semicolon-separated list of key:value pairs. Default " @@ -481,32 +523,35 @@ private void runTest(TestCases testCase) throws Exception { } case RPC_SOAK: { - tester.performSoakTest( + SoakClient.performSoakTest( serverHost, - false /* resetChannelPerIteration */, soakIterations, soakMaxFailures, soakPerIterationMaxAcceptableLatencyMs, soakMinTimeMsBetweenRpcs, soakOverallTimeoutSeconds, soakRequestSize, - soakResponseSize); + soakResponseSize, + numThreads, + tester.createChannelBuilder().build(), + (currentChannel) -> currentChannel); break; } case CHANNEL_SOAK: { - tester.performSoakTest( + SoakClient.performSoakTest( serverHost, - true /* resetChannelPerIteration */, soakIterations, soakMaxFailures, soakPerIterationMaxAcceptableLatencyMs, soakMinTimeMsBetweenRpcs, soakOverallTimeoutSeconds, soakRequestSize, - soakResponseSize); + soakResponseSize, + numThreads, + tester.createChannelBuilder().build(), + (currentChannel) -> tester.createNewChannel(currentChannel)); break; - } case ORCA_PER_RPC: { @@ -526,7 +571,7 @@ private void runTest(TestCases testCase) throws Exception { /* Parses input string as a semi-colon-separated list of colon-separated key/value pairs. * Allow any character but semicolons in values. - * If the string is emtpy, return null. + * If the string is empty, return null. * Otherwise, return a client interceptor which inserts the provided metadata. */ @Nullable @@ -668,6 +713,323 @@ protected ManagedChannelBuilder createChannelBuilder() { return okBuilder.intercept(createCensusStatsClientInterceptor()); } + ManagedChannel createNewChannel(ManagedChannel currentChannel) { + currentChannel.shutdownNow(); + try { + currentChannel.awaitTermination(10, TimeUnit.SECONDS); + } catch (InterruptedException e) { + throw new RuntimeException("Interrupted while creating a new channel", e); + } + return createChannel(); + } + + /** + * Assuming "pick_first" policy is used, tests that all requests are sent to the same server. + */ + public void pickFirstUnary() throws Exception { + SimpleRequest request = SimpleRequest.newBuilder() + .setResponseSize(1) + .setFillServerId(true) + .setPayload(Payload.newBuilder().setBody(ByteString.copyFrom(new byte[1]))) + .build(); + + SimpleResponse firstResponse = blockingStub.unaryCall(request); + // Increase the chance of all servers are connected, in case the channel should be doing + // round_robin instead. + Thread.sleep(5000); + for (int i = 0; i < 100; i++) { + SimpleResponse response = blockingStub.unaryCall(request); + assertThat(response.getServerId()).isEqualTo(firstResponse.getServerId()); + } + } + + /** + * Sends a cacheable unary rpc using GET. Requires that the server is behind a caching proxy. + */ + public void cacheableUnary() { + // THIS TEST IS BROKEN. Enabling safe just on the MethodDescriptor does nothing by itself. + // This test would need to enable GET on the channel. + // Set safe to true. + MethodDescriptor safeCacheableUnaryCallMethod = + TestServiceGrpc.getCacheableUnaryCallMethod().toBuilder().setSafe(true).build(); + // Set fake user IP since some proxies (GFE) won't cache requests from localhost. + Metadata.Key userIpKey = + Metadata.Key.of("x-user-ip", Metadata.ASCII_STRING_MARSHALLER); + Metadata metadata = new Metadata(); + metadata.put(userIpKey, "1.2.3.4"); + Channel channelWithUserIpKey = ClientInterceptors.intercept( + channel, MetadataUtils.newAttachHeadersInterceptor(metadata)); + SimpleRequest requests1And2 = + SimpleRequest.newBuilder() + .setPayload( + Payload.newBuilder() + .setBody(ByteString.copyFromUtf8(String.valueOf(System.nanoTime())))) + .build(); + SimpleRequest request3 = + SimpleRequest.newBuilder() + .setPayload( + Payload.newBuilder() + .setBody(ByteString.copyFromUtf8(String.valueOf(System.nanoTime())))) + .build(); + + SimpleResponse response1 = + ClientCalls.blockingUnaryCall( + channelWithUserIpKey, safeCacheableUnaryCallMethod, CallOptions.DEFAULT, + requests1And2); + SimpleResponse response2 = + ClientCalls.blockingUnaryCall( + channelWithUserIpKey, safeCacheableUnaryCallMethod, CallOptions.DEFAULT, + requests1And2); + SimpleResponse response3 = + ClientCalls.blockingUnaryCall( + channelWithUserIpKey, safeCacheableUnaryCallMethod, CallOptions.DEFAULT, request3); + + assertEquals(response1, response2); + assertNotEquals(response1, response3); + // THIS TEST IS BROKEN. See comment at start of method. + } + + /** Sends a large unary rpc with service account credentials. */ + public void serviceAccountCreds(String jsonKey, InputStream credentialsStream, String authScope) + throws Exception { + // cast to ServiceAccountCredentials to double-check the right type of object was created. + GoogleCredentials credentials = + ServiceAccountCredentials.class.cast(GoogleCredentials.fromStream(credentialsStream)); + credentials = credentials.createScoped(Arrays.asList(authScope)); + TestServiceGrpc.TestServiceBlockingStub stub = blockingStub + .withCallCredentials(MoreCallCredentials.from(credentials)); + final SimpleRequest request = SimpleRequest.newBuilder() + .setFillUsername(true) + .setFillOauthScope(true) + .setResponseSize(314159) + .setPayload(Payload.newBuilder() + .setBody(ByteString.copyFrom(new byte[271828]))) + .build(); + + final SimpleResponse response = stub.unaryCall(request); + assertFalse(response.getUsername().isEmpty()); + assertTrue("Received username: " + response.getUsername(), + jsonKey.contains(response.getUsername())); + assertFalse(response.getOauthScope().isEmpty()); + assertTrue("Received oauth scope: " + response.getOauthScope(), + authScope.contains(response.getOauthScope())); + + final SimpleResponse goldenResponse = SimpleResponse.newBuilder() + .setOauthScope(response.getOauthScope()) + .setUsername(response.getUsername()) + .setPayload(Payload.newBuilder() + .setBody(ByteString.copyFrom(new byte[314159]))) + .build(); + assertResponse(goldenResponse, response); + } + + /** Sends a large unary rpc with compute engine credentials. */ + public void computeEngineCreds(String serviceAccount, String oauthScope) throws Exception { + ComputeEngineCredentials credentials = ComputeEngineCredentials.create(); + TestServiceGrpc.TestServiceBlockingStub stub = blockingStub + .withCallCredentials(MoreCallCredentials.from(credentials)); + final SimpleRequest request = SimpleRequest.newBuilder() + .setFillUsername(true) + .setFillOauthScope(true) + .setResponseSize(314159) + .setPayload(Payload.newBuilder() + .setBody(ByteString.copyFrom(new byte[271828]))) + .build(); + + final SimpleResponse response = stub.unaryCall(request); + assertEquals(serviceAccount, response.getUsername()); + assertFalse(response.getOauthScope().isEmpty()); + assertTrue("Received oauth scope: " + response.getOauthScope(), + oauthScope.contains(response.getOauthScope())); + + final SimpleResponse goldenResponse = SimpleResponse.newBuilder() + .setOauthScope(response.getOauthScope()) + .setUsername(response.getUsername()) + .setPayload(Payload.newBuilder() + .setBody(ByteString.copyFrom(new byte[314159]))) + .build(); + assertResponse(goldenResponse, response); + } + + /** Sends an unary rpc with ComputeEngineChannelBuilder. */ + public void computeEngineChannelCredentials( + String defaultServiceAccount, + TestServiceGrpc.TestServiceBlockingStub computeEngineStub) throws Exception { + final SimpleRequest request = SimpleRequest.newBuilder() + .setFillUsername(true) + .setResponseSize(314159) + .setPayload(Payload.newBuilder() + .setBody(ByteString.copyFrom(new byte[271828]))) + .build(); + final SimpleResponse response = computeEngineStub.unaryCall(request); + assertEquals(defaultServiceAccount, response.getUsername()); + final SimpleResponse goldenResponse = SimpleResponse.newBuilder() + .setUsername(defaultServiceAccount) + .setPayload(Payload.newBuilder() + .setBody(ByteString.copyFrom(new byte[314159]))) + .build(); + assertResponse(goldenResponse, response); + } + + /** Test JWT-based auth. */ + public void jwtTokenCreds(InputStream serviceAccountJson) throws Exception { + final SimpleRequest request = SimpleRequest.newBuilder() + .setResponseSize(314159) + .setPayload(Payload.newBuilder() + .setBody(ByteString.copyFrom(new byte[271828]))) + .setFillUsername(true) + .build(); + + ServiceAccountCredentials credentials = (ServiceAccountCredentials) + GoogleCredentials.fromStream(serviceAccountJson); + TestServiceGrpc.TestServiceBlockingStub stub = blockingStub + .withCallCredentials(MoreCallCredentials.from(credentials)); + SimpleResponse response = stub.unaryCall(request); + assertEquals(credentials.getClientEmail(), response.getUsername()); + assertEquals(314159, response.getPayload().getBody().size()); + } + + /** Sends a unary rpc with raw oauth2 access token credentials. */ + public void oauth2AuthToken(String jsonKey, InputStream credentialsStream, String authScope) + throws Exception { + GoogleCredentials utilCredentials = + GoogleCredentials.fromStream(credentialsStream); + utilCredentials = utilCredentials.createScoped(Arrays.asList(authScope)); + AccessToken accessToken = utilCredentials.refreshAccessToken(); + + OAuth2Credentials credentials = OAuth2Credentials.create(accessToken); + + TestServiceGrpc.TestServiceBlockingStub stub = blockingStub + .withCallCredentials(MoreCallCredentials.from(credentials)); + final SimpleRequest request = SimpleRequest.newBuilder() + .setFillUsername(true) + .setFillOauthScope(true) + .build(); + + final SimpleResponse response = stub.unaryCall(request); + assertFalse(response.getUsername().isEmpty()); + assertTrue("Received username: " + response.getUsername(), + jsonKey.contains(response.getUsername())); + assertFalse(response.getOauthScope().isEmpty()); + assertTrue("Received oauth scope: " + response.getOauthScope(), + authScope.contains(response.getOauthScope())); + } + + /** Sends a unary rpc with "per rpc" raw oauth2 access token credentials. */ + public void perRpcCreds(String jsonKey, InputStream credentialsStream, String oauthScope) + throws Exception { + // In gRpc Java, we don't have per Rpc credentials, user can use an intercepted stub only once + // for that purpose. + // So, this test is identical to oauth2_auth_token test. + oauth2AuthToken(jsonKey, credentialsStream, oauthScope); + } + + /** Sends an unary rpc with "google default credentials". */ + public void googleDefaultCredentials( + String defaultServiceAccount, + TestServiceGrpc.TestServiceBlockingStub googleDefaultStub) throws Exception { + final SimpleRequest request = SimpleRequest.newBuilder() + .setFillUsername(true) + .setResponseSize(314159) + .setPayload(Payload.newBuilder() + .setBody(ByteString.copyFrom(new byte[271828]))) + .build(); + final SimpleResponse response = googleDefaultStub.unaryCall(request); + assertEquals(defaultServiceAccount, response.getUsername()); + + final SimpleResponse goldenResponse = SimpleResponse.newBuilder() + .setUsername(defaultServiceAccount) + .setPayload(Payload.newBuilder() + .setBody(ByteString.copyFrom(new byte[314159]))) + .build(); + assertResponse(goldenResponse, response); + } + + /** + * Test backend metrics per query reporting: expect the test client LB policy to receive load + * reports. + */ + public void testOrcaPerRpc() throws Exception { + AtomicReference reportHolder = new AtomicReference<>(); + TestOrcaReport answer = TestOrcaReport.newBuilder() + .setCpuUtilization(0.8210) + .setMemoryUtilization(0.5847) + .putRequestCost("cost", 3456.32) + .putUtilization("util", 0.30499) + .build(); + blockingStub.withOption(ORCA_RPC_REPORT_KEY, reportHolder).unaryCall( + SimpleRequest.newBuilder().setOrcaPerQueryReport(answer).build()); + assertThat(reportHolder.get()).isEqualTo(answer); + } + + /** + * Test backend metrics OOB reporting: expect the test client LB policy to receive load reports. + */ + public void testOrcaOob() throws Exception { + AtomicReference reportHolder = new AtomicReference<>(); + final TestOrcaReport answer = TestOrcaReport.newBuilder() + .setCpuUtilization(0.8210) + .setMemoryUtilization(0.5847) + .putUtilization("util", 0.30499) + .build(); + final TestOrcaReport answer2 = TestOrcaReport.newBuilder() + .setCpuUtilization(0.29309) + .setMemoryUtilization(0.2) + .putUtilization("util", 0.2039) + .build(); + + final int retryLimit = 5; + BlockingQueue queue = new LinkedBlockingQueue<>(); + final Object lastItem = new Object(); + StreamObserver streamObserver = + asyncStub.fullDuplexCall(new StreamObserver() { + + @Override + public void onNext(StreamingOutputCallResponse value) { + queue.add(value); + } + + @Override + public void onError(Throwable t) { + queue.add(t); + } + + @Override + public void onCompleted() { + queue.add(lastItem); + } + }); + + streamObserver.onNext(StreamingOutputCallRequest.newBuilder() + .setOrcaOobReport(answer) + .addResponseParameters(ResponseParameters.newBuilder().setSize(1).build()).build()); + assertThat(queue.take()).isInstanceOf(StreamingOutputCallResponse.class); + int i = 0; + for (; i < retryLimit; i++) { + Thread.sleep(1000); + blockingStub.withOption(ORCA_OOB_REPORT_KEY, reportHolder).emptyCall(EMPTY); + if (answer.equals(reportHolder.get())) { + break; + } + } + assertThat(i).isLessThan(retryLimit); + streamObserver.onNext(StreamingOutputCallRequest.newBuilder() + .setOrcaOobReport(answer2) + .addResponseParameters(ResponseParameters.newBuilder().setSize(1).build()).build()); + assertThat(queue.take()).isInstanceOf(StreamingOutputCallResponse.class); + + for (i = 0; i < retryLimit; i++) { + Thread.sleep(1000); + blockingStub.withOption(ORCA_OOB_REPORT_KEY, reportHolder).emptyCall(EMPTY); + if (reportHolder.get().equals(answer2)) { + break; + } + } + assertThat(i).isLessThan(retryLimit); + streamObserver.onCompleted(); + assertThat(queue.take()).isSameInstanceAs(lastItem); + } + @Override protected boolean metricsExpected() { // Exact message size doesn't match when testing with Go servers: @@ -697,7 +1059,7 @@ protected int operationTimeoutMillis() { private static String validTestCasesHelpText() { StringBuilder builder = new StringBuilder(); for (TestCases testCase : TestCases.values()) { - String strTestcase = testCase.name().toLowerCase(); + String strTestcase = testCase.toString(); builder.append("\n ") .append(strTestcase) .append(": ") diff --git a/interop-testing/src/main/java/io/grpc/testing/integration/TestServiceImpl.java b/interop-testing/src/main/java/io/grpc/testing/integration/TestServiceImpl.java index 8fa272122d0..4742675416b 100644 --- a/interop-testing/src/main/java/io/grpc/testing/integration/TestServiceImpl.java +++ b/interop-testing/src/main/java/io/grpc/testing/integration/TestServiceImpl.java @@ -18,6 +18,7 @@ import com.google.common.base.Preconditions; import com.google.common.collect.Queues; +import com.google.errorprone.annotations.concurrent.GuardedBy; import com.google.protobuf.ByteString; import io.grpc.ForwardingServerCall.SimpleForwardingServerCall; import io.grpc.Metadata; @@ -44,17 +45,14 @@ import java.util.ArrayDeque; import java.util.Arrays; import java.util.HashMap; -import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Queue; import java.util.Random; -import java.util.Set; import java.util.concurrent.Future; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.Semaphore; import java.util.concurrent.TimeUnit; -import javax.annotation.concurrent.GuardedBy; /** * Implementation of the business logic for the TestService. Uses an executor to schedule chunks @@ -511,27 +509,30 @@ public static List interceptors() { } /** - * Echo the request headers from a client into response headers and trailers. Useful for + * Echo a request header from a client into response headers and trailers. Useful for * testing end-to-end metadata propagation. */ - private static ServerInterceptor echoRequestHeadersInterceptor(final Metadata.Key... keys) { - final Set> keySet = new HashSet<>(Arrays.asList(keys)); + private static ServerInterceptor echoRequestHeadersInterceptor(final Metadata.Key key) { return new ServerInterceptor() { @Override public ServerCall.Listener interceptCall( ServerCall call, - final Metadata requestHeaders, + Metadata requestHeaders, ServerCallHandler next) { + if (!requestHeaders.containsKey(key)) { + return next.startCall(call, requestHeaders); + } + T value = requestHeaders.get(key); return next.startCall(new SimpleForwardingServerCall(call) { @Override public void sendHeaders(Metadata responseHeaders) { - responseHeaders.merge(requestHeaders, keySet); + responseHeaders.put(key, value); super.sendHeaders(responseHeaders); } @Override public void close(Status status, Metadata trailers) { - trailers.merge(requestHeaders, keySet); + trailers.put(key, value); super.close(status, trailers); } }, requestHeaders); @@ -540,52 +541,48 @@ public void close(Status status, Metadata trailers) { } /** - * Echoes request headers with the specified key(s) from a client into response headers only. + * Echoes request headers with the specified key from a client into response headers only. */ - private static ServerInterceptor echoRequestMetadataInHeaders(final Metadata.Key... keys) { - final Set> keySet = new HashSet<>(Arrays.asList(keys)); + private static ServerInterceptor echoRequestMetadataInHeaders(final Metadata.Key key) { return new ServerInterceptor() { @Override public ServerCall.Listener interceptCall( ServerCall call, final Metadata requestHeaders, ServerCallHandler next) { + if (!requestHeaders.containsKey(key)) { + return next.startCall(call, requestHeaders); + } + T value = requestHeaders.get(key); return next.startCall(new SimpleForwardingServerCall(call) { @Override public void sendHeaders(Metadata responseHeaders) { - responseHeaders.merge(requestHeaders, keySet); + responseHeaders.put(key, value); super.sendHeaders(responseHeaders); } - - @Override - public void close(Status status, Metadata trailers) { - super.close(status, trailers); - } }, requestHeaders); } }; } /** - * Echoes request headers with the specified key(s) from a client into response trailers only. + * Echoes request headers with the specified key from a client into response trailers only. */ - private static ServerInterceptor echoRequestMetadataInTrailers(final Metadata.Key... keys) { - final Set> keySet = new HashSet<>(Arrays.asList(keys)); + private static ServerInterceptor echoRequestMetadataInTrailers(final Metadata.Key key) { return new ServerInterceptor() { @Override public ServerCall.Listener interceptCall( ServerCall call, final Metadata requestHeaders, ServerCallHandler next) { + if (!requestHeaders.containsKey(key)) { + return next.startCall(call, requestHeaders); + } + T value = requestHeaders.get(key); return next.startCall(new SimpleForwardingServerCall(call) { - @Override - public void sendHeaders(Metadata responseHeaders) { - super.sendHeaders(responseHeaders); - } - @Override public void close(Status status, Metadata trailers) { - trailers.merge(requestHeaders, keySet); + trailers.put(key, value); super.close(status, trailers); } }, requestHeaders); diff --git a/interop-testing/src/main/java/io/grpc/testing/integration/TestServiceServer.java b/interop-testing/src/main/java/io/grpc/testing/integration/TestServiceServer.java index f2f5b43ea1b..fc4cdf9178f 100644 --- a/interop-testing/src/main/java/io/grpc/testing/integration/TestServiceServer.java +++ b/interop-testing/src/main/java/io/grpc/testing/integration/TestServiceServer.java @@ -22,14 +22,20 @@ import io.grpc.Grpc; import io.grpc.InsecureServerCredentials; import io.grpc.Server; +import io.grpc.ServerBuilder; import io.grpc.ServerCredentials; import io.grpc.ServerInterceptors; import io.grpc.TlsServerCredentials; import io.grpc.alts.AltsServerCredentials; +import io.grpc.netty.NettyServerBuilder; import io.grpc.services.MetricRecorder; import io.grpc.testing.TlsTesting; import io.grpc.xds.orca.OrcaMetricReportingServerInterceptor; import io.grpc.xds.orca.OrcaServiceImpl; +import java.net.InetSocketAddress; +import java.net.SocketAddress; +import java.util.List; +import java.util.Locale; import java.util.concurrent.Executors; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; @@ -73,6 +79,7 @@ public void run() { private ScheduledExecutorService executor; private Server server; private int localHandshakerPort = -1; + private Util.AddressType addressType = Util.AddressType.IPV4_IPV6; @VisibleForTesting void parseArgs(String[] args) { @@ -103,6 +110,8 @@ void parseArgs(String[] args) { useAlts = Boolean.parseBoolean(value); } else if ("local_handshaker_port".equals(key)) { localHandshakerPort = Integer.parseInt(value); + } else if ("address_type".equals(key)) { + addressType = Util.AddressType.valueOf(value.toUpperCase(Locale.ROOT)); } else if ("grpc_version".equals(key)) { if (!"2".equals(value)) { System.err.println("Only grpc version 2 is supported"); @@ -130,11 +139,14 @@ void parseArgs(String[] args) { + "\n --local_handshaker_port=PORT" + "\n Use local ALTS handshaker service on the specified port " + "\n for testing. Only effective when --use_alts=true." + + "\n --address_type=IPV4|IPV6|IPV4_IPV6" + + "\n What type of addresses to listen on. Default IPV4_IPV6" ); System.exit(1); } } + @SuppressWarnings("AddressSelection") @VisibleForTesting void start() throws Exception { executor = Executors.newSingleThreadScheduledExecutor(); @@ -156,7 +168,40 @@ void start() throws Exception { MetricRecorder metricRecorder = MetricRecorder.newInstance(); BindableService orcaOobService = OrcaServiceImpl.createService(executor, metricRecorder, 1, TimeUnit.SECONDS); - server = Grpc.newServerBuilderForPort(port, serverCreds) + + // Create ServerBuilder with appropriate addresses + // - IPV4_IPV6: bind to wildcard which covers all addresses on all interfaces of both families + // - IPV4: bind to v4 address for local hostname + v4 localhost + // - IPV6: bind to all v6 addresses for local hostname + v6 localhost + ServerBuilder serverBuilder; + switch (addressType) { + case IPV4_IPV6: + serverBuilder = Grpc.newServerBuilderForPort(port, serverCreds); + break; + case IPV4: + SocketAddress v4Address = Util.getV4Address(port); + InetSocketAddress localV4Address = new InetSocketAddress("127.0.0.1", port); + serverBuilder = + NettyServerBuilder.forAddress(localV4Address, serverCreds); + if (v4Address != null && !v4Address.equals(localV4Address)) { + ((NettyServerBuilder) serverBuilder).addListenAddress(v4Address); + } + break; + case IPV6: + List v6Addresses = Util.getV6Addresses(port); + InetSocketAddress localV6Address = new InetSocketAddress("::1", port); + serverBuilder = + NettyServerBuilder.forAddress(localV6Address, serverCreds); + for (SocketAddress address : v6Addresses) { + if (!address.equals(localV6Address)) { + ((NettyServerBuilder) serverBuilder).addListenAddress(address); + } + } + break; + default: + throw new AssertionError("Unknown address type: " + addressType); + } + server = serverBuilder .maxInboundMessageSize(AbstractInteropTest.MAX_MESSAGE_SIZE) .addService( ServerInterceptors.intercept( @@ -187,4 +232,5 @@ private void blockUntilShutdown() throws InterruptedException { server.awaitTermination(); } } + } diff --git a/interop-testing/src/main/java/io/grpc/testing/integration/Util.java b/interop-testing/src/main/java/io/grpc/testing/integration/Util.java index b66114f12c0..50da0a6373d 100644 --- a/interop-testing/src/main/java/io/grpc/testing/integration/Util.java +++ b/interop-testing/src/main/java/io/grpc/testing/integration/Util.java @@ -16,10 +16,17 @@ package io.grpc.testing.integration; +import com.google.common.annotations.VisibleForTesting; import com.google.protobuf.MessageLite; import com.google.protobuf.StringValue; import io.grpc.Metadata; import io.grpc.protobuf.lite.ProtoLiteUtils; +import java.io.IOException; +import java.net.InetAddress; +import java.net.ServerSocket; +import java.net.SocketAddress; +import java.net.UnknownHostException; +import java.util.ArrayList; import java.util.List; import org.junit.Assert; @@ -66,4 +73,51 @@ public static void assertEquals(List expected, } } } + + static List getV6Addresses(int port) throws UnknownHostException { + List v6addresses = new ArrayList<>(); + InetAddress[] addresses = InetAddress.getAllByName(InetAddress.getLocalHost().getHostName()); + for (InetAddress address : addresses) { + if (address.getAddress().length != 4) { + v6addresses.add(new java.net.InetSocketAddress(address, port)); + } + } + return v6addresses; + } + + static SocketAddress getV4Address(int port) throws UnknownHostException { + InetAddress[] addresses = InetAddress.getAllByName(InetAddress.getLocalHost().getHostName()); + for (InetAddress address : addresses) { + if (address.getAddress().length == 4) { + return new java.net.InetSocketAddress(address, port); + } + } + return null; // means it is v6 only + } + + + /** + * Picks a port that is not used right at this moment. + * Warning: Not thread safe. May see "BindException: Address already in use: bind" if using the + * returned port to create a new server socket when other threads/processes are concurrently + * creating new sockets without a specific port. + */ + public static int pickUnusedPort() { + try { + ServerSocket serverSocket = new ServerSocket(0); + int port = serverSocket.getLocalPort(); + serverSocket.close(); + return port; + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + + @VisibleForTesting + enum AddressType { + IPV4, + IPV6, + IPV4_IPV6 + } } diff --git a/interop-testing/src/main/java/io/grpc/testing/integration/XdsFederationTestClient.java b/interop-testing/src/main/java/io/grpc/testing/integration/XdsFederationTestClient.java index f55ccbdefa7..bba282b7b6f 100644 --- a/interop-testing/src/main/java/io/grpc/testing/integration/XdsFederationTestClient.java +++ b/interop-testing/src/main/java/io/grpc/testing/integration/XdsFederationTestClient.java @@ -22,9 +22,10 @@ import io.grpc.ChannelCredentials; import io.grpc.Grpc; import io.grpc.InsecureChannelCredentials; -import io.grpc.ManagedChannelBuilder; +import io.grpc.ManagedChannel; import io.grpc.alts.ComputeEngineChannelCredentials; import java.util.ArrayList; +import java.util.concurrent.TimeUnit; import java.util.logging.Logger; /** @@ -44,26 +45,8 @@ public final class XdsFederationTestClient { public static void main(String[] args) throws Exception { final XdsFederationTestClient client = new XdsFederationTestClient(); client.parseArgs(args); - Runtime.getRuntime() - .addShutdownHook( - new Thread() { - @Override - @SuppressWarnings("CatchAndPrintStackTrace") - public void run() { - System.out.println("Shutting down"); - try { - client.tearDown(); - } catch (RuntimeException e) { - e.printStackTrace(); - } - } - }); client.setUp(); - try { - client.run(); - } finally { - client.tearDown(); - } + client.run(); System.exit(0); } @@ -209,22 +192,13 @@ void setUp() { for (int i = 0; i < uris.length; i++) { clients.add(new InnerClient(creds[i], uris[i])); } - for (InnerClient c : clients) { - c.setUp(); - } - } - - private synchronized void tearDown() { - for (InnerClient c : clients) { - c.tearDown(); - } } /** * Wraps a single client stub configuration and executes a * soak test case with that configuration. */ - class InnerClient extends AbstractInteropTest { + class InnerClient { private final String credentialsType; private final String serverUri; private boolean runSucceeded = false; @@ -245,29 +219,43 @@ public boolean runSucceeded() { /** * Run the intended soak test. */ - public void run() { - boolean resetChannelPerIteration; - switch (testCase) { - case "rpc_soak": - resetChannelPerIteration = false; - break; - case "channel_soak": - resetChannelPerIteration = true; - break; - default: - throw new RuntimeException("invalid testcase: " + testCase); - } + public void run() throws InterruptedException { try { - performSoakTest( - serverUri, - resetChannelPerIteration, - soakIterations, - soakMaxFailures, - soakPerIterationMaxAcceptableLatencyMs, - soakMinTimeMsBetweenRpcs, - soakOverallTimeoutSeconds, - soakRequestSize, - soakResponseSize); + switch (testCase) { + case "rpc_soak": { + SoakClient.performSoakTest( + serverUri, + soakIterations, + soakMaxFailures, + soakPerIterationMaxAcceptableLatencyMs, + soakMinTimeMsBetweenRpcs, + soakOverallTimeoutSeconds, + soakRequestSize, + soakResponseSize, + 1, + createChannel(), + (currentChannel) -> currentChannel); + } + break; + case "channel_soak": { + SoakClient.performSoakTest( + serverUri, + soakIterations, + soakMaxFailures, + soakPerIterationMaxAcceptableLatencyMs, + soakMinTimeMsBetweenRpcs, + soakOverallTimeoutSeconds, + soakRequestSize, + soakResponseSize, + 1, + createChannel(), + (currentChannel) -> createNewChannel(currentChannel)); + } + break; + default: + throw new RuntimeException("invalid testcase: " + testCase); + } + logger.info("Test case: " + testCase + " done for server: " + serverUri); runSucceeded = true; } catch (Exception e) { @@ -276,8 +264,7 @@ public void run() { } } - @Override - protected ManagedChannelBuilder createChannelBuilder() { + ManagedChannel createChannel() { ChannelCredentials channelCredentials; switch (credentialsType) { case "compute_engine_channel_creds": @@ -291,15 +278,33 @@ protected ManagedChannelBuilder createChannelBuilder() { } return Grpc.newChannelBuilder(serverUri, channelCredentials) .keepAliveTime(3600, SECONDS) - .keepAliveTimeout(20, SECONDS); + .keepAliveTimeout(20, SECONDS) + .build(); + } + + ManagedChannel createNewChannel(ManagedChannel currentChannel) { + currentChannel.shutdownNow(); + try { + currentChannel.awaitTermination(10, TimeUnit.SECONDS); + } catch (InterruptedException e) { + throw new RuntimeException("Interrupted while creating a new channel", e); + } + return createChannel(); } } - private void run() throws Exception { + private void run() throws InterruptedException { logger.info("Begin test case: " + testCase); ArrayList threads = new ArrayList<>(); for (InnerClient c : clients) { - Thread t = new Thread(c::run); + Thread t = new Thread(() -> { + try { + c.run(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); // Properly re-interrupt the thread + throw new RuntimeException("Thread was interrupted during execution", e); + } + }); t.start(); threads.add(t); } diff --git a/interop-testing/src/main/java/io/grpc/testing/integration/XdsTestClient.java b/interop-testing/src/main/java/io/grpc/testing/integration/XdsTestClient.java index c38123cad64..89519041a79 100644 --- a/interop-testing/src/main/java/io/grpc/testing/integration/XdsTestClient.java +++ b/interop-testing/src/main/java/io/grpc/testing/integration/XdsTestClient.java @@ -19,6 +19,7 @@ import com.google.common.base.CaseFormat; import com.google.common.base.Splitter; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import com.google.common.primitives.Ints; import com.google.common.util.concurrent.FutureCallback; import com.google.common.util.concurrent.Futures; @@ -26,6 +27,8 @@ import com.google.common.util.concurrent.ListeningScheduledExecutorService; import com.google.common.util.concurrent.MoreExecutors; import com.google.common.util.concurrent.SettableFuture; +import com.google.protobuf.ByteString; +import io.grpc.BindableService; import io.grpc.CallOptions; import io.grpc.Channel; import io.grpc.ClientCall; @@ -40,7 +43,9 @@ import io.grpc.MethodDescriptor; import io.grpc.Server; import io.grpc.Status; +import io.grpc.gcp.csm.observability.CsmObservability; import io.grpc.protobuf.services.ProtoReflectionService; +import io.grpc.protobuf.services.ProtoReflectionServiceV1; import io.grpc.services.AdminInterface; import io.grpc.stub.StreamObserver; import io.grpc.testing.integration.Messages.ClientConfigureRequest; @@ -51,9 +56,11 @@ import io.grpc.testing.integration.Messages.LoadBalancerAccumulatedStatsResponse.MethodStats; import io.grpc.testing.integration.Messages.LoadBalancerStatsRequest; import io.grpc.testing.integration.Messages.LoadBalancerStatsResponse; +import io.grpc.testing.integration.Messages.Payload; import io.grpc.testing.integration.Messages.SimpleRequest; import io.grpc.testing.integration.Messages.SimpleResponse; import io.grpc.xds.XdsChannelCredentials; +import io.opentelemetry.sdk.autoconfigure.AutoConfiguredOpenTelemetrySdk; import java.util.ArrayList; import java.util.Collections; import java.util.EnumMap; @@ -71,6 +78,7 @@ import java.util.logging.Logger; import javax.annotation.Nullable; import javax.annotation.concurrent.ThreadSafe; +import org.codehaus.mojo.animal_sniffer.IgnoreJRERequirement; /** Client for xDS interop tests. */ public final class XdsTestClient { @@ -88,10 +96,14 @@ public final class XdsTestClient { private int rpcTimeoutSec = 20; private boolean secureMode = false; private String server = "localhost:8080"; + private int requestSize; + private int responseSize; + private boolean enableCsmObservability; private int statsPort = 8081; private Server statsServer; private long currentRequestId; private ListeningScheduledExecutorService exec; + private CsmObservability csmObservability; /** * The main application allowing this client to be launched from the command line. @@ -151,6 +163,12 @@ private void parseArgs(String[] args) { rpcTimeoutSec = Integer.valueOf(value); } else if ("server".equals(key)) { server = value; + } else if ("request_payload_size".equals(key)) { + requestSize = Integer.valueOf(value); + } else if ("response_payload_size".equals(key)) { + responseSize = Integer.valueOf(value); + } else if ("enable_csm_observability".equals(key)) { + enableCsmObservability = Boolean.valueOf(value); } else if ("stats_port".equals(key)) { statsPort = Integer.valueOf(value); } else if ("secure_mode".equals(key)) { @@ -194,6 +212,10 @@ private void parseArgs(String[] args) { + c.server + "\n --secure_mode=BOOLEAN Use true to enable XdsCredentials. Default: " + c.secureMode + + "\n --request_payload_size=INT Per-request size. Default: " + c.requestSize + + "\n --response_payload_size=INT Per-response size. Default: " + c.responseSize + + "\n --enable_csm_observability=BOOL Enable CSM observability reporting. Default: " + + c.enableCsmObservability + "\n --stats_port=INT Port to expose peer distribution stats service. " + "Default: " + c.statsPort); @@ -240,12 +262,28 @@ private static RpcType parseRpc(String rpc) { } } + @IgnoreJRERequirement // OpenTelemetry uses Java 8+ APIs private void run() { + if (enableCsmObservability) { + csmObservability = CsmObservability.newBuilder() + .sdk(AutoConfiguredOpenTelemetrySdk.builder() + .addPropertiesSupplier(() -> ImmutableMap.of( + "otel.logs.exporter", "none", + "otel.metrics.exporter", "prometheus", + "otel.traces.exporter", "none")) + .build() + .getOpenTelemetrySdk()) + .build(); + csmObservability.registerGlobal(); + } + @SuppressWarnings("deprecation") + BindableService oldReflectionService = ProtoReflectionService.newInstance(); statsServer = Grpc.newServerBuilderForPort(statsPort, InsecureServerCredentials.create()) .addService(new XdsStatsImpl()) .addService(new ConfigureUpdateServiceImpl()) - .addService(ProtoReflectionService.newInstance()) + .addService(oldReflectionService) + .addService(ProtoReflectionServiceV1.newInstance()) .addServices(AdminInterface.getStandardServices()) .build(); try { @@ -261,7 +299,10 @@ private void run() { .build()); } exec = MoreExecutors.listeningDecorator(Executors.newSingleThreadScheduledExecutor()); - runQps(); + Payload requestPayload = Payload.newBuilder() + .setBody(ByteString.copyFrom(new byte[requestSize])) + .build(); + runQps(requestPayload); } catch (Throwable t) { logger.log(Level.SEVERE, "Error running client", t); System.exit(1); @@ -281,10 +322,13 @@ private void stop() throws InterruptedException { if (exec != null) { exec.shutdownNow(); } + if (csmObservability != null) { + csmObservability.close(); + } } - private void runQps() throws InterruptedException, ExecutionException { + private void runQps(Payload requestPayload) throws InterruptedException, ExecutionException { final SettableFuture failure = SettableFuture.create(); final class PeriodicRpc implements Runnable { @@ -357,7 +401,11 @@ public void onError(Throwable t) { public void onNext(EmptyProtos.Empty response) {} }); } else if (config.rpcType == RpcType.UNARY_CALL) { - SimpleRequest request = SimpleRequest.newBuilder().setFillServerId(true).build(); + SimpleRequest request = SimpleRequest.newBuilder() + .setFillServerId(true) + .setPayload(requestPayload) + .setResponseSize(responseSize) + .build(); stub.unaryCall( request, new StreamObserver() { diff --git a/interop-testing/src/main/java/io/grpc/testing/integration/XdsTestServer.java b/interop-testing/src/main/java/io/grpc/testing/integration/XdsTestServer.java index 9d74cfac542..88f1bf468b6 100644 --- a/interop-testing/src/main/java/io/grpc/testing/integration/XdsTestServer.java +++ b/interop-testing/src/main/java/io/grpc/testing/integration/XdsTestServer.java @@ -17,33 +17,47 @@ package io.grpc.testing.integration; import com.google.common.base.Splitter; +import com.google.common.collect.ImmutableMap; import com.google.common.collect.Iterables; +import com.google.protobuf.ByteString; +import io.grpc.BindableService; import io.grpc.ForwardingServerCall.SimpleForwardingServerCall; import io.grpc.Grpc; import io.grpc.InsecureServerCredentials; import io.grpc.Metadata; import io.grpc.Server; +import io.grpc.ServerBuilder; import io.grpc.ServerCall; import io.grpc.ServerCallHandler; +import io.grpc.ServerCredentials; import io.grpc.ServerInterceptor; import io.grpc.ServerInterceptors; import io.grpc.Status; +import io.grpc.gcp.csm.observability.CsmObservability; import io.grpc.health.v1.HealthCheckResponse.ServingStatus; +import io.grpc.netty.NettyServerBuilder; import io.grpc.protobuf.services.HealthStatusManager; import io.grpc.protobuf.services.ProtoReflectionService; +import io.grpc.protobuf.services.ProtoReflectionServiceV1; import io.grpc.services.AdminInterface; import io.grpc.stub.StreamObserver; +import io.grpc.testing.integration.Messages.Payload; import io.grpc.testing.integration.Messages.SimpleRequest; import io.grpc.testing.integration.Messages.SimpleResponse; import io.grpc.xds.XdsServerBuilder; import io.grpc.xds.XdsServerCredentials; +import io.opentelemetry.sdk.autoconfigure.AutoConfiguredOpenTelemetrySdk; import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.net.SocketAddress; import java.net.UnknownHostException; import java.util.ArrayList; import java.util.List; +import java.util.Locale; import java.util.concurrent.TimeUnit; import java.util.logging.Level; import java.util.logging.Logger; +import org.codehaus.mojo.animal_sniffer.IgnoreJRERequirement; /** Interop test server that implements the xDS testing service. */ public final class XdsTestServer { @@ -70,11 +84,14 @@ public final class XdsTestServer { private int port = 8080; private int maintenancePort = 8080; private boolean secureMode = false; + private boolean enableCsmObservability; private String serverId = "java_server"; private HealthStatusManager health; private Server server; private Server maintenanceServer; private String host; + private Util.AddressType addressType = Util.AddressType.IPV4_IPV6; + private CsmObservability csmObservability; /** * The main application allowing this client to be launched from the command line. @@ -101,7 +118,7 @@ public void run() { server.blockUntilShutdown(); } - private void parseArgs(String[] args) { + void parseArgs(String[] args) { boolean usage = false; for (String arg : args) { if (!arg.startsWith("--")) { @@ -127,8 +144,12 @@ private void parseArgs(String[] args) { maintenancePort = Integer.valueOf(value); } else if ("secure_mode".equals(key)) { secureMode = Boolean.parseBoolean(value); + } else if ("enable_csm_observability".equals(key)) { + enableCsmObservability = Boolean.valueOf(value); } else if ("server_id".equals(key)) { serverId = value; + } else if ("address_type".equals(key)) { + addressType = Util.AddressType.valueOf(value.toUpperCase(Locale.ROOT)); } else { System.err.println("Unknown argument: " + key); usage = true; @@ -160,14 +181,33 @@ private void parseArgs(String[] args) { + " port and maintenance_port should be different for secure mode." + "\n Default: " + s.secureMode + + "\n --enable_csm_observability=BOOL Enable CSM observability reporting. Default: " + + s.enableCsmObservability + "\n --server_id=STRING server ID for response." + "\n Default: " - + s.serverId); + + s.serverId + + "\n --address_type=STRING type of IP address to bind to (IPV4|IPV6|IPV4_IPV6)." + + "\n Default: " + + s.addressType); System.exit(1); } } - private void start() throws Exception { + @SuppressWarnings("AddressSelection") + @IgnoreJRERequirement // OpenTelemetry uses Java 8+ APIs + void start() throws Exception { + if (enableCsmObservability) { + csmObservability = CsmObservability.newBuilder() + .sdk(AutoConfiguredOpenTelemetrySdk.builder() + .addPropertiesSupplier(() -> ImmutableMap.of( + "otel.logs.exporter", "none", + "otel.metrics.exporter", "prometheus", + "otel.traces.exporter", "none")) + .build() + .getOpenTelemetrySdk()) + .build(); + csmObservability.registerGlobal(); + } try { host = InetAddress.getLocalHost().getHostName(); } catch (UnknownHostException e) { @@ -175,12 +215,18 @@ private void start() throws Exception { throw new RuntimeException(e); } health = new HealthStatusManager(); + @SuppressWarnings("deprecation") + BindableService oldReflectionService = ProtoReflectionService.newInstance(); if (secureMode) { + if (addressType != Util.AddressType.IPV4_IPV6) { + throw new IllegalArgumentException("Secure mode only supports IPV4_IPV6 address type"); + } maintenanceServer = Grpc.newServerBuilderForPort(maintenancePort, InsecureServerCredentials.create()) .addService(new XdsUpdateHealthServiceImpl(health)) .addService(health.getHealthService()) - .addService(ProtoReflectionService.newInstance()) + .addService(oldReflectionService) + .addService(ProtoReflectionServiceV1.newInstance()) .addServices(AdminInterface.getStandardServices()) .build(); maintenanceServer.start(); @@ -193,14 +239,46 @@ private void start() throws Exception { .build(); server.start(); } else { + ServerBuilder serverBuilder; + ServerCredentials insecureServerCreds = InsecureServerCredentials.create(); + switch (addressType) { + case IPV4_IPV6: + serverBuilder = Grpc.newServerBuilderForPort(port, insecureServerCreds); + break; + case IPV4: + SocketAddress v4Address = Util.getV4Address(port); + InetSocketAddress localV4Address = new InetSocketAddress("127.0.0.1", port); + serverBuilder = NettyServerBuilder.forAddress( + localV4Address, insecureServerCreds); + if (v4Address != null && !v4Address.equals(localV4Address) ) { + ((NettyServerBuilder) serverBuilder).addListenAddress(v4Address); + } + break; + case IPV6: + List v6Addresses = Util.getV6Addresses(port); + InetSocketAddress localV6Address = new InetSocketAddress("::1", port); + serverBuilder = NettyServerBuilder.forAddress(localV6Address, insecureServerCreds); + for (SocketAddress address : v6Addresses) { + if (!address.equals(localV6Address)) { + ((NettyServerBuilder) serverBuilder).addListenAddress(address); + } + } + break; + default: + throw new AssertionError("Unknown address type: " + addressType); + } + + logger.info("Starting server on port " + port + " with address type " + addressType); + server = - Grpc.newServerBuilderForPort(port, InsecureServerCredentials.create()) + serverBuilder .addService( ServerInterceptors.intercept( new TestServiceImpl(serverId, host), new TestInfoInterceptor(host))) .addService(new XdsUpdateHealthServiceImpl(health)) .addService(health.getHealthService()) - .addService(ProtoReflectionService.newInstance()) + .addService(oldReflectionService) + .addService(ProtoReflectionServiceV1.newInstance()) .addServices(AdminInterface.getStandardServices()) .build(); server.start(); @@ -209,7 +287,7 @@ private void start() throws Exception { health.setStatus("", ServingStatus.SERVING); } - private void stop() throws Exception { + void stop() throws Exception { server.shutdownNow(); if (maintenanceServer != null) { maintenanceServer.shutdownNow(); @@ -220,6 +298,9 @@ private void stop() throws Exception { if (maintenanceServer != null && !maintenanceServer.awaitTermination(5, TimeUnit.SECONDS)) { System.err.println("Timed out waiting for maintenanceServer shutdown"); } + if (csmObservability != null) { + csmObservability.close(); + } } private void blockUntilShutdown() throws InterruptedException { @@ -249,8 +330,13 @@ public void emptyCall( @Override public void unaryCall(SimpleRequest req, StreamObserver responseObserver) { - responseObserver.onNext( - SimpleResponse.newBuilder().setServerId(serverId).setHostname(host).build()); + responseObserver.onNext(SimpleResponse.newBuilder() + .setServerId(serverId) + .setHostname(host) + .setPayload(Payload.newBuilder() + .setBody(ByteString.copyFrom(new byte[req.getResponseSize()])) + .build()) + .build()); responseObserver.onCompleted(); } } diff --git a/interop-testing/src/test/java/io/grpc/testing/integration/CompressionTest.java b/interop-testing/src/test/java/io/grpc/testing/integration/CompressionTest.java index 208eb40c438..5307c26949b 100644 --- a/interop-testing/src/test/java/io/grpc/testing/integration/CompressionTest.java +++ b/interop-testing/src/test/java/io/grpc/testing/integration/CompressionTest.java @@ -24,6 +24,8 @@ import static org.junit.Assert.assertNull; import static org.junit.Assert.assertTrue; +import com.google.common.collect.Iterables; +import com.google.common.collect.Lists; import com.google.protobuf.ByteString; import io.grpc.CallOptions; import io.grpc.Channel; @@ -53,8 +55,6 @@ import io.grpc.testing.integration.TestServiceGrpc.TestServiceBlockingStub; import io.grpc.testing.integration.TransportCompressionTest.Fzip; import java.nio.charset.Charset; -import java.util.ArrayList; -import java.util.Collection; import java.util.List; import java.util.concurrent.Executors; import java.util.concurrent.ScheduledExecutorService; @@ -146,25 +146,16 @@ public void tearDown() { * Parameters for test. */ @Parameters - public static Collection params() { - boolean[] bools = new boolean[]{false, true}; - List combos = new ArrayList<>(64); - for (boolean enableClientMessageCompression : bools) { - for (boolean clientAcceptEncoding : bools) { - for (boolean clientEncoding : bools) { - for (boolean enableServerMessageCompression : bools) { - for (boolean serverAcceptEncoding : bools) { - for (boolean serverEncoding : bools) { - combos.add(new Object[] { - enableClientMessageCompression, clientAcceptEncoding, clientEncoding, - enableServerMessageCompression, serverAcceptEncoding, serverEncoding}); - } - } - } - } - } - } - return combos; + public static Iterable params() { + List bools = Lists.newArrayList(false, true); + return Iterables.transform(Lists.cartesianProduct( + bools, // enableClientMessageCompression + bools, // clientAcceptEncoding + bools, // clientEncoding + bools, // enableServerMessageCompression + bools, // serverAcceptEncoding + bools // serverEncoding + ), List::toArray); } @Test diff --git a/interop-testing/src/test/java/io/grpc/testing/integration/Http2Test.java b/interop-testing/src/test/java/io/grpc/testing/integration/Http2Test.java index 8d448a049cb..4c0d47be567 100644 --- a/interop-testing/src/test/java/io/grpc/testing/integration/Http2Test.java +++ b/interop-testing/src/test/java/io/grpc/testing/integration/Http2Test.java @@ -18,6 +18,7 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotEquals; +import static org.junit.Assert.assertTrue; import io.grpc.ChannelCredentials; import io.grpc.ManagedChannelBuilder; @@ -38,7 +39,6 @@ import io.grpc.stub.MetadataUtils; import io.grpc.testing.TlsTesting; import java.io.IOException; -import java.net.InetAddress; import java.net.InetSocketAddress; import java.util.Arrays; import org.junit.Test; @@ -139,7 +139,7 @@ protected ManagedChannelBuilder createChannelBuilder() { @Test public void remoteAddr() { InetSocketAddress isa = (InetSocketAddress) obtainRemoteClientAddr(); - assertEquals(InetAddress.getLoopbackAddress(), isa.getAddress()); + assertTrue(isa.getAddress().isLoopbackAddress()); // It should not be the same as the server assertNotEquals(((InetSocketAddress) getListenAddress()).getPort(), isa.getPort()); } @@ -147,7 +147,7 @@ public void remoteAddr() { @Test public void localAddr() throws Exception { InetSocketAddress isa = (InetSocketAddress) obtainLocalServerAddr(); - assertEquals(InetAddress.getLoopbackAddress(), isa.getAddress()); + assertTrue(isa.getAddress().isLoopbackAddress()); assertEquals(((InetSocketAddress) getListenAddress()).getPort(), isa.getPort()); } diff --git a/interop-testing/src/test/java/io/grpc/testing/integration/ManagedChannelImplIntegrationTest.java b/interop-testing/src/test/java/io/grpc/testing/integration/ManagedChannelImplIntegrationTest.java new file mode 100644 index 00000000000..f09f196d7d8 --- /dev/null +++ b/interop-testing/src/test/java/io/grpc/testing/integration/ManagedChannelImplIntegrationTest.java @@ -0,0 +1,80 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.testing.integration; + +import static com.google.common.truth.Truth.assertThat; + +import io.grpc.ManagedChannel; +import io.grpc.ServerInterceptors; +import io.grpc.inprocess.InProcessChannelBuilder; +import io.grpc.inprocess.InProcessServerBuilder; +import io.grpc.internal.FakeClock; +import io.grpc.internal.testing.StreamRecorder; +import io.grpc.stub.StreamObserver; +import io.grpc.testing.GrpcCleanupRule; +import io.grpc.testing.integration.EmptyProtos.Empty; +import io.grpc.testing.integration.Messages.ResponseParameters; +import io.grpc.testing.integration.Messages.StreamingOutputCallRequest; +import io.grpc.testing.integration.Messages.StreamingOutputCallResponse; +import java.util.concurrent.TimeUnit; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Tests for ManagedChannelImpl that use a real transport. */ +@RunWith(JUnit4.class) +public final class ManagedChannelImplIntegrationTest { + private static final String SERVER_NAME = ManagedChannelImplIntegrationTest.class.getName(); + @Rule + public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule(); + + @Test + public void idleWhileRpcInTransport_exitsIdleForNewRpc() throws Exception { + FakeClock fakeClock = new FakeClock(); + grpcCleanup.register(InProcessServerBuilder.forName(SERVER_NAME) + .directExecutor() + .addService( + ServerInterceptors.intercept( + new TestServiceImpl(fakeClock.getScheduledExecutorService()), + TestServiceImpl.interceptors())) + .build() + .start()); + ManagedChannel channel = grpcCleanup.register(InProcessChannelBuilder.forName(SERVER_NAME) + .directExecutor() + .build()); + + TestServiceGrpc.TestServiceBlockingStub blockingStub = TestServiceGrpc.newBlockingStub(channel); + TestServiceGrpc.TestServiceStub asyncStub = TestServiceGrpc.newStub(channel); + StreamRecorder responseObserver = StreamRecorder.create(); + StreamObserver requestObserver = + asyncStub.fullDuplexCall(responseObserver); + requestObserver.onNext(StreamingOutputCallRequest.newBuilder() + .addResponseParameters(ResponseParameters.newBuilder() + .setIntervalUs(Integer.MAX_VALUE)) + .build()); + try { + channel.enterIdle(); + assertThat(blockingStub + .withDeadlineAfter(10, TimeUnit.SECONDS) + .emptyCall(Empty.getDefaultInstance())) + .isEqualTo(Empty.getDefaultInstance()); + } finally { + requestObserver.onError(new RuntimeException("cleanup")); + } + } +} diff --git a/interop-testing/src/test/java/io/grpc/testing/integration/OpenTelemetryContextPropagationTest.java b/interop-testing/src/test/java/io/grpc/testing/integration/OpenTelemetryContextPropagationTest.java new file mode 100644 index 00000000000..3884d977a6e --- /dev/null +++ b/interop-testing/src/test/java/io/grpc/testing/integration/OpenTelemetryContextPropagationTest.java @@ -0,0 +1,191 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.testing.integration; + +import static org.junit.Assert.assertEquals; + +import io.grpc.ForwardingServerCallListener; +import io.grpc.InsecureServerCredentials; +import io.grpc.ManagedChannelBuilder; +import io.grpc.Metadata; +import io.grpc.ServerBuilder; +import io.grpc.ServerCall; +import io.grpc.ServerCallHandler; +import io.grpc.ServerInterceptor; +import io.grpc.netty.InternalNettyChannelBuilder; +import io.grpc.netty.NettyChannelBuilder; +import io.grpc.netty.NettyServerBuilder; +import io.grpc.opentelemetry.GrpcOpenTelemetry; +import io.grpc.opentelemetry.GrpcTraceBinContextPropagator; +import io.grpc.opentelemetry.InternalGrpcOpenTelemetry; +import io.grpc.testing.integration.Messages.SimpleRequest; +import io.opentelemetry.api.trace.Span; +import io.opentelemetry.api.trace.Tracer; +import io.opentelemetry.api.trace.propagation.W3CTraceContextPropagator; +import io.opentelemetry.context.Context; +import io.opentelemetry.context.Scope; +import io.opentelemetry.context.propagation.ContextPropagators; +import io.opentelemetry.context.propagation.TextMapPropagator; +import io.opentelemetry.sdk.OpenTelemetrySdk; +import io.opentelemetry.sdk.trace.SdkTracerProvider; +import java.util.Arrays; +import java.util.concurrent.atomic.AtomicReference; +import org.junit.Assume; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +@RunWith(Parameterized.class) +public class OpenTelemetryContextPropagationTest extends AbstractInteropTest { + private final OpenTelemetrySdk openTelemetrySdk; + private final Tracer tracer; + private final GrpcOpenTelemetry grpcOpenTelemetry; + private final AtomicReference applicationSpan = new AtomicReference<>(); + private final boolean censusClient; + + @Parameterized.Parameters(name = "ContextPropagator={0}, CensusClient={1}") + public static Iterable data() { + return Arrays.asList(new Object[][] { + {W3CTraceContextPropagator.getInstance(), false}, + {GrpcTraceBinContextPropagator.defaultInstance(), false}, + {GrpcTraceBinContextPropagator.defaultInstance(), true} + }); + } + + public OpenTelemetryContextPropagationTest(TextMapPropagator textMapPropagator, + boolean isCensusClient) { + this.openTelemetrySdk = OpenTelemetrySdk.builder() + .setTracerProvider(SdkTracerProvider.builder().build()) + .setPropagators(ContextPropagators.create(TextMapPropagator.composite( + textMapPropagator + ))) + .build(); + this.tracer = openTelemetrySdk + .getTracer("grpc-java-interop-test"); + GrpcOpenTelemetry.Builder grpcOpentelemetryBuilder = GrpcOpenTelemetry.newBuilder() + .sdk(openTelemetrySdk); + InternalGrpcOpenTelemetry.enableTracing(grpcOpentelemetryBuilder, true); + grpcOpenTelemetry = grpcOpentelemetryBuilder.build(); + this.censusClient = isCensusClient; + } + + @Override + protected ServerBuilder getServerBuilder() { + NettyServerBuilder builder = NettyServerBuilder.forPort(0, InsecureServerCredentials.create()) + .maxInboundMessageSize(AbstractInteropTest.MAX_MESSAGE_SIZE); + builder.intercept(new ServerInterceptor() { + @Override + public ServerCall.Listener interceptCall(ServerCall call, + Metadata headers, ServerCallHandler next) { + ServerCall.Listener listener = next.startCall(call, headers); + return new ForwardingServerCallListener() { + @Override + protected ServerCall.Listener delegate() { + return listener; + } + + @Override + public void onMessage(ReqT request) { + applicationSpan.set(tracer.spanBuilder("InteropTest.Application.Span").startSpan()); + delegate().onMessage(request); + } + + @Override + public void onHalfClose() { + maybeCloseSpan(applicationSpan); + delegate().onHalfClose(); + } + + @Override + public void onCancel() { + maybeCloseSpan(applicationSpan); + delegate().onCancel(); + } + + @Override + public void onComplete() { + maybeCloseSpan(applicationSpan); + delegate().onComplete(); + } + }; + } + }); + // To ensure proper propagation of remote spans from gRPC to your application, this interceptor + // must be after any application interceptors that interact with spans. This allows the tracing + // information to be correctly passed along. However, it's fine for application-level onMessage + // handlers to access the span. + grpcOpenTelemetry.configureServerBuilder(builder); + return builder; + } + + private void maybeCloseSpan(AtomicReference applicationSpan) { + Span tmp = applicationSpan.get(); + if (tmp != null) { + tmp.end(); + } + } + + @Override + protected boolean metricsExpected() { + return false; + } + + @Override + protected ManagedChannelBuilder createChannelBuilder() { + NettyChannelBuilder builder = NettyChannelBuilder.forAddress(getListenAddress()) + .maxInboundMessageSize(AbstractInteropTest.MAX_MESSAGE_SIZE) + .usePlaintext(); + if (!censusClient) { + // Disabling census-tracing is necessary to avoid trace ID mismatches. + // This is because census-tracing overrides the grpc-trace-bin header with + // OpenTelemetry's GrpcTraceBinPropagator. + InternalNettyChannelBuilder.setTracingEnabled(builder, false); + grpcOpenTelemetry.configureChannelBuilder(builder); + } + return builder; + } + + @Test + public void otelSpanContextPropagation() { + Assume.assumeFalse(censusClient); + Span parentSpan = tracer.spanBuilder("Test.interopTest").startSpan(); + try (Scope scope = Context.current().with(parentSpan).makeCurrent()) { + blockingStub.unaryCall(SimpleRequest.getDefaultInstance()); + } + assertEquals(parentSpan.getSpanContext().getTraceId(), + applicationSpan.get().getSpanContext().getTraceId()); + } + + @Test + @SuppressWarnings("deprecation") + public void censusToOtelGrpcTraceBinPropagator() { + Assume.assumeTrue(censusClient); + io.opencensus.trace.Tracer censusTracer = io.opencensus.trace.Tracing.getTracer(); + io.opencensus.trace.Span parentSpan = censusTracer.spanBuilder("Test.interopTest") + .startSpan(); + io.grpc.Context context = io.opencensus.trace.unsafe.ContextUtils.withValue( + io.grpc.Context.current(), parentSpan); + io.grpc.Context previous = context.attach(); + try { + blockingStub.unaryCall(SimpleRequest.getDefaultInstance()); + assertEquals(parentSpan.getContext().getTraceId().toLowerBase16(), + applicationSpan.get().getSpanContext().getTraceId()); + } finally { + context.detach(previous); + } + } +} diff --git a/interop-testing/src/test/java/io/grpc/testing/integration/ProxyTest.java b/interop-testing/src/test/java/io/grpc/testing/integration/ProxyTest.java index f550d657a12..725e98d0fe3 100644 --- a/interop-testing/src/test/java/io/grpc/testing/integration/ProxyTest.java +++ b/interop-testing/src/test/java/io/grpc/testing/integration/ProxyTest.java @@ -62,7 +62,6 @@ public void shutdownTest() throws IOException { } @Test - @org.junit.Ignore // flaky. latency commonly too high public void smallLatency() throws Exception { server = new Server(); int serverPort = server.init(); diff --git a/interop-testing/src/test/java/io/grpc/testing/integration/RetryTest.java b/interop-testing/src/test/java/io/grpc/testing/integration/RetryTest.java index edd2a57ab9d..669ce1c69db 100644 --- a/interop-testing/src/test/java/io/grpc/testing/integration/RetryTest.java +++ b/interop-testing/src/test/java/io/grpc/testing/integration/RetryTest.java @@ -303,7 +303,7 @@ public void retryUntilBufferLimitExceeded() throws Exception { serverCall.close( Status.UNAVAILABLE.withDescription("original attempt failed"), new Metadata()); - elapseBackoff(10, SECONDS); + elapseBackoff(12, SECONDS); // 2nd attempt received serverCall = serverCalls.poll(5, SECONDS); serverCall.request(2); @@ -348,7 +348,7 @@ public void statsRecorded() throws Exception { Status.UNAVAILABLE.withDescription("original attempt failed"), new Metadata()); assertRpcStatusRecorded(Status.Code.UNAVAILABLE, 1000, 1); - elapseBackoff(10, SECONDS); + elapseBackoff(12, SECONDS); assertRpcStartedRecorded(); assertOutboundMessageRecorded(); serverCall = serverCalls.poll(5, SECONDS); @@ -366,7 +366,7 @@ public void statsRecorded() throws Exception { call.request(1); assertInboundMessageRecorded(); assertInboundWireSizeRecorded(1); - assertRpcStatusRecorded(Status.Code.OK, 12000, 2); + assertRpcStatusRecorded(Status.Code.OK, 14000, 2); assertRetryStatsRecorded(1, 0, 0); } @@ -418,7 +418,7 @@ public void streamClosed(Status status) { Status.UNAVAILABLE.withDescription("original attempt failed"), new Metadata()); assertRpcStatusRecorded(Code.UNAVAILABLE, 5000, 1); - elapseBackoff(10, SECONDS); + elapseBackoff(12, SECONDS); assertRpcStartedRecorded(); assertOutboundMessageRecorded(); serverCall = serverCalls.poll(5, SECONDS); @@ -431,7 +431,7 @@ public void streamClosed(Status status) { streamClosedLatch.countDown(); // The call listener is closed. verify(mockCallListener, timeout(5000)).onClose(any(Status.class), any(Metadata.class)); - assertRpcStatusRecorded(Code.CANCELLED, 17_000, 1); + assertRpcStatusRecorded(Code.CANCELLED, 19_000, 1); assertRetryStatsRecorded(1, 0, 0); } diff --git a/interop-testing/src/test/java/io/grpc/testing/integration/RpcBehaviorLoadBalancerProviderTest.java b/interop-testing/src/test/java/io/grpc/testing/integration/RpcBehaviorLoadBalancerProviderTest.java index e19208b8883..4a43af67ac8 100644 --- a/interop-testing/src/test/java/io/grpc/testing/integration/RpcBehaviorLoadBalancerProviderTest.java +++ b/interop-testing/src/test/java/io/grpc/testing/integration/RpcBehaviorLoadBalancerProviderTest.java @@ -78,6 +78,7 @@ public void parseInvalidConfig() { assertThat(status.getDescription()).contains("rpcBehavior"); } + @Deprecated @Test public void handleResolvedAddressesDelegated() { RpcBehaviorLoadBalancer lb = new RpcBehaviorLoadBalancer(new RpcBehaviorHelper(mockHelper), @@ -87,6 +88,15 @@ public void handleResolvedAddressesDelegated() { verify(mockDelegateLb).handleResolvedAddresses(resolvedAddresses); } + @Test + public void acceptResolvedAddressesDelegated() { + RpcBehaviorLoadBalancer lb = new RpcBehaviorLoadBalancer(new RpcBehaviorHelper(mockHelper), + mockDelegateLb); + ResolvedAddresses resolvedAddresses = buildResolvedAddresses(buildConfig()); + lb.acceptResolvedAddresses(resolvedAddresses); + verify(mockDelegateLb).acceptResolvedAddresses(resolvedAddresses); + } + @Test public void helperWrapsPicker() { RpcBehaviorHelper helper = new RpcBehaviorHelper(mockHelper); @@ -100,7 +110,7 @@ public void helperWrapsPicker() { @Test public void pickerAddsRpcBehaviorMetadata() { PickSubchannelArgsImpl args = new PickSubchannelArgsImpl(TestMethodDescriptors.voidMethod(), - new Metadata(), CallOptions.DEFAULT); + new Metadata(), CallOptions.DEFAULT, new LoadBalancer.PickDetailsConsumer() {}); new RpcBehaviorPicker(mockPicker, "error-code-15").pickSubchannel(args); assertThat(args.getHeaders() diff --git a/interop-testing/src/test/java/io/grpc/testing/integration/StressTestClientTest.java b/interop-testing/src/test/java/io/grpc/testing/integration/StressTestClientTest.java index c09a0cfeab9..a1a2cb9b5ea 100644 --- a/interop-testing/src/test/java/io/grpc/testing/integration/StressTestClientTest.java +++ b/interop-testing/src/test/java/io/grpc/testing/integration/StressTestClientTest.java @@ -44,13 +44,13 @@ public class StressTestClientTest { @Rule - public final Timeout globalTimeout = Timeout.seconds(10); + public final Timeout globalTimeout = Timeout.seconds(15); @Test public void ipv6AddressesShouldBeSupported() { StressTestClient client = new StressTestClient(); - client.parseArgs(new String[] {"--server_addresses=[0:0:0:0:0:0:0:1]:8080," - + "[1:2:3:4:f:e:a:b]:8083"}); + client.parseArgs(new String[] { + "--server_addresses=[0:0:0:0:0:0:0:1]:8080,[1:2:3:4:f:e:a:b]:8083"}); assertEquals(2, client.addresses().size()); assertEquals(new InetSocketAddress("0:0:0:0:0:0:0:1", 8080), client.addresses().get(0)); diff --git a/interop-testing/src/test/java/io/grpc/testing/integration/XdsTestServerTest.java b/interop-testing/src/test/java/io/grpc/testing/integration/XdsTestServerTest.java new file mode 100644 index 00000000000..7bfa1a2cd7d --- /dev/null +++ b/interop-testing/src/test/java/io/grpc/testing/integration/XdsTestServerTest.java @@ -0,0 +1,113 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.testing.integration; + +import static org.junit.Assert.assertEquals; + +import io.grpc.Channel; +import io.grpc.ChannelCredentials; +import io.grpc.Grpc; +import io.grpc.InsecureChannelCredentials; +import io.grpc.ManagedChannel; +import io.grpc.ManagedChannelBuilder; +import io.grpc.testing.GrpcCleanupRule; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Unit tests to make sure that the {@link XdsTestServer} is working as expected. + * Specifically, that for dualstack communication is handled correctly across address families + * and that the test server is correctly handling the address_type flag. + */ +@RunWith(JUnit4.class) +public class XdsTestServerTest { + protected static final EmptyProtos.Empty EMPTY = EmptyProtos.Empty.getDefaultInstance(); + + @Rule + public final GrpcCleanupRule cleanupRule = new GrpcCleanupRule(); + + + @Test + public void check_ipv4() throws Exception { + checkConnectionWorks("127.0.0.1", "--address_type=IPV4"); + } + + @Test + public void check_ipv6() throws Exception { + checkConnectionWorks("::1", "--address_type=IPV6"); + } + + @Test + public void check_ipv4_ipv6() throws Exception { + checkConnectionWorks("localhost", "--address_type=IPV4_IPV6"); + } + + @Test + public void checkNoAddressType() throws Exception { + // This ensures that all of the other xds tests aren't broken by the address_type argument. + checkConnectionWorks("localhost", null); + } + + // Simple test to ensure that communication with the server works which includes starting and + // stopping the server, creating a channel and doing a unary rpc. + private void checkConnectionWorks(String targetServer, String addressTypeArg) + throws Exception { + + int port = Util.pickUnusedPort(); + + XdsTestServer server = getAndStartTestServiceServer(port, addressTypeArg); + + try { + ManagedChannel realChannel = createChannel(port, targetServer); + Channel channel = cleanupRule.register(realChannel); + TestServiceGrpc.TestServiceBlockingStub stub = TestServiceGrpc.newBlockingStub(channel); + + assertEquals(EMPTY, stub.emptyCall(EMPTY)); + } catch (Exception e) { + throw new AssertionError(e); + } finally { + server.stop(); + } + } + + private static ManagedChannel createChannel(int port, String target) { + ChannelCredentials creds = InsecureChannelCredentials.create(); + + ManagedChannelBuilder builder; + if (port == 0) { + builder = Grpc.newChannelBuilder(target, creds); + } else { + builder = Grpc.newChannelBuilderForAddress(target, port, creds); + } + + builder.overrideAuthority("foo.test.google.fr"); + return builder.build(); + } + + private static XdsTestServer getAndStartTestServiceServer(int port, String addressTypeArg) + throws Exception { + XdsTestServer server = new XdsTestServer(); + String[] args = addressTypeArg != null + ? new String[]{"--port=" + port, addressTypeArg} + : new String[]{"--port=" + port}; + server.parseArgs(args); + server.start(); + return server; + } + +} diff --git a/istio-interop-testing/build.gradle b/istio-interop-testing/build.gradle index e2fe228f13b..083d8fcb9bf 100644 --- a/istio-interop-testing/build.gradle +++ b/istio-interop-testing/build.gradle @@ -18,8 +18,6 @@ dependencies { project(':grpc-testing'), project(':grpc-xds') - compileOnly libraries.javax.annotation - runtimeOnly libraries.netty.tcnative, libraries.netty.tcnative.classes testImplementation testFixtures(project(':grpc-api')), @@ -28,7 +26,11 @@ dependencies { libraries.junit, libraries.truth - signature libraries.signature.java + signature (libraries.signature.java) { + artifact { + extension = "signature" + } + } } sourceSets { diff --git a/istio-interop-testing/src/generated/main/grpc/io/istio/test/EchoTestServiceGrpc.java b/istio-interop-testing/src/generated/main/grpc/io/istio/test/EchoTestServiceGrpc.java index 1f48c16aed3..61d20d2f7bb 100644 --- a/istio-interop-testing/src/generated/main/grpc/io/istio/test/EchoTestServiceGrpc.java +++ b/istio-interop-testing/src/generated/main/grpc/io/istio/test/EchoTestServiceGrpc.java @@ -4,9 +4,6 @@ /** */ -@javax.annotation.Generated( - value = "by gRPC proto compiler", - comments = "Source: test/echo/proto/echo.proto") @io.grpc.stub.annotations.GrpcGenerated public final class EchoTestServiceGrpc { @@ -91,6 +88,21 @@ public EchoTestServiceStub newStub(io.grpc.Channel channel, io.grpc.CallOptions return EchoTestServiceStub.newStub(factory, channel); } + /** + * Creates a new blocking-style stub that supports all types of calls on the service + */ + public static EchoTestServiceBlockingV2Stub newBlockingV2Stub( + io.grpc.Channel channel) { + io.grpc.stub.AbstractStub.StubFactory factory = + new io.grpc.stub.AbstractStub.StubFactory() { + @java.lang.Override + public EchoTestServiceBlockingV2Stub newStub(io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new EchoTestServiceBlockingV2Stub(channel, callOptions); + } + }; + return EchoTestServiceBlockingV2Stub.newStub(factory, channel); + } + /** * Creates a new blocking-style stub that supports unary and streaming output calls on the service */ @@ -187,6 +199,37 @@ public void forwardEcho(io.istio.test.Echo.ForwardEchoRequest request, /** * A stub to allow clients to do synchronous rpc calls to service EchoTestService. */ + public static final class EchoTestServiceBlockingV2Stub + extends io.grpc.stub.AbstractBlockingStub { + private EchoTestServiceBlockingV2Stub( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + super(channel, callOptions); + } + + @java.lang.Override + protected EchoTestServiceBlockingV2Stub build( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new EchoTestServiceBlockingV2Stub(channel, callOptions); + } + + /** + */ + public io.istio.test.Echo.EchoResponse echo(io.istio.test.Echo.EchoRequest request) throws io.grpc.StatusException { + return io.grpc.stub.ClientCalls.blockingV2UnaryCall( + getChannel(), getEchoMethod(), getCallOptions(), request); + } + + /** + */ + public io.istio.test.Echo.ForwardEchoResponse forwardEcho(io.istio.test.Echo.ForwardEchoRequest request) throws io.grpc.StatusException { + return io.grpc.stub.ClientCalls.blockingV2UnaryCall( + getChannel(), getForwardEchoMethod(), getCallOptions(), request); + } + } + + /** + * A stub to allow clients to do limited synchronous rpc calls to service EchoTestService. + */ public static final class EchoTestServiceBlockingStub extends io.grpc.stub.AbstractBlockingStub { private EchoTestServiceBlockingStub( diff --git a/java_grpc_library.bzl b/java_grpc_library.bzl index 22487954aef..e6afc028883 100644 --- a/java_grpc_library.bzl +++ b/java_grpc_library.bzl @@ -1,5 +1,8 @@ """Build rule for java_grpc_library.""" +load("@com_google_protobuf//bazel/common:proto_info.bzl", "ProtoInfo") +load("@rules_java//java:defs.bzl", "JavaInfo", "JavaPluginInfo", "java_common") + _JavaRpcToolchainInfo = provider( fields = [ "java_toolchain", @@ -89,15 +92,15 @@ def _java_rpc_library_impl(ctx): srcjar = ctx.actions.declare_file("%s-proto-gensrc.jar" % ctx.label.name) args = ctx.actions.args() - args.add(toolchain.plugin.files_to_run.executable, format = "--plugin=protoc-gen-rpc-plugin=%s") + args.add(toolchain.plugin[DefaultInfo].files_to_run.executable, format = "--plugin=protoc-gen-rpc-plugin=%s") args.add("--rpc-plugin_out={0}:{1}".format(toolchain.plugin_arg, srcjar.path)) args.add_joined("--descriptor_set_in", descriptor_set_in, join_with = ctx.configuration.host_path_separator) args.add_all(srcs, map_each = _path_ignoring_repository) ctx.actions.run( - inputs = depset(srcs, transitive = [descriptor_set_in, toolchain.plugin.files]), + inputs = depset(srcs, transitive = [descriptor_set_in, toolchain.plugin[DefaultInfo].files]), outputs = [srcjar], - executable = toolchain.protoc.files_to_run, + executable = toolchain.protoc[DefaultInfo].files_to_run, arguments = [args], use_default_shell_env = True, toolchain = None, @@ -145,6 +148,33 @@ _java_grpc_library = rule( implementation = _java_rpc_library_impl, ) +# A copy of _java_grpc_library, except with a neverlink=1 _toolchain +INTERNAL_java_grpc_library_for_xds = rule( + attrs = { + "srcs": attr.label_list( + mandatory = True, + allow_empty = False, + providers = [ProtoInfo], + ), + "deps": attr.label_list( + mandatory = True, + allow_empty = False, + providers = [JavaInfo], + ), + "_toolchain": attr.label( + default = Label("//xds:java_grpc_library_toolchain"), + ), + }, + toolchains = ["@bazel_tools//tools/jdk:toolchain_type"], + fragments = ["java"], + outputs = { + "jar": "lib%{name}.jar", + "srcjar": "lib%{name}-src.jar", + }, + provides = [JavaInfo], + implementation = _java_rpc_library_impl, +) + _java_lite_grpc_library = rule( attrs = { "srcs": attr.label_list( diff --git a/lint.xml b/lint.xml new file mode 100644 index 00000000000..5b35a8d151b --- /dev/null +++ b/lint.xml @@ -0,0 +1,13 @@ + + + + + + + + diff --git a/netty/BUILD.bazel b/netty/BUILD.bazel index d2497d065ec..8253d1f5bff 100644 --- a/netty/BUILD.bazel +++ b/netty/BUILD.bazel @@ -1,3 +1,6 @@ +load("@rules_java//java:defs.bzl", "java_library") +load("@rules_jvm_external//:defs.bzl", "artifact") + java_library( name = "netty", srcs = glob([ @@ -10,22 +13,22 @@ java_library( deps = [ "//api", "//core:internal", - "@com_google_code_findbugs_jsr305//jar", - "@com_google_errorprone_error_prone_annotations//jar", - "@com_google_guava_guava//jar", - "@com_google_j2objc_j2objc_annotations//jar", - "@io_netty_netty_buffer//jar", - "@io_netty_netty_codec//jar", - "@io_netty_netty_codec_http//jar", - "@io_netty_netty_codec_http2//jar", - "@io_netty_netty_codec_socks//jar", - "@io_netty_netty_common//jar", - "@io_netty_netty_handler//jar", - "@io_netty_netty_handler_proxy//jar", - "@io_netty_netty_resolver//jar", - "@io_netty_netty_transport//jar", - "@io_netty_netty_transport_native_unix_common//jar", - "@io_perfmark_perfmark_api//jar", + artifact("com.google.code.findbugs:jsr305"), + artifact("com.google.errorprone:error_prone_annotations"), + artifact("com.google.guava:guava"), + artifact("io.netty:netty-buffer"), + artifact("io.netty:netty-codec"), + artifact("io.netty:netty-codec-http"), + artifact("io.netty:netty-codec-http2"), + artifact("io.netty:netty-codec-socks"), + artifact("io.netty:netty-common"), + artifact("io.netty:netty-handler"), + artifact("io.netty:netty-handler-proxy"), + artifact("io.netty:netty-resolver"), + artifact("io.netty:netty-transport"), + artifact("io.netty:netty-transport-native-unix-common"), + artifact("io.perfmark:perfmark-api"), + artifact("org.codehaus.mojo:animal-sniffer-annotations"), ], ) diff --git a/netty/build.gradle b/netty/build.gradle index 7bff4bfd377..cb97ae10b55 100644 --- a/netty/build.gradle +++ b/netty/build.gradle @@ -17,6 +17,7 @@ tasks.named("jar").configure { dependencies { api project(':grpc-api'), + libraries.animalsniffer.annotations, libraries.netty.codec.http2 implementation project(':grpc-core'), libs.netty.handler.proxy, @@ -65,14 +66,22 @@ dependencies { classifier = "linux-x86_64" } } - signature libraries.signature.java - signature libraries.signature.android + signature (libraries.signature.java) { + artifact { + extension = "signature" + } + } + signature (libraries.signature.android) { + artifact { + extension = "signature" + } + } } import net.ltgt.gradle.errorprone.CheckSeverity [tasks.named("compileJava"), tasks.named("compileTestJava")]*.configure { - // Netty retuns a lot of futures that we mostly don't care about. + // Netty returns a lot of futures that we mostly don't care about. options.errorprone.check("FutureReturnValueIgnored", CheckSeverity.OFF) } diff --git a/netty/shaded/BUILD.bazel b/netty/shaded/BUILD.bazel index 657bf6aafa9..0a93907bd2f 100644 --- a/netty/shaded/BUILD.bazel +++ b/netty/shaded/BUILD.bazel @@ -1,12 +1,15 @@ +load("@rules_java//java:defs.bzl", "java_library") +load("@rules_jvm_external//:defs.bzl", "artifact") + # Publicly exposed in //netty package. Purposefully does not export any symbols. java_library( name = "shaded", visibility = ["//netty:__pkg__"], runtime_deps = [ "//netty", - "@io_netty_netty_tcnative_boringssl_static//jar", - "@io_netty_netty_tcnative_classes//jar", - "@io_netty_netty_transport_native_unix_common//jar", - "@io_netty_netty_transport_native_epoll_linux_x86_64//jar", + artifact("io.netty:netty-tcnative-boringssl-static"), + artifact("io.netty:netty-tcnative-classes"), + artifact("io.netty:netty-transport-native-unix-common"), + artifact("io.netty:netty-transport-native-epoll_linux_x86_64"), ], ) diff --git a/netty/shaded/build.gradle b/netty/shaded/build.gradle index 3e52c3e0d95..27816f9380b 100644 --- a/netty/shaded/build.gradle +++ b/netty/shaded/build.gradle @@ -9,7 +9,7 @@ plugins { id "java" id "maven-publish" - id "com.github.johnrengelman.shadow" + id "com.gradleup.shadow" id "ru.vyarus.animalsniffer" } @@ -17,6 +17,8 @@ description = "gRPC: Netty Shaded" sourceSets { testShadow {} } +evaluationDependsOn(':grpc-netty') + dependencies { implementation project(':grpc-netty') runtimeOnly libraries.netty.tcnative, @@ -63,8 +65,16 @@ dependencies { shadow project(':grpc-netty').configurations.runtimeClasspath.allDependencies.matching { it.group != 'io.netty' } - signature libraries.signature.java - signature libraries.signature.android + signature (libraries.signature.java) { + artifact { + extension = "signature" + } + } + signature (libraries.signature.android) { + artifact { + extension = "signature" + } + } } tasks.named("jar").configure { @@ -143,7 +153,11 @@ class NettyResourceTransformer implements Transformer { @Override boolean canTransformResource(FileTreeElement fileTreeElement) { - fileTreeElement.name.startsWith("META-INF/native-image/io.netty") + // io.netty.versions.properties can't actually be shaded successfully, + // as io.netty.util.Version still looks for the unshaded name. But we + // keep the file for manual inspection. + fileTreeElement.name.startsWith("META-INF/native-image/io.netty") || + fileTreeElement.name.startsWith("META-INF/io.netty.versions.properties") } @Override diff --git a/netty/shaded/src/testShadow/java/io/grpc/netty/shaded/ShadingTest.java b/netty/shaded/src/testShadow/java/io/grpc/netty/shaded/ShadingTest.java index 7a5e4b43c8b..89803998925 100644 --- a/netty/shaded/src/testShadow/java/io/grpc/netty/shaded/ShadingTest.java +++ b/netty/shaded/src/testShadow/java/io/grpc/netty/shaded/ShadingTest.java @@ -79,7 +79,7 @@ public void nettyResourcesUpdated() throws IOException { InputStream inputStream = NettyChannelBuilder.class.getClassLoader() .getResourceAsStream( "META-INF/native-image/io.grpc.netty.shaded.io.netty/netty-transport/" - + "reflection-config.json"); + + "reflect-config.json"); assertThat(inputStream).isNotNull(); Scanner s = new Scanner(inputStream, StandardCharsets.UTF_8.name()).useDelimiter("\\A"); diff --git a/netty/src/main/java/io/grpc/netty/AbstractNettyHandler.java b/netty/src/main/java/io/grpc/netty/AbstractNettyHandler.java index 7f088509c04..c4ec5913cde 100644 --- a/netty/src/main/java/io/grpc/netty/AbstractNettyHandler.java +++ b/netty/src/main/java/io/grpc/netty/AbstractNettyHandler.java @@ -42,13 +42,15 @@ abstract class AbstractNettyHandler extends GrpcHttp2ConnectionHandler { private final int initialConnectionWindow; private final FlowControlPinger flowControlPing; - + protected final int maxHeaderListSize; + protected final int softLimitHeaderListSize; private boolean autoTuneFlowControlOn; private ChannelHandlerContext ctx; private boolean initialWindowSent = false; private final Ticker ticker; private static final long BDP_MEASUREMENT_PING = 1234; + protected static final int MIN_ALLOCATED_CHUNK = 16 * 1024; AbstractNettyHandler( ChannelPromise channelUnused, @@ -58,7 +60,9 @@ abstract class AbstractNettyHandler extends GrpcHttp2ConnectionHandler { ChannelLogger negotiationLogger, boolean autoFlowControl, PingLimiter pingLimiter, - Ticker ticker) { + Ticker ticker, + int maxHeaderListSize, + int softLimitHeaderListSize) { super(channelUnused, decoder, encoder, initialSettings, negotiationLogger); // During a graceful shutdown, wait until all streams are closed. @@ -73,6 +77,8 @@ abstract class AbstractNettyHandler extends GrpcHttp2ConnectionHandler { } this.flowControlPing = new FlowControlPinger(pingLimiter); this.ticker = checkNotNull(ticker, "ticker"); + this.maxHeaderListSize = maxHeaderListSize; + this.softLimitHeaderListSize = softLimitHeaderListSize; } @Override diff --git a/netty/src/main/java/io/grpc/netty/CancelServerStreamCommand.java b/netty/src/main/java/io/grpc/netty/CancelServerStreamCommand.java index d9f5d96e06e..a3b29457670 100644 --- a/netty/src/main/java/io/grpc/netty/CancelServerStreamCommand.java +++ b/netty/src/main/java/io/grpc/netty/CancelServerStreamCommand.java @@ -27,10 +27,23 @@ final class CancelServerStreamCommand extends WriteQueue.AbstractQueuedCommand { private final NettyServerStream.TransportState stream; private final Status reason; + private final PeerNotify peerNotify; - CancelServerStreamCommand(NettyServerStream.TransportState stream, Status reason) { + private CancelServerStreamCommand( + NettyServerStream.TransportState stream, Status reason, PeerNotify peerNotify) { this.stream = Preconditions.checkNotNull(stream, "stream"); this.reason = Preconditions.checkNotNull(reason, "reason"); + this.peerNotify = Preconditions.checkNotNull(peerNotify, "peerNotify"); + } + + static CancelServerStreamCommand withReset( + NettyServerStream.TransportState stream, Status reason) { + return new CancelServerStreamCommand(stream, reason, PeerNotify.RESET); + } + + static CancelServerStreamCommand withReason( + NettyServerStream.TransportState stream, Status reason) { + return new CancelServerStreamCommand(stream, reason, PeerNotify.BEST_EFFORT_STATUS); } NettyServerStream.TransportState stream() { @@ -41,6 +54,10 @@ Status reason() { return reason; } + boolean wantsHeaders() { + return peerNotify == PeerNotify.BEST_EFFORT_STATUS; + } + @Override public boolean equals(Object o) { if (this == o) { @@ -52,13 +69,14 @@ public boolean equals(Object o) { CancelServerStreamCommand that = (CancelServerStreamCommand) o; - return Objects.equal(this.stream, that.stream) - && Objects.equal(this.reason, that.reason); + return this.stream.equals(that.stream) + && this.reason.equals(that.reason) + && this.peerNotify.equals(that.peerNotify); } @Override public int hashCode() { - return Objects.hashCode(stream, reason); + return Objects.hashCode(stream, reason, peerNotify); } @Override @@ -66,6 +84,14 @@ public String toString() { return MoreObjects.toStringHelper(this) .add("stream", stream) .add("reason", reason) + .add("peerNotify", peerNotify) .toString(); } + + private enum PeerNotify { + /** Notify the peer by sending a RST_STREAM with no other information. */ + RESET, + /** Notify the peer about the {@link #reason} by sending structured headers, if possible. */ + BEST_EFFORT_STATUS, + } } diff --git a/netty/src/main/java/io/grpc/netty/ClientTransportLifecycleManager.java b/netty/src/main/java/io/grpc/netty/ClientTransportLifecycleManager.java index 34f72ab97bd..01e7bc3ed12 100644 --- a/netty/src/main/java/io/grpc/netty/ClientTransportLifecycleManager.java +++ b/netty/src/main/java/io/grpc/netty/ClientTransportLifecycleManager.java @@ -19,6 +19,7 @@ import com.google.errorprone.annotations.CanIgnoreReturnValue; import io.grpc.Attributes; import io.grpc.Status; +import io.grpc.internal.DisconnectError; import io.grpc.internal.ManagedClientTransport; /** Maintainer of transport lifecycle status. */ @@ -30,7 +31,6 @@ final class ClientTransportLifecycleManager { /** null iff !transportShutdown. */ private Status shutdownStatus; /** null iff !transportShutdown. */ - private Throwable shutdownThrowable; private boolean transportTerminated; public ClientTransportLifecycleManager(ManagedClientTransport.Listener listener) { @@ -56,23 +56,22 @@ public void notifyReady() { * Marks transport as shutdown, but does not set the error status. This must eventually be * followed by a call to notifyShutdown. */ - public void notifyGracefulShutdown(Status s) { + public void notifyGracefulShutdown(Status s, DisconnectError disconnectError) { if (transportShutdown) { return; } transportShutdown = true; - listener.transportShutdown(s); + listener.transportShutdown(s, disconnectError); } /** Returns {@code true} if was the first shutdown. */ @CanIgnoreReturnValue - public boolean notifyShutdown(Status s) { - notifyGracefulShutdown(s); + public boolean notifyShutdown(Status s, DisconnectError disconnectError) { + notifyGracefulShutdown(s, disconnectError); if (shutdownStatus != null) { return false; } shutdownStatus = s; - shutdownThrowable = s.asException(); return true; } @@ -84,12 +83,12 @@ public void notifyInUse(boolean inUse) { listener.transportInUse(inUse); } - public void notifyTerminated(Status s) { + public void notifyTerminated(Status s, DisconnectError disconnectError) { if (transportTerminated) { return; } transportTerminated = true; - notifyShutdown(s); + notifyShutdown(s, disconnectError); listener.transportTerminated(); } @@ -97,7 +96,4 @@ public Status getShutdownStatus() { return shutdownStatus; } - public Throwable getShutdownThrowable() { - return shutdownThrowable; - } } diff --git a/netty/src/main/java/io/grpc/netty/GrpcHttp2ConnectionHandler.java b/netty/src/main/java/io/grpc/netty/GrpcHttp2ConnectionHandler.java index 13f55226483..ee5227484fb 100644 --- a/netty/src/main/java/io/grpc/netty/GrpcHttp2ConnectionHandler.java +++ b/netty/src/main/java/io/grpc/netty/GrpcHttp2ConnectionHandler.java @@ -34,14 +34,11 @@ */ @Internal public abstract class GrpcHttp2ConnectionHandler extends Http2ConnectionHandler { - static final int ADAPTIVE_CUMULATOR_COMPOSE_MIN_SIZE_DEFAULT = 1024; - static final Cumulator ADAPTIVE_CUMULATOR = - new NettyAdaptiveCumulator(ADAPTIVE_CUMULATOR_COMPOSE_MIN_SIZE_DEFAULT); - @Nullable protected final ChannelPromise channelUnused; private final ChannelLogger negotiationLogger; + @SuppressWarnings("this-escape") protected GrpcHttp2ConnectionHandler( ChannelPromise channelUnused, Http2ConnectionDecoder decoder, @@ -51,7 +48,6 @@ protected GrpcHttp2ConnectionHandler( super(decoder, encoder, initialSettings); this.channelUnused = channelUnused; this.negotiationLogger = negotiationLogger; - setCumulator(ADAPTIVE_CUMULATOR); } /** diff --git a/netty/src/main/java/io/grpc/netty/GrpcHttp2HeadersUtils.java b/netty/src/main/java/io/grpc/netty/GrpcHttp2HeadersUtils.java index c0d60721a1b..96c4310ae3d 100644 --- a/netty/src/main/java/io/grpc/netty/GrpcHttp2HeadersUtils.java +++ b/netty/src/main/java/io/grpc/netty/GrpcHttp2HeadersUtils.java @@ -31,12 +31,12 @@ package io.grpc.netty; -import static com.google.common.base.Charsets.US_ASCII; import static com.google.common.base.Preconditions.checkArgument; import static io.grpc.netty.Utils.TE_HEADER; import static io.netty.handler.codec.http2.Http2Error.PROTOCOL_ERROR; import static io.netty.handler.codec.http2.Http2Exception.connectionError; import static io.netty.util.AsciiString.isUpperCase; +import static java.nio.charset.StandardCharsets.US_ASCII; import com.google.common.io.BaseEncoding; import com.google.errorprone.annotations.CanIgnoreReturnValue; diff --git a/netty/src/main/java/io/grpc/netty/GrpcHttp2OutboundHeaders.java b/netty/src/main/java/io/grpc/netty/GrpcHttp2OutboundHeaders.java index 0489e135813..aabcd4fbaaa 100644 --- a/netty/src/main/java/io/grpc/netty/GrpcHttp2OutboundHeaders.java +++ b/netty/src/main/java/io/grpc/netty/GrpcHttp2OutboundHeaders.java @@ -66,6 +66,16 @@ private GrpcHttp2OutboundHeaders(AsciiString[] preHeaders, byte[][] serializedMe this.preHeaders = preHeaders; } + @Override + public CharSequence authority() { + for (int i = 0; i < preHeaders.length / 2; i++) { + if (preHeaders[i * 2].equals(Http2Headers.PseudoHeaderName.AUTHORITY.value())) { + return preHeaders[i * 2 + 1]; + } + } + return null; + } + @Override @SuppressWarnings("ReferenceEquality") // STATUS.value() never changes. public CharSequence status() { diff --git a/netty/src/main/java/io/grpc/netty/GrpcSslContexts.java b/netty/src/main/java/io/grpc/netty/GrpcSslContexts.java index 04a290165d7..f1f2c8aed71 100644 --- a/netty/src/main/java/io/grpc/netty/GrpcSslContexts.java +++ b/netty/src/main/java/io/grpc/netty/GrpcSslContexts.java @@ -84,6 +84,7 @@ private GrpcSslContexts() {} private static final String SUN_PROVIDER_NAME = "SunJSSE"; private static final String IBM_PROVIDER_NAME = "IBMJSSE2"; private static final String OPENJSSE_PROVIDER_NAME = "OpenJSSE"; + private static final String BCJSSE_PROVIDER_NAME = "BCJSSE"; /** * Creates an SslContextBuilder with ciphers and APN appropriate for gRPC. @@ -199,7 +200,8 @@ public static SslContextBuilder configure(SslContextBuilder builder, Provider jd jdkProvider.getName() + " selected, but Java 9+ and Jetty NPN/ALPN unavailable"); } } else if (IBM_PROVIDER_NAME.equals(jdkProvider.getName()) - || OPENJSSE_PROVIDER_NAME.equals(jdkProvider.getName())) { + || OPENJSSE_PROVIDER_NAME.equals(jdkProvider.getName()) + || BCJSSE_PROVIDER_NAME.equals(jdkProvider.getName())) { if (JettyTlsUtil.isJava9AlpnAvailable()) { apc = ALPN; } else { @@ -255,7 +257,8 @@ private static Provider findJdkProvider() { return provider; } } else if (IBM_PROVIDER_NAME.equals(provider.getName()) - || OPENJSSE_PROVIDER_NAME.equals(provider.getName())) { + || OPENJSSE_PROVIDER_NAME.equals(provider.getName()) + || BCJSSE_PROVIDER_NAME.equals(provider.getName())) { if (JettyTlsUtil.isJava9AlpnAvailable()) { return provider; } diff --git a/netty/src/main/java/io/grpc/netty/InternalProtocolNegotiators.java b/netty/src/main/java/io/grpc/netty/InternalProtocolNegotiators.java index 0d309828c6d..35dc1bbc2e8 100644 --- a/netty/src/main/java/io/grpc/netty/InternalProtocolNegotiators.java +++ b/netty/src/main/java/io/grpc/netty/InternalProtocolNegotiators.java @@ -16,13 +16,17 @@ package io.grpc.netty; +import com.google.common.base.Optional; import io.grpc.ChannelLogger; +import io.grpc.internal.ObjectPool; import io.grpc.netty.ProtocolNegotiators.ClientTlsHandler; import io.grpc.netty.ProtocolNegotiators.GrpcNegotiationHandler; import io.grpc.netty.ProtocolNegotiators.WaitUntilActiveHandler; import io.netty.channel.ChannelHandler; import io.netty.handler.ssl.SslContext; import io.netty.util.AsciiString; +import java.util.concurrent.Executor; +import javax.net.ssl.X509TrustManager; /** * Internal accessor for {@link ProtocolNegotiators}. @@ -35,9 +39,15 @@ private InternalProtocolNegotiators() {} * Returns a {@link ProtocolNegotiator} that ensures the pipeline is set up so that TLS will * be negotiated, the {@code handler} is added and writes to the {@link io.netty.channel.Channel} * may happen immediately, even before the TLS Handshake is complete. + * @param executorPool a dedicated {@link Executor} pool for time-consuming TLS tasks */ - public static InternalProtocolNegotiator.ProtocolNegotiator tls(SslContext sslContext) { - final io.grpc.netty.ProtocolNegotiator negotiator = ProtocolNegotiators.tls(sslContext); + public static InternalProtocolNegotiator.ProtocolNegotiator tls(SslContext sslContext, + ObjectPool executorPool, + Optional handshakeCompleteRunnable, + X509TrustManager extendedX509TrustManager, + String sni) { + final io.grpc.netty.ProtocolNegotiator negotiator = ProtocolNegotiators.tls(sslContext, + executorPool, handshakeCompleteRunnable, extendedX509TrustManager, sni); final class TlsNegotiator implements InternalProtocolNegotiator.ProtocolNegotiator { @Override @@ -55,10 +65,21 @@ public void close() { negotiator.close(); } } - + return new TlsNegotiator(); } + /** + * Returns a {@link ProtocolNegotiator} that ensures the pipeline is set up so that TLS will + * be negotiated, the {@code handler} is added and writes to the {@link io.netty.channel.Channel} + * may happen immediately, even before the TLS Handshake is complete. + */ + public static InternalProtocolNegotiator.ProtocolNegotiator tls( + SslContext sslContext, String sni, + X509TrustManager extendedX509TrustManager) { + return tls(sslContext, null, Optional.absent(), extendedX509TrustManager, sni); + } + /** * Returns a {@link ProtocolNegotiator} that ensures the pipeline is set up so that TLS will be * negotiated, the server TLS {@code handler} is added and writes to the {@link @@ -153,7 +174,8 @@ public static ChannelHandler grpcNegotiationHandler(GrpcHttp2ConnectionHandler n public static ChannelHandler clientTlsHandler( ChannelHandler next, SslContext sslContext, String authority, ChannelLogger negotiationLogger) { - return new ClientTlsHandler(next, sslContext, authority, null, negotiationLogger); + return new ClientTlsHandler(next, sslContext, authority, null, negotiationLogger, + Optional.absent(), null, null); } public static class ProtocolNegotiationHandler diff --git a/netty/src/main/java/io/grpc/netty/NettyAdaptiveCumulator.java b/netty/src/main/java/io/grpc/netty/NettyAdaptiveCumulator.java index 58eabb2cf8d..a1fe3fc2c38 100644 --- a/netty/src/main/java/io/grpc/netty/NettyAdaptiveCumulator.java +++ b/netty/src/main/java/io/grpc/netty/NettyAdaptiveCumulator.java @@ -23,6 +23,12 @@ import io.netty.buffer.CompositeByteBuf; import io.netty.handler.codec.ByteToMessageDecoder.Cumulator; + +/** + * "Adaptive" cumulator: cumulate {@link ByteBuf}s by dynamically switching between merge and + * compose strategies. + */ + class NettyAdaptiveCumulator implements Cumulator { private final int composeMinSize; @@ -152,6 +158,7 @@ static void mergeWithCompositeTail( try { if (tail.refCnt() == 1 && !tail.isReadOnly() && newTailSize <= tail.maxCapacity()) { // Ideal case: the tail isn't shared, and can be expanded to the required capacity. + // Take ownership of the tail. newTail = tail.retain(); @@ -188,6 +195,7 @@ static void mergeWithCompositeTail( * as pronounced because the capacity is doubled with each reallocation. */ newTail.writeBytes(in); + } else { // The tail is shared, or not expandable. Replace it with a new buffer of desired capacity. newTail = alloc.buffer(alloc.calculateNewCapacity(newTailSize, Integer.MAX_VALUE)); @@ -196,6 +204,7 @@ static void mergeWithCompositeTail( .writerIndex(newTailSize); in.readerIndex(in.writerIndex()); } + // Store readerIndex to avoid out of bounds writerIndex during component replacement. int prevReader = composite.readerIndex(); // Remove the old tail, reset writer index. diff --git a/netty/src/main/java/io/grpc/netty/NettyChannelBuilder.java b/netty/src/main/java/io/grpc/netty/NettyChannelBuilder.java index 305ad128454..e64f1065681 100644 --- a/netty/src/main/java/io/grpc/netty/NettyChannelBuilder.java +++ b/netty/src/main/java/io/grpc/netty/NettyChannelBuilder.java @@ -23,6 +23,7 @@ import static io.grpc.internal.GrpcUtil.KEEPALIVE_TIME_NANOS_DISABLED; import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Optional; import com.google.common.base.Ticker; import com.google.errorprone.annotations.CanIgnoreReturnValue; import com.google.errorprone.annotations.CheckReturnValue; @@ -103,6 +104,7 @@ public final class NettyChannelBuilder extends ForwardingChannelBuilder2 0, "maxInboundMetadataSize must be > 0"); this.maxHeaderListSize = bytes; + // Clear the soft limit setting, by setting soft limit to maxInboundMetadataSize. The + // maxInboundMetadataSize will take precedence be applied before soft limit check. + this.softLimitHeaderListSize = bytes; + return this; + } + + /** + * Sets the size of metadata that clients are advised to not exceed. When a metadata with size + * larger than the soft limit is encountered there will be a probability the RPC will fail. The + * chance of failing increases as the metadata size approaches the hard limit. + * {@code Integer.MAX_VALUE} disables the enforcement. The default is implementation-dependent, + * but is not generally less than 8 KiB and may be unlimited. + * + *

This is cumulative size of the metadata. The precise calculation is + * implementation-dependent, but implementations are encouraged to follow the calculation used + * for + * HTTP/2's + * SETTINGS_MAX_HEADER_LIST_SIZE. It sums the bytes from each entry's key and value, plus 32 + * bytes of overhead per entry. + * + * @param soft the soft size limit of received metadata + * @param max the hard size limit of received metadata + * @return this + * @throws IllegalArgumentException if soft and/or max is non-positive, or max smaller than + * soft + * @since 1.68.0 + */ + @CanIgnoreReturnValue + public NettyChannelBuilder maxInboundMetadataSize(int soft, int max) { + checkArgument(soft > 0, "softLimitHeaderListSize must be > 0"); + checkArgument(max > soft, + "maxInboundMetadataSize must be greater than softLimitHeaderListSize"); + this.softLimitHeaderListSize = soft; + this.maxHeaderListSize = max; return this; } @@ -572,10 +608,22 @@ ClientTransportFactory buildTransportFactory() { ProtocolNegotiator negotiator = protocolNegotiatorFactory.newNegotiator(); return new NettyTransportFactory( - negotiator, channelFactory, channelOptions, - eventLoopGroupPool, autoFlowControl, flowControlWindow, maxInboundMessageSize, - maxHeaderListSize, keepAliveTimeNanos, keepAliveTimeoutNanos, keepAliveWithoutCalls, - transportTracerFactory, localSocketPicker, useGetForSafeMethods, transportSocketType); + negotiator, + channelFactory, + channelOptions, + eventLoopGroupPool, + autoFlowControl, + flowControlWindow, + maxInboundMessageSize, + maxHeaderListSize, + softLimitHeaderListSize, + keepAliveTimeNanos, + keepAliveTimeoutNanos, + keepAliveWithoutCalls, + transportTracerFactory, + localSocketPicker, + useGetForSafeMethods, + transportSocketType); } @VisibleForTesting @@ -604,7 +652,7 @@ static ProtocolNegotiator createProtocolNegotiatorByType( case PLAINTEXT_UPGRADE: return ProtocolNegotiators.plaintextUpgrade(); case TLS: - return ProtocolNegotiators.tls(sslContext, executorPool); + return ProtocolNegotiators.tls(sslContext, executorPool, Optional.absent(), null, null); default: throw new IllegalArgumentException("Unsupported negotiationType: " + negotiationType); } @@ -709,6 +757,7 @@ private static final class NettyTransportFactory implements ClientTransportFacto private final int flowControlWindow; private final int maxMessageSize; private final int maxHeaderListSize; + private final int softLimitHeaderListSize; private final long keepAliveTimeNanos; private final AtomicBackoff keepAliveBackoff; private final long keepAliveTimeoutNanos; @@ -723,11 +772,20 @@ private static final class NettyTransportFactory implements ClientTransportFacto NettyTransportFactory( ProtocolNegotiator protocolNegotiator, ChannelFactory channelFactory, - Map, ?> channelOptions, ObjectPool groupPool, - boolean autoFlowControl, int flowControlWindow, int maxMessageSize, int maxHeaderListSize, - long keepAliveTimeNanos, long keepAliveTimeoutNanos, boolean keepAliveWithoutCalls, - TransportTracer.Factory transportTracerFactory, LocalSocketPicker localSocketPicker, - boolean useGetForSafeMethods, Class transportSocketType) { + Map, ?> channelOptions, + ObjectPool groupPool, + boolean autoFlowControl, + int flowControlWindow, + int maxMessageSize, + int maxHeaderListSize, + int softLimitHeaderListSize, + long keepAliveTimeNanos, + long keepAliveTimeoutNanos, + boolean keepAliveWithoutCalls, + TransportTracer.Factory transportTracerFactory, + LocalSocketPicker localSocketPicker, + boolean useGetForSafeMethods, + Class transportSocketType) { this.protocolNegotiator = checkNotNull(protocolNegotiator, "protocolNegotiator"); this.channelFactory = channelFactory; this.channelOptions = new HashMap, Object>(channelOptions); @@ -737,6 +795,7 @@ private static final class NettyTransportFactory implements ClientTransportFacto this.flowControlWindow = flowControlWindow; this.maxMessageSize = maxMessageSize; this.maxHeaderListSize = maxHeaderListSize; + this.softLimitHeaderListSize = softLimitHeaderListSize; this.keepAliveTimeNanos = keepAliveTimeNanos; this.keepAliveBackoff = new AtomicBackoff("keepalive time nanos", keepAliveTimeNanos); this.keepAliveTimeoutNanos = keepAliveTimeoutNanos; @@ -759,6 +818,7 @@ public ConnectionClientTransport newClientTransport( serverAddress = proxiedAddr.getTargetAddress(); localNegotiator = ProtocolNegotiators.httpProxy( proxiedAddr.getProxyAddress(), + proxiedAddr.getHeaders(), proxiedAddr.getUsername(), proxiedAddr.getPassword(), protocolNegotiator); @@ -773,13 +833,31 @@ public void run() { }; // TODO(carl-mastrangelo): Pass channelLogger in. - NettyClientTransport transport = new NettyClientTransport( - serverAddress, channelFactory, channelOptions, group, - localNegotiator, autoFlowControl, flowControlWindow, - maxMessageSize, maxHeaderListSize, keepAliveTimeNanosState.get(), keepAliveTimeoutNanos, - keepAliveWithoutCalls, options.getAuthority(), options.getUserAgent(), - tooManyPingsRunnable, transportTracerFactory.create(), options.getEagAttributes(), - localSocketPicker, channelLogger, useGetForSafeMethods, Ticker.systemTicker()); + NettyClientTransport transport = + new NettyClientTransport( + serverAddress, + channelFactory, + channelOptions, + group, + localNegotiator, + autoFlowControl, + flowControlWindow, + maxMessageSize, + maxHeaderListSize, + softLimitHeaderListSize, + keepAliveTimeNanosState.get(), + keepAliveTimeoutNanos, + keepAliveWithoutCalls, + options.getAuthority(), + options.getUserAgent(), + tooManyPingsRunnable, + transportTracerFactory.create(), + options.getEagAttributes(), + localSocketPicker, + channelLogger, + useGetForSafeMethods, + options.getMetricRecorder(), + Ticker.systemTicker()); return transport; } @@ -795,11 +873,24 @@ public SwapChannelCredentialsResult swapChannelCredentials(ChannelCredentials ch if (result.error != null) { return null; } - ClientTransportFactory factory = new NettyTransportFactory( - result.negotiator.newNegotiator(), channelFactory, channelOptions, groupPool, - autoFlowControl, flowControlWindow, maxMessageSize, maxHeaderListSize, keepAliveTimeNanos, - keepAliveTimeoutNanos, keepAliveWithoutCalls, transportTracerFactory, localSocketPicker, - useGetForSafeMethods, transportSocketType); + ClientTransportFactory factory = + new NettyTransportFactory( + result.negotiator.newNegotiator(), + channelFactory, + channelOptions, + groupPool, + autoFlowControl, + flowControlWindow, + maxMessageSize, + maxHeaderListSize, + softLimitHeaderListSize, + keepAliveTimeNanos, + keepAliveTimeoutNanos, + keepAliveWithoutCalls, + transportTracerFactory, + localSocketPicker, + useGetForSafeMethods, + transportSocketType); return new SwapChannelCredentialsResult(factory, result.callCredentials); } diff --git a/netty/src/main/java/io/grpc/netty/NettyClientHandler.java b/netty/src/main/java/io/grpc/netty/NettyClientHandler.java index eb4dbf8cc66..14a1d7535ad 100644 --- a/netty/src/main/java/io/grpc/netty/NettyClientHandler.java +++ b/netty/src/main/java/io/grpc/netty/NettyClientHandler.java @@ -28,16 +28,21 @@ import io.grpc.Attributes; import io.grpc.ChannelLogger; import io.grpc.InternalChannelz; +import io.grpc.InternalStatus; import io.grpc.Metadata; +import io.grpc.MetricRecorder; import io.grpc.Status; import io.grpc.StatusException; import io.grpc.internal.ClientStreamListener.RpcProgress; import io.grpc.internal.ClientTransport.PingCallback; +import io.grpc.internal.DisconnectError; +import io.grpc.internal.GoAwayDisconnectError; import io.grpc.internal.GrpcAttributes; import io.grpc.internal.GrpcUtil; import io.grpc.internal.Http2Ping; import io.grpc.internal.InUseStateAggregator; import io.grpc.internal.KeepAliveManager; +import io.grpc.internal.SimpleDisconnectError; import io.grpc.internal.TransportTracer; import io.grpc.netty.GrpcHttp2HeadersUtils.GrpcHttp2ClientHeadersDecoder; import io.netty.buffer.ByteBuf; @@ -77,12 +82,14 @@ import io.netty.handler.codec.http2.Http2Stream; import io.netty.handler.codec.http2.Http2StreamVisitor; import io.netty.handler.codec.http2.StreamBufferingEncoder; -import io.netty.handler.codec.http2.WeightedFairQueueByteDistributor; +import io.netty.handler.codec.http2.UniformStreamByteDistributor; import io.netty.handler.logging.LogLevel; import io.perfmark.PerfMark; import io.perfmark.Tag; import io.perfmark.TaskCloseable; import java.nio.channels.ClosedChannelException; +import java.util.LinkedHashMap; +import java.util.Map; import java.util.concurrent.Executor; import java.util.logging.Level; import java.util.logging.Logger; @@ -94,6 +101,8 @@ */ class NettyClientHandler extends AbstractNettyHandler { private static final Logger logger = Logger.getLogger(NettyClientHandler.class.getName()); + static boolean enablePerRpcAuthorityCheck = + GrpcUtil.getFlag("GRPC_ENABLE_PER_RPC_AUTHORITY_CHECK", false); /** * A message that simply passes through the channel without any real processing. It is useful to @@ -115,6 +124,7 @@ class NettyClientHandler extends AbstractNettyHandler { private final Supplier stopwatchFactory; private final TransportTracer transportTracer; private final Attributes eagAttributes; + private final TcpMetrics tcpMetrics; private final String authority; private final InUseStateAggregator inUseState = new InUseStateAggregator() { @@ -128,6 +138,13 @@ protected void handleNotInUse() { lifecycleManager.notifyInUse(false); } }; + private final Map peerVerificationResults = + new LinkedHashMap() { + @Override + protected boolean removeEldestEntry(Map.Entry eldest) { + return size() > 100; + } + }; private WriteQueue clientWriteQueue; private Http2Ping ping; @@ -142,13 +159,15 @@ static NettyClientHandler newHandler( boolean autoFlowControl, int flowControlWindow, int maxHeaderListSize, + int softLimitHeaderListSize, Supplier stopwatchFactory, Runnable tooManyPingsRunnable, TransportTracer transportTracer, Attributes eagAttributes, String authority, ChannelLogger negotiationLogger, - Ticker ticker) { + Ticker ticker, + MetricRecorder metricRecorder) { Preconditions.checkArgument(maxHeaderListSize > 0, "maxHeaderListSize must be positive"); Http2HeadersDecoder headersDecoder = new GrpcHttp2ClientHeadersDecoder(maxHeaderListSize); Http2FrameReader frameReader = new DefaultHttp2FrameReader(headersDecoder); @@ -156,8 +175,8 @@ static NettyClientHandler newHandler( Http2HeadersEncoder.NEVER_SENSITIVE, false, 16, Integer.MAX_VALUE); Http2FrameWriter frameWriter = new DefaultHttp2FrameWriter(encoder); Http2Connection connection = new DefaultHttp2Connection(false); - WeightedFairQueueByteDistributor dist = new WeightedFairQueueByteDistributor(connection); - dist.allocationQuantum(16 * 1024); // Make benchmarks fast again. + UniformStreamByteDistributor dist = new UniformStreamByteDistributor(connection); + dist.minAllocationChunk(MIN_ALLOCATED_CHUNK); // Increased for benchmarks performance. DefaultHttp2RemoteFlowController controller = new DefaultHttp2RemoteFlowController(connection, dist); connection.remote().flowController(controller); @@ -171,13 +190,15 @@ static NettyClientHandler newHandler( autoFlowControl, flowControlWindow, maxHeaderListSize, + softLimitHeaderListSize, stopwatchFactory, tooManyPingsRunnable, transportTracer, eagAttributes, authority, negotiationLogger, - ticker); + ticker, + metricRecorder); } @VisibleForTesting @@ -190,18 +211,22 @@ static NettyClientHandler newHandler( boolean autoFlowControl, int flowControlWindow, int maxHeaderListSize, + int softLimitHeaderListSize, Supplier stopwatchFactory, Runnable tooManyPingsRunnable, TransportTracer transportTracer, Attributes eagAttributes, String authority, ChannelLogger negotiationLogger, - Ticker ticker) { + Ticker ticker, + MetricRecorder metricRecorder) { Preconditions.checkNotNull(connection, "connection"); Preconditions.checkNotNull(frameReader, "frameReader"); Preconditions.checkNotNull(lifecycleManager, "lifecycleManager"); Preconditions.checkArgument(flowControlWindow > 0, "flowControlWindow must be positive"); Preconditions.checkArgument(maxHeaderListSize > 0, "maxHeaderListSize must be positive"); + Preconditions.checkArgument(softLimitHeaderListSize > 0, + "softLimitHeaderListSize must be positive"); Preconditions.checkNotNull(stopwatchFactory, "stopwatchFactory"); Preconditions.checkNotNull(tooManyPingsRunnable, "tooManyPingsRunnable"); Preconditions.checkNotNull(eagAttributes, "eagAttributes"); @@ -247,7 +272,10 @@ static NettyClientHandler newHandler( authority, autoFlowControl, pingCounter, - ticker); + ticker, + maxHeaderListSize, + softLimitHeaderListSize, + metricRecorder); } private NettyClientHandler( @@ -264,9 +292,21 @@ private NettyClientHandler( String authority, boolean autoFlowControl, PingLimiter pingLimiter, - Ticker ticker) { - super(/* channelUnused= */ null, decoder, encoder, settings, - negotiationLogger, autoFlowControl, pingLimiter, ticker); + Ticker ticker, + int maxHeaderListSize, + int softLimitHeaderListSize, + MetricRecorder metricRecorder) { + super( + /* channelUnused= */ null, + decoder, + encoder, + settings, + negotiationLogger, + autoFlowControl, + pingLimiter, + ticker, + maxHeaderListSize, + softLimitHeaderListSize); this.lifecycleManager = lifecycleManager; this.keepAliveManager = keepAliveManager; this.stopwatchFactory = stopwatchFactory; @@ -275,6 +315,7 @@ private NettyClientHandler( this.authority = authority; this.attributes = Attributes.newBuilder() .set(GrpcAttributes.ATTR_CLIENT_EAG_ATTRS, eagAttributes).build(); + this.tcpMetrics = new TcpMetrics(metricRecorder); // Set the frame listener on the decoder. decoder().frameListener(new FrameListener()); @@ -380,6 +421,28 @@ private void onHeadersRead(int streamId, Http2Headers headers, boolean endStream if (streamId != Http2CodecUtil.HTTP_UPGRADE_STREAM_ID) { NettyClientStream.TransportState stream = clientStream(requireHttp2Stream(streamId)); PerfMark.event("NettyClientHandler.onHeadersRead", stream.tag()); + // check metadata size vs soft limit + int h2HeadersSize = Utils.getH2HeadersSize(headers); + boolean shouldFail = + Utils.shouldRejectOnMetadataSizeSoftLimitExceeded( + h2HeadersSize, softLimitHeaderListSize, maxHeaderListSize); + if (shouldFail && endStream) { + stream.transportReportStatus(Status.RESOURCE_EXHAUSTED + .withDescription( + String.format( + "Server Status + Trailers of size %d exceeded Metadata size soft limit: %d", + h2HeadersSize, + softLimitHeaderListSize)), true, new Metadata()); + return; + } else if (shouldFail) { + stream.transportReportStatus(Status.RESOURCE_EXHAUSTED + .withDescription( + String.format( + "Server Headers of size %d exceeded Metadata size soft limit: %d", + h2HeadersSize, + softLimitHeaderListSize)), true, new Metadata()); + return; + } stream.transportHeadersReceived(headers, endStream); } @@ -423,10 +486,12 @@ private void onRstStreamRead(int streamId, long errorCode) { @Override public void close(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception { + tcpMetrics.recordTcpInfo(ctx.channel()); logger.fine("Network channel being closed by the application."); if (ctx.channel().isActive()) { // Ignore notification that the socket was closed lifecycleManager.notifyShutdown( - Status.UNAVAILABLE.withDescription("Transport closed for unknown reason")); + Status.UNAVAILABLE.withDescription("Transport closed for unknown reason"), + SimpleDisconnectError.UNKNOWN); } super.close(ctx, promise); } @@ -434,12 +499,19 @@ public void close(ChannelHandlerContext ctx, ChannelPromise promise) throws Exce /** * Handler for the Channel shutting down. */ + @Override + public void channelActive(ChannelHandlerContext ctx) throws Exception { + tcpMetrics.channelActive(ctx.channel()); + super.channelActive(ctx); + } + @Override public void channelInactive(ChannelHandlerContext ctx) throws Exception { try { logger.fine("Network channel is closed"); + tcpMetrics.channelInactive(ctx.channel()); Status status = Status.UNAVAILABLE.withDescription("Network closed for unknown reason"); - lifecycleManager.notifyShutdown(status); + lifecycleManager.notifyShutdown(status, SimpleDisconnectError.UNKNOWN); final Status streamStatus; if (channelInactiveReason != null) { streamStatus = channelInactiveReason; @@ -447,7 +519,7 @@ public void channelInactive(ChannelHandlerContext ctx) throws Exception { streamStatus = lifecycleManager.getShutdownStatus(); } try { - cancelPing(lifecycleManager.getShutdownThrowable()); + cancelPing(lifecycleManager.getShutdownStatus()); // Report status to the application layer for any open streams connection().forEachActiveStream(new Http2StreamVisitor() { @Override @@ -460,7 +532,7 @@ public boolean visit(Http2Stream stream) throws Http2Exception { } }); } finally { - lifecycleManager.notifyTerminated(status); + lifecycleManager.notifyTerminated(status, SimpleDisconnectError.UNKNOWN); } } finally { // Close any open streams @@ -508,7 +580,8 @@ InternalChannelz.Security getSecurityInfo() { protected void onConnectionError(ChannelHandlerContext ctx, boolean outbound, Throwable cause, Http2Exception http2Ex) { logger.log(Level.FINE, "Caught a connection error", cause); - lifecycleManager.notifyShutdown(Utils.statusFromThrowable(cause)); + lifecycleManager.notifyShutdown(Utils.statusFromThrowable(cause), + SimpleDisconnectError.SOCKET_ERROR); // Parent class will shut down the Channel super.onConnectionError(ctx, outbound, cause, http2Ex); } @@ -541,16 +614,67 @@ protected boolean isGracefulShutdownComplete() { */ private void createStream(CreateStreamCommand command, ChannelPromise promise) throws Exception { - if (lifecycleManager.getShutdownThrowable() != null) { + if (lifecycleManager.getShutdownStatus() != null) { command.stream().setNonExistent(); // The connection is going away (it is really the GOAWAY case), // just terminate the stream now. command.stream().transportReportStatus( lifecycleManager.getShutdownStatus(), RpcProgress.MISCARRIED, true, new Metadata()); - promise.setFailure(lifecycleManager.getShutdownThrowable()); + promise.setFailure(InternalStatus.asRuntimeExceptionWithoutStacktrace( + lifecycleManager.getShutdownStatus(), null)); return; } + CharSequence authorityHeader = command.headers().authority(); + if (authorityHeader == null) { + Status authorityVerificationStatus = Status.UNAVAILABLE.withDescription( + "Missing authority header"); + command.stream().setNonExistent(); + command.stream().transportReportStatus( + Status.UNAVAILABLE, RpcProgress.PROCESSED, true, new Metadata()); + promise.setFailure(InternalStatus.asRuntimeExceptionWithoutStacktrace( + authorityVerificationStatus, null)); + return; + } + // No need to verify authority for the rpc outgoing header if it is same as the authority + // for the transport + if (!authority.contentEquals(authorityHeader)) { + Status authorityVerificationStatus = peerVerificationResults.get( + authorityHeader.toString()); + if (authorityVerificationStatus == null) { + if (attributes.get(GrpcAttributes.ATTR_AUTHORITY_VERIFIER) == null) { + authorityVerificationStatus = Status.UNAVAILABLE.withDescription( + "Authority verifier not found to verify authority"); + command.stream().setNonExistent(); + command.stream().transportReportStatus( + authorityVerificationStatus, RpcProgress.PROCESSED, true, new Metadata()); + promise.setFailure(InternalStatus.asRuntimeExceptionWithoutStacktrace( + authorityVerificationStatus, null)); + return; + } + authorityVerificationStatus = attributes.get(GrpcAttributes.ATTR_AUTHORITY_VERIFIER) + .verifyAuthority(authorityHeader.toString()); + peerVerificationResults.put(authorityHeader.toString(), authorityVerificationStatus); + if (!authorityVerificationStatus.isOk() && !enablePerRpcAuthorityCheck) { + logger.log(Level.WARNING, String.format("%s.%s", + authorityVerificationStatus.getDescription(), + enablePerRpcAuthorityCheck + ? "" : " This will be an error in the future."), + InternalStatus.asRuntimeExceptionWithoutStacktrace( + authorityVerificationStatus, null)); + } + } + if (!authorityVerificationStatus.isOk()) { + if (enablePerRpcAuthorityCheck) { + command.stream().setNonExistent(); + command.stream().transportReportStatus( + authorityVerificationStatus, RpcProgress.PROCESSED, true, new Metadata()); + promise.setFailure(InternalStatus.asRuntimeExceptionWithoutStacktrace( + authorityVerificationStatus, null)); + return; + } + } + } // Get the stream ID for the new stream. int streamId; try { @@ -564,7 +688,7 @@ private void createStream(CreateStreamCommand command, ChannelPromise promise) if (!connection().goAwaySent()) { logger.fine("Stream IDs have been exhausted for this connection. " + "Initiating graceful shutdown of the connection."); - lifecycleManager.notifyShutdown(e.getStatus()); + lifecycleManager.notifyShutdown(e.getStatus(), SimpleDisconnectError.UNKNOWN); close(ctx(), ctx().newPromise()); } return; @@ -635,14 +759,19 @@ public void operationComplete(ChannelFuture future) throws Exception { // Attach the client stream to the HTTP/2 stream object as user data. stream.setHttp2Stream(http2Stream); + promise.setSuccess(); + } else { + // Otherwise, the stream has been cancelled and Netty is sending a + // RST_STREAM frame which causes it to purge pending writes from the + // flow-controller and delete the http2Stream. The stream listener has already + // been notified of cancellation so there is nothing to do. + // + // This process has been observed to fail in some circumstances, leaving listeners + // unanswered. Ensure that some exception has been delivered consistent with the + // implied RST_STREAM result above. + Status status = Status.INTERNAL.withDescription("unknown stream for connection"); + promise.setFailure(status.asRuntimeException()); } - // Otherwise, the stream has been cancelled and Netty is sending a - // RST_STREAM frame which causes it to purge pending writes from the - // flow-controller and delete the http2Stream. The stream listener has already - // been notified of cancellation so there is nothing to do. - - // Just forward on the success status to the original promise. - promise.setSuccess(); } else { Throwable cause = future.cause(); if (cause instanceof StreamBufferingEncoder.Http2GoAwayException) { @@ -665,6 +794,19 @@ public void operationComplete(ChannelFuture future) throws Exception { } } }); + // When the HEADERS are not buffered because of MAX_CONCURRENT_STREAMS in + // StreamBufferingEncoder, the stream is created immediately even if the bytes of the HEADERS + // are delayed because the OS may have too much buffered and isn't accepting the write. The + // write promise is also delayed until flush(). However, we need to associate the netty stream + // with the transport state so that goingAway() and forcefulClose() and able to notify the + // stream of failures. + // + // This leaves a hole when MAX_CONCURRENT_STREAMS is reached, as http2Stream will be null, but + // it is better than nothing. + Http2Stream http2Stream = connection().stream(streamId); + if (http2Stream != null) { + http2Stream.setProperty(streamKey, stream); + } } /** @@ -750,19 +892,21 @@ private void sendPingFrameTraced(ChannelHandlerContext ctx, SendPingCommand msg, public void operationComplete(ChannelFuture future) throws Exception { if (future.isSuccess()) { transportTracer.reportKeepAliveSent(); - } else { - Throwable cause = future.cause(); - if (cause instanceof ClosedChannelException) { - cause = lifecycleManager.getShutdownThrowable(); - if (cause == null) { - cause = Status.UNKNOWN.withDescription("Ping failed but for unknown reason.") - .withCause(future.cause()).asException(); - } - } - finalPing.failed(cause); - if (ping == finalPing) { - ping = null; + return; + } + Throwable cause = future.cause(); + Status status = lifecycleManager.getShutdownStatus(); + if (cause instanceof ClosedChannelException) { + if (status == null) { + status = Status.UNKNOWN.withDescription("Ping failed but for unknown reason.") + .withCause(future.cause()); } + } else { + status = Utils.statusFromThrowable(cause); + } + finalPing.failed(status); + if (ping == finalPing) { + ping = null; } } }); @@ -770,7 +914,7 @@ public void operationComplete(ChannelFuture future) throws Exception { private void gracefulClose(ChannelHandlerContext ctx, GracefulCloseCommand msg, ChannelPromise promise) throws Exception { - lifecycleManager.notifyShutdown(msg.getStatus()); + lifecycleManager.notifyShutdown(msg.getStatus(), SimpleDisconnectError.SUBCHANNEL_SHUTDOWN); // Explicitly flush to create any buffered streams before sending GOAWAY. // TODO(ejona): determine if the need to flush is a bug in Netty flush(ctx); @@ -806,13 +950,15 @@ public boolean visit(Http2Stream stream) throws Http2Exception { private void goingAway(long errorCode, byte[] debugData) { Status finalStatus = statusFromH2Error( Status.Code.UNAVAILABLE, "GOAWAY shut down transport", errorCode, debugData); - lifecycleManager.notifyGracefulShutdown(finalStatus); + DisconnectError disconnectError = new GoAwayDisconnectError( + GrpcUtil.Http2Error.forCode(errorCode)); + lifecycleManager.notifyGracefulShutdown(finalStatus, disconnectError); abruptGoAwayStatus = statusFromH2Error( Status.Code.UNAVAILABLE, "Abrupt GOAWAY closed unsent stream", errorCode, debugData); // While this _should_ be UNAVAILABLE, Netty uses the wrong stream id in the GOAWAY when it // fails streams due to HPACK failures (e.g., header list too large). To be more conservative, // we assume any sent streams may be related to the GOAWAY. This should rarely impact users - // since the main time servers should use abrupt GOAWAYs is if there is a protocol error, and if + // since the main time servers should use abrupt GOAWAYs if there is a protocol error, and if // there wasn't a protocol error the error code was probably NO_ERROR which is mapped to // UNAVAILABLE. https://github.com/netty/netty/issues/10670 final Status abruptGoAwayStatusConservative = statusFromH2Error( @@ -827,7 +973,7 @@ private void goingAway(long errorCode, byte[] debugData) { // This can cause reentrancy, but should be minor since it is normal to handle writes in // response to a read. Also, the call stack is rather shallow at this point clientWriteQueue.drainNow(); - if (lifecycleManager.notifyShutdown(finalStatus)) { + if (lifecycleManager.notifyShutdown(finalStatus, disconnectError)) { // This is for the only RPCs that are actually covered by the GOAWAY error code. All other // RPCs were not observed by the remote and so should be UNAVAILABLE. channelInactiveReason = statusFromH2Error( @@ -861,9 +1007,9 @@ public boolean visit(Http2Stream stream) throws Http2Exception { } } - private void cancelPing(Throwable t) { + private void cancelPing(Status s) { if (ping != null) { - ping.failed(t); + ping.failed(s); ping = null; } } diff --git a/netty/src/main/java/io/grpc/netty/NettyClientStream.java b/netty/src/main/java/io/grpc/netty/NettyClientStream.java index 0c0bb7eeb8d..2939eed2e37 100644 --- a/netty/src/main/java/io/grpc/netty/NettyClientStream.java +++ b/netty/src/main/java/io/grpc/netty/NettyClientStream.java @@ -182,20 +182,10 @@ private void writeFrameInternal( if (numBytes > 0) { // Add the bytes to outbound flow control. onSendingBytes(numBytes); + ChannelFutureListener failureListener = + future -> transportState().onWriteFrameData(future, numMessages, numBytes); writeQueue.enqueue(new SendGrpcFrameCommand(transportState(), bytebuf, endOfStream), flush) - .addListener(new ChannelFutureListener() { - @Override - public void operationComplete(ChannelFuture future) throws Exception { - // If the future succeeds when http2stream is null, the stream has been cancelled - // before it began and Netty is purging pending writes from the flow-controller. - if (future.isSuccess() && transportState().http2Stream() != null) { - // Remove the bytes from outbound flow control, optionally notifying - // the client that they can send more bytes. - transportState().onSentBytes(numBytes); - NettyClientStream.this.getTransportTracer().reportMessageSent(numMessages); - } - } - }); + .addListener(failureListener); } else { // The frame is empty and will not impact outbound flow control. Just send it. writeQueue.enqueue( @@ -237,8 +227,9 @@ protected TransportState( int maxMessageSize, StatsTraceContext statsTraceCtx, TransportTracer transportTracer, - String methodName) { - super(maxMessageSize, statsTraceCtx, transportTracer); + String methodName, + CallOptions options) { + super(maxMessageSize, statsTraceCtx, transportTracer, options); this.methodName = checkNotNull(methodName, "methodName"); this.handler = checkNotNull(handler, "handler"); this.eventLoop = checkNotNull(eventLoop, "eventLoop"); @@ -306,6 +297,29 @@ protected void http2ProcessingFailed(Status status, boolean stopDelivery, Metada handler.getWriteQueue().enqueue(new CancelClientStreamCommand(this, status), true); } + private void onWriteFrameData(ChannelFuture future, int numMessages, int numBytes) { + // If the future succeeds when http2stream is null, the stream has been cancelled + // before it began and Netty is purging pending writes from the flow-controller. + if (future.isSuccess() && http2Stream() == null) { + return; + } + + if (future.isSuccess()) { + // Remove the bytes from outbound flow control, optionally notifying + // the client that they can send more bytes. + onSentBytes(numBytes); + getTransportTracer().reportMessageSent(numMessages); + } else if (!isStreamDeallocated()) { + // Future failed, fail RPC. + // Normally we don't need to do anything here because the cause of a failed future + // while writing DATA frames would be an IO error and the stream is already closed. + // However, we still need handle any unexpected failures raised in Netty. + // Note: isStreamDeallocated() protects from spamming stream resets by scheduling multiple + // CancelClientStreamCommand commands. + http2ProcessingFailed(statusFromFailedFuture(future), true, new Metadata()); + } + } + @Override public void runOnTransportThread(final Runnable r) { if (eventLoop.inEventLoop()) { diff --git a/netty/src/main/java/io/grpc/netty/NettyClientTransport.java b/netty/src/main/java/io/grpc/netty/NettyClientTransport.java index 689dd847d5e..6585df42df3 100644 --- a/netty/src/main/java/io/grpc/netty/NettyClientTransport.java +++ b/netty/src/main/java/io/grpc/netty/NettyClientTransport.java @@ -34,14 +34,17 @@ import io.grpc.InternalLogId; import io.grpc.Metadata; import io.grpc.MethodDescriptor; +import io.grpc.MetricRecorder; import io.grpc.Status; import io.grpc.internal.ClientStream; import io.grpc.internal.ConnectionClientTransport; +import io.grpc.internal.DisconnectError; import io.grpc.internal.FailingClientStream; import io.grpc.internal.GrpcUtil; import io.grpc.internal.Http2Ping; import io.grpc.internal.KeepAliveManager; import io.grpc.internal.KeepAliveManager.ClientKeepAlivePinger; +import io.grpc.internal.SimpleDisconnectError; import io.grpc.internal.StatsTraceContext; import io.grpc.internal.TransportTracer; import io.grpc.netty.NettyChannelBuilder.LocalSocketPicker; @@ -68,7 +71,8 @@ /** * A Netty-based {@link ConnectionClientTransport} implementation. */ -class NettyClientTransport implements ConnectionClientTransport { +class NettyClientTransport implements ConnectionClientTransport, + ClientKeepAlivePinger.TransportWithDisconnectReason { private final InternalLogId logId; private final Map, ?> channelOptions; @@ -83,6 +87,7 @@ class NettyClientTransport implements ConnectionClientTransport { private final int flowControlWindow; private final int maxMessageSize; private final int maxHeaderListSize; + private final int softLimitHeaderListSize; private KeepAliveManager keepAliveManager; private final long keepAliveTimeNanos; private final long keepAliveTimeoutNanos; @@ -104,17 +109,33 @@ class NettyClientTransport implements ConnectionClientTransport { private final ChannelLogger channelLogger; private final boolean useGetForSafeMethods; private final Ticker ticker; + private final MetricRecorder metricRecorder; + NettyClientTransport( - SocketAddress address, ChannelFactory channelFactory, - Map, ?> channelOptions, EventLoopGroup group, - ProtocolNegotiator negotiator, boolean autoFlowControl, int flowControlWindow, - int maxMessageSize, int maxHeaderListSize, - long keepAliveTimeNanos, long keepAliveTimeoutNanos, - boolean keepAliveWithoutCalls, String authority, @Nullable String userAgent, - Runnable tooManyPingsRunnable, TransportTracer transportTracer, Attributes eagAttributes, - LocalSocketPicker localSocketPicker, ChannelLogger channelLogger, - boolean useGetForSafeMethods, Ticker ticker) { + SocketAddress address, + ChannelFactory channelFactory, + Map, ?> channelOptions, + EventLoopGroup group, + ProtocolNegotiator negotiator, + boolean autoFlowControl, + int flowControlWindow, + int maxMessageSize, + int maxHeaderListSize, + int softLimitHeaderListSize, + long keepAliveTimeNanos, + long keepAliveTimeoutNanos, + boolean keepAliveWithoutCalls, + String authority, + @Nullable String userAgent, + Runnable tooManyPingsRunnable, + TransportTracer transportTracer, + Attributes eagAttributes, + LocalSocketPicker localSocketPicker, + ChannelLogger channelLogger, + boolean useGetForSafeMethods, + MetricRecorder metricRecorder, + Ticker ticker) { this.negotiator = Preconditions.checkNotNull(negotiator, "negotiator"); this.negotiationScheme = this.negotiator.scheme(); @@ -126,6 +147,7 @@ class NettyClientTransport implements ConnectionClientTransport { this.flowControlWindow = flowControlWindow; this.maxMessageSize = maxMessageSize; this.maxHeaderListSize = maxHeaderListSize; + this.softLimitHeaderListSize = softLimitHeaderListSize; this.keepAliveTimeNanos = keepAliveTimeNanos; this.keepAliveTimeoutNanos = keepAliveTimeoutNanos; this.keepAliveWithoutCalls = keepAliveWithoutCalls; @@ -140,6 +162,7 @@ class NettyClientTransport implements ConnectionClientTransport { this.logId = InternalLogId.allocate(getClass(), remoteAddress.toString()); this.channelLogger = Preconditions.checkNotNull(channelLogger, "channelLogger"); this.useGetForSafeMethods = useGetForSafeMethods; + this.metricRecorder = metricRecorder; this.ticker = Preconditions.checkNotNull(ticker, "ticker"); } @@ -149,7 +172,7 @@ public void ping(final PingCallback callback, final Executor executor) { executor.execute(new Runnable() { @Override public void run() { - callback.onFailure(statusExplainingWhyTheChannelIsNull.asException()); + callback.onFailure(statusExplainingWhyTheChannelIsNull); } }); return; @@ -161,7 +184,7 @@ public void run() { public void operationComplete(ChannelFuture future) throws Exception { if (!future.isSuccess()) { Status s = statusFromFailedFuture(future); - Http2Ping.notifyFailed(callback, executor, s.asException()); + Http2Ping.notifyFailed(callback, executor, s); } } }; @@ -188,7 +211,8 @@ public ClientStream newStream( maxMessageSize, statsTraceCtx, transportTracer, - method.getFullMethodName()) { + method.getFullMethodName(), + callOptions) { @Override protected Status statusFromFailedFuture(ChannelFuture f) { return NettyClientTransport.this.statusFromFailedFuture(f); @@ -214,23 +238,25 @@ public Runnable start(Listener transportListener) { EventLoop eventLoop = group.next(); if (keepAliveTimeNanos != KEEPALIVE_TIME_NANOS_DISABLED) { keepAliveManager = new KeepAliveManager( - new ClientKeepAlivePinger(this), eventLoop, keepAliveTimeNanos, keepAliveTimeoutNanos, - keepAliveWithoutCalls); + new ClientKeepAlivePinger(this), eventLoop, keepAliveTimeNanos, + keepAliveTimeoutNanos, keepAliveWithoutCalls); } handler = NettyClientHandler.newHandler( - lifecycleManager, - keepAliveManager, - autoFlowControl, - flowControlWindow, - maxHeaderListSize, - GrpcUtil.STOPWATCH_SUPPLIER, - tooManyPingsRunnable, - transportTracer, - eagAttributes, - authorityString, - channelLogger, - ticker); + lifecycleManager, + keepAliveManager, + autoFlowControl, + flowControlWindow, + maxHeaderListSize, + softLimitHeaderListSize, + GrpcUtil.STOPWATCH_SUPPLIER, + tooManyPingsRunnable, + transportTracer, + eagAttributes, + authorityString, + channelLogger, + ticker, + metricRecorder); ChannelHandler negotiationHandler = negotiator.newHandler(handler); @@ -240,13 +266,6 @@ public Runnable start(Listener transportListener) { b.channelFactory(channelFactory); // For non-socket based channel, the option will be ignored. b.option(SO_KEEPALIVE, true); - // For non-epoll based channel, the option will be ignored. - if (keepAliveTimeNanos != KEEPALIVE_TIME_NANOS_DISABLED) { - ChannelOption tcpUserTimeout = Utils.maybeGetTcpUserTimeoutOption(); - if (tcpUserTimeout != null) { - b.option(tcpUserTimeout, (int) TimeUnit.NANOSECONDS.toMillis(keepAliveTimeoutNanos)); - } - } for (Map.Entry, ?> entry : channelOptions.entrySet()) { // Every entry in the map is obtained from // NettyChannelBuilder#withOption(ChannelOption option, T value) @@ -280,11 +299,26 @@ public void run() { // could use GlobalEventExecutor (which is what regFuture would use for notifying // listeners in this case), but avoiding on-demand thread creation in an error case seems // a good idea and is probably clearer threading. - lifecycleManager.notifyTerminated(statusExplainingWhyTheChannelIsNull); + lifecycleManager.notifyTerminated(statusExplainingWhyTheChannelIsNull, + SimpleDisconnectError.UNKNOWN); } }; } channel = regFuture.channel(); + // For non-epoll based channel, the option will be ignored. + try { + if (keepAliveTimeNanos != KEEPALIVE_TIME_NANOS_DISABLED + && Class.forName("io.netty.channel.epoll.AbstractEpollChannel").isInstance(channel)) { + ChannelOption tcpUserTimeout = Utils.maybeGetTcpUserTimeoutOption(); + if (tcpUserTimeout != null) { + int tcpUserTimeoutMs = (int) TimeUnit.NANOSECONDS.toMillis(keepAliveTimeoutNanos); + channel.config().setOption(tcpUserTimeout, tcpUserTimeoutMs); + } + } + } catch (ClassNotFoundException ignored) { + // JVM did not load AbstractEpollChannel, so the current channel will not be of epoll type, + // so there is no need to set TCP_USER_TIMEOUT + } // Start the write queue as soon as the channel is constructed handler.startWriteQueue(channel); // This write will have no effect, yet it will only complete once the negotiationHandler @@ -298,7 +332,8 @@ public void operationComplete(ChannelFuture future) throws Exception { if (!future.isSuccess()) { // Need to notify of this failure, because NettyClientHandler may not have been added to // the pipeline before the error occurred. - lifecycleManager.notifyTerminated(Utils.statusFromThrowable(future.cause())); + lifecycleManager.notifyTerminated(Utils.statusFromThrowable(future.cause()), + SimpleDisconnectError.UNKNOWN); } } }); @@ -332,12 +367,17 @@ public void shutdown(Status reason) { @Override public void shutdownNow(final Status reason) { + shutdownNow(reason, SimpleDisconnectError.SUBCHANNEL_SHUTDOWN); + } + + @Override + public void shutdownNow(final Status reason, DisconnectError disconnectError) { // Notifying of termination is automatically done when the channel closes. if (channel != null && channel.isOpen()) { handler.getWriteQueue().enqueue(new Runnable() { @Override public void run() { - lifecycleManager.notifyShutdown(reason); + lifecycleManager.notifyShutdown(reason, disconnectError); channel.write(new ForcefulCloseCommand(reason)); } }, true); diff --git a/netty/src/main/java/io/grpc/netty/NettyReadableBuffer.java b/netty/src/main/java/io/grpc/netty/NettyReadableBuffer.java index 7e180544de4..af5ec8d8bad 100644 --- a/netty/src/main/java/io/grpc/netty/NettyReadableBuffer.java +++ b/netty/src/main/java/io/grpc/netty/NettyReadableBuffer.java @@ -60,11 +60,6 @@ public void readBytes(byte[] dest, int index, int length) { buffer.readBytes(dest, index, length); } - @Override - public void readBytes(ByteBuffer dest) { - buffer.readBytes(dest); - } - @Override public void readBytes(OutputStream dest, int length) { try { diff --git a/netty/src/main/java/io/grpc/netty/NettyServer.java b/netty/src/main/java/io/grpc/netty/NettyServer.java index 2960604e5b5..2bb6b2c5921 100644 --- a/netty/src/main/java/io/grpc/netty/NettyServer.java +++ b/netty/src/main/java/io/grpc/netty/NettyServer.java @@ -31,6 +31,7 @@ import io.grpc.InternalInstrumented; import io.grpc.InternalLogId; import io.grpc.InternalWithLogId; +import io.grpc.MetricRecorder; import io.grpc.ServerStreamTracer; import io.grpc.internal.InternalServer; import io.grpc.internal.ObjectPool; @@ -92,6 +93,8 @@ class NettyServer implements InternalServer, InternalWithLogId { private final int flowControlWindow; private final int maxMessageSize; private final int maxHeaderListSize; + private final int softLimitHeaderListSize; + private MetricRecorder metricRecorder; private final long keepAliveTimeInNanos; private final long keepAliveTimeoutInNanos; private final long maxConnectionIdleInNanos; @@ -123,15 +126,22 @@ class NettyServer implements InternalServer, InternalWithLogId { ProtocolNegotiator protocolNegotiator, List streamTracerFactories, TransportTracer.Factory transportTracerFactory, - int maxStreamsPerConnection, boolean autoFlowControl, int flowControlWindow, - int maxMessageSize, int maxHeaderListSize, - long keepAliveTimeInNanos, long keepAliveTimeoutInNanos, + int maxStreamsPerConnection, + boolean autoFlowControl, + int flowControlWindow, + int maxMessageSize, + int maxHeaderListSize, + int softLimitHeaderListSize, + long keepAliveTimeInNanos, + long keepAliveTimeoutInNanos, long maxConnectionIdleInNanos, long maxConnectionAgeInNanos, long maxConnectionAgeGraceInNanos, boolean permitKeepAliveWithoutCalls, long permitKeepAliveTimeInNanos, int maxRstCount, long maxRstPeriodNanos, - Attributes eagAttributes, InternalChannelz channelz) { + Attributes eagAttributes, InternalChannelz channelz, + MetricRecorder metricRecorder) { this.addresses = checkNotNull(addresses, "addresses"); + this.metricRecorder = metricRecorder; this.channelFactory = checkNotNull(channelFactory, "channelFactory"); checkNotNull(channelOptions, "channelOptions"); this.channelOptions = new HashMap, Object>(channelOptions); @@ -152,6 +162,7 @@ class NettyServer implements InternalServer, InternalWithLogId { this.flowControlWindow = flowControlWindow; this.maxMessageSize = maxMessageSize; this.maxHeaderListSize = maxHeaderListSize; + this.softLimitHeaderListSize = softLimitHeaderListSize; this.keepAliveTimeInNanos = keepAliveTimeInNanos; this.keepAliveTimeoutInNanos = keepAliveTimeoutInNanos; this.maxConnectionIdleInNanos = maxConnectionIdleInNanos; @@ -167,6 +178,7 @@ class NettyServer implements InternalServer, InternalWithLogId { String.valueOf(addresses)); } + @Override public SocketAddress getListenSocketAddress() { Iterator it = channelGroup.iterator(); @@ -243,28 +255,30 @@ public void initChannel(Channel ch) { (long) ((.9D + Math.random() * .2D) * maxConnectionAgeInNanos); } - NettyServerTransport transport = - new NettyServerTransport( - ch, - channelDone, - protocolNegotiator, - streamTracerFactories, - transportTracerFactory.create(), - maxStreamsPerConnection, - autoFlowControl, - flowControlWindow, - maxMessageSize, - maxHeaderListSize, - keepAliveTimeInNanos, - keepAliveTimeoutInNanos, - maxConnectionIdleInNanos, - maxConnectionAgeInNanos, - maxConnectionAgeGraceInNanos, - permitKeepAliveWithoutCalls, - permitKeepAliveTimeInNanos, - maxRstCount, - maxRstPeriodNanos, - eagAttributes); + NettyServerTransport transport = + new NettyServerTransport( + ch, + channelDone, + protocolNegotiator, + streamTracerFactories, + transportTracerFactory.create(), + maxStreamsPerConnection, + autoFlowControl, + flowControlWindow, + maxMessageSize, + maxHeaderListSize, + softLimitHeaderListSize, + keepAliveTimeInNanos, + keepAliveTimeoutInNanos, + maxConnectionIdleInNanos, + maxConnectionAgeInNanos, + maxConnectionAgeGraceInNanos, + permitKeepAliveWithoutCalls, + permitKeepAliveTimeInNanos, + maxRstCount, + maxRstPeriodNanos, + eagAttributes, + metricRecorder); ServerTransportListener transportListener; // This is to order callbacks on the listener, not to guard access to channel. synchronized (NettyServer.this) { diff --git a/netty/src/main/java/io/grpc/netty/NettyServerBuilder.java b/netty/src/main/java/io/grpc/netty/NettyServerBuilder.java index 3b82b193f61..4ef14b0e933 100644 --- a/netty/src/main/java/io/grpc/netty/NettyServerBuilder.java +++ b/netty/src/main/java/io/grpc/netty/NettyServerBuilder.java @@ -22,6 +22,7 @@ import static io.grpc.internal.GrpcUtil.DEFAULT_MAX_MESSAGE_SIZE; import static io.grpc.internal.GrpcUtil.DEFAULT_SERVER_KEEPALIVE_TIMEOUT_NANOS; import static io.grpc.internal.GrpcUtil.DEFAULT_SERVER_KEEPALIVE_TIME_NANOS; +import static io.grpc.internal.GrpcUtil.DEFAULT_SERVER_PERMIT_KEEPALIVE_TIME_NANOS; import static io.grpc.internal.GrpcUtil.SERVER_KEEPALIVE_TIME_NANOS_DISABLED; import com.google.common.annotations.VisibleForTesting; @@ -32,6 +33,7 @@ import io.grpc.ExperimentalApi; import io.grpc.ForwardingServerBuilder; import io.grpc.Internal; +import io.grpc.MetricRecorder; import io.grpc.ServerBuilder; import io.grpc.ServerCredentials; import io.grpc.ServerStreamTracer; @@ -105,13 +107,14 @@ public final class NettyServerBuilder extends ForwardingServerBuilder streamTracerFactories) { - return buildTransportServers(streamTracerFactories); + List streamTracerFactories, + MetricRecorder metricRecorder) { + return buildTransportServers(streamTracerFactories, metricRecorder); } } @@ -492,6 +496,39 @@ public NettyServerBuilder maxHeaderListSize(int maxHeaderListSize) { public NettyServerBuilder maxInboundMetadataSize(int bytes) { checkArgument(bytes > 0, "maxInboundMetadataSize must be positive: %s", bytes); this.maxHeaderListSize = bytes; + // Clear the soft limit setting, by setting soft limit to maxInboundMetadataSize. The + // maxInboundMetadataSize will take precedence over soft limit check. + this.softLimitHeaderListSize = bytes; + return this; + } + + /** + * Sets the size of metadata that clients are advised to not exceed. When a metadata with size + * larger than the soft limit is encountered there will be a probability the RPC will fail. The + * chance of failing increases as the metadata size approaches the hard limit. + * {@code Integer.MAX_VALUE} disables the enforcement. The default is implementation-dependent, + * but is not generally less than 8 KiB and may be unlimited. + * + *

This is cumulative size of the metadata. The precise calculation is + * implementation-dependent, but implementations are encouraged to follow the calculation used + * for + * HTTP/2's + * SETTINGS_MAX_HEADER_LIST_SIZE. It sums the bytes from each entry's key and value, plus 32 + * bytes of overhead per entry. + * + * @param soft the soft size limit of received metadata + * @param max the hard size limit of received metadata + * @return this + * @throws IllegalArgumentException if soft and/or max is non-positive, or max smaller than soft + * @since 1.68.0 + */ + @CanIgnoreReturnValue + public NettyServerBuilder maxInboundMetadataSize(int soft, int max) { + checkArgument(soft > 0, "softLimitHeaderListSize must be positive: %s", soft); + checkArgument(max > soft, + "maxInboundMetadataSize: %s must be greater than softLimitHeaderListSize: %s", max, soft); + this.softLimitHeaderListSize = soft; + this.maxHeaderListSize = max; return this; } @@ -669,22 +706,44 @@ void eagAttributes(Attributes eagAttributes) { this.eagAttributes = checkNotNull(eagAttributes, "eagAttributes"); } + @VisibleForTesting NettyServer buildTransportServers( - List streamTracerFactories) { + List streamTracerFactories, + MetricRecorder metricRecorder) { assertEventLoopsAndChannelType(); ProtocolNegotiator negotiator = protocolNegotiatorFactory.newNegotiator( this.serverImplBuilder.getExecutorPool()); return new NettyServer( - listenAddresses, channelFactory, channelOptions, childChannelOptions, - bossEventLoopGroupPool, workerEventLoopGroupPool, forceHeapBuffer, negotiator, - streamTracerFactories, transportTracerFactory, maxConcurrentCallsPerConnection, - autoFlowControl, flowControlWindow, maxMessageSize, maxHeaderListSize, - keepAliveTimeInNanos, keepAliveTimeoutInNanos, - maxConnectionIdleInNanos, maxConnectionAgeInNanos, - maxConnectionAgeGraceInNanos, permitKeepAliveWithoutCalls, permitKeepAliveTimeInNanos, - maxRstCount, maxRstPeriodNanos, eagAttributes, this.serverImplBuilder.getChannelz()); + listenAddresses, + channelFactory, + channelOptions, + childChannelOptions, + bossEventLoopGroupPool, + workerEventLoopGroupPool, + forceHeapBuffer, + negotiator, + streamTracerFactories, + transportTracerFactory, + maxConcurrentCallsPerConnection, + autoFlowControl, + flowControlWindow, + maxMessageSize, + maxHeaderListSize, + softLimitHeaderListSize, + keepAliveTimeInNanos, + keepAliveTimeoutInNanos, + maxConnectionIdleInNanos, + maxConnectionAgeInNanos, + maxConnectionAgeGraceInNanos, + permitKeepAliveWithoutCalls, + permitKeepAliveTimeInNanos, + maxRstCount, + maxRstPeriodNanos, + eagAttributes, + this.serverImplBuilder.getChannelz(), + metricRecorder); } @VisibleForTesting diff --git a/netty/src/main/java/io/grpc/netty/NettyServerHandler.java b/netty/src/main/java/io/grpc/netty/NettyServerHandler.java index 2b06a3fcf55..79715ca2996 100644 --- a/netty/src/main/java/io/grpc/netty/NettyServerHandler.java +++ b/netty/src/main/java/io/grpc/netty/NettyServerHandler.java @@ -42,6 +42,7 @@ import io.grpc.InternalMetadata; import io.grpc.InternalStatus; import io.grpc.Metadata; +import io.grpc.MetricRecorder; import io.grpc.ServerStreamTracer; import io.grpc.Status; import io.grpc.internal.GrpcUtil; @@ -60,6 +61,8 @@ import io.netty.channel.ChannelFutureListener; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelPromise; +import io.netty.handler.codec.http.HttpHeaderNames; +import io.netty.handler.codec.http2.DecoratingHttp2ConnectionEncoder; import io.netty.handler.codec.http2.DecoratingHttp2FrameWriter; import io.netty.handler.codec.http2.DefaultHttp2Connection; import io.netty.handler.codec.http2.DefaultHttp2ConnectionDecoder; @@ -67,8 +70,10 @@ import io.netty.handler.codec.http2.DefaultHttp2FrameReader; import io.netty.handler.codec.http2.DefaultHttp2FrameWriter; import io.netty.handler.codec.http2.DefaultHttp2Headers; +import io.netty.handler.codec.http2.DefaultHttp2HeadersEncoder; import io.netty.handler.codec.http2.DefaultHttp2LocalFlowController; import io.netty.handler.codec.http2.DefaultHttp2RemoteFlowController; +import io.netty.handler.codec.http2.EmptyHttp2Headers; import io.netty.handler.codec.http2.Http2Connection; import io.netty.handler.codec.http2.Http2ConnectionAdapter; import io.netty.handler.codec.http2.Http2ConnectionDecoder; @@ -82,12 +87,14 @@ import io.netty.handler.codec.http2.Http2FrameWriter; import io.netty.handler.codec.http2.Http2Headers; import io.netty.handler.codec.http2.Http2HeadersDecoder; +import io.netty.handler.codec.http2.Http2HeadersEncoder; import io.netty.handler.codec.http2.Http2InboundFrameLogger; +import io.netty.handler.codec.http2.Http2LifecycleManager; import io.netty.handler.codec.http2.Http2OutboundFrameLogger; import io.netty.handler.codec.http2.Http2Settings; import io.netty.handler.codec.http2.Http2Stream; import io.netty.handler.codec.http2.Http2StreamVisitor; -import io.netty.handler.codec.http2.WeightedFairQueueByteDistributor; +import io.netty.handler.codec.http2.UniformStreamByteDistributor; import io.netty.handler.logging.LogLevel; import io.netty.util.AsciiString; import io.netty.util.ReferenceCountUtil; @@ -121,17 +128,16 @@ class NettyServerHandler extends AbstractNettyHandler { private final Http2Connection.PropertyKey streamKey; private final ServerTransportListener transportListener; private final int maxMessageSize; + private final TcpMetrics tcpMetrics; private final long keepAliveTimeInNanos; private final long keepAliveTimeoutInNanos; private final long maxConnectionAgeInNanos; private final long maxConnectionAgeGraceInNanos; - private final int maxRstCount; - private final long maxRstPeriodNanos; + private final RstStreamCounter rstStreamCounter; private final List streamTracerFactories; private final TransportTracer transportTracer; private final KeepAliveEnforcer keepAliveEnforcer; private final Attributes eagAttributes; - private final Ticker ticker; /** Incomplete attributes produced by negotiator. */ private Attributes negotiationAttributes; private InternalChannelz.Security securityInfo; @@ -149,9 +155,6 @@ class NettyServerHandler extends AbstractNettyHandler { private ScheduledFuture maxConnectionAgeMonitor; @CheckForNull private GracefulShutdown gracefulShutdown; - private int rstCount; - private long lastRstNanoTime; - static NettyServerHandler newHandler( ServerTransportListener transportListener, @@ -162,6 +165,7 @@ static NettyServerHandler newHandler( boolean autoFlowControl, int flowControlWindow, int maxHeaderListSize, + int softLimitHeaderListSize, int maxMessageSize, long keepAliveTimeInNanos, long keepAliveTimeoutInNanos, @@ -172,15 +176,18 @@ static NettyServerHandler newHandler( long permitKeepAliveTimeInNanos, int maxRstCount, long maxRstPeriodNanos, - Attributes eagAttributes) { + Attributes eagAttributes, + MetricRecorder metricRecorder) { Preconditions.checkArgument(maxHeaderListSize > 0, "maxHeaderListSize must be positive: %s", maxHeaderListSize); Http2FrameLogger frameLogger = new Http2FrameLogger(LogLevel.DEBUG, NettyServerHandler.class); Http2HeadersDecoder headersDecoder = new GrpcHttp2ServerHeadersDecoder(maxHeaderListSize); Http2FrameReader frameReader = new Http2InboundFrameLogger( new DefaultHttp2FrameReader(headersDecoder), frameLogger); + Http2HeadersEncoder encoder = new DefaultHttp2HeadersEncoder( + Http2HeadersEncoder.NEVER_SENSITIVE, false, 16, Integer.MAX_VALUE); Http2FrameWriter frameWriter = - new Http2OutboundFrameLogger(new DefaultHttp2FrameWriter(), frameLogger); + new Http2OutboundFrameLogger(new DefaultHttp2FrameWriter(encoder), frameLogger); return newHandler( channelUnused, frameReader, @@ -192,6 +199,7 @@ static NettyServerHandler newHandler( autoFlowControl, flowControlWindow, maxHeaderListSize, + softLimitHeaderListSize, maxMessageSize, keepAliveTimeInNanos, keepAliveTimeoutInNanos, @@ -203,7 +211,8 @@ static NettyServerHandler newHandler( maxRstCount, maxRstPeriodNanos, eagAttributes, - Ticker.systemTicker()); + Ticker.systemTicker(), + metricRecorder); } static NettyServerHandler newHandler( @@ -217,6 +226,7 @@ static NettyServerHandler newHandler( boolean autoFlowControl, int flowControlWindow, int maxHeaderListSize, + int softLimitHeaderListSize, int maxMessageSize, long keepAliveTimeInNanos, long keepAliveTimeoutInNanos, @@ -228,24 +238,34 @@ static NettyServerHandler newHandler( int maxRstCount, long maxRstPeriodNanos, Attributes eagAttributes, - Ticker ticker) { + Ticker ticker, + MetricRecorder metricRecorder) { Preconditions.checkArgument(maxStreams > 0, "maxStreams must be positive: %s", maxStreams); Preconditions.checkArgument(flowControlWindow > 0, "flowControlWindow must be positive: %s", flowControlWindow); Preconditions.checkArgument(maxHeaderListSize > 0, "maxHeaderListSize must be positive: %s", maxHeaderListSize); + Preconditions.checkArgument( + softLimitHeaderListSize > 0, "softLimitHeaderListSize must be positive: %s", + softLimitHeaderListSize); Preconditions.checkArgument(maxMessageSize > 0, "maxMessageSize must be positive: %s", maxMessageSize); final Http2Connection connection = new DefaultHttp2Connection(true); - WeightedFairQueueByteDistributor dist = new WeightedFairQueueByteDistributor(connection); - dist.allocationQuantum(16 * 1024); // Make benchmarks fast again. + UniformStreamByteDistributor dist = new UniformStreamByteDistributor(connection); + dist.minAllocationChunk(MIN_ALLOCATED_CHUNK); // Increased for benchmarks performance. DefaultHttp2RemoteFlowController controller = new DefaultHttp2RemoteFlowController(connection, dist); connection.remote().flowController(controller); final KeepAliveEnforcer keepAliveEnforcer = new KeepAliveEnforcer( permitKeepAliveWithoutCalls, permitKeepAliveTimeInNanos, TimeUnit.NANOSECONDS); + if (ticker == null) { + ticker = Ticker.systemTicker(); + } + + RstStreamCounter rstStreamCounter + = new RstStreamCounter(maxRstCount, maxRstPeriodNanos, ticker); // Create the local flow controller configured to auto-refill the connection window. connection.local().flowController( new DefaultHttp2LocalFlowController(connection, DEFAULT_WINDOW_UPDATE_RATIO, true)); @@ -253,6 +273,7 @@ static NettyServerHandler newHandler( Http2ConnectionEncoder encoder = new DefaultHttp2ConnectionEncoder(connection, frameWriter); encoder = new Http2ControlFrameLimitEncoder(encoder, 10000); + encoder = new Http2RstCounterEncoder(encoder, rstStreamCounter); Http2ConnectionDecoder decoder = new DefaultHttp2ConnectionDecoder(connection, encoder, frameReader); @@ -261,10 +282,6 @@ static NettyServerHandler newHandler( settings.maxConcurrentStreams(maxStreams); settings.maxHeaderListSize(maxHeaderListSize); - if (ticker == null) { - ticker = Ticker.systemTicker(); - } - return new NettyServerHandler( channelUnused, connection, @@ -273,14 +290,17 @@ static NettyServerHandler newHandler( transportTracer, decoder, encoder, settings, maxMessageSize, - keepAliveTimeInNanos, keepAliveTimeoutInNanos, + maxHeaderListSize, + softLimitHeaderListSize, + keepAliveTimeInNanos, + keepAliveTimeoutInNanos, maxConnectionIdleInNanos, maxConnectionAgeInNanos, maxConnectionAgeGraceInNanos, keepAliveEnforcer, autoFlowControl, - maxRstCount, - maxRstPeriodNanos, - eagAttributes, ticker); + rstStreamCounter, + eagAttributes, ticker, + metricRecorder); } private NettyServerHandler( @@ -293,6 +313,8 @@ private NettyServerHandler( Http2ConnectionEncoder encoder, Http2Settings settings, int maxMessageSize, + int maxHeaderListSize, + int softLimitHeaderListSize, long keepAliveTimeInNanos, long keepAliveTimeoutInNanos, long maxConnectionIdleInNanos, @@ -300,12 +322,21 @@ private NettyServerHandler( long maxConnectionAgeGraceInNanos, final KeepAliveEnforcer keepAliveEnforcer, boolean autoFlowControl, - int maxRstCount, - long maxRstPeriodNanos, + RstStreamCounter rstStreamCounter, Attributes eagAttributes, - Ticker ticker) { - super(channelUnused, decoder, encoder, settings, new ServerChannelLogger(), - autoFlowControl, null, ticker); + Ticker ticker, + MetricRecorder metricRecorder) { + super( + channelUnused, + decoder, + encoder, + settings, + new ServerChannelLogger(), + autoFlowControl, + null, + ticker, + maxHeaderListSize, + softLimitHeaderListSize); final MaxConnectionIdleManager maxConnectionIdleManager; if (maxConnectionIdleInNanos == MAX_CONNECTION_IDLE_NANOS_DISABLED) { @@ -338,18 +369,16 @@ public void onStreamClosed(Http2Stream stream) { checkArgument(maxMessageSize >= 0, "maxMessageSize must be non-negative: %s", maxMessageSize); this.maxMessageSize = maxMessageSize; + this.tcpMetrics = new TcpMetrics(metricRecorder); this.keepAliveTimeInNanos = keepAliveTimeInNanos; this.keepAliveTimeoutInNanos = keepAliveTimeoutInNanos; this.maxConnectionIdleManager = maxConnectionIdleManager; this.maxConnectionAgeInNanos = maxConnectionAgeInNanos; this.maxConnectionAgeGraceInNanos = maxConnectionAgeGraceInNanos; this.keepAliveEnforcer = checkNotNull(keepAliveEnforcer, "keepAliveEnforcer"); - this.maxRstCount = maxRstCount; - this.maxRstPeriodNanos = maxRstPeriodNanos; + this.rstStreamCounter = rstStreamCounter; this.eagAttributes = checkNotNull(eagAttributes, "eagAttributes"); - this.ticker = checkNotNull(ticker, "ticker"); - this.lastRstNanoTime = ticker.read(); streamKey = encoder.connection().newKey(); this.transportListener = checkNotNull(transportListener, "transportListener"); this.streamTracerFactories = checkNotNull(streamTracerFactories, "streamTracerFactories"); @@ -465,8 +494,20 @@ private void onHeadersRead(ChannelHandlerContext ctx, int streamId, Http2Headers } if (!HTTP_METHOD.contentEquals(headers.method())) { + Http2Headers extraHeaders = new DefaultHttp2Headers(); + extraHeaders.add(HttpHeaderNames.ALLOW, HTTP_METHOD); respondWithHttpError(ctx, streamId, 405, Status.Code.INTERNAL, - String.format("Method '%s' is not supported", headers.method())); + String.format("Method '%s' is not supported", headers.method()), extraHeaders); + return; + } + + int h2HeadersSize = Utils.getH2HeadersSize(headers); + if (Utils.shouldRejectOnMetadataSizeSoftLimitExceeded( + h2HeadersSize, softLimitHeaderListSize, maxHeaderListSize)) { + respondWithHttpError(ctx, streamId, 431, Status.Code.RESOURCE_EXHAUSTED, String.format( + "Client Headers of size %d exceeded Metadata size soft limit: %d", + h2HeadersSize, + softLimitHeaderListSize)); return; } @@ -502,8 +543,7 @@ private void onHeadersRead(ChannelHandlerContext ctx, int streamId, Http2Headers state, attributes, authority, - statsTraceCtx, - transportTracer); + statsTraceCtx); transportListener.streamCreated(stream, method, metadata); state.onStreamAllocated(); http2Stream.setProperty(streamKey, state); @@ -547,24 +587,9 @@ private void onDataRead(int streamId, ByteBuf data, int padding, boolean endOfSt } private void onRstStreamRead(int streamId, long errorCode) throws Http2Exception { - if (maxRstCount > 0) { - long now = ticker.read(); - if (now - lastRstNanoTime > maxRstPeriodNanos) { - lastRstNanoTime = now; - rstCount = 1; - } else { - rstCount++; - if (rstCount > maxRstCount) { - throw new Http2Exception(Http2Error.ENHANCE_YOUR_CALM, "too_many_rststreams") { - @SuppressWarnings("UnsynchronizedOverridesSynchronized") // No memory accesses - @Override - public Throwable fillInStackTrace() { - // Avoid the CPU cycles, since the resets may be a CPU consumption attack - return this; - } - }; - } - } + Http2Exception tooManyRstStream = rstStreamCounter.countRstStream(); + if (tooManyRstStream != null) { + throw tooManyRstStream; } try { @@ -644,9 +669,16 @@ void setKeepAliveManagerForTest(KeepAliveManager keepAliveManager) { /** * Handler for the Channel shutting down. */ + @Override + public void channelActive(ChannelHandlerContext ctx) throws Exception { + tcpMetrics.channelActive(ctx.channel()); + super.channelActive(ctx); + } + @Override public void channelInactive(ChannelHandlerContext ctx) throws Exception { try { + tcpMetrics.channelInactive(ctx.channel()); if (keepAliveManager != null) { keepAliveManager.onTransportTermination(); } @@ -749,6 +781,7 @@ private void sendGrpcFrame( int streamId = cmd.stream().id(); Http2Stream stream = connection().stream(streamId); if (stream == null) { + cmd.release(); streamGone(streamId, promise); return; } @@ -788,9 +821,37 @@ private void cancelStream(ChannelHandlerContext ctx, CancelServerStreamCommand c PerfMark.linkIn(cmd.getLink()); // Notify the listener if we haven't already. cmd.stream().transportReportStatus(cmd.reason()); - // Terminate the stream. - encoder().writeRstStream(ctx, cmd.stream().id(), Http2Error.CANCEL.code(), promise); + + // Now we need to decide how we're going to notify the peer that this stream is closed. + // If possible, it's nice to inform the peer _why_ this stream was cancelled by sending + // a structured headers frame. + if (shouldCloseStreamWithHeaders(cmd, connection())) { + Metadata md = new Metadata(); + md.put(InternalStatus.CODE_KEY, cmd.reason()); + if (cmd.reason().getDescription() != null) { + md.put(InternalStatus.MESSAGE_KEY, cmd.reason().getDescription()); + } + Http2Headers headers = Utils.convertServerHeaders(md); + encoder().writeHeaders( + ctx, cmd.stream().id(), headers, /* padding = */ 0, /* endStream = */ true, promise); + } else { + // Terminate the stream. + encoder().writeRstStream(ctx, cmd.stream().id(), Http2Error.CANCEL.code(), promise); + } + } + } + + // Determine whether a CancelServerStreamCommand should try to close the stream with a + // HEADERS or a RST_STREAM frame. The caller has some influence over this (they can + // configure cmd.wantsHeaders()). The state of the stream also has an influence: we + // only try to send HEADERS if the stream exists and hasn't already sent any headers. + private static boolean shouldCloseStreamWithHeaders( + CancelServerStreamCommand cmd, Http2Connection conn) { + if (!cmd.wantsHeaders()) { + return false; } + Http2Stream stream = conn.stream(cmd.stream().id()); + return stream != null && !stream.isHeadersSent(); } private void gracefulClose(final ChannelHandlerContext ctx, final GracefulServerCloseCommand msg, @@ -831,6 +892,12 @@ public boolean visit(Http2Stream stream) throws Http2Exception { private void respondWithHttpError( ChannelHandlerContext ctx, int streamId, int code, Status.Code statusCode, String msg) { + respondWithHttpError(ctx, streamId, code, statusCode, msg, EmptyHttp2Headers.INSTANCE); + } + + private void respondWithHttpError( + ChannelHandlerContext ctx, int streamId, int code, Status.Code statusCode, String msg, + Http2Headers extraHeaders) { Metadata metadata = new Metadata(); metadata.put(InternalStatus.CODE_KEY, statusCode.toStatus()); metadata.put(InternalStatus.MESSAGE_KEY, msg); @@ -842,6 +909,7 @@ private void respondWithHttpError( for (int i = 0; i < serialized.length; i += 2) { headers.add(new AsciiString(serialized[i], false), new AsciiString(serialized[i + 1], false)); } + headers.add(extraHeaders); encoder().writeHeaders(ctx, streamId, headers, 0, false, ctx.newPromise()); ByteBuf msgBuf = ByteBufUtil.writeUtf8(ctx.alloc(), msg); encoder().writeData(ctx, streamId, msgBuf, 0, true, ctx.newPromise()); @@ -1123,6 +1191,81 @@ public ChannelFuture writeHeaders(ChannelHandlerContext ctx, int streamId, Http2 } } + private static final class Http2RstCounterEncoder extends DecoratingHttp2ConnectionEncoder { + private final RstStreamCounter rstStreamCounter; + private Http2LifecycleManager lifecycleManager; + + Http2RstCounterEncoder(Http2ConnectionEncoder encoder, RstStreamCounter rstStreamCounter) { + super(encoder); + this.rstStreamCounter = rstStreamCounter; + } + + @Override + public void lifecycleManager(Http2LifecycleManager lifecycleManager) { + this.lifecycleManager = lifecycleManager; + super.lifecycleManager(lifecycleManager); + } + + @Override + public ChannelFuture writeRstStream( + ChannelHandlerContext ctx, int streamId, long errorCode, ChannelPromise promise) { + ChannelFuture future = super.writeRstStream(ctx, streamId, errorCode, promise); + // We want to count "induced" RST_STREAM, where the server sent a reset because of a malformed + // frame. + boolean normalRst + = errorCode == Http2Error.NO_ERROR.code() || errorCode == Http2Error.CANCEL.code(); + if (!normalRst) { + Http2Exception tooManyRstStream = rstStreamCounter.countRstStream(); + if (tooManyRstStream != null) { + lifecycleManager.onError(ctx, true, tooManyRstStream); + ctx.close(); + } + } + return future; + } + } + + private static final class RstStreamCounter { + private final int maxRstCount; + private final long maxRstPeriodNanos; + private final Ticker ticker; + private int rstCount; + private long lastRstNanoTime; + + RstStreamCounter(int maxRstCount, long maxRstPeriodNanos, Ticker ticker) { + checkArgument(maxRstCount >= 0, "maxRstCount must be non-negative: %s", maxRstCount); + this.maxRstCount = maxRstCount; + this.maxRstPeriodNanos = maxRstPeriodNanos; + this.ticker = checkNotNull(ticker, "ticker"); + this.lastRstNanoTime = ticker.read(); + } + + /** Returns non-{@code null} when the connection should be killed by the caller. */ + private Http2Exception countRstStream() { + if (maxRstCount == 0) { + return null; + } + long now = ticker.read(); + if (now - lastRstNanoTime > maxRstPeriodNanos) { + lastRstNanoTime = now; + rstCount = 1; + } else { + rstCount++; + if (rstCount > maxRstCount) { + return new Http2Exception(Http2Error.ENHANCE_YOUR_CALM, "too_many_rststreams") { + @SuppressWarnings("UnsynchronizedOverridesSynchronized") // No memory accesses + @Override + public Throwable fillInStackTrace() { + // Avoid the CPU cycles, since the resets may be a CPU consumption attack + return this; + } + }; + } + } + return null; + } + } + private static class ServerChannelLogger extends ChannelLogger { private static final Logger log = Logger.getLogger(ChannelLogger.class.getName()); diff --git a/netty/src/main/java/io/grpc/netty/NettyServerStream.java b/netty/src/main/java/io/grpc/netty/NettyServerStream.java index a44d8b4a64f..836f39ddf19 100644 --- a/netty/src/main/java/io/grpc/netty/NettyServerStream.java +++ b/netty/src/main/java/io/grpc/netty/NettyServerStream.java @@ -52,7 +52,6 @@ class NettyServerStream extends AbstractServerStream { private final WriteQueue writeQueue; private final Attributes attributes; private final String authority; - private final TransportTracer transportTracer; private final int streamId; public NettyServerStream( @@ -60,14 +59,12 @@ public NettyServerStream( TransportState state, Attributes transportAttrs, String authority, - StatsTraceContext statsTraceCtx, - TransportTracer transportTracer) { + StatsTraceContext statsTraceCtx) { super(new NettyWritableBufferAllocator(channel.alloc()), statsTraceCtx); this.state = checkNotNull(state, "transportState"); this.writeQueue = state.handler.getWriteQueue(); this.attributes = checkNotNull(transportAttrs); this.authority = authority; - this.transportTracer = checkNotNull(transportTracer, "transportTracer"); // Read the id early to avoid reading transportState later. this.streamId = transportState().id(); } @@ -96,38 +93,26 @@ private class Sink implements AbstractServerStream.Sink { @Override public void writeHeaders(Metadata headers, boolean flush) { try (TaskCloseable ignore = PerfMark.traceTask("NettyServerStream$Sink.writeHeaders")) { - writeQueue.enqueue( - SendResponseHeadersCommand.createHeaders( - transportState(), - Utils.convertServerHeaders(headers)), - flush); + Http2Headers http2headers = Utils.convertServerHeaders(headers); + SendResponseHeadersCommand headersCommand = + SendResponseHeadersCommand.createHeaders(transportState(), http2headers); + writeQueue.enqueue(headersCommand, flush) + .addListener((ChannelFutureListener) transportState()::handleWriteFutureFailures); } } - private void writeFrameInternal(WritableBuffer frame, boolean flush, final int numMessages) { - Preconditions.checkArgument(numMessages >= 0); - ByteBuf bytebuf = ((NettyWritableBuffer) frame).bytebuf().touch(); - final int numBytes = bytebuf.readableBytes(); - // Add the bytes to outbound flow control. - onSendingBytes(numBytes); - writeQueue.enqueue(new SendGrpcFrameCommand(transportState(), bytebuf, false), flush) - .addListener(new ChannelFutureListener() { - @Override - public void operationComplete(ChannelFuture future) throws Exception { - // Remove the bytes from outbound flow control, optionally notifying - // the client that they can send more bytes. - transportState().onSentBytes(numBytes); - if (future.isSuccess()) { - transportTracer.reportMessageSent(numMessages); - } - } - }); - } - @Override public void writeFrame(WritableBuffer frame, boolean flush, final int numMessages) { try (TaskCloseable ignore = PerfMark.traceTask("NettyServerStream$Sink.writeFrame")) { - writeFrameInternal(frame, flush, numMessages); + Preconditions.checkArgument(numMessages >= 0); + ByteBuf bytebuf = ((NettyWritableBuffer) frame).bytebuf().touch(); + final int numBytes = bytebuf.readableBytes(); + // Add the bytes to outbound flow control. + onSendingBytes(numBytes); + ChannelFutureListener failureListener = + future -> transportState().onWriteFrameData(future, numMessages, numBytes); + writeQueue.enqueue(new SendGrpcFrameCommand(transportState(), bytebuf, false), flush) + .addListener(failureListener); } } @@ -135,16 +120,17 @@ public void writeFrame(WritableBuffer frame, boolean flush, final int numMessage public void writeTrailers(Metadata trailers, boolean headersSent, Status status) { try (TaskCloseable ignore = PerfMark.traceTask("NettyServerStream$Sink.writeTrailers")) { Http2Headers http2Trailers = Utils.convertTrailers(trailers, headersSent); - writeQueue.enqueue( - SendResponseHeadersCommand.createTrailers(transportState(), http2Trailers, status), - true); + SendResponseHeadersCommand trailersCommand = + SendResponseHeadersCommand.createTrailers(transportState(), http2Trailers, status); + writeQueue.enqueue(trailersCommand, true) + .addListener((ChannelFutureListener) transportState()::handleWriteFutureFailures); } } @Override public void cancel(Status status) { try (TaskCloseable ignore = PerfMark.traceTask("NettyServerStream$Sink.cancel")) { - writeQueue.enqueue(new CancelServerStreamCommand(transportState(), status), true); + writeQueue.enqueue(CancelServerStreamCommand.withReset(transportState(), status), true); } } } @@ -203,7 +189,40 @@ public void deframeFailed(Throwable cause) { log.log(Level.WARNING, "Exception processing message", cause); Status status = Status.fromThrowable(cause); transportReportStatus(status); - handler.getWriteQueue().enqueue(new CancelServerStreamCommand(this, status), true); + handler.getWriteQueue().enqueue(CancelServerStreamCommand.withReason(this, status), true); + } + + private void onWriteFrameData(ChannelFuture future, int numMessages, int numBytes) { + // Remove the bytes from outbound flow control, optionally notifying + // the client that they can send more bytes. + if (future.isSuccess()) { + onSentBytes(numBytes); + getTransportTracer().reportMessageSent(numMessages); + } else { + handleWriteFutureFailures(future); + } + } + + private void handleWriteFutureFailures(ChannelFuture future) { + // isStreamDeallocated() check protects from spamming stream resets by scheduling multiple + // CancelServerStreamCommand commands. + if (future.isSuccess() || isStreamDeallocated()) { + return; + } + + // Future failed, fail RPC. + // Normally we don't need to do anything on frame write failures because the cause of + // the failed future would be an IO error that closed the stream. + // However, we still need handle any unexpected failures raised in Netty. + http2ProcessingFailed(Utils.statusFromThrowable(future.cause())); + } + + /** + * Called to process a failure in HTTP/2 processing. + */ + protected void http2ProcessingFailed(Status status) { + transportReportStatus(status); + handler.getWriteQueue().enqueue(CancelServerStreamCommand.withReset(this, status), true); } void inboundDataReceived(ByteBuf frame, boolean endOfStream) { diff --git a/netty/src/main/java/io/grpc/netty/NettyServerTransport.java b/netty/src/main/java/io/grpc/netty/NettyServerTransport.java index 9511927a09f..c0e52b75876 100644 --- a/netty/src/main/java/io/grpc/netty/NettyServerTransport.java +++ b/netty/src/main/java/io/grpc/netty/NettyServerTransport.java @@ -25,6 +25,7 @@ import io.grpc.Attributes; import io.grpc.InternalChannelz.SocketStats; import io.grpc.InternalLogId; +import io.grpc.MetricRecorder; import io.grpc.ServerStreamTracer; import io.grpc.Status; import io.grpc.internal.ServerTransport; @@ -70,6 +71,7 @@ class NettyServerTransport implements ServerTransport { private final int flowControlWindow; private final int maxMessageSize; private final int maxHeaderListSize; + private final int softLimitHeaderListSize; private final long keepAliveTimeInNanos; private final long keepAliveTimeoutInNanos; private final long maxConnectionIdleInNanos; @@ -80,6 +82,7 @@ class NettyServerTransport implements ServerTransport { private final int maxRstCount; private final long maxRstPeriodNanos; private final Attributes eagAttributes; + private final MetricRecorder metricRecorder; private final List streamTracerFactories; private final TransportTracer transportTracer; @@ -94,6 +97,7 @@ class NettyServerTransport implements ServerTransport { int flowControlWindow, int maxMessageSize, int maxHeaderListSize, + int softLimitHeaderListSize, long keepAliveTimeInNanos, long keepAliveTimeoutInNanos, long maxConnectionIdleInNanos, @@ -103,7 +107,8 @@ class NettyServerTransport implements ServerTransport { long permitKeepAliveTimeInNanos, int maxRstCount, long maxRstPeriodNanos, - Attributes eagAttributes) { + Attributes eagAttributes, + MetricRecorder metricRecorder) { this.channel = Preconditions.checkNotNull(channel, "channel"); this.channelUnused = channelUnused; this.protocolNegotiator = Preconditions.checkNotNull(protocolNegotiator, "protocolNegotiator"); @@ -115,6 +120,7 @@ class NettyServerTransport implements ServerTransport { this.flowControlWindow = flowControlWindow; this.maxMessageSize = maxMessageSize; this.maxHeaderListSize = maxHeaderListSize; + this.softLimitHeaderListSize = softLimitHeaderListSize; this.keepAliveTimeInNanos = keepAliveTimeInNanos; this.keepAliveTimeoutInNanos = keepAliveTimeoutInNanos; this.maxConnectionIdleInNanos = maxConnectionIdleInNanos; @@ -125,6 +131,7 @@ class NettyServerTransport implements ServerTransport { this.maxRstCount = maxRstCount; this.maxRstPeriodNanos = maxRstPeriodNanos; this.eagAttributes = Preconditions.checkNotNull(eagAttributes, "eagAttributes"); + this.metricRecorder = metricRecorder; SocketAddress remote = channel.remoteAddress(); this.logId = InternalLogId.allocate(getClass(), remote != null ? remote.toString() : null); } @@ -275,6 +282,7 @@ private NettyServerHandler createHandler( autoFlowControl, flowControlWindow, maxHeaderListSize, + softLimitHeaderListSize, maxMessageSize, keepAliveTimeInNanos, keepAliveTimeoutInNanos, @@ -285,6 +293,7 @@ private NettyServerHandler createHandler( permitKeepAliveTimeInNanos, maxRstCount, maxRstPeriodNanos, - eagAttributes); + eagAttributes, + metricRecorder); } } diff --git a/netty/src/main/java/io/grpc/netty/NettySslContextChannelCredentials.java b/netty/src/main/java/io/grpc/netty/NettySslContextChannelCredentials.java index ede511b68f6..3d3fdc67e8e 100644 --- a/netty/src/main/java/io/grpc/netty/NettySslContextChannelCredentials.java +++ b/netty/src/main/java/io/grpc/netty/NettySslContextChannelCredentials.java @@ -34,6 +34,6 @@ public static ChannelCredentials create(SslContext sslContext) { Preconditions.checkArgument(sslContext.isClient(), "Server SSL context can not be used for client channel"); GrpcSslContexts.ensureAlpnAndH2Enabled(sslContext.applicationProtocolNegotiator()); - return NettyChannelCredentials.create(ProtocolNegotiators.tlsClientFactory(sslContext)); + return NettyChannelCredentials.create(ProtocolNegotiators.tlsClientFactory(sslContext, null)); } } diff --git a/netty/src/main/java/io/grpc/netty/NettyWritableBufferAllocator.java b/netty/src/main/java/io/grpc/netty/NettyWritableBufferAllocator.java index 9e93ee1155c..40b84717160 100644 --- a/netty/src/main/java/io/grpc/netty/NettyWritableBufferAllocator.java +++ b/netty/src/main/java/io/grpc/netty/NettyWritableBufferAllocator.java @@ -33,9 +33,6 @@ */ class NettyWritableBufferAllocator implements WritableBufferAllocator { - // Use 4k as our minimum buffer size. - private static final int MIN_BUFFER = 4 * 1024; - // Set the maximum buffer size to 1MB. private static final int MAX_BUFFER = 1024 * 1024; @@ -47,7 +44,7 @@ class NettyWritableBufferAllocator implements WritableBufferAllocator { @Override public WritableBuffer allocate(int capacityHint) { - capacityHint = Math.min(MAX_BUFFER, Math.max(MIN_BUFFER, capacityHint)); + capacityHint = Math.min(MAX_BUFFER, capacityHint); return new NettyWritableBuffer(allocator.buffer(capacityHint, capacityHint)); } } diff --git a/netty/src/main/java/io/grpc/netty/NoopSslEngine.java b/netty/src/main/java/io/grpc/netty/NoopSslEngine.java new file mode 100644 index 00000000000..7e14dbf0e79 --- /dev/null +++ b/netty/src/main/java/io/grpc/netty/NoopSslEngine.java @@ -0,0 +1,151 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.netty; + +import java.nio.ByteBuffer; +import javax.net.ssl.SSLEngine; +import javax.net.ssl.SSLEngineResult; +import javax.net.ssl.SSLException; +import javax.net.ssl.SSLSession; + +/** + * A no-op implementation of SslEngine, to facilitate overriding only the required methods in + * specific implementations. + */ +class NoopSslEngine extends SSLEngine { + @Override + public SSLEngineResult wrap(ByteBuffer[] srcs, int offset, int length, ByteBuffer dst) + throws SSLException { + return null; + } + + @Override + public SSLEngineResult unwrap(ByteBuffer src, ByteBuffer[] dsts, int offset, int length) + throws SSLException { + return null; + } + + @Override + public Runnable getDelegatedTask() { + return null; + } + + @Override + public void closeInbound() throws SSLException { + + } + + @Override + public boolean isInboundDone() { + return false; + } + + @Override + public void closeOutbound() { + + } + + @Override + public boolean isOutboundDone() { + return false; + } + + @Override + public String[] getSupportedCipherSuites() { + return new String[0]; + } + + @Override + public String[] getEnabledCipherSuites() { + return new String[0]; + } + + @Override + public void setEnabledCipherSuites(String[] suites) { + + } + + @Override + public String[] getSupportedProtocols() { + return new String[0]; + } + + @Override + public String[] getEnabledProtocols() { + return new String[0]; + } + + @Override + public void setEnabledProtocols(String[] protocols) { + + } + + @Override + public SSLSession getSession() { + return null; + } + + @Override + public void beginHandshake() throws SSLException { + + } + + @Override + public SSLEngineResult.HandshakeStatus getHandshakeStatus() { + return null; + } + + @Override + public void setUseClientMode(boolean mode) { + + } + + @Override + public boolean getUseClientMode() { + return false; + } + + @Override + public void setNeedClientAuth(boolean need) { + + } + + @Override + public boolean getNeedClientAuth() { + return false; + } + + @Override + public void setWantClientAuth(boolean want) { + + } + + @Override + public boolean getWantClientAuth() { + return false; + } + + @Override + public void setEnableSessionCreation(boolean flag) { + + } + + @Override + public boolean getEnableSessionCreation() { + return false; + } +} diff --git a/netty/src/main/java/io/grpc/netty/ProtocolNegotiationEvent.java b/netty/src/main/java/io/grpc/netty/ProtocolNegotiationEvent.java index 16da79e1af8..8103a2dc79f 100644 --- a/netty/src/main/java/io/grpc/netty/ProtocolNegotiationEvent.java +++ b/netty/src/main/java/io/grpc/netty/ProtocolNegotiationEvent.java @@ -20,10 +20,10 @@ import com.google.common.base.MoreObjects; import com.google.common.base.Objects; +import com.google.errorprone.annotations.CheckReturnValue; import io.grpc.Attributes; import io.grpc.Internal; import io.grpc.InternalChannelz.Security; -import javax.annotation.CheckReturnValue; import javax.annotation.Nullable; /** diff --git a/netty/src/main/java/io/grpc/netty/ProtocolNegotiator.java b/netty/src/main/java/io/grpc/netty/ProtocolNegotiator.java index 8a2c6f104b2..4332fdf2919 100644 --- a/netty/src/main/java/io/grpc/netty/ProtocolNegotiator.java +++ b/netty/src/main/java/io/grpc/netty/ProtocolNegotiator.java @@ -63,4 +63,5 @@ interface ServerFactory { */ ProtocolNegotiator newNegotiator(ObjectPool offloadExecutorPool); } + } diff --git a/netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java b/netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java index 19d3e01b785..8faf3d0fae8 100644 --- a/netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java +++ b/netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java @@ -16,10 +16,10 @@ package io.grpc.netty; -import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.base.Preconditions.checkState; import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Optional; import com.google.common.base.Preconditions; import com.google.errorprone.annotations.ForOverride; import io.grpc.Attributes; @@ -41,18 +41,22 @@ import io.grpc.Status; import io.grpc.TlsChannelCredentials; import io.grpc.TlsServerCredentials; +import io.grpc.internal.CertificateUtils; import io.grpc.internal.GrpcAttributes; import io.grpc.internal.GrpcUtil; +import io.grpc.internal.NoopSslSession; import io.grpc.internal.ObjectPool; import io.netty.channel.ChannelDuplexHandler; import io.netty.channel.ChannelFutureListener; import io.netty.channel.ChannelHandler; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.handler.codec.http.DefaultHttpHeaders; import io.netty.handler.codec.http.DefaultHttpRequest; import io.netty.handler.codec.http.HttpClientCodec; import io.netty.handler.codec.http.HttpClientUpgradeHandler; import io.netty.handler.codec.http.HttpHeaderNames; +import io.netty.handler.codec.http.HttpHeaders; import io.netty.handler.codec.http.HttpMethod; import io.netty.handler.codec.http.HttpVersion; import io.netty.handler.codec.http2.Http2ClientUpgradeCodec; @@ -70,8 +74,12 @@ import java.net.SocketAddress; import java.net.URI; import java.nio.channels.ClosedChannelException; +import java.security.GeneralSecurityException; +import java.security.KeyStore; import java.util.Arrays; import java.util.EnumSet; +import java.util.List; +import java.util.Map; import java.util.Set; import java.util.concurrent.Executor; import java.util.logging.Level; @@ -81,6 +89,10 @@ import javax.net.ssl.SSLException; import javax.net.ssl.SSLParameters; import javax.net.ssl.SSLSession; +import javax.net.ssl.TrustManager; +import javax.net.ssl.TrustManagerFactory; +import javax.net.ssl.X509TrustManager; +import org.codehaus.mojo.animal_sniffer.IgnoreJRERequirement; /** * Common {@link ProtocolNegotiator}s used by gRPC. @@ -94,7 +106,6 @@ final class ProtocolNegotiators { EnumSet.of( TlsServerCredentials.Feature.MTLS, TlsServerCredentials.Feature.CUSTOM_MANAGERS); - private ProtocolNegotiators() { } @@ -116,14 +127,25 @@ public static FromChannelCredentialsResult from(ChannelCredentials creds) { new ByteArrayInputStream(tlsCreds.getPrivateKey()), tlsCreds.getPrivateKeyPassword()); } - if (tlsCreds.getTrustManagers() != null) { - builder.trustManager(new FixedTrustManagerFactory(tlsCreds.getTrustManagers())); - } else if (tlsCreds.getRootCertificates() != null) { - builder.trustManager(new ByteArrayInputStream(tlsCreds.getRootCertificates())); - } // else use system default try { - return FromChannelCredentialsResult.negotiator(tlsClientFactory(builder.build())); - } catch (SSLException ex) { + List trustManagers; + if (tlsCreds.getTrustManagers() != null) { + trustManagers = tlsCreds.getTrustManagers(); + } else if (tlsCreds.getRootCertificates() != null) { + trustManagers = Arrays.asList(CertificateUtils.createTrustManager( + new ByteArrayInputStream(tlsCreds.getRootCertificates()))); + } else { // else use system default + TrustManagerFactory tmf = TrustManagerFactory.getInstance( + TrustManagerFactory.getDefaultAlgorithm()); + tmf.init((KeyStore) null); + trustManagers = Arrays.asList(tmf.getTrustManagers()); + } + builder.trustManager(new FixedTrustManagerFactory(trustManagers)); + TrustManager x509ExtendedTrustManager = + CertificateUtils.getX509ExtendedTrustManager(trustManagers); + return FromChannelCredentialsResult.negotiator(tlsClientFactory(builder.build(), + (X509TrustManager) x509ExtendedTrustManager)); + } catch (SSLException | GeneralSecurityException ex) { log.log(Level.FINE, "Exception building SslContext", ex); return FromChannelCredentialsResult.error( "Unable to create SslContext: " + ex.getMessage()); @@ -160,41 +182,6 @@ public static FromChannelCredentialsResult from(ChannelCredentials creds) { } } - public static final class FromChannelCredentialsResult { - public final ProtocolNegotiator.ClientFactory negotiator; - public final CallCredentials callCredentials; - public final String error; - - private FromChannelCredentialsResult(ProtocolNegotiator.ClientFactory negotiator, - CallCredentials creds, String error) { - this.negotiator = negotiator; - this.callCredentials = creds; - this.error = error; - } - - public static FromChannelCredentialsResult error(String error) { - return new FromChannelCredentialsResult( - null, null, Preconditions.checkNotNull(error, "error")); - } - - public static FromChannelCredentialsResult negotiator( - ProtocolNegotiator.ClientFactory factory) { - return new FromChannelCredentialsResult( - Preconditions.checkNotNull(factory, "factory"), null, null); - } - - public FromChannelCredentialsResult withCallCredentials(CallCredentials callCreds) { - Preconditions.checkNotNull(callCreds, "callCreds"); - if (error != null) { - return this; - } - if (this.callCredentials != null) { - callCreds = new CompositeCallCredentials(this.callCredentials, callCreds); - } - return new FromChannelCredentialsResult(negotiator, callCreds, null); - } - } - public static FromServerCredentialsResult from(ServerCredentials creds) { if (creds instanceof TlsServerCredentials) { TlsServerCredentials tlsCreds = (TlsServerCredentials) creds; @@ -273,6 +260,41 @@ public static FromServerCredentialsResult from(ServerCredentials creds) { } } + public static final class FromChannelCredentialsResult { + public final ProtocolNegotiator.ClientFactory negotiator; + public final CallCredentials callCredentials; + public final String error; + + private FromChannelCredentialsResult(ProtocolNegotiator.ClientFactory negotiator, + CallCredentials creds, String error) { + this.negotiator = negotiator; + this.callCredentials = creds; + this.error = error; + } + + public static FromChannelCredentialsResult error(String error) { + return new FromChannelCredentialsResult( + null, null, Preconditions.checkNotNull(error, "error")); + } + + public static FromChannelCredentialsResult negotiator( + ProtocolNegotiator.ClientFactory factory) { + return new FromChannelCredentialsResult( + Preconditions.checkNotNull(factory, "factory"), null, null); + } + + public FromChannelCredentialsResult withCallCredentials(CallCredentials callCreds) { + Preconditions.checkNotNull(callCreds, "callCreds"); + if (error != null) { + return this; + } + if (this.callCredentials != null) { + callCreds = new CompositeCallCredentials(this.callCredentials, callCreds); + } + return new FromChannelCredentialsResult(negotiator, callCreds, null); + } + } + public static final class FromServerCredentialsResult { public final ProtocolNegotiator.ServerFactory negotiator; public final String error; @@ -409,8 +431,8 @@ static final class ServerTlsHandler extends ChannelInboundHandlerAdapter { ServerTlsHandler(ChannelHandler next, SslContext sslContext, final ObjectPool executorPool) { - this.sslContext = checkNotNull(sslContext, "sslContext"); - this.next = checkNotNull(next, "next"); + this.sslContext = Preconditions.checkNotNull(sslContext, "sslContext"); + this.next = Preconditions.checkNotNull(next, "next"); if (executorPool != null) { this.executor = executorPool.getObject(); } @@ -465,18 +487,20 @@ private void fireProtocolNegotiationEvent(ChannelHandlerContext ctx, SSLSession * Returns a {@link ProtocolNegotiator} that does HTTP CONNECT proxy negotiation. */ public static ProtocolNegotiator httpProxy(final SocketAddress proxyAddress, - final @Nullable String proxyUsername, final @Nullable String proxyPassword, + final @Nullable Map headers, final @Nullable String proxyUsername, + final @Nullable String proxyPassword, final ProtocolNegotiator negotiator) { - checkNotNull(negotiator, "negotiator"); - checkNotNull(proxyAddress, "proxyAddress"); + Preconditions.checkNotNull(negotiator, "negotiator"); + Preconditions.checkNotNull(proxyAddress, "proxyAddress"); final AsciiString scheme = negotiator.scheme(); class ProxyNegotiator implements ProtocolNegotiator { @Override public ChannelHandler newHandler(GrpcHttp2ConnectionHandler http2Handler) { ChannelHandler protocolNegotiationHandler = negotiator.newHandler(http2Handler); ChannelLogger negotiationLogger = http2Handler.getNegotiationLogger(); + HttpHeaders httpHeaders = toHttpHeaders(headers); return new ProxyProtocolNegotiationHandler( - proxyAddress, proxyUsername, proxyPassword, protocolNegotiationHandler, + proxyAddress, httpHeaders, proxyUsername, proxyPassword, protocolNegotiationHandler, negotiationLogger); } @@ -496,6 +520,22 @@ public void close() { return new ProxyNegotiator(); } + /** + * Converts generic Map of headers to Netty's HttpHeaders. + * Returns null if the map is null or empty. + */ + @Nullable + private static HttpHeaders toHttpHeaders(@Nullable Map headers) { + if (headers == null || headers.isEmpty()) { + return null; + } + HttpHeaders httpHeaders = new DefaultHttpHeaders(); + for (Map.Entry entry : headers.entrySet()) { + httpHeaders.add(entry.getKey(), entry.getValue()); + } + return httpHeaders; + } + /** * A Proxy handler follows {@link ProtocolNegotiationHandler} pattern. Upon successful proxy * connection, this handler will install {@code next} handler which should be a handler from @@ -504,17 +544,20 @@ public void close() { static final class ProxyProtocolNegotiationHandler extends ProtocolNegotiationHandler { private final SocketAddress address; + @Nullable private final HttpHeaders httpHeaders; @Nullable private final String userName; @Nullable private final String password; public ProxyProtocolNegotiationHandler( SocketAddress address, + @Nullable HttpHeaders httpHeaders, @Nullable String userName, @Nullable String password, ChannelHandler next, ChannelLogger negotiationLogger) { super(next, negotiationLogger); - this.address = checkNotNull(address, "address"); + this.address = Preconditions.checkNotNull(address, "address"); + this.httpHeaders = httpHeaders; this.userName = userName; this.password = password; } @@ -523,9 +566,9 @@ public ProxyProtocolNegotiationHandler( protected void protocolNegotiationEventTriggered(ChannelHandlerContext ctx) { HttpProxyHandler nettyProxyHandler; if (userName == null || password == null) { - nettyProxyHandler = new HttpProxyHandler(address); + nettyProxyHandler = new HttpProxyHandler(address, httpHeaders); } else { - nettyProxyHandler = new HttpProxyHandler(address, userName, password); + nettyProxyHandler = new HttpProxyHandler(address, userName, password, httpHeaders); } ctx.pipeline().addBefore(ctx.name(), /* name= */ null, nettyProxyHandler); } @@ -543,16 +586,23 @@ protected void userEventTriggered0(ChannelHandlerContext ctx, Object evt) throws static final class ClientTlsProtocolNegotiator implements ProtocolNegotiator { public ClientTlsProtocolNegotiator(SslContext sslContext, - ObjectPool executorPool) { - this.sslContext = checkNotNull(sslContext, "sslContext"); + ObjectPool executorPool, Optional handshakeCompleteRunnable, + X509TrustManager x509ExtendedTrustManager, String sni) { + this.sslContext = Preconditions.checkNotNull(sslContext, "sslContext"); this.executorPool = executorPool; if (this.executorPool != null) { this.executor = this.executorPool.getObject(); } + this.handshakeCompleteRunnable = handshakeCompleteRunnable; + this.x509ExtendedTrustManager = x509ExtendedTrustManager; + this.sni = sni; } private final SslContext sslContext; private final ObjectPool executorPool; + private final Optional handshakeCompleteRunnable; + private final X509TrustManager x509ExtendedTrustManager; + private final String sni; private Executor executor; @Override @@ -564,8 +614,17 @@ public AsciiString scheme() { public ChannelHandler newHandler(GrpcHttp2ConnectionHandler grpcHandler) { ChannelHandler gnh = new GrpcNegotiationHandler(grpcHandler); ChannelLogger negotiationLogger = grpcHandler.getNegotiationLogger(); - ChannelHandler cth = new ClientTlsHandler(gnh, sslContext, grpcHandler.getAuthority(), - this.executor, negotiationLogger); + String authority; + if ("".equals(sni)) { + authority = null; + } else if (sni != null) { + authority = sni; + } else { + authority = grpcHandler.getAuthority(); + } + ChannelHandler cth = new ClientTlsHandler(gnh, sslContext, + authority, this.executor, negotiationLogger, handshakeCompleteRunnable, this, + x509ExtendedTrustManager); return new WaitUntilActiveHandler(cth, negotiationLogger); } @@ -575,6 +634,11 @@ public void close() { this.executorPool.returnObject(this.executor); } } + + @VisibleForTesting + boolean hasX509ExtendedTrustManager() { + return x509ExtendedTrustManager != null; + } } static final class ClientTlsHandler extends ProtocolNegotiationHandler { @@ -583,20 +647,38 @@ static final class ClientTlsHandler extends ProtocolNegotiationHandler { private final String host; private final int port; private Executor executor; + private final Optional handshakeCompleteRunnable; + private final X509TrustManager x509TrustManager; + private SSLEngine sslEngine; ClientTlsHandler(ChannelHandler next, SslContext sslContext, String authority, - Executor executor, ChannelLogger negotiationLogger) { + Executor executor, ChannelLogger negotiationLogger, + Optional handshakeCompleteRunnable, + ClientTlsProtocolNegotiator clientTlsProtocolNegotiator, + X509TrustManager x509TrustManager) { super(next, negotiationLogger); - this.sslContext = checkNotNull(sslContext, "sslContext"); - HostPort hostPort = parseAuthority(authority); - this.host = hostPort.host; - this.port = hostPort.port; + this.sslContext = Preconditions.checkNotNull(sslContext, "sslContext"); + if (authority != null) { + HostPort hostPort = parseAuthority(authority); + this.host = hostPort.host; + this.port = hostPort.port; + } else { + this.host = null; + this.port = 0; + } this.executor = executor; + this.handshakeCompleteRunnable = handshakeCompleteRunnable; + this.x509TrustManager = x509TrustManager; } @Override + @IgnoreJRERequirement protected void handlerAdded0(ChannelHandlerContext ctx) { - SSLEngine sslEngine = sslContext.newEngine(ctx.alloc(), host, port); + if (host != null) { + sslEngine = sslContext.newEngine(ctx.alloc(), host, port); + } else { + sslEngine = sslContext.newEngine(ctx.alloc()); + } SSLParameters sslParams = sslEngine.getSSLParameters(); sslParams.setEndpointIdentificationAlgorithm("HTTPS"); sslEngine.setSSLParameters(sslParams); @@ -636,6 +718,9 @@ protected void userEventTriggered0(ChannelHandlerContext ctx, Object evt) throws } ctx.fireExceptionCaught(t); } + if (handshakeCompleteRunnable.isPresent()) { + handshakeCompleteRunnable.get().run(); + } } else { super.userEventTriggered0(ctx, evt); } @@ -647,8 +732,13 @@ private void propagateTlsComplete(ChannelHandlerContext ctx, SSLSession session) Attributes attrs = existingPne.getAttributes().toBuilder() .set(GrpcAttributes.ATTR_SECURITY_LEVEL, SecurityLevel.PRIVACY_AND_INTEGRITY) .set(Grpc.TRANSPORT_ATTR_SSL_SESSION, session) + .set(GrpcAttributes.ATTR_AUTHORITY_VERIFIER, new X509AuthorityVerifier( + sslEngine, x509TrustManager)) .build(); replaceProtocolNegotiationEvent(existingPne.withAttributes(attrs).withSecurity(security)); + if (handshakeCompleteRunnable.isPresent()) { + handshakeCompleteRunnable.get().run(); + } fireProtocolNegotiationEvent(ctx); } } @@ -680,11 +770,14 @@ static HostPort parseAuthority(String authority) { * Returns a {@link ProtocolNegotiator} that ensures the pipeline is set up so that TLS will * be negotiated, the {@code handler} is added and writes to the {@link io.netty.channel.Channel} * may happen immediately, even before the TLS Handshake is complete. + * * @param executorPool a dedicated {@link Executor} pool for time-consuming TLS tasks */ public static ProtocolNegotiator tls(SslContext sslContext, - ObjectPool executorPool) { - return new ClientTlsProtocolNegotiator(sslContext, executorPool); + ObjectPool executorPool, Optional handshakeCompleteRunnable, + X509TrustManager x509ExtendedTrustManager, String sni) { + return new ClientTlsProtocolNegotiator(sslContext, executorPool, handshakeCompleteRunnable, + x509ExtendedTrustManager, sni); } /** @@ -692,25 +785,30 @@ public static ProtocolNegotiator tls(SslContext sslContext, * be negotiated, the {@code handler} is added and writes to the {@link io.netty.channel.Channel} * may happen immediately, even before the TLS Handshake is complete. */ - public static ProtocolNegotiator tls(SslContext sslContext) { - return tls(sslContext, null); + public static ProtocolNegotiator tls(SslContext sslContext, + X509TrustManager x509ExtendedTrustManager) { + return tls(sslContext, null, Optional.absent(), x509ExtendedTrustManager, null); } - public static ProtocolNegotiator.ClientFactory tlsClientFactory(SslContext sslContext) { - return new TlsProtocolNegotiatorClientFactory(sslContext); + public static ProtocolNegotiator.ClientFactory tlsClientFactory(SslContext sslContext, + X509TrustManager x509ExtendedTrustManager) { + return new TlsProtocolNegotiatorClientFactory(sslContext, x509ExtendedTrustManager); } @VisibleForTesting static final class TlsProtocolNegotiatorClientFactory implements ProtocolNegotiator.ClientFactory { private final SslContext sslContext; + private final X509TrustManager x509ExtendedTrustManager; - public TlsProtocolNegotiatorClientFactory(SslContext sslContext) { + public TlsProtocolNegotiatorClientFactory(SslContext sslContext, + X509TrustManager x509ExtendedTrustManager) { this.sslContext = Preconditions.checkNotNull(sslContext, "sslContext"); + this.x509ExtendedTrustManager = x509ExtendedTrustManager; } @Override public ProtocolNegotiator newNegotiator() { - return tls(sslContext); + return tls(sslContext, x509ExtendedTrustManager); } @Override public int getDefaultPort() { @@ -763,7 +861,9 @@ public AsciiString scheme() { public ChannelHandler newHandler(GrpcHttp2ConnectionHandler grpcHandler) { ChannelHandler upgradeHandler = new Http2UpgradeAndGrpcHandler(grpcHandler.getAuthority(), grpcHandler); - return new WaitUntilActiveHandler(upgradeHandler, grpcHandler.getNegotiationLogger()); + ChannelHandler plaintextHandler = + new PlaintextHandler(upgradeHandler, grpcHandler.getNegotiationLogger()); + return new WaitUntilActiveHandler(plaintextHandler, grpcHandler.getNegotiationLogger()); } @Override @@ -784,8 +884,8 @@ static final class Http2UpgradeAndGrpcHandler extends ChannelInboundHandlerAdapt private ProtocolNegotiationEvent pne; Http2UpgradeAndGrpcHandler(String authority, GrpcHttp2ConnectionHandler next) { - this.authority = checkNotNull(authority, "authority"); - this.next = checkNotNull(next, "next"); + this.authority = Preconditions.checkNotNull(authority, "authority"); + this.next = Preconditions.checkNotNull(next, "next"); this.negotiationLogger = next.getNegotiationLogger(); } @@ -829,9 +929,9 @@ public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exc } /** - * Returns a {@link ChannelHandler} that ensures that the {@code handler} is added to the - * pipeline writes to the {@link io.netty.channel.Channel} may happen immediately, even before it - * is active. + * Returns a {@link io.netty.channel.ChannelHandler} that ensures that the {@code handler} is + * added to the pipeline writes to the {@link io.netty.channel.Channel} may happen immediately, + * even before it is active. */ public static ProtocolNegotiator plaintext() { return new PlaintextProtocolNegotiator(); @@ -909,7 +1009,7 @@ static final class GrpcNegotiationHandler extends ChannelInboundHandlerAdapter { private final GrpcHttp2ConnectionHandler next; public GrpcNegotiationHandler(GrpcHttp2ConnectionHandler next) { - this.next = checkNotNull(next, "next"); + this.next = Preconditions.checkNotNull(next, "next"); } @Override @@ -960,7 +1060,9 @@ static final class PlaintextProtocolNegotiator implements ProtocolNegotiator { @Override public ChannelHandler newHandler(GrpcHttp2ConnectionHandler grpcHandler) { ChannelHandler grpcNegotiationHandler = new GrpcNegotiationHandler(grpcHandler); - ChannelHandler activeHandler = new WaitUntilActiveHandler(grpcNegotiationHandler, + ChannelHandler plaintextHandler = + new PlaintextHandler(grpcNegotiationHandler, grpcHandler.getNegotiationLogger()); + ChannelHandler activeHandler = new WaitUntilActiveHandler(plaintextHandler, grpcHandler.getNegotiationLogger()); return activeHandler; } @@ -974,6 +1076,22 @@ public AsciiString scheme() { } } + static final class PlaintextHandler extends ProtocolNegotiationHandler { + PlaintextHandler(ChannelHandler next, ChannelLogger negotiationLogger) { + super(next, negotiationLogger); + } + + @Override + protected void protocolNegotiationEventTriggered(ChannelHandlerContext ctx) { + ProtocolNegotiationEvent existingPne = getProtocolNegotiationEvent(); + Attributes attrs = existingPne.getAttributes().toBuilder() + .set(GrpcAttributes.ATTR_AUTHORITY_VERIFIER, (authority) -> Status.OK) + .build(); + replaceProtocolNegotiationEvent(existingPne.withAttributes(attrs)); + fireProtocolNegotiationEvent(ctx); + } + } + /** * Waits for the channel to be active, and then installs the next Handler. Using this allows * subsequent handlers to assume the channel is active and ready to send. Additionally, this a @@ -1031,15 +1149,15 @@ static class ProtocolNegotiationHandler extends ChannelDuplexHandler { protected ProtocolNegotiationHandler(ChannelHandler next, String negotiatorName, ChannelLogger negotiationLogger) { - this.next = checkNotNull(next, "next"); + this.next = Preconditions.checkNotNull(next, "next"); this.negotiatorName = negotiatorName; - this.negotiationLogger = checkNotNull(negotiationLogger, "negotiationLogger"); + this.negotiationLogger = Preconditions.checkNotNull(negotiationLogger, "negotiationLogger"); } protected ProtocolNegotiationHandler(ChannelHandler next, ChannelLogger negotiationLogger) { - this.next = checkNotNull(next, "next"); + this.next = Preconditions.checkNotNull(next, "next"); this.negotiatorName = getClass().getSimpleName().replace("Handler", ""); - this.negotiationLogger = checkNotNull(negotiationLogger, "negotiationLogger"); + this.negotiationLogger = Preconditions.checkNotNull(negotiationLogger, "negotiationLogger"); } @Override @@ -1080,7 +1198,7 @@ protected final ProtocolNegotiationEvent getProtocolNegotiationEvent() { protected final void replaceProtocolNegotiationEvent(ProtocolNegotiationEvent pne) { checkState(this.pne != null, "previous protocol negotiation event hasn't triggered"); - this.pne = checkNotNull(pne); + this.pne = Preconditions.checkNotNull(pne); } protected final void fireProtocolNegotiationEvent(ChannelHandlerContext ctx) { @@ -1090,4 +1208,42 @@ protected final void fireProtocolNegotiationEvent(ChannelHandlerContext ctx) { ctx.fireUserEventTriggered(pne); } } + + static final class SslEngineWrapper extends NoopSslEngine { + private final SSLEngine sslEngine; + private final String peerHost; + + SslEngineWrapper(SSLEngine sslEngine, String peerHost) { + this.sslEngine = sslEngine; + this.peerHost = peerHost; + } + + @Override + public String getPeerHost() { + return peerHost; + } + + @Override + public SSLSession getHandshakeSession() { + return new FakeSslSession(peerHost); + } + + @Override + public SSLParameters getSSLParameters() { + return sslEngine.getSSLParameters(); + } + } + + static final class FakeSslSession extends NoopSslSession { + private final String peerHost; + + FakeSslSession(String peerHost) { + this.peerHost = peerHost; + } + + @Override + public String getPeerHost() { + return peerHost; + } + } } diff --git a/netty/src/main/java/io/grpc/netty/TcpMetrics.java b/netty/src/main/java/io/grpc/netty/TcpMetrics.java new file mode 100644 index 00000000000..c5809a5677e --- /dev/null +++ b/netty/src/main/java/io/grpc/netty/TcpMetrics.java @@ -0,0 +1,227 @@ +/* + * Copyright 2026 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.netty; + +import com.google.common.annotations.VisibleForTesting; +import io.grpc.InternalTcpMetrics; +import io.grpc.MetricRecorder; +import io.netty.channel.Channel; +import io.netty.util.concurrent.ScheduledFuture; +import java.lang.reflect.Method; +import java.net.InetSocketAddress; +import java.net.SocketAddress; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.ThreadLocalRandom; +import java.util.concurrent.TimeUnit; +import java.util.logging.Level; +import java.util.logging.Logger; + +/** + * Utility for collecting TCP metrics from Netty channels. + */ +final class TcpMetrics { + private static final Logger log = Logger.getLogger(TcpMetrics.class.getName()); + + static EpollInfo epollInfo = loadEpollInfo(); + + static final class EpollInfo { + final Class channelClass; + final java.lang.reflect.Constructor infoConstructor; + final Method tcpInfo; + final Method totalRetrans; + final Method retransmits; + final Method rtt; + + EpollInfo( + Class channelClass, + java.lang.reflect.Constructor infoConstructor, + Method tcpInfo, + Method totalRetrans, + Method retransmits, + Method rtt) { + this.channelClass = channelClass; + this.infoConstructor = infoConstructor; + this.tcpInfo = tcpInfo; + this.totalRetrans = totalRetrans; + this.retransmits = retransmits; + this.rtt = rtt; + } + } + + static EpollInfo loadEpollInfo() { + boolean epollAvailable = false; + try { + Class epollClass = Class.forName("io.netty.channel.epoll.Epoll"); + Method isAvailableMethod = epollClass.getDeclaredMethod("isAvailable"); + epollAvailable = (Boolean) isAvailableMethod.invoke(null); + if (epollAvailable) { + Class channelClass = Class.forName("io.netty.channel.epoll.EpollSocketChannel"); + Class infoClass = Class.forName("io.netty.channel.epoll.EpollTcpInfo"); + return new EpollInfo( + channelClass, + infoClass.getDeclaredConstructor(), + channelClass.getMethod("tcpInfo", infoClass), + infoClass.getMethod("totalRetrans"), + infoClass.getMethod("retrans"), + infoClass.getMethod("rtt")); + } + } catch (ReflectiveOperationException e) { + log.log(Level.FINE, "Failed to initialize Epoll tcp_info reflection", e); + } finally { + log.log(Level.INFO, "Epoll available during static init of TcpMetrics:" + + "{0}", epollAvailable); + } + return null; + } + + private static final long RECORD_INTERVAL_MILLIS = TimeUnit.MINUTES.toMillis(5); + private final MetricRecorder metricRecorder; + private final Object tcpInfo; + private long lastTotalRetrans = 0; + private ScheduledFuture reportTimer; + + TcpMetrics(MetricRecorder metricRecorder) { + this.metricRecorder = metricRecorder; + + Object tcpInfo = null; + if (epollInfo != null) { + try { + tcpInfo = epollInfo.infoConstructor.newInstance(); + } catch (ReflectiveOperationException e) { + log.log(Level.FINE, "Failed to instantiate EpollTcpInfo", e); + } + } + this.tcpInfo = tcpInfo; + } + + void channelActive(Channel channel) { + List labelValues = getLabelValues(channel); + metricRecorder.addLongCounter(InternalTcpMetrics.CONNECTIONS_CREATED_INSTRUMENT, 1, + Collections.emptyList(), labelValues); + metricRecorder.addLongUpDownCounter(InternalTcpMetrics.CONNECTION_COUNT_INSTRUMENT, 1, + Collections.emptyList(), labelValues); + scheduleNextReport(channel, true); + } + + private void scheduleNextReport(final Channel channel, boolean isInitial) { + if (epollInfo == null || !epollInfo.channelClass.isInstance(channel) || !channel.isActive()) { + return; + } + + // Initial report has a larger jitter range to spread out initial connections. + // Subsequent reports have a smaller jitter range to avoid drift. + double jitter = isInitial + ? 0.1 + ThreadLocalRandom.current().nextDouble() // 10% to 110% + : 0.9 + ThreadLocalRandom.current().nextDouble() * 0.2; // 90% to 110% + long rearmingDelay = (long) (RECORD_INTERVAL_MILLIS * jitter); + + reportTimer = channel.eventLoop().schedule(() -> { + if (channel.isActive()) { + recordTcpInfo(channel, false); + scheduleNextReport(channel, false); // Re-arm + } + }, rearmingDelay, TimeUnit.MILLISECONDS); + } + + void channelInactive(Channel channel) { + if (reportTimer != null) { + reportTimer.cancel(false); + } + List labelValues = getLabelValues(channel); + metricRecorder.addLongUpDownCounter(InternalTcpMetrics.CONNECTION_COUNT_INSTRUMENT, -1, + Collections.emptyList(), labelValues); + // Final collection on close + if (epollInfo != null && epollInfo.channelClass.isInstance(channel)) { + recordTcpInfo(channel, true); + } + } + + void recordTcpInfo(Channel channel) { + recordTcpInfo(channel, false); + } + + private void recordTcpInfo(Channel channel, boolean isClose) { + if (epollInfo == null || !epollInfo.channelClass.isInstance(channel)) { + return; + } + List labelValues = getLabelValues(channel); + long totalRetrans; + long retransmits; + long rtt; + try { + epollInfo.tcpInfo.invoke(channel, tcpInfo); + totalRetrans = (Long) epollInfo.totalRetrans.invoke(tcpInfo); + retransmits = (Long) epollInfo.retransmits.invoke(tcpInfo); + rtt = (Long) epollInfo.rtt.invoke(tcpInfo); + } catch (ReflectiveOperationException e) { + log.log(Level.FINE, "Error computing TCP metrics", e); + return; + } + + long deltaTotal = totalRetrans - lastTotalRetrans; + if (deltaTotal > 0) { + metricRecorder.addLongCounter(InternalTcpMetrics.PACKETS_RETRANSMITTED_INSTRUMENT, + deltaTotal, Collections.emptyList(), labelValues); + lastTotalRetrans = totalRetrans; + } + if (isClose && retransmits > 0) { + metricRecorder.addLongCounter(InternalTcpMetrics.RECURRING_RETRANSMITS_INSTRUMENT, + retransmits, Collections.emptyList(), labelValues); + } + metricRecorder.recordDoubleHistogram(InternalTcpMetrics.MIN_RTT_INSTRUMENT, + rtt / 1000000.0, // Convert microseconds to seconds + Collections.emptyList(), labelValues); + } + + @VisibleForTesting + ScheduledFuture getReportTimer() { + return reportTimer; + } + + private static List getLabelValues(Channel channel) { + String localAddress = ""; + String localPort = ""; + String peerAddress = ""; + String peerPort = ""; + + SocketAddress local = channel.localAddress(); + if (local instanceof InetSocketAddress) { + InetSocketAddress inetLocal = (InetSocketAddress) local; + if (inetLocal.getAddress() != null) { + localAddress = inetLocal.getAddress().getHostAddress(); + } else if (inetLocal.getHostString() != null) { + localAddress = inetLocal.getHostString(); + } + localPort = String.valueOf(inetLocal.getPort()); + } + + SocketAddress remote = channel.remoteAddress(); + if (remote instanceof InetSocketAddress) { + InetSocketAddress inetRemote = (InetSocketAddress) remote; + if (inetRemote.getAddress() != null) { + peerAddress = inetRemote.getAddress().getHostAddress(); + } else if (inetRemote.getHostString() != null) { + peerAddress = inetRemote.getHostString(); + } + peerPort = String.valueOf(inetRemote.getPort()); + } + + return Arrays.asList(localAddress, localPort, peerAddress, peerPort); + } +} diff --git a/netty/src/main/java/io/grpc/netty/UdsNameResolver.java b/netty/src/main/java/io/grpc/netty/UdsNameResolver.java index 8fa8ea06250..3477a458933 100644 --- a/netty/src/main/java/io/grpc/netty/UdsNameResolver.java +++ b/netty/src/main/java/io/grpc/netty/UdsNameResolver.java @@ -18,21 +18,32 @@ import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkNotNull; +import static com.google.common.base.Strings.isNullOrEmpty; import com.google.common.base.Preconditions; import io.grpc.EquivalentAddressGroup; import io.grpc.NameResolver; +import io.grpc.StatusOr; import io.netty.channel.unix.DomainSocketAddress; import java.util.ArrayList; -import java.util.Collections; import java.util.List; final class UdsNameResolver extends NameResolver { private NameResolver.Listener2 listener; private final String authority; - UdsNameResolver(String authority, String targetPath) { - checkArgument(authority == null, "non-null authority not supported"); + /** + * Constructs a new instance of UdsNameResolver. + * + * @param authority authority of the 'unix:' URI to resolve, or null if target has no authority + * @param targetPath path of the 'unix:' URI to resolve + */ + UdsNameResolver(String authority, String targetPath, Args args) { + // UDS is inherently local. According to https://github.com/grpc/grpc/blob/master/doc/naming.md, + // this is expressed in the target URI either by using a blank authority, like "unix:///sock", + // or by omitting authority completely, e.g. "unix:/sock". + // TODO(jdcormie): Allow the explicit authority string "localhost"? + checkArgument(isNullOrEmpty(authority), "authority not supported: %s", authority); this.authority = targetPath; } @@ -57,8 +68,8 @@ private void resolve() { ResolutionResult.Builder resolutionResultBuilder = ResolutionResult.newBuilder(); List servers = new ArrayList<>(1); servers.add(new EquivalentAddressGroup(new DomainSocketAddress(authority))); - resolutionResultBuilder.setAddresses(Collections.unmodifiableList(servers)); - listener.onResult(resolutionResultBuilder.build()); + resolutionResultBuilder.setAddressesOrError(StatusOr.fromValue(servers)); + listener.onResult2(resolutionResultBuilder.build()); } @Override diff --git a/netty/src/main/java/io/grpc/netty/UdsNameResolverProvider.java b/netty/src/main/java/io/grpc/netty/UdsNameResolverProvider.java index 9f594193b4c..baf18e3d7de 100644 --- a/netty/src/main/java/io/grpc/netty/UdsNameResolverProvider.java +++ b/netty/src/main/java/io/grpc/netty/UdsNameResolverProvider.java @@ -20,6 +20,7 @@ import io.grpc.Internal; import io.grpc.NameResolver; import io.grpc.NameResolverProvider; +import io.grpc.Uri; import io.netty.channel.unix.DomainSocketAddress; import java.net.SocketAddress; import java.net.URI; @@ -31,10 +32,22 @@ public final class UdsNameResolverProvider extends NameResolverProvider { private static final String SCHEME = "unix"; + @Override + public NameResolver newNameResolver(Uri targetUri, NameResolver.Args args) { + if (SCHEME.equals(targetUri.getScheme())) { + return new UdsNameResolver(targetUri.getAuthority(), targetUri.getPath(), args); + } else { + return null; + } + } + @Override public UdsNameResolver newNameResolver(URI targetUri, NameResolver.Args args) { if (SCHEME.equals(targetUri.getScheme())) { - return new UdsNameResolver(targetUri.getAuthority(), getTargetPathFromUri(targetUri)); + // TODO(jdcormie): java.net.URI has a bug where getAuthority() returns null for both the + // undefined and zero-length authority. Doesn't matter for now because UdsNameResolver doesn't + // distinguish these cases. + return new UdsNameResolver(targetUri.getAuthority(), getTargetPathFromUri(targetUri), args); } else { return null; } @@ -44,6 +57,10 @@ static String getTargetPathFromUri(URI targetUri) { Preconditions.checkArgument(SCHEME.equals(targetUri.getScheme()), "scheme must be " + SCHEME); String targetPath = targetUri.getPath(); if (targetPath == null) { + // TODO(jdcormie): This incorrectly includes '?' and any characters that follow. In the + // hierarchical case ('unix:///path'), java.net.URI parses these into a query component that's + // distinct from the path. But in the present "opaque" case ('unix:/path'), what may look like + // a query is considered part of the SSP. targetPath = Preconditions.checkNotNull(targetUri.getSchemeSpecificPart(), "targetPath"); } return targetPath; diff --git a/netty/src/main/java/io/grpc/netty/Utils.java b/netty/src/main/java/io/grpc/netty/Utils.java index 96f19aab5e3..386df20ba0b 100644 --- a/netty/src/main/java/io/grpc/netty/Utils.java +++ b/netty/src/main/java/io/grpc/netty/Utils.java @@ -23,9 +23,11 @@ import static io.netty.channel.ChannelOption.SO_LINGER; import static io.netty.channel.ChannelOption.SO_TIMEOUT; import static io.netty.util.CharsetUtil.UTF_8; +import static java.nio.charset.StandardCharsets.US_ASCII; import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; +import com.google.errorprone.annotations.CheckReturnValue; import io.grpc.InternalChannelz; import io.grpc.InternalMetadata; import io.grpc.Metadata; @@ -67,7 +69,6 @@ import java.util.concurrent.TimeUnit; import java.util.logging.Level; import java.util.logging.Logger; -import javax.annotation.CheckReturnValue; import javax.annotation.Nullable; import javax.net.ssl.SSLException; @@ -91,7 +92,9 @@ class Utils { = new DefaultEventLoopGroupResource(1, "grpc-nio-boss-ELG", EventLoopGroupType.NIO); public static final Resource NIO_WORKER_EVENT_LOOP_GROUP = new DefaultEventLoopGroupResource(0, "grpc-nio-worker-ELG", EventLoopGroupType.NIO); - + private static final int HEADER_ENTRY_OVERHEAD = 32; + private static final byte[] binaryHeaderSuffixBytes = + Metadata.BINARY_HEADER_SUFFIX.getBytes(US_ASCII); public static final Resource DEFAULT_BOSS_EVENT_LOOP_GROUP; public static final Resource DEFAULT_WORKER_EVENT_LOOP_GROUP; @@ -119,10 +122,10 @@ private static final class ByteBufAllocatorPreferHeapHolder { EPOLL_DOMAIN_CLIENT_CHANNEL_TYPE = epollDomainSocketChannelType(); DEFAULT_SERVER_CHANNEL_FACTORY = new ReflectiveChannelFactory<>(epollServerChannelType()); EPOLL_EVENT_LOOP_GROUP_CONSTRUCTOR = epollEventLoopGroupConstructor(); - DEFAULT_BOSS_EVENT_LOOP_GROUP - = new DefaultEventLoopGroupResource(1, "grpc-default-boss-ELG", EventLoopGroupType.EPOLL); - DEFAULT_WORKER_EVENT_LOOP_GROUP - = new DefaultEventLoopGroupResource(0,"grpc-default-worker-ELG", EventLoopGroupType.EPOLL); + DEFAULT_BOSS_EVENT_LOOP_GROUP = new DefaultEventLoopGroupResource( + 1, "grpc-default-boss-ELG", EventLoopGroupType.EPOLL); + DEFAULT_WORKER_EVENT_LOOP_GROUP = new DefaultEventLoopGroupResource( + 0, "grpc-default-worker-ELG", EventLoopGroupType.EPOLL); } else { logger.log(Level.FINE, "Epoll is not available, using Nio.", getEpollUnavailabilityCause()); DEFAULT_SERVER_CHANNEL_FACTORY = nioServerChannelFactory(); @@ -195,6 +198,61 @@ public static Metadata convertHeaders(Http2Headers http2Headers) { return InternalMetadata.newMetadata(convertHeadersToArray(http2Headers)); } + public static int getH2HeadersSize(Http2Headers http2Headers) { + if (http2Headers instanceof GrpcHttp2InboundHeaders) { + GrpcHttp2InboundHeaders h = (GrpcHttp2InboundHeaders) http2Headers; + int size = 0; + for (int i = 0; i < h.numHeaders(); i++) { + size += h.namesAndValues()[2 * i].length; + size += + maybeAddBinaryHeaderOverhead(h.namesAndValues()[2 * i], h.namesAndValues()[2 * i + 1]); + size += HEADER_ENTRY_OVERHEAD; + } + return size; + } + + // the binary header is not decoded yet, no need to add overhead. + int size = 0; + for (Map.Entry entry : http2Headers) { + size += entry.getKey().length(); + size += entry.getValue().length(); + size += HEADER_ENTRY_OVERHEAD; + } + return size; + } + + private static int maybeAddBinaryHeaderOverhead(byte[] name, byte[] value) { + if (endsWith(name, binaryHeaderSuffixBytes)) { + return value.length * 4 / 3; + } + return value.length; + } + + private static boolean endsWith(byte[] bytes, byte[] suffix) { + if (bytes == null || suffix == null || bytes.length < suffix.length) { + return false; + } + + for (int i = 0; i < suffix.length; i++) { + if (bytes[bytes.length - suffix.length + i] != suffix[i]) { + return false; + } + } + + return true; + } + + public static boolean shouldRejectOnMetadataSizeSoftLimitExceeded( + int h2HeadersSize, int softLimitHeaderListSize, int maxHeaderListSize) { + if (h2HeadersSize < softLimitHeaderListSize) { + return false; + } + double failProbability = + (double) (h2HeadersSize - softLimitHeaderListSize) / (double) (maxHeaderListSize + - softLimitHeaderListSize); + return Math.random() < failProbability; + } + @CheckReturnValue private static byte[][] convertHeadersToArray(Http2Headers http2Headers) { // The Netty AsciiString class is really just a wrapper around a byte[] and supports diff --git a/netty/src/main/java/io/grpc/netty/X509AuthorityVerifier.java b/netty/src/main/java/io/grpc/netty/X509AuthorityVerifier.java new file mode 100644 index 00000000000..a2df3dbc431 --- /dev/null +++ b/netty/src/main/java/io/grpc/netty/X509AuthorityVerifier.java @@ -0,0 +1,108 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.netty; + +import static com.google.common.base.Preconditions.checkNotNull; + +import io.grpc.Status; +import io.grpc.internal.AuthorityVerifier; +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.security.cert.Certificate; +import java.security.cert.CertificateException; +import java.security.cert.X509Certificate; +import javax.annotation.Nonnull; +import javax.net.ssl.SSLEngine; +import javax.net.ssl.SSLPeerUnverifiedException; +import javax.net.ssl.X509TrustManager; + +final class X509AuthorityVerifier implements AuthorityVerifier { + private final SSLEngine sslEngine; + private final X509TrustManager x509ExtendedTrustManager; + + private static final Method checkServerTrustedMethod; + + static { + Method method = null; + try { + Class x509ExtendedTrustManagerClass = + Class.forName("javax.net.ssl.X509ExtendedTrustManager"); + method = x509ExtendedTrustManagerClass.getMethod("checkServerTrusted", + X509Certificate[].class, String.class, SSLEngine.class); + } catch (ClassNotFoundException e) { + // Per-rpc authority overriding via call options will be disallowed. + } catch (NoSuchMethodException e) { + // Should never happen since X509ExtendedTrustManager was introduced in Android API level 24 + // along with checkServerTrusted. + } + checkServerTrustedMethod = method; + } + + public X509AuthorityVerifier(SSLEngine sslEngine, X509TrustManager x509ExtendedTrustManager) { + this.sslEngine = checkNotNull(sslEngine, "sslEngine"); + this.x509ExtendedTrustManager = x509ExtendedTrustManager; + } + + @Override + public Status verifyAuthority(@Nonnull String authority) { + if (x509ExtendedTrustManager == null) { + return Status.UNAVAILABLE.withDescription( + "Can't allow authority override in rpc when X509ExtendedTrustManager" + + " is not available"); + } + Status peerVerificationStatus; + try { + // Because the authority pseudo-header can contain a port number: + // https://www.rfc-editor.org/rfc/rfc7540#section-8.1.2.3 + verifyAuthorityAllowedForPeerCert(removeAnyPortNumber(authority)); + peerVerificationStatus = Status.OK; + } catch (SSLPeerUnverifiedException | CertificateException | InvocationTargetException + | IllegalAccessException | IllegalStateException e) { + peerVerificationStatus = Status.UNAVAILABLE.withDescription( + String.format("Peer hostname verification during rpc failed for authority '%s'", + authority)).withCause(e); + } + return peerVerificationStatus; + } + + private String removeAnyPortNumber(String authority) { + int closingSquareBracketIndex = authority.lastIndexOf(']'); + int portNumberSeperatorColonIndex = authority.lastIndexOf(':'); + if (portNumberSeperatorColonIndex > closingSquareBracketIndex) { + return authority.substring(0, portNumberSeperatorColonIndex); + } + return authority; + } + + private void verifyAuthorityAllowedForPeerCert(String authority) + throws SSLPeerUnverifiedException, CertificateException, InvocationTargetException, + IllegalAccessException { + SSLEngine sslEngineWrapper = new ProtocolNegotiators.SslEngineWrapper(sslEngine, authority); + // The typecasting of Certificate to X509Certificate should work because this method will only + // be called when using TLS and thus X509. + Certificate[] peerCertificates = sslEngine.getSession().getPeerCertificates(); + X509Certificate[] x509PeerCertificates = new X509Certificate[peerCertificates.length]; + for (int i = 0; i < peerCertificates.length; i++) { + x509PeerCertificates[i] = (X509Certificate) peerCertificates[i]; + } + if (checkServerTrustedMethod == null) { + throw new IllegalStateException("checkServerTrustedMethod not found"); + } + checkServerTrustedMethod.invoke( + x509ExtendedTrustManager, x509PeerCertificates, "UNKNOWN", sslEngineWrapper); + } +} diff --git a/netty/src/main/java/io/grpc/netty/package-info.java b/netty/src/main/java/io/grpc/netty/package-info.java index 54595b38573..d1d7b87cf51 100644 --- a/netty/src/main/java/io/grpc/netty/package-info.java +++ b/netty/src/main/java/io/grpc/netty/package-info.java @@ -18,5 +18,5 @@ * The main transport implementation based on Netty, * for both the client and the server. */ -@javax.annotation.CheckReturnValue +@com.google.errorprone.annotations.CheckReturnValue package io.grpc.netty; diff --git a/netty/src/test/java/io/grpc/netty/AdvancedTlsTest.java b/netty/src/test/java/io/grpc/netty/AdvancedTlsTest.java index c60cb4824dd..66591cda153 100644 --- a/netty/src/test/java/io/grpc/netty/AdvancedTlsTest.java +++ b/netty/src/test/java/io/grpc/netty/AdvancedTlsTest.java @@ -45,14 +45,11 @@ import io.grpc.util.CertificateUtils; import java.io.Closeable; import java.io.File; -import java.io.IOException; import java.net.Socket; import java.security.GeneralSecurityException; -import java.security.NoSuchAlgorithmException; import java.security.PrivateKey; import java.security.cert.CertificateException; import java.security.cert.X509Certificate; -import java.security.spec.InvalidKeySpecException; import java.util.concurrent.Executors; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; @@ -65,13 +62,13 @@ @RunWith(JUnit4.class) public class AdvancedTlsTest { - public static final String SERVER_0_KEY_FILE = "server0.key"; - public static final String SERVER_0_PEM_FILE = "server0.pem"; - public static final String CLIENT_0_KEY_FILE = "client.key"; - public static final String CLIENT_0_PEM_FILE = "client.pem"; - public static final String CA_PEM_FILE = "ca.pem"; - public static final String SERVER_BAD_KEY_FILE = "badserver.key"; - public static final String SERVER_BAD_PEM_FILE = "badserver.pem"; + private static final String SERVER_0_KEY_FILE = "server0.key"; + private static final String SERVER_0_PEM_FILE = "server0.pem"; + private static final String CLIENT_0_KEY_FILE = "client.key"; + private static final String CLIENT_0_PEM_FILE = "client.pem"; + private static final String CA_PEM_FILE = "ca.pem"; + private static final String SERVER_BAD_KEY_FILE = "badserver.key"; + private static final String SERVER_BAD_PEM_FILE = "badserver.pem"; private ScheduledExecutorService executor; private Server server; @@ -92,7 +89,7 @@ public class AdvancedTlsTest { @Before public void setUp() - throws NoSuchAlgorithmException, IOException, CertificateException, InvalidKeySpecException { + throws Exception { executor = Executors.newSingleThreadScheduledExecutor(); caCertFile = TestUtils.loadCert(CA_PEM_FILE); serverKey0File = TestUtils.loadCert(SERVER_0_KEY_FILE); @@ -150,7 +147,7 @@ public void basicMutualTlsTest() throws Exception { public void advancedTlsKeyManagerTrustManagerMutualTlsTest() throws Exception { // Create a server with the key manager and trust manager. AdvancedTlsX509KeyManager serverKeyManager = new AdvancedTlsX509KeyManager(); - serverKeyManager.updateIdentityCredentials(serverKey0, serverCert0); + serverKeyManager.updateIdentityCredentials(serverCert0, serverKey0); AdvancedTlsX509TrustManager serverTrustManager = AdvancedTlsX509TrustManager.newBuilder() .setVerification(Verification.CERTIFICATE_ONLY_VERIFICATION) .build(); @@ -162,7 +159,7 @@ public void advancedTlsKeyManagerTrustManagerMutualTlsTest() throws Exception { new SimpleServiceImpl()).build().start(); // Create a client with the key manager and trust manager. AdvancedTlsX509KeyManager clientKeyManager = new AdvancedTlsX509KeyManager(); - clientKeyManager.updateIdentityCredentials(clientKey0, clientCert0); + clientKeyManager.updateIdentityCredentials(clientCert0, clientKey0); AdvancedTlsX509TrustManager clientTrustManager = AdvancedTlsX509TrustManager.newBuilder() .setVerification(Verification.CERTIFICATE_AND_HOST_NAME_VERIFICATION) .build(); @@ -185,7 +182,7 @@ public void advancedTlsKeyManagerTrustManagerMutualTlsTest() throws Exception { @Test public void trustManagerCustomVerifierMutualTlsTest() throws Exception { AdvancedTlsX509KeyManager serverKeyManager = new AdvancedTlsX509KeyManager(); - serverKeyManager.updateIdentityCredentials(serverKey0, serverCert0); + serverKeyManager.updateIdentityCredentials(serverCert0, serverKey0); // Set server's custom verification based on the information of clientCert0. AdvancedTlsX509TrustManager serverTrustManager = AdvancedTlsX509TrustManager.newBuilder() .setVerification(Verification.CERTIFICATE_ONLY_VERIFICATION) @@ -224,7 +221,7 @@ public void verifyPeerCertificate(X509Certificate[] peerCertChain, String authTy new SimpleServiceImpl()).build().start(); AdvancedTlsX509KeyManager clientKeyManager = new AdvancedTlsX509KeyManager(); - clientKeyManager.updateIdentityCredentials(clientKey0, clientCert0); + clientKeyManager.updateIdentityCredentials(clientCert0, clientKey0); // Set client's custom verification based on the information of serverCert0. AdvancedTlsX509TrustManager clientTrustManager = AdvancedTlsX509TrustManager.newBuilder() .setVerification(Verification.CERTIFICATE_ONLY_VERIFICATION) @@ -278,18 +275,18 @@ public void trustManagerInsecurelySkipAllTest() throws Exception { AdvancedTlsX509KeyManager serverKeyManager = new AdvancedTlsX509KeyManager(); // Even if we provide bad credentials for the server, the test should still pass, because we // will configure the client to skip all checks later. - serverKeyManager.updateIdentityCredentials(serverKeyBad, serverCertBad); + serverKeyManager.updateIdentityCredentials(serverCertBad, serverKeyBad); AdvancedTlsX509TrustManager serverTrustManager = AdvancedTlsX509TrustManager.newBuilder() .setVerification(Verification.CERTIFICATE_ONLY_VERIFICATION) .setSslSocketAndEnginePeerVerifier( new SslSocketAndEnginePeerVerifier() { @Override public void verifyPeerCertificate(X509Certificate[] peerCertChain, String authType, - Socket socket) throws CertificateException { } + Socket socket) { } @Override public void verifyPeerCertificate(X509Certificate[] peerCertChain, String authType, - SSLEngine engine) throws CertificateException { } + SSLEngine engine) { } }) .build(); serverTrustManager.updateTrustCredentials(caCert); @@ -300,7 +297,7 @@ public void verifyPeerCertificate(X509Certificate[] peerCertChain, String authTy new SimpleServiceImpl()).build().start(); AdvancedTlsX509KeyManager clientKeyManager = new AdvancedTlsX509KeyManager(); - clientKeyManager.updateIdentityCredentials(clientKey0, clientCert0); + clientKeyManager.updateIdentityCredentials(clientCert0, clientKey0); // Set the client to skip all checks, including traditional certificate verification. // Note this is very dangerous in production environment - only do so if you are confident on // what you are doing! @@ -310,11 +307,11 @@ public void verifyPeerCertificate(X509Certificate[] peerCertChain, String authTy new SslSocketAndEnginePeerVerifier() { @Override public void verifyPeerCertificate(X509Certificate[] peerCertChain, String authType, - Socket socket) throws CertificateException { } + Socket socket) { } @Override public void verifyPeerCertificate(X509Certificate[] peerCertChain, String authType, - SSLEngine engine) throws CertificateException { } + SSLEngine engine) { } }) .build(); clientTrustManager.updateTrustCredentials(caCert); @@ -337,12 +334,12 @@ public void verifyPeerCertificate(X509Certificate[] peerCertChain, String authTy public void onFileReloadingKeyManagerTrustManagerTest() throws Exception { // Create & start a server. AdvancedTlsX509KeyManager serverKeyManager = new AdvancedTlsX509KeyManager(); - Closeable serverKeyShutdown = serverKeyManager.updateIdentityCredentialsFromFile(serverKey0File, - serverCert0File, 100, TimeUnit.MILLISECONDS, executor); + Closeable serverKeyShutdown = serverKeyManager.updateIdentityCredentials(serverCert0File, + serverKey0File, 100, TimeUnit.MILLISECONDS, executor); AdvancedTlsX509TrustManager serverTrustManager = AdvancedTlsX509TrustManager.newBuilder() .setVerification(Verification.CERTIFICATE_ONLY_VERIFICATION) .build(); - Closeable serverTrustShutdown = serverTrustManager.updateTrustCredentialsFromFile(caCertFile, + Closeable serverTrustShutdown = serverTrustManager.updateTrustCredentials(caCertFile, 100, TimeUnit.MILLISECONDS, executor); ServerCredentials serverCredentials = TlsServerCredentials.newBuilder() .keyManager(serverKeyManager).trustManager(serverTrustManager) @@ -351,12 +348,12 @@ public void onFileReloadingKeyManagerTrustManagerTest() throws Exception { new SimpleServiceImpl()).build().start(); // Create a client to connect. AdvancedTlsX509KeyManager clientKeyManager = new AdvancedTlsX509KeyManager(); - Closeable clientKeyShutdown = clientKeyManager.updateIdentityCredentialsFromFile(clientKey0File, - clientCert0File,100, TimeUnit.MILLISECONDS, executor); + Closeable clientKeyShutdown = clientKeyManager.updateIdentityCredentials(clientCert0File, + clientKey0File, 100, TimeUnit.MILLISECONDS, executor); AdvancedTlsX509TrustManager clientTrustManager = AdvancedTlsX509TrustManager.newBuilder() .setVerification(Verification.CERTIFICATE_AND_HOST_NAME_VERIFICATION) .build(); - Closeable clientTrustShutdown = clientTrustManager.updateTrustCredentialsFromFile(caCertFile, + Closeable clientTrustShutdown = clientTrustManager.updateTrustCredentials(caCertFile, 100, TimeUnit.MILLISECONDS, executor); ChannelCredentials channelCredentials = TlsChannelCredentials.newBuilder() .keyManager(clientKeyManager).trustManager(clientTrustManager).build(); @@ -384,11 +381,11 @@ public void onFileReloadingKeyManagerTrustManagerTest() throws Exception { public void onFileLoadingKeyManagerTrustManagerTest() throws Exception { // Create & start a server. AdvancedTlsX509KeyManager serverKeyManager = new AdvancedTlsX509KeyManager(); - serverKeyManager.updateIdentityCredentialsFromFile(serverKey0File, serverCert0File); + serverKeyManager.updateIdentityCredentials(serverCert0File, serverKey0File); AdvancedTlsX509TrustManager serverTrustManager = AdvancedTlsX509TrustManager.newBuilder() .setVerification(Verification.CERTIFICATE_ONLY_VERIFICATION) .build(); - serverTrustManager.updateTrustCredentialsFromFile(caCertFile); + serverTrustManager.updateTrustCredentials(caCertFile); ServerCredentials serverCredentials = TlsServerCredentials.newBuilder() .keyManager(serverKeyManager).trustManager(serverTrustManager) .clientAuth(ClientAuth.REQUIRE).build(); @@ -396,11 +393,11 @@ public void onFileLoadingKeyManagerTrustManagerTest() throws Exception { new SimpleServiceImpl()).build().start(); // Create a client to connect. AdvancedTlsX509KeyManager clientKeyManager = new AdvancedTlsX509KeyManager(); - clientKeyManager.updateIdentityCredentialsFromFile(clientKey0File, clientCert0File); + clientKeyManager.updateIdentityCredentials(clientCert0File, clientKey0File); AdvancedTlsX509TrustManager clientTrustManager = AdvancedTlsX509TrustManager.newBuilder() .setVerification(Verification.CERTIFICATE_AND_HOST_NAME_VERIFICATION) .build(); - clientTrustManager.updateTrustCredentialsFromFile(caCertFile); + clientTrustManager.updateTrustCredentials(caCertFile); ChannelCredentials channelCredentials = TlsChannelCredentials.newBuilder() .keyManager(clientKeyManager).trustManager(clientTrustManager).build(); channel = Grpc.newChannelBuilderForAddress("localhost", server.getPort(), channelCredentials) @@ -419,12 +416,12 @@ public void onFileLoadingKeyManagerTrustManagerTest() throws Exception { } @Test - public void onFileReloadingKeyManagerBadInitialContentTest() throws Exception { + public void onFileReloadingKeyManagerBadInitialContentTest() { AdvancedTlsX509KeyManager keyManager = new AdvancedTlsX509KeyManager(); // We swap the order of key and certificates to intentionally create an exception. assertThrows(GeneralSecurityException.class, - () -> keyManager.updateIdentityCredentialsFromFile(serverCert0File, - serverKey0File, 100, TimeUnit.MILLISECONDS, executor)); + () -> keyManager.updateIdentityCredentials(serverKey0File, serverCert0File, + 100, TimeUnit.MILLISECONDS, executor)); } @Test @@ -434,21 +431,18 @@ public void onFileReloadingTrustManagerBadInitialContentTest() throws Exception .build(); // We pass in a key as the trust certificates to intentionally create an exception. assertThrows(GeneralSecurityException.class, - () -> trustManager.updateTrustCredentialsFromFile(serverKey0File, - 100, TimeUnit.MILLISECONDS, executor)); + () -> trustManager.updateTrustCredentials(serverKey0File, 100, TimeUnit.MILLISECONDS, + executor)); } @Test public void keyManagerAliasesTest() throws Exception { AdvancedTlsX509KeyManager km = new AdvancedTlsX509KeyManager(); - assertArrayEquals( - new String[] {"default"}, km.getClientAliases("", null)); - assertEquals( - "default", km.chooseClientAlias(new String[] {"default"}, null, null)); - assertArrayEquals( - new String[] {"default"}, km.getServerAliases("", null)); - assertEquals( - "default", km.chooseServerAlias("default", null, null)); + km.updateIdentityCredentials(serverCert0, serverKey0); + assertArrayEquals(new String[] {"key-1"}, km.getClientAliases("", null)); + assertEquals("key-1", km.chooseClientAlias(new String[] {"key-1"}, null, null)); + assertArrayEquals(new String[] {"key-1"}, km.getServerAliases("", null)); + assertEquals("key-1", km.chooseServerAlias("key-1", null, null)); } @Test diff --git a/netty/src/test/java/io/grpc/netty/NettyAdaptiveCumulatorTest.java b/netty/src/test/java/io/grpc/netty/NettyAdaptiveCumulatorTest.java index c4c2f95a2a9..b19f247b5cf 100644 --- a/netty/src/test/java/io/grpc/netty/NettyAdaptiveCumulatorTest.java +++ b/netty/src/test/java/io/grpc/netty/NettyAdaptiveCumulatorTest.java @@ -40,7 +40,6 @@ import io.netty.buffer.UnpooledByteBufAllocator; import java.util.Collection; import java.util.List; -import java.util.stream.Collectors; import org.junit.After; import org.junit.Before; import org.junit.Test; @@ -53,9 +52,12 @@ @RunWith(Enclosed.class) public class NettyAdaptiveCumulatorTest { + private static boolean usingPre4_1_111_Netty() { + return false; // Disabled detection because it was unreliable + } private static Collection cartesianProductParams(List... lists) { - return Lists.cartesianProduct(lists).stream().map(List::toArray).collect(Collectors.toList()); + return Lists.transform(Lists.cartesianProduct(lists), List::toArray); } @RunWith(JUnit4.class) @@ -122,7 +124,7 @@ public void cumulate_contiguousCumulation_newCompositeFromContiguousAndInput() { @Test public void cumulate_compositeCumulation_inputAppendedAsANewComponent() { - CompositeByteBuf composite = alloc.compositeBuffer().addComponent(true, contiguous); + CompositeByteBuf composite = alloc.compositeBuffer().addFlattenedComponents(true, contiguous); assertSame(composite, cumulator.cumulate(alloc, composite, in)); assertEquals(DATA_INITIAL, composite.component(0).toString(US_ASCII)); assertEquals(DATA_INCOMING, composite.component(1).toString(US_ASCII)); @@ -136,7 +138,7 @@ public void cumulate_compositeCumulation_inputAppendedAsANewComponent() { @Test public void cumulate_compositeCumulation_inputReleasedOnError() { - CompositeByteBuf composite = alloc.compositeBuffer().addComponent(true, contiguous); + CompositeByteBuf composite = alloc.compositeBuffer().addFlattenedComponents(true, contiguous); try { throwingCumulator.cumulate(alloc, composite, in); fail("Cumulator didn't throw"); @@ -386,6 +388,8 @@ public void mergeWithCompositeTail_tailExpandable_reallocateInMemory() { } private void assertTailExpanded(String expectedTailReadableData, int expectedNewTailCapacity) { + assume().withMessage("Netty 4.1.111 doesn't work with NettyAdaptiveCumulator") + .that(usingPre4_1_111_Netty()).isTrue(); int originalNumComponents = composite.numComponents(); // Handle the case when reader index is beyond all readable bytes of the cumulation. @@ -528,7 +532,7 @@ public void mergeWithCompositeTail_tailExpandable_mergedReleaseOnThrow() { tail) { @Override public CompositeByteBuf addFlattenedComponents(boolean increaseWriterIndex, - ByteBuf buffer) { + ByteBuf buffer) { throw expectedError; } }; @@ -562,7 +566,7 @@ public void mergeWithCompositeTail_tailNotExpandable_mergedReleaseOnThrow() { tail.asReadOnly()) { @Override public CompositeByteBuf addFlattenedComponents(boolean increaseWriterIndex, - ByteBuf buffer) { + ByteBuf buffer) { throw expectedError; } }; @@ -626,6 +630,9 @@ public void mergeWithCompositeTail_outOfSyncComposite() { alloc.compositeBuffer(8).addFlattenedComponents(true, composite1); assertThat(composite2.toString(US_ASCII)).isEqualTo("01234"); + assume().withMessage("Netty 4.1.111 doesn't work with NettyAdaptiveCumulator") + .that(usingPre4_1_111_Netty()).isTrue(); + // The previous operation does not adjust the read indexes of the underlying buffers, // only the internal Component offsets. When the cumulator attempts to append the input to // the tail buffer, it extracts it from the cumulation, writes to it, and then adds it back. diff --git a/netty/src/test/java/io/grpc/netty/NettyChannelBuilderTest.java b/netty/src/test/java/io/grpc/netty/NettyChannelBuilderTest.java index 5789d275c07..95d54d13b82 100644 --- a/netty/src/test/java/io/grpc/netty/NettyChannelBuilderTest.java +++ b/netty/src/test/java/io/grpc/netty/NettyChannelBuilderTest.java @@ -19,6 +19,7 @@ import static com.google.common.truth.Truth.assertThat; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; import static org.mockito.Mockito.mock; @@ -39,17 +40,13 @@ import java.net.SocketAddress; import java.util.concurrent.TimeUnit; import javax.net.ssl.SSLException; -import org.junit.Rule; import org.junit.Test; -import org.junit.rules.ExpectedException; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @RunWith(JUnit4.class) public class NettyChannelBuilderTest { - @SuppressWarnings("deprecation") // https://github.com/grpc/grpc-java/issues/7467 - @Rule public final ExpectedException thrown = ExpectedException.none(); private final SslContext noSslContext = null; private void shutdown(ManagedChannel mc) throws Exception { @@ -107,10 +104,9 @@ private void overrideAuthorityIsReadableHelper(NettyChannelBuilder builder, public void failOverrideInvalidAuthority() { NettyChannelBuilder builder = new NettyChannelBuilder(getTestSocketAddress()); - thrown.expect(IllegalArgumentException.class); - thrown.expectMessage("Invalid authority:"); - - builder.overrideAuthority("[invalidauthority"); + IllegalArgumentException e = assertThrows(IllegalArgumentException.class, + () -> builder.overrideAuthority("[invalidauthority")); + assertThat(e).hasMessageThat().isEqualTo("Invalid authority: [invalidauthority"); } @Test @@ -128,20 +124,18 @@ public void enableCheckAuthorityFailOverrideInvalidAuthority() { NettyChannelBuilder builder = new NettyChannelBuilder(getTestSocketAddress()) .disableCheckAuthority() .enableCheckAuthority(); - - thrown.expect(IllegalArgumentException.class); - thrown.expectMessage("Invalid authority:"); - builder.overrideAuthority("[invalidauthority"); + IllegalArgumentException e = assertThrows(IllegalArgumentException.class, + () -> builder.overrideAuthority("[invalidauthority")); + assertThat(e).hasMessageThat().isEqualTo("Invalid authority: [invalidauthority"); } @Test public void failInvalidAuthority() { - thrown.expect(IllegalArgumentException.class); - thrown.expectMessage("Invalid host or port"); - @SuppressWarnings("AddressSelection") // We actually expect zero addresses! - Object unused = - NettyChannelBuilder.forAddress(new InetSocketAddress("invalid_authority", 1234)); + InetSocketAddress address = new InetSocketAddress("invalid_authority", 1234); + IllegalArgumentException e = assertThrows(IllegalArgumentException.class, + () -> NettyChannelBuilder.forAddress(address)); + assertThat(e).hasMessageThat().isEqualTo("Invalid host or port: invalid_authority 1234"); } @Test @@ -155,10 +149,10 @@ public void failIfSslContextIsNotClient() { SslContext sslContext = mock(SslContext.class); NettyChannelBuilder builder = new NettyChannelBuilder(getTestSocketAddress()); - thrown.expect(IllegalArgumentException.class); - thrown.expectMessage("Server SSL context can not be used for client channel"); - - builder.sslContext(sslContext); + IllegalArgumentException e = assertThrows(IllegalArgumentException.class, + () -> builder.sslContext(sslContext)); + assertThat(e).hasMessageThat() + .isEqualTo("Server SSL context can not be used for client channel"); } @Test @@ -166,10 +160,10 @@ public void failNegotiationTypeWithChannelCredentials_target() { NettyChannelBuilder builder = NettyChannelBuilder.forTarget( "fakeTarget", InsecureChannelCredentials.create()); - thrown.expect(IllegalStateException.class); - thrown.expectMessage("Cannot change security when using ChannelCredentials"); - - builder.negotiationType(NegotiationType.TLS); + IllegalStateException e = assertThrows(IllegalStateException.class, + () -> builder.negotiationType(NegotiationType.TLS)); + assertThat(e).hasMessageThat() + .isEqualTo("Cannot change security when using ChannelCredentials"); } @Test @@ -177,10 +171,10 @@ public void failNegotiationTypeWithChannelCredentials_socketAddress() { NettyChannelBuilder builder = NettyChannelBuilder.forAddress( getTestSocketAddress(), InsecureChannelCredentials.create()); - thrown.expect(IllegalStateException.class); - thrown.expectMessage("Cannot change security when using ChannelCredentials"); - - builder.negotiationType(NegotiationType.TLS); + IllegalStateException e = assertThrows(IllegalStateException.class, + () -> builder.negotiationType(NegotiationType.TLS)); + assertThat(e).hasMessageThat() + .isEqualTo("Cannot change security when using ChannelCredentials"); } @Test @@ -205,10 +199,9 @@ public void createProtocolNegotiatorByType_plaintextUpgrade() { @Test public void createProtocolNegotiatorByType_tlsWithNoContext() { - thrown.expect(NullPointerException.class); - NettyChannelBuilder.createProtocolNegotiatorByType( - NegotiationType.TLS, - noSslContext, null); + assertThrows(NullPointerException.class, + () -> NettyChannelBuilder.createProtocolNegotiatorByType( + NegotiationType.TLS, noSslContext, null)); } @Test @@ -245,38 +238,40 @@ public void createProtocolNegotiatorByType_tlsWithAuthorityFallback() throws SSL public void negativeKeepAliveTime() { NettyChannelBuilder builder = NettyChannelBuilder.forTarget("fakeTarget"); - thrown.expect(IllegalArgumentException.class); - thrown.expectMessage("keepalive time must be positive"); - builder.keepAliveTime(-1L, TimeUnit.HOURS); + IllegalArgumentException e = assertThrows(IllegalArgumentException.class, + () -> builder.keepAliveTime(-1L, TimeUnit.HOURS)); + assertThat(e).hasMessageThat().isEqualTo("keepalive time must be positive"); } @Test public void negativeKeepAliveTimeout() { NettyChannelBuilder builder = NettyChannelBuilder.forTarget("fakeTarget"); - thrown.expect(IllegalArgumentException.class); - thrown.expectMessage("keepalive timeout must be positive"); - builder.keepAliveTimeout(-1L, TimeUnit.HOURS); + IllegalArgumentException e = assertThrows(IllegalArgumentException.class, + () -> builder.keepAliveTimeout(-1L, TimeUnit.HOURS)); + assertThat(e).hasMessageThat().isEqualTo("keepalive timeout must be positive"); } @Test public void assertEventLoopAndChannelType_onlyGroupProvided() { NettyChannelBuilder builder = NettyChannelBuilder.forTarget("fakeTarget"); builder.eventLoopGroup(mock(EventLoopGroup.class)); - thrown.expect(IllegalStateException.class); - thrown.expectMessage("Both EventLoopGroup and ChannelType should be provided"); - builder.assertEventLoopAndChannelType(); + IllegalStateException e = assertThrows(IllegalStateException.class, + builder::assertEventLoopAndChannelType); + assertThat(e).hasMessageThat() + .isEqualTo("Both EventLoopGroup and ChannelType should be provided or neither should be"); } @Test public void assertEventLoopAndChannelType_onlyTypeProvided() { NettyChannelBuilder builder = NettyChannelBuilder.forTarget("fakeTarget"); builder.channelType(LocalChannel.class, LocalAddress.class); - thrown.expect(IllegalStateException.class); - thrown.expectMessage("Both EventLoopGroup and ChannelType should be provided"); - builder.assertEventLoopAndChannelType(); + IllegalStateException e = assertThrows(IllegalStateException.class, + builder::assertEventLoopAndChannelType); + assertThat(e).hasMessageThat() + .isEqualTo("Both EventLoopGroup and ChannelType should be provided or neither should be"); } @Test @@ -288,10 +283,11 @@ public Channel newChannel() { return null; } }); - thrown.expect(IllegalStateException.class); - thrown.expectMessage("Both EventLoopGroup and ChannelType should be provided"); - builder.assertEventLoopAndChannelType(); + IllegalStateException e = assertThrows(IllegalStateException.class, + builder::assertEventLoopAndChannelType); + assertThat(e).hasMessageThat() + .isEqualTo("Both EventLoopGroup and ChannelType should be provided or neither should be"); } @Test diff --git a/netty/src/test/java/io/grpc/netty/NettyClientHandlerTest.java b/netty/src/test/java/io/grpc/netty/NettyClientHandlerTest.java index 9cb2c043e54..9f6be9a2f3e 100644 --- a/netty/src/test/java/io/grpc/netty/NettyClientHandlerTest.java +++ b/netty/src/test/java/io/grpc/netty/NettyClientHandlerTest.java @@ -16,7 +16,6 @@ package io.grpc.netty; -import static com.google.common.base.Charsets.UTF_8; import static com.google.common.truth.Truth.assertThat; import static io.grpc.internal.ClientStreamListener.RpcProgress.MISCARRIED; import static io.grpc.internal.ClientStreamListener.RpcProgress.PROCESSED; @@ -29,13 +28,14 @@ import static io.grpc.netty.Utils.STATUS_OK; import static io.grpc.netty.Utils.TE_HEADER; import static io.grpc.netty.Utils.TE_TRAILERS; -import static io.netty.handler.codec.http2.Http2CodecUtil.DEFAULT_PRIORITY_WEIGHT; +import static java.nio.charset.StandardCharsets.UTF_8; import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.same; @@ -47,6 +47,7 @@ import static org.mockito.Mockito.verifyNoMoreInteractions; import com.google.common.base.Stopwatch; +import com.google.common.base.Strings; import com.google.common.base.Supplier; import com.google.common.base.Ticker; import com.google.common.collect.ImmutableList; @@ -56,16 +57,18 @@ import io.grpc.Attributes; import io.grpc.CallOptions; import io.grpc.Metadata; +import io.grpc.MetricRecorder; import io.grpc.Status; -import io.grpc.StatusException; import io.grpc.internal.AbstractStream; import io.grpc.internal.ClientStreamListener; import io.grpc.internal.ClientStreamListener.RpcProgress; import io.grpc.internal.ClientTransport; import io.grpc.internal.ClientTransport.PingCallback; +import io.grpc.internal.GrpcAttributes; import io.grpc.internal.GrpcUtil; import io.grpc.internal.KeepAliveManager; import io.grpc.internal.ManagedClientTransport; +import io.grpc.internal.SimpleDisconnectError; import io.grpc.internal.StatsTraceContext; import io.grpc.internal.StreamListener; import io.grpc.internal.TransportTracer; @@ -89,10 +92,12 @@ import io.netty.handler.codec.http2.Http2Stream; import io.netty.util.AsciiString; import java.io.InputStream; +import java.security.cert.CertificateException; import java.text.MessageFormat; import java.util.LinkedList; import java.util.List; import java.util.Queue; +import java.util.concurrent.ExecutionException; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicReference; import java.util.logging.Handler; @@ -122,6 +127,7 @@ public class NettyClientHandlerTest extends NettyHandlerTestBaseany()); - + doAnswer((attributes) -> Attributes.newBuilder().set( + GrpcAttributes.ATTR_AUTHORITY_VERIFIER, + (authority) -> Status.OK).build()) + .when(listener) + .filterTransport(ArgumentMatchers.any(Attributes.class)); lifecycleManager = new ClientTransportLifecycleManager(listener); // This mocks the keepalive manager only for there's in which we verify it. For other tests // it'll be null which will be testing if we behave correctly when it's not present. @@ -215,6 +225,37 @@ public Void answer(InvocationOnMock invocation) throws Throwable { // Simulate receipt of initial remote settings. ByteBuf serializedSettings = serializeSettings(new Http2Settings()); channelRead(serializedSettings); + channel().releaseOutbound(); + } + + @Test + @SuppressWarnings("InlineMeInliner") + public void sendLargerThanSoftLimitHeaderMayFail() throws Exception { + maxHeaderListSize = 8000; + softLimitHeaderListSize = 2000; + manualSetUp(); + + createStream(); + // total head size of 7999, soft limit = 2000 and max = 8000. + // This header has 5999/6000 chance to be rejected. + Http2Headers headers = new DefaultHttp2Headers() + .scheme(HTTPS) + .authority(as("www.fake.com")) + .path(as("/fakemethod")) + .method(HTTP_METHOD) + .add(as("auth"), as("sometoken")) + .add(CONTENT_TYPE_HEADER, CONTENT_TYPE_GRPC) + .add(TE_HEADER, TE_TRAILERS) + .add("large-field", Strings.repeat("a", 7620)); // String.repeat() requires Java 11 + + ByteBuf headersFrame = headersFrame(STREAM_ID, headers); + channelRead(headersFrame); + ArgumentCaptor statusArgumentCaptor = ArgumentCaptor.forClass(Status.class); + verify(streamListener).closed(statusArgumentCaptor.capture(), eq(PROCESSED), + any(Metadata.class)); + assertThat(statusArgumentCaptor.getValue().getCode()).isEqualTo(Status.Code.RESOURCE_EXHAUSTED); + assertThat(statusArgumentCaptor.getValue().getDescription()).contains( + "exceeded Metadata size soft limit"); } @Test @@ -228,7 +269,7 @@ public void cancelBufferedStreamShouldChangeClientStreamStatus() throws Exceptio // Cancel the stream. cancelStream(Status.CANCELLED); - assertTrue(createFuture.isSuccess()); + assertFalse(createFuture.isSuccess()); verify(streamListener).closed(eq(Status.CANCELLED), same(PROCESSED), any(Metadata.class)); } @@ -236,7 +277,7 @@ public void cancelBufferedStreamShouldChangeClientStreamStatus() throws Exceptio public void createStreamShouldSucceed() throws Exception { createStream(); verifyWrite().writeHeaders(eq(ctx()), eq(STREAM_ID), eq(grpcHeaders), eq(0), - eq(DEFAULT_PRIORITY_WEIGHT), eq(false), eq(0), eq(false), any(ChannelPromise.class)); + eq(false), any(ChannelPromise.class)); } @Test @@ -271,7 +312,7 @@ public void cancelWhileBufferedShouldSucceed() throws Exception { ChannelFuture cancelFuture = cancelStream(Status.CANCELLED); assertTrue(cancelFuture.isSuccess()); assertTrue(createFuture.isDone()); - assertTrue(createFuture.isSuccess()); + assertFalse(createFuture.isSuccess()); } /** @@ -310,11 +351,12 @@ public void sendFrameShouldSucceed() throws Exception { createStream(); // Send a frame and verify that it was written. + ByteBuf content = content(); ChannelFuture future - = enqueue(new SendGrpcFrameCommand(streamTransportState, content(), true)); + = enqueue(new SendGrpcFrameCommand(streamTransportState, content, true)); assertTrue(future.isSuccess()); - verifyWrite().writeData(eq(ctx()), eq(STREAM_ID), eq(content()), eq(0), eq(true), + verifyWrite().writeData(eq(ctx()), eq(STREAM_ID), same(content), eq(0), eq(true), any(ChannelPromise.class)); verify(mockKeepAliveManager, times(1)).onTransportActive(); // onStreamActive verifyNoMoreInteractions(mockKeepAliveManager); @@ -412,6 +454,26 @@ public void receivedAbruptGoAwayShouldFailRacingQueuedStreamid() throws Exceptio assertTrue(future.isDone()); } + @Test + public void receivedAbruptGoAwayShouldFailRacingQueuedIoStreamid() throws Exception { + // Purposefully avoid flush(), since we want the write to not actually complete. + // EmbeddedChannel doesn't support flow control, so this is the next closest approximation. + ChannelFuture future = channel().write( + newCreateStreamCommand(grpcHeaders, streamTransportState)); + // Read a GOAWAY that indicates our stream can't be sent + channelRead(goAwayFrame(0, 0 /* NO_ERROR */, Unpooled.copiedBuffer("this is a test", UTF_8))); + + ArgumentCaptor captor = ArgumentCaptor.forClass(Status.class); + verify(streamListener).closed(captor.capture(), same(REFUSED), + ArgumentMatchers.notNull()); + assertEquals(Status.UNAVAILABLE.getCode(), captor.getValue().getCode()); + assertEquals( + "Abrupt GOAWAY closed sent stream. HTTP/2 error code: NO_ERROR, " + + "debug data: this is a test", + captor.getValue().getDescription()); + assertTrue(future.isDone()); + } + @Test public void receivedGoAway_shouldFailBufferedStreamsExceedingMaxConcurrentStreams() throws Exception { @@ -704,7 +766,7 @@ public void exhaustedStreamsShouldFail() throws Exception { public void nonExistentStream() throws Exception { Status status = Status.INTERNAL.withDescription("zz"); - lifecycleManager.notifyShutdown(status); + lifecycleManager.notifyShutdown(status, SimpleDisconnectError.SUBCHANNEL_SHUTDOWN); // Stream creation can race with the transport shutting down, with the create command already // enqueued. ChannelFuture future1 = createStream(); @@ -770,9 +832,7 @@ public void ping_failsWhenChannelCloses() throws Exception { handler().channelInactive(ctx()); // ping failed on channel going inactive assertEquals(1, callback.invocationCount); - assertTrue(callback.failureCause instanceof StatusException); - assertEquals(Status.Code.UNAVAILABLE, - ((StatusException) callback.failureCause).getStatus().getCode()); + assertEquals(Status.Code.UNAVAILABLE, callback.failureCause.getCode()); // A failed ping is still counted assertEquals(1, transportTracer.getStats().keepAlivesSent); } @@ -885,6 +945,159 @@ public void exceptionCaughtShouldCloseConnection() throws Exception { assertFalse(channel().isOpen()); } + @Test + public void missingAuthorityHeader_streamCreationShouldFail() throws Exception { + Http2Headers grpcHeadersWithoutAuthority = new DefaultHttp2Headers() + .scheme(HTTPS) + .path(as("/fakemethod")) + .method(HTTP_METHOD) + .add(as("auth"), as("sometoken")) + .add(CONTENT_TYPE_HEADER, CONTENT_TYPE_GRPC) + .add(TE_HEADER, TE_TRAILERS); + ChannelFuture channelFuture = enqueue(newCreateStreamCommand( + grpcHeadersWithoutAuthority, streamTransportState)); + try { + channelFuture.get(); + fail("Expected stream creation failure"); + } catch (ExecutionException e) { + assertThat(e.getCause().getMessage()).isEqualTo("UNAVAILABLE: Missing authority header"); + } + } + + @Test + public void missingAuthorityVerifierInAttributes_streamCreationShouldFail() throws Exception { + doAnswer(new Answer() { + @Override + public Void answer(InvocationOnMock invocation) throws Throwable { + StreamListener.MessageProducer producer = + (StreamListener.MessageProducer) invocation.getArguments()[0]; + InputStream message; + while ((message = producer.next()) != null) { + streamListenerMessageQueue.add(message); + } + return null; + } + }) + .when(streamListener) + .messagesAvailable(ArgumentMatchers.any()); + doAnswer((attributes) -> Attributes.EMPTY) + .when(listener) + .filterTransport(ArgumentMatchers.any(Attributes.class)); + lifecycleManager = new ClientTransportLifecycleManager(listener); + // This mocks the keepalive manager only for there's in which we verify it. For other tests + // it'll be null which will be testing if we behave correctly when it's not present. + if (setKeepaliveManagerFor.contains(testNameRule.getMethodName())) { + mockKeepAliveManager = mock(KeepAliveManager.class); + } + + initChannel(new GrpcHttp2ClientHeadersDecoder(GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE)); + streamTransportState = new TransportStateImpl( + handler(), + channel().eventLoop(), + DEFAULT_MAX_MESSAGE_SIZE, + transportTracer); + streamTransportState.setListener(streamListener); + + grpcHeaders = new DefaultHttp2Headers() + .scheme(HTTPS) + .authority(as("www.fake.com")) + .path(as("/fakemethod")) + .method(HTTP_METHOD) + .add(as("auth"), as("sometoken")) + .add(CONTENT_TYPE_HEADER, CONTENT_TYPE_GRPC) + .add(TE_HEADER, TE_TRAILERS); + + // Simulate receipt of initial remote settings. + ByteBuf serializedSettings = serializeSettings(new Http2Settings()); + channelRead(serializedSettings); + channel().releaseOutbound(); + + ChannelFuture channelFuture = createStream(); + try { + channelFuture.get(); + fail("Expected stream creation failure"); + } catch (ExecutionException e) { + assertThat(e.getCause().getMessage()).isEqualTo( + "UNAVAILABLE: Authority verifier not found to verify authority"); + } + } + + @Test + public void authorityVerificationSuccess_streamCreationSucceeds() throws Exception { + NettyClientHandler.enablePerRpcAuthorityCheck = true; + try { + ChannelFuture channelFuture = createStream(); + channelFuture.get(); + } finally { + NettyClientHandler.enablePerRpcAuthorityCheck = false; + } + } + + @Test + public void authorityVerificationFailure_streamCreationFails() throws Exception { + NettyClientHandler.enablePerRpcAuthorityCheck = true; + try { + doAnswer(new Answer() { + @Override + public Void answer(InvocationOnMock invocation) throws Throwable { + StreamListener.MessageProducer producer = + (StreamListener.MessageProducer) invocation.getArguments()[0]; + InputStream message; + while ((message = producer.next()) != null) { + streamListenerMessageQueue.add(message); + } + return null; + } + }) + .when(streamListener) + .messagesAvailable(ArgumentMatchers.any()); + doAnswer((attributes) -> Attributes.newBuilder().set( + GrpcAttributes.ATTR_AUTHORITY_VERIFIER, + (authority) -> Status.UNAVAILABLE.withCause( + new CertificateException("Peer verification failed"))).build()) + .when(listener) + .filterTransport(ArgumentMatchers.any(Attributes.class)); + lifecycleManager = new ClientTransportLifecycleManager(listener); + // This mocks the keepalive manager only for there's in which we verify it. For other tests + // it'll be null which will be testing if we behave correctly when it's not present. + if (setKeepaliveManagerFor.contains(testNameRule.getMethodName())) { + mockKeepAliveManager = mock(KeepAliveManager.class); + } + + initChannel(new GrpcHttp2ClientHeadersDecoder(GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE)); + streamTransportState = new TransportStateImpl( + handler(), + channel().eventLoop(), + DEFAULT_MAX_MESSAGE_SIZE, + transportTracer); + streamTransportState.setListener(streamListener); + + grpcHeaders = new DefaultHttp2Headers() + .scheme(HTTPS) + .authority(as("www.fake.com")) + .path(as("/fakemethod")) + .method(HTTP_METHOD) + .add(as("auth"), as("sometoken")) + .add(CONTENT_TYPE_HEADER, CONTENT_TYPE_GRPC) + .add(TE_HEADER, TE_TRAILERS); + + // Simulate receipt of initial remote settings. + ByteBuf serializedSettings = serializeSettings(new Http2Settings()); + channelRead(serializedSettings); + channel().releaseOutbound(); + + ChannelFuture channelFuture = createStream(); + try { + channelFuture.get(); + fail("Expected stream creation failure"); + } catch (ExecutionException e) { + assertThat(e.getMessage()).isEqualTo("io.grpc.InternalStatusRuntimeException: UNAVAILABLE"); + } + } finally { + NettyClientHandler.enablePerRpcAuthorityCheck = false; + } + } + @Override protected void makeStream() throws Exception { createStream(); @@ -946,13 +1159,15 @@ public Stopwatch get() { false, flowControlWindow, maxHeaderListSize, + softLimitHeaderListSize, stopwatchSupplier, tooManyPingsRunnable, transportTracer, Attributes.EMPTY, "someauthority", null, - fakeClock().getTicker()); + fakeClock().getTicker(), + new MetricRecorder() {}); } @Override @@ -973,7 +1188,7 @@ private static CreateStreamCommand newCreateStreamCommand( private static class PingCallbackImpl implements ClientTransport.PingCallback { int invocationCount; long roundTripTime; - Throwable failureCause; + Status failureCause; @Override public void onSuccess(long roundTripTimeNanos) { @@ -982,7 +1197,7 @@ public void onSuccess(long roundTripTimeNanos) { } @Override - public void onFailure(Throwable cause) { + public void onFailure(Status cause) { invocationCount++; this.failureCause = cause; } @@ -1000,7 +1215,8 @@ public TransportStateImpl( maxMessageSize, StatsTraceContext.NOOP, transportTracer, - "methodName"); + "methodName", + CallOptions.DEFAULT); } @Override diff --git a/netty/src/test/java/io/grpc/netty/NettyClientStreamTest.java b/netty/src/test/java/io/grpc/netty/NettyClientStreamTest.java index 96551d173a4..4dd24c3fd4d 100644 --- a/netty/src/test/java/io/grpc/netty/NettyClientStreamTest.java +++ b/netty/src/test/java/io/grpc/netty/NettyClientStreamTest.java @@ -23,6 +23,8 @@ import static io.grpc.netty.Utils.CONTENT_TYPE_GRPC; import static io.grpc.netty.Utils.CONTENT_TYPE_HEADER; import static io.grpc.netty.Utils.STATUS_OK; +import static io.netty.handler.codec.http2.Http2Error.PROTOCOL_ERROR; +import static io.netty.handler.codec.http2.Http2Exception.connectionError; import static io.netty.util.CharsetUtil.UTF_8; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; @@ -34,6 +36,7 @@ import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.isA; import static org.mockito.ArgumentMatchers.same; +import static org.mockito.Mockito.atLeast; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; @@ -43,6 +46,7 @@ import static org.mockito.Mockito.when; import com.google.common.collect.ImmutableListMultimap; +import com.google.common.collect.Iterables; import com.google.common.io.BaseEncoding; import io.grpc.CallOptions; import io.grpc.InternalStatus; @@ -62,6 +66,7 @@ import io.netty.channel.ChannelPromise; import io.netty.channel.DefaultChannelPromise; import io.netty.handler.codec.http2.DefaultHttp2Headers; +import io.netty.handler.codec.http2.Http2Exception; import io.netty.handler.codec.http2.Http2Headers; import io.netty.util.AsciiString; import java.io.BufferedInputStream; @@ -75,6 +80,7 @@ import org.junit.runners.JUnit4; import org.mockito.ArgumentCaptor; import org.mockito.ArgumentMatchers; +import org.mockito.InOrder; import org.mockito.Mock; import org.mockito.Mockito; import org.mockito.invocation.InvocationOnMock; @@ -205,6 +211,52 @@ public void writeMessageShouldSendRequestUnknownLength() throws Exception { eq(true)); } + @Test + public void writeFrameFutureFailedShouldCancelRpc() { + Http2Exception h2Error = connectionError(PROTOCOL_ERROR, "Stream does not exist %d", STREAM_ID); + // Fail all SendGrpcFrameCommands command sent to the queue. + when(writeQueue.enqueue(any(SendGrpcFrameCommand.class), anyBoolean())).thenReturn( + new DefaultChannelPromise(channel).setFailure(h2Error)); + + // Write multiple messages to ensure multiple SendGrpcFrameCommand are enqueued. We set up all + // of them to fail, which allows us to assert that only a single cancel is sent, and the stream + // isn't spammed with multiple RST_STREAM. + stream().transportState().setId(STREAM_ID); + stream.writeMessage(new ByteArrayInputStream(smallMessage())); + stream.writeMessage(new ByteArrayInputStream(largeMessage())); + stream.flush(); + + InOrder inOrder = Mockito.inOrder(writeQueue); + // Normal stream create and write frame. + inOrder.verify(writeQueue).enqueue(any(CreateStreamCommand.class), eq(false)); + inOrder.verify(writeQueue).enqueue(any(SendGrpcFrameCommand.class), eq(false)); + // Verify that failed SendGrpcFrameCommand results in immediate CancelClientStreamCommand. + inOrder.verify(writeQueue).enqueue(any(CancelClientStreamCommand.class), eq(true)); + // Verify that any other failures do not produce another CancelClientStreamCommand in the queue. + inOrder.verify(writeQueue, atLeast(0)).enqueue(any(SendGrpcFrameCommand.class), eq(false)); + inOrder.verify(writeQueue).enqueue(any(SendGrpcFrameCommand.class), eq(true)); + inOrder.verifyNoMoreInteractions(); + + // Get the CancelClientStreamCommand written to the queue. Above we verified that there is + // only one CancelClientStreamCommand enqueued, and is the third enqueued command (create, + // frame write failure, cancel). + CancelClientStreamCommand cancelCommand = Iterables.get( + Iterables.filter( + Mockito.mockingDetails(writeQueue).getInvocations(), + // Get enqueue() innovations only + invocation -> invocation.getMethod().getName().equals("enqueue")), + // Get the third invocation of enqueue() + 2) + // Get the first argument (QueuedCommand command) + .getArgument(0); + + Status cancelReason = cancelCommand.reason(); + assertThat(cancelReason.getCode()).isEqualTo(Status.INTERNAL.getCode()); + assertThat(cancelReason.getCause()).isEqualTo(h2Error); + // Verify listener closed. + verify(listener).closed(same(cancelReason), eq(PROCESSED), any(Metadata.class)); + } + @Test public void setStatusWithOkShouldCloseStream() { stream().transportState().setId(STREAM_ID); @@ -563,7 +615,8 @@ public TransportStateImpl(NettyClientHandler handler, int maxMessageSize) { maxMessageSize, StatsTraceContext.NOOP, transportTracer, - "methodName"); + "methodName", + CallOptions.DEFAULT); } @Override diff --git a/netty/src/test/java/io/grpc/netty/NettyClientTransportTest.java b/netty/src/test/java/io/grpc/netty/NettyClientTransportTest.java index f94960cbab3..7023acc947c 100644 --- a/netty/src/test/java/io/grpc/netty/NettyClientTransportTest.java +++ b/netty/src/test/java/io/grpc/netty/NettyClientTransportTest.java @@ -16,7 +16,6 @@ package io.grpc.netty; -import static com.google.common.base.Charsets.UTF_8; import static com.google.common.truth.Truth.assertThat; import static com.google.common.truth.TruthJUnit.assume; import static io.grpc.internal.GrpcUtil.DEFAULT_MAX_MESSAGE_SIZE; @@ -29,6 +28,7 @@ import static io.grpc.netty.NettyServerBuilder.MAX_CONNECTION_IDLE_NANOS_DISABLED; import static io.grpc.netty.NettyServerBuilder.MAX_RST_COUNT_DISABLED; import static io.netty.handler.codec.http2.Http2CodecUtil.DEFAULT_WINDOW_SIZE; +import static java.nio.charset.StandardCharsets.UTF_8; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotNull; @@ -37,12 +37,16 @@ import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.timeout; +import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import com.google.common.base.Optional; import com.google.common.base.Strings; import com.google.common.base.Ticker; import com.google.common.io.ByteStreams; import com.google.common.util.concurrent.SettableFuture; +import com.google.errorprone.annotations.concurrent.GuardedBy; import io.grpc.Attributes; import io.grpc.CallOptions; import io.grpc.ChannelLogger; @@ -52,13 +56,16 @@ import io.grpc.Metadata; import io.grpc.MethodDescriptor; import io.grpc.MethodDescriptor.Marshaller; +import io.grpc.MetricRecorder; import io.grpc.ServerStreamTracer; import io.grpc.Status; import io.grpc.Status.Code; import io.grpc.StatusException; +import io.grpc.TlsChannelCredentials; import io.grpc.internal.ClientStream; import io.grpc.internal.ClientStreamListener; import io.grpc.internal.ClientTransport; +import io.grpc.internal.DisconnectError; import io.grpc.internal.FakeClock; import io.grpc.internal.FixedObjectPool; import io.grpc.internal.GrpcUtil; @@ -73,6 +80,7 @@ import io.grpc.netty.NettyChannelBuilder.LocalSocketPicker; import io.grpc.netty.NettyTestUtil.TrackingObjectPoolForTest; import io.grpc.testing.TlsTesting; +import io.grpc.util.CertificateUtils; import io.netty.buffer.ByteBuf; import io.netty.channel.Channel; import io.netty.channel.ChannelConfig; @@ -82,6 +90,7 @@ import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelOption; import io.netty.channel.ChannelPromise; +import io.netty.channel.DefaultEventLoopGroup; import io.netty.channel.EventLoopGroup; import io.netty.channel.ReflectiveChannelFactory; import io.netty.channel.local.LocalChannel; @@ -93,12 +102,18 @@ import io.netty.handler.ssl.ClientAuth; import io.netty.handler.ssl.SslContext; import io.netty.util.AsciiString; +import io.netty.util.ReferenceCountUtil; import java.io.ByteArrayInputStream; import java.io.IOException; import java.io.InputStream; +import java.lang.reflect.InvocationTargetException; import java.net.InetSocketAddress; import java.net.SocketAddress; import java.nio.charset.StandardCharsets; +import java.security.GeneralSecurityException; +import java.security.KeyStore; +import java.security.cert.CertificateException; +import java.security.cert.X509Certificate; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; @@ -112,6 +127,12 @@ import javax.annotation.Nullable; import javax.net.ssl.SSLException; import javax.net.ssl.SSLHandshakeException; +import javax.net.ssl.TrustManager; +import javax.net.ssl.TrustManagerFactory; +import javax.net.ssl.X509ExtendedTrustManager; +import javax.net.ssl.X509TrustManager; +import javax.security.auth.x500.X500Principal; +import org.codehaus.mojo.animal_sniffer.IgnoreJRERequirement; import org.junit.After; import org.junit.Before; import org.junit.Rule; @@ -126,11 +147,15 @@ * Tests for {@link NettyClientTransport}. */ @RunWith(JUnit4.class) +@IgnoreJRERequirement public class NettyClientTransportTest { @Rule public final MockitoRule mocks = MockitoJUnit.rule(); private static final SslContext SSL_CONTEXT = createSslContext(); + @SuppressWarnings("InlineMeInliner") // Requires Java 11 + private static final String LONG_STRING_OF_A = Strings.repeat("a", 128); + @Mock private ManagedClientTransport.Listener clientTransportListener; @@ -185,6 +210,7 @@ public void addDefaultUserAgent() throws Exception { startServer(); NettyClientTransport transport = newTransport(newNegotiator()); callMeMaybe(transport.start(clientTransportListener)); + verify(clientTransportListener, timeout(5000)).transportReady(); // Send a single RPC and wait for the response. new Rpc(transport).halfClose().waitForResponse(); @@ -197,18 +223,37 @@ public void addDefaultUserAgent() throws Exception { } @Test - public void setSoLingerChannelOption() throws IOException { + public void setSoLingerChannelOption() throws IOException, GeneralSecurityException { startServer(); Map, Object> channelOptions = new HashMap<>(); // set SO_LINGER option int soLinger = 123; channelOptions.put(ChannelOption.SO_LINGER, soLinger); NettyClientTransport transport = new NettyClientTransport( - address, new ReflectiveChannelFactory<>(NioSocketChannel.class), channelOptions, group, - newNegotiator(), false, DEFAULT_WINDOW_SIZE, DEFAULT_MAX_MESSAGE_SIZE, - GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE, KEEPALIVE_TIME_NANOS_DISABLED, 1L, false, authority, - null /* user agent */, tooManyPingsRunnable, new TransportTracer(), Attributes.EMPTY, - new SocketPicker(), new FakeChannelLogger(), false, Ticker.systemTicker()); + address, + new ReflectiveChannelFactory<>(NioSocketChannel.class), + channelOptions, + group, + newNegotiator(), + false, + DEFAULT_WINDOW_SIZE, + DEFAULT_MAX_MESSAGE_SIZE, + GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE, + GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE, + KEEPALIVE_TIME_NANOS_DISABLED, + 1L, + false, + authority, + null /* user agent */, + tooManyPingsRunnable, + new TransportTracer(), + Attributes.EMPTY, + new SocketPicker(), + new FakeChannelLogger(), + false, + new MetricRecorder() { + }, + Ticker.systemTicker()); transports.add(transport); callMeMaybe(transport.start(clientTransportListener)); @@ -224,6 +269,7 @@ public void overrideDefaultUserAgent() throws Exception { NettyClientTransport transport = newTransport(newNegotiator(), DEFAULT_MAX_MESSAGE_SIZE, GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE, "testUserAgent", true); callMeMaybe(transport.start(clientTransportListener)); + verify(clientTransportListener, timeout(5000)).transportReady(); new Rpc(transport, new Metadata()).halfClose().waitForResponse(); @@ -241,6 +287,7 @@ public void maxMessageSizeShouldBeEnforced() throws Throwable { NettyClientTransport transport = newTransport(newNegotiator(), 1, GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE, null, true); callMeMaybe(transport.start(clientTransportListener)); + verify(clientTransportListener, timeout(5000)).transportReady(); try { // Send a single RPC and wait for the response. @@ -267,6 +314,7 @@ public void creatingMultipleTlsTransportsShouldSucceed() throws Exception { NettyClientTransport transport = newTransport(negotiator); callMeMaybe(transport.start(clientTransportListener)); } + verify(clientTransportListener, timeout(5000).times(2)).transportReady(); // Send a single RPC on each transport. final List rpcs = new ArrayList<>(transports.size()); @@ -296,6 +344,7 @@ public void run() { failureStatus.asRuntimeException()); } }); + verify(clientTransportListener, timeout(5000)).transportTerminated(); Rpc rpc = new Rpc(transport).halfClose(); try { @@ -326,9 +375,10 @@ public void tlsNegotiationFailurePropagatesToStatus() throws Exception { .trustManager(caCert) .keyManager(clientCert, clientKey) .build(); - ProtocolNegotiator negotiator = ProtocolNegotiators.tls(clientContext); + ProtocolNegotiator negotiator = ProtocolNegotiators.tls(clientContext, null); final NettyClientTransport transport = newTransport(negotiator); callMeMaybe(transport.start(clientTransportListener)); + verify(clientTransportListener, timeout(5000)).transportTerminated(); Rpc rpc = new Rpc(transport).halfClose(); try { @@ -358,6 +408,7 @@ public void channelExceptionDuringNegotiatonPropagatesToStatus() throws Exceptio callMeMaybe(transport.start(clientTransportListener)); final Status failureStatus = Status.UNAVAILABLE.withDescription("oh noes!"); transport.channel().pipeline().fireExceptionCaught(failureStatus.asRuntimeException()); + verify(clientTransportListener, timeout(5000)).transportTerminated(); Rpc rpc = new Rpc(transport).halfClose(); try { @@ -389,6 +440,7 @@ public void run() { } } }); + verify(clientTransportListener, timeout(5000)).transportTerminated(); Rpc rpc = new Rpc(transport).halfClose(); try { @@ -408,6 +460,7 @@ public void bufferedStreamsShouldBeClosedWhenConnectionTerminates() throws Excep NettyClientTransport transport = newTransport(newNegotiator()); callMeMaybe(transport.start(clientTransportListener)); + verify(clientTransportListener, timeout(5000)).transportReady(); // Send a dummy RPC in order to ensure that the updated SETTINGS_MAX_CONCURRENT_STREAMS // has been received by the remote endpoint. @@ -453,12 +506,30 @@ public void failingToConstructChannelShouldFailGracefully() throws Exception { address = TestUtils.testServerAddress(new InetSocketAddress(12345)); authority = GrpcUtil.authorityFromHostAndPort(address.getHostString(), address.getPort()); NettyClientTransport transport = new NettyClientTransport( - address, new ReflectiveChannelFactory<>(CantConstructChannel.class), - new HashMap, Object>(), group, - newNegotiator(), false, DEFAULT_WINDOW_SIZE, DEFAULT_MAX_MESSAGE_SIZE, - GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE, KEEPALIVE_TIME_NANOS_DISABLED, 1, false, authority, - null, tooManyPingsRunnable, new TransportTracer(), Attributes.EMPTY, new SocketPicker(), - new FakeChannelLogger(), false, Ticker.systemTicker()); + address, + new ReflectiveChannelFactory<>(CantConstructChannel.class), + new HashMap, Object>(), + group, + newNegotiator(), + false, + DEFAULT_WINDOW_SIZE, + DEFAULT_MAX_MESSAGE_SIZE, + GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE, + GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE, + KEEPALIVE_TIME_NANOS_DISABLED, + 1, + false, + authority, + null, + tooManyPingsRunnable, + new TransportTracer(), + Attributes.EMPTY, + new SocketPicker(), + new FakeChannelLogger(), + false, + new MetricRecorder() { + }, + Ticker.systemTicker()); transports.add(transport); // Should not throw @@ -484,8 +555,8 @@ public void onSuccess(long roundTripTimeNanos) { } @Override - public void onFailure(Throwable cause) { - pingResult.setException(cause); + public void onFailure(Status cause) { + pingResult.setException(cause.asException()); } }; transport.ping(pingCallback, clock.getScheduledExecutorService()); @@ -519,15 +590,20 @@ public void channelFactoryShouldSetSocketOptionKeepAlive() throws Exception { @Test public void channelFactoryShouldNNotSetSocketOptionKeepAlive() throws Exception { startServer(); - NettyClientTransport transport = newTransport(newNegotiator(), - DEFAULT_MAX_MESSAGE_SIZE, GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE, "testUserAgent", true, - TimeUnit.SECONDS.toNanos(10L), TimeUnit.SECONDS.toNanos(1L), - new ReflectiveChannelFactory<>(LocalChannel.class), group); + DefaultEventLoopGroup group = new DefaultEventLoopGroup(1); + try { + NettyClientTransport transport = newTransport(newNegotiator(), + DEFAULT_MAX_MESSAGE_SIZE, GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE, "testUserAgent", true, + TimeUnit.SECONDS.toNanos(10L), TimeUnit.SECONDS.toNanos(1L), + new ReflectiveChannelFactory<>(LocalChannel.class), group); - callMeMaybe(transport.start(clientTransportListener)); + callMeMaybe(transport.start(clientTransportListener)); - assertThat(transport.channel().config().getOption(ChannelOption.SO_KEEPALIVE)) - .isNull(); + assertThat(transport.channel().config().getOption(ChannelOption.SO_KEEPALIVE)) + .isNull(); + } finally { + group.shutdownGracefully(0, 10, TimeUnit.SECONDS); + } } @Test @@ -537,6 +613,7 @@ public void maxHeaderListSizeShouldBeEnforcedOnClient() throws Exception { NettyClientTransport transport = newTransport(newNegotiator(), DEFAULT_MAX_MESSAGE_SIZE, 1, null, true); callMeMaybe(transport.start(clientTransportListener)); + verify(clientTransportListener, timeout(5000)).transportReady(); try { // Send a single RPC and wait for the response. @@ -554,9 +631,6 @@ public void maxHeaderListSizeShouldBeEnforcedOnClient() throws Exception { @Test public void huffmanCodingShouldNotBePerformed() throws Exception { - @SuppressWarnings("InlineMeInliner") // Requires Java 11 - String longStringOfA = Strings.repeat("a", 128); - negotiator = ProtocolNegotiators.serverPlaintext(); startServer(); @@ -567,9 +641,10 @@ public void huffmanCodingShouldNotBePerformed() throws Exception { Metadata headers = new Metadata(); headers.put(Metadata.Key.of("test", Metadata.ASCII_STRING_MARSHALLER), - longStringOfA); + LONG_STRING_OF_A); callMeMaybe(transport.start(clientTransportListener)); + verify(clientTransportListener, timeout(5000)).transportReady(); AtomicBoolean foundExpectedHeaderBytes = new AtomicBoolean(false); @@ -578,7 +653,7 @@ public void huffmanCodingShouldNotBePerformed() throws Exception { public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception { if (msg instanceof ByteBuf) { - if (((ByteBuf) msg).toString(StandardCharsets.UTF_8).contains(longStringOfA)) { + if (((ByteBuf) msg).toString(StandardCharsets.UTF_8).contains(LONG_STRING_OF_A)) { foundExpectedHeaderBytes.set(true); } } @@ -593,12 +668,54 @@ public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) } } + @Test + public void huffmanCodingShouldNotBePerformedOnServer() throws Exception { + negotiator = ProtocolNegotiators.serverPlaintext(); + + Metadata responseHeaders = new Metadata(); + responseHeaders.put(Metadata.Key.of("test", Metadata.ASCII_STRING_MARSHALLER), + LONG_STRING_OF_A); + + startServer(new EchoServerListener(responseHeaders)); + + NettyClientTransport transport = newTransport(ProtocolNegotiators.plaintext(), + DEFAULT_MAX_MESSAGE_SIZE, GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE, null, false, + TimeUnit.SECONDS.toNanos(10L), TimeUnit.SECONDS.toNanos(1L), + new ReflectiveChannelFactory<>(NioSocketChannel.class), group); + + callMeMaybe(transport.start(clientTransportListener)); + verify(clientTransportListener, timeout(5000)).transportReady(); + + AtomicBoolean foundExpectedHeaderBytes = new AtomicBoolean(false); + + // Add a handler to the client pipeline to inspect server's response + transport.channel().pipeline().addFirst(new ChannelDuplexHandler() { + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + if (msg instanceof ByteBuf) { + String data = ((ByteBuf) msg).toString(StandardCharsets.UTF_8); + if (data.contains(LONG_STRING_OF_A)) { + foundExpectedHeaderBytes.set(true); + } + } + super.channelRead(ctx, msg); + } + }); + + new Rpc(transport).halfClose().waitForResponse(); + + if (!foundExpectedHeaderBytes.get()) { + fail("expected to find UTF-8 encoded 'a's in the response header sent by the server"); + } + } + @Test public void maxHeaderListSizeShouldBeEnforcedOnServer() throws Exception { startServer(100, 1); NettyClientTransport transport = newTransport(newNegotiator()); callMeMaybe(transport.start(clientTransportListener)); + verify(clientTransportListener, timeout(5000)).transportReady(); try { // Send a single RPC and wait for the response. @@ -643,6 +760,7 @@ public void clientStreamGetsAttributes() throws Exception { startServer(); NettyClientTransport transport = newTransport(newNegotiator()); callMeMaybe(transport.start(clientTransportListener)); + verify(clientTransportListener, timeout(5000)).transportReady(); Rpc rpc = new Rpc(transport).halfClose(); rpc.waitForResponse(); @@ -661,6 +779,7 @@ public void keepAliveEnabled() throws Exception { NettyClientTransport transport = newTransport(newNegotiator(), DEFAULT_MAX_MESSAGE_SIZE, GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE, null /* user agent */, true /* keep alive */); callMeMaybe(transport.start(clientTransportListener)); + verify(clientTransportListener, timeout(5000)).transportReady(); Rpc rpc = new Rpc(transport).halfClose(); rpc.waitForResponse(); @@ -673,6 +792,7 @@ public void keepAliveDisabled() throws Exception { NettyClientTransport transport = newTransport(newNegotiator(), DEFAULT_MAX_MESSAGE_SIZE, GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE, null /* user agent */, false /* keep alive */); callMeMaybe(transport.start(clientTransportListener)); + verify(clientTransportListener, timeout(5000)).transportReady(); Rpc rpc = new Rpc(transport).halfClose(); rpc.waitForResponse(); @@ -760,11 +880,13 @@ public void tlsNegotiationServerExecutorShouldSucceed() throws Exception { .trustManager(caCert) .keyManager(clientCert, clientKey) .build(); - ProtocolNegotiator negotiator = ProtocolNegotiators.tls(clientContext, clientExecutorPool); + ProtocolNegotiator negotiator = ProtocolNegotiators.tls(clientContext, clientExecutorPool, + Optional.absent(), null, null); // after starting the client, the Executor in the client pool should be used assertEquals(true, clientExecutorPool.isInUse()); final NettyClientTransport transport = newTransport(negotiator); callMeMaybe(transport.start(clientTransportListener)); + verify(clientTransportListener, timeout(5000)).transportReady(); Rpc rpc = new Rpc(transport).halfClose(); rpc.waitForResponse(); // closing the negotiators should return the executors back to pool, and release the resource @@ -774,6 +896,179 @@ public void tlsNegotiationServerExecutorShouldSucceed() throws Exception { assertEquals(false, serverExecutorPool.isInUse()); } + /** + * This test tests the case of TlsCredentials passed to ProtocolNegotiators not having an instance + * of X509ExtendedTrustManager (this is not testable in ProtocolNegotiatorsTest without creating + * accessors for the internal state of negotiator whether it has a X509ExtendedTrustManager, + * hence the need to test it in this class instead). To establish a successful handshake we create + * a fake X509TrustManager not implementing X509ExtendedTrustManager but wraps the real + * X509ExtendedTrustManager. + */ + @Test + public void authorityOverrideInCallOptions_noX509ExtendedTrustManager_newStreamCreationFails() + throws IOException, InterruptedException, GeneralSecurityException, ExecutionException, + TimeoutException { + NettyClientHandler.enablePerRpcAuthorityCheck = true; + try { + startServer(); + InputStream caCert = TlsTesting.loadCert("ca.pem"); + X509TrustManager x509ExtendedTrustManager = + (X509TrustManager) getX509ExtendedTrustManager(caCert); + ProtocolNegotiators.FromChannelCredentialsResult result = + ProtocolNegotiators.from(TlsChannelCredentials.newBuilder() + .trustManager(new FakeTrustManager(x509ExtendedTrustManager)).build()); + NettyClientTransport transport = newTransport(result.negotiator.newNegotiator()); + SettableFuture connected = SettableFuture.create(); + FakeClientTransportListener fakeClientTransportListener = + new FakeClientTransportListener(connected); + callMeMaybe(transport.start(fakeClientTransportListener)); + connected.get(10, TimeUnit.SECONDS); + assertThat(fakeClientTransportListener.isConnected()).isTrue(); + + Rpc rpc = new Rpc(transport, new Metadata(), "foo.test.google.in"); + try { + rpc.waitForClose(); + fail("Expected exception in starting stream"); + } catch (ExecutionException ex) { + Status status = ((StatusException) ex.getCause()).getStatus(); + assertThat(status.getDescription()).isEqualTo("Can't allow authority override in rpc " + + "when X509ExtendedTrustManager is not available"); + assertThat(status.getCode()).isEqualTo(Code.UNAVAILABLE); + } + } finally { + NettyClientHandler.enablePerRpcAuthorityCheck = false; + } + } + + @Test + public void authorityOverrideInCallOptions_doesntMatchServerPeerHost_newStreamCreationFails() + throws IOException, InterruptedException, GeneralSecurityException, ExecutionException, + TimeoutException { + NettyClientHandler.enablePerRpcAuthorityCheck = true; + try { + startServer(); + NettyClientTransport transport = newTransport(newNegotiator()); + SettableFuture connected = SettableFuture.create(); + FakeClientTransportListener fakeClientTransportListener = + new FakeClientTransportListener(connected); + callMeMaybe(transport.start(fakeClientTransportListener)); + connected.get(10, TimeUnit.SECONDS); + assertThat(fakeClientTransportListener.isConnected()).isTrue(); + + Rpc rpc = new Rpc(transport, new Metadata(), "foo.test.google.in"); + try { + rpc.waitForClose(); + fail("Expected exception in starting stream"); + } catch (ExecutionException ex) { + Status status = ((StatusException) ex.getCause()).getStatus(); + assertThat(status.getDescription()).isEqualTo("Peer hostname verification during rpc " + + "failed for authority 'foo.test.google.in'"); + assertThat(status.getCode()).isEqualTo(Code.UNAVAILABLE); + assertThat(((InvocationTargetException) ex.getCause().getCause()).getTargetException()) + .isInstanceOf(CertificateException.class); + assertThat(((InvocationTargetException) ex.getCause().getCause()).getTargetException() + .getMessage()).isEqualTo( + "No subject alternative DNS name matching foo.test.google.in found."); + } + } finally { + NettyClientHandler.enablePerRpcAuthorityCheck = false; + } + } + + @Test + public void authorityOverrideInCallOptions_matchesServerPeerHost_newStreamCreationSucceeds() + throws IOException, InterruptedException, GeneralSecurityException, ExecutionException, + TimeoutException { + NettyClientHandler.enablePerRpcAuthorityCheck = true; + try { + startServer(); + NettyClientTransport transport = newTransport(newNegotiator()); + SettableFuture connected = SettableFuture.create(); + FakeClientTransportListener fakeClientTransportListener = + new FakeClientTransportListener(connected); + callMeMaybe(transport.start(fakeClientTransportListener)); + connected.get(10, TimeUnit.SECONDS); + assertThat(fakeClientTransportListener.isConnected()).isTrue(); + + new Rpc(transport, new Metadata(), "foo.test.google.fr").waitForResponse(); + } finally { + NettyClientHandler.enablePerRpcAuthorityCheck = false; + } + } + + // Without removing the port number part that {@link X509AuthorityVerifier} does, there will be a + // java.security.cert.CertificateException: Illegal given domain name: foo.test.google.fr:12345 + @Test + public void authorityOverrideInCallOptions_portNumberInAuthority_isStrippedForPeerVerification() + throws IOException, InterruptedException, GeneralSecurityException, ExecutionException, + TimeoutException { + NettyClientHandler.enablePerRpcAuthorityCheck = true; + try { + startServer(); + NettyClientTransport transport = newTransport(newNegotiator()); + SettableFuture connected = SettableFuture.create(); + FakeClientTransportListener fakeClientTransportListener = + new FakeClientTransportListener(connected); + callMeMaybe(transport.start(fakeClientTransportListener)); + connected.get(10, TimeUnit.SECONDS); + assertThat(fakeClientTransportListener.isConnected()).isTrue(); + + new Rpc(transport, new Metadata(), "foo.test.google.fr:12345").waitForResponse(); + } finally { + NettyClientHandler.enablePerRpcAuthorityCheck = false; + } + } + + @Test + public void authorityOverrideInCallOptions_portNumberAndIpv6_isStrippedForPeerVerification() + throws IOException, InterruptedException, GeneralSecurityException, ExecutionException, + TimeoutException { + NettyClientHandler.enablePerRpcAuthorityCheck = true; + try { + startServer(); + NettyClientTransport transport = newTransport(newNegotiator()); + SettableFuture connected = SettableFuture.create(); + FakeClientTransportListener fakeClientTransportListener = + new FakeClientTransportListener(connected); + callMeMaybe(transport.start(fakeClientTransportListener)); + connected.get(10, TimeUnit.SECONDS); + assertThat(fakeClientTransportListener.isConnected()).isTrue(); + + new Rpc(transport, new Metadata(), "[2001:db8:3333:4444:5555:6666:1.2.3.4]:12345") + .waitForResponse(); + } catch (ExecutionException ex) { + Status status = ((StatusException) ex.getCause()).getStatus(); + assertThat(status.getDescription()).isEqualTo("Peer hostname verification during rpc " + + "failed for authority '[2001:db8:3333:4444:5555:6666:1.2.3.4]:12345'"); + assertThat(status.getCode()).isEqualTo(Code.UNAVAILABLE); + assertThat(((InvocationTargetException) ex.getCause().getCause()).getTargetException()) + .isInstanceOf(CertificateException.class); + // Port number is removed by {@link X509AuthorityVerifier}. + assertThat(((InvocationTargetException) ex.getCause().getCause()).getTargetException() + .getMessage()).isEqualTo( + "No subject alternative names matching IP address 2001:db8:3333:4444:5555:6666:1.2.3.4 " + + "found"); + } finally { + NettyClientHandler.enablePerRpcAuthorityCheck = false; + } + } + + @Test + public void authorityOverrideInCallOptions_notMatches_flagDisabled_createsStream() + throws IOException, InterruptedException, GeneralSecurityException, ExecutionException, + TimeoutException { + startServer(); + NettyClientTransport transport = newTransport(newNegotiator()); + SettableFuture connected = SettableFuture.create(); + FakeClientTransportListener fakeClientTransportListener = + new FakeClientTransportListener(connected); + callMeMaybe(transport.start(fakeClientTransportListener)); + connected.get(10, TimeUnit.SECONDS); + assertThat(fakeClientTransportListener.isConnected()).isTrue(); + + new Rpc(transport, new Metadata(), "foo.test.google.in").waitForResponse(); + } + private Throwable getRootCause(Throwable t) { if (t.getCause() == null) { return t; @@ -781,10 +1076,37 @@ private Throwable getRootCause(Throwable t) { return getRootCause(t.getCause()); } - private ProtocolNegotiator newNegotiator() throws IOException { + private ProtocolNegotiator newNegotiator() throws IOException, GeneralSecurityException { InputStream caCert = TlsTesting.loadCert("ca.pem"); SslContext clientContext = GrpcSslContexts.forClient().trustManager(caCert).build(); - return ProtocolNegotiators.tls(clientContext); + return ProtocolNegotiators.tls(clientContext, + (X509TrustManager) getX509ExtendedTrustManager(TlsTesting.loadCert("ca.pem"))); + } + + private static TrustManager getX509ExtendedTrustManager(InputStream rootCerts) + throws GeneralSecurityException { + KeyStore ks = KeyStore.getInstance(KeyStore.getDefaultType()); + try { + ks.load(null, null); + } catch (IOException ex) { + // Shouldn't really happen, as we're not loading any data. + throw new GeneralSecurityException(ex); + } + X509Certificate[] certs = CertificateUtils.getX509Certificates(rootCerts); + for (X509Certificate cert : certs) { + X500Principal principal = cert.getSubjectX500Principal(); + ks.setCertificateEntry(principal.getName("RFC2253"), cert); + } + + TrustManagerFactory trustManagerFactory = + TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()); + trustManagerFactory.init(ks); + for (TrustManager trustManager : trustManagerFactory.getTrustManagers()) { + if (trustManager instanceof X509ExtendedTrustManager) { + return trustManager; + } + } + return null; } private NettyClientTransport newTransport(ProtocolNegotiator negotiator) { @@ -807,11 +1129,29 @@ private NettyClientTransport newTransport(ProtocolNegotiator negotiator, int max keepAliveTimeNano = KEEPALIVE_TIME_NANOS_DISABLED; } NettyClientTransport transport = new NettyClientTransport( - address, channelFactory, new HashMap, Object>(), group, - negotiator, false, DEFAULT_WINDOW_SIZE, maxMsgSize, maxHeaderListSize, - keepAliveTimeNano, keepAliveTimeoutNano, - false, authority, userAgent, tooManyPingsRunnable, - new TransportTracer(), eagAttributes, new SocketPicker(), new FakeChannelLogger(), false, + address, + channelFactory, + new HashMap, Object>(), + group, + negotiator, + false, + DEFAULT_WINDOW_SIZE, + maxMsgSize, + maxHeaderListSize, + maxHeaderListSize, + keepAliveTimeNano, + keepAliveTimeoutNano, + false, + authority, + userAgent, + tooManyPingsRunnable, + new TransportTracer(), + eagAttributes, + new SocketPicker(), + new FakeChannelLogger(), + false, + new MetricRecorder() { + }, Ticker.systemTicker()); transports.add(transport); return transport; @@ -821,23 +1161,45 @@ private void startServer() throws IOException { startServer(100, GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE); } + private void startServer(ServerListener serverListener) throws IOException { + startServer(100, GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE, serverListener); + } + private void startServer(int maxStreamsPerConnection, int maxHeaderListSize) throws IOException { + startServer(maxStreamsPerConnection, maxHeaderListSize, serverListener); + } + + private void startServer(int maxStreamsPerConnection, int maxHeaderListSize, + ServerListener serverListener) throws IOException { server = new NettyServer( TestUtils.testServerAddresses(new InetSocketAddress(0)), new ReflectiveChannelFactory<>(NioServerSocketChannel.class), new HashMap, Object>(), new HashMap, Object>(), - new FixedObjectPool<>(group), new FixedObjectPool<>(group), false, negotiator, + new FixedObjectPool<>(group), + new FixedObjectPool<>(group), + false, + negotiator, Collections.emptyList(), TransportTracer.getDefaultFactory(), maxStreamsPerConnection, false, - DEFAULT_WINDOW_SIZE, DEFAULT_MAX_MESSAGE_SIZE, maxHeaderListSize, - DEFAULT_SERVER_KEEPALIVE_TIME_NANOS, DEFAULT_SERVER_KEEPALIVE_TIMEOUT_NANOS, + DEFAULT_WINDOW_SIZE, + DEFAULT_MAX_MESSAGE_SIZE, + maxHeaderListSize, + maxHeaderListSize, + DEFAULT_SERVER_KEEPALIVE_TIME_NANOS, + DEFAULT_SERVER_KEEPALIVE_TIMEOUT_NANOS, MAX_CONNECTION_IDLE_NANOS_DISABLED, - MAX_CONNECTION_AGE_NANOS_DISABLED, MAX_CONNECTION_AGE_GRACE_NANOS_INFINITE, true, 0, - MAX_RST_COUNT_DISABLED, 0, Attributes.EMPTY, - channelz); + MAX_CONNECTION_AGE_NANOS_DISABLED, + MAX_CONNECTION_AGE_GRACE_NANOS_INFINITE, + true, + 0, + MAX_RST_COUNT_DISABLED, + 0, + Attributes.EMPTY, + channelz, + new MetricRecorder() {}); server.start(serverListener); address = TestUtils.testServerAddress((InetSocketAddress) server.getListenSocketAddress()); authority = GrpcUtil.authorityFromHostAndPort(address.getHostString(), address.getPort()); @@ -873,13 +1235,20 @@ private static class Rpc { final TestClientStreamListener listener = new TestClientStreamListener(); Rpc(NettyClientTransport transport) { - this(transport, new Metadata()); + this(transport, new Metadata(), null); } Rpc(NettyClientTransport transport, Metadata headers) { + this(transport, headers, null); + } + + Rpc(NettyClientTransport transport, Metadata headers, String authorityOverride) { stream = transport.newStream( METHOD, headers, CallOptions.DEFAULT, new ClientStreamTracer[]{ new ClientStreamTracer() {} }); + if (authorityOverride != null) { + stream.setAuthority(authorityOverride); + } stream.start(listener); stream.request(1); stream.writeMessage(new ByteArrayInputStream(MESSAGE.getBytes(UTF_8))); @@ -969,6 +1338,15 @@ private final class EchoServerListener implements ServerListener { final List transports = new ArrayList<>(); final List streamListeners = Collections.synchronizedList(new ArrayList()); + Metadata responseHeaders; + + public EchoServerListener() { + this(new Metadata()); + } + + public EchoServerListener(Metadata responseHeaders) { + this.responseHeaders = responseHeaders; + } @Override public ServerTransportListener transportCreated(final ServerTransport transport) { @@ -978,7 +1356,7 @@ public ServerTransportListener transportCreated(final ServerTransport transport) public void streamCreated(ServerStream stream, String method, Metadata headers) { EchoServerStreamListener listener = new EchoServerStreamListener(stream, headers); stream.setListener(listener); - stream.writeHeaders(new Metadata(), true); + stream.writeHeaders(responseHeaders, true); stream.request(1); streamListeners.add(listener); } @@ -1025,9 +1403,15 @@ public NoopHandler(GrpcHttp2ConnectionHandler grpcHandler) { this.grpcHandler = grpcHandler; } + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) { + // Prevent any data being passed to NettyClientHandler + ReferenceCountUtil.release(msg); + } + @Override public void channelRegistered(ChannelHandlerContext ctx) throws Exception { - ctx.pipeline().addBefore(ctx.name(), null, grpcHandler); + ctx.pipeline().addAfter(ctx.name(), null, grpcHandler); } public void fail(ChannelHandlerContext ctx, Throwable cause) { @@ -1071,4 +1455,62 @@ public void log(ChannelLogLevel level, String message) {} @Override public void log(ChannelLogLevel level, String messageFormat, Object... args) {} } + + static class FakeClientTransportListener implements ManagedClientTransport.Listener { + private final SettableFuture connected; + + @GuardedBy("this") + private boolean isConnected = false; + + public FakeClientTransportListener(SettableFuture connected) { + this.connected = connected; + } + + @Override + public void transportShutdown(Status s, DisconnectError e) {} + + @Override + public void transportTerminated() {} + + @Override + public void transportReady() { + synchronized (this) { + isConnected = true; + } + connected.set(null); + } + + synchronized boolean isConnected() { + return isConnected; + } + + @Override + public void transportInUse(boolean inUse) {} + } + + private static class FakeTrustManager implements X509TrustManager { + + private final X509TrustManager delegate; + + public FakeTrustManager(X509TrustManager x509ExtendedTrustManager) { + this.delegate = x509ExtendedTrustManager; + } + + @Override + public void checkClientTrusted(X509Certificate[] x509Certificates, String s) + throws CertificateException { + delegate.checkClientTrusted(x509Certificates, s); + } + + @Override + public void checkServerTrusted(X509Certificate[] x509Certificates, String s) + throws CertificateException { + delegate.checkServerTrusted(x509Certificates, s); + } + + @Override + public X509Certificate[] getAcceptedIssuers() { + return delegate.getAcceptedIssuers(); + } + } } diff --git a/netty/src/test/java/io/grpc/netty/NettyHandlerTestBase.java b/netty/src/test/java/io/grpc/netty/NettyHandlerTestBase.java index fbab1ca5fae..c971294fbb6 100644 --- a/netty/src/test/java/io/grpc/netty/NettyHandlerTestBase.java +++ b/netty/src/test/java/io/grpc/netty/NettyHandlerTestBase.java @@ -16,8 +16,8 @@ package io.grpc.netty; -import static com.google.common.base.Charsets.UTF_8; import static io.netty.handler.codec.http2.Http2CodecUtil.DEFAULT_WINDOW_SIZE; +import static java.nio.charset.StandardCharsets.UTF_8; import static org.junit.Assert.assertEquals; import static org.mockito.AdditionalAnswers.delegatesTo; import static org.mockito.ArgumentMatchers.any; @@ -38,7 +38,6 @@ import io.grpc.internal.WritableBuffer; import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; -import io.netty.buffer.ByteBufUtil; import io.netty.buffer.CompositeByteBuf; import io.netty.buffer.Unpooled; import io.netty.buffer.UnpooledByteBufAllocator; @@ -68,6 +67,7 @@ import java.nio.ByteBuffer; import java.util.concurrent.Delayed; import java.util.concurrent.TimeUnit; +import org.junit.After; import org.junit.Assert; import org.junit.Test; import org.junit.runner.RunWith; @@ -84,7 +84,6 @@ public abstract class NettyHandlerTestBase { protected static final int STREAM_ID = 3; - private ByteBuf content; private EmbeddedChannel channel; @@ -106,18 +105,24 @@ protected void manualSetUp() throws Exception {} protected final TransportTracer transportTracer = new TransportTracer(); protected int flowControlWindow = DEFAULT_WINDOW_SIZE; protected boolean autoFlowControl = false; - private final FakeClock fakeClock = new FakeClock(); FakeClock fakeClock() { return fakeClock; } + @After + public void tearDown() throws Exception { + if (channel() != null) { + channel().releaseInbound(); + channel().releaseOutbound(); + } + } + /** * Must be called by subclasses to initialize the handler and channel. */ protected final void initChannel(Http2HeadersDecoder headersDecoder) throws Exception { - content = Unpooled.copiedBuffer("hello world", UTF_8); frameWriter = mock(Http2FrameWriter.class, delegatesTo(new DefaultHttp2FrameWriter())); frameReader = new DefaultHttp2FrameReader(headersDecoder); @@ -233,11 +238,11 @@ protected final Http2FrameReader frameReader() { } protected final ByteBuf content() { - return content; + return Unpooled.copiedBuffer(contentAsArray()); } protected final byte[] contentAsArray() { - return ByteBufUtil.getBytes(content()); + return "\000\000\000\000\rhello world".getBytes(UTF_8); } protected final Http2FrameWriter verifyWrite() { @@ -252,8 +257,8 @@ protected final void channelRead(Object obj) throws Exception { channel.writeInbound(obj); } - protected ByteBuf grpcDataFrame(int streamId, boolean endStream, byte[] content) { - final ByteBuf compressionFrame = Unpooled.buffer(content.length); + protected ByteBuf grpcFrame(byte[] message) { + final ByteBuf compressionFrame = Unpooled.buffer(message.length); MessageFramer framer = new MessageFramer( new MessageFramer.Sink() { @Override @@ -262,23 +267,22 @@ public void deliverFrame( if (frame != null) { ByteBuf bytebuf = ((NettyWritableBuffer) frame).bytebuf(); compressionFrame.writeBytes(bytebuf); + bytebuf.release(); } } }, new NettyWritableBufferAllocator(ByteBufAllocator.DEFAULT), StatsTraceContext.NOOP); - framer.writePayload(new ByteArrayInputStream(content)); - framer.flush(); - ChannelHandlerContext ctx = newMockContext(); - new DefaultHttp2FrameWriter().writeData(ctx, streamId, compressionFrame, 0, endStream, - newPromise()); - return captureWrite(ctx); + framer.writePayload(new ByteArrayInputStream(message)); + framer.close(); + return compressionFrame; } - protected final ByteBuf dataFrame(int streamId, boolean endStream, ByteBuf content) { - // Need to retain the content since the frameWriter releases it. - content.retain(); + protected final ByteBuf grpcDataFrame(int streamId, boolean endStream, byte[] content) { + return dataFrame(streamId, endStream, grpcFrame(content)); + } + protected final ByteBuf dataFrame(int streamId, boolean endStream, ByteBuf content) { ChannelHandlerContext ctx = newMockContext(); new DefaultHttp2FrameWriter().writeData(ctx, streamId, content, 0, endStream, newPromise()); return captureWrite(ctx); @@ -410,6 +414,7 @@ public void dataSizeSincePingAccumulates() throws Exception { channelRead(dataFrame(3, false, buff.copy())); assertEquals(length * 3, handler.flowControlPing().getDataSincePing()); + buff.release(); } @Test @@ -608,12 +613,14 @@ public void bdpPingWindowResizing() throws Exception { private void readPingAck(long pingData) throws Exception { channelRead(pingFrame(true, pingData)); + channel().releaseOutbound(); } private void readXCopies(int copies, byte[] data) throws Exception { for (int i = 0; i < copies; i++) { channelRead(grpcDataFrame(STREAM_ID, false, data)); // buffer it stream().request(1); // consume it + channel().releaseOutbound(); } } diff --git a/netty/src/test/java/io/grpc/netty/NettyReadableBufferTest.java b/netty/src/test/java/io/grpc/netty/NettyReadableBufferTest.java index 1a0ac229a89..4f10504c07d 100644 --- a/netty/src/test/java/io/grpc/netty/NettyReadableBufferTest.java +++ b/netty/src/test/java/io/grpc/netty/NettyReadableBufferTest.java @@ -16,7 +16,7 @@ package io.grpc.netty; -import static com.google.common.base.Charsets.UTF_8; +import static java.nio.charset.StandardCharsets.UTF_8; import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; diff --git a/netty/src/test/java/io/grpc/netty/NettyServerBuilderTest.java b/netty/src/test/java/io/grpc/netty/NettyServerBuilderTest.java index 6d8192322aa..f3b73a515b5 100644 --- a/netty/src/test/java/io/grpc/netty/NettyServerBuilderTest.java +++ b/netty/src/test/java/io/grpc/netty/NettyServerBuilderTest.java @@ -16,20 +16,19 @@ package io.grpc.netty; +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; import com.google.common.collect.ImmutableList; -import com.google.common.truth.Truth; -import io.grpc.ServerStreamTracer; +import io.grpc.MetricRecorder; import io.netty.channel.EventLoopGroup; import io.netty.channel.local.LocalServerChannel; import io.netty.handler.ssl.SslContext; import java.net.InetSocketAddress; import java.util.concurrent.TimeUnit; -import org.junit.Rule; import org.junit.Test; -import org.junit.rules.ExpectedException; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -39,18 +38,16 @@ @RunWith(JUnit4.class) public class NettyServerBuilderTest { - @SuppressWarnings("deprecation") // https://github.com/grpc/grpc-java/issues/7467 - @Rule public final ExpectedException thrown = ExpectedException.none(); - private NettyServerBuilder builder = NettyServerBuilder.forPort(8080); @Test public void addMultipleListenAddresses() { builder.addListenAddress(new InetSocketAddress(8081)); - NettyServer server = - builder.buildTransportServers(ImmutableList.of()); + NettyServer server = builder.buildTransportServers( + ImmutableList.of(), + new MetricRecorder() {}); - Truth.assertThat(server.getListenSocketAddresses()).hasSize(2); + assertThat(server.getListenSocketAddresses()).hasSize(2); } @Test @@ -63,105 +60,112 @@ public void failIfSslContextIsNotServer() { SslContext sslContext = mock(SslContext.class); when(sslContext.isClient()).thenReturn(true); - thrown.expect(IllegalArgumentException.class); - thrown.expectMessage("Client SSL context can not be used for server"); - builder.sslContext(sslContext); + IllegalArgumentException e = assertThrows( + IllegalArgumentException.class, () -> builder.sslContext(sslContext)); + assertThat(e).hasMessageThat().isEqualTo("Client SSL context can not be used for server"); } @Test public void failIfKeepAliveTimeNegative() { - thrown.expect(IllegalArgumentException.class); - thrown.expectMessage("keepalive time must be positive"); - - builder.keepAliveTime(-10L, TimeUnit.HOURS); + IllegalArgumentException e = assertThrows(IllegalArgumentException.class, + () -> builder.keepAliveTime(-10L, TimeUnit.HOURS)); + assertThat(e).hasMessageThat().isEqualTo("keepalive time must be positive:-10"); } @Test public void failIfKeepAliveTimeoutNegative() { - thrown.expect(IllegalArgumentException.class); - thrown.expectMessage("keepalive timeout must be positive"); - - builder.keepAliveTimeout(-10L, TimeUnit.HOURS); + IllegalArgumentException e = assertThrows(IllegalArgumentException.class, + () -> builder.keepAliveTimeout(-10L, TimeUnit.HOURS)); + assertThat(e).hasMessageThat().isEqualTo("keepalive timeout must be positive: -10"); } @Test public void failIfMaxConcurrentCallsPerConnectionNegative() { - thrown.expect(IllegalArgumentException.class); - thrown.expectMessage("max must be positive"); - - builder.maxConcurrentCallsPerConnection(0); + IllegalArgumentException e = assertThrows(IllegalArgumentException.class, + () -> builder.maxConcurrentCallsPerConnection(0)); + assertThat(e).hasMessageThat().isEqualTo("max must be positive: 0"); } @Test public void failIfMaxInboundMetadataSizeNonPositive() { - thrown.expect(IllegalArgumentException.class); - thrown.expectMessage("maxInboundMetadataSize must be positive"); + IllegalArgumentException e = assertThrows(IllegalArgumentException.class, + () -> builder.maxInboundMetadataSize(0)); + assertThat(e).hasMessageThat().isEqualTo("maxInboundMetadataSize must be positive: 0"); + } - builder.maxInboundMetadataSize(0); + @Test + public void failIfSoftInboundMetadataSizeNonPositive() { + IllegalArgumentException e = assertThrows(IllegalArgumentException.class, + () -> builder.maxInboundMetadataSize(0, 100)); + assertThat(e).hasMessageThat().isEqualTo("softLimitHeaderListSize must be positive: 0"); } @Test - public void failIfMaxConnectionIdleNegative() { - thrown.expect(IllegalArgumentException.class); - thrown.expectMessage("max connection idle must be positive"); + public void failIfMaxInboundMetadataSizeSmallerThanSoft() { + IllegalArgumentException e = assertThrows(IllegalArgumentException.class, + () -> builder.maxInboundMetadataSize(100, 80)); + assertThat(e).hasMessageThat().isEqualTo("maxInboundMetadataSize: 80 " + + "must be greater than softLimitHeaderListSize: 100"); + } - builder.maxConnectionIdle(-1, TimeUnit.HOURS); + @Test + public void failIfMaxConnectionIdleNegative() { + IllegalArgumentException e = assertThrows(IllegalArgumentException.class, + () -> builder.maxConnectionIdle(-1, TimeUnit.HOURS)); + assertThat(e).hasMessageThat().isEqualTo("max connection idle must be positive: -1"); } @Test public void failIfMaxConnectionAgeNegative() { - thrown.expect(IllegalArgumentException.class); - thrown.expectMessage("max connection age must be positive"); - - builder.maxConnectionAge(-1, TimeUnit.HOURS); + IllegalArgumentException e = assertThrows(IllegalArgumentException.class, + () -> builder.maxConnectionAge(-1, TimeUnit.HOURS)); + assertThat(e).hasMessageThat().isEqualTo("max connection age must be positive: -1"); } @Test public void failIfMaxConnectionAgeGraceNegative() { - thrown.expect(IllegalArgumentException.class); - thrown.expectMessage("max connection age grace must be non-negative"); - - builder.maxConnectionAgeGrace(-1, TimeUnit.HOURS); + IllegalArgumentException e = assertThrows(IllegalArgumentException.class, + () -> builder.maxConnectionAgeGrace(-1, TimeUnit.HOURS)); + assertThat(e).hasMessageThat().isEqualTo("max connection age grace must be non-negative: -1"); } @Test public void failIfPermitKeepAliveTimeNegative() { - thrown.expect(IllegalArgumentException.class); - thrown.expectMessage("permit keepalive time must be non-negative"); - - builder.permitKeepAliveTime(-1, TimeUnit.HOURS); + IllegalArgumentException e = assertThrows(IllegalArgumentException.class, + () -> builder.permitKeepAliveTime(-1, TimeUnit.HOURS)); + assertThat(e).hasMessageThat().isEqualTo("permit keepalive time must be non-negative: -1"); } @Test public void assertEventLoopsAndChannelType_onlyBossGroupProvided() { EventLoopGroup mockEventLoopGroup = mock(EventLoopGroup.class); builder.bossEventLoopGroup(mockEventLoopGroup); - thrown.expect(IllegalStateException.class); - thrown.expectMessage( - "All of BossEventLoopGroup, WorkerEventLoopGroup and ChannelType should be provided"); - - builder.assertEventLoopsAndChannelType(); + IllegalStateException e = assertThrows(IllegalStateException.class, + builder::assertEventLoopsAndChannelType); + assertThat(e).hasMessageThat().isEqualTo( + "All of BossEventLoopGroup, WorkerEventLoopGroup and ChannelType should be provided " + + "or neither should be"); } @Test public void assertEventLoopsAndChannelType_onlyWorkerGroupProvided() { EventLoopGroup mockEventLoopGroup = mock(EventLoopGroup.class); builder.workerEventLoopGroup(mockEventLoopGroup); - thrown.expect(IllegalStateException.class); - thrown.expectMessage( - "All of BossEventLoopGroup, WorkerEventLoopGroup and ChannelType should be provided"); - - builder.assertEventLoopsAndChannelType(); + IllegalStateException e = assertThrows(IllegalStateException.class, + builder::assertEventLoopsAndChannelType); + assertThat(e).hasMessageThat().isEqualTo( + "All of BossEventLoopGroup, WorkerEventLoopGroup and ChannelType should be provided " + + "or neither should be"); } @Test public void assertEventLoopsAndChannelType_onlyTypeProvided() { builder.channelType(LocalServerChannel.class); - thrown.expect(IllegalStateException.class); - thrown.expectMessage( - "All of BossEventLoopGroup, WorkerEventLoopGroup and ChannelType should be provided"); - - builder.assertEventLoopsAndChannelType(); + IllegalStateException e = assertThrows(IllegalStateException.class, + builder::assertEventLoopsAndChannelType); + assertThat(e).hasMessageThat().isEqualTo( + "All of BossEventLoopGroup, WorkerEventLoopGroup and ChannelType should be provided " + + "or neither should be"); } @Test @@ -186,4 +190,5 @@ public void useNioTransport_shouldNotThrow() { builder.assertEventLoopsAndChannelType(); } + } diff --git a/netty/src/test/java/io/grpc/netty/NettyServerHandlerTest.java b/netty/src/test/java/io/grpc/netty/NettyServerHandlerTest.java index 281ff3b17d6..1c8d2b5479d 100644 --- a/netty/src/test/java/io/grpc/netty/NettyServerHandlerTest.java +++ b/netty/src/test/java/io/grpc/netty/NettyServerHandlerTest.java @@ -16,7 +16,6 @@ package io.grpc.netty; -import static com.google.common.base.Charsets.UTF_8; import static io.grpc.internal.GrpcUtil.DEFAULT_MAX_MESSAGE_SIZE; import static io.grpc.internal.GrpcUtil.DEFAULT_SERVER_KEEPALIVE_TIMEOUT_NANOS; import static io.grpc.internal.GrpcUtil.DEFAULT_SERVER_KEEPALIVE_TIME_NANOS; @@ -29,6 +28,7 @@ import static io.grpc.netty.Utils.HTTP_METHOD; import static io.grpc.netty.Utils.TE_HEADER; import static io.grpc.netty.Utils.TE_TRAILERS; +import static java.nio.charset.StandardCharsets.UTF_8; import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; @@ -43,6 +43,7 @@ import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.isA; +import static org.mockito.ArgumentMatchers.same; import static org.mockito.Mockito.atLeastOnce; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doThrow; @@ -58,6 +59,7 @@ import io.grpc.Attributes; import io.grpc.InternalStatus; import io.grpc.Metadata; +import io.grpc.MetricRecorder; import io.grpc.ServerStreamTracer; import io.grpc.Status; import io.grpc.Status.Code; @@ -74,10 +76,10 @@ import io.grpc.internal.testing.TestServerStreamTracer; import io.grpc.netty.GrpcHttp2HeadersUtils.GrpcHttp2ServerHeadersDecoder; import io.netty.buffer.ByteBuf; -import io.netty.buffer.ByteBufUtil; import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelPromise; +import io.netty.handler.codec.http.HttpHeaderNames; import io.netty.handler.codec.http2.DefaultHttp2Headers; import io.netty.handler.codec.http2.Http2CodecUtil; import io.netty.handler.codec.http2.Http2Error; @@ -89,8 +91,10 @@ import java.io.InputStream; import java.nio.channels.ClosedChannelException; import java.util.Arrays; +import java.util.HashMap; import java.util.LinkedList; import java.util.List; +import java.util.Map; import java.util.Queue; import java.util.concurrent.TimeUnit; import org.junit.Before; @@ -118,27 +122,22 @@ public class NettyServerHandlerTest extends NettyHandlerTestBase streamListenerMessageQueue = new LinkedList<>(); private int maxConcurrentStreams = Integer.MAX_VALUE; private int maxHeaderListSize = Integer.MAX_VALUE; + private int softLimitHeaderListSize = Integer.MAX_VALUE; private boolean permitKeepAliveWithoutCalls = true; private long permitKeepAliveTimeInNanos = 0; private long maxConnectionIdleInNanos = MAX_CONNECTION_IDLE_NANOS_DISABLED; @@ -205,6 +204,19 @@ protected void manualSetUp() throws Exception { // Simulate receipt of initial remote settings. ByteBuf serializedSettings = serializeSettings(new Http2Settings()); channelRead(serializedSettings); + channel().releaseOutbound(); + } + + @Test + public void tcpMetrics_recorded() throws Exception { + manualSetUp(); + handler().channelActive(ctx()); + // Verify that channelActive triggered TcpMetrics + verify(metricRecorder, atLeastOnce()).addLongCounter( + eq(io.grpc.InternalTcpMetrics.CONNECTIONS_CREATED_INSTRUMENT), + eq(1L), + any(), + any()); } @Test @@ -226,10 +238,11 @@ public void sendFrameShouldSucceed() throws Exception { createStream(); // Send a frame and verify that it was written. + ByteBuf content = content(); ChannelFuture future = enqueue( - new SendGrpcFrameCommand(stream.transportState(), content(), false)); + new SendGrpcFrameCommand(stream.transportState(), content, false)); assertTrue(future.isSuccess()); - verifyWrite().writeData(eq(ctx()), eq(STREAM_ID), eq(content()), eq(0), eq(false), + verifyWrite().writeData(eq(ctx()), eq(STREAM_ID), same(content), eq(0), eq(false), any(ChannelPromise.class)); } @@ -264,10 +277,11 @@ private void inboundDataShouldForwardToStreamListener(boolean endStream) throws // Create a data frame and then trigger the handler to read it. ByteBuf frame = grpcDataFrame(STREAM_ID, endStream, contentAsArray()); channelRead(frame); + channel().releaseOutbound(); verify(streamListener, atLeastOnce()) .messagesAvailable(any(StreamListener.MessageProducer.class)); InputStream message = streamListenerMessageQueue.poll(); - assertArrayEquals(ByteBufUtil.getBytes(content()), ByteStreams.toByteArray(message)); + assertArrayEquals(contentAsArray(), ByteStreams.toByteArray(message)); message.close(); assertNull("no additional message expected", streamListenerMessageQueue.poll()); @@ -469,11 +483,41 @@ public void connectionWindowShouldBeOverridden() throws Exception { public void cancelShouldSendRstStream() throws Exception { manualSetUp(); createStream(); - enqueue(new CancelServerStreamCommand(stream.transportState(), Status.DEADLINE_EXCEEDED)); + enqueue(CancelServerStreamCommand.withReset(stream.transportState(), Status.DEADLINE_EXCEEDED)); verifyWrite().writeRstStream(eq(ctx()), eq(stream.transportState().id()), eq(Http2Error.CANCEL.code()), any(ChannelPromise.class)); } + @Test + public void cancelWithNotify_shouldSendHeaders() throws Exception { + manualSetUp(); + createStream(); + + enqueue(CancelServerStreamCommand.withReason( + stream.transportState(), + Status.RESOURCE_EXHAUSTED.withDescription("my custom description") + )); + + ArgumentCaptor captor = ArgumentCaptor.forClass(Http2Headers.class); + verifyWrite() + .writeHeaders( + eq(ctx()), + eq(STREAM_ID), + captor.capture(), + eq(0), + eq(true), + any(ChannelPromise.class)); + + // For arcane reasons, the specific implementation of Http2Headers here doesn't actually support + // methods like `get(...)`, so we have to manually convert it into a map. + Map actualHeaders = new HashMap<>(); + for (Map.Entry entry : captor.getValue()) { + actualHeaders.put(entry.getKey().toString(), entry.getValue().toString()); + } + assertEquals("8", actualHeaders.get(InternalStatus.CODE_KEY.name())); + assertEquals("my custom description", actualHeaders.get(InternalStatus.MESSAGE_KEY.name())); + } + @Test public void headersWithInvalidContentTypeShouldFail() throws Exception { manualSetUp(); @@ -513,7 +557,8 @@ public void headersWithInvalidMethodShouldFail() throws Exception { .set(InternalStatus.CODE_KEY.name(), String.valueOf(Code.INTERNAL.value())) .set(InternalStatus.MESSAGE_KEY.name(), "Method 'FAKE' is not supported") .status("" + 405) - .set(CONTENT_TYPE_HEADER, "text/plain; charset=utf-8"); + .set(CONTENT_TYPE_HEADER, "text/plain; charset=utf-8") + .set(HttpHeaderNames.ALLOW, HTTP_METHOD); verifyWrite() .writeHeaders( @@ -837,7 +882,7 @@ public void keepAliveEnforcer_sendingDataResetsCounters() throws Exception { future.get(); for (int i = 0; i < 10; i++) { future = enqueue( - new SendGrpcFrameCommand(stream.transportState(), content().retainedSlice(), false)); + new SendGrpcFrameCommand(stream.transportState(), content(), false)); future.get(); channel().releaseOutbound(); channelRead(pingFrame(false /* isAck */, 1L)); @@ -1260,6 +1305,7 @@ public void maxRstCount_withinLimit_succeeds() throws Exception { maxRstPeriodNanos = TimeUnit.MILLISECONDS.toNanos(100); manualSetUp(); rapidReset(maxRstCount); + assertTrue(channel().isOpen()); } @@ -1269,10 +1315,13 @@ public void maxRstCount_exceedsLimit_fails() throws Exception { maxRstPeriodNanos = TimeUnit.MILLISECONDS.toNanos(100); manualSetUp(); assertThrows(ClosedChannelException.class, () -> rapidReset(maxRstCount + 1)); + assertFalse(channel().isOpen()); } private void rapidReset(int burstSize) throws Exception { + when(streamTracerFactory.newServerStreamTracer(anyString(), any(Metadata.class))) + .thenAnswer((args) -> new TestServerStreamTracer()); Http2Headers headers = new DefaultHttp2Headers() .method(HTTP_METHOD) .set(CONTENT_TYPE_HEADER, new AsciiString("application/grpc", UTF_8)) @@ -1292,6 +1341,48 @@ private void rapidReset(int burstSize) throws Exception { } } + @Test + public void maxRstCountSent_withinLimit_succeeds() throws Exception { + maxRstCount = 10; + maxRstPeriodNanos = TimeUnit.MILLISECONDS.toNanos(100); + manualSetUp(); + madeYouReset(maxRstCount); + + assertTrue(channel().isOpen()); + } + + @Test + public void maxRstCountSent_exceedsLimit_fails() throws Exception { + maxRstCount = 10; + maxRstPeriodNanos = TimeUnit.MILLISECONDS.toNanos(100); + manualSetUp(); + assertThrows(ClosedChannelException.class, () -> madeYouReset(maxRstCount + 1)); + + assertFalse(channel().isOpen()); + } + + private void madeYouReset(int burstSize) throws Exception { + when(streamTracerFactory.newServerStreamTracer(anyString(), any(Metadata.class))) + .thenAnswer((args) -> new TestServerStreamTracer()); + Http2Headers headers = new DefaultHttp2Headers() + .method(HTTP_METHOD) + .set(CONTENT_TYPE_HEADER, new AsciiString("application/grpc", UTF_8)) + .set(TE_HEADER, TE_TRAILERS) + .path(new AsciiString("/foo/bar")); + int streamId = 1; + long rpcTimeNanos = maxRstPeriodNanos / 2 / burstSize; + for (int period = 0; period < 3; period++) { + for (int i = 0; i < burstSize; i++) { + channelRead(headersFrame(streamId, headers)); + channelRead(windowUpdate(streamId, 0)); + streamId += 2; + fakeClock().forwardNanos(rpcTimeNanos); + } + while (channel().readOutbound() != null) {} + fakeClock().forwardNanos(maxRstPeriodNanos - rpcTimeNanos * burstSize + 1); + } + } + private void createStream() throws Exception { Http2Headers headers = new DefaultHttp2Headers() .method(HTTP_METHOD) @@ -1311,11 +1402,7 @@ private void createStream() throws Exception { private ByteBuf emptyGrpcFrame(int streamId, boolean endStream) throws Exception { ByteBuf buf = NettyTestUtil.messageFrame(""); - try { - return dataFrame(streamId, endStream, buf); - } finally { - buf.release(); - } + return dataFrame(streamId, endStream, buf); } @Override @@ -1331,6 +1418,7 @@ protected NettyServerHandler newHandler() { autoFlowControl, flowControlWindow, maxHeaderListSize, + softLimitHeaderListSize, DEFAULT_MAX_MESSAGE_SIZE, keepAliveTimeInNanos, keepAliveTimeoutInNanos, @@ -1342,7 +1430,8 @@ protected NettyServerHandler newHandler() { maxRstCount, maxRstPeriodNanos, Attributes.EMPTY, - fakeClock().getTicker()); + fakeClock().getTicker(), + metricRecorder); } @Override diff --git a/netty/src/test/java/io/grpc/netty/NettyServerStreamTest.java b/netty/src/test/java/io/grpc/netty/NettyServerStreamTest.java index e95a2a52bc9..2f2933ae103 100644 --- a/netty/src/test/java/io/grpc/netty/NettyServerStreamTest.java +++ b/netty/src/test/java/io/grpc/netty/NettyServerStreamTest.java @@ -17,12 +17,16 @@ package io.grpc.netty; import static com.google.common.truth.Truth.assertThat; -import static io.grpc.internal.GrpcUtil.DEFAULT_MAX_MESSAGE_SIZE; +import static com.google.common.truth.Truth.assertWithMessage; import static io.grpc.netty.NettyTestUtil.messageFrame; +import static io.netty.handler.codec.http2.Http2Error.PROTOCOL_ERROR; +import static io.netty.handler.codec.http2.Http2Exception.connectionError; import static org.junit.Assert.assertNull; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyBoolean; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.same; +import static org.mockito.Mockito.atLeast; import static org.mockito.Mockito.atLeastOnce; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.never; @@ -32,8 +36,11 @@ import static org.mockito.Mockito.verifyNoMoreInteractions; import static org.mockito.Mockito.when; +import com.google.common.base.Strings; import com.google.common.collect.ImmutableListMultimap; +import com.google.common.collect.Iterables; import com.google.common.collect.ListMultimap; +import com.google.common.collect.Lists; import io.grpc.Attributes; import io.grpc.Metadata; import io.grpc.Status; @@ -43,11 +50,14 @@ import io.grpc.internal.TransportTracer; import io.netty.buffer.EmptyByteBuf; import io.netty.buffer.UnpooledByteBufAllocator; +import io.netty.channel.DefaultChannelPromise; import io.netty.handler.codec.http2.DefaultHttp2Headers; +import io.netty.handler.codec.http2.Http2Exception; import io.netty.util.AsciiString; import java.io.ByteArrayInputStream; import java.io.InputStream; import java.util.LinkedList; +import java.util.List; import java.util.Queue; import org.junit.Before; import org.junit.Test; @@ -55,13 +65,17 @@ import org.junit.runners.JUnit4; import org.mockito.ArgumentCaptor; import org.mockito.ArgumentMatchers; +import org.mockito.InOrder; import org.mockito.Mock; +import org.mockito.Mockito; import org.mockito.invocation.InvocationOnMock; import org.mockito.stubbing.Answer; /** Unit tests for {@link NettyServerStream}. */ @RunWith(JUnit4.class) public class NettyServerStreamTest extends NettyStreamTestBase { + private static final int TEST_MAX_MESSAGE_SIZE = 128; + @Mock protected ServerStreamListener serverListener; @@ -124,6 +138,100 @@ public void writeMessageShouldSendResponse() throws Exception { eq(true)); } + @Test + public void writeFrameFutureFailedShouldCancelRpc() { + Http2Exception h2Error = connectionError(PROTOCOL_ERROR, "Stream does not exist %d", STREAM_ID); + when(writeQueue.enqueue(any(SendGrpcFrameCommand.class), anyBoolean())).thenReturn( + new DefaultChannelPromise(channel).setFailure(h2Error)); + + // Write multiple messages to ensure multiple SendGrpcFrameCommand are enqueued. We set up all + // of them to fail, which allows us to assert that only a single cancel is sent, and the stream + // isn't spammed with multiple RST_STREAM. + stream.writeMessage(new ByteArrayInputStream(smallMessage())); + stream.writeMessage(new ByteArrayInputStream(largeMessage())); + stream.flush(); + + verifyWriteFutureFailure(h2Error); + // Verify CancelServerStreamCommand enqueued once, right after first SendGrpcFrameCommand. + InOrder inOrder = Mockito.inOrder(writeQueue); + inOrder.verify(writeQueue).enqueue(any(SendGrpcFrameCommand.class), eq(false)); + // Verify that failed SendGrpcFrameCommand results in immediate CancelServerStreamCommand. + inOrder.verify(writeQueue).enqueue(any(CancelServerStreamCommand.class), eq(true)); + // Verify that any other failures do not produce another CancelServerStreamCommand in the queue. + inOrder.verify(writeQueue, atLeast(0)).enqueue(any(SendGrpcFrameCommand.class), eq(false)); + inOrder.verify(writeQueue).enqueue(any(SendGrpcFrameCommand.class), eq(true)); + inOrder.verifyNoMoreInteractions(); + } + + @Test + public void writeHeadersFutureFailedShouldCancelRpc() { + Http2Exception h2Error = connectionError(PROTOCOL_ERROR, "Stream does not exist %d", STREAM_ID); + Class headersCommandClass = SendResponseHeadersCommand.class; + when(writeQueue.enqueue(any(headersCommandClass), anyBoolean())).thenReturn( + new DefaultChannelPromise(channel).setFailure(h2Error)); + + // Prepare different headers to make it easier to distinguish in the error message. + Metadata headers1 = new Metadata(); + headers1.put(Metadata.Key.of("writeHeaders", Metadata.ASCII_STRING_MARSHALLER), "1"); + Metadata headers2 = new Metadata(); + headers2.put(Metadata.Key.of("writeHeaders", Metadata.ASCII_STRING_MARSHALLER), "2"); + Metadata headers3 = new Metadata(); + headers3.put(Metadata.Key.of("writeHeaders", Metadata.ASCII_STRING_MARSHALLER), "3"); + + // Note writeHeaders flush argument shouldn't matter for this test. + stream().writeHeaders(headers1, false); + stream().writeHeaders(headers2, false); + stream().writeHeaders(headers3, true); + stream.flush(); + + verifyWriteFutureFailure(h2Error); + // Verify CancelServerStreamCommand enqueued once, right after first SendResponseHeadersCommand. + InOrder inOrder = Mockito.inOrder(writeQueue); + inOrder.verify(writeQueue).enqueue(any(headersCommandClass), anyBoolean()); + inOrder.verify(writeQueue).enqueue(any(CancelServerStreamCommand.class), eq(true)); + inOrder.verify(writeQueue, atLeast(1)).enqueue(any(headersCommandClass), anyBoolean()); + inOrder.verifyNoMoreInteractions(); + } + + @Test + public void writeTrailersFutureFailedShouldCancelRpc() { + Http2Exception h2Error = connectionError(PROTOCOL_ERROR, "Stream does not exist %d", STREAM_ID); + when(writeQueue.enqueue(any(SendResponseHeadersCommand.class), eq(true))).thenReturn( + new DefaultChannelPromise(channel).setFailure(h2Error)); + + stream().close(Status.OK, trailers); + + verifyWriteFutureFailure(h2Error); + verify(writeQueue).enqueue(any(CancelServerStreamCommand.class), eq(true)); + } + + private void verifyWriteFutureFailure(Http2Exception h2Error) { + // Check the error that caused the future write failure propagated via Status. + Status cancelReason = findCancelServerStreamCommand().reason(); + assertThat(cancelReason.getCode()).isEqualTo(Status.INTERNAL.getCode()); + assertThat(cancelReason.getCause()).isEqualTo(h2Error); + // Verify the listener has closed. + verify(serverListener).closed(same(cancelReason)); + } + + private CancelServerStreamCommand findCancelServerStreamCommand() { + // Ensure there's no CancelServerStreamCommand enqueued with flush=false. + verify(writeQueue, never()).enqueue(any(CancelServerStreamCommand.class), eq(false)); + + List commands = Lists.newArrayList( + Iterables.transform( + Iterables.filter( + Mockito.mockingDetails(writeQueue).getInvocations(), + // Get enqueue() innovations only + invocation -> invocation.getMethod().getName().equals("enqueue") + // Find the cancel commands. + && invocation.getArgument(0) instanceof CancelServerStreamCommand), + invocation -> invocation.getArgument(0, CancelServerStreamCommand.class))); + + assertWithMessage("Expected exactly one CancelClientStreamCommand").that(commands).hasSize(1); + return commands.get(0); + } + @Test public void writeHeadersShouldSendHeaders() throws Exception { Metadata headers = new Metadata(); @@ -276,10 +384,31 @@ public void emptyFramerShouldSendNoPayload() { public void cancelStreamShouldSucceed() { stream().cancel(Status.DEADLINE_EXCEEDED); verify(writeQueue).enqueue( - new CancelServerStreamCommand(stream().transportState(), Status.DEADLINE_EXCEEDED), + CancelServerStreamCommand.withReset(stream().transportState(), Status.DEADLINE_EXCEEDED), true); } + @Test + public void oversizedMessagesResultInResourceExhaustedTrailers() throws Exception { + @SuppressWarnings("InlineMeInliner") // Requires Java 11 + String oversizedMsg = Strings.repeat("a", TEST_MAX_MESSAGE_SIZE + 1); + stream.request(1); + stream.transportState().inboundDataReceived(messageFrame(oversizedMsg), false); + assertNull("message should have caused a deframer error", listenerMessageQueue().poll()); + + ArgumentCaptor cancelCmdCap = + ArgumentCaptor.forClass(CancelServerStreamCommand.class); + verify(writeQueue).enqueue(cancelCmdCap.capture(), eq(true)); + + Status status = Status.RESOURCE_EXHAUSTED + .withDescription("gRPC message exceeds maximum size 128: 129"); + + CancelServerStreamCommand actualCmd = cancelCmdCap.getValue(); + assertThat(actualCmd.reason().getCode()).isEqualTo(status.getCode()); + assertThat(actualCmd.reason().getDescription()).isEqualTo(status.getDescription()); + assertThat(actualCmd.wantsHeaders()).isTrue(); + } + @Override @SuppressWarnings("DirectInvocationOnMock") protected NettyServerStream createStream() { @@ -287,10 +416,10 @@ protected NettyServerStream createStream() { StatsTraceContext statsTraceCtx = StatsTraceContext.NOOP; TransportTracer transportTracer = new TransportTracer(); NettyServerStream.TransportState state = new NettyServerStream.TransportState( - handler, channel.eventLoop(), http2Stream, DEFAULT_MAX_MESSAGE_SIZE, statsTraceCtx, + handler, channel.eventLoop(), http2Stream, TEST_MAX_MESSAGE_SIZE, statsTraceCtx, transportTracer, "method"); NettyServerStream stream = new NettyServerStream(channel, state, Attributes.EMPTY, - "test-authority", statsTraceCtx, transportTracer); + "test-authority", statsTraceCtx); stream.transportState().setListener(serverListener); state.onStreamAllocated(); verify(serverListener, atLeastOnce()).onReady(); diff --git a/netty/src/test/java/io/grpc/netty/NettyServerTest.java b/netty/src/test/java/io/grpc/netty/NettyServerTest.java index 64d31070156..61c3f9e219e 100644 --- a/netty/src/test/java/io/grpc/netty/NettyServerTest.java +++ b/netty/src/test/java/io/grpc/netty/NettyServerTest.java @@ -37,6 +37,7 @@ import io.grpc.InternalChannelz.SocketStats; import io.grpc.InternalInstrumented; import io.grpc.Metadata; +import io.grpc.MetricRecorder; import io.grpc.ServerStreamTracer; import io.grpc.internal.FixedObjectPool; import io.grpc.internal.ServerListener; @@ -133,29 +134,35 @@ class NoHandlerProtocolNegotiator implements ProtocolNegotiator { } NoHandlerProtocolNegotiator protocolNegotiator = new NoHandlerProtocolNegotiator(); - NettyServer ns = new NettyServer( - Arrays.asList(addr), - new ReflectiveChannelFactory<>(NioServerSocketChannel.class), - new HashMap, Object>(), - new HashMap, Object>(), - new FixedObjectPool<>(eventLoop), - new FixedObjectPool<>(eventLoop), - false, - protocolNegotiator, - Collections.emptyList(), - TransportTracer.getDefaultFactory(), - 1, // ignore - false, // ignore - 1, // ignore - 1, // ignore - 1, // ignore - 1, // ignore - 1, 1, // ignore - 1, 1, // ignore - true, 0, // ignore - 0, 0, // ignore - Attributes.EMPTY, - channelz); + NettyServer ns = + new NettyServer( + Arrays.asList(addr), + new ReflectiveChannelFactory<>(NioServerSocketChannel.class), + new HashMap, Object>(), + new HashMap, Object>(), + new FixedObjectPool<>(eventLoop), + new FixedObjectPool<>(eventLoop), + false, + protocolNegotiator, + Collections.emptyList(), + TransportTracer.getDefaultFactory(), + 1, // ignore + false, // ignore + 1, // ignore + 1, // ignore + 1, // ignore + 1, // ignore + 1, // ignore + 1, + 1, // ignore + 1, + 1, // ignore + true, + 0, // ignore + 0, + 0, // ignore + Attributes.EMPTY, + channelz, mock(MetricRecorder.class)); final SettableFuture serverShutdownCalled = SettableFuture.create(); ns.start(new ServerListener() { @Override @@ -184,29 +191,35 @@ public void multiPortStartStopGet() throws Exception { InetSocketAddress addr1 = new InetSocketAddress(0); InetSocketAddress addr2 = new InetSocketAddress(0); - NettyServer ns = new NettyServer( - Arrays.asList(addr1, addr2), - new ReflectiveChannelFactory<>(NioServerSocketChannel.class), - new HashMap, Object>(), - new HashMap, Object>(), - new FixedObjectPool<>(eventLoop), - new FixedObjectPool<>(eventLoop), - false, - ProtocolNegotiators.plaintext(), - Collections.emptyList(), - TransportTracer.getDefaultFactory(), - 1, // ignore - false, // ignore - 1, // ignore - 1, // ignore - 1, // ignore - 1, // ignore - 1, 1, // ignore - 1, 1, // ignore - true, 0, // ignore - 0, 0, // ignore - Attributes.EMPTY, - channelz); + NettyServer ns = + new NettyServer( + Arrays.asList(addr1, addr2), + new ReflectiveChannelFactory<>(NioServerSocketChannel.class), + new HashMap, Object>(), + new HashMap, Object>(), + new FixedObjectPool<>(eventLoop), + new FixedObjectPool<>(eventLoop), + false, + ProtocolNegotiators.plaintext(), + Collections.emptyList(), + TransportTracer.getDefaultFactory(), + 1, // ignore + false, // ignore + 1, // ignore + 1, // ignore + 1, // ignore + 1, // ignore + 1, // ignore + 1, + 1, // ignore + 1, + 1, // ignore + true, + 0, // ignore + 0, + 0, // ignore + Attributes.EMPTY, + channelz, mock(MetricRecorder.class)); final SettableFuture shutdownCompleted = SettableFuture.create(); ns.start(new ServerListener() { @Override @@ -258,29 +271,35 @@ public void multiPortConnections() throws Exception { InetSocketAddress addr2 = new InetSocketAddress(0); final CountDownLatch allPortsConnectedCountDown = new CountDownLatch(2); - NettyServer ns = new NettyServer( - Arrays.asList(addr1, addr2), - new ReflectiveChannelFactory<>(NioServerSocketChannel.class), - new HashMap, Object>(), - new HashMap, Object>(), - new FixedObjectPool<>(eventLoop), - new FixedObjectPool<>(eventLoop), - false, - ProtocolNegotiators.plaintext(), - Collections.emptyList(), - TransportTracer.getDefaultFactory(), - 1, // ignore - false, // ignore - 1, // ignore - 1, // ignore - 1, // ignore - 1, // ignore - 1, 1, // ignore - 1, 1, // ignore - true, 0, // ignore - 0, 0, // ignore - Attributes.EMPTY, - channelz); + NettyServer ns = + new NettyServer( + Arrays.asList(addr1, addr2), + new ReflectiveChannelFactory<>(NioServerSocketChannel.class), + new HashMap, Object>(), + new HashMap, Object>(), + new FixedObjectPool<>(eventLoop), + new FixedObjectPool<>(eventLoop), + false, + ProtocolNegotiators.plaintext(), + Collections.emptyList(), + TransportTracer.getDefaultFactory(), + 1, // ignore + false, // ignore + 1, // ignore + 1, // ignore + 1, // ignore + 1, // ignore + 1, // ignore + 1, + 1, // ignore + 1, + 1, // ignore + true, + 0, // ignore + 0, + 0, // ignore + Attributes.EMPTY, + channelz, mock(MetricRecorder.class)); final SettableFuture shutdownCompleted = SettableFuture.create(); ns.start(new ServerListener() { @Override @@ -320,29 +339,35 @@ public void run() {} public void getPort_notStarted() { InetSocketAddress addr = new InetSocketAddress(0); List addresses = Collections.singletonList(addr); - NettyServer ns = new NettyServer( - addresses, - new ReflectiveChannelFactory<>(NioServerSocketChannel.class), - new HashMap, Object>(), - new HashMap, Object>(), - new FixedObjectPool<>(eventLoop), - new FixedObjectPool<>(eventLoop), - false, - ProtocolNegotiators.plaintext(), - Collections.emptyList(), - TransportTracer.getDefaultFactory(), - 1, // ignore - false, // ignore - 1, // ignore - 1, // ignore - 1, // ignore - 1, // ignore - 1, 1, // ignore - 1, 1, // ignore - true, 0, // ignore - 0, 0, // ignore - Attributes.EMPTY, - channelz); + NettyServer ns = + new NettyServer( + addresses, + new ReflectiveChannelFactory<>(NioServerSocketChannel.class), + new HashMap, Object>(), + new HashMap, Object>(), + new FixedObjectPool<>(eventLoop), + new FixedObjectPool<>(eventLoop), + false, + ProtocolNegotiators.plaintext(), + Collections.emptyList(), + TransportTracer.getDefaultFactory(), + 1, // ignore + false, // ignore + 1, // ignore + 1, // ignore + 1, // ignore + 1, // ignore + 1, // ignore + 1, + 1, // ignore + 1, + 1, // ignore + true, + 0, // ignore + 0, + 0, // ignore + Attributes.EMPTY, + channelz, mock(MetricRecorder.class)); assertThat(ns.getListenSocketAddress()).isEqualTo(addr); assertThat(ns.getListenSocketAddresses()).isEqualTo(addresses); @@ -395,29 +420,35 @@ class TestProtocolNegotiator implements ProtocolNegotiator { .build(); TestProtocolNegotiator protocolNegotiator = new TestProtocolNegotiator(); InetSocketAddress addr = new InetSocketAddress(0); - NettyServer ns = new NettyServer( - Arrays.asList(addr), - new ReflectiveChannelFactory<>(NioServerSocketChannel.class), - new HashMap, Object>(), - childChannelOptions, - new FixedObjectPool<>(eventLoop), - new FixedObjectPool<>(eventLoop), - false, - protocolNegotiator, - Collections.emptyList(), - TransportTracer.getDefaultFactory(), - 1, // ignore - false, // ignore - 1, // ignore - 1, // ignore - 1, // ignore - 1, // ignore - 1, 1, // ignore - 1, 1, // ignore - true, 0, // ignore - 0, 0, // ignore - eagAttributes, - channelz); + NettyServer ns = + new NettyServer( + Arrays.asList(addr), + new ReflectiveChannelFactory<>(NioServerSocketChannel.class), + new HashMap, Object>(), + childChannelOptions, + new FixedObjectPool<>(eventLoop), + new FixedObjectPool<>(eventLoop), + false, + protocolNegotiator, + Collections.emptyList(), + TransportTracer.getDefaultFactory(), + 1, // ignore + false, // ignore + 1, // ignore + 1, // ignore + 1, // ignore + 1, // ignore + 1, // ignore + 1, + 1, // ignore + 1, + 1, // ignore + true, + 0, // ignore + 0, + 0, // ignore + eagAttributes, + channelz, mock(MetricRecorder.class)); ns.start(new ServerListener() { @Override public ServerTransportListener transportCreated(ServerTransport transport) { @@ -443,29 +474,35 @@ public void serverShutdown() {} @Test public void channelzListenSocket() throws Exception { InetSocketAddress addr = new InetSocketAddress(0); - NettyServer ns = new NettyServer( - Arrays.asList(addr), - new ReflectiveChannelFactory<>(NioServerSocketChannel.class), - new HashMap, Object>(), - new HashMap, Object>(), - new FixedObjectPool<>(eventLoop), - new FixedObjectPool<>(eventLoop), - false, - ProtocolNegotiators.plaintext(), - Collections.emptyList(), - TransportTracer.getDefaultFactory(), - 1, // ignore - false, // ignore - 1, // ignore - 1, // ignore - 1, // ignore - 1, // ignore - 1, 1, // ignore - 1, 1, // ignore - true, 0, // ignore - 0, 0, // ignore - Attributes.EMPTY, - channelz); + NettyServer ns = + new NettyServer( + Arrays.asList(addr), + new ReflectiveChannelFactory<>(NioServerSocketChannel.class), + new HashMap, Object>(), + new HashMap, Object>(), + new FixedObjectPool<>(eventLoop), + new FixedObjectPool<>(eventLoop), + false, + ProtocolNegotiators.plaintext(), + Collections.emptyList(), + TransportTracer.getDefaultFactory(), + 1, // ignore + false, // ignore + 1, // ignore + 1, // ignore + 1, // ignore + 1, // ignore + 1, // ignore + 1, + 1, // ignore + 1, + 1, // ignore + true, + 0, // ignore + 0, + 0, // ignore + Attributes.EMPTY, + channelz, mock(MetricRecorder.class)); final SettableFuture shutdownCompleted = SettableFuture.create(); ns.start(new ServerListener() { @Override @@ -603,12 +640,17 @@ private NettyServer getServer(List addr, EventLoopGroup ev) { 1, // ignore 1, // ignore 1, // ignore - 1, 1, // ignore - 1, 1, // ignore - true, 0, // ignore - 0, 0, // ignore + 1, // ignore + 1, + 1, // ignore + 1, + 1, // ignore + true, + 0, // ignore + 0, + 0, // ignore Attributes.EMPTY, - channelz); + channelz, mock(MetricRecorder.class)); } private static class NoopServerTransportListener implements ServerTransportListener { diff --git a/netty/src/test/java/io/grpc/netty/NettyStreamTestBase.java b/netty/src/test/java/io/grpc/netty/NettyStreamTestBase.java index f073fb6b2e4..ce42e3d25df 100644 --- a/netty/src/test/java/io/grpc/netty/NettyStreamTestBase.java +++ b/netty/src/test/java/io/grpc/netty/NettyStreamTestBase.java @@ -16,8 +16,8 @@ package io.grpc.netty; -import static com.google.common.base.Charsets.US_ASCII; import static io.grpc.netty.NettyTestUtil.messageFrame; +import static java.nio.charset.StandardCharsets.US_ASCII; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNull; diff --git a/netty/src/test/java/io/grpc/netty/NettyTransportTest.java b/netty/src/test/java/io/grpc/netty/NettyTransportTest.java index b1c89e22f93..22758a8b727 100644 --- a/netty/src/test/java/io/grpc/netty/NettyTransportTest.java +++ b/netty/src/test/java/io/grpc/netty/NettyTransportTest.java @@ -22,10 +22,12 @@ import com.google.common.util.concurrent.SettableFuture; import io.grpc.Attributes; import io.grpc.ChannelLogger; +import io.grpc.MetricRecorder; import io.grpc.ServerStreamTracer; import io.grpc.Status; import io.grpc.internal.AbstractTransportTest; import io.grpc.internal.ClientTransportFactory; +import io.grpc.internal.DisconnectError; import io.grpc.internal.FakeClock; import io.grpc.internal.InternalServer; import io.grpc.internal.ManagedClientTransport; @@ -70,7 +72,7 @@ protected InternalServer newServer( .forAddress(new InetSocketAddress("localhost", 0)) .flowControlWindow(AbstractTransportTest.TEST_FLOW_CONTROL_WINDOW) .setTransportTracerFactory(fakeClockTransportTracer) - .buildTransportServers(streamTracerFactories); + .buildTransportServers(streamTracerFactories, new MetricRecorder() {}); } @Override @@ -80,7 +82,7 @@ protected InternalServer newServer( .forAddress(new InetSocketAddress("localhost", port)) .flowControlWindow(AbstractTransportTest.TEST_FLOW_CONTROL_WINDOW) .setTransportTracerFactory(fakeClockTransportTracer) - .buildTransportServers(streamTracerFactories); + .buildTransportServers(streamTracerFactories, new MetricRecorder() {}); } @Override @@ -127,7 +129,7 @@ public void channelHasUnresolvedHostname() throws Exception { .setChannelLogger(logger), logger); Runnable runnable = transport.start(new ManagedClientTransport.Listener() { @Override - public void transportShutdown(Status s) { + public void transportShutdown(Status s, DisconnectError e) { future.set(s); } diff --git a/netty/src/test/java/io/grpc/netty/NettyWritableBufferAllocatorTest.java b/netty/src/test/java/io/grpc/netty/NettyWritableBufferAllocatorTest.java index d577ec46b03..0b741ae24b3 100644 --- a/netty/src/test/java/io/grpc/netty/NettyWritableBufferAllocatorTest.java +++ b/netty/src/test/java/io/grpc/netty/NettyWritableBufferAllocatorTest.java @@ -40,13 +40,6 @@ protected WritableBufferAllocator allocator() { return allocator; } - @Test - public void testCapacityHasMinimum() { - WritableBuffer buffer = allocator().allocate(100); - assertEquals(0, buffer.readableBytes()); - assertEquals(4096, buffer.writableBytes()); - } - @Test public void testCapacityIsExactAboveMinimum() { WritableBuffer buffer = allocator().allocate(9000); diff --git a/netty/src/test/java/io/grpc/netty/ProtocolNegotiatorsTest.java b/netty/src/test/java/io/grpc/netty/ProtocolNegotiatorsTest.java index 1852213da52..403b1b64329 100644 --- a/netty/src/test/java/io/grpc/netty/ProtocolNegotiatorsTest.java +++ b/netty/src/test/java/io/grpc/netty/ProtocolNegotiatorsTest.java @@ -16,14 +16,15 @@ package io.grpc.netty; -import static com.google.common.base.Charsets.UTF_8; import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.base.Preconditions.checkState; import static com.google.common.truth.Truth.assertThat; +import static java.nio.charset.StandardCharsets.UTF_8; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; @@ -31,6 +32,7 @@ import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; +import com.google.common.base.Optional; import io.grpc.Attributes; import io.grpc.CallCredentials; import io.grpc.ChannelCredentials; @@ -44,6 +46,7 @@ import io.grpc.InternalChannelz; import io.grpc.InternalChannelz.Security; import io.grpc.Metadata; +import io.grpc.MetricRecorder; import io.grpc.SecurityLevel; import io.grpc.ServerCredentials; import io.grpc.ServerStreamTracer; @@ -53,6 +56,7 @@ import io.grpc.TlsChannelCredentials; import io.grpc.TlsServerCredentials; import io.grpc.internal.ClientTransportFactory; +import io.grpc.internal.DisconnectError; import io.grpc.internal.GrpcAttributes; import io.grpc.internal.InternalServer; import io.grpc.internal.ManagedClientTransport; @@ -111,15 +115,20 @@ import io.netty.handler.ssl.SslHandler; import io.netty.handler.ssl.SslHandshakeCompletionEvent; import java.io.File; +import java.io.IOException; import java.io.InputStream; import java.net.InetSocketAddress; import java.net.SocketAddress; import java.security.KeyStore; +import java.security.KeyStoreException; +import java.security.NoSuchAlgorithmException; +import java.security.cert.CertificateException; import java.security.cert.X509Certificate; import java.util.ArrayDeque; import java.util.Arrays; import java.util.Collections; import java.util.List; +import java.util.Map; import java.util.Queue; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; @@ -141,7 +150,6 @@ import org.junit.Rule; import org.junit.Test; import org.junit.rules.DisableOnDebug; -import org.junit.rules.ExpectedException; import org.junit.rules.TestRule; import org.junit.rules.Timeout; import org.junit.runner.RunWith; @@ -169,8 +177,6 @@ public static void loadCerts() throws Exception { private static final int TIMEOUT_SECONDS = 60; @Rule public final TestRule globalTimeout = new DisableOnDebug(Timeout.seconds(TIMEOUT_SECONDS)); - @SuppressWarnings("deprecation") // https://github.com/grpc/grpc-java/issues/7467 - @Rule public final ExpectedException thrown = ExpectedException.none(); private final EventLoopGroup group = new DefaultEventLoop(); private Channel chan; @@ -221,13 +227,52 @@ public ChannelCredentials withoutBearerTokens() { } @Test - public void fromClient_tls() { + public void fromClient_tls_trustManager() + throws KeyStoreException, CertificateException, IOException, NoSuchAlgorithmException { + KeyStore certStore = KeyStore.getInstance(KeyStore.getDefaultType()); + certStore.load(null); + TrustManagerFactory trustManagerFactory = + TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()); + try (InputStream ca = TlsTesting.loadCert("ca.pem")) { + for (X509Certificate cert : CertificateUtils.getX509Certificates(ca)) { + certStore.setCertificateEntry(cert.getSubjectX500Principal().getName("RFC2253"), cert); + } + } + trustManagerFactory.init(certStore); + ProtocolNegotiators.FromChannelCredentialsResult result = + ProtocolNegotiators.from(TlsChannelCredentials.newBuilder() + .trustManager(trustManagerFactory.getTrustManagers()).build()); + assertThat(result.error).isNull(); + assertThat(result.callCredentials).isNull(); + assertThat(result.negotiator) + .isInstanceOf(ProtocolNegotiators.TlsProtocolNegotiatorClientFactory.class); + assertThat(((ClientTlsProtocolNegotiator) result.negotiator.newNegotiator()) + .hasX509ExtendedTrustManager()).isTrue(); + } + + @Test + public void fromClient_tls_CaCertsInputStream() throws IOException { + ProtocolNegotiators.FromChannelCredentialsResult result = + ProtocolNegotiators.from(TlsChannelCredentials.newBuilder() + .trustManager(TlsTesting.loadCert("ca.pem")).build()); + assertThat(result.error).isNull(); + assertThat(result.callCredentials).isNull(); + assertThat(result.negotiator) + .isInstanceOf(ProtocolNegotiators.TlsProtocolNegotiatorClientFactory.class); + assertThat(((ClientTlsProtocolNegotiator) result.negotiator.newNegotiator()) + .hasX509ExtendedTrustManager()).isTrue(); + } + + @Test + public void fromClient_tls_systemDefault() { ProtocolNegotiators.FromChannelCredentialsResult result = ProtocolNegotiators.from(TlsChannelCredentials.create()); assertThat(result.error).isNull(); assertThat(result.callCredentials).isNull(); assertThat(result.negotiator) .isInstanceOf(ProtocolNegotiators.TlsProtocolNegotiatorClientFactory.class); + assertThat(((ClientTlsProtocolNegotiator) result.negotiator.newNegotiator()) + .hasX509ExtendedTrustManager()).isTrue(); } @Test @@ -345,7 +390,9 @@ private Object expectHandshake( .buildTransportFactory(); InternalServer server = NettyServerBuilder .forPort(0, serverCreds) - .buildTransportServers(Collections.emptyList()); + .buildTransportServers( + Collections.emptyList(), + new MetricRecorder() {}); server.start(serverListener); ManagedClientTransport.Listener clientTransportListener = @@ -366,7 +413,7 @@ private Object expectHandshake( } else { ArgumentCaptor captor = ArgumentCaptor.forClass(Status.class); verify(clientTransportListener, timeout(TIMEOUT_SECONDS * 1000)) - .transportShutdown(captor.capture()); + .transportShutdown(captor.capture(), any(DisconnectError.class)); result = captor.getValue(); } @@ -670,11 +717,10 @@ public void handlerAdded(ChannelHandlerContext ctx) throws Exception { } @Test - public void tlsHandler_failsOnNullEngine() throws Exception { - thrown.expect(NullPointerException.class); - thrown.expectMessage("ssl"); - - Object unused = ProtocolNegotiators.serverTls(null); + public void tlsHandler_failsOnNullEngine() { + NullPointerException e = assertThrows(NullPointerException.class, + () -> ProtocolNegotiators.serverTls(null)); + assertThat(e).hasMessageThat().isEqualTo("sslContext"); } @@ -876,7 +922,8 @@ public String applicationProtocol() { DefaultEventLoopGroup elg = new DefaultEventLoopGroup(1); ClientTlsHandler handler = new ClientTlsHandler(grpcHandler, sslContext, - "authority", elg, noopLogger); + "authority", elg, noopLogger, Optional.absent(), + getClientTlsProtocolNegotiator(), null); pipeline.addLast(handler); pipeline.replace(SslHandler.class, null, goodSslHandler); pipeline.fireUserEventTriggered(ProtocolNegotiationEvent.DEFAULT); @@ -914,7 +961,8 @@ public String applicationProtocol() { .applicationProtocolConfig(apn).build(); ClientTlsHandler handler = new ClientTlsHandler(grpcHandler, sslContext, - "authority", elg, noopLogger); + "authority", elg, noopLogger, Optional.absent(), + getClientTlsProtocolNegotiator(), null); pipeline.addLast(handler); pipeline.replace(SslHandler.class, null, goodSslHandler); pipeline.fireUserEventTriggered(ProtocolNegotiationEvent.DEFAULT); @@ -938,7 +986,8 @@ public String applicationProtocol() { DefaultEventLoopGroup elg = new DefaultEventLoopGroup(1); ClientTlsHandler handler = new ClientTlsHandler(grpcHandler, sslContext, - "authority", elg, noopLogger); + "authority", elg, noopLogger, Optional.absent(), + getClientTlsProtocolNegotiator(), null); pipeline.addLast(handler); final AtomicReference error = new AtomicReference<>(); @@ -966,7 +1015,8 @@ public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { @Test public void clientTlsHandler_closeDuringNegotiation() throws Exception { ClientTlsHandler handler = new ClientTlsHandler(grpcHandler, sslContext, - "authority", null, noopLogger); + "authority", null, noopLogger, Optional.absent(), + getClientTlsProtocolNegotiator(), null); pipeline.addLast(new WriteBufferingAndExceptionHandler(handler)); ChannelFuture pendingWrite = channel.writeAndFlush(NettyClientHandler.NOOP_MESSAGE); @@ -978,6 +1028,12 @@ public void clientTlsHandler_closeDuringNegotiation() throws Exception { .isEqualTo(Status.Code.UNAVAILABLE); } + private ClientTlsProtocolNegotiator getClientTlsProtocolNegotiator() throws SSLException { + return new ClientTlsProtocolNegotiator(GrpcSslContexts.forClient().trustManager( + TlsTesting.loadCert("ca.pem")).build(), + null, Optional.absent(), null, ""); + } + @Test public void engineLog() { ChannelHandler handler = new ServerTlsHandler(grpcHandler, sslContext, null); @@ -1004,9 +1060,8 @@ public boolean isLoggable(LogRecord record) { @Test public void tls_failsOnNullSslContext() { - thrown.expect(NullPointerException.class); - - Object unused = ProtocolNegotiators.tls(null); + assertThrows(NullPointerException.class, + () -> ProtocolNegotiators.tls(null, null)); } @Test @@ -1036,23 +1091,23 @@ public void tls_invalidHost() throws SSLException { } @Test - public void httpProxy_nullAddressNpe() throws Exception { - thrown.expect(NullPointerException.class); - Object unused = - ProtocolNegotiators.httpProxy(null, "user", "pass", ProtocolNegotiators.plaintext()); + public void httpProxy_nullAddressNpe() { + assertThrows(NullPointerException.class, + () -> ProtocolNegotiators.httpProxy(null, null, "user", "pass", + ProtocolNegotiators.plaintext())); } @Test - public void httpProxy_nullNegotiatorNpe() throws Exception { - thrown.expect(NullPointerException.class); - Object unused = ProtocolNegotiators.httpProxy( - InetSocketAddress.createUnresolved("localhost", 80), "user", "pass", null); + public void httpProxy_nullNegotiatorNpe() { + assertThrows(NullPointerException.class, + () -> ProtocolNegotiators.httpProxy( + InetSocketAddress.createUnresolved("localhost", 80), null, "user", "pass", null)); } @Test public void httpProxy_nullUserPassNoException() throws Exception { assertNotNull(ProtocolNegotiators.httpProxy( - InetSocketAddress.createUnresolved("localhost", 80), null, null, + InetSocketAddress.createUnresolved("localhost", 80), null, null, null, ProtocolNegotiators.plaintext())); } @@ -1070,7 +1125,7 @@ public void httpProxy_completes() throws Exception { .bind(proxy).sync().channel(); ProtocolNegotiator nego = - ProtocolNegotiators.httpProxy(proxy, null, null, ProtocolNegotiators.plaintext()); + ProtocolNegotiators.httpProxy(proxy, null, null, null, ProtocolNegotiators.plaintext()); // normally NettyClientTransport will add WBAEH which kick start the ProtocolNegotiation, // mocking the behavior using KickStartHandler. ChannelHandler handler = @@ -1133,7 +1188,7 @@ public void httpProxy_500() throws Exception { .bind(proxy).sync().channel(); ProtocolNegotiator nego = - ProtocolNegotiators.httpProxy(proxy, null, null, ProtocolNegotiators.plaintext()); + ProtocolNegotiators.httpProxy(proxy, null, null, null, ProtocolNegotiators.plaintext()); // normally NettyClientTransport will add WBAEH which kick start the ProtocolNegotiation, // mocking the behavior using KickStartHandler. ChannelHandler handler = @@ -1164,14 +1219,84 @@ public void httpProxy_500() throws Exception { assertFalse(negotiationFuture.isDone()); String response = "HTTP/1.1 500 OMG\r\nContent-Length: 4\r\n\r\noops"; serverContext.writeAndFlush(bb(response, serverContext.channel())).sync(); - thrown.expect(ProxyConnectException.class); try { - negotiationFuture.sync(); + assertThrows(ProxyConnectException.class, negotiationFuture::sync); } finally { channel.close(); } } + @Test + public void httpProxy_customHeaders() throws Exception { + DefaultEventLoopGroup elg = new DefaultEventLoopGroup(1); + // ProxyHandler is incompatible with EmbeddedChannel because when channelRegistered() is called + // the channel is already active. + LocalAddress proxy = new LocalAddress("httpProxy_customHeaders"); + SocketAddress host = InetSocketAddress.createUnresolved("example.com", 443); + + ChannelInboundHandler mockHandler = mock(ChannelInboundHandler.class); + Channel serverChannel = new ServerBootstrap().group(elg).channel(LocalServerChannel.class) + .childHandler(mockHandler) + .bind(proxy).sync().channel(); + + Map headers = new java.util.HashMap<>(); + headers.put("X-Custom-Header", "custom-value"); + headers.put("Proxy-Authorization", "Bearer token123"); + + ProtocolNegotiator nego = ProtocolNegotiators.httpProxy( + proxy, headers, null, null, ProtocolNegotiators.plaintext()); + // normally NettyClientTransport will add WBAEH which kick start the ProtocolNegotiation, + // mocking the behavior using KickStartHandler. + ChannelHandler handler = + new KickStartHandler(nego.newHandler(FakeGrpcHttp2ConnectionHandler.noopHandler())); + Channel channel = new Bootstrap().group(elg).channel(LocalChannel.class).handler(handler) + .register().sync().channel(); + pipeline = channel.pipeline(); + // Wait for initialization to complete + channel.eventLoop().submit(NOOP_RUNNABLE).sync(); + channel.connect(host).sync(); + serverChannel.close(); + ArgumentCaptor contextCaptor = + ArgumentCaptor.forClass(ChannelHandlerContext.class); + Mockito.verify(mockHandler).channelActive(contextCaptor.capture()); + ChannelHandlerContext serverContext = contextCaptor.getValue(); + + final String golden = "testData"; + ChannelFuture negotiationFuture = channel.writeAndFlush(bb(golden, channel)); + + // Wait for sending initial request to complete + channel.eventLoop().submit(NOOP_RUNNABLE).sync(); + ArgumentCaptor objectCaptor = ArgumentCaptor.forClass(Object.class); + Mockito.verify(mockHandler) + .channelRead(ArgumentMatchers.any(), objectCaptor.capture()); + ByteBuf b = (ByteBuf) objectCaptor.getValue(); + String request = b.toString(UTF_8); + b.release(); + + // Verify custom headers are present in the CONNECT request + assertTrue("No trailing newline: " + request, request.endsWith("\r\n\r\n")); + assertTrue("No CONNECT: " + request, request.startsWith("CONNECT example.com:443 ")); + assertTrue("No custom header: " + request, + request.contains("X-Custom-Header: custom-value")); + assertTrue("No proxy authorization: " + request, + request.contains("Proxy-Authorization: Bearer token123")); + + assertFalse(negotiationFuture.isDone()); + serverContext.writeAndFlush(bb("HTTP/1.1 200 OK\r\n\r\n", serverContext.channel())).sync(); + negotiationFuture.sync(); + + channel.eventLoop().submit(NOOP_RUNNABLE).sync(); + objectCaptor = ArgumentCaptor.forClass(Object.class); + Mockito.verify(mockHandler, times(2)) + .channelRead(ArgumentMatchers.any(), objectCaptor.capture()); + b = (ByteBuf) objectCaptor.getAllValues().get(1); + String preface = b.toString(UTF_8); + b.release(); + assertEquals(golden, preface); + + channel.close(); + } + @Test public void waitUntilActiveHandler_firesNegotiation() throws Exception { EventLoopGroup elg = new DefaultEventLoopGroup(1); @@ -1228,7 +1353,8 @@ public void clientTlsHandler_firesNegotiation() throws Exception { serverSslContext = GrpcSslContexts.forServer(server1Chain, server1Key).build(); } FakeGrpcHttp2ConnectionHandler gh = FakeGrpcHttp2ConnectionHandler.newHandler(); - ClientTlsProtocolNegotiator pn = new ClientTlsProtocolNegotiator(clientSslContext, null); + ClientTlsProtocolNegotiator pn = new ClientTlsProtocolNegotiator(clientSslContext, + null, Optional.absent(), null, null); WriteBufferingAndExceptionHandler clientWbaeh = new WriteBufferingAndExceptionHandler(pn.newHandler(gh)); diff --git a/netty/src/test/java/io/grpc/netty/TcpMetricsTest.java b/netty/src/test/java/io/grpc/netty/TcpMetricsTest.java new file mode 100644 index 00000000000..f75a98b46df --- /dev/null +++ b/netty/src/test/java/io/grpc/netty/TcpMetricsTest.java @@ -0,0 +1,616 @@ +/* + * Copyright 2026 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.netty; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNotSame; +import static org.junit.Assert.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyLong; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; + +import io.grpc.InternalTcpMetrics; +import io.grpc.MetricRecorder; +import io.netty.util.concurrent.ScheduledFuture; +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.net.SocketAddress; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Objects; +import java.util.concurrent.TimeUnit; +import org.junit.After; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnit; +import org.mockito.junit.MockitoRule; + +@RunWith(JUnit4.class) +public class TcpMetricsTest { + + @Rule + public final MockitoRule mocks = MockitoJUnit.rule(); + + @Mock + private MetricRecorder metricRecorder; + + private ConfigurableFakeWithTcpInfo channel; + private TcpMetrics metrics; + + @Before + public void setUp() throws Exception { + FakeEpollTcpInfo dummyInfo = new FakeEpollTcpInfo(); + channel = new ConfigurableFakeWithTcpInfo(dummyInfo); + metrics = new TcpMetrics(metricRecorder); + } + + @After + public void tearDown() throws Exception { + TcpMetrics.epollInfo = TcpMetrics.loadEpollInfo(); + } + + @Test + public void metricsInitialization() { + + assertNotNull(InternalTcpMetrics.CONNECTIONS_CREATED_INSTRUMENT); + assertNotNull(InternalTcpMetrics.CONNECTION_COUNT_INSTRUMENT); + assertNotNull(InternalTcpMetrics.PACKETS_RETRANSMITTED_INSTRUMENT); + assertNotNull(InternalTcpMetrics.RECURRING_RETRANSMITS_INSTRUMENT); + assertNotNull(InternalTcpMetrics.MIN_RTT_INSTRUMENT); + } + + public static class FakeEpollTcpInfo { + long totalRetrans; + long retransmits; + long rtt; + + public void setValues(long totalRetrans, long retransmits, long rtt) { + this.totalRetrans = totalRetrans; + this.retransmits = retransmits; + this.rtt = rtt; + } + + @SuppressWarnings("unused") + public long totalRetrans() { + return totalRetrans; + } + + @SuppressWarnings("unused") + public long retrans() { + return retransmits; + } + + @SuppressWarnings("unused") + public long rtt() { + return rtt; + } + } + + @Test + public void tracker_recordTcpInfo_reflectionSuccess() throws Exception { + MetricRecorder recorder = mock(MetricRecorder.class); + TcpMetrics.epollInfo = new TcpMetrics.EpollInfo( + ConfigurableFakeWithTcpInfo.class, + FakeEpollTcpInfo.class.getConstructor(), + ConfigurableFakeWithTcpInfo.class.getMethod("tcpInfo", FakeEpollTcpInfo.class), + FakeEpollTcpInfo.class.getMethod("totalRetrans"), + FakeEpollTcpInfo.class.getMethod("retrans"), + FakeEpollTcpInfo.class.getMethod("rtt")); + TcpMetrics tracker = new TcpMetrics(recorder); + + FakeEpollTcpInfo infoSource = new FakeEpollTcpInfo(); + infoSource.setValues(123, 4, 5000); + ConfigurableFakeWithTcpInfo channel = new ConfigurableFakeWithTcpInfo(infoSource); + channel.writeInbound("dummy"); + + tracker.channelInactive(channel); + + verify(recorder).addLongCounter( + eq(Objects.requireNonNull(InternalTcpMetrics.PACKETS_RETRANSMITTED_INSTRUMENT)), + eq(123L), any(), any()); + verify(recorder).addLongCounter( + eq(Objects.requireNonNull(InternalTcpMetrics.RECURRING_RETRANSMITS_INSTRUMENT)), + eq(4L), any(), any()); + verify(recorder).recordDoubleHistogram( + eq(Objects.requireNonNull(InternalTcpMetrics.MIN_RTT_INSTRUMENT)), + eq(0.005), any(), any()); + } + + @Test + public void tracker_periodicRecord_doesNotRecordRecurringRetransmits() throws Exception { + MetricRecorder recorder = mock(MetricRecorder.class); + TcpMetrics.epollInfo = new TcpMetrics.EpollInfo( + ConfigurableFakeWithTcpInfo.class, + FakeEpollTcpInfo.class.getConstructor(), + ConfigurableFakeWithTcpInfo.class.getMethod("tcpInfo", FakeEpollTcpInfo.class), + FakeEpollTcpInfo.class.getMethod("totalRetrans"), + FakeEpollTcpInfo.class.getMethod("retrans"), + FakeEpollTcpInfo.class.getMethod("rtt")); + TcpMetrics tracker = new TcpMetrics(recorder); + + FakeEpollTcpInfo infoSource = new FakeEpollTcpInfo(); + infoSource.setValues(123, 4, 5000); + ConfigurableFakeWithTcpInfo channel = new ConfigurableFakeWithTcpInfo(infoSource); + + tracker.channelActive(channel); + + ScheduledFuture timer = tracker.getReportTimer(); + assertNotNull("Timer should be scheduled", timer); + + long delay = timer.getDelay(TimeUnit.MILLISECONDS); + channel.advanceTimeBy(delay + 1, TimeUnit.MILLISECONDS); + channel.runScheduledPendingTasks(); + + verify(recorder).addLongCounter( + eq(Objects.requireNonNull(InternalTcpMetrics.PACKETS_RETRANSMITTED_INSTRUMENT)), + eq(123L), any(), any()); + verify(recorder).recordDoubleHistogram( + eq(Objects.requireNonNull(InternalTcpMetrics.MIN_RTT_INSTRUMENT)), + eq(0.005), any(), any()); + // Should NOT record recurring retransmits during periodic polling + verify(recorder, org.mockito.Mockito.never()) + .addLongCounter( + eq(Objects.requireNonNull(InternalTcpMetrics.RECURRING_RETRANSMITS_INSTRUMENT)), + anyLong(), any(), any()); + } + + @Test + public void tracker_channelInactive_recordsRecurringRetransmits_raw_notDelta() throws Exception { + MetricRecorder recorder = mock(MetricRecorder.class); + TcpMetrics.epollInfo = new TcpMetrics.EpollInfo( + ConfigurableFakeWithTcpInfo.class, + FakeEpollTcpInfo.class.getConstructor(), + ConfigurableFakeWithTcpInfo.class.getMethod("tcpInfo", FakeEpollTcpInfo.class), + FakeEpollTcpInfo.class.getMethod("totalRetrans"), + FakeEpollTcpInfo.class.getMethod("retrans"), + FakeEpollTcpInfo.class.getMethod("rtt")); + TcpMetrics tracker = new TcpMetrics(recorder); + + FakeEpollTcpInfo infoSource = new FakeEpollTcpInfo(); + infoSource.setValues(123, 4, 5000); + ConfigurableFakeWithTcpInfo channel = new ConfigurableFakeWithTcpInfo(infoSource); + + tracker.channelActive(channel); + + ScheduledFuture timer = tracker.getReportTimer(); + assertNotNull("Timer should be scheduled", timer); + + long delay = timer.getDelay(TimeUnit.MILLISECONDS); + channel.advanceTimeBy(delay + 1, TimeUnit.MILLISECONDS); + channel.runScheduledPendingTasks(); + + org.mockito.Mockito.clearInvocations(recorder); + + // Let's just create a new channel instance where tcpInfo sets retrans=5. + FakeEpollTcpInfo infoSource2 = new FakeEpollTcpInfo(); + infoSource2.setValues(130, 5, 5000); + ConfigurableFakeWithTcpInfo channel2 = new ConfigurableFakeWithTcpInfo(infoSource2); + + tracker.channelInactive(channel2); + + // It should record delta for totalRetrans (130 - 123 = 7) + verify(recorder).addLongCounter( + eq(Objects.requireNonNull(InternalTcpMetrics.PACKETS_RETRANSMITTED_INSTRUMENT)), + eq(7L), any(), any()); + // But for recurringRetransmits it MUST record the raw value 5, not the delta! + verify(recorder).addLongCounter( + eq(Objects.requireNonNull(InternalTcpMetrics.RECURRING_RETRANSMITS_INSTRUMENT)), + eq(5L), any(), any()); + } + + @Test + public void tracker_periodicRecord_reportsDeltaForTotalRetrans() throws Exception { + MetricRecorder recorder = mock(MetricRecorder.class); + TcpMetrics.epollInfo = new TcpMetrics.EpollInfo( + ConfigurableFakeWithTcpInfo.class, + FakeEpollTcpInfo.class.getConstructor(), + ConfigurableFakeWithTcpInfo.class.getMethod("tcpInfo", FakeEpollTcpInfo.class), + FakeEpollTcpInfo.class.getMethod("totalRetrans"), + FakeEpollTcpInfo.class.getMethod("retrans"), + FakeEpollTcpInfo.class.getMethod("rtt")); + TcpMetrics tracker = new TcpMetrics(recorder); + + FakeEpollTcpInfo infoSource = new FakeEpollTcpInfo(); + infoSource.setValues(123, 4, 5000); + ConfigurableFakeWithTcpInfo channel = new ConfigurableFakeWithTcpInfo(infoSource); + + tracker.channelActive(channel); + + ScheduledFuture timer = tracker.getReportTimer(); + assertNotNull("Timer should be scheduled", timer); + + long delay = timer.getDelay(TimeUnit.MILLISECONDS); + channel.advanceTimeBy(delay + 1, TimeUnit.MILLISECONDS); + channel.runScheduledPendingTasks(); + + verify(recorder).addLongCounter( + eq(Objects.requireNonNull(InternalTcpMetrics.PACKETS_RETRANSMITTED_INSTRUMENT)), + eq(123L), any(), any()); + + org.mockito.Mockito.clearInvocations(recorder); + + // Change tcpInfo for second periodic record + infoSource.setValues(150, 2, 6000); // 150 - 123 = 27 + + ScheduledFuture newTimer = tracker.getReportTimer(); + assertNotNull("New timer should be scheduled", newTimer); + assertNotSame("Timer should be a new instance", timer, newTimer); + long newDelay = newTimer.getDelay(TimeUnit.MILLISECONDS); + channel.advanceTimeBy(newDelay + 1, TimeUnit.MILLISECONDS); + channel.runScheduledPendingTasks(); + + // Only the delta (150 - 123 = 27) should be recorded + verify(recorder).addLongCounter( + eq(Objects.requireNonNull(InternalTcpMetrics.PACKETS_RETRANSMITTED_INSTRUMENT)), + eq(27L), any(), any()); + verify(recorder).recordDoubleHistogram( + eq(Objects.requireNonNull(InternalTcpMetrics.MIN_RTT_INSTRUMENT)), + eq(0.006), any(), any()); + verify(recorder, org.mockito.Mockito.never()) + .addLongCounter( + eq(Objects.requireNonNull(InternalTcpMetrics.RECURRING_RETRANSMITS_INSTRUMENT)), + anyLong(), any(), any()); + } + + @Test + public void tracker_periodicRecord_doesNotReportZeroDeltaForTotalRetrans() throws Exception { + MetricRecorder recorder = mock(MetricRecorder.class); + TcpMetrics.epollInfo = new TcpMetrics.EpollInfo( + ConfigurableFakeWithTcpInfo.class, + FakeEpollTcpInfo.class.getConstructor(), + ConfigurableFakeWithTcpInfo.class.getMethod("tcpInfo", FakeEpollTcpInfo.class), + FakeEpollTcpInfo.class.getMethod("totalRetrans"), + FakeEpollTcpInfo.class.getMethod("retrans"), + FakeEpollTcpInfo.class.getMethod("rtt")); + TcpMetrics tracker = new TcpMetrics(recorder); + + FakeEpollTcpInfo infoSource = new FakeEpollTcpInfo(); + infoSource.setValues(123, 4, 5000); + ConfigurableFakeWithTcpInfo channel = new ConfigurableFakeWithTcpInfo(infoSource); + + tracker.channelActive(channel); + + ScheduledFuture timer = tracker.getReportTimer(); + assertNotNull("Timer should be scheduled", timer); + + long delay = timer.getDelay(TimeUnit.MILLISECONDS); + channel.advanceTimeBy(delay + 1, TimeUnit.MILLISECONDS); + channel.runScheduledPendingTasks(); + + verify(recorder).addLongCounter( + eq(Objects.requireNonNull(InternalTcpMetrics.PACKETS_RETRANSMITTED_INSTRUMENT)), + eq(123L), any(), any()); + + org.mockito.Mockito.clearInvocations(recorder); + + // Keep tcpInfo the same for second periodic record + ScheduledFuture newTimer = tracker.getReportTimer(); + assertNotNull("New timer should be scheduled", newTimer); + assertNotSame("Timer should be a new instance", timer, newTimer); + long newDelay = newTimer.getDelay(TimeUnit.MILLISECONDS); + channel.advanceTimeBy(newDelay + 1, TimeUnit.MILLISECONDS); + channel.runScheduledPendingTasks(); + + // NO delta (123 - 123 = 0), so it should not be recorded + verify(recorder, org.mockito.Mockito.never()) + .addLongCounter( + eq(Objects.requireNonNull(InternalTcpMetrics.PACKETS_RETRANSMITTED_INSTRUMENT)), + anyLong(), any(), any()); + + // MIN_RTT should be recorded again! + verify(recorder).recordDoubleHistogram( + eq(Objects.requireNonNull(InternalTcpMetrics.MIN_RTT_INSTRUMENT)), + eq(0.005), any(), any()); + } + + public static class ConfigurableFakeWithTcpInfo extends + io.netty.channel.embedded.EmbeddedChannel { + private final FakeEpollTcpInfo infoToCopy; + + public ConfigurableFakeWithTcpInfo(FakeEpollTcpInfo infoToCopy) { + this.infoToCopy = infoToCopy; + } + + public void tcpInfo(FakeEpollTcpInfo info) { + info.totalRetrans = infoToCopy.totalRetrans; + info.retransmits = infoToCopy.retransmits; + info.rtt = infoToCopy.rtt; + } + } + + private static class AddressOverrideEmbeddedChannel extends + io.netty.channel.embedded.EmbeddedChannel { + private final SocketAddress local; + private final SocketAddress remote; + + public AddressOverrideEmbeddedChannel(SocketAddress local, SocketAddress remote) { + this.local = local; + this.remote = remote; + } + + @Override + public SocketAddress localAddress() { + return local; + } + + @Override + public SocketAddress remoteAddress() { + return remote; + } + } + + @Test + public void tracker_reportsDeltas_correctly() throws Exception { + MetricRecorder recorder = mock(MetricRecorder.class); + + TcpMetrics.epollInfo = new TcpMetrics.EpollInfo( + ConfigurableFakeWithTcpInfo.class, + FakeEpollTcpInfo.class.getConstructor(), + ConfigurableFakeWithTcpInfo.class.getMethod("tcpInfo", FakeEpollTcpInfo.class), + FakeEpollTcpInfo.class.getMethod("totalRetrans"), + FakeEpollTcpInfo.class.getMethod("retrans"), + FakeEpollTcpInfo.class.getMethod("rtt")); + TcpMetrics tracker = new TcpMetrics(recorder); + + FakeEpollTcpInfo infoSource = new FakeEpollTcpInfo(); + ConfigurableFakeWithTcpInfo channel = new ConfigurableFakeWithTcpInfo(infoSource); + + // 10 retransmits total + infoSource.setValues(10, 2, 1000); + tracker.recordTcpInfo(channel); + + verify(recorder).addLongCounter( + eq(Objects.requireNonNull(InternalTcpMetrics.PACKETS_RETRANSMITTED_INSTRUMENT)), + eq(10L), any(), any()); + + // 15 retransmits total (delta 5) + infoSource.setValues(15, 0, 1000); + tracker.recordTcpInfo(channel); + + verify(recorder).addLongCounter( + eq(Objects.requireNonNull(InternalTcpMetrics.PACKETS_RETRANSMITTED_INSTRUMENT)), + eq(5L), any(), any()); + + // 15 retransmits total (delta 0) - should NOT report + // also set retransmits to 1 + infoSource.setValues(15, 1, 1000); + tracker.recordTcpInfo(channel); + // Verify no new interactions with this specific metric and value + // We can't easily verify "no interaction" for specific value without capturing. + verify(recorder, org.mockito.Mockito.times(1)).addLongCounter( + eq(Objects.requireNonNull(InternalTcpMetrics.PACKETS_RETRANSMITTED_INSTRUMENT)), + eq(10L), any(), any()); + verify(recorder, org.mockito.Mockito.times(1)).addLongCounter( + eq(Objects.requireNonNull(InternalTcpMetrics.PACKETS_RETRANSMITTED_INSTRUMENT)), + eq(5L), any(), any()); + // Total interactions for packetsRetransmitted should be 2 + verify(recorder, org.mockito.Mockito.times(2)).addLongCounter( + eq(Objects.requireNonNull(InternalTcpMetrics.PACKETS_RETRANSMITTED_INSTRUMENT)), + anyLong(), any(), any()); + + // recurringRetransmits should NOT have been reported yet (periodic calls) + verify(recorder, org.mockito.Mockito.times(0)).addLongCounter( + eq(Objects.requireNonNull(InternalTcpMetrics.RECURRING_RETRANSMITS_INSTRUMENT)), + anyLong(), any(), any()); + + // Close channel - should report recurringRetransmits + tracker.channelInactive(channel); + verify(recorder, org.mockito.Mockito.times(1)).addLongCounter( + eq(Objects.requireNonNull(InternalTcpMetrics.RECURRING_RETRANSMITS_INSTRUMENT)), + eq(1L), // From last infoSource setValues(15, 1, 1000) + any(), any()); + } + + @Test + public void tracker_recordTcpInfo_reflectionFailure() { + MetricRecorder recorder = mock(MetricRecorder.class); + + TcpMetrics.epollInfo = null; + TcpMetrics tracker = new TcpMetrics(recorder); + + io.netty.channel.embedded.EmbeddedChannel channel = new + io.netty.channel.embedded.EmbeddedChannel(); + + // Should catch exception and ignore + tracker.channelInactive(channel); + } + + @Test + public void registeredMetrics_haveCorrectOptionalLabels() { + List expectedOptionalLabels = Arrays.asList( + "network.local.address", + "network.local.port", + "network.peer.address", + "network.peer.port"); + + assertEquals( + expectedOptionalLabels, + InternalTcpMetrics.CONNECTIONS_CREATED_INSTRUMENT.getOptionalLabelKeys()); + assertEquals( + expectedOptionalLabels, + InternalTcpMetrics.CONNECTION_COUNT_INSTRUMENT.getOptionalLabelKeys()); + + assertEquals( + expectedOptionalLabels, + Objects.requireNonNull(InternalTcpMetrics.PACKETS_RETRANSMITTED_INSTRUMENT) + .getOptionalLabelKeys()); + assertEquals( + expectedOptionalLabels, + Objects.requireNonNull(InternalTcpMetrics.RECURRING_RETRANSMITS_INSTRUMENT) + .getOptionalLabelKeys()); + assertEquals( + expectedOptionalLabels, + Objects.requireNonNull(InternalTcpMetrics.MIN_RTT_INSTRUMENT).getOptionalLabelKeys()); + } + + @Test + public void channelActive_extractsLabels_ipv4() throws Exception { + InetAddress localInet = InetAddress.getByAddress(new byte[] {127, 0, 0, 1}); + InetAddress remoteInet = InetAddress.getByAddress(new byte[] {127, 0, 0, 2}); + + AddressOverrideEmbeddedChannel channel = new AddressOverrideEmbeddedChannel( + new InetSocketAddress(localInet, 8080), + new InetSocketAddress(remoteInet, 443)); + + metrics.channelActive(channel); + + verify(metricRecorder).addLongCounter( + eq(InternalTcpMetrics.CONNECTIONS_CREATED_INSTRUMENT), eq(1L), + eq(Collections.emptyList()), + eq(Arrays.asList("127.0.0.1", "8080", "127.0.0.2", "443"))); + verify(metricRecorder).addLongUpDownCounter( + eq(InternalTcpMetrics.CONNECTION_COUNT_INSTRUMENT), eq(1L), + eq(Collections.emptyList()), + eq(Arrays.asList("127.0.0.1", "8080", "127.0.0.2", "443"))); + verifyNoMoreInteractions(metricRecorder); + } + + @Test + public void channelInactive_extractsLabels_ipv6() throws Exception { + InetAddress localInet = InetAddress.getByAddress(new byte[] {0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1}); + InetAddress remoteInet = InetAddress.getByAddress(new byte[] {0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,2}); + + AddressOverrideEmbeddedChannel channel = new AddressOverrideEmbeddedChannel( + new InetSocketAddress(localInet, 8080), + new InetSocketAddress(remoteInet, 443)); + + metrics.channelInactive(channel); + + verify(metricRecorder).addLongUpDownCounter( + eq(InternalTcpMetrics.CONNECTION_COUNT_INSTRUMENT), eq(-1L), + eq(Collections.emptyList()), + eq(Arrays.asList("0:0:0:0:0:0:0:1", "8080", "0:0:0:0:0:0:0:2", "443"))); + verifyNoMoreInteractions(metricRecorder); + } + + @Test + public void channelActive_extractsLabels_nonInetAddress() { + SocketAddress dummyAddress = new SocketAddress() { + }; + AddressOverrideEmbeddedChannel channel = new AddressOverrideEmbeddedChannel( + dummyAddress, dummyAddress); + + metrics.channelActive(channel); + + verify(metricRecorder).addLongCounter( + eq(InternalTcpMetrics.CONNECTIONS_CREATED_INSTRUMENT), eq(1L), + eq(Collections.emptyList()), + eq(Arrays.asList("", "", "", ""))); + verify(metricRecorder).addLongUpDownCounter( + eq(InternalTcpMetrics.CONNECTION_COUNT_INSTRUMENT), eq(1L), + eq(Collections.emptyList()), + eq(Arrays.asList("", "", "", ""))); + verifyNoMoreInteractions(metricRecorder); + } + + @Test + public void channelActive_incrementsCounts() { + metrics.channelActive(channel); + verify(metricRecorder).addLongCounter( + eq(InternalTcpMetrics.CONNECTIONS_CREATED_INSTRUMENT), eq(1L), + eq(Collections.emptyList()), + eq(Arrays.asList("", "", "", ""))); + verify(metricRecorder).addLongUpDownCounter( + eq(InternalTcpMetrics.CONNECTION_COUNT_INSTRUMENT), eq(1L), + eq(Collections.emptyList()), + eq(Arrays.asList("", "", "", ""))); + verifyNoMoreInteractions(metricRecorder); + } + + @Test + public void channelInactive_decrementsCount_noEpoll_noError() { + metrics.channelInactive(channel); + verify(metricRecorder).addLongUpDownCounter( + eq(InternalTcpMetrics.CONNECTION_COUNT_INSTRUMENT), eq(-1L), + eq(Collections.emptyList()), + eq(Arrays.asList("", "", "", ""))); + verifyNoMoreInteractions(metricRecorder); + } + + @Test + public void channelActive_schedulesReportTimer() throws Exception { + TcpMetrics.epollInfo = new TcpMetrics.EpollInfo( + ConfigurableFakeWithTcpInfo.class, + FakeEpollTcpInfo.class.getConstructor(), + ConfigurableFakeWithTcpInfo.class.getMethod("tcpInfo", FakeEpollTcpInfo.class), + FakeEpollTcpInfo.class.getMethod("totalRetrans"), + FakeEpollTcpInfo.class.getMethod("retrans"), + FakeEpollTcpInfo.class.getMethod("rtt")); + + metrics = new TcpMetrics(metricRecorder); + + FakeEpollTcpInfo infoSource = new FakeEpollTcpInfo(); + ConfigurableFakeWithTcpInfo channel = new ConfigurableFakeWithTcpInfo(infoSource); + + metrics.channelActive(channel); + + ScheduledFuture timer = metrics.getReportTimer(); + assertNotNull("Timer should be scheduled", timer); + + long delay = timer.getDelay(TimeUnit.MILLISECONDS); + assertTrue("Delay should be >= 30000 but was " + delay, delay >= 30_000); + assertTrue("Delay should be <= 330000 but was " + delay, delay <= 330_000); + + // Advance time to trigger the task + channel.advanceTimeBy(delay + 1, TimeUnit.MILLISECONDS); + channel.runScheduledPendingTasks(); + + // Verify rescheduling + ScheduledFuture newTimer = metrics.getReportTimer(); + assertNotNull("New timer should be scheduled", newTimer); + assertNotSame("Timer should be a new instance", timer, newTimer); + + long newDelay = newTimer.getDelay(TimeUnit.MILLISECONDS); + // Re-arming jitter is 90% to 110%, so 270,000 ms to 330,000 ms + assertTrue("Delay should be >= 270000 but was " + newDelay, newDelay >= 270_000); + assertTrue("Delay should be <= 330000 but was " + newDelay, newDelay <= 330_000); + } + + @Test + public void channelInactive_cancelsReportTimer() throws Exception { + TcpMetrics.epollInfo = new TcpMetrics.EpollInfo( + ConfigurableFakeWithTcpInfo.class, + FakeEpollTcpInfo.class.getConstructor(), + ConfigurableFakeWithTcpInfo.class.getMethod("tcpInfo", FakeEpollTcpInfo.class), + FakeEpollTcpInfo.class.getMethod("totalRetrans"), + FakeEpollTcpInfo.class.getMethod("retrans"), + FakeEpollTcpInfo.class.getMethod("rtt")); + + metrics = new TcpMetrics(metricRecorder); + + FakeEpollTcpInfo infoSource = new FakeEpollTcpInfo(); + ConfigurableFakeWithTcpInfo channel = new ConfigurableFakeWithTcpInfo(infoSource); + + metrics.channelActive(channel); + + ScheduledFuture timer = metrics.getReportTimer(); + assertNotNull("Timer should be scheduled", timer); + + metrics.channelInactive(channel); + + assertTrue("Timer should be cancelled", timer.isCancelled()); + } +} diff --git a/netty/src/test/java/io/grpc/netty/UdsNameResolverProviderTest.java b/netty/src/test/java/io/grpc/netty/UdsNameResolverProviderTest.java index 6a329c8fc68..1766a8e4134 100644 --- a/netty/src/test/java/io/grpc/netty/UdsNameResolverProviderTest.java +++ b/netty/src/test/java/io/grpc/netty/UdsNameResolverProviderTest.java @@ -17,19 +17,30 @@ package io.grpc.netty; import static com.google.common.truth.Truth.assertThat; +import static com.google.common.truth.TruthJUnit.assume; import static org.junit.Assert.fail; +import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; +import io.grpc.ChannelLogger; import io.grpc.EquivalentAddressGroup; import io.grpc.NameResolver; +import io.grpc.NameResolver.ServiceConfigParser; +import io.grpc.SynchronizationContext; +import io.grpc.Uri; +import io.grpc.internal.FakeClock; +import io.grpc.internal.GrpcUtil; import io.netty.channel.unix.DomainSocketAddress; import java.net.SocketAddress; import java.net.URI; +import java.util.Arrays; import java.util.List; import org.junit.Rule; import org.junit.Test; import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; +import org.junit.runners.Parameterized; +import org.junit.runners.Parameterized.Parameter; +import org.junit.runners.Parameterized.Parameters; import org.mockito.ArgumentCaptor; import org.mockito.Captor; import org.mockito.Mock; @@ -37,8 +48,16 @@ import org.mockito.junit.MockitoRule; /** Unit tests for {@link UdsNameResolverProvider}. */ -@RunWith(JUnit4.class) +@RunWith(Parameterized.class) public class UdsNameResolverProviderTest { + private static final int DEFAULT_PORT = 887; + + @Parameters(name = "enableRfc3986UrisParam={0}") + public static Iterable data() { + return Arrays.asList(new Object[][] {{true}, {false}}); + } + + @Parameter public boolean enableRfc3986UrisParam; @Rule public final MockitoRule mocks = MockitoJUnit.rule(); @@ -51,56 +70,81 @@ public class UdsNameResolverProviderTest { UdsNameResolverProvider udsNameResolverProvider = new UdsNameResolverProvider(); + private final SynchronizationContext syncContext = new SynchronizationContext( + (t, e) -> { + throw new AssertionError(e); + }); + private final FakeClock fakeExecutor = new FakeClock(); + private final NameResolver.Args args = NameResolver.Args.newBuilder() + .setDefaultPort(DEFAULT_PORT) + .setProxyDetector(GrpcUtil.DEFAULT_PROXY_DETECTOR) + .setSynchronizationContext(syncContext) + .setServiceConfigParser(mock(ServiceConfigParser.class)) + .setChannelLogger(mock(ChannelLogger.class)) + .setScheduledExecutorService(fakeExecutor.getScheduledExecutorService()) + .build(); @Test public void testUnixRelativePath() { - UdsNameResolver udsNameResolver = - udsNameResolverProvider.newNameResolver(URI.create("unix:sock.sock"), null); - assertThat(udsNameResolver).isNotNull(); - udsNameResolver.start(mockListener); - verify(mockListener).onResult(resultCaptor.capture()); - NameResolver.ResolutionResult result = resultCaptor.getValue(); - List list = result.getAddresses(); - assertThat(list).isNotNull(); - assertThat(list).hasSize(1); - EquivalentAddressGroup eag = list.get(0); - assertThat(eag).isNotNull(); - List addresses = eag.getAddresses(); - assertThat(addresses).hasSize(1); - assertThat(addresses.get(0)).isInstanceOf(DomainSocketAddress.class); - DomainSocketAddress domainSocketAddress = (DomainSocketAddress) addresses.get(0); + UdsNameResolver udsNameResolver = newNameResolver("unix:sock.sock", args); + DomainSocketAddress domainSocketAddress = startAndGetUniqueResolvedAddress(udsNameResolver); assertThat(domainSocketAddress.path()).isEqualTo("sock.sock"); } @Test public void testUnixAbsolutePath() { - UdsNameResolver udsNameResolver = - udsNameResolverProvider.newNameResolver(URI.create("unix:/sock.sock"), null); - assertThat(udsNameResolver).isNotNull(); - udsNameResolver.start(mockListener); - verify(mockListener).onResult(resultCaptor.capture()); - NameResolver.ResolutionResult result = resultCaptor.getValue(); - List list = result.getAddresses(); - assertThat(list).isNotNull(); - assertThat(list).hasSize(1); - EquivalentAddressGroup eag = list.get(0); - assertThat(eag).isNotNull(); - List addresses = eag.getAddresses(); - assertThat(addresses).hasSize(1); - assertThat(addresses.get(0)).isInstanceOf(DomainSocketAddress.class); - DomainSocketAddress domainSocketAddress = (DomainSocketAddress) addresses.get(0); + UdsNameResolver udsNameResolver = newNameResolver("unix:/sock.sock", args); + DomainSocketAddress domainSocketAddress = startAndGetUniqueResolvedAddress(udsNameResolver); assertThat(domainSocketAddress.path()).isEqualTo("/sock.sock"); } @Test public void testUnixAbsoluteAlternatePath() { - UdsNameResolver udsNameResolver = - udsNameResolverProvider.newNameResolver(URI.create("unix:///sock.sock"), null); + UdsNameResolver udsNameResolver = newNameResolver("unix:///sock.sock", args); + DomainSocketAddress domainSocketAddress = startAndGetUniqueResolvedAddress(udsNameResolver); + assertThat(domainSocketAddress.path()).isEqualTo("/sock.sock"); + } + + @Test + public void testUnixPathWithAuthority() { + try { + newNameResolver("unix://localhost/sock.sock", args); + fail("exception expected"); + } catch (IllegalArgumentException e) { + assertThat(e).hasMessageThat().isEqualTo("authority not supported: localhost"); + } + } + + @Test + public void testUnixAbsolutePathDoesNotIncludeQueryOrFragment() { + UdsNameResolver udsNameResolver = newNameResolver("unix:///sock.sock?query#fragment", args); + DomainSocketAddress domainSocketAddress = startAndGetUniqueResolvedAddress(udsNameResolver); + assertThat(domainSocketAddress.path()).isEqualTo("/sock.sock"); + } + + @Test + public void testUnixRelativePathDoesNotIncludeQueryOrFragment() { + // This test fails without RFC 3986 support because of a bug in the legacy java.net.URI-based + // NRP implementation. + assume().that(enableRfc3986UrisParam).isTrue(); + + UdsNameResolver udsNameResolver = newNameResolver("unix:sock.sock?query#fragment", args); + DomainSocketAddress domainSocketAddress = startAndGetUniqueResolvedAddress(udsNameResolver); + assertThat(domainSocketAddress.path()).isEqualTo("sock.sock"); + } + + private UdsNameResolver newNameResolver(String uriString, NameResolver.Args args) { + return enableRfc3986UrisParam + ? (UdsNameResolver) udsNameResolverProvider.newNameResolver(Uri.create(uriString), args) + : udsNameResolverProvider.newNameResolver(URI.create(uriString), args); + } + + private DomainSocketAddress startAndGetUniqueResolvedAddress(UdsNameResolver udsNameResolver) { assertThat(udsNameResolver).isNotNull(); udsNameResolver.start(mockListener); - verify(mockListener).onResult(resultCaptor.capture()); + verify(mockListener).onResult2(resultCaptor.capture()); NameResolver.ResolutionResult result = resultCaptor.getValue(); - List list = result.getAddresses(); + List list = result.getAddressesOrError().getValue(); assertThat(list).isNotNull(); assertThat(list).hasSize(1); EquivalentAddressGroup eag = list.get(0); @@ -108,17 +152,6 @@ public void testUnixAbsoluteAlternatePath() { List addresses = eag.getAddresses(); assertThat(addresses).hasSize(1); assertThat(addresses.get(0)).isInstanceOf(DomainSocketAddress.class); - DomainSocketAddress domainSocketAddress = (DomainSocketAddress) addresses.get(0); - assertThat(domainSocketAddress.path()).isEqualTo("/sock.sock"); - } - - @Test - public void testUnixPathWithAuthority() { - try { - udsNameResolverProvider.newNameResolver(URI.create("unix://localhost/sock.sock"), null); - fail("exception expected"); - } catch (IllegalArgumentException e) { - assertThat(e).hasMessageThat().isEqualTo("non-null authority not supported"); - } + return (DomainSocketAddress) addresses.get(0); } } diff --git a/netty/src/test/java/io/grpc/netty/UdsNameResolverTest.java b/netty/src/test/java/io/grpc/netty/UdsNameResolverTest.java index 8eb010e23e5..7bf808c18ce 100644 --- a/netty/src/test/java/io/grpc/netty/UdsNameResolverTest.java +++ b/netty/src/test/java/io/grpc/netty/UdsNameResolverTest.java @@ -18,10 +18,16 @@ import static com.google.common.truth.Truth.assertThat; import static org.junit.Assert.fail; +import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; +import io.grpc.ChannelLogger; import io.grpc.EquivalentAddressGroup; import io.grpc.NameResolver; +import io.grpc.NameResolver.ServiceConfigParser; +import io.grpc.SynchronizationContext; +import io.grpc.internal.FakeClock; +import io.grpc.internal.GrpcUtil; import io.netty.channel.unix.DomainSocketAddress; import java.net.SocketAddress; import java.util.List; @@ -41,7 +47,20 @@ public class UdsNameResolverTest { @Rule public final MockitoRule mocks = MockitoJUnit.rule(); - + private static final int DEFAULT_PORT = 887; + private final FakeClock fakeExecutor = new FakeClock(); + private final SynchronizationContext syncContext = new SynchronizationContext( + (t, e) -> { + throw new AssertionError(e); + }); + private final NameResolver.Args args = NameResolver.Args.newBuilder() + .setDefaultPort(DEFAULT_PORT) + .setProxyDetector(GrpcUtil.DEFAULT_PROXY_DETECTOR) + .setSynchronizationContext(syncContext) + .setServiceConfigParser(mock(ServiceConfigParser.class)) + .setChannelLogger(mock(ChannelLogger.class)) + .setScheduledExecutorService(fakeExecutor.getScheduledExecutorService()) + .build(); @Mock private NameResolver.Listener2 mockListener; @@ -52,11 +71,11 @@ public class UdsNameResolverTest { @Test public void testValidTargetPath() { - udsNameResolver = new UdsNameResolver(null, "sock.sock"); + udsNameResolver = new UdsNameResolver(null, "sock.sock", args); udsNameResolver.start(mockListener); - verify(mockListener).onResult(resultCaptor.capture()); + verify(mockListener).onResult2(resultCaptor.capture()); NameResolver.ResolutionResult result = resultCaptor.getValue(); - List list = result.getAddresses(); + List list = result.getAddressesOrError().getValue(); assertThat(list).isNotNull(); assertThat(list).hasSize(1); EquivalentAddressGroup eag = list.get(0); @@ -72,10 +91,10 @@ public void testValidTargetPath() { @Test public void testNonNullAuthority() { try { - udsNameResolver = new UdsNameResolver("authority", "sock.sock"); + udsNameResolver = new UdsNameResolver("somehost", "sock.sock", args); fail("exception expected"); } catch (IllegalArgumentException e) { - assertThat(e).hasMessageThat().isEqualTo("non-null authority not supported"); + assertThat(e).hasMessageThat().isEqualTo("authority not supported: somehost"); } } } diff --git a/okhttp/BUILD.bazel b/okhttp/BUILD.bazel index 30a77b11465..74a9f7a4300 100644 --- a/okhttp/BUILD.bazel +++ b/okhttp/BUILD.bazel @@ -1,3 +1,6 @@ +load("@rules_java//java:defs.bzl", "java_library") +load("@rules_jvm_external//:defs.bzl", "artifact") + java_library( name = "okhttp", srcs = glob([ @@ -12,12 +15,11 @@ java_library( "//api", "//core:internal", "//util", - "@com_google_code_findbugs_jsr305//jar", - "@com_google_errorprone_error_prone_annotations//jar", - "@com_google_guava_guava//jar", - "@com_google_j2objc_j2objc_annotations//jar", - "@com_squareup_okhttp_okhttp//jar", - "@com_squareup_okio_okio//jar", - "@io_perfmark_perfmark_api//jar", + artifact("com.google.code.findbugs:jsr305"), + artifact("com.google.errorprone:error_prone_annotations"), + artifact("com.google.guava:guava"), + artifact("com.squareup.okhttp:okhttp"), + artifact("com.squareup.okio:okio"), + artifact("io.perfmark:perfmark-api"), ], ) diff --git a/okhttp/build.gradle b/okhttp/build.gradle index 063e4775de1..6c542feec9c 100644 --- a/okhttp/build.gradle +++ b/okhttp/build.gradle @@ -31,8 +31,16 @@ dependencies { project(':grpc-testing-proto'), libraries.netty.codec.http2, libraries.okhttp - signature libraries.signature.java - signature libraries.signature.android + signature (libraries.signature.java) { + artifact { + extension = "signature" + } + } + signature (libraries.signature.android) { + artifact { + extension = "signature" + } + } } project.sourceSets { diff --git a/okhttp/src/main/java/io/grpc/okhttp/AsyncSink.java b/okhttp/src/main/java/io/grpc/okhttp/AsyncSink.java index 1ac64d7ebb5..01ee23b905c 100644 --- a/okhttp/src/main/java/io/grpc/okhttp/AsyncSink.java +++ b/okhttp/src/main/java/io/grpc/okhttp/AsyncSink.java @@ -19,6 +19,7 @@ import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.base.Preconditions.checkState; +import com.google.errorprone.annotations.concurrent.GuardedBy; import io.grpc.internal.SerializingExecutor; import io.grpc.okhttp.ExceptionHandlingFrameWriter.TransportExceptionHandler; import io.grpc.okhttp.internal.framed.ErrorCode; @@ -30,7 +31,6 @@ import java.io.IOException; import java.net.Socket; import javax.annotation.Nullable; -import javax.annotation.concurrent.GuardedBy; import okio.Buffer; import okio.Sink; import okio.Timeout; diff --git a/okhttp/src/main/java/io/grpc/okhttp/NoopSslSocket.java b/okhttp/src/main/java/io/grpc/okhttp/NoopSslSocket.java new file mode 100644 index 00000000000..6e6a6f12a39 --- /dev/null +++ b/okhttp/src/main/java/io/grpc/okhttp/NoopSslSocket.java @@ -0,0 +1,117 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.okhttp; + +import java.io.IOException; +import javax.net.ssl.HandshakeCompletedListener; +import javax.net.ssl.SSLSession; +import javax.net.ssl.SSLSocket; + +/** A no-op ssl socket, to facilitate overriding only the required methods in specific + * implementations. + */ +class NoopSslSocket extends SSLSocket { + @Override + public String[] getSupportedCipherSuites() { + return new String[0]; + } + + @Override + public String[] getEnabledCipherSuites() { + return new String[0]; + } + + @Override + public void setEnabledCipherSuites(String[] suites) { + + } + + @Override + public String[] getSupportedProtocols() { + return new String[0]; + } + + @Override + public String[] getEnabledProtocols() { + return new String[0]; + } + + @Override + public void setEnabledProtocols(String[] protocols) { + + } + + @Override + public SSLSession getSession() { + return null; + } + + @Override + public void addHandshakeCompletedListener(HandshakeCompletedListener listener) { + + } + + @Override + public void removeHandshakeCompletedListener(HandshakeCompletedListener listener) { + + } + + @Override + public void startHandshake() throws IOException { + + } + + @Override + public void setUseClientMode(boolean mode) { + + } + + @Override + public boolean getUseClientMode() { + return false; + } + + @Override + public void setNeedClientAuth(boolean need) { + + } + + @Override + public boolean getNeedClientAuth() { + return false; + } + + @Override + public void setWantClientAuth(boolean want) { + + } + + @Override + public boolean getWantClientAuth() { + return false; + } + + @Override + public void setEnableSessionCreation(boolean flag) { + + } + + @Override + public boolean getEnableSessionCreation() { + return false; + } +} diff --git a/okhttp/src/main/java/io/grpc/okhttp/OkHttpChannelBuilder.java b/okhttp/src/main/java/io/grpc/okhttp/OkHttpChannelBuilder.java index 15508110344..43bc92af092 100644 --- a/okhttp/src/main/java/io/grpc/okhttp/OkHttpChannelBuilder.java +++ b/okhttp/src/main/java/io/grpc/okhttp/OkHttpChannelBuilder.java @@ -17,11 +17,13 @@ package io.grpc.okhttp; import static com.google.common.base.Preconditions.checkNotNull; +import static io.grpc.internal.CertificateUtils.createTrustManager; import static io.grpc.internal.GrpcUtil.DEFAULT_KEEPALIVE_TIMEOUT_NANOS; import static io.grpc.internal.GrpcUtil.KEEPALIVE_TIME_NANOS_DISABLED; import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; +import com.google.errorprone.annotations.CheckReturnValue; import io.grpc.CallCredentials; import io.grpc.ChannelCredentials; import io.grpc.ChannelLogger; @@ -72,7 +74,6 @@ import java.util.concurrent.TimeUnit; import java.util.logging.Level; import java.util.logging.Logger; -import javax.annotation.CheckReturnValue; import javax.annotation.Nullable; import javax.net.SocketFactory; import javax.net.ssl.HostnameVerifier; @@ -81,8 +82,6 @@ import javax.net.ssl.SSLContext; import javax.net.ssl.SSLSocketFactory; import javax.net.ssl.TrustManager; -import javax.net.ssl.TrustManagerFactory; -import javax.security.auth.x500.X500Principal; /** Convenience class for building channels with the OkHttp transport. */ @ExperimentalApi("https://github.com/grpc/grpc-java/issues/1785") @@ -91,6 +90,7 @@ public final class OkHttpChannelBuilder extends ForwardingChannelBuilder2> getSupportedSocketAddressTypes() { return Collections.singleton(InetSocketAddress.class); } @@ -799,6 +797,7 @@ static final class OkHttpTransportFactory implements ClientTransportFactory { private final boolean keepAliveWithoutCalls; final int maxInboundMetadataSize; final boolean useGetForSafeMethods; + private final ChannelCredentials channelCredentials; private boolean closed; private OkHttpTransportFactory( @@ -816,7 +815,8 @@ private OkHttpTransportFactory( boolean keepAliveWithoutCalls, int maxInboundMetadataSize, TransportTracer.Factory transportTracerFactory, - boolean useGetForSafeMethods) { + boolean useGetForSafeMethods, + ChannelCredentials channelCredentials) { this.executorPool = executorPool; this.executor = executorPool.getObject(); this.scheduledExecutorServicePool = scheduledExecutorServicePool; @@ -834,6 +834,7 @@ private OkHttpTransportFactory( this.keepAliveWithoutCalls = keepAliveWithoutCalls; this.maxInboundMetadataSize = maxInboundMetadataSize; this.useGetForSafeMethods = useGetForSafeMethods; + this.channelCredentials = channelCredentials; this.transportTracerFactory = Preconditions.checkNotNull(transportTracerFactory, "transportTracerFactory"); @@ -861,7 +862,8 @@ public void run() { options.getUserAgent(), options.getEagAttributes(), options.getHttpConnectProxiedSocketAddress(), - tooManyPingsRunnable); + tooManyPingsRunnable, + channelCredentials); if (enableKeepAlive) { transport.enableKeepAlive( true, keepAliveTimeNanosState.get(), keepAliveTimeoutNanos, keepAliveWithoutCalls); @@ -897,7 +899,8 @@ public SwapChannelCredentialsResult swapChannelCredentials(ChannelCredentials ch keepAliveWithoutCalls, maxInboundMetadataSize, transportTracerFactory, - useGetForSafeMethods); + useGetForSafeMethods, + channelCredentials); return new SwapChannelCredentialsResult(factory, result.callCredentials); } diff --git a/okhttp/src/main/java/io/grpc/okhttp/OkHttpClientStream.java b/okhttp/src/main/java/io/grpc/okhttp/OkHttpClientStream.java index 50de8c7002f..8dd55d9f23e 100644 --- a/okhttp/src/main/java/io/grpc/okhttp/OkHttpClientStream.java +++ b/okhttp/src/main/java/io/grpc/okhttp/OkHttpClientStream.java @@ -21,6 +21,7 @@ import static io.grpc.internal.ClientStreamListener.RpcProgress.PROCESSED; import com.google.common.io.BaseEncoding; +import com.google.errorprone.annotations.concurrent.GuardedBy; import io.grpc.Attributes; import io.grpc.CallOptions; import io.grpc.Metadata; @@ -37,7 +38,6 @@ import io.perfmark.Tag; import io.perfmark.TaskCloseable; import java.util.List; -import javax.annotation.concurrent.GuardedBy; import okio.Buffer; /** @@ -99,7 +99,8 @@ class OkHttpClientStream extends AbstractClientStream { outboundFlow, transport, initialWindowSize, - method.getFullMethodName()); + method.getFullMethodName(), + callOptions); } @Override @@ -222,8 +223,9 @@ public TransportState( OutboundFlowController outboundFlow, OkHttpClientTransport transport, int initialWindowSize, - String methodName) { - super(maxMessageSize, statsTraceCtx, OkHttpClientStream.this.getTransportTracer()); + String methodName, + CallOptions options) { + super(maxMessageSize, statsTraceCtx, OkHttpClientStream.this.getTransportTracer(), options); this.lock = checkNotNull(lock, "lock"); this.frameWriter = frameWriter; this.outboundFlow = outboundFlow; @@ -407,7 +409,7 @@ private void streamReady(Metadata metadata, String path) { transport.isUsingPlaintext()); // TODO(b/145386688): This access should be guarded by 'this.transport.lock'; instead found: // 'this.lock' - transport.streamReadyToStart(OkHttpClientStream.this); + transport.streamReadyToStart(OkHttpClientStream.this, authority); } Tag tag() { diff --git a/okhttp/src/main/java/io/grpc/okhttp/OkHttpClientTransport.java b/okhttp/src/main/java/io/grpc/okhttp/OkHttpClientTransport.java index 3b9513eeb47..4764a6a1387 100644 --- a/okhttp/src/main/java/io/grpc/okhttp/OkHttpClientTransport.java +++ b/okhttp/src/main/java/io/grpc/okhttp/OkHttpClientTransport.java @@ -27,8 +27,10 @@ import com.google.common.base.Supplier; import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.SettableFuture; +import com.google.errorprone.annotations.concurrent.GuardedBy; import io.grpc.Attributes; import io.grpc.CallOptions; +import io.grpc.ChannelCredentials; import io.grpc.ClientStreamTracer; import io.grpc.Grpc; import io.grpc.HttpConnectProxiedSocketAddress; @@ -42,20 +44,28 @@ import io.grpc.Status; import io.grpc.Status.Code; import io.grpc.StatusException; +import io.grpc.TlsChannelCredentials; +import io.grpc.internal.CertificateUtils; import io.grpc.internal.ClientStreamListener.RpcProgress; import io.grpc.internal.ConnectionClientTransport; +import io.grpc.internal.DisconnectError; +import io.grpc.internal.GoAwayDisconnectError; import io.grpc.internal.GrpcAttributes; import io.grpc.internal.GrpcUtil; import io.grpc.internal.Http2Ping; import io.grpc.internal.InUseStateAggregator; import io.grpc.internal.KeepAliveManager; import io.grpc.internal.KeepAliveManager.ClientKeepAlivePinger; +import io.grpc.internal.NoopSslSession; import io.grpc.internal.SerializingExecutor; +import io.grpc.internal.SimpleDisconnectError; import io.grpc.internal.StatsTraceContext; import io.grpc.internal.TransportTracer; import io.grpc.okhttp.ExceptionHandlingFrameWriter.TransportExceptionHandler; +import io.grpc.okhttp.OkHttpChannelBuilder.OkHttpTransportFactory; import io.grpc.okhttp.internal.ConnectionSpec; import io.grpc.okhttp.internal.Credentials; +import io.grpc.okhttp.internal.OkHostnameVerifier; import io.grpc.okhttp.internal.StatusLine; import io.grpc.okhttp.internal.framed.ErrorCode; import io.grpc.okhttp.internal.framed.FrameReader; @@ -70,31 +80,46 @@ import io.perfmark.PerfMark; import java.io.EOFException; import java.io.IOException; +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; import java.net.InetSocketAddress; import java.net.Socket; import java.net.URI; +import java.security.GeneralSecurityException; +import java.security.KeyStore; +import java.security.cert.Certificate; +import java.security.cert.X509Certificate; import java.util.Collections; import java.util.Deque; import java.util.EnumMap; import java.util.HashMap; import java.util.Iterator; +import java.util.LinkedHashMap; import java.util.LinkedList; import java.util.List; import java.util.Locale; import java.util.Map; import java.util.Random; +import java.util.concurrent.BrokenBarrierException; import java.util.concurrent.CountDownLatch; +import java.util.concurrent.CyclicBarrier; import java.util.concurrent.Executor; import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; import java.util.logging.Level; import java.util.logging.Logger; import javax.annotation.Nullable; -import javax.annotation.concurrent.GuardedBy; import javax.net.SocketFactory; import javax.net.ssl.HostnameVerifier; +import javax.net.ssl.SSLParameters; +import javax.net.ssl.SSLPeerUnverifiedException; import javax.net.ssl.SSLSession; import javax.net.ssl.SSLSocket; import javax.net.ssl.SSLSocketFactory; +import javax.net.ssl.TrustManager; +import javax.net.ssl.TrustManagerFactory; +import javax.net.ssl.X509TrustManager; import okio.Buffer; import okio.BufferedSink; import okio.BufferedSource; @@ -107,9 +132,15 @@ * A okhttp-based {@link ConnectionClientTransport} implementation. */ class OkHttpClientTransport implements ConnectionClientTransport, TransportExceptionHandler, - OutboundFlowController.Transport { + OutboundFlowController.Transport, ClientKeepAlivePinger.TransportWithDisconnectReason { private static final Map ERROR_CODE_TO_STATUS = buildErrorCodeToStatusMap(); private static final Logger log = Logger.getLogger(OkHttpClientTransport.class.getName()); + private static final String GRPC_ENABLE_PER_RPC_AUTHORITY_CHECK = + "GRPC_ENABLE_PER_RPC_AUTHORITY_CHECK"; + static boolean enablePerRpcAuthorityCheck = + GrpcUtil.getFlag(GRPC_ENABLE_PER_RPC_AUTHORITY_CHECK, false); + private Socket sock; + private SSLSession sslSession; private static Map buildErrorCodeToStatusMap() { Map errorToStatus = new EnumMap<>(ErrorCode.class); @@ -140,6 +171,26 @@ private static Map buildErrorCodeToStatusMap() { return Collections.unmodifiableMap(errorToStatus); } + private static final Class x509ExtendedTrustManagerClass; + private static final Method checkServerTrustedMethod; + + static { + Class x509ExtendedTrustManagerClass1 = null; + Method checkServerTrustedMethod1 = null; + try { + x509ExtendedTrustManagerClass1 = Class.forName("javax.net.ssl.X509ExtendedTrustManager"); + checkServerTrustedMethod1 = x509ExtendedTrustManagerClass1.getMethod("checkServerTrusted", + X509Certificate[].class, String.class, Socket.class); + } catch (ClassNotFoundException e) { + // Per-rpc authority override via call options will be disallowed. + } catch (NoSuchMethodException e) { + // Should never happen since X509ExtendedTrustManager was introduced in Android API level 24 + // along with checkServerTrusted. + } + x509ExtendedTrustManagerClass = x509ExtendedTrustManagerClass1; + checkServerTrustedMethod = checkServerTrustedMethod1; + } + private final InetSocketAddress address; private final String defaultAuthority; private final String userAgent; @@ -201,6 +252,19 @@ private static Map buildErrorCodeToStatusMap() { private final boolean useGetForSafeMethods; @GuardedBy("lock") private final TransportTracer transportTracer; + private final TrustManager x509TrustManager; + + @SuppressWarnings("serial") + private static class LruCache extends LinkedHashMap { + @Override + protected boolean removeEldestEntry(Map.Entry eldest) { + return size() > 100; + } + } + + @GuardedBy("lock") + private final Map authorityVerificationResults = new LruCache<>(); + @GuardedBy("lock") private final InUseStateAggregator inUseState = new InUseStateAggregator() { @@ -229,13 +293,14 @@ protected void handleNotInUse() { SettableFuture connectedFuture; public OkHttpClientTransport( - OkHttpChannelBuilder.OkHttpTransportFactory transportFactory, - InetSocketAddress address, - String authority, - @Nullable String userAgent, - Attributes eagAttrs, - @Nullable HttpConnectProxiedSocketAddress proxiedAddr, - Runnable tooManyPingsRunnable) { + OkHttpTransportFactory transportFactory, + InetSocketAddress address, + String authority, + @Nullable String userAgent, + Attributes eagAttrs, + @Nullable HttpConnectProxiedSocketAddress proxiedAddr, + Runnable tooManyPingsRunnable, + ChannelCredentials channelCredentials) { this( transportFactory, address, @@ -245,19 +310,21 @@ public OkHttpClientTransport( GrpcUtil.STOPWATCH_SUPPLIER, new Http2(), proxiedAddr, - tooManyPingsRunnable); + tooManyPingsRunnable, + channelCredentials); } private OkHttpClientTransport( - OkHttpChannelBuilder.OkHttpTransportFactory transportFactory, - InetSocketAddress address, - String authority, - @Nullable String userAgent, - Attributes eagAttrs, - Supplier stopwatchFactory, - Variant variant, - @Nullable HttpConnectProxiedSocketAddress proxiedAddr, - Runnable tooManyPingsRunnable) { + OkHttpTransportFactory transportFactory, + InetSocketAddress address, + String authority, + @Nullable String userAgent, + Attributes eagAttrs, + Supplier stopwatchFactory, + Variant variant, + @Nullable HttpConnectProxiedSocketAddress proxiedAddr, + Runnable tooManyPingsRunnable, + ChannelCredentials channelCredentials) { this.address = Preconditions.checkNotNull(address, "address"); this.defaultAuthority = authority; this.maxMessageSize = transportFactory.maxMessageSize; @@ -272,7 +339,8 @@ private OkHttpClientTransport( this.socketFactory = transportFactory.socketFactory == null ? SocketFactory.getDefault() : transportFactory.socketFactory; this.sslSocketFactory = transportFactory.sslSocketFactory; - this.hostnameVerifier = transportFactory.hostnameVerifier; + this.hostnameVerifier = transportFactory.hostnameVerifier != null + ? transportFactory.hostnameVerifier : OkHostnameVerifier.INSTANCE; this.connectionSpec = Preconditions.checkNotNull( transportFactory.connectionSpec, "connectionSpec"); this.stopwatchFactory = Preconditions.checkNotNull(stopwatchFactory, "stopwatchFactory"); @@ -288,6 +356,21 @@ private OkHttpClientTransport( .set(GrpcAttributes.ATTR_CLIENT_EAG_ATTRS, eagAttrs).build(); this.useGetForSafeMethods = transportFactory.useGetForSafeMethods; initTransportTracer(); + TrustManager tempX509TrustManager; + if (channelCredentials instanceof TlsChannelCredentials + && x509ExtendedTrustManagerClass != null) { + try { + tempX509TrustManager = getTrustManager( + (TlsChannelCredentials) channelCredentials); + } catch (GeneralSecurityException e) { + tempX509TrustManager = null; + log.log(Level.WARNING, "Obtaining X509ExtendedTrustManager for the transport failed." + + "Per-rpc authority overrides will be disallowed.", e); + } + } else { + tempX509TrustManager = null; + } + x509TrustManager = tempX509TrustManager; } /** @@ -296,7 +379,7 @@ private OkHttpClientTransport( @SuppressWarnings("AddressSelection") // An IP address always returns one address @VisibleForTesting OkHttpClientTransport( - OkHttpChannelBuilder.OkHttpTransportFactory transportFactory, + OkHttpTransportFactory transportFactory, String userAgent, Supplier stopwatchFactory, Variant variant, @@ -312,7 +395,8 @@ private OkHttpClientTransport( stopwatchFactory, variant, null, - tooManyPingsRunnable); + tooManyPingsRunnable, + null); this.connectingCallback = connectingCallback; this.connectedFuture = Preconditions.checkNotNull(connectedFuture, "connectedFuture"); } @@ -392,6 +476,7 @@ public OkHttpClientStream newStream( Preconditions.checkNotNull(headers, "headers"); StatsTraceContext statsTraceContext = StatsTraceContext.newClientContext(tracers, getAttributes(), headers); + // FIXME: it is likely wrong to pass the transportTracer here as it'll exit the lock's scope synchronized (lock) { // to make @GuardedBy linter happy return new OkHttpClientStream( @@ -412,23 +497,116 @@ public OkHttpClientStream newStream( } } + private TrustManager getTrustManager(TlsChannelCredentials tlsCreds) + throws GeneralSecurityException { + TrustManager[] tm; + // Using the same way of creating TrustManager from OkHttpChannelBuilder.sslSocketFactoryFrom() + if (tlsCreds.getTrustManagers() != null) { + tm = tlsCreds.getTrustManagers().toArray(new TrustManager[0]); + } else if (tlsCreds.getRootCertificates() != null) { + tm = CertificateUtils.createTrustManager(tlsCreds.getRootCertificates()); + } else { // else use system default + TrustManagerFactory tmf = TrustManagerFactory.getInstance( + TrustManagerFactory.getDefaultAlgorithm()); + tmf.init((KeyStore) null); + tm = tmf.getTrustManagers(); + } + for (TrustManager trustManager: tm) { + if (trustManager instanceof X509TrustManager) { + return trustManager; + } + } + return null; + } + @GuardedBy("lock") - void streamReadyToStart(OkHttpClientStream clientStream) { + void streamReadyToStart(OkHttpClientStream clientStream, String authority) { if (goAwayStatus != null) { clientStream.transportState().transportReportStatus( goAwayStatus, RpcProgress.MISCARRIED, true, new Metadata()); - } else if (streams.size() >= maxConcurrentStreams) { - pendingStreams.add(clientStream); - setInUse(clientStream); } else { - startStream(clientStream); + if (socket instanceof SSLSocket && !authority.equals(defaultAuthority)) { + Status authorityVerificationResult; + if (authorityVerificationResults.containsKey(authority)) { + authorityVerificationResult = authorityVerificationResults.get(authority); + } else { + authorityVerificationResult = verifyAuthority(authority); + authorityVerificationResults.put(authority, authorityVerificationResult); + } + if (!authorityVerificationResult.isOk()) { + if (enablePerRpcAuthorityCheck) { + clientStream.transportState().transportReportStatus( + authorityVerificationResult, RpcProgress.PROCESSED, true, new Metadata()); + return; + } + } + } + if (streams.size() >= maxConcurrentStreams) { + pendingStreams.add(clientStream); + setInUse(clientStream); + } else { + startStream(clientStream); + } } } + private Status verifyAuthority(String authority) { + Status authorityVerificationResult; + if (hostnameVerifier.verify(authority, ((SSLSocket) socket).getSession())) { + authorityVerificationResult = Status.OK; + } else { + authorityVerificationResult = Status.UNAVAILABLE.withDescription(String.format( + "HostNameVerifier verification failed for authority '%s'", + authority)); + } + if (!authorityVerificationResult.isOk() && !enablePerRpcAuthorityCheck) { + log.log(Level.WARNING, String.format("HostNameVerifier verification failed for " + + "authority '%s'. This will be an error in the future.", + authority)); + } + if (authorityVerificationResult.isOk()) { + // The status is trivially assigned in this case, but we are still making use of the + // cache to keep track that a warning log had been logged for the authority when + // enablePerRpcAuthorityCheck is false. When we permanently enable the feature, the + // status won't need to be cached for case when x509TrustManager is null. + if (x509TrustManager == null) { + authorityVerificationResult = Status.UNAVAILABLE.withDescription( + String.format("Could not verify authority '%s' for the rpc with no " + + "X509TrustManager available", + authority)); + } else if (x509ExtendedTrustManagerClass.isInstance(x509TrustManager)) { + try { + Certificate[] peerCertificates = sslSession.getPeerCertificates(); + X509Certificate[] x509PeerCertificates = + new X509Certificate[peerCertificates.length]; + for (int i = 0; i < peerCertificates.length; i++) { + x509PeerCertificates[i] = (X509Certificate) peerCertificates[i]; + } + checkServerTrustedMethod.invoke(x509TrustManager, x509PeerCertificates, + "RSA", new SslSocketWrapper((SSLSocket) socket, authority)); + authorityVerificationResult = Status.OK; + } catch (SSLPeerUnverifiedException | InvocationTargetException + | IllegalAccessException e) { + authorityVerificationResult = Status.UNAVAILABLE.withCause(e).withDescription( + "Peer verification failed"); + } + if (authorityVerificationResult.getCause() != null) { + log.log(Level.WARNING, authorityVerificationResult.getDescription() + + ". This will be an error in the future.", + authorityVerificationResult.getCause()); + } else { + log.log(Level.WARNING, authorityVerificationResult.getDescription() + + ". This will be an error in the future."); + } + } + } + return authorityVerificationResult; + } + @SuppressWarnings("GuardedBy") @GuardedBy("lock") private void startStream(OkHttpClientStream stream) { - Preconditions.checkState( + checkState( stream.transportState().id() == OkHttpClientStream.ABSENT_ID, "StreamId already assigned"); streams.put(nextStreamId, stream); setInUse(stream); @@ -499,20 +677,18 @@ public Runnable start(Listener listener) { outboundFlow = new OutboundFlowController(this, frameWriter); } final CountDownLatch latch = new CountDownLatch(1); + final CountDownLatch latchForExtraThread = new CountDownLatch(1); + // The transport needs up to two threads to function once started, + // but only needs one during handshaking. Start another thread during handshaking + // to make sure there's still a free thread available. If the number of threads is exhausted, + // it is better to kill the transport than for all the transports to hang unable to send. + CyclicBarrier barrier = new CyclicBarrier(2); // Connecting in the serializingExecutor, so that some stream operations like synStream // will be executed after connected. + serializingExecutor.execute(new Runnable() { @Override public void run() { - // This is a hack to make sure the connection preface and initial settings to be sent out - // without blocking the start. By doing this essentially prevents potential deadlock when - // network is not available during startup while another thread holding lock to send the - // initial preface. - try { - latch.await(); - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - } // Use closed source on failure so that the reader immediately shuts down. BufferedSource source = Okio.buffer(new Source() { @Override @@ -529,9 +705,23 @@ public Timeout timeout() { public void close() { } }); - Socket sock; - SSLSession sslSession = null; try { + // This is a hack to make sure the connection preface and initial settings to be sent out + // without blocking the start. By doing this essentially prevents potential deadlock when + // network is not available during startup while another thread holding lock to send the + // initial preface. + try { + latch.await(); + barrier.await(1000, TimeUnit.MILLISECONDS); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } catch (TimeoutException | BrokenBarrierException e) { + startGoAway(0, ErrorCode.INTERNAL_ERROR, Status.UNAVAILABLE + .withDescription("Timed out waiting for second handshake thread. " + + "The transport executor pool may have run out of threads")); + return; + } + if (proxiedAddr == null) { sock = socketFactory.createSocket(address.getAddress(), address.getPort()); } else { @@ -575,6 +765,7 @@ sslSocketFactory, hostnameVerifier, sock, getOverridenHost(), getOverridenPort() return; } finally { clientFrameHandler = new ClientFrameHandler(variant.newReader(source, true)); + latchForExtraThread.countDown(); } synchronized (lock) { socket = Preconditions.checkNotNull(sock, "socket"); @@ -584,6 +775,21 @@ sslSocketFactory, hostnameVerifier, sock, getOverridenHost(), getOverridenPort() } } }); + + executor.execute(new Runnable() { + @Override + public void run() { + try { + barrier.await(1000, TimeUnit.MILLISECONDS); + latchForExtraThread.await(); + } catch (BrokenBarrierException | TimeoutException e) { + // Something bad happened, maybe too few threads available! + // This will be handled in the handshake thread. + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + } + }); // Schedule to send connection preface & settings before any other write. try { sendConnectionPrefaceAndSettings(); @@ -597,13 +803,15 @@ public void run() { if (connectingCallback != null) { connectingCallback.run(); } - // ClientFrameHandler need to be started after connectionPreface / settings, otherwise it - // may send goAway immediately. - executor.execute(clientFrameHandler); synchronized (lock) { maxConcurrentStreams = Integer.MAX_VALUE; - startPendingStreams(); + checkState(pendingStreams.isEmpty(), + "Pending streams detected during transport start." + + " RPCs should not be started before transport is ready."); } + // ClientFrameHandler need to be started after connectionPreface / settings, otherwise it + // may send goAway immediately. + executor.execute(clientFrameHandler); if (connectedFuture != null) { connectedFuture.set(null); } @@ -787,13 +995,18 @@ public void shutdown(Status reason) { } goAwayStatus = reason; - listener.transportShutdown(goAwayStatus); + listener.transportShutdown(goAwayStatus, SimpleDisconnectError.SUBCHANNEL_SHUTDOWN); stopIfNecessary(); } } @Override public void shutdownNow(Status reason) { + shutdownNow(reason, SimpleDisconnectError.SUBCHANNEL_SHUTDOWN); + } + + @Override + public void shutdownNow(Status reason, DisconnectError disconnectError) { shutdown(reason); synchronized (lock) { Iterator> it = streams.entrySet().iterator(); @@ -883,7 +1096,13 @@ private void startGoAway(int lastKnownStreamId, ErrorCode errorCode, Status stat synchronized (lock) { if (goAwayStatus == null) { goAwayStatus = status; - listener.transportShutdown(status); + GrpcUtil.Http2Error http2Error; + if (errorCode == null) { + http2Error = GrpcUtil.Http2Error.NO_ERROR; + } else { + http2Error = GrpcUtil.Http2Error.forCode(errorCode.httpCode); + } + listener.transportShutdown(status, new GoAwayDisconnectError(http2Error)); } if (errorCode != null && !goAwaySent) { // Send GOAWAY with lastGoodStreamId of 0, since we don't expect any server-initiated @@ -953,8 +1172,8 @@ void finishStream( } if (!startPendingStreams()) { stopIfNecessary(); - maybeClearInUse(stream); } + maybeClearInUse(stream); } } } @@ -1028,12 +1247,12 @@ private void setInUse(OkHttpClientStream stream) { } } - private Throwable getPingFailure() { + private Status getPingFailure() { synchronized (lock) { if (goAwayStatus != null) { - return goAwayStatus.asException(); + return goAwayStatus; } else { - return Status.UNAVAILABLE.withDescription("Connection closed").asException(); + return Status.UNAVAILABLE.withDescription("Connection closed"); } } } @@ -1426,4 +1645,50 @@ public void alternateService(int streamId, String origin, ByteString protocol, S // TODO(madongfly): Deal with alternateService propagation } } + + /** + * SSLSocket wrapper that provides a fake SSLSession for handshake session. + */ + static final class SslSocketWrapper extends NoopSslSocket { + + private final SSLSession sslSession; + private final SSLSocket sslSocket; + + SslSocketWrapper(SSLSocket sslSocket, String peerHost) { + this.sslSocket = sslSocket; + this.sslSession = new FakeSslSession(peerHost); + } + + @Override + public SSLSession getHandshakeSession() { + return this.sslSession; + } + + @Override + public boolean isConnected() { + return sslSocket.isConnected(); + } + + @Override + public SSLParameters getSSLParameters() { + return sslSocket.getSSLParameters(); + } + } + + /** + * Fake SSLSession instance that provides the peer host name to verify for per-rpc check. + */ + static class FakeSslSession extends NoopSslSession { + + private final String peerHost; + + FakeSslSession(String peerHost) { + this.peerHost = peerHost; + } + + @Override + public String getPeerHost() { + return peerHost; + } + } } diff --git a/okhttp/src/main/java/io/grpc/okhttp/OkHttpProtocolNegotiator.java b/okhttp/src/main/java/io/grpc/okhttp/OkHttpProtocolNegotiator.java index d09d6cccedd..3f5a4d8cb2b 100644 --- a/okhttp/src/main/java/io/grpc/okhttp/OkHttpProtocolNegotiator.java +++ b/okhttp/src/main/java/io/grpc/okhttp/OkHttpProtocolNegotiator.java @@ -19,6 +19,8 @@ import static com.google.common.base.Preconditions.checkNotNull; import com.google.common.annotations.VisibleForTesting; +import com.google.common.net.HostAndPort; +import com.google.common.net.InetAddresses; import io.grpc.internal.GrpcUtil; import io.grpc.okhttp.internal.OptionalMethod; import io.grpc.okhttp.internal.Platform; @@ -120,7 +122,6 @@ public String getSelectedProtocol(SSLSocket socket) { return platform.getSelectedProtocol(socket); } - @VisibleForTesting static final class AndroidNegotiator extends OkHttpProtocolNegotiator { // setUseSessionTickets(boolean) private static final OptionalMethod SET_USE_SESSION_TICKETS = @@ -247,7 +248,9 @@ protected void configureTlsExtensions( } else { SET_USE_SESSION_TICKETS.invokeOptionalWithoutCheckedException(sslSocket, true); } - if (SET_SERVER_NAMES != null && SNI_HOST_NAME != null) { + if (SET_SERVER_NAMES != null + && SNI_HOST_NAME != null + && !InetAddresses.isInetAddress(HostAndPort.fromString(hostname).getHost())) { SET_SERVER_NAMES .invoke(sslParams, Collections.singletonList(SNI_HOST_NAME.newInstance(hostname))); } else { diff --git a/okhttp/src/main/java/io/grpc/okhttp/OkHttpReadableBuffer.java b/okhttp/src/main/java/io/grpc/okhttp/OkHttpReadableBuffer.java index 136ee8954a2..d65453722f0 100644 --- a/okhttp/src/main/java/io/grpc/okhttp/OkHttpReadableBuffer.java +++ b/okhttp/src/main/java/io/grpc/okhttp/OkHttpReadableBuffer.java @@ -21,7 +21,6 @@ import java.io.EOFException; import java.io.IOException; import java.io.OutputStream; -import java.nio.ByteBuffer; /** * A {@link ReadableBuffer} implementation that is backed by an {@link okio.Buffer}. @@ -71,12 +70,6 @@ public void readBytes(byte[] dest, int destOffset, int length) { } } - @Override - public void readBytes(ByteBuffer dest) { - // We are not using it. - throw new UnsupportedOperationException(); - } - @Override public void readBytes(OutputStream dest, int length) throws IOException { buffer.writeTo(dest, length); diff --git a/okhttp/src/main/java/io/grpc/okhttp/OkHttpServerBuilder.java b/okhttp/src/main/java/io/grpc/okhttp/OkHttpServerBuilder.java index 8269a8ddf0f..50097e1922e 100644 --- a/okhttp/src/main/java/io/grpc/okhttp/OkHttpServerBuilder.java +++ b/okhttp/src/main/java/io/grpc/okhttp/OkHttpServerBuilder.java @@ -17,6 +17,8 @@ package io.grpc.okhttp; import static com.google.common.base.Preconditions.checkArgument; +import static io.grpc.internal.CertificateUtils.createTrustManager; +import static io.grpc.internal.GrpcUtil.DEFAULT_SERVER_PERMIT_KEEPALIVE_TIME_NANOS; import com.google.common.base.Preconditions; import com.google.errorprone.annotations.CanIgnoreReturnValue; @@ -26,6 +28,7 @@ import io.grpc.ForwardingServerBuilder; import io.grpc.InsecureServerCredentials; import io.grpc.Internal; +import io.grpc.MetricRecorder; import io.grpc.ServerBuilder; import io.grpc.ServerCredentials; import io.grpc.ServerStreamTracer; @@ -74,6 +77,7 @@ public final class OkHttpServerBuilder extends ForwardingServerBuilder streamTracerFactories, + MetricRecorder metricRecorder) { + return buildTransportServers(streamTracerFactories); + } + }); final SocketAddress listenAddress; final HandshakerSocketFactory handshakerSocketFactory; TransportTracer.Factory transportTracerFactory = TransportTracer.getDefaultFactory(); @@ -126,9 +138,10 @@ public static OkHttpServerBuilder forPort(SocketAddress address, ServerCredentia int maxInboundMessageSize = GrpcUtil.DEFAULT_MAX_MESSAGE_SIZE; long maxConnectionIdleInNanos = MAX_CONNECTION_IDLE_NANOS_DISABLED; boolean permitKeepAliveWithoutCalls; - long permitKeepAliveTimeInNanos = TimeUnit.MINUTES.toNanos(5); + long permitKeepAliveTimeInNanos = DEFAULT_SERVER_PERMIT_KEEPALIVE_TIME_NANOS; long maxConnectionAgeInNanos = MAX_CONNECTION_AGE_NANOS_DISABLED; long maxConnectionAgeGraceInNanos = MAX_CONNECTION_AGE_GRACE_NANOS_INFINITE; + int maxConcurrentCallsPerConnection = MAX_CONCURRENT_STREAMS; OkHttpServerBuilder( SocketAddress address, HandshakerSocketFactory handshakerSocketFactory) { @@ -350,6 +363,18 @@ public OkHttpServerBuilder maxInboundMetadataSize(int bytes) { return this; } + /** + * The maximum number of concurrent calls permitted for each incoming connection. Defaults to no + * limit. + */ + @CanIgnoreReturnValue + public OkHttpServerBuilder maxConcurrentCallsPerConnection(int maxConcurrentCallsPerConnection) { + checkArgument(maxConcurrentCallsPerConnection > 0, + "max must be positive: %s", maxConcurrentCallsPerConnection); + this.maxConcurrentCallsPerConnection = maxConcurrentCallsPerConnection; + return this; + } + /** * Sets the maximum message size allowed to be received on the server. If not called, defaults to * defaults to 4 MiB. The default provides protection to servers who haven't considered the @@ -411,7 +436,7 @@ static HandshakerSocketFactoryResult handshakerSocketFactoryFrom(ServerCredentia tm = tlsCreds.getTrustManagers().toArray(new TrustManager[0]); } else if (tlsCreds.getRootCertificates() != null) { try { - tm = OkHttpChannelBuilder.createTrustManager(tlsCreds.getRootCertificates()); + tm = createTrustManager(tlsCreds.getRootCertificates()); } catch (GeneralSecurityException gse) { log.log(Level.FINE, "Exception loading root certificates from credential", gse); return HandshakerSocketFactoryResult.error( diff --git a/okhttp/src/main/java/io/grpc/okhttp/OkHttpServerStream.java b/okhttp/src/main/java/io/grpc/okhttp/OkHttpServerStream.java index bcf8837b7eb..d1f1a3f4fe0 100644 --- a/okhttp/src/main/java/io/grpc/okhttp/OkHttpServerStream.java +++ b/okhttp/src/main/java/io/grpc/okhttp/OkHttpServerStream.java @@ -17,6 +17,7 @@ package io.grpc.okhttp; import com.google.common.base.Preconditions; +import com.google.errorprone.annotations.concurrent.GuardedBy; import io.grpc.Attributes; import io.grpc.Metadata; import io.grpc.Status; @@ -30,7 +31,6 @@ import io.perfmark.Tag; import io.perfmark.TaskCloseable; import java.util.List; -import javax.annotation.concurrent.GuardedBy; import okio.Buffer; /** diff --git a/okhttp/src/main/java/io/grpc/okhttp/OkHttpServerTransport.java b/okhttp/src/main/java/io/grpc/okhttp/OkHttpServerTransport.java index 8fb74d3f1b5..7d192b16943 100644 --- a/okhttp/src/main/java/io/grpc/okhttp/OkHttpServerTransport.java +++ b/okhttp/src/main/java/io/grpc/okhttp/OkHttpServerTransport.java @@ -20,8 +20,10 @@ import static io.grpc.okhttp.OkHttpServerBuilder.MAX_CONNECTION_IDLE_NANOS_DISABLED; import com.google.common.base.Preconditions; +import com.google.common.collect.Lists; import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.ListenableFuture; +import com.google.errorprone.annotations.concurrent.GuardedBy; import io.grpc.Attributes; import io.grpc.InternalChannelz; import io.grpc.InternalLogId; @@ -50,6 +52,8 @@ import io.grpc.okhttp.internal.framed.Variant; import java.io.IOException; import java.net.Socket; +import java.net.SocketException; +import java.util.Collections; import java.util.List; import java.util.Locale; import java.util.Map; @@ -61,7 +65,6 @@ import java.util.logging.Level; import java.util.logging.Logger; import javax.annotation.Nullable; -import javax.annotation.concurrent.GuardedBy; import okio.Buffer; import okio.BufferedSource; import okio.ByteString; @@ -90,6 +93,7 @@ final class OkHttpServerTransport implements ServerTransport, private static final ByteString TE_TRAILERS = ByteString.encodeUtf8("trailers"); private static final ByteString CONTENT_TYPE = ByteString.encodeUtf8("content-type"); private static final ByteString CONTENT_LENGTH = ByteString.encodeUtf8("content-length"); + private static final ByteString ALLOW = ByteString.encodeUtf8("allow"); private final Config config; private final Variant variant = new Http2(); @@ -170,6 +174,13 @@ private void startIo(SerializingExecutor serializingExecutor) { HandshakerSocketFactory.HandshakeResult result = config.handshakerSocketFactory.handshake(socket, Attributes.EMPTY); synchronized (lock) { + if (socket.isClosed()) { + // The wrapped socket may not handle the underlying socket being closed by shutdown(). In + // particular, SSLSocket hangs future reads if the underlying socket is already closed at + // this point, even if you call sslSocket.close() later. + result.socket.close(); + throw new SocketException("Socket close raced with handshake"); + } this.socket = result.socket; } this.attributes = result.attributes; @@ -219,6 +230,10 @@ public void data(boolean outFinished, int streamId, Buffer source, int byteCount OkHttpSettingsUtil.INITIAL_WINDOW_SIZE, config.flowControlWindow); OkHttpSettingsUtil.set(settings, OkHttpSettingsUtil.MAX_HEADER_LIST_SIZE, config.maxInboundMetadataSize); + if (config.maxConcurrentStreams != Integer.MAX_VALUE) { + OkHttpSettingsUtil.set(settings, + OkHttpSettingsUtil.MAX_CONCURRENT_STREAMS, config.maxConcurrentStreams); + } frameWriter.settings(settings); if (config.flowControlWindow > Utils.DEFAULT_WINDOW_SIZE) { frameWriter.windowUpdate( @@ -520,6 +535,7 @@ static final class Config { final long permitKeepAliveTimeInNanos; final long maxConnectionAgeInNanos; final long maxConnectionAgeGraceInNanos; + final int maxConcurrentStreams; public Config( OkHttpServerBuilder builder, @@ -544,6 +560,7 @@ public Config( permitKeepAliveTimeInNanos = builder.permitKeepAliveTimeInNanos; maxConnectionAgeInNanos = builder.maxConnectionAgeInNanos; maxConnectionAgeGraceInNanos = builder.maxConnectionAgeGraceInNanos; + maxConcurrentStreams = builder.maxConcurrentCallsPerConnection; } } @@ -638,6 +655,11 @@ public void headers(boolean outFinished, newStream = streamId > lastStreamId; if (newStream) { lastStreamId = streamId; + if (config.maxConcurrentStreams <= streams.size()) { + streamError(streamId, ErrorCode.REFUSED_STREAM, + "Max concurrent stream reached. RFC7540 section 5.1.2"); + return; + } } } @@ -753,8 +775,9 @@ public void headers(boolean outFinished, } if (!POST_METHOD.equals(httpMethod)) { + List
extraHeaders = Lists.newArrayList(new Header(ALLOW, POST_METHOD)); respondWithHttpError(streamId, inFinished, 405, Status.Code.INTERNAL, - "HTTP Method is not supported: " + asciiString(httpMethod)); + "HTTP Method is not supported: " + asciiString(httpMethod), extraHeaders); return; } @@ -928,13 +951,13 @@ public void settings(boolean clearPrevious, Settings settings) { @Override public void ping(boolean ack, int payload1, int payload2) { - if (!keepAliveEnforcer.pingAcceptable()) { - abruptShutdown(ErrorCode.ENHANCE_YOUR_CALM, "too_many_pings", - Status.RESOURCE_EXHAUSTED.withDescription("Too many pings from client"), false); - return; - } long payload = (((long) payload1) << 32) | (payload2 & 0xffffffffL); if (!ack) { + if (!keepAliveEnforcer.pingAcceptable()) { + abruptShutdown(ErrorCode.ENHANCE_YOUR_CALM, "too_many_pings", + Status.RESOURCE_EXHAUSTED.withDescription("Too many pings from client"), false); + return; + } frameLogger.logPing(OkHttpFrameLogger.Direction.INBOUND, payload); synchronized (lock) { frameWriter.ping(true, payload1, payload2); @@ -1047,11 +1070,19 @@ private void streamError(int streamId, ErrorCode errorCode, String reason) { private void respondWithHttpError( int streamId, boolean inFinished, int httpCode, Status.Code statusCode, String msg) { + respondWithHttpError(streamId, inFinished, httpCode, statusCode, msg, + Collections.emptyList()); + } + + private void respondWithHttpError( + int streamId, boolean inFinished, int httpCode, Status.Code statusCode, String msg, + List
extraHeaders) { Metadata metadata = new Metadata(); metadata.put(InternalStatus.CODE_KEY, statusCode.toStatus()); metadata.put(InternalStatus.MESSAGE_KEY, msg); List
headers = Headers.createHttpResponseHeaders(httpCode, "text/plain; charset=utf-8", metadata); + headers.addAll(extraHeaders); Buffer data = new Buffer().writeUtf8(msg); synchronized (lock) { diff --git a/okhttp/src/main/java/io/grpc/okhttp/OkHttpTlsUpgrader.java b/okhttp/src/main/java/io/grpc/okhttp/OkHttpTlsUpgrader.java index 1004dcd93f9..a8b038c91f4 100644 --- a/okhttp/src/main/java/io/grpc/okhttp/OkHttpTlsUpgrader.java +++ b/okhttp/src/main/java/io/grpc/okhttp/OkHttpTlsUpgrader.java @@ -19,13 +19,13 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; import io.grpc.okhttp.internal.ConnectionSpec; -import io.grpc.okhttp.internal.OkHostnameVerifier; import io.grpc.okhttp.internal.Protocol; import java.io.IOException; import java.net.Socket; import java.util.Arrays; import java.util.Collections; import java.util.List; +import javax.annotation.Nonnull; import javax.net.ssl.HostnameVerifier; import javax.net.ssl.SSLPeerUnverifiedException; import javax.net.ssl.SSLSocket; @@ -52,7 +52,7 @@ final class OkHttpTlsUpgrader { * @throws RuntimeException if the upgrade negotiation failed. */ public static SSLSocket upgrade(SSLSocketFactory sslSocketFactory, - HostnameVerifier hostnameVerifier, Socket socket, String host, int port, + @Nonnull HostnameVerifier hostnameVerifier, Socket socket, String host, int port, ConnectionSpec spec) throws IOException { Preconditions.checkNotNull(sslSocketFactory, "sslSocketFactory"); Preconditions.checkNotNull(socket, "socket"); @@ -67,9 +67,6 @@ public static SSLSocket upgrade(SSLSocketFactory sslSocketFactory, "Only " + TLS_PROTOCOLS + " are supported, but negotiated protocol is %s", negotiatedProtocol); - if (hostnameVerifier == null) { - hostnameVerifier = OkHostnameVerifier.INSTANCE; - } if (!hostnameVerifier.verify(canonicalizeHost(host), sslSocket.getSession())) { throw new SSLPeerUnverifiedException("Cannot verify hostname: " + host); } diff --git a/okhttp/src/main/java/io/grpc/okhttp/OkHttpWritableBufferAllocator.java b/okhttp/src/main/java/io/grpc/okhttp/OkHttpWritableBufferAllocator.java index 481ada61c96..58896a5dbb0 100644 --- a/okhttp/src/main/java/io/grpc/okhttp/OkHttpWritableBufferAllocator.java +++ b/okhttp/src/main/java/io/grpc/okhttp/OkHttpWritableBufferAllocator.java @@ -27,11 +27,9 @@ */ class OkHttpWritableBufferAllocator implements WritableBufferAllocator { - // Use 4k as our minimum buffer size. - private static final int MIN_BUFFER = 4096; - // Set the maximum buffer size to 1MB private static final int MAX_BUFFER = 1024 * 1024; + public static final int SEGMENT_SIZE_COPY = 8192; // Should equal Segment.SIZE /** * Construct a new instance. @@ -45,7 +43,9 @@ class OkHttpWritableBufferAllocator implements WritableBufferAllocator { */ @Override public WritableBuffer allocate(int capacityHint) { - capacityHint = Math.min(MAX_BUFFER, Math.max(MIN_BUFFER, capacityHint)); + // okio buffer uses fixed size Segments, round capacityHint up + capacityHint = Math.min(MAX_BUFFER, + (capacityHint + SEGMENT_SIZE_COPY - 1) / SEGMENT_SIZE_COPY * SEGMENT_SIZE_COPY); return new OkHttpWritableBuffer(new Buffer(), capacityHint); } } diff --git a/okhttp/src/main/java/io/grpc/okhttp/SslSocketFactoryServerCredentials.java b/okhttp/src/main/java/io/grpc/okhttp/SslSocketFactoryServerCredentials.java index 63c6f33ff79..ad9af056afc 100644 --- a/okhttp/src/main/java/io/grpc/okhttp/SslSocketFactoryServerCredentials.java +++ b/okhttp/src/main/java/io/grpc/okhttp/SslSocketFactoryServerCredentials.java @@ -41,7 +41,7 @@ static final class ServerCredentials extends io.grpc.ServerCredentials { private final ConnectionSpec connectionSpec; ServerCredentials(SSLSocketFactory factory) { - this(factory, OkHttpChannelBuilder.INTERNAL_DEFAULT_CONNECTION_SPEC); + this(factory, OkHttpChannelBuilder.INTERNAL_LEGACY_CONNECTION_SPEC); } ServerCredentials(SSLSocketFactory factory, ConnectionSpec connectionSpec) { diff --git a/okhttp/src/main/java/io/grpc/okhttp/Utils.java b/okhttp/src/main/java/io/grpc/okhttp/Utils.java index 2dc5f1e1ec9..4546143cf3b 100644 --- a/okhttp/src/main/java/io/grpc/okhttp/Utils.java +++ b/okhttp/src/main/java/io/grpc/okhttp/Utils.java @@ -17,6 +17,7 @@ package io.grpc.okhttp; import com.google.common.base.Preconditions; +import com.google.errorprone.annotations.CheckReturnValue; import io.grpc.InternalChannelz; import io.grpc.InternalMetadata; import io.grpc.Metadata; @@ -29,7 +30,6 @@ import java.util.List; import java.util.logging.Level; import java.util.logging.Logger; -import javax.annotation.CheckReturnValue; /** * Common utility methods for OkHttp transport. diff --git a/okhttp/src/test/java/io/grpc/okhttp/AsyncSinkTest.java b/okhttp/src/test/java/io/grpc/okhttp/AsyncSinkTest.java index 46011588b16..478e18d0a2b 100644 --- a/okhttp/src/test/java/io/grpc/okhttp/AsyncSinkTest.java +++ b/okhttp/src/test/java/io/grpc/okhttp/AsyncSinkTest.java @@ -30,11 +30,11 @@ import static org.mockito.Mockito.timeout; import static org.mockito.Mockito.verify; -import com.google.common.base.Charsets; import io.grpc.internal.SerializingExecutor; import io.grpc.okhttp.ExceptionHandlingFrameWriter.TransportExceptionHandler; import java.io.IOException; import java.net.Socket; +import java.nio.charset.StandardCharsets; import java.util.Queue; import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.Executor; @@ -73,8 +73,8 @@ public void noCoalesceRequired() throws IOException { @Test public void flushCoalescing_shouldNotMergeTwoDistinctFlushes() throws IOException { - byte[] firstData = "a string".getBytes(Charsets.UTF_8); - byte[] secondData = "a longer string".getBytes(Charsets.UTF_8); + byte[] firstData = "a string".getBytes(StandardCharsets.UTF_8); + byte[] secondData = "a longer string".getBytes(StandardCharsets.UTF_8); sink.becomeConnected(mockedSink, socket); Buffer buffer = new Buffer(); @@ -95,8 +95,8 @@ public void flushCoalescing_shouldNotMergeTwoDistinctFlushes() throws IOExceptio @Test public void flushCoalescing_shouldMergeTwoQueuedFlushesAndWrites() throws IOException { - byte[] firstData = "a string".getBytes(Charsets.UTF_8); - byte[] secondData = "a longer string".getBytes(Charsets.UTF_8); + byte[] firstData = "a string".getBytes(StandardCharsets.UTF_8); + byte[] secondData = "a longer string".getBytes(StandardCharsets.UTF_8); Buffer buffer = new Buffer().write(firstData); sink.becomeConnected(mockedSink, socket); sink.write(buffer, buffer.size()); @@ -115,8 +115,8 @@ public void flushCoalescing_shouldMergeTwoQueuedFlushesAndWrites() throws IOExce @Test public void flushCoalescing_shouldMergeWrites() throws IOException { - byte[] firstData = "a string".getBytes(Charsets.UTF_8); - byte[] secondData = "a longer string".getBytes(Charsets.UTF_8); + byte[] firstData = "a string".getBytes(StandardCharsets.UTF_8); + byte[] secondData = "a longer string".getBytes(StandardCharsets.UTF_8); Buffer buffer = new Buffer(); sink.becomeConnected(mockedSink, socket); sink.write(buffer.write(firstData), buffer.size()); diff --git a/okhttp/src/test/java/io/grpc/okhttp/ExceptionHandlingFrameWriterTest.java b/okhttp/src/test/java/io/grpc/okhttp/ExceptionHandlingFrameWriterTest.java index a9d39088844..8829abac034 100644 --- a/okhttp/src/test/java/io/grpc/okhttp/ExceptionHandlingFrameWriterTest.java +++ b/okhttp/src/test/java/io/grpc/okhttp/ExceptionHandlingFrameWriterTest.java @@ -16,9 +16,9 @@ package io.grpc.okhttp; -import static com.google.common.base.Charsets.UTF_8; import static com.google.common.truth.Truth.assertThat; import static io.grpc.okhttp.ExceptionHandlingFrameWriter.getLogLevel; +import static java.nio.charset.StandardCharsets.UTF_8; import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; diff --git a/okhttp/src/test/java/io/grpc/okhttp/OkHttpChannelBuilderTest.java b/okhttp/src/test/java/io/grpc/okhttp/OkHttpChannelBuilderTest.java index 3670cd057c1..89d37536b70 100644 --- a/okhttp/src/test/java/io/grpc/okhttp/OkHttpChannelBuilderTest.java +++ b/okhttp/src/test/java/io/grpc/okhttp/OkHttpChannelBuilderTest.java @@ -22,6 +22,7 @@ import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertThrows; import static org.mockito.Mockito.mock; import com.google.common.util.concurrent.SettableFuture; @@ -34,6 +35,7 @@ import io.grpc.InsecureChannelCredentials; import io.grpc.ManagedChannel; import io.grpc.TlsChannelCredentials; +import io.grpc.internal.CertificateUtils; import io.grpc.internal.ClientTransportFactory; import io.grpc.internal.ClientTransportFactory.SwapChannelCredentialsResult; import io.grpc.internal.FakeClock; @@ -56,7 +58,6 @@ import javax.security.auth.x500.X500Principal; import org.junit.Rule; import org.junit.Test; -import org.junit.rules.ExpectedException; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -66,8 +67,6 @@ @RunWith(JUnit4.class) public class OkHttpChannelBuilderTest { - @SuppressWarnings("deprecation") // https://github.com/grpc/grpc-java/issues/7467 - @Rule public final ExpectedException thrown = ExpectedException.none(); @Rule public final GrpcCleanupRule grpcCleanupRule = new GrpcCleanupRule(); @Test @@ -99,10 +98,9 @@ private void overrideAuthorityIsReadableHelper(OkHttpChannelBuilder builder, @Test public void failOverrideInvalidAuthority() { OkHttpChannelBuilder builder = OkHttpChannelBuilder.forAddress("good", 1234); - - thrown.expect(IllegalArgumentException.class); - thrown.expectMessage("Invalid authority:"); - builder.overrideAuthority("[invalidauthority"); + IllegalArgumentException e = assertThrows(IllegalArgumentException.class, + () -> builder.overrideAuthority("[invalidauthority")); + assertThat(e).hasMessageThat().isEqualTo("Invalid authority: [invalidauthority"); } @Test @@ -118,17 +116,16 @@ public void enableCheckAuthorityFailOverrideInvalidAuthority() { .disableCheckAuthority() .enableCheckAuthority(); - thrown.expect(IllegalArgumentException.class); - thrown.expectMessage("Invalid authority:"); - builder.overrideAuthority("[invalidauthority"); + IllegalArgumentException e = assertThrows(IllegalArgumentException.class, + () -> builder.overrideAuthority("[invalidauthority")); + assertThat(e).hasMessageThat().isEqualTo("Invalid authority: [invalidauthority"); } @Test public void failInvalidAuthority() { - thrown.expect(IllegalArgumentException.class); - thrown.expectMessage("Invalid host or port"); - - OkHttpChannelBuilder.forAddress("invalid_authority", 1234); + IllegalArgumentException e = assertThrows(IllegalArgumentException.class, + () -> OkHttpChannelBuilder.forAddress("invalid_authority", 1234)); + assertThat(e.getMessage()).isEqualTo("Invalid host or port: invalid_authority 1234"); } @Test @@ -212,7 +209,7 @@ public void sslSocketFactoryFrom_tls_mtls() throws Exception { TrustManager[] trustManagers; try (InputStream ca = TlsTesting.loadCert("ca.pem")) { - trustManagers = OkHttpChannelBuilder.createTrustManager(ca); + trustManagers = CertificateUtils.createTrustManager(ca); } SSLContext serverContext = SSLContext.getInstance("TLS"); @@ -257,7 +254,7 @@ public void sslSocketFactoryFrom_tls_mtls_keyFile() throws Exception { InputStream ca = TlsTesting.loadCert("ca.pem")) { serverContext.init( OkHttpChannelBuilder.createKeyManager(server1Chain, server1Key), - OkHttpChannelBuilder.createTrustManager(ca), + CertificateUtils.createTrustManager(ca), null); } final SSLServerSocket serverListenSocket = @@ -395,10 +392,10 @@ public ChannelCredentials withoutBearerTokens() { @Test public void failForUsingClearTextSpecDirectly() { - thrown.expect(IllegalArgumentException.class); - thrown.expectMessage("plaintext ConnectionSpec is not accepted"); - - OkHttpChannelBuilder.forAddress("host", 1234).connectionSpec(ConnectionSpec.CLEARTEXT); + OkHttpChannelBuilder builder = OkHttpChannelBuilder.forAddress("host", 1234); + IllegalArgumentException e = assertThrows(IllegalArgumentException.class, + () -> builder.connectionSpec(ConnectionSpec.CLEARTEXT)); + assertThat(e).hasMessageThat().isEqualTo("plaintext ConnectionSpec is not accepted"); } @Test diff --git a/okhttp/src/test/java/io/grpc/okhttp/OkHttpClientStreamTest.java b/okhttp/src/test/java/io/grpc/okhttp/OkHttpClientStreamTest.java index 1f716705968..1c98d6ee30d 100644 --- a/okhttp/src/test/java/io/grpc/okhttp/OkHttpClientStreamTest.java +++ b/okhttp/src/test/java/io/grpc/okhttp/OkHttpClientStreamTest.java @@ -20,6 +20,7 @@ import static io.grpc.internal.ClientStreamListener.RpcProgress.PROCESSED; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.isA; import static org.mockito.Mockito.times; @@ -244,12 +245,13 @@ public void getUnaryRequest() throws IOException { // GET streams send headers after halfClose is called. verify(mockedFrameWriter, times(0)).synStream( eq(false), eq(false), eq(3), eq(0), headersCaptor.capture()); - verify(transport, times(0)).streamReadyToStart(isA(OkHttpClientStream.class)); + verify(transport, times(0)).streamReadyToStart(isA(OkHttpClientStream.class), + isA(String.class)); byte[] msg = "request".getBytes(Charset.forName("UTF-8")); stream.writeMessage(new ByteArrayInputStream(msg)); stream.halfClose(); - verify(transport).streamReadyToStart(eq(stream)); + verify(transport).streamReadyToStart(eq(stream), any(String.class)); stream.transportState().start(3); verify(mockedFrameWriter) diff --git a/okhttp/src/test/java/io/grpc/okhttp/OkHttpClientTransportTest.java b/okhttp/src/test/java/io/grpc/okhttp/OkHttpClientTransportTest.java index 7347399bfe5..0b571530db4 100644 --- a/okhttp/src/test/java/io/grpc/okhttp/OkHttpClientTransportTest.java +++ b/okhttp/src/test/java/io/grpc/okhttp/OkHttpClientTransportTest.java @@ -16,7 +16,6 @@ package io.grpc.okhttp; -import static com.google.common.base.Charsets.UTF_8; import static com.google.common.truth.Truth.assertThat; import static io.grpc.internal.ClientStreamListener.RpcProgress.MISCARRIED; import static io.grpc.internal.ClientStreamListener.RpcProgress.PROCESSED; @@ -25,6 +24,7 @@ import static io.grpc.okhttp.Headers.HTTP_SCHEME_HEADER; import static io.grpc.okhttp.Headers.METHOD_HEADER; import static io.grpc.okhttp.Headers.TE_HEADER; +import static java.nio.charset.StandardCharsets.UTF_8; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotNull; @@ -67,13 +67,16 @@ import io.grpc.MethodDescriptor.MethodType; import io.grpc.Status; import io.grpc.Status.Code; -import io.grpc.StatusException; import io.grpc.internal.AbstractStream; +import io.grpc.internal.ClientStream; import io.grpc.internal.ClientStreamListener; import io.grpc.internal.ClientTransport; +import io.grpc.internal.DisconnectError; import io.grpc.internal.FakeClock; +import io.grpc.internal.GoAwayDisconnectError; import io.grpc.internal.GrpcUtil; import io.grpc.internal.ManagedClientTransport; +import io.grpc.internal.SimpleDisconnectError; import io.grpc.okhttp.OkHttpClientTransport.ClientFrameHandler; import io.grpc.okhttp.OkHttpFrameLogger.Direction; import io.grpc.okhttp.internal.Protocol; @@ -116,6 +119,10 @@ import java.util.logging.Logger; import javax.annotation.Nullable; import javax.net.SocketFactory; +import javax.net.ssl.HandshakeCompletedListener; +import javax.net.ssl.HostnameVerifier; +import javax.net.ssl.SSLSession; +import javax.net.ssl.SSLSocket; import okio.Buffer; import okio.BufferedSink; import okio.BufferedSource; @@ -185,21 +192,35 @@ public class OkHttpClientTransportTest { @After public void tearDown() { - executor.shutdownNow(); + try { + executor.shutdownNow(); + executor.awaitTermination(10, TimeUnit.SECONDS); + } catch (InterruptedException e) { + // Ignore in a test and continue on as normal. + Thread.currentThread().interrupt(); + } } private void initTransport() throws Exception { startTransport( - DEFAULT_START_STREAM_ID, null, true, null); + DEFAULT_START_STREAM_ID, null, true, null, null); } private void initTransport(int startId) throws Exception { - startTransport(startId, null, true, null); + startTransport(startId, null, true, null, null); + } + + private void startTransport(int startId, @Nullable Runnable connectingCallback, + boolean waitingForConnected, String userAgent, + HostnameVerifier hostnameVerifier) throws Exception { + startTransport(startId, connectingCallback, waitingForConnected, userAgent, hostnameVerifier, + false); } private void startTransport(int startId, @Nullable Runnable connectingCallback, - boolean waitingForConnected, String userAgent) - throws Exception { + boolean waitingForConnected, String userAgent, + HostnameVerifier hostnameVerifier, boolean useSslSocket) + throws Exception { connectedFuture = SettableFuture.create(); final Ticker ticker = new Ticker() { @Override @@ -213,7 +234,11 @@ public Stopwatch get() { return Stopwatch.createUnstarted(ticker); } }; - channelBuilder.socketFactory(new FakeSocketFactory(socket)); + channelBuilder.socketFactory( + new FakeSocketFactory(useSslSocket ? new MockSslSocket(socket) : socket)); + if (hostnameVerifier != null) { + channelBuilder = channelBuilder.hostnameVerifier(hostnameVerifier); + } clientTransport = new OkHttpClientTransport( channelBuilder.buildTransportFactory(), userAgent, @@ -241,12 +266,37 @@ public void testToString() throws Exception { /*userAgent=*/ null, EAG_ATTRS, NO_PROXY, - tooManyPingsRunnable); + tooManyPingsRunnable, + null); String s = clientTransport.toString(); assertTrue("Unexpected: " + s, s.contains("OkHttpClientTransport")); assertTrue("Unexpected: " + s, s.contains(address.toString())); } + @Test + public void testTransportExecutorWithTooFewThreads() throws Exception { + ExecutorService fixedPoolExecutor = Executors.newFixedThreadPool(1); + channelBuilder.transportExecutor(fixedPoolExecutor); + InetSocketAddress address = InetSocketAddress.createUnresolved("hostname", 31415); + clientTransport = new OkHttpClientTransport( + channelBuilder.buildTransportFactory(), + address, + "hostname", + null, + EAG_ATTRS, + NO_PROXY, + tooManyPingsRunnable, + null); + clientTransport.start(transportListener); + ArgumentCaptor statusCaptor = ArgumentCaptor.forClass(Status.class); + verify(transportListener, timeout(TIME_OUT_MS)).transportShutdown(statusCaptor.capture(), + eq(new GoAwayDisconnectError(GrpcUtil.Http2Error.INTERNAL_ERROR))); + Status capturedStatus = statusCaptor.getValue(); + assertEquals("Timed out waiting for second handshake thread. " + + "The transport executor pool may have run out of threads", + capturedStatus.getDescription()); + } + /** * Test logging is functioning correctly for client received Http/2 frames. Not intended to test * actual frame content being logged. @@ -278,7 +328,7 @@ public void close() throws SecurityException { assertThat(log.getLevel()).isEqualTo(Level.FINE); MockStreamListener listener = new MockStreamListener(); - OkHttpClientStream stream = + ClientStream stream = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); stream.request(1); @@ -368,7 +418,7 @@ public void maxMessageSizeShouldBeEnforced() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); - OkHttpClientStream stream = + ClientStream stream = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); stream.request(1); @@ -421,11 +471,11 @@ public void nextFrameThrowIoException() throws Exception { initTransport(); MockStreamListener listener1 = new MockStreamListener(); MockStreamListener listener2 = new MockStreamListener(); - OkHttpClientStream stream1 = + ClientStream stream1 = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream1.start(listener1); stream1.request(1); - OkHttpClientStream stream2 = + ClientStream stream2 = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream2.start(listener2); stream2.request(1); @@ -441,7 +491,8 @@ public void nextFrameThrowIoException() throws Exception { assertEquals(NETWORK_ISSUE_MESSAGE, listener1.status.getCause().getMessage()); assertEquals(Status.INTERNAL.getCode(), listener2.status.getCode()); assertEquals(NETWORK_ISSUE_MESSAGE, listener2.status.getCause().getMessage()); - verify(transportListener, timeout(TIME_OUT_MS)).transportShutdown(isA(Status.class)); + verify(transportListener, timeout(TIME_OUT_MS)).transportShutdown(isA(Status.class), + any(DisconnectError.class)); verify(transportListener, timeout(TIME_OUT_MS)).transportTerminated(); shutdownAndVerify(); } @@ -455,7 +506,7 @@ public void nextFrameThrowIoException() throws Exception { public void nextFrameThrowsError() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); - OkHttpClientStream stream = + ClientStream stream = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); stream.request(1); @@ -467,7 +518,8 @@ public void nextFrameThrowsError() throws Exception { assertEquals(0, activeStreamCount()); assertEquals(Status.INTERNAL.getCode(), listener.status.getCode()); assertEquals(ERROR_MESSAGE, listener.status.getCause().getMessage()); - verify(transportListener, timeout(TIME_OUT_MS)).transportShutdown(isA(Status.class)); + verify(transportListener, timeout(TIME_OUT_MS)).transportShutdown(isA(Status.class), + any(DisconnectError.class)); verify(transportListener, timeout(TIME_OUT_MS)).transportTerminated(); shutdownAndVerify(); } @@ -476,14 +528,15 @@ public void nextFrameThrowsError() throws Exception { public void nextFrameReturnFalse() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); - OkHttpClientStream stream = + ClientStream stream = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); stream.request(1); frameReader.nextFrameAtEndOfStream(); listener.waitUntilStreamClosed(); assertEquals(Status.UNAVAILABLE.getCode(), listener.status.getCode()); - verify(transportListener, timeout(TIME_OUT_MS)).transportShutdown(isA(Status.class)); + verify(transportListener, timeout(TIME_OUT_MS)).transportShutdown(isA(Status.class), + any(DisconnectError.class)); verify(transportListener, timeout(TIME_OUT_MS)).transportTerminated(); shutdownAndVerify(); } @@ -494,7 +547,7 @@ public void readMessages() throws Exception { final int numMessages = 10; final String message = "Hello Client"; MockStreamListener listener = new MockStreamListener(); - OkHttpClientStream stream = + ClientStream stream = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); stream.request(numMessages); @@ -525,7 +578,8 @@ public void receivedHeadersForInvalidStreamShouldKillConnection() throws Excepti HeadersMode.HTTP_20_HEADERS); verify(frameWriter, timeout(TIME_OUT_MS)) .goAway(eq(0), eq(ErrorCode.PROTOCOL_ERROR), any(byte[].class)); - verify(transportListener).transportShutdown(isA(Status.class)); + verify(transportListener).transportShutdown(isA(Status.class), + any(DisconnectError.class)); verify(transportListener, timeout(TIME_OUT_MS)).transportTerminated(); shutdownAndVerify(); } @@ -537,7 +591,8 @@ public void receivedDataForInvalidStreamShouldKillConnection() throws Exception 1000, 1000); verify(frameWriter, timeout(TIME_OUT_MS)) .goAway(eq(0), eq(ErrorCode.PROTOCOL_ERROR), any(byte[].class)); - verify(transportListener).transportShutdown(isA(Status.class)); + verify(transportListener).transportShutdown(isA(Status.class), + any(DisconnectError.class)); verify(transportListener, timeout(TIME_OUT_MS)).transportTerminated(); shutdownAndVerify(); } @@ -546,7 +601,7 @@ public void receivedDataForInvalidStreamShouldKillConnection() throws Exception public void invalidInboundHeadersCancelStream() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); - OkHttpClientStream stream = + ClientStream stream = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); stream.request(1); @@ -571,7 +626,7 @@ public void invalidInboundHeadersCancelStream() throws Exception { public void invalidInboundTrailersPropagateToMetadata() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); - OkHttpClientStream stream = + ClientStream stream = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); stream.request(1); @@ -591,7 +646,7 @@ public void invalidInboundTrailersPropagateToMetadata() throws Exception { public void readStatus() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); - OkHttpClientStream stream = + ClientStream stream = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); assertContainStream(3); @@ -605,7 +660,7 @@ public void readStatus() throws Exception { public void receiveReset() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); - OkHttpClientStream stream = + ClientStream stream = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); assertContainStream(3); @@ -622,7 +677,7 @@ public void receiveReset() throws Exception { public void receiveResetNoError() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); - OkHttpClientStream stream = + ClientStream stream = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); assertContainStream(3); @@ -643,7 +698,7 @@ public void receiveResetNoError() throws Exception { public void cancelStream() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); - OkHttpClientStream stream = + ClientStream stream = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); getStream(3).cancel(Status.CANCELLED); @@ -658,7 +713,7 @@ public void cancelStream() throws Exception { public void addDefaultUserAgent() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); - OkHttpClientStream stream = + ClientStream stream = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); Header userAgentHeader = new Header(GrpcUtil.USER_AGENT_KEY.name(), @@ -675,9 +730,9 @@ public void addDefaultUserAgent() throws Exception { @Test public void overrideDefaultUserAgent() throws Exception { - startTransport(3, null, true, "fakeUserAgent"); + startTransport(3, null, true, "fakeUserAgent", null); MockStreamListener listener = new MockStreamListener(); - OkHttpClientStream stream = + ClientStream stream = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); List
expectedHeaders = Arrays.asList(HTTP_SCHEME_HEADER, METHOD_HEADER, @@ -696,7 +751,7 @@ public void overrideDefaultUserAgent() throws Exception { public void cancelStreamForDeadlineExceeded() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); - OkHttpClientStream stream = + ClientStream stream = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); getStream(3).cancel(Status.DEADLINE_EXCEEDED); @@ -710,7 +765,7 @@ public void writeMessage() throws Exception { initTransport(); final String message = "Hello Server"; MockStreamListener listener = new MockStreamListener(); - OkHttpClientStream stream = + ClientStream stream = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); InputStream input = new ByteArrayInputStream(message.getBytes(UTF_8)); @@ -725,6 +780,65 @@ public void writeMessage() throws Exception { shutdownAndVerify(); } + @Test + public void perRpcAuthoritySpecified_verificationSkippedInPlainTextConnection() + throws Exception { + initTransport(); + final String message = "Hello Server"; + MockStreamListener listener = new MockStreamListener(); + ClientStream stream = + clientTransport.newStream(method, new Metadata(), + CallOptions.DEFAULT.withAuthority("some-authority"), tracers); + stream.start(listener); + InputStream input = new ByteArrayInputStream(message.getBytes(UTF_8)); + assertEquals(12, input.available()); + stream.writeMessage(input); + stream.flush(); + verify(frameWriter, timeout(TIME_OUT_MS)) + .data(eq(false), eq(3), any(Buffer.class), eq(12 + HEADER_LENGTH)); + Buffer sentFrame = capturedBuffer.poll(); + assertEquals(createMessageFrame(message), sentFrame); + stream.cancel(Status.CANCELLED); + shutdownAndVerify(); + } + + @Test + public void perRpcAuthoritySpecified_hostnameVerification_ignoredForNonSslSocket() + throws Exception { + startTransport( + DEFAULT_START_STREAM_ID, null, true, null, + (hostname, session) -> false, false); + ClientStream unused = + clientTransport.newStream(method, new Metadata(), + CallOptions.DEFAULT.withAuthority("some-authority"), tracers); + shutdownAndVerify(); + } + + @Test + public void perRpcAuthoritySpecified_hostnameVerification_SslSocket_successCase() + throws Exception { + startTransport( + DEFAULT_START_STREAM_ID, null, true, null, + (hostname, session) -> true, true); + ClientStream unused = + clientTransport.newStream(method, new Metadata(), + CallOptions.DEFAULT.withAuthority("some-authority"), tracers); + shutdownAndVerify(); + } + + @Test + public void perRpcAuthoritySpecified_hostnameVerification_SslSocket_flagDisabled() + throws Exception { + startTransport( + DEFAULT_START_STREAM_ID, null, true, null, + (hostname, session) -> false, true); + ClientStream clientStream = + clientTransport.newStream(method, new Metadata(), + CallOptions.DEFAULT.withAuthority("some-authority"), tracers); + assertThat(clientStream).isInstanceOf(OkHttpClientStream.class); + shutdownAndVerify(); + } + @Test public void transportTracer_windowSizeDefault() throws Exception { initTransport(); @@ -751,12 +865,12 @@ public void windowUpdate() throws Exception { initTransport(); MockStreamListener listener1 = new MockStreamListener(); MockStreamListener listener2 = new MockStreamListener(); - OkHttpClientStream stream1 = + ClientStream stream1 = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream1.start(listener1); stream1.request(2); - OkHttpClientStream stream2 = + ClientStream stream2 = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream2.start(listener2); stream2.request(2); @@ -821,7 +935,7 @@ public void windowUpdate() throws Exception { public void windowUpdateWithInboundFlowControl() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); - OkHttpClientStream stream = + ClientStream stream = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); int messageLength = INITIAL_WINDOW_SIZE / 2 + 1; @@ -858,7 +972,7 @@ public void windowUpdateWithInboundFlowControl() throws Exception { public void outboundFlowControl() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); - OkHttpClientStream stream = + ClientStream stream = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); @@ -904,7 +1018,7 @@ public void outboundFlowControl_smallWindowSize() throws Exception { setInitialWindowSize(initialOutboundWindowSize); MockStreamListener listener = new MockStreamListener(); - OkHttpClientStream stream = + ClientStream stream = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); @@ -947,7 +1061,7 @@ public void outboundFlowControl_bigWindowSize() throws Exception { frameHandler().windowUpdate(0, 65535); MockStreamListener listener = new MockStreamListener(); - OkHttpClientStream stream = + ClientStream stream = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); @@ -983,7 +1097,7 @@ public void outboundFlowControl_bigWindowSize() throws Exception { public void outboundFlowControlWithInitialWindowSizeChange() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); - OkHttpClientStream stream = + ClientStream stream = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); int messageLength = 20; @@ -1029,7 +1143,7 @@ public void outboundFlowControlWithInitialWindowSizeChange() throws Exception { public void outboundFlowControlWithInitialWindowSizeChangeInMiddleOfStream() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); - OkHttpClientStream stream = + ClientStream stream = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); int messageLength = 20; @@ -1064,17 +1178,18 @@ public void stopNormally() throws Exception { initTransport(); MockStreamListener listener1 = new MockStreamListener(); MockStreamListener listener2 = new MockStreamListener(); - OkHttpClientStream stream1 = + ClientStream stream1 = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream1.start(listener1); - OkHttpClientStream stream2 = + ClientStream stream2 = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream2.start(listener2); assertEquals(2, activeStreamCount()); clientTransport.shutdown(SHUTDOWN_REASON); assertEquals(2, activeStreamCount()); - verify(transportListener).transportShutdown(same(SHUTDOWN_REASON)); + verify(transportListener).transportShutdown(same(SHUTDOWN_REASON), + eq(SimpleDisconnectError.SUBCHANNEL_SHUTDOWN)); stream1.cancel(Status.CANCELLED); stream2.cancel(Status.CANCELLED); @@ -1094,11 +1209,11 @@ public void receiveGoAway() throws Exception { // start 2 streams. MockStreamListener listener1 = new MockStreamListener(); MockStreamListener listener2 = new MockStreamListener(); - OkHttpClientStream stream1 = + ClientStream stream1 = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream1.start(listener1); stream1.request(1); - OkHttpClientStream stream2 = + ClientStream stream2 = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream2.start(listener2); stream2.request(1); @@ -1108,7 +1223,8 @@ public void receiveGoAway() throws Exception { frameHandler().goAway(3, ErrorCode.CANCEL, ByteString.EMPTY); // Transport should be in STOPPING state. - verify(transportListener).transportShutdown(isA(Status.class)); + verify(transportListener).transportShutdown(isA(Status.class), + any(DisconnectError.class)); verify(transportListener, never()).transportTerminated(); // Stream 2 should be closed. @@ -1121,7 +1237,7 @@ public void receiveGoAway() throws Exception { // But stream 1 should be able to send. final String sentMessage = "Should I also go away?"; - OkHttpClientStream stream = getStream(3); + ClientStream stream = getStream(3); InputStream input = new ByteArrayInputStream(sentMessage.getBytes(UTF_8)); assertEquals(22, input.available()); stream.writeMessage(input); @@ -1153,7 +1269,7 @@ public void streamIdExhausted() throws Exception { initTransport(startId); MockStreamListener listener = new MockStreamListener(); - OkHttpClientStream stream = + ClientStream stream = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); stream.request(1); @@ -1178,7 +1294,8 @@ public void streamIdExhausted() throws Exception { // Should only have the first message delivered. assertEquals(message, listener.messages.get(0)); verify(frameWriter, timeout(TIME_OUT_MS)).rstStream(eq(startId), eq(ErrorCode.CANCEL)); - verify(transportListener).transportShutdown(isA(Status.class)); + verify(transportListener).transportShutdown(isA(Status.class), + any(DisconnectError.class)); verify(transportListener, timeout(TIME_OUT_MS)).transportTerminated(); shutdownAndVerify(); } @@ -1189,11 +1306,11 @@ public void pendingStreamSucceed() throws Exception { setMaxConcurrentStreams(1); final MockStreamListener listener1 = new MockStreamListener(); final MockStreamListener listener2 = new MockStreamListener(); - OkHttpClientStream stream1 = + ClientStream stream1 = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream1.start(listener1); // The second stream should be pending. - OkHttpClientStream stream2 = + ClientStream stream2 = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream2.start(listener2); String sentMessage = "hello"; @@ -1226,7 +1343,7 @@ public void pendingStreamCancelled() throws Exception { initTransport(); setMaxConcurrentStreams(0); MockStreamListener listener = new MockStreamListener(); - OkHttpClientStream stream = + ClientStream stream = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); waitForStreamPending(1); @@ -1245,11 +1362,11 @@ public void pendingStreamFailedByGoAway() throws Exception { setMaxConcurrentStreams(1); final MockStreamListener listener1 = new MockStreamListener(); final MockStreamListener listener2 = new MockStreamListener(); - OkHttpClientStream stream1 = + ClientStream stream1 = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream1.start(listener1); // The second stream should be pending. - OkHttpClientStream stream2 = + ClientStream stream2 = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream2.start(listener2); @@ -1275,7 +1392,7 @@ public void pendingStreamSucceedAfterShutdown() throws Exception { setMaxConcurrentStreams(0); final MockStreamListener listener = new MockStreamListener(); // The second stream should be pending. - OkHttpClientStream stream = + ClientStream stream = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); waitForStreamPending(1); @@ -1299,15 +1416,15 @@ public void pendingStreamFailedByIdExhausted() throws Exception { final MockStreamListener listener2 = new MockStreamListener(); final MockStreamListener listener3 = new MockStreamListener(); - OkHttpClientStream stream1 = + ClientStream stream1 = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream1.start(listener1); // The second and third stream should be pending. - OkHttpClientStream stream2 = + ClientStream stream2 = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream2.start(listener2); - OkHttpClientStream stream3 = + ClientStream stream3 = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream3.start(listener3); @@ -1331,7 +1448,7 @@ public void pendingStreamFailedByIdExhausted() throws Exception { public void receivingWindowExceeded() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); - OkHttpClientStream stream = + ClientStream stream = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); stream.request(1); @@ -1383,7 +1500,7 @@ public void duplexStreamingHeadersShouldNotBeFlushed() throws Exception { private void shouldHeadersBeFlushed(boolean shouldBeFlushed) throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); - OkHttpClientStream stream = + ClientStream stream = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); verify(frameWriter, timeout(TIME_OUT_MS)).synStream( @@ -1400,7 +1517,7 @@ private void shouldHeadersBeFlushed(boolean shouldBeFlushed) throws Exception { public void receiveDataWithoutHeader() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); - OkHttpClientStream stream = + ClientStream stream = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); stream.request(1); @@ -1423,7 +1540,7 @@ public void receiveDataWithoutHeader() throws Exception { public void receiveDataWithoutHeaderAndTrailer() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); - OkHttpClientStream stream = + ClientStream stream = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); stream.request(1); @@ -1447,7 +1564,7 @@ public void receiveDataWithoutHeaderAndTrailer() throws Exception { public void receiveLongEnoughDataWithoutHeaderAndTrailer() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); - OkHttpClientStream stream = + ClientStream stream = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); stream.request(1); @@ -1469,7 +1586,7 @@ public void receiveLongEnoughDataWithoutHeaderAndTrailer() throws Exception { public void receiveDataForUnknownStreamUpdateConnectionWindow() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); - OkHttpClientStream stream = + ClientStream stream = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); stream.cancel(Status.CANCELLED); @@ -1489,7 +1606,8 @@ public void receiveDataForUnknownStreamUpdateConnectionWindow() throws Exception (int) buffer.size()); verify(frameWriter, timeout(TIME_OUT_MS)) .goAway(eq(0), eq(ErrorCode.PROTOCOL_ERROR), any(byte[].class)); - verify(transportListener).transportShutdown(isA(Status.class)); + verify(transportListener).transportShutdown(isA(Status.class), + any(DisconnectError.class)); verify(transportListener, timeout(TIME_OUT_MS)).transportTerminated(); shutdownAndVerify(); } @@ -1498,7 +1616,7 @@ public void receiveDataForUnknownStreamUpdateConnectionWindow() throws Exception public void receiveWindowUpdateForUnknownStream() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); - OkHttpClientStream stream = + ClientStream stream = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); stream.cancel(Status.CANCELLED); @@ -1509,7 +1627,8 @@ public void receiveWindowUpdateForUnknownStream() throws Exception { frameHandler().windowUpdate(5, 73); verify(frameWriter, timeout(TIME_OUT_MS)) .goAway(eq(0), eq(ErrorCode.PROTOCOL_ERROR), any(byte[].class)); - verify(transportListener).transportShutdown(isA(Status.class)); + verify(transportListener).transportShutdown(isA(Status.class), + any(DisconnectError.class)); verify(transportListener, timeout(TIME_OUT_MS)).transportTerminated(); shutdownAndVerify(); } @@ -1518,7 +1637,7 @@ public void receiveWindowUpdateForUnknownStream() throws Exception { public void shouldBeInitiallyReady() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); - OkHttpClientStream stream = + ClientStream stream = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); assertTrue(stream.isReady()); @@ -1536,7 +1655,7 @@ public void notifyOnReady() throws Exception { AbstractStream.TransportState.DEFAULT_ONREADY_THRESHOLD - HEADER_LENGTH - 1; setInitialWindowSize(0); MockStreamListener listener = new MockStreamListener(); - OkHttpClientStream stream = + ClientStream stream = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); assertTrue(stream.isReady()); @@ -1642,16 +1761,14 @@ public void ping_failsWhenTransportShutdown() throws Exception { clientTransport.shutdown(SHUTDOWN_REASON); // ping failed on channel shutdown assertEquals(1, callback.invocationCount); - assertTrue(callback.failureCause instanceof StatusException); - assertSame(SHUTDOWN_REASON, ((StatusException) callback.failureCause).getStatus()); + assertSame(SHUTDOWN_REASON, callback.failureCause); // now that handler is in terminal state, all future pings fail immediately callback = new PingCallbackImpl(); clientTransport.ping(callback, MoreExecutors.directExecutor()); assertEquals(1, getTransportStats(clientTransport).keepAlivesSent); assertEquals(1, callback.invocationCount); - assertTrue(callback.failureCause instanceof StatusException); - assertSame(SHUTDOWN_REASON, ((StatusException) callback.failureCause).getStatus()); + assertSame(SHUTDOWN_REASON, callback.failureCause); shutdownAndVerify(); } @@ -1666,18 +1783,14 @@ public void ping_failsIfTransportFails() throws Exception { clientTransport.onException(new IOException()); // ping failed on error assertEquals(1, callback.invocationCount); - assertTrue(callback.failureCause instanceof StatusException); - assertEquals(Status.Code.UNAVAILABLE, - ((StatusException) callback.failureCause).getStatus().getCode()); + assertEquals(Status.Code.UNAVAILABLE, callback.failureCause.getCode()); // now that handler is in terminal state, all future pings fail immediately callback = new PingCallbackImpl(); clientTransport.ping(callback, MoreExecutors.directExecutor()); assertEquals(1, getTransportStats(clientTransport).keepAlivesSent); assertEquals(1, callback.invocationCount); - assertTrue(callback.failureCause instanceof StatusException); - assertEquals(Status.Code.UNAVAILABLE, - ((StatusException) callback.failureCause).getStatus().getCode()); + assertEquals(Status.Code.UNAVAILABLE, callback.failureCause.getCode()); shutdownAndVerify(); } @@ -1689,7 +1802,7 @@ public void shutdownDuringConnecting() throws Exception { DEFAULT_START_STREAM_ID, connectingCallback, false, - null); + null, null); clientTransport.shutdown(SHUTDOWN_REASON); delayed.set(null); shutdownAndVerify(); @@ -1704,7 +1817,8 @@ public void invalidAuthorityPropagates() { "userAgent", EAG_ATTRS, NO_PROXY, - tooManyPingsRunnable); + tooManyPingsRunnable, + null); String host = clientTransport.getOverridenHost(); int port = clientTransport.getOverridenPort(); @@ -1722,13 +1836,15 @@ public void unreachableServer() throws Exception { "userAgent", EAG_ATTRS, NO_PROXY, - tooManyPingsRunnable); + tooManyPingsRunnable, + null); ManagedClientTransport.Listener listener = mock(ManagedClientTransport.Listener.class); clientTransport.start(listener); - ArgumentCaptor captor = ArgumentCaptor.forClass(Status.class); - verify(listener, timeout(TIME_OUT_MS)).transportShutdown(captor.capture()); - Status status = captor.getValue(); + ArgumentCaptor statusCaptor = ArgumentCaptor.forClass(Status.class); + verify(listener, timeout(TIME_OUT_MS)).transportShutdown(statusCaptor.capture(), + eq(new GoAwayDisconnectError(GrpcUtil.Http2Error.INTERNAL_ERROR))); + Status status = statusCaptor.getValue(); assertEquals(Status.UNAVAILABLE.getCode(), status.getCode()); assertTrue(status.getCause().toString(), status.getCause() instanceof IOException); @@ -1752,13 +1868,15 @@ public void customSocketFactory() throws Exception { "userAgent", EAG_ATTRS, NO_PROXY, - tooManyPingsRunnable); + tooManyPingsRunnable, + null); ManagedClientTransport.Listener listener = mock(ManagedClientTransport.Listener.class); clientTransport.start(listener); - ArgumentCaptor captor = ArgumentCaptor.forClass(Status.class); - verify(listener, timeout(TIME_OUT_MS)).transportShutdown(captor.capture()); - Status status = captor.getValue(); + ArgumentCaptor statusCaptor = ArgumentCaptor.forClass(Status.class); + verify(listener, timeout(TIME_OUT_MS)).transportShutdown(statusCaptor.capture(), + eq(new GoAwayDisconnectError(GrpcUtil.Http2Error.INTERNAL_ERROR))); + Status status = statusCaptor.getValue(); assertEquals(Status.UNAVAILABLE.getCode(), status.getCode()); assertSame(exception, status.getCause()); } @@ -1777,7 +1895,8 @@ public void proxy_200() throws Exception { .setTargetAddress(targetAddress) .setProxyAddress(new InetSocketAddress("localhost", serverSocket.getLocalPort())) .build(), - tooManyPingsRunnable); + tooManyPingsRunnable, + null); clientTransport.start(transportListener); Socket sock = serverSocket.accept(); @@ -1806,7 +1925,8 @@ public void proxy_200() throws Exception { }); sock.getOutputStream().flush(); - verify(transportListener, timeout(TIME_OUT_MS)).transportShutdown(isA(Status.class)); + verify(transportListener, timeout(TIME_OUT_MS)).transportShutdown(isA(Status.class), + any(DisconnectError.class)); while (sock.getInputStream().read() != -1) {} verify(transportListener, timeout(TIME_OUT_MS)).transportTerminated(); sock.close(); @@ -1826,7 +1946,8 @@ public void proxy_500() throws Exception { .setTargetAddress(targetAddress) .setProxyAddress(new InetSocketAddress("localhost", serverSocket.getLocalPort())) .build(), - tooManyPingsRunnable); + tooManyPingsRunnable, + null); clientTransport.start(transportListener); Socket sock = serverSocket.accept(); @@ -1845,17 +1966,18 @@ public void proxy_500() throws Exception { assertEquals(-1, sock.getInputStream().read()); - ArgumentCaptor captor = ArgumentCaptor.forClass(Status.class); - verify(transportListener, timeout(TIME_OUT_MS)).transportShutdown(captor.capture()); - Status error = captor.getValue(); - assertTrue("Status didn't contain error code: " + captor.getValue(), - error.getDescription().contains("500")); - assertTrue("Status didn't contain error description: " + captor.getValue(), - error.getDescription().contains("OH NO")); - assertTrue("Status didn't contain error text: " + captor.getValue(), - error.getDescription().contains(errorText)); - assertEquals("Not UNAVAILABLE: " + captor.getValue(), - Status.UNAVAILABLE.getCode(), error.getCode()); + ArgumentCaptor statusCaptor = ArgumentCaptor.forClass(Status.class); + verify(transportListener, timeout(TIME_OUT_MS)).transportShutdown(statusCaptor.capture(), + eq(new GoAwayDisconnectError(GrpcUtil.Http2Error.INTERNAL_ERROR))); + Status status = statusCaptor.getValue(); + assertTrue("Status didn't contain error code: " + statusCaptor.getValue(), + status.getDescription().contains("500")); + assertTrue("Status didn't contain error description: " + statusCaptor.getValue(), + status.getDescription().contains("OH NO")); + assertTrue("Status didn't contain error text: " + statusCaptor.getValue(), + status.getDescription().contains(errorText)); + assertEquals("Not UNAVAILABLE: " + statusCaptor.getValue(), + Status.UNAVAILABLE.getCode(), status.getCode()); sock.close(); verify(transportListener, timeout(TIME_OUT_MS)).transportTerminated(); } @@ -1874,20 +1996,22 @@ public void proxy_immediateServerClose() throws Exception { .setTargetAddress(targetAddress) .setProxyAddress(new InetSocketAddress("localhost", serverSocket.getLocalPort())) .build(), - tooManyPingsRunnable); + tooManyPingsRunnable, + null); clientTransport.start(transportListener); Socket sock = serverSocket.accept(); serverSocket.close(); sock.close(); - ArgumentCaptor captor = ArgumentCaptor.forClass(Status.class); - verify(transportListener, timeout(TIME_OUT_MS)).transportShutdown(captor.capture()); - Status error = captor.getValue(); - assertTrue("Status didn't contain proxy: " + captor.getValue(), - error.getDescription().contains("proxy")); - assertEquals("Not UNAVAILABLE: " + captor.getValue(), - Status.UNAVAILABLE.getCode(), error.getCode()); + ArgumentCaptor statusCaptor = ArgumentCaptor.forClass(Status.class); + verify(transportListener, timeout(TIME_OUT_MS)).transportShutdown(statusCaptor.capture(), + eq(new GoAwayDisconnectError(GrpcUtil.Http2Error.INTERNAL_ERROR))); + Status status = statusCaptor.getValue(); + assertTrue("Status didn't contain proxy: " + statusCaptor.getValue(), + status.getDescription().contains("proxy")); + assertEquals("Not UNAVAILABLE: " + statusCaptor.getValue(), + Status.UNAVAILABLE.getCode(), status.getCode()); verify(transportListener, timeout(TIME_OUT_MS)).transportTerminated(); } @@ -1905,7 +2029,8 @@ public void proxy_serverHangs() throws Exception { .setTargetAddress(targetAddress) .setProxyAddress(new InetSocketAddress("localhost", serverSocket.getLocalPort())) .build(), - tooManyPingsRunnable); + tooManyPingsRunnable, + null); clientTransport.proxySocketTimeout = 10; clientTransport.start(transportListener); @@ -1917,7 +2042,8 @@ public void proxy_serverHangs() throws Exception { assertEquals("Host: theservice:80", reader.readLine()); while (!"".equals(reader.readLine())) {} - verify(transportListener, timeout(200)).transportShutdown(any(Status.class)); + verify(transportListener, timeout(200)).transportShutdown(any(Status.class), + any(DisconnectError.class)); verify(transportListener, timeout(TIME_OUT_MS)).transportTerminated(); sock.close(); } @@ -1972,13 +2098,13 @@ public void goAway_streamListenerRpcProgress() throws Exception { MockStreamListener listener1 = new MockStreamListener(); MockStreamListener listener2 = new MockStreamListener(); MockStreamListener listener3 = new MockStreamListener(); - OkHttpClientStream stream1 = + ClientStream stream1 = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream1.start(listener1); - OkHttpClientStream stream2 = + ClientStream stream2 = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream2.start(listener2); - OkHttpClientStream stream3 = + ClientStream stream3 = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream3.start(listener3); waitForStreamPending(1); @@ -2012,13 +2138,13 @@ public void reset_streamListenerRpcProgress() throws Exception { MockStreamListener listener1 = new MockStreamListener(); MockStreamListener listener2 = new MockStreamListener(); MockStreamListener listener3 = new MockStreamListener(); - OkHttpClientStream stream1 = + ClientStream stream1 = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream1.start(listener1); - OkHttpClientStream stream2 = + ClientStream stream2 = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream2.start(listener2); - OkHttpClientStream stream3 = + ClientStream stream3 = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream3.start(listener3); @@ -2054,13 +2180,13 @@ public void shutdownNow_streamListenerRpcProgress() throws Exception { MockStreamListener listener1 = new MockStreamListener(); MockStreamListener listener2 = new MockStreamListener(); MockStreamListener listener3 = new MockStreamListener(); - OkHttpClientStream stream1 = + ClientStream stream1 = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream1.start(listener1); - OkHttpClientStream stream2 = + ClientStream stream2 = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream2.start(listener2); - OkHttpClientStream stream3 = + ClientStream stream3 = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream3.start(listener3); waitForStreamPending(1); @@ -2080,6 +2206,26 @@ public void shutdownNow_streamListenerRpcProgress() throws Exception { assertEquals(MISCARRIED, listener3.rpcProgress); } + @Test + public void finishedStreamRemovedFromInUseState() throws Exception { + initTransport(); + setMaxConcurrentStreams(1); + final MockStreamListener listener = new MockStreamListener(); + OkHttpClientStream stream = clientTransport.newStream( + method, new Metadata(), CallOptions.DEFAULT, tracers); + stream.start(listener); + OkHttpClientStream pendingStream = clientTransport.newStream( + method, new Metadata(), CallOptions.DEFAULT, tracers); + pendingStream.start(listener); + waitForStreamPending(1); + clientTransport.finishStream(stream.transportState().id(), Status.OK, PROCESSED, + false, null, null); + verify(transportListener).transportInUse(true); + clientTransport.finishStream(pendingStream.transportState().id(), Status.OK, PROCESSED, + false, null, null); + verify(transportListener).transportInUse(false); + } + private int activeStreamCount() { return clientTransport.getActiveStreams().length; } @@ -2109,7 +2255,7 @@ private void waitForStreamPending(int expected) throws Exception { private void assertNewStreamFail() throws Exception { MockStreamListener listener = new MockStreamListener(); - OkHttpClientStream stream = + ClientStream stream = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); listener.waitUntilStreamClosed(); @@ -2340,10 +2486,128 @@ public InputStream getInputStream() { } } + private static class MockSslSocket extends SSLSocket { + private Socket delegate; + + MockSslSocket(Socket socket) { + delegate = socket; + } + + @Override + public String[] getSupportedCipherSuites() { + return new String[0]; + } + + @Override + public String[] getEnabledCipherSuites() { + return new String[0]; + } + + @Override + public void setEnabledCipherSuites(String[] suites) { + + } + + @Override + public String[] getSupportedProtocols() { + return new String[0]; + } + + @Override + public String[] getEnabledProtocols() { + return new String[0]; + } + + @Override + public void setEnabledProtocols(String[] protocols) { + + } + + @Override + public SSLSession getSession() { + return null; + } + + @Override + public void addHandshakeCompletedListener(HandshakeCompletedListener listener) { + + } + + @Override + public void removeHandshakeCompletedListener(HandshakeCompletedListener listener) { + + } + + @Override + public void startHandshake() throws IOException { + + } + + @Override + public void setUseClientMode(boolean mode) { + + } + + @Override + public boolean getUseClientMode() { + return false; + } + + @Override + public void setNeedClientAuth(boolean need) { + + } + + @Override + public boolean getNeedClientAuth() { + return false; + } + + @Override + public void setWantClientAuth(boolean want) { + + } + + @Override + public boolean getWantClientAuth() { + return false; + } + + @Override + public void setEnableSessionCreation(boolean flag) { + + } + + @Override + public boolean getEnableSessionCreation() { + return false; + } + + @Override + public synchronized void close() throws IOException { + delegate.close(); + } + + @Override + public SocketAddress getLocalSocketAddress() { + return delegate.getLocalSocketAddress(); + } + + @Override + public OutputStream getOutputStream() throws IOException { + return delegate.getOutputStream(); + } + + @Override + public InputStream getInputStream() throws IOException { + return delegate.getInputStream(); + } + } + static class PingCallbackImpl implements ClientTransport.PingCallback { int invocationCount; long roundTripTime; - Throwable failureCause; + Status failureCause; @Override public void onSuccess(long roundTripTimeNanos) { @@ -2352,7 +2616,7 @@ public void onSuccess(long roundTripTimeNanos) { } @Override - public void onFailure(Throwable cause) { + public void onFailure(Status cause) { invocationCount++; this.failureCause = cause; } diff --git a/okhttp/src/test/java/io/grpc/okhttp/OkHttpProtocolNegotiatorTest.java b/okhttp/src/test/java/io/grpc/okhttp/OkHttpProtocolNegotiatorTest.java index 3a4a21b2467..4353dc2597b 100644 --- a/okhttp/src/test/java/io/grpc/okhttp/OkHttpProtocolNegotiatorTest.java +++ b/okhttp/src/test/java/io/grpc/okhttp/OkHttpProtocolNegotiatorTest.java @@ -16,10 +16,12 @@ package io.grpc.okhttp; -import static com.google.common.base.Charsets.UTF_8; +import static com.google.common.truth.Truth.assertThat; +import static java.nio.charset.StandardCharsets.UTF_8; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertThrows; import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; @@ -37,9 +39,7 @@ import javax.net.ssl.SSLParameters; import javax.net.ssl.SSLSession; import javax.net.ssl.SSLSocket; -import org.junit.Rule; import org.junit.Test; -import org.junit.rules.ExpectedException; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; import org.mockito.ArgumentMatchers; @@ -49,9 +49,6 @@ */ @RunWith(JUnit4.class) public class OkHttpProtocolNegotiatorTest { - @SuppressWarnings("deprecation") // https://github.com/grpc/grpc-java/issues/7467 - @Rule public final ExpectedException thrown = ExpectedException.none(); - private final SSLSocket sock = mock(SSLSocket.class); private final Platform platform = mock(Platform.class); @@ -118,21 +115,19 @@ public void negotiate_handshakeFails() throws IOException { OkHttpProtocolNegotiator negotiator = OkHttpProtocolNegotiator.get(); doReturn(parameters).when(sock).getSSLParameters(); doThrow(new IOException()).when(sock).startHandshake(); - thrown.expect(IOException.class); - - negotiator.negotiate(sock, "hostname", ImmutableList.of(Protocol.HTTP_2)); + assertThrows(IOException.class, + () -> negotiator.negotiate(sock, "hostname", ImmutableList.of(Protocol.HTTP_2))); } @Test - public void negotiate_noSelectedProtocol() throws Exception { + public void negotiate_noSelectedProtocol() { Platform platform = mock(Platform.class); OkHttpProtocolNegotiator negotiator = new OkHttpProtocolNegotiator(platform); - thrown.expect(RuntimeException.class); - thrown.expectMessage("TLS ALPN negotiation failed"); - - negotiator.negotiate(sock, "hostname", ImmutableList.of(Protocol.HTTP_2)); + RuntimeException e = assertThrows(RuntimeException.class, + () -> negotiator.negotiate(sock, "hostname", ImmutableList.of(Protocol.HTTP_2))); + assertThat(e).hasMessageThat().isEqualTo("TLS ALPN negotiation failed with protocols: [h2]"); } @Test @@ -150,7 +145,7 @@ public void negotiate_success() throws Exception { // Checks that the super class is properly invoked. @Test - public void negotiate_android_handshakeFails() throws Exception { + public void negotiate_android_handshakeFails() { when(platform.getTlsExtensionType()).thenReturn(TlsExtensionType.ALPN_AND_NPN); AndroidNegotiator negotiator = new AndroidNegotiator(platform); @@ -161,10 +156,9 @@ public void startHandshake() throws IOException { } }; - thrown.expect(IOException.class); - thrown.expectMessage("expected"); - - negotiator.negotiate(androidSock, "hostname", ImmutableList.of(Protocol.HTTP_2)); + IOException e = assertThrows(IOException.class, + () -> negotiator.negotiate(androidSock, "hostname", ImmutableList.of(Protocol.HTTP_2))); + assertThat(e).hasMessageThat().isEqualTo("expected"); } @VisibleForTesting diff --git a/okhttp/src/test/java/io/grpc/okhttp/OkHttpReadableBufferTest.java b/okhttp/src/test/java/io/grpc/okhttp/OkHttpReadableBufferTest.java index 4aeeae2fa8b..be8dbf0e62b 100644 --- a/okhttp/src/test/java/io/grpc/okhttp/OkHttpReadableBufferTest.java +++ b/okhttp/src/test/java/io/grpc/okhttp/OkHttpReadableBufferTest.java @@ -44,18 +44,6 @@ public void setup() { } } - @Override - @Test - public void readToByteBufferShouldSucceed() { - // Not supported. - } - - @Override - @Test - public void partialReadToByteBufferShouldSucceed() { - // Not supported. - } - @Override @Test public void markAndResetWithReadShouldSucceed() { diff --git a/okhttp/src/test/java/io/grpc/okhttp/OkHttpServerTransportTest.java b/okhttp/src/test/java/io/grpc/okhttp/OkHttpServerTransportTest.java index 3f88c35e017..00db6e1d339 100644 --- a/okhttp/src/test/java/io/grpc/okhttp/OkHttpServerTransportTest.java +++ b/okhttp/src/test/java/io/grpc/okhttp/OkHttpServerTransportTest.java @@ -16,12 +16,12 @@ package io.grpc.okhttp; -import static com.google.common.base.Charsets.UTF_8; import static com.google.common.truth.Truth.assertThat; import static io.grpc.okhttp.Headers.CONTENT_TYPE_HEADER; import static io.grpc.okhttp.Headers.HTTP_SCHEME_HEADER; import static io.grpc.okhttp.Headers.METHOD_HEADER; import static io.grpc.okhttp.Headers.TE_HEADER; +import static java.nio.charset.StandardCharsets.UTF_8; import static org.mockito.AdditionalAnswers.answerVoid; import static org.mockito.AdditionalAnswers.delegatesTo; import static org.mockito.ArgumentMatchers.any; @@ -34,6 +34,7 @@ import static org.mockito.Mockito.timeout; import static org.mockito.Mockito.verify; +import com.google.common.collect.Lists; import com.google.common.io.ByteStreams; import io.grpc.Attributes; import io.grpc.InternalChannelz.SocketStats; @@ -62,6 +63,7 @@ import java.util.ArrayDeque; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; import java.util.Deque; import java.util.List; import java.util.concurrent.CountDownLatch; @@ -919,8 +921,9 @@ public void httpGet_failsWith405() throws Exception { CONTENT_TYPE_HEADER, TE_HEADER)); clientFrameWriter.flush(); - - verifyHttpError(1, 405, Status.Code.INTERNAL, "HTTP Method is not supported: GET"); + List
extraHeaders = Lists.newArrayList(new Header("allow", "POST")); + verifyHttpError(1, 405, Status.Code.INTERNAL, "HTTP Method is not supported: GET", + extraHeaders); shutdownAndTerminate(/*lastStreamId=*/ 1); } @@ -976,7 +979,8 @@ public void httpErrorsAdhereToFlowControl() throws Exception { new Header(":status", "405"), new Header("content-type", "text/plain; charset=utf-8"), new Header("grpc-status", "" + Status.Code.INTERNAL.value()), - new Header("grpc-message", errorDescription)); + new Header("grpc-message", errorDescription), + new Header("allow", "POST")); assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue(); verify(clientFramesRead) .headers(false, false, 1, -1, responseHeaders, HeadersMode.HTTP_20_HEADERS); @@ -1264,6 +1268,60 @@ public void keepAliveEnforcer_noticesActive() throws Exception { eq(ByteString.encodeString("too_many_pings", GrpcUtil.US_ASCII))); } + @Test + public void keepAliveEnforcer_doesNotEnforcePingAcks() throws Exception { + serverBuilder.permitKeepAliveTime(1, TimeUnit.HOURS) + .permitKeepAliveWithoutCalls(true); + initTransport(); + handshake(); + + for (int i = 0; i < KeepAliveEnforcer.MAX_PING_STRIKES + 2; i++) { + int serverPingId = 0xDEAD + i; + clientFrameWriter.ping(true, serverPingId, 0); + clientFrameWriter.flush(); + } + + for (int i = 0; i < KeepAliveEnforcer.MAX_PING_STRIKES; i++) { + pingPong(); + } + + pingPongId++; + clientFrameWriter.ping(false, pingPongId, 0); + clientFrameWriter.flush(); + assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue(); + verify(clientFramesRead).goAway(0, ErrorCode.ENHANCE_YOUR_CALM, + ByteString.encodeString("too_many_pings", GrpcUtil.US_ASCII)); + } + + @Test + public void maxConcurrentCallsPerConnection_failsWithRst() throws Exception { + int maxConcurrentCallsPerConnection = 1; + serverBuilder.maxConcurrentCallsPerConnection(maxConcurrentCallsPerConnection); + initTransport(); + handshake(); + + ArgumentCaptor settingsCaptor = ArgumentCaptor.forClass(Settings.class); + verify(clientFramesRead).settings(eq(false), settingsCaptor.capture()); + final Settings settings = settingsCaptor.getValue(); + assertThat(OkHttpSettingsUtil.get(settings, OkHttpSettingsUtil.MAX_CONCURRENT_STREAMS)) + .isEqualTo(maxConcurrentCallsPerConnection); + + final List
headers = Arrays.asList( + HTTP_SCHEME_HEADER, + METHOD_HEADER, + new Header(Header.TARGET_AUTHORITY, "example.com:80"), + new Header(Header.TARGET_PATH, "/com.example/SimpleService/doit"), + CONTENT_TYPE_HEADER, + TE_HEADER); + + clientFrameWriter.headers(1, headers); + clientFrameWriter.headers(3, headers); + clientFrameWriter.flush(); + + assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue(); + verify(clientFramesRead).rstStream(3, ErrorCode.REFUSED_STREAM); + } + private void initTransport() throws Exception { serverTransport = new OkHttpServerTransport( new OkHttpServerTransport.Config(serverBuilder, Arrays.asList()), @@ -1369,11 +1427,18 @@ private void pingPong() throws IOException { private void verifyHttpError( int streamId, int httpCode, Status.Code grpcCode, String errorDescription) throws Exception { - List
responseHeaders = Arrays.asList( + verifyHttpError(streamId, httpCode, grpcCode, errorDescription, Collections.emptyList()); + } + + private void verifyHttpError( + int streamId, int httpCode, Status.Code grpcCode, String errorDescription, + List
extraHeaders) throws Exception { + List
responseHeaders = Lists.newArrayList( new Header(":status", "" + httpCode), new Header("content-type", "text/plain; charset=utf-8"), new Header("grpc-status", "" + grpcCode.value()), new Header("grpc-message", errorDescription)); + responseHeaders.addAll(extraHeaders); assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue(); verify(clientFramesRead) .headers(false, false, streamId, -1, responseHeaders, HeadersMode.HTTP_20_HEADERS); diff --git a/okhttp/src/test/java/io/grpc/okhttp/OkHttpWritableBufferAllocatorTest.java b/okhttp/src/test/java/io/grpc/okhttp/OkHttpWritableBufferAllocatorTest.java index e606b6b9a50..c19224822a8 100644 --- a/okhttp/src/test/java/io/grpc/okhttp/OkHttpWritableBufferAllocatorTest.java +++ b/okhttp/src/test/java/io/grpc/okhttp/OkHttpWritableBufferAllocatorTest.java @@ -16,11 +16,13 @@ package io.grpc.okhttp; +import static io.grpc.okhttp.OkHttpWritableBufferAllocator.SEGMENT_SIZE_COPY; import static org.junit.Assert.assertEquals; import io.grpc.internal.WritableBuffer; import io.grpc.internal.WritableBufferAllocator; import io.grpc.internal.WritableBufferAllocatorTestBase; +import okio.Segment; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -38,11 +40,12 @@ protected WritableBufferAllocator allocator() { return allocator; } + @SuppressWarnings("KotlinInternal") @Test public void testCapacity() { WritableBuffer buffer = allocator().allocate(4096); assertEquals(0, buffer.readableBytes()); - assertEquals(4096, buffer.writableBytes()); + assertEquals(SEGMENT_SIZE_COPY, buffer.writableBytes()); } @Test @@ -54,8 +57,14 @@ public void testInitialCapacityHasMaximum() { @Test public void testIsExactBelowMaxCapacity() { - WritableBuffer buffer = allocator().allocate(4097); + WritableBuffer buffer = allocator().allocate(SEGMENT_SIZE_COPY + 1); assertEquals(0, buffer.readableBytes()); - assertEquals(4097, buffer.writableBytes()); + assertEquals(SEGMENT_SIZE_COPY * 2, buffer.writableBytes()); + } + + @SuppressWarnings("KotlinInternal") + @Test + public void testSegmentSizeMatchesKotlin() { + assertEquals(Segment.SIZE, SEGMENT_SIZE_COPY); } } diff --git a/okhttp/src/test/java/io/grpc/okhttp/TlsTest.java b/okhttp/src/test/java/io/grpc/okhttp/TlsTest.java index a21360a89ba..20a2f1a5ca7 100644 --- a/okhttp/src/test/java/io/grpc/okhttp/TlsTest.java +++ b/okhttp/src/test/java/io/grpc/okhttp/TlsTest.java @@ -18,8 +18,10 @@ import static com.google.common.truth.Truth.assertThat; import static com.google.common.truth.Truth.assertWithMessage; +import static org.junit.Assert.fail; import com.google.common.base.Throwables; +import io.grpc.CallOptions; import io.grpc.ChannelCredentials; import io.grpc.ConnectivityState; import io.grpc.ManagedChannel; @@ -32,18 +34,34 @@ import io.grpc.TlsServerCredentials; import io.grpc.internal.testing.TestUtils; import io.grpc.okhttp.internal.Platform; +import io.grpc.stub.ClientCalls; import io.grpc.stub.StreamObserver; import io.grpc.testing.GrpcCleanupRule; import io.grpc.testing.TlsTesting; import io.grpc.testing.protobuf.SimpleRequest; import io.grpc.testing.protobuf.SimpleResponse; import io.grpc.testing.protobuf.SimpleServiceGrpc; +import io.grpc.util.CertificateUtils; import java.io.IOException; import java.io.InputStream; +import java.net.Socket; +import java.security.GeneralSecurityException; +import java.security.KeyStore; +import java.security.cert.CertificateException; +import java.security.cert.X509Certificate; +import java.util.Arrays; +import java.util.Optional; import javax.net.ssl.SSLContext; import javax.net.ssl.SSLEngine; import javax.net.ssl.SSLPeerUnverifiedException; +import javax.net.ssl.SSLSocket; import javax.net.ssl.SSLSocketFactory; +import javax.net.ssl.TrustManager; +import javax.net.ssl.TrustManagerFactory; +import javax.net.ssl.X509ExtendedTrustManager; +import javax.net.ssl.X509TrustManager; +import javax.security.auth.x500.X500Principal; +import org.codehaus.mojo.animal_sniffer.IgnoreJRERequirement; import org.junit.Assume; import org.junit.Before; import org.junit.Rule; @@ -53,6 +71,7 @@ /** Verify OkHttp's TLS integration. */ @RunWith(JUnit4.class) +@IgnoreJRERequirement public class TlsTest { @Rule public final GrpcCleanupRule grpcCleanupRule = new GrpcCleanupRule(); @@ -92,6 +111,325 @@ public void basicTls_succeeds() throws Exception { SimpleServiceGrpc.newBlockingStub(channel).unaryRpc(SimpleRequest.getDefaultInstance()); } + @Test + public void perRpcAuthorityOverride_hostnameVerifier_goodAuthority_succeeds() throws Exception { + OkHttpClientTransport.enablePerRpcAuthorityCheck = true; + try { + ServerCredentials serverCreds; + try (InputStream serverCert = TlsTesting.loadCert("server1.pem"); + InputStream serverPrivateKey = TlsTesting.loadCert("server1.key")) { + serverCreds = TlsServerCredentials.newBuilder() + .keyManager(serverCert, serverPrivateKey) + .build(); + } + ChannelCredentials channelCreds; + try (InputStream caCert = TlsTesting.loadCert("ca.pem")) { + channelCreds = TlsChannelCredentials.newBuilder() + .trustManager(caCert) + .build(); + } + Server server = grpcCleanupRule.register(server(serverCreds)); + ManagedChannel channel = grpcCleanupRule.register(clientChannel(server, channelCreds)); + + ClientCalls.blockingUnaryCall(channel, SimpleServiceGrpc.getUnaryRpcMethod(), + CallOptions.DEFAULT.withAuthority("good.test.google.fr"), + SimpleRequest.getDefaultInstance()); + } finally { + OkHttpClientTransport.enablePerRpcAuthorityCheck = false; + } + } + + @Test + public void perRpcAuthorityOverride_hostnameVerifier_badAuthority_fails() + throws Exception { + OkHttpClientTransport.enablePerRpcAuthorityCheck = true; + try { + ServerCredentials serverCreds; + try (InputStream serverCert = TlsTesting.loadCert("server1.pem"); + InputStream serverPrivateKey = TlsTesting.loadCert("server1.key")) { + serverCreds = TlsServerCredentials.newBuilder() + .keyManager(serverCert, serverPrivateKey) + .build(); + } + ChannelCredentials channelCreds; + try (InputStream caCert = TlsTesting.loadCert("ca.pem")) { + channelCreds = TlsChannelCredentials.newBuilder() + .trustManager(caCert) + .build(); + } + Server server = grpcCleanupRule.register(server(serverCreds)); + ManagedChannel channel = grpcCleanupRule.register(clientChannel(server, channelCreds)); + + try { + ClientCalls.blockingUnaryCall(channel, SimpleServiceGrpc.getUnaryRpcMethod(), + CallOptions.DEFAULT.withAuthority("disallowed.name.com"), + SimpleRequest.getDefaultInstance()); + fail("Expected exception for hostname verifier failure."); + } catch (StatusRuntimeException ex) { + assertThat(ex.getStatus().getCode()).isEqualTo(Status.Code.UNAVAILABLE); + assertThat(ex.getStatus().getDescription()).isEqualTo( + "HostNameVerifier verification failed for authority 'disallowed.name.com'"); + } + } finally { + OkHttpClientTransport.enablePerRpcAuthorityCheck = false; + } + } + + @Test + public void perRpcAuthorityOverride_hostnameVerifier_badAuthority_flagDisabled_succeeds() + throws Exception { + ServerCredentials serverCreds; + try (InputStream serverCert = TlsTesting.loadCert("server1.pem"); + InputStream serverPrivateKey = TlsTesting.loadCert("server1.key")) { + serverCreds = TlsServerCredentials.newBuilder() + .keyManager(serverCert, serverPrivateKey) + .build(); + } + ChannelCredentials channelCreds; + try (InputStream caCert = TlsTesting.loadCert("ca.pem")) { + channelCreds = TlsChannelCredentials.newBuilder() + .trustManager(caCert) + .build(); + } + Server server = grpcCleanupRule.register(server(serverCreds)); + ManagedChannel channel = grpcCleanupRule.register(clientChannel(server, channelCreds)); + + ClientCalls.blockingUnaryCall(channel, SimpleServiceGrpc.getUnaryRpcMethod(), + CallOptions.DEFAULT.withAuthority("disallowed.name.com"), + SimpleRequest.getDefaultInstance()); + } + + @Test + public void perRpcAuthorityOverride_noTlsCredentialsUsedToBuildChannel_fails() throws Exception { + OkHttpClientTransport.enablePerRpcAuthorityCheck = true; + try { + ServerCredentials serverCreds; + try (InputStream serverCert = TlsTesting.loadCert("server1.pem"); + InputStream serverPrivateKey = TlsTesting.loadCert("server1.key")) { + serverCreds = TlsServerCredentials.newBuilder() + .keyManager(serverCert, serverPrivateKey) + .build(); + } + Server server = grpcCleanupRule.register(server(serverCreds)); + SSLSocketFactory sslSocketFactory = TestUtils.newSslSocketFactoryForCa( + Platform.get().getProvider(), TestUtils.loadCert("ca.pem")); + ManagedChannel channel = grpcCleanupRule.register( + OkHttpChannelBuilder.forAddress("localhost", server.getPort()) + .overrideAuthority(TestUtils.TEST_SERVER_HOST) + .directExecutor() + .sslSocketFactory(sslSocketFactory) + .build()); + + try { + ClientCalls.blockingUnaryCall(channel, SimpleServiceGrpc.getUnaryRpcMethod(), + CallOptions.DEFAULT.withAuthority("bar.test.google.fr"), + SimpleRequest.getDefaultInstance()); + fail("Expected exception for authority verification failure."); + } catch (StatusRuntimeException ex) { + assertThat(ex.getStatus().getCode()).isEqualTo(Status.Code.UNAVAILABLE); + assertThat(ex.getStatus().getDescription()).isEqualTo( + "Could not verify authority 'bar.test.google.fr' for the rpc with no " + + "X509TrustManager available"); + } + } finally { + OkHttpClientTransport.enablePerRpcAuthorityCheck = false; + } + } + + @Test + public void perRpcAuthorityOverride_trustManager_permitted_succeeds() throws Exception { + OkHttpClientTransport.enablePerRpcAuthorityCheck = true; + try { + ServerCredentials serverCreds; + try (InputStream serverCert = TlsTesting.loadCert("server1.pem"); + InputStream serverPrivateKey = TlsTesting.loadCert("server1.key")) { + serverCreds = TlsServerCredentials.newBuilder() + .keyManager(serverCert, serverPrivateKey) + .build(); + } + ChannelCredentials channelCreds; + try (InputStream caCert = TlsTesting.loadCert("ca.pem")) { + X509ExtendedTrustManager regularTrustManager = + (X509ExtendedTrustManager) getX509ExtendedTrustManager(caCert).get(); + channelCreds = TlsChannelCredentials.newBuilder() + .trustManager(new HostnameCheckingX509ExtendedTrustManager(regularTrustManager)) + .build(); + } + Server server = grpcCleanupRule.register(server(serverCreds)); + ManagedChannel channel = grpcCleanupRule.register(clientChannel(server, channelCreds)); + + ClientCalls.blockingUnaryCall(channel, SimpleServiceGrpc.getUnaryRpcMethod(), + CallOptions.DEFAULT.withAuthority("good.test.google.fr"), + SimpleRequest.getDefaultInstance()); + } finally { + OkHttpClientTransport.enablePerRpcAuthorityCheck = false; + } + } + + @Test + public void perRpcAuthorityOverride_trustManager_denied_fails() throws Exception { + OkHttpClientTransport.enablePerRpcAuthorityCheck = true; + try { + ServerCredentials serverCreds; + try (InputStream serverCert = TlsTesting.loadCert("server1.pem"); + InputStream serverPrivateKey = TlsTesting.loadCert("server1.key")) { + serverCreds = TlsServerCredentials.newBuilder() + .keyManager(serverCert, serverPrivateKey) + .build(); + } + ChannelCredentials channelCreds; + try (InputStream caCert = TlsTesting.loadCert("ca.pem")) { + X509ExtendedTrustManager regularTrustManager = + (X509ExtendedTrustManager) getX509ExtendedTrustManager(caCert).get(); + channelCreds = TlsChannelCredentials.newBuilder() + .trustManager(new HostnameCheckingX509ExtendedTrustManager(regularTrustManager)) + .build(); + } + Server server = grpcCleanupRule.register(server(serverCreds)); + ManagedChannel channel = grpcCleanupRule.register(clientChannel(server, channelCreds)); + + try { + ClientCalls.blockingUnaryCall(channel, SimpleServiceGrpc.getUnaryRpcMethod(), + CallOptions.DEFAULT.withAuthority("bad.test.google.fr"), + SimpleRequest.getDefaultInstance()); + fail("Expected exception for authority verification failure."); + } catch (StatusRuntimeException ex) { + assertThat(ex.getStatus().getCode()).isEqualTo(Status.Code.UNAVAILABLE); + assertThat(ex.getCause().getCause()).isInstanceOf(CertificateException.class); + } + } finally { + OkHttpClientTransport.enablePerRpcAuthorityCheck = false; + } + } + + @Test + public void perRpcAuthorityOverride_trustManager_denied_flagDisabled_succeeds() + throws Exception { + ServerCredentials serverCreds; + try (InputStream serverCert = TlsTesting.loadCert("server1.pem"); + InputStream serverPrivateKey = TlsTesting.loadCert("server1.key")) { + serverCreds = TlsServerCredentials.newBuilder() + .keyManager(serverCert, serverPrivateKey) + .build(); + } + ChannelCredentials channelCreds; + try (InputStream caCert = TlsTesting.loadCert("ca.pem")) { + X509ExtendedTrustManager regularTrustManager = + (X509ExtendedTrustManager) getX509ExtendedTrustManager(caCert).get(); + channelCreds = TlsChannelCredentials.newBuilder() + .trustManager(new HostnameCheckingX509ExtendedTrustManager(regularTrustManager)) + .build(); + } + Server server = grpcCleanupRule.register(server(serverCreds)); + ManagedChannel channel = grpcCleanupRule.register(clientChannel(server, channelCreds)); + + ClientCalls.blockingUnaryCall(channel, SimpleServiceGrpc.getUnaryRpcMethod(), + CallOptions.DEFAULT.withAuthority("bad.test.google.fr"), + SimpleRequest.getDefaultInstance()); + } + + /** + * This test simulates the absence of X509ExtendedTrustManager while still using the + * real trust manager for the connection handshake to happen. When the TrustManager is not an + * X509ExtendedTrustManager, the per-rpc check ignores the trust manager. However, the + * HostnameVerifier is still used, so only valid authorities are permitted. + */ + @Test + public void perRpcAuthorityOverride_notX509ExtendedTrustManager_goodAuthority_succeeds() + throws Exception { + OkHttpClientTransport.enablePerRpcAuthorityCheck = true; + try { + ServerCredentials serverCreds; + try (InputStream serverCert = TlsTesting.loadCert("server1.pem"); + InputStream serverPrivateKey = TlsTesting.loadCert("server1.key")) { + serverCreds = TlsServerCredentials.newBuilder() + .keyManager(serverCert, serverPrivateKey) + .build(); + } + ChannelCredentials channelCreds; + try (InputStream caCert = TlsTesting.loadCert("ca.pem")) { + X509TrustManager x509ExtendedTrustManager = + (X509TrustManager) getX509ExtendedTrustManager(caCert).get(); + channelCreds = TlsChannelCredentials.newBuilder() + .trustManager(new FakeTrustManager(x509ExtendedTrustManager)) + .build(); + } + Server server = grpcCleanupRule.register(server(serverCreds)); + ManagedChannel channel = grpcCleanupRule.register(clientChannel(server, channelCreds)); + + ClientCalls.blockingUnaryCall(channel, SimpleServiceGrpc.getUnaryRpcMethod(), + CallOptions.DEFAULT.withAuthority("foo.test.google.fr"), + SimpleRequest.getDefaultInstance()); + } finally { + OkHttpClientTransport.enablePerRpcAuthorityCheck = false; + } + } + + @Test + public void perRpcAuthorityOverride_notX509ExtendedTrustManager_badAuthority_fails() + throws Exception { + OkHttpClientTransport.enablePerRpcAuthorityCheck = true; + try { + ServerCredentials serverCreds; + try (InputStream serverCert = TlsTesting.loadCert("server1.pem"); + InputStream serverPrivateKey = TlsTesting.loadCert("server1.key")) { + serverCreds = TlsServerCredentials.newBuilder() + .keyManager(serverCert, serverPrivateKey) + .build(); + } + ChannelCredentials channelCreds; + try (InputStream caCert = TlsTesting.loadCert("ca.pem")) { + X509TrustManager x509ExtendedTrustManager = + (X509TrustManager) getX509ExtendedTrustManager(caCert).get(); + channelCreds = TlsChannelCredentials.newBuilder() + .trustManager(new FakeTrustManager(x509ExtendedTrustManager)) + .build(); + } + Server server = grpcCleanupRule.register(server(serverCreds)); + ManagedChannel channel = grpcCleanupRule.register(clientChannel(server, channelCreds)); + + try { + ClientCalls.blockingUnaryCall(channel, SimpleServiceGrpc.getUnaryRpcMethod(), + CallOptions.DEFAULT.withAuthority("disallowed.name.com"), + SimpleRequest.getDefaultInstance()); + fail("Expected exception for authority verification failure."); + } catch (StatusRuntimeException ex) { + assertThat(ex.getStatus().getCode()).isEqualTo(Status.Code.UNAVAILABLE); + assertThat(ex.getStatus().getDescription()) + .isEqualTo("HostNameVerifier verification failed for authority 'disallowed.name.com'"); + } + } finally { + OkHttpClientTransport.enablePerRpcAuthorityCheck = false; + } + } + + @Test + public void + perRpcAuthorityOverride_notX509ExtendedTrustManager_badAuthority_flagDisabled_succeeds() + throws Exception { + ServerCredentials serverCreds; + try (InputStream serverCert = TlsTesting.loadCert("server1.pem"); + InputStream serverPrivateKey = TlsTesting.loadCert("server1.key")) { + serverCreds = TlsServerCredentials.newBuilder() + .keyManager(serverCert, serverPrivateKey) + .build(); + } + ChannelCredentials channelCreds; + try (InputStream caCert = TlsTesting.loadCert("ca.pem")) { + X509TrustManager x509ExtendedTrustManager = + (X509TrustManager) getX509ExtendedTrustManager(caCert).get(); + channelCreds = TlsChannelCredentials.newBuilder() + .trustManager(new FakeTrustManager(x509ExtendedTrustManager)) + .build(); + } + Server server = grpcCleanupRule.register(server(serverCreds)); + ManagedChannel channel = grpcCleanupRule.register(clientChannel(server, channelCreds)); + + ClientCalls.blockingUnaryCall(channel, SimpleServiceGrpc.getUnaryRpcMethod(), + CallOptions.DEFAULT.withAuthority("disallowed.name.com"), + SimpleRequest.getDefaultInstance()); + } + @Test public void mtls_succeeds() throws Exception { ServerCredentials serverCreds; @@ -282,6 +620,127 @@ public void hostnameVerifierFails_fails() assertThat(status.getCause()).isInstanceOf(SSLPeerUnverifiedException.class); } + /** Used to simulate the case of X509ExtendedTrustManager not present. */ + private static class FakeTrustManager implements X509TrustManager { + + private final X509TrustManager delegate; + + public FakeTrustManager(X509TrustManager x509ExtendedTrustManager) { + this.delegate = x509ExtendedTrustManager; + } + + @Override + public void checkClientTrusted(X509Certificate[] x509Certificates, String s) + throws CertificateException { + delegate.checkClientTrusted(x509Certificates, s); + } + + @Override + public void checkServerTrusted(X509Certificate[] x509Certificates, String s) + throws CertificateException { + delegate.checkServerTrusted(x509Certificates, s); + } + + @Override + public X509Certificate[] getAcceptedIssuers() { + return delegate.getAcceptedIssuers(); + } + } + + /** + * Checks against a limited set of hostnames. In production, EndpointIdentificationAlgorithm is + * unset so the default trust manager will not fail based on the hostname. This class is used to + * test user-provided trust managers that may have their own behavior. + */ + private static class HostnameCheckingX509ExtendedTrustManager + extends ForwardingX509ExtendedTrustManager { + public HostnameCheckingX509ExtendedTrustManager(X509ExtendedTrustManager tm) { + super(tm); + } + + @Override + public void checkServerTrusted(X509Certificate[] chain, String authType, Socket socket) + throws CertificateException { + String peer = ((SSLSocket) socket).getHandshakeSession().getPeerHost(); + if (!"foo.test.google.fr".equals(peer) && !"good.test.google.fr".equals(peer)) { + throw new CertificateException("Peer verification failed."); + } + super.checkServerTrusted(chain, authType, socket); + } + } + + @IgnoreJRERequirement + private static class ForwardingX509ExtendedTrustManager extends X509ExtendedTrustManager { + private final X509ExtendedTrustManager delegate; + + private ForwardingX509ExtendedTrustManager(X509ExtendedTrustManager delegate) { + this.delegate = delegate; + } + + @Override + public void checkServerTrusted(X509Certificate[] chain, String authType, Socket socket) + throws CertificateException { + delegate.checkServerTrusted(chain, authType, socket); + } + + @Override + public void checkServerTrusted(X509Certificate[] chain, String authType, SSLEngine engine) + throws CertificateException { + delegate.checkServerTrusted(chain, authType, engine); + } + + @Override + public void checkServerTrusted(X509Certificate[] chain, String authType) + throws CertificateException { + delegate.checkServerTrusted(chain, authType); + } + + @Override + public void checkClientTrusted(X509Certificate[] chain, String authType, SSLEngine engine) + throws CertificateException { + delegate.checkClientTrusted(chain, authType, engine); + } + + @Override + public void checkClientTrusted(X509Certificate[] chain, String authType) + throws CertificateException { + delegate.checkClientTrusted(chain, authType); + } + + @Override + public void checkClientTrusted(X509Certificate[] chain, String authType, Socket socket) + throws CertificateException { + delegate.checkClientTrusted(chain, authType, socket); + } + + @Override + public X509Certificate[] getAcceptedIssuers() { + return delegate.getAcceptedIssuers(); + } + } + + private static Optional getX509ExtendedTrustManager(InputStream rootCerts) + throws GeneralSecurityException { + KeyStore ks = KeyStore.getInstance(KeyStore.getDefaultType()); + try { + ks.load(null, null); + } catch (IOException ex) { + // Shouldn't really happen, as we're not loading any data. + throw new GeneralSecurityException(ex); + } + X509Certificate[] certs = CertificateUtils.getX509Certificates(rootCerts); + for (X509Certificate cert : certs) { + X500Principal principal = cert.getSubjectX500Principal(); + ks.setCertificateEntry(principal.getName("RFC2253"), cert); + } + + TrustManagerFactory trustManagerFactory = + TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()); + trustManagerFactory.init(ks); + return Arrays.stream(trustManagerFactory.getTrustManagers()) + .filter(trustManager -> trustManager instanceof X509ExtendedTrustManager).findFirst(); + } + private static Server server(ServerCredentials creds) throws IOException { return OkHttpServerBuilder.forPort(0, creds) .directExecutor() diff --git a/okhttp/src/test/java/io/grpc/okhttp/UtilsTest.java b/okhttp/src/test/java/io/grpc/okhttp/UtilsTest.java index 895ba7ff7c7..1c97e027b4a 100644 --- a/okhttp/src/test/java/io/grpc/okhttp/UtilsTest.java +++ b/okhttp/src/test/java/io/grpc/okhttp/UtilsTest.java @@ -16,7 +16,9 @@ package io.grpc.okhttp; +import static com.google.common.truth.Truth.assertThat; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; import io.grpc.InternalChannelz.SocketOptions; @@ -25,9 +27,8 @@ import io.grpc.okhttp.internal.TlsVersion; import java.net.Socket; import java.util.List; -import org.junit.Rule; +import java.util.Locale; import org.junit.Test; -import org.junit.rules.ExpectedException; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -37,16 +38,12 @@ @RunWith(JUnit4.class) public class UtilsTest { - @SuppressWarnings("deprecation") // https://github.com/grpc/grpc-java/issues/7467 - @Rule - public final ExpectedException thrown = ExpectedException.none(); - @Test public void convertSpecRejectsPlaintext() { com.squareup.okhttp.ConnectionSpec plaintext = com.squareup.okhttp.ConnectionSpec.CLEARTEXT; - thrown.expect(IllegalArgumentException.class); - thrown.expectMessage("plaintext ConnectionSpec is not accepted"); - Utils.convertSpec(plaintext); + IllegalArgumentException e = assertThrows(IllegalArgumentException.class, + () -> Utils.convertSpec(plaintext)); + assertThat(e).hasMessageThat().isEqualTo("plaintext ConnectionSpec is not accepted"); } @Test @@ -95,6 +92,9 @@ public void getSocketOptions() throws Exception { assertEquals("5000", socketOptions.others.get("SO_SNDBUF")); assertEquals("true", socketOptions.others.get("SO_KEEPALIVE")); assertEquals("true", socketOptions.others.get("SO_OOBINLINE")); - assertEquals("8", socketOptions.others.get("IP_TOS")); + String osName = System.getProperty("os.name").toLowerCase(Locale.ENGLISH); + if (!osName.startsWith("windows")) { + assertEquals("8", socketOptions.others.get("IP_TOS")); + } } } diff --git a/okhttp/third_party/okhttp/main/java/io/grpc/okhttp/internal/OkHostnameVerifier.java b/okhttp/third_party/okhttp/main/java/io/grpc/okhttp/internal/OkHostnameVerifier.java index 34bb56ee2d6..f6efb2d90e7 100644 --- a/okhttp/third_party/okhttp/main/java/io/grpc/okhttp/internal/OkHostnameVerifier.java +++ b/okhttp/third_party/okhttp/main/java/io/grpc/okhttp/internal/OkHostnameVerifier.java @@ -29,10 +29,13 @@ import java.util.List; import java.util.Locale; import java.util.regex.Pattern; +import java.nio.charset.StandardCharsets; import javax.net.ssl.HostnameVerifier; import javax.net.ssl.SSLException; import javax.net.ssl.SSLSession; import javax.security.auth.x500.X500Principal; +import com.google.common.base.Utf8; +import com.google.common.base.Ascii; /** * A HostnameVerifier consistent with altNames = getSubjectAltNames(certificate, ALT_DNS_NAME); for (int i = 0, size = altNames.size(); i < size; i++) { @@ -198,7 +204,7 @@ private boolean verifyHostName(String hostName, String pattern) { } // hostName and pattern are now absolute domain names. - pattern = pattern.toLowerCase(Locale.US); + pattern = Ascii.toLowerCase(pattern); // hostName and pattern are now in lower case -- domain names are case-insensitive. if (!pattern.contains("*")) { @@ -254,4 +260,13 @@ private boolean verifyHostName(String hostName, String pattern) { // hostName matches pattern return true; } + + /** + * Returns true if {@code input} is an ASCII string. + * @param input the string to check. + */ + private static boolean isAscii(String input) { + // Only ASCII characters are 1 byte in UTF-8. + return Utf8.encodedLength(input) == input.length(); + } } diff --git a/okhttp/third_party/okhttp/main/java/io/grpc/okhttp/internal/Platform.java b/okhttp/third_party/okhttp/main/java/io/grpc/okhttp/internal/Platform.java index 6ed3bc50b81..29ea8055b26 100644 --- a/okhttp/third_party/okhttp/main/java/io/grpc/okhttp/internal/Platform.java +++ b/okhttp/third_party/okhttp/main/java/io/grpc/okhttp/internal/Platform.java @@ -283,7 +283,7 @@ private static boolean isAtLeastAndroid41() { /** * Select the first recognized security provider according to the preference order returned by - * {@link Security#getProviders}. If a recognized provider is not found then warn but continue. + * {@link Security#getProviders}. */ private static Provider getAndroidSecurityProvider() { Provider[] providers = Security.getProviders(); @@ -295,7 +295,6 @@ private static Provider getAndroidSecurityProvider() { } } } - logger.log(Level.WARNING, "Unable to find Conscrypt"); return null; } diff --git a/okhttp/third_party/okhttp/main/java/io/grpc/okhttp/internal/framed/Hpack.java b/okhttp/third_party/okhttp/main/java/io/grpc/okhttp/internal/framed/Hpack.java index 484cc5673dc..3155d6d533a 100644 --- a/okhttp/third_party/okhttp/main/java/io/grpc/okhttp/internal/framed/Hpack.java +++ b/okhttp/third_party/okhttp/main/java/io/grpc/okhttp/internal/framed/Hpack.java @@ -354,6 +354,13 @@ int readInt(int firstByte, int prefixMask) throws IOException { if ((b & 0x80) != 0) { // Equivalent to (b >= 128) since b is in [0..255]. result += (b & 0x7f) << shift; shift += 7; + // We can safely store 31 bits, and then next byte will have 7 more bits. While the next + // byte may not have high bits set to cause an overflow, that's only useful for 256+ MiB + // values, which is excessive. This also gives us at least one bit of spare, which is + // necessary to store the carry from the addition. + if (shift >= 28) { + throw new IOException("Varint overflowed"); + } } else { result += b << shift; // Last byte. break; @@ -508,6 +515,9 @@ void writeInt(int value, int prefixMask, int bits) throws IOException { // Write the mask to start a multibyte value. out.writeByte(bits | prefixMask); value -= prefixMask; + if (value > 0xfffffff) { + throw new IOException("Varint would overflow reader"); + } // Write 7 bits at a time 'til we're done. while (value >= 0x80) { diff --git a/okhttp/third_party/okhttp/test/java/io/grpc/okhttp/internal/framed/HpackTest.java b/okhttp/third_party/okhttp/test/java/io/grpc/okhttp/internal/framed/HpackTest.java index 26580f85e54..dc5e030810f 100644 --- a/okhttp/third_party/okhttp/test/java/io/grpc/okhttp/internal/framed/HpackTest.java +++ b/okhttp/third_party/okhttp/test/java/io/grpc/okhttp/internal/framed/HpackTest.java @@ -455,7 +455,7 @@ public void theSameHeaderAfterOneIncrementalIndexed() throws IOException { hpackReader.readHeaders(); fail(); } catch (IOException e) { - assertEquals("Header index too large -2147483521", e.getMessage()); + assertEquals("Varint overflowed", e.getMessage()); } } @@ -497,7 +497,7 @@ public void theSameHeaderAfterOneIncrementalIndexed() throws IOException { hpackReader.readHeaders(); fail(); } catch (IOException e) { - assertEquals("Invalid dynamic table size update -2147483648", e.getMessage()); + assertEquals("Varint overflowed", e.getMessage()); } } @@ -856,11 +856,53 @@ private void checkReadThirdRequestWithHuffman() { assertBytes(0xe0 | 31, 154, 10); } - @Test public void max31BitValue() throws IOException { - hpackWriter.writeInt(0x7fffffff, 31, 0); - assertBytes(31, 224, 255, 255, 255, 7); - assertEquals(0x7fffffff, - newReader(byteStream(224, 255, 255, 255, 7)).readInt(31, 31)); + @Test public void max29BitValue() throws IOException { + hpackWriter.writeInt(0x100000fe, 0xff, 0xff); + assertBytes(0xff, 0xff, 0xff, 0xff, 0x7f); + assertEquals(0x100000fe, + newReader(byteStream(0xff, 0xff, 0xff, 0x7f)).readInt(0xff, 0xff)); + } + + @Test public void beyondMax29BitValue() throws IOException { + try { + hpackWriter.writeInt(0x100000ff, 0xff, 0xff); + fail(); + } catch (IOException e) { + assertEquals("Varint would overflow reader", e.getMessage()); + } + try { + newReader(byteStream(0xff, 0xff, 0xff, 0xff, 0x80)).readInt(0xff, 0xff); + fail(); + } catch (IOException e) { + assertEquals("Varint overflowed", e.getMessage()); + } + } + + @Test public void beyondMax29BitValue_smallPrefix() throws IOException { + try { + hpackWriter.writeInt(0x10000001, 1, 1); + fail(); + } catch (IOException e) { + assertEquals("Varint would overflow reader", e.getMessage()); + } + try { + newReader(byteStream(0xff, 0xff, 0xff, 0xff, 0x80)).readInt(1, 1); + fail(); + } catch (IOException e) { + assertEquals("Varint overflowed", e.getMessage()); + } + } + + @Test public void readerAbortsLongVarintsWithZeros() throws IOException { + try { + // The reader should fail before getting to the end, because it will overflow as soon as there + // is a 1 bit, and the only reason to use this many continuations is to eventually have a 1 + // bit. + newReader(byteStream(0x80, 0x80, 0x80, 0x80, 0x80)).readInt(31, 31); + fail(); + } catch (IOException e) { + assertEquals("Varint overflowed", e.getMessage()); + } } @Test public void prefixMask() throws IOException { diff --git a/opentelemetry/build.gradle b/opentelemetry/build.gradle index 1fb8180a2ce..594686294f0 100644 --- a/opentelemetry/build.gradle +++ b/opentelemetry/build.gradle @@ -1,5 +1,8 @@ plugins { id "java-library" + id "maven-publish" + + id "ru.vyarus.animalsniffer" } description = 'gRPC: OpenTelemetry' @@ -9,14 +12,29 @@ dependencies { implementation libraries.guava, project(':grpc-core'), libraries.opentelemetry.api, - libraries.auto.value.annotations + libraries.auto.value.annotations, + libraries.animalsniffer.annotations - testImplementation testFixtures(project(':grpc-core')), - project(':grpc-testing'), + testImplementation project(':grpc-testing'), + project(':grpc-testing-proto'), + project(':grpc-inprocess'), + testFixtures(project(':grpc-core')), + testFixtures(project(':grpc-api')), libraries.opentelemetry.sdk.testing, - "org.assertj:assertj-core:3.24.2" + libraries.assertj.core // opentelemetry.sdk.testing uses compileOnly for assertj annotationProcessor libraries.auto.value + + signature (libraries.signature.java) { + artifact { + extension = "signature" + } + } + signature (libraries.signature.android) { + artifact { + extension = "signature" + } + } } tasks.named("jar").configure { @@ -35,3 +53,7 @@ tasks.named("compileJava").configure { ".*/build/generated/sources/annotationProcessor/java/.*", "|") } + +tasks.named("javadoc").configure { + exclude 'io/grpc/opentelemetry/internal/**' +} diff --git a/opentelemetry/src/main/java/io/grpc/opentelemetry/BinaryFormat.java b/opentelemetry/src/main/java/io/grpc/opentelemetry/BinaryFormat.java new file mode 100644 index 00000000000..cdf27875903 --- /dev/null +++ b/opentelemetry/src/main/java/io/grpc/opentelemetry/BinaryFormat.java @@ -0,0 +1,143 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.opentelemetry; + + +import static com.google.common.base.Preconditions.checkNotNull; + +import io.grpc.Metadata; +import io.opentelemetry.api.trace.SpanContext; +import io.opentelemetry.api.trace.SpanId; +import io.opentelemetry.api.trace.TraceFlags; +import io.opentelemetry.api.trace.TraceId; +import io.opentelemetry.api.trace.TraceState; +import java.util.Arrays; + +/** + * Binary encoded {@link SpanContext} for context propagation. This is adapted from OpenCensus + * binary format. + * + *

BinaryFormat format: + * + *

    + *
  • Binary value: <version_id><version_format> + *
  • version_id: 1-byte representing the version id. + *
  • For version_id = 0: + *
      + *
    • version_format: <field><field> + *
    • field_format: <field_id><field_format> + *
    • Fields: + *
        + *
      • TraceId: (field_id = 0, len = 16, default = "0000000000000000") - + * 16-byte array representing the trace_id. + *
      • SpanId: (field_id = 1, len = 8, default = "00000000") - 8-byte array + * representing the span_id. + *
      • TraceFlags: (field_id = 2, len = 1, default = "0") - 1-byte array + * representing the trace_flags. + *
      + *
    • Fields MUST be encoded using the field id order (smaller to higher). + *
    • Valid value example: + *
        + *
      • {0, 0, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 1, 97, + * 98, 99, 100, 101, 102, 103, 104, 2, 1} + *
      • version_id = 0; + *
      • trace_id = {64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79} + *
      • span_id = {97, 98, 99, 100, 101, 102, 103, 104}; + *
      • trace_flags = {1}; + *
      + *
    + *
+ */ +final class BinaryFormat implements Metadata.BinaryMarshaller { + private static final byte VERSION_ID = 0; + private static final int VERSION_ID_OFFSET = 0; + private static final byte ID_SIZE = 1; + private static final byte TRACE_ID_FIELD_ID = 0; + + private static final int TRACE_ID_FIELD_ID_OFFSET = VERSION_ID_OFFSET + ID_SIZE; + private static final int TRACE_ID_OFFSET = TRACE_ID_FIELD_ID_OFFSET + ID_SIZE; + private static final int TRACE_ID_SIZE = TraceId.getLength() / 2; + + private static final byte SPAN_ID_FIELD_ID = 1; + private static final int SPAN_ID_FIELD_ID_OFFSET = TRACE_ID_OFFSET + TRACE_ID_SIZE; + private static final int SPAN_ID_OFFSET = SPAN_ID_FIELD_ID_OFFSET + ID_SIZE; + private static final int SPAN_ID_SIZE = SpanId.getLength() / 2; + + private static final byte TRACE_FLAG_FIELD_ID = 2; + private static final int TRACE_FLAG_FIELD_ID_OFFSET = SPAN_ID_OFFSET + SPAN_ID_SIZE; + private static final int TRACE_FLAG_OFFSET = TRACE_FLAG_FIELD_ID_OFFSET + ID_SIZE; + private static final int REQUIRED_FORMAT_LENGTH = 3 * ID_SIZE + TRACE_ID_SIZE + SPAN_ID_SIZE; + private static final int TRACE_FLAG_SIZE = TraceFlags.getLength() / 2; + private static final int ALL_FORMAT_LENGTH = REQUIRED_FORMAT_LENGTH + ID_SIZE + TRACE_FLAG_SIZE; + + private static final BinaryFormat INSTANCE = new BinaryFormat(); + + public static BinaryFormat getInstance() { + return INSTANCE; + } + + @Override + public byte[] toBytes(SpanContext spanContext) { + checkNotNull(spanContext, "spanContext"); + byte[] bytes = new byte[ALL_FORMAT_LENGTH]; + bytes[VERSION_ID_OFFSET] = VERSION_ID; + bytes[TRACE_ID_FIELD_ID_OFFSET] = TRACE_ID_FIELD_ID; + System.arraycopy(spanContext.getTraceIdBytes(), 0, bytes, TRACE_ID_OFFSET, TRACE_ID_SIZE); + bytes[SPAN_ID_FIELD_ID_OFFSET] = SPAN_ID_FIELD_ID; + System.arraycopy(spanContext.getSpanIdBytes(), 0, bytes, SPAN_ID_OFFSET, SPAN_ID_SIZE); + bytes[TRACE_FLAG_FIELD_ID_OFFSET] = TRACE_FLAG_FIELD_ID; + bytes[TRACE_FLAG_OFFSET] = spanContext.getTraceFlags().asByte(); + return bytes; + } + + + @Override + public SpanContext parseBytes(byte[] serialized) { + checkNotNull(serialized, "bytes"); + if (serialized.length == 0 || serialized[0] != VERSION_ID) { + throw new IllegalArgumentException("Unsupported version."); + } + if (serialized.length < REQUIRED_FORMAT_LENGTH) { + throw new IllegalArgumentException("Invalid input: truncated"); + } + String traceId; + String spanId; + TraceFlags traceFlags = TraceFlags.getDefault(); + int pos = 1; + if (serialized[pos] == TRACE_ID_FIELD_ID) { + traceId = TraceId.fromBytes( + Arrays.copyOfRange(serialized, pos + ID_SIZE, pos + ID_SIZE + TRACE_ID_SIZE)); + pos += ID_SIZE + TRACE_ID_SIZE; + } else { + throw new IllegalArgumentException("Invalid input: expected trace ID at offset " + pos); + } + if (serialized[pos] == SPAN_ID_FIELD_ID) { + spanId = SpanId.fromBytes( + Arrays.copyOfRange(serialized, pos + ID_SIZE, pos + ID_SIZE + SPAN_ID_SIZE)); + pos += ID_SIZE + SPAN_ID_SIZE; + } else { + throw new IllegalArgumentException("Invalid input: expected span ID at offset " + pos); + } + if (serialized.length > pos && serialized[pos] == TRACE_FLAG_FIELD_ID) { + if (serialized.length < ALL_FORMAT_LENGTH) { + throw new IllegalArgumentException("Invalid input: truncated"); + } + traceFlags = TraceFlags.fromByte(serialized[pos + ID_SIZE]); + } + return SpanContext.create(traceId, spanId, traceFlags, TraceState.getDefault()); + } +} diff --git a/opentelemetry/src/main/java/io/grpc/opentelemetry/GrpcOpenTelemetry.java b/opentelemetry/src/main/java/io/grpc/opentelemetry/GrpcOpenTelemetry.java new file mode 100644 index 00000000000..87ad61c9f27 --- /dev/null +++ b/opentelemetry/src/main/java/io/grpc/opentelemetry/GrpcOpenTelemetry.java @@ -0,0 +1,471 @@ +/* + * Copyright 2023 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.opentelemetry; + +import static com.google.common.base.Preconditions.checkNotNull; +import static io.grpc.internal.GrpcUtil.IMPLEMENTATION_VERSION; +import static io.grpc.opentelemetry.internal.OpenTelemetryConstants.HEDGE_BUCKETS; +import static io.grpc.opentelemetry.internal.OpenTelemetryConstants.LATENCY_BUCKETS; +import static io.grpc.opentelemetry.internal.OpenTelemetryConstants.RETRY_BUCKETS; +import static io.grpc.opentelemetry.internal.OpenTelemetryConstants.SIZE_BUCKETS; +import static io.grpc.opentelemetry.internal.OpenTelemetryConstants.TRANSPARENT_RETRY_BUCKETS; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Stopwatch; +import com.google.common.base.Supplier; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import io.grpc.ExperimentalApi; +import io.grpc.InternalConfigurator; +import io.grpc.InternalConfiguratorRegistry; +import io.grpc.InternalManagedChannelBuilder; +import io.grpc.ManagedChannelBuilder; +import io.grpc.MetricSink; +import io.grpc.ServerBuilder; +import io.grpc.internal.GrpcUtil; +import io.grpc.opentelemetry.internal.OpenTelemetryConstants; +import io.opentelemetry.api.OpenTelemetry; +import io.opentelemetry.api.metrics.Meter; +import io.opentelemetry.api.metrics.MeterProvider; +import io.opentelemetry.api.trace.Tracer; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.function.Predicate; +import javax.annotation.Nullable; +import org.codehaus.mojo.animal_sniffer.IgnoreJRERequirement; + +/** + * The entrypoint for OpenTelemetry metrics functionality in gRPC. + * + *

GrpcOpenTelemetry uses {@link io.opentelemetry.api.OpenTelemetry} APIs for instrumentation. + * When no SDK is explicitly added no telemetry data will be collected. See + * {@code io.opentelemetry.sdk.OpenTelemetrySdk} for information on how to construct the SDK. + * + */ +public final class GrpcOpenTelemetry { + + private static final Supplier STOPWATCH_SUPPLIER = new Supplier() { + @Override + public Stopwatch get() { + return Stopwatch.createUnstarted(); + } + }; + + @VisibleForTesting + static boolean ENABLE_OTEL_TRACING = + GrpcUtil.getFlag("GRPC_EXPERIMENTAL_ENABLE_OTEL_TRACING", false); + + private final OpenTelemetry openTelemetrySdk; + private final MeterProvider meterProvider; + private final Meter meter; + private final Map enableMetrics; + private final boolean disableDefault; + private final OpenTelemetryMetricsResource resource; + private final OpenTelemetryMetricsModule openTelemetryMetricsModule; + private final OpenTelemetryTracingModule openTelemetryTracingModule; + private final List optionalLabels; + private final MetricSink sink; + + public static Builder newBuilder() { + return new Builder(); + } + + private GrpcOpenTelemetry(Builder builder) { + this.openTelemetrySdk = checkNotNull(builder.openTelemetrySdk, "openTelemetrySdk"); + this.meterProvider = checkNotNull(openTelemetrySdk.getMeterProvider(), "meterProvider"); + this.meter = this.meterProvider + .meterBuilder(OpenTelemetryConstants.INSTRUMENTATION_SCOPE) + .setInstrumentationVersion(IMPLEMENTATION_VERSION) + .build(); + this.enableMetrics = ImmutableMap.copyOf(builder.enableMetrics); + this.disableDefault = builder.disableAll; + this.resource = createMetricInstruments(meter, enableMetrics, disableDefault); + this.optionalLabels = ImmutableList.copyOf(builder.optionalLabels); + this.openTelemetryMetricsModule = new OpenTelemetryMetricsModule( + STOPWATCH_SUPPLIER, resource, optionalLabels, builder.plugins, + builder.targetFilter); + this.openTelemetryTracingModule = new OpenTelemetryTracingModule(openTelemetrySdk); + this.sink = new OpenTelemetryMetricSink(meter, enableMetrics, disableDefault, optionalLabels); + } + + @VisibleForTesting + OpenTelemetry getOpenTelemetryInstance() { + return this.openTelemetrySdk; + } + + @VisibleForTesting + MeterProvider getMeterProvider() { + return this.meterProvider; + } + + @VisibleForTesting + Meter getMeter() { + return this.meter; + } + + @VisibleForTesting + OpenTelemetryMetricsResource getResource() { + return this.resource; + } + + @VisibleForTesting + Map getEnableMetrics() { + return this.enableMetrics; + } + + @VisibleForTesting + List getOptionalLabels() { + return optionalLabels; + } + + MetricSink getSink() { + return sink; + } + + @VisibleForTesting + Tracer getTracer() { + return this.openTelemetryTracingModule.getTracer(); + } + + @VisibleForTesting + TargetFilter getTargetAttributeFilter() { + return this.openTelemetryMetricsModule.getTargetAttributeFilter(); + } + + /** + * Registers GrpcOpenTelemetry globally, applying its configuration to all subsequently created + * gRPC channels and servers. + */ + @ExperimentalApi("https://github.com/grpc/grpc-java/issues/10591") + public void registerGlobal() { + InternalConfiguratorRegistry.setConfigurators(Collections.singletonList( + new InternalConfigurator() { + @Override + public void configureChannelBuilder(ManagedChannelBuilder channelBuilder) { + GrpcOpenTelemetry.this.configureChannelBuilder(channelBuilder); + } + + @Override + public void configureServerBuilder(ServerBuilder serverBuilder) { + GrpcOpenTelemetry.this.configureServerBuilder(serverBuilder); + } + })); + } + + /** + * Configures the given {@link ManagedChannelBuilder} with OpenTelemetry metrics instrumentation. + */ + public void configureChannelBuilder(ManagedChannelBuilder builder) { + InternalManagedChannelBuilder.addMetricSink(builder, sink); + InternalManagedChannelBuilder.interceptWithTarget( + builder, openTelemetryMetricsModule::getClientInterceptor); + if (ENABLE_OTEL_TRACING) { + builder.intercept(openTelemetryTracingModule.getClientInterceptor()); + } + } + + /** + * Configures the given {@link ServerBuilder} with OpenTelemetry metrics instrumentation. + * + * @param serverBuilder the server builder to configure + */ + public void configureServerBuilder(ServerBuilder serverBuilder) { + /* To ensure baggage propagation to metrics, we need the tracing + tracers to be initialised before metrics */ + if (ENABLE_OTEL_TRACING) { + serverBuilder.addStreamTracerFactory( + openTelemetryTracingModule.getServerTracerFactory()); + serverBuilder.intercept(openTelemetryTracingModule.getServerSpanPropagationInterceptor()); + } + serverBuilder.addStreamTracerFactory(openTelemetryMetricsModule.getServerTracerFactory()); + serverBuilder.addMetricSink(sink); + } + + @VisibleForTesting + static OpenTelemetryMetricsResource createMetricInstruments(Meter meter, + Map enableMetrics, boolean disableDefault) { + OpenTelemetryMetricsResource.Builder builder = OpenTelemetryMetricsResource.builder(); + + if (isMetricEnabled("grpc.client.call.duration", enableMetrics, disableDefault)) { + builder.clientCallDurationCounter( + meter.histogramBuilder("grpc.client.call.duration") + .setUnit("s") + .setDescription( + "Time taken by gRPC to complete an RPC from application's perspective") + .setExplicitBucketBoundariesAdvice(LATENCY_BUCKETS) + .build()); + } + + if (isMetricEnabled("grpc.client.attempt.started", enableMetrics, disableDefault)) { + builder.clientAttemptCountCounter( + meter.counterBuilder("grpc.client.attempt.started") + .setUnit("{attempt}") + .setDescription("Number of client call attempts started") + .build()); + } + + if (isMetricEnabled("grpc.client.attempt.duration", enableMetrics, disableDefault)) { + builder.clientAttemptDurationCounter( + meter.histogramBuilder( + "grpc.client.attempt.duration") + .setUnit("s") + .setDescription("Time taken to complete a client call attempt") + .setExplicitBucketBoundariesAdvice(LATENCY_BUCKETS) + .build()); + } + + if (isMetricEnabled("grpc.client.attempt.sent_total_compressed_message_size", enableMetrics, + disableDefault)) { + builder.clientTotalSentCompressedMessageSizeCounter( + meter.histogramBuilder( + "grpc.client.attempt.sent_total_compressed_message_size") + .setUnit("By") + .setDescription("Compressed message bytes sent per client call attempt") + .ofLongs() + .setExplicitBucketBoundariesAdvice(SIZE_BUCKETS) + .build()); + } + + if (isMetricEnabled("grpc.client.attempt.rcvd_total_compressed_message_size", enableMetrics, + disableDefault)) { + builder.clientTotalReceivedCompressedMessageSizeCounter( + meter.histogramBuilder( + "grpc.client.attempt.rcvd_total_compressed_message_size") + .setUnit("By") + .setDescription("Compressed message bytes received per call attempt") + .ofLongs() + .setExplicitBucketBoundariesAdvice(SIZE_BUCKETS) + .build()); + } + + if (isMetricEnabled("grpc.client.call.retries", enableMetrics, disableDefault)) { + builder.clientCallRetriesCounter( + meter.histogramBuilder( + "grpc.client.call.retries") + .setUnit("{retry}") + .setDescription("Number of retries during the client call. " + + "If there were no retries, 0 is not reported.") + .ofLongs() + .setExplicitBucketBoundariesAdvice(RETRY_BUCKETS) + .build()); + } + + if (isMetricEnabled("grpc.client.call.transparent_retries", enableMetrics, + disableDefault)) { + builder.clientCallTransparentRetriesCounter( + meter.histogramBuilder( + "grpc.client.call.transparent_retries") + .setUnit("{transparent_retry}") + .setDescription("Number of transparent retries during the client call. " + + "If there were no transparent retries, 0 is not reported.") + .ofLongs() + .setExplicitBucketBoundariesAdvice(TRANSPARENT_RETRY_BUCKETS) + .build()); + } + + if (isMetricEnabled("grpc.client.call.hedges", enableMetrics, disableDefault)) { + builder.clientCallHedgesCounter( + meter.histogramBuilder( + "grpc.client.call.hedges") + .setUnit("{hedge}") + .setDescription("Number of hedges during the client call. " + + "If there were no hedges, 0 is not reported.") + .ofLongs() + .setExplicitBucketBoundariesAdvice(HEDGE_BUCKETS) + .build()); + } + + if (isMetricEnabled("grpc.client.call.retry_delay", enableMetrics, disableDefault)) { + builder.clientCallRetryDelayCounter( + meter.histogramBuilder( + "grpc.client.call.retry_delay") + .setUnit("s") + .setDescription("Total time of delay while there is no active attempt during the " + + "client call") + .setExplicitBucketBoundariesAdvice(LATENCY_BUCKETS) + .build()); + } + + if (isMetricEnabled("grpc.server.call.started", enableMetrics, disableDefault)) { + builder.serverCallCountCounter( + meter.counterBuilder("grpc.server.call.started") + .setUnit("{call}") + .setDescription("Number of server calls started") + .build()); + } + + if (isMetricEnabled("grpc.server.call.duration", enableMetrics, disableDefault)) { + builder.serverCallDurationCounter( + meter.histogramBuilder("grpc.server.call.duration") + .setUnit("s") + .setDescription( + "Time taken to complete a call from server transport's perspective") + .setExplicitBucketBoundariesAdvice(LATENCY_BUCKETS) + .build()); + } + + if (isMetricEnabled("grpc.server.call.sent_total_compressed_message_size", + enableMetrics, disableDefault)) { + builder.serverTotalSentCompressedMessageSizeCounter( + meter.histogramBuilder( + "grpc.server.call.sent_total_compressed_message_size") + .setUnit("By") + .setDescription("Compressed message bytes sent per server call") + .ofLongs() + .setExplicitBucketBoundariesAdvice(SIZE_BUCKETS) + .build()); + } + + if (isMetricEnabled("grpc.server.call.rcvd_total_compressed_message_size", + enableMetrics, disableDefault)) { + builder.serverTotalReceivedCompressedMessageSizeCounter( + meter.histogramBuilder( + "grpc.server.call.rcvd_total_compressed_message_size") + .setUnit("By") + .setDescription("Compressed message bytes received per server call") + .ofLongs() + .setExplicitBucketBoundariesAdvice(SIZE_BUCKETS) + .build()); + } + + return builder.build(); + } + + static boolean isMetricEnabled(String metricName, Map enableMetrics, + boolean disableDefault) { + Boolean explicitlyEnabled = enableMetrics.get(metricName); + if (explicitlyEnabled != null) { + return explicitlyEnabled; + } + return OpenTelemetryMetricsModule.DEFAULT_PER_CALL_METRICS_SET.contains(metricName) + && !disableDefault; + } + + /** + * Internal interface to avoid storing a {@link java.util.function.Predicate} directly, ensuring + * compatibility with Android devices (API level < 24) that do not use library desugaring. + */ + interface TargetFilter { + boolean test(String target); + } + + /** + * Builder for configuring {@link GrpcOpenTelemetry}. + */ + public static class Builder { + private OpenTelemetry openTelemetrySdk = OpenTelemetry.noop(); + private final List plugins = new ArrayList<>(); + private final Collection optionalLabels = new ArrayList<>(); + private final Map enableMetrics = new HashMap<>(); + private boolean disableAll; + @Nullable + private TargetFilter targetFilter; + + private Builder() {} + + /** + * Sets the {@link io.opentelemetry.api.OpenTelemetry} entrypoint to use. This can be used to + * configure OpenTelemetry by returning the instance created by a + * {@code io.opentelemetry.sdk.OpenTelemetrySdkBuilder}. + */ + public Builder sdk(OpenTelemetry sdk) { + this.openTelemetrySdk = sdk; + return this; + } + + Builder plugin(OpenTelemetryPlugin plugin) { + plugins.add(checkNotNull(plugin, "plugin")); + return this; + } + + /** + * Adds optionalLabelKey to all the metrics that can provide value for the + * optionalLabelKey. + */ + public Builder addOptionalLabel(String optionalLabelKey) { + this.optionalLabels.add(optionalLabelKey); + return this; + } + + /** + * Enables the specified metrics for collection and export. By default, only a subset of + * metrics are enabled. + */ + public Builder enableMetrics(Collection enableMetrics) { + for (String metric : enableMetrics) { + this.enableMetrics.put(metric, true); + } + return this; + } + + /** + * Disables the specified metrics from being collected and exported. + */ + public Builder disableMetrics(Collection disableMetrics) { + for (String metric : disableMetrics) { + this.enableMetrics.put(metric, false); + } + return this; + } + + /** + * Disable all metrics. If set to true all metrics must be explicitly enabled. + */ + public Builder disableAllMetrics() { + this.enableMetrics.clear(); + this.disableAll = true; + return this; + } + + Builder enableTracing(boolean enable) { + ENABLE_OTEL_TRACING = enable; + return this; + } + + /** + * Sets an optional filter to control recording of the {@code grpc.target} metric + * attribute. + * + *

If the predicate returns {@code true}, the original target is recorded. Otherwise, + * the target is recorded as {@code "other"} to limit metric cardinality. + * + *

If unset, all targets are recorded as-is. + */ + @ExperimentalApi("https://github.com/grpc/grpc-java/issues/12595") + @IgnoreJRERequirement + public Builder targetAttributeFilter(@Nullable Predicate filter) { + if (filter == null) { + this.targetFilter = null; + } else { + this.targetFilter = filter::test; + } + return this; + } + + /** + * Returns a new {@link GrpcOpenTelemetry} built with the configuration of this {@link + * Builder}. + */ + public GrpcOpenTelemetry build() { + return new GrpcOpenTelemetry(this); + } + } +} diff --git a/opentelemetry/src/main/java/io/grpc/opentelemetry/GrpcTraceBinContextPropagator.java b/opentelemetry/src/main/java/io/grpc/opentelemetry/GrpcTraceBinContextPropagator.java new file mode 100644 index 00000000000..4825b203529 --- /dev/null +++ b/opentelemetry/src/main/java/io/grpc/opentelemetry/GrpcTraceBinContextPropagator.java @@ -0,0 +1,147 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.opentelemetry; + + +import static com.google.common.base.Preconditions.checkNotNull; +import static io.grpc.InternalMetadata.BASE64_ENCODING_OMIT_PADDING; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.io.BaseEncoding; +import io.grpc.ExperimentalApi; +import io.grpc.Metadata; +import io.opentelemetry.api.trace.Span; +import io.opentelemetry.api.trace.SpanContext; +import io.opentelemetry.context.Context; +import io.opentelemetry.context.propagation.TextMapGetter; +import io.opentelemetry.context.propagation.TextMapPropagator; +import io.opentelemetry.context.propagation.TextMapSetter; +import java.util.Collection; +import java.util.Collections; +import java.util.logging.Level; +import java.util.logging.Logger; +import javax.annotation.Nullable; + +/** + * A {@link TextMapPropagator} for transmitting "grpc-trace-bin" span context. + * + *

This propagator can transmit the "grpc-trace-bin" context in either binary or Base64-encoded + * text format, depending on the capabilities of the provided {@link TextMapGetter} and + * {@link TextMapSetter}. + * + *

If the {@code TextMapGetter} and {@code TextMapSetter} only support text format, Base64 + * encoding and decoding will be used when communicating with the carrier API. But gRPC uses + * it with gRPC's metadata-based getter/setter, and the propagator can directly transmit the binary + * header, avoiding the need for Base64 encoding. + */ + +@ExperimentalApi("https://github.com/grpc/grpc-java/issues/11400") +public final class GrpcTraceBinContextPropagator implements TextMapPropagator { + private static final Logger log = Logger.getLogger(GrpcTraceBinContextPropagator.class.getName()); + public static final String GRPC_TRACE_BIN_HEADER = "grpc-trace-bin"; + private final Metadata.BinaryMarshaller binaryFormat; + private static final GrpcTraceBinContextPropagator INSTANCE = + new GrpcTraceBinContextPropagator(BinaryFormat.getInstance()); + + public static GrpcTraceBinContextPropagator defaultInstance() { + return INSTANCE; + } + + @VisibleForTesting + GrpcTraceBinContextPropagator(Metadata.BinaryMarshaller binaryFormat) { + this.binaryFormat = checkNotNull(binaryFormat, "binaryFormat"); + } + + @Override + public Collection fields() { + return Collections.singleton(GRPC_TRACE_BIN_HEADER); + } + + @Override + public void inject(Context context, @Nullable C carrier, TextMapSetter setter) { + if (context == null || setter == null) { + return; + } + SpanContext spanContext = Span.fromContext(context).getSpanContext(); + if (!spanContext.isValid()) { + return; + } + try { + byte[] b = binaryFormat.toBytes(spanContext); + if (setter instanceof MetadataSetter) { + ((MetadataSetter) setter).set((Metadata) carrier, GRPC_TRACE_BIN_HEADER, b); + } else { + setter.set(carrier, GRPC_TRACE_BIN_HEADER, BASE64_ENCODING_OMIT_PADDING.encode(b)); + } + } catch (Exception e) { + log.log(Level.FINE, "Set grpc-trace-bin spanContext failed", e); + } + } + + @Override + public Context extract(Context context, @Nullable C carrier, TextMapGetter getter) { + if (context == null) { + return Context.root(); + } + if (getter == null) { + return context; + } + byte[] b; + if (getter instanceof MetadataGetter) { + try { + b = ((MetadataGetter) getter).getBinary((Metadata) carrier, GRPC_TRACE_BIN_HEADER); + if (b == null) { + log.log(Level.FINE, "No grpc-trace-bin present in carrier"); + return context; + } + } catch (Exception e) { + log.log(Level.FINE, "Get 'grpc-trace-bin' from MetadataGetter failed", e); + return context; + } + } else { + String value; + try { + value = getter.get(carrier, GRPC_TRACE_BIN_HEADER); + if (value == null) { + log.log(Level.FINE, "No grpc-trace-bin present in carrier"); + return context; + } + } catch (Exception e) { + log.log(Level.FINE, "Get 'grpc-trace-bin' from getter failed", e); + return context; + } + try { + b = BaseEncoding.base64().decode(value); + } catch (Exception e) { + log.log(Level.FINE, "Base64-decode spanContext bytes failed", e); + return context; + } + } + + SpanContext spanContext; + try { + spanContext = binaryFormat.parseBytes(b); + } catch (Exception e) { + log.log(Level.FINE, "Failed to parse tracing header", e); + return context; + } + if (!spanContext.isValid()) { + return context; + } + return context.with(Span.wrap(spanContext)); + } +} diff --git a/opentelemetry/src/main/java/io/grpc/opentelemetry/InternalGrpcOpenTelemetry.java b/opentelemetry/src/main/java/io/grpc/opentelemetry/InternalGrpcOpenTelemetry.java new file mode 100644 index 00000000000..ea1e7ab803f --- /dev/null +++ b/opentelemetry/src/main/java/io/grpc/opentelemetry/InternalGrpcOpenTelemetry.java @@ -0,0 +1,36 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.opentelemetry; + +import io.grpc.Internal; + +/** + * Internal accessor for {@link GrpcOpenTelemetry}. + */ +@Internal +public final class InternalGrpcOpenTelemetry { + private InternalGrpcOpenTelemetry() {} + + public static void builderPlugin( + GrpcOpenTelemetry.Builder builder, InternalOpenTelemetryPlugin plugin) { + builder.plugin(plugin); + } + + public static void enableTracing(GrpcOpenTelemetry.Builder builder, boolean enable) { + builder.enableTracing(enable); + } +} diff --git a/opentelemetry/src/main/java/io/grpc/opentelemetry/InternalOpenTelemetryPlugin.java b/opentelemetry/src/main/java/io/grpc/opentelemetry/InternalOpenTelemetryPlugin.java new file mode 100644 index 00000000000..38275506e1a --- /dev/null +++ b/opentelemetry/src/main/java/io/grpc/opentelemetry/InternalOpenTelemetryPlugin.java @@ -0,0 +1,36 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.opentelemetry; + +import io.grpc.Internal; + +/** + * Accessors for making plugins. + */ +@Internal +public interface InternalOpenTelemetryPlugin extends OpenTelemetryPlugin { + @Override + ClientCallPlugin newClientCallPlugin(); + + interface ClientCallPlugin extends OpenTelemetryPlugin.ClientCallPlugin { + @Override + ClientStreamPlugin newClientStreamPlugin(); + } + + interface ClientStreamPlugin extends OpenTelemetryPlugin.ClientStreamPlugin { + } +} diff --git a/opentelemetry/src/main/java/io/grpc/opentelemetry/MetadataGetter.java b/opentelemetry/src/main/java/io/grpc/opentelemetry/MetadataGetter.java new file mode 100644 index 00000000000..f49c029f2fb --- /dev/null +++ b/opentelemetry/src/main/java/io/grpc/opentelemetry/MetadataGetter.java @@ -0,0 +1,87 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.opentelemetry; + + +import static io.grpc.InternalMetadata.BASE64_ENCODING_OMIT_PADDING; + +import io.grpc.Metadata; +import io.opentelemetry.context.propagation.TextMapGetter; +import java.util.logging.Level; +import java.util.logging.Logger; +import javax.annotation.Nullable; + +/** + * A TextMapGetter that reads value from gRPC {@link Metadata}. Supports both text and binary + * headers. Supporting binary header is an optimization path for GrpcTraceBinContextPropagator + * to work around the lack of binary propagator API and thus avoid + * base64 (de)encoding when passing data between propagator API interfaces. + */ +final class MetadataGetter implements TextMapGetter { + private static final Logger logger = Logger.getLogger(MetadataGetter.class.getName()); + private static final MetadataGetter INSTANCE = new MetadataGetter(); + + public static MetadataGetter getInstance() { + return INSTANCE; + } + + @Override + public Iterable keys(Metadata carrier) { + return carrier.keys(); + } + + @Nullable + @Override + public String get(@Nullable Metadata carrier, String key) { + if (carrier == null) { + logger.log(Level.FINE, "Carrier is null, getting no data"); + return null; + } + try { + if (key.equals("grpc-trace-bin")) { + byte[] value = carrier.get(Metadata.Key.of(key, Metadata.BINARY_BYTE_MARSHALLER)); + if (value == null) { + return null; + } + return BASE64_ENCODING_OMIT_PADDING.encode(value); + } else { + return carrier.get(Metadata.Key.of(key, Metadata.ASCII_STRING_MARSHALLER)); + } + } catch (Exception e) { + logger.log(Level.FINE, String.format("Failed to get metadata key %s", key), e); + return null; + } + } + + @Nullable + public byte[] getBinary(@Nullable Metadata carrier, String key) { + if (carrier == null) { + logger.log(Level.FINE, "Carrier is null, getting no data"); + return null; + } + if (!key.equals("grpc-trace-bin")) { + logger.log(Level.FINE, "Only support 'grpc-trace-bin' binary header. Get no data"); + return null; + } + try { + return carrier.get(Metadata.Key.of(key, Metadata.BINARY_BYTE_MARSHALLER)); + } catch (Exception e) { + logger.log(Level.FINE, String.format("Failed to get metadata key %s", key), e); + return null; + } + } +} diff --git a/opentelemetry/src/main/java/io/grpc/opentelemetry/MetadataSetter.java b/opentelemetry/src/main/java/io/grpc/opentelemetry/MetadataSetter.java new file mode 100644 index 00000000000..5892c7accfe --- /dev/null +++ b/opentelemetry/src/main/java/io/grpc/opentelemetry/MetadataSetter.java @@ -0,0 +1,74 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.opentelemetry; + + +import com.google.common.io.BaseEncoding; +import io.grpc.Metadata; +import io.opentelemetry.context.propagation.TextMapSetter; +import java.util.logging.Level; +import java.util.logging.Logger; +import javax.annotation.Nullable; + +/** + * A {@link TextMapSetter} that sets value to gRPC {@link Metadata}. Supports both text and binary + * headers. Supporting binary header is an optimization path for GrpcTraceBinContextPropagator + * to work around the lack of binary propagator API and thus avoid + * base64 (de)encoding when passing data between propagator API interfaces. + */ +final class MetadataSetter implements TextMapSetter { + private static final Logger logger = Logger.getLogger(MetadataSetter.class.getName()); + private static final MetadataSetter INSTANCE = new MetadataSetter(); + + public static MetadataSetter getInstance() { + return INSTANCE; + } + + @Override + public void set(@Nullable Metadata carrier, String key, String value) { + if (carrier == null) { + logger.log(Level.FINE, "Carrier is null, setting no data"); + return; + } + try { + if (key.equals("grpc-trace-bin")) { + carrier.put(Metadata.Key.of(key, Metadata.BINARY_BYTE_MARSHALLER), + BaseEncoding.base64().decode(value)); + } else { + carrier.put(Metadata.Key.of(key, Metadata.ASCII_STRING_MARSHALLER), value); + } + } catch (Exception e) { + logger.log(Level.INFO, String.format("Failed to set metadata, key=%s", key), e); + } + } + + void set(@Nullable Metadata carrier, String key, byte[] value) { + if (carrier == null) { + logger.log(Level.FINE, "Carrier is null, setting no data"); + return; + } + if (!key.equals("grpc-trace-bin")) { + logger.log(Level.INFO, "Only support 'grpc-trace-bin' binary header. Set no data"); + return; + } + try { + carrier.put(Metadata.Key.of(key, Metadata.BINARY_BYTE_MARSHALLER), value); + } catch (Exception e) { + logger.log(Level.INFO, String.format("Failed to set metadata key=%s", key), e); + } + } +} diff --git a/opentelemetry/src/main/java/io/grpc/opentelemetry/OpenTelemetryMetricSink.java b/opentelemetry/src/main/java/io/grpc/opentelemetry/OpenTelemetryMetricSink.java new file mode 100644 index 00000000000..fd8af7f998f --- /dev/null +++ b/opentelemetry/src/main/java/io/grpc/opentelemetry/OpenTelemetryMetricSink.java @@ -0,0 +1,338 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.opentelemetry; + +import static com.google.common.base.Preconditions.checkNotNull; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import io.grpc.CallbackMetricInstrument; +import io.grpc.DoubleCounterMetricInstrument; +import io.grpc.DoubleHistogramMetricInstrument; +import io.grpc.LongCounterMetricInstrument; +import io.grpc.LongGaugeMetricInstrument; +import io.grpc.LongHistogramMetricInstrument; +import io.grpc.LongUpDownCounterMetricInstrument; +import io.grpc.MetricInstrument; +import io.grpc.MetricSink; +import io.opentelemetry.api.common.Attributes; +import io.opentelemetry.api.common.AttributesBuilder; +import io.opentelemetry.api.metrics.BatchCallback; +import io.opentelemetry.api.metrics.DoubleCounter; +import io.opentelemetry.api.metrics.DoubleHistogram; +import io.opentelemetry.api.metrics.LongCounter; +import io.opentelemetry.api.metrics.LongHistogram; +import io.opentelemetry.api.metrics.LongUpDownCounter; +import io.opentelemetry.api.metrics.Meter; +import io.opentelemetry.api.metrics.ObservableLongMeasurement; +import io.opentelemetry.api.metrics.ObservableMeasurement; +import java.util.ArrayList; +import java.util.BitSet; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.logging.Level; +import java.util.logging.Logger; + +final class OpenTelemetryMetricSink implements MetricSink { + private static final Logger logger = Logger.getLogger(OpenTelemetryMetricSink.class.getName()); + private final Object lock = new Object(); + private final Meter openTelemetryMeter; + private final Map enableMetrics; + private final boolean disableDefaultMetrics; + private final Set optionalLabels; + private volatile List measures = new ArrayList<>(); + + OpenTelemetryMetricSink(Meter meter, Map enableMetrics, + boolean disableDefaultMetrics, List optionalLabels) { + this.openTelemetryMeter = checkNotNull(meter, "meter"); + this.enableMetrics = ImmutableMap.copyOf(enableMetrics); + this.disableDefaultMetrics = disableDefaultMetrics; + this.optionalLabels = ImmutableSet.copyOf(optionalLabels); + } + + @Override + public Map getEnabledMetrics() { + return enableMetrics; + } + + @Override + public Set getOptionalLabels() { + return optionalLabels; + } + + @Override + public int getMeasuresSize() { + return measures.size(); + } + + @VisibleForTesting + List getMeasures() { + synchronized (lock) { + return Collections.unmodifiableList(measures); + } + } + + @Override + public void addDoubleCounter(DoubleCounterMetricInstrument metricInstrument, double value, + List requiredLabelValues, List optionalLabelValues) { + MeasuresData instrumentData = measures.get(metricInstrument.getIndex()); + if (instrumentData == null) { + // Disabled metric + return; + } + Attributes attributes = createAttributes(metricInstrument.getRequiredLabelKeys(), + metricInstrument.getOptionalLabelKeys(), requiredLabelValues, optionalLabelValues, + instrumentData.getOptionalLabelsBitSet()); + DoubleCounter counter = (DoubleCounter) instrumentData.getMeasure(); + counter.add(value, attributes); + } + + @Override + public void addLongCounter(LongCounterMetricInstrument metricInstrument, long value, + List requiredLabelValues, List optionalLabelValues) { + MeasuresData instrumentData = measures.get(metricInstrument.getIndex()); + if (instrumentData == null) { + // Disabled metric + return; + } + Attributes attributes = createAttributes(metricInstrument.getRequiredLabelKeys(), + metricInstrument.getOptionalLabelKeys(), requiredLabelValues, optionalLabelValues, + instrumentData.getOptionalLabelsBitSet()); + LongCounter counter = (LongCounter) instrumentData.getMeasure(); + counter.add(value, attributes); + } + + @Override + public void addLongUpDownCounter(LongUpDownCounterMetricInstrument metricInstrument, long value, + List requiredLabelValues, + List optionalLabelValues) { + MeasuresData instrumentData = measures.get(metricInstrument.getIndex()); + if (instrumentData == null) { + // Disabled metric + return; + } + Attributes attributes = createAttributes(metricInstrument.getRequiredLabelKeys(), + metricInstrument.getOptionalLabelKeys(), requiredLabelValues, optionalLabelValues, + instrumentData.getOptionalLabelsBitSet()); + LongUpDownCounter counter = (LongUpDownCounter) instrumentData.getMeasure(); + counter.add(value, attributes); + } + + @Override + public void recordDoubleHistogram(DoubleHistogramMetricInstrument metricInstrument, double value, + List requiredLabelValues, List optionalLabelValues) { + MeasuresData instrumentData = measures.get(metricInstrument.getIndex()); + if (instrumentData == null) { + // Disabled metric + return; + } + Attributes attributes = createAttributes(metricInstrument.getRequiredLabelKeys(), + metricInstrument.getOptionalLabelKeys(), requiredLabelValues, optionalLabelValues, + instrumentData.getOptionalLabelsBitSet()); + DoubleHistogram histogram = (DoubleHistogram) instrumentData.getMeasure(); + histogram.record(value, attributes); + } + + @Override + public void recordLongHistogram(LongHistogramMetricInstrument metricInstrument, long value, + List requiredLabelValues, List optionalLabelValues) { + MeasuresData instrumentData = measures.get(metricInstrument.getIndex()); + if (instrumentData == null) { + // Disabled metric + return; + } + Attributes attributes = createAttributes(metricInstrument.getRequiredLabelKeys(), + metricInstrument.getOptionalLabelKeys(), requiredLabelValues, optionalLabelValues, + instrumentData.getOptionalLabelsBitSet()); + LongHistogram histogram = (LongHistogram) instrumentData.getMeasure(); + histogram.record(value, attributes); + } + + @Override + public void recordLongGauge(LongGaugeMetricInstrument metricInstrument, long value, + List requiredLabelValues, List optionalLabelValues) { + MeasuresData instrumentData = measures.get(metricInstrument.getIndex()); + if (instrumentData == null) { + // Disabled metric + return; + } + Attributes attributes = createAttributes(metricInstrument.getRequiredLabelKeys(), + metricInstrument.getOptionalLabelKeys(), requiredLabelValues, optionalLabelValues, + instrumentData.getOptionalLabelsBitSet()); + ObservableLongMeasurement gauge = (ObservableLongMeasurement) instrumentData.getMeasure(); + gauge.record(value, attributes); + } + + @Override + public Registration registerBatchCallback(Runnable callback, + CallbackMetricInstrument... metricInstruments) { + List measurements = new ArrayList<>(metricInstruments.length); + for (CallbackMetricInstrument metricInstrument: metricInstruments) { + MeasuresData instrumentData = measures.get(metricInstrument.getIndex()); + if (instrumentData == null) { + // Disabled metric + continue; + } + if (!(instrumentData.getMeasure() instanceof ObservableMeasurement)) { + logger.log(Level.FINE, "Unsupported metric instrument type : {0} {1}", + new Object[] {metricInstrument, instrumentData.getMeasure().getClass()}); + continue; + } + measurements.add((ObservableMeasurement) instrumentData.getMeasure()); + } + if (measurements.isEmpty()) { + return () -> { }; + } + ObservableMeasurement first = measurements.get(0); + measurements.remove(0); + BatchCallback closeable = openTelemetryMeter.batchCallback( + callback, first, measurements.toArray(new ObservableMeasurement[0])); + return closeable::close; + } + + @Override + public void updateMeasures(List instruments) { + synchronized (lock) { + if (measures.size() >= instruments.size()) { + // Already up-to-date + return; + } + + List newMeasures = new ArrayList<>(instruments.size()); + // Reuse existing measures + newMeasures.addAll(measures); + + for (int i = measures.size(); i < instruments.size(); i++) { + MetricInstrument instrument = instruments.get(i); + // Check if the metric is disabled + if (!shouldEnableMetric(instrument)) { + // Adding null measure for disabled Metric + newMeasures.add(null); + continue; + } + + BitSet bitSet = new BitSet(instrument.getOptionalLabelKeys().size()); + if (optionalLabels.isEmpty()) { + // initialize an empty list + } else { + List labels = instrument.getOptionalLabelKeys(); + for (int j = 0; j < labels.size(); j++) { + if (optionalLabels.contains(labels.get(j))) { + bitSet.set(j); + } + } + } + + int index = instrument.getIndex(); + String name = instrument.getName(); + String unit = instrument.getUnit(); + String description = instrument.getDescription(); + + Object openTelemetryMeasure; + if (instrument instanceof DoubleCounterMetricInstrument) { + openTelemetryMeasure = openTelemetryMeter.counterBuilder(name) + .setUnit(unit) + .setDescription(description) + .ofDoubles() + .build(); + } else if (instrument instanceof LongCounterMetricInstrument) { + openTelemetryMeasure = openTelemetryMeter.counterBuilder(name) + .setUnit(unit) + .setDescription(description) + .build(); + } else if (instrument instanceof DoubleHistogramMetricInstrument) { + openTelemetryMeasure = openTelemetryMeter.histogramBuilder(name) + .setUnit(unit) + .setDescription(description) + .build(); + } else if (instrument instanceof LongHistogramMetricInstrument) { + openTelemetryMeasure = openTelemetryMeter.histogramBuilder(name) + .setUnit(unit) + .setDescription(description) + .ofLongs() + .build(); + } else if (instrument instanceof LongGaugeMetricInstrument) { + openTelemetryMeasure = openTelemetryMeter.gaugeBuilder(name) + .setUnit(unit) + .setDescription(description) + .ofLongs() + .buildObserver(); + } else if (instrument instanceof LongUpDownCounterMetricInstrument) { + openTelemetryMeasure = openTelemetryMeter.upDownCounterBuilder(name) + .setUnit(unit) + .setDescription(description) + .build(); + } else { + logger.log(Level.FINE, "Unsupported metric instrument type : {0}", instrument); + openTelemetryMeasure = null; + } + newMeasures.add(index, new MeasuresData(bitSet, openTelemetryMeasure)); + } + + measures = newMeasures; + } + } + + private boolean shouldEnableMetric(MetricInstrument instrument) { + Boolean explicitlyEnabled = enableMetrics.get(instrument.getName()); + if (explicitlyEnabled != null) { + return explicitlyEnabled; + } + return instrument.isEnableByDefault() && !disableDefaultMetrics; + } + + + private Attributes createAttributes(List requiredLabelKeys, + List optionalLabelKeys, + List requiredLabelValues, List optionalLabelValues, BitSet bitSet) { + AttributesBuilder builder = Attributes.builder(); + // Required Labels + for (int i = 0; i < requiredLabelKeys.size(); i++) { + builder.put(requiredLabelKeys.get(i), requiredLabelValues.get(i)); + } + // Optional labels + for (int i = bitSet.nextSetBit(0); i >= 0; i = bitSet.nextSetBit(i + 1)) { + if (i == Integer.MAX_VALUE) { + break; // or (i+1) would overflow + } + builder.put(optionalLabelKeys.get(i), optionalLabelValues.get(i)); + } + return builder.build(); + } + + static final class MeasuresData { + final BitSet optionalLabelsIndices; + final Object measure; + + MeasuresData(BitSet optionalLabelsIndices, Object measure) { + this.optionalLabelsIndices = optionalLabelsIndices; + this.measure = measure; + } + + public BitSet getOptionalLabelsBitSet() { + return optionalLabelsIndices; + } + + public Object getMeasure() { + return measure; + } + } + +} diff --git a/opentelemetry/src/main/java/io/grpc/opentelemetry/OpenTelemetryMetricsModule.java b/opentelemetry/src/main/java/io/grpc/opentelemetry/OpenTelemetryMetricsModule.java index 66513a5ca2f..f783b9495dd 100644 --- a/opentelemetry/src/main/java/io/grpc/opentelemetry/OpenTelemetryMetricsModule.java +++ b/opentelemetry/src/main/java/io/grpc/opentelemetry/OpenTelemetryMetricsModule.java @@ -17,12 +17,20 @@ package io.grpc.opentelemetry; import static com.google.common.base.Preconditions.checkNotNull; +import static io.grpc.opentelemetry.internal.OpenTelemetryConstants.BACKEND_SERVICE_KEY; +import static io.grpc.opentelemetry.internal.OpenTelemetryConstants.BAGGAGE_KEY; +import static io.grpc.opentelemetry.internal.OpenTelemetryConstants.CUSTOM_LABEL_KEY; +import static io.grpc.opentelemetry.internal.OpenTelemetryConstants.LOCALITY_KEY; import static io.grpc.opentelemetry.internal.OpenTelemetryConstants.METHOD_KEY; import static io.grpc.opentelemetry.internal.OpenTelemetryConstants.STATUS_KEY; +import static io.grpc.opentelemetry.internal.OpenTelemetryConstants.TARGET_KEY; import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Stopwatch; import com.google.common.base.Supplier; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; +import com.google.errorprone.annotations.concurrent.GuardedBy; import io.grpc.CallOptions; import io.grpc.Channel; import io.grpc.ClientCall; @@ -32,21 +40,29 @@ import io.grpc.Deadline; import io.grpc.ForwardingClientCall.SimpleForwardingClientCall; import io.grpc.ForwardingClientCallListener.SimpleForwardingClientCallListener; +import io.grpc.Grpc; import io.grpc.Metadata; import io.grpc.MethodDescriptor; import io.grpc.ServerStreamTracer; import io.grpc.Status; import io.grpc.Status.Code; import io.grpc.StreamTracer; +import io.grpc.internal.StatsTraceContext.ServerCallMethodListener; +import io.grpc.opentelemetry.GrpcOpenTelemetry.TargetFilter; +import io.opentelemetry.api.baggage.Baggage; +import io.opentelemetry.api.common.AttributesBuilder; +import io.opentelemetry.context.Context; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.List; import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicLongFieldUpdater; import java.util.logging.Level; import java.util.logging.Logger; import javax.annotation.Nullable; -import javax.annotation.concurrent.GuardedBy; /** * Provides factories for {@link StreamTracer} that records metrics to OpenTelemetry. @@ -56,12 +72,27 @@ * tracer. It's the tracer that reports per-attempt stats, and the factory that reports the stats * of the overall RPC, such as RETRIES_PER_CALL, to OpenTelemetry. * + *

This module optionally applies a target attribute filter to limit the cardinality of + * the {@code grpc.target} attribute in client-side metrics by mapping disallowed targets + * to a stable placeholder value. + * *

On the server-side, there is only one ServerStream per each ServerCall, and ServerStream * starts earlier than the ServerCall. Therefore, only one tracer is created per stream/call, and * it's the tracer that reports the summary to OpenTelemetry. */ final class OpenTelemetryMetricsModule { private static final Logger logger = Logger.getLogger(OpenTelemetryMetricsModule.class.getName()); + public static final ImmutableSet DEFAULT_PER_CALL_METRICS_SET = + ImmutableSet.of( + "grpc.client.attempt.started", + "grpc.client.attempt.duration", + "grpc.client.attempt.sent_total_compressed_message_size", + "grpc.client.attempt.rcvd_total_compressed_message_size", + "grpc.client.call.duration", + "grpc.server.call.started", + "grpc.server.call.duration", + "grpc.server.call.sent_total_compressed_message_size", + "grpc.server.call.rcvd_total_compressed_message_size"); // Using floating point because TimeUnit.NANOSECONDS.toSeconds would discard // fractional seconds. @@ -69,11 +100,35 @@ final class OpenTelemetryMetricsModule { private final OpenTelemetryMetricsResource resource; private final Supplier stopwatchSupplier; + private final boolean localityEnabled; + private final boolean backendServiceEnabled; + private final boolean customLabelEnabled; + private final ImmutableList plugins; + @Nullable + private final TargetFilter targetAttributeFilter; + + OpenTelemetryMetricsModule(Supplier stopwatchSupplier, + OpenTelemetryMetricsResource resource, + Collection optionalLabels, List plugins) { + this(stopwatchSupplier, resource, optionalLabels, plugins, null); + } OpenTelemetryMetricsModule(Supplier stopwatchSupplier, - OpenTelemetryMetricsResource resource) { + OpenTelemetryMetricsResource resource, + Collection optionalLabels, List plugins, + @Nullable TargetFilter targetAttributeFilter) { this.resource = checkNotNull(resource, "resource"); this.stopwatchSupplier = checkNotNull(stopwatchSupplier, "stopwatchSupplier"); + this.localityEnabled = optionalLabels.contains(LOCALITY_KEY.getKey()); + this.backendServiceEnabled = optionalLabels.contains(BACKEND_SERVICE_KEY.getKey()); + this.customLabelEnabled = optionalLabels.contains(CUSTOM_LABEL_KEY.getKey()); + this.plugins = ImmutableList.copyOf(plugins); + this.targetAttributeFilter = targetAttributeFilter; + } + + @VisibleForTesting + TargetFilter getTargetAttributeFilter() { + return targetAttributeFilter; } /** @@ -86,8 +141,23 @@ ServerStreamTracer.Factory getServerTracerFactory() { /** * Returns the client interceptor that facilitates OpenTelemetry metrics reporting. */ - ClientInterceptor getClientInterceptor() { - return new MetricsClientInterceptor(); + ClientInterceptor getClientInterceptor(String target) { + ImmutableList.Builder pluginBuilder = + ImmutableList.builderWithExpectedSize(plugins.size()); + for (OpenTelemetryPlugin plugin : plugins) { + if (plugin.enablePluginForChannel(target)) { + pluginBuilder.add(plugin); + } + } + String filteredTarget = recordTarget(target); + return new MetricsClientInterceptor(filteredTarget, pluginBuilder.build()); + } + + String recordTarget(String target) { + if (targetAttributeFilter == null || target == null) { + return target; + } + return targetAttributeFilter.test(target) ? target : "other"; } static String recordMethodName(String fullMethodName, boolean isGeneratedMethod) { @@ -122,24 +192,37 @@ private static final class ClientTracer extends ClientStreamTracer { final Stopwatch stopwatch; final CallAttemptsTracerFactory attemptsState; - final AtomicBoolean inboundReceivedOrClosed = new AtomicBoolean(); final OpenTelemetryMetricsModule module; final StreamInfo info; + final String target; final String fullMethodName; + final List streamPlugins; volatile long outboundWireSize; volatile long inboundWireSize; + volatile String locality; + volatile String backendService; long attemptNanos; Code statusCode; ClientTracer(CallAttemptsTracerFactory attemptsState, OpenTelemetryMetricsModule module, - StreamInfo info, String fullMethodName) { + StreamInfo info, String target, String fullMethodName, + List streamPlugins) { this.attemptsState = attemptsState; this.module = module; this.info = info; + this.target = target; this.fullMethodName = fullMethodName; + this.streamPlugins = streamPlugins; this.stopwatch = module.stopwatchSupplier.get().start(); } + @Override + public void inboundHeaders(Metadata headers) { + for (OpenTelemetryPlugin.ClientStreamPlugin plugin : streamPlugins) { + plugin.inboundHeaders(headers); + } + } + @Override @SuppressWarnings("NonAtomicVolatileUpdate") public void outboundWireSize(long bytes) { @@ -161,15 +244,21 @@ public void inboundWireSize(long bytes) { } @Override - @SuppressWarnings("NonAtomicVolatileUpdate") - public void inboundMessage(int seqNo) { - if (inboundReceivedOrClosed.compareAndSet(false, true)) { - // Because inboundUncompressedSize() might be called after streamClosed(), - // we will report stats in callEnded(). Note that this attempt is already committed. - attemptsState.inboundMetricTracer = this; + public void addOptionalLabel(String key, String value) { + if ("grpc.lb.locality".equals(key)) { + locality = value; + } + if ("grpc.lb.backend_service".equals(key)) { + backendService = value; } } + @Override + public void inboundTrailers(Metadata trailers) { + for (OpenTelemetryPlugin.ClientStreamPlugin plugin : streamPlugins) { + plugin.inboundTrailers(trailers); + } + } @Override public void streamClosed(Status status) { @@ -185,60 +274,103 @@ public void streamClosed(Status status) { statusCode = Code.DEADLINE_EXCEEDED; } } - attemptsState.attemptEnded(); - if (inboundReceivedOrClosed.compareAndSet(false, true)) { - // Stream is closed early. So no need to record metrics for any inbound events after this - // point. - recordFinishedAttempt(); - } // Otherwise will report metrics in callEnded() to guarantee all inbound metrics are - // recorded. + attemptsState.attemptEnded(info.getCallOptions()); + recordFinishedAttempt(); } void recordFinishedAttempt() { - // TODO(dnvindhya) : add target as an attribute - io.opentelemetry.api.common.Attributes attribute = - io.opentelemetry.api.common.Attributes.of(METHOD_KEY, fullMethodName, - STATUS_KEY, statusCode.toString()); - - module.resource.clientAttemptDurationCounter() - .record(attemptNanos * SECONDS_PER_NANO, attribute); - module.resource.clientTotalSentCompressedMessageSizeCounter() - .record(outboundWireSize, attribute); - module.resource.clientTotalReceivedCompressedMessageSizeCounter() - .record(inboundWireSize, attribute); + AttributesBuilder builder = io.opentelemetry.api.common.Attributes.builder() + .put(METHOD_KEY, fullMethodName) + .put(TARGET_KEY, target) + .put(STATUS_KEY, statusCode.toString()); + if (module.localityEnabled) { + String savedLocality = locality; + if (savedLocality == null) { + savedLocality = ""; + } + builder.put(LOCALITY_KEY, savedLocality); + } + if (module.backendServiceEnabled) { + String savedBackendService = backendService; + if (savedBackendService == null) { + savedBackendService = ""; + } + builder.put(BACKEND_SERVICE_KEY, savedBackendService); + } + if (module.customLabelEnabled) { + builder.put( + CUSTOM_LABEL_KEY, info.getCallOptions().getOption(Grpc.CALL_OPTION_CUSTOM_LABEL)); + } + for (OpenTelemetryPlugin.ClientStreamPlugin plugin : streamPlugins) { + plugin.addLabels(builder); + } + io.opentelemetry.api.common.Attributes attribute = builder.build(); + + if (module.resource.clientAttemptDurationCounter() != null ) { + module.resource.clientAttemptDurationCounter() + .record(attemptNanos * SECONDS_PER_NANO, attribute, attemptsState.otelContext); + } + if (module.resource.clientTotalSentCompressedMessageSizeCounter() != null) { + module.resource.clientTotalSentCompressedMessageSizeCounter() + .record(outboundWireSize, attribute, attemptsState.otelContext); + } + if (module.resource.clientTotalReceivedCompressedMessageSizeCounter() != null) { + module.resource.clientTotalReceivedCompressedMessageSizeCounter() + .record(inboundWireSize, attribute, attemptsState.otelContext); + } } } @VisibleForTesting static final class CallAttemptsTracerFactory extends ClientStreamTracer.Factory { - ClientTracer inboundMetricTracer; private final OpenTelemetryMetricsModule module; - private final Stopwatch attemptStopwatch; + private final String target; + private final Stopwatch attemptDelayStopwatch; private final Stopwatch callStopWatch; @GuardedBy("lock") private boolean callEnded; private final String fullMethodName; + private final List callPlugins; + private final Context otelContext; private Status status; + private long retryDelayNanos; private long callLatencyNanos; private final Object lock = new Object(); private final AtomicLong attemptsPerCall = new AtomicLong(); + private final AtomicLong hedgedAttemptsPerCall = new AtomicLong(); + private final AtomicLong transparentRetriesPerCall = new AtomicLong(); @GuardedBy("lock") private int activeStreams; @GuardedBy("lock") private boolean finishedCallToBeRecorded; - CallAttemptsTracerFactory(OpenTelemetryMetricsModule module, String fullMethodName) { + CallAttemptsTracerFactory( + OpenTelemetryMetricsModule module, + String target, + CallOptions callOptions, + String fullMethodName, + List callPlugins, Context otelContext) { this.module = checkNotNull(module, "module"); + this.target = checkNotNull(target, "target"); this.fullMethodName = checkNotNull(fullMethodName, "fullMethodName"); - this.attemptStopwatch = module.stopwatchSupplier.get(); + this.callPlugins = checkNotNull(callPlugins, "callPlugins"); + this.otelContext = checkNotNull(otelContext, "otelContext"); + this.attemptDelayStopwatch = module.stopwatchSupplier.get(); this.callStopWatch = module.stopwatchSupplier.get().start(); - // TODO(dnvindhya) : add target as an attribute - io.opentelemetry.api.common.Attributes attribute = - io.opentelemetry.api.common.Attributes.of(METHOD_KEY, fullMethodName); + AttributesBuilder builder = io.opentelemetry.api.common.Attributes.builder() + .put(METHOD_KEY, fullMethodName) + .put(TARGET_KEY, target); + if (module.customLabelEnabled) { + builder.put( + CUSTOM_LABEL_KEY, callOptions.getOption(Grpc.CALL_OPTION_CUSTOM_LABEL)); + } + io.opentelemetry.api.common.Attributes attribute = builder.build(); // Record here in case mewClientStreamTracer() would never be called. - module.resource.clientAttemptCountCounter().add(1, attribute); + if (module.resource.clientAttemptCountCounter() != null) { + module.resource.clientAttemptCountCounter().add(1, attribute, otelContext); + } } @Override @@ -248,31 +380,55 @@ public ClientStreamTracer newClientStreamTracer(StreamInfo info, Metadata metada // This can be the case when the call is cancelled but a retry attempt is created. return new ClientStreamTracer() {}; } - if (++activeStreams == 1 && attemptStopwatch.isRunning()) { - attemptStopwatch.stop(); + if (++activeStreams == 1 && attemptDelayStopwatch.isRunning()) { + attemptDelayStopwatch.stop(); + retryDelayNanos = attemptDelayStopwatch.elapsed(TimeUnit.NANOSECONDS); } } // Skip recording for the first time, since it is already recorded in // CallAttemptsTracerFactory constructor. attemptsPerCall will be non-zero after the first // attempt, as first attempt cannot be a transparent retry. if (attemptsPerCall.get() > 0) { - // TODO(dnvindhya): Add target as an attribute - io.opentelemetry.api.common.Attributes attribute = - io.opentelemetry.api.common.Attributes.of(METHOD_KEY, fullMethodName); - module.resource.clientAttemptCountCounter().add(1, attribute); + AttributesBuilder builder = io.opentelemetry.api.common.Attributes.builder() + .put(METHOD_KEY, fullMethodName) + .put(TARGET_KEY, target); + if (module.customLabelEnabled) { + builder.put( + CUSTOM_LABEL_KEY, info.getCallOptions().getOption(Grpc.CALL_OPTION_CUSTOM_LABEL)); + } + io.opentelemetry.api.common.Attributes attribute = builder.build(); + if (module.resource.clientAttemptCountCounter() != null) { + module.resource.clientAttemptCountCounter().add(1, attribute, otelContext); + } } - if (!info.isTransparentRetry()) { + if (info.isTransparentRetry()) { + transparentRetriesPerCall.incrementAndGet(); + } else if (info.isHedging()) { + hedgedAttemptsPerCall.incrementAndGet(); + } else { attemptsPerCall.incrementAndGet(); } - return new ClientTracer(this, module, info, fullMethodName); + return newClientTracer(info); + } + + private ClientTracer newClientTracer(StreamInfo info) { + List streamPlugins = Collections.emptyList(); + if (!callPlugins.isEmpty()) { + streamPlugins = new ArrayList<>(callPlugins.size()); + for (OpenTelemetryPlugin.ClientCallPlugin plugin : callPlugins) { + streamPlugins.add(plugin.newClientStreamPlugin()); + } + streamPlugins = Collections.unmodifiableList(streamPlugins); + } + return new ClientTracer(this, module, info, target, fullMethodName, streamPlugins); } // Called whenever each attempt is ended. - void attemptEnded() { + void attemptEnded(CallOptions callOptions) { boolean shouldRecordFinishedCall = false; synchronized (lock) { if (--activeStreams == 0) { - attemptStopwatch.start(); + attemptDelayStopwatch.start(); if (callEnded && !finishedCallToBeRecorded) { shouldRecordFinishedCall = true; finishedCallToBeRecorded = true; @@ -280,11 +436,11 @@ void attemptEnded() { } } if (shouldRecordFinishedCall) { - recordFinishedCall(); + recordFinishedCall(callOptions); } } - void callEnded(Status status) { + void callEnded(Status status, CallOptions callOptions) { callStopWatch.stop(); this.status = status; boolean shouldRecordFinishedCall = false; @@ -300,33 +456,79 @@ void callEnded(Status status) { } } if (shouldRecordFinishedCall) { - recordFinishedCall(); + recordFinishedCall(callOptions); } } - void recordFinishedCall() { + void recordFinishedCall(CallOptions callOptions) { if (attemptsPerCall.get() == 0) { - ClientTracer tracer = new ClientTracer(this, module, null, fullMethodName); - tracer.attemptNanos = attemptStopwatch.elapsed(TimeUnit.NANOSECONDS); + ClientTracer tracer = newClientTracer(null); + tracer.attemptNanos = attemptDelayStopwatch.elapsed(TimeUnit.NANOSECONDS); tracer.statusCode = status.getCode(); tracer.recordFinishedAttempt(); - } else if (inboundMetricTracer != null) { - // activeStreams has been decremented to 0 by attemptEnded(), - // so inboundMetricTracer.statusCode is guaranteed to be assigned already. - inboundMetricTracer.recordFinishedAttempt(); } callLatencyNanos = callStopWatch.elapsed(TimeUnit.NANOSECONDS); - // TODO(dnvindhya): record target as an attribute - io.opentelemetry.api.common.Attributes attribute = - io.opentelemetry.api.common.Attributes.of(METHOD_KEY, fullMethodName, - STATUS_KEY, status.getCode().toString()); - module.resource.clientCallDurationCounter() - .record(callLatencyNanos * SECONDS_PER_NANO, attribute); + // Base attributes + AttributesBuilder builder = io.opentelemetry.api.common.Attributes.builder() + .put(METHOD_KEY, fullMethodName) + .put(TARGET_KEY, target); + if (module.customLabelEnabled) { + builder.put(CUSTOM_LABEL_KEY, callOptions.getOption(Grpc.CALL_OPTION_CUSTOM_LABEL)); + } + io.opentelemetry.api.common.Attributes baseAttributes = builder.build(); + + // Duration + if (module.resource.clientCallDurationCounter() != null) { + module.resource.clientCallDurationCounter().record( + callLatencyNanos * SECONDS_PER_NANO, + baseAttributes.toBuilder() + .put(STATUS_KEY, status.getCode().toString()) + .build(), + otelContext + ); + } + + // Retry counts + if (module.resource.clientCallRetriesCounter() != null) { + long retriesPerCall = Math.max(attemptsPerCall.get() - 1, 0); + if (retriesPerCall > 0) { + module.resource.clientCallRetriesCounter() + .record(retriesPerCall, baseAttributes, otelContext); + } + } + + // Hedge counts + if (module.resource.clientCallHedgesCounter() != null) { + long hedges = hedgedAttemptsPerCall.get(); + if (hedges > 0) { + module.resource.clientCallHedgesCounter() + .record(hedges, baseAttributes, otelContext); + } + } + + // Transparent Retry counts + if (module.resource.clientCallTransparentRetriesCounter() != null) { + long transparentRetries = transparentRetriesPerCall.get(); + if (transparentRetries > 0) { + module.resource.clientCallTransparentRetriesCounter() + .record(transparentRetries, baseAttributes, otelContext); + } + } + + // Retry delay + if (module.resource.clientCallRetryDelayCounter() != null) { + module.resource.clientCallRetryDelayCounter().record( + retryDelayNanos * SECONDS_PER_NANO, + baseAttributes, + otelContext + ); + } } } - private static final class ServerTracer extends ServerStreamTracer { + private static final class ServerTracer extends ServerStreamTracer + implements ServerCallMethodListener { @Nullable private static final AtomicIntegerFieldUpdater streamClosedUpdater; @Nullable private static final AtomicLongFieldUpdater outboundWireSizeUpdater; @Nullable private static final AtomicLongFieldUpdater inboundWireSizeUpdater; @@ -360,18 +562,38 @@ private static final class ServerTracer extends ServerStreamTracer { private final OpenTelemetryMetricsModule module; private final String fullMethodName; + private final List streamPlugins; + private Context otelContext = Context.root(); private volatile boolean isGeneratedMethod; private volatile int streamClosed; private final Stopwatch stopwatch; private volatile long outboundWireSize; private volatile long inboundWireSize; - ServerTracer(OpenTelemetryMetricsModule module, String fullMethodName) { + ServerTracer(OpenTelemetryMetricsModule module, String fullMethodName, + List streamPlugins) { this.module = checkNotNull(module, "module"); this.fullMethodName = fullMethodName; + this.streamPlugins = checkNotNull(streamPlugins, "streamPlugins"); this.stopwatch = module.stopwatchSupplier.get().start(); } + @Override + public io.grpc.Context filterContext(io.grpc.Context context) { + Baggage baggage = BAGGAGE_KEY.get(context); + if (baggage != null) { + otelContext = Context.current().with(baggage); + } else { + otelContext = Context.current(); + } + return context; + } + + @Override + public void serverCallMethodResolved(MethodDescriptor method) { + isGeneratedMethod = method.isSampledToLocalTracing(); + } + @Override public void serverCallStarted(ServerCallInfo callInfo) { // Only record method name as an attribute if isSampledToLocalTracing is set to true, @@ -379,11 +601,14 @@ public void serverCallStarted(ServerCallInfo callInfo) { // created methods result in high cardinality metrics. boolean isSampledToLocalTracing = callInfo.getMethodDescriptor().isSampledToLocalTracing(); isGeneratedMethod = isSampledToLocalTracing; + io.opentelemetry.api.common.Attributes attribute = io.opentelemetry.api.common.Attributes.of( METHOD_KEY, recordMethodName(fullMethodName, isSampledToLocalTracing)); - module.resource.serverCallCountCounter().add(1, attribute); + if (module.resource.serverCallCountCounter() != null) { + module.resource.serverCallCountCounter().add(1, attribute, otelContext); + } } @Override @@ -426,17 +651,41 @@ public void streamClosed(Status status) { } stopwatch.stop(); long elapsedTimeNanos = stopwatch.elapsed(TimeUnit.NANOSECONDS); - io.opentelemetry.api.common.Attributes attributes = - io.opentelemetry.api.common.Attributes.of( - METHOD_KEY, recordMethodName(fullMethodName, isGeneratedMethod), - STATUS_KEY, status.getCode().toString()); - - module.resource.serverCallDurationCounter() - .record(elapsedTimeNanos * SECONDS_PER_NANO, attributes); - module.resource.serverTotalSentCompressedMessageSizeCounter() - .record(outboundWireSize, attributes); - module.resource.serverTotalReceivedCompressedMessageSizeCounter() - .record(inboundWireSize, attributes); + recordClosedStream( + status, + elapsedTimeNanos, + outboundWireSize, + inboundWireSize, + isGeneratedMethod); + } + + private void recordClosedStream( + Status status, + long elapsedTimeNanos, + long closedOutboundWireSize, + long closedInboundWireSize, + boolean generatedMethod) { + AttributesBuilder builder = + io.opentelemetry.api.common.Attributes.builder() + .put(METHOD_KEY, recordMethodName(fullMethodName, generatedMethod)) + .put(STATUS_KEY, status.getCode().toString()); + for (OpenTelemetryPlugin.ServerStreamPlugin plugin : streamPlugins) { + plugin.addLabels(builder); + } + io.opentelemetry.api.common.Attributes attributes = builder.build(); + + if (module.resource.serverCallDurationCounter() != null) { + module.resource.serverCallDurationCounter() + .record(elapsedTimeNanos * SECONDS_PER_NANO, attributes, otelContext); + } + if (module.resource.serverTotalSentCompressedMessageSizeCounter() != null) { + module.resource.serverTotalSentCompressedMessageSizeCounter() + .record(closedOutboundWireSize, attributes, otelContext); + } + if (module.resource.serverTotalReceivedCompressedMessageSizeCounter() != null) { + module.resource.serverTotalReceivedCompressedMessageSizeCounter() + .record(closedInboundWireSize, attributes, otelContext); + } } } @@ -444,31 +693,70 @@ METHOD_KEY, recordMethodName(fullMethodName, isGeneratedMethod), final class ServerTracerFactory extends ServerStreamTracer.Factory { @Override public ServerStreamTracer newServerStreamTracer(String fullMethodName, Metadata headers) { - return new ServerTracer(OpenTelemetryMetricsModule.this, fullMethodName); + final List streamPlugins; + if (plugins.isEmpty()) { + streamPlugins = Collections.emptyList(); + } else { + List streamPluginsMutable = + new ArrayList<>(plugins.size()); + for (OpenTelemetryPlugin plugin : plugins) { + streamPluginsMutable.add(plugin.newServerStreamPlugin(headers)); + } + streamPlugins = Collections.unmodifiableList(streamPluginsMutable); + } + return new ServerTracer(OpenTelemetryMetricsModule.this, fullMethodName, + streamPlugins); } } @VisibleForTesting final class MetricsClientInterceptor implements ClientInterceptor { + private final String target; + private final ImmutableList plugins; + + MetricsClientInterceptor(String target, ImmutableList plugins) { + this.target = checkNotNull(target, "target"); + this.plugins = checkNotNull(plugins, "plugins"); + } + @Override public ClientCall interceptCall( MethodDescriptor method, CallOptions callOptions, Channel next) { + final List callPlugins; + if (plugins.isEmpty()) { + callPlugins = Collections.emptyList(); + } else { + List callPluginsMutable = + new ArrayList<>(plugins.size()); + for (OpenTelemetryPlugin plugin : plugins) { + callPluginsMutable.add(plugin.newClientCallPlugin()); + } + callPlugins = Collections.unmodifiableList(callPluginsMutable); + for (OpenTelemetryPlugin.ClientCallPlugin plugin : callPlugins) { + callOptions = plugin.filterCallOptions(callOptions); + } + } + final CallOptions finalCallOptions = callOptions; // Only record method name as an attribute if isSampledToLocalTracing is set to true, // which is true for all generated methods. Otherwise, programatically // created methods result in high cardinality metrics. final CallAttemptsTracerFactory tracerFactory = new CallAttemptsTracerFactory( - OpenTelemetryMetricsModule.this, recordMethodName(method.getFullMethodName(), - method.isSampledToLocalTracing())); + OpenTelemetryMetricsModule.this, target, callOptions, + recordMethodName(method.getFullMethodName(), method.isSampledToLocalTracing()), + callPlugins, Context.current()); ClientCall call = next.newCall(method, callOptions.withStreamTracerFactory(tracerFactory)); return new SimpleForwardingClientCall(call) { @Override public void start(Listener responseListener, Metadata headers) { + for (OpenTelemetryPlugin.ClientCallPlugin plugin : callPlugins) { + plugin.addMetadata(headers); + } delegate().start( new SimpleForwardingClientCallListener(responseListener) { @Override public void onClose(Status status, Metadata trailers) { - tracerFactory.callEnded(status); + tracerFactory.callEnded(status, finalCallOptions); super.onClose(status, trailers); } }, diff --git a/opentelemetry/src/main/java/io/grpc/opentelemetry/OpenTelemetryMetricsResource.java b/opentelemetry/src/main/java/io/grpc/opentelemetry/OpenTelemetryMetricsResource.java index a435ec6bcaa..d32ae1e67f5 100644 --- a/opentelemetry/src/main/java/io/grpc/opentelemetry/OpenTelemetryMetricsResource.java +++ b/opentelemetry/src/main/java/io/grpc/opentelemetry/OpenTelemetryMetricsResource.java @@ -20,29 +20,50 @@ import io.opentelemetry.api.metrics.DoubleHistogram; import io.opentelemetry.api.metrics.LongCounter; import io.opentelemetry.api.metrics.LongHistogram; +import javax.annotation.Nullable; @AutoValue abstract class OpenTelemetryMetricsResource { /* Client Metrics */ + @Nullable abstract DoubleHistogram clientCallDurationCounter(); + @Nullable abstract LongCounter clientAttemptCountCounter(); + @Nullable abstract DoubleHistogram clientAttemptDurationCounter(); + @Nullable abstract LongHistogram clientTotalSentCompressedMessageSizeCounter(); + @Nullable abstract LongHistogram clientTotalReceivedCompressedMessageSizeCounter(); + @Nullable + abstract LongHistogram clientCallRetriesCounter(); + + @Nullable + abstract LongHistogram clientCallTransparentRetriesCounter(); + + @Nullable + abstract LongHistogram clientCallHedgesCounter(); + + @Nullable + abstract DoubleHistogram clientCallRetryDelayCounter(); /* Server Metrics */ + @Nullable abstract LongCounter serverCallCountCounter(); + @Nullable abstract DoubleHistogram serverCallDurationCounter(); + @Nullable abstract LongHistogram serverTotalSentCompressedMessageSizeCounter(); + @Nullable abstract LongHistogram serverTotalReceivedCompressedMessageSizeCounter(); static Builder builder() { @@ -63,6 +84,14 @@ abstract static class Builder { abstract Builder clientTotalReceivedCompressedMessageSizeCounter( LongHistogram counter); + abstract Builder clientCallRetriesCounter(LongHistogram counter); + + abstract Builder clientCallTransparentRetriesCounter(LongHistogram counter); + + abstract Builder clientCallHedgesCounter(LongHistogram counter); + + abstract Builder clientCallRetryDelayCounter(DoubleHistogram counter); + abstract Builder serverCallCountCounter(LongCounter counter); abstract Builder serverCallDurationCounter(DoubleHistogram counter); diff --git a/opentelemetry/src/main/java/io/grpc/opentelemetry/OpenTelemetryModule.java b/opentelemetry/src/main/java/io/grpc/opentelemetry/OpenTelemetryModule.java deleted file mode 100644 index f2deb60a9bf..00000000000 --- a/opentelemetry/src/main/java/io/grpc/opentelemetry/OpenTelemetryModule.java +++ /dev/null @@ -1,210 +0,0 @@ -/* - * Copyright 2023 The gRPC Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.grpc.opentelemetry; - -import static com.google.common.base.Preconditions.checkNotNull; -import static io.grpc.internal.GrpcUtil.IMPLEMENTATION_VERSION; - -import com.google.common.annotations.VisibleForTesting; -import com.google.common.base.Stopwatch; -import com.google.common.base.Supplier; -import io.grpc.ClientInterceptor; -import io.grpc.ExperimentalApi; -import io.grpc.ServerStreamTracer; -import io.grpc.opentelemetry.internal.OpenTelemetryConstants; -import io.opentelemetry.api.OpenTelemetry; -import io.opentelemetry.api.metrics.Meter; -import io.opentelemetry.api.metrics.MeterProvider; - -/** - * The entrypoint for OpenTelemetry metrics functionality in gRPC. - * - *

OpenTelemetryModule uses {@link io.opentelemetry.api.OpenTelemetry} APIs for instrumentation. - * When no SDK is explicitly added no telemetry data will be collected. See - * {@link io.opentelemetry.sdk.OpenTelemetrySdk} for information on how to construct the SDK. - * - */ -@ExperimentalApi("https://github.com/grpc/grpc-java/issues/10591") -public final class OpenTelemetryModule { - - private static final Supplier STOPWATCH_SUPPLIER = new Supplier() { - @Override - public Stopwatch get() { - return Stopwatch.createUnstarted(); - } - }; - - private final OpenTelemetry openTelemetryInstance; - private final MeterProvider meterProvider; - private final Meter meter; - private final OpenTelemetryMetricsResource resource; - - public static Builder newBuilder() { - return new Builder(); - } - - private OpenTelemetryModule(Builder builder) { - this.openTelemetryInstance = checkNotNull(builder.openTelemetrySdk, "openTelemetrySdk"); - this.meterProvider = checkNotNull(openTelemetryInstance.getMeterProvider(), "meterProvider"); - this.meter = this.meterProvider - .meterBuilder(OpenTelemetryConstants.INSTRUMENTATION_SCOPE) - .setInstrumentationVersion(IMPLEMENTATION_VERSION) - .build(); - this.resource = createMetricInstruments(meter); - } - - @VisibleForTesting - OpenTelemetry getOpenTelemetryInstance() { - return this.openTelemetryInstance; - } - - @VisibleForTesting - MeterProvider getMeterProvider() { - return this.meterProvider; - } - - @VisibleForTesting - Meter getMeter() { - return this.meter; - } - - @VisibleForTesting - OpenTelemetryMetricsResource getResource() { - return this.resource; - } - - /** - * Returns a {@link ClientInterceptor} with metrics implementation. - */ - public ClientInterceptor getClientInterceptor() { - OpenTelemetryMetricsModule openTelemetryMetricsModule = - new OpenTelemetryMetricsModule( - STOPWATCH_SUPPLIER, - resource); - return openTelemetryMetricsModule.getClientInterceptor(); - } - - /** - * Returns a {@link ServerStreamTracer.Factory} with metrics implementation. - */ - public ServerStreamTracer.Factory getServerStreamTracerFactory() { - OpenTelemetryMetricsModule openTelemetryMetricsModule = - new OpenTelemetryMetricsModule( - STOPWATCH_SUPPLIER, - resource); - return openTelemetryMetricsModule.getServerTracerFactory(); - } - - @VisibleForTesting - static OpenTelemetryMetricsResource createMetricInstruments(Meter meter) { - OpenTelemetryMetricsResource.Builder builder = OpenTelemetryMetricsResource.builder(); - - builder.clientCallDurationCounter( - meter.histogramBuilder("grpc.client.call.duration") - .setUnit("s") - .setDescription( - "Time taken by gRPC to complete an RPC from application's perspective") - .build()); - - builder.clientAttemptCountCounter( - meter.counterBuilder("grpc.client.attempt.started") - .setUnit("{attempt}") - .setDescription("Number of client call attempts started") - .build()); - - builder.clientAttemptDurationCounter( - meter.histogramBuilder( - "grpc.client.attempt.duration") - .setUnit("s") - .setDescription("Time taken to complete a client call attempt") - .build()); - - builder.clientTotalSentCompressedMessageSizeCounter( - meter.histogramBuilder( - "grpc.client.attempt.sent_total_compressed_message_size") - .setUnit("By") - .setDescription("Compressed message bytes sent per client call attempt") - .ofLongs() - .build()); - - builder.clientTotalReceivedCompressedMessageSizeCounter( - meter.histogramBuilder( - "grpc.client.attempt.rcvd_total_compressed_message_size") - .setUnit("By") - .setDescription("Compressed message bytes received per call attempt") - .ofLongs() - .build()); - - builder.serverCallCountCounter( - meter.counterBuilder("grpc.server.call.started") - .setUnit("{call}") - .setDescription("Number of server calls started") - .build()); - - builder.serverCallDurationCounter( - meter.histogramBuilder("grpc.server.call.duration") - .setUnit("s") - .setDescription( - "Time taken to complete a call from server transport's perspective") - .build()); - - builder.serverTotalSentCompressedMessageSizeCounter( - meter.histogramBuilder( - "grpc.server.call.sent_total_compressed_message_size") - .setUnit("By") - .setDescription("Compressed message bytes sent per server call") - .ofLongs() - .build()); - - builder.serverTotalReceivedCompressedMessageSizeCounter( - meter.histogramBuilder( - "grpc.server.call.rcvd_total_compressed_message_size") - .setUnit("By") - .setDescription("Compressed message bytes received per server call") - .ofLongs() - .build()); - - return builder.build(); - } - - /** - * Builder for configuring {@link OpenTelemetryModule}. - */ - public static class Builder { - private OpenTelemetry openTelemetrySdk = OpenTelemetry.noop(); - - private Builder() {} - - /** - * Sets the {@link io.opentelemetry.api.OpenTelemetry} entrypoint to use. This can be used to - * configure OpenTelemetry by returning the instance created by a - * {@link io.opentelemetry.sdk.OpenTelemetrySdkBuilder}. - */ - public Builder sdk(OpenTelemetry sdk) { - this.openTelemetrySdk = sdk; - return this; - } - - /** - * Returns a new {@link OpenTelemetryModule} built with the configuration of this {@link - * Builder}. - */ - public OpenTelemetryModule build() { - return new OpenTelemetryModule(this); - } - } -} diff --git a/opentelemetry/src/main/java/io/grpc/opentelemetry/OpenTelemetryPlugin.java b/opentelemetry/src/main/java/io/grpc/opentelemetry/OpenTelemetryPlugin.java new file mode 100644 index 00000000000..3705b4b65e1 --- /dev/null +++ b/opentelemetry/src/main/java/io/grpc/opentelemetry/OpenTelemetryPlugin.java @@ -0,0 +1,65 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.opentelemetry; + +import io.grpc.CallOptions; +import io.grpc.Metadata; +import io.opentelemetry.api.common.AttributesBuilder; + +/** + * Injects behavior into {@link GrpcOpenTelemetry}. + */ +interface OpenTelemetryPlugin { + /** + * Limited ability to disable the plugin based on the target. This only has an effect for + * per-call metrics. + * + *

Ideally this method wouldn't exist and it'd be handled by wrapping GrpcOpenTelemetry and + * conditionally delegating to it. But this is needed by CSM until ChannelBuilders have a + * consistent target over their life; currently specifying nameResolverFactory can change the + * target's scheme. + */ + default boolean enablePluginForChannel(String target) { + return true; + } + + ClientCallPlugin newClientCallPlugin(); + + ServerStreamPlugin newServerStreamPlugin(Metadata inboundMetadata); + + interface ClientCallPlugin { + ClientStreamPlugin newClientStreamPlugin(); + + default void addMetadata(Metadata toMetadata) {} + + default CallOptions filterCallOptions(CallOptions options) { + return options; + } + } + + interface ClientStreamPlugin { + default void inboundHeaders(Metadata headers) {} + + default void inboundTrailers(Metadata trailers) {} + + default void addLabels(AttributesBuilder to) {} + } + + interface ServerStreamPlugin { + default void addLabels(AttributesBuilder to) {} + } +} diff --git a/opentelemetry/src/main/java/io/grpc/opentelemetry/OpenTelemetryTracingModule.java b/opentelemetry/src/main/java/io/grpc/opentelemetry/OpenTelemetryTracingModule.java new file mode 100644 index 00000000000..d214e99bd75 --- /dev/null +++ b/opentelemetry/src/main/java/io/grpc/opentelemetry/OpenTelemetryTracingModule.java @@ -0,0 +1,505 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.opentelemetry; + +import static com.google.common.base.Preconditions.checkNotNull; +import static io.grpc.ClientStreamTracer.NAME_RESOLUTION_DELAYED; +import static io.grpc.internal.GrpcUtil.IMPLEMENTATION_VERSION; +import static io.grpc.opentelemetry.internal.OpenTelemetryConstants.BAGGAGE_KEY; + +import com.google.common.annotations.VisibleForTesting; +import io.grpc.Attributes; +import io.grpc.CallOptions; +import io.grpc.Channel; +import io.grpc.ClientCall; +import io.grpc.ClientInterceptor; +import io.grpc.ClientStreamTracer; +import io.grpc.ForwardingClientCall.SimpleForwardingClientCall; +import io.grpc.ForwardingClientCallListener.SimpleForwardingClientCallListener; +import io.grpc.ForwardingServerCallListener; +import io.grpc.Metadata; +import io.grpc.MethodDescriptor; +import io.grpc.ServerCall; +import io.grpc.ServerCallHandler; +import io.grpc.ServerInterceptor; +import io.grpc.ServerStreamTracer; +import io.grpc.internal.GrpcUtil; +import io.grpc.opentelemetry.internal.OpenTelemetryConstants; +import io.opentelemetry.api.OpenTelemetry; +import io.opentelemetry.api.baggage.Baggage; +import io.opentelemetry.api.common.AttributesBuilder; +import io.opentelemetry.api.trace.Span; +import io.opentelemetry.api.trace.StatusCode; +import io.opentelemetry.api.trace.Tracer; +import io.opentelemetry.context.Context; +import io.opentelemetry.context.Scope; +import io.opentelemetry.context.propagation.ContextPropagators; +import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; +import java.util.logging.Level; +import java.util.logging.Logger; +import javax.annotation.Nullable; + +/** + * Provides factories for {@link io.grpc.StreamTracer} that records tracing to OpenTelemetry. + */ +final class OpenTelemetryTracingModule { + private static final Logger logger = Logger.getLogger(OpenTelemetryTracingModule.class.getName()); + + @VisibleForTesting + final io.grpc.Context.Key otelSpan = io.grpc.Context.key("opentelemetry-span-key"); + + @Nullable + private static final AtomicIntegerFieldUpdater callEndedUpdater; + @Nullable + private static final AtomicIntegerFieldUpdater streamClosedUpdater; + + /* + * When using Atomic*FieldUpdater, some Samsung Android 5.0.x devices encounter a bug in their JDK + * reflection API that triggers a NoSuchFieldException. When this occurs, we fallback to + * (potentially racy) direct updates of the volatile variables. + */ + static { + AtomicIntegerFieldUpdater tmpCallEndedUpdater; + AtomicIntegerFieldUpdater tmpStreamClosedUpdater; + try { + tmpCallEndedUpdater = + AtomicIntegerFieldUpdater.newUpdater(CallAttemptsTracerFactory.class, "callEnded"); + tmpStreamClosedUpdater = + AtomicIntegerFieldUpdater.newUpdater(ServerTracer.class, "streamClosed"); + } catch (Throwable t) { + logger.log(Level.SEVERE, "Creating atomic field updaters failed", t); + tmpCallEndedUpdater = null; + tmpStreamClosedUpdater = null; + } + callEndedUpdater = tmpCallEndedUpdater; + streamClosedUpdater = tmpStreamClosedUpdater; + } + + private final Tracer otelTracer; + private final ContextPropagators contextPropagators; + private final MetadataGetter metadataGetter = MetadataGetter.getInstance(); + private final MetadataSetter metadataSetter = MetadataSetter.getInstance(); + private final TracingClientInterceptor clientInterceptor = new TracingClientInterceptor(); + private final ServerInterceptor serverSpanPropagationInterceptor = + new TracingServerSpanPropagationInterceptor(); + private final ServerTracerFactory serverTracerFactory = new ServerTracerFactory(); + + OpenTelemetryTracingModule(OpenTelemetry openTelemetry) { + this.otelTracer = checkNotNull(openTelemetry.getTracerProvider(), "tracerProvider") + .tracerBuilder(OpenTelemetryConstants.INSTRUMENTATION_SCOPE) + .setInstrumentationVersion(IMPLEMENTATION_VERSION) + .build(); + this.contextPropagators = checkNotNull(openTelemetry.getPropagators(), "contextPropagators"); + } + + @VisibleForTesting + Tracer getTracer() { + return otelTracer; + } + + /** + * Creates a {@link CallAttemptsTracerFactory} for a new call. + */ + @VisibleForTesting + CallAttemptsTracerFactory newClientCallTracer(Span clientSpan, MethodDescriptor method) { + return new CallAttemptsTracerFactory(clientSpan, method); + } + + /** + * Returns the server tracer factory. + */ + ServerStreamTracer.Factory getServerTracerFactory() { + return serverTracerFactory; + } + + /** + * Returns the client interceptor that facilitates otel tracing reporting. + */ + ClientInterceptor getClientInterceptor() { + return clientInterceptor; + } + + ServerInterceptor getServerSpanPropagationInterceptor() { + return serverSpanPropagationInterceptor; + } + + @VisibleForTesting + final class CallAttemptsTracerFactory extends ClientStreamTracer.Factory { + volatile int callEnded; + private final Span clientSpan; + private final String fullMethodName; + + CallAttemptsTracerFactory(Span clientSpan, MethodDescriptor method) { + checkNotNull(method, "method"); + this.fullMethodName = checkNotNull(method.getFullMethodName(), "fullMethodName"); + this.clientSpan = checkNotNull(clientSpan, "clientSpan"); + } + + @Override + public ClientStreamTracer newClientStreamTracer( + ClientStreamTracer.StreamInfo info, Metadata headers) { + Span attemptSpan = otelTracer.spanBuilder( + "Attempt." + fullMethodName.replace('/', '.')) + .setParent(Context.current().with(clientSpan)) + .startSpan(); + attemptSpan.setAttribute( + "previous-rpc-attempts", info.getPreviousAttempts()); + attemptSpan.setAttribute( + "transparent-retry",info.isTransparentRetry()); + if (info.getCallOptions().getOption(NAME_RESOLUTION_DELAYED) != null) { + clientSpan.addEvent("Delayed name resolution complete"); + } + return new ClientTracer(attemptSpan, clientSpan); + } + + /** + * Record a finished call and mark the current time as the end time. + * + *

Can be called from any thread without synchronization. Calling it the second time or more + * is a no-op. + */ + void callEnded(io.grpc.Status status) { + if (callEndedUpdater != null) { + if (callEndedUpdater.getAndSet(this, 1) != 0) { + return; + } + } else { + if (callEnded != 0) { + return; + } + callEnded = 1; + } + endSpanWithStatus(clientSpan, status); + } + } + + private final class ClientTracer extends ClientStreamTracer { + private final Span span; + private final Span parentSpan; + volatile int seqNo; + boolean isPendingStream; + + ClientTracer(Span span, Span parentSpan) { + this.span = checkNotNull(span, "span"); + this.parentSpan = checkNotNull(parentSpan, "parent span"); + } + + @Override + public void streamCreated(Attributes transportAtts, Metadata headers) { + contextPropagators.getTextMapPropagator().inject(Context.current().with(span), headers, + metadataSetter); + if (isPendingStream) { + span.addEvent("Delayed LB pick complete"); + } + } + + @Override + public void createPendingStream() { + isPendingStream = true; + } + + @Override + public void outboundMessageSent( + int seqNo, long optionalWireSize, long optionalUncompressedSize) { + recordOutboundMessageSentEvent(span, seqNo, optionalWireSize, optionalUncompressedSize); + } + + @Override + public void inboundMessageRead( + int seqNo, long optionalWireSize, long optionalUncompressedSize) { + if (optionalWireSize != optionalUncompressedSize) { + recordInboundCompressedMessage(span, seqNo, optionalWireSize); + } + } + + @Override + public void inboundMessage(int seqNo) { + this.seqNo = seqNo; + } + + @Override + public void inboundUncompressedSize(long bytes) { + recordInboundMessageSize(parentSpan, seqNo, bytes); + } + + @Override + public void streamClosed(io.grpc.Status status) { + endSpanWithStatus(span, status); + } + } + + private final class ServerTracer extends ServerStreamTracer { + private final Span span; + volatile int streamClosed; + private int seqNo; + private Baggage baggage; + + ServerTracer(String fullMethodName, @Nullable Span remoteSpan, Baggage baggage) { + checkNotNull(fullMethodName, "fullMethodName"); + this.span = + otelTracer.spanBuilder(generateTraceSpanName(true, fullMethodName)) + .setParent(remoteSpan == null ? null : Context.current().with(remoteSpan)) + .startSpan(); + this.baggage = baggage; + } + + /** + * Record a finished stream and mark the current time as the end time. + * + *

Can be called from any thread without synchronization. Calling it the second time or more + * is a no-op. + */ + @Override + public void streamClosed(io.grpc.Status status) { + if (streamClosedUpdater != null) { + if (streamClosedUpdater.getAndSet(this, 1) != 0) { + return; + } + } else { + if (streamClosed != 0) { + return; + } + streamClosed = 1; + } + endSpanWithStatus(span, status); + } + + @Override + public io.grpc.Context filterContext(io.grpc.Context context) { + return context + .withValue(otelSpan, span) + .withValue(BAGGAGE_KEY, baggage); + } + + @Override + public void outboundMessageSent( + int seqNo, long optionalWireSize, long optionalUncompressedSize) { + recordOutboundMessageSentEvent(span, seqNo, optionalWireSize, optionalUncompressedSize); + } + + @Override + public void inboundMessageRead( + int seqNo, long optionalWireSize, long optionalUncompressedSize) { + if (optionalWireSize != optionalUncompressedSize) { + recordInboundCompressedMessage(span, seqNo, optionalWireSize); + } + } + + @Override + public void inboundMessage(int seqNo) { + this.seqNo = seqNo; + } + + @Override + public void inboundUncompressedSize(long bytes) { + recordInboundMessageSize(span, seqNo, bytes); + } + } + + @VisibleForTesting + final class ServerTracerFactory extends ServerStreamTracer.Factory { + @SuppressWarnings("ReferenceEquality") + @Override + public ServerStreamTracer newServerStreamTracer(String fullMethodName, Metadata headers) { + Context context = contextPropagators.getTextMapPropagator().extract( + Context.current(), headers, metadataGetter + ); + Span remoteSpan = Span.fromContext(context); + if (remoteSpan == Span.getInvalid()) { + remoteSpan = null; + } + Baggage baggage = Baggage.fromContext(context); + return new ServerTracer(fullMethodName, remoteSpan, baggage); + } + } + + @VisibleForTesting + final class TracingServerSpanPropagationInterceptor implements ServerInterceptor { + @Override + public ServerCall.Listener interceptCall(ServerCall call, + Metadata headers, ServerCallHandler next) { + Span span = otelSpan.get(io.grpc.Context.current()); + if (span == null) { + logger.log(Level.FINE, "Server span not found. ServerTracerFactory for server " + + "tracing must be set."); + return next.startCall(call, headers); + } + Context serverCallContext = Context.current(); + serverCallContext = serverCallContext.with(span); + Baggage baggage = BAGGAGE_KEY.get(); + if (baggage != null) { + serverCallContext = serverCallContext.with(baggage); + } else { + logger.log(Level.WARNING, "Server baggage not found which is unexpected, " + + "as it is being added unconditionally in filterContext()."); + } + try (Scope scope = serverCallContext.makeCurrent()) { + return new ContextServerCallListener<>(next.startCall(call, headers), serverCallContext); + } + } + } + + private static class ContextServerCallListener extends + ForwardingServerCallListener.SimpleForwardingServerCallListener { + private final Context context; + + protected ContextServerCallListener(ServerCall.Listener delegate, Context context) { + super(delegate); + this.context = checkNotNull(context, "context"); + } + + @Override + public void onMessage(ReqT message) { + try (Scope scope = context.makeCurrent()) { + delegate().onMessage(message); + } + } + + @Override + public void onHalfClose() { + try (Scope scope = context.makeCurrent()) { + delegate().onHalfClose(); + } + } + + @Override + public void onCancel() { + try (Scope scope = context.makeCurrent()) { + delegate().onCancel(); + } + } + + @Override + public void onComplete() { + try (Scope scope = context.makeCurrent()) { + delegate().onComplete(); + } + } + + @Override + public void onReady() { + try (Scope scope = context.makeCurrent()) { + delegate().onReady(); + } + } + } + + @VisibleForTesting + final class TracingClientInterceptor implements ClientInterceptor { + + @Override + public ClientCall interceptCall( + MethodDescriptor method, CallOptions callOptions, Channel next) { + Span clientSpan = otelTracer.spanBuilder( + generateTraceSpanName(false, method.getFullMethodName())) + .startSpan(); + + final CallAttemptsTracerFactory tracerFactory = newClientCallTracer(clientSpan, method); + ClientCall call = + next.newCall( + method, + callOptions.withStreamTracerFactory(tracerFactory)); + return new SimpleForwardingClientCall(call) { + @Override + public void start(Listener responseListener, Metadata headers) { + delegate().start( + new SimpleForwardingClientCallListener(responseListener) { + @Override + public void onClose(io.grpc.Status status, Metadata trailers) { + tracerFactory.callEnded(status); + super.onClose(status, trailers); + } + }, + headers); + } + }; + } + } + + // Attribute named "message-size" always means the message size the application sees. + // If there was compression, additional event reports "message-size-compressed". + // + // An example trace with message compression: + // + // Sending: + // |-- Event 'Outbound message sent', attributes('sequence-numer' = 0, 'message-size' = 7854, + // 'message-size-compressed' = 5493) ----| + // + // Receiving: + // |-- Event 'Inbound compressed message', attributes('sequence-numer' = 0, + // 'message-size-compressed' = 5493 ) ----| + // |-- Event 'Inbound message received', attributes('sequence-numer' = 0, + // 'message-size' = 7854) ----| + // + // An example trace with no message compression: + // + // Sending: + // |-- Event 'Outbound message sent', attributes('sequence-numer' = 0, 'message-size' = 7854) ---| + // + // Receiving: + // |-- Event 'Inbound message received', attributes('sequence-numer' = 0, + // 'message-size' = 7854) ----| + private void recordOutboundMessageSentEvent(Span span, + int seqNo, long optionalWireSize, long optionalUncompressedSize) { + AttributesBuilder attributesBuilder = io.opentelemetry.api.common.Attributes.builder(); + attributesBuilder.put("sequence-number", seqNo); + if (optionalUncompressedSize != -1) { + attributesBuilder.put("message-size", optionalUncompressedSize); + } + if (optionalWireSize != -1 && optionalWireSize != optionalUncompressedSize) { + attributesBuilder.put("message-size-compressed", optionalWireSize); + } + span.addEvent("Outbound message", attributesBuilder.build()); + } + + private void recordInboundCompressedMessage(Span span, int seqNo, long optionalWireSize) { + AttributesBuilder attributesBuilder = io.opentelemetry.api.common.Attributes.builder(); + attributesBuilder.put("sequence-number", seqNo); + attributesBuilder.put("message-size-compressed", optionalWireSize); + span.addEvent("Inbound compressed message", attributesBuilder.build()); + } + + private void recordInboundMessageSize(Span span, int seqNo, long bytes) { + AttributesBuilder attributesBuilder = io.opentelemetry.api.common.Attributes.builder(); + attributesBuilder.put("sequence-number", seqNo); + attributesBuilder.put("message-size", bytes); + span.addEvent("Inbound message", attributesBuilder.build()); + } + + private void endSpanWithStatus(Span span, io.grpc.Status status) { + if (status.isOk()) { + span.setStatus(StatusCode.OK); + } else { + span.setStatus(StatusCode.ERROR, GrpcUtil.statusToPrettyString(status)); + } + span.end(); + } + + /** + * Convert a full method name to a tracing span name. + * + * @param isServer {@code false} if the span is on the client-side, {@code true} if on the + * server-side + * @param fullMethodName the method name as returned by + * {@link MethodDescriptor#getFullMethodName}. + */ + @VisibleForTesting + static String generateTraceSpanName(boolean isServer, String fullMethodName) { + String prefix = isServer ? "Recv" : "Sent"; + return prefix + "." + fullMethodName.replace('/', '.'); + } +} diff --git a/opentelemetry/src/main/java/io/grpc/opentelemetry/internal/OpenTelemetryConstants.java b/opentelemetry/src/main/java/io/grpc/opentelemetry/internal/OpenTelemetryConstants.java index af84caa8b4f..c09a1a2beca 100644 --- a/opentelemetry/src/main/java/io/grpc/opentelemetry/internal/OpenTelemetryConstants.java +++ b/opentelemetry/src/main/java/io/grpc/opentelemetry/internal/OpenTelemetryConstants.java @@ -16,7 +16,11 @@ package io.grpc.opentelemetry.internal; +import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.ImmutableList; +import io.opentelemetry.api.baggage.Baggage; import io.opentelemetry.api.common.AttributeKey; +import java.util.List; public final class OpenTelemetryConstants { @@ -28,6 +32,45 @@ public final class OpenTelemetryConstants { public static final AttributeKey TARGET_KEY = AttributeKey.stringKey("grpc.target"); + public static final AttributeKey LOCALITY_KEY = + AttributeKey.stringKey("grpc.lb.locality"); + + public static final AttributeKey BACKEND_SERVICE_KEY = + AttributeKey.stringKey("grpc.lb.backend_service"); + + public static final AttributeKey CUSTOM_LABEL_KEY = + AttributeKey.stringKey("grpc.client.call.custom"); + + public static final AttributeKey DISCONNECT_ERROR_KEY = + AttributeKey.stringKey("grpc.disconnect_error"); + + public static final AttributeKey SECURITY_LEVEL_KEY = + AttributeKey.stringKey("grpc.security_level"); + + @VisibleForTesting + public static final io.grpc.Context.Key BAGGAGE_KEY = + io.grpc.Context.key("opentelemetry-baggage-key"); + + public static final List LATENCY_BUCKETS = + ImmutableList.of( + 0d, 0.00001d, 0.00005d, 0.0001d, 0.0003d, 0.0006d, 0.0008d, 0.001d, 0.002d, + 0.003d, 0.004d, 0.005d, 0.006d, 0.008d, 0.01d, 0.013d, 0.016d, 0.02d, + 0.025d, 0.03d, 0.04d, 0.05d, 0.065d, 0.08d, 0.1d, 0.13d, 0.16d, + 0.2d, 0.25d, 0.3d, 0.4d, 0.5d, 0.65d, 0.8d, 1d, 2d, + 5d, 10d, 20d, 50d, 100d); + + public static final List SIZE_BUCKETS = + ImmutableList.of( + 0L, 1024L, 2048L, 4096L, 16384L, 65536L, 262144L, 1048576L, 4194304L, 16777216L, + 67108864L, 268435456L, 1073741824L, 4294967296L); + + public static final List RETRY_BUCKETS = ImmutableList.of(1L, 2L, 3L, 4L, 5L); + + public static final List TRANSPARENT_RETRY_BUCKETS = + ImmutableList.of(1L, 2L, 3L, 4L, 5L, 10L); + + public static final List HEDGE_BUCKETS = ImmutableList.of(1L, 2L, 3L, 4L, 5L); + private OpenTelemetryConstants() { } } diff --git a/opentelemetry/src/test/java/io/grpc/opentelemetry/GrpcOpenTelemetryTest.java b/opentelemetry/src/test/java/io/grpc/opentelemetry/GrpcOpenTelemetryTest.java new file mode 100644 index 00000000000..f0bd6f93098 --- /dev/null +++ b/opentelemetry/src/test/java/io/grpc/opentelemetry/GrpcOpenTelemetryTest.java @@ -0,0 +1,173 @@ +/* + * Copyright 2023 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.opentelemetry; + +import static com.google.common.truth.Truth.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; + +import com.google.common.collect.ImmutableList; +import io.grpc.ClientInterceptor; +import io.grpc.ManagedChannelBuilder; +import io.grpc.ServerBuilder; +import io.grpc.internal.GrpcUtil; +import io.grpc.opentelemetry.GrpcOpenTelemetry.TargetFilter; +import io.opentelemetry.api.OpenTelemetry; +import io.opentelemetry.sdk.OpenTelemetrySdk; +import io.opentelemetry.sdk.metrics.SdkMeterProvider; +import io.opentelemetry.sdk.testing.exporter.InMemoryMetricReader; +import io.opentelemetry.sdk.trace.SdkTracerProvider; +import java.util.Arrays; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class GrpcOpenTelemetryTest { + private final InMemoryMetricReader inMemoryMetricReader = InMemoryMetricReader.create(); + private final SdkMeterProvider meterProvider = + SdkMeterProvider.builder().registerMetricReader(inMemoryMetricReader).build(); + private final SdkTracerProvider tracerProvider = SdkTracerProvider.builder().build(); + private final OpenTelemetry noopOpenTelemetry = OpenTelemetry.noop(); + private boolean originalEnableOtelTracing; + + @Before + public void setup() { + originalEnableOtelTracing = GrpcOpenTelemetry.ENABLE_OTEL_TRACING; + } + + @After + public void tearDown() { + GrpcOpenTelemetry.ENABLE_OTEL_TRACING = originalEnableOtelTracing; + } + + @Test + public void build() { + OpenTelemetrySdk sdk = + OpenTelemetrySdk.builder().setMeterProvider(meterProvider).build(); + + GrpcOpenTelemetry openTelemetryModule = GrpcOpenTelemetry.newBuilder() + .sdk(sdk) + .addOptionalLabel("version") + .build(); + + assertThat(openTelemetryModule.getOpenTelemetryInstance()).isSameInstanceAs(sdk); + assertThat(openTelemetryModule.getMeterProvider()).isNotNull(); + assertThat(openTelemetryModule.getMeter()).isSameInstanceAs( + meterProvider.meterBuilder("grpc-java") + .setInstrumentationVersion(GrpcUtil.IMPLEMENTATION_VERSION) + .build()); + assertThat(openTelemetryModule.getOptionalLabels()).isEqualTo(ImmutableList.of("version")); + } + + @Test + public void buildTracer() { + OpenTelemetrySdk sdk = + OpenTelemetrySdk.builder().setTracerProvider(tracerProvider).build(); + + GrpcOpenTelemetry grpcOpenTelemetry = GrpcOpenTelemetry.newBuilder() + .enableTracing(true) + .sdk(sdk).build(); + + assertThat(grpcOpenTelemetry.getOpenTelemetryInstance()).isSameInstanceAs(sdk); + assertThat(grpcOpenTelemetry.getTracer()).isSameInstanceAs( + tracerProvider.tracerBuilder("grpc-java") + .setInstrumentationVersion(GrpcUtil.IMPLEMENTATION_VERSION) + .build()); + ServerBuilder mockServerBuiler = mock(ServerBuilder.class); + grpcOpenTelemetry.configureServerBuilder(mockServerBuiler); + verify(mockServerBuiler, times(2)).addStreamTracerFactory(any()); + verify(mockServerBuiler).intercept(any()); + verify(mockServerBuiler).addMetricSink(any()); + verifyNoMoreInteractions(mockServerBuiler); + + ManagedChannelBuilder mockChannelBuilder = mock(ManagedChannelBuilder.class); + grpcOpenTelemetry.configureChannelBuilder(mockChannelBuilder); + verify(mockChannelBuilder).intercept(any(ClientInterceptor.class)); + } + + @Test + public void builderDefaults() { + GrpcOpenTelemetry module = GrpcOpenTelemetry.newBuilder().build(); + + assertThat(module.getOpenTelemetryInstance()).isNotNull(); + assertThat(module.getOpenTelemetryInstance()).isSameInstanceAs(noopOpenTelemetry); + assertThat(module.getMeterProvider()).isNotNull(); + assertThat(module.getMeterProvider()) + .isSameInstanceAs(noopOpenTelemetry.getMeterProvider()); + assertThat(module.getMeter()).isSameInstanceAs(noopOpenTelemetry + .getMeterProvider() + .meterBuilder("grpc-java") + .setInstrumentationVersion(GrpcUtil.IMPLEMENTATION_VERSION) + .build()); + assertThat(module.getEnableMetrics()).isEmpty(); + assertThat(module.getOptionalLabels()).isEmpty(); + + assertThat(module.getTracer()).isSameInstanceAs(noopOpenTelemetry + .getTracerProvider() + .tracerBuilder("grpc-java") + .setInstrumentationVersion(GrpcUtil.IMPLEMENTATION_VERSION) + .build() + ); + } + + @Test + public void builderTargetAttributeFilter() { + GrpcOpenTelemetry module = GrpcOpenTelemetry.newBuilder() + .targetAttributeFilter(t -> t.contains("allowed.com")) + .build(); + + TargetFilter internalFilter = module.getTargetAttributeFilter(); + + assertThat(internalFilter.test("allowed.com")).isTrue(); + assertThat(internalFilter.test("example.com")).isFalse(); + } + + @Test + public void enableDisableMetrics() { + GrpcOpenTelemetry.Builder builder = GrpcOpenTelemetry.newBuilder(); + builder.enableMetrics(Arrays.asList("metric1", "metric4")); + builder.disableMetrics(Arrays.asList("metric2", "metric3")); + + GrpcOpenTelemetry module = builder.build(); + + assertThat(module.getEnableMetrics().get("metric1")).isTrue(); + assertThat(module.getEnableMetrics().get("metric4")).isTrue(); + assertThat(module.getEnableMetrics().get("metric2")).isFalse(); + assertThat(module.getEnableMetrics().get("metric3")).isFalse(); + } + + @Test + public void disableAllMetrics() { + GrpcOpenTelemetry.Builder builder = GrpcOpenTelemetry.newBuilder(); + builder.enableMetrics(Arrays.asList("metric1", "metric4")); + builder.disableMetrics(Arrays.asList("metric2", "metric3")); + builder.disableAllMetrics(); + + GrpcOpenTelemetry module = builder.build(); + + assertThat(module.getEnableMetrics()).isEmpty(); + } + + // TODO(dnvindhya): Add tests for configurator + +} diff --git a/opentelemetry/src/test/java/io/grpc/opentelemetry/GrpcTraceBinContextPropagatorTest.java b/opentelemetry/src/test/java/io/grpc/opentelemetry/GrpcTraceBinContextPropagatorTest.java new file mode 100644 index 00000000000..f85b8067c26 --- /dev/null +++ b/opentelemetry/src/test/java/io/grpc/opentelemetry/GrpcTraceBinContextPropagatorTest.java @@ -0,0 +1,313 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.opentelemetry; + +import static com.google.common.truth.Truth.assertThat; +import static io.grpc.InternalMetadata.BASE64_ENCODING_OMIT_PADDING; +import static org.junit.Assert.assertTrue; + +import com.google.common.collect.ImmutableMap; +import io.grpc.Metadata; +import io.opentelemetry.api.trace.Span; +import io.opentelemetry.api.trace.SpanContext; +import io.opentelemetry.api.trace.TraceFlags; +import io.opentelemetry.api.trace.TraceState; +import io.opentelemetry.context.Context; +import io.opentelemetry.context.propagation.TextMapGetter; +import io.opentelemetry.context.propagation.TextMapSetter; +import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; +import javax.annotation.Nullable; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class GrpcTraceBinContextPropagatorTest { + private static final String TRACE_ID_BASE16 = "e384981d65129fa3e384981d65129fa3"; + private static final String SPAN_ID_BASE16 = "e384981d65129fa3"; + private static final String TRACE_HEADER_SAMPLED = + "0000" + TRACE_ID_BASE16 + "01" + SPAN_ID_BASE16 + "0201"; + private static final String TRACE_HEADER_NOT_SAMPLED = + "0000" + TRACE_ID_BASE16 + "01" + SPAN_ID_BASE16 + "0200"; + private final String goldenHeaderEncodedSampled = encode(TRACE_HEADER_SAMPLED); + private final String goldenHeaderEncodedNotSampled = encode(TRACE_HEADER_NOT_SAMPLED); + private static final TextMapSetter> setter = Map::put; + private static final TextMapGetter> getter = + new TextMapGetter>() { + @Override + public Iterable keys(Map carrier) { + return carrier.keySet(); + } + + @Nullable + @Override + public String get(Map carrier, String key) { + return carrier.get(key); + } + }; + private final GrpcTraceBinContextPropagator grpcTraceBinContextPropagator = + GrpcTraceBinContextPropagator.defaultInstance(); + + private static Context withSpanContext(SpanContext spanContext, Context context) { + return context.with(Span.wrap(spanContext)); + } + + private static SpanContext getSpanContext(Context context) { + return Span.fromContext(context).getSpanContext(); + } + + @Test + public void inject_map_Nothing() { + Map carrier = new HashMap<>(); + grpcTraceBinContextPropagator.inject(Context.current(), carrier, setter); + assertThat(carrier).hasSize(0); + } + + @Test + public void inject_map_invalidSpan() { + Map carrier = new HashMap<>(); + Context context = withSpanContext(SpanContext.getInvalid(), Context.current()); + grpcTraceBinContextPropagator.inject(context, carrier, setter); + assertThat(carrier).isEmpty(); + } + + @Test + public void inject_map_nullCarrier() { + Map carrier = new HashMap<>(); + Context context = + withSpanContext( + SpanContext.create( + TRACE_ID_BASE16, SPAN_ID_BASE16, TraceFlags.getSampled(), TraceState.getDefault()), + Context.current()); + grpcTraceBinContextPropagator.inject(context, null, + (TextMapSetter>) (ignored, key, value) -> carrier.put(key, value)); + assertThat(carrier) + .containsExactly( + GrpcTraceBinContextPropagator.GRPC_TRACE_BIN_HEADER, goldenHeaderEncodedSampled); + } + + @Test + public void inject_map_nullContext() { + Map carrier = new HashMap<>(); + grpcTraceBinContextPropagator.inject(null, carrier, setter); + assertThat(carrier).isEmpty(); + } + + @Test + public void inject_map_invalidBinaryFormat() { + GrpcTraceBinContextPropagator propagator = new GrpcTraceBinContextPropagator( + new Metadata.BinaryMarshaller() { + @Override + public byte[] toBytes(SpanContext value) { + throw new IllegalArgumentException("failed to byte"); + } + + @Override + public SpanContext parseBytes(byte[] serialized) { + return null; + } + }); + Map carrier = new HashMap<>(); + Context context = + withSpanContext( + SpanContext.create( + TRACE_ID_BASE16, SPAN_ID_BASE16, TraceFlags.getSampled(), TraceState.getDefault()), + Context.current()); + propagator.inject(context, carrier, setter); + assertThat(carrier).hasSize(0); + } + + @Test + public void inject_map_SampledContext() { + verify_inject_map(TraceFlags.getSampled(), goldenHeaderEncodedSampled); + } + + @Test + public void inject_map_NotSampledContext() { + verify_inject_map(TraceFlags.getDefault(), goldenHeaderEncodedNotSampled); + } + + private void verify_inject_map(TraceFlags traceFlags, String goldenHeader) { + Map carrier = new HashMap<>(); + Context context = + withSpanContext( + SpanContext.create( + TRACE_ID_BASE16, SPAN_ID_BASE16, traceFlags, TraceState.getDefault()), + Context.current()); + grpcTraceBinContextPropagator.inject(context, carrier, setter); + assertThat(carrier) + .containsExactly( + GrpcTraceBinContextPropagator.GRPC_TRACE_BIN_HEADER, goldenHeader); + } + + @Test + public void extract_map_nothing() { + Map carrier = new HashMap<>(); + assertThat(grpcTraceBinContextPropagator.extract(Context.current(), carrier, getter)) + .isSameInstanceAs(Context.current()); + } + + @Test + public void extract_map_SampledContext() { + verify_extract_map(TraceFlags.getSampled(), goldenHeaderEncodedSampled); + } + + @Test + public void extract_map_NotSampledContext() { + verify_extract_map(TraceFlags.getDefault(), goldenHeaderEncodedNotSampled); + } + + private void verify_extract_map(TraceFlags traceFlags, String goldenHeader) { + Map carrier = ImmutableMap.of( + GrpcTraceBinContextPropagator.GRPC_TRACE_BIN_HEADER, goldenHeader); + Context result = grpcTraceBinContextPropagator.extract(Context.current(), carrier, getter); + assertThat(getSpanContext(result)).isEqualTo(SpanContext.create( + TRACE_ID_BASE16, SPAN_ID_BASE16, traceFlags, TraceState.getDefault())); + } + + @Test + public void inject_metadata_Nothing() { + Metadata carrier = new Metadata(); + grpcTraceBinContextPropagator.inject(Context.current(), carrier, MetadataSetter.getInstance()); + assertThat(carrier.keys()).isEmpty(); + } + + @Test + public void inject_metadata_nullCarrier() { + Context context = + withSpanContext( + SpanContext.create( + TRACE_ID_BASE16, SPAN_ID_BASE16, TraceFlags.getSampled(), TraceState.getDefault()), + Context.current()); + grpcTraceBinContextPropagator.inject(context, null, MetadataSetter.getInstance()); + } + + @Test + public void inject_metadata_invalidSpan() { + Metadata carrier = new Metadata(); + Context context = withSpanContext(SpanContext.getInvalid(), Context.current()); + grpcTraceBinContextPropagator.inject(context, carrier, MetadataSetter.getInstance()); + assertThat(carrier.keys()).isEmpty(); + } + + @Test + public void inject_metadata_SampledContext() { + verify_inject_metadata(TraceFlags.getSampled(), hexStringToByteArray(TRACE_HEADER_SAMPLED)); + } + + @Test + public void inject_metadataSetter_NotSampledContext() { + verify_inject_metadata(TraceFlags.getDefault(), hexStringToByteArray(TRACE_HEADER_NOT_SAMPLED)); + } + + private void verify_inject_metadata(TraceFlags traceFlags, byte[] bytes) { + Metadata metadata = new Metadata(); + Context context = + withSpanContext( + SpanContext.create( + TRACE_ID_BASE16, SPAN_ID_BASE16, traceFlags, TraceState.getDefault()), + Context.current()); + grpcTraceBinContextPropagator.inject(context, metadata, MetadataSetter.getInstance()); + byte[] injected = metadata.get(Metadata.Key.of( + GrpcTraceBinContextPropagator.GRPC_TRACE_BIN_HEADER, Metadata.BINARY_BYTE_MARSHALLER)); + assertTrue(Arrays.equals(injected, bytes)); + } + + @Test + public void extract_metadata_nothing() { + assertThat(grpcTraceBinContextPropagator.extract( + Context.current(), new Metadata(), MetadataGetter.getInstance())) + .isSameInstanceAs(Context.current()); + } + + @Test + public void extract_metadata_nullCarrier() { + assertThat(grpcTraceBinContextPropagator.extract( + Context.current(), null, MetadataGetter.getInstance())) + .isSameInstanceAs(Context.current()); + } + + @Test + public void extract_metadata_SampledContext() { + verify_extract_metadata(TraceFlags.getSampled(), TRACE_HEADER_SAMPLED); + } + + @Test + public void extract_metadataGetter_NotSampledContext() { + verify_extract_metadata(TraceFlags.getDefault(), TRACE_HEADER_NOT_SAMPLED); + } + + private void verify_extract_metadata(TraceFlags traceFlags, String hex) { + Metadata carrier = new Metadata(); + carrier.put(Metadata.Key.of( + GrpcTraceBinContextPropagator.GRPC_TRACE_BIN_HEADER, Metadata.BINARY_BYTE_MARSHALLER), + hexStringToByteArray(hex)); + Context result = grpcTraceBinContextPropagator.extract(Context.current(), carrier, + MetadataGetter.getInstance()); + assertThat(getSpanContext(result)).isEqualTo(SpanContext.create( + TRACE_ID_BASE16, SPAN_ID_BASE16, traceFlags, TraceState.getDefault())); + } + + @Test + public void extract_metadata_invalidBinaryFormat() { + GrpcTraceBinContextPropagator propagator = new GrpcTraceBinContextPropagator( + new Metadata.BinaryMarshaller() { + @Override + public byte[] toBytes(SpanContext value) { + return new byte[0]; + } + + @Override + public SpanContext parseBytes(byte[] serialized) { + throw new IllegalArgumentException("failed to byte"); + } + }); + Metadata carrier = new Metadata(); + carrier.put(Metadata.Key.of( + GrpcTraceBinContextPropagator.GRPC_TRACE_BIN_HEADER, Metadata.BINARY_BYTE_MARSHALLER), + hexStringToByteArray(TRACE_HEADER_SAMPLED)); + assertThat(propagator.extract(Context.current(), carrier, MetadataGetter.getInstance())) + .isSameInstanceAs(Context.current()); + } + + @Test + public void extract_metadata_invalidBinaryFormatVersion() { + Metadata carrier = new Metadata(); + carrier.put(Metadata.Key.of( + GrpcTraceBinContextPropagator.GRPC_TRACE_BIN_HEADER, Metadata.BINARY_BYTE_MARSHALLER), + hexStringToByteArray("0100" + TRACE_ID_BASE16 + "01" + SPAN_ID_BASE16 + "0201")); + assertThat(grpcTraceBinContextPropagator.extract( + Context.current(), carrier, MetadataGetter.getInstance())) + .isSameInstanceAs(Context.current()); + } + + private static String encode(String hex) { + return BASE64_ENCODING_OMIT_PADDING.encode(hexStringToByteArray(hex)); + } + + private static byte[] hexStringToByteArray(String s) { + int len = s.length(); + byte[] data = new byte[len / 2]; + for (int i = 0; i < len; i += 2) { + data[i / 2] = (byte) ((Character.digit(s.charAt(i), 16) << 4) + + Character.digit(s.charAt(i + 1), 16)); + } + return data; + } +} diff --git a/opentelemetry/src/test/java/io/grpc/opentelemetry/MetadataGetterTest.java b/opentelemetry/src/test/java/io/grpc/opentelemetry/MetadataGetterTest.java new file mode 100644 index 00000000000..5934240e5c2 --- /dev/null +++ b/opentelemetry/src/test/java/io/grpc/opentelemetry/MetadataGetterTest.java @@ -0,0 +1,96 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.opentelemetry; + +import static io.grpc.InternalMetadata.BASE64_ENCODING_OMIT_PADDING; +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; + +import io.grpc.Metadata; +import java.nio.charset.Charset; +import java.util.Iterator; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class MetadataGetterTest { + private final MetadataGetter metadataGetter = MetadataGetter.getInstance(); + + @Test + public void getBinaryGrpcTraceBin() { + Metadata metadata = new Metadata(); + byte[] b = "generated".getBytes(Charset.defaultCharset()); + Metadata.Key grpc_trace_bin_key = + Metadata.Key.of("grpc-trace-bin", Metadata.BINARY_BYTE_MARSHALLER); + metadata.put(grpc_trace_bin_key, b); + assertArrayEquals(b, metadataGetter.getBinary(metadata, "grpc-trace-bin")); + } + + @Test + public void getBinaryEmptyMetadata() { + assertNull(metadataGetter.getBinary(new Metadata(), "grpc-trace-bin")); + } + + @Test + public void getBinaryNotGrpcTraceBin() { + Metadata metadata = new Metadata(); + byte[] b = "generated".getBytes(Charset.defaultCharset()); + Metadata.Key grpc_trace_bin_key = + Metadata.Key.of("another-bin", Metadata.BINARY_BYTE_MARSHALLER); + metadata.put(grpc_trace_bin_key, b); + assertNull(metadataGetter.getBinary(metadata, "another-bin")); + } + + @Test + public void getTextEmptyMetadata() { + assertNull(metadataGetter.get(new Metadata(), "a-key")); + } + + @Test + public void getTextBinHeader() { + assertNull(metadataGetter.get(new Metadata(), "a-key-bin")); + } + + @Test + public void getTestGrpcTraceBin() { + Metadata metadata = new Metadata(); + byte[] b = "generated".getBytes(Charset.defaultCharset()); + Metadata.Key grpc_trace_bin_key = + Metadata.Key.of("grpc-trace-bin", Metadata.BINARY_BYTE_MARSHALLER); + metadata.put(grpc_trace_bin_key, b); + assertEquals(BASE64_ENCODING_OMIT_PADDING.encode(b), + metadataGetter.get(metadata, "grpc-trace-bin")); + } + + @Test + public void getText() { + Metadata metadata = new Metadata(); + Metadata.Key other_key = + Metadata.Key.of("other", Metadata.ASCII_STRING_MARSHALLER); + metadata.put(other_key, "header-value"); + assertEquals("header-value", metadataGetter.get(metadata, "other")); + + Iterator iterator = metadataGetter.keys(metadata).iterator(); + assertTrue(iterator.hasNext()); + assertEquals("other", iterator.next()); + assertFalse(iterator.hasNext()); + } +} diff --git a/opentelemetry/src/test/java/io/grpc/opentelemetry/MetadataSetterTest.java b/opentelemetry/src/test/java/io/grpc/opentelemetry/MetadataSetterTest.java new file mode 100644 index 00000000000..fcd85480bb9 --- /dev/null +++ b/opentelemetry/src/test/java/io/grpc/opentelemetry/MetadataSetterTest.java @@ -0,0 +1,83 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.opentelemetry; + +import static io.grpc.InternalMetadata.BASE64_ENCODING_OMIT_PADDING; +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; + +import io.grpc.Metadata; +import java.nio.charset.Charset; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class MetadataSetterTest { + private final MetadataSetter metadataSetter = MetadataSetter.getInstance(); + + @Test + public void setGrpcTraceBin() { + Metadata metadata = new Metadata(); + byte[] b = "generated".getBytes(Charset.defaultCharset()); + Metadata.Key grpc_trace_bin_key = + Metadata.Key.of("grpc-trace-bin", Metadata.BINARY_BYTE_MARSHALLER); + metadataSetter.set(metadata, "grpc-trace-bin", b); + assertArrayEquals(b, metadata.get(grpc_trace_bin_key)); + } + + @Test + public void setOtherBinaryKey() { + Metadata metadata = new Metadata(); + byte[] b = "generated".getBytes(Charset.defaultCharset()); + Metadata.Key other_key = + Metadata.Key.of("for-test-bin", Metadata.BINARY_BYTE_MARSHALLER); + metadataSetter.set(metadata, other_key.name(), b); + assertNull(metadata.get(other_key)); + } + + @Test + public void setText() { + Metadata metadata = new Metadata(); + String v = "generated"; + Metadata.Key textKey = + Metadata.Key.of("text-key", Metadata.ASCII_STRING_MARSHALLER); + metadataSetter.set(metadata, textKey.name(), v); + assertEquals(metadata.get(textKey), v); + } + + @Test + public void setTextBin() { + Metadata metadata = new Metadata(); + Metadata.Key other_key = + Metadata.Key.of("for-test-bin", Metadata.BINARY_BYTE_MARSHALLER); + metadataSetter.set(metadata, other_key.name(), "generated"); + assertNull(metadata.get(other_key)); + } + + @Test + public void setTextGrpcTraceBin() { + Metadata metadata = new Metadata(); + byte[] b = "generated".getBytes(Charset.defaultCharset()); + metadataSetter.set(metadata, "grpc-trace-bin", BASE64_ENCODING_OMIT_PADDING.encode(b)); + + Metadata.Key grpc_trace_bin_key = + Metadata.Key.of("grpc-trace-bin", Metadata.BINARY_BYTE_MARSHALLER); + assertArrayEquals(metadata.get(grpc_trace_bin_key), b); + } +} diff --git a/opentelemetry/src/test/java/io/grpc/opentelemetry/OpenTelemetryMetricSinkTest.java b/opentelemetry/src/test/java/io/grpc/opentelemetry/OpenTelemetryMetricSinkTest.java new file mode 100644 index 00000000000..cced4de3cb4 --- /dev/null +++ b/opentelemetry/src/test/java/io/grpc/opentelemetry/OpenTelemetryMetricSinkTest.java @@ -0,0 +1,480 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.opentelemetry; + +import static io.opentelemetry.sdk.testing.assertj.OpenTelemetryAssertions.assertThat; + +import com.google.common.collect.ImmutableList; +import io.grpc.DoubleCounterMetricInstrument; +import io.grpc.DoubleHistogramMetricInstrument; +import io.grpc.LongCounterMetricInstrument; +import io.grpc.LongGaugeMetricInstrument; +import io.grpc.LongHistogramMetricInstrument; +import io.grpc.LongUpDownCounterMetricInstrument; +import io.grpc.MetricInstrument; +import io.grpc.MetricSink; +import io.grpc.opentelemetry.internal.OpenTelemetryConstants; +import io.opentelemetry.api.common.AttributeKey; +import io.opentelemetry.api.metrics.DoubleCounter; +import io.opentelemetry.api.metrics.DoubleHistogram; +import io.opentelemetry.api.metrics.LongCounter; +import io.opentelemetry.api.metrics.Meter; +import io.opentelemetry.sdk.common.InstrumentationScopeInfo; +import io.opentelemetry.sdk.testing.junit4.OpenTelemetryRule; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class OpenTelemetryMetricSinkTest { + + @Rule + public final OpenTelemetryRule openTelemetryTesting = OpenTelemetryRule.create(); + + private final Meter testMeter = openTelemetryTesting.getOpenTelemetry() + .getMeter(OpenTelemetryConstants.INSTRUMENTATION_SCOPE); + + private OpenTelemetryMetricSink sink; + + @Test + public void updateMeasures_enabledMetrics() { + Map enabledMetrics = new HashMap<>(); + enabledMetrics.put("client_calls_started", true); + enabledMetrics.put("server_calls_started", true); + + List optionalLabels = Arrays.asList("status"); + + List instruments = Arrays.asList( + new DoubleCounterMetricInstrument(0, "client_calls_started", + "Number of client calls started", "count", Collections.emptyList(), + Collections.emptyList(), + true), + new LongCounterMetricInstrument(1, "server_calls_started", "Number of server calls started", + "count", Collections.emptyList(), Collections.emptyList(), false), + new DoubleHistogramMetricInstrument(2, "client_message_size", "Sent message size", "bytes", + Collections.emptyList(), + Collections.emptyList(), Collections.emptyList(), true) + ); + + // Create sink + sink = new OpenTelemetryMetricSink(testMeter, enabledMetrics, false, optionalLabels); + + // Invoke updateMeasures + sink.updateMeasures(instruments); + + com.google.common.truth.Truth.assertThat(sink.getMeasuresSize()).isEqualTo(3); + // Metric is explicitly enabled for sink + com.google.common.truth.Truth.assertThat(sink.getMeasures().get(0).getMeasure()) + .isInstanceOf(DoubleCounter.class); + // Metric is explicitly enabled for sink + com.google.common.truth.Truth.assertThat(sink.getMeasures().get(1).getMeasure()) + .isInstanceOf(LongCounter.class); + // Metric is enabled by default + com.google.common.truth.Truth.assertThat(sink.getMeasures().get(2).getMeasure()) + .isInstanceOf(DoubleHistogram.class); + + } + + @Test + public void updateMeasure_disabledMetrics() { + Map enabledMetrics = new HashMap<>(); + enabledMetrics.put("client_calls_started", false); + enabledMetrics.put("server_calls_started", false); + + List optionalLabels = Arrays.asList("status"); + + List instruments = Arrays.asList( + new DoubleCounterMetricInstrument(0, "client_calls_started", + "Number of client calls started", "count", Collections.emptyList(), + Collections.emptyList(), true), + new LongCounterMetricInstrument(1, "server_calls_started", "Number of server calls started", + "count", Collections.emptyList(), Collections.emptyList(), true), + new DoubleHistogramMetricInstrument(2, "client_message_size", "Sent message size", "bytes", + Collections.emptyList(), + Collections.emptyList(), Collections.emptyList(), true) + ); + + // Create sink + sink = new OpenTelemetryMetricSink(testMeter, enabledMetrics, true, optionalLabels); + + // Invoke updateMeasures + sink.updateMeasures(instruments); + + com.google.common.truth.Truth.assertThat(sink.getMeasuresSize()).isEqualTo(3); + // Metric is explicitly disabled + com.google.common.truth.Truth.assertThat(sink.getMeasures().get(0)).isNull(); + // Metric is explicitly disabled + com.google.common.truth.Truth.assertThat(sink.getMeasures().get(1)).isNull(); + // Metric is enabled by default, but all default metrics are disabled + com.google.common.truth.Truth.assertThat(sink.getMeasures().get(2)).isNull(); + + } + + @Test + public void addCounter_enabledMetric() { + // set up sink with disabled metric + Map enabledMetrics = new HashMap<>(); + enabledMetrics.put("client_latency", true); + + LongCounterMetricInstrument longCounterInstrument = + new LongCounterMetricInstrument(0, "client_latency", "Client latency", "s", + Collections.emptyList(), + Collections.emptyList(), false); + DoubleCounterMetricInstrument doubleCounterInstrument = + new DoubleCounterMetricInstrument(1, "client_calls_started", + "Number of client calls started", "count", Collections.emptyList(), + Collections.emptyList(), + true); + LongUpDownCounterMetricInstrument longUpDownCounterInstrument = + new LongUpDownCounterMetricInstrument(2, "active_carrier_pigeons", + "Active Carrier Pigeons", "pigeons", + Collections.emptyList(), + Collections.emptyList(), true); + + // Create sink + sink = new OpenTelemetryMetricSink(testMeter, enabledMetrics, false, Collections.emptyList()); + + // Invoke updateMeasures + sink.updateMeasures(Arrays.asList(longCounterInstrument, doubleCounterInstrument, + longUpDownCounterInstrument)); + + sink.addLongCounter(longCounterInstrument, 123L, Collections.emptyList(), + Collections.emptyList()); + sink.addDoubleCounter(doubleCounterInstrument, 12.0, Collections.emptyList(), + Collections.emptyList()); + sink.addLongUpDownCounter(longUpDownCounterInstrument, -3L, Collections.emptyList(), + Collections.emptyList()); + + assertThat(openTelemetryTesting.getMetrics()) + .satisfiesExactlyInAnyOrder( + metric -> + assertThat(metric) + .hasInstrumentationScope(InstrumentationScopeInfo.create( + OpenTelemetryConstants.INSTRUMENTATION_SCOPE)) + .hasName("client_latency") + .hasDescription("Client latency") + .hasUnit("s") + .hasLongSumSatisfying( + longSum -> + longSum + .hasPointsSatisfying( + point -> + point + .hasValue(123L))), + metric -> + assertThat(metric) + .hasInstrumentationScope(InstrumentationScopeInfo.create( + OpenTelemetryConstants.INSTRUMENTATION_SCOPE)) + .hasName("client_calls_started") + .hasDescription("Number of client calls started") + .hasUnit("count") + .hasDoubleSumSatisfying( + doubleSum -> + doubleSum + .hasPointsSatisfying( + point -> + point + .hasValue(12.0D))), + metric -> + assertThat(metric) + .hasInstrumentationScope(InstrumentationScopeInfo.create( + OpenTelemetryConstants.INSTRUMENTATION_SCOPE)) + .hasName("active_carrier_pigeons") + .hasDescription("Active Carrier Pigeons") + .hasUnit("pigeons") + .hasLongSumSatisfying( + longSum -> + longSum + .hasPointsSatisfying( + point -> + point + .hasValue(-3L)))); + } + + @Test + public void addCounter_disabledMetric() { + // set up sink with disabled metric + Map enabledMetrics = new HashMap<>(); + enabledMetrics.put("client_latency", false); + enabledMetrics.put("active_carrier_pigeons", false); + + LongCounterMetricInstrument instrument = + new LongCounterMetricInstrument(0, "client_latency", "Client latency", "s", + Collections.emptyList(), + Collections.emptyList(), true); + LongUpDownCounterMetricInstrument longUpDownCounterInstrument = + new LongUpDownCounterMetricInstrument(1, "active_carrier_pigeons", + "Active Carrier Pigeons", "pigeons", + Collections.emptyList(), + Collections.emptyList(), false); + + // Create sink + sink = new OpenTelemetryMetricSink(testMeter, enabledMetrics, true, Collections.emptyList()); + + // Invoke updateMeasures + sink.updateMeasures(Arrays.asList(instrument, longUpDownCounterInstrument)); + + sink.addLongCounter(instrument, 123L, Collections.emptyList(), Collections.emptyList()); + sink.addLongUpDownCounter(longUpDownCounterInstrument, -13L, Collections.emptyList(), + Collections.emptyList()); + + assertThat(openTelemetryTesting.getMetrics()).isEmpty(); + } + + @Test + public void addHistogram_enabledMetric() { + // set up sink with disabled metric + Map enabledMetrics = new HashMap<>(); + enabledMetrics.put("client_message_size", true); + enabledMetrics.put("server_message_size", true); + + DoubleHistogramMetricInstrument doubleHistogramInstrument = + new DoubleHistogramMetricInstrument(0, "client_message_size", "Sent message size", "bytes", + Collections.emptyList(), + Collections.emptyList(), Collections.emptyList(), false); + LongHistogramMetricInstrument longHistogramInstrument = + new LongHistogramMetricInstrument(1, "server_message_size", "Received message size", + "bytes", + Collections.emptyList(), + Collections.emptyList(), Collections.emptyList(), true); + + // Create sink + sink = new OpenTelemetryMetricSink(testMeter, enabledMetrics, false, Collections.emptyList()); + + // Invoke updateMeasures + sink.updateMeasures(Arrays.asList(doubleHistogramInstrument, longHistogramInstrument)); + + sink.recordDoubleHistogram(doubleHistogramInstrument, 12.0, Collections.emptyList(), + Collections.emptyList()); + sink.recordLongHistogram(longHistogramInstrument, 123L, Collections.emptyList(), + Collections.emptyList()); + + assertThat(openTelemetryTesting.getMetrics()) + .satisfiesExactlyInAnyOrder( + metric -> + assertThat(metric) + .hasInstrumentationScope(InstrumentationScopeInfo.create( + OpenTelemetryConstants.INSTRUMENTATION_SCOPE)) + .hasName("client_message_size") + .hasDescription("Sent message size") + .hasUnit("bytes") + .hasHistogramSatisfying( + histogram -> + histogram.hasPointsSatisfying( + point -> + point + .hasCount(1) + .hasSum(12.0))), + + metric -> + assertThat(metric) + .hasInstrumentationScope(InstrumentationScopeInfo.create( + OpenTelemetryConstants.INSTRUMENTATION_SCOPE)) + .hasName("server_message_size") + .hasDescription("Received message size") + .hasUnit("bytes") + .hasHistogramSatisfying( + histogram -> + histogram.hasPointsSatisfying( + point -> + point + .hasCount(1) + .hasSum(123L)))); + } + + @Test + public void addHistogram_disabledMetric() { + // set up sink with disabled metric + Map enabledMetrics = new HashMap<>(); + enabledMetrics.put("client_message_size", false); + enabledMetrics.put("server_message_size", false); + + DoubleHistogramMetricInstrument doubleHistogramInstrument = + new DoubleHistogramMetricInstrument(0, "client_message_size", "Sent message size", "bytes", + Collections.emptyList(), + Collections.emptyList(), Collections.emptyList(), false); + LongHistogramMetricInstrument longHistogramInstrument = + new LongHistogramMetricInstrument(1, "server_message_size", "Received message size", + "bytes", + Collections.emptyList(), + Collections.emptyList(), Collections.emptyList(), true); + + // Create sink + sink = new OpenTelemetryMetricSink(testMeter, enabledMetrics, false, Collections.emptyList()); + + // Invoke updateMeasures + sink.updateMeasures(Arrays.asList(doubleHistogramInstrument, longHistogramInstrument)); + + sink.recordDoubleHistogram(doubleHistogramInstrument, 12.0, Collections.emptyList(), + Collections.emptyList()); + sink.recordLongHistogram(longHistogramInstrument, 123L, Collections.emptyList(), + Collections.emptyList()); + + assertThat(openTelemetryTesting.getMetrics()).isEmpty(); + } + + @Test + public void registerBatchCallback_allDisabled() { + // set up sink with disabled metric + Map enabledMetrics = new HashMap<>(); + + LongGaugeMetricInstrument longGaugeInstrumentDisabled = + new LongGaugeMetricInstrument(0, "disk", "Amount of disk used", "By", + Collections.emptyList(), Collections.emptyList(), false); + + // Create sink + sink = new OpenTelemetryMetricSink(testMeter, enabledMetrics, false, Collections.emptyList()); + + // Invoke updateMeasures + sink.updateMeasures(Arrays.asList(longGaugeInstrumentDisabled)); + + MetricSink.Registration registration = sink.registerBatchCallback(() -> { + sink.recordLongGauge( + longGaugeInstrumentDisabled, 999, Collections.emptyList(), Collections.emptyList()); + }, longGaugeInstrumentDisabled); + + assertThat(openTelemetryTesting.getMetrics()) + .satisfiesExactlyInAnyOrder(); + registration.close(); + } + + @Test + public void registerBatchCallback_bothEnabledAndDisabled() { + // set up sink with disabled metric + Map enabledMetrics = new HashMap<>(); + enabledMetrics.put("memory", true); + + LongGaugeMetricInstrument longGaugeInstrumentEnabled = + new LongGaugeMetricInstrument(0, "memory", "Amount of memory used", "By", + Collections.emptyList(), Collections.emptyList(), false); + LongGaugeMetricInstrument longGaugeInstrumentDisabled = + new LongGaugeMetricInstrument(1, "disk", "Amount of disk used", "By", + Collections.emptyList(), Collections.emptyList(), false); + + // Create sink + sink = new OpenTelemetryMetricSink(testMeter, enabledMetrics, false, Collections.emptyList()); + + // Invoke updateMeasures + sink.updateMeasures(Arrays.asList(longGaugeInstrumentEnabled, longGaugeInstrumentDisabled)); + + MetricSink.Registration registration = sink.registerBatchCallback(() -> { + sink.recordLongGauge( + longGaugeInstrumentEnabled, 99, Collections.emptyList(), Collections.emptyList()); + sink.recordLongGauge( + longGaugeInstrumentDisabled, 999, Collections.emptyList(), Collections.emptyList()); + }, longGaugeInstrumentEnabled, longGaugeInstrumentDisabled); + + assertThat(openTelemetryTesting.getMetrics()) + .satisfiesExactlyInAnyOrder( + metric -> + assertThat(metric) + .hasInstrumentationScope(InstrumentationScopeInfo.create( + OpenTelemetryConstants.INSTRUMENTATION_SCOPE)) + .hasName("memory") + .hasDescription("Amount of memory used") + .hasUnit("By") + .hasLongGaugeSatisfying( + gauge -> + gauge.hasPointsSatisfying( + point -> + point + .hasValue(99)))); + + // Gauge goes away after close + registration.close(); + assertThat(openTelemetryTesting.getMetrics()) + .satisfiesExactlyInAnyOrder(); + } + + @Test + public void recordLabels() { + Map enabledMetrics = new HashMap<>(); + enabledMetrics.put("client_latency", true); + enabledMetrics.put("ghosts_in_the_wire", true); + + List optionalLabels = Arrays.asList("optional_label_key_2"); + + LongCounterMetricInstrument longCounterInstrument = + new LongCounterMetricInstrument(0, "client_latency", "Client latency", "s", + ImmutableList.of("required_label_key_1", "required_label_key_2"), + ImmutableList.of("optional_label_key_1", "optional_label_key_2"), false); + LongUpDownCounterMetricInstrument longUpDownCounterInstrument = + new LongUpDownCounterMetricInstrument(1, "ghosts_in_the_wire", + "Number of Ghosts Haunting the Wire", "{ghosts}", + ImmutableList.of("required_label_key_1", "required_label_key_2"), + ImmutableList.of("optional_label_key_1", "optional_label_key_2"), false); + + // Create sink + sink = new OpenTelemetryMetricSink(testMeter, enabledMetrics, false, optionalLabels); + + // Invoke updateMeasures + sink.updateMeasures(Arrays.asList(longCounterInstrument, longUpDownCounterInstrument)); + + sink.addLongCounter(longCounterInstrument, 123L, + ImmutableList.of("required_label_value_1", "required_label_value_2"), + ImmutableList.of("optional_label_value_1", "optional_label_value_2")); + sink.addLongUpDownCounter(longUpDownCounterInstrument, -400L, + ImmutableList.of("required_label_value_1", "required_label_value_2"), + ImmutableList.of("optional_label_value_1", "optional_label_value_2")); + + io.opentelemetry.api.common.Attributes expectedAtrributes + = io.opentelemetry.api.common.Attributes.of( + AttributeKey.stringKey("required_label_key_1"), "required_label_value_1", + AttributeKey.stringKey("required_label_key_2"), "required_label_value_2", + AttributeKey.stringKey("optional_label_key_2"), "optional_label_value_2"); + + assertThat(openTelemetryTesting.getMetrics()) + .satisfiesExactlyInAnyOrder( + metric -> + assertThat(metric) + .hasInstrumentationScope(InstrumentationScopeInfo.create( + OpenTelemetryConstants.INSTRUMENTATION_SCOPE)) + .hasName("client_latency") + .hasDescription("Client latency") + .hasUnit("s") + .hasLongSumSatisfying( + longSum -> + longSum + .hasPointsSatisfying( + point -> + point + .hasAttributes(expectedAtrributes) + .hasValue(123L))), + metric -> + assertThat(metric) + .hasInstrumentationScope(InstrumentationScopeInfo.create( + OpenTelemetryConstants.INSTRUMENTATION_SCOPE)) + .hasName("ghosts_in_the_wire") + .hasDescription("Number of Ghosts Haunting the Wire") + .hasUnit("{ghosts}") + .hasLongSumSatisfying( + longSum -> + longSum + .hasPointsSatisfying( + point -> + point + .hasAttributes(expectedAtrributes) + .hasValue(-400L)))); + + } +} diff --git a/opentelemetry/src/test/java/io/grpc/opentelemetry/OpenTelemetryMetricsModuleTest.java b/opentelemetry/src/test/java/io/grpc/opentelemetry/OpenTelemetryMetricsModuleTest.java index 5217d66db20..7c9db875196 100644 --- a/opentelemetry/src/test/java/io/grpc/opentelemetry/OpenTelemetryMetricsModuleTest.java +++ b/opentelemetry/src/test/java/io/grpc/opentelemetry/OpenTelemetryMetricsModuleTest.java @@ -17,15 +17,22 @@ package io.grpc.opentelemetry; import static io.grpc.ClientStreamTracer.NAME_RESOLUTION_DELAYED; +import static io.grpc.opentelemetry.internal.OpenTelemetryConstants.LOCALITY_KEY; import static io.grpc.opentelemetry.internal.OpenTelemetryConstants.METHOD_KEY; import static io.grpc.opentelemetry.internal.OpenTelemetryConstants.STATUS_KEY; +import static io.grpc.opentelemetry.internal.OpenTelemetryConstants.TARGET_KEY; import static io.opentelemetry.sdk.testing.assertj.OpenTelemetryAssertions.assertThat; +import static java.util.Collections.emptyList; import static java.util.concurrent.TimeUnit.MILLISECONDS; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertTrue; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyDouble; +import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; +import com.google.common.collect.ImmutableMap; import io.grpc.Attributes; import io.grpc.CallOptions; import io.grpc.Channel; @@ -33,26 +40,53 @@ import io.grpc.ClientInterceptor; import io.grpc.ClientInterceptors; import io.grpc.ClientStreamTracer; +import io.grpc.Grpc; +import io.grpc.KnownLength; +import io.grpc.ManagedChannel; import io.grpc.Metadata; import io.grpc.MethodDescriptor; +import io.grpc.Server; import io.grpc.ServerCall; import io.grpc.ServerCallHandler; import io.grpc.ServerServiceDefinition; import io.grpc.ServerStreamTracer; import io.grpc.ServerStreamTracer.ServerCallInfo; +import io.grpc.ServiceDescriptor; import io.grpc.Status; import io.grpc.Status.Code; +import io.grpc.inprocess.InProcessChannelBuilder; +import io.grpc.inprocess.InProcessServerBuilder; import io.grpc.internal.FakeClock; +import io.grpc.internal.StatsTraceContext.ServerCallMethodListener; +import io.grpc.opentelemetry.GrpcOpenTelemetry.TargetFilter; import io.grpc.opentelemetry.OpenTelemetryMetricsModule.CallAttemptsTracerFactory; import io.grpc.opentelemetry.internal.OpenTelemetryConstants; +import io.grpc.stub.ClientCalls; +import io.grpc.testing.GrpcCleanupRule; import io.grpc.testing.GrpcServerRule; +import io.opentelemetry.api.OpenTelemetry; +import io.opentelemetry.api.baggage.Baggage; +import io.opentelemetry.api.baggage.propagation.W3CBaggagePropagator; +import io.opentelemetry.api.common.AttributeKey; +import io.opentelemetry.api.metrics.DoubleHistogram; import io.opentelemetry.api.metrics.Meter; +import io.opentelemetry.context.Context; +import io.opentelemetry.context.Scope; +import io.opentelemetry.context.propagation.ContextPropagators; +import io.opentelemetry.sdk.OpenTelemetrySdk; import io.opentelemetry.sdk.common.InstrumentationScopeInfo; +import io.opentelemetry.sdk.metrics.data.MetricData; import io.opentelemetry.sdk.testing.junit4.OpenTelemetryRule; +import java.io.IOException; import java.io.InputStream; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.Optional; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; import javax.annotation.Nullable; +import org.junit.After; import org.junit.Before; import org.junit.Rule; import org.junit.Test; @@ -73,10 +107,9 @@ public class OpenTelemetryMetricsModuleTest { private static final CallOptions.Key CUSTOM_OPTION = CallOptions.Key.createWithDefault("option1", "default"); private static final CallOptions CALL_OPTIONS = - CallOptions.DEFAULT.withOption(CUSTOM_OPTION, "customvalue"); + CallOptions.DEFAULT.withOption(NAME_RESOLUTION_DELAYED, 10L); private static final ClientStreamTracer.StreamInfo STREAM_INFO = - ClientStreamTracer.StreamInfo.newBuilder() - .setCallOptions(CallOptions.DEFAULT.withOption(NAME_RESOLUTION_DELAYED, 10L)).build(); + ClientStreamTracer.StreamInfo.newBuilder().setCallOptions(CALL_OPTIONS).build(); private static final String CLIENT_ATTEMPT_COUNT_INSTRUMENT_NAME = "grpc.client.attempt.started"; private static final String CLIENT_ATTEMPT_DURATION_INSTRUMENT_NAME = "grpc.client.attempt.duration"; @@ -85,14 +118,28 @@ public class OpenTelemetryMetricsModuleTest { private static final String CLIENT_ATTEMPT_RECV_TOTAL_COMPRESSED_MESSAGE_SIZE = "grpc.client.attempt.rcvd_total_compressed_message_size"; private static final String CLIENT_CALL_DURATION = "grpc.client.call.duration"; + private static final String CLIENT_CALL_RETRIES = "grpc.client.call.retries"; + private static final String CLIENT_CALL_TRANSPARENT_RETRIES = + "grpc.client.call.transparent_retries"; + private static final String CLIENT_CALL_HEDGES = "grpc.client.call.hedges"; + private static final String CLIENT_CALL_RETRY_DELAY = "grpc.client.call.retry_delay"; private static final String SERVER_CALL_COUNT = "grpc.server.call.started"; private static final String SERVER_CALL_DURATION = "grpc.server.call.duration"; private static final String SERVER_CALL_SENT_TOTAL_COMPRESSED_MESSAGE_SIZE = "grpc.server.call.sent_total_compressed_message_size"; private static final String SERVER_CALL_RECV_TOTAL_COMPRESSED_MESSAGE_SIZE = "grpc.server.call.rcvd_total_compressed_message_size"; + private static final double[] latencyBuckets = + { 0d, 0.00001d, 0.00005d, 0.0001d, 0.0003d, 0.0006d, 0.0008d, 0.001d, 0.002d, + 0.003d, 0.004d, 0.005d, 0.006d, 0.008d, 0.01d, 0.013d, 0.016d, 0.02d, + 0.025d, 0.03d, 0.04d, 0.05d, 0.065d, 0.08d, 0.1d, 0.13d, 0.16d, + 0.2d, 0.25d, 0.3d, 0.4d, 0.5d, 0.65d, 0.8d, 1d, 2d, + 5d, 10d, 20d, 50d, 100d }; + private static final double[] sizeBuckets = + { 0L, 1024L, 2048L, 4096L, 16384L, 65536L, 262144L, 1048576L, 4194304L, 16777216L, + 67108864L, 268435456L, 1073741824L, 4294967296L }; - private static final class StringInputStream extends InputStream { + private static final class StringInputStream extends InputStream implements KnownLength { final String string; StringInputStream(String string) { @@ -103,6 +150,11 @@ private static final class StringInputStream extends InputStream { public int read() { throw new UnsupportedOperationException("should not be called"); } + + @Override + public int available() throws IOException { + return string == null ? 0 : string.length(); + } } private static final MethodDescriptor.Marshaller MARSHALLER = @@ -121,6 +173,8 @@ public String parse(InputStream stream) { @Rule public final MockitoRule mocks = MockitoJUnit.rule(); @Rule + public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule(); + @Rule public final GrpcServerRule grpcServerRule = new GrpcServerRule().directExecutor(); @Rule public final OpenTelemetryRule openTelemetryTesting = OpenTelemetryRule.create(); @@ -131,6 +185,9 @@ public String parse(InputStream stream) { @Captor private ArgumentCaptor statusCaptor; + private Server server; + private ManagedChannel channel; + private final FakeClock fakeClock = new FakeClock(); private final MethodDescriptor method = MethodDescriptor.newBuilder() @@ -141,18 +198,32 @@ public String parse(InputStream stream) { .setSampledToLocalTracing(true) .build(); private Meter testMeter; + private final Map enabledMetricsMap = ImmutableMap.of(); + + private final boolean disableDefaultMetrics = false; @Before public void setUp() throws Exception { testMeter = openTelemetryTesting.getOpenTelemetry() .getMeter(OpenTelemetryConstants.INSTRUMENTATION_SCOPE); + + } + + @After + public void tearDown() { + if (channel != null) { + channel.shutdownNow(); + } + if (server != null) { + server.shutdownNow(); + } } @Test public void testClientInterceptors() { - OpenTelemetryMetricsResource resource = OpenTelemetryModule.createMetricInstruments(testMeter); - OpenTelemetryMetricsModule module = - new OpenTelemetryMetricsModule(fakeClock.getStopwatchSupplier(), resource); + OpenTelemetryMetricsResource resource = GrpcOpenTelemetry.createMetricInstruments(testMeter, + enabledMetricsMap, disableDefaultMetrics); + OpenTelemetryMetricsModule module = newOpenTelemetryMetricsModule(resource); grpcServerRule.getServiceRegistry().addService( ServerServiceDefinition.builder("package1.service2").addMethod( method, new ServerCallHandler() { @@ -168,7 +239,7 @@ public ServerCall.Listener startCall( }).build()); final AtomicReference capturedCallOptions = new AtomicReference<>(); - ClientInterceptor callOptionsCatureInterceptor = new ClientInterceptor() { + ClientInterceptor callOptionsCaptureInterceptor = new ClientInterceptor() { @Override public ClientCall interceptCall( MethodDescriptor method, CallOptions callOptions, Channel next) { @@ -178,10 +249,11 @@ public ClientCall interceptCall( }; Channel interceptedChannel = ClientInterceptors.intercept( - grpcServerRule.getChannel(), callOptionsCatureInterceptor, - module.getClientInterceptor()); + grpcServerRule.getChannel(), callOptionsCaptureInterceptor, + module.getClientInterceptor("target:///")); ClientCall call; - call = interceptedChannel.newCall(method, CALL_OPTIONS); + call = interceptedChannel.newCall( + method, CallOptions.DEFAULT.withOption(CUSTOM_OPTION, "customvalue")); assertEquals("customvalue", capturedCallOptions.get().getOption(CUSTOM_OPTION)); assertEquals(1, capturedCallOptions.get().getStreamTracerFactories().size()); @@ -205,15 +277,18 @@ public ClientCall interceptCall( @Test public void clientBasicMetrics() { - OpenTelemetryMetricsResource resource = OpenTelemetryModule.createMetricInstruments(testMeter);; - OpenTelemetryMetricsModule module = - new OpenTelemetryMetricsModule(fakeClock.getStopwatchSupplier(), resource); + String target = "target:///"; + OpenTelemetryMetricsResource resource = GrpcOpenTelemetry.createMetricInstruments(testMeter, + enabledMetricsMap, disableDefaultMetrics); + OpenTelemetryMetricsModule module = newOpenTelemetryMetricsModule(resource); OpenTelemetryMetricsModule.CallAttemptsTracerFactory callAttemptsTracerFactory = - new CallAttemptsTracerFactory(module, method.getFullMethodName()); + new CallAttemptsTracerFactory(module, target, CALL_OPTIONS, method.getFullMethodName(), + emptyList(), Context.root()); Metadata headers = new Metadata(); ClientStreamTracer tracer = callAttemptsTracerFactory.newClientStreamTracer(STREAM_INFO, headers); io.opentelemetry.api.common.Attributes attributes = io.opentelemetry.api.common.Attributes.of( + TARGET_KEY, target, METHOD_KEY, method.getFullMethodName()); assertThat(openTelemetryTesting.getMetrics()) @@ -233,6 +308,8 @@ public void clientBasicMetrics() { .hasAttributes(attributes) .hasValue(1)))); + tracer.addOptionalLabel("grpc.lb.locality", "should-be-ignored"); + fakeClock.forwardTime(30, TimeUnit.MILLISECONDS); tracer.outboundHeaders(); @@ -251,10 +328,11 @@ public void clientBasicMetrics() { tracer.inboundMessage(1); tracer.inboundWireSize(154); tracer.streamClosed(Status.OK); - callAttemptsTracerFactory.callEnded(Status.OK); + callAttemptsTracerFactory.callEnded(Status.OK, CALL_OPTIONS); io.opentelemetry.api.common.Attributes clientAttributes = io.opentelemetry.api.common.Attributes.of( + TARGET_KEY, target, METHOD_KEY, method.getFullMethodName(), STATUS_KEY, Status.Code.OK.toString()); @@ -287,7 +365,11 @@ public void clientBasicMetrics() { point .hasCount(1) .hasSum(0.03 + 0.1 + 0.016 + 0.024) - .hasAttributes(clientAttributes))), + .hasAttributes(clientAttributes) + .hasBucketBoundaries(latencyBuckets) + .hasBucketCounts(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0))), metric -> assertThat(metric) .hasInstrumentationScope(InstrumentationScopeInfo.create( @@ -302,7 +384,10 @@ public void clientBasicMetrics() { point .hasCount(1) .hasSum(1028L + 99) - .hasAttributes(clientAttributes))), + .hasAttributes(clientAttributes) + .hasBucketBoundaries(sizeBuckets) + .hasBucketCounts(0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0))), metric -> assertThat(metric) .hasInstrumentationScope(InstrumentationScopeInfo.create( @@ -319,7 +404,9 @@ public void clientBasicMetrics() { point .hasCount(1) .hasSum(154) - .hasAttributes(clientAttributes))), + .hasAttributes(clientAttributes) + .hasBucketCounts(0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0))), metric -> assertThat(metric) .hasInstrumentationScope(InstrumentationScopeInfo.create( @@ -333,22 +420,111 @@ public void clientBasicMetrics() { point .hasCount(1) .hasSum(0.03 + 0.1 + 0.016 + 0.024) - .hasAttributes(clientAttributes)))); + .hasAttributes(clientAttributes) + .hasBucketBoundaries(latencyBuckets) + .hasBucketCounts(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0)))); + + assertThat(openTelemetryTesting.getMetrics()) + .extracting("name") + .doesNotContain( + CLIENT_CALL_RETRIES, + CLIENT_CALL_TRANSPARENT_RETRIES, + CLIENT_CALL_HEDGES, + CLIENT_CALL_RETRY_DELAY); + } + + @Test + public void clientBasicMetrics_withRetryMetricsEnabled_shouldRecordZeroOrBeAbsent() { + // Explicitly enable the retry metrics + Map enabledMetrics = ImmutableMap.of( + CLIENT_CALL_RETRIES, true, + CLIENT_CALL_TRANSPARENT_RETRIES, true, + CLIENT_CALL_HEDGES, true, + CLIENT_CALL_RETRY_DELAY, true + ); + + String target = "target:///"; + OpenTelemetryMetricsResource resource = GrpcOpenTelemetry.createMetricInstruments(testMeter, + enabledMetrics, disableDefaultMetrics); + OpenTelemetryMetricsModule module = newOpenTelemetryMetricsModule(resource); + OpenTelemetryMetricsModule.CallAttemptsTracerFactory callAttemptsTracerFactory = + new CallAttemptsTracerFactory(module, target, CALL_OPTIONS, method.getFullMethodName(), + emptyList(), Context.root()); + ClientStreamTracer tracer = + callAttemptsTracerFactory.newClientStreamTracer(STREAM_INFO, new Metadata()); + + fakeClock.forwardTime(30, TimeUnit.MILLISECONDS); + tracer.outboundHeaders(); + fakeClock.forwardTime(100, TimeUnit.MILLISECONDS); + tracer.outboundMessage(0); + tracer.streamClosed(Status.OK); + callAttemptsTracerFactory.callEnded(Status.OK, CALL_OPTIONS); + + io.opentelemetry.api.common.Attributes finalAttributes + = io.opentelemetry.api.common.Attributes.of( + TARGET_KEY, target, + METHOD_KEY, method.getFullMethodName()); + + assertThat(openTelemetryTesting.getMetrics()) + .satisfiesExactlyInAnyOrder( + metric -> assertThat(metric).hasName(CLIENT_ATTEMPT_COUNT_INSTRUMENT_NAME), + metric -> assertThat(metric).hasName(CLIENT_ATTEMPT_DURATION_INSTRUMENT_NAME), + metric -> assertThat(metric).hasName(CLIENT_ATTEMPT_SENT_TOTAL_COMPRESSED_MESSAGE_SIZE), + metric -> assertThat(metric).hasName(CLIENT_ATTEMPT_RECV_TOTAL_COMPRESSED_MESSAGE_SIZE), + metric -> assertThat(metric).hasName(CLIENT_CALL_DURATION), + metric -> assertThat(metric) + .hasName(CLIENT_CALL_RETRY_DELAY) + .hasHistogramSatisfying( + histogram -> + histogram.hasPointsSatisfying( + point -> + point + .hasSum(0) + .hasCount(1) + .hasAttributes(finalAttributes))) + + ); + + List optionalMetricNames = Arrays.asList( + CLIENT_CALL_RETRIES, + CLIENT_CALL_TRANSPARENT_RETRIES, + CLIENT_CALL_HEDGES); + + for (String metricName : optionalMetricNames) { + Optional metric = openTelemetryTesting.getMetrics().stream() + .filter(m -> m.getName().equals(metricName)) + .findFirst(); + if (metric.isPresent()) { + assertThat(metric.get()) + .hasHistogramSatisfying( + histogram -> + histogram.hasPointsSatisfying( + point -> + point + .hasSum(0) + .hasCount(1) + .hasAttributes(finalAttributes))); + } + } } // This test is only unit-testing the metrics recording logic. The retry behavior is faked. @Test public void recordAttemptMetrics() { - OpenTelemetryMetricsResource resource = OpenTelemetryModule.createMetricInstruments(testMeter); - OpenTelemetryMetricsModule module = - new OpenTelemetryMetricsModule(fakeClock.getStopwatchSupplier(), resource); + String target = "dns:///example.com"; + OpenTelemetryMetricsResource resource = GrpcOpenTelemetry.createMetricInstruments(testMeter, + enabledMetricsMap, disableDefaultMetrics); + OpenTelemetryMetricsModule module = newOpenTelemetryMetricsModule(resource); OpenTelemetryMetricsModule.CallAttemptsTracerFactory callAttemptsTracerFactory = - new OpenTelemetryMetricsModule.CallAttemptsTracerFactory(module, - method.getFullMethodName()); + new OpenTelemetryMetricsModule.CallAttemptsTracerFactory(module, target, CALL_OPTIONS, + method.getFullMethodName(), emptyList(), Context.root()); ClientStreamTracer tracer = callAttemptsTracerFactory.newClientStreamTracer(STREAM_INFO, new Metadata()); io.opentelemetry.api.common.Attributes attributes = io.opentelemetry.api.common.Attributes.of( + TARGET_KEY, target, METHOD_KEY, method.getFullMethodName()); assertThat(openTelemetryTesting.getMetrics()) @@ -379,6 +555,7 @@ public void recordAttemptMetrics() { io.opentelemetry.api.common.Attributes clientAttributes = io.opentelemetry.api.common.Attributes.of( + TARGET_KEY, target, METHOD_KEY, method.getFullMethodName(), STATUS_KEY, Code.UNAVAILABLE.toString()); @@ -411,7 +588,8 @@ public void recordAttemptMetrics() { point .hasCount(1) .hasSum(0.03 + 0.1 + 0.024) - .hasAttributes(clientAttributes))), + .hasAttributes(clientAttributes) + .hasBucketBoundaries(latencyBuckets))), metric -> assertThat(metric) .hasInstrumentationScope(InstrumentationScopeInfo.create( @@ -426,7 +604,8 @@ public void recordAttemptMetrics() { point .hasCount(1) .hasSum(1028L) - .hasAttributes(clientAttributes))), + .hasAttributes(clientAttributes) + .hasBucketBoundaries(sizeBuckets))), metric -> assertThat(metric) .hasInstrumentationScope(InstrumentationScopeInfo.create( @@ -443,7 +622,8 @@ public void recordAttemptMetrics() { point .hasCount(1) .hasSum(0) - .hasAttributes(clientAttributes)))); + .hasAttributes(clientAttributes) + .hasBucketBoundaries(sizeBuckets)))); // faking retry @@ -459,6 +639,7 @@ public void recordAttemptMetrics() { io.opentelemetry.api.common.Attributes clientAttributes1 = io.opentelemetry.api.common.Attributes.of( + TARGET_KEY, target, METHOD_KEY, method.getFullMethodName(), STATUS_KEY, Code.NOT_FOUND.toString()); @@ -492,12 +673,14 @@ public void recordAttemptMetrics() { point .hasCount(1) .hasSum(0.1) - .hasAttributes(clientAttributes1), + .hasAttributes(clientAttributes1) + .hasBucketBoundaries(latencyBuckets), point -> point .hasCount(1) .hasSum(0.154) - .hasAttributes(clientAttributes))), + .hasAttributes(clientAttributes) + .hasBucketBoundaries(latencyBuckets))), metric -> assertThat(metric) .hasInstrumentationScope(InstrumentationScopeInfo.create( @@ -514,12 +697,14 @@ public void recordAttemptMetrics() { point .hasCount(1) .hasSum(0) - .hasAttributes(clientAttributes1), + .hasAttributes(clientAttributes1) + .hasBucketBoundaries(sizeBuckets), point -> point .hasCount(1) .hasSum(0) - .hasAttributes(clientAttributes))), + .hasAttributes(clientAttributes) + .hasBucketBoundaries(sizeBuckets))), metric -> assertThat(metric) .hasInstrumentationScope(InstrumentationScopeInfo.create( @@ -534,12 +719,14 @@ public void recordAttemptMetrics() { point .hasCount(1) .hasSum(1028L) - .hasAttributes(clientAttributes1), + .hasAttributes(clientAttributes1) + .hasBucketBoundaries(sizeBuckets), point -> point .hasCount(1) .hasSum(1028L) - .hasAttributes(clientAttributes)))); + .hasAttributes(clientAttributes) + .hasBucketBoundaries(sizeBuckets)))); // fake transparent retry fakeClock.forwardTime(10, TimeUnit.MILLISECONDS); @@ -579,12 +766,14 @@ public void recordAttemptMetrics() { point .hasCount(1) .hasSum(0.1) - .hasAttributes(clientAttributes1), + .hasAttributes(clientAttributes1) + .hasBucketBoundaries(latencyBuckets), point -> point .hasCount(2) .hasSum(0.154 + 0.032) - .hasAttributes(clientAttributes))), + .hasAttributes(clientAttributes) + .hasBucketBoundaries(latencyBuckets))), metric -> assertThat(metric) .hasInstrumentationScope(InstrumentationScopeInfo.create( @@ -601,12 +790,14 @@ public void recordAttemptMetrics() { point .hasCount(1) .hasSum(0) - .hasAttributes(clientAttributes1), + .hasAttributes(clientAttributes1) + .hasBucketBoundaries(sizeBuckets), point -> point .hasCount(2) .hasSum(0 + 0) - .hasAttributes(clientAttributes))), + .hasAttributes(clientAttributes) + .hasBucketBoundaries(sizeBuckets))), metric -> assertThat(metric) .hasInstrumentationScope(InstrumentationScopeInfo.create( @@ -621,12 +812,14 @@ public void recordAttemptMetrics() { point .hasCount(1) .hasSum(1028L) - .hasAttributes(clientAttributes1), + .hasAttributes(clientAttributes1) + .hasBucketBoundaries(sizeBuckets), point -> point .hasCount(2) .hasSum(1028L + 0) - .hasAttributes(clientAttributes)))); + .hasAttributes(clientAttributes) + .hasBucketBoundaries(sizeBuckets)))); // fake another transparent retry fakeClock.forwardTime(10, MILLISECONDS); @@ -641,10 +834,11 @@ public void recordAttemptMetrics() { fakeClock.forwardTime(24, MILLISECONDS); // RPC succeeded tracer.streamClosed(Status.OK); - callAttemptsTracerFactory.callEnded(Status.OK); + callAttemptsTracerFactory.callEnded(Status.OK, CALL_OPTIONS); io.opentelemetry.api.common.Attributes clientAttributes2 = io.opentelemetry.api.common.Attributes.of( + TARGET_KEY, target, METHOD_KEY, method.getFullMethodName(), STATUS_KEY, Code.OK.toString()); @@ -678,17 +872,20 @@ public void recordAttemptMetrics() { point .hasCount(1) .hasSum(1028L) - .hasAttributes(clientAttributes1), + .hasAttributes(clientAttributes1) + .hasBucketBoundaries(sizeBuckets), point -> point .hasCount(2) .hasSum(1028L + 0) - .hasAttributes(clientAttributes), + .hasAttributes(clientAttributes) + .hasBucketBoundaries(sizeBuckets), point -> point .hasCount(1) .hasSum(1028L) - .hasAttributes(clientAttributes2))), + .hasAttributes(clientAttributes2) + .hasBucketBoundaries(sizeBuckets))), metric -> assertThat(metric) .hasInstrumentationScope(InstrumentationScopeInfo.create( @@ -703,7 +900,8 @@ public void recordAttemptMetrics() { .hasCount(1) .hasSum(0.03 + 0.1 + 0.024 + 1 + 0.1 + 0.01 + 0.032 + 0.01 + 0.024) - .hasAttributes(clientAttributes2))), + .hasAttributes(clientAttributes2) + .hasBucketBoundaries(latencyBuckets))), metric -> assertThat(metric) .hasInstrumentationScope(InstrumentationScopeInfo.create( @@ -717,17 +915,20 @@ public void recordAttemptMetrics() { point .hasCount(1) .hasSum(0.100) - .hasAttributes(clientAttributes1), + .hasAttributes(clientAttributes1) + .hasBucketBoundaries(latencyBuckets), point -> point .hasCount(2) .hasSum(0.154 + 0.032) - .hasAttributes(clientAttributes), + .hasAttributes(clientAttributes) + .hasBucketBoundaries(latencyBuckets), point -> point .hasCount(1) .hasSum(0.024) - .hasAttributes(clientAttributes2))), + .hasAttributes(clientAttributes2) + .hasBucketBoundaries(latencyBuckets))), metric -> assertThat(metric) .hasInstrumentationScope(InstrumentationScopeInfo.create( @@ -744,37 +945,219 @@ public void recordAttemptMetrics() { point .hasCount(1) .hasSum(0) - .hasAttributes(clientAttributes1), + .hasAttributes(clientAttributes1) + .hasBucketBoundaries(sizeBuckets), point -> point .hasCount(2) .hasSum(0 + 0) - .hasAttributes(clientAttributes), + .hasAttributes(clientAttributes) + .hasBucketBoundaries(sizeBuckets), point -> point .hasCount(1) .hasSum(33D) - .hasAttributes(clientAttributes2)))); + .hasAttributes(clientAttributes2) + .hasBucketBoundaries(sizeBuckets)))); + } + + @Test + public void recordAttemptMetrics_withRetryMetricsEnabled() { + Map enabledMetrics = ImmutableMap.of( + CLIENT_CALL_RETRIES, true, + CLIENT_CALL_TRANSPARENT_RETRIES, true, + CLIENT_CALL_HEDGES, true, + CLIENT_CALL_RETRY_DELAY, true + ); + + String target = "dns:///example.com"; + OpenTelemetryMetricsResource resource = GrpcOpenTelemetry.createMetricInstruments(testMeter, + enabledMetrics, disableDefaultMetrics); + OpenTelemetryMetricsModule module = newOpenTelemetryMetricsModule(resource); + OpenTelemetryMetricsModule.CallAttemptsTracerFactory callAttemptsTracerFactory = + new OpenTelemetryMetricsModule.CallAttemptsTracerFactory(module, target, CALL_OPTIONS, + method.getFullMethodName(), emptyList(), Context.root()); + + ClientStreamTracer tracer = + callAttemptsTracerFactory.newClientStreamTracer(STREAM_INFO, new Metadata()); + fakeClock.forwardTime(154, TimeUnit.MILLISECONDS); + tracer.streamClosed(Status.UNAVAILABLE); + + fakeClock.forwardTime(1000, TimeUnit.MILLISECONDS); + tracer = callAttemptsTracerFactory.newClientStreamTracer(STREAM_INFO, new Metadata()); + fakeClock.forwardTime(100, TimeUnit.MILLISECONDS); + tracer.streamClosed(Status.NOT_FOUND); + + fakeClock.forwardTime(10, TimeUnit.MILLISECONDS); + tracer = callAttemptsTracerFactory.newClientStreamTracer( + STREAM_INFO.toBuilder().setIsTransparentRetry(true).build(), new Metadata()); + fakeClock.forwardTime(32, MILLISECONDS); + tracer.streamClosed(Status.UNAVAILABLE); + + fakeClock.forwardTime(10, MILLISECONDS); + tracer = callAttemptsTracerFactory.newClientStreamTracer( + STREAM_INFO.toBuilder().setIsTransparentRetry(true).build(), new Metadata()); + tracer.inboundWireSize(33); + fakeClock.forwardTime(24, MILLISECONDS); + tracer.streamClosed(Status.OK); // RPC succeeded + + // --- The overall call ends --- + callAttemptsTracerFactory.callEnded(Status.OK, CALL_OPTIONS); + + // Define attributes for assertions + io.opentelemetry.api.common.Attributes finalAttributes + = io.opentelemetry.api.common.Attributes.of( + TARGET_KEY, target, + METHOD_KEY, method.getFullMethodName()); + + // FINAL ASSERTION BLOCK + assertThat(openTelemetryTesting.getMetrics()) + .satisfiesExactlyInAnyOrder( + // Default metrics + metric -> assertThat(metric).hasName(CLIENT_ATTEMPT_COUNT_INSTRUMENT_NAME), + metric -> assertThat(metric).hasName(CLIENT_ATTEMPT_DURATION_INSTRUMENT_NAME), + metric -> assertThat(metric).hasName(CLIENT_ATTEMPT_SENT_TOTAL_COMPRESSED_MESSAGE_SIZE), + metric -> assertThat(metric).hasName(CLIENT_ATTEMPT_RECV_TOTAL_COMPRESSED_MESSAGE_SIZE), + metric -> assertThat(metric).hasName(CLIENT_CALL_DURATION), + + // --- Assertions for the retry metrics --- + metric -> assertThat(metric) + .hasName(CLIENT_CALL_RETRIES) + .hasUnit("{retry}") + .hasHistogramSatisfying(histogram -> histogram.hasPointsSatisfying( + point -> point + .hasCount(1) + .hasSum(1) // We faked one standard retry + .hasAttributes(finalAttributes))), + metric -> assertThat(metric) + .hasName(CLIENT_CALL_TRANSPARENT_RETRIES) + .hasUnit("{transparent_retry}") + .hasHistogramSatisfying(histogram -> histogram.hasPointsSatisfying( + point -> point + .hasCount(1) + .hasSum(2) // We faked two transparent retries + .hasAttributes(finalAttributes))), + metric -> assertThat(metric) + .hasName(CLIENT_CALL_RETRY_DELAY) + .hasUnit("s") + .hasHistogramSatisfying(histogram -> histogram.hasPointsSatisfying( + point -> point + .hasCount(1) + .hasSum(1.02) // 1000ms + 10ms + 10ms + .hasAttributes(finalAttributes))) + ); + } + + @Test + public void recordAttemptMetrics_withHedgedCalls() { + // Enable the retry metrics, including hedges + Map enabledMetrics = ImmutableMap.of( + CLIENT_CALL_RETRIES, true, + CLIENT_CALL_TRANSPARENT_RETRIES, true, + CLIENT_CALL_HEDGES, true, + CLIENT_CALL_RETRY_DELAY, true + ); + + String target = "dns:///example.com"; + OpenTelemetryMetricsResource resource = GrpcOpenTelemetry.createMetricInstruments(testMeter, + enabledMetrics, disableDefaultMetrics); + OpenTelemetryMetricsModule module = newOpenTelemetryMetricsModule(resource); + OpenTelemetryMetricsModule.CallAttemptsTracerFactory callAttemptsTracerFactory = + new OpenTelemetryMetricsModule.CallAttemptsTracerFactory(module, target, CALL_OPTIONS, + method.getFullMethodName(), emptyList(), Context.root()); + + // Create a StreamInfo specifically for hedged attempts + final ClientStreamTracer.StreamInfo hedgedStreamInfo = + STREAM_INFO.toBuilder().setIsHedging(true).build(); + + // --- First attempt starts --- + ClientStreamTracer tracer = + callAttemptsTracerFactory.newClientStreamTracer(STREAM_INFO, new Metadata()); + + // --- Faking a hedged attempt --- + fakeClock.forwardTime(10, TimeUnit.MILLISECONDS); // Hedging delay + ClientStreamTracer hedgeTracer1 = + callAttemptsTracerFactory.newClientStreamTracer(hedgedStreamInfo, new Metadata()); + + // --- Faking a second hedged attempt --- + fakeClock.forwardTime(20, TimeUnit.MILLISECONDS); // Another hedging delay + ClientStreamTracer hedgeTracer2 = + callAttemptsTracerFactory.newClientStreamTracer(hedgedStreamInfo, new Metadata()); + + // --- Let the attempts resolve --- + fakeClock.forwardTime(50, TimeUnit.MILLISECONDS); + // Initial attempt is cancelled because a hedge will succeed + tracer.streamClosed(Status.CANCELLED); + hedgeTracer1.streamClosed(Status.UNAVAILABLE); // First hedge fails + + fakeClock.forwardTime(30, TimeUnit.MILLISECONDS); + hedgeTracer2.streamClosed(Status.OK); // Second hedge succeeds + + // --- The overall call ends --- + callAttemptsTracerFactory.callEnded(Status.OK, CALL_OPTIONS); + + // Define attributes for assertions + io.opentelemetry.api.common.Attributes finalAttributes + = io.opentelemetry.api.common.Attributes.of( + TARGET_KEY, target, + METHOD_KEY, method.getFullMethodName()); + + // FINAL ASSERTION BLOCK + // We expect 7 metrics: 5 default + hedges + retry_delay. + // Retries and transparent_retries are 0 and will not be reported. + assertThat(openTelemetryTesting.getMetrics()) + .satisfiesExactlyInAnyOrder( + // Default metrics + metric -> assertThat(metric).hasName(CLIENT_ATTEMPT_COUNT_INSTRUMENT_NAME), + metric -> assertThat(metric).hasName(CLIENT_ATTEMPT_DURATION_INSTRUMENT_NAME), + metric -> assertThat(metric).hasName(CLIENT_ATTEMPT_SENT_TOTAL_COMPRESSED_MESSAGE_SIZE), + metric -> assertThat(metric).hasName(CLIENT_ATTEMPT_RECV_TOTAL_COMPRESSED_MESSAGE_SIZE), + metric -> assertThat(metric).hasName(CLIENT_CALL_DURATION), + + // --- Assertions for the NEW metrics --- + metric -> assertThat(metric) + .hasName(CLIENT_CALL_HEDGES) + .hasUnit("{hedge}") + .hasHistogramSatisfying(histogram -> histogram.hasPointsSatisfying( + point -> point + .hasCount(1) + .hasSum(2) + .hasAttributes(finalAttributes))), + metric -> assertThat(metric) + .hasName(CLIENT_CALL_RETRY_DELAY) + .hasUnit("s") + .hasHistogramSatisfying( + histogram -> + histogram.hasPointsSatisfying( + point -> + point + .hasCount(1) + .hasSum(0) + .hasAttributes(finalAttributes))) + ); } @Test public void clientStreamNeverCreatedStillRecordMetrics() { - OpenTelemetryMetricsResource resource = OpenTelemetryModule.createMetricInstruments(testMeter); - OpenTelemetryMetricsModule module = - new OpenTelemetryMetricsModule(fakeClock.getStopwatchSupplier(), resource); + String target = "dns:///foo.example.com"; + OpenTelemetryMetricsResource resource = GrpcOpenTelemetry.createMetricInstruments(testMeter, + enabledMetricsMap, disableDefaultMetrics); + OpenTelemetryMetricsModule module = newOpenTelemetryMetricsModule(resource); OpenTelemetryMetricsModule.CallAttemptsTracerFactory callAttemptsTracerFactory = - new OpenTelemetryMetricsModule.CallAttemptsTracerFactory(module, - method.getFullMethodName()); + new OpenTelemetryMetricsModule.CallAttemptsTracerFactory(module, target, CALL_OPTIONS, + method.getFullMethodName(), emptyList(), Context.root()); fakeClock.forwardTime(3000, MILLISECONDS); Status status = Status.DEADLINE_EXCEEDED.withDescription("5 seconds"); - callAttemptsTracerFactory.callEnded(status); + callAttemptsTracerFactory.callEnded(status, CALL_OPTIONS); io.opentelemetry.api.common.Attributes attemptStartedAttributes = io.opentelemetry.api.common.Attributes.of( + TARGET_KEY, target, METHOD_KEY, method.getFullMethodName()); io.opentelemetry.api.common.Attributes clientAttributes = io.opentelemetry.api.common.Attributes.of( + TARGET_KEY, target, METHOD_KEY, method.getFullMethodName(), STATUS_KEY, Code.DEADLINE_EXCEEDED.toString()); @@ -809,7 +1192,8 @@ public void clientStreamNeverCreatedStillRecordMetrics() { point .hasCount(1) .hasSum(0) - .hasAttributes(clientAttributes))), + .hasAttributes(clientAttributes) + .hasBucketBoundaries(sizeBuckets))), metric -> assertThat(metric) .hasInstrumentationScope(InstrumentationScopeInfo.create( @@ -823,7 +1207,8 @@ public void clientStreamNeverCreatedStillRecordMetrics() { point .hasCount(1) .hasSum(3D) - .hasAttributes(clientAttributes))), + .hasAttributes(clientAttributes) + .hasBucketBoundaries(latencyBuckets))), metric -> assertThat(metric) .hasInstrumentationScope(InstrumentationScopeInfo.create( @@ -837,7 +1222,8 @@ public void clientStreamNeverCreatedStillRecordMetrics() { point .hasCount(1) .hasSum(0) - .hasAttributes(clientAttributes))), + .hasAttributes(clientAttributes) + .hasBucketBoundaries(latencyBuckets))), metric -> assertThat(metric) .hasInstrumentationScope(InstrumentationScopeInfo.create( @@ -854,15 +1240,382 @@ public void clientStreamNeverCreatedStillRecordMetrics() { point .hasCount(1) .hasSum(0) - .hasAttributes(clientAttributes)))); + .hasAttributes(clientAttributes) + .hasBucketBoundaries(sizeBuckets)))); } @Test - public void serverBasicMetrics() { - OpenTelemetryMetricsResource resource = OpenTelemetryModule.createMetricInstruments(testMeter); + public void clientLocalityMetrics_present() { + String target = "target:///"; + OpenTelemetryMetricsResource resource = GrpcOpenTelemetry.createMetricInstruments(testMeter, + enabledMetricsMap, disableDefaultMetrics); OpenTelemetryMetricsModule module = new OpenTelemetryMetricsModule( - fakeClock.getStopwatchSupplier(), resource); + fakeClock.getStopwatchSupplier(), resource, Arrays.asList("grpc.lb.locality"), + emptyList()); + OpenTelemetryMetricsModule.CallAttemptsTracerFactory callAttemptsTracerFactory = + new CallAttemptsTracerFactory(module, target, CALL_OPTIONS, method.getFullMethodName(), + emptyList(), Context.root()); + + ClientStreamTracer tracer = + callAttemptsTracerFactory.newClientStreamTracer(STREAM_INFO, new Metadata()); + tracer.addOptionalLabel("grpc.lb.foo", "unimportant"); + tracer.addOptionalLabel("grpc.lb.locality", "should-be-overwritten"); + tracer.addOptionalLabel("grpc.lb.locality", "the-moon"); + tracer.addOptionalLabel("grpc.lb.foo", "thats-no-moon"); + tracer.streamClosed(Status.OK); + callAttemptsTracerFactory.callEnded(Status.OK, CALL_OPTIONS); + + io.opentelemetry.api.common.Attributes attributes = io.opentelemetry.api.common.Attributes.of( + TARGET_KEY, target, + METHOD_KEY, method.getFullMethodName()); + + io.opentelemetry.api.common.Attributes clientAttributes + = io.opentelemetry.api.common.Attributes.of( + TARGET_KEY, target, + METHOD_KEY, method.getFullMethodName(), + STATUS_KEY, Status.Code.OK.toString()); + + io.opentelemetry.api.common.Attributes clientAttributesWithLocality + = clientAttributes.toBuilder() + .put(LOCALITY_KEY, "the-moon") + .build(); + + assertThat(openTelemetryTesting.getMetrics()) + .satisfiesExactlyInAnyOrder( + metric -> + assertThat(metric) + .hasName(CLIENT_ATTEMPT_COUNT_INSTRUMENT_NAME) + .hasLongSumSatisfying( + longSum -> longSum.hasPointsSatisfying( + point -> point.hasAttributes(attributes))), + metric -> + assertThat(metric) + .hasName(CLIENT_ATTEMPT_DURATION_INSTRUMENT_NAME) + .hasHistogramSatisfying( + histogram -> histogram.hasPointsSatisfying( + point -> point.hasAttributes(clientAttributesWithLocality))), + metric -> + assertThat(metric) + .hasName(CLIENT_ATTEMPT_SENT_TOTAL_COMPRESSED_MESSAGE_SIZE) + .hasHistogramSatisfying( + histogram -> histogram.hasPointsSatisfying( + point -> point.hasAttributes(clientAttributesWithLocality))), + metric -> + assertThat(metric) + .hasName(CLIENT_ATTEMPT_RECV_TOTAL_COMPRESSED_MESSAGE_SIZE) + .hasHistogramSatisfying( + histogram -> histogram.hasPointsSatisfying( + point -> point.hasAttributes(clientAttributesWithLocality))), + metric -> + assertThat(metric) + .hasName(CLIENT_CALL_DURATION) + .hasHistogramSatisfying( + histogram -> histogram.hasPointsSatisfying( + point -> point.hasAttributes(clientAttributes)))); + } + + @Test + public void clientLocalityMetrics_missing() { + String target = "target:///"; + OpenTelemetryMetricsResource resource = GrpcOpenTelemetry.createMetricInstruments(testMeter, + enabledMetricsMap, disableDefaultMetrics); + OpenTelemetryMetricsModule module = new OpenTelemetryMetricsModule( + fakeClock.getStopwatchSupplier(), resource, Arrays.asList("grpc.lb.locality"), + emptyList()); + OpenTelemetryMetricsModule.CallAttemptsTracerFactory callAttemptsTracerFactory = + new CallAttemptsTracerFactory(module, target, CALL_OPTIONS, method.getFullMethodName(), + emptyList(), Context.root()); + + ClientStreamTracer tracer = + callAttemptsTracerFactory.newClientStreamTracer(STREAM_INFO, new Metadata()); + tracer.streamClosed(Status.OK); + callAttemptsTracerFactory.callEnded(Status.OK, CALL_OPTIONS); + + io.opentelemetry.api.common.Attributes attributes = io.opentelemetry.api.common.Attributes.of( + TARGET_KEY, target, + METHOD_KEY, method.getFullMethodName()); + + io.opentelemetry.api.common.Attributes clientAttributes + = io.opentelemetry.api.common.Attributes.of( + TARGET_KEY, target, + METHOD_KEY, method.getFullMethodName(), + STATUS_KEY, Status.Code.OK.toString()); + + io.opentelemetry.api.common.Attributes clientAttributesWithLocality + = clientAttributes.toBuilder() + .put(LOCALITY_KEY, "") + .build(); + + assertThat(openTelemetryTesting.getMetrics()) + .satisfiesExactlyInAnyOrder( + metric -> + assertThat(metric) + .hasName(CLIENT_ATTEMPT_COUNT_INSTRUMENT_NAME) + .hasLongSumSatisfying( + longSum -> longSum.hasPointsSatisfying( + point -> point.hasAttributes(attributes))), + metric -> + assertThat(metric) + .hasName(CLIENT_ATTEMPT_DURATION_INSTRUMENT_NAME) + .hasHistogramSatisfying( + histogram -> histogram.hasPointsSatisfying( + point -> point.hasAttributes(clientAttributesWithLocality))), + metric -> + assertThat(metric) + .hasName(CLIENT_ATTEMPT_SENT_TOTAL_COMPRESSED_MESSAGE_SIZE) + .hasHistogramSatisfying( + histogram -> histogram.hasPointsSatisfying( + point -> point.hasAttributes(clientAttributesWithLocality))), + metric -> + assertThat(metric) + .hasName(CLIENT_ATTEMPT_RECV_TOTAL_COMPRESSED_MESSAGE_SIZE) + .hasHistogramSatisfying( + histogram -> histogram.hasPointsSatisfying( + point -> point.hasAttributes(clientAttributesWithLocality))), + metric -> + assertThat(metric) + .hasName(CLIENT_CALL_DURATION) + .hasHistogramSatisfying( + histogram -> histogram.hasPointsSatisfying( + point -> point.hasAttributes(clientAttributes)))); + } + + @Test + public void clientBackendServiceMetrics_present() { + String target = "target:///"; + OpenTelemetryMetricsResource resource = GrpcOpenTelemetry.createMetricInstruments(testMeter, + enabledMetricsMap, disableDefaultMetrics); + OpenTelemetryMetricsModule module = new OpenTelemetryMetricsModule( + fakeClock.getStopwatchSupplier(), resource, Arrays.asList("grpc.lb.backend_service"), + emptyList()); + OpenTelemetryMetricsModule.CallAttemptsTracerFactory callAttemptsTracerFactory = + new CallAttemptsTracerFactory(module, target, CALL_OPTIONS, method.getFullMethodName(), + emptyList(), Context.root()); + + ClientStreamTracer tracer = + callAttemptsTracerFactory.newClientStreamTracer(STREAM_INFO, new Metadata()); + tracer.addOptionalLabel("grpc.lb.foo", "unimportant"); + tracer.addOptionalLabel("grpc.lb.backend_service", "should-be-overwritten"); + tracer.addOptionalLabel("grpc.lb.backend_service", "the-moon"); + tracer.addOptionalLabel("grpc.lb.foo", "thats-no-moon"); + tracer.streamClosed(Status.OK); + callAttemptsTracerFactory.callEnded(Status.OK, CALL_OPTIONS); + + io.opentelemetry.api.common.Attributes attributes = io.opentelemetry.api.common.Attributes.of( + TARGET_KEY, target, + METHOD_KEY, method.getFullMethodName()); + + io.opentelemetry.api.common.Attributes clientAttributes + = io.opentelemetry.api.common.Attributes.of( + TARGET_KEY, target, + METHOD_KEY, method.getFullMethodName(), + STATUS_KEY, Status.Code.OK.toString()); + + io.opentelemetry.api.common.Attributes clientAttributesWithBackendService + = clientAttributes.toBuilder() + .put(AttributeKey.stringKey("grpc.lb.backend_service"), "the-moon") + .build(); + + assertThat(openTelemetryTesting.getMetrics()) + .satisfiesExactlyInAnyOrder( + metric -> + assertThat(metric) + .hasName(CLIENT_ATTEMPT_COUNT_INSTRUMENT_NAME) + .hasLongSumSatisfying( + longSum -> longSum.hasPointsSatisfying( + point -> point.hasAttributes(attributes))), + metric -> + assertThat(metric) + .hasName(CLIENT_ATTEMPT_DURATION_INSTRUMENT_NAME) + .hasHistogramSatisfying( + histogram -> histogram.hasPointsSatisfying( + point -> point.hasAttributes(clientAttributesWithBackendService))), + metric -> + assertThat(metric) + .hasName(CLIENT_ATTEMPT_SENT_TOTAL_COMPRESSED_MESSAGE_SIZE) + .hasHistogramSatisfying( + histogram -> histogram.hasPointsSatisfying( + point -> point.hasAttributes(clientAttributesWithBackendService))), + metric -> + assertThat(metric) + .hasName(CLIENT_ATTEMPT_RECV_TOTAL_COMPRESSED_MESSAGE_SIZE) + .hasHistogramSatisfying( + histogram -> histogram.hasPointsSatisfying( + point -> point.hasAttributes(clientAttributesWithBackendService))), + metric -> + assertThat(metric) + .hasName(CLIENT_CALL_DURATION) + .hasHistogramSatisfying( + histogram -> histogram.hasPointsSatisfying( + point -> point.hasAttributes(clientAttributes)))); + } + + @Test + public void clientBackendServiceMetrics_missing() { + String target = "target:///"; + OpenTelemetryMetricsResource resource = GrpcOpenTelemetry.createMetricInstruments(testMeter, + enabledMetricsMap, disableDefaultMetrics); + OpenTelemetryMetricsModule module = new OpenTelemetryMetricsModule( + fakeClock.getStopwatchSupplier(), resource, Arrays.asList("grpc.lb.backend_service"), + emptyList()); + OpenTelemetryMetricsModule.CallAttemptsTracerFactory callAttemptsTracerFactory = + new CallAttemptsTracerFactory(module, target, CALL_OPTIONS, method.getFullMethodName(), + emptyList(), Context.root()); + + ClientStreamTracer tracer = + callAttemptsTracerFactory.newClientStreamTracer(STREAM_INFO, new Metadata()); + tracer.streamClosed(Status.OK); + callAttemptsTracerFactory.callEnded(Status.OK, CALL_OPTIONS); + + io.opentelemetry.api.common.Attributes attributes = io.opentelemetry.api.common.Attributes.of( + TARGET_KEY, target, + METHOD_KEY, method.getFullMethodName()); + + io.opentelemetry.api.common.Attributes clientAttributes + = io.opentelemetry.api.common.Attributes.of( + TARGET_KEY, target, + METHOD_KEY, method.getFullMethodName(), + STATUS_KEY, Status.Code.OK.toString()); + + io.opentelemetry.api.common.Attributes clientAttributesWithBackendService + = clientAttributes.toBuilder() + .put(AttributeKey.stringKey("grpc.lb.backend_service"), "") + .build(); + + assertThat(openTelemetryTesting.getMetrics()) + .satisfiesExactlyInAnyOrder( + metric -> + assertThat(metric) + .hasName(CLIENT_ATTEMPT_COUNT_INSTRUMENT_NAME) + .hasLongSumSatisfying( + longSum -> longSum.hasPointsSatisfying( + point -> point.hasAttributes(attributes))), + metric -> + assertThat(metric) + .hasName(CLIENT_ATTEMPT_DURATION_INSTRUMENT_NAME) + .hasHistogramSatisfying( + histogram -> histogram.hasPointsSatisfying( + point -> point.hasAttributes(clientAttributesWithBackendService))), + metric -> + assertThat(metric) + .hasName(CLIENT_ATTEMPT_SENT_TOTAL_COMPRESSED_MESSAGE_SIZE) + .hasHistogramSatisfying( + histogram -> histogram.hasPointsSatisfying( + point -> point.hasAttributes(clientAttributesWithBackendService))), + metric -> + assertThat(metric) + .hasName(CLIENT_ATTEMPT_RECV_TOTAL_COMPRESSED_MESSAGE_SIZE) + .hasHistogramSatisfying( + histogram -> histogram.hasPointsSatisfying( + point -> point.hasAttributes(clientAttributesWithBackendService))), + metric -> + assertThat(metric) + .hasName(CLIENT_CALL_DURATION) + .hasHistogramSatisfying( + histogram -> histogram.hasPointsSatisfying( + point -> point.hasAttributes(clientAttributes)))); + } + + @Test + public void customLabel_present() { + Map enabledMetrics = ImmutableMap.of( + CLIENT_CALL_HEDGES, true, + CLIENT_CALL_RETRIES, true, + CLIENT_CALL_RETRY_DELAY, true, + CLIENT_CALL_TRANSPARENT_RETRIES, true + ); + String target = "target:///"; + String customValue = "some-random-value"; + CallOptions callOptions = + STREAM_INFO.getCallOptions().withOption(Grpc.CALL_OPTION_CUSTOM_LABEL, customValue); + OpenTelemetryMetricsResource resource = GrpcOpenTelemetry.createMetricInstruments(testMeter, + enabledMetrics, disableDefaultMetrics); + String customLabel = "grpc.client.call.custom"; + OpenTelemetryMetricsModule module = new OpenTelemetryMetricsModule( + fakeClock.getStopwatchSupplier(), resource, Arrays.asList(customLabel), + emptyList()); + OpenTelemetryMetricsModule.CallAttemptsTracerFactory callAttemptsTracerFactory = + new CallAttemptsTracerFactory( + module, target, callOptions, method.getFullMethodName(), emptyList(), Context.root()); + + ClientStreamTracer.StreamInfo streamInfo = + STREAM_INFO.toBuilder().setCallOptions(callOptions).build(); + ClientStreamTracer tracer = + callAttemptsTracerFactory.newClientStreamTracer(streamInfo, new Metadata()); + tracer.streamClosed(Status.UNAVAILABLE); + + tracer = callAttemptsTracerFactory.newClientStreamTracer(streamInfo, new Metadata()); + tracer.streamClosed(Status.UNAVAILABLE); + + tracer = callAttemptsTracerFactory.newClientStreamTracer( + streamInfo.toBuilder().setIsTransparentRetry(true).build(), new Metadata()); + tracer.streamClosed(Status.UNAVAILABLE); + + tracer = callAttemptsTracerFactory.newClientStreamTracer( + streamInfo.toBuilder().setIsHedging(true).build(), new Metadata()); + tracer.streamClosed(Status.OK); + callAttemptsTracerFactory.callEnded(Status.OK, callOptions); + + AttributeKey attributeKey = AttributeKey.stringKey(customLabel); + + assertThat(sortByName(openTelemetryTesting.getMetrics())) + .satisfiesExactly( + metric -> assertThat(metric) + .hasName(CLIENT_ATTEMPT_DURATION_INSTRUMENT_NAME) + .hasHistogramSatisfying( + histogram -> histogram.hasPointsSatisfying( + point -> point.hasAttribute(attributeKey, customValue), + point -> point.hasAttribute(attributeKey, customValue))), + metric -> assertThat(metric) + .hasName(CLIENT_ATTEMPT_RECV_TOTAL_COMPRESSED_MESSAGE_SIZE) + .hasHistogramSatisfying( + histogram -> histogram.hasPointsSatisfying( + point -> point.hasAttribute(attributeKey, customValue), + point -> point.hasAttribute(attributeKey, customValue))), + metric -> assertThat(metric) + .hasName(CLIENT_ATTEMPT_SENT_TOTAL_COMPRESSED_MESSAGE_SIZE) + .hasHistogramSatisfying( + histogram -> histogram.hasPointsSatisfying( + point -> point.hasAttribute(attributeKey, customValue), + point -> point.hasAttribute(attributeKey, customValue))), + metric -> assertThat(metric) + .hasName(CLIENT_ATTEMPT_COUNT_INSTRUMENT_NAME) + .hasLongSumSatisfying( + longSum -> longSum.hasPointsSatisfying( + point -> point.hasAttribute(attributeKey, customValue))), + metric -> assertThat(metric) + .hasName(CLIENT_CALL_DURATION) + .hasHistogramSatisfying( + histogram -> histogram.hasPointsSatisfying( + point -> point.hasAttribute(attributeKey, customValue))), + metric -> assertThat(metric) + .hasName(CLIENT_CALL_HEDGES) + .hasHistogramSatisfying( + histogram -> histogram.hasPointsSatisfying( + point -> point.hasAttribute(attributeKey, customValue))), + metric -> assertThat(metric) + .hasName(CLIENT_CALL_RETRIES) + .hasHistogramSatisfying( + histogram -> histogram.hasPointsSatisfying( + point -> point.hasAttribute(attributeKey, customValue))), + metric -> assertThat(metric) + .hasName(CLIENT_CALL_RETRY_DELAY) + .hasHistogramSatisfying( + histogram -> histogram.hasPointsSatisfying( + point -> point.hasAttribute(attributeKey, customValue))), + metric -> assertThat(metric) + .hasName(CLIENT_CALL_TRANSPARENT_RETRIES) + .hasHistogramSatisfying( + histogram -> histogram.hasPointsSatisfying( + point -> point.hasAttribute(attributeKey, customValue)))); + } + + @Test + public void serverBasicMetrics() { + OpenTelemetryMetricsResource resource = GrpcOpenTelemetry.createMetricInstruments(testMeter, + enabledMetricsMap, disableDefaultMetrics); + OpenTelemetryMetricsModule module = newOpenTelemetryMetricsModule(resource); ServerStreamTracer.Factory tracerFactory = module.getServerTracerFactory(); ServerStreamTracer tracer = tracerFactory.newServerStreamTracer(method.getFullMethodName(), new Metadata()); @@ -923,7 +1676,10 @@ public void serverBasicMetrics() { point .hasCount(1) .hasSum(1028L + 99) - .hasAttributes(serverAttributes))), + .hasAttributes(serverAttributes) + .hasBucketBoundaries(sizeBuckets) + .hasBucketCounts(0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0))), metric -> assertThat(metric) .hasInstrumentationScope(InstrumentationScopeInfo.create( @@ -951,7 +1707,11 @@ public void serverBasicMetrics() { point .hasCount(1) .hasSum(0.1 + 0.016 + 0.024) - .hasAttributes(serverAttributes))), + .hasAttributes(serverAttributes) + .hasBucketBoundaries(latencyBuckets) + .hasBucketCounts(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0))), metric -> assertThat(metric) .hasInstrumentationScope(InstrumentationScopeInfo.create( @@ -968,8 +1728,279 @@ public void serverBasicMetrics() { point .hasCount(1) .hasSum(34L + 154) - .hasAttributes(serverAttributes)))); + .hasAttributes(serverAttributes) + .hasBucketBoundaries(sizeBuckets) + .hasBucketCounts(0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0)))); + + } + + @Test + public void serverMetrics_methodResolvedBeforeStreamClosed_generatedMethodRecordsName() { + OpenTelemetryMetricsResource resource = GrpcOpenTelemetry.createMetricInstruments(testMeter, + enabledMetricsMap, disableDefaultMetrics); + OpenTelemetryMetricsModule module = newOpenTelemetryMetricsModule(resource); + ServerStreamTracer.Factory tracerFactory = module.getServerTracerFactory(); + ServerStreamTracer tracer = + tracerFactory.newServerStreamTracer(method.getFullMethodName(), new Metadata()); + + ((ServerCallMethodListener) tracer).serverCallMethodResolved(method); + fakeClock.forwardTime(10, MILLISECONDS); + tracer.streamClosed(Status.CANCELLED); + + io.opentelemetry.api.common.Attributes serverAttributes = + io.opentelemetry.api.common.Attributes.of( + METHOD_KEY, method.getFullMethodName(), + STATUS_KEY, Code.CANCELLED.toString()); + + assertThat(openTelemetryTesting.getMetrics()) + .anySatisfy( + metric -> + assertThat(metric) + .hasName(SERVER_CALL_DURATION) + .hasUnit("s") + .hasHistogramSatisfying( + histogram -> + histogram.hasPointsSatisfying( + point -> + point + .hasCount(1) + .hasSum(0.01) + .hasAttributes(serverAttributes)))); + } + + @Test + public void serverMetrics_methodResolvedBeforeStreamClosed_nonGeneratedMethodRecordsOther() { + MethodDescriptor nonGeneratedMethod = + method.toBuilder().setSampledToLocalTracing(false).build(); + OpenTelemetryMetricsResource resource = GrpcOpenTelemetry.createMetricInstruments(testMeter, + enabledMetricsMap, disableDefaultMetrics); + OpenTelemetryMetricsModule module = newOpenTelemetryMetricsModule(resource); + ServerStreamTracer.Factory tracerFactory = module.getServerTracerFactory(); + ServerStreamTracer tracer = + tracerFactory.newServerStreamTracer(nonGeneratedMethod.getFullMethodName(), new Metadata()); + + ((ServerCallMethodListener) tracer).serverCallMethodResolved(nonGeneratedMethod); + fakeClock.forwardTime(10, MILLISECONDS); + tracer.streamClosed(Status.CANCELLED); + + io.opentelemetry.api.common.Attributes serverAttributes = + io.opentelemetry.api.common.Attributes.of( + METHOD_KEY, "other", + STATUS_KEY, Code.CANCELLED.toString()); + + assertThat(openTelemetryTesting.getMetrics()) + .anySatisfy( + metric -> + assertThat(metric) + .hasName(SERVER_CALL_DURATION) + .hasUnit("s") + .hasHistogramSatisfying( + histogram -> + histogram.hasPointsSatisfying( + point -> + point + .hasCount(1) + .hasSum(0.01) + .hasAttributes(serverAttributes)))); + } + + @Test + public void serverMetrics_serverCallStarted_nonGeneratedMethodRecordsOther() { + MethodDescriptor nonGeneratedMethod = + method.toBuilder().setSampledToLocalTracing(false).build(); + OpenTelemetryMetricsResource resource = GrpcOpenTelemetry.createMetricInstruments(testMeter, + enabledMetricsMap, disableDefaultMetrics); + OpenTelemetryMetricsModule module = newOpenTelemetryMetricsModule(resource); + ServerStreamTracer.Factory tracerFactory = module.getServerTracerFactory(); + ServerStreamTracer tracer = + tracerFactory.newServerStreamTracer(nonGeneratedMethod.getFullMethodName(), new Metadata()); + tracer.serverCallStarted( + new CallInfo<>(nonGeneratedMethod, Attributes.EMPTY, null)); + + io.opentelemetry.api.common.Attributes startedAttributes = + io.opentelemetry.api.common.Attributes.of(METHOD_KEY, "other"); + + assertThat(openTelemetryTesting.getMetrics()) + .anySatisfy( + metric -> + assertThat(metric) + .hasName(SERVER_CALL_COUNT) + .hasUnit("{call}") + .hasLongSumSatisfying( + longSum -> + longSum.hasPointsSatisfying( + point -> + point + .hasAttributes(startedAttributes) + .hasValue(1)))); + + fakeClock.forwardTime(10, MILLISECONDS); + tracer.streamClosed(Status.CANCELLED); + + io.opentelemetry.api.common.Attributes closedAttributes = + io.opentelemetry.api.common.Attributes.of( + METHOD_KEY, "other", + STATUS_KEY, Code.CANCELLED.toString()); + + assertThat(openTelemetryTesting.getMetrics()) + .anySatisfy( + metric -> + assertThat(metric) + .hasName(SERVER_CALL_DURATION) + .hasUnit("s") + .hasHistogramSatisfying( + histogram -> + histogram.hasPointsSatisfying( + point -> + point + .hasCount(1) + .hasSum(0.01) + .hasAttributes(closedAttributes)))); + } + + @Test + public void targetAttributeFilter_notSet_usesOriginalTarget() { + // Test that when no filter is set, the original target is used + String target = "dns:///example.com"; + OpenTelemetryMetricsResource resource = GrpcOpenTelemetry.createMetricInstruments(testMeter, + enabledMetricsMap, disableDefaultMetrics); + OpenTelemetryMetricsModule module = newOpenTelemetryMetricsModule(resource); + + Channel interceptedChannel = + ClientInterceptors.intercept( + grpcServerRule.getChannel(), module.getClientInterceptor(target)); + + ClientCall call = interceptedChannel.newCall(method, CALL_OPTIONS); + + // Make the call + Metadata headers = new Metadata(); + call.start(mockClientCallListener, headers); + + // End the call + call.halfClose(); + call.request(1); + + io.opentelemetry.api.common.Attributes attributes = io.opentelemetry.api.common.Attributes.of( + TARGET_KEY, target, + METHOD_KEY, method.getFullMethodName()); + + assertThat(openTelemetryTesting.getMetrics()) + .anySatisfy( + metric -> + assertThat(metric) + .hasInstrumentationScope(InstrumentationScopeInfo.create( + OpenTelemetryConstants.INSTRUMENTATION_SCOPE)) + .hasName(CLIENT_ATTEMPT_COUNT_INSTRUMENT_NAME) + .hasUnit("{attempt}") + .hasLongSumSatisfying( + longSum -> + longSum + .hasPointsSatisfying( + point -> + point + .hasAttributes(attributes)))); + } + + @Test + public void targetAttributeFilter_allowsTarget_usesOriginalTarget() { + // Test that when filter allows the target, the original target is used + String target = "dns:///example.com"; + OpenTelemetryMetricsResource resource = GrpcOpenTelemetry.createMetricInstruments(testMeter, + enabledMetricsMap, disableDefaultMetrics); + OpenTelemetryMetricsModule module = newOpenTelemetryMetricsModule(resource, + t -> t.contains("example.com")); + + Channel interceptedChannel = + ClientInterceptors.intercept( + grpcServerRule.getChannel(), module.getClientInterceptor(target)); + + ClientCall call = interceptedChannel.newCall(method, CALL_OPTIONS); + + // Make the call + Metadata headers = new Metadata(); + call.start(mockClientCallListener, headers); + + // End the call + call.halfClose(); + call.request(1); + + io.opentelemetry.api.common.Attributes attributes = io.opentelemetry.api.common.Attributes.of( + TARGET_KEY, target, + METHOD_KEY, method.getFullMethodName()); + + assertThat(openTelemetryTesting.getMetrics()) + .anySatisfy( + metric -> + assertThat(metric) + .hasInstrumentationScope(InstrumentationScopeInfo.create( + OpenTelemetryConstants.INSTRUMENTATION_SCOPE)) + .hasName(CLIENT_ATTEMPT_COUNT_INSTRUMENT_NAME) + .hasUnit("{attempt}") + .hasLongSumSatisfying( + longSum -> + longSum + .hasPointsSatisfying( + point -> + point + .hasAttributes(attributes)))); + } + + @Test + public void targetAttributeFilter_rejectsTarget_mapsToOther() { + // Test that when filter rejects the target, it is mapped to "other" + String target = "dns:///example.com"; + OpenTelemetryMetricsResource resource = GrpcOpenTelemetry.createMetricInstruments(testMeter, + enabledMetricsMap, disableDefaultMetrics); + OpenTelemetryMetricsModule module = newOpenTelemetryMetricsModule(resource, + t -> t.contains("allowed.com")); + + Channel interceptedChannel = + ClientInterceptors.intercept( + grpcServerRule.getChannel(), module.getClientInterceptor(target)); + + ClientCall call = interceptedChannel.newCall(method, CALL_OPTIONS); + + // Make the call + Metadata headers = new Metadata(); + call.start(mockClientCallListener, headers); + + // End the call + call.halfClose(); + call.request(1); + + io.opentelemetry.api.common.Attributes attributes = io.opentelemetry.api.common.Attributes.of( + TARGET_KEY, "other", + METHOD_KEY, method.getFullMethodName()); + + assertThat(openTelemetryTesting.getMetrics()) + .anySatisfy( + metric -> + assertThat(metric) + .hasInstrumentationScope(InstrumentationScopeInfo.create( + OpenTelemetryConstants.INSTRUMENTATION_SCOPE)) + .hasName(CLIENT_ATTEMPT_COUNT_INSTRUMENT_NAME) + .hasUnit("{attempt}") + .hasLongSumSatisfying( + longSum -> + longSum + .hasPointsSatisfying( + point -> + point + .hasAttributes(attributes)))); + } + + private OpenTelemetryMetricsModule newOpenTelemetryMetricsModule( + OpenTelemetryMetricsResource resource) { + return new OpenTelemetryMetricsModule( + fakeClock.getStopwatchSupplier(), resource, emptyList(), emptyList()); + } + private OpenTelemetryMetricsModule newOpenTelemetryMetricsModule( + OpenTelemetryMetricsResource resource, TargetFilter filter) { + return new OpenTelemetryMetricsModule( + fakeClock.getStopwatchSupplier(), resource, emptyList(), emptyList(), + filter); } static class CallInfo extends ServerCallInfo { @@ -1002,4 +2033,130 @@ public String getAuthority() { return authority; } } + + @Test + public void serverMetrics_recordsBaggage() { + DoubleHistogram mockDurationHistogram = mock(DoubleHistogram.class); + OpenTelemetryMetricsResource mockResource = OpenTelemetryMetricsResource.builder() + .serverCallDurationCounter(mockDurationHistogram) + .build(); + + OpenTelemetryMetricsModule module = newOpenTelemetryMetricsModule(mockResource); + ServerStreamTracer.Factory tracerFactory = module.getServerTracerFactory(); + + Baggage baggage = Baggage.builder() + .put("baggage-key-1", "baggage-val-1") + .build(); + + io.grpc.Context grpcContext = io.grpc.Context.ROOT + .withValue(OpenTelemetryConstants.BAGGAGE_KEY, baggage); + io.grpc.Context previous = grpcContext.attach(); + + ServerStreamTracer tracer; + try { + tracer = tracerFactory.newServerStreamTracer( + method.getFullMethodName(), new Metadata()); + tracer.filterContext(grpcContext); + tracer.serverCallStarted( + new CallInfo<>(method, Attributes.EMPTY, null)); + } finally { + grpcContext.detach(previous); + } + + try (io.opentelemetry.context.Scope scope = Context.root().makeCurrent()) { + tracer.streamClosed(Status.CANCELLED); + } + + ArgumentCaptor contextCaptor = ArgumentCaptor.forClass(Context.class); + verify(mockDurationHistogram).record( + anyDouble(), + any(), + contextCaptor.capture()); + + Baggage capturedBaggage = Baggage.fromContext(contextCaptor.getValue()); + assertNotNull("Captured context should have baggage", capturedBaggage); + assertEquals( + "baggage-val-1", capturedBaggage.getEntryValue("baggage-key-1")); + } + + @Test + public void serverMetrics_recordsBaggage_endToEnd() throws Exception { + DoubleHistogram mockDurationHistogram = mock(DoubleHistogram.class); + OpenTelemetryMetricsResource mockResource = OpenTelemetryMetricsResource.builder() + .serverCallDurationCounter(mockDurationHistogram) + .build(); + + OpenTelemetry openTelemetry = OpenTelemetrySdk + .builder() + .setPropagators(ContextPropagators.create( + W3CBaggagePropagator.getInstance())) + .build(); + + OpenTelemetryMetricsModule module = newOpenTelemetryMetricsModule(mockResource); + OpenTelemetryTracingModule tracingModule = new OpenTelemetryTracingModule(openTelemetry); + + String serverName = InProcessServerBuilder.generateName(); + InProcessServerBuilder serverBuilder = InProcessServerBuilder + .forName(serverName).directExecutor(); + + serverBuilder.addStreamTracerFactory(tracingModule.getServerTracerFactory()); + serverBuilder.intercept(tracingModule.getServerSpanPropagationInterceptor()); + serverBuilder.addStreamTracerFactory(module.getServerTracerFactory()); + + serverBuilder.addService(ServerServiceDefinition.builder( + ServiceDescriptor.newBuilder("package1.service2") + .addMethod(method) + .build()) + .addMethod(method, new ServerCallHandler() { + @Override + public ServerCall.Listener startCall( + ServerCall call, Metadata headers) { + call.sendHeaders(new Metadata()); + call.sendMessage("response"); + call.close(Status.OK, new Metadata()); + return new ServerCall.Listener() { + }; + } + }).build()); + grpcCleanup.register(serverBuilder.build().start()); + + InProcessChannelBuilder channelBuilder = InProcessChannelBuilder + .forName(serverName).directExecutor(); + channelBuilder.intercept(tracingModule.getClientInterceptor()); + channelBuilder.intercept(module.getClientInterceptor(serverName)); + Channel channel = grpcCleanup.register(channelBuilder.intercept(new ClientInterceptor() { + @Override + public ClientCall interceptCall( + MethodDescriptor method, CallOptions callOptions, Channel next) { + return next.newCall(method, callOptions); + } + }).build()); + + Baggage baggage = Baggage.builder() + .put("baggage-key-1", "baggage-val-1") + .build(); + + Context otelContext = Context.root().with(baggage); + + try (Scope scope = otelContext.makeCurrent()) { + ClientCalls.blockingUnaryCall(channel, + method, CallOptions.DEFAULT, "request"); + } + + ArgumentCaptor contextCaptor = ArgumentCaptor.forClass(Context.class); + verify(mockDurationHistogram).record( + anyDouble(), + any(), + contextCaptor.capture()); + + Baggage capturedBaggage = Baggage.fromContext(contextCaptor.getValue()); + assertNotNull("Captured context should have baggage", capturedBaggage); + assertEquals( + "baggage-val-1", capturedBaggage.getEntryValue("baggage-key-1")); + } + + private static List sortByName(List metrics) { + metrics.sort((m1, m2) -> m1.getName().compareTo(m2.getName())); + return metrics; + } } diff --git a/opentelemetry/src/test/java/io/grpc/opentelemetry/OpenTelemetryModuleTest.java b/opentelemetry/src/test/java/io/grpc/opentelemetry/OpenTelemetryModuleTest.java deleted file mode 100644 index 28d3026dd23..00000000000 --- a/opentelemetry/src/test/java/io/grpc/opentelemetry/OpenTelemetryModuleTest.java +++ /dev/null @@ -1,68 +0,0 @@ -/* - * Copyright 2023 The gRPC Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.grpc.opentelemetry; - -import static com.google.common.truth.Truth.assertThat; - -import io.grpc.internal.GrpcUtil; -import io.opentelemetry.api.OpenTelemetry; -import io.opentelemetry.sdk.OpenTelemetrySdk; -import io.opentelemetry.sdk.metrics.SdkMeterProvider; -import io.opentelemetry.sdk.testing.exporter.InMemoryMetricReader; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; - -@RunWith(JUnit4.class) -public class OpenTelemetryModuleTest { - private final InMemoryMetricReader inMemoryMetricReader = InMemoryMetricReader.create(); - private final SdkMeterProvider meterProvider = - SdkMeterProvider.builder().registerMetricReader(inMemoryMetricReader).build(); - private final OpenTelemetry noopOpenTelemetry = OpenTelemetry.noop(); - - @Test - public void build() { - OpenTelemetrySdk sdk = - OpenTelemetrySdk.builder().setMeterProvider(meterProvider).build(); - OpenTelemetryModule openTelemetryModule = OpenTelemetryModule.newBuilder() - .sdk(sdk) - .build(); - - assertThat(openTelemetryModule.getOpenTelemetryInstance()).isSameInstanceAs(sdk); - assertThat(openTelemetryModule.getMeterProvider()).isNotNull(); - assertThat(openTelemetryModule.getMeter()).isSameInstanceAs( - meterProvider.meterBuilder("grpc-java") - .setInstrumentationVersion(GrpcUtil.IMPLEMENTATION_VERSION) - .build()); - } - - @Test - public void builderDefaults() { - OpenTelemetryModule module = OpenTelemetryModule.newBuilder().build(); - - assertThat(module.getOpenTelemetryInstance()).isNotNull(); - assertThat(module.getOpenTelemetryInstance()).isSameInstanceAs(noopOpenTelemetry); - assertThat(module.getMeterProvider()).isNotNull(); - assertThat(module.getMeterProvider()) - .isSameInstanceAs(noopOpenTelemetry.getMeterProvider()); - assertThat(module.getMeter()).isSameInstanceAs(noopOpenTelemetry - .getMeterProvider() - .meterBuilder("grpc-java") - .setInstrumentationVersion(GrpcUtil.IMPLEMENTATION_VERSION) - .build()); - } -} diff --git a/opentelemetry/src/test/java/io/grpc/opentelemetry/OpenTelemetryTracingModuleTest.java b/opentelemetry/src/test/java/io/grpc/opentelemetry/OpenTelemetryTracingModuleTest.java new file mode 100644 index 00000000000..e6759aadb1e --- /dev/null +++ b/opentelemetry/src/test/java/io/grpc/opentelemetry/OpenTelemetryTracingModuleTest.java @@ -0,0 +1,878 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.opentelemetry; + +import static io.grpc.ClientStreamTracer.NAME_RESOLUTION_DELAYED; +import static io.grpc.opentelemetry.internal.OpenTelemetryConstants.BAGGAGE_KEY; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.ArgumentMatchers.same; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.inOrder; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.when; + +import com.google.common.collect.ImmutableSet; +import io.grpc.Attributes; +import io.grpc.CallOptions; +import io.grpc.Channel; +import io.grpc.ClientCall; +import io.grpc.ClientInterceptor; +import io.grpc.ClientInterceptors; +import io.grpc.ClientStreamTracer; +import io.grpc.KnownLength; +import io.grpc.ManagedChannel; +import io.grpc.Metadata; +import io.grpc.MethodDescriptor; +import io.grpc.NoopServerCall; +import io.grpc.Server; +import io.grpc.ServerCall; +import io.grpc.ServerCallHandler; +import io.grpc.ServerInterceptor; +import io.grpc.ServerInterceptors; +import io.grpc.ServerServiceDefinition; +import io.grpc.ServerStreamTracer; +import io.grpc.Status; +import io.grpc.inprocess.InProcessChannelBuilder; +import io.grpc.inprocess.InProcessServerBuilder; +import io.grpc.opentelemetry.OpenTelemetryTracingModule.CallAttemptsTracerFactory; +import io.grpc.opentelemetry.internal.OpenTelemetryConstants; +import io.grpc.testing.GrpcCleanupRule; +import io.grpc.testing.GrpcServerRule; +import io.opentelemetry.api.OpenTelemetry; +import io.opentelemetry.api.baggage.Baggage; +import io.opentelemetry.api.trace.Span; +import io.opentelemetry.api.trace.SpanBuilder; +import io.opentelemetry.api.trace.SpanContext; +import io.opentelemetry.api.trace.SpanId; +import io.opentelemetry.api.trace.StatusCode; +import io.opentelemetry.api.trace.TraceFlags; +import io.opentelemetry.api.trace.TraceId; +import io.opentelemetry.api.trace.TraceState; +import io.opentelemetry.api.trace.Tracer; +import io.opentelemetry.api.trace.TracerBuilder; +import io.opentelemetry.api.trace.TracerProvider; +import io.opentelemetry.context.Context; +import io.opentelemetry.context.Scope; +import io.opentelemetry.context.propagation.ContextPropagators; +import io.opentelemetry.context.propagation.TextMapPropagator; +import io.opentelemetry.sdk.testing.junit4.OpenTelemetryRule; +import io.opentelemetry.sdk.trace.data.EventData; +import io.opentelemetry.sdk.trace.data.SpanData; +import java.io.IOException; +import java.io.InputStream; +import java.util.Arrays; +import java.util.List; +import java.util.concurrent.atomic.AtomicReference; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.ArgumentCaptor; +import org.mockito.Captor; +import org.mockito.InOrder; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnit; +import org.mockito.junit.MockitoRule; + +@RunWith(JUnit4.class) +public class OpenTelemetryTracingModuleTest { + @Rule + public final MockitoRule mocks = MockitoJUnit.rule(); + + private static final ClientStreamTracer.StreamInfo STREAM_INFO = + ClientStreamTracer.StreamInfo.newBuilder() + .setCallOptions(CallOptions.DEFAULT.withOption(NAME_RESOLUTION_DELAYED, 10L)).build(); + private static final CallOptions.Key CUSTOM_OPTION = + CallOptions.Key.createWithDefault("option1", "default"); + private static final CallOptions CALL_OPTIONS = + CallOptions.DEFAULT.withOption(CUSTOM_OPTION, "customvalue"); + + private static class StringInputStream extends InputStream implements KnownLength { + final String string; + + StringInputStream(String string) { + this.string = string; + } + + @Override + public int read() { + // InProcessTransport doesn't actually read bytes from the InputStream. The InputStream is + // passed to the InProcess server and consumed by MARSHALLER.parse(). + throw new UnsupportedOperationException("Should not be called"); + } + + @Override + public int available() throws IOException { + return string == null ? 0 : string.length(); + } + } + + private static final MethodDescriptor.Marshaller MARSHALLER = + new MethodDescriptor.Marshaller() { + @Override + public InputStream stream(String value) { + return new StringInputStream(value); + } + + @Override + public String parse(InputStream stream) { + return ((StringInputStream) stream).string; + } + }; + + private final MethodDescriptor method = + MethodDescriptor.newBuilder() + .setType(MethodDescriptor.MethodType.UNKNOWN) + .setRequestMarshaller(MARSHALLER) + .setResponseMarshaller(MARSHALLER) + .setFullMethodName("package1.service2/method3") + .build(); + + @Rule + public final OpenTelemetryRule openTelemetryRule = OpenTelemetryRule.create(); + @Rule + public final GrpcServerRule grpcServerRule = new GrpcServerRule().directExecutor(); + @Rule + public final GrpcCleanupRule grpcCleanupRule = new GrpcCleanupRule(); + private Tracer tracerRule; + @Mock + private Tracer mockTracer; + @Mock + TextMapPropagator mockPropagator; + @Mock + private Span mockClientSpan; + @Mock + private Span mockAttemptSpan; + @Mock + private ServerCall.Listener mockServerCallListener; + @Mock + private ClientCall.Listener mockClientCallListener; + @Mock + private SpanBuilder mockSpanBuilder; + @Mock + private OpenTelemetry mockOpenTelemetry; + @Captor + private ArgumentCaptor eventNameCaptor; + @Captor + private ArgumentCaptor attributesCaptor; + @Captor + private ArgumentCaptor statusCaptor; + + @Before + public void setUp() { + tracerRule = openTelemetryRule.getOpenTelemetry().getTracer( + OpenTelemetryConstants.INSTRUMENTATION_SCOPE); + TracerProvider mockTracerProvider = mock(TracerProvider.class); + when(mockOpenTelemetry.getTracerProvider()).thenReturn(mockTracerProvider); + TracerBuilder mockTracerBuilder = mock(TracerBuilder.class); + when(mockTracerProvider.tracerBuilder(OpenTelemetryConstants.INSTRUMENTATION_SCOPE)) + .thenReturn(mockTracerBuilder); + when(mockTracerBuilder.setInstrumentationVersion(any())).thenReturn(mockTracerBuilder); + when(mockTracerBuilder.build()).thenReturn(mockTracer); + when(mockOpenTelemetry.getPropagators()).thenReturn(ContextPropagators.create(mockPropagator)); + when(mockSpanBuilder.startSpan()).thenReturn(mockAttemptSpan); + when(mockSpanBuilder.setParent(any())).thenReturn(mockSpanBuilder); + when(mockTracer.spanBuilder(any())).thenReturn(mockSpanBuilder); + } + + // Use mock instead of OpenTelemetryRule to verify inOrder and propagator. + @Test + public void clientBasicTracingMocking() { + OpenTelemetryTracingModule tracingModule = new OpenTelemetryTracingModule(mockOpenTelemetry); + CallAttemptsTracerFactory callTracer = + tracingModule.newClientCallTracer(mockClientSpan, method); + Metadata headers = new Metadata(); + ClientStreamTracer clientStreamTracer = callTracer.newClientStreamTracer(STREAM_INFO, headers); + clientStreamTracer.createPendingStream(); + clientStreamTracer.streamCreated(Attributes.EMPTY, headers); + + verify(mockTracer).spanBuilder(eq("Attempt.package1.service2.method3")); + verify(mockPropagator).inject(any(), eq(headers), eq(MetadataSetter.getInstance())); + verify(mockClientSpan, never()).end(); + verify(mockAttemptSpan, never()).end(); + + clientStreamTracer.outboundMessage(0); + clientStreamTracer.outboundMessageSent(0, 882, -1); + clientStreamTracer.inboundMessage(0); + clientStreamTracer.outboundMessage(1); + clientStreamTracer.outboundMessageSent(1, -1, 27); + clientStreamTracer.inboundMessageRead(0, 255, 90); + + clientStreamTracer.streamClosed(Status.OK); + callTracer.callEnded(Status.OK); + + InOrder inOrder = inOrder(mockClientSpan, mockAttemptSpan); + inOrder.verify(mockAttemptSpan) + .setAttribute("previous-rpc-attempts", 0); + inOrder.verify(mockAttemptSpan) + .setAttribute("transparent-retry", false); + inOrder.verify(mockClientSpan).addEvent("Delayed name resolution complete"); + inOrder.verify(mockAttemptSpan).addEvent("Delayed LB pick complete"); + inOrder.verify(mockAttemptSpan, times(3)).addEvent( + eventNameCaptor.capture(), attributesCaptor.capture() + ); + List events = eventNameCaptor.getAllValues(); + List attributes = attributesCaptor.getAllValues(); + assertEquals( + "Outbound message" , + events.get(0)); + assertEquals( + io.opentelemetry.api.common.Attributes.builder() + .put("sequence-number", 0) + .put("message-size-compressed", 882) + .build(), + attributes.get(0)); + + assertEquals( + "Outbound message" , + events.get(1)); + assertEquals( + io.opentelemetry.api.common.Attributes.builder() + .put("sequence-number", 1) + .put("message-size", 27) + .build(), + attributes.get(1)); + + assertEquals( + "Inbound compressed message" , + events.get(2)); + assertEquals( + io.opentelemetry.api.common.Attributes.builder() + .put("sequence-number", 0) + .put("message-size-compressed", 255) + .build(), + attributes.get(2)); + + inOrder.verify(mockAttemptSpan).setStatus(StatusCode.OK); + inOrder.verify(mockAttemptSpan).end(); + inOrder.verify(mockClientSpan).setStatus(StatusCode.OK); + inOrder.verify(mockClientSpan).end(); + inOrder.verifyNoMoreInteractions(); + } + + @Test + public void clientBasicTracingRule() { + OpenTelemetryTracingModule tracingModule = new OpenTelemetryTracingModule( + openTelemetryRule.getOpenTelemetry()); + Span clientSpan = tracerRule.spanBuilder("test-client-span").startSpan(); + CallAttemptsTracerFactory callTracer = + tracingModule.newClientCallTracer(clientSpan, method); + Metadata headers = new Metadata(); + ClientStreamTracer clientStreamTracer = callTracer.newClientStreamTracer(STREAM_INFO, headers); + clientStreamTracer.createPendingStream(); + clientStreamTracer.streamCreated(Attributes.EMPTY, headers); + clientStreamTracer.outboundMessage(0); + clientStreamTracer.outboundMessageSent(0, 882, -1); + clientStreamTracer.inboundMessage(0); + clientStreamTracer.outboundMessage(1); + clientStreamTracer.outboundMessageSent(1, -1, 27); + clientStreamTracer.inboundMessageRead(0, 255, -1); + clientStreamTracer.inboundUncompressedSize(288); + clientStreamTracer.inboundMessageRead(1, 128, 128); + clientStreamTracer.inboundMessage(1); + clientStreamTracer.inboundUncompressedSize(128); + + clientStreamTracer.streamClosed(Status.OK); + callTracer.callEnded(Status.OK); + + List spans = openTelemetryRule.getSpans(); + assertEquals(spans.size(), 2); + SpanData attemptSpanData = spans.get(0); + SpanData clientSpanData = spans.get(1); + assertEquals(attemptSpanData.getName(), "Attempt.package1.service2.method3"); + assertEquals(clientSpanData.getName(), "test-client-span"); + assertEquals(headers.keys(), ImmutableSet.of("traceparent")); + String spanContext = headers.get( + Metadata.Key.of("traceparent", Metadata.ASCII_STRING_MARSHALLER)); + assertEquals(spanContext.substring(3, 3 + TraceId.getLength()), + spans.get(1).getSpanContext().getTraceId()); + + // parent(client) span data + List clientSpanEvents = clientSpanData.getEvents(); + assertEquals(clientSpanEvents.size(), 3); + assertEquals( + "Delayed name resolution complete", + clientSpanEvents.get(0).getName()); + assertTrue(clientSpanEvents.get(0).getAttributes().isEmpty()); + + assertEquals( + "Inbound message" , + clientSpanEvents.get(1).getName()); + assertEquals( + io.opentelemetry.api.common.Attributes.builder() + .put("sequence-number", 0) + .put("message-size", 288) + .build(), + clientSpanEvents.get(1).getAttributes()); + + assertEquals( + "Inbound message" , + clientSpanEvents.get(2).getName()); + assertEquals( + io.opentelemetry.api.common.Attributes.builder() + .put("sequence-number", 1) + .put("message-size", 128) + .build(), + clientSpanEvents.get(2).getAttributes()); + assertEquals(clientSpanData.hasEnded(), true); + + // child(attempt) span data + List attemptSpanEvents = attemptSpanData.getEvents(); + assertEquals(clientSpanEvents.size(), 3); + assertEquals( + "Delayed LB pick complete", + attemptSpanEvents.get(0).getName()); + assertTrue(clientSpanEvents.get(0).getAttributes().isEmpty()); + + assertEquals( + "Outbound message" , + attemptSpanEvents.get(1).getName()); + assertEquals( + io.opentelemetry.api.common.Attributes.builder() + .put("sequence-number", 0) + .put("message-size-compressed", 882) + .build(), + attemptSpanEvents.get(1).getAttributes()); + + assertEquals( + "Outbound message" , + attemptSpanEvents.get(2).getName()); + assertEquals( + io.opentelemetry.api.common.Attributes.builder() + .put("sequence-number", 1) + .put("message-size", 27) + .build(), + attemptSpanEvents.get(2).getAttributes()); + + assertEquals( + "Inbound compressed message" , + attemptSpanEvents.get(3).getName()); + assertEquals( + io.opentelemetry.api.common.Attributes.builder() + .put("sequence-number", 0) + .put("message-size-compressed", 255) + .build(), + attemptSpanEvents.get(3).getAttributes()); + + assertEquals(attemptSpanData.hasEnded(), true); + } + + @Test + public void clientInterceptor() { + testClientInterceptors(false); + } + + @Test + public void clientInterceptorNonDefaultOtelContext() { + testClientInterceptors(true); + } + + private void testClientInterceptors(boolean nonDefaultOtelContext) { + final AtomicReference capturedMetadata = new AtomicReference<>(); + grpcServerRule.getServiceRegistry().addService( + ServerServiceDefinition.builder("package1.service2").addMethod( + method, new ServerCallHandler() { + @Override + public ServerCall.Listener startCall( + ServerCall call, Metadata headers) { + capturedMetadata.set(headers); + call.sendHeaders(new Metadata()); + call.sendMessage("Hello"); + call.close( + Status.PERMISSION_DENIED.withDescription("No you don't"), new Metadata()); + return mockServerCallListener; + } + }).build()); + + final AtomicReference capturedCallOptions = new AtomicReference<>(); + ClientInterceptor callOptionsCaptureInterceptor = new ClientInterceptor() { + @Override + public ClientCall interceptCall( + MethodDescriptor method, CallOptions callOptions, Channel next) { + capturedCallOptions.set(callOptions); + return next.newCall(method, callOptions); + } + }; + OpenTelemetryTracingModule tracingModule = new OpenTelemetryTracingModule( + openTelemetryRule.getOpenTelemetry()); + Channel interceptedChannel = + ClientInterceptors.intercept( + grpcServerRule.getChannel(), callOptionsCaptureInterceptor, + tracingModule.getClientInterceptor()); + Span parentSpan = tracerRule.spanBuilder("test-parent-span").startSpan(); + ClientCall call; + + if (nonDefaultOtelContext) { + try (Scope scope = io.opentelemetry.context.Context.current().with(parentSpan) + .makeCurrent()) { + call = interceptedChannel.newCall(method, CALL_OPTIONS); + } + } else { + call = interceptedChannel.newCall(method, CALL_OPTIONS); + } + assertEquals("customvalue", capturedCallOptions.get().getOption(CUSTOM_OPTION)); + assertEquals(1, capturedCallOptions.get().getStreamTracerFactories().size()); + assertTrue( + capturedCallOptions.get().getStreamTracerFactories().get(0) + instanceof CallAttemptsTracerFactory); + + // Make the call + Metadata headers = new Metadata(); + call.start(mockClientCallListener, headers); + + // End the call + call.halfClose(); + call.request(1); + parentSpan.end(); + + verify(mockClientCallListener).onClose(statusCaptor.capture(), any(Metadata.class)); + Status status = statusCaptor.getValue(); + assertEquals(Status.Code.PERMISSION_DENIED, status.getCode()); + assertEquals("No you don't", status.getDescription()); + + List spans = openTelemetryRule.getSpans(); + assertEquals(spans.size(), 3); + + SpanData clientSpan = spans.get(1); + SpanData attemptSpan = spans.get(0); + if (nonDefaultOtelContext) { + assertEquals(clientSpan.getParentSpanContext(), parentSpan.getSpanContext()); + } else { + assertEquals(clientSpan.getParentSpanContext(), + Span.fromContext(Context.root()).getSpanContext()); + } + String spanContext = capturedMetadata.get().get( + Metadata.Key.of("traceparent", Metadata.ASCII_STRING_MARSHALLER)); + // W3C format: 00--- + assertEquals(spanContext.substring(3, 3 + TraceId.getLength()), + attemptSpan.getSpanContext().getTraceId()); + assertEquals(spanContext.substring(3 + TraceId.getLength() + 1, + 3 + TraceId.getLength() + 1 + SpanId.getLength()), + attemptSpan.getSpanContext().getSpanId()); + + assertEquals(attemptSpan.getParentSpanContext(), clientSpan.getSpanContext()); + assertTrue(clientSpan.hasEnded()); + assertEquals(clientSpan.getStatus().getStatusCode(), StatusCode.ERROR); + assertEquals(clientSpan.getStatus().getDescription(), "PERMISSION_DENIED: No you don't"); + assertTrue(attemptSpan.hasEnded()); + assertTrue(attemptSpan.hasEnded()); + assertEquals(attemptSpan.getStatus().getStatusCode(), StatusCode.ERROR); + assertEquals(attemptSpan.getStatus().getDescription(), "PERMISSION_DENIED: No you don't"); + } + + @Test + public void clientStreamNeverCreatedStillRecordTracing() { + OpenTelemetryTracingModule tracingModule = new OpenTelemetryTracingModule( + openTelemetryRule.getOpenTelemetry()); + CallAttemptsTracerFactory callTracer = + tracingModule.newClientCallTracer(mockClientSpan, method); + + callTracer.callEnded(Status.DEADLINE_EXCEEDED.withDescription("3 seconds")); + verify(mockClientSpan).end(); + verify(mockClientSpan).setStatus(eq(StatusCode.ERROR), + eq("DEADLINE_EXCEEDED: 3 seconds")); + verifyNoMoreInteractions(mockClientSpan); + } + + @Test + public void serverBasicTracingNoHeaders() { + OpenTelemetryTracingModule tracingModule = new OpenTelemetryTracingModule( + openTelemetryRule.getOpenTelemetry()); + ServerStreamTracer.Factory tracerFactory = tracingModule.getServerTracerFactory(); + ServerStreamTracer serverStreamTracer = + tracerFactory.newServerStreamTracer(method.getFullMethodName(), new Metadata()); + assertSame(Span.fromContext(Context.current()), Span.getInvalid()); + + serverStreamTracer.outboundMessage(0); + serverStreamTracer.outboundMessageSent(0, 882, 998); + serverStreamTracer.inboundMessage(0); + serverStreamTracer.outboundMessage(1); + serverStreamTracer.outboundMessageSent(1, -1, 27); + serverStreamTracer.inboundMessageRead(0, 90, -1); + serverStreamTracer.inboundUncompressedSize(255); + + serverStreamTracer.streamClosed(Status.CANCELLED); + + List spans = openTelemetryRule.getSpans(); + assertEquals(spans.size(), 1); + assertEquals(spans.get(0).getName(), "Recv.package1.service2.method3"); + assertEquals(spans.get(0).getParentSpanContext(), Span.getInvalid().getSpanContext()); + + List events = spans.get(0).getEvents(); + assertEquals(events.size(), 4); + assertEquals( + "Outbound message" , + events.get(0).getName()); + assertEquals( + io.opentelemetry.api.common.Attributes.builder() + .put("sequence-number", 0) + .put("message-size-compressed", 882) + .put("message-size", 998) + .build(), + events.get(0).getAttributes()); + + assertEquals( + "Outbound message" , + events.get(1).getName()); + assertEquals( + io.opentelemetry.api.common.Attributes.builder() + .put("sequence-number", 1) + .put("message-size", 27) + .build(), + events.get(1).getAttributes()); + + assertEquals( + "Inbound compressed message" , + events.get(2).getName()); + assertEquals( + io.opentelemetry.api.common.Attributes.builder() + .put("sequence-number", 0) + .put("message-size-compressed", 90) + .build(), + events.get(2).getAttributes()); + + assertEquals( + "Inbound message" , + events.get(3).getName()); + assertEquals( + io.opentelemetry.api.common.Attributes.builder() + .put("sequence-number", 0) + .put("message-size", 255) + .build(), + events.get(3).getAttributes()); + + assertEquals(spans.get(0).hasEnded(), true); + } + + @Test + public void grpcTraceBinPropagator() { + when(mockOpenTelemetry.getPropagators()).thenReturn( + ContextPropagators.create(GrpcTraceBinContextPropagator.defaultInstance())); + ArgumentCaptor contextArgumentCaptor = ArgumentCaptor.forClass(Context.class); + OpenTelemetryTracingModule tracingModule = new OpenTelemetryTracingModule(mockOpenTelemetry); + Span testClientSpan = tracerRule.spanBuilder("test-client-span").startSpan(); + CallAttemptsTracerFactory callTracer = + tracingModule.newClientCallTracer(testClientSpan, method); + Span testAttemptSpan = tracerRule.spanBuilder("test-attempt-span").startSpan(); + when(mockSpanBuilder.startSpan()).thenReturn(testAttemptSpan); + + Metadata headers = new Metadata(); + ClientStreamTracer clientStreamTracer = callTracer.newClientStreamTracer(STREAM_INFO, headers); + clientStreamTracer.streamCreated(Attributes.EMPTY, headers); + clientStreamTracer.streamClosed(Status.CANCELLED); + + Metadata.Key key = Metadata.Key.of( + GrpcTraceBinContextPropagator.GRPC_TRACE_BIN_HEADER, Metadata.BINARY_BYTE_MARSHALLER); + assertTrue(Arrays.equals(BinaryFormat.getInstance().toBytes(testAttemptSpan.getSpanContext()), + headers.get(key) + )); + verify(mockSpanBuilder).setParent(contextArgumentCaptor.capture()); + assertEquals(testClientSpan, Span.fromContext(contextArgumentCaptor.getValue())); + + Span serverSpan = tracerRule.spanBuilder("test-server-span").startSpan(); + when(mockSpanBuilder.startSpan()).thenReturn(serverSpan); + ServerStreamTracer.Factory tracerFactory = tracingModule.getServerTracerFactory(); + ServerStreamTracer serverStreamTracer = + tracerFactory.newServerStreamTracer(method.getFullMethodName(), headers); + serverStreamTracer.streamClosed(Status.CANCELLED); + + verify(mockSpanBuilder, times(2)) + .setParent(contextArgumentCaptor.capture()); + assertEquals(testAttemptSpan.getSpanContext(), + Span.fromContext(contextArgumentCaptor.getValue()).getSpanContext()); + } + + @Test + public void testServerParentSpanPropagation() throws Exception { + final AtomicReference applicationSpan = new AtomicReference<>(); + OpenTelemetryTracingModule tracingModule = new OpenTelemetryTracingModule( + openTelemetryRule.getOpenTelemetry()); + ServerServiceDefinition serviceDefinition = + ServerServiceDefinition.builder("package1.service2").addMethod( + method, new ServerCallHandler() { + @Override + public ServerCall.Listener startCall( + ServerCall call, Metadata headers) { + applicationSpan.set(Span.fromContext(Context.current())); + call.sendHeaders(new Metadata()); + call.sendMessage("Hello"); + call.close( + Status.PERMISSION_DENIED.withDescription("No you don't"), new Metadata()); + return mockServerCallListener; + } + }).build(); + + Server server = InProcessServerBuilder.forName("test-server-span") + .addService( + ServerInterceptors.intercept(serviceDefinition, + tracingModule.getServerSpanPropagationInterceptor())) + .addStreamTracerFactory(tracingModule.getServerTracerFactory()) + .directExecutor().build().start(); + grpcCleanupRule.register(server); + + ManagedChannel channel = InProcessChannelBuilder.forName("test-server-span") + .directExecutor().build(); + grpcCleanupRule.register(channel); + + Span parentSpan = tracerRule.spanBuilder("test-parent-span").startSpan(); + try (Scope scope = Context.current().with(parentSpan).makeCurrent()) { + Channel interceptedChannel = + ClientInterceptors.intercept( + channel, tracingModule.getClientInterceptor()); + ClientCall call = interceptedChannel.newCall(method, CALL_OPTIONS); + Metadata headers = new Metadata(); + call.start(mockClientCallListener, headers); + + // End the call + call.halfClose(); + call.request(1); + parentSpan.end(); + } + + verify(mockClientCallListener).onClose(statusCaptor.capture(), any(Metadata.class)); + Status rpcStatus = statusCaptor.getValue(); + assertEquals(rpcStatus.getCode(), Status.Code.PERMISSION_DENIED); + assertEquals(rpcStatus.getDescription(), "No you don't"); + assertEquals(applicationSpan.get().getSpanContext().getTraceId(), + parentSpan.getSpanContext().getTraceId()); + + List spans = openTelemetryRule.getSpans(); + assertEquals(spans.size(), 4); + SpanData clientSpan = spans.get(2); + SpanData attemptSpan = spans.get(1); + + assertEquals(clientSpan.getName(), "Sent.package1.service2.method3"); + assertTrue(clientSpan.hasEnded()); + assertEquals(clientSpan.getStatus().getStatusCode(), StatusCode.ERROR); + assertEquals(clientSpan.getStatus().getDescription(), "PERMISSION_DENIED: No you don't"); + + assertEquals(attemptSpan.getName(), "Attempt.package1.service2.method3"); + assertTrue(attemptSpan.hasEnded()); + assertEquals(attemptSpan.getStatus().getStatusCode(), StatusCode.ERROR); + assertEquals(attemptSpan.getStatus().getDescription(), "PERMISSION_DENIED: No you don't"); + + SpanData serverSpan = spans.get(0); + assertEquals(serverSpan.getName(), "Recv.package1.service2.method3"); + assertTrue(serverSpan.hasEnded()); + assertEquals(serverSpan.getStatus().getStatusCode(), StatusCode.ERROR); + assertEquals(serverSpan.getStatus().getDescription(), "PERMISSION_DENIED: No you don't"); + } + + @Test + public void serverSpanPropagationInterceptor() throws Exception { + OpenTelemetryTracingModule tracingModule = new OpenTelemetryTracingModule( + openTelemetryRule.getOpenTelemetry()); + Server server = InProcessServerBuilder.forName("test-span-propagation-interceptor") + .directExecutor().build().start(); + grpcCleanupRule.register(server); + final AtomicReference callbackSpan = new AtomicReference<>(); + ServerCall.Listener getContextListener = new ServerCall.Listener() { + @Override + public void onMessage(Integer message) { + callbackSpan.set(Span.fromContext(Context.current())); + } + + @Override + public void onHalfClose() { + callbackSpan.set(Span.fromContext(Context.current())); + } + + @Override + public void onCancel() { + callbackSpan.set(Span.fromContext(Context.current())); + } + + @Override + public void onComplete() { + callbackSpan.set(Span.fromContext(Context.current())); + } + }; + ServerInterceptor interceptor = tracingModule.getServerSpanPropagationInterceptor(); + @SuppressWarnings("unchecked") + ServerCallHandler handler = mock(ServerCallHandler.class); + when(handler.startCall(any(), any())).thenReturn(getContextListener); + ServerCall call = new NoopServerCall<>(); + Metadata metadata = new Metadata(); + ServerCall.Listener listener = interceptor.interceptCall(call, metadata, handler); + verify(handler).startCall(same(call), same(metadata)); + listener.onMessage(1); + assertEquals(callbackSpan.get(), Span.getInvalid()); + listener.onReady(); + assertEquals(callbackSpan.get(), Span.getInvalid()); + listener.onCancel(); + assertEquals(callbackSpan.get(), Span.getInvalid()); + listener.onHalfClose(); + assertEquals(callbackSpan.get(), Span.getInvalid()); + listener.onComplete(); + assertEquals(callbackSpan.get(), Span.getInvalid()); + + Span parentSpan = tracerRule.spanBuilder("parent-span").startSpan(); + io.grpc.Context context = io.grpc.Context.current().withValue( + tracingModule.otelSpan, parentSpan); + io.grpc.Context previous = context.attach(); + try { + listener = interceptor.interceptCall(call, metadata, handler); + verify(handler, times(2)).startCall(same(call), same(metadata)); + listener.onMessage(1); + assertEquals(callbackSpan.get().getSpanContext().getTraceId(), + parentSpan.getSpanContext().getTraceId()); + listener.onReady(); + assertEquals(callbackSpan.get().getSpanContext().getTraceId(), + parentSpan.getSpanContext().getTraceId()); + listener.onCancel(); + assertEquals(callbackSpan.get().getSpanContext().getTraceId(), + parentSpan.getSpanContext().getTraceId()); + listener.onHalfClose(); + assertEquals(callbackSpan.get().getSpanContext().getTraceId(), + parentSpan.getSpanContext().getTraceId()); + listener.onComplete(); + assertEquals(callbackSpan.get().getSpanContext().getTraceId(), + parentSpan.getSpanContext().getTraceId()); + } finally { + context.detach(previous); + } + } + + /** + * Tests that baggage from the initial context is propagated + * to the context active during the next handler's execution. + */ + @Test + public void testBaggageIsPropagatedToHandlerContext() { + // 1. ARRANGE + OpenTelemetryTracingModule tracingModule = new OpenTelemetryTracingModule( + openTelemetryRule.getOpenTelemetry()); + ServerInterceptor interceptor = tracingModule.getServerSpanPropagationInterceptor(); + + // Create mocks for the gRPC call chain + @SuppressWarnings("unchecked") + ServerCallHandler mockHandler = mock(ServerCallHandler.class); + @SuppressWarnings("unchecked") + ServerCall.Listener mockListener = mock(ServerCall.Listener.class); + ServerCall mockCall = new NoopServerCall<>(); + Metadata mockHeaders = new Metadata(); + + // Create a non-null Span (required to pass the first 'if' check) + Span testSpan = Span.wrap( + SpanContext.create("time-period", "star-wars", + TraceFlags.getSampled(), TraceState.getDefault())); + + // Create the test Baggage + Baggage testBaggage = Baggage.builder().put("best-bot", "R2D2").build(); + + // Create the initial gRPC context that the interceptor will read from + io.grpc.Context initialGrpcContext = io.grpc.Context.current() + .withValue(tracingModule.otelSpan, testSpan) + .withValue(BAGGAGE_KEY, testBaggage); + + // This AtomicReference will capture the Baggage from *within* the handler + final AtomicReference capturedBaggage = new AtomicReference<>(); + + // Stub the handler to capture the *current* context when it's called + doAnswer(invocation -> { + // Baggage.current() gets baggage from io.opentelemetry.context.Context.current() + capturedBaggage.set(Baggage.current()); + return mockListener; + }).when(mockHandler).startCall(any(), any()); + + // 2. ACT + // Run the interceptCall method within the prepared context + io.grpc.Context previous = initialGrpcContext.attach(); + try { + interceptor.interceptCall(mockCall, mockHeaders, mockHandler); + } finally { + initialGrpcContext.detach(previous); + } + + // 3. ASSERT + // Verify the next handler was called + verify(mockHandler).startCall(same(mockCall), same(mockHeaders)); + + // Check the baggage that was captured + assertNotNull("Baggage should not be null in handler context", capturedBaggage.get()); + assertEquals("Baggage was not correctly propagated to the handler's context", + "R2D2", capturedBaggage.get().getEntryValue("best-bot")); + } + + /** + * Tests that the interceptor proceeds correctly if baggage is null or empty. + */ + @Test + public void testNullBaggageIsHandledGracefully() { + // 1. ARRANGE + OpenTelemetryTracingModule tracingModule = new OpenTelemetryTracingModule( + openTelemetryRule.getOpenTelemetry()); + ServerInterceptor interceptor = tracingModule.getServerSpanPropagationInterceptor(); + + @SuppressWarnings("unchecked") + ServerCallHandler mockHandler = mock(ServerCallHandler.class); + @SuppressWarnings("unchecked") + ServerCall.Listener mockListener = mock(ServerCall.Listener.class); + ServerCall mockCall = new NoopServerCall<>(); + Metadata mockHeaders = new Metadata(); + + Span testSpan = Span.getInvalid(); // A non-null span + + // No baggage is set in the context + io.grpc.Context initialGrpcContext = io.grpc.Context.current() + .withValue(tracingModule.otelSpan, testSpan); + + final AtomicReference capturedBaggage = new AtomicReference<>(); + + // Stub the handler to capture the *current* context when it's called + doAnswer(invocation -> { + // Baggage.current() gets baggage from io.opentelemetry.context.Context.current() + capturedBaggage.set(Baggage.current()); + return mockListener; + }).when(mockHandler).startCall(any(), any()); + + // 2. ACT + io.grpc.Context previous = initialGrpcContext.attach(); + try { + interceptor.interceptCall(mockCall, mockHeaders, mockHandler); + } finally { + initialGrpcContext.detach(previous); + } + + // 3. ASSERT + verify(mockHandler).startCall(same(mockCall), same(mockHeaders)); + + // Baggage should be null in the downstream context + assertEquals("Baggage should be empty when not provided", + Baggage.empty(), capturedBaggage.get()); + } + + @Test + public void generateTraceSpanName() { + assertEquals( + "Sent.io.grpc.Foo", OpenTelemetryTracingModule.generateTraceSpanName( + false, "io.grpc/Foo")); + assertEquals( + "Recv.io.grpc.Bar", OpenTelemetryTracingModule.generateTraceSpanName( + true, "io.grpc/Bar")); + } +} diff --git a/protobuf-lite/BUILD.bazel b/protobuf-lite/BUILD.bazel index 85cd9669e1f..97a5e492d80 100644 --- a/protobuf-lite/BUILD.bazel +++ b/protobuf-lite/BUILD.bazel @@ -1,3 +1,6 @@ +load("@rules_java//java:defs.bzl", "java_library") +load("@rules_jvm_external//:defs.bzl", "artifact") + java_library( name = "protobuf-lite", srcs = glob([ @@ -6,11 +9,10 @@ java_library( visibility = ["//visibility:public"], deps = [ "//api", - "@com_google_code_findbugs_jsr305//jar", - "@com_google_guava_guava//jar", - "@com_google_j2objc_j2objc_annotations//jar", + artifact("com.google.code.findbugs:jsr305"), + artifact("com.google.guava:guava"), ] + select({ - ":android": ["@com_google_protobuf_javalite//:protobuf_javalite"], + ":android": ["@com_google_protobuf//:protobuf_javalite"], "//conditions:default": ["@com_google_protobuf//:protobuf_java"], }), ) diff --git a/protobuf-lite/build.gradle b/protobuf-lite/build.gradle index 11a49d4816d..c1e5b51ae35 100644 --- a/protobuf-lite/build.gradle +++ b/protobuf-lite/build.gradle @@ -17,8 +17,16 @@ dependencies { testImplementation project(':grpc-core') - signature libraries.signature.java - signature libraries.signature.android + signature (libraries.signature.java) { + artifact { + extension = "signature" + } + } + signature (libraries.signature.android) { + artifact { + extension = "signature" + } + } } tasks.named("jar").configure { @@ -31,7 +39,7 @@ tasks.named("compileTestJava").configure { options.compilerArgs += [ "-Xlint:-cast" ] - options.errorprone.excludedPaths = ".*/build/generated/source/proto/.*" + options.errorprone.excludedPaths = ".*/build/generated/sources/proto/.*" } protobuf { diff --git a/protobuf-lite/src/main/java/io/grpc/protobuf/lite/ProtoLiteUtils.java b/protobuf-lite/src/main/java/io/grpc/protobuf/lite/ProtoLiteUtils.java index 7e33fc67622..ef4b16bd476 100644 --- a/protobuf-lite/src/main/java/io/grpc/protobuf/lite/ProtoLiteUtils.java +++ b/protobuf-lite/src/main/java/io/grpc/protobuf/lite/ProtoLiteUtils.java @@ -89,12 +89,11 @@ public static Marshaller marshaller(T defaultInstance /** * Creates a {@link Marshaller} for protos of the same type as {@code defaultInstance} and a - * custom limit for the recursion depth. Any negative number will leave the limit to its default + * custom limit for the recursion depth. Any negative number will leave the limit as its default * value as defined by the protobuf library. * * @since 1.56.0 */ - @ExperimentalApi("https://github.com/grpc/grpc-java/issues/10108") public static Marshaller marshallerWithRecursionLimit( T defaultInstance, int recursionLimit) { return new MessageMarshaller<>(defaultInstance, recursionLimit); diff --git a/protobuf-lite/src/test/java/io/grpc/protobuf/lite/ProtoLiteUtilsTest.java b/protobuf-lite/src/test/java/io/grpc/protobuf/lite/ProtoLiteUtilsTest.java index 5c25cb3b309..204264b016d 100644 --- a/protobuf-lite/src/test/java/io/grpc/protobuf/lite/ProtoLiteUtilsTest.java +++ b/protobuf-lite/src/test/java/io/grpc/protobuf/lite/ProtoLiteUtilsTest.java @@ -16,6 +16,7 @@ package io.grpc.protobuf.lite; +import static com.google.common.truth.Truth.assertThat; import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; @@ -43,9 +44,7 @@ import java.io.IOException; import java.io.InputStream; import java.util.Arrays; -import org.junit.Rule; import org.junit.Test; -import org.junit.rules.ExpectedException; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -53,9 +52,6 @@ @RunWith(JUnit4.class) public class ProtoLiteUtilsTest { - @SuppressWarnings("deprecation") // https://github.com/grpc/grpc-java/issues/7467 - @Rule public final ExpectedException thrown = ExpectedException.none(); - private final Marshaller marshaller = ProtoLiteUtils.marshaller(Type.getDefaultInstance()); private Type proto = Type.newBuilder().setName("name").build(); @@ -214,10 +210,9 @@ public void metadataMarshaller_invalid() { @Test public void extensionRegistry_notNull() { - thrown.expect(NullPointerException.class); - thrown.expectMessage("newRegistry"); - - ProtoLiteUtils.setExtensionRegistry(null); + NullPointerException e = assertThrows(NullPointerException.class, + () -> ProtoLiteUtils.setExtensionRegistry(null)); + assertThat(e).hasMessageThat().isEqualTo("newRegistry"); } @Test diff --git a/protobuf/BUILD.bazel b/protobuf/BUILD.bazel index 42085eea583..a31f8b6f6f5 100644 --- a/protobuf/BUILD.bazel +++ b/protobuf/BUILD.bazel @@ -1,3 +1,6 @@ +load("@rules_java//java:defs.bzl", "java_library") +load("@rules_jvm_external//:defs.bzl", "artifact") + java_library( name = "protobuf", srcs = glob([ @@ -7,10 +10,10 @@ java_library( deps = [ "//api", "//protobuf-lite", - "@com_google_api_grpc_proto_google_common_protos//jar", - "@com_google_code_findbugs_jsr305//jar", - "@com_google_guava_guava//jar", - "@com_google_j2objc_j2objc_annotations//jar", "@com_google_protobuf//:protobuf_java", + artifact("com.google.api.grpc:proto-google-common-protos"), + artifact("com.google.code.findbugs:jsr305"), + artifact("com.google.errorprone:error_prone_annotations"), + artifact("com.google.guava:guava"), ], ) diff --git a/protobuf/build.gradle b/protobuf/build.gradle index c88ae836e0f..c477e41dceb 100644 --- a/protobuf/build.gradle +++ b/protobuf/build.gradle @@ -31,8 +31,16 @@ dependencies { exclude group: 'com.google.protobuf', module: 'protobuf-javalite' } - signature libraries.signature.java - signature libraries.signature.android + signature (libraries.signature.java) { + artifact { + extension = "signature" + } + } + signature (libraries.signature.android) { + artifact { + extension = "signature" + } + } } tasks.named("javadoc").configure { diff --git a/protobuf/src/main/java/io/grpc/protobuf/ProtoMethodDescriptorSupplier.java b/protobuf/src/main/java/io/grpc/protobuf/ProtoMethodDescriptorSupplier.java index e5b2f38e3c0..e7cd3ed336f 100644 --- a/protobuf/src/main/java/io/grpc/protobuf/ProtoMethodDescriptorSupplier.java +++ b/protobuf/src/main/java/io/grpc/protobuf/ProtoMethodDescriptorSupplier.java @@ -16,8 +16,8 @@ package io.grpc.protobuf; +import com.google.errorprone.annotations.CheckReturnValue; import com.google.protobuf.Descriptors.MethodDescriptor; -import javax.annotation.CheckReturnValue; /** * Provides access to the underlying proto service method descriptor. diff --git a/protobuf/src/main/java/io/grpc/protobuf/ProtoUtils.java b/protobuf/src/main/java/io/grpc/protobuf/ProtoUtils.java index 933d598996c..d403789eb5f 100644 --- a/protobuf/src/main/java/io/grpc/protobuf/ProtoUtils.java +++ b/protobuf/src/main/java/io/grpc/protobuf/ProtoUtils.java @@ -18,7 +18,6 @@ import com.google.protobuf.ExtensionRegistry; import com.google.protobuf.Message; -import io.grpc.ExperimentalApi; import io.grpc.Metadata; import io.grpc.MethodDescriptor.Marshaller; import io.grpc.protobuf.lite.ProtoLiteUtils; @@ -58,12 +57,11 @@ public static Marshaller marshaller(final T defaultInstan /** * Creates a {@link Marshaller} for protos of the same type as {@code defaultInstance} and a - * custom limit for the recursion depth. Any negative number will leave the limit to its default + * custom limit for the recursion depth. Any negative number will leave the limit as its default * value as defined by the protobuf library. * * @since 1.56.0 */ - @ExperimentalApi("https://github.com/grpc/grpc-java/issues/10108") public static Marshaller marshallerWithRecursionLimit(T defaultInstance, int recursionLimit) { return ProtoLiteUtils.marshallerWithRecursionLimit(defaultInstance, recursionLimit); diff --git a/protobuf/src/main/java/io/grpc/protobuf/StatusProto.java b/protobuf/src/main/java/io/grpc/protobuf/StatusProto.java index 988e1938af0..0ebc1e714f6 100644 --- a/protobuf/src/main/java/io/grpc/protobuf/StatusProto.java +++ b/protobuf/src/main/java/io/grpc/protobuf/StatusProto.java @@ -103,6 +103,25 @@ public static StatusException toStatusException( return toStatus(statusProto).asException(toMetadata(statusProto, metadata)); } + /** + * Convert a {@link com.google.rpc.Status} instance to a {@link StatusException} with additional + * metadata and the root exception thrown. The exception isn't propagated over the wire. + * + *

The returned {@link StatusException} will wrap a {@link Status} whose code and description + * are set from the code and message in {@code statusProto}. {@code statusProto} will be + * serialized and added to {@code metadata}. {@code metadata} will be set as the metadata of the + * returned {@link StatusException}. The {@link Throwable} is the exception that is set as the + * {@code cause} of the returned {@link StatusException}. + * + * @throws IllegalArgumentException if the value of {@code statusProto.getCode()} is not a valid + * gRPC status code. + * @since 1.3.0 + */ + public static StatusException toStatusException( + com.google.rpc.Status statusProto, Metadata metadata, Throwable cause) { + return toStatus(statusProto).withCause(cause).asException(toMetadata(statusProto, metadata)); + } + private static Status toStatus(com.google.rpc.Status statusProto) { Status status = Status.fromCodeValue(statusProto.getCode()); checkArgument(status.getCode().value() == statusProto.getCode(), "invalid status code"); diff --git a/protobuf/src/test/java/io/grpc/protobuf/StatusProtoTest.java b/protobuf/src/test/java/io/grpc/protobuf/StatusProtoTest.java index cf9c2c564ab..47c045bf952 100644 --- a/protobuf/src/test/java/io/grpc/protobuf/StatusProtoTest.java +++ b/protobuf/src/test/java/io/grpc/protobuf/StatusProtoTest.java @@ -176,6 +176,14 @@ public void fromThrowable_shouldReturnNullIfNoEmbeddedStatus() { assertNull(StatusProto.fromThrowable(nestedSe)); } + @Test + public void toStatusExceptionWithMetadataAndCause_shouldCaptureCause() { + RuntimeException exc = new RuntimeException("This is a test exception."); + StatusException se = StatusProto.toStatusException(STATUS_PROTO, new Metadata(), exc); + + assertEquals(exc, se.getCause()); + } + private static final Metadata.Key METADATA_KEY = Metadata.Key.of("test-metadata", Metadata.ASCII_STRING_MARSHALLER); private static final String METADATA_VALUE = "test metadata value"; diff --git a/repositories.bzl b/repositories.bzl index 702daaf4be2..9691a12f286 100644 --- a/repositories.bzl +++ b/repositories.bzl @@ -9,43 +9,47 @@ load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") # # Your own deps # ] + IO_GRPC_GRPC_JAVA_ARTIFACTS, # ) +# GRPC_DEPS_START IO_GRPC_GRPC_JAVA_ARTIFACTS = [ "com.google.android:annotations:4.1.1.4", - "com.google.api.grpc:proto-google-common-protos:2.29.0", - "com.google.auth:google-auth-library-credentials:1.22.0", - "com.google.auth:google-auth-library-oauth2-http:1.22.0", - "com.google.auto.value:auto-value-annotations:1.10.4", - "com.google.auto.value:auto-value:1.10.4", + "com.google.api.grpc:proto-google-common-protos:2.64.1", + "com.google.auth:google-auth-library-credentials:1.42.1", + "com.google.auth:google-auth-library-oauth2-http:1.42.1", + "com.google.auto.value:auto-value-annotations:1.11.0", + "com.google.auto.value:auto-value:1.11.0", "com.google.code.findbugs:jsr305:3.0.2", - "com.google.code.gson:gson:2.10.1", - "com.google.errorprone:error_prone_annotations:2.23.0", + "com.google.code.gson:gson:2.13.2", + "com.google.errorprone:error_prone_annotations:2.48.0", "com.google.guava:failureaccess:1.0.1", - "com.google.guava:guava:32.1.3-android", - "com.google.re2j:re2j:1.7", - "com.google.truth:truth:1.1.5", + "com.google.guava:guava:33.5.0-android", + "com.google.re2j:re2j:1.8", + "com.google.s2a.proto.v2:s2a-proto:0.1.3", + "com.google.truth:truth:1.4.5", "com.squareup.okhttp:okhttp:2.7.5", "com.squareup.okio:okio:2.10.0", # 3.0+ needs swapping to -jvm; need work to avoid flag-day - "io.netty:netty-buffer:4.1.100.Final", - "io.netty:netty-codec-http2:4.1.100.Final", - "io.netty:netty-codec-http:4.1.100.Final", - "io.netty:netty-codec-socks:4.1.100.Final", - "io.netty:netty-codec:4.1.100.Final", - "io.netty:netty-common:4.1.100.Final", - "io.netty:netty-handler-proxy:4.1.100.Final", - "io.netty:netty-handler:4.1.100.Final", - "io.netty:netty-resolver:4.1.100.Final", - "io.netty:netty-tcnative-boringssl-static:2.0.61.Final", - "io.netty:netty-tcnative-classes:2.0.61.Final", - "io.netty:netty-transport-native-epoll:jar:linux-x86_64:4.1.100.Final", - "io.netty:netty-transport-native-unix-common:4.1.100.Final", - "io.netty:netty-transport:4.1.100.Final", + "io.netty:netty-buffer:4.1.133.Final", + "io.netty:netty-codec-http2:4.1.133.Final", + "io.netty:netty-codec-http:4.1.133.Final", + "io.netty:netty-codec-socks:4.1.133.Final", + "io.netty:netty-codec:4.1.133.Final", + "io.netty:netty-common:4.1.133.Final", + "io.netty:netty-handler-proxy:4.1.133.Final", + "io.netty:netty-handler:4.1.133.Final", + "io.netty:netty-resolver:4.1.133.Final", + "io.netty:netty-tcnative-boringssl-static:2.0.75.Final", + "io.netty:netty-tcnative-classes:2.0.75.Final", + "io.netty:netty-transport-native-epoll:jar:linux-x86_64:4.1.133.Final", + "io.netty:netty-transport-native-unix-common:4.1.133.Final", + "io.netty:netty-transport:4.1.133.Final", "io.opencensus:opencensus-api:0.31.0", "io.opencensus:opencensus-contrib-grpc-metrics:0.31.0", - "io.perfmark:perfmark-api:0.26.0", + "io.perfmark:perfmark-api:0.27.0", "junit:junit:4.13.2", - "org.apache.tomcat:annotations-api:6.0.53", - "org.codehaus.mojo:animal-sniffer-annotations:1.23", + "org.mockito:mockito-core:4.4.0", + "org.checkerframework:checker-qual:3.49.5", + "org.codehaus.mojo:animal-sniffer-annotations:1.27", ] +# GRPC_DEPS_END # For use with maven_install's override_targets. # maven_install( @@ -60,7 +64,7 @@ IO_GRPC_GRPC_JAVA_ARTIFACTS = [ IO_GRPC_GRPC_JAVA_OVERRIDE_TARGETS = { "com.google.protobuf:protobuf-java": "@com_google_protobuf//:protobuf_java", "com.google.protobuf:protobuf-java-util": "@com_google_protobuf//:protobuf_java_util", - "com.google.protobuf:protobuf-javalite": "@com_google_protobuf_javalite//:protobuf_javalite", + "com.google.protobuf:protobuf-javalite": "@com_google_protobuf//:protobuf_javalite", "io.grpc:grpc-alts": "@io_grpc_grpc_java//alts", "io.grpc:grpc-api": "@io_grpc_grpc_java//api", "io.grpc:grpc-auth": "@io_grpc_grpc_java//auth", @@ -78,6 +82,7 @@ IO_GRPC_GRPC_JAVA_OVERRIDE_TARGETS = { "io.grpc:grpc-rls": "@io_grpc_grpc_java//rls", "io.grpc:grpc-services": "@io_grpc_grpc_java//services:services_maven", "io.grpc:grpc-stub": "@io_grpc_grpc_java//stub", + "io.grpc:grpc-s2a": "@io_grpc_grpc_java//s2a", "io.grpc:grpc-testing": "@io_grpc_grpc_java//testing", "io.grpc:grpc-xds": "@io_grpc_grpc_java//xds:xds_maven", "io.grpc:grpc-util": "@io_grpc_grpc_java//util", @@ -85,56 +90,25 @@ IO_GRPC_GRPC_JAVA_OVERRIDE_TARGETS = { def grpc_java_repositories(): """Imports dependencies for grpc-java.""" - if not native.existing_rule("com_github_cncf_xds"): - http_archive( - name = "com_github_cncf_xds", - strip_prefix = "xds-e9ce68804cb4e64cab5a52e3c8baf840d4ff87b7", - sha256 = "0d33b83f8c6368954e72e7785539f0d272a8aba2f6e2e336ed15fd1514bc9899", - urls = [ - "https://github.com/cncf/xds/archive/e9ce68804cb4e64cab5a52e3c8baf840d4ff87b7.tar.gz", - ], - ) - if not native.existing_rule("com_github_grpc_grpc"): - http_archive( - name = "com_github_grpc_grpc", - strip_prefix = "grpc-1.46.0", - sha256 = "67423a4cd706ce16a88d1549297023f0f9f0d695a96dd684adc21e67b021f9bc", - urls = [ - "https://github.com/grpc/grpc/archive/v1.46.0.tar.gz", - ], - ) if not native.existing_rule("com_google_protobuf"): com_google_protobuf() - if not native.existing_rule("com_google_protobuf_javalite"): - com_google_protobuf_javalite() if not native.existing_rule("com_google_googleapis"): http_archive( name = "com_google_googleapis", - sha256 = "49930468563dd48283e8301e8d4e71436bf6d27ac27c235224cc1a098710835d", - strip_prefix = "googleapis-ca1372c6d7bcb199638ebfdb40d2b2660bab7b88", + sha256 = "397fd8eb8a1a62dcf144216d9775816fad7a3fcff0ced1614bee529003c30d9e", + strip_prefix = "googleapis-1dbb1a14e079f78d9214f8e48bf083f32e3ddb96", urls = [ - "https://github.com/googleapis/googleapis/archive/ca1372c6d7bcb199638ebfdb40d2b2660bab7b88.tar.gz", - ], - ) - if not native.existing_rule("io_bazel_rules_go"): - http_archive( - name = "io_bazel_rules_go", - sha256 = "ab21448cef298740765f33a7f5acee0607203e4ea321219f2a4c85a6e0fb0a27", - urls = [ - "https://mirror.bazel.build/github.com/bazelbuild/rules_go/releases/download/v0.32.0/rules_go-v0.32.0.zip", - "https://github.com/bazelbuild/rules_go/releases/download/v0.32.0/rules_go-v0.32.0.zip", + "https://github.com/googleapis/googleapis/archive/1dbb1a14e079f78d9214f8e48bf083f32e3ddb96.tar.gz", ], ) if not native.existing_rule("io_grpc_grpc_proto"): io_grpc_grpc_proto() - if not native.existing_rule("envoy_api"): + if not native.existing_rule("bazel_jar_jar"): http_archive( - name = "envoy_api", - sha256 = "b426904abf51ba21dd8947a05694bb3c861d6f5e436e4673e74d7d7bfb6d3188", - strip_prefix = "data-plane-api-268824e4eee3d7770a347a5dc5aaddc0b1b14e24", - urls = [ - "https://github.com/envoyproxy/data-plane-api/archive/268824e4eee3d7770a347a5dc5aaddc0b1b14e24.tar.gz", - ], + name = "bazel_jar_jar", + sha256 = "3117f913c732142a795551f530d02c9157b9ea895e6b2de0fbb5af54f03040a5", + strip_prefix = "bazel_jar_jar-0.1.6", + url = "https://github.com/bazeltools/bazel_jar_jar/releases/download/v0.1.6/bazel_jar_jar-v0.1.6.tar.gz", ) def com_google_protobuf(): @@ -143,24 +117,15 @@ def com_google_protobuf(): # This statement defines the @com_google_protobuf repo. http_archive( name = "com_google_protobuf", - sha256 = "9bd87b8280ef720d3240514f884e56a712f2218f0d693b48050c836028940a42", - strip_prefix = "protobuf-25.1", - urls = ["https://github.com/protocolbuffers/protobuf/releases/download/v25.1/protobuf-25.1.tar.gz"], - ) - -def com_google_protobuf_javalite(): - # java_lite_proto_library rules implicitly depend on @com_google_protobuf_javalite - http_archive( - name = "com_google_protobuf_javalite", - sha256 = "9bd87b8280ef720d3240514f884e56a712f2218f0d693b48050c836028940a42", - strip_prefix = "protobuf-25.1", - urls = ["https://github.com/protocolbuffers/protobuf/releases/download/v25.1/protobuf-25.1.tar.gz"], + sha256 = "bc670a4e34992c175137ddda24e76562bb928f849d712a0e3c2fb2e19249bea1", + strip_prefix = "protobuf-33.4", + urls = ["https://github.com/protocolbuffers/protobuf/releases/download/v33.4/protobuf-33.4.tar.gz"], ) def io_grpc_grpc_proto(): http_archive( name = "io_grpc_grpc_proto", - sha256 = "464e97a24d7d784d9c94c25fa537ba24127af5aae3edd381007b5b98705a0518", - strip_prefix = "grpc-proto-08911e9d585cbda3a55eb1dcc4b99c89aebccff8", - urls = ["https://github.com/grpc/grpc-proto/archive/08911e9d585cbda3a55eb1dcc4b99c89aebccff8.zip"], + sha256 = "729ac127a003836d539ed9da72a21e094aac4c4609e0481d6fc9e28a844e11af", + strip_prefix = "grpc-proto-4f245d272a28a680606c0739753506880cf33b5f", + urls = ["https://github.com/grpc/grpc-proto/archive/4f245d272a28a680606c0739753506880cf33b5f.zip"], ) diff --git a/rls/BUILD.bazel b/rls/BUILD.bazel index c67c7cd56be..70c17a9c8b6 100644 --- a/rls/BUILD.bazel +++ b/rls/BUILD.bazel @@ -1,3 +1,5 @@ +load("@rules_java//java:defs.bzl", "java_library") +load("@rules_jvm_external//:defs.bzl", "artifact") load("//:java_grpc_library.bzl", "java_grpc_library") java_library( @@ -12,13 +14,14 @@ java_library( "//api", "//core", "//core:internal", - "//util", "//stub", - "@com_google_auto_value_auto_value_annotations//jar", - "@com_google_code_findbugs_jsr305//jar", - "@com_google_guava_guava//jar", + "//util", "@io_grpc_grpc_proto//:rls_config_java_proto", "@io_grpc_grpc_proto//:rls_java_proto", + artifact("com.google.auto.value:auto-value-annotations"), + artifact("com.google.code.findbugs:jsr305"), + artifact("com.google.errorprone:error_prone_annotations"), + artifact("com.google.guava:guava"), ], ) diff --git a/rls/build.gradle b/rls/build.gradle index 9e38ef3d868..10b1d5fc371 100644 --- a/rls/build.gradle +++ b/rls/build.gradle @@ -22,14 +22,18 @@ dependencies { libraries.auto.value.annotations, libraries.guava annotationProcessor libraries.auto.value - compileOnly libraries.javax.annotation testImplementation libraries.truth, project(':grpc-grpclb'), project(':grpc-inprocess'), project(':grpc-testing'), project(':grpc-testing-proto'), + testFixtures(project(':grpc-api')), testFixtures(project(':grpc-core')) - signature libraries.signature.java + signature (libraries.signature.java) { + artifact { + extension = "signature" + } + } } tasks.named("compileJava").configure { @@ -45,7 +49,7 @@ tasks.named("compileJava").configure { tasks.named("javadoc").configure { // Do not publish javadoc since currently there is no public API. - failOnError false // no public or protected classes found to document + failOnError = false // no public or protected classes found to document exclude 'io/grpc/lookup/v1/**' exclude 'io/grpc/rls/*Provider.java' exclude 'io/grpc/rls/internal/**' diff --git a/rls/src/generated/main/grpc/io/grpc/lookup/v1/RouteLookupServiceGrpc.java b/rls/src/generated/main/grpc/io/grpc/lookup/v1/RouteLookupServiceGrpc.java index d7334b942ff..be060e576a4 100644 --- a/rls/src/generated/main/grpc/io/grpc/lookup/v1/RouteLookupServiceGrpc.java +++ b/rls/src/generated/main/grpc/io/grpc/lookup/v1/RouteLookupServiceGrpc.java @@ -4,9 +4,6 @@ /** */ -@javax.annotation.Generated( - value = "by gRPC proto compiler", - comments = "Source: grpc/lookup/v1/rls.proto") @io.grpc.stub.annotations.GrpcGenerated public final class RouteLookupServiceGrpc { @@ -60,6 +57,21 @@ public RouteLookupServiceStub newStub(io.grpc.Channel channel, io.grpc.CallOptio return RouteLookupServiceStub.newStub(factory, channel); } + /** + * Creates a new blocking-style stub that supports all types of calls on the service + */ + public static RouteLookupServiceBlockingV2Stub newBlockingV2Stub( + io.grpc.Channel channel) { + io.grpc.stub.AbstractStub.StubFactory factory = + new io.grpc.stub.AbstractStub.StubFactory() { + @java.lang.Override + public RouteLookupServiceBlockingV2Stub newStub(io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new RouteLookupServiceBlockingV2Stub(channel, callOptions); + } + }; + return RouteLookupServiceBlockingV2Stub.newStub(factory, channel); + } + /** * Creates a new blocking-style stub that supports unary and streaming output calls on the service */ @@ -147,6 +159,33 @@ public void routeLookup(io.grpc.lookup.v1.RouteLookupRequest request, /** * A stub to allow clients to do synchronous rpc calls to service RouteLookupService. */ + public static final class RouteLookupServiceBlockingV2Stub + extends io.grpc.stub.AbstractBlockingStub { + private RouteLookupServiceBlockingV2Stub( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + super(channel, callOptions); + } + + @java.lang.Override + protected RouteLookupServiceBlockingV2Stub build( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new RouteLookupServiceBlockingV2Stub(channel, callOptions); + } + + /** + *

+     * Lookup returns a target for a single key.
+     * 
+ */ + public io.grpc.lookup.v1.RouteLookupResponse routeLookup(io.grpc.lookup.v1.RouteLookupRequest request) throws io.grpc.StatusException { + return io.grpc.stub.ClientCalls.blockingV2UnaryCall( + getChannel(), getRouteLookupMethod(), getCallOptions(), request); + } + } + + /** + * A stub to allow clients to do limited synchronous rpc calls to service RouteLookupService. + */ public static final class RouteLookupServiceBlockingStub extends io.grpc.stub.AbstractBlockingStub { private RouteLookupServiceBlockingStub( diff --git a/rls/src/main/java/io/grpc/rls/CachingRlsLbClient.java b/rls/src/main/java/io/grpc/rls/CachingRlsLbClient.java index 19fe2430934..a2846fd04c8 100644 --- a/rls/src/main/java/io/grpc/rls/CachingRlsLbClient.java +++ b/rls/src/main/java/io/grpc/rls/CachingRlsLbClient.java @@ -24,28 +24,36 @@ import com.google.common.base.MoreObjects; import com.google.common.base.MoreObjects.ToStringHelper; import com.google.common.base.Ticker; +import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.ListenableFuture; +import com.google.common.util.concurrent.MoreExecutors; import com.google.common.util.concurrent.SettableFuture; +import com.google.errorprone.annotations.CheckReturnValue; +import com.google.errorprone.annotations.concurrent.GuardedBy; import io.grpc.ChannelLogger; import io.grpc.ChannelLogger.ChannelLogLevel; import io.grpc.ConnectivityState; +import io.grpc.Grpc; import io.grpc.LoadBalancer.Helper; import io.grpc.LoadBalancer.PickResult; import io.grpc.LoadBalancer.PickSubchannelArgs; import io.grpc.LoadBalancer.ResolvedAddresses; import io.grpc.LoadBalancer.SubchannelPicker; +import io.grpc.LongCounterMetricInstrument; +import io.grpc.LongGaugeMetricInstrument; import io.grpc.ManagedChannel; import io.grpc.ManagedChannelBuilder; import io.grpc.Metadata; +import io.grpc.MetricInstrumentRegistry; +import io.grpc.MetricRecorder.BatchCallback; +import io.grpc.MetricRecorder.BatchRecorder; +import io.grpc.MetricRecorder.Registration; import io.grpc.Status; -import io.grpc.SynchronizationContext; -import io.grpc.SynchronizationContext.ScheduledHandle; import io.grpc.internal.BackoffPolicy; import io.grpc.internal.ExponentialBackoffPolicy; import io.grpc.lookup.v1.RouteLookupServiceGrpc; import io.grpc.lookup.v1.RouteLookupServiceGrpc.RouteLookupServiceStub; import io.grpc.rls.ChildLoadBalancerHelper.ChildLoadBalancerHelperProvider; -import io.grpc.rls.LbPolicyConfiguration.ChildLbStatusListener; import io.grpc.rls.LbPolicyConfiguration.ChildPolicyWrapper; import io.grpc.rls.LbPolicyConfiguration.RefCountedChildPolicyWrapperFactory; import io.grpc.rls.LruCache.EvictionListener; @@ -53,20 +61,22 @@ import io.grpc.rls.RlsProtoConverters.RouteLookupResponseConverter; import io.grpc.rls.RlsProtoData.RouteLookupConfig; import io.grpc.rls.RlsProtoData.RouteLookupRequest; +import io.grpc.rls.RlsProtoData.RouteLookupRequestKey; import io.grpc.rls.RlsProtoData.RouteLookupResponse; -import io.grpc.rls.Throttler.ThrottledException; import io.grpc.stub.StreamObserver; import io.grpc.util.ForwardingLoadBalancerHelper; import java.net.URI; import java.net.URISyntaxException; +import java.util.Arrays; +import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.UUID; +import java.util.concurrent.Future; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; -import javax.annotation.CheckReturnValue; import javax.annotation.Nullable; -import javax.annotation.concurrent.GuardedBy; import javax.annotation.concurrent.ThreadSafe; /** @@ -87,16 +97,24 @@ final class CachingRlsLbClient { /** Minimum bytes for a Java Object. */ public static final int OBJ_OVERHEAD_B = 16; + private static final LongCounterMetricInstrument DEFAULT_TARGET_PICKS_COUNTER; + private static final LongCounterMetricInstrument TARGET_PICKS_COUNTER; + private static final LongCounterMetricInstrument FAILED_PICKS_COUNTER; + private static final LongGaugeMetricInstrument CACHE_ENTRIES_GAUGE; + private static final LongGaugeMetricInstrument CACHE_SIZE_GAUGE; + private final Registration gaugeRegistration; + private final String metricsInstanceUuid = UUID.randomUUID().toString(); + // All cache status changes (pending, backoff, success) must be under this lock private final Object lock = new Object(); // LRU cache based on access order (BACKOFF and actual data will be here) @GuardedBy("lock") private final RlsAsyncLruCache linkedHashLruCache; + private final Future periodicCleaner; // any RPC on the fly will cached in this map @GuardedBy("lock") - private final Map pendingCallCache = new HashMap<>(); + private final Map pendingCallCache = new HashMap<>(); - private final SynchronizationContext synchronizationContext; private final ScheduledExecutorService scheduledExecutorService; private final Ticker ticker; private final Throttler throttler; @@ -112,13 +130,47 @@ final class CachingRlsLbClient { private final RouteLookupServiceStub rlsStub; private final RlsPicker rlsPicker; private final ResolvedAddressFactory childLbResolvedAddressFactory; + @GuardedBy("lock") private final RefCountedChildPolicyWrapperFactory refCountedChildPolicyWrapperFactory; private final ChannelLogger logger; + private final ChildPolicyWrapper fallbackChildPolicyWrapper; + + static { + MetricInstrumentRegistry metricInstrumentRegistry + = MetricInstrumentRegistry.getDefaultRegistry(); + DEFAULT_TARGET_PICKS_COUNTER = metricInstrumentRegistry.registerLongCounter( + "grpc.lb.rls.default_target_picks", + "EXPERIMENTAL. Number of LB picks sent to the default target", "{pick}", + Arrays.asList("grpc.target", "grpc.lb.rls.server_target", + "grpc.lb.rls.data_plane_target", "grpc.lb.pick_result"), + Arrays.asList("grpc.client.call.custom"), + false); + TARGET_PICKS_COUNTER = metricInstrumentRegistry.registerLongCounter("grpc.lb.rls.target_picks", + "EXPERIMENTAL. Number of LB picks sent to each RLS target. Note that if the default " + + "target is also returned by the RLS server, RPCs sent to that target from the cache " + + "will be counted in this metric, not in grpc.rls.default_target_picks.", "{pick}", + Arrays.asList("grpc.target", "grpc.lb.rls.server_target", "grpc.lb.rls.data_plane_target", + "grpc.lb.pick_result"), + Arrays.asList("grpc.client.call.custom"), + false); + FAILED_PICKS_COUNTER = metricInstrumentRegistry.registerLongCounter("grpc.lb.rls.failed_picks", + "EXPERIMENTAL. Number of LB picks failed due to either a failed RLS request or the " + + "RLS channel being throttled", "{pick}", + Arrays.asList("grpc.target", "grpc.lb.rls.server_target"), + Arrays.asList("grpc.client.call.custom"), false); + CACHE_ENTRIES_GAUGE = metricInstrumentRegistry.registerLongGauge("grpc.lb.rls.cache_entries", + "EXPERIMENTAL. Number of entries in the RLS cache", "{entry}", + Arrays.asList("grpc.target", "grpc.lb.rls.server_target", "grpc.lb.rls.instance_uuid"), + Collections.emptyList(), false); + CACHE_SIZE_GAUGE = metricInstrumentRegistry.registerLongGauge("grpc.lb.rls.cache_size", + "EXPERIMENTAL. The current size of the RLS cache", "By", + Arrays.asList("grpc.target", "grpc.lb.rls.server_target", "grpc.lb.rls.instance_uuid"), + Collections.emptyList(), false); + } private CachingRlsLbClient(Builder builder) { helper = new RlsLbHelper(checkNotNull(builder.helper, "helper")); scheduledExecutorService = helper.getScheduledExecutorService(); - synchronizationContext = helper.getSynchronizationContext(); lbPolicyConfig = checkNotNull(builder.lbPolicyConfig, "lbPolicyConfig"); RouteLookupConfig rlsConfig = lbPolicyConfig.getRouteLookupConfig(); maxAgeNanos = rlsConfig.maxAgeInNanos(); @@ -129,10 +181,11 @@ private CachingRlsLbClient(Builder builder) { linkedHashLruCache = new RlsAsyncLruCache( rlsConfig.cacheSizeBytes(), - builder.evictionListener, - scheduledExecutorService, + new AutoCleaningEvictionListener(builder.evictionListener), ticker, - lock); + helper); + periodicCleaner = + scheduledExecutorService.scheduleAtFixedRate(this::periodicClean, 1, 1, TimeUnit.MINUTES); logger = helper.getChannelLogger(); String serverHost = null; try { @@ -147,7 +200,7 @@ private CachingRlsLbClient(Builder builder) { } RlsRequestFactory requestFactory = new RlsRequestFactory( lbPolicyConfig.getRouteLookupConfig(), serverHost); - rlsPicker = new RlsPicker(requestFactory); + rlsPicker = new RlsPicker(requestFactory, rlsConfig.lookupService()); // It is safe to use helper.getUnsafeChannelCredentials() because the client authenticates the // RLS server using the same authority as the backends, even though the RLS server’s addresses // will be looked up differently than the backends; overrideAuthority(helper.getAuthority()) is @@ -166,7 +219,35 @@ private CachingRlsLbClient(Builder builder) { rlsChannelBuilder.disableServiceConfigLookUp(); } rlsChannel = rlsChannelBuilder.build(); - helper.updateBalancingState(ConnectivityState.CONNECTING, rlsPicker); + Runnable rlsServerConnectivityStateChangeHandler = new Runnable() { + private boolean wasInTransientFailure; + @Override + public void run() { + ConnectivityState currentState = rlsChannel.getState(false); + if (currentState == ConnectivityState.TRANSIENT_FAILURE) { + wasInTransientFailure = true; + } else if (wasInTransientFailure && currentState == ConnectivityState.READY) { + wasInTransientFailure = false; + synchronized (lock) { + boolean anyBackoffsCanceled = false; + for (CacheEntry value : linkedHashLruCache.values()) { + if (value instanceof BackoffCacheEntry) { + if (((BackoffCacheEntry) value).scheduledFuture.cancel(false)) { + anyBackoffsCanceled = true; + } + } + } + if (anyBackoffsCanceled) { + // Cache updated. updateBalancingState() to reattempt picks + helper.triggerPendingRpcProcessing(); + } + } + } + rlsChannel.notifyWhenStateChanged(currentState, this); + } + }; + rlsChannel.notifyWhenStateChanged( + ConnectivityState.IDLE, rlsServerConnectivityStateChangeHandler); rlsStub = RouteLookupServiceGrpc.newStub(rlsChannel); childLbResolvedAddressFactory = checkNotNull(builder.resolvedAddressFactory, "resolvedAddressFactory"); @@ -176,13 +257,52 @@ private CachingRlsLbClient(Builder builder) { refCountedChildPolicyWrapperFactory = new RefCountedChildPolicyWrapperFactory( lbPolicyConfig.getLoadBalancingPolicy(), childLbResolvedAddressFactory, - childLbHelperProvider, - new BackoffRefreshListener()); + childLbHelperProvider); + // TODO(creamsoup) wait until lb is ready + String defaultTarget = lbPolicyConfig.getRouteLookupConfig().defaultTarget(); + if (defaultTarget != null && !defaultTarget.isEmpty()) { + fallbackChildPolicyWrapper = refCountedChildPolicyWrapperFactory.createOrGet(defaultTarget); + } else { + fallbackChildPolicyWrapper = null; + } + + gaugeRegistration = helper.getMetricRecorder() + .registerBatchCallback(new BatchCallback() { + @Override + public void accept(BatchRecorder recorder) { + int estimatedSize; + long estimatedSizeBytes; + synchronized (lock) { + estimatedSize = linkedHashLruCache.estimatedSize(); + estimatedSizeBytes = linkedHashLruCache.estimatedSizeBytes(); + } + recorder.recordLongGauge(CACHE_ENTRIES_GAUGE, estimatedSize, + Arrays.asList(helper.getChannelTarget(), rlsConfig.lookupService(), + metricsInstanceUuid), Collections.emptyList()); + recorder.recordLongGauge(CACHE_SIZE_GAUGE, estimatedSizeBytes, + Arrays.asList(helper.getChannelTarget(), rlsConfig.lookupService(), + metricsInstanceUuid), Collections.emptyList()); + } + }, CACHE_ENTRIES_GAUGE, CACHE_SIZE_GAUGE); + logger.log(ChannelLogLevel.DEBUG, "CachingRlsLbClient created"); } + void init() { + synchronized (lock) { + refCountedChildPolicyWrapperFactory.init(); + } + } + + Status acceptResolvedAddressFactory(ResolvedAddressFactory childLbResolvedAddressFactory) { + synchronized (lock) { + return refCountedChildPolicyWrapperFactory.acceptResolvedAddressFactory( + childLbResolvedAddressFactory); + } + } + /** - * Convert the status to UNAVAILBLE and enhance the error message. + * Convert the status to UNAVAILABLE and enhance the error message. * @param status status as provided by server * @param serverName Used for error description * @return Transformed status @@ -194,42 +314,57 @@ static Status convertRlsServerStatus(Status status, String serverName) { serverName, status.getCode(), status.getDescription())); } - @CheckReturnValue - private ListenableFuture asyncRlsCall(RouteLookupRequest request) { - logger.log(ChannelLogLevel.DEBUG, "Making an async call to RLS"); - final SettableFuture response = SettableFuture.create(); + private void periodicClean() { + synchronized (lock) { + linkedHashLruCache.cleanupExpiredEntries(); + } + } + + /** Populates async cache entry for new request. */ + @GuardedBy("lock") + private CachedRouteLookupResponse asyncRlsCall( + RouteLookupRequestKey routeLookupRequestKey, @Nullable BackoffPolicy backoffPolicy, + RouteLookupRequest.Reason routeLookupReason) { if (throttler.shouldThrottle()) { - logger.log(ChannelLogLevel.DEBUG, "Request is throttled"); - response.setException(new ThrottledException()); - return response; + logger.log(ChannelLogLevel.DEBUG, "[RLS Entry {0}] Throttled RouteLookup", + routeLookupRequestKey); + // Cache updated, but no need to call updateBalancingState because no RPCs were queued waiting + // on this result + return CachedRouteLookupResponse.backoffEntry(createBackOffEntry( + routeLookupRequestKey, Status.RESOURCE_EXHAUSTED.withDescription("RLS throttled"), + backoffPolicy)); } - io.grpc.lookup.v1.RouteLookupRequest routeLookupRequest = REQUEST_CONVERTER.convert(request); - logger.log(ChannelLogLevel.DEBUG, "Sending RouteLookupRequest: {0}", routeLookupRequest); + final SettableFuture response = SettableFuture.create(); + io.grpc.lookup.v1.RouteLookupRequest routeLookupRequest = REQUEST_CONVERTER.convert( + RouteLookupRequest.create(routeLookupRequestKey.keyMap(), routeLookupReason)); + logger.log(ChannelLogLevel.DEBUG, + "[RLS Entry {0}] Starting RouteLookup: {1}", routeLookupRequestKey, routeLookupRequest); rlsStub.withDeadlineAfter(callTimeoutNanos, TimeUnit.NANOSECONDS) .routeLookup( routeLookupRequest, new StreamObserver() { @Override public void onNext(io.grpc.lookup.v1.RouteLookupResponse value) { - logger.log(ChannelLogLevel.DEBUG, "Received RouteLookupResponse: {0}", value); + logger.log(ChannelLogLevel.DEBUG, + "[RLS Entry {0}] RouteLookup succeeded: {1}", routeLookupRequestKey, value); response.set(RESPONSE_CONVERTER.reverse().convert(value)); } @Override public void onError(Throwable t) { - logger.log(ChannelLogLevel.DEBUG, "Error looking up route:", t); + logger.log(ChannelLogLevel.DEBUG, + "[RLS Entry {0}] RouteLookup failed: {1}", routeLookupRequestKey, t); response.setException(t); throttler.registerBackendResponse(true); - helper.propagateRlsError(); } @Override public void onCompleted() { - logger.log(ChannelLogLevel.DEBUG, "routeLookup call completed"); throttler.registerBackendResponse(false); } }); - return response; + return CachedRouteLookupResponse.pendingResponse( + createPendingEntry(routeLookupRequestKey, response, backoffPolicy)); } /** @@ -238,28 +373,30 @@ public void onCompleted() { * changed after the return. */ @CheckReturnValue - final CachedRouteLookupResponse get(final RouteLookupRequest request) { - logger.log(ChannelLogLevel.DEBUG, "Acquiring lock to get cached entry"); + final CachedRouteLookupResponse get(final RouteLookupRequestKey routeLookupRequestKey) { synchronized (lock) { - logger.log(ChannelLogLevel.DEBUG, "Acquired lock to get cached entry"); final CacheEntry cacheEntry; - cacheEntry = linkedHashLruCache.read(request); - if (cacheEntry == null) { - logger.log(ChannelLogLevel.DEBUG, "No cache entry found, making a new lrs request"); - return handleNewRequest(request); + cacheEntry = linkedHashLruCache.read(routeLookupRequestKey); + if (cacheEntry == null + || (cacheEntry instanceof BackoffCacheEntry + && !((BackoffCacheEntry) cacheEntry).isInBackoffPeriod())) { + PendingCacheEntry pendingEntry = pendingCallCache.get(routeLookupRequestKey); + if (pendingEntry != null) { + return CachedRouteLookupResponse.pendingResponse(pendingEntry); + } + return asyncRlsCall(routeLookupRequestKey, cacheEntry instanceof BackoffCacheEntry + ? ((BackoffCacheEntry) cacheEntry).backoffPolicy : null, + RouteLookupRequest.Reason.REASON_MISS); } if (cacheEntry instanceof DataCacheEntry) { // cache hit, initiate async-refresh if entry is staled - logger.log(ChannelLogLevel.DEBUG, "Cache hit for the request"); DataCacheEntry dataEntry = ((DataCacheEntry) cacheEntry); if (dataEntry.isStaled(ticker.read())) { - logger.log(ChannelLogLevel.DEBUG, "Cache entry is stale"); dataEntry.maybeRefresh(); } return CachedRouteLookupResponse.dataEntry((DataCacheEntry) cacheEntry); } - logger.log(ChannelLogLevel.DEBUG, "Cache hit for a backup entry"); return CachedRouteLookupResponse.backoffEntry((BackoffCacheEntry) cacheEntry); } } @@ -268,51 +405,100 @@ final CachedRouteLookupResponse get(final RouteLookupRequest request) { void close() { logger.log(ChannelLogLevel.DEBUG, "CachingRlsLbClient closed"); synchronized (lock) { + periodicCleaner.cancel(false); // all childPolicyWrapper will be returned via AutoCleaningEvictionListener linkedHashLruCache.close(); // TODO(creamsoup) maybe cancel all pending requests pendingCallCache.clear(); rlsChannel.shutdownNow(); rlsPicker.close(); + gaugeRegistration.close(); } } - /** - * Populates async cache entry for new request. This is only methods directly modifies the cache, - * any status change is happening via event (async request finished, timed out, etc) in {@link - * PendingCacheEntry}, {@link DataCacheEntry} and {@link BackoffCacheEntry}. - */ - private CachedRouteLookupResponse handleNewRequest(RouteLookupRequest request) { + void requestConnection() { + rlsChannel.getState(true); + } + + @GuardedBy("lock") + private PendingCacheEntry createPendingEntry( + RouteLookupRequestKey routeLookupRequestKey, + ListenableFuture pendingCall, + @Nullable BackoffPolicy backoffPolicy) { + PendingCacheEntry entry = new PendingCacheEntry(routeLookupRequestKey, pendingCall, + backoffPolicy); + // Add the entry to the map before adding the Listener, because the listener removes the + // entry from the map + pendingCallCache.put(routeLookupRequestKey, entry); + // Beware that the listener can run immediately on the current thread + pendingCall.addListener(() -> pendingRpcComplete(entry), MoreExecutors.directExecutor()); + return entry; + } + + private void pendingRpcComplete(PendingCacheEntry entry) { synchronized (lock) { - PendingCacheEntry pendingEntry = pendingCallCache.get(request); - if (pendingEntry != null) { - return CachedRouteLookupResponse.pendingResponse(pendingEntry); + boolean clientClosed = pendingCallCache.remove(entry.routeLookupRequestKey) == null; + if (clientClosed) { + return; } - ListenableFuture asyncCall = asyncRlsCall(request); - if (!asyncCall.isDone()) { - pendingEntry = new PendingCacheEntry(request, asyncCall); - pendingCallCache.put(request, pendingEntry); - return CachedRouteLookupResponse.pendingResponse(pendingEntry); - } else { - // async call returned finished future is most likely throttled - try { - RouteLookupResponse response = asyncCall.get(); - DataCacheEntry dataEntry = new DataCacheEntry(request, response); - linkedHashLruCache.cacheAndClean(request, dataEntry); - return CachedRouteLookupResponse.dataEntry(dataEntry); - } catch (Exception e) { - BackoffCacheEntry backoffEntry = - new BackoffCacheEntry(request, Status.fromThrowable(e), backoffProvider.get()); - linkedHashLruCache.cacheAndClean(request, backoffEntry); - return CachedRouteLookupResponse.backoffEntry(backoffEntry); - } + try { + createDataEntry(entry.routeLookupRequestKey, Futures.getDone(entry.pendingCall)); + // Cache updated. DataCacheEntry constructor indirectly calls updateBalancingState() to + // reattempt picks when the child LB is done connecting + } catch (Exception e) { + createBackOffEntry(entry.routeLookupRequestKey, Status.fromThrowable(e), + entry.backoffPolicy); + // Cache updated. updateBalancingState() to reattempt picks + helper.triggerPendingRpcProcessing(); } } } - void requestConnection() { - rlsChannel.getState(true); + @GuardedBy("lock") + private DataCacheEntry createDataEntry( + RouteLookupRequestKey routeLookupRequestKey, RouteLookupResponse routeLookupResponse) { + logger.log( + ChannelLogLevel.DEBUG, + "[RLS Entry {0}] Transition to data cache: routeLookupResponse={1}", + routeLookupRequestKey, routeLookupResponse); + DataCacheEntry entry = new DataCacheEntry(routeLookupRequestKey, routeLookupResponse); + // Constructor for DataCacheEntry causes updateBalancingState, but the picks can't happen until + // this cache update because the lock is held + linkedHashLruCache.cacheAndClean(routeLookupRequestKey, entry); + return entry; + } + + @GuardedBy("lock") + private BackoffCacheEntry createBackOffEntry(RouteLookupRequestKey routeLookupRequestKey, + Status status, @Nullable BackoffPolicy backoffPolicy) { + if (backoffPolicy == null) { + backoffPolicy = backoffProvider.get(); + } + long delayNanos = backoffPolicy.nextBackoffNanos(); + logger.log( + ChannelLogLevel.DEBUG, + "[RLS Entry {0}] Transition to back off: status={1}, delayNanos={2}", + routeLookupRequestKey, status, delayNanos); + BackoffCacheEntry entry = new BackoffCacheEntry(routeLookupRequestKey, status, backoffPolicy, + ticker.read() + delayNanos * 2); + // Lock is held, so the task can't execute before the assignment + entry.scheduledFuture = scheduledExecutorService.schedule( + () -> refreshBackoffEntry(entry), delayNanos, TimeUnit.NANOSECONDS); + linkedHashLruCache.cacheAndClean(routeLookupRequestKey, entry); + return entry; + } + + private void refreshBackoffEntry(BackoffCacheEntry entry) { + synchronized (lock) { + // This checks whether the task has been cancelled and prevents a second execution. + if (!entry.scheduledFuture.cancel(false)) { + // Future was previously cancelled + return; + } + // Cache updated. updateBalancingState() to reattempt picks + helper.triggerPendingRpcProcessing(); + } } private static final class RlsLbHelper extends ForwardingLoadBalancerHelper { @@ -337,20 +523,10 @@ public void updateBalancingState(ConnectivityState newState, SubchannelPicker ne super.updateBalancingState(newState, newPicker); } - void propagateRlsError() { - getSynchronizationContext().execute(new Runnable() { - @Override - public void run() { - if (picker != null) { - // Refresh the channel state and let pending RPCs reprocess the picker. - updateBalancingState(state, picker); - } - } - }); - } - void triggerPendingRpcProcessing() { - super.updateBalancingState(state, picker); + checkState(state != null, "updateBalancingState hasn't yet been called"); + helper.getSynchronizationContext().execute( + () -> super.updateBalancingState(state, picker)); } } @@ -452,103 +628,46 @@ public String toString() { } /** A pending cache entry when the async RouteLookup RPC is still on the fly. */ - final class PendingCacheEntry { + static final class PendingCacheEntry { private final ListenableFuture pendingCall; - private final RouteLookupRequest request; + private final RouteLookupRequestKey routeLookupRequestKey; + @Nullable private final BackoffPolicy backoffPolicy; PendingCacheEntry( - RouteLookupRequest request, ListenableFuture pendingCall) { - this(request, pendingCall, null); - } - - PendingCacheEntry( - RouteLookupRequest request, + RouteLookupRequestKey routeLookupRequestKey, ListenableFuture pendingCall, @Nullable BackoffPolicy backoffPolicy) { - this.request = checkNotNull(request, "request"); - this.pendingCall = pendingCall; - this.backoffPolicy = backoffPolicy == null ? backoffProvider.get() : backoffPolicy; - pendingCall.addListener( - new Runnable() { - @Override - public void run() { - handleDoneFuture(); - } - }, - synchronizationContext); - } - - private void handleDoneFuture() { - synchronized (lock) { - pendingCallCache.remove(request); - if (pendingCall.isCancelled()) { - return; - } - - try { - transitionToDataEntry(pendingCall.get()); - } catch (Exception e) { - if (e instanceof ThrottledException) { - transitionToBackOff(Status.RESOURCE_EXHAUSTED.withCause(e)); - } else { - transitionToBackOff(Status.fromThrowable(e)); - } - } - } - } - - private void transitionToDataEntry(RouteLookupResponse routeLookupResponse) { - synchronized (lock) { - logger.log( - ChannelLogLevel.DEBUG, - "Transition to data cache: routeLookupResponse={0}", - routeLookupResponse); - linkedHashLruCache.cacheAndClean(request, new DataCacheEntry(request, routeLookupResponse)); - } - } - - private void transitionToBackOff(Status status) { - synchronized (lock) { - logger.log(ChannelLogLevel.DEBUG, "Transition to back off: status={0}", status); - linkedHashLruCache.cacheAndClean(request, - new BackoffCacheEntry(request, status, backoffPolicy)); - } + this.routeLookupRequestKey = checkNotNull(routeLookupRequestKey, "request"); + this.pendingCall = checkNotNull(pendingCall, "pendingCall"); + this.backoffPolicy = backoffPolicy; } @Override public String toString() { return MoreObjects.toStringHelper(this) - .add("request", request) + .add("routeLookupRequestKey", routeLookupRequestKey) .toString(); } } /** Common cache entry data for {@link RlsAsyncLruCache}. */ - abstract class CacheEntry { + abstract static class CacheEntry { - protected final RouteLookupRequest request; + protected final RouteLookupRequestKey routeLookupRequestKey; - CacheEntry(RouteLookupRequest request) { - this.request = checkNotNull(request, "request"); + CacheEntry(RouteLookupRequestKey routeLookupRequestKey) { + this.routeLookupRequestKey = checkNotNull(routeLookupRequestKey, "request"); } abstract int getSizeBytes(); - final boolean isExpired() { - return isExpired(ticker.read()); - } - abstract boolean isExpired(long now); abstract void cleanup(); - protected long getMinEvictionTime() { - return 0L; - } - - protected void triggerPendingRpcProcessing() { - helper.triggerPendingRpcProcessing(); + protected boolean isOldEnoughToBeEvicted(long now) { + return true; } } @@ -561,8 +680,9 @@ final class DataCacheEntry extends CacheEntry { private final List childPolicyWrappers; // GuardedBy CachingRlsLbClient.lock - DataCacheEntry(RouteLookupRequest request, final RouteLookupResponse response) { - super(request); + DataCacheEntry(RouteLookupRequestKey routeLookupRequestKey, + final RouteLookupResponse response) { + super(routeLookupRequestKey); this.response = checkNotNull(response, "response"); checkState(!response.targets().isEmpty(), "No targets returned by RLS"); childPolicyWrappers = @@ -589,36 +709,15 @@ final class DataCacheEntry extends CacheEntry { * */ void maybeRefresh() { - logger.log(ChannelLogLevel.DEBUG, "Acquiring lock to maybe refresh cache entry"); - synchronized (lock) { - logger.log(ChannelLogLevel.DEBUG, "Lock to maybe refresh cache entry acquired"); - if (pendingCallCache.containsKey(request)) { + synchronized (lock) { // Lock is already held, but ErrorProne can't tell + if (pendingCallCache.containsKey(routeLookupRequestKey)) { // pending already requested - logger.log(ChannelLogLevel.DEBUG, - "A pending refresh request already created, no need to proceed with refresh"); return; } - final ListenableFuture asyncCall = asyncRlsCall(request); - if (!asyncCall.isDone()) { - logger.log(ChannelLogLevel.DEBUG, - "Async call to rls not yet complete, adding a pending cache entry"); - pendingCallCache.put(request, new PendingCacheEntry(request, asyncCall)); - } else { - // async call returned finished future is most likely throttled - try { - logger.log(ChannelLogLevel.DEBUG, "Waiting for RLS call to return"); - RouteLookupResponse response = asyncCall.get(); - logger.log(ChannelLogLevel.DEBUG, "RLS call to returned"); - linkedHashLruCache.cacheAndClean(request, new DataCacheEntry(request, response)); - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - } catch (Exception e) { - logger.log(ChannelLogLevel.DEBUG, "RLS call failed, adding a backoff entry", e); - BackoffCacheEntry backoffEntry = - new BackoffCacheEntry(request, Status.fromThrowable(e), backoffProvider.get()); - linkedHashLruCache.cacheAndClean(request, backoffEntry); - } - } + logger.log(ChannelLogLevel.DEBUG, + "[RLS Entry {0}] Cache entry is stale, refreshing", routeLookupRequestKey); + asyncRlsCall(routeLookupRequestKey, /* backoffPolicy= */ null, + RouteLookupRequest.Reason.REASON_STALE); } } @@ -672,8 +771,8 @@ boolean isStaled(long now) { } @Override - protected long getMinEvictionTime() { - return minEvictionTime; + protected boolean isOldEnoughToBeEvicted(long now) { + return minEvictionTime - now <= 0; } @Override @@ -688,7 +787,7 @@ void cleanup() { @Override public String toString() { return MoreObjects.toStringHelper(this) - .add("request", request) + .add("request", routeLookupRequestKey) .add("response", response) .add("expireTime", expireTime) .add("staleTime", staleTime) @@ -701,77 +800,19 @@ public String toString() { * Implementation of {@link CacheEntry} contains error. This entry will transition to pending * status when the backoff time is expired. */ - private final class BackoffCacheEntry extends CacheEntry { + private static final class BackoffCacheEntry extends CacheEntry { private final Status status; - private final ScheduledHandle scheduledHandle; private final BackoffPolicy backoffPolicy; - private final long expireNanos; - private boolean shutdown = false; + private final long expiryTimeNanos; + private Future scheduledFuture; - BackoffCacheEntry(RouteLookupRequest request, Status status, BackoffPolicy backoffPolicy) { - super(request); + BackoffCacheEntry(RouteLookupRequestKey routeLookupRequestKey, Status status, + BackoffPolicy backoffPolicy, long expiryTimeNanos) { + super(routeLookupRequestKey); this.status = checkNotNull(status, "status"); this.backoffPolicy = checkNotNull(backoffPolicy, "backoffPolicy"); - long delayNanos = backoffPolicy.nextBackoffNanos(); - this.expireNanos = ticker.read() + delayNanos; - this.scheduledHandle = - synchronizationContext.schedule( - new Runnable() { - @Override - public void run() { - transitionToPending(); - } - }, - delayNanos, - TimeUnit.NANOSECONDS, - scheduledExecutorService); - logger.log(ChannelLogLevel.DEBUG, "BackoffCacheEntry created with a delay of {0} nanos", - delayNanos); - } - - /** Forcefully refreshes cache entry by ignoring the backoff timer. */ - void forceRefresh() { - logger.log(ChannelLogLevel.DEBUG, "Forcefully refreshing cache entry"); - if (scheduledHandle.isPending()) { - scheduledHandle.cancel(); - transitionToPending(); - } - } - - private void transitionToPending() { - logger.log(ChannelLogLevel.DEBUG, "Acquiring lock to transition to pending"); - synchronized (lock) { - logger.log(ChannelLogLevel.DEBUG, "Acquired lock to transition to pending"); - if (shutdown) { - logger.log(ChannelLogLevel.DEBUG, "Already shut down, not transitioning to pending"); - return; - } - logger.log(ChannelLogLevel.DEBUG, "Calling RLS for transition to pending"); - ListenableFuture call = asyncRlsCall(request); - if (!call.isDone()) { - logger.log(ChannelLogLevel.DEBUG, - "Transition to pending RLS call not done, adding a pending cache entry"); - PendingCacheEntry pendingEntry = new PendingCacheEntry(request, call, backoffPolicy); - pendingCallCache.put(request, pendingEntry); - linkedHashLruCache.invalidate(request); - } else { - try { - logger.log(ChannelLogLevel.DEBUG, - "Waiting for transition to pending RLS call response"); - RouteLookupResponse response = call.get(); - linkedHashLruCache.cacheAndClean(request, new DataCacheEntry(request, response)); - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - } catch (Exception e) { - logger.log(ChannelLogLevel.DEBUG, - "Transition to pending RLS call failed, creating a backoff entry", e); - linkedHashLruCache.cacheAndClean( - request, - new BackoffCacheEntry(request, Status.fromThrowable(e), backoffPolicy)); - } - } - } + this.expiryTimeNanos = expiryTimeNanos; } Status getStatus() { @@ -783,26 +824,24 @@ int getSizeBytes() { return OBJ_OVERHEAD_B * 3 + Long.SIZE + 8; // 3 java objects, 1 long and a boolean } + boolean isInBackoffPeriod() { + return !scheduledFuture.isDone(); + } + @Override - boolean isExpired(long now) { - return expireNanos - now <= 0; + boolean isExpired(long nowNanos) { + return nowNanos > expiryTimeNanos; } @Override void cleanup() { - if (shutdown) { - return; - } - shutdown = true; - if (!scheduledHandle.isPending()) { - scheduledHandle.cancel(); - } + scheduledFuture.cancel(false); } @Override public String toString() { return MoreObjects.toStringHelper(this) - .add("request", request) + .add("request", routeLookupRequestKey) .add("status", status) .toString(); } @@ -821,7 +860,7 @@ static final class Builder { private Throttler throttler = new HappyThrottler(); private ResolvedAddressFactory resolvedAddressFactory; private Ticker ticker = Ticker.systemTicker(); - private EvictionListener evictionListener; + private EvictionListener evictionListener; private BackoffPolicy.Provider backoffProvider = new ExponentialBackoffPolicy.Provider(); Builder setHelper(Helper helper) { @@ -855,7 +894,7 @@ Builder setTicker(Ticker ticker) { } Builder setEvictionListener( - @Nullable EvictionListener evictionListener) { + @Nullable EvictionListener evictionListener) { this.evictionListener = evictionListener; return this; } @@ -866,7 +905,9 @@ Builder setBackoffProvider(BackoffPolicy.Provider provider) { } CachingRlsLbClient build() { - return new CachingRlsLbClient(this); + CachingRlsLbClient client = new CachingRlsLbClient(this); + client.init(); + return client; } } @@ -875,17 +916,17 @@ CachingRlsLbClient build() { * CacheEntry#cleanup()} after original {@link EvictionListener} is finished. */ private static final class AutoCleaningEvictionListener - implements EvictionListener { + implements EvictionListener { - private final EvictionListener delegate; + private final EvictionListener delegate; AutoCleaningEvictionListener( - @Nullable EvictionListener delegate) { + @Nullable EvictionListener delegate) { this.delegate = delegate; } @Override - public void onEviction(RouteLookupRequest key, CacheEntry value, EvictionType cause) { + public void onEviction(RouteLookupRequestKey key, CacheEntry value, EvictionType cause) { if (delegate != null) { delegate.onEviction(key, value, cause); } @@ -910,35 +951,30 @@ public void registerBackendResponse(boolean throttled) { /** Implementation of {@link LinkedHashLruCache} for RLS. */ private static final class RlsAsyncLruCache - extends LinkedHashLruCache { + extends LinkedHashLruCache { + private final RlsLbHelper helper; RlsAsyncLruCache(long maxEstimatedSizeBytes, - @Nullable EvictionListener evictionListener, - ScheduledExecutorService ses, Ticker ticker, Object lock) { - super( - maxEstimatedSizeBytes, - new AutoCleaningEvictionListener(evictionListener), - 1, - TimeUnit.MINUTES, - ses, - ticker, - lock); + @Nullable EvictionListener evictionListener, + Ticker ticker, RlsLbHelper helper) { + super(maxEstimatedSizeBytes, evictionListener, ticker); + this.helper = checkNotNull(helper, "helper"); } @Override - protected boolean isExpired(RouteLookupRequest key, CacheEntry value, long nowNanos) { - return value.isExpired(); + protected boolean isExpired(RouteLookupRequestKey key, CacheEntry value, long nowNanos) { + return value.isExpired(nowNanos); } @Override - protected int estimateSizeOf(RouteLookupRequest key, CacheEntry value) { + protected int estimateSizeOf(RouteLookupRequestKey key, CacheEntry value) { return value.getSizeBytes(); } @Override protected boolean shouldInvalidateEldestEntry( - RouteLookupRequest eldestKey, CacheEntry eldestValue) { - if (eldestValue.getMinEvictionTime() > now()) { + RouteLookupRequestKey eldestKey, CacheEntry eldestValue, long now) { + if (!eldestValue.isOldEnoughToBeEvicted(now)) { return false; } @@ -946,46 +982,17 @@ protected boolean shouldInvalidateEldestEntry( return this.estimatedSizeBytes() > this.estimatedMaxSizeBytes(); } - public CacheEntry cacheAndClean(RouteLookupRequest key, CacheEntry value) { + public CacheEntry cacheAndClean(RouteLookupRequestKey key, CacheEntry value) { CacheEntry newEntry = cache(key, value); // force cleanup if new entry pushed cache over max size (in bytes) if (fitToLimit()) { - value.triggerPendingRpcProcessing(); + helper.triggerPendingRpcProcessing(); } return newEntry; } } - /** - * LbStatusListener refreshes {@link BackoffCacheEntry} when lb state is changed to {@link - * ConnectivityState#READY} from {@link ConnectivityState#TRANSIENT_FAILURE}. - */ - private final class BackoffRefreshListener implements ChildLbStatusListener { - - @Nullable - private ConnectivityState prevState = null; - - @Override - public void onStatusChanged(ConnectivityState newState) { - logger.log(ChannelLogLevel.DEBUG, "LB status changed to: {0}", newState); - if (prevState == ConnectivityState.TRANSIENT_FAILURE - && newState == ConnectivityState.READY) { - logger.log(ChannelLogLevel.DEBUG, "Transitioning from TRANSIENT_FAILURE to READY"); - logger.log(ChannelLogLevel.DEBUG, "Acquiring lock force refresh backoff cache entries"); - synchronized (lock) { - logger.log(ChannelLogLevel.DEBUG, "Lock acquired for refreshing backoff cache entries"); - for (CacheEntry value : linkedHashLruCache.values()) { - if (value instanceof BackoffCacheEntry) { - ((BackoffCacheEntry) value).forceRefresh(); - } - } - } - } - prevState = newState; - } - } - /** A header will be added when RLS server respond with additional header data. */ @VisibleForTesting static final Metadata.Key RLS_DATA_KEY = @@ -994,92 +1001,95 @@ public void onStatusChanged(ConnectivityState newState) { final class RlsPicker extends SubchannelPicker { private final RlsRequestFactory requestFactory; + private final String lookupService; - RlsPicker(RlsRequestFactory requestFactory) { + RlsPicker(RlsRequestFactory requestFactory, String lookupService) { this.requestFactory = checkNotNull(requestFactory, "requestFactory"); + this.lookupService = checkNotNull(lookupService, "rlsConfig"); } @Override public PickResult pickSubchannel(PickSubchannelArgs args) { String serviceName = args.getMethodDescriptor().getServiceName(); String methodName = args.getMethodDescriptor().getBareMethodName(); - RouteLookupRequest request = + RlsProtoData.RouteLookupRequestKey lookupRequestKey = requestFactory.create(serviceName, methodName, args.getHeaders()); - final CachedRouteLookupResponse response = CachingRlsLbClient.this.get(request); - logger.log(ChannelLogLevel.DEBUG, - "Got route lookup cache entry for service={0}, method={1}, headers={2}:\n {3}", - new Object[]{serviceName, methodName, args.getHeaders(), response}); + final CachedRouteLookupResponse response = CachingRlsLbClient.this.get(lookupRequestKey); if (response.getHeaderData() != null && !response.getHeaderData().isEmpty()) { - logger.log(ChannelLogLevel.DEBUG, "Updating LRS metadata from the LRS response headers"); Metadata headers = args.getHeaders(); headers.discardAll(RLS_DATA_KEY); headers.put(RLS_DATA_KEY, response.getHeaderData()); } String defaultTarget = lbPolicyConfig.getRouteLookupConfig().defaultTarget(); - logger.log(ChannelLogLevel.DEBUG, "defaultTarget = {0}", defaultTarget); boolean hasFallback = defaultTarget != null && !defaultTarget.isEmpty(); if (response.hasData()) { - logger.log(ChannelLogLevel.DEBUG, "LRS response has data, proceed with selecting a picker"); ChildPolicyWrapper childPolicyWrapper = response.getChildPolicyWrapper(); SubchannelPicker picker = (childPolicyWrapper != null) ? childPolicyWrapper.getPicker() : null; if (picker == null) { - logger.log(ChannelLogLevel.DEBUG, - "Child policy wrapper didn't return a picker, returning PickResult with no results"); return PickResult.withNoResult(); } // Happy path - logger.log(ChannelLogLevel.DEBUG, "Returning PickResult"); - return picker.pickSubchannel(args); + PickResult pickResult = picker.pickSubchannel(args); + if (pickResult.hasResult()) { + helper.getMetricRecorder().addLongCounter(TARGET_PICKS_COUNTER, 1, + Arrays.asList(helper.getChannelTarget(), lookupService, + childPolicyWrapper.getTarget(), determineMetricsPickResult(pickResult)), + Arrays.asList(determineCustomLabel(args))); + } + return pickResult; } else if (response.hasError()) { - logger.log(ChannelLogLevel.DEBUG, "RLS response has errors"); if (hasFallback) { - logger.log(ChannelLogLevel.DEBUG, "Using RLS fallback"); return useFallback(args); } - logger.log(ChannelLogLevel.DEBUG, "No RLS fallback, returning PickResult with an error"); + helper.getMetricRecorder().addLongCounter(FAILED_PICKS_COUNTER, 1, + Arrays.asList(helper.getChannelTarget(), lookupService), + Arrays.asList(determineCustomLabel(args))); return PickResult.withError( convertRlsServerStatus(response.getStatus(), lbPolicyConfig.getRouteLookupConfig().lookupService())); } else { - logger.log(ChannelLogLevel.DEBUG, - "RLS response had no data, return a PickResult with no data"); return PickResult.withNoResult(); } } - private ChildPolicyWrapper fallbackChildPolicyWrapper; - /** Uses Subchannel connected to default target. */ private PickResult useFallback(PickSubchannelArgs args) { - // TODO(creamsoup) wait until lb is ready - startFallbackChildPolicy(); SubchannelPicker picker = fallbackChildPolicyWrapper.getPicker(); if (picker == null) { return PickResult.withNoResult(); } - return picker.pickSubchannel(args); + PickResult pickResult = picker.pickSubchannel(args); + if (pickResult.hasResult()) { + helper.getMetricRecorder().addLongCounter(DEFAULT_TARGET_PICKS_COUNTER, 1, + Arrays.asList(helper.getChannelTarget(), lookupService, + fallbackChildPolicyWrapper.getTarget(), determineMetricsPickResult(pickResult)), + Arrays.asList(determineCustomLabel(args))); + } + return pickResult; } - private void startFallbackChildPolicy() { - String defaultTarget = lbPolicyConfig.getRouteLookupConfig().defaultTarget(); - logger.log(ChannelLogLevel.DEBUG, "starting fallback to {0}", defaultTarget); - logger.log(ChannelLogLevel.DEBUG, "Acquiring lock to start fallback child policy"); - synchronized (lock) { - logger.log(ChannelLogLevel.DEBUG, "Acquired lock for starting fallback child policy"); - if (fallbackChildPolicyWrapper != null) { - return; - } - fallbackChildPolicyWrapper = refCountedChildPolicyWrapperFactory.createOrGet(defaultTarget); + private String determineMetricsPickResult(PickResult pickResult) { + if (pickResult.getStatus().isOk()) { + return "complete"; + } else if (pickResult.isDrop()) { + return "drop"; + } else { + return "fail"; } } + private String determineCustomLabel(PickSubchannelArgs args) { + return args.getCallOptions().getOption(Grpc.CALL_OPTION_CUSTOM_LABEL); + } + // GuardedBy CachingRlsLbClient.lock void close() { - logger.log(ChannelLogLevel.DEBUG, "Closing RLS picker"); - if (fallbackChildPolicyWrapper != null) { - refCountedChildPolicyWrapperFactory.release(fallbackChildPolicyWrapper); + synchronized (lock) { // Lock is already held, but ErrorProne can't tell + if (fallbackChildPolicyWrapper != null) { + refCountedChildPolicyWrapperFactory.release(fallbackChildPolicyWrapper); + } } } diff --git a/rls/src/main/java/io/grpc/rls/ChildLoadBalancerHelper.java b/rls/src/main/java/io/grpc/rls/ChildLoadBalancerHelper.java index 3131aba7551..7a5d5dcc645 100644 --- a/rls/src/main/java/io/grpc/rls/ChildLoadBalancerHelper.java +++ b/rls/src/main/java/io/grpc/rls/ChildLoadBalancerHelper.java @@ -77,6 +77,10 @@ static final class ChildLoadBalancerHelperProvider { this.picker = checkNotNull(picker, "picker"); } + void init() { + helper.updateBalancingState(ConnectivityState.CONNECTING, picker); + } + ChildLoadBalancerHelper forTarget(String target) { return new ChildLoadBalancerHelper(target, helper, subchannelStateManager, picker); } diff --git a/rls/src/main/java/io/grpc/rls/LbPolicyConfiguration.java b/rls/src/main/java/io/grpc/rls/LbPolicyConfiguration.java index 37b7c2eb0be..77ed080e654 100644 --- a/rls/src/main/java/io/grpc/rls/LbPolicyConfiguration.java +++ b/rls/src/main/java/io/grpc/rls/LbPolicyConfiguration.java @@ -31,6 +31,7 @@ import io.grpc.LoadBalancerProvider; import io.grpc.LoadBalancerRegistry; import io.grpc.NameResolver.ConfigOrError; +import io.grpc.Status; import io.grpc.internal.ObjectPool; import io.grpc.rls.ChildLoadBalancerHelper.ChildLoadBalancerHelperProvider; import io.grpc.rls.RlsProtoData.RouteLookupConfig; @@ -41,7 +42,6 @@ import java.util.List; import java.util.Map; import java.util.Objects; -import java.util.concurrent.atomic.AtomicLong; import javax.annotation.Nullable; /** Configuration for RLS load balancing policy. */ @@ -203,40 +203,52 @@ public String toString() { } } - /** Factory for {@link ChildPolicyWrapper}. */ + /** Factory for {@link ChildPolicyWrapper}. Not thread-safe. */ static final class RefCountedChildPolicyWrapperFactory { - // GuardedBy CachingRlsLbClient.lock @VisibleForTesting final Map childPolicyMap = new HashMap<>(); private final ChildLoadBalancerHelperProvider childLbHelperProvider; - private final ChildLbStatusListener childLbStatusListener; private final ChildLoadBalancingPolicy childPolicy; - private final ResolvedAddressFactory childLbResolvedAddressFactory; + private ResolvedAddressFactory childLbResolvedAddressFactory; public RefCountedChildPolicyWrapperFactory( ChildLoadBalancingPolicy childPolicy, ResolvedAddressFactory childLbResolvedAddressFactory, - ChildLoadBalancerHelperProvider childLbHelperProvider, - ChildLbStatusListener childLbStatusListener) { + ChildLoadBalancerHelperProvider childLbHelperProvider) { this.childPolicy = checkNotNull(childPolicy, "childPolicy"); this.childLbResolvedAddressFactory = checkNotNull(childLbResolvedAddressFactory, "childLbResolvedAddressFactory"); this.childLbHelperProvider = checkNotNull(childLbHelperProvider, "childLbHelperProvider"); - this.childLbStatusListener = checkNotNull(childLbStatusListener, "childLbStatusListener"); } - // GuardedBy CachingRlsLbClient.lock + void init() { + childLbHelperProvider.init(); + } + + Status acceptResolvedAddressFactory(ResolvedAddressFactory childLbResolvedAddressFactory) { + this.childLbResolvedAddressFactory = childLbResolvedAddressFactory; + Status status = Status.OK; + for (RefCountedChildPolicyWrapper wrapper : childPolicyMap.values()) { + Status newStatus = + wrapper.childPolicyWrapper.acceptResolvedAddressFactory(childLbResolvedAddressFactory); + if (!newStatus.isOk()) { + status = newStatus; + } + } + return status; + } + ChildPolicyWrapper createOrGet(String target) { // TODO(creamsoup) check if the target is valid or not RefCountedChildPolicyWrapper pooledChildPolicyWrapper = childPolicyMap.get(target); if (pooledChildPolicyWrapper == null) { ChildPolicyWrapper childPolicyWrapper = new ChildPolicyWrapper( - target, childPolicy, childLbResolvedAddressFactory, childLbHelperProvider, - childLbStatusListener); + target, childPolicy, childLbHelperProvider); pooledChildPolicyWrapper = RefCountedChildPolicyWrapper.of(childPolicyWrapper); childPolicyMap.put(target, pooledChildPolicyWrapper); + childPolicyWrapper.start(childLbResolvedAddressFactory); return pooledChildPolicyWrapper.getObject(); } else { ChildPolicyWrapper childPolicyWrapper = pooledChildPolicyWrapper.getObject(); @@ -247,7 +259,6 @@ ChildPolicyWrapper createOrGet(String target) { } } - // GuardedBy CachingRlsLbClient.lock List createOrGet(List targets) { List retVal = new ArrayList<>(); for (String target : targets) { @@ -256,7 +267,6 @@ List createOrGet(List targets) { return retVal; } - // GuardedBy CachingRlsLbClient.lock void release(ChildPolicyWrapper childPolicyWrapper) { checkNotNull(childPolicyWrapper, "childPolicyWrapper"); String target = childPolicyWrapper.getTarget(); @@ -278,32 +288,33 @@ static final class ChildPolicyWrapper { private final String target; private final ChildPolicyReportingHelper helper; private final LoadBalancer lb; + private final Object childLbConfig; private volatile SubchannelPicker picker; private ConnectivityState state; public ChildPolicyWrapper( String target, ChildLoadBalancingPolicy childPolicy, - final ResolvedAddressFactory childLbResolvedAddressFactory, - ChildLoadBalancerHelperProvider childLbHelperProvider, - ChildLbStatusListener childLbStatusListener) { + ChildLoadBalancerHelperProvider childLbHelperProvider) { this.target = target; - this.helper = - new ChildPolicyReportingHelper(childLbHelperProvider, childLbStatusListener); + this.helper = new ChildPolicyReportingHelper(childLbHelperProvider); LoadBalancerProvider lbProvider = childPolicy.getEffectiveLbProvider(); final ConfigOrError lbConfig = lbProvider .parseLoadBalancingPolicyConfig( childPolicy.getEffectiveChildPolicy(target)); this.lb = lbProvider.newLoadBalancer(helper); + this.childLbConfig = lbConfig.getConfig(); helper.getChannelLogger().log( - ChannelLogLevel.DEBUG, "RLS child lb created. config: {0}", lbConfig.getConfig()); + ChannelLogLevel.DEBUG, "RLS child lb created. config: {0}", childLbConfig); + } + + void start(ResolvedAddressFactory childLbResolvedAddressFactory) { helper.getSynchronizationContext().execute( new Runnable() { @Override public void run() { - if (!lb.acceptResolvedAddresses( - childLbResolvedAddressFactory.create(lbConfig.getConfig())).isOk()) { + if (!acceptResolvedAddressFactory(childLbResolvedAddressFactory).isOk()) { helper.refreshNameResolution(); } lb.requestConnection(); @@ -311,6 +322,11 @@ public void run() { }); } + Status acceptResolvedAddressFactory(ResolvedAddressFactory childLbResolvedAddressFactory) { + helper.getSynchronizationContext().throwIfNotInThisSynchronizationContext(); + return lb.acceptResolvedAddresses(childLbResolvedAddressFactory.create(childLbConfig)); + } + String getTarget() { return target; } @@ -367,14 +383,11 @@ public String toString() { final class ChildPolicyReportingHelper extends ForwardingLoadBalancerHelper { private final ChildLoadBalancerHelper delegate; - private final ChildLbStatusListener listener; ChildPolicyReportingHelper( - ChildLoadBalancerHelperProvider childHelperProvider, - ChildLbStatusListener listener) { + ChildLoadBalancerHelperProvider childHelperProvider) { checkNotNull(childHelperProvider, "childHelperProvider"); this.delegate = childHelperProvider.forTarget(getTarget()); - this.listener = checkNotNull(listener, "listener"); } @Override @@ -387,22 +400,14 @@ public void updateBalancingState(ConnectivityState newState, SubchannelPicker ne picker = newPicker; state = newState; super.updateBalancingState(newState, newPicker); - listener.onStatusChanged(newState); } } } - /** Listener for child lb status change events. */ - interface ChildLbStatusListener { - - /** Notifies when child lb status changes. */ - void onStatusChanged(ConnectivityState newState); - } - private static final class RefCountedChildPolicyWrapper implements ObjectPool { - private final AtomicLong refCnt = new AtomicLong(); + private long refCnt; @Nullable private ChildPolicyWrapper childPolicyWrapper; @@ -413,7 +418,7 @@ private RefCountedChildPolicyWrapper(ChildPolicyWrapper childPolicyWrapper) { @Override public ChildPolicyWrapper getObject() { checkState(!isReleased(), "ChildPolicyWrapper is already released"); - refCnt.getAndIncrement(); + refCnt++; return childPolicyWrapper; } @@ -426,7 +431,7 @@ public ChildPolicyWrapper returnObject(Object object) { checkState( childPolicyWrapper == object, "returned object doesn't match the pooled childPolicyWrapper"); - long newCnt = refCnt.decrementAndGet(); + long newCnt = --refCnt; checkState(newCnt != -1, "Cannot return never pooled childPolicyWrapper"); if (newCnt == 0) { childPolicyWrapper.shutdown(); @@ -447,7 +452,7 @@ static RefCountedChildPolicyWrapper of(ChildPolicyWrapper childPolicyWrapper) { public String toString() { return MoreObjects.toStringHelper(this) .add("object", childPolicyWrapper) - .add("refCnt", refCnt.get()) + .add("refCnt", refCnt) .toString(); } } diff --git a/rls/src/main/java/io/grpc/rls/LinkedHashLruCache.java b/rls/src/main/java/io/grpc/rls/LinkedHashLruCache.java index 5a4a2dab452..9a961759693 100644 --- a/rls/src/main/java/io/grpc/rls/LinkedHashLruCache.java +++ b/rls/src/main/java/io/grpc/rls/LinkedHashLruCache.java @@ -22,6 +22,7 @@ import com.google.common.base.MoreObjects; import com.google.common.base.Ticker; +import com.google.errorprone.annotations.CheckReturnValue; import java.util.ArrayList; import java.util.Collections; import java.util.Iterator; @@ -29,47 +30,31 @@ import java.util.List; import java.util.Map; import java.util.Objects; -import java.util.concurrent.ScheduledExecutorService; -import java.util.concurrent.ScheduledFuture; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicLong; -import javax.annotation.CheckReturnValue; import javax.annotation.Nullable; -import javax.annotation.concurrent.GuardedBy; -import javax.annotation.concurrent.ThreadSafe; /** * A LinkedHashLruCache implements least recently used caching where it supports access order lru * cache eviction while allowing entry level expiration time. When the cache reaches max capacity, * LruCache try to remove up to one already expired entries. If it doesn't find any expired entries, - * it will remove based on access order of entry. On top of this, LruCache also proactively removes - * expired entries based on configured time interval. + * it will remove based on access order of entry. To proactively clean up expired entries, call + * {@link #cleanupExpiredEntries()} (e.g., via a recurring timer). */ -@ThreadSafe abstract class LinkedHashLruCache implements LruCache { - private final Object lock; - - @GuardedBy("lock") private final LinkedHashMap delegate; - private final PeriodicCleaner periodicCleaner; private final Ticker ticker; - private final EvictionListener evictionListener; - private final AtomicLong estimatedSizeBytes = new AtomicLong(); + @Nullable + private final EvictionListener evictionListener; + private long estimatedSizeBytes; private long estimatedMaxSizeBytes; LinkedHashLruCache( final long estimatedMaxSizeBytes, @Nullable final EvictionListener evictionListener, - int cleaningInterval, - TimeUnit cleaningIntervalUnit, - ScheduledExecutorService ses, - final Ticker ticker, - Object lock) { + final Ticker ticker) { checkState(estimatedMaxSizeBytes > 0, "max estimated cache size should be positive"); this.estimatedMaxSizeBytes = estimatedMaxSizeBytes; - this.lock = checkNotNull(lock, "lock"); - this.evictionListener = new SizeHandlingEvictionListener(evictionListener); + this.evictionListener = evictionListener; this.ticker = checkNotNull(ticker, "ticker"); delegate = new LinkedHashMap( // rough estimate or minimum hashmap default @@ -78,15 +63,15 @@ abstract class LinkedHashLruCache implements LruCache { /* accessOrder= */ true) { @Override protected boolean removeEldestEntry(Map.Entry eldest) { - if (estimatedSizeBytes.get() <= LinkedHashLruCache.this.estimatedMaxSizeBytes) { + if (estimatedSizeBytes <= LinkedHashLruCache.this.estimatedMaxSizeBytes) { return false; } // first, remove at most 1 expired entry boolean removed = cleanupExpiredEntries(1, ticker.read()); // handles size based eviction if necessary no expired entry - boolean shouldRemove = - !removed && shouldInvalidateEldestEntry(eldest.getKey(), eldest.getValue().value); + boolean shouldRemove = !removed + && shouldInvalidateEldestEntry(eldest.getKey(), eldest.getValue().value, ticker.read()); if (shouldRemove) { // remove entry by us to make sure lruIterator and cache is in sync LinkedHashLruCache.this.invalidate(eldest.getKey(), EvictionType.SIZE); @@ -94,7 +79,6 @@ protected boolean removeEldestEntry(Map.Entry eldest) { return false; } }; - periodicCleaner = new PeriodicCleaner(ses, cleaningInterval, cleaningIntervalUnit).start(); } /** @@ -102,7 +86,7 @@ protected boolean removeEldestEntry(Map.Entry eldest) { * that LruCache is access level and the eldest is determined by access pattern. */ @SuppressWarnings("unused") - protected boolean shouldInvalidateEldestEntry(K eldestKey, V eldestValue) { + protected boolean shouldInvalidateEldestEntry(K eldestKey, V eldestValue, long now) { return true; } @@ -124,16 +108,14 @@ protected long estimatedMaxSizeBytes() { /** Updates size for given key if entry exists. It is useful if the cache value is mutated. */ public void updateEntrySize(K key) { - synchronized (lock) { - SizedValue entry = readInternal(key); - if (entry == null) { - return; - } - int prevSize = entry.size; - int newSize = estimateSizeOf(key, entry.value); - entry.size = newSize; - estimatedSizeBytes.addAndGet(newSize - prevSize); + SizedValue entry = readInternal(key); + if (entry == null) { + return; } + int prevSize = entry.size; + int newSize = estimateSizeOf(key, entry.value); + entry.size = newSize; + estimatedSizeBytes += newSize - prevSize; } /** @@ -141,7 +123,7 @@ public void updateEntrySize(K key) { * #estimateSizeOf(java.lang.Object, java.lang.Object)}. */ public long estimatedSizeBytes() { - return estimatedSizeBytes.get(); + return estimatedSizeBytes; } @Override @@ -151,12 +133,10 @@ public final V cache(K key, V value) { checkNotNull(value, "value"); SizedValue existing; int size = estimateSizeOf(key, value); - synchronized (lock) { - estimatedSizeBytes.addAndGet(size); - existing = delegate.put(key, new SizedValue(size, value)); - if (existing != null) { - evictionListener.onEviction(key, existing, EvictionType.REPLACED); - } + estimatedSizeBytes += size; + existing = delegate.put(key, new SizedValue(size, value)); + if (existing != null) { + fireOnEviction(key, existing, EvictionType.REPLACED); } return existing == null ? null : existing.value; } @@ -176,13 +156,11 @@ public final V read(K key) { @CheckReturnValue private SizedValue readInternal(K key) { checkNotNull(key, "key"); - synchronized (lock) { - SizedValue existing = delegate.get(key); - if (existing != null && isExpired(key, existing.value, ticker.read())) { - return null; - } - return existing; + SizedValue existing = delegate.get(key); + if (existing != null && isExpired(key, existing.value, ticker.read())) { + return null; } + return existing; } @Override @@ -195,26 +173,22 @@ public final V invalidate(K key) { private V invalidate(K key, EvictionType cause) { checkNotNull(key, "key"); checkNotNull(cause, "cause"); - synchronized (lock) { - SizedValue existing = delegate.remove(key); - if (existing != null) { - evictionListener.onEviction(key, existing, cause); - } - return existing == null ? null : existing.value; + SizedValue existing = delegate.remove(key); + if (existing != null) { + fireOnEviction(key, existing, cause); } + return existing == null ? null : existing.value; } @Override public final void invalidateAll() { - synchronized (lock) { - Iterator> iterator = delegate.entrySet().iterator(); - while (iterator.hasNext()) { - Map.Entry entry = iterator.next(); - if (entry.getValue() != null) { - evictionListener.onEviction(entry.getKey(), entry.getValue(), EvictionType.EXPLICIT); - } - iterator.remove(); + Iterator> iterator = delegate.entrySet().iterator(); + while (iterator.hasNext()) { + Map.Entry entry = iterator.next(); + if (entry.getValue() != null) { + fireOnEviction(entry.getKey(), entry.getValue(), EvictionType.EXPLICIT); } + iterator.remove(); } } @@ -227,17 +201,11 @@ public final boolean hasCacheEntry(K key) { /** Returns shallow copied values in the cache. */ public final List values() { - synchronized (lock) { - List list = new ArrayList<>(delegate.size()); - for (SizedValue value : delegate.values()) { - list.add(value.value); - } - return Collections.unmodifiableList(list); + List list = new ArrayList<>(delegate.size()); + for (SizedValue value : delegate.values()) { + list.add(value.value); } - } - - protected long now() { - return ticker.read(); + return Collections.unmodifiableList(list); } /** @@ -247,26 +215,24 @@ protected long now() { */ protected final boolean fitToLimit() { boolean removedAnyUnexpired = false; - synchronized (lock) { - if (estimatedSizeBytes.get() <= estimatedMaxSizeBytes) { - // new size is larger no need to do cleanup - return false; - } - // cleanup expired entries - cleanupExpiredEntries(now()); - - // cleanup eldest entry until new size limit - Iterator> lruIter = delegate.entrySet().iterator(); - while (lruIter.hasNext() && estimatedMaxSizeBytes < this.estimatedSizeBytes.get()) { - Map.Entry entry = lruIter.next(); - if (!shouldInvalidateEldestEntry(entry.getKey(), entry.getValue().value)) { - break; // Violates some constraint like minimum age so stop our cleanup - } - lruIter.remove(); - // eviction listener will update the estimatedSizeBytes - evictionListener.onEviction(entry.getKey(), entry.getValue(), EvictionType.SIZE); - removedAnyUnexpired = true; + if (estimatedSizeBytes <= estimatedMaxSizeBytes) { + return false; + } + // cleanup expired entries + long now = ticker.read(); + cleanupExpiredEntries(now); + + // cleanup eldest entry until the size of all entries fits within the limit + Iterator> lruIter = delegate.entrySet().iterator(); + while (lruIter.hasNext() && estimatedMaxSizeBytes < this.estimatedSizeBytes) { + Map.Entry entry = lruIter.next(); + if (!shouldInvalidateEldestEntry(entry.getKey(), entry.getValue().value, now)) { + break; // Violates some constraint like minimum age so stop our cleanup } + lruIter.remove(); + // fireOnEviction will update the estimatedSizeBytes + fireOnEviction(entry.getKey(), entry.getValue(), EvictionType.SIZE); + removedAnyUnexpired = true; } return removedAnyUnexpired; } @@ -276,18 +242,19 @@ protected final boolean fitToLimit() { * removing expired entries and removing oldest entries by LRU order. */ public final void resize(long newSizeBytes) { - synchronized (lock) { - this.estimatedMaxSizeBytes = newSizeBytes; - fitToLimit(); - } + this.estimatedMaxSizeBytes = newSizeBytes; + fitToLimit(); } @Override @CheckReturnValue public final int estimatedSize() { - synchronized (lock) { - return delegate.size(); - } + return delegate.size(); + } + + /** Returns {@code true} if any entries were removed. */ + public final boolean cleanupExpiredEntries() { + return cleanupExpiredEntries(ticker.read()); } private boolean cleanupExpiredEntries(long now) { @@ -298,16 +265,14 @@ private boolean cleanupExpiredEntries(long now) { private boolean cleanupExpiredEntries(int maxExpiredEntries, long now) { checkArgument(maxExpiredEntries > 0, "maxExpiredEntries must be positive"); boolean removedAny = false; - synchronized (lock) { - Iterator> lruIter = delegate.entrySet().iterator(); - while (lruIter.hasNext() && maxExpiredEntries > 0) { - Map.Entry entry = lruIter.next(); - if (isExpired(entry.getKey(), entry.getValue().value, now)) { - lruIter.remove(); - evictionListener.onEviction(entry.getKey(), entry.getValue(), EvictionType.EXPIRED); - removedAny = true; - maxExpiredEntries--; - } + Iterator> lruIter = delegate.entrySet().iterator(); + while (lruIter.hasNext() && maxExpiredEntries > 0) { + Map.Entry entry = lruIter.next(); + if (isExpired(entry.getKey(), entry.getValue().value, now)) { + lruIter.remove(); + fireOnEviction(entry.getKey(), entry.getValue(), EvictionType.EXPIRED); + removedAny = true; + maxExpiredEntries--; } } return removedAny; @@ -315,65 +280,13 @@ private boolean cleanupExpiredEntries(int maxExpiredEntries, long now) { @Override public final void close() { - synchronized (lock) { - periodicCleaner.stop(); - invalidateAll(); - } - } - - /** Periodically cleans up the AsyncRequestCache. */ - private final class PeriodicCleaner { - - private final ScheduledExecutorService ses; - private final int interval; - private final TimeUnit intervalUnit; - private ScheduledFuture scheduledFuture; - - PeriodicCleaner(ScheduledExecutorService ses, int interval, TimeUnit intervalUnit) { - this.ses = checkNotNull(ses, "ses"); - checkState(interval > 0, "interval must be positive"); - this.interval = interval; - this.intervalUnit = checkNotNull(intervalUnit, "intervalUnit"); - } - - PeriodicCleaner start() { - checkState(scheduledFuture == null, "cleaning task can be started only once"); - this.scheduledFuture = - ses.scheduleAtFixedRate(new CleaningTask(), interval, interval, intervalUnit); - return this; - } - - void stop() { - if (scheduledFuture != null) { - scheduledFuture.cancel(false); - scheduledFuture = null; - } - } - - private class CleaningTask implements Runnable { - - @Override - public void run() { - cleanupExpiredEntries(ticker.read()); - } - } + invalidateAll(); } - /** A {@link EvictionListener} keeps track of size. */ - private final class SizeHandlingEvictionListener implements EvictionListener { - - private final EvictionListener delegate; - - SizeHandlingEvictionListener(@Nullable EvictionListener delegate) { - this.delegate = delegate; - } - - @Override - public void onEviction(K key, SizedValue value, EvictionType cause) { - estimatedSizeBytes.addAndGet(-1L * value.size); - if (delegate != null) { - delegate.onEviction(key, value.value, cause); - } + private void fireOnEviction(K key, SizedValue value, EvictionType cause) { + estimatedSizeBytes -= value.size; + if (evictionListener != null) { + evictionListener.onEviction(key, value.value, cause); } } diff --git a/rls/src/main/java/io/grpc/rls/LruCache.java b/rls/src/main/java/io/grpc/rls/LruCache.java index 1ad5a958289..8fc4ae98472 100644 --- a/rls/src/main/java/io/grpc/rls/LruCache.java +++ b/rls/src/main/java/io/grpc/rls/LruCache.java @@ -16,7 +16,7 @@ package io.grpc.rls; -import javax.annotation.CheckReturnValue; +import com.google.errorprone.annotations.CheckReturnValue; import javax.annotation.Nullable; /** An LruCache is a cache with least recently used eviction. */ diff --git a/rls/src/main/java/io/grpc/rls/RlsLoadBalancer.java b/rls/src/main/java/io/grpc/rls/RlsLoadBalancer.java index d1e537f1482..848199f50a8 100644 --- a/rls/src/main/java/io/grpc/rls/RlsLoadBalancer.java +++ b/rls/src/main/java/io/grpc/rls/RlsLoadBalancer.java @@ -19,7 +19,6 @@ import static com.google.common.base.Preconditions.checkNotNull; import com.google.common.annotations.VisibleForTesting; -import com.google.common.base.MoreObjects; import io.grpc.ChannelLogger; import io.grpc.ChannelLogger.ChannelLogLevel; import io.grpc.ConnectivityState; @@ -50,12 +49,11 @@ final class RlsLoadBalancer extends LoadBalancer { @Override public Status acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { - logger.log(ChannelLogLevel.DEBUG, "Received resolution result: {0}", resolvedAddresses); LbPolicyConfiguration lbPolicyConfiguration = (LbPolicyConfiguration) resolvedAddresses.getLoadBalancingPolicyConfig(); checkNotNull(lbPolicyConfiguration, "Missing RLS LB config"); if (!lbPolicyConfiguration.equals(this.lbPolicyConfiguration)) { - logger.log(ChannelLogLevel.DEBUG, "A new RLS LB config received"); + logger.log(ChannelLogLevel.DEBUG, "A new RLS LB config received: {0}", lbPolicyConfiguration); boolean needToConnect = this.lbPolicyConfiguration == null || !this.lbPolicyConfiguration.getRouteLookupConfig().lookupService().equals( lbPolicyConfiguration.getRouteLookupConfig().lookupService()); @@ -80,50 +78,32 @@ public Status acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { // not required. this.lbPolicyConfiguration = lbPolicyConfiguration; } - logger.log(ChannelLogLevel.DEBUG, "RLS LB accepted resolved addresses successfully"); - return Status.OK; + return routeLookupClient.acceptResolvedAddressFactory( + new ChildLbResolvedAddressFactory( + resolvedAddresses.getAddresses(), resolvedAddresses.getAttributes())); } @Override public void requestConnection() { - logger.log(ChannelLogLevel.DEBUG, "connection requested from RLS LB"); if (routeLookupClient != null) { - logger.log(ChannelLogLevel.DEBUG, "requesting a connection from the routeLookupClient"); routeLookupClient.requestConnection(); } } @Override public void handleNameResolutionError(final Status error) { - logger.log(ChannelLogLevel.DEBUG, "Received resolution error: {0}", error); - class ErrorPicker extends SubchannelPicker { - @Override - public PickResult pickSubchannel(PickSubchannelArgs args) { - return PickResult.withError(error); - } - - @Override - public String toString() { - return MoreObjects.toStringHelper(this) - .add("error", error) - .toString(); - } - } - if (routeLookupClient != null) { logger.log(ChannelLogLevel.DEBUG, "closing the routeLookupClient on a name resolution error"); routeLookupClient.close(); routeLookupClient = null; lbPolicyConfiguration = null; } - logger.log(ChannelLogLevel.DEBUG, - "Updating balancing state to TRANSIENT_FAILURE with an error picker"); - helper.updateBalancingState(ConnectivityState.TRANSIENT_FAILURE, new ErrorPicker()); + helper.updateBalancingState( + ConnectivityState.TRANSIENT_FAILURE, new FixedResultPicker(PickResult.withError(error))); } @Override public void shutdown() { - logger.log(ChannelLogLevel.DEBUG, "Rls lb shutdown"); if (routeLookupClient != null) { logger.log(ChannelLogLevel.DEBUG, "closing the routeLookupClient because of RLS LB shutdown"); routeLookupClient.close(); diff --git a/rls/src/main/java/io/grpc/rls/RlsProtoConverters.java b/rls/src/main/java/io/grpc/rls/RlsProtoConverters.java index cd164f5e2a7..70f9fb4d891 100644 --- a/rls/src/main/java/io/grpc/rls/RlsProtoConverters.java +++ b/rls/src/main/java/io/grpc/rls/RlsProtoConverters.java @@ -64,7 +64,9 @@ static final class RouteLookupRequestConverter @Override protected RlsProtoData.RouteLookupRequest doForward(RouteLookupRequest routeLookupRequest) { return RlsProtoData.RouteLookupRequest.create( - ImmutableMap.copyOf(routeLookupRequest.getKeyMapMap())); + ImmutableMap.copyOf(routeLookupRequest.getKeyMapMap()), + RlsProtoData.RouteLookupRequest.Reason.valueOf(routeLookupRequest.getReason().name()) + ); } @Override @@ -72,6 +74,7 @@ protected RouteLookupRequest doBackward(RlsProtoData.RouteLookupRequest routeLoo return RouteLookupRequest.newBuilder() .setTargetType("grpc") + .setReason(RouteLookupRequest.Reason.valueOf(routeLookupRequest.reason().name())) .putAllKeyMap(routeLookupRequest.keyMap()) .build(); } @@ -152,10 +155,15 @@ protected RouteLookupConfig doForward(Map json) { checkArgument(staleAge == null, "to specify staleAge, must have maxAge"); maxAge = MAX_AGE_NANOS; } - if (staleAge == null) { + // If staleAge is not set, clamp maxAge to <= 5. + if (staleAge == null && maxAge > MAX_AGE_NANOS) { + maxAge = MAX_AGE_NANOS; + } + // Clamp staleAge to <= 5 + if (staleAge == null || staleAge > MAX_AGE_NANOS) { staleAge = MAX_AGE_NANOS; } - maxAge = Math.min(maxAge, MAX_AGE_NANOS); + // Ignore staleAge if greater than maxAge. staleAge = Math.min(staleAge, maxAge); long cacheSize = orDefault(JsonUtil.getNumberAsLong(json, "cacheSizeBytes"), MAX_CACHE_SIZE); checkArgument(cacheSize > 0, "cacheSize must be positive"); diff --git a/rls/src/main/java/io/grpc/rls/RlsProtoData.java b/rls/src/main/java/io/grpc/rls/RlsProtoData.java index 49f32c6b6e3..39c404870f9 100644 --- a/rls/src/main/java/io/grpc/rls/RlsProtoData.java +++ b/rls/src/main/java/io/grpc/rls/RlsProtoData.java @@ -27,16 +27,42 @@ final class RlsProtoData { private RlsProtoData() {} + /** A key object for the Rls route lookup data cache. */ + @AutoValue + @Immutable + abstract static class RouteLookupRequestKey { + + /** Returns a map of key values extracted via key builders for the gRPC or HTTP request. */ + abstract ImmutableMap keyMap(); + + static RouteLookupRequestKey create(ImmutableMap keyMap) { + return new AutoValue_RlsProtoData_RouteLookupRequestKey(keyMap); + } + } + /** A request object sent to route lookup service. */ @AutoValue @Immutable abstract static class RouteLookupRequest { + /** Names should match those in {@link io.grpc.lookup.v1.RouteLookupRequest.Reason}. */ + enum Reason { + /** Unused. */ + REASON_UNKNOWN, + /** No data available in local cache. */ + REASON_MISS, + /** Data in local cache is stale. */ + REASON_STALE; + } + + /** Reason for making this request. */ + abstract Reason reason(); + /** Returns a map of key values extracted via key builders for the gRPC or HTTP request. */ abstract ImmutableMap keyMap(); - static RouteLookupRequest create(ImmutableMap keyMap) { - return new AutoValue_RlsProtoData_RouteLookupRequest(keyMap); + static RouteLookupRequest create(ImmutableMap keyMap, Reason reason) { + return new AutoValue_RlsProtoData_RouteLookupRequest(reason, keyMap); } } diff --git a/rls/src/main/java/io/grpc/rls/RlsRequestFactory.java b/rls/src/main/java/io/grpc/rls/RlsRequestFactory.java index a6ca0137ff1..1fed78f4df3 100644 --- a/rls/src/main/java/io/grpc/rls/RlsRequestFactory.java +++ b/rls/src/main/java/io/grpc/rls/RlsRequestFactory.java @@ -20,20 +20,20 @@ import com.google.common.base.MoreObjects; import com.google.common.collect.ImmutableMap; +import com.google.errorprone.annotations.CheckReturnValue; import io.grpc.Metadata; import io.grpc.rls.RlsProtoData.ExtraKeys; import io.grpc.rls.RlsProtoData.GrpcKeyBuilder; import io.grpc.rls.RlsProtoData.GrpcKeyBuilder.Name; import io.grpc.rls.RlsProtoData.NameMatcher; import io.grpc.rls.RlsProtoData.RouteLookupConfig; -import io.grpc.rls.RlsProtoData.RouteLookupRequest; +import io.grpc.rls.RlsProtoData.RouteLookupRequestKey; import java.util.HashMap; import java.util.List; import java.util.Map; -import javax.annotation.CheckReturnValue; /** - * A RlsRequestFactory creates {@link RouteLookupRequest} using key builder map from {@link + * A RlsRequestFactory creates {@link RouteLookupRequestKey} using key builder map from {@link * RouteLookupConfig}. */ final class RlsRequestFactory { @@ -61,9 +61,9 @@ private static Map createKeyBuilderTable( return table; } - /** Creates a {@link RouteLookupRequest} for given request's metadata. */ + /** Creates a {@link RouteLookupRequestKey} for the given request lookup metadata. */ @CheckReturnValue - RouteLookupRequest create(String service, String method, Metadata metadata) { + RouteLookupRequestKey create(String service, String method, Metadata metadata) { checkNotNull(service, "service"); checkNotNull(method, "method"); String path = "/" + service + "/" + method; @@ -73,7 +73,7 @@ RouteLookupRequest create(String service, String method, Metadata metadata) { grpcKeyBuilder = keyBuilderTable.get("/" + service + "/*"); } if (grpcKeyBuilder == null) { - return RouteLookupRequest.create(ImmutableMap.of()); + return RouteLookupRequestKey.create(ImmutableMap.of()); } ImmutableMap.Builder rlsRequestHeaders = createRequestHeaders(metadata, grpcKeyBuilder.headers()); @@ -89,7 +89,7 @@ RouteLookupRequest create(String service, String method, Metadata metadata) { rlsRequestHeaders.put(extraKeys.method(), method); } rlsRequestHeaders.putAll(constantKeys); - return RouteLookupRequest.create(rlsRequestHeaders.buildOrThrow()); + return RouteLookupRequestKey.create(rlsRequestHeaders.buildOrThrow()); } private ImmutableMap.Builder createRequestHeaders( diff --git a/rls/src/main/java/io/grpc/rls/Throttler.java b/rls/src/main/java/io/grpc/rls/Throttler.java index 08f54c2e1b3..96d17e70adf 100644 --- a/rls/src/main/java/io/grpc/rls/Throttler.java +++ b/rls/src/main/java/io/grpc/rls/Throttler.java @@ -42,27 +42,4 @@ interface Throttler { * @param throttled specifies whether the request was throttled by the backend. */ void registerBackendResponse(boolean throttled); - - /** - * A ThrottledException indicates the call is throttled. This exception is meant to be used by - * caller of {@link Throttler}, the implementation of Throttler should not throw - * this exception when {@link #shouldThrottle()} is called. - */ - final class ThrottledException extends RuntimeException { - - static final long serialVersionUID = 1L; - - public ThrottledException() { - super(); - } - - public ThrottledException(String s) { - super(s); - } - - @Override - public synchronized Throwable fillInStackTrace() { - return this; - } - } } diff --git a/rls/src/test/java/io/grpc/rls/CachingRlsLbClientTest.java b/rls/src/test/java/io/grpc/rls/CachingRlsLbClientTest.java index 61cf4023779..b349aecdbf3 100644 --- a/rls/src/test/java/io/grpc/rls/CachingRlsLbClientTest.java +++ b/rls/src/test/java/io/grpc/rls/CachingRlsLbClientTest.java @@ -23,12 +23,15 @@ import static io.grpc.rls.CachingRlsLbClient.RLS_DATA_KEY; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertSame; +import static org.mockito.AdditionalAnswers.delegatesTo; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.argThat; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.inOrder; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; import com.google.common.base.Converter; import com.google.common.collect.ImmutableList; @@ -43,13 +46,20 @@ import io.grpc.ForwardingChannelBuilder2; import io.grpc.LoadBalancer; import io.grpc.LoadBalancer.Helper; +import io.grpc.LoadBalancer.PickDetailsConsumer; import io.grpc.LoadBalancer.PickResult; import io.grpc.LoadBalancer.SubchannelPicker; import io.grpc.LoadBalancerProvider; +import io.grpc.LongGaugeMetricInstrument; import io.grpc.ManagedChannel; import io.grpc.ManagedChannelBuilder; import io.grpc.Metadata; +import io.grpc.MetricRecorder; +import io.grpc.MetricRecorder.BatchCallback; +import io.grpc.MetricRecorder.BatchRecorder; +import io.grpc.MetricRecorder.Registration; import io.grpc.NameResolver.ConfigOrError; +import io.grpc.Server; import io.grpc.Status; import io.grpc.Status.Code; import io.grpc.SynchronizationContext; @@ -57,7 +67,10 @@ import io.grpc.inprocess.InProcessServerBuilder; import io.grpc.internal.BackoffPolicy; import io.grpc.internal.FakeClock; +import io.grpc.internal.GrpcUtil; +import io.grpc.internal.ObjectPool; import io.grpc.internal.PickSubchannelArgsImpl; +import io.grpc.internal.SharedResourcePool; import io.grpc.lookup.v1.RouteLookupServiceGrpc; import io.grpc.rls.CachingRlsLbClient.CacheEntry; import io.grpc.rls.CachingRlsLbClient.CachedRouteLookupResponse; @@ -87,18 +100,23 @@ import java.util.Map; import java.util.Set; import java.util.concurrent.ExecutionException; +import java.util.concurrent.Executor; +import java.util.concurrent.ExecutorService; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ScheduledFuture; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; +import java.util.concurrent.atomic.AtomicBoolean; import javax.annotation.Nonnull; import org.junit.After; +import org.junit.Before; import org.junit.Rule; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; -import org.mockito.AdditionalAnswers; import org.mockito.ArgumentCaptor; +import org.mockito.ArgumentMatcher; +import org.mockito.Captor; import org.mockito.InOrder; import org.mockito.Mock; import org.mockito.junit.MockitoJUnit; @@ -117,9 +135,18 @@ public class CachingRlsLbClientTest { public final GrpcCleanupRule grpcCleanupRule = new GrpcCleanupRule(); @Mock - private EvictionListener evictionListener; + private EvictionListener evictionListener; @Mock private SocketAddress socketAddress; + @Mock + private MetricRecorder mockMetricRecorder; + @Mock + private BatchRecorder mockBatchRecorder; + @Mock + private Registration mockGaugeRegistration; + @Captor + private ArgumentCaptor gaugeBatchCallbackCaptor; + private final SynchronizationContext syncContext = new SynchronizationContext(new UncaughtExceptionHandler() { @@ -140,8 +167,9 @@ public void uncaughtException(Thread t, Throwable e) { fakeClock.getScheduledExecutorService()); private final ChildLoadBalancingPolicy childLbPolicy = new ChildLoadBalancingPolicy("target", Collections.emptyMap(), lbProvider); + private final FakeHelper fakeHelper = new FakeHelper(); private final Helper helper = - mock(Helper.class, AdditionalAnswers.delegatesTo(new FakeHelper())); + mock(Helper.class, delegatesTo(fakeHelper)); private final FakeThrottler fakeThrottler = new FakeThrottler(); private final LbPolicyConfiguration lbPolicyConfiguration = new LbPolicyConfiguration(ROUTE_LOOKUP_CONFIG, null, childLbPolicy); @@ -164,23 +192,30 @@ private void setUpRlsLbClient() { .build(); } + @Before + public void setUpMockMetricRecorder() { + when(mockMetricRecorder.registerBatchCallback(any(), any())).thenReturn(mockGaugeRegistration); + } + @After public void tearDown() throws Exception { - rlsLbClient.close(); + if (rlsLbClient != null) { + rlsLbClient.close(); + } assertWithMessage( "On client shut down, RlsLoadBalancer must shut down with all its child loadbalancers.") .that(lbProvider.loadBalancers).isEmpty(); } private CachedRouteLookupResponse getInSyncContext( - final RouteLookupRequest request) + final RlsProtoData.RouteLookupRequestKey routeLookupRequestKey) throws ExecutionException, InterruptedException, TimeoutException { final SettableFuture responseSettableFuture = SettableFuture.create(); syncContext.execute(new Runnable() { @Override public void run() { - responseSettableFuture.set(rlsLbClient.get(request)); + responseSettableFuture.set(rlsLbClient.get(routeLookupRequestKey)); } }); return responseSettableFuture.get(5, TimeUnit.SECONDS); @@ -190,48 +225,53 @@ public void run() { public void get_noError_lifeCycle() throws Exception { setUpRlsLbClient(); InOrder inOrder = inOrder(evictionListener); - RouteLookupRequest routeLookupRequest = RouteLookupRequest.create(ImmutableMap.of( - "server", "bigtable.googleapis.com", "service-key", "foo", "method-key", "bar")); + RlsProtoData.RouteLookupRequestKey routeLookupRequestKey = + RlsProtoData.RouteLookupRequestKey.create( + ImmutableMap.of( + "server", "bigtable.googleapis.com", "service-key", "foo", "method-key", "bar")); rlsServerImpl.setLookupTable( ImmutableMap.of( - routeLookupRequest, + routeLookupRequestKey, RouteLookupResponse.create(ImmutableList.of("target"), "header"))); // initial request - CachedRouteLookupResponse resp = getInSyncContext(routeLookupRequest); + CachedRouteLookupResponse resp = getInSyncContext(routeLookupRequestKey); assertThat(resp.isPending()).isTrue(); // server response fakeClock.forwardTime(SERVER_LATENCY_MILLIS, TimeUnit.MILLISECONDS); - resp = getInSyncContext(routeLookupRequest); + resp = getInSyncContext(routeLookupRequestKey); assertThat(resp.hasData()).isTrue(); // cache hit for staled entry fakeClock.forwardTime(ROUTE_LOOKUP_CONFIG.staleAgeInNanos(), TimeUnit.NANOSECONDS); - resp = getInSyncContext(routeLookupRequest); + rlsServerImpl.routeLookupReason = null; + resp = getInSyncContext(routeLookupRequestKey); assertThat(resp.hasData()).isTrue(); // async refresh finishes fakeClock.forwardTime(SERVER_LATENCY_MILLIS, TimeUnit.MILLISECONDS); inOrder .verify(evictionListener) - .onEviction(eq(routeLookupRequest), any(CacheEntry.class), eq(EvictionType.REPLACED)); + .onEviction(eq(routeLookupRequestKey), any(CacheEntry.class), eq(EvictionType.REPLACED)); - resp = getInSyncContext(routeLookupRequest); + resp = getInSyncContext(routeLookupRequestKey); + assertThat(rlsServerImpl.routeLookupReason).isEqualTo( + io.grpc.lookup.v1.RouteLookupRequest.Reason.REASON_STALE); assertThat(resp.hasData()).isTrue(); // existing cache expired fakeClock.forwardTime(ROUTE_LOOKUP_CONFIG.maxAgeInNanos(), TimeUnit.NANOSECONDS); - resp = getInSyncContext(routeLookupRequest); + resp = getInSyncContext(routeLookupRequestKey); assertThat(resp.isPending()).isTrue(); inOrder .verify(evictionListener) - .onEviction(eq(routeLookupRequest), any(CacheEntry.class), eq(EvictionType.EXPIRED)); + .onEviction(eq(routeLookupRequestKey), any(CacheEntry.class), eq(EvictionType.EXPIRED)); inOrder.verifyNoMoreInteractions(); } @@ -260,99 +300,275 @@ public void rls_withCustomRlsChannelServiceConfig() throws Exception { .setThrottler(fakeThrottler) .setTicker(fakeClock.getTicker()) .build(); - RouteLookupRequest routeLookupRequest = RouteLookupRequest.create(ImmutableMap.of( - "server", "bigtable.googleapis.com", "service-key", "foo", "method-key", "bar")); + RlsProtoData.RouteLookupRequestKey routeLookupRequestKey = + RlsProtoData.RouteLookupRequestKey.create( + ImmutableMap.of( + "server", "bigtable.googleapis.com", "service-key", "foo", "method-key", "bar")); rlsServerImpl.setLookupTable( ImmutableMap.of( - routeLookupRequest, + routeLookupRequestKey, RouteLookupResponse.create(ImmutableList.of("target"), "header"))); + rlsServerImpl.routeLookupReason = null; // initial request - CachedRouteLookupResponse resp = getInSyncContext(routeLookupRequest); + CachedRouteLookupResponse resp = getInSyncContext(routeLookupRequestKey); assertThat(resp.isPending()).isTrue(); // server response fakeClock.forwardTime(SERVER_LATENCY_MILLIS, TimeUnit.MILLISECONDS); - resp = getInSyncContext(routeLookupRequest); + resp = getInSyncContext(routeLookupRequestKey); assertThat(resp.hasData()).isTrue(); + assertThat(rlsServerImpl.routeLookupReason).isEqualTo( + io.grpc.lookup.v1.RouteLookupRequest.Reason.REASON_MISS); assertThat(rlsChannelOverriddenAuthority).isEqualTo("bigtable.googleapis.com:443"); assertThat(rlsChannelServiceConfig).isEqualTo(routeLookupChannelServiceConfig); } @Test - public void get_throttledAndRecover() throws Exception { + public void backoffTimerEnd_updatesPicker() throws Exception { setUpRlsLbClient(); - RouteLookupRequest routeLookupRequest = RouteLookupRequest.create(ImmutableMap.of( - "server", "bigtable.googleapis.com", "service-key", "foo", "method-key", "bar")); + InOrder inOrder = inOrder(helper); + RlsProtoData.RouteLookupRequestKey routeLookupRequestKey = + RlsProtoData.RouteLookupRequestKey.create( + ImmutableMap.of( + "server", "bigtable.googleapis.com", "service-key", "foo", "method-key", "bar")); rlsServerImpl.setLookupTable( ImmutableMap.of( - routeLookupRequest, + routeLookupRequestKey, RouteLookupResponse.create(ImmutableList.of("target"), "header"))); fakeThrottler.nextResult = true; fakeBackoffProvider.nextPolicy = createBackoffPolicy(10, TimeUnit.MILLISECONDS); - CachedRouteLookupResponse resp = getInSyncContext(routeLookupRequest); - + CachedRouteLookupResponse resp = getInSyncContext(routeLookupRequestKey); assertThat(resp.hasError()).isTrue(); fakeClock.forwardTime(10, TimeUnit.MILLISECONDS); - // initially backed off entry is backed off again - verify(evictionListener) - .onEviction(eq(routeLookupRequest), any(CacheEntry.class), eq(EvictionType.REPLACED)); + // Assert that Rls LB policy picker was updated which picks the fallback target + ArgumentCaptor pickerCaptor = ArgumentCaptor.forClass(SubchannelPicker.class); + ArgumentCaptor stateCaptor = + ArgumentCaptor.forClass(ConnectivityState.class); + + inOrder.verify(helper, times(3)) + .updateBalancingState(stateCaptor.capture(), pickerCaptor.capture()); + assertThat(new HashSet<>(pickerCaptor.getAllValues())).hasSize(1); + assertThat(stateCaptor.getAllValues()) + .containsExactly(ConnectivityState.TRANSIENT_FAILURE, ConnectivityState.CONNECTING, + ConnectivityState.CONNECTING); + Metadata headers = new Metadata(); + PickResult pickResult = getPickResultForCreate(pickerCaptor, headers); + assertThat(pickResult.getStatus().getCode()).isEqualTo(Status.Code.UNAVAILABLE); + assertThat(pickResult.getStatus().getDescription()).isEqualTo("fallback not available"); + } - resp = getInSyncContext(routeLookupRequest); + @Test + public void get_throttledTwice_usesSameBackoffpolicy() throws Exception { + setUpRlsLbClient(); + RlsProtoData.RouteLookupRequestKey routeLookupRequestKey = + RlsProtoData.RouteLookupRequestKey.create( + ImmutableMap.of( + "server", "bigtable.googleapis.com", "service-key", "foo", "method-key", "bar")); + rlsServerImpl.setLookupTable( + ImmutableMap.of( + routeLookupRequestKey, + RouteLookupResponse.create(ImmutableList.of("target"), "header"))); + fakeThrottler.nextResult = true; + fakeBackoffProvider.nextPolicy = createBackoffPolicy(10, TimeUnit.MILLISECONDS); + + CachedRouteLookupResponse resp = getInSyncContext(routeLookupRequestKey); + + assertThat(resp.hasError()).isTrue(); + + fakeClock.forwardTime(10, TimeUnit.MILLISECONDS); + + // Assert that the same backoff policy is still in effect for the cache entry. + // The below provider should not get used, so the back off time will still be set to 10ms. + fakeBackoffProvider.nextPolicy = createBackoffPolicy(20, TimeUnit.MILLISECONDS); + // let it be throttled again + resp = getInSyncContext(routeLookupRequestKey); assertThat(resp.hasError()).isTrue(); - // let it pass throttler + fakeClock.forwardTime(10, TimeUnit.MILLISECONDS); + + // Backoff entry's backoff timer has gone off, so next rpc should not be backed off. fakeThrottler.nextResult = false; + resp = getInSyncContext(routeLookupRequestKey); + assertThat(resp.isPending()).isTrue(); + + rlsServerImpl.routeLookupReason = null; + // server responses + fakeClock.forwardTime(SERVER_LATENCY_MILLIS, TimeUnit.MILLISECONDS); + assertThat(rlsServerImpl.routeLookupReason).isEqualTo( + io.grpc.lookup.v1.RouteLookupRequest.Reason.REASON_MISS); + } + + @Test + public void get_errorResponseTwice_usesSameBackoffPolicy() throws Exception { + setUpRlsLbClient(); + RlsProtoData.RouteLookupRequestKey invalidRouteLookupRequestKey = + RlsProtoData.RouteLookupRequestKey.create(ImmutableMap.of()); + CachedRouteLookupResponse resp = getInSyncContext(invalidRouteLookupRequestKey); + assertThat(resp.isPending()).isTrue(); + fakeBackoffProvider.nextPolicy = createBackoffPolicy(10, TimeUnit.MILLISECONDS); + fakeClock.forwardTime(SERVER_LATENCY_MILLIS, TimeUnit.MILLISECONDS); + assertThat(rlsServerImpl.routeLookupReason).isEqualTo( + io.grpc.lookup.v1.RouteLookupRequest.Reason.REASON_MISS); + + resp = getInSyncContext(invalidRouteLookupRequestKey); + assertThat(resp.hasError()).isTrue(); + + // Backoff time expiry fakeClock.forwardTime(10, TimeUnit.MILLISECONDS); + resp = getInSyncContext(invalidRouteLookupRequestKey); + assertThat(resp.isPending()).isTrue(); + // Assert that the same backoff policy is still in effect for the cache entry. + // The below provider should not get used, so the back off time will still be set to 10ms. + fakeBackoffProvider.nextPolicy = createBackoffPolicy(20, TimeUnit.MILLISECONDS); + // Gets error again and backed off again + fakeClock.forwardTime(SERVER_LATENCY_MILLIS, TimeUnit.MILLISECONDS); - resp = getInSyncContext(routeLookupRequest); + resp = getInSyncContext(invalidRouteLookupRequestKey); + assertThat(resp.hasError()).isTrue(); + // Backoff time expiry + fakeClock.forwardTime(10, TimeUnit.MILLISECONDS); + resp = getInSyncContext(invalidRouteLookupRequestKey); assertThat(resp.isPending()).isTrue(); + rlsServerImpl.routeLookupReason = null; // server responses fakeClock.forwardTime(SERVER_LATENCY_MILLIS, TimeUnit.MILLISECONDS); + assertThat(rlsServerImpl.routeLookupReason).isEqualTo( + io.grpc.lookup.v1.RouteLookupRequest.Reason.REASON_MISS); + } - resp = getInSyncContext(routeLookupRequest); + @Test + public void controlPlaneTransientToReady_backOffEntriesRemovedAndPickerUpdated() + throws Exception { + setUpRlsLbClient(); + InOrder inOrder = inOrder(helper); + final ConnectivityState[] rlsChannelState = new ConnectivityState[1]; + Runnable channelStateListener = new Runnable() { + @Override + public void run() { + rlsChannelState[0] = fakeHelper.oobChannel.getState(false); + fakeHelper.oobChannel.notifyWhenStateChanged(rlsChannelState[0], this); + synchronized (this) { + notify(); + } + } + }; + fakeHelper.oobChannel.notifyWhenStateChanged(fakeHelper.oobChannel.getState(false), + channelStateListener); + + fakeHelper.server.shutdown(); + // Channel goes to IDLE state from the shutdown listener handling. + try { + if (!fakeHelper.server.awaitTermination(10, TimeUnit.SECONDS)) { + fakeHelper.server.shutdownNow(); // Forceful shutdown if graceful timeout expires + } + } catch (InterruptedException e) { + fakeHelper.server.shutdownNow(); + } + RlsProtoData.RouteLookupRequestKey routeLookupRequestKey = + RlsProtoData.RouteLookupRequestKey.create(ImmutableMap.of( + "server", "bigtable.googleapis.com", "service-key", "foo", "method-key", "bar")); + // Rls channel will go to TRANSIENT_FAILURE (connection back-off). + CachedRouteLookupResponse resp = getInSyncContext(routeLookupRequestKey); + assertThat(resp.isPending()).isTrue(); + assertThat(rlsChannelState[0]).isEqualTo(ConnectivityState.TRANSIENT_FAILURE); + // Throttle the next rpc call. + fakeThrottler.nextResult = true; + fakeBackoffProvider.nextPolicy = createBackoffPolicy(10, TimeUnit.MILLISECONDS); - assertThat(resp.hasData()).isTrue(); + // Cause two cache misses by using new request keys. This will create back-off Rls cache + // entries. RLS control plane state transitioning to READY should reset both back-offs but + // update picker only once. + RlsProtoData.RouteLookupRequestKey routeLookupRequestKey2 = + RlsProtoData.RouteLookupRequestKey.create(ImmutableMap.of( + "server", "bigtable.googleapis.com", "service-key", "foo2", "method-key", "bar")); + resp = getInSyncContext(routeLookupRequestKey2); + assertThat(resp.hasError()).isTrue(); + RlsProtoData.RouteLookupRequestKey routeLookupRequestKey3 = + RlsProtoData.RouteLookupRequestKey.create(ImmutableMap.of( + "server", "bigtable.googleapis.com", "service-key", "foo3", "method-key", "bar")); + resp = getInSyncContext(routeLookupRequestKey3); + assertThat(resp.hasError()).isTrue(); + + fakeHelper.createServerAndRegister("service1"); + // Wait for Rls control plane channel back-off expiry and its moving to READY + synchronized (channelStateListener) { + channelStateListener.wait(2000); + } + assertThat(rlsChannelState[0]).isEqualTo(ConnectivityState.READY); + final ObjectPool defaultExecutorPool = + SharedResourcePool.forResource(GrpcUtil.SHARED_CHANNEL_EXECUTOR); + AtomicBoolean isSuccess = new AtomicBoolean(false); + ((ExecutorService) defaultExecutorPool.getObject()).submit(() -> { + // Assert that Rls LB policy picker was updated which picks the fallback target + ArgumentCaptor pickerCaptor = + ArgumentCaptor.forClass(SubchannelPicker.class); + ArgumentCaptor stateCaptor = + ArgumentCaptor.forClass(ConnectivityState.class); + + inOrder.verify(helper, times(4)) + .updateBalancingState(stateCaptor.capture(), pickerCaptor.capture()); + assertThat(new HashSet<>(pickerCaptor.getAllValues())).hasSize(1); + assertThat(stateCaptor.getAllValues()) + .containsExactly(ConnectivityState.TRANSIENT_FAILURE, ConnectivityState.CONNECTING, + ConnectivityState.CONNECTING, ConnectivityState.CONNECTING); + Metadata headers = new Metadata(); + PickResult pickResult = getPickResultForCreate(pickerCaptor, headers); + assertThat(pickResult.getStatus().getCode()).isEqualTo(Status.Code.UNAVAILABLE); + assertThat(pickResult.getStatus().getDescription()).isEqualTo("fallback not available"); + isSuccess.set(true); + }).get(); + assertThat(isSuccess.get()).isTrue(); + + fakeThrottler.nextResult = false; + // Rpcs are not backed off now. + assertThat(getInSyncContext(routeLookupRequestKey2).isPending()).isTrue(); + assertThat(getInSyncContext(routeLookupRequestKey3).isPending()).isTrue(); } @Test public void get_updatesLbState() throws Exception { setUpRlsLbClient(); InOrder inOrder = inOrder(helper); - RouteLookupRequest routeLookupRequest = RouteLookupRequest.create(ImmutableMap.of( - "server", "bigtable.googleapis.com", "service-key", "service1", "method-key", "create")); + RlsProtoData.RouteLookupRequestKey routeLookupRequestKey = + RlsProtoData.RouteLookupRequestKey.create( + ImmutableMap.of( + "server", "bigtable.googleapis.com", "service-key", "service1", + "method-key", "create")); rlsServerImpl.setLookupTable( ImmutableMap.of( - routeLookupRequest, + routeLookupRequestKey, RouteLookupResponse.create( ImmutableList.of("primary.cloudbigtable.googleapis.com"), "header-rls-data-value"))); // valid channel - CachedRouteLookupResponse resp = getInSyncContext(routeLookupRequest); + CachedRouteLookupResponse resp = getInSyncContext(routeLookupRequestKey); assertThat(resp.isPending()).isTrue(); fakeClock.forwardTime(SERVER_LATENCY_MILLIS, TimeUnit.MILLISECONDS); - resp = getInSyncContext(routeLookupRequest); + resp = getInSyncContext(routeLookupRequestKey); assertThat(resp.hasData()).isTrue(); ArgumentCaptor pickerCaptor = ArgumentCaptor.forClass(SubchannelPicker.class); ArgumentCaptor stateCaptor = ArgumentCaptor.forClass(ConnectivityState.class); - inOrder.verify(helper, times(2)) + inOrder.verify(helper, times(3)) .updateBalancingState(stateCaptor.capture(), pickerCaptor.capture()); assertThat(new HashSet<>(pickerCaptor.getAllValues())).hasSize(1); + // TRANSIENT_FAILURE is because the test setup pretends fallback is not available. assertThat(stateCaptor.getAllValues()) - .containsExactly(ConnectivityState.CONNECTING, ConnectivityState.READY); + .containsExactly(ConnectivityState.TRANSIENT_FAILURE, ConnectivityState.CONNECTING, + ConnectivityState.READY); Metadata headers = new Metadata(); PickResult pickResult = getPickResultForCreate(pickerCaptor, headers); assertThat(pickResult.getStatus().isOk()).isTrue(); @@ -364,13 +580,13 @@ public void get_updatesLbState() throws Exception { // move backoff further back to only test error behavior fakeBackoffProvider.nextPolicy = createBackoffPolicy(100, TimeUnit.MILLISECONDS); // try to get invalid - RouteLookupRequest invalidRouteLookupRequest = - RouteLookupRequest.create(ImmutableMap.of()); - CachedRouteLookupResponse errorResp = getInSyncContext(invalidRouteLookupRequest); + RlsProtoData.RouteLookupRequestKey invalidRouteLookupRequestKey = + RlsProtoData.RouteLookupRequestKey.create(ImmutableMap.of()); + CachedRouteLookupResponse errorResp = getInSyncContext(invalidRouteLookupRequestKey); assertThat(errorResp.isPending()).isTrue(); fakeClock.forwardTime(SERVER_LATENCY_MILLIS, TimeUnit.MILLISECONDS); - errorResp = getInSyncContext(invalidRouteLookupRequest); + errorResp = getInSyncContext(invalidRouteLookupRequestKey); assertThat(errorResp.hasError()).isTrue(); // Channel is still READY because the subchannel for method /service1/create is still READY. @@ -383,7 +599,8 @@ public void get_updatesLbState() throws Exception { .setFullMethodName("doesn/exists") .build(), headers, - CallOptions.DEFAULT)); + CallOptions.DEFAULT, + new PickDetailsConsumer() {})); assertThat(pickResult.getStatus().getCode()).isEqualTo(Code.UNAVAILABLE); assertThat(pickResult.getStatus().getDescription()).contains("fallback not available"); assertThat(fakeThrottler.getNumThrottled()).isEqualTo(1); @@ -393,27 +610,30 @@ public void get_updatesLbState() throws Exception { @Test public void timeout_not_changing_picked_subchannel() throws Exception { setUpRlsLbClient(); - RouteLookupRequest routeLookupRequest = RouteLookupRequest.create(ImmutableMap.of( - "server", "bigtable.googleapis.com", "service-key", "service1", "method-key", "create")); + RlsProtoData.RouteLookupRequestKey routeLookupRequestKey = + RlsProtoData.RouteLookupRequestKey.create( + ImmutableMap.of( + "server", "bigtable.googleapis.com", "service-key", "service1", + "method-key", "create")); rlsServerImpl.setLookupTable( ImmutableMap.of( - routeLookupRequest, + routeLookupRequestKey, RouteLookupResponse.create( ImmutableList.of("primary.cloudbigtable.googleapis.com", "target2", "target3"), "header-rls-data-value"))); // valid channel - CachedRouteLookupResponse resp = getInSyncContext(routeLookupRequest); + CachedRouteLookupResponse resp = getInSyncContext(routeLookupRequestKey); assertThat(resp.hasData()).isFalse(); fakeClock.forwardTime(SERVER_LATENCY_MILLIS, TimeUnit.MILLISECONDS); - resp = getInSyncContext(routeLookupRequest); + resp = getInSyncContext(routeLookupRequestKey); assertThat(resp.hasData()).isTrue(); ArgumentCaptor pickerCaptor = ArgumentCaptor.forClass(SubchannelPicker.class); ArgumentCaptor stateCaptor = ArgumentCaptor.forClass(ConnectivityState.class); - verify(helper, times(4)).updateBalancingState(stateCaptor.capture(), pickerCaptor.capture()); + verify(helper, times(5)).updateBalancingState(stateCaptor.capture(), pickerCaptor.capture()); Metadata headers = new Metadata(); PickResult pickResult = getPickResultForCreate(pickerCaptor, headers); @@ -438,7 +658,8 @@ private static PickResult getPickResultForCreate(ArgumentCaptor pickerCaptor = ArgumentCaptor.forClass(SubchannelPicker.class); ArgumentCaptor stateCaptor = ArgumentCaptor.forClass(ConnectivityState.class); - inOrder.verify(helper, times(2)) + inOrder.verify(helper, times(3)) .updateBalancingState(stateCaptor.capture(), pickerCaptor.capture()); Metadata headers = new Metadata(); @@ -493,13 +717,13 @@ public void get_withAdaptiveThrottler() throws Exception { // move backoff further back to only test error behavior fakeBackoffProvider.nextPolicy = createBackoffPolicy(100, TimeUnit.MILLISECONDS); // try to get invalid - RouteLookupRequest invalidRouteLookupRequest = - RouteLookupRequest.create(ImmutableMap.of()); - CachedRouteLookupResponse errorResp = getInSyncContext(invalidRouteLookupRequest); + RlsProtoData.RouteLookupRequestKey invalidRouteLookupRequestKey = + RlsProtoData.RouteLookupRequestKey.create(ImmutableMap.of()); + CachedRouteLookupResponse errorResp = getInSyncContext(invalidRouteLookupRequestKey); assertThat(errorResp.isPending()).isTrue(); fakeClock.forwardTime(SERVER_LATENCY_MILLIS, TimeUnit.MILLISECONDS); - errorResp = getInSyncContext(invalidRouteLookupRequest); + errorResp = getInSyncContext(invalidRouteLookupRequestKey); assertThat(errorResp.hasError()).isTrue(); // Channel is still READY because the subchannel for method /service1/create is still READY. @@ -521,29 +745,34 @@ private PickSubchannelArgsImpl getInvalidArgs(Metadata headers) { .setFullMethodName("doesn/exists") .build(), headers, - CallOptions.DEFAULT); + CallOptions.DEFAULT, + new PickDetailsConsumer() {}); return invalidArgs; } @Test public void get_childPolicyWrapper_reusedForSameTarget() throws Exception { setUpRlsLbClient(); - RouteLookupRequest routeLookupRequest = RouteLookupRequest.create(ImmutableMap.of( - "server", "bigtable.googleapis.com", "service-key", "foo", "method-key", "bar")); - RouteLookupRequest routeLookupRequest2 = RouteLookupRequest.create(ImmutableMap.of( - "server", "bigtable.googleapis.com", "service-key", "foo", "method-key", "baz")); + RlsProtoData.RouteLookupRequestKey routeLookupRequestKey = + RlsProtoData.RouteLookupRequestKey.create( + ImmutableMap.of( + "server", "bigtable.googleapis.com", "service-key", "foo", "method-key", "bar")); + RlsProtoData.RouteLookupRequestKey routeLookupRequestKey2 = + RlsProtoData.RouteLookupRequestKey.create( + ImmutableMap.of( + "server", "bigtable.googleapis.com", "service-key", "foo", "method-key", "baz")); rlsServerImpl.setLookupTable( ImmutableMap.of( - routeLookupRequest, + routeLookupRequestKey, RouteLookupResponse.create(ImmutableList.of("target"), "header"), - routeLookupRequest2, + routeLookupRequestKey2, RouteLookupResponse.create(ImmutableList.of("target"), "header2"))); - CachedRouteLookupResponse resp = getInSyncContext(routeLookupRequest); + CachedRouteLookupResponse resp = getInSyncContext(routeLookupRequestKey); assertThat(resp.isPending()).isTrue(); fakeClock.forwardTime(SERVER_LATENCY_MILLIS, TimeUnit.MILLISECONDS); - resp = getInSyncContext(routeLookupRequest); + resp = getInSyncContext(routeLookupRequestKey); assertThat(resp.hasData()).isTrue(); assertThat(resp.getHeaderData()).isEqualTo("header"); @@ -553,11 +782,11 @@ public void get_childPolicyWrapper_reusedForSameTarget() throws Exception { assertThat(childPolicyWrapper.getPicker()).isNotInstanceOf(RlsPicker.class); // request2 has same target, it should reuse childPolicyWrapper - CachedRouteLookupResponse resp2 = getInSyncContext(routeLookupRequest2); + CachedRouteLookupResponse resp2 = getInSyncContext(routeLookupRequestKey2); assertThat(resp2.isPending()).isTrue(); fakeClock.forwardTime(SERVER_LATENCY_MILLIS, TimeUnit.MILLISECONDS); - resp2 = getInSyncContext(routeLookupRequest2); + resp2 = getInSyncContext(routeLookupRequestKey2); assertThat(resp2.hasData()).isTrue(); assertThat(resp2.getHeaderData()).isEqualTo("header2"); assertThat(resp2.getChildPolicyWrapper()).isEqualTo(resp.getChildPolicyWrapper()); @@ -566,20 +795,22 @@ public void get_childPolicyWrapper_reusedForSameTarget() throws Exception { @Test public void get_childPolicyWrapper_multiTarget() throws Exception { setUpRlsLbClient(); - RouteLookupRequest routeLookupRequest = RouteLookupRequest.create(ImmutableMap.of( - "server", "bigtable.googleapis.com", "service-key", "foo", "method-key", "bar")); + RlsProtoData.RouteLookupRequestKey routeLookupRequestKey = + RlsProtoData.RouteLookupRequestKey.create( + ImmutableMap.of( + "server", "bigtable.googleapis.com", "service-key", "foo", "method-key", "bar")); rlsServerImpl.setLookupTable( ImmutableMap.of( - routeLookupRequest, + routeLookupRequestKey, RouteLookupResponse.create( ImmutableList.of("target1", "target2", "target3"), "header"))); - CachedRouteLookupResponse resp = getInSyncContext(routeLookupRequest); + CachedRouteLookupResponse resp = getInSyncContext(routeLookupRequestKey); assertThat(resp.isPending()).isTrue(); fakeClock.forwardTime(SERVER_LATENCY_MILLIS, TimeUnit.MILLISECONDS); - resp = getInSyncContext(routeLookupRequest); + resp = getInSyncContext(routeLookupRequestKey); assertThat(resp.hasData()).isTrue(); List policyWrappers = new ArrayList<>(); @@ -629,6 +860,53 @@ private void setState(ChildPolicyWrapper policyWrapper, ConnectivityState newSta policyWrapper.getHelper().updateBalancingState(newState, policyWrapper.getPicker()); } + @Test + public void metricGauges() throws ExecutionException, InterruptedException, TimeoutException { + setUpRlsLbClient(); + + verify(mockMetricRecorder).registerBatchCallback(gaugeBatchCallbackCaptor.capture(), + any()); + + BatchCallback gaugeBatchCallback = gaugeBatchCallbackCaptor.getValue(); + + // Verify the correct cache gauge values when requested at this point. + InOrder inOrder = inOrder(mockBatchRecorder); + gaugeBatchCallback.accept(mockBatchRecorder); + inOrder.verify(mockBatchRecorder).recordLongGauge( + argThat(new LongGaugeInstrumentArgumentMatcher("grpc.lb.rls.cache_entries")), eq(0L), + any(), any()); + inOrder.verify(mockBatchRecorder) + .recordLongGauge(argThat(new LongGaugeInstrumentArgumentMatcher("grpc.lb.rls.cache_size")), + eq(0L), any(), any()); + + RlsProtoData.RouteLookupRequestKey routeLookupRequestKey = + RlsProtoData.RouteLookupRequestKey.create( + ImmutableMap.of("server", "bigtable.googleapis.com", "service-key", "foo", "method-key", + "bar")); + rlsServerImpl.setLookupTable(ImmutableMap.of(routeLookupRequestKey, + RouteLookupResponse.create(ImmutableList.of("target"), "header"))); + + // Make a request that will populate the cache with an entry + getInSyncContext(routeLookupRequestKey); + fakeClock.forwardTime(SERVER_LATENCY_MILLIS, TimeUnit.MILLISECONDS); + + // Gauge values should reflect the new cache entry. + gaugeBatchCallback.accept(mockBatchRecorder); + inOrder.verify(mockBatchRecorder).recordLongGauge( + argThat(new LongGaugeInstrumentArgumentMatcher("grpc.lb.rls.cache_entries")), eq(1L), + any(), any()); + inOrder.verify(mockBatchRecorder) + .recordLongGauge(argThat(new LongGaugeInstrumentArgumentMatcher("grpc.lb.rls.cache_size")), + eq(260L), any(), any()); + + inOrder.verifyNoMoreInteractions(); + + // Shutdown + rlsLbClient.close(); + rlsLbClient = null; + verify(mockGaugeRegistration).close(); + } + private static RouteLookupConfig getRouteLookupConfig() { return RouteLookupConfig.builder() .grpcKeybuilders(ImmutableList.of( @@ -660,6 +938,21 @@ public long nextBackoffNanos() { }; } + private static class LongGaugeInstrumentArgumentMatcher implements + ArgumentMatcher { + + private final String instrumentName; + + public LongGaugeInstrumentArgumentMatcher(String instrumentName) { + this.instrumentName = instrumentName; + } + + @Override + public boolean matches(LongGaugeMetricInstrument instrument) { + return instrument.getName().equals(instrumentName); + } + } + private static final class FakeBackoffProvider implements BackoffPolicy.Provider { private BackoffPolicy nextPolicy = createBackoffPolicy(100, TimeUnit.MILLISECONDS); @@ -732,14 +1025,9 @@ public PickResult pickSubchannel(PickSubchannelArgs args) { @Override public void handleNameResolutionError(final Status error) { - class ErrorPicker extends SubchannelPicker { - @Override - public PickResult pickSubchannel(PickSubchannelArgs args) { - return PickResult.withError(error); - } - } - - helper.updateBalancingState(ConnectivityState.TRANSIENT_FAILURE, new ErrorPicker()); + helper.updateBalancingState( + ConnectivityState.TRANSIENT_FAILURE, + new FixedResultPicker(PickResult.withError(error))); } @Override @@ -764,7 +1052,9 @@ private static final class StaticFixedDelayRlsServerImpl private final long responseDelayNano; private final ScheduledExecutorService scheduledExecutorService; - private Map lookupTable = ImmutableMap.of(); + private Map lookupTable = + ImmutableMap.of(); + io.grpc.lookup.v1.RouteLookupRequest.Reason routeLookupReason; public StaticFixedDelayRlsServerImpl( long responseDelayNano, ScheduledExecutorService scheduledExecutorService) { @@ -774,7 +1064,8 @@ public StaticFixedDelayRlsServerImpl( checkNotNull(scheduledExecutorService, "scheduledExecutorService"); } - private void setLookupTable(Map lookupTable) { + private void setLookupTable(Map lookupTable) { this.lookupTable = checkNotNull(lookupTable, "lookupTable"); } @@ -786,8 +1077,11 @@ public void routeLookup(final io.grpc.lookup.v1.RouteLookupRequest request, new Runnable() { @Override public void run() { + routeLookupReason = request.getReason(); RouteLookupResponse response = - lookupTable.get(REQUEST_CONVERTER.convert(request)); + lookupTable.get( + RlsProtoData.RouteLookupRequestKey.create( + REQUEST_CONVERTER.convert(request).keyMap())); if (response == null) { responseObserver.onError(new RuntimeException("not found")); } else { @@ -801,16 +1095,23 @@ public void run() { private final class FakeHelper extends Helper { + Server server; + ManagedChannel oobChannel; + + void createServerAndRegister(String target) throws IOException { + server = InProcessServerBuilder.forName(target) + .addService(rlsServerImpl) + .directExecutor() + .build() + .start(); + grpcCleanupRule.register(server); + } + @Override public ManagedChannelBuilder createResolvingOobChannelBuilder( String target, ChannelCredentials creds) { try { - grpcCleanupRule.register( - InProcessServerBuilder.forName(target) - .addService(rlsServerImpl) - .directExecutor() - .build() - .start()); + createServerAndRegister(target); } catch (IOException e) { throw new RuntimeException("cannot create server: " + target, e); } @@ -826,7 +1127,8 @@ protected ManagedChannelBuilder delegate() { @Override public ManagedChannel build() { - return grpcCleanupRule.register(super.build()); + oobChannel = super.build(); + return grpcCleanupRule.register(oobChannel); } @Override @@ -855,7 +1157,6 @@ public ManagedChannel createOobChannel(EquivalentAddressGroup eag, String author @Override public void updateBalancingState( @Nonnull ConnectivityState newState, @Nonnull SubchannelPicker newPicker) { - // no-op } @Override @@ -888,6 +1189,16 @@ public SynchronizationContext getSynchronizationContext() { public ChannelLogger getChannelLogger() { return mock(ChannelLogger.class); } + + @Override + public MetricRecorder getMetricRecorder() { + return mockMetricRecorder; + } + + @Override + public String getChannelTarget() { + return "channelTarget"; + } } private static final class FakeThrottler implements Throttler { diff --git a/rls/src/test/java/io/grpc/rls/LbPolicyConfigurationTest.java b/rls/src/test/java/io/grpc/rls/LbPolicyConfigurationTest.java index d6025d5bad4..de41d0488fc 100644 --- a/rls/src/test/java/io/grpc/rls/LbPolicyConfigurationTest.java +++ b/rls/src/test/java/io/grpc/rls/LbPolicyConfigurationTest.java @@ -21,6 +21,7 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -39,14 +40,13 @@ import io.grpc.Status; import io.grpc.SynchronizationContext; import io.grpc.rls.ChildLoadBalancerHelper.ChildLoadBalancerHelperProvider; -import io.grpc.rls.LbPolicyConfiguration.ChildLbStatusListener; import io.grpc.rls.LbPolicyConfiguration.ChildLoadBalancingPolicy; import io.grpc.rls.LbPolicyConfiguration.ChildPolicyWrapper; import io.grpc.rls.LbPolicyConfiguration.ChildPolicyWrapper.ChildPolicyReportingHelper; import io.grpc.rls.LbPolicyConfiguration.InvalidChildPolicyConfigException; import io.grpc.rls.LbPolicyConfiguration.RefCountedChildPolicyWrapperFactory; -import java.lang.Thread.UncaughtExceptionHandler; import java.util.Map; +import java.util.concurrent.atomic.AtomicBoolean; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @@ -61,7 +61,9 @@ public class LbPolicyConfigurationTest { private final LoadBalancer lb = mock(LoadBalancer.class); private final SubchannelStateManager subchannelStateManager = new SubchannelStateManagerImpl(); private final SubchannelPicker picker = mock(SubchannelPicker.class); - private final ChildLbStatusListener childLbStatusListener = mock(ChildLbStatusListener.class); + private final SynchronizationContext syncContext = new SynchronizationContext((t, e) -> { + throw new AssertionError(e); + }); private final ResolvedAddressFactory resolvedAddressFactory = new ResolvedAddressFactory() { @Override @@ -78,21 +80,12 @@ public ResolvedAddresses create(Object childLbConfig) { ImmutableMap.of("foo", "bar"), lbProvider), resolvedAddressFactory, - new ChildLoadBalancerHelperProvider(helper, subchannelStateManager, picker), - childLbStatusListener); + new ChildLoadBalancerHelperProvider(helper, subchannelStateManager, picker)); @Before public void setUp() { doReturn(mock(ChannelLogger.class)).when(helper).getChannelLogger(); - doReturn( - new SynchronizationContext( - new UncaughtExceptionHandler() { - @Override - public void uncaughtException(Thread t, Throwable e) { - throw new AssertionError(e); - } - })) - .when(helper).getSynchronizationContext(); + doReturn(syncContext).when(helper).getSynchronizationContext(); doReturn(lb).when(lbProvider).newLoadBalancer(any(Helper.class)); doReturn(ConfigOrError.fromConfig(new Object())) .when(lbProvider).parseLoadBalancingPolicyConfig(ArgumentMatchers.>any()); @@ -185,9 +178,26 @@ public void updateBalancingState_triggersListener() { childPolicyReportingHelper.updateBalancingState(ConnectivityState.READY, childPicker); - verify(childLbStatusListener).onStatusChanged(ConnectivityState.READY); assertThat(childPolicyWrapper.getPicker()).isEqualTo(childPicker); // picker governs childPickers will be reported to parent LB verify(helper).updateBalancingState(ConnectivityState.READY, picker); } + + @Test + public void refCountedGetOrCreate_addsChildBeforeConfiguringChild() { + AtomicBoolean calledAlready = new AtomicBoolean(); + when(lb.acceptResolvedAddresses(any(ResolvedAddresses.class))).thenAnswer(i -> { + if (!calledAlready.get()) { + calledAlready.set(true); + // Should end up calling this function again, as this child should already be added to the + // list of children. In practice, this can be caused by CDS is_dynamic=true starting a watch + // when XdsClient already has the cluster cached (e.g., from another channel). + syncContext.execute(() -> + factory.acceptResolvedAddressFactory(resolvedAddressFactory)); + } + return Status.OK; + }); + ChildPolicyWrapper unused = factory.createOrGet("foo.google.com"); + verify(lb, times(2)).acceptResolvedAddresses(any(ResolvedAddresses.class)); + } } diff --git a/rls/src/test/java/io/grpc/rls/LinkedHashLruCacheTest.java b/rls/src/test/java/io/grpc/rls/LinkedHashLruCacheTest.java index a31f58f5365..23ffe6ec026 100644 --- a/rls/src/test/java/io/grpc/rls/LinkedHashLruCacheTest.java +++ b/rls/src/test/java/io/grpc/rls/LinkedHashLruCacheTest.java @@ -25,8 +25,10 @@ import io.grpc.internal.FakeClock; import io.grpc.rls.LruCache.EvictionListener; import io.grpc.rls.LruCache.EvictionType; +import java.util.Arrays; import java.util.Objects; import java.util.concurrent.TimeUnit; +import javax.annotation.Nullable; import org.junit.Before; import org.junit.Rule; import org.junit.Test; @@ -56,14 +58,10 @@ public void setUp() { this.cache = new LinkedHashLruCache( MAX_SIZE, evictionListener, - 10, - TimeUnit.NANOSECONDS, - fakeClock.getScheduledExecutorService(), - fakeClock.getTicker(), - new Object()) { + fakeClock.getTicker()) { @Override protected boolean isExpired(Integer key, Entry value, long nowNanos) { - return value.expireTime <= nowNanos; + return value.expireTime - nowNanos <= 0; } @Override @@ -107,9 +105,11 @@ public void eviction_expire() { cache.cache(1, survivor); fakeClock.forwardTime(10, TimeUnit.NANOSECONDS); + cache.cleanupExpiredEntries(); verify(evictionListener).onEviction(0, toBeEvicted, EvictionType.EXPIRED); fakeClock.forwardTime(10, TimeUnit.NANOSECONDS); + cache.cleanupExpiredEntries(); verify(evictionListener).onEviction(1, survivor, EvictionType.EXPIRED); } @@ -160,6 +160,7 @@ public void eviction_cleanupShouldRemoveAlreadyExpired() { assertThat(cache.estimatedSize()).isEqualTo(MAX_SIZE); fakeClock.forwardTime(1, TimeUnit.MINUTES); + cache.cleanupExpiredEntries(); assertThat(cache.read(MAX_SIZE)).isNull(); assertThat(cache.estimatedSize()).isEqualTo(MAX_SIZE - 1); verify(evictionListener).onEviction(eq(MAX_SIZE), any(Entry.class), eq(EvictionType.EXPIRED)); @@ -267,4 +268,91 @@ public int hashCode() { return Objects.hash(value, expireTime); } } + + @Test + public void testFitToLimitWithReSize() { + + Entry entry1 = new Entry("Entry1", ticker.read() + 10, 4); + Entry entry2 = new Entry("Entry2", ticker.read() + 20, 1); + Entry entry3 = new Entry("Entry3", ticker.read() + 30, 2); + + cache.cache(1, entry1); + cache.cache(2, entry2); + cache.cache(3, entry3); + + assertThat(cache.estimatedSize()).isEqualTo(2); + assertThat(cache.estimatedSizeBytes()).isEqualTo(3); + assertThat(cache.estimatedMaxSizeBytes()).isEqualTo(5); + + cache.resize(2); + assertThat(cache.estimatedSize()).isEqualTo(1); + assertThat(cache.estimatedSizeBytes()).isEqualTo(2); + assertThat(cache.estimatedMaxSizeBytes()).isEqualTo(2); + + assertThat(cache.fitToLimit()).isEqualTo(false); + } + + @Test + public void testFitToLimit() { + + TestFitToLimitEviction localCache = new TestFitToLimitEviction( + MAX_SIZE, + evictionListener, + fakeClock.getTicker() + ); + + Entry entry1 = new Entry("Entry1", ticker.read() + 10, 4); + Entry entry2 = new Entry("Entry2", ticker.read() + 20, 2); + Entry entry3 = new Entry("Entry3", ticker.read() + 30, 1); + + localCache.cache(1, entry1); + localCache.cache(2, entry2); + localCache.cache(3, entry3); + + assertThat(localCache.estimatedSize()).isEqualTo(3); + assertThat(localCache.estimatedSizeBytes()).isEqualTo(7); + assertThat(localCache.estimatedMaxSizeBytes()).isEqualTo(5); + + localCache.enableEviction(); + + assertThat(localCache.fitToLimit()).isEqualTo(true); + + assertThat(localCache.values().contains(entry1)).isFalse(); + assertThat(localCache.values().containsAll(Arrays.asList(entry2, entry3))).isTrue(); + + assertThat(localCache.estimatedSize()).isEqualTo(2); + assertThat(localCache.estimatedSizeBytes()).isEqualTo(3); + assertThat(localCache.estimatedMaxSizeBytes()).isEqualTo(5); + } + + private static class TestFitToLimitEviction extends LinkedHashLruCache { + + private boolean allowEviction = false; + + TestFitToLimitEviction( + long estimatedMaxSizeBytes, + @Nullable EvictionListener evictionListener, + Ticker ticker) { + super(estimatedMaxSizeBytes, evictionListener, ticker); + } + + @Override + protected boolean isExpired(Integer key, Entry value, long nowNanos) { + return value.expireTime - nowNanos <= 0; + } + + @Override + protected int estimateSizeOf(Integer key, Entry value) { + return value.size; + } + + @Override + protected boolean shouldInvalidateEldestEntry(Integer eldestKey, Entry eldestValue, long now) { + return allowEviction && super.shouldInvalidateEldestEntry(eldestKey, eldestValue, now); + } + + public void enableEviction() { + allowEviction = true; + } + } } diff --git a/rls/src/test/java/io/grpc/rls/RlsLoadBalancerTest.java b/rls/src/test/java/io/grpc/rls/RlsLoadBalancerTest.java index e8a857d884e..a52390743a6 100644 --- a/rls/src/test/java/io/grpc/rls/RlsLoadBalancerTest.java +++ b/rls/src/test/java/io/grpc/rls/RlsLoadBalancerTest.java @@ -18,17 +18,22 @@ import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.truth.Truth.assertThat; +import static org.mockito.AdditionalAnswers.delegatesTo; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.argThat; import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.atLeast; import static org.mockito.Mockito.inOrder; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; import com.google.common.base.Converter; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Lists; import io.grpc.Attributes; import io.grpc.CallOptions; import io.grpc.ChannelCredentials; @@ -37,9 +42,13 @@ import io.grpc.ConnectivityStateInfo; import io.grpc.EquivalentAddressGroup; import io.grpc.ForwardingChannelBuilder2; +import io.grpc.Grpc; +import io.grpc.InternalManagedChannelBuilder; import io.grpc.LoadBalancer.CreateSubchannelArgs; import io.grpc.LoadBalancer.Helper; +import io.grpc.LoadBalancer.PickDetailsConsumer; import io.grpc.LoadBalancer.PickResult; +import io.grpc.LoadBalancer.PickSubchannelArgs; import io.grpc.LoadBalancer.ResolvedAddresses; import io.grpc.LoadBalancer.Subchannel; import io.grpc.LoadBalancer.SubchannelPicker; @@ -48,26 +57,37 @@ import io.grpc.ManagedChannelBuilder; import io.grpc.Metadata; import io.grpc.MethodDescriptor; -import io.grpc.MethodDescriptor.Marshaller; import io.grpc.MethodDescriptor.MethodType; +import io.grpc.MetricInstrument; +import io.grpc.MetricRecorder; +import io.grpc.MetricRecorder.Registration; +import io.grpc.MetricSink; import io.grpc.NameResolver.ConfigOrError; +import io.grpc.NoopMetricSink; +import io.grpc.ServerCall; +import io.grpc.ServerServiceDefinition; import io.grpc.Status; +import io.grpc.Status.Code; import io.grpc.SynchronizationContext; import io.grpc.inprocess.InProcessChannelBuilder; import io.grpc.inprocess.InProcessServerBuilder; import io.grpc.internal.FakeClock; import io.grpc.internal.JsonParser; import io.grpc.internal.PickSubchannelArgsImpl; +import io.grpc.internal.testing.StreamRecorder; import io.grpc.lookup.v1.RouteLookupServiceGrpc; import io.grpc.rls.RlsLoadBalancer.CachingRlsLbClientBuilderProvider; import io.grpc.rls.RlsProtoConverters.RouteLookupResponseConverter; import io.grpc.rls.RlsProtoData.RouteLookupRequest; import io.grpc.rls.RlsProtoData.RouteLookupResponse; +import io.grpc.stub.ClientCalls; import io.grpc.stub.StreamObserver; import io.grpc.testing.GrpcCleanupRule; +import io.grpc.testing.TestMethodDescriptors; import java.io.IOException; import java.lang.Thread.UncaughtExceptionHandler; import java.net.SocketAddress; +import java.util.Arrays; import java.util.Collections; import java.util.Deque; import java.util.LinkedList; @@ -84,6 +104,7 @@ import org.junit.runners.JUnit4; import org.mockito.AdditionalAnswers; import org.mockito.ArgumentCaptor; +import org.mockito.ArgumentMatcher; import org.mockito.Captor; import org.mockito.InOrder; import org.mockito.Mock; @@ -106,47 +127,58 @@ public void uncaughtException(Thread t, Throwable e) { throw new RuntimeException(e); } }); + @Mock + private MetricRecorder mockMetricRecorder; + @Mock + private Registration mockGaugeRegistration; + private final FakeHelper helperDelegate = new FakeHelper(); private final Helper helper = - mock(Helper.class, AdditionalAnswers.delegatesTo(new FakeHelper())); - private final FakeRlsServerImpl fakeRlsServerImpl = new FakeRlsServerImpl(); + mock(Helper.class, AdditionalAnswers.delegatesTo(helperDelegate)); + private final FakeRlsServerImpl fakeRlsServerImpl = new FakeRlsServerImpl( + fakeClock.getScheduledExecutorService()); private final Deque subchannels = new LinkedList<>(); private final FakeThrottler fakeThrottler = new FakeThrottler(); - @Mock - private Marshaller mockMarshaller; + private final String channelTarget = "channelTarget"; @Captor private ArgumentCaptor pickerCaptor; - private MethodDescriptor fakeSearchMethod; - private MethodDescriptor fakeRescueMethod; + private MethodDescriptor fakeSearchMethod; + private MethodDescriptor fakeRescueMethod; private RlsLoadBalancer rlsLb; private String defaultTarget = "defaultTarget"; + private PickSubchannelArgs searchSubchannelArgs; + private PickSubchannelArgs rescueSubchannelArgs; @Before public void setUp() { fakeSearchMethod = - MethodDescriptor.newBuilder() + MethodDescriptor.newBuilder() .setFullMethodName("com.google/Search") - .setRequestMarshaller(mockMarshaller) - .setResponseMarshaller(mockMarshaller) + .setRequestMarshaller(TestMethodDescriptors.voidMarshaller()) + .setResponseMarshaller(TestMethodDescriptors.voidMarshaller()) .setType(MethodType.CLIENT_STREAMING) .build(); fakeRescueMethod = - MethodDescriptor.newBuilder() + MethodDescriptor.newBuilder() .setFullMethodName("com.google/Rescue") - .setRequestMarshaller(mockMarshaller) - .setResponseMarshaller(mockMarshaller) + .setRequestMarshaller(TestMethodDescriptors.voidMarshaller()) + .setResponseMarshaller(TestMethodDescriptors.voidMarshaller()) .setType(MethodType.UNARY) .build(); fakeRlsServerImpl.setLookupTable( ImmutableMap.of( - RouteLookupRequest.create(ImmutableMap.of( + RouteLookupRequest.create( + ImmutableMap.of( "server", "fake-bigtable.googleapis.com", "service-key", "com.google", - "method-key", "Search")), + "method-key", "Search"), + RouteLookupRequest.Reason.REASON_MISS), RouteLookupResponse.create(ImmutableList.of("wilderness"), "where are you?"), - RouteLookupRequest.create(ImmutableMap.of( + RouteLookupRequest.create( + ImmutableMap.of( "server", "fake-bigtable.googleapis.com", "service-key", "com.google", - "method-key", "Rescue")), + "method-key", "Rescue"), + RouteLookupRequest.Reason.REASON_MISS), RouteLookupResponse.create(ImmutableList.of("civilization"), "you are safe"))); rlsLb = (RlsLoadBalancer) provider.newLoadBalancer(helper); @@ -159,6 +191,11 @@ public CachingRlsLbClient.Builder get() { .setTicker(fakeClock.getTicker()); } }; + + searchSubchannelArgs = newPickSubchannelArgs(fakeSearchMethod); + rescueSubchannelArgs = newPickSubchannelArgs(fakeRescueMethod); + + when(mockMetricRecorder.registerBatchCallback(any(), any())).thenReturn(mockGaugeRegistration); } @After @@ -168,22 +205,34 @@ public void tearDown() { @Test public void lb_serverStatusCodeConversion() throws Exception { - deliverResolvedAddresses(); + helper.getSynchronizationContext().execute(() -> { + try { + deliverResolvedAddresses(); + } catch (Exception e) { + throw new RuntimeException(e); + } + }); + assertThat(subchannels.poll()).isNotNull(); // default target + assertThat(subchannels.poll()).isNull(); + // Warm-up pick; will be queued InOrder inOrder = inOrder(helper); inOrder.verify(helper) .updateBalancingState(eq(ConnectivityState.CONNECTING), pickerCaptor.capture()); SubchannelPicker picker = pickerCaptor.getValue(); - Metadata headers = new Metadata(); - PickSubchannelArgsImpl fakeSearchMethodArgs = - new PickSubchannelArgsImpl(fakeSearchMethod, headers, CallOptions.DEFAULT); + PickSubchannelArgs fakeSearchMethodArgs = newPickSubchannelArgs(fakeSearchMethod); PickResult res = picker.pickSubchannel(fakeSearchMethodArgs); - FakeSubchannel subchannel = (FakeSubchannel) res.getSubchannel(); + assertThat(res.getStatus().isOk()).isTrue(); + assertThat(res.getSubchannel()).isNull(); + // Cache is warm, but still unconnected + picker.pickSubchannel(fakeSearchMethodArgs); // Will create the subchannel + FakeSubchannel subchannel = subchannels.peek(); assertThat(subchannel).isNotNull(); // Ensure happy path is unaffected subchannel.updateState(ConnectivityStateInfo.forNonError(ConnectivityState.READY)); res = picker.pickSubchannel(fakeSearchMethodArgs); assertThat(res.getStatus().getCode()).isEqualTo(Status.Code.OK); + verifyLongCounterAdd("grpc.lb.rls.target_picks", 1, 1, "wilderness", "complete"); // Check on conversion Throwable cause = new Throwable("cause"); @@ -198,47 +247,57 @@ public void lb_serverStatusCodeConversion() throws Exception { @Test public void lb_working_withDefaultTarget_rlsResponding() throws Exception { - deliverResolvedAddresses(); + helper.getSynchronizationContext().execute(() -> { + try { + deliverResolvedAddresses(); + } catch (Exception e) { + throw new RuntimeException(e); + } + }); InOrder inOrder = inOrder(helper); inOrder.verify(helper) .updateBalancingState(eq(ConnectivityState.CONNECTING), pickerCaptor.capture()); SubchannelPicker picker = pickerCaptor.getValue(); - Metadata headers = new Metadata(); - PickResult res = picker.pickSubchannel( - new PickSubchannelArgsImpl(fakeSearchMethod, headers, CallOptions.DEFAULT)); + // Warm-up pick; will be queued + PickResult res = picker.pickSubchannel(searchSubchannelArgs); + assertThat(res.getStatus().isOk()).isTrue(); + assertThat(res.getSubchannel()).isNull(); + // Cache is warm, but still unconnected + res = picker.pickSubchannel(searchSubchannelArgs); inOrder.verify(helper).createSubchannel(any(CreateSubchannelArgs.class)); - inOrder.verify(helper) + inOrder.verify(helper, atLeast(0)) .updateBalancingState(eq(ConnectivityState.CONNECTING), any(SubchannelPicker.class)); + inOrder.verify(helper, atLeast(0)).getSynchronizationContext(); + inOrder.verify(helper, atLeast(0)).getScheduledExecutorService(); + inOrder.verify(helper, atLeast(0)).getMetricRecorder(); + inOrder.verify(helper, atLeast(0)).getChannelTarget(); inOrder.verifyNoMoreInteractions(); - assertThat(res.getStatus().isOk()).isTrue(); - assertThat(subchannelIsReady(res.getSubchannel())).isFalse(); - assertThat(subchannels).hasSize(1); + assertThat(res.getStatus().isOk()).isTrue(); + assertThat(subchannels).hasSize(2); // includes fallback sub-channel FakeSubchannel searchSubchannel = subchannels.getLast(); + assertThat(subchannelIsReady(searchSubchannel)).isFalse(); + searchSubchannel.updateState(ConnectivityStateInfo.forNonError(ConnectivityState.READY)); inOrder.verify(helper) .updateBalancingState(eq(ConnectivityState.READY), pickerCaptor.capture()); inOrder.verifyNoMoreInteractions(); + res = picker.pickSubchannel(searchSubchannelArgs); assertThat(subchannelIsReady(res.getSubchannel())).isTrue(); - assertThat(res.getSubchannel().getAddresses()).isEqualTo(searchSubchannel.getAddresses()); - assertThat(res.getSubchannel().getAttributes()).isEqualTo(searchSubchannel.getAttributes()); + assertThat(res.getSubchannel()).isSameInstanceAs(searchSubchannel); + verifyLongCounterAdd("grpc.lb.rls.target_picks", 1, 1, "wilderness", "complete"); // rescue should be pending status although the overall channel state is READY - res = picker.pickSubchannel( - new PickSubchannelArgsImpl(fakeRescueMethod, headers, CallOptions.DEFAULT)); + res = picker.pickSubchannel(rescueSubchannelArgs); inOrder.verify(helper).createSubchannel(any(CreateSubchannelArgs.class)); // other rls picker itself is ready due to first channel. - inOrder.verify(helper) - .updateBalancingState(eq(ConnectivityState.READY), pickerCaptor.capture()); - inOrder.verifyNoMoreInteractions(); assertThat(res.getStatus().isOk()).isTrue(); assertThat(subchannelIsReady(res.getSubchannel())).isFalse(); - assertThat(subchannels).hasSize(2); + assertThat(subchannels).hasSize(3); // includes fallback sub-channel FakeSubchannel rescueSubchannel = subchannels.getLast(); // search subchannel is down, rescue subchannel is connecting searchSubchannel.updateState(ConnectivityStateInfo.forTransientFailure(Status.UNAVAILABLE)); - inOrder.verify(helper) .updateBalancingState(eq(ConnectivityState.CONNECTING), pickerCaptor.capture()); @@ -248,14 +307,89 @@ public void lb_working_withDefaultTarget_rlsResponding() throws Exception { // search again, verify that it doesn't use fallback, since RLS server responded, even though // subchannel is in failure mode - res = picker.pickSubchannel( - new PickSubchannelArgsImpl(fakeSearchMethod, headers, CallOptions.DEFAULT)); + res = picker.pickSubchannel(searchSubchannelArgs); assertThat(res.getStatus().getCode()).isEqualTo(Status.Code.UNAVAILABLE); assertThat(subchannelIsReady(res.getSubchannel())).isFalse(); + verifyLongCounterAdd("grpc.lb.rls.target_picks", 1, 1, "wilderness", "fail"); } @Test - public void lb_working_withDefaultTarget_noRlsResponse() throws Exception { + public void fallbackWithDelay_succeeds() throws Exception { + fakeRlsServerImpl.setResponseDelay(100, TimeUnit.MILLISECONDS); + grpcCleanupRule.register( + InProcessServerBuilder.forName("fake-bigtable.googleapis.com") + .addService(ServerServiceDefinition.builder("com.google") + .addMethod(fakeSearchMethod, (call, headers) -> { + call.sendHeaders(new Metadata()); + call.sendMessage(null); + call.close(Status.OK, new Metadata()); + return new ServerCall.Listener() {}; + }) + .build()) + .addService(fakeRlsServerImpl) + .directExecutor() + .build() + .start()); + ManagedChannel channel = grpcCleanupRule.register( + InProcessChannelBuilder.forName("fake-bigtable.googleapis.com") + .defaultServiceConfig(parseJson(getServiceConfigJsonStr())) + .directExecutor() + .build()); + + StreamRecorder recorder = StreamRecorder.create(); + StreamObserver requestObserver = ClientCalls.asyncClientStreamingCall( + channel.newCall(fakeSearchMethod, CallOptions.DEFAULT), recorder); + requestObserver.onCompleted(); + fakeClock.forwardTime(100, TimeUnit.MILLISECONDS); + assertThat(recorder.awaitCompletion(10, TimeUnit.SECONDS)).isTrue(); + assertThat(recorder.getError()).isNull(); + } + + @Test + public void metricsWithRealChannel() throws Exception { + grpcCleanupRule.register( + InProcessServerBuilder.forName("fake-bigtable.googleapis.com") + .addService(ServerServiceDefinition.builder("com.google") + .addMethod(fakeSearchMethod, (call, headers) -> { + call.sendHeaders(new Metadata()); + call.sendMessage(null); + call.close(Status.OK, new Metadata()); + return new ServerCall.Listener() {}; + }) + .build()) + .addService(fakeRlsServerImpl) + .directExecutor() + .build() + .start()); + MetricSink metrics = mock(MetricSink.class, delegatesTo(new NoopMetricSink())); + ManagedChannel channel = grpcCleanupRule.register( + InternalManagedChannelBuilder.addMetricSink( + InProcessChannelBuilder.forName("fake-bigtable.googleapis.com") + .defaultServiceConfig(parseJson(getServiceConfigJsonStr())) + .directExecutor(), + metrics) + .build()); + + StreamRecorder recorder = StreamRecorder.create(); + CallOptions callOptions = CallOptions.DEFAULT + .withOption(Grpc.CALL_OPTION_CUSTOM_LABEL, "customvalue"); + StreamObserver requestObserver = ClientCalls.asyncClientStreamingCall( + channel.newCall(fakeSearchMethod, callOptions), recorder); + requestObserver.onCompleted(); + assertThat(recorder.awaitCompletion(10, TimeUnit.SECONDS)).isTrue(); + assertThat(recorder.getError()).isNull(); + + verify(metrics).addLongCounter( + eqMetricInstrumentName("grpc.lb.rls.default_target_picks"), + eq(1L), + eq(Arrays.asList("directaddress:///fake-bigtable.googleapis.com", "localhost:8972", + "defaultTarget", "complete")), + eq(Arrays.asList("customvalue"))); + } + + @Test + public void lb_working_withoutDefaultTarget_noRlsResponse() throws Exception { + defaultTarget = ""; fakeThrottler.nextResult = true; deliverResolvedAddresses(); @@ -263,52 +397,77 @@ public void lb_working_withDefaultTarget_noRlsResponse() throws Exception { inOrder.verify(helper) .updateBalancingState(eq(ConnectivityState.CONNECTING), pickerCaptor.capture()); SubchannelPicker picker = pickerCaptor.getValue(); - Metadata headers = new Metadata(); - PickResult res; - // Search that when the RLS server doesn't respond, that fallback is used - res = picker.pickSubchannel( - new PickSubchannelArgsImpl(fakeSearchMethod, headers, CallOptions.DEFAULT)); - FakeSubchannel fallbackSubchannel = (FakeSubchannel) res.getSubchannel(); - assertThat(fallbackSubchannel).isNotNull(); - - assertThat(res.getStatus().getCode()).isEqualTo(Status.Code.OK); - assertThat(subchannelIsReady(res.getSubchannel())).isFalse(); - inOrder.verify(helper).createSubchannel(any(CreateSubchannelArgs.class)); - fallbackSubchannel.updateState(ConnectivityStateInfo.forNonError(ConnectivityState.READY)); - inOrder.verify(helper, times(1)) - .updateBalancingState(eq(ConnectivityState.READY), pickerCaptor.capture()); + // With no RLS response and no fallback, we should see a failure + PickResult res = picker.pickSubchannel(searchSubchannelArgs); // create subchannel + assertThat(res.getStatus().getCode()).isEqualTo(Code.UNAVAILABLE); + inOrder.verify(helper).getMetricRecorder(); + inOrder.verify(helper).getChannelTarget(); inOrder.verifyNoMoreInteractions(); + verifyFailedPicksCounterAdd(1, 1); + } - res = picker.pickSubchannel( - new PickSubchannelArgsImpl(fakeSearchMethod, headers, CallOptions.DEFAULT)); - assertThat(subchannelIsReady(res.getSubchannel())).isTrue(); - assertThat(res.getSubchannel()).isSameInstanceAs(fallbackSubchannel); + @Test + public void lb_working_withDefaultTarget_noRlsResponse() throws Exception { + fakeThrottler.nextResult = true; - res = picker.pickSubchannel( - new PickSubchannelArgsImpl(fakeRescueMethod, headers, CallOptions.DEFAULT)); - assertThat(subchannelIsReady(res.getSubchannel())).isTrue(); - assertThat(res.getSubchannel()).isSameInstanceAs(fallbackSubchannel); + helper.getSynchronizationContext().execute(() -> { + try { + deliverResolvedAddresses(); + } catch (Exception e) { + throw new RuntimeException(e); + } + }); + InOrder inOrder = inOrder(helper); + inOrder.verify(helper) + .updateBalancingState(eq(ConnectivityState.CONNECTING), pickerCaptor.capture()); + SubchannelPicker picker = pickerCaptor.getValue(); + + // Search that when the RLS server doesn't respond, that fallback is used + PickResult res = picker.pickSubchannel(searchSubchannelArgs); // create subchannel + assertThat(res.getStatus().getCode()).isEqualTo(Status.Code.OK); + FakeSubchannel fallbackSubchannel = + (FakeSubchannel) markReadyAndGetPickResult(inOrder, searchSubchannelArgs).getSubchannel(); + assertThat(fallbackSubchannel).isNotNull(); + assertThat(subchannelIsReady(fallbackSubchannel)).isTrue(); + inOrder.verify(helper).getMetricRecorder(); + inOrder.verify(helper).getChannelTarget(); + inOrder.verifyNoMoreInteractions(); + int times = 1; + verifyLongCounterAdd("grpc.lb.rls.default_target_picks", times, 1, + "defaultTarget", "complete"); + + Subchannel subchannel = picker.pickSubchannel(searchSubchannelArgs).getSubchannel(); + assertThat(subchannelIsReady(subchannel)).isTrue(); + assertThat(subchannel).isSameInstanceAs(fallbackSubchannel); + verifyLongCounterAdd("grpc.lb.rls.default_target_picks", ++times, 1, "defaultTarget", + "complete"); + + subchannel = picker.pickSubchannel(searchSubchannelArgs).getSubchannel(); + assertThat(subchannelIsReady(subchannel)).isTrue(); + assertThat(subchannel).isSameInstanceAs(fallbackSubchannel); + verifyLongCounterAdd("grpc.lb.rls.default_target_picks", ++times, 1, "defaultTarget", + "complete"); // Make sure that when RLS starts communicating that default stops being used fakeThrottler.nextResult = false; fakeClock.forwardTime(2, TimeUnit.SECONDS); // Expires backoff cache entries - // Create search subchannel - res = picker.pickSubchannel( - new PickSubchannelArgsImpl(fakeSearchMethod, headers, CallOptions.DEFAULT)); - assertThat(res.getSubchannel()).isNotSameInstanceAs(fallbackSubchannel); - FakeSubchannel searchSubchannel = (FakeSubchannel) res.getSubchannel(); + + picker.pickSubchannel(searchSubchannelArgs);// Create search subchannel + FakeSubchannel searchSubchannel = + (FakeSubchannel) markReadyAndGetPickResult(inOrder, searchSubchannelArgs).getSubchannel(); assertThat(searchSubchannel).isNotNull(); - searchSubchannel.updateState(ConnectivityStateInfo.forNonError(ConnectivityState.READY)); + assertThat(searchSubchannel).isNotSameInstanceAs(fallbackSubchannel); + verifyLongCounterAdd("grpc.lb.rls.target_picks", 1, 1, "wilderness", "complete"); // create rescue subchannel - res = picker.pickSubchannel( - new PickSubchannelArgsImpl(fakeRescueMethod, headers, CallOptions.DEFAULT)); - assertThat(res.getSubchannel()).isNotSameInstanceAs(fallbackSubchannel); - assertThat(res.getSubchannel()).isNotSameInstanceAs(searchSubchannel); - FakeSubchannel rescueSubchannel = (FakeSubchannel) res.getSubchannel(); + picker.pickSubchannel(rescueSubchannelArgs); + FakeSubchannel rescueSubchannel = + (FakeSubchannel) markReadyAndGetPickResult(inOrder, rescueSubchannelArgs).getSubchannel(); assertThat(rescueSubchannel).isNotNull(); - rescueSubchannel.updateState(ConnectivityStateInfo.forNonError(ConnectivityState.READY)); + assertThat(rescueSubchannel).isNotSameInstanceAs(fallbackSubchannel); + assertThat(rescueSubchannel).isNotSameInstanceAs(searchSubchannel); + verifyLongCounterAdd("grpc.lb.rls.target_picks", 1, 1, "civilization", "complete"); // all channels are failed rescueSubchannel.updateState(ConnectivityStateInfo.forTransientFailure(Status.UNAVAILABLE)); @@ -316,9 +475,10 @@ public void lb_working_withDefaultTarget_noRlsResponse() throws Exception { fallbackSubchannel.updateState(ConnectivityStateInfo.forTransientFailure(Status.UNAVAILABLE)); res = picker.pickSubchannel( - new PickSubchannelArgsImpl(fakeSearchMethod, headers, CallOptions.DEFAULT)); + searchSubchannelArgs); assertThat(res.getStatus().getCode()).isEqualTo(Status.Code.UNAVAILABLE); assertThat(res.getSubchannel()).isNull(); + verifyLongCounterAdd("grpc.lb.rls.target_picks", 1, 1, "wilderness", "fail"); } @Test @@ -329,38 +489,40 @@ public void lb_working_withoutDefaultTarget() throws Exception { inOrder.verify(helper) .updateBalancingState(eq(ConnectivityState.CONNECTING), pickerCaptor.capture()); SubchannelPicker picker = pickerCaptor.getValue(); - Metadata headers = new Metadata(); - PickResult res = picker.pickSubchannel( - new PickSubchannelArgsImpl(fakeSearchMethod, headers, CallOptions.DEFAULT)); + // Warm-up pick; will be queued + PickResult res = picker.pickSubchannel(searchSubchannelArgs); + assertThat(res.getStatus().isOk()).isTrue(); + assertThat(res.getSubchannel()).isNull(); + // Cache is warm, but still unconnected + res = picker.pickSubchannel(searchSubchannelArgs); inOrder.verify(helper).createSubchannel(any(CreateSubchannelArgs.class)); - inOrder.verify(helper) + inOrder.verify(helper, atLeast(0)) .updateBalancingState(eq(ConnectivityState.CONNECTING), any(SubchannelPicker.class)); + inOrder.verify(helper, atLeast(0)).getSynchronizationContext(); + inOrder.verify(helper, atLeast(0)).getScheduledExecutorService(); + inOrder.verify(helper, atLeast(0)).getMetricRecorder(); + inOrder.verify(helper, atLeast(0)).getChannelTarget(); inOrder.verifyNoMoreInteractions(); assertThat(res.getStatus().isOk()).isTrue(); assertThat(subchannels).hasSize(1); - FakeSubchannel searchSubchannel = subchannels.getLast(); - searchSubchannel.updateState(ConnectivityStateInfo.forNonError(ConnectivityState.READY)); - inOrder.verify(helper) - .updateBalancingState(eq(ConnectivityState.READY), pickerCaptor.capture()); + FakeSubchannel searchSubchannel = + (FakeSubchannel) markReadyAndGetPickResult(inOrder, searchSubchannelArgs).getSubchannel(); + inOrder.verify(helper).getMetricRecorder(); + inOrder.verify(helper).getChannelTarget(); inOrder.verifyNoMoreInteractions(); - assertThat(subchannelIsReady(res.getSubchannel())).isTrue(); - assertThat(res.getSubchannel().getAddresses()).isEqualTo(searchSubchannel.getAddresses()); - assertThat(res.getSubchannel().getAttributes()).isEqualTo(searchSubchannel.getAttributes()); + assertThat(subchannelIsReady(searchSubchannel)).isTrue(); + assertThat(subchannels.getLast()).isSameInstanceAs(searchSubchannel); // rescue should be pending status although the overall channel state is READY picker = pickerCaptor.getValue(); - res = picker.pickSubchannel( - new PickSubchannelArgsImpl(fakeRescueMethod, headers, CallOptions.DEFAULT)); + res = picker.pickSubchannel(rescueSubchannelArgs); inOrder.verify(helper).createSubchannel(any(CreateSubchannelArgs.class)); // other rls picker itself is ready due to first channel. - inOrder.verify(helper) - .updateBalancingState(eq(ConnectivityState.READY), pickerCaptor.capture()); - inOrder.verifyNoMoreInteractions(); assertThat(res.getStatus().isOk()).isTrue(); - assertThat(subchannelIsReady(res.getSubchannel())).isFalse(); assertThat(subchannels).hasSize(2); FakeSubchannel rescueSubchannel = subchannels.getLast(); + assertThat(subchannelIsReady(rescueSubchannel)).isFalse(); // search subchannel is down, rescue subchannel is still connecting searchSubchannel.updateState(ConnectivityStateInfo.forTransientFailure(Status.NOT_FOUND)); @@ -373,43 +535,45 @@ public void lb_working_withoutDefaultTarget() throws Exception { // search method will fail because there is no fallback target. picker = pickerCaptor.getValue(); - res = picker.pickSubchannel( - new PickSubchannelArgsImpl(fakeSearchMethod, headers, CallOptions.DEFAULT)); + res = picker.pickSubchannel(newPickSubchannelArgs(fakeSearchMethod)); assertThat(res.getStatus().isOk()).isFalse(); assertThat(subchannelIsReady(res.getSubchannel())).isFalse(); + verifyLongCounterAdd("grpc.lb.rls.target_picks", 1, 1, "wilderness", "complete"); - res = picker.pickSubchannel( - new PickSubchannelArgsImpl(fakeRescueMethod, headers, CallOptions.DEFAULT)); + res = picker.pickSubchannel(newPickSubchannelArgs(fakeRescueMethod)); assertThat(subchannelIsReady(res.getSubchannel())).isTrue(); assertThat(res.getSubchannel().getAddresses()).isEqualTo(rescueSubchannel.getAddresses()); assertThat(res.getSubchannel().getAttributes()).isEqualTo(rescueSubchannel.getAttributes()); + verifyLongCounterAdd("grpc.lb.rls.target_picks", 1, 1, "civilization", "complete"); // all channels are failed rescueSubchannel.updateState(ConnectivityStateInfo.forTransientFailure(Status.NOT_FOUND)); inOrder.verify(helper) .updateBalancingState(eq(ConnectivityState.TRANSIENT_FAILURE), pickerCaptor.capture()); + inOrder.verify(helper, atLeast(0)).refreshNameResolution(); inOrder.verifyNoMoreInteractions(); + verifyLongCounterAdd("grpc.lb.rls.target_picks", 1, 1, "wilderness", "fail"); } @Test public void lb_nameResolutionFailed() throws Exception { - deliverResolvedAddresses(); + helper.getSynchronizationContext().execute(() -> { + try { + deliverResolvedAddresses(); + } catch (Exception e) { + throw new RuntimeException(e); + } + }); InOrder inOrder = inOrder(helper); inOrder.verify(helper) .updateBalancingState(eq(ConnectivityState.CONNECTING), pickerCaptor.capture()); SubchannelPicker picker = pickerCaptor.getValue(); - Metadata headers = new Metadata(); - PickResult res = - picker.pickSubchannel( - new PickSubchannelArgsImpl(fakeSearchMethod, headers, CallOptions.DEFAULT)); + PickResult res = picker.pickSubchannel(newPickSubchannelArgs(fakeSearchMethod)); assertThat(res.getStatus().isOk()).isTrue(); assertThat(subchannelIsReady(res.getSubchannel())).isFalse(); inOrder.verify(helper).createSubchannel(any(CreateSubchannelArgs.class)); - inOrder.verify(helper) - .updateBalancingState(eq(ConnectivityState.CONNECTING), pickerCaptor.capture()); - assertThat(subchannels).hasSize(1); - inOrder.verifyNoMoreInteractions(); + assertThat(subchannels).hasSize(2); // includes fallback sub-channel FakeSubchannel searchSubchannel = subchannels.getLast(); searchSubchannel.updateState(ConnectivityStateInfo.forNonError(ConnectivityState.READY)); @@ -418,13 +582,15 @@ public void lb_nameResolutionFailed() throws Exception { SubchannelPicker picker2 = pickerCaptor.getValue(); assertThat(picker2).isEqualTo(picker); - res = picker2.pickSubchannel( - new PickSubchannelArgsImpl(fakeSearchMethod, headers, CallOptions.DEFAULT)); + res = picker2.pickSubchannel(newPickSubchannelArgs(fakeSearchMethod)); // verify success. Subchannel is wrapped, so checking attributes. assertThat(subchannelIsReady(res.getSubchannel())).isTrue(); assertThat(res.getSubchannel().getAddresses()).isEqualTo(searchSubchannel.getAddresses()); assertThat(res.getSubchannel().getAttributes()).isEqualTo(searchSubchannel.getAttributes()); + verifyLongCounterAdd("grpc.lb.rls.target_picks", 1, 1, "wilderness", "complete"); + inOrder.verify(helper).getMetricRecorder(); + inOrder.verify(helper).getChannelTarget(); inOrder.verifyNoMoreInteractions(); rlsLb.handleNameResolutionError(Status.UNAVAILABLE); @@ -432,15 +598,24 @@ public void lb_nameResolutionFailed() throws Exception { verify(helper) .updateBalancingState(eq(ConnectivityState.TRANSIENT_FAILURE), pickerCaptor.capture()); SubchannelPicker failedPicker = pickerCaptor.getValue(); - res = failedPicker.pickSubchannel( - new PickSubchannelArgsImpl(fakeSearchMethod, headers, CallOptions.DEFAULT)); + res = failedPicker.pickSubchannel(newPickSubchannelArgs(fakeSearchMethod)); assertThat(res.getStatus().isOk()).isFalse(); assertThat(subchannelIsReady(res.getSubchannel())).isFalse(); } + private PickResult markReadyAndGetPickResult(InOrder inOrder, + PickSubchannelArgs pickSubchannelArgs) { + subchannels.getLast().updateState(ConnectivityStateInfo.forNonError(ConnectivityState.READY)); + inOrder.verify(helper, atLeast(1)) + .updateBalancingState(eq(ConnectivityState.READY), pickerCaptor.capture()); + PickResult pickResult = pickerCaptor.getValue().pickSubchannel(pickSubchannelArgs); + inOrder.verify(helper, atLeast(0)).getChannelLogger(); + return pickResult; + } + private void deliverResolvedAddresses() throws Exception { ConfigOrError parsedConfigOrError = - provider.parseLoadBalancingPolicyConfig(getServiceConfig()); + provider.parseLoadBalancingPolicyConfig(parseJson(getLbConfigJsonStr())); assertThat(parsedConfigOrError.getConfig()).isNotNull(); rlsLb.acceptResolvedAddresses(ResolvedAddresses.newBuilder() .setAddresses(ImmutableList.of(new EquivalentAddressGroup(mock(SocketAddress.class)))) @@ -450,13 +625,24 @@ private void deliverResolvedAddresses() throws Exception { } @SuppressWarnings("unchecked") - private Map getServiceConfig() throws IOException { - String serviceConfig = "{" + private Map parseJson(String json) throws IOException { + return (Map) JsonParser.parse(json); + } + + private String getServiceConfigJsonStr() { + return "{" + + " \"loadBalancingConfig\": [{" + + " \"rls_experimental\": " + getLbConfigJsonStr() + + " }]" + + "}"; + } + + private String getLbConfigJsonStr() { + return "{" + " \"routeLookupConfig\": " + getRlsConfigJsonStr() + ", " + " \"childPolicy\": [{\"pick_first\": {}}]," + " \"childPolicyConfigTargetFieldName\": \"serviceName\"" + "}"; - return (Map) JsonParser.parse(serviceConfig); } private String getRlsConfigJsonStr() { @@ -494,6 +680,41 @@ private String getRlsConfigJsonStr() { + "}"; } + // Verifies that the MetricRecorder has been called to record a long counter value of 1 for the + // given metric name, the given number of times + private void verifyLongCounterAdd(String name, int times, long value, + String dataPlaneTargetLabel, String pickResult) { + // TODO: support the "grpc.target" label once available. + verify(mockMetricRecorder, times(times)).addLongCounter( + eqMetricInstrumentName(name), eq(value), + eq(Lists.newArrayList(channelTarget, "localhost:8972", dataPlaneTargetLabel, pickResult)), + eq(Lists.newArrayList(""))); + } + + // This one is for verifying the failed_pick metric specifically. + private void verifyFailedPicksCounterAdd(int times, long value) { + // TODO: support the "grpc.target" label once available. + verify(mockMetricRecorder, times(times)).addLongCounter( + eqMetricInstrumentName("grpc.lb.rls.failed_picks"), eq(value), + eq(Lists.newArrayList(channelTarget, "localhost:8972")), + eq(Lists.newArrayList(""))); + } + + @SuppressWarnings("TypeParameterUnusedInFormals") + private T eqMetricInstrumentName(String name) { + return argThat(new ArgumentMatcher() { + @Override + public boolean matches(T instrument) { + return instrument.getName().equals(name); + } + }); + } + + private PickSubchannelArgs newPickSubchannelArgs(MethodDescriptor method) { + return new PickSubchannelArgsImpl( + method, new Metadata(), CallOptions.DEFAULT, new PickDetailsConsumer() {}); + } + private final class FakeHelper extends Helper { @Override @@ -582,6 +803,16 @@ public SynchronizationContext getSynchronizationContext() { public ChannelLogger getChannelLogger() { return mock(ChannelLogger.class); } + + @Override + public MetricRecorder getMetricRecorder() { + return mockMetricRecorder; + } + + @Override + public String getChannelTarget() { + return channelTarget; + } } private static final class FakeRlsServerImpl @@ -592,17 +823,41 @@ private static final class FakeRlsServerImpl private static final Converter RESPONSE_CONVERTER = new RouteLookupResponseConverter().reverse(); + private final ScheduledExecutorService scheduler; + private long delay; + private TimeUnit delayUnit; + + public FakeRlsServerImpl(ScheduledExecutorService scheduler) { + this.scheduler = scheduler; + } + private Map lookupTable = ImmutableMap.of(); private void setLookupTable(Map lookupTable) { this.lookupTable = checkNotNull(lookupTable, "lookupTable"); } + void setResponseDelay(long delay, TimeUnit unit) { + this.delay = delay; + this.delayUnit = unit; + } + @Override + @SuppressWarnings("FutureReturnValueIgnored") public void routeLookup(io.grpc.lookup.v1.RouteLookupRequest request, StreamObserver responseObserver) { RouteLookupResponse response = lookupTable.get(REQUEST_CONVERTER.convert(request)); + Runnable sendResponse = () -> sendResponse(response, responseObserver); + if (delay != 0) { + scheduler.schedule(sendResponse, delay, delayUnit); + } else { + sendResponse.run(); + } + } + + private void sendResponse(RouteLookupResponse response, + StreamObserver responseObserver) { if (response == null) { responseObserver.onError(new RuntimeException("not found")); } else { diff --git a/rls/src/test/java/io/grpc/rls/RlsProtoConvertersTest.java b/rls/src/test/java/io/grpc/rls/RlsProtoConvertersTest.java index 98b7101fd5e..82ad606c50d 100644 --- a/rls/src/test/java/io/grpc/rls/RlsProtoConvertersTest.java +++ b/rls/src/test/java/io/grpc/rls/RlsProtoConvertersTest.java @@ -61,12 +61,14 @@ public void convert_toRequestObject() { Converter converter = new RouteLookupRequestConverter().reverse(); RlsProtoData.RouteLookupRequest requestObject = - RlsProtoData.RouteLookupRequest.create(ImmutableMap.of("key1", "val1")); + RlsProtoData.RouteLookupRequest.create(ImmutableMap.of("key1", "val1"), + RlsProtoData.RouteLookupRequest.Reason.REASON_MISS); RouteLookupRequest proto = converter.convert(requestObject); assertThat(proto.getTargetType()).isEqualTo("grpc"); assertThat(proto.getKeyMapMap()).containsExactly("key1", "val1"); + assertThat(proto.getReason()).isEqualTo(RouteLookupRequest.Reason.REASON_MISS); } @Test @@ -469,6 +471,124 @@ public void convert_jsonRlsConfig_staleAgeGivenWithoutMaxAge() throws IOExceptio } } + @Test + public void convert_jsonRlsConfig_doNotClampMaxAgeIfStaleAgeIsSet() throws IOException { + String jsonStr = "{\n" + + " \"grpcKeybuilders\": [\n" + + " {\n" + + " \"names\": [\n" + + " {\n" + + " \"service\": \"service1\",\n" + + " \"method\": \"create\"\n" + + " }\n" + + " ],\n" + + " \"headers\": [\n" + + " {\n" + + " \"key\": \"user\"," + + " \"names\": [\"User\", \"Parent\"],\n" + + " \"optional\": true\n" + + " },\n" + + " {\n" + + " \"key\": \"id\"," + + " \"names\": [\"X-Google-Id\"],\n" + + " \"optional\": true\n" + + " }\n" + + " ]\n" + + " }\n" + + " ],\n" + + " \"lookupService\": \"service1\",\n" + + " \"lookupServiceTimeout\": \"2s\",\n" + + " \"maxAge\": \"350s\",\n" + + " \"staleAge\": \"310s\",\n" + + " \"validTargets\": [\"a valid target\"]," + + " \"cacheSizeBytes\": \"1000\",\n" + + " \"defaultTarget\": \"us_east_1.cloudbigtable.googleapis.com\"\n" + + "}"; + + RouteLookupConfig expectedConfig = + RouteLookupConfig.builder() + .grpcKeybuilders(ImmutableList.of( + GrpcKeyBuilder.create( + ImmutableList.of(Name.create("service1", "create")), + ImmutableList.of( + NameMatcher.create("user", ImmutableList.of("User", "Parent")), + NameMatcher.create("id", ImmutableList.of("X-Google-Id"))), + ExtraKeys.DEFAULT, + ImmutableMap.of()))) + .lookupService("service1") + .lookupServiceTimeoutInNanos(TimeUnit.SECONDS.toNanos(2)) + .maxAgeInNanos(TimeUnit.SECONDS.toNanos(350)) // Should not be clamped + .staleAgeInNanos(TimeUnit.SECONDS.toNanos(300)) // Should be clamped to max 300s + .cacheSizeBytes(1000) + .defaultTarget("us_east_1.cloudbigtable.googleapis.com") + .build(); + + RouteLookupConfigConverter converter = new RouteLookupConfigConverter(); + @SuppressWarnings("unchecked") + Map parsedJson = (Map) JsonParser.parse(jsonStr); + RouteLookupConfig converted = converter.convert(parsedJson); + assertThat(converted).isEqualTo(expectedConfig); + } + + @Test + public void convert_jsonRlsConfig_clampMaxAgeIfStaleAgeMissing() throws IOException { + String jsonStr = "{\n" + + " \"grpcKeybuilders\": [\n" + + " {\n" + + " \"names\": [\n" + + " {\n" + + " \"service\": \"service1\",\n" + + " \"method\": \"create\"\n" + + " }\n" + + " ],\n" + + " \"headers\": [\n" + + " {\n" + + " \"key\": \"user\"," + + " \"names\": [\"User\", \"Parent\"],\n" + + " \"optional\": true\n" + + " },\n" + + " {\n" + + " \"key\": \"id\"," + + " \"names\": [\"X-Google-Id\"],\n" + + " \"optional\": true\n" + + " }\n" + + " ]\n" + + " }\n" + + " ],\n" + + " \"lookupService\": \"service1\",\n" + + " \"lookupServiceTimeout\": \"2s\",\n" + + " \"maxAge\": \"350s\",\n" // Exceeds 5m limit + + " \"validTargets\": [\"a valid target\"]," + + " \"cacheSizeBytes\": \"1000\",\n" + + " \"defaultTarget\": \"us_east_1.cloudbigtable.googleapis.com\"\n" + + "}"; + + RouteLookupConfig expectedConfig = + RouteLookupConfig.builder() + .grpcKeybuilders(ImmutableList.of( + GrpcKeyBuilder.create( + ImmutableList.of(Name.create("service1", "create")), + ImmutableList.of( + NameMatcher.create("user", ImmutableList.of("User", "Parent")), + NameMatcher.create("id", ImmutableList.of("X-Google-Id"))), + ExtraKeys.DEFAULT, + ImmutableMap.of()))) + .lookupService("service1") + .lookupServiceTimeoutInNanos(TimeUnit.SECONDS.toNanos(2)) + // Should be clamped to 300s (5m) because staleAge is missing + .maxAgeInNanos(TimeUnit.MINUTES.toNanos(5)) + .staleAgeInNanos(TimeUnit.MINUTES.toNanos(5)) + .cacheSizeBytes(1000) + .defaultTarget("us_east_1.cloudbigtable.googleapis.com") + .build(); + + RouteLookupConfigConverter converter = new RouteLookupConfigConverter(); + @SuppressWarnings("unchecked") + Map parsedJson = (Map) JsonParser.parse(jsonStr); + RouteLookupConfig converted = converter.convert(parsedJson); + assertThat(converted).isEqualTo(expectedConfig); + } + @Test public void convert_jsonRlsConfig_keyBuilderWithoutName() throws IOException { String jsonStr = "{\n" diff --git a/rls/src/test/java/io/grpc/rls/RlsRequestFactoryTest.java b/rls/src/test/java/io/grpc/rls/RlsRequestFactoryTest.java index 6ee2c01af8a..2b900994ed9 100644 --- a/rls/src/test/java/io/grpc/rls/RlsRequestFactoryTest.java +++ b/rls/src/test/java/io/grpc/rls/RlsRequestFactoryTest.java @@ -26,7 +26,6 @@ import io.grpc.rls.RlsProtoData.GrpcKeyBuilder.Name; import io.grpc.rls.RlsProtoData.NameMatcher; import io.grpc.rls.RlsProtoData.RouteLookupConfig; -import io.grpc.rls.RlsProtoData.RouteLookupRequest; import java.util.concurrent.TimeUnit; import org.junit.Test; import org.junit.runner.RunWith; @@ -82,8 +81,9 @@ public void create_pathMatches() { metadata.put(Metadata.Key.of("X-Google-Id", Metadata.ASCII_STRING_MARSHALLER), "123"); metadata.put(Metadata.Key.of("foo", Metadata.ASCII_STRING_MARSHALLER), "bar"); - RouteLookupRequest request = factory.create("com.google.service1", "Create", metadata); - assertThat(request.keyMap()).containsExactly( + RlsProtoData.RouteLookupRequestKey routeLookupRequestKey = + factory.create("com.google.service1", "Create", metadata); + assertThat(routeLookupRequestKey.keyMap()).containsExactly( "user", "test", "id", "123", "server-1", "bigtable.googleapis.com", @@ -97,9 +97,10 @@ public void create_pathFallbackMatches() { metadata.put(Metadata.Key.of("Password", Metadata.ASCII_STRING_MARSHALLER), "hunter2"); metadata.put(Metadata.Key.of("foo", Metadata.ASCII_STRING_MARSHALLER), "bar"); - RouteLookupRequest request = factory.create("com.google.service1" , "Update", metadata); + RlsProtoData.RouteLookupRequestKey routeLookupRequestKey = + factory.create("com.google.service1" , "Update", metadata); - assertThat(request.keyMap()).containsExactly( + assertThat(routeLookupRequestKey.keyMap()).containsExactly( "user", "test", "password", "hunter2", "service-2", "com.google.service1", @@ -113,9 +114,10 @@ public void create_pathFallbackMatches_optionalHeaderMissing() { metadata.put(Metadata.Key.of("X-Google-Id", Metadata.ASCII_STRING_MARSHALLER), "123"); metadata.put(Metadata.Key.of("foo", Metadata.ASCII_STRING_MARSHALLER), "bar"); - RouteLookupRequest request = factory.create("com.google.service1", "Update", metadata); + RlsProtoData.RouteLookupRequestKey routeLookupRequestKey = + factory.create("com.google.service1", "Update", metadata); - assertThat(request.keyMap()).containsExactly( + assertThat(routeLookupRequestKey.keyMap()).containsExactly( "user", "test", "service-2", "com.google.service1", "const-key-2", "const-value-2"); @@ -128,8 +130,9 @@ public void create_unknownPath() { metadata.put(Metadata.Key.of("X-Google-Id", Metadata.ASCII_STRING_MARSHALLER), "123"); metadata.put(Metadata.Key.of("foo", Metadata.ASCII_STRING_MARSHALLER), "bar"); - RouteLookupRequest request = factory.create("abc.def.service999", "Update", metadata); - assertThat(request.keyMap()).isEmpty(); + RlsProtoData.RouteLookupRequestKey routeLookupRequestKey = + factory.create("abc.def.service999", "Update", metadata); + assertThat(routeLookupRequestKey.keyMap()).isEmpty(); } @Test @@ -139,9 +142,10 @@ public void create_noMethodInRlsConfig() { metadata.put(Metadata.Key.of("X-Google-Id", Metadata.ASCII_STRING_MARSHALLER), "123"); metadata.put(Metadata.Key.of("foo", Metadata.ASCII_STRING_MARSHALLER), "bar"); - RouteLookupRequest request = factory.create("com.google.service3", "Update", metadata); + RlsProtoData.RouteLookupRequestKey routeLookupRequestKey = + factory.create("com.google.service3", "Update", metadata); - assertThat(request.keyMap()).containsExactly( + assertThat(routeLookupRequestKey.keyMap()).containsExactly( "user", "test", "const-key-4", "const-value-4"); } } diff --git a/s2a/BUILD.bazel b/s2a/BUILD.bazel new file mode 100644 index 00000000000..34387206ba5 --- /dev/null +++ b/s2a/BUILD.bazel @@ -0,0 +1,93 @@ +load("@rules_java//java:defs.bzl", "java_library") +load("@rules_jvm_external//:defs.bzl", "artifact") + +java_library( + name = "s2a_channel_pool", + srcs = glob([ + "src/main/java/io/grpc/s2a/internal/channel/*.java", + ]), + deps = [ + "//api", + "//core", + "//core:internal", + "//netty", + artifact("com.google.code.findbugs:jsr305"), + artifact("com.google.errorprone:error_prone_annotations"), + artifact("com.google.guava:guava"), + artifact("org.checkerframework:checker-qual"), + artifact("io.netty:netty-common"), + artifact("io.netty:netty-transport"), + ], +) + +java_library( + name = "s2a_identity", + srcs = ["src/main/java/io/grpc/s2a/internal/handshaker/S2AIdentity.java"], + deps = [ + artifact("com.google.errorprone:error_prone_annotations"), + artifact("com.google.guava:guava"), + artifact("com.google.s2a.proto.v2:s2a-proto"), + ], +) + +java_library( + name = "token_manager", + srcs = glob([ + "src/main/java/io/grpc/s2a/internal/handshaker/tokenmanager/*.java", + ]), + deps = [ + ":s2a_identity", + artifact("com.google.code.findbugs:jsr305"), + artifact("com.google.guava:guava"), + ], +) + +java_library( + name = "s2a_handshaker", + srcs = [ + "src/main/java/io/grpc/s2a/internal/handshaker/ConnectionClosedException.java", + "src/main/java/io/grpc/s2a/internal/handshaker/GetAuthenticationMechanisms.java", + "src/main/java/io/grpc/s2a/internal/handshaker/ProtoUtil.java", + "src/main/java/io/grpc/s2a/internal/handshaker/S2AConnectionException.java", + "src/main/java/io/grpc/s2a/internal/handshaker/S2APrivateKeyMethod.java", + "src/main/java/io/grpc/s2a/internal/handshaker/S2AProtocolNegotiatorFactory.java", + "src/main/java/io/grpc/s2a/internal/handshaker/S2AStub.java", + "src/main/java/io/grpc/s2a/internal/handshaker/S2ATrustManager.java", + "src/main/java/io/grpc/s2a/internal/handshaker/SslContextFactory.java", + ], + deps = [ + ":s2a_identity", + ":token_manager", + "//api", + "//core:internal", + "//netty", + "//stub", + artifact("com.google.code.findbugs:jsr305"), + artifact("com.google.errorprone:error_prone_annotations"), + artifact("com.google.guava:guava"), + artifact("com.google.s2a.proto.v2:s2a-proto"), + artifact("org.checkerframework:checker-qual"), + "@com_google_protobuf//:protobuf_java", + artifact("io.netty:netty-common"), + artifact("io.netty:netty-handler"), + artifact("io.netty:netty-transport"), + ], +) + +java_library( + name = "s2a", + srcs = ["src/main/java/io/grpc/s2a/S2AChannelCredentials.java"], + visibility = ["//visibility:public"], + deps = [ + ":s2a_channel_pool", + ":s2a_handshaker", + ":s2a_identity", + "//api", + "//core:internal", + "//netty", + artifact("com.google.code.findbugs:jsr305"), + artifact("com.google.errorprone:error_prone_annotations"), + artifact("com.google.guava:guava"), + artifact("org.checkerframework:checker-qual"), + ], +) diff --git a/s2a/build.gradle b/s2a/build.gradle new file mode 100644 index 00000000000..c46993ec9c8 --- /dev/null +++ b/s2a/build.gradle @@ -0,0 +1,120 @@ +plugins { + id "java-library" + id "maven-publish" + + id "com.google.osdetector" + id "com.google.protobuf" + id "com.gradleup.shadow" + id "ru.vyarus.animalsniffer" +} + +description = "gRPC: S2A" + +dependencies { + implementation libraries.s2a.proto + implementation 'org.checkerframework:checker-qual:3.49.5' + + api project(':grpc-api') + implementation project(':grpc-stub'), + project(':grpc-protobuf'), + project(':grpc-core'), + libraries.protobuf.java, + libraries.guava.jre // JRE required by protobuf-java-util from grpclb + def nettyDependency = implementation project(':grpc-netty') + + shadow configurations.implementation.getDependencies().minus(nettyDependency) + shadow project(path: ':grpc-netty-shaded', configuration: 'shadow') + + testImplementation project(':grpc-benchmarks'), + project(':grpc-testing'), + project(':grpc-testing-proto'), + testFixtures(project(':grpc-core')), + libraries.guava + + testImplementation 'com.google.truth:truth:1.4.2' + testImplementation 'com.google.truth.extensions:truth-proto-extension:1.4.2' + testImplementation libraries.guava.testlib + + testRuntimeOnly libraries.netty.tcnative, + libraries.netty.tcnative.classes + testRuntimeOnly (libraries.netty.tcnative) { + artifact { + classifier = "linux-x86_64" + } + } + testRuntimeOnly (libraries.netty.tcnative) { + artifact { + classifier = "linux-aarch_64" + } + } + testRuntimeOnly (libraries.netty.tcnative) { + artifact { + classifier = "osx-x86_64" + } + } + testRuntimeOnly (libraries.netty.tcnative) { + artifact { + classifier = "osx-aarch_64" + } + } + testRuntimeOnly (libraries.netty.tcnative) { + artifact { + classifier = "windows-x86_64" + } + } + + signature (libraries.signature.java) { + artifact { + extension = "signature" + } + } +} + +configureProtoCompilation() + +tasks.named("javadoc").configure { + exclude 'io/grpc/s2a/**' +} + +tasks.named("jar").configure { + // Must use a different archiveClassifier to avoid conflicting with shadowJar + archiveClassifier = 'original' + manifest { + attributes('Automatic-Module-Name': 'io.grpc.s2a') + } +} + +// We want to use grpc-netty-shaded instead of grpc-netty. But we also want our +// source to work with Bazel, so we rewrite the code as part of the build. +tasks.named("shadowJar").configure { + archiveClassifier = null + dependencies { + exclude(dependency {true}) + } + relocate 'io.grpc.netty', 'io.grpc.netty.shaded.io.grpc.netty' + relocate 'io.netty', 'io.grpc.netty.shaded.io.netty' +} + +plugins.withId('maven-publish') { +publishing { + publications { + maven(MavenPublication) { + // We want this to throw an exception if it isn't working + def originalJar = artifacts.find { dep -> dep.classifier == 'original'} + artifacts.remove(originalJar) + + pom.withXml { + def dependenciesNode = new Node(null, 'dependencies') + project.configurations.shadow.allDependencies.each { dep -> + def dependencyNode = dependenciesNode.appendNode('dependency') + dependencyNode.appendNode('groupId', dep.group) + dependencyNode.appendNode('artifactId', dep.name) + dependencyNode.appendNode('version', dep.version) + dependencyNode.appendNode('scope', 'compile') + } + asNode().dependencies[0].replaceNode(dependenciesNode) + } + } + } +} +} diff --git a/s2a/src/main/java/io/grpc/s2a/S2AChannelCredentials.java b/s2a/src/main/java/io/grpc/s2a/S2AChannelCredentials.java new file mode 100644 index 00000000000..4be32475205 --- /dev/null +++ b/s2a/src/main/java/io/grpc/s2a/S2AChannelCredentials.java @@ -0,0 +1,135 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.s2a; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkNotNull; +import static com.google.common.base.Strings.isNullOrEmpty; + +import com.google.common.annotations.VisibleForTesting; +import com.google.errorprone.annotations.CanIgnoreReturnValue; +import io.grpc.Channel; +import io.grpc.ChannelCredentials; +import io.grpc.ExperimentalApi; +import io.grpc.internal.ObjectPool; +import io.grpc.internal.SharedResourcePool; +import io.grpc.netty.InternalNettyChannelCredentials; +import io.grpc.netty.InternalProtocolNegotiator; +import io.grpc.s2a.internal.channel.S2AHandshakerServiceChannel; +import io.grpc.s2a.internal.handshaker.S2AIdentity; +import io.grpc.s2a.internal.handshaker.S2AProtocolNegotiatorFactory; +import io.grpc.s2a.internal.handshaker.S2AStub; +import javax.annotation.concurrent.NotThreadSafe; +import org.checkerframework.checker.nullness.qual.Nullable; + +/** + * Configures gRPC to use S2A for transport security when establishing a secure channel. Only for + * use on the client side of a gRPC connection. + */ +@ExperimentalApi("https://github.com/grpc/grpc-java/issues/11533") +public final class S2AChannelCredentials { + /** + * Creates a channel credentials builder for establishing an S2A-secured connection. + * + * @param s2aAddress the address of the S2A server used to secure the connection. + * @param s2aChannelCredentials the credentials to be used when connecting to the S2A. + * @return a {@code S2AChannelCredentials.Builder} instance. + */ + public static Builder newBuilder(String s2aAddress, ChannelCredentials s2aChannelCredentials) { + checkArgument(!isNullOrEmpty(s2aAddress), "S2A address must not be null or empty."); + checkNotNull(s2aChannelCredentials, "S2A channel credentials must not be null"); + return new Builder(s2aAddress, s2aChannelCredentials); + } + + /** Builds an {@code S2AChannelCredentials} instance. */ + @NotThreadSafe + public static final class Builder { + private final String s2aAddress; + private final ChannelCredentials s2aChannelCredentials; + private @Nullable S2AIdentity localIdentity = null; + private @Nullable S2AStub stub = null; + + Builder(String s2aAddress, ChannelCredentials s2aChannelCredentials) { + this.s2aAddress = s2aAddress; + this.s2aChannelCredentials = s2aChannelCredentials; + } + + /** + * Sets the local identity of the client in the form of a SPIFFE ID. The client may set at most + * 1 local identity. If no local identity is specified, then the S2A chooses a default local + * identity, if one exists. + */ + @CanIgnoreReturnValue + public Builder setLocalSpiffeId(String localSpiffeId) { + checkNotNull(localSpiffeId); + checkArgument(localIdentity == null, "localIdentity is already set."); + localIdentity = S2AIdentity.fromSpiffeId(localSpiffeId); + return this; + } + + /** + * Sets the local identity of the client in the form of a hostname. The client may set at most 1 + * local identity. If no local identity is specified, then the S2A chooses a default local + * identity, if one exists. + */ + @CanIgnoreReturnValue + public Builder setLocalHostname(String localHostname) { + checkNotNull(localHostname); + checkArgument(localIdentity == null, "localIdentity is already set."); + localIdentity = S2AIdentity.fromHostname(localHostname); + return this; + } + + /** + * Sets the local identity of the client in the form of a UID. The client may set at most 1 + * local identity. If no local identity is specified, then the S2A chooses a default local + * identity, if one exists. + */ + @CanIgnoreReturnValue + public Builder setLocalUid(String localUid) { + checkNotNull(localUid); + checkArgument(localIdentity == null, "localIdentity is already set."); + localIdentity = S2AIdentity.fromUid(localUid); + return this; + } + + /** + * Sets the stub to use to communicate with S2A. This is only used for testing that the + * stream to S2A gets closed. + */ + @VisibleForTesting + Builder setStub(S2AStub stub) { + checkNotNull(stub); + this.stub = stub; + return this; + } + + public ChannelCredentials build() { + return InternalNettyChannelCredentials.create(buildProtocolNegotiatorFactory()); + } + + InternalProtocolNegotiator.ClientFactory buildProtocolNegotiatorFactory() { + ObjectPool s2aChannelPool = + SharedResourcePool.forResource( + S2AHandshakerServiceChannel.getChannelResource(s2aAddress, s2aChannelCredentials)); + checkNotNull(s2aChannelPool, "s2aChannelPool"); + return S2AProtocolNegotiatorFactory.createClientFactory(localIdentity, s2aChannelPool, stub); + } + } + + private S2AChannelCredentials() {} +} diff --git a/s2a/src/main/java/io/grpc/s2a/internal/channel/S2AHandshakerServiceChannel.java b/s2a/src/main/java/io/grpc/s2a/internal/channel/S2AHandshakerServiceChannel.java new file mode 100644 index 00000000000..8453268efc0 --- /dev/null +++ b/s2a/src/main/java/io/grpc/s2a/internal/channel/S2AHandshakerServiceChannel.java @@ -0,0 +1,107 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.s2a.internal.channel; + +import static com.google.common.base.Preconditions.checkNotNull; +import static java.util.concurrent.TimeUnit.SECONDS; + +import io.grpc.Channel; +import io.grpc.ChannelCredentials; +import io.grpc.ManagedChannel; +import io.grpc.internal.SharedResourceHolder.Resource; +import io.grpc.netty.NettyChannelBuilder; +import javax.annotation.concurrent.ThreadSafe; + +/** + * Provides APIs for managing gRPC channels to an S2A server. Each channel is local and plaintext. + * If credentials are provided, they are used to secure the channel. + * + *

This is done as follows: for an S2A server, provides an implementation of gRPC's {@link + * SharedResourceHolder.Resource} interface called a {@code Resource}. A {@code + * Resource} is a factory for creating gRPC channels to the S2A server at a given address, + * and a channel must be returned to the {@code Resource} when it is no longer needed. + * + *

Typical usage pattern is below: + * + *

{@code
+ * Resource resource = S2AHandshakerServiceChannel.getChannelResource("localhost:1234",
+ * creds);
+ * Channel channel = resource.create();
+ * // Send an RPC over the channel to the S2A server running at localhost:1234.
+ * resource.close(channel);
+ * }
+ */ +@ThreadSafe +public final class S2AHandshakerServiceChannel { + + /** + * Returns a {@link SharedResourceHolder.Resource} instance for managing channels to an S2A server + * running at {@code s2aAddress}. + * + * @param s2aAddress the address of the S2A, typically in the format {@code host:port}. + * @param s2aChannelCredentials the credentials to use when establishing a connection to the S2A. + * @return a {@link ChannelResource} instance that manages a {@link Channel} to the S2A server + * running at {@code s2aAddress}. + */ + public static Resource getChannelResource( + String s2aAddress, ChannelCredentials s2aChannelCredentials) { + checkNotNull(s2aAddress); + return new ChannelResource(s2aAddress, s2aChannelCredentials); + } + + /** + * Defines how to create and destroy a {@link Channel} instance that uses shared resources. A + * channel created by {@code ChannelResource} is a plaintext, local channel to the service running + * at {@code targetAddress}. + */ + private static class ChannelResource implements Resource { + private final String targetAddress; + private final ChannelCredentials channelCredentials; + + public ChannelResource(String targetAddress, ChannelCredentials channelCredentials) { + this.targetAddress = targetAddress; + this.channelCredentials = channelCredentials; + } + + /** + * Creates a {@code ManagedChannel} instance to the service running at {@code + * targetAddress}. + */ + @Override + public Channel create() { + return NettyChannelBuilder.forTarget(targetAddress, channelCredentials) + .directExecutor() + .idleTimeout(5, SECONDS) + .build(); + } + + /** Destroys a {@code ManagedChannel} instance. */ + @Override + public void close(Channel instanceChannel) { + checkNotNull(instanceChannel); + ManagedChannel channel = (ManagedChannel) instanceChannel; + channel.shutdownNow(); + } + + @Override + public String toString() { + return "grpc-s2a-channel"; + } + } + + private S2AHandshakerServiceChannel() {} +} \ No newline at end of file diff --git a/s2a/src/main/java/io/grpc/s2a/internal/handshaker/ConnectionClosedException.java b/s2a/src/main/java/io/grpc/s2a/internal/handshaker/ConnectionClosedException.java new file mode 100644 index 00000000000..d6f1aa70f7c --- /dev/null +++ b/s2a/src/main/java/io/grpc/s2a/internal/handshaker/ConnectionClosedException.java @@ -0,0 +1,27 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.s2a.internal.handshaker; + +import java.io.IOException; + +/** Indicates that a connection has been closed. */ +@SuppressWarnings("serial") // This class is never serialized. +final class ConnectionClosedException extends IOException { + public ConnectionClosedException(String errorMessage) { + super(errorMessage); + } +} \ No newline at end of file diff --git a/s2a/src/main/java/io/grpc/s2a/internal/handshaker/GetAuthenticationMechanisms.java b/s2a/src/main/java/io/grpc/s2a/internal/handshaker/GetAuthenticationMechanisms.java new file mode 100644 index 00000000000..cf632418e66 --- /dev/null +++ b/s2a/src/main/java/io/grpc/s2a/internal/handshaker/GetAuthenticationMechanisms.java @@ -0,0 +1,59 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.s2a.internal.handshaker; + +import com.google.errorprone.annotations.Immutable; +import com.google.s2a.proto.v2.AuthenticationMechanism; +import io.grpc.s2a.internal.handshaker.S2AIdentity; +import io.grpc.s2a.internal.handshaker.tokenmanager.AccessTokenManager; +import java.util.Optional; + +/** Retrieves the authentication mechanism for a given local identity. */ +@Immutable +final class GetAuthenticationMechanisms { + static final Optional TOKEN_MANAGER = AccessTokenManager.create(); + + /** + * Retrieves the authentication mechanism for a given local identity. + * + * @param localIdentity the identity for which to fetch a token. + * @param tokenManager the token manager to use for fetching tokens. + * @return an {@link AuthenticationMechanism} for the given local identity. + */ + static Optional getAuthMechanism(Optional localIdentity, + Optional tokenManager) { + if (!tokenManager.isPresent()) { + return Optional.empty(); + } + AccessTokenManager manager = tokenManager.get(); + // If no identity is provided, fetch the default access token and DO NOT attach an identity + // to the request. + if (!localIdentity.isPresent()) { + return Optional.of( + AuthenticationMechanism.newBuilder().setToken(manager.getDefaultToken()).build()); + } else { + // Fetch an access token for the provided identity. + return Optional.of( + AuthenticationMechanism.newBuilder() + .setIdentity(localIdentity.get().getIdentity()) + .setToken(manager.getToken(localIdentity.get())) + .build()); + } + } + + private GetAuthenticationMechanisms() {} +} \ No newline at end of file diff --git a/s2a/src/main/java/io/grpc/s2a/internal/handshaker/ProtoUtil.java b/s2a/src/main/java/io/grpc/s2a/internal/handshaker/ProtoUtil.java new file mode 100644 index 00000000000..0526ec154f9 --- /dev/null +++ b/s2a/src/main/java/io/grpc/s2a/internal/handshaker/ProtoUtil.java @@ -0,0 +1,78 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.s2a.internal.handshaker; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.ImmutableSet; +import com.google.s2a.proto.v2.TLSVersion; + +/** Converts proto messages to Netty strings. */ +final class ProtoUtil { + + /** + * Converts a {@link TLSVersion} object to its {@link String} representation. + * + * @param tlsVersion the {@link TLSVersion} object to be converted. + * @return a {@link String} representation of the TLS version. + * @throws IllegalArgumentException if the {@code tlsVersion} is not one of + * the supported TLS versions. + */ + @VisibleForTesting + static String convertTlsProtocolVersion(TLSVersion tlsVersion) { + switch (tlsVersion) { + case TLS_VERSION_1_3: + return "TLSv1.3"; + case TLS_VERSION_1_2: + return "TLSv1.2"; + case TLS_VERSION_1_1: + return "TLSv1.1"; + case TLS_VERSION_1_0: + return "TLSv1"; + default: + throw new IllegalArgumentException( + String.format("TLS version %d is not supported.", tlsVersion.getNumber())); + } + } + + /** + * Builds a set of strings representing all {@link TLSVersion}s between {@code minTlsVersion} and + * {@code maxTlsVersion}. + */ + static ImmutableSet buildTlsProtocolVersionSet( + TLSVersion minTlsVersion, TLSVersion maxTlsVersion) { + ImmutableSet.Builder tlsVersions = ImmutableSet.builder(); + for (TLSVersion tlsVersion : TLSVersion.values()) { + int versionNumber; + try { + versionNumber = tlsVersion.getNumber(); + } catch (IllegalArgumentException e) { + continue; + } + if (versionNumber >= minTlsVersion.getNumber() + && versionNumber <= maxTlsVersion.getNumber()) { + try { + tlsVersions.add(convertTlsProtocolVersion(tlsVersion)); + } catch (IllegalArgumentException e) { + continue; + } + } + } + return tlsVersions.build(); + } + + private ProtoUtil() {} +} \ No newline at end of file diff --git a/s2a/src/main/java/io/grpc/s2a/internal/handshaker/S2AConnectionException.java b/s2a/src/main/java/io/grpc/s2a/internal/handshaker/S2AConnectionException.java new file mode 100644 index 00000000000..9b6c244751b --- /dev/null +++ b/s2a/src/main/java/io/grpc/s2a/internal/handshaker/S2AConnectionException.java @@ -0,0 +1,25 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.s2a.internal.handshaker; + +/** Exception that denotes a runtime error that was encountered when talking to the S2A server. */ +@SuppressWarnings("serial") // This class is never serialized. +public class S2AConnectionException extends RuntimeException { + S2AConnectionException(String message) { + super(message); + } +} \ No newline at end of file diff --git a/s2a/src/main/java/io/grpc/s2a/internal/handshaker/S2AIdentity.java b/s2a/src/main/java/io/grpc/s2a/internal/handshaker/S2AIdentity.java new file mode 100644 index 00000000000..f4d6b88ce45 --- /dev/null +++ b/s2a/src/main/java/io/grpc/s2a/internal/handshaker/S2AIdentity.java @@ -0,0 +1,63 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.s2a.internal.handshaker; + +import static com.google.common.base.Preconditions.checkNotNull; + +import com.google.errorprone.annotations.ThreadSafe; +import com.google.s2a.proto.v2.Identity; + +/** + * Stores an identity in such a way that it can be sent to the S2A handshaker service. The identity + * may be formatted as a SPIFFE ID or as a hostname. + */ +@ThreadSafe +public final class S2AIdentity { + private final Identity identity; + + /** Returns an {@link S2AIdentity} instance with SPIFFE ID set to {@code spiffeId}. */ + public static S2AIdentity fromSpiffeId(String spiffeId) { + checkNotNull(spiffeId); + return new S2AIdentity(Identity.newBuilder().setSpiffeId(spiffeId).build()); + } + + /** Returns an {@link S2AIdentity} instance with hostname set to {@code hostname}. */ + public static S2AIdentity fromHostname(String hostname) { + checkNotNull(hostname); + return new S2AIdentity(Identity.newBuilder().setHostname(hostname).build()); + } + + /** Returns an {@link S2AIdentity} instance with UID set to {@code uid}. */ + public static S2AIdentity fromUid(String uid) { + checkNotNull(uid); + return new S2AIdentity(Identity.newBuilder().setUid(uid).build()); + } + + /** Returns an {@link S2AIdentity} instance with {@code identity} set. */ + public static S2AIdentity fromIdentity(Identity identity) { + return new S2AIdentity(identity == null ? Identity.getDefaultInstance() : identity); + } + + private S2AIdentity(Identity identity) { + this.identity = identity; + } + + /** Returns the proto {@link Identity} representation of this identity instance. */ + public Identity getIdentity() { + return identity; + } +} \ No newline at end of file diff --git a/s2a/src/main/java/io/grpc/s2a/internal/handshaker/S2APrivateKeyMethod.java b/s2a/src/main/java/io/grpc/s2a/internal/handshaker/S2APrivateKeyMethod.java new file mode 100644 index 00000000000..1a5c37eb989 --- /dev/null +++ b/s2a/src/main/java/io/grpc/s2a/internal/handshaker/S2APrivateKeyMethod.java @@ -0,0 +1,147 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.s2a.internal.handshaker; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkNotNull; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.ImmutableMap; +import com.google.protobuf.ByteString; +import com.google.s2a.proto.v2.OffloadPrivateKeyOperationReq; +import com.google.s2a.proto.v2.SessionReq; +import com.google.s2a.proto.v2.SessionResp; +import com.google.s2a.proto.v2.SignatureAlgorithm; +import io.grpc.s2a.internal.handshaker.S2AIdentity; +import io.netty.handler.ssl.OpenSslPrivateKeyMethod; +import java.io.IOException; +import java.util.Optional; +import javax.annotation.concurrent.NotThreadSafe; +import javax.net.ssl.SSLEngine; + +/** + * Handles requests on signing bytes with a private key designated by {@code stub}. + * + *

This is done by sending the to-be-signed bytes to an S2A server (designated by {@code stub}) + * and read the signature from the server. + * + *

OpenSSL libraries must be appropriately initialized before using this class. One possible way + * to initialize OpenSSL library is to call {@code + * GrpcSslContexts.configure(SslContextBuilder.forClient());}. + */ +@NotThreadSafe +final class S2APrivateKeyMethod implements OpenSslPrivateKeyMethod { + private final S2AStub stub; + private final Optional localIdentity; + private static final ImmutableMap + OPENSSL_TO_S2A_SIGNATURE_ALGORITHM_MAP = + ImmutableMap.of( + OpenSslPrivateKeyMethod.SSL_SIGN_RSA_PKCS1_SHA256, + SignatureAlgorithm.S2A_SSL_SIGN_RSA_PKCS1_SHA256, + OpenSslPrivateKeyMethod.SSL_SIGN_RSA_PKCS1_SHA384, + SignatureAlgorithm.S2A_SSL_SIGN_RSA_PKCS1_SHA384, + OpenSslPrivateKeyMethod.SSL_SIGN_RSA_PKCS1_SHA512, + SignatureAlgorithm.S2A_SSL_SIGN_RSA_PKCS1_SHA512, + OpenSslPrivateKeyMethod.SSL_SIGN_ECDSA_SECP256R1_SHA256, + SignatureAlgorithm.S2A_SSL_SIGN_ECDSA_SECP256R1_SHA256, + OpenSslPrivateKeyMethod.SSL_SIGN_ECDSA_SECP384R1_SHA384, + SignatureAlgorithm.S2A_SSL_SIGN_ECDSA_SECP384R1_SHA384, + OpenSslPrivateKeyMethod.SSL_SIGN_ECDSA_SECP521R1_SHA512, + SignatureAlgorithm.S2A_SSL_SIGN_ECDSA_SECP521R1_SHA512, + OpenSslPrivateKeyMethod.SSL_SIGN_RSA_PSS_RSAE_SHA256, + SignatureAlgorithm.S2A_SSL_SIGN_RSA_PSS_RSAE_SHA256, + OpenSslPrivateKeyMethod.SSL_SIGN_RSA_PSS_RSAE_SHA384, + SignatureAlgorithm.S2A_SSL_SIGN_RSA_PSS_RSAE_SHA384, + OpenSslPrivateKeyMethod.SSL_SIGN_RSA_PSS_RSAE_SHA512, + SignatureAlgorithm.S2A_SSL_SIGN_RSA_PSS_RSAE_SHA512); + + public static S2APrivateKeyMethod create(S2AStub stub, Optional localIdentity) { + checkNotNull(stub); + return new S2APrivateKeyMethod(stub, localIdentity); + } + + private S2APrivateKeyMethod(S2AStub stub, Optional localIdentity) { + this.stub = stub; + this.localIdentity = localIdentity; + } + + /** + * Converts the signature algorithm to an enum understood by S2A. + * + * @param signatureAlgorithm the int representation of the signature algorithm define by {@code + * OpenSslPrivateKeyMethod}. + * @return the signature algorithm enum defined by S2A proto. + * @throws UnsupportedOperationException if the algorithm is not supported by S2A. + */ + @VisibleForTesting + static SignatureAlgorithm convertOpenSslSignAlgToS2ASignAlg(int signatureAlgorithm) { + SignatureAlgorithm sig = OPENSSL_TO_S2A_SIGNATURE_ALGORITHM_MAP.get(signatureAlgorithm); + if (sig == null) { + throw new UnsupportedOperationException( + String.format("Signature Algorithm %d is not supported.", signatureAlgorithm)); + } + return sig; + } + + /** + * Signs the input bytes by sending the request to the S2A srever. + * + * @param engine not used. + * @param signatureAlgorithm the {@link OpenSslPrivateKeyMethod}'s signature algorithm + * representation + * @param input the bytes to be signed. + * @return the signature of the {@code input}. + * @throws IOException if the connection to the S2A server is corrupted. + * @throws InterruptedException if the connection to the S2A server is interrupted. + * @throws S2AConnectionException if the response from the S2A server does not contain valid data. + */ + @Override + public byte[] sign(SSLEngine engine, int signatureAlgorithm, byte[] input) + throws IOException, InterruptedException { + checkArgument(input.length > 0, "No bytes to sign."); + SignatureAlgorithm s2aSignatureAlgorithm = + convertOpenSslSignAlgToS2ASignAlg(signatureAlgorithm); + SessionReq.Builder reqBuilder = + SessionReq.newBuilder() + .setOffloadPrivateKeyOperationReq( + OffloadPrivateKeyOperationReq.newBuilder() + .setOperation(OffloadPrivateKeyOperationReq.PrivateKeyOperation.SIGN) + .setSignatureAlgorithm(s2aSignatureAlgorithm) + .setRawBytes(ByteString.copyFrom(input))); + if (localIdentity.isPresent()) { + reqBuilder.setLocalIdentity(localIdentity.get().getIdentity()); + } + + SessionResp resp = stub.send(reqBuilder.build()); + + if (resp.hasStatus() && resp.getStatus().getCode() != 0) { + throw new S2AConnectionException( + String.format( + "Error occurred in response from S2A, error code: %d, error message: \"%s\".", + resp.getStatus().getCode(), resp.getStatus().getDetails())); + } + if (!resp.hasOffloadPrivateKeyOperationResp()) { + throw new S2AConnectionException("No valid response received from S2A."); + } + return resp.getOffloadPrivateKeyOperationResp().getOutBytes().toByteArray(); + } + + @Override + public byte[] decrypt(SSLEngine engine, byte[] input) { + throw new UnsupportedOperationException("decrypt is not supported."); + } +} \ No newline at end of file diff --git a/s2a/src/main/java/io/grpc/s2a/internal/handshaker/S2AProtocolNegotiatorFactory.java b/s2a/src/main/java/io/grpc/s2a/internal/handshaker/S2AProtocolNegotiatorFactory.java new file mode 100644 index 00000000000..9dcbdcf0509 --- /dev/null +++ b/s2a/src/main/java/io/grpc/s2a/internal/handshaker/S2AProtocolNegotiatorFactory.java @@ -0,0 +1,282 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.s2a.internal.handshaker; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkNotNull; +import static com.google.common.base.Strings.isNullOrEmpty; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.net.HostAndPort; +import com.google.common.util.concurrent.FutureCallback; +import com.google.common.util.concurrent.Futures; +import com.google.common.util.concurrent.ListenableFuture; +import com.google.common.util.concurrent.ListeningExecutorService; +import com.google.common.util.concurrent.MoreExecutors; +import com.google.errorprone.annotations.ThreadSafe; +import com.google.s2a.proto.v2.S2AServiceGrpc; +import io.grpc.Channel; +import io.grpc.internal.GrpcUtil; +import io.grpc.internal.ObjectPool; +import io.grpc.internal.SharedResourcePool; +import io.grpc.netty.GrpcHttp2ConnectionHandler; +import io.grpc.netty.InternalProtocolNegotiator; +import io.grpc.netty.InternalProtocolNegotiator.ProtocolNegotiator; +import io.grpc.netty.InternalProtocolNegotiators; +import io.grpc.netty.InternalProtocolNegotiators.ProtocolNegotiationHandler; +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerAdapter; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.handler.ssl.SslContext; +import io.netty.util.AsciiString; +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; +import java.util.concurrent.Executors; +import javax.annotation.Nullable; + +/** Factory for performing negotiation of a secure channel using the S2A. */ +@ThreadSafe +public final class S2AProtocolNegotiatorFactory { + @VisibleForTesting static final int DEFAULT_PORT = 443; + private static final AsciiString SCHEME = AsciiString.of("https"); + + /** + * Creates a {@code S2AProtocolNegotiatorFactory} configured for a client to establish secure + * connections using the S2A. + * + * @param localIdentity the identity of the client; if none is provided, the S2A will use the + * client's default identity. + * @param s2aChannelPool a pool of shared channels that can be used to connect to the S2A. + * @param stub the stub to use to communicate with S2A. If none is provided the channelPool + * will be used to create the stub. This is exposed for verifying the stream to S2A gets + * closed in tests. + * @return a factory for creating a client-side protocol negotiator. + */ + public static InternalProtocolNegotiator.ClientFactory createClientFactory( + @Nullable S2AIdentity localIdentity, ObjectPool s2aChannelPool, + @Nullable S2AStub stub) { + checkNotNull(s2aChannelPool, "S2A channel pool should not be null."); + return new S2AClientProtocolNegotiatorFactory(localIdentity, s2aChannelPool, stub); + } + + static final class S2AClientProtocolNegotiatorFactory + implements InternalProtocolNegotiator.ClientFactory { + private final @Nullable S2AIdentity localIdentity; + private final ObjectPool channelPool; + private final @Nullable S2AStub stub; + + S2AClientProtocolNegotiatorFactory( + @Nullable S2AIdentity localIdentity, ObjectPool channelPool, + @Nullable S2AStub stub) { + this.localIdentity = localIdentity; + this.channelPool = channelPool; + this.stub = stub; + } + + @Override + public ProtocolNegotiator newNegotiator() { + return S2AProtocolNegotiator.createForClient(channelPool, localIdentity, stub); + } + + @Override + public int getDefaultPort() { + return DEFAULT_PORT; + } + } + + /** Negotiates the TLS handshake using S2A. */ + @VisibleForTesting + static final class S2AProtocolNegotiator implements ProtocolNegotiator { + + private final ObjectPool channelPool; + private @Nullable Channel channel = null; + private final Optional localIdentity; + private final @Nullable S2AStub stub; + private final ListeningExecutorService service = + MoreExecutors.listeningDecorator(Executors.newFixedThreadPool(1)); + + static S2AProtocolNegotiator createForClient( + ObjectPool channelPool, @Nullable S2AIdentity localIdentity, + @Nullable S2AStub stub) { + checkNotNull(channelPool, "Channel pool should not be null."); + if (localIdentity == null) { + return new S2AProtocolNegotiator(channelPool, Optional.empty(), stub); + } else { + return new S2AProtocolNegotiator(channelPool, Optional.of(localIdentity), stub); + } + } + + @VisibleForTesting + static @Nullable String getHostNameFromAuthority(@Nullable String authority) { + if (authority == null) { + return null; + } + return HostAndPort.fromString(authority).getHost(); + } + + private S2AProtocolNegotiator(ObjectPool channelPool, + Optional localIdentity, @Nullable S2AStub stub) { + this.channelPool = channelPool; + this.localIdentity = localIdentity; + this.stub = stub; + if (this.stub == null) { + this.channel = channelPool.getObject(); + } + } + + @Override + public AsciiString scheme() { + return SCHEME; + } + + @Override + public ChannelHandler newHandler(GrpcHttp2ConnectionHandler grpcHandler) { + checkNotNull(grpcHandler, "grpcHandler should not be null."); + String hostname = getHostNameFromAuthority(grpcHandler.getAuthority()); + checkArgument(!isNullOrEmpty(hostname), "hostname should not be null or empty."); + return new S2AProtocolNegotiationHandler( + grpcHandler, channel, localIdentity, hostname, service, stub); + } + + @Override + public void close() { + service.shutdown(); + if (channel != null) { + channelPool.returnObject(channel); + } + } + } + + @VisibleForTesting + static class BufferReadsHandler extends ChannelInboundHandlerAdapter { + private final List reads = new ArrayList<>(); + private boolean readComplete; + + public List getReads() { + return reads; + } + + @Override + public void channelRead(ChannelHandlerContext unused, Object msg) { + reads.add(msg); + } + + @Override + public void channelReadComplete(ChannelHandlerContext unused) { + readComplete = true; + } + + @Override + public void handlerRemoved(ChannelHandlerContext ctx) throws Exception { + for (Object msg : reads) { + super.channelRead(ctx, msg); + } + if (readComplete) { + super.channelReadComplete(ctx); + } + } + } + + private static final class S2AProtocolNegotiationHandler extends ProtocolNegotiationHandler { + private final @Nullable Channel channel; + private final Optional localIdentity; + private final String hostname; + private final GrpcHttp2ConnectionHandler grpcHandler; + private final ListeningExecutorService service; + private final @Nullable S2AStub stub; + + private S2AProtocolNegotiationHandler( + GrpcHttp2ConnectionHandler grpcHandler, + Channel channel, + Optional localIdentity, + String hostname, + ListeningExecutorService service, + @Nullable S2AStub stub) { + super( + // superclass (InternalProtocolNegotiators.ProtocolNegotiationHandler) expects 'next' + // handler but we don't have a next handler _yet_. So we "disable" superclass's behavior + // here and then manually add 'next' when we call fireProtocolNegotiationEvent() + new ChannelHandlerAdapter() { + @Override + public void handlerAdded(ChannelHandlerContext ctx) { + ctx.pipeline().remove(this); + } + }, + grpcHandler.getNegotiationLogger()); + this.grpcHandler = grpcHandler; + this.channel = channel; + this.localIdentity = localIdentity; + this.hostname = hostname; + checkNotNull(service, "service should not be null."); + this.service = service; + this.stub = stub; + } + + @Override + protected void handlerAdded0(ChannelHandlerContext ctx) { + // Buffer all reads until the TLS Handler is added. + BufferReadsHandler bufferReads = new BufferReadsHandler(); + ctx.pipeline().addBefore(ctx.name(), /* name= */ null, bufferReads); + + S2AStub s2aStub; + if (this.stub == null) { + checkNotNull(channel, "Channel to S2A should not be null"); + s2aStub = S2AStub.newInstance(S2AServiceGrpc.newStub(channel)); + } else { + s2aStub = this.stub; + } + + ListenableFuture sslContextFuture = + service.submit(() -> SslContextFactory.createForClient(s2aStub, hostname, localIdentity)); + Futures.addCallback( + sslContextFuture, + new FutureCallback() { + @Override + public void onSuccess(SslContext sslContext) { + ChannelHandler handler = + InternalProtocolNegotiators.tls( + sslContext, + SharedResourcePool.forResource(GrpcUtil.SHARED_CHANNEL_EXECUTOR), + com.google.common.base.Optional.of(new Runnable() { + @Override + public void run() { + s2aStub.close(); + } + }), + null, null) + .newHandler(grpcHandler); + + // Delegate the rest of the handshake to the TLS handler. and remove the + // bufferReads handler. + ctx.pipeline().addAfter(ctx.name(), /* name= */ null, handler); + fireProtocolNegotiationEvent(ctx); + ctx.pipeline().remove(bufferReads); + } + + @Override + public void onFailure(Throwable t) { + ctx.fireExceptionCaught(t); + } + }, + service); + } + } + + private S2AProtocolNegotiatorFactory() {} +} diff --git a/s2a/src/main/java/io/grpc/s2a/internal/handshaker/S2AStub.java b/s2a/src/main/java/io/grpc/s2a/internal/handshaker/S2AStub.java new file mode 100644 index 00000000000..37236f26f4b --- /dev/null +++ b/s2a/src/main/java/io/grpc/s2a/internal/handshaker/S2AStub.java @@ -0,0 +1,245 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.s2a.internal.handshaker; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkNotNull; +import static com.google.common.base.Verify.verify; +import static java.util.concurrent.TimeUnit.SECONDS; + +import com.google.common.annotations.VisibleForTesting; +import com.google.s2a.proto.v2.S2AServiceGrpc; +import com.google.s2a.proto.v2.SessionReq; +import com.google.s2a.proto.v2.SessionResp; +import io.grpc.stub.StreamObserver; +import java.io.IOException; +import java.util.Optional; +import java.util.concurrent.ArrayBlockingQueue; +import java.util.concurrent.BlockingQueue; +import java.util.logging.Level; +import java.util.logging.Logger; +import javax.annotation.concurrent.NotThreadSafe; + +/** Reads and writes messages to and from the S2A. */ +@NotThreadSafe +public class S2AStub implements AutoCloseable { + private static final Logger logger = Logger.getLogger(S2AStub.class.getName()); + private static final long HANDSHAKE_RPC_DEADLINE_SECS = 20; + private final StreamObserver reader = new Reader(); + private final BlockingQueue responses = new ArrayBlockingQueue<>(10); + private S2AServiceGrpc.S2AServiceStub serviceStub; + private StreamObserver writer; + private long deadlineSeconds = HANDSHAKE_RPC_DEADLINE_SECS; + private boolean doneReading = false; + private boolean doneWriting = false; + private boolean isClosed = false; + + @VisibleForTesting + public static S2AStub newInstance(S2AServiceGrpc.S2AServiceStub serviceStub) { + checkNotNull(serviceStub); + return new S2AStub(serviceStub); + } + + @VisibleForTesting + static S2AStub newInstanceWithDeadline( + S2AServiceGrpc.S2AServiceStub serviceStub, long deadlineSeconds) { + checkNotNull(serviceStub); + checkArgument(deadlineSeconds > 0); + return new S2AStub(serviceStub, deadlineSeconds); + } + + @VisibleForTesting + static S2AStub newInstanceForTesting(StreamObserver writer) { + checkNotNull(writer); + return new S2AStub(writer); + } + + private S2AStub(S2AServiceGrpc.S2AServiceStub serviceStub) { + this.serviceStub = serviceStub; + } + + private S2AStub(S2AServiceGrpc.S2AServiceStub serviceStub, long deadlineSeconds) { + this.serviceStub = serviceStub; + this.deadlineSeconds = deadlineSeconds; + } + + private S2AStub(StreamObserver writer) { + this.writer = writer; + } + + @VisibleForTesting + StreamObserver getReader() { + return reader; + } + + @VisibleForTesting + BlockingQueue getResponses() { + return responses; + } + + /** + * Sends a request and returns the response. Caller must wait until this method executes prior to + * calling it again. If this method throws {@code ConnectionClosedException}, then it should not + * be called again, and both {@code reader} and {@code writer} are closed. + * + * @param req the {@code SessionReq} message to be sent to the S2A server. + * @return the {@code SessionResp} message received from the S2A server. + * @throws ConnectionClosedException if {@code reader} or {@code writer} calls their {@code + * onCompleted} method. + * @throws IOException if an unexpected response is received, or if the {@code reader} or {@code + * writer} calls their {@code onError} method. + */ + @SuppressWarnings("CheckReturnValue") + public SessionResp send(SessionReq req) throws IOException, InterruptedException { + if (doneWriting && doneReading) { + logger.log(Level.INFO, "Stream to the S2A is closed."); + throw new ConnectionClosedException("Stream to the S2A is closed."); + } + createWriterIfNull(); + if (!responses.isEmpty()) { + IOException exception = null; + try { + responses.take().getResultOrThrow(); + } catch (IOException e) { + exception = e; + } + responses.clear(); + if (exception != null) { + throw new IOException( + "Received an unexpected response from a host at the S2A's address. The S2A might be" + + " unavailable.", exception); + } else { + throw new IOException("Received an unexpected response from a host at the S2A's address."); + } + } + try { + writer.onNext(req); + } catch (RuntimeException e) { + writer.onError(e); + responses.add(Result.createWithThrowable(e)); + } + try { + return responses.take().getResultOrThrow(); + } catch (ConnectionClosedException e) { + // A ConnectionClosedException is thrown by getResultOrThrow when reader calls its + // onCompleted method. The close method is called to also close the writer, and then the + // ConnectionClosedException is re-thrown in order to indicate to the caller that send + // should not be called again. + close(); + throw e; + } + } + + @Override + public void close() { + if (doneWriting && doneReading) { + return; + } + verify(!doneWriting); + doneReading = true; + doneWriting = true; + if (writer != null) { + writer.onCompleted(); + } + isClosed = true; + } + + public boolean isClosed() { + return isClosed; + } + + /** Create a new writer if the writer is null. */ + private void createWriterIfNull() { + if (writer == null) { + writer = + serviceStub + .withWaitForReady() + .withDeadlineAfter(deadlineSeconds, SECONDS) + .setUpSession(reader); + } + } + + private class Reader implements StreamObserver { + /** + * Places a {@code SessionResp} message in the {@code responses} queue, or an {@code + * IOException} if reading is complete. + * + * @param resp the {@code SessionResp} message received from the S2A handshaker module. + */ + @Override + public void onNext(SessionResp resp) { + verify(!doneReading); + responses.add(Result.createWithResponse(resp)); + } + + /** + * Places a {@code Throwable} in the {@code responses} queue. + * + * @param t the {@code Throwable} caught when reading the stream to the S2A handshaker module. + */ + @Override + public void onError(Throwable t) { + responses.add(Result.createWithThrowable(t)); + } + + /** + * Sets {@code doneReading} to true, and places a {@code ConnectionClosedException} in the + * {@code responses} queue. + */ + @Override + public void onCompleted() { + logger.log(Level.INFO, "Reading from the S2A is complete."); + doneReading = true; + responses.add( + Result.createWithThrowable( + new ConnectionClosedException("Reading from the S2A is complete."))); + } + } + + private static final class Result { + private final Optional response; + private final Optional throwable; + + static Result createWithResponse(SessionResp response) { + return new Result(Optional.of(response), Optional.empty()); + } + + static Result createWithThrowable(Throwable throwable) { + return new Result(Optional.empty(), Optional.of(throwable)); + } + + private Result(Optional response, Optional throwable) { + checkArgument(response.isPresent() != throwable.isPresent()); + this.response = response; + this.throwable = throwable; + } + + /** Throws {@code throwable} if present, and returns {@code response} otherwise. */ + SessionResp getResultOrThrow() throws IOException { + if (throwable.isPresent()) { + if (throwable.get() instanceof ConnectionClosedException) { + ConnectionClosedException exception = (ConnectionClosedException) throwable.get(); + throw exception; + } else { + throw new IOException(throwable.get()); + } + } + verify(response.isPresent()); + return response.get(); + } + } +} diff --git a/s2a/src/main/java/io/grpc/s2a/internal/handshaker/S2ATrustManager.java b/s2a/src/main/java/io/grpc/s2a/internal/handshaker/S2ATrustManager.java new file mode 100644 index 00000000000..a7ffafd01f2 --- /dev/null +++ b/s2a/src/main/java/io/grpc/s2a/internal/handshaker/S2ATrustManager.java @@ -0,0 +1,159 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.s2a.internal.handshaker; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkNotNull; + +import com.google.common.collect.ImmutableList; +import com.google.protobuf.ByteString; +import com.google.s2a.proto.v2.SessionReq; +import com.google.s2a.proto.v2.SessionResp; +import com.google.s2a.proto.v2.ValidatePeerCertificateChainReq; +import com.google.s2a.proto.v2.ValidatePeerCertificateChainReq.VerificationMode; +import com.google.s2a.proto.v2.ValidatePeerCertificateChainResp; +import io.grpc.s2a.internal.handshaker.S2AIdentity; +import java.io.IOException; +import java.security.cert.CertificateEncodingException; +import java.security.cert.CertificateException; +import java.security.cert.X509Certificate; +import java.util.Optional; +import javax.annotation.concurrent.NotThreadSafe; +import javax.net.ssl.X509TrustManager; +import org.checkerframework.checker.nullness.qual.Nullable; + +/** Offloads verification of the peer certificate chain to S2A. */ +@NotThreadSafe +final class S2ATrustManager implements X509TrustManager { + private final Optional localIdentity; + private final S2AStub stub; + private final String hostname; + + static S2ATrustManager createForClient( + S2AStub stub, String hostname, Optional localIdentity) { + checkNotNull(stub); + checkNotNull(hostname); + return new S2ATrustManager(stub, hostname, localIdentity); + } + + private S2ATrustManager(S2AStub stub, String hostname, Optional localIdentity) { + this.stub = stub; + this.hostname = hostname; + this.localIdentity = localIdentity; + } + + /** + * Validates the given certificate chain provided by the peer. + * + * @param chain the peer certificate chain + * @param authType the authentication type based on the client certificate + * @throws IllegalArgumentException if null or zero-length chain is passed in for the chain + * parameter. + * @throws CertificateException if the certificate chain is not trusted by this TrustManager. + */ + @Override + public void checkClientTrusted(X509Certificate[] chain, String authType) + throws CertificateException { + checkPeerTrusted(chain, /* isCheckingClientCertificateChain= */ true); + } + + /** + * Validates the given certificate chain provided by the peer. + * + * @param chain the peer certificate chain + * @param authType the authentication type based on the client certificate + * @throws IllegalArgumentException if null or zero-length chain is passed in for the chain + * parameter. + * @throws CertificateException if the certificate chain is not trusted by this TrustManager. + */ + @Override + public void checkServerTrusted(X509Certificate[] chain, String authType) + throws CertificateException { + checkPeerTrusted(chain, /* isCheckingClientCertificateChain= */ false); + } + + /** + * Returns null because the accepted issuers are held in S2A and this class receives decision made + * from S2A on the fly about which to use to verify a given chain. + * + * @return null. + */ + @Override + public X509Certificate @Nullable [] getAcceptedIssuers() { + return null; + } + + private void checkPeerTrusted(X509Certificate[] chain, boolean isCheckingClientCertificateChain) + throws CertificateException { + checkNotNull(chain); + checkArgument(chain.length > 0, "Certificate chain has zero certificates."); + + ValidatePeerCertificateChainReq.Builder validatePeerCertificateChainReq = + ValidatePeerCertificateChainReq.newBuilder().setMode(VerificationMode.UNSPECIFIED); + if (isCheckingClientCertificateChain) { + validatePeerCertificateChainReq.setClientPeer( + ValidatePeerCertificateChainReq.ClientPeer.newBuilder() + .addAllCertificateChain(certificateChainToDerChain(chain))); + } else { + validatePeerCertificateChainReq.setServerPeer( + ValidatePeerCertificateChainReq.ServerPeer.newBuilder() + .addAllCertificateChain(certificateChainToDerChain(chain)) + .setServerHostname(hostname)); + } + + SessionReq.Builder reqBuilder = + SessionReq.newBuilder().setValidatePeerCertificateChainReq(validatePeerCertificateChainReq); + if (localIdentity.isPresent()) { + reqBuilder.setLocalIdentity(localIdentity.get().getIdentity()); + } + + SessionResp resp; + try { + resp = stub.send(reqBuilder.build()); + } catch (IOException e) { + throw new CertificateException("Failed to send request to S2A.", e); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new CertificateException("Failed to send request to S2A.", e); + } + if (resp.hasStatus() && resp.getStatus().getCode() != 0) { + throw new CertificateException( + String.format( + "Error occurred in response from S2A, error code: %d, error message: %s.", + resp.getStatus().getCode(), resp.getStatus().getDetails())); + } + + if (!resp.hasValidatePeerCertificateChainResp()) { + throw new CertificateException("No valid response received from S2A."); + } + + ValidatePeerCertificateChainResp validationResult = resp.getValidatePeerCertificateChainResp(); + if (validationResult.getValidationResult() + != ValidatePeerCertificateChainResp.ValidationResult.SUCCESS) { + throw new CertificateException(validationResult.getValidationDetails()); + } + } + + private static ImmutableList certificateChainToDerChain(X509Certificate[] chain) + throws CertificateEncodingException { + ImmutableList.Builder derChain = ImmutableList.builder(); + for (X509Certificate certificate : chain) { + derChain.add(ByteString.copyFrom(certificate.getEncoded())); + } + return derChain.build(); + } +} \ No newline at end of file diff --git a/s2a/src/main/java/io/grpc/s2a/internal/handshaker/SslContextFactory.java b/s2a/src/main/java/io/grpc/s2a/internal/handshaker/SslContextFactory.java new file mode 100644 index 00000000000..5d4ef9eb667 --- /dev/null +++ b/s2a/src/main/java/io/grpc/s2a/internal/handshaker/SslContextFactory.java @@ -0,0 +1,187 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.s2a.internal.handshaker; + +import static com.google.common.base.Preconditions.checkNotNull; +import static java.nio.charset.StandardCharsets.UTF_8; + +import com.google.common.collect.ImmutableSet; +import com.google.s2a.proto.v2.AuthenticationMechanism; +import com.google.s2a.proto.v2.ConnectionSide; +import com.google.s2a.proto.v2.GetTlsConfigurationReq; +import com.google.s2a.proto.v2.GetTlsConfigurationResp; +import com.google.s2a.proto.v2.SessionReq; +import com.google.s2a.proto.v2.SessionResp; +import io.grpc.netty.GrpcSslContexts; +import io.grpc.s2a.internal.handshaker.S2AIdentity; +import io.netty.handler.ssl.OpenSslContextOption; +import io.netty.handler.ssl.OpenSslSessionContext; +import io.netty.handler.ssl.OpenSslX509KeyManagerFactory; +import io.netty.handler.ssl.SslContext; +import io.netty.handler.ssl.SslContextBuilder; +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.security.GeneralSecurityException; +import java.security.KeyStoreException; +import java.security.NoSuchAlgorithmException; +import java.security.UnrecoverableKeyException; +import java.security.cert.CertificateException; +import java.security.cert.CertificateFactory; +import java.security.cert.X509Certificate; +import java.util.Optional; +import javax.net.ssl.KeyManager; +import javax.net.ssl.SSLSessionContext; + +/** Creates {@link SslContext} objects with TLS configurations from S2A server. */ +final class SslContextFactory { + + /** + * Creates {@link SslContext} objects for client with TLS configurations from S2A server. + * + * @param stub the {@link S2AStub} to talk to the S2A server. + * @param targetName the {@link String} of the server that this client makes connection to. + * @param localIdentity the {@link S2AIdentity} that should be used when talking to S2A server. + * Will use default identity if empty. + * @return a {@link SslContext} object. + * @throws NullPointerException if either {@code stub} or {@code targetName} is null. + * @throws IOException if an unexpected response from S2A server is received. + * @throws InterruptedException if {@code stub} is closed. + */ + static SslContext createForClient( + S2AStub stub, String targetName, Optional localIdentity) + throws IOException, + InterruptedException, + CertificateException, + KeyStoreException, + NoSuchAlgorithmException, + UnrecoverableKeyException, + GeneralSecurityException { + checkNotNull(stub, "stub should not be null."); + checkNotNull(targetName, "targetName should not be null on client side."); + GetTlsConfigurationResp.ClientTlsConfiguration clientTlsConfiguration; + try { + clientTlsConfiguration = getClientTlsConfigurationFromS2A(stub, localIdentity); + } catch (IOException | InterruptedException e) { + throw new GeneralSecurityException("Failed to get client TLS configuration from S2A.", e); + } + + // Use the default value for timeout. + // Use the smallest possible value for cache size. + // The Provider is by default OPENSSL. No need to manually set it. + SslContextBuilder sslContextBuilder = + GrpcSslContexts.configure(SslContextBuilder.forClient()) + .sessionCacheSize(1) + .sessionTimeout(0); + + configureSslContextWithClientTlsConfiguration(clientTlsConfiguration, sslContextBuilder); + sslContextBuilder.trustManager( + S2ATrustManager.createForClient(stub, targetName, localIdentity)); + sslContextBuilder.option( + OpenSslContextOption.PRIVATE_KEY_METHOD, S2APrivateKeyMethod.create(stub, localIdentity)); + + SslContext sslContext = sslContextBuilder.build(); + SSLSessionContext sslSessionContext = sslContext.sessionContext(); + if (sslSessionContext instanceof OpenSslSessionContext) { + OpenSslSessionContext openSslSessionContext = (OpenSslSessionContext) sslSessionContext; + openSslSessionContext.setSessionCacheEnabled(false); + } + + return sslContext; + } + + private static GetTlsConfigurationResp.ClientTlsConfiguration getClientTlsConfigurationFromS2A( + S2AStub stub, Optional localIdentity) throws IOException, InterruptedException { + checkNotNull(stub, "stub should not be null."); + SessionReq.Builder reqBuilder = SessionReq.newBuilder(); + if (localIdentity.isPresent()) { + reqBuilder.setLocalIdentity(localIdentity.get().getIdentity()); + } + Optional authMechanism = + GetAuthenticationMechanisms.getAuthMechanism(localIdentity, + GetAuthenticationMechanisms.TOKEN_MANAGER); + if (authMechanism.isPresent()) { + reqBuilder.addAuthenticationMechanisms(authMechanism.get()); + } + SessionResp resp = + stub.send( + reqBuilder + .setGetTlsConfigurationReq( + GetTlsConfigurationReq.newBuilder() + .setConnectionSide(ConnectionSide.CONNECTION_SIDE_CLIENT)) + .build()); + if (resp.hasStatus() && resp.getStatus().getCode() != 0) { + throw new S2AConnectionException( + String.format( + "response from S2A server has ean error %d with error message %s.", + resp.getStatus().getCode(), resp.getStatus().getDetails())); + } + if (!resp.getGetTlsConfigurationResp().hasClientTlsConfiguration()) { + throw new S2AConnectionException( + "Response from S2A server does NOT contain ClientTlsConfiguration."); + } + return resp.getGetTlsConfigurationResp().getClientTlsConfiguration(); + } + + private static void configureSslContextWithClientTlsConfiguration( + GetTlsConfigurationResp.ClientTlsConfiguration clientTlsConfiguration, + SslContextBuilder sslContextBuilder) + throws CertificateException, + IOException, + KeyStoreException, + NoSuchAlgorithmException, + UnrecoverableKeyException { + sslContextBuilder.keyManager(createKeylessManager(clientTlsConfiguration)); + ImmutableSet tlsVersions; + tlsVersions = + ProtoUtil.buildTlsProtocolVersionSet( + clientTlsConfiguration.getMinTlsVersion(), clientTlsConfiguration.getMaxTlsVersion()); + if (tlsVersions.isEmpty()) { + throw new S2AConnectionException( + "Set of TLS versions received from S2A server is empty or not supported."); + } + sslContextBuilder.protocols(tlsVersions); + } + + private static KeyManager createKeylessManager( + GetTlsConfigurationResp.ClientTlsConfiguration clientTlsConfiguration) + throws CertificateException, + IOException, + KeyStoreException, + NoSuchAlgorithmException, + UnrecoverableKeyException { + X509Certificate[] certificates = + new X509Certificate[clientTlsConfiguration.getCertificateChainCount()]; + for (int i = 0; i < clientTlsConfiguration.getCertificateChainCount(); ++i) { + certificates[i] = convertStringToX509Cert(clientTlsConfiguration.getCertificateChain(i)); + } + KeyManager[] keyManagers = + OpenSslX509KeyManagerFactory.newKeyless(certificates).getKeyManagers(); + if (keyManagers == null || keyManagers.length == 0) { + throw new IllegalStateException("No key managers created."); + } + return keyManagers[0]; + } + + private static X509Certificate convertStringToX509Cert(String certificate) + throws CertificateException { + return (X509Certificate) + CertificateFactory.getInstance("X509") + .generateCertificate(new ByteArrayInputStream(certificate.getBytes(UTF_8))); + } + + private SslContextFactory() {} +} diff --git a/s2a/src/main/java/io/grpc/s2a/internal/handshaker/tokenmanager/AccessTokenManager.java b/s2a/src/main/java/io/grpc/s2a/internal/handshaker/tokenmanager/AccessTokenManager.java new file mode 100644 index 00000000000..65fca46bbb2 --- /dev/null +++ b/s2a/src/main/java/io/grpc/s2a/internal/handshaker/tokenmanager/AccessTokenManager.java @@ -0,0 +1,49 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.s2a.internal.handshaker.tokenmanager; + +import io.grpc.s2a.internal.handshaker.S2AIdentity; +import java.util.Optional; +import javax.annotation.concurrent.ThreadSafe; + +/** Manages access tokens for authenticating to the S2A. */ +@ThreadSafe +public final class AccessTokenManager { + private final TokenFetcher tokenFetcher; + + /** Creates an {@code AccessTokenManager} based on the environment where the application runs. */ + public static Optional create() { + Optional tokenFetcher = SingleTokenFetcher.create(); + return tokenFetcher.isPresent() + ? Optional.of(new AccessTokenManager(tokenFetcher.get())) + : Optional.empty(); + } + + private AccessTokenManager(TokenFetcher tokenFetcher) { + this.tokenFetcher = tokenFetcher; + } + + /** Returns an access token when no identity is specified. */ + public String getDefaultToken() { + return tokenFetcher.getDefaultToken(); + } + + /** Returns an access token for the given identity. */ + public String getToken(S2AIdentity identity) { + return tokenFetcher.getToken(identity); + } +} \ No newline at end of file diff --git a/s2a/src/main/java/io/grpc/s2a/internal/handshaker/tokenmanager/SingleTokenFetcher.java b/s2a/src/main/java/io/grpc/s2a/internal/handshaker/tokenmanager/SingleTokenFetcher.java new file mode 100644 index 00000000000..28aa0f87ba1 --- /dev/null +++ b/s2a/src/main/java/io/grpc/s2a/internal/handshaker/tokenmanager/SingleTokenFetcher.java @@ -0,0 +1,62 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.s2a.internal.handshaker.tokenmanager; + +import com.google.common.annotations.VisibleForTesting; +import io.grpc.s2a.internal.handshaker.S2AIdentity; +import java.util.Optional; + +/** Fetches a single access token via an environment variable. */ +@SuppressWarnings("NonFinalStaticField") +public final class SingleTokenFetcher implements TokenFetcher { + private static final String ENVIRONMENT_VARIABLE = "S2A_ACCESS_TOKEN"; + private static String accessToken = System.getenv(ENVIRONMENT_VARIABLE); + + private final String token; + + /** + * Creates a {@code SingleTokenFetcher} from {@code ENVIRONMENT_VARIABLE}, and returns an empty + * {@code Optional} instance if the token could not be fetched. + */ + public static Optional create() { + return Optional.ofNullable(accessToken).map(SingleTokenFetcher::new); + } + + @VisibleForTesting + public static void setAccessToken(String token) { + accessToken = token; + } + + @VisibleForTesting + public static String getAccessToken() { + return accessToken; + } + + private SingleTokenFetcher(String token) { + this.token = token; + } + + @Override + public String getDefaultToken() { + return token; + } + + @Override + public String getToken(S2AIdentity identity) { + return token; + } +} diff --git a/s2a/src/main/java/io/grpc/s2a/internal/handshaker/tokenmanager/TokenFetcher.java b/s2a/src/main/java/io/grpc/s2a/internal/handshaker/tokenmanager/TokenFetcher.java new file mode 100644 index 00000000000..6827f095afe --- /dev/null +++ b/s2a/src/main/java/io/grpc/s2a/internal/handshaker/tokenmanager/TokenFetcher.java @@ -0,0 +1,28 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.s2a.internal.handshaker.tokenmanager; + +import io.grpc.s2a.internal.handshaker.S2AIdentity; + +/** Fetches tokens used to authenticate to S2A. */ +interface TokenFetcher { + /** Returns an access token when no identity is specified. */ + String getDefaultToken(); + + /** Returns an access token for the given identity. */ + String getToken(S2AIdentity identity); +} \ No newline at end of file diff --git a/s2a/src/test/java/io/grpc/s2a/IntegrationTest.java b/s2a/src/test/java/io/grpc/s2a/IntegrationTest.java new file mode 100644 index 00000000000..1d3568808c6 --- /dev/null +++ b/s2a/src/test/java/io/grpc/s2a/IntegrationTest.java @@ -0,0 +1,256 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.s2a; + +import static com.google.common.truth.Truth.assertThat; +import static java.util.concurrent.TimeUnit.SECONDS; + +import com.google.s2a.proto.v2.S2AServiceGrpc; +import io.grpc.Channel; +import io.grpc.ChannelCredentials; +import io.grpc.Grpc; +import io.grpc.InsecureChannelCredentials; +import io.grpc.ManagedChannel; +import io.grpc.Server; +import io.grpc.ServerBuilder; +import io.grpc.ServerCredentials; +import io.grpc.TlsChannelCredentials; +import io.grpc.TlsServerCredentials; +import io.grpc.benchmarks.Utils; +import io.grpc.internal.ObjectPool; +import io.grpc.internal.SharedResourcePool; +import io.grpc.netty.GrpcSslContexts; +import io.grpc.netty.NettyServerBuilder; +import io.grpc.s2a.S2AChannelCredentials; +import io.grpc.s2a.internal.channel.S2AHandshakerServiceChannel; +import io.grpc.s2a.internal.handshaker.FakeS2AServer; +import io.grpc.s2a.internal.handshaker.S2AStub; +import io.grpc.stub.StreamObserver; +import io.grpc.testing.protobuf.SimpleRequest; +import io.grpc.testing.protobuf.SimpleResponse; +import io.grpc.testing.protobuf.SimpleServiceGrpc; +import io.netty.handler.ssl.ClientAuth; +import io.netty.handler.ssl.OpenSslSessionContext; +import io.netty.handler.ssl.SslContext; +import io.netty.handler.ssl.SslContextBuilder; +import io.netty.handler.ssl.SslProvider; +import java.io.InputStream; +import java.util.concurrent.FutureTask; +import java.util.logging.Logger; +import javax.net.ssl.SSLException; +import javax.net.ssl.SSLSessionContext; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public final class IntegrationTest { + private static final Logger logger = Logger.getLogger(FakeS2AServer.class.getName()); + private String s2aAddress; + private Server s2aServer; + private String s2aDelayAddress; + private Server s2aDelayServer; + private String mtlsS2AAddress; + private Server mtlsS2AServer; + private String serverAddress; + private Server server; + + @Before + public void setUp() throws Exception { + s2aServer = ServerBuilder.forPort(0).addService(new FakeS2AServer()).build().start(); + int s2aPort = s2aServer.getPort(); + s2aAddress = "localhost:" + s2aPort; + logger.info("S2A service listening on localhost:" + s2aPort); + ClassLoader classLoader = IntegrationTest.class.getClassLoader(); + InputStream s2aCert = classLoader.getResourceAsStream("server_cert.pem"); + InputStream s2aKey = classLoader.getResourceAsStream("server_key.pem"); + InputStream rootCert = classLoader.getResourceAsStream("root_cert.pem"); + ServerCredentials s2aCreds = + TlsServerCredentials.newBuilder() + .keyManager(s2aCert, s2aKey) + .trustManager(rootCert) + .clientAuth(TlsServerCredentials.ClientAuth.REQUIRE) + .build(); + mtlsS2AServer = NettyServerBuilder.forPort(0, s2aCreds).addService(new FakeS2AServer()).build(); + mtlsS2AServer.start(); + int mtlsS2APort = mtlsS2AServer.getPort(); + mtlsS2AAddress = "localhost:" + mtlsS2APort; + logger.info("mTLS S2A service listening on localhost:" + mtlsS2APort); + + int s2aDelayPort = Utils.pickUnusedPort(); + s2aDelayAddress = "localhost:" + s2aDelayPort; + s2aDelayServer = ServerBuilder.forPort(s2aDelayPort).addService(new FakeS2AServer()).build(); + + server = + NettyServerBuilder.forPort(0) + .addService(new SimpleServiceImpl()) + .sslContext(buildSslContext()) + .build() + .start(); + int serverPort = server.getPort(); + serverAddress = "localhost:" + serverPort; + logger.info("Simple Service listening on localhost:" + serverPort); + } + + @After + public void tearDown() throws Exception { + server.shutdown(); + s2aServer.shutdown(); + s2aDelayServer.shutdown(); + mtlsS2AServer.shutdown(); + + server.awaitTermination(10, SECONDS); + s2aServer.awaitTermination(10, SECONDS); + s2aDelayServer.awaitTermination(10, SECONDS); + mtlsS2AServer.awaitTermination(10, SECONDS); + } + + @Test + public void clientCommunicateUsingS2ACredentials_succeeds() throws Exception { + ChannelCredentials credentials = + S2AChannelCredentials.newBuilder(s2aAddress, InsecureChannelCredentials.create()) + .setLocalSpiffeId("test-spiffe-id").build(); + ManagedChannel channel = Grpc.newChannelBuilder(serverAddress, credentials).build(); + + assertThat(doUnaryRpc(channel)).isTrue(); + } + + @Test + public void clientCommunicateUsingS2ACredentialsNoLocalIdentity_succeeds() throws Exception { + ChannelCredentials credentials = S2AChannelCredentials.newBuilder(s2aAddress, + InsecureChannelCredentials.create()).build(); + ManagedChannel channel = Grpc.newChannelBuilder(serverAddress, credentials).build(); + + assertThat(doUnaryRpc(channel)).isTrue(); + } + + @Test + public void clientCommunicateUsingS2ACredentialsSucceeds_verifyStreamToS2AClosed() + throws Exception { + ObjectPool s2aChannelPool = + SharedResourcePool.forResource( + S2AHandshakerServiceChannel.getChannelResource(s2aAddress, + InsecureChannelCredentials.create())); + Channel ch = s2aChannelPool.getObject(); + S2AStub stub = S2AStub.newInstance(S2AServiceGrpc.newStub(ch)); + ChannelCredentials credentials = + S2AChannelCredentials.newBuilder(s2aAddress, InsecureChannelCredentials.create()) + .setLocalSpiffeId("test-spiffe-id").setStub(stub).build(); + ManagedChannel channel = Grpc.newChannelBuilder(serverAddress, credentials).build(); + + s2aChannelPool.returnObject(ch); + assertThat(doUnaryRpc(channel)).isTrue(); + assertThat(stub.isClosed()).isTrue(); + } + + @Test + public void clientCommunicateUsingMtlsToS2ACredentials_succeeds() throws Exception { + ClassLoader classLoader = IntegrationTest.class.getClassLoader(); + InputStream privateKey = classLoader.getResourceAsStream("client_key.pem"); + InputStream certChain = classLoader.getResourceAsStream("client_cert.pem"); + InputStream trustBundle = classLoader.getResourceAsStream("root_cert.pem"); + ChannelCredentials s2aChannelCredentials = + TlsChannelCredentials.newBuilder() + .keyManager(certChain, privateKey) + .trustManager(trustBundle) + .build(); + + ChannelCredentials credentials = + S2AChannelCredentials.newBuilder(mtlsS2AAddress, s2aChannelCredentials) + .setLocalSpiffeId("test-spiffe-id") + .build(); + ManagedChannel channel = Grpc.newChannelBuilder(serverAddress, credentials).build(); + + assertThat(doUnaryRpc(channel)).isTrue(); + } + + @Test + public void clientCommunicateUsingS2ACredentials_s2AdelayStart_succeeds() throws Exception { + ChannelCredentials credentials = S2AChannelCredentials.newBuilder(s2aDelayAddress, + InsecureChannelCredentials.create()).build(); + ManagedChannel channel = Grpc.newChannelBuilder(serverAddress, credentials).build(); + + FutureTask rpc = new FutureTask<>(() -> doUnaryRpc(channel)); + new Thread(rpc).start(); + Thread.sleep(2000); + s2aDelayServer.start(); + assertThat(rpc.get()).isTrue(); + } + + public static boolean doUnaryRpc(ManagedChannel channel) throws InterruptedException { + try { + SimpleServiceGrpc.SimpleServiceBlockingStub stub = + SimpleServiceGrpc.newBlockingStub(channel); + SimpleResponse resp = stub.unaryRpc(SimpleRequest.newBuilder() + .setRequestMessage("S2A team") + .build()); + if (!resp.getResponseMessage().equals("Hello, S2A team!")) { + logger.info( + "Received unexpected message from the Simple Service: " + resp.getResponseMessage()); + throw new RuntimeException(); + } else { + System.out.println( + "We received this message from the Simple Service: " + resp.getResponseMessage()); + return true; + } + } finally { + channel.shutdown(); + channel.awaitTermination(1, SECONDS); + } + } + + private static SslContext buildSslContext() throws SSLException { + ClassLoader classLoader = IntegrationTest.class.getClassLoader(); + InputStream privateKey = classLoader.getResourceAsStream("leaf_key_ec.pem"); + InputStream rootCert = classLoader.getResourceAsStream("root_cert_ec.pem"); + InputStream certChain = classLoader.getResourceAsStream("cert_chain_ec.pem"); + SslContextBuilder sslServerContextBuilder = + SslContextBuilder.forServer(certChain, privateKey); + SslContext sslServerContext = + GrpcSslContexts.configure(sslServerContextBuilder, SslProvider.OPENSSL) + .protocols("TLSv1.3", "TLSv1.2") + .trustManager(rootCert) + .clientAuth(ClientAuth.REQUIRE) + .build(); + + // Enable TLS resumption. This requires using the OpenSSL provider, since the JDK provider does + // not allow a server to send session tickets. + SSLSessionContext sslSessionContext = sslServerContext.sessionContext(); + if (!(sslSessionContext instanceof OpenSslSessionContext)) { + throw new SSLException("sslSessionContext does not use OpenSSL."); + } + OpenSslSessionContext openSslSessionContext = (OpenSslSessionContext) sslSessionContext; + // Calling {@code setTicketKeys} without specifying any keys means that the SSL libraries will + // handle the generation of the resumption master secret. + openSslSessionContext.setTicketKeys(); + + return sslServerContext; + } + + public static class SimpleServiceImpl extends SimpleServiceGrpc.SimpleServiceImplBase { + @Override + public void unaryRpc(SimpleRequest request, StreamObserver observer) { + observer.onNext( + SimpleResponse.newBuilder() + .setResponseMessage("Hello, " + request.getRequestMessage() + "!") + .build()); + observer.onCompleted(); + } + } +} diff --git a/s2a/src/test/java/io/grpc/s2a/S2AChannelCredentialsTest.java b/s2a/src/test/java/io/grpc/s2a/S2AChannelCredentialsTest.java new file mode 100644 index 00000000000..3e6eef7f470 --- /dev/null +++ b/s2a/src/test/java/io/grpc/s2a/S2AChannelCredentialsTest.java @@ -0,0 +1,136 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.s2a; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; + +import io.grpc.ChannelCredentials; +import io.grpc.InsecureChannelCredentials; +import io.grpc.TlsChannelCredentials; +import java.io.InputStream; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Unit tests for {@code S2AChannelCredentials}. */ +@RunWith(JUnit4.class) +public final class S2AChannelCredentialsTest { + @Test + public void newBuilder_nullAddress_throwsException() throws Exception { + assertThrows(IllegalArgumentException.class, () -> S2AChannelCredentials.newBuilder(null, + InsecureChannelCredentials.create())); + } + + @Test + public void newBuilder_emptyAddress_throwsException() throws Exception { + assertThrows(IllegalArgumentException.class, () -> S2AChannelCredentials.newBuilder("", + InsecureChannelCredentials.create())); + } + + @Test + public void newBuilder_nullChannelCreds_throwsException() throws Exception { + assertThrows(NullPointerException.class, () -> S2AChannelCredentials + .newBuilder("s2a_address", null)); + } + + @Test + public void setLocalSpiffeId_nullArgument_throwsException() throws Exception { + assertThrows( + NullPointerException.class, + () -> S2AChannelCredentials.newBuilder("s2a_address", + InsecureChannelCredentials.create()).setLocalSpiffeId(null)); + } + + @Test + public void setLocalHostname_nullArgument_throwsException() throws Exception { + assertThrows( + NullPointerException.class, + () -> S2AChannelCredentials.newBuilder("s2a_address", + InsecureChannelCredentials.create()).setLocalHostname(null)); + } + + @Test + public void setLocalUid_nullArgument_throwsException() throws Exception { + assertThrows( + NullPointerException.class, + () -> S2AChannelCredentials.newBuilder("s2a_address", + InsecureChannelCredentials.create()).setLocalUid(null)); + } + + @Test + public void build_withLocalSpiffeId_succeeds() throws Exception { + assertThat( + S2AChannelCredentials.newBuilder("s2a_address", InsecureChannelCredentials.create()) + .setLocalSpiffeId("spiffe://test") + .build()) + .isNotNull(); + } + + @Test + public void build_withLocalHostname_succeeds() throws Exception { + assertThat( + S2AChannelCredentials.newBuilder("s2a_address", InsecureChannelCredentials.create()) + .setLocalHostname("local_hostname") + .build()) + .isNotNull(); + } + + @Test + public void build_withLocalUid_succeeds() throws Exception { + assertThat(S2AChannelCredentials.newBuilder("s2a_address", + InsecureChannelCredentials.create()).setLocalUid("local_uid").build()) + .isNotNull(); + } + + @Test + public void build_withNoLocalIdentity_succeeds() throws Exception { + assertThat(S2AChannelCredentials.newBuilder("s2a_address", + InsecureChannelCredentials.create()).build()) + .isNotNull(); + } + + @Test + public void build_withUseMtlsToS2ANoLocalIdentity_success() throws Exception { + ChannelCredentials s2aChannelCredentials = getTlsChannelCredentials(); + assertThat( + S2AChannelCredentials.newBuilder("s2a_address", s2aChannelCredentials) + .build()) + .isNotNull(); + } + + @Test + public void build_withUseMtlsToS2AWithLocalUid_success() throws Exception { + ChannelCredentials s2aChannelCredentials = getTlsChannelCredentials(); + assertThat( + S2AChannelCredentials.newBuilder("s2a_address", s2aChannelCredentials) + .setLocalUid("local_uid") + .build()) + .isNotNull(); + } + + private static ChannelCredentials getTlsChannelCredentials() throws Exception { + ClassLoader classLoader = S2AChannelCredentialsTest.class.getClassLoader(); + InputStream privateKey = classLoader.getResourceAsStream("client_key.pem"); + InputStream certChain = classLoader.getResourceAsStream("client_cert.pem"); + InputStream trustBundle = classLoader.getResourceAsStream("root_cert.pem"); + return TlsChannelCredentials.newBuilder() + .keyManager(certChain, privateKey) + .trustManager(trustBundle) + .build(); + } +} \ No newline at end of file diff --git a/s2a/src/test/java/io/grpc/s2a/internal/channel/S2AHandshakerServiceChannelTest.java b/s2a/src/test/java/io/grpc/s2a/internal/channel/S2AHandshakerServiceChannelTest.java new file mode 100644 index 00000000000..9ba3caaf99e --- /dev/null +++ b/s2a/src/test/java/io/grpc/s2a/internal/channel/S2AHandshakerServiceChannelTest.java @@ -0,0 +1,259 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.s2a.internal.channel; + +import static com.google.common.truth.Truth.assertThat; +import static com.google.common.truth.extensions.proto.ProtoTruth.assertThat; +import static org.junit.Assert.assertThrows; + +import io.grpc.Channel; +import io.grpc.ChannelCredentials; +import io.grpc.InsecureChannelCredentials; +import io.grpc.ManagedChannel; +import io.grpc.Server; +import io.grpc.ServerBuilder; +import io.grpc.ServerCredentials; +import io.grpc.StatusRuntimeException; +import io.grpc.TlsChannelCredentials; +import io.grpc.TlsServerCredentials; +import io.grpc.internal.SharedResourceHolder.Resource; +import io.grpc.netty.NettyServerBuilder; +import io.grpc.stub.StreamObserver; +import io.grpc.testing.GrpcCleanupRule; +import io.grpc.testing.protobuf.SimpleRequest; +import io.grpc.testing.protobuf.SimpleResponse; +import io.grpc.testing.protobuf.SimpleServiceGrpc; +import java.io.InputStream; +import org.junit.Before; +import org.junit.ClassRule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Unit tests for {@link S2AHandshakerServiceChannel}. */ +@RunWith(JUnit4.class) +public final class S2AHandshakerServiceChannelTest { + @ClassRule public static final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule(); + private Server mtlsServer; + private Server plaintextServer; + + @Before + public void setUp() throws Exception { + mtlsServer = createMtlsServer(); + plaintextServer = createPlaintextServer(); + mtlsServer.start(); + plaintextServer.start(); + } + + /** + * Creates a {@code Resource} and verifies that it produces a {@code ChannelResource} + * instance by using its {@code toString()} method. + */ + @Test + public void getChannelResource_success() { + Resource resource = + S2AHandshakerServiceChannel.getChannelResource( + "localhost:" + plaintextServer.getPort(), + InsecureChannelCredentials.create()); + assertThat(resource.toString()).isEqualTo("grpc-s2a-channel"); + } + + /** Same as getChannelResource_success, but use mTLS. */ + @Test + public void getChannelResource_mtlsSuccess() throws Exception { + Resource resource = + S2AHandshakerServiceChannel.getChannelResource( + "localhost:" + mtlsServer.getPort(), getTlsChannelCredentials()); + assertThat(resource.toString()).isEqualTo("grpc-s2a-channel"); + } + + /** + * Creates two {@code Resoure}s for the same target address and verifies that they are + * distinct. + */ + @Test + public void getChannelResource_twoUnEqualChannels() { + Resource resource = + S2AHandshakerServiceChannel.getChannelResource( + "localhost:" + plaintextServer.getPort(), + InsecureChannelCredentials.create()); + Resource resourceTwo = + S2AHandshakerServiceChannel.getChannelResource( + "localhost:" + plaintextServer.getPort(), + InsecureChannelCredentials.create()); + assertThat(resource).isNotEqualTo(resourceTwo); + } + + /** Same as getChannelResource_twoUnEqualChannels, but use mTLS. */ + @Test + public void getChannelResource_mtlsTwoUnEqualChannels() throws Exception { + Resource resource = + S2AHandshakerServiceChannel.getChannelResource( + "localhost:" + mtlsServer.getPort(), getTlsChannelCredentials()); + Resource resourceTwo = + S2AHandshakerServiceChannel.getChannelResource( + "localhost:" + mtlsServer.getPort(), getTlsChannelCredentials()); + assertThat(resource).isNotEqualTo(resourceTwo); + } + + /** + * Creates two {@code Resoure}s for different target addresses and verifies that they are + * distinct. + */ + @Test + public void getChannelResource_twoDistinctChannels() { + Resource resource = + S2AHandshakerServiceChannel.getChannelResource( + "localhost:" + plaintextServer.getPort(), + InsecureChannelCredentials.create()); + Resource resourceTwo = + S2AHandshakerServiceChannel.getChannelResource( + "localhost:" + plaintextServer.getPort() + 1, InsecureChannelCredentials.create()); + assertThat(resourceTwo).isNotEqualTo(resource); + } + + /** Same as getChannelResource_twoDistinctChannels, but use mTLS. */ + @Test + public void getChannelResource_mtlsTwoDistinctChannels() throws Exception { + Resource resource = + S2AHandshakerServiceChannel.getChannelResource( + "localhost:" + mtlsServer.getPort(), getTlsChannelCredentials()); + Resource resourceTwo = + S2AHandshakerServiceChannel.getChannelResource( + "localhost:" + mtlsServer.getPort() + 1, getTlsChannelCredentials()); + assertThat(resourceTwo).isNotEqualTo(resource); + } + + /** + * Uses a {@code Resource} to create a channel, closes the channel, and verifies that the + * channel is closed by attempting to make a simple RPC. + */ + @Test + public void close_success() { + Resource resource = + S2AHandshakerServiceChannel.getChannelResource( + "localhost:" + plaintextServer.getPort(), + InsecureChannelCredentials.create()); + Channel channel = resource.create(); + resource.close(channel); + StatusRuntimeException expected = + assertThrows( + StatusRuntimeException.class, + () -> + SimpleServiceGrpc.newBlockingStub(channel) + .unaryRpc(SimpleRequest.getDefaultInstance())); + assertThat(expected).hasMessageThat().isEqualTo("UNAVAILABLE: Channel shutdown invoked"); + } + + /** Same as close_success, but use mTLS. */ + @Test + public void close_mtlsSuccess() throws Exception { + Resource resource = + S2AHandshakerServiceChannel.getChannelResource( + "localhost:" + mtlsServer.getPort(), getTlsChannelCredentials()); + Channel channel = resource.create(); + resource.close(channel); + StatusRuntimeException expected = + assertThrows( + StatusRuntimeException.class, + () -> + SimpleServiceGrpc.newBlockingStub(channel) + .unaryRpc(SimpleRequest.getDefaultInstance())); + assertThat(expected).hasMessageThat().isEqualTo("UNAVAILABLE: Channel shutdown invoked"); + } + + /** + * Creates and closes a {@code ManagedChannel}, creates a new channel from the same + * resource, and verifies that this second channel is useable. + */ + @Test + public void create_succeedsAfterCloseIsCalledOnce() throws Exception { + Resource resource = + S2AHandshakerServiceChannel.getChannelResource( + "localhost:" + plaintextServer.getPort(), + InsecureChannelCredentials.create()); + Channel channelOne = resource.create(); + resource.close(channelOne); + + Channel channelTwo = resource.create(); + assertThat(channelTwo).isInstanceOf(ManagedChannel.class); + assertThat( + SimpleServiceGrpc.newBlockingStub(channelTwo) + .unaryRpc(SimpleRequest.getDefaultInstance())) + .isEqualToDefaultInstance(); + resource.close(channelTwo); + } + + /** Same as create_succeedsAfterCloseIsCalledOnce, but use mTLS. */ + @Test + public void create_mtlsSucceedsAfterCloseIsCalledOnce() throws Exception { + Resource resource = + S2AHandshakerServiceChannel.getChannelResource( + "localhost:" + mtlsServer.getPort(), getTlsChannelCredentials()); + Channel channelOne = resource.create(); + resource.close(channelOne); + + Channel channelTwo = resource.create(); + assertThat(channelTwo).isInstanceOf(ManagedChannel.class); + assertThat( + SimpleServiceGrpc.newBlockingStub(channelTwo) + .unaryRpc(SimpleRequest.getDefaultInstance())) + .isEqualToDefaultInstance(); + resource.close(channelTwo); + } + + private static Server createMtlsServer() throws Exception { + SimpleServiceImpl service = new SimpleServiceImpl(); + ClassLoader classLoader = S2AHandshakerServiceChannelTest.class.getClassLoader(); + InputStream serverCert = classLoader.getResourceAsStream("server_cert.pem"); + InputStream serverKey = classLoader.getResourceAsStream("server_key.pem"); + InputStream rootCert = classLoader.getResourceAsStream("root_cert.pem"); + ServerCredentials creds = + TlsServerCredentials.newBuilder() + .keyManager(serverCert, serverKey) + .trustManager(rootCert) + .clientAuth(TlsServerCredentials.ClientAuth.REQUIRE) + .build(); + return grpcCleanup.register( + NettyServerBuilder.forPort(0, creds).addService(service).build()); + } + + private static Server createPlaintextServer() { + SimpleServiceImpl service = new SimpleServiceImpl(); + return grpcCleanup.register( + ServerBuilder.forPort(0).addService(service).build()); + } + + private static ChannelCredentials getTlsChannelCredentials() throws Exception { + ClassLoader classLoader = S2AHandshakerServiceChannelTest.class.getClassLoader(); + InputStream clientCert = classLoader.getResourceAsStream("client_cert.pem"); + InputStream clientKey = classLoader.getResourceAsStream("client_key.pem"); + InputStream rootCert = classLoader.getResourceAsStream("root_cert.pem"); + return TlsChannelCredentials.newBuilder() + .keyManager(clientCert, clientKey) + .trustManager(rootCert) + .build(); + } + + private static class SimpleServiceImpl extends SimpleServiceGrpc.SimpleServiceImplBase { + @Override + public void unaryRpc(SimpleRequest request, StreamObserver streamObserver) { + streamObserver.onNext(SimpleResponse.getDefaultInstance()); + streamObserver.onCompleted(); + } + } +} diff --git a/s2a/src/test/java/io/grpc/s2a/internal/handshaker/FakeS2AServer.java b/s2a/src/test/java/io/grpc/s2a/internal/handshaker/FakeS2AServer.java new file mode 100644 index 00000000000..322397c93be --- /dev/null +++ b/s2a/src/test/java/io/grpc/s2a/internal/handshaker/FakeS2AServer.java @@ -0,0 +1,63 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.s2a.internal.handshaker; + +import com.google.s2a.proto.v2.S2AServiceGrpc; +import com.google.s2a.proto.v2.SessionReq; +import com.google.s2a.proto.v2.SessionResp; +import io.grpc.stub.StreamObserver; +import java.io.IOException; +import java.security.NoSuchAlgorithmException; +import java.security.spec.InvalidKeySpecException; +import java.util.logging.Logger; + +/** A fake S2Av2 server that should be used for testing only. */ +public final class FakeS2AServer extends S2AServiceGrpc.S2AServiceImplBase { + private static final Logger logger = Logger.getLogger(FakeS2AServer.class.getName()); + + private final FakeWriter writer; + + public FakeS2AServer() throws InvalidKeySpecException, NoSuchAlgorithmException, IOException { + this.writer = new FakeWriter(); + this.writer.setVerificationResult(FakeWriter.VerificationResult.SUCCESS).initializePrivateKey(); + } + + @Override + public StreamObserver setUpSession(StreamObserver responseObserver) { + return new StreamObserver() { + @Override + public void onNext(SessionReq req) { + logger.info("Received a request from client."); + try { + responseObserver.onNext(writer.handleResponse(req)); + } catch (IOException e) { + responseObserver.onError(e); + } + } + + @Override + public void onError(Throwable t) { + responseObserver.onError(t); + } + + @Override + public void onCompleted() { + responseObserver.onCompleted(); + } + }; + } +} \ No newline at end of file diff --git a/s2a/src/test/java/io/grpc/s2a/internal/handshaker/FakeS2AServerTest.java b/s2a/src/test/java/io/grpc/s2a/internal/handshaker/FakeS2AServerTest.java new file mode 100644 index 00000000000..c3155b864b3 --- /dev/null +++ b/s2a/src/test/java/io/grpc/s2a/internal/handshaker/FakeS2AServerTest.java @@ -0,0 +1,300 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.s2a.internal.handshaker; + +import static com.google.common.truth.extensions.proto.ProtoTruth.assertThat; +import static java.util.concurrent.TimeUnit.SECONDS; + +import com.google.common.collect.ImmutableList; +import com.google.common.util.concurrent.SettableFuture; +import com.google.protobuf.ByteString; +import com.google.s2a.proto.v2.Ciphersuite; +import com.google.s2a.proto.v2.ConnectionSide; +import com.google.s2a.proto.v2.GetTlsConfigurationReq; +import com.google.s2a.proto.v2.GetTlsConfigurationResp; +import com.google.s2a.proto.v2.S2AServiceGrpc; +import com.google.s2a.proto.v2.SessionReq; +import com.google.s2a.proto.v2.SessionResp; +import com.google.s2a.proto.v2.TLSVersion; +import com.google.s2a.proto.v2.ValidatePeerCertificateChainReq; +import com.google.s2a.proto.v2.ValidatePeerCertificateChainReq.VerificationMode; +import com.google.s2a.proto.v2.ValidatePeerCertificateChainResp; +import io.grpc.Grpc; +import io.grpc.InsecureChannelCredentials; +import io.grpc.ManagedChannel; +import io.grpc.Server; +import io.grpc.ServerBuilder; +import io.grpc.stub.StreamObserver; +import java.io.IOException; +import java.io.InputStream; +import java.nio.charset.StandardCharsets; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeoutException; +import java.util.logging.Logger; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Unit tests for {@link FakeS2AServer}. */ +@RunWith(JUnit4.class) +public final class FakeS2AServerTest { + private static final Logger logger = Logger.getLogger(FakeS2AServerTest.class.getName()); + + private static final ImmutableList FAKE_CERT_DER_CHAIN = + ImmutableList.of(ByteString.copyFrom("fake-der-chain".getBytes(StandardCharsets.US_ASCII))); + private String serverAddress; + private Server fakeS2AServer; + + @Before + public void setUp() throws Exception { + fakeS2AServer = ServerBuilder.forPort(0).addService(new FakeS2AServer()).build(); + fakeS2AServer.start(); + serverAddress = String.format("localhost:%d", fakeS2AServer.getPort()); + } + + @After + public void tearDown() throws Exception { + fakeS2AServer.shutdown(); + fakeS2AServer.awaitTermination(10, SECONDS); + } + + @Test + public void callS2AServerOnce_getTlsConfiguration_returnsValidResult() + throws InterruptedException, + IOException, + java.util.concurrent.ExecutionException, + TimeoutException { + ExecutorService executor = Executors.newSingleThreadExecutor(); + logger.info("Client connecting to: " + serverAddress); + ManagedChannel channel = + Grpc.newChannelBuilder(serverAddress, InsecureChannelCredentials.create()) + .executor(executor) + .build(); + SettableFuture respFuture = SettableFuture.create(); + try { + S2AServiceGrpc.S2AServiceStub asyncStub = S2AServiceGrpc.newStub(channel); + StreamObserver requestObserver = + asyncStub.setUpSession( + new StreamObserver() { + SessionResp recvResp; + @Override + public void onNext(SessionResp resp) { + recvResp = resp; + } + + @Override + public void onError(Throwable t) { + respFuture.setException(t); + } + + @Override + public void onCompleted() { + respFuture.set(recvResp); + } + }); + try { + requestObserver.onNext( + SessionReq.newBuilder() + .setGetTlsConfigurationReq( + GetTlsConfigurationReq.newBuilder() + .setConnectionSide(ConnectionSide.CONNECTION_SIDE_CLIENT)) + .build()); + } catch (RuntimeException e) { + // Cancel the RPC. + requestObserver.onError(e); + throw e; + } + // Mark the end of requests. + requestObserver.onCompleted(); + // Wait for receiving to happen. + respFuture.get(5, SECONDS); + } finally { + channel.shutdown(); + channel.awaitTermination(1, SECONDS); + executor.shutdown(); + executor.awaitTermination(1, SECONDS); + } + + String leafCertString = ""; + String cert2String = ""; + String cert1String = ""; + ClassLoader classLoader = FakeS2AServerTest.class.getClassLoader(); + try ( + InputStream leafCert = classLoader.getResourceAsStream("leaf_cert_ec.pem"); + InputStream cert2 = classLoader.getResourceAsStream("int_cert2_ec.pem"); + InputStream cert1 = classLoader.getResourceAsStream("int_cert1_ec.pem"); + ) { + leafCertString = FakeWriter.convertInputStreamToString(leafCert); + cert2String = FakeWriter.convertInputStreamToString(cert2); + cert1String = FakeWriter.convertInputStreamToString(cert1); + } + + SessionResp expected = + SessionResp.newBuilder() + .setGetTlsConfigurationResp( + GetTlsConfigurationResp.newBuilder() + .setClientTlsConfiguration( + GetTlsConfigurationResp.ClientTlsConfiguration.newBuilder() + .addCertificateChain(leafCertString) + .addCertificateChain(cert1String) + .addCertificateChain(cert2String) + .setMinTlsVersion(TLSVersion.TLS_VERSION_1_3) + .setMaxTlsVersion(TLSVersion.TLS_VERSION_1_3) + .addCiphersuites( + Ciphersuite.CIPHERSUITE_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256) + .addCiphersuites( + Ciphersuite.CIPHERSUITE_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384) + .addCiphersuites( + Ciphersuite.CIPHERSUITE_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256))) + .build(); + assertThat(respFuture.get()).ignoringRepeatedFieldOrder().isEqualTo(expected); + } + + @Test + public void callS2AServerOnce_validatePeerCertifiate_returnsValidResult() + throws InterruptedException, java.util.concurrent.ExecutionException, TimeoutException { + ExecutorService executor = Executors.newSingleThreadExecutor(); + logger.info("Client connecting to: " + serverAddress); + ManagedChannel channel = + Grpc.newChannelBuilder(serverAddress, InsecureChannelCredentials.create()) + .executor(executor) + .build(); + SettableFuture respFuture = SettableFuture.create(); + try { + S2AServiceGrpc.S2AServiceStub asyncStub = S2AServiceGrpc.newStub(channel); + StreamObserver requestObserver = + asyncStub.setUpSession( + new StreamObserver() { + private SessionResp recvResp; + @Override + public void onNext(SessionResp resp) { + recvResp = resp; + } + + @Override + public void onError(Throwable t) { + respFuture.setException(t); + } + + @Override + public void onCompleted() { + respFuture.set(recvResp); + } + }); + try { + requestObserver.onNext( + SessionReq.newBuilder() + .setValidatePeerCertificateChainReq( + ValidatePeerCertificateChainReq.newBuilder() + .setMode(VerificationMode.UNSPECIFIED) + .setClientPeer( + ValidatePeerCertificateChainReq.ClientPeer.newBuilder() + .addAllCertificateChain(FAKE_CERT_DER_CHAIN))) + .build()); + } catch (RuntimeException e) { + // Cancel the RPC. + requestObserver.onError(e); + throw e; + } + // Mark the end of requests. + requestObserver.onCompleted(); + // Wait for receiving to happen. + respFuture.get(5, SECONDS); + } finally { + channel.shutdown(); + channel.awaitTermination(1, SECONDS); + executor.shutdown(); + executor.awaitTermination(1, SECONDS); + } + + SessionResp expected = + SessionResp.newBuilder() + .setValidatePeerCertificateChainResp( + ValidatePeerCertificateChainResp.newBuilder() + .setValidationResult(ValidatePeerCertificateChainResp.ValidationResult.SUCCESS)) + .build(); + assertThat(respFuture.get()).ignoringRepeatedFieldOrder().isEqualTo(expected); + } + + @Test + public void callS2AServerRepeatedly_returnsValidResult() throws InterruptedException { + final int numberOfRequests = 10; + ExecutorService executor = Executors.newSingleThreadExecutor(); + logger.info("Client connecting to: " + serverAddress); + ManagedChannel channel = + Grpc.newChannelBuilder(serverAddress, InsecureChannelCredentials.create()) + .executor(executor) + .build(); + + try { + S2AServiceGrpc.S2AServiceStub asyncStub = S2AServiceGrpc.newStub(channel); + CountDownLatch finishLatch = new CountDownLatch(1); + StreamObserver requestObserver = + asyncStub.setUpSession( + new StreamObserver() { + private int expectedNumberOfReplies = numberOfRequests; + + @Override + public void onNext(SessionResp reply) { + System.out.println("Received a message from the S2AService service."); + expectedNumberOfReplies -= 1; + } + + @Override + public void onError(Throwable t) { + finishLatch.countDown(); + if (expectedNumberOfReplies != 0) { + throw new RuntimeException(t); + } + } + + @Override + public void onCompleted() { + finishLatch.countDown(); + if (expectedNumberOfReplies != 0) { + throw new RuntimeException(); + } + } + }); + try { + for (int i = 0; i < numberOfRequests; i++) { + requestObserver.onNext(SessionReq.getDefaultInstance()); + } + } catch (RuntimeException e) { + // Cancel the RPC. + requestObserver.onError(e); + throw e; + } + // Mark the end of requests. + requestObserver.onCompleted(); + // Wait for receiving to happen. + if (!finishLatch.await(10, SECONDS)) { + throw new RuntimeException(); + } + } finally { + channel.shutdown(); + channel.awaitTermination(1, SECONDS); + executor.shutdown(); + executor.awaitTermination(1, SECONDS); + } + } + +} \ No newline at end of file diff --git a/s2a/src/test/java/io/grpc/s2a/internal/handshaker/FakeWriter.java b/s2a/src/test/java/io/grpc/s2a/internal/handshaker/FakeWriter.java new file mode 100644 index 00000000000..0b398638f92 --- /dev/null +++ b/s2a/src/test/java/io/grpc/s2a/internal/handshaker/FakeWriter.java @@ -0,0 +1,386 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.s2a.internal.handshaker; + +import static com.google.s2a.proto.v2.TLSVersion.TLS_VERSION_1_2; +import static com.google.s2a.proto.v2.TLSVersion.TLS_VERSION_1_3; + +import com.google.common.collect.ImmutableMap; +import com.google.common.io.CharStreams; +import com.google.errorprone.annotations.CanIgnoreReturnValue; +import com.google.protobuf.ByteString; +import com.google.s2a.proto.v2.Ciphersuite; +import com.google.s2a.proto.v2.ConnectionSide; +import com.google.s2a.proto.v2.GetTlsConfigurationReq; +import com.google.s2a.proto.v2.GetTlsConfigurationResp; +import com.google.s2a.proto.v2.OffloadPrivateKeyOperationReq; +import com.google.s2a.proto.v2.OffloadPrivateKeyOperationResp; +import com.google.s2a.proto.v2.SessionReq; +import com.google.s2a.proto.v2.SessionResp; +import com.google.s2a.proto.v2.SignatureAlgorithm; +import com.google.s2a.proto.v2.Status; +import com.google.s2a.proto.v2.ValidatePeerCertificateChainReq; +import com.google.s2a.proto.v2.ValidatePeerCertificateChainResp; +import io.grpc.stub.StreamObserver; +import io.grpc.util.CertificateUtils; +import java.io.FileNotFoundException; +import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.io.UnsupportedEncodingException; +import java.nio.charset.StandardCharsets; +import java.security.NoSuchAlgorithmException; +import java.security.PrivateKey; +import java.security.Signature; +import java.security.spec.InvalidKeySpecException; + +/** A fake Writer Class to mock the behavior of S2A server. */ +final class FakeWriter implements StreamObserver { + /** Fake behavior of S2A service. */ + enum Behavior { + OK_STATUS, + EMPTY_RESPONSE, + ERROR_STATUS, + ERROR_RESPONSE, + COMPLETE_STATUS, + BAD_TLS_VERSION_RESPONSE, + } + + enum VerificationResult { + UNSPECIFIED, + SUCCESS, + FAILURE + } + + private static final ClassLoader classLoader = FakeWriter.class.getClassLoader(); + private static final ImmutableMap + ALGORITHM_TO_SIGNATURE_INSTANCE_IDENTIFIER = + ImmutableMap.of( + SignatureAlgorithm.S2A_SSL_SIGN_ECDSA_SECP256R1_SHA256, + "SHA256withECDSA", + SignatureAlgorithm.S2A_SSL_SIGN_ECDSA_SECP384R1_SHA384, + "SHA384withECDSA", + SignatureAlgorithm.S2A_SSL_SIGN_ECDSA_SECP521R1_SHA512, + "SHA512withECDSA"); + + private boolean fakeWriterClosed = false; + private Behavior behavior = Behavior.OK_STATUS; + private StreamObserver reader; + private VerificationResult verificationResult = VerificationResult.UNSPECIFIED; + private String failureReason; + private PrivateKey privateKey; + + public static String convertInputStreamToString(InputStream is) throws IOException { + return CharStreams.toString(new InputStreamReader(is, StandardCharsets.UTF_8)); + } + + @CanIgnoreReturnValue + FakeWriter setReader(StreamObserver reader) { + this.reader = reader; + return this; + } + + @CanIgnoreReturnValue + FakeWriter setBehavior(Behavior behavior) { + this.behavior = behavior; + return this; + } + + @CanIgnoreReturnValue + FakeWriter setVerificationResult(VerificationResult verificationResult) { + this.verificationResult = verificationResult; + return this; + } + + @CanIgnoreReturnValue + FakeWriter setFailureReason(String failureReason) { + this.failureReason = failureReason; + return this; + } + + @CanIgnoreReturnValue + FakeWriter initializePrivateKey() throws InvalidKeySpecException, NoSuchAlgorithmException, + IOException, FileNotFoundException, UnsupportedEncodingException { + try ( + InputStream keyInputStream = classLoader.getResourceAsStream("leaf_key_ec.pem"); + ) { + privateKey = CertificateUtils.getPrivateKey(keyInputStream); + } + return this; + } + + @CanIgnoreReturnValue + FakeWriter resetPrivateKey() { + privateKey = null; + return this; + } + + void sendUnexpectedResponse() { + reader.onNext(SessionResp.getDefaultInstance()); + } + + void sendIoError() { + reader.onError(new IOException("Intended ERROR from FakeWriter.")); + } + + void sendGetTlsConfigResp() { + String leafCertString = ""; + String cert2String = ""; + String cert1String = ""; + try ( + InputStream leafCert = classLoader.getResourceAsStream("leaf_cert_ec.pem"); + InputStream cert2 = classLoader.getResourceAsStream("int_cert2_ec.pem"); + InputStream cert1 = classLoader.getResourceAsStream("int_cert1_ec.pem"); + ) { + leafCertString = FakeWriter.convertInputStreamToString(leafCert); + cert2String = FakeWriter.convertInputStreamToString(cert2); + cert1String = FakeWriter.convertInputStreamToString(cert1); + } catch (IOException e) { + reader.onError(e); + } + reader.onNext( + SessionResp.newBuilder() + .setGetTlsConfigurationResp( + GetTlsConfigurationResp.newBuilder() + .setClientTlsConfiguration( + GetTlsConfigurationResp.ClientTlsConfiguration.newBuilder() + .addCertificateChain(leafCertString) + .addCertificateChain(cert1String) + .addCertificateChain(cert2String) + .setMinTlsVersion(TLS_VERSION_1_3) + .setMaxTlsVersion(TLS_VERSION_1_3) + .addCiphersuites( + Ciphersuite.CIPHERSUITE_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256) + .addCiphersuites( + Ciphersuite.CIPHERSUITE_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384) + .addCiphersuites( + Ciphersuite + .CIPHERSUITE_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256))) + .build()); + } + + boolean isFakeWriterClosed() { + return fakeWriterClosed; + } + + @Override + public void onNext(SessionReq sessionReq) { + switch (behavior) { + case OK_STATUS: + try { + reader.onNext(handleResponse(sessionReq)); + } catch (IOException e) { + reader.onError(e); + } + break; + case EMPTY_RESPONSE: + reader.onNext(SessionResp.getDefaultInstance()); + break; + case ERROR_STATUS: + reader.onNext( + SessionResp.newBuilder() + .setStatus( + Status.newBuilder() + .setCode(1) + .setDetails("Intended ERROR Status from FakeWriter.")) + .build()); + break; + case ERROR_RESPONSE: + reader.onError(new S2AConnectionException("Intended ERROR from FakeWriter.")); + break; + case COMPLETE_STATUS: + reader.onCompleted(); + break; + case BAD_TLS_VERSION_RESPONSE: + String leafCertString = ""; + String cert2String = ""; + String cert1String = ""; + try ( + InputStream leafCert = classLoader.getResourceAsStream("leaf_cert_ec.pem"); + InputStream cert2 = classLoader.getResourceAsStream("int_cert2_ec.pem"); + InputStream cert1 = classLoader.getResourceAsStream("int_cert1_ec.pem"); + ) { + leafCertString = FakeWriter.convertInputStreamToString(leafCert); + cert2String = FakeWriter.convertInputStreamToString(cert2); + cert1String = FakeWriter.convertInputStreamToString(cert1); + } catch (IOException e) { + reader.onError(e); + } + reader.onNext( + SessionResp.newBuilder() + .setGetTlsConfigurationResp( + GetTlsConfigurationResp.newBuilder() + .setClientTlsConfiguration( + GetTlsConfigurationResp.ClientTlsConfiguration.newBuilder() + .addCertificateChain(leafCertString) + .addCertificateChain(cert1String) + .addCertificateChain(cert2String) + .setMinTlsVersion(TLS_VERSION_1_3) + .setMaxTlsVersion(TLS_VERSION_1_2))) + .build()); + break; + default: + try { + reader.onNext(handleResponse(sessionReq)); + } catch (IOException e) { + reader.onError(e); + } + } + } + + SessionResp handleResponse(SessionReq sessionReq) throws IOException { + if (sessionReq.hasGetTlsConfigurationReq()) { + return handleGetTlsConfigurationReq(sessionReq.getGetTlsConfigurationReq()); + } + + if (sessionReq.hasValidatePeerCertificateChainReq()) { + return handleValidatePeerCertificateChainReq(sessionReq.getValidatePeerCertificateChainReq()); + } + + if (sessionReq.hasOffloadPrivateKeyOperationReq()) { + return handleOffloadPrivateKeyOperationReq(sessionReq.getOffloadPrivateKeyOperationReq()); + } + + return SessionResp.newBuilder() + .setStatus( + Status.newBuilder().setCode(255).setDetails("No supported operation designated.")) + .build(); + } + + private SessionResp handleGetTlsConfigurationReq(GetTlsConfigurationReq req) + throws IOException { + if (!req.getConnectionSide().equals(ConnectionSide.CONNECTION_SIDE_CLIENT)) { + return SessionResp.newBuilder() + .setStatus( + Status.newBuilder() + .setCode(255) + .setDetails("No TLS configuration for the server side.")) + .build(); + } + String leafCertString = ""; + String cert2String = ""; + String cert1String = ""; + try ( + InputStream leafCert = classLoader.getResourceAsStream("leaf_cert_ec.pem"); + InputStream cert2 = classLoader.getResourceAsStream("int_cert2_ec.pem"); + InputStream cert1 = classLoader.getResourceAsStream("int_cert1_ec.pem"); + ) { + leafCertString = FakeWriter.convertInputStreamToString(leafCert); + cert2String = FakeWriter.convertInputStreamToString(cert2); + cert1String = FakeWriter.convertInputStreamToString(cert1); + } catch (IOException e) { + reader.onError(e); + } + return SessionResp.newBuilder() + .setGetTlsConfigurationResp( + GetTlsConfigurationResp.newBuilder() + .setClientTlsConfiguration( + GetTlsConfigurationResp.ClientTlsConfiguration.newBuilder() + .addCertificateChain(leafCertString) + .addCertificateChain(cert1String) + .addCertificateChain(cert2String) + .setMinTlsVersion(TLS_VERSION_1_3) + .setMaxTlsVersion(TLS_VERSION_1_3) + .addCiphersuites( + Ciphersuite.CIPHERSUITE_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256) + .addCiphersuites( + Ciphersuite.CIPHERSUITE_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384) + .addCiphersuites( + Ciphersuite.CIPHERSUITE_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256))) + .build(); + } + + private SessionResp handleValidatePeerCertificateChainReq(ValidatePeerCertificateChainReq req) { + if (verifyValidatePeerCertificateChainReq(req) + && verificationResult == VerificationResult.SUCCESS) { + return SessionResp.newBuilder() + .setValidatePeerCertificateChainResp( + ValidatePeerCertificateChainResp.newBuilder() + .setValidationResult(ValidatePeerCertificateChainResp.ValidationResult.SUCCESS)) + .build(); + } + return SessionResp.newBuilder() + .setValidatePeerCertificateChainResp( + ValidatePeerCertificateChainResp.newBuilder() + .setValidationResult( + verificationResult == VerificationResult.FAILURE + ? ValidatePeerCertificateChainResp.ValidationResult.FAILURE + : ValidatePeerCertificateChainResp.ValidationResult.UNSPECIFIED) + .setValidationDetails(failureReason)) + .build(); + } + + private boolean verifyValidatePeerCertificateChainReq(ValidatePeerCertificateChainReq req) { + if (req.getMode() != ValidatePeerCertificateChainReq.VerificationMode.UNSPECIFIED) { + return false; + } + if (req.getClientPeer().getCertificateChainCount() > 0) { + return true; + } + if (req.getServerPeer().getCertificateChainCount() > 0 + && !req.getServerPeer().getServerHostname().isEmpty()) { + return true; + } + return false; + } + + private SessionResp handleOffloadPrivateKeyOperationReq(OffloadPrivateKeyOperationReq req) { + if (privateKey == null) { + return SessionResp.newBuilder() + .setStatus(Status.newBuilder().setCode(255).setDetails("No Private Key available.")) + .build(); + } + String signatureIdentifier = + ALGORITHM_TO_SIGNATURE_INSTANCE_IDENTIFIER.get(req.getSignatureAlgorithm()); + if (signatureIdentifier == null) { + return SessionResp.newBuilder() + .setStatus( + Status.newBuilder() + .setCode(255) + .setDetails("Only ECDSA key algorithms are supported.")) + .build(); + } + + byte[] signature; + try { + Signature sig = Signature.getInstance(signatureIdentifier); + sig.initSign(privateKey); + sig.update(req.getRawBytes().toByteArray()); + signature = sig.sign(); + } catch (Exception e) { + return SessionResp.newBuilder() + .setStatus(Status.newBuilder().setCode(255).setDetails(e.getMessage())) + .build(); + } + + return SessionResp.newBuilder() + .setOffloadPrivateKeyOperationResp( + OffloadPrivateKeyOperationResp.newBuilder().setOutBytes(ByteString.copyFrom(signature))) + .build(); + } + + @Override + public void onError(Throwable t) { + throw new UnsupportedOperationException("onError is not supported by FakeWriter."); + } + + @Override + public void onCompleted() { + fakeWriterClosed = true; + reader.onCompleted(); + } +} \ No newline at end of file diff --git a/s2a/src/test/java/io/grpc/s2a/internal/handshaker/GetAuthenticationMechanismsTest.java b/s2a/src/test/java/io/grpc/s2a/internal/handshaker/GetAuthenticationMechanismsTest.java new file mode 100644 index 00000000000..c1c629366aa --- /dev/null +++ b/s2a/src/test/java/io/grpc/s2a/internal/handshaker/GetAuthenticationMechanismsTest.java @@ -0,0 +1,78 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.s2a.internal.handshaker; + +import com.google.common.truth.Expect; +import com.google.s2a.proto.v2.AuthenticationMechanism; +import io.grpc.s2a.internal.handshaker.S2AIdentity; +import io.grpc.s2a.internal.handshaker.tokenmanager.AccessTokenManager; +import io.grpc.s2a.internal.handshaker.tokenmanager.SingleTokenFetcher; +import java.util.Optional; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Unit tests for {@link GetAuthenticationMechanisms}. */ +@RunWith(JUnit4.class) +public final class GetAuthenticationMechanismsTest { + @Rule public final Expect expect = Expect.create(); + private static final String TOKEN = "access_token"; + private static String originalAccessToken; + private Optional tokenManager; + + @BeforeClass + public static void setUpClass() { + originalAccessToken = SingleTokenFetcher.getAccessToken(); + // Set the token that the client will use to authenticate to the S2A. + SingleTokenFetcher.setAccessToken(TOKEN); + } + + @Before + public void setUp() { + tokenManager = AccessTokenManager.create(); + } + + @AfterClass + public static void tearDownClass() { + SingleTokenFetcher.setAccessToken(originalAccessToken); + } + + @Test + public void getAuthMechanisms_emptyIdentity_success() { + expect + .that(GetAuthenticationMechanisms.getAuthMechanism(Optional.empty(), tokenManager)) + .isEqualTo( + Optional.of(AuthenticationMechanism.newBuilder().setToken("access_token").build())); + } + + @Test + public void getAuthMechanisms_nonEmptyIdentity_success() { + S2AIdentity fakeIdentity = S2AIdentity.fromSpiffeId("fake-spiffe-id"); + expect + .that(GetAuthenticationMechanisms.getAuthMechanism(Optional.of(fakeIdentity), tokenManager)) + .isEqualTo( + Optional.of( + AuthenticationMechanism.newBuilder() + .setIdentity(fakeIdentity.getIdentity()) + .setToken("access_token") + .build())); + } +} diff --git a/s2a/src/test/java/io/grpc/s2a/internal/handshaker/ProtoUtilTest.java b/s2a/src/test/java/io/grpc/s2a/internal/handshaker/ProtoUtilTest.java new file mode 100644 index 00000000000..28dbf0e4d88 --- /dev/null +++ b/s2a/src/test/java/io/grpc/s2a/internal/handshaker/ProtoUtilTest.java @@ -0,0 +1,89 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.s2a.internal.handshaker; + +import static org.junit.Assert.assertThrows; + +import com.google.common.collect.ImmutableSet; +import com.google.common.truth.Expect; +import com.google.s2a.proto.v2.TLSVersion; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Unit tests for {@link ProtoUtil}. */ +@RunWith(JUnit4.class) +public final class ProtoUtilTest { + @Rule public final Expect expect = Expect.create(); + + @Test + public void convertTlsProtocolVersion_success() { + expect + .that(ProtoUtil.convertTlsProtocolVersion(TLSVersion.TLS_VERSION_1_3)) + .isEqualTo("TLSv1.3"); + expect + .that(ProtoUtil.convertTlsProtocolVersion(TLSVersion.TLS_VERSION_1_2)) + .isEqualTo("TLSv1.2"); + expect + .that(ProtoUtil.convertTlsProtocolVersion(TLSVersion.TLS_VERSION_1_1)) + .isEqualTo("TLSv1.1"); + expect.that(ProtoUtil.convertTlsProtocolVersion(TLSVersion.TLS_VERSION_1_0)).isEqualTo("TLSv1"); + } + + @Test + public void convertTlsProtocolVersion_withUnknownTlsVersion_fails() { + IllegalArgumentException expected = + assertThrows( + IllegalArgumentException.class, + () -> ProtoUtil.convertTlsProtocolVersion(TLSVersion.TLS_VERSION_UNSPECIFIED)); + expect.that(expected).hasMessageThat().isEqualTo("TLS version 0 is not supported."); + } + + @Test + public void buildTlsProtocolVersionSet_success() { + expect + .that( + ProtoUtil.buildTlsProtocolVersionSet( + TLSVersion.TLS_VERSION_1_0, TLSVersion.TLS_VERSION_1_3)) + .isEqualTo(ImmutableSet.of("TLSv1", "TLSv1.1", "TLSv1.2", "TLSv1.3")); + expect + .that( + ProtoUtil.buildTlsProtocolVersionSet( + TLSVersion.TLS_VERSION_1_2, TLSVersion.TLS_VERSION_1_2)) + .isEqualTo(ImmutableSet.of("TLSv1.2")); + expect + .that( + ProtoUtil.buildTlsProtocolVersionSet( + TLSVersion.TLS_VERSION_1_3, TLSVersion.TLS_VERSION_1_3)) + .isEqualTo(ImmutableSet.of("TLSv1.3")); + expect + .that( + ProtoUtil.buildTlsProtocolVersionSet( + TLSVersion.TLS_VERSION_1_3, TLSVersion.TLS_VERSION_1_2)) + .isEmpty(); + } + + @Test + public void buildTlsProtocolVersionSet_failure() { + expect + .that( + ProtoUtil.buildTlsProtocolVersionSet( + TLSVersion.TLS_VERSION_UNSPECIFIED, TLSVersion.TLS_VERSION_1_3)) + .isEqualTo(ImmutableSet.of("TLSv1", "TLSv1.1", "TLSv1.2", "TLSv1.3")); + } +} \ No newline at end of file diff --git a/s2a/src/test/java/io/grpc/s2a/internal/handshaker/S2APrivateKeyMethodTest.java b/s2a/src/test/java/io/grpc/s2a/internal/handshaker/S2APrivateKeyMethodTest.java new file mode 100644 index 00000000000..8f71496cab8 --- /dev/null +++ b/s2a/src/test/java/io/grpc/s2a/internal/handshaker/S2APrivateKeyMethodTest.java @@ -0,0 +1,318 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.s2a.internal.handshaker; + +import static com.google.common.truth.Truth.assertThat; +import static java.nio.charset.StandardCharsets.UTF_8; +import static org.junit.Assert.assertThrows; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import com.google.common.truth.Expect; +import com.google.protobuf.ByteString; +import com.google.s2a.proto.v2.OffloadPrivateKeyOperationReq; +import com.google.s2a.proto.v2.OffloadPrivateKeyOperationResp; +import com.google.s2a.proto.v2.SessionReq; +import com.google.s2a.proto.v2.SessionResp; +import com.google.s2a.proto.v2.SignatureAlgorithm; +import io.grpc.netty.GrpcSslContexts; +import io.grpc.s2a.internal.handshaker.S2AIdentity; +import io.netty.handler.ssl.OpenSslPrivateKeyMethod; +import io.netty.handler.ssl.SslContextBuilder; +import java.io.ByteArrayInputStream; +import java.io.InputStream; +import java.security.PublicKey; +import java.security.Signature; +import java.security.cert.CertificateFactory; +import java.security.cert.X509Certificate; +import java.util.Optional; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public final class S2APrivateKeyMethodTest { + @Rule public final Expect expect = Expect.create(); + private static final byte[] DATA_TO_SIGN = "random bytes for signing.".getBytes(UTF_8); + + private S2AStub stub; + private FakeWriter writer; + private S2APrivateKeyMethod keyMethod; + + private static PublicKey extractPublicKeyFromPem(String pem) throws Exception { + X509Certificate cert = + (X509Certificate) + CertificateFactory.getInstance("X.509") + .generateCertificate(new ByteArrayInputStream(pem.getBytes(UTF_8))); + return cert.getPublicKey(); + } + + private static boolean verifySignature( + byte[] dataToSign, byte[] signature, String signatureAlgorithm) throws Exception { + Signature sig = Signature.getInstance(signatureAlgorithm); + InputStream leafCert = + S2APrivateKeyMethodTest.class.getClassLoader().getResourceAsStream("leaf_cert_ec.pem"); + sig.initVerify(extractPublicKeyFromPem(FakeWriter.convertInputStreamToString( + leafCert))); + leafCert.close(); + sig.update(dataToSign); + return sig.verify(signature); + } + + @Before + public void setUp() { + // This is line is to ensure that JNI correctly links the necessary objects. Without this, we + // get `java.lang.UnsatisfiedLinkError` on + // `io.netty.internal.tcnative.NativeStaticallyReferencedJniMethods.sslSignRsaPkcsSha1()` + GrpcSslContexts.configure(SslContextBuilder.forClient()); + + writer = new FakeWriter(); + stub = S2AStub.newInstanceForTesting(writer); + writer.setReader(stub.getReader()); + keyMethod = S2APrivateKeyMethod.create(stub, /* localIdentity= */ Optional.empty()); + } + + @Test + public void signatureAlgorithmConversion_success() { + expect + .that( + S2APrivateKeyMethod.convertOpenSslSignAlgToS2ASignAlg( + OpenSslPrivateKeyMethod.SSL_SIGN_RSA_PKCS1_SHA256)) + .isEqualTo(SignatureAlgorithm.S2A_SSL_SIGN_RSA_PKCS1_SHA256); + expect + .that( + S2APrivateKeyMethod.convertOpenSslSignAlgToS2ASignAlg( + OpenSslPrivateKeyMethod.SSL_SIGN_RSA_PKCS1_SHA384)) + .isEqualTo(SignatureAlgorithm.S2A_SSL_SIGN_RSA_PKCS1_SHA384); + expect + .that( + S2APrivateKeyMethod.convertOpenSslSignAlgToS2ASignAlg( + OpenSslPrivateKeyMethod.SSL_SIGN_RSA_PKCS1_SHA512)) + .isEqualTo(SignatureAlgorithm.S2A_SSL_SIGN_RSA_PKCS1_SHA512); + expect + .that( + S2APrivateKeyMethod.convertOpenSslSignAlgToS2ASignAlg( + OpenSslPrivateKeyMethod.SSL_SIGN_ECDSA_SECP256R1_SHA256)) + .isEqualTo(SignatureAlgorithm.S2A_SSL_SIGN_ECDSA_SECP256R1_SHA256); + expect + .that( + S2APrivateKeyMethod.convertOpenSslSignAlgToS2ASignAlg( + OpenSslPrivateKeyMethod.SSL_SIGN_ECDSA_SECP384R1_SHA384)) + .isEqualTo(SignatureAlgorithm.S2A_SSL_SIGN_ECDSA_SECP384R1_SHA384); + expect + .that( + S2APrivateKeyMethod.convertOpenSslSignAlgToS2ASignAlg( + OpenSslPrivateKeyMethod.SSL_SIGN_ECDSA_SECP521R1_SHA512)) + .isEqualTo(SignatureAlgorithm.S2A_SSL_SIGN_ECDSA_SECP521R1_SHA512); + expect + .that( + S2APrivateKeyMethod.convertOpenSslSignAlgToS2ASignAlg( + OpenSslPrivateKeyMethod.SSL_SIGN_RSA_PSS_RSAE_SHA256)) + .isEqualTo(SignatureAlgorithm.S2A_SSL_SIGN_RSA_PSS_RSAE_SHA256); + expect + .that( + S2APrivateKeyMethod.convertOpenSslSignAlgToS2ASignAlg( + OpenSslPrivateKeyMethod.SSL_SIGN_RSA_PSS_RSAE_SHA384)) + .isEqualTo(SignatureAlgorithm.S2A_SSL_SIGN_RSA_PSS_RSAE_SHA384); + expect + .that( + S2APrivateKeyMethod.convertOpenSslSignAlgToS2ASignAlg( + OpenSslPrivateKeyMethod.SSL_SIGN_RSA_PSS_RSAE_SHA512)) + .isEqualTo(SignatureAlgorithm.S2A_SSL_SIGN_RSA_PSS_RSAE_SHA512); + } + + @Test + public void signatureAlgorithmConversion_unsupportedOperation() { + UnsupportedOperationException e = + assertThrows( + UnsupportedOperationException.class, + () -> S2APrivateKeyMethod.convertOpenSslSignAlgToS2ASignAlg(-1)); + + assertThat(e).hasMessageThat().contains("Signature Algorithm -1 is not supported."); + } + + @Test + public void createOnNullStub_returnsNullPointerException() { + assertThrows( + NullPointerException.class, + () -> S2APrivateKeyMethod.create(/* stub= */ null, /* localIdentity= */ Optional.empty())); + } + + @Test + public void decrypt_unsupportedOperation() { + UnsupportedOperationException e = + assertThrows( + UnsupportedOperationException.class, + () -> keyMethod.decrypt(/* engine= */ null, DATA_TO_SIGN)); + + assertThat(e).hasMessageThat().contains("decrypt is not supported."); + } + + @Test + public void fakelocalIdentity_signWithSha256_success() throws Exception { + S2AIdentity fakeIdentity = S2AIdentity.fromSpiffeId("fake-spiffe-id"); + S2AStub mockStub = mock(S2AStub.class); + OpenSslPrivateKeyMethod keyMethodWithFakeIdentity = + S2APrivateKeyMethod.create(mockStub, Optional.of(fakeIdentity)); + SessionReq req = + SessionReq.newBuilder() + .setLocalIdentity(fakeIdentity.getIdentity()) + .setOffloadPrivateKeyOperationReq( + OffloadPrivateKeyOperationReq.newBuilder() + .setOperation(OffloadPrivateKeyOperationReq.PrivateKeyOperation.SIGN) + .setSignatureAlgorithm(SignatureAlgorithm.S2A_SSL_SIGN_ECDSA_SECP256R1_SHA256) + .setRawBytes(ByteString.copyFrom(DATA_TO_SIGN))) + .build(); + byte[] expectedOutbytes = "fake out bytes".getBytes(UTF_8); + when(mockStub.send(req)) + .thenReturn( + SessionResp.newBuilder() + .setOffloadPrivateKeyOperationResp( + OffloadPrivateKeyOperationResp.newBuilder() + .setOutBytes(ByteString.copyFrom(expectedOutbytes))) + .build()); + + byte[] signature = + keyMethodWithFakeIdentity.sign( + /* engine= */ null, + OpenSslPrivateKeyMethod.SSL_SIGN_ECDSA_SECP256R1_SHA256, + DATA_TO_SIGN); + verify(mockStub).send(req); + assertThat(signature).isEqualTo(expectedOutbytes); + } + + @Test + public void signWithSha256_success() throws Exception { + writer.initializePrivateKey().setBehavior(FakeWriter.Behavior.OK_STATUS); + + byte[] signature = + keyMethod.sign( + /* engine= */ null, + OpenSslPrivateKeyMethod.SSL_SIGN_ECDSA_SECP256R1_SHA256, + DATA_TO_SIGN); + + assertThat(signature).isNotEmpty(); + assertThat(verifySignature(DATA_TO_SIGN, signature, "SHA256withECDSA")).isTrue(); + } + + @Test + public void signWithSha384_success() throws Exception { + writer.initializePrivateKey().setBehavior(FakeWriter.Behavior.OK_STATUS); + + byte[] signature = + keyMethod.sign( + /* engine= */ null, + OpenSslPrivateKeyMethod.SSL_SIGN_ECDSA_SECP384R1_SHA384, + DATA_TO_SIGN); + + assertThat(signature).isNotEmpty(); + assertThat(verifySignature(DATA_TO_SIGN, signature, "SHA384withECDSA")).isTrue(); + } + + @Test + public void signWithSha512_success() throws Exception { + writer.initializePrivateKey().setBehavior(FakeWriter.Behavior.OK_STATUS); + + byte[] signature = + keyMethod.sign( + /* engine= */ null, + OpenSslPrivateKeyMethod.SSL_SIGN_ECDSA_SECP521R1_SHA512, + DATA_TO_SIGN); + + assertThat(signature).isNotEmpty(); + assertThat(verifySignature(DATA_TO_SIGN, signature, "SHA512withECDSA")).isTrue(); + } + + @Test + public void sign_noKeyAvailable() throws Exception { + writer.resetPrivateKey().setBehavior(FakeWriter.Behavior.OK_STATUS); + + S2AConnectionException e = + assertThrows( + S2AConnectionException.class, + () -> + keyMethod.sign( + /* engine= */ null, + OpenSslPrivateKeyMethod.SSL_SIGN_ECDSA_SECP256R1_SHA256, + DATA_TO_SIGN)); + + assertThat(e) + .hasMessageThat() + .contains( + "Error occurred in response from S2A, error code: 255, error message: \"No Private Key" + + " available.\"."); + } + + @Test + public void sign_algorithmNotSupported() throws Exception { + writer.initializePrivateKey().setBehavior(FakeWriter.Behavior.OK_STATUS); + + S2AConnectionException e = + assertThrows( + S2AConnectionException.class, + () -> + keyMethod.sign( + /* engine= */ null, + OpenSslPrivateKeyMethod.SSL_SIGN_RSA_PKCS1_SHA256, + DATA_TO_SIGN)); + + assertThat(e) + .hasMessageThat() + .contains( + "Error occurred in response from S2A, error code: 255, error message: \"Only ECDSA key" + + " algorithms are supported.\"."); + } + + @Test + public void sign_getsErrorResponse() throws Exception { + writer.initializePrivateKey().setBehavior(FakeWriter.Behavior.ERROR_STATUS); + + S2AConnectionException e = + assertThrows( + S2AConnectionException.class, + () -> + keyMethod.sign( + /* engine= */ null, + OpenSslPrivateKeyMethod.SSL_SIGN_ECDSA_SECP256R1_SHA256, + DATA_TO_SIGN)); + + assertThat(e) + .hasMessageThat() + .contains( + "Error occurred in response from S2A, error code: 1, error message: \"Intended ERROR" + + " Status from FakeWriter.\"."); + } + + @Test + public void sign_getsEmptyResponse() throws Exception { + writer.initializePrivateKey().setBehavior(FakeWriter.Behavior.EMPTY_RESPONSE); + + S2AConnectionException e = + assertThrows( + S2AConnectionException.class, + () -> + keyMethod.sign( + /* engine= */ null, + OpenSslPrivateKeyMethod.SSL_SIGN_ECDSA_SECP256R1_SHA256, + DATA_TO_SIGN)); + + assertThat(e).hasMessageThat().contains("No valid response received from S2A."); + } +} \ No newline at end of file diff --git a/s2a/src/test/java/io/grpc/s2a/internal/handshaker/S2AProtocolNegotiatorFactoryTest.java b/s2a/src/test/java/io/grpc/s2a/internal/handshaker/S2AProtocolNegotiatorFactoryTest.java new file mode 100644 index 00000000000..7e776f16da2 --- /dev/null +++ b/s2a/src/test/java/io/grpc/s2a/internal/handshaker/S2AProtocolNegotiatorFactoryTest.java @@ -0,0 +1,259 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.s2a.internal.handshaker; + +import static com.google.common.truth.Truth.assertThat; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; + +import com.google.common.testing.NullPointerTester; +import com.google.common.testing.NullPointerTester.Visibility; +import com.google.s2a.proto.v2.S2AServiceGrpc; +import com.google.s2a.proto.v2.SessionReq; +import com.google.s2a.proto.v2.SessionResp; +import io.grpc.Channel; +import io.grpc.InsecureChannelCredentials; +import io.grpc.Server; +import io.grpc.ServerBuilder; +import io.grpc.benchmarks.Utils; +import io.grpc.internal.ObjectPool; +import io.grpc.internal.SharedResourcePool; +import io.grpc.internal.TestUtils.NoopChannelLogger; +import io.grpc.netty.GrpcHttp2ConnectionHandler; +import io.grpc.netty.InternalProtocolNegotiator; +import io.grpc.netty.InternalProtocolNegotiator.ProtocolNegotiator; +import io.grpc.s2a.internal.channel.S2AHandshakerServiceChannel; +import io.grpc.s2a.internal.handshaker.S2AIdentity; +import io.grpc.s2a.internal.handshaker.S2AProtocolNegotiatorFactory.S2AProtocolNegotiator; +import io.grpc.stub.StreamObserver; +import io.netty.channel.ChannelDuplexHandler; +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelPromise; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.handler.codec.http2.Http2ConnectionDecoder; +import io.netty.handler.codec.http2.Http2ConnectionEncoder; +import io.netty.handler.codec.http2.Http2Settings; +import io.netty.util.AsciiString; +import java.io.IOException; +import java.util.Optional; +import org.checkerframework.checker.nullness.qual.Nullable; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Unit tests for {@link S2AProtocolNegotiatorFactory}. */ +@RunWith(JUnit4.class) +public class S2AProtocolNegotiatorFactoryTest { + private static final S2AIdentity LOCAL_IDENTITY = S2AIdentity.fromSpiffeId("local identity"); + private final ChannelHandlerContext mockChannelHandlerContext = mock(ChannelHandlerContext.class); + private GrpcHttp2ConnectionHandler fakeConnectionHandler; + private String authority; + private int port; + private Server fakeS2AServer; + private ObjectPool channelPool; + + @Before + public void setUp() throws Exception { + port = Utils.pickUnusedPort(); + fakeS2AServer = ServerBuilder.forPort(port).addService(new S2AServiceImpl()).build(); + fakeS2AServer.start(); + channelPool = new FakeChannelPool(); + authority = "localhost:" + port; + fakeConnectionHandler = FakeConnectionHandler.create(authority); + } + + @After + public void tearDown() { + fakeS2AServer.shutdown(); + } + + @Test + public void handlerRemoved_success() throws Exception { + S2AProtocolNegotiatorFactory.BufferReadsHandler handler1 = + new S2AProtocolNegotiatorFactory.BufferReadsHandler(); + S2AProtocolNegotiatorFactory.BufferReadsHandler handler2 = + new S2AProtocolNegotiatorFactory.BufferReadsHandler(); + EmbeddedChannel channel = new EmbeddedChannel(handler1, handler2); + channel.writeInbound("message1"); + channel.writeInbound("message2"); + channel.writeInbound("message3"); + assertThat(handler1.getReads()).hasSize(3); + assertThat(handler2.getReads()).isEmpty(); + channel.pipeline().remove(handler1); + assertThat(handler2.getReads()).hasSize(3); + } + + @Test + public void createProtocolNegotiatorFactory_nullArgument() throws Exception { + NullPointerTester tester = new NullPointerTester().setDefault(Optional.class, Optional.empty()); + + tester.testStaticMethods(S2AProtocolNegotiatorFactory.class, Visibility.PUBLIC); + } + + @Test + public void createProtocolNegotiator_nullArgument() throws Exception { + ObjectPool pool = + SharedResourcePool.forResource( + S2AHandshakerServiceChannel.getChannelResource( + "localhost:8080", InsecureChannelCredentials.create())); + + NullPointerTester tester = + new NullPointerTester() + .setDefault(ObjectPool.class, pool) + .setDefault(Optional.class, Optional.empty()); + + tester.testStaticMethods(S2AProtocolNegotiator.class, Visibility.PACKAGE); + } + + @Test + public void createProtocolNegotiatorFactory_getsDefaultPort_succeeds() throws Exception { + InternalProtocolNegotiator.ClientFactory clientFactory = + S2AProtocolNegotiatorFactory.createClientFactory(LOCAL_IDENTITY, channelPool, null); + + assertThat(clientFactory.getDefaultPort()).isEqualTo(S2AProtocolNegotiatorFactory.DEFAULT_PORT); + } + + @Test + public void s2aProtocolNegotiator_getHostNameOnNull_returnsNull() throws Exception { + assertThat(S2AProtocolNegotiatorFactory.S2AProtocolNegotiator.getHostNameFromAuthority(null)) + .isNull(); + } + + @Test + public void s2aProtocolNegotiator_getHostNameOnValidAuthority_returnsValidHostname() + throws Exception { + assertThat( + S2AProtocolNegotiatorFactory.S2AProtocolNegotiator.getHostNameFromAuthority( + "hostname:80")) + .isEqualTo("hostname"); + } + + @Test + public void createProtocolNegotiatorFactory_buildsAnS2AProtocolNegotiatorOnClientSide_succeeds() + throws Exception { + InternalProtocolNegotiator.ClientFactory clientFactory = + S2AProtocolNegotiatorFactory.createClientFactory(LOCAL_IDENTITY, channelPool, null); + + ProtocolNegotiator clientNegotiator = clientFactory.newNegotiator(); + + assertThat(clientNegotiator).isInstanceOf(S2AProtocolNegotiator.class); + assertThat(clientNegotiator.scheme()).isEqualTo(AsciiString.of("https")); + } + + @Test + public void closeProtocolNegotiator_verifyProtocolNegotiatorIsClosedOnClientSide() + throws Exception { + InternalProtocolNegotiator.ClientFactory clientFactory = + S2AProtocolNegotiatorFactory.createClientFactory(LOCAL_IDENTITY, channelPool, null); + ProtocolNegotiator clientNegotiator = clientFactory.newNegotiator(); + + clientNegotiator.close(); + + assertThat(((FakeChannelPool) channelPool).isChannelCached()).isFalse(); + } + + @Test + public void createChannelHandler_addHandlerToMockContext() throws Exception { + ProtocolNegotiator clientNegotiator = + S2AProtocolNegotiatorFactory.S2AProtocolNegotiator.createForClient( + channelPool, LOCAL_IDENTITY, null); + + ChannelHandler channelHandler = clientNegotiator.newHandler(fakeConnectionHandler); + + ((ChannelDuplexHandler) channelHandler).userEventTriggered(mockChannelHandlerContext, "event"); + verify(mockChannelHandlerContext).fireUserEventTriggered("event"); + } + + /** A {@code GrpcHttp2ConnectionHandler} that does nothing. */ + private static class FakeConnectionHandler extends GrpcHttp2ConnectionHandler { + private static final Http2ConnectionDecoder DECODER = mock(Http2ConnectionDecoder.class); + private static final Http2ConnectionEncoder ENCODER = mock(Http2ConnectionEncoder.class); + private static final Http2Settings SETTINGS = new Http2Settings(); + private final String authority; + + static FakeConnectionHandler create(String authority) { + return new FakeConnectionHandler(null, DECODER, ENCODER, SETTINGS, authority); + } + + private FakeConnectionHandler( + ChannelPromise channelUnused, + Http2ConnectionDecoder decoder, + Http2ConnectionEncoder encoder, + Http2Settings initialSettings, + String authority) { + super(channelUnused, decoder, encoder, initialSettings, new NoopChannelLogger()); + this.authority = authority; + } + + @Override + public String getAuthority() { + return authority; + } + } + + /** An S2A server that handles GetTlsConfiguration request. */ + private static class S2AServiceImpl extends S2AServiceGrpc.S2AServiceImplBase { + static final FakeWriter writer = new FakeWriter(); + + @Override + public StreamObserver setUpSession(StreamObserver responseObserver) { + return new StreamObserver() { + @Override + public void onNext(SessionReq req) { + try { + responseObserver.onNext(writer.handleResponse(req)); + } catch (IOException e) { + responseObserver.onError(e); + } + } + + @Override + public void onError(Throwable t) {} + + @Override + public void onCompleted() {} + }; + } + } + + private static class FakeChannelPool implements ObjectPool { + private final Channel mockChannel = mock(Channel.class); + private @Nullable Channel cachedChannel = null; + + @Override + public Channel getObject() { + if (cachedChannel == null) { + cachedChannel = mockChannel; + } + return cachedChannel; + } + + @Override + public Channel returnObject(Object object) { + assertThat(object).isSameInstanceAs(mockChannel); + cachedChannel = null; + return null; + } + + public boolean isChannelCached() { + return (cachedChannel != null); + } + } +} \ No newline at end of file diff --git a/s2a/src/test/java/io/grpc/s2a/internal/handshaker/S2AStubTest.java b/s2a/src/test/java/io/grpc/s2a/internal/handshaker/S2AStubTest.java new file mode 100644 index 00000000000..2c7a7dd8405 --- /dev/null +++ b/s2a/src/test/java/io/grpc/s2a/internal/handshaker/S2AStubTest.java @@ -0,0 +1,285 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.s2a.internal.handshaker; + +import static com.google.common.truth.Truth.assertThat; +import static com.google.common.truth.extensions.proto.ProtoTruth.assertThat; +import static org.junit.Assert.assertThrows; + +import com.google.common.truth.Expect; +import com.google.s2a.proto.v2.Ciphersuite; +import com.google.s2a.proto.v2.ConnectionSide; +import com.google.s2a.proto.v2.GetTlsConfigurationReq; +import com.google.s2a.proto.v2.GetTlsConfigurationResp; +import com.google.s2a.proto.v2.S2AServiceGrpc; +import com.google.s2a.proto.v2.SessionReq; +import com.google.s2a.proto.v2.SessionResp; +import com.google.s2a.proto.v2.Status; +import com.google.s2a.proto.v2.TLSVersion; +import io.grpc.Channel; +import io.grpc.InsecureChannelCredentials; +import io.grpc.internal.ObjectPool; +import io.grpc.internal.SharedResourcePool; +import io.grpc.s2a.internal.channel.S2AHandshakerServiceChannel; +import io.grpc.stub.StreamObserver; +import java.io.IOException; +import java.io.InputStream; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Unit tests for {@link S2AStub}. */ +@RunWith(JUnit4.class) +public class S2AStubTest { + @Rule public final Expect expect = Expect.create(); + private static final String S2A_ADDRESS = "localhost:8080"; + private S2AStub stub; + private FakeWriter writer; + + @Before + public void setUp() { + writer = new FakeWriter(); + stub = S2AStub.newInstanceForTesting(writer); + writer.setReader(stub.getReader()); + } + + @Test + public void send_receiveOkStatus() throws Exception { + SessionReq req = + SessionReq.newBuilder() + .setGetTlsConfigurationReq( + GetTlsConfigurationReq.newBuilder() + .setConnectionSide(ConnectionSide.CONNECTION_SIDE_CLIENT)) + .build(); + + SessionResp resp = stub.send(req); + + assertThat(resp.hasGetTlsConfigurationResp()).isTrue(); + assertThat(resp.getGetTlsConfigurationResp().hasClientTlsConfiguration()).isTrue(); + } + + @Test + public void send_clientTlsConfiguration_receiveOkStatus() throws Exception { + SessionReq req = + SessionReq.newBuilder() + .setGetTlsConfigurationReq( + GetTlsConfigurationReq.newBuilder() + .setConnectionSide(ConnectionSide.CONNECTION_SIDE_CLIENT)) + .build(); + + SessionResp resp = stub.send(req); + + String leafCertString = ""; + String cert2String = ""; + String cert1String = ""; + ClassLoader classLoader = S2AStubTest.class.getClassLoader(); + try ( + InputStream leafCert = classLoader.getResourceAsStream("leaf_cert_ec.pem"); + InputStream cert2 = classLoader.getResourceAsStream("int_cert2_ec.pem"); + InputStream cert1 = classLoader.getResourceAsStream("int_cert1_ec.pem"); + ) { + leafCertString = FakeWriter.convertInputStreamToString(leafCert); + cert2String = FakeWriter.convertInputStreamToString(cert2); + cert1String = FakeWriter.convertInputStreamToString(cert1); + } + + SessionResp expected = + SessionResp.newBuilder() + .setGetTlsConfigurationResp( + GetTlsConfigurationResp.newBuilder() + .setClientTlsConfiguration( + GetTlsConfigurationResp.ClientTlsConfiguration.newBuilder() + .addCertificateChain(leafCertString) + .addCertificateChain(cert1String) + .addCertificateChain(cert2String) + .setMinTlsVersion(TLSVersion.TLS_VERSION_1_3) + .setMaxTlsVersion(TLSVersion.TLS_VERSION_1_3) + .addCiphersuites( + Ciphersuite.CIPHERSUITE_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256) + .addCiphersuites( + Ciphersuite.CIPHERSUITE_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384) + .addCiphersuites( + Ciphersuite.CIPHERSUITE_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256))) + .build(); + assertThat(resp).ignoringRepeatedFieldOrder().isEqualTo(expected); + } + + @Test + public void send_serverTlsConfiguration_receiveErrorStatus() throws Exception { + SessionReq req = + SessionReq.newBuilder() + .setGetTlsConfigurationReq( + GetTlsConfigurationReq.newBuilder() + .setConnectionSide(ConnectionSide.CONNECTION_SIDE_SERVER)) + .build(); + + SessionResp resp = stub.send(req); + + SessionResp expected = + SessionResp.newBuilder() + .setStatus( + Status.newBuilder() + .setCode(255) + .setDetails("No TLS configuration for the server side.")) + .build(); + assertThat(resp).isEqualTo(expected); + } + + @Test + public void send_receiveErrorStatus() throws Exception { + writer.setBehavior(FakeWriter.Behavior.ERROR_STATUS); + + SessionResp resp = stub.send(SessionReq.getDefaultInstance()); + + SessionResp expected = + SessionResp.newBuilder() + .setStatus( + Status.newBuilder().setCode(1).setDetails("Intended ERROR Status from FakeWriter.")) + .build(); + assertThat(resp).isEqualTo(expected); + } + + @Test + public void send_receiveErrorResponse() throws InterruptedException { + writer.setBehavior(FakeWriter.Behavior.ERROR_RESPONSE); + + IOException expected = + assertThrows(IOException.class, () -> stub.send(SessionReq.getDefaultInstance())); + + expect.that(expected).hasCauseThat().isInstanceOf(RuntimeException.class); + expect.that(expected).hasMessageThat().contains("Intended ERROR from FakeWriter."); + } + + @Test + public void send_receiveCompleteStatus() throws Exception { + writer.setBehavior(FakeWriter.Behavior.COMPLETE_STATUS); + + ConnectionClosedException expected = + assertThrows( + ConnectionClosedException.class, () -> stub.send(SessionReq.getDefaultInstance())); + + assertThat(expected).hasMessageThat().contains("Reading from the S2A is complete."); + } + + @Test + public void send_receiveUnexpectedResponse() throws Exception { + writer.sendIoError(); + + IOException expected = + assertThrows(IOException.class, () -> stub.send(SessionReq.getDefaultInstance())); + + assertThat(expected) + .hasMessageThat() + .contains( + "Received an unexpected response from a host at the S2A's address. The S2A might be" + + " unavailable."); + } + + @Test + public void send_receiveManyUnexpectedResponse_expectResponsesEmpty() throws Exception { + writer.sendIoError(); + writer.sendIoError(); + writer.sendIoError(); + + IOException expected = + assertThrows(IOException.class, () -> stub.send(SessionReq.getDefaultInstance())); + + assertThat(expected) + .hasMessageThat() + .contains( + "Received an unexpected response from a host at the S2A's address. The S2A might be" + + " unavailable."); + + assertThat(stub.getResponses()).isEmpty(); + } + + @Test + public void send_receiveDelayedResponse() throws Exception { + writer.sendGetTlsConfigResp(); + IOException expectedException = + assertThrows(IOException.class, () -> stub.send(SessionReq.getDefaultInstance())); + assertThat(expectedException) + .hasMessageThat() + .contains("Received an unexpected response from a host at the S2A's address."); + + assertThat(stub.getResponses()).isEmpty(); + } + + @Test + public void send_afterEarlyClose_receivesClosedException() throws InterruptedException { + stub.close(); + expect.that(writer.isFakeWriterClosed()).isTrue(); + + ConnectionClosedException expected = + assertThrows( + ConnectionClosedException.class, () -> stub.send(SessionReq.getDefaultInstance())); + + assertThat(expected).hasMessageThat().contains("Stream to the S2A is closed."); + } + + @Test + public void send_withUnavailableService_throwsDeadlineExceeded() throws Exception { + ObjectPool channelPool = + SharedResourcePool.forResource( + S2AHandshakerServiceChannel.getChannelResource( + S2A_ADDRESS, InsecureChannelCredentials.create())); + S2AServiceGrpc.S2AServiceStub serviceStub = S2AServiceGrpc.newStub(channelPool.getObject()); + S2AStub newStub = S2AStub.newInstanceWithDeadline(serviceStub, 1); + + IOException expected = + assertThrows(IOException.class, () -> newStub.send(SessionReq.getDefaultInstance())); + + assertThat(expected).hasMessageThat().contains("DEADLINE_EXCEEDED"); + } + + @Test + public void send_failToWrite() throws Exception { + FailWriter failWriter = new FailWriter(); + stub = S2AStub.newInstanceForTesting(failWriter); + + IOException expected = + assertThrows(IOException.class, () -> stub.send(SessionReq.getDefaultInstance())); + + expect.that(expected).hasCauseThat().isInstanceOf(S2AConnectionException.class); + expect + .that(expected) + .hasCauseThat() + .hasMessageThat() + .isEqualTo("Could not send request to S2A."); + } + + /** Fails whenever a write is attempted. */ + private static class FailWriter implements StreamObserver { + @Override + public void onNext(SessionReq req) { + assertThat(req).isNotNull(); + throw new S2AConnectionException("Could not send request to S2A."); + } + + @Override + public void onError(Throwable t) { + assertThat(t).isInstanceOf(S2AConnectionException.class); + } + + @Override + public void onCompleted() { + throw new UnsupportedOperationException(); + } + } +} \ No newline at end of file diff --git a/s2a/src/test/java/io/grpc/s2a/internal/handshaker/S2ATrustManagerTest.java b/s2a/src/test/java/io/grpc/s2a/internal/handshaker/S2ATrustManagerTest.java new file mode 100644 index 00000000000..198001838aa --- /dev/null +++ b/s2a/src/test/java/io/grpc/s2a/internal/handshaker/S2ATrustManagerTest.java @@ -0,0 +1,262 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.s2a.internal.handshaker; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; + +import io.grpc.s2a.internal.handshaker.S2AIdentity; +import java.io.ByteArrayInputStream; +import java.security.cert.CertificateException; +import java.security.cert.CertificateFactory; +import java.security.cert.X509Certificate; +import java.util.Base64; +import java.util.Optional; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public final class S2ATrustManagerTest { + private S2AStub stub; + private FakeWriter writer; + private static final String FAKE_HOSTNAME = "Fake-Hostname"; + private static final String CLIENT_CERT_PEM = + "MIICKjCCAc+gAwIBAgIUC2GShcVO+5Zkml+7VO3OQ+B2c7EwCgYIKoZIzj0EAwIw" + + "HzEdMBsGA1UEAwwUcm9vdGNlcnQuZXhhbXBsZS5jb20wIBcNMjMwMTI2MTk0OTUx" + + "WhgPMjA1MDA2MTMxOTQ5NTFaMB8xHTAbBgNVBAMMFGxlYWZjZXJ0LmV4YW1wbGUu" + + "Y29tMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEeciYZgFAZjxyzTrklCRIWpad" + + "8wkyCZQzJSf0IfNn9NKtfzL2V/blteULO0o9Da8e2Avaj+XCKfFTc7salMo/waOB" + + "5jCB4zAOBgNVHQ8BAf8EBAMCB4AwIAYDVR0lAQH/BBYwFAYIKwYBBQUHAwIGCCsG" + + "AQUFBwMBMAwGA1UdEwEB/wQCMAAwYQYDVR0RBFowWIYic3BpZmZlOi8vZm9vLnBy" + + "b2QuZ29vZ2xlLmNvbS9wMS9wMoIUZm9vLnByb2Quc3BpZmZlLmdvb2eCHG1hY2hp" + + "bmUtbmFtZS5wcm9kLmdvb2dsZS5jb20wHQYDVR0OBBYEFETY6Cu/aW924nfvUrOs" + + "yXCC1hrpMB8GA1UdIwQYMBaAFJLkXGlTYKISiGd+K/Ijh4IOEpHBMAoGCCqGSM49" + + "BAMCA0kAMEYCIQCZDW472c1/4jEOHES/88X7NTqsYnLtIpTjp5PZ62z3sAIhAN1J" + + "vxvbxt9ySdFO+cW7oLBEkCwUicBhxJi5VfQeQypT"; + + @Before + public void setUp() { + writer = new FakeWriter(); + stub = S2AStub.newInstanceForTesting(writer); + writer.setReader(stub.getReader()); + } + + @Test + public void createForClient_withNullStub_throwsError() { + NullPointerException expected = + assertThrows( + NullPointerException.class, + () -> + S2ATrustManager.createForClient( + /* stub= */ null, FAKE_HOSTNAME, /* localIdentity= */ Optional.empty())); + + assertThat(expected).hasMessageThat().isNull(); + } + + @Test + public void createForClient_withNullHostname_throwsError() { + NullPointerException expected = + assertThrows( + NullPointerException.class, + () -> + S2ATrustManager.createForClient( + stub, /* hostname= */ null, /* localIdentity= */ Optional.empty())); + + assertThat(expected).hasMessageThat().isNull(); + } + + @Test + public void getAcceptedIssuers_returnsExpectedNullResult() { + S2ATrustManager trustManager = + S2ATrustManager.createForClient(stub, FAKE_HOSTNAME, /* localIdentity= */ Optional.empty()); + + assertThat(trustManager.getAcceptedIssuers()).isNull(); + } + + @Test + public void checkClientTrusted_withEmptyCertificateChain_throwsException() + throws CertificateException { + writer.setVerificationResult(FakeWriter.VerificationResult.SUCCESS); + S2ATrustManager trustManager = + S2ATrustManager.createForClient(stub, FAKE_HOSTNAME, /* localIdentity= */ Optional.empty()); + + IllegalArgumentException expected = + assertThrows( + IllegalArgumentException.class, + () -> trustManager.checkClientTrusted(new X509Certificate[] {}, /* authType= */ "")); + + assertThat(expected).hasMessageThat().contains("Certificate chain has zero certificates."); + } + + @Test + public void checkServerTrusted_withEmptyCertificateChain_throwsException() + throws CertificateException { + writer.setVerificationResult(FakeWriter.VerificationResult.SUCCESS); + S2ATrustManager trustManager = + S2ATrustManager.createForClient(stub, FAKE_HOSTNAME, /* localIdentity= */ Optional.empty()); + + IllegalArgumentException expected = + assertThrows( + IllegalArgumentException.class, + () -> trustManager.checkServerTrusted(new X509Certificate[] {}, /* authType= */ "")); + + assertThat(expected).hasMessageThat().contains("Certificate chain has zero certificates."); + } + + @Test + public void checkClientTrusted_getsSuccessResponse() throws CertificateException { + writer.setVerificationResult(FakeWriter.VerificationResult.SUCCESS); + S2ATrustManager trustManager = + S2ATrustManager.createForClient(stub, FAKE_HOSTNAME, /* localIdentity= */ Optional.empty()); + + // Expect no exception. + trustManager.checkClientTrusted(getCerts(), /* authType= */ ""); + } + + @Test + public void checkClientTrusted_withLocalIdentity_getsSuccessResponse() + throws CertificateException { + writer.setVerificationResult(FakeWriter.VerificationResult.SUCCESS); + S2ATrustManager trustManager = + S2ATrustManager.createForClient( + stub, FAKE_HOSTNAME, Optional.of(S2AIdentity.fromSpiffeId("fake-spiffe-id"))); + + // Expect no exception. + trustManager.checkClientTrusted(getCerts(), /* authType= */ ""); + } + + @Test + public void checkServerTrusted_getsSuccessResponse() throws CertificateException { + writer.setVerificationResult(FakeWriter.VerificationResult.SUCCESS); + S2ATrustManager trustManager = + S2ATrustManager.createForClient(stub, FAKE_HOSTNAME, /* localIdentity= */ Optional.empty()); + + // Expect no exception. + trustManager.checkServerTrusted(getCerts(), /* authType= */ ""); + } + + @Test + public void checkServerTrusted_withLocalIdentity_getsSuccessResponse() + throws CertificateException { + writer.setVerificationResult(FakeWriter.VerificationResult.SUCCESS); + S2ATrustManager trustManager = + S2ATrustManager.createForClient( + stub, FAKE_HOSTNAME, Optional.of(S2AIdentity.fromSpiffeId("fake-spiffe-id"))); + + // Expect no exception. + trustManager.checkServerTrusted(getCerts(), /* authType= */ ""); + } + + @Test + public void checkClientTrusted_getsIntendedFailureResponse() throws CertificateException { + writer + .setVerificationResult(FakeWriter.VerificationResult.FAILURE) + .setFailureReason("Intended failure."); + S2ATrustManager trustManager = + S2ATrustManager.createForClient(stub, FAKE_HOSTNAME, /* localIdentity= */ Optional.empty()); + + CertificateException expected = + assertThrows( + CertificateException.class, + () -> trustManager.checkClientTrusted(getCerts(), /* authType= */ "")); + + assertThat(expected).hasMessageThat().contains("Intended failure."); + } + + @Test + public void checkClientTrusted_getsIntendedFailureStatusInResponse() throws CertificateException { + writer.setBehavior(FakeWriter.Behavior.ERROR_STATUS); + S2ATrustManager trustManager = + S2ATrustManager.createForClient(stub, FAKE_HOSTNAME, /* localIdentity= */ Optional.empty()); + + CertificateException expected = + assertThrows( + CertificateException.class, + () -> trustManager.checkClientTrusted(getCerts(), /* authType= */ "")); + + assertThat(expected).hasMessageThat().contains("Error occurred in response from S2A"); + } + + @Test + public void checkClientTrusted_getsIntendedFailureFromServer() throws CertificateException { + writer.setBehavior(FakeWriter.Behavior.ERROR_RESPONSE); + S2ATrustManager trustManager = + S2ATrustManager.createForClient(stub, FAKE_HOSTNAME, /* localIdentity= */ Optional.empty()); + + CertificateException expected = + assertThrows( + CertificateException.class, + () -> trustManager.checkClientTrusted(getCerts(), /* authType= */ "")); + + assertThat(expected).hasMessageThat().isEqualTo("Failed to send request to S2A."); + } + + @Test + public void checkServerTrusted_getsIntendedFailureResponse() throws CertificateException { + writer + .setVerificationResult(FakeWriter.VerificationResult.FAILURE) + .setFailureReason("Intended failure."); + S2ATrustManager trustManager = + S2ATrustManager.createForClient(stub, FAKE_HOSTNAME, /* localIdentity= */ Optional.empty()); + + CertificateException expected = + assertThrows( + CertificateException.class, + () -> trustManager.checkServerTrusted(getCerts(), /* authType= */ "")); + + assertThat(expected).hasMessageThat().contains("Intended failure."); + } + + @Test + public void checkServerTrusted_getsIntendedFailureStatusInResponse() throws CertificateException { + writer.setBehavior(FakeWriter.Behavior.ERROR_STATUS); + S2ATrustManager trustManager = + S2ATrustManager.createForClient(stub, FAKE_HOSTNAME, /* localIdentity= */ Optional.empty()); + + CertificateException expected = + assertThrows( + CertificateException.class, + () -> trustManager.checkServerTrusted(getCerts(), /* authType= */ "")); + + assertThat(expected).hasMessageThat().contains("Error occurred in response from S2A"); + } + + @Test + public void checkServerTrusted_getsIntendedFailureFromServer() throws CertificateException { + writer.setBehavior(FakeWriter.Behavior.ERROR_RESPONSE); + S2ATrustManager trustManager = + S2ATrustManager.createForClient(stub, FAKE_HOSTNAME, /* localIdentity= */ Optional.empty()); + + CertificateException expected = + assertThrows( + CertificateException.class, + () -> trustManager.checkServerTrusted(getCerts(), /* authType= */ "")); + + assertThat(expected).hasMessageThat().isEqualTo("Failed to send request to S2A."); + } + + private X509Certificate[] getCerts() throws CertificateException { + byte[] decoded = Base64.getDecoder().decode(CLIENT_CERT_PEM); + return new X509Certificate[] { + (X509Certificate) + CertificateFactory.getInstance("X.509") + .generateCertificate(new ByteArrayInputStream(decoded)) + }; + } +} \ No newline at end of file diff --git a/s2a/src/test/java/io/grpc/s2a/internal/handshaker/SslContextFactoryTest.java b/s2a/src/test/java/io/grpc/s2a/internal/handshaker/SslContextFactoryTest.java new file mode 100644 index 00000000000..17b834abf2a --- /dev/null +++ b/s2a/src/test/java/io/grpc/s2a/internal/handshaker/SslContextFactoryTest.java @@ -0,0 +1,177 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.s2a.internal.handshaker; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; + +import com.google.common.truth.Expect; +import io.grpc.s2a.internal.handshaker.S2AIdentity; +import io.netty.handler.ssl.OpenSslSessionContext; +import io.netty.handler.ssl.SslContext; +import java.security.GeneralSecurityException; +import java.util.Optional; +import javax.net.ssl.SSLSessionContext; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Unit tests for {@link SslContextFactory}. */ +@RunWith(JUnit4.class) +public final class SslContextFactoryTest { + @Rule public final Expect expect = Expect.create(); + private static final String FAKE_TARGET_NAME = "fake_target_name"; + private S2AStub stub; + private FakeWriter writer; + + @Before + public void setUp() { + writer = new FakeWriter(); + stub = S2AStub.newInstanceForTesting(writer); + writer.setReader(stub.getReader()); + } + + @Test + public void createForClient_returnsValidSslContext() throws Exception { + SslContext sslContext = + SslContextFactory.createForClient( + stub, FAKE_TARGET_NAME, /* localIdentity= */ Optional.empty()); + + expect.that(sslContext).isNotNull(); + expect.that(sslContext.sessionCacheSize()).isEqualTo(1); + expect.that(sslContext.sessionTimeout()).isEqualTo(300); + expect.that(sslContext.isClient()).isTrue(); + expect.that(sslContext.applicationProtocolNegotiator().protocols()).containsExactly("h2"); + SSLSessionContext sslSessionContext = sslContext.sessionContext(); + if (sslSessionContext instanceof OpenSslSessionContext) { + OpenSslSessionContext openSslSessionContext = (OpenSslSessionContext) sslSessionContext; + expect.that(openSslSessionContext.isSessionCacheEnabled()).isFalse(); + } + } + + @Test + public void createForClient_withLocalIdentity_returnsValidSslContext() throws Exception { + SslContext sslContext = + SslContextFactory.createForClient( + stub, FAKE_TARGET_NAME, Optional.of(S2AIdentity.fromSpiffeId("fake-spiffe-id"))); + + expect.that(sslContext).isNotNull(); + expect.that(sslContext.sessionCacheSize()).isEqualTo(1); + expect.that(sslContext.sessionTimeout()).isEqualTo(300); + expect.that(sslContext.isClient()).isTrue(); + expect.that(sslContext.applicationProtocolNegotiator().protocols()).containsExactly("h2"); + SSLSessionContext sslSessionContext = sslContext.sessionContext(); + if (sslSessionContext instanceof OpenSslSessionContext) { + OpenSslSessionContext openSslSessionContext = (OpenSslSessionContext) sslSessionContext; + expect.that(openSslSessionContext.isSessionCacheEnabled()).isFalse(); + } + } + + @Test + public void createForClient_returnsEmptyResponse_error() throws Exception { + writer.setBehavior(FakeWriter.Behavior.EMPTY_RESPONSE); + + S2AConnectionException expected = + assertThrows( + S2AConnectionException.class, + () -> + SslContextFactory.createForClient( + stub, FAKE_TARGET_NAME, /* localIdentity= */ Optional.empty())); + + assertThat(expected) + .hasMessageThat() + .contains("Response from S2A server does NOT contain ClientTlsConfiguration."); + } + + @Test + public void createForClient_returnsErrorStatus_error() throws Exception { + writer.setBehavior(FakeWriter.Behavior.ERROR_STATUS); + + S2AConnectionException expected = + assertThrows( + S2AConnectionException.class, + () -> + SslContextFactory.createForClient( + stub, FAKE_TARGET_NAME, /* localIdentity= */ Optional.empty())); + + assertThat(expected).hasMessageThat().contains("Intended ERROR Status from FakeWriter."); + } + + @Test + public void createForClient_getsErrorFromServer_throwsError() throws Exception { + writer.sendIoError(); + + GeneralSecurityException expected = + assertThrows( + GeneralSecurityException.class, + () -> + SslContextFactory.createForClient( + stub, FAKE_TARGET_NAME, /* localIdentity= */ Optional.empty())); + + assertThat(expected) + .hasMessageThat() + .contains("Failed to get client TLS configuration from S2A."); + } + + @Test + public void createForClient_getsBadTlsVersionsFromServer_throwsError() throws Exception { + writer.setBehavior(FakeWriter.Behavior.BAD_TLS_VERSION_RESPONSE); + + S2AConnectionException expected = + assertThrows( + S2AConnectionException.class, + () -> + SslContextFactory.createForClient( + stub, FAKE_TARGET_NAME, /* localIdentity= */ Optional.empty())); + + assertThat(expected) + .hasMessageThat() + .contains("Set of TLS versions received from S2A server is empty or not supported."); + } + + @Test + public void createForClient_nullStub_throwsError() throws Exception { + writer.sendUnexpectedResponse(); + + NullPointerException expected = + assertThrows( + NullPointerException.class, + () -> + SslContextFactory.createForClient( + /* stub= */ null, FAKE_TARGET_NAME, /* localIdentity= */ Optional.empty())); + + assertThat(expected).hasMessageThat().isEqualTo("stub should not be null."); + } + + @Test + public void createForClient_nullTargetName_throwsError() throws Exception { + writer.sendUnexpectedResponse(); + + NullPointerException expected = + assertThrows( + NullPointerException.class, + () -> + SslContextFactory.createForClient( + stub, /* targetName= */ null, /* localIdentity= */ Optional.empty())); + + assertThat(expected) + .hasMessageThat() + .isEqualTo("targetName should not be null on client side."); + } +} \ No newline at end of file diff --git a/s2a/src/test/java/io/grpc/s2a/internal/handshaker/tokenmanager/SingleTokenAccessTokenManagerTest.java b/s2a/src/test/java/io/grpc/s2a/internal/handshaker/tokenmanager/SingleTokenAccessTokenManagerTest.java new file mode 100644 index 00000000000..9fd33fe9070 --- /dev/null +++ b/s2a/src/test/java/io/grpc/s2a/internal/handshaker/tokenmanager/SingleTokenAccessTokenManagerTest.java @@ -0,0 +1,80 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.s2a.internal.handshaker.tokenmanager; + +import static com.google.common.truth.Truth.assertThat; + +import io.grpc.s2a.internal.handshaker.S2AIdentity; +import java.util.Optional; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public final class SingleTokenAccessTokenManagerTest { + private static final S2AIdentity IDENTITY = S2AIdentity.fromSpiffeId("spiffe_id"); + private static final String TOKEN = "token"; + + private String originalAccessToken; + + @Before + public void setUp() { + originalAccessToken = SingleTokenFetcher.getAccessToken(); + SingleTokenFetcher.setAccessToken(null); + } + + @After + public void tearDown() { + SingleTokenFetcher.setAccessToken(originalAccessToken); + } + + @Test + public void getDefaultToken_success() throws Exception { + SingleTokenFetcher.setAccessToken(TOKEN); + Optional manager = AccessTokenManager.create(); + assertThat(manager).isPresent(); + assertThat(manager.get().getDefaultToken()).isEqualTo(TOKEN); + } + + @Test + public void getToken_success() throws Exception { + SingleTokenFetcher.setAccessToken(TOKEN); + Optional manager = AccessTokenManager.create(); + assertThat(manager).isPresent(); + assertThat(manager.get().getToken(IDENTITY)).isEqualTo(TOKEN); + } + + @Test + public void getToken_noEnvironmentVariable() throws Exception { + assertThat(SingleTokenFetcher.create()).isEmpty(); + } + + @Test + public void create_success() throws Exception { + SingleTokenFetcher.setAccessToken(TOKEN); + Optional manager = AccessTokenManager.create(); + assertThat(manager).isPresent(); + assertThat(manager.get().getToken(IDENTITY)).isEqualTo(TOKEN); + } + + @Test + public void create_noEnvironmentVariable() throws Exception { + assertThat(AccessTokenManager.create()).isEmpty(); + } +} diff --git a/s2a/src/test/resources/README.md b/s2a/src/test/resources/README.md new file mode 100644 index 00000000000..2250ffb1dec --- /dev/null +++ b/s2a/src/test/resources/README.md @@ -0,0 +1,69 @@ +# Generating certificates and keys for testing mTLS-S2A + +Content from: https://github.com/google/s2a-go/blob/main/testdata/README.md + +Create root CA + +``` +openssl req -x509 -sha256 -days 7305 -newkey rsa:2048 -keyout root_key.pem -out +root_cert.pem +``` + +Generate private keys for server and client + +``` +openssl genrsa -out server_key.pem 2048 +openssl genrsa -out client_key.pem 2048 +``` + +Generate CSRs for server and client (set Common Name to localhost, leave all +other fields blank) + +``` +openssl req -key server_key.pem -new -out server.csr -config config.cnf +openssl req -key client_key.pem -new -out client.csr -config config.cnf +``` + +Sign CSRs for server and client + +``` +openssl x509 -req -CA root_cert.pem -CAkey root_key.pem -in server.csr -out server_cert.pem -days 7305 -extfile config.cnf -extensions req_ext +openssl x509 -req -CA root_cert.pem -CAkey root_key.pem -in client.csr -out client_cert.pem -days 7305 +``` + +Generate self-signed ECDSA root cert + +``` +openssl ecparam -name prime256v1 -genkey -noout -out temp.pem +openssl pkcs8 -topk8 -in temp.pem -out root_key_ec.pem -nocrypt +rm temp.pem +openssl req -x509 -days 7305 -new -key root_key_ec.pem -nodes -out root_cert_ec.pem -config root_ec.cnf -extensions 'v3_req' +``` + +Generate a chain of ECDSA certs + +``` +openssl ecparam -name prime256v1 -genkey -noout -out temp.pem +openssl pkcs8 -topk8 -in temp.pem -out int_key2_ec.pem -nocrypt +rm temp.pem +openssl req -key int_key2_ec.pem -new -out temp.csr -config int_cert2.cnf +openssl x509 -req -days 7305 -in temp.csr -CA root_cert_ec.pem -CAkey root_key_ec.pem -CAcreateserial -out int_cert2_ec.pem -extfile int_cert2.cnf -extensions 'v3_req' + + +openssl ecparam -name prime256v1 -genkey -noout -out temp.pem +openssl pkcs8 -topk8 -in temp.pem -out int_key1_ec.pem -nocrypt +rm temp.pem +openssl req -key int_key1_ec.pem -new -out temp.csr -config int_cert1.cnf +openssl x509 -req -days 7305 -in temp.csr -CA int_cert2_ec.pem -CAkey int_key2_ec.pem -CAcreateserial -out int_cert1_ec.pem -extfile int_cert1.cnf -extensions 'v3_req' + + +openssl ecparam -name prime256v1 -genkey -noout -out temp.pem +openssl pkcs8 -topk8 -in temp.pem -out leaf_key_ec.pem -nocrypt +rm temp.pem +openssl req -key leaf_key_ec.pem -new -out temp.csr -config leaf.cnf +openssl x509 -req -days 7305 -in temp.csr -CA int_cert1_ec.pem -CAkey int_key1_ec.pem -CAcreateserial -out leaf_cert_ec.pem -extfile leaf.cnf -extensions 'v3_req' +``` + +``` +cat leaf_cert_ec.pem int_cert1_ec.pem int_cert2_ec.pem > cert_chain_ec.pem +``` \ No newline at end of file diff --git a/s2a/src/test/resources/cert_chain_ec.pem b/s2a/src/test/resources/cert_chain_ec.pem new file mode 100644 index 00000000000..a249904286c --- /dev/null +++ b/s2a/src/test/resources/cert_chain_ec.pem @@ -0,0 +1,39 @@ +-----BEGIN CERTIFICATE----- +MIIB6jCCAZCgAwIBAgIUA98F2JkYZAyz9BdIkBK3P8Df7OUwCgYIKoZIzj0EAwIw +MjEOMAwGA1UECgwFaW50MU8xDzANBgNVBAsMBmludDFPVTEPMA0GA1UEAwwGaW50 +MUNOMB4XDTI0MTAwMTIxNDIwMFoXDTQ0MTAwMTIxNDIwMFowMjEOMAwGA1UECgwF +bGVhZk8xDzANBgNVBAsMBmxlYWZPVTEPMA0GA1UEAwwGbGVhZkNOMFkwEwYHKoZI +zj0CAQYIKoZIzj0DAQcDQgAEtpTTzt2VDTP6gO4uUIpg8sB63Ff4T4YPMoIGrrn3 +tU3f9j0Ysa5/xblM0LkwRImcrKKchYDiNm1wHkWo+qDImaOBgzCBgDAOBgNVHQ8B +Af8EBAMCB4AwIAYDVR0lAQH/BBYwFAYIKwYBBQUHAwIGCCsGAQUFBwMBMAwGA1Ud +EwEB/wQCMAAwHQYDVR0OBBYEFGzFBt/E6vDJRcH+Izy4MQ9AHycqMB8GA1UdIwQY +MBaAFBYs72Jv682/xzG3Tm8hItIFis//MAoGCCqGSM49BAMCA0gAMEUCIHUcqPTB +mQ4kXE0WoOUC8ZmzvthvfKjCNe0YogcjZgwWAiEAvapmWoQIO4qie25Ae9sYRCPq +5xAHztAquk5HLfwabow= +-----END CERTIFICATE----- +-----BEGIN CERTIFICATE----- +MIIB8TCCAZagAwIBAgIUEXwpznJIlU+ELO7Qgb4UUGpfbj8wCgYIKoZIzj0EAwIw +MjEOMAwGA1UECgwFaW50Mk8xDzANBgNVBAsMBmludDJPVTEPMA0GA1UEAwwGaW50 +MkNOMB4XDTI0MTAwMTIxNDIwMFoXDTQ0MTAwMTIxNDIwMFowMjEOMAwGA1UECgwF +aW50MU8xDzANBgNVBAsMBmludDFPVTEPMA0GA1UEAwwGaW50MUNOMFkwEwYHKoZI +zj0CAQYIKoZIzj0DAQcDQgAEoenicrtL6ezEW2yLSXADscDJQ/fdbr+vJEU/aieV +wA2EnPbrdpvQZaz+pXtuZzBLZY50XI9y33E+/PvBFtZob6OBiTCBhjAOBgNVHQ8B +Af8EBAMCAQYwIAYDVR0lAQH/BBYwFAYIKwYBBQUHAwIGCCsGAQUFBwMBMBIGA1Ud +EwEB/wQIMAYBAf8CAQEwHQYDVR0OBBYEFBYs72Jv682/xzG3Tm8hItIFis//MB8G +A1UdIwQYMBaAFPhN6eGgVc36Kc50rREZhMdBIkgGMAoGCCqGSM49BAMCA0kAMEYC +IQDiPcbihg1iDi0m9CUn96IbWOTh1X75RfVJYcR3Q5T78AIhAK/fxZauDeWPzk2r +2/ohCQOZFHtAi9VRpr/TqNi3SaYt +-----END CERTIFICATE----- +-----BEGIN CERTIFICATE----- +MIIB8DCCAZagAwIBAgIUNOH4wQEoKHvaQ9Xgd36vh5TnhfUwCgYIKoZIzj0EAwIw +MjEOMAwGA1UECgwFcm9vdE8xDzANBgNVBAsMBnJvb3RPVTEPMA0GA1UEAwwGcm9v +dENOMB4XDTI0MTAwMTIxNDIwMFoXDTQ0MTAwMTIxNDIwMFowMjEOMAwGA1UECgwF +aW50Mk8xDzANBgNVBAsMBmludDJPVTEPMA0GA1UEAwwGaW50MkNOMFkwEwYHKoZI +zj0CAQYIKoZIzj0DAQcDQgAE44B/G4pzAvLpIUaPp8XNRtXuw8jeLgE40NjQMuqq +3jNs6ID/fv/jiRggLMXL3Tii1CisM4BRjg56/Owky1Fyv6OBiTCBhjAOBgNVHQ8B +Af8EBAMCAQYwIAYDVR0lAQH/BBYwFAYIKwYBBQUHAwIGCCsGAQUFBwMBMBIGA1Ud +EwEB/wQIMAYBAf8CAQIwHQYDVR0OBBYEFPhN6eGgVc36Kc50rREZhMdBIkgGMB8G +A1UdIwQYMBaAFNHNBlllqi9koRtf7EBHjRMwVgWsMAoGCCqGSM49BAMCA0gAMEUC +IBd4bvqVeYSSUEGF1wB0KlYxn1L0Ub/LjgIUUQFAEwahAiEAgeArX63bnlI7u3dq +v/FGilvcLP3P3AvRozpHJiIZ860= +-----END CERTIFICATE----- \ No newline at end of file diff --git a/s2a/src/test/resources/client_cert.pem b/s2a/src/test/resources/client_cert.pem new file mode 100644 index 00000000000..837f8bb5019 --- /dev/null +++ b/s2a/src/test/resources/client_cert.pem @@ -0,0 +1,20 @@ +-----BEGIN CERTIFICATE----- +MIIDPTCCAiWgAwIBAgIUaarddwSWeE4jDC9kwxEr446ehqUwDQYJKoZIhvcNAQEL +BQAwWTELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoM +GEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDESMBAGA1UEAwwJbG9jYWxob3N0MB4X +DTI0MTAwMTIxNTk1NFoXDTQ0MTAwMTIxNTk1NFowFDESMBAGA1UEAwwJbG9jYWxo +b3N0MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAxlNsldt7yAU4KRuS +2D2/FjNIE1US5olBm4HteTr++41WaELZJqNLRPPp052jEQU3aKSYNGZvUUO6buu7 +eFpz2SBNUVMyvmzzocjVAyyf4NQvDazYHWOb+/YCeUppTRWriz4V5sn47qJTQ8cd +CGrTFeLHxUjx4nh/OiqVXP/KnF3EqPEuqph0ky7+GirnJgPRe+C5ERuGkJye8dmP +yWGA2lSS6MeDe7JZTAMi08bAn7BuNpeBkOzz1msGGI9PnUanUs7GOPWTDdcQAVY8 +KMvHCuGaNMGpb4rOR2mm8LlbAbpTPz8Pkw4QtMCLkgsrz2CzXpVwnLsU7nDXJAIO +B155lQIDAQABo0IwQDAdBgNVHQ4EFgQUSZEyIHLzkIw7AwkBaUjYfIrGVR4wHwYD +VR0jBBgwFoAUcq3dtxAVA410YWyM0B4e+4umbiwwDQYJKoZIhvcNAQELBQADggEB +AAz0bZ4ayrZLhA45xn0yvdpdqiCtiWikCRtxgE7VXHg/ziZJVMpBpAhbIGO5tIyd +lttnRXHwz5DUwKiba4/bCEFe229BshQEql5qaqcbGbFfSly11WeqqnwR1N7c8Gpv +pD9sVrx22seN0rTUk87MY/S7mzCxHqAx35zm/LTW3pWcgCTMKFHy4Gt4mpTnXkNA +WkhP2OhW5RLiu6Whi0BEdb2TGG1+ctamgijKXb+gJeef5ehlHXG8eU862KF5UlEA +NeQKBm/PpQxOMe0NdpatjN8QRoczku0Itiodng+OZ1o+2iSNG988uFRb3CUSnjtE +R/HL6ULAFzo59EpIYxruU/w= +-----END CERTIFICATE----- \ No newline at end of file diff --git a/s2a/src/test/resources/client_key.pem b/s2a/src/test/resources/client_key.pem new file mode 100644 index 00000000000..38b93eb65c4 --- /dev/null +++ b/s2a/src/test/resources/client_key.pem @@ -0,0 +1,28 @@ +-----BEGIN PRIVATE KEY----- +MIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQDGU2yV23vIBTgp +G5LYPb8WM0gTVRLmiUGbge15Ov77jVZoQtkmo0tE8+nTnaMRBTdopJg0Zm9RQ7pu +67t4WnPZIE1RUzK+bPOhyNUDLJ/g1C8NrNgdY5v79gJ5SmlNFauLPhXmyfjuolND +xx0IatMV4sfFSPHieH86KpVc/8qcXcSo8S6qmHSTLv4aKucmA9F74LkRG4aQnJ7x +2Y/JYYDaVJLox4N7sllMAyLTxsCfsG42l4GQ7PPWawYYj0+dRqdSzsY49ZMN1xAB +Vjwoy8cK4Zo0walvis5HaabwuVsBulM/Pw+TDhC0wIuSCyvPYLNelXCcuxTucNck +Ag4HXnmVAgMBAAECggEAKuW9jXaBgiS63o1jyFkmvWcPNntG0M2sfrXuRzQfFgse +vwOCk8xrSflWQNsOe+58ayp6746ekl3LdBWSIbiy6SqG/sm3pp/LXNmjVYHv/QH4 +QYV643R5t1ihdVnGiBFhXwdpVleme/tpdjYZzgnJKak5W69o/nrgzhSK5ShAy2xM +j0XXbgdqG+4JxPb5BZmjHHfXAXUfgSORMdfArkbgFBRc9wL/6JVTXjeAMy5WX9qe +5UQsSOYkwc9P2snifC/jdIhjHQOkkx59O0FgukJEFZPoagVG1duWQbnNDr7QVHCJ +jV6dg9tIT4SXD3uPSPbgNGlRUseIakCzrhHARJuA2wKBgQD/h8zoh0KaqKyViCYw +XKOFpm1pAFnp2GiDOblxNubNFAXEWnC+FlkvO/z1s0zVuYELUqfxcYMSXJFEVelK +rfjZtoC5oxqWGqLo9iCj7pa8t+ipulYcLt2SWc7eZPD4T4lzeEf1Qz77aKcz34sa +dv9lzQkDvhR/Mv1VeEGFHiq2VwKBgQDGsLcTGH5Yxs//LRSY8TigBkQEDrH5NvXu +2jtAzZhy1Yhsoa5eiZkhnnzM6+n05ovfZLcy6s7dnwP1Y+C79vs+DKMBsodtDG5z +YpsB0VrXYa6P6pCqkcz0Bz9xdo5sOhAK3AKnX6jd29XBDdeYsw/lxHLG24wProTD +cCYFqtaj8wKBgQCaqKT68DL9zK14a8lBaDCIyexaqx3AjXzkP+Hfhi03XrEG4P5v +7rLYBeTbCUSt7vMN2V9QoTWFvYUm6SCkVJvTmcRblz6WL1T+z0l+LwAJBP7LC77m +m+77j2PH8yxt/iXhP6G97o+GNxdMLDbTM8bs5KZaH4fkXQY73uc5HMMZTQKBgEZS +7blYhf+t/ph2wD+RwVUCYrh86wkmJs2veCFro3WhlnO8lhbn5Mc9bTaqmVgQ8ZjT +8POYoDdYvPHxs+1TcYF4v4kuQziZmc5FLE/sZZauADb38tQsXrpQhmgGakpsEpmF +XXsYJJDB6lo2KATn+8x7R5SSyHQUdPEnlI2U9ft5AoGBAJw0NJiM1EzRS8xq0DmO +AvQaPjo01o2hH6wghws8gDQwrj0eHraHgVi7zo0VkaHJbO7ahKPudset3N7owJhA +CUAPPRtv5wn0amAyNz77f1dz4Gys3AkcchflqhbEaQpzKYx4kX0adclur4WJ/DVm +P7DI977SHCVB4FVMbXMEkBjN +-----END PRIVATE KEY----- \ No newline at end of file diff --git a/s2a/src/test/resources/config.cnf b/s2a/src/test/resources/config.cnf new file mode 100644 index 00000000000..5f9a7710e92 --- /dev/null +++ b/s2a/src/test/resources/config.cnf @@ -0,0 +1,17 @@ +[req] +distinguished_name = req_distinguished_name +req_extensions = req_ext + +[req_distinguished_name] +countryName = Country Name (2 letter code) +stateOrProvinceName = State or Province Name (full name) +localityName = Locality Name (eg, city) +organizationalUnitName = Organizational Unit Name (eg, section) +commonName = Common Name (eg, your name or your server\'s hostname) +emailAddress = Email Address + +[req_ext] +subjectAltName = @alt_names + +[alt_names] +IP.1 = ::1 \ No newline at end of file diff --git a/s2a/src/test/resources/int_cert1_.cnf b/s2a/src/test/resources/int_cert1_.cnf new file mode 100644 index 00000000000..ba5a0f66a5e --- /dev/null +++ b/s2a/src/test/resources/int_cert1_.cnf @@ -0,0 +1,14 @@ +[req] +distinguished_name = req_distinguished_name +req_extensions = v3_req +prompt = no + +[req_distinguished_name] +O = int1O +OU = int1OU +CN = int1CN + +[v3_req] +keyUsage = critical, keyCertSign, cRLSign +extendedKeyUsage = critical, clientAuth, serverAuth +basicConstraints = critical, CA:true, pathlen: 1 \ No newline at end of file diff --git a/s2a/src/test/resources/int_cert1_ec.pem b/s2a/src/test/resources/int_cert1_ec.pem new file mode 100644 index 00000000000..de83c2aba79 --- /dev/null +++ b/s2a/src/test/resources/int_cert1_ec.pem @@ -0,0 +1,13 @@ +-----BEGIN CERTIFICATE----- +MIIB8TCCAZagAwIBAgIUEXwpznJIlU+ELO7Qgb4UUGpfbj8wCgYIKoZIzj0EAwIw +MjEOMAwGA1UECgwFaW50Mk8xDzANBgNVBAsMBmludDJPVTEPMA0GA1UEAwwGaW50 +MkNOMB4XDTI0MTAwMTIxNDIwMFoXDTQ0MTAwMTIxNDIwMFowMjEOMAwGA1UECgwF +aW50MU8xDzANBgNVBAsMBmludDFPVTEPMA0GA1UEAwwGaW50MUNOMFkwEwYHKoZI +zj0CAQYIKoZIzj0DAQcDQgAEoenicrtL6ezEW2yLSXADscDJQ/fdbr+vJEU/aieV +wA2EnPbrdpvQZaz+pXtuZzBLZY50XI9y33E+/PvBFtZob6OBiTCBhjAOBgNVHQ8B +Af8EBAMCAQYwIAYDVR0lAQH/BBYwFAYIKwYBBQUHAwIGCCsGAQUFBwMBMBIGA1Ud +EwEB/wQIMAYBAf8CAQEwHQYDVR0OBBYEFBYs72Jv682/xzG3Tm8hItIFis//MB8G +A1UdIwQYMBaAFPhN6eGgVc36Kc50rREZhMdBIkgGMAoGCCqGSM49BAMCA0kAMEYC +IQDiPcbihg1iDi0m9CUn96IbWOTh1X75RfVJYcR3Q5T78AIhAK/fxZauDeWPzk2r +2/ohCQOZFHtAi9VRpr/TqNi3SaYt +-----END CERTIFICATE----- \ No newline at end of file diff --git a/s2a/src/test/resources/int_cert2.cnf b/s2a/src/test/resources/int_cert2.cnf new file mode 100644 index 00000000000..f48524effb2 --- /dev/null +++ b/s2a/src/test/resources/int_cert2.cnf @@ -0,0 +1,14 @@ +[req] +distinguished_name = req_distinguished_name +req_extensions = v3_req +prompt = no + +[req_distinguished_name] +O = int2O +OU = int2OU +CN = int2CN + +[v3_req] +keyUsage = critical, keyCertSign, cRLSign +extendedKeyUsage = critical, clientAuth, serverAuth +basicConstraints = critical, CA:true, pathlen: 2 \ No newline at end of file diff --git a/s2a/src/test/resources/int_cert2_ec.pem b/s2a/src/test/resources/int_cert2_ec.pem new file mode 100644 index 00000000000..4f502fda808 --- /dev/null +++ b/s2a/src/test/resources/int_cert2_ec.pem @@ -0,0 +1,13 @@ +-----BEGIN CERTIFICATE----- +MIIB8DCCAZagAwIBAgIUNOH4wQEoKHvaQ9Xgd36vh5TnhfUwCgYIKoZIzj0EAwIw +MjEOMAwGA1UECgwFcm9vdE8xDzANBgNVBAsMBnJvb3RPVTEPMA0GA1UEAwwGcm9v +dENOMB4XDTI0MTAwMTIxNDIwMFoXDTQ0MTAwMTIxNDIwMFowMjEOMAwGA1UECgwF +aW50Mk8xDzANBgNVBAsMBmludDJPVTEPMA0GA1UEAwwGaW50MkNOMFkwEwYHKoZI +zj0CAQYIKoZIzj0DAQcDQgAE44B/G4pzAvLpIUaPp8XNRtXuw8jeLgE40NjQMuqq +3jNs6ID/fv/jiRggLMXL3Tii1CisM4BRjg56/Owky1Fyv6OBiTCBhjAOBgNVHQ8B +Af8EBAMCAQYwIAYDVR0lAQH/BBYwFAYIKwYBBQUHAwIGCCsGAQUFBwMBMBIGA1Ud +EwEB/wQIMAYBAf8CAQIwHQYDVR0OBBYEFPhN6eGgVc36Kc50rREZhMdBIkgGMB8G +A1UdIwQYMBaAFNHNBlllqi9koRtf7EBHjRMwVgWsMAoGCCqGSM49BAMCA0gAMEUC +IBd4bvqVeYSSUEGF1wB0KlYxn1L0Ub/LjgIUUQFAEwahAiEAgeArX63bnlI7u3dq +v/FGilvcLP3P3AvRozpHJiIZ860= +-----END CERTIFICATE----- \ No newline at end of file diff --git a/s2a/src/test/resources/int_key1_ec.pem b/s2a/src/test/resources/int_key1_ec.pem new file mode 100644 index 00000000000..909c119b60c --- /dev/null +++ b/s2a/src/test/resources/int_key1_ec.pem @@ -0,0 +1,5 @@ +-----BEGIN PRIVATE KEY----- +MIGHAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBG0wawIBAQQgnYGMzs4siZ7Fy3mI +rmsqBdP6We4Zt+ndtOYEGaZDj06hRANCAASh6eJyu0vp7MRbbItJcAOxwMlD991u +v68kRT9qJ5XADYSc9ut2m9BlrP6le25nMEtljnRcj3LfcT78+8EW1mhv +-----END PRIVATE KEY----- \ No newline at end of file diff --git a/s2a/src/test/resources/int_key2_ec.pem b/s2a/src/test/resources/int_key2_ec.pem new file mode 100644 index 00000000000..520300d2560 --- /dev/null +++ b/s2a/src/test/resources/int_key2_ec.pem @@ -0,0 +1,5 @@ +-----BEGIN PRIVATE KEY----- +MIGHAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBG0wawIBAQQgzLSoAcENXIiQfBS7 +meBDCohT1rofhWSfD0m55qi8V3WhRANCAATjgH8binMC8ukhRo+nxc1G1e7DyN4u +ATjQ2NAy6qreM2zogP9+/+OJGCAsxcvdOKLUKKwzgFGODnr87CTLUXK/ +-----END PRIVATE KEY----- \ No newline at end of file diff --git a/s2a/src/test/resources/leaf.cnf b/s2a/src/test/resources/leaf.cnf new file mode 100644 index 00000000000..c21cee5568f --- /dev/null +++ b/s2a/src/test/resources/leaf.cnf @@ -0,0 +1,14 @@ +[req] +distinguished_name = req_distinguished_name +req_extensions = v3_req +prompt = no + +[req_distinguished_name] +O = leafO +OU = leafOU +CN = leafCN + +[v3_req] +keyUsage = critical, digitalSignature +extendedKeyUsage = critical, clientAuth, serverAuth +basicConstraints = critical, CA:false \ No newline at end of file diff --git a/s2a/src/test/resources/leaf_cert_ec.pem b/s2a/src/test/resources/leaf_cert_ec.pem new file mode 100644 index 00000000000..ca48b821f60 --- /dev/null +++ b/s2a/src/test/resources/leaf_cert_ec.pem @@ -0,0 +1,13 @@ +-----BEGIN CERTIFICATE----- +MIIB6jCCAZCgAwIBAgIUA98F2JkYZAyz9BdIkBK3P8Df7OUwCgYIKoZIzj0EAwIw +MjEOMAwGA1UECgwFaW50MU8xDzANBgNVBAsMBmludDFPVTEPMA0GA1UEAwwGaW50 +MUNOMB4XDTI0MTAwMTIxNDIwMFoXDTQ0MTAwMTIxNDIwMFowMjEOMAwGA1UECgwF +bGVhZk8xDzANBgNVBAsMBmxlYWZPVTEPMA0GA1UEAwwGbGVhZkNOMFkwEwYHKoZI +zj0CAQYIKoZIzj0DAQcDQgAEtpTTzt2VDTP6gO4uUIpg8sB63Ff4T4YPMoIGrrn3 +tU3f9j0Ysa5/xblM0LkwRImcrKKchYDiNm1wHkWo+qDImaOBgzCBgDAOBgNVHQ8B +Af8EBAMCB4AwIAYDVR0lAQH/BBYwFAYIKwYBBQUHAwIGCCsGAQUFBwMBMAwGA1Ud +EwEB/wQCMAAwHQYDVR0OBBYEFGzFBt/E6vDJRcH+Izy4MQ9AHycqMB8GA1UdIwQY +MBaAFBYs72Jv682/xzG3Tm8hItIFis//MAoGCCqGSM49BAMCA0gAMEUCIHUcqPTB +mQ4kXE0WoOUC8ZmzvthvfKjCNe0YogcjZgwWAiEAvapmWoQIO4qie25Ae9sYRCPq +5xAHztAquk5HLfwabow= +-----END CERTIFICATE----- \ No newline at end of file diff --git a/s2a/src/test/resources/leaf_key_ec.pem b/s2a/src/test/resources/leaf_key_ec.pem new file mode 100644 index 00000000000..b92b90ba1da --- /dev/null +++ b/s2a/src/test/resources/leaf_key_ec.pem @@ -0,0 +1,5 @@ +-----BEGIN PRIVATE KEY----- +MIGHAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBG0wawIBAQQgkvnGZBh3uIYfZiau +/0qN0YcQXlwwVVUh8EybjvKUlX2hRANCAAS2lNPO3ZUNM/qA7i5QimDywHrcV/hP +hg8yggauufe1Td/2PRixrn/FuUzQuTBEiZysopyFgOI2bXAeRaj6oMiZ +-----END PRIVATE KEY----- \ No newline at end of file diff --git a/s2a/src/test/resources/root_cert.pem b/s2a/src/test/resources/root_cert.pem new file mode 100644 index 00000000000..ccd0a46bc23 --- /dev/null +++ b/s2a/src/test/resources/root_cert.pem @@ -0,0 +1,22 @@ +-----BEGIN CERTIFICATE----- +MIIDkzCCAnugAwIBAgIUWemeXZdfqcqkP8/Eyj74oTJtoNQwDQYJKoZIhvcNAQEL +BQAwWTELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoM +GEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDESMBAGA1UEAwwJbG9jYWxob3N0MB4X +DTI0MTAwMTIxNTkxMVoXDTQ0MTAwMTIxNTkxMVowWTELMAkGA1UEBhMCQVUxEzAR +BgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoMGEludGVybmV0IFdpZGdpdHMgUHR5 +IEx0ZDESMBAGA1UEAwwJbG9jYWxob3N0MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8A +MIIBCgKCAQEAt3A04hy5lljv86Nu0LLQZ2hA+fcImHjt1p1Mxgcta/5oxfVLcerE +ZH+DAQLDtWzp9Up/vI57MM419GIL8Iszk7hnZRS/HWJ+2jewZJtz4i/g15dLr6+1 +uabMdPOWos60BwcLMxKEe6lJO1mV4z9d4NH4mAuMIHyM+ty0Klp9MfeDJtYEh0+z +AxJUHCixDTsnKJro7My7A3ZT7bvaMfXxS7XN6qlRgBfiCmXo/GKTFfmfBW/EZGkG +XOCxE2D79wYNhC41Q/ix0kwjEeOj2vgGFoiyblSdHdzvRXzsoQTEiZSM8lJDR2IT +ZbpgbBlknMU6efNWlS8P5damB9ZWXg3x4wIDAQABo1MwUTAdBgNVHQ4EFgQUcq3d +txAVA410YWyM0B4e+4umbiwwHwYDVR0jBBgwFoAUcq3dtxAVA410YWyM0B4e+4um +biwwDwYDVR0TAQH/BAUwAwEB/zANBgkqhkiG9w0BAQsFAAOCAQEApZvaI9y7vjX/ +RRdvwf2Db9KlTE9nuVQ3AsrmG9Ml0p2X6U5aTetxdYBo2PuaaYHheF03JOH8zjpL +UfFzvbi52DPbfFAaDw/6NIAenXlg492leNvUFNjGGRyJO9R5/aDfv40/fT3Em5G5 +DnR8SeGQ9tI1t6xBBT+d+/MilSiEKVu8IIF/p0SwvEyR4pKo6wFVZR0ZiIj2v/FZ +P5Qk0Xhb+slpmaR3Wtx/mPl9Wb3kpPD4CAwhWDqFkKJql9/n9FvMjdwlCQKQGB26 +ZDXY3C0UTdktK5biNWRgAUVJEWBX6Q2amrxQHIn2d9RJ8uxCME/KBAntK+VxZE78 +w0JOvQ4Dpw== +-----END CERTIFICATE----- \ No newline at end of file diff --git a/s2a/src/test/resources/root_cert_ec.pem b/s2a/src/test/resources/root_cert_ec.pem new file mode 100644 index 00000000000..3d20dcfe83c --- /dev/null +++ b/s2a/src/test/resources/root_cert_ec.pem @@ -0,0 +1,12 @@ +-----BEGIN CERTIFICATE----- +MIIBxzCCAW2gAwIBAgIUN+H7Td9dhyvMrrzZhanevAfCN34wCgYIKoZIzj0EAwIw +MjEOMAwGA1UECgwFcm9vdE8xDzANBgNVBAsMBnJvb3RPVTEPMA0GA1UEAwwGcm9v +dENOMB4XDTI0MTAwMTIxNDIwMFoXDTQ0MTAwMTIxNDIwMFowMjEOMAwGA1UECgwF +cm9vdE8xDzANBgNVBAsMBnJvb3RPVTEPMA0GA1UEAwwGcm9vdENOMFkwEwYHKoZI +zj0CAQYIKoZIzj0DAQcDQgAEGnS2gVv6Bs0GtuUAOebR9E0fqaj3zi9mD97B/dgi +MLENhtVPJQzeePv6Ccap+73O0BINRNOl8tlHX0YaXDeEHKNhMF8wDgYDVR0PAQH/ +BAQDAgEGMB0GA1UdJQQWMBQGCCsGAQUFBwMBBggrBgEFBQcDAjAPBgNVHRMBAf8E +BTADAQH/MB0GA1UdDgQWBBTRzQZZZaovZKEbX+xAR40TMFYFrDAKBggqhkjOPQQD +AgNIADBFAiEAgnIyLs7FsZNsJjFgYzlaut4h23RxrpUYVCVZt/+x1Q0CIG3U6WGz +YaEyKoCtBHH9cAy76+pP/NU2f7/QuHU9Vymd +-----END CERTIFICATE----- \ No newline at end of file diff --git a/s2a/src/test/resources/root_ec.cnf b/s2a/src/test/resources/root_ec.cnf new file mode 100644 index 00000000000..d736865c831 --- /dev/null +++ b/s2a/src/test/resources/root_ec.cnf @@ -0,0 +1,14 @@ +[req] +distinguished_name = req_distinguished_name +req_extensions = v3_req +prompt = no + +[req_distinguished_name] +O = rootO +OU = rootOU +CN = rootCN + +[v3_req] +keyUsage = critical, keyCertSign, cRLSign +extendedKeyUsage = serverAuth, clientAuth +basicConstraints = critical, CA:true \ No newline at end of file diff --git a/s2a/src/test/resources/root_key.pem b/s2a/src/test/resources/root_key.pem new file mode 100644 index 00000000000..34d0ffa61eb --- /dev/null +++ b/s2a/src/test/resources/root_key.pem @@ -0,0 +1,30 @@ +-----BEGIN ENCRYPTED PRIVATE KEY----- +MIIFJDBWBgkqhkiG9w0BBQ0wSTAxBgkqhkiG9w0BBQwwJAQQJXNe391O3gaNbKLw +o60XrQICCAAwDAYIKoZIhvcNAgkFADAUBggqhkiG9w0DBwQI4pf69+BBF8IEggTI +JuQ3p67U9k/NWMuYXaR9a6lv24YZ1qR6ieL5B6keCaCDVoQMb5V22O0vBqCVePgr +EG0yWIeeAsARMzAxE7Lnil6abSe7tij+LjEI9F7mV/1QSFt03PLVI+e7OcKNI+Nr +6vISEi8CaddekP8JDRhPMpgdWderZvogo3REpJ8GNIUddQzu1e3ZgDtOPquqcgqb +MH/HuPE3vjj4/l6ZpX+6DZKIvzjwtBQ4PMzSWLumzmYLItd3kz7UryN+9hKluSZp +D2KB24aUIQFbDxe2DMTi5c0QIiyzjwkv081ecNJOy2gYX3uiucr8/Ax3o21RNZtI +oKCmSPVEfYdrkdfkwuSOioVTbWBZBcSZo3L2bmCkSXTuheGurEw/TtQWXBgew0Bn +UQjEJgZy96PVsQeu3t+NRCacARQi4vfv7PVHlQW8fcfcC6CeNw7VIZ8aS7supqym +RJxzMY9ZnLwO9cgybXLYgosVZnvI7nOokJPfO1+KqBK01C1Sgc3tg8czKhRuztHu +qDO0GCZ7l+9/ku/WIy/5NiatNvRo5dMAOGxsSrjI9a7+EmenoIfd8/KREVX19D+R +gZRALVATHq83rF6BdsyTwya1QUr/J24EIlkOc4HbCBm5WxA2ZjNdDBZ+KhivYaS7 +l1qrbkFOhmBD9kYRbseBrxlzKUWJMGhOpw3xebut3HngLqyezLcjsXQuF3Iau5Hl +9QFcmSdLj2ZlNlQvmfNJX/r6a/K2LigruXCbvHWMqVsHd7XZdWJ/8wjm2AL97iON +mYFLP+ScfYom9qrF41jNkUKZiLk/ppvSHyWBAqbze+R9Zfpcf8ArCwuAL/JlEMzv +YkBv1DWKfzJpZHYX695MxrpS3C8m0IyXNxktBL3KTVvwZaIhSNBlNS3fdb9m8toR +Tz/LS8jseWpZ5D552/+KAa0Skhav3ZFpxmAS8BEyE/nI9Dwg9niYcZLWORWHAQPp +jraG0BkE7bn5No/k7E4rjFb+2N+36QxVacJI3neC8bQXVHP0BVUvrabOWFPnGivl +Ok91Eo8q5PUAsd15ZnKjTHzlD7zv7fF6ncBgj3P4L2Xrs6P34JOZEd4wixEUZYeC +Xe+SZrFyUr6CcNC45C6R3hDYqmrz0GK1ikkis3XcKT+C5flBYb9NRx8G9wyCuS6H +oHl0Rfbpc47wQTuajicMVO2El7syMPUAxjo3EfMzvjm7uCXLTHnXRnRt3Y5AkPGa +0kFE9Vm00PReRfQ7qbSUiOOHYa9NIsw1l2ZI+knP9XbY2HikELOpjgucrMxZF+ms +zit5YGD3NGZi5xcHZFZTs9L8kaJccXn5DtjA30eEiFzKqMtMKnwlrbSL55I1JXim +co1RLpRK2KQmtJHo1br3RH6jP7fePYzgDceDds5HKWz22pYFcVtlx4DeYH5vjdEp +i3yNQZ32jD2HYhgCK325QLP5S2UYmUOPWd4sEiwZMBPpPOlt0TqCdFKYgS2GHlSN +IYVBYelPUYsz9Kg0TFtLMZLNUmwsXJ+jqnLVtmFyoV6IIvbSCqQ9jxTbZQKxThK8 +A1G+nXBO41ZW8eQZUGx8CzbCj2JvtVThgErSRqAuYbvlUt7EI4Ac8veZC8rJIG0Q +ADkueb978o4OI6vpOdTYCmdTIoHWlpup +-----END ENCRYPTED PRIVATE KEY----- \ No newline at end of file diff --git a/s2a/src/test/resources/root_key_ec.pem b/s2a/src/test/resources/root_key_ec.pem new file mode 100644 index 00000000000..5560a66d414 --- /dev/null +++ b/s2a/src/test/resources/root_key_ec.pem @@ -0,0 +1,5 @@ +-----BEGIN PRIVATE KEY----- +MIGHAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBG0wawIBAQQgjfTyzPIlKV0zANQP +2s1C2FhbenE34QEsf83wjpuQrZWhRANCAAQadLaBW/oGzQa25QA55tH0TR+pqPfO +L2YP3sH92CIwsQ2G1U8lDN54+/oJxqn7vc7QEg1E06Xy2UdfRhpcN4Qc +-----END PRIVATE KEY----- \ No newline at end of file diff --git a/s2a/src/test/resources/server_cert.pem b/s2a/src/test/resources/server_cert.pem new file mode 100644 index 00000000000..909b83aa903 --- /dev/null +++ b/s2a/src/test/resources/server_cert.pem @@ -0,0 +1,20 @@ +-----BEGIN CERTIFICATE----- +MIIDWjCCAkKgAwIBAgIUAeWzyzIEetYf+ZWHj9NzH1JkLYkwDQYJKoZIhvcNAQEL +BQAwWTELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoM +GEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDESMBAGA1UEAwwJbG9jYWxob3N0MB4X +DTI0MTAwMTIxNTk0NloXDTQ0MTAwMTIxNTk0NlowFDESMBAGA1UEAwwJbG9jYWxo +b3N0MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA1qnW7Pb06MgRNLzt +icv/ydl8W/lpPRjrJJb04/TtXbJ1hjnp7i796TfNGrJgHqEZnaR8q83lO0L38B2X +sJ04b3R+y+6HhH8+MbHejM7ybrTZRNQXip/Kxu4QLHBTQEsplycWLf42/R3cIk/X +vgxq5NsCsbk4xI4xwlcqC8FM1AHU0VrKxzHWVhZEM+/KovBAr/hRYln9CukeKjOf +UiVq58uuDAlJRC3yH2Rd/sqCDELvqRv17J6eYx2nJ3mSN5aBa0FwVjg6vr5Obddj +AWWIkgrlAr+a+OraxOrWElFfChBSvr/qHdJFWHeCdq/SAhow5uRhC69ScJf+7lrX +hsj1sQIDAQABo18wXTAbBgNVHREEFDAShxAAAAAAAAAAAAAAAAAAAAABMB0GA1Ud +DgQWBBRdDRg6GuDj8Sujmz4/rqfP0jZHbTAfBgNVHSMEGDAWgBRyrd23EBUDjXRh +bIzQHh77i6ZuLDANBgkqhkiG9w0BAQsFAAOCAQEAAEUS27+6p88CWYemMOY0iu0e +mp4YqG0XQSilbSnxrqnJb3N8pR3Yh6JJKnblQ6xdexfzrXlBA/v7nx+f8e9HS2QZ +KLtEIaEvNKL51JdOS6ebEzLVvhk98r2kpKM3wpT++/18HPlPK5W3rMQNsLOyAdvP +UX6TakhIfflRjz1DYXQ1ERvJOFw2HEmw6K6r2VwBhZKfwwzxmAHpVwniWXGbgyRF +79hG6rO1tv1K5LHAPIRs0h2Lh/VPxm2XiaNkdGyarUy5/NM+GoHErgxOBmYltn5Q +vAlZrgF2/mSXcUb7EHoXvoC9L4M7U/dRQD4Q1fQRJ/KjrhbDAC3gfZ4zorKoaQ== +-----END CERTIFICATE----- \ No newline at end of file diff --git a/s2a/src/test/resources/server_key.pem b/s2a/src/test/resources/server_key.pem new file mode 100644 index 00000000000..edc37cb3855 --- /dev/null +++ b/s2a/src/test/resources/server_key.pem @@ -0,0 +1,28 @@ +-----BEGIN PRIVATE KEY----- +MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQDWqdbs9vToyBE0 +vO2Jy//J2Xxb+Wk9GOsklvTj9O1dsnWGOenuLv3pN80asmAeoRmdpHyrzeU7Qvfw +HZewnThvdH7L7oeEfz4xsd6MzvJutNlE1BeKn8rG7hAscFNASymXJxYt/jb9Hdwi +T9e+DGrk2wKxuTjEjjHCVyoLwUzUAdTRWsrHMdZWFkQz78qi8ECv+FFiWf0K6R4q +M59SJWrny64MCUlELfIfZF3+yoIMQu+pG/Xsnp5jHacneZI3loFrQXBWODq+vk5t +12MBZYiSCuUCv5r46trE6tYSUV8KEFK+v+od0kVYd4J2r9ICGjDm5GELr1Jwl/7u +WteGyPWxAgMBAAECggEAFEAgcOlZME6TZPS/ueSfRET6mNieB2/+2sxM3OZhsBmi +QZ/cBCa1uFcVx8N1Et6iwn7ebfy199G4/xNjmHs0dDs6rPVbHnI8hUag1oq9TxlL +d9VERUUOxZZ2uyJ7kBCnI0XCL2OQf29eMXRzx093lBBfIDH3e39ojUtYwZQiMcuw +EPry0k4fVhymhKg9Wnmt5lMg4Mdc1TpPfmNFuTR0PZ1nAaVQglvH66qNKGVoWEhZ +paNLaKC4H2Jfa1AfAWl6Efy5JDMOfHF0ww0cDUrTzAeQ7jEh0UGyL1lX8W6kKRDa +0quUqxOJz9aQ8cyd27s2OQMlRtbXi/jhhVp7WLIrWQKBgQD9gKG5CgBO/L8nIj5o +EhHFhtfjEhdeXTAlenmxoBxUN7Pwkc2OvhNef7+T0+euwl50ieopWLoRxLZ2yY8l +E2b2+7EM6/8/wgt1bCVh5NCWrE63tLCx+wdht1oqciDXvuv5bJTf73sipgDTYYSV +gE+DHXq96mxVJXo1TLtQQpXMVQKBgQDYx0AbO0KP2TTNY5ChqVwthaETHjWs6z9p +U5WRgNYeXbUKg3l7JJk6zq72ZIBeqEr3d9mJqrk6HFKTh4c+LyjKyLjmY5wkmfHh +s6s1lCEgEoXKT3Fa+DxlsXltyxrJLzuf1h276jeL5bB6BmJNKLODcEoCx/ubrwOj +prdUSWqf7QKBgQCO/sg7AJE7/QY2pPJe8hJkQbP1unbEG/zUp0mOEKrqNqGhyh0R +r9ZtL9J5KMc/pRRy2Hjl6c7LxxLF3tyIJXGnUEKG73iEFokwK1jK569hzsB4j8w8 +GUYIsMyDtO0hxeiGQeGYkBX9bXZ5xkBrtH0lkLNz/ZAuV32gIzBmDalCIQKBgDGT +f+m6Z8KWHilKt+0A2n/eq7O/mO7u7hWcc/xOxqkzLRA2eTXcbN6yHfljiqgbPOnT +kwCU9r9/crMir59dEasutHqcFT2Zp2PCv0kFk33OPqLCAF6ZntZy/B5L8NhJ4Qzw +3uP28LUh1nZRt3GF+Wf56jMwoS49nEt0+UBhee0RAoGAS9YsJkbjBg2p3Gxvo5c0 +IjfZdcyS2ndTjXv+hFvkjMw0ULFT3dqpk+0asaCh5nrDUbVQyan+D8LgwSwNZy89 +e99bl//oliv/Om7lVFCKtBOhe+fIWHlrR0e2bemsQi/pgTURjYFuvjhR50dcKx96 +jLHvG4mTfStHaJ1gKGWvgWA= +-----END PRIVATE KEY----- \ No newline at end of file diff --git a/services/BUILD.bazel b/services/BUILD.bazel index fa708dd04b2..ba9d334a5c9 100644 --- a/services/BUILD.bazel +++ b/services/BUILD.bazel @@ -1,3 +1,5 @@ +load("@rules_java//java:defs.bzl", "java_library") +load("@rules_jvm_external//:defs.bzl", "artifact") load("//:java_grpc_library.bzl", "java_grpc_library") package(default_visibility = ["//visibility:public"]) @@ -9,6 +11,7 @@ java_library( name = "services_maven", exports = [ ":admin", + ":binarylog", ":channelz", ":health", ":healthlb", @@ -26,7 +29,7 @@ java_library( deps = [ ":channelz", "//api", - "@com_google_code_findbugs_jsr305//jar", + artifact("com.google.code.findbugs:jsr305"), ], ) @@ -35,15 +38,15 @@ java_library( srcs = [ "src/main/java/io/grpc/services/CallMetricRecorder.java", "src/main/java/io/grpc/services/MetricRecorder.java", - "src/main/java/io/grpc/services/MetricReport.java", "src/main/java/io/grpc/services/MetricRecorderHelper.java", + "src/main/java/io/grpc/services/MetricReport.java", ], deps = [ "//api", "//context", - "@com_google_code_findbugs_jsr305//jar", - "@com_google_errorprone_error_prone_annotations//jar", - "@com_google_guava_guava//jar", + artifact("com.google.code.findbugs:jsr305"), + artifact("com.google.errorprone:error_prone_annotations"), + artifact("com.google.guava:guava"), ], ) @@ -61,6 +64,28 @@ java_library( ], ) +java_library( + name = "binarylog", + srcs = [ + "src/main/java/io/grpc/protobuf/services/BinaryLogProvider.java", + "src/main/java/io/grpc/protobuf/services/BinaryLogProviderImpl.java", + "src/main/java/io/grpc/protobuf/services/BinaryLogSink.java", + "src/main/java/io/grpc/protobuf/services/BinaryLogs.java", + "src/main/java/io/grpc/protobuf/services/BinlogHelper.java", + "src/main/java/io/grpc/protobuf/services/InetAddressUtil.java", + "src/main/java/io/grpc/protobuf/services/TempFileSink.java", + ], + deps = [ + "//api", + "//context", + "@com_google_protobuf//:protobuf_java", + "@com_google_protobuf//:protobuf_java_util", + "@io_grpc_grpc_proto//:binarylog_java_proto", + artifact("com.google.code.findbugs:jsr305"), + artifact("com.google.guava:guava"), + ], +) + java_library( name = "channelz", srcs = [ @@ -72,11 +97,11 @@ java_library( ":_channelz_java_grpc", "//api", "//stub", - "@com_google_code_findbugs_jsr305//jar", - "@com_google_guava_guava//jar", "@com_google_protobuf//:protobuf_java", "@com_google_protobuf//:protobuf_java_util", "@io_grpc_grpc_proto//:channelz_java_proto", + artifact("com.google.code.findbugs:jsr305"), + artifact("com.google.guava:guava"), ], ) @@ -84,17 +109,21 @@ java_library( name = "reflection", srcs = [ "src/main/java/io/grpc/protobuf/services/ProtoReflectionService.java", + "src/main/java/io/grpc/protobuf/services/ProtoReflectionServiceV1.java", ], deps = [ ":_reflection_java_grpc", + ":_reflection_v1_java_grpc", "//api", "//protobuf", "//stub", - "@com_google_code_findbugs_jsr305//jar", - "@com_google_guava_guava//jar", "@com_google_protobuf//:protobuf_java", "@com_google_protobuf//:protobuf_java_util", + "@io_grpc_grpc_proto//:reflection_java_proto", "@io_grpc_grpc_proto//:reflection_java_proto_deprecated", + artifact("com.google.code.findbugs:jsr305"), + artifact("com.google.errorprone:error_prone_annotations"), + artifact("com.google.guava:guava"), ], ) @@ -110,9 +139,10 @@ java_library( "//api", "//context", "//stub", - "@com_google_code_findbugs_jsr305//jar", - "@com_google_guava_guava//jar", "@io_grpc_grpc_proto//:health_java_proto", + artifact("com.google.code.findbugs:jsr305"), + artifact("com.google.errorprone:error_prone_annotations"), + artifact("com.google.guava:guava"), ], ) @@ -131,9 +161,9 @@ java_library( "//api", "//core:internal", "//util", - "@com_google_code_findbugs_jsr305//jar", - "@com_google_guava_guava//jar", "@io_grpc_grpc_proto//:health_java_proto", + artifact("com.google.code.findbugs:jsr305"), + artifact("com.google.guava:guava"), ], ) @@ -147,6 +177,13 @@ java_grpc_library( deps = ["@io_grpc_grpc_proto//:reflection_java_proto_deprecated"], ) +java_grpc_library( + name = "_reflection_v1_java_grpc", + srcs = ["@io_grpc_grpc_proto//:reflection_proto"], + visibility = ["//visibility:private"], + deps = ["@io_grpc_grpc_proto//:reflection_java_proto"], +) + java_grpc_library( name = "_channelz_java_grpc", srcs = ["@io_grpc_grpc_proto//:channelz_proto"], diff --git a/services/build.gradle b/services/build.gradle index de716c9fa1d..c30e1ba53bd 100644 --- a/services/build.gradle +++ b/services/build.gradle @@ -27,20 +27,21 @@ dependencies { implementation project(':grpc-core'), project(':grpc-protobuf'), project(':grpc-util'), - libraries.protobuf.java.util, - libraries.guava.jre // JRE required by protobuf-java-util + libraries.guava.jre, // JRE required by protobuf-java-util + libraries.protobuf.java.util runtimeOnly libraries.errorprone.annotations, - libraries.j2objc.annotations, // Explicit dependency to keep in step with version used by guava libraries.gson // to fix checkUpperBoundDeps error here - compileOnly libraries.javax.annotation testImplementation project(':grpc-testing'), project(':grpc-inprocess'), libraries.netty.transport.epoll, // for DomainSocketAddress testFixtures(project(':grpc-core')), testFixtures(project(':grpc-api')) - testCompileOnly libraries.javax.annotation - signature libraries.signature.java + signature (libraries.signature.java) { + artifact { + extension = "signature" + } + } } configureProtoCompilation() @@ -59,6 +60,7 @@ tasks.named("jacocoTestReport").configure { '**/io/grpc/binarylog/v1/**', '**/io/grpc/channelz/v1/**', '**/io/grpc/health/v1/**', + '**/io/grpc/reflection/v1/**', '**/io/grpc/reflection/v1alpha/**', ]) } diff --git a/services/src/generated/main/grpc/io/grpc/channelz/v1/ChannelzGrpc.java b/services/src/generated/main/grpc/io/grpc/channelz/v1/ChannelzGrpc.java index b3c1c285c8f..c4ac4076d22 100644 --- a/services/src/generated/main/grpc/io/grpc/channelz/v1/ChannelzGrpc.java +++ b/services/src/generated/main/grpc/io/grpc/channelz/v1/ChannelzGrpc.java @@ -8,9 +8,6 @@ * information. * */ -@javax.annotation.Generated( - value = "by gRPC proto compiler", - comments = "Source: grpc/channelz/v1/channelz.proto") @io.grpc.stub.annotations.GrpcGenerated public final class ChannelzGrpc { @@ -250,6 +247,21 @@ public ChannelzStub newStub(io.grpc.Channel channel, io.grpc.CallOptions callOpt return ChannelzStub.newStub(factory, channel); } + /** + * Creates a new blocking-style stub that supports all types of calls on the service + */ + public static ChannelzBlockingV2Stub newBlockingV2Stub( + io.grpc.Channel channel) { + io.grpc.stub.AbstractStub.StubFactory factory = + new io.grpc.stub.AbstractStub.StubFactory() { + @java.lang.Override + public ChannelzBlockingV2Stub newStub(io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new ChannelzBlockingV2Stub(channel, callOptions); + } + }; + return ChannelzBlockingV2Stub.newStub(factory, channel); + } + /** * Creates a new blocking-style stub that supports unary and streaming output calls on the service */ @@ -481,6 +493,98 @@ public void getSocket(io.grpc.channelz.v1.GetSocketRequest request, * information. * */ + public static final class ChannelzBlockingV2Stub + extends io.grpc.stub.AbstractBlockingStub { + private ChannelzBlockingV2Stub( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + super(channel, callOptions); + } + + @java.lang.Override + protected ChannelzBlockingV2Stub build( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new ChannelzBlockingV2Stub(channel, callOptions); + } + + /** + *
+     * Gets all root channels (i.e. channels the application has directly
+     * created). This does not include subchannels nor non-top level channels.
+     * 
+ */ + public io.grpc.channelz.v1.GetTopChannelsResponse getTopChannels(io.grpc.channelz.v1.GetTopChannelsRequest request) throws io.grpc.StatusException { + return io.grpc.stub.ClientCalls.blockingV2UnaryCall( + getChannel(), getGetTopChannelsMethod(), getCallOptions(), request); + } + + /** + *
+     * Gets all servers that exist in the process.
+     * 
+ */ + public io.grpc.channelz.v1.GetServersResponse getServers(io.grpc.channelz.v1.GetServersRequest request) throws io.grpc.StatusException { + return io.grpc.stub.ClientCalls.blockingV2UnaryCall( + getChannel(), getGetServersMethod(), getCallOptions(), request); + } + + /** + *
+     * Returns a single Server, or else a NOT_FOUND code.
+     * 
+ */ + public io.grpc.channelz.v1.GetServerResponse getServer(io.grpc.channelz.v1.GetServerRequest request) throws io.grpc.StatusException { + return io.grpc.stub.ClientCalls.blockingV2UnaryCall( + getChannel(), getGetServerMethod(), getCallOptions(), request); + } + + /** + *
+     * Gets all server sockets that exist in the process.
+     * 
+ */ + public io.grpc.channelz.v1.GetServerSocketsResponse getServerSockets(io.grpc.channelz.v1.GetServerSocketsRequest request) throws io.grpc.StatusException { + return io.grpc.stub.ClientCalls.blockingV2UnaryCall( + getChannel(), getGetServerSocketsMethod(), getCallOptions(), request); + } + + /** + *
+     * Returns a single Channel, or else a NOT_FOUND code.
+     * 
+ */ + public io.grpc.channelz.v1.GetChannelResponse getChannel(io.grpc.channelz.v1.GetChannelRequest request) throws io.grpc.StatusException { + return io.grpc.stub.ClientCalls.blockingV2UnaryCall( + getChannel(), getGetChannelMethod(), getCallOptions(), request); + } + + /** + *
+     * Returns a single Subchannel, or else a NOT_FOUND code.
+     * 
+ */ + public io.grpc.channelz.v1.GetSubchannelResponse getSubchannel(io.grpc.channelz.v1.GetSubchannelRequest request) throws io.grpc.StatusException { + return io.grpc.stub.ClientCalls.blockingV2UnaryCall( + getChannel(), getGetSubchannelMethod(), getCallOptions(), request); + } + + /** + *
+     * Returns a single Socket or else a NOT_FOUND code.
+     * 
+ */ + public io.grpc.channelz.v1.GetSocketResponse getSocket(io.grpc.channelz.v1.GetSocketRequest request) throws io.grpc.StatusException { + return io.grpc.stub.ClientCalls.blockingV2UnaryCall( + getChannel(), getGetSocketMethod(), getCallOptions(), request); + } + } + + /** + * A stub to allow clients to do limited synchronous rpc calls to service Channelz. + *
+   * Channelz is a service exposed by gRPC servers that provides detailed debug
+   * information.
+   * 
+ */ public static final class ChannelzBlockingStub extends io.grpc.stub.AbstractBlockingStub { private ChannelzBlockingStub( diff --git a/services/src/generated/main/grpc/io/grpc/health/v1/HealthGrpc.java b/services/src/generated/main/grpc/io/grpc/health/v1/HealthGrpc.java index 73ddd4e0d23..b8e94ef7d20 100644 --- a/services/src/generated/main/grpc/io/grpc/health/v1/HealthGrpc.java +++ b/services/src/generated/main/grpc/io/grpc/health/v1/HealthGrpc.java @@ -4,9 +4,6 @@ /** */ -@javax.annotation.Generated( - value = "by gRPC proto compiler", - comments = "Source: grpc/health/v1/health.proto") @io.grpc.stub.annotations.GrpcGenerated public final class HealthGrpc { @@ -91,6 +88,21 @@ public HealthStub newStub(io.grpc.Channel channel, io.grpc.CallOptions callOptio return HealthStub.newStub(factory, channel); } + /** + * Creates a new blocking-style stub that supports all types of calls on the service + */ + public static HealthBlockingV2Stub newBlockingV2Stub( + io.grpc.Channel channel) { + io.grpc.stub.AbstractStub.StubFactory factory = + new io.grpc.stub.AbstractStub.StubFactory() { + @java.lang.Override + public HealthBlockingV2Stub newStub(io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new HealthBlockingV2Stub(channel, callOptions); + } + }; + return HealthBlockingV2Stub.newStub(factory, channel); + } + /** * Creates a new blocking-style stub that supports unary and streaming output calls on the service */ @@ -225,6 +237,58 @@ public void watch(io.grpc.health.v1.HealthCheckRequest request, /** * A stub to allow clients to do synchronous rpc calls to service Health. */ + public static final class HealthBlockingV2Stub + extends io.grpc.stub.AbstractBlockingStub { + private HealthBlockingV2Stub( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + super(channel, callOptions); + } + + @java.lang.Override + protected HealthBlockingV2Stub build( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new HealthBlockingV2Stub(channel, callOptions); + } + + /** + *
+     * If the requested service is unknown, the call will fail with status
+     * NOT_FOUND.
+     * 
+ */ + public io.grpc.health.v1.HealthCheckResponse check(io.grpc.health.v1.HealthCheckRequest request) throws io.grpc.StatusException { + return io.grpc.stub.ClientCalls.blockingV2UnaryCall( + getChannel(), getCheckMethod(), getCallOptions(), request); + } + + /** + *
+     * Performs a watch for the serving status of the requested service.
+     * The server will immediately send back a message indicating the current
+     * serving status.  It will then subsequently send a new message whenever
+     * the service's serving status changes.
+     * If the requested service is unknown when the call is received, the
+     * server will send a message setting the serving status to
+     * SERVICE_UNKNOWN but will *not* terminate the call.  If at some
+     * future point, the serving status of the service becomes known, the
+     * server will send a new message with the service's serving status.
+     * If the call terminates with status UNIMPLEMENTED, then clients
+     * should assume this method is not supported and should not retry the
+     * call.  If the call terminates with any other status (including OK),
+     * clients should retry the call with appropriate exponential backoff.
+     * 
+ */ + @io.grpc.ExperimentalApi("https://github.com/grpc/grpc-java/issues/10918") + public io.grpc.stub.BlockingClientCall + watch(io.grpc.health.v1.HealthCheckRequest request) { + return io.grpc.stub.ClientCalls.blockingV2ServerStreamingCall( + getChannel(), getWatchMethod(), getCallOptions(), request); + } + } + + /** + * A stub to allow clients to do limited synchronous rpc calls to service Health. + */ public static final class HealthBlockingStub extends io.grpc.stub.AbstractBlockingStub { private HealthBlockingStub( diff --git a/services/src/generated/main/grpc/io/grpc/reflection/v1/ServerReflectionGrpc.java b/services/src/generated/main/grpc/io/grpc/reflection/v1/ServerReflectionGrpc.java new file mode 100644 index 00000000000..04f8dea3ace --- /dev/null +++ b/services/src/generated/main/grpc/io/grpc/reflection/v1/ServerReflectionGrpc.java @@ -0,0 +1,327 @@ +package io.grpc.reflection.v1; + +import static io.grpc.MethodDescriptor.generateFullMethodName; + +/** + */ +@io.grpc.stub.annotations.GrpcGenerated +public final class ServerReflectionGrpc { + + private ServerReflectionGrpc() {} + + public static final java.lang.String SERVICE_NAME = "grpc.reflection.v1.ServerReflection"; + + // Static method descriptors that strictly reflect the proto. + private static volatile io.grpc.MethodDescriptor getServerReflectionInfoMethod; + + @io.grpc.stub.annotations.RpcMethod( + fullMethodName = SERVICE_NAME + '/' + "ServerReflectionInfo", + requestType = io.grpc.reflection.v1.ServerReflectionRequest.class, + responseType = io.grpc.reflection.v1.ServerReflectionResponse.class, + methodType = io.grpc.MethodDescriptor.MethodType.BIDI_STREAMING) + public static io.grpc.MethodDescriptor getServerReflectionInfoMethod() { + io.grpc.MethodDescriptor getServerReflectionInfoMethod; + if ((getServerReflectionInfoMethod = ServerReflectionGrpc.getServerReflectionInfoMethod) == null) { + synchronized (ServerReflectionGrpc.class) { + if ((getServerReflectionInfoMethod = ServerReflectionGrpc.getServerReflectionInfoMethod) == null) { + ServerReflectionGrpc.getServerReflectionInfoMethod = getServerReflectionInfoMethod = + io.grpc.MethodDescriptor.newBuilder() + .setType(io.grpc.MethodDescriptor.MethodType.BIDI_STREAMING) + .setFullMethodName(generateFullMethodName(SERVICE_NAME, "ServerReflectionInfo")) + .setSampledToLocalTracing(true) + .setRequestMarshaller(io.grpc.protobuf.ProtoUtils.marshaller( + io.grpc.reflection.v1.ServerReflectionRequest.getDefaultInstance())) + .setResponseMarshaller(io.grpc.protobuf.ProtoUtils.marshaller( + io.grpc.reflection.v1.ServerReflectionResponse.getDefaultInstance())) + .setSchemaDescriptor(new ServerReflectionMethodDescriptorSupplier("ServerReflectionInfo")) + .build(); + } + } + } + return getServerReflectionInfoMethod; + } + + /** + * Creates a new async stub that supports all call types for the service + */ + public static ServerReflectionStub newStub(io.grpc.Channel channel) { + io.grpc.stub.AbstractStub.StubFactory factory = + new io.grpc.stub.AbstractStub.StubFactory() { + @java.lang.Override + public ServerReflectionStub newStub(io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new ServerReflectionStub(channel, callOptions); + } + }; + return ServerReflectionStub.newStub(factory, channel); + } + + /** + * Creates a new blocking-style stub that supports all types of calls on the service + */ + public static ServerReflectionBlockingV2Stub newBlockingV2Stub( + io.grpc.Channel channel) { + io.grpc.stub.AbstractStub.StubFactory factory = + new io.grpc.stub.AbstractStub.StubFactory() { + @java.lang.Override + public ServerReflectionBlockingV2Stub newStub(io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new ServerReflectionBlockingV2Stub(channel, callOptions); + } + }; + return ServerReflectionBlockingV2Stub.newStub(factory, channel); + } + + /** + * Creates a new blocking-style stub that supports unary and streaming output calls on the service + */ + public static ServerReflectionBlockingStub newBlockingStub( + io.grpc.Channel channel) { + io.grpc.stub.AbstractStub.StubFactory factory = + new io.grpc.stub.AbstractStub.StubFactory() { + @java.lang.Override + public ServerReflectionBlockingStub newStub(io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new ServerReflectionBlockingStub(channel, callOptions); + } + }; + return ServerReflectionBlockingStub.newStub(factory, channel); + } + + /** + * Creates a new ListenableFuture-style stub that supports unary calls on the service + */ + public static ServerReflectionFutureStub newFutureStub( + io.grpc.Channel channel) { + io.grpc.stub.AbstractStub.StubFactory factory = + new io.grpc.stub.AbstractStub.StubFactory() { + @java.lang.Override + public ServerReflectionFutureStub newStub(io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new ServerReflectionFutureStub(channel, callOptions); + } + }; + return ServerReflectionFutureStub.newStub(factory, channel); + } + + /** + */ + public interface AsyncService { + + /** + *
+     * The reflection service is structured as a bidirectional stream, ensuring
+     * all related requests go to a single server.
+     * 
+ */ + default io.grpc.stub.StreamObserver serverReflectionInfo( + io.grpc.stub.StreamObserver responseObserver) { + return io.grpc.stub.ServerCalls.asyncUnimplementedStreamingCall(getServerReflectionInfoMethod(), responseObserver); + } + } + + /** + * Base class for the server implementation of the service ServerReflection. + */ + public static abstract class ServerReflectionImplBase + implements io.grpc.BindableService, AsyncService { + + @java.lang.Override public final io.grpc.ServerServiceDefinition bindService() { + return ServerReflectionGrpc.bindService(this); + } + } + + /** + * A stub to allow clients to do asynchronous rpc calls to service ServerReflection. + */ + public static final class ServerReflectionStub + extends io.grpc.stub.AbstractAsyncStub { + private ServerReflectionStub( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + super(channel, callOptions); + } + + @java.lang.Override + protected ServerReflectionStub build( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new ServerReflectionStub(channel, callOptions); + } + + /** + *
+     * The reflection service is structured as a bidirectional stream, ensuring
+     * all related requests go to a single server.
+     * 
+ */ + public io.grpc.stub.StreamObserver serverReflectionInfo( + io.grpc.stub.StreamObserver responseObserver) { + return io.grpc.stub.ClientCalls.asyncBidiStreamingCall( + getChannel().newCall(getServerReflectionInfoMethod(), getCallOptions()), responseObserver); + } + } + + /** + * A stub to allow clients to do synchronous rpc calls to service ServerReflection. + */ + public static final class ServerReflectionBlockingV2Stub + extends io.grpc.stub.AbstractBlockingStub { + private ServerReflectionBlockingV2Stub( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + super(channel, callOptions); + } + + @java.lang.Override + protected ServerReflectionBlockingV2Stub build( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new ServerReflectionBlockingV2Stub(channel, callOptions); + } + + /** + *
+     * The reflection service is structured as a bidirectional stream, ensuring
+     * all related requests go to a single server.
+     * 
+ */ + @io.grpc.ExperimentalApi("https://github.com/grpc/grpc-java/issues/10918") + public io.grpc.stub.BlockingClientCall + serverReflectionInfo() { + return io.grpc.stub.ClientCalls.blockingBidiStreamingCall( + getChannel(), getServerReflectionInfoMethod(), getCallOptions()); + } + } + + /** + * A stub to allow clients to do limited synchronous rpc calls to service ServerReflection. + */ + public static final class ServerReflectionBlockingStub + extends io.grpc.stub.AbstractBlockingStub { + private ServerReflectionBlockingStub( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + super(channel, callOptions); + } + + @java.lang.Override + protected ServerReflectionBlockingStub build( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new ServerReflectionBlockingStub(channel, callOptions); + } + } + + /** + * A stub to allow clients to do ListenableFuture-style rpc calls to service ServerReflection. + */ + public static final class ServerReflectionFutureStub + extends io.grpc.stub.AbstractFutureStub { + private ServerReflectionFutureStub( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + super(channel, callOptions); + } + + @java.lang.Override + protected ServerReflectionFutureStub build( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new ServerReflectionFutureStub(channel, callOptions); + } + } + + private static final int METHODID_SERVER_REFLECTION_INFO = 0; + + private static final class MethodHandlers implements + io.grpc.stub.ServerCalls.UnaryMethod, + io.grpc.stub.ServerCalls.ServerStreamingMethod, + io.grpc.stub.ServerCalls.ClientStreamingMethod, + io.grpc.stub.ServerCalls.BidiStreamingMethod { + private final AsyncService serviceImpl; + private final int methodId; + + MethodHandlers(AsyncService serviceImpl, int methodId) { + this.serviceImpl = serviceImpl; + this.methodId = methodId; + } + + @java.lang.Override + @java.lang.SuppressWarnings("unchecked") + public void invoke(Req request, io.grpc.stub.StreamObserver responseObserver) { + switch (methodId) { + default: + throw new AssertionError(); + } + } + + @java.lang.Override + @java.lang.SuppressWarnings("unchecked") + public io.grpc.stub.StreamObserver invoke( + io.grpc.stub.StreamObserver responseObserver) { + switch (methodId) { + case METHODID_SERVER_REFLECTION_INFO: + return (io.grpc.stub.StreamObserver) serviceImpl.serverReflectionInfo( + (io.grpc.stub.StreamObserver) responseObserver); + default: + throw new AssertionError(); + } + } + } + + public static final io.grpc.ServerServiceDefinition bindService(AsyncService service) { + return io.grpc.ServerServiceDefinition.builder(getServiceDescriptor()) + .addMethod( + getServerReflectionInfoMethod(), + io.grpc.stub.ServerCalls.asyncBidiStreamingCall( + new MethodHandlers< + io.grpc.reflection.v1.ServerReflectionRequest, + io.grpc.reflection.v1.ServerReflectionResponse>( + service, METHODID_SERVER_REFLECTION_INFO))) + .build(); + } + + private static abstract class ServerReflectionBaseDescriptorSupplier + implements io.grpc.protobuf.ProtoFileDescriptorSupplier, io.grpc.protobuf.ProtoServiceDescriptorSupplier { + ServerReflectionBaseDescriptorSupplier() {} + + @java.lang.Override + public com.google.protobuf.Descriptors.FileDescriptor getFileDescriptor() { + return io.grpc.reflection.v1.ServerReflectionProto.getDescriptor(); + } + + @java.lang.Override + public com.google.protobuf.Descriptors.ServiceDescriptor getServiceDescriptor() { + return getFileDescriptor().findServiceByName("ServerReflection"); + } + } + + private static final class ServerReflectionFileDescriptorSupplier + extends ServerReflectionBaseDescriptorSupplier { + ServerReflectionFileDescriptorSupplier() {} + } + + private static final class ServerReflectionMethodDescriptorSupplier + extends ServerReflectionBaseDescriptorSupplier + implements io.grpc.protobuf.ProtoMethodDescriptorSupplier { + private final java.lang.String methodName; + + ServerReflectionMethodDescriptorSupplier(java.lang.String methodName) { + this.methodName = methodName; + } + + @java.lang.Override + public com.google.protobuf.Descriptors.MethodDescriptor getMethodDescriptor() { + return getServiceDescriptor().findMethodByName(methodName); + } + } + + private static volatile io.grpc.ServiceDescriptor serviceDescriptor; + + public static io.grpc.ServiceDescriptor getServiceDescriptor() { + io.grpc.ServiceDescriptor result = serviceDescriptor; + if (result == null) { + synchronized (ServerReflectionGrpc.class) { + result = serviceDescriptor; + if (result == null) { + serviceDescriptor = result = io.grpc.ServiceDescriptor.newBuilder(SERVICE_NAME) + .setSchemaDescriptor(new ServerReflectionFileDescriptorSupplier()) + .addMethod(getServerReflectionInfoMethod()) + .build(); + } + } + } + return result; + } +} diff --git a/services/src/generated/main/grpc/io/grpc/reflection/v1alpha/ServerReflectionGrpc.java b/services/src/generated/main/grpc/io/grpc/reflection/v1alpha/ServerReflectionGrpc.java index 7119e96d1f3..3cbb3a1d1b9 100644 --- a/services/src/generated/main/grpc/io/grpc/reflection/v1alpha/ServerReflectionGrpc.java +++ b/services/src/generated/main/grpc/io/grpc/reflection/v1alpha/ServerReflectionGrpc.java @@ -4,9 +4,6 @@ /** */ -@javax.annotation.Generated( - value = "by gRPC proto compiler", - comments = "Source: grpc/reflection/v1alpha/reflection.proto") @io.grpc.stub.annotations.GrpcGenerated public final class ServerReflectionGrpc { @@ -60,6 +57,21 @@ public ServerReflectionStub newStub(io.grpc.Channel channel, io.grpc.CallOptions return ServerReflectionStub.newStub(factory, channel); } + /** + * Creates a new blocking-style stub that supports all types of calls on the service + */ + public static ServerReflectionBlockingV2Stub newBlockingV2Stub( + io.grpc.Channel channel) { + io.grpc.stub.AbstractStub.StubFactory factory = + new io.grpc.stub.AbstractStub.StubFactory() { + @java.lang.Override + public ServerReflectionBlockingV2Stub newStub(io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new ServerReflectionBlockingV2Stub(channel, callOptions); + } + }; + return ServerReflectionBlockingV2Stub.newStub(factory, channel); + } + /** * Creates a new blocking-style stub that supports unary and streaming output calls on the service */ @@ -149,6 +161,36 @@ public io.grpc.stub.StreamObserver { + private ServerReflectionBlockingV2Stub( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + super(channel, callOptions); + } + + @java.lang.Override + protected ServerReflectionBlockingV2Stub build( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new ServerReflectionBlockingV2Stub(channel, callOptions); + } + + /** + *
+     * The reflection service is structured as a bidirectional stream, ensuring
+     * all related requests go to a single server.
+     * 
+ */ + @io.grpc.ExperimentalApi("https://github.com/grpc/grpc-java/issues/10918") + public io.grpc.stub.BlockingClientCall + serverReflectionInfo() { + return io.grpc.stub.ClientCalls.blockingBidiStreamingCall( + getChannel(), getServerReflectionInfoMethod(), getCallOptions()); + } + } + + /** + * A stub to allow clients to do limited synchronous rpc calls to service ServerReflection. + */ public static final class ServerReflectionBlockingStub extends io.grpc.stub.AbstractBlockingStub { private ServerReflectionBlockingStub( diff --git a/services/src/generated/test/grpc/io/grpc/reflection/testing/AnotherDynamicServiceGrpc.java b/services/src/generated/test/grpc/io/grpc/reflection/testing/AnotherDynamicServiceGrpc.java index 088d27b619c..978af2d887e 100644 --- a/services/src/generated/test/grpc/io/grpc/reflection/testing/AnotherDynamicServiceGrpc.java +++ b/services/src/generated/test/grpc/io/grpc/reflection/testing/AnotherDynamicServiceGrpc.java @@ -7,9 +7,6 @@ * AnotherDynamicService * */ -@javax.annotation.Generated( - value = "by gRPC proto compiler", - comments = "Source: io/grpc/reflection/testing/dynamic_reflection_test.proto") @io.grpc.stub.annotations.GrpcGenerated public final class AnotherDynamicServiceGrpc { @@ -63,6 +60,21 @@ public AnotherDynamicServiceStub newStub(io.grpc.Channel channel, io.grpc.CallOp return AnotherDynamicServiceStub.newStub(factory, channel); } + /** + * Creates a new blocking-style stub that supports all types of calls on the service + */ + public static AnotherDynamicServiceBlockingV2Stub newBlockingV2Stub( + io.grpc.Channel channel) { + io.grpc.stub.AbstractStub.StubFactory factory = + new io.grpc.stub.AbstractStub.StubFactory() { + @java.lang.Override + public AnotherDynamicServiceBlockingV2Stub newStub(io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new AnotherDynamicServiceBlockingV2Stub(channel, callOptions); + } + }; + return AnotherDynamicServiceBlockingV2Stub.newStub(factory, channel); + } + /** * Creates a new blocking-style stub that supports unary and streaming output calls on the service */ @@ -162,6 +174,36 @@ public void method(io.grpc.reflection.testing.DynamicRequest request, * AnotherDynamicService * */ + public static final class AnotherDynamicServiceBlockingV2Stub + extends io.grpc.stub.AbstractBlockingStub { + private AnotherDynamicServiceBlockingV2Stub( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + super(channel, callOptions); + } + + @java.lang.Override + protected AnotherDynamicServiceBlockingV2Stub build( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new AnotherDynamicServiceBlockingV2Stub(channel, callOptions); + } + + /** + *
+     * A method
+     * 
+ */ + public io.grpc.reflection.testing.DynamicReply method(io.grpc.reflection.testing.DynamicRequest request) throws io.grpc.StatusException { + return io.grpc.stub.ClientCalls.blockingV2UnaryCall( + getChannel(), getMethodMethod(), getCallOptions(), request); + } + } + + /** + * A stub to allow clients to do limited synchronous rpc calls to service AnotherDynamicService. + *
+   * AnotherDynamicService
+   * 
+ */ public static final class AnotherDynamicServiceBlockingStub extends io.grpc.stub.AbstractBlockingStub { private AnotherDynamicServiceBlockingStub( diff --git a/services/src/generated/test/grpc/io/grpc/reflection/testing/AnotherReflectableServiceGrpc.java b/services/src/generated/test/grpc/io/grpc/reflection/testing/AnotherReflectableServiceGrpc.java index a84b95b2126..e688c3d5cca 100644 --- a/services/src/generated/test/grpc/io/grpc/reflection/testing/AnotherReflectableServiceGrpc.java +++ b/services/src/generated/test/grpc/io/grpc/reflection/testing/AnotherReflectableServiceGrpc.java @@ -4,9 +4,6 @@ /** */ -@javax.annotation.Generated( - value = "by gRPC proto compiler", - comments = "Source: io/grpc/reflection/testing/reflection_test.proto") @io.grpc.stub.annotations.GrpcGenerated public final class AnotherReflectableServiceGrpc { @@ -60,6 +57,21 @@ public AnotherReflectableServiceStub newStub(io.grpc.Channel channel, io.grpc.Ca return AnotherReflectableServiceStub.newStub(factory, channel); } + /** + * Creates a new blocking-style stub that supports all types of calls on the service + */ + public static AnotherReflectableServiceBlockingV2Stub newBlockingV2Stub( + io.grpc.Channel channel) { + io.grpc.stub.AbstractStub.StubFactory factory = + new io.grpc.stub.AbstractStub.StubFactory() { + @java.lang.Override + public AnotherReflectableServiceBlockingV2Stub newStub(io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new AnotherReflectableServiceBlockingV2Stub(channel, callOptions); + } + }; + return AnotherReflectableServiceBlockingV2Stub.newStub(factory, channel); + } + /** * Creates a new blocking-style stub that supports unary and streaming output calls on the service */ @@ -141,6 +153,30 @@ public void method(io.grpc.reflection.testing.Request request, /** * A stub to allow clients to do synchronous rpc calls to service AnotherReflectableService. */ + public static final class AnotherReflectableServiceBlockingV2Stub + extends io.grpc.stub.AbstractBlockingStub { + private AnotherReflectableServiceBlockingV2Stub( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + super(channel, callOptions); + } + + @java.lang.Override + protected AnotherReflectableServiceBlockingV2Stub build( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new AnotherReflectableServiceBlockingV2Stub(channel, callOptions); + } + + /** + */ + public io.grpc.reflection.testing.Reply method(io.grpc.reflection.testing.Request request) throws io.grpc.StatusException { + return io.grpc.stub.ClientCalls.blockingV2UnaryCall( + getChannel(), getMethodMethod(), getCallOptions(), request); + } + } + + /** + * A stub to allow clients to do limited synchronous rpc calls to service AnotherReflectableService. + */ public static final class AnotherReflectableServiceBlockingStub extends io.grpc.stub.AbstractBlockingStub { private AnotherReflectableServiceBlockingStub( diff --git a/services/src/generated/test/grpc/io/grpc/reflection/testing/DynamicServiceGrpc.java b/services/src/generated/test/grpc/io/grpc/reflection/testing/DynamicServiceGrpc.java index 338b67e684d..efef61be151 100644 --- a/services/src/generated/test/grpc/io/grpc/reflection/testing/DynamicServiceGrpc.java +++ b/services/src/generated/test/grpc/io/grpc/reflection/testing/DynamicServiceGrpc.java @@ -7,9 +7,6 @@ * A DynamicService * */ -@javax.annotation.Generated( - value = "by gRPC proto compiler", - comments = "Source: io/grpc/reflection/testing/dynamic_reflection_test.proto") @io.grpc.stub.annotations.GrpcGenerated public final class DynamicServiceGrpc { @@ -63,6 +60,21 @@ public DynamicServiceStub newStub(io.grpc.Channel channel, io.grpc.CallOptions c return DynamicServiceStub.newStub(factory, channel); } + /** + * Creates a new blocking-style stub that supports all types of calls on the service + */ + public static DynamicServiceBlockingV2Stub newBlockingV2Stub( + io.grpc.Channel channel) { + io.grpc.stub.AbstractStub.StubFactory factory = + new io.grpc.stub.AbstractStub.StubFactory() { + @java.lang.Override + public DynamicServiceBlockingV2Stub newStub(io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new DynamicServiceBlockingV2Stub(channel, callOptions); + } + }; + return DynamicServiceBlockingV2Stub.newStub(factory, channel); + } + /** * Creates a new blocking-style stub that supports unary and streaming output calls on the service */ @@ -162,6 +174,36 @@ public void method(io.grpc.reflection.testing.DynamicRequest request, * A DynamicService * */ + public static final class DynamicServiceBlockingV2Stub + extends io.grpc.stub.AbstractBlockingStub { + private DynamicServiceBlockingV2Stub( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + super(channel, callOptions); + } + + @java.lang.Override + protected DynamicServiceBlockingV2Stub build( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new DynamicServiceBlockingV2Stub(channel, callOptions); + } + + /** + *
+     * A method
+     * 
+ */ + public io.grpc.reflection.testing.DynamicReply method(io.grpc.reflection.testing.DynamicRequest request) throws io.grpc.StatusException { + return io.grpc.stub.ClientCalls.blockingV2UnaryCall( + getChannel(), getMethodMethod(), getCallOptions(), request); + } + } + + /** + * A stub to allow clients to do limited synchronous rpc calls to service DynamicService. + *
+   * A DynamicService
+   * 
+ */ public static final class DynamicServiceBlockingStub extends io.grpc.stub.AbstractBlockingStub { private DynamicServiceBlockingStub( diff --git a/services/src/generated/test/grpc/io/grpc/reflection/testing/ReflectableServiceGrpc.java b/services/src/generated/test/grpc/io/grpc/reflection/testing/ReflectableServiceGrpc.java index 0b8954b5eb9..b5d130d6952 100644 --- a/services/src/generated/test/grpc/io/grpc/reflection/testing/ReflectableServiceGrpc.java +++ b/services/src/generated/test/grpc/io/grpc/reflection/testing/ReflectableServiceGrpc.java @@ -4,9 +4,6 @@ /** */ -@javax.annotation.Generated( - value = "by gRPC proto compiler", - comments = "Source: io/grpc/reflection/testing/reflection_test.proto") @io.grpc.stub.annotations.GrpcGenerated public final class ReflectableServiceGrpc { @@ -60,6 +57,21 @@ public ReflectableServiceStub newStub(io.grpc.Channel channel, io.grpc.CallOptio return ReflectableServiceStub.newStub(factory, channel); } + /** + * Creates a new blocking-style stub that supports all types of calls on the service + */ + public static ReflectableServiceBlockingV2Stub newBlockingV2Stub( + io.grpc.Channel channel) { + io.grpc.stub.AbstractStub.StubFactory factory = + new io.grpc.stub.AbstractStub.StubFactory() { + @java.lang.Override + public ReflectableServiceBlockingV2Stub newStub(io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new ReflectableServiceBlockingV2Stub(channel, callOptions); + } + }; + return ReflectableServiceBlockingV2Stub.newStub(factory, channel); + } + /** * Creates a new blocking-style stub that supports unary and streaming output calls on the service */ @@ -141,6 +153,30 @@ public void method(io.grpc.reflection.testing.Request request, /** * A stub to allow clients to do synchronous rpc calls to service ReflectableService. */ + public static final class ReflectableServiceBlockingV2Stub + extends io.grpc.stub.AbstractBlockingStub { + private ReflectableServiceBlockingV2Stub( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + super(channel, callOptions); + } + + @java.lang.Override + protected ReflectableServiceBlockingV2Stub build( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new ReflectableServiceBlockingV2Stub(channel, callOptions); + } + + /** + */ + public io.grpc.reflection.testing.Reply method(io.grpc.reflection.testing.Request request) throws io.grpc.StatusException { + return io.grpc.stub.ClientCalls.blockingV2UnaryCall( + getChannel(), getMethodMethod(), getCallOptions(), request); + } + } + + /** + * A stub to allow clients to do limited synchronous rpc calls to service ReflectableService. + */ public static final class ReflectableServiceBlockingStub extends io.grpc.stub.AbstractBlockingStub { private ReflectableServiceBlockingStub( diff --git a/services/src/main/java/io/grpc/protobuf/services/BinlogHelper.java b/services/src/main/java/io/grpc/protobuf/services/BinlogHelper.java index 845ec1036ad..e810c983beb 100644 --- a/services/src/main/java/io/grpc/protobuf/services/BinlogHelper.java +++ b/services/src/main/java/io/grpc/protobuf/services/BinlogHelper.java @@ -22,7 +22,6 @@ import static io.grpc.protobuf.services.BinaryLogProvider.BYTEARRAY_MARSHALLER; import com.google.common.annotations.VisibleForTesting; -import com.google.common.base.Charsets; import com.google.common.base.Preconditions; import com.google.common.base.Splitter; import com.google.protobuf.ByteString; @@ -59,6 +58,7 @@ import java.net.InetAddress; import java.net.InetSocketAddress; import java.net.SocketAddress; +import java.nio.charset.StandardCharsets; import java.util.Collections; import java.util.HashMap; import java.util.HashSet; @@ -841,7 +841,7 @@ static MaybeTruncated createMetadataProto if (serialized != null) { int curBytes = 0; for (int i = 0; i < serialized.length; i += 2) { - String key = new String(serialized[i], Charsets.UTF_8); + String key = new String(serialized[i], StandardCharsets.UTF_8); byte[] value = serialized[i + 1]; if (NEVER_INCLUDED_METADATA.contains(key)) { continue; diff --git a/services/src/main/java/io/grpc/protobuf/services/ChannelzProtoUtil.java b/services/src/main/java/io/grpc/protobuf/services/ChannelzProtoUtil.java index cf003b2f881..74448a8c5bf 100644 --- a/services/src/main/java/io/grpc/protobuf/services/ChannelzProtoUtil.java +++ b/services/src/main/java/io/grpc/protobuf/services/ChannelzProtoUtil.java @@ -21,6 +21,7 @@ import com.google.protobuf.Any; import com.google.protobuf.ByteString; import com.google.protobuf.Int64Value; +import com.google.protobuf.MessageLite; import com.google.protobuf.util.Durations; import com.google.protobuf.util.Timestamps; import io.grpc.ConnectivityState; @@ -79,6 +80,8 @@ /** * A static utility class for turning internal data structures into protos. + * + *

Works with both regular and lite protos. */ final class ChannelzProtoUtil { private static final Logger logger = Logger.getLogger(ChannelzProtoUtil.class.getName()); @@ -254,22 +257,20 @@ static SocketOption toSocketOptionLinger(int lingerSeconds) { } else { lingerOpt = SocketOptionLinger.getDefaultInstance(); } - return SocketOption - .newBuilder() + return SocketOption.newBuilder() .setName(SO_LINGER) - .setAdditional(Any.pack(lingerOpt)) + .setAdditional(packToAny("SocketOptionLinger", lingerOpt)) .build(); } static SocketOption toSocketOptionTimeout(String name, int timeoutMillis) { Preconditions.checkNotNull(name); - return SocketOption - .newBuilder() + return SocketOption.newBuilder() .setName(name) .setAdditional( - Any.pack( - SocketOptionTimeout - .newBuilder() + packToAny( + "SocketOptionTimeout", + SocketOptionTimeout.newBuilder() .setDuration(Durations.fromMillis(timeoutMillis)) .build())) .build(); @@ -307,10 +308,9 @@ static SocketOption toSocketOptionTcpInfo(InternalChannelz.TcpInfo i) { .setTcpiAdvmss(i.advmss) .setTcpiReordering(i.reordering) .build(); - return SocketOption - .newBuilder() + return SocketOption.newBuilder() .setName(TCP_INFO) - .setAdditional(Any.pack(tcpInfo)) + .setAdditional(packToAny("SocketOptionTcpInfo", tcpInfo)) .build(); } @@ -380,10 +380,11 @@ private static ChannelTrace toChannelTrace(InternalChannelz.ChannelTrace channel private static List toChannelTraceEvents(List events) { List channelTraceEvents = new ArrayList<>(); for (Event event : events) { - ChannelTraceEvent.Builder builder = ChannelTraceEvent.newBuilder() - .setDescription(event.description) - .setSeverity(Severity.valueOf(event.severity.name())) - .setTimestamp(Timestamps.fromNanos(event.timestampNanos)); + ChannelTraceEvent.Builder builder = + ChannelTraceEvent.newBuilder() + .setDescription(event.description) + .setSeverity(toSeverity(event.severity)) + .setTimestamp(Timestamps.fromNanos(event.timestampNanos)); if (event.channelRef != null) { builder.setChannelRef(toChannelRef(event.channelRef)); } @@ -395,14 +396,39 @@ private static List toChannelTraceEvents(List events) return Collections.unmodifiableList(channelTraceEvents); } + static Severity toSeverity(Event.Severity severity) { + if (severity == null) { + return Severity.CT_UNKNOWN; + } + switch (severity) { + case CT_INFO: + return Severity.CT_INFO; + case CT_ERROR: + return Severity.CT_ERROR; + case CT_WARNING: + return Severity.CT_WARNING; + default: + return Severity.CT_UNKNOWN; + } + } + static State toState(ConnectivityState state) { if (state == null) { return State.UNKNOWN; } - try { - return Enum.valueOf(State.class, state.name()); - } catch (IllegalArgumentException e) { - return State.UNKNOWN; + switch (state) { + case IDLE: + return State.IDLE; + case READY: + return State.READY; + case CONNECTING: + return State.CONNECTING; + case SHUTDOWN: + return State.SHUTDOWN; + case TRANSIENT_FAILURE: + return State.TRANSIENT_FAILURE; + default: + return State.UNKNOWN; } } @@ -468,4 +494,12 @@ private static T getFuture(ListenableFuture future) { throw Status.INTERNAL.withCause(e).asRuntimeException(); } } + + // A version of Any.pack() that works with protolite. + private static Any packToAny(String typeName, MessageLite value) { + return Any.newBuilder() + .setTypeUrl("type.googleapis.com/grpc.channelz.v1." + typeName) + .setValue(value.toByteString()) + .build(); + } } diff --git a/services/src/main/java/io/grpc/protobuf/services/HealthCheckingLoadBalancerFactory.java b/services/src/main/java/io/grpc/protobuf/services/HealthCheckingLoadBalancerFactory.java index cac522caf9e..b9f235d0aff 100644 --- a/services/src/main/java/io/grpc/protobuf/services/HealthCheckingLoadBalancerFactory.java +++ b/services/src/main/java/io/grpc/protobuf/services/HealthCheckingLoadBalancerFactory.java @@ -144,6 +144,30 @@ void setHealthCheckedService(@Nullable String service) { public String toString() { return MoreObjects.toStringHelper(this).add("delegate", delegate()).toString(); } + + @Override + public void updateBalancingState( + io.grpc.ConnectivityState newState, LoadBalancer.SubchannelPicker newPicker) { + delegate().updateBalancingState(newState, new HealthCheckPicker(newPicker)); + } + + private final class HealthCheckPicker extends LoadBalancer.SubchannelPicker { + private final LoadBalancer.SubchannelPicker delegate; + + HealthCheckPicker(LoadBalancer.SubchannelPicker delegate) { + this.delegate = delegate; + } + + @Override + public LoadBalancer.PickResult pickSubchannel(LoadBalancer.PickSubchannelArgs args) { + LoadBalancer.PickResult result = delegate.pickSubchannel(args); + LoadBalancer.Subchannel subchannel = result.getSubchannel(); + if (subchannel instanceof SubchannelImpl) { + return result.copyWithSubchannel(((SubchannelImpl) subchannel).delegate()); + } + return result; + } + } } @VisibleForTesting @@ -194,7 +218,18 @@ public void handleResolvedAddresses(ResolvedAddresses resolvedAddresses) { .get(LoadBalancer.ATTR_HEALTH_CHECKING_CONFIG); String serviceName = ServiceConfigUtil.getHealthCheckedServiceName(healthCheckingConfig); helper.setHealthCheckedService(serviceName); - super.handleResolvedAddresses(resolvedAddresses); + delegate.handleResolvedAddresses(resolvedAddresses); + } + + @Override + public Status acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { + Map healthCheckingConfig = + resolvedAddresses + .getAttributes() + .get(LoadBalancer.ATTR_HEALTH_CHECKING_CONFIG); + String serviceName = ServiceConfigUtil.getHealthCheckedServiceName(healthCheckingConfig); + helper.setHealthCheckedService(serviceName); + return delegate.acceptResolvedAddresses(resolvedAddresses); } @Override diff --git a/services/src/main/java/io/grpc/protobuf/services/HealthServiceImpl.java b/services/src/main/java/io/grpc/protobuf/services/HealthServiceImpl.java index 6ce602b9295..5cd294b4fbe 100644 --- a/services/src/main/java/io/grpc/protobuf/services/HealthServiceImpl.java +++ b/services/src/main/java/io/grpc/protobuf/services/HealthServiceImpl.java @@ -18,6 +18,7 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.util.concurrent.MoreExecutors; +import com.google.errorprone.annotations.concurrent.GuardedBy; import io.grpc.Context; import io.grpc.Context.CancellationListener; import io.grpc.Status; @@ -26,6 +27,7 @@ import io.grpc.health.v1.HealthCheckResponse; import io.grpc.health.v1.HealthCheckResponse.ServingStatus; import io.grpc.health.v1.HealthGrpc; +import io.grpc.stub.ServerCallStreamObserver; import io.grpc.stub.StreamObserver; import java.util.HashMap; import java.util.IdentityHashMap; @@ -34,7 +36,6 @@ import java.util.logging.Level; import java.util.logging.Logger; import javax.annotation.Nullable; -import javax.annotation.concurrent.GuardedBy; final class HealthServiceImpl extends HealthGrpc.HealthImplBase { @@ -83,6 +84,11 @@ public void watch(HealthCheckRequest request, final StreamObserver responseObserver) { final String service = request.getService(); synchronized (watchLock) { + if (responseObserver instanceof ServerCallStreamObserver) { + ((ServerCallStreamObserver) responseObserver).setOnCancelHandler(() -> { + removeWatcher(service, responseObserver); + }); + } ServingStatus status = statusMap.get(service); responseObserver.onNext(getResponseForWatch(status)); IdentityHashMap, Boolean> serviceWatchers = @@ -98,21 +104,25 @@ public void watch(HealthCheckRequest request, @Override // Called when the client has closed the stream public void cancelled(Context context) { - synchronized (watchLock) { - IdentityHashMap, Boolean> serviceWatchers = - watchers.get(service); - if (serviceWatchers != null) { - serviceWatchers.remove(responseObserver); - if (serviceWatchers.isEmpty()) { - watchers.remove(service); - } - } - } + removeWatcher(service, responseObserver); } }, MoreExecutors.directExecutor()); } + void removeWatcher(String service, StreamObserver responseObserver) { + synchronized (watchLock) { + IdentityHashMap, Boolean> serviceWatchers = + watchers.get(service); + if (serviceWatchers != null) { + serviceWatchers.remove(responseObserver); + if (serviceWatchers.isEmpty()) { + watchers.remove(service); + } + } + } + } + void setStatus(String service, ServingStatus status) { synchronized (watchLock) { if (terminal) { diff --git a/services/src/main/java/io/grpc/protobuf/services/ProtoReflectionService.java b/services/src/main/java/io/grpc/protobuf/services/ProtoReflectionService.java index 4a7840a3ad9..07008b682c3 100644 --- a/services/src/main/java/io/grpc/protobuf/services/ProtoReflectionService.java +++ b/services/src/main/java/io/grpc/protobuf/services/ProtoReflectionService.java @@ -16,43 +16,15 @@ package io.grpc.protobuf.services; -import static com.google.common.base.Preconditions.checkNotNull; -import static com.google.common.base.Preconditions.checkState; - -import com.google.protobuf.Descriptors.Descriptor; -import com.google.protobuf.Descriptors.FieldDescriptor; -import com.google.protobuf.Descriptors.FileDescriptor; -import com.google.protobuf.Descriptors.MethodDescriptor; -import com.google.protobuf.Descriptors.ServiceDescriptor; import io.grpc.BindableService; import io.grpc.ExperimentalApi; -import io.grpc.InternalServer; -import io.grpc.Server; +import io.grpc.MethodDescriptor; +import io.grpc.ServerCallHandler; import io.grpc.ServerServiceDefinition; -import io.grpc.Status; -import io.grpc.protobuf.ProtoFileDescriptorSupplier; -import io.grpc.reflection.v1alpha.ErrorResponse; -import io.grpc.reflection.v1alpha.ExtensionNumberResponse; -import io.grpc.reflection.v1alpha.ExtensionRequest; -import io.grpc.reflection.v1alpha.FileDescriptorResponse; -import io.grpc.reflection.v1alpha.ListServiceResponse; -import io.grpc.reflection.v1alpha.ServerReflectionGrpc; -import io.grpc.reflection.v1alpha.ServerReflectionRequest; -import io.grpc.reflection.v1alpha.ServerReflectionResponse; -import io.grpc.reflection.v1alpha.ServiceResponse; -import io.grpc.stub.ServerCallStreamObserver; -import io.grpc.stub.StreamObserver; -import java.util.ArrayDeque; -import java.util.Collections; -import java.util.HashMap; -import java.util.HashSet; -import java.util.List; -import java.util.Map; -import java.util.Queue; -import java.util.Set; -import java.util.WeakHashMap; -import javax.annotation.Nullable; -import javax.annotation.concurrent.GuardedBy; +import io.grpc.ServiceDescriptor; +import io.grpc.reflection.v1.ServerReflectionGrpc; +import io.grpc.reflection.v1.ServerReflectionRequest; +import io.grpc.reflection.v1.ServerReflectionResponse; /** * Provides a reflection service for Protobuf services (including the reflection service itself). @@ -60,480 +32,56 @@ *

Separately tracks mutable and immutable services. Throws an exception if either group of * services contains multiple Protobuf files with declarations of the same service, method, type, or * extension. + * Uses the deprecated v1alpha proto. New users should use {@link ProtoReflectionServiceV1} instead. */ @ExperimentalApi("https://github.com/grpc/grpc-java/issues/2222") -public final class ProtoReflectionService extends ServerReflectionGrpc.ServerReflectionImplBase { - - private final Object lock = new Object(); - - @GuardedBy("lock") - private final Map serverReflectionIndexes = new WeakHashMap<>(); +public final class ProtoReflectionService implements BindableService { - private ProtoReflectionService() {} + private ProtoReflectionService() { + } - /** - * Creates a instance of {@link ProtoReflectionService}. - */ + @Deprecated public static BindableService newInstance() { return new ProtoReflectionService(); } - /** - * Retrieves the index for services of the server that dispatches the current call. Computes - * one if not exist. The index is updated if any changes to the server's mutable services are - * detected. A change is any addition or removal in the set of file descriptors attached to the - * mutable services or a change in the service names. - */ - private ServerReflectionIndex getRefreshedIndex() { - synchronized (lock) { - Server server = InternalServer.SERVER_CONTEXT_KEY.get(); - ServerReflectionIndex index = serverReflectionIndexes.get(server); - if (index == null) { - index = - new ServerReflectionIndex(server.getImmutableServices(), server.getMutableServices()); - serverReflectionIndexes.put(server, index); - return index; - } - - Set serverFileDescriptors = new HashSet<>(); - Set serverServiceNames = new HashSet<>(); - List serverMutableServices = server.getMutableServices(); - for (ServerServiceDefinition mutableService : serverMutableServices) { - io.grpc.ServiceDescriptor serviceDescriptor = mutableService.getServiceDescriptor(); - if (serviceDescriptor.getSchemaDescriptor() instanceof ProtoFileDescriptorSupplier) { - String serviceName = serviceDescriptor.getName(); - FileDescriptor fileDescriptor = - ((ProtoFileDescriptorSupplier) serviceDescriptor.getSchemaDescriptor()) - .getFileDescriptor(); - serverFileDescriptors.add(fileDescriptor); - serverServiceNames.add(serviceName); - } - } - - // Replace the index if the underlying mutable services have changed. Check both the file - // descriptors and the service names, because one file descriptor can define multiple - // services. - FileDescriptorIndex mutableServicesIndex = index.getMutableServicesIndex(); - if (!mutableServicesIndex.getServiceFileDescriptors().equals(serverFileDescriptors) - || !mutableServicesIndex.getServiceNames().equals(serverServiceNames)) { - index = - new ServerReflectionIndex(server.getImmutableServices(), serverMutableServices); - serverReflectionIndexes.put(server, index); - } - - return index; - } - } - @Override - public StreamObserver serverReflectionInfo( - final StreamObserver responseObserver) { - final ServerCallStreamObserver serverCallStreamObserver = - (ServerCallStreamObserver) responseObserver; - ProtoReflectionStreamObserver requestObserver = - new ProtoReflectionStreamObserver(getRefreshedIndex(), serverCallStreamObserver); - serverCallStreamObserver.setOnReadyHandler(requestObserver); - serverCallStreamObserver.disableAutoRequest(); - serverCallStreamObserver.request(1); - return requestObserver; + @SuppressWarnings("deprecation") + public ServerServiceDefinition bindService() { + ServerServiceDefinition serverServiceDefinitionV1 = ProtoReflectionServiceV1.newInstance() + .bindService(); + MethodDescriptor methodDescriptorV1 = + ServerReflectionGrpc.getServerReflectionInfoMethod(); + // Retain the v1 proto marshallers but change the method name and schema descriptor to v1alpha. + MethodDescriptor methodDescriptorV1AlphaGenerated = + io.grpc.reflection.v1alpha.ServerReflectionGrpc.getServerReflectionInfoMethod(); + MethodDescriptor methodDescriptorV1Alpha = + methodDescriptorV1.toBuilder() + .setFullMethodName(methodDescriptorV1AlphaGenerated.getFullMethodName()) + .setSchemaDescriptor(methodDescriptorV1AlphaGenerated.getSchemaDescriptor()) + .build(); + // Retain the v1 server call handler but change the service name schema descriptor in the + // service descriptor to v1alpha. + ServiceDescriptor serviceDescriptorV1AlphaGenerated = + io.grpc.reflection.v1alpha.ServerReflectionGrpc.getServiceDescriptor(); + ServiceDescriptor serviceDescriptorV1Alpha = + ServiceDescriptor.newBuilder(serviceDescriptorV1AlphaGenerated.getName()) + .setSchemaDescriptor(serviceDescriptorV1AlphaGenerated.getSchemaDescriptor()) + .addMethod(methodDescriptorV1Alpha) + .build(); + return ServerServiceDefinition.builder(serviceDescriptorV1Alpha) + .addMethod(methodDescriptorV1Alpha, createServerCallHandler(serverServiceDefinitionV1)) + .build(); } - private static class ProtoReflectionStreamObserver - implements Runnable, StreamObserver { - private final ServerReflectionIndex serverReflectionIndex; - private final ServerCallStreamObserver serverCallStreamObserver; - - private boolean closeAfterSend = false; - private ServerReflectionRequest request; - - ProtoReflectionStreamObserver( - ServerReflectionIndex serverReflectionIndex, - ServerCallStreamObserver serverCallStreamObserver) { - this.serverReflectionIndex = serverReflectionIndex; - this.serverCallStreamObserver = checkNotNull(serverCallStreamObserver, "observer"); - } - - @Override - public void run() { - if (request != null) { - handleReflectionRequest(); - } - } - - @Override - public void onNext(ServerReflectionRequest request) { - checkState(this.request == null); - this.request = checkNotNull(request); - handleReflectionRequest(); - } - - private void handleReflectionRequest() { - if (serverCallStreamObserver.isReady()) { - switch (request.getMessageRequestCase()) { - case FILE_BY_FILENAME: - getFileByName(request); - break; - case FILE_CONTAINING_SYMBOL: - getFileContainingSymbol(request); - break; - case FILE_CONTAINING_EXTENSION: - getFileByExtension(request); - break; - case ALL_EXTENSION_NUMBERS_OF_TYPE: - getAllExtensions(request); - break; - case LIST_SERVICES: - listServices(request); - break; - default: - sendErrorResponse( - request, - Status.Code.UNIMPLEMENTED, - "not implemented " + request.getMessageRequestCase()); - } - request = null; - if (closeAfterSend) { - serverCallStreamObserver.onCompleted(); - } else { - serverCallStreamObserver.request(1); - } - } - } - - @Override - public void onCompleted() { - if (request != null) { - closeAfterSend = true; - } else { - serverCallStreamObserver.onCompleted(); - } - } - - @Override - public void onError(Throwable cause) { - serverCallStreamObserver.onError(cause); - } - - private void getFileByName(ServerReflectionRequest request) { - String name = request.getFileByFilename(); - FileDescriptor fd = serverReflectionIndex.getFileDescriptorByName(name); - if (fd != null) { - serverCallStreamObserver.onNext(createServerReflectionResponse(request, fd)); - } else { - sendErrorResponse(request, Status.Code.NOT_FOUND, "File not found."); - } - } - - private void getFileContainingSymbol(ServerReflectionRequest request) { - String symbol = request.getFileContainingSymbol(); - FileDescriptor fd = serverReflectionIndex.getFileDescriptorBySymbol(symbol); - if (fd != null) { - serverCallStreamObserver.onNext(createServerReflectionResponse(request, fd)); - } else { - sendErrorResponse(request, Status.Code.NOT_FOUND, "Symbol not found."); - } - } - - private void getFileByExtension(ServerReflectionRequest request) { - ExtensionRequest extensionRequest = request.getFileContainingExtension(); - String type = extensionRequest.getContainingType(); - int extension = extensionRequest.getExtensionNumber(); - FileDescriptor fd = - serverReflectionIndex.getFileDescriptorByExtensionAndNumber(type, extension); - if (fd != null) { - serverCallStreamObserver.onNext(createServerReflectionResponse(request, fd)); - } else { - sendErrorResponse(request, Status.Code.NOT_FOUND, "Extension not found."); - } - } - - private void getAllExtensions(ServerReflectionRequest request) { - String type = request.getAllExtensionNumbersOfType(); - Set extensions = serverReflectionIndex.getExtensionNumbersOfType(type); - if (extensions != null) { - ExtensionNumberResponse.Builder builder = - ExtensionNumberResponse.newBuilder() - .setBaseTypeName(type) - .addAllExtensionNumber(extensions); - serverCallStreamObserver.onNext( - ServerReflectionResponse.newBuilder() - .setValidHost(request.getHost()) - .setOriginalRequest(request) - .setAllExtensionNumbersResponse(builder) - .build()); - } else { - sendErrorResponse(request, Status.Code.NOT_FOUND, "Type not found."); - } - } - - private void listServices(ServerReflectionRequest request) { - ListServiceResponse.Builder builder = ListServiceResponse.newBuilder(); - for (String serviceName : serverReflectionIndex.getServiceNames()) { - builder.addService(ServiceResponse.newBuilder().setName(serviceName)); - } - serverCallStreamObserver.onNext( - ServerReflectionResponse.newBuilder() - .setValidHost(request.getHost()) - .setOriginalRequest(request) - .setListServicesResponse(builder) - .build()); - } - - private void sendErrorResponse( - ServerReflectionRequest request, Status.Code code, String message) { - ServerReflectionResponse response = - ServerReflectionResponse.newBuilder() - .setValidHost(request.getHost()) - .setOriginalRequest(request) - .setErrorResponse( - ErrorResponse.newBuilder() - .setErrorCode(code.value()) - .setErrorMessage(message)) - .build(); - serverCallStreamObserver.onNext(response); - } - - private ServerReflectionResponse createServerReflectionResponse( - ServerReflectionRequest request, FileDescriptor fd) { - FileDescriptorResponse.Builder fdRBuilder = FileDescriptorResponse.newBuilder(); - - Set seenFiles = new HashSet<>(); - Queue frontier = new ArrayDeque<>(); - seenFiles.add(fd.getName()); - frontier.add(fd); - while (!frontier.isEmpty()) { - FileDescriptor nextFd = frontier.remove(); - fdRBuilder.addFileDescriptorProto(nextFd.toProto().toByteString()); - for (FileDescriptor dependencyFd : nextFd.getDependencies()) { - if (!seenFiles.contains(dependencyFd.getName())) { - seenFiles.add(dependencyFd.getName()); - frontier.add(dependencyFd); - } - } - } - return ServerReflectionResponse.newBuilder() - .setValidHost(request.getHost()) - .setOriginalRequest(request) - .setFileDescriptorResponse(fdRBuilder) - .build(); - } - } - - /** - * Indexes the server's services and allows lookups of file descriptors by filename, symbol, type, - * and extension number. - * - *

Internally, this stores separate indices for the immutable and mutable services. When - * queried, the immutable service index is checked for a matching value. Only if there is no match - * in the immutable service index are the mutable services checked. - */ - private static final class ServerReflectionIndex { - private final FileDescriptorIndex immutableServicesIndex; - private final FileDescriptorIndex mutableServicesIndex; - - public ServerReflectionIndex( - List immutableServices, - List mutableServices) { - immutableServicesIndex = new FileDescriptorIndex(immutableServices); - mutableServicesIndex = new FileDescriptorIndex(mutableServices); - } - - private FileDescriptorIndex getMutableServicesIndex() { - return mutableServicesIndex; - } - - private Set getServiceNames() { - Set immutableServiceNames = immutableServicesIndex.getServiceNames(); - Set mutableServiceNames = mutableServicesIndex.getServiceNames(); - Set serviceNames = - new HashSet<>(immutableServiceNames.size() + mutableServiceNames.size()); - serviceNames.addAll(immutableServiceNames); - serviceNames.addAll(mutableServiceNames); - return serviceNames; - } - - @Nullable - private FileDescriptor getFileDescriptorByName(String name) { - FileDescriptor fd = immutableServicesIndex.getFileDescriptorByName(name); - if (fd == null) { - fd = mutableServicesIndex.getFileDescriptorByName(name); - } - return fd; - } - - @Nullable - private FileDescriptor getFileDescriptorBySymbol(String symbol) { - FileDescriptor fd = immutableServicesIndex.getFileDescriptorBySymbol(symbol); - if (fd == null) { - fd = mutableServicesIndex.getFileDescriptorBySymbol(symbol); - } - return fd; - } - - @Nullable - private FileDescriptor getFileDescriptorByExtensionAndNumber(String type, int extension) { - FileDescriptor fd = - immutableServicesIndex.getFileDescriptorByExtensionAndNumber(type, extension); - if (fd == null) { - fd = mutableServicesIndex.getFileDescriptorByExtensionAndNumber(type, extension); - } - return fd; - } - - @Nullable - private Set getExtensionNumbersOfType(String type) { - Set extensionNumbers = immutableServicesIndex.getExtensionNumbersOfType(type); - if (extensionNumbers == null) { - extensionNumbers = mutableServicesIndex.getExtensionNumbersOfType(type); - } - return extensionNumbers; - } - } - - /** - * Provides a set of methods for answering reflection queries for the file descriptors underlying - * a set of services. Used by {@link ServerReflectionIndex} to separately index immutable and - * mutable services. - */ - private static final class FileDescriptorIndex { - private final Set serviceNames = new HashSet<>(); - private final Set serviceFileDescriptors = new HashSet<>(); - private final Map fileDescriptorsByName = - new HashMap<>(); - private final Map fileDescriptorsBySymbol = - new HashMap<>(); - private final Map> fileDescriptorsByExtensionAndNumber = - new HashMap<>(); - - FileDescriptorIndex(List services) { - Queue fileDescriptorsToProcess = new ArrayDeque<>(); - Set seenFiles = new HashSet<>(); - for (ServerServiceDefinition service : services) { - io.grpc.ServiceDescriptor serviceDescriptor = service.getServiceDescriptor(); - if (serviceDescriptor.getSchemaDescriptor() instanceof ProtoFileDescriptorSupplier) { - FileDescriptor fileDescriptor = - ((ProtoFileDescriptorSupplier) serviceDescriptor.getSchemaDescriptor()) - .getFileDescriptor(); - String serviceName = serviceDescriptor.getName(); - checkState( - !serviceNames.contains(serviceName), "Service already defined: %s", serviceName); - serviceFileDescriptors.add(fileDescriptor); - serviceNames.add(serviceName); - if (!seenFiles.contains(fileDescriptor.getName())) { - seenFiles.add(fileDescriptor.getName()); - fileDescriptorsToProcess.add(fileDescriptor); - } - } - } - - while (!fileDescriptorsToProcess.isEmpty()) { - FileDescriptor currentFd = fileDescriptorsToProcess.remove(); - processFileDescriptor(currentFd); - for (FileDescriptor dependencyFd : currentFd.getDependencies()) { - if (!seenFiles.contains(dependencyFd.getName())) { - seenFiles.add(dependencyFd.getName()); - fileDescriptorsToProcess.add(dependencyFd); - } - } - } - } - - /** - * Returns the file descriptors for the indexed services, but not their dependencies. This is - * used to check if the server's mutable services have changed. - */ - private Set getServiceFileDescriptors() { - return Collections.unmodifiableSet(serviceFileDescriptors); - } - - private Set getServiceNames() { - return Collections.unmodifiableSet(serviceNames); - } - - @Nullable - private FileDescriptor getFileDescriptorByName(String name) { - return fileDescriptorsByName.get(name); - } - - @Nullable - private FileDescriptor getFileDescriptorBySymbol(String symbol) { - return fileDescriptorsBySymbol.get(symbol); - } - - @Nullable - private FileDescriptor getFileDescriptorByExtensionAndNumber(String type, int number) { - if (fileDescriptorsByExtensionAndNumber.containsKey(type)) { - return fileDescriptorsByExtensionAndNumber.get(type).get(number); - } - return null; - } - - @Nullable - private Set getExtensionNumbersOfType(String type) { - if (fileDescriptorsByExtensionAndNumber.containsKey(type)) { - return Collections.unmodifiableSet(fileDescriptorsByExtensionAndNumber.get(type).keySet()); - } - return null; - } - - private void processFileDescriptor(FileDescriptor fd) { - String fdName = fd.getName(); - checkState(!fileDescriptorsByName.containsKey(fdName), "File name already used: %s", fdName); - fileDescriptorsByName.put(fdName, fd); - for (ServiceDescriptor service : fd.getServices()) { - processService(service, fd); - } - for (Descriptor type : fd.getMessageTypes()) { - processType(type, fd); - } - for (FieldDescriptor extension : fd.getExtensions()) { - processExtension(extension, fd); - } - } - - private void processService(ServiceDescriptor service, FileDescriptor fd) { - String serviceName = service.getFullName(); - checkState( - !fileDescriptorsBySymbol.containsKey(serviceName), - "Service already defined: %s", - serviceName); - fileDescriptorsBySymbol.put(serviceName, fd); - for (MethodDescriptor method : service.getMethods()) { - String methodName = method.getFullName(); - checkState( - !fileDescriptorsBySymbol.containsKey(methodName), - "Method already defined: %s", - methodName); - fileDescriptorsBySymbol.put(methodName, fd); - } - } - - private void processType(Descriptor type, FileDescriptor fd) { - String typeName = type.getFullName(); - checkState( - !fileDescriptorsBySymbol.containsKey(typeName), "Type already defined: %s", typeName); - fileDescriptorsBySymbol.put(typeName, fd); - for (FieldDescriptor extension : type.getExtensions()) { - processExtension(extension, fd); - } - for (Descriptor nestedType : type.getNestedTypes()) { - processType(nestedType, fd); - } - } - - private void processExtension(FieldDescriptor extension, FileDescriptor fd) { - String extensionName = extension.getContainingType().getFullName(); - int extensionNumber = extension.getNumber(); - if (!fileDescriptorsByExtensionAndNumber.containsKey(extensionName)) { - fileDescriptorsByExtensionAndNumber.put( - extensionName, new HashMap()); - } - checkState( - !fileDescriptorsByExtensionAndNumber.get(extensionName).containsKey(extensionNumber), - "Extension name and number already defined: %s, %s", - extensionName, - extensionNumber); - fileDescriptorsByExtensionAndNumber.get(extensionName).put(extensionNumber, fd); - } + @SuppressWarnings("unchecked") + private ServerCallHandler + createServerCallHandler( + ServerServiceDefinition serverServiceDefinition) { + return (ServerCallHandler) + serverServiceDefinition.getMethod( + ServerReflectionGrpc.getServerReflectionInfoMethod().getFullMethodName()) + .getServerCallHandler(); } } diff --git a/services/src/main/java/io/grpc/protobuf/services/ProtoReflectionServiceV1.java b/services/src/main/java/io/grpc/protobuf/services/ProtoReflectionServiceV1.java new file mode 100644 index 00000000000..59e9c33d279 --- /dev/null +++ b/services/src/main/java/io/grpc/protobuf/services/ProtoReflectionServiceV1.java @@ -0,0 +1,539 @@ +/* + * Copyright 2016 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.protobuf.services; + +import static com.google.common.base.Preconditions.checkNotNull; +import static com.google.common.base.Preconditions.checkState; + +import com.google.errorprone.annotations.concurrent.GuardedBy; +import com.google.protobuf.Descriptors.Descriptor; +import com.google.protobuf.Descriptors.FieldDescriptor; +import com.google.protobuf.Descriptors.FileDescriptor; +import com.google.protobuf.Descriptors.MethodDescriptor; +import com.google.protobuf.Descriptors.ServiceDescriptor; +import io.grpc.BindableService; +import io.grpc.ExperimentalApi; +import io.grpc.InternalServer; +import io.grpc.Server; +import io.grpc.ServerServiceDefinition; +import io.grpc.Status; +import io.grpc.protobuf.ProtoFileDescriptorSupplier; +import io.grpc.reflection.v1.ErrorResponse; +import io.grpc.reflection.v1.ExtensionNumberResponse; +import io.grpc.reflection.v1.ExtensionRequest; +import io.grpc.reflection.v1.FileDescriptorResponse; +import io.grpc.reflection.v1.ListServiceResponse; +import io.grpc.reflection.v1.ServerReflectionGrpc; +import io.grpc.reflection.v1.ServerReflectionRequest; +import io.grpc.reflection.v1.ServerReflectionResponse; +import io.grpc.reflection.v1.ServiceResponse; +import io.grpc.stub.ServerCallStreamObserver; +import io.grpc.stub.StreamObserver; +import java.util.ArrayDeque; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Queue; +import java.util.Set; +import java.util.WeakHashMap; +import javax.annotation.Nullable; + +/** + * Provides a reflection service for Protobuf services (including the reflection service itself). + * + *

Separately tracks mutable and immutable services. Throws an exception if either group of + * services contains multiple Protobuf files with declarations of the same service, method, type, or + * extension. + */ +@ExperimentalApi("https://github.com/grpc/grpc-java/issues/2222") +public final class ProtoReflectionServiceV1 extends ServerReflectionGrpc.ServerReflectionImplBase { + + private final Object lock = new Object(); + + @GuardedBy("lock") + private final Map serverReflectionIndexes = new WeakHashMap<>(); + + private ProtoReflectionServiceV1() {} + + /** + * Creates a instance of {@link ProtoReflectionServiceV1}. + */ + public static BindableService newInstance() { + return new ProtoReflectionServiceV1(); + } + + /** + * Retrieves the index for services of the server that dispatches the current call. Computes + * one if not exist. The index is updated if any changes to the server's mutable services are + * detected. A change is any addition or removal in the set of file descriptors attached to the + * mutable services or a change in the service names. + */ + private ServerReflectionIndex getRefreshedIndex() { + synchronized (lock) { + Server server = InternalServer.SERVER_CONTEXT_KEY.get(); + ServerReflectionIndex index = serverReflectionIndexes.get(server); + if (index == null) { + index = + new ServerReflectionIndex(server.getImmutableServices(), server.getMutableServices()); + serverReflectionIndexes.put(server, index); + return index; + } + + Set serverFileDescriptors = new HashSet<>(); + Set serverServiceNames = new HashSet<>(); + List serverMutableServices = server.getMutableServices(); + for (ServerServiceDefinition mutableService : serverMutableServices) { + io.grpc.ServiceDescriptor serviceDescriptor = mutableService.getServiceDescriptor(); + if (serviceDescriptor.getSchemaDescriptor() instanceof ProtoFileDescriptorSupplier) { + String serviceName = serviceDescriptor.getName(); + FileDescriptor fileDescriptor = + ((ProtoFileDescriptorSupplier) serviceDescriptor.getSchemaDescriptor()) + .getFileDescriptor(); + serverFileDescriptors.add(fileDescriptor); + serverServiceNames.add(serviceName); + } + } + + // Replace the index if the underlying mutable services have changed. Check both the file + // descriptors and the service names, because one file descriptor can define multiple + // services. + FileDescriptorIndex mutableServicesIndex = index.getMutableServicesIndex(); + if (!mutableServicesIndex.getServiceFileDescriptors().equals(serverFileDescriptors) + || !mutableServicesIndex.getServiceNames().equals(serverServiceNames)) { + index = + new ServerReflectionIndex(server.getImmutableServices(), serverMutableServices); + serverReflectionIndexes.put(server, index); + } + + return index; + } + } + + @Override + public StreamObserver serverReflectionInfo( + final StreamObserver responseObserver) { + final ServerCallStreamObserver serverCallStreamObserver = + (ServerCallStreamObserver) responseObserver; + ProtoReflectionStreamObserver requestObserver = + new ProtoReflectionStreamObserver(getRefreshedIndex(), serverCallStreamObserver); + serverCallStreamObserver.setOnReadyHandler(requestObserver); + serverCallStreamObserver.disableAutoRequest(); + serverCallStreamObserver.request(1); + return requestObserver; + } + + private static class ProtoReflectionStreamObserver + implements Runnable, StreamObserver { + private final ServerReflectionIndex serverReflectionIndex; + private final ServerCallStreamObserver serverCallStreamObserver; + + private boolean closeAfterSend = false; + private ServerReflectionRequest request; + + ProtoReflectionStreamObserver( + ServerReflectionIndex serverReflectionIndex, + ServerCallStreamObserver serverCallStreamObserver) { + this.serverReflectionIndex = serverReflectionIndex; + this.serverCallStreamObserver = checkNotNull(serverCallStreamObserver, "observer"); + } + + @Override + public void run() { + if (request != null) { + handleReflectionRequest(); + } + } + + @Override + public void onNext(ServerReflectionRequest request) { + checkState(this.request == null); + this.request = checkNotNull(request); + handleReflectionRequest(); + } + + private void handleReflectionRequest() { + if (serverCallStreamObserver.isReady()) { + switch (request.getMessageRequestCase()) { + case FILE_BY_FILENAME: + getFileByName(request); + break; + case FILE_CONTAINING_SYMBOL: + getFileContainingSymbol(request); + break; + case FILE_CONTAINING_EXTENSION: + getFileByExtension(request); + break; + case ALL_EXTENSION_NUMBERS_OF_TYPE: + getAllExtensions(request); + break; + case LIST_SERVICES: + listServices(request); + break; + default: + sendErrorResponse( + request, + Status.Code.UNIMPLEMENTED, + "not implemented " + request.getMessageRequestCase()); + } + request = null; + if (closeAfterSend) { + serverCallStreamObserver.onCompleted(); + } else { + serverCallStreamObserver.request(1); + } + } + } + + @Override + public void onCompleted() { + if (request != null) { + closeAfterSend = true; + } else { + serverCallStreamObserver.onCompleted(); + } + } + + @Override + public void onError(Throwable cause) { + serverCallStreamObserver.onError(cause); + } + + private void getFileByName(ServerReflectionRequest request) { + String name = request.getFileByFilename(); + FileDescriptor fd = serverReflectionIndex.getFileDescriptorByName(name); + if (fd != null) { + serverCallStreamObserver.onNext(createServerReflectionResponse(request, fd)); + } else { + sendErrorResponse(request, Status.Code.NOT_FOUND, "File not found."); + } + } + + private void getFileContainingSymbol(ServerReflectionRequest request) { + String symbol = request.getFileContainingSymbol(); + FileDescriptor fd = serverReflectionIndex.getFileDescriptorBySymbol(symbol); + if (fd != null) { + serverCallStreamObserver.onNext(createServerReflectionResponse(request, fd)); + } else { + sendErrorResponse(request, Status.Code.NOT_FOUND, "Symbol not found."); + } + } + + private void getFileByExtension(ServerReflectionRequest request) { + ExtensionRequest extensionRequest = request.getFileContainingExtension(); + String type = extensionRequest.getContainingType(); + int extension = extensionRequest.getExtensionNumber(); + FileDescriptor fd = + serverReflectionIndex.getFileDescriptorByExtensionAndNumber(type, extension); + if (fd != null) { + serverCallStreamObserver.onNext(createServerReflectionResponse(request, fd)); + } else { + sendErrorResponse(request, Status.Code.NOT_FOUND, "Extension not found."); + } + } + + private void getAllExtensions(ServerReflectionRequest request) { + String type = request.getAllExtensionNumbersOfType(); + Set extensions = serverReflectionIndex.getExtensionNumbersOfType(type); + if (extensions != null) { + ExtensionNumberResponse.Builder builder = + ExtensionNumberResponse.newBuilder() + .setBaseTypeName(type) + .addAllExtensionNumber(extensions); + serverCallStreamObserver.onNext( + ServerReflectionResponse.newBuilder() + .setValidHost(request.getHost()) + .setOriginalRequest(request) + .setAllExtensionNumbersResponse(builder) + .build()); + } else { + sendErrorResponse(request, Status.Code.NOT_FOUND, "Type not found."); + } + } + + private void listServices(ServerReflectionRequest request) { + ListServiceResponse.Builder builder = ListServiceResponse.newBuilder(); + for (String serviceName : serverReflectionIndex.getServiceNames()) { + builder.addService(ServiceResponse.newBuilder().setName(serviceName)); + } + serverCallStreamObserver.onNext( + ServerReflectionResponse.newBuilder() + .setValidHost(request.getHost()) + .setOriginalRequest(request) + .setListServicesResponse(builder) + .build()); + } + + private void sendErrorResponse( + ServerReflectionRequest request, Status.Code code, String message) { + ServerReflectionResponse response = + ServerReflectionResponse.newBuilder() + .setValidHost(request.getHost()) + .setOriginalRequest(request) + .setErrorResponse( + ErrorResponse.newBuilder() + .setErrorCode(code.value()) + .setErrorMessage(message)) + .build(); + serverCallStreamObserver.onNext(response); + } + + private ServerReflectionResponse createServerReflectionResponse( + ServerReflectionRequest request, FileDescriptor fd) { + FileDescriptorResponse.Builder fdRBuilder = FileDescriptorResponse.newBuilder(); + + Set seenFiles = new HashSet<>(); + Queue frontier = new ArrayDeque<>(); + seenFiles.add(fd.getName()); + frontier.add(fd); + while (!frontier.isEmpty()) { + FileDescriptor nextFd = frontier.remove(); + fdRBuilder.addFileDescriptorProto(nextFd.toProto().toByteString()); + for (FileDescriptor dependencyFd : nextFd.getDependencies()) { + if (!seenFiles.contains(dependencyFd.getName())) { + seenFiles.add(dependencyFd.getName()); + frontier.add(dependencyFd); + } + } + } + return ServerReflectionResponse.newBuilder() + .setValidHost(request.getHost()) + .setOriginalRequest(request) + .setFileDescriptorResponse(fdRBuilder) + .build(); + } + } + + /** + * Indexes the server's services and allows lookups of file descriptors by filename, symbol, type, + * and extension number. + * + *

Internally, this stores separate indices for the immutable and mutable services. When + * queried, the immutable service index is checked for a matching value. Only if there is no match + * in the immutable service index are the mutable services checked. + */ + private static final class ServerReflectionIndex { + private final FileDescriptorIndex immutableServicesIndex; + private final FileDescriptorIndex mutableServicesIndex; + + public ServerReflectionIndex( + List immutableServices, + List mutableServices) { + immutableServicesIndex = new FileDescriptorIndex(immutableServices); + mutableServicesIndex = new FileDescriptorIndex(mutableServices); + } + + private FileDescriptorIndex getMutableServicesIndex() { + return mutableServicesIndex; + } + + private Set getServiceNames() { + Set immutableServiceNames = immutableServicesIndex.getServiceNames(); + Set mutableServiceNames = mutableServicesIndex.getServiceNames(); + Set serviceNames = + new HashSet<>(immutableServiceNames.size() + mutableServiceNames.size()); + serviceNames.addAll(immutableServiceNames); + serviceNames.addAll(mutableServiceNames); + return serviceNames; + } + + @Nullable + private FileDescriptor getFileDescriptorByName(String name) { + FileDescriptor fd = immutableServicesIndex.getFileDescriptorByName(name); + if (fd == null) { + fd = mutableServicesIndex.getFileDescriptorByName(name); + } + return fd; + } + + @Nullable + private FileDescriptor getFileDescriptorBySymbol(String symbol) { + FileDescriptor fd = immutableServicesIndex.getFileDescriptorBySymbol(symbol); + if (fd == null) { + fd = mutableServicesIndex.getFileDescriptorBySymbol(symbol); + } + return fd; + } + + @Nullable + private FileDescriptor getFileDescriptorByExtensionAndNumber(String type, int extension) { + FileDescriptor fd = + immutableServicesIndex.getFileDescriptorByExtensionAndNumber(type, extension); + if (fd == null) { + fd = mutableServicesIndex.getFileDescriptorByExtensionAndNumber(type, extension); + } + return fd; + } + + @Nullable + private Set getExtensionNumbersOfType(String type) { + Set extensionNumbers = immutableServicesIndex.getExtensionNumbersOfType(type); + if (extensionNumbers == null) { + extensionNumbers = mutableServicesIndex.getExtensionNumbersOfType(type); + } + return extensionNumbers; + } + } + + /** + * Provides a set of methods for answering reflection queries for the file descriptors underlying + * a set of services. Used by {@link ServerReflectionIndex} to separately index immutable and + * mutable services. + */ + private static final class FileDescriptorIndex { + private final Set serviceNames = new HashSet<>(); + private final Set serviceFileDescriptors = new HashSet<>(); + private final Map fileDescriptorsByName = + new HashMap<>(); + private final Map fileDescriptorsBySymbol = + new HashMap<>(); + private final Map> fileDescriptorsByExtensionAndNumber = + new HashMap<>(); + + FileDescriptorIndex(List services) { + Queue fileDescriptorsToProcess = new ArrayDeque<>(); + Set seenFiles = new HashSet<>(); + for (ServerServiceDefinition service : services) { + io.grpc.ServiceDescriptor serviceDescriptor = service.getServiceDescriptor(); + if (serviceDescriptor.getSchemaDescriptor() instanceof ProtoFileDescriptorSupplier) { + FileDescriptor fileDescriptor = + ((ProtoFileDescriptorSupplier) serviceDescriptor.getSchemaDescriptor()) + .getFileDescriptor(); + String serviceName = serviceDescriptor.getName(); + checkState( + !serviceNames.contains(serviceName), "Service already defined: %s", serviceName); + serviceFileDescriptors.add(fileDescriptor); + serviceNames.add(serviceName); + if (!seenFiles.contains(fileDescriptor.getName())) { + seenFiles.add(fileDescriptor.getName()); + fileDescriptorsToProcess.add(fileDescriptor); + } + } + } + + while (!fileDescriptorsToProcess.isEmpty()) { + FileDescriptor currentFd = fileDescriptorsToProcess.remove(); + processFileDescriptor(currentFd); + for (FileDescriptor dependencyFd : currentFd.getDependencies()) { + if (!seenFiles.contains(dependencyFd.getName())) { + seenFiles.add(dependencyFd.getName()); + fileDescriptorsToProcess.add(dependencyFd); + } + } + } + } + + /** + * Returns the file descriptors for the indexed services, but not their dependencies. This is + * used to check if the server's mutable services have changed. + */ + private Set getServiceFileDescriptors() { + return Collections.unmodifiableSet(serviceFileDescriptors); + } + + private Set getServiceNames() { + return Collections.unmodifiableSet(serviceNames); + } + + @Nullable + private FileDescriptor getFileDescriptorByName(String name) { + return fileDescriptorsByName.get(name); + } + + @Nullable + private FileDescriptor getFileDescriptorBySymbol(String symbol) { + return fileDescriptorsBySymbol.get(symbol); + } + + @Nullable + private FileDescriptor getFileDescriptorByExtensionAndNumber(String type, int number) { + if (fileDescriptorsByExtensionAndNumber.containsKey(type)) { + return fileDescriptorsByExtensionAndNumber.get(type).get(number); + } + return null; + } + + @Nullable + private Set getExtensionNumbersOfType(String type) { + if (fileDescriptorsByExtensionAndNumber.containsKey(type)) { + return Collections.unmodifiableSet(fileDescriptorsByExtensionAndNumber.get(type).keySet()); + } + return null; + } + + private void processFileDescriptor(FileDescriptor fd) { + String fdName = fd.getName(); + checkState(!fileDescriptorsByName.containsKey(fdName), "File name already used: %s", fdName); + fileDescriptorsByName.put(fdName, fd); + for (ServiceDescriptor service : fd.getServices()) { + processService(service, fd); + } + for (Descriptor type : fd.getMessageTypes()) { + processType(type, fd); + } + for (FieldDescriptor extension : fd.getExtensions()) { + processExtension(extension, fd); + } + } + + private void processService(ServiceDescriptor service, FileDescriptor fd) { + String serviceName = service.getFullName(); + checkState( + !fileDescriptorsBySymbol.containsKey(serviceName), + "Service already defined: %s", + serviceName); + fileDescriptorsBySymbol.put(serviceName, fd); + for (MethodDescriptor method : service.getMethods()) { + String methodName = method.getFullName(); + checkState( + !fileDescriptorsBySymbol.containsKey(methodName), + "Method already defined: %s", + methodName); + fileDescriptorsBySymbol.put(methodName, fd); + } + } + + private void processType(Descriptor type, FileDescriptor fd) { + String typeName = type.getFullName(); + checkState( + !fileDescriptorsBySymbol.containsKey(typeName), "Type already defined: %s", typeName); + fileDescriptorsBySymbol.put(typeName, fd); + for (FieldDescriptor extension : type.getExtensions()) { + processExtension(extension, fd); + } + for (Descriptor nestedType : type.getNestedTypes()) { + processType(nestedType, fd); + } + } + + private void processExtension(FieldDescriptor extension, FileDescriptor fd) { + String extensionName = extension.getContainingType().getFullName(); + int extensionNumber = extension.getNumber(); + if (!fileDescriptorsByExtensionAndNumber.containsKey(extensionName)) { + fileDescriptorsByExtensionAndNumber.put( + extensionName, new HashMap()); + } + checkState( + !fileDescriptorsByExtensionAndNumber.get(extensionName).containsKey(extensionNumber), + "Extension name and number already defined: %s, %s", + extensionName, + extensionNumber); + fileDescriptorsByExtensionAndNumber.get(extensionName).put(extensionNumber, fd); + } + } +} diff --git a/services/src/main/proto/grpc/binlog/v1/binarylog.proto b/services/src/main/proto/grpc/binlog/v1/binarylog.proto index 9ed1733e2d8..b18bd88ddc9 100644 --- a/services/src/main/proto/grpc/binlog/v1/binarylog.proto +++ b/services/src/main/proto/grpc/binlog/v1/binarylog.proto @@ -120,7 +120,7 @@ message ClientHeader { // A single process may be used to run multiple virtual // servers with different identities. - // The authority is the name of such a server identitiy. + // The authority is the name of such a server identity. // It is typically a portion of the URI in the form of // or : . string authority = 3; diff --git a/services/src/main/proto/grpc/reflection/v1/reflection.proto b/services/src/main/proto/grpc/reflection/v1/reflection.proto new file mode 100644 index 00000000000..1a2ceedc3d2 --- /dev/null +++ b/services/src/main/proto/grpc/reflection/v1/reflection.proto @@ -0,0 +1,147 @@ +// Copyright 2016 The gRPC Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Service exported by server reflection. A more complete description of how +// server reflection works can be found at +// https://github.com/grpc/grpc/blob/master/doc/server-reflection.md +// +// The canonical version of this proto can be found at +// https://github.com/grpc/grpc-proto/blob/master/grpc/reflection/v1/reflection.proto + +syntax = "proto3"; + +package grpc.reflection.v1; + +option go_package = "google.golang.org/grpc/reflection/grpc_reflection_v1"; +option java_multiple_files = true; +option java_package = "io.grpc.reflection.v1"; +option java_outer_classname = "ServerReflectionProto"; + +service ServerReflection { + // The reflection service is structured as a bidirectional stream, ensuring + // all related requests go to a single server. + rpc ServerReflectionInfo(stream ServerReflectionRequest) + returns (stream ServerReflectionResponse); +} + +// The message sent by the client when calling ServerReflectionInfo method. +message ServerReflectionRequest { + string host = 1; + // To use reflection service, the client should set one of the following + // fields in message_request. The server distinguishes requests by their + // defined field and then handles them using corresponding methods. + oneof message_request { + // Find a proto file by the file name. + string file_by_filename = 3; + + // Find the proto file that declares the given fully-qualified symbol name. + // This field should be a fully-qualified symbol name + // (e.g. .[.] or .). + string file_containing_symbol = 4; + + // Find the proto file which defines an extension extending the given + // message type with the given field number. + ExtensionRequest file_containing_extension = 5; + + // Finds the tag numbers used by all known extensions of the given message + // type, and appends them to ExtensionNumberResponse in an undefined order. + // Its corresponding method is best-effort: it's not guaranteed that the + // reflection service will implement this method, and it's not guaranteed + // that this method will provide all extensions. Returns + // StatusCode::UNIMPLEMENTED if it's not implemented. + // This field should be a fully-qualified type name. The format is + // . + string all_extension_numbers_of_type = 6; + + // List the full names of registered services. The content will not be + // checked. + string list_services = 7; + } +} + +// The type name and extension number sent by the client when requesting +// file_containing_extension. +message ExtensionRequest { + // Fully-qualified type name. The format should be . + string containing_type = 1; + int32 extension_number = 2; +} + +// The message sent by the server to answer ServerReflectionInfo method. +message ServerReflectionResponse { + string valid_host = 1; + ServerReflectionRequest original_request = 2; + // The server sets one of the following fields according to the message_request + // in the request. + oneof message_response { + // This message is used to answer file_by_filename, file_containing_symbol, + // file_containing_extension requests with transitive dependencies. + // As the repeated label is not allowed in oneof fields, we use a + // FileDescriptorResponse message to encapsulate the repeated fields. + // The reflection service is allowed to avoid sending FileDescriptorProtos + // that were previously sent in response to earlier requests in the stream. + FileDescriptorResponse file_descriptor_response = 4; + + // This message is used to answer all_extension_numbers_of_type requests. + ExtensionNumberResponse all_extension_numbers_response = 5; + + // This message is used to answer list_services requests. + ListServiceResponse list_services_response = 6; + + // This message is used when an error occurs. + ErrorResponse error_response = 7; + } +} + +// Serialized FileDescriptorProto messages sent by the server answering +// a file_by_filename, file_containing_symbol, or file_containing_extension +// request. +message FileDescriptorResponse { + // Serialized FileDescriptorProto messages. We avoid taking a dependency on + // descriptor.proto, which uses proto2 only features, by making them opaque + // bytes instead. + repeated bytes file_descriptor_proto = 1; +} + +// A list of extension numbers sent by the server answering +// all_extension_numbers_of_type request. +message ExtensionNumberResponse { + // Full name of the base type, including the package name. The format + // is . + string base_type_name = 1; + repeated int32 extension_number = 2; +} + +// A list of ServiceResponse sent by the server answering list_services request. +message ListServiceResponse { + // The information of each service may be expanded in the future, so we use + // ServiceResponse message to encapsulate it. + repeated ServiceResponse service = 1; +} + +// The information of a single service used by ListServiceResponse to answer +// list_services request. +message ServiceResponse { + // Full name of a registered service, including its package name. The format + // is . + string name = 1; +} + +// The error code and error message sent by the server when an error occurs. +message ErrorResponse { + // This field uses the error codes defined in grpc::StatusCode. + int32 error_code = 1; + string error_message = 2; +} + diff --git a/services/src/main/proto/grpc/reflection/v1alpha/reflection.proto b/services/src/main/proto/grpc/reflection/v1alpha/reflection.proto index 8c5e06fe148..a3984b55c2d 100644 --- a/services/src/main/proto/grpc/reflection/v1alpha/reflection.proto +++ b/services/src/main/proto/grpc/reflection/v1alpha/reflection.proto @@ -80,7 +80,7 @@ message ExtensionRequest { message ServerReflectionResponse { string valid_host = 1; ServerReflectionRequest original_request = 2; - // The server set one of the following fields accroding to the message_request + // The server set one of the following fields according to the message_request // in the request. oneof message_response { // This message is used to answer file_by_filename, file_containing_symbol, @@ -91,7 +91,7 @@ message ServerReflectionResponse { // that were previously sent in response to earlier requests in the stream. FileDescriptorResponse file_descriptor_response = 4; - // This message is used to answer all_extension_numbers_of_type requst. + // This message is used to answer all_extension_numbers_of_type request. ExtensionNumberResponse all_extension_numbers_response = 5; // This message is used to answer list_services request. diff --git a/services/src/test/java/io/grpc/protobuf/services/BinaryLogProviderTest.java b/services/src/test/java/io/grpc/protobuf/services/BinaryLogProviderTest.java index 2d2b7651c0a..67b187e9d7a 100644 --- a/services/src/test/java/io/grpc/protobuf/services/BinaryLogProviderTest.java +++ b/services/src/test/java/io/grpc/protobuf/services/BinaryLogProviderTest.java @@ -16,8 +16,8 @@ package io.grpc.protobuf.services; -import static com.google.common.base.Charsets.UTF_8; import static com.google.common.truth.Truth.assertThat; +import static java.nio.charset.StandardCharsets.UTF_8; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertSame; import static org.junit.Assert.assertTrue; diff --git a/services/src/test/java/io/grpc/protobuf/services/BinlogHelperTest.java b/services/src/test/java/io/grpc/protobuf/services/BinlogHelperTest.java index 856029bc750..ca42c1691b1 100644 --- a/services/src/test/java/io/grpc/protobuf/services/BinlogHelperTest.java +++ b/services/src/test/java/io/grpc/protobuf/services/BinlogHelperTest.java @@ -888,7 +888,7 @@ public void logRpcMessage() throws Exception { verify(sink).write(base); } - // server messsage + // server message { sinkWriterImpl.logRpcMessage( seq, @@ -1433,16 +1433,16 @@ public ServerCall.Listener startCall( // send server header { - Metadata serverInital = new Metadata(); - interceptedCall.get().sendHeaders(serverInital); + Metadata serverInitial = new Metadata(); + interceptedCall.get().sendHeaders(serverInitial); verify(mockSinkWriter).logServerHeader( /*seq=*/ eq(2L), - same(serverInital), + same(serverInitial), eq(Logger.LOGGER_SERVER), eq(CALL_ID), ArgumentMatchers.isNull()); verifyNoMoreInteractions(mockSinkWriter); - assertSame(serverInital, actualServerInitial.get()); + assertSame(serverInitial, actualServerInitial.get()); } // receive client msg diff --git a/services/src/test/java/io/grpc/protobuf/services/ChannelzProtoUtilTest.java b/services/src/test/java/io/grpc/protobuf/services/ChannelzProtoUtilTest.java index 4098885fd0d..598a8625e58 100644 --- a/services/src/test/java/io/grpc/protobuf/services/ChannelzProtoUtilTest.java +++ b/services/src/test/java/io/grpc/protobuf/services/ChannelzProtoUtilTest.java @@ -22,13 +22,12 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; -import com.google.common.base.Charsets; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.protobuf.Any; import com.google.protobuf.ByteString; import com.google.protobuf.Int64Value; -import com.google.protobuf.Message; +import com.google.protobuf.MessageLite; import com.google.protobuf.util.Durations; import com.google.protobuf.util.Timestamps; import io.grpc.ConnectivityState; @@ -82,6 +81,7 @@ import java.net.Inet4Address; import java.net.InetSocketAddress; import java.net.SocketAddress; +import java.nio.charset.StandardCharsets; import java.security.cert.Certificate; import java.util.Arrays; import java.util.Collections; @@ -154,33 +154,44 @@ public final class ChannelzProtoUtilTest { .setData(serverData) .build(); - private final SocketOption sockOptLingerDisabled = SocketOption - .newBuilder() - .setName("SO_LINGER") - .setAdditional( - Any.pack(SocketOptionLinger.getDefaultInstance())) - .build(); - - private final SocketOption sockOptlinger10s = SocketOption - .newBuilder() - .setName("SO_LINGER") - .setAdditional( - Any.pack(SocketOptionLinger - .newBuilder() - .setActive(true) - .setDuration(Durations.fromSeconds(10)) - .build())) - .build(); - - private final SocketOption sockOptTimeout200ms = SocketOption - .newBuilder() - .setName("SO_TIMEOUT") - .setAdditional( - Any.pack(SocketOptionTimeout - .newBuilder() - .setDuration(Durations.fromMillis(200)) - .build()) - ).build(); + private final SocketOption sockOptLingerDisabled = + SocketOption.newBuilder() + .setName("SO_LINGER") + .setAdditional( + Any.newBuilder() + .setTypeUrl("type.googleapis.com/grpc.channelz.v1.SocketOptionLinger") + .setValue(SocketOptionLinger.getDefaultInstance().toByteString()) + .build()) + .build(); + + private final SocketOption sockOptlinger10s = + SocketOption.newBuilder() + .setName("SO_LINGER") + .setAdditional( + Any.newBuilder() + .setTypeUrl("type.googleapis.com/grpc.channelz.v1.SocketOptionLinger") + .setValue( + SocketOptionLinger.newBuilder() + .setActive(true) + .setDuration(Durations.fromSeconds(10)) + .build() + .toByteString()) + .build()) + .build(); + + private final SocketOption sockOptTimeout200ms = + SocketOption.newBuilder() + .setName("SO_TIMEOUT") + .setAdditional( + Any.newBuilder() + .setTypeUrl("type.googleapis.com/grpc.channelz.v1.SocketOptionTimeout") + .setValue( + SocketOptionTimeout.newBuilder() + .setDuration(Durations.fromMillis(200)) + .build() + .toByteString()) + .build()) + .build(); private final SocketOption sockOptAdditional = SocketOption .newBuilder() @@ -221,43 +232,46 @@ public final class ChannelzProtoUtilTest { .setReordering(728) .build(); - private final SocketOption socketOptionTcpInfo = SocketOption - .newBuilder() - .setName("TCP_INFO") - .setAdditional( - Any.pack( - SocketOptionTcpInfo.newBuilder() - .setTcpiState(70) - .setTcpiCaState(71) - .setTcpiRetransmits(72) - .setTcpiProbes(73) - .setTcpiBackoff(74) - .setTcpiOptions(75) - .setTcpiSndWscale(76) - .setTcpiRcvWscale(77) - .setTcpiRto(78) - .setTcpiAto(79) - .setTcpiSndMss(710) - .setTcpiRcvMss(711) - .setTcpiUnacked(712) - .setTcpiSacked(713) - .setTcpiLost(714) - .setTcpiRetrans(715) - .setTcpiFackets(716) - .setTcpiLastDataSent(717) - .setTcpiLastAckSent(718) - .setTcpiLastDataRecv(719) - .setTcpiLastAckRecv(720) - .setTcpiPmtu(721) - .setTcpiRcvSsthresh(722) - .setTcpiRtt(723) - .setTcpiRttvar(724) - .setTcpiSndSsthresh(725) - .setTcpiSndCwnd(726) - .setTcpiAdvmss(727) - .setTcpiReordering(728) - .build())) - .build(); + private final SocketOption socketOptionTcpInfo = + SocketOption.newBuilder() + .setName("TCP_INFO") + .setAdditional( + Any.newBuilder() + .setTypeUrl("type.googleapis.com/grpc.channelz.v1.SocketOptionTcpInfo") + .setValue( + SocketOptionTcpInfo.newBuilder() + .setTcpiState(70) + .setTcpiCaState(71) + .setTcpiRetransmits(72) + .setTcpiProbes(73) + .setTcpiBackoff(74) + .setTcpiOptions(75) + .setTcpiSndWscale(76) + .setTcpiRcvWscale(77) + .setTcpiRto(78) + .setTcpiAto(79) + .setTcpiSndMss(710) + .setTcpiRcvMss(711) + .setTcpiUnacked(712) + .setTcpiSacked(713) + .setTcpiLost(714) + .setTcpiRetrans(715) + .setTcpiFackets(716) + .setTcpiLastDataSent(717) + .setTcpiLastAckSent(718) + .setTcpiLastDataRecv(719) + .setTcpiLastAckRecv(720) + .setTcpiPmtu(721) + .setTcpiRcvSsthresh(722) + .setTcpiRtt(723) + .setTcpiRttvar(724) + .setTcpiSndSsthresh(725) + .setTcpiSndCwnd(726) + .setTcpiAdvmss(727) + .setTcpiReordering(728) + .build() + .toByteString())) + .build(); private final TestListenSocket listenSocket = new TestListenSocket(); private final SocketRef listenSocketRef = SocketRef @@ -336,6 +350,16 @@ public void toServerRef() { assertEquals(serverRef, ChannelzProtoUtil.toServerRef(server)); } + @Test + public void toSeverity() { + for (Severity severity : Severity.values()) { + assertEquals( + severity.name(), + ChannelzProtoUtil.toSeverity(severity).name()); // OK because test isn't proguarded. + } + assertEquals(ChannelTraceEvent.Severity.CT_UNKNOWN, ChannelzProtoUtil.toSeverity(null)); + } + @Test public void toSocketRef() { assertEquals(socketRef, ChannelzProtoUtil.toSocketRef(socket)); @@ -346,7 +370,7 @@ public void toState() { for (ConnectivityState connectivityState : ConnectivityState.values()) { assertEquals( connectivityState.name(), - ChannelzProtoUtil.toState(connectivityState).getValueDescriptor().getName()); + ChannelzProtoUtil.toState(connectivityState).name()); // OK because test isn't proguarded. } assertEquals(State.UNKNOWN, ChannelzProtoUtil.toState(null)); } @@ -437,8 +461,8 @@ public void toSocketData() throws Exception { public void socketSecurityTls() throws Exception { Certificate local = mock(Certificate.class); Certificate remote = mock(Certificate.class); - when(local.getEncoded()).thenReturn("localcert".getBytes(Charsets.UTF_8)); - when(remote.getEncoded()).thenReturn("remotecert".getBytes(Charsets.UTF_8)); + when(local.getEncoded()).thenReturn("localcert".getBytes(StandardCharsets.UTF_8)); + when(remote.getEncoded()).thenReturn("remotecert".getBytes(StandardCharsets.UTF_8)); socket.security = new InternalChannelz.Security( new InternalChannelz.Tls("TLS_NULL_WITH_NULL_NULL", local, remote)); @@ -446,8 +470,8 @@ public void socketSecurityTls() throws Exception { Security.newBuilder().setTls( Tls.newBuilder() .setStandardName("TLS_NULL_WITH_NULL_NULL") - .setLocalCertificate(ByteString.copyFrom("localcert", Charsets.UTF_8)) - .setRemoteCertificate(ByteString.copyFrom("remotecert", Charsets.UTF_8))) + .setLocalCertificate(ByteString.copyFrom("localcert", StandardCharsets.UTF_8)) + .setRemoteCertificate(ByteString.copyFrom("remotecert", StandardCharsets.UTF_8))) .build(), ChannelzProtoUtil.toSocket(socket).getSecurity()); @@ -457,7 +481,7 @@ public void socketSecurityTls() throws Exception { Security.newBuilder().setTls( Tls.newBuilder() .setStandardName("TLS_NULL_WITH_NULL_NULL") - .setRemoteCertificate(ByteString.copyFrom("remotecert", Charsets.UTF_8))) + .setRemoteCertificate(ByteString.copyFrom("remotecert", StandardCharsets.UTF_8))) .build(), ChannelzProtoUtil.toSocket(socket).getSecurity()); @@ -467,7 +491,7 @@ public void socketSecurityTls() throws Exception { Security.newBuilder().setTls( Tls.newBuilder() .setStandardName("TLS_NULL_WITH_NULL_NULL") - .setLocalCertificate(ByteString.copyFrom("localcert", Charsets.UTF_8))) + .setLocalCertificate(ByteString.copyFrom("localcert", StandardCharsets.UTF_8))) .build(), ChannelzProtoUtil.toSocket(socket).getSecurity()); } @@ -475,8 +499,12 @@ public void socketSecurityTls() throws Exception { @Test public void socketSecurityOther() throws Exception { // what is packed here is not important, just pick some proto message - Message contents = GetChannelRequest.newBuilder().setChannelId(1).build(); - Any packed = Any.pack(contents); + MessageLite contents = GetChannelRequest.newBuilder().setChannelId(1).build(); + Any packed = + Any.newBuilder() + .setTypeUrl("type.googleapis.com/grpc.channelz.v1.GetChannelRequest") + .setValue(contents.toByteString()) + .build(); socket.security = new InternalChannelz.Security( new InternalChannelz.OtherSecurity("other_security", packed)); diff --git a/services/src/test/java/io/grpc/protobuf/services/HealthCheckingLoadBalancerFactoryTest.java b/services/src/test/java/io/grpc/protobuf/services/HealthCheckingLoadBalancerFactoryTest.java index 4787010ebe1..a49c426f7e1 100644 --- a/services/src/test/java/io/grpc/protobuf/services/HealthCheckingLoadBalancerFactoryTest.java +++ b/services/src/test/java/io/grpc/protobuf/services/HealthCheckingLoadBalancerFactoryTest.java @@ -206,15 +206,16 @@ public void setup() throws Exception { boolean shutdown; @Override - public void handleResolvedAddresses(final ResolvedAddresses resolvedAddresses) { + public Status acceptResolvedAddresses(final ResolvedAddresses resolvedAddresses) { syncContext.execute(new Runnable() { @Override public void run() { if (!shutdown) { - hcLb.handleResolvedAddresses(resolvedAddresses); + hcLb.acceptResolvedAddresses(resolvedAddresses); } } }); + return Status.OK; } @Override @@ -264,16 +265,16 @@ public void typicalWorkflow() { .setAddresses(resolvedAddressList) .setAttributes(resolutionAttrs) .build(); - hcLbEventDelivery.handleResolvedAddresses(result); + hcLbEventDelivery.acceptResolvedAddresses(result); - verify(origLb).handleResolvedAddresses(result); + verify(origLb).acceptResolvedAddresses(result); verify(origHelper, atLeast(0)).getSynchronizationContext(); verify(origHelper, atLeast(0)).getScheduledExecutorService(); verifyNoMoreInteractions(origHelper); verifyNoMoreInteractions(origLb); Subchannel[] wrappedSubchannels = new Subchannel[NUM_SUBCHANNELS]; - // Simulate that the orignal LB creates Subchannels + // Simulate that the original LB creates Subchannels for (int i = 0; i < NUM_SUBCHANNELS; i++) { // Subchannel attributes set by origLb are correctly plumbed in String subchannelAttrValue = "eag attr " + i; @@ -404,9 +405,9 @@ public void healthCheckDisabledWhenServiceNotImplemented() { .setAddresses(resolvedAddressList) .setAttributes(resolutionAttrs) .build(); - hcLbEventDelivery.handleResolvedAddresses(result); + hcLbEventDelivery.acceptResolvedAddresses(result); - verify(origLb).handleResolvedAddresses(result); + verify(origLb).acceptResolvedAddresses(result); verifyNoMoreInteractions(origLb); // We create 2 Subchannels. One of them connects to a server that doesn't implement health check @@ -489,9 +490,9 @@ public void backoffRetriesWhenServerErroneouslyClosesRpcBeforeAnyResponse() { .setAddresses(resolvedAddressList) .setAttributes(resolutionAttrs) .build(); - hcLbEventDelivery.handleResolvedAddresses(result); + hcLbEventDelivery.acceptResolvedAddresses(result); - verify(origLb).handleResolvedAddresses(result); + verify(origLb).acceptResolvedAddresses(result); verifyNoMoreInteractions(origLb); SubchannelStateListener mockHealthListener = mockHealthListeners[0]; @@ -567,9 +568,9 @@ public void serverRespondResetsBackoff() { .setAddresses(resolvedAddressList) .setAttributes(resolutionAttrs) .build(); - hcLbEventDelivery.handleResolvedAddresses(result); + hcLbEventDelivery.acceptResolvedAddresses(result); - verify(origLb).handleResolvedAddresses(result); + verify(origLb).acceptResolvedAddresses(result); verifyNoMoreInteractions(origLb); SubchannelStateListener mockStateListener = mockStateListeners[0]; @@ -667,9 +668,9 @@ public void serviceConfigHasNoHealthCheckingInitiallyButDoesLater() { .setAddresses(resolvedAddressList) .setAttributes(Attributes.EMPTY) .build(); - hcLbEventDelivery.handleResolvedAddresses(result1); + hcLbEventDelivery.acceptResolvedAddresses(result1); - verify(origLb).handleResolvedAddresses(result1); + verify(origLb).acceptResolvedAddresses(result1); verifyNoMoreInteractions(origLb); // First, create Subchannels 0 @@ -688,8 +689,8 @@ public void serviceConfigHasNoHealthCheckingInitiallyButDoesLater() { .setAddresses(resolvedAddressList) .setAttributes(resolutionAttrs) .build(); - hcLbEventDelivery.handleResolvedAddresses(result2); - verify(origLb).handleResolvedAddresses(result2); + hcLbEventDelivery.acceptResolvedAddresses(result2); + verify(origLb).acceptResolvedAddresses(result2); // Health check started on existing Subchannel assertThat(healthImpls[0].calls).hasSize(1); @@ -711,9 +712,9 @@ public void serviceConfigDisablesHealthCheckWhenRpcActive() { .setAddresses(resolvedAddressList) .setAttributes(resolutionAttrs) .build(); - hcLbEventDelivery.handleResolvedAddresses(result1); + hcLbEventDelivery.acceptResolvedAddresses(result1); - verify(origLb).handleResolvedAddresses(result1); + verify(origLb).acceptResolvedAddresses(result1); verifyNoMoreInteractions(origLb); Subchannel subchannel = createSubchannel(0, Attributes.EMPTY, maybeGetMockListener()); @@ -738,7 +739,7 @@ public void serviceConfigDisablesHealthCheckWhenRpcActive() { .setAddresses(resolvedAddressList) .setAttributes(Attributes.EMPTY) .build(); - hcLbEventDelivery.handleResolvedAddresses(result2); + hcLbEventDelivery.acceptResolvedAddresses(result2); // Health check RPC cancelled. assertThat(serverCall.cancelled).isTrue(); @@ -746,7 +747,7 @@ public void serviceConfigDisablesHealthCheckWhenRpcActive() { inOrder.verify(getMockListener()).onSubchannelState( eq(ConnectivityStateInfo.forNonError(READY))); - inOrder.verify(origLb).handleResolvedAddresses(result2); + inOrder.verify(origLb).acceptResolvedAddresses(result2); verifyNoMoreInteractions(origLb, mockStateListeners[0]); assertThat(healthImpl.calls).isEmpty(); @@ -759,9 +760,9 @@ public void serviceConfigDisablesHealthCheckWhenRetryPending() { .setAddresses(resolvedAddressList) .setAttributes(resolutionAttrs) .build(); - hcLbEventDelivery.handleResolvedAddresses(result); + hcLbEventDelivery.acceptResolvedAddresses(result); - verify(origLb).handleResolvedAddresses(result); + verify(origLb).acceptResolvedAddresses(result); verifyNoMoreInteractions(origLb); SubchannelStateListener mockHealthListener = mockHealthListeners[0]; @@ -793,7 +794,7 @@ public void serviceConfigDisablesHealthCheckWhenRetryPending() { .setAddresses(resolvedAddressList) .setAttributes(Attributes.EMPTY) .build(); - hcLbEventDelivery.handleResolvedAddresses(result2); + hcLbEventDelivery.acceptResolvedAddresses(result2); // Retry timer is cancelled assertThat(clock.getPendingTasks()).isEmpty(); @@ -805,7 +806,7 @@ public void serviceConfigDisablesHealthCheckWhenRetryPending() { inOrder.verify(getMockListener()).onSubchannelState( eq(ConnectivityStateInfo.forNonError(READY))); - inOrder.verify(origLb).handleResolvedAddresses(result2); + inOrder.verify(origLb).acceptResolvedAddresses(result2); verifyNoMoreInteractions(origLb, mockStateListeners[0]); } @@ -817,9 +818,9 @@ public void serviceConfigDisablesHealthCheckWhenRpcInactive() { .setAddresses(resolvedAddressList) .setAttributes(resolutionAttrs) .build(); - hcLbEventDelivery.handleResolvedAddresses(result1); + hcLbEventDelivery.acceptResolvedAddresses(result1); - verify(origLb).handleResolvedAddresses(result1); + verify(origLb).acceptResolvedAddresses(result1); verifyNoMoreInteractions(origLb); Subchannel subchannel = createSubchannel(0, Attributes.EMPTY, maybeGetMockListener()); @@ -842,9 +843,9 @@ public void serviceConfigDisablesHealthCheckWhenRpcInactive() { .setAddresses(resolvedAddressList) .setAttributes(Attributes.EMPTY) .build(); - hcLbEventDelivery.handleResolvedAddresses(result2); + hcLbEventDelivery.acceptResolvedAddresses(result2); - inOrder.verify(origLb).handleResolvedAddresses(result2); + inOrder.verify(origLb).acceptResolvedAddresses(result2); // Underlying subchannel is now ready deliverSubchannelState(0, ConnectivityStateInfo.forNonError(READY)); @@ -870,9 +871,9 @@ public void serviceConfigChangesServiceNameWhenRpcActive() { .setAddresses(resolvedAddressList) .setAttributes(resolutionAttrs) .build(); - hcLbEventDelivery.handleResolvedAddresses(result1); + hcLbEventDelivery.acceptResolvedAddresses(result1); - verify(origLb).handleResolvedAddresses(result1); + verify(origLb).acceptResolvedAddresses(result1); verifyNoMoreInteractions(origLb); SubchannelStateListener mockHealthListener = mockHealthListeners[0]; @@ -900,9 +901,9 @@ public void serviceConfigChangesServiceNameWhenRpcActive() { eq(ConnectivityStateInfo.forNonError(READY))); // Service config returns with the same health check name. - hcLbEventDelivery.handleResolvedAddresses(result1); + hcLbEventDelivery.acceptResolvedAddresses(result1); // It's delivered to origLb, but nothing else happens - inOrder.verify(origLb).handleResolvedAddresses(result1); + inOrder.verify(origLb).acceptResolvedAddresses(result1); verifyNoMoreInteractions(origLb, mockListener); // Service config returns a different health check name. @@ -911,8 +912,8 @@ public void serviceConfigChangesServiceNameWhenRpcActive() { .setAddresses(resolvedAddressList) .setAttributes(resolutionAttrs) .build(); - hcLbEventDelivery.handleResolvedAddresses(result2); - inOrder.verify(origLb).handleResolvedAddresses(result2); + hcLbEventDelivery.acceptResolvedAddresses(result2); + inOrder.verify(origLb).acceptResolvedAddresses(result2); // Current health check RPC cancelled. assertThat(serverCall.cancelled).isTrue(); @@ -934,9 +935,9 @@ public void serviceConfigChangesServiceNameWhenRetryPending() { .setAddresses(resolvedAddressList) .setAttributes(resolutionAttrs) .build(); - hcLbEventDelivery.handleResolvedAddresses(result1); + hcLbEventDelivery.acceptResolvedAddresses(result1); - verify(origLb).handleResolvedAddresses(result1); + verify(origLb).acceptResolvedAddresses(result1); verifyNoMoreInteractions(origLb); SubchannelStateListener mockHealthListener = mockHealthListeners[0]; @@ -969,9 +970,9 @@ public void serviceConfigChangesServiceNameWhenRetryPending() { // Service config returns with the same health check name. - hcLbEventDelivery.handleResolvedAddresses(result1); + hcLbEventDelivery.acceptResolvedAddresses(result1); // It's delivered to origLb, but nothing else happens - inOrder.verify(origLb).handleResolvedAddresses(result1); + inOrder.verify(origLb).acceptResolvedAddresses(result1); verifyNoMoreInteractions(origLb, mockListener); assertThat(clock.getPendingTasks()).hasSize(1); assertThat(healthImpl.calls).isEmpty(); @@ -982,12 +983,12 @@ public void serviceConfigChangesServiceNameWhenRetryPending() { .setAddresses(resolvedAddressList) .setAttributes(resolutionAttrs) .build(); - hcLbEventDelivery.handleResolvedAddresses(result2); + hcLbEventDelivery.acceptResolvedAddresses(result2); // Concluded CONNECTING state inOrder.verify(getMockListener()).onSubchannelState( eq(ConnectivityStateInfo.forNonError(CONNECTING))); - inOrder.verify(origLb).handleResolvedAddresses(result2); + inOrder.verify(origLb).acceptResolvedAddresses(result2); // Current retry timer cancelled assertThat(clock.getPendingTasks()).isEmpty(); @@ -1008,9 +1009,9 @@ public void serviceConfigChangesServiceNameWhenRpcInactive() { .setAddresses(resolvedAddressList) .setAttributes(resolutionAttrs) .build(); - hcLbEventDelivery.handleResolvedAddresses(result1); + hcLbEventDelivery.acceptResolvedAddresses(result1); - verify(origLb).handleResolvedAddresses(result1); + verify(origLb).acceptResolvedAddresses(result1); verifyNoMoreInteractions(origLb); Subchannel subchannel = createSubchannel(0, Attributes.EMPTY, maybeGetMockListener()); @@ -1031,9 +1032,9 @@ public void serviceConfigChangesServiceNameWhenRpcInactive() { inOrder.verifyNoMoreInteractions(); // Service config returns with the same health check name. - hcLbEventDelivery.handleResolvedAddresses(result1); + hcLbEventDelivery.acceptResolvedAddresses(result1); // It's delivered to origLb, but nothing else happens - inOrder.verify(origLb).handleResolvedAddresses(result1); + inOrder.verify(origLb).acceptResolvedAddresses(result1); assertThat(healthImpl.calls).isEmpty(); verifyNoMoreInteractions(origLb); @@ -1043,9 +1044,9 @@ public void serviceConfigChangesServiceNameWhenRpcInactive() { .setAddresses(resolvedAddressList) .setAttributes(resolutionAttrs) .build(); - hcLbEventDelivery.handleResolvedAddresses(result2); + hcLbEventDelivery.acceptResolvedAddresses(result2); - inOrder.verify(origLb).handleResolvedAddresses(result2); + inOrder.verify(origLb).acceptResolvedAddresses(result2); // Underlying subchannel is now ready deliverSubchannelState(0, ConnectivityStateInfo.forNonError(READY)); @@ -1092,9 +1093,9 @@ public void balancerShutdown() { .setAddresses(resolvedAddressList) .setAttributes(resolutionAttrs) .build(); - hcLbEventDelivery.handleResolvedAddresses(result); + hcLbEventDelivery.acceptResolvedAddresses(result); - verify(origLb).handleResolvedAddresses(result); + verify(origLb).acceptResolvedAddresses(result); verifyNoMoreInteractions(origLb); ServerSideCall[] serverCalls = new ServerSideCall[NUM_SUBCHANNELS]; @@ -1172,8 +1173,8 @@ public LoadBalancer newLoadBalancer(Helper helper) { .setAddresses(resolvedAddressList) .setAttributes(resolutionAttrs) .build(); - hcLbEventDelivery.handleResolvedAddresses(result); - verify(origLb).handleResolvedAddresses(result); + hcLbEventDelivery.acceptResolvedAddresses(result); + verify(origLb).acceptResolvedAddresses(result); createSubchannel(0, Attributes.EMPTY); assertThat(healthImpls[0].calls).isEmpty(); deliverSubchannelState(0, ConnectivityStateInfo.forNonError(READY)); diff --git a/services/src/test/java/io/grpc/protobuf/services/HealthStatusManagerTest.java b/services/src/test/java/io/grpc/protobuf/services/HealthStatusManagerTest.java index 87d4ac29be8..b2652e92771 100644 --- a/services/src/test/java/io/grpc/protobuf/services/HealthStatusManagerTest.java +++ b/services/src/test/java/io/grpc/protobuf/services/HealthStatusManagerTest.java @@ -18,6 +18,11 @@ import static com.google.common.truth.Truth.assertThat; import static org.junit.Assert.fail; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; import io.grpc.BindableService; import io.grpc.Context; @@ -28,6 +33,7 @@ import io.grpc.health.v1.HealthCheckResponse; import io.grpc.health.v1.HealthCheckResponse.ServingStatus; import io.grpc.health.v1.HealthGrpc; +import io.grpc.stub.ServerCallStreamObserver; import io.grpc.stub.StreamObserver; import io.grpc.testing.GrpcServerRule; import java.util.ArrayDeque; @@ -109,6 +115,18 @@ public void enterTerminalState_watch() throws Exception { assertThat(obs.responses).isEmpty(); } + @Test + @SuppressWarnings("unchecked") + public void serverCallStreamObserver_watch() throws Exception { + manager.setStatus(SERVICE1, ServingStatus.SERVING); + ServerCallStreamObserver observer = mock(ServerCallStreamObserver.class); + service.watch(HealthCheckRequest.newBuilder().setService(SERVICE1).build(), observer); + + verify(observer, times(1)) + .onNext(eq(HealthCheckResponse.newBuilder().setStatus(ServingStatus.SERVING).build())); + verify(observer, times(1)).setOnCancelHandler(any(Runnable.class)); + } + @Test public void enterTerminalState_ignoreClear() throws Exception { manager.setStatus(SERVICE1, ServingStatus.SERVING); diff --git a/services/src/test/java/io/grpc/protobuf/services/ProtoReflectionServiceTest.java b/services/src/test/java/io/grpc/protobuf/services/ProtoReflectionServiceTest.java index c9dd1014141..115dd11b0f1 100644 --- a/services/src/test/java/io/grpc/protobuf/services/ProtoReflectionServiceTest.java +++ b/services/src/test/java/io/grpc/protobuf/services/ProtoReflectionServiceTest.java @@ -71,7 +71,8 @@ public class ProtoReflectionServiceTest { private static final String TEST_HOST = "localhost"; private MutableHandlerRegistry handlerRegistry = new MutableHandlerRegistry(); - private BindableService reflectionService; + @SuppressWarnings("deprecation") + private BindableService reflectionService = ProtoReflectionService.newInstance(); private ServerServiceDefinition dynamicService = new DynamicServiceGrpc.DynamicServiceImplBase() {}.bindService(); private ServerServiceDefinition anotherDynamicService = @@ -80,7 +81,6 @@ public class ProtoReflectionServiceTest { @Before public void setUp() throws Exception { - reflectionService = ProtoReflectionService.newInstance(); Server server = InProcessServerBuilder.forName("proto-reflection-test") .directExecutor() diff --git a/services/src/test/java/io/grpc/protobuf/services/ProtoReflectionServiceV1Test.java b/services/src/test/java/io/grpc/protobuf/services/ProtoReflectionServiceV1Test.java new file mode 100644 index 00000000000..47bd3e792ad --- /dev/null +++ b/services/src/test/java/io/grpc/protobuf/services/ProtoReflectionServiceV1Test.java @@ -0,0 +1,670 @@ +/* + * Copyright 2016 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.protobuf.services; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +import com.google.protobuf.ByteString; +import io.grpc.BindableService; +import io.grpc.ManagedChannel; +import io.grpc.Server; +import io.grpc.ServerServiceDefinition; +import io.grpc.inprocess.InProcessChannelBuilder; +import io.grpc.inprocess.InProcessServerBuilder; +import io.grpc.internal.testing.StreamRecorder; +import io.grpc.reflection.testing.AnotherDynamicServiceGrpc; +import io.grpc.reflection.testing.AnotherReflectableServiceGrpc; +import io.grpc.reflection.testing.DynamicReflectionTestDepthTwoProto; +import io.grpc.reflection.testing.DynamicServiceGrpc; +import io.grpc.reflection.testing.ReflectableServiceGrpc; +import io.grpc.reflection.testing.ReflectionTestDepthThreeProto; +import io.grpc.reflection.testing.ReflectionTestDepthTwoAlternateProto; +import io.grpc.reflection.testing.ReflectionTestDepthTwoProto; +import io.grpc.reflection.testing.ReflectionTestProto; +import io.grpc.reflection.v1.ExtensionNumberResponse; +import io.grpc.reflection.v1.ExtensionRequest; +import io.grpc.reflection.v1.FileDescriptorResponse; +import io.grpc.reflection.v1.ServerReflectionGrpc; +import io.grpc.reflection.v1.ServerReflectionRequest; +import io.grpc.reflection.v1.ServerReflectionResponse; +import io.grpc.reflection.v1.ServiceResponse; +import io.grpc.stub.ClientCallStreamObserver; +import io.grpc.stub.ClientResponseObserver; +import io.grpc.stub.StreamObserver; +import io.grpc.testing.GrpcCleanupRule; +import io.grpc.util.MutableHandlerRegistry; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashSet; +import java.util.List; +import java.util.Set; +import java.util.concurrent.ExecutionException; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Tests for {@link ProtoReflectionServiceV1}. */ +@RunWith(JUnit4.class) +public class ProtoReflectionServiceV1Test { + @Rule + public GrpcCleanupRule grpcCleanupRule = new GrpcCleanupRule(); + + private static final String TEST_HOST = "localhost"; + private MutableHandlerRegistry handlerRegistry = new MutableHandlerRegistry(); + private BindableService reflectionService; + private ServerServiceDefinition dynamicService = + new DynamicServiceGrpc.DynamicServiceImplBase() {}.bindService(); + private ServerServiceDefinition anotherDynamicService = + new AnotherDynamicServiceGrpc.AnotherDynamicServiceImplBase() {}.bindService(); + private ServerReflectionGrpc.ServerReflectionStub stub; + + @Before + public void setUp() throws Exception { + reflectionService = ProtoReflectionServiceV1.newInstance(); + Server server = + InProcessServerBuilder.forName("proto-reflection-test") + .directExecutor() + .addService(reflectionService) + .addService(new ReflectableServiceGrpc.ReflectableServiceImplBase() {}) + .fallbackHandlerRegistry(handlerRegistry) + .build() + .start(); + grpcCleanupRule.register(server); + ManagedChannel channel = + grpcCleanupRule.register( + InProcessChannelBuilder.forName("proto-reflection-test").directExecutor().build()); + stub = ServerReflectionGrpc.newStub(channel); + } + + @Test + public void listServices() throws Exception { + Set originalServices = + new HashSet<>( + Arrays.asList( + ServiceResponse.newBuilder() + .setName("grpc.reflection.v1.ServerReflection") + .build(), + ServiceResponse.newBuilder() + .setName("grpc.reflection.testing.ReflectableService") + .build())); + assertServiceResponseEquals(originalServices); + + handlerRegistry.addService(dynamicService); + assertServiceResponseEquals( + new HashSet<>( + Arrays.asList( + ServiceResponse.newBuilder() + .setName("grpc.reflection.v1.ServerReflection") + .build(), + ServiceResponse.newBuilder() + .setName("grpc.reflection.testing.ReflectableService") + .build(), + ServiceResponse.newBuilder() + .setName("grpc.reflection.testing.DynamicService") + .build()))); + + handlerRegistry.addService(anotherDynamicService); + assertServiceResponseEquals( + new HashSet<>( + Arrays.asList( + ServiceResponse.newBuilder() + .setName("grpc.reflection.v1.ServerReflection") + .build(), + ServiceResponse.newBuilder() + .setName("grpc.reflection.testing.ReflectableService") + .build(), + ServiceResponse.newBuilder() + .setName("grpc.reflection.testing.DynamicService") + .build(), + ServiceResponse.newBuilder() + .setName("grpc.reflection.testing.AnotherDynamicService") + .build()))); + + handlerRegistry.removeService(dynamicService); + assertServiceResponseEquals( + new HashSet<>( + Arrays.asList( + ServiceResponse.newBuilder() + .setName("grpc.reflection.v1.ServerReflection") + .build(), + ServiceResponse.newBuilder() + .setName("grpc.reflection.testing.ReflectableService") + .build(), + ServiceResponse.newBuilder() + .setName("grpc.reflection.testing.AnotherDynamicService") + .build()))); + + handlerRegistry.removeService(anotherDynamicService); + assertServiceResponseEquals(originalServices); + } + + @Test + public void fileByFilename() throws Exception { + ServerReflectionRequest request = + ServerReflectionRequest.newBuilder() + .setHost(TEST_HOST) + .setFileByFilename("io/grpc/reflection/testing/reflection_test_depth_three.proto") + .build(); + + ServerReflectionResponse goldenResponse = + ServerReflectionResponse.newBuilder() + .setValidHost(TEST_HOST) + .setOriginalRequest(request) + .setFileDescriptorResponse( + FileDescriptorResponse.newBuilder() + .addFileDescriptorProto( + ReflectionTestDepthThreeProto.getDescriptor().toProto().toByteString()) + .build()) + .build(); + + StreamRecorder responseObserver = StreamRecorder.create(); + StreamObserver requestObserver = + stub.serverReflectionInfo(responseObserver); + requestObserver.onNext(request); + requestObserver.onCompleted(); + + assertEquals(goldenResponse, responseObserver.firstValue().get()); + } + + @Test + public void fileByFilenameConsistentForMutableServices() throws Exception { + ServerReflectionRequest request = + ServerReflectionRequest.newBuilder() + .setHost(TEST_HOST) + .setFileByFilename("io/grpc/reflection/testing/dynamic_reflection_test_depth_two.proto") + .build(); + ServerReflectionResponse goldenResponse = + ServerReflectionResponse.newBuilder() + .setValidHost(TEST_HOST) + .setOriginalRequest(request) + .setFileDescriptorResponse( + FileDescriptorResponse.newBuilder() + .addFileDescriptorProto( + DynamicReflectionTestDepthTwoProto.getDescriptor().toProto().toByteString()) + .build()) + .build(); + + StreamRecorder responseObserver = StreamRecorder.create(); + StreamObserver requestObserver = + stub.serverReflectionInfo(responseObserver); + handlerRegistry.addService(dynamicService); + requestObserver.onNext(request); + requestObserver.onCompleted(); + StreamRecorder responseObserver2 = StreamRecorder.create(); + StreamObserver requestObserver2 = + stub.serverReflectionInfo(responseObserver2); + handlerRegistry.removeService(dynamicService); + requestObserver2.onNext(request); + requestObserver2.onCompleted(); + StreamRecorder responseObserver3 = StreamRecorder.create(); + StreamObserver requestObserver3 = + stub.serverReflectionInfo(responseObserver3); + requestObserver3.onNext(request); + requestObserver3.onCompleted(); + + assertEquals( + ServerReflectionResponse.MessageResponseCase.ERROR_RESPONSE, + responseObserver.firstValue().get().getMessageResponseCase()); + assertEquals(goldenResponse, responseObserver2.firstValue().get()); + assertEquals( + ServerReflectionResponse.MessageResponseCase.ERROR_RESPONSE, + responseObserver3.firstValue().get().getMessageResponseCase()); + } + + @Test + public void fileContainingSymbol() throws Exception { + ServerReflectionRequest request = + ServerReflectionRequest.newBuilder() + .setHost(TEST_HOST) + .setFileContainingSymbol("grpc.reflection.testing.ReflectableService.Method") + .build(); + + List goldenResponse = + Arrays.asList( + ReflectionTestProto.getDescriptor().toProto().toByteString(), + ReflectionTestDepthTwoProto.getDescriptor().toProto().toByteString(), + ReflectionTestDepthTwoAlternateProto.getDescriptor().toProto().toByteString(), + ReflectionTestDepthThreeProto.getDescriptor().toProto().toByteString()); + + StreamRecorder responseObserver = StreamRecorder.create(); + StreamObserver requestObserver = + stub.serverReflectionInfo(responseObserver); + requestObserver.onNext(request); + requestObserver.onCompleted(); + + List response = + responseObserver + .firstValue() + .get() + .getFileDescriptorResponse() + .getFileDescriptorProtoList(); + assertEquals(goldenResponse.size(), response.size()); + assertEquals(new HashSet<>(goldenResponse), new HashSet<>(response)); + } + + @Test + public void fileContainingNestedSymbol() throws Exception { + ServerReflectionRequest request = + ServerReflectionRequest.newBuilder() + .setHost(TEST_HOST) + .setFileContainingSymbol("grpc.reflection.testing.NestedTypeOuter.Middle.Inner") + .build(); + + ServerReflectionResponse goldenResponse = + ServerReflectionResponse.newBuilder() + .setValidHost(TEST_HOST) + .setOriginalRequest(request) + .setFileDescriptorResponse( + FileDescriptorResponse.newBuilder() + .addFileDescriptorProto( + ReflectionTestDepthThreeProto.getDescriptor().toProto().toByteString()) + .build()) + .build(); + + StreamRecorder responseObserver = StreamRecorder.create(); + StreamObserver requestObserver = + stub.serverReflectionInfo(responseObserver); + requestObserver.onNext(request); + requestObserver.onCompleted(); + assertEquals(goldenResponse, responseObserver.firstValue().get()); + } + + @Test + public void fileContainingSymbolForMutableServices() throws Exception { + ServerReflectionRequest request = + ServerReflectionRequest.newBuilder() + .setHost(TEST_HOST) + .setFileContainingSymbol("grpc.reflection.testing.DynamicRequest") + .build(); + ServerReflectionResponse goldenResponse = + ServerReflectionResponse.newBuilder() + .setValidHost(TEST_HOST) + .setOriginalRequest(request) + .setFileDescriptorResponse( + FileDescriptorResponse.newBuilder() + .addFileDescriptorProto( + DynamicReflectionTestDepthTwoProto.getDescriptor().toProto().toByteString()) + .build()) + .build(); + + StreamRecorder responseObserver = StreamRecorder.create(); + StreamObserver requestObserver = + stub.serverReflectionInfo(responseObserver); + handlerRegistry.addService(dynamicService); + requestObserver.onNext(request); + requestObserver.onCompleted(); + StreamRecorder responseObserver2 = StreamRecorder.create(); + StreamObserver requestObserver2 = + stub.serverReflectionInfo(responseObserver2); + handlerRegistry.removeService(dynamicService); + requestObserver2.onNext(request); + requestObserver2.onCompleted(); + StreamRecorder responseObserver3 = StreamRecorder.create(); + StreamObserver requestObserver3 = + stub.serverReflectionInfo(responseObserver3); + requestObserver3.onNext(request); + requestObserver3.onCompleted(); + + assertEquals( + ServerReflectionResponse.MessageResponseCase.ERROR_RESPONSE, + responseObserver.firstValue().get().getMessageResponseCase()); + assertEquals(goldenResponse, responseObserver2.firstValue().get()); + assertEquals( + ServerReflectionResponse.MessageResponseCase.ERROR_RESPONSE, + responseObserver3.firstValue().get().getMessageResponseCase()); + } + + @Test + public void fileContainingExtension() throws Exception { + ServerReflectionRequest request = + ServerReflectionRequest.newBuilder() + .setHost(TEST_HOST) + .setFileContainingExtension( + ExtensionRequest.newBuilder() + .setContainingType("grpc.reflection.testing.ThirdLevelType") + .setExtensionNumber(100) + .build()) + .build(); + + List goldenResponse = + Arrays.asList( + ReflectionTestProto.getDescriptor().toProto().toByteString(), + ReflectionTestDepthTwoProto.getDescriptor().toProto().toByteString(), + ReflectionTestDepthTwoAlternateProto.getDescriptor().toProto().toByteString(), + ReflectionTestDepthThreeProto.getDescriptor().toProto().toByteString()); + + StreamRecorder responseObserver = StreamRecorder.create(); + StreamObserver requestObserver = + stub.serverReflectionInfo(responseObserver); + requestObserver.onNext(request); + requestObserver.onCompleted(); + + List response = + responseObserver + .firstValue() + .get() + .getFileDescriptorResponse() + .getFileDescriptorProtoList(); + assertEquals(goldenResponse.size(), response.size()); + assertEquals(new HashSet<>(goldenResponse), new HashSet<>(response)); + } + + @Test + public void fileContainingNestedExtension() throws Exception { + ServerReflectionRequest request = + ServerReflectionRequest.newBuilder() + .setHost(TEST_HOST) + .setFileContainingExtension( + ExtensionRequest.newBuilder() + .setContainingType("grpc.reflection.testing.ThirdLevelType") + .setExtensionNumber(101) + .build()) + .build(); + + ServerReflectionResponse goldenResponse = + ServerReflectionResponse.newBuilder() + .setValidHost(TEST_HOST) + .setOriginalRequest(request) + .setFileDescriptorResponse( + FileDescriptorResponse.newBuilder() + .addFileDescriptorProto( + ReflectionTestDepthTwoProto.getDescriptor().toProto().toByteString()) + .addFileDescriptorProto( + ReflectionTestDepthThreeProto.getDescriptor().toProto().toByteString()) + .build()) + .build(); + + StreamRecorder responseObserver = StreamRecorder.create(); + StreamObserver requestObserver = + stub.serverReflectionInfo(responseObserver); + requestObserver.onNext(request); + requestObserver.onCompleted(); + assertEquals(goldenResponse, responseObserver.firstValue().get()); + } + + @Test + public void fileContainingExtensionForMutableServices() throws Exception { + ServerReflectionRequest request = + ServerReflectionRequest.newBuilder() + .setHost(TEST_HOST) + .setFileContainingExtension( + ExtensionRequest.newBuilder() + .setContainingType("grpc.reflection.testing.TypeWithExtensions") + .setExtensionNumber(200) + .build()) + .build(); + ServerReflectionResponse goldenResponse = + ServerReflectionResponse.newBuilder() + .setValidHost(TEST_HOST) + .setOriginalRequest(request) + .setFileDescriptorResponse( + FileDescriptorResponse.newBuilder() + .addFileDescriptorProto( + DynamicReflectionTestDepthTwoProto.getDescriptor().toProto().toByteString()) + .build()) + .build(); + + StreamRecorder responseObserver = StreamRecorder.create(); + StreamObserver requestObserver = + stub.serverReflectionInfo(responseObserver); + handlerRegistry.addService(dynamicService); + requestObserver.onNext(request); + requestObserver.onCompleted(); + StreamRecorder responseObserver2 = StreamRecorder.create(); + StreamObserver requestObserver2 = + stub.serverReflectionInfo(responseObserver2); + handlerRegistry.removeService(dynamicService); + requestObserver2.onNext(request); + requestObserver2.onCompleted(); + StreamRecorder responseObserver3 = StreamRecorder.create(); + StreamObserver requestObserver3 = + stub.serverReflectionInfo(responseObserver3); + requestObserver3.onNext(request); + requestObserver3.onCompleted(); + + assertEquals( + ServerReflectionResponse.MessageResponseCase.ERROR_RESPONSE, + responseObserver.firstValue().get().getMessageResponseCase()); + assertEquals(goldenResponse, responseObserver2.firstValue().get()); + assertEquals( + ServerReflectionResponse.MessageResponseCase.ERROR_RESPONSE, + responseObserver3.firstValue().get().getMessageResponseCase()); + } + + @Test + public void allExtensionNumbersOfType() throws Exception { + ServerReflectionRequest request = + ServerReflectionRequest.newBuilder() + .setHost(TEST_HOST) + .setAllExtensionNumbersOfType("grpc.reflection.testing.ThirdLevelType") + .build(); + + Set goldenResponse = new HashSet<>(Arrays.asList(100, 101)); + + StreamRecorder responseObserver = StreamRecorder.create(); + StreamObserver requestObserver = + stub.serverReflectionInfo(responseObserver); + requestObserver.onNext(request); + requestObserver.onCompleted(); + Set extensionNumberResponseSet = + new HashSet<>( + responseObserver + .firstValue() + .get() + .getAllExtensionNumbersResponse() + .getExtensionNumberList()); + assertEquals(goldenResponse, extensionNumberResponseSet); + } + + @Test + public void allExtensionNumbersOfTypeForMutableServices() throws Exception { + String type = "grpc.reflection.testing.TypeWithExtensions"; + ServerReflectionRequest request = + ServerReflectionRequest.newBuilder() + .setHost(TEST_HOST) + .setAllExtensionNumbersOfType(type) + .build(); + ServerReflectionResponse goldenResponse = + ServerReflectionResponse.newBuilder() + .setValidHost(TEST_HOST) + .setOriginalRequest(request) + .setAllExtensionNumbersResponse( + ExtensionNumberResponse.newBuilder() + .setBaseTypeName(type) + .addExtensionNumber(200) + .build()) + .build(); + + StreamRecorder responseObserver = StreamRecorder.create(); + StreamObserver requestObserver = + stub.serverReflectionInfo(responseObserver); + handlerRegistry.addService(dynamicService); + requestObserver.onNext(request); + requestObserver.onCompleted(); + StreamRecorder responseObserver2 = StreamRecorder.create(); + StreamObserver requestObserver2 = + stub.serverReflectionInfo(responseObserver2); + handlerRegistry.removeService(dynamicService); + requestObserver2.onNext(request); + requestObserver2.onCompleted(); + StreamRecorder responseObserver3 = StreamRecorder.create(); + StreamObserver requestObserver3 = + stub.serverReflectionInfo(responseObserver3); + requestObserver3.onNext(request); + requestObserver3.onCompleted(); + + assertEquals( + ServerReflectionResponse.MessageResponseCase.ERROR_RESPONSE, + responseObserver.firstValue().get().getMessageResponseCase()); + assertEquals(goldenResponse, responseObserver2.firstValue().get()); + assertEquals( + ServerReflectionResponse.MessageResponseCase.ERROR_RESPONSE, + responseObserver3.firstValue().get().getMessageResponseCase()); + } + + @Test + public void sharedServiceBetweenServers() + throws IOException, ExecutionException, InterruptedException { + Server anotherServer = InProcessServerBuilder.forName("proto-reflection-test-2") + .directExecutor() + .addService(reflectionService) + .addService(new AnotherReflectableServiceGrpc.AnotherReflectableServiceImplBase() {}) + .build() + .start(); + grpcCleanupRule.register(anotherServer); + ManagedChannel anotherChannel = grpcCleanupRule.register( + InProcessChannelBuilder.forName("proto-reflection-test-2").directExecutor().build()); + ServerReflectionGrpc.ServerReflectionStub stub2 = ServerReflectionGrpc.newStub(anotherChannel); + + ServerReflectionRequest request = + ServerReflectionRequest.newBuilder().setHost(TEST_HOST).setListServices("services").build(); + StreamRecorder responseObserver = StreamRecorder.create(); + StreamObserver requestObserver = + stub2.serverReflectionInfo(responseObserver); + requestObserver.onNext(request); + requestObserver.onCompleted(); + List response = + responseObserver.firstValue().get().getListServicesResponse().getServiceList(); + assertEquals(new HashSet<>( + Arrays.asList( + ServiceResponse.newBuilder() + .setName("grpc.reflection.v1.ServerReflection") + .build(), + ServiceResponse.newBuilder() + .setName("grpc.reflection.testing.AnotherReflectableService") + .build())), + new HashSet<>(response)); + } + + @Test + public void flowControl() throws Exception { + FlowControlClientResponseObserver clientResponseObserver = + new FlowControlClientResponseObserver(); + ClientCallStreamObserver requestObserver = + (ClientCallStreamObserver) + stub.serverReflectionInfo(clientResponseObserver); + + // Verify we don't receive a response until we request it. + requestObserver.onNext(flowControlRequest); + assertEquals(0, clientResponseObserver.getResponses().size()); + + requestObserver.request(1); + assertEquals(1, clientResponseObserver.getResponses().size()); + assertEquals(flowControlGoldenResponse, clientResponseObserver.getResponses().get(0)); + + // Verify we don't receive an additional response until we request it. + requestObserver.onNext(flowControlRequest); + assertEquals(1, clientResponseObserver.getResponses().size()); + + requestObserver.request(1); + assertEquals(2, clientResponseObserver.getResponses().size()); + assertEquals(flowControlGoldenResponse, clientResponseObserver.getResponses().get(1)); + + requestObserver.onCompleted(); + assertTrue(clientResponseObserver.onCompleteCalled()); + } + + @Test + public void flowControlOnCompleteWithPendingRequest() throws Exception { + FlowControlClientResponseObserver clientResponseObserver = + new FlowControlClientResponseObserver(); + ClientCallStreamObserver requestObserver = + (ClientCallStreamObserver) + stub.serverReflectionInfo(clientResponseObserver); + + requestObserver.onNext(flowControlRequest); + requestObserver.onCompleted(); + assertEquals(0, clientResponseObserver.getResponses().size()); + assertFalse(clientResponseObserver.onCompleteCalled()); + + requestObserver.request(1); + assertTrue(clientResponseObserver.onCompleteCalled()); + assertEquals(1, clientResponseObserver.getResponses().size()); + assertEquals(flowControlGoldenResponse, clientResponseObserver.getResponses().get(0)); + } + + private final ServerReflectionRequest flowControlRequest = + ServerReflectionRequest.newBuilder() + .setHost(TEST_HOST) + .setFileByFilename("io/grpc/reflection/testing/reflection_test_depth_three.proto") + .build(); + private final ServerReflectionResponse flowControlGoldenResponse = + ServerReflectionResponse.newBuilder() + .setValidHost(TEST_HOST) + .setOriginalRequest(flowControlRequest) + .setFileDescriptorResponse( + FileDescriptorResponse.newBuilder() + .addFileDescriptorProto( + ReflectionTestDepthThreeProto.getDescriptor().toProto().toByteString()) + .build()) + .build(); + + private static class FlowControlClientResponseObserver + implements ClientResponseObserver { + private final List responses = + new ArrayList<>(); + private boolean onCompleteCalled = false; + + @Override + public void beforeStart(final ClientCallStreamObserver requestStream) { + requestStream.disableAutoRequestWithInitial(0); + } + + @Override + public void onNext(ServerReflectionResponse value) { + responses.add(value); + } + + @Override + public void onError(Throwable t) { + fail("onError called"); + } + + @Override + public void onCompleted() { + onCompleteCalled = true; + } + + public List getResponses() { + return responses; + } + + public boolean onCompleteCalled() { + return onCompleteCalled; + } + } + + private void assertServiceResponseEquals(Set goldenResponse) throws Exception { + ServerReflectionRequest request = + ServerReflectionRequest.newBuilder().setHost(TEST_HOST).setListServices("services").build(); + StreamRecorder responseObserver = StreamRecorder.create(); + StreamObserver requestObserver = + stub.serverReflectionInfo(responseObserver); + requestObserver.onNext(request); + requestObserver.onCompleted(); + List response = + responseObserver.firstValue().get().getListServicesResponse().getServiceList(); + assertEquals(goldenResponse.size(), response.size()); + assertEquals(goldenResponse, new HashSet<>(response)); + } +} diff --git a/servlet/build.gradle b/servlet/build.gradle index fe5914f5193..1367a72ab44 100644 --- a/servlet/build.gradle +++ b/servlet/build.gradle @@ -34,18 +34,16 @@ tasks.named("jar").configure { dependencies { api project(':grpc-api') - compileOnly libraries.javax.servlet.api, - libraries.javax.annotation // java 9, 10 needs it + compileOnly libraries.javax.servlet.api - implementation project(':grpc-util'), - project(':grpc-core'), + implementation project(':grpc-core'), libraries.guava testImplementation libraries.javax.servlet.api threadingTestImplementation project(':grpc-servlet'), - libraries.truth, - libraries.javax.servlet.api, + libraries.junit, + libraries.javax.servlet.api, libraries.lincheck itImplementation project(':grpc-servlet'), @@ -58,7 +56,7 @@ dependencies { exclude group: 'io.grpc', module: 'grpc-xds' } - undertowTestImplementation libraries.undertow.servlet + undertowTestImplementation libraries.undertow.servlet22 tomcatTestImplementation libraries.tomcat.embed.core9 @@ -70,19 +68,12 @@ dependencies { libraries.protobuf.java } -tasks.named("test").configure { - if (JavaVersion.current().isJava9Compatible()) { - jvmArgs += [ - // required for Lincheck - '--add-opens=java.base/jdk.internal.misc=ALL-UNNAMED', - '--add-exports=java.base/jdk.internal.util=ALL-UNNAMED', - ] - } -} - tasks.register('threadingTest', Test) { classpath = sourceSets.threadingTest.runtimeClasspath testClassesDirs = sourceSets.threadingTest.output.classesDirs + jacoco { + enabled = false + } } tasks.named("assemble").configure { diff --git a/servlet/jakarta/build.gradle b/servlet/jakarta/build.gradle index f548805bd2d..5cd213949f4 100644 --- a/servlet/jakarta/build.gradle +++ b/servlet/jakarta/build.gradle @@ -7,21 +7,24 @@ description = "gRPC: Jakarta Servlet" // Set up classpaths and source directories for different servlet tests sourceSets { - undertowTest { - java { - include '**/Undertow*.java' - } - } - tomcatTest { - java { - include '**/Tomcat*.java' + + // Only run these tests if the required minimum Java version is being used + if (JavaVersion.current().isCompatibleWith(JavaVersion.VERSION_17)) { + jettyTest { + java { + include '**/Jetty*.java' + } } } - // Only run these tests if java 11+ is being used if (JavaVersion.current().isJava11Compatible()) { - jettyTest { + tomcatTest { java { - include '**/Jetty*.java' + include '**/Tomcat*.java' + } + } + undertowTest { + java { + include '**/Undertow*.java' } } } @@ -44,11 +47,15 @@ def migrate(String name, String inputDir, SourceSet sourceSet) { def outputDir = layout.buildDirectory.dir('generated/sources/jakarta-' + name) sourceSet.java.srcDir tasks.register('migrateSources' + name.capitalize(), Sync) { task -> into(outputDir) + // Increment when changing the filter, to inform Gradle it needs to rebuild + inputs.property("filter-version", "1") from("$inputDir/io/grpc/servlet") { into('io/grpc/servlet/jakarta') filter { String line -> line.replace('javax.servlet', 'jakarta.servlet') .replace('io.grpc.servlet', 'io.grpc.servlet.jakarta') + .replace('org.eclipse.jetty.http2.parser', 'org.eclipse.jetty.http2') + .replace('org.eclipse.jetty.servlet', 'org.eclipse.jetty.ee10.servlet') } } } @@ -56,13 +63,14 @@ def migrate(String name, String inputDir, SourceSet sourceSet) { migrate('main', '../src/main/java', sourceSets.main) -// Build the set of sourceSets and classpaths to modify, since Jetty 11 requires Java 11 -// and must be skipped -migrate('undertowTest', '../src/undertowTest/java', sourceSets.undertowTest) -migrate('tomcatTest', '../src/tomcatTest/java', sourceSets.tomcatTest) -if (JavaVersion.current().isJava11Compatible()) { +// Only build sourceSets and classpaths for tests if using the required minimum Java version +if (JavaVersion.current().isCompatibleWith(JavaVersion.VERSION_17)) { migrate('jettyTest', '../src/jettyTest/java', sourceSets.jettyTest) } +if (JavaVersion.current().isJava11Compatible()) { + migrate('tomcatTest', '../src/tomcatTest/java', sourceSets.tomcatTest) + migrate('undertowTest', '../src/undertowTest/java', sourceSets.undertowTest) +} // Disable checkstyle for this project, since it consists only of generated code tasks.withType(Checkstyle).configureEach { @@ -77,8 +85,7 @@ tasks.named("jar").configure { dependencies { api project(':grpc-api') - compileOnly libraries.jakarta.servlet.api, - libraries.javax.annotation + compileOnly libraries.jakarta.servlet.api implementation project(':grpc-util'), project(':grpc-core'), @@ -99,46 +106,58 @@ dependencies { jettyTestImplementation libraries.jetty.servlet, libraries.jetty.http2.server - undertowTestImplementation libraries.undertow.servlet.jakartaee9 + undertowTestImplementation libraries.undertow.servlet } // Set up individual classpaths for each test, to avoid any mismatch, // and ensure they are only used when supported by the current jvm -def undertowTest = tasks.register('undertowTest', Test) { - classpath = sourceSets.undertowTest.runtimeClasspath - testClassesDirs = sourceSets.undertowTest.output.classesDirs -} -def tomcat10Test = tasks.register('tomcat10Test', Test) { - classpath = sourceSets.tomcatTest.runtimeClasspath - testClassesDirs = sourceSets.tomcatTest.output.classesDirs - - // Provide a temporary directory for tomcat to be deleted after test finishes - def tomcatTempDir = "$buildDir/tomcat_catalina_base" - systemProperty 'catalina.base', tomcatTempDir - doLast { - file(tomcatTempDir).deleteDir() +if (JavaVersion.current().isCompatibleWith(JavaVersion.VERSION_17)) { + def jetty11Test = tasks.register('jetty11Test', Test) { + classpath = sourceSets.jettyTest.runtimeClasspath + testClassesDirs = sourceSets.jettyTest.output.classesDirs } - - // tomcat-embed-core 10 presently performs illegal reflective access on - // java.io.ObjectStreamClass$Caches.localDescs and sun.rmi.transport.Target.ccl, - // see https://lists.apache.org/thread/s0xr7tk2kfkkxfjps9n7dhh4cypfdhyy - if (JavaVersion.current().isJava9Compatible()) { - jvmArgs += ['--add-opens=java.base/java.io=ALL-UNNAMED', '--add-opens=java.rmi/sun.rmi.transport=ALL-UNNAMED'] + tasks.named('compileJettyTestJava') { JavaCompile task -> + task.options.release.set 9 + } + tasks.named("check").configure { + dependsOn jetty11Test + } + tasks.named("jacocoTestReport").configure { + // Must use executionData(Task...) override. The executionData(Object...) override doesn't + // find execution data correctly for tasks. + executionData jetty11Test.get() } } - -tasks.named("check").configure { - dependsOn undertowTest, tomcat10Test -} - -// Only run these tests if java 11+ is being used if (JavaVersion.current().isJava11Compatible()) { - def jetty11Test = tasks.register('jetty11Test', Test) { - classpath = sourceSets.jettyTest.runtimeClasspath - testClassesDirs = sourceSets.jettyTest.output.classesDirs + def tomcat10Test = tasks.register('tomcat10Test', Test) { + classpath = sourceSets.tomcatTest.runtimeClasspath + testClassesDirs = sourceSets.tomcatTest.output.classesDirs + + // Provide a temporary directory for tomcat to be deleted after test finishes + def tomcatTempDir = "$buildDir/tomcat_catalina_base" + systemProperty 'catalina.base', tomcatTempDir + doLast { + file(tomcatTempDir).deleteDir() + } + } + tasks.named('compileTomcatTestJava') { JavaCompile task -> + task.options.release.set 11 + } + + def undertowTest = tasks.register('undertowTest', Test) { + classpath = sourceSets.undertowTest.runtimeClasspath + testClassesDirs = sourceSets.undertowTest.output.classesDirs + } + tasks.named('compileUndertowTestJava') { JavaCompile task -> + task.options.release.set 11 } tasks.named("check").configure { - dependsOn jetty11Test + dependsOn tomcat10Test, undertowTest + } + tasks.named("jacocoTestReport").configure { + // Must use executionData(Task...) override. The executionData(Object...) override doesn't + // find execution data correctly for tasks. + executionData tomcat10Test.get(), undertowTest.get() } } diff --git a/servlet/src/jettyTest/java/io/grpc/servlet/JettyTransportTest.java b/servlet/src/jettyTest/java/io/grpc/servlet/JettyTransportTest.java index f21754fb686..58143a8516c 100644 --- a/servlet/src/jettyTest/java/io/grpc/servlet/JettyTransportTest.java +++ b/servlet/src/jettyTest/java/io/grpc/servlet/JettyTransportTest.java @@ -69,6 +69,7 @@ public void start(ServerListener listener) throws IOException { listener.transportCreated(new ServletServerBuilder.ServerTransportImpl(scheduler)); ServletAdapter adapter = new ServletAdapter(serverTransportListener, streamTracerFactories, + ServletAdapter.DEFAULT_METHOD_NAME_RESOLVER, Integer.MAX_VALUE); GrpcServlet grpcServlet = new GrpcServlet(adapter); @@ -76,9 +77,7 @@ public void start(ServerListener listener) throws IOException { ServerConnector sc = (ServerConnector) jettyServer.getConnectors()[0]; HttpConfiguration httpConfiguration = new HttpConfiguration(); - // Must be set for several tests to pass, so that the request handling can begin before - // content arrives. - httpConfiguration.setDelayDispatchUntilContent(false); + setDelayDispatchUntilContent(httpConfiguration); HTTP2CServerConnectionFactory factory = new HTTP2CServerConnectionFactory(httpConfiguration); @@ -134,6 +133,16 @@ protected InternalServer newServer(int port, return newServer(streamTracerFactories); } + // The future default appears to be false as people are supposed to be migrate to + // EagerContentHandler, but the default is still true. Seems they messed up the migration + // process here by not flipping the default. + @SuppressWarnings("removal") + private static void setDelayDispatchUntilContent(HttpConfiguration httpConfiguration) { + // Must be set for several tests to pass, so that the request handling can begin before + // content arrives. + httpConfiguration.setDelayDispatchUntilContent(false); + } + @Override protected ManagedClientTransport newClientTransport(InternalServer server) { NettyChannelBuilder nettyChannelBuilder = NettyChannelBuilder @@ -252,4 +261,14 @@ public void clientCancel() { @Ignore("regression since bumping grpc v1.46 to v1.53") @Test public void messageProducerOnlyProducesRequestedMessages() {} + + @Override + @Ignore("https://github.com/jetty/jetty.project/issues/11822") + @Test + public void clientChecksInboundMetadataSize_header() {} + + @Override + @Ignore("https://github.com/jetty/jetty.project/issues/11822") + @Test + public void clientChecksInboundMetadataSize_trailer() {} } diff --git a/servlet/src/main/java/io/grpc/servlet/AsyncServletOutputStreamWriter.java b/servlet/src/main/java/io/grpc/servlet/AsyncServletOutputStreamWriter.java index 5ee5c02a128..3c8d3d07571 100644 --- a/servlet/src/main/java/io/grpc/servlet/AsyncServletOutputStreamWriter.java +++ b/servlet/src/main/java/io/grpc/servlet/AsyncServletOutputStreamWriter.java @@ -22,18 +22,19 @@ import static java.util.logging.Level.FINEST; import com.google.common.annotations.VisibleForTesting; +import com.google.errorprone.annotations.CheckReturnValue; import io.grpc.InternalLogId; import io.grpc.servlet.ServletServerStream.ServletTransportState; import java.io.IOException; -import java.time.Duration; import java.util.Queue; import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; import java.util.concurrent.locks.LockSupport; import java.util.function.BiFunction; import java.util.function.BooleanSupplier; +import java.util.logging.Level; import java.util.logging.Logger; -import javax.annotation.CheckReturnValue; import javax.annotation.Nullable; import javax.servlet.AsyncContext; import javax.servlet.ServletOutputStream; @@ -86,6 +87,11 @@ final class AsyncServletOutputStreamWriter { InternalLogId logId) throws IOException { Logger logger = Logger.getLogger(AsyncServletOutputStreamWriter.class.getName()); this.log = new Log() { + @Override + public boolean isLoggable(Level level) { + return logger.isLoggable(level); + } + @Override public void fine(String str, Object... params) { if (logger.isLoggable(FINE)) { @@ -105,7 +111,9 @@ public void finest(String str, Object... params) { this.writeAction = (byte[] bytes, Integer numBytes) -> () -> { outputStream.write(bytes, 0, numBytes); transportState.runOnTransportThread(() -> transportState.onSentBytes(numBytes)); - log.finest("outbound data: length={0}, bytes={1}", numBytes, toHexString(bytes, numBytes)); + if (log.isLoggable(Level.FINEST)) { + log.finest("outbound data: length={0}, bytes={1}", numBytes, toHexString(bytes, numBytes)); + } }; this.flushAction = () -> { log.finest("flushBuffer"); @@ -120,7 +128,7 @@ public void finest(String str, Object... params) { log.fine("call completed"); }); }; - this.isReady = () -> outputStream.isReady(); + this.isReady = outputStream::isReady; } /** @@ -165,7 +173,9 @@ void complete() { /** Called from the container thread {@link javax.servlet.WriteListener#onWritePossible()}. */ void onWritePossible() throws IOException { log.finest("onWritePossible: ENTRY. The servlet output stream becomes ready"); - assureReadyAndDrainedTurnsFalse(); + if (writeState.get().readyAndDrained) { + assureReadyAndDrainedTurnsFalse(); + } while (isReady.getAsBoolean()) { WriteState curState = writeState.get(); @@ -192,11 +202,9 @@ private void assureReadyAndDrainedTurnsFalse() { // readyAndDrained should have been set to false already. // Just in case due to a race condition readyAndDrained is still true at this moment and is // being set to false by runOrBuffer() concurrently. + parkingThread = Thread.currentThread(); while (writeState.get().readyAndDrained) { - parkingThread = Thread.currentThread(); - // Try to sleep for an extremely long time to avoid writeState being changed at exactly - // the time when sleep time expires (in extreme scenario, such as #9917). - LockSupport.parkNanos(Duration.ofHours(1).toNanos()); // should return immediately + LockSupport.parkNanos(TimeUnit.MINUTES.toNanos(1)); // should return immediately } parkingThread = null; } @@ -245,6 +253,10 @@ interface ActionItem { @VisibleForTesting // Lincheck test can not run with java.util.logging dependency. interface Log { + default boolean isLoggable(Level level) { + return false; + } + default void fine(String str, Object...params) {} default void finest(String str, Object...params) {} diff --git a/servlet/src/main/java/io/grpc/servlet/GrpcServlet.java b/servlet/src/main/java/io/grpc/servlet/GrpcServlet.java index f68ed083506..8c1eb858ad1 100644 --- a/servlet/src/main/java/io/grpc/servlet/GrpcServlet.java +++ b/servlet/src/main/java/io/grpc/servlet/GrpcServlet.java @@ -37,6 +37,7 @@ public class GrpcServlet extends HttpServlet { private static final long serialVersionUID = 1L; + @SuppressWarnings("serial") private final ServletAdapter servletAdapter; GrpcServlet(ServletAdapter servletAdapter) { diff --git a/servlet/src/main/java/io/grpc/servlet/ServletAdapter.java b/servlet/src/main/java/io/grpc/servlet/ServletAdapter.java index 5a567916f99..668e82425cb 100644 --- a/servlet/src/main/java/io/grpc/servlet/ServletAdapter.java +++ b/servlet/src/main/java/io/grpc/servlet/ServletAdapter.java @@ -22,6 +22,7 @@ import static java.util.logging.Level.FINE; import static java.util.logging.Level.FINEST; +import com.google.common.annotations.VisibleForTesting; import com.google.common.io.BaseEncoding; import io.grpc.Attributes; import io.grpc.ExperimentalApi; @@ -45,6 +46,7 @@ import java.util.Enumeration; import java.util.List; import java.util.concurrent.TimeUnit; +import java.util.function.Function; import java.util.logging.Logger; import javax.servlet.AsyncContext; import javax.servlet.AsyncEvent; @@ -72,18 +74,23 @@ public final class ServletAdapter { static final Logger logger = Logger.getLogger(ServletAdapter.class.getName()); + static final Function DEFAULT_METHOD_NAME_RESOLVER = + req -> req.getRequestURI().substring(1); // remove the leading "/" private final ServerTransportListener transportListener; private final List streamTracerFactories; + private final Function methodNameResolver; private final int maxInboundMessageSize; private final Attributes attributes; ServletAdapter( ServerTransportListener transportListener, List streamTracerFactories, + Function methodNameResolver, int maxInboundMessageSize) { this.transportListener = transportListener; this.streamTracerFactories = streamTracerFactories; + this.methodNameResolver = methodNameResolver; this.maxInboundMessageSize = maxInboundMessageSize; attributes = transportListener.transportReady(Attributes.EMPTY); } @@ -119,7 +126,7 @@ public void doPost(HttpServletRequest req, HttpServletResponse resp) throws IOEx AsyncContext asyncCtx = req.startAsync(req, resp); - String method = req.getRequestURI().substring(1); // remove the leading "/" + String method = methodNameResolver.apply(req); Metadata headers = getHeaders(req); if (logger.isLoggable(FINEST)) { @@ -128,10 +135,9 @@ public void doPost(HttpServletRequest req, HttpServletResponse resp) throws IOEx } Long timeoutNanos = headers.get(TIMEOUT_KEY); - if (timeoutNanos == null) { - timeoutNanos = 0L; - } - asyncCtx.setTimeout(TimeUnit.NANOSECONDS.toMillis(timeoutNanos)); + asyncCtx.setTimeout(timeoutNanos != null + ? TimeUnit.NANOSECONDS.toMillis(timeoutNanos) + ASYNC_TIMEOUT_SAFETY_MARGIN + : 0); StatsTraceContext statsTraceCtx = StatsTraceContext.newServerContext(streamTracerFactories, method, headers); @@ -158,6 +164,12 @@ public void doPost(HttpServletRequest req, HttpServletResponse resp) throws IOEx asyncCtx.addListener(new GrpcAsyncListener(stream, logId)); } + /** + * Deadlines are managed via Context, servlet async timeout is not supposed to happen. + */ + @VisibleForTesting + static final long ASYNC_TIMEOUT_SAFETY_MARGIN = 5_000; + // This method must use Enumeration and its members, since that is the only way to read headers // from the servlet api. @SuppressWarnings("JdkObsolete") @@ -215,7 +227,9 @@ private static final class GrpcAsyncListener implements AsyncListener { } @Override - public void onComplete(AsyncEvent event) {} + public void onComplete(AsyncEvent event) { + stream.asyncCompleted = true; + } @Override public void onTimeout(AsyncEvent event) { diff --git a/servlet/src/main/java/io/grpc/servlet/ServletServerBuilder.java b/servlet/src/main/java/io/grpc/servlet/ServletServerBuilder.java index 72c4383d273..5bea4c6e03b 100644 --- a/servlet/src/main/java/io/grpc/servlet/ServletServerBuilder.java +++ b/servlet/src/main/java/io/grpc/servlet/ServletServerBuilder.java @@ -49,8 +49,10 @@ import java.util.Collections; import java.util.List; import java.util.concurrent.ScheduledExecutorService; +import java.util.function.Function; import javax.annotation.Nullable; import javax.annotation.concurrent.NotThreadSafe; +import javax.servlet.http.HttpServletRequest; /** * Builder to build a gRPC server that can run as a servlet. This is for advanced custom settings. @@ -64,6 +66,8 @@ @NotThreadSafe public final class ServletServerBuilder extends ForwardingServerBuilder { List streamTracerFactories; + private Function methodNameResolver = + ServletAdapter.DEFAULT_METHOD_NAME_RESOLVER; int maxInboundMessageSize = DEFAULT_MAX_MESSAGE_SIZE; private final ServerImplBuilder serverImplBuilder; @@ -74,7 +78,9 @@ public final class ServletServerBuilder extends ForwardingServerBuilder + buildTransportServers(streamTracerFactories)); } /** @@ -98,7 +104,8 @@ public Server build() { * Creates a {@link ServletAdapter}. */ public ServletAdapter buildServletAdapter() { - return new ServletAdapter(buildAndStart(), streamTracerFactories, maxInboundMessageSize); + return new ServletAdapter(buildAndStart(), streamTracerFactories, methodNameResolver, + maxInboundMessageSize); } /** @@ -176,6 +183,18 @@ public ServletServerBuilder useTransportSecurity(File certChain, File privateKey throw new UnsupportedOperationException("TLS should be configured by the servlet container"); } + /** + * Specifies how to determine gRPC method name from servlet request. + * + *

The default strategy is using {@link HttpServletRequest#getRequestURI()} without the leading + * slash.

+ */ + public ServletServerBuilder methodNameResolver( + Function methodResolver) { + this.methodNameResolver = checkNotNull(methodResolver); + return this; + } + @Override public ServletServerBuilder maxInboundMessageSize(int bytes) { checkArgument(bytes >= 0, "bytes must be >= 0"); diff --git a/servlet/src/main/java/io/grpc/servlet/ServletServerStream.java b/servlet/src/main/java/io/grpc/servlet/ServletServerStream.java index b7ad6e0decc..0182f302698 100644 --- a/servlet/src/main/java/io/grpc/servlet/ServletServerStream.java +++ b/servlet/src/main/java/io/grpc/servlet/ServletServerStream.java @@ -30,7 +30,6 @@ import io.grpc.InternalLogId; import io.grpc.Metadata; import io.grpc.Status; -import io.grpc.Status.Code; import io.grpc.internal.AbstractServerStream; import io.grpc.internal.GrpcUtil; import io.grpc.internal.SerializingExecutor; @@ -43,8 +42,7 @@ import java.util.Collections; import java.util.HashMap; import java.util.Map; -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.TimeUnit; +import java.util.function.BiConsumer; import java.util.function.Supplier; import java.util.logging.Logger; import javax.annotation.Nullable; @@ -58,12 +56,15 @@ final class ServletServerStream extends AbstractServerStream { private final ServletTransportState transportState; private final Sink sink = new Sink(); - private final AsyncContext asyncCtx; private final HttpServletResponse resp; private final Attributes attributes; private final String authority; private final InternalLogId logId; private final AsyncServletOutputStreamWriter writer; + /** + * If the async servlet operation has been completed. + */ + volatile boolean asyncCompleted = false; ServletServerStream( AsyncContext asyncCtx, @@ -78,7 +79,6 @@ final class ServletServerStream extends AbstractServerStream { this.attributes = attributes; this.authority = authority; this.logId = logId; - this.asyncCtx = asyncCtx; this.resp = (HttpServletResponse) asyncCtx.getResponse(); this.writer = new AsyncServletOutputStreamWriter( asyncCtx, transportState, logId); @@ -123,9 +123,13 @@ private void writeHeadersToServletResponse(Metadata metadata) { resp.setStatus(HttpServletResponse.SC_OK); resp.setContentType(CONTENT_TYPE_GRPC); + serializeHeaders(metadata, resp::addHeader); + } + + private static void serializeHeaders(Metadata metadata, BiConsumer consumer) { byte[][] serializedHeaders = TransportFrameUtil.toHttp2Headers(metadata); for (int i = 0; i < serializedHeaders.length; i += 2) { - resp.addHeader( + consumer.accept( new String(serializedHeaders[i], StandardCharsets.US_ASCII), new String(serializedHeaders[i + 1], StandardCharsets.US_ASCII)); } @@ -154,8 +158,8 @@ public void bytesRead(int numBytes) { @Override public void deframeFailed(Throwable cause) { - if (logger.isLoggable(FINE)) { - logger.log(FINE, String.format("[{%s}] Exception processing message", logId), cause); + if (logger.isLoggable(WARNING)) { + logger.log(WARNING, String.format("[{%s}] Exception processing message", logId), cause); } cancel(Status.fromThrowable(cause)); } @@ -168,7 +172,7 @@ private static final class ByteArrayWritableBuffer implements WritableBuffer { private int index; ByteArrayWritableBuffer(int capacityHint) { - this.bytes = new byte[min(1024 * 1024, max(4096, capacityHint))]; + this.bytes = new byte[min(1024 * 1024, capacityHint)]; this.capacity = bytes.length; } @@ -278,13 +282,8 @@ public void writeTrailers(Metadata trailers, boolean headersSent, Status status) if (!headersSent) { writeHeadersToServletResponse(trailers); } else { - byte[][] serializedHeaders = TransportFrameUtil.toHttp2Headers(trailers); - for (int i = 0; i < serializedHeaders.length; i += 2) { - String key = new String(serializedHeaders[i], StandardCharsets.US_ASCII); - String newValue = new String(serializedHeaders[i + 1], StandardCharsets.US_ASCII); - trailerSupplier.get().computeIfPresent(key, (k, v) -> v + "," + newValue); - trailerSupplier.get().putIfAbsent(key, newValue); - } + serializeHeaders(trailers, + (k, v) -> trailerSupplier.get().merge(k, v, (oldV, newV) -> oldV + "," + newV)); } writer.complete(); @@ -292,22 +291,14 @@ public void writeTrailers(Metadata trailers, boolean headersSent, Status status) @Override public void cancel(Status status) { - if (resp.isCommitted() && Code.DEADLINE_EXCEEDED == status.getCode()) { - return; // let the servlet timeout, the container will sent RST_STREAM automatically - } transportState.runOnTransportThread(() -> transportState.transportReportStatus(status)); - // There is no way to RST_STREAM with CANCEL code, so write trailers instead - close(Status.CANCELLED.withCause(status.asRuntimeException()), new Metadata()); - CountDownLatch countDownLatch = new CountDownLatch(1); - transportState.runOnTransportThread(() -> { - asyncCtx.complete(); - countDownLatch.countDown(); - }); - try { - countDownLatch.await(5, TimeUnit.SECONDS); - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); + if (asyncCompleted) { + logger.fine("ignore cancel as already completed"); + return; } + // There is no way to RST_STREAM with CANCEL code, so write trailers instead + close(status, new Metadata()); + // close() calls writeTrailers(), which calls AsyncContext.complete() } } diff --git a/servlet/src/test/java/io/grpc/servlet/ServletServerBuilderTest.java b/servlet/src/test/java/io/grpc/servlet/ServletServerBuilderTest.java index d571cfd45d5..7a8c5b91f25 100644 --- a/servlet/src/test/java/io/grpc/servlet/ServletServerBuilderTest.java +++ b/servlet/src/test/java/io/grpc/servlet/ServletServerBuilderTest.java @@ -80,7 +80,7 @@ public void scheduledExecutorService() throws Exception { ServletAdapter servletAdapter = serverBuilder.buildServletAdapter(); servletAdapter.doPost(request, response); - verify(asyncContext).setTimeout(1); + verify(asyncContext).setTimeout(1 + ServletAdapter.ASYNC_TIMEOUT_SAFETY_MARGIN); // The following just verifies that scheduler is populated to the transport. // It doesn't matter what tasks (such as handshake timeout and request deadline) are actually diff --git a/servlet/src/threadingTest/java/io/grpc/servlet/AsyncServletOutputStreamWriterConcurrencyTest.java b/servlet/src/threadingTest/java/io/grpc/servlet/AsyncServletOutputStreamWriterConcurrencyTest.java index 61da2bf4c69..b2891b6e47e 100644 --- a/servlet/src/threadingTest/java/io/grpc/servlet/AsyncServletOutputStreamWriterConcurrencyTest.java +++ b/servlet/src/threadingTest/java/io/grpc/servlet/AsyncServletOutputStreamWriterConcurrencyTest.java @@ -16,23 +16,22 @@ package io.grpc.servlet; -import static com.google.common.truth.Truth.assertWithMessage; -import static org.jetbrains.kotlinx.lincheck.strategy.managed.ManagedStrategyGuaranteeKt.forClasses; +import static org.jetbrains.lincheck.datastructures.ManagedStrategyGuaranteeKt.forClasses; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; import io.grpc.servlet.AsyncServletOutputStreamWriter.ActionItem; import io.grpc.servlet.AsyncServletOutputStreamWriter.Log; import java.io.IOException; import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicReference; import java.util.function.BiFunction; -import org.jetbrains.kotlinx.lincheck.LinChecker; -import org.jetbrains.kotlinx.lincheck.annotations.OpGroupConfig; -import org.jetbrains.kotlinx.lincheck.annotations.Operation; -import org.jetbrains.kotlinx.lincheck.annotations.Param; -import org.jetbrains.kotlinx.lincheck.paramgen.BooleanGen; -import org.jetbrains.kotlinx.lincheck.strategy.managed.modelchecking.ModelCheckingCTest; -import org.jetbrains.kotlinx.lincheck.strategy.managed.modelchecking.ModelCheckingOptions; -import org.jetbrains.kotlinx.lincheck.verifier.VerifierState; +import org.jetbrains.lincheck.datastructures.BooleanGen; +import org.jetbrains.lincheck.datastructures.ModelCheckingOptions; +import org.jetbrains.lincheck.datastructures.Operation; +import org.jetbrains.lincheck.datastructures.Param; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -49,18 +48,19 @@ * test all possibly interleaves (on context switch) between the two threads, and then verify the * operations are linearizable in each interleave scenario. */ -@ModelCheckingCTest -@OpGroupConfig(name = "update", nonParallel = true) -@OpGroupConfig(name = "write", nonParallel = true) @Param(name = "keepReady", gen = BooleanGen.class) @RunWith(JUnit4.class) -public class AsyncServletOutputStreamWriterConcurrencyTest extends VerifierState { +public class AsyncServletOutputStreamWriterConcurrencyTest { private static final int OPERATIONS_PER_THREAD = 6; private final AsyncServletOutputStreamWriter writer; private final boolean[] keepReadyArray = new boolean[OPERATIONS_PER_THREAD]; private volatile boolean isReady; + /** + * The container initiates the first call shortly after {@code startAsync}. + */ + private final AtomicBoolean initialOnWritePossible = new AtomicBoolean(true); // when isReadyReturnedFalse, writer.onWritePossible() will be called. private volatile boolean isReadyReturnedFalse; private int producerIndex; @@ -71,17 +71,15 @@ public class AsyncServletOutputStreamWriterConcurrencyTest extends VerifierState public AsyncServletOutputStreamWriterConcurrencyTest() { BiFunction writeAction = (bytes, numBytes) -> () -> { - assertWithMessage("write should only be called while isReady() is true") - .that(isReady) - .isTrue(); + assertTrue("write should only be called while isReady() is true", isReady); // The byte to be written must equal to consumerIndex, otherwise execution order is wrong - assertWithMessage("write in wrong order").that(bytes[0]).isEqualTo((byte) consumerIndex); + assertEquals("write in wrong order", bytes[0], (byte) consumerIndex); bytesWritten++; writeOrFlush(); }; ActionItem flushAction = () -> { - assertWithMessage("flush must only be called while isReady() is true").that(isReady).isTrue(); + assertTrue("flush must only be called while isReady() is true", isReady); writeOrFlush(); }; @@ -102,12 +100,13 @@ private void writeOrFlush() { } private boolean isReady() { - if (!isReady) { - assertWithMessage("isReady() already returned false, onWritePossible() will be invoked") - .that(isReadyReturnedFalse).isFalse(); + boolean copyOfIsReady = isReady; + if (!copyOfIsReady) { + assertFalse("isReady() already returned false, onWritePossible() will be invoked", + isReadyReturnedFalse); isReadyReturnedFalse = true; } - return isReady; + return copyOfIsReady; } /** @@ -118,7 +117,7 @@ private boolean isReady() { * the ServletOutputStream should become unready if keepReady == false. */ // @com.google.errorprone.annotations.Keep - @Operation(group = "write") + @Operation(nonParallelGroup = "write") public void write(@Param(name = "keepReady") boolean keepReady) throws IOException { keepReadyArray[producerIndex] = keepReady; writer.writeBytes(new byte[]{(byte) producerIndex}, 1); @@ -133,7 +132,7 @@ public void write(@Param(name = "keepReady") boolean keepReady) throws IOExcepti * the ServletOutputStream should become unready if keepReady == false. */ // @com.google.errorprone.annotations.Keep // called by lincheck reflectively - @Operation(group = "write") + @Operation(nonParallelGroup = "write") public void flush(@Param(name = "keepReady") boolean keepReady) throws IOException { keepReadyArray[producerIndex] = keepReady; writer.flush(); @@ -142,9 +141,12 @@ public void flush(@Param(name = "keepReady") boolean keepReady) throws IOExcepti /** If the writer is not ready, let it turn ready and call writer.onWritePossible(). */ // @com.google.errorprone.annotations.Keep // called by lincheck reflectively - @Operation(group = "update") + @Operation(nonParallelGroup = "update") public void maybeOnWritePossible() throws IOException { - if (isReadyReturnedFalse) { + if (initialOnWritePossible.compareAndSet(true, false)) { + isReady = true; + writer.onWritePossible(); + } else if (isReadyReturnedFalse) { isReadyReturnedFalse = false; isReady = true; writer.onWritePossible(); @@ -152,7 +154,13 @@ public void maybeOnWritePossible() throws IOException { } @Override - protected Object extractState() { + public final boolean equals(Object o) { + return o instanceof AsyncServletOutputStreamWriterConcurrencyTest + && bytesWritten == ((AsyncServletOutputStreamWriterConcurrencyTest) o).bytesWritten; + } + + @Override + public int hashCode() { return bytesWritten; } @@ -169,6 +177,6 @@ public void linCheck() { AtomicReference.class.getName()) .allMethods() .treatAsAtomic()); - LinChecker.check(AsyncServletOutputStreamWriterConcurrencyTest.class, options); + options.check(AsyncServletOutputStreamWriterConcurrencyTest.class); } } diff --git a/servlet/src/tomcatTest/java/io/grpc/servlet/TomcatInteropTest.java b/servlet/src/tomcatTest/java/io/grpc/servlet/TomcatInteropTest.java index 1422b5388fd..d072fea93a1 100644 --- a/servlet/src/tomcatTest/java/io/grpc/servlet/TomcatInteropTest.java +++ b/servlet/src/tomcatTest/java/io/grpc/servlet/TomcatInteropTest.java @@ -113,27 +113,28 @@ protected boolean metricsExpected() { @Test public void gracefulShutdown() {} - // FIXME @Override @Ignore("Tomcat is not able to send trailer only") @Test public void specialStatusMessage() {} - // FIXME @Override @Ignore("Tomcat is not able to send trailer only") @Test public void unimplementedMethod() {} - // FIXME @Override @Ignore("Tomcat is not able to send trailer only") @Test public void statusCodeAndMessage() {} - // FIXME @Override @Ignore("Tomcat is not able to send trailer only") @Test public void emptyStream() {} + + @Override + @Ignore("Tomcat is not able to send trailer only") + @Test + public void timeoutOnSleepingServer() {} } diff --git a/servlet/src/tomcatTest/java/io/grpc/servlet/TomcatTransportTest.java b/servlet/src/tomcatTest/java/io/grpc/servlet/TomcatTransportTest.java index 262036883a9..cd73b096ccb 100644 --- a/servlet/src/tomcatTest/java/io/grpc/servlet/TomcatTransportTest.java +++ b/servlet/src/tomcatTest/java/io/grpc/servlet/TomcatTransportTest.java @@ -81,7 +81,9 @@ public void start(ServerListener listener) throws IOException { ServerTransportListener serverTransportListener = listener.transportCreated(new ServerTransportImpl(scheduler)); ServletAdapter adapter = - new ServletAdapter(serverTransportListener, streamTracerFactories, Integer.MAX_VALUE); + new ServletAdapter(serverTransportListener, streamTracerFactories, + ServletAdapter.DEFAULT_METHOD_NAME_RESOLVER, + Integer.MAX_VALUE); GrpcServlet grpcServlet = new GrpcServlet(adapter); tomcatServer = new Tomcat(); @@ -91,6 +93,10 @@ public void start(ServerListener listener) throws IOException { .setAsyncSupported(true); ctx.addServletMappingDecoded("/*", "TomcatTransportTest"); tomcatServer.getConnector().addUpgradeProtocol(new Http2Protocol()); + // Workaround for https://github.com/grpc/grpc-java/issues/12540 + // Prevent premature OutputBuffer recycling by disabling facade recycling. + // This should be revisited once the root cause is fixed. + tomcatServer.getConnector().setDiscardFacades(false); try { tomcatServer.start(); } catch (LifecycleException e) { diff --git a/servlet/src/undertowTest/java/io/grpc/servlet/UndertowTransportTest.java b/servlet/src/undertowTest/java/io/grpc/servlet/UndertowTransportTest.java index e14c11985de..ef897c87d70 100644 --- a/servlet/src/undertowTest/java/io/grpc/servlet/UndertowTransportTest.java +++ b/servlet/src/undertowTest/java/io/grpc/servlet/UndertowTransportTest.java @@ -100,7 +100,9 @@ public void start(ServerListener listener) throws IOException { ServerTransportListener serverTransportListener = listener.transportCreated(new ServerTransportImpl(scheduler)); ServletAdapter adapter = - new ServletAdapter(serverTransportListener, streamTracerFactories, Integer.MAX_VALUE); + new ServletAdapter(serverTransportListener, streamTracerFactories, + ServletAdapter.DEFAULT_METHOD_NAME_RESOLVER, + Integer.MAX_VALUE); GrpcServlet grpcServlet = new GrpcServlet(adapter); InstanceFactory instanceFactory = () -> new ImmediateInstanceHandle<>(grpcServlet); diff --git a/settings.gradle b/settings.gradle index d7aea83b3a4..51c4bdc0d3d 100644 --- a/settings.gradle +++ b/settings.gradle @@ -1,26 +1,44 @@ pluginManagement { + // https://issuetracker.google.com/issues/342522142#comment8 + // use D8/R8 8.0.44 or 8.1.44 with AGP 7.4 if needed. + buildscript { + repositories { + mavenCentral() + maven { + url = uri("https://storage.googleapis.com/r8-releases/raw") + } + } + dependencies { + classpath("com.android.tools:r8:8.1.44") + } + } plugins { // https://developer.android.com/build/releases/gradle-plugin - id "com.android.application" version "7.4.0" - id "com.android.library" version "7.4.0" - // https://github.com/johnrengelman/shadow/releases - id "com.github.johnrengelman.shadow" version "8.1.1" + // 8+ has many changes: https://github.com/grpc/grpc-java/issues/10152 + id "com.android.application" version "7.4.1" + id "com.android.library" version "7.4.1" + // https://github.com/kt3k/coveralls-gradle-plugin/tags id "com.github.kt3k.coveralls" version "2.12.2" - // https://github.com/GoogleCloudPlatform/app-gradle-plugin/releases - id "com.google.cloud.tools.appengine" version "2.4.5" + // https://github.com/GoogleCloudPlatform/appengine-plugins/releases + id "com.google.cloud.tools.appengine" version "2.8.6" // https://github.com/GoogleContainerTools/jib/blob/master/jib-gradle-plugin/CHANGELOG.md - id "com.google.cloud.tools.jib" version "3.3.2" + id "com.google.cloud.tools.jib" version "3.5.1" + // https://github.com/google/osdetector-gradle-plugin/tags id "com.google.osdetector" version "1.7.3" // https://github.com/google/protobuf-gradle-plugin/releases - id "com.google.protobuf" version "0.9.4" + id "com.google.protobuf" version "0.9.5" + // https://github.com/GradleUp/shadow/releases + // 8.3.2+ requires Java 11+ + // 8.3.1 breaks apache imports for netty/shaded, fixed in 8.3.2 + id "com.gradleup.shadow" version "8.3.0" // https://github.com/melix/japicmp-gradle-plugin/blob/master/CHANGELOG.txt - id "me.champeau.gradle.japicmp" version "0.4.1" + id "me.champeau.gradle.japicmp" version "0.4.2" // https://github.com/melix/jmh-gradle-plugin/releases - id "me.champeau.jmh" version "0.7.1" + id "me.champeau.jmh" version "0.7.3" // https://github.com/tbroyer/gradle-errorprone-plugin/releases - id "net.ltgt.errorprone" version "3.1.0" + id "net.ltgt.errorprone" version "4.3.0" // https://github.com/xvik/gradle-animalsniffer-plugin/releases - id "ru.vyarus.animalsniffer" version "1.7.1" + id "ru.vyarus.animalsniffer" version "2.0.1" } resolutionStrategy { eachPlugin { @@ -28,7 +46,7 @@ pluginManagement { useModule("com.android.tools.build:gradle:${target.version}") } if (requested.id.id.startsWith('com.google.cloud.tools.appengine')) { - useModule("com.google.cloud.tools:appengine-gradle-plugin:${requested.version}") + useModule("com.google.cloud.tools:appengine-gradle-plugin:${target.version}") } } } @@ -62,16 +80,19 @@ include ":grpc-benchmarks" include ":grpc-services" include ":grpc-servlet" include ":grpc-servlet-jakarta" +include ":grpc-s2a" include ":grpc-xds" include ":grpc-bom" include ":grpc-rls" include ":grpc-authz" +include ":grpc-gcp-csm-observability" include ":grpc-gcp-observability" include ":grpc-gcp-observability:interop" include ":grpc-istio-interop-testing" include ":grpc-inprocess" include ":grpc-util" include ":grpc-opentelemetry" +include ":grpc-context-override-opentelemetry" project(':grpc-api').projectDir = "$rootDir/api" as File project(':grpc-core').projectDir = "$rootDir/core" as File @@ -96,16 +117,19 @@ project(':grpc-benchmarks').projectDir = "$rootDir/benchmarks" as File project(':grpc-services').projectDir = "$rootDir/services" as File project(':grpc-servlet').projectDir = "$rootDir/servlet" as File project(':grpc-servlet-jakarta').projectDir = "$rootDir/servlet/jakarta" as File +project(':grpc-s2a').projectDir = "$rootDir/s2a" as File project(':grpc-xds').projectDir = "$rootDir/xds" as File project(':grpc-bom').projectDir = "$rootDir/bom" as File project(':grpc-rls').projectDir = "$rootDir/rls" as File project(':grpc-authz').projectDir = "$rootDir/authz" as File +project(':grpc-gcp-csm-observability').projectDir = "$rootDir/gcp-csm-observability" as File project(':grpc-gcp-observability').projectDir = "$rootDir/gcp-observability" as File project(':grpc-gcp-observability:interop').projectDir = "$rootDir/gcp-observability/interop" as File project(':grpc-istio-interop-testing').projectDir = "$rootDir/istio-interop-testing" as File project(':grpc-inprocess').projectDir = "$rootDir/inprocess" as File project(':grpc-util').projectDir = "$rootDir/util" as File project(':grpc-opentelemetry').projectDir = "$rootDir/opentelemetry" as File +project(':grpc-context-override-opentelemetry').projectDir = "$rootDir/contextstorage" as File if (settings.hasProperty('skipCodegen') && skipCodegen.toBoolean()) { println '*** Skipping the build of codegen and compilation of proto files because skipCodegen=true' diff --git a/stub/BUILD.bazel b/stub/BUILD.bazel index c65b01a23dc..f9188c27272 100644 --- a/stub/BUILD.bazel +++ b/stub/BUILD.bazel @@ -1,3 +1,6 @@ +load("@rules_java//java:defs.bzl", "java_library") +load("@rules_jvm_external//:defs.bzl", "artifact") + java_library( name = "stub", srcs = glob([ @@ -7,18 +10,9 @@ java_library( deps = [ "//api", "//context", - "@com_google_code_findbugs_jsr305//jar", - "@com_google_errorprone_error_prone_annotations//jar", - "@com_google_guava_guava//jar", - "@com_google_j2objc_j2objc_annotations//jar", + artifact("com.google.code.findbugs:jsr305"), + artifact("com.google.errorprone:error_prone_annotations"), + artifact("com.google.guava:guava"), + artifact("org.codehaus.mojo:animal-sniffer-annotations"), ], ) - -# javax.annotation.Generated is not included in the default root modules in 9, -# see: http://openjdk.java.net/jeps/320. -java_library( - name = "javax_annotation", - neverlink = 1, # @Generated is source-retention - visibility = ["//visibility:public"], - exports = ["@org_apache_tomcat_annotations_api//jar"], -) diff --git a/stub/build.gradle b/stub/build.gradle index 867936f3ea3..2dabd9e6202 100644 --- a/stub/build.gradle +++ b/stub/build.gradle @@ -16,14 +16,23 @@ tasks.named("jar").configure { dependencies { api project(':grpc-api'), + libraries.animalsniffer.annotations, libraries.guava implementation libraries.errorprone.annotations testImplementation libraries.truth, project(':grpc-inprocess'), project(':grpc-testing'), testFixtures(project(':grpc-api')) - signature libraries.signature.java - signature libraries.signature.android + signature (libraries.signature.java) { + artifact { + extension = "signature" + } + } + signature (libraries.signature.android) { + artifact { + extension = "signature" + } + } } tasks.named("javadoc").configure { diff --git a/stub/src/main/java/io/grpc/stub/AbstractAsyncStub.java b/stub/src/main/java/io/grpc/stub/AbstractAsyncStub.java index c6f912cb3a7..041f9ed08ed 100644 --- a/stub/src/main/java/io/grpc/stub/AbstractAsyncStub.java +++ b/stub/src/main/java/io/grpc/stub/AbstractAsyncStub.java @@ -16,11 +16,10 @@ package io.grpc.stub; +import com.google.errorprone.annotations.CheckReturnValue; import io.grpc.CallOptions; import io.grpc.Channel; import io.grpc.stub.ClientCalls.StubType; -import javax.annotation.CheckReturnValue; -import javax.annotation.concurrent.ThreadSafe; /** * Stub implementations for async stubs. @@ -28,9 +27,10 @@ *

DO NOT MOCK: Customizing options doesn't work properly in mocks. Use InProcessChannelBuilder * to create a real channel suitable for testing. It is also possible to mock Channel instead. * + *

This class is thread-safe. + * * @since 1.26.0 */ -@ThreadSafe @CheckReturnValue public abstract class AbstractAsyncStub> extends AbstractStub { diff --git a/stub/src/main/java/io/grpc/stub/AbstractBlockingStub.java b/stub/src/main/java/io/grpc/stub/AbstractBlockingStub.java index 1cb919e67b0..49ecd1fca40 100644 --- a/stub/src/main/java/io/grpc/stub/AbstractBlockingStub.java +++ b/stub/src/main/java/io/grpc/stub/AbstractBlockingStub.java @@ -16,11 +16,10 @@ package io.grpc.stub; +import com.google.errorprone.annotations.CheckReturnValue; import io.grpc.CallOptions; import io.grpc.Channel; import io.grpc.stub.ClientCalls.StubType; -import javax.annotation.CheckReturnValue; -import javax.annotation.concurrent.ThreadSafe; /** * Stub implementations for blocking stubs. @@ -28,9 +27,10 @@ *

DO NOT MOCK: Customizing options doesn't work properly in mocks. Use InProcessChannelBuilder * to create a real channel suitable for testing. It is also possible to mock Channel instead. * + *

This class is thread-safe. + * * @since 1.26.0 */ -@ThreadSafe @CheckReturnValue public abstract class AbstractBlockingStub> extends AbstractStub { diff --git a/stub/src/main/java/io/grpc/stub/AbstractFutureStub.java b/stub/src/main/java/io/grpc/stub/AbstractFutureStub.java index 66570bcd6ff..4aede0dcbbe 100644 --- a/stub/src/main/java/io/grpc/stub/AbstractFutureStub.java +++ b/stub/src/main/java/io/grpc/stub/AbstractFutureStub.java @@ -16,11 +16,10 @@ package io.grpc.stub; +import com.google.errorprone.annotations.CheckReturnValue; import io.grpc.CallOptions; import io.grpc.Channel; import io.grpc.stub.ClientCalls.StubType; -import javax.annotation.CheckReturnValue; -import javax.annotation.concurrent.ThreadSafe; /** * Stub implementations for future stubs. @@ -28,9 +27,10 @@ *

DO NOT MOCK: Customizing options doesn't work properly in mocks. Use InProcessChannelBuilder * to create a real channel suitable for testing. It is also possible to mock Channel instead. * + *

This class is thread-safe. + * * @since 1.26.0 */ -@ThreadSafe @CheckReturnValue public abstract class AbstractFutureStub> extends AbstractStub { diff --git a/stub/src/main/java/io/grpc/stub/AbstractStub.java b/stub/src/main/java/io/grpc/stub/AbstractStub.java index efda8799d76..409f1e7ed53 100644 --- a/stub/src/main/java/io/grpc/stub/AbstractStub.java +++ b/stub/src/main/java/io/grpc/stub/AbstractStub.java @@ -17,7 +17,9 @@ package io.grpc.stub; import static com.google.common.base.Preconditions.checkNotNull; +import static io.grpc.InternalTimeUtils.convert; +import com.google.errorprone.annotations.CheckReturnValue; import io.grpc.CallCredentials; import io.grpc.CallOptions; import io.grpc.Channel; @@ -26,11 +28,11 @@ import io.grpc.Deadline; import io.grpc.ExperimentalApi; import io.grpc.ManagedChannelBuilder; +import java.time.Duration; import java.util.concurrent.Executor; import java.util.concurrent.TimeUnit; -import javax.annotation.CheckReturnValue; import javax.annotation.Nullable; -import javax.annotation.concurrent.ThreadSafe; +import org.codehaus.mojo.animal_sniffer.IgnoreJRERequirement; /** * Common base type for stub implementations. Stub configuration is immutable; changing the @@ -43,10 +45,11 @@ *

DO NOT MOCK: Customizing options doesn't work properly in mocks. Use InProcessChannelBuilder * to create a real channel suitable for testing. It is also possible to mock Channel instead. * + *

This class is thread-safe. + * * @since 1.0.0 * @param the concrete type of this stub. */ -@ThreadSafe @CheckReturnValue public abstract class AbstractStub> { private final Channel channel; @@ -149,6 +152,12 @@ public final S withDeadlineAfter(long duration, TimeUnit unit) { return build(channel, callOptions.withDeadlineAfter(duration, unit)); } + @ExperimentalApi("https://github.com/grpc/grpc-java/issues/11657") + @IgnoreJRERequirement + public final S withDeadlineAfter(Duration duration) { + return withDeadlineAfter(convert(duration), TimeUnit.NANOSECONDS); + } + /** * Returns a new stub with the given executor that is to be used instead of the default one * specified with {@link ManagedChannelBuilder#executor}. Note that setting this option may not @@ -252,6 +261,16 @@ public final S withMaxOutboundMessageSize(int maxSize) { return build(channel, callOptions.withMaxOutboundMessageSize(maxSize)); } + /** + * Returns a new stub that limits the maximum number of bytes per stream in the queue. + * + * @since 1.1.0 + */ + @ExperimentalApi("https://github.com/grpc/grpc-java/issues/11021") + public final S withOnReadyThreshold(int numBytes) { + return build(channel, callOptions.withOnReadyThreshold(numBytes)); + } + /** * A factory class for stub. * diff --git a/stub/src/main/java/io/grpc/stub/BlockingClientCall.java b/stub/src/main/java/io/grpc/stub/BlockingClientCall.java new file mode 100644 index 00000000000..6a52ce50776 --- /dev/null +++ b/stub/src/main/java/io/grpc/stub/BlockingClientCall.java @@ -0,0 +1,352 @@ +/* + * Copyright 2023 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.stub; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Preconditions; +import com.google.common.base.Predicate; +import io.grpc.ClientCall; +import io.grpc.ExperimentalApi; +import io.grpc.Metadata; +import io.grpc.Status; +import io.grpc.StatusException; +import io.grpc.stub.ClientCalls.ThreadSafeThreadlessExecutor; +import java.util.concurrent.ArrayBlockingQueue; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import java.util.concurrent.atomic.AtomicReference; +import java.util.logging.Level; +import java.util.logging.Logger; + +/** + * Represents a bidirectional streaming call from a client. Allows in a blocking manner, sending + * over the stream and receiving from the stream. Also supports terminating the call. + * Wraps a ClientCall and converts from async communication to the sync paradigm used by the + * various blocking stream methods in {@link ClientCalls} which are used by the generated stubs. + * + *

Supports separate threads for reads and writes, but only 1 of each + * + *

Read methods consist of: + *

    + *
  • {@link #read()} + *
  • {@link #read(long timeout, TimeUnit unit)} + *
  • {@link #hasNext()} + *
  • {@link #cancel(String, Throwable)} + *
+ * + *

Write methods consist of: + *

    + *
  • {@link #write(Object)} + *
  • {@link #write(Object, long timeout, TimeUnit unit)} + *
  • {@link #halfClose()} + *
+ * + * @param Type of the Request Message + * @param Type of the Response Message + */ +@ExperimentalApi("https://github.com/grpc/grpc-java/issues/10918") +public final class BlockingClientCall { + + private static final Logger logger = Logger.getLogger(BlockingClientCall.class.getName()); + + private final BlockingQueue buffer; + private final ClientCall call; + + private final ThreadSafeThreadlessExecutor executor; + + private boolean writeClosed; + private AtomicReference closeState = new AtomicReference<>(); + + BlockingClientCall(ClientCall call, ThreadSafeThreadlessExecutor executor) { + this.call = call; + this.executor = executor; + buffer = new ArrayBlockingQueue<>(1); + } + + /** + * Wait if necessary for a value to be available from the server. If there is an available value + * return it immediately, if the stream is closed return a null. Otherwise, wait for a value to be + * available or the stream to be closed + * + * @return value from server or null if stream has been closed + * @throws StatusException If the stream has closed in an error state + */ + public RespT read() throws InterruptedException, StatusException { + try { + return read(true, 0); + } catch (TimeoutException e) { + throw new AssertionError("should never happen", e); + } + } + + /** + * Wait with timeout, if necessary, for a value to be available from the server. If there is an + * available value, return it immediately. If the stream is closed return a null. Otherwise, wait + * for a value to be available, the stream to be closed or the timeout to expire. + * + * @param timeout how long to wait before giving up. Values <= 0 are no wait + * @param unit a TimeUnit determining how to interpret the timeout parameter + * @return value from server or null (if stream has been closed) + * @throws TimeoutException if no read becomes ready before the specified timeout expires + * @throws StatusException If the stream has closed in an error state + */ + public RespT read(long timeout, TimeUnit unit) throws InterruptedException, TimeoutException, + StatusException { + long endNanoTime = System.nanoTime() + unit.toNanos(timeout); + return read(false, endNanoTime); + } + + private RespT read(boolean waitForever, long endNanoTime) + throws InterruptedException, TimeoutException, StatusException { + Predicate> predicate = BlockingClientCall::skipWaitingForRead; + executor.waitAndDrainWithTimeout(waitForever, endNanoTime, predicate, this); + RespT bufferedValue = buffer.poll(); + + if (logger.isLoggable(Level.FINER)) { + logger.finer("Client Blocking read had value: " + bufferedValue); + } + + CloseState currentCloseState; + if (bufferedValue != null) { + call.request(1); + return bufferedValue; + } else if ((currentCloseState = closeState.get()) == null) { + throw new IllegalStateException( + "The message disappeared... are you reading from multiple threads?"); + } else if (!currentCloseState.status.isOk()) { + throw currentCloseState.status.asException(currentCloseState.trailers); + } else { + return null; + } + } + + boolean skipWaitingForRead() { + return closeState.get() != null || !buffer.isEmpty(); + } + + /** + * Wait for a value to be available from the server. If there is an + * available value, return true immediately. If the stream was closed with Status.OK, return + * false. If the stream was closed with an error status, throw a StatusException. Otherwise, wait + * for a value to be available or the stream to be closed. + * + * @return True when there is a value to read. Return false if stream closed cleanly. + * @throws StatusException If the stream was closed in an error state + */ + public boolean hasNext() throws InterruptedException, StatusException { + executor.waitAndDrain((x) -> !x.buffer.isEmpty() || x.closeState.get() != null, this); + + CloseState currentCloseState = closeState.get(); + if (currentCloseState != null && !currentCloseState.status.isOk()) { + throw currentCloseState.status.asException(currentCloseState.trailers); + } + + return !buffer.isEmpty(); + } + + /** + * Send a value to the stream for sending to server, wait if necessary for the grpc stream to be + * ready. + * + *

If write is not legal at the time of call, immediately returns false + * + *


NOTE: This method will return as soon as it passes the request to the grpc stream + * layer. It will not block while the message is being sent on the wire and returning true does + * not guarantee that the server gets the message. + * + *


WARNING: Doing only writes without reads can lead to deadlocks. This is because + * flow control, imposed by networks to protect intermediary routers and endpoints that are + * operating under resource constraints, requires reads to be done in order to progress writes. + * Furthermore, the server closing the stream will only be identified after + * the last sent value is read. + * + * @param request Message to send to the server + * @return true if the request is sent to stream, false if skipped + * @throws StatusException If the stream has closed in an error state + */ + public boolean write(ReqT request) throws InterruptedException, StatusException { + try { + return write(true, request, 0); + } catch (TimeoutException e) { + throw new RuntimeException(e); // should never happen + } + } + + /** + * Send a value to the stream for sending to server, wait if necessary for the grpc stream to be + * ready up to specified timeout. + * + *

If write is not legal at the time of call, immediately returns false + * + *


NOTE: This method will return as soon as it passes the request to the grpc stream + * layer. It will not block while the message is being sent on the wire and returning true does + * not guarantee that the server gets the message. + * + *


WARNING: Doing only writes without reads can lead to deadlocks as a result of + * flow control. Furthermore, the server closing the stream will only be identified after the + * last sent value is read. + * + * @param request Message to send to the server + * @param timeout How long to wait before giving up. Values <= 0 are no wait + * @param unit A TimeUnit determining how to interpret the timeout parameter + * @return true if the request is sent to stream, false if skipped + * @throws TimeoutException if write does not become ready before the specified timeout expires + * @throws StatusException If the stream has closed in an error state + */ + public boolean write(ReqT request, long timeout, TimeUnit unit) + throws InterruptedException, TimeoutException, StatusException { + long endNanoTime = System.nanoTime() + unit.toNanos(timeout); + return write(false, request, endNanoTime); + } + + private boolean write(boolean waitForever, ReqT request, long endNanoTime) + throws InterruptedException, TimeoutException, StatusException { + + if (writeClosed) { + throw new IllegalStateException("Writes cannot be done after calling halfClose or cancel"); + } + + Predicate> predicate = + (x) -> x.call.isReady() || x.closeState.get() != null; + executor.waitAndDrainWithTimeout(waitForever, endNanoTime, predicate, this); + CloseState savedCloseState = closeState.get(); + if (savedCloseState == null) { + call.sendMessage(request); + return true; + } else if (savedCloseState.status.isOk()) { + return false; + } else { + throw savedCloseState.status.asException(savedCloseState.trailers); + } + } + + void sendSingleRequest(ReqT request) { + call.sendMessage(request); + } + + /** + * Cancel stream and stop any further writes. Note that some reads that are in flight may still + * happen after the cancel. + * + * @param message if not {@code null}, will appear as the description of the CANCELLED status + * @param cause if not {@code null}, will appear as the cause of the CANCELLED status + */ + public void cancel(String message, Throwable cause) { + writeClosed = true; + call.cancel(message, cause); + } + + /** + * Indicate that no more writes will be done and the stream will be closed from the client side. + * + * @see ClientCall#halfClose() + */ + public void halfClose() { + if (writeClosed) { + throw new IllegalStateException( + "halfClose cannot be called after already half closed or cancelled"); + } + + writeClosed = true; + call.halfClose(); + } + + /** + * Status that server sent when closing channel from its side. + * + * @return null if stream not closed by server, otherwise Status sent by server + */ + @VisibleForTesting + Status getClosedStatus() { + executor.drain(); + CloseState state = closeState.get(); + return (state == null) ? null : state.status; + } + + /** + * Check for whether some action is ready. + * + * @return True if legal to write and writeOrRead can run without blocking + */ + @VisibleForTesting + boolean isEitherReadOrWriteReady() { + return (isWriteLegal() && isWriteReady()) || isReadReady(); + } + + /** + * Check whether there are any values waiting to be read. + * + * @return true if read will not block + */ + @VisibleForTesting + boolean isReadReady() { + executor.drain(); + + return !buffer.isEmpty(); + } + + /** + * Check that write hasn't been marked complete and stream is ready to receive a write (so will + * not block). + * + * @return true if legal to write and write will not block + */ + @VisibleForTesting + boolean isWriteReady() { + executor.drain(); + + return isWriteLegal() && call.isReady(); + } + + /** + * Check whether we'll ever be able to do writes or should terminate. + * @return True if writes haven't been closed and the server hasn't closed the stream + */ + private boolean isWriteLegal() { + return !writeClosed && closeState.get() == null; + } + + ClientCall.Listener getListener() { + return new QueuingListener(); + } + + private final class QueuingListener extends ClientCall.Listener { + @Override + public void onMessage(RespT value) { + Preconditions.checkState(closeState.get() == null, "ClientCall already closed"); + buffer.add(value); + } + + @Override + public void onClose(Status status, Metadata trailers) { + CloseState newCloseState = new CloseState(status, trailers); + boolean wasSet = closeState.compareAndSet(null, newCloseState); + Preconditions.checkState(wasSet, "ClientCall already closed"); + } + } + + private static final class CloseState { + final Status status; + final Metadata trailers; + + CloseState(Status status, Metadata trailers) { + this.status = Preconditions.checkNotNull(status, "status"); + this.trailers = trailers; + } + } +} diff --git a/stub/src/main/java/io/grpc/stub/ClientCalls.java b/stub/src/main/java/io/grpc/stub/ClientCalls.java index 13fb00d3b3e..ff2804a0a1f 100644 --- a/stub/src/main/java/io/grpc/stub/ClientCalls.java +++ b/stub/src/main/java/io/grpc/stub/ClientCalls.java @@ -22,12 +22,14 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.base.MoreObjects; import com.google.common.base.Preconditions; +import com.google.common.base.Predicate; import com.google.common.base.Strings; import com.google.common.util.concurrent.AbstractFuture; import com.google.common.util.concurrent.ListenableFuture; import io.grpc.CallOptions; import io.grpc.Channel; import io.grpc.ClientCall; +import io.grpc.ExperimentalApi; import io.grpc.Metadata; import io.grpc.MethodDescriptor; import io.grpc.Status; @@ -42,9 +44,14 @@ import java.util.concurrent.Executor; import java.util.concurrent.Future; import java.util.concurrent.RejectedExecutionException; +import java.util.concurrent.TimeoutException; +import java.util.concurrent.locks.Condition; +import java.util.concurrent.locks.Lock; import java.util.concurrent.locks.LockSupport; +import java.util.concurrent.locks.ReentrantLock; import java.util.logging.Level; import java.util.logging.Logger; +import javax.annotation.Nonnull; import javax.annotation.Nullable; /** @@ -175,6 +182,23 @@ public static RespT blockingUnaryCall( } } + /** + * Executes a unary call and blocks on the response, + * throws a checked {@link StatusException}. + * + * @return the single response message. + * @throws StatusException on error + */ + public static RespT blockingV2UnaryCall( + Channel channel, MethodDescriptor method, CallOptions callOptions, ReqT req) + throws StatusException { + try { + return blockingUnaryCall(channel, method, callOptions, req); + } catch (StatusRuntimeException e) { + throw e.getStatus().asException(e.getTrailers()); + } + } + /** * Executes a server-streaming call returning a blocking {@link Iterator} over the * response stream. The {@code call} should not be already started. After calling this method, @@ -184,7 +208,6 @@ public static RespT blockingUnaryCall( * * @return an iterator over the response stream. */ - // TODO(louiscryan): Not clear if we want to use this idiom for 'simple' stubs. public static Iterator blockingServerStreamingCall( ClientCall call, ReqT req) { BlockingResponseStream result = new BlockingResponseStream<>(call); @@ -194,11 +217,12 @@ public static Iterator blockingServerStreamingCall( /** * Executes a server-streaming call returning a blocking {@link Iterator} over the - * response stream. The {@code call} should not be already started. After calling this method, - * {@code call} should no longer be used. + * response stream. * *

The returned iterator may throw {@link StatusRuntimeException} on error. * + *

Warning: the iterator can result in leaks if not completely consumed. + * * @return an iterator over the response stream. */ public static Iterator blockingServerStreamingCall( @@ -211,6 +235,82 @@ public static Iterator blockingServerStreamingCall( return result; } + /** + * Initiates a client streaming call over the specified channel. It returns an + * object which can be used in a blocking manner to retrieve responses.. + * + *

The methods {@link BlockingClientCall#hasNext()} and {@link + * BlockingClientCall#cancel(String, Throwable)} can be used for more extensive control. + * + * @return A {@link BlockingClientCall} that has had the request sent and halfClose called + */ + @ExperimentalApi("https://github.com/grpc/grpc-java/issues/10918") + public static BlockingClientCall blockingV2ServerStreamingCall( + Channel channel, MethodDescriptor method, CallOptions callOptions, ReqT req) { + BlockingClientCall call = + blockingBidiStreamingCall(channel, method, callOptions); + + call.sendSingleRequest(req); + call.halfClose(); + return call; + } + + /** + * Initiates a server streaming call and sends the specified request to the server. It returns an + * object which can be used in a blocking manner to retrieve values from the server. After the + * last value has been read, the next read call will return null. + * + *

Call {@link BlockingClientCall#read()} for + * retrieving values. A {@code null} will be returned after the server has closed the stream. + * + *

The methods {@link BlockingClientCall#hasNext()} and {@link + * BlockingClientCall#cancel(String, Throwable)} can be used for more extensive control. + * + *


Example usage: + *

 {@code  while ((response = call.read()) != null) { ... } } 
+ * or + *
 {@code
+   *   while (call.hasNext()) {
+   *     response = call.read();
+   *     ...
+   *   }
+   * } 
+ * + *

Note that this paradigm is different from the original + * {@link #blockingServerStreamingCall(Channel, MethodDescriptor, CallOptions, Object)} + * which returns an iterator, which would leave the stream open if not completely consumed. + * + * @return A {@link BlockingClientCall} which can be used by the client to write and receive + * messages over the grpc channel. + */ + @ExperimentalApi("https://github.com/grpc/grpc-java/issues/10918") + public static BlockingClientCall blockingClientStreamingCall( + Channel channel, MethodDescriptor method, CallOptions callOptions) { + return blockingBidiStreamingCall(channel, method, callOptions); + } + + /** + * Initiate a bidirectional-streaming {@link ClientCall} and returning a stream object + * ({@link BlockingClientCall}) which can be used by the client to send and receive messages over + * the grpc channel. + * + * @return an object representing the call which can be used to read, write and terminate it. + */ + @ExperimentalApi("https://github.com/grpc/grpc-java/issues/10918") + public static BlockingClientCall blockingBidiStreamingCall( + Channel channel, MethodDescriptor method, CallOptions callOptions) { + ThreadSafeThreadlessExecutor executor = new ThreadSafeThreadlessExecutor(); + ClientCall call = channel.newCall(method, callOptions.withExecutor(executor)); + + BlockingClientCall blockingClientCall = new BlockingClientCall<>(call, executor); + + // Get the call started + call.start(blockingClientCall.getListener(), new Metadata()); + call.request(1); + + return blockingClientCall; + } + /** * Executes a unary call and returns a {@link ListenableFuture} to the response. The * {@code call} should not be already started. After calling this method, {@code call} should no @@ -414,7 +514,7 @@ public void disableAutoRequestWithInitial(int request) { public void request(int count) { if (!streamingResponse && count == 1) { // Initially ask for two responses from flow-control so that if a misbehaving server - // sends more than one responses, we can catch it and fail it in the listener. + // sends more than one response, we can catch it and fail it in the listener. call.request(2); } else { call.request(count); @@ -637,7 +737,7 @@ public boolean hasNext() { public T next() { // Eagerly call request(1) so it can be processing the next message while we wait for the // current one, which reduces latency for the next message. With MigratingThreadDeframer and - // if the data has already been recieved, every other message can be delivered instantly. This + // if the data has already been received, every other message can be delivered instantly. This // can be run after hasNext(), but just would be slower. if (!(last instanceof StatusRuntimeException) && last != this) { call.request(1); @@ -726,6 +826,12 @@ public void waitAndDrain() throws InterruptedException { } while ((runnable = poll()) != null); } + private static void throwIfInterrupted() throws InterruptedException { + if (Thread.interrupted()) { + throw new InterruptedException(); + } + } + /** * Called after final call to {@link #waitAndDrain()}, from same thread. */ @@ -745,12 +851,6 @@ private static void runQuietly(Runnable runnable) { } } - private static void throwIfInterrupted() throws InterruptedException { - if (Thread.interrupted()) { - throw new InterruptedException(); - } - } - @Override public void execute(Runnable runnable) { add(runnable); @@ -763,6 +863,128 @@ public void execute(Runnable runnable) { } } + @SuppressWarnings("serial") + static final class ThreadSafeThreadlessExecutor extends ConcurrentLinkedQueue + implements Executor { + private static final Logger log = + Logger.getLogger(ThreadSafeThreadlessExecutor.class.getName()); + + private final Lock waiterLock = new ReentrantLock(); + private final Condition waiterCondition = waiterLock.newCondition(); + + // Non private to avoid synthetic class + ThreadSafeThreadlessExecutor() {} + + /** + * Waits until there is a Runnable, then executes it and all queued Runnables after it. + */ + public void waitAndDrain(Predicate predicate, T testTarget) throws InterruptedException { + try { + waitAndDrainWithTimeout(true, 0, predicate, testTarget); + } catch (TimeoutException e) { + throw new AssertionError(e); // Should never happen + } + } + + /** + * Waits for up to specified nanoseconds until there is a Runnable, then executes it and all + * queued Runnables after it. + * + *

his should always be called in a loop that checks whether the reason we are waiting has + * been satisfied.

T + * + * @param waitForever ignore the rest of the arguments and wait until there is a task to run + * @param end System.nanoTime() to stop waiting if haven't been woken up yet + * @param predicate non-null condition to test for skipping wake or waking up threads + * @param testTarget object to pass to predicate + */ + public void waitAndDrainWithTimeout(boolean waitForever, long end, + @Nonnull Predicate predicate, T testTarget) + throws InterruptedException, TimeoutException { + throwIfInterrupted(); + Runnable runnable; + + while (!predicate.apply(testTarget)) { + waiterLock.lock(); + try { + while ((runnable = poll()) == null) { + if (predicate.apply(testTarget)) { + return; // The condition for which we were waiting is now satisfied + } + + if (waitForever) { + waiterCondition.await(); + } else { + long waitNanos = end - System.nanoTime(); + if (waitNanos <= 0) { + throw new TimeoutException(); // Deadline is expired + } + waiterCondition.awaitNanos(waitNanos); + } + } + } finally { + waiterLock.unlock(); + } + + do { + runQuietly(runnable); + } while ((runnable = poll()) != null); + // Wake everything up now that we've done something and they can check in their outer loop + // if they can continue or need to wait again. + signalAll(); + } + } + + /** Executes all queued Runnables and if there were any wakes up any waiting threads. */ + void drain() { + Runnable runnable; + boolean didWork = false; + + while ((runnable = poll()) != null) { + runQuietly(runnable); + didWork = true; + } + + if (didWork) { + signalAll(); + } + } + + private void signalAll() { + waiterLock.lock(); + try { + waiterCondition.signalAll(); + } finally { + waiterLock.unlock(); + } + } + + private static void runQuietly(Runnable runnable) { + try { + runnable.run(); + } catch (Throwable t) { + log.log(Level.WARNING, "Runnable threw exception", t); + } + } + + private static void throwIfInterrupted() throws InterruptedException { + if (Thread.interrupted()) { + throw new InterruptedException(); + } + } + + @Override + public void execute(Runnable runnable) { + waiterLock.lock(); + try { + add(runnable); + waiterCondition.signalAll(); // If anything is waiting let it wake up and process this task + } finally { + waiterLock.unlock(); + } + } + } + enum StubType { BLOCKING, FUTURE, ASYNC } diff --git a/stub/src/main/java/io/grpc/stub/MetadataUtils.java b/stub/src/main/java/io/grpc/stub/MetadataUtils.java index addf54c0f81..4208d3ca652 100644 --- a/stub/src/main/java/io/grpc/stub/MetadataUtils.java +++ b/stub/src/main/java/io/grpc/stub/MetadataUtils.java @@ -22,10 +22,15 @@ import io.grpc.Channel; import io.grpc.ClientCall; import io.grpc.ClientInterceptor; +import io.grpc.ExperimentalApi; import io.grpc.ForwardingClientCall.SimpleForwardingClientCall; import io.grpc.ForwardingClientCallListener.SimpleForwardingClientCallListener; +import io.grpc.ForwardingServerCall.SimpleForwardingServerCall; import io.grpc.Metadata; import io.grpc.MethodDescriptor; +import io.grpc.ServerCall; +import io.grpc.ServerCallHandler; +import io.grpc.ServerInterceptor; import io.grpc.Status; import java.util.concurrent.atomic.AtomicReference; @@ -143,4 +148,63 @@ public void onClose(Status status, Metadata trailers) { } } } + + /** + * Returns a ServerInterceptor that adds the specified Metadata to every response stream, one way + * or another. + * + *

If, absent this interceptor, a stream would have headers, 'extras' will be added to those + * headers. Otherwise, 'extras' will be sent as trailers. This pattern is useful when you have + * some fixed information, server identity say, that should be included no matter how the call + * turns out. The fallback to trailers avoids artificially committing clients to error responses + * that could otherwise be retried (see https://grpc.io/docs/guides/retry/ for more). + * + *

For correct operation, be sure to arrange for this interceptor to run *before* any others + * that might add headers. + * + * @param extras the Metadata to be added to each stream. Caller gives up ownership. + */ + @ExperimentalApi("https://github.com/grpc/grpc-java/issues/11462") + public static ServerInterceptor newAttachMetadataServerInterceptor(Metadata extras) { + return new MetadataAttachingServerInterceptor(extras); + } + + private static final class MetadataAttachingServerInterceptor implements ServerInterceptor { + + private final Metadata extras; + + MetadataAttachingServerInterceptor(Metadata extras) { + this.extras = extras; + } + + @Override + public ServerCall.Listener interceptCall( + ServerCall call, Metadata headers, ServerCallHandler next) { + return next.startCall(new MetadataAttachingServerCall<>(call), headers); + } + + final class MetadataAttachingServerCall + extends SimpleForwardingServerCall { + boolean headersSent; + + MetadataAttachingServerCall(ServerCall delegate) { + super(delegate); + } + + @Override + public void sendHeaders(Metadata headers) { + headers.merge(extras); + headersSent = true; + super.sendHeaders(headers); + } + + @Override + public void close(Status status, Metadata trailers) { + if (!headersSent) { + trailers.merge(extras); + } + super.close(status, trailers); + } + } + } } diff --git a/stub/src/main/java/io/grpc/stub/ServerCallStreamObserver.java b/stub/src/main/java/io/grpc/stub/ServerCallStreamObserver.java index 8201a230546..6ffea3500cc 100644 --- a/stub/src/main/java/io/grpc/stub/ServerCallStreamObserver.java +++ b/stub/src/main/java/io/grpc/stub/ServerCallStreamObserver.java @@ -16,6 +16,8 @@ package io.grpc.stub; +import static com.google.common.base.Preconditions.checkArgument; + import io.grpc.ExperimentalApi; /** @@ -64,6 +66,21 @@ public abstract class ServerCallStreamObserver extends CallStreamObserver */ public abstract void setOnCancelHandler(Runnable onCancelHandler); + + /** + * A hint to the call that specifies how many bytes must be queued before + * {@link #isReady()} will return false. A call may ignore this property if + * unsupported. This may only be set during stream initialization before + * any messages are set. + * + * @param numBytes The number of bytes that must be queued. Must be a + * positive integer. + */ + @ExperimentalApi("https://github.com/grpc/grpc-java/issues/11021") + public void setOnReadyThreshold(int numBytes) { + checkArgument(numBytes > 0, "numBytes must be positive: %s", numBytes); + } + /** * Sets the compression algorithm to use for the call. May only be called before sending any * messages. Default gRPC servers support the "gzip" compressor. diff --git a/stub/src/main/java/io/grpc/stub/ServerCalls.java b/stub/src/main/java/io/grpc/stub/ServerCalls.java index 83954af9670..9f0063713cc 100644 --- a/stub/src/main/java/io/grpc/stub/ServerCalls.java +++ b/stub/src/main/java/io/grpc/stub/ServerCalls.java @@ -382,9 +382,10 @@ public void onNext(RespT response) { @Override public void onError(Throwable t) { - Metadata metadata = Status.trailersFromThrowable(t); - if (metadata == null) { - metadata = new Metadata(); + Metadata metadata = new Metadata(); + Metadata trailers = Status.trailersFromThrowable(t); + if (trailers != null) { + metadata.merge(trailers); } call.close(Status.fromThrowable(t), metadata); aborted = true; @@ -395,7 +396,7 @@ public void onCompleted() { call.close(Status.OK, new Metadata()); completed = true; } - + @Override public boolean isReady() { return call.isReady(); @@ -422,6 +423,14 @@ public void setOnCancelHandler(Runnable onCancelHandler) { this.onCancelHandler = onCancelHandler; } + @Override + public void setOnReadyThreshold(int numBytes) { + checkState(!frozen, "Cannot alter setOnReadyThreshold after initialization. May only be " + + "called during the initial call to the application, before the service returns its " + + "StreamObserver"); + call.setOnReadyThreshold(numBytes); + } + @Override public void disableAutoInboundFlowControl() { disableAutoRequest(); diff --git a/stub/src/main/java/io/grpc/stub/StreamObservers.java b/stub/src/main/java/io/grpc/stub/StreamObservers.java index 2cc53ea0aa2..a421d3eca2f 100644 --- a/stub/src/main/java/io/grpc/stub/StreamObservers.java +++ b/stub/src/main/java/io/grpc/stub/StreamObservers.java @@ -23,12 +23,21 @@ /** * Utility functions for working with {@link StreamObserver} and it's common subclasses like * {@link CallStreamObserver}. - * - * @deprecated Of questionable utility and generally not used. */ -@Deprecated -@ExperimentalApi("https://github.com/grpc/grpc-java/issues/4694") public final class StreamObservers { + // Prevent instantiation + private StreamObservers() { } + + /** + * Utility method to call {@link StreamObserver#onNext(Object)} and + * {@link StreamObserver#onCompleted()} on the specified responseObserver. + */ + @ExperimentalApi("https://github.com/grpc/grpc-java/issues/10957") + public static void nextAndComplete(StreamObserver responseObserver, T response) { + responseObserver.onNext(response); + responseObserver.onCompleted(); + } + /** * Copy the values of an {@link Iterator} to the target {@link CallStreamObserver} while properly * accounting for outbound flow-control. After calling this method, {@code target} should no @@ -40,7 +49,10 @@ public final class StreamObservers { * * @param source of values expressed as an {@link Iterator}. * @param target {@link CallStreamObserver} which accepts values from the source. + * @deprecated Of questionable utility and generally not used. */ + @Deprecated + @ExperimentalApi("https://github.com/grpc/grpc-java/issues/4694") public static void copyWithFlowControl(final Iterator source, final CallStreamObserver target) { Preconditions.checkNotNull(source, "source"); @@ -80,7 +92,10 @@ public void run() { * * @param source of values expressed as an {@link Iterable}. * @param target {@link CallStreamObserver} which accepts values from the source. + * @deprecated Of questionable utility and generally not used. */ + @Deprecated + @ExperimentalApi("https://github.com/grpc/grpc-java/issues/4694") public static void copyWithFlowControl(final Iterable source, CallStreamObserver target) { Preconditions.checkNotNull(source, "source"); diff --git a/stub/src/test/java/io/grpc/stub/AbstractStubTest.java b/stub/src/test/java/io/grpc/stub/AbstractStubTest.java index 9006b8679e4..352a2fb7fe2 100644 --- a/stub/src/test/java/io/grpc/stub/AbstractStubTest.java +++ b/stub/src/test/java/io/grpc/stub/AbstractStubTest.java @@ -16,12 +16,19 @@ package io.grpc.stub; +import static com.google.common.truth.Truth.assertAbout; import static com.google.common.truth.Truth.assertThat; +import static io.grpc.testing.DeadlineSubject.deadline; +import static java.util.concurrent.TimeUnit.MILLISECONDS; +import static java.util.concurrent.TimeUnit.MINUTES; import io.grpc.CallOptions; import io.grpc.Channel; +import io.grpc.Deadline; import io.grpc.stub.AbstractStub.StubFactory; import io.grpc.stub.AbstractStubTest.NoopStub; +import java.time.Duration; +import org.codehaus.mojo.animal_sniffer.IgnoreJRERequirement; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -47,8 +54,23 @@ public NoopStub newStub(Channel channel, CallOptions callOptions) { .isNull(); } - class NoopStub extends AbstractStub { + @Test + @IgnoreJRERequirement + public void testDuration() { + NoopStub stub = NoopStub.newStub(new StubFactory() { + @Override + public NoopStub newStub(Channel channel, CallOptions callOptions) { + return create(channel, callOptions); + } + }, channel, CallOptions.DEFAULT); + NoopStub stubInstance = stub.withDeadlineAfter(Duration.ofMinutes(1L)); + Deadline actual = stubInstance.getCallOptions().getDeadline(); + Deadline expected = Deadline.after(1, MINUTES); + assertAbout(deadline()).that(actual).isWithin(10, MILLISECONDS).of(expected); + } + + class NoopStub extends AbstractStub { NoopStub(Channel channel, CallOptions options) { super(channel, options); } diff --git a/stub/src/test/java/io/grpc/stub/BaseAbstractStubTest.java b/stub/src/test/java/io/grpc/stub/BaseAbstractStubTest.java index cc5d5785449..9f7f10d8298 100644 --- a/stub/src/test/java/io/grpc/stub/BaseAbstractStubTest.java +++ b/stub/src/test/java/io/grpc/stub/BaseAbstractStubTest.java @@ -16,6 +16,7 @@ package io.grpc.stub; +import static com.google.common.truth.Truth.assertThat; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNull; @@ -90,4 +91,16 @@ public void withExecutor() { assertEquals(callOptions.getExecutor(), executor); } + + @Test + public void withOnReadyThreshold() { + T stub = create(channel); + CallOptions callOptions = stub.getCallOptions(); + assertNull(callOptions.getOnReadyThreshold()); + + int onReadyThreshold = 1024; + stub = stub.withOnReadyThreshold(onReadyThreshold); + callOptions = stub.getCallOptions(); + assertThat(callOptions.getOnReadyThreshold()).isEqualTo(onReadyThreshold); + } } diff --git a/stub/src/test/java/io/grpc/stub/BlockingClientCallTest.java b/stub/src/test/java/io/grpc/stub/BlockingClientCallTest.java new file mode 100644 index 00000000000..e3a4f90e2c2 --- /dev/null +++ b/stub/src/test/java/io/grpc/stub/BlockingClientCallTest.java @@ -0,0 +1,499 @@ +/* + * Copyright 2023 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.stub; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +import io.grpc.CallOptions; +import io.grpc.ManagedChannel; +import io.grpc.MethodDescriptor; +import io.grpc.MethodDescriptor.MethodType; +import io.grpc.Server; +import io.grpc.ServerServiceDefinition; +import io.grpc.ServiceDescriptor; +import io.grpc.Status; +import io.grpc.Status.Code; +import io.grpc.StatusException; +import io.grpc.inprocess.InProcessChannelBuilder; +import io.grpc.inprocess.InProcessServerBuilder; +import io.grpc.stub.ServerCalls.BidiStreamingMethod; +import io.grpc.stub.ServerCallsTest.IntegerMarshaller; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import java.util.logging.Logger; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class BlockingClientCallTest { + private static final Logger logger = Logger.getLogger(BlockingClientCallTest.class.getName()); + + public static final int DELAY_MILLIS = 2000; + public static final long DELAY_NANOS = TimeUnit.MILLISECONDS.toNanos(DELAY_MILLIS); + private static final MethodDescriptor BIDI_STREAMING_METHOD = + MethodDescriptor.newBuilder() + .setType(MethodType.BIDI_STREAMING) + .setFullMethodName("some/method") + .setRequestMarshaller(new IntegerMarshaller()) + .setResponseMarshaller(new IntegerMarshaller()) + .build(); + + private Server server; + + private ManagedChannel channel; + + private IntegerTestMethod testMethod; + private BlockingClientCall biDiStream; + + @Before + public void setUp() throws Exception { + testMethod = new IntegerTestMethod(); + + ServerServiceDefinition service = ServerServiceDefinition.builder( + new ServiceDescriptor("some", BIDI_STREAMING_METHOD)) + .addMethod(BIDI_STREAMING_METHOD, ServerCalls.asyncBidiStreamingCall(testMethod)) + .build(); + long tag = System.nanoTime(); + + server = InProcessServerBuilder.forName("go-with-the-flow" + tag).directExecutor() + .addService(service).build().start(); + + channel = InProcessChannelBuilder.forName("go-with-the-flow" + tag).directExecutor().build(); + } + + @After + public void tearDown() { + if (server != null) { + server.shutdownNow(); + } + if (channel != null) { + channel.shutdownNow(); + } + if (biDiStream != null) { + biDiStream.cancel("In teardown", null); + } + } + + @Test + public void sanityTest() throws Exception { + Integer req = 2; + biDiStream = ClientCalls.blockingBidiStreamingCall(channel, BIDI_STREAMING_METHOD, + CallOptions.DEFAULT); + + // verify activity ready + assertTrue(biDiStream.isEitherReadOrWriteReady()); + assertTrue(biDiStream.isWriteReady()); + + // Have server send a value + testMethod.sendValueToClient(10); + + // Do a writeOrRead + biDiStream.write(req, 3, TimeUnit.SECONDS); + assertEquals(Integer.valueOf(10), biDiStream.read(DELAY_MILLIS, TimeUnit.MILLISECONDS)); + + // mark complete + biDiStream.halfClose(); + assertNull(biDiStream.read(2, TimeUnit.SECONDS)); + + // verify activity !ready and !writeable + assertFalse(biDiStream.isEitherReadOrWriteReady()); + assertFalse(biDiStream.isWriteReady()); + + assertEquals(Code.OK, biDiStream.getClosedStatus().getCode()); + } + + @Test + public void testReadSuccess_withoutBlocking() throws Exception { + biDiStream = ClientCalls.blockingBidiStreamingCall(channel, BIDI_STREAMING_METHOD, + CallOptions.DEFAULT); + + // Have server push a value + testMethod.sendValueToClient(11); + + long start = System.nanoTime(); + Integer value = biDiStream.read(100, TimeUnit.SECONDS); + assertNotNull(value); + long timeTaken = System.nanoTime() - start; + assertThat(timeTaken).isLessThan(TimeUnit.MILLISECONDS.toNanos(100)); + } + + @Test + public void testReadSuccess_withBlocking() throws Exception { + biDiStream = ClientCalls.blockingBidiStreamingCall(channel, BIDI_STREAMING_METHOD, + CallOptions.DEFAULT); + + try { + biDiStream.read(1, TimeUnit.SECONDS); + fail("Expected timeout"); + } catch (TimeoutException t) { + // ignore + } + + long start = System.nanoTime(); + delayedAddValue(DELAY_MILLIS, 12); + assertNotNull(biDiStream.read(DELAY_MILLIS * 2, TimeUnit.MILLISECONDS)); + long timeTaken = System.nanoTime() - start; + assertThat(timeTaken).isGreaterThan(DELAY_NANOS); + assertThat(timeTaken).isLessThan(DELAY_NANOS * 2); + + start = System.nanoTime(); + Integer[] values = {13, 14, 15, 16}; + delayedAddValue(DELAY_MILLIS, values); + for (Integer value : values) { + Integer readValue = biDiStream.read(DELAY_MILLIS * 2, TimeUnit.MILLISECONDS); + assertEquals(value, readValue); + } + timeTaken = System.nanoTime() - start; + assertThat(timeTaken).isLessThan(DELAY_NANOS * 2); + assertThat(timeTaken).isAtLeast(DELAY_NANOS); + + start = System.nanoTime(); + delayedVoidMethod(100, testMethod::halfClose); + assertNull(biDiStream.read(DELAY_MILLIS * 2, TimeUnit.MILLISECONDS)); + timeTaken = System.nanoTime() - start; + assertThat(timeTaken).isLessThan(DELAY_NANOS); + } + + @Test + public void testCancel() throws Exception { + testMethod.disableAutoRequest(); + biDiStream = ClientCalls.blockingBidiStreamingCall(channel, BIDI_STREAMING_METHOD, + CallOptions.DEFAULT); + + // read terminated + long start = System.currentTimeMillis(); + delayedCancel(biDiStream, "cancel read"); + try { + assertNull(biDiStream.read(2 * DELAY_MILLIS, TimeUnit.MILLISECONDS)); + fail("No exception thrown by read after cancel"); + } catch (StatusException e) { + assertEquals(Status.CANCELLED.getCode(), e.getStatus().getCode()); + assertThat(System.currentTimeMillis() - start).isLessThan(2 * DELAY_MILLIS); + } + + // after cancel tests + biDiStream = ClientCalls.blockingBidiStreamingCall(channel, BIDI_STREAMING_METHOD, + CallOptions.DEFAULT); + biDiStream.cancel("cancel write", new RuntimeException("Test requested close")); + + // Write after cancel should throw an exception + try { + start = System.currentTimeMillis(); + biDiStream.write(30); + fail("No exception doing write after cancel"); + } catch (IllegalStateException e) { + assertThat(System.currentTimeMillis() - start).isLessThan(200); + assertThat(e.getMessage()).contains("cancel"); + } + + // new read after cancel immediately throws an exception + try { + start = System.currentTimeMillis(); + assertNull(biDiStream.read(2, TimeUnit.SECONDS)); + } catch (StatusException e) { + assertEquals(Status.CANCELLED.getCode(), e.getStatus().getCode()); + assertThat(System.currentTimeMillis() - start).isLessThan(200); + } + + } + + @Test + public void testIsActivityReady() throws Exception { + biDiStream = ClientCalls.blockingBidiStreamingCall(channel, BIDI_STREAMING_METHOD, + CallOptions.DEFAULT); + + // write only ready + assertTrue(biDiStream.isEitherReadOrWriteReady()); + assertTrue(biDiStream.isWriteReady()); + assertFalse(biDiStream.isReadReady()); + + // both ready + testMethod.sendValueToClient(40); + assertTrue(biDiStream.isEitherReadOrWriteReady()); + assertTrue(biDiStream.isReadReady()); + assertTrue(biDiStream.isWriteReady()); + + // read only ready + biDiStream.halfClose(); + assertTrue(biDiStream.isEitherReadOrWriteReady()); + assertTrue(biDiStream.isReadReady()); + assertFalse(biDiStream.isWriteReady()); + + // Neither ready + assertNotNull(biDiStream.read(1, TimeUnit.MILLISECONDS)); + assertFalse(biDiStream.isEitherReadOrWriteReady()); + assertFalse(biDiStream.isReadReady()); + assertFalse(biDiStream.isWriteReady()); + } + + @Test + public void testWriteSuccess_withBlocking() throws Exception { + testMethod.disableAutoRequest(); + biDiStream = ClientCalls.blockingBidiStreamingCall(channel, BIDI_STREAMING_METHOD, + CallOptions.DEFAULT); + + assertFalse(biDiStream.isWriteReady()); + delayedWriteEnable(500); + assertTrue(biDiStream.write(40)); + + delayedWriteEnable(500); + assertTrue(biDiStream.write(41, 0, TimeUnit.NANOSECONDS)); + } + + + @Test + public void testReadNonblocking_whenWriteBlocked() throws Exception { + testMethod.disableAutoRequest(); + biDiStream = ClientCalls.blockingBidiStreamingCall(channel, BIDI_STREAMING_METHOD, + CallOptions.DEFAULT); + + // One value waiting + testMethod.sendValueToClient(50); + long start = System.currentTimeMillis(); + assertEquals(Integer.valueOf(50), biDiStream.read()); + assertThat(System.currentTimeMillis() - start).isLessThan(DELAY_MILLIS); + + // Two values waiting + start = System.currentTimeMillis(); + testMethod.sendValuesToClient(51, 52); + assertEquals(Integer.valueOf(51), biDiStream.read()); + assertEquals(Integer.valueOf(52), biDiStream.read()); + assertThat(System.currentTimeMillis() - start).isLessThan(DELAY_MILLIS); + } + + @Test + public void testReadsAndWritesInterleaved_withBlocking() throws Exception { + biDiStream = ClientCalls.blockingBidiStreamingCall(channel, BIDI_STREAMING_METHOD, + CallOptions.DEFAULT); + + Integer[] valuesOut = {1001, 10022, 1003}; + Integer[] valuesIn = new Integer[valuesOut.length]; + delayedAddValue(300, valuesOut); + int iteration = 0; + for (int i = 0; i < valuesOut.length && iteration++ < (20 + valuesOut.length); ) { + try { + if ((valuesIn[i] = biDiStream.read(50, TimeUnit.MILLISECONDS)) != null) { + i++; + } + } catch (TimeoutException e) { + logger.info("Read timed out for " + i); + } + } + assertArrayEquals(valuesOut, valuesIn); + } + + @Test + public void testReadsAndWritesInterleaved_BlockingWrites() throws Exception { + testMethod.disableAutoRequest(); + biDiStream = ClientCalls.blockingBidiStreamingCall(channel, BIDI_STREAMING_METHOD, + CallOptions.DEFAULT); + + testMethod.sendValuesToClient(10, 11, 12); + delayedWriteEnable(500); + long start = System.currentTimeMillis(); + boolean done = false; + int count = 0; + while (!done) { + count++; + if (!biDiStream.isWriteReady() && biDiStream.isReadReady()) { + biDiStream.read(100, TimeUnit.MILLISECONDS); + } else { + done = biDiStream.write(100, 1, TimeUnit.SECONDS); + } + } + assertEquals(4, count); + assertThat(System.currentTimeMillis() - start).isLessThan(700); + + testMethod.sendValuesToClient(20, 21, 22); + delayedWriteEnable(100); + while (!biDiStream.isWriteReady()) { + Thread.sleep(20); + } + + assertTrue(biDiStream.write(1000, 2 * DELAY_MILLIS, TimeUnit.MILLISECONDS)); + + assertEquals(Integer.valueOf(20), biDiStream.read(200, TimeUnit.MILLISECONDS)); + assertEquals(Integer.valueOf(21), biDiStream.read(200, TimeUnit.MILLISECONDS)); + assertEquals(Integer.valueOf(22), biDiStream.read(200, TimeUnit.MILLISECONDS)); + try { + Integer value = biDiStream.read(200, TimeUnit.MILLISECONDS); + fail("Unexpected read success instead of timeout. Value was: " + value); + } catch (TimeoutException ignore) { + // ignore since expected + } + } + + @Test + public void testWriteAfterCloseThrows() throws Exception { + testMethod.disableAutoRequest(); + biDiStream = ClientCalls.blockingBidiStreamingCall(channel, BIDI_STREAMING_METHOD, + CallOptions.DEFAULT); + + // verify new writes throw an illegalStateException + biDiStream.halfClose(); + try { + assertFalse(biDiStream.write(2)); + fail("write did not throw an exception when called after halfClose"); + } catch (IllegalStateException e) { + assertThat(e.getMessage()).containsMatch("after.*halfClose.*cancel"); + } + } + + @Test + public void testClose_withException() throws Exception { + biDiStream = ClientCalls.blockingBidiStreamingCall(channel, BIDI_STREAMING_METHOD, + CallOptions.DEFAULT); + + String descr = "too many small numbers"; + testMethod.sendError( + Status.FAILED_PRECONDITION.withDescription(descr).asRuntimeException()); + Status closedStatus = biDiStream.getClosedStatus(); + assertEquals(Code.FAILED_PRECONDITION, closedStatus.getCode()); + assertEquals(descr, closedStatus.getDescription()); + try { + assertFalse(biDiStream.write(1)); + } catch (StatusException e) { + assertThat(e.getMessage()).startsWith("FAILED_PRECONDITION"); + } + } + + private void delayedAddValue(int delayMillis, Integer... values) { + new Thread("delayedAddValue " + values.length) { + @Override + public void run() { + try { + Thread.sleep(delayMillis); + for (Integer cur : values) { + testMethod.sendValueToClient(cur); + } + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + } + }.start(); + } + + public interface Thunk { void apply(); } // supports passing void method w/out args + + private void delayedVoidMethod(int delayMillis, Thunk method) { + new Thread("delayedHalfClose") { + @Override + public void run() { + try { + Thread.sleep(delayMillis); + method.apply(); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + } + }.start(); + } + + private void delayedWriteEnable(int delayMillis) { + delayedVoidMethod(delayMillis, testMethod::readValueFromClient); + } + + private void delayedCancel(BlockingClientCall biDiStream, String message) { + new Thread("delayedCancel") { + @Override + public void run() { + try { + Thread.sleep(BlockingClientCallTest.DELAY_MILLIS); + biDiStream.cancel(message, new RuntimeException("Test requested close")); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + } + }.start(); + } + + private static class IntegerTestMethod implements BidiStreamingMethod { + boolean autoRequest = true; + + void disableAutoRequest() { + assertNull("Can't disable auto request after invoke has been called", serverCallObserver); + autoRequest = false; + } + + ServerCallStreamObserver serverCallObserver; + + @Override + public StreamObserver invoke(StreamObserver responseObserver) { + serverCallObserver = (ServerCallStreamObserver) responseObserver; + if (!autoRequest) { + serverCallObserver.disableAutoRequest(); + } + + return new StreamObserver() { + @Override + public void onNext(Integer value) { + if (!autoRequest) { + serverCallObserver.request(1); + } + + // For testing ReqResp actions + if (value > 1000) { + serverCallObserver.onNext(value); + } + } + + @Override + public void onError(Throwable t) { + // no-op + } + + @Override + public void onCompleted() { + serverCallObserver.onCompleted(); + } + }; + } + + void readValueFromClient() { + serverCallObserver.request(1); + } + + void sendValueToClient(int value) { + serverCallObserver.onNext(value); + } + + private void sendValuesToClient(int ...values) { + for (int cur : values) { + sendValueToClient(cur); + } + } + + void halfClose() { + serverCallObserver.onCompleted(); + } + + void sendError(Throwable t) { + serverCallObserver.onError(t); + } + } + +} diff --git a/stub/src/test/java/io/grpc/stub/ClientCallsTest.java b/stub/src/test/java/io/grpc/stub/ClientCallsTest.java index 28801874ea1..b711b2a23b5 100644 --- a/stub/src/test/java/io/grpc/stub/ClientCallsTest.java +++ b/stub/src/test/java/io/grpc/stub/ClientCallsTest.java @@ -399,7 +399,7 @@ public void cancel(String message, Throwable cause) { future.get(); fail("Should fail"); } catch (CancellationException e) { - // Exepcted + // Expected } } @@ -971,8 +971,8 @@ public ClientCall interceptCall( } @Override public void halfClose() { - Thread.currentThread().interrupt(); super.halfClose(); + Thread.currentThread().interrupt(); } }; } diff --git a/stub/src/test/java/io/grpc/stub/MetadataUtilsTest.java b/stub/src/test/java/io/grpc/stub/MetadataUtilsTest.java new file mode 100644 index 00000000000..f9890ac0433 --- /dev/null +++ b/stub/src/test/java/io/grpc/stub/MetadataUtilsTest.java @@ -0,0 +1,175 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.stub; + +import static com.google.common.truth.Truth.assertThat; +import static io.grpc.stub.MetadataUtils.newAttachMetadataServerInterceptor; +import static io.grpc.stub.MetadataUtils.newCaptureMetadataInterceptor; +import static org.junit.Assert.fail; + +import com.google.common.collect.ImmutableList; +import io.grpc.CallOptions; +import io.grpc.ManagedChannel; +import io.grpc.Metadata; +import io.grpc.MethodDescriptor; +import io.grpc.ServerCallHandler; +import io.grpc.ServerInterceptors; +import io.grpc.ServerMethodDefinition; +import io.grpc.ServerServiceDefinition; +import io.grpc.Status; +import io.grpc.Status.Code; +import io.grpc.StatusRuntimeException; +import io.grpc.StringMarshaller; +import io.grpc.inprocess.InProcessChannelBuilder; +import io.grpc.inprocess.InProcessServerBuilder; +import io.grpc.testing.GrpcCleanupRule; +import java.io.IOException; +import java.util.Iterator; +import java.util.concurrent.atomic.AtomicReference; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class MetadataUtilsTest { + + @Rule public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule(); + + private static final String SERVER_NAME = "test"; + private static final Metadata.Key FOO_KEY = + Metadata.Key.of("foo-key", Metadata.ASCII_STRING_MARSHALLER); + + private final MethodDescriptor echoMethod = + MethodDescriptor.newBuilder(StringMarshaller.INSTANCE, StringMarshaller.INSTANCE) + .setFullMethodName("test/echo") + .setType(MethodDescriptor.MethodType.UNARY) + .build(); + + private final ServerCallHandler echoCallHandler = + ServerCalls.asyncUnaryCall( + (req, respObserver) -> { + respObserver.onNext(req); + respObserver.onCompleted(); + }); + + MethodDescriptor echoServerStreamingMethod = + MethodDescriptor.newBuilder(StringMarshaller.INSTANCE, StringMarshaller.INSTANCE) + .setFullMethodName("test/echoStream") + .setType(MethodDescriptor.MethodType.SERVER_STREAMING) + .build(); + + private final AtomicReference trailersCapture = new AtomicReference<>(); + private final AtomicReference headersCapture = new AtomicReference<>(); + + @Test + public void shouldAttachHeadersToResponse() throws IOException { + Metadata extras = new Metadata(); + extras.put(FOO_KEY, "foo-value"); + + ServerServiceDefinition serviceDef = + ServerInterceptors.intercept( + ServerServiceDefinition.builder("test").addMethod(echoMethod, echoCallHandler).build(), + ImmutableList.of(newAttachMetadataServerInterceptor(extras))); + + grpcCleanup.register(newInProcessServerBuilder().addService(serviceDef).build().start()); + ManagedChannel channel = + grpcCleanup.register( + newInProcessChannelBuilder() + .intercept(newCaptureMetadataInterceptor(headersCapture, trailersCapture)) + .build()); + + String response = + ClientCalls.blockingUnaryCall(channel, echoMethod, CallOptions.DEFAULT, "hello"); + assertThat(response).isEqualTo("hello"); + assertThat(trailersCapture.get() == null || !trailersCapture.get().containsKey(FOO_KEY)) + .isTrue(); + assertThat(headersCapture.get().get(FOO_KEY)).isEqualTo("foo-value"); + } + + @Test + public void shouldAttachTrailersWhenNoResponse() throws IOException { + Metadata extras = new Metadata(); + extras.put(FOO_KEY, "foo-value"); + + ServerServiceDefinition serviceDef = + ServerInterceptors.intercept( + ServerServiceDefinition.builder("test") + .addMethod( + ServerMethodDefinition.create( + echoServerStreamingMethod, + ServerCalls.asyncUnaryCall( + (req, respObserver) -> respObserver.onCompleted()))) + .build(), + ImmutableList.of(newAttachMetadataServerInterceptor(extras))); + grpcCleanup.register(newInProcessServerBuilder().addService(serviceDef).build().start()); + + ManagedChannel channel = + grpcCleanup.register( + newInProcessChannelBuilder() + .intercept(newCaptureMetadataInterceptor(headersCapture, trailersCapture)) + .build()); + + Iterator response = + ClientCalls.blockingServerStreamingCall( + channel, echoServerStreamingMethod, CallOptions.DEFAULT, "hello"); + assertThat(response.hasNext()).isFalse(); + assertThat(headersCapture.get() == null || !headersCapture.get().containsKey(FOO_KEY)).isTrue(); + assertThat(trailersCapture.get().get(FOO_KEY)).isEqualTo("foo-value"); + } + + @Test + public void shouldAttachTrailersToErrorResponse() throws IOException { + Metadata extras = new Metadata(); + extras.put(FOO_KEY, "foo-value"); + + ServerServiceDefinition serviceDef = + ServerInterceptors.intercept( + ServerServiceDefinition.builder("test") + .addMethod( + echoMethod, + ServerCalls.asyncUnaryCall( + (req, respObserver) -> + respObserver.onError(Status.INVALID_ARGUMENT.asRuntimeException()))) + .build(), + ImmutableList.of(newAttachMetadataServerInterceptor(extras))); + grpcCleanup.register(newInProcessServerBuilder().addService(serviceDef).build().start()); + + ManagedChannel channel = + grpcCleanup.register( + newInProcessChannelBuilder() + .intercept(newCaptureMetadataInterceptor(headersCapture, trailersCapture)) + .build()); + try { + ClientCalls.blockingUnaryCall(channel, echoMethod, CallOptions.DEFAULT, "hello"); + fail(); + } catch (StatusRuntimeException e) { + assertThat(e.getStatus()).isNotNull(); + assertThat(e.getStatus().getCode()).isEqualTo(Code.INVALID_ARGUMENT); + } + assertThat(headersCapture.get() == null || !headersCapture.get().containsKey(FOO_KEY)).isTrue(); + assertThat(trailersCapture.get().get(FOO_KEY)).isEqualTo("foo-value"); + } + + private static InProcessServerBuilder newInProcessServerBuilder() { + return InProcessServerBuilder.forName(SERVER_NAME).directExecutor(); + } + + private static InProcessChannelBuilder newInProcessChannelBuilder() { + return InProcessChannelBuilder.forName(SERVER_NAME).directExecutor(); + } +} diff --git a/stub/src/test/java/io/grpc/stub/ServerCallsTest.java b/stub/src/test/java/io/grpc/stub/ServerCallsTest.java index 7227d26c5b8..6f458facc5e 100644 --- a/stub/src/test/java/io/grpc/stub/ServerCallsTest.java +++ b/stub/src/test/java/io/grpc/stub/ServerCallsTest.java @@ -451,6 +451,31 @@ public void run() { assertEquals(2, onReadyCalled.get()); } + @Test + public void setOnReadyThreshold() throws Exception { + final int testThreshold = Integer.MAX_VALUE; + ServerCallHandler callHandler = + ServerCalls.asyncServerStreamingCall( + new ServerCalls.ServerStreamingMethod() { + @Override + public void invoke(Integer req, StreamObserver responseObserver) { + ServerCallStreamObserver serverCallObserver = + (ServerCallStreamObserver) responseObserver; + serverCallObserver.setOnReadyThreshold(req); + } + }); + ServerCall.Listener callListener = + callHandler.startCall(serverCall, new Metadata()); + serverCall.isReady = true; + serverCall.isCancelled = false; + callListener.onReady(); + callListener.onMessage(testThreshold); + // half-closing triggers the unary request delivery and onReady + callListener.onHalfClose(); + + assertEquals(testThreshold, serverCall.getOnReadyThreshold()); + } + @Test public void clientSendsOne_errorMissingRequest_unary() { ServerCallRecorder serverCall = new ServerCallRecorder(UNARY_METHOD); @@ -530,6 +555,35 @@ public void invoke(Integer req, StreamObserver responseObserver) { listener.onHalfClose(); } + @Test + public void clientSendsOne_serverOnErrorWithTrailers_serverStreaming() { + Metadata trailers = new Metadata(); + Metadata.Key key = Metadata.Key.of("trailers-test-key1", + Metadata.ASCII_STRING_MARSHALLER); + trailers.put(key, "trailers-test-value1"); + + ServerCallRecorder serverCall = new ServerCallRecorder(SERVER_STREAMING_METHOD); + ServerCallHandler callHandler = ServerCalls.asyncServerStreamingCall( + new ServerCalls.ServerStreamingMethod() { + @Override + public void invoke(Integer req, StreamObserver responseObserver) { + responseObserver.onError( + Status.fromCode(Status.Code.INTERNAL) + .asRuntimeException(trailers) + ); + } + }); + ServerCall.Listener listener = callHandler.startCall(serverCall, new Metadata()); + serverCall.isReady = true; + serverCall.isCancelled = false; + listener.onReady(); + listener.onMessage(1); + listener.onHalfClose(); + // verify trailers key is set + assertTrue(serverCall.trailers.containsKey(key)); + assertTrue(serverCall.status.equals(Status.INTERNAL)); + } + @Test public void inprocessTransportManualFlow() throws Exception { final Semaphore semaphore = new Semaphore(1); @@ -626,6 +680,8 @@ private static class ServerCallRecorder extends ServerCall { private Status status; private boolean isCancelled; private boolean isReady; + private int onReadyThreshold; + private Metadata trailers; public ServerCallRecorder(MethodDescriptor methodDescriptor) { this.methodDescriptor = methodDescriptor; @@ -648,6 +704,7 @@ public void sendMessage(Integer message) { @Override public void close(Status status, Metadata trailers) { this.status = status; + this.trailers = trailers; } @Override @@ -660,9 +717,19 @@ public boolean isReady() { return isReady; } + @Override + public void setOnReadyThreshold(int numBytes) { + super.setOnReadyThreshold(numBytes); + onReadyThreshold = numBytes; + } + @Override public MethodDescriptor getMethodDescriptor() { return methodDescriptor; } + + public int getOnReadyThreshold() { + return onReadyThreshold; + } } } diff --git a/stub/src/test/java/io/grpc/stub/StreamObserversTest.java b/stub/src/test/java/io/grpc/stub/StreamObserversTest.java new file mode 100644 index 00000000000..237dd2e1434 --- /dev/null +++ b/stub/src/test/java/io/grpc/stub/StreamObserversTest.java @@ -0,0 +1,38 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.stub; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.InOrder; +import org.mockito.Mockito; + +@RunWith(JUnit4.class) +public class StreamObserversTest { + + @Test + public void nextAndComplete() { + @SuppressWarnings("unchecked") + StreamObserver observer = Mockito.mock(StreamObserver.class); + InOrder inOrder = Mockito.inOrder(observer); + StreamObservers.nextAndComplete(observer, "TEST"); + inOrder.verify(observer).onNext("TEST"); + inOrder.verify(observer).onCompleted(); + inOrder.verifyNoMoreInteractions(); + } +} diff --git a/testing-proto/BUILD.bazel b/testing-proto/BUILD.bazel new file mode 100644 index 00000000000..aa0fc9ee20b --- /dev/null +++ b/testing-proto/BUILD.bazel @@ -0,0 +1,22 @@ +load("@com_google_protobuf//bazel:java_proto_library.bzl", "java_proto_library") +load("@com_google_protobuf//bazel:proto_library.bzl", "proto_library") +load("//:java_grpc_library.bzl", "java_grpc_library") + +proto_library( + name = "simpleservice_proto", + srcs = ["src/main/proto/io/grpc/testing/protobuf/simpleservice.proto"], + strip_import_prefix = "src/main/proto/", +) + +java_proto_library( + name = "simpleservice_java_proto", + visibility = ["//xds:__pkg__"], + deps = [":simpleservice_proto"], +) + +java_grpc_library( + name = "simpleservice_java_grpc", + srcs = [":simpleservice_proto"], + visibility = ["//xds:__pkg__"], + deps = [":simpleservice_java_proto"], +) diff --git a/testing-proto/build.gradle b/testing-proto/build.gradle index e6afce468f0..ee602bc5135 100644 --- a/testing-proto/build.gradle +++ b/testing-proto/build.gradle @@ -17,10 +17,12 @@ tasks.named("jar").configure { dependencies { api project(':grpc-protobuf'), project(':grpc-stub') - compileOnly libraries.javax.annotation testImplementation libraries.truth - testRuntimeOnly libraries.javax.annotation - signature libraries.signature.java + signature (libraries.signature.java) { + artifact { + extension = "signature" + } + } } configureProtoCompilation() diff --git a/testing-proto/src/generated/main/grpc/io/grpc/testing/protobuf/SimpleServiceGrpc.java b/testing-proto/src/generated/main/grpc/io/grpc/testing/protobuf/SimpleServiceGrpc.java index 8c58f2c5a2c..e242fd0f513 100644 --- a/testing-proto/src/generated/main/grpc/io/grpc/testing/protobuf/SimpleServiceGrpc.java +++ b/testing-proto/src/generated/main/grpc/io/grpc/testing/protobuf/SimpleServiceGrpc.java @@ -7,9 +7,6 @@ * A simple service for test. * */ -@javax.annotation.Generated( - value = "by gRPC proto compiler", - comments = "Source: io/grpc/testing/protobuf/simpleservice.proto") @io.grpc.stub.annotations.GrpcGenerated public final class SimpleServiceGrpc { @@ -156,6 +153,21 @@ public SimpleServiceStub newStub(io.grpc.Channel channel, io.grpc.CallOptions ca return SimpleServiceStub.newStub(factory, channel); } + /** + * Creates a new blocking-style stub that supports all types of calls on the service + */ + public static SimpleServiceBlockingV2Stub newBlockingV2Stub( + io.grpc.Channel channel) { + io.grpc.stub.AbstractStub.StubFactory factory = + new io.grpc.stub.AbstractStub.StubFactory() { + @java.lang.Override + public SimpleServiceBlockingV2Stub newStub(io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new SimpleServiceBlockingV2Stub(channel, callOptions); + } + }; + return SimpleServiceBlockingV2Stub.newStub(factory, channel); + } + /** * Creates a new blocking-style stub that supports unary and streaming output calls on the service */ @@ -318,6 +330,72 @@ public io.grpc.stub.StreamObserver bidiS * A simple service for test. * */ + public static final class SimpleServiceBlockingV2Stub + extends io.grpc.stub.AbstractBlockingStub { + private SimpleServiceBlockingV2Stub( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + super(channel, callOptions); + } + + @java.lang.Override + protected SimpleServiceBlockingV2Stub build( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new SimpleServiceBlockingV2Stub(channel, callOptions); + } + + /** + *

+     * Simple unary RPC.
+     * 
+ */ + public io.grpc.testing.protobuf.SimpleResponse unaryRpc(io.grpc.testing.protobuf.SimpleRequest request) throws io.grpc.StatusException { + return io.grpc.stub.ClientCalls.blockingV2UnaryCall( + getChannel(), getUnaryRpcMethod(), getCallOptions(), request); + } + + /** + *
+     * Simple client-to-server streaming RPC.
+     * 
+ */ + @io.grpc.ExperimentalApi("https://github.com/grpc/grpc-java/issues/10918") + public io.grpc.stub.BlockingClientCall + clientStreamingRpc() { + return io.grpc.stub.ClientCalls.blockingClientStreamingCall( + getChannel(), getClientStreamingRpcMethod(), getCallOptions()); + } + + /** + *
+     * Simple server-to-client streaming RPC.
+     * 
+ */ + @io.grpc.ExperimentalApi("https://github.com/grpc/grpc-java/issues/10918") + public io.grpc.stub.BlockingClientCall + serverStreamingRpc(io.grpc.testing.protobuf.SimpleRequest request) { + return io.grpc.stub.ClientCalls.blockingV2ServerStreamingCall( + getChannel(), getServerStreamingRpcMethod(), getCallOptions(), request); + } + + /** + *
+     * Simple bidirectional streaming RPC.
+     * 
+ */ + @io.grpc.ExperimentalApi("https://github.com/grpc/grpc-java/issues/10918") + public io.grpc.stub.BlockingClientCall + bidiStreamingRpc() { + return io.grpc.stub.ClientCalls.blockingBidiStreamingCall( + getChannel(), getBidiStreamingRpcMethod(), getCallOptions()); + } + } + + /** + * A stub to allow clients to do limited synchronous rpc calls to service SimpleService. + *
+   * A simple service for test.
+   * 
+ */ public static final class SimpleServiceBlockingStub extends io.grpc.stub.AbstractBlockingStub { private SimpleServiceBlockingStub( diff --git a/testing/BUILD.bazel b/testing/BUILD.bazel index 974cb32f752..d280ab97ee1 100644 --- a/testing/BUILD.bazel +++ b/testing/BUILD.bazel @@ -1,3 +1,6 @@ +load("@rules_java//java:defs.bzl", "java_library") +load("@rules_jvm_external//:defs.bzl", "artifact") + java_library( name = "testing", testonly = 1, @@ -12,12 +15,11 @@ java_library( "//api", "//context", "//inprocess", - "//util", "//stub", - "@com_google_code_findbugs_jsr305//jar", - "@com_google_guava_guava//jar", - "@com_google_j2objc_j2objc_annotations//jar", - "@com_google_truth_truth//jar", - "@junit_junit//jar", + "//util", + artifact("com.google.code.findbugs:jsr305"), + artifact("com.google.guava:guava"), + artifact("com.google.truth:truth"), + artifact("junit:junit"), ], ) diff --git a/testing/build.gradle b/testing/build.gradle index a782a5fa1c6..b92e39279c6 100644 --- a/testing/build.gradle +++ b/testing/build.gradle @@ -15,7 +15,8 @@ dependencies { implementation project(':grpc-inprocess') implementation project(':grpc-core') // Only io.grpc.internal.testing.StatsTestUtils depends on opencensus_api, for internal use. - compileOnly libraries.opencensus.api + compileOnly libraries.opencensus.api, + project(":grpc-context") // Override opencensus dependency with our newer version runtimeOnly project(":grpc-api") // Pull in newer version than census-api testImplementation libraries.mockito.core @@ -23,8 +24,16 @@ dependencies { testImplementation project(':grpc-testing-proto'), testFixtures(project(':grpc-core')) - signature libraries.signature.java - signature libraries.signature.android + signature (libraries.signature.java) { + artifact { + extension = "signature" + } + } + signature (libraries.signature.android) { + artifact { + extension = "signature" + } + } } tasks.named("javadoc").configure { exclude 'io/grpc/internal/**' } diff --git a/testing/src/main/java/io/grpc/internal/testing/FakeNameResolverProvider.java b/testing/src/main/java/io/grpc/internal/testing/FakeNameResolverProvider.java index 52bbc8efb4c..c77f7f8945a 100644 --- a/testing/src/main/java/io/grpc/internal/testing/FakeNameResolverProvider.java +++ b/testing/src/main/java/io/grpc/internal/testing/FakeNameResolverProvider.java @@ -21,6 +21,7 @@ import io.grpc.NameResolver; import io.grpc.NameResolverProvider; import io.grpc.Status; +import io.grpc.StatusOr; import java.net.SocketAddress; import java.net.URI; import java.util.Collection; @@ -52,7 +53,7 @@ protected boolean isAvailable() { @Override protected int priority() { - return 5; // Default + return 10; // High priority } @Override @@ -81,9 +82,10 @@ public void start(Listener2 listener) { if (shutdown) { listener.onError(Status.FAILED_PRECONDITION.withDescription("Resolver is shutdown")); } else { - listener.onResult( + listener.onResult2( ResolutionResult.newBuilder() - .setAddresses(ImmutableList.of(new EquivalentAddressGroup(address))) + .setAddressesOrError( + StatusOr.fromValue(ImmutableList.of(new EquivalentAddressGroup(address)))) .build()); } } diff --git a/testing/src/main/java/io/grpc/internal/testing/StatsTestUtils.java b/testing/src/main/java/io/grpc/internal/testing/StatsTestUtils.java index cd525eeeeb9..a15559ed5cb 100644 --- a/testing/src/main/java/io/grpc/internal/testing/StatsTestUtils.java +++ b/testing/src/main/java/io/grpc/internal/testing/StatsTestUtils.java @@ -16,8 +16,8 @@ package io.grpc.internal.testing; -import static com.google.common.base.Charsets.UTF_8; import static com.google.common.base.Preconditions.checkNotNull; +import static java.nio.charset.StandardCharsets.UTF_8; import com.google.common.base.Function; import com.google.common.collect.ImmutableMap; diff --git a/testing/src/main/resources/certs/README b/testing/src/main/resources/certs/README index 1fa6b733950..13e375784c5 100644 --- a/testing/src/main/resources/certs/README +++ b/testing/src/main/resources/certs/README @@ -67,6 +67,35 @@ ecdsa.key is used to test keys with algorithm other than RSA: $ openssl ecparam -name secp256k1 -genkey -noout -out ecdsa.pem $ openssl pkcs8 -topk8 -in ecdsa.pem -out ecdsa.key -nocrypt +SPIFFE test credentials: +======================= + +The SPIFFE related extensions are listed in spiffe-openssl.cnf config. Both +client_spiffe.pem and server1_spiffe.pem are generated in the same way with +original client.pem and server1.pem but with using that config. Here are the +exact commands (we pass "-subj" as argument in this case): +---------------------- +$ openssl req -new -key client.key -out spiffe-cert.csr \ + -subj /C=US/ST=CA/L=SVL/O=gRPC/CN=testclient/ \ + -config spiffe-openssl.cnf -reqexts spiffe_client_e2e +$ openssl x509 -req -CA ca.pem -CAkey ca.key -CAcreateserial \ + -in spiffe-cert.csr -out client_spiffe.pem -extensions spiffe_client_e2e \ + -extfile spiffe-openssl.cnf -days 3650 -sha256 +$ openssl req -new -key server1.key -out spiffe-cert.csr \ + -subj /C=US/ST=CA/L=SVL/O=gRPC/CN=*.test.google.com/ \ + -config spiffe-openssl.cnf -reqexts spiffe_server_e2e +$ openssl x509 -req -CA ca.pem -CAkey ca.key -CAcreateserial \ + -in spiffe-cert.csr -out server1_spiffe.pem -extensions spiffe_server_e2e \ + -extfile spiffe-openssl.cnf -days 3650 -sha256 + +Additionally, SPIFFE trust bundle map files spiffebundle.json and \ +spiffebundle1.json are manually created for end to end testing. The \ +spiffebundle.json contains "example.com" trust domain (only this entry is used \ +in e2e tests) matching URI SAN of server1_spiffe.pem, and the CA certificate \ +there is ca.pem. The spiffebundle.json file contains "foo.bar.com" trust \ +domain (only this entry is used in e2e tests) matching URI SAN of \ +client_spiffe.pem, and the CA certificate there is also ca.pem. + Clean up: --------- $ rm *.rsa diff --git a/testing/src/main/resources/certs/client_spiffe.pem b/testing/src/main/resources/certs/client_spiffe.pem new file mode 100644 index 00000000000..c70981a4030 --- /dev/null +++ b/testing/src/main/resources/certs/client_spiffe.pem @@ -0,0 +1,25 @@ +-----BEGIN CERTIFICATE----- +MIIEMjCCAxqgAwIBAgIUVXGlXjNENtOZbI12epjgIhMaShUwDQYJKoZIhvcNAQEL +BQAwVjELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoM +GEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDEPMA0GA1UEAwwGdGVzdGNhMB4XDTI0 +MTAyNDE2NDAzN1oXDTM0MTAyMjE2NDAzN1owaDELMAkGA1UEBhMCQVUxEzARBgNV +BAgMClNvbWUtU3RhdGUxDDAKBgNVBAcMA1NWTDEhMB8GA1UECgwYSW50ZXJuZXQg +V2lkZ2l0cyBQdHkgTHRkMRMwEQYDVQQDDAp0ZXN0Y2xpZW50MIIBIjANBgkqhkiG +9w0BAQEFAAOCAQ8AMIIBCgKCAQEAsqmEafg11ae9jRW0B/IXYU2S8nGVzpSYZjLK +yZq459qe6SP/Jk2f9BQvkhlgRmVfhC4h65gl+c32iC6/SLsOxoa91c6Hn4vK+tqy +7qVTzYv6naso1pNnRAhwvWd/gINysyk8nq11oynL8ilZjNGcRNEV4Q1v0aEG6mbF +NhioNQdq4VFPCjdIFZip9KyRzsc0VUmHmC2KeWJ+yq7TyXCsqPWlbhK+3RgDc6ch +epYP52AVnPvUhsJKC3RbyrwAWCTMq2zYR1EH79H82mdD/OnX0xDaw8cwC68xp6nM +dyk68CY5Gf2kq9bcg9P7V77pERYj8VgSYYx0O9BqkxUGNfUW4QIDAQABo4HlMIHi +MEQGA1UdEQQ9MDuGOXNwaWZmZTovL2Zvby5iYXIuY29tLzllZWJjY2QyLTEyYmYt +NDBhNi1iMjYyLTY1ZmUwNDg3ZDQ1MzAdBgNVHQ4EFgQU28U8sUTGNEDyeCrvJDJd +AALabSMwewYDVR0jBHQwcqFapFgwVjELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNv +bWUtU3RhdGUxITAfBgNVBAoMGEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDEPMA0G +A1UEAwwGdGVzdGNhghRas/RW8dzL4s/pS5g22Iv2AGEPmjANBgkqhkiG9w0BAQsF +AAOCAQEAE3LLE8GR283q/aE646SgAfltqpESP38NmYdJMdZgWRxbOqdWabYDfibt +9r8j+IRvVuuTWuH2eNS5wXJtS1BZ+z24wTLa+a2KjOV12gChP+3N7jhqId4eolSL +1fjscPY6luZP4Pm3D73lBvIoBvXpDGyrxleiUCEEkKXmTOA8doFvbrcbwm+yUJOP +VKUKvAzTNztb0BGDzKKU4E2yK5PSyv2n5m2NpzxYYfHoGeVcxvj7nCnSfoX/EWHb +d8ztJYDg9X0iNcfQXt7PZ+j6VcxfDpGCDxe2rFQoYvlWjhr3xOi/1e5A1zx1Ly07 +m9MB4hntu4e2656ZDWbgOHLpO0q1iQ== +-----END CERTIFICATE----- diff --git a/testing/src/main/resources/certs/server1_spiffe.pem b/testing/src/main/resources/certs/server1_spiffe.pem new file mode 100644 index 00000000000..76cb41d6922 --- /dev/null +++ b/testing/src/main/resources/certs/server1_spiffe.pem @@ -0,0 +1,26 @@ +-----BEGIN CERTIFICATE----- +MIIEZDCCA0ygAwIBAgIUVXGlXjNENtOZbI12epjgIhMaShMwDQYJKoZIhvcNAQEL +BQAwVjELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoM +GEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDEPMA0GA1UEAwwGdGVzdGNhMB4XDTI0 +MTAyMTAyMTQxNVoXDTM0MTAxOTAyMTQxNVowZTELMAkGA1UEBhMCVVMxETAPBgNV +BAgMCElsbGlub2lzMRAwDgYDVQQHDAdDaGljYWdvMRUwEwYDVQQKDAxFeGFtcGxl +LCBDby4xGjAYBgNVBAMMESoudGVzdC5nb29nbGUuY29tMIIBIjANBgkqhkiG9w0B +AQEFAAOCAQ8AMIIBCgKCAQEA5xOONxJJ8b8Qauvob5/7dPYZfIcd+uhAWL2ZlTPz +Qvu4oF0QI4iYgP5iGgry9zEtCM+YQS8UhiAlPlqa6ANxgiBSEyMHH/xE8lo/+caY +GeACqy640Jpl/JocFGo3xd1L8DCawjlaj6eu7T7T/tpAV2qq13b5710eNRbCAfFe +8yALiGQemx0IYhlZXNbIGWLBNhBhvVjJh7UvOqpADk4xtl8o5j0xgMIRg6WJGK6c +6ffSIg4eP1XmovNYZ9LLEJG68tF0Q/yIN43B4dt1oq4jzSdCbG4F1EiykT2TmwPV +YDi8tml6DfOCDGnit8svnMEmBv/fcPd31GSbXjF8M+KGGQIDAQABo4IBGTCCARUw +dwYDVR0RBHAwboIQKi50ZXN0Lmdvb2dsZS5mcoIYd2F0ZXJ6b29pLnRlc3QuZ29v +Z2xlLmJlghIqLnRlc3QueW91dHViZS5jb22HBMCoAQOGJnNwaWZmZTovL2V4YW1w +bGUuY29tL3dvcmtsb2FkLzllZWJjY2QyMB0GA1UdDgQWBBRvRpAYHQYP6dFPf5V7 +/MyCftnNjTB7BgNVHSMEdDByoVqkWDBWMQswCQYDVQQGEwJBVTETMBEGA1UECAwK +U29tZS1TdGF0ZTEhMB8GA1UECgwYSW50ZXJuZXQgV2lkZ2l0cyBQdHkgTHRkMQ8w +DQYDVQQDDAZ0ZXN0Y2GCFFqz9Fbx3Mviz+lLmDbYi/YAYQ+aMA0GCSqGSIb3DQEB +CwUAA4IBAQBJ1bnbBHa1n15vvhpGIzokuiJ+9q/zim63UuVDnkhrQM2N+RQbStGT +Tis2tNse1bh460dJFm6ArgHWogzx6fQZzgaDeCOAXvrAe4jM9IHr9K7lkq/33CZS +BDV+jCmm2sRsqSMkKUcX6JhyqWGFHuTDAKJzsEV2MlcswleKlGHDkeelAaxlLzpz +RHOSQd0N9xAs18lzx95SQEx90PtrBOmvIDDiI5o5z9Oz12Iy1toiksFl4jmknkDD +5VF3AyCRgN8NPW0uNC8D2vo4L+tgj9U6NPlmMOrjRsEH257LJ1wopAGr+yezkIId +QQodGSVm5cOuw/K7Ma4nBDjVJkjcdY3t +-----END CERTIFICATE----- diff --git a/testing/src/main/resources/certs/spiffe-openssl.cnf b/testing/src/main/resources/certs/spiffe-openssl.cnf new file mode 100644 index 00000000000..f03af40a782 --- /dev/null +++ b/testing/src/main/resources/certs/spiffe-openssl.cnf @@ -0,0 +1,28 @@ +[spiffe_client] +subjectAltName = @alt_names + +[spiffe_client_multi] +subjectAltName = @alt_names_multi + +[spiffe_server_e2e] +subjectAltName = @alt_names_server_e2e + +[spiffe_client_e2e] +subjectAltName = @alt_names_client_e2e + +[alt_names] +URI = spiffe://foo.bar.com/client/workload/1 + +[alt_names_multi] +URI.1 = spiffe://foo.bar.com/client/workload/1 +URI.2 = spiffe://foo.bar.com/client/workload/2 + +[alt_names_server_e2e] +DNS.1 = *.test.google.fr +DNS.2 = waterzooi.test.google.be +DNS.3 = *.test.youtube.com +IP.1 = "192.168.1.3" +URI = spiffe://example.com/workload/9eebccd2 + +[alt_names_client_e2e] +URI = spiffe://foo.bar.com/9eebccd2-12bf-40a6-b262-65fe0487d453 \ No newline at end of file diff --git a/testing/src/main/resources/certs/spiffe_cert.pem b/testing/src/main/resources/certs/spiffe_cert.pem new file mode 100644 index 00000000000..bc070042f69 --- /dev/null +++ b/testing/src/main/resources/certs/spiffe_cert.pem @@ -0,0 +1,33 @@ +-----BEGIN CERTIFICATE----- +MIIFsjCCA5qgAwIBAgIURygVMMzdr+Q7rsUaz189JozyHMwwDQYJKoZIhvcNAQEL +BQAwTjELMAkGA1UEBhMCVVMxCzAJBgNVBAgMAkNBMQwwCgYDVQQHDANTVkwxDTAL +BgNVBAoMBGdSUEMxFTATBgNVBAMMDHRlc3QtY2xpZW50MTAeFw0yMTEyMjMxODQy +NTJaFw0zMTEyMjExODQyNTJaME4xCzAJBgNVBAYTAlVTMQswCQYDVQQIDAJDQTEM +MAoGA1UEBwwDU1ZMMQ0wCwYDVQQKDARnUlBDMRUwEwYDVQQDDAx0ZXN0LWNsaWVu +dDEwggIiMA0GCSqGSIb3DQEBAQUAA4ICDwAwggIKAoICAQDJ4AqpGetyVSqGUuBJ +LVFla+7bEfca7UYzfVSSZLZ/X+JDmWIVN8UIPuFib5jhMEc3XaUnFXUmM7zEtz/Z +G5hapwLwOb2C3ZxOP6PQjYCJxbkLie+b43UQrFu1xxd3vMhVJgcj/AIxEpmszuqO +a6kUrkYifjJADQ+64kZgl66bsTdXMCzpxyFl9xUfff59L8OX+HUfAcoZz3emjg3Z +JPYURQEmjdZTOau1EjFilwHgd989Jt7NKgx30NXoHmw7nusVBIY94fL2VKN3f1XV +m0dHu5NI279Q6zr0ZBU7k5T3IeHnzsUesQS4NGlklDWoVTKk73Uv9Pna8yQsSW75 +7PEbHOGp9Knu4bnoGPOlsG81yIPipO6hTgGFK24pF97M9kpGbWqYX4+2vLlrCAfc +msHqaUPmQlYeRVTT6vw7ctYo2kyUYGtnODXk76LqewRBVvkzx75QUhfjAyb740Yc +DmIenc56Tq6gebJHjhEmVSehR6xIpXP7SVeurTyhPsEQnpJHtgs4dcwWOZp7BvPN +zHXmJqfr7vsshie3vS5kQ0u1e1yqAqXgyDjqKXOkx+dpgUTehSJHhPNHvTc5LXRs +vvXKYz6FrwR/DZ8t7BNEvPeLjFgxpH7QVJFLCvCbXs5K6yYbsnLfxFIBPRnrbJkI +sK+sQwnRdnsiUdPsTkG5B2lQfQIDAQABo4GHMIGEMB0GA1UdDgQWBBQ2lBp0PiRH +HvQ5IRURm8aHsj4RETAfBgNVHSMEGDAWgBQ2lBp0PiRHHvQ5IRURm8aHsj4RETAP +BgNVHRMBAf8EBTADAQH/MDEGA1UdEQQqMCiGJnNwaWZmZTovL2Zvby5iYXIuY29t +L2NsaWVudC93b3JrbG9hZC8xMA0GCSqGSIb3DQEBCwUAA4ICAQA1mSkgRclAl+E/ +aS9zJ7t8+Y4n3T24nOKKveSIjxXm/zjhWqVsLYBI6kglWtih2+PELvU8JdPqNZK3 +4Kl0Q6FWpVSGDdWN1i6NyORt2ocggL3ke3iXxRk3UpUKJmqwz81VhA2KUHnMlyE0 +IufFfZNwNWWHBv13uJfRbjeQpKPhU+yf4DeXrsWcvrZlGvAET+mcplafUzCp7Iv+ +PcISJtUerbxbVtuHVeZCLlgDXWkLAWJN8rf0dIG4x060LJ+j6j9uRVhb9sZn1HJV ++j4XdIYm1VKilluhOtNwP2d3Ox/JuTBxf7hFHXZPfMagQE5k5PzmxRaCAEMJ1l2D +vUbZw+shJfSNoWcBo2qadnUaWT3BmmJRBDh7ZReib/RQ1Rd4ygOyzP3E0vkV4/gq +yjLdApXh5PZP8KLQZ+1JN/sdWt7VfIt9wYOpkIqujdll51ESHzwQeAK9WVCB4UvV +z6zdhItB9CRbXPreWC+wCB1xDovIzFKOVsLs5+Gqs1m7VinG2LxbDqaKyo/FB0Hx +x0acBNzezLWoDwXYQrN0T0S4pnqhKD1CYPpdArBkNezUYAjS725FkApuK+mnBX3U +0msBffEaUEOkcyar1EW2m/33vpetD/k3eQQkmvQf4Hbiu9AF+9cNDm/hMuXEw5EX +GA91fn0891b5eEW8BJHXX0jri0aN8g== +-----END CERTIFICATE----- \ No newline at end of file diff --git a/testing/src/main/resources/certs/spiffe_multi_uri_san_cert.pem b/testing/src/main/resources/certs/spiffe_multi_uri_san_cert.pem new file mode 100644 index 00000000000..eb5c879abf8 --- /dev/null +++ b/testing/src/main/resources/certs/spiffe_multi_uri_san_cert.pem @@ -0,0 +1,25 @@ +-----BEGIN CERTIFICATE----- +MIIELTCCAxWgAwIBAgIUVXGlXjNENtOZbI12epjgIhMaShEwDQYJKoZIhvcNAQEL +BQAwVjELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoM +GEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDEPMA0GA1UEAwwGdGVzdGNhMB4XDTI0 +MDkxNzE2MTk0NFoXDTM0MDkxNTE2MTk0NFowTjELMAkGA1UEBhMCVVMxCzAJBgNV +BAgMAkNBMQwwCgYDVQQHDANTVkwxDTALBgNVBAoMBGdSUEMxFTATBgNVBAMMDHRl +c3QtY2xpZW50MTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBAOcTjjcS +SfG/EGrr6G+f+3T2GXyHHfroQFi9mZUz80L7uKBdECOImID+YhoK8vcxLQjPmEEv +FIYgJT5amugDcYIgUhMjBx/8RPJaP/nGmBngAqsuuNCaZfyaHBRqN8XdS/AwmsI5 +Wo+nru0+0/7aQFdqqtd2+e9dHjUWwgHxXvMgC4hkHpsdCGIZWVzWyBliwTYQYb1Y +yYe1LzqqQA5OMbZfKOY9MYDCEYOliRiunOn30iIOHj9V5qLzWGfSyxCRuvLRdEP8 +iDeNweHbdaKuI80nQmxuBdRIspE9k5sD1WA4vLZpeg3zggxp4rfLL5zBJgb/33D3 +d9Rkm14xfDPihhkCAwEAAaOB+jCB9zBZBgNVHREEUjBQhiZzcGlmZmU6Ly9mb28u +YmFyLmNvbS9jbGllbnQvd29ya2xvYWQvMYYmc3BpZmZlOi8vZm9vLmJhci5jb20v +Y2xpZW50L3dvcmtsb2FkLzIwHQYDVR0OBBYEFG9GkBgdBg/p0U9/lXv8zIJ+2c2N +MHsGA1UdIwR0MHKhWqRYMFYxCzAJBgNVBAYTAkFVMRMwEQYDVQQIDApTb21lLVN0 +YXRlMSEwHwYDVQQKDBhJbnRlcm5ldCBXaWRnaXRzIFB0eSBMdGQxDzANBgNVBAMM +BnRlc3RjYYIUWrP0VvHcy+LP6UuYNtiL9gBhD5owDQYJKoZIhvcNAQELBQADggEB +AJ4Cbxv+02SpUgkEu4hP/1+8DtSBXUxNxI0VG4e3Ap2+Rhjm3YiFeS/UeaZhNrrw +UEjkSTPFODyXR7wI7UO9OO1StyD6CMkp3SEvevU5JsZtGL6mTiTLTi3Qkywa91Bt +GlyZdVMghA1bBJLBMwiD5VT5noqoJBD7hDy6v9yNmt1Sw2iYBJPqI3Gnf5bMjR3s +UICaxmFyqaMCZsPkfJh0DmZpInGJys3m4QqGz6ZE2DWgcSr1r/ML7/5bSPjjr8j4 +WFFSqFR3dMu8CbGnfZTCTXa4GTX/rARXbAO67Z/oJbJBK7VKayskL+PzKuohb9ox +jGL772hQMbwtFCOFXu5VP0s= +-----END CERTIFICATE----- \ No newline at end of file diff --git a/testing/src/main/resources/certs/spiffebundle.json b/testing/src/main/resources/certs/spiffebundle.json new file mode 100644 index 00000000000..5bc8fcfb432 --- /dev/null +++ b/testing/src/main/resources/certs/spiffebundle.json @@ -0,0 +1,101 @@ +{ + "trust_domains": { + "example.com": { + "spiffe_sequence": 12035488, + "keys": [ + { + "kty": "RSA", + "use": "x509-svid", + "x5c": ["MIIDWjCCAkKgAwIBAgIUWrP0VvHcy+LP6UuYNtiL9gBhD5owDQYJKoZIhvcNAQEL + BQAwVjELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoM + GEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDEPMA0GA1UEAwwGdGVzdGNhMB4XDTIw + MDMxNzE4NTk1MVoXDTMwMDMxNTE4NTk1MVowVjELMAkGA1UEBhMCQVUxEzARBgNV + BAgMClNvbWUtU3RhdGUxITAfBgNVBAoMGEludGVybmV0IFdpZGdpdHMgUHR5IEx0 + ZDEPMA0GA1UEAwwGdGVzdGNhMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKC + AQEAsGL0oXflF0LzoM+Bh+qUU9yhqzw2w8OOX5mu/iNCyUOBrqaHi7mGHx73GD01 + diNzCzvlcQqdNIH6NQSL7DTpBjca66jYT9u73vZe2MDrr1nVbuLvfu9850cdxiUO + Inv5xf8+sTHG0C+a+VAvMhsLiRjsq+lXKRJyk5zkbbsETybqpxoJ+K7CoSy3yc/k + QIY3TipwEtwkKP4hzyo6KiGd/DPexie4nBUInN3bS1BUeNZ5zeaIC2eg3bkeeW7c + qT55b+Yen6CxY0TEkzBK6AKt/WUialKMgT0wbTxRZO7kUCH3Sq6e/wXeFdJ+HvdV + LPlAg5TnMaNpRdQih/8nRFpsdwIDAQABoyAwHjAMBgNVHRMEBTADAQH/MA4GA1Ud + DwEB/wQEAwICBDANBgkqhkiG9w0BAQsFAAOCAQEAkTrKZjBrJXHps/HrjNCFPb5a + THuGPCSsepe1wkKdSp1h4HGRpLoCgcLysCJ5hZhRpHkRihhef+rFHEe60UePQO3S + CVTtdJB4CYWpcNyXOdqefrbJW5QNljxgi6Fhvs7JJkBqdXIkWXtFk2eRgOIP2Eo9 + /OHQHlYnwZFrk6sp4wPyR+A95S0toZBcyDVz7u+hOW0pGK3wviOe9lvRgj/H3Pwt + bewb0l+MhRig0/DVHamyVxrDRbqInU1/GTNCwcZkXKYFWSf92U+kIcTth24Q1gcw + eZiLl5FfrWokUNytFElXob0V0a5/kbhiLc3yWmvWqHTpqCALbVyF+rKJo2f5Kw=="], + "n": "", + "e": "AQAB" + } + ] + }, + "test.example.com": { + "keys": [ + { + "kty": "RSA", + "use": "x509-svid", + "x5c": ["MIIFsjCCA5qgAwIBAgIURygVMMzdr+Q7rsUaz189JozyHMwwDQYJKoZIhvcNAQEL + BQAwTjELMAkGA1UEBhMCVVMxCzAJBgNVBAgMAkNBMQwwCgYDVQQHDANTVkwxDTAL + BgNVBAoMBGdSUEMxFTATBgNVBAMMDHRlc3QtY2xpZW50MTAeFw0yMTEyMjMxODQy + NTJaFw0zMTEyMjExODQyNTJaME4xCzAJBgNVBAYTAlVTMQswCQYDVQQIDAJDQTEM + MAoGA1UEBwwDU1ZMMQ0wCwYDVQQKDARnUlBDMRUwEwYDVQQDDAx0ZXN0LWNsaWVu + dDEwggIiMA0GCSqGSIb3DQEBAQUAA4ICDwAwggIKAoICAQDJ4AqpGetyVSqGUuBJ + LVFla+7bEfca7UYzfVSSZLZ/X+JDmWIVN8UIPuFib5jhMEc3XaUnFXUmM7zEtz/Z + G5hapwLwOb2C3ZxOP6PQjYCJxbkLie+b43UQrFu1xxd3vMhVJgcj/AIxEpmszuqO + a6kUrkYifjJADQ+64kZgl66bsTdXMCzpxyFl9xUfff59L8OX+HUfAcoZz3emjg3Z + JPYURQEmjdZTOau1EjFilwHgd989Jt7NKgx30NXoHmw7nusVBIY94fL2VKN3f1XV + m0dHu5NI279Q6zr0ZBU7k5T3IeHnzsUesQS4NGlklDWoVTKk73Uv9Pna8yQsSW75 + 7PEbHOGp9Knu4bnoGPOlsG81yIPipO6hTgGFK24pF97M9kpGbWqYX4+2vLlrCAfc + msHqaUPmQlYeRVTT6vw7ctYo2kyUYGtnODXk76LqewRBVvkzx75QUhfjAyb740Yc + DmIenc56Tq6gebJHjhEmVSehR6xIpXP7SVeurTyhPsEQnpJHtgs4dcwWOZp7BvPN + zHXmJqfr7vsshie3vS5kQ0u1e1yqAqXgyDjqKXOkx+dpgUTehSJHhPNHvTc5LXRs + vvXKYz6FrwR/DZ8t7BNEvPeLjFgxpH7QVJFLCvCbXs5K6yYbsnLfxFIBPRnrbJkI + sK+sQwnRdnsiUdPsTkG5B2lQfQIDAQABo4GHMIGEMB0GA1UdDgQWBBQ2lBp0PiRH + HvQ5IRURm8aHsj4RETAfBgNVHSMEGDAWgBQ2lBp0PiRHHvQ5IRURm8aHsj4RETAP + BgNVHRMBAf8EBTADAQH/MDEGA1UdEQQqMCiGJnNwaWZmZTovL2Zvby5iYXIuY29t + L2NsaWVudC93b3JrbG9hZC8xMA0GCSqGSIb3DQEBCwUAA4ICAQA1mSkgRclAl+E/ + aS9zJ7t8+Y4n3T24nOKKveSIjxXm/zjhWqVsLYBI6kglWtih2+PELvU8JdPqNZK3 + 4Kl0Q6FWpVSGDdWN1i6NyORt2ocggL3ke3iXxRk3UpUKJmqwz81VhA2KUHnMlyE0 + IufFfZNwNWWHBv13uJfRbjeQpKPhU+yf4DeXrsWcvrZlGvAET+mcplafUzCp7Iv+ + PcISJtUerbxbVtuHVeZCLlgDXWkLAWJN8rf0dIG4x060LJ+j6j9uRVhb9sZn1HJV + +j4XdIYm1VKilluhOtNwP2d3Ox/JuTBxf7hFHXZPfMagQE5k5PzmxRaCAEMJ1l2D + vUbZw+shJfSNoWcBo2qadnUaWT3BmmJRBDh7ZReib/RQ1Rd4ygOyzP3E0vkV4/gq + yjLdApXh5PZP8KLQZ+1JN/sdWt7VfIt9wYOpkIqujdll51ESHzwQeAK9WVCB4UvV + z6zdhItB9CRbXPreWC+wCB1xDovIzFKOVsLs5+Gqs1m7VinG2LxbDqaKyo/FB0Hx + x0acBNzezLWoDwXYQrN0T0S4pnqhKD1CYPpdArBkNezUYAjS725FkApuK+mnBX3U + 0msBffEaUEOkcyar1EW2m/33vpetD/k3eQQkmvQf4Hbiu9AF+9cNDm/hMuXEw5EX + GA91fn0891b5eEW8BJHXX0jri0aN8g=="], + "n": "", + "e": "AQAB" + }, + { + "kty": "RSA", + "use": "x509-svid", + "x5c": ["MIIELTCCAxWgAwIBAgIUVXGlXjNENtOZbI12epjgIhMaShEwDQYJKoZIhvcNAQEL + BQAwVjELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoM + GEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDEPMA0GA1UEAwwGdGVzdGNhMB4XDTI0 + MDkxNzE2MTk0NFoXDTM0MDkxNTE2MTk0NFowTjELMAkGA1UEBhMCVVMxCzAJBgNV + BAgMAkNBMQwwCgYDVQQHDANTVkwxDTALBgNVBAoMBGdSUEMxFTATBgNVBAMMDHRl + c3QtY2xpZW50MTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBAOcTjjcS + SfG/EGrr6G+f+3T2GXyHHfroQFi9mZUz80L7uKBdECOImID+YhoK8vcxLQjPmEEv + FIYgJT5amugDcYIgUhMjBx/8RPJaP/nGmBngAqsuuNCaZfyaHBRqN8XdS/AwmsI5 + Wo+nru0+0/7aQFdqqtd2+e9dHjUWwgHxXvMgC4hkHpsdCGIZWVzWyBliwTYQYb1Y + yYe1LzqqQA5OMbZfKOY9MYDCEYOliRiunOn30iIOHj9V5qLzWGfSyxCRuvLRdEP8 + iDeNweHbdaKuI80nQmxuBdRIspE9k5sD1WA4vLZpeg3zggxp4rfLL5zBJgb/33D3 + d9Rkm14xfDPihhkCAwEAAaOB+jCB9zBZBgNVHREEUjBQhiZzcGlmZmU6Ly9mb28u + YmFyLmNvbS9jbGllbnQvd29ya2xvYWQvMYYmc3BpZmZlOi8vZm9vLmJhci5jb20v + Y2xpZW50L3dvcmtsb2FkLzIwHQYDVR0OBBYEFG9GkBgdBg/p0U9/lXv8zIJ+2c2N + MHsGA1UdIwR0MHKhWqRYMFYxCzAJBgNVBAYTAkFVMRMwEQYDVQQIDApTb21lLVN0 + YXRlMSEwHwYDVQQKDBhJbnRlcm5ldCBXaWRnaXRzIFB0eSBMdGQxDzANBgNVBAMM + BnRlc3RjYYIUWrP0VvHcy+LP6UuYNtiL9gBhD5owDQYJKoZIhvcNAQELBQADggEB + AJ4Cbxv+02SpUgkEu4hP/1+8DtSBXUxNxI0VG4e3Ap2+Rhjm3YiFeS/UeaZhNrrw + UEjkSTPFODyXR7wI7UO9OO1StyD6CMkp3SEvevU5JsZtGL6mTiTLTi3Qkywa91Bt + GlyZdVMghA1bBJLBMwiD5VT5noqoJBD7hDy6v9yNmt1Sw2iYBJPqI3Gnf5bMjR3s + UICaxmFyqaMCZsPkfJh0DmZpInGJys3m4QqGz6ZE2DWgcSr1r/ML7/5bSPjjr8j4 + WFFSqFR3dMu8CbGnfZTCTXa4GTX/rARXbAO67Z/oJbJBK7VKayskL+PzKuohb9ox + jGL772hQMbwtFCOFXu5VP0s="] + } + ] + } + } +} \ No newline at end of file diff --git a/testing/src/main/resources/certs/spiffebundle1.json b/testing/src/main/resources/certs/spiffebundle1.json new file mode 100644 index 00000000000..f79af09a3e7 --- /dev/null +++ b/testing/src/main/resources/certs/spiffebundle1.json @@ -0,0 +1,59 @@ +{ + "trust_domains": { + "example.com": { + "spiffe_sequence": 12035488, + "keys": [ + { + "kty": "RSA", + "use": "x509-svid", + "x5c": ["MIIDWjCCAkKgAwIBAgIUWrP0VvHcy+LP6UuYNtiL9gBhD5owDQYJKoZIhvcNAQEL + BQAwVjELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoM + GEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDEPMA0GA1UEAwwGdGVzdGNhMB4XDTIw + MDMxNzE4NTk1MVoXDTMwMDMxNTE4NTk1MVowVjELMAkGA1UEBhMCQVUxEzARBgNV + BAgMClNvbWUtU3RhdGUxITAfBgNVBAoMGEludGVybmV0IFdpZGdpdHMgUHR5IEx0 + ZDEPMA0GA1UEAwwGdGVzdGNhMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKC + AQEAsGL0oXflF0LzoM+Bh+qUU9yhqzw2w8OOX5mu/iNCyUOBrqaHi7mGHx73GD01 + diNzCzvlcQqdNIH6NQSL7DTpBjca66jYT9u73vZe2MDrr1nVbuLvfu9850cdxiUO + Inv5xf8+sTHG0C+a+VAvMhsLiRjsq+lXKRJyk5zkbbsETybqpxoJ+K7CoSy3yc/k + QIY3TipwEtwkKP4hzyo6KiGd/DPexie4nBUInN3bS1BUeNZ5zeaIC2eg3bkeeW7c + qT55b+Yen6CxY0TEkzBK6AKt/WUialKMgT0wbTxRZO7kUCH3Sq6e/wXeFdJ+HvdV + LPlAg5TnMaNpRdQih/8nRFpsdwIDAQABoyAwHjAMBgNVHRMEBTADAQH/MA4GA1Ud + DwEB/wQEAwICBDANBgkqhkiG9w0BAQsFAAOCAQEAkTrKZjBrJXHps/HrjNCFPb5a + THuGPCSsepe1wkKdSp1h4HGRpLoCgcLysCJ5hZhRpHkRihhef+rFHEe60UePQO3S + CVTtdJB4CYWpcNyXOdqefrbJW5QNljxgi6Fhvs7JJkBqdXIkWXtFk2eRgOIP2Eo9 + /OHQHlYnwZFrk6sp4wPyR+A95S0toZBcyDVz7u+hOW0pGK3wviOe9lvRgj/H3Pwt + bewb0l+MhRig0/DVHamyVxrDRbqInU1/GTNCwcZkXKYFWSf92U+kIcTth24Q1gcw + eZiLl5FfrWokUNytFElXob0V0a5/kbhiLc3yWmvWqHTpqCALbVyF+rKJo2f5Kw=="], + "n": "", + "e": "AQAB" + } + ] + }, + "foo.bar.com": { + "keys": [ + { + "kty": "RSA", + "use": "x509-svid", + "x5c": ["MIIDWjCCAkKgAwIBAgIUWrP0VvHcy+LP6UuYNtiL9gBhD5owDQYJKoZIhvcNAQEL + BQAwVjELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoM + GEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDEPMA0GA1UEAwwGdGVzdGNhMB4XDTIw + MDMxNzE4NTk1MVoXDTMwMDMxNTE4NTk1MVowVjELMAkGA1UEBhMCQVUxEzARBgNV + BAgMClNvbWUtU3RhdGUxITAfBgNVBAoMGEludGVybmV0IFdpZGdpdHMgUHR5IEx0 + ZDEPMA0GA1UEAwwGdGVzdGNhMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKC + AQEAsGL0oXflF0LzoM+Bh+qUU9yhqzw2w8OOX5mu/iNCyUOBrqaHi7mGHx73GD01 + diNzCzvlcQqdNIH6NQSL7DTpBjca66jYT9u73vZe2MDrr1nVbuLvfu9850cdxiUO + Inv5xf8+sTHG0C+a+VAvMhsLiRjsq+lXKRJyk5zkbbsETybqpxoJ+K7CoSy3yc/k + QIY3TipwEtwkKP4hzyo6KiGd/DPexie4nBUInN3bS1BUeNZ5zeaIC2eg3bkeeW7c + qT55b+Yen6CxY0TEkzBK6AKt/WUialKMgT0wbTxRZO7kUCH3Sq6e/wXeFdJ+HvdV + LPlAg5TnMaNpRdQih/8nRFpsdwIDAQABoyAwHjAMBgNVHRMEBTADAQH/MA4GA1Ud + DwEB/wQEAwICBDANBgkqhkiG9w0BAQsFAAOCAQEAkTrKZjBrJXHps/HrjNCFPb5a + THuGPCSsepe1wkKdSp1h4HGRpLoCgcLysCJ5hZhRpHkRihhef+rFHEe60UePQO3S + CVTtdJB4CYWpcNyXOdqefrbJW5QNljxgi6Fhvs7JJkBqdXIkWXtFk2eRgOIP2Eo9 + /OHQHlYnwZFrk6sp4wPyR+A95S0toZBcyDVz7u+hOW0pGK3wviOe9lvRgj/H3Pwt + bewb0l+MhRig0/DVHamyVxrDRbqInU1/GTNCwcZkXKYFWSf92U+kIcTth24Q1gcw + eZiLl5FfrWokUNytFElXob0V0a5/kbhiLc3yWmvWqHTpqCALbVyF+rKJo2f5Kw=="] + } + ] + } + } +} \ No newline at end of file diff --git a/testing/src/test/java/io/grpc/testing/GrpcCleanupRuleTest.java b/testing/src/test/java/io/grpc/testing/GrpcCleanupRuleTest.java index a5a6783d53f..8eb3edd3825 100644 --- a/testing/src/test/java/io/grpc/testing/GrpcCleanupRuleTest.java +++ b/testing/src/test/java/io/grpc/testing/GrpcCleanupRuleTest.java @@ -18,6 +18,7 @@ import static com.google.common.truth.Truth.assertThat; import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; import static org.mockito.AdditionalAnswers.delegatesTo; import static org.mockito.ArgumentMatchers.any; @@ -35,9 +36,7 @@ import io.grpc.internal.FakeClock; import io.grpc.testing.GrpcCleanupRule.Resource; import java.util.concurrent.TimeUnit; -import org.junit.Rule; import org.junit.Test; -import org.junit.rules.ExpectedException; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; import org.junit.runners.model.MultipleFailureException; @@ -51,10 +50,6 @@ public class GrpcCleanupRuleTest { public static final FakeClock fakeClock = new FakeClock(); - @SuppressWarnings("deprecation") // https://github.com/grpc/grpc-java/issues/7467 - @Rule - public ExpectedException thrown = ExpectedException.none(); - @Test public void registerChannelReturnSameChannel() { ManagedChannel channel = mock(ManagedChannel.class); @@ -72,10 +67,9 @@ public void registerNullChannelThrowsNpe() { ManagedChannel channel = null; GrpcCleanupRule grpcCleanup = new GrpcCleanupRule(); - thrown.expect(NullPointerException.class); - thrown.expectMessage("channel"); - - grpcCleanup.register(channel); + NullPointerException e = assertThrows(NullPointerException.class, + () -> grpcCleanup.register(channel)); + assertThat(e).hasMessageThat().isEqualTo("channel"); } @Test @@ -83,10 +77,9 @@ public void registerNullServerThrowsNpe() { Server server = null; GrpcCleanupRule grpcCleanup = new GrpcCleanupRule(); - thrown.expect(NullPointerException.class); - thrown.expectMessage("server"); - - grpcCleanup.register(server); + NullPointerException e = assertThrows(NullPointerException.class, + () -> grpcCleanup.register(server)); + assertThat(e).hasMessageThat().isEqualTo("server"); } @Test diff --git a/util/BUILD.bazel b/util/BUILD.bazel index b95e428f435..32d5a367b95 100644 --- a/util/BUILD.bazel +++ b/util/BUILD.bazel @@ -1,3 +1,6 @@ +load("@rules_java//java:defs.bzl", "java_library") +load("@rules_jvm_external//:defs.bzl", "artifact") + java_library( name = "util", srcs = glob([ @@ -10,9 +13,9 @@ java_library( deps = [ "//api", "//core:internal", - "@com_google_code_findbugs_jsr305//jar", - "@com_google_guava_guava//jar", - "@com_google_j2objc_j2objc_annotations//jar", - "@org_codehaus_mojo_animal_sniffer_annotations//jar", + artifact("com.google.code.findbugs:jsr305"), + artifact("com.google.errorprone:error_prone_annotations"), + artifact("com.google.guava:guava"), + artifact("org.codehaus.mojo:animal-sniffer-annotations"), ], ) diff --git a/util/build.gradle b/util/build.gradle index 932ca66883e..846b110b106 100644 --- a/util/build.gradle +++ b/util/build.gradle @@ -35,8 +35,16 @@ dependencies { project(':grpc-testing') jmh project(':grpc-testing') - signature libraries.signature.java - signature libraries.signature.android + signature (libraries.signature.java) { + artifact { + extension = "signature" + } + } + signature (libraries.signature.android) { + artifact { + extension = "signature" + } + } } animalsniffer { @@ -50,6 +58,7 @@ animalsniffer { tasks.named("javadoc").configure { exclude 'io/grpc/util/MultiChildLoadBalancer.java' exclude 'io/grpc/util/OutlierDetectionLoadBalancer*' + exclude 'io/grpc/util/RandomSubsettingLoadBalancer*' exclude 'io/grpc/util/RoundRobinLoadBalancer*' } diff --git a/util/src/main/java/io/grpc/util/AdvancedTlsX509KeyManager.java b/util/src/main/java/io/grpc/util/AdvancedTlsX509KeyManager.java index 1530834d609..eea664f2ad4 100644 --- a/util/src/main/java/io/grpc/util/AdvancedTlsX509KeyManager.java +++ b/util/src/main/java/io/grpc/util/AdvancedTlsX509KeyManager.java @@ -18,6 +18,7 @@ import static com.google.common.base.Preconditions.checkNotNull; +import com.google.errorprone.annotations.InlineMe; import io.grpc.ExperimentalApi; import java.io.File; import java.io.FileInputStream; @@ -26,12 +27,12 @@ import java.security.GeneralSecurityException; import java.security.Principal; import java.security.PrivateKey; -import java.security.cert.CertificateException; import java.security.cert.X509Certificate; import java.util.Arrays; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ScheduledFuture; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; import java.util.logging.Level; import java.util.logging.Logger; import javax.net.ssl.SSLEngine; @@ -39,65 +40,87 @@ /** * AdvancedTlsX509KeyManager is an {@code X509ExtendedKeyManager} that allows users to configure - * advanced TLS features, such as private key and certificate chain reloading, etc. + * advanced TLS features, such as private key and certificate chain reloading. + * + *

The alias increments on every credential load (e.g. {@code "key-1"}, {@code "key-2"}, ...), + * so the same alias always maps to the same key material. The previous alias is retained for one + * rotation to allow in-progress handshakes to complete, ensuring alias-to-key-material consistency + * across credential reloads. */ -@ExperimentalApi("https://github.com/grpc/grpc-java/issues/8024") public final class AdvancedTlsX509KeyManager extends X509ExtendedKeyManager { private static final Logger log = Logger.getLogger(AdvancedTlsX509KeyManager.class.getName()); + // Minimum allowed period for refreshing files with credential information. + private static final int MINIMUM_REFRESH_PERIOD_IN_MINUTES = 1; + // Prefix for the key material alias; revision counter appended on each credential load. + static final String ALIAS_PREFIX = "key-"; - // The credential information sent to peers to prove our identity. - private volatile KeyInfo keyInfo; + private final AtomicInteger revision = new AtomicInteger(0); + // Snapshot of current and previous KeyInfo; previous is retained for in-progress handshakes + // after one rotation. + private volatile KeyInfoSnapshot snapshot = new KeyInfoSnapshot(null, null); - /** - * Constructs an AdvancedTlsX509KeyManager. - */ - public AdvancedTlsX509KeyManager() throws CertificateException { } + public AdvancedTlsX509KeyManager() {} + + private String alias() { + KeyInfo curr = this.snapshot.current; + return curr != null ? curr.alias : null; + } @Override public PrivateKey getPrivateKey(String alias) { - if (alias.equals("default")) { - return this.keyInfo.key; + KeyInfoSnapshot snap = this.snapshot; + if (snap.current != null && snap.current.alias.equals(alias)) { + return snap.current.key; + } + if (snap.previous != null && snap.previous.alias.equals(alias)) { + return snap.previous.key; } return null; } @Override public X509Certificate[] getCertificateChain(String alias) { - if (alias.equals("default")) { - return Arrays.copyOf(this.keyInfo.certs, this.keyInfo.certs.length); + KeyInfoSnapshot snap = this.snapshot; + if (snap.current != null && snap.current.alias.equals(alias)) { + return Arrays.copyOf(snap.current.certs, snap.current.certs.length); + } + if (snap.previous != null && snap.previous.alias.equals(alias)) { + return Arrays.copyOf(snap.previous.certs, snap.previous.certs.length); } return null; } @Override public String[] getClientAliases(String keyType, Principal[] issuers) { - return new String[] {"default"}; + String alias = alias(); + return alias != null ? new String[] {alias} : null; } @Override public String chooseClientAlias(String[] keyType, Principal[] issuers, Socket socket) { - return "default"; + return alias(); } @Override public String chooseEngineClientAlias(String[] keyType, Principal[] issuers, SSLEngine engine) { - return "default"; + return alias(); } @Override public String[] getServerAliases(String keyType, Principal[] issuers) { - return new String[] {"default"}; + String alias = alias(); + return alias != null ? new String[] {alias} : null; } @Override public String chooseServerAlias(String keyType, Principal[] issuers, Socket socket) { - return "default"; + return alias(); } @Override public String chooseEngineServerAlias(String keyType, Principal[] issuers, SSLEngine engine) { - return "default"; + return alias(); } /** @@ -105,106 +128,184 @@ public String chooseEngineServerAlias(String keyType, Principal[] issuers, * * @param key the private key that is going to be used * @param certs the certificate chain that is going to be used + * @deprecated Use {@link #updateIdentityCredentials(X509Certificate[], PrivateKey)} */ + @Deprecated + @InlineMe(replacement = "this.updateIdentityCredentials(certs, key)") + @ExperimentalApi("https://github.com/grpc/grpc-java/issues/8024") public void updateIdentityCredentials(PrivateKey key, X509Certificate[] certs) { - // TODO(ZhenLian): explore possibilities to do a crypto check here. - this.keyInfo = new KeyInfo(checkNotNull(key, "key"), checkNotNull(certs, "certs")); + updateIdentityCredentials(certs, key); } /** - * Schedules a {@code ScheduledExecutorService} to read private key and certificate chains from + * Updates the current cached private key and cert chains. + * + * @param certs the certificate chain that is going to be used + * @param key the private key that is going to be used + */ + public void updateIdentityCredentials(X509Certificate[] certs, PrivateKey key) { + KeyInfo newInfo = new KeyInfo(checkNotNull(certs, "certs"), checkNotNull(key, "key"), + ALIAS_PREFIX + revision.incrementAndGet()); + this.snapshot = new KeyInfoSnapshot(newInfo, this.snapshot.current); + } + + /** + * Schedules a {@code ScheduledExecutorService} to read certificate chains and private key from * the local file paths periodically, and update the cached identity credentials if they are both - * updated. + * updated. You must close the returned Closeable before calling this method again or other update + * methods ({@link AdvancedTlsX509KeyManager#updateIdentityCredentials}, {@link + * AdvancedTlsX509KeyManager#updateIdentityCredentials(File, File)}). + * Before scheduling the task, the method synchronously executes {@code readAndUpdate} once. The + * minimum refresh period of 1 minute is enforced. * - * @param keyFile the file on disk holding the private key * @param certFile the file on disk holding the certificate chain + * @param keyFile the file on disk holding the private key * @param period the period between successive read-and-update executions * @param unit the time unit of the initialDelay and period parameters - * @param executor the execute service we use to read and update the credentials + * @param executor the executor service we use to read and update the credentials * @return an object that caller should close when the file refreshes are not needed */ - public Closeable updateIdentityCredentialsFromFile(File keyFile, File certFile, + public Closeable updateIdentityCredentials(File certFile, File keyFile, long period, TimeUnit unit, ScheduledExecutorService executor) throws IOException, GeneralSecurityException { - UpdateResult newResult = readAndUpdate(keyFile, certFile, 0, 0); + UpdateResult newResult = readAndUpdate(certFile, keyFile, 0, 0); if (!newResult.success) { throw new GeneralSecurityException( "Files were unmodified before their initial update. Probably a bug."); } + if (checkNotNull(unit, "unit").toMinutes(period) < MINIMUM_REFRESH_PERIOD_IN_MINUTES) { + log.log(Level.FINE, + "Provided refresh period of {0} {1} is too small. Default value of {2} minute(s) " + + "will be used.", new Object[] {period, unit.name(), MINIMUM_REFRESH_PERIOD_IN_MINUTES}); + period = MINIMUM_REFRESH_PERIOD_IN_MINUTES; + unit = TimeUnit.MINUTES; + } final ScheduledFuture future = - executor.scheduleWithFixedDelay( - new LoadFilePathExecution(keyFile, certFile), period, period, unit); - return new Closeable() { - @Override public void close() { - future.cancel(false); - } - }; + checkNotNull(executor, "executor").scheduleWithFixedDelay( + new LoadFilePathExecution(certFile, keyFile), period, period, unit); + return () -> future.cancel(false); } /** - * Updates the private key and certificate chains from the local file paths. + * Updates certificate chains and the private key from the local file paths. * - * @param keyFile the file on disk holding the private key * @param certFile the file on disk holding the certificate chain + * @param keyFile the file on disk holding the private key */ - public void updateIdentityCredentialsFromFile(File keyFile, File certFile) throws IOException, + public void updateIdentityCredentials(File certFile, File keyFile) throws IOException, GeneralSecurityException { - UpdateResult newResult = readAndUpdate(keyFile, certFile, 0, 0); + UpdateResult newResult = readAndUpdate(certFile, keyFile, 0, 0); if (!newResult.success) { throw new GeneralSecurityException( "Files were unmodified before their initial update. Probably a bug."); } } + /** + * Updates the private key and certificate chains from the local file paths. + * + * @param keyFile the file on disk holding the private key + * @param certFile the file on disk holding the certificate chain + * @deprecated Use {@link #updateIdentityCredentials(File, File)} instead. + */ + @Deprecated + @InlineMe(replacement = "this.updateIdentityCredentials(certFile, keyFile)") + @ExperimentalApi("https://github.com/grpc/grpc-java/issues/8024") + public void updateIdentityCredentialsFromFile(File keyFile, File certFile) throws IOException, + GeneralSecurityException { + updateIdentityCredentials(certFile, keyFile); + } + + /** + * Schedules a {@code ScheduledExecutorService} to read private key and certificate chains from + * the local file paths periodically, and update the cached identity credentials if they are both + * updated. You must close the returned Closeable before calling this method again or other update + * methods ({@link AdvancedTlsX509KeyManager#updateIdentityCredentials}, {@link + * AdvancedTlsX509KeyManager#updateIdentityCredentials(File, File)}). + * Before scheduling the task, the method synchronously executes {@code readAndUpdate} once. The + * minimum refresh period of 1 minute is enforced. + * + * @param keyFile the file on disk holding the private key + * @param certFile the file on disk holding the certificate chain + * @param period the period between successive read-and-update executions + * @param unit the time unit of the initialDelay and period parameters + * @param executor the executor service we use to read and update the credentials + * @return an object that caller should close when the file refreshes are not needed + * @deprecated Use {@link + * #updateIdentityCredentials(File, File, long, TimeUnit, ScheduledExecutorService)} instead. + */ + @Deprecated + @InlineMe(replacement = + "this.updateIdentityCredentials(certFile, keyFile, period, unit, executor)") + @ExperimentalApi("https://github.com/grpc/grpc-java/issues/8024") + public Closeable updateIdentityCredentialsFromFile(File keyFile, File certFile, + long period, TimeUnit unit, ScheduledExecutorService executor) throws IOException, + GeneralSecurityException { + return updateIdentityCredentials(certFile, keyFile, period, unit, executor); + } + private static class KeyInfo { // The private key and the cert chain we will use to send to peers to prove our identity. - final PrivateKey key; final X509Certificate[] certs; + final PrivateKey key; + final String alias; - public KeyInfo(PrivateKey key, X509Certificate[] certs) { - this.key = key; + public KeyInfo(X509Certificate[] certs, PrivateKey key, String alias) { this.certs = certs; + this.key = key; + this.alias = alias; + } + } + + private static class KeyInfoSnapshot { + final KeyInfo current; + final KeyInfo previous; + + KeyInfoSnapshot(KeyInfo current, KeyInfo previous) { + this.current = current; + this.previous = previous; } } private class LoadFilePathExecution implements Runnable { File keyFile; File certFile; - long currentKeyTime; long currentCertTime; + long currentKeyTime; - public LoadFilePathExecution(File keyFile, File certFile) { - this.keyFile = keyFile; + public LoadFilePathExecution(File certFile, File keyFile) { this.certFile = certFile; - this.currentKeyTime = 0; + this.keyFile = keyFile; this.currentCertTime = 0; + this.currentKeyTime = 0; } @Override public void run() { try { - UpdateResult newResult = readAndUpdate(this.keyFile, this.certFile, this.currentKeyTime, + UpdateResult newResult = readAndUpdate(this.certFile, this.keyFile, this.currentKeyTime, this.currentCertTime); if (newResult.success) { - this.currentKeyTime = newResult.keyTime; this.currentCertTime = newResult.certTime; + this.currentKeyTime = newResult.keyTime; } } catch (IOException | GeneralSecurityException e) { - log.log(Level.SEVERE, "Failed refreshing private key and certificate chain from files. " - + "Using previous ones", e); + log.log(Level.SEVERE, String.format("Failed refreshing certificate and private key" + + " chain from files. Using previous ones (certFile lastModified = %s, keyFile " + + "lastModified = %s)", certFile.lastModified(), keyFile.lastModified()), e); } } } private static class UpdateResult { boolean success; - long keyTime; long certTime; + long keyTime; - public UpdateResult(boolean success, long keyTime, long certTime) { + public UpdateResult(boolean success, long certTime, long keyTime) { this.success = success; - this.keyTime = keyTime; this.certTime = certTime; + this.keyTime = keyTime; } } @@ -212,16 +313,16 @@ public UpdateResult(boolean success, long keyTime, long certTime) { * Reads the private key and certificates specified in the path locations. Updates {@code key} and * {@code cert} if both of their modified time changed since last read. * - * @param keyFile the file on disk holding the private key * @param certFile the file on disk holding the certificate chain + * @param keyFile the file on disk holding the private key * @param oldKeyTime the time when the private key file is modified during last execution * @param oldCertTime the time when the certificate chain file is modified during last execution * @return the result of this update execution */ - private UpdateResult readAndUpdate(File keyFile, File certFile, long oldKeyTime, long oldCertTime) + private UpdateResult readAndUpdate(File certFile, File keyFile, long oldKeyTime, long oldCertTime) throws IOException, GeneralSecurityException { - long newKeyTime = keyFile.lastModified(); - long newCertTime = certFile.lastModified(); + long newKeyTime = checkNotNull(keyFile, "keyFile").lastModified(); + long newCertTime = checkNotNull(certFile, "certFile").lastModified(); // We only update when both the key and the certs are updated. if (newKeyTime != oldKeyTime && newCertTime != oldCertTime) { FileInputStream keyInputStream = new FileInputStream(keyFile); @@ -230,7 +331,7 @@ private UpdateResult readAndUpdate(File keyFile, File certFile, long oldKeyTime, FileInputStream certInputStream = new FileInputStream(certFile); try { X509Certificate[] certs = CertificateUtils.getX509Certificates(certInputStream); - updateIdentityCredentials(key, certs); + updateIdentityCredentials(certs, key); return new UpdateResult(true, newKeyTime, newCertTime); } finally { certInputStream.close(); @@ -250,4 +351,3 @@ public interface Closeable extends java.io.Closeable { void close(); } } - diff --git a/util/src/main/java/io/grpc/util/AdvancedTlsX509TrustManager.java b/util/src/main/java/io/grpc/util/AdvancedTlsX509TrustManager.java index 7465e632104..0739fa3d453 100644 --- a/util/src/main/java/io/grpc/util/AdvancedTlsX509TrustManager.java +++ b/util/src/main/java/io/grpc/util/AdvancedTlsX509TrustManager.java @@ -16,6 +16,9 @@ package io.grpc.util; +import static com.google.common.base.Preconditions.checkNotNull; + +import com.google.errorprone.annotations.InlineMe; import io.grpc.ExperimentalApi; import java.io.File; import java.io.FileInputStream; @@ -42,14 +45,20 @@ /** * AdvancedTlsX509TrustManager is an {@code X509ExtendedTrustManager} that allows users to configure - * advanced TLS features, such as root certificate reloading, peer cert custom verification, etc. - * For Android users: this class is only supported in API level 24 and above. + * advanced TLS features, such as root certificate reloading and peer cert custom verification. + * The basic instantiation pattern is + * new Builder().build().useSystemDefaultTrustCerts(); + * + *

For Android users: this class is only supported in API level 24 and above. */ -@ExperimentalApi("https://github.com/grpc/grpc-java/issues/8024") @IgnoreJRERequirement public final class AdvancedTlsX509TrustManager extends X509ExtendedTrustManager { private static final Logger log = Logger.getLogger(AdvancedTlsX509TrustManager.class.getName()); + // Minimum allowed period for refreshing files with credential information. + private static final int MINIMUM_REFRESH_PERIOD_IN_MINUTES = 1; + private static final String NOT_ENOUGH_INFO_MESSAGE = + "Not enough information to validate peer. SSLEngine or Socket required."; private final Verification verification; private final SslSocketAndEnginePeerVerifier socketAndEnginePeerVerifier; @@ -57,7 +66,7 @@ public final class AdvancedTlsX509TrustManager extends X509ExtendedTrustManager private volatile X509ExtendedTrustManager delegateManager = null; private AdvancedTlsX509TrustManager(Verification verification, - SslSocketAndEnginePeerVerifier socketAndEnginePeerVerifier) throws CertificateException { + SslSocketAndEnginePeerVerifier socketAndEnginePeerVerifier) { this.verification = verification; this.socketAndEnginePeerVerifier = socketAndEnginePeerVerifier; } @@ -65,8 +74,7 @@ private AdvancedTlsX509TrustManager(Verification verification, @Override public void checkClientTrusted(X509Certificate[] chain, String authType) throws CertificateException { - throw new CertificateException( - "Not enough information to validate peer. SSLEngine or Socket required."); + throw new CertificateException(NOT_ENOUGH_INFO_MESSAGE); } @Override @@ -90,8 +98,7 @@ public void checkServerTrusted(X509Certificate[] chain, String authType, SSLEng @Override public void checkServerTrusted(X509Certificate[] chain, String authType) throws CertificateException { - throw new CertificateException( - "Not enough information to validate peer. SSLEngine or Socket required."); + throw new CertificateException(NOT_ENOUGH_INFO_MESSAGE); } @Override @@ -111,7 +118,7 @@ public X509Certificate[] getAcceptedIssuers() { /** * Uses the default trust certificates stored on user's local system. * After this is used, functions that will provide new credential - * data(e.g. updateTrustCredentials(), updateTrustCredentialsFromFile()) should not be called. + * data(e.g. updateTrustCredentials) should not be called. */ public void useSystemDefaultTrustCerts() throws CertificateException, KeyStoreException, NoSuchAlgorithmException { @@ -120,25 +127,6 @@ public void useSystemDefaultTrustCerts() throws CertificateException, KeyStoreEx this.delegateManager = createDelegateTrustManager(null); } - /** - * Updates the current cached trust certificates as well as the key store. - * - * @param trustCerts the trust certificates that are going to be used - */ - public void updateTrustCredentials(X509Certificate[] trustCerts) throws IOException, - GeneralSecurityException { - KeyStore keyStore = KeyStore.getInstance(KeyStore.getDefaultType()); - keyStore.load(null, null); - int i = 1; - for (X509Certificate cert: trustCerts) { - String alias = Integer.toString(i); - keyStore.setCertificateEntry(alias, cert); - i++; - } - X509ExtendedTrustManager newDelegateManager = createDelegateTrustManager(keyStore); - this.delegateManager = newDelegateManager; - } - private static X509ExtendedTrustManager createDelegateTrustManager(KeyStore keyStore) throws CertificateException, KeyStoreException, NoSuchAlgorithmException { TrustManagerFactory tmf = TrustManagerFactory.getInstance( @@ -148,9 +136,9 @@ private static X509ExtendedTrustManager createDelegateTrustManager(KeyStore keyS TrustManager[] tms = tmf.getTrustManagers(); // Iterate over the returned trust managers, looking for an instance of X509TrustManager. // If found, use that as the delegate trust manager. - for (int j = 0; j < tms.length; j++) { - if (tms[j] instanceof X509ExtendedTrustManager) { - delegateManager = (X509ExtendedTrustManager) tms[j]; + for (TrustManager tm : tms) { + if (tm instanceof X509ExtendedTrustManager) { + delegateManager = (X509ExtendedTrustManager) tm; break; } } @@ -169,8 +157,7 @@ private void checkTrusted(X509Certificate[] chain, String authType, SSLEngine ss "Want certificate verification but got null or empty certificates"); } if (sslEngine == null && socket == null) { - throw new CertificateException( - "Not enough information to validate peer. SSLEngine or Socket required."); + throw new CertificateException(NOT_ENOUGH_INFO_MESSAGE); } if (this.verification != Verification.INSECURELY_SKIP_ALL_VERIFICATION) { X509ExtendedTrustManager currentDelegateManager = this.delegateManager; @@ -196,7 +183,11 @@ private void checkTrusted(X509Certificate[] chain, String authType, SSLEngine ss currentDelegateManager.checkServerTrusted(chain, authType, sslSocket); } } else { - currentDelegateManager.checkClientTrusted(chain, authType, sslEngine); + if (sslEngine != null) { + currentDelegateManager.checkClientTrusted(chain, authType, sslEngine); + } else { + currentDelegateManager.checkClientTrusted(chain, authType, socket); + } } } // Perform the additional peer cert check. @@ -209,40 +200,121 @@ private void checkTrusted(X509Certificate[] chain, String authType, SSLEngine ss } } + /** + * Updates the current cached trust certificates as well as the key store. + * + * @param trustCerts the trust certificates that are going to be used + */ + public void updateTrustCredentials(X509Certificate[] trustCerts) throws IOException, + GeneralSecurityException { + checkNotNull(trustCerts, "trustCerts"); + KeyStore keyStore = KeyStore.getInstance(KeyStore.getDefaultType()); + keyStore.load(null, null); + int i = 1; + for (X509Certificate cert: trustCerts) { + String alias = Integer.toString(i); + keyStore.setCertificateEntry(alias, cert); + i++; + } + this.delegateManager = createDelegateTrustManager(keyStore); + } + + /** + * Updates the trust certificates from a local file path. + * + * @param trustCertFile the file on disk holding the trust certificates + */ + public void updateTrustCredentials(File trustCertFile) throws IOException, + GeneralSecurityException { + long updatedTime = readAndUpdate(trustCertFile, 0); + if (updatedTime == 0) { + throw new GeneralSecurityException( + "Files were unmodified before their initial update. Probably a bug."); + } + } + /** * Schedules a {@code ScheduledExecutorService} to read trust certificates from a local file path - * periodically, and update the cached trust certs if there is an update. + * periodically, and updates the cached trust certs if there is an update. You must close the + * returned Closeable before calling this method again or other update methods + * ({@link AdvancedTlsX509TrustManager#useSystemDefaultTrustCerts()}, + * {@link AdvancedTlsX509TrustManager#updateTrustCredentials(X509Certificate[])}, + * {@link AdvancedTlsX509TrustManager#updateTrustCredentialsFromFile(File)}). + * Before scheduling the task, the method synchronously reads and updates trust certificates once. + * If the provided period is less than 1 minute, it is automatically adjusted to 1 minute. * * @param trustCertFile the file on disk holding the trust certificates * @param period the period between successive read-and-update executions * @param unit the time unit of the initialDelay and period parameters - * @param executor the execute service we use to read and update the credentials + * @param executor the executor service we use to read and update the credentials * @return an object that caller should close when the file refreshes are not needed */ - public Closeable updateTrustCredentialsFromFile(File trustCertFile, long period, TimeUnit unit, + public Closeable updateTrustCredentials(File trustCertFile, long period, TimeUnit unit, ScheduledExecutorService executor) throws IOException, GeneralSecurityException { long updatedTime = readAndUpdate(trustCertFile, 0); if (updatedTime == 0) { throw new GeneralSecurityException( "Files were unmodified before their initial update. Probably a bug."); } + if (checkNotNull(unit, "unit").toMinutes(period) < MINIMUM_REFRESH_PERIOD_IN_MINUTES) { + log.log(Level.FINE, + "Provided refresh period of {0} {1} is too small. Default value of {2} minute(s) " + + "will be used.", new Object[] {period, unit.name(), MINIMUM_REFRESH_PERIOD_IN_MINUTES}); + period = MINIMUM_REFRESH_PERIOD_IN_MINUTES; + unit = TimeUnit.MINUTES; + } final ScheduledFuture future = - executor.scheduleWithFixedDelay( - new LoadFilePathExecution(trustCertFile), period, period, unit); - return new Closeable() { - @Override public void close() { - future.cancel(false); - } - }; + checkNotNull(executor, "executor").scheduleWithFixedDelay( + new LoadFilePathExecution(trustCertFile, updatedTime), period, period, unit); + return () -> future.cancel(false); + } + + /** + * Schedules a {@code ScheduledExecutorService} to read trust certificates from a local file path + * periodically, and updates the cached trust certs if there is an update. You must close the + * returned Closeable before calling this method again or other update methods + * ({@link AdvancedTlsX509TrustManager#useSystemDefaultTrustCerts()}, + * {@link AdvancedTlsX509TrustManager#updateTrustCredentials(X509Certificate[])}, + * {@link AdvancedTlsX509TrustManager#updateTrustCredentialsFromFile(File)}). + * Before scheduling the task, the method synchronously reads and updates trust certificates once. + * If the provided period is less than 1 minute, it is automatically adjusted to 1 minute. + * + * @param trustCertFile the file on disk holding the trust certificates + * @param period the period between successive read-and-update executions + * @param unit the time unit of the initialDelay and period parameters + * @param executor the executor service we use to read and update the credentials + * @return an object that caller should close when the file refreshes are not needed + * @deprecated Use {@link #updateTrustCredentials(File, long ,TimeUnit, ScheduledExecutorService)} + */ + @Deprecated + @InlineMe(replacement = "this.updateTrustCredentials(trustCertFile, period, unit, executor)") + @ExperimentalApi("https://github.com/grpc/grpc-java/issues/8024") + public Closeable updateTrustCredentialsFromFile(File trustCertFile, long period, TimeUnit unit, + ScheduledExecutorService executor) throws IOException, GeneralSecurityException { + return updateTrustCredentials(trustCertFile, period, unit, executor); + } + + /** + * Updates the trust certificates from a local file path. + * + * @param trustCertFile the file on disk holding the trust certificates + * @deprecated Use {@link #updateTrustCredentials(File)} + */ + @Deprecated + @InlineMe(replacement = "this.updateTrustCredentials(trustCertFile)") + @ExperimentalApi("https://github.com/grpc/grpc-java/issues/8024") + public void updateTrustCredentialsFromFile(File trustCertFile) throws IOException, + GeneralSecurityException { + updateTrustCredentials(trustCertFile); } private class LoadFilePathExecution implements Runnable { File file; long currentTime; - public LoadFilePathExecution(File file) { + public LoadFilePathExecution(File file, long currentTime) { this.file = file; - this.currentTime = 0; + this.currentTime = currentTime; } @Override @@ -250,27 +322,14 @@ public void run() { try { this.currentTime = readAndUpdate(this.file, this.currentTime); } catch (IOException | GeneralSecurityException e) { - log.log(Level.SEVERE, "Failed refreshing trust CAs from file. Using previous CAs", e); + log.log(Level.SEVERE, String.format("Failed refreshing trust CAs from file. Using " + + "previous CAs (file lastModified = %s)", file.lastModified()), e); } } } /** - * Updates the trust certificates from a local file path. - * - * @param trustCertFile the file on disk holding the trust certificates - */ - public void updateTrustCredentialsFromFile(File trustCertFile) throws IOException, - GeneralSecurityException { - long updatedTime = readAndUpdate(trustCertFile, 0); - if (updatedTime == 0) { - throw new GeneralSecurityException( - "Files were unmodified before their initial update. Probably a bug."); - } - } - - /** - * Reads the trust certificates specified in the path location, and update the key store if the + * Reads the trust certificates specified in the path location, and updates the key store if the * modified time has changed since last read. * * @param trustCertFile the file on disk holding the trust certificates @@ -279,7 +338,11 @@ public void updateTrustCredentialsFromFile(File trustCertFile) throws IOExceptio */ private long readAndUpdate(File trustCertFile, long oldTime) throws IOException, GeneralSecurityException { - long newTime = trustCertFile.lastModified(); + long newTime = checkNotNull(trustCertFile, "trustCertFile").lastModified(); + if (newTime == 0) { + throw new IOException( + "Certificate file not found or not readable: " + trustCertFile.getAbsolutePath()); + } if (newTime == oldTime) { return oldTime; } @@ -303,27 +366,32 @@ public static Builder newBuilder() { return new Builder(); } - // The verification mode when authenticating the peer certificate. + /** + * The verification mode when authenticating the peer certificate. + */ public enum Verification { - // This is the DEFAULT and RECOMMENDED mode for most applications. - // Setting this on the client side will do the certificate and hostname verification, while - // setting this on the server side will only do the certificate verification. + /** + * This is the DEFAULT and RECOMMENDED mode for most applications. + * Setting this on the client side performs both certificate and hostname verification, while + * setting it on the server side only performs certificate verification. + */ CERTIFICATE_AND_HOST_NAME_VERIFICATION, - // This SHOULD be chosen only when you know what the implication this will bring, and have a - // basic understanding about TLS. - // It SHOULD be accompanied with proper additional peer identity checks set through - // {@code PeerVerifier}(nit: why this @code not working?). Failing to do so will leave - // applications to MITM attack. - // Also note that this will only take effect if the underlying SDK implementation invokes - // checkClientTrusted/checkServerTrusted with the {@code SSLEngine} parameter while doing - // verification. - // Setting this on either side will only do the certificate verification. + /** + * DANGEROUS: Use trusted credentials to verify the certificate, but clients will not verify the + * certificate is for the expected host. This setting is only appropriate when accompanied by + * proper additional peer identity checks set through SslSocketAndEnginePeerVerifier. Failing to + * do so will leave your applications vulnerable to MITM attacks. + * This setting has the same behavior on server-side as CERTIFICATE_AND_HOST_NAME_VERIFICATION. + */ CERTIFICATE_ONLY_VERIFICATION, - // Setting is very DANGEROUS. Please try to avoid this in a real production environment, unless - // you are a super advanced user intended to re-implement the whole verification logic on your - // own. A secure verification might include: - // 1. proper verification on the peer certificate chain - // 2. proper checks on the identity of the peer certificate + /** + * DANGEROUS: This SHOULD be used by advanced user intended to implement the entire verification + * logic themselves {@link SslSocketAndEnginePeerVerifier}) themselves. This includes:
+ * 1. Proper verification of the peer certificate chain
+ * 2. Proper checks of the identity of the peer certificate
+ * Failing to do so will leave your application without any TLS-related protection. Keep in mind + * that any loaded trust certificates will be ignored when using this mode. + */ INSECURELY_SKIP_ALL_VERIFICATION, } @@ -356,6 +424,14 @@ void verifyPeerCertificate(X509Certificate[] peerCertChain, String authType, SSL throws CertificateException; } + /** + * Builds a new {@link AdvancedTlsX509TrustManager}. By default, no trust certificates are loaded + * after the build. To load them, use one of the following methods: {@link + * AdvancedTlsX509TrustManager#updateTrustCredentials(X509Certificate[])}, {@link + * AdvancedTlsX509TrustManager#updateTrustCredentials(File, long, TimeUnit, + * ScheduledExecutorService)}, {@link AdvancedTlsX509TrustManager#updateTrustCredentials + * (File, long, TimeUnit, ScheduledExecutorService)}. + */ public static final class Builder { private Verification verification = Verification.CERTIFICATE_AND_HOST_NAME_VERIFICATION; @@ -363,11 +439,26 @@ public static final class Builder { private Builder() {} + /** + * Sets {@link Verification}, mode when authenticating the peer certificate. By default, {@link + * Verification#CERTIFICATE_AND_HOST_NAME_VERIFICATION} value is used. + * + * @param verification Verification mode used for the current AdvancedTlsX509TrustManager + * @return Builder with set verification + */ public Builder setVerification(Verification verification) { this.verification = verification; return this; } + /** + * Sets {@link SslSocketAndEnginePeerVerifier}, which methods will be called in addition to + * verifying certificates. + * + * @param verifier SslSocketAndEnginePeerVerifier used for the current + * AdvancedTlsX509TrustManager + * @return Builder with set verifier + */ public Builder setSslSocketAndEnginePeerVerifier(SslSocketAndEnginePeerVerifier verifier) { this.socketAndEnginePeerVerifier = verifier; return this; diff --git a/util/src/main/java/io/grpc/util/ForwardingClientStreamTracer.java b/util/src/main/java/io/grpc/util/ForwardingClientStreamTracer.java index 7317917887a..9c9998571e5 100644 --- a/util/src/main/java/io/grpc/util/ForwardingClientStreamTracer.java +++ b/util/src/main/java/io/grpc/util/ForwardingClientStreamTracer.java @@ -48,11 +48,21 @@ public void inboundHeaders() { delegate().inboundHeaders(); } + @Override + public void inboundHeaders(Metadata headers) { + delegate().inboundHeaders(headers); + } + @Override public void inboundTrailers(Metadata trailers) { delegate().inboundTrailers(trailers); } + @Override + public void addOptionalLabel(String key, String value) { + delegate().addOptionalLabel(key, value); + } + @Override public void streamClosed(Status status) { delegate().streamClosed(status); diff --git a/util/src/main/java/io/grpc/util/ForwardingLoadBalancer.java b/util/src/main/java/io/grpc/util/ForwardingLoadBalancer.java index cefcbf344ea..d52ff42e652 100644 --- a/util/src/main/java/io/grpc/util/ForwardingLoadBalancer.java +++ b/util/src/main/java/io/grpc/util/ForwardingLoadBalancer.java @@ -29,6 +29,7 @@ public abstract class ForwardingLoadBalancer extends LoadBalancer { */ protected abstract LoadBalancer delegate(); + @Deprecated @Override public void handleResolvedAddresses(ResolvedAddresses resolvedAddresses) { delegate().handleResolvedAddresses(resolvedAddresses); @@ -52,6 +53,8 @@ public void shutdown() { } @Override + @Deprecated + @SuppressWarnings("InlineMeSuggester") public boolean canHandleEmptyAddressListFromNameResolution() { return delegate().canHandleEmptyAddressListFromNameResolution(); } diff --git a/util/src/main/java/io/grpc/util/ForwardingLoadBalancerHelper.java b/util/src/main/java/io/grpc/util/ForwardingLoadBalancerHelper.java index c684051f0b2..338903fc5fc 100644 --- a/util/src/main/java/io/grpc/util/ForwardingLoadBalancerHelper.java +++ b/util/src/main/java/io/grpc/util/ForwardingLoadBalancerHelper.java @@ -28,6 +28,7 @@ import io.grpc.LoadBalancer.SubchannelPicker; import io.grpc.ManagedChannel; import io.grpc.ManagedChannelBuilder; +import io.grpc.MetricRecorder; import io.grpc.NameResolver; import io.grpc.NameResolverRegistry; import io.grpc.SynchronizationContext; @@ -105,6 +106,11 @@ public String getAuthority() { return delegate().getAuthority(); } + @Override + public String getChannelTarget() { + return delegate().getChannelTarget(); + } + @Override public ChannelCredentials getChannelCredentials() { return delegate().getChannelCredentials(); @@ -140,6 +146,11 @@ public NameResolverRegistry getNameResolverRegistry() { return delegate().getNameResolverRegistry(); } + @Override + public MetricRecorder getMetricRecorder() { + return delegate().getMetricRecorder(); + } + @Override public String toString() { return MoreObjects.toStringHelper(this).add("delegate", delegate()).toString(); diff --git a/util/src/main/java/io/grpc/util/ForwardingSubchannel.java b/util/src/main/java/io/grpc/util/ForwardingSubchannel.java index 51f2583186e..416be378162 100644 --- a/util/src/main/java/io/grpc/util/ForwardingSubchannel.java +++ b/util/src/main/java/io/grpc/util/ForwardingSubchannel.java @@ -74,11 +74,17 @@ public Object getInternalSubchannel() { return delegate().getInternalSubchannel(); } + @Override public void updateAddresses(List addrs) { delegate().updateAddresses(addrs); } + @Override + public Attributes getConnectedAddressAttributes() { + return delegate().getConnectedAddressAttributes(); + } + @Override public String toString() { return MoreObjects.toStringHelper(this).add("delegate", delegate()).toString(); diff --git a/util/src/main/java/io/grpc/util/GracefulSwitchLoadBalancer.java b/util/src/main/java/io/grpc/util/GracefulSwitchLoadBalancer.java index a07428a30b9..27dc080c71b 100644 --- a/util/src/main/java/io/grpc/util/GracefulSwitchLoadBalancer.java +++ b/util/src/main/java/io/grpc/util/GracefulSwitchLoadBalancer.java @@ -19,39 +19,39 @@ import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.base.Preconditions.checkState; -import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.MoreObjects; +import com.google.common.base.Objects; import io.grpc.ConnectivityState; import io.grpc.ConnectivityStateInfo; import io.grpc.ExperimentalApi; import io.grpc.LoadBalancer; +import io.grpc.LoadBalancerRegistry; +import io.grpc.NameResolver.ConfigOrError; import io.grpc.Status; +import io.grpc.internal.ServiceConfigUtil; +import java.util.List; +import java.util.Map; import javax.annotation.Nullable; import javax.annotation.concurrent.NotThreadSafe; /** * A load balancer that gracefully swaps to a new lb policy. If the channel is currently in a state * other than READY, the new policy will be swapped into place immediately. Otherwise, the channel - * will keep using the old policy until the new policy reports READY or the old policy exits READY. + * will keep using the old policy until the new policy leaves CONNECTING or the old policy exits + * READY. * - *

The balancer must {@link #switchTo(LoadBalancer.Factory) switch to} a policy prior to {@link - * LoadBalancer#handleResolvedAddresses(ResolvedAddresses) handling resolved addresses} for the - * first time. + *

The child balancer and configuration is specified using service config. Config objects are + * generally created by calling {@link #parseLoadBalancingPolicyConfig(List)} from a + * {@link io.grpc.LoadBalancerProvider#parseLoadBalancingPolicyConfig + * provider's parseLoadBalancingPolicyConfig()} implementation. */ @ExperimentalApi("https://github.com/grpc/grpc-java/issues/5999") @NotThreadSafe // Must be accessed in SynchronizationContext public final class GracefulSwitchLoadBalancer extends ForwardingLoadBalancer { private final LoadBalancer defaultBalancer = new LoadBalancer() { @Override - public void handleResolvedAddresses(ResolvedAddresses resolvedAddresses) { - // Most LB policies using this class will receive child policy configuration within the - // service config, so they are naturally calling switchTo() just before - // handleResolvedAddresses(), within their own handleResolvedAddresses(). If switchTo() is - // not called immediately after construction that does open up potential for bugs in the - // parent policies, where they fail to call switchTo(). So we will use the exception to try - // to notice those bugs quickly, as it will fail very loudly. - throw new IllegalStateException( - "GracefulSwitchLoadBalancer must switch to a load balancing policy before handling" - + " ResolvedAddresses"); + public Status acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { + throw new AssertionError("real LB is called instead"); } @Override @@ -65,19 +65,6 @@ public void handleNameResolutionError(final Status error) { public void shutdown() {} }; - @VisibleForTesting - static final SubchannelPicker BUFFER_PICKER = new SubchannelPicker() { - @Override - public PickResult pickSubchannel(PickSubchannelArgs args) { - return PickResult.withNoResult(); - } - - @Override - public String toString() { - return "BUFFER_PICKER"; - } - }; - private final Helper helper; // While the new policy is not fully switched on, the pendingLb is handling new updates from name @@ -97,11 +84,28 @@ public GracefulSwitchLoadBalancer(Helper helper) { this.helper = checkNotNull(helper, "helper"); } - /** - * Gracefully switch to a new policy defined by the given factory, if the given factory isn't - * equal to the current one. - */ - public void switchTo(LoadBalancer.Factory newBalancerFactory) { + @Deprecated + @Override + public void handleResolvedAddresses(ResolvedAddresses resolvedAddresses) { + Config config = (Config) resolvedAddresses.getLoadBalancingPolicyConfig(); + switchToInternal(config.childFactory); + delegate().handleResolvedAddresses( + resolvedAddresses.toBuilder() + .setLoadBalancingPolicyConfig(config.childConfig) + .build()); + } + + @Override + public Status acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { + Config config = (Config) resolvedAddresses.getLoadBalancingPolicyConfig(); + switchToInternal(config.childFactory); + return delegate().acceptResolvedAddresses( + resolvedAddresses.toBuilder() + .setLoadBalancingPolicyConfig(config.childConfig) + .build()); + } + + private void switchToInternal(LoadBalancer.Factory newBalancerFactory) { checkNotNull(newBalancerFactory, "newBalancerFactory"); if (newBalancerFactory.equals(pendingBalancerFactory)) { @@ -111,7 +115,7 @@ public void switchTo(LoadBalancer.Factory newBalancerFactory) { pendingLb = defaultBalancer; pendingBalancerFactory = null; pendingState = ConnectivityState.CONNECTING; - pendingPicker = BUFFER_PICKER; + pendingPicker = new FixedResultPicker(PickResult.withNoResult()); if (newBalancerFactory.equals(currentBalancerFactory)) { return; @@ -131,7 +135,7 @@ public void updateBalancingState(ConnectivityState newState, SubchannelPicker ne checkState(currentLbIsReady, "there's pending lb while current lb has been out of READY"); pendingState = newState; pendingPicker = newPicker; - if (newState == ConnectivityState.READY) { + if (newState != ConnectivityState.CONNECTING) { swap(); } } else if (lb == currentLb) { @@ -185,4 +189,86 @@ public void shutdown() { public String delegateType() { return delegate().getClass().getSimpleName(); } + + /** + * Provided a JSON list of LoadBalancingConfigs, parse it into a config to pass to GracefulSwitch. + */ + public static ConfigOrError parseLoadBalancingPolicyConfig( + List> loadBalancingConfigs) { + return parseLoadBalancingPolicyConfig( + loadBalancingConfigs, LoadBalancerRegistry.getDefaultRegistry()); + } + + /** + * Provided a JSON list of LoadBalancingConfigs, parse it into a config to pass to GracefulSwitch. + */ + public static ConfigOrError parseLoadBalancingPolicyConfig( + List> loadBalancingConfigs, LoadBalancerRegistry lbRegistry) { + List childConfigCandidates = + ServiceConfigUtil.unwrapLoadBalancingConfigList(loadBalancingConfigs); + if (childConfigCandidates == null || childConfigCandidates.isEmpty()) { + return ConfigOrError.fromError( + Status.UNAVAILABLE.withDescription("No child LB config specified")); + } + ConfigOrError selectedConfig = + ServiceConfigUtil.selectLbPolicyFromList(childConfigCandidates, lbRegistry); + if (selectedConfig.getError() != null) { + Status error = selectedConfig.getError(); + return ConfigOrError.fromError( + Status.UNAVAILABLE + .withCause(error.getCause()) + .withDescription(error.getDescription()) + .augmentDescription("Failed to select child config")); + } + ServiceConfigUtil.PolicySelection selection = + (ServiceConfigUtil.PolicySelection) selectedConfig.getConfig(); + return ConfigOrError.fromConfig( + createLoadBalancingPolicyConfig(selection.getProvider(), selection.getConfig())); + } + + /** + * Directly create a config to pass to GracefulSwitch. The object returned is the same as would be + * found in {@code ConfigOrError.getConfig()}. + */ + public static Object createLoadBalancingPolicyConfig( + LoadBalancer.Factory childFactory, @Nullable Object childConfig) { + return new Config(childFactory, childConfig); + } + + static final class Config { + final LoadBalancer.Factory childFactory; + @Nullable + final Object childConfig; + + public Config(LoadBalancer.Factory childFactory, @Nullable Object childConfig) { + this.childFactory = checkNotNull(childFactory, "childFactory"); + this.childConfig = childConfig; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof Config)) { + return false; + } + Config that = (Config) o; + return Objects.equal(childFactory, that.childFactory) + && Objects.equal(childConfig, that.childConfig); + } + + @Override + public int hashCode() { + return Objects.hashCode(childFactory, childConfig); + } + + @Override + public String toString() { + return MoreObjects.toStringHelper("GracefulSwitchLoadBalancer.Config") + .add("childFactory", childFactory) + .add("childConfig", childConfig) + .toString(); + } + } } diff --git a/util/src/main/java/io/grpc/util/HealthProducerHelper.java b/util/src/main/java/io/grpc/util/HealthProducerHelper.java index b11864765ea..d871911d203 100644 --- a/util/src/main/java/io/grpc/util/HealthProducerHelper.java +++ b/util/src/main/java/io/grpc/util/HealthProducerHelper.java @@ -22,6 +22,7 @@ import com.google.common.annotations.VisibleForTesting; import io.grpc.Attributes; +import io.grpc.ConnectivityState; import io.grpc.ConnectivityStateInfo; import io.grpc.Internal; import io.grpc.LoadBalancer; @@ -84,6 +85,31 @@ protected LoadBalancer.Helper delegate() { return delegate; } + @Override + public void updateBalancingState( + ConnectivityState newState, LoadBalancer.SubchannelPicker newPicker) { + delegate.updateBalancingState(newState, new HealthProducerPicker(newPicker)); + } + + private static final class HealthProducerPicker extends LoadBalancer.SubchannelPicker { + private final LoadBalancer.SubchannelPicker delegate; + + HealthProducerPicker(LoadBalancer.SubchannelPicker delegate) { + this.delegate = delegate; + } + + @Override + public LoadBalancer.PickResult pickSubchannel(LoadBalancer.PickSubchannelArgs args) { + LoadBalancer.PickResult result = delegate.pickSubchannel(args); + LoadBalancer.Subchannel subchannel = result.getSubchannel(); + if (subchannel instanceof HealthProducerSubchannel) { + return result.copyWithSubchannel( + ((HealthProducerSubchannel) subchannel).delegate()); + } + return result; + } + } + // The parent subchannel in the health check producer LB chain. It duplicates subchannel state to // both the state listener and health listener. @VisibleForTesting diff --git a/util/src/main/java/io/grpc/util/MultiChildLoadBalancer.java b/util/src/main/java/io/grpc/util/MultiChildLoadBalancer.java index 862be71f125..acc186e3be6 100644 --- a/util/src/main/java/io/grpc/util/MultiChildLoadBalancer.java +++ b/util/src/main/java/io/grpc/util/MultiChildLoadBalancer.java @@ -16,7 +16,6 @@ package io.grpc.util; -import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkNotNull; import static io.grpc.ConnectivityState.CONNECTING; import static io.grpc.ConnectivityState.IDLE; @@ -25,8 +24,9 @@ import static io.grpc.ConnectivityState.TRANSIENT_FAILURE; import com.google.common.annotations.VisibleForTesting; -import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Iterables; +import com.google.common.collect.Maps; +import com.google.common.primitives.UnsignedInts; import io.grpc.Attributes; import io.grpc.ConnectivityState; import io.grpc.EquivalentAddressGroup; @@ -37,14 +37,12 @@ import io.grpc.internal.PickFirstLoadBalancerProvider; import java.net.SocketAddress; import java.util.ArrayList; -import java.util.Arrays; import java.util.Collection; import java.util.Collections; -import java.util.HashMap; -import java.util.LinkedHashMap; +import java.util.HashSet; import java.util.List; import java.util.Map; -import java.util.Set; +import java.util.Random; import java.util.logging.Level; import java.util.logging.Logger; import javax.annotation.Nullable; @@ -57,7 +55,9 @@ public abstract class MultiChildLoadBalancer extends LoadBalancer { private static final Logger logger = Logger.getLogger(MultiChildLoadBalancer.class.getName()); - private final Map childLbStates = new LinkedHashMap<>(); + private static final int OFFSET_SEED = new Random().nextInt(); + // Modify by replacing the list to release memory when no longer used. + private List childLbStates = new ArrayList<>(0); private final Helper helper; // Set to true if currently in the process of handling resolved addresses. protected boolean resolvingAddresses; @@ -81,29 +81,29 @@ protected MultiChildLoadBalancer(Helper helper) { /** * Override to utilize parsing of the policy configuration or alternative helper/lb generation. + * Override this if keys are not Endpoints or if child policies have configuration. Null map + * values preserve the child without delivering the child an update. */ - protected Map createChildLbMap(ResolvedAddresses resolvedAddresses) { - Map childLbMap = new HashMap<>(); - List addresses = resolvedAddresses.getAddresses(); - for (EquivalentAddressGroup eag : addresses) { - Endpoint endpoint = new Endpoint(eag); // keys need to be just addresses - ChildLbState existingChildLbState = childLbStates.get(endpoint); - if (existingChildLbState != null) { - childLbMap.put(endpoint, existingChildLbState); - } else { - childLbMap.put(endpoint, - createChildLbState(endpoint, null, getInitialPicker(), resolvedAddresses)); - } - } - return childLbMap; + protected Map createChildAddressesMap( + ResolvedAddresses resolvedAddresses) { + Map childAddresses = + Maps.newLinkedHashMapWithExpectedSize(resolvedAddresses.getAddresses().size()); + for (EquivalentAddressGroup eag : resolvedAddresses.getAddresses()) { + ResolvedAddresses addresses = resolvedAddresses.toBuilder() + .setAddresses(Collections.singletonList(eag)) + .setAttributes(Attributes.newBuilder().set(IS_PETIOLE_POLICY, true).build()) + .setLoadBalancingPolicyConfig(null) + .build(); + childAddresses.put(new Endpoint(eag), addresses); + } + return childAddresses; } /** * Override to create an instance of a subclass. */ - protected ChildLbState createChildLbState(Object key, Object policyConfig, - SubchannelPicker initialPicker, ResolvedAddresses resolvedAddresses) { - return new ChildLbState(key, pickFirstLbProvider, policyConfig, initialPicker); + protected ChildLbState createChildLbState(Object key) { + return new ChildLbState(key, pickFirstLbProvider); } /** @@ -111,61 +111,27 @@ protected ChildLbState createChildLbState(Object key, Object policyConfig, */ @Override public Status acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { + logger.log(Level.FINE, "Received resolution result: {0}", resolvedAddresses); try { resolvingAddresses = true; // process resolvedAddresses to update children - AcceptResolvedAddrRetVal acceptRetVal = acceptResolvedAddressesInternal(resolvedAddresses); - if (!acceptRetVal.status.isOk()) { - return acceptRetVal.status; + Map newChildAddresses = createChildAddressesMap(resolvedAddresses); + + // Handle error case + if (newChildAddresses.isEmpty()) { + Status unavailableStatus = Status.UNAVAILABLE.withDescription( + "NameResolver returned no usable address. " + resolvedAddresses); + handleNameResolutionError(unavailableStatus); + return unavailableStatus; } - // Update the picker and our connectivity state - updateOverallBalancingState(); - - // shutdown removed children - shutdownRemoved(acceptRetVal.removedChildren); - return acceptRetVal.status; + return updateChildrenWithResolvedAddresses(newChildAddresses); } finally { resolvingAddresses = false; } } - /** - * Override this if your keys are not of type Endpoint. - * @param key Key to identify the ChildLbState - * @param resolvedAddresses list of addresses which include attributes - * @param childConfig a load balancing policy config. This field is optional. - * @return a fully loaded ResolvedAddresses object for the specified key - */ - protected ResolvedAddresses getChildAddresses(Object key, ResolvedAddresses resolvedAddresses, - Object childConfig) { - Endpoint endpointKey; - if (key instanceof EquivalentAddressGroup) { - endpointKey = new Endpoint((EquivalentAddressGroup) key); - } else { - checkArgument(key instanceof Endpoint, "key is wrong type"); - endpointKey = (Endpoint) key; - } - - // Retrieve the non-stripped version - EquivalentAddressGroup eagToUse = null; - for (EquivalentAddressGroup currEag : resolvedAddresses.getAddresses()) { - if (endpointKey.equals(new Endpoint(currEag))) { - eagToUse = currEag; - break; - } - } - - checkNotNull(eagToUse, key + " no longer present in load balancer children"); - - return resolvedAddresses.toBuilder() - .setAddresses(Collections.singletonList(eagToUse)) - .setAttributes(Attributes.newBuilder().set(IS_PETIOLE_POLICY, true).build()) - .setLoadBalancingPolicyConfig(childConfig) - .build(); - } - /** * Handle the name resolution error. * @@ -174,134 +140,75 @@ protected ResolvedAddresses getChildAddresses(Object key, ResolvedAddresses reso @Override public void handleNameResolutionError(Status error) { if (currentConnectivityState != READY) { - helper.updateBalancingState(TRANSIENT_FAILURE, getErrorPicker(error)); + helper.updateBalancingState( + TRANSIENT_FAILURE, new FixedResultPicker(PickResult.withError(error))); } } - /** - * Handle the name resolution error only for the specified child. - * - *

Override if you need special handling. - */ - protected void handleNameResolutionError(ChildLbState child, Status error) { - child.lb.handleNameResolutionError(error); - } - - /** - * Creates a picker representing the state before any connections have been established. - * - *

Override to produce a custom picker. - */ - protected SubchannelPicker getInitialPicker() { - return new FixedResultPicker(PickResult.withNoResult()); - } - - /** - * Creates a new picker representing an error status. - * - *

Override to produce a custom picker when there are errors. - */ - protected SubchannelPicker getErrorPicker(Status error) { - return new FixedResultPicker(PickResult.withError(error)); - } - @Override public void shutdown() { logger.log(Level.FINE, "Shutdown"); - for (ChildLbState state : childLbStates.values()) { + for (ChildLbState state : childLbStates) { state.shutdown(); } childLbStates.clear(); } - /** - * This does the work to update the child map and calculate which children have been removed. - * You must call {@link #updateOverallBalancingState} to update the picker - * and call {@link #shutdownRemoved(List)} to shutdown the endpoints that have been removed. - */ - protected final AcceptResolvedAddrRetVal acceptResolvedAddressesInternal( - ResolvedAddresses resolvedAddresses) { - logger.log(Level.FINE, "Received resolution result: {0}", resolvedAddresses); - - // Subclass handles any special manipulation to create appropriate types of keyed ChildLbStates - Map newChildren = createChildLbMap(resolvedAddresses); - - // Handle error case - if (newChildren.isEmpty()) { - Status unavailableStatus = Status.UNAVAILABLE.withDescription( - "NameResolver returned no usable address. " + resolvedAddresses); - handleNameResolutionError(unavailableStatus); - return new AcceptResolvedAddrRetVal(unavailableStatus, null); - } - - Collection reusedChildren = addMissingChildrenAndIdReuse(newChildren); - - // Raactivate deactivated children - for (ChildLbState reusedChild : reusedChildren) { - reusedChild.reactivate(reusedChild.getPolicyProvider()); - } - - updateChildrenWithResolvedAddresses(resolvedAddresses, newChildren); - - return new AcceptResolvedAddrRetVal(Status.OK, getRemovedChildren(newChildren.keySet())); - } - - protected final Collection addMissingChildrenAndIdReuse( - Map newChildren) { - Collection reusedChildren = new ArrayList<>(); - - // Do adds and identify reused children - for (Map.Entry entry : newChildren.entrySet()) { - final Object key = entry.getKey(); - if (!childLbStates.containsKey(key)) { - childLbStates.put(key, entry.getValue()); - } else { - // Reuse the existing one - ChildLbState existingChildLbState = childLbStates.get(key); - if (existingChildLbState.isDeactivated()) { - reusedChildren.add(existingChildLbState); + private Status updateChildrenWithResolvedAddresses( + Map newChildAddresses) { + // Create a map with the old values + Map oldStatesMap = + Maps.newLinkedHashMapWithExpectedSize(childLbStates.size()); + for (ChildLbState state : childLbStates) { + oldStatesMap.put(state.getKey(), state); + } + + // Move ChildLbStates from the map to a new list (preserving the new map's order) + Status status = Status.OK; + List newChildLbStates = new ArrayList<>(newChildAddresses.size()); + for (Map.Entry entry : newChildAddresses.entrySet()) { + ChildLbState childLbState = oldStatesMap.remove(entry.getKey()); + if (childLbState == null) { + childLbState = createChildLbState(entry.getKey()); + } + newChildLbStates.add(childLbState); + } + // Use a random start position for child updates to weakly "shuffle" connection creation order. + // The network will often add noise to the creation order, but this avoids giving earlier + // children a consistent head start. + for (ChildLbState childLbState : offsetIterable(newChildLbStates, OFFSET_SEED)) { + ResolvedAddresses addresses = newChildAddresses.get(childLbState.getKey()); + if (addresses != null) { + // update child LB + Status newStatus = childLbState.lb.acceptResolvedAddresses(addresses); + if (!newStatus.isOk()) { + status = newStatus; } } } - return reusedChildren; - } - protected final void updateChildrenWithResolvedAddresses(ResolvedAddresses resolvedAddresses, - Map newChildren) { - for (Map.Entry entry : newChildren.entrySet()) { - Object childConfig = entry.getValue().getConfig(); - ChildLbState childLbState = childLbStates.get(entry.getKey()); - ResolvedAddresses childAddresses = - getChildAddresses(entry.getKey(), resolvedAddresses, childConfig); - childLbState.setResolvedAddresses(childAddresses); // update child - if (!childLbState.deactivated) { - childLbState.lb.handleResolvedAddresses(childAddresses); // update child LB - } - } - } + childLbStates = newChildLbStates; + // Update the picker and our connectivity state + updateOverallBalancingState(); - /** - * Identifies which children have been removed (are not part of the newChildKeys). - */ - protected final List getRemovedChildren(Set newChildKeys) { - List removedChildren = new ArrayList<>(); - // Do removals - for (Object key : ImmutableList.copyOf(childLbStates.keySet())) { - if (!newChildKeys.contains(key)) { - ChildLbState childLbState = childLbStates.get(key); - childLbState.deactivate(); - removedChildren.add(childLbState); - } + // Remaining entries in map are orphaned + for (ChildLbState childLbState : oldStatesMap.values()) { + childLbState.shutdown(); } - return removedChildren; + return status; } - protected final void shutdownRemoved(List removedChildren) { - // Do shutdowns after updating picker to reduce the chance of failing an RPC by picking a - // subchannel that has been shutdown. - for (ChildLbState childLbState : removedChildren) { - childLbState.shutdown(); + @VisibleForTesting + static Iterable offsetIterable(Collection c, int seed) { + int pos; + if (c.isEmpty()) { + pos = 0; + } else { + pos = UnsignedInts.remainder(seed, c.size()); } + return Iterables.concat( + Iterables.skip(c, pos), + Iterables.limit(c, pos)); } @Nullable @@ -326,43 +233,18 @@ protected final Helper getHelper() { return helper; } - protected final void removeChild(Object key) { - childLbStates.remove(key); - } - - @VisibleForTesting - public final ImmutableMap getImmutableChildMap() { - return ImmutableMap.copyOf(childLbStates); - } - @VisibleForTesting public final Collection getChildLbStates() { - return childLbStates.values(); - } - - @VisibleForTesting - public final ChildLbState getChildLbState(Object key) { - if (key == null) { - return null; - } - if (key instanceof EquivalentAddressGroup) { - key = new Endpoint((EquivalentAddressGroup) key); - } - return childLbStates.get(key); - } - - @VisibleForTesting - public final ChildLbState getChildLbStateEag(EquivalentAddressGroup eag) { - return getChildLbState(new Endpoint(eag)); + return childLbStates; } /** - * Filters out non-ready and deactivated child load balancers (subchannels). + * Filters out non-ready child load balancers (subchannels). */ protected final List getReadyChildren() { List activeChildren = new ArrayList<>(); for (ChildLbState child : getChildLbStates()) { - if (!child.isDeactivated() && child.getCurrentState() == READY) { + if (child.getCurrentState() == READY) { activeChildren.add(child); } } @@ -372,9 +254,7 @@ protected final List getReadyChildren() { /** * This represents the state of load balancer children. Each endpoint (represented by an * EquivalentAddressGroup or EDS string) will have a separate ChildLbState which in turn will - * define a GracefulSwitchLoadBalancer. When the GracefulSwitchLoadBalancer is activated, a - * single PickFirstLoadBalancer will be created which will then create a subchannel and start - * trying to connect to it. + * have a single child LoadBalancer created from the provided factory. * *

A ChildLbStateHelper is the glue between ChildLbState and the helpers associated with the * petiole policy above and the PickFirstLoadBalancer's helper below. @@ -384,68 +264,22 @@ protected final List getReadyChildren() { */ public class ChildLbState { private final Object key; - private ResolvedAddresses resolvedAddresses; - private final Object config; - - private final GracefulSwitchLoadBalancer lb; - private final LoadBalancerProvider policyProvider; + private final LoadBalancer lb; private ConnectivityState currentState; - private SubchannelPicker currentPicker; - private boolean deactivated; + private SubchannelPicker currentPicker = new FixedResultPicker(PickResult.withNoResult()); - public ChildLbState(Object key, LoadBalancerProvider policyProvider, Object childConfig, - SubchannelPicker initialPicker) { - this(key, policyProvider, childConfig, initialPicker, null, false); - } - - public ChildLbState(Object key, LoadBalancerProvider policyProvider, Object childConfig, - SubchannelPicker initialPicker, ResolvedAddresses resolvedAddrs, boolean deactivated) { + @SuppressWarnings("this-escape") + // TODO(okshiva): Fix 'this-escape' from the constructor before making the API public. + public ChildLbState(Object key, LoadBalancer.Factory policyFactory) { this.key = key; - this.policyProvider = policyProvider; - this.deactivated = deactivated; - this.currentPicker = initialPicker; - this.config = childConfig; - this.lb = new GracefulSwitchLoadBalancer(createChildHelper()); - this.currentState = deactivated ? IDLE : CONNECTING; - this.resolvedAddresses = resolvedAddrs; - if (!deactivated) { - lb.switchTo(policyProvider); - } + this.lb = policyFactory.newLoadBalancer(createChildHelper()); + this.currentState = CONNECTING; } protected ChildLbStateHelper createChildHelper() { return new ChildLbStateHelper(); } - /** - * The default implementation. This not only marks the lb policy as not active, it also removes - * this child from the map of children maintained by the petiole policy. - * - *

Note that this does not explicitly shutdown this child. That will generally be done by - * acceptResolvedAddresses on the LB, but can also be handled by an override such as is done - * in ClusterManagerLoadBalancer. - * - *

If you plan to reactivate, you will probably want to override this to not call - * childLbStates.remove() and handle that cleanup another way. - */ - protected void deactivate() { - if (deactivated) { - return; - } - - childLbStates.remove(key); // This means it can't be reactivated again - deactivated = true; - logger.log(Level.FINE, "Child balancer {0} deactivated", key); - } - - /** - * This base implementation does nothing but reset the flag. If you really want to both - * deactivate and reactivate you should override them both. - */ - protected void reactivate(LoadBalancerProvider policyProvider) { - deactivated = false; - } - /** * Override for unique behavior such as delayed shutdowns of subchannels. */ @@ -460,8 +294,7 @@ public String toString() { return "Address = " + key + ", state = " + currentState + ", picker type: " + currentPicker.getClass() - + ", lb: " + lb.delegate().getClass() - + (deactivated ? ", deactivated" : ""); + + ", lb: " + lb; } public final Object getKey() { @@ -469,7 +302,7 @@ public final Object getKey() { } @VisibleForTesting - public final GracefulSwitchLoadBalancer getLb() { + public final LoadBalancer getLb() { return lb; } @@ -478,17 +311,6 @@ public final SubchannelPicker getCurrentPicker() { return currentPicker; } - protected final LoadBalancerProvider getPolicyProvider() { - return policyProvider; - } - - protected final Subchannel getSubchannels(PickSubchannelArgs args) { - if (getCurrentPicker() == null) { - return null; - } - return getCurrentPicker().pickSubchannel(args).getSubchannel(); - } - public final ConnectivityState getCurrentState() { return currentState; } @@ -501,39 +323,6 @@ protected final void setCurrentPicker(SubchannelPicker newPicker) { currentPicker = newPicker; } - public final EquivalentAddressGroup getEag() { - if (resolvedAddresses == null || resolvedAddresses.getAddresses().isEmpty()) { - return null; - } - return resolvedAddresses.getAddresses().get(0); - } - - public final boolean isDeactivated() { - return deactivated; - } - - protected final void setDeactivated() { - deactivated = true; - } - - protected final void markReactivated() { - deactivated = false; - } - - protected final void setResolvedAddresses(ResolvedAddresses newAddresses) { - checkNotNull(newAddresses, "Missing address list for child"); - resolvedAddresses = newAddresses; - } - - private Object getConfig() { - return config; - } - - @VisibleForTesting - public final ResolvedAddresses getResolvedAddresses() { - return resolvedAddresses; - } - /** * ChildLbStateHelper is the glue between ChildLbState and the helpers associated with the * petiole policy above and the PickFirstLoadBalancer's helper below. @@ -546,26 +335,19 @@ protected class ChildLbStateHelper extends ForwardingLoadBalancerHelper { /** * Update current state and picker for this child and then use * {@link #updateOverallBalancingState()} for the parent LB. - * - *

Override this if you don't want to automatically request a connection when in IDLE */ @Override public void updateBalancingState(final ConnectivityState newState, final SubchannelPicker newPicker) { - // If we are already in the process of resolving addresses, the overall balancing state - // will be updated at the end of it, and we don't need to trigger that update here. - if (!childLbStates.containsKey(key)) { + if (currentState == SHUTDOWN) { return; } - // Subchannel picker and state are saved, but will only be propagated to the channel - // when the child instance exits deactivated state. currentState = newState; currentPicker = newPicker; - if (!deactivated && !resolvingAddresses) { - if (newState == IDLE) { - lb.requestConnection(); - } + // If we are already in the process of resolving addresses, the overall balancing state + // will be updated at the end of it, and we don't need to trigger that update here. + if (!resolvingAddresses) { updateOverallBalancingState(); } } @@ -579,25 +361,27 @@ protected Helper delegate() { /** * Endpoint is an optimization to quickly lookup and compare EquivalentAddressGroup address sets. - * Ignores the attributes, orders the addresses in a deterministic manner and converts each - * address into a string for easy comparison. Also caches the hashcode. - * Is used as a key for ChildLbState for most load balancers (ClusterManagerLB uses a String). + * It ignores the attributes. Is used as a key for ChildLbState for most load balancers + * (ClusterManagerLB uses a String). */ protected static class Endpoint { - final String[] addrs; + final Collection addrs; final int hashCode; public Endpoint(EquivalentAddressGroup eag) { checkNotNull(eag, "eag"); - addrs = new String[eag.getAddresses().size()]; - int i = 0; + if (eag.getAddresses().size() < 10) { + addrs = eag.getAddresses(); + } else { + // This is expected to be very unlikely in practice + addrs = new HashSet<>(eag.getAddresses()); + } + int sum = 0; for (SocketAddress address : eag.getAddresses()) { - addrs[i++] = address.toString(); + sum += address.hashCode(); } - Arrays.sort(addrs); - - hashCode = Arrays.hashCode(addrs); + hashCode = sum; } @Override @@ -610,34 +394,21 @@ public boolean equals(Object other) { if (this == other) { return true; } - if (other == null) { - return false; - } if (!(other instanceof Endpoint)) { return false; } Endpoint o = (Endpoint) other; - if (o.hashCode != hashCode || o.addrs.length != addrs.length) { + if (o.hashCode != hashCode || o.addrs.size() != addrs.size()) { return false; } - return Arrays.equals(o.addrs, this.addrs); + return o.addrs.containsAll(addrs); } @Override public String toString() { - return Arrays.toString(addrs); - } - } - - protected static class AcceptResolvedAddrRetVal { - public final Status status; - public final List removedChildren; - - public AcceptResolvedAddrRetVal(Status status, List removedChildren) { - this.status = status; - this.removedChildren = removedChildren; + return addrs.toString(); } } } diff --git a/util/src/main/java/io/grpc/util/OutlierDetectionLoadBalancer.java b/util/src/main/java/io/grpc/util/OutlierDetectionLoadBalancer.java index c24e2386466..dc61441bccd 100644 --- a/util/src/main/java/io/grpc/util/OutlierDetectionLoadBalancer.java +++ b/util/src/main/java/io/grpc/util/OutlierDetectionLoadBalancer.java @@ -22,6 +22,7 @@ import static java.util.concurrent.TimeUnit.NANOSECONDS; import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Ticker; import com.google.common.collect.ForwardingMap; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; @@ -39,8 +40,6 @@ import io.grpc.Status; import io.grpc.SynchronizationContext; import io.grpc.SynchronizationContext.ScheduledHandle; -import io.grpc.internal.ServiceConfigUtil.PolicySelection; -import io.grpc.internal.TimeProvider; import java.net.SocketAddress; import java.util.ArrayList; import java.util.Collection; @@ -83,7 +82,7 @@ public final class OutlierDetectionLoadBalancer extends LoadBalancer { private final SynchronizationContext syncContext; private final Helper childHelper; private final GracefulSwitchLoadBalancer switchLb; - private TimeProvider timeProvider; + private Ticker ticker; private final ScheduledExecutorService timeService; private ScheduledHandle detectionTimerHandle; private Long detectionTimerStartNanos; @@ -96,14 +95,14 @@ public final class OutlierDetectionLoadBalancer extends LoadBalancer { /** * Creates a new instance of {@link OutlierDetectionLoadBalancer}. */ - public OutlierDetectionLoadBalancer(Helper helper, TimeProvider timeProvider) { + public OutlierDetectionLoadBalancer(Helper helper, Ticker ticker) { logger = helper.getChannelLogger(); childHelper = new ChildHelper(checkNotNull(helper, "helper")); switchLb = new GracefulSwitchLoadBalancer(childHelper); endpointTrackerMap = new EndpointTrackerMap(); this.syncContext = checkNotNull(helper.getSynchronizationContext(), "syncContext"); this.timeService = checkNotNull(helper.getScheduledExecutorService(), "timeService"); - this.timeProvider = timeProvider; + this.ticker = ticker; logger.log(ChannelLogLevel.DEBUG, "OutlierDetection lb created."); } @@ -140,12 +139,10 @@ public Status acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { addressMap.put(e.getKey(), endpointTrackerMap.get(e.getValue())); } - switchLb.switchTo(config.childPolicy.getProvider()); - // If outlier detection is actually configured, start a timer that will periodically try to // detect outliers. if (config.outlierDetectionEnabled()) { - Long initialDelayNanos; + long initialDelayNanos; if (detectionTimerStartNanos == null) { // On the first go we use the configured interval. @@ -154,7 +151,7 @@ public Status acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { // If a timer has started earlier we cancel it and use the difference between the start // time and now as the interval. initialDelayNanos = Math.max(0L, - config.intervalNanos - (timeProvider.currentTimeNanos() - detectionTimerStartNanos)); + config.intervalNanos - (ticker.read() - detectionTimerStartNanos)); } // If a timer has been previously created we need to cancel it and reset all the call counters @@ -174,10 +171,8 @@ public Status acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { endpointTrackerMap.cancelTracking(); } - switchLb.handleResolvedAddresses( - resolvedAddresses.toBuilder().setLoadBalancingPolicyConfig(config.childPolicy.getConfig()) - .build()); - return Status.OK; + return switchLb.acceptResolvedAddresses( + resolvedAddresses.toBuilder().setLoadBalancingPolicyConfig(config.childConfig).build()); } @Override @@ -194,7 +189,7 @@ public void shutdown() { * This timer will be invoked periodically, according to configuration, and it will look for any * outlier subchannels. */ - class DetectionTimer implements Runnable { + final class DetectionTimer implements Runnable { OutlierDetectionLoadBalancerConfig config; ChannelLogger logger; @@ -206,7 +201,7 @@ class DetectionTimer implements Runnable { @Override public void run() { - detectionTimerStartNanos = timeProvider.currentTimeNanos(); + detectionTimerStartNanos = ticker.read(); endpointTrackerMap.swapCounters(); @@ -222,7 +217,7 @@ public void run() { * This child helper wraps the provided helper so that it can hand out wrapped {@link * OutlierDetectionSubchannel}s and manage the address info map. */ - class ChildHelper extends ForwardingLoadBalancerHelper { + final class ChildHelper extends ForwardingLoadBalancerHelper { private Helper delegate; @@ -264,7 +259,7 @@ public void updateBalancingState(ConnectivityState newState, SubchannelPicker ne } } - class OutlierDetectionSubchannel extends ForwardingSubchannel { + final class OutlierDetectionSubchannel extends ForwardingSubchannel { private final Subchannel delegate; private EndpointTracker endpointTracker; @@ -377,8 +372,9 @@ void clearEndpointTracker() { void eject() { ejected = true; - subchannelStateListener.onSubchannelState( - ConnectivityStateInfo.forTransientFailure(Status.UNAVAILABLE)); + subchannelStateListener.onSubchannelState(ConnectivityStateInfo.forTransientFailure( + Status.UNAVAILABLE.withDescription( + "The subchannel has been ejected by outlier detection"))); logger.log(ChannelLogLevel.INFO, "Subchannel ejected: {0}", this); } @@ -402,7 +398,7 @@ protected Subchannel delegate() { /** * Wraps the actual listener so that state changes from the actual one can be intercepted. */ - class OutlierDetectionSubchannelStateListener implements SubchannelStateListener { + final class OutlierDetectionSubchannelStateListener implements SubchannelStateListener { private final SubchannelStateListener delegate; @@ -432,7 +428,7 @@ public String toString() { * This picker delegates the actual picking logic to a wrapped delegate, but associates a {@link * ClientStreamTracer} with each pick to track the results of each subchannel stream. */ - class OutlierDetectionPicker extends SubchannelPicker { + final class OutlierDetectionPicker extends SubchannelPicker { private final SubchannelPicker delegate; @@ -446,9 +442,14 @@ public PickResult pickSubchannel(PickSubchannelArgs args) { Subchannel subchannel = pickResult.getSubchannel(); if (subchannel != null) { - return PickResult.withSubchannel(subchannel, new ResultCountingClientStreamTracerFactory( - subchannel.getAttributes().get(ENDPOINT_TRACKER_KEY), - pickResult.getStreamTracerFactory())); + EndpointTracker tracker = subchannel.getAttributes().get(ENDPOINT_TRACKER_KEY); + if (subchannel instanceof OutlierDetectionSubchannel) { + subchannel = ((OutlierDetectionSubchannel) subchannel).delegate(); + } + return pickResult.copyWithSubchannel(subchannel) + .copyWithStreamTracerFactory(new ResultCountingClientStreamTracerFactory( + tracker, + pickResult.getStreamTracerFactory())); } return pickResult; @@ -458,7 +459,7 @@ public PickResult pickSubchannel(PickSubchannelArgs args) { * Builds instances of a {@link ClientStreamTracer} that increments the call count in the * tracker for each closed stream. */ - class ResultCountingClientStreamTracerFactory extends ClientStreamTracer.Factory { + final class ResultCountingClientStreamTracerFactory extends ClientStreamTracer.Factory { private final EndpointTracker tracker; @@ -502,7 +503,7 @@ public void streamClosed(Status status) { /** * Tracks additional information about the endpoint needed for outlier detection. */ - static class EndpointTracker { + static final class EndpointTracker { private OutlierDetectionLoadBalancerConfig config; // Marked as volatile to assure that when the inactive counter is swapped in as the new active @@ -642,11 +643,11 @@ public boolean maxEjectionTimeElapsed(long currentTimeNanos) { config.baseEjectionTimeNanos * ejectionTimeMultiplier, maxEjectionDurationSecs); - return currentTimeNanos > maxEjectionTimeNanos; + return currentTimeNanos - maxEjectionTimeNanos > 0; } /** Tracks both successful and failed call counts. */ - private static class CallCounter { + private static final class CallCounter { AtomicLong successCount = new AtomicLong(); AtomicLong failureCount = new AtomicLong(); @@ -667,7 +668,7 @@ public String toString() { /** * Maintains a mapping from endpoint (a set of addresses) to their trackers. */ - static class EndpointTrackerMap extends ForwardingMap, EndpointTracker> { + static final class EndpointTrackerMap extends ForwardingMap, EndpointTracker> { private final Map, EndpointTracker> trackerMap; EndpointTrackerMap() { @@ -688,7 +689,11 @@ void updateTrackerConfigs(OutlierDetectionLoadBalancerConfig config) { /** Adds a new tracker for every given address. */ void putNewTrackers(OutlierDetectionLoadBalancerConfig config, Set> endpoints) { - endpoints.forEach(e -> trackerMap.putIfAbsent(e, new EndpointTracker(config))); + for (Set endpoint : endpoints) { + if (!trackerMap.containsKey(endpoint)) { + trackerMap.put(endpoint, new EndpointTracker(config)); + } + } } /** Resets the call counters for all the trackers in the map. */ @@ -723,7 +728,7 @@ void swapCounters() { * that don't have ejected subchannels and uneject ones that have spent the maximum ejection * time allowed. */ - void maybeUnejectOutliers(Long detectionTimerStartNanos) { + void maybeUnejectOutliers(long detectionTimerStartNanos) { for (EndpointTracker tracker : trackerMap.values()) { if (!tracker.subchannelsEjected()) { tracker.decrementEjectionTimeMultiplier(); @@ -784,7 +789,7 @@ static List forConfig(OutlierDetectionLoadBalancerConf * required rate is not fixed, but is based on the mean and standard deviation of the success * rates of all of the addresses. */ - static class SuccessRateOutlierEjectionAlgorithm implements OutlierEjectionAlgorithm { + static final class SuccessRateOutlierEjectionAlgorithm implements OutlierEjectionAlgorithm { private final OutlierDetectionLoadBalancerConfig config; @@ -869,7 +874,7 @@ static double standardDeviation(Collection values, double mean) { } } - static class FailurePercentageOutlierEjectionAlgorithm implements OutlierEjectionAlgorithm { + static final class FailurePercentageOutlierEjectionAlgorithm implements OutlierEjectionAlgorithm { private final OutlierDetectionLoadBalancerConfig config; @@ -951,64 +956,54 @@ private static boolean hasSingleAddress(List addressGrou */ public static final class OutlierDetectionLoadBalancerConfig { - public final Long intervalNanos; - public final Long baseEjectionTimeNanos; - public final Long maxEjectionTimeNanos; - public final Integer maxEjectionPercent; + public final long intervalNanos; + public final long baseEjectionTimeNanos; + public final long maxEjectionTimeNanos; + public final int maxEjectionPercent; public final SuccessRateEjection successRateEjection; public final FailurePercentageEjection failurePercentageEjection; - public final PolicySelection childPolicy; - - private OutlierDetectionLoadBalancerConfig(Long intervalNanos, - Long baseEjectionTimeNanos, - Long maxEjectionTimeNanos, - Integer maxEjectionPercent, - SuccessRateEjection successRateEjection, - FailurePercentageEjection failurePercentageEjection, - PolicySelection childPolicy) { - this.intervalNanos = intervalNanos; - this.baseEjectionTimeNanos = baseEjectionTimeNanos; - this.maxEjectionTimeNanos = maxEjectionTimeNanos; - this.maxEjectionPercent = maxEjectionPercent; - this.successRateEjection = successRateEjection; - this.failurePercentageEjection = failurePercentageEjection; - this.childPolicy = childPolicy; + public final Object childConfig; + + private OutlierDetectionLoadBalancerConfig(Builder builder) { + this.intervalNanos = builder.intervalNanos; + this.baseEjectionTimeNanos = builder.baseEjectionTimeNanos; + this.maxEjectionTimeNanos = builder.maxEjectionTimeNanos; + this.maxEjectionPercent = builder.maxEjectionPercent; + this.successRateEjection = builder.successRateEjection; + this.failurePercentageEjection = builder.failurePercentageEjection; + this.childConfig = builder.childConfig; } /** Builds a new {@link OutlierDetectionLoadBalancerConfig}. */ - public static class Builder { - Long intervalNanos = 10_000_000_000L; // 10s - Long baseEjectionTimeNanos = 30_000_000_000L; // 30s - Long maxEjectionTimeNanos = 300_000_000_000L; // 300s - Integer maxEjectionPercent = 10; + public static final class Builder { + long intervalNanos = 10_000_000_000L; // 10s + long baseEjectionTimeNanos = 30_000_000_000L; // 30s + long maxEjectionTimeNanos = 300_000_000_000L; // 300s + int maxEjectionPercent = 10; SuccessRateEjection successRateEjection; FailurePercentageEjection failurePercentageEjection; - PolicySelection childPolicy; + Object childConfig; /** The interval between outlier detection sweeps. */ - public Builder setIntervalNanos(Long intervalNanos) { - checkArgument(intervalNanos != null); + public Builder setIntervalNanos(long intervalNanos) { this.intervalNanos = intervalNanos; return this; } /** The base time an address is ejected for. */ - public Builder setBaseEjectionTimeNanos(Long baseEjectionTimeNanos) { - checkArgument(baseEjectionTimeNanos != null); + public Builder setBaseEjectionTimeNanos(long baseEjectionTimeNanos) { this.baseEjectionTimeNanos = baseEjectionTimeNanos; return this; } /** The longest time an address can be ejected. */ - public Builder setMaxEjectionTimeNanos(Long maxEjectionTimeNanos) { - checkArgument(maxEjectionTimeNanos != null); + public Builder setMaxEjectionTimeNanos(long maxEjectionTimeNanos) { this.maxEjectionTimeNanos = maxEjectionTimeNanos; return this; } /** The algorithm agnostic maximum percentage of addresses that can be ejected. */ - public Builder setMaxEjectionPercent(Integer maxEjectionPercent) { - checkArgument(maxEjectionPercent != null); + public Builder setMaxEjectionPercent(int maxEjectionPercent) { this.maxEjectionPercent = maxEjectionPercent; return this; } @@ -1027,74 +1022,70 @@ public Builder setFailurePercentageEjection( return this; } - /** Sets the child policy the {@link OutlierDetectionLoadBalancer} delegates to. */ - public Builder setChildPolicy(PolicySelection childPolicy) { - checkState(childPolicy != null); - this.childPolicy = childPolicy; + /** + * Sets the graceful child switch config the {@link OutlierDetectionLoadBalancer} delegates + * to. + */ + public Builder setChildConfig(Object childConfig) { + checkState(childConfig != null); + this.childConfig = childConfig; return this; } /** Builds a new instance of {@link OutlierDetectionLoadBalancerConfig}. */ public OutlierDetectionLoadBalancerConfig build() { - checkState(childPolicy != null); - return new OutlierDetectionLoadBalancerConfig(intervalNanos, baseEjectionTimeNanos, - maxEjectionTimeNanos, maxEjectionPercent, successRateEjection, - failurePercentageEjection, childPolicy); + checkState(childConfig != null); + return new OutlierDetectionLoadBalancerConfig(this); } } /** The configuration for success rate ejection. */ - public static class SuccessRateEjection { - - public final Integer stdevFactor; - public final Integer enforcementPercentage; - public final Integer minimumHosts; - public final Integer requestVolume; - - SuccessRateEjection(Integer stdevFactor, Integer enforcementPercentage, Integer minimumHosts, - Integer requestVolume) { - this.stdevFactor = stdevFactor; - this.enforcementPercentage = enforcementPercentage; - this.minimumHosts = minimumHosts; - this.requestVolume = requestVolume; + public static final class SuccessRateEjection { + + public final int stdevFactor; + public final int enforcementPercentage; + public final int minimumHosts; + public final int requestVolume; + + SuccessRateEjection(Builder builder) { + this.stdevFactor = builder.stdevFactor; + this.enforcementPercentage = builder.enforcementPercentage; + this.minimumHosts = builder.minimumHosts; + this.requestVolume = builder.requestVolume; } /** Builds new instances of {@link SuccessRateEjection}. */ public static final class Builder { - Integer stdevFactor = 1900; - Integer enforcementPercentage = 100; - Integer minimumHosts = 5; - Integer requestVolume = 100; + int stdevFactor = 1900; + int enforcementPercentage = 100; + int minimumHosts = 5; + int requestVolume = 100; /** The product of this and the standard deviation of success rates determine the ejection * threshold. */ - public Builder setStdevFactor(Integer stdevFactor) { - checkArgument(stdevFactor != null); + public Builder setStdevFactor(int stdevFactor) { this.stdevFactor = stdevFactor; return this; } /** Only eject this percentage of outliers. */ - public Builder setEnforcementPercentage(Integer enforcementPercentage) { - checkArgument(enforcementPercentage != null); + public Builder setEnforcementPercentage(int enforcementPercentage) { checkArgument(enforcementPercentage >= 0 && enforcementPercentage <= 100); this.enforcementPercentage = enforcementPercentage; return this; } /** The minimum amount of hosts needed for success rate ejection. */ - public Builder setMinimumHosts(Integer minimumHosts) { - checkArgument(minimumHosts != null); + public Builder setMinimumHosts(int minimumHosts) { checkArgument(minimumHosts >= 0); this.minimumHosts = minimumHosts; return this; } /** The minimum address request volume to be considered for success rate ejection. */ - public Builder setRequestVolume(Integer requestVolume) { - checkArgument(requestVolume != null); + public Builder setRequestVolume(int requestVolume) { checkArgument(requestVolume >= 0); this.requestVolume = requestVolume; return this; @@ -1102,53 +1093,48 @@ public Builder setRequestVolume(Integer requestVolume) { /** Builds a new instance of {@link SuccessRateEjection}. */ public SuccessRateEjection build() { - return new SuccessRateEjection(stdevFactor, enforcementPercentage, minimumHosts, - requestVolume); + return new SuccessRateEjection(this); } } } /** The configuration for failure percentage ejection. */ - public static class FailurePercentageEjection { - public final Integer threshold; - public final Integer enforcementPercentage; - public final Integer minimumHosts; - public final Integer requestVolume; - - FailurePercentageEjection(Integer threshold, Integer enforcementPercentage, - Integer minimumHosts, Integer requestVolume) { - this.threshold = threshold; - this.enforcementPercentage = enforcementPercentage; - this.minimumHosts = minimumHosts; - this.requestVolume = requestVolume; + public static final class FailurePercentageEjection { + public final int threshold; + public final int enforcementPercentage; + public final int minimumHosts; + public final int requestVolume; + + FailurePercentageEjection(Builder builder) { + this.threshold = builder.threshold; + this.enforcementPercentage = builder.enforcementPercentage; + this.minimumHosts = builder.minimumHosts; + this.requestVolume = builder.requestVolume; } /** For building new {@link FailurePercentageEjection} instances. */ public static class Builder { - Integer threshold = 85; - Integer enforcementPercentage = 100; - Integer minimumHosts = 5; - Integer requestVolume = 50; + int threshold = 85; + int enforcementPercentage = 100; + int minimumHosts = 5; + int requestVolume = 50; /** The failure percentage that will result in an address being considered an outlier. */ - public Builder setThreshold(Integer threshold) { - checkArgument(threshold != null); + public Builder setThreshold(int threshold) { checkArgument(threshold >= 0 && threshold <= 100); this.threshold = threshold; return this; } /** Only eject this percentage of outliers. */ - public Builder setEnforcementPercentage(Integer enforcementPercentage) { - checkArgument(enforcementPercentage != null); + public Builder setEnforcementPercentage(int enforcementPercentage) { checkArgument(enforcementPercentage >= 0 && enforcementPercentage <= 100); this.enforcementPercentage = enforcementPercentage; return this; } /** The minimum amount of host for failure percentage ejection to be enabled. */ - public Builder setMinimumHosts(Integer minimumHosts) { - checkArgument(minimumHosts != null); + public Builder setMinimumHosts(int minimumHosts) { checkArgument(minimumHosts >= 0); this.minimumHosts = minimumHosts; return this; @@ -1158,8 +1144,7 @@ public Builder setMinimumHosts(Integer minimumHosts) { * The request volume required for an address to be considered for failure percentage * ejection. */ - public Builder setRequestVolume(Integer requestVolume) { - checkArgument(requestVolume != null); + public Builder setRequestVolume(int requestVolume) { checkArgument(requestVolume >= 0); this.requestVolume = requestVolume; return this; @@ -1167,8 +1152,7 @@ public Builder setRequestVolume(Integer requestVolume) { /** Builds a new instance of {@link FailurePercentageEjection}. */ public FailurePercentageEjection build() { - return new FailurePercentageEjection(threshold, enforcementPercentage, minimumHosts, - requestVolume); + return new FailurePercentageEjection(this); } } } diff --git a/util/src/main/java/io/grpc/util/OutlierDetectionLoadBalancerProvider.java b/util/src/main/java/io/grpc/util/OutlierDetectionLoadBalancerProvider.java index 5d8233eb8ab..084898bc38f 100644 --- a/util/src/main/java/io/grpc/util/OutlierDetectionLoadBalancerProvider.java +++ b/util/src/main/java/io/grpc/util/OutlierDetectionLoadBalancerProvider.java @@ -16,22 +16,18 @@ package io.grpc.util; +import com.google.common.base.Ticker; import io.grpc.Internal; import io.grpc.LoadBalancer; import io.grpc.LoadBalancer.Helper; import io.grpc.LoadBalancerProvider; -import io.grpc.LoadBalancerRegistry; import io.grpc.NameResolver.ConfigOrError; import io.grpc.Status; +import io.grpc.internal.GrpcUtil; import io.grpc.internal.JsonUtil; -import io.grpc.internal.ServiceConfigUtil; -import io.grpc.internal.ServiceConfigUtil.LbConfig; -import io.grpc.internal.ServiceConfigUtil.PolicySelection; -import io.grpc.internal.TimeProvider; import io.grpc.util.OutlierDetectionLoadBalancer.OutlierDetectionLoadBalancerConfig; import io.grpc.util.OutlierDetectionLoadBalancer.OutlierDetectionLoadBalancerConfig.FailurePercentageEjection; import io.grpc.util.OutlierDetectionLoadBalancer.OutlierDetectionLoadBalancerConfig.SuccessRateEjection; -import java.util.List; import java.util.Map; @Internal @@ -39,7 +35,7 @@ public final class OutlierDetectionLoadBalancerProvider extends LoadBalancerProv @Override public LoadBalancer newLoadBalancer(Helper helper) { - return new OutlierDetectionLoadBalancer(helper, TimeProvider.SYSTEM_TIME_PROVIDER); + return new OutlierDetectionLoadBalancer(helper, Ticker.systemTicker()); } @Override @@ -150,20 +146,15 @@ private ConfigOrError parseLoadBalancingPolicyConfigInternal(Map rawC } // Child load balancer configuration. - List childConfigCandidates = ServiceConfigUtil.unwrapLoadBalancingConfigList( + ConfigOrError childConfig = GracefulSwitchLoadBalancer.parseLoadBalancingPolicyConfig( JsonUtil.getListOfObjects(rawConfig, "childPolicy")); - if (childConfigCandidates == null || childConfigCandidates.isEmpty()) { - return ConfigOrError.fromError(Status.INTERNAL.withDescription( - "No child policy in outlier_detection_experimental LB policy: " - + rawConfig)); + if (childConfig.getError() != null) { + return ConfigOrError.fromError(GrpcUtil.statusWithDetails( + Status.Code.UNAVAILABLE, + "Failed to parse child in outlier_detection_experimental", + childConfig.getError())); } - ConfigOrError selectedConfig = - ServiceConfigUtil.selectLbPolicyFromList(childConfigCandidates, - LoadBalancerRegistry.getDefaultRegistry()); - if (selectedConfig.getError() != null) { - return selectedConfig; - } - configBuilder.setChildPolicy((PolicySelection) selectedConfig.getConfig()); + configBuilder.setChildConfig(childConfig.getConfig()); return ConfigOrError.fromConfig(configBuilder.build()); } diff --git a/util/src/main/java/io/grpc/util/RandomSubsettingLoadBalancer.java b/util/src/main/java/io/grpc/util/RandomSubsettingLoadBalancer.java new file mode 100644 index 00000000000..ad4de9e8921 --- /dev/null +++ b/util/src/main/java/io/grpc/util/RandomSubsettingLoadBalancer.java @@ -0,0 +1,161 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.util; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkNotNull; +import static com.google.common.base.Preconditions.checkState; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.hash.HashCode; +import com.google.common.hash.HashFunction; +import com.google.common.hash.Hashing; +import com.google.common.primitives.Ints; +import io.grpc.EquivalentAddressGroup; +import io.grpc.LoadBalancer; +import io.grpc.Status; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.Collections; +import java.util.Comparator; +import java.util.Random; + + +/** + * Wraps a child {@code LoadBalancer}, separating the total set of backends into smaller subsets for + * the child balancer to balance across. + * + *

This implements random subsetting gRFC: + * https://https://github.com/grpc/proposal/blob/master/A68-random-subsetting.md + */ +final class RandomSubsettingLoadBalancer extends LoadBalancer { + private final GracefulSwitchLoadBalancer switchLb; + private final HashFunction hashFunc; + + public RandomSubsettingLoadBalancer(Helper helper) { + this(helper, new Random().nextInt()); + } + + @VisibleForTesting + RandomSubsettingLoadBalancer(Helper helper, int seed) { + switchLb = new GracefulSwitchLoadBalancer(checkNotNull(helper, "helper")); + hashFunc = Hashing.murmur3_128(seed); + } + + @Override + public Status acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { + RandomSubsettingLoadBalancerConfig config = + (RandomSubsettingLoadBalancerConfig) + resolvedAddresses.getLoadBalancingPolicyConfig(); + + ResolvedAddresses subsetAddresses = filterEndpoints(resolvedAddresses, config.subsetSize); + + return switchLb.acceptResolvedAddresses( + subsetAddresses.toBuilder() + .setLoadBalancingPolicyConfig(config.childConfig) + .build()); + } + + // implements the subsetting algorithm, as described in A68: + // https://github.com/grpc/proposal/pull/423 + private ResolvedAddresses filterEndpoints(ResolvedAddresses resolvedAddresses, int subsetSize) { + if (subsetSize >= resolvedAddresses.getAddresses().size()) { + return resolvedAddresses; + } + + ArrayList endpointWithHashList = + new ArrayList<>(resolvedAddresses.getAddresses().size()); + + for (EquivalentAddressGroup addressGroup : resolvedAddresses.getAddresses()) { + HashCode hashCode = hashFunc.hashString( + addressGroup.getAddresses().get(0).toString(), + StandardCharsets.UTF_8); + endpointWithHashList.add(new EndpointWithHash(addressGroup, hashCode.asLong())); + } + + Collections.sort(endpointWithHashList, new HashAddressComparator()); + + ArrayList addressGroups = new ArrayList<>(subsetSize); + + for (int idx = 0; idx < subsetSize; ++idx) { + addressGroups.add(endpointWithHashList.get(idx).addressGroup); + } + + return resolvedAddresses.toBuilder().setAddresses(addressGroups).build(); + } + + @Override + public void handleNameResolutionError(Status error) { + switchLb.handleNameResolutionError(error); + } + + @Override + public void shutdown() { + switchLb.shutdown(); + } + + private static final class EndpointWithHash { + public final EquivalentAddressGroup addressGroup; + public final long hashCode; + + public EndpointWithHash(EquivalentAddressGroup addressGroup, long hashCode) { + this.addressGroup = addressGroup; + this.hashCode = hashCode; + } + } + + private static final class HashAddressComparator implements Comparator { + @Override + public int compare(EndpointWithHash lhs, EndpointWithHash rhs) { + return Long.compare(lhs.hashCode, rhs.hashCode); + } + } + + public static final class RandomSubsettingLoadBalancerConfig { + public final int subsetSize; + public final Object childConfig; + + private RandomSubsettingLoadBalancerConfig(int subsetSize, Object childConfig) { + this.subsetSize = subsetSize; + this.childConfig = childConfig; + } + + public static class Builder { + int subsetSize; + Object childConfig; + + public Builder setSubsetSize(long subsetSize) { + checkArgument(subsetSize > 0L, "Subset size must be greater than 0"); + // clamping subset size to Integer.MAX_VALUE due to collection indexing limitations in JVM + this.subsetSize = Ints.saturatedCast(subsetSize); + return this; + } + + public Builder setChildConfig(Object childConfig) { + this.childConfig = checkNotNull(childConfig, "childConfig"); + return this; + } + + public RandomSubsettingLoadBalancerConfig build() { + checkState(subsetSize != 0L, "Subset size must be set before building the config"); + return new RandomSubsettingLoadBalancerConfig( + subsetSize, + checkNotNull(childConfig, "childConfig")); + } + } + } +} diff --git a/util/src/main/java/io/grpc/util/RandomSubsettingLoadBalancerProvider.java b/util/src/main/java/io/grpc/util/RandomSubsettingLoadBalancerProvider.java new file mode 100644 index 00000000000..edcbf48a201 --- /dev/null +++ b/util/src/main/java/io/grpc/util/RandomSubsettingLoadBalancerProvider.java @@ -0,0 +1,86 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.util; + +import io.grpc.Internal; +import io.grpc.LoadBalancer; +import io.grpc.LoadBalancerProvider; +import io.grpc.NameResolver.ConfigOrError; +import io.grpc.Status; +import io.grpc.internal.JsonUtil; +import java.util.Map; + +@Internal +public final class RandomSubsettingLoadBalancerProvider extends LoadBalancerProvider { + private static final String POLICY_NAME = "random_subsetting_experimental"; + + @Override + public LoadBalancer newLoadBalancer(LoadBalancer.Helper helper) { + return new RandomSubsettingLoadBalancer(helper); + } + + @Override + public boolean isAvailable() { + return true; + } + + @Override + public int getPriority() { + return 5; + } + + @Override + public String getPolicyName() { + return POLICY_NAME; + } + + @Override + public ConfigOrError parseLoadBalancingPolicyConfig(Map rawConfig) { + try { + return parseLoadBalancingPolicyConfigInternal(rawConfig); + } catch (RuntimeException e) { + return ConfigOrError.fromError( + Status.UNAVAILABLE + .withCause(e) + .withDescription("Failed parsing configuration for " + getPolicyName())); + } + } + + private ConfigOrError parseLoadBalancingPolicyConfigInternal(Map rawConfig) { + Long subsetSize = JsonUtil.getNumberAsLong(rawConfig, "subsetSize"); + if (subsetSize == null) { + return ConfigOrError.fromError( + Status.UNAVAILABLE.withDescription( + "Subset size missing in " + getPolicyName() + ", LB policy config=" + rawConfig)); + } + + ConfigOrError childConfig = GracefulSwitchLoadBalancer.parseLoadBalancingPolicyConfig( + JsonUtil.getListOfObjects(rawConfig, "childPolicy")); + if (childConfig.getError() != null) { + return ConfigOrError.fromError(Status.UNAVAILABLE + .withDescription( + "Failed to parse child in " + getPolicyName() + ", LB policy config=" + rawConfig) + .withCause(childConfig.getError().asRuntimeException())); + } + + return ConfigOrError.fromConfig( + new RandomSubsettingLoadBalancer.RandomSubsettingLoadBalancerConfig.Builder() + .setSubsetSize(subsetSize) + .setChildConfig(childConfig.getConfig()) + .build()); + } +} diff --git a/util/src/main/java/io/grpc/util/RoundRobinLoadBalancer.java b/util/src/main/java/io/grpc/util/RoundRobinLoadBalancer.java index a06bae545df..22940e875ac 100644 --- a/util/src/main/java/io/grpc/util/RoundRobinLoadBalancer.java +++ b/util/src/main/java/io/grpc/util/RoundRobinLoadBalancer.java @@ -27,7 +27,6 @@ import com.google.common.base.Preconditions; import io.grpc.ConnectivityState; import io.grpc.EquivalentAddressGroup; -import io.grpc.Internal; import io.grpc.LoadBalancer; import io.grpc.NameResolver; import java.util.ArrayList; @@ -41,10 +40,9 @@ * A {@link LoadBalancer} that provides round-robin load-balancing over the {@link * EquivalentAddressGroup}s from the {@link NameResolver}. */ -@Internal -public class RoundRobinLoadBalancer extends MultiChildLoadBalancer { +final class RoundRobinLoadBalancer extends MultiChildLoadBalancer { private final AtomicInteger sequence = new AtomicInteger(new Random().nextInt()); - protected SubchannelPicker currentPicker = new EmptyPicker(); + private SubchannelPicker currentPicker = new FixedResultPicker(PickResult.withNoResult()); public RoundRobinLoadBalancer(Helper helper) { super(helper); @@ -70,7 +68,7 @@ protected void updateOverallBalancingState() { } if (isConnecting) { - updateBalancingState(CONNECTING, new EmptyPicker()); + updateBalancingState(CONNECTING, new FixedResultPicker(PickResult.withNoResult())); } else { updateBalancingState(TRANSIENT_FAILURE, createReadyPicker(getChildLbStates())); } @@ -87,7 +85,7 @@ private void updateBalancingState(ConnectivityState state, SubchannelPicker pick } } - protected SubchannelPicker createReadyPicker(Collection children) { + private SubchannelPicker createReadyPicker(Collection children) { List pickerList = new ArrayList<>(); for (ChildLbState child : children) { SubchannelPicker picker = child.getCurrentPicker(); @@ -97,6 +95,24 @@ protected SubchannelPicker createReadyPicker(Collection children) return new ReadyPicker(pickerList, sequence); } + @Override + protected ChildLbState createChildLbState(Object key) { + return new ChildLbState(key, pickFirstLbProvider) { + @Override + protected ChildLbStateHelper createChildHelper() { + return new ChildLbStateHelper() { + @Override + public void updateBalancingState(ConnectivityState newState, SubchannelPicker newPicker) { + super.updateBalancingState(newState, newPicker); + if (!resolvingAddresses && newState == IDLE) { + getLb().requestConnection(); + } + } + }; + } + }; + } + @VisibleForTesting static class ReadyPicker extends SubchannelPicker { private final List subchannelPickers; // non-empty @@ -163,22 +179,4 @@ public boolean equals(Object o) { && new HashSet<>(subchannelPickers).containsAll(other.subchannelPickers); } } - - @VisibleForTesting - static final class EmptyPicker extends SubchannelPicker { - @Override - public PickResult pickSubchannel(PickSubchannelArgs args) { - return PickResult.withNoResult(); - } - - @Override - public int hashCode() { - return getClass().hashCode(); - } - - @Override - public boolean equals(Object o) { - return o instanceof EmptyPicker; - } - } } diff --git a/util/src/main/java/io/grpc/util/TransmitStatusRuntimeExceptionInterceptor.java b/util/src/main/java/io/grpc/util/TransmitStatusRuntimeExceptionInterceptor.java index bead2be4e9e..b477ae1fdfb 100644 --- a/util/src/main/java/io/grpc/util/TransmitStatusRuntimeExceptionInterceptor.java +++ b/util/src/main/java/io/grpc/util/TransmitStatusRuntimeExceptionInterceptor.java @@ -219,6 +219,17 @@ public void run() { }); } + @Override + @ExperimentalApi("https://github.com/grpc/grpc-java/issues/11021") + public void setOnReadyThreshold(final int numBytes) { + serializingExecutor.execute(new Runnable() { + @Override + public void run() { + SerializingServerCall.super.setOnReadyThreshold(numBytes); + } + }); + } + @Override public void setCompression(final String compressor) { serializingExecutor.execute(new Runnable() { diff --git a/util/src/main/resources/META-INF/services/io.grpc.LoadBalancerProvider b/util/src/main/resources/META-INF/services/io.grpc.LoadBalancerProvider index 1fdd69cb00b..d973a6f6728 100644 --- a/util/src/main/resources/META-INF/services/io.grpc.LoadBalancerProvider +++ b/util/src/main/resources/META-INF/services/io.grpc.LoadBalancerProvider @@ -1,2 +1,3 @@ io.grpc.util.SecretRoundRobinLoadBalancerProvider$Provider io.grpc.util.OutlierDetectionLoadBalancerProvider +io.grpc.util.RandomSubsettingLoadBalancerProvider diff --git a/util/src/test/java/io/grpc/util/AdvancedTlsX509KeyManagerTest.java b/util/src/test/java/io/grpc/util/AdvancedTlsX509KeyManagerTest.java new file mode 100644 index 00000000000..b8431d4f991 --- /dev/null +++ b/util/src/test/java/io/grpc/util/AdvancedTlsX509KeyManagerTest.java @@ -0,0 +1,209 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.util; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +import io.grpc.internal.FakeClock; +import io.grpc.internal.testing.TestUtils; +import io.grpc.testing.TlsTesting; +import java.io.File; +import java.security.PrivateKey; +import java.security.cert.X509Certificate; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeUnit; +import java.util.logging.Handler; +import java.util.logging.Level; +import java.util.logging.LogRecord; +import java.util.logging.Logger; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Unit tests for {@link AdvancedTlsX509KeyManager}. */ +@RunWith(JUnit4.class) +public class AdvancedTlsX509KeyManagerTest { + private static final String SERVER_0_KEY_FILE = "server0.key"; + private static final String SERVER_0_PEM_FILE = "server0.pem"; + private static final String CLIENT_0_KEY_FILE = "client.key"; + private static final String CLIENT_0_PEM_FILE = "client.pem"; + + private ScheduledExecutorService executor; + + private File serverKey0File; + private File serverCert0File; + private File clientKey0File; + private File clientCert0File; + + private PrivateKey serverKey0; + private X509Certificate[] serverCert0; + private PrivateKey clientKey0; + private X509Certificate[] clientCert0; + + @Before + public void setUp() throws Exception { + executor = new FakeClock().getScheduledExecutorService(); + serverKey0File = TestUtils.loadCert(SERVER_0_KEY_FILE); + serverCert0File = TestUtils.loadCert(SERVER_0_PEM_FILE); + clientKey0File = TestUtils.loadCert(CLIENT_0_KEY_FILE); + clientCert0File = TestUtils.loadCert(CLIENT_0_PEM_FILE); + serverKey0 = CertificateUtils.getPrivateKey(TlsTesting.loadCert(SERVER_0_KEY_FILE)); + serverCert0 = CertificateUtils.getX509Certificates(TlsTesting.loadCert(SERVER_0_PEM_FILE)); + clientKey0 = CertificateUtils.getPrivateKey(TlsTesting.loadCert(CLIENT_0_KEY_FILE)); + clientCert0 = CertificateUtils.getX509Certificates(TlsTesting.loadCert(CLIENT_0_PEM_FILE)); + } + + @Test + public void updateTrustCredentials_replacesIssuers() throws Exception { + // Overall happy path checking of public API. + AdvancedTlsX509KeyManager serverKeyManager = new AdvancedTlsX509KeyManager(); + + serverKeyManager.updateIdentityCredentials(serverCert0, serverKey0); + String alias1 = serverKeyManager.chooseEngineServerAlias(null, null, null); + assertEquals(AdvancedTlsX509KeyManager.ALIAS_PREFIX + "1", alias1); + assertEquals(serverKey0, serverKeyManager.getPrivateKey(alias1)); + assertArrayEquals(serverCert0, serverKeyManager.getCertificateChain(alias1)); + + serverKeyManager.updateIdentityCredentials(clientCert0File, clientKey0File); + String alias2 = serverKeyManager.chooseEngineServerAlias(null, null, null); + assertEquals(AdvancedTlsX509KeyManager.ALIAS_PREFIX + "2", alias2); + assertEquals(clientKey0, serverKeyManager.getPrivateKey(alias2)); + assertArrayEquals(clientCert0, serverKeyManager.getCertificateChain(alias2)); + // Previous alias still resolves — retained to allow in-progress handshakes to complete. + assertEquals(serverKey0, serverKeyManager.getPrivateKey(alias1)); + assertArrayEquals(serverCert0, serverKeyManager.getCertificateChain(alias1)); + + serverKeyManager.updateIdentityCredentials(serverCert0File, serverKey0File, 1, + TimeUnit.MINUTES, executor); + String alias3 = serverKeyManager.chooseEngineServerAlias(null, null, null); + assertEquals(serverKey0, serverKeyManager.getPrivateKey(alias3)); + assertArrayEquals(serverCert0, serverKeyManager.getCertificateChain(alias3)); + // alias1 is now two rotations back — no longer retained. + assertNull(serverKeyManager.getPrivateKey(alias1)); + + serverKeyManager.updateIdentityCredentials(serverCert0, serverKey0); + String alias4 = serverKeyManager.chooseEngineServerAlias(null, null, null); + assertEquals(serverKey0, serverKeyManager.getPrivateKey(alias4)); + assertArrayEquals(serverCert0, serverKeyManager.getCertificateChain(alias4)); + } + + @Test + public void allAliasMethods_returnNullBeforeCredentialsLoaded() { + AdvancedTlsX509KeyManager keyManager = new AdvancedTlsX509KeyManager(); + + assertNull(keyManager.chooseClientAlias(null, null, null)); + assertNull(keyManager.chooseServerAlias(null, null, null)); + assertNull(keyManager.chooseEngineClientAlias(null, null, null)); + assertNull(keyManager.chooseEngineServerAlias(null, null, null)); + assertNull(keyManager.getClientAliases(null, null)); + assertNull(keyManager.getServerAliases(null, null)); + assertNull(keyManager.getPrivateKey("key-1")); + assertNull(keyManager.getCertificateChain("key-1")); + } + + @Test + public void allAliasMethods_agreeAfterCredentialLoad() throws Exception { + AdvancedTlsX509KeyManager keyManager = new AdvancedTlsX509KeyManager(); + keyManager.updateIdentityCredentials(serverCert0, serverKey0); + + String expectedAlias = AdvancedTlsX509KeyManager.ALIAS_PREFIX + "1"; + assertEquals(expectedAlias, keyManager.chooseClientAlias(null, null, null)); + assertEquals(expectedAlias, keyManager.chooseServerAlias(null, null, null)); + assertEquals(expectedAlias, keyManager.chooseEngineClientAlias(null, null, null)); + assertEquals(expectedAlias, keyManager.chooseEngineServerAlias(null, null, null)); + assertArrayEquals(new String[]{expectedAlias}, keyManager.getClientAliases(null, null)); + assertArrayEquals(new String[]{expectedAlias}, keyManager.getServerAliases(null, null)); + } + + @Test + public void credentialSettingParameterValidity() throws Exception { + // Checking edge cases of public API parameter setting. + AdvancedTlsX509KeyManager serverKeyManager = new AdvancedTlsX509KeyManager(); + NullPointerException npe = assertThrows(NullPointerException.class, () -> serverKeyManager + .updateIdentityCredentials(serverCert0, null)); + assertEquals("key", npe.getMessage()); + + npe = assertThrows(NullPointerException.class, () -> serverKeyManager + .updateIdentityCredentials(null, serverKey0)); + assertEquals("certs", npe.getMessage()); + + npe = assertThrows(NullPointerException.class, () -> serverKeyManager + .updateIdentityCredentials(null, serverKey0File)); + assertEquals("certFile", npe.getMessage()); + + npe = assertThrows(NullPointerException.class, () -> serverKeyManager + .updateIdentityCredentials(serverCert0File, null)); + assertEquals("keyFile", npe.getMessage()); + + npe = assertThrows(NullPointerException.class, () -> serverKeyManager + .updateIdentityCredentials(serverCert0File, serverKey0File, 1, null, + executor)); + assertEquals("unit", npe.getMessage()); + + npe = assertThrows(NullPointerException.class, () -> serverKeyManager + .updateIdentityCredentials(serverCert0File, serverKey0File, 1, + TimeUnit.MINUTES, null)); + assertEquals("executor", npe.getMessage()); + + Logger log = Logger.getLogger(AdvancedTlsX509KeyManager.class.getName()); + TestHandler handler = new TestHandler(); + log.addHandler(handler); + log.setUseParentHandlers(false); + log.setLevel(Level.FINE); + serverKeyManager.updateIdentityCredentials(serverCert0File, serverKey0File, -1, + TimeUnit.SECONDS, executor); + log.removeHandler(handler); + for (LogRecord record : handler.getRecords()) { + if (record.getMessage().contains("Default value of ")) { + assertTrue(true); + return; + } + } + fail("Log message related to setting default values not found"); + } + + + private static class TestHandler extends Handler { + private final List records = new ArrayList<>(); + + @Override + public void publish(LogRecord record) { + records.add(record); + } + + @Override + public void flush() { + } + + @Override + public void close() throws SecurityException { + } + + public List getRecords() { + return records; + } + } + +} diff --git a/util/src/test/java/io/grpc/util/AdvancedTlsX509TrustManagerTest.java b/util/src/test/java/io/grpc/util/AdvancedTlsX509TrustManagerTest.java new file mode 100644 index 00000000000..b9803b03570 --- /dev/null +++ b/util/src/test/java/io/grpc/util/AdvancedTlsX509TrustManagerTest.java @@ -0,0 +1,223 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.util; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertThrows; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import com.google.common.collect.Iterables; +import com.google.common.io.Files; +import io.grpc.internal.FakeClock; +import io.grpc.internal.testing.TestUtils; +import io.grpc.testing.TlsTesting; +import io.grpc.util.AdvancedTlsX509TrustManager.Verification; +import java.io.File; +import java.io.IOException; +import java.net.Socket; +import java.security.GeneralSecurityException; +import java.security.cert.CertificateException; +import java.security.cert.X509Certificate; +import java.util.ArrayList; +import java.util.List; +import java.util.NoSuchElementException; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeUnit; +import java.util.logging.Handler; +import java.util.logging.Level; +import java.util.logging.LogRecord; +import java.util.logging.Logger; +import javax.net.ssl.SSLSocket; +import org.codehaus.mojo.animal_sniffer.IgnoreJRERequirement; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Unit tests for {@link AdvancedTlsX509TrustManager}. */ +@RunWith(JUnit4.class) +@IgnoreJRERequirement +public class AdvancedTlsX509TrustManagerTest { + + private static final String CA_PEM_FILE = "ca.pem"; + private static final String SERVER_0_PEM_FILE = "server0.pem"; + private static final String SERVER_1_PEM_FILE = "server1.pem"; + private File caCertFile; + private File serverCert0File; + private File serverCert1File; + + private X509Certificate[] caCert; + private X509Certificate[] serverCert0; + private X509Certificate[] serverCert1; + + private FakeClock fakeClock; + private ScheduledExecutorService executor; + + @Before + public void setUp() throws IOException, GeneralSecurityException { + fakeClock = new FakeClock(); + executor = fakeClock.getScheduledExecutorService(); + caCertFile = TestUtils.loadCert(CA_PEM_FILE); + caCert = CertificateUtils.getX509Certificates(TlsTesting.loadCert(CA_PEM_FILE)); + serverCert0File = TestUtils.loadCert(SERVER_0_PEM_FILE); + serverCert0 = CertificateUtils.getX509Certificates(TlsTesting.loadCert(SERVER_0_PEM_FILE)); + serverCert1File = TestUtils.loadCert(SERVER_1_PEM_FILE); + serverCert1 = CertificateUtils.getX509Certificates(TlsTesting.loadCert(SERVER_1_PEM_FILE)); + } + + @Test + public void updateTrustCredentials_replacesIssuers() throws Exception { + // Overall happy path checking of public API. + AdvancedTlsX509TrustManager trustManager = AdvancedTlsX509TrustManager.newBuilder().build(); + trustManager.updateTrustCredentials(serverCert0File); + assertArrayEquals(serverCert0, trustManager.getAcceptedIssuers()); + + trustManager.updateTrustCredentials(caCert); + assertArrayEquals(caCert, trustManager.getAcceptedIssuers()); + + trustManager.updateTrustCredentials(serverCert0File, 1, TimeUnit.MINUTES, + executor); + assertArrayEquals(serverCert0, trustManager.getAcceptedIssuers()); + + trustManager.updateTrustCredentials(serverCert0File); + assertArrayEquals(serverCert0, trustManager.getAcceptedIssuers()); + } + + @Test + public void systemDefaultDelegateManagerInstantiation() throws Exception { + AdvancedTlsX509TrustManager trustManager = AdvancedTlsX509TrustManager.newBuilder().build(); + trustManager.useSystemDefaultTrustCerts(); + CertificateException ce = assertThrows(CertificateException.class, () -> trustManager + .checkServerTrusted(serverCert0, "RSA", new Socket())); + assertEquals("socket is not a type of SSLSocket", ce.getMessage()); + } + + @Test + public void credentialSettingParameterValidity() throws Exception { + // Checking edge cases of public API parameter setting. + AdvancedTlsX509TrustManager trustManager = AdvancedTlsX509TrustManager.newBuilder().build(); + + NullPointerException npe = assertThrows(NullPointerException.class, () -> trustManager + .updateTrustCredentials(null, 1, null, null)); + assertEquals("trustCertFile", npe.getMessage()); + + npe = assertThrows(NullPointerException.class, () -> trustManager + .updateTrustCredentials(caCertFile, 1, null, null)); + assertEquals("unit", npe.getMessage()); + + npe = assertThrows(NullPointerException.class, () -> trustManager + .updateTrustCredentials(caCertFile, 1, TimeUnit.MINUTES, null)); + assertEquals("executor", npe.getMessage()); + + Logger log = Logger.getLogger(AdvancedTlsX509TrustManager.class.getName()); + TestHandler handler = new TestHandler(); + log.addHandler(handler); + log.setUseParentHandlers(false); + log.setLevel(Level.FINE); + trustManager.updateTrustCredentials(serverCert0File, -1, TimeUnit.SECONDS, executor); + log.removeHandler(handler); + try { + LogRecord logRecord = Iterables.find(handler.getRecords(), + record -> record.getMessage().contains("Default value of ")); + assertNotNull(logRecord); + } catch (NoSuchElementException e) { + throw new AssertionError("Log message related to setting default values not found"); + } + } + + @Test + public void missingFile_throwsFileNotFoundException() throws Exception { + AdvancedTlsX509TrustManager trustManager = AdvancedTlsX509TrustManager.newBuilder().build(); + File nonExistentFile = new File("missing_cert.pem"); + Exception thrown = + assertThrows(Exception.class, () -> trustManager.updateTrustCredentials(nonExistentFile)); + assertNotNull(thrown); + assertEquals(thrown.getMessage(), + "Certificate file not found or not readable: " + nonExistentFile.getAbsolutePath()); + } + + @Test + public void clientTrustedWithSocketTest() throws Exception { + AdvancedTlsX509TrustManager trustManager = AdvancedTlsX509TrustManager.newBuilder() + .setVerification(Verification.CERTIFICATE_ONLY_VERIFICATION).build(); + trustManager.updateTrustCredentials(caCert); + SSLSocket sslSocket = mock(SSLSocket.class); + when(sslSocket.isConnected()).thenReturn(true); + when(sslSocket.getHandshakeSession()).thenReturn(null); + CertificateException ce = assertThrows(CertificateException.class, () -> trustManager + .checkClientTrusted(serverCert0, "RSA", sslSocket)); + assertEquals("No handshake session", ce.getMessage()); + } + + @Test + public void updateTrustCredentials_rotate() throws GeneralSecurityException, IOException { + AdvancedTlsX509TrustManager trustManager = AdvancedTlsX509TrustManager.newBuilder().build(); + trustManager.updateTrustCredentials(serverCert0File); + assertArrayEquals(serverCert0, trustManager.getAcceptedIssuers()); + + trustManager.updateTrustCredentials(serverCert0File, 1, TimeUnit.MINUTES, + executor); + assertArrayEquals(serverCert0, trustManager.getAcceptedIssuers()); + + fakeClock.forwardTime(1, TimeUnit.MINUTES); + assertArrayEquals(serverCert0, trustManager.getAcceptedIssuers()); + + serverCert0File.setLastModified(serverCert0File.lastModified() - 2000); + + fakeClock.forwardTime(1, TimeUnit.MINUTES); + assertArrayEquals(serverCert0, trustManager.getAcceptedIssuers()); + + long beforeModify = serverCert0File.lastModified(); + Files.copy(serverCert1File, serverCert0File); + serverCert0File.setLastModified(beforeModify); + + // although file content changed, file modification time is not changed + fakeClock.forwardTime(1, TimeUnit.MINUTES); + assertArrayEquals(serverCert0, trustManager.getAcceptedIssuers()); + + serverCert0File.setLastModified(beforeModify + 2000); + + // file modification time changed + fakeClock.forwardTime(1, TimeUnit.MINUTES); + assertArrayEquals(serverCert1, trustManager.getAcceptedIssuers()); + } + + private static class TestHandler extends Handler { + private final List records = new ArrayList<>(); + + @Override + public void publish(LogRecord record) { + records.add(record); + } + + @Override + public void flush() { + } + + @Override + public void close() throws SecurityException { + } + + public List getRecords() { + return records; + } + } + +} diff --git a/util/src/test/java/io/grpc/util/CertificateUtilsTest.java b/util/src/test/java/io/grpc/util/CertificateUtilsTest.java index aef99c0f378..dbddd35bca3 100644 --- a/util/src/test/java/io/grpc/util/CertificateUtilsTest.java +++ b/util/src/test/java/io/grpc/util/CertificateUtilsTest.java @@ -18,12 +18,12 @@ import static com.google.common.truth.Truth.assertThat; -import com.google.common.base.Charsets; import io.grpc.internal.testing.TestUtils; import java.io.ByteArrayInputStream; import java.io.IOException; import java.io.InputStream; import java.math.BigInteger; +import java.nio.charset.StandardCharsets; import java.security.PrivateKey; import java.security.cert.CertificateException; import java.security.cert.X509Certificate; @@ -80,7 +80,7 @@ public void readCaPemFile() throws CertificateException, IOException { @Test public void readBadFormatKeyFile() throws Exception { - InputStream in = new ByteArrayInputStream(BAD_PEM_FORMAT.getBytes(Charsets.UTF_8)); + InputStream in = new ByteArrayInputStream(BAD_PEM_FORMAT.getBytes(StandardCharsets.UTF_8)); try { CertificateUtils.getPrivateKey(in); Assert.fail("no exception thrown"); @@ -92,7 +92,7 @@ public void readBadFormatKeyFile() throws Exception { @Test public void readBadContentKeyFile() { - InputStream in = new ByteArrayInputStream(BAD_PEM_CONTENT.getBytes(Charsets.UTF_8)); + InputStream in = new ByteArrayInputStream(BAD_PEM_CONTENT.getBytes(StandardCharsets.UTF_8)); try { CertificateUtils.getPrivateKey(in); Assert.fail("no exception thrown"); diff --git a/util/src/test/java/io/grpc/util/GracefulSwitchLoadBalancerTest.java b/util/src/test/java/io/grpc/util/GracefulSwitchLoadBalancerTest.java index 6e89176e9c9..0467f9526f6 100644 --- a/util/src/test/java/io/grpc/util/GracefulSwitchLoadBalancerTest.java +++ b/util/src/test/java/io/grpc/util/GracefulSwitchLoadBalancerTest.java @@ -18,18 +18,22 @@ import static com.google.common.truth.Truth.assertThat; import static io.grpc.ConnectivityState.CONNECTING; +import static io.grpc.ConnectivityState.IDLE; import static io.grpc.ConnectivityState.READY; import static io.grpc.ConnectivityState.TRANSIENT_FAILURE; -import static io.grpc.util.GracefulSwitchLoadBalancer.BUFFER_PICKER; +import static org.junit.Assert.assertThrows; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.inOrder; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.when; +import com.google.common.testing.EqualsTester; +import io.grpc.ConnectivityState; import io.grpc.ConnectivityStateInfo; import io.grpc.EquivalentAddressGroup; import io.grpc.LoadBalancer; @@ -41,17 +45,16 @@ import io.grpc.LoadBalancer.SubchannelPicker; import io.grpc.LoadBalancerProvider; import io.grpc.LoadBalancerRegistry; +import io.grpc.NameResolver.ConfigOrError; import io.grpc.Status; import java.net.SocketAddress; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; -import org.junit.Before; -import org.junit.Rule; import org.junit.Test; -import org.junit.rules.ExpectedException; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; import org.mockito.ArgumentCaptor; @@ -62,91 +65,185 @@ */ @RunWith(JUnit4.class) public class GracefulSwitchLoadBalancerTest { - @SuppressWarnings("deprecation") // https://github.com/grpc/grpc-java/issues/7467 - @Rule - public final ExpectedException thrown = ExpectedException.none(); - - private final LoadBalancerRegistry lbRegistry = new LoadBalancerRegistry(); - // maps policy name to lb provide - private final Map lbProviders = new HashMap<>(); - // maps policy name to lb - private final Map balancers = new HashMap<>(); + private static final Object FAKE_CONFIG = new Object(); + + private final Map balancers = new HashMap<>(); private final Map helpers = new HashMap<>(); private final Helper mockHelper = mock(Helper.class); private final GracefulSwitchLoadBalancer gracefulSwitchLb = new GracefulSwitchLoadBalancer(mockHelper); - private final String[] lbPolicies = {"lb_policy_0", "lb_policy_1", "lb_policy_2", "lb_policy_3"}; - - @Before - public void setUp() { - for (String lbPolicy : lbPolicies) { - LoadBalancerProvider lbProvider = new FakeLoadBalancerProvider(lbPolicy); - lbProviders.put(lbPolicy, lbProvider); - lbRegistry.register(lbProvider); - } + private final LoadBalancerProvider[] lbPolicies = { + new FakeLoadBalancerProvider("lb_policy_0"), + new FakeLoadBalancerProvider("lb_policy_1"), + new FakeLoadBalancerProvider("lb_policy_2"), + new FakeLoadBalancerProvider("lb_policy_3"), + }; + + @Test + public void transientFailureOnInitialResolutionError() { + gracefulSwitchLb.handleNameResolutionError(Status.DATA_LOSS); + ArgumentCaptor pickerCaptor = ArgumentCaptor.forClass(SubchannelPicker.class); + verify(mockHelper).updateBalancingState(eq(TRANSIENT_FAILURE), pickerCaptor.capture()); + SubchannelPicker picker = pickerCaptor.getValue(); + assertThat(picker.pickSubchannel(mock(PickSubchannelArgs.class)).getStatus().getCode()) + .isEqualTo(Status.Code.DATA_LOSS); } + @Deprecated @Test + public void handleSubchannelState_shouldThrow() { + assertIsOk(gracefulSwitchLb.acceptResolvedAddresses(addressesBuilder() + .setLoadBalancingPolicyConfig(createConfig(lbPolicies[0], new Object())) + .build())); + Subchannel subchannel = mock(Subchannel.class); + ConnectivityStateInfo connectivityStateInfo = ConnectivityStateInfo.forNonError(READY); + assertThrows(UnsupportedOperationException.class, + () -> gracefulSwitchLb.handleSubchannelState(subchannel, connectivityStateInfo)); + } + + @Test + @Deprecated public void canHandleEmptyAddressListFromNameResolutionForwardedToLatestPolicy() { - gracefulSwitchLb.switchTo(lbProviders.get(lbPolicies[0])); + assertIsOk(gracefulSwitchLb.acceptResolvedAddresses(addressesBuilder() + .setLoadBalancingPolicyConfig(createConfig(lbPolicies[0], new Object())) + .build())); LoadBalancer lb0 = balancers.get(lbPolicies[0]); Helper helper0 = helpers.get(lb0); SubchannelPicker picker = mock(SubchannelPicker.class); helper0.updateBalancingState(READY, picker); assertThat(gracefulSwitchLb.canHandleEmptyAddressListFromNameResolution()).isFalse(); - doReturn(true).when(lb0).canHandleEmptyAddressListFromNameResolution(); + when(lb0.canHandleEmptyAddressListFromNameResolution()).thenReturn(true); assertThat(gracefulSwitchLb.canHandleEmptyAddressListFromNameResolution()).isTrue(); - gracefulSwitchLb.switchTo(lbProviders.get(lbPolicies[1])); + assertIsOk(gracefulSwitchLb.acceptResolvedAddresses(addressesBuilder() + .setLoadBalancingPolicyConfig(createConfig(lbPolicies[1], new Object())) + .build())); LoadBalancer lb1 = balancers.get(lbPolicies[1]); assertThat(gracefulSwitchLb.canHandleEmptyAddressListFromNameResolution()).isFalse(); - doReturn(true).when(lb1).canHandleEmptyAddressListFromNameResolution(); + when(lb1.canHandleEmptyAddressListFromNameResolution()).thenReturn(true); assertThat(gracefulSwitchLb.canHandleEmptyAddressListFromNameResolution()).isTrue(); - gracefulSwitchLb.switchTo(lbProviders.get(lbPolicies[2])); + assertIsOk(gracefulSwitchLb.acceptResolvedAddresses(addressesBuilder() + .setLoadBalancingPolicyConfig(createConfig(lbPolicies[2], new Object())) + .build())); LoadBalancer lb2 = balancers.get(lbPolicies[2]); assertThat(gracefulSwitchLb.canHandleEmptyAddressListFromNameResolution()).isFalse(); - doReturn(true).when(lb2).canHandleEmptyAddressListFromNameResolution(); + when(lb2.canHandleEmptyAddressListFromNameResolution()).thenReturn(true); assertThat(gracefulSwitchLb.canHandleEmptyAddressListFromNameResolution()).isTrue(); } + @Deprecated @Test public void handleResolvedAddressesAndNameResolutionErrorForwardedToLatestPolicy() { - gracefulSwitchLb.switchTo(lbProviders.get(lbPolicies[0])); + ResolvedAddresses addresses = newFakeAddresses(); + Object child0Config = new Object(); + gracefulSwitchLb.handleResolvedAddresses(addresses.toBuilder() + .setLoadBalancingPolicyConfig(createConfig(lbPolicies[0], child0Config)) + .build()); LoadBalancer lb0 = balancers.get(lbPolicies[0]); + verify(lb0).handleResolvedAddresses(addresses.toBuilder() + .setLoadBalancingPolicyConfig(child0Config) + .build()); Helper helper0 = helpers.get(lb0); SubchannelPicker picker = mock(SubchannelPicker.class); helper0.updateBalancingState(READY, picker); + gracefulSwitchLb.handleNameResolutionError(Status.DATA_LOSS); + verify(lb0).handleNameResolutionError(Status.DATA_LOSS); + Object child1Config = new Object(); + addresses = newFakeAddresses(); + gracefulSwitchLb.handleResolvedAddresses(addresses.toBuilder() + .setLoadBalancingPolicyConfig(createConfig(lbPolicies[1], child1Config)) + .build()); + LoadBalancer lb1 = balancers.get(lbPolicies[1]); + verify(lb0, never()).handleResolvedAddresses(addresses.toBuilder() + .setLoadBalancingPolicyConfig(child1Config) + .build()); + verify(lb1).handleResolvedAddresses(addresses.toBuilder() + .setLoadBalancingPolicyConfig(child1Config) + .build()); + gracefulSwitchLb.handleNameResolutionError(Status.ALREADY_EXISTS); + verify(lb0, never()).handleNameResolutionError(Status.ALREADY_EXISTS); + verify(lb1).handleNameResolutionError(Status.ALREADY_EXISTS); + + Object child2Config = new Object(); + addresses = newFakeAddresses(); + gracefulSwitchLb.handleResolvedAddresses(addresses.toBuilder() + .setLoadBalancingPolicyConfig(createConfig(lbPolicies[2], child2Config)) + .build()); + verify(lb1).shutdown(); + LoadBalancer lb2 = balancers.get(lbPolicies[2]); + verify(lb0, never()).handleResolvedAddresses(addresses.toBuilder() + .setLoadBalancingPolicyConfig(child2Config) + .build()); + verify(lb1, never()).handleResolvedAddresses(addresses.toBuilder() + .setLoadBalancingPolicyConfig(child2Config) + .build()); + verify(lb2).handleResolvedAddresses(addresses.toBuilder() + .setLoadBalancingPolicyConfig(child2Config) + .build()); + gracefulSwitchLb.handleNameResolutionError(Status.CANCELLED); + verify(lb0, never()).handleNameResolutionError(Status.CANCELLED); + verify(lb1, never()).handleNameResolutionError(Status.CANCELLED); + verify(lb2).handleNameResolutionError(Status.CANCELLED); + + verifyNoMoreInteractions(lb0, lb1, lb2); + } + + @Test + public void acceptResolvedAddressesAndNameResolutionErrorForwardedToLatestPolicy() { ResolvedAddresses addresses = newFakeAddresses(); - gracefulSwitchLb.handleResolvedAddresses(addresses); - verify(lb0).handleResolvedAddresses(addresses); + Object child0Config = new Object(); + assertIsOk(gracefulSwitchLb.acceptResolvedAddresses(addresses.toBuilder() + .setLoadBalancingPolicyConfig(createConfig(lbPolicies[0], child0Config)) + .build())); + LoadBalancer lb0 = balancers.get(lbPolicies[0]); + verify(lb0).acceptResolvedAddresses(addresses.toBuilder() + .setLoadBalancingPolicyConfig(child0Config) + .build()); + Helper helper0 = helpers.get(lb0); + SubchannelPicker picker = mock(SubchannelPicker.class); + helper0.updateBalancingState(READY, picker); gracefulSwitchLb.handleNameResolutionError(Status.DATA_LOSS); verify(lb0).handleNameResolutionError(Status.DATA_LOSS); - gracefulSwitchLb.switchTo(lbProviders.get(lbPolicies[1])); - LoadBalancer lb1 = balancers.get(lbPolicies[1]); + Object child1Config = new Object(); addresses = newFakeAddresses(); - gracefulSwitchLb.handleResolvedAddresses(addresses); - verify(lb0, never()).handleResolvedAddresses(addresses); - verify(lb1).handleResolvedAddresses(addresses); + assertIsOk(gracefulSwitchLb.acceptResolvedAddresses(addresses.toBuilder() + .setLoadBalancingPolicyConfig(createConfig(lbPolicies[1], child1Config)) + .build())); + LoadBalancer lb1 = balancers.get(lbPolicies[1]); + verify(lb0, never()).acceptResolvedAddresses(addresses.toBuilder() + .setLoadBalancingPolicyConfig(child1Config) + .build()); + verify(lb1).acceptResolvedAddresses(addresses.toBuilder() + .setLoadBalancingPolicyConfig(child1Config) + .build()); gracefulSwitchLb.handleNameResolutionError(Status.ALREADY_EXISTS); verify(lb0, never()).handleNameResolutionError(Status.ALREADY_EXISTS); verify(lb1).handleNameResolutionError(Status.ALREADY_EXISTS); - gracefulSwitchLb.switchTo(lbProviders.get(lbPolicies[2])); + Object child2Config = new Object(); + addresses = newFakeAddresses(); + assertIsOk(gracefulSwitchLb.acceptResolvedAddresses(addresses.toBuilder() + .setLoadBalancingPolicyConfig(createConfig(lbPolicies[2], child2Config)) + .build())); verify(lb1).shutdown(); LoadBalancer lb2 = balancers.get(lbPolicies[2]); - addresses = newFakeAddresses(); - gracefulSwitchLb.handleResolvedAddresses(addresses); - verify(lb0, never()).handleResolvedAddresses(addresses); - verify(lb1, never()).handleResolvedAddresses(addresses); - verify(lb2).handleResolvedAddresses(addresses); + verify(lb0, never()).acceptResolvedAddresses(addresses.toBuilder() + .setLoadBalancingPolicyConfig(child2Config) + .build()); + verify(lb1, never()).acceptResolvedAddresses(addresses.toBuilder() + .setLoadBalancingPolicyConfig(child2Config) + .build()); + verify(lb2).acceptResolvedAddresses(addresses.toBuilder() + .setLoadBalancingPolicyConfig(child2Config) + .build()); gracefulSwitchLb.handleNameResolutionError(Status.CANCELLED); verify(lb0, never()).handleNameResolutionError(Status.CANCELLED); verify(lb1, never()).handleNameResolutionError(Status.CANCELLED); @@ -157,24 +254,32 @@ public void handleResolvedAddressesAndNameResolutionErrorForwardedToLatestPolicy @Test public void shutdownTriggeredWhenSwitchAndForwardedWhenSwitchLbShutdown() { - gracefulSwitchLb.switchTo(lbProviders.get(lbPolicies[0])); + assertIsOk(gracefulSwitchLb.acceptResolvedAddresses(addressesBuilder() + .setLoadBalancingPolicyConfig(createConfig(lbPolicies[0], new Object())) + .build())); LoadBalancer lb0 = balancers.get(lbPolicies[0]); Helper helper0 = helpers.get(lb0); SubchannelPicker picker = mock(SubchannelPicker.class); helper0.updateBalancingState(READY, picker); - gracefulSwitchLb.switchTo(lbProviders.get(lbPolicies[1])); + assertIsOk(gracefulSwitchLb.acceptResolvedAddresses(addressesBuilder() + .setLoadBalancingPolicyConfig(createConfig(lbPolicies[1], new Object())) + .build())); LoadBalancer lb1 = balancers.get(lbPolicies[1]); verify(lb1, never()).shutdown(); - gracefulSwitchLb.switchTo(lbProviders.get(lbPolicies[2])); + assertIsOk(gracefulSwitchLb.acceptResolvedAddresses(addressesBuilder() + .setLoadBalancingPolicyConfig(createConfig(lbPolicies[2], new Object())) + .build())); verify(lb1).shutdown(); LoadBalancer lb2 = balancers.get(lbPolicies[2]); verify(lb0, never()).shutdown(); helpers.get(lb2).updateBalancingState(READY, mock(SubchannelPicker.class)); verify(lb0).shutdown(); - gracefulSwitchLb.switchTo(lbProviders.get(lbPolicies[3])); + assertIsOk(gracefulSwitchLb.acceptResolvedAddresses(addressesBuilder() + .setLoadBalancingPolicyConfig(createConfig(lbPolicies[3], new Object())) + .build())); LoadBalancer lb3 = balancers.get(lbPolicies[3]); verify(lb2, never()).shutdown(); verify(lb3, never()).shutdown(); @@ -182,13 +287,13 @@ public void shutdownTriggeredWhenSwitchAndForwardedWhenSwitchLbShutdown() { gracefulSwitchLb.shutdown(); verify(lb2).shutdown(); verify(lb3).shutdown(); - - verifyNoMoreInteractions(lb0, lb1, lb2, lb3); } @Test public void requestConnectionForwardedToLatestPolicies() { - gracefulSwitchLb.switchTo(lbProviders.get(lbPolicies[0])); + assertIsOk(gracefulSwitchLb.acceptResolvedAddresses(addressesBuilder() + .setLoadBalancingPolicyConfig(createConfig(lbPolicies[0], new Object())) + .build())); LoadBalancer lb0 = balancers.get(lbPolicies[0]); Helper helper0 = helpers.get(lb0); SubchannelPicker picker = mock(SubchannelPicker.class); @@ -197,12 +302,16 @@ public void requestConnectionForwardedToLatestPolicies() { gracefulSwitchLb.requestConnection(); verify(lb0).requestConnection(); - gracefulSwitchLb.switchTo(lbProviders.get(lbPolicies[1])); + assertIsOk(gracefulSwitchLb.acceptResolvedAddresses(addressesBuilder() + .setLoadBalancingPolicyConfig(createConfig(lbPolicies[1], new Object())) + .build())); LoadBalancer lb1 = balancers.get(lbPolicies[1]); gracefulSwitchLb.requestConnection(); verify(lb1).requestConnection(); - gracefulSwitchLb.switchTo(lbProviders.get(lbPolicies[2])); + assertIsOk(gracefulSwitchLb.acceptResolvedAddresses(addressesBuilder() + .setLoadBalancingPolicyConfig(createConfig(lbPolicies[2], new Object())) + .build())); verify(lb1).shutdown(); LoadBalancer lb2 = balancers.get(lbPolicies[2]); gracefulSwitchLb.requestConnection(); @@ -215,17 +324,19 @@ public void requestConnectionForwardedToLatestPolicies() { gracefulSwitchLb.requestConnection(); verify(lb2, times(2)).requestConnection(); - gracefulSwitchLb.switchTo(lbProviders.get(lbPolicies[3])); + assertIsOk(gracefulSwitchLb.acceptResolvedAddresses(addressesBuilder() + .setLoadBalancingPolicyConfig(createConfig(lbPolicies[3], new Object())) + .build())); LoadBalancer lb3 = balancers.get(lbPolicies[3]); gracefulSwitchLb.requestConnection(); verify(lb3).requestConnection(); - - verifyNoMoreInteractions(lb0, lb1, lb2, lb3); } @Test public void createSubchannelForwarded() { - gracefulSwitchLb.switchTo(lbProviders.get(lbPolicies[0])); + assertIsOk(gracefulSwitchLb.acceptResolvedAddresses(addressesBuilder() + .setLoadBalancingPolicyConfig(createConfig(lbPolicies[0], new Object())) + .build())); LoadBalancer lb0 = balancers.get(lbPolicies[0]); Helper helper0 = helpers.get(lb0); SubchannelPicker picker = mock(SubchannelPicker.class); @@ -235,7 +346,9 @@ public void createSubchannelForwarded() { helper0.createSubchannel(createSubchannelArgs); verify(mockHelper).createSubchannel(createSubchannelArgs); - gracefulSwitchLb.switchTo(lbProviders.get(lbPolicies[1])); + assertIsOk(gracefulSwitchLb.acceptResolvedAddresses(addressesBuilder() + .setLoadBalancingPolicyConfig(createConfig(lbPolicies[1], new Object())) + .build())); LoadBalancer lb1 = balancers.get(lbPolicies[1]); Helper helper1 = helpers.get(lb1); createSubchannelArgs = newFakeCreateSubchannelArgs(); @@ -245,27 +358,45 @@ public void createSubchannelForwarded() { createSubchannelArgs = newFakeCreateSubchannelArgs(); helper0.createSubchannel(createSubchannelArgs); verify(mockHelper).createSubchannel(createSubchannelArgs); + } - verifyNoMoreInteractions(lb0, lb1); + @Test + public void updateBalancingStateIsGraceful_Ready() { + updateBalancingStateIsGraceful(READY); + } + + @Test + public void updateBalancingStateIsGraceful_TransientFailure() { + updateBalancingStateIsGraceful(TRANSIENT_FAILURE); } @Test - public void updateBalancingStateIsGraceful() { - gracefulSwitchLb.switchTo(lbProviders.get(lbPolicies[0])); + public void updateBalancingStateIsGraceful_Idle() { + updateBalancingStateIsGraceful(IDLE); + } + + public void updateBalancingStateIsGraceful(ConnectivityState swapsOnState) { + assertIsOk(gracefulSwitchLb.acceptResolvedAddresses(addressesBuilder() + .setLoadBalancingPolicyConfig(createConfig(lbPolicies[0], new Object())) + .build())); LoadBalancer lb0 = balancers.get(lbPolicies[0]); Helper helper0 = helpers.get(lb0); SubchannelPicker picker = mock(SubchannelPicker.class); helper0.updateBalancingState(READY, picker); verify(mockHelper).updateBalancingState(READY, picker); - gracefulSwitchLb.switchTo(lbProviders.get(lbPolicies[1])); + assertIsOk(gracefulSwitchLb.acceptResolvedAddresses(addressesBuilder() + .setLoadBalancingPolicyConfig(createConfig(lbPolicies[1], new Object())) + .build())); LoadBalancer lb1 = balancers.get(lbPolicies[1]); Helper helper1 = helpers.get(lb1); picker = mock(SubchannelPicker.class); helper1.updateBalancingState(CONNECTING, picker); verify(mockHelper, never()).updateBalancingState(CONNECTING, picker); - gracefulSwitchLb.switchTo(lbProviders.get(lbPolicies[2])); + assertIsOk(gracefulSwitchLb.acceptResolvedAddresses(addressesBuilder() + .setLoadBalancingPolicyConfig(createConfig(lbPolicies[2], new Object())) + .build())); verify(lb1).shutdown(); LoadBalancer lb2 = balancers.get(lbPolicies[2]); Helper helper2 = helpers.get(lb2); @@ -273,20 +404,22 @@ public void updateBalancingStateIsGraceful() { helper2.updateBalancingState(CONNECTING, picker); verify(mockHelper, never()).updateBalancingState(CONNECTING, picker); - // lb2 reports READY + // lb2 reports swapsOnState SubchannelPicker picker2 = mock(SubchannelPicker.class); - helper2.updateBalancingState(READY, picker2); + helper2.updateBalancingState(swapsOnState, picker2); verify(lb0).shutdown(); - verify(mockHelper).updateBalancingState(READY, picker2); + verify(mockHelper).updateBalancingState(swapsOnState, picker2); - gracefulSwitchLb.switchTo(lbProviders.get(lbPolicies[3])); + assertIsOk(gracefulSwitchLb.acceptResolvedAddresses(addressesBuilder() + .setLoadBalancingPolicyConfig(createConfig(lbPolicies[3], new Object())) + .build())); LoadBalancer lb3 = balancers.get(lbPolicies[3]); Helper helper3 = helpers.get(lb3); SubchannelPicker picker3 = mock(SubchannelPicker.class); helper3.updateBalancingState(CONNECTING, picker3); verify(mockHelper, never()).updateBalancingState(CONNECTING, picker3); - // lb2 out of READY + // lb2 out of swapsOnState picker2 = mock(SubchannelPicker.class); helper2.updateBalancingState(CONNECTING, picker2); verify(mockHelper, never()).updateBalancingState(CONNECTING, picker2); @@ -296,13 +429,13 @@ public void updateBalancingStateIsGraceful() { picker3 = mock(SubchannelPicker.class); helper3.updateBalancingState(CONNECTING, picker3); verify(mockHelper).updateBalancingState(CONNECTING, picker3); - - verifyNoMoreInteractions(lb0, lb1, lb2, lb3); } @Test public void switchWhileOldPolicyIsNotReady() { - gracefulSwitchLb.switchTo(lbProviders.get(lbPolicies[0])); + assertIsOk(gracefulSwitchLb.acceptResolvedAddresses(addressesBuilder() + .setLoadBalancingPolicyConfig(createConfig(lbPolicies[0], new Object())) + .build())); LoadBalancer lb0 = balancers.get(lbPolicies[0]); Helper helper0 = helpers.get(lb0); SubchannelPicker picker = mock(SubchannelPicker.class); @@ -311,7 +444,9 @@ public void switchWhileOldPolicyIsNotReady() { helper0.updateBalancingState(CONNECTING, picker); verify(lb0, never()).shutdown(); - gracefulSwitchLb.switchTo(lbProviders.get(lbPolicies[1])); + assertIsOk(gracefulSwitchLb.acceptResolvedAddresses(addressesBuilder() + .setLoadBalancingPolicyConfig(createConfig(lbPolicies[1], new Object())) + .build())); verify(lb0).shutdown(); LoadBalancer lb1 = balancers.get(lbPolicies[1]); @@ -321,22 +456,25 @@ public void switchWhileOldPolicyIsNotReady() { verify(mockHelper).updateBalancingState(CONNECTING, picker); verify(lb1, never()).shutdown(); - gracefulSwitchLb.switchTo(lbProviders.get(lbPolicies[2])); + assertIsOk(gracefulSwitchLb.acceptResolvedAddresses(addressesBuilder() + .setLoadBalancingPolicyConfig(createConfig(lbPolicies[2], new Object())) + .build())); verify(lb1).shutdown(); - LoadBalancer lb2 = balancers.get(lbPolicies[2]); - - verifyNoMoreInteractions(lb0, lb1, lb2); } @Test public void switchWhileOldPolicyGoesFromReadyToNotReady() { - gracefulSwitchLb.switchTo(lbProviders.get(lbPolicies[0])); + assertIsOk(gracefulSwitchLb.acceptResolvedAddresses(addressesBuilder() + .setLoadBalancingPolicyConfig(createConfig(lbPolicies[0], new Object())) + .build())); LoadBalancer lb0 = balancers.get(lbPolicies[0]); Helper helper0 = helpers.get(lb0); SubchannelPicker picker = mock(SubchannelPicker.class); helper0.updateBalancingState(READY, picker); - gracefulSwitchLb.switchTo(lbProviders.get(lbPolicies[1])); + assertIsOk(gracefulSwitchLb.acceptResolvedAddresses(addressesBuilder() + .setLoadBalancingPolicyConfig(createConfig(lbPolicies[1], new Object())) + .build())); verify(lb0, never()).shutdown(); LoadBalancer lb1 = balancers.get(lbPolicies[1]); @@ -354,20 +492,22 @@ public void switchWhileOldPolicyGoesFromReadyToNotReady() { picker1 = mock(SubchannelPicker.class); helper1.updateBalancingState(READY, picker1); verify(mockHelper).updateBalancingState(READY, picker1); - - verifyNoMoreInteractions(lb0, lb1); } @Test public void switchWhileOldPolicyGoesFromReadyToNotReadyWhileNewPolicyStillIdle() { - gracefulSwitchLb.switchTo(lbProviders.get(lbPolicies[0])); + assertIsOk(gracefulSwitchLb.acceptResolvedAddresses(addressesBuilder() + .setLoadBalancingPolicyConfig(createConfig(lbPolicies[0], new Object())) + .build())); LoadBalancer lb0 = balancers.get(lbPolicies[0]); InOrder inOrder = inOrder(lb0, mockHelper); Helper helper0 = helpers.get(lb0); SubchannelPicker picker = mock(SubchannelPicker.class); helper0.updateBalancingState(READY, picker); - gracefulSwitchLb.switchTo(lbProviders.get(lbPolicies[1])); + assertIsOk(gracefulSwitchLb.acceptResolvedAddresses(addressesBuilder() + .setLoadBalancingPolicyConfig(createConfig(lbPolicies[1], new Object())) + .build())); verify(lb0, never()).shutdown(); LoadBalancer lb1 = balancers.get(lbPolicies[1]); @@ -377,54 +517,65 @@ public void switchWhileOldPolicyGoesFromReadyToNotReadyWhileNewPolicyStillIdle() helper0.updateBalancingState(CONNECTING, picker); verify(mockHelper, never()).updateBalancingState(CONNECTING, picker); - inOrder.verify(mockHelper).updateBalancingState(CONNECTING, BUFFER_PICKER); + ArgumentCaptor pickerCaptor = ArgumentCaptor.forClass(SubchannelPicker.class); + inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), pickerCaptor.capture()); + assertThat(pickerCaptor.getValue().pickSubchannel(mock(PickSubchannelArgs.class)).hasResult()) + .isFalse(); + inOrder.verify(lb0).shutdown(); // shutdown after update picker = mock(SubchannelPicker.class); helper1.updateBalancingState(CONNECTING, picker); inOrder.verify(mockHelper).updateBalancingState(CONNECTING, picker); - - inOrder.verifyNoMoreInteractions(); - verifyNoMoreInteractions(lb1); } @Test public void newPolicyNameTheSameAsPendingPolicy_shouldHaveNoEffect() { - gracefulSwitchLb.switchTo(lbProviders.get(lbPolicies[0])); + assertIsOk(gracefulSwitchLb.acceptResolvedAddresses(addressesBuilder() + .setLoadBalancingPolicyConfig(createConfig(lbPolicies[0], new Object())) + .build())); LoadBalancer lb0 = balancers.get(lbPolicies[0]); Helper helper0 = helpers.get(lb0); SubchannelPicker picker = mock(SubchannelPicker.class); helper0.updateBalancingState(READY, picker); - gracefulSwitchLb.switchTo(lbProviders.get(lbPolicies[1])); + assertIsOk(gracefulSwitchLb.acceptResolvedAddresses(addressesBuilder() + .setLoadBalancingPolicyConfig(createConfig(lbPolicies[1], new Object())) + .build())); LoadBalancer lb1 = balancers.get(lbPolicies[1]); - gracefulSwitchLb.switchTo(lbProviders.get(lbPolicies[1])); + assertIsOk(gracefulSwitchLb.acceptResolvedAddresses(addressesBuilder() + .setLoadBalancingPolicyConfig(createConfig(lbPolicies[1], new Object())) + .build())); assertThat(balancers.get(lbPolicies[1])).isSameInstanceAs(lb1); - - verifyNoMoreInteractions(lb0, lb1); } @Test public void newPolicyNameTheSameAsCurrentPolicy_shouldShutdownPendingLb() { - gracefulSwitchLb.switchTo(lbProviders.get(lbPolicies[0])); + assertIsOk(gracefulSwitchLb.acceptResolvedAddresses(addressesBuilder() + .setLoadBalancingPolicyConfig(createConfig(lbPolicies[0], new Object())) + .build())); LoadBalancer lb0 = balancers.get(lbPolicies[0]); - gracefulSwitchLb.switchTo(lbProviders.get(lbPolicies[0])); + assertIsOk(gracefulSwitchLb.acceptResolvedAddresses(addressesBuilder() + .setLoadBalancingPolicyConfig(createConfig(lbPolicies[0], new Object())) + .build())); assertThat(balancers.get(lbPolicies[0])).isSameInstanceAs(lb0); Helper helper0 = helpers.get(lb0); SubchannelPicker picker = mock(SubchannelPicker.class); helper0.updateBalancingState(READY, picker); - gracefulSwitchLb.switchTo(lbProviders.get(lbPolicies[1])); + assertIsOk(gracefulSwitchLb.acceptResolvedAddresses(addressesBuilder() + .setLoadBalancingPolicyConfig(createConfig(lbPolicies[1], new Object())) + .build())); LoadBalancer lb1 = balancers.get(lbPolicies[1]); - gracefulSwitchLb.switchTo(lbProviders.get(lbPolicies[0])); + assertIsOk(gracefulSwitchLb.acceptResolvedAddresses(addressesBuilder() + .setLoadBalancingPolicyConfig(createConfig(lbPolicies[0], new Object())) + .build())); verify(lb1).shutdown(); assertThat(balancers.get(lbPolicies[0])).isSameInstanceAs(lb0); - - verifyNoMoreInteractions(lb0, lb1); } @@ -442,6 +593,7 @@ final class LoadBalancerFactoryWithId extends LoadBalancer.Factory { @Override public LoadBalancer newLoadBalancer(Helper helper) { LoadBalancer balancer = mock(LoadBalancer.class); + when(balancer.acceptResolvedAddresses(any())).thenReturn(Status.OK); balancers.add(balancer); return balancer; } @@ -461,39 +613,67 @@ public int hashCode() { } } - gracefulSwitchLb.switchTo(new LoadBalancerFactoryWithId(0)); + assertIsOk(gracefulSwitchLb.acceptResolvedAddresses(addressesBuilder() + .setLoadBalancingPolicyConfig(createConfig(new LoadBalancerFactoryWithId(0), new Object())) + .build())); assertThat(balancers).hasSize(1); LoadBalancer lb0 = balancers.get(0); - gracefulSwitchLb.switchTo(new LoadBalancerFactoryWithId(0)); + assertIsOk(gracefulSwitchLb.acceptResolvedAddresses(addressesBuilder() + .setLoadBalancingPolicyConfig(createConfig(new LoadBalancerFactoryWithId(0), new Object())) + .build())); assertThat(balancers).hasSize(1); - gracefulSwitchLb.switchTo(new LoadBalancerFactoryWithId(1)); + assertIsOk(gracefulSwitchLb.acceptResolvedAddresses(addressesBuilder() + .setLoadBalancingPolicyConfig(createConfig(new LoadBalancerFactoryWithId(1), new Object())) + .build())); assertThat(balancers).hasSize(2); - LoadBalancer lb1 = balancers.get(1); verify(lb0).shutdown(); + } - verifyNoMoreInteractions(lb0, lb1); + @Test + public void configEquals() { + Object config = new Object(); + new EqualsTester() + .addEqualityGroup(createConfig(lbPolicies[0], config), createConfig(lbPolicies[0], config)) + .addEqualityGroup(createConfig(lbPolicies[1], config)) + .addEqualityGroup(createConfig(lbPolicies[0], new Object())) + .testEquals(); } @Test - public void transientFailureOnInitialResolutionError() { - gracefulSwitchLb.handleNameResolutionError(Status.DATA_LOSS); - ArgumentCaptor pickerCaptor = ArgumentCaptor.forClass(SubchannelPicker.class); - verify(mockHelper).updateBalancingState(eq(TRANSIENT_FAILURE), pickerCaptor.capture()); - SubchannelPicker picker = pickerCaptor.getValue(); - assertThat(picker.pickSubchannel(mock(PickSubchannelArgs.class)).getStatus().getCode()) - .isEqualTo(Status.Code.DATA_LOSS); + public void parseLoadBalancingPolicyConfig_null_fails() { + ConfigOrError result = GracefulSwitchLoadBalancer.parseLoadBalancingPolicyConfig(null); + assertThat(result.getError()).isNotNull(); } - @Deprecated @Test - public void handleSubchannelState_shouldThrow() { - gracefulSwitchLb.switchTo(lbProviders.get(lbPolicies[0])); - Subchannel subchannel = mock(Subchannel.class); - ConnectivityStateInfo connectivityStateInfo = ConnectivityStateInfo.forNonError(READY); - thrown.expect(UnsupportedOperationException.class); - gracefulSwitchLb.handleSubchannelState(subchannel, connectivityStateInfo); + public void parseLoadBalancingPolicyConfig_empty_fails() { + ConfigOrError result = GracefulSwitchLoadBalancer.parseLoadBalancingPolicyConfig( + Arrays.asList()); + assertThat(result.getError()).isNotNull(); + } + + @Test + public void parseLoadBalancingPolicyConfig_missing_fails() { + LoadBalancerRegistry lbRegistry = new LoadBalancerRegistry(); + ConfigOrError result = GracefulSwitchLoadBalancer.parseLoadBalancingPolicyConfig( + Arrays.asList(Collections.singletonMap("lb_policy_0", Collections.emptyMap())), lbRegistry); + assertThat(result.getError()).isNotNull(); + } + + @Test + public void parseLoadBalancingPolicyConfig_succeeds() { + LoadBalancerRegistry lbRegistry = new LoadBalancerRegistry(); + lbRegistry.register(lbPolicies[0]); + ConfigOrError result = GracefulSwitchLoadBalancer.parseLoadBalancingPolicyConfig( + Arrays.asList(Collections.singletonMap("lb_policy_0", Collections.emptyMap())), lbRegistry); + assertThat(result.getError()).isNull(); + assertThat(result.getConfig()).isInstanceOf(GracefulSwitchLoadBalancer.Config.class); + GracefulSwitchLoadBalancer.Config config = + (GracefulSwitchLoadBalancer.Config) result.getConfig(); + assertThat(config.childFactory).isEqualTo(lbPolicies[0]); + assertThat(config.childConfig).isEqualTo(FAKE_CONFIG); } private final class FakeLoadBalancerProvider extends LoadBalancerProvider { @@ -522,10 +702,31 @@ public String getPolicyName() { @Override public LoadBalancer newLoadBalancer(Helper helper) { LoadBalancer balancer = mock(LoadBalancer.class); - balancers.put(policyName, balancer); + when(balancer.acceptResolvedAddresses(any())).thenReturn(Status.OK); + balancers.put(this, balancer); helpers.put(balancer, helper); return balancer; } + + @Override + public ConfigOrError parseLoadBalancingPolicyConfig(Map rawConfig) { + return ConfigOrError.fromConfig(FAKE_CONFIG); + } + } + + private static void assertIsOk(Status status) { + assertThat(status.isOk()).isTrue(); + } + + private ResolvedAddresses.Builder addressesBuilder() { + return ResolvedAddresses.newBuilder() + .setAddresses( + Collections.singletonList(new EquivalentAddressGroup(mock(SocketAddress.class)))); + } + + private static Object createConfig( + LoadBalancer.Factory childFactory, Object childConfig) { + return GracefulSwitchLoadBalancer.createLoadBalancingPolicyConfig(childFactory, childConfig); } private static ResolvedAddresses newFakeAddresses() { diff --git a/util/src/test/java/io/grpc/util/MultiChildLoadBalancerTest.java b/util/src/test/java/io/grpc/util/MultiChildLoadBalancerTest.java index d90c5eab92c..14dc8518756 100644 --- a/util/src/test/java/io/grpc/util/MultiChildLoadBalancerTest.java +++ b/util/src/test/java/io/grpc/util/MultiChildLoadBalancerTest.java @@ -21,7 +21,6 @@ import static io.grpc.ConnectivityState.READY; import static io.grpc.ConnectivityState.SHUTDOWN; import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotEquals; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; import static org.mockito.AdditionalAnswers.delegatesTo; @@ -32,15 +31,16 @@ import static org.mockito.Mockito.never; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.verifyNoMoreInteractions; import com.google.common.collect.Lists; +import com.google.common.testing.EqualsTester; import io.grpc.Attributes; import io.grpc.ConnectivityState; import io.grpc.ConnectivityStateInfo; import io.grpc.EquivalentAddressGroup; import io.grpc.LoadBalancer; import io.grpc.Status; +import io.grpc.internal.PickFirstLoadBalancerProvider; import io.grpc.util.AbstractTestHelper.FakeSocketAddress; import io.grpc.util.MultiChildLoadBalancer.ChildLbState; import io.grpc.util.MultiChildLoadBalancer.Endpoint; @@ -52,7 +52,6 @@ import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; -import java.util.stream.Collectors; import org.junit.Before; import org.junit.Rule; import org.junit.Test; @@ -81,8 +80,8 @@ public class MultiChildLoadBalancerTest { private ArgumentCaptor stateCaptor; @Captor private ArgumentCaptor createArgsCaptor; - private TestHelper testHelperInst = new TestHelper(); - private LoadBalancer.Helper mockHelper = + private final TestHelper testHelperInst = new TestHelper(); + private final LoadBalancer.Helper mockHelper = mock(LoadBalancer.Helper.class, delegatesTo(testHelperInst)); private TestLb loadBalancer; @@ -99,7 +98,7 @@ public void setUp() { } @Test - public void pickAfterResolved() throws Exception { + public void pickAfterResolved() { Status addressesAcceptanceStatus = loadBalancer.acceptResolvedAddresses( LoadBalancer.ResolvedAddresses.newBuilder().setAddresses(servers).build()); assertThat(addressesAcceptanceStatus.isOk()).isTrue(); @@ -127,11 +126,11 @@ public void pickAfterResolved() throws Exception { (TestLb.TestSubchannelPicker) pickerCaptor.getValue(); assertThat(subchannelPicker.getReadySubchannels()).containsExactly(readySubchannel); - verifyNoMoreInteractions(mockHelper); + AbstractTestHelper.verifyNoMoreMeaningfulInteractions(mockHelper); } @Test - public void pickAfterResolvedUpdatedHosts() throws Exception { + public void pickAfterResolvedUpdatedHosts() { Attributes.Key key = Attributes.Key.create("check-that-it-is-propagated"); FakeSocketAddress removedAddr = new FakeSocketAddress("removed"); EquivalentAddressGroup removedEag = new EquivalentAddressGroup(removedAddr); @@ -153,8 +152,7 @@ public void pickAfterResolvedUpdatedHosts() throws Exception { LoadBalancer.Subchannel removedSubchannel = getSubchannel(removedEag); LoadBalancer.Subchannel oldSubchannel = getSubchannel(oldEag1); LoadBalancer.SubchannelStateListener removedListener = - testHelperInst.getSubchannelStateListeners() - .get(testHelperInst.getRealForMockSubChannel(removedSubchannel)); + testHelperInst.getSubchannelStateListener(removedSubchannel); inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), pickerCaptor.capture()); @@ -168,8 +166,6 @@ public void pickAfterResolvedUpdatedHosts() throws Exception { verify(removedSubchannel, times(1)).requestConnection(); verify(oldSubchannel, times(1)).requestConnection(); - assertThat(getChildEags(loadBalancer)).containsExactly(removedEag, oldEag1); - // This time with Attributes List latestServers = Lists.newArrayList(oldEag2, newEag); @@ -186,16 +182,16 @@ public void pickAfterResolvedUpdatedHosts() throws Exception { removedListener.onSubchannelState(ConnectivityStateInfo.forNonError(SHUTDOWN)); deliverSubchannelState(newSubchannel, ConnectivityStateInfo.forNonError(READY)); - assertThat(getChildEags(loadBalancer)).containsExactly(oldEag2, newEag); - verify(mockHelper, times(3)).createSubchannel(any(LoadBalancer.CreateSubchannelArgs.class)); inOrder.verify(mockHelper, times(2)).updateBalancingState(eq(READY), pickerCaptor.capture()); + picker = pickerCaptor.getValue(); + assertThat(getList(picker)).containsExactly(oldSubchannel, newSubchannel); - verifyNoMoreInteractions(mockHelper); + AbstractTestHelper.verifyNoMoreMeaningfulInteractions(mockHelper); } @Test - public void pickFromMultiAddressEags() throws Exception { + public void pickFromMultiAddressEags() { List addressList1 = new ArrayList<>(); List addressList2 = new ArrayList<>(); for (int i = 0; i < 3; i++) { @@ -215,7 +211,7 @@ public void pickFromMultiAddressEags() throws Exception { LoadBalancer.ResolvedAddresses.newBuilder().setAddresses(multiGroups).build()); assertTrue(addressesAcceptanceStatus.isOk()); - LoadBalancer.Subchannel evens = subchannels.get(Collections.singletonList(eag1)); + LoadBalancer.Subchannel evens = getSubchannel(eag1); deliverSubchannelState(evens, ConnectivityStateInfo.forNonError(READY)); verify(mockHelper).updateBalancingState(eq(READY), pickerCaptor.capture()); assertThat(pickerCaptor.getValue()).isInstanceOf(TestLb.TestSubchannelPicker.class); @@ -244,37 +240,64 @@ public void testEndpoint_toString() { @Test public void testEndpoint_equals() { - assertEquals( - createEndpoint(Attributes.EMPTY, "addr1"), - createEndpoint(Attributes.EMPTY, "addr1")); - - assertEquals( - createEndpoint(Attributes.EMPTY, "addr1", "addr2"), - createEndpoint(Attributes.EMPTY, "addr2", "addr1")); - - assertEquals( - createEndpoint(Attributes.EMPTY, "addr1", "addr2"), - createEndpoint(affinity, "addr2", "addr1")); + new EqualsTester() + .addEqualityGroup( + createEndpoint(Attributes.EMPTY, "addr1"), + createEndpoint(Attributes.EMPTY, "addr1")) + .addEqualityGroup( + createEndpoint(Attributes.EMPTY, "addr1", "addr2"), + createEndpoint(Attributes.EMPTY, "addr2", "addr1"), + createEndpoint(affinity, "addr1", "addr2")) + .addEqualityGroup( + createEndpoint(Attributes.EMPTY, "addr1", "addr3")) + .addEqualityGroup( + createEndpoint(Attributes.EMPTY, "addr1", "addr2", "addr3", "addr4", "addr5", "addr6", + "addr7", "addr8", "addr9", "addr10"), + createEndpoint(Attributes.EMPTY, "addr2", "addr1", "addr3", "addr4", "addr5", "addr6", + "addr7", "addr8", "addr9", "addr10")) + .addEqualityGroup( + createEndpoint(Attributes.EMPTY, "addr1", "addr2", "addr3", "addr4", "addr5", "addr6", + "addr7", "addr8", "addr9", "addr11")) + .addEqualityGroup( + createEndpoint(Attributes.EMPTY, "addr1", "addr2", "addr3", "addr4", "addr5", "addr6", + "addr7", "addr8", "addr9", "addr10", "addr11")) + .testEquals(); + } - assertEquals( - createEndpoint(Attributes.EMPTY, "addr1", "addr2").hashCode(), - createEndpoint(affinity, "addr2", "addr1").hashCode()); + @Test + public void offsetIterable_positive() { + assertThat(MultiChildLoadBalancer.offsetIterable(Arrays.asList(1, 2, 3, 4), 9)) + .containsExactly(2, 3, 4, 1) + .inOrder(); + assertThat(MultiChildLoadBalancer.offsetIterable(Arrays.asList(1, 2, 3, 4, 5), 9)) + .containsExactly(5, 1, 2, 3, 4) + .inOrder(); + assertThat(MultiChildLoadBalancer.offsetIterable(Arrays.asList(1, 2, 3), 3)) + .containsExactly(1, 2, 3) + .inOrder(); + assertThat(MultiChildLoadBalancer.offsetIterable(Arrays.asList(1, 2, 3), 0)) + .containsExactly(1, 2, 3) + .inOrder(); + assertThat(MultiChildLoadBalancer.offsetIterable(Arrays.asList(1), 123)) + .containsExactly(1) + .inOrder(); + } + @Test + public void offsetIterable_negative() { + assertThat(MultiChildLoadBalancer.offsetIterable(Arrays.asList(1, 2, 3, 4), -1)) + .containsExactly(4, 1, 2, 3) + .inOrder(); } @Test - public void testEndpoint_notEquals() { - assertNotEquals( - createEndpoint(Attributes.EMPTY, "addr1", "addr2"), - createEndpoint(Attributes.EMPTY, "addr1", "addr3")); - - assertNotEquals( - createEndpoint(Attributes.EMPTY, "addr1"), - createEndpoint(Attributes.EMPTY, "addr1", "addr2")); - - assertNotEquals( - createEndpoint(Attributes.EMPTY, "addr1", "addr2"), - createEndpoint(Attributes.EMPTY, "addr1")); + public void offsetIterable_empty() { + assertThat(MultiChildLoadBalancer.offsetIterable(Arrays.asList(), 1)) + .isEmpty(); + assertThat(MultiChildLoadBalancer.offsetIterable(Arrays.asList(), 0)) + .isEmpty(); + assertThat(MultiChildLoadBalancer.offsetIterable(Arrays.asList(), -1)) + .isEmpty(); } private String addressesOnlyString(EquivalentAddressGroup eag) { @@ -321,14 +344,20 @@ private Endpoint createEndpoint(Attributes attr, String... names) { return new Endpoint(eag); } - private LoadBalancer.Subchannel getSubchannel(EquivalentAddressGroup removedEag) { - return subchannels.get(Collections.singletonList(removedEag)); - } + private LoadBalancer.Subchannel getSubchannel(EquivalentAddressGroup eag) { + if (PickFirstLoadBalancerProvider.isEnabledNewPickFirst()) { + for (SocketAddress addr : eag.getAddresses()) { + LoadBalancer.Subchannel subchannel = subchannels.get( + Arrays.asList(new EquivalentAddressGroup(addr, eag.getAttributes()))); + if (subchannel != null) { + return subchannel; + } + } + } else { + return subchannels.get(Collections.singletonList(eag)); + } - private static List getChildEags(MultiChildLoadBalancer loadBalancer) { - return loadBalancer.getChildLbStates().stream() - .map(ChildLbState::getEag) - .collect(Collectors.toList()); + return null; } private void deliverSubchannelState(LoadBalancer.Subchannel subchannel, @@ -345,16 +374,16 @@ protected TestLb(Helper mockHelper) { protected void updateOverallBalancingState() { ConnectivityState overallState = null; final Map childPickers = new HashMap<>(); + final Map childConnStates = new HashMap<>(); for (ChildLbState childLbState : getChildLbStates()) { - if (childLbState.isDeactivated()) { - continue; - } childPickers.put(childLbState.getKey(), childLbState.getCurrentPicker()); + childConnStates.put(childLbState.getKey(), childLbState.getCurrentState()); overallState = aggregateState(overallState, childLbState.getCurrentState()); } if (overallState != null) { - getHelper().updateBalancingState(overallState, new TestSubchannelPicker(childPickers)); + getHelper().updateBalancingState( + overallState, new TestSubchannelPicker(childPickers, childConnStates)); currentConnectivityState = overallState; } @@ -364,18 +393,17 @@ private class TestSubchannelPicker extends SubchannelPicker { Map childPickerMap; Map childStates = new HashMap<>(); - TestSubchannelPicker(Map childPickers) { - childPickerMap = childPickers; - for (Object key : childPickerMap.keySet()) { - childStates.put(key, getChildLbState(key).getCurrentState()); - } + TestSubchannelPicker( + Map childPickers, Map childStates) { + this.childPickerMap = childPickers; + this.childStates = childStates; } List getReadySubchannels() { List readySubchannels = new ArrayList<>(); for ( Map.Entry cur : childStates.entrySet()) { if (cur.getValue() == READY) { - Subchannel s = subchannels.get(Arrays.asList(getChildLbState(cur.getKey()).getEag())); + Subchannel s = childPickerMap.get(cur.getKey()).pickSubchannel(null).getSubchannel(); readySubchannels.add(s); } } diff --git a/util/src/test/java/io/grpc/util/OutlierDetectionLoadBalancerProviderTest.java b/util/src/test/java/io/grpc/util/OutlierDetectionLoadBalancerProviderTest.java index cf162ffaec0..d87aa85eb4f 100644 --- a/util/src/test/java/io/grpc/util/OutlierDetectionLoadBalancerProviderTest.java +++ b/util/src/test/java/io/grpc/util/OutlierDetectionLoadBalancerProviderTest.java @@ -89,7 +89,9 @@ public void parseLoadBalancingConfig_defaults() throws IOException { = (OutlierDetectionLoadBalancerConfig) configOrError.getConfig(); assertThat(config.successRateEjection).isNotNull(); assertThat(config.failurePercentageEjection).isNotNull(); - assertThat(config.childPolicy.getProvider().getPolicyName()).isEqualTo("round_robin"); + assertThat( + GracefulSwitchLoadBalancerAccessor.getChildProvider(config.childConfig).getPolicyName()) + .isEqualTo("round_robin"); } @Test @@ -135,7 +137,9 @@ public void parseLoadBalancingConfig_valuesSet() throws IOException { assertThat(config.failurePercentageEjection.minimumHosts).isEqualTo(100); assertThat(config.failurePercentageEjection.requestVolume).isEqualTo(100); - assertThat(config.childPolicy.getProvider().getPolicyName()).isEqualTo("round_robin"); + assertThat( + GracefulSwitchLoadBalancerAccessor.getChildProvider(config.childConfig).getPolicyName()) + .isEqualTo("round_robin"); } @SuppressWarnings("unchecked") diff --git a/util/src/test/java/io/grpc/util/OutlierDetectionLoadBalancerTest.java b/util/src/test/java/io/grpc/util/OutlierDetectionLoadBalancerTest.java index 87bd50a58bc..39f5b5fb7d6 100644 --- a/util/src/test/java/io/grpc/util/OutlierDetectionLoadBalancerTest.java +++ b/util/src/test/java/io/grpc/util/OutlierDetectionLoadBalancerTest.java @@ -19,8 +19,8 @@ import static com.google.common.truth.Truth.assertThat; import static com.google.common.truth.Truth.assertWithMessage; import static io.grpc.ConnectivityState.READY; +import static io.grpc.ConnectivityState.TRANSIENT_FAILURE; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; @@ -50,10 +50,10 @@ import io.grpc.LoadBalancerProvider; import io.grpc.Metadata; import io.grpc.Status; +import io.grpc.Status.Code; import io.grpc.SynchronizationContext; import io.grpc.internal.FakeClock; import io.grpc.internal.FakeClock.ScheduledTask; -import io.grpc.internal.ServiceConfigUtil.PolicySelection; import io.grpc.internal.TestUtils.StandardLoadBalancerProvider; import io.grpc.util.OutlierDetectionLoadBalancer.EndpointTracker; import io.grpc.util.OutlierDetectionLoadBalancer.OutlierDetectionLoadBalancerConfig; @@ -226,7 +226,7 @@ public Void answer(InvocationOnMock invocation) throws Throwable { when(mockStreamTracerFactory.newClientStreamTracer(any(), any())).thenReturn(mockStreamTracer); - loadBalancer = new OutlierDetectionLoadBalancer(mockHelper, fakeClock.getTimeProvider()); + loadBalancer = new OutlierDetectionLoadBalancer(mockHelper, fakeClock.getTicker()); } @Test @@ -243,7 +243,7 @@ public void handleNameResolutionError_withChildLb() { loadBalancer.acceptResolvedAddresses(buildResolvedAddress( new OutlierDetectionLoadBalancerConfig.Builder() .setSuccessRateEjection(new SuccessRateEjection.Builder().build()) - .setChildPolicy(new PolicySelection(mockChildLbProvider, null)).build(), + .setChildConfig(newChildConfig(mockChildLbProvider, null)).build(), new EquivalentAddressGroup(mockSocketAddress))); loadBalancer.handleNameResolutionError(Status.DEADLINE_EXCEEDED); @@ -258,7 +258,7 @@ public void shutdown() { loadBalancer.acceptResolvedAddresses(buildResolvedAddress( new OutlierDetectionLoadBalancerConfig.Builder() .setSuccessRateEjection(new SuccessRateEjection.Builder().build()) - .setChildPolicy(new PolicySelection(mockChildLbProvider, null)).build(), + .setChildConfig(newChildConfig(mockChildLbProvider, null)).build(), new EquivalentAddressGroup(mockSocketAddress))); loadBalancer.shutdown(); verify(mockChildLb).shutdown(); @@ -269,18 +269,18 @@ public void shutdown() { */ @Test public void acceptResolvedAddresses() { + Object childConfig = "theConfig"; OutlierDetectionLoadBalancerConfig config = new OutlierDetectionLoadBalancerConfig.Builder() .setSuccessRateEjection(new SuccessRateEjection.Builder().build()) - .setChildPolicy(new PolicySelection(mockChildLbProvider, null)).build(); + .setChildConfig(newChildConfig(mockChildLbProvider, childConfig)).build(); ResolvedAddresses resolvedAddresses = buildResolvedAddress(config, new EquivalentAddressGroup(mockSocketAddress)); loadBalancer.acceptResolvedAddresses(resolvedAddresses); // Handling of resolved addresses is delegated - verify(mockChildLb).handleResolvedAddresses( - resolvedAddresses.toBuilder().setLoadBalancingPolicyConfig(config.childPolicy.getConfig()) - .build()); + verify(mockChildLb).acceptResolvedAddresses( + resolvedAddresses.toBuilder().setLoadBalancingPolicyConfig(childConfig).build()); // There is a single pending task to run the outlier detection algorithm assertThat(fakeClock.getPendingTasks()).hasSize(1); @@ -299,7 +299,7 @@ public void acceptResolvedAddresses() { public void childLbRecreatesSubchannels() { OutlierDetectionLoadBalancerConfig config = new OutlierDetectionLoadBalancerConfig.Builder() .setSuccessRateEjection(new SuccessRateEjection.Builder().build()) - .setChildPolicy(new PolicySelection(fakeLbProvider, null)).build(); + .setChildConfig(newChildConfig(fakeLbProvider, null)).build(); loadBalancer.acceptResolvedAddresses(buildResolvedAddress(config, servers.get(0))); @@ -321,7 +321,7 @@ public void childLbRecreatesSubchannels() { public void acceptResolvedAddresses_outlierDetectionDisabled() { OutlierDetectionLoadBalancerConfig config = new OutlierDetectionLoadBalancerConfig.Builder() .setSuccessRateEjection(new SuccessRateEjection.Builder().build()) - .setChildPolicy(new PolicySelection(mockChildLbProvider, null)).build(); + .setChildConfig(newChildConfig(mockChildLbProvider, null)).build(); ResolvedAddresses resolvedAddresses = buildResolvedAddress(config, new EquivalentAddressGroup(mockSocketAddress)); @@ -332,8 +332,8 @@ public void acceptResolvedAddresses_outlierDetectionDisabled() { // There is a single pending task to run the outlier detection algorithm assertThat(fakeClock.getPendingTasks()).hasSize(1); - config = new OutlierDetectionLoadBalancerConfig.Builder().setChildPolicy( - new PolicySelection(mockChildLbProvider, null)).build(); + config = new OutlierDetectionLoadBalancerConfig.Builder().setChildConfig( + newChildConfig(mockChildLbProvider, null)).build(); loadBalancer.acceptResolvedAddresses( buildResolvedAddress(config, new EquivalentAddressGroup(mockSocketAddress))); @@ -349,7 +349,7 @@ public void acceptResolvedAddresses_outlierDetectionDisabled() { public void acceptResolvedAddresses_intervalUpdate() { OutlierDetectionLoadBalancerConfig config = new OutlierDetectionLoadBalancerConfig.Builder() .setSuccessRateEjection(new SuccessRateEjection.Builder().build()) - .setChildPolicy(new PolicySelection(mockChildLbProvider, null)).build(); + .setChildConfig(newChildConfig(mockChildLbProvider, null)).build(); ResolvedAddresses resolvedAddresses = buildResolvedAddress(config, new EquivalentAddressGroup(mockSocketAddress)); @@ -359,7 +359,7 @@ public void acceptResolvedAddresses_intervalUpdate() { config = new OutlierDetectionLoadBalancerConfig.Builder() .setIntervalNanos(config.intervalNanos * 2) .setSuccessRateEjection(new SuccessRateEjection.Builder().build()) - .setChildPolicy(new PolicySelection(mockChildLbProvider, null)).build(); + .setChildConfig(newChildConfig(mockChildLbProvider, null)).build(); loadBalancer.acceptResolvedAddresses( buildResolvedAddress(config, new EquivalentAddressGroup(mockSocketAddress))); @@ -394,7 +394,7 @@ public void acceptResolvedAddresses_intervalUpdate() { public void delegatePick() throws Exception { OutlierDetectionLoadBalancerConfig config = new OutlierDetectionLoadBalancerConfig.Builder() .setSuccessRateEjection(new SuccessRateEjection.Builder().build()) - .setChildPolicy(new PolicySelection(roundRobinLbProvider, null)).build(); + .setChildConfig(newChildConfig(roundRobinLbProvider, null)).build(); loadBalancer.acceptResolvedAddresses(buildResolvedAddress(config, servers.get(0))); @@ -408,7 +408,10 @@ public void delegatePick() throws Exception { // Make sure that we can pick the single READY subchannel. SubchannelPicker picker = pickerCaptor.getAllValues().get(2); PickResult pickResult = picker.pickSubchannel(mock(PickSubchannelArgs.class)); - Subchannel s = ((OutlierDetectionSubchannel) pickResult.getSubchannel()).delegate(); + Subchannel s = pickResult.getSubchannel(); + if (s instanceof HealthProducerHelper.HealthProducerSubchannel) { + s = ((HealthProducerHelper.HealthProducerSubchannel) s).delegate(); + } assertThat(s).isEqualTo(readySubchannel); } @@ -419,7 +422,7 @@ public void delegatePick() throws Exception { public void delegatePickTracerFactoryPreserved() { OutlierDetectionLoadBalancerConfig config = new OutlierDetectionLoadBalancerConfig.Builder() .setSuccessRateEjection(new SuccessRateEjection.Builder().build()) - .setChildPolicy(new PolicySelection(fakeLbProvider, null)).build(); + .setChildConfig(newChildConfig(fakeLbProvider, null)).build(); loadBalancer.acceptResolvedAddresses(buildResolvedAddress(config, servers.get(0))); @@ -458,7 +461,7 @@ public void delegatePickTracerFactoryNotSet() throws Exception { OutlierDetectionLoadBalancerConfig config = new OutlierDetectionLoadBalancerConfig.Builder() .setSuccessRateEjection(new SuccessRateEjection.Builder().build()) - .setChildPolicy(new PolicySelection(fakeLbProvider, null)).build(); + .setChildConfig(newChildConfig(fakeLbProvider, null)).build(); loadBalancer.acceptResolvedAddresses(buildResolvedAddress(config, servers.get(0))); @@ -491,7 +494,7 @@ public void successRateNoOutliers() { .setMaxEjectionPercent(50) .setSuccessRateEjection( new SuccessRateEjection.Builder().setMinimumHosts(3).setRequestVolume(10).build()) - .setChildPolicy(new PolicySelection(roundRobinLbProvider, null)).build(); + .setChildConfig(newChildConfig(roundRobinLbProvider, null)).build(); loadBalancer.acceptResolvedAddresses(buildResolvedAddress(config, servers)); @@ -515,7 +518,7 @@ public void successRateOneOutlier() { new SuccessRateEjection.Builder() .setMinimumHosts(3) .setRequestVolume(10).build()) - .setChildPolicy(new PolicySelection(roundRobinLbProvider, null)).build(); + .setChildConfig(newChildConfig(roundRobinLbProvider, null)).build(); loadBalancer.acceptResolvedAddresses(buildResolvedAddress(config, servers)); @@ -540,7 +543,7 @@ public void successRateOneOutlier_configChange() { new SuccessRateEjection.Builder() .setMinimumHosts(3) .setRequestVolume(10).build()) - .setChildPolicy(new PolicySelection(roundRobinLbProvider, null)).build(); + .setChildConfig(newChildConfig(roundRobinLbProvider, null)).build(); loadBalancer.acceptResolvedAddresses(buildResolvedAddress(config, servers)); @@ -560,11 +563,11 @@ public void successRateOneOutlier_configChange() { .setMinimumHosts(3) .setRequestVolume(10) .setEnforcementPercentage(0).build()) - .setChildPolicy(new PolicySelection(roundRobinLbProvider, null)).build(); + .setChildConfig(newChildConfig(roundRobinLbProvider, null)).build(); loadBalancer.acceptResolvedAddresses(buildResolvedAddress(config, servers)); - generateLoad(ImmutableMap.of(subchannel2, Status.DEADLINE_EXCEEDED), 12); + generateLoad(ImmutableMap.of(subchannel2, Status.DEADLINE_EXCEEDED), 8); // Move forward in time to a point where the detection timer has fired. forwardTime(config); @@ -585,7 +588,7 @@ public void successRateOneOutlier_unejected() { new SuccessRateEjection.Builder() .setMinimumHosts(3) .setRequestVolume(10).build()) - .setChildPolicy(new PolicySelection(roundRobinLbProvider, null)).build(); + .setChildConfig(newChildConfig(roundRobinLbProvider, null)).build(); loadBalancer.acceptResolvedAddresses(buildResolvedAddress(config, servers)); @@ -597,8 +600,8 @@ public void successRateOneOutlier_unejected() { // The one subchannel that was returning errors should be ejected. assertEjectedSubchannels(ImmutableSet.of(ImmutableSet.copyOf(servers.get(0).getAddresses()))); - // Now we produce more load, but the subchannel start working and is no longer an outlier. - generateLoad(ImmutableMap.of(), 12); + // Now we produce more load, but the subchannel has started working and is no longer an outlier. + generateLoad(ImmutableMap.of(), 8); // Move forward in time to a point where the detection timer has fired. fakeClock.forwardTime(config.maxEjectionTimeNanos + 1, TimeUnit.NANOSECONDS); @@ -618,7 +621,7 @@ public void successRateOneOutlier_notEnoughVolume() { new SuccessRateEjection.Builder() .setMinimumHosts(3) .setRequestVolume(20).build()) - .setChildPolicy(new PolicySelection(roundRobinLbProvider, null)).build(); + .setChildConfig(newChildConfig(roundRobinLbProvider, null)).build(); loadBalancer.acceptResolvedAddresses(buildResolvedAddress(config, servers)); @@ -646,7 +649,7 @@ public void successRateOneOutlier_notEnoughAddressesWithVolume() { new SuccessRateEjection.Builder() .setMinimumHosts(5) .setRequestVolume(20).build()) - .setChildPolicy(new PolicySelection(roundRobinLbProvider, null)).build(); + .setChildConfig(newChildConfig(roundRobinLbProvider, null)).build(); loadBalancer.acceptResolvedAddresses(buildResolvedAddress(config, servers)); @@ -676,7 +679,7 @@ public void successRateOneOutlier_enforcementPercentage() { .setRequestVolume(10) .setEnforcementPercentage(0) .build()) - .setChildPolicy(new PolicySelection(roundRobinLbProvider, null)).build(); + .setChildConfig(newChildConfig(roundRobinLbProvider, null)).build(); loadBalancer.acceptResolvedAddresses(buildResolvedAddress(config, servers)); @@ -701,7 +704,7 @@ public void successRateTwoOutliers() { .setMinimumHosts(3) .setRequestVolume(10) .setStdevFactor(1).build()) - .setChildPolicy(new PolicySelection(roundRobinLbProvider, null)).build(); + .setChildConfig(newChildConfig(roundRobinLbProvider, null)).build(); loadBalancer.acceptResolvedAddresses(buildResolvedAddress(config, servers)); @@ -730,7 +733,7 @@ public void successRateThreeOutliers_maxEjectionPercentage() { .setMinimumHosts(3) .setRequestVolume(10) .setStdevFactor(1).build()) - .setChildPolicy(new PolicySelection(roundRobinLbProvider, null)).build(); + .setChildConfig(newChildConfig(roundRobinLbProvider, null)).build(); loadBalancer.acceptResolvedAddresses(buildResolvedAddress(config, servers)); @@ -764,7 +767,7 @@ public void failurePercentageNoOutliers() { new FailurePercentageEjection.Builder() .setMinimumHosts(3) .setRequestVolume(10).build()) - .setChildPolicy(new PolicySelection(roundRobinLbProvider, null)).build(); + .setChildConfig(newChildConfig(roundRobinLbProvider, null)).build(); loadBalancer.acceptResolvedAddresses(buildResolvedAddress(config, servers)); @@ -789,7 +792,7 @@ public void failurePercentageOneOutlier() { new FailurePercentageEjection.Builder() .setMinimumHosts(3) .setRequestVolume(10).build()) - .setChildPolicy(new PolicySelection(roundRobinLbProvider, null)).build(); + .setChildConfig(newChildConfig(roundRobinLbProvider, null)).build(); loadBalancer.acceptResolvedAddresses(buildResolvedAddress(config, servers)); @@ -813,7 +816,7 @@ public void failurePercentageOneOutlier_notEnoughVolume() { new FailurePercentageEjection.Builder() .setMinimumHosts(3) .setRequestVolume(100).build()) // We won't produce this much volume... - .setChildPolicy(new PolicySelection(roundRobinLbProvider, null)).build(); + .setChildConfig(newChildConfig(roundRobinLbProvider, null)).build(); loadBalancer.acceptResolvedAddresses(buildResolvedAddress(config, servers)); @@ -838,7 +841,7 @@ public void failurePercentageOneOutlier_notEnoughAddressesWithVolume() { new FailurePercentageEjection.Builder() .setMinimumHosts(5) .setRequestVolume(20).build()) - .setChildPolicy(new PolicySelection(roundRobinLbProvider, null)).build(); + .setChildConfig(newChildConfig(roundRobinLbProvider, null)).build(); loadBalancer.acceptResolvedAddresses(buildResolvedAddress(config, servers)); @@ -868,7 +871,7 @@ public void failurePercentageOneOutlier_enforcementPercentage() { .setRequestVolume(10) .setEnforcementPercentage(0) .build()) - .setChildPolicy(new PolicySelection(roundRobinLbProvider, null)).build(); + .setChildConfig(newChildConfig(roundRobinLbProvider, null)).build(); loadBalancer.acceptResolvedAddresses(buildResolvedAddress(config, servers)); @@ -897,7 +900,7 @@ public void successRateAndFailurePercentageThreeOutliers() { .setMinimumHosts(3) .setRequestVolume(1) .build()) - .setChildPolicy(new PolicySelection(roundRobinLbProvider, null)).build(); + .setChildConfig(newChildConfig(roundRobinLbProvider, null)).build(); loadBalancer.acceptResolvedAddresses(buildResolvedAddress(config, servers)); @@ -934,7 +937,7 @@ public void subchannelUpdateAddress_singleReplaced() { new FailurePercentageEjection.Builder() .setMinimumHosts(3) .setRequestVolume(10).build()) - .setChildPolicy(new PolicySelection(roundRobinLbProvider, null)).build(); + .setChildConfig(newChildConfig(roundRobinLbProvider, null)).build(); loadBalancer.acceptResolvedAddresses(buildResolvedAddress(config, servers)); @@ -978,7 +981,7 @@ public void multipleAddressesEndpoint() { new FailurePercentageEjection.Builder() .setMinimumHosts(3) .setRequestVolume(10).build()) - .setChildPolicy(new PolicySelection(fakeLbProvider, null)).build(); + .setChildConfig(newChildConfig(fakeLbProvider, null)).build(); loadBalancer.acceptResolvedAddresses(buildResolvedAddress(config, servers)); EquivalentAddressGroup manyAddEndpoint = new EquivalentAddressGroup(Arrays.asList( @@ -1021,7 +1024,7 @@ public void subchannelUpdateAddress_singleReplacedWithMultiple() { new FailurePercentageEjection.Builder() .setMinimumHosts(3) .setRequestVolume(10).build()) - .setChildPolicy(new PolicySelection(roundRobinLbProvider, null)).build(); + .setChildConfig(newChildConfig(roundRobinLbProvider, null)).build(); loadBalancer.acceptResolvedAddresses(buildResolvedAddress(config, servers)); @@ -1065,7 +1068,7 @@ public void subchannelUpdateAddress_multipleReplacedWithSingle() { new FailurePercentageEjection.Builder() .setMinimumHosts(3) .setRequestVolume(10).build()) - .setChildPolicy(new PolicySelection(fakeLbProvider, null)).build(); + .setChildConfig(newChildConfig(fakeLbProvider, null)).build(); loadBalancer.acceptResolvedAddresses(buildResolvedAddress(config, servers)); @@ -1128,7 +1131,7 @@ public void successRateAndFailurePercentage_noOutliers() { new FailurePercentageEjection.Builder() .setMinimumHosts(3) .setRequestVolume(10).build()) - .setChildPolicy(new PolicySelection(roundRobinLbProvider, null)).build(); + .setChildConfig(newChildConfig(roundRobinLbProvider, null)).build(); loadBalancer.acceptResolvedAddresses(buildResolvedAddress(config, servers)); @@ -1155,7 +1158,7 @@ public void successRateAndFailurePercentage_successRateOutlier() { .setMinimumHosts(3) .setRequestVolume(10) .setEnforcementPercentage(0).build()) // Configured, but not enforcing. - .setChildPolicy(new PolicySelection(roundRobinLbProvider, null)).build(); + .setChildConfig(newChildConfig(roundRobinLbProvider, null)).build(); loadBalancer.acceptResolvedAddresses(buildResolvedAddress(config, servers)); @@ -1184,7 +1187,7 @@ public void successRateAndFailurePercentage_successRateOutlier_() { // with heal .setMinimumHosts(3) .setRequestVolume(10) .setEnforcementPercentage(0).build()) // Configured, but not enforcing. - .setChildPolicy(new PolicySelection(fakeLbProvider, null)).build(); + .setChildConfig(newChildConfig(fakeLbProvider, null)).build(); loadBalancer.acceptResolvedAddresses(buildResolvedAddress(config, servers)); @@ -1196,9 +1199,21 @@ public void successRateAndFailurePercentage_successRateOutlier_() { // with heal // The one subchannel that was returning errors should be ejected. assertEjectedSubchannels(ImmutableSet.of(ImmutableSet.copyOf(servers.get(0).getAddresses()))); if (hasHealthConsumer) { - verify(healthListeners.get(servers.get(0))).onSubchannelState(eq( - ConnectivityStateInfo.forTransientFailure(Status.UNAVAILABLE) - )); + ArgumentCaptor csiCaptor = ArgumentCaptor.forClass( + ConnectivityStateInfo.class); + verify(healthListeners.get(servers.get(0)), times(2)).onSubchannelState(csiCaptor.capture()); + List connectivityStateInfos = csiCaptor.getAllValues(); + + // The subchannel went through two state transitions... + assertThat(connectivityStateInfos).hasSize(2); + // ...it first went to the READY state... + assertThat(connectivityStateInfos.get(0).getState()).isEqualTo(READY); + + // ...and then to TRANSIENT_FAILURE as outlier detection ejected it. + assertThat(connectivityStateInfos.get(1).getState()).isEqualTo(TRANSIENT_FAILURE); + assertThat(connectivityStateInfos.get(1).getStatus().getCode()).isEqualTo(Code.UNAVAILABLE); + assertThat(connectivityStateInfos.get(1).getStatus().getDescription()).isEqualTo( + "The subchannel has been ejected by outlier detection"); } } @@ -1216,7 +1231,7 @@ public void successRateAndFailurePercentage_errorPercentageOutlier() { new FailurePercentageEjection.Builder() .setMinimumHosts(3) .setRequestVolume(10).build()) // Configured, but not enforcing. - .setChildPolicy(new PolicySelection(roundRobinLbProvider, null)).build(); + .setChildConfig(newChildConfig(roundRobinLbProvider, null)).build(); loadBalancer.acceptResolvedAddresses(buildResolvedAddress(config, servers)); @@ -1245,7 +1260,7 @@ public void successRateAndFailurePercentage_errorPercentageOutlier_() { // with new FailurePercentageEjection.Builder() .setMinimumHosts(3) .setRequestVolume(10).build()) // Configured, but not enforcing. - .setChildPolicy(new PolicySelection(fakeLbProvider, null)).build(); + .setChildConfig(newChildConfig(fakeLbProvider, null)).build(); loadBalancer.acceptResolvedAddresses(buildResolvedAddress(config, servers)); @@ -1257,9 +1272,21 @@ public void successRateAndFailurePercentage_errorPercentageOutlier_() { // with // The one subchannel that was returning errors should be ejected. assertEjectedSubchannels(ImmutableSet.of(ImmutableSet.copyOf(servers.get(0).getAddresses()))); if (hasHealthConsumer) { - verify(healthListeners.get(servers.get(0))).onSubchannelState(eq( - ConnectivityStateInfo.forTransientFailure(Status.UNAVAILABLE) - )); + ArgumentCaptor csiCaptor = ArgumentCaptor.forClass( + ConnectivityStateInfo.class); + verify(healthListeners.get(servers.get(0)), times(2)).onSubchannelState(csiCaptor.capture()); + List connectivityStateInfos = csiCaptor.getAllValues(); + + // The subchannel went through two state transitions... + assertThat(connectivityStateInfos).hasSize(2); + // ...it first went to the READY state... + assertThat(connectivityStateInfos.get(0).getState()).isEqualTo(READY); + + // ...and then to TRANSIENT_FAILURE as outlier detection ejected it. + assertThat(connectivityStateInfos.get(1).getState()).isEqualTo(TRANSIENT_FAILURE); + assertThat(connectivityStateInfos.get(1).getStatus().getCode()).isEqualTo(Code.UNAVAILABLE); + assertThat(connectivityStateInfos.get(1).getStatus().getDescription()).isEqualTo( + "The subchannel has been ejected by outlier detection"); } } @@ -1357,6 +1384,10 @@ void assertEjectedSubchannels(Collection> addresses) { } } + private Object newChildConfig(LoadBalancerProvider provider, Object config) { + return GracefulSwitchLoadBalancer.createLoadBalancingPolicyConfig(provider, config); + } + /** Round robin like fake load balancer. */ private final class FakeLoadBalancer extends LoadBalancer { private final Helper helper; diff --git a/util/src/test/java/io/grpc/util/RandomSubsettingLoadBalancerProviderTest.java b/util/src/test/java/io/grpc/util/RandomSubsettingLoadBalancerProviderTest.java new file mode 100644 index 00000000000..18a0766d4b2 --- /dev/null +++ b/util/src/test/java/io/grpc/util/RandomSubsettingLoadBalancerProviderTest.java @@ -0,0 +1,135 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.util; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.fail; +import static org.mockito.Mockito.mock; + +import io.grpc.InternalServiceProviders; +import io.grpc.LoadBalancer.Helper; +import io.grpc.LoadBalancerProvider; +import io.grpc.NameResolver.ConfigOrError; +import io.grpc.Status; +import io.grpc.internal.JsonParser; +import io.grpc.util.RandomSubsettingLoadBalancer.RandomSubsettingLoadBalancerConfig; +import java.io.IOException; +import java.util.Map; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class RandomSubsettingLoadBalancerProviderTest { + private final RandomSubsettingLoadBalancerProvider provider = + new RandomSubsettingLoadBalancerProvider(); + + @Test + public void registered() { + for (LoadBalancerProvider current : + InternalServiceProviders.getCandidatesViaServiceLoader( + LoadBalancerProvider.class, getClass().getClassLoader())) { + if (current instanceof RandomSubsettingLoadBalancerProvider) { + return; + } + } + fail("RandomSubsettingLoadBalancerProvider not registered"); + } + + @Test + public void providesLoadBalancer() { + Helper helper = mock(Helper.class); + assertThat(provider.newLoadBalancer(helper)) + .isInstanceOf(RandomSubsettingLoadBalancer.class); + } + + @Test + public void parseConfigRequiresSubsetSize() throws IOException { + String emptyConfig = "{}"; + + ConfigOrError configOrError = + provider.parseLoadBalancingPolicyConfig(parseJsonObject(emptyConfig)); + assertThat(configOrError.getError()).isNotNull(); + assertThat(configOrError.getError().toString()) + .isEqualTo( + Status.UNAVAILABLE + .withDescription( + "Subset size missing in random_subsetting_experimental, LB policy config={}") + .toString()); + } + + @Test + public void parseConfigReturnsErrorWhenChildPolicyMissing() throws IOException { + String missingChildPolicyConfig = "{\"subsetSize\": 3}"; + + ConfigOrError configOrError = + provider.parseLoadBalancingPolicyConfig(parseJsonObject(missingChildPolicyConfig)); + assertThat(configOrError.getError()).isNotNull(); + + Status error = configOrError.getError(); + assertThat(error.getCode()).isEqualTo(Status.Code.UNAVAILABLE); + assertThat(error.getDescription()).isEqualTo( + "Failed to parse child in random_subsetting_experimental" + + ", LB policy config={subsetSize=3.0}"); + assertThat(error.getCause().getMessage()).isEqualTo( + "UNAVAILABLE: No child LB config specified"); + } + + @Test + public void parseConfigReturnsErrorWhenChildPolicyInvalid() throws IOException { + String invalidChildPolicyConfig = + "{" + + "\"subsetSize\": 3, " + + "\"childPolicy\" : [{\"random_policy\" : {}}]" + + "}"; + + ConfigOrError configOrError = + provider.parseLoadBalancingPolicyConfig(parseJsonObject(invalidChildPolicyConfig)); + assertThat(configOrError.getError()).isNotNull(); + + Status error = configOrError.getError(); + assertThat(error.getCode()).isEqualTo(Status.Code.UNAVAILABLE); + assertThat(error.getDescription()).isEqualTo( + "Failed to parse child in random_subsetting_experimental, LB policy config=" + + "{subsetSize=3.0, childPolicy=[{random_policy={}}]}"); + assertThat(error.getCause().getMessage()).contains( + "UNAVAILABLE: None of [random_policy] specified by Service Config are available."); + } + + @Test + public void parseValidConfig() throws IOException { + String validConfig = + "{" + + "\"subsetSize\": 3, " + + "\"childPolicy\" : [{\"round_robin\" : {}}]" + + "}"; + ConfigOrError configOrError = + provider.parseLoadBalancingPolicyConfig(parseJsonObject(validConfig)); + assertThat(configOrError.getConfig()).isNotNull(); + + RandomSubsettingLoadBalancerConfig actualConfig = + (RandomSubsettingLoadBalancerConfig) configOrError.getConfig(); + assertThat(GracefulSwitchLoadBalancerAccessor.getChildProvider( + actualConfig.childConfig).getPolicyName()).isEqualTo("round_robin"); + assertThat(actualConfig.subsetSize).isEqualTo(3); + } + + @SuppressWarnings("unchecked") + private static Map parseJsonObject(String json) throws IOException { + return (Map) JsonParser.parse(json); + } +} diff --git a/util/src/test/java/io/grpc/util/RandomSubsettingLoadBalancerTest.java b/util/src/test/java/io/grpc/util/RandomSubsettingLoadBalancerTest.java new file mode 100644 index 00000000000..2c43e8f4c3a --- /dev/null +++ b/util/src/test/java/io/grpc/util/RandomSubsettingLoadBalancerTest.java @@ -0,0 +1,333 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.util; + +import static com.google.common.truth.Truth.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.atLeastOnce; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Lists; +import com.google.common.collect.Maps; +import io.grpc.ConnectivityState; +import io.grpc.ConnectivityStateInfo; +import io.grpc.EquivalentAddressGroup; +import io.grpc.LoadBalancer; +import io.grpc.LoadBalancer.CreateSubchannelArgs; +import io.grpc.LoadBalancer.ResolvedAddresses; +import io.grpc.LoadBalancer.Subchannel; +import io.grpc.LoadBalancer.SubchannelStateListener; +import io.grpc.LoadBalancerProvider; +import io.grpc.Status; +import io.grpc.internal.TestUtils; +import io.grpc.util.RandomSubsettingLoadBalancer.RandomSubsettingLoadBalancerConfig; +import java.net.SocketAddress; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.ArgumentCaptor; +import org.mockito.Captor; +import org.mockito.Mock; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.junit.MockitoJUnit; +import org.mockito.junit.MockitoRule; +import org.mockito.stubbing.Answer; + +@RunWith(JUnit4.class) +public class RandomSubsettingLoadBalancerTest { + @Rule + public final MockitoRule mockitoRule = MockitoJUnit.rule(); + + @Mock + private LoadBalancer.Helper mockHelper; + @Mock + private LoadBalancer mockChildLb; + @Mock + private SocketAddress mockSocketAddress; + + @Captor + private ArgumentCaptor resolvedAddrCaptor; + + private BackendDetails backendDetails; + + private RandomSubsettingLoadBalancer loadBalancer; + + private final LoadBalancerProvider mockChildLbProvider = + new TestUtils.StandardLoadBalancerProvider("foo_policy") { + @Override + public LoadBalancer newLoadBalancer(LoadBalancer.Helper helper) { + return mockChildLb; + } + }; + + private final LoadBalancerProvider roundRobinLbProvider = + new TestUtils.StandardLoadBalancerProvider("round_robin") { + @Override + public LoadBalancer newLoadBalancer(LoadBalancer.Helper helper) { + return new RoundRobinLoadBalancer(helper); + } + }; + + private Object newChildConfig(LoadBalancerProvider provider, Object config) { + return GracefulSwitchLoadBalancer.createLoadBalancingPolicyConfig(provider, config); + } + + private RandomSubsettingLoadBalancerConfig createRandomSubsettingLbConfig( + int subsetSize, LoadBalancerProvider childLbProvider, Object childConfig) { + return new RandomSubsettingLoadBalancer.RandomSubsettingLoadBalancerConfig.Builder() + .setSubsetSize(subsetSize) + .setChildConfig(newChildConfig(childLbProvider, childConfig)) + .build(); + } + + private BackendDetails setupBackends(int backendCount) { + List servers = Lists.newArrayList(); + Map, Subchannel> subchannels = Maps.newLinkedHashMap(); + + for (int i = 0; i < backendCount; i++) { + SocketAddress addr = new FakeSocketAddress("server" + i); + EquivalentAddressGroup addressGroup = new EquivalentAddressGroup(addr); + servers.add(addressGroup); + Subchannel subchannel = mock(Subchannel.class); + subchannels.put(Arrays.asList(addressGroup), subchannel); + } + + return new BackendDetails(servers, subchannels); + } + + @Before + public void setUp() { + int seed = 0; + loadBalancer = new RandomSubsettingLoadBalancer(mockHelper, seed); + + int backendSize = 5; + backendDetails = setupBackends(backendSize); + } + + @Test + public void handleNameResolutionError() { + int subsetSize = 2; + Object childConfig = "someConfig"; + + RandomSubsettingLoadBalancerConfig config = createRandomSubsettingLbConfig( + subsetSize, mockChildLbProvider, childConfig); + + loadBalancer.acceptResolvedAddresses( + ResolvedAddresses.newBuilder() + .setAddresses(ImmutableList.of(new EquivalentAddressGroup(mockSocketAddress))) + .setLoadBalancingPolicyConfig(config) + .build()); + + loadBalancer.handleNameResolutionError(Status.DEADLINE_EXCEEDED); + verify(mockChildLb).handleNameResolutionError(Status.DEADLINE_EXCEEDED); + } + + @Test + public void shutdown() { + int subsetSize = 2; + Object childConfig = "someConfig"; + + RandomSubsettingLoadBalancerConfig config = createRandomSubsettingLbConfig( + subsetSize, mockChildLbProvider, childConfig); + + loadBalancer.acceptResolvedAddresses( + ResolvedAddresses.newBuilder() + .setAddresses(ImmutableList.of(new EquivalentAddressGroup(mockSocketAddress))) + .setLoadBalancingPolicyConfig(config) + .build()); + + loadBalancer.shutdown(); + verify(mockChildLb).shutdown(); + } + + @Test + public void acceptResolvedAddresses_mockedChildLbPolicy() { + int subsetSize = 3; + Object childConfig = "someConfig"; + + RandomSubsettingLoadBalancerConfig config = createRandomSubsettingLbConfig( + subsetSize, mockChildLbProvider, childConfig); + + ResolvedAddresses resolvedAddresses = + ResolvedAddresses.newBuilder() + .setAddresses(ImmutableList.copyOf(backendDetails.servers)) + .setLoadBalancingPolicyConfig(config) + .build(); + + loadBalancer.acceptResolvedAddresses(resolvedAddresses); + + verify(mockChildLb).acceptResolvedAddresses(resolvedAddrCaptor.capture()); + assertThat(resolvedAddrCaptor.getValue().getAddresses().size()).isEqualTo(subsetSize); + assertThat(resolvedAddrCaptor.getValue().getLoadBalancingPolicyConfig()).isEqualTo(childConfig); + } + + @Test + public void acceptResolvedAddresses_roundRobinChildLbPolicy() { + int subsetSize = 3; + Object childConfig = null; + + RandomSubsettingLoadBalancerConfig config = createRandomSubsettingLbConfig( + subsetSize, roundRobinLbProvider, childConfig); + + ResolvedAddresses resolvedAddresses = + ResolvedAddresses.newBuilder() + .setAddresses(ImmutableList.copyOf(backendDetails.servers)) + .setLoadBalancingPolicyConfig(config) + .build(); + + loadBalancer.acceptResolvedAddresses(resolvedAddresses); + + int insubset = 0; + for (Subchannel subchannel : backendDetails.subchannels.values()) { + LoadBalancer.SubchannelStateListener ssl = + backendDetails.subchannelStateListeners.get(subchannel); + if (ssl != null) { // it might be null if it's not in the subset. + insubset += 1; + ssl.onSubchannelState(ConnectivityStateInfo.forNonError(ConnectivityState.READY)); + } + } + + assertThat(insubset).isEqualTo(subsetSize); + } + + // verifies: https://github.com/grpc/proposal/blob/master/A68_graphics/subsetting100-100-5.png + @Test + public void backendsCanBeDistributedEvenly_subsetting100_100_5() { + verifyConnectionsByServer(100, 100, 5, 15); + } + + // verifies https://github.com/grpc/proposal/blob/master/A68_graphics/subsetting100-100-25.png + @Test + public void backendsCanBeDistributedEvenly_subsetting100_100_25() { + verifyConnectionsByServer(100, 100, 25, 40); + } + + // verifies: https://github.com/grpc/proposal/blob/master/A68_graphics/subsetting100-10-5.png + @Test + public void backendsCanBeDistributedEvenly_subsetting100_10_5() { + verifyConnectionsByServer(100, 10, 5, 65); + } + + // verifies: https://github.com/grpc/proposal/blob/master/A68_graphics/subsetting500-10-5.png + @Test + public void backendsCanBeDistributedEvenly_subsetting500_10_5() { + verifyConnectionsByServer(500, 10, 5, 600); + } + + // verifies: https://github.com/grpc/proposal/blob/master/A68_graphics/subsetting2000-10-5.png + @Test + public void backendsCanBeDistributedEvenly_subsetting2000_100_5() { + verifyConnectionsByServer(2000, 10, 5, 1200); + } + + public void verifyConnectionsByServer( + int clientsCount, int serversCount, int subsetSize, int expectedMaxConnections) { + backendDetails = setupBackends(serversCount); + Object childConfig = "someConfig"; + + List configs = Lists.newArrayList(); + for (int i = 0; i < clientsCount; i++) { + configs.add(createRandomSubsettingLbConfig(subsetSize, mockChildLbProvider, childConfig)); + } + + Map connectionsByServer = Maps.newLinkedHashMap(); + + for (int i = 0; i < clientsCount; i++) { + RandomSubsettingLoadBalancerConfig config = configs.get(i); + + ResolvedAddresses resolvedAddresses = + ResolvedAddresses.newBuilder() + .setAddresses(ImmutableList.copyOf(backendDetails.servers)) + .setLoadBalancingPolicyConfig(config) + .build(); + + loadBalancer = new RandomSubsettingLoadBalancer(mockHelper, i); + loadBalancer.acceptResolvedAddresses(resolvedAddresses); + + verify(mockChildLb, atLeastOnce()).acceptResolvedAddresses(resolvedAddrCaptor.capture()); + // Verify ChildLB is only getting subsetSize ResolvedAddresses each time + assertThat(resolvedAddrCaptor.getValue().getAddresses().size()).isEqualTo(config.subsetSize); + + for (EquivalentAddressGroup eag : resolvedAddrCaptor.getValue().getAddresses()) { + for (SocketAddress addr : eag.getAddresses()) { + Integer prev = connectionsByServer.getOrDefault(addr, 0); + connectionsByServer.put(addr, prev + 1); + } + } + } + + int maxConnections = Collections.max(connectionsByServer.values()); + + assertThat(maxConnections).isAtMost(expectedMaxConnections); + } + + private class BackendDetails { + private final List servers; + private final Map, Subchannel> subchannels; + private final Map subchannelStateListeners; + + BackendDetails(List servers, + Map, Subchannel> subchannels) { + this.servers = servers; + this.subchannels = subchannels; + this.subchannelStateListeners = Maps.newLinkedHashMap(); + + when(mockHelper.createSubchannel(any(LoadBalancer.CreateSubchannelArgs.class))).then( + new Answer() { + @Override + public Subchannel answer(InvocationOnMock invocation) throws Throwable { + CreateSubchannelArgs args = (CreateSubchannelArgs) invocation.getArguments()[0]; + final Subchannel subchannel = backendDetails.subchannels.get(args.getAddresses()); + when(subchannel.getAllAddresses()).thenReturn(args.getAddresses()); + when(subchannel.getAttributes()).thenReturn(args.getAttributes()); + doAnswer(new Answer() { + @Override + public Void answer(InvocationOnMock invocation) throws Throwable { + subchannelStateListeners.put(subchannel, + (SubchannelStateListener) invocation.getArguments()[0]); + return null; + } + }).when(subchannel).start(any(SubchannelStateListener.class)); + return subchannel; + } + }); + } + } + + private static class FakeSocketAddress extends SocketAddress { + final String name; + + FakeSocketAddress(String name) { + this.name = name; + } + + @Override + public String toString() { + return "FakeSocketAddress-" + name; + } + } +} diff --git a/util/src/test/java/io/grpc/util/RoundRobinLoadBalancerTest.java b/util/src/test/java/io/grpc/util/RoundRobinLoadBalancerTest.java index 6b935ec3aa9..18854ca1bb6 100644 --- a/util/src/test/java/io/grpc/util/RoundRobinLoadBalancerTest.java +++ b/util/src/test/java/io/grpc/util/RoundRobinLoadBalancerTest.java @@ -22,7 +22,6 @@ import static io.grpc.ConnectivityState.READY; import static io.grpc.ConnectivityState.SHUTDOWN; import static io.grpc.ConnectivityState.TRANSIENT_FAILURE; -import static io.grpc.util.MultiChildLoadBalancer.IS_PETIOLE_POLICY; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNull; import static org.junit.Assert.fail; @@ -30,6 +29,8 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.isA; +import static org.mockito.Mockito.atLeast; +import static org.mockito.Mockito.atLeastOnce; import static org.mockito.Mockito.inOrder; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; @@ -44,15 +45,17 @@ import io.grpc.EquivalentAddressGroup; import io.grpc.LoadBalancer; import io.grpc.LoadBalancer.CreateSubchannelArgs; +import io.grpc.LoadBalancer.FixedResultPicker; import io.grpc.LoadBalancer.Helper; +import io.grpc.LoadBalancer.PickResult; import io.grpc.LoadBalancer.PickSubchannelArgs; import io.grpc.LoadBalancer.ResolvedAddresses; import io.grpc.LoadBalancer.Subchannel; import io.grpc.LoadBalancer.SubchannelPicker; import io.grpc.Status; +import io.grpc.internal.PickFirstLoadBalancerProvider; +import io.grpc.internal.PickFirstLoadBalancerProviderAccessor; import io.grpc.internal.TestUtils; -import io.grpc.util.MultiChildLoadBalancer.ChildLbState; -import io.grpc.util.RoundRobinLoadBalancer.EmptyPicker; import io.grpc.util.RoundRobinLoadBalancer.ReadyPicker; import java.net.SocketAddress; import java.util.ArrayList; @@ -82,6 +85,8 @@ @RunWith(JUnit4.class) public class RoundRobinLoadBalancerTest { private static final Attributes.Key MAJOR_KEY = Attributes.Key.create("major-key"); + private static final SubchannelPicker EMPTY_PICKER = + new FixedResultPicker(PickResult.withNoResult()); @Rule public final MockitoRule mocks = MockitoJUnit.rule(); @@ -100,6 +105,7 @@ public class RoundRobinLoadBalancerTest { private ArgumentCaptor createArgsCaptor; private TestHelper testHelperInst = new TestHelper(); private Helper mockHelper = mock(Helper.class, delegatesTo(testHelperInst)); + private boolean defaultNewPickFirst = PickFirstLoadBalancerProvider.isEnabledNewPickFirst(); @Mock // This LoadBalancer doesn't use any of the arg fields, as verified in tearDown(). private PickSubchannelArgs mockArgs; @@ -122,6 +128,7 @@ private Status acceptAddresses(List eagList, Attributes @After public void tearDown() throws Exception { + PickFirstLoadBalancerProviderAccessor.setEnableNewPickFirst(defaultNewPickFirst); verifyNoMoreInteractions(mockArgs); } @@ -151,7 +158,7 @@ public void pickAfterResolved() throws Exception { assertEquals(READY, stateCaptor.getAllValues().get(1)); assertThat(getList(pickerCaptor.getValue())).containsExactly(readySubchannel); - verifyNoMoreInteractions(mockHelper); + AbstractTestHelper.verifyNoMoreMeaningfulInteractions(mockHelper); } @Test @@ -197,16 +204,6 @@ public void pickAfterResolvedUpdatedHosts() throws Exception { verify(removedSubchannel, times(1)).requestConnection(); verify(oldSubchannel, times(1)).requestConnection(); - assertThat(loadBalancer.getChildLbStates().size()).isEqualTo(2); - for (ChildLbState childLbState : loadBalancer.getChildLbStates()) { - assertThat(childLbState.getResolvedAddresses().getAttributes().get(IS_PETIOLE_POLICY)) - .isTrue(); - } - assertThat(loadBalancer.getChildLbStateEag(removedEag).getCurrentPicker().pickSubchannel(null) - .getSubchannel()).isEqualTo(removedSubchannel); - assertThat(loadBalancer.getChildLbStateEag(oldEag1).getCurrentPicker().pickSubchannel(null) - .getSubchannel()).isEqualTo(oldSubchannel); - // This time with Attributes List latestServers = Lists.newArrayList(oldEag2, newEag); @@ -220,56 +217,48 @@ public void pickAfterResolvedUpdatedHosts() throws Exception { deliverSubchannelState(newSubchannel, ConnectivityStateInfo.forNonError(READY)); - assertThat(loadBalancer.getChildLbStates().size()).isEqualTo(2); - assertThat(loadBalancer.getChildLbStateEag(newEag).getCurrentPicker() - .pickSubchannel(null).getSubchannel()).isEqualTo(newSubchannel); - assertThat(loadBalancer.getChildLbStateEag(oldEag2).getCurrentPicker() - .pickSubchannel(null).getSubchannel()).isEqualTo(oldSubchannel); - verify(mockHelper, times(6)).createSubchannel(any(CreateSubchannelArgs.class)); inOrder.verify(mockHelper, times(2)).updateBalancingState(eq(READY), pickerCaptor.capture()); picker = pickerCaptor.getValue(); assertThat(getList(picker)).containsExactly(oldSubchannel, newSubchannel); - verifyNoMoreInteractions(mockHelper); + AbstractTestHelper.verifyNoMoreMeaningfulInteractions(mockHelper); } @Test public void pickAfterStateChange() throws Exception { InOrder inOrder = inOrder(mockHelper); - Status addressesAcceptanceStatus = acceptAddresses(servers, Attributes.EMPTY); + Status addressesAcceptanceStatus = + acceptAddresses(Arrays.asList(servers.get(0)), Attributes.EMPTY); assertThat(addressesAcceptanceStatus.isOk()).isTrue(); + inOrder.verify(mockHelper).createSubchannel(any(CreateSubchannelArgs.class)); // TODO figure out if this method testing the right things - ChildLbState childLbState = loadBalancer.getChildLbStates().iterator().next(); - Subchannel subchannel = childLbState.getCurrentPicker().pickSubchannel(null).getSubchannel(); + assertThat(subchannels).hasSize(1); + Subchannel subchannel = subchannels.values().iterator().next(); - inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), isA(EmptyPicker.class)); - assertThat(childLbState.getCurrentState()).isEqualTo(CONNECTING); + inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), eq(EMPTY_PICKER)); deliverSubchannelState(subchannel, ConnectivityStateInfo.forNonError(READY)); inOrder.verify(mockHelper).updateBalancingState(eq(READY), pickerCaptor.capture()); assertThat(pickerCaptor.getValue()).isInstanceOf(ReadyPicker.class); - assertThat(childLbState.getCurrentState()).isEqualTo(READY); Status error = Status.UNKNOWN.withDescription("¯\\_(ツ)_//¯"); deliverSubchannelState(subchannel, ConnectivityStateInfo.forTransientFailure(error)); - assertThat(childLbState.getCurrentState()).isEqualTo(TRANSIENT_FAILURE); - inOrder.verify(mockHelper).refreshNameResolution(); - inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), pickerCaptor.capture()); - assertThat(pickerCaptor.getValue()).isInstanceOf(EmptyPicker.class); + AbstractTestHelper.refreshInvokedAndUpdateBS( + inOrder, TRANSIENT_FAILURE, mockHelper, pickerCaptor); + assertThat(pickerCaptor.getValue().pickSubchannel(mockArgs).getStatus()).isEqualTo(error); - deliverSubchannelState(subchannel, - ConnectivityStateInfo.forNonError(IDLE)); + deliverSubchannelState(subchannel, ConnectivityStateInfo.forNonError(IDLE)); inOrder.verify(mockHelper).refreshNameResolution(); - assertThat(childLbState.getCurrentState()).isEqualTo(TRANSIENT_FAILURE); + inOrder.verify(mockHelper, never()) + .updateBalancingState(eq(TRANSIENT_FAILURE), any(SubchannelPicker.class)); - verify(subchannel, times(2)).requestConnection(); - verify(mockHelper, times(3)).createSubchannel(any(CreateSubchannelArgs.class)); - verifyNoMoreInteractions(mockHelper); + verify(subchannel, atLeastOnce()).requestConnection(); + AbstractTestHelper.verifyNoMoreMeaningfulInteractions(mockHelper); } @Test @@ -277,12 +266,12 @@ public void ignoreShutdownSubchannelStateChange() { InOrder inOrder = inOrder(mockHelper); Status addressesAcceptanceStatus = acceptAddresses(servers, Attributes.EMPTY); assertThat(addressesAcceptanceStatus.isOk()).isTrue(); - inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), isA(EmptyPicker.class)); + inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), eq(EMPTY_PICKER)); + List savedSubchannels = new ArrayList<>(subchannels.values()); loadBalancer.shutdown(); - for (ChildLbState child : loadBalancer.getChildLbStates()) { - Subchannel sc = child.getCurrentPicker().pickSubchannel(null).getSubchannel(); - verify(child).shutdown(); + for (Subchannel sc : savedSubchannels) { + verify(sc).shutdown(); // When the subchannel is being shut down, a SHUTDOWN connectivity state is delivered // back to the subchannel state listener. deliverSubchannelState(sc, ConnectivityStateInfo.forNonError(SHUTDOWN)); @@ -297,35 +286,29 @@ public void stayTransientFailureUntilReady() { Status addressesAcceptanceStatus = acceptAddresses(servers, Attributes.EMPTY); assertThat(addressesAcceptanceStatus.isOk()).isTrue(); - inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), isA(EmptyPicker.class)); + inOrder.verify(mockHelper, times(3)).createSubchannel(any(CreateSubchannelArgs.class)); + inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), eq(EMPTY_PICKER)); - Map childToSubChannelMap = new HashMap<>(); // Simulate state transitions for each subchannel individually. - for ( ChildLbState child : loadBalancer.getChildLbStates()) { - Subchannel sc = child.getSubchannels(mockArgs); - childToSubChannelMap.put(child, sc); + for (Subchannel sc : subchannels.values()) { Status error = Status.UNKNOWN.withDescription("connection broken"); deliverSubchannelState( sc, ConnectivityStateInfo.forTransientFailure(error)); - assertEquals(TRANSIENT_FAILURE, child.getCurrentState()); - inOrder.verify(mockHelper).refreshNameResolution(); deliverSubchannelState( sc, ConnectivityStateInfo.forNonError(CONNECTING)); - assertEquals(TRANSIENT_FAILURE, child.getCurrentState()); } inOrder.verify(mockHelper).updateBalancingState(eq(TRANSIENT_FAILURE), isA(ReadyPicker.class)); + inOrder.verify(mockHelper, atLeast(0)).refreshNameResolution(); inOrder.verifyNoMoreInteractions(); - ChildLbState child = loadBalancer.getChildLbStates().iterator().next(); - Subchannel subchannel = childToSubChannelMap.get(child); + Subchannel subchannel = subchannels.values().iterator().next(); deliverSubchannelState(subchannel, ConnectivityStateInfo.forNonError(READY)); - assertThat(child.getCurrentState()).isEqualTo(READY); inOrder.verify(mockHelper).updateBalancingState(eq(READY), isA(ReadyPicker.class)); - verify(mockHelper, times(3)).createSubchannel(any(CreateSubchannelArgs.class)); - verifyNoMoreInteractions(mockHelper); + inOrder.verify(mockHelper, atLeast(0)).refreshNameResolution(); + inOrder.verifyNoMoreInteractions(); } @Test @@ -335,11 +318,10 @@ public void refreshNameResolutionWhenSubchannelConnectionBroken() { assertThat(addressesAcceptanceStatus.isOk()).isTrue(); verify(mockHelper, times(3)).createSubchannel(any(CreateSubchannelArgs.class)); - inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), isA(EmptyPicker.class)); + inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), eq(EMPTY_PICKER)); // Simulate state transitions for each subchannel individually. - for (ChildLbState child : loadBalancer.getChildLbStates()) { - Subchannel sc = child.getSubchannels(mockArgs); + for (Subchannel sc : subchannels.values()) { verify(sc).requestConnection(); deliverSubchannelState(sc, ConnectivityStateInfo.forNonError(CONNECTING)); Status error = Status.UNKNOWN.withDescription("connection broken"); @@ -351,10 +333,10 @@ public void refreshNameResolutionWhenSubchannelConnectionBroken() { deliverSubchannelState(sc, ConnectivityStateInfo.forNonError(IDLE)); inOrder.verify(mockHelper).refreshNameResolution(); verify(sc, times(2)).requestConnection(); - inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), isA(EmptyPicker.class)); + inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), eq(EMPTY_PICKER)); } - verifyNoMoreInteractions(mockHelper); + AbstractTestHelper.verifyNoMoreMeaningfulInteractions(mockHelper); } @Test @@ -431,7 +413,7 @@ public void nameResolutionErrorWithActiveChannels() throws Exception { LoadBalancer.PickResult pickResult2 = pickerCaptor.getValue().pickSubchannel(mockArgs); assertEquals(readySubchannel, pickResult2.getSubchannel()); - verifyNoMoreInteractions(mockHelper); + AbstractTestHelper.verifyNoMoreMeaningfulInteractions(mockHelper); } @Test @@ -460,7 +442,7 @@ public void subchannelStateIsolation() throws Exception { Iterator pickers = pickerCaptor.getAllValues().iterator(); // The picker is incrementally updated as subchannels become READY assertEquals(CONNECTING, stateIterator.next()); - assertThat(pickers.next()).isInstanceOf(EmptyPicker.class); + assertThat(pickers.next()).isEqualTo(EMPTY_PICKER); assertEquals(READY, stateIterator.next()); assertThat(getList(pickers.next())).containsExactly(sc1); assertEquals(READY, stateIterator.next()); @@ -479,6 +461,60 @@ public void subchannelStateIsolation() throws Exception { assertThat(pickers.hasNext()).isFalse(); } + @Test + public void subchannelHealthObserved() throws Exception { + // Only the new PF policy observes the new separate listener for health + PickFirstLoadBalancerProviderAccessor.setEnableNewPickFirst(true); + // PickFirst does most of this work. If the test fails, check IS_PETIOLE_POLICY + Map healthListeners = new HashMap<>(); + loadBalancer = new RoundRobinLoadBalancer(new ForwardingLoadBalancerHelper() { + @Override + public Subchannel createSubchannel(CreateSubchannelArgs args) { + Subchannel subchannel = super.createSubchannel(args.toBuilder() + .setAttributes(args.getAttributes().toBuilder() + .set(LoadBalancer.HAS_HEALTH_PRODUCER_LISTENER_KEY, true) + .build()) + .build()); + healthListeners.put( + subchannel, args.getOption(LoadBalancer.HEALTH_CONSUMER_LISTENER_ARG_KEY)); + return subchannel; + } + + @Override + protected Helper delegate() { + return mockHelper; + } + }); + + InOrder inOrder = inOrder(mockHelper); + Status addressesAcceptanceStatus = acceptAddresses(servers, Attributes.EMPTY); + assertThat(addressesAcceptanceStatus.isOk()).isTrue(); + Subchannel subchannel0 = subchannels.get(Arrays.asList(servers.get(0))); + Subchannel subchannel1 = subchannels.get(Arrays.asList(servers.get(1))); + Subchannel subchannel2 = subchannels.get(Arrays.asList(servers.get(2))); + + // Subchannels go READY, but the LB waits for health + for (Subchannel subchannel : subchannels.values()) { + deliverSubchannelState(subchannel, ConnectivityStateInfo.forNonError(READY)); + } + inOrder.verify(mockHelper, times(0)) + .updateBalancingState(eq(READY), any(SubchannelPicker.class)); + + // Health results lets subchannels go READY + healthListeners.get(subchannel0).onSubchannelState( + ConnectivityStateInfo.forTransientFailure(Status.UNAVAILABLE.withDescription("oh no"))); + healthListeners.get(subchannel1).onSubchannelState(ConnectivityStateInfo.forNonError(READY)); + healthListeners.get(subchannel2).onSubchannelState(ConnectivityStateInfo.forNonError(READY)); + inOrder.verify(mockHelper, times(2)).updateBalancingState(eq(READY), pickerCaptor.capture()); + SubchannelPicker picker = pickerCaptor.getValue(); + List picks = Arrays.asList( + picker.pickSubchannel(mockArgs).getSubchannel(), + picker.pickSubchannel(mockArgs).getSubchannel(), + picker.pickSubchannel(mockArgs).getSubchannel(), + picker.pickSubchannel(mockArgs).getSubchannel()); + assertThat(picks).containsExactly(subchannel1, subchannel2, subchannel1, subchannel2); + } + @Test public void readyPicker_emptyList() { // ready picker list must be non-empty @@ -491,8 +527,8 @@ public void readyPicker_emptyList() { @Test public void internalPickerComparisons() { - SubchannelPicker empty1 = new EmptyPicker(); - SubchannelPicker empty2 = new EmptyPicker(); + SubchannelPicker empty1 = new FixedResultPicker(PickResult.withNoResult()); + SubchannelPicker empty2 = new FixedResultPicker(PickResult.withNoResult()); AtomicInteger seq = new AtomicInteger(0); acceptAddresses(servers, Attributes.EMPTY); // create subchannels diff --git a/util/src/test/java/io/grpc/util/UtilServerInterceptorsTest.java b/util/src/test/java/io/grpc/util/UtilServerInterceptorsTest.java index cfd1d1354fc..a4691d8bdec 100644 --- a/util/src/test/java/io/grpc/util/UtilServerInterceptorsTest.java +++ b/util/src/test/java/io/grpc/util/UtilServerInterceptorsTest.java @@ -43,7 +43,7 @@ @RunWith(JUnit4.class) public class UtilServerInterceptorsTest { private static class VoidCallListener extends ServerCall.Listener { - public void onCall(ServerCall call, Metadata headers) { } + public void onCall(ServerCall unused, Metadata unused2) { } } private MethodDescriptor flowMethod = TestMethodDescriptors.voidMethod(); diff --git a/util/src/testFixtures/java/io/grpc/util/AbstractTestHelper.java b/util/src/testFixtures/java/io/grpc/util/AbstractTestHelper.java index d8e75e939f8..837dc68c057 100644 --- a/util/src/testFixtures/java/io/grpc/util/AbstractTestHelper.java +++ b/util/src/testFixtures/java/io/grpc/util/AbstractTestHelper.java @@ -16,10 +16,14 @@ package io.grpc.util; +import static com.google.common.base.Preconditions.checkNotNull; import static org.mockito.AdditionalAnswers.delegatesTo; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.atLeast; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; -import com.google.common.collect.Maps; import io.grpc.Attributes; import io.grpc.Channel; import io.grpc.ChannelLogger; @@ -31,11 +35,17 @@ import io.grpc.LoadBalancer.Subchannel; import io.grpc.LoadBalancer.SubchannelPicker; import io.grpc.LoadBalancer.SubchannelStateListener; +import io.grpc.SynchronizationContext; +import io.grpc.internal.FakeClock; +import io.grpc.internal.PickFirstLoadBalancerProvider; import java.net.SocketAddress; import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.concurrent.ScheduledExecutorService; +import org.mockito.ArgumentCaptor; +import org.mockito.InOrder; /** * A real class that can be used as a delegate of a mock Helper to provide more real representation @@ -45,6 +55,7 @@ * To use it replace
* \@mock Helper mockHelper
* with
+ * *

Helper mockHelper = mock(Helper.class, delegatesTo(new TestHelper()));

*
* TestHelper will need to define accessors for the maps that information is store within as @@ -52,39 +63,58 @@ */ public abstract class AbstractTestHelper extends ForwardingLoadBalancerHelper { - private final Map mockToRealSubChannelMap = new HashMap<>(); + private final Map mockToRealSubChannelMap = new HashMap<>(); protected final Map realToMockSubChannelMap = new HashMap<>(); - private final Map subchannelStateListeners = - Maps.newLinkedHashMap(); + private final FakeClock fakeClock; + private final SynchronizationContext syncContext; public abstract Map, Subchannel> getSubchannelMap(); - public Map getMockToRealSubChannelMap() { - return mockToRealSubChannelMap; + public AbstractTestHelper() { + this(new FakeClock(), new SynchronizationContext(new Thread.UncaughtExceptionHandler() { + @Override + public void uncaughtException(Thread t, Throwable e) { + throw new RuntimeException(e); + } + })); + } + + public AbstractTestHelper(FakeClock fakeClock, SynchronizationContext syncContext) { + super(); + this.fakeClock = fakeClock; + this.syncContext = syncContext; } - public Subchannel getRealForMockSubChannel(Subchannel mock) { - Subchannel realSc = getMockToRealSubChannelMap().get(mock); + private TestSubchannel getRealForMockSubChannel(Subchannel mock) { + TestSubchannel realSc = mockToRealSubChannelMap.get(mock); if (realSc == null) { - realSc = mock; + realSc = (TestSubchannel) mock; } return realSc; } - public Map getSubchannelStateListeners() { - return subchannelStateListeners; + public static final FakeClock.TaskFilter NOT_START_NEXT_CONNECTION = + new FakeClock.TaskFilter() { + @Override + public boolean shouldAccept(Runnable command) { + return !command.toString().contains("StartNextConnection"); + } + }; + + public static int getNumFilteredPendingTasks(FakeClock fakeClock) { + return fakeClock.getPendingTasks(NOT_START_NEXT_CONNECTION).size(); } public void deliverSubchannelState(Subchannel subchannel, ConnectivityStateInfo newState) { - Subchannel realSc = getMockToRealSubChannelMap().get(subchannel); - if (realSc == null) { - realSc = subchannel; - } - SubchannelStateListener listener = getSubchannelStateListeners().get(realSc); + getSubchannelStateListener(subchannel).onSubchannelState(newState); + } + + public SubchannelStateListener getSubchannelStateListener(Subchannel subchannel) { + SubchannelStateListener listener = getRealForMockSubChannel(subchannel).listener; if (listener == null) { - throw new IllegalArgumentException("subchannel does not have a matching listener"); + throw new IllegalArgumentException("subchannel has not been started"); } - listener.onSubchannelState(newState); + return listener; } @Override @@ -104,7 +134,7 @@ public Subchannel createSubchannel(CreateSubchannelArgs args) { TestSubchannel delegate = createRealSubchannel(args); subchannel = mock(Subchannel.class, delegatesTo(delegate)); getSubchannelMap().put(args.getAddresses(), subchannel); - getMockToRealSubChannelMap().put(subchannel, delegate); + mockToRealSubChannelMap.put(subchannel, delegate); realToMockSubChannelMap.put(delegate, subchannel); } @@ -121,7 +151,17 @@ public void refreshNameResolution() { } public void setChannel(Subchannel subchannel, Channel channel) { - ((TestSubchannel)subchannel).channel = channel; + getRealForMockSubChannel(subchannel).channel = channel; + } + + @Override + public SynchronizationContext getSynchronizationContext() { + return syncContext; + } + + @Override + public ScheduledExecutorService getScheduledExecutorService() { + return fakeClock.getScheduledExecutorService(); } @Override @@ -129,8 +169,36 @@ public String toString() { return "Test Helper"; } + public static void refreshInvokedAndUpdateBS(InOrder inOrder, ConnectivityState state, + Helper helper, + ArgumentCaptor pickerCaptor) { + // Old PF and new PF reverse calling order of updateBlaancingState and refreshNameResolution + if (PickFirstLoadBalancerProvider.isEnabledNewPickFirst()) { + inOrder.verify(helper).updateBalancingState(eq(state), pickerCaptor.capture()); + } + + inOrder.verify(helper).refreshNameResolution(); + + if (!PickFirstLoadBalancerProvider.isEnabledNewPickFirst()) { + inOrder.verify(helper).updateBalancingState(eq(state), pickerCaptor.capture()); + } + } + + public static void verifyNoMoreMeaningfulInteractions(Helper helper) { + verify(helper, atLeast(0)).getSynchronizationContext(); + verify(helper, atLeast(0)).getScheduledExecutorService(); + verifyNoMoreInteractions(helper); + } + + public static void verifyNoMoreMeaningfulInteractions(Helper helper, InOrder inOrder) { + inOrder.verify(helper, atLeast(0)).getSynchronizationContext(); + inOrder.verify(helper, atLeast(0)).getScheduledExecutorService(); + inOrder.verifyNoMoreInteractions(); + } + protected class TestSubchannel extends ForwardingSubchannel { CreateSubchannelArgs args; + SubchannelStateListener listener; Channel channel; public TestSubchannel(CreateSubchannelArgs args) { @@ -173,12 +241,11 @@ public void updateAddresses(List addrs) { @Override public void start(SubchannelStateListener listener) { - getSubchannelStateListeners().put(this, listener); + this.listener = checkNotNull(listener, "listener"); } @Override public void shutdown() { - getSubchannelStateListeners().remove(this); for (EquivalentAddressGroup eag : getAllAddresses()) { getSubchannelMap().remove(Collections.singletonList(eag)); } @@ -200,7 +267,7 @@ public String toString() { } } - public static class FakeSocketAddress extends SocketAddress { + public static final class FakeSocketAddress extends SocketAddress { private static final long serialVersionUID = 0L; final String name; @@ -212,6 +279,20 @@ public static class FakeSocketAddress extends SocketAddress { public String toString() { return "FakeSocketAddress-" + name; } + + @Override + public boolean equals(Object o) { + if (!(o instanceof FakeSocketAddress)) { + return false; + } + FakeSocketAddress that = (FakeSocketAddress) o; + return this.name.equals(that.name); + } + + @Override + public int hashCode() { + return name.hashCode(); + } } } diff --git a/util/src/testFixtures/java/io/grpc/util/GracefulSwitchLoadBalancerAccessor.java b/util/src/testFixtures/java/io/grpc/util/GracefulSwitchLoadBalancerAccessor.java new file mode 100644 index 00000000000..8f62e66be4a --- /dev/null +++ b/util/src/testFixtures/java/io/grpc/util/GracefulSwitchLoadBalancerAccessor.java @@ -0,0 +1,36 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.util; + +import io.grpc.LoadBalancerProvider; + +/** + * Accessors for white-box testing involving GracefulSwitchLoadBalancer. + */ +public final class GracefulSwitchLoadBalancerAccessor { + private GracefulSwitchLoadBalancerAccessor() { + // Do not instantiate + } + + public static LoadBalancerProvider getChildProvider(Object config) { + return (LoadBalancerProvider) ((GracefulSwitchLoadBalancer.Config) config).childFactory; + } + + public static Object getChildConfig(Object config) { + return ((GracefulSwitchLoadBalancer.Config) config).childConfig; + } +} diff --git a/xds/BUILD.bazel b/xds/BUILD.bazel index d3b746e39fa..9a650485c6c 100644 --- a/xds/BUILD.bazel +++ b/xds/BUILD.bazel @@ -1,4 +1,9 @@ -load("//:java_grpc_library.bzl", "java_grpc_library") +load("@bazel_jar_jar//:jar_jar.bzl", "jar_jar") +load("@com_google_protobuf//bazel:java_proto_library.bzl", "java_proto_library") +load("@com_google_protobuf//bazel:proto_library.bzl", "proto_library") +load("@rules_java//java:defs.bzl", "java_binary", "java_library", "java_test") +load("@rules_jvm_external//:defs.bzl", "artifact") +load("//:java_grpc_library.bzl", "INTERNAL_java_grpc_library_for_xds", "java_grpc_library", "java_rpc_toolchain") # Mirrors the dependencies included in the artifact on Maven Central for usage # with maven_install's override_targets. Should only be used as a dep for @@ -12,8 +17,60 @@ java_library( ], ) +# Ordinary deps for :xds java_library( - name = "xds", + name = "xds_deps_depend", + exports = [ + ":orca", + "//:auto_value_annotations", + "//alts", + "//api", + "//auth", + "//context", + "//core:internal", + "//netty", + "//services:metrics", + "//services:metrics_internal", + "//stub", + "//util", + "@com_google_protobuf//:protobuf_java", + "@com_google_protobuf//:protobuf_java_util", + "@maven//:com_google_auth_google_auth_library_oauth2_http", + artifact("com.google.code.findbugs:jsr305"), + artifact("com.google.code.gson:gson"), + artifact("com.google.errorprone:error_prone_annotations"), + artifact("com.google.guava:guava"), + artifact("com.google.re2j:re2j"), + artifact("io.netty:netty-buffer"), + artifact("io.netty:netty-codec"), + artifact("io.netty:netty-common"), + artifact("io.netty:netty-handler"), + artifact("io.netty:netty-transport"), + ], + runtime_deps = [ + "//compiler:java_grpc_library_deps__do_not_reference", + ], +) + +java_library( + name = "xds_deps_depend_neverlink", + neverlink = 1, + exports = [":xds_deps_depend"], +) + +# Deps to be combined into the :xds jar itself +java_library( + name = "xds_deps_embed", + exports = [ + ":envoy_java_grpc", + ":envoy_java_proto", + ":googleapis_rpc_java_proto", + ":xds_java_proto", + ], +) + +java_binary( + name = "xds_notjarjar", srcs = glob( [ "src/main/java/**/*.java", @@ -21,151 +78,288 @@ java_library( ], exclude = ["src/main/java/io/grpc/xds/orca/**"], ), + main_class = "unused", resources = glob([ "src/main/resources/**", ]), + deps = [ + # Do not add additional dependencies here; add them to one of these two deps instead + ":xds_deps_depend_neverlink", + ":xds_deps_embed", + ], +) + +JAR_JAR_RULES = [ + "zap com.google.protobuf.**", # Drop codegen dep + # Keep in sync with build.gradle's shadowJar + "rule com.github.udpa.** io.grpc.xds.shaded.com.github.udpa.@1", + "rule com.github.xds.** io.grpc.xds.shaded.com.github.xds.@1", + "rule com.google.api.expr.** io.grpc.xds.shaded.com.google.api.expr.@1", + "rule com.google.security.** io.grpc.xds.shaded.com.google.security.@1", + "rule dev.cel.expr.** io.grpc.xds.shaded.dev.cel.expr.@1", + "rule envoy.annotations.** io.grpc.xds.shaded.envoy.annotations.@1", + "rule io.envoyproxy.** io.grpc.xds.shaded.io.envoyproxy.@1", + "rule udpa.annotations.** io.grpc.xds.shaded.udpa.annotations.@1", + "rule xds.annotations.** io.grpc.xds.shaded.xds.annotations.@1", +] + +jar_jar( + name = "xds_jarjar", + inline_rules = JAR_JAR_RULES, + input_jar = ":xds_notjarjar_deploy.jar", +) + +java_library( + name = "xds", visibility = ["//visibility:public"], + exports = [":xds_jarjar"], + runtime_deps = [":xds_deps_depend"], +) + +java_proto_library( + name = "googleapis_rpc_java_proto", deps = [ - ":envoy_service_discovery_v2_java_grpc", - ":envoy_service_discovery_v3_java_grpc", - ":envoy_service_load_stats_v2_java_grpc", - ":envoy_service_load_stats_v3_java_grpc", - ":envoy_service_status_v3_java_grpc", - ":xds_protos_java", - ":orca", - "//:auto_value_annotations", - "//alts", + "@com_google_googleapis//google/rpc:code_proto", + "@com_google_googleapis//google/rpc:status_proto", + ], +) + +# Ordinary deps for :orca +java_library( + name = "orca_deps_depend", + exports = [ + ":xds_orca_java_grpc", + ":xds_orca_java_proto", "//api", "//context", "//core:internal", - "//util", - "//netty", - "//stub", + "//protobuf", "//services:metrics", "//services:metrics_internal", - "@com_google_code_findbugs_jsr305//jar", - "@com_google_code_gson_gson//jar", - "@com_google_errorprone_error_prone_annotations//jar", - "@com_google_googleapis//google/rpc:rpc_java_proto", - "@com_google_guava_guava//jar", - "@com_google_protobuf//:protobuf_java", + "//stub", + "//util", "@com_google_protobuf//:protobuf_java_util", - "@com_google_re2j_re2j//jar", - "@io_netty_netty_buffer//jar", - "@io_netty_netty_codec//jar", - "@io_netty_netty_common//jar", - "@io_netty_netty_handler//jar", - "@io_netty_netty_transport//jar", + artifact("com.google.code.findbugs:jsr305"), + artifact("com.google.guava:guava"), ], ) -java_proto_library( - name = "xds_protos_java", +java_library( + name = "orca_deps_depend_neverlink", + neverlink = 1, + exports = [":orca_deps_depend"], +) + +# Deps to be combined into the :orca jar itself +java_library( + name = "orca_deps_embed", + exports = [ + ":xds_orca_java_grpc", + ":xds_orca_java_proto", + ], +) + +java_binary( + name = "orca_notjarjar", + srcs = glob([ + "src/main/java/io/grpc/xds/orca/*.java", + ]), + main_class = "unused", + visibility = ["//visibility:public"], deps = [ - "@com_github_cncf_udpa//udpa/type/v1:pkg", - "@com_github_cncf_xds//xds/data/orca/v3:pkg", - "@com_github_cncf_xds//xds/service/orca/v3:pkg", - "@com_github_cncf_xds//xds/type/v3:pkg", - "@envoy_api//envoy/admin/v3:pkg", - "@envoy_api//envoy/api/v2:pkg", - "@envoy_api//envoy/api/v2/core:pkg", - "@envoy_api//envoy/api/v2/endpoint:pkg", - "@envoy_api//envoy/config/cluster/aggregate/v2alpha:pkg", - "@envoy_api//envoy/config/cluster/v3:pkg", - "@envoy_api//envoy/config/core/v3:pkg", - "@envoy_api//envoy/config/endpoint/v3:pkg", - "@envoy_api//envoy/config/filter/http/fault/v2:pkg", - "@envoy_api//envoy/config/filter/http/router/v2:pkg", - "@envoy_api//envoy/config/filter/network/http_connection_manager/v2:pkg", - "@envoy_api//envoy/config/listener/v3:pkg", - "@envoy_api//envoy/config/rbac/v3:pkg", - "@envoy_api//envoy/config/route/v3:pkg", - "@envoy_api//envoy/extensions/clusters/aggregate/v3:pkg", - "@envoy_api//envoy/extensions/filters/common/fault/v3:pkg", - "@envoy_api//envoy/extensions/filters/http/fault/v3:pkg", - "@envoy_api//envoy/extensions/filters/http/rbac/v3:pkg", - "@envoy_api//envoy/extensions/filters/http/router/v3:pkg", - "@envoy_api//envoy/extensions/filters/network/http_connection_manager/v3:pkg", - "@envoy_api//envoy/extensions/load_balancing_policies/client_side_weighted_round_robin/v3:pkg", - "@envoy_api//envoy/extensions/load_balancing_policies/least_request/v3:pkg", - "@envoy_api//envoy/extensions/load_balancing_policies/pick_first/v3:pkg", - "@envoy_api//envoy/extensions/load_balancing_policies/ring_hash/v3:pkg", - "@envoy_api//envoy/extensions/load_balancing_policies/round_robin/v3:pkg", - "@envoy_api//envoy/extensions/load_balancing_policies/wrr_locality/v3:pkg", - "@envoy_api//envoy/extensions/transport_sockets/tls/v3:pkg", - "@envoy_api//envoy/service/discovery/v2:pkg", - "@envoy_api//envoy/service/discovery/v3:pkg", - "@envoy_api//envoy/service/load_stats/v2:pkg", - "@envoy_api//envoy/service/load_stats/v3:pkg", - "@envoy_api//envoy/service/status/v3:pkg", - "@envoy_api//envoy/type/matcher/v3:pkg", - "@envoy_api//envoy/type/v3:pkg", + # Do not add additional dependencies here; add them to one of these two deps instead + ":orca_deps_depend_neverlink", + ":orca_deps_embed", ], ) -java_grpc_library( - name = "envoy_service_discovery_v2_java_grpc", - srcs = ["@envoy_api//envoy/service/discovery/v2:pkg"], - deps = [":xds_protos_java"], +jar_jar( + name = "orca_jarjar", + inline_rules = JAR_JAR_RULES, + input_jar = ":orca_notjarjar_deploy.jar", ) -java_grpc_library( - name = "envoy_service_discovery_v3_java_grpc", - srcs = ["@envoy_api//envoy/service/discovery/v3:pkg"], - deps = [":xds_protos_java"], +java_library( + name = "orca", + visibility = ["//visibility:public"], + exports = [":orca_jarjar"], + runtime_deps = [":orca_deps_depend"], ) -java_grpc_library( - name = "envoy_service_load_stats_v2_java_grpc", - srcs = ["@envoy_api//envoy/service/load_stats/v2:pkg"], - deps = [":xds_protos_java"], +java_proto_library( + name = "orca_java_proto", + deps = [":xds_proto"], ) java_grpc_library( - name = "envoy_service_load_stats_v3_java_grpc", - srcs = ["@envoy_api//envoy/service/load_stats/v3:pkg"], - deps = [":xds_protos_java"], + name = "orca_java_grpc", + srcs = [":xds_proto"], + deps = [":orca_java_proto"], ) -java_grpc_library( - name = "envoy_service_status_v3_java_grpc", - srcs = ["@envoy_api//envoy/service/status/v3:pkg"], - deps = [":xds_protos_java"], +proto_library( + name = "cel_spec_proto", + srcs = glob(["third_party/cel-spec/src/main/proto/**/*.proto"]), + strip_import_prefix = "third_party/cel-spec/src/main/proto/", + deps = [ + "@com_google_protobuf//:duration_proto", + "@com_google_protobuf//:empty_proto", + "@com_google_protobuf//:struct_proto", + "@com_google_protobuf//:timestamp_proto", + ], ) -java_library( - name = "orca", +proto_library( + name = "envoy_proto", + srcs = glob(["third_party/envoy/src/main/proto/**/*.proto"]), + strip_import_prefix = "third_party/envoy/src/main/proto/", + deps = [ + ":googleapis_proto", + ":protoc_gen_validate_proto", + ":xds_proto", + "@com_google_googleapis//google/api:annotations_proto", + "@com_google_googleapis//google/rpc:status_proto", + "@com_google_protobuf//:any_proto", + "@com_google_protobuf//:descriptor_proto", + "@com_google_protobuf//:duration_proto", + "@com_google_protobuf//:empty_proto", + "@com_google_protobuf//:struct_proto", + "@com_google_protobuf//:timestamp_proto", + "@com_google_protobuf//:wrappers_proto", + ], +) + +java_proto_library( + name = "envoy_java_proto", + deps = [":envoy_proto"], +) + +INTERNAL_java_grpc_library_for_xds( + name = "envoy_java_grpc", + srcs = [":envoy_proto"], + deps = [":envoy_java_proto"], +) + +proto_library( + name = "googleapis_proto", + srcs = glob(["third_party/googleapis/src/main/proto/**/*.proto"]), + strip_import_prefix = "third_party/googleapis/src/main/proto/", + deps = [ + "@com_google_protobuf//:duration_proto", + "@com_google_protobuf//:empty_proto", + "@com_google_protobuf//:struct_proto", + "@com_google_protobuf//:timestamp_proto", + ], +) + +proto_library( + name = "protoc_gen_validate_proto", + srcs = glob(["third_party/protoc-gen-validate/src/main/proto/**/*.proto"]), + strip_import_prefix = "third_party/protoc-gen-validate/src/main/proto/", + deps = [ + "@com_google_protobuf//:descriptor_proto", + "@com_google_protobuf//:duration_proto", + "@com_google_protobuf//:timestamp_proto", + ], +) + +proto_library( + name = "xds_proto", + srcs = glob( + ["third_party/xds/src/main/proto/**/*.proto"], + exclude = [ + "third_party/xds/src/main/proto/xds/data/orca/v3/*.proto", + "third_party/xds/src/main/proto/xds/service/orca/v3/*.proto", + ], + ), + strip_import_prefix = "third_party/xds/src/main/proto/", + deps = [ + ":cel_spec_proto", + ":googleapis_proto", + ":protoc_gen_validate_proto", + "@com_google_protobuf//:any_proto", + "@com_google_protobuf//:descriptor_proto", + "@com_google_protobuf//:duration_proto", + "@com_google_protobuf//:struct_proto", + "@com_google_protobuf//:wrappers_proto", + ], +) + +java_proto_library( + name = "xds_java_proto", + deps = [":xds_proto"], +) + +proto_library( + name = "xds_orca_proto", srcs = glob([ - "src/main/java/io/grpc/xds/orca/*.java", + "third_party/xds/src/main/proto/xds/data/orca/v3/*.proto", + "third_party/xds/src/main/proto/xds/service/orca/v3/*.proto", ]), - visibility = ["//visibility:public"], + strip_import_prefix = "third_party/xds/src/main/proto/", deps = [ - ":orca_protos_java", - ":xds_service_orca_v3_java_grpc", - "//api", - "//context", - "//core:internal", - "//util", - "//protobuf", - "//services:metrics", - "//services:metrics_internal", - "//stub", - "@com_google_code_findbugs_jsr305//jar", - "@com_google_guava_guava//jar", - "@com_google_protobuf//:protobuf_java_util", + ":protoc_gen_validate_proto", + "@com_google_protobuf//:duration_proto", ], ) java_proto_library( - name = "orca_protos_java", + name = "xds_orca_java_proto", + deps = [":xds_orca_proto"], +) + +java_grpc_library( + name = "xds_orca_java_grpc", + srcs = [":xds_orca_proto"], + deps = [":xds_orca_java_proto"], +) + +java_rpc_toolchain( + name = "java_grpc_library_toolchain", + plugin = "//compiler:grpc_java_plugin", + runtime = [":java_grpc_library_deps"], +) + +java_library( + name = "java_grpc_library_deps", + neverlink = 1, + exports = ["//compiler:java_grpc_library_deps__do_not_reference"], +) + +java_library( + name = "testlib", + testonly = 1, + srcs = [ + "src/test/java/io/grpc/xds/ControlPlaneRule.java", + "src/test/java/io/grpc/xds/DataPlaneRule.java", + "src/test/java/io/grpc/xds/FakeControlPlaneXdsIntegrationTest.java", + "src/test/java/io/grpc/xds/MetadataLoadBalancerProvider.java", + "src/test/java/io/grpc/xds/XdsTestControlPlaneService.java", + "src/test/java/io/grpc/xds/XdsTestLoadReportingService.java", + ], deps = [ - "@com_github_cncf_xds//xds/data/orca/v3:pkg", - "@com_github_cncf_xds//xds/service/orca/v3:pkg", + ":envoy_java_grpc", + ":envoy_java_proto", + ":xds", + ":xds_java_proto", + "//api", + "//api:test_fixtures", + "//core:internal", + "//stub", + "//testing-proto:simpleservice_java_grpc", + "//testing-proto:simpleservice_java_proto", + "//util", + "@com_google_protobuf//java/core", + "@maven//:com_google_code_findbugs_jsr305", + "@maven//:com_google_guava_guava", + "@maven//:com_google_truth_truth", + "@maven//:junit_junit", ], ) -java_grpc_library( - name = "xds_service_orca_v3_java_grpc", - srcs = ["@com_github_cncf_xds//xds/service/orca/v3:pkg"], - deps = [":orca_protos_java"], +java_test( + name = "FakeControlPlaneXdsIntegrationTest", + size = "small", + test_class = "io.grpc.xds.FakeControlPlaneXdsIntegrationTest", + runtime_deps = [":testlib"], ) diff --git a/xds/build.gradle b/xds/build.gradle index fa6eb9f9017..8394fe12f6b 100644 --- a/xds/build.gradle +++ b/xds/build.gradle @@ -4,8 +4,8 @@ plugins { id "java" id "maven-publish" - id "com.github.johnrengelman.shadow" id "com.google.protobuf" + id "com.gradleup.shadow" id "ru.vyarus.animalsniffer" } @@ -17,11 +17,11 @@ sourceSets { srcDir "${projectDir}/third_party/zero-allocation-hashing/main/java" } proto { + srcDir 'third_party/cel-spec/src/main/proto' srcDir 'third_party/envoy/src/main/proto' + srcDir 'third_party/googleapis/src/main/proto' srcDir 'third_party/protoc-gen-validate/src/main/proto' srcDir 'third_party/xds/src/main/proto' - srcDir 'third_party/googleapis/src/main/proto' - srcDir 'third_party/istio/src/main/proto' } } main { @@ -41,23 +41,24 @@ configurations { } dependencies { - thirdpartyCompileOnly libraries.javax.annotation thirdpartyImplementation project(':grpc-protobuf'), - project(':grpc-stub'), - libraries.opencensus.proto + project(':grpc-stub') compileOnly sourceSets.thirdparty.output + testCompileOnly sourceSets.thirdparty.output implementation project(':grpc-stub'), project(':grpc-core'), project(':grpc-util'), project(':grpc-services'), project(':grpc-auth'), project(path: ':grpc-alts', configuration: 'shadow'), + libraries.guava, libraries.gson, libraries.re2j, libraries.auto.value.annotations, libraries.protobuf.java.util def nettyDependency = implementation project(':grpc-netty') + testImplementation project(':grpc-api') testImplementation project(':grpc-rls') testImplementation project(':grpc-inprocess') testImplementation testFixtures(project(':grpc-core')), @@ -80,7 +81,11 @@ dependencies { shadow configurations.implementation.getDependencies().minus([nettyDependency]) shadow project(path: ':grpc-netty-shaded', configuration: 'shadow') - signature libraries.signature.java + signature (libraries.signature.java) { + artifact { + extension = "signature" + } + } testRuntimeOnly libraries.netty.tcnative, libraries.netty.tcnative.classes testRuntimeOnly (libraries.netty.tcnative) { @@ -126,8 +131,6 @@ tasks.named("checkstyleThirdparty").configure { tasks.named("compileJava").configure { it.options.compilerArgs += [ - // TODO: remove - "-Xlint:-deprecation", // only has AutoValue annotation processor "-Xlint:-processing", ] @@ -181,11 +184,13 @@ tasks.named("shadowJar").configure { include(project(':grpc-xds')) } // Relocated packages commonly need exclusions in jacocoTestReport and javadoc + // Keep in sync with BUILD.bazel's JAR_JAR_RULES relocate 'com.github.udpa', "${prefixName}.shaded.com.github.udpa" relocate 'com.github.xds', "${prefixName}.shaded.com.github.xds" relocate 'com.google.api.expr', "${prefixName}.shaded.com.google.api.expr" relocate 'com.google.security', "${prefixName}.shaded.com.google.security" // TODO: missing java_package option in .proto + relocate 'dev.cel.expr', "${prefixName}.shaded.dev.cel.expr" relocate 'envoy.annotations', "${prefixName}.shaded.envoy.annotations" relocate 'io.envoyproxy', "${prefixName}.shaded.io.envoyproxy" relocate 'io.grpc.netty', 'io.grpc.netty.shaded.io.grpc.netty' @@ -213,6 +218,7 @@ tasks.named("jacocoTestReport").configure { '**/com/github/xds/**', '**/com/google/api/expr/**', '**/com/google/security/**', + '**/cel/expr/**', '**/envoy/annotations/**', '**/io/envoyproxy/**', '**/udpa/annotations/**', diff --git a/xds/src/generated/thirdparty/grpc/com/github/xds/service/orca/v3/OpenRcaServiceGrpc.java b/xds/src/generated/thirdparty/grpc/com/github/xds/service/orca/v3/OpenRcaServiceGrpc.java index de2c7424fca..e0e28ad4072 100644 --- a/xds/src/generated/thirdparty/grpc/com/github/xds/service/orca/v3/OpenRcaServiceGrpc.java +++ b/xds/src/generated/thirdparty/grpc/com/github/xds/service/orca/v3/OpenRcaServiceGrpc.java @@ -14,9 +14,6 @@ * a new call to change backend reporting frequency. * */ -@javax.annotation.Generated( - value = "by gRPC proto compiler", - comments = "Source: xds/service/orca/v3/orca.proto") @io.grpc.stub.annotations.GrpcGenerated public final class OpenRcaServiceGrpc { @@ -70,6 +67,21 @@ public OpenRcaServiceStub newStub(io.grpc.Channel channel, io.grpc.CallOptions c return OpenRcaServiceStub.newStub(factory, channel); } + /** + * Creates a new blocking-style stub that supports all types of calls on the service + */ + public static OpenRcaServiceBlockingV2Stub newBlockingV2Stub( + io.grpc.Channel channel) { + io.grpc.stub.AbstractStub.StubFactory factory = + new io.grpc.stub.AbstractStub.StubFactory() { + @java.lang.Override + public OpenRcaServiceBlockingV2Stub newStub(io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new OpenRcaServiceBlockingV2Stub(channel, callOptions); + } + }; + return OpenRcaServiceBlockingV2Stub.newStub(factory, channel); + } + /** * Creates a new blocking-style stub that supports unary and streaming output calls on the service */ @@ -191,6 +203,42 @@ public void streamCoreMetrics(com.github.xds.service.orca.v3.OrcaLoadReportReque * a new call to change backend reporting frequency. * */ + public static final class OpenRcaServiceBlockingV2Stub + extends io.grpc.stub.AbstractBlockingStub { + private OpenRcaServiceBlockingV2Stub( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + super(channel, callOptions); + } + + @java.lang.Override + protected OpenRcaServiceBlockingV2Stub build( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new OpenRcaServiceBlockingV2Stub(channel, callOptions); + } + + /** + */ + @io.grpc.ExperimentalApi("https://github.com/grpc/grpc-java/issues/10918") + public io.grpc.stub.BlockingClientCall + streamCoreMetrics(com.github.xds.service.orca.v3.OrcaLoadReportRequest request) { + return io.grpc.stub.ClientCalls.blockingV2ServerStreamingCall( + getChannel(), getStreamCoreMetricsMethod(), getCallOptions(), request); + } + } + + /** + * A stub to allow clients to do limited synchronous rpc calls to service OpenRcaService. + *
+   * Out-of-band (OOB) load reporting service for the additional load reporting
+   * agent that does not sit in the request path. Reports are periodically sampled
+   * with sufficient frequency to provide temporal association with requests.
+   * OOB reporting compensates the limitation of in-band reporting in revealing
+   * costs for backends that do not provide a steady stream of telemetry such as
+   * long running stream operations and zero QPS services. This is a server
+   * streaming service, client needs to terminate current RPC and initiate
+   * a new call to change backend reporting frequency.
+   * 
+ */ public static final class OpenRcaServiceBlockingStub extends io.grpc.stub.AbstractBlockingStub { private OpenRcaServiceBlockingStub( diff --git a/xds/src/generated/thirdparty/grpc/io/envoyproxy/envoy/service/auth/v3/AuthorizationGrpc.java b/xds/src/generated/thirdparty/grpc/io/envoyproxy/envoy/service/auth/v3/AuthorizationGrpc.java new file mode 100644 index 00000000000..df9b7a3514b --- /dev/null +++ b/xds/src/generated/thirdparty/grpc/io/envoyproxy/envoy/service/auth/v3/AuthorizationGrpc.java @@ -0,0 +1,377 @@ +package io.envoyproxy.envoy.service.auth.v3; + +import static io.grpc.MethodDescriptor.generateFullMethodName; + +/** + *
+ * A generic interface for performing authorization check on incoming
+ * requests to a networked service.
+ * 
+ */ +@io.grpc.stub.annotations.GrpcGenerated +public final class AuthorizationGrpc { + + private AuthorizationGrpc() {} + + public static final java.lang.String SERVICE_NAME = "envoy.service.auth.v3.Authorization"; + + // Static method descriptors that strictly reflect the proto. + private static volatile io.grpc.MethodDescriptor getCheckMethod; + + @io.grpc.stub.annotations.RpcMethod( + fullMethodName = SERVICE_NAME + '/' + "Check", + requestType = io.envoyproxy.envoy.service.auth.v3.CheckRequest.class, + responseType = io.envoyproxy.envoy.service.auth.v3.CheckResponse.class, + methodType = io.grpc.MethodDescriptor.MethodType.UNARY) + public static io.grpc.MethodDescriptor getCheckMethod() { + io.grpc.MethodDescriptor getCheckMethod; + if ((getCheckMethod = AuthorizationGrpc.getCheckMethod) == null) { + synchronized (AuthorizationGrpc.class) { + if ((getCheckMethod = AuthorizationGrpc.getCheckMethod) == null) { + AuthorizationGrpc.getCheckMethod = getCheckMethod = + io.grpc.MethodDescriptor.newBuilder() + .setType(io.grpc.MethodDescriptor.MethodType.UNARY) + .setFullMethodName(generateFullMethodName(SERVICE_NAME, "Check")) + .setSampledToLocalTracing(true) + .setRequestMarshaller(io.grpc.protobuf.ProtoUtils.marshaller( + io.envoyproxy.envoy.service.auth.v3.CheckRequest.getDefaultInstance())) + .setResponseMarshaller(io.grpc.protobuf.ProtoUtils.marshaller( + io.envoyproxy.envoy.service.auth.v3.CheckResponse.getDefaultInstance())) + .setSchemaDescriptor(new AuthorizationMethodDescriptorSupplier("Check")) + .build(); + } + } + } + return getCheckMethod; + } + + /** + * Creates a new async stub that supports all call types for the service + */ + public static AuthorizationStub newStub(io.grpc.Channel channel) { + io.grpc.stub.AbstractStub.StubFactory factory = + new io.grpc.stub.AbstractStub.StubFactory() { + @java.lang.Override + public AuthorizationStub newStub(io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new AuthorizationStub(channel, callOptions); + } + }; + return AuthorizationStub.newStub(factory, channel); + } + + /** + * Creates a new blocking-style stub that supports all types of calls on the service + */ + public static AuthorizationBlockingV2Stub newBlockingV2Stub( + io.grpc.Channel channel) { + io.grpc.stub.AbstractStub.StubFactory factory = + new io.grpc.stub.AbstractStub.StubFactory() { + @java.lang.Override + public AuthorizationBlockingV2Stub newStub(io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new AuthorizationBlockingV2Stub(channel, callOptions); + } + }; + return AuthorizationBlockingV2Stub.newStub(factory, channel); + } + + /** + * Creates a new blocking-style stub that supports unary and streaming output calls on the service + */ + public static AuthorizationBlockingStub newBlockingStub( + io.grpc.Channel channel) { + io.grpc.stub.AbstractStub.StubFactory factory = + new io.grpc.stub.AbstractStub.StubFactory() { + @java.lang.Override + public AuthorizationBlockingStub newStub(io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new AuthorizationBlockingStub(channel, callOptions); + } + }; + return AuthorizationBlockingStub.newStub(factory, channel); + } + + /** + * Creates a new ListenableFuture-style stub that supports unary calls on the service + */ + public static AuthorizationFutureStub newFutureStub( + io.grpc.Channel channel) { + io.grpc.stub.AbstractStub.StubFactory factory = + new io.grpc.stub.AbstractStub.StubFactory() { + @java.lang.Override + public AuthorizationFutureStub newStub(io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new AuthorizationFutureStub(channel, callOptions); + } + }; + return AuthorizationFutureStub.newStub(factory, channel); + } + + /** + *
+   * A generic interface for performing authorization check on incoming
+   * requests to a networked service.
+   * 
+ */ + public interface AsyncService { + + /** + *
+     * Performs authorization check based on the attributes associated with the
+     * incoming request, and returns status `OK` or not `OK`.
+     * 
+ */ + default void check(io.envoyproxy.envoy.service.auth.v3.CheckRequest request, + io.grpc.stub.StreamObserver responseObserver) { + io.grpc.stub.ServerCalls.asyncUnimplementedUnaryCall(getCheckMethod(), responseObserver); + } + } + + /** + * Base class for the server implementation of the service Authorization. + *
+   * A generic interface for performing authorization check on incoming
+   * requests to a networked service.
+   * 
+ */ + public static abstract class AuthorizationImplBase + implements io.grpc.BindableService, AsyncService { + + @java.lang.Override public final io.grpc.ServerServiceDefinition bindService() { + return AuthorizationGrpc.bindService(this); + } + } + + /** + * A stub to allow clients to do asynchronous rpc calls to service Authorization. + *
+   * A generic interface for performing authorization check on incoming
+   * requests to a networked service.
+   * 
+ */ + public static final class AuthorizationStub + extends io.grpc.stub.AbstractAsyncStub { + private AuthorizationStub( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + super(channel, callOptions); + } + + @java.lang.Override + protected AuthorizationStub build( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new AuthorizationStub(channel, callOptions); + } + + /** + *
+     * Performs authorization check based on the attributes associated with the
+     * incoming request, and returns status `OK` or not `OK`.
+     * 
+ */ + public void check(io.envoyproxy.envoy.service.auth.v3.CheckRequest request, + io.grpc.stub.StreamObserver responseObserver) { + io.grpc.stub.ClientCalls.asyncUnaryCall( + getChannel().newCall(getCheckMethod(), getCallOptions()), request, responseObserver); + } + } + + /** + * A stub to allow clients to do synchronous rpc calls to service Authorization. + *
+   * A generic interface for performing authorization check on incoming
+   * requests to a networked service.
+   * 
+ */ + public static final class AuthorizationBlockingV2Stub + extends io.grpc.stub.AbstractBlockingStub { + private AuthorizationBlockingV2Stub( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + super(channel, callOptions); + } + + @java.lang.Override + protected AuthorizationBlockingV2Stub build( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new AuthorizationBlockingV2Stub(channel, callOptions); + } + + /** + *
+     * Performs authorization check based on the attributes associated with the
+     * incoming request, and returns status `OK` or not `OK`.
+     * 
+ */ + public io.envoyproxy.envoy.service.auth.v3.CheckResponse check(io.envoyproxy.envoy.service.auth.v3.CheckRequest request) throws io.grpc.StatusException { + return io.grpc.stub.ClientCalls.blockingV2UnaryCall( + getChannel(), getCheckMethod(), getCallOptions(), request); + } + } + + /** + * A stub to allow clients to do limited synchronous rpc calls to service Authorization. + *
+   * A generic interface for performing authorization check on incoming
+   * requests to a networked service.
+   * 
+ */ + public static final class AuthorizationBlockingStub + extends io.grpc.stub.AbstractBlockingStub { + private AuthorizationBlockingStub( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + super(channel, callOptions); + } + + @java.lang.Override + protected AuthorizationBlockingStub build( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new AuthorizationBlockingStub(channel, callOptions); + } + + /** + *
+     * Performs authorization check based on the attributes associated with the
+     * incoming request, and returns status `OK` or not `OK`.
+     * 
+ */ + public io.envoyproxy.envoy.service.auth.v3.CheckResponse check(io.envoyproxy.envoy.service.auth.v3.CheckRequest request) { + return io.grpc.stub.ClientCalls.blockingUnaryCall( + getChannel(), getCheckMethod(), getCallOptions(), request); + } + } + + /** + * A stub to allow clients to do ListenableFuture-style rpc calls to service Authorization. + *
+   * A generic interface for performing authorization check on incoming
+   * requests to a networked service.
+   * 
+ */ + public static final class AuthorizationFutureStub + extends io.grpc.stub.AbstractFutureStub { + private AuthorizationFutureStub( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + super(channel, callOptions); + } + + @java.lang.Override + protected AuthorizationFutureStub build( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new AuthorizationFutureStub(channel, callOptions); + } + + /** + *
+     * Performs authorization check based on the attributes associated with the
+     * incoming request, and returns status `OK` or not `OK`.
+     * 
+ */ + public com.google.common.util.concurrent.ListenableFuture check( + io.envoyproxy.envoy.service.auth.v3.CheckRequest request) { + return io.grpc.stub.ClientCalls.futureUnaryCall( + getChannel().newCall(getCheckMethod(), getCallOptions()), request); + } + } + + private static final int METHODID_CHECK = 0; + + private static final class MethodHandlers implements + io.grpc.stub.ServerCalls.UnaryMethod, + io.grpc.stub.ServerCalls.ServerStreamingMethod, + io.grpc.stub.ServerCalls.ClientStreamingMethod, + io.grpc.stub.ServerCalls.BidiStreamingMethod { + private final AsyncService serviceImpl; + private final int methodId; + + MethodHandlers(AsyncService serviceImpl, int methodId) { + this.serviceImpl = serviceImpl; + this.methodId = methodId; + } + + @java.lang.Override + @java.lang.SuppressWarnings("unchecked") + public void invoke(Req request, io.grpc.stub.StreamObserver responseObserver) { + switch (methodId) { + case METHODID_CHECK: + serviceImpl.check((io.envoyproxy.envoy.service.auth.v3.CheckRequest) request, + (io.grpc.stub.StreamObserver) responseObserver); + break; + default: + throw new AssertionError(); + } + } + + @java.lang.Override + @java.lang.SuppressWarnings("unchecked") + public io.grpc.stub.StreamObserver invoke( + io.grpc.stub.StreamObserver responseObserver) { + switch (methodId) { + default: + throw new AssertionError(); + } + } + } + + public static final io.grpc.ServerServiceDefinition bindService(AsyncService service) { + return io.grpc.ServerServiceDefinition.builder(getServiceDescriptor()) + .addMethod( + getCheckMethod(), + io.grpc.stub.ServerCalls.asyncUnaryCall( + new MethodHandlers< + io.envoyproxy.envoy.service.auth.v3.CheckRequest, + io.envoyproxy.envoy.service.auth.v3.CheckResponse>( + service, METHODID_CHECK))) + .build(); + } + + private static abstract class AuthorizationBaseDescriptorSupplier + implements io.grpc.protobuf.ProtoFileDescriptorSupplier, io.grpc.protobuf.ProtoServiceDescriptorSupplier { + AuthorizationBaseDescriptorSupplier() {} + + @java.lang.Override + public com.google.protobuf.Descriptors.FileDescriptor getFileDescriptor() { + return io.envoyproxy.envoy.service.auth.v3.ExternalAuthProto.getDescriptor(); + } + + @java.lang.Override + public com.google.protobuf.Descriptors.ServiceDescriptor getServiceDescriptor() { + return getFileDescriptor().findServiceByName("Authorization"); + } + } + + private static final class AuthorizationFileDescriptorSupplier + extends AuthorizationBaseDescriptorSupplier { + AuthorizationFileDescriptorSupplier() {} + } + + private static final class AuthorizationMethodDescriptorSupplier + extends AuthorizationBaseDescriptorSupplier + implements io.grpc.protobuf.ProtoMethodDescriptorSupplier { + private final java.lang.String methodName; + + AuthorizationMethodDescriptorSupplier(java.lang.String methodName) { + this.methodName = methodName; + } + + @java.lang.Override + public com.google.protobuf.Descriptors.MethodDescriptor getMethodDescriptor() { + return getServiceDescriptor().findMethodByName(methodName); + } + } + + private static volatile io.grpc.ServiceDescriptor serviceDescriptor; + + public static io.grpc.ServiceDescriptor getServiceDescriptor() { + io.grpc.ServiceDescriptor result = serviceDescriptor; + if (result == null) { + synchronized (AuthorizationGrpc.class) { + result = serviceDescriptor; + if (result == null) { + serviceDescriptor = result = io.grpc.ServiceDescriptor.newBuilder(SERVICE_NAME) + .setSchemaDescriptor(new AuthorizationFileDescriptorSupplier()) + .addMethod(getCheckMethod()) + .build(); + } + } + } + return result; + } +} diff --git a/xds/src/generated/thirdparty/grpc/io/envoyproxy/envoy/service/discovery/v3/AggregatedDiscoveryServiceGrpc.java b/xds/src/generated/thirdparty/grpc/io/envoyproxy/envoy/service/discovery/v3/AggregatedDiscoveryServiceGrpc.java index e039c2193e8..94b2fd86b96 100644 --- a/xds/src/generated/thirdparty/grpc/io/envoyproxy/envoy/service/discovery/v3/AggregatedDiscoveryServiceGrpc.java +++ b/xds/src/generated/thirdparty/grpc/io/envoyproxy/envoy/service/discovery/v3/AggregatedDiscoveryServiceGrpc.java @@ -12,9 +12,6 @@ * the multiplexed singleton APIs at the Envoy instance and management server. * */ -@javax.annotation.Generated( - value = "by gRPC proto compiler", - comments = "Source: envoy/service/discovery/v3/ads.proto") @io.grpc.stub.annotations.GrpcGenerated public final class AggregatedDiscoveryServiceGrpc { @@ -99,6 +96,21 @@ public AggregatedDiscoveryServiceStub newStub(io.grpc.Channel channel, io.grpc.C return AggregatedDiscoveryServiceStub.newStub(factory, channel); } + /** + * Creates a new blocking-style stub that supports all types of calls on the service + */ + public static AggregatedDiscoveryServiceBlockingV2Stub newBlockingV2Stub( + io.grpc.Channel channel) { + io.grpc.stub.AbstractStub.StubFactory factory = + new io.grpc.stub.AbstractStub.StubFactory() { + @java.lang.Override + public AggregatedDiscoveryServiceBlockingV2Stub newStub(io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new AggregatedDiscoveryServiceBlockingV2Stub(channel, callOptions); + } + }; + return AggregatedDiscoveryServiceBlockingV2Stub.newStub(factory, channel); + } + /** * Creates a new blocking-style stub that supports unary and streaming output calls on the service */ @@ -233,6 +245,52 @@ public io.grpc.stub.StreamObserver */ + public static final class AggregatedDiscoveryServiceBlockingV2Stub + extends io.grpc.stub.AbstractBlockingStub { + private AggregatedDiscoveryServiceBlockingV2Stub( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + super(channel, callOptions); + } + + @java.lang.Override + protected AggregatedDiscoveryServiceBlockingV2Stub build( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new AggregatedDiscoveryServiceBlockingV2Stub(channel, callOptions); + } + + /** + *
+     * This is a gRPC-only API.
+     * 
+ */ + @io.grpc.ExperimentalApi("https://github.com/grpc/grpc-java/issues/10918") + public io.grpc.stub.BlockingClientCall + streamAggregatedResources() { + return io.grpc.stub.ClientCalls.blockingBidiStreamingCall( + getChannel(), getStreamAggregatedResourcesMethod(), getCallOptions()); + } + + /** + */ + @io.grpc.ExperimentalApi("https://github.com/grpc/grpc-java/issues/10918") + public io.grpc.stub.BlockingClientCall + deltaAggregatedResources() { + return io.grpc.stub.ClientCalls.blockingBidiStreamingCall( + getChannel(), getDeltaAggregatedResourcesMethod(), getCallOptions()); + } + } + + /** + * A stub to allow clients to do limited synchronous rpc calls to service AggregatedDiscoveryService. + *
+   * See https://github.com/envoyproxy/envoy-api#apis for a description of the role of
+   * ADS and how it is intended to be used by a management server. ADS requests
+   * have the same structure as their singleton xDS counterparts, but can
+   * multiplex many resource types on a single stream. The type_url in the
+   * DiscoveryRequest/DiscoveryResponse provides sufficient information to recover
+   * the multiplexed singleton APIs at the Envoy instance and management server.
+   * 
+ */ public static final class AggregatedDiscoveryServiceBlockingStub extends io.grpc.stub.AbstractBlockingStub { private AggregatedDiscoveryServiceBlockingStub( diff --git a/xds/src/generated/thirdparty/grpc/io/envoyproxy/envoy/service/load_stats/v3/LoadReportingServiceGrpc.java b/xds/src/generated/thirdparty/grpc/io/envoyproxy/envoy/service/load_stats/v3/LoadReportingServiceGrpc.java index 2adbf02e98a..4f12405be87 100644 --- a/xds/src/generated/thirdparty/grpc/io/envoyproxy/envoy/service/load_stats/v3/LoadReportingServiceGrpc.java +++ b/xds/src/generated/thirdparty/grpc/io/envoyproxy/envoy/service/load_stats/v3/LoadReportingServiceGrpc.java @@ -4,9 +4,6 @@ /** */ -@javax.annotation.Generated( - value = "by gRPC proto compiler", - comments = "Source: envoy/service/load_stats/v3/lrs.proto") @io.grpc.stub.annotations.GrpcGenerated public final class LoadReportingServiceGrpc { @@ -60,6 +57,21 @@ public LoadReportingServiceStub newStub(io.grpc.Channel channel, io.grpc.CallOpt return LoadReportingServiceStub.newStub(factory, channel); } + /** + * Creates a new blocking-style stub that supports all types of calls on the service + */ + public static LoadReportingServiceBlockingV2Stub newBlockingV2Stub( + io.grpc.Channel channel) { + io.grpc.stub.AbstractStub.StubFactory factory = + new io.grpc.stub.AbstractStub.StubFactory() { + @java.lang.Override + public LoadReportingServiceBlockingV2Stub newStub(io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new LoadReportingServiceBlockingV2Stub(channel, callOptions); + } + }; + return LoadReportingServiceBlockingV2Stub.newStub(factory, channel); + } + /** * Creates a new blocking-style stub that supports unary and streaming output calls on the service */ @@ -199,6 +211,61 @@ public io.grpc.stub.StreamObserver { + private LoadReportingServiceBlockingV2Stub( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + super(channel, callOptions); + } + + @java.lang.Override + protected LoadReportingServiceBlockingV2Stub build( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new LoadReportingServiceBlockingV2Stub(channel, callOptions); + } + + /** + *
+     * Advanced API to allow for multi-dimensional load balancing by remote
+     * server. For receiving LB assignments, the steps are:
+     * 1, The management server is configured with per cluster/zone/load metric
+     *    capacity configuration. The capacity configuration definition is
+     *    outside of the scope of this document.
+     * 2. Envoy issues a standard {Stream,Fetch}Endpoints request for the clusters
+     *    to balance.
+     * Independently, Envoy will initiate a StreamLoadStats bidi stream with a
+     * management server:
+     * 1. Once a connection establishes, the management server publishes a
+     *    LoadStatsResponse for all clusters it is interested in learning load
+     *    stats about.
+     * 2. For each cluster, Envoy load balances incoming traffic to upstream hosts
+     *    based on per-zone weights and/or per-instance weights (if specified)
+     *    based on intra-zone LbPolicy. This information comes from the above
+     *    {Stream,Fetch}Endpoints.
+     * 3. When upstream hosts reply, they optionally add header <define header
+     *    name> with ASCII representation of EndpointLoadMetricStats.
+     * 4. Envoy aggregates load reports over the period of time given to it in
+     *    LoadStatsResponse.load_reporting_interval. This includes aggregation
+     *    stats Envoy maintains by itself (total_requests, rpc_errors etc.) as
+     *    well as load metrics from upstream hosts.
+     * 5. When the timer of load_reporting_interval expires, Envoy sends new
+     *    LoadStatsRequest filled with load reports for each cluster.
+     * 6. The management server uses the load reports from all reported Envoys
+     *    from around the world, computes global assignment and prepares traffic
+     *    assignment destined for each zone Envoys are located in. Goto 2.
+     * 
+ */ + @io.grpc.ExperimentalApi("https://github.com/grpc/grpc-java/issues/10918") + public io.grpc.stub.BlockingClientCall + streamLoadStats() { + return io.grpc.stub.ClientCalls.blockingBidiStreamingCall( + getChannel(), getStreamLoadStatsMethod(), getCallOptions()); + } + } + + /** + * A stub to allow clients to do limited synchronous rpc calls to service LoadReportingService. + */ public static final class LoadReportingServiceBlockingStub extends io.grpc.stub.AbstractBlockingStub { private LoadReportingServiceBlockingStub( diff --git a/xds/src/generated/thirdparty/grpc/io/envoyproxy/envoy/service/rate_limit_quota/v3/RateLimitQuotaServiceGrpc.java b/xds/src/generated/thirdparty/grpc/io/envoyproxy/envoy/service/rate_limit_quota/v3/RateLimitQuotaServiceGrpc.java new file mode 100644 index 00000000000..3f17bb54566 --- /dev/null +++ b/xds/src/generated/thirdparty/grpc/io/envoyproxy/envoy/service/rate_limit_quota/v3/RateLimitQuotaServiceGrpc.java @@ -0,0 +1,348 @@ +package io.envoyproxy.envoy.service.rate_limit_quota.v3; + +import static io.grpc.MethodDescriptor.generateFullMethodName; + +/** + *
+ * Defines the Rate Limit Quota Service (RLQS).
+ * 
+ */ +@io.grpc.stub.annotations.GrpcGenerated +public final class RateLimitQuotaServiceGrpc { + + private RateLimitQuotaServiceGrpc() {} + + public static final java.lang.String SERVICE_NAME = "envoy.service.rate_limit_quota.v3.RateLimitQuotaService"; + + // Static method descriptors that strictly reflect the proto. + private static volatile io.grpc.MethodDescriptor getStreamRateLimitQuotasMethod; + + @io.grpc.stub.annotations.RpcMethod( + fullMethodName = SERVICE_NAME + '/' + "StreamRateLimitQuotas", + requestType = io.envoyproxy.envoy.service.rate_limit_quota.v3.RateLimitQuotaUsageReports.class, + responseType = io.envoyproxy.envoy.service.rate_limit_quota.v3.RateLimitQuotaResponse.class, + methodType = io.grpc.MethodDescriptor.MethodType.BIDI_STREAMING) + public static io.grpc.MethodDescriptor getStreamRateLimitQuotasMethod() { + io.grpc.MethodDescriptor getStreamRateLimitQuotasMethod; + if ((getStreamRateLimitQuotasMethod = RateLimitQuotaServiceGrpc.getStreamRateLimitQuotasMethod) == null) { + synchronized (RateLimitQuotaServiceGrpc.class) { + if ((getStreamRateLimitQuotasMethod = RateLimitQuotaServiceGrpc.getStreamRateLimitQuotasMethod) == null) { + RateLimitQuotaServiceGrpc.getStreamRateLimitQuotasMethod = getStreamRateLimitQuotasMethod = + io.grpc.MethodDescriptor.newBuilder() + .setType(io.grpc.MethodDescriptor.MethodType.BIDI_STREAMING) + .setFullMethodName(generateFullMethodName(SERVICE_NAME, "StreamRateLimitQuotas")) + .setSampledToLocalTracing(true) + .setRequestMarshaller(io.grpc.protobuf.ProtoUtils.marshaller( + io.envoyproxy.envoy.service.rate_limit_quota.v3.RateLimitQuotaUsageReports.getDefaultInstance())) + .setResponseMarshaller(io.grpc.protobuf.ProtoUtils.marshaller( + io.envoyproxy.envoy.service.rate_limit_quota.v3.RateLimitQuotaResponse.getDefaultInstance())) + .setSchemaDescriptor(new RateLimitQuotaServiceMethodDescriptorSupplier("StreamRateLimitQuotas")) + .build(); + } + } + } + return getStreamRateLimitQuotasMethod; + } + + /** + * Creates a new async stub that supports all call types for the service + */ + public static RateLimitQuotaServiceStub newStub(io.grpc.Channel channel) { + io.grpc.stub.AbstractStub.StubFactory factory = + new io.grpc.stub.AbstractStub.StubFactory() { + @java.lang.Override + public RateLimitQuotaServiceStub newStub(io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new RateLimitQuotaServiceStub(channel, callOptions); + } + }; + return RateLimitQuotaServiceStub.newStub(factory, channel); + } + + /** + * Creates a new blocking-style stub that supports all types of calls on the service + */ + public static RateLimitQuotaServiceBlockingV2Stub newBlockingV2Stub( + io.grpc.Channel channel) { + io.grpc.stub.AbstractStub.StubFactory factory = + new io.grpc.stub.AbstractStub.StubFactory() { + @java.lang.Override + public RateLimitQuotaServiceBlockingV2Stub newStub(io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new RateLimitQuotaServiceBlockingV2Stub(channel, callOptions); + } + }; + return RateLimitQuotaServiceBlockingV2Stub.newStub(factory, channel); + } + + /** + * Creates a new blocking-style stub that supports unary and streaming output calls on the service + */ + public static RateLimitQuotaServiceBlockingStub newBlockingStub( + io.grpc.Channel channel) { + io.grpc.stub.AbstractStub.StubFactory factory = + new io.grpc.stub.AbstractStub.StubFactory() { + @java.lang.Override + public RateLimitQuotaServiceBlockingStub newStub(io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new RateLimitQuotaServiceBlockingStub(channel, callOptions); + } + }; + return RateLimitQuotaServiceBlockingStub.newStub(factory, channel); + } + + /** + * Creates a new ListenableFuture-style stub that supports unary calls on the service + */ + public static RateLimitQuotaServiceFutureStub newFutureStub( + io.grpc.Channel channel) { + io.grpc.stub.AbstractStub.StubFactory factory = + new io.grpc.stub.AbstractStub.StubFactory() { + @java.lang.Override + public RateLimitQuotaServiceFutureStub newStub(io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new RateLimitQuotaServiceFutureStub(channel, callOptions); + } + }; + return RateLimitQuotaServiceFutureStub.newStub(factory, channel); + } + + /** + *
+   * Defines the Rate Limit Quota Service (RLQS).
+   * 
+ */ + public interface AsyncService { + + /** + *
+     * Main communication channel: the data plane sends usage reports to the RLQS server,
+     * and the server asynchronously responding with the assignments.
+     * 
+ */ + default io.grpc.stub.StreamObserver streamRateLimitQuotas( + io.grpc.stub.StreamObserver responseObserver) { + return io.grpc.stub.ServerCalls.asyncUnimplementedStreamingCall(getStreamRateLimitQuotasMethod(), responseObserver); + } + } + + /** + * Base class for the server implementation of the service RateLimitQuotaService. + *
+   * Defines the Rate Limit Quota Service (RLQS).
+   * 
+ */ + public static abstract class RateLimitQuotaServiceImplBase + implements io.grpc.BindableService, AsyncService { + + @java.lang.Override public final io.grpc.ServerServiceDefinition bindService() { + return RateLimitQuotaServiceGrpc.bindService(this); + } + } + + /** + * A stub to allow clients to do asynchronous rpc calls to service RateLimitQuotaService. + *
+   * Defines the Rate Limit Quota Service (RLQS).
+   * 
+ */ + public static final class RateLimitQuotaServiceStub + extends io.grpc.stub.AbstractAsyncStub { + private RateLimitQuotaServiceStub( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + super(channel, callOptions); + } + + @java.lang.Override + protected RateLimitQuotaServiceStub build( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new RateLimitQuotaServiceStub(channel, callOptions); + } + + /** + *
+     * Main communication channel: the data plane sends usage reports to the RLQS server,
+     * and the server asynchronously responding with the assignments.
+     * 
+ */ + public io.grpc.stub.StreamObserver streamRateLimitQuotas( + io.grpc.stub.StreamObserver responseObserver) { + return io.grpc.stub.ClientCalls.asyncBidiStreamingCall( + getChannel().newCall(getStreamRateLimitQuotasMethod(), getCallOptions()), responseObserver); + } + } + + /** + * A stub to allow clients to do synchronous rpc calls to service RateLimitQuotaService. + *
+   * Defines the Rate Limit Quota Service (RLQS).
+   * 
+ */ + public static final class RateLimitQuotaServiceBlockingV2Stub + extends io.grpc.stub.AbstractBlockingStub { + private RateLimitQuotaServiceBlockingV2Stub( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + super(channel, callOptions); + } + + @java.lang.Override + protected RateLimitQuotaServiceBlockingV2Stub build( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new RateLimitQuotaServiceBlockingV2Stub(channel, callOptions); + } + + /** + *
+     * Main communication channel: the data plane sends usage reports to the RLQS server,
+     * and the server asynchronously responding with the assignments.
+     * 
+ */ + @io.grpc.ExperimentalApi("https://github.com/grpc/grpc-java/issues/10918") + public io.grpc.stub.BlockingClientCall + streamRateLimitQuotas() { + return io.grpc.stub.ClientCalls.blockingBidiStreamingCall( + getChannel(), getStreamRateLimitQuotasMethod(), getCallOptions()); + } + } + + /** + * A stub to allow clients to do limited synchronous rpc calls to service RateLimitQuotaService. + *
+   * Defines the Rate Limit Quota Service (RLQS).
+   * 
+ */ + public static final class RateLimitQuotaServiceBlockingStub + extends io.grpc.stub.AbstractBlockingStub { + private RateLimitQuotaServiceBlockingStub( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + super(channel, callOptions); + } + + @java.lang.Override + protected RateLimitQuotaServiceBlockingStub build( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new RateLimitQuotaServiceBlockingStub(channel, callOptions); + } + } + + /** + * A stub to allow clients to do ListenableFuture-style rpc calls to service RateLimitQuotaService. + *
+   * Defines the Rate Limit Quota Service (RLQS).
+   * 
+ */ + public static final class RateLimitQuotaServiceFutureStub + extends io.grpc.stub.AbstractFutureStub { + private RateLimitQuotaServiceFutureStub( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + super(channel, callOptions); + } + + @java.lang.Override + protected RateLimitQuotaServiceFutureStub build( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new RateLimitQuotaServiceFutureStub(channel, callOptions); + } + } + + private static final int METHODID_STREAM_RATE_LIMIT_QUOTAS = 0; + + private static final class MethodHandlers implements + io.grpc.stub.ServerCalls.UnaryMethod, + io.grpc.stub.ServerCalls.ServerStreamingMethod, + io.grpc.stub.ServerCalls.ClientStreamingMethod, + io.grpc.stub.ServerCalls.BidiStreamingMethod { + private final AsyncService serviceImpl; + private final int methodId; + + MethodHandlers(AsyncService serviceImpl, int methodId) { + this.serviceImpl = serviceImpl; + this.methodId = methodId; + } + + @java.lang.Override + @java.lang.SuppressWarnings("unchecked") + public void invoke(Req request, io.grpc.stub.StreamObserver responseObserver) { + switch (methodId) { + default: + throw new AssertionError(); + } + } + + @java.lang.Override + @java.lang.SuppressWarnings("unchecked") + public io.grpc.stub.StreamObserver invoke( + io.grpc.stub.StreamObserver responseObserver) { + switch (methodId) { + case METHODID_STREAM_RATE_LIMIT_QUOTAS: + return (io.grpc.stub.StreamObserver) serviceImpl.streamRateLimitQuotas( + (io.grpc.stub.StreamObserver) responseObserver); + default: + throw new AssertionError(); + } + } + } + + public static final io.grpc.ServerServiceDefinition bindService(AsyncService service) { + return io.grpc.ServerServiceDefinition.builder(getServiceDescriptor()) + .addMethod( + getStreamRateLimitQuotasMethod(), + io.grpc.stub.ServerCalls.asyncBidiStreamingCall( + new MethodHandlers< + io.envoyproxy.envoy.service.rate_limit_quota.v3.RateLimitQuotaUsageReports, + io.envoyproxy.envoy.service.rate_limit_quota.v3.RateLimitQuotaResponse>( + service, METHODID_STREAM_RATE_LIMIT_QUOTAS))) + .build(); + } + + private static abstract class RateLimitQuotaServiceBaseDescriptorSupplier + implements io.grpc.protobuf.ProtoFileDescriptorSupplier, io.grpc.protobuf.ProtoServiceDescriptorSupplier { + RateLimitQuotaServiceBaseDescriptorSupplier() {} + + @java.lang.Override + public com.google.protobuf.Descriptors.FileDescriptor getFileDescriptor() { + return io.envoyproxy.envoy.service.rate_limit_quota.v3.RlqsProto.getDescriptor(); + } + + @java.lang.Override + public com.google.protobuf.Descriptors.ServiceDescriptor getServiceDescriptor() { + return getFileDescriptor().findServiceByName("RateLimitQuotaService"); + } + } + + private static final class RateLimitQuotaServiceFileDescriptorSupplier + extends RateLimitQuotaServiceBaseDescriptorSupplier { + RateLimitQuotaServiceFileDescriptorSupplier() {} + } + + private static final class RateLimitQuotaServiceMethodDescriptorSupplier + extends RateLimitQuotaServiceBaseDescriptorSupplier + implements io.grpc.protobuf.ProtoMethodDescriptorSupplier { + private final java.lang.String methodName; + + RateLimitQuotaServiceMethodDescriptorSupplier(java.lang.String methodName) { + this.methodName = methodName; + } + + @java.lang.Override + public com.google.protobuf.Descriptors.MethodDescriptor getMethodDescriptor() { + return getServiceDescriptor().findMethodByName(methodName); + } + } + + private static volatile io.grpc.ServiceDescriptor serviceDescriptor; + + public static io.grpc.ServiceDescriptor getServiceDescriptor() { + io.grpc.ServiceDescriptor result = serviceDescriptor; + if (result == null) { + synchronized (RateLimitQuotaServiceGrpc.class) { + result = serviceDescriptor; + if (result == null) { + serviceDescriptor = result = io.grpc.ServiceDescriptor.newBuilder(SERVICE_NAME) + .setSchemaDescriptor(new RateLimitQuotaServiceFileDescriptorSupplier()) + .addMethod(getStreamRateLimitQuotasMethod()) + .build(); + } + } + } + return result; + } +} diff --git a/xds/src/generated/thirdparty/grpc/io/envoyproxy/envoy/service/status/v3/ClientStatusDiscoveryServiceGrpc.java b/xds/src/generated/thirdparty/grpc/io/envoyproxy/envoy/service/status/v3/ClientStatusDiscoveryServiceGrpc.java index 3f8874248d0..cb166503566 100644 --- a/xds/src/generated/thirdparty/grpc/io/envoyproxy/envoy/service/status/v3/ClientStatusDiscoveryServiceGrpc.java +++ b/xds/src/generated/thirdparty/grpc/io/envoyproxy/envoy/service/status/v3/ClientStatusDiscoveryServiceGrpc.java @@ -9,9 +9,6 @@ * also be used to get the current xDS states directly from the client. * */ -@javax.annotation.Generated( - value = "by gRPC proto compiler", - comments = "Source: envoy/service/status/v3/csds.proto") @io.grpc.stub.annotations.GrpcGenerated public final class ClientStatusDiscoveryServiceGrpc { @@ -96,6 +93,21 @@ public ClientStatusDiscoveryServiceStub newStub(io.grpc.Channel channel, io.grpc return ClientStatusDiscoveryServiceStub.newStub(factory, channel); } + /** + * Creates a new blocking-style stub that supports all types of calls on the service + */ + public static ClientStatusDiscoveryServiceBlockingV2Stub newBlockingV2Stub( + io.grpc.Channel channel) { + io.grpc.stub.AbstractStub.StubFactory factory = + new io.grpc.stub.AbstractStub.StubFactory() { + @java.lang.Override + public ClientStatusDiscoveryServiceBlockingV2Stub newStub(io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new ClientStatusDiscoveryServiceBlockingV2Stub(channel, callOptions); + } + }; + return ClientStatusDiscoveryServiceBlockingV2Stub.newStub(factory, channel); + } + /** * Creates a new blocking-style stub that supports unary and streaming output calls on the service */ @@ -212,6 +224,44 @@ public void fetchClientStatus(io.envoyproxy.envoy.service.status.v3.ClientStatus * also be used to get the current xDS states directly from the client. * */ + public static final class ClientStatusDiscoveryServiceBlockingV2Stub + extends io.grpc.stub.AbstractBlockingStub { + private ClientStatusDiscoveryServiceBlockingV2Stub( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + super(channel, callOptions); + } + + @java.lang.Override + protected ClientStatusDiscoveryServiceBlockingV2Stub build( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new ClientStatusDiscoveryServiceBlockingV2Stub(channel, callOptions); + } + + /** + */ + @io.grpc.ExperimentalApi("https://github.com/grpc/grpc-java/issues/10918") + public io.grpc.stub.BlockingClientCall + streamClientStatus() { + return io.grpc.stub.ClientCalls.blockingBidiStreamingCall( + getChannel(), getStreamClientStatusMethod(), getCallOptions()); + } + + /** + */ + public io.envoyproxy.envoy.service.status.v3.ClientStatusResponse fetchClientStatus(io.envoyproxy.envoy.service.status.v3.ClientStatusRequest request) throws io.grpc.StatusException { + return io.grpc.stub.ClientCalls.blockingV2UnaryCall( + getChannel(), getFetchClientStatusMethod(), getCallOptions(), request); + } + } + + /** + * A stub to allow clients to do limited synchronous rpc calls to service ClientStatusDiscoveryService. + *
+   * CSDS is Client Status Discovery Service. It can be used to get the status of
+   * an xDS-compliant client from the management server's point of view. It can
+   * also be used to get the current xDS states directly from the client.
+   * 
+ */ public static final class ClientStatusDiscoveryServiceBlockingStub extends io.grpc.stub.AbstractBlockingStub { private ClientStatusDiscoveryServiceBlockingStub( diff --git a/xds/src/main/java/io/grpc/xds/AddressFilter.java b/xds/src/main/java/io/grpc/xds/AddressFilter.java index 841e96d06bb..8008f638565 100644 --- a/xds/src/main/java/io/grpc/xds/AddressFilter.java +++ b/xds/src/main/java/io/grpc/xds/AddressFilter.java @@ -24,11 +24,12 @@ import java.util.ArrayList; import java.util.Collections; import java.util.List; +import java.util.ListIterator; import javax.annotation.Nullable; final class AddressFilter { @ResolutionResultAttr - private static final Attributes.Key PATH_CHAIN_KEY = + static final Attributes.Key PATH_CHAIN_KEY = Attributes.Key.create("io.grpc.xds.AddressFilter.PATH_CHAIN_KEY"); // Prevent instantiation. @@ -41,19 +42,25 @@ private AddressFilter() {} static EquivalentAddressGroup setPathFilter(EquivalentAddressGroup address, List names) { checkNotNull(address, "address"); checkNotNull(names, "names"); - Attributes.Builder attrBuilder = address.getAttributes().toBuilder().discard(PATH_CHAIN_KEY); - PathChain pathChain = null; - for (String name : names) { - if (pathChain == null) { - pathChain = new PathChain(name); - attrBuilder.set(PATH_CHAIN_KEY, pathChain); - } else { - pathChain.next = new PathChain(name); - } - } + Attributes.Builder attrBuilder = address.getAttributes().toBuilder() + .set(PATH_CHAIN_KEY, createPathChain(names)); return new EquivalentAddressGroup(address.getAddresses(), attrBuilder.build()); } + /** + * Creates a PathChain that can be set in an EquivalentAddressGroup's Attributes as a value of + * PATH_CHAIN_KEY. + */ + @Nullable static PathChain createPathChain(List names) { + checkNotNull(names, "names"); + PathChain current = null; + ListIterator iter = names.listIterator(names.size()); + while (iter.hasPrevious()) { + current = new PathChain(iter.previous(), current); + } + return current; + } + /** * Returns the next level hierarchical addresses derived from the given hierarchical addresses * with the given filter name (any non-hierarchical addresses in the input will be ignored). @@ -75,12 +82,13 @@ static List filter(List addresse return Collections.unmodifiableList(filteredAddresses); } - private static final class PathChain { + static final class PathChain { final String name; - @Nullable PathChain next; + @Nullable final PathChain next; - PathChain(String name) { + PathChain(String name, @Nullable PathChain next) { this.name = checkNotNull(name, "name"); + this.next = next; } @Override diff --git a/xds/src/main/java/io/grpc/xds/CdsLoadBalancer2.java b/xds/src/main/java/io/grpc/xds/CdsLoadBalancer2.java index fe73e1886f3..a52c6cba01b 100644 --- a/xds/src/main/java/io/grpc/xds/CdsLoadBalancer2.java +++ b/xds/src/main/java/io/grpc/xds/CdsLoadBalancer2.java @@ -18,41 +18,58 @@ import static com.google.common.base.Preconditions.checkNotNull; import static io.grpc.ConnectivityState.TRANSIENT_FAILURE; -import static io.grpc.xds.XdsLbPolicies.CLUSTER_RESOLVER_POLICY_NAME; - -import com.google.common.annotations.VisibleForTesting; +import static io.grpc.xds.XdsLbPolicies.CDS_POLICY_NAME; +import static io.grpc.xds.XdsLbPolicies.PRIORITY_POLICY_NAME; + +import com.google.common.collect.ImmutableMap; +import com.google.common.primitives.UnsignedInts; +import com.google.errorprone.annotations.CheckReturnValue; +import io.grpc.Attributes; +import io.grpc.EquivalentAddressGroup; +import io.grpc.HttpConnectProxiedSocketAddress; +import io.grpc.InternalEquivalentAddressGroup; import io.grpc.InternalLogId; import io.grpc.LoadBalancer; import io.grpc.LoadBalancerProvider; import io.grpc.LoadBalancerRegistry; import io.grpc.NameResolver; import io.grpc.Status; -import io.grpc.SynchronizationContext; -import io.grpc.internal.ObjectPool; -import io.grpc.internal.ServiceConfigUtil; -import io.grpc.internal.ServiceConfigUtil.LbConfig; -import io.grpc.internal.ServiceConfigUtil.PolicySelection; +import io.grpc.StatusOr; +import io.grpc.internal.GrpcUtil; +import io.grpc.util.GracefulSwitchLoadBalancer; +import io.grpc.util.OutlierDetectionLoadBalancer.OutlierDetectionLoadBalancerConfig; import io.grpc.xds.CdsLoadBalancerProvider.CdsConfig; -import io.grpc.xds.ClusterResolverLoadBalancerProvider.ClusterResolverConfig; -import io.grpc.xds.ClusterResolverLoadBalancerProvider.ClusterResolverConfig.DiscoveryMechanism; +import io.grpc.xds.ClusterImplLoadBalancerProvider.ClusterImplConfig; +import io.grpc.xds.Endpoints.DropOverload; +import io.grpc.xds.Endpoints.LbEndpoint; +import io.grpc.xds.Endpoints.LocalityLbEndpoints; +import io.grpc.xds.EnvoyServerProtoData.FailurePercentageEjection; +import io.grpc.xds.EnvoyServerProtoData.OutlierDetection; +import io.grpc.xds.EnvoyServerProtoData.SuccessRateEjection; +import io.grpc.xds.PriorityLoadBalancerProvider.PriorityLbConfig; +import io.grpc.xds.PriorityLoadBalancerProvider.PriorityLbConfig.PriorityChildConfig; import io.grpc.xds.XdsClusterResource.CdsUpdate; import io.grpc.xds.XdsClusterResource.CdsUpdate.ClusterType; -import io.grpc.xds.client.XdsClient; -import io.grpc.xds.client.XdsClient.ResourceWatcher; +import io.grpc.xds.XdsConfig.Subscription; +import io.grpc.xds.XdsConfig.XdsClusterConfig; +import io.grpc.xds.XdsConfig.XdsClusterConfig.AggregateConfig; +import io.grpc.xds.XdsConfig.XdsClusterConfig.EndpointConfig; +import io.grpc.xds.XdsEndpointResource.EdsUpdate; +import io.grpc.xds.client.Locality; import io.grpc.xds.client.XdsLogger; import io.grpc.xds.client.XdsLogger.XdsLogLevel; -import java.util.ArrayDeque; +import io.grpc.xds.internal.XdsInternalAttributes; +import java.net.InetSocketAddress; +import java.net.SocketAddress; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collections; import java.util.HashMap; import java.util.HashSet; -import java.util.LinkedHashMap; import java.util.List; import java.util.Map; -import java.util.Queue; import java.util.Set; -import java.util.concurrent.ConcurrentHashMap; -import javax.annotation.Nullable; +import java.util.TreeMap; /** * Load balancer for cds_experimental LB policy. One instance per top-level cluster. @@ -60,50 +77,128 @@ * by a group of sub-clusters in a tree hierarchy. */ final class CdsLoadBalancer2 extends LoadBalancer { + static boolean pickFirstWeightedShuffling = + GrpcUtil.getFlag("GRPC_EXPERIMENTAL_PF_WEIGHTED_SHUFFLING", true); + private final XdsLogger logger; private final Helper helper; - private final SynchronizationContext syncContext; private final LoadBalancerRegistry lbRegistry; + private final ClusterState clusterState = new ClusterState(); + private GracefulSwitchLoadBalancer delegate; // Following fields are effectively final. - private ObjectPool xdsClientPool; - private XdsClient xdsClient; - private CdsLbState cdsLbState; - private ResolvedAddresses resolvedAddresses; - - CdsLoadBalancer2(Helper helper) { - this(helper, LoadBalancerRegistry.getDefaultRegistry()); - } + private String clusterName; + private Subscription clusterSubscription; - @VisibleForTesting CdsLoadBalancer2(Helper helper, LoadBalancerRegistry lbRegistry) { this.helper = checkNotNull(helper, "helper"); - this.syncContext = checkNotNull(helper.getSynchronizationContext(), "syncContext"); this.lbRegistry = checkNotNull(lbRegistry, "lbRegistry"); + this.delegate = new GracefulSwitchLoadBalancer(helper); logger = XdsLogger.withLogId(InternalLogId.allocate("cds-lb", helper.getAuthority())); logger.log(XdsLogLevel.INFO, "Created"); } @Override public Status acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { - if (this.resolvedAddresses != null) { + logger.log(XdsLogLevel.DEBUG, "Received resolution result: {0}", resolvedAddresses); + if (this.clusterName == null) { + CdsConfig config = (CdsConfig) resolvedAddresses.getLoadBalancingPolicyConfig(); + logger.log(XdsLogLevel.INFO, "Config: {0}", config); + if (config.isDynamic) { + clusterSubscription = resolvedAddresses.getAttributes() + .get(XdsAttributes.XDS_CLUSTER_SUBSCRIPT_REGISTRY) + .subscribeToCluster(config.name); + } + this.clusterName = config.name; + } + XdsConfig xdsConfig = resolvedAddresses.getAttributes().get(XdsAttributes.XDS_CONFIG); + StatusOr clusterConfigOr = xdsConfig.getClusters().get(clusterName); + if (clusterConfigOr == null) { + if (clusterSubscription == null) { + // Should be impossible, because XdsDependencyManager wouldn't have generated this + return fail(Status.INTERNAL.withDescription( + errorPrefix() + "Unable to find non-dynamic cluster")); + } + // The dynamic cluster must not have loaded yet return Status.OK; } - logger.log(XdsLogLevel.DEBUG, "Received resolution result: {0}", resolvedAddresses); - this.resolvedAddresses = resolvedAddresses; - xdsClientPool = resolvedAddresses.getAttributes().get(InternalXdsAttributes.XDS_CLIENT_POOL); - xdsClient = xdsClientPool.getObject(); - CdsConfig config = (CdsConfig) resolvedAddresses.getLoadBalancingPolicyConfig(); - logger.log(XdsLogLevel.INFO, "Config: {0}", config); - cdsLbState = new CdsLbState(config.name); - cdsLbState.start(); - return Status.OK; + if (!clusterConfigOr.hasValue()) { + return fail(clusterConfigOr.getStatus()); + } + XdsClusterConfig clusterConfig = clusterConfigOr.getValue(); + + NameResolver.ConfigOrError configOrError; + if (clusterConfig.getChildren() instanceof EndpointConfig) { + // The LB policy config is provided in service_config.proto/JSON format. + configOrError = + GracefulSwitchLoadBalancer.parseLoadBalancingPolicyConfig( + Arrays.asList(clusterConfig.getClusterResource().lbPolicyConfig()), + lbRegistry); + if (configOrError.getError() != null) { + // Should be impossible, because XdsClusterResource validated this + return fail(Status.INTERNAL.withDescription( + errorPrefix() + "Unable to parse the LB config: " + configOrError.getError())); + } + + StatusOr edsUpdate = getEdsUpdate(xdsConfig, clusterName); + StatusOr statusOrResult = clusterState.edsUpdateToResult( + clusterName, + clusterConfig.getClusterResource(), + configOrError.getConfig(), + edsUpdate); + if (!statusOrResult.hasValue()) { + Status status = Status.UNAVAILABLE + .withDescription(statusOrResult.getStatus().getDescription()) + .withCause(statusOrResult.getStatus().getCause()); + delegate.handleNameResolutionError(status); + return status; + } + ClusterResolutionResult result = statusOrResult.getValue(); + List addresses = result.addresses; + if (addresses.isEmpty()) { + Status status = Status.UNAVAILABLE + .withDescription("No usable endpoint from cluster: " + clusterName); + delegate.handleNameResolutionError(status); + return status; + } + Object gracefulConfig = GracefulSwitchLoadBalancer.createLoadBalancingPolicyConfig( + lbRegistry.getProvider(PRIORITY_POLICY_NAME), + new PriorityLbConfig( + Collections.unmodifiableMap(result.priorityChildConfigs), + Collections.unmodifiableList(result.priorities))); + return delegate.acceptResolvedAddresses( + resolvedAddresses.toBuilder() + .setLoadBalancingPolicyConfig(gracefulConfig) + .setAddresses(Collections.unmodifiableList(addresses)) + .build()); + } else if (clusterConfig.getChildren() instanceof AggregateConfig) { + Map priorityChildConfigs = new HashMap<>(); + List leafClusters = ((AggregateConfig) clusterConfig.getChildren()).getLeafNames(); + for (String childCluster: leafClusters) { + priorityChildConfigs.put(childCluster, + new PriorityChildConfig( + GracefulSwitchLoadBalancer.createLoadBalancingPolicyConfig( + lbRegistry.getProvider(CDS_POLICY_NAME), + new CdsConfig(childCluster)), + false)); + } + Object gracefulConfig = GracefulSwitchLoadBalancer.createLoadBalancingPolicyConfig( + lbRegistry.getProvider(PRIORITY_POLICY_NAME), + new PriorityLoadBalancerProvider.PriorityLbConfig( + Collections.unmodifiableMap(priorityChildConfigs), leafClusters)); + return delegate.acceptResolvedAddresses( + resolvedAddresses.toBuilder().setLoadBalancingPolicyConfig(gracefulConfig).build()); + } else { + return fail(Status.INTERNAL.withDescription( + errorPrefix() + "Unexpected cluster children type: " + + clusterConfig.getChildren().getClass())); + } } @Override public void handleNameResolutionError(Status error) { logger.log(XdsLogLevel.WARNING, "Received name resolution error: {0}", error); - if (cdsLbState != null && cdsLbState.childLb != null) { - cdsLbState.childLb.handleNameResolutionError(error); + if (delegate != null) { + delegate.handleNameResolutionError(error); } else { helper.updateBalancingState( TRANSIENT_FAILURE, new FixedResultPicker(PickResult.withError(error))); @@ -113,315 +208,413 @@ public void handleNameResolutionError(Status error) { @Override public void shutdown() { logger.log(XdsLogLevel.INFO, "Shutdown"); - if (cdsLbState != null) { - cdsLbState.shutdown(); - } - if (xdsClientPool != null) { - xdsClientPool.returnObject(xdsClient); + delegate.shutdown(); + delegate = new GracefulSwitchLoadBalancer(helper); + if (clusterSubscription != null) { + clusterSubscription.close(); + clusterSubscription = null; } } + @CheckReturnValue // don't forget to return up the stack after the fail call + private Status fail(Status error) { + delegate.shutdown(); + helper.updateBalancingState( + TRANSIENT_FAILURE, new FixedResultPicker(PickResult.withError(error))); + return Status.OK; // XdsNameResolver isn't a polling NR, so this value doesn't matter + } + + private String errorPrefix() { + return "CdsLb for " + clusterName + ": "; + } + /** - * The state of a CDS working session of {@link CdsLoadBalancer2}. Created and started when - * receiving the CDS LB policy config with the top-level cluster name. + * The number of bits assigned to the fractional part of fixed-point values. We normalize weights + * to a fixed-point number between 0 and 1, representing that item's proportion of traffic (1 == + * 100% of traffic). We reserve at least one bit for the whole number so that we don't need to + * special case a single item, and so that we can round up very low values without risking uint32 + * overflow of the sum of weights. */ - private final class CdsLbState { + private static final int FIXED_POINT_FRACTIONAL_BITS = 31; - private final ClusterState root; - private final Map clusterStates = new ConcurrentHashMap<>(); - private LoadBalancer childLb; + /** Divide two uint32s and produce a fixed-point uint32 result. */ + private static long fractionToFixedPoint(long numerator, long denominator) { + long one = 1L << FIXED_POINT_FRACTIONAL_BITS; + return numerator * one / denominator; + } - private CdsLbState(String rootCluster) { - root = new ClusterState(rootCluster); - } + /** Multiply two uint32 fixed-point numbers, returning a uint32 fixed-point. */ + private static long fixedPointMultiply(long a, long b) { + return (a * b) >> FIXED_POINT_FRACTIONAL_BITS; + } - private void start() { - root.start(); + private static StatusOr getEdsUpdate(XdsConfig xdsConfig, String cluster) { + StatusOr clusterConfig = xdsConfig.getClusters().get(cluster); + if (clusterConfig == null) { + return StatusOr.fromStatus(Status.INTERNAL + .withDescription("BUG: cluster resolver could not find cluster in xdsConfig")); } - - private void shutdown() { - root.shutdown(); - if (childLb != null) { - childLb.shutdown(); - } + if (!clusterConfig.hasValue()) { + return StatusOr.fromStatus(clusterConfig.getStatus()); } + if (!(clusterConfig.getValue().getChildren() instanceof XdsClusterConfig.EndpointConfig)) { + return StatusOr.fromStatus(Status.INTERNAL + .withDescription("BUG: cluster resolver cluster with children of unknown type")); + } + XdsClusterConfig.EndpointConfig endpointConfig = + (XdsClusterConfig.EndpointConfig) clusterConfig.getValue().getChildren(); + return endpointConfig.getEndpoint(); + } + + /** + * Generates a string that represents the priority in the LB policy config. The string is unique + * across priorities in all clusters and priorityName(c, p1) < priorityName(c, p2) iff p1 < p2. + * The ordering is undefined for priorities in different clusters. + */ + private static String priorityName(String cluster, int priority) { + return cluster + "[child" + priority + "]"; + } - private void handleClusterDiscovered() { - List instances = new ArrayList<>(); - - // Used for loop detection to break the infinite recursion that loops would cause - Map> parentClusters = new HashMap<>(); - Status loopStatus = null; - - // Level-order traversal. - // Collect configurations for all non-aggregate (leaf) clusters. - Queue queue = new ArrayDeque<>(); - queue.add(root); - while (!queue.isEmpty()) { - int size = queue.size(); - for (int i = 0; i < size; i++) { - ClusterState clusterState = queue.remove(); - if (!clusterState.discovered) { - return; // do not proceed until all clusters discovered + /** + * Generates a string that represents the locality in the LB policy config. The string is unique + * across all localities in all clusters. + */ + private static String localityName(Locality locality) { + return "{region=\"" + locality.region() + + "\", zone=\"" + locality.zone() + + "\", sub_zone=\"" + locality.subZone() + + "\"}"; + } + + private final class ClusterState { + private Map localityPriorityNames = Collections.emptyMap(); + int priorityNameGenId = 1; + + StatusOr edsUpdateToResult( + String clusterName, + CdsUpdate discovery, + Object lbConfig, + StatusOr updateOr) { + if (!updateOr.hasValue()) { + return StatusOr.fromStatus(updateOr.getStatus()); + } + EdsUpdate update = updateOr.getValue(); + logger.log(XdsLogLevel.DEBUG, "Received endpoint update {0}", update); + if (logger.isLoggable(XdsLogLevel.INFO)) { + logger.log(XdsLogLevel.INFO, "Cluster {0}: {1} localities, {2} drop categories", + clusterName, update.localityLbEndpointsMap.size(), + update.dropPolicies.size()); + } + Map localityLbEndpoints = + update.localityLbEndpointsMap; + List dropOverloads = update.dropPolicies; + List addresses = new ArrayList<>(); + Map> prioritizedLocalityWeights = new HashMap<>(); + List sortedPriorityNames = + generatePriorityNames(clusterName, localityLbEndpoints); + Map priorityLocalityWeightSums; + if (pickFirstWeightedShuffling) { + priorityLocalityWeightSums = new HashMap<>(sortedPriorityNames.size() * 2); + for (Locality locality : localityLbEndpoints.keySet()) { + LocalityLbEndpoints localityLbInfo = localityLbEndpoints.get(locality); + String priorityName = localityPriorityNames.get(locality); + Long sum = priorityLocalityWeightSums.get(priorityName); + if (sum == null) { + sum = 0L; } - if (clusterState.result == null) { // resource revoked or not exists - continue; + long weight = UnsignedInts.toLong(localityLbInfo.localityWeight()); + priorityLocalityWeightSums.put(priorityName, sum + weight); + } + } else { + priorityLocalityWeightSums = null; + } + + for (Locality locality : localityLbEndpoints.keySet()) { + LocalityLbEndpoints localityLbInfo = localityLbEndpoints.get(locality); + String priorityName = localityPriorityNames.get(locality); + String localityName = localityName(locality); + AddressFilter.PathChain pathChain = + AddressFilter.createPathChain(Arrays.asList(priorityName, localityName)); + + boolean discard = true; + // These sums _should_ fit in uint32, but XdsEndpointResource isn't actually verifying that + // is true today. Since we are using long to avoid signedness trouble, the math happens to + // still work if it turns out the sums exceed uint32. + long localityWeightSum = 0; + long endpointWeightSum = 0; + if (pickFirstWeightedShuffling) { + localityWeightSum = priorityLocalityWeightSums.get(priorityName); + for (LbEndpoint endpoint : localityLbInfo.endpoints()) { + if (endpoint.isHealthy()) { + endpointWeightSum += UnsignedInts.toLong(endpoint.loadBalancingWeight()); + } } - if (clusterState.isLeaf) { - if (instances.stream().map(inst -> inst.cluster).noneMatch(clusterState.name::equals)) { - DiscoveryMechanism instance; - if (clusterState.result.clusterType() == ClusterType.EDS) { - instance = DiscoveryMechanism.forEds( - clusterState.name, clusterState.result.edsServiceName(), - clusterState.result.lrsServerInfo(), - clusterState.result.maxConcurrentRequests(), - clusterState.result.upstreamTlsContext(), - clusterState.result.outlierDetection()); - } else { // logical DNS - instance = DiscoveryMechanism.forLogicalDns( - clusterState.name, clusterState.result.dnsHostName(), - clusterState.result.lrsServerInfo(), - clusterState.result.maxConcurrentRequests(), - clusterState.result.upstreamTlsContext()); + } + for (LbEndpoint endpoint : localityLbInfo.endpoints()) { + if (endpoint.isHealthy()) { + discard = false; + long weight; + if (pickFirstWeightedShuffling) { + // Combine locality and endpoint weights as defined by gRFC A113 + long localityWeight = fractionToFixedPoint( + UnsignedInts.toLong(localityLbInfo.localityWeight()), localityWeightSum); + long endpointWeight = fractionToFixedPoint( + UnsignedInts.toLong(endpoint.loadBalancingWeight()), endpointWeightSum); + weight = fixedPointMultiply(localityWeight, endpointWeight); + if (weight == 0) { + weight = 1; } - instances.add(instance); - } - } else { - if (clusterState.childClusterStates == null) { - continue; - } - // Do loop detection and break recursion if detected - List namesCausingLoops = identifyLoops(clusterState, parentClusters); - if (namesCausingLoops.isEmpty()) { - queue.addAll(clusterState.childClusterStates.values()); } else { - // Do cleanup - if (childLb != null) { - childLb.shutdown(); - childLb = null; + weight = localityLbInfo.localityWeight(); + if (endpoint.loadBalancingWeight() != 0) { + weight *= endpoint.loadBalancingWeight(); } - if (loopStatus != null) { - logger.log(XdsLogLevel.WARNING, - "Multiple loops in CDS config. Old msg: " + loopStatus.getDescription()); + } + + Attributes attr = + endpoint.eag().getAttributes().toBuilder() + .set(InternalEquivalentAddressGroup.ATTR_BACKEND_SERVICE, clusterName) + .set(io.grpc.xds.XdsAttributes.ATTR_LOCALITY, locality) + .set(EquivalentAddressGroup.ATTR_LOCALITY_NAME, localityName) + .set(io.grpc.xds.XdsAttributes.ATTR_LOCALITY_WEIGHT, + localityLbInfo.localityWeight()) + .set(io.grpc.xds.XdsAttributes.ATTR_SERVER_WEIGHT, weight) + .set(XdsInternalAttributes.ATTR_ADDRESS_NAME, endpoint.hostname()) + .set(AddressFilter.PATH_CHAIN_KEY, pathChain) + .build(); + EquivalentAddressGroup eag; + if (discovery.isHttp11ProxyAvailable()) { + List rewrittenAddresses = new ArrayList<>(); + for (SocketAddress addr : endpoint.eag().getAddresses()) { + rewrittenAddresses.add(rewriteAddress( + addr, endpoint.endpointMetadata(), localityLbInfo.localityMetadata())); } - loopStatus = Status.UNAVAILABLE.withDescription(String.format( - "CDS error: circular aggregate clusters directly under %s for " - + "root cluster %s, named %s", - clusterState.name, root.name, namesCausingLoops)); + eag = new EquivalentAddressGroup(rewrittenAddresses, attr); + } else { + eag = new EquivalentAddressGroup(endpoint.eag().getAddresses(), attr); } + addresses.add(eag); } } + if (discard) { + logger.log(XdsLogLevel.INFO, + "Discard locality {0} with 0 healthy endpoints", locality); + continue; + } + if (!prioritizedLocalityWeights.containsKey(priorityName)) { + prioritizedLocalityWeights.put(priorityName, new HashMap()); + } + prioritizedLocalityWeights.get(priorityName).put( + locality, localityLbInfo.localityWeight()); + } + if (prioritizedLocalityWeights.isEmpty()) { + // Will still update the result, as if the cluster resource is revoked. + logger.log(XdsLogLevel.INFO, + "Cluster {0} has no usable priority/locality/endpoint", clusterName); } + sortedPriorityNames.retainAll(prioritizedLocalityWeights.keySet()); + Map priorityChildConfigs = + generatePriorityChildConfigs( + clusterName, discovery, lbConfig, lbRegistry, + prioritizedLocalityWeights, dropOverloads); + return StatusOr.fromValue(new ClusterResolutionResult(addresses, priorityChildConfigs, + sortedPriorityNames)); + } - if (loopStatus != null) { - helper.updateBalancingState( - TRANSIENT_FAILURE, new FixedResultPicker(PickResult.withError(loopStatus))); - return; + private SocketAddress rewriteAddress(SocketAddress addr, + ImmutableMap endpointMetadata, + ImmutableMap localityMetadata) { + if (!(addr instanceof InetSocketAddress)) { + return addr; } - if (instances.isEmpty()) { // none of non-aggregate clusters exists - if (childLb != null) { - childLb.shutdown(); - childLb = null; + SocketAddress proxyAddress; + try { + proxyAddress = (SocketAddress) endpointMetadata.get( + "envoy.http11_proxy_transport_socket.proxy_address"); + if (proxyAddress == null) { + proxyAddress = (SocketAddress) localityMetadata.get( + "envoy.http11_proxy_transport_socket.proxy_address"); } - Status unavailable = - Status.UNAVAILABLE.withDescription("CDS error: found 0 leaf (logical DNS or EDS) " - + "clusters for root cluster " + root.name); - helper.updateBalancingState( - TRANSIENT_FAILURE, new FixedResultPicker(PickResult.withError(unavailable))); - return; + } catch (ClassCastException e) { + return addr; } - // The LB policy config is provided in service_config.proto/JSON format. It is unwrapped - // to determine the name of the policy in the load balancer registry. - LbConfig unwrappedLbConfig = ServiceConfigUtil.unwrapLoadBalancingConfig( - root.result.lbPolicyConfig()); - LoadBalancerProvider lbProvider = lbRegistry.getProvider(unwrappedLbConfig.getPolicyName()); - if (lbProvider == null) { - throw NameResolver.ConfigOrError.fromError(Status.UNAVAILABLE.withDescription( - "No provider available for LB: " + unwrappedLbConfig.getPolicyName())).getError() - .asRuntimeException(); - } - NameResolver.ConfigOrError configOrError = lbProvider.parseLoadBalancingPolicyConfig( - unwrappedLbConfig.getRawConfigValue()); - if (configOrError.getError() != null) { - throw configOrError.getError().augmentDescription("Unable to parse the LB config") - .asRuntimeException(); + if (proxyAddress == null) { + return addr; } - ClusterResolverConfig config = new ClusterResolverConfig( - Collections.unmodifiableList(instances), - new PolicySelection(lbProvider, configOrError.getConfig())); - if (childLb == null) { - childLb = lbRegistry.getProvider(CLUSTER_RESOLVER_POLICY_NAME).newLoadBalancer(helper); - } - childLb.handleResolvedAddresses( - resolvedAddresses.toBuilder().setLoadBalancingPolicyConfig(config).build()); + return HttpConnectProxiedSocketAddress.newBuilder() + .setTargetAddress((InetSocketAddress) addr) + .setProxyAddress(proxyAddress) + .build(); } - /** - * Returns children that would cause loops and builds up the parentClusters map. - **/ - - private List identifyLoops(ClusterState clusterState, - Map> parentClusters) { - Set ancestors = new HashSet<>(); - ancestors.add(clusterState.name); - addAncestors(ancestors, clusterState, parentClusters); - - List namesCausingLoops = new ArrayList<>(); - for (ClusterState state : clusterState.childClusterStates.values()) { - if (ancestors.contains(state.name)) { - namesCausingLoops.add(state.name); + private List generatePriorityNames(String name, + Map localityLbEndpoints) { + TreeMap> todo = new TreeMap<>(); + for (Locality locality : localityLbEndpoints.keySet()) { + int priority = localityLbEndpoints.get(locality).priority(); + if (!todo.containsKey(priority)) { + todo.put(priority, new ArrayList<>()); } + todo.get(priority).add(locality); } + Map newNames = new HashMap<>(); + Set usedNames = new HashSet<>(); + List ret = new ArrayList<>(); + for (Integer priority: todo.keySet()) { + String foundName = ""; + for (Locality locality : todo.get(priority)) { + if (localityPriorityNames.containsKey(locality) + && usedNames.add(localityPriorityNames.get(locality))) { + foundName = localityPriorityNames.get(locality); + break; + } + } + if ("".equals(foundName)) { + foundName = priorityName(name, priorityNameGenId++); + } + for (Locality locality : todo.get(priority)) { + newNames.put(locality, foundName); + } + ret.add(foundName); + } + localityPriorityNames = newNames; + return ret; + } + } - // Update parent map with entries from remaining children to clusterState - clusterState.childClusterStates.values().stream() - .filter(child -> !namesCausingLoops.contains(child.name)) - .forEach( - child -> parentClusters.computeIfAbsent(child, k -> new ArrayList<>()) - .add(clusterState)); - - return namesCausingLoops; + private static class ClusterResolutionResult { + // Endpoint addresses. + private final List addresses; + // Config (include load balancing policy/config) for each priority in the cluster. + private final Map priorityChildConfigs; + // List of priority names ordered in descending priorities. + private final List priorities; + + ClusterResolutionResult(List addresses, + Map configs, List priorities) { + this.addresses = addresses; + this.priorityChildConfigs = configs; + this.priorities = priorities; } + } - /** Recursively add all parents to the ancestors list. **/ - private void addAncestors(Set ancestors, ClusterState clusterState, - Map> parentClusters) { - List directParents = parentClusters.get(clusterState); - if (directParents != null) { - directParents.stream().map(c -> c.name).forEach(ancestors::add); - directParents.forEach(p -> addAncestors(ancestors, p, parentClusters)); + /** + * Generates configs to be used in the priority LB policy for priorities in a cluster. + * + *

priority LB -> cluster_impl LB (one per priority) -> (weighted_target LB + * -> round_robin / least_request_experimental (one per locality)) / ring_hash_experimental + */ + private static Map generatePriorityChildConfigs( + String clusterName, + CdsUpdate discovery, + Object endpointLbConfig, + LoadBalancerRegistry lbRegistry, + Map> prioritizedLocalityWeights, + List dropOverloads) { + Map configs = new HashMap<>(); + for (String priority : prioritizedLocalityWeights.keySet()) { + ClusterImplConfig clusterImplConfig = + new ClusterImplConfig( + clusterName, discovery.edsServiceName(), discovery.lrsServerInfo(), + discovery.maxConcurrentRequests(), dropOverloads, endpointLbConfig, + discovery.upstreamTlsContext(), discovery.filterMetadata(), + discovery.backendMetricPropagation()); + LoadBalancerProvider clusterImplLbProvider = + lbRegistry.getProvider(XdsLbPolicies.CLUSTER_IMPL_POLICY_NAME); + Object priorityChildPolicy = GracefulSwitchLoadBalancer.createLoadBalancingPolicyConfig( + clusterImplLbProvider, clusterImplConfig); + + // If outlier detection has been configured we wrap the child policy in the outlier detection + // load balancer. + if (discovery.outlierDetection() != null) { + LoadBalancerProvider outlierDetectionProvider = lbRegistry.getProvider( + "outlier_detection_experimental"); + priorityChildPolicy = GracefulSwitchLoadBalancer.createLoadBalancingPolicyConfig( + outlierDetectionProvider, + buildOutlierDetectionLbConfig(discovery.outlierDetection(), priorityChildPolicy)); } + + boolean isEds = discovery.clusterType() == ClusterType.EDS; + PriorityChildConfig priorityChildConfig = + new PriorityChildConfig(priorityChildPolicy, isEds /* ignoreReresolution */); + configs.put(priority, priorityChildConfig); } + return configs; + } - private void handleClusterDiscoveryError(Status error) { - if (childLb != null) { - childLb.handleNameResolutionError(error); - } else { - helper.updateBalancingState( - TRANSIENT_FAILURE, new FixedResultPicker(PickResult.withError(error))); - } + /** + * Converts {@link OutlierDetection} that represents the xDS configuration to {@link + * OutlierDetectionLoadBalancerConfig} that the {@link io.grpc.util.OutlierDetectionLoadBalancer} + * understands. + */ + private static OutlierDetectionLoadBalancerConfig buildOutlierDetectionLbConfig( + OutlierDetection outlierDetection, Object childConfig) { + OutlierDetectionLoadBalancerConfig.Builder configBuilder + = new OutlierDetectionLoadBalancerConfig.Builder(); + + configBuilder.setChildConfig(childConfig); + + if (outlierDetection.intervalNanos() != null) { + configBuilder.setIntervalNanos(outlierDetection.intervalNanos()); + } + if (outlierDetection.baseEjectionTimeNanos() != null) { + configBuilder.setBaseEjectionTimeNanos(outlierDetection.baseEjectionTimeNanos()); + } + if (outlierDetection.maxEjectionTimeNanos() != null) { + configBuilder.setMaxEjectionTimeNanos(outlierDetection.maxEjectionTimeNanos()); + } + if (outlierDetection.maxEjectionPercent() != null) { + configBuilder.setMaxEjectionPercent(outlierDetection.maxEjectionPercent()); } - private final class ClusterState implements ResourceWatcher { - private final String name; - @Nullable - private Map childClusterStates; - @Nullable - private CdsUpdate result; - // Following fields are effectively final. - private boolean isLeaf; - private boolean discovered; - private boolean shutdown; - - private ClusterState(String name) { - this.name = name; - } + SuccessRateEjection successRate = outlierDetection.successRateEjection(); + if (successRate != null) { + OutlierDetectionLoadBalancerConfig.SuccessRateEjection.Builder + successRateConfigBuilder = new OutlierDetectionLoadBalancerConfig + .SuccessRateEjection.Builder(); - private void start() { - shutdown = false; - xdsClient.watchXdsResource(XdsClusterResource.getInstance(), name, this, syncContext); + if (successRate.stdevFactor() != null) { + successRateConfigBuilder.setStdevFactor(successRate.stdevFactor()); } - - void shutdown() { - shutdown = true; - xdsClient.cancelXdsResourceWatch(XdsClusterResource.getInstance(), name, this); - if (childClusterStates != null) { - // recursively shut down all descendants - childClusterStates.values().stream() - .filter(state -> !state.shutdown) - .forEach(ClusterState::shutdown); - } + if (successRate.enforcementPercentage() != null) { + successRateConfigBuilder.setEnforcementPercentage(successRate.enforcementPercentage()); } - - @Override - public void onError(Status error) { - Status status = Status.UNAVAILABLE - .withDescription( - String.format("Unable to load CDS %s. xDS server returned: %s: %s", - name, error.getCode(), error.getDescription())) - .withCause(error.getCause()); - if (shutdown) { - return; - } - // All watchers should receive the same error, so we only propagate it once. - if (ClusterState.this == root) { - handleClusterDiscoveryError(status); - } + if (successRate.minimumHosts() != null) { + successRateConfigBuilder.setMinimumHosts(successRate.minimumHosts()); } - - @Override - public void onResourceDoesNotExist(String resourceName) { - if (shutdown) { - return; - } - discovered = true; - result = null; - if (childClusterStates != null) { - for (ClusterState state : childClusterStates.values()) { - state.shutdown(); - } - childClusterStates = null; - } - handleClusterDiscovered(); + if (successRate.requestVolume() != null) { + successRateConfigBuilder.setRequestVolume(successRate.requestVolume()); } - @Override - public void onChanged(final CdsUpdate update) { - if (shutdown) { - return; - } - logger.log(XdsLogLevel.DEBUG, "Received cluster update {0}", update); - discovered = true; - result = update; - if (update.clusterType() == ClusterType.AGGREGATE) { - isLeaf = false; - logger.log(XdsLogLevel.INFO, "Aggregate cluster {0}, underlying clusters: {1}", - update.clusterName(), update.prioritizedClusterNames()); - Map newChildStates = new LinkedHashMap<>(); - for (String cluster : update.prioritizedClusterNames()) { - if (newChildStates.containsKey(cluster)) { - logger.log(XdsLogLevel.WARNING, - String.format("duplicate cluster name %s in aggregate %s is being ignored", - cluster, update.clusterName())); - continue; - } - if (childClusterStates == null || !childClusterStates.containsKey(cluster)) { - ClusterState childState; - if (clusterStates.containsKey(cluster)) { - childState = clusterStates.get(cluster); - if (childState.shutdown) { - childState.start(); - } - } else { - childState = new ClusterState(cluster); - clusterStates.put(cluster, childState); - childState.start(); - } - newChildStates.put(cluster, childState); - } else { - newChildStates.put(cluster, childClusterStates.remove(cluster)); - } - } - if (childClusterStates != null) { // stop subscribing to revoked child clusters - for (ClusterState watcher : childClusterStates.values()) { - watcher.shutdown(); - } - } - childClusterStates = newChildStates; - } else if (update.clusterType() == ClusterType.EDS) { - isLeaf = true; - logger.log(XdsLogLevel.INFO, "EDS cluster {0}, edsServiceName: {1}", - update.clusterName(), update.edsServiceName()); - } else { // logical DNS - isLeaf = true; - logger.log(XdsLogLevel.INFO, "Logical DNS cluster {0}", update.clusterName()); - } - handleClusterDiscovered(); + configBuilder.setSuccessRateEjection(successRateConfigBuilder.build()); + } + + FailurePercentageEjection failurePercentage = outlierDetection.failurePercentageEjection(); + if (failurePercentage != null) { + OutlierDetectionLoadBalancerConfig.FailurePercentageEjection.Builder + failurePercentageConfigBuilder = new OutlierDetectionLoadBalancerConfig + .FailurePercentageEjection.Builder(); + + if (failurePercentage.threshold() != null) { + failurePercentageConfigBuilder.setThreshold(failurePercentage.threshold()); + } + if (failurePercentage.enforcementPercentage() != null) { + failurePercentageConfigBuilder.setEnforcementPercentage( + failurePercentage.enforcementPercentage()); + } + if (failurePercentage.minimumHosts() != null) { + failurePercentageConfigBuilder.setMinimumHosts(failurePercentage.minimumHosts()); + } + if (failurePercentage.requestVolume() != null) { + failurePercentageConfigBuilder.setRequestVolume(failurePercentage.requestVolume()); } + configBuilder.setFailurePercentageEjection(failurePercentageConfigBuilder.build()); } + + return configBuilder.build(); } } diff --git a/xds/src/main/java/io/grpc/xds/CdsLoadBalancerProvider.java b/xds/src/main/java/io/grpc/xds/CdsLoadBalancerProvider.java index 01bd2ab27f6..875af9089ed 100644 --- a/xds/src/main/java/io/grpc/xds/CdsLoadBalancerProvider.java +++ b/xds/src/main/java/io/grpc/xds/CdsLoadBalancerProvider.java @@ -23,6 +23,7 @@ import io.grpc.LoadBalancer; import io.grpc.LoadBalancer.Helper; import io.grpc.LoadBalancerProvider; +import io.grpc.LoadBalancerRegistry; import io.grpc.NameResolver.ConfigOrError; import io.grpc.Status; import io.grpc.internal.JsonUtil; @@ -36,8 +37,6 @@ @Internal public class CdsLoadBalancerProvider extends LoadBalancerProvider { - private static final String CLUSTER_KEY = "cluster"; - @Override public boolean isAvailable() { return true; @@ -53,9 +52,24 @@ public String getPolicyName() { return XdsLbPolicies.CDS_POLICY_NAME; } + private final LoadBalancerRegistry loadBalancerRegistry; + + public CdsLoadBalancerProvider() { + this.loadBalancerRegistry = null; + } + + public CdsLoadBalancerProvider(LoadBalancerRegistry loadBalancerRegistry) { + this.loadBalancerRegistry = loadBalancerRegistry; + } + @Override public LoadBalancer newLoadBalancer(Helper helper) { - return new CdsLoadBalancer2(helper); + LoadBalancerRegistry loadBalancerRegistry = this.loadBalancerRegistry; + if (loadBalancerRegistry == null) { + loadBalancerRegistry = LoadBalancerRegistry.getDefaultRegistry(); + } + + return new CdsLoadBalancer2(helper, loadBalancerRegistry); } @Override @@ -70,9 +84,12 @@ public ConfigOrError parseLoadBalancingPolicyConfig( */ static ConfigOrError parseLoadBalancingConfigPolicy(Map rawLoadBalancingPolicyConfig) { try { - String cluster = - JsonUtil.getString(rawLoadBalancingPolicyConfig, CLUSTER_KEY); - return ConfigOrError.fromConfig(new CdsConfig(cluster)); + String cluster = JsonUtil.getString(rawLoadBalancingPolicyConfig, "cluster"); + Boolean isDynamic = JsonUtil.getBoolean(rawLoadBalancingPolicyConfig, "is_dynamic"); + if (isDynamic == null) { + isDynamic = Boolean.FALSE; + } + return ConfigOrError.fromConfig(new CdsConfig(cluster, isDynamic)); } catch (RuntimeException e) { return ConfigOrError.fromError( Status.UNAVAILABLE.withCause(e).withDescription( @@ -89,15 +106,28 @@ static final class CdsConfig { * Name of cluster to query CDS for. */ final String name; + /** + * Whether this cluster was dynamically chosen, so the XdsDependencyManager may be unaware of + * it without an explicit cluster subscription. + */ + final boolean isDynamic; CdsConfig(String name) { + this(name, false); + } + + CdsConfig(String name, boolean isDynamic) { checkArgument(name != null && !name.isEmpty(), "name is null or empty"); this.name = name; + this.isDynamic = isDynamic; } @Override public String toString() { - return MoreObjects.toStringHelper(this).add("name", name).toString(); + return MoreObjects.toStringHelper(this) + .add("name", name) + .add("isDynamic", isDynamic) + .toString(); } } } diff --git a/xds/src/main/java/io/grpc/xds/ClusterImplLoadBalancer.java b/xds/src/main/java/io/grpc/xds/ClusterImplLoadBalancer.java index 871a317f832..64105144240 100644 --- a/xds/src/main/java/io/grpc/xds/ClusterImplLoadBalancer.java +++ b/xds/src/main/java/io/grpc/xds/ClusterImplLoadBalancer.java @@ -17,21 +17,26 @@ package io.grpc.xds; import static com.google.common.base.Preconditions.checkNotNull; +import static io.grpc.xds.client.LoadStatsManager2.isEnabledOrcaLrsPropagation; import com.google.common.annotations.VisibleForTesting; import com.google.common.base.MoreObjects; import com.google.common.base.Strings; +import com.google.common.collect.ImmutableMap; +import com.google.protobuf.Struct; import io.grpc.Attributes; import io.grpc.ClientStreamTracer; import io.grpc.ClientStreamTracer.StreamInfo; import io.grpc.ConnectivityState; +import io.grpc.ConnectivityStateInfo; import io.grpc.EquivalentAddressGroup; import io.grpc.InternalLogId; import io.grpc.LoadBalancer; import io.grpc.Metadata; +import io.grpc.NameResolver; import io.grpc.Status; import io.grpc.internal.ForwardingClientStreamTracer; -import io.grpc.internal.ObjectPool; +import io.grpc.internal.GrpcUtil; import io.grpc.services.MetricReport; import io.grpc.util.ForwardingLoadBalancerHelper; import io.grpc.util.ForwardingSubchannel; @@ -41,6 +46,7 @@ import io.grpc.xds.EnvoyServerProtoData.UpstreamTlsContext; import io.grpc.xds.ThreadSafeRandom.ThreadSafeRandomImpl; import io.grpc.xds.XdsNameResolverProvider.CallCounterProvider; +import io.grpc.xds.client.BackendMetricPropagation; import io.grpc.xds.client.Bootstrapper.ServerInfo; import io.grpc.xds.client.LoadStatsManager2.ClusterDropStats; import io.grpc.xds.client.LoadStatsManager2.ClusterLocalityStats; @@ -48,14 +54,19 @@ import io.grpc.xds.client.XdsClient; import io.grpc.xds.client.XdsLogger; import io.grpc.xds.client.XdsLogger.XdsLogLevel; +import io.grpc.xds.internal.XdsInternalAttributes; +import io.grpc.xds.internal.security.SecurityProtocolNegotiators; import io.grpc.xds.internal.security.SslContextProviderSupplier; import io.grpc.xds.orca.OrcaPerRequestUtil; import io.grpc.xds.orca.OrcaPerRequestUtil.OrcaPerRequestReportListener; import java.util.ArrayList; import java.util.Collections; import java.util.List; +import java.util.Locale; +import java.util.Map; import java.util.Objects; import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.atomic.AtomicReference; import javax.annotation.Nullable; /** @@ -74,8 +85,11 @@ final class ClusterImplLoadBalancer extends LoadBalancer { Strings.isNullOrEmpty(System.getenv("GRPC_XDS_EXPERIMENTAL_CIRCUIT_BREAKING")) || Boolean.parseBoolean(System.getenv("GRPC_XDS_EXPERIMENTAL_CIRCUIT_BREAKING")); - private static final Attributes.Key ATTR_CLUSTER_LOCALITY_STATS = - Attributes.Key.create("io.grpc.xds.ClusterImplLoadBalancer.clusterLocalityStats"); + private static final Attributes.Key> ATTR_CLUSTER_LOCALITY = + Attributes.Key.create("io.grpc.xds.ClusterImplLoadBalancer.clusterLocality"); + @VisibleForTesting + static final Attributes.Key ATTR_SUBCHANNEL_ADDRESS_NAME = + Attributes.Key.create("io.grpc.xds.ClusterImplLoadBalancer.addressName"); private final XdsLogger logger; private final Helper helper; @@ -84,7 +98,6 @@ final class ClusterImplLoadBalancer extends LoadBalancer { private String cluster; @Nullable private String edsServiceName; - private ObjectPool xdsClientPool; private XdsClient xdsClient; private CallCounterProvider callCounterProvider; private ClusterDropStats dropStats; @@ -107,13 +120,11 @@ final class ClusterImplLoadBalancer extends LoadBalancer { public Status acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { logger.log(XdsLogLevel.DEBUG, "Received resolution result: {0}", resolvedAddresses); Attributes attributes = resolvedAddresses.getAttributes(); - if (xdsClientPool == null) { - xdsClientPool = attributes.get(InternalXdsAttributes.XDS_CLIENT_POOL); - assert xdsClientPool != null; - xdsClient = xdsClientPool.getObject(); + if (xdsClient == null) { + xdsClient = checkNotNull(attributes.get(io.grpc.xds.XdsAttributes.XDS_CLIENT), "xdsClient"); } if (callCounterProvider == null) { - callCounterProvider = attributes.get(InternalXdsAttributes.CALL_COUNTER_PROVIDER); + callCounterProvider = attributes.get(io.grpc.xds.XdsAttributes.CALL_COUNTER_PROVIDER); } ClusterImplConfig config = @@ -138,14 +149,16 @@ public Status acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { childLbHelper.updateDropPolicies(config.dropCategories); childLbHelper.updateMaxConcurrentRequests(config.maxConcurrentRequests); childLbHelper.updateSslContextProviderSupplier(config.tlsContext); + childLbHelper.updateFilterMetadata(config.filterMetadata); + childLbHelper.updateBackendMetricPropagation(config.backendMetricPropagation); - childSwitchLb.switchTo(config.childPolicy.getProvider()); - childSwitchLb.handleResolvedAddresses( + return childSwitchLb.acceptResolvedAddresses( resolvedAddresses.toBuilder() - .setAttributes(attributes) - .setLoadBalancingPolicyConfig(config.childPolicy.getConfig()) + .setAttributes(attributes.toBuilder() + .set(NameResolver.ATTR_BACKEND_SERVICE, cluster) + .build()) + .setLoadBalancingPolicyConfig(config.childConfig) .build()); - return Status.OK; } @Override @@ -158,6 +171,13 @@ public void handleNameResolutionError(Status error) { } } + @Override + public void requestConnection() { + if (childSwitchLb != null) { + childSwitchLb.requestConnection(); + } + } + @Override public void shutdown() { if (dropStats != null) { @@ -170,9 +190,7 @@ public void shutdown() { childLbHelper = null; } } - if (xdsClient != null) { - xdsClient = xdsClientPool.returnObject(xdsClient); - } + xdsClient = null; } /** @@ -187,8 +205,11 @@ private final class ClusterImplLbHelper extends ForwardingLoadBalancerHelper { private long maxConcurrentRequests = DEFAULT_PER_CLUSTER_MAX_CONCURRENT_REQUESTS; @Nullable private SslContextProviderSupplier sslContextProviderSupplier; + private Map filterMetadata = ImmutableMap.of(); @Nullable private final ServerInfo lrsServerInfo; + @Nullable + private BackendMetricPropagation backendMetricPropagation; private ClusterImplLbHelper(AtomicLong inFlights, @Nullable ServerInfo lrsServerInfo) { this.inFlights = checkNotNull(inFlights, "inFlights"); @@ -199,52 +220,86 @@ private ClusterImplLbHelper(AtomicLong inFlights, @Nullable ServerInfo lrsServer public void updateBalancingState(ConnectivityState newState, SubchannelPicker newPicker) { currentState = newState; currentPicker = newPicker; - SubchannelPicker picker = - new RequestLimitingSubchannelPicker(newPicker, dropPolicies, maxConcurrentRequests); + SubchannelPicker picker = new RequestLimitingSubchannelPicker( + newPicker, dropPolicies, maxConcurrentRequests, filterMetadata); delegate().updateBalancingState(newState, picker); } @Override public Subchannel createSubchannel(CreateSubchannelArgs args) { List addresses = withAdditionalAttributes(args.getAddresses()); - Locality locality = args.getAddresses().get(0).getAttributes().get( - InternalXdsAttributes.ATTR_LOCALITY); // all addresses should be in the same locality - // Endpoint addresses resolved by ClusterResolverLoadBalancer should always contain - // attributes with its locality, including endpoints in LOGICAL_DNS clusters. - // In case of not (which really shouldn't), loads are aggregated under an empty locality. - if (locality == null) { - locality = Locality.create("", "", ""); + // This value for ClusterLocality is not recommended for general use. + // Currently, we extract locality data from the first address, even before the subchannel is + // READY. + // This is mainly to accommodate scenarios where a Load Balancing API (like "pick first") + // might return the subchannel before it is READY. Typically, we wouldn't report load for such + // selections because the channel will disregard the chosen (not-ready) subchannel. + // However, we needed to ensure this case is handled. + ClusterLocality clusterLocality = createClusterLocalityFromAttributes( + args.getAddresses().get(0).getAttributes()); + AtomicReference localityAtomicReference = new AtomicReference<>( + clusterLocality); + Attributes.Builder attrsBuilder = args.getAttributes().toBuilder() + .set(ATTR_CLUSTER_LOCALITY, localityAtomicReference); + if (GrpcUtil.getFlag("GRPC_EXPERIMENTAL_XDS_AUTHORITY_REWRITE", false)) { + String hostname = args.getAddresses().get(0).getAttributes() + .get(XdsInternalAttributes.ATTR_ADDRESS_NAME); + if (hostname != null) { + attrsBuilder.set(ATTR_SUBCHANNEL_ADDRESS_NAME, hostname); + } } - final ClusterLocalityStats localityStats = - (lrsServerInfo == null) - ? null - : xdsClient.addClusterLocalityStats(lrsServerInfo, cluster, - edsServiceName, locality); - - Attributes attrs = args.getAttributes().toBuilder().set( - ATTR_CLUSTER_LOCALITY_STATS, localityStats).build(); - args = args.toBuilder().setAddresses(addresses).setAttributes(attrs).build(); + args = args.toBuilder().setAddresses(addresses).setAttributes(attrsBuilder.build()).build(); final Subchannel subchannel = delegate().createSubchannel(args); - return new ForwardingSubchannel() { - @Override - public void shutdown() { - if (localityStats != null) { - localityStats.release(); - } - delegate().shutdown(); - } + return new ClusterImplSubchannel(subchannel, localityAtomicReference); + } - @Override - public void updateAddresses(List addresses) { - delegate().updateAddresses(withAdditionalAttributes(addresses)); - } + private final class ClusterImplSubchannel extends ForwardingSubchannel { + private final Subchannel delegate; + private final AtomicReference localityAtomicReference; - @Override - protected Subchannel delegate() { - return subchannel; - } - }; + private ClusterImplSubchannel( + Subchannel delegate, AtomicReference localityAtomicReference) { + this.delegate = delegate; + this.localityAtomicReference = localityAtomicReference; + } + + @Override + public void start(SubchannelStateListener listener) { + delegate().start( + new SubchannelStateListener() { + @Override + public void onSubchannelState(ConnectivityStateInfo newState) { + // Do nothing if LB has been shutdown + if (xdsClient != null && newState.getState().equals(ConnectivityState.READY)) { + // Get locality based on the connected address attributes + ClusterLocality updatedClusterLocality = + createClusterLocalityFromAttributes( + delegate.getConnectedAddressAttributes()); + ClusterLocality oldClusterLocality = + localityAtomicReference.getAndSet(updatedClusterLocality); + oldClusterLocality.release(); + } + listener.onSubchannelState(newState); + } + }); + } + + @Override + public void shutdown() { + localityAtomicReference.get().release(); + delegate().shutdown(); + } + + @Override + public void updateAddresses(List addresses) { + delegate().updateAddresses(withAdditionalAttributes(addresses)); + } + + @Override + protected Subchannel delegate() { + return delegate; + } } private List withAdditionalAttributes( @@ -252,10 +307,10 @@ private List withAdditionalAttributes( List newAddresses = new ArrayList<>(); for (EquivalentAddressGroup eag : addresses) { Attributes.Builder attrBuilder = eag.getAttributes().toBuilder().set( - InternalXdsAttributes.ATTR_CLUSTER_NAME, cluster); + io.grpc.xds.XdsAttributes.ATTR_CLUSTER_NAME, cluster); if (sslContextProviderSupplier != null) { attrBuilder.set( - InternalXdsAttributes.ATTR_SSL_CONTEXT_PROVIDER_SUPPLIER, + SecurityProtocolNegotiators.ATTR_SSL_CONTEXT_PROVIDER_SUPPLIER, sslContextProviderSupplier); } newAddresses.add(new EquivalentAddressGroup(eag.getAddresses(), attrBuilder.build())); @@ -263,6 +318,28 @@ private List withAdditionalAttributes( return newAddresses; } + private ClusterLocality createClusterLocalityFromAttributes(Attributes addressAttributes) { + Locality locality = addressAttributes.get(io.grpc.xds.XdsAttributes.ATTR_LOCALITY); + String localityName = addressAttributes.get(EquivalentAddressGroup.ATTR_LOCALITY_NAME); + + // Endpoint addresses resolved by ClusterResolverLoadBalancer should always contain + // attributes with its locality, including endpoints in LOGICAL_DNS clusters. + // In case of not (which really shouldn't), loads are aggregated under an empty + // locality. + if (locality == null) { + locality = Locality.create("", "", ""); + localityName = ""; + } + + final ClusterLocalityStats localityStats = + (lrsServerInfo == null) + ? null + : xdsClient.addClusterLocalityStats(lrsServerInfo, cluster, + edsServiceName, locality, backendMetricPropagation); + + return new ClusterLocality(localityStats, localityName); + } + @Override protected Helper delegate() { return helper; @@ -304,20 +381,35 @@ private void updateSslContextProviderSupplier(@Nullable UpstreamTlsContext tlsCo : null; } + private void updateFilterMetadata(Map filterMetadata) { + this.filterMetadata = ImmutableMap.copyOf(filterMetadata); + } + + private void updateBackendMetricPropagation( + @Nullable BackendMetricPropagation backendMetricPropagation) { + this.backendMetricPropagation = backendMetricPropagation; + } + private class RequestLimitingSubchannelPicker extends SubchannelPicker { private final SubchannelPicker delegate; private final List dropPolicies; private final long maxConcurrentRequests; + private final Map filterMetadata; private RequestLimitingSubchannelPicker(SubchannelPicker delegate, - List dropPolicies, long maxConcurrentRequests) { + List dropPolicies, long maxConcurrentRequests, + Map filterMetadata) { this.delegate = delegate; this.dropPolicies = dropPolicies; this.maxConcurrentRequests = maxConcurrentRequests; + this.filterMetadata = checkNotNull(filterMetadata, "filterMetadata"); } @Override public PickResult pickSubchannel(PickSubchannelArgs args) { + args.getCallOptions().getOption(ClusterImplLoadBalancerProvider.FILTER_METADATA_CONSUMER) + .accept(filterMetadata); + args.getPickDetailsConsumer().addOptionalLabel("grpc.lb.backend_service", cluster); for (DropOverload dropOverload : dropPolicies) { int rand = random.nextInt(1_000_000); if (rand < dropOverload.dropsPerMillion()) { @@ -330,25 +422,46 @@ public PickResult pickSubchannel(PickSubchannelArgs args) { Status.UNAVAILABLE.withDescription("Dropped: " + dropOverload.category())); } } - final PickResult result = delegate.pickSubchannel(args); + PickResult result = delegate.pickSubchannel(args); if (result.getStatus().isOk() && result.getSubchannel() != null) { + Subchannel subchannel = result.getSubchannel(); + if (subchannel instanceof ClusterImplLbHelper.ClusterImplSubchannel) { + subchannel = ((ClusterImplLbHelper.ClusterImplSubchannel) subchannel).delegate(); + result = result.copyWithSubchannel(subchannel); + } if (enableCircuitBreaking) { if (inFlights.get() >= maxConcurrentRequests) { if (dropStats != null) { dropStats.recordDroppedRequest(); } return PickResult.withDrop(Status.UNAVAILABLE.withDescription( - "Cluster max concurrent requests limit exceeded")); + String.format(Locale.US, "Cluster max concurrent requests limit of %d exceeded", + maxConcurrentRequests))); + } + } + final AtomicReference clusterLocality = + result.getSubchannel().getAttributes().get(ATTR_CLUSTER_LOCALITY); + + if (clusterLocality != null) { + ClusterLocalityStats stats = clusterLocality.get().getClusterLocalityStats(); + if (stats != null) { + String localityName = + result.getSubchannel().getAttributes().get(ATTR_CLUSTER_LOCALITY).get() + .getClusterLocalityName(); + args.getPickDetailsConsumer().addOptionalLabel("grpc.lb.locality", localityName); + + ClientStreamTracer.Factory tracerFactory = new CountingStreamTracerFactory( + stats, inFlights, result.getStreamTracerFactory()); + ClientStreamTracer.Factory orcaTracerFactory = OrcaPerRequestUtil.getInstance() + .newOrcaClientStreamTracerFactory(tracerFactory, new OrcaPerRpcListener(stats)); + result = result.copyWithStreamTracerFactory(orcaTracerFactory); } } - final ClusterLocalityStats stats = - result.getSubchannel().getAttributes().get(ATTR_CLUSTER_LOCALITY_STATS); - if (stats != null) { - ClientStreamTracer.Factory tracerFactory = new CountingStreamTracerFactory( - stats, inFlights, result.getStreamTracerFactory()); - ClientStreamTracer.Factory orcaTracerFactory = OrcaPerRequestUtil.getInstance() - .newOrcaClientStreamTracerFactory(tracerFactory, new OrcaPerRpcListener(stats)); - return PickResult.withSubchannel(result.getSubchannel(), orcaTracerFactory); + if (args.getCallOptions().getOption(XdsNameResolver.AUTO_HOST_REWRITE_KEY) != null + && args.getCallOptions().getOption(XdsNameResolver.AUTO_HOST_REWRITE_KEY)) { + result = PickResult.withSubchannel(result.getSubchannel(), + result.getStreamTracerFactory(), + result.getSubchannel().getAttributes().get(ATTR_SUBCHANNEL_ADDRESS_NAME)); } } return result; @@ -415,12 +528,49 @@ private OrcaPerRpcListener(ClusterLocalityStats stats) { } /** - * Copies {@link MetricReport#getNamedMetrics()} to {@link ClusterLocalityStats} such that it is - * included in the snapshot for the LRS report sent to the LRS server. + * Copies ORCA metrics from {@link MetricReport} to {@link ClusterLocalityStats} + * such that they are included in the snapshot for the LRS report sent to the LRS server. + * This includes both top-level metrics (CPU, memory, application utilization) and named + * metrics, filtered according to the backend metric propagation configuration. */ @Override public void onLoadReport(MetricReport report) { + if (isEnabledOrcaLrsPropagation) { + stats.recordTopLevelMetrics( + report.getCpuUtilization(), + report.getMemoryUtilization(), + report.getApplicationUtilization()); + } stats.recordBackendLoadMetricStats(report.getNamedMetrics()); } } + + /** + * Represents the {@link ClusterLocalityStats} and network locality name of a cluster. + */ + static final class ClusterLocality { + private final ClusterLocalityStats clusterLocalityStats; + private final String clusterLocalityName; + + @VisibleForTesting + ClusterLocality(ClusterLocalityStats localityStats, String localityName) { + this.clusterLocalityStats = localityStats; + this.clusterLocalityName = localityName; + } + + ClusterLocalityStats getClusterLocalityStats() { + return clusterLocalityStats; + } + + String getClusterLocalityName() { + return clusterLocalityName; + } + + @VisibleForTesting + void release() { + if (clusterLocalityStats != null) { + clusterLocalityStats.release(); + } + } + } } diff --git a/xds/src/main/java/io/grpc/xds/ClusterImplLoadBalancerProvider.java b/xds/src/main/java/io/grpc/xds/ClusterImplLoadBalancerProvider.java index ff32779b0e6..f369c3b99b4 100644 --- a/xds/src/main/java/io/grpc/xds/ClusterImplLoadBalancerProvider.java +++ b/xds/src/main/java/io/grpc/xds/ClusterImplLoadBalancerProvider.java @@ -19,6 +19,9 @@ import static com.google.common.base.Preconditions.checkNotNull; import com.google.common.base.MoreObjects; +import com.google.common.collect.ImmutableMap; +import com.google.protobuf.Struct; +import io.grpc.CallOptions; import io.grpc.Internal; import io.grpc.LoadBalancer; import io.grpc.LoadBalancer.Helper; @@ -26,14 +29,15 @@ import io.grpc.LoadBalancerRegistry; import io.grpc.NameResolver.ConfigOrError; import io.grpc.Status; -import io.grpc.internal.ServiceConfigUtil.PolicySelection; import io.grpc.xds.Endpoints.DropOverload; import io.grpc.xds.EnvoyServerProtoData.UpstreamTlsContext; +import io.grpc.xds.client.BackendMetricPropagation; import io.grpc.xds.client.Bootstrapper.ServerInfo; import java.util.ArrayList; import java.util.Collections; import java.util.List; import java.util.Map; +import java.util.function.Consumer; import javax.annotation.Nullable; /** @@ -43,6 +47,11 @@ */ @Internal public final class ClusterImplLoadBalancerProvider extends LoadBalancerProvider { + /** + * Consumer of filter metadata from the cluster used by the call. Consumer may not modify map. + */ + public static final CallOptions.Key>> FILTER_METADATA_CONSUMER = + CallOptions.Key.createWithDefault("io.grpc.xds.internalFilterMetadataConsumer", (m) -> { }); @Override public boolean isAvailable() { @@ -88,20 +97,26 @@ static final class ClusterImplConfig { // Drop configurations. final List dropCategories; // Provides the direct child policy and its config. - final PolicySelection childPolicy; + final Object childConfig; + final Map filterMetadata; + @Nullable + final BackendMetricPropagation backendMetricPropagation; ClusterImplConfig(String cluster, @Nullable String edsServiceName, @Nullable ServerInfo lrsServerInfo, @Nullable Long maxConcurrentRequests, - List dropCategories, PolicySelection childPolicy, - @Nullable UpstreamTlsContext tlsContext) { + List dropCategories, Object childConfig, + @Nullable UpstreamTlsContext tlsContext, Map filterMetadata, + @Nullable BackendMetricPropagation backendMetricPropagation) { this.cluster = checkNotNull(cluster, "cluster"); this.edsServiceName = edsServiceName; this.lrsServerInfo = lrsServerInfo; this.maxConcurrentRequests = maxConcurrentRequests; this.tlsContext = tlsContext; + this.filterMetadata = ImmutableMap.copyOf(filterMetadata); this.dropCategories = Collections.unmodifiableList( new ArrayList<>(checkNotNull(dropCategories, "dropCategories"))); - this.childPolicy = checkNotNull(childPolicy, "childPolicy"); + this.childConfig = checkNotNull(childConfig, "childConfig"); + this.backendMetricPropagation = backendMetricPropagation; } @Override @@ -113,7 +128,7 @@ public String toString() { .add("maxConcurrentRequests", maxConcurrentRequests) // Exclude tlsContext as its string representation is cumbersome. .add("dropCategories", dropCategories) - .add("childPolicy", childPolicy) + .add("childConfig", childConfig) .toString(); } } diff --git a/xds/src/main/java/io/grpc/xds/ClusterManagerLoadBalancer.java b/xds/src/main/java/io/grpc/xds/ClusterManagerLoadBalancer.java index f2e3833ae15..22b5aaa7d73 100644 --- a/xds/src/main/java/io/grpc/xds/ClusterManagerLoadBalancer.java +++ b/xds/src/main/java/io/grpc/xds/ClusterManagerLoadBalancer.java @@ -23,11 +23,11 @@ import com.google.common.base.MoreObjects; import io.grpc.ConnectivityState; import io.grpc.InternalLogId; -import io.grpc.LoadBalancerProvider; +import io.grpc.LoadBalancer; import io.grpc.Status; import io.grpc.SynchronizationContext; import io.grpc.SynchronizationContext.ScheduledHandle; -import io.grpc.internal.ServiceConfigUtil.PolicySelection; +import io.grpc.util.GracefulSwitchLoadBalancer; import io.grpc.util.MultiChildLoadBalancer; import io.grpc.xds.ClusterManagerLoadBalancerProvider.ClusterManagerConfig; import io.grpc.xds.client.XdsLogger; @@ -57,6 +57,7 @@ class ClusterManagerLoadBalancer extends MultiChildLoadBalancer { protected final SynchronizationContext syncContext; private final ScheduledExecutorService timeService; private final XdsLogger logger; + private ResolvedAddresses lastResolvedAddresses; ClusterManagerLoadBalancer(Helper helper) { super(helper); @@ -69,55 +70,48 @@ class ClusterManagerLoadBalancer extends MultiChildLoadBalancer { } @Override - protected ResolvedAddresses getChildAddresses(Object key, ResolvedAddresses resolvedAddresses, - Object childConfig) { - return resolvedAddresses.toBuilder().setLoadBalancingPolicyConfig(childConfig).build(); + protected ChildLbState createChildLbState(Object key) { + return new ClusterManagerLbState(key, GracefulSwitchLoadBalancerFactory.INSTANCE); } @Override - protected Map createChildLbMap(ResolvedAddresses resolvedAddresses) { + protected Map createChildAddressesMap( + ResolvedAddresses resolvedAddresses) { + lastResolvedAddresses = resolvedAddresses; + ClusterManagerConfig config = (ClusterManagerConfig) resolvedAddresses.getLoadBalancingPolicyConfig(); - Map newChildPolicies = new HashMap<>(); - if (config != null) { - for (Entry entry : config.childPolicies.entrySet()) { - ChildLbState child = getChildLbState(entry.getKey()); - if (child == null) { - child = new ClusterManagerLbState(entry.getKey(), - entry.getValue().getProvider(), entry.getValue().getConfig(), getInitialPicker()); - } - newChildPolicies.put(entry.getKey(), child); - } - } logger.log( XdsLogLevel.INFO, - "Received cluster_manager lb config: child names={0}", newChildPolicies.keySet()); - return newChildPolicies; - } + "Received cluster_manager lb config: child names={0}", config.childPolicies.keySet()); + Map childAddresses = new HashMap<>(); - /** - * This is like the parent except that it doesn't shutdown the removed children since we want that - * to be done by the timer. - */ - @Override - public Status acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { - try { - resolvingAddresses = true; - - // process resolvedAddresses to update children - AcceptResolvedAddrRetVal acceptRetVal = - acceptResolvedAddressesInternal(resolvedAddresses); - if (!acceptRetVal.status.isOk()) { - return acceptRetVal.status; + // Reactivate children with config; deactivate children without config + for (ChildLbState rawState : getChildLbStates()) { + ClusterManagerLbState state = (ClusterManagerLbState) rawState; + if (config.childPolicies.containsKey(state.getKey())) { + // Active child + if (state.deletionTimer != null) { + state.reactivateChild(); + } + } else { + // Inactive child + if (state.deletionTimer == null) { + state.deactivateChild(); + } + if (state.deletionTimer.isPending()) { + childAddresses.put(state.getKey(), null); // Preserve child, without config update + } } + } - // Update the picker - updateOverallBalancingState(); - - return acceptRetVal.status; - } finally { - resolvingAddresses = false; + for (Map.Entry childPolicy : config.childPolicies.entrySet()) { + ResolvedAddresses addresses = resolvedAddresses.toBuilder() + .setLoadBalancingPolicyConfig(childPolicy.getValue()) + .build(); + childAddresses.put(childPolicy.getKey(), addresses); } + return childAddresses; } /** @@ -130,7 +124,7 @@ protected void updateOverallBalancingState() { ConnectivityState overallState = null; final Map childPickers = new HashMap<>(); for (ChildLbState childLbState : getChildLbStates()) { - if (childLbState.isDeactivated()) { + if (((ClusterManagerLbState) childLbState).deletionTimer != null) { continue; } childPickers.put(childLbState.getKey(), childLbState.getCurrentPicker()); @@ -171,13 +165,14 @@ public void handleNameResolutionError(Status error) { logger.log(XdsLogLevel.WARNING, "Received name resolution error: {0}", error); boolean gotoTransientFailure = true; for (ChildLbState state : getChildLbStates()) { - if (!state.isDeactivated()) { + if (((ClusterManagerLbState) state).deletionTimer == null) { gotoTransientFailure = false; - handleNameResolutionError(state, error); + state.getLb().handleNameResolutionError(error); } } if (gotoTransientFailure) { - getHelper().updateBalancingState(TRANSIENT_FAILURE, getErrorPicker(error)); + getHelper().updateBalancingState( + TRANSIENT_FAILURE, new FixedResultPicker(PickResult.withError(error))); } } @@ -191,9 +186,8 @@ private class ClusterManagerLbState extends ChildLbState { @Nullable ScheduledHandle deletionTimer; - public ClusterManagerLbState(Object key, LoadBalancerProvider policyProvider, - Object childConfig, SubchannelPicker initialPicker) { - super(key, policyProvider, childConfig, initialPicker); + public ClusterManagerLbState(Object key, LoadBalancer.Factory policyFactory) { + super(key, policyFactory); } @Override @@ -203,34 +197,28 @@ protected ChildLbStateHelper createChildHelper() { @Override protected void shutdown() { - if (deletionTimer != null && deletionTimer.isPending()) { + if (deletionTimer != null) { deletionTimer.cancel(); + deletionTimer = null; } super.shutdown(); } - @Override - protected void reactivate(LoadBalancerProvider policyProvider) { - if (deletionTimer != null && deletionTimer.isPending()) { - deletionTimer.cancel(); - logger.log(XdsLogLevel.DEBUG, "Child balancer {0} reactivated", getKey()); - } - - super.reactivate(policyProvider); + void reactivateChild() { + assert deletionTimer != null; + deletionTimer.cancel(); + deletionTimer = null; + logger.log(XdsLogLevel.DEBUG, "Child balancer {0} reactivated", getKey()); } - @Override - protected void deactivate() { - if (isDeactivated()) { - return; - } + void deactivateChild() { + assert deletionTimer == null; class DeletionTask implements Runnable { @Override public void run() { - shutdown(); - removeChild(getKey()); + acceptResolvedAddresses(lastResolvedAddresses); } } @@ -240,7 +228,6 @@ public void run() { DELAYED_CHILD_DELETION_TIME_MINUTES, TimeUnit.MINUTES, timeService); - setDeactivated(); logger.log(XdsLogLevel.DEBUG, "Child balancer {0} deactivated", getKey()); } @@ -248,9 +235,7 @@ private class ClusterManagerChildHelper extends ChildLbStateHelper { @Override public void updateBalancingState(final ConnectivityState newState, final SubchannelPicker newPicker) { - // If we are already in the process of resolving addresses, the overall balancing state - // will be updated at the end of it, and we don't need to trigger that update here. - if (getChildLbState(getKey()) == null) { + if (getCurrentState() == ConnectivityState.SHUTDOWN) { return; } @@ -258,10 +243,21 @@ public void updateBalancingState(final ConnectivityState newState, // when the child instance exits deactivated state. setCurrentState(newState); setCurrentPicker(newPicker); - if (!isDeactivated() && !resolvingAddresses) { + // If we are already in the process of resolving addresses, the overall balancing state + // will be updated at the end of it, and we don't need to trigger that update here. + if (deletionTimer == null && !resolvingAddresses) { updateOverallBalancingState(); } } } } + + static final class GracefulSwitchLoadBalancerFactory extends LoadBalancer.Factory { + static final LoadBalancer.Factory INSTANCE = new GracefulSwitchLoadBalancerFactory(); + + @Override + public LoadBalancer newLoadBalancer(LoadBalancer.Helper helper) { + return new GracefulSwitchLoadBalancer(helper); + } + } } diff --git a/xds/src/main/java/io/grpc/xds/ClusterManagerLoadBalancerProvider.java b/xds/src/main/java/io/grpc/xds/ClusterManagerLoadBalancerProvider.java index 9c97d3fe966..7a7e16286f8 100644 --- a/xds/src/main/java/io/grpc/xds/ClusterManagerLoadBalancerProvider.java +++ b/xds/src/main/java/io/grpc/xds/ClusterManagerLoadBalancerProvider.java @@ -26,12 +26,9 @@ import io.grpc.NameResolver.ConfigOrError; import io.grpc.Status; import io.grpc.internal.JsonUtil; -import io.grpc.internal.ServiceConfigUtil; -import io.grpc.internal.ServiceConfigUtil.LbConfig; -import io.grpc.internal.ServiceConfigUtil.PolicySelection; +import io.grpc.util.GracefulSwitchLoadBalancer; import java.util.Collections; import java.util.LinkedHashMap; -import java.util.List; import java.util.Map; import java.util.Objects; import javax.annotation.Nullable; @@ -73,7 +70,7 @@ public String getPolicyName() { @Override public ConfigOrError parseLoadBalancingPolicyConfig(Map rawConfig) { - Map parsedChildPolicies = new LinkedHashMap<>(); + Map parsedChildPolicies = new LinkedHashMap<>(); try { Map childPolicies = JsonUtil.getObject(rawConfig, "childPolicy"); if (childPolicies == null || childPolicies.isEmpty()) { @@ -86,27 +83,19 @@ public ConfigOrError parseLoadBalancingPolicyConfig(Map rawConfig) { return ConfigOrError.fromError(Status.INTERNAL.withDescription( "No config for child " + name + " in cluster_manager LB policy: " + rawConfig)); } - List childConfigCandidates = - ServiceConfigUtil.unwrapLoadBalancingConfigList( - JsonUtil.getListOfObjects(childPolicy, "lbPolicy")); - if (childConfigCandidates == null || childConfigCandidates.isEmpty()) { - return ConfigOrError.fromError(Status.INTERNAL.withDescription( - "No config specified for child " + name + " in cluster_manager Lb policy: " - + rawConfig)); - } LoadBalancerRegistry registry = lbRegistry != null ? lbRegistry : LoadBalancerRegistry.getDefaultRegistry(); - ConfigOrError selectedConfig = - ServiceConfigUtil.selectLbPolicyFromList(childConfigCandidates, registry); - if (selectedConfig.getError() != null) { - Status error = selectedConfig.getError(); + ConfigOrError childConfig = GracefulSwitchLoadBalancer.parseLoadBalancingPolicyConfig( + JsonUtil.getListOfObjects(childPolicy, "lbPolicy"), registry); + if (childConfig.getError() != null) { + Status error = childConfig.getError(); return ConfigOrError.fromError( Status.INTERNAL .withCause(error.getCause()) .withDescription(error.getDescription()) - .augmentDescription("Failed to select config for child " + name)); + .augmentDescription("Failed to parse config for child " + name)); } - parsedChildPolicies.put(name, (PolicySelection) selectedConfig.getConfig()); + parsedChildPolicies.put(name, childConfig.getConfig()); } } catch (RuntimeException e) { return ConfigOrError.fromError( @@ -122,9 +111,9 @@ public LoadBalancer newLoadBalancer(Helper helper) { } static class ClusterManagerConfig { - final Map childPolicies; + final Map childPolicies; - ClusterManagerConfig(Map childPolicies) { + ClusterManagerConfig(Map childPolicies) { this.childPolicies = Collections.unmodifiableMap(childPolicies); } diff --git a/xds/src/main/java/io/grpc/xds/ClusterResolverLoadBalancer.java b/xds/src/main/java/io/grpc/xds/ClusterResolverLoadBalancer.java deleted file mode 100644 index 3ba08a23fbc..00000000000 --- a/xds/src/main/java/io/grpc/xds/ClusterResolverLoadBalancer.java +++ /dev/null @@ -1,849 +0,0 @@ -/* - * Copyright 2020 The gRPC Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.grpc.xds; - -import static com.google.common.base.Preconditions.checkNotNull; -import static io.grpc.ConnectivityState.TRANSIENT_FAILURE; -import static io.grpc.xds.XdsLbPolicies.PRIORITY_POLICY_NAME; - -import com.google.common.annotations.VisibleForTesting; -import io.grpc.Attributes; -import io.grpc.EquivalentAddressGroup; -import io.grpc.InternalLogId; -import io.grpc.LoadBalancer; -import io.grpc.LoadBalancerProvider; -import io.grpc.LoadBalancerRegistry; -import io.grpc.NameResolver; -import io.grpc.NameResolver.ResolutionResult; -import io.grpc.Status; -import io.grpc.SynchronizationContext; -import io.grpc.SynchronizationContext.ScheduledHandle; -import io.grpc.internal.BackoffPolicy; -import io.grpc.internal.ExponentialBackoffPolicy; -import io.grpc.internal.ObjectPool; -import io.grpc.internal.ServiceConfigUtil.PolicySelection; -import io.grpc.util.ForwardingLoadBalancerHelper; -import io.grpc.util.GracefulSwitchLoadBalancer; -import io.grpc.util.OutlierDetectionLoadBalancer.OutlierDetectionLoadBalancerConfig; -import io.grpc.xds.ClusterImplLoadBalancerProvider.ClusterImplConfig; -import io.grpc.xds.ClusterResolverLoadBalancerProvider.ClusterResolverConfig; -import io.grpc.xds.ClusterResolverLoadBalancerProvider.ClusterResolverConfig.DiscoveryMechanism; -import io.grpc.xds.Endpoints.DropOverload; -import io.grpc.xds.Endpoints.LbEndpoint; -import io.grpc.xds.Endpoints.LocalityLbEndpoints; -import io.grpc.xds.EnvoyServerProtoData.FailurePercentageEjection; -import io.grpc.xds.EnvoyServerProtoData.OutlierDetection; -import io.grpc.xds.EnvoyServerProtoData.SuccessRateEjection; -import io.grpc.xds.EnvoyServerProtoData.UpstreamTlsContext; -import io.grpc.xds.PriorityLoadBalancerProvider.PriorityLbConfig; -import io.grpc.xds.PriorityLoadBalancerProvider.PriorityLbConfig.PriorityChildConfig; -import io.grpc.xds.XdsEndpointResource.EdsUpdate; -import io.grpc.xds.client.Bootstrapper.ServerInfo; -import io.grpc.xds.client.Locality; -import io.grpc.xds.client.XdsClient; -import io.grpc.xds.client.XdsClient.ResourceWatcher; -import io.grpc.xds.client.XdsLogger; -import io.grpc.xds.client.XdsLogger.XdsLogLevel; -import java.net.URI; -import java.net.URISyntaxException; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collections; -import java.util.HashMap; -import java.util.HashSet; -import java.util.List; -import java.util.Locale; -import java.util.Map; -import java.util.Objects; -import java.util.Set; -import java.util.TreeMap; -import java.util.concurrent.ScheduledExecutorService; -import java.util.concurrent.TimeUnit; -import javax.annotation.Nullable; - -/** - * Load balancer for cluster_resolver_experimental LB policy. This LB policy is the child LB policy - * of the cds_experimental LB policy and the parent LB policy of the priority_experimental LB - * policy in the xDS load balancing hierarchy. This policy resolves endpoints of non-aggregate - * clusters (e.g., EDS or Logical DNS) and groups endpoints in priorities and localities to be - * used in the downstream LB policies for fine-grained load balancing purposes. - */ -final class ClusterResolverLoadBalancer extends LoadBalancer { - // DNS-resolved endpoints do not have the definition of the locality it belongs to, just hardcode - // to an empty locality. - private static final Locality LOGICAL_DNS_CLUSTER_LOCALITY = Locality.create("", "", ""); - private final XdsLogger logger; - private final SynchronizationContext syncContext; - private final ScheduledExecutorService timeService; - private final LoadBalancerRegistry lbRegistry; - private final BackoffPolicy.Provider backoffPolicyProvider; - private final GracefulSwitchLoadBalancer delegate; - private ObjectPool xdsClientPool; - private XdsClient xdsClient; - private ClusterResolverConfig config; - - ClusterResolverLoadBalancer(Helper helper) { - this(helper, LoadBalancerRegistry.getDefaultRegistry(), - new ExponentialBackoffPolicy.Provider()); - } - - @VisibleForTesting - ClusterResolverLoadBalancer(Helper helper, LoadBalancerRegistry lbRegistry, - BackoffPolicy.Provider backoffPolicyProvider) { - this.lbRegistry = checkNotNull(lbRegistry, "lbRegistry"); - this.backoffPolicyProvider = checkNotNull(backoffPolicyProvider, "backoffPolicyProvider"); - this.syncContext = checkNotNull(helper.getSynchronizationContext(), "syncContext"); - this.timeService = checkNotNull(helper.getScheduledExecutorService(), "timeService"); - delegate = new GracefulSwitchLoadBalancer(helper); - logger = XdsLogger.withLogId( - InternalLogId.allocate("cluster-resolver-lb", helper.getAuthority())); - logger.log(XdsLogLevel.INFO, "Created"); - } - - @Override - public Status acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { - logger.log(XdsLogLevel.DEBUG, "Received resolution result: {0}", resolvedAddresses); - if (xdsClientPool == null) { - xdsClientPool = resolvedAddresses.getAttributes().get(InternalXdsAttributes.XDS_CLIENT_POOL); - xdsClient = xdsClientPool.getObject(); - } - ClusterResolverConfig config = - (ClusterResolverConfig) resolvedAddresses.getLoadBalancingPolicyConfig(); - if (!Objects.equals(this.config, config)) { - logger.log(XdsLogLevel.DEBUG, "Config: {0}", config); - delegate.switchTo(new ClusterResolverLbStateFactory()); - this.config = config; - delegate.handleResolvedAddresses(resolvedAddresses); - } - return Status.OK; - } - - @Override - public void handleNameResolutionError(Status error) { - logger.log(XdsLogLevel.WARNING, "Received name resolution error: {0}", error); - delegate.handleNameResolutionError(error); - } - - @Override - public void shutdown() { - logger.log(XdsLogLevel.INFO, "Shutdown"); - delegate.shutdown(); - if (xdsClientPool != null) { - xdsClientPool.returnObject(xdsClient); - } - } - - private final class ClusterResolverLbStateFactory extends LoadBalancer.Factory { - @Override - public LoadBalancer newLoadBalancer(Helper helper) { - return new ClusterResolverLbState(helper); - } - } - - /** - * The state of a cluster_resolver LB working session. A new instance is created whenever - * the cluster_resolver LB receives a new config. The old instance is replaced when the - * new one is ready to handle new RPCs. - */ - private final class ClusterResolverLbState extends LoadBalancer { - private final Helper helper; - private final List clusters = new ArrayList<>(); - private final Map clusterStates = new HashMap<>(); - private PolicySelection endpointLbPolicy; - private ResolvedAddresses resolvedAddresses; - private LoadBalancer childLb; - - ClusterResolverLbState(Helper helper) { - this.helper = new RefreshableHelper(checkNotNull(helper, "helper")); - logger.log(XdsLogLevel.DEBUG, "New ClusterResolverLbState"); - } - - @Override - public Status acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { - this.resolvedAddresses = resolvedAddresses; - ClusterResolverConfig config = - (ClusterResolverConfig) resolvedAddresses.getLoadBalancingPolicyConfig(); - endpointLbPolicy = config.lbPolicy; - for (DiscoveryMechanism instance : config.discoveryMechanisms) { - clusters.add(instance.cluster); - ClusterState state; - if (instance.type == DiscoveryMechanism.Type.EDS) { - state = new EdsClusterState(instance.cluster, instance.edsServiceName, - instance.lrsServerInfo, instance.maxConcurrentRequests, instance.tlsContext, - instance.outlierDetection); - } else { // logical DNS - state = new LogicalDnsClusterState(instance.cluster, instance.dnsHostName, - instance.lrsServerInfo, instance.maxConcurrentRequests, instance.tlsContext); - } - clusterStates.put(instance.cluster, state); - state.start(); - } - return Status.OK; - } - - @Override - public void handleNameResolutionError(Status error) { - if (childLb != null) { - childLb.handleNameResolutionError(error); - } else { - helper.updateBalancingState( - TRANSIENT_FAILURE, new FixedResultPicker(PickResult.withError(error))); - } - } - - @Override - public void shutdown() { - for (ClusterState state : clusterStates.values()) { - state.shutdown(); - } - if (childLb != null) { - childLb.shutdown(); - } - } - - private void handleEndpointResourceUpdate() { - List addresses = new ArrayList<>(); - Map priorityChildConfigs = new HashMap<>(); - List priorities = new ArrayList<>(); // totally ordered priority list - - Status endpointNotFound = Status.OK; - for (String cluster : clusters) { - ClusterState state = clusterStates.get(cluster); - // Propagate endpoints to the child LB policy only after all clusters have been resolved. - if (!state.resolved && state.status.isOk()) { - return; - } - if (state.result != null) { - addresses.addAll(state.result.addresses); - priorityChildConfigs.putAll(state.result.priorityChildConfigs); - priorities.addAll(state.result.priorities); - } else { - endpointNotFound = state.status; - } - } - if (addresses.isEmpty()) { - if (endpointNotFound.isOk()) { - endpointNotFound = Status.UNAVAILABLE.withDescription( - "No usable endpoint from cluster(s): " + clusters); - } else { - endpointNotFound = - Status.UNAVAILABLE.withCause(endpointNotFound.getCause()) - .withDescription(endpointNotFound.getDescription()); - } - helper.updateBalancingState( - TRANSIENT_FAILURE, new FixedResultPicker(PickResult.withError(endpointNotFound))); - if (childLb != null) { - childLb.shutdown(); - childLb = null; - } - return; - } - PriorityLbConfig childConfig = - new PriorityLbConfig(Collections.unmodifiableMap(priorityChildConfigs), - Collections.unmodifiableList(priorities)); - if (childLb == null) { - childLb = lbRegistry.getProvider(PRIORITY_POLICY_NAME).newLoadBalancer(helper); - } - childLb.handleResolvedAddresses( - resolvedAddresses.toBuilder() - .setLoadBalancingPolicyConfig(childConfig) - .setAddresses(Collections.unmodifiableList(addresses)) - .build()); - } - - private void handleEndpointResolutionError() { - boolean allInError = true; - Status error = null; - for (String cluster : clusters) { - ClusterState state = clusterStates.get(cluster); - if (state.status.isOk()) { - allInError = false; - } else { - error = state.status; - } - } - if (allInError) { - if (childLb != null) { - childLb.handleNameResolutionError(error); - } else { - helper.updateBalancingState( - TRANSIENT_FAILURE, new FixedResultPicker(PickResult.withError(error))); - } - } - } - - /** - * Wires re-resolution requests from downstream LB policies with DNS resolver. - */ - private final class RefreshableHelper extends ForwardingLoadBalancerHelper { - private final Helper delegate; - - private RefreshableHelper(Helper delegate) { - this.delegate = checkNotNull(delegate, "delegate"); - } - - @Override - public void refreshNameResolution() { - for (ClusterState state : clusterStates.values()) { - if (state instanceof LogicalDnsClusterState) { - ((LogicalDnsClusterState) state).refresh(); - } - } - } - - @Override - protected Helper delegate() { - return delegate; - } - } - - /** - * Resolution state of an underlying cluster. - */ - private abstract class ClusterState { - // Name of the cluster to be resolved. - protected final String name; - @Nullable - protected final ServerInfo lrsServerInfo; - @Nullable - protected final Long maxConcurrentRequests; - @Nullable - protected final UpstreamTlsContext tlsContext; - @Nullable - protected final OutlierDetection outlierDetection; - // Resolution status, may contain most recent error encountered. - protected Status status = Status.OK; - // True if has received resolution result. - protected boolean resolved; - // Most recently resolved addresses and config, or null if resource not exists. - @Nullable - protected ClusterResolutionResult result; - - protected boolean shutdown; - - private ClusterState(String name, @Nullable ServerInfo lrsServerInfo, - @Nullable Long maxConcurrentRequests, @Nullable UpstreamTlsContext tlsContext, - @Nullable OutlierDetection outlierDetection) { - this.name = name; - this.lrsServerInfo = lrsServerInfo; - this.maxConcurrentRequests = maxConcurrentRequests; - this.tlsContext = tlsContext; - this.outlierDetection = outlierDetection; - } - - abstract void start(); - - void shutdown() { - shutdown = true; - } - } - - private final class EdsClusterState extends ClusterState implements ResourceWatcher { - @Nullable - private final String edsServiceName; - private Map localityPriorityNames = Collections.emptyMap(); - int priorityNameGenId = 1; - - private EdsClusterState(String name, @Nullable String edsServiceName, - @Nullable ServerInfo lrsServerInfo, @Nullable Long maxConcurrentRequests, - @Nullable UpstreamTlsContext tlsContext, @Nullable OutlierDetection outlierDetection) { - super(name, lrsServerInfo, maxConcurrentRequests, tlsContext, outlierDetection); - this.edsServiceName = edsServiceName; - } - - @Override - void start() { - String resourceName = edsServiceName != null ? edsServiceName : name; - logger.log(XdsLogLevel.INFO, "Start watching EDS resource {0}", resourceName); - xdsClient.watchXdsResource(XdsEndpointResource.getInstance(), - resourceName, this, syncContext); - } - - @Override - protected void shutdown() { - super.shutdown(); - String resourceName = edsServiceName != null ? edsServiceName : name; - logger.log(XdsLogLevel.INFO, "Stop watching EDS resource {0}", resourceName); - xdsClient.cancelXdsResourceWatch(XdsEndpointResource.getInstance(), resourceName, this); - } - - @Override - public void onChanged(final EdsUpdate update) { - class EndpointsUpdated implements Runnable { - @Override - public void run() { - if (shutdown) { - return; - } - logger.log(XdsLogLevel.DEBUG, "Received endpoint update {0}", update); - if (logger.isLoggable(XdsLogLevel.INFO)) { - logger.log(XdsLogLevel.INFO, "Cluster {0}: {1} localities, {2} drop categories", - update.clusterName, update.localityLbEndpointsMap.size(), - update.dropPolicies.size()); - } - Map localityLbEndpoints = - update.localityLbEndpointsMap; - List dropOverloads = update.dropPolicies; - List addresses = new ArrayList<>(); - Map> prioritizedLocalityWeights = new HashMap<>(); - List sortedPriorityNames = generatePriorityNames(name, localityLbEndpoints); - for (Locality locality : localityLbEndpoints.keySet()) { - LocalityLbEndpoints localityLbInfo = localityLbEndpoints.get(locality); - String priorityName = localityPriorityNames.get(locality); - boolean discard = true; - for (LbEndpoint endpoint : localityLbInfo.endpoints()) { - if (endpoint.isHealthy()) { - discard = false; - long weight = localityLbInfo.localityWeight(); - if (endpoint.loadBalancingWeight() != 0) { - weight *= endpoint.loadBalancingWeight(); - } - Attributes attr = - endpoint.eag().getAttributes().toBuilder() - .set(InternalXdsAttributes.ATTR_LOCALITY, locality) - .set(InternalXdsAttributes.ATTR_LOCALITY_WEIGHT, - localityLbInfo.localityWeight()) - .set(InternalXdsAttributes.ATTR_SERVER_WEIGHT, weight) - .build(); - EquivalentAddressGroup eag = new EquivalentAddressGroup( - endpoint.eag().getAddresses(), attr); - eag = AddressFilter.setPathFilter( - eag, Arrays.asList(priorityName, localityName(locality))); - addresses.add(eag); - } - } - if (discard) { - logger.log(XdsLogLevel.INFO, - "Discard locality {0} with 0 healthy endpoints", locality); - continue; - } - if (!prioritizedLocalityWeights.containsKey(priorityName)) { - prioritizedLocalityWeights.put(priorityName, new HashMap()); - } - prioritizedLocalityWeights.get(priorityName).put( - locality, localityLbInfo.localityWeight()); - } - if (prioritizedLocalityWeights.isEmpty()) { - // Will still update the result, as if the cluster resource is revoked. - logger.log(XdsLogLevel.INFO, - "Cluster {0} has no usable priority/locality/endpoint", update.clusterName); - } - sortedPriorityNames.retainAll(prioritizedLocalityWeights.keySet()); - Map priorityChildConfigs = - generateEdsBasedPriorityChildConfigs( - name, edsServiceName, lrsServerInfo, maxConcurrentRequests, tlsContext, - outlierDetection, endpointLbPolicy, lbRegistry, prioritizedLocalityWeights, - dropOverloads); - status = Status.OK; - resolved = true; - result = new ClusterResolutionResult(addresses, priorityChildConfigs, - sortedPriorityNames); - handleEndpointResourceUpdate(); - } - } - - new EndpointsUpdated().run(); - } - - private List generatePriorityNames(String name, - Map localityLbEndpoints) { - TreeMap> todo = new TreeMap<>(); - for (Locality locality : localityLbEndpoints.keySet()) { - int priority = localityLbEndpoints.get(locality).priority(); - if (!todo.containsKey(priority)) { - todo.put(priority, new ArrayList<>()); - } - todo.get(priority).add(locality); - } - Map newNames = new HashMap<>(); - Set usedNames = new HashSet<>(); - List ret = new ArrayList<>(); - for (Integer priority: todo.keySet()) { - String foundName = ""; - for (Locality locality : todo.get(priority)) { - if (localityPriorityNames.containsKey(locality) - && usedNames.add(localityPriorityNames.get(locality))) { - foundName = localityPriorityNames.get(locality); - break; - } - } - if ("".equals(foundName)) { - foundName = String.format(Locale.US, "%s[child%d]", name, priorityNameGenId++); - } - for (Locality locality : todo.get(priority)) { - newNames.put(locality, foundName); - } - ret.add(foundName); - } - localityPriorityNames = newNames; - return ret; - } - - @Override - public void onResourceDoesNotExist(final String resourceName) { - if (shutdown) { - return; - } - logger.log(XdsLogLevel.INFO, "Resource {0} unavailable", resourceName); - status = Status.OK; - resolved = true; - result = null; // resource revoked - handleEndpointResourceUpdate(); - } - - @Override - public void onError(final Status error) { - if (shutdown) { - return; - } - String resourceName = edsServiceName != null ? edsServiceName : name; - status = Status.UNAVAILABLE - .withDescription(String.format("Unable to load EDS %s. xDS server returned: %s: %s", - resourceName, error.getCode(), error.getDescription())) - .withCause(error.getCause()); - logger.log(XdsLogLevel.WARNING, "Received EDS error: {0}", error); - handleEndpointResolutionError(); - } - } - - private final class LogicalDnsClusterState extends ClusterState { - private final String dnsHostName; - private final NameResolver.Factory nameResolverFactory; - private final NameResolver.Args nameResolverArgs; - private NameResolver resolver; - @Nullable - private BackoffPolicy backoffPolicy; - @Nullable - private ScheduledHandle scheduledRefresh; - - private LogicalDnsClusterState(String name, String dnsHostName, - @Nullable ServerInfo lrsServerInfo, @Nullable Long maxConcurrentRequests, - @Nullable UpstreamTlsContext tlsContext) { - super(name, lrsServerInfo, maxConcurrentRequests, tlsContext, null); - this.dnsHostName = checkNotNull(dnsHostName, "dnsHostName"); - nameResolverFactory = - checkNotNull(helper.getNameResolverRegistry().asFactory(), "nameResolverFactory"); - nameResolverArgs = checkNotNull(helper.getNameResolverArgs(), "nameResolverArgs"); - } - - @Override - void start() { - URI uri; - try { - uri = new URI("dns", "", "/" + dnsHostName, null); - } catch (URISyntaxException e) { - status = Status.INTERNAL.withDescription( - "Bug, invalid URI creation: " + dnsHostName).withCause(e); - handleEndpointResolutionError(); - return; - } - resolver = nameResolverFactory.newNameResolver(uri, nameResolverArgs); - if (resolver == null) { - status = Status.INTERNAL.withDescription("Xds cluster resolver lb for logical DNS " - + "cluster [" + name + "] cannot find DNS resolver with uri:" + uri); - handleEndpointResolutionError(); - return; - } - resolver.start(new NameResolverListener()); - } - - void refresh() { - if (resolver == null) { - return; - } - cancelBackoff(); - resolver.refresh(); - } - - @Override - void shutdown() { - super.shutdown(); - if (resolver != null) { - resolver.shutdown(); - } - cancelBackoff(); - } - - private void cancelBackoff() { - if (scheduledRefresh != null) { - scheduledRefresh.cancel(); - scheduledRefresh = null; - backoffPolicy = null; - } - } - - private class DelayedNameResolverRefresh implements Runnable { - @Override - public void run() { - scheduledRefresh = null; - if (!shutdown) { - resolver.refresh(); - } - } - } - - private class NameResolverListener extends NameResolver.Listener2 { - @Override - public void onResult(final ResolutionResult resolutionResult) { - class NameResolved implements Runnable { - @Override - public void run() { - if (shutdown) { - return; - } - backoffPolicy = null; // reset backoff sequence if succeeded - // Arbitrary priority notation for all DNS-resolved endpoints. - String priorityName = priorityName(name, 0); // value doesn't matter - List addresses = new ArrayList<>(); - for (EquivalentAddressGroup eag : resolutionResult.getAddresses()) { - // No weight attribute is attached, all endpoint-level LB policy should be able - // to handle such it. - Attributes attr = eag.getAttributes().toBuilder().set( - InternalXdsAttributes.ATTR_LOCALITY, LOGICAL_DNS_CLUSTER_LOCALITY).build(); - eag = new EquivalentAddressGroup(eag.getAddresses(), attr); - eag = AddressFilter.setPathFilter( - eag, Arrays.asList(priorityName, LOGICAL_DNS_CLUSTER_LOCALITY.toString())); - addresses.add(eag); - } - PriorityChildConfig priorityChildConfig = generateDnsBasedPriorityChildConfig( - name, lrsServerInfo, maxConcurrentRequests, tlsContext, lbRegistry, - Collections.emptyList()); - status = Status.OK; - resolved = true; - result = new ClusterResolutionResult(addresses, priorityName, priorityChildConfig); - handleEndpointResourceUpdate(); - } - } - - syncContext.execute(new NameResolved()); - } - - @Override - public void onError(final Status error) { - syncContext.execute(new Runnable() { - @Override - public void run() { - if (shutdown) { - return; - } - status = error; - // NameResolver.Listener API cannot distinguish between address-not-found and - // transient errors. If the error occurs in the first resolution, treat it as - // address not found. Otherwise, either there is previously resolved addresses - // previously encountered error, propagate the error to downstream/upstream and - // let downstream/upstream handle it. - if (!resolved) { - resolved = true; - handleEndpointResourceUpdate(); - } else { - handleEndpointResolutionError(); - } - if (scheduledRefresh != null && scheduledRefresh.isPending()) { - return; - } - if (backoffPolicy == null) { - backoffPolicy = backoffPolicyProvider.get(); - } - long delayNanos = backoffPolicy.nextBackoffNanos(); - logger.log(XdsLogLevel.DEBUG, - "Logical DNS resolver for cluster {0} encountered name resolution " - + "error: {1}, scheduling DNS resolution backoff for {2} ns", - name, error, delayNanos); - scheduledRefresh = - syncContext.schedule( - new DelayedNameResolverRefresh(), delayNanos, TimeUnit.NANOSECONDS, - timeService); - } - }); - } - } - } - } - - private static class ClusterResolutionResult { - // Endpoint addresses. - private final List addresses; - // Config (include load balancing policy/config) for each priority in the cluster. - private final Map priorityChildConfigs; - // List of priority names ordered in descending priorities. - private final List priorities; - - ClusterResolutionResult(List addresses, String priority, - PriorityChildConfig config) { - this(addresses, Collections.singletonMap(priority, config), - Collections.singletonList(priority)); - } - - ClusterResolutionResult(List addresses, - Map configs, List priorities) { - this.addresses = addresses; - this.priorityChildConfigs = configs; - this.priorities = priorities; - } - } - - /** - * Generates the config to be used in the priority LB policy for the single priority of - * logical DNS cluster. - * - *

priority LB -> cluster_impl LB (single hardcoded priority) -> pick_first - */ - private static PriorityChildConfig generateDnsBasedPriorityChildConfig( - String cluster, @Nullable ServerInfo lrsServerInfo, @Nullable Long maxConcurrentRequests, - @Nullable UpstreamTlsContext tlsContext, LoadBalancerRegistry lbRegistry, - List dropOverloads) { - // Override endpoint-level LB policy with pick_first for logical DNS cluster. - PolicySelection endpointLbPolicy = - new PolicySelection(lbRegistry.getProvider("pick_first"), null); - ClusterImplConfig clusterImplConfig = - new ClusterImplConfig(cluster, null, lrsServerInfo, maxConcurrentRequests, - dropOverloads, endpointLbPolicy, tlsContext); - LoadBalancerProvider clusterImplLbProvider = - lbRegistry.getProvider(XdsLbPolicies.CLUSTER_IMPL_POLICY_NAME); - PolicySelection clusterImplPolicy = - new PolicySelection(clusterImplLbProvider, clusterImplConfig); - return new PriorityChildConfig(clusterImplPolicy, false /* ignoreReresolution*/); - } - - /** - * Generates configs to be used in the priority LB policy for priorities in an EDS cluster. - * - *

priority LB -> cluster_impl LB (one per priority) -> (weighted_target LB - * -> round_robin / least_request_experimental (one per locality)) / ring_hash_experimental - */ - private static Map generateEdsBasedPriorityChildConfigs( - String cluster, @Nullable String edsServiceName, @Nullable ServerInfo lrsServerInfo, - @Nullable Long maxConcurrentRequests, @Nullable UpstreamTlsContext tlsContext, - @Nullable OutlierDetection outlierDetection, PolicySelection endpointLbPolicy, - LoadBalancerRegistry lbRegistry, Map> prioritizedLocalityWeights, List dropOverloads) { - Map configs = new HashMap<>(); - for (String priority : prioritizedLocalityWeights.keySet()) { - ClusterImplConfig clusterImplConfig = - new ClusterImplConfig(cluster, edsServiceName, lrsServerInfo, maxConcurrentRequests, - dropOverloads, endpointLbPolicy, tlsContext); - LoadBalancerProvider clusterImplLbProvider = - lbRegistry.getProvider(XdsLbPolicies.CLUSTER_IMPL_POLICY_NAME); - PolicySelection priorityChildPolicy = - new PolicySelection(clusterImplLbProvider, clusterImplConfig); - - // If outlier detection has been configured we wrap the child policy in the outlier detection - // load balancer. - if (outlierDetection != null) { - LoadBalancerProvider outlierDetectionProvider = lbRegistry.getProvider( - "outlier_detection_experimental"); - priorityChildPolicy = new PolicySelection(outlierDetectionProvider, - buildOutlierDetectionLbConfig(outlierDetection, priorityChildPolicy)); - } - - PriorityChildConfig priorityChildConfig = - new PriorityChildConfig(priorityChildPolicy, true /* ignoreReresolution */); - configs.put(priority, priorityChildConfig); - } - return configs; - } - - /** - * Converts {@link OutlierDetection} that represents the xDS configuration to {@link - * OutlierDetectionLoadBalancerConfig} that the {@link io.grpc.util.OutlierDetectionLoadBalancer} - * understands. - */ - private static OutlierDetectionLoadBalancerConfig buildOutlierDetectionLbConfig( - OutlierDetection outlierDetection, PolicySelection childPolicy) { - OutlierDetectionLoadBalancerConfig.Builder configBuilder - = new OutlierDetectionLoadBalancerConfig.Builder(); - - configBuilder.setChildPolicy(childPolicy); - - if (outlierDetection.intervalNanos() != null) { - configBuilder.setIntervalNanos(outlierDetection.intervalNanos()); - } - if (outlierDetection.baseEjectionTimeNanos() != null) { - configBuilder.setBaseEjectionTimeNanos(outlierDetection.baseEjectionTimeNanos()); - } - if (outlierDetection.maxEjectionTimeNanos() != null) { - configBuilder.setMaxEjectionTimeNanos(outlierDetection.maxEjectionTimeNanos()); - } - if (outlierDetection.maxEjectionPercent() != null) { - configBuilder.setMaxEjectionPercent(outlierDetection.maxEjectionPercent()); - } - - SuccessRateEjection successRate = outlierDetection.successRateEjection(); - if (successRate != null) { - OutlierDetectionLoadBalancerConfig.SuccessRateEjection.Builder - successRateConfigBuilder = new OutlierDetectionLoadBalancerConfig - .SuccessRateEjection.Builder(); - - if (successRate.stdevFactor() != null) { - successRateConfigBuilder.setStdevFactor(successRate.stdevFactor()); - } - if (successRate.enforcementPercentage() != null) { - successRateConfigBuilder.setEnforcementPercentage(successRate.enforcementPercentage()); - } - if (successRate.minimumHosts() != null) { - successRateConfigBuilder.setMinimumHosts(successRate.minimumHosts()); - } - if (successRate.requestVolume() != null) { - successRateConfigBuilder.setRequestVolume(successRate.requestVolume()); - } - - configBuilder.setSuccessRateEjection(successRateConfigBuilder.build()); - } - - FailurePercentageEjection failurePercentage = outlierDetection.failurePercentageEjection(); - if (failurePercentage != null) { - OutlierDetectionLoadBalancerConfig.FailurePercentageEjection.Builder - failurePercentageConfigBuilder = new OutlierDetectionLoadBalancerConfig - .FailurePercentageEjection.Builder(); - - if (failurePercentage.threshold() != null) { - failurePercentageConfigBuilder.setThreshold(failurePercentage.threshold()); - } - if (failurePercentage.enforcementPercentage() != null) { - failurePercentageConfigBuilder.setEnforcementPercentage( - failurePercentage.enforcementPercentage()); - } - if (failurePercentage.minimumHosts() != null) { - failurePercentageConfigBuilder.setMinimumHosts(failurePercentage.minimumHosts()); - } - if (failurePercentage.requestVolume() != null) { - failurePercentageConfigBuilder.setRequestVolume(failurePercentage.requestVolume()); - } - - configBuilder.setFailurePercentageEjection(failurePercentageConfigBuilder.build()); - } - - return configBuilder.build(); - } - - /** - * Generates a string that represents the priority in the LB policy config. The string is unique - * across priorities in all clusters and priorityName(c, p1) < priorityName(c, p2) iff p1 < p2. - * The ordering is undefined for priorities in different clusters. - */ - private static String priorityName(String cluster, int priority) { - return cluster + "[child" + priority + "]"; - } - - /** - * Generates a string that represents the locality in the LB policy config. The string is unique - * across all localities in all clusters. - */ - private static String localityName(Locality locality) { - return locality.toString(); - } -} diff --git a/xds/src/main/java/io/grpc/xds/ClusterResolverLoadBalancerProvider.java b/xds/src/main/java/io/grpc/xds/ClusterResolverLoadBalancerProvider.java deleted file mode 100644 index 6488a719a1b..00000000000 --- a/xds/src/main/java/io/grpc/xds/ClusterResolverLoadBalancerProvider.java +++ /dev/null @@ -1,206 +0,0 @@ -/* - * Copyright 2020 The gRPC Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.grpc.xds; - -import static com.google.common.base.Preconditions.checkNotNull; - -import com.google.common.base.MoreObjects; -import io.grpc.Internal; -import io.grpc.LoadBalancer; -import io.grpc.LoadBalancer.Helper; -import io.grpc.LoadBalancerProvider; -import io.grpc.NameResolver.ConfigOrError; -import io.grpc.Status; -import io.grpc.internal.ServiceConfigUtil.PolicySelection; -import io.grpc.xds.EnvoyServerProtoData.OutlierDetection; -import io.grpc.xds.EnvoyServerProtoData.UpstreamTlsContext; -import io.grpc.xds.client.Bootstrapper.ServerInfo; -import java.util.List; -import java.util.Map; -import java.util.Objects; -import javax.annotation.Nullable; - -/** - * The provider for the cluster_resolver load balancing policy. This class should not be directly - * referenced in code. The policy should be accessed through - * {@link io.grpc.LoadBalancerRegistry#getProvider} with the name "cluster_resolver_experimental". - */ -@Internal -public final class ClusterResolverLoadBalancerProvider extends LoadBalancerProvider { - - @Override - public boolean isAvailable() { - return true; - } - - @Override - public int getPriority() { - return 5; - } - - @Override - public String getPolicyName() { - return XdsLbPolicies.CLUSTER_RESOLVER_POLICY_NAME; - } - - @Override - public ConfigOrError parseLoadBalancingPolicyConfig(Map rawLoadBalancingPolicyConfig) { - return ConfigOrError.fromError( - Status.INTERNAL.withDescription(getPolicyName() + " cannot be used from service config")); - } - - @Override - public LoadBalancer newLoadBalancer(Helper helper) { - return new ClusterResolverLoadBalancer(helper); - } - - static final class ClusterResolverConfig { - // Ordered list of clusters to be resolved. - final List discoveryMechanisms; - // Endpoint-level load balancing policy with config - // (round_robin, least_request_experimental or ring_hash_experimental). - final PolicySelection lbPolicy; - - ClusterResolverConfig(List discoveryMechanisms, PolicySelection lbPolicy) { - this.discoveryMechanisms = checkNotNull(discoveryMechanisms, "discoveryMechanisms"); - this.lbPolicy = checkNotNull(lbPolicy, "lbPolicy"); - } - - @Override - public int hashCode() { - return Objects.hash(discoveryMechanisms, lbPolicy); - } - - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { - return false; - } - ClusterResolverConfig that = (ClusterResolverConfig) o; - return discoveryMechanisms.equals(that.discoveryMechanisms) - && lbPolicy.equals(that.lbPolicy); - } - - @Override - public String toString() { - return MoreObjects.toStringHelper(this) - .add("discoveryMechanisms", discoveryMechanisms) - .add("lbPolicy", lbPolicy) - .toString(); - } - - // Describes the mechanism for a specific cluster. - static final class DiscoveryMechanism { - // Name of the cluster to resolve. - final String cluster; - // Type of the cluster. - final Type type; - // Load reporting server info. Null if not enabled. - @Nullable - final ServerInfo lrsServerInfo; - // Cluster-level max concurrent request threshold. Null if not specified. - @Nullable - final Long maxConcurrentRequests; - // TLS context for connections to endpoints in the cluster. - @Nullable - final UpstreamTlsContext tlsContext; - // Resource name for resolving endpoints via EDS. Only valid for EDS clusters. - @Nullable - final String edsServiceName; - // Hostname for resolving endpoints via DNS. Only valid for LOGICAL_DNS clusters. - @Nullable - final String dnsHostName; - @Nullable - final OutlierDetection outlierDetection; - - enum Type { - EDS, - LOGICAL_DNS, - } - - private DiscoveryMechanism(String cluster, Type type, @Nullable String edsServiceName, - @Nullable String dnsHostName, @Nullable ServerInfo lrsServerInfo, - @Nullable Long maxConcurrentRequests, @Nullable UpstreamTlsContext tlsContext, - @Nullable OutlierDetection outlierDetection) { - this.cluster = checkNotNull(cluster, "cluster"); - this.type = checkNotNull(type, "type"); - this.edsServiceName = edsServiceName; - this.dnsHostName = dnsHostName; - this.lrsServerInfo = lrsServerInfo; - this.maxConcurrentRequests = maxConcurrentRequests; - this.tlsContext = tlsContext; - this.outlierDetection = outlierDetection; - } - - static DiscoveryMechanism forEds(String cluster, @Nullable String edsServiceName, - @Nullable ServerInfo lrsServerInfo, @Nullable Long maxConcurrentRequests, - @Nullable UpstreamTlsContext tlsContext, - OutlierDetection outlierDetection) { - return new DiscoveryMechanism(cluster, Type.EDS, edsServiceName, null, lrsServerInfo, - maxConcurrentRequests, tlsContext, outlierDetection); - } - - static DiscoveryMechanism forLogicalDns(String cluster, String dnsHostName, - @Nullable ServerInfo lrsServerInfo, @Nullable Long maxConcurrentRequests, - @Nullable UpstreamTlsContext tlsContext) { - return new DiscoveryMechanism(cluster, Type.LOGICAL_DNS, null, dnsHostName, - lrsServerInfo, maxConcurrentRequests, tlsContext, null); - } - - @Override - public int hashCode() { - return Objects.hash(cluster, type, lrsServerInfo, maxConcurrentRequests, tlsContext, - edsServiceName, dnsHostName); - } - - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { - return false; - } - DiscoveryMechanism that = (DiscoveryMechanism) o; - return cluster.equals(that.cluster) - && type == that.type - && Objects.equals(edsServiceName, that.edsServiceName) - && Objects.equals(dnsHostName, that.dnsHostName) - && Objects.equals(lrsServerInfo, that.lrsServerInfo) - && Objects.equals(maxConcurrentRequests, that.maxConcurrentRequests) - && Objects.equals(tlsContext, that.tlsContext); - } - - @Override - public String toString() { - MoreObjects.ToStringHelper toStringHelper = - MoreObjects.toStringHelper(this) - .add("cluster", cluster) - .add("type", type) - .add("edsServiceName", edsServiceName) - .add("dnsHostName", dnsHostName) - .add("lrsServerInfo", lrsServerInfo) - // Exclude tlsContext as its string representation is cumbersome. - .add("maxConcurrentRequests", maxConcurrentRequests); - return toStringHelper.toString(); - } - } - } -} diff --git a/xds/src/main/java/io/grpc/xds/CsdsService.java b/xds/src/main/java/io/grpc/xds/CsdsService.java index 69aee71f17f..8c2fe333c15 100644 --- a/xds/src/main/java/io/grpc/xds/CsdsService.java +++ b/xds/src/main/java/io/grpc/xds/CsdsService.java @@ -28,7 +28,8 @@ import io.envoyproxy.envoy.service.status.v3.ClientStatusDiscoveryServiceGrpc; import io.envoyproxy.envoy.service.status.v3.ClientStatusRequest; import io.envoyproxy.envoy.service.status.v3.ClientStatusResponse; -import io.grpc.ExperimentalApi; +import io.grpc.BindableService; +import io.grpc.ServerServiceDefinition; import io.grpc.Status; import io.grpc.StatusException; import io.grpc.internal.ObjectPool; @@ -38,6 +39,8 @@ import io.grpc.xds.client.XdsClient.ResourceMetadata.ResourceMetadataStatus; import io.grpc.xds.client.XdsClient.ResourceMetadata.UpdateFailureState; import io.grpc.xds.client.XdsResourceType; +import java.util.ArrayList; +import java.util.List; import java.util.Map; import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; @@ -55,11 +58,10 @@ * * @since 1.37.0 */ -@ExperimentalApi("https://github.com/grpc/grpc-java/issues/8016") -public final class CsdsService extends - ClientStatusDiscoveryServiceGrpc.ClientStatusDiscoveryServiceImplBase { +public final class CsdsService implements BindableService { private static final Logger logger = Logger.getLogger(CsdsService.class.getName()); private final XdsClientPoolFactory xdsClientPoolFactory; + private final CsdsServiceInternal delegate = new CsdsServiceInternal(); @VisibleForTesting CsdsService(XdsClientPoolFactory xdsClientPoolFactory) { @@ -76,75 +78,99 @@ public static CsdsService newInstance() { } @Override - public void fetchClientStatus( - ClientStatusRequest request, StreamObserver responseObserver) { - if (handleRequest(request, responseObserver)) { - responseObserver.onCompleted(); - } - // TODO(sergiitk): Add a case covering mutating handleRequest return false to true - to verify - // that responseObserver.onCompleted() isn't erroneously called on error. + public ServerServiceDefinition bindService() { + return delegate.bindService(); } - @Override - public StreamObserver streamClientStatus( - final StreamObserver responseObserver) { - return new StreamObserver() { - @Override - public void onNext(ClientStatusRequest request) { - handleRequest(request, responseObserver); + /** Hide protobuf from being exposed via the API. */ + private final class CsdsServiceInternal + extends ClientStatusDiscoveryServiceGrpc.ClientStatusDiscoveryServiceImplBase { + @Override + public void fetchClientStatus( + ClientStatusRequest request, StreamObserver responseObserver) { + if (handleRequest(request, responseObserver)) { + responseObserver.onCompleted(); } + // TODO(sergiitk): Add a case covering mutating handleRequest return false to true - to verify + // that responseObserver.onCompleted() isn't erroneously called on error. + } - @Override - public void onError(Throwable t) { - onCompleted(); - } + @Override + public StreamObserver streamClientStatus( + final StreamObserver responseObserver) { + return new StreamObserver() { + @Override + public void onNext(ClientStatusRequest request) { + handleRequest(request, responseObserver); + } - @Override - public void onCompleted() { - responseObserver.onCompleted(); - } - }; + @Override + public void onError(Throwable t) { + onCompleted(); + } + + @Override + public void onCompleted() { + responseObserver.onCompleted(); + } + }; + } } private boolean handleRequest( ClientStatusRequest request, StreamObserver responseObserver) { - StatusException error; - try { - responseObserver.onNext(getConfigDumpForRequest(request)); - return true; - } catch (StatusException e) { - error = e; - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - logger.log(Level.FINE, "Server interrupted while building CSDS config dump", e); - error = Status.ABORTED.withDescription("Thread interrupted").withCause(e).asException(); - } catch (RuntimeException e) { - logger.log(Level.WARNING, "Unexpected error while building CSDS config dump", e); - error = - Status.INTERNAL.withDescription("Unexpected internal error").withCause(e).asException(); - } - responseObserver.onError(error); - return false; - } + StatusException error = null; - private ClientStatusResponse getConfigDumpForRequest(ClientStatusRequest request) - throws StatusException, InterruptedException { if (request.getNodeMatchersCount() > 0) { - throw new StatusException( + error = new StatusException( Status.INVALID_ARGUMENT.withDescription("node_matchers not supported")); + } else { + List targets = xdsClientPoolFactory.getTargets(); + List clientConfigs = new ArrayList<>(targets.size()); + + for (int i = 0; i < targets.size() && error == null; i++) { + try { + ClientConfig clientConfig = getConfigForRequest(targets.get(i)); + if (clientConfig != null) { + clientConfigs.add(clientConfig); + } + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + logger.log(Level.FINE, "Server interrupted while building CSDS config dump", e); + error = Status.ABORTED.withDescription("Thread interrupted").withCause(e).asException(); + } catch (RuntimeException e) { + logger.log(Level.WARNING, "Unexpected error while building CSDS config dump", e); + error = Status.INTERNAL.withDescription("Unexpected internal error").withCause(e) + .asException(); + } + } + + try { + responseObserver.onNext(getStatusResponse(clientConfigs)); + } catch (RuntimeException e) { + logger.log(Level.WARNING, "Unexpected error while processing CSDS config dump", e); + error = Status.INTERNAL.withDescription("Unexpected internal error").withCause(e) + .asException(); + } } - ObjectPool xdsClientPool = xdsClientPoolFactory.get(); + if (error == null) { + return true; // All clients reported without error + } + responseObserver.onError(error); + return false; + } + + private ClientConfig getConfigForRequest(String target) throws InterruptedException { + ObjectPool xdsClientPool = xdsClientPoolFactory.get(target); if (xdsClientPool == null) { - return ClientStatusResponse.getDefaultInstance(); + return null; } XdsClient xdsClient = null; try { xdsClient = xdsClientPool.getObject(); - return ClientStatusResponse.newBuilder() - .addConfig(getClientConfigForXdsClient(xdsClient)) - .build(); + return getClientConfigForXdsClient(xdsClient, target); } finally { if (xdsClient != null) { xdsClientPool.returnObject(xdsClient); @@ -152,9 +178,18 @@ private ClientStatusResponse getConfigDumpForRequest(ClientStatusRequest request } } + private ClientStatusResponse getStatusResponse(List clientConfigs) { + if (clientConfigs.isEmpty()) { + return ClientStatusResponse.getDefaultInstance(); + } + return ClientStatusResponse.newBuilder().addAllConfig(clientConfigs).build(); + } + @VisibleForTesting - static ClientConfig getClientConfigForXdsClient(XdsClient xdsClient) throws InterruptedException { + static ClientConfig getClientConfigForXdsClient(XdsClient xdsClient, String target) + throws InterruptedException { ClientConfig.Builder builder = ClientConfig.newBuilder() + .setClientScope(target) .setNode(xdsClient.getBootstrapInfo().node().toEnvoyProtoNode()); Map, Map> metadataByType = @@ -214,6 +249,8 @@ static ClientResourceStatus metadataStatusToClientStatus(ResourceMetadataStatus return ClientResourceStatus.ACKED; case NACKED: return ClientResourceStatus.NACKED; + case TIMEOUT: + return ClientResourceStatus.TIMEOUT; default: throw new AssertionError("Unexpected ResourceMetadataStatus: " + status); } diff --git a/xds/src/main/java/io/grpc/xds/Endpoints.java b/xds/src/main/java/io/grpc/xds/Endpoints.java index 8b1715731df..558e3932ddc 100644 --- a/xds/src/main/java/io/grpc/xds/Endpoints.java +++ b/xds/src/main/java/io/grpc/xds/Endpoints.java @@ -21,6 +21,8 @@ import com.google.auto.value.AutoValue; import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.net.InetAddresses; import io.grpc.EquivalentAddressGroup; import java.net.InetSocketAddress; import java.util.List; @@ -41,11 +43,13 @@ abstract static class LocalityLbEndpoints { // Locality's priority level. abstract int priority(); + abstract ImmutableMap localityMetadata(); + static LocalityLbEndpoints create(List endpoints, int localityWeight, - int priority) { + int priority, ImmutableMap localityMetadata) { checkArgument(localityWeight > 0, "localityWeight must be greater than 0"); return new AutoValue_Endpoints_LocalityLbEndpoints( - ImmutableList.copyOf(endpoints), localityWeight, priority); + ImmutableList.copyOf(endpoints), localityWeight, priority, localityMetadata); } } @@ -55,23 +59,32 @@ abstract static class LbEndpoint { // The endpoint address to be connected to. abstract EquivalentAddressGroup eag(); - // Endpoint's weight for load balancing. If unspecified, value of 0 is returned. + // Endpoint's weight for load balancing. Guaranteed not to be 0. abstract int loadBalancingWeight(); // Whether the endpoint is healthy. abstract boolean isHealthy(); + abstract String hostname(); + + abstract ImmutableMap endpointMetadata(); + static LbEndpoint create(EquivalentAddressGroup eag, int loadBalancingWeight, - boolean isHealthy) { - return new AutoValue_Endpoints_LbEndpoint(eag, loadBalancingWeight, isHealthy); + boolean isHealthy, String hostname, ImmutableMap endpointMetadata) { + if (loadBalancingWeight == 0) { + loadBalancingWeight = 1; + } + return new AutoValue_Endpoints_LbEndpoint( + eag, loadBalancingWeight, isHealthy, hostname, endpointMetadata); } // Only for testing. @VisibleForTesting - static LbEndpoint create( - String address, int port, int loadBalancingWeight, boolean isHealthy) { - return LbEndpoint.create(new EquivalentAddressGroup(new InetSocketAddress(address, port)), - loadBalancingWeight, isHealthy); + static LbEndpoint create(String address, int port, int loadBalancingWeight, boolean isHealthy, + String hostname, ImmutableMap endpointMetadata) { + return LbEndpoint.create( + new EquivalentAddressGroup(new InetSocketAddress(InetAddresses.forString(address), port)), + loadBalancingWeight, isHealthy, hostname, endpointMetadata); } } diff --git a/xds/src/main/java/io/grpc/xds/EnvoyServerProtoData.java b/xds/src/main/java/io/grpc/xds/EnvoyServerProtoData.java index 978e6663cbe..3cf28d23578 100644 --- a/xds/src/main/java/io/grpc/xds/EnvoyServerProtoData.java +++ b/xds/src/main/java/io/grpc/xds/EnvoyServerProtoData.java @@ -16,16 +16,18 @@ package io.grpc.xds; +import static com.google.common.base.Preconditions.checkNotNull; + import com.google.auto.value.AutoValue; import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableList; import com.google.protobuf.util.Durations; +import io.envoyproxy.envoy.config.core.v3.SocketAddress.Protocol; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext; import io.grpc.Internal; import io.grpc.xds.client.EnvoyProtoData; import io.grpc.xds.internal.security.SslContextProviderSupplier; import java.net.InetAddress; -import java.net.UnknownHostException; import java.util.Objects; import javax.annotation.Nullable; @@ -41,13 +43,13 @@ private EnvoyServerProtoData() { } public abstract static class BaseTlsContext { - @Nullable protected final CommonTlsContext commonTlsContext; + protected final CommonTlsContext commonTlsContext; - protected BaseTlsContext(@Nullable CommonTlsContext commonTlsContext) { - this.commonTlsContext = commonTlsContext; + protected BaseTlsContext(CommonTlsContext commonTlsContext) { + this.commonTlsContext = checkNotNull(commonTlsContext, "commonTlsContext cannot be null."); } - @Nullable public CommonTlsContext getCommonTlsContext() { + public CommonTlsContext getCommonTlsContext() { return commonTlsContext; } @@ -71,20 +73,81 @@ public int hashCode() { public static final class UpstreamTlsContext extends BaseTlsContext { + private final String sni; + private final boolean autoHostSni; + private final boolean autoSniSanValidation; + @VisibleForTesting public UpstreamTlsContext(CommonTlsContext commonTlsContext) { + this(commonTlsContext, "", false, false); + } + + @VisibleForTesting + public UpstreamTlsContext( + CommonTlsContext commonTlsContext, String sni, boolean autoHostSni, + boolean autoSniSanValidation) { super(commonTlsContext); + this.sni = sni == null ? "" : sni; + this.autoHostSni = autoHostSni; + this.autoSniSanValidation = autoSniSanValidation; + } + + @VisibleForTesting + public UpstreamTlsContext( + io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.UpstreamTlsContext + upstreamTlsContext) { + super(upstreamTlsContext.getCommonTlsContext()); + this.sni = upstreamTlsContext.getSni(); + this.autoHostSni = upstreamTlsContext.getAutoHostSni(); + this.autoSniSanValidation = upstreamTlsContext.getAutoSniSanValidation(); } public static UpstreamTlsContext fromEnvoyProtoUpstreamTlsContext( io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.UpstreamTlsContext upstreamTlsContext) { - return new UpstreamTlsContext(upstreamTlsContext.getCommonTlsContext()); + return new UpstreamTlsContext(upstreamTlsContext); + } + + public String getSni() { + return sni; + } + + public boolean getAutoHostSni() { + return autoHostSni; + } + + public boolean getAutoSniSanValidation() { + return autoSniSanValidation; } @Override public String toString() { - return "UpstreamTlsContext{" + "commonTlsContext=" + commonTlsContext + '}'; + return "UpstreamTlsContext{" + + "commonTlsContext=" + commonTlsContext + + "\nsni=" + sni + + "\nauto_host_sni=" + autoHostSni + + "\nauto_sni_san_validation=" + autoSniSanValidation + + "}"; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + UpstreamTlsContext that = (UpstreamTlsContext) o; + return autoHostSni == that.autoHostSni + && autoSniSanValidation == that.autoSniSanValidation + && Objects.equals(commonTlsContext, that.commonTlsContext) + && Objects.equals(sni, that.sni); + } + + @Override + public int hashCode() { + return Objects.hash(commonTlsContext, sni, autoHostSni, autoSniSanValidation); } } @@ -148,9 +211,9 @@ abstract static class CidrRange { abstract int prefixLen(); - static CidrRange create(String addressPrefix, int prefixLen) throws UnknownHostException { + static CidrRange create(InetAddress addressPrefix, int prefixLen) { return new AutoValue_EnvoyServerProtoData_CidrRange( - InetAddress.getByName(addressPrefix), prefixLen); + addressPrefix, prefixLen); } } @@ -205,7 +268,7 @@ public static FilterChainMatch create(int destinationPort, @AutoValue abstract static class FilterChain { - // possibly empty + // Must be unique per server instance (except the default chain). abstract String name(); // TODO(sanjaypujare): flatten structure by moving FilterChainMatch class members here. @@ -247,13 +310,17 @@ abstract static class Listener { @Nullable abstract FilterChain defaultFilterChain(); + @Nullable + abstract Protocol protocol(); + static Listener create( String name, @Nullable String address, ImmutableList filterChains, - @Nullable FilterChain defaultFilterChain) { + @Nullable FilterChain defaultFilterChain, + @Nullable Protocol protocol) { return new AutoValue_EnvoyServerProtoData_Listener(name, address, filterChains, - defaultFilterChain); + defaultFilterChain, protocol); } } @@ -322,7 +389,7 @@ static OutlierDetection fromEnvoyOutlierDetection( Integer minimumHosts = envoyOutlierDetection.hasSuccessRateMinimumHosts() ? envoyOutlierDetection.getSuccessRateMinimumHosts().getValue() : null; Integer requestVolume = envoyOutlierDetection.hasSuccessRateRequestVolume() - ? envoyOutlierDetection.getSuccessRateMinimumHosts().getValue() : null; + ? envoyOutlierDetection.getSuccessRateRequestVolume().getValue() : null; successRateEjection = SuccessRateEjection.create(stdevFactor, enforcementPercentage, minimumHosts, requestVolume); diff --git a/xds/src/main/java/io/grpc/xds/ExtAuthzConfigParser.java b/xds/src/main/java/io/grpc/xds/ExtAuthzConfigParser.java new file mode 100644 index 00000000000..853e8a5c03a --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/ExtAuthzConfigParser.java @@ -0,0 +1,103 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds; + +import com.google.common.collect.ImmutableList; +import io.envoyproxy.envoy.extensions.filters.http.ext_authz.v3.ExtAuthz; +import io.grpc.internal.GrpcUtil; +import io.grpc.xds.client.Bootstrapper.BootstrapInfo; +import io.grpc.xds.client.Bootstrapper.ServerInfo; +import io.grpc.xds.internal.MatcherParser; +import io.grpc.xds.internal.extauthz.ExtAuthzConfig; +import io.grpc.xds.internal.extauthz.ExtAuthzParseException; +import io.grpc.xds.internal.grpcservice.GrpcServiceConfig; +import io.grpc.xds.internal.grpcservice.GrpcServiceParseException; +import io.grpc.xds.internal.headermutations.HeaderMutationRulesParseException; +import io.grpc.xds.internal.headermutations.HeaderMutationRulesParser; + + +/** + * Parser for {@link io.envoyproxy.envoy.extensions.filters.http.ext_authz.v3.ExtAuthz}. + */ +final class ExtAuthzConfigParser { + + private ExtAuthzConfigParser() {} + + /** + * Parses the {@link io.envoyproxy.envoy.extensions.filters.http.ext_authz.v3.ExtAuthz} proto to + * create an {@link ExtAuthzConfig} instance. + * + * @param extAuthzProto The ext_authz proto to parse. + * @return An {@link ExtAuthzConfig} instance. + * @throws ExtAuthzParseException if the proto is invalid or contains unsupported features. + */ + public static ExtAuthzConfig parse( + ExtAuthz extAuthzProto, BootstrapInfo bootstrapInfo, ServerInfo serverInfo) + throws ExtAuthzParseException { + if (!extAuthzProto.hasGrpcService()) { + throw new ExtAuthzParseException( + "unsupported ExtAuthz service type: only grpc_service is supported"); + } + GrpcServiceConfig grpcServiceConfig; + try { + grpcServiceConfig = + GrpcServiceConfigParser.parse(extAuthzProto.getGrpcService(), bootstrapInfo, serverInfo); + } catch (GrpcServiceParseException e) { + throw new ExtAuthzParseException("Failed to parse GrpcService config: " + e.getMessage(), e); + } + ExtAuthzConfig.Builder builder = ExtAuthzConfig.builder().grpcService(grpcServiceConfig) + .failureModeAllow(extAuthzProto.getFailureModeAllow()) + .failureModeAllowHeaderAdd(extAuthzProto.getFailureModeAllowHeaderAdd()) + .includePeerCertificate(extAuthzProto.getIncludePeerCertificate()) + .denyAtDisable(extAuthzProto.getDenyAtDisable().getDefaultValue().getValue()); + + if (extAuthzProto.hasFilterEnabled()) { + try { + builder.filterEnabled( + MatcherParser.parseFractionMatcher(extAuthzProto.getFilterEnabled().getDefaultValue())); + } catch (IllegalArgumentException e) { + throw new ExtAuthzParseException(e.getMessage()); + } + } + + if (extAuthzProto.hasStatusOnError()) { + builder.statusOnError( + GrpcUtil.httpStatusToGrpcStatus(extAuthzProto.getStatusOnError().getCodeValue())); + } + + if (extAuthzProto.hasAllowedHeaders()) { + builder.allowedHeaders(extAuthzProto.getAllowedHeaders().getPatternsList().stream() + .map(MatcherParser::parseStringMatcher).collect(ImmutableList.toImmutableList())); + } + + if (extAuthzProto.hasDisallowedHeaders()) { + builder.disallowedHeaders(extAuthzProto.getDisallowedHeaders().getPatternsList().stream() + .map(MatcherParser::parseStringMatcher).collect(ImmutableList.toImmutableList())); + } + + if (extAuthzProto.hasDecoderHeaderMutationRules()) { + try { + builder.decoderHeaderMutationRules( + HeaderMutationRulesParser.parse(extAuthzProto.getDecoderHeaderMutationRules())); + } catch (HeaderMutationRulesParseException e) { + throw new ExtAuthzParseException(e.getMessage(), e); + } + } + + return builder.build(); + } +} diff --git a/xds/src/main/java/io/grpc/xds/FaultFilter.java b/xds/src/main/java/io/grpc/xds/FaultFilter.java index d46b3d30f5a..ce764c7e943 100644 --- a/xds/src/main/java/io/grpc/xds/FaultFilter.java +++ b/xds/src/main/java/io/grpc/xds/FaultFilter.java @@ -37,7 +37,6 @@ import io.grpc.Deadline; import io.grpc.ForwardingClientCall; import io.grpc.ForwardingClientCallListener.SimpleForwardingClientCallListener; -import io.grpc.LoadBalancer.PickSubchannelArgs; import io.grpc.Metadata; import io.grpc.MethodDescriptor; import io.grpc.Status; @@ -46,7 +45,6 @@ import io.grpc.internal.GrpcUtil; import io.grpc.xds.FaultConfig.FaultAbort; import io.grpc.xds.FaultConfig.FaultDelay; -import io.grpc.xds.Filter.ClientInterceptorBuilder; import io.grpc.xds.ThreadSafeRandom.ThreadSafeRandomImpl; import java.util.Locale; import java.util.concurrent.Executor; @@ -57,10 +55,11 @@ import javax.annotation.Nullable; /** HttpFault filter implementation. */ -final class FaultFilter implements Filter, ClientInterceptorBuilder { +final class FaultFilter implements Filter { - static final FaultFilter INSTANCE = + private static final FaultFilter INSTANCE = new FaultFilter(ThreadSafeRandomImpl.instance, new AtomicLong()); + @VisibleForTesting static final Metadata.Key HEADER_DELAY_KEY = Metadata.Key.of("x-envoy-fault-delay-request", Metadata.ASCII_STRING_MARSHALLER); @@ -88,196 +87,218 @@ final class FaultFilter implements Filter, ClientInterceptorBuilder { this.activeFaultCounter = activeFaultCounter; } - @Override - public String[] typeUrls() { - return new String[] { TYPE_URL }; - } - - @Override - public ConfigOrError parseFilterConfig(Message rawProtoMessage) { - HTTPFault httpFaultProto; - if (!(rawProtoMessage instanceof Any)) { - return ConfigOrError.fromError("Invalid config type: " + rawProtoMessage.getClass()); + static final class Provider implements Filter.Provider { + @Override + public String[] typeUrls() { + return new String[]{TYPE_URL}; } - Any anyMessage = (Any) rawProtoMessage; - try { - httpFaultProto = anyMessage.unpack(HTTPFault.class); - } catch (InvalidProtocolBufferException e) { - return ConfigOrError.fromError("Invalid proto: " + e); + + @Override + public boolean isClientFilter() { + return true; } - return parseHttpFault(httpFaultProto); - } - private static ConfigOrError parseHttpFault(HTTPFault httpFault) { - FaultDelay faultDelay = null; - FaultAbort faultAbort = null; - if (httpFault.hasDelay()) { - faultDelay = parseFaultDelay(httpFault.getDelay()); + @Override + public FaultFilter newInstance(String name) { + return INSTANCE; } - if (httpFault.hasAbort()) { - ConfigOrError faultAbortOrError = parseFaultAbort(httpFault.getAbort()); - if (faultAbortOrError.errorDetail != null) { - return ConfigOrError.fromError( - "HttpFault contains invalid FaultAbort: " + faultAbortOrError.errorDetail); + + @Override + public ConfigOrError parseFilterConfig( + Message rawProtoMessage, FilterConfigParseContext context) { + HTTPFault httpFaultProto; + if (!(rawProtoMessage instanceof Any)) { + return ConfigOrError.fromError("Invalid config type: " + rawProtoMessage.getClass()); } - faultAbort = faultAbortOrError.config; - } - Integer maxActiveFaults = null; - if (httpFault.hasMaxActiveFaults()) { - maxActiveFaults = httpFault.getMaxActiveFaults().getValue(); - if (maxActiveFaults < 0) { - maxActiveFaults = Integer.MAX_VALUE; + Any anyMessage = (Any) rawProtoMessage; + try { + httpFaultProto = anyMessage.unpack(HTTPFault.class); + } catch (InvalidProtocolBufferException e) { + return ConfigOrError.fromError("Invalid proto: " + e); } + return parseHttpFault(httpFaultProto); } - return ConfigOrError.fromConfig(FaultConfig.create(faultDelay, faultAbort, maxActiveFaults)); - } - private static FaultDelay parseFaultDelay( - io.envoyproxy.envoy.extensions.filters.common.fault.v3.FaultDelay faultDelay) { - FaultConfig.FractionalPercent percent = parsePercent(faultDelay.getPercentage()); - if (faultDelay.hasHeaderDelay()) { - return FaultDelay.forHeader(percent); + @Override + public ConfigOrError parseFilterConfigOverride( + Message rawProtoMessage, FilterConfigParseContext context) { + return parseFilterConfig(rawProtoMessage, context); } - return FaultDelay.forFixedDelay(Durations.toNanos(faultDelay.getFixedDelay()), percent); - } - @VisibleForTesting - static ConfigOrError parseFaultAbort( - io.envoyproxy.envoy.extensions.filters.http.fault.v3.FaultAbort faultAbort) { - FaultConfig.FractionalPercent percent = parsePercent(faultAbort.getPercentage()); - switch (faultAbort.getErrorTypeCase()) { - case HEADER_ABORT: - return ConfigOrError.fromConfig(FaultAbort.forHeader(percent)); - case HTTP_STATUS: - return ConfigOrError.fromConfig(FaultAbort.forStatus( - GrpcUtil.httpStatusToGrpcStatus(faultAbort.getHttpStatus()), percent)); - case GRPC_STATUS: - return ConfigOrError.fromConfig(FaultAbort.forStatus( - Status.fromCodeValue(faultAbort.getGrpcStatus()), percent)); - case ERRORTYPE_NOT_SET: - default: - return ConfigOrError.fromError( - "Unknown error type case: " + faultAbort.getErrorTypeCase()); + private static ConfigOrError parseHttpFault(HTTPFault httpFault) { + FaultDelay faultDelay = null; + FaultAbort faultAbort = null; + if (httpFault.hasDelay()) { + faultDelay = parseFaultDelay(httpFault.getDelay()); + } + if (httpFault.hasAbort()) { + ConfigOrError faultAbortOrError = parseFaultAbort(httpFault.getAbort()); + if (faultAbortOrError.errorDetail != null) { + return ConfigOrError.fromError( + "HttpFault contains invalid FaultAbort: " + faultAbortOrError.errorDetail); + } + faultAbort = faultAbortOrError.config; + } + Integer maxActiveFaults = null; + if (httpFault.hasMaxActiveFaults()) { + maxActiveFaults = httpFault.getMaxActiveFaults().getValue(); + if (maxActiveFaults < 0) { + maxActiveFaults = Integer.MAX_VALUE; + } + } + return ConfigOrError.fromConfig(FaultConfig.create(faultDelay, faultAbort, maxActiveFaults)); } - } - private static FaultConfig.FractionalPercent parsePercent(FractionalPercent proto) { - switch (proto.getDenominator()) { - case HUNDRED: - return FaultConfig.FractionalPercent.perHundred(proto.getNumerator()); - case TEN_THOUSAND: - return FaultConfig.FractionalPercent.perTenThousand(proto.getNumerator()); - case MILLION: - return FaultConfig.FractionalPercent.perMillion(proto.getNumerator()); - case UNRECOGNIZED: - default: - throw new IllegalArgumentException("Unknown denominator type: " + proto.getDenominator()); + private static FaultDelay parseFaultDelay( + io.envoyproxy.envoy.extensions.filters.common.fault.v3.FaultDelay faultDelay) { + FaultConfig.FractionalPercent percent = parsePercent(faultDelay.getPercentage()); + if (faultDelay.hasHeaderDelay()) { + return FaultDelay.forHeader(percent); + } + return FaultDelay.forFixedDelay(Durations.toNanos(faultDelay.getFixedDelay()), percent); } - } - @Override - public ConfigOrError parseFilterConfigOverride(Message rawProtoMessage) { - return parseFilterConfig(rawProtoMessage); + @VisibleForTesting + static ConfigOrError parseFaultAbort( + io.envoyproxy.envoy.extensions.filters.http.fault.v3.FaultAbort faultAbort) { + FaultConfig.FractionalPercent percent = parsePercent(faultAbort.getPercentage()); + switch (faultAbort.getErrorTypeCase()) { + case HEADER_ABORT: + return ConfigOrError.fromConfig(FaultAbort.forHeader(percent)); + case HTTP_STATUS: + return ConfigOrError.fromConfig(FaultAbort.forStatus( + GrpcUtil.httpStatusToGrpcStatus(faultAbort.getHttpStatus()), percent)); + case GRPC_STATUS: + return ConfigOrError.fromConfig(FaultAbort.forStatus( + Status.fromCodeValue(faultAbort.getGrpcStatus()), percent)); + case ERRORTYPE_NOT_SET: + default: + return ConfigOrError.fromError( + "Unknown error type case: " + faultAbort.getErrorTypeCase()); + } + } + + private static FaultConfig.FractionalPercent parsePercent(FractionalPercent proto) { + switch (proto.getDenominator()) { + case HUNDRED: + return FaultConfig.FractionalPercent.perHundred(proto.getNumerator()); + case TEN_THOUSAND: + return FaultConfig.FractionalPercent.perTenThousand(proto.getNumerator()); + case MILLION: + return FaultConfig.FractionalPercent.perMillion(proto.getNumerator()); + case UNRECOGNIZED: + default: + throw new IllegalArgumentException("Unknown denominator type: " + proto.getDenominator()); + } + } } @Nullable @Override public ClientInterceptor buildClientInterceptor( - FilterConfig config, @Nullable FilterConfig overrideConfig, PickSubchannelArgs args, + FilterConfig config, @Nullable FilterConfig overrideConfig, final ScheduledExecutorService scheduler) { checkNotNull(config, "config"); if (overrideConfig != null) { config = overrideConfig; } FaultConfig faultConfig = (FaultConfig) config; - Long delayNanos = null; - Status abortStatus = null; - if (faultConfig.maxActiveFaults() == null - || activeFaultCounter.get() < faultConfig.maxActiveFaults()) { - Metadata headers = args.getHeaders(); - if (faultConfig.faultDelay() != null) { - delayNanos = determineFaultDelayNanos(faultConfig.faultDelay(), headers); - } - if (faultConfig.faultAbort() != null) { - abortStatus = determineFaultAbortStatus(faultConfig.faultAbort(), headers); - } - } - if (delayNanos == null && abortStatus == null) { - return null; - } - final Long finalDelayNanos = delayNanos; - final Status finalAbortStatus = getAbortStatusWithDescription(abortStatus); final class FaultInjectionInterceptor implements ClientInterceptor { @Override public ClientCall interceptCall( final MethodDescriptor method, final CallOptions callOptions, final Channel next) { - Executor callExecutor = callOptions.getExecutor(); - if (callExecutor == null) { // This should never happen in practice because - // ManagedChannelImpl.ConfigSelectingClientCall always provides CallOptions with - // a callExecutor. - // TODO(https://github.com/grpc/grpc-java/issues/7868) - callExecutor = MoreExecutors.directExecutor(); + boolean checkFault = false; + if (faultConfig.maxActiveFaults() == null + || activeFaultCounter.get() < faultConfig.maxActiveFaults()) { + checkFault = faultConfig.faultDelay() != null || faultConfig.faultAbort() != null; } - if (finalDelayNanos != null) { - Supplier> callSupplier; - if (finalAbortStatus != null) { - callSupplier = Suppliers.ofInstance( - new FailingClientCall(finalAbortStatus, callExecutor)); - } else { - callSupplier = new Supplier>() { - @Override - public ClientCall get() { - return next.newCall(method, callOptions); - } - }; + if (!checkFault) { + return next.newCall(method, callOptions); + } + final class DeadlineInsightForwardingCall extends ForwardingClientCall { + private ClientCall delegate; + + @Override + protected ClientCall delegate() { + return delegate; } - final DelayInjectedCall delayInjectedCall = new DelayInjectedCall<>( - finalDelayNanos, callExecutor, scheduler, callOptions.getDeadline(), callSupplier); - final class DeadlineInsightForwardingCall extends ForwardingClientCall { - @Override - protected ClientCall delegate() { - return delayInjectedCall; + @Override + public void start(Listener listener, Metadata headers) { + Executor callExecutor = callOptions.getExecutor(); + if (callExecutor == null) { // This should never happen in practice because + // ManagedChannelImpl.ConfigSelectingClientCall always provides CallOptions with + // a callExecutor. + // TODO(https://github.com/grpc/grpc-java/issues/7868) + callExecutor = MoreExecutors.directExecutor(); } - @Override - public void start(Listener listener, Metadata headers) { - Listener finalListener = - new SimpleForwardingClientCallListener(listener) { - @Override - public void onClose(Status status, Metadata trailers) { - if (status.getCode().equals(Code.DEADLINE_EXCEEDED)) { - // TODO(zdapeng:) check effective deadline locally, and - // do the following only if the local deadline is exceeded. - // (If the server sends DEADLINE_EXCEEDED for its own deadline, then the - // injected delay does not contribute to the error, because the request is - // only sent out after the delay. There could be a race between local and - // remote, but it is rather rare.) - String description = String.format( - Locale.US, - "Deadline exceeded after up to %d ns of fault-injected delay", - finalDelayNanos); - if (status.getDescription() != null) { - description = description + ": " + status.getDescription(); - } - status = Status.DEADLINE_EXCEEDED - .withDescription(description).withCause(status.getCause()); - // Replace trailers to prevent mixing sources of status and trailers. - trailers = new Metadata(); + Long delayNanos; + Status abortStatus = null; + if (faultConfig.faultDelay() != null) { + delayNanos = determineFaultDelayNanos(faultConfig.faultDelay(), headers); + } else { + delayNanos = null; + } + if (faultConfig.faultAbort() != null) { + abortStatus = getAbortStatusWithDescription( + determineFaultAbortStatus(faultConfig.faultAbort(), headers)); + } + + Supplier> callSupplier; + if (abortStatus != null) { + callSupplier = Suppliers.ofInstance( + new FailingClientCall(abortStatus, callExecutor)); + } else { + callSupplier = new Supplier>() { + @Override + public ClientCall get() { + return next.newCall(method, callOptions); + } + }; + } + if (delayNanos == null) { + delegate = callSupplier.get(); + delegate().start(listener, headers); + return; + } + + delegate = new DelayInjectedCall<>( + delayNanos, callExecutor, scheduler, callOptions.getDeadline(), callSupplier); + + Listener finalListener = + new SimpleForwardingClientCallListener(listener) { + @Override + public void onClose(Status status, Metadata trailers) { + if (status.getCode().equals(Code.DEADLINE_EXCEEDED)) { + // TODO(zdapeng:) check effective deadline locally, and + // do the following only if the local deadline is exceeded. + // (If the server sends DEADLINE_EXCEEDED for its own deadline, then the + // injected delay does not contribute to the error, because the request is + // only sent out after the delay. There could be a race between local and + // remote, but it is rather rare.) + String description = String.format( + Locale.US, + "Deadline exceeded after up to %d ns of fault-injected delay", + delayNanos); + if (status.getDescription() != null) { + description = description + ": " + status.getDescription(); } - delegate().onClose(status, trailers); + status = Status.DEADLINE_EXCEEDED + .withDescription(description).withCause(status.getCause()); + // Replace trailers to prevent mixing sources of status and trailers. + trailers = new Metadata(); } - }; - delegate().start(finalListener, headers); - } + delegate().onClose(status, trailers); + } + }; + delegate().start(finalListener, headers); } - - return new DeadlineInsightForwardingCall(); - } else { - return new FailingClientCall<>(finalAbortStatus, callExecutor); } + + return new DeadlineInsightForwardingCall(); } } diff --git a/xds/src/main/java/io/grpc/xds/Filter.java b/xds/src/main/java/io/grpc/xds/Filter.java index 4b2767687f3..4fa56beb1de 100644 --- a/xds/src/main/java/io/grpc/xds/Filter.java +++ b/xds/src/main/java/io/grpc/xds/Filter.java @@ -16,58 +16,140 @@ package io.grpc.xds; + +import com.google.auto.value.AutoValue; import com.google.common.base.MoreObjects; import com.google.protobuf.Message; import io.grpc.ClientInterceptor; -import io.grpc.LoadBalancer.PickSubchannelArgs; import io.grpc.ServerInterceptor; +import io.grpc.xds.client.Bootstrapper.BootstrapInfo; +import io.grpc.xds.client.Bootstrapper.ServerInfo; +import java.io.Closeable; import java.util.Objects; import java.util.concurrent.ScheduledExecutorService; import javax.annotation.Nullable; /** - * Defines the parsing functionality of an HTTP filter. A Filter may optionally implement either - * {@link ClientInterceptorBuilder} or {@link ServerInterceptorBuilder} or both, indicating it is - * capable of working on the client side or server side or both, respectively. + * Defines the parsing functionality of an HTTP filter. + * + *

A Filter may optionally implement either {@link Filter#buildClientInterceptor} or + * {@link Filter#buildServerInterceptor} or both, and return true from corresponding + * {@link Provider#isClientFilter()}, {@link Provider#isServerFilter()} to indicate that the filter + * is capable of working on the client side or server side or both, respectively. */ -interface Filter { +interface Filter extends Closeable { - /** - * The proto message types supported by this filter. A filter will be registered by each of its - * supported message types. - */ - String[] typeUrls(); + /** Represents an opaque data structure holding configuration for a filter. */ + interface FilterConfig { + String typeUrl(); + } /** - * Parses the top-level filter config from raw proto message. The message may be either a {@link - * com.google.protobuf.Any} or a {@link com.google.protobuf.Struct}. + * Common interface for filter providers. */ - ConfigOrError parseFilterConfig(Message rawProtoMessage); + interface Provider { + /** + * The proto message types supported by this filter. A filter will be registered by each of its + * supported message types. + */ + String[] typeUrls(); - /** - * Parses the per-filter override filter config from raw proto message. The message may be either - * a {@link com.google.protobuf.Any} or a {@link com.google.protobuf.Struct}. - */ - ConfigOrError parseFilterConfigOverride(Message rawProtoMessage); + /** + * Whether the filter can be installed on the client side. + * + *

Returns true if the filter implements {@link Filter#buildClientInterceptor}. + */ + default boolean isClientFilter() { + return false; + } - /** Represents an opaque data structure holding configuration for a filter. */ - interface FilterConfig { - String typeUrl(); + /** + * Whether the filter can be installed into xDS-enabled servers. + * + *

Returns true if the filter implements {@link Filter#buildServerInterceptor}. + */ + default boolean isServerFilter() { + return false; + } + + /** + * Creates a new instance of the filter. + * + *

Returns a filter instance registered with the same typeUrls as the provider, + * capable of working with the same FilterConfig type returned by provider's parse functions. + * + *

For xDS gRPC clients, new filter instances are created per combination of: + *

    + *
  1. XdsNameResolver instance,
  2. + *
  3. Filter name+typeUrl in HttpConnectionManager (HCM) http_filters.
  4. + *
+ * + *

For xDS-enabled gRPC servers, new filter instances are created per combination of: + *

    + *
  1. Server instance,
  2. + *
  3. FilterChain name,
  4. + *
  5. Filter name+typeUrl in FilterChain's HCM.http_filters.
  6. + *
+ */ + Filter newInstance(String name); + + /** + * Parses the top-level filter config from raw proto message. The message may be either a {@link + * com.google.protobuf.Any} or a {@link com.google.protobuf.Struct}. + */ + ConfigOrError parseFilterConfig( + Message rawProtoMessage, FilterConfigParseContext context); + + /** + * Parses the per-filter override filter config from raw proto message. The message may be + * either a {@link com.google.protobuf.Any} or a {@link com.google.protobuf.Struct}. + */ + ConfigOrError parseFilterConfigOverride( + Message rawProtoMessage, FilterConfigParseContext context); } /** Uses the FilterConfigs produced above to produce an HTTP filter interceptor for clients. */ - interface ClientInterceptorBuilder { - @Nullable - ClientInterceptor buildClientInterceptor( - FilterConfig config, @Nullable FilterConfig overrideConfig, PickSubchannelArgs args, - ScheduledExecutorService scheduler); + @Nullable + default ClientInterceptor buildClientInterceptor( + FilterConfig config, @Nullable FilterConfig overrideConfig, + ScheduledExecutorService scheduler) { + return null; } /** Uses the FilterConfigs produced above to produce an HTTP filter interceptor for the server. */ - interface ServerInterceptorBuilder { - @Nullable - ServerInterceptor buildServerInterceptor( - FilterConfig config, @Nullable FilterConfig overrideConfig); + @Nullable + default ServerInterceptor buildServerInterceptor( + FilterConfig config, @Nullable FilterConfig overrideConfig) { + return null; + } + + /** + * Releases filter resources like shared resources and remote connections. + * + *

See {@link Provider#newInstance()} for details on filter instance creation. + */ + @Override + default void close() {} + + /** Context carrying dynamic metadata for a filter. */ + @AutoValue + abstract static class FilterConfigParseContext { + abstract BootstrapInfo bootstrapInfo(); + + abstract ServerInfo serverInfo(); + + static Builder builder() { + return new AutoValue_Filter_FilterConfigParseContext.Builder(); + } + + @AutoValue.Builder + abstract static class Builder { + abstract Builder bootstrapInfo(BootstrapInfo info); + + abstract Builder serverInfo(ServerInfo info); + + abstract FilterConfigParseContext build(); + } } /** Filter config with instance name. */ @@ -81,6 +163,10 @@ final class NamedFilterConfig { this.filterConfig = filterConfig; } + String filterStateKey() { + return name + "_" + filterConfig.typeUrl(); + } + @Override public boolean equals(Object o) { if (this == o) { diff --git a/xds/src/main/java/io/grpc/xds/FilterChainMatchingProtocolNegotiators.java b/xds/src/main/java/io/grpc/xds/FilterChainMatchingProtocolNegotiators.java index fa03b2add4d..77a66495614 100644 --- a/xds/src/main/java/io/grpc/xds/FilterChainMatchingProtocolNegotiators.java +++ b/xds/src/main/java/io/grpc/xds/FilterChainMatchingProtocolNegotiators.java @@ -17,8 +17,8 @@ package io.grpc.xds; import static com.google.common.base.Preconditions.checkNotNull; -import static io.grpc.xds.InternalXdsAttributes.ATTR_DRAIN_GRACE_NANOS; -import static io.grpc.xds.InternalXdsAttributes.ATTR_FILTER_CHAIN_SELECTOR_MANAGER; +import static io.grpc.xds.XdsAttributes.ATTR_DRAIN_GRACE_NANOS; +import static io.grpc.xds.XdsAttributes.ATTR_FILTER_CHAIN_SELECTOR_MANAGER; import static io.grpc.xds.XdsServerWrapper.ATTR_SERVER_ROUTING_CONFIG; import static io.grpc.xds.internal.security.SecurityProtocolNegotiators.ATTR_SERVER_SSL_CONTEXT_PROVIDER_SUPPLIER; @@ -151,6 +151,10 @@ static final class FilterChainSelector { this.defaultRoutingConfig = checkNotNull(defaultRoutingConfig, "defaultRoutingConfig"); } + FilterChainSelector(Map> routingConfigs) { + this(routingConfigs, null, new AtomicReference<>()); + } + @VisibleForTesting Map> getRoutingConfigs() { return routingConfigs; @@ -329,7 +333,7 @@ private static int getMatchingPrefixLength( // use prefix_ranges (CIDR) and get the most specific matches private static Collection filterOnIpAddress( Collection filterChains, InetAddress address, boolean forDestination) { - // curent list of top ones + // current list of top ones ArrayList topOnes = new ArrayList<>(filterChains.size()); int topMatchingPrefixLen = -1; for (FilterChain filterChain : filterChains) { diff --git a/xds/src/main/java/io/grpc/xds/FilterChainSelectorManager.java b/xds/src/main/java/io/grpc/xds/FilterChainSelectorManager.java index 4295d75f59b..b3cc14c6484 100644 --- a/xds/src/main/java/io/grpc/xds/FilterChainSelectorManager.java +++ b/xds/src/main/java/io/grpc/xds/FilterChainSelectorManager.java @@ -18,11 +18,11 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; +import com.google.errorprone.annotations.concurrent.GuardedBy; import io.grpc.xds.FilterChainMatchingProtocolNegotiators.FilterChainMatchingHandler.FilterChainSelector; import java.util.Comparator; import java.util.TreeSet; import java.util.concurrent.atomic.AtomicLong; -import javax.annotation.concurrent.GuardedBy; /** * Maintains the current xDS selector and any resources using that selector. When the selector diff --git a/xds/src/main/java/io/grpc/xds/FilterRegistry.java b/xds/src/main/java/io/grpc/xds/FilterRegistry.java index 7f1fe82c6c3..da3a59fe8c1 100644 --- a/xds/src/main/java/io/grpc/xds/FilterRegistry.java +++ b/xds/src/main/java/io/grpc/xds/FilterRegistry.java @@ -23,21 +23,22 @@ /** * A registry for all supported {@link Filter}s. Filters can be queried from the registry - * by any of the {@link Filter#typeUrls() type URLs}. + * by any of the {@link Filter.Provider#typeUrls() type URLs}. */ final class FilterRegistry { private static FilterRegistry instance; - private final Map supportedFilters = new HashMap<>(); + private final Map supportedFilters = new HashMap<>(); private FilterRegistry() {} static synchronized FilterRegistry getDefaultRegistry() { if (instance == null) { instance = newRegistry().register( - FaultFilter.INSTANCE, - RouterFilter.INSTANCE, - RbacFilter.INSTANCE); + new FaultFilter.Provider(), + new RouterFilter.Provider(), + new RbacFilter.Provider(), + new GcpAuthenticationFilter.Provider()); } return instance; } @@ -48,8 +49,8 @@ static FilterRegistry newRegistry() { } @VisibleForTesting - FilterRegistry register(Filter... filters) { - for (Filter filter : filters) { + FilterRegistry register(Filter.Provider... filters) { + for (Filter.Provider filter : filters) { for (String typeUrl : filter.typeUrls()) { supportedFilters.put(typeUrl, filter); } @@ -58,7 +59,7 @@ FilterRegistry register(Filter... filters) { } @Nullable - Filter get(String typeUrl) { + Filter.Provider get(String typeUrl) { return supportedFilters.get(typeUrl); } } diff --git a/xds/src/main/java/io/grpc/xds/GcpAuthenticationFilter.java b/xds/src/main/java/io/grpc/xds/GcpAuthenticationFilter.java new file mode 100644 index 00000000000..e87c402fcb0 --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/GcpAuthenticationFilter.java @@ -0,0 +1,328 @@ +/* + * Copyright 2021 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds; + +import static com.google.common.base.Preconditions.checkNotNull; +import static io.grpc.xds.XdsNameResolver.CLUSTER_SELECTION_KEY; +import static io.grpc.xds.XdsNameResolver.XDS_CONFIG_CALL_OPTION_KEY; + +import com.google.auth.oauth2.ComputeEngineCredentials; +import com.google.auth.oauth2.IdTokenCredentials; +import com.google.common.annotations.VisibleForTesting; +import com.google.common.primitives.UnsignedLongs; +import com.google.protobuf.Any; +import com.google.protobuf.InvalidProtocolBufferException; +import com.google.protobuf.Message; +import io.envoyproxy.envoy.extensions.filters.http.gcp_authn.v3.Audience; +import io.envoyproxy.envoy.extensions.filters.http.gcp_authn.v3.GcpAuthnFilterConfig; +import io.envoyproxy.envoy.extensions.filters.http.gcp_authn.v3.TokenCacheConfig; +import io.grpc.CallCredentials; +import io.grpc.CallOptions; +import io.grpc.Channel; +import io.grpc.ClientCall; +import io.grpc.ClientInterceptor; +import io.grpc.CompositeCallCredentials; +import io.grpc.Metadata; +import io.grpc.MethodDescriptor; +import io.grpc.Status; +import io.grpc.StatusOr; +import io.grpc.auth.MoreCallCredentials; +import io.grpc.xds.GcpAuthenticationFilter.AudienceMetadataParser.AudienceWrapper; +import io.grpc.xds.MetadataRegistry.MetadataValueParser; +import io.grpc.xds.XdsConfig.XdsClusterConfig; +import io.grpc.xds.client.XdsResourceType.ResourceInvalidException; +import java.util.LinkedHashMap; +import java.util.Map; +import java.util.concurrent.ScheduledExecutorService; +import java.util.function.Function; +import javax.annotation.Nullable; + +/** + * A {@link Filter} that injects a {@link CallCredentials} to handle + * authentication for xDS credentials. + */ +final class GcpAuthenticationFilter implements Filter { + + static final String TYPE_URL = + "type.googleapis.com/envoy.extensions.filters.http.gcp_authn.v3.GcpAuthnFilterConfig"; + private final LruCache callCredentialsCache; + final String filterInstanceName; + + GcpAuthenticationFilter(String name, int cacheSize) { + filterInstanceName = checkNotNull(name, "name"); + this.callCredentialsCache = new LruCache<>(cacheSize); + } + + static final class Provider implements Filter.Provider { + private final int cacheSize = 10; + + @Override + public String[] typeUrls() { + return new String[]{TYPE_URL}; + } + + @Override + public boolean isClientFilter() { + return true; + } + + @Override + public GcpAuthenticationFilter newInstance(String name) { + return new GcpAuthenticationFilter(name, cacheSize); + } + + @Override + public ConfigOrError parseFilterConfig( + Message rawProtoMessage, FilterConfigParseContext context) { + GcpAuthnFilterConfig gcpAuthnProto; + if (!(rawProtoMessage instanceof Any)) { + return ConfigOrError.fromError("Invalid config type: " + rawProtoMessage.getClass()); + } + Any anyMessage = (Any) rawProtoMessage; + + try { + gcpAuthnProto = anyMessage.unpack(GcpAuthnFilterConfig.class); + } catch (InvalidProtocolBufferException e) { + return ConfigOrError.fromError("Invalid proto: " + e); + } + + long cacheSize = 10; + // Validate cache_config + if (gcpAuthnProto.hasCacheConfig()) { + TokenCacheConfig cacheConfig = gcpAuthnProto.getCacheConfig(); + if (cacheConfig.hasCacheSize()) { + cacheSize = cacheConfig.getCacheSize().getValue(); + if (cacheSize == 0) { + return ConfigOrError.fromError( + "cache_config.cache_size must be greater than zero"); + } + } + + // LruCache's size is an int and briefly exceeds its maximum size before evicting entries + cacheSize = UnsignedLongs.min(cacheSize, Integer.MAX_VALUE - 1); + } + + GcpAuthenticationConfig config = new GcpAuthenticationConfig((int) cacheSize); + return ConfigOrError.fromConfig(config); + } + + @Override + public ConfigOrError parseFilterConfigOverride( + Message rawProtoMessage, FilterConfigParseContext context) { + return parseFilterConfig(rawProtoMessage, context); + } + } + + @Nullable + @Override + public ClientInterceptor buildClientInterceptor(FilterConfig config, + @Nullable FilterConfig overrideConfig, ScheduledExecutorService scheduler) { + + ComputeEngineCredentials credentials = ComputeEngineCredentials.create(); + synchronized (callCredentialsCache) { + callCredentialsCache.resizeCache(((GcpAuthenticationConfig) config).getCacheSize()); + } + return new ClientInterceptor() { + @Override + public ClientCall interceptCall( + MethodDescriptor method, CallOptions callOptions, Channel next) { + + String clusterName = callOptions.getOption(CLUSTER_SELECTION_KEY); + if (clusterName == null) { + return new FailingClientCall<>( + Status.UNAVAILABLE.withDescription( + String.format( + "GCP Authn for %s does not contain cluster resource", filterInstanceName))); + } + + if (!clusterName.startsWith("cluster:")) { + return next.newCall(method, callOptions); + } + XdsConfig xdsConfig = callOptions.getOption(XDS_CONFIG_CALL_OPTION_KEY); + if (xdsConfig == null) { + return new FailingClientCall<>( + Status.UNAVAILABLE.withDescription( + String.format( + "GCP Authn for %s with %s does not contain xds configuration", + filterInstanceName, clusterName))); + } + StatusOr xdsCluster = + xdsConfig.getClusters().get(clusterName.substring("cluster:".length())); + if (xdsCluster == null) { + return new FailingClientCall<>( + Status.UNAVAILABLE.withDescription( + String.format( + "GCP Authn for %s with %s - xds cluster config does not contain xds cluster", + filterInstanceName, clusterName))); + } + if (!xdsCluster.hasValue()) { + return new FailingClientCall<>(xdsCluster.getStatus()); + } + Object audienceObj = + xdsCluster.getValue().getClusterResource().parsedMetadata().get(filterInstanceName); + if (audienceObj == null) { + return next.newCall(method, callOptions); + } + if (!(audienceObj instanceof AudienceWrapper)) { + return new FailingClientCall<>( + Status.UNAVAILABLE.withDescription( + String.format("GCP Authn found wrong type in %s metadata: %s=%s", + clusterName, filterInstanceName, audienceObj.getClass()))); + } + AudienceWrapper audience = (AudienceWrapper) audienceObj; + CallCredentials existingCallCredentials = callOptions.getCredentials(); + CallCredentials newCallCredentials = + getCallCredentials(callCredentialsCache, audience.audience, credentials); + if (existingCallCredentials != null) { + callOptions = callOptions.withCallCredentials( + new CompositeCallCredentials(existingCallCredentials, newCallCredentials)); + } else { + callOptions = callOptions.withCallCredentials(newCallCredentials); + } + return next.newCall(method, callOptions); + } + }; + } + + private CallCredentials getCallCredentials(LruCache cache, + String audience, ComputeEngineCredentials credentials) { + + synchronized (cache) { + return cache.getOrInsert(audience, key -> { + IdTokenCredentials creds = IdTokenCredentials.newBuilder() + .setIdTokenProvider(credentials) + .setTargetAudience(audience) + .build(); + return MoreCallCredentials.from(creds); + }); + } + } + + static final class GcpAuthenticationConfig implements FilterConfig { + + private final int cacheSize; + + public GcpAuthenticationConfig(int cacheSize) { + this.cacheSize = cacheSize; + } + + public int getCacheSize() { + return cacheSize; + } + + @Override + public String typeUrl() { + return GcpAuthenticationFilter.TYPE_URL; + } + } + + /** An implementation of {@link ClientCall} that fails when started. */ + @VisibleForTesting + static final class FailingClientCall extends ClientCall { + + @VisibleForTesting + final Status error; + + public FailingClientCall(Status error) { + this.error = error; + } + + @Override + public void start(ClientCall.Listener listener, Metadata headers) { + listener.onClose(error, new Metadata()); + } + + @Override + public void request(int numMessages) {} + + @Override + public void cancel(String message, Throwable cause) {} + + @Override + public void halfClose() {} + + @Override + public void sendMessage(ReqT message) {} + } + + private static final class LruCache { + + private Map cache; + private int maxSize; + + LruCache(int maxSize) { + this.maxSize = maxSize; + this.cache = createEvictingMap(maxSize); + } + + V getOrInsert(K key, Function create) { + return cache.computeIfAbsent(key, create); + } + + private void resizeCache(int newSize) { + if (newSize >= maxSize) { + maxSize = newSize; + return; + } + Map newCache = createEvictingMap(newSize); + maxSize = newSize; + newCache.putAll(cache); + cache = newCache; + } + + private Map createEvictingMap(int size) { + return new LinkedHashMap(size, 0.75f, true) { + @Override + protected boolean removeEldestEntry(Map.Entry eldest) { + return size() > LruCache.this.maxSize; + } + }; + } + } + + static class AudienceMetadataParser implements MetadataValueParser { + + static final class AudienceWrapper { + final String audience; + + AudienceWrapper(String audience) { + this.audience = checkNotNull(audience); + } + } + + @Override + public String getTypeUrl() { + return "type.googleapis.com/envoy.extensions.filters.http.gcp_authn.v3.Audience"; + } + + @Override + public AudienceWrapper parse(Any any) throws ResourceInvalidException { + Audience audience; + try { + audience = any.unpack(Audience.class); + } catch (InvalidProtocolBufferException ex) { + throw new ResourceInvalidException("Invalid Resource in address proto", ex); + } + String url = audience.getUrl(); + if (url.isEmpty()) { + throw new ResourceInvalidException( + "Audience URL is empty. Metadata value must contain a valid URL."); + } + return new AudienceWrapper(url); + } + } +} diff --git a/xds/src/main/java/io/grpc/xds/GrpcBootstrapImplConfig.java b/xds/src/main/java/io/grpc/xds/GrpcBootstrapImplConfig.java new file mode 100644 index 00000000000..e119321fb6c --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/GrpcBootstrapImplConfig.java @@ -0,0 +1,34 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds; + +import com.google.auto.value.AutoValue; +import io.grpc.Internal; +import io.grpc.xds.client.AllowedGrpcServices; + +/** + * Custom configuration for gRPC xDS bootstrap implementation. + */ +@Internal +@AutoValue +public abstract class GrpcBootstrapImplConfig { + public abstract AllowedGrpcServices allowedGrpcServices(); + + public static GrpcBootstrapImplConfig create(AllowedGrpcServices services) { + return new AutoValue_GrpcBootstrapImplConfig(services); + } +} diff --git a/xds/src/main/java/io/grpc/xds/GrpcBootstrapperImpl.java b/xds/src/main/java/io/grpc/xds/GrpcBootstrapperImpl.java index f61fab42cae..00a2e0d48d6 100644 --- a/xds/src/main/java/io/grpc/xds/GrpcBootstrapperImpl.java +++ b/xds/src/main/java/io/grpc/xds/GrpcBootstrapperImpl.java @@ -18,14 +18,21 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableMap; +import com.google.errorprone.annotations.concurrent.GuardedBy; +import io.grpc.CallCredentials; import io.grpc.ChannelCredentials; import io.grpc.internal.JsonUtil; +import io.grpc.xds.client.AllowedGrpcServices; +import io.grpc.xds.client.AllowedGrpcServices.AllowedGrpcService; import io.grpc.xds.client.BootstrapperImpl; +import io.grpc.xds.client.ConfiguredChannelCredentials; +import io.grpc.xds.client.ConfiguredChannelCredentials.ChannelCredsConfig; import io.grpc.xds.client.XdsInitializationException; import io.grpc.xds.client.XdsLogger; import java.io.IOException; import java.util.List; import java.util.Map; +import java.util.Optional; import javax.annotation.Nullable; class GrpcBootstrapperImpl extends BootstrapperImpl { @@ -48,7 +55,11 @@ class GrpcBootstrapperImpl extends BootstrapperImpl { @Override public BootstrapInfo bootstrap(Map rawData) throws XdsInitializationException { - return super.bootstrap(rawData); + BootstrapInfo info = super.bootstrap(rawData); + if (info.servers().isEmpty()) { + throw new XdsInitializationException("Invalid bootstrap: 'xds_servers' is empty"); + } + return info; } /** @@ -92,29 +103,50 @@ protected String getJsonContent() throws XdsInitializationException, IOException @Override protected Object getImplSpecificConfig(Map serverConfig, String serverUri) throws XdsInitializationException { - return getChannelCredentials(serverConfig, serverUri); + ConfiguredChannelCredentials configuredChannel = getChannelCredentials(serverConfig, serverUri); + return configuredChannel != null ? configuredChannel.channelCredentials() : null; + } + + @GuardedBy("GrpcBootstrapperImpl.class") + private static Map defaultBootstrapOverride; + @GuardedBy("GrpcBootstrapperImpl.class") + private static BootstrapInfo defaultBootstrap; + + static synchronized void setDefaultBootstrapOverride(Map rawBootstrap) { + defaultBootstrapOverride = rawBootstrap; + } + + static synchronized BootstrapInfo defaultBootstrap() throws XdsInitializationException { + if (defaultBootstrap == null) { + if (defaultBootstrapOverride == null) { + defaultBootstrap = new GrpcBootstrapperImpl().bootstrap(); + } else { + defaultBootstrap = new GrpcBootstrapperImpl().bootstrap(defaultBootstrapOverride); + } + } + return defaultBootstrap; } - private static ChannelCredentials getChannelCredentials(Map serverConfig, - String serverUri) + private static ConfiguredChannelCredentials getChannelCredentials(Map serverConfig, + String serverUri) throws XdsInitializationException { List rawChannelCredsList = JsonUtil.getList(serverConfig, "channel_creds"); if (rawChannelCredsList == null || rawChannelCredsList.isEmpty()) { throw new XdsInitializationException( "Invalid bootstrap: server " + serverUri + " 'channel_creds' required"); } - ChannelCredentials channelCredentials = + ConfiguredChannelCredentials credentials = parseChannelCredentials(JsonUtil.checkObjectList(rawChannelCredsList), serverUri); - if (channelCredentials == null) { + if (credentials == null) { throw new XdsInitializationException( "Server " + serverUri + ": no supported channel credentials found"); } - return channelCredentials; + return credentials; } @Nullable - private static ChannelCredentials parseChannelCredentials(List> jsonList, - String serverUri) + private static ConfiguredChannelCredentials parseChannelCredentials(List> jsonList, + String serverUri) throws XdsInitializationException { for (Map channelCreds : jsonList) { String type = JsonUtil.getString(channelCreds, "type"); @@ -130,9 +162,95 @@ private static ChannelCredentials parseChannelCredentials(List> j config = ImmutableMap.of(); } - return provider.newChannelCredentials(config); + ChannelCredentials creds = provider.newChannelCredentials(config); + if (creds == null) { + return null; + } + return ConfiguredChannelCredentials.create(creds, new JsonChannelCredsConfig(type, config)); } } return null; } + + @Override + protected Optional parseImplSpecificObject( + @Nullable Map rawAllowedGrpcServices) + throws XdsInitializationException { + if (rawAllowedGrpcServices == null || rawAllowedGrpcServices.isEmpty()) { + return Optional.of(GrpcBootstrapImplConfig.create(AllowedGrpcServices.empty())); + } + + ImmutableMap.Builder builder = + ImmutableMap.builder(); + for (String targetUri : rawAllowedGrpcServices.keySet()) { + Map serviceConfig = JsonUtil.getObject(rawAllowedGrpcServices, targetUri); + if (serviceConfig == null) { + throw new XdsInitializationException( + "Invalid allowed_grpc_services config for " + targetUri); + } + ConfiguredChannelCredentials configuredChannel = + getChannelCredentials(serviceConfig, targetUri); + + Optional callCredentials = Optional.empty(); + List rawCallCredsList = JsonUtil.getList(serviceConfig, "call_creds"); + if (rawCallCredsList != null && !rawCallCredsList.isEmpty()) { + callCredentials = + parseCallCredentials(JsonUtil.checkObjectList(rawCallCredsList), targetUri); + } + + AllowedGrpcService.Builder b = AllowedGrpcService.builder() + .configuredChannelCredentials(configuredChannel); + callCredentials.ifPresent(b::callCredentials); + builder.put(targetUri, b.build()); + } + GrpcBootstrapImplConfig customConfig = + GrpcBootstrapImplConfig.create(AllowedGrpcServices.create(builder.build())); + return Optional.of(customConfig); + } + + @SuppressWarnings("unused") + private static Optional parseCallCredentials(List> jsonList, + String targetUri) + throws XdsInitializationException { + // TODO(sauravzg): Currently no xDS call credentials providers are implemented (no + // XdsCallCredentialsRegistry). + // As per A102/A97, we should just ignore unsupported call credentials types + // without throwing an exception. + return Optional.empty(); + } + + private static final class JsonChannelCredsConfig implements ChannelCredsConfig { + private final String type; + private final Map config; + + JsonChannelCredsConfig(String type, Map config) { + this.type = type; + this.config = config; + } + + @Override + public String type() { + return type; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + JsonChannelCredsConfig that = (JsonChannelCredsConfig) o; + return java.util.Objects.equals(type, that.type) + && java.util.Objects.equals(config, that.config); + } + + @Override + public int hashCode() { + return java.util.Objects.hash(type, config); + } + } + } + diff --git a/xds/src/main/java/io/grpc/xds/GrpcServiceConfigParser.java b/xds/src/main/java/io/grpc/xds/GrpcServiceConfigParser.java new file mode 100644 index 00000000000..1510924f74c --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/GrpcServiceConfigParser.java @@ -0,0 +1,339 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds; + +import com.google.auth.oauth2.AccessToken; +import com.google.auth.oauth2.OAuth2Credentials; +import com.google.common.collect.ImmutableList; +import com.google.protobuf.Any; +import com.google.protobuf.InvalidProtocolBufferException; +import com.google.protobuf.util.Durations; +import io.envoyproxy.envoy.config.core.v3.GrpcService; +import io.envoyproxy.envoy.extensions.grpc_service.call_credentials.access_token.v3.AccessTokenCredentials; +import io.envoyproxy.envoy.extensions.grpc_service.channel_credentials.xds.v3.XdsCredentials; +import io.grpc.CallCredentials; +import io.grpc.CompositeCallCredentials; +import io.grpc.InsecureChannelCredentials; +import io.grpc.Metadata; +import io.grpc.NameResolverRegistry; +import io.grpc.SecurityLevel; +import io.grpc.alts.GoogleDefaultChannelCredentials; +import io.grpc.auth.MoreCallCredentials; +import io.grpc.xds.client.AllowedGrpcServices; +import io.grpc.xds.client.AllowedGrpcServices.AllowedGrpcService; +import io.grpc.xds.client.Bootstrapper; +import io.grpc.xds.client.ConfiguredChannelCredentials; +import io.grpc.xds.internal.grpcservice.GrpcServiceConfig; +import io.grpc.xds.internal.grpcservice.GrpcServiceParseException; +import io.grpc.xds.internal.grpcservice.HeaderValue; +import io.grpc.xds.internal.grpcservice.HeaderValueValidationUtils; +import java.net.URI; +import java.net.URISyntaxException; +import java.time.Duration; +import java.util.ArrayList; +import java.util.Date; +import java.util.List; +import java.util.Optional; +import java.util.concurrent.Executor; + +/** + * Parser for {@link io.envoyproxy.envoy.config.core.v3.GrpcService} and related protos. + */ +final class GrpcServiceConfigParser { + + static final String TLS_CREDENTIALS_TYPE_URL = + "type.googleapis.com/envoy.extensions.grpc_service.channel_credentials." + + "tls.v3.TlsCredentials"; + static final String LOCAL_CREDENTIALS_TYPE_URL = + "type.googleapis.com/envoy.extensions.grpc_service.channel_credentials." + + "local.v3.LocalCredentials"; + static final String XDS_CREDENTIALS_TYPE_URL = + "type.googleapis.com/envoy.extensions.grpc_service.channel_credentials." + + "xds.v3.XdsCredentials"; + static final String INSECURE_CREDENTIALS_TYPE_URL = + "type.googleapis.com/envoy.extensions.grpc_service.channel_credentials." + + "insecure.v3.InsecureCredentials"; + static final String GOOGLE_DEFAULT_CREDENTIALS_TYPE_URL = + "type.googleapis.com/envoy.extensions.grpc_service.channel_credentials." + + "google_default.v3.GoogleDefaultCredentials"; + + + + /** + * Parses the {@link io.envoyproxy.envoy.config.core.v3.GrpcService} proto to create a + * {@link GrpcServiceConfig} instance. + * + * @param grpcServiceProto The proto to parse. + * @return A {@link GrpcServiceConfig} instance. + * @throws GrpcServiceParseException if the proto is invalid or uses unsupported features. + */ + public static GrpcServiceConfig parse(GrpcService grpcServiceProto, + Bootstrapper.BootstrapInfo bootstrapInfo, Bootstrapper.ServerInfo serverInfo) + throws GrpcServiceParseException { + if (!grpcServiceProto.hasGoogleGrpc()) { + throw new GrpcServiceParseException( + "Unsupported: GrpcService must have GoogleGrpc, got: " + grpcServiceProto); + } + GrpcServiceConfig.GoogleGrpcConfig googleGrpcConfig = + parseGoogleGrpcConfig(grpcServiceProto.getGoogleGrpc(), bootstrapInfo, serverInfo); + + GrpcServiceConfig.Builder builder = GrpcServiceConfig.builder().googleGrpc(googleGrpcConfig); + + ImmutableList.Builder initialMetadata = ImmutableList.builder(); + for (io.envoyproxy.envoy.config.core.v3.HeaderValue header : grpcServiceProto + .getInitialMetadataList()) { + String key = header.getKey(); + HeaderValue headerValue; + if (key.endsWith(Metadata.BINARY_HEADER_SUFFIX)) { + headerValue = HeaderValue.create(key, header.getRawValue()); + } else { + headerValue = HeaderValue.create(key, header.getValue()); + } + if (HeaderValueValidationUtils.isDisallowed(headerValue)) { + throw new GrpcServiceParseException("Invalid initial metadata header: " + key); + } + initialMetadata.add(headerValue); + } + builder.initialMetadata(initialMetadata.build()); + + if (grpcServiceProto.hasTimeout()) { + com.google.protobuf.Duration timeout = grpcServiceProto.getTimeout(); + if (!Durations.isValid(timeout) || Durations.compare(timeout, Durations.ZERO) <= 0) { + throw new GrpcServiceParseException("Timeout must be strictly positive and valid"); + } + builder.timeout(Duration.ofSeconds(timeout.getSeconds(), timeout.getNanos())); + } + return builder.build(); + } + + /** + * Parses the {@link io.envoyproxy.envoy.config.core.v3.GrpcService.GoogleGrpc} proto to create a + * {@link GrpcServiceConfig.GoogleGrpcConfig} instance. + * + * @param googleGrpcProto The proto to parse. + * @return A {@link GrpcServiceConfig.GoogleGrpcConfig} instance. + * @throws GrpcServiceParseException if the proto is invalid. + */ + public static GrpcServiceConfig.GoogleGrpcConfig parseGoogleGrpcConfig( + GrpcService.GoogleGrpc googleGrpcProto, Bootstrapper.BootstrapInfo bootstrapInfo, + Bootstrapper.ServerInfo serverInfo) throws GrpcServiceParseException { + + String targetUri = googleGrpcProto.getTargetUri(); + + AllowedGrpcServices allowedGrpcServices = + bootstrapInfo.implSpecificObject() + .filter(GrpcBootstrapImplConfig.class::isInstance) + .map(GrpcBootstrapImplConfig.class::cast) + .map(GrpcBootstrapImplConfig::allowedGrpcServices) + .orElse(AllowedGrpcServices.empty()); + + boolean isTrustedControlPlane = serverInfo.isTrustedXdsServer(); + Optional override = + Optional.ofNullable(allowedGrpcServices.services().get(targetUri)); + + boolean isTargetUriSchemeSupported = false; + try { + URI uri = new URI(targetUri); + String scheme = uri.getScheme(); + if (scheme == null) { + scheme = NameResolverRegistry.getDefaultRegistry().getDefaultScheme(); + } + if (scheme != null) { + isTargetUriSchemeSupported = + NameResolverRegistry.getDefaultRegistry().getProviderForScheme(scheme) != null; + } + } catch (URISyntaxException e) { + // Fallback or ignore if not a valid URI + } + + if (!isTargetUriSchemeSupported) { + throw new GrpcServiceParseException("Target URI scheme is not resolvable: " + targetUri); + } + + if (!isTrustedControlPlane) { + if (!override.isPresent()) { + throw new GrpcServiceParseException( + "Untrusted xDS server & URI not found in allowed_grpc_services: " + targetUri); + } + + GrpcServiceConfig.GoogleGrpcConfig.Builder builder = + GrpcServiceConfig.GoogleGrpcConfig.builder().target(targetUri) + .configuredChannelCredentials(override.get().configuredChannelCredentials()); + if (override.get().callCredentials().isPresent()) { + builder.callCredentials(override.get().callCredentials().get()); + } + return builder.build(); + } + + ConfiguredChannelCredentials channelCreds = + extractChannelCredentials(googleGrpcProto.getChannelCredentialsPluginList()); + + Optional callCreds = + extractCallCredentials(googleGrpcProto.getCallCredentialsPluginList()); + + GrpcServiceConfig.GoogleGrpcConfig.Builder builder = + GrpcServiceConfig.GoogleGrpcConfig.builder().target(googleGrpcProto.getTargetUri()) + .configuredChannelCredentials(channelCreds); + if (callCreds.isPresent()) { + builder.callCredentials(callCreds.get()); + } + return builder.build(); + } + + private static Optional channelCredsFromProto(Any cred) + throws GrpcServiceParseException { + String typeUrl = cred.getTypeUrl(); + try { + switch (typeUrl) { + case GOOGLE_DEFAULT_CREDENTIALS_TYPE_URL: + return Optional + .of(ConfiguredChannelCredentials.create(GoogleDefaultChannelCredentials.create(), + new ProtoChannelCredsConfig(typeUrl, cred))); + case INSECURE_CREDENTIALS_TYPE_URL: + return Optional.of(ConfiguredChannelCredentials.create( + InsecureChannelCredentials.create(), new ProtoChannelCredsConfig(typeUrl, cred))); + case XDS_CREDENTIALS_TYPE_URL: + XdsCredentials xdsConfig = cred.unpack(XdsCredentials.class); + Optional fallbackCreds = + channelCredsFromProto(xdsConfig.getFallbackCredentials()); + if (!fallbackCreds.isPresent()) { + throw new GrpcServiceParseException( + "Unsupported fallback credentials type for XdsCredentials"); + } + return Optional.of(ConfiguredChannelCredentials.create( + XdsChannelCredentials.create(fallbackCreds.get().channelCredentials()), + new ProtoChannelCredsConfig(typeUrl, cred))); + case LOCAL_CREDENTIALS_TYPE_URL: + throw new GrpcServiceParseException( + "LocalCredentials are not supported in grpc-java. " + + "See https://github.com/grpc/grpc-java/issues/8928"); + case TLS_CREDENTIALS_TYPE_URL: + // For this PR, we establish this structural skeleton, + // but throw an GrpcServiceParseException until the exact stream conversions are + // merged. + throw new GrpcServiceParseException( + "TlsCredentials input stream construction pending."); + default: + return Optional.empty(); + } + } catch (InvalidProtocolBufferException e) { + throw new GrpcServiceParseException("Failed to parse channel credentials: " + e.getMessage()); + } + } + + private static ConfiguredChannelCredentials extractChannelCredentials( + List channelCredentialPlugins) throws GrpcServiceParseException { + for (Any cred : channelCredentialPlugins) { + Optional parsed = channelCredsFromProto(cred); + if (parsed.isPresent()) { + return parsed.get(); + } + } + throw new GrpcServiceParseException("No valid supported channel_credentials found"); + } + + private static Optional callCredsFromProto(Any cred) + throws GrpcServiceParseException { + if (cred.is(AccessTokenCredentials.class)) { + try { + AccessTokenCredentials accessToken = cred.unpack(AccessTokenCredentials.class); + if (accessToken.getToken().isEmpty()) { + throw new GrpcServiceParseException("Missing or empty access token in call credentials."); + } + return Optional + .of(new SecurityAwareAccessTokenCredentials(MoreCallCredentials.from(OAuth2Credentials + .create(new AccessToken(accessToken.getToken(), new Date(Long.MAX_VALUE)))))); + } catch (InvalidProtocolBufferException e) { + throw new GrpcServiceParseException( + "Failed to parse access token credentials: " + e.getMessage()); + } + } + return Optional.empty(); + } + + private static Optional extractCallCredentials(List callCredentialPlugins) + throws GrpcServiceParseException { + List creds = new ArrayList<>(); + for (Any cred : callCredentialPlugins) { + Optional parsed = callCredsFromProto(cred); + if (parsed.isPresent()) { + creds.add(parsed.get()); + } + } + return creds.stream().reduce(CompositeCallCredentials::new); + } + + private static final class SecurityAwareAccessTokenCredentials extends CallCredentials { + + private final CallCredentials delegate; + + SecurityAwareAccessTokenCredentials(CallCredentials delegate) { + this.delegate = delegate; + } + + @Override + public void applyRequestMetadata(RequestInfo requestInfo, Executor appExecutor, + MetadataApplier applier) { + if (requestInfo.getSecurityLevel() == SecurityLevel.PRIVACY_AND_INTEGRITY) { + delegate.applyRequestMetadata(requestInfo, appExecutor, applier); + } else { + applier.apply(new Metadata()); + } + } + } + + static final class ProtoChannelCredsConfig + implements ConfiguredChannelCredentials.ChannelCredsConfig { + private final String type; + private final Any configProto; + + ProtoChannelCredsConfig(String type, Any configProto) { + this.type = type; + this.configProto = configProto; + } + + @Override + public String type() { + return type; + } + + Any configProto() { + return configProto; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + ProtoChannelCredsConfig that = (ProtoChannelCredsConfig) o; + return java.util.Objects.equals(type, that.type) + && java.util.Objects.equals(configProto, that.configProto); + } + + @Override + public int hashCode() { + return java.util.Objects.hash(type, configProto); + } + } + + + +} diff --git a/xds/src/main/java/io/grpc/xds/GrpcXdsTransportFactory.java b/xds/src/main/java/io/grpc/xds/GrpcXdsTransportFactory.java index 74c28ba2d2d..5100537aea2 100644 --- a/xds/src/main/java/io/grpc/xds/GrpcXdsTransportFactory.java +++ b/xds/src/main/java/io/grpc/xds/GrpcXdsTransportFactory.java @@ -19,6 +19,7 @@ import static com.google.common.base.Preconditions.checkNotNull; import com.google.common.annotations.VisibleForTesting; +import io.grpc.CallCredentials; import io.grpc.CallOptions; import io.grpc.ChannelCredentials; import io.grpc.ClientCall; @@ -30,39 +31,94 @@ import io.grpc.Status; import io.grpc.xds.client.Bootstrapper; import io.grpc.xds.client.XdsTransportFactory; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.TimeUnit; +/** + * A factory for creating gRPC-based transports for xDS communication. + * + *

WARNING: This class reuses channels when possible, based on the provided {@link + * Bootstrapper.ServerInfo} with important considerations. The {@link Bootstrapper.ServerInfo} + * includes {@link ChannelCredentials}, which is compared by reference equality. This means every + * {@link Bootstrapper.BootstrapInfo} would have non-equal copies of {@link + * Bootstrapper.ServerInfo}, even if they all represent the same xDS server configuration. For gRPC + * name resolution with the {@code xds} and {@code google-c2p} scheme, this transport sharing works + * as expected as it internally reuses a single {@link Bootstrapper.BootstrapInfo} instance. + * Otherwise, new transports would be created for each {@link Bootstrapper.ServerInfo} despite them + * possibly representing the same xDS server configuration and defeating the purpose of transport + * sharing. + */ final class GrpcXdsTransportFactory implements XdsTransportFactory { - static final GrpcXdsTransportFactory DEFAULT_XDS_TRANSPORT_FACTORY = - new GrpcXdsTransportFactory(); + private final CallCredentials callCredentials; + // The map of xDS server info to its corresponding gRPC xDS transport. + // This enables reusing and sharing the same underlying gRPC channel. + // + // NOTE: ConcurrentHashMap is used as a per-entry lock and all reads and writes must be a mutation + // via the ConcurrentHashMap APIs to acquire the per-entry lock in order to ensure thread safety + // for reference counting of each GrpcXdsTransport instance. + private static final Map xdsServerInfoToTransportMap = + new ConcurrentHashMap<>(); + + GrpcXdsTransportFactory(CallCredentials callCredentials) { + this.callCredentials = callCredentials; + } @Override public XdsTransport create(Bootstrapper.ServerInfo serverInfo) { - return new GrpcXdsTransport(serverInfo); + return xdsServerInfoToTransportMap.compute( + serverInfo, + (info, transport) -> { + if (transport == null) { + transport = new GrpcXdsTransport(serverInfo, callCredentials); + } + ++transport.refCount; + return transport; + }); } @VisibleForTesting public XdsTransport createForTest(ManagedChannel channel) { - return new GrpcXdsTransport(channel); + return new GrpcXdsTransport(channel, callCredentials, null); } @VisibleForTesting static class GrpcXdsTransport implements XdsTransport { private final ManagedChannel channel; + private final CallCredentials callCredentials; + private final Bootstrapper.ServerInfo serverInfo; + // Must only be accessed via the ConcurrentHashMap APIs which act as the locking methods. + private int refCount = 0; public GrpcXdsTransport(Bootstrapper.ServerInfo serverInfo) { + this(serverInfo, null); + } + + @VisibleForTesting + public GrpcXdsTransport(ManagedChannel channel) { + this(channel, null, null); + } + + public GrpcXdsTransport(Bootstrapper.ServerInfo serverInfo, CallCredentials callCredentials) { String target = serverInfo.target(); ChannelCredentials channelCredentials = (ChannelCredentials) serverInfo.implSpecificConfig(); this.channel = Grpc.newChannelBuilder(target, channelCredentials) .keepAliveTime(5, TimeUnit.MINUTES) .build(); + this.callCredentials = callCredentials; + this.serverInfo = serverInfo; } @VisibleForTesting - public GrpcXdsTransport(ManagedChannel channel) { + public GrpcXdsTransport( + ManagedChannel channel, + CallCredentials callCredentials, + Bootstrapper.ServerInfo serverInfo) { this.channel = checkNotNull(channel, "channel"); + this.callCredentials = callCredentials; + this.serverInfo = serverInfo; } @Override @@ -72,7 +128,8 @@ public StreamingCall createStreamingCall( MethodDescriptor.Marshaller respMarshaller) { Context prevContext = Context.ROOT.attach(); try { - return new XdsStreamingCall<>(fullMethodName, reqMarshaller, respMarshaller); + return new XdsStreamingCall<>( + fullMethodName, reqMarshaller, respMarshaller, callCredentials); } finally { Context.ROOT.detach(prevContext); } @@ -81,7 +138,19 @@ public StreamingCall createStreamingCall( @Override public void shutdown() { - channel.shutdown(); + if (serverInfo == null) { + channel.shutdown(); + return; + } + xdsServerInfoToTransportMap.computeIfPresent( + serverInfo, + (info, transport) -> { + if (--transport.refCount == 0) { // Prefix decrement and return the updated value. + transport.channel.shutdown(); + return null; // Remove mapping. + } + return transport; + }); } private class XdsStreamingCall implements @@ -89,16 +158,21 @@ private class XdsStreamingCall implements private final ClientCall call; - public XdsStreamingCall(String methodName, MethodDescriptor.Marshaller reqMarshaller, - MethodDescriptor.Marshaller respMarshaller) { - this.call = channel.newCall( - MethodDescriptor.newBuilder() - .setFullMethodName(methodName) - .setType(MethodDescriptor.MethodType.BIDI_STREAMING) - .setRequestMarshaller(reqMarshaller) - .setResponseMarshaller(respMarshaller) - .build(), - CallOptions.DEFAULT); // TODO(zivy): support waitForReady + public XdsStreamingCall( + String methodName, + MethodDescriptor.Marshaller reqMarshaller, + MethodDescriptor.Marshaller respMarshaller, + CallCredentials callCredentials) { + this.call = + channel.newCall( + MethodDescriptor.newBuilder() + .setFullMethodName(methodName) + .setType(MethodDescriptor.MethodType.BIDI_STREAMING) + .setRequestMarshaller(reqMarshaller) + .setResponseMarshaller(respMarshaller) + .build(), + CallOptions.DEFAULT.withCallCredentials( + callCredentials)); // TODO(zivy): support waitForReady } @Override diff --git a/xds/src/main/java/io/grpc/xds/InternalGrpcBootstrapperImpl.java b/xds/src/main/java/io/grpc/xds/InternalGrpcBootstrapperImpl.java new file mode 100644 index 00000000000..7bbc2a6dfca --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/InternalGrpcBootstrapperImpl.java @@ -0,0 +1,35 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds; + +import io.grpc.Internal; +import io.grpc.xds.client.Bootstrapper.BootstrapInfo; +import io.grpc.xds.client.XdsInitializationException; +import java.util.Map; + +/** + * Internal accessors for GrpcBootstrapperImpl. + */ +@Internal +public final class InternalGrpcBootstrapperImpl { + private InternalGrpcBootstrapperImpl() {} // prevent instantiation + + public static BootstrapInfo parseBootstrap(Map bootstrap) + throws XdsInitializationException { + return new GrpcBootstrapperImpl().bootstrap(bootstrap); + } +} diff --git a/xds/src/main/java/io/grpc/xds/InternalRbacFilter.java b/xds/src/main/java/io/grpc/xds/InternalRbacFilter.java index 54e6c748cd5..476adbf9cfd 100644 --- a/xds/src/main/java/io/grpc/xds/InternalRbacFilter.java +++ b/xds/src/main/java/io/grpc/xds/InternalRbacFilter.java @@ -19,8 +19,6 @@ import io.envoyproxy.envoy.extensions.filters.http.rbac.v3.RBAC; import io.grpc.Internal; import io.grpc.ServerInterceptor; -import io.grpc.xds.RbacConfig; -import io.grpc.xds.RbacFilter; /** This class exposes some functionality in RbacFilter to other packages. */ @Internal @@ -30,11 +28,12 @@ private InternalRbacFilter() {} /** Parses RBAC filter config and creates AuthorizationServerInterceptor. */ public static ServerInterceptor createInterceptor(RBAC rbac) { - ConfigOrError filterConfig = RbacFilter.parseRbacConfig(rbac); + ConfigOrError filterConfig = RbacFilter.Provider.parseRbacConfig(rbac); if (filterConfig.errorDetail != null) { throw new IllegalArgumentException( String.format("Failed to parse Rbac policy: %s", filterConfig.errorDetail)); } - return new RbacFilter().buildServerInterceptor(filterConfig.config, null); + return new RbacFilter.Provider().newInstance("internalRbacFilter") + .buildServerInterceptor(filterConfig.config, null); } } diff --git a/xds/src/main/java/io/grpc/xds/InternalSharedXdsClientPoolProvider.java b/xds/src/main/java/io/grpc/xds/InternalSharedXdsClientPoolProvider.java index 39b9ed0d095..cc5ff128274 100644 --- a/xds/src/main/java/io/grpc/xds/InternalSharedXdsClientPoolProvider.java +++ b/xds/src/main/java/io/grpc/xds/InternalSharedXdsClientPoolProvider.java @@ -16,8 +16,11 @@ package io.grpc.xds; +import io.grpc.CallCredentials; import io.grpc.Internal; +import io.grpc.MetricRecorder; import io.grpc.internal.ObjectPool; +import io.grpc.xds.client.Bootstrapper.BootstrapInfo; import io.grpc.xds.client.XdsClient; import io.grpc.xds.client.XdsInitializationException; import java.util.Map; @@ -30,12 +33,79 @@ public final class InternalSharedXdsClientPoolProvider { // Prevent instantiation private InternalSharedXdsClientPoolProvider() {} + /** + * Override the global bootstrap. + * + * @deprecated Use InternalGrpcBootstrapperImpl.parseBootstrap() and pass the result to + * getOrCreate(). + */ + @Deprecated public static void setDefaultProviderBootstrapOverride(Map bootstrap) { - SharedXdsClientPoolProvider.getDefaultProvider().setBootstrapOverride(bootstrap); + GrpcBootstrapperImpl.setDefaultBootstrapOverride(bootstrap); } + /** + * Get an XdsClient pool. + * + * @deprecated Use InternalGrpcBootstrapperImpl.parseBootstrap() and pass the result to the other + * getOrCreate(). + */ + @Deprecated public static ObjectPool getOrCreate(String target) throws XdsInitializationException { - return SharedXdsClientPoolProvider.getDefaultProvider().getOrCreate(); + return getOrCreate(target, new MetricRecorder() {}); + } + + /** + * Get an XdsClient pool. + * + * @deprecated Use InternalGrpcBootstrapperImpl.parseBootstrap() and pass the result to the other + * getOrCreate(). + */ + @Deprecated + public static ObjectPool getOrCreate(String target, MetricRecorder metricRecorder) + throws XdsInitializationException { + return getOrCreate(target, metricRecorder, null); + } + + /** + * Get an XdsClient pool. + * + * @deprecated Use InternalGrpcBootstrapperImpl.parseBootstrap() and pass the result to the other + * getOrCreate(). + */ + @Deprecated + public static ObjectPool getOrCreate( + String target, MetricRecorder metricRecorder, CallCredentials transportCallCredentials) + throws XdsInitializationException { + return SharedXdsClientPoolProvider.getDefaultProvider() + .getOrCreate(target, metricRecorder, transportCallCredentials); + } + + public static XdsClientResult getOrCreate( + String target, BootstrapInfo bootstrapInfo, MetricRecorder metricRecorder, + CallCredentials transportCallCredentials) { + return new XdsClientResult(SharedXdsClientPoolProvider.getDefaultProvider() + .getOrCreate(target, bootstrapInfo, metricRecorder, transportCallCredentials)); + } + + /** + * An ObjectPool, except without exposing io.grpc.internal, which must not be used for + * cross-package APIs. + */ + public static final class XdsClientResult { + private final ObjectPool xdsClientPool; + + XdsClientResult(ObjectPool xdsClientPool) { + this.xdsClientPool = xdsClientPool; + } + + public XdsClient getObject() { + return xdsClientPool.getObject(); + } + + public XdsClient returnObject(XdsClient xdsClient) { + return xdsClientPool.returnObject(xdsClient); + } } } diff --git a/xds/src/main/java/io/grpc/xds/InternalXdsAttributes.java b/xds/src/main/java/io/grpc/xds/InternalXdsAttributes.java index 1497eff048a..ed70e6f5e78 100644 --- a/xds/src/main/java/io/grpc/xds/InternalXdsAttributes.java +++ b/xds/src/main/java/io/grpc/xds/InternalXdsAttributes.java @@ -1,5 +1,5 @@ /* - * Copyright 2019 The gRPC Authors + * Copyright 2024 The gRPC Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,84 +18,19 @@ import io.grpc.Attributes; import io.grpc.EquivalentAddressGroup; -import io.grpc.Grpc; import io.grpc.Internal; -import io.grpc.NameResolver; -import io.grpc.internal.ObjectPool; -import io.grpc.xds.XdsNameResolverProvider.CallCounterProvider; -import io.grpc.xds.client.Locality; -import io.grpc.xds.client.XdsClient; -import io.grpc.xds.internal.security.SslContextProviderSupplier; /** * Internal attributes used for xDS implementation. Do not use. */ @Internal public final class InternalXdsAttributes { - - // TODO(sanjaypujare): move to xds internal package. - /** Attribute key for SslContextProviderSupplier (used from client) for a subchannel. */ - @Grpc.TransportAttr - public static final Attributes.Key - ATTR_SSL_CONTEXT_PROVIDER_SUPPLIER = - Attributes.Key.create("io.grpc.xds.internal.security.SslContextProviderSupplier"); - - /** - * Attribute key for passing around the XdsClient object pool across NameResolver/LoadBalancers. - */ - @NameResolver.ResolutionResultAttr - static final Attributes.Key> XDS_CLIENT_POOL = - Attributes.Key.create("io.grpc.xds.InternalXdsAttributes.xdsClientPool"); - - /** - * Attribute key for obtaining the global provider that provides atomics for aggregating - * outstanding RPCs sent to each cluster. - */ - @NameResolver.ResolutionResultAttr - static final Attributes.Key CALL_COUNTER_PROVIDER = - Attributes.Key.create("io.grpc.xds.InternalXdsAttributes.callCounterProvider"); - - /** - * Map from localities to their weights. - */ - @NameResolver.ResolutionResultAttr - static final Attributes.Key ATTR_LOCALITY_WEIGHT = - Attributes.Key.create("io.grpc.xds.InternalXdsAttributes.localityWeight"); - /** * Name of the cluster that provides this EquivalentAddressGroup. */ - @Internal @EquivalentAddressGroup.Attr public static final Attributes.Key ATTR_CLUSTER_NAME = - Attributes.Key.create("io.grpc.xds.InternalXdsAttributes.clusterName"); - - /** - * The locality that this EquivalentAddressGroup is in. - */ - @EquivalentAddressGroup.Attr - static final Attributes.Key ATTR_LOCALITY = - Attributes.Key.create("io.grpc.xds.InternalXdsAttributes.locality"); - - /** - * Endpoint weight for load balancing purposes. - */ - @EquivalentAddressGroup.Attr - static final Attributes.Key ATTR_SERVER_WEIGHT = - Attributes.Key.create("io.grpc.xds.InternalXdsAttributes.serverWeight"); - - /** - * Filter chain match for network filters. - */ - @Grpc.TransportAttr - static final Attributes.Key - ATTR_FILTER_CHAIN_SELECTOR_MANAGER = Attributes.Key.create( - "io.grpc.xds.InternalXdsAttributes.filterChainSelectorManager"); - - /** Grace time to use when draining. Null for an infinite grace time. */ - @Grpc.TransportAttr - static final Attributes.Key ATTR_DRAIN_GRACE_NANOS = - Attributes.Key.create("io.grpc.xds.InternalXdsAttributes.drainGraceTime"); + XdsAttributes.ATTR_CLUSTER_NAME; private InternalXdsAttributes() {} } diff --git a/xds/src/main/java/io/grpc/xds/LazyLoadBalancer.java b/xds/src/main/java/io/grpc/xds/LazyLoadBalancer.java new file mode 100644 index 00000000000..b5f09c4ea93 --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/LazyLoadBalancer.java @@ -0,0 +1,139 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds; + +import com.google.common.base.Preconditions; +import io.grpc.ConnectivityState; +import io.grpc.LoadBalancer; +import io.grpc.Status; +import io.grpc.util.ForwardingLoadBalancer; + +/** + * A load balancer that starts in IDLE instead of CONNECTING. Once it starts connecting, it + * instantiates its delegate. + */ +final class LazyLoadBalancer extends ForwardingLoadBalancer { + private LoadBalancer delegate; + + public LazyLoadBalancer(Helper helper, LoadBalancer.Factory delegateFactory) { + this.delegate = new LazyDelegate(helper, delegateFactory); + } + + @Override + protected LoadBalancer delegate() { + return delegate; + } + + @Override + public Status acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { + return delegate.acceptResolvedAddresses(resolvedAddresses); + } + + private final class LazyDelegate extends LoadBalancer { + private final Helper helper; + private final LoadBalancer.Factory delegateFactory; + private ResolvedAddresses addresses; + private Status error; + private boolean updatedBalancingState; + + public LazyDelegate(Helper helper, LoadBalancer.Factory delegateFactory) { + this.helper = Preconditions.checkNotNull(helper, "helper"); + this.delegateFactory = Preconditions.checkNotNull(delegateFactory, "delegateFactory"); + } + + private LoadBalancer activate() { + if (delegate != this) { + return delegate; + } + delegate = delegateFactory.newLoadBalancer(helper); + if (addresses != null) { + delegate.acceptResolvedAddresses(addresses); + } + if (error != null) { + delegate.handleNameResolutionError(error); + } + return delegate; + } + + @Override + public Status acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { + this.addresses = resolvedAddresses; + this.error = null; + initializeBalancingState(); + return Status.OK; + } + + @Override + public void handleNameResolutionError(Status error) { + // Preserve addresses, because even old addresses may be used by the real policy + this.error = error; + initializeBalancingState(); + } + + private void initializeBalancingState() { + if (updatedBalancingState) { + return; + } + helper.updateBalancingState(ConnectivityState.IDLE, new LazyPicker()); + updatedBalancingState = true; + } + + @Override + public void requestConnection() { + activate().requestConnection(); + } + + @Override + public void shutdown() { + delegate = new NoopLoadBalancer(); + } + + private final class LazyPicker extends SubchannelPicker { + @Override + public PickResult pickSubchannel(PickSubchannelArgs args) { + // activate() is a no-op after shutdown() + helper.getSynchronizationContext().execute(LazyDelegate.this::activate); + return PickResult.withNoResult(); + } + } + } + + public static final class Factory extends LoadBalancer.Factory { + private final LoadBalancer.Factory delegate; + + public Factory(LoadBalancer.Factory delegate) { + this.delegate = Preconditions.checkNotNull(delegate, "delegate"); + } + + @Override public LoadBalancer newLoadBalancer(Helper helper) { + return new LazyLoadBalancer(helper, delegate); + } + } + + private static final class NoopLoadBalancer extends LoadBalancer { + @Override + public Status acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { + return Status.OK; + } + + @Override + public void handleNameResolutionError(Status error) {} + + @Override + public void shutdown() {} + } +} diff --git a/xds/src/main/java/io/grpc/xds/LeastRequestLoadBalancer.java b/xds/src/main/java/io/grpc/xds/LeastRequestLoadBalancer.java index f96c171ee9c..1f23f2a4af5 100644 --- a/xds/src/main/java/io/grpc/xds/LeastRequestLoadBalancer.java +++ b/xds/src/main/java/io/grpc/xds/LeastRequestLoadBalancer.java @@ -32,7 +32,6 @@ import io.grpc.ClientStreamTracer; import io.grpc.ClientStreamTracer.StreamInfo; import io.grpc.ConnectivityState; -import io.grpc.EquivalentAddressGroup; import io.grpc.LoadBalancer; import io.grpc.LoadBalancerProvider; import io.grpc.Metadata; @@ -55,7 +54,7 @@ final class LeastRequestLoadBalancer extends MultiChildLoadBalancer { private final ThreadSafeRandom random; - private SubchannelPicker currentPicker = new EmptyPicker(); + private SubchannelPicker currentPicker = new FixedResultPicker(PickResult.withNoResult()); private int choiceCount = DEFAULT_CHOICE_COUNT; LeastRequestLoadBalancer(Helper helper) { @@ -114,7 +113,7 @@ protected void updateOverallBalancingState() { } } if (isConnecting) { - updateBalancingState(CONNECTING, new EmptyPicker()); + updateBalancingState(CONNECTING, new FixedResultPicker(PickResult.withNoResult())); } else { // Give it all the failing children and let it randomly pick among them updateBalancingState(TRANSIENT_FAILURE, @@ -126,9 +125,8 @@ protected void updateOverallBalancingState() { } @Override - protected ChildLbState createChildLbState(Object key, Object policyConfig, - SubchannelPicker initialPicker, ResolvedAddresses unused) { - return new LeastRequestLbState(key, pickFirstLbProvider, policyConfig, initialPicker); + protected ChildLbState createChildLbState(Object key) { + return new LeastRequestLbState(key, pickFirstLbProvider); } private void updateBalancingState(ConnectivityState state, SubchannelPicker picker) { @@ -156,7 +154,6 @@ private static AtomicInteger getInFlights(ChildLbState childLbState) { static final class ReadyPicker extends SubchannelPicker { private final List childPickers; // non-empty private final List childInFlights; // 1:1 with childPickers - private final List childEags; // 1:1 with childPickers private final int choiceCount; private final ThreadSafeRandom random; private final int hashCode; @@ -165,11 +162,9 @@ static final class ReadyPicker extends SubchannelPicker { checkArgument(!childLbStates.isEmpty(), "empty list"); this.childPickers = new ArrayList<>(childLbStates.size()); this.childInFlights = new ArrayList<>(childLbStates.size()); - this.childEags = new ArrayList<>(childLbStates.size()); for (ChildLbState state : childLbStates) { childPickers.add(state.getCurrentPicker()); childInFlights.add(getInFlights(state)); - childEags.add(state.getEag()); } this.choiceCount = choiceCount; this.random = checkNotNull(random, "random"); @@ -225,11 +220,6 @@ List getChildPickers() { return childPickers; } - @VisibleForTesting - List getChildEags() { - return childEags; - } - @Override public int hashCode() { return hashCode; @@ -320,13 +310,25 @@ public String toString() { protected class LeastRequestLbState extends ChildLbState { private final AtomicInteger activeRequests = new AtomicInteger(0); - public LeastRequestLbState(Object key, LoadBalancerProvider policyProvider, - Object childConfig, SubchannelPicker initialPicker) { - super(key, policyProvider, childConfig, initialPicker); + public LeastRequestLbState(Object key, LoadBalancerProvider policyProvider) { + super(key, policyProvider); } int getActiveRequests() { return activeRequests.get(); } + + @Override + protected ChildLbStateHelper createChildHelper() { + return new ChildLbStateHelper() { + @Override + public void updateBalancingState(ConnectivityState newState, SubchannelPicker newPicker) { + super.updateBalancingState(newState, newPicker); + if (!resolvingAddresses && newState == IDLE) { + getLb().requestConnection(); + } + } + }; + } } } diff --git a/xds/src/main/java/io/grpc/xds/LoadBalancerConfigFactory.java b/xds/src/main/java/io/grpc/xds/LoadBalancerConfigFactory.java index 526c18584e6..5fd8ec5526e 100644 --- a/xds/src/main/java/io/grpc/xds/LoadBalancerConfigFactory.java +++ b/xds/src/main/java/io/grpc/xds/LoadBalancerConfigFactory.java @@ -91,6 +91,7 @@ class LoadBalancerConfigFactory { static final String SHUFFLE_ADDRESS_LIST_FIELD_NAME = "shuffleAddressList"; static final String ERROR_UTILIZATION_PENALTY = "errorUtilizationPenalty"; + static final String METRIC_NAMES_FOR_COMPUTING_UTILIZATION = "metricNamesForComputingUtilization"; /** * Factory method for creating a new {link LoadBalancerConfigConverter} for a given xDS {@link @@ -98,15 +99,14 @@ class LoadBalancerConfigFactory { * * @throws ResourceInvalidException If the {@link Cluster} has an invalid LB configuration. */ - static ImmutableMap newConfig(Cluster cluster, boolean enableLeastRequest, - boolean enableWrr, boolean enablePickFirst) + static ImmutableMap newConfig(Cluster cluster, boolean enableLeastRequest) throws ResourceInvalidException { // The new load_balancing_policy will always be used if it is set, but for backward // compatibility we will fall back to using the old lb_policy field if the new field is not set. if (cluster.hasLoadBalancingPolicy()) { try { return LoadBalancingPolicyConverter.convertToServiceConfig(cluster.getLoadBalancingPolicy(), - 0, enableWrr, enablePickFirst); + 0); } catch (MaxRecursionReachedException e) { throw new ResourceInvalidException("Maximum LB config recursion depth reached", e); } @@ -135,11 +135,9 @@ class LoadBalancerConfigFactory { * the given config values. */ private static ImmutableMap buildWrrConfig(String blackoutPeriod, - String weightExpirationPeriod, - String oobReportingPeriod, - Boolean enableOobLoadReport, - String weightUpdatePeriod, - Float errorUtilizationPenalty) { + String weightExpirationPeriod, String oobReportingPeriod, Boolean enableOobLoadReport, + String weightUpdatePeriod, Float errorUtilizationPenalty, + ImmutableList metricNamesForComputingUtilization) { ImmutableMap.Builder configBuilder = ImmutableMap.builder(); if (blackoutPeriod != null) { configBuilder.put(BLACK_OUT_PERIOD, blackoutPeriod); @@ -159,6 +157,10 @@ class LoadBalancerConfigFactory { if (errorUtilizationPenalty != null) { configBuilder.put(ERROR_UTILIZATION_PENALTY, errorUtilizationPenalty); } + if (metricNamesForComputingUtilization != null + && !metricNamesForComputingUtilization.isEmpty()) { + configBuilder.put(METRIC_NAMES_FOR_COMPUTING_UTILIZATION, metricNamesForComputingUtilization); + } return ImmutableMap.of(WeightedRoundRobinLoadBalancerProvider.SCHEME, configBuilder.buildOrThrow()); } @@ -213,8 +215,7 @@ static class LoadBalancingPolicyConverter { * Converts a {@link LoadBalancingPolicy} object to a service config JSON object. */ private static ImmutableMap convertToServiceConfig( - LoadBalancingPolicy loadBalancingPolicy, int recursionDepth, boolean enableWrr, - boolean enablePickFirst) + LoadBalancingPolicy loadBalancingPolicy, int recursionDepth) throws ResourceInvalidException, MaxRecursionReachedException { if (recursionDepth > MAX_RECURSION) { throw new MaxRecursionReachedException(); @@ -228,20 +229,16 @@ static class LoadBalancingPolicyConverter { serviceConfig = convertRingHashConfig(typedConfig.unpack(RingHash.class)); } else if (typedConfig.is(WrrLocality.class)) { serviceConfig = convertWrrLocalityConfig(typedConfig.unpack(WrrLocality.class), - recursionDepth, enableWrr, enablePickFirst); + recursionDepth); } else if (typedConfig.is(RoundRobin.class)) { serviceConfig = convertRoundRobinConfig(); } else if (typedConfig.is(LeastRequest.class)) { serviceConfig = convertLeastRequestConfig(typedConfig.unpack(LeastRequest.class)); } else if (typedConfig.is(ClientSideWeightedRoundRobin.class)) { - if (enableWrr) { - serviceConfig = convertWeightedRoundRobinConfig( - typedConfig.unpack(ClientSideWeightedRoundRobin.class)); - } + serviceConfig = convertWeightedRoundRobinConfig( + typedConfig.unpack(ClientSideWeightedRoundRobin.class)); } else if (typedConfig.is(PickFirst.class)) { - if (enablePickFirst) { - serviceConfig = convertPickFirstConfig(typedConfig.unpack(PickFirst.class)); - } + serviceConfig = convertPickFirstConfig(typedConfig.unpack(PickFirst.class)); } else if (typedConfig.is(com.github.xds.type.v3.TypedStruct.class)) { serviceConfig = convertCustomConfig( typedConfig.unpack(com.github.xds.type.v3.TypedStruct.class)); @@ -290,7 +287,7 @@ static class LoadBalancingPolicyConverter { } private static ImmutableMap convertWeightedRoundRobinConfig( - ClientSideWeightedRoundRobin wrr) throws ResourceInvalidException { + ClientSideWeightedRoundRobin wrr) throws ResourceInvalidException { try { return buildWrrConfig( wrr.hasBlackoutPeriod() ? Durations.toString(wrr.getBlackoutPeriod()) : null, @@ -299,7 +296,8 @@ static class LoadBalancingPolicyConverter { wrr.hasOobReportingPeriod() ? Durations.toString(wrr.getOobReportingPeriod()) : null, wrr.hasEnableOobLoadReport() ? wrr.getEnableOobLoadReport().getValue() : null, wrr.hasWeightUpdatePeriod() ? Durations.toString(wrr.getWeightUpdatePeriod()) : null, - wrr.hasErrorUtilizationPenalty() ? wrr.getErrorUtilizationPenalty().getValue() : null); + wrr.hasErrorUtilizationPenalty() ? wrr.getErrorUtilizationPenalty().getValue() : null, + ImmutableList.copyOf(wrr.getMetricNamesForComputingUtilizationList())); } catch (IllegalArgumentException ex) { throw new ResourceInvalidException("Invalid duration in weighted round robin config: " + ex.getMessage()); @@ -310,12 +308,10 @@ static class LoadBalancingPolicyConverter { * Converts a wrr_locality {@link Any} configuration to service config format. */ private static ImmutableMap convertWrrLocalityConfig(WrrLocality wrrLocality, - int recursionDepth, boolean enableWrr, boolean enablePickFirst) - throws ResourceInvalidException, - MaxRecursionReachedException { + int recursionDepth) + throws ResourceInvalidException, MaxRecursionReachedException { return buildWrrLocalityConfig( - convertToServiceConfig(wrrLocality.getEndpointPickingPolicy(), - recursionDepth + 1, enableWrr, enablePickFirst)); + convertToServiceConfig(wrrLocality.getEndpointPickingPolicy(), recursionDepth + 1)); } /** diff --git a/xds/src/main/java/io/grpc/xds/MessagePrinter.java b/xds/src/main/java/io/grpc/xds/MessagePrinter.java index 5927bfd517e..d6fdaa81dd7 100644 --- a/xds/src/main/java/io/grpc/xds/MessagePrinter.java +++ b/xds/src/main/java/io/grpc/xds/MessagePrinter.java @@ -16,6 +16,7 @@ package io.grpc.xds; +import com.github.xds.type.v3.TypedStruct; import com.google.protobuf.Descriptors.Descriptor; import com.google.protobuf.InvalidProtocolBufferException; import com.google.protobuf.Message; @@ -32,8 +33,11 @@ import io.envoyproxy.envoy.extensions.filters.http.rbac.v3.RBACPerRoute; import io.envoyproxy.envoy.extensions.filters.http.router.v3.Router; import io.envoyproxy.envoy.extensions.filters.network.http_connection_manager.v3.HttpConnectionManager; +import io.envoyproxy.envoy.extensions.load_balancing_policies.round_robin.v3.RoundRobin; +import io.envoyproxy.envoy.extensions.load_balancing_policies.wrr_locality.v3.WrrLocality; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.DownstreamTlsContext; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.UpstreamTlsContext; +import io.envoyproxy.envoy.service.discovery.v3.Resource; import io.grpc.xds.client.MessagePrettyPrinter; /** @@ -52,6 +56,7 @@ private static class LazyHolder { private static JsonFormat.Printer newPrinter() { TypeRegistry.Builder registry = TypeRegistry.newBuilder() + .add(Resource.getDescriptor()) .add(Listener.getDescriptor()) .add(HttpConnectionManager.getDescriptor()) .add(HTTPFault.getDescriptor()) @@ -65,7 +70,10 @@ private static JsonFormat.Printer newPrinter() { .add(RouteConfiguration.getDescriptor()) .add(Cluster.getDescriptor()) .add(ClusterConfig.getDescriptor()) - .add(ClusterLoadAssignment.getDescriptor()); + .add(ClusterLoadAssignment.getDescriptor()) + .add(WrrLocality.getDescriptor()) + .add(TypedStruct.getDescriptor()) + .add(RoundRobin.getDescriptor()); try { @SuppressWarnings("unchecked") Class routeLookupClusterSpecifierClass = diff --git a/xds/src/main/java/io/grpc/xds/MetadataRegistry.java b/xds/src/main/java/io/grpc/xds/MetadataRegistry.java new file mode 100644 index 00000000000..b79a61a261a --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/MetadataRegistry.java @@ -0,0 +1,125 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.ImmutableMap; +import com.google.protobuf.Any; +import com.google.protobuf.Struct; +import io.envoyproxy.envoy.config.core.v3.Metadata; +import io.grpc.xds.GcpAuthenticationFilter.AudienceMetadataParser; +import io.grpc.xds.XdsEndpointResource.AddressMetadataParser; +import io.grpc.xds.client.XdsResourceType.ResourceInvalidException; +import io.grpc.xds.internal.ProtobufJsonConverter; +import java.util.HashMap; +import java.util.Map; + +/** + * Registry for parsing cluster metadata values. + * + *

This class maintains a mapping of type URLs to {@link MetadataValueParser} instances, + * allowing for the parsing of different metadata types. + */ +final class MetadataRegistry { + private static final MetadataRegistry INSTANCE = new MetadataRegistry(); + + private final Map supportedParsers = new HashMap<>(); + + private MetadataRegistry() { + registerParser(new AudienceMetadataParser()); + registerParser(new AddressMetadataParser()); + } + + static MetadataRegistry getInstance() { + return INSTANCE; + } + + MetadataValueParser findParser(String typeUrl) { + return supportedParsers.get(typeUrl); + } + + @VisibleForTesting + void registerParser(MetadataValueParser parser) { + supportedParsers.put(parser.getTypeUrl(), parser); + } + + void removeParser(MetadataValueParser parser) { + supportedParsers.remove(parser.getTypeUrl()); + } + + /** + * Parses cluster metadata into a structured map. + * + *

Values in {@code typed_filter_metadata} take precedence over + * {@code filter_metadata} when keys overlap, following Envoy API behavior. See + * + * Envoy metadata documentation for details. + * + * @param metadata the {@link Metadata} containing the fields to parse. + * @return an immutable map of parsed metadata. + * @throws ResourceInvalidException if parsing {@code typed_filter_metadata} fails. + */ + public ImmutableMap parseMetadata(Metadata metadata) + throws ResourceInvalidException { + ImmutableMap.Builder parsedMetadata = ImmutableMap.builder(); + + // Process typed_filter_metadata + for (Map.Entry entry : metadata.getTypedFilterMetadataMap().entrySet()) { + String key = entry.getKey(); + Any value = entry.getValue(); + MetadataValueParser parser = findParser(value.getTypeUrl()); + if (parser != null) { + try { + Object parsedValue = parser.parse(value); + parsedMetadata.put(key, parsedValue); + } catch (ResourceInvalidException e) { + throw new ResourceInvalidException( + String.format("Failed to parse metadata key: %s, type: %s. Error: %s", + key, value.getTypeUrl(), e.getMessage()), e); + } + } + } + // building once to reuse in the next loop + ImmutableMap intermediateParsedMetadata = parsedMetadata.build(); + + // Process filter_metadata for remaining keys + for (Map.Entry entry : metadata.getFilterMetadataMap().entrySet()) { + String key = entry.getKey(); + if (!intermediateParsedMetadata.containsKey(key)) { + Struct structValue = entry.getValue(); + Object jsonValue = ProtobufJsonConverter.convertToJson(structValue); + parsedMetadata.put(key, jsonValue); + } + } + + return parsedMetadata.build(); + } + + interface MetadataValueParser { + + String getTypeUrl(); + + /** + * Parses the given {@link Any} object into a specific metadata value. + * + * @param any the {@link Any} object to parse. + * @return the parsed metadata value. + * @throws ResourceInvalidException if the parsing fails. + */ + Object parse(Any any) throws ResourceInvalidException; + } +} diff --git a/xds/src/main/java/io/grpc/xds/PriorityLoadBalancer.java b/xds/src/main/java/io/grpc/xds/PriorityLoadBalancer.java index 96ab40020b7..ca142af0af3 100644 --- a/xds/src/main/java/io/grpc/xds/PriorityLoadBalancer.java +++ b/xds/src/main/java/io/grpc/xds/PriorityLoadBalancer.java @@ -25,11 +25,10 @@ import io.grpc.ConnectivityState; import io.grpc.InternalLogId; import io.grpc.LoadBalancer; -import io.grpc.LoadBalancerProvider; import io.grpc.Status; import io.grpc.SynchronizationContext; import io.grpc.SynchronizationContext.ScheduledHandle; -import io.grpc.internal.ServiceConfigUtil.PolicySelection; +import io.grpc.internal.GrpcUtil; import io.grpc.util.ForwardingLoadBalancerHelper; import io.grpc.util.GracefulSwitchLoadBalancer; import io.grpc.xds.PriorityLoadBalancerProvider.PriorityLbConfig; @@ -75,6 +74,8 @@ final class PriorityLoadBalancer extends LoadBalancer { private SubchannelPicker currentPicker; // Set to true if currently in the process of handling resolved addresses. private boolean handlingResolvedAddresses; + static boolean enablePriorityLbChildPolicyCache = + GrpcUtil.getFlag("GRPC_EXPERIMENTAL_ENABLE_PRIORITY_LB_CHILD_POLICY_CACHE", false); PriorityLoadBalancer(Helper helper) { this.helper = checkNotNull(helper, "helper"); @@ -93,13 +94,19 @@ public Status acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { checkNotNull(config, "missing priority lb config"); priorityNames = config.priorities; priorityConfigs = config.childConfigs; + Status status = Status.OK; Set prioritySet = new HashSet<>(config.priorities); ArrayList childKeys = new ArrayList<>(children.keySet()); for (String priority : childKeys) { if (!prioritySet.contains(priority)) { ChildLbState childLbState = children.get(priority); if (childLbState != null) { - childLbState.deactivate(); + if (enablePriorityLbChildPolicyCache) { + childLbState.deactivate(); + } else { + childLbState.tearDown(); + children.remove(priority); + } } } } @@ -107,12 +114,18 @@ public Status acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { for (String priority : priorityNames) { ChildLbState childLbState = children.get(priority); if (childLbState != null) { - childLbState.updateResolvedAddresses(); + Status newStatus = childLbState.updateResolvedAddresses(); + if (!newStatus.isOk()) { + status = newStatus; + } } } handlingResolvedAddresses = false; - tryNextPriority(); - return Status.OK; + Status newStatus = tryNextPriority(); + if (!newStatus.isOk()) { + status = newStatus; + } + return status; } @Override @@ -142,19 +155,19 @@ public void shutdown() { children.clear(); } - private void tryNextPriority() { + private Status tryNextPriority() { for (int i = 0; i < priorityNames.size(); i++) { String priority = priorityNames.get(i); if (!children.containsKey(priority)) { ChildLbState child = new ChildLbState(priority, priorityConfigs.get(priority).ignoreReresolution); children.put(priority, child); - updateOverallState(priority, CONNECTING, new FixedResultPicker(PickResult.withNoResult())); + // Child is created in CONNECTING with pending failOverTimer + updateOverallState(priority, child.connectivityState, child.picker); // Calling the child's updateResolvedAddresses() can result in tryNextPriority() being // called recursively. We need to be sure to be done with processing here before it is // called. - child.updateResolvedAddresses(); - return; // Give priority i time to connect. + return child.updateResolvedAddresses(); // Give priority i time to connect. } ChildLbState child = children.get(priority); child.reactivate(); @@ -167,23 +180,26 @@ private void tryNextPriority() { children.get(p).deactivate(); } } - return; + return Status.OK; } - if (child.failOverTimer != null && child.failOverTimer.isPending()) { + if (child.failOverTimer.isPending()) { updateOverallState(priority, child.connectivityState, child.picker); - return; // Give priority i time to connect. + return Status.OK; // Give priority i time to connect. } - if (priority.equals(currentPriority) && child.connectivityState != TRANSIENT_FAILURE) { - // If the current priority is not changed into TRANSIENT_FAILURE, keep using it. + } + for (int i = 0; i < priorityNames.size(); i++) { + String priority = priorityNames.get(i); + ChildLbState child = children.get(priority); + if (child.connectivityState.equals(CONNECTING)) { updateOverallState(priority, child.connectivityState, child.picker); - return; + return Status.OK; } } - // TODO(zdapeng): Include error details of each priority. logger.log(XdsLogLevel.DEBUG, "All priority failed"); String lastPriority = priorityNames.get(priorityNames.size() - 1); - SubchannelPicker errorPicker = children.get(lastPriority).picker; - updateOverallState(lastPriority, TRANSIENT_FAILURE, errorPicker); + ChildLbState child = children.get(lastPriority); + updateOverallState(lastPriority, child.connectivityState, child.picker); + return Status.OK; } private void updateOverallState( @@ -208,7 +224,6 @@ private final class ChildLbState { // Timer to delay shutdown and deletion of the priority. Scheduled whenever the child is // deactivated. @Nullable ScheduledHandle deletionTimer; - @Nullable String policy; ConnectivityState connectivityState = CONNECTING; SubchannelPicker picker = new FixedResultPicker(PickResult.withNoResult()); @@ -227,11 +242,12 @@ public void run() { // The child is deactivated. return; } - picker = new FixedResultPicker(PickResult.withError( - Status.UNAVAILABLE.withDescription("Connection timeout for priority " + priority))); logger.log(XdsLogLevel.DEBUG, "Priority {0} failed over to next", priority); - currentPriority = null; // reset currentPriority to guarantee failover happen - tryNextPriority(); + Status status = tryNextPriority(); + if (!status.isOk()) { + // A child had a problem with the addresses/config. Request it to be refreshed + helper.refreshNameResolution(); + } } } @@ -282,20 +298,13 @@ void tearDown() { * resolvedAddresses}, or when priority lb receives a new resolved addresses while the child * already exists. */ - void updateResolvedAddresses() { + Status updateResolvedAddresses() { PriorityLbConfig config = (PriorityLbConfig) resolvedAddresses.getLoadBalancingPolicyConfig(); - PolicySelection childPolicySelection = config.childConfigs.get(priority).policySelection; - LoadBalancerProvider lbProvider = childPolicySelection.getProvider(); - String newPolicy = lbProvider.getPolicyName(); - if (!newPolicy.equals(policy)) { - policy = newPolicy; - lb.switchTo(lbProvider); - } - lb.handleResolvedAddresses( + return lb.acceptResolvedAddresses( resolvedAddresses.toBuilder() .setAddresses(AddressFilter.filter(resolvedAddresses.getAddresses(), priority)) - .setLoadBalancingPolicyConfig(childPolicySelection.getConfig()) + .setLoadBalancingPolicyConfig(config.childConfigs.get(priority).childConfig) .build()); } @@ -319,13 +328,14 @@ public void updateBalancingState(final ConnectivityState newState, if (!children.containsKey(priority)) { return; } + ConnectivityState oldState = connectivityState; connectivityState = newState; picker = newPicker; if (deletionTimer != null && deletionTimer.isPending()) { return; } - if (newState.equals(CONNECTING)) { + if (newState.equals(CONNECTING) && !oldState.equals(newState)) { if (!failOverTimer.isPending() && seenReadyOrIdleSinceTransientFailure) { failOverTimer = syncContext.schedule(new FailOverTask(), 10, TimeUnit.SECONDS, executor); @@ -341,7 +351,11 @@ public void updateBalancingState(final ConnectivityState newState, // If we are currently handling newly resolved addresses, let's not try to reconfigure as // the address handling process will take care of that to provide an atomic config update. if (!handlingResolvedAddresses) { - tryNextPriority(); + Status status = tryNextPriority(); + if (!status.isOk()) { + // A child had a problem with the addresses/config. Request it to be refreshed + helper.refreshNameResolution(); + } } } diff --git a/xds/src/main/java/io/grpc/xds/PriorityLoadBalancerProvider.java b/xds/src/main/java/io/grpc/xds/PriorityLoadBalancerProvider.java index 2bc561268fb..1aab6c31b08 100644 --- a/xds/src/main/java/io/grpc/xds/PriorityLoadBalancerProvider.java +++ b/xds/src/main/java/io/grpc/xds/PriorityLoadBalancerProvider.java @@ -26,7 +26,6 @@ import io.grpc.LoadBalancerProvider; import io.grpc.NameResolver.ConfigOrError; import io.grpc.Status; -import io.grpc.internal.ServiceConfigUtil.PolicySelection; import java.util.Collections; import java.util.HashSet; import java.util.List; @@ -90,18 +89,18 @@ public String toString() { } static final class PriorityChildConfig { - final PolicySelection policySelection; + final Object childConfig; final boolean ignoreReresolution; - PriorityChildConfig(PolicySelection policySelection, boolean ignoreReresolution) { - this.policySelection = checkNotNull(policySelection, "policySelection"); + PriorityChildConfig(Object childConfig, boolean ignoreReresolution) { + this.childConfig = checkNotNull(childConfig, "childConfig"); this.ignoreReresolution = ignoreReresolution; } @Override public String toString() { return MoreObjects.toStringHelper(this) - .add("policySelection", policySelection) + .add("childConfig", childConfig) .add("ignoreReresolution", ignoreReresolution) .toString(); } diff --git a/xds/src/main/java/io/grpc/xds/RbacFilter.java b/xds/src/main/java/io/grpc/xds/RbacFilter.java index 6a55f7f193e..58cc46fad3e 100644 --- a/xds/src/main/java/io/grpc/xds/RbacFilter.java +++ b/xds/src/main/java/io/grpc/xds/RbacFilter.java @@ -18,7 +18,6 @@ import static com.google.common.base.Preconditions.checkNotNull; -import com.google.common.annotations.VisibleForTesting; import com.google.protobuf.Any; import com.google.protobuf.InvalidProtocolBufferException; import com.google.protobuf.Message; @@ -34,7 +33,6 @@ import io.grpc.ServerCallHandler; import io.grpc.ServerInterceptor; import io.grpc.Status; -import io.grpc.xds.Filter.ServerInterceptorBuilder; import io.grpc.xds.internal.MatcherParser; import io.grpc.xds.internal.Matchers; import io.grpc.xds.internal.rbac.engine.GrpcAuthorizationEngine; @@ -66,10 +64,10 @@ import javax.annotation.Nullable; /** RBAC Http filter implementation. */ -final class RbacFilter implements Filter, ServerInterceptorBuilder { +final class RbacFilter implements Filter { private static final Logger logger = Logger.getLogger(RbacFilter.class.getName()); - static final RbacFilter INSTANCE = new RbacFilter(); + private static final RbacFilter INSTANCE = new RbacFilter(); static final String TYPE_URL = "type.googleapis.com/envoy.extensions.filters.http.rbac.v3.RBAC"; @@ -77,87 +75,101 @@ final class RbacFilter implements Filter, ServerInterceptorBuilder { private static final String TYPE_URL_OVERRIDE_CONFIG = "type.googleapis.com/envoy.extensions.filters.http.rbac.v3.RBACPerRoute"; - RbacFilter() {} + private RbacFilter() {} - @Override - public String[] typeUrls() { - return new String[] { TYPE_URL, TYPE_URL_OVERRIDE_CONFIG }; - } + static final class Provider implements Filter.Provider { + @Override + public String[] typeUrls() { + return new String[] {TYPE_URL, TYPE_URL_OVERRIDE_CONFIG}; + } - @Override - public ConfigOrError parseFilterConfig(Message rawProtoMessage) { - RBAC rbacProto; - if (!(rawProtoMessage instanceof Any)) { - return ConfigOrError.fromError("Invalid config type: " + rawProtoMessage.getClass()); + @Override + public boolean isServerFilter() { + return true; } - Any anyMessage = (Any) rawProtoMessage; - try { - rbacProto = anyMessage.unpack(RBAC.class); - } catch (InvalidProtocolBufferException e) { - return ConfigOrError.fromError("Invalid proto: " + e); + + @Override + public RbacFilter newInstance(String name) { + return INSTANCE; } - return parseRbacConfig(rbacProto); - } - @VisibleForTesting - static ConfigOrError parseRbacConfig(RBAC rbac) { - if (!rbac.hasRules()) { - return ConfigOrError.fromConfig(RbacConfig.create(null)); + @Override + public ConfigOrError parseFilterConfig( + Message rawProtoMessage, FilterConfigParseContext context) { + RBAC rbacProto; + if (!(rawProtoMessage instanceof Any)) { + return ConfigOrError.fromError("Invalid config type: " + rawProtoMessage.getClass()); + } + Any anyMessage = (Any) rawProtoMessage; + try { + rbacProto = anyMessage.unpack(RBAC.class); + } catch (InvalidProtocolBufferException e) { + return ConfigOrError.fromError("Invalid proto: " + e); + } + return parseRbacConfig(rbacProto); } - io.envoyproxy.envoy.config.rbac.v3.RBAC rbacConfig = rbac.getRules(); - GrpcAuthorizationEngine.Action authAction; - switch (rbacConfig.getAction()) { - case ALLOW: - authAction = GrpcAuthorizationEngine.Action.ALLOW; - break; - case DENY: - authAction = GrpcAuthorizationEngine.Action.DENY; - break; - case LOG: + + @Override + public ConfigOrError parseFilterConfigOverride( + Message rawProtoMessage, FilterConfigParseContext context) { + RBACPerRoute rbacPerRoute; + if (!(rawProtoMessage instanceof Any)) { + return ConfigOrError.fromError("Invalid config type: " + rawProtoMessage.getClass()); + } + Any anyMessage = (Any) rawProtoMessage; + try { + rbacPerRoute = anyMessage.unpack(RBACPerRoute.class); + } catch (InvalidProtocolBufferException e) { + return ConfigOrError.fromError("Invalid proto: " + e); + } + if (rbacPerRoute.hasRbac()) { + return parseRbacConfig(rbacPerRoute.getRbac()); + } else { return ConfigOrError.fromConfig(RbacConfig.create(null)); - case UNRECOGNIZED: - default: - return ConfigOrError.fromError("Unknown rbacConfig action type: " + rbacConfig.getAction()); + } } - List policyMatchers = new ArrayList<>(); - List> sortedPolicyEntries = rbacConfig.getPoliciesMap().entrySet() - .stream() - .sorted((a,b) -> a.getKey().compareTo(b.getKey())) - .collect(Collectors.toList()); - for (Map.Entry entry: sortedPolicyEntries) { - try { - Policy policy = entry.getValue(); - if (policy.hasCondition() || policy.hasCheckedCondition()) { + + static ConfigOrError parseRbacConfig(RBAC rbac) { + if (!rbac.hasRules()) { + return ConfigOrError.fromConfig(RbacConfig.create(null)); + } + io.envoyproxy.envoy.config.rbac.v3.RBAC rbacConfig = rbac.getRules(); + GrpcAuthorizationEngine.Action authAction; + switch (rbacConfig.getAction()) { + case ALLOW: + authAction = GrpcAuthorizationEngine.Action.ALLOW; + break; + case DENY: + authAction = GrpcAuthorizationEngine.Action.DENY; + break; + case LOG: + return ConfigOrError.fromConfig(RbacConfig.create(null)); + case UNRECOGNIZED: + default: return ConfigOrError.fromError( - "Policy.condition and Policy.checked_condition must not set: " + entry.getKey()); + "Unknown rbacConfig action type: " + rbacConfig.getAction()); + } + List policyMatchers = new ArrayList<>(); + List> sortedPolicyEntries = rbacConfig.getPoliciesMap().entrySet() + .stream() + .sorted((a,b) -> a.getKey().compareTo(b.getKey())) + .collect(Collectors.toList()); + for (Map.Entry entry: sortedPolicyEntries) { + try { + Policy policy = entry.getValue(); + if (policy.hasCondition() || policy.hasCheckedCondition()) { + return ConfigOrError.fromError( + "Policy.condition and Policy.checked_condition must not set: " + entry.getKey()); + } + policyMatchers.add(PolicyMatcher.create(entry.getKey(), + parsePermissionList(policy.getPermissionsList()), + parsePrincipalList(policy.getPrincipalsList()))); + } catch (Exception e) { + return ConfigOrError.fromError("Encountered error parsing policy: " + e); } - policyMatchers.add(PolicyMatcher.create(entry.getKey(), - parsePermissionList(policy.getPermissionsList()), - parsePrincipalList(policy.getPrincipalsList()))); - } catch (Exception e) { - return ConfigOrError.fromError("Encountered error parsing policy: " + e); } - } - return ConfigOrError.fromConfig(RbacConfig.create( - AuthConfig.create(policyMatchers, authAction))); - } - - @Override - public ConfigOrError parseFilterConfigOverride(Message rawProtoMessage) { - RBACPerRoute rbacPerRoute; - if (!(rawProtoMessage instanceof Any)) { - return ConfigOrError.fromError("Invalid config type: " + rawProtoMessage.getClass()); - } - Any anyMessage = (Any) rawProtoMessage; - try { - rbacPerRoute = anyMessage.unpack(RBACPerRoute.class); - } catch (InvalidProtocolBufferException e) { - return ConfigOrError.fromError("Invalid proto: " + e); - } - if (rbacPerRoute.hasRbac()) { - return parseRbacConfig(rbacPerRoute.getRbac()); - } else { - return ConfigOrError.fromConfig(RbacConfig.create(null)); + return ConfigOrError.fromConfig(RbacConfig.create( + AuthConfig.create(policyMatchers, authAction))); } } @@ -266,8 +278,13 @@ private static Matcher parsePrincipal(Principal principal) { return createSourceIpMatcher(principal.getDirectRemoteIp()); case REMOTE_IP: return createSourceIpMatcher(principal.getRemoteIp()); - case SOURCE_IP: - return createSourceIpMatcher(principal.getSourceIp()); + case SOURCE_IP: { + // gRFC A41 has identical handling of source_ip as remote_ip and direct_remote_ip and + // pre-dates the deprecation. + @SuppressWarnings("deprecation") + CidrRange sourceIp = principal.getSourceIp(); + return createSourceIpMatcher(sourceIp); + } case HEADER: return parseHeaderMatcher(principal.getHeader()); case NOT_ID: diff --git a/xds/src/main/java/io/grpc/xds/RingHashLoadBalancer.java b/xds/src/main/java/io/grpc/xds/RingHashLoadBalancer.java index 76bac3118d8..513f4d643ea 100644 --- a/xds/src/main/java/io/grpc/xds/RingHashLoadBalancer.java +++ b/xds/src/main/java/io/grpc/xds/RingHashLoadBalancer.java @@ -25,9 +25,10 @@ import static io.grpc.ConnectivityState.SHUTDOWN; import static io.grpc.ConnectivityState.TRANSIENT_FAILURE; +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Joiner; import com.google.common.base.MoreObjects; import com.google.common.collect.HashMultiset; -import com.google.common.collect.ImmutableMap; import com.google.common.collect.Multiset; import com.google.common.primitives.UnsignedInteger; import io.grpc.Attributes; @@ -35,19 +36,22 @@ import io.grpc.EquivalentAddressGroup; import io.grpc.InternalLogId; import io.grpc.LoadBalancer; -import io.grpc.LoadBalancerProvider; +import io.grpc.Metadata; import io.grpc.Status; import io.grpc.SynchronizationContext; import io.grpc.util.MultiChildLoadBalancer; +import io.grpc.xds.ThreadSafeRandom.ThreadSafeRandomImpl; import io.grpc.xds.client.XdsLogger; import io.grpc.xds.client.XdsLogger.XdsLogLevel; import java.net.SocketAddress; import java.util.ArrayList; +import java.util.Collection; import java.util.Collections; import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.Set; import java.util.stream.Collectors; import javax.annotation.Nullable; @@ -66,15 +70,25 @@ final class RingHashLoadBalancer extends MultiChildLoadBalancer { + " config selector always generates a hash."); private static final XxHash64 hashFunc = XxHash64.INSTANCE; + private final LoadBalancer.Factory lazyLbFactory = + new LazyLoadBalancer.Factory(pickFirstLbProvider); private final XdsLogger logger; private final SynchronizationContext syncContext; + private final ThreadSafeRandom random; private List ring; + @Nullable private Metadata.Key requestHashHeaderKey; RingHashLoadBalancer(Helper helper) { + this(helper, ThreadSafeRandomImpl.instance); + } + + @VisibleForTesting + RingHashLoadBalancer(Helper helper, ThreadSafeRandom random) { super(helper); syncContext = checkNotNull(helper.getSynchronizationContext(), "syncContext"); logger = XdsLogger.withLogId(InternalLogId.allocate("ring_hash_lb", helper.getAuthority())); logger.log(XdsLogLevel.INFO, "Created"); + this.random = checkNotNull(random, "random"); } @Override @@ -86,71 +100,50 @@ public Status acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { return addressValidityStatus; } - try { - resolvingAddresses = true; - // Subclass handles any special manipulation to create appropriate types of ChildLbStates - Map newChildren = createChildLbMap(resolvedAddresses); - - if (newChildren.isEmpty()) { - addressValidityStatus = Status.UNAVAILABLE.withDescription( - "Ring hash lb error: EDS resolution was successful, but there were no valid addresses"); - handleNameResolutionError(addressValidityStatus); - return addressValidityStatus; - } - - // We don't care about reuse because we don't want to activate them - addMissingChildrenAndIdReuse(newChildren); - updateChildrenWithResolvedAddresses(resolvedAddresses, newChildren); - - // Now do the ringhash specific logic with weights and building the ring - RingHashConfig config = (RingHashConfig) resolvedAddresses.getLoadBalancingPolicyConfig(); - if (config == null) { - throw new IllegalArgumentException("Missing RingHash configuration"); + // Now do the ringhash specific logic with weights and building the ring + RingHashConfig config = (RingHashConfig) resolvedAddresses.getLoadBalancingPolicyConfig(); + if (config == null) { + throw new IllegalArgumentException("Missing RingHash configuration"); + } + requestHashHeaderKey = + config.requestHashHeader.isEmpty() + ? null + : Metadata.Key.of(config.requestHashHeader, Metadata.ASCII_STRING_MARSHALLER); + Map serverWeights = new HashMap<>(); + long totalWeight = 0L; + for (EquivalentAddressGroup eag : addrList) { + Long weight = eag.getAttributes().get(XdsAttributes.ATTR_SERVER_WEIGHT); + // Support two ways of server weighing: either multiple instances of the same address + // or each address contains a per-address weight attribute. If a weight is not provided, + // each occurrence of the address will be counted a weight value of one. + if (weight == null) { + weight = 1L; } - Map serverWeights = new HashMap<>(); - long totalWeight = 0L; - for (EquivalentAddressGroup eag : addrList) { - Long weight = eag.getAttributes().get(InternalXdsAttributes.ATTR_SERVER_WEIGHT); - // Support two ways of server weighing: either multiple instances of the same address - // or each address contains a per-address weight attribute. If a weight is not provided, - // each occurrence of the address will be counted a weight value of one. - if (weight == null) { - weight = 1L; - } - totalWeight += weight; - EquivalentAddressGroup addrKey = stripAttrs(eag); - if (serverWeights.containsKey(addrKey)) { - serverWeights.put(addrKey, serverWeights.get(addrKey) + weight); - } else { - serverWeights.put(addrKey, weight); - } + totalWeight += weight; + EquivalentAddressGroup addrKey = stripAttrs(eag); + if (serverWeights.containsKey(addrKey)) { + serverWeights.put(addrKey, serverWeights.get(addrKey) + weight); + } else { + serverWeights.put(addrKey, weight); } - // Calculate scale - long minWeight = Collections.min(serverWeights.values()); - double normalizedMinWeight = (double) minWeight / totalWeight; - // Scale up the number of hashes per host such that the least-weighted host gets a whole - // number of hashes on the the ring. Other hosts might not end up with whole numbers, and - // that's fine (the ring-building algorithm can handle this). This preserves the original - // implementation's behavior: when weights aren't provided, all hosts should get an equal - // number of hashes. In the case where this number exceeds the max_ring_size, it's scaled - // back down to fit. - double scale = Math.min( - Math.ceil(normalizedMinWeight * config.minRingSize) / normalizedMinWeight, - (double) config.maxRingSize); - - // Build the ring - ring = buildRing(serverWeights, totalWeight, scale); - - // Must update channel picker before return so that new RPCs will not be routed to deleted - // clusters and resolver can remove them in service config. - updateOverallBalancingState(); - - shutdownRemoved(getRemovedChildren(newChildren.keySet())); - } finally { - this.resolvingAddresses = false; } - - return Status.OK; + // Calculate scale + long minWeight = Collections.min(serverWeights.values()); + double normalizedMinWeight = (double) minWeight / totalWeight; + // Scale up the number of hashes per host such that the least-weighted host gets a whole + // number of hashes on the the ring. Other hosts might not end up with whole numbers, and + // that's fine (the ring-building algorithm can handle this). This preserves the original + // implementation's behavior: when weights aren't provided, all hosts should get an equal + // number of hashes. In the case where this number exceeds the max_ring_size, it's scaled + // back down to fit. + double scale = Math.min( + Math.ceil(normalizedMinWeight * config.minRingSize) / normalizedMinWeight, + (double) config.maxRingSize); + + // Build the ring + ring = buildRing(serverWeights, totalWeight, scale); + + return super.acceptResolvedAddresses(resolvedAddresses); } @@ -221,16 +214,35 @@ protected void updateOverallBalancingState() { overallState = TRANSIENT_FAILURE; } - RingHashPicker picker = new RingHashPicker(syncContext, ring, getImmutableChildMap()); + // gRFC A61: if the aggregated connectivity state is TRANSIENT_FAILURE or CONNECTING and + // there are no endpoints in CONNECTING state, the ring_hash policy will choose one of + // the endpoints in IDLE state (if any) to trigger a connection attempt on + if (numReady == 0 && numTF > 0 && numConnecting == 0 && numIdle > 0) { + triggerIdleChildConnection(); + } + + RingHashPicker picker = + new RingHashPicker(syncContext, ring, getChildLbStates(), requestHashHeaderKey, random); getHelper().updateBalancingState(overallState, picker); this.currentConnectivityState = overallState; } + + /** + * Triggers a connection attempt for the first IDLE child load balancer. + */ + private void triggerIdleChildConnection() { + for (ChildLbState child : getChildLbStates()) { + if (child.getCurrentState() == ConnectivityState.IDLE) { + child.getLb().requestConnection(); + return; + } + } + } + @Override - protected ChildLbState createChildLbState(Object key, Object policyConfig, - SubchannelPicker initialPicker, ResolvedAddresses resolvedAddresses) { - return new RingHashChildLbState((Endpoint)key, - getChildAddresses(key, resolvedAddresses, null)); + protected ChildLbState createChildLbState(Object key) { + return new ChildLbState(key, lazyLbFactory); } private Status validateAddrList(List addrList) { @@ -251,7 +263,7 @@ private Status validateAddrList(List addrList) { long totalWeight = 0; for (EquivalentAddressGroup eag : addrList) { - Long weight = eag.getAttributes().get(InternalXdsAttributes.ATTR_SERVER_WEIGHT); + Long weight = eag.getAttributes().get(XdsAttributes.ATTR_SERVER_WEIGHT); if (weight == null) { weight = 1L; @@ -351,22 +363,32 @@ private static final class RingHashPicker extends SubchannelPicker { // TODO(chengyuanzhang): can be more performance-friendly with // IdentityHashMap and RingEntry contains Subchannel. private final Map pickableSubchannels; // read-only + @Nullable private final Metadata.Key requestHashHeaderKey; + private final ThreadSafeRandom random; + private final boolean hasEndpointInConnectingState; private RingHashPicker( SynchronizationContext syncContext, List ring, - ImmutableMap subchannels) { + Collection children, Metadata.Key requestHashHeaderKey, + ThreadSafeRandom random) { this.syncContext = syncContext; this.ring = ring; - pickableSubchannels = new HashMap<>(subchannels.size()); - for (Map.Entry entry : subchannels.entrySet()) { - RingHashChildLbState childLbState = (RingHashChildLbState) entry.getValue(); - pickableSubchannels.put((Endpoint)entry.getKey(), + this.requestHashHeaderKey = requestHashHeaderKey; + this.random = random; + pickableSubchannels = new HashMap<>(children.size()); + boolean hasConnectingState = false; + for (ChildLbState childLbState : children) { + pickableSubchannels.put((Endpoint)childLbState.getKey(), new SubchannelView(childLbState, childLbState.getCurrentState())); + if (childLbState.getCurrentState() == CONNECTING) { + hasConnectingState = true; + } } + this.hasEndpointInConnectingState = hasConnectingState; } // Find the ring entry with hash next to (clockwise) the RPC's hash (binary search). - private int getTargetIndex(Long requestHash) { + private int getTargetIndex(long requestHash) { if (ring.size() <= 1) { return 0; } @@ -392,47 +414,85 @@ private int getTargetIndex(Long requestHash) { @Override public PickResult pickSubchannel(PickSubchannelArgs args) { - Long requestHash = args.getCallOptions().getOption(XdsNameResolver.RPC_HASH_KEY); - if (requestHash == null) { - return PickResult.withError(RPC_HASH_NOT_FOUND); + // Determine request hash. + boolean usingRandomHash = false; + long requestHash; + if (requestHashHeaderKey == null) { + // Set by the xDS config selector. + Long rpcHashFromCallOptions = args.getCallOptions().getOption(XdsNameResolver.RPC_HASH_KEY); + if (rpcHashFromCallOptions == null) { + return PickResult.withError(RPC_HASH_NOT_FOUND); + } + requestHash = rpcHashFromCallOptions; + } else { + Iterable headerValues = args.getHeaders().getAll(requestHashHeaderKey); + if (headerValues != null) { + requestHash = hashFunc.hashAsciiString(Joiner.on(",").join(headerValues)); + } else { + requestHash = random.nextLong(); + usingRandomHash = true; + } } int targetIndex = getTargetIndex(requestHash); - // Per gRFC A61, because of sticky-TF with PickFirst's auto reconnect on TF, we ignore - // all TF subchannels and find the first ring entry in READY, CONNECTING or IDLE. If - // CONNECTING or IDLE we return a pick with no results. Additionally, if that entry is in - // IDLE, we initiate a connection. - for (int i = 0; i < ring.size(); i++) { - int index = (targetIndex + i) % ring.size(); - SubchannelView subchannelView = pickableSubchannels.get(ring.get(index).addrKey); - RingHashChildLbState childLbState = subchannelView.childLbState; - - if (subchannelView.connectivityState == READY) { - return childLbState.getCurrentPicker().pickSubchannel(args); + if (!usingRandomHash) { + // Per gRFC A61, because of sticky-TF with PickFirst's auto reconnect on TF, we ignore + // all TF subchannels and find the first ring entry in READY, CONNECTING or IDLE. If + // CONNECTING or IDLE we return a pick with no results. Additionally, if that entry is in + // IDLE, we initiate a connection. + for (int i = 0; i < ring.size(); i++) { + int index = (targetIndex + i) % ring.size(); + SubchannelView subchannelView = pickableSubchannels.get(ring.get(index).addrKey); + ChildLbState childLbState = subchannelView.childLbState; + + if (subchannelView.connectivityState == READY) { + return childLbState.getCurrentPicker().pickSubchannel(args); + } + + // RPCs can be buffered if the next subchannel is pending (per A62). Otherwise, RPCs + // are failed unless there is a READY connection. + if (subchannelView.connectivityState == CONNECTING) { + return PickResult.withNoResult(); + } + + if (subchannelView.connectivityState == IDLE) { + syncContext.execute(() -> { + if (childLbState.getCurrentState() == IDLE) { + childLbState.getLb().requestConnection(); + } + }); + + return PickResult.withNoResult(); // Indicates that this should be retried after backoff + } } - - // RPCs can be buffered if the next subchannel is pending (per A62). Otherwise, RPCs - // are failed unless there is a READY connection. - if (subchannelView.connectivityState == CONNECTING) { - return PickResult.withNoResult(); + } else { + // Using a random hash. Find and use the first READY ring entry, triggering at most one + // entry to attempt connection. + boolean requestedConnection = hasEndpointInConnectingState; + for (int i = 0; i < ring.size(); i++) { + int index = (targetIndex + i) % ring.size(); + SubchannelView subchannelView = pickableSubchannels.get(ring.get(index).addrKey); + ChildLbState childLbState = subchannelView.childLbState; + if (subchannelView.connectivityState == READY) { + return childLbState.getCurrentPicker().pickSubchannel(args); + } + if (!requestedConnection && subchannelView.connectivityState == IDLE) { + syncContext.execute(() -> { + if (childLbState.getCurrentState() == IDLE) { + childLbState.getLb().requestConnection(); + } + }); + requestedConnection = true; + } } - - if (subchannelView.connectivityState == IDLE) { - syncContext.execute(() -> { - if (childLbState.isDeactivated()) { - childLbState.activate(); - } else { - childLbState.getLb().requestConnection(); - } - }); - - return PickResult.withNoResult(); // Indicates that this should be retried after backoff + if (requestedConnection) { + return PickResult.withNoResult(); } } // return the pick from the original subchannel hit by hash, which is probably an error - RingHashChildLbState originalSubchannel = + ChildLbState originalSubchannel = pickableSubchannels.get(ring.get(targetIndex).addrKey).childLbState; return originalSubchannel.getCurrentPicker().pickSubchannel(args); } @@ -444,10 +504,10 @@ public PickResult pickSubchannel(PickSubchannelArgs args) { * state changes. */ private static final class SubchannelView { - private final RingHashChildLbState childLbState; + private final ChildLbState childLbState; private final ConnectivityState connectivityState; - private SubchannelView(RingHashChildLbState childLbState, ConnectivityState state) { + private SubchannelView(ChildLbState childLbState, ConnectivityState state) { this.childLbState = childLbState; this.connectivityState = state; } @@ -475,76 +535,41 @@ public int compareTo(RingEntry entry) { static final class RingHashConfig { final long minRingSize; final long maxRingSize; + final String requestHashHeader; - RingHashConfig(long minRingSize, long maxRingSize) { + RingHashConfig(long minRingSize, long maxRingSize, String requestHashHeader) { checkArgument(minRingSize > 0, "minRingSize <= 0"); checkArgument(maxRingSize > 0, "maxRingSize <= 0"); checkArgument(minRingSize <= maxRingSize, "minRingSize > maxRingSize"); + checkNotNull(requestHashHeader); this.minRingSize = minRingSize; this.maxRingSize = maxRingSize; + this.requestHashHeader = requestHashHeader; } @Override - public String toString() { - return MoreObjects.toStringHelper(this) - .add("minRingSize", minRingSize) - .add("maxRingSize", maxRingSize) - .toString(); - } - } - - class RingHashChildLbState extends MultiChildLoadBalancer.ChildLbState { - - public RingHashChildLbState(Endpoint key, ResolvedAddresses resolvedAddresses) { - super(key, pickFirstLbProvider, null, EMPTY_PICKER, resolvedAddresses, true); - } - - @Override - protected ChildLbStateHelper createChildHelper() { - return new RingHashChildHelper(); - } - - @Override - protected void reactivate(LoadBalancerProvider policyProvider) { - if (!isDeactivated()) { - return; + public boolean equals(Object o) { + if (!(o instanceof RingHashConfig)) { + return false; } - currentConnectivityState = CONNECTING; - getLb().switchTo(pickFirstLbProvider); - markReactivated(); - getLb().acceptResolvedAddresses(this.getResolvedAddresses()); - logger.log(XdsLogLevel.DEBUG, "Child balancer {0} reactivated", getKey()); - } - - public void activate() { - reactivate(pickFirstLbProvider); + RingHashConfig that = (RingHashConfig) o; + return this.minRingSize == that.minRingSize + && this.maxRingSize == that.maxRingSize + && Objects.equals(this.requestHashHeader, that.requestHashHeader); } - // Need to expose this to the LB class @Override - protected void shutdown() { - super.shutdown(); + public int hashCode() { + return Objects.hash(minRingSize, maxRingSize, requestHashHeader); } - private class RingHashChildHelper extends ChildLbStateHelper { - @Override - public void updateBalancingState(final ConnectivityState newState, - final SubchannelPicker newPicker) { - // If we are already in the process of resolving addresses, the overall balancing state - // will be updated at the end of it, and we don't need to trigger that update here. - if (getChildLbState(getKey()) == null) { - return; - } - - // Subchannel picker and state are saved, but will only be propagated to the channel - // when the child instance exits deactivated state. - setCurrentState(newState); - setCurrentPicker(newPicker); - if (!isDeactivated() && !resolvingAddresses) { - updateOverallBalancingState(); - } - } + @Override + public String toString() { + return MoreObjects.toStringHelper(this) + .add("minRingSize", minRingSize) + .add("maxRingSize", maxRingSize) + .add("requestHashHeader", requestHashHeader) + .toString(); } } - -} \ No newline at end of file +} diff --git a/xds/src/main/java/io/grpc/xds/RingHashLoadBalancerProvider.java b/xds/src/main/java/io/grpc/xds/RingHashLoadBalancerProvider.java index dad79384569..bb4f8de5a5f 100644 --- a/xds/src/main/java/io/grpc/xds/RingHashLoadBalancerProvider.java +++ b/xds/src/main/java/io/grpc/xds/RingHashLoadBalancerProvider.java @@ -24,6 +24,7 @@ import io.grpc.LoadBalancerProvider; import io.grpc.NameResolver.ConfigOrError; import io.grpc.Status; +import io.grpc.internal.GrpcUtil; import io.grpc.internal.JsonUtil; import io.grpc.xds.RingHashLoadBalancer.RingHashConfig; import io.grpc.xds.RingHashOptions; @@ -81,6 +82,10 @@ private ConfigOrError parseLoadBalancingPolicyConfigInternal( Map rawLoadBalancingPolicyConfig) { Long minRingSize = JsonUtil.getNumberAsLong(rawLoadBalancingPolicyConfig, "minRingSize"); Long maxRingSize = JsonUtil.getNumberAsLong(rawLoadBalancingPolicyConfig, "maxRingSize"); + String requestHashHeader = ""; + if (GrpcUtil.getFlag("GRPC_EXPERIMENTAL_RING_HASH_SET_REQUEST_HASH_KEY", false)) { + requestHashHeader = JsonUtil.getString(rawLoadBalancingPolicyConfig, "requestHashHeader"); + } long maxRingSizeCap = RingHashOptions.getRingSizeCap(); if (minRingSize == null) { minRingSize = DEFAULT_MIN_RING_SIZE; @@ -88,6 +93,9 @@ private ConfigOrError parseLoadBalancingPolicyConfigInternal( if (maxRingSize == null) { maxRingSize = DEFAULT_MAX_RING_SIZE; } + if (requestHashHeader == null) { + requestHashHeader = ""; + } if (minRingSize > maxRingSizeCap) { minRingSize = maxRingSizeCap; } @@ -96,8 +104,9 @@ private ConfigOrError parseLoadBalancingPolicyConfigInternal( } if (minRingSize <= 0 || maxRingSize <= 0 || minRingSize > maxRingSize) { return ConfigOrError.fromError(Status.UNAVAILABLE.withDescription( - "Invalid 'mingRingSize'/'maxRingSize'")); + "Invalid 'minRingSize'/'maxRingSize'")); } - return ConfigOrError.fromConfig(new RingHashConfig(minRingSize, maxRingSize)); + return ConfigOrError.fromConfig( + new RingHashConfig(minRingSize, maxRingSize, requestHashHeader)); } } diff --git a/xds/src/main/java/io/grpc/xds/RouterFilter.java b/xds/src/main/java/io/grpc/xds/RouterFilter.java index 7f1adf86a6d..090d1cfabad 100644 --- a/xds/src/main/java/io/grpc/xds/RouterFilter.java +++ b/xds/src/main/java/io/grpc/xds/RouterFilter.java @@ -17,19 +17,12 @@ package io.grpc.xds; import com.google.protobuf.Message; -import io.grpc.ClientInterceptor; -import io.grpc.LoadBalancer.PickSubchannelArgs; -import io.grpc.ServerInterceptor; -import io.grpc.xds.Filter.ClientInterceptorBuilder; -import io.grpc.xds.Filter.ServerInterceptorBuilder; -import java.util.concurrent.ScheduledExecutorService; -import javax.annotation.Nullable; /** * Router filter implementation. Currently this filter does not parse any field in the config. */ -enum RouterFilter implements Filter, ClientInterceptorBuilder, ServerInterceptorBuilder { - INSTANCE; +final class RouterFilter implements Filter { + private static final RouterFilter INSTANCE = new RouterFilter(); static final String TYPE_URL = "type.googleapis.com/envoy.extensions.filters.http.router.v3.Router"; @@ -37,7 +30,7 @@ enum RouterFilter implements Filter, ClientInterceptorBuilder, ServerInterceptor static final FilterConfig ROUTER_CONFIG = new FilterConfig() { @Override public String typeUrl() { - return RouterFilter.TYPE_URL; + return TYPE_URL; } @Override @@ -46,33 +39,39 @@ public String toString() { } }; - @Override - public String[] typeUrls() { - return new String[] { TYPE_URL }; - } + static final class Provider implements Filter.Provider { + @Override + public String[] typeUrls() { + return new String[]{TYPE_URL}; + } - @Override - public ConfigOrError parseFilterConfig(Message rawProtoMessage) { - return ConfigOrError.fromConfig(ROUTER_CONFIG); - } + @Override + public boolean isClientFilter() { + return true; + } - @Override - public ConfigOrError parseFilterConfigOverride(Message rawProtoMessage) { - return ConfigOrError.fromError("Router Filter should not have override config"); - } + @Override + public boolean isServerFilter() { + return true; + } - @Nullable - @Override - public ClientInterceptor buildClientInterceptor( - FilterConfig config, @Nullable FilterConfig overrideConfig, PickSubchannelArgs args, - ScheduledExecutorService scheduler) { - return null; - } + @Override + public RouterFilter newInstance(String name) { + return INSTANCE; + } - @Nullable - @Override - public ServerInterceptor buildServerInterceptor( - FilterConfig config, @Nullable Filter.FilterConfig overrideConfig) { - return null; + @Override + public ConfigOrError parseFilterConfig( + Message rawProtoMessage, FilterConfigParseContext context) { + return ConfigOrError.fromConfig(ROUTER_CONFIG); + } + + @Override + public ConfigOrError parseFilterConfigOverride( + Message rawProtoMessage, FilterConfigParseContext context) { + return ConfigOrError.fromError("Router Filter should not have override config"); + } } + + private RouterFilter() {} } diff --git a/xds/src/main/java/io/grpc/xds/SharedXdsClientPoolProvider.java b/xds/src/main/java/io/grpc/xds/SharedXdsClientPoolProvider.java index a4f20f1b65d..45c379244af 100644 --- a/xds/src/main/java/io/grpc/xds/SharedXdsClientPoolProvider.java +++ b/xds/src/main/java/io/grpc/xds/SharedXdsClientPoolProvider.java @@ -17,9 +17,12 @@ package io.grpc.xds; import static com.google.common.base.Preconditions.checkNotNull; -import static io.grpc.xds.GrpcXdsTransportFactory.DEFAULT_XDS_TRANSPORT_FACTORY; import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.ImmutableList; +import com.google.errorprone.annotations.concurrent.GuardedBy; +import io.grpc.CallCredentials; +import io.grpc.MetricRecorder; import io.grpc.internal.ExponentialBackoffPolicy; import io.grpc.internal.GrpcUtil; import io.grpc.internal.ObjectPool; @@ -32,10 +35,11 @@ import io.grpc.xds.client.XdsInitializationException; import io.grpc.xds.internal.security.TlsContextManagerImpl; import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ScheduledExecutorService; -import java.util.concurrent.atomic.AtomicReference; +import java.util.logging.Level; +import java.util.logging.Logger; import javax.annotation.Nullable; -import javax.annotation.concurrent.GuardedBy; import javax.annotation.concurrent.ThreadSafe; /** @@ -44,19 +48,24 @@ */ @ThreadSafe final class SharedXdsClientPoolProvider implements XdsClientPoolFactory { + private static final boolean LOG_XDS_NODE_ID = Boolean.parseBoolean( + System.getenv("GRPC_LOG_XDS_NODE_ID")); + private static final Logger log = Logger.getLogger(XdsClientImpl.class.getName()); + private static final ExponentialBackoffPolicy.Provider BACKOFF_POLICY_PROVIDER = + new ExponentialBackoffPolicy.Provider(); + @Nullable private final Bootstrapper bootstrapper; private final Object lock = new Object(); - private final AtomicReference> bootstrapOverride = new AtomicReference<>(); - private volatile ObjectPool xdsClientPool; + private final Map> targetToXdsClientMap = new ConcurrentHashMap<>(); SharedXdsClientPoolProvider() { - this(new GrpcBootstrapperImpl()); + this(null); } @VisibleForTesting - SharedXdsClientPoolProvider(Bootstrapper bootstrapper) { - this.bootstrapper = checkNotNull(bootstrapper, "bootstrapper"); + SharedXdsClientPoolProvider(@Nullable Bootstrapper bootstrapper) { + this.bootstrapper = bootstrapper; } static SharedXdsClientPoolProvider getDefaultProvider() { @@ -64,48 +73,67 @@ static SharedXdsClientPoolProvider getDefaultProvider() { } @Override - public void setBootstrapOverride(Map bootstrap) { - bootstrapOverride.set(bootstrap); + @Nullable + public ObjectPool get(String target) { + return targetToXdsClientMap.get(target); } - @Override - @Nullable - public ObjectPool get() { - return xdsClientPool; + @Deprecated + public ObjectPool getOrCreate( + String target, MetricRecorder metricRecorder, CallCredentials transportCallCredentials) + throws XdsInitializationException { + BootstrapInfo bootstrapInfo; + if (bootstrapper != null) { + bootstrapInfo = bootstrapper.bootstrap(); + } else { + bootstrapInfo = GrpcBootstrapperImpl.defaultBootstrap(); + } + return getOrCreate(target, bootstrapInfo, metricRecorder, transportCallCredentials); } @Override - public ObjectPool getOrCreate() throws XdsInitializationException { - ObjectPool ref = xdsClientPool; + public ObjectPool getOrCreate( + String target, BootstrapInfo bootstrapInfo, MetricRecorder metricRecorder) { + return getOrCreate(target, bootstrapInfo, metricRecorder, null); + } + + public ObjectPool getOrCreate( + String target, + BootstrapInfo bootstrapInfo, + MetricRecorder metricRecorder, + CallCredentials transportCallCredentials) { + ObjectPool ref = targetToXdsClientMap.get(target); if (ref == null) { synchronized (lock) { - ref = xdsClientPool; + ref = targetToXdsClientMap.get(target); if (ref == null) { - BootstrapInfo bootstrapInfo; - Map rawBootstrap = bootstrapOverride.get(); - if (rawBootstrap != null) { - bootstrapInfo = bootstrapper.bootstrap(rawBootstrap); - } else { - bootstrapInfo = bootstrapper.bootstrap(); - } - if (bootstrapInfo.servers().isEmpty()) { - throw new XdsInitializationException("No xDS server provided"); - } - ref = xdsClientPool = new RefCountedXdsClientObjectPool(bootstrapInfo); + ref = + new RefCountedXdsClientObjectPool( + bootstrapInfo, target, metricRecorder, transportCallCredentials); + targetToXdsClientMap.put(target, ref); } } } return ref; } + @Override + public ImmutableList getTargets() { + return ImmutableList.copyOf(targetToXdsClientMap.keySet()); + } + private static class SharedXdsClientPoolProviderHolder { private static final SharedXdsClientPoolProvider instance = new SharedXdsClientPoolProvider(); } @ThreadSafe @VisibleForTesting - static class RefCountedXdsClientObjectPool implements ObjectPool { + class RefCountedXdsClientObjectPool implements ObjectPool { + private final BootstrapInfo bootstrapInfo; + private final String target; // The target associated with the xDS client. + private final MetricRecorder metricRecorder; + private final CallCredentials transportCallCredentials; private final Object lock = new Object(); @GuardedBy("lock") private ScheduledExecutorService scheduler; @@ -113,26 +141,50 @@ static class RefCountedXdsClientObjectPool implements ObjectPool { private XdsClient xdsClient; @GuardedBy("lock") private int refCount; + @GuardedBy("lock") + private XdsClientMetricReporterImpl metricReporter; + + @VisibleForTesting + RefCountedXdsClientObjectPool( + BootstrapInfo bootstrapInfo, String target, MetricRecorder metricRecorder) { + this(bootstrapInfo, target, metricRecorder, null); + } @VisibleForTesting - RefCountedXdsClientObjectPool(BootstrapInfo bootstrapInfo) { - this.bootstrapInfo = checkNotNull(bootstrapInfo); + RefCountedXdsClientObjectPool( + BootstrapInfo bootstrapInfo, + String target, + MetricRecorder metricRecorder, + CallCredentials transportCallCredentials) { + this.bootstrapInfo = checkNotNull(bootstrapInfo, "bootstrapInfo"); + this.target = target; + this.metricRecorder = checkNotNull(metricRecorder, "metricRecorder"); + this.transportCallCredentials = transportCallCredentials; } @Override public XdsClient getObject() { synchronized (lock) { if (refCount == 0) { + if (LOG_XDS_NODE_ID) { + log.log(Level.INFO, "xDS node ID: {0}", bootstrapInfo.node().getId()); + } scheduler = SharedResourceHolder.get(GrpcUtil.TIMER_SERVICE); - xdsClient = new XdsClientImpl( - DEFAULT_XDS_TRANSPORT_FACTORY, - bootstrapInfo, - scheduler, - new ExponentialBackoffPolicy.Provider(), - GrpcUtil.STOPWATCH_SUPPLIER, - TimeProvider.SYSTEM_TIME_PROVIDER, - MessagePrinter.INSTANCE, - new TlsContextManagerImpl(bootstrapInfo)); + metricReporter = new XdsClientMetricReporterImpl(metricRecorder, target); + GrpcXdsTransportFactory xdsTransportFactory = + new GrpcXdsTransportFactory(transportCallCredentials); + xdsClient = + new XdsClientImpl( + xdsTransportFactory, + bootstrapInfo, + scheduler, + BACKOFF_POLICY_PROVIDER, + GrpcUtil.STOPWATCH_SUPPLIER, + TimeProvider.SYSTEM_TIME_PROVIDER, + MessagePrinter.INSTANCE, + new TlsContextManagerImpl(bootstrapInfo), + metricReporter); + metricReporter.setXdsClient(xdsClient); } refCount++; return xdsClient; @@ -146,7 +198,14 @@ public XdsClient returnObject(Object object) { if (refCount == 0) { xdsClient.shutdown(); xdsClient = null; + metricReporter.close(); + metricReporter = null; + targetToXdsClientMap.remove(target); scheduler = SharedResourceHolder.release(GrpcUtil.TIMER_SERVICE, scheduler); + } else if (refCount < 0) { + assert false; // We want our tests to fail + log.log(Level.SEVERE, "Negative reference count. File a bug", new Exception()); + refCount = 0; } return null; } @@ -159,5 +218,10 @@ XdsClient getXdsClientForTest() { return xdsClient; } } + + public String getTarget() { + return target; + } } + } diff --git a/xds/src/main/java/io/grpc/xds/StructOrError.java b/xds/src/main/java/io/grpc/xds/StructOrError.java new file mode 100644 index 00000000000..14f008d191e --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/StructOrError.java @@ -0,0 +1,72 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds; + +import static com.google.common.base.Preconditions.checkNotNull; + +import com.google.common.annotations.VisibleForTesting; +import javax.annotation.Nullable; + +/** An object or a String error. */ +final class StructOrError { + + /** + * Returns a {@link StructOrError} for the successfully converted data object. + */ + public static StructOrError fromStruct(T struct) { + return new StructOrError<>(struct); + } + + /** + * Returns a {@link StructOrError} for the failure to convert the data object. + */ + public static StructOrError fromError(String errorDetail) { + return new StructOrError<>(errorDetail); + } + + private final String errorDetail; + private final T struct; + + private StructOrError(T struct) { + this.struct = checkNotNull(struct, "struct"); + this.errorDetail = null; + } + + private StructOrError(String errorDetail) { + this.struct = null; + this.errorDetail = checkNotNull(errorDetail, "errorDetail"); + } + + /** + * Returns struct if exists, otherwise null. + */ + @VisibleForTesting + @Nullable + public T getStruct() { + return struct; + } + + /** + * Returns error detail if exists, otherwise null. + */ + @VisibleForTesting + @Nullable + public String getErrorDetail() { + return errorDetail; + } +} + diff --git a/xds/src/main/java/io/grpc/xds/VirtualHost.java b/xds/src/main/java/io/grpc/xds/VirtualHost.java index d9f93dd3a07..5cc979984c6 100644 --- a/xds/src/main/java/io/grpc/xds/VirtualHost.java +++ b/xds/src/main/java/io/grpc/xds/VirtualHost.java @@ -166,29 +166,34 @@ abstract static class RouteAction { @Nullable abstract RetryPolicy retryPolicy(); + abstract boolean autoHostRewrite(); + static RouteAction forCluster( String cluster, List hashPolicies, @Nullable Long timeoutNano, - @Nullable RetryPolicy retryPolicy) { + @Nullable RetryPolicy retryPolicy, boolean autoHostRewrite) { checkNotNull(cluster, "cluster"); - return RouteAction.create(hashPolicies, timeoutNano, cluster, null, null, retryPolicy); + return RouteAction.create(hashPolicies, timeoutNano, cluster, null, null, retryPolicy, + autoHostRewrite); } static RouteAction forWeightedClusters( List weightedClusters, List hashPolicies, - @Nullable Long timeoutNano, @Nullable RetryPolicy retryPolicy) { + @Nullable Long timeoutNano, @Nullable RetryPolicy retryPolicy, boolean autoHostRewrite) { checkNotNull(weightedClusters, "weightedClusters"); checkArgument(!weightedClusters.isEmpty(), "empty cluster list"); return RouteAction.create( - hashPolicies, timeoutNano, null, weightedClusters, null, retryPolicy); + hashPolicies, timeoutNano, null, weightedClusters, null, retryPolicy, autoHostRewrite); } static RouteAction forClusterSpecifierPlugin( NamedPluginConfig namedConfig, List hashPolicies, @Nullable Long timeoutNano, - @Nullable RetryPolicy retryPolicy) { + @Nullable RetryPolicy retryPolicy, + boolean autoHostRewrite) { checkNotNull(namedConfig, "namedConfig"); - return RouteAction.create(hashPolicies, timeoutNano, null, null, namedConfig, retryPolicy); + return RouteAction.create(hashPolicies, timeoutNano, null, null, namedConfig, retryPolicy, + autoHostRewrite); } private static RouteAction create( @@ -197,26 +202,29 @@ private static RouteAction create( @Nullable String cluster, @Nullable List weightedClusters, @Nullable NamedPluginConfig namedConfig, - @Nullable RetryPolicy retryPolicy) { + @Nullable RetryPolicy retryPolicy, + boolean autoHostRewrite) { return new AutoValue_VirtualHost_Route_RouteAction( ImmutableList.copyOf(hashPolicies), timeoutNano, cluster, weightedClusters == null ? null : ImmutableList.copyOf(weightedClusters), namedConfig, - retryPolicy); + retryPolicy, + autoHostRewrite); } @AutoValue abstract static class ClusterWeight { abstract String name(); - abstract int weight(); + abstract long weight(); abstract ImmutableMap filterConfigOverrides(); static ClusterWeight create( - String name, int weight, Map filterConfigOverrides) { + String name, long weight, Map filterConfigOverrides) { + checkArgument(weight >= 0, "weight must not be negative"); return new AutoValue_VirtualHost_Route_RouteAction_ClusterWeight( name, weight, ImmutableMap.copyOf(filterConfigOverrides)); } diff --git a/xds/src/main/java/io/grpc/xds/WeightedRoundRobinLoadBalancer.java b/xds/src/main/java/io/grpc/xds/WeightedRoundRobinLoadBalancer.java index c3f60623a95..6744903de35 100644 --- a/xds/src/main/java/io/grpc/xds/WeightedRoundRobinLoadBalancer.java +++ b/xds/src/main/java/io/grpc/xds/WeightedRoundRobinLoadBalancer.java @@ -17,37 +17,41 @@ package io.grpc.xds; import static com.google.common.base.Preconditions.checkArgument; -import static com.google.common.base.Preconditions.checkElementIndex; import static com.google.common.base.Preconditions.checkNotNull; import com.google.common.annotations.VisibleForTesting; import com.google.common.base.MoreObjects; import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; +import com.google.common.collect.Lists; import io.grpc.ConnectivityState; import io.grpc.ConnectivityStateInfo; import io.grpc.Deadline.Ticker; +import io.grpc.DoubleHistogramMetricInstrument; import io.grpc.EquivalentAddressGroup; -import io.grpc.ExperimentalApi; import io.grpc.LoadBalancer; import io.grpc.LoadBalancerProvider; +import io.grpc.LongCounterMetricInstrument; +import io.grpc.MetricInstrumentRegistry; import io.grpc.NameResolver; import io.grpc.Status; import io.grpc.SynchronizationContext; import io.grpc.SynchronizationContext.ScheduledHandle; import io.grpc.services.MetricReport; -import io.grpc.util.ForwardingLoadBalancerHelper; import io.grpc.util.ForwardingSubchannel; -import io.grpc.util.RoundRobinLoadBalancer; +import io.grpc.util.MultiChildLoadBalancer; +import io.grpc.xds.internal.MetricReportUtils; +import io.grpc.xds.internal.MetricReportUtils.ParsedMetricName; import io.grpc.xds.orca.OrcaOobUtil; import io.grpc.xds.orca.OrcaOobUtil.OrcaOobReportListener; import io.grpc.xds.orca.OrcaPerRequestUtil; import io.grpc.xds.orca.OrcaPerRequestUtil.OrcaPerRequestReportListener; +import java.util.ArrayList; import java.util.Collection; -import java.util.HashMap; import java.util.HashSet; import java.util.List; -import java.util.Map; +import java.util.Objects; +import java.util.OptionalDouble; import java.util.Random; import java.util.Set; import java.util.concurrent.ScheduledExecutorService; @@ -57,12 +61,40 @@ import java.util.logging.Logger; /** - * A {@link LoadBalancer} that provides weighted-round-robin load-balancing over - * the {@link EquivalentAddressGroup}s from the {@link NameResolver}. The subchannel weights are + * A {@link LoadBalancer} that provides weighted-round-robin load-balancing over the + * {@link EquivalentAddressGroup}s from the {@link NameResolver}. The subchannel weights are * determined by backend metrics using ORCA. + * To use WRR, users may configure through channel serviceConfig. Example config: + *

 {@code
+ *       String wrrConfig = "{\"loadBalancingConfig\":" +
+ *           "[{\"weighted_round_robin\":{\"enableOobLoadReport\":true, " +
+ *           "\"blackoutPeriod\":\"10s\"," +
+ *           "\"oobReportingPeriod\":\"10s\"," +
+ *           "\"weightExpirationPeriod\":\"180s\"," +
+ *           "\"errorUtilizationPenalty\":\"1.0\"," +
+ *           "\"weightUpdatePeriod\":\"1s\"}}]}";
+ *        serviceConfig = (Map) JsonParser.parse(wrrConfig);
+ *        channel = ManagedChannelBuilder.forTarget("test:///lb.test.grpc.io")
+ *            .defaultServiceConfig(serviceConfig)
+ *            .build();
+ *  }
+ *  
+ * Users may also configure through xDS control plane via custom lb policy. But that is much more + * complex to set up. Example config: + *
+ *  localityLbPolicies:
+ *   - customPolicy:
+ *       name: weighted_round_robin
+ *       data: '{ "enableOobLoadReport": true }'
+ *  
+ * See related documentation: https://cloud.google.com/service-mesh/legacy/load-balancing-apis/proxyless-configure-advanced-traffic-management#custom-lb-config */ -@ExperimentalApi("https://github.com/grpc/grpc-java/issues/9885") -final class WeightedRoundRobinLoadBalancer extends RoundRobinLoadBalancer { +final class WeightedRoundRobinLoadBalancer extends MultiChildLoadBalancer { + + private static final LongCounterMetricInstrument RR_FALLBACK_COUNTER; + private static final LongCounterMetricInstrument ENDPOINT_WEIGHT_NOT_YET_USEABLE_COUNTER; + private static final LongCounterMetricInstrument ENDPOINT_WEIGHT_STALE_COUNTER; + private static final DoubleHistogramMetricInstrument ENDPOINT_WEIGHTS_HISTOGRAM; private static final Logger log = Logger.getLogger( WeightedRoundRobinLoadBalancer.class.getName()); private WeightedRoundRobinLoadBalancerConfig config; @@ -73,14 +105,55 @@ final class WeightedRoundRobinLoadBalancer extends RoundRobinLoadBalancer { private final AtomicInteger sequence; private final long infTime; private final Ticker ticker; + private String locality = ""; + private String backendService = ""; + private SubchannelPicker currentPicker = new FixedResultPicker(PickResult.withNoResult()); + + // The metric instruments are only registered once and shared by all instances of this LB. + static { + MetricInstrumentRegistry metricInstrumentRegistry + = MetricInstrumentRegistry.getDefaultRegistry(); + RR_FALLBACK_COUNTER = metricInstrumentRegistry.registerLongCounter( + "grpc.lb.wrr.rr_fallback", + "EXPERIMENTAL. Number of scheduler updates in which there were not enough endpoints " + + "with valid weight, which caused the WRR policy to fall back to RR behavior", + "{update}", + Lists.newArrayList("grpc.target"), + Lists.newArrayList("grpc.lb.locality", "grpc.lb.backend_service"), + false); + ENDPOINT_WEIGHT_NOT_YET_USEABLE_COUNTER = metricInstrumentRegistry.registerLongCounter( + "grpc.lb.wrr.endpoint_weight_not_yet_usable", + "EXPERIMENTAL. Number of endpoints from each scheduler update that don't yet have usable " + + "weight information", + "{endpoint}", + Lists.newArrayList("grpc.target"), + Lists.newArrayList("grpc.lb.locality", "grpc.lb.backend_service"), + false); + ENDPOINT_WEIGHT_STALE_COUNTER = metricInstrumentRegistry.registerLongCounter( + "grpc.lb.wrr.endpoint_weight_stale", + "EXPERIMENTAL. Number of endpoints from each scheduler update whose latest weight is " + + "older than the expiration period", + "{endpoint}", + Lists.newArrayList("grpc.target"), + Lists.newArrayList("grpc.lb.locality", "grpc.lb.backend_service"), + false); + ENDPOINT_WEIGHTS_HISTOGRAM = metricInstrumentRegistry.registerDoubleHistogram( + "grpc.lb.wrr.endpoint_weights", + "EXPERIMENTAL. The histogram buckets will be endpoint weight ranges.", + "{weight}", + Lists.newArrayList(), + Lists.newArrayList("grpc.target"), + Lists.newArrayList("grpc.lb.locality", "grpc.lb.backend_service"), + false); + } public WeightedRoundRobinLoadBalancer(Helper helper, Ticker ticker) { - this(new WrrHelper(OrcaOobUtil.newOrcaReportingHelper(helper)), ticker, new Random()); + this(helper, ticker, new Random()); } - public WeightedRoundRobinLoadBalancer(WrrHelper helper, Ticker ticker, Random random) { - super(helper); - helper.setLoadBalancer(this); + @VisibleForTesting + WeightedRoundRobinLoadBalancer(Helper helper, Ticker ticker, Random random) { + super(OrcaOobUtil.newOrcaReportingHelper(helper)); this.ticker = checkNotNull(ticker, "ticker"); this.infTime = ticker.nanoTime() + Long.MAX_VALUE; this.syncContext = checkNotNull(helper.getSynchronizationContext(), "syncContext"); @@ -90,17 +163,9 @@ public WeightedRoundRobinLoadBalancer(WrrHelper helper, Ticker ticker, Random ra log.log(Level.FINE, "weighted_round_robin LB created"); } - @VisibleForTesting - WeightedRoundRobinLoadBalancer(Helper helper, Ticker ticker, Random random) { - this(new WrrHelper(OrcaOobUtil.newOrcaReportingHelper(helper)), ticker, random); - } - @Override - protected ChildLbState createChildLbState(Object key, Object policyConfig, - SubchannelPicker initialPicker, ResolvedAddresses unused) { - ChildLbState childLbState = new WeightedChildLbState(key, pickFirstLbProvider, policyConfig, - initialPicker); - return childLbState; + protected ChildLbState createChildLbState(Object key) { + return new WeightedChildLbState(key, pickFirstLbProvider); } @Override @@ -113,39 +178,114 @@ public Status acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { handleNameResolutionError(unavailableStatus); return unavailableStatus; } + String locality = resolvedAddresses.getAttributes().get(WeightedTargetLoadBalancer.CHILD_NAME); + if (locality != null) { + this.locality = locality; + } else { + this.locality = ""; + } + String backendService + = resolvedAddresses.getAttributes().get(NameResolver.ATTR_BACKEND_SERVICE); + if (backendService != null) { + this.backendService = backendService; + } else { + this.backendService = ""; + } config = - (WeightedRoundRobinLoadBalancerConfig) resolvedAddresses.getLoadBalancingPolicyConfig(); - AcceptResolvedAddrRetVal acceptRetVal; - try { - resolvingAddresses = true; - acceptRetVal = acceptResolvedAddressesInternal(resolvedAddresses); - if (!acceptRetVal.status.isOk()) { - return acceptRetVal.status; - } + (WeightedRoundRobinLoadBalancerConfig) resolvedAddresses.getLoadBalancingPolicyConfig(); - if (weightUpdateTimer != null && weightUpdateTimer.isPending()) { - weightUpdateTimer.cancel(); - } - updateWeightTask.run(); + if (weightUpdateTimer != null && weightUpdateTimer.isPending()) { + weightUpdateTimer.cancel(); + } + updateWeightTask.run(); - createAndApplyOrcaListeners(); + Status status = super.acceptResolvedAddresses(resolvedAddresses); - // Must update channel picker before return so that new RPCs will not be routed to deleted - // clusters and resolver can remove them in service config. - updateOverallBalancingState(); + createAndApplyOrcaListeners(); - shutdownRemoved(acceptRetVal.removedChildren); - } finally { - resolvingAddresses = false; + return status; + } + + /** + * Updates picker with the list of active subchannels (state == READY). + */ + @Override + protected void updateOverallBalancingState() { + List activeList = getReadyChildren(); + if (activeList.isEmpty()) { + // No READY subchannels + + // MultiChildLB will request connection immediately on subchannel IDLE. + boolean isConnecting = false; + for (ChildLbState childLbState : getChildLbStates()) { + ConnectivityState state = childLbState.getCurrentState(); + if (state == ConnectivityState.CONNECTING || state == ConnectivityState.IDLE) { + isConnecting = true; + break; + } + } + + if (isConnecting) { + updateBalancingState( + ConnectivityState.CONNECTING, new FixedResultPicker(PickResult.withNoResult())); + } else { + updateBalancingState( + ConnectivityState.TRANSIENT_FAILURE, createReadyPicker(getChildLbStates())); + } + } else { + updateBalancingState(ConnectivityState.READY, createReadyPicker(activeList)); } + } - return acceptRetVal.status; + private SubchannelPicker createReadyPicker(Collection activeList) { + WeightedRoundRobinPicker picker = new WeightedRoundRobinPicker(ImmutableList.copyOf(activeList), + config.enableOobLoadReport, config.errorUtilizationPenalty, sequence, + config.parsedMetricNamesForComputingUtilization); + updateWeight(picker); + return picker; } - @Override - public SubchannelPicker createReadyPicker(Collection activeList) { - return new WeightedRoundRobinPicker(ImmutableList.copyOf(activeList), - config.enableOobLoadReport, config.errorUtilizationPenalty, sequence); + private void updateWeight(WeightedRoundRobinPicker picker) { + Helper helper = getHelper(); + float[] newWeights = new float[picker.children.size()]; + AtomicInteger staleEndpoints = new AtomicInteger(); + AtomicInteger notYetUsableEndpoints = new AtomicInteger(); + for (int i = 0; i < picker.children.size(); i++) { + double newWeight = ((WeightedChildLbState) picker.children.get(i)).getWeight(staleEndpoints, + notYetUsableEndpoints); + helper.getMetricRecorder() + .recordDoubleHistogram(ENDPOINT_WEIGHTS_HISTOGRAM, newWeight, + ImmutableList.of(helper.getChannelTarget()), + ImmutableList.of(locality, backendService)); + newWeights[i] = newWeight > 0 ? (float) newWeight : 0.0f; + } + + if (staleEndpoints.get() > 0) { + helper.getMetricRecorder() + .addLongCounter(ENDPOINT_WEIGHT_STALE_COUNTER, staleEndpoints.get(), + ImmutableList.of(helper.getChannelTarget()), + ImmutableList.of(locality, backendService)); + } + if (notYetUsableEndpoints.get() > 0) { + helper.getMetricRecorder() + .addLongCounter(ENDPOINT_WEIGHT_NOT_YET_USEABLE_COUNTER, notYetUsableEndpoints.get(), + ImmutableList.of(helper.getChannelTarget()), + ImmutableList.of(locality, backendService)); + } + boolean weightsEffective = picker.updateWeight(newWeights); + if (!weightsEffective) { + helper.getMetricRecorder() + .addLongCounter(RR_FALLBACK_COUNTER, 1, ImmutableList.of(helper.getChannelTarget()), + ImmutableList.of(locality, backendService)); + } + } + + private void updateBalancingState(ConnectivityState state, SubchannelPicker picker) { + if (state != currentConnectivityState || !picker.equals(currentPicker)) { + getHelper().updateBalancingState(state, picker); + currentConnectivityState = state; + currentPicker = picker; + } } @VisibleForTesting @@ -158,21 +298,27 @@ final class WeightedChildLbState extends ChildLbState { private OrcaReportListener orcaReportListener; - public WeightedChildLbState(Object key, LoadBalancerProvider policyProvider, Object childConfig, - SubchannelPicker initialPicker) { - super(key, policyProvider, childConfig, initialPicker); + public WeightedChildLbState(Object key, LoadBalancerProvider policyProvider) { + super(key, policyProvider); } - private double getWeight() { + @Override + protected ChildLbStateHelper createChildHelper() { + return new WrrChildLbStateHelper(); + } + + private double getWeight(AtomicInteger staleEndpoints, AtomicInteger notYetUsableEndpoints) { if (config == null) { return 0; } long now = ticker.nanoTime(); if (now - lastUpdated >= config.weightExpirationPeriodNanos) { nonEmptySince = infTime; + staleEndpoints.incrementAndGet(); return 0; } else if (now - nonEmptySince < config.blackoutPeriodNanos && config.blackoutPeriodNanos > 0) { + notYetUsableEndpoints.incrementAndGet(); return 0; } else { return weight; @@ -183,12 +329,16 @@ public void addSubchannel(WrrSubchannel wrrSubchannel) { subchannels.add(wrrSubchannel); } - public OrcaReportListener getOrCreateOrcaListener(float errorUtilizationPenalty) { + public OrcaReportListener getOrCreateOrcaListener(float errorUtilizationPenalty, + ImmutableList parsedMetricNamesForComputingUtilization) { if (orcaReportListener != null - && orcaReportListener.errorUtilizationPenalty == errorUtilizationPenalty) { + && orcaReportListener.errorUtilizationPenalty == errorUtilizationPenalty + && orcaReportListener.parsedMetricNamesForComputingUtilization + .equals(parsedMetricNamesForComputingUtilization)) { return orcaReportListener; } - orcaReportListener = new OrcaReportListener(errorUtilizationPenalty); + orcaReportListener = + new OrcaReportListener(errorUtilizationPenalty, parsedMetricNamesForComputingUtilization); return orcaReportListener; } @@ -196,20 +346,36 @@ public void removeSubchannel(WrrSubchannel wrrSubchannel) { subchannels.remove(wrrSubchannel); } + final class WrrChildLbStateHelper extends ChildLbStateHelper { + @Override + public Subchannel createSubchannel(CreateSubchannelArgs args) { + return new WrrSubchannel(super.createSubchannel(args), WeightedChildLbState.this); + } + + @Override + public void updateBalancingState(ConnectivityState newState, SubchannelPicker newPicker) { + super.updateBalancingState(newState, newPicker); + if (!resolvingAddresses && newState == ConnectivityState.IDLE) { + getLb().requestConnection(); + } + } + } + final class OrcaReportListener implements OrcaPerRequestReportListener, OrcaOobReportListener { private final float errorUtilizationPenalty; + private final ImmutableList parsedMetricNamesForComputingUtilization; - OrcaReportListener(float errorUtilizationPenalty) { + OrcaReportListener(float errorUtilizationPenalty, + ImmutableList parsedMetricNamesForComputingUtilization) { this.errorUtilizationPenalty = errorUtilizationPenalty; + this.parsedMetricNamesForComputingUtilization = parsedMetricNamesForComputingUtilization; } @Override public void onLoadReport(MetricReport report) { + double utilization = getUtilization(report); + double newWeight = 0; - // Prefer application utilization and fallback to CPU utilization if unset. - double utilization = - report.getApplicationUtilization() > 0 ? report.getApplicationUtilization() - : report.getCpuUtilization(); if (utilization > 0 && report.getQps() > 0) { double penalty = 0; if (report.getEps() > 0 && errorUtilizationPenalty > 0) { @@ -226,6 +392,44 @@ public void onLoadReport(MetricReport report) { lastUpdated = ticker.nanoTime(); weight = newWeight; } + + /** + * Returns the utilization value computed from the specified metric names. If the custom + * metrics are present and valid, the maximum of the custom metrics is returned. Otherwise, + * if application utilization is > 0, it is returned. If neither are present, the CPU + * utilization is returned. + */ + private double getUtilization(MetricReport report) { + OptionalDouble customUtil = getCustomMetricUtilization(report); + if (customUtil.isPresent()) { + return customUtil.getAsDouble(); + } + double appUtil = report.getApplicationUtilization(); + if (appUtil > 0) { + return appUtil; + } + return report.getCpuUtilization(); + } + + /** + * Returns the maximum utilization value among the parsed metric names. + * Returns OptionalDouble.empty() if NONE of the specified metrics are present in the report, + * or if all present metrics are NaN or non positive. + */ + private OptionalDouble getCustomMetricUtilization(MetricReport report) { + OptionalDouble max = OptionalDouble.empty(); + for (int i = 0; i < parsedMetricNamesForComputingUtilization.size(); i++) { + OptionalDouble opt = MetricReportUtils.getMetricValue(report, + parsedMetricNamesForComputingUtilization.get(i)); + if (opt.isPresent()) { + double d = opt.getAsDouble(); + if (!Double.isNaN(d) && d > 0 && (!max.isPresent() || d > max.getAsDouble())) { + max = opt; + } + } + } + return max; + } } } @@ -233,7 +437,7 @@ private final class UpdateWeightTask implements Runnable { @Override public void run() { if (currentPicker != null && currentPicker instanceof WeightedRoundRobinPicker) { - ((WeightedRoundRobinPicker) currentPicker).updateWeight(); + updateWeight((WeightedRoundRobinPicker) currentPicker); } weightUpdateTimer = syncContext.schedule(this, config.weightUpdatePeriodNanos, TimeUnit.NANOSECONDS, timeService); @@ -246,10 +450,10 @@ private void createAndApplyOrcaListeners() { for (WrrSubchannel weightedSubchannel : wChild.subchannels) { if (config.enableOobLoadReport) { OrcaOobUtil.setListener(weightedSubchannel, - wChild.getOrCreateOrcaListener(config.errorUtilizationPenalty), + wChild.getOrCreateOrcaListener(config.errorUtilizationPenalty, + config.parsedMetricNamesForComputingUtilization), OrcaOobUtil.OrcaReportingConfig.newBuilder() - .setReportInterval(config.oobReportingPeriodNanos, TimeUnit.NANOSECONDS) - .build()); + .setReportInterval(config.oobReportingPeriodNanos, TimeUnit.NANOSECONDS).build()); } else { OrcaOobUtil.setListener(weightedSubchannel, null, null); } @@ -265,32 +469,6 @@ public void shutdown() { super.shutdown(); } - private static final class WrrHelper extends ForwardingLoadBalancerHelper { - private final Helper delegate; - private WeightedRoundRobinLoadBalancer wrr; - - WrrHelper(Helper helper) { - this.delegate = helper; - } - - void setLoadBalancer(WeightedRoundRobinLoadBalancer lb) { - this.wrr = lb; - } - - @Override - protected Helper delegate() { - return delegate; - } - - @Override - public Subchannel createSubchannel(CreateSubchannelArgs args) { - checkElementIndex(0, args.getAddresses().size(), "Empty address group"); - WeightedChildLbState childLbState = - (WeightedChildLbState) wrr.getChildLbStateEag(args.getAddresses().get(0)); - return wrr.new WrrSubchannel(delegate().createSubchannel(args), childLbState); - } - } - @VisibleForTesting final class WrrSubchannel extends ForwardingSubchannel { private final Subchannel delegate; @@ -329,9 +507,12 @@ public void shutdown() { @VisibleForTesting static final class WeightedRoundRobinPicker extends SubchannelPicker { - private final List children; - private final Map subchannelToReportListenerMap = - new HashMap<>(); + // Parallel lists (column-based storage instead of normal row-based storage of List). + // The ith element of children corresponds to the ith element of pickers, listeners, and even + // updateWeight(float[]). + private final List children; // May only be accessed from sync context + private final List pickers; + private final List reportListeners; private final boolean enableOobLoadReport; private final float errorUtilizationPenalty; private final AtomicInteger sequence; @@ -339,59 +520,59 @@ static final class WeightedRoundRobinPicker extends SubchannelPicker { private volatile StaticStrideScheduler scheduler; WeightedRoundRobinPicker(List children, boolean enableOobLoadReport, - float errorUtilizationPenalty, AtomicInteger sequence) { + float errorUtilizationPenalty, AtomicInteger sequence, + ImmutableList parsedMetricNamesForComputingUtilization) { checkNotNull(children, "children"); Preconditions.checkArgument(!children.isEmpty(), "empty child list"); this.children = children; + List pickers = new ArrayList<>(children.size()); + List reportListeners = new ArrayList<>(children.size()); for (ChildLbState child : children) { WeightedChildLbState wChild = (WeightedChildLbState) child; - for (WrrSubchannel subchannel : wChild.subchannels) { - this.subchannelToReportListenerMap - .put(subchannel, wChild.getOrCreateOrcaListener(errorUtilizationPenalty)); - } + pickers.add(wChild.getCurrentPicker()); + reportListeners.add(wChild.getOrCreateOrcaListener(errorUtilizationPenalty, + parsedMetricNamesForComputingUtilization)); } + this.pickers = pickers; + this.reportListeners = reportListeners; this.enableOobLoadReport = enableOobLoadReport; this.errorUtilizationPenalty = errorUtilizationPenalty; this.sequence = checkNotNull(sequence, "sequence"); - // For equality we treat children as a set; use hash code as defined by Set + // For equality we treat pickers as a set; use hash code as defined by Set int sum = 0; - for (ChildLbState child : children) { - sum += child.hashCode(); + for (SubchannelPicker picker : pickers) { + sum += picker.hashCode(); } this.hashCode = sum ^ Boolean.hashCode(enableOobLoadReport) ^ Float.hashCode(errorUtilizationPenalty); - - updateWeight(); } @Override public PickResult pickSubchannel(PickSubchannelArgs args) { - ChildLbState childLbState = children.get(scheduler.pick()); - WeightedChildLbState wChild = (WeightedChildLbState) childLbState; - PickResult pickResult = childLbState.getCurrentPicker().pickSubchannel(args); + int pick = scheduler.pick(); + PickResult pickResult = pickers.get(pick).pickSubchannel(args); Subchannel subchannel = pickResult.getSubchannel(); if (subchannel == null) { return pickResult; } + + subchannel = ((WrrSubchannel) subchannel).delegate(); if (!enableOobLoadReport) { - return PickResult.withSubchannel(subchannel, - OrcaPerRequestUtil.getInstance().newOrcaClientStreamTracerFactory( - subchannelToReportListenerMap.getOrDefault(subchannel, - wChild.getOrCreateOrcaListener(errorUtilizationPenalty)))); + return pickResult.copyWithSubchannel(subchannel) + .copyWithStreamTracerFactory( + OrcaPerRequestUtil.getInstance().newOrcaClientStreamTracerFactory( + reportListeners.get(pick))); } else { - return PickResult.withSubchannel(subchannel); + return pickResult.copyWithSubchannel(subchannel); } } - private void updateWeight() { - float[] newWeights = new float[children.size()]; - for (int i = 0; i < children.size(); i++) { - double newWeight = ((WeightedChildLbState)children.get(i)).getWeight(); - newWeights[i] = newWeight > 0 ? (float) newWeight : 0.0f; - } + /** Returns {@code true} if weights are different than round_robin. */ + private boolean updateWeight(float[] newWeights) { this.scheduler = new StaticStrideScheduler(newWeights, sequence); + return !this.scheduler.usesRoundRobin(); } @Override @@ -399,7 +580,8 @@ public String toString() { return MoreObjects.toStringHelper(WeightedRoundRobinPicker.class) .add("enableOobLoadReport", enableOobLoadReport) .add("errorUtilizationPenalty", errorUtilizationPenalty) - .add("list", children).toString(); + .add("pickers", pickers) + .toString(); } @VisibleForTesting @@ -426,8 +608,8 @@ public boolean equals(Object o) { && sequence == other.sequence && enableOobLoadReport == other.enableOobLoadReport && Float.compare(errorUtilizationPenalty, other.errorUtilizationPenalty) == 0 - && children.size() == other.children.size() - && new HashSet<>(children).containsAll(other.children); + && pickers.size() == other.pickers.size() + && new HashSet<>(pickers).containsAll(other.pickers); } } @@ -454,6 +636,7 @@ public boolean equals(Object o) { static final class StaticStrideScheduler { private final short[] scaledWeights; private final AtomicInteger sequence; + private final boolean usesRoundRobin; private static final int K_MAX_WEIGHT = 0xFFFF; // Assuming the mean of all known weights is M, StaticStrideScheduler will clamp @@ -495,10 +678,14 @@ static final class StaticStrideScheduler { unscaledMeanWeight = sumWeight / numWeightedChannels; unscaledMaxWeight = Math.min(unscaledMaxWeight, (float) (K_MAX_RATIO * unscaledMeanWeight)); } else { - // Fall back to round robin if all values are non-positives + // Fall back to round robin if all values are non-positives. Note that + // numWeightedChannels == 1 also behaves like RR because the weights are all the same, but + // the weights aren't 1, so it doesn't go through this path. unscaledMeanWeight = 1; unscaledMaxWeight = 1; } + // We need at least two weights for WRR to be distinguishable from round_robin. + usesRoundRobin = numWeightedChannels < 2; // Scales weights s.t. max(weights) == K_MAX_WEIGHT, meanWeight is scaled accordingly. // Note that, since we cap the weights to stay within K_MAX_RATIO, meanWeight might not @@ -521,7 +708,14 @@ static final class StaticStrideScheduler { this.sequence = sequence; } - /** Returns the next sequence number and atomically increases sequence with wraparound. */ + // Without properly weighted channels, we do plain vanilla round_robin. + boolean usesRoundRobin() { + return usesRoundRobin; + } + + /** + * Returns the next sequence number and atomically increases sequence with wraparound. + */ private long nextSequence() { return Integer.toUnsignedLong(sequence.getAndIncrement()); } @@ -578,32 +772,70 @@ static final class WeightedRoundRobinLoadBalancerConfig { final long oobReportingPeriodNanos; final long weightUpdatePeriodNanos; final float errorUtilizationPenalty; + final ImmutableList parsedMetricNamesForComputingUtilization; public static Builder newBuilder() { return new Builder(); } private WeightedRoundRobinLoadBalancerConfig(long blackoutPeriodNanos, - long weightExpirationPeriodNanos, - boolean enableOobLoadReport, - long oobReportingPeriodNanos, - long weightUpdatePeriodNanos, - float errorUtilizationPenalty) { + long weightExpirationPeriodNanos, boolean enableOobLoadReport, long oobReportingPeriodNanos, + long weightUpdatePeriodNanos, float errorUtilizationPenalty, + ImmutableList metricNamesForComputingUtilization) { this.blackoutPeriodNanos = blackoutPeriodNanos; this.weightExpirationPeriodNanos = weightExpirationPeriodNanos; this.enableOobLoadReport = enableOobLoadReport; this.oobReportingPeriodNanos = oobReportingPeriodNanos; this.weightUpdatePeriodNanos = weightUpdatePeriodNanos; this.errorUtilizationPenalty = errorUtilizationPenalty; + + ImmutableList.Builder builder = ImmutableList.builder(); + if (metricNamesForComputingUtilization != null) { + for (int i = 0; i < metricNamesForComputingUtilization.size(); i++) { + String metricName = metricNamesForComputingUtilization.get(i); + ParsedMetricName parsed = MetricReportUtils.ParsedMetricName.parse(metricName); + if (parsed.getMetricType() != MetricReportUtils.MetricType.INVALID) { + builder.add(parsed); + } else { + log.log(Level.FINE, "Invalid custom metric name configured and ignored: " + metricName); + } + } + } + this.parsedMetricNamesForComputingUtilization = builder.build(); + } + + @Override + public boolean equals(Object o) { + if (!(o instanceof WeightedRoundRobinLoadBalancerConfig)) { + return false; + } + WeightedRoundRobinLoadBalancerConfig that = (WeightedRoundRobinLoadBalancerConfig) o; + return this.blackoutPeriodNanos == that.blackoutPeriodNanos + && this.weightExpirationPeriodNanos == that.weightExpirationPeriodNanos + && this.enableOobLoadReport == that.enableOobLoadReport + && this.oobReportingPeriodNanos == that.oobReportingPeriodNanos + && this.weightUpdatePeriodNanos == that.weightUpdatePeriodNanos + // Float.compare considers NaNs equal + && Float.compare(this.errorUtilizationPenalty, that.errorUtilizationPenalty) == 0 + && Objects.equals(this.parsedMetricNamesForComputingUtilization, + that.parsedMetricNamesForComputingUtilization); + } + + @Override + public int hashCode() { + return Objects.hash(blackoutPeriodNanos, weightExpirationPeriodNanos, enableOobLoadReport, + oobReportingPeriodNanos, weightUpdatePeriodNanos, errorUtilizationPenalty, + parsedMetricNamesForComputingUtilization); } static final class Builder { long blackoutPeriodNanos = 10_000_000_000L; // 10s - long weightExpirationPeriodNanos = 180_000_000_000L; //3min + long weightExpirationPeriodNanos = 180_000_000_000L; // 3min boolean enableOobLoadReport = false; long oobReportingPeriodNanos = 10_000_000_000L; // 10s long weightUpdatePeriodNanos = 1_000_000_000L; // 1s float errorUtilizationPenalty = 1.0F; + ImmutableList metricNamesForComputingUtilization = ImmutableList.of(); private Builder() { @@ -641,10 +873,17 @@ Builder setErrorUtilizationPenalty(float errorUtilizationPenalty) { return this; } + Builder setMetricNamesForComputingUtilization( + List metricNamesForComputingUtilization) { + this.metricNamesForComputingUtilization = + ImmutableList.copyOf(metricNamesForComputingUtilization); + return this; + } + WeightedRoundRobinLoadBalancerConfig build() { return new WeightedRoundRobinLoadBalancerConfig(blackoutPeriodNanos, - weightExpirationPeriodNanos, enableOobLoadReport, oobReportingPeriodNanos, - weightUpdatePeriodNanos, errorUtilizationPenalty); + weightExpirationPeriodNanos, enableOobLoadReport, oobReportingPeriodNanos, + weightUpdatePeriodNanos, errorUtilizationPenalty, metricNamesForComputingUtilization); } } } diff --git a/xds/src/main/java/io/grpc/xds/WeightedRoundRobinLoadBalancerProvider.java b/xds/src/main/java/io/grpc/xds/WeightedRoundRobinLoadBalancerProvider.java index 161e7c4ed0c..0f9fcf07c9a 100644 --- a/xds/src/main/java/io/grpc/xds/WeightedRoundRobinLoadBalancerProvider.java +++ b/xds/src/main/java/io/grpc/xds/WeightedRoundRobinLoadBalancerProvider.java @@ -18,21 +18,21 @@ import com.google.common.annotations.VisibleForTesting; import io.grpc.Deadline; -import io.grpc.ExperimentalApi; import io.grpc.Internal; import io.grpc.LoadBalancer; import io.grpc.LoadBalancer.Helper; import io.grpc.LoadBalancerProvider; import io.grpc.NameResolver.ConfigOrError; import io.grpc.Status; +import io.grpc.internal.GrpcUtil; import io.grpc.internal.JsonUtil; import io.grpc.xds.WeightedRoundRobinLoadBalancer.WeightedRoundRobinLoadBalancerConfig; +import java.util.List; import java.util.Map; /** * Provides a {@link WeightedRoundRobinLoadBalancer}. * */ -@ExperimentalApi("https://github.com/grpc/grpc-java/issues/9885") @Internal public final class WeightedRoundRobinLoadBalancerProvider extends LoadBalancerProvider { @@ -75,14 +75,16 @@ public ConfigOrError parseLoadBalancingPolicyConfig(Map rawConfig) { private ConfigOrError parseLoadBalancingPolicyConfigInternal(Map rawConfig) { Long blackoutPeriodNanos = JsonUtil.getStringAsDuration(rawConfig, "blackoutPeriod"); Long weightExpirationPeriodNanos = - JsonUtil.getStringAsDuration(rawConfig, "weightExpirationPeriod"); + JsonUtil.getStringAsDuration(rawConfig, "weightExpirationPeriod"); Long oobReportingPeriodNanos = JsonUtil.getStringAsDuration(rawConfig, "oobReportingPeriod"); Boolean enableOobLoadReport = JsonUtil.getBoolean(rawConfig, "enableOobLoadReport"); Long weightUpdatePeriodNanos = JsonUtil.getStringAsDuration(rawConfig, "weightUpdatePeriod"); Float errorUtilizationPenalty = JsonUtil.getNumberAsFloat(rawConfig, "errorUtilizationPenalty"); + List metricNamesForComputingUtilization = JsonUtil.getListOfStrings(rawConfig, + "metricNamesForComputingUtilization"); WeightedRoundRobinLoadBalancerConfig.Builder configBuilder = - WeightedRoundRobinLoadBalancerConfig.newBuilder(); + WeightedRoundRobinLoadBalancerConfig.newBuilder(); if (blackoutPeriodNanos != null) { configBuilder.setBlackoutPeriodNanos(blackoutPeriodNanos); } @@ -104,6 +106,10 @@ private ConfigOrError parseLoadBalancingPolicyConfigInternal(Map rawC if (errorUtilizationPenalty != null) { configBuilder.setErrorUtilizationPenalty(errorUtilizationPenalty); } + if (metricNamesForComputingUtilization != null + && GrpcUtil.getFlag("GRPC_EXPERIMENTAL_WRR_CUSTOM_METRICS", false)) { + configBuilder.setMetricNamesForComputingUtilization(metricNamesForComputingUtilization); + } return ConfigOrError.fromConfig(configBuilder.build()); } } diff --git a/xds/src/main/java/io/grpc/xds/WeightedTargetLoadBalancer.java b/xds/src/main/java/io/grpc/xds/WeightedTargetLoadBalancer.java index f108ac899c1..9468a9daf9d 100644 --- a/xds/src/main/java/io/grpc/xds/WeightedTargetLoadBalancer.java +++ b/xds/src/main/java/io/grpc/xds/WeightedTargetLoadBalancer.java @@ -23,6 +23,7 @@ import static io.grpc.ConnectivityState.TRANSIENT_FAILURE; import com.google.common.collect.ImmutableMap; +import io.grpc.Attributes; import io.grpc.ConnectivityState; import io.grpc.InternalLogId; import io.grpc.LoadBalancer; @@ -42,6 +43,8 @@ /** Load balancer for weighted_target policy. */ final class WeightedTargetLoadBalancer extends LoadBalancer { + public static final Attributes.Key CHILD_NAME = + Attributes.Key.create("io.grpc.xds.WeightedTargetLoadBalancer.CHILD_NAME"); private final XdsLogger logger; private final Map childBalancers = new HashMap<>(); @@ -76,26 +79,27 @@ public Status acceptResolvedAddressesInternal(ResolvedAddresses resolvedAddresse WeightedTargetConfig weightedTargetConfig = (WeightedTargetConfig) lbConfig; Map newTargets = weightedTargetConfig.targets; for (String targetName : newTargets.keySet()) { - WeightedPolicySelection weightedChildLbConfig = newTargets.get(targetName); if (!targets.containsKey(targetName)) { ChildHelper childHelper = new ChildHelper(targetName); GracefulSwitchLoadBalancer childBalancer = new GracefulSwitchLoadBalancer(childHelper); - childBalancer.switchTo(weightedChildLbConfig.policySelection.getProvider()); childHelpers.put(targetName, childHelper); childBalancers.put(targetName, childBalancer); - } else if (!weightedChildLbConfig.policySelection.getProvider().equals( - targets.get(targetName).policySelection.getProvider())) { - childBalancers.get(targetName) - .switchTo(weightedChildLbConfig.policySelection.getProvider()); } } targets = newTargets; + Status status = Status.OK; for (String targetName : targets.keySet()) { - childBalancers.get(targetName).handleResolvedAddresses( + Status newStatus = childBalancers.get(targetName).acceptResolvedAddresses( resolvedAddresses.toBuilder() .setAddresses(AddressFilter.filter(resolvedAddresses.getAddresses(), targetName)) - .setLoadBalancingPolicyConfig(targets.get(targetName).policySelection.getConfig()) + .setLoadBalancingPolicyConfig(targets.get(targetName).childConfig) + .setAttributes(resolvedAddresses.getAttributes().toBuilder() + .set(CHILD_NAME, targetName) + .build()) .build()); + if (!newStatus.isOk()) { + status = newStatus; + } } // Cleanup removed targets. @@ -108,7 +112,7 @@ public Status acceptResolvedAddressesInternal(ResolvedAddresses resolvedAddresse childBalancers.keySet().retainAll(targets.keySet()); childHelpers.keySet().retainAll(targets.keySet()); updateOverallBalancingState(); - return Status.OK; + return status; } @Override @@ -124,6 +128,8 @@ public void handleNameResolutionError(Status error) { } @Override + @Deprecated + @SuppressWarnings("InlineMeSuggester") public boolean canHandleEmptyAddressListFromNameResolution() { return true; } diff --git a/xds/src/main/java/io/grpc/xds/WeightedTargetLoadBalancerProvider.java b/xds/src/main/java/io/grpc/xds/WeightedTargetLoadBalancerProvider.java index c6a0893db02..15318693aca 100644 --- a/xds/src/main/java/io/grpc/xds/WeightedTargetLoadBalancerProvider.java +++ b/xds/src/main/java/io/grpc/xds/WeightedTargetLoadBalancerProvider.java @@ -25,12 +25,10 @@ import io.grpc.LoadBalancerRegistry; import io.grpc.NameResolver.ConfigOrError; import io.grpc.Status; +import io.grpc.internal.GrpcUtil; import io.grpc.internal.JsonUtil; -import io.grpc.internal.ServiceConfigUtil; -import io.grpc.internal.ServiceConfigUtil.LbConfig; -import io.grpc.internal.ServiceConfigUtil.PolicySelection; +import io.grpc.util.GracefulSwitchLoadBalancer; import java.util.LinkedHashMap; -import java.util.List; import java.util.Map; import java.util.Objects; import javax.annotation.Nullable; @@ -97,22 +95,17 @@ public ConfigOrError parseLoadBalancingPolicyConfig(Map rawConfig) { return ConfigOrError.fromError(Status.INTERNAL.withDescription( "Wrong weight for target " + name + " in weighted_target LB policy:\n " + rawConfig)); } - List childConfigCandidates = ServiceConfigUtil.unwrapLoadBalancingConfigList( - JsonUtil.getListOfObjects(rawWeightedTarget, "childPolicy")); - if (childConfigCandidates == null || childConfigCandidates.isEmpty()) { - return ConfigOrError.fromError(Status.INTERNAL.withDescription( - "No child policy for target " + name + " in weighted_target LB policy:\n " - + rawConfig)); - } LoadBalancerRegistry lbRegistry = this.lbRegistry == null ? LoadBalancerRegistry.getDefaultRegistry() : this.lbRegistry; - ConfigOrError selectedConfig = - ServiceConfigUtil.selectLbPolicyFromList(childConfigCandidates, lbRegistry); - if (selectedConfig.getError() != null) { - return selectedConfig; + ConfigOrError childConfig = GracefulSwitchLoadBalancer.parseLoadBalancingPolicyConfig( + JsonUtil.getListOfObjects(rawWeightedTarget, "childPolicy"), lbRegistry); + if (childConfig.getError() != null) { + return ConfigOrError.fromError(GrpcUtil.statusWithDetails( + Status.Code.INTERNAL, + "Could not parse weighted_target's child policy: " + name, + childConfig.getError())); } - PolicySelection policySelection = (PolicySelection) selectedConfig.getConfig(); - parsedChildConfigs.put(name, new WeightedPolicySelection(weight, policySelection)); + parsedChildConfigs.put(name, new WeightedPolicySelection(weight, childConfig.getConfig())); } return ConfigOrError.fromConfig(new WeightedTargetConfig(parsedChildConfigs)); } catch (RuntimeException e) { @@ -125,11 +118,11 @@ public ConfigOrError parseLoadBalancingPolicyConfig(Map rawConfig) { static final class WeightedPolicySelection { final int weight; - final PolicySelection policySelection; + final Object childConfig; - WeightedPolicySelection(int weight, PolicySelection policySelection) { + WeightedPolicySelection(int weight, Object childConfig) { this.weight = weight; - this.policySelection = policySelection; + this.childConfig = childConfig; } @Override @@ -141,19 +134,19 @@ public boolean equals(Object o) { return false; } WeightedPolicySelection that = (WeightedPolicySelection) o; - return weight == that.weight && Objects.equals(policySelection, that.policySelection); + return weight == that.weight && Objects.equals(childConfig, that.childConfig); } @Override public int hashCode() { - return Objects.hash(weight, policySelection); + return Objects.hash(weight, childConfig); } @Override public String toString() { return MoreObjects.toStringHelper(this) .add("weight", weight) - .add("policySelection", policySelection) + .add("childConfig", childConfig) .toString(); } } diff --git a/xds/src/main/java/io/grpc/xds/WrrLocalityLoadBalancer.java b/xds/src/main/java/io/grpc/xds/WrrLocalityLoadBalancer.java index e7463cd3710..1a12412f923 100644 --- a/xds/src/main/java/io/grpc/xds/WrrLocalityLoadBalancer.java +++ b/xds/src/main/java/io/grpc/xds/WrrLocalityLoadBalancer.java @@ -27,11 +27,9 @@ import io.grpc.LoadBalancer; import io.grpc.LoadBalancerRegistry; import io.grpc.Status; -import io.grpc.internal.ServiceConfigUtil.PolicySelection; import io.grpc.util.GracefulSwitchLoadBalancer; import io.grpc.xds.WeightedTargetLoadBalancerProvider.WeightedPolicySelection; import io.grpc.xds.WeightedTargetLoadBalancerProvider.WeightedTargetConfig; -import io.grpc.xds.client.Locality; import io.grpc.xds.client.XdsLogger; import io.grpc.xds.client.XdsLogger.XdsLogLevel; import java.util.HashMap; @@ -73,11 +71,11 @@ public Status acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { = (WrrLocalityConfig) resolvedAddresses.getLoadBalancingPolicyConfig(); // A map of locality weights is built up from the locality weight attributes in each address. - Map localityWeights = new HashMap<>(); + Map localityWeights = new HashMap<>(); for (EquivalentAddressGroup eag : resolvedAddresses.getAddresses()) { Attributes eagAttrs = eag.getAttributes(); - Locality locality = eagAttrs.get(InternalXdsAttributes.ATTR_LOCALITY); - Integer localityWeight = eagAttrs.get(InternalXdsAttributes.ATTR_LOCALITY_WEIGHT); + String locality = eagAttrs.get(EquivalentAddressGroup.ATTR_LOCALITY_NAME); + Integer localityWeight = eagAttrs.get(XdsAttributes.ATTR_LOCALITY_WEIGHT); if (locality == null) { Status unavailableStatus = Status.UNAVAILABLE.withDescription( @@ -106,19 +104,19 @@ public Status acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { // Weighted target LB expects a WeightedPolicySelection for each locality as it will create a // child LB for each. Map weightedPolicySelections = new HashMap<>(); - for (Locality locality : localityWeights.keySet()) { - weightedPolicySelections.put(locality.toString(), + for (String locality : localityWeights.keySet()) { + weightedPolicySelections.put(locality, new WeightedPolicySelection(localityWeights.get(locality), - wrrLocalityConfig.childPolicy)); + wrrLocalityConfig.childConfig)); } - switchLb.switchTo(lbRegistry.getProvider(WEIGHTED_TARGET_POLICY_NAME)); - switchLb.handleResolvedAddresses( + Object switchConfig = GracefulSwitchLoadBalancer.createLoadBalancingPolicyConfig( + lbRegistry.getProvider(WEIGHTED_TARGET_POLICY_NAME), + new WeightedTargetConfig(weightedPolicySelections)); + return switchLb.acceptResolvedAddresses( resolvedAddresses.toBuilder() - .setLoadBalancingPolicyConfig(new WeightedTargetConfig(weightedPolicySelections)) + .setLoadBalancingPolicyConfig(switchConfig) .build()); - - return Status.OK; } @Override @@ -137,10 +135,10 @@ public void shutdown() { */ static final class WrrLocalityConfig { - final PolicySelection childPolicy; + final Object childConfig; - WrrLocalityConfig(PolicySelection childPolicy) { - this.childPolicy = childPolicy; + WrrLocalityConfig(Object childConfig) { + this.childConfig = childConfig; } @Override @@ -152,17 +150,17 @@ public boolean equals(Object o) { return false; } WrrLocalityConfig that = (WrrLocalityConfig) o; - return Objects.equals(childPolicy, that.childPolicy); + return Objects.equals(childConfig, that.childConfig); } @Override public int hashCode() { - return Objects.hashCode(childPolicy); + return Objects.hashCode(childConfig); } @Override public String toString() { - return MoreObjects.toStringHelper(this).add("childPolicy", childPolicy).toString(); + return MoreObjects.toStringHelper(this).add("childConfig", childConfig).toString(); } } } diff --git a/xds/src/main/java/io/grpc/xds/WrrLocalityLoadBalancerProvider.java b/xds/src/main/java/io/grpc/xds/WrrLocalityLoadBalancerProvider.java index 31a4e128140..3693df9208a 100644 --- a/xds/src/main/java/io/grpc/xds/WrrLocalityLoadBalancerProvider.java +++ b/xds/src/main/java/io/grpc/xds/WrrLocalityLoadBalancerProvider.java @@ -23,12 +23,10 @@ import io.grpc.LoadBalancerRegistry; import io.grpc.NameResolver.ConfigOrError; import io.grpc.Status; +import io.grpc.internal.GrpcUtil; import io.grpc.internal.JsonUtil; -import io.grpc.internal.ServiceConfigUtil; -import io.grpc.internal.ServiceConfigUtil.LbConfig; -import io.grpc.internal.ServiceConfigUtil.PolicySelection; +import io.grpc.util.GracefulSwitchLoadBalancer; import io.grpc.xds.WrrLocalityLoadBalancer.WrrLocalityConfig; -import java.util.List; import java.util.Map; /** @@ -62,21 +60,15 @@ public String getPolicyName() { @Override public ConfigOrError parseLoadBalancingPolicyConfig(Map rawConfig) { try { - List childConfigCandidates = ServiceConfigUtil.unwrapLoadBalancingConfigList( + ConfigOrError childConfig = GracefulSwitchLoadBalancer.parseLoadBalancingPolicyConfig( JsonUtil.getListOfObjects(rawConfig, "childPolicy")); - if (childConfigCandidates == null || childConfigCandidates.isEmpty()) { - return ConfigOrError.fromError(Status.INTERNAL.withDescription( - "No child policy in wrr_locality LB policy: " - + rawConfig)); + if (childConfig.getError() != null) { + return ConfigOrError.fromError(GrpcUtil.statusWithDetails( + Status.Code.INTERNAL, + "Failed to parse child policy in wrr_locality LB policy", + childConfig.getError())); } - ConfigOrError selectedConfig = - ServiceConfigUtil.selectLbPolicyFromList(childConfigCandidates, - LoadBalancerRegistry.getDefaultRegistry()); - if (selectedConfig.getError() != null) { - return selectedConfig; - } - PolicySelection policySelection = (PolicySelection) selectedConfig.getConfig(); - return ConfigOrError.fromConfig(new WrrLocalityConfig(policySelection)); + return ConfigOrError.fromConfig(new WrrLocalityConfig(childConfig.getConfig())); } catch (RuntimeException e) { return ConfigOrError.fromError(Status.INTERNAL.withCause(e) .withDescription("Failed to parse wrr_locality LB config: " + rawConfig)); diff --git a/xds/src/main/java/io/grpc/xds/XdsAttributes.java b/xds/src/main/java/io/grpc/xds/XdsAttributes.java new file mode 100644 index 00000000000..d3fe8d4619c --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/XdsAttributes.java @@ -0,0 +1,104 @@ +/* + * Copyright 2019 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds; + +import io.grpc.Attributes; +import io.grpc.EquivalentAddressGroup; +import io.grpc.Grpc; +import io.grpc.InternalEquivalentAddressGroup; +import io.grpc.NameResolver; +import io.grpc.xds.XdsNameResolverProvider.CallCounterProvider; +import io.grpc.xds.client.Locality; +import io.grpc.xds.client.XdsClient; + +/** + * Attributes used for xDS implementation. + */ +final class XdsAttributes { + /** + * Attribute key for passing around the XdsClient object pool across NameResolver/LoadBalancers. + */ + @NameResolver.ResolutionResultAttr + static final Attributes.Key XDS_CLIENT = + Attributes.Key.create("io.grpc.xds.XdsAttributes.xdsClient"); + + /** + * Attribute key for passing around the latest XdsConfig across NameResolver/LoadBalancers. + */ + @NameResolver.ResolutionResultAttr + static final Attributes.Key XDS_CONFIG = + Attributes.Key.create("io.grpc.xds.XdsAttributes.xdsConfig"); + + + /** + * Attribute key for passing around the XdsDependencyManager across NameResolver/LoadBalancers. + */ + @NameResolver.ResolutionResultAttr + static final Attributes.Key + XDS_CLUSTER_SUBSCRIPT_REGISTRY = + Attributes.Key.create("io.grpc.xds.XdsAttributes.xdsConfig.XdsClusterSubscriptionRegistry"); + + /** + * Attribute key for obtaining the global provider that provides atomics for aggregating + * outstanding RPCs sent to each cluster. + */ + @NameResolver.ResolutionResultAttr + static final Attributes.Key CALL_COUNTER_PROVIDER = + Attributes.Key.create("io.grpc.xds.XdsAttributes.callCounterProvider"); + + /** + * Map from localities to their weights. + */ + @NameResolver.ResolutionResultAttr + static final Attributes.Key ATTR_LOCALITY_WEIGHT = + Attributes.Key.create("io.grpc.xds.XdsAttributes.localityWeight"); + + /** + * Name of the cluster that provides this EquivalentAddressGroup. + */ + @EquivalentAddressGroup.Attr + public static final Attributes.Key ATTR_CLUSTER_NAME = + Attributes.Key.create("io.grpc.xds.XdsAttributes.clusterName"); + + /** + * The locality that this EquivalentAddressGroup is in. + */ + @EquivalentAddressGroup.Attr + static final Attributes.Key ATTR_LOCALITY = + Attributes.Key.create("io.grpc.xds.XdsAttributes.locality"); + + /** + * Endpoint weight for load balancing purposes. + */ + @EquivalentAddressGroup.Attr + static final Attributes.Key ATTR_SERVER_WEIGHT = InternalEquivalentAddressGroup.ATTR_WEIGHT; + + /** + * Filter chain match for network filters. + */ + @Grpc.TransportAttr + static final Attributes.Key + ATTR_FILTER_CHAIN_SELECTOR_MANAGER = Attributes.Key.create( + "io.grpc.xds.XdsAttributes.filterChainSelectorManager"); + + /** Grace time to use when draining. Null for an infinite grace time. */ + @Grpc.TransportAttr + static final Attributes.Key ATTR_DRAIN_GRACE_NANOS = + Attributes.Key.create("io.grpc.xds.XdsAttributes.drainGraceTime"); + + private XdsAttributes() {} +} diff --git a/xds/src/main/java/io/grpc/xds/XdsClientMetricReporterImpl.java b/xds/src/main/java/io/grpc/xds/XdsClientMetricReporterImpl.java new file mode 100644 index 00000000000..5cfba11c065 --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/XdsClientMetricReporterImpl.java @@ -0,0 +1,233 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.util.concurrent.ListenableFuture; +import io.grpc.LongCounterMetricInstrument; +import io.grpc.LongGaugeMetricInstrument; +import io.grpc.MetricInstrumentRegistry; +import io.grpc.MetricRecorder; +import io.grpc.MetricRecorder.BatchCallback; +import io.grpc.MetricRecorder.BatchRecorder; +import io.grpc.MetricRecorder.Registration; +import io.grpc.xds.client.XdsClient; +import io.grpc.xds.client.XdsClient.ResourceMetadata; +import io.grpc.xds.client.XdsClient.ResourceMetadata.ResourceMetadataStatus; +import io.grpc.xds.client.XdsClient.ServerConnectionCallback; +import io.grpc.xds.client.XdsClientMetricReporter; +import io.grpc.xds.client.XdsResourceType; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import java.util.logging.Level; +import java.util.logging.Logger; +import javax.annotation.Nullable; + +/** + * XdsClientMetricReporter implementation. + */ +final class XdsClientMetricReporterImpl implements XdsClientMetricReporter { + + private static final Logger logger = Logger.getLogger( + XdsClientMetricReporterImpl.class.getName()); + private static final LongCounterMetricInstrument SERVER_FAILURE_COUNTER; + private static final LongCounterMetricInstrument RESOURCE_UPDATES_VALID_COUNTER; + private static final LongCounterMetricInstrument RESOURCE_UPDATES_INVALID_COUNTER; + private static final LongGaugeMetricInstrument CONNECTED_GAUGE; + private static final LongGaugeMetricInstrument RESOURCES_GAUGE; + + private final MetricRecorder metricRecorder; + private final String target; + @Nullable + private Registration gaugeRegistration = null; + + static { + MetricInstrumentRegistry metricInstrumentRegistry + = MetricInstrumentRegistry.getDefaultRegistry(); + SERVER_FAILURE_COUNTER = metricInstrumentRegistry.registerLongCounter( + "grpc.xds_client.server_failure", + "EXPERIMENTAL. A counter of xDS servers going from healthy to unhealthy. A server goes" + + " unhealthy when we have a connectivity failure or when the ADS stream fails without" + + " seeing a response message, as per gRFC A57.", "{failure}", + Arrays.asList("grpc.target", "grpc.xds.server"), Collections.emptyList(), false); + RESOURCE_UPDATES_VALID_COUNTER = metricInstrumentRegistry.registerLongCounter( + "grpc.xds_client.resource_updates_valid", + "EXPERIMENTAL. A counter of resources received that were considered valid. The counter will" + + " be incremented even for resources that have not changed.", "{resource}", + Arrays.asList("grpc.target", "grpc.xds.server", "grpc.xds.resource_type"), + Collections.emptyList(), false); + RESOURCE_UPDATES_INVALID_COUNTER = metricInstrumentRegistry.registerLongCounter( + "grpc.xds_client.resource_updates_invalid", + "EXPERIMENTAL. A counter of resources received that were considered invalid.", "{resource}", + Arrays.asList("grpc.target", "grpc.xds.server", "grpc.xds.resource_type"), + Collections.emptyList(), false); + CONNECTED_GAUGE = metricInstrumentRegistry.registerLongGauge("grpc.xds_client.connected", + "EXPERIMENTAL. Whether or not the xDS client currently has a working ADS stream to the xDS" + + " server. For a given server, this will be set to 1 when the stream is initially" + + " created. It will be set to 0 when we have a connectivity failure or when the ADS" + + " stream fails without seeing a response message, as per gRFC A57. Once set to 0, it" + + " will be reset to 1 when we receive the first response on an ADS stream.", "{bool}", + Arrays.asList("grpc.target", "grpc.xds.server"), Collections.emptyList(), false); + RESOURCES_GAUGE = metricInstrumentRegistry.registerLongGauge("grpc.xds_client.resources", + "EXPERIMENTAL. Number of xDS resources.", "{resource}", + Arrays.asList("grpc.target", "grpc.xds.authority", "grpc.xds.cache_state", + "grpc.xds.resource_type"), Collections.emptyList(), false); + } + + XdsClientMetricReporterImpl(MetricRecorder metricRecorder, String target) { + this.metricRecorder = metricRecorder; + this.target = target; + } + + @Override + public void reportResourceUpdates(long validResourceCount, long invalidResourceCount, + String xdsServer, String resourceType) { + metricRecorder.addLongCounter(RESOURCE_UPDATES_VALID_COUNTER, validResourceCount, + Arrays.asList(target, xdsServer, resourceType), Collections.emptyList()); + metricRecorder.addLongCounter(RESOURCE_UPDATES_INVALID_COUNTER, invalidResourceCount, + Arrays.asList(target, xdsServer, resourceType), Collections.emptyList()); + } + + @Override + public void reportServerFailure(long serverFailure, String xdsServer) { + metricRecorder.addLongCounter(SERVER_FAILURE_COUNTER, serverFailure, + Arrays.asList(target, xdsServer), Collections.emptyList()); + } + + void setXdsClient(XdsClient xdsClient) { + assert gaugeRegistration == null; + // register gauge here + this.gaugeRegistration = metricRecorder.registerBatchCallback(new BatchCallback() { + @Override + public void accept(BatchRecorder recorder) { + reportCallbackMetrics(recorder, xdsClient); + } + }, CONNECTED_GAUGE, RESOURCES_GAUGE); + } + + void close() { + if (gaugeRegistration != null) { + gaugeRegistration.close(); + gaugeRegistration = null; + } + } + + void reportCallbackMetrics(BatchRecorder recorder, XdsClient xdsClient) { + MetricReporterCallback callback = new MetricReporterCallback(recorder, target); + try { + Future reportServerConnectionsCompleted = xdsClient.reportServerConnections(callback); + + ListenableFuture, Map>> + getResourceMetadataCompleted = xdsClient.getSubscribedResourcesMetadataSnapshot(); + + Map, Map> metadataByType = + getResourceMetadataCompleted.get(10, TimeUnit.SECONDS); + + computeAndReportResourceCounts(metadataByType, callback); + + // Normally this shouldn't take long, but adding a timeout to avoid indefinite blocking + Void unused = reportServerConnectionsCompleted.get(5, TimeUnit.SECONDS); + } catch (ExecutionException | TimeoutException | InterruptedException e) { + if (e instanceof InterruptedException) { + Thread.currentThread().interrupt(); // re-set the current thread's interruption state + } + logger.log(Level.WARNING, "Failed to report gauge metrics", e); + } + } + + private void computeAndReportResourceCounts( + Map, Map> metadataByType, + MetricReporterCallback callback) { + for (Map.Entry, Map> metadataByTypeEntry : + metadataByType.entrySet()) { + XdsResourceType type = metadataByTypeEntry.getKey(); + Map resources = metadataByTypeEntry.getValue(); + + Map> resourceCountsByAuthorityAndState = new HashMap<>(); + for (Map.Entry resourceEntry : resources.entrySet()) { + String resourceName = resourceEntry.getKey(); + ResourceMetadata metadata = resourceEntry.getValue(); + String authority = XdsClient.getAuthorityFromResourceName(resourceName); + String cacheState = cacheStateFromResourceStatus(metadata.getStatus(), metadata.isCached()); + resourceCountsByAuthorityAndState + .computeIfAbsent(authority, k -> new HashMap<>()) + .merge(cacheState, 1L, Long::sum); + } + + // Report metrics + for (Map.Entry> authorityEntry + : resourceCountsByAuthorityAndState.entrySet()) { + String authority = authorityEntry.getKey(); + Map stateCounts = authorityEntry.getValue(); + + for (Map.Entry stateEntry : stateCounts.entrySet()) { + String cacheState = stateEntry.getKey(); + Long count = stateEntry.getValue(); + + callback.reportResourceCountGauge(count, authority, cacheState, type.typeUrl()); + } + } + } + } + + private static String cacheStateFromResourceStatus(ResourceMetadataStatus metadataStatus, + boolean isResourceCached) { + switch (metadataStatus) { + case REQUESTED: + return "requested"; + case DOES_NOT_EXIST: + return "does_not_exist"; + case ACKED: + return "acked"; + case NACKED: + return isResourceCached ? "nacked_but_cached" : "nacked"; + default: + return "unknown"; + } + } + + @VisibleForTesting + static final class MetricReporterCallback implements ServerConnectionCallback { + private final BatchRecorder recorder; + private final String target; + + MetricReporterCallback(BatchRecorder recorder, String target) { + this.recorder = recorder; + this.target = target; + } + + void reportResourceCountGauge(long resourceCount, String authority, String cacheState, + String resourceType) { + // authority = #old, for non-xdstp resource names + recorder.recordLongGauge(RESOURCES_GAUGE, resourceCount, + Arrays.asList(target, authority == null ? "#old" : authority, cacheState, resourceType), + Collections.emptyList()); + } + + @Override + public void reportServerConnectionGauge(boolean isConnected, String xdsServer) { + recorder.recordLongGauge(CONNECTED_GAUGE, isConnected ? 1 : 0, + Arrays.asList(target, xdsServer), Collections.emptyList()); + } + } +} diff --git a/xds/src/main/java/io/grpc/xds/XdsClientPoolFactory.java b/xds/src/main/java/io/grpc/xds/XdsClientPoolFactory.java index c649b3b3069..6df8d566a7a 100644 --- a/xds/src/main/java/io/grpc/xds/XdsClientPoolFactory.java +++ b/xds/src/main/java/io/grpc/xds/XdsClientPoolFactory.java @@ -16,17 +16,19 @@ package io.grpc.xds; +import io.grpc.MetricRecorder; import io.grpc.internal.ObjectPool; +import io.grpc.xds.client.Bootstrapper.BootstrapInfo; import io.grpc.xds.client.XdsClient; -import io.grpc.xds.client.XdsInitializationException; -import java.util.Map; +import java.util.List; import javax.annotation.Nullable; interface XdsClientPoolFactory { - void setBootstrapOverride(Map bootstrap); - @Nullable - ObjectPool get(); + ObjectPool get(String target); + + ObjectPool getOrCreate( + String target, BootstrapInfo bootstrapInfo, MetricRecorder metricRecorder); - ObjectPool getOrCreate() throws XdsInitializationException; + List getTargets(); } diff --git a/xds/src/main/java/io/grpc/xds/XdsClusterResource.java b/xds/src/main/java/io/grpc/xds/XdsClusterResource.java index d5fe8a0ab97..10efc47be47 100644 --- a/xds/src/main/java/io/grpc/xds/XdsClusterResource.java +++ b/xds/src/main/java/io/grpc/xds/XdsClusterResource.java @@ -18,42 +18,64 @@ import static com.google.common.base.Preconditions.checkNotNull; import static io.grpc.xds.client.Bootstrapper.ServerInfo; +import static io.grpc.xds.client.LoadStatsManager2.isEnabledOrcaLrsPropagation; import com.google.auto.value.AutoValue; import com.google.common.annotations.VisibleForTesting; import com.google.common.base.MoreObjects; +import com.google.common.base.Strings; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.protobuf.Duration; import com.google.protobuf.InvalidProtocolBufferException; import com.google.protobuf.Message; +import com.google.protobuf.Struct; import com.google.protobuf.util.Durations; import io.envoyproxy.envoy.config.cluster.v3.CircuitBreakers.Thresholds; import io.envoyproxy.envoy.config.cluster.v3.Cluster; import io.envoyproxy.envoy.config.core.v3.RoutingPriority; import io.envoyproxy.envoy.config.core.v3.SocketAddress; +import io.envoyproxy.envoy.config.core.v3.TransportSocket; import io.envoyproxy.envoy.config.endpoint.v3.ClusterLoadAssignment; +import io.envoyproxy.envoy.extensions.transport_sockets.http_11_proxy.v3.Http11ProxyUpstreamTransport; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateValidationContext; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext; import io.grpc.LoadBalancerRegistry; import io.grpc.NameResolver; +import io.grpc.internal.GrpcUtil; import io.grpc.internal.ServiceConfigUtil; import io.grpc.internal.ServiceConfigUtil.LbConfig; import io.grpc.xds.EnvoyServerProtoData.OutlierDetection; import io.grpc.xds.EnvoyServerProtoData.UpstreamTlsContext; import io.grpc.xds.XdsClusterResource.CdsUpdate; +import io.grpc.xds.client.BackendMetricPropagation; import io.grpc.xds.client.XdsClient.ResourceUpdate; import io.grpc.xds.client.XdsResourceType; +import io.grpc.xds.internal.security.CommonTlsContextUtil; import java.util.List; import java.util.Locale; import java.util.Set; import javax.annotation.Nullable; class XdsClusterResource extends XdsResourceType { + @VisibleForTesting + static boolean enableLeastRequest = + !Strings.isNullOrEmpty(System.getenv("GRPC_EXPERIMENTAL_ENABLE_LEAST_REQUEST")) + ? Boolean.parseBoolean(System.getenv("GRPC_EXPERIMENTAL_ENABLE_LEAST_REQUEST")) + : Boolean.parseBoolean( + System.getProperty("io.grpc.xds.experimentalEnableLeastRequest", "true")); + @VisibleForTesting + public static boolean enableSystemRootCerts = + GrpcUtil.getFlag("GRPC_EXPERIMENTAL_XDS_SYSTEM_ROOT_CERTS", true); + static boolean isEnabledXdsHttpConnect = + GrpcUtil.getFlag("GRPC_EXPERIMENTAL_XDS_HTTP_CONNECT", false); + @VisibleForTesting static final String AGGREGATE_CLUSTER_TYPE_NAME = "envoy.clusters.aggregate"; static final String ADS_TYPE_URL_CDS = "type.googleapis.com/envoy.config.cluster.v3.Cluster"; + private static final String TYPE_URL_CLUSTER_CONFIG = + "type.googleapis.com/envoy.extensions.clusters.aggregate.v3.ClusterConfig"; private static final String TYPE_URL_UPSTREAM_TLS_CONTEXT = "type.googleapis.com/envoy.extensions.transport_sockets.tls.v3.UpstreamTlsContext"; private static final String TYPE_URL_UPSTREAM_TLS_CONTEXT_V2 = @@ -141,7 +163,7 @@ static CdsUpdate processCluster(Cluster cluster, CdsUpdate.Builder updateBuilder = structOrError.getStruct(); ImmutableMap lbPolicyConfig = LoadBalancerConfigFactory.newConfig(cluster, - enableLeastRequest, enableWrr, enablePickFirst); + enableLeastRequest); // Validate the LB config by trying to parse it with the corresponding LB provider. LbConfig lbConfig = ServiceConfigUtil.unwrapLoadBalancingConfig(lbPolicyConfig); @@ -149,10 +171,25 @@ static CdsUpdate processCluster(Cluster cluster, lbConfig.getPolicyName()).parseLoadBalancingPolicyConfig( lbConfig.getRawConfigValue()); if (configOrError.getError() != null) { - throw new ResourceInvalidException(structOrError.getErrorDetail()); + throw new ResourceInvalidException( + "Failed to parse lb config for cluster '" + cluster.getName() + "': " + + configOrError.getError()); } updateBuilder.lbPolicyConfig(lbPolicyConfig); + updateBuilder.filterMetadata( + ImmutableMap.copyOf(cluster.getMetadata().getFilterMetadataMap())); + + try { + MetadataRegistry registry = MetadataRegistry.getInstance(); + ImmutableMap parsedFilterMetadata = + registry.parseMetadata(cluster.getMetadata()); + updateBuilder.parsedMetadata(parsedFilterMetadata); + } catch (ResourceInvalidException e) { + throw new ResourceInvalidException( + "Failed to parse xDS filter metadata for cluster '" + cluster.getName() + "': " + + e.getMessage(), e); + } return updateBuilder.build(); } @@ -173,6 +210,10 @@ private static StructOrError parseAggregateCluster(Cluster cl } catch (InvalidProtocolBufferException e) { return StructOrError.fromError("Cluster " + clusterName + ": malformed ClusterConfig: " + e); } + if (clusterConfig.getClustersList().isEmpty()) { + return StructOrError.fromError("Cluster " + clusterName + + ": aggregate ClusterConfig.clusters must not be empty"); + } return StructOrError.fromStruct(CdsUpdate.forAggregate( clusterName, clusterConfig.getClustersList())); } @@ -184,6 +225,13 @@ private static StructOrError parseNonAggregateCluster( Long maxConcurrentRequests = null; UpstreamTlsContext upstreamTlsContext = null; OutlierDetection outlierDetection = null; + boolean isHttp11ProxyAvailable = false; + BackendMetricPropagation backendMetricPropagation = null; + + if (isEnabledOrcaLrsPropagation) { + backendMetricPropagation = BackendMetricPropagation.fromMetricSpecs( + cluster.getLrsReportEndpointMetricsList()); + } if (cluster.hasLrsServer()) { if (!cluster.getLrsServer().hasSelf()) { return StructOrError.fromError( @@ -198,7 +246,7 @@ private static StructOrError parseNonAggregateCluster( continue; } if (threshold.hasMaxRequests()) { - maxConcurrentRequests = (long) threshold.getMaxRequests().getValue(); + maxConcurrentRequests = Integer.toUnsignedLong(threshold.getMaxRequests().getValue()); } } } @@ -206,17 +254,43 @@ private static StructOrError parseNonAggregateCluster( return StructOrError.fromError("Cluster " + clusterName + ": transport-socket-matches not supported."); } - if (cluster.hasTransportSocket()) { - if (!TRANSPORT_SOCKET_NAME_TLS.equals(cluster.getTransportSocket().getName())) { - return StructOrError.fromError("transport-socket with name " - + cluster.getTransportSocket().getName() + " not supported."); + boolean hasTransportSocket = cluster.hasTransportSocket(); + TransportSocket transportSocket = cluster.getTransportSocket(); + + if (hasTransportSocket && !TRANSPORT_SOCKET_NAME_TLS.equals(transportSocket.getName()) + && !(isEnabledXdsHttpConnect && transportSocket.getTypedConfig().is( + Http11ProxyUpstreamTransport.class))) { + return StructOrError.fromError( + "transport-socket with name " + transportSocket.getName() + " not supported."); + } + + if (hasTransportSocket && isEnabledXdsHttpConnect && transportSocket.getTypedConfig().is( + Http11ProxyUpstreamTransport.class)) { + isHttp11ProxyAvailable = true; + try { + Http11ProxyUpstreamTransport wrappedTransportSocket = transportSocket + .getTypedConfig().unpack(io.envoyproxy.envoy.extensions.transport_sockets + .http_11_proxy.v3.Http11ProxyUpstreamTransport.class); + hasTransportSocket = wrappedTransportSocket.hasTransportSocket(); + transportSocket = wrappedTransportSocket.getTransportSocket(); + } catch (InvalidProtocolBufferException e) { + return StructOrError.fromError( + "Cluster " + clusterName + ": malformed Http11ProxyUpstreamTransport: " + e); + } catch (ClassCastException e) { + return StructOrError.fromError( + "Cluster " + clusterName + + ": invalid transport_socket type in Http11ProxyUpstreamTransport"); } + } + + if (hasTransportSocket && TRANSPORT_SOCKET_NAME_TLS.equals(transportSocket.getName())) { try { upstreamTlsContext = UpstreamTlsContext.fromEnvoyProtoUpstreamTlsContext( validateUpstreamTlsContext( - unpackCompatibleType(cluster.getTransportSocket().getTypedConfig(), - io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.UpstreamTlsContext.class, - TYPE_URL_UPSTREAM_TLS_CONTEXT, TYPE_URL_UPSTREAM_TLS_CONTEXT_V2), + unpackCompatibleType(transportSocket.getTypedConfig(), + io.envoyproxy.envoy.extensions + .transport_sockets.tls.v3.UpstreamTlsContext.class, + TYPE_URL_UPSTREAM_TLS_CONTEXT, TYPE_URL_UPSTREAM_TLS_CONTEXT_V2), certProviderInstances)); } catch (InvalidProtocolBufferException | ResourceInvalidException e) { return StructOrError.fromError( @@ -250,13 +324,14 @@ private static StructOrError parseNonAggregateCluster( edsServiceName = edsClusterConfig.getServiceName(); } // edsServiceName is required if the CDS resource has an xdstp name. - if ((edsServiceName == null) && clusterName.toLowerCase().startsWith("xdstp:")) { + if ((edsServiceName == null) && clusterName.toLowerCase(Locale.ROOT).startsWith("xdstp:")) { return StructOrError.fromError( "EDS service_name must be set when Cluster resource has an xdstp name"); } + return StructOrError.fromStruct(CdsUpdate.forEds( clusterName, edsServiceName, lrsServerInfo, maxConcurrentRequests, upstreamTlsContext, - outlierDetection)); + outlierDetection, isHttp11ProxyAvailable, backendMetricPropagation)); } else if (type.equals(Cluster.DiscoveryType.LOGICAL_DNS)) { if (!cluster.hasLoadAssignment()) { return StructOrError.fromError( @@ -291,7 +366,8 @@ private static StructOrError parseNonAggregateCluster( String dnsHostName = String.format( Locale.US, "%s:%d", socketAddress.getAddress(), socketAddress.getPortValue()); return StructOrError.fromStruct(CdsUpdate.forLogicalDns( - clusterName, dnsHostName, lrsServerInfo, maxConcurrentRequests, upstreamTlsContext)); + clusterName, dnsHostName, lrsServerInfo, maxConcurrentRequests, + upstreamTlsContext, isHttp11ProxyAvailable, backendMetricPropagation)); } return StructOrError.fromError( "Cluster " + clusterName + ": unsupported built-in discovery type: " + type); @@ -386,15 +462,6 @@ static void validateCommonTlsContext( throw new ResourceInvalidException( "common-tls-context with validation_context_sds_secret_config is not supported"); } - if (commonTlsContext.hasValidationContextCertificateProvider()) { - throw new ResourceInvalidException( - "common-tls-context with validation_context_certificate_provider is not supported"); - } - if (commonTlsContext.hasValidationContextCertificateProviderInstance()) { - throw new ResourceInvalidException( - "common-tls-context with validation_context_certificate_provider_instance is not" - + " supported"); - } String certInstanceName = getIdentityCertInstanceName(commonTlsContext); if (certInstanceName == null) { if (server) { @@ -409,10 +476,6 @@ static void validateCommonTlsContext( throw new ResourceInvalidException( "tls_certificate_provider_instance is unset"); } - if (commonTlsContext.hasTlsCertificateCertificateProvider()) { - throw new ResourceInvalidException( - "tls_certificate_provider_instance is unset"); - } } else if (certProviderInstances == null || !certProviderInstances.contains(certInstanceName)) { throw new ResourceInvalidException( "CertificateProvider instance name '" + certInstanceName @@ -420,9 +483,11 @@ static void validateCommonTlsContext( } String rootCaInstanceName = getRootCertInstanceName(commonTlsContext); if (rootCaInstanceName == null) { - if (!server) { + if (!server && (!enableSystemRootCerts + || !CommonTlsContextUtil.isUsingSystemRootCerts(commonTlsContext))) { throw new ResourceInvalidException( - "ca_certificate_provider_instance is required in upstream-tls-context"); + "ca_certificate_provider_instance or system_root_certs is required in " + + "upstream-tls-context"); } } else { if (certProviderInstances == null || !certProviderInstances.contains(rootCaInstanceName)) { @@ -439,7 +504,9 @@ static void validateCommonTlsContext( .getDefaultValidationContext(); } if (certificateValidationContext != null) { - if (certificateValidationContext.getMatchSubjectAltNamesCount() > 0 && server) { + @SuppressWarnings("deprecation") // gRFC A29 predates match_typed_subject_alt_names + int matchSubjectAltNamesCount = certificateValidationContext.getMatchSubjectAltNamesCount(); + if (matchSubjectAltNamesCount > 0 && server) { throw new ResourceInvalidException( "match_subject_alt_names only allowed in upstream_tls_context"); } @@ -470,10 +537,13 @@ static void validateCommonTlsContext( private static String getIdentityCertInstanceName(CommonTlsContext commonTlsContext) { if (commonTlsContext.hasTlsCertificateProviderInstance()) { return commonTlsContext.getTlsCertificateProviderInstance().getInstanceName(); - } else if (commonTlsContext.hasTlsCertificateCertificateProviderInstance()) { - return commonTlsContext.getTlsCertificateCertificateProviderInstance().getInstanceName(); } - return null; + // Fall back to deprecated field (field 11) for backward compatibility with Istio + @SuppressWarnings("deprecation") + String instanceName = commonTlsContext.hasTlsCertificateCertificateProviderInstance() + ? commonTlsContext.getTlsCertificateCertificateProviderInstance().getInstanceName() + : null; + return instanceName; } private static String getRootCertInstanceName(CommonTlsContext commonTlsContext) { @@ -490,10 +560,16 @@ private static String getRootCertInstanceName(CommonTlsContext commonTlsContext) .hasCaCertificateProviderInstance()) { return combinedCertificateValidationContext.getDefaultValidationContext() .getCaCertificateProviderInstance().getInstanceName(); - } else if (combinedCertificateValidationContext - .hasValidationContextCertificateProviderInstance()) { - return combinedCertificateValidationContext - .getValidationContextCertificateProviderInstance().getInstanceName(); + } + // Fall back to deprecated field (field 4) in CombinedValidationContext + @SuppressWarnings("deprecation") + String instanceName = combinedCertificateValidationContext + .hasValidationContextCertificateProviderInstance() + ? combinedCertificateValidationContext.getValidationContextCertificateProviderInstance() + .getInstanceName() + : null; + if (instanceName != null) { + return instanceName; } } return null; @@ -543,6 +619,8 @@ abstract static class CdsUpdate implements ResourceUpdate { @Nullable abstract UpstreamTlsContext upstreamTlsContext(); + abstract boolean isHttp11ProxyAvailable(); + // List of underlying clusters making of this aggregate cluster. // Only valid for AGGREGATE cluster. @Nullable @@ -552,48 +630,63 @@ abstract static class CdsUpdate implements ResourceUpdate { @Nullable abstract OutlierDetection outlierDetection(); - static Builder forAggregate(String clusterName, List prioritizedClusterNames) { - checkNotNull(prioritizedClusterNames, "prioritizedClusterNames"); + abstract ImmutableMap filterMetadata(); + + abstract ImmutableMap parsedMetadata(); + + @Nullable + abstract BackendMetricPropagation backendMetricPropagation(); + + private static Builder newBuilder(String clusterName) { return new AutoValue_XdsClusterResource_CdsUpdate.Builder() .clusterName(clusterName) - .clusterType(ClusterType.AGGREGATE) .minRingSize(0) .maxRingSize(0) .choiceCount(0) + .filterMetadata(ImmutableMap.of()) + .parsedMetadata(ImmutableMap.of()) + .isHttp11ProxyAvailable(false) + .backendMetricPropagation(null); + } + + static Builder forAggregate(String clusterName, List prioritizedClusterNames) { + checkNotNull(prioritizedClusterNames, "prioritizedClusterNames"); + return newBuilder(clusterName) + .clusterType(ClusterType.AGGREGATE) .prioritizedClusterNames(ImmutableList.copyOf(prioritizedClusterNames)); } static Builder forEds(String clusterName, @Nullable String edsServiceName, @Nullable ServerInfo lrsServerInfo, @Nullable Long maxConcurrentRequests, @Nullable UpstreamTlsContext upstreamTlsContext, - @Nullable OutlierDetection outlierDetection) { - return new AutoValue_XdsClusterResource_CdsUpdate.Builder() - .clusterName(clusterName) + @Nullable OutlierDetection outlierDetection, + boolean isHttp11ProxyAvailable, + BackendMetricPropagation backendMetricPropagation) { + return newBuilder(clusterName) .clusterType(ClusterType.EDS) - .minRingSize(0) - .maxRingSize(0) - .choiceCount(0) .edsServiceName(edsServiceName) .lrsServerInfo(lrsServerInfo) .maxConcurrentRequests(maxConcurrentRequests) .upstreamTlsContext(upstreamTlsContext) - .outlierDetection(outlierDetection); + .outlierDetection(outlierDetection) + .isHttp11ProxyAvailable(isHttp11ProxyAvailable) + .backendMetricPropagation(backendMetricPropagation); } static Builder forLogicalDns(String clusterName, String dnsHostName, @Nullable ServerInfo lrsServerInfo, @Nullable Long maxConcurrentRequests, - @Nullable UpstreamTlsContext upstreamTlsContext) { - return new AutoValue_XdsClusterResource_CdsUpdate.Builder() - .clusterName(clusterName) + @Nullable UpstreamTlsContext upstreamTlsContext, + boolean isHttp11ProxyAvailable, + BackendMetricPropagation backendMetricPropagation) { + return newBuilder(clusterName) .clusterType(ClusterType.LOGICAL_DNS) - .minRingSize(0) - .maxRingSize(0) - .choiceCount(0) .dnsHostName(dnsHostName) .lrsServerInfo(lrsServerInfo) .maxConcurrentRequests(maxConcurrentRequests) - .upstreamTlsContext(upstreamTlsContext); + .upstreamTlsContext(upstreamTlsContext) + .isHttp11ProxyAvailable(isHttp11ProxyAvailable) + .backendMetricPropagation(backendMetricPropagation); } enum ClusterType { @@ -670,6 +763,8 @@ Builder leastRequestLbPolicy(Integer choiceCount) { // Private, use one of the static factory methods instead. protected abstract Builder maxConcurrentRequests(Long maxConcurrentRequests); + protected abstract Builder isHttp11ProxyAvailable(boolean isHttp11ProxyAvailable); + // Private, use one of the static factory methods instead. protected abstract Builder upstreamTlsContext(UpstreamTlsContext upstreamTlsContext); @@ -678,6 +773,13 @@ Builder leastRequestLbPolicy(Integer choiceCount) { protected abstract Builder outlierDetection(OutlierDetection outlierDetection); + protected abstract Builder filterMetadata(ImmutableMap filterMetadata); + + protected abstract Builder parsedMetadata(ImmutableMap parsedMetadata); + + protected abstract Builder backendMetricPropagation( + BackendMetricPropagation backendMetricPropagation); + abstract CdsUpdate build(); } } diff --git a/xds/src/main/java/io/grpc/xds/XdsConfig.java b/xds/src/main/java/io/grpc/xds/XdsConfig.java new file mode 100644 index 00000000000..d184f08de55 --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/XdsConfig.java @@ -0,0 +1,265 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds; + +import static com.google.common.base.Preconditions.checkNotNull; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import io.grpc.StatusOr; +import io.grpc.xds.XdsClusterResource.CdsUpdate; +import io.grpc.xds.XdsEndpointResource.EdsUpdate; +import io.grpc.xds.XdsListenerResource.LdsUpdate; +import io.grpc.xds.XdsRouteConfigureResource.RdsUpdate; +import java.io.Closeable; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; + +/** + * Represents the xDS configuration tree for a specified Listener. + */ +final class XdsConfig { + private final LdsUpdate listener; + private final RdsUpdate route; + private final VirtualHost virtualHost; + private final ImmutableMap> clusters; + private final int hashCode; + + XdsConfig(LdsUpdate listener, RdsUpdate route, Map> clusters, + VirtualHost virtualHost) { + this(listener, route, virtualHost, ImmutableMap.copyOf(clusters)); + } + + public XdsConfig(LdsUpdate listener, RdsUpdate route, VirtualHost virtualHost, + ImmutableMap> clusters) { + this.listener = listener; + this.route = route; + this.virtualHost = virtualHost; + this.clusters = clusters; + + hashCode = Objects.hash(listener, route, virtualHost, clusters); + } + + @Override + public boolean equals(Object obj) { + if (!(obj instanceof XdsConfig)) { + return false; + } + + XdsConfig o = (XdsConfig) obj; + + return hashCode() == o.hashCode() && Objects.equals(listener, o.listener) + && Objects.equals(route, o.route) && Objects.equals(virtualHost, o.virtualHost) + && Objects.equals(clusters, o.clusters); + } + + @Override + public int hashCode() { + return hashCode; + } + + @Override + public String toString() { + StringBuilder builder = new StringBuilder(); + builder.append("XdsConfig{") + .append("\n listener=").append(listener) + .append(",\n route=").append(route) + .append(",\n virtualHost=").append(virtualHost) + .append(",\n clusters=").append(clusters) + .append("\n}"); + return builder.toString(); + } + + public LdsUpdate getListener() { + return listener; + } + + public RdsUpdate getRoute() { + return route; + } + + public VirtualHost getVirtualHost() { + return virtualHost; + } + + public ImmutableMap> getClusters() { + return clusters; + } + + static final class XdsClusterConfig { + private final String clusterName; + private final CdsUpdate clusterResource; + private final ClusterChild children; // holds details + + XdsClusterConfig(String clusterName, CdsUpdate clusterResource, ClusterChild details) { + this.clusterName = checkNotNull(clusterName, "clusterName"); + this.clusterResource = checkNotNull(clusterResource, "clusterResource"); + this.children = checkNotNull(details, "details"); + } + + @Override + public int hashCode() { + return clusterName.hashCode() + clusterResource.hashCode() + children.hashCode(); + } + + @Override + public boolean equals(Object obj) { + if (!(obj instanceof XdsClusterConfig)) { + return false; + } + XdsClusterConfig o = (XdsClusterConfig) obj; + return Objects.equals(clusterName, o.clusterName) + && Objects.equals(clusterResource, o.clusterResource) + && Objects.equals(children, o.children); + } + + @Override + public String toString() { + StringBuilder builder = new StringBuilder(); + builder.append("XdsClusterConfig{clusterName=").append(clusterName) + .append(", clusterResource=").append(clusterResource) + .append(", children={").append(children) + .append("}"); + return builder.toString(); + } + + public String getClusterName() { + return clusterName; + } + + public CdsUpdate getClusterResource() { + return clusterResource; + } + + public ClusterChild getChildren() { + return children; + } + + interface ClusterChild {} + + /** Endpoint info for EDS and LOGICAL_DNS clusters. If there was an + * error, endpoints will be null and resolution_note will be set. + */ + static final class EndpointConfig implements ClusterChild { + private final StatusOr endpoint; + + public EndpointConfig(StatusOr endpoint) { + this.endpoint = checkNotNull(endpoint, "endpoint"); + } + + @Override + public int hashCode() { + return endpoint.hashCode(); + } + + @Override + public boolean equals(Object obj) { + if (!(obj instanceof EndpointConfig)) { + return false; + } + return Objects.equals(endpoint, ((EndpointConfig)obj).endpoint); + } + + public StatusOr getEndpoint() { + return endpoint; + } + + @Override + public String toString() { + if (endpoint.hasValue()) { + return "EndpointConfig{endpoint=" + endpoint.getValue() + "}"; + } else { + return "EndpointConfig{error=" + endpoint.getStatus() + "}"; + } + } + } + + // The list of leaf clusters for an aggregate cluster. + static final class AggregateConfig implements ClusterChild { + private final List leafNames; + + public AggregateConfig(List leafNames) { + this.leafNames = ImmutableList.copyOf(checkNotNull(leafNames, "leafNames")); + } + + public List getLeafNames() { + return leafNames; + } + + @Override + public int hashCode() { + return leafNames.hashCode(); + } + + @Override + public boolean equals(Object obj) { + if (!(obj instanceof AggregateConfig)) { + return false; + } + return Objects.equals(leafNames, ((AggregateConfig) obj).leafNames); + } + } + } + + static final class XdsConfigBuilder { + private LdsUpdate listener; + private RdsUpdate route; + private Map> clusters = new HashMap<>(); + private VirtualHost virtualHost; + + XdsConfigBuilder setListener(LdsUpdate listener) { + this.listener = checkNotNull(listener, "listener"); + return this; + } + + XdsConfigBuilder setRoute(RdsUpdate route) { + this.route = checkNotNull(route, "route"); + return this; + } + + XdsConfigBuilder addCluster(String name, StatusOr clusterConfig) { + checkNotNull(name, "name"); + checkNotNull(clusterConfig, "clusterConfig"); + clusters.put(name, clusterConfig); + return this; + } + + XdsConfigBuilder setVirtualHost(VirtualHost virtualHost) { + this.virtualHost = checkNotNull(virtualHost, "virtualHost"); + return this; + } + + XdsConfig build() { + checkNotNull(listener, "listener"); + checkNotNull(route, "route"); + checkNotNull(virtualHost, "virtualHost"); + return new XdsConfig(listener, route, clusters, virtualHost); + } + } + + public interface XdsClusterSubscriptionRegistry { + Subscription subscribeToCluster(String clusterName); + } + + public interface Subscription extends Closeable { + /** Release resources without throwing exceptions. */ + @Override + void close(); + } +} diff --git a/xds/src/main/java/io/grpc/xds/XdsCredentialsRegistry.java b/xds/src/main/java/io/grpc/xds/XdsCredentialsRegistry.java index c33b3cd2f85..9dd77a400cd 100644 --- a/xds/src/main/java/io/grpc/xds/XdsCredentialsRegistry.java +++ b/xds/src/main/java/io/grpc/xds/XdsCredentialsRegistry.java @@ -21,6 +21,7 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableMap; +import com.google.errorprone.annotations.concurrent.GuardedBy; import io.grpc.InternalServiceProviders; import java.util.ArrayList; import java.util.Collections; @@ -28,10 +29,10 @@ import java.util.LinkedHashSet; import java.util.List; import java.util.Map; +import java.util.ServiceLoader; import java.util.logging.Level; import java.util.logging.Logger; import javax.annotation.Nullable; -import javax.annotation.concurrent.GuardedBy; import javax.annotation.concurrent.ThreadSafe; /** @@ -109,8 +110,10 @@ public static synchronized XdsCredentialsRegistry getDefaultRegistry() { if (instance == null) { List providerList = InternalServiceProviders.loadAll( XdsCredentialsProvider.class, - getHardCodedClasses(), - XdsCredentialsProvider.class.getClassLoader(), + ServiceLoader + .load(XdsCredentialsProvider.class, XdsCredentialsProvider.class.getClassLoader()) + .iterator(), + XdsCredentialsRegistry::getHardCodedClasses, new XdsCredentialsProviderPriorityAccessor()); if (providerList.isEmpty()) { logger.warning("No XdsCredsRegistry found via ServiceLoader, including for GoogleDefault, " diff --git a/xds/src/main/java/io/grpc/xds/XdsDependencyManager.java b/xds/src/main/java/io/grpc/xds/XdsDependencyManager.java new file mode 100644 index 00000000000..a0af5974175 --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/XdsDependencyManager.java @@ -0,0 +1,947 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkNotNull; +import static com.google.common.base.Preconditions.checkState; +import static io.grpc.xds.client.XdsClient.ResourceUpdate; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import io.grpc.EquivalentAddressGroup; +import io.grpc.NameResolver; +import io.grpc.NameResolverProvider; +import io.grpc.Status; +import io.grpc.StatusOr; +import io.grpc.SynchronizationContext; +import io.grpc.internal.GrpcUtil; +import io.grpc.internal.RetryingNameResolver; +import io.grpc.xds.Endpoints.LocalityLbEndpoints; +import io.grpc.xds.VirtualHost.Route.RouteAction.ClusterWeight; +import io.grpc.xds.XdsClusterResource.CdsUpdate.ClusterType; +import io.grpc.xds.XdsConfig.XdsClusterConfig.AggregateConfig; +import io.grpc.xds.XdsConfig.XdsClusterConfig.EndpointConfig; +import io.grpc.xds.XdsRouteConfigureResource.RdsUpdate; +import io.grpc.xds.client.Locality; +import io.grpc.xds.client.XdsClient; +import io.grpc.xds.client.XdsClient.ResourceWatcher; +import io.grpc.xds.client.XdsResourceType; +import java.net.SocketAddress; +import java.net.URI; +import java.net.URISyntaxException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.EnumMap; +import java.util.HashMap; +import java.util.HashSet; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Set; +import javax.annotation.Nullable; + +/** + * This class acts as a layer of indirection between the XdsClient and the NameResolver. It + * maintains the watchers for the xds resources and when an update is received, it either requests + * referenced resources or updates the XdsConfig and notifies the XdsConfigWatcher. Each instance + * applies to a single data plane authority. + */ +final class XdsDependencyManager implements XdsConfig.XdsClusterSubscriptionRegistry { + private enum TrackedWatcherTypeEnum { + LDS, RDS, CDS, EDS, DNS + } + + private static final TrackedWatcherType LDS_TYPE = + new TrackedWatcherType<>(TrackedWatcherTypeEnum.LDS); + private static final TrackedWatcherType RDS_TYPE = + new TrackedWatcherType<>(TrackedWatcherTypeEnum.RDS); + private static final TrackedWatcherType CDS_TYPE = + new TrackedWatcherType<>(TrackedWatcherTypeEnum.CDS); + private static final TrackedWatcherType EDS_TYPE = + new TrackedWatcherType<>(TrackedWatcherTypeEnum.EDS); + private static final TrackedWatcherType> DNS_TYPE = + new TrackedWatcherType<>(TrackedWatcherTypeEnum.DNS); + + // DNS-resolved endpoints do not have the definition of the locality it belongs to, just hardcode + // to an empty locality. + private static final Locality LOGICAL_DNS_CLUSTER_LOCALITY = Locality.create("", "", ""); + + private static final int MAX_CLUSTER_RECURSION_DEPTH = 16; // Specified by gRFC A37 + + static boolean enableLogicalDns = true; + + private final String listenerName; + private final XdsClient xdsClient; + private final SynchronizationContext syncContext; + private final String dataPlaneAuthority; + private final NameResolver.Args nameResolverArgs; + private XdsConfigWatcher xdsConfigWatcher; + + private StatusOr lastUpdate = null; + private final Map> resourceWatchers = + new EnumMap<>(TrackedWatcherTypeEnum.class); + private final Set subscriptions = new HashSet<>(); + + XdsDependencyManager( + XdsClient xdsClient, + SynchronizationContext syncContext, + String dataPlaneAuthority, + String listenerName, + NameResolver.Args nameResolverArgs) { + this.listenerName = checkNotNull(listenerName, "listenerName"); + this.xdsClient = checkNotNull(xdsClient, "xdsClient"); + this.syncContext = checkNotNull(syncContext, "syncContext"); + this.dataPlaneAuthority = checkNotNull(dataPlaneAuthority, "dataPlaneAuthority"); + this.nameResolverArgs = checkNotNull(nameResolverArgs, "nameResolverArgs"); + } + + public static String toContextStr(String typeName, String resourceName) { + return typeName + " resource " + resourceName; + } + + public void start(XdsConfigWatcher xdsConfigWatcher) { + checkState(this.xdsConfigWatcher == null, "dep manager may not be restarted"); + this.xdsConfigWatcher = checkNotNull(xdsConfigWatcher, "xdsConfigWatcher"); + // start the ball rolling + syncContext.execute(() -> addWatcher(LDS_TYPE, new LdsWatcher(listenerName))); + } + + @Override + public XdsConfig.Subscription subscribeToCluster(String clusterName) { + checkState(this.xdsConfigWatcher != null, "dep manager must first be started"); + checkNotNull(clusterName, "clusterName"); + ClusterSubscription subscription = new ClusterSubscription(clusterName); + + syncContext.execute(() -> { + if (getWatchers(LDS_TYPE).isEmpty()) { + subscription.closed = true; + return; // shutdown() called + } + subscriptions.add(subscription); + addClusterWatcher(clusterName); + }); + + return subscription; + } + + /** + * For all logical dns clusters refresh their results. + */ + public void requestReresolution() { + syncContext.execute(() -> { + for (TrackedWatcher> watcher : getWatchers(DNS_TYPE).values()) { + DnsWatcher dnsWatcher = (DnsWatcher) watcher; + dnsWatcher.refresh(); + } + }); + } + + private void addWatcher( + TrackedWatcherType watcherType, XdsWatcherBase watcher) { + syncContext.throwIfNotInThisSynchronizationContext(); + XdsResourceType type = watcher.type; + String resourceName = watcher.resourceName; + + getWatchers(watcherType).put(resourceName, watcher); + xdsClient.watchXdsResource(type, resourceName, watcher, syncContext); + } + + public void shutdown() { + syncContext.execute(() -> { + for (TypeWatchers watchers : resourceWatchers.values()) { + for (TrackedWatcher watcher : watchers.watchers.values()) { + watcher.close(); + } + } + resourceWatchers.clear(); + subscriptions.clear(); + }); + } + + private void releaseSubscription(ClusterSubscription subscription) { + checkNotNull(subscription, "subscription"); + syncContext.execute(() -> { + if (subscription.closed) { + return; + } + subscription.closed = true; + if (!subscriptions.remove(subscription)) { + return; // shutdown() called + } + maybePublishConfig(); + }); + } + + /** + * Check if all resources have results, and if so, generate a new XdsConfig and send it to all + * the watchers. + */ + private void maybePublishConfig() { + syncContext.throwIfNotInThisSynchronizationContext(); + if (getWatchers(LDS_TYPE).isEmpty()) { + return; // shutdown() called + } + boolean waitingOnResource = resourceWatchers.values().stream() + .flatMap(typeWatchers -> typeWatchers.watchers.values().stream()) + .anyMatch(TrackedWatcher::missingResult); + if (waitingOnResource) { + return; + } + + StatusOr newUpdate = buildUpdate(); + if (Objects.equals(newUpdate, lastUpdate)) { + return; + } + assert newUpdate.hasValue() + || (newUpdate.getStatus().getCode() == Status.Code.UNAVAILABLE + || newUpdate.getStatus().getCode() == Status.Code.INTERNAL); + lastUpdate = newUpdate; + xdsConfigWatcher.onUpdate(lastUpdate); + } + + @VisibleForTesting + StatusOr buildUpdate() { + // Create a config and discard any watchers not accessed + WatcherTracer tracer = new WatcherTracer(resourceWatchers); + StatusOr config = buildUpdate( + tracer, listenerName, dataPlaneAuthority, subscriptions); + tracer.closeUnusedWatchers(); + return config; + } + + private static StatusOr buildUpdate( + WatcherTracer tracer, + String listenerName, + String dataPlaneAuthority, + Set subscriptions) { + XdsConfig.XdsConfigBuilder builder = new XdsConfig.XdsConfigBuilder(); + + // Iterate watchers and build the XdsConfig + + TrackedWatcher ldsWatcher + = tracer.getWatcher(LDS_TYPE, listenerName); + if (ldsWatcher == null) { + return StatusOr.fromStatus(Status.UNAVAILABLE.withDescription( + "Bug: No listener watcher found for " + listenerName)); + } + if (!ldsWatcher.getData().hasValue()) { + return StatusOr.fromStatus(ldsWatcher.getData().getStatus()); + } + XdsListenerResource.LdsUpdate ldsUpdate = ldsWatcher.getData().getValue(); + builder.setListener(ldsUpdate); + + RdsUpdateSupplier routeSource = ((LdsWatcher) ldsWatcher).getRouteSource(tracer); + if (routeSource == null) { + return StatusOr.fromStatus(Status.UNAVAILABLE.withDescription( + "Bug: No route source found for listener " + dataPlaneAuthority)); + } + StatusOr statusOrRdsUpdate = routeSource.getRdsUpdate(); + if (!statusOrRdsUpdate.hasValue()) { + return StatusOr.fromStatus(statusOrRdsUpdate.getStatus()); + } + RdsUpdate rdsUpdate = statusOrRdsUpdate.getValue(); + builder.setRoute(rdsUpdate); + + VirtualHost activeVirtualHost = + RoutingUtils.findVirtualHostForHostName(rdsUpdate.virtualHosts, dataPlaneAuthority); + if (activeVirtualHost == null) { + String error = "Failed to find virtual host matching hostname: " + dataPlaneAuthority; + return StatusOr.fromStatus(Status.UNAVAILABLE.withDescription(error)); + } + builder.setVirtualHost(activeVirtualHost); + + Map> clusters = new HashMap<>(); + LinkedHashSet ancestors = new LinkedHashSet<>(); + for (String cluster : getClusterNamesFromVirtualHost(activeVirtualHost)) { + addConfigForCluster(clusters, cluster, ancestors, tracer); + } + for (ClusterSubscription subscription : subscriptions) { + addConfigForCluster(clusters, subscription.getClusterName(), ancestors, tracer); + } + for (Map.Entry> me : clusters.entrySet()) { + builder.addCluster(me.getKey(), me.getValue()); + } + + return StatusOr.fromValue(builder.build()); + } + + private Map> getWatchers(TrackedWatcherType watcherType) { + TypeWatchers typeWatchers = resourceWatchers.get(watcherType.typeEnum); + if (typeWatchers == null) { + typeWatchers = new TypeWatchers(watcherType); + resourceWatchers.put(watcherType.typeEnum, typeWatchers); + } + assert typeWatchers.watcherType == watcherType; + @SuppressWarnings("unchecked") + TypeWatchers tTypeWatchers = (TypeWatchers) typeWatchers; + return tTypeWatchers.watchers; + } + + private static void addConfigForCluster( + Map> clusters, + String clusterName, + @SuppressWarnings("NonApiType") // Need order-preserving set for errors + LinkedHashSet ancestors, + WatcherTracer tracer) { + if (clusters.containsKey(clusterName)) { + return; + } + if (ancestors.contains(clusterName)) { + clusters.put(clusterName, StatusOr.fromStatus( + Status.INTERNAL.withDescription( + "Aggregate cluster cycle detected: " + ancestors))); + return; + } + if (ancestors.size() > MAX_CLUSTER_RECURSION_DEPTH) { + clusters.put(clusterName, StatusOr.fromStatus( + Status.INTERNAL.withDescription("Recursion limit reached: " + ancestors))); + return; + } + + CdsWatcher cdsWatcher = (CdsWatcher) tracer.getWatcher(CDS_TYPE, clusterName); + StatusOr cdsWatcherDataOr = cdsWatcher.getData(); + if (!cdsWatcherDataOr.hasValue()) { + clusters.put(clusterName, StatusOr.fromStatus(cdsWatcherDataOr.getStatus())); + return; + } + + XdsClusterResource.CdsUpdate cdsUpdate = cdsWatcherDataOr.getValue(); + XdsConfig.XdsClusterConfig.ClusterChild child; + switch (cdsUpdate.clusterType()) { + case AGGREGATE: + // Re-inserting a present element into a LinkedHashSet does not reorder the entries, so it + // preserves the priority across all aggregate clusters + LinkedHashSet leafNames = new LinkedHashSet(); + ancestors.add(clusterName); + for (String childCluster : cdsUpdate.prioritizedClusterNames()) { + addConfigForCluster(clusters, childCluster, ancestors, tracer); + StatusOr config = clusters.get(childCluster); + if (!config.hasValue()) { + // gRFC A37 says: If any of a CDS policy's watchers reports that the resource does not + // exist the policy should report that it is in TRANSIENT_FAILURE. If any of the + // watchers reports a transient ADS stream error, the policy should report that it is in + // TRANSIENT_FAILURE if it has never passed a config to its child. + // + // But there's currently disagreement about whether that is actually what we want, and + // that was not originally implemented in gRPC Java. So we're keeping Java's old + // behavior for now and only failing the "leaves" (which is a bit arbitrary for a + // cycle). + leafNames.add(childCluster); + continue; + } + XdsConfig.XdsClusterConfig.ClusterChild children = config.getValue().getChildren(); + if (children instanceof AggregateConfig) { + leafNames.addAll(((AggregateConfig) children).getLeafNames()); + } else { + leafNames.add(childCluster); + } + } + ancestors.remove(clusterName); + + child = new AggregateConfig(ImmutableList.copyOf(leafNames)); + break; + case EDS: + TrackedWatcher edsWatcher = + tracer.getWatcher(EDS_TYPE, cdsWatcher.getEdsServiceName()); + if (edsWatcher != null) { + child = new EndpointConfig(edsWatcher.getData()); + } else { + child = new EndpointConfig(StatusOr.fromStatus(Status.INTERNAL.withDescription( + "EDS resource not found for cluster " + clusterName))); + } + break; + case LOGICAL_DNS: + if (enableLogicalDns) { + TrackedWatcher> dnsWatcher = + tracer.getWatcher(DNS_TYPE, cdsUpdate.dnsHostName()); + child = new EndpointConfig(dnsToEdsUpdate(dnsWatcher.getData(), cdsUpdate.dnsHostName())); + } else { + child = new EndpointConfig(StatusOr.fromStatus( + Status.INTERNAL.withDescription("Logical DNS in dependency manager unsupported"))); + } + break; + default: + child = new EndpointConfig(StatusOr.fromStatus(Status.UNAVAILABLE.withDescription( + "Unknown type in cluster " + clusterName + " " + cdsUpdate.clusterType()))); + } + if (clusters.containsKey(clusterName)) { + // If a cycle is detected, we'll have detected it while recursing, so now there will be a key + // present. We don't want to overwrite it with a non-error value. + return; + } + clusters.put(clusterName, StatusOr.fromValue( + new XdsConfig.XdsClusterConfig(clusterName, cdsUpdate, child))); + } + + private static StatusOr dnsToEdsUpdate( + StatusOr> dnsData, String dnsHostName) { + if (!dnsData.hasValue()) { + return StatusOr.fromStatus(dnsData.getStatus()); + } + + List addresses = new ArrayList<>(); + for (EquivalentAddressGroup eag : dnsData.getValue()) { + addresses.addAll(eag.getAddresses()); + } + EquivalentAddressGroup eag = new EquivalentAddressGroup(addresses); + List endpoints = ImmutableList.of( + Endpoints.LbEndpoint.create(eag, 1, true, dnsHostName, ImmutableMap.of())); + LocalityLbEndpoints lbEndpoints = + LocalityLbEndpoints.create(endpoints, 1, 0, ImmutableMap.of()); + return StatusOr.fromValue(new XdsEndpointResource.EdsUpdate( + "fakeEds_logicalDns", + Collections.singletonMap(LOGICAL_DNS_CLUSTER_LOCALITY, lbEndpoints), + new ArrayList<>())); + } + + private void addRdsWatcher(String resourceName) { + if (getWatchers(RDS_TYPE).containsKey(resourceName)) { + return; + } + + addWatcher(RDS_TYPE, new RdsWatcher(resourceName)); + } + + private void addEdsWatcher(String edsServiceName) { + if (getWatchers(EDS_TYPE).containsKey(edsServiceName)) { + return; + } + + addWatcher(EDS_TYPE, new EdsWatcher(edsServiceName)); + } + + private void addClusterWatcher(String clusterName) { + if (getWatchers(CDS_TYPE).containsKey(clusterName)) { + return; + } + + addWatcher(CDS_TYPE, new CdsWatcher(clusterName)); + } + + private void addDnsWatcher(String dnsHostName) { + syncContext.throwIfNotInThisSynchronizationContext(); + if (getWatchers(DNS_TYPE).containsKey(dnsHostName)) { + return; + } + + DnsWatcher watcher = new DnsWatcher(dnsHostName, nameResolverArgs); + getWatchers(DNS_TYPE).put(dnsHostName, watcher); + watcher.start(); + } + + private void updateRoutes(List virtualHosts) { + VirtualHost virtualHost = + RoutingUtils.findVirtualHostForHostName(virtualHosts, dataPlaneAuthority); + Set newClusters = getClusterNamesFromVirtualHost(virtualHost); + newClusters.forEach((cluster) -> addClusterWatcher(cluster)); + } + + private String nodeInfo() { + return " nodeID: " + xdsClient.getBootstrapInfo().node().getId(); + } + + private static Set getClusterNamesFromVirtualHost(VirtualHost virtualHost) { + if (virtualHost == null) { + return Collections.emptySet(); + } + + // Get all cluster names to which requests can be routed through the virtual host. + Set clusters = new HashSet<>(); + for (VirtualHost.Route route : virtualHost.routes()) { + VirtualHost.Route.RouteAction action = route.routeAction(); + if (action == null) { + continue; + } + if (action.cluster() != null) { + clusters.add(action.cluster()); + } else if (action.weightedClusters() != null) { + for (ClusterWeight weighedCluster : action.weightedClusters()) { + clusters.add(weighedCluster.name()); + } + } + } + + return clusters; + } + + private static NameResolver createNameResolver( + String dnsHostName, + NameResolver.Args nameResolverArgs) { + URI uri; + try { + uri = new URI("dns", "", "/" + dnsHostName, null); + } catch (URISyntaxException e) { + return new FailingNameResolver( + Status.INTERNAL.withDescription("Bug, invalid URI creation: " + dnsHostName) + .withCause(e)); + } + + NameResolverProvider provider = + nameResolverArgs.getNameResolverRegistry().getProviderForScheme("dns"); + if (provider == null) { + return new FailingNameResolver( + Status.INTERNAL.withDescription("Could not find dns name resolver")); + } + + NameResolver bareResolver = provider.newNameResolver(uri, nameResolverArgs); + if (bareResolver == null) { + return new FailingNameResolver( + Status.INTERNAL.withDescription("DNS name resolver provider returned null: " + uri)); + } + return RetryingNameResolver.wrap(bareResolver, nameResolverArgs); + } + + private static class TypeWatchers { + // Key is resource name + final Map> watchers = new HashMap<>(); + final TrackedWatcherType watcherType; + + TypeWatchers(TrackedWatcherType watcherType) { + this.watcherType = checkNotNull(watcherType, "watcherType"); + } + } + + public interface XdsConfigWatcher { + /** + * An updated XdsConfig or RPC-safe Status. The status code will be either UNAVAILABLE or + * INTERNAL. + */ + void onUpdate(StatusOr config); + } + + private final class ClusterSubscription implements XdsConfig.Subscription { + private final String clusterName; + boolean closed; // Accessed from syncContext + + public ClusterSubscription(String clusterName) { + this.clusterName = checkNotNull(clusterName, "clusterName"); + } + + String getClusterName() { + return clusterName; + } + + @Override + public void close() { + releaseSubscription(this); + } + } + + /** State for tracing garbage collector. */ + private static final class WatcherTracer { + private final Map> resourceWatchers; + private final Map> usedWatchers; + + public WatcherTracer(Map> resourceWatchers) { + this.resourceWatchers = resourceWatchers; + + this.usedWatchers = new EnumMap<>(TrackedWatcherTypeEnum.class); + for (Map.Entry> me : resourceWatchers.entrySet()) { + usedWatchers.put(me.getKey(), newTypeWatchers(me.getValue().watcherType)); + } + } + + private static TypeWatchers newTypeWatchers(TrackedWatcherType type) { + return new TypeWatchers(type); + } + + public TrackedWatcher getWatcher(TrackedWatcherType watcherType, String name) { + TypeWatchers typeWatchers = resourceWatchers.get(watcherType.typeEnum); + if (typeWatchers == null) { + return null; + } + assert typeWatchers.watcherType == watcherType; + @SuppressWarnings("unchecked") + TypeWatchers tTypeWatchers = (TypeWatchers) typeWatchers; + TrackedWatcher watcher = tTypeWatchers.watchers.get(name); + if (watcher == null) { + return null; + } + @SuppressWarnings("unchecked") + TypeWatchers usedTypeWatchers = (TypeWatchers) usedWatchers.get(watcherType.typeEnum); + usedTypeWatchers.watchers.put(name, watcher); + return watcher; + } + + /** Shut down unused watchers. */ + public void closeUnusedWatchers() { + boolean changed = false; // Help out the GC by preferring old objects + for (TrackedWatcherTypeEnum key : resourceWatchers.keySet()) { + TypeWatchers orig = resourceWatchers.get(key); + TypeWatchers used = usedWatchers.get(key); + for (String name : orig.watchers.keySet()) { + if (used.watchers.containsKey(name)) { + continue; + } + orig.watchers.get(name).close(); + changed = true; + } + } + if (changed) { + resourceWatchers.putAll(usedWatchers); + } + } + } + + @SuppressWarnings("UnusedTypeParameter") + private static final class TrackedWatcherType { + public final TrackedWatcherTypeEnum typeEnum; + + public TrackedWatcherType(TrackedWatcherTypeEnum typeEnum) { + this.typeEnum = checkNotNull(typeEnum, "typeEnum"); + } + } + + private interface TrackedWatcher { + @Nullable + StatusOr getData(); + + default boolean missingResult() { + return getData() == null; + } + + default boolean hasDataValue() { + StatusOr data = getData(); + return data != null && data.hasValue(); + } + + void close(); + } + + private abstract class XdsWatcherBase + implements ResourceWatcher, TrackedWatcher { + private final XdsResourceType type; + private final String resourceName; + boolean cancelled; + + @Nullable + private StatusOr data; + @Nullable + @SuppressWarnings("unused") + private Status ambientError; + + + private XdsWatcherBase(XdsResourceType type, String resourceName) { + this.type = checkNotNull(type, "type"); + this.resourceName = checkNotNull(resourceName, "resourceName"); + } + + @Override + public void onResourceChanged(StatusOr update) { + if (cancelled) { + return; + } + ambientError = null; + if (update.hasValue()) { + data = update; + subscribeToChildren(update.getValue()); + } else { + Status translatedStatus = GrpcUtil.statusWithDetails( + Status.Code.UNAVAILABLE, + "Error retrieving " + toContextString() + nodeInfo(), + update.getStatus()); + + data = StatusOr.fromStatus(translatedStatus); + } + maybePublishConfig(); + } + + @Override + public void onAmbientError(Status error) { + if (cancelled) { + return; + } + ambientError = error.withDescription( + String.format("Ambient error for %s: %s. Details: %s%s", + toContextString(), + error.getCode(), + error.getDescription() != null ? error.getDescription() : "", + nodeInfo())); + } + + protected abstract void subscribeToChildren(T update); + + @Override + public void close() { + cancelled = true; + xdsClient.cancelXdsResourceWatch(type, resourceName, this); + } + + @Override + @Nullable + public StatusOr getData() { + return data; + } + + public String toContextString() { + return toContextStr(type.typeName(), resourceName); + } + } + + private interface RdsUpdateSupplier { + StatusOr getRdsUpdate(); + } + + private class LdsWatcher extends XdsWatcherBase + implements RdsUpdateSupplier { + + private LdsWatcher(String resourceName) { + super(XdsListenerResource.getInstance(), resourceName); + } + + @Override + public void subscribeToChildren(XdsListenerResource.LdsUpdate update) { + HttpConnectionManager httpConnectionManager = update.httpConnectionManager(); + List virtualHosts; + if (httpConnectionManager == null) { + // TCP listener. Unsupported config + virtualHosts = Collections.emptyList(); // Not null, to not delegate to RDS + } else { + virtualHosts = httpConnectionManager.virtualHosts(); + } + if (virtualHosts != null) { + updateRoutes(virtualHosts); + } + + String rdsName = getRdsName(update); + if (rdsName != null) { + addRdsWatcher(rdsName); + } + } + + private String getRdsName(XdsListenerResource.LdsUpdate update) { + HttpConnectionManager httpConnectionManager = update.httpConnectionManager(); + if (httpConnectionManager == null) { + // TCP listener. Unsupported config + return null; + } + return httpConnectionManager.rdsName(); + } + + private RdsWatcher getRdsWatcher(XdsListenerResource.LdsUpdate update, WatcherTracer tracer) { + String rdsName = getRdsName(update); + if (rdsName == null) { + return null; + } + return (RdsWatcher) tracer.getWatcher(RDS_TYPE, rdsName); + } + + public RdsUpdateSupplier getRouteSource(WatcherTracer tracer) { + if (!hasDataValue()) { + return this; + } + HttpConnectionManager hcm = getData().getValue().httpConnectionManager(); + if (hcm == null) { + return this; + } + List virtualHosts = hcm.virtualHosts(); + if (virtualHosts != null) { + return this; + } + RdsWatcher rdsWatcher = getRdsWatcher(getData().getValue(), tracer); + assert rdsWatcher != null; + return rdsWatcher; + } + + @Override + public StatusOr getRdsUpdate() { + if (missingResult()) { + return StatusOr.fromStatus(Status.UNAVAILABLE.withDescription("Not yet loaded")); + } + if (!getData().hasValue()) { + return StatusOr.fromStatus(getData().getStatus()); + } + HttpConnectionManager hcm = getData().getValue().httpConnectionManager(); + if (hcm == null) { + return StatusOr.fromStatus( + Status.UNAVAILABLE.withDescription("Not an API listener" + nodeInfo())); + } + List virtualHosts = hcm.virtualHosts(); + if (virtualHosts == null) { + // Code shouldn't trigger this case, as it should be calling RdsWatcher instead. This would + // be easily implemented with getRdsWatcher().getRdsUpdate(), but getting here is likely a + // bug + return StatusOr.fromStatus(Status.INTERNAL.withDescription("Routes are in RDS, not LDS")); + } + return StatusOr.fromValue(new RdsUpdate(virtualHosts)); + } + } + + private class RdsWatcher extends XdsWatcherBase implements RdsUpdateSupplier { + + public RdsWatcher(String resourceName) { + super(XdsRouteConfigureResource.getInstance(), checkNotNull(resourceName, "resourceName")); + } + + @Override + public void subscribeToChildren(RdsUpdate update) { + updateRoutes(update.virtualHosts); + } + + @Override + public StatusOr getRdsUpdate() { + if (missingResult()) { + return StatusOr.fromStatus(Status.UNAVAILABLE.withDescription("Not yet loaded")); + } + return getData(); + } + } + + private class CdsWatcher extends XdsWatcherBase { + CdsWatcher(String resourceName) { + super(XdsClusterResource.getInstance(), checkNotNull(resourceName, "resourceName")); + } + + @Override + public void subscribeToChildren(XdsClusterResource.CdsUpdate update) { + switch (update.clusterType()) { + case EDS: + addEdsWatcher(getEdsServiceName()); + break; + case LOGICAL_DNS: + if (enableLogicalDns) { + addDnsWatcher(update.dnsHostName()); + } + break; + case AGGREGATE: + update.prioritizedClusterNames() + .forEach(name -> addClusterWatcher(name)); + break; + default: + } + } + + public String getEdsServiceName() { + XdsClusterResource.CdsUpdate cdsUpdate = getData().getValue(); + assert cdsUpdate.clusterType() == ClusterType.EDS; + String edsServiceName = cdsUpdate.edsServiceName(); + if (edsServiceName == null) { + edsServiceName = cdsUpdate.clusterName(); + } + return edsServiceName; + } + } + + private class EdsWatcher extends XdsWatcherBase { + private EdsWatcher(String resourceName) { + super(XdsEndpointResource.getInstance(), checkNotNull(resourceName, "resourceName")); + } + + @Override + public void subscribeToChildren(XdsEndpointResource.EdsUpdate update) {} + } + + private final class DnsWatcher implements TrackedWatcher> { + private final NameResolver resolver; + @Nullable + private StatusOr> data; + private boolean cancelled; + + public DnsWatcher(String dnsHostName, NameResolver.Args nameResolverArgs) { + this.resolver = createNameResolver(dnsHostName, nameResolverArgs); + } + + public void start() { + resolver.start(new NameResolverListener()); + } + + public void refresh() { + if (cancelled) { + return; + } + resolver.refresh(); + } + + @Override + @Nullable + public StatusOr> getData() { + return data; + } + + @Override + public void close() { + if (cancelled) { + return; + } + cancelled = true; + resolver.shutdown(); + } + + private class NameResolverListener extends NameResolver.Listener2 { + @Override + public void onResult(final NameResolver.ResolutionResult resolutionResult) { + syncContext.execute(() -> onResult2(resolutionResult)); + } + + @Override + public Status onResult2(final NameResolver.ResolutionResult resolutionResult) { + if (cancelled) { + return Status.OK; + } + data = resolutionResult.getAddressesOrError(); + maybePublishConfig(); + return resolutionResult.getAddressesOrError().getStatus(); + } + + @Override + public void onError(final Status error) { + syncContext.execute(new Runnable() { + @Override + public void run() { + if (cancelled) { + return; + } + // DnsNameResolver cannot distinguish between address-not-found and transient errors. + // Assume it is a transient error. + // TODO: Once the resolution note API is available, don't throw away the error if + // hasDataValue(); pass it as the note instead + if (!hasDataValue()) { + data = StatusOr.fromStatus(error); + maybePublishConfig(); + } + } + }); + } + } + } + + private static final class FailingNameResolver extends NameResolver { + private final Status status; + + public FailingNameResolver(Status status) { + checkNotNull(status, "status"); + checkArgument(!status.isOk(), "Status must not be OK"); + this.status = status; + } + + @Override + public void start(Listener2 listener) { + listener.onError(status); + } + + @Override + public String getServiceAuthority() { + return "bug-if-you-see-this-authority"; + } + + @Override + public void shutdown() {} + } +} diff --git a/xds/src/main/java/io/grpc/xds/XdsEndpointResource.java b/xds/src/main/java/io/grpc/xds/XdsEndpointResource.java index 010214cfcf8..9ad75595ea6 100644 --- a/xds/src/main/java/io/grpc/xds/XdsEndpointResource.java +++ b/xds/src/main/java/io/grpc/xds/XdsEndpointResource.java @@ -20,17 +20,27 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.base.MoreObjects; -import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.net.InetAddresses; +import com.google.protobuf.Any; +import com.google.protobuf.InvalidProtocolBufferException; import com.google.protobuf.Message; +import io.envoyproxy.envoy.config.core.v3.Address; +import io.envoyproxy.envoy.config.core.v3.HealthStatus; +import io.envoyproxy.envoy.config.core.v3.SocketAddress; import io.envoyproxy.envoy.config.endpoint.v3.ClusterLoadAssignment; +import io.envoyproxy.envoy.config.endpoint.v3.Endpoint; import io.envoyproxy.envoy.type.v3.FractionalPercent; import io.grpc.EquivalentAddressGroup; +import io.grpc.internal.GrpcUtil; import io.grpc.xds.Endpoints.DropOverload; import io.grpc.xds.Endpoints.LocalityLbEndpoints; +import io.grpc.xds.MetadataRegistry.MetadataValueParser; import io.grpc.xds.XdsEndpointResource.EdsUpdate; import io.grpc.xds.client.Locality; import io.grpc.xds.client.XdsClient.ResourceUpdate; import io.grpc.xds.client.XdsResourceType; +import java.net.InetAddress; import java.net.InetSocketAddress; import java.util.ArrayList; import java.util.Collections; @@ -47,6 +57,9 @@ class XdsEndpointResource extends XdsResourceType { static final String ADS_TYPE_URL_EDS = "type.googleapis.com/envoy.config.endpoint.v3.ClusterLoadAssignment"; + public static final String GRPC_EXPERIMENTAL_XDS_DUALSTACK_ENDPOINTS = + "GRPC_EXPERIMENTAL_XDS_DUALSTACK_ENDPOINTS"; + private static final XdsEndpointResource instance = new XdsEndpointResource(); static XdsEndpointResource getInstance() { @@ -95,6 +108,10 @@ protected EdsUpdate doParse(Args args, Message unpackedMessage) throws ResourceI return processClusterLoadAssignment((ClusterLoadAssignment) unpackedMessage); } + private static boolean isEnabledXdsDualStack() { + return GrpcUtil.getFlag(GRPC_EXPERIMENTAL_XDS_DUALSTACK_ENDPOINTS, false); + } + private static EdsUpdate processClusterLoadAssignment(ClusterLoadAssignment assignment) throws ResourceInvalidException { Map> priorities = new HashMap<>(); @@ -175,7 +192,8 @@ private static int getRatePerMillion(FractionalPercent percent) { @VisibleForTesting @Nullable static StructOrError parseLocalityLbEndpoints( - io.envoyproxy.envoy.config.endpoint.v3.LocalityLbEndpoints proto) { + io.envoyproxy.envoy.config.endpoint.v3.LocalityLbEndpoints proto) + throws ResourceInvalidException { // Filter out localities without or with 0 weight. if (!proto.hasLoadBalancingWeight() || proto.getLoadBalancingWeight().getValue() < 1) { return null; @@ -183,6 +201,15 @@ static StructOrError parseLocalityLbEndpoints( if (proto.getPriority() < 0) { return StructOrError.fromError("negative priority"); } + + ImmutableMap localityMetadata; + MetadataRegistry registry = MetadataRegistry.getInstance(); + try { + localityMetadata = registry.parseMetadata(proto.getMetadata()); + } catch (ResourceInvalidException e) { + throw new ResourceInvalidException("Failed to parse Locality Endpoint metadata: " + + e.getMessage(), e); + } List endpoints = new ArrayList<>(proto.getLbEndpointsCount()); for (io.envoyproxy.envoy.config.endpoint.v3.LbEndpoint endpoint : proto.getLbEndpointsList()) { // The endpoint field of each lb_endpoints must be set. @@ -190,20 +217,45 @@ static StructOrError parseLocalityLbEndpoints( if (!endpoint.hasEndpoint() || !endpoint.getEndpoint().hasAddress()) { return StructOrError.fromError("LbEndpoint with no endpoint/address"); } - io.envoyproxy.envoy.config.core.v3.SocketAddress socketAddress = - endpoint.getEndpoint().getAddress().getSocketAddress(); - InetSocketAddress addr = - new InetSocketAddress(socketAddress.getAddress(), socketAddress.getPortValue()); - boolean isHealthy = - endpoint.getHealthStatus() == io.envoyproxy.envoy.config.core.v3.HealthStatus.HEALTHY - || endpoint.getHealthStatus() - == io.envoyproxy.envoy.config.core.v3.HealthStatus.UNKNOWN; + ImmutableMap endpointMetadata; + try { + endpointMetadata = registry.parseMetadata(endpoint.getMetadata()); + } catch (ResourceInvalidException e) { + throw new ResourceInvalidException("Failed to parse Endpoint metadata: " + + e.getMessage(), e); + } + List addresses = new ArrayList<>(); + addresses.add(getInetSocketAddress(endpoint.getEndpoint().getAddress())); + + if (isEnabledXdsDualStack()) { + for (Endpoint.AdditionalAddress additionalAddress + : endpoint.getEndpoint().getAdditionalAddressesList()) { + addresses.add(getInetSocketAddress(additionalAddress.getAddress())); + } + } + boolean isHealthy = (endpoint.getHealthStatus() == HealthStatus.HEALTHY) + || (endpoint.getHealthStatus() == HealthStatus.UNKNOWN); endpoints.add(Endpoints.LbEndpoint.create( - new EquivalentAddressGroup(ImmutableList.of(addr)), - endpoint.getLoadBalancingWeight().getValue(), isHealthy)); + new EquivalentAddressGroup(addresses), + endpoint.getLoadBalancingWeight().getValue(), isHealthy, + endpoint.getEndpoint().getHostname(), + endpointMetadata)); } return StructOrError.fromStruct(Endpoints.LocalityLbEndpoints.create( - endpoints, proto.getLoadBalancingWeight().getValue(), proto.getPriority())); + endpoints, proto.getLoadBalancingWeight().getValue(), + proto.getPriority(), localityMetadata)); + } + + private static InetSocketAddress getInetSocketAddress(Address address) + throws ResourceInvalidException { + io.envoyproxy.envoy.config.core.v3.SocketAddress socketAddress = address.getSocketAddress(); + InetAddress parsedAddress; + try { + parsedAddress = InetAddresses.forString(socketAddress.getAddress()); + } catch (IllegalArgumentException ex) { + throw new ResourceInvalidException("Address is not an IP", ex); + } + return new InetSocketAddress(parsedAddress, socketAddress.getPortValue()); } static final class EdsUpdate implements ResourceUpdate { @@ -250,4 +302,47 @@ public String toString() { .toString(); } } + + public static class AddressMetadataParser implements MetadataValueParser { + + @Override + public String getTypeUrl() { + return "type.googleapis.com/envoy.config.core.v3.Address"; + } + + @Override + public java.net.SocketAddress parse(Any any) throws ResourceInvalidException { + SocketAddress socketAddress; + try { + socketAddress = any.unpack(Address.class).getSocketAddress(); + } catch (InvalidProtocolBufferException ex) { + throw new ResourceInvalidException("Invalid Resource in address proto", ex); + } + validateAddress(socketAddress); + + String ip = socketAddress.getAddress(); + int port = socketAddress.getPortValue(); + + try { + return new InetSocketAddress(InetAddresses.forString(ip), port); + } catch (IllegalArgumentException e) { + throw createException("Invalid IP address or port: " + ip + ":" + port); + } + } + + private void validateAddress(SocketAddress socketAddress) throws ResourceInvalidException { + if (socketAddress.getAddress().isEmpty()) { + throw createException("Address field is empty or invalid."); + } + long port = Integer.toUnsignedLong(socketAddress.getPortValue()); + if (port > 65535) { + throw createException(String.format("Port value %d out of range 1-65535.", port)); + } + } + + private ResourceInvalidException createException(String message) { + return new ResourceInvalidException( + "Failed to parse envoy.config.core.v3.Address: " + message); + } + } } diff --git a/xds/src/main/java/io/grpc/xds/XdsLbPolicies.java b/xds/src/main/java/io/grpc/xds/XdsLbPolicies.java index dcca2fbfff3..ae5ac38b471 100644 --- a/xds/src/main/java/io/grpc/xds/XdsLbPolicies.java +++ b/xds/src/main/java/io/grpc/xds/XdsLbPolicies.java @@ -19,7 +19,6 @@ final class XdsLbPolicies { static final String CLUSTER_MANAGER_POLICY_NAME = "cluster_manager_experimental"; static final String CDS_POLICY_NAME = "cds_experimental"; - static final String CLUSTER_RESOLVER_POLICY_NAME = "cluster_resolver_experimental"; static final String PRIORITY_POLICY_NAME = "priority_experimental"; static final String CLUSTER_IMPL_POLICY_NAME = "cluster_impl_experimental"; static final String WEIGHTED_TARGET_POLICY_NAME = "weighted_target_experimental"; diff --git a/xds/src/main/java/io/grpc/xds/XdsListenerResource.java b/xds/src/main/java/io/grpc/xds/XdsListenerResource.java index af77d128ae7..ccb88a8e543 100644 --- a/xds/src/main/java/io/grpc/xds/XdsListenerResource.java +++ b/xds/src/main/java/io/grpc/xds/XdsListenerResource.java @@ -25,6 +25,7 @@ import com.google.auto.value.AutoValue; import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableList; +import com.google.common.net.InetAddresses; import com.google.protobuf.Any; import com.google.protobuf.InvalidProtocolBufferException; import com.google.protobuf.Message; @@ -43,7 +44,6 @@ import io.grpc.xds.Filter.FilterConfig; import io.grpc.xds.XdsListenerResource.LdsUpdate; import io.grpc.xds.client.XdsResourceType; -import java.net.UnknownHostException; import java.util.ArrayList; import java.util.Collection; import java.util.HashSet; @@ -108,13 +108,13 @@ protected LdsUpdate doParse(Args args, Message unpackedMessage) Listener listener = (Listener) unpackedMessage; if (listener.hasApiListener()) { - return processClientSideListener(listener); + return processClientSideListener(listener, args); } else { return processServerSideListener(listener, args); } } - private LdsUpdate processClientSideListener(Listener listener) + private LdsUpdate processClientSideListener(Listener listener, XdsResourceType.Args args) throws ResourceInvalidException { // Unpack HttpConnectionManager from the Listener. HttpConnectionManager hcm; @@ -127,10 +127,10 @@ private LdsUpdate processClientSideListener(Listener listener) "Could not parse HttpConnectionManager config from ApiListener", e); } return LdsUpdate.forApiListener( - parseHttpConnectionManager(hcm, filterRegistry, true /* isForClient */)); + parseHttpConnectionManager(hcm, filterRegistry, true /* isForClient */, args)); } - private LdsUpdate processServerSideListener(Listener proto, Args args) + private LdsUpdate processServerSideListener(Listener proto, XdsResourceType.Args args) throws ResourceInvalidException { Set certProviderInstances = null; if (args.getBootstrapInfo() != null && args.getBootstrapInfo().certProviders() != null) { @@ -138,19 +138,19 @@ private LdsUpdate processServerSideListener(Listener proto, Args args) } return LdsUpdate.forTcpListener(parseServerSideListener(proto, (TlsContextManager) args.getSecurityConfig(), - filterRegistry, certProviderInstances)); + filterRegistry, certProviderInstances, args)); } @VisibleForTesting static EnvoyServerProtoData.Listener parseServerSideListener( Listener proto, TlsContextManager tlsContextManager, - FilterRegistry filterRegistry, Set certProviderInstances) + FilterRegistry filterRegistry, Set certProviderInstances, XdsResourceType.Args args) throws ResourceInvalidException { - if (!proto.getTrafficDirection().equals(TrafficDirection.INBOUND) - && !proto.getTrafficDirection().equals(TrafficDirection.UNSPECIFIED)) { + TrafficDirection trafficDirection = proto.getTrafficDirection(); + if (!trafficDirection.equals(TrafficDirection.INBOUND) + && !trafficDirection.equals(TrafficDirection.UNSPECIFIED)) { throw new ResourceInvalidException( - "Listener " + proto.getName() + " with invalid traffic direction: " - + proto.getTrafficDirection()); + "Listener " + proto.getName() + " with invalid traffic direction: " + trafficDirection); } if (!proto.getListenerFiltersList().isEmpty()) { throw new ResourceInvalidException( @@ -162,13 +162,16 @@ static EnvoyServerProtoData.Listener parseServerSideListener( } String address = null; + SocketAddress socketAddress = null; if (proto.getAddress().hasSocketAddress()) { - SocketAddress socketAddress = proto.getAddress().getSocketAddress(); + socketAddress = proto.getAddress().getSocketAddress(); address = socketAddress.getAddress(); + if (address.isEmpty()) { + throw new ResourceInvalidException("Invalid address: Empty address is not allowed."); + } switch (socketAddress.getPortSpecifierCase()) { case NAMED_PORT: - address = address + ":" + socketAddress.getNamedPort(); - break; + throw new ResourceInvalidException("NAMED_PORT is not supported in gRPC."); case PORT_VALUE: address = address + ":" + socketAddress.getPortValue(); break; @@ -178,56 +181,82 @@ static EnvoyServerProtoData.Listener parseServerSideListener( } ImmutableList.Builder filterChains = ImmutableList.builder(); - Set uniqueSet = new HashSet<>(); + Set filterChainNames = new HashSet<>(); + Set filterChainMatchSet = new HashSet<>(); + int i = 0; for (io.envoyproxy.envoy.config.listener.v3.FilterChain fc : proto.getFilterChainsList()) { + // May be empty. If it's not empty, required to be unique. + String filterChainName = fc.getName(); + if (filterChainName.isEmpty()) { + // Generate a name, so we can identify it in the logs. + filterChainName = "chain_" + i; + } + if (!filterChainNames.add(filterChainName)) { + throw new ResourceInvalidException("Filter chain names must be unique. " + + "Found duplicate: " + filterChainName); + } filterChains.add( - parseFilterChain(fc, tlsContextManager, filterRegistry, uniqueSet, - certProviderInstances)); + parseFilterChain(fc, filterChainName, tlsContextManager, filterRegistry, + filterChainMatchSet, certProviderInstances, args)); + i++; } + FilterChain defaultFilterChain = null; if (proto.hasDefaultFilterChain()) { + String defaultFilterChainName = proto.getDefaultFilterChain().getName(); + if (defaultFilterChainName.isEmpty()) { + defaultFilterChainName = "chain_default"; + } defaultFilterChain = parseFilterChain( - proto.getDefaultFilterChain(), tlsContextManager, filterRegistry, - null, certProviderInstances); + proto.getDefaultFilterChain(), defaultFilterChainName, tlsContextManager, filterRegistry, + null, certProviderInstances, args); } - return EnvoyServerProtoData.Listener.create( - proto.getName(), address, filterChains.build(), defaultFilterChain); + return EnvoyServerProtoData.Listener.create(proto.getName(), address, filterChains.build(), + defaultFilterChain, socketAddress == null ? null : socketAddress.getProtocol()); } @VisibleForTesting static FilterChain parseFilterChain( io.envoyproxy.envoy.config.listener.v3.FilterChain proto, - TlsContextManager tlsContextManager, FilterRegistry filterRegistry, - Set uniqueSet, Set certProviderInstances) + String filterChainName, + TlsContextManager tlsContextManager, + FilterRegistry filterRegistry, + // null disables FilterChainMatch uniqueness check, used for defaultFilterChain + @Nullable Set filterChainMatchSet, + Set certProviderInstances, + XdsResourceType.Args args) throws ResourceInvalidException { + // FilterChain contains L4 filters, so we ensure it contains only HCM. if (proto.getFiltersCount() != 1) { - throw new ResourceInvalidException("FilterChain " + proto.getName() + throw new ResourceInvalidException("FilterChain " + filterChainName + " should contain exact one HttpConnectionManager filter"); } - io.envoyproxy.envoy.config.listener.v3.Filter filter = proto.getFiltersList().get(0); - if (!filter.hasTypedConfig()) { + io.envoyproxy.envoy.config.listener.v3.Filter l4Filter = proto.getFiltersList().get(0); + if (!l4Filter.hasTypedConfig()) { throw new ResourceInvalidException( - "FilterChain " + proto.getName() + " contains filter " + filter.getName() + "FilterChain " + filterChainName + " contains filter " + l4Filter.getName() + " without typed_config"); } - Any any = filter.getTypedConfig(); - // HttpConnectionManager is the only supported network filter at the moment. + Any any = l4Filter.getTypedConfig(); if (!any.getTypeUrl().equals(TYPE_URL_HTTP_CONNECTION_MANAGER)) { throw new ResourceInvalidException( - "FilterChain " + proto.getName() + " contains filter " + filter.getName() + "FilterChain " + filterChainName + " contains filter " + l4Filter.getName() + " with unsupported typed_config type " + any.getTypeUrl()); } + + // Parse HCM. HttpConnectionManager hcmProto; try { hcmProto = any.unpack(HttpConnectionManager.class); } catch (InvalidProtocolBufferException e) { - throw new ResourceInvalidException("FilterChain " + proto.getName() + " with filter " - + filter.getName() + " failed to unpack message", e); + throw new ResourceInvalidException("FilterChain " + filterChainName + " with filter " + + l4Filter.getName() + " failed to unpack message", e); } io.grpc.xds.HttpConnectionManager httpConnectionManager = parseHttpConnectionManager( - hcmProto, filterRegistry, false /* isForClient */); + hcmProto, filterRegistry, false /* isForClient */, args); + // Parse Transport Socket. EnvoyServerProtoData.DownstreamTlsContext downstreamTlsContext = null; if (proto.hasTransportSocket()) { if (!TRANSPORT_SOCKET_NAME_TLS.equals(proto.getTransportSocket().getName())) { @@ -239,7 +268,7 @@ static FilterChain parseFilterChain( downstreamTlsContextProto = proto.getTransportSocket().getTypedConfig().unpack(DownstreamTlsContext.class); } catch (InvalidProtocolBufferException e) { - throw new ResourceInvalidException("FilterChain " + proto.getName() + throw new ResourceInvalidException("FilterChain " + filterChainName + " failed to unpack message", e); } downstreamTlsContext = @@ -247,10 +276,15 @@ static FilterChain parseFilterChain( validateDownstreamTlsContext(downstreamTlsContextProto, certProviderInstances)); } + // Parse FilterChainMatch. FilterChainMatch filterChainMatch = parseFilterChainMatch(proto.getFilterChainMatch()); - checkForUniqueness(uniqueSet, filterChainMatch); + // null used to skip this check for defaultFilterChain. + if (filterChainMatchSet != null) { + validateFilterChainMatchForUniqueness(filterChainMatchSet, filterChainMatch); + } + return FilterChain.create( - proto.getName(), + filterChainName, filterChainMatch, httpConnectionManager, downstreamTlsContext, @@ -284,15 +318,15 @@ static DownstreamTlsContext validateDownstreamTlsContext( return downstreamTlsContext; } - private static void checkForUniqueness(Set uniqueSet, + private static void validateFilterChainMatchForUniqueness( + Set filterChainMatchSet, FilterChainMatch filterChainMatch) throws ResourceInvalidException { - if (uniqueSet != null) { - List crossProduct = getCrossProduct(filterChainMatch); - for (FilterChainMatch cur : crossProduct) { - if (!uniqueSet.add(cur)) { - throw new ResourceInvalidException("FilterChainMatch must be unique. " - + "Found duplicate: " + cur); - } + // Flattens complex FilterChainMatch into a list of simple FilterChainMatch'es. + List crossProduct = getCrossProduct(filterChainMatch); + for (FilterChainMatch cur : crossProduct) { + if (!filterChainMatchSet.add(cur)) { + throw new ResourceInvalidException("FilterChainMatch must be unique. " + + "Found duplicate: " + cur); } } } @@ -420,16 +454,18 @@ private static FilterChainMatch parseFilterChainMatch( try { for (io.envoyproxy.envoy.config.core.v3.CidrRange range : proto.getPrefixRangesList()) { prefixRanges.add( - CidrRange.create(range.getAddressPrefix(), range.getPrefixLen().getValue())); + CidrRange.create(InetAddresses.forString(range.getAddressPrefix()), + range.getPrefixLen().getValue())); } for (io.envoyproxy.envoy.config.core.v3.CidrRange range : proto.getSourcePrefixRangesList()) { - sourcePrefixRanges.add( - CidrRange.create(range.getAddressPrefix(), range.getPrefixLen().getValue())); + sourcePrefixRanges.add(CidrRange.create( + InetAddresses.forString(range.getAddressPrefix()), range.getPrefixLen().getValue())); } - } catch (UnknownHostException e) { - throw new ResourceInvalidException("Failed to create CidrRange", e); + } catch (IllegalArgumentException ex) { + throw new ResourceInvalidException("Failed to create CidrRange", ex); } + ConnectionSourceType sourceType; switch (proto.getSourceType()) { case ANY: @@ -458,7 +494,7 @@ private static FilterChainMatch parseFilterChainMatch( @VisibleForTesting static io.grpc.xds.HttpConnectionManager parseHttpConnectionManager( HttpConnectionManager proto, FilterRegistry filterRegistry, - boolean isForClient) throws ResourceInvalidException { + boolean isForClient, XdsResourceType.Args args) throws ResourceInvalidException { if (proto.getXffNumTrustedHops() != 0) { throw new ResourceInvalidException( "HttpConnectionManager with xff_num_trusted_hops unsupported"); @@ -491,7 +527,7 @@ static io.grpc.xds.HttpConnectionManager parseHttpConnectionManager( "HttpConnectionManager contains duplicate HttpFilter: " + filterName); } StructOrError filterConfig = - parseHttpFilter(httpFilter, filterRegistry, isForClient); + parseHttpFilter(httpFilter, filterRegistry, isForClient, args); if ((i == proto.getHttpFiltersCount() - 1) && (filterConfig == null || !isTerminalFilter(filterConfig.getStruct()))) { throw new ResourceInvalidException("The last HttpFilter must be a terminal filter: " @@ -515,7 +551,7 @@ static io.grpc.xds.HttpConnectionManager parseHttpConnectionManager( // Parse inlined RouteConfiguration or RDS. if (proto.hasRouteConfig()) { List virtualHosts = extractVirtualHosts( - proto.getRouteConfig(), filterRegistry); + proto.getRouteConfig(), filterRegistry, args); return io.grpc.xds.HttpConnectionManager.forVirtualHosts( maxStreamDuration, virtualHosts, filterConfigs); } @@ -545,16 +581,13 @@ private static boolean isTerminalFilter(Filter.FilterConfig filterConfig) { @Nullable // Returns null if the filter is optional but not supported. static StructOrError parseHttpFilter( io.envoyproxy.envoy.extensions.filters.network.http_connection_manager.v3.HttpFilter - httpFilter, FilterRegistry filterRegistry, boolean isForClient) { + httpFilter, FilterRegistry filterRegistry, boolean isForClient, + XdsResourceType.Args args) { String filterName = httpFilter.getName(); boolean isOptional = httpFilter.getIsOptional(); if (!httpFilter.hasTypedConfig()) { - if (isOptional) { - return null; - } else { - return StructOrError.fromError( - "HttpFilter [" + filterName + "] is not optional and has no typed config"); - } + return isOptional ? null : StructOrError.fromError( + "HttpFilter [" + filterName + "] is not optional and has no typed config"); } Message rawConfig = httpFilter.getTypedConfig(); String typeUrl = httpFilter.getTypedConfig().getTypeUrl(); @@ -574,18 +607,23 @@ static StructOrError parseHttpFilter( return StructOrError.fromError( "HttpFilter [" + filterName + "] contains invalid proto: " + e); } - Filter filter = filterRegistry.get(typeUrl); - if ((isForClient && !(filter instanceof Filter.ClientInterceptorBuilder)) - || (!isForClient && !(filter instanceof Filter.ServerInterceptorBuilder))) { - if (isOptional) { - return null; - } else { - return StructOrError.fromError( - "HttpFilter [" + filterName + "](" + typeUrl + ") is required but unsupported for " - + (isForClient ? "client" : "server")); - } + + Filter.Provider provider = filterRegistry.get(typeUrl); + if (provider == null + || (isForClient && !provider.isClientFilter()) + || (!isForClient && !provider.isServerFilter())) { + // Filter type not supported. + return isOptional ? null : StructOrError.fromError( + "HttpFilter [" + filterName + "](" + typeUrl + ") is required but unsupported for " + ( + isForClient ? "client" : "server")); } - ConfigOrError filterConfig = filter.parseFilterConfig(rawConfig); + + Filter.FilterConfigParseContext filterContext = Filter.FilterConfigParseContext.builder() + .bootstrapInfo(args.getBootstrapInfo()) + .serverInfo(args.getServerInfo()) + .build(); + ConfigOrError filterConfig = + provider.parseFilterConfig(rawConfig, filterContext); if (filterConfig.errorDetail != null) { return StructOrError.fromError( "Invalid filter config for HttpFilter [" + filterName + "]: " + filterConfig.errorDetail); diff --git a/xds/src/main/java/io/grpc/xds/XdsNameResolver.java b/xds/src/main/java/io/grpc/xds/XdsNameResolver.java index 9ad9b6e82f0..69b0b824433 100644 --- a/xds/src/main/java/io/grpc/xds/XdsNameResolver.java +++ b/xds/src/main/java/io/grpc/xds/XdsNameResolver.java @@ -41,14 +41,15 @@ import io.grpc.LoadBalancer.PickSubchannelArgs; import io.grpc.Metadata; import io.grpc.MethodDescriptor; +import io.grpc.MetricRecorder; import io.grpc.NameResolver; import io.grpc.Status; import io.grpc.Status.Code; +import io.grpc.StatusOr; import io.grpc.SynchronizationContext; import io.grpc.internal.GrpcUtil; import io.grpc.internal.ObjectPool; import io.grpc.xds.ClusterSpecifierPlugin.PluginConfig; -import io.grpc.xds.Filter.ClientInterceptorBuilder; import io.grpc.xds.Filter.FilterConfig; import io.grpc.xds.Filter.NamedFilterConfig; import io.grpc.xds.RouteLookupServiceClusterSpecifierPlugin.RlsPluginConfig; @@ -58,12 +59,12 @@ import io.grpc.xds.VirtualHost.Route.RouteAction.ClusterWeight; import io.grpc.xds.VirtualHost.Route.RouteAction.HashPolicy; import io.grpc.xds.VirtualHost.Route.RouteAction.RetryPolicy; +import io.grpc.xds.VirtualHost.Route.RouteMatch; import io.grpc.xds.XdsNameResolverProvider.CallCounterProvider; -import io.grpc.xds.XdsRouteConfigureResource.RdsUpdate; import io.grpc.xds.client.Bootstrapper.AuthorityInfo; import io.grpc.xds.client.Bootstrapper.BootstrapInfo; import io.grpc.xds.client.XdsClient; -import io.grpc.xds.client.XdsClient.ResourceWatcher; +import io.grpc.xds.client.XdsInitializationException; import io.grpc.xds.client.XdsLogger; import io.grpc.xds.client.XdsLogger.XdsLogLevel; import java.util.ArrayList; @@ -79,6 +80,7 @@ import java.util.concurrent.ConcurrentMap; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Supplier; import javax.annotation.Nullable; /** @@ -90,11 +92,14 @@ * @see XdsNameResolverProvider */ final class XdsNameResolver extends NameResolver { - static final CallOptions.Key CLUSTER_SELECTION_KEY = CallOptions.Key.create("io.grpc.xds.CLUSTER_SELECTION_KEY"); + static final CallOptions.Key XDS_CONFIG_CALL_OPTION_KEY = + CallOptions.Key.create("io.grpc.xds.XDS_CONFIG_CALL_OPTION_KEY"); static final CallOptions.Key RPC_HASH_KEY = CallOptions.Key.create("io.grpc.xds.RPC_HASH_KEY"); + static final CallOptions.Key AUTO_HOST_REWRITE_KEY = + CallOptions.Key.create("io.grpc.xds.AUTO_HOST_REWRITE_KEY"); @VisibleForTesting static boolean enableTimeout = Strings.isNullOrEmpty(System.getenv("GRPC_XDS_EXPERIMENTAL_ENABLE_TIMEOUT")) @@ -112,7 +117,7 @@ final class XdsNameResolver extends NameResolver { private final ServiceConfigParser serviceConfigParser; private final SynchronizationContext syncContext; private final ScheduledExecutorService scheduler; - private final XdsClientPoolFactory xdsClientPoolFactory; + private final XdsClientPool xdsClientPool; private final ThreadSafeRandom random; private final FilterRegistry filterRegistry; private final XxHash64 hashFunc = XxHash64.INSTANCE; @@ -121,34 +126,48 @@ final class XdsNameResolver extends NameResolver { private final ConcurrentMap clusterRefs = new ConcurrentHashMap<>(); private final ConfigSelector configSelector = new ConfigSelector(); private final long randomChannelId; + private final Args nameResolverArgs; + // Must be accessed in syncContext. + // Filter instances are unique per channel, and per filter (name+typeUrl). + // NamedFilterConfig.filterStateKey -> filter_instance. + private final HashMap activeFilters = new HashMap<>(); - private volatile RoutingConfig routingConfig = RoutingConfig.empty; + private volatile RoutingConfig routingConfig; private Listener2 listener; - private ObjectPool xdsClientPool; private XdsClient xdsClient; private CallCounterProvider callCounterProvider; private ResolveState resolveState; - // Workaround for https://github.com/grpc/grpc-java/issues/8886 . This should be handled in - // XdsClient instead of here. - private boolean receivedConfig; + /** + * Constructs a new instance. + * + * @param target the target URI to resolve + * @param targetAuthority the authority component of `target`, possibly the empty string, or null + * if 'target' has no such component + */ XdsNameResolver( - @Nullable String targetAuthority, String name, @Nullable String overrideAuthority, - ServiceConfigParser serviceConfigParser, + String target, @Nullable String targetAuthority, String name, + @Nullable String overrideAuthority, ServiceConfigParser serviceConfigParser, SynchronizationContext syncContext, ScheduledExecutorService scheduler, - @Nullable Map bootstrapOverride) { - this(targetAuthority, name, overrideAuthority, serviceConfigParser, syncContext, scheduler, - SharedXdsClientPoolProvider.getDefaultProvider(), ThreadSafeRandomImpl.instance, - FilterRegistry.getDefaultRegistry(), bootstrapOverride); + @Nullable Map bootstrapOverride, + MetricRecorder metricRecorder, Args nameResolverArgs) { + this(target, targetAuthority, name, overrideAuthority, serviceConfigParser, + syncContext, scheduler, + bootstrapOverride == null + ? SharedXdsClientPoolProvider.getDefaultProvider() + : new SharedXdsClientPoolProvider(), + ThreadSafeRandomImpl.instance, FilterRegistry.getDefaultRegistry(), bootstrapOverride, + metricRecorder, nameResolverArgs); } @VisibleForTesting XdsNameResolver( - @Nullable String targetAuthority, String name, @Nullable String overrideAuthority, - ServiceConfigParser serviceConfigParser, + String target, @Nullable String targetAuthority, String name, + @Nullable String overrideAuthority, ServiceConfigParser serviceConfigParser, SynchronizationContext syncContext, ScheduledExecutorService scheduler, XdsClientPoolFactory xdsClientPoolFactory, ThreadSafeRandom random, - FilterRegistry filterRegistry, @Nullable Map bootstrapOverride) { + FilterRegistry filterRegistry, @Nullable Map bootstrapOverride, + MetricRecorder metricRecorder, Args nameResolverArgs) { this.targetAuthority = targetAuthority; // The name might have multiple slashes so encode it before verifying. @@ -160,11 +179,19 @@ final class XdsNameResolver extends NameResolver { this.serviceConfigParser = checkNotNull(serviceConfigParser, "serviceConfigParser"); this.syncContext = checkNotNull(syncContext, "syncContext"); this.scheduler = checkNotNull(scheduler, "scheduler"); - this.xdsClientPoolFactory = bootstrapOverride == null ? checkNotNull(xdsClientPoolFactory, - "xdsClientPoolFactory") : new SharedXdsClientPoolProvider(); - this.xdsClientPoolFactory.setBootstrapOverride(bootstrapOverride); + Supplier xdsClientSupplierArg = + nameResolverArgs.getArg(XdsNameResolverProvider.XDS_CLIENT_SUPPLIER); + if (xdsClientSupplierArg != null) { + this.xdsClientPool = new SupplierXdsClientPool(xdsClientSupplierArg); + } else { + checkNotNull(xdsClientPoolFactory, "xdsClientPoolFactory"); + this.xdsClientPool = new BootstrappingXdsClientPool( + xdsClientPoolFactory, target, bootstrapOverride, metricRecorder); + } this.random = checkNotNull(random, "random"); this.filterRegistry = checkNotNull(filterRegistry, "filterRegistry"); + this.nameResolverArgs = checkNotNull(nameResolverArgs, "nameResolverArgs"); + randomChannelId = random.nextLong(); logId = InternalLogId.allocate("xds-resolver", name); logger = XdsLogger.withLogId(logId); @@ -180,16 +207,17 @@ public String getServiceAuthority() { public void start(Listener2 listener) { this.listener = checkNotNull(listener, "listener"); try { - xdsClientPool = xdsClientPoolFactory.getOrCreate(); + xdsClient = xdsClientPool.getObject(); } catch (Exception e) { listener.onError( Status.UNAVAILABLE.withDescription("Failed to initialize xDS").withCause(e)); return; } - xdsClient = xdsClientPool.getObject(); BootstrapInfo bootstrapInfo = xdsClient.getBootstrapInfo(); String listenerNameTemplate; - if (targetAuthority == null) { + if (targetAuthority == null || targetAuthority.isEmpty()) { + // Both https://github.com/grpc/proposal/blob/master/A27-xds-global-load-balancing.md and + // A47-xds-federation.md seem to treat an empty authority the same as an undefined one. listenerNameTemplate = bootstrapInfo.clientDefaultListenerResourceNameTemplate(); } else { AuthorityInfo authorityInfo = bootstrapInfo.authorities().get(targetAuthority); @@ -213,10 +241,18 @@ public void start(Listener2 listener) { } ldsResourceName = XdsClient.canonifyResourceName(ldsResourceName); callCounterProvider = SharedCallCounterMap.getInstance(); + resolveState = new ResolveState(ldsResourceName); resolveState.start(); } + @Override + public void refresh() { + if (resolveState != null) { + resolveState.refresh(); + } + } + private static String expandPercentS(String template, String replacement) { return template.replace("%s", replacement); } @@ -225,7 +261,7 @@ private static String expandPercentS(String template, String replacement) { public void shutdown() { logger.log(XdsLogLevel.INFO, "Shutdown"); if (resolveState != null) { - resolveState.stop(); + resolveState.shutdown(); } if (xdsClient != null) { xdsClient = xdsClientPool.returnObject(xdsClient); @@ -274,7 +310,7 @@ XdsClient getXdsClient() { } // called in syncContext - private void updateResolutionResult() { + private void updateResolutionResult(XdsConfig xdsConfig) { syncContext.throwIfNotInThisSynchronizationContext(); ImmutableMap.Builder childPolicy = new ImmutableMap.Builder<>(); @@ -290,13 +326,15 @@ private void updateResolutionResult() { if (logger.isLoggable(XdsLogLevel.INFO)) { logger.log( - XdsLogLevel.INFO, "Generated service config:\n{0}", new Gson().toJson(rawServiceConfig)); + XdsLogLevel.INFO, "Generated service config: {0}", new Gson().toJson(rawServiceConfig)); } ConfigOrError parsedServiceConfig = serviceConfigParser.parseServiceConfig(rawServiceConfig); Attributes attrs = Attributes.newBuilder() - .set(InternalXdsAttributes.XDS_CLIENT_POOL, xdsClientPool) - .set(InternalXdsAttributes.CALL_COUNTER_PROVIDER, callCounterProvider) + .set(XdsAttributes.XDS_CLIENT, xdsClient) + .set(XdsAttributes.XDS_CONFIG, xdsConfig) + .set(XdsAttributes.XDS_CLUSTER_SUBSCRIPT_REGISTRY, resolveState.xdsDependencyManager) + .set(XdsAttributes.CALL_COUNTER_PROVIDER, callCounterProvider) .set(InternalConfigSelector.KEY, configSelector) .build(); ResolutionResult result = @@ -304,8 +342,9 @@ private void updateResolutionResult() { .setAttributes(attrs) .setServiceConfig(parsedServiceConfig) .build(); - listener.onResult(result); - receivedConfig = true; + if (!listener.onResult2(result).isOk()) { + resolveState.xdsDependencyManager.requestReresolution(); + } } /** @@ -371,21 +410,21 @@ static boolean matchHostName(String hostName, String pattern) { private final class ConfigSelector extends InternalConfigSelector { @Override public Result selectConfig(PickSubchannelArgs args) { - String cluster = null; - Route selectedRoute = null; RoutingConfig routingCfg; - Map selectedOverrideConfigs; - List filterInterceptors = new ArrayList<>(); + RouteData selectedRoute; + String cluster; + ClientInterceptor filters; Metadata headers = args.getHeaders(); + String path = "/" + args.getMethodDescriptor().getFullMethodName(); do { routingCfg = routingConfig; - selectedOverrideConfigs = new HashMap<>(routingCfg.virtualHostOverrideConfig); - for (Route route : routingCfg.routes) { - if (RoutingUtils.matchRoute( - route.routeMatch(), "/" + args.getMethodDescriptor().getFullMethodName(), - headers, random)) { + if (routingCfg.errorStatus != null) { + return Result.forError(routingCfg.errorStatus); + } + selectedRoute = null; + for (RouteData route : routingCfg.routes) { + if (RoutingUtils.matchRoute(route.routeMatch, path, headers, random)) { selectedRoute = route; - selectedOverrideConfigs.putAll(route.filterConfigOverrides()); break; } } @@ -393,38 +432,45 @@ public Result selectConfig(PickSubchannelArgs args) { return Result.forError( Status.UNAVAILABLE.withDescription("Could not find xDS route matching RPC")); } - if (selectedRoute.routeAction() == null) { + if (selectedRoute.routeAction == null) { return Result.forError(Status.UNAVAILABLE.withDescription( "Could not route RPC to Route with non-forwarding action")); } - RouteAction action = selectedRoute.routeAction(); + RouteAction action = selectedRoute.routeAction; if (action.cluster() != null) { cluster = prefixedClusterName(action.cluster()); + filters = selectedRoute.filterChoices.get(0); } else if (action.weightedClusters() != null) { + // XdsRouteConfigureResource verifies the total weight will not be 0 or exceed uint32 long totalWeight = 0; for (ClusterWeight weightedCluster : action.weightedClusters()) { totalWeight += weightedCluster.weight(); } long select = random.nextLong(totalWeight); long accumulator = 0; - for (ClusterWeight weightedCluster : action.weightedClusters()) { + for (int i = 0; ; i++) { + ClusterWeight weightedCluster = action.weightedClusters().get(i); accumulator += weightedCluster.weight(); if (select < accumulator) { cluster = prefixedClusterName(weightedCluster.name()); - selectedOverrideConfigs.putAll(weightedCluster.filterConfigOverrides()); + filters = selectedRoute.filterChoices.get(i); break; } } } else if (action.namedClusterSpecifierPluginConfig() != null) { cluster = prefixedClusterSpecifierPluginName(action.namedClusterSpecifierPluginConfig().name()); + filters = selectedRoute.filterChoices.get(0); + } else { + // updateRoutes() discards routes with unknown actions + throw new AssertionError(); } } while (!retainCluster(cluster)); + + final RouteAction routeAction = selectedRoute.routeAction; Long timeoutNanos = null; if (enableTimeout) { - if (selectedRoute != null) { - timeoutNanos = selectedRoute.routeAction().timeoutNano(); - } + timeoutNanos = routeAction.timeoutNano(); if (timeoutNanos == null) { timeoutNanos = routingCfg.fallbackTimeoutNano; } @@ -432,8 +478,7 @@ public Result selectConfig(PickSubchannelArgs args) { timeoutNanos = null; } } - RetryPolicy retryPolicy = - selectedRoute == null ? null : selectedRoute.routeAction().retryPolicy(); + RetryPolicy retryPolicy = routeAction.retryPolicy(); // TODO(chengyuanzhang): avoid service config generation and parsing for each call. Map rawServiceConfig = generateServiceConfigWithMethodConfig(timeoutNanos, retryPolicy); @@ -445,31 +490,21 @@ public Result selectConfig(PickSubchannelArgs args) { parsedServiceConfig.getError().augmentDescription( "Failed to parse service config (method config)")); } - if (routingCfg.filterChain != null) { - for (NamedFilterConfig namedFilter : routingCfg.filterChain) { - FilterConfig filterConfig = namedFilter.filterConfig; - Filter filter = filterRegistry.get(filterConfig.typeUrl()); - if (filter instanceof ClientInterceptorBuilder) { - ClientInterceptor interceptor = ((ClientInterceptorBuilder) filter) - .buildClientInterceptor( - filterConfig, selectedOverrideConfigs.get(namedFilter.name), - args, scheduler); - if (interceptor != null) { - filterInterceptors.add(interceptor); - } - } - } - } final String finalCluster = cluster; - final long hash = generateHash(selectedRoute.routeAction().hashPolicies(), headers); + final XdsConfig xdsConfig = routingCfg.xdsConfig; + final long hash = generateHash(routeAction.hashPolicies(), headers); class ClusterSelectionInterceptor implements ClientInterceptor { @Override public ClientCall interceptCall( final MethodDescriptor method, CallOptions callOptions, final Channel next) { - final CallOptions callOptionsForCluster = + CallOptions callOptionsForCluster = callOptions.withOption(CLUSTER_SELECTION_KEY, finalCluster) + .withOption(XDS_CONFIG_CALL_OPTION_KEY, xdsConfig) .withOption(RPC_HASH_KEY, hash); + if (routeAction.autoHostRewrite()) { + callOptionsForCluster = callOptionsForCluster.withOption(AUTO_HOST_REWRITE_KEY, true); + } return new SimpleForwardingClientCall( next.newCall(method, callOptionsForCluster)) { @Override @@ -498,11 +533,11 @@ public void onClose(Status status, Metadata trailers) { } } - filterInterceptors.add(new ClusterSelectionInterceptor()); return Result.newBuilder() .setConfig(config) - .setInterceptor(combineInterceptors(filterInterceptors)) + .setInterceptor(combineInterceptors( + ImmutableList.of(new ClusterSelectionInterceptor(), filters))) .build(); } @@ -524,13 +559,21 @@ private boolean retainCluster(String cluster) { private void releaseCluster(final String cluster) { int count = clusterRefs.get(cluster).refCount.decrementAndGet(); + if (count < 0) { + throw new AssertionError(); + } if (count == 0) { syncContext.execute(new Runnable() { @Override public void run() { - if (clusterRefs.get(cluster).refCount.get() == 0) { - clusterRefs.remove(cluster); - updateResolutionResult(); + if (clusterRefs.get(cluster).refCount.get() != 0) { + throw new AssertionError(); + } + clusterRefs.remove(cluster).close(); + if (resolveState.lastConfigOrStatus.hasValue()) { + updateResolutionResult(resolveState.lastConfigOrStatus.getValue()); + } else { + resolveState.cleanUpRoutes(resolveState.lastConfigOrStatus.getStatus()); } } }); @@ -568,8 +611,18 @@ private long generateHash(List hashPolicies, Metadata headers) { } } + static final class PassthroughClientInterceptor implements ClientInterceptor { + @Override + public ClientCall interceptCall( + MethodDescriptor method, CallOptions callOptions, Channel next) { + return next.newCall(method, callOptions); + } + } + private static ClientInterceptor combineInterceptors(final List interceptors) { - checkArgument(!interceptors.isEmpty(), "empty interceptors"); + if (interceptors.size() == 0) { + return new PassthroughClientInterceptor(); + } if (interceptors.size() == 1) { return interceptors.get(0); } @@ -609,103 +662,106 @@ private static String prefixedClusterSpecifierPluginName(String pluginName) { return "cluster_specifier_plugin:" + pluginName; } - private static final class FailingConfigSelector extends InternalConfigSelector { - private final Result result; - - public FailingConfigSelector(Status error) { - this.result = Result.forError(error); - } - - @Override - public Result selectConfig(PickSubchannelArgs args) { - return result; - } - } - - private class ResolveState implements ResourceWatcher { + class ResolveState implements XdsDependencyManager.XdsConfigWatcher { private final ConfigOrError emptyServiceConfig = serviceConfigParser.parseServiceConfig(Collections.emptyMap()); - private final String ldsResourceName; + private final String authority; + private final XdsDependencyManager xdsDependencyManager; private boolean stopped; @Nullable private Set existingClusters; // clusters to which new requests can be routed - @Nullable - private RouteDiscoveryState routeDiscoveryState; + private StatusOr lastConfigOrStatus; - ResolveState(String ldsResourceName) { - this.ldsResourceName = ldsResourceName; + private ResolveState(String ldsResourceName) { + authority = overrideAuthority != null ? overrideAuthority : encodedServiceAuthority; + xdsDependencyManager = + new XdsDependencyManager(xdsClient, syncContext, authority, ldsResourceName, + nameResolverArgs); } - @Override - public void onChanged(final XdsListenerResource.LdsUpdate update) { + void start() { + xdsDependencyManager.start(this); + } + + void refresh() { + xdsDependencyManager.requestReresolution(); + } + + private void shutdown() { if (stopped) { return; } - logger.log(XdsLogLevel.INFO, "Receive LDS resource update: {0}", update); - HttpConnectionManager httpConnectionManager = update.httpConnectionManager(); - List virtualHosts = httpConnectionManager.virtualHosts(); - String rdsName = httpConnectionManager.rdsName(); - cleanUpRouteDiscoveryState(); - if (virtualHosts != null) { - updateRoutes(virtualHosts, httpConnectionManager.httpMaxStreamDurationNano(), - httpConnectionManager.httpFilterConfigs()); - } else { - routeDiscoveryState = new RouteDiscoveryState( - rdsName, httpConnectionManager.httpMaxStreamDurationNano(), - httpConnectionManager.httpFilterConfigs()); - logger.log(XdsLogLevel.INFO, "Start watching RDS resource {0}", rdsName); - xdsClient.watchXdsResource(XdsRouteConfigureResource.getInstance(), - rdsName, routeDiscoveryState, syncContext); - } + + stopped = true; + xdsDependencyManager.shutdown(); + updateActiveFilters(null); } @Override - public void onError(final Status error) { - if (stopped || receivedConfig) { + public void onUpdate(StatusOr updateOrStatus) { + if (stopped) { return; } - listener.onError(Status.UNAVAILABLE.withCause(error.getCause()).withDescription( - String.format("Unable to load LDS %s. xDS server returned: %s: %s", - ldsResourceName, error.getCode(), error.getDescription()))); - } + logger.log(XdsLogLevel.INFO, "Receive XDS resource update: {0}", updateOrStatus); - @Override - public void onResourceDoesNotExist(final String resourceName) { - if (stopped) { + lastConfigOrStatus = updateOrStatus; + if (!updateOrStatus.hasValue()) { + updateActiveFilters(null); + cleanUpRoutes(updateOrStatus.getStatus()); return; } - String error = "LDS resource does not exist: " + resourceName; - logger.log(XdsLogLevel.INFO, error); - cleanUpRouteDiscoveryState(); - cleanUpRoutes(error); - } - private void start() { - logger.log(XdsLogLevel.INFO, "Start watching LDS resource {0}", ldsResourceName); - xdsClient.watchXdsResource(XdsListenerResource.getInstance(), - ldsResourceName, this, syncContext); - } + // Process Route + XdsConfig update = updateOrStatus.getValue(); + HttpConnectionManager httpConnectionManager = update.getListener().httpConnectionManager(); + if (httpConnectionManager == null) { + logger.log(XdsLogLevel.INFO, "API Listener: httpConnectionManager does not exist."); + updateActiveFilters(null); + cleanUpRoutes(updateOrStatus.getStatus()); + return; + } - private void stop() { - logger.log(XdsLogLevel.INFO, "Stop watching LDS resource {0}", ldsResourceName); - stopped = true; - cleanUpRouteDiscoveryState(); - xdsClient.cancelXdsResourceWatch(XdsListenerResource.getInstance(), ldsResourceName, this); + VirtualHost virtualHost = update.getVirtualHost(); + ImmutableList filterConfigs = httpConnectionManager.httpFilterConfigs(); + long streamDurationNano = httpConnectionManager.httpMaxStreamDurationNano(); + + updateActiveFilters(filterConfigs); + updateRoutes(update, virtualHost, streamDurationNano, filterConfigs); } // called in syncContext - private void updateRoutes(List virtualHosts, long httpMaxStreamDurationNano, - @Nullable List filterConfigs) { - String authority = overrideAuthority != null ? overrideAuthority : encodedServiceAuthority; - VirtualHost virtualHost = RoutingUtils.findVirtualHostForHostName(virtualHosts, authority); - if (virtualHost == null) { - String error = "Failed to find virtual host matching hostname: " + authority; - logger.log(XdsLogLevel.WARNING, error); - cleanUpRoutes(error); - return; + private void updateActiveFilters(@Nullable List filterConfigs) { + if (filterConfigs == null) { + filterConfigs = ImmutableList.of(); + } + Set filtersToShutdown = new HashSet<>(activeFilters.keySet()); + for (NamedFilterConfig namedFilter : filterConfigs) { + String typeUrl = namedFilter.filterConfig.typeUrl(); + String filterKey = namedFilter.filterStateKey(); + + Filter.Provider provider = filterRegistry.get(typeUrl); + checkNotNull(provider, "provider %s", typeUrl); + Filter filter = activeFilters.computeIfAbsent( + filterKey, k -> provider.newInstance(namedFilter.name)); + checkNotNull(filter, "filter %s", filterKey); + filtersToShutdown.remove(filterKey); + } + + // Shutdown filters not present in current HCM. + for (String filterKey : filtersToShutdown) { + Filter filterToShutdown = activeFilters.remove(filterKey); + checkNotNull(filterToShutdown, "filterToShutdown %s", filterKey); + filterToShutdown.close(); } + } + private void updateRoutes( + XdsConfig xdsConfig, + @Nullable VirtualHost virtualHost, + long httpMaxStreamDurationNano, + @Nullable List filterConfigs) { List routes = virtualHost.routes(); + ImmutableList.Builder routesData = ImmutableList.builder(); // Populate all clusters to which requests can be routed to through the virtual host. Set clusters = new HashSet<>(); @@ -716,26 +772,36 @@ private void updateRoutes(List virtualHosts, long httpMaxStreamDura for (Route route : routes) { RouteAction action = route.routeAction(); String prefixedName; - if (action != null) { - if (action.cluster() != null) { - prefixedName = prefixedClusterName(action.cluster()); + if (action == null) { + routesData.add(new RouteData(route.routeMatch(), null, ImmutableList.of())); + } else if (action.cluster() != null) { + prefixedName = prefixedClusterName(action.cluster()); + clusters.add(prefixedName); + clusterNameMap.put(prefixedName, action.cluster()); + ClientInterceptor filters = createFilters(filterConfigs, virtualHost, route, null); + routesData.add(new RouteData(route.routeMatch(), route.routeAction(), filters)); + } else if (action.weightedClusters() != null) { + ImmutableList.Builder filterList = ImmutableList.builder(); + for (ClusterWeight weightedCluster : action.weightedClusters()) { + prefixedName = prefixedClusterName(weightedCluster.name()); clusters.add(prefixedName); - clusterNameMap.put(prefixedName, action.cluster()); - } else if (action.weightedClusters() != null) { - for (ClusterWeight weighedCluster : action.weightedClusters()) { - prefixedName = prefixedClusterName(weighedCluster.name()); - clusters.add(prefixedName); - clusterNameMap.put(prefixedName, weighedCluster.name()); - } - } else if (action.namedClusterSpecifierPluginConfig() != null) { - PluginConfig pluginConfig = action.namedClusterSpecifierPluginConfig().config(); - if (pluginConfig instanceof RlsPluginConfig) { - prefixedName = prefixedClusterSpecifierPluginName( - action.namedClusterSpecifierPluginConfig().name()); - clusters.add(prefixedName); - rlsPluginConfigMap.put(prefixedName, (RlsPluginConfig) pluginConfig); - } + clusterNameMap.put(prefixedName, weightedCluster.name()); + filterList.add(createFilters(filterConfigs, virtualHost, route, weightedCluster)); } + routesData.add( + new RouteData(route.routeMatch(), route.routeAction(), filterList.build())); + } else if (action.namedClusterSpecifierPluginConfig() != null) { + PluginConfig pluginConfig = action.namedClusterSpecifierPluginConfig().config(); + if (pluginConfig instanceof RlsPluginConfig) { + prefixedName = prefixedClusterSpecifierPluginName( + action.namedClusterSpecifierPluginConfig().name()); + clusters.add(prefixedName); + rlsPluginConfigMap.put(prefixedName, (RlsPluginConfig) pluginConfig); + } + ClientInterceptor filters = createFilters(filterConfigs, virtualHost, route, null); + routesData.add(new RouteData(route.routeMatch(), route.routeAction(), filters)); + } else { + // Discard route } } @@ -752,9 +818,13 @@ private void updateRoutes(List virtualHosts, long httpMaxStreamDura clusterRefs.get(cluster).refCount.incrementAndGet(); } else { if (clusterNameMap.containsKey(cluster)) { + assert cluster.startsWith("cluster:"); + XdsConfig.Subscription subscription = + xdsDependencyManager.subscribeToCluster(cluster.substring("cluster:".length())); clusterRefs.put( cluster, - ClusterRefState.forCluster(new AtomicInteger(1), clusterNameMap.get(cluster))); + ClusterRefState.forCluster( + new AtomicInteger(1), clusterNameMap.get(cluster), subscription)); } if (rlsPluginConfigMap.containsKey(cluster)) { clusterRefs.put( @@ -775,108 +845,86 @@ private void updateRoutes(List virtualHosts, long httpMaxStreamDura } } // Update service config to include newly added clusters. - if (shouldUpdateResult) { - updateResolutionResult(); + if (shouldUpdateResult && routingConfig != null) { + updateResolutionResult(xdsConfig); + shouldUpdateResult = false; + } else { + // Need to update at least once + shouldUpdateResult = true; } // Make newly added clusters selectable by config selector and deleted clusters no longer // selectable. - routingConfig = - new RoutingConfig( - httpMaxStreamDurationNano, routes, filterConfigs, - virtualHost.filterConfigOverrides()); - shouldUpdateResult = false; + routingConfig = new RoutingConfig(xdsConfig, httpMaxStreamDurationNano, routesData.build()); for (String cluster : deletedClusters) { int count = clusterRefs.get(cluster).refCount.decrementAndGet(); if (count == 0) { - clusterRefs.remove(cluster); + clusterRefs.remove(cluster).close(); shouldUpdateResult = true; } } if (shouldUpdateResult) { - updateResolutionResult(); + updateResolutionResult(xdsConfig); } } - private void cleanUpRoutes(String error) { + private ClientInterceptor createFilters( + @Nullable List filterConfigs, + VirtualHost virtualHost, + Route route, + @Nullable ClusterWeight weightedCluster) { + if (filterConfigs == null) { + return new PassthroughClientInterceptor(); + } + + Map selectedOverrideConfigs = + new HashMap<>(virtualHost.filterConfigOverrides()); + selectedOverrideConfigs.putAll(route.filterConfigOverrides()); + if (weightedCluster != null) { + selectedOverrideConfigs.putAll(weightedCluster.filterConfigOverrides()); + } + + ImmutableList.Builder filterInterceptors = ImmutableList.builder(); + for (NamedFilterConfig namedFilter : filterConfigs) { + String name = namedFilter.name; + FilterConfig config = namedFilter.filterConfig; + FilterConfig overrideConfig = selectedOverrideConfigs.get(name); + String filterKey = namedFilter.filterStateKey(); + + Filter filter = activeFilters.get(filterKey); + checkNotNull(filter, "activeFilters.get(%s)", filterKey); + ClientInterceptor interceptor = + filter.buildClientInterceptor(config, overrideConfig, scheduler); + + if (interceptor != null) { + filterInterceptors.add(interceptor); + } + } + + // Combine interceptors produced by different filters into a single one that executes + // them sequentially. The order is preserved. + return combineInterceptors(filterInterceptors.build()); + } + + private void cleanUpRoutes(Status error) { + routingConfig = new RoutingConfig(error); if (existingClusters != null) { for (String cluster : existingClusters) { int count = clusterRefs.get(cluster).refCount.decrementAndGet(); if (count == 0) { - clusterRefs.remove(cluster); + clusterRefs.remove(cluster).close(); } } existingClusters = null; } - routingConfig = RoutingConfig.empty; + // Without addresses the default LB (normally pick_first) should become TRANSIENT_FAILURE, and - // the config selector handles the error message itself. Once the LB API allows providing - // failure information for addresses yet still providing a service config, the config seector - // could be avoided. - listener.onResult(ResolutionResult.newBuilder() + // the config selector handles the error message itself. + listener.onResult2(ResolutionResult.newBuilder() .setAttributes(Attributes.newBuilder() - .set(InternalConfigSelector.KEY, - new FailingConfigSelector(Status.UNAVAILABLE.withDescription(error))) + .set(InternalConfigSelector.KEY, configSelector) .build()) .setServiceConfig(emptyServiceConfig) .build()); - receivedConfig = true; - } - - private void cleanUpRouteDiscoveryState() { - if (routeDiscoveryState != null) { - String rdsName = routeDiscoveryState.resourceName; - logger.log(XdsLogLevel.INFO, "Stop watching RDS resource {0}", rdsName); - xdsClient.cancelXdsResourceWatch(XdsRouteConfigureResource.getInstance(), rdsName, - routeDiscoveryState); - routeDiscoveryState = null; - } - } - - /** - * Discovery state for RouteConfiguration resource. One instance for each Listener resource - * update. - */ - private class RouteDiscoveryState implements ResourceWatcher { - private final String resourceName; - private final long httpMaxStreamDurationNano; - @Nullable - private final List filterConfigs; - - private RouteDiscoveryState(String resourceName, long httpMaxStreamDurationNano, - @Nullable List filterConfigs) { - this.resourceName = resourceName; - this.httpMaxStreamDurationNano = httpMaxStreamDurationNano; - this.filterConfigs = filterConfigs; - } - - @Override - public void onChanged(final RdsUpdate update) { - if (RouteDiscoveryState.this != routeDiscoveryState) { - return; - } - logger.log(XdsLogLevel.INFO, "Received RDS resource update: {0}", update); - updateRoutes(update.virtualHosts, httpMaxStreamDurationNano, filterConfigs); - } - - @Override - public void onError(final Status error) { - if (RouteDiscoveryState.this != routeDiscoveryState || receivedConfig) { - return; - } - listener.onError(Status.UNAVAILABLE.withCause(error.getCause()).withDescription( - String.format("Unable to load RDS %s. xDS server returned: %s: %s", - resourceName, error.getCode(), error.getDescription()))); - } - - @Override - public void onResourceDoesNotExist(final String resourceName) { - if (RouteDiscoveryState.this != routeDiscoveryState) { - return; - } - String error = "RDS resource does not exist: " + resourceName; - logger.log(XdsLogLevel.INFO, error); - cleanUpRoutes(error); - } } } @@ -884,23 +932,62 @@ public void onResourceDoesNotExist(final String resourceName) { * VirtualHost-level configuration for request routing. */ private static class RoutingConfig { - private final long fallbackTimeoutNano; - final List routes; - // Null if HttpFilter is not supported. - @Nullable final List filterChain; - final Map virtualHostOverrideConfig; - - private static RoutingConfig empty = new RoutingConfig( - 0, Collections.emptyList(), null, Collections.emptyMap()); + final XdsConfig xdsConfig; + final long fallbackTimeoutNano; + final ImmutableList routes; + final Status errorStatus; private RoutingConfig( - long fallbackTimeoutNano, List routes, @Nullable List filterChain, - Map virtualHostOverrideConfig) { + XdsConfig xdsConfig, long fallbackTimeoutNano, ImmutableList routes) { + this.xdsConfig = checkNotNull(xdsConfig, "xdsConfig"); this.fallbackTimeoutNano = fallbackTimeoutNano; - this.routes = routes; - checkArgument(filterChain == null || !filterChain.isEmpty(), "filterChain is empty"); - this.filterChain = filterChain == null ? null : Collections.unmodifiableList(filterChain); - this.virtualHostOverrideConfig = Collections.unmodifiableMap(virtualHostOverrideConfig); + this.routes = checkNotNull(routes, "routes"); + this.errorStatus = null; + } + + private RoutingConfig(Status errorStatus) { + this.xdsConfig = null; + this.fallbackTimeoutNano = 0; + this.routes = null; + this.errorStatus = checkNotNull(errorStatus, "errorStatus"); + checkArgument(!errorStatus.isOk(), "errorStatus should not be okay"); + } + } + + static final class RouteData { + final RouteMatch routeMatch; + /** null implies non-forwarding action. */ + @Nullable + final RouteAction routeAction; + /** + * Only one of these interceptors should be used per-RPC. There are only multiple values in the + * list for weighted clusters, in which case the order of the list mirrors the weighted + * clusters. + */ + final ImmutableList filterChoices; + + RouteData(RouteMatch routeMatch, @Nullable RouteAction routeAction, ClientInterceptor filter) { + this(routeMatch, routeAction, ImmutableList.of(filter)); + } + + RouteData( + RouteMatch routeMatch, + @Nullable RouteAction routeAction, + ImmutableList filterChoices) { + this.routeMatch = checkNotNull(routeMatch, "routeMatch"); + checkArgument( + routeAction == null || !filterChoices.isEmpty(), + "filter may be empty only for non-forwarding action"); + this.routeAction = routeAction; + if (routeAction != null && routeAction.weightedClusters() != null) { + checkArgument( + routeAction.weightedClusters().size() == filterChoices.size(), + "filter choices must match size of weighted clusters"); + } + for (ClientInterceptor filter : filterChoices) { + checkNotNull(filter, "entry in filterChoices is null"); + } + this.filterChoices = checkNotNull(filterChoices, "filterChoices"); } } @@ -910,15 +997,18 @@ private static class ClusterRefState { final String traditionalCluster; @Nullable final RlsPluginConfig rlsPluginConfig; + @Nullable + final XdsConfig.Subscription subscription; private ClusterRefState( AtomicInteger refCount, @Nullable String traditionalCluster, - @Nullable RlsPluginConfig rlsPluginConfig) { + @Nullable RlsPluginConfig rlsPluginConfig, @Nullable XdsConfig.Subscription subscription) { this.refCount = refCount; checkArgument(traditionalCluster == null ^ rlsPluginConfig == null, "There must be exactly one non-null value in traditionalCluster and pluginConfig"); this.traditionalCluster = traditionalCluster; this.rlsPluginConfig = rlsPluginConfig; + this.subscription = subscription; } private Map toLbPolicy() { @@ -931,19 +1021,97 @@ private ClusterRefState( .put("routeLookupConfig", rlsPluginConfig.config()) .put( "childPolicy", - ImmutableList.of(ImmutableMap.of(XdsLbPolicies.CDS_POLICY_NAME, ImmutableMap.of()))) + ImmutableList.of(ImmutableMap.of(XdsLbPolicies.CDS_POLICY_NAME, ImmutableMap.of( + "is_dynamic", true)))) .put("childPolicyConfigTargetFieldName", "cluster") .buildOrThrow(); return ImmutableMap.of("rls_experimental", rlsConfig); } } - static ClusterRefState forCluster(AtomicInteger refCount, String name) { - return new ClusterRefState(refCount, name, null); + private void close() { + if (subscription != null) { + subscription.close(); + } + } + + static ClusterRefState forCluster( + AtomicInteger refCount, String name, XdsConfig.Subscription subscription) { + return new ClusterRefState(refCount, name, null, checkNotNull(subscription, "subscription")); + } + + static ClusterRefState forRlsPlugin( + AtomicInteger refCount, + RlsPluginConfig rlsPluginConfig) { + return new ClusterRefState(refCount, null, rlsPluginConfig, null); + } + } + + /** An ObjectPool, except it can throw an exception. */ + private interface XdsClientPool { + XdsClient getObject() throws XdsInitializationException; + + XdsClient returnObject(XdsClient xdsClient); + } + + private static final class BootstrappingXdsClientPool implements XdsClientPool { + private final XdsClientPoolFactory xdsClientPoolFactory; + private final String target; + private final @Nullable Map bootstrapOverride; + private final MetricRecorder metricRecorder; + private ObjectPool xdsClientPool; + + BootstrappingXdsClientPool( + XdsClientPoolFactory xdsClientPoolFactory, + String target, + @Nullable Map bootstrapOverride, + MetricRecorder metricRecorder) { + this.xdsClientPoolFactory = checkNotNull(xdsClientPoolFactory, "xdsClientPoolFactory"); + this.target = checkNotNull(target, "target"); + this.bootstrapOverride = bootstrapOverride; + this.metricRecorder = checkNotNull(metricRecorder, "metricRecorder"); + } + + @Override + public XdsClient getObject() throws XdsInitializationException { + if (xdsClientPool == null) { + BootstrapInfo bootstrapInfo; + if (bootstrapOverride == null) { + bootstrapInfo = GrpcBootstrapperImpl.defaultBootstrap(); + } else { + bootstrapInfo = new GrpcBootstrapperImpl().bootstrap(bootstrapOverride); + } + this.xdsClientPool = + xdsClientPoolFactory.getOrCreate(target, bootstrapInfo, metricRecorder); + } + return xdsClientPool.getObject(); + } + + @Override + public XdsClient returnObject(XdsClient xdsClient) { + return xdsClientPool.returnObject(xdsClient); } + } - static ClusterRefState forRlsPlugin(AtomicInteger refCount, RlsPluginConfig rlsPluginConfig) { - return new ClusterRefState(refCount, null, rlsPluginConfig); + private static final class SupplierXdsClientPool implements XdsClientPool { + private final Supplier xdsClientSupplier; + + SupplierXdsClientPool(Supplier xdsClientSupplier) { + this.xdsClientSupplier = checkNotNull(xdsClientSupplier, "xdsClientSupplier"); + } + + @Override + public XdsClient getObject() throws XdsInitializationException { + XdsClient xdsClient = xdsClientSupplier.get(); + if (xdsClient == null) { + throw new XdsInitializationException("Caller failed to initialize XDS_CLIENT_SUPPLIER"); + } + return xdsClient; + } + + @Override + public XdsClient returnObject(XdsClient xdsClient) { + return null; } } } diff --git a/xds/src/main/java/io/grpc/xds/XdsNameResolverProvider.java b/xds/src/main/java/io/grpc/xds/XdsNameResolverProvider.java index 598be07fcd8..51b1ff49bf0 100644 --- a/xds/src/main/java/io/grpc/xds/XdsNameResolverProvider.java +++ b/xds/src/main/java/io/grpc/xds/XdsNameResolverProvider.java @@ -22,6 +22,8 @@ import io.grpc.Internal; import io.grpc.NameResolver.Args; import io.grpc.NameResolverProvider; +import io.grpc.Uri; +import io.grpc.xds.client.XdsClient; import java.net.InetSocketAddress; import java.net.SocketAddress; import java.net.URI; @@ -29,6 +31,7 @@ import java.util.Collections; import java.util.Map; import java.util.concurrent.atomic.AtomicLong; +import java.util.function.Supplier; import javax.annotation.Nullable; /** @@ -43,6 +46,13 @@ */ @Internal public final class XdsNameResolverProvider extends NameResolverProvider { + /** + * If provided, the suppler must return non-null when lb.start() is called (which implies not + * throwing), and the XdsClient must remain alive until lb.shutdown() returns. It may only be + * called from the synchronization context. + */ + public static final Args.Key> XDS_CLIENT_SUPPLIER = + Args.Key.create("io.grpc.xds.XdsNameResolverProvider.XDS_CLIENT_SUPPLIER"); private static final String SCHEME = "xds"; private final String scheme; @@ -77,15 +87,43 @@ public XdsNameResolver newNameResolver(URI targetUri, Args args) { targetPath, targetUri); String name = targetPath.substring(1); - return new XdsNameResolver( - targetUri.getAuthority(), name, args.getOverrideAuthority(), - args.getServiceConfigParser(), args.getSynchronizationContext(), - args.getScheduledExecutorService(), - bootstrapOverride); + // TODO(jdcormie): java.net.URI#getAuthority incorrectly returns null for both xds:///service + // and xds:/service. This doesn't matter for now since XdsNameResolver treats them the same + // anyway and all this code will go away once newNameResolver(io.grpc.Uri) launches. + String targetAuthority = targetUri.getAuthority(); + return newNameResolver(targetUri.toString(), targetAuthority, name, args); + } + return null; + } + + @Override + public XdsNameResolver newNameResolver(Uri targetUri, Args args) { + if (scheme.equals(targetUri.getScheme())) { + Preconditions.checkArgument( + targetUri.isPathAbsolute(), + "the path component of the target (%s) must start with '/'", + targetUri); + return newNameResolver( + targetUri.toString(), targetUri.getAuthority(), targetUri.getPath().substring(1), args); } return null; } + private XdsNameResolver newNameResolver( + String targetUri, String targetAuthority, String name, Args args) { + return new XdsNameResolver( + targetUri.toString(), + targetAuthority, + name, + args.getOverrideAuthority(), + args.getServiceConfigParser(), + args.getSynchronizationContext(), + args.getScheduledExecutorService(), + bootstrapOverride, + args.getMetricRecorder(), + args); + } + @Override public String getDefaultScheme() { return scheme; diff --git a/xds/src/main/java/io/grpc/xds/XdsRouteConfigureResource.java b/xds/src/main/java/io/grpc/xds/XdsRouteConfigureResource.java index 22e65334390..730d301c3ec 100644 --- a/xds/src/main/java/io/grpc/xds/XdsRouteConfigureResource.java +++ b/xds/src/main/java/io/grpc/xds/XdsRouteConfigureResource.java @@ -36,8 +36,8 @@ import io.envoyproxy.envoy.config.route.v3.ClusterSpecifierPlugin; import io.envoyproxy.envoy.config.route.v3.RetryPolicy.RetryBackOff; import io.envoyproxy.envoy.config.route.v3.RouteConfiguration; -import io.envoyproxy.envoy.type.v3.FractionalPercent; import io.grpc.Status; +import io.grpc.internal.GrpcUtil; import io.grpc.xds.ClusterSpecifierPlugin.NamedPluginConfig; import io.grpc.xds.ClusterSpecifierPlugin.PluginConfig; import io.grpc.xds.Filter.FilterConfig; @@ -67,10 +67,18 @@ import javax.annotation.Nullable; class XdsRouteConfigureResource extends XdsResourceType { + + private static final boolean isXdsAuthorityRewriteEnabled = GrpcUtil.getFlag( + "GRPC_EXPERIMENTAL_XDS_AUTHORITY_REWRITE", true); + @VisibleForTesting + static boolean enableRouteLookup = GrpcUtil.getFlag("GRPC_EXPERIMENTAL_XDS_RLS_LB", true); + static final String ADS_TYPE_URL_RDS = "type.googleapis.com/envoy.config.route.v3.RouteConfiguration"; private static final String TYPE_URL_FILTER_CONFIG = "type.googleapis.com/envoy.config.route.v3.FilterConfig"; + @VisibleForTesting + static final String HASH_POLICY_FILTER_STATE_KEY = "io.grpc.channel_id"; // TODO(zdapeng): need to discuss how to handle unsupported values. private static final Set SUPPORTED_RETRYABLE_CODES = Collections.unmodifiableSet(EnumSet.of( @@ -124,17 +132,17 @@ protected RdsUpdate doParse(XdsResourceType.Args args, Message unpackedMessage) throw new ResourceInvalidException("Invalid message type: " + unpackedMessage.getClass()); } return processRouteConfiguration( - (RouteConfiguration) unpackedMessage, FilterRegistry.getDefaultRegistry()); + (RouteConfiguration) unpackedMessage, FilterRegistry.getDefaultRegistry(), args); } private static RdsUpdate processRouteConfiguration( - RouteConfiguration routeConfig, FilterRegistry filterRegistry) + RouteConfiguration routeConfig, FilterRegistry filterRegistry, XdsResourceType.Args args) throws ResourceInvalidException { - return new RdsUpdate(extractVirtualHosts(routeConfig, filterRegistry)); + return new RdsUpdate(extractVirtualHosts(routeConfig, filterRegistry, args)); } static List extractVirtualHosts( - RouteConfiguration routeConfig, FilterRegistry filterRegistry) + RouteConfiguration routeConfig, FilterRegistry filterRegistry, XdsResourceType.Args args) throws ResourceInvalidException { Map pluginConfigMap = new HashMap<>(); ImmutableSet.Builder optionalPlugins = ImmutableSet.builder(); @@ -160,7 +168,7 @@ static List extractVirtualHosts( : routeConfig.getVirtualHostsList()) { StructOrError virtualHost = parseVirtualHost(virtualHostProto, filterRegistry, pluginConfigMap, - optionalPlugins.build()); + optionalPlugins.build(), args); if (virtualHost.getErrorDetail() != null) { throw new ResourceInvalidException( "RouteConfiguration contains invalid virtual host: " + virtualHost.getErrorDetail()); @@ -173,12 +181,12 @@ static List extractVirtualHosts( private static StructOrError parseVirtualHost( io.envoyproxy.envoy.config.route.v3.VirtualHost proto, FilterRegistry filterRegistry, Map pluginConfigMap, - Set optionalPlugins) { + Set optionalPlugins, XdsResourceType.Args args) { String name = proto.getName(); List routes = new ArrayList<>(proto.getRoutesCount()); for (io.envoyproxy.envoy.config.route.v3.Route routeProto : proto.getRoutesList()) { StructOrError route = parseRoute( - routeProto, filterRegistry, pluginConfigMap, optionalPlugins); + routeProto, filterRegistry, pluginConfigMap, optionalPlugins, args); if (route == null) { continue; } @@ -189,7 +197,7 @@ private static StructOrError parseVirtualHost( routes.add(route.getStruct()); } StructOrError> overrideConfigs = - parseOverrideFilterConfigs(proto.getTypedPerFilterConfigMap(), filterRegistry); + parseOverrideFilterConfigs(proto.getTypedPerFilterConfigMap(), filterRegistry, args); if (overrideConfigs.getErrorDetail() != null) { return StructOrError.fromError( "VirtualHost [" + proto.getName() + "] contains invalid HttpFilter config: " @@ -201,7 +209,12 @@ private static StructOrError parseVirtualHost( @VisibleForTesting static StructOrError> parseOverrideFilterConfigs( - Map rawFilterConfigMap, FilterRegistry filterRegistry) { + Map rawFilterConfigMap, FilterRegistry filterRegistry, + XdsResourceType.Args args) { + Filter.FilterConfigParseContext context = Filter.FilterConfigParseContext.builder() + .bootstrapInfo(args.getBootstrapInfo()) + .serverInfo(args.getServerInfo()) + .build(); Map overrideConfigs = new HashMap<>(); for (String name : rawFilterConfigMap.keySet()) { Any anyConfig = rawFilterConfigMap.get(name); @@ -236,8 +249,8 @@ static StructOrError> parseOverrideFilterConfigs( return StructOrError.fromError( "FilterConfig [" + name + "] contains invalid proto: " + e); } - Filter filter = filterRegistry.get(typeUrl); - if (filter == null) { + Filter.Provider provider = filterRegistry.get(typeUrl); + if (provider == null) { if (isOptional) { continue; } @@ -245,7 +258,7 @@ static StructOrError> parseOverrideFilterConfigs( "HttpFilter [" + name + "](" + typeUrl + ") is required but unsupported"); } ConfigOrError filterConfig = - filter.parseFilterConfigOverride(rawConfig); + provider.parseFilterConfigOverride(rawConfig, context); if (filterConfig.errorDetail != null) { return StructOrError.fromError( "Invalid filter config for HttpFilter [" + name + "]: " + filterConfig.errorDetail); @@ -260,7 +273,7 @@ static StructOrError> parseOverrideFilterConfigs( static StructOrError parseRoute( io.envoyproxy.envoy.config.route.v3.Route proto, FilterRegistry filterRegistry, Map pluginConfigMap, - Set optionalPlugins) { + Set optionalPlugins, XdsResourceType.Args args) { StructOrError routeMatch = parseRouteMatch(proto.getMatch()); if (routeMatch == null) { return null; @@ -272,7 +285,7 @@ static StructOrError parseRoute( } StructOrError> overrideConfigsOrError = - parseOverrideFilterConfigs(proto.getTypedPerFilterConfigMap(), filterRegistry); + parseOverrideFilterConfigs(proto.getTypedPerFilterConfigMap(), filterRegistry, args); if (overrideConfigsOrError.getErrorDetail() != null) { return StructOrError.fromError( "Route [" + proto.getName() + "] contains invalid HttpFilter config: " @@ -284,7 +297,7 @@ static StructOrError parseRoute( case ROUTE: StructOrError routeAction = parseRouteAction(proto.getRoute(), filterRegistry, pluginConfigMap, - optionalPlugins); + optionalPlugins, args); if (routeAction == null) { return null; } @@ -322,12 +335,12 @@ static StructOrError parseRouteMatch( FractionMatcher fractionMatch = null; if (proto.hasRuntimeFraction()) { - StructOrError parsedFraction = - parseFractionMatcher(proto.getRuntimeFraction().getDefaultValue()); - if (parsedFraction.getErrorDetail() != null) { - return StructOrError.fromError(parsedFraction.getErrorDetail()); + try { + fractionMatch = + MatcherParser.parseFractionMatcher(proto.getRuntimeFraction().getDefaultValue()); + } catch (IllegalArgumentException e) { + return StructOrError.fromError(e.getMessage()); } - fractionMatch = parsedFraction.getStruct(); } List headerMatchers = new ArrayList<>(); @@ -368,26 +381,7 @@ static StructOrError parsePathMatcher( } } - private static StructOrError parseFractionMatcher(FractionalPercent proto) { - int numerator = proto.getNumerator(); - int denominator = 0; - switch (proto.getDenominator()) { - case HUNDRED: - denominator = 100; - break; - case TEN_THOUSAND: - denominator = 10_000; - break; - case MILLION: - denominator = 1_000_000; - break; - case UNRECOGNIZED: - default: - return StructOrError.fromError( - "Unrecognized fractional percent denominator: " + proto.getDenominator()); - } - return StructOrError.fromStruct(FractionMatcher.create(numerator, denominator)); - } + @VisibleForTesting static StructOrError parseHeaderMatcher( @@ -410,7 +404,7 @@ static StructOrError parseHeaderMatcher( static StructOrError parseRouteAction( io.envoyproxy.envoy.config.route.v3.RouteAction proto, FilterRegistry filterRegistry, Map pluginConfigMap, - Set optionalPlugins) { + Set optionalPlugins, XdsResourceType.Args args) { Long timeoutNano = null; if (proto.hasMaxStreamDuration()) { io.envoyproxy.envoy.config.route.v3.RouteAction.MaxStreamDuration maxStreamDuration @@ -442,8 +436,7 @@ static StructOrError parseRouteAction( config.getHeader(); Pattern regEx = null; String regExSubstitute = null; - if (headerCfg.hasRegexRewrite() && headerCfg.getRegexRewrite().hasPattern() - && headerCfg.getRegexRewrite().getPattern().hasGoogleRe2()) { + if (headerCfg.hasRegexRewrite() && headerCfg.getRegexRewrite().hasPattern()) { regEx = Pattern.compile(headerCfg.getRegexRewrite().getPattern().getRegex()); regExSubstitute = headerCfg.getRegexRewrite().getSubstitution(); } @@ -466,7 +459,9 @@ static StructOrError parseRouteAction( switch (proto.getClusterSpecifierCase()) { case CLUSTER: return StructOrError.fromStruct(RouteAction.forCluster( - proto.getCluster(), hashPolicies, timeoutNano, retryPolicy)); + proto.getCluster(), hashPolicies, timeoutNano, retryPolicy, + isXdsAuthorityRewriteEnabled && args.getServerInfo().isTrustedXdsServer() + && proto.getAutoHostRewrite().getValue())); case CLUSTER_HEADER: return null; case WEIGHTED_CLUSTERS: @@ -480,13 +475,14 @@ static StructOrError parseRouteAction( for (io.envoyproxy.envoy.config.route.v3.WeightedCluster.ClusterWeight clusterWeight : clusterWeights) { StructOrError clusterWeightOrError = - parseClusterWeight(clusterWeight, filterRegistry); + parseClusterWeight(clusterWeight, filterRegistry, args); if (clusterWeightOrError.getErrorDetail() != null) { return StructOrError.fromError("RouteAction contains invalid ClusterWeight: " + clusterWeightOrError.getErrorDetail()); } - clusterWeightSum += clusterWeight.getWeight().getValue(); - weightedClusters.add(clusterWeightOrError.getStruct()); + ClusterWeight parsedWeight = clusterWeightOrError.getStruct(); + clusterWeightSum += parsedWeight.weight(); + weightedClusters.add(parsedWeight); } if (clusterWeightSum <= 0) { return StructOrError.fromError("Sum of cluster weights should be above 0."); @@ -498,7 +494,9 @@ static StructOrError parseRouteAction( UnsignedInteger.MAX_VALUE.longValue(), clusterWeightSum)); } return StructOrError.fromStruct(VirtualHost.Route.RouteAction.forWeightedClusters( - weightedClusters, hashPolicies, timeoutNano, retryPolicy)); + weightedClusters, hashPolicies, timeoutNano, retryPolicy, + isXdsAuthorityRewriteEnabled && args.getServerInfo().isTrustedXdsServer() + && proto.getAutoHostRewrite().getValue())); case CLUSTER_SPECIFIER_PLUGIN: if (enableRouteLookup) { String pluginName = proto.getClusterSpecifierPlugin(); @@ -513,7 +511,9 @@ static StructOrError parseRouteAction( } NamedPluginConfig namedPluginConfig = NamedPluginConfig.create(pluginName, pluginConfig); return StructOrError.fromStruct(VirtualHost.Route.RouteAction.forClusterSpecifierPlugin( - namedPluginConfig, hashPolicies, timeoutNano, retryPolicy)); + namedPluginConfig, hashPolicies, timeoutNano, retryPolicy, + isXdsAuthorityRewriteEnabled && args.getServerInfo().isTrustedXdsServer() + && proto.getAutoHostRewrite().getValue())); } else { return null; } @@ -584,16 +584,18 @@ private static StructOrError parseRet @VisibleForTesting static StructOrError parseClusterWeight( io.envoyproxy.envoy.config.route.v3.WeightedCluster.ClusterWeight proto, - FilterRegistry filterRegistry) { + FilterRegistry filterRegistry, XdsResourceType.Args args) { StructOrError> overrideConfigs = - parseOverrideFilterConfigs(proto.getTypedPerFilterConfigMap(), filterRegistry); + parseOverrideFilterConfigs(proto.getTypedPerFilterConfigMap(), filterRegistry, args); if (overrideConfigs.getErrorDetail() != null) { return StructOrError.fromError( "ClusterWeight [" + proto.getName() + "] contains invalid HttpFilter config: " + overrideConfigs.getErrorDetail()); } return StructOrError.fromStruct(VirtualHost.Route.RouteAction.ClusterWeight.create( - proto.getName(), proto.getWeight().getValue(), overrideConfigs.getStruct())); + proto.getName(), + Integer.toUnsignedLong(proto.getWeight().getValue()), + overrideConfigs.getStruct())); } @Nullable // null if the plugin is not supported, but it's marked as optional. diff --git a/xds/src/main/java/io/grpc/xds/XdsServerBuilder.java b/xds/src/main/java/io/grpc/xds/XdsServerBuilder.java index b75d5755f6e..4a4fb71aa84 100644 --- a/xds/src/main/java/io/grpc/xds/XdsServerBuilder.java +++ b/xds/src/main/java/io/grpc/xds/XdsServerBuilder.java @@ -19,8 +19,8 @@ import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.base.Preconditions.checkState; -import static io.grpc.xds.InternalXdsAttributes.ATTR_DRAIN_GRACE_NANOS; -import static io.grpc.xds.InternalXdsAttributes.ATTR_FILTER_CHAIN_SELECTOR_MANAGER; +import static io.grpc.xds.XdsAttributes.ATTR_DRAIN_GRACE_NANOS; +import static io.grpc.xds.XdsAttributes.ATTR_FILTER_CHAIN_SELECTOR_MANAGER; import com.google.common.annotations.VisibleForTesting; import com.google.errorprone.annotations.DoNotCall; @@ -55,6 +55,7 @@ public final class XdsServerBuilder extends ForwardingServerBuilder bootstrapOverride; private long drainGraceTime = 10; private TimeUnit drainGraceTimeUnit = TimeUnit.MINUTES; @@ -127,7 +128,7 @@ public Server build() { } InternalNettyServerBuilder.eagAttributes(delegate, builder.build()); return new XdsServerWrapper("0.0.0.0:" + port, delegate, xdsServingStatusListener, - filterChainSelectorManager, xdsClientPoolFactory, filterRegistry); + filterChainSelectorManager, xdsClientPoolFactory, bootstrapOverride, filterRegistry); } @VisibleForTesting @@ -140,11 +141,10 @@ XdsServerBuilder xdsClientPoolFactory(XdsClientPoolFactory xdsClientPoolFactory) * Allows providing bootstrap override, useful for testing. */ public XdsServerBuilder overrideBootstrapForTest(Map bootstrapOverride) { - checkNotNull(bootstrapOverride, "bootstrapOverride"); + this.bootstrapOverride = checkNotNull(bootstrapOverride, "bootstrapOverride"); if (this.xdsClientPoolFactory == SharedXdsClientPoolProvider.getDefaultProvider()) { this.xdsClientPoolFactory = new SharedXdsClientPoolProvider(); } - this.xdsClientPoolFactory.setBootstrapOverride(bootstrapOverride); return this; } diff --git a/xds/src/main/java/io/grpc/xds/XdsServerWrapper.java b/xds/src/main/java/io/grpc/xds/XdsServerWrapper.java index bf8603fb3e4..5529f96c7a2 100644 --- a/xds/src/main/java/io/grpc/xds/XdsServerWrapper.java +++ b/xds/src/main/java/io/grpc/xds/XdsServerWrapper.java @@ -24,11 +24,15 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import com.google.common.net.HostAndPort; +import com.google.common.net.InetAddresses; import com.google.common.util.concurrent.SettableFuture; +import io.envoyproxy.envoy.config.core.v3.SocketAddress.Protocol; import io.grpc.Attributes; import io.grpc.InternalServerInterceptors; import io.grpc.Metadata; import io.grpc.MethodDescriptor; +import io.grpc.MetricRecorder; import io.grpc.Server; import io.grpc.ServerBuilder; import io.grpc.ServerCall; @@ -38,6 +42,7 @@ import io.grpc.ServerServiceDefinition; import io.grpc.Status; import io.grpc.StatusException; +import io.grpc.StatusOr; import io.grpc.SynchronizationContext; import io.grpc.SynchronizationContext.ScheduledHandle; import io.grpc.internal.GrpcUtil; @@ -46,20 +51,20 @@ import io.grpc.xds.EnvoyServerProtoData.FilterChain; import io.grpc.xds.Filter.FilterConfig; import io.grpc.xds.Filter.NamedFilterConfig; -import io.grpc.xds.Filter.ServerInterceptorBuilder; import io.grpc.xds.FilterChainMatchingProtocolNegotiators.FilterChainMatchingHandler.FilterChainSelector; import io.grpc.xds.ThreadSafeRandom.ThreadSafeRandomImpl; import io.grpc.xds.VirtualHost.Route; import io.grpc.xds.XdsListenerResource.LdsUpdate; import io.grpc.xds.XdsRouteConfigureResource.RdsUpdate; import io.grpc.xds.XdsServerBuilder.XdsServingStatusListener; +import io.grpc.xds.client.Bootstrapper.BootstrapInfo; import io.grpc.xds.client.XdsClient; import io.grpc.xds.client.XdsClient.ResourceWatcher; import io.grpc.xds.internal.security.SslContextProviderSupplier; import java.io.IOException; +import java.net.InetAddress; import java.net.SocketAddress; import java.util.ArrayList; -import java.util.Collections; import java.util.HashMap; import java.util.HashSet; import java.util.List; @@ -100,6 +105,7 @@ public void uncaughtException(Thread t, Throwable e) { private final FilterRegistry filterRegistry; private final ThreadSafeRandom random = ThreadSafeRandomImpl.instance; private final XdsClientPoolFactory xdsClientPoolFactory; + private final @Nullable Map bootstrapOverride; private final XdsServingStatusListener listener; private final FilterChainSelectorManager filterChainSelectorManager; private final AtomicBoolean started = new AtomicBoolean(false); @@ -114,15 +120,31 @@ public void uncaughtException(Thread t, Throwable e) { private DiscoveryState discoveryState; private volatile Server delegate; + // Must be accessed in syncContext. + // Filter instances are unique per Server, per FilterChain, and per filter's name+typeUrl. + // FilterChain.name -> filter_instance>. + private final HashMap> activeFilters = new HashMap<>(); + // Default filter chain Filter instances are unique per Server, and per filter's name+typeUrl. + // NamedFilterConfig.filterStateKey -> filter_instance. + private final HashMap activeFiltersDefaultChain = new HashMap<>(); + XdsServerWrapper( String listenerAddress, ServerBuilder delegateBuilder, XdsServingStatusListener listener, FilterChainSelectorManager filterChainSelectorManager, XdsClientPoolFactory xdsClientPoolFactory, + @Nullable Map bootstrapOverride, FilterRegistry filterRegistry) { - this(listenerAddress, delegateBuilder, listener, filterChainSelectorManager, - xdsClientPoolFactory, filterRegistry, SharedResourceHolder.get(GrpcUtil.TIMER_SERVICE)); + this( + listenerAddress, + delegateBuilder, + listener, + filterChainSelectorManager, + xdsClientPoolFactory, + bootstrapOverride, + filterRegistry, + SharedResourceHolder.get(GrpcUtil.TIMER_SERVICE)); sharedTimeService = true; } @@ -133,6 +155,7 @@ public void uncaughtException(Thread t, Throwable e) { XdsServingStatusListener listener, FilterChainSelectorManager filterChainSelectorManager, XdsClientPoolFactory xdsClientPoolFactory, + @Nullable Map bootstrapOverride, FilterRegistry filterRegistry, ScheduledExecutorService timeService) { this.listenerAddress = checkNotNull(listenerAddress, "listenerAddress"); @@ -142,6 +165,7 @@ public void uncaughtException(Thread t, Throwable e) { this.filterChainSelectorManager = checkNotNull(filterChainSelectorManager, "filterChainSelectorManager"); this.xdsClientPoolFactory = checkNotNull(xdsClientPoolFactory, "xdsClientPoolFactory"); + this.bootstrapOverride = bootstrapOverride; this.timeService = checkNotNull(timeService, "timeService"); this.filterRegistry = checkNotNull(filterRegistry,"filterRegistry"); this.delegate = delegateBuilder.build(); @@ -171,7 +195,14 @@ public void run() { private void internalStart() { try { - xdsClientPool = xdsClientPoolFactory.getOrCreate(); + BootstrapInfo bootstrapInfo; + if (bootstrapOverride == null) { + bootstrapInfo = GrpcBootstrapperImpl.defaultBootstrap(); + } else { + bootstrapInfo = new GrpcBootstrapperImpl().bootstrap(bootstrapOverride); + } + xdsClientPool = xdsClientPoolFactory.getOrCreate( + "#server", bootstrapInfo, new MetricRecorder() {}); } catch (Exception e) { StatusException statusException = Status.UNAVAILABLE.withDescription( "Failed to initialize xDS").withCause(e).asException(); @@ -371,25 +402,55 @@ private DiscoveryState(String resourceName) { } @Override - public void onChanged(final LdsUpdate update) { + public void onResourceChanged(final StatusOr update) { if (stopped) { return; } - logger.log(Level.FINEST, "Received Lds update {0}", update); - checkNotNull(update.listener(), "update"); + + if (!update.hasValue()) { + Status status = update.getStatus(); + StatusException statusException = Status.UNAVAILABLE.withDescription( + String.format("Listener %s unavailable: %s", resourceName, status.getDescription())) + .withCause(status.asException()) + .asException(); + handleConfigNotFoundOrMismatch(statusException); + return; + } + + final LdsUpdate ldsUpdate = update.getValue(); + logger.log(Level.FINEST, "Received Lds update {0}", ldsUpdate); + if (ldsUpdate.listener() == null) { + handleConfigNotFoundOrMismatch( + Status.NOT_FOUND.withDescription("Listener is null in LdsUpdate").asException()); + return; + } + String ldsAddress = ldsUpdate.listener().address(); + if (ldsAddress == null || ldsUpdate.listener().protocol() != Protocol.TCP + || !ipAddressesMatch(ldsAddress)) { + handleConfigNotFoundOrMismatch( + Status.UNKNOWN.withDescription( + String.format( + "Listener address mismatch: expected %s, but got %s.", + listenerAddress, ldsAddress)).asException()); + return; + } + if (!pendingRds.isEmpty()) { // filter chain state has not yet been applied to filterChainSelectorManager and there - // are two sets of sslContextProviderSuppliers, so we release the old ones. releaseSuppliersInFlight(); pendingRds.clear(); } - filterChains = update.listener().filterChains(); - defaultFilterChain = update.listener().defaultFilterChain(); + + filterChains = ldsUpdate.listener().filterChains(); + defaultFilterChain = ldsUpdate.listener().defaultFilterChain(); + updateActiveFilters(); + List allFilterChains = filterChains; if (defaultFilterChain != null) { allFilterChains = new ArrayList<>(filterChains); allFilterChains.add(defaultFilterChain); } + Set allRds = new HashSet<>(); for (FilterChain filterChain : allFilterChains) { HttpConnectionManager hcm = filterChain.httpConnectionManager(); @@ -407,6 +468,7 @@ public void onChanged(final LdsUpdate update) { allRds.add(hcm.rdsName()); } } + for (Map.Entry entry: routeDiscoveryStates.entrySet()) { if (!allRds.contains(entry.getKey())) { xdsClient.cancelXdsResourceWatch(XdsRouteConfigureResource.getInstance(), @@ -420,31 +482,38 @@ public void onChanged(final LdsUpdate update) { } @Override - public void onResourceDoesNotExist(final String resourceName) { + public void onAmbientError(final Status error) { if (stopped) { return; } - StatusException statusException = Status.UNAVAILABLE.withDescription( - "Listener " + resourceName + " unavailable").asException(); - handleConfigNotFound(statusException); - } + String description = error.getDescription() == null ? "" : error.getDescription() + " "; + Status errorWithNodeId = error.withDescription( + description + "xDS node ID: " + xdsClient.getBootstrapInfo().node().getId()); + logger.log(Level.FINE, "Error from XdsClient", errorWithNodeId); - @Override - public void onError(final Status error) { - if (stopped) { - return; - } - logger.log(Level.FINE, "Error from XdsClient", error); if (!isServing) { - listener.onNotServing(error.asException()); + listener.onNotServing(errorWithNodeId.asException()); } } + private boolean ipAddressesMatch(String ldsAddress) { + HostAndPort ldsAddressHnP = HostAndPort.fromString(ldsAddress); + HostAndPort listenerAddressHnP = HostAndPort.fromString(listenerAddress); + if (!ldsAddressHnP.hasPort() || !listenerAddressHnP.hasPort() + || ldsAddressHnP.getPort() != listenerAddressHnP.getPort()) { + return false; + } + InetAddress listenerIp = InetAddresses.forString(listenerAddressHnP.getHost()); + InetAddress ldsIp = InetAddresses.forString(ldsAddressHnP.getHost()); + return listenerIp.equals(ldsIp); + } + private void shutdown() { stopped = true; cleanUpRouteDiscoveryStates(); logger.log(Level.FINE, "Stop watching LDS resource {0}", resourceName); xdsClient.cancelXdsResourceWatch(XdsListenerResource.getInstance(), resourceName, this); + shutdownActiveFilters(); List toRelease = getSuppliersInUse(); filterChainSelectorManager.updateSelector(FilterChainSelector.NO_FILTER_CHAIN); for (SslContextProviderSupplier s: toRelease) { @@ -454,81 +523,184 @@ private void shutdown() { } private void updateSelector() { - Map> filterChainRouting = new HashMap<>(); + // This is regenerated in generateRoutingConfig() calls below. savedRdsRoutingConfigRef.clear(); + + // Prepare server routing config map. + ImmutableMap.Builder> routingConfigs = + ImmutableMap.builder(); for (FilterChain filterChain: filterChains) { - filterChainRouting.put(filterChain, generateRoutingConfig(filterChain)); + HashMap chainFilters = activeFilters.get(filterChain.name()); + routingConfigs.put(filterChain, generateRoutingConfig(filterChain, chainFilters)); } - FilterChainSelector selector = new FilterChainSelector( - Collections.unmodifiableMap(filterChainRouting), - defaultFilterChain == null ? null : defaultFilterChain.sslContextProviderSupplier(), - defaultFilterChain == null ? new AtomicReference() : - generateRoutingConfig(defaultFilterChain)); - List toRelease = getSuppliersInUse(); + + // Prepare the new selector. + FilterChainSelector selector; + if (defaultFilterChain != null) { + selector = new FilterChainSelector( + routingConfigs.build(), + defaultFilterChain.sslContextProviderSupplier(), + generateRoutingConfig(defaultFilterChain, activeFiltersDefaultChain)); + } else { + selector = new FilterChainSelector(routingConfigs.build()); + } + + // Prepare the list of current selector's resources to close later. + List oldSslSuppliers = getSuppliersInUse(); + + // Swap the selectors, initiate a graceful shutdown of the old one. logger.log(Level.FINEST, "Updating selector {0}", selector); filterChainSelectorManager.updateSelector(selector); - for (SslContextProviderSupplier e: toRelease) { - e.close(); + + // Release old resources. + for (SslContextProviderSupplier supplier: oldSslSuppliers) { + supplier.close(); } + + // Now that we have valid Transport Socket config, we can start/restart listening on a port. startDelegateServer(); } - private AtomicReference generateRoutingConfig(FilterChain filterChain) { + // called in syncContext + private void updateActiveFilters() { + Set removedChains = new HashSet<>(activeFilters.keySet()); + for (FilterChain filterChain: filterChains) { + removedChains.remove(filterChain.name()); + updateActiveFiltersForChain( + activeFilters.computeIfAbsent(filterChain.name(), k -> new HashMap<>()), + filterChain.httpConnectionManager().httpFilterConfigs()); + } + + // Shutdown all filters of chains missing from the LDS. + for (String chainToShutdown : removedChains) { + HashMap filtersToShutdown = activeFilters.get(chainToShutdown); + checkNotNull(filtersToShutdown, "filtersToShutdown of chain %s", chainToShutdown); + updateActiveFiltersForChain(filtersToShutdown, null); + activeFilters.remove(chainToShutdown); + } + + // Default chain. + ImmutableList defaultChainConfigs = null; + if (defaultFilterChain != null) { + defaultChainConfigs = defaultFilterChain.httpConnectionManager().httpFilterConfigs(); + } + updateActiveFiltersForChain(activeFiltersDefaultChain, defaultChainConfigs); + } + + // called in syncContext + private void shutdownActiveFilters() { + for (HashMap chainFilters : activeFilters.values()) { + checkNotNull(chainFilters, "chainFilters"); + updateActiveFiltersForChain(chainFilters, null); + } + activeFilters.clear(); + updateActiveFiltersForChain(activeFiltersDefaultChain, null); + } + + // called in syncContext + private void updateActiveFiltersForChain( + Map chainFilters, @Nullable List filterConfigs) { + if (filterConfigs == null) { + filterConfigs = ImmutableList.of(); + } + + Set filtersToShutdown = new HashSet<>(chainFilters.keySet()); + for (NamedFilterConfig namedFilter : filterConfigs) { + String typeUrl = namedFilter.filterConfig.typeUrl(); + String filterKey = namedFilter.filterStateKey(); + + Filter.Provider provider = filterRegistry.get(typeUrl); + checkNotNull(provider, "provider %s", typeUrl); + Filter filter = chainFilters.computeIfAbsent( + filterKey, k -> provider.newInstance(namedFilter.name)); + checkNotNull(filter, "filter %s", filterKey); + filtersToShutdown.remove(filterKey); + } + + // Shutdown filters not present in current HCM. + for (String filterKey : filtersToShutdown) { + Filter filterToShutdown = chainFilters.remove(filterKey); + checkNotNull(filterToShutdown, "filterToShutdown %s", filterKey); + filterToShutdown.close(); + } + } + + private AtomicReference generateRoutingConfig( + FilterChain filterChain, Map chainFilters) { HttpConnectionManager hcm = filterChain.httpConnectionManager(); - if (hcm.virtualHosts() != null) { - ImmutableMap interceptors = generatePerRouteInterceptors( - hcm.httpFilterConfigs(), hcm.virtualHosts()); - return new AtomicReference<>(ServerRoutingConfig.create(hcm.virtualHosts(),interceptors)); + ServerRoutingConfig routingConfig; + + // Inlined routes. + ImmutableList vhosts = hcm.virtualHosts(); + if (vhosts != null) { + routingConfig = ServerRoutingConfig.create(vhosts, + generatePerRouteInterceptors(hcm.httpFilterConfigs(), vhosts, chainFilters)); + return new AtomicReference<>(routingConfig); + } + + // Routes from RDS. + RouteDiscoveryState rds = routeDiscoveryStates.get(hcm.rdsName()); + checkNotNull(rds, "rds"); + + ImmutableList savedVhosts = rds.savedVirtualHosts; + if (savedVhosts != null) { + routingConfig = ServerRoutingConfig.create(savedVhosts, + generatePerRouteInterceptors(hcm.httpFilterConfigs(), savedVhosts, chainFilters)); } else { - RouteDiscoveryState rds = routeDiscoveryStates.get(hcm.rdsName()); - checkNotNull(rds, "rds"); - AtomicReference serverRoutingConfigRef = new AtomicReference<>(); - if (rds.savedVirtualHosts != null) { - ImmutableMap interceptors = generatePerRouteInterceptors( - hcm.httpFilterConfigs(), rds.savedVirtualHosts); - ServerRoutingConfig serverRoutingConfig = - ServerRoutingConfig.create(rds.savedVirtualHosts, interceptors); - serverRoutingConfigRef.set(serverRoutingConfig); - } else { - serverRoutingConfigRef.set(ServerRoutingConfig.FAILING_ROUTING_CONFIG); - } - savedRdsRoutingConfigRef.put(filterChain, serverRoutingConfigRef); - return serverRoutingConfigRef; + routingConfig = ServerRoutingConfig.FAILING_ROUTING_CONFIG; } + AtomicReference routingConfigRef = new AtomicReference<>(routingConfig); + savedRdsRoutingConfigRef.put(filterChain, routingConfigRef); + return routingConfigRef; } private ImmutableMap generatePerRouteInterceptors( - List namedFilterConfigs, List virtualHosts) { + @Nullable List filterConfigs, + List virtualHosts, + Map chainFilters) { + syncContext.throwIfNotInThisSynchronizationContext(); + + checkNotNull(chainFilters, "chainFilters"); ImmutableMap.Builder perRouteInterceptors = new ImmutableMap.Builder<>(); + for (VirtualHost virtualHost : virtualHosts) { for (Route route : virtualHost.routes()) { - List filterInterceptors = new ArrayList<>(); - Map selectedOverrideConfigs = - new HashMap<>(virtualHost.filterConfigOverrides()); - selectedOverrideConfigs.putAll(route.filterConfigOverrides()); - if (namedFilterConfigs != null) { - for (NamedFilterConfig namedFilterConfig : namedFilterConfigs) { - FilterConfig filterConfig = namedFilterConfig.filterConfig; - Filter filter = filterRegistry.get(filterConfig.typeUrl()); - if (filter instanceof ServerInterceptorBuilder) { - ServerInterceptor interceptor = - ((ServerInterceptorBuilder) filter).buildServerInterceptor( - filterConfig, selectedOverrideConfigs.get(namedFilterConfig.name)); - if (interceptor != null) { - filterInterceptors.add(interceptor); - } - } else { - logger.log(Level.WARNING, "HttpFilterConfig(type URL: " - + filterConfig.typeUrl() + ") is not supported on server-side. " - + "Probably a bug at ClientXdsClient verification."); - } + // Short circuit. + if (filterConfigs == null) { + perRouteInterceptors.put(route, noopInterceptor); + continue; + } + + // Override vhost filter configs with more specific per-route configs. + Map perRouteOverrides = ImmutableMap.builder() + .putAll(virtualHost.filterConfigOverrides()) + .putAll(route.filterConfigOverrides()) + .buildKeepingLast(); + + // Interceptors for this vhost/route combo. + List interceptors = new ArrayList<>(filterConfigs.size()); + for (NamedFilterConfig namedFilter : filterConfigs) { + String name = namedFilter.name; + FilterConfig config = namedFilter.filterConfig; + FilterConfig overrideConfig = perRouteOverrides.get(name); + String filterKey = namedFilter.filterStateKey(); + + Filter filter = chainFilters.get(filterKey); + checkNotNull(filter, "chainFilters.get(%s)", filterKey); + ServerInterceptor interceptor = filter.buildServerInterceptor(config, overrideConfig); + + if (interceptor != null) { + interceptors.add(interceptor); } } - ServerInterceptor interceptor = combineInterceptors(filterInterceptors); - perRouteInterceptors.put(route, interceptor); + + // Combine interceptors produced by different filters into a single one that executes + // them sequentially. The order is preserved. + perRouteInterceptors.put(route, combineInterceptors(interceptors)); } } + return perRouteInterceptors.buildOrThrow(); } @@ -553,8 +725,9 @@ public Listener interceptCall(ServerCall call, }; } - private void handleConfigNotFound(StatusException exception) { + private void handleConfigNotFoundOrMismatch(StatusException exception) { cleanUpRouteDiscoveryStates(); + shutdownActiveFilters(); List toRelease = getSuppliersInUse(); filterChainSelectorManager.updateSelector(FilterChainSelector.NO_FILTER_CHAIN); for (SslContextProviderSupplier s: toRelease) { @@ -623,72 +796,65 @@ private RouteDiscoveryState(String resourceName) { } @Override - public void onChanged(final RdsUpdate update) { - syncContext.execute(new Runnable() { - @Override - public void run() { - if (!routeDiscoveryStates.containsKey(resourceName)) { - return; - } - if (savedVirtualHosts == null && !isPending) { - logger.log(Level.WARNING, "Received valid Rds {0} configuration.", resourceName); - } - savedVirtualHosts = ImmutableList.copyOf(update.virtualHosts); - updateRdsRoutingConfig(); - maybeUpdateSelector(); + public void onResourceChanged(final StatusOr update) { + syncContext.execute(() -> { + if (!routeDiscoveryStates.containsKey(resourceName)) { + return; // Watcher has been cancelled. } - }); - } - @Override - public void onResourceDoesNotExist(final String resourceName) { - syncContext.execute(new Runnable() { - @Override - public void run() { - if (!routeDiscoveryStates.containsKey(resourceName)) { - return; + if (update.hasValue()) { + if (savedVirtualHosts == null && !isPending) { + logger.log(Level.WARNING, "Received valid Rds {0} configuration.", resourceName); } - logger.log(Level.WARNING, "Rds {0} unavailable", resourceName); + savedVirtualHosts = ImmutableList.copyOf(update.getValue().virtualHosts); + } else { + logger.log(Level.WARNING, "Rds {0} unavailable: {1}", + new Object[]{resourceName, update.getStatus()}); savedVirtualHosts = null; - updateRdsRoutingConfig(); - maybeUpdateSelector(); } + // In both cases, a change has occurred that requires a config update. + updateRdsRoutingConfig(); + maybeUpdateSelector(); }); } @Override - public void onError(final Status error) { - syncContext.execute(new Runnable() { - @Override - public void run() { - if (!routeDiscoveryStates.containsKey(resourceName)) { - return; - } - logger.log(Level.WARNING, "Error loading RDS resource {0} from XdsClient: {1}.", - new Object[]{resourceName, error}); - maybeUpdateSelector(); + public void onAmbientError(final Status error) { + syncContext.execute(() -> { + if (!routeDiscoveryStates.containsKey(resourceName)) { + return; // Watcher has been cancelled. } + String description = error.getDescription() == null ? "" : error.getDescription() + " "; + Status errorWithNodeId = error.withDescription( + description + "xDS node ID: " + xdsClient.getBootstrapInfo().node().getId()); + logger.log(Level.WARNING, "Error loading RDS resource {0} from XdsClient: {1}.", + new Object[]{resourceName, errorWithNodeId}); + + // Per gRFC A88, ambient errors should not trigger a configuration change. + // Therefore, we do NOT call maybeUpdateSelector() here. }); } private void updateRdsRoutingConfig() { for (FilterChain filterChain : savedRdsRoutingConfigRef.keySet()) { - if (resourceName.equals(filterChain.httpConnectionManager().rdsName())) { - ServerRoutingConfig updatedRoutingConfig; - if (savedVirtualHosts == null) { - updatedRoutingConfig = ServerRoutingConfig.FAILING_ROUTING_CONFIG; - } else { - ImmutableMap updatedInterceptors = - generatePerRouteInterceptors( - filterChain.httpConnectionManager().httpFilterConfigs(), - savedVirtualHosts); - updatedRoutingConfig = ServerRoutingConfig.create(savedVirtualHosts, - updatedInterceptors); - } - logger.log(Level.FINEST, "Updating filter chain {0} rds routing config: {1}", - new Object[]{filterChain.name(), updatedRoutingConfig}); - savedRdsRoutingConfigRef.get(filterChain).set(updatedRoutingConfig); + HttpConnectionManager hcm = filterChain.httpConnectionManager(); + if (!resourceName.equals(hcm.rdsName())) { + continue; } + + ServerRoutingConfig updatedRoutingConfig; + if (savedVirtualHosts == null) { + updatedRoutingConfig = ServerRoutingConfig.FAILING_ROUTING_CONFIG; + } else { + HashMap chainFilters = activeFilters.get(filterChain.name()); + ImmutableMap interceptors = generatePerRouteInterceptors( + hcm.httpFilterConfigs(), savedVirtualHosts, chainFilters); + updatedRoutingConfig = ServerRoutingConfig.create(savedVirtualHosts, interceptors); + } + + logger.log(Level.FINEST, "Updating filter chain {0} rds routing config: {1}", + new Object[]{filterChain.name(), updatedRoutingConfig}); + savedRdsRoutingConfigRef.get(filterChain).set(updatedRoutingConfig); } } diff --git a/xds/src/main/java/io/grpc/xds/client/AllowedGrpcServices.java b/xds/src/main/java/io/grpc/xds/client/AllowedGrpcServices.java new file mode 100644 index 00000000000..e2d77689fca --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/client/AllowedGrpcServices.java @@ -0,0 +1,66 @@ +/* + * Copyright 2026 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.client; + +import com.google.auto.value.AutoValue; +import com.google.common.collect.ImmutableMap; +import io.grpc.CallCredentials; +import io.grpc.Internal; +import java.util.Map; +import java.util.Optional; + +/** + * Wrapper for allowed gRPC services keyed by target URI. + */ +@Internal +@AutoValue +public abstract class AllowedGrpcServices { + public abstract ImmutableMap services(); + + public static AllowedGrpcServices create(Map services) { + return new AutoValue_AllowedGrpcServices(ImmutableMap.copyOf(services)); + } + + public static AllowedGrpcServices empty() { + return create(ImmutableMap.of()); + } + + /** + * Represents an allowed gRPC service configuration with call credentials. + */ + @Internal + @AutoValue + public abstract static class AllowedGrpcService { + public abstract ConfiguredChannelCredentials configuredChannelCredentials(); + + public abstract Optional callCredentials(); + + public static Builder builder() { + return new AutoValue_AllowedGrpcServices_AllowedGrpcService.Builder(); + } + + @AutoValue.Builder + public abstract static class Builder { + public abstract Builder configuredChannelCredentials( + ConfiguredChannelCredentials credentials); + + public abstract Builder callCredentials(CallCredentials callCredentials); + + public abstract AllowedGrpcService build(); + } + } +} diff --git a/xds/src/main/java/io/grpc/xds/client/BackendMetricPropagation.java b/xds/src/main/java/io/grpc/xds/client/BackendMetricPropagation.java new file mode 100644 index 00000000000..f0e2c9484b4 --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/client/BackendMetricPropagation.java @@ -0,0 +1,133 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.client; + +import static com.google.common.base.Preconditions.checkNotNull; + +import com.google.common.collect.ImmutableSet; +import io.grpc.Internal; +import java.util.Objects; +import javax.annotation.Nullable; + +/** + * Represents the configuration for which ORCA metrics should be propagated from backend + * to LRS load reports, as defined in gRFC A85. + */ +@Internal +public final class BackendMetricPropagation { + + public final boolean propagateCpuUtilization; + public final boolean propagateMemUtilization; + public final boolean propagateApplicationUtilization; + + private final boolean propagateAllNamedMetrics; + private final ImmutableSet namedMetricKeys; + + private BackendMetricPropagation( + boolean propagateCpuUtilization, + boolean propagateMemUtilization, + boolean propagateApplicationUtilization, + boolean propagateAllNamedMetrics, + ImmutableSet namedMetricKeys) { + this.propagateCpuUtilization = propagateCpuUtilization; + this.propagateMemUtilization = propagateMemUtilization; + this.propagateApplicationUtilization = propagateApplicationUtilization; + this.propagateAllNamedMetrics = propagateAllNamedMetrics; + this.namedMetricKeys = checkNotNull(namedMetricKeys, "namedMetricKeys"); + } + + /** + * Creates a BackendMetricPropagation from a list of metric specifications. + * + * @param metricSpecs list of metric specification strings from CDS resource + * @return BackendMetricPropagation instance + */ + public static BackendMetricPropagation fromMetricSpecs( + @Nullable java.util.List metricSpecs) { + if (metricSpecs == null || metricSpecs.isEmpty()) { + return new BackendMetricPropagation(false, false, false, false, ImmutableSet.of()); + } + + boolean propagateCpuUtilization = false; + boolean propagateMemUtilization = false; + boolean propagateApplicationUtilization = false; + boolean propagateAllNamedMetrics = false; + ImmutableSet.Builder namedMetricKeysBuilder = ImmutableSet.builder(); + for (String spec : metricSpecs) { + if (spec == null) { + continue; + } + switch (spec) { + case "cpu_utilization": + propagateCpuUtilization = true; + break; + case "mem_utilization": + propagateMemUtilization = true; + break; + case "application_utilization": + propagateApplicationUtilization = true; + break; + case "named_metrics.*": + propagateAllNamedMetrics = true; + break; + default: + if (spec.startsWith("named_metrics.")) { + String metricKey = spec.substring("named_metrics.".length()); + if (!metricKey.isEmpty()) { + namedMetricKeysBuilder.add(metricKey); + } + } + } + } + + return new BackendMetricPropagation( + propagateCpuUtilization, + propagateMemUtilization, + propagateApplicationUtilization, + propagateAllNamedMetrics, + namedMetricKeysBuilder.build()); + } + + /** + * Returns whether the given named metric key should be propagated. + */ + public boolean shouldPropagateNamedMetric(String metricKey) { + return propagateAllNamedMetrics || namedMetricKeys.contains(metricKey); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + BackendMetricPropagation that = (BackendMetricPropagation) o; + return propagateCpuUtilization == that.propagateCpuUtilization + && propagateMemUtilization == that.propagateMemUtilization + && propagateApplicationUtilization == that.propagateApplicationUtilization + && propagateAllNamedMetrics == that.propagateAllNamedMetrics + && Objects.equals(namedMetricKeys, that.namedMetricKeys); + } + + @Override + public int hashCode() { + return Objects.hash(propagateCpuUtilization, propagateMemUtilization, + propagateApplicationUtilization, propagateAllNamedMetrics, namedMetricKeys); + } +} \ No newline at end of file diff --git a/xds/src/main/java/io/grpc/xds/client/Bootstrapper.java b/xds/src/main/java/io/grpc/xds/client/Bootstrapper.java index fe0c0050b52..b8d6444e3b3 100644 --- a/xds/src/main/java/io/grpc/xds/client/Bootstrapper.java +++ b/xds/src/main/java/io/grpc/xds/client/Bootstrapper.java @@ -26,6 +26,7 @@ import io.grpc.xds.client.EnvoyProtoData.Node; import java.util.List; import java.util.Map; +import java.util.Optional; import javax.annotation.Nullable; /** @@ -61,16 +62,26 @@ public abstract static class ServerInfo { public abstract boolean ignoreResourceDeletion(); + public abstract boolean isTrustedXdsServer(); + + public abstract boolean resourceTimerIsTransientError(); + + public abstract boolean failOnDataErrors(); + @VisibleForTesting public static ServerInfo create(String target, @Nullable Object implSpecificConfig) { - return new AutoValue_Bootstrapper_ServerInfo(target, implSpecificConfig, false); + return new AutoValue_Bootstrapper_ServerInfo(target, implSpecificConfig, + false, false, false, false); } @VisibleForTesting public static ServerInfo create( - String target, Object implSpecificConfig, boolean ignoreResourceDeletion) { + String target, Object implSpecificConfig, + boolean ignoreResourceDeletion, boolean isTrustedXdsServer, + boolean resourceTimerIsTransientError, boolean failOnDataErrors) { return new AutoValue_Bootstrapper_ServerInfo(target, implSpecificConfig, - ignoreResourceDeletion); + ignoreResourceDeletion, isTrustedXdsServer, + resourceTimerIsTransientError, failOnDataErrors); } } @@ -195,11 +206,18 @@ public abstract static class BootstrapInfo { */ public abstract ImmutableMap authorities(); + /** + * Parsed configuration for implementation-specific extensions. + * Returns an opaque object containing the parsed configuration. + */ + public abstract Optional implSpecificObject(); + @VisibleForTesting public static Builder builder() { return new AutoValue_Bootstrapper_BootstrapInfo.Builder() .clientDefaultListenerResourceNameTemplate("%s") - .authorities(ImmutableMap.of()); + .authorities(ImmutableMap.of()) + .implSpecificObject(Optional.empty()); } @AutoValue.Builder @@ -221,7 +239,10 @@ public abstract Builder clientDefaultListenerResourceNameTemplate( public abstract Builder authorities(Map authorities); + public abstract Builder implSpecificObject(Optional implSpecificObject); + public abstract BootstrapInfo build(); } } + } diff --git a/xds/src/main/java/io/grpc/xds/client/BootstrapperImpl.java b/xds/src/main/java/io/grpc/xds/client/BootstrapperImpl.java index 7ef739c8048..3f4ea8eb5c6 100644 --- a/xds/src/main/java/io/grpc/xds/client/BootstrapperImpl.java +++ b/xds/src/main/java/io/grpc/xds/client/BootstrapperImpl.java @@ -34,6 +34,8 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Optional; +import javax.annotation.Nullable; /** * A {@link Bootstrapper} implementation that reads xDS configurations from local file system. @@ -41,6 +43,11 @@ @Internal public abstract class BootstrapperImpl extends Bootstrapper { + public static final String GRPC_EXPERIMENTAL_XDS_FALLBACK = + "GRPC_EXPERIMENTAL_XDS_FALLBACK"; + public static final String GRPC_EXPERIMENTAL_XDS_DATA_ERROR_HANDLING = + "GRPC_EXPERIMENTAL_XDS_DATA_ERROR_HANDLING"; + // Client features. @VisibleForTesting public static final String CLIENT_FEATURE_DISABLE_OVERPROVISIONING = @@ -50,6 +57,17 @@ public abstract class BootstrapperImpl extends Bootstrapper { // Server features. private static final String SERVER_FEATURE_IGNORE_RESOURCE_DELETION = "ignore_resource_deletion"; + private static final String SERVER_FEATURE_TRUSTED_XDS_SERVER = "trusted_xds_server"; + private static final String + SERVER_FEATURE_RESOURCE_TIMER_IS_TRANSIENT_ERROR = "resource_timer_is_transient_error"; + private static final String SERVER_FEATURE_FAIL_ON_DATA_ERRORS = "fail_on_data_errors"; + + @VisibleForTesting + static boolean enableXdsFallback = GrpcUtil.getFlag(GRPC_EXPERIMENTAL_XDS_FALLBACK, true); + + @VisibleForTesting + public static boolean xdsDataErrorHandlingEnabled + = GrpcUtil.getFlag(GRPC_EXPERIMENTAL_XDS_DATA_ERROR_HANDLING, false); protected final XdsLogger logger; @@ -64,6 +82,7 @@ protected BootstrapperImpl() { protected abstract Object getImplSpecificConfig(Map serverConfig, String serverUri) throws XdsInitializationException; + /** * Reads and parses bootstrap config. The config is expected to be in JSON format. */ @@ -102,6 +121,9 @@ protected BootstrapInfo.Builder bootstrapBuilder(Map rawData) throw new XdsInitializationException("Invalid bootstrap: 'xds_servers' does not exist."); } List servers = parseServerInfos(rawServerConfigs, logger); + if (servers.size() > 1 && !enableXdsFallback) { + servers = ImmutableList.of(servers.get(0)); + } builder.servers(servers); Node.Builder nodeBuilder = Node.newBuilder(); @@ -208,6 +230,9 @@ protected BootstrapInfo.Builder bootstrapBuilder(Map rawData) if (rawAuthorityServers == null || rawAuthorityServers.isEmpty()) { authorityServers = servers; } else { + if (rawAuthorityServers.size() > 1 && !enableXdsFallback) { + rawAuthorityServers = ImmutableList.of(rawAuthorityServers.get(0)); + } authorityServers = parseServerInfos(rawAuthorityServers, logger); } authorityInfoMapBuilder.put( @@ -216,9 +241,18 @@ protected BootstrapInfo.Builder bootstrapBuilder(Map rawData) builder.authorities(authorityInfoMapBuilder.buildOrThrow()); } + Map rawAllowedGrpcServices = JsonUtil.getObject(rawData, "allowed_grpc_services"); + builder.implSpecificObject(parseImplSpecificObject(rawAllowedGrpcServices)); + return builder; } + protected Optional parseImplSpecificObject( + @Nullable Map rawAllowedGrpcServices) + throws XdsInitializationException { + return Optional.empty(); + } + private List parseServerInfos(List rawServerConfigs, XdsLogger logger) throws XdsInitializationException { logger.log(XdsLogLevel.INFO, "Configured with {0} xDS servers", rawServerConfigs.size()); @@ -233,14 +267,27 @@ private List parseServerInfos(List rawServerConfigs, XdsLogger lo Object implSpecificConfig = getImplSpecificConfig(serverConfig, serverUri); + boolean resourceTimerIsTransientError = false; boolean ignoreResourceDeletion = false; - List serverFeatures = JsonUtil.getListOfStrings(serverConfig, "server_features"); + boolean failOnDataErrors = false; + // "For forward compatibility reasons, the client will ignore any entry in the list that it + // does not understand, regardless of type." + List serverFeatures = JsonUtil.getList(serverConfig, "server_features"); if (serverFeatures != null) { logger.log(XdsLogLevel.INFO, "Server features: {0}", serverFeatures); - ignoreResourceDeletion = serverFeatures.contains(SERVER_FEATURE_IGNORE_RESOURCE_DELETION); + if (serverFeatures.contains(SERVER_FEATURE_IGNORE_RESOURCE_DELETION)) { + ignoreResourceDeletion = true; + } + resourceTimerIsTransientError = xdsDataErrorHandlingEnabled + && serverFeatures.contains(SERVER_FEATURE_RESOURCE_TIMER_IS_TRANSIENT_ERROR); + failOnDataErrors = xdsDataErrorHandlingEnabled + && serverFeatures.contains(SERVER_FEATURE_FAIL_ON_DATA_ERRORS); } servers.add( - ServerInfo.create(serverUri, implSpecificConfig, ignoreResourceDeletion)); + ServerInfo.create(serverUri, implSpecificConfig, ignoreResourceDeletion, + serverFeatures != null + && serverFeatures.contains(SERVER_FEATURE_TRUSTED_XDS_SERVER), + resourceTimerIsTransientError, failOnDataErrors)); } return servers.build(); } diff --git a/xds/src/main/java/io/grpc/xds/client/ConfiguredChannelCredentials.java b/xds/src/main/java/io/grpc/xds/client/ConfiguredChannelCredentials.java new file mode 100644 index 00000000000..c6b9d774b4d --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/client/ConfiguredChannelCredentials.java @@ -0,0 +1,48 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.client; + +import com.google.auto.value.AutoValue; +import io.grpc.ChannelCredentials; +import io.grpc.Internal; + +/** + * Composition of {@link ChannelCredentials} and {@link ChannelCredsConfig}. + */ +@Internal +@AutoValue +public abstract class ConfiguredChannelCredentials { + public abstract ChannelCredentials channelCredentials(); + + public abstract ChannelCredsConfig channelCredsConfig(); + + public static ConfiguredChannelCredentials create(ChannelCredentials creds, + ChannelCredsConfig config) { + return new AutoValue_ConfiguredChannelCredentials(creds, config); + } + + /** + * Configuration for channel credentials. + */ + @Internal + public interface ChannelCredsConfig { + /** + * Returns the type of the credentials. + */ + String type(); + } +} diff --git a/xds/src/main/java/io/grpc/xds/client/ControlPlaneClient.java b/xds/src/main/java/io/grpc/xds/client/ControlPlaneClient.java index 761c10ede6a..981db516e5b 100644 --- a/xds/src/main/java/io/grpc/xds/client/ControlPlaneClient.java +++ b/xds/src/main/java/io/grpc/xds/client/ControlPlaneClient.java @@ -16,7 +16,6 @@ package io.grpc.xds.client; -import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.base.Preconditions.checkState; @@ -40,7 +39,6 @@ import io.grpc.xds.client.XdsClient.ResourceStore; import io.grpc.xds.client.XdsClient.XdsResponseHandler; import io.grpc.xds.client.XdsLogger.XdsLogLevel; -import io.grpc.xds.client.XdsTransportFactory.EventHandler; import io.grpc.xds.client.XdsTransportFactory.StreamingCall; import io.grpc.xds.client.XdsTransportFactory.XdsTransport; import java.util.Collection; @@ -60,7 +58,6 @@ */ final class ControlPlaneClient { - public static final String CLOSED_BY_SERVER = "Closed by server"; private final SynchronizationContext syncContext; private final InternalLogId logId; private final XdsLogger logger; @@ -72,7 +69,6 @@ final class ControlPlaneClient { private final BackoffPolicy.Provider backoffPolicyProvider; private final Stopwatch stopwatch; private final Node bootstrapNode; - private final XdsClient xdsClient; // Last successfully applied version_info for each resource type. Starts with empty string. // A version_info is used to update management server with client's most recent knowledge of @@ -80,13 +76,15 @@ final class ControlPlaneClient { private final Map, String> versions = new HashMap<>(); private boolean shutdown; + private boolean inError; + @Nullable private AdsStream adsStream; @Nullable private BackoffPolicy retryBackoffPolicy; @Nullable private ScheduledHandle rpcRetryTimer; - private MessagePrettyPrinter messagePrinter; + private final MessagePrettyPrinter messagePrinter; /** An entity that manages ADS RPCs over a single channel. */ ControlPlaneClient( @@ -100,7 +98,6 @@ final class ControlPlaneClient { SynchronizationContext syncContext, BackoffPolicy.Provider backoffPolicyProvider, Supplier stopwatchSupplier, - XdsClient xdsClient, MessagePrettyPrinter messagePrinter) { this.serverInfo = checkNotNull(serverInfo, "serverInfo"); this.xdsTransport = checkNotNull(xdsTransport, "xdsTransport"); @@ -110,10 +107,9 @@ final class ControlPlaneClient { this.timeService = checkNotNull(timeService, "timeService"); this.syncContext = checkNotNull(syncContext, "syncContext"); this.backoffPolicyProvider = checkNotNull(backoffPolicyProvider, "backoffPolicyProvider"); - this.xdsClient = checkNotNull(xdsClient, "xdsClient"); this.messagePrinter = checkNotNull(messagePrinter, "messagePrinter"); stopwatch = checkNotNull(stopwatchSupplier, "stopwatchSupplier").get(); - logId = InternalLogId.allocate("xds-client", serverInfo.target()); + logId = InternalLogId.allocate("xds-cp-client", serverInfo.target()); logger = XdsLogger.withLogId(logId); logger.log(XdsLogLevel.INFO, "Created"); } @@ -140,20 +136,66 @@ public String toString() { return logId.toString(); } + public ServerInfo getServerInfo() { + return serverInfo; + } + /** * Updates the resource subscription for the given resource type. */ // Must be synchronized. void adjustResourceSubscription(XdsResourceType resourceType) { - if (isInBackoff()) { + if (rpcRetryTimer != null && rpcRetryTimer.isPending()) { return; } if (adsStream == null) { startRpcStream(); + // when the stream becomes ready, it will send the discovery requests + return; } + + // We will do the rest of the method as part of the readyHandler when the stream is ready. + if (!isConnected()) { + return; + } + Collection resources = resourceStore.getSubscribedResources(serverInfo, resourceType); - if (resources != null) { - adsStream.sendDiscoveryRequest(resourceType, resources); + if (resources == null && !adsStream.sentTypes.contains(resourceType)) { + // No subscription for this type on this server, and we have never sent a DiscoveryRequest + // of this type on the current stream — the server has no subscription state to clear. + // + // Per the ResourceStore contract in XdsClient.java, a null return means "no subscription"; + // an empty collection means wildcard subscription, which is a real subscription and must + // not be skipped here. + // + // We track sent types per-stream rather than gating on `versions` because `versions` is + // only populated on ACK. If a watch is canceled after the initial DiscoveryRequest goes + // out but before any response is ACKed, `versions` would still have no entry for the + // type, and gating on it would suppress the empty unsubscribe — leaving the server with + // a stale subscription until the stream resets. + // + // Without this skip, sendDiscoveryRequests() iterates over every globally-subscribed + // resource type when a stream becomes ready and emits an empty DiscoveryRequest for types + // that have no subscription on this server. Per A47 (xDS Federation) servers may be + // authority-specific (e.g. an EDS-only control plane) and reject DiscoveryRequests for + // types they do not handle, tearing down the stream. + // + // Mirrors grpc-go's behavior in + // internal/xds/clients/xdsclient/ads_stream.go:sendExisting, which skips types with no + // subscription. + return; + } + if (resources == null) { + resources = Collections.emptyList(); + } + adsStream.sendDiscoveryRequest(resourceType, resources); + resourceStore.startMissingResourceTimers(resources, resourceType); + + if (resources.isEmpty()) { + // The resource type no longer has subscribing resources; clean up references to it, except + // for nonces. If the resource type becomes used again the control plane can ignore requests + // for old/missing nonces. Old type's nonces are dropped when the ADS stream is restarted. + versions.remove(resourceType); } } @@ -189,35 +231,42 @@ void nackResponse(XdsResourceType type, String nonce, String errorDetail) { adsStream.sendDiscoveryRequest(type, versionInfo, resources, nonce, errorDetail); } - /** - * Returns {@code true} if the resource discovery is currently in backoff. - */ // Must be synchronized. - boolean isInBackoff() { - return rpcRetryTimer != null && rpcRetryTimer.isPending(); + boolean isReady() { + return adsStream != null && adsStream.call != null + && adsStream.call.isReady() && !adsStream.closed; } - // Must be synchronized. - boolean isReady() { - return adsStream != null && adsStream.call != null && adsStream.call.isReady(); + boolean isConnected() { + return adsStream != null && adsStream.sentInitialRequest; } /** - * Starts a timer for each requested resource that hasn't been responded to and - * has been waiting for the channel to get ready. + * Used for identifying whether or not when getting a control plane for authority that this + * control plane should be skipped over if there is a fallback. + * + *

Also used by metric to consider this control plane to not be "active". + * + *

A ControlPlaneClient is considered to be in error during the time from when an + * {@link AdsStream} closed without having received a response to the time an AdsStream does + * receive a response. */ - // Must be synchronized. - void readyHandler() { - if (!isReady()) { - return; - } + boolean isInError() { + return inError; + } - if (isInBackoff()) { - rpcRetryTimer.cancel(); - rpcRetryTimer = null; - } - xdsClient.startSubscriberTimersIfNeeded(serverInfo); + /** + * Cleans up outstanding rpcRetryTimer if present, since we are communicating. + * If we haven't sent the initial discovery request for this RPC stream, we will delegate to + * xdsResponseHandler (in practice XdsClientImpl) to do any initialization for a new active + * stream such as starting timers. We then send the initial discovery request. + */ + // Must be synchronized. + void readyHandler(boolean shouldSendInitialRequest) { + if (shouldSendInitialRequest) { + sendDiscoveryRequests(); + } } /** @@ -227,28 +276,51 @@ void readyHandler() { // Must be synchronized. private void startRpcStream() { checkState(adsStream == null, "Previous adsStream has not been cleared yet"); + + if (rpcRetryTimer != null) { + rpcRetryTimer.cancel(); + rpcRetryTimer = null; + } + adsStream = new AdsStream(); + adsStream.start(); logger.log(XdsLogLevel.INFO, "ADS stream started"); stopwatch.reset().start(); } + void sendDiscoveryRequests() { + if (rpcRetryTimer != null && rpcRetryTimer.isPending()) { + return; + } + + if (adsStream == null) { + startRpcStream(); + // when the stream becomes ready, it will send the discovery requests + return; + } + + if (isConnected()) { + Set> subscribedResourceTypes = + new HashSet<>(resourceStore.getSubscribedResourceTypesWithTypeUrl().values()); + + for (XdsResourceType type : subscribedResourceTypes) { + adjustResourceSubscription(type); + } + } + } + @VisibleForTesting public final class RpcRetryTask implements Runnable { @Override public void run() { + logger.log(XdsLogLevel.DEBUG, "Retry timeout. Restart ADS stream {0}", logId); if (shutdown) { return; } + startRpcStream(); - Set> subscribedResourceTypes = - new HashSet<>(resourceStore.getSubscribedResourceTypesWithTypeUrl().values()); - for (XdsResourceType type : subscribedResourceTypes) { - Collection resources = resourceStore.getSubscribedResources(serverInfo, type); - if (resources != null) { - adsStream.sendDiscoveryRequest(type, resources); - } - } - xdsResponseHandler.handleStreamRestarted(serverInfo); + + // handling CPC management is triggered in readyHandler } } @@ -258,16 +330,25 @@ XdsResourceType fromTypeUrl(String typeUrl) { return resourceStore.getSubscribedResourceTypesWithTypeUrl().get(typeUrl); } - private class AdsStream implements EventHandler { + private class AdsStream implements XdsTransportFactory.EventHandler { private boolean responseReceived; + private boolean sentInitialRequest; private boolean closed; - // Response nonce for the most recently received discovery responses of each resource type. + // Response nonce for the most recently received discovery responses of each resource type URL. // Client initiated requests start response nonce with empty string. // Nonce in each response is echoed back in the following ACK/NACK request. It is // used for management server to identify which response the client is ACKing/NACking. // To avoid confusion, client-initiated requests will always use the nonce in - // most recently received responses of each resource type. - private final Map, String> respNonces = new HashMap<>(); + // most recently received responses of each resource type. Nonces are never deleted from the + // map; nonces are only discarded once the stream closes because xds_protocol says "the + // management server should not send a DiscoveryResponse for any DiscoveryRequest that has a + // stale nonce." + private final Map respNonces = new HashMap<>(); + // Resource types for which a DiscoveryRequest has been sent on this stream. Used by + // adjustResourceSubscription() to decide whether an empty unsubscribe must be sent on the + // wire: the server only has subscription state to clear for types we have actually sent a + // request for on this stream. Cleared implicitly when the stream is replaced. + private final Set> sentTypes = new HashSet<>(); private final StreamingCall call; private final MethodDescriptor methodDescriptor = AggregatedDiscoveryServiceGrpc.getStreamAggregatedResourcesMethod(); @@ -275,6 +356,9 @@ private class AdsStream implements EventHandler { private AdsStream() { this.call = xdsTransport.createStreamingCall(methodDescriptor.getFullMethodName(), methodDescriptor.getRequestMarshaller(), methodDescriptor.getResponseMarshaller()); + } + + void start() { call.start(this); } @@ -304,6 +388,7 @@ void sendDiscoveryRequest(XdsResourceType type, String versionInfo, } DiscoveryRequest request = builder.build(); call.sendMessage(request); + sentTypes.add(type); if (logger.isLoggable(XdsLogLevel.DEBUG)) { logger.log(XdsLogLevel.DEBUG, "Sent DiscoveryRequest\n{0}", messagePrinter.print(request)); } @@ -315,12 +400,24 @@ void sendDiscoveryRequest(XdsResourceType type, String versionInfo, final void sendDiscoveryRequest(XdsResourceType type, Collection resources) { logger.log(XdsLogLevel.INFO, "Sending {0} request for resources: {1}", type, resources); sendDiscoveryRequest(type, versions.getOrDefault(type, ""), resources, - respNonces.getOrDefault(type, ""), null); + respNonces.getOrDefault(type.typeUrl(), ""), null); } @Override public void onReady() { - syncContext.execute(ControlPlaneClient.this::readyHandler); + syncContext.execute(() -> { + if (!isReady()) { + logger.log(XdsLogLevel.DEBUG, + "ADS stream ready handler called, but not ready {0}", logId); + return; + } + + logger.log(XdsLogLevel.DEBUG, "ADS stream ready {0}", logId); + + boolean hadSentInitialRequest = sentInitialRequest; + sentInitialRequest = true; + readyHandler(!hadSentInitialRequest); + }); } @Override @@ -328,6 +425,14 @@ public void onRecvMessage(DiscoveryResponse response) { syncContext.execute(new Runnable() { @Override public void run() { + if (closed) { + return; + } + boolean isFirstResponse = !responseReceived; + responseReceived = true; + inError = false; + respNonces.put(response.getTypeUrl(), response.getNonce()); + XdsResourceType type = fromTypeUrl(response.getTypeUrl()); if (logger.isLoggable(XdsLogLevel.DEBUG)) { logger.log( @@ -344,7 +449,7 @@ public void run() { return; } handleRpcResponse(type, response.getVersionInfo(), response.getResourcesList(), - response.getNonce()); + response.getNonce(), isFirstResponse); } }); } @@ -352,30 +457,22 @@ public void run() { @Override public void onStatusReceived(final Status status) { syncContext.execute(() -> { - if (status.isOk()) { - handleRpcStreamClosed(Status.UNAVAILABLE.withDescription(CLOSED_BY_SERVER)); - } else { - handleRpcStreamClosed(status); - } + handleRpcStreamClosed(status); }); } final void handleRpcResponse(XdsResourceType type, String versionInfo, List resources, - String nonce) { + String nonce, boolean isFirstResponse) { checkNotNull(type, "type"); - if (closed) { - return; - } - responseReceived = true; - respNonces.put(type, nonce); + ProcessingTracker processingTracker = new ProcessingTracker( () -> call.startRecvMessage(), syncContext); xdsResponseHandler.handleResourceResponse(type, serverInfo, versionInfo, resources, nonce, - processingTracker); + isFirstResponse, processingTracker); processingTracker.onComplete(); } - private void handleRpcStreamClosed(Status error) { + private void handleRpcStreamClosed(Status status) { if (closed) { return; } @@ -384,27 +481,47 @@ private void handleRpcStreamClosed(Status error) { // Reset the backoff sequence if had received a response, or backoff sequence // has never been initialized. retryBackoffPolicy = backoffPolicyProvider.get(); + stopwatch.reset(); + } + + Status newStatus = status; + if (responseReceived) { + // A closed ADS stream after a successful response is not considered an error. Servers may + // close streams for various reasons during normal operation, such as load balancing or + // underlying connection hitting its max connection age limit (see gRFC A9). + if (!status.isOk()) { + newStatus = Status.OK; + logger.log(XdsLogLevel.DEBUG, "ADS stream closed with error {0}: {1}. However, a " + + "response was received, so this will not be treated as an error. Cause: {2}", + status.getCode(), status.getDescription(), status.getCause()); + } else { + logger.log(XdsLogLevel.DEBUG, + "ADS stream closed by server after a response was received"); + } + } else { + // If the ADS stream is closed without ever having received a response from the server, then + // the XdsClient should consider that a connectivity error (see gRFC A57). + inError = true; + if (status.isOk()) { + newStatus = Status.UNAVAILABLE.withDescription( + "ADS stream closed with OK before receiving a response"); + } + logger.log( + XdsLogLevel.ERROR, "ADS stream failed with status {0}: {1}. Cause: {2}", + newStatus.getCode(), newStatus.getDescription(), newStatus.getCause()); } + + close(newStatus.asException()); + // FakeClock in tests isn't thread-safe. Schedule the retry timer before notifying callbacks // to avoid TSAN races, since tests may wait until callbacks are called but then would run // concurrently with the stopwatch and schedule. long elapsed = stopwatch.elapsed(TimeUnit.NANOSECONDS); long delayNanos = Math.max(0, retryBackoffPolicy.nextBackoffNanos() - elapsed); - rpcRetryTimer = syncContext.schedule( - new RpcRetryTask(), delayNanos, TimeUnit.NANOSECONDS, timeService); - - checkArgument(!error.isOk(), "unexpected OK status"); - String errorMsg = error.getDescription() != null - && error.getDescription().equals(CLOSED_BY_SERVER) - ? "ADS stream closed with status {0}: {1}. Cause: {2}" - : "ADS stream failed with status {0}: {1}. Cause: {2}"; - logger.log( - XdsLogLevel.ERROR, errorMsg, error.getCode(), error.getDescription(), error.getCause()); - closed = true; - xdsResponseHandler.handleStreamClosed(error); - cleanUp(); + rpcRetryTimer = + syncContext.schedule(new RpcRetryTask(), delayNanos, TimeUnit.NANOSECONDS, timeService); - logger.log(XdsLogLevel.INFO, "Retry ADS stream in {0} ns", delayNanos); + xdsResponseHandler.handleStreamClosed(newStatus, !responseReceived); } private void close(Exception error) { @@ -422,4 +539,55 @@ private void cleanUp() { } } } + + @VisibleForTesting + static class FailingXdsTransport implements XdsTransport { + Status error; + + public FailingXdsTransport(Status error) { + this.error = error; + } + + @Override + public StreamingCall + createStreamingCall(String fullMethodName, + MethodDescriptor.Marshaller reqMarshaller, + MethodDescriptor.Marshaller respMarshaller) { + return new FailingXdsStreamingCall<>(); + } + + @Override + public void shutdown() { + // no-op + } + + private class FailingXdsStreamingCall implements StreamingCall { + + @Override + public void start(XdsTransportFactory.EventHandler eventHandler) { + eventHandler.onStatusReceived(error); + } + + @Override + public void sendMessage(ReqT message) { + // no-op + } + + @Override + public void startRecvMessage() { + // no-op + } + + @Override + public void sendError(Exception e) { + // no-op + } + + @Override + public boolean isReady() { + return false; + } + } + } + } diff --git a/xds/src/main/java/io/grpc/xds/client/LoadStatsManager2.java b/xds/src/main/java/io/grpc/xds/client/LoadStatsManager2.java index 393cce16194..cd858dccd99 100644 --- a/xds/src/main/java/io/grpc/xds/client/LoadStatsManager2.java +++ b/xds/src/main/java/io/grpc/xds/client/LoadStatsManager2.java @@ -25,6 +25,7 @@ import com.google.common.collect.Sets; import io.grpc.Internal; import io.grpc.Status; +import io.grpc.internal.GrpcUtil; import io.grpc.xds.client.Stats.BackendLoadMetricStats; import io.grpc.xds.client.Stats.ClusterStats; import io.grpc.xds.client.Stats.DroppedRequests; @@ -57,6 +58,8 @@ public final class LoadStatsManager2 { private final Map>>> allLoadStats = new HashMap<>(); private final Supplier stopwatchSupplier; + public static boolean isEnabledOrcaLrsPropagation = + GrpcUtil.getFlag("GRPC_EXPERIMENTAL_XDS_ORCA_LRS_PROPAGATION", false); @VisibleForTesting public LoadStatsManager2(Supplier stopwatchSupplier) { @@ -91,20 +94,27 @@ private synchronized void releaseClusterDropCounter( String cluster, @Nullable String edsServiceName) { checkState(allDropStats.containsKey(cluster) && allDropStats.get(cluster).containsKey(edsServiceName), - "stats for cluster %s, edsServiceName %s not exits", cluster, edsServiceName); + "stats for cluster %s, edsServiceName %s do not exist", cluster, edsServiceName); ReferenceCounted ref = allDropStats.get(cluster).get(edsServiceName); ref.release(); } /** * Gets or creates the stats object for recording loads for the specified locality (in the - * specified cluster with edsServiceName). The returned object is reference counted and the - * caller should use {@link ClusterLocalityStats#release} to release its hard reference + * specified cluster with edsServiceName) with the specified backend metric propagation + * configuration. The returned object is reference counted and the caller should + * use {@link ClusterLocalityStats#release} to release its hard reference * when it is safe to discard the future stats for the locality. */ @VisibleForTesting public synchronized ClusterLocalityStats getClusterLocalityStats( String cluster, @Nullable String edsServiceName, Locality locality) { + return getClusterLocalityStats(cluster, edsServiceName, locality, null); + } + + public synchronized ClusterLocalityStats getClusterLocalityStats( + String cluster, @Nullable String edsServiceName, Locality locality, + @Nullable BackendMetricPropagation backendMetricPropagation) { if (!allLoadStats.containsKey(cluster)) { allLoadStats.put( cluster, @@ -121,8 +131,8 @@ public synchronized ClusterLocalityStats getClusterLocalityStats( if (!localityStats.containsKey(locality)) { localityStats.put( locality, - ReferenceCounted.wrap(new ClusterLocalityStats( - cluster, edsServiceName, locality, stopwatchSupplier.get()))); + ReferenceCounted.wrap(new ClusterLocalityStats(cluster, edsServiceName, + locality, stopwatchSupplier.get(), backendMetricPropagation))); } ReferenceCounted ref = localityStats.get(locality); ref.retain(); @@ -325,6 +335,8 @@ public final class ClusterLocalityStats { private final String edsServiceName; private final Locality locality; private final Stopwatch stopwatch; + @Nullable + private final BackendMetricPropagation backendMetricPropagation; private final AtomicLong callsInProgress = new AtomicLong(); private final AtomicLong callsSucceeded = new AtomicLong(); private final AtomicLong callsFailed = new AtomicLong(); @@ -333,11 +345,12 @@ public final class ClusterLocalityStats { private ClusterLocalityStats( String clusterName, @Nullable String edsServiceName, Locality locality, - Stopwatch stopwatch) { + Stopwatch stopwatch, BackendMetricPropagation backendMetricPropagation) { this.clusterName = checkNotNull(clusterName, "clusterName"); this.edsServiceName = edsServiceName; this.locality = checkNotNull(locality, "locality"); this.stopwatch = checkNotNull(stopwatch, "stopwatch"); + this.backendMetricPropagation = backendMetricPropagation; stopwatch.reset().start(); } @@ -367,17 +380,51 @@ public void recordCallFinished(Status status) { * requests counter of 1 and the {@code value} if the key is not present in the map. Otherwise, * increments the finished requests counter and adds the {@code value} to the existing * {@link BackendLoadMetricStats}. + * Metrics are filtered based on the backend metric propagation configuration if configured. */ public synchronized void recordBackendLoadMetricStats(Map namedMetrics) { + if (!isEnabledOrcaLrsPropagation) { + namedMetrics.forEach((name, value) -> updateLoadMetricStats(name, value)); + return; + } + namedMetrics.forEach((name, value) -> { - if (!loadMetricStatsMap.containsKey(name)) { - loadMetricStatsMap.put(name, new BackendLoadMetricStats(1, value)); - } else { - loadMetricStatsMap.get(name).addMetricValueAndIncrementRequestsFinished(value); + if (backendMetricPropagation.shouldPropagateNamedMetric(name)) { + updateLoadMetricStats("named_metrics." + name, value); } }); } + private void updateLoadMetricStats(String metricName, double value) { + if (!loadMetricStatsMap.containsKey(metricName)) { + loadMetricStatsMap.put(metricName, new BackendLoadMetricStats(1, value)); + } else { + loadMetricStatsMap.get(metricName).addMetricValueAndIncrementRequestsFinished(value); + } + } + + /** + * Records top-level ORCA metrics (CPU, memory, application utilization) for per-call load + * reporting. Metrics are filtered based on the backend metric propagation configuration + * if configured. + * + * @param cpuUtilization CPU utilization metric value + * @param memUtilization Memory utilization metric value + * @param applicationUtilization Application utilization metric value + */ + public synchronized void recordTopLevelMetrics(double cpuUtilization, double memUtilization, + double applicationUtilization) { + if (backendMetricPropagation.propagateCpuUtilization && cpuUtilization > 0) { + updateLoadMetricStats("cpu_utilization", cpuUtilization); + } + if (backendMetricPropagation.propagateMemUtilization && memUtilization > 0) { + updateLoadMetricStats("mem_utilization", memUtilization); + } + if (backendMetricPropagation.propagateApplicationUtilization && applicationUtilization > 0) { + updateLoadMetricStats("application_utilization", applicationUtilization); + } + } + /** * Release the hard reference for this stats object (previously obtained via {@link * LoadStatsManager2#getClusterLocalityStats}). The object may still be diff --git a/xds/src/main/java/io/grpc/xds/client/XdsClient.java b/xds/src/main/java/io/grpc/xds/client/XdsClient.java index fc7e1777384..982fb6651a9 100644 --- a/xds/src/main/java/io/grpc/xds/client/XdsClient.java +++ b/xds/src/main/java/io/grpc/xds/client/XdsClient.java @@ -27,6 +27,7 @@ import com.google.protobuf.Any; import io.grpc.ExperimentalApi; import io.grpc.Status; +import io.grpc.StatusOr; import io.grpc.xds.client.Bootstrapper.ServerInfo; import java.net.URI; import java.net.URISyntaxException; @@ -36,6 +37,7 @@ import java.util.List; import java.util.Map; import java.util.concurrent.Executor; +import java.util.concurrent.Future; import java.util.concurrent.atomic.AtomicInteger; import javax.annotation.Nullable; @@ -117,34 +119,50 @@ public static String percentEncodePath(String input) { return Joiner.on('/').join(encodedSegs); } + /** + * Returns the authority from the resource name. + */ + public static String getAuthorityFromResourceName(String resourceNames) { + String authority; + if (resourceNames.startsWith(XDSTP_SCHEME)) { + URI uri = URI.create(resourceNames); + authority = uri.getAuthority(); + if (authority == null) { + authority = ""; + } + } else { + authority = null; + } + return authority; + } + public interface ResourceUpdate {} /** * Watcher interface for a single requested xDS resource. + * + *

Note that we expect that the implementer to: + * - Comply with the guarantee to not generate certain statuses by the library: + * https://grpc.github.io/grpc/core/md_doc_statuscodes.html. If the code needs to be + * propagated to the channel, override it with {@link io.grpc.Status.Code#UNAVAILABLE}. + * - Keep {@link Status} description in one form or another, as it contains valuable debugging + * information. */ @ExperimentalApi("https://github.com/grpc/grpc-java/issues/10862") public interface ResourceWatcher { /** - * Called when the resource discovery RPC encounters some transient error. - * - *

Note that we expect that the implementer to: - * - Comply with the guarantee to not generate certain statuses by the library: - * https://grpc.github.io/grpc/core/md_doc_statuscodes.html. If the code needs to be - * propagated to the channel, override it with {@link io.grpc.Status.Code#UNAVAILABLE}. - * - Keep {@link Status} description in one form or another, as it contains valuable debugging - * information. + * Called to deliver a resource update or an error. If an error is passed after a valid + * resource has been delivered, the watcher should stop using the previously delivered + * resource. */ - void onError(Status error); + void onResourceChanged(StatusOr update); /** - * Called when the requested resource is not available. - * - * @param resourceName name of the resource requested in discovery request. - */ - void onResourceDoesNotExist(String resourceName); - - void onChanged(T update); + * Called to deliver a transient error that should not affect the watcher's use of any + * previously received resource. + * */ + void onAmbientError(Status error); } /** @@ -154,44 +172,50 @@ public static final class ResourceMetadata { private final String version; private final ResourceMetadataStatus status; private final long updateTimeNanos; + private final boolean cached; @Nullable private final Any rawResource; @Nullable private final UpdateFailureState errorState; private ResourceMetadata( - ResourceMetadataStatus status, String version, long updateTimeNanos, + ResourceMetadataStatus status, String version, long updateTimeNanos, boolean cached, @Nullable Any rawResource, @Nullable UpdateFailureState errorState) { this.status = checkNotNull(status, "status"); this.version = checkNotNull(version, "version"); this.updateTimeNanos = updateTimeNanos; + this.cached = cached; this.rawResource = rawResource; this.errorState = errorState; } - static ResourceMetadata newResourceMetadataUnknown() { - return new ResourceMetadata(ResourceMetadataStatus.UNKNOWN, "", 0, null, null); + public static ResourceMetadata newResourceMetadataUnknown() { + return new ResourceMetadata(ResourceMetadataStatus.UNKNOWN, "", 0, false,null, null); + } + + public static ResourceMetadata newResourceMetadataRequested() { + return new ResourceMetadata(ResourceMetadataStatus.REQUESTED, "", 0, false, null, null); } - static ResourceMetadata newResourceMetadataRequested() { - return new ResourceMetadata(ResourceMetadataStatus.REQUESTED, "", 0, null, null); + public static ResourceMetadata newResourceMetadataDoesNotExist() { + return new ResourceMetadata(ResourceMetadataStatus.DOES_NOT_EXIST, "", 0, false, null, null); } - static ResourceMetadata newResourceMetadataDoesNotExist() { - return new ResourceMetadata(ResourceMetadataStatus.DOES_NOT_EXIST, "", 0, null, null); + public static ResourceMetadata newResourceMetadataTimeout() { + return new ResourceMetadata(ResourceMetadataStatus.TIMEOUT, "", 0, false, null, null); } public static ResourceMetadata newResourceMetadataAcked( Any rawResource, String version, long updateTimeNanos) { checkNotNull(rawResource, "rawResource"); return new ResourceMetadata( - ResourceMetadataStatus.ACKED, version, updateTimeNanos, rawResource, null); + ResourceMetadataStatus.ACKED, version, updateTimeNanos, true, rawResource, null); } - static ResourceMetadata newResourceMetadataNacked( + public static ResourceMetadata newResourceMetadataNacked( ResourceMetadata metadata, String failedVersion, long failedUpdateTime, - String failedDetails) { + String failedDetails, boolean cached) { checkNotNull(metadata, "metadata"); return new ResourceMetadata(ResourceMetadataStatus.NACKED, - metadata.getVersion(), metadata.getUpdateTimeNanos(), metadata.getRawResource(), + metadata.getVersion(), metadata.getUpdateTimeNanos(), cached, metadata.getRawResource(), new UpdateFailureState(failedVersion, failedUpdateTime, failedDetails)); } @@ -210,6 +234,11 @@ public long getUpdateTimeNanos() { return updateTimeNanos; } + /** Returns whether the resource was cached. */ + public boolean isCached() { + return cached; + } + /** The last successfully updated xDS resource as it was returned by the server. */ @Nullable public Any getRawResource() { @@ -231,7 +260,7 @@ public UpdateFailureState getErrorState() { * config_dump.proto */ public enum ResourceMetadataStatus { - UNKNOWN, REQUESTED, DOES_NOT_EXIST, ACKED, NACKED + UNKNOWN, REQUESTED, DOES_NOT_EXIST, ACKED, NACKED, TIMEOUT } /** @@ -298,14 +327,6 @@ public Object getSecurityConfig() { throw new UnsupportedOperationException(); } - /** - * For all subscriber's for the specified server, if the resource hasn't yet been - * resolved then start a timer for it. - */ - protected void startSubscriberTimersIfNeeded(ServerInfo serverInfo) { - throw new UnsupportedOperationException(); - } - /** * Returns a {@link ListenableFuture} to the snapshot of the subscribed resources as * they are at the moment of the call. @@ -367,6 +388,23 @@ public LoadStatsManager2.ClusterDropStats addClusterDropStats( public LoadStatsManager2.ClusterLocalityStats addClusterLocalityStats( Bootstrapper.ServerInfo serverInfo, String clusterName, @Nullable String edsServiceName, Locality locality) { + return addClusterLocalityStats(serverInfo, clusterName, edsServiceName, locality, null); + } + + /** + * Adds load stats for the specified locality (in the specified cluster with edsServiceName) by + * using the returned object to record RPCs. Load stats recorded with the returned object will + * be reported to the load reporting server. The returned object is reference counted and the + * caller should use {@link LoadStatsManager2.ClusterLocalityStats#release} to release its + * hard reference when it is safe to stop reporting RPC loads for the specified locality + * in the future. + * + * @param backendMetricPropagation Configuration for which backend metrics should be propagated + * to LRS load reports. If null, all metrics will be propagated (legacy behavior). + */ + public LoadStatsManager2.ClusterLocalityStats addClusterLocalityStats( + Bootstrapper.ServerInfo serverInfo, String clusterName, @Nullable String edsServiceName, + Locality locality, @Nullable BackendMetricPropagation backendMetricPropagation) { throw new UnsupportedOperationException(); } @@ -378,6 +416,23 @@ public Map getServerLrsClientMap() { throw new UnsupportedOperationException(); } + /** Callback used to report a gauge metric value for server connections. */ + public interface ServerConnectionCallback { + void reportServerConnectionGauge(boolean isConnected, String xdsServer); + } + + /** + * Reports whether xDS client has a "working" ADS stream to xDS server. The definition of a + * working stream is defined in gRFC A78. + * + * @see + * A78-grpc-metrics-wrr-pf-xds.md + */ + public Future reportServerConnections(ServerConnectionCallback callback) { + throw new UnsupportedOperationException(); + } + static final class ProcessingTracker { private final AtomicInteger pendingTask = new AtomicInteger(1); private final Executor executor; @@ -403,30 +458,39 @@ interface XdsResponseHandler { /** Called when a xds response is received. */ void handleResourceResponse( XdsResourceType resourceType, ServerInfo serverInfo, String versionInfo, - List resources, String nonce, ProcessingTracker processingTracker); + List resources, String nonce, boolean isFirstResponse, + ProcessingTracker processingTracker); /** Called when the ADS stream is closed passively. */ // Must be synchronized. - void handleStreamClosed(Status error); - - /** Called when the ADS stream has been recreated. */ - // Must be synchronized. - void handleStreamRestarted(ServerInfo serverInfo); + void handleStreamClosed(Status error, boolean shouldTryFallback); } - public interface ResourceStore { + interface ResourceStore { + /** - * Returns the collection of resources currently subscribing to or {@code null} if not - * subscribing to any resources for the given type. + * Returns the collection of resources currently subscribed to which have an authority matching + * one of those for which the ControlPlaneClient associated with the specified ServerInfo is + * the active one, or {@code null} if no such resources are currently subscribed to. * *

Note an empty collection indicates subscribing to resources of the given type with * wildcard mode. + * + * @param serverInfo the xds server to get the resources from + * @param type the type of the resources that should be retrieved */ // Must be synchronized. @Nullable - Collection getSubscribedResources(ServerInfo serverInfo, - XdsResourceType type); + Collection getSubscribedResources( + ServerInfo serverInfo, XdsResourceType type); Map> getSubscribedResourceTypesWithTypeUrl(); + + /** + * For any of the subscribers to one of the specified resources, if there isn't a result or + * an existing timer for the resource, start a timer for the resource. + */ + void startMissingResourceTimers(Collection resourceNames, + XdsResourceType resourceType); } } diff --git a/xds/src/main/java/io/grpc/xds/client/XdsClientImpl.java b/xds/src/main/java/io/grpc/xds/client/XdsClientImpl.java index d11808a7d8b..0584a3dbfdd 100644 --- a/xds/src/main/java/io/grpc/xds/client/XdsClientImpl.java +++ b/xds/src/main/java/io/grpc/xds/client/XdsClientImpl.java @@ -18,7 +18,6 @@ import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkNotNull; -import static io.grpc.xds.client.Bootstrapper.XDSTP_SCHEME; import static io.grpc.xds.client.XdsResourceType.ParsedResource; import static io.grpc.xds.client.XdsResourceType.ValidatedResourceUpdate; @@ -26,14 +25,15 @@ import com.google.common.base.Joiner; import com.google.common.base.Stopwatch; import com.google.common.base.Supplier; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import com.google.common.collect.ImmutableSet; import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.SettableFuture; import com.google.protobuf.Any; import io.grpc.Internal; import io.grpc.InternalLogId; import io.grpc.Status; +import io.grpc.StatusOr; import io.grpc.SynchronizationContext; import io.grpc.SynchronizationContext.ScheduledHandle; import io.grpc.internal.BackoffPolicy; @@ -41,36 +41,34 @@ import io.grpc.xds.client.Bootstrapper.AuthorityInfo; import io.grpc.xds.client.Bootstrapper.ServerInfo; import io.grpc.xds.client.XdsClient.ResourceStore; -import io.grpc.xds.client.XdsClient.XdsResponseHandler; import io.grpc.xds.client.XdsLogger.XdsLogLevel; -import java.net.URI; +import java.io.IOException; +import java.util.ArrayList; import java.util.Collection; import java.util.Collections; import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Objects; import java.util.Set; import java.util.concurrent.Executor; +import java.util.concurrent.Future; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; -import java.util.logging.Level; -import java.util.logging.Logger; +import java.util.stream.Collectors; import javax.annotation.Nullable; /** * XdsClient implementation. */ @Internal -public final class XdsClientImpl extends XdsClient implements XdsResponseHandler, ResourceStore { - - private static final boolean LOG_XDS_NODE_ID = Boolean.parseBoolean( - System.getenv("GRPC_LOG_XDS_NODE_ID")); - private static final Logger classLogger = Logger.getLogger(XdsClientImpl.class.getName()); +public final class XdsClientImpl extends XdsClient implements ResourceStore { // Longest time to wait, since the subscription to some resource, for concluding its absence. @VisibleForTesting public static final int INITIAL_RESOURCE_FETCH_TIMEOUT_SEC = 15; + public static final int EXTENDED_RESOURCE_FETCH_TIMEOUT_SEC = 30; private final SynchronizationContext syncContext = new SynchronizationContext( new Thread.UncaughtExceptionHandler() { @@ -80,21 +78,25 @@ public void uncaughtException(Thread t, Throwable e) { XdsLogLevel.ERROR, "Uncaught exception in XdsClient SynchronizationContext. Panic!", e); - // TODO(chengyuanzhang): better error handling. + // TODO: better error handling. throw new AssertionError(e); } }); - private final Map loadStatsManagerMap = - new HashMap<>(); - final Map serverLrsClientMap = - new HashMap<>(); - + private final Map loadStatsManagerMap = new HashMap<>(); + final Map serverLrsClientMap = new HashMap<>(); + /** Map of authority to its activated control plane client (affected by xds fallback). + * The last entry in the list for each value is the "active" CPC for the matching key */ + private final Map> activatedCpClients = new HashMap<>(); private final Map serverCpClientMap = new HashMap<>(); + + /** Maps resource type to the corresponding map of subscribers (keyed by resource name). */ private final Map, Map>> resourceSubscribers = new HashMap<>(); + /** Maps typeUrl to the corresponding XdsResourceType. */ private final Map> subscribedResourceTypeUrls = new HashMap<>(); + private final XdsTransportFactory xdsTransportFactory; private final Bootstrapper.BootstrapInfo bootstrapInfo; private final ScheduledExecutorService timeService; @@ -106,6 +108,7 @@ public void uncaughtException(Thread t, Throwable e) { private final XdsLogger logger; private volatile boolean isShutdown; private final MessagePrettyPrinter messagePrinter; + private final XdsClientMetricReporter metricReporter; public XdsClientImpl( XdsTransportFactory xdsTransportFactory, @@ -115,7 +118,8 @@ public XdsClientImpl( Supplier stopwatchSupplier, TimeProvider timeProvider, MessagePrettyPrinter messagePrinter, - Object securityConfig) { + Object securityConfig, + XdsClientMetricReporter metricReporter) { this.xdsTransportFactory = xdsTransportFactory; this.bootstrapInfo = bootstrapInfo; this.timeService = timeService; @@ -124,54 +128,10 @@ public XdsClientImpl( this.timeProvider = timeProvider; this.messagePrinter = messagePrinter; this.securityConfig = securityConfig; + this.metricReporter = metricReporter; logId = InternalLogId.allocate("xds-client", null); logger = XdsLogger.withLogId(logId); logger.log(XdsLogLevel.INFO, "Created"); - if (LOG_XDS_NODE_ID) { - classLogger.log(Level.INFO, "xDS node ID: {0}", bootstrapInfo.node().getId()); - } - } - - @Override - public void handleResourceResponse( - XdsResourceType xdsResourceType, ServerInfo serverInfo, String versionInfo, - List resources, String nonce, ProcessingTracker processingTracker) { - checkNotNull(xdsResourceType, "xdsResourceType"); - syncContext.throwIfNotInThisSynchronizationContext(); - Set toParseResourceNames = - xdsResourceType.shouldRetrieveResourceKeysForArgs() - ? getResourceKeys(xdsResourceType) - : null; - XdsResourceType.Args args = new XdsResourceType.Args(serverInfo, versionInfo, nonce, - bootstrapInfo, securityConfig, toParseResourceNames); - handleResourceUpdate(args, resources, xdsResourceType, processingTracker); - } - - @Override - public void handleStreamClosed(Status error) { - syncContext.throwIfNotInThisSynchronizationContext(); - cleanUpResourceTimers(); - for (Map> subscriberMap : - resourceSubscribers.values()) { - for (ResourceSubscriber subscriber : subscriberMap.values()) { - if (!subscriber.hasResult()) { - subscriber.onError(error, null); - } - } - } - } - - @Override - public void handleStreamRestarted(ServerInfo serverInfo) { - syncContext.throwIfNotInThisSynchronizationContext(); - for (Map> subscriberMap : - resourceSubscribers.values()) { - for (ResourceSubscriber subscriber : subscriberMap.values()) { - if (subscriber.serverInfo.equals(serverInfo)) { - subscriber.restartTimer(); - } - } - } } @Override @@ -190,7 +150,8 @@ public void run() { for (final LoadReportClient lrsClient : serverLrsClientMap.values()) { lrsClient.stopLoadReporting(); } - cleanUpResourceTimers(); + cleanUpResourceTimers(null); + activatedCpClients.clear(); } }); } @@ -205,20 +166,53 @@ public Map> getSubscribedResourceTypesWithTypeUrl() { return Collections.unmodifiableMap(subscribedResourceTypeUrls); } + private ControlPlaneClient getActiveCpc(String authority) { + List controlPlaneClients = activatedCpClients.get(authority); + if (controlPlaneClients == null || controlPlaneClients.isEmpty()) { + return null; + } + + return controlPlaneClients.get(controlPlaneClients.size() - 1); + } + @Nullable @Override - public Collection getSubscribedResources(ServerInfo serverInfo, - XdsResourceType type) { + public Collection getSubscribedResources( + ServerInfo serverInfo, XdsResourceType type) { + ControlPlaneClient targetCpc = serverCpClientMap.get(serverInfo); + if (targetCpc == null) { + return null; + } + + // This should include all of the authorities that targetCpc or a fallback from it is serving + List authorities = activatedCpClients.entrySet().stream() + .filter(entry -> entry.getValue().contains(targetCpc)) + .map(Map.Entry::getKey) + .collect(Collectors.toList()); + Map> resources = resourceSubscribers.getOrDefault(type, Collections.emptyMap()); - ImmutableSet.Builder builder = ImmutableSet.builder(); - for (String key : resources.keySet()) { - if (resources.get(key).serverInfo.equals(serverInfo)) { - builder.add(key); + + Collection retVal = resources.entrySet().stream() + .filter(entry -> authorities.contains(entry.getValue().authority)) + .map(Map.Entry::getKey) + .collect(Collectors.toList()); + + return retVal.isEmpty() ? null : retVal; + } + + @Override + public void startMissingResourceTimers(Collection resourceNames, + XdsResourceType resourceType) { + Map> subscriberMap = + resourceSubscribers.get(resourceType); + + for (String resourceName : resourceNames) { + ResourceSubscriber subscriber = subscriberMap.get(resourceName); + if (subscriber.respTimer == null && !subscriber.hasResult()) { + subscriber.restartTimer(); } } - Collection retVal = builder.build(); - return retVal.isEmpty() ? null : retVal; } // As XdsClient APIs becomes resource agnostic, subscribed resource types are dynamic. @@ -234,7 +228,7 @@ public void run() { // A map from a "resource type" to a map ("resource name": "resource metadata") ImmutableMap.Builder, Map> metadataSnapshot = ImmutableMap.builder(); - for (XdsResourceType resourceType: resourceSubscribers.keySet()) { + for (XdsResourceType resourceType : resourceSubscribers.keySet()) { ImmutableMap.Builder metadataMap = ImmutableMap.builder(); for (Map.Entry> resourceEntry : resourceSubscribers.get(resourceType).entrySet()) { @@ -255,9 +249,9 @@ public Object getSecurityConfig() { @Override public void watchXdsResource(XdsResourceType type, - String resourceName, - ResourceWatcher watcher, - Executor watcherExecutor) { + String resourceName, + ResourceWatcher watcher, + Executor watcherExecutor) { syncContext.execute(new Runnable() { @Override @SuppressWarnings("unchecked") @@ -268,36 +262,125 @@ public void run() { } ResourceSubscriber subscriber = (ResourceSubscriber) resourceSubscribers.get(type).get(resourceName); + if (subscriber == null) { logger.log(XdsLogLevel.INFO, "Subscribe {0} resource {1}", type, resourceName); subscriber = new ResourceSubscriber<>(type, resourceName); resourceSubscribers.get(type).put(resourceName, subscriber); - if (subscriber.controlPlaneClient != null) { - subscriber.controlPlaneClient.adjustResourceSubscription(type); + + if (subscriber.errorDescription == null) { + CpcWithFallbackState cpcToUse = manageControlPlaneClient(subscriber); + if (cpcToUse.cpc != null) { + cpcToUse.cpc.adjustResourceSubscription(type); + } } } + subscriber.addWatcher(watcher, watcherExecutor); } }); } + /** + * Gets a ControlPlaneClient for the subscriber's authority, creating one if necessary. + * If there already was an active CPC for this authority, and it is different from the one + * identified, then do fallback to the identified one (cpcToUse). + * + * @return identified CPC or {@code null} (if there are no valid ServerInfos associated with the + * subscriber's authority or CPC's for all are in backoff), and whether did a fallback. + */ + @VisibleForTesting + private CpcWithFallbackState manageControlPlaneClient( + ResourceSubscriber subscriber) { + + ControlPlaneClient cpcToUse; + boolean didFallback = false; + try { + cpcToUse = getOrCreateControlPlaneClient(subscriber.authority); + } catch (IllegalArgumentException e) { + if (subscriber.errorDescription == null) { + subscriber.errorDescription = "Bad configuration: " + e.getMessage(); + } + + subscriber.onError( + Status.INVALID_ARGUMENT.withDescription(subscriber.errorDescription), null); + return new CpcWithFallbackState(null, false); + } catch (IOException e) { + logger.log(XdsLogLevel.DEBUG, + "Could not create a control plane client for authority {0}: {1}", + subscriber.authority, e.getMessage()); + return new CpcWithFallbackState(null, false); + } + + ControlPlaneClient activeCpClient = getActiveCpc(subscriber.authority); + if (cpcToUse != activeCpClient) { + addCpcToAuthority(subscriber.authority, cpcToUse); // makes it active + if (activeCpClient != null) { + didFallback = cpcToUse != null && !cpcToUse.isInError(); + if (didFallback) { + logger.log(XdsLogLevel.INFO, "Falling back to XDS server {0}", + cpcToUse.getServerInfo().target()); + } else { + logger.log(XdsLogLevel.WARNING, "No working fallback XDS Servers found from {0}", + activeCpClient.getServerInfo().target()); + } + } + } + + return new CpcWithFallbackState(cpcToUse, didFallback); + } + + private void addCpcToAuthority(String authority, ControlPlaneClient cpcToUse) { + List controlPlaneClients = + activatedCpClients.computeIfAbsent(authority, k -> new ArrayList<>()); + + if (controlPlaneClients.contains(cpcToUse)) { + return; + } + + // if there are any missing CPCs between the last one and cpcToUse, add them + add cpcToUse + ImmutableList serverInfos = getServerInfos(authority); + for (int i = controlPlaneClients.size(); i < serverInfos.size(); i++) { + ServerInfo serverInfo = serverInfos.get(i); + ControlPlaneClient cpc = serverCpClientMap.get(serverInfo); + controlPlaneClients.add(cpc); + logger.log(XdsLogLevel.DEBUG, "Adding control plane client {0} to authority {1}", + cpc, authority); + cpcToUse.sendDiscoveryRequests(); + if (cpc == cpcToUse) { + break; + } + } + } + @Override public void cancelXdsResourceWatch(XdsResourceType type, - String resourceName, - ResourceWatcher watcher) { + String resourceName, + ResourceWatcher watcher) { syncContext.execute(new Runnable() { @Override @SuppressWarnings("unchecked") public void run() { ResourceSubscriber subscriber = - (ResourceSubscriber) resourceSubscribers.get(type).get(resourceName);; + (ResourceSubscriber) resourceSubscribers.get(type).get(resourceName); + if (subscriber == null) { + logger.log(XdsLogLevel.WARNING, "double cancel of resource watch for {0}:{1}", + type.typeName(), resourceName); + return; + } subscriber.removeWatcher(watcher); if (!subscriber.isWatched()) { subscriber.cancelResourceWatch(); resourceSubscribers.get(type).remove(resourceName); - if (subscriber.controlPlaneClient != null) { - subscriber.controlPlaneClient.adjustResourceSubscription(type); + + List controlPlaneClients = + activatedCpClients.get(subscriber.authority); + if (controlPlaneClients != null) { + controlPlaneClients.forEach((cpc) -> { + cpc.adjustResourceSubscription(type); + }); } + if (resourceSubscribers.get(type).isEmpty()) { resourceSubscribers.remove(type); subscribedResourceTypeUrls.remove(type.typeUrl()); @@ -327,9 +410,22 @@ public void run() { public LoadStatsManager2.ClusterLocalityStats addClusterLocalityStats( final ServerInfo serverInfo, String clusterName, @Nullable String edsServiceName, Locality locality) { + return addClusterLocalityStats(serverInfo, clusterName, edsServiceName, locality, null); + } + + @Override + public LoadStatsManager2.ClusterLocalityStats addClusterLocalityStats( + final ServerInfo serverInfo, + String clusterName, + @Nullable String edsServiceName, + Locality locality, + @Nullable BackendMetricPropagation backendMetricPropagation) { LoadStatsManager2 loadStatsManager = loadStatsManagerMap.get(serverInfo); + LoadStatsManager2.ClusterLocalityStats loadCounter = - loadStatsManager.getClusterLocalityStats(clusterName, edsServiceName, locality); + loadStatsManager.getClusterLocalityStats( + clusterName, edsServiceName, locality, backendMetricPropagation); + syncContext.execute(new Runnable() { @Override public void run() { @@ -350,30 +446,6 @@ public String toString() { return logId.toString(); } - @Override - protected void startSubscriberTimersIfNeeded(ServerInfo serverInfo) { - if (isShutDown()) { - return; - } - - syncContext.execute(new Runnable() { - @Override - public void run() { - if (isShutDown()) { - return; - } - - for (Map> subscriberMap : resourceSubscribers.values()) { - for (ResourceSubscriber subscriber : subscriberMap.values()) { - if (subscriber.serverInfo.equals(serverInfo) && subscriber.respTimer == null) { - subscriber.restartTimer(); - } - } - } - } - }); - } - private Set getResourceKeys(XdsResourceType xdsResourceType) { if (!resourceSubscribers.containsKey(xdsResourceType)) { return null; @@ -382,33 +454,77 @@ private Set getResourceKeys(XdsResourceType xdsResourceType) { return resourceSubscribers.get(xdsResourceType).keySet(); } - private void cleanUpResourceTimers() { + // cpcForThisStream is null when doing shutdown + private void cleanUpResourceTimers(ControlPlaneClient cpcForThisStream) { + Collection authoritiesForCpc = getActiveAuthorities(cpcForThisStream); + String target = cpcForThisStream == null ? "null" : cpcForThisStream.getServerInfo().target(); + logger.log(XdsLogLevel.DEBUG, "Cleaning up resource timers for CPC {0}, authorities {1}", + target, authoritiesForCpc); + for (Map> subscriberMap : resourceSubscribers.values()) { for (ResourceSubscriber subscriber : subscriberMap.values()) { - subscriber.stopTimer(); + if (cpcForThisStream == null || authoritiesForCpc.contains(subscriber.authority)) { + subscriber.stopTimer(); + } } } } - public ControlPlaneClient getOrCreateControlPlaneClient(ServerInfo serverInfo) { + private ControlPlaneClient getOrCreateControlPlaneClient(String authority) throws IOException { + // Optimize for the common case of a working ads stream already exists for the authority + ControlPlaneClient activeCpc = getActiveCpc(authority); + if (activeCpc != null && !activeCpc.isInError()) { + return activeCpc; + } + + ImmutableList serverInfos = getServerInfos(authority); + if (serverInfos == null) { + throw new IllegalArgumentException("No xds servers found for authority " + authority); + } + + for (ServerInfo serverInfo : serverInfos) { + ControlPlaneClient cpc = getOrCreateControlPlaneClient(serverInfo); + if (cpc.isInError()) { + continue; + } + return cpc; + } + + // Everything existed and is in backoff so throw + throw new IOException("All xds transports for authority " + authority + " are in backoff"); + } + + private ControlPlaneClient getOrCreateControlPlaneClient(ServerInfo serverInfo) { syncContext.throwIfNotInThisSynchronizationContext(); if (serverCpClientMap.containsKey(serverInfo)) { return serverCpClientMap.get(serverInfo); } - XdsTransportFactory.XdsTransport xdsTransport = xdsTransportFactory.create(serverInfo); + logger.log(XdsLogLevel.DEBUG, "Creating control plane client for {0}", serverInfo.target()); + XdsTransportFactory.XdsTransport xdsTransport; + try { + xdsTransport = xdsTransportFactory.create(serverInfo); + } catch (Exception e) { + String msg = String.format("Failed to create xds transport for %s: %s", + serverInfo.target(), e.getMessage()); + logger.log(XdsLogLevel.WARNING, msg); + xdsTransport = + new ControlPlaneClient.FailingXdsTransport(Status.UNAVAILABLE.withDescription(msg)); + } + ControlPlaneClient controlPlaneClient = new ControlPlaneClient( xdsTransport, serverInfo, bootstrapInfo.node(), - this, + new ResponseHandler(serverInfo), this, timeService, syncContext, backoffPolicyProvider, stopwatchSupplier, - this, - messagePrinter); + messagePrinter + ); + serverCpClientMap.put(serverInfo, controlPlaneClient); LoadStatsManager2 loadStatsManager = new LoadStatsManager2(stopwatchSupplier); @@ -428,45 +544,49 @@ public Map getServerLrsClientMap() { } @Nullable - private ServerInfo getServerInfo(String resource) { - if (resource.startsWith(XDSTP_SCHEME)) { - URI uri = URI.create(resource); - String authority = uri.getAuthority(); - if (authority == null) { - authority = ""; - } + private ImmutableList getServerInfos(String authority) { + if (authority != null) { AuthorityInfo authorityInfo = bootstrapInfo.authorities().get(authority); if (authorityInfo == null || authorityInfo.xdsServers().isEmpty()) { return null; } - return authorityInfo.xdsServers().get(0); + return authorityInfo.xdsServers(); } else { - return bootstrapInfo.servers().get(0); // use first server + return bootstrapInfo.servers(); } } @SuppressWarnings("unchecked") private void handleResourceUpdate( XdsResourceType.Args args, List resources, XdsResourceType xdsResourceType, - ProcessingTracker processingTracker) { + boolean isFirstResponse, ProcessingTracker processingTracker) { + ControlPlaneClient controlPlaneClient = serverCpClientMap.get(args.serverInfo); + + if (isFirstResponse) { + shutdownLowerPriorityCpcs(controlPlaneClient); + } + ValidatedResourceUpdate result = xdsResourceType.parse(args, resources); logger.log(XdsLogger.XdsLogLevel.INFO, "Received {0} Response version {1} nonce {2}. Parsed resources: {3}", - xdsResourceType.typeName(), args.versionInfo, args.nonce, result.unpackedResources); + xdsResourceType.typeName(), args.versionInfo, args.nonce, result.unpackedResources); Map> parsedResources = result.parsedResources; Set invalidResources = result.invalidResources; + metricReporter.reportResourceUpdates(Long.valueOf(parsedResources.size()), + Long.valueOf(invalidResources.size()), + args.getServerInfo().target(), xdsResourceType.typeUrl()); + List errors = result.errors; String errorDetail = null; if (errors.isEmpty()) { checkArgument(invalidResources.isEmpty(), "found invalid resources but missing errors"); - serverCpClientMap.get(args.serverInfo).ackResponse(xdsResourceType, args.versionInfo, - args.nonce); + controlPlaneClient.ackResponse(xdsResourceType, args.versionInfo, args.nonce); } else { errorDetail = Joiner.on('\n').join(errors); logger.log(XdsLogLevel.WARNING, "Failed processing {0} Response version {1} nonce {2}. Errors:\n{3}", xdsResourceType.typeName(), args.versionInfo, args.nonce, errorDetail); - serverCpClientMap.get(args.serverInfo).nackResponse(xdsResourceType, args.nonce, errorDetail); + controlPlaneClient.nackResponse(xdsResourceType, args.nonce, errorDetail); } long updateTime = timeProvider.currentTimeNanos(); @@ -483,8 +603,21 @@ private void handleResourceUpdate( } if (invalidResources.contains(resourceName)) { - // The resource update is invalid. Capture the error without notifying the watchers. + // The resource update is invalid (NACK). Handle as a data error. subscriber.onRejected(args.versionInfo, updateTime, errorDetail); + + // Handle data errors (NACKs) based on fail_on_data_errors server feature. + // When xdsDataErrorHandlingEnabled is true and fail_on_data_errors is present, + // delete cached data so onError will call onResourceChanged instead of onAmbientError. + // When xdsDataErrorHandlingEnabled is false, use old behavior (always keep cached data). + if (BootstrapperImpl.xdsDataErrorHandlingEnabled && subscriber.data != null + && args.serverInfo.failOnDataErrors()) { + subscriber.data = null; + } + // Call onError, which will decide whether to call onResourceChanged or onAmbientError + // based on whether data exists after the above deletion. + subscriber.onError(Status.UNAVAILABLE.withDescription(errorDetail), processingTracker); + continue; } // Nothing else to do for incremental ADS resources. @@ -492,73 +625,106 @@ private void handleResourceUpdate( continue; } - // Handle State of the World ADS: invalid resources. - if (invalidResources.contains(resourceName)) { - // The resource is missing. Reuse the cached resource if possible. - if (subscriber.data == null) { - // No cached data. Notify the watchers of an invalid update. - subscriber.onError(Status.UNAVAILABLE.withDescription(errorDetail), processingTracker); - } - continue; - } - // For State of the World services, notify watchers when their watched resource is missing // from the ADS update. Note that we can only do this if the resource update is coming from // the same xDS server that the ResourceSubscriber is subscribed to. - if (subscriber.serverInfo.equals(args.serverInfo)) { - subscriber.onAbsent(processingTracker); + if (getActiveCpc(subscriber.authority) == controlPlaneClient) { + subscriber.onAbsent(processingTracker, args.serverInfo); } } } - /** - * Tracks a single subscribed resource. - */ + @Override + public Future reportServerConnections(ServerConnectionCallback callback) { + SettableFuture future = SettableFuture.create(); + syncContext.execute(() -> { + serverCpClientMap.forEach((serverInfo, controlPlaneClient) -> + callback.reportServerConnectionGauge( + !controlPlaneClient.isInError(), serverInfo.target())); + future.set(null); + }); + return future; + } + + private void shutdownLowerPriorityCpcs(ControlPlaneClient activatedCpc) { + // For each authority, remove any control plane clients, with lower priority than the activated + // one, from activatedCpClients storing them all in cpcsToShutdown. + Set cpcsToShutdown = new HashSet<>(); + for ( List cpcsForAuth : activatedCpClients.values()) { + if (cpcsForAuth == null) { + continue; + } + int index = cpcsForAuth.indexOf(activatedCpc); + if (index > -1) { + cpcsToShutdown.addAll(cpcsForAuth.subList(index + 1, cpcsForAuth.size())); + cpcsForAuth.subList(index + 1, cpcsForAuth.size()).clear(); // remove lower priority cpcs + } + } + + // Shutdown any lower priority control plane clients identified above that aren't still being + // used by another authority. If they are still being used let the XDS server know that we + // no longer are interested in subscriptions for authorities we are no longer responsible for. + for (ControlPlaneClient cpc : cpcsToShutdown) { + if (activatedCpClients.values().stream().noneMatch(list -> list.contains(cpc))) { + cpc.shutdown(); + serverCpClientMap.remove(cpc.getServerInfo()); + } else { + cpc.sendDiscoveryRequests(); + } + } + } + + + /** Tracks a single subscribed resource. */ private final class ResourceSubscriber { - @Nullable private final ServerInfo serverInfo; - @Nullable private final ControlPlaneClient controlPlaneClient; + @Nullable + private final String authority; private final XdsResourceType type; private final String resource; private final Map, Executor> watchers = new HashMap<>(); - @Nullable private T data; + @Nullable + private T data; private boolean absent; // Tracks whether the deletion has been ignored per bootstrap server feature. // See https://github.com/grpc/proposal/blob/master/A53-xds-ignore-resource-deletion.md private boolean resourceDeletionIgnored; - @Nullable private ScheduledHandle respTimer; - @Nullable private ResourceMetadata metadata; - @Nullable private String errorDescription; + @Nullable + private ScheduledHandle respTimer; + @Nullable + private ResourceMetadata metadata; + @Nullable + private String errorDescription; + @Nullable + private Status lastError; ResourceSubscriber(XdsResourceType type, String resource) { syncContext.throwIfNotInThisSynchronizationContext(); this.type = type; this.resource = resource; - this.serverInfo = getServerInfo(resource); - if (serverInfo == null) { + this.authority = getAuthorityFromResourceName(resource); + if (getServerInfos(authority) == null) { this.errorDescription = "Wrong configuration: xds server does not exist for resource " + resource; - this.controlPlaneClient = null; return; } + // Initialize metadata in UNKNOWN state to cover the case when resource subscriber, // is created but not yet requested because the client is in backoff. this.metadata = ResourceMetadata.newResourceMetadataUnknown(); + } - ControlPlaneClient controlPlaneClient = null; - try { - controlPlaneClient = getOrCreateControlPlaneClient(serverInfo); - if (controlPlaneClient.isInBackoff()) { - return; - } - } catch (IllegalArgumentException e) { - controlPlaneClient = null; - this.errorDescription = "Bad configuration: " + e.getMessage(); - return; - } finally { - this.controlPlaneClient = controlPlaneClient; - } - - restartTimer(); + @Override + public String toString() { + return "ResourceSubscriber{" + + "resource='" + resource + '\'' + + ", authority='" + authority + '\'' + + ", type=" + type + + ", watchers=" + watchers.size() + + ", data=" + data + + ", absent=" + absent + + ", resourceDeletionIgnored=" + resourceDeletionIgnored + + ", errorDescription='" + errorDescription + '\'' + + '}'; } void addWatcher(ResourceWatcher watcher, Executor watcherExecutor) { @@ -566,20 +732,28 @@ void addWatcher(ResourceWatcher watcher, Executor watcherExecutor) { watchers.put(watcher, watcherExecutor); T savedData = data; boolean savedAbsent = absent; + Status savedError = lastError; watcherExecutor.execute(() -> { if (errorDescription != null) { - watcher.onError(Status.INVALID_ARGUMENT.withDescription(errorDescription)); + watcher.onResourceChanged(StatusOr.fromStatus( + Status.INVALID_ARGUMENT.withDescription(errorDescription))); return; } if (savedData != null) { - notifyWatcher(watcher, savedData); + watcher.onResourceChanged(StatusOr.fromValue(savedData)); + if (savedError != null) { + watcher.onAmbientError(savedError); + } + } else if (savedError != null) { + watcher.onResourceChanged(StatusOr.fromStatus(savedError)); } else if (savedAbsent) { - watcher.onResourceDoesNotExist(resource); + watcher.onResourceChanged(StatusOr.fromStatus( + Status.NOT_FOUND.withDescription("Resource " + resource + " does not exist"))); } }); } - void removeWatcher(ResourceWatcher watcher) { + void removeWatcher(ResourceWatcher watcher) { checkArgument(watchers.containsKey(watcher), "watcher %s not registered", watcher); watchers.remove(watcher); } @@ -588,17 +762,22 @@ void restartTimer() { if (data != null || absent) { // resource already resolved return; } - if (!controlPlaneClient.isReady()) { // When client becomes ready, it triggers a restartTimer + ControlPlaneClient activeCpc = getActiveCpc(authority); + if (activeCpc == null || !activeCpc.isReady()) { + // When client becomes ready, it triggers a restartTimer for all relevant subscribers. return; } + ServerInfo serverInfo = activeCpc.getServerInfo(); + int timeoutSec = serverInfo.resourceTimerIsTransientError() + ? EXTENDED_RESOURCE_FETCH_TIMEOUT_SEC : INITIAL_RESOURCE_FETCH_TIMEOUT_SEC; class ResourceNotFound implements Runnable { @Override public void run() { logger.log(XdsLogLevel.INFO, "{0} resource {1} initial fetch timeout", type, resource); + onAbsent(null, activeCpc.getServerInfo()); respTimer = null; - onAbsent(null); } @Override @@ -610,9 +789,11 @@ public String toString() { // Initial fetch scheduled or rescheduled, transition metadata state to REQUESTED. metadata = ResourceMetadata.newResourceMetadataRequested(); + if (respTimer != null) { + respTimer.cancel(); + } respTimer = syncContext.schedule( - new ResourceNotFound(), INITIAL_RESOURCE_FETCH_TIMEOUT_SEC, TimeUnit.SECONDS, - timeService); + new ResourceNotFound(), timeoutSec, TimeUnit.SECONDS, timeService); } void stopTimer() { @@ -633,8 +814,7 @@ void cancelResourceWatch() { message += " for which we previously ignored a deletion"; logLevel = XdsLogLevel.FORCE_INFO; } - logger.log(logLevel, message, type, resource, - serverInfo != null ? serverInfo.target() : "unknown"); + logger.log(logLevel, message, type, resource, getTarget()); } boolean isWatched() { @@ -651,23 +831,25 @@ void onData(ParsedResource parsedResource, String version, long updateTime, respTimer.cancel(); respTimer = null; } - this.metadata = ResourceMetadata - .newResourceMetadataAcked(parsedResource.getRawResource(), version, updateTime); ResourceUpdate oldData = this.data; this.data = parsedResource.getResourceUpdate(); + this.metadata = ResourceMetadata.newResourceMetadataAcked( + parsedResource.getRawResource(), version, updateTime); absent = false; + lastError = null; if (resourceDeletionIgnored) { logger.log(XdsLogLevel.FORCE_INFO, "xds server {0}: server returned new version " + "of resource for which we previously ignored a deletion: type {1} name {2}", - serverInfo != null ? serverInfo.target() : "unknown", type, resource); + getTarget(), type, resource); resourceDeletionIgnored = false; } if (!Objects.equals(oldData, data)) { + StatusOr update = StatusOr.fromValue(data); for (ResourceWatcher watcher : watchers.keySet()) { processingTracker.startTask(); watchers.get(watcher).execute(() -> { try { - notifyWatcher(watcher, data); + watcher.onResourceChanged(update); } finally { processingTracker.onComplete(); } @@ -676,37 +858,85 @@ void onData(ParsedResource parsedResource, String version, long updateTime, } } - void onAbsent(@Nullable ProcessingTracker processingTracker) { + private String getTarget() { + ControlPlaneClient activeCpc = getActiveCpc(authority); + return (activeCpc != null) + ? activeCpc.getServerInfo().target() + : "unknown"; + } + + void onAbsent(@Nullable ProcessingTracker processingTracker, ServerInfo serverInfo) { if (respTimer != null && respTimer.isPending()) { // too early to conclude absence return; } - // Ignore deletion of State of the World resources when this feature is on, - // and the resource is reusable. - boolean ignoreResourceDeletionEnabled = - serverInfo != null && serverInfo.ignoreResourceDeletion(); - if (ignoreResourceDeletionEnabled && type.isFullStateOfTheWorld() && data != null) { - if (!resourceDeletionIgnored) { - logger.log(XdsLogLevel.FORCE_WARNING, - "xds server {0}: ignoring deletion for resource type {1} name {2}}", - serverInfo.target(), type, resource); - resourceDeletionIgnored = true; + // Handle data errors (resource deletions) based on fail_on_data_errors server feature. + // When xdsDataErrorHandlingEnabled is true and fail_on_data_errors is not present, + // we treat deletions as ambient errors and keep using the cached resource. + // When fail_on_data_errors is present, we delete the cached resource and fail. + // When xdsDataErrorHandlingEnabled is false, use the old behavior (ignore_resource_deletion). + boolean ignoreResourceDeletionEnabled = serverInfo.ignoreResourceDeletion(); + boolean failOnDataErrors = serverInfo.failOnDataErrors(); + boolean xdsDataErrorHandlingEnabled = BootstrapperImpl.xdsDataErrorHandlingEnabled; + + if (type.isFullStateOfTheWorld() && data != null) { + // New behavior (per gRFC A88): Default is to treat deletions as ambient errors + if (xdsDataErrorHandlingEnabled && !failOnDataErrors) { + if (!resourceDeletionIgnored) { + logger.log(XdsLogLevel.FORCE_WARNING, + "xds server {0}: ignoring deletion for resource type {1} name {2}}", + serverInfo.target(), type, resource); + resourceDeletionIgnored = true; + } + Status deletionStatus = Status.NOT_FOUND.withDescription( + "Resource " + resource + " deleted from server"); + onAmbientError(deletionStatus, processingTracker); + return; + } + // Old behavior: Use ignore_resource_deletion server feature + if (!xdsDataErrorHandlingEnabled && ignoreResourceDeletionEnabled) { + if (!resourceDeletionIgnored) { + logger.log(XdsLogLevel.FORCE_WARNING, + "xds server {0}: ignoring deletion for resource type {1} name {2}}", + serverInfo.target(), type, resource); + resourceDeletionIgnored = true; + } + Status deletionStatus = Status.NOT_FOUND.withDescription( + "Resource " + resource + " deleted from server"); + onAmbientError(deletionStatus, processingTracker); + return; } - return; } logger.log(XdsLogLevel.INFO, "Conclude {0} resource {1} not exist", type, resource); if (!absent) { data = null; absent = true; - metadata = ResourceMetadata.newResourceMetadataDoesNotExist(); - for (ResourceWatcher watcher : watchers.keySet()) { + lastError = null; + + Status status; + if (respTimer == null) { + status = Status.NOT_FOUND.withDescription("Resource " + resource + " does not exist"); + metadata = ResourceMetadata.newResourceMetadataDoesNotExist(); + } else { + status = serverInfo.resourceTimerIsTransientError() + ? Status.UNAVAILABLE.withDescription( + "Timed out waiting for resource " + resource + " from xDS server") + : Status.NOT_FOUND.withDescription( + "Timed out waiting for resource " + resource + " from xDS server"); + metadata = serverInfo.resourceTimerIsTransientError() + ? ResourceMetadata.newResourceMetadataTimeout() + : ResourceMetadata.newResourceMetadataDoesNotExist(); + } + + StatusOr update = StatusOr.fromStatus(status); + for (Map.Entry, Executor> entry : watchers.entrySet()) { if (processingTracker != null) { processingTracker.startTask(); } - watchers.get(watcher).execute(() -> { + entry.getValue().execute(() -> { try { - watcher.onResourceDoesNotExist(resource); + entry.getKey().onResourceChanged(update); } finally { if (processingTracker != null) { processingTracker.onComplete(); @@ -729,14 +959,39 @@ void onError(Status error, @Nullable ProcessingTracker tracker) { Status errorAugmented = Status.fromCode(error.getCode()) .withDescription(description + "nodeID: " + bootstrapInfo.node().getId()) .withCause(error.getCause()); + this.lastError = errorAugmented; + + if (data != null) { + // We have cached data, so this is an ambient error. + onAmbientError(errorAugmented, tracker); + } else { + // No data, this is a definitive resource error. + StatusOr update = StatusOr.fromStatus(errorAugmented); + for (Map.Entry, Executor> entry : watchers.entrySet()) { + if (tracker != null) { + tracker.startTask(); + } + entry.getValue().execute(() -> { + try { + entry.getKey().onResourceChanged(update); + } finally { + if (tracker != null) { + tracker.onComplete(); + } + } + }); + } + } + } - for (ResourceWatcher watcher : watchers.keySet()) { + private void onAmbientError(Status error, @Nullable ProcessingTracker tracker) { + for (Map.Entry, Executor> entry : watchers.entrySet()) { if (tracker != null) { tracker.startTask(); } - watchers.get(watcher).execute(() -> { + entry.getValue().execute(() -> { try { - watcher.onError(errorAugmented); + entry.getKey().onAmbientError(error); } finally { if (tracker != null) { tracker.onComplete(); @@ -748,12 +1003,100 @@ void onError(Status error, @Nullable ProcessingTracker tracker) { void onRejected(String rejectedVersion, long rejectedTime, String rejectedDetails) { metadata = ResourceMetadata - .newResourceMetadataNacked(metadata, rejectedVersion, rejectedTime, rejectedDetails); + .newResourceMetadataNacked(metadata, rejectedVersion, rejectedTime, rejectedDetails, + data != null); + } + } + + private class ResponseHandler implements XdsResponseHandler { + final ServerInfo serverInfo; + + ResponseHandler(ServerInfo serverInfo) { + this.serverInfo = serverInfo; + } + + @Override + public void handleResourceResponse( + XdsResourceType xdsResourceType, ServerInfo serverInfo, String versionInfo, + List resources, String nonce, boolean isFirstResponse, + ProcessingTracker processingTracker) { + checkNotNull(xdsResourceType, "xdsResourceType"); + syncContext.throwIfNotInThisSynchronizationContext(); + Set toParseResourceNames = + xdsResourceType.shouldRetrieveResourceKeysForArgs() + ? getResourceKeys(xdsResourceType) + : null; + XdsResourceType.Args args = new XdsResourceType.Args(serverInfo, versionInfo, nonce, + bootstrapInfo, securityConfig, toParseResourceNames); + handleResourceUpdate(args, resources, xdsResourceType, isFirstResponse, processingTracker); } - private void notifyWatcher(ResourceWatcher watcher, T update) { - watcher.onChanged(update); + @Override + public void handleStreamClosed(Status status, boolean shouldTryFallback) { + syncContext.throwIfNotInThisSynchronizationContext(); + + ControlPlaneClient cpcClosed = serverCpClientMap.get(serverInfo); + if (cpcClosed == null) { + logger.log(XdsLogLevel.DEBUG, + "Couldn't find closing CPC for {0}, so skipping cleanup and reporting", serverInfo); + return; + } + + cleanUpResourceTimers(cpcClosed); + + if (status.isOk()) { + return; // Not considered an error + } + + metricReporter.reportServerFailure(1L, serverInfo.target()); + + Collection authoritiesForClosedCpc = getActiveAuthorities(cpcClosed); + for (Map> subscriberMap : + resourceSubscribers.values()) { + for (ResourceSubscriber subscriber : subscriberMap.values()) { + if (!authoritiesForClosedCpc.contains(subscriber.authority)) { + continue; + } + // If subscriber already has data, this is an ambient error. + if (subscriber.hasResult()) { + subscriber.onError(status, null); + continue; + } + + // try to fallback to lower priority control plane client + if (shouldTryFallback && manageControlPlaneClient(subscriber).didFallback) { + authoritiesForClosedCpc.remove(subscriber.authority); + if (authoritiesForClosedCpc.isEmpty()) { + return; // optimization: no need to continue once all authorities have done fallback + } + continue; // since we did fallback, don't consider it an error + } + + subscriber.onError(status, null); + } + } + } + } + + private static class CpcWithFallbackState { + ControlPlaneClient cpc; + boolean didFallback; + + private CpcWithFallbackState(ControlPlaneClient cpc, boolean didFallback) { + this.cpc = cpc; + this.didFallback = didFallback; } } + private Collection getActiveAuthorities(ControlPlaneClient cpc) { + List asList = activatedCpClients.entrySet().stream() + .filter(entry -> !entry.getValue().isEmpty() + && cpc == entry.getValue().get(entry.getValue().size() - 1)) + .map(Map.Entry::getKey) + .collect(Collectors.toList()); + + // Since this is usually used for contains, use a set when the list is large + return (asList.size() < 100) ? asList : new HashSet<>(asList); + } + } diff --git a/xds/src/main/java/io/grpc/xds/client/XdsClientMetricReporter.java b/xds/src/main/java/io/grpc/xds/client/XdsClientMetricReporter.java new file mode 100644 index 00000000000..a044d501759 --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/client/XdsClientMetricReporter.java @@ -0,0 +1,48 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.client; + +import io.grpc.Internal; + +/** + * Interface for reporting metrics from the xDS client. + */ +@Internal +public interface XdsClientMetricReporter { + + /** + * Reports number of valid and invalid resources. + * + * @param validResourceCount Number of resources that were valid. + * @param invalidResourceCount Number of resources that were invalid. + * @param xdsServer Target URI of the xDS server with which the XdsClient is communicating. + * @param resourceType Type of XDS resource (e.g., "envoy.config.listener.v3.Listener"). + */ + default void reportResourceUpdates(long validResourceCount, long invalidResourceCount, + String xdsServer, String resourceType) { + } + + /** + * Reports number of xDS servers going from healthy to unhealthy. + * + * @param serverFailure Number of xDS server failures. + * @param xdsServer Target URI of the xDS server with which the XdsClient is communicating. + */ + default void reportServerFailure(long serverFailure, String xdsServer) { + } + +} diff --git a/xds/src/main/java/io/grpc/xds/client/XdsResourceType.java b/xds/src/main/java/io/grpc/xds/client/XdsResourceType.java index f15d6524751..4d6e75b1809 100644 --- a/xds/src/main/java/io/grpc/xds/client/XdsResourceType.java +++ b/xds/src/main/java/io/grpc/xds/client/XdsResourceType.java @@ -20,8 +20,6 @@ import static io.grpc.xds.client.XdsClient.canonifyResourceName; import static io.grpc.xds.client.XdsClient.isResourceNameValid; -import com.google.common.annotations.VisibleForTesting; -import com.google.common.base.Strings; import com.google.protobuf.Any; import com.google.protobuf.InvalidProtocolBufferException; import com.google.protobuf.Message; @@ -35,38 +33,35 @@ import java.util.List; import java.util.Map; import java.util.Set; +import java.util.logging.Level; +import java.util.logging.Logger; import javax.annotation.Nullable; @ExperimentalApi("https://github.com/grpc/grpc-java/issues/10847") public abstract class XdsResourceType { + private static final Logger log = Logger.getLogger(XdsResourceType.class.getName()); + static final String TYPE_URL_RESOURCE = "type.googleapis.com/envoy.service.discovery.v3.Resource"; protected static final String TRANSPORT_SOCKET_NAME_TLS = "envoy.transport_sockets.tls"; - @VisibleForTesting - public static final String HASH_POLICY_FILTER_STATE_KEY = "io.grpc.channel_id"; - @VisibleForTesting - public static boolean enableRouteLookup = getFlag("GRPC_EXPERIMENTAL_XDS_RLS_LB", true); - @VisibleForTesting - public static boolean enableLeastRequest = - !Strings.isNullOrEmpty(System.getenv("GRPC_EXPERIMENTAL_ENABLE_LEAST_REQUEST")) - ? Boolean.parseBoolean(System.getenv("GRPC_EXPERIMENTAL_ENABLE_LEAST_REQUEST")) - : Boolean.parseBoolean(System.getProperty("io.grpc.xds.experimentalEnableLeastRequest")); - - @VisibleForTesting - public static boolean enableWrr = getFlag("GRPC_EXPERIMENTAL_XDS_WRR_LB", true); - - @VisibleForTesting - protected static boolean enablePickFirst = getFlag("GRPC_EXPERIMENTAL_PICKFIRST_LB_CONFIG", true); - - protected static final String TYPE_URL_CLUSTER_CONFIG = - "type.googleapis.com/envoy.extensions.clusters.aggregate.v3.ClusterConfig"; + protected static final String TYPE_URL_TYPED_STRUCT_UDPA = "type.googleapis.com/udpa.type.v1.TypedStruct"; protected static final String TYPE_URL_TYPED_STRUCT = "type.googleapis.com/xds.type.v3.TypedStruct"; + /** + * Extract the resource name from an older resource type that included the name within the + * resource contents itself. The newer approach has resources wrapped with {@code + * envoy.service.discovery.v3.Resource} which then provides the name. This method is only called + * for the old approach. + * + * @return the resource's name, or {@code null} if name is not stored within the resource contents + */ @Nullable - protected abstract String extractResourceName(Message unpackedResource); + protected String extractResourceName(Message unpackedResource) { + return null; + } protected abstract Class unpackedClassName(); @@ -148,15 +143,24 @@ ValidatedResourceUpdate parse(Args args, List resources) { Any resource = resources.get(i); Message unpackedMessage; + String name = ""; try { - resource = maybeUnwrapResources(resource); + if (resource.getTypeUrl().equals(TYPE_URL_RESOURCE)) { + Resource wrappedResource = unpackCompatibleType(resource, Resource.class, + TYPE_URL_RESOURCE, null); + resource = wrappedResource.getResource(); + name = wrappedResource.getName(); + } unpackedMessage = unpackCompatibleType(resource, unpackedClassName(), typeUrl(), null); } catch (InvalidProtocolBufferException e) { errors.add(String.format("%s response Resource index %d - can't decode %s: %s", typeName(), i, unpackedClassName().getSimpleName(), e.getMessage())); continue; } - String name = extractResourceName(unpackedMessage); + // Fallback to inner resource name if the outer resource didn't have a name. + if (name.isEmpty()) { + name = extractResourceName(unpackedMessage); + } if (name == null || !isResourceNameValid(name, resource.getTypeUrl())) { errors.add( "Unsupported resource name: " + name + " for type: " + typeName()); @@ -176,6 +180,16 @@ ValidatedResourceUpdate parse(Args args, List resources) { typeName(), unpackedClassName().getSimpleName(), cname, e.getMessage())); invalidResources.add(cname); continue; + } catch (Throwable t) { + log.log(Level.FINE, "Unexpected error in doParse()", t); + String errorMessage = t.getClass().getSimpleName(); + if (t.getMessage() != null) { + errorMessage = errorMessage + ": " + t.getMessage(); + } + errors.add(String.format("%s response '%s' unexpected error: %s", + typeName(), cname, errorMessage)); + invalidResources.add(cname); + continue; } // Resource parsed successfully. @@ -209,16 +223,6 @@ protected static T unpackCompatibleType( return any.unpack(clazz); } - private Any maybeUnwrapResources(Any resource) - throws InvalidProtocolBufferException { - if (resource.getTypeUrl().equals(TYPE_URL_RESOURCE)) { - return unpackCompatibleType(resource, Resource.class, TYPE_URL_RESOURCE, - null).getResource(); - } else { - return resource; - } - } - static final class ParsedResource { private final T resourceUpdate; private final Any rawResource; @@ -254,62 +258,4 @@ public ValidatedResourceUpdate(Map> parsedResources, this.errors = errors; } } - - private static boolean getFlag(String envVarName, boolean enableByDefault) { - String envVar = System.getenv(envVarName); - if (enableByDefault) { - return Strings.isNullOrEmpty(envVar) || Boolean.parseBoolean(envVar); - } else { - return !Strings.isNullOrEmpty(envVar) && Boolean.parseBoolean(envVar); - } - } - - @VisibleForTesting - public static final class StructOrError { - - /** - * Returns a {@link StructOrError} for the successfully converted data object. - */ - public static StructOrError fromStruct(T struct) { - return new StructOrError<>(struct); - } - - /** - * Returns a {@link StructOrError} for the failure to convert the data object. - */ - public static StructOrError fromError(String errorDetail) { - return new StructOrError<>(errorDetail); - } - - private final String errorDetail; - private final T struct; - - private StructOrError(T struct) { - this.struct = checkNotNull(struct, "struct"); - this.errorDetail = null; - } - - private StructOrError(String errorDetail) { - this.struct = null; - this.errorDetail = checkNotNull(errorDetail, "errorDetail"); - } - - /** - * Returns struct if exists, otherwise null. - */ - @VisibleForTesting - @Nullable - public T getStruct() { - return struct; - } - - /** - * Returns error detail if exists, otherwise null. - */ - @VisibleForTesting - @Nullable - public String getErrorDetail() { - return errorDetail; - } - } } diff --git a/xds/src/main/java/io/grpc/xds/internal/MatcherParser.java b/xds/src/main/java/io/grpc/xds/internal/MatcherParser.java index 39b80bbcc03..91b77b05d01 100644 --- a/xds/src/main/java/io/grpc/xds/internal/MatcherParser.java +++ b/xds/src/main/java/io/grpc/xds/internal/MatcherParser.java @@ -26,9 +26,12 @@ public static Matchers.HeaderMatcher parseHeaderMatcher( io.envoyproxy.envoy.config.route.v3.HeaderMatcher proto) { switch (proto.getHeaderMatchSpecifierCase()) { case EXACT_MATCH: + @SuppressWarnings("deprecation") // gRFC A63: support indefinitely + String exactMatch = proto.getExactMatch(); return Matchers.HeaderMatcher.forExactValue( - proto.getName(), proto.getExactMatch(), proto.getInvertMatch()); + proto.getName(), exactMatch, proto.getInvertMatch()); case SAFE_REGEX_MATCH: + @SuppressWarnings("deprecation") // gRFC A63: support indefinitely String rawPattern = proto.getSafeRegexMatch().getRegex(); Pattern safeRegExMatch; try { @@ -49,14 +52,20 @@ public static Matchers.HeaderMatcher parseHeaderMatcher( return Matchers.HeaderMatcher.forPresent( proto.getName(), proto.getPresentMatch(), proto.getInvertMatch()); case PREFIX_MATCH: + @SuppressWarnings("deprecation") // gRFC A63: support indefinitely + String prefixMatch = proto.getPrefixMatch(); return Matchers.HeaderMatcher.forPrefix( - proto.getName(), proto.getPrefixMatch(), proto.getInvertMatch()); + proto.getName(), prefixMatch, proto.getInvertMatch()); case SUFFIX_MATCH: + @SuppressWarnings("deprecation") // gRFC A63: support indefinitely + String suffixMatch = proto.getSuffixMatch(); return Matchers.HeaderMatcher.forSuffix( - proto.getName(), proto.getSuffixMatch(), proto.getInvertMatch()); + proto.getName(), suffixMatch, proto.getInvertMatch()); case CONTAINS_MATCH: + @SuppressWarnings("deprecation") // gRFC A63: support indefinitely + String containsMatch = proto.getContainsMatch(); return Matchers.HeaderMatcher.forContains( - proto.getName(), proto.getContainsMatch(), proto.getInvertMatch()); + proto.getName(), containsMatch, proto.getInvertMatch()); case STRING_MATCH: return Matchers.HeaderMatcher.forString( proto.getName(), parseStringMatcher(proto.getStringMatch()), proto.getInvertMatch()); @@ -88,4 +97,25 @@ public static Matchers.StringMatcher parseStringMatcher( "Unknown StringMatcher match pattern: " + proto.getMatchPatternCase()); } } + + /** Translates envoy proto FractionalPercent to internal FractionMatcher. */ + public static Matchers.FractionMatcher parseFractionMatcher( + io.envoyproxy.envoy.type.v3.FractionalPercent proto) { + int denominator; + switch (proto.getDenominator()) { + case HUNDRED: + denominator = 100; + break; + case TEN_THOUSAND: + denominator = 10_000; + break; + case MILLION: + denominator = 1_000_000; + break; + case UNRECOGNIZED: + default: + throw new IllegalArgumentException("Unknown denominator type: " + proto.getDenominator()); + } + return Matchers.FractionMatcher.create(proto.getNumerator(), denominator); + } } diff --git a/xds/src/main/java/io/grpc/xds/internal/Matchers.java b/xds/src/main/java/io/grpc/xds/internal/Matchers.java index f833fd2e480..228b20cfcd7 100644 --- a/xds/src/main/java/io/grpc/xds/internal/Matchers.java +++ b/xds/src/main/java/io/grpc/xds/internal/Matchers.java @@ -22,6 +22,7 @@ import com.google.re2j.Pattern; import java.math.BigInteger; import java.net.InetAddress; +import java.util.Locale; import javax.annotation.Nullable; /** @@ -273,11 +274,11 @@ public boolean matches(String args) { : exact().equals(args); } else if (prefix() != null) { return ignoreCase() - ? args.toLowerCase().startsWith(prefix().toLowerCase()) + ? args.toLowerCase(Locale.ROOT).startsWith(prefix().toLowerCase(Locale.ROOT)) : args.startsWith(prefix()); } else if (suffix() != null) { return ignoreCase() - ? args.toLowerCase().endsWith(suffix().toLowerCase()) + ? args.toLowerCase(Locale.ROOT).endsWith(suffix().toLowerCase(Locale.ROOT)) : args.endsWith(suffix()); } else if (contains() != null) { return args.contains(contains()); diff --git a/xds/src/main/java/io/grpc/xds/internal/MetricReportUtils.java b/xds/src/main/java/io/grpc/xds/internal/MetricReportUtils.java new file mode 100644 index 00000000000..4194cab76d3 --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/internal/MetricReportUtils.java @@ -0,0 +1,119 @@ +/* + * Copyright 2026 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal; + +import com.google.auto.value.AutoValue; +import io.grpc.services.MetricReport; +import java.util.Optional; +import java.util.OptionalDouble; + + +/** + * Utilities for parsing and resolving metrics from {@link MetricReport}. + */ +public final class MetricReportUtils { + + private MetricReportUtils() {} + + public enum MetricType { + CPU_UTILIZATION, + APPLICATION_UTILIZATION, + MEMORY_UTILIZATION, + UTILIZATION, + NAMED_METRICS, + INVALID + } + + @AutoValue + public abstract static class ParsedMetricName { + public abstract MetricType getMetricType(); + + public abstract Optional getKey(); + + public static ParsedMetricName create(MetricType metricType, Optional key) { + return new AutoValue_MetricReportUtils_ParsedMetricName(metricType, key); + } + + /** + * Pre-parses a custom metric name into a {@link ParsedMetricName}. + * + * @param name The custom metric name to parse. + * @return The parsed metric name. + */ + public static ParsedMetricName parse(String name) { + if (name.equals("cpu_utilization")) { + return create(MetricType.CPU_UTILIZATION, Optional.empty()); + } + if (name.equals("application_utilization")) { + return create(MetricType.APPLICATION_UTILIZATION, Optional.empty()); + } + if (name.equals("mem_utilization")) { + return create(MetricType.MEMORY_UTILIZATION, Optional.empty()); + } + if (name.startsWith("utilization.")) { + return create(MetricType.UTILIZATION, Optional.of(name.substring("utilization.".length()))); + } + if (name.startsWith("named_metrics.")) { + return create(MetricType.NAMED_METRICS, + Optional.of(name.substring("named_metrics.".length()))); + } + return create(MetricType.INVALID, Optional.empty()); + } + + } + + /** + * Resolves a custom metric value for `parsedMetric` + * Returns OptionalDouble.empty() if the metric is absent or invalid. + * + * @param report The metric report to query. + * @param parsedMetric The parsed metric to lookup. + * @return The metric value wrapped in an OptionalDouble, or empty if absent. + */ + + public static OptionalDouble getMetricValue(MetricReport report, ParsedMetricName parsedMetric) { + switch (parsedMetric.getMetricType()) { + case CPU_UTILIZATION: + return OptionalDouble.of(report.getCpuUtilization()); + case APPLICATION_UTILIZATION: + return OptionalDouble.of(report.getApplicationUtilization()); + case MEMORY_UTILIZATION: + return OptionalDouble.of(report.getMemoryUtilization()); + case UTILIZATION: + if (parsedMetric.getKey().isPresent()) { + String key = parsedMetric.getKey().get(); + Double val = report.getUtilizationMetrics().get(key); + if (val != null) { + return OptionalDouble.of(val); + } + } + return OptionalDouble.empty(); + case NAMED_METRICS: + if (parsedMetric.getKey().isPresent()) { + String key = parsedMetric.getKey().get(); + Double val = report.getNamedMetrics().get(key); + if (val != null) { + return OptionalDouble.of(val); + } + } + return OptionalDouble.empty(); + case INVALID: + default: + return OptionalDouble.empty(); + } + } +} diff --git a/xds/src/main/java/io/grpc/xds/internal/ProtobufJsonConverter.java b/xds/src/main/java/io/grpc/xds/internal/ProtobufJsonConverter.java new file mode 100644 index 00000000000..964c28c57e0 --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/internal/ProtobufJsonConverter.java @@ -0,0 +1,61 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal; + +import com.google.protobuf.Struct; +import com.google.protobuf.Value; +import io.grpc.Internal; +import java.util.HashMap; +import java.util.Map; +import java.util.stream.Collectors; + +/** + * Converter for Protobuf {@link Struct} to JSON-like {@link Map}. + */ +@Internal +public final class ProtobufJsonConverter { + private ProtobufJsonConverter() {} + + public static Map convertToJson(Struct struct) { + Map result = new HashMap<>(); + for (Map.Entry entry : struct.getFieldsMap().entrySet()) { + result.put(entry.getKey(), convertValue(entry.getValue())); + } + return result; + } + + private static Object convertValue(Value value) { + switch (value.getKindCase()) { + case STRUCT_VALUE: + return convertToJson(value.getStructValue()); + case LIST_VALUE: + return value.getListValue().getValuesList().stream() + .map(ProtobufJsonConverter::convertValue) + .collect(Collectors.toList()); + case NUMBER_VALUE: + return value.getNumberValue(); + case STRING_VALUE: + return value.getStringValue(); + case BOOL_VALUE: + return value.getBoolValue(); + case NULL_VALUE: + return null; + default: + throw new IllegalArgumentException("Unknown Value type: " + value.getKindCase()); + } + } +} diff --git a/xds/src/main/java/io/grpc/xds/internal/XdsInternalAttributes.java b/xds/src/main/java/io/grpc/xds/internal/XdsInternalAttributes.java new file mode 100644 index 00000000000..b05230ea30b --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/internal/XdsInternalAttributes.java @@ -0,0 +1,27 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal; + +import io.grpc.Attributes; +import io.grpc.EquivalentAddressGroup; + +public final class XdsInternalAttributes { + /** Name associated with individual address, if available (e.g., DNS name). */ + @EquivalentAddressGroup.Attr + public static final Attributes.Key ATTR_ADDRESS_NAME = + Attributes.Key.create("io.grpc.xds.XdsAttributes.addressName"); +} diff --git a/xds/src/main/java/io/grpc/xds/internal/extauthz/ExtAuthzConfig.java b/xds/src/main/java/io/grpc/xds/internal/extauthz/ExtAuthzConfig.java new file mode 100644 index 00000000000..5aeb44c6e2a --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/internal/extauthz/ExtAuthzConfig.java @@ -0,0 +1,145 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.extauthz; + +import com.google.auto.value.AutoValue; +import com.google.common.collect.ImmutableList; +import io.grpc.Status; +import io.grpc.xds.internal.Matchers; +import io.grpc.xds.internal.grpcservice.GrpcServiceConfig; +import io.grpc.xds.internal.headermutations.HeaderMutationRulesConfig; +import java.util.Optional; + +/** + * Represents the configuration for the external authorization (ext_authz) filter. This class + * encapsulates the settings defined in the + * {@link io.envoyproxy.envoy.extensions.filters.http.ext_authz.v3.ExtAuthz} proto, providing a + * structured, immutable representation for use within gRPC. It includes configurations for the gRPC + * service used for authorization, header mutation rules, and other filter behaviors. + */ +@AutoValue +public abstract class ExtAuthzConfig { + + /** Creates a new builder for creating {@link ExtAuthzConfig} instances. */ + public static Builder builder() { + return new AutoValue_ExtAuthzConfig.Builder().allowedHeaders(ImmutableList.of()) + .disallowedHeaders(ImmutableList.of()).statusOnError(Status.PERMISSION_DENIED) + .filterEnabled(Matchers.FractionMatcher.create(100, 100)); + } + + /** + * The gRPC service configuration for the external authorization service. This is a required + * field. + * + * @see ExtAuthz#getGrpcService() + */ + public abstract GrpcServiceConfig grpcService(); + + /** + * Changes the filter's behavior on errors from the authorization service. If {@code true}, the + * filter will accept the request even if the authorization service fails or returns an error. + * + * @see ExtAuthz#getFailureModeAllow() + */ + public abstract boolean failureModeAllow(); + + /** + * Determines if the {@code x-envoy-auth-failure-mode-allowed} header is added to the request when + * {@link #failureModeAllow()} is true. + * + * @see ExtAuthz#getFailureModeAllowHeaderAdd() + */ + public abstract boolean failureModeAllowHeaderAdd(); + + /** + * Specifies if the peer certificate is sent to the external authorization service. + * + * @see ExtAuthz#getIncludePeerCertificate() + */ + public abstract boolean includePeerCertificate(); + + /** + * The gRPC status returned to the client when the authorization server returns an error or is + * unreachable. Defaults to {@code PERMISSION_DENIED}. + * + * @see io.envoyproxy.envoy.extensions.filters.http.ext_authz.v3.ExtAuthz#getStatusOnError() + */ + public abstract Status statusOnError(); + + /** + * Specifies whether to deny requests when the filter is disabled. Defaults to {@code false}. + * + * @see ExtAuthz#getDenyAtDisable() + */ + public abstract boolean denyAtDisable(); + + /** + * The fraction of requests that will be checked by the authorization service. Defaults to all + * requests. + * + * @see ExtAuthz#getFilterEnabled() + */ + public abstract Matchers.FractionMatcher filterEnabled(); + + /** + * Specifies which request headers are sent to the authorization service. If empty, all headers + * are sent. + * + * @see ExtAuthz#getAllowedHeaders() + */ + public abstract ImmutableList allowedHeaders(); + + /** + * Specifies which request headers are not sent to the authorization service. This overrides + * {@link #allowedHeaders()}. + * + * @see ExtAuthz#getDisallowedHeaders() + */ + public abstract ImmutableList disallowedHeaders(); + + /** + * Rules for what modifications an ext_authz server may make to request headers. + * + * @see ExtAuthz#getDecoderHeaderMutationRules() + */ + public abstract Optional decoderHeaderMutationRules(); + + @AutoValue.Builder + public abstract static class Builder { + public abstract Builder grpcService(GrpcServiceConfig grpcService); + + public abstract Builder failureModeAllow(boolean failureModeAllow); + + public abstract Builder failureModeAllowHeaderAdd(boolean failureModeAllowHeaderAdd); + + public abstract Builder includePeerCertificate(boolean includePeerCertificate); + + public abstract Builder statusOnError(Status statusOnError); + + public abstract Builder denyAtDisable(boolean denyAtDisable); + + public abstract Builder filterEnabled(Matchers.FractionMatcher filterEnabled); + + public abstract Builder allowedHeaders(Iterable allowedHeaders); + + public abstract Builder disallowedHeaders(Iterable disallowedHeaders); + + public abstract Builder decoderHeaderMutationRules(HeaderMutationRulesConfig rules); + + public abstract ExtAuthzConfig build(); + } +} diff --git a/xds/src/main/java/io/grpc/xds/internal/extauthz/ExtAuthzParseException.java b/xds/src/main/java/io/grpc/xds/internal/extauthz/ExtAuthzParseException.java new file mode 100644 index 00000000000..78edea5c305 --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/internal/extauthz/ExtAuthzParseException.java @@ -0,0 +1,34 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.extauthz; + +/** + * A custom exception for signaling errors during the parsing of external authorization + * (ext_authz) configurations. + */ +public class ExtAuthzParseException extends Exception { + + private static final long serialVersionUID = 0L; + + public ExtAuthzParseException(String message) { + super(message); + } + + public ExtAuthzParseException(String message, Throwable cause) { + super(message, cause); + } +} diff --git a/xds/src/main/java/io/grpc/xds/internal/grpcservice/GrpcServiceConfig.java b/xds/src/main/java/io/grpc/xds/internal/grpcservice/GrpcServiceConfig.java new file mode 100644 index 00000000000..cefc235e9eb --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/internal/grpcservice/GrpcServiceConfig.java @@ -0,0 +1,87 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.grpcservice; + +import com.google.auto.value.AutoValue; +import com.google.common.collect.ImmutableList; +import io.grpc.CallCredentials; +import io.grpc.xds.client.ConfiguredChannelCredentials; +import java.time.Duration; +import java.util.Optional; + + +/** + * This class encapsulates the configuration for a gRPC service, including target URI, credentials, + * and other settings. This class is immutable and uses the AutoValue library for its + * implementation. + */ +@AutoValue +public abstract class GrpcServiceConfig { + + public static Builder builder() { + return new AutoValue_GrpcServiceConfig.Builder(); + } + + public abstract GoogleGrpcConfig googleGrpc(); + + public abstract Optional timeout(); + + public abstract ImmutableList initialMetadata(); + + @AutoValue.Builder + public abstract static class Builder { + public abstract Builder googleGrpc(GoogleGrpcConfig googleGrpc); + + public abstract Builder timeout(Duration timeout); + + public abstract Builder initialMetadata(ImmutableList initialMetadata); + + public abstract GrpcServiceConfig build(); + } + + /** + * This class encapsulates settings specific to Google's gRPC implementation, such as target URI + * and credentials. + */ + @AutoValue + public abstract static class GoogleGrpcConfig { + + public static Builder builder() { + return new AutoValue_GrpcServiceConfig_GoogleGrpcConfig.Builder(); + } + + public abstract String target(); + + public abstract ConfiguredChannelCredentials configuredChannelCredentials(); + + public abstract Optional callCredentials(); + + @AutoValue.Builder + public abstract static class Builder { + public abstract Builder target(String target); + + public abstract Builder configuredChannelCredentials( + ConfiguredChannelCredentials channelCredentials); + + public abstract Builder callCredentials(CallCredentials callCredentials); + + public abstract GoogleGrpcConfig build(); + } + } + + +} diff --git a/xds/src/main/java/io/grpc/xds/internal/grpcservice/GrpcServiceParseException.java b/xds/src/main/java/io/grpc/xds/internal/grpcservice/GrpcServiceParseException.java new file mode 100644 index 00000000000..319ad3d07e3 --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/internal/grpcservice/GrpcServiceParseException.java @@ -0,0 +1,33 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.grpcservice; + +/** + * Exception thrown when there is an error parsing the gRPC service config. + */ +public class GrpcServiceParseException extends Exception { + + private static final long serialVersionUID = 1L; + + public GrpcServiceParseException(String message) { + super(message); + } + + public GrpcServiceParseException(String message, Throwable cause) { + super(message, cause); + } +} diff --git a/xds/src/main/java/io/grpc/xds/internal/grpcservice/HeaderValue.java b/xds/src/main/java/io/grpc/xds/internal/grpcservice/HeaderValue.java new file mode 100644 index 00000000000..1b7bb283744 --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/internal/grpcservice/HeaderValue.java @@ -0,0 +1,44 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.grpcservice; + +import com.google.auto.value.AutoValue; +import com.google.protobuf.ByteString; +import java.util.Optional; + +/** + * Represents a header to be mutated or added as part of xDS configuration. + * Avoids direct dependency on Envoy's proto objects while providing an immutable representation. + */ +@AutoValue +public abstract class HeaderValue { + + public static HeaderValue create(String key, String value) { + return new AutoValue_HeaderValue(key, Optional.of(value), Optional.empty()); + } + + public static HeaderValue create(String key, ByteString rawValue) { + return new AutoValue_HeaderValue(key, Optional.empty(), Optional.of(rawValue)); + } + + + public abstract String key(); + + public abstract Optional value(); + + public abstract Optional rawValue(); +} diff --git a/xds/src/main/java/io/grpc/xds/internal/grpcservice/HeaderValueValidationUtils.java b/xds/src/main/java/io/grpc/xds/internal/grpcservice/HeaderValueValidationUtils.java new file mode 100644 index 00000000000..ff0df11bdc5 --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/internal/grpcservice/HeaderValueValidationUtils.java @@ -0,0 +1,67 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.grpcservice; + +import java.util.Locale; + +/** + * Utility class for validating HTTP headers. + */ +public final class HeaderValueValidationUtils { + public static final int MAX_HEADER_LENGTH = 16384; + + private HeaderValueValidationUtils() {} + + /** + * Returns true if the header key is disallowed for mutations or validation. + * + * @param key The header key (e.g., "content-type") + */ + public static boolean isDisallowed(String key) { + if (key.isEmpty() || key.length() > MAX_HEADER_LENGTH) { + return true; + } + if (!key.equals(key.toLowerCase(Locale.ROOT))) { + return true; + } + if (key.startsWith("grpc-")) { + return true; + } + if (key.startsWith(":") || key.equals("host")) { + return true; + } + return false; + } + + /** + * Returns true if the header value is disallowed. + * + * @param header The HeaderValue containing key and values + */ + public static boolean isDisallowed(HeaderValue header) { + if (isDisallowed(header.key())) { + return true; + } + if (header.value().isPresent() && header.value().get().length() > MAX_HEADER_LENGTH) { + return true; + } + if (header.rawValue().isPresent() && header.rawValue().get().size() > MAX_HEADER_LENGTH) { + return true; + } + return false; + } +} diff --git a/xds/src/main/java/io/grpc/xds/internal/headermutations/HeaderMutationDisallowedException.java b/xds/src/main/java/io/grpc/xds/internal/headermutations/HeaderMutationDisallowedException.java new file mode 100644 index 00000000000..b8d4eb582fb --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/internal/headermutations/HeaderMutationDisallowedException.java @@ -0,0 +1,32 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.headermutations; + +import io.grpc.Status; +import io.grpc.StatusException; + +/** + * Exception thrown when a header mutation is disallowed. + */ +public final class HeaderMutationDisallowedException extends StatusException { + + private static final long serialVersionUID = 1L; + + public HeaderMutationDisallowedException(String message) { + super(Status.INTERNAL.withDescription(message)); + } +} diff --git a/xds/src/main/java/io/grpc/xds/internal/headermutations/HeaderMutationFilter.java b/xds/src/main/java/io/grpc/xds/internal/headermutations/HeaderMutationFilter.java new file mode 100644 index 00000000000..35cab17d928 --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/internal/headermutations/HeaderMutationFilter.java @@ -0,0 +1,114 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.headermutations; + +import com.google.common.collect.ImmutableList; +import io.grpc.xds.internal.grpcservice.HeaderValueValidationUtils; +import java.util.Collection; +import java.util.Optional; +import java.util.function.Predicate; + +/** + * The HeaderMutationFilter class is responsible for filtering header mutations based on a given set + * of rules. + */ +public class HeaderMutationFilter { + private final Optional mutationRules; + + + + public HeaderMutationFilter(Optional mutationRules) { + this.mutationRules = mutationRules; + } + + /** + * Filters the given header mutations based on the configured rules and returns the allowed + * mutations. + * + * @param mutations The header mutations to filter + * @return The allowed header mutations. + * @throws HeaderMutationDisallowedException if a disallowed mutation is encountered and the rules + * specify that this should be an error. + */ + public HeaderMutations filter(HeaderMutations mutations) + throws HeaderMutationDisallowedException { + ImmutableList allowedHeaders = + filterCollection(mutations.headers(), this::isDisallowed, this::isHeaderMutationAllowed); + ImmutableList allowedHeadersToRemove = + filterCollection(mutations.headersToRemove(), this::isDisallowed, + this::isHeaderMutationAllowed); + return HeaderMutations.create(allowedHeaders, allowedHeadersToRemove); + } + + /** + * A generic helper to filter a collection based on a predicate. + */ + private ImmutableList filterCollection(Collection items, + Predicate isIgnoredPredicate, Predicate isAllowedPredicate) + throws HeaderMutationDisallowedException { + ImmutableList.Builder allowed = ImmutableList.builder(); + for (T item : items) { + boolean isIgnored = isIgnoredPredicate.test(item); + boolean isAllowed = isAllowedPredicate.test(item); + + // TODO(sauravzg): The specification is ambiguous regarding whether system headers + // should be silently ignored or trigger an error when disallowIsError is enabled. + // We default to triggering errors matching Envoy's implementation. + // Ref: https://github.com/grpc/proposal/pull/481#discussion_r3124453674 + if (!isIgnored && isAllowed) { + allowed.add(item); + } else if (disallowIsError()) { + throw new HeaderMutationDisallowedException("Header mutation disallowed"); + } + } + return allowed.build(); + } + + private boolean isDisallowed(String key) { + return HeaderValueValidationUtils.isDisallowed(key); + } + + private boolean isDisallowed(HeaderValueOption option) { + return HeaderValueValidationUtils.isDisallowed(option.header()); + } + + private boolean isHeaderMutationAllowed(HeaderValueOption option) { + return isHeaderMutationAllowed(option.header().key()); + } + + private boolean isHeaderMutationAllowed(String headerName) { + return mutationRules.map(rules -> isHeaderMutationAllowed(headerName, rules)) + .orElse(true); + } + + private boolean isHeaderMutationAllowed(String headerName, + HeaderMutationRulesConfig rules) { + if (rules.disallowExpression().isPresent() + && rules.disallowExpression().get().matcher(headerName).matches()) { + return false; + } + if (rules.allowExpression().isPresent() + && rules.allowExpression().get().matcher(headerName).matches()) { + return true; + } + return !rules.disallowAll(); + } + + private boolean disallowIsError() { + return mutationRules.map(HeaderMutationRulesConfig::disallowIsError).orElse(false); + } +} diff --git a/xds/src/main/java/io/grpc/xds/internal/headermutations/HeaderMutationRulesConfig.java b/xds/src/main/java/io/grpc/xds/internal/headermutations/HeaderMutationRulesConfig.java new file mode 100644 index 00000000000..b16ec7948ed --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/internal/headermutations/HeaderMutationRulesConfig.java @@ -0,0 +1,77 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.headermutations; + +import com.google.auto.value.AutoValue; +import com.google.re2j.Pattern; +import io.envoyproxy.envoy.config.common.mutation_rules.v3.HeaderMutationRules; +import java.util.Optional; + +/** + * Represents the configuration for header mutation rules, as defined in the + * {@link io.envoyproxy.envoy.config.common.mutation_rules.v3.HeaderMutationRules} proto. + */ +@AutoValue +public abstract class HeaderMutationRulesConfig { + /** Creates a new builder for creating {@link HeaderMutationRulesConfig} instances. */ + public static Builder builder() { + return new AutoValue_HeaderMutationRulesConfig.Builder().disallowAll(false) + .disallowIsError(false); + } + + /** + * If set, allows any header that matches this regular expression. + * + * @see HeaderMutationRules#getAllowExpression() + */ + public abstract Optional allowExpression(); + + /** + * If set, disallows any header that matches this regular expression. + * + * @see HeaderMutationRules#getDisallowExpression() + */ + public abstract Optional disallowExpression(); + + /** + * If true, disallows all header mutations. + * + * @see HeaderMutationRules#getDisallowAll() + */ + public abstract boolean disallowAll(); + + /** + * If true, a disallowed header mutation will result in an error instead of being ignored. + * + * @see HeaderMutationRules#getDisallowIsError() + */ + public abstract boolean disallowIsError(); + + + @AutoValue.Builder + public abstract static class Builder { + public abstract Builder allowExpression(Pattern matcher); + + public abstract Builder disallowExpression(Pattern matcher); + + public abstract Builder disallowAll(boolean disallowAll); + + public abstract Builder disallowIsError(boolean disallowIsError); + + public abstract HeaderMutationRulesConfig build(); + } +} diff --git a/xds/src/main/java/io/grpc/xds/internal/headermutations/HeaderMutationRulesParseException.java b/xds/src/main/java/io/grpc/xds/internal/headermutations/HeaderMutationRulesParseException.java new file mode 100644 index 00000000000..3782e84a54b --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/internal/headermutations/HeaderMutationRulesParseException.java @@ -0,0 +1,32 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.headermutations; + +/** + * Exception thrown when parsing header mutation rules fails. + */ +public final class HeaderMutationRulesParseException extends Exception { + private static final long serialVersionUID = 1L; + + public HeaderMutationRulesParseException(String message) { + super(message); + } + + public HeaderMutationRulesParseException(String message, Throwable cause) { + super(message, cause); + } +} diff --git a/xds/src/main/java/io/grpc/xds/internal/headermutations/HeaderMutationRulesParser.java b/xds/src/main/java/io/grpc/xds/internal/headermutations/HeaderMutationRulesParser.java new file mode 100644 index 00000000000..f6bb2ec508d --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/internal/headermutations/HeaderMutationRulesParser.java @@ -0,0 +1,55 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.headermutations; + +import com.google.re2j.Pattern; +import com.google.re2j.PatternSyntaxException; +import io.envoyproxy.envoy.config.common.mutation_rules.v3.HeaderMutationRules; + +/** + * Parser for {@link io.envoyproxy.envoy.config.common.mutation_rules.v3.HeaderMutationRules}. + */ +public final class HeaderMutationRulesParser { + + private HeaderMutationRulesParser() {} + + public static HeaderMutationRulesConfig parse(HeaderMutationRules proto) + throws HeaderMutationRulesParseException { + HeaderMutationRulesConfig.Builder builder = HeaderMutationRulesConfig.builder(); + builder.disallowAll(proto.getDisallowAll().getValue()); + builder.disallowIsError(proto.getDisallowIsError().getValue()); + if (proto.hasAllowExpression()) { + builder.allowExpression( + parseRegex(proto.getAllowExpression().getRegex(), "allow_expression")); + } + if (proto.hasDisallowExpression()) { + builder.disallowExpression( + parseRegex(proto.getDisallowExpression().getRegex(), "disallow_expression")); + } + return builder.build(); + } + + private static Pattern parseRegex(String regex, String fieldName) + throws HeaderMutationRulesParseException { + try { + return Pattern.compile(regex); + } catch (PatternSyntaxException e) { + throw new HeaderMutationRulesParseException( + "Invalid regex pattern for " + fieldName + ": " + e.getMessage(), e); + } + } +} diff --git a/xds/src/main/java/io/grpc/xds/internal/headermutations/HeaderMutations.java b/xds/src/main/java/io/grpc/xds/internal/headermutations/HeaderMutations.java new file mode 100644 index 00000000000..a456413c899 --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/internal/headermutations/HeaderMutations.java @@ -0,0 +1,34 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.headermutations; + +import com.google.auto.value.AutoValue; +import com.google.common.collect.ImmutableList; + +/** A collection of header mutations. */ +@AutoValue +public abstract class HeaderMutations { + + public static HeaderMutations create(ImmutableList headers, + ImmutableList headersToRemove) { + return new AutoValue_HeaderMutations(headers, headersToRemove); + } + + public abstract ImmutableList headers(); + + public abstract ImmutableList headersToRemove(); +} diff --git a/xds/src/main/java/io/grpc/xds/internal/headermutations/HeaderMutator.java b/xds/src/main/java/io/grpc/xds/internal/headermutations/HeaderMutator.java new file mode 100644 index 00000000000..e6cdc126f22 --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/internal/headermutations/HeaderMutator.java @@ -0,0 +1,123 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.headermutations; + + +import io.grpc.Metadata; +import io.grpc.xds.internal.grpcservice.HeaderValue; +import io.grpc.xds.internal.headermutations.HeaderValueOption.HeaderAppendAction; +import java.util.logging.Logger; + +/** + * The HeaderMutator provides methods to apply header mutations to a given set of headers based on a + * given set of rules. + */ +public class HeaderMutator { + + private static final Logger logger = Logger.getLogger(HeaderMutator.class.getName()); + + /** + * Creates a new instance of {@code HeaderMutator}. + */ + public static HeaderMutator create() { + return new HeaderMutator(); + } + + HeaderMutator() {} + + /** + * Applies the given header mutations to the provided metadata headers. + * + * @param mutations The header mutations to apply. + * @param headers The metadata headers to which the mutations will be applied. + */ + public void applyMutations(final HeaderMutations mutations, Metadata headers) { + // TODO(sauravzg): The specification is not clear on order of header removals and additions. + // in case of conflicts. Copying the order from Envoy here, which does removals at the end. + applyHeaderUpdates(mutations.headers(), headers); + for (String headerToRemove : mutations.headersToRemove()) { + Metadata.Key key = headerToRemove.endsWith(Metadata.BINARY_HEADER_SUFFIX) + ? Metadata.Key.of(headerToRemove, Metadata.BINARY_BYTE_MARSHALLER) + : Metadata.Key.of(headerToRemove, Metadata.ASCII_STRING_MARSHALLER); + headers.discardAll(key); + } + } + + private void applyHeaderUpdates(final Iterable headerOptions, + Metadata headers) { + for (HeaderValueOption headerOption : headerOptions) { + updateHeader(headerOption, headers); + } + } + + private void updateHeader(final HeaderValueOption option, Metadata mutableHeaders) { + HeaderValue header = option.header(); + HeaderAppendAction action = option.appendAction(); + boolean keepEmptyValue = option.keepEmptyValue(); + + if (header.key().endsWith(Metadata.BINARY_HEADER_SUFFIX)) { + if (header.rawValue().isPresent()) { + byte[] value = header.rawValue().get().toByteArray(); + if (value.length > 0 || keepEmptyValue) { + updateHeader(action, Metadata.Key.of(header.key(), Metadata.BINARY_BYTE_MARSHALLER), + value, mutableHeaders); + } + } else { + logger.fine("Missing binary rawValue for header: " + header.key()); + } + } else { + if (header.value().isPresent()) { + String value = header.value().get(); + if (!value.isEmpty() || keepEmptyValue) { + updateHeader(action, Metadata.Key.of(header.key(), Metadata.ASCII_STRING_MARSHALLER), + value, mutableHeaders); + } + } else { + logger.fine("Missing value for header: " + header.key()); + } + } + } + + private void updateHeader(final HeaderAppendAction action, final Metadata.Key key, + final T value, Metadata mutableHeaders) { + switch (action) { + case APPEND_IF_EXISTS_OR_ADD: + mutableHeaders.put(key, value); + break; + case ADD_IF_ABSENT: + if (!mutableHeaders.containsKey(key)) { + mutableHeaders.put(key, value); + } + break; + case OVERWRITE_IF_EXISTS_OR_ADD: + mutableHeaders.discardAll(key); + mutableHeaders.put(key, value); + break; + case OVERWRITE_IF_EXISTS: + if (mutableHeaders.containsKey(key)) { + mutableHeaders.discardAll(key); + mutableHeaders.put(key, value); + } + break; + + default: + // Should be unreachable unless there's a proto schema mismatch. + logger.fine("Unknown HeaderAppendAction: " + action); + } + } +} + diff --git a/xds/src/main/java/io/grpc/xds/internal/headermutations/HeaderValueOption.java b/xds/src/main/java/io/grpc/xds/internal/headermutations/HeaderValueOption.java new file mode 100644 index 00000000000..6cb96da864d --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/internal/headermutations/HeaderValueOption.java @@ -0,0 +1,50 @@ +/* + * Copyright 2026 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.headermutations; + +import com.google.auto.value.AutoValue; +import io.grpc.xds.internal.grpcservice.HeaderValue; + +/** + * Represents a header option to be appended or mutated as part of xDS configuration. + * Avoids direct dependency on Envoy's proto objects. + */ +@AutoValue +public abstract class HeaderValueOption { + + public static HeaderValueOption create( + HeaderValue header, HeaderAppendAction appendAction, boolean keepEmptyValue) { + return new AutoValue_HeaderValueOption(header, appendAction, keepEmptyValue); + } + + public abstract HeaderValue header(); + + public abstract HeaderAppendAction appendAction(); + + public abstract boolean keepEmptyValue(); + + /** + * Defines the action to take when appending headers. + * Mirrors io.envoyproxy.envoy.config.core.v3.HeaderValueOption.HeaderAppendAction. + */ + public enum HeaderAppendAction { + APPEND_IF_EXISTS_OR_ADD, + ADD_IF_ABSENT, + OVERWRITE_IF_EXISTS_OR_ADD, + OVERWRITE_IF_EXISTS + } +} diff --git a/xds/src/main/java/io/grpc/xds/internal/security/ClientSslContextProviderFactory.java b/xds/src/main/java/io/grpc/xds/internal/security/ClientSslContextProviderFactory.java index 90202b4820a..37d289c1c47 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/ClientSslContextProviderFactory.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/ClientSslContextProviderFactory.java @@ -16,8 +16,6 @@ package io.grpc.xds.internal.security; -import static com.google.common.base.Preconditions.checkNotNull; - import io.grpc.xds.EnvoyServerProtoData.UpstreamTlsContext; import io.grpc.xds.client.Bootstrapper.BootstrapInfo; import io.grpc.xds.internal.security.ReferenceCountingMap.ValueFactory; @@ -44,17 +42,9 @@ final class ClientSslContextProviderFactory /** Creates an SslContextProvider from the given UpstreamTlsContext. */ @Override public SslContextProvider create(UpstreamTlsContext upstreamTlsContext) { - checkNotNull(upstreamTlsContext, "upstreamTlsContext"); - checkNotNull( - upstreamTlsContext.getCommonTlsContext(), - "upstreamTlsContext should have CommonTlsContext"); - if (CommonTlsContextUtil.hasCertProviderInstance( - upstreamTlsContext.getCommonTlsContext())) { - return certProviderClientSslContextProviderFactory.getProvider( - upstreamTlsContext, - bootstrapInfo.node().toEnvoyProtoNode(), - bootstrapInfo.certProviders()); - } - throw new UnsupportedOperationException("Unsupported configurations in UpstreamTlsContext!"); + return certProviderClientSslContextProviderFactory.getProvider( + upstreamTlsContext, + bootstrapInfo.node().toEnvoyProtoNode(), + bootstrapInfo.certProviders()); } } diff --git a/xds/src/main/java/io/grpc/xds/internal/security/CommonTlsContextUtil.java b/xds/src/main/java/io/grpc/xds/internal/security/CommonTlsContextUtil.java index d3003b4a792..bd8a423e683 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/CommonTlsContextUtil.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/CommonTlsContextUtil.java @@ -18,33 +18,21 @@ import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateProviderPluginInstance; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext; -import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext.CombinedCertificateValidationContext; /** Class for utility functions for {@link CommonTlsContext}. */ public final class CommonTlsContextUtil { private CommonTlsContextUtil() {} - static boolean hasCertProviderInstance(CommonTlsContext commonTlsContext) { + public static boolean hasCertProviderInstance(CommonTlsContext commonTlsContext) { if (commonTlsContext == null) { return false; } - return hasIdentityCertificateProviderInstance(commonTlsContext) - || hasCertProviderValidationContext(commonTlsContext); - } - - private static boolean hasCertProviderValidationContext(CommonTlsContext commonTlsContext) { - if (commonTlsContext.hasCombinedValidationContext()) { - CombinedCertificateValidationContext combinedCertificateValidationContext = - commonTlsContext.getCombinedValidationContext(); - return combinedCertificateValidationContext.hasValidationContextCertificateProviderInstance(); - } - return hasValidationProviderInstance(commonTlsContext); - } - - private static boolean hasIdentityCertificateProviderInstance(CommonTlsContext commonTlsContext) { + @SuppressWarnings("deprecation") + boolean hasDeprecatedField = commonTlsContext.hasTlsCertificateCertificateProviderInstance(); return commonTlsContext.hasTlsCertificateProviderInstance() - || commonTlsContext.hasTlsCertificateCertificateProviderInstance(); + || hasDeprecatedField + || hasValidationProviderInstance(commonTlsContext); } private static boolean hasValidationProviderInstance(CommonTlsContext commonTlsContext) { @@ -52,7 +40,19 @@ private static boolean hasValidationProviderInstance(CommonTlsContext commonTlsC .hasCaCertificateProviderInstance()) { return true; } - return commonTlsContext.hasValidationContextCertificateProviderInstance(); + if (commonTlsContext.hasCombinedValidationContext()) { + CommonTlsContext.CombinedCertificateValidationContext combined = + commonTlsContext.getCombinedValidationContext(); + if (combined.hasDefaultValidationContext() + && combined.getDefaultValidationContext().hasCaCertificateProviderInstance()) { + return true; + } + // Check deprecated field (field 4) in CombinedValidationContext + @SuppressWarnings("deprecation") + boolean hasDeprecatedField = combined.hasValidationContextCertificateProviderInstance(); + return hasDeprecatedField; + } + return false; } /** @@ -65,4 +65,15 @@ public static CommonTlsContext.CertificateProviderInstance convert( .setInstanceName(pluginInstance.getInstanceName()) .setCertificateName(pluginInstance.getCertificateName()).build(); } + + public static boolean isUsingSystemRootCerts(CommonTlsContext commonTlsContext) { + if (commonTlsContext.hasCombinedValidationContext()) { + return commonTlsContext.getCombinedValidationContext().getDefaultValidationContext() + .hasSystemRootCerts(); + } + if (commonTlsContext.hasValidationContext()) { + return commonTlsContext.getValidationContext().hasSystemRootCerts(); + } + return false; + } } diff --git a/xds/src/main/java/io/grpc/xds/internal/security/DynamicSslContextProvider.java b/xds/src/main/java/io/grpc/xds/internal/security/DynamicSslContextProvider.java index 6bf66d022ff..e7b27cd644a 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/DynamicSslContextProvider.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/DynamicSslContextProvider.java @@ -30,9 +30,11 @@ import java.io.IOException; import java.security.cert.CertStoreException; import java.security.cert.CertificateException; +import java.util.AbstractMap; import java.util.ArrayList; import java.util.List; import javax.annotation.Nullable; +import javax.net.ssl.X509TrustManager; /** Base class for dynamic {@link SslContextProvider}s. */ @Internal @@ -40,7 +42,8 @@ public abstract class DynamicSslContextProvider extends SslContextProvider { protected final List pendingCallbacks = new ArrayList<>(); @Nullable protected final CertificateValidationContext staticCertificateValidationContext; - @Nullable protected SslContext sslContext; + @Nullable protected AbstractMap.SimpleImmutableEntry + sslContextAndTrustManager; protected DynamicSslContextProvider( BaseTlsContext tlsContext, CertificateValidationContext staticCertValidationContext) { @@ -49,15 +52,17 @@ protected DynamicSslContextProvider( } @Nullable - public SslContext getSslContext() { - return sslContext; + public AbstractMap.SimpleImmutableEntry + getSslContextAndTrustManager() { + return sslContextAndTrustManager; } protected abstract CertificateValidationContext generateCertificateValidationContext(); /** Gets a server or client side SslContextBuilder. */ - protected abstract SslContextBuilder getSslContextBuilder( - CertificateValidationContext certificateValidationContext) + protected abstract AbstractMap.SimpleImmutableEntry + getSslContextBuilderAndTrustManager( + CertificateValidationContext certificateValidationContext) throws CertificateException, IOException, CertStoreException; // this gets called only when requested secrets are ready... @@ -65,7 +70,8 @@ protected final void updateSslContext() { try { CertificateValidationContext localCertValidationContext = generateCertificateValidationContext(); - SslContextBuilder sslContextBuilder = getSslContextBuilder(localCertValidationContext); + AbstractMap.SimpleImmutableEntry sslContextBuilderAndTm = + getSslContextBuilderAndTrustManager(localCertValidationContext); CommonTlsContext commonTlsContext = getCommonTlsContext(); if (commonTlsContext != null && commonTlsContext.getAlpnProtocolsCount() > 0) { List alpnList = commonTlsContext.getAlpnProtocolsList(); @@ -75,16 +81,18 @@ protected final void updateSslContext() { ApplicationProtocolConfig.SelectorFailureBehavior.NO_ADVERTISE, ApplicationProtocolConfig.SelectedListenerFailureBehavior.ACCEPT, alpnList); - sslContextBuilder.applicationProtocolConfig(apn); + sslContextBuilderAndTm.getKey().applicationProtocolConfig(apn); } List pendingCallbacksCopy; - SslContext sslContextCopy; + AbstractMap.SimpleImmutableEntry + sslContextAndExtendedX09TrustManagerCopy; synchronized (pendingCallbacks) { - sslContext = sslContextBuilder.build(); - sslContextCopy = sslContext; + sslContextAndTrustManager = new AbstractMap.SimpleImmutableEntry<>( + sslContextBuilderAndTm.getKey().build(), sslContextBuilderAndTm.getValue()); + sslContextAndExtendedX09TrustManagerCopy = sslContextAndTrustManager; pendingCallbacksCopy = clonePendingCallbacksAndClear(); } - makePendingCallbacks(sslContextCopy, pendingCallbacksCopy); + makePendingCallbacks(sslContextAndExtendedX09TrustManagerCopy, pendingCallbacksCopy); } catch (Exception e) { onError(Status.fromThrowable(e)); throw new RuntimeException(e); @@ -92,12 +100,13 @@ protected final void updateSslContext() { } protected final void callPerformCallback( - Callback callback, final SslContext sslContextCopy) { + Callback callback, + final AbstractMap.SimpleImmutableEntry sslContextAndTmCopy) { performCallback( new SslContextGetter() { @Override - public SslContext get() { - return sslContextCopy; + public AbstractMap.SimpleImmutableEntry get() { + return sslContextAndTmCopy; } }, callback @@ -108,10 +117,10 @@ public SslContext get() { public final void addCallback(Callback callback) { checkNotNull(callback, "callback"); // if there is a computed sslContext just send it - SslContext sslContextCopy = null; + AbstractMap.SimpleImmutableEntry sslContextCopy = null; synchronized (pendingCallbacks) { - if (sslContext != null) { - sslContextCopy = sslContext; + if (sslContextAndTrustManager != null) { + sslContextCopy = sslContextAndTrustManager; } else { pendingCallbacks.add(callback); } @@ -122,9 +131,11 @@ public final void addCallback(Callback callback) { } private final void makePendingCallbacks( - SslContext sslContextCopy, List pendingCallbacksCopy) { + AbstractMap.SimpleImmutableEntry + sslContextAndExtendedX509TrustManagerCopy, + List pendingCallbacksCopy) { for (Callback callback : pendingCallbacksCopy) { - callPerformCallback(callback, sslContextCopy); + callPerformCallback(callback, sslContextAndExtendedX509TrustManagerCopy); } } diff --git a/xds/src/main/java/io/grpc/xds/internal/security/ReferenceCountingMap.java b/xds/src/main/java/io/grpc/xds/internal/security/ReferenceCountingMap.java index b7f56492fa5..08b8f6a325b 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/ReferenceCountingMap.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/ReferenceCountingMap.java @@ -20,9 +20,9 @@ import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.base.Preconditions.checkState; +import com.google.errorprone.annotations.CheckReturnValue; import java.util.HashMap; import java.util.Map; -import javax.annotation.CheckReturnValue; import javax.annotation.concurrent.ThreadSafe; /** diff --git a/xds/src/main/java/io/grpc/xds/internal/security/SecurityProtocolNegotiators.java b/xds/src/main/java/io/grpc/xds/internal/security/SecurityProtocolNegotiators.java index 00659e53de1..a93299de11c 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/SecurityProtocolNegotiators.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/SecurityProtocolNegotiators.java @@ -19,7 +19,9 @@ import static com.google.common.base.Preconditions.checkNotNull; import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Strings; import io.grpc.Attributes; +import io.grpc.Grpc; import io.grpc.internal.GrpcUtil; import io.grpc.internal.ObjectPool; import io.grpc.netty.GrpcHttp2ConnectionHandler; @@ -28,7 +30,10 @@ import io.grpc.netty.InternalProtocolNegotiator.ProtocolNegotiator; import io.grpc.netty.InternalProtocolNegotiators; import io.grpc.netty.ProtocolNegotiationEvent; -import io.grpc.xds.InternalXdsAttributes; +import io.grpc.xds.EnvoyServerProtoData; +import io.grpc.xds.EnvoyServerProtoData.UpstreamTlsContext; +import io.grpc.xds.internal.XdsInternalAttributes; +import io.grpc.xds.internal.security.trust.CertificateUtils; import io.netty.channel.ChannelHandler; import io.netty.channel.ChannelHandlerAdapter; import io.netty.channel.ChannelHandlerContext; @@ -36,12 +41,14 @@ import io.netty.handler.ssl.SslContext; import io.netty.util.AsciiString; import java.security.cert.CertStoreException; +import java.util.AbstractMap; import java.util.ArrayList; import java.util.List; import java.util.concurrent.Executor; import java.util.logging.Level; import java.util.logging.Logger; import javax.annotation.Nullable; +import javax.net.ssl.X509TrustManager; /** * Provides client and server side gRPC {@link ProtocolNegotiator}s to provide the SSL @@ -60,8 +67,14 @@ private SecurityProtocolNegotiators() { private static final AsciiString SCHEME = AsciiString.of("http"); public static final Attributes.Key - ATTR_SERVER_SSL_CONTEXT_PROVIDER_SUPPLIER = - Attributes.Key.create("io.grpc.xds.internal.security.server.sslContextProviderSupplier"); + ATTR_SERVER_SSL_CONTEXT_PROVIDER_SUPPLIER = + Attributes.Key.create("io.grpc.xds.internal.security.server.sslContextProviderSupplier"); + + /** Attribute key for SslContextProviderSupplier (used from client) for a subchannel. */ + @Grpc.TransportAttr + public static final Attributes.Key + ATTR_SSL_CONTEXT_PROVIDER_SUPPLIER = + Attributes.Key.create("io.grpc.xds.internal.security.SslContextProviderSupplier"); /** * Returns a {@link InternalProtocolNegotiator.ClientFactory}. @@ -130,14 +143,14 @@ public AsciiString scheme() { public ChannelHandler newHandler(GrpcHttp2ConnectionHandler grpcHandler) { // check if SslContextProviderSupplier was passed via attributes SslContextProviderSupplier localSslContextProviderSupplier = - grpcHandler.getEagAttributes().get( - InternalXdsAttributes.ATTR_SSL_CONTEXT_PROVIDER_SUPPLIER); + grpcHandler.getEagAttributes().get(ATTR_SSL_CONTEXT_PROVIDER_SUPPLIER); if (localSslContextProviderSupplier == null) { checkNotNull( fallbackProtocolNegotiator, "No TLS config and no fallbackProtocolNegotiator!"); return fallbackProtocolNegotiator.newHandler(grpcHandler); } - return new ClientSecurityHandler(grpcHandler, localSslContextProviderSupplier); + return new ClientSecurityHandler(grpcHandler, localSslContextProviderSupplier, + grpcHandler.getEagAttributes().get(XdsInternalAttributes.ATTR_ADDRESS_NAME)); } @Override @@ -180,10 +193,13 @@ static final class ClientSecurityHandler extends InternalProtocolNegotiators.ProtocolNegotiationHandler { private final GrpcHttp2ConnectionHandler grpcHandler; private final SslContextProviderSupplier sslContextProviderSupplier; + private final String sni; + private final boolean autoSniSanValidationDoesNotApply; ClientSecurityHandler( GrpcHttp2ConnectionHandler grpcHandler, - SslContextProviderSupplier sslContextProviderSupplier) { + SslContextProviderSupplier sslContextProviderSupplier, + String endpointHostname) { super( // superclass (InternalProtocolNegotiators.ProtocolNegotiationHandler) expects 'next' // handler but we don't have a next handler _yet_. So we "disable" superclass's behavior @@ -197,6 +213,26 @@ public void handlerAdded(ChannelHandlerContext ctx) throws Exception { checkNotNull(grpcHandler, "grpcHandler"); this.grpcHandler = grpcHandler; this.sslContextProviderSupplier = sslContextProviderSupplier; + EnvoyServerProtoData.BaseTlsContext tlsContext = sslContextProviderSupplier.getTlsContext(); + UpstreamTlsContext upstreamTlsContext = ((UpstreamTlsContext) tlsContext); + + String sniToUse = upstreamTlsContext.getAutoHostSni() + && !Strings.isNullOrEmpty(endpointHostname) + ? endpointHostname : upstreamTlsContext.getSni(); + if (sniToUse.isEmpty()) { + if (CertificateUtils.useChannelAuthorityIfNoSniApplicable) { + sniToUse = grpcHandler.getAuthority(); + } + autoSniSanValidationDoesNotApply = true; + } else { + autoSniSanValidationDoesNotApply = false; + } + sni = sniToUse; + } + + @VisibleForTesting + String getSni() { + return sni; } @Override @@ -208,7 +244,8 @@ protected void handlerAdded0(final ChannelHandlerContext ctx) { new SslContextProvider.Callback(ctx.executor()) { @Override - public void updateSslContext(SslContext sslContext) { + public void updateSslContextAndExtendedX509TrustManager( + AbstractMap.SimpleImmutableEntry sslContextAndTm) { if (ctx.isRemoved()) { return; } @@ -217,7 +254,9 @@ public void updateSslContext(SslContext sslContext) { "ClientSecurityHandler.updateSslContext authority={0}, ctx.name={1}", new Object[]{grpcHandler.getAuthority(), ctx.name()}); ChannelHandler handler = - InternalProtocolNegotiators.tls(sslContext).newHandler(grpcHandler); + InternalProtocolNegotiators.tls( + sslContextAndTm.getKey(), sni, sslContextAndTm.getValue()) + .newHandler(grpcHandler); // Delegate rest of handshake to TLS handler ctx.pipeline().addAfter(ctx.name(), null, handler); @@ -229,8 +268,8 @@ public void updateSslContext(SslContext sslContext) { public void onException(Throwable throwable) { ctx.fireExceptionCaught(throwable); } - } - ); + }, + autoSniSanValidationDoesNotApply); } @Override @@ -351,9 +390,10 @@ protected void handlerAdded0(final ChannelHandlerContext ctx) { new SslContextProvider.Callback(ctx.executor()) { @Override - public void updateSslContext(SslContext sslContext) { - ChannelHandler handler = - InternalProtocolNegotiators.serverTls(sslContext).newHandler(grpcHandler); + public void updateSslContextAndExtendedX509TrustManager( + AbstractMap.SimpleImmutableEntry sslContextAndTm) { + ChannelHandler handler = InternalProtocolNegotiators.serverTls( + sslContextAndTm.getKey()).newHandler(grpcHandler); // Delegate rest of handshake to TLS handler if (!ctx.isRemoved()) { @@ -367,8 +407,8 @@ public void updateSslContext(SslContext sslContext) { public void onException(Throwable throwable) { ctx.fireExceptionCaught(throwable); } - } - ); + }, + false); } } } diff --git a/xds/src/main/java/io/grpc/xds/internal/security/SslContextProvider.java b/xds/src/main/java/io/grpc/xds/internal/security/SslContextProvider.java index a0c4ed37dfb..a5d14f72dc5 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/SslContextProvider.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/SslContextProvider.java @@ -32,7 +32,9 @@ import java.io.IOException; import java.security.cert.CertStoreException; import java.security.cert.CertificateException; +import java.util.AbstractMap; import java.util.concurrent.Executor; +import javax.net.ssl.X509TrustManager; /** * A SslContextProvider is a "container" or provider of SslContext. This is used by gRPC-xds to @@ -57,7 +59,8 @@ protected Callback(Executor executor) { } /** Informs callee of new/updated SslContext. */ - @VisibleForTesting public abstract void updateSslContext(SslContext sslContext); + @VisibleForTesting public abstract void updateSslContextAndExtendedX509TrustManager( + AbstractMap.SimpleImmutableEntry sslContext); /** Informs callee of an exception that was generated. */ @VisibleForTesting protected abstract void onException(Throwable throwable); @@ -119,8 +122,9 @@ protected final void performCallback( @Override public void run() { try { - SslContext sslContext = sslContextGetter.get(); - callback.updateSslContext(sslContext); + AbstractMap.SimpleImmutableEntry sslContextAndTm = + sslContextGetter.get(); + callback.updateSslContextAndExtendedX509TrustManager(sslContextAndTm); } catch (Throwable e) { callback.onException(e); } @@ -130,6 +134,6 @@ public void run() { /** Allows implementations to compute or get SslContext. */ protected interface SslContextGetter { - SslContext get() throws Exception; + AbstractMap.SimpleImmutableEntry get() throws Exception; } } diff --git a/xds/src/main/java/io/grpc/xds/internal/security/SslContextProviderSupplier.java b/xds/src/main/java/io/grpc/xds/internal/security/SslContextProviderSupplier.java index 5f629273179..94fc423c202 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/SslContextProviderSupplier.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/SslContextProviderSupplier.java @@ -25,7 +25,9 @@ import io.grpc.xds.EnvoyServerProtoData.UpstreamTlsContext; import io.grpc.xds.TlsContextManager; import io.netty.handler.ssl.SslContext; +import java.util.AbstractMap; import java.util.Objects; +import javax.net.ssl.X509TrustManager; /** * Enables Client or server side to initialize this object with the received {@link BaseTlsContext} @@ -52,22 +54,24 @@ public BaseTlsContext getTlsContext() { } /** Updates SslContext via the passed callback. */ - public synchronized void updateSslContext(final SslContextProvider.Callback callback) { + public synchronized void updateSslContext( + final SslContextProvider.Callback callback, boolean autoSniSanValidationDoesNotApply) { checkNotNull(callback, "callback"); try { if (!shutdown) { if (sslContextProvider == null) { - sslContextProvider = getSslContextProvider(); + sslContextProvider = getSslContextProvider(autoSniSanValidationDoesNotApply); } } // we want to increment the ref-count so call findOrCreate again... - final SslContextProvider toRelease = getSslContextProvider(); + final SslContextProvider toRelease = getSslContextProvider(autoSniSanValidationDoesNotApply); toRelease.addCallback( new SslContextProvider.Callback(callback.getExecutor()) { @Override - public void updateSslContext(SslContext sslContext) { - callback.updateSslContext(sslContext); + public void updateSslContextAndExtendedX509TrustManager( + AbstractMap.SimpleImmutableEntry sslContextAndTm) { + callback.updateSslContextAndExtendedX509TrustManager(sslContextAndTm); releaseSslContextProvider(toRelease); } @@ -95,10 +99,20 @@ private void releaseSslContextProvider(SslContextProvider toRelease) { } } - private SslContextProvider getSslContextProvider() { - return tlsContext instanceof UpstreamTlsContext - ? tlsContextManager.findOrCreateClientSslContextProvider((UpstreamTlsContext) tlsContext) - : tlsContextManager.findOrCreateServerSslContextProvider((DownstreamTlsContext) tlsContext); + private SslContextProvider getSslContextProvider(boolean autoSniSanValidationDoesNotApply) { + if (tlsContext instanceof UpstreamTlsContext) { + UpstreamTlsContext upstreamTlsContext = (UpstreamTlsContext) tlsContext; + if (autoSniSanValidationDoesNotApply && upstreamTlsContext.getAutoSniSanValidation()) { + upstreamTlsContext = new UpstreamTlsContext( + upstreamTlsContext.getCommonTlsContext(), + upstreamTlsContext.getSni(), + upstreamTlsContext.getAutoHostSni(), + false); + } + return tlsContextManager.findOrCreateClientSslContextProvider(upstreamTlsContext); + } + return tlsContextManager.findOrCreateServerSslContextProvider( + (DownstreamTlsContext) tlsContext); } @VisibleForTesting public boolean isShutdown() { diff --git a/xds/src/main/java/io/grpc/xds/internal/security/TlsContextManagerImpl.java b/xds/src/main/java/io/grpc/xds/internal/security/TlsContextManagerImpl.java index 34a8863c52b..f56524d50b7 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/TlsContextManagerImpl.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/TlsContextManagerImpl.java @@ -71,8 +71,6 @@ public SslContextProvider findOrCreateServerSslContextProvider( public SslContextProvider findOrCreateClientSslContextProvider( UpstreamTlsContext upstreamTlsContext) { checkNotNull(upstreamTlsContext, "upstreamTlsContext"); - CommonTlsContext.Builder builder = upstreamTlsContext.getCommonTlsContext().toBuilder(); - upstreamTlsContext = new UpstreamTlsContext(builder.build()); return mapForClients.get(upstreamTlsContext); } diff --git a/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProvider.java b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProvider.java index d4080101c1a..8984efc9435 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProvider.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProvider.java @@ -16,8 +16,6 @@ package io.grpc.xds.internal.security.certprovider; -import static com.google.common.base.Preconditions.checkNotNull; - import io.envoyproxy.envoy.config.core.v3.Node; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateValidationContext; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext; @@ -28,8 +26,11 @@ import io.netty.handler.ssl.SslContextBuilder; import java.security.cert.CertStoreException; import java.security.cert.X509Certificate; +import java.util.AbstractMap; +import java.util.Arrays; import java.util.Map; import javax.annotation.Nullable; +import javax.net.ssl.X509TrustManager; /** A client SslContext provider using CertificateProviderInstance to fetch secrets. */ final class CertProviderClientSslContextProvider extends CertProviderSslContextProvider { @@ -46,26 +47,41 @@ final class CertProviderClientSslContextProvider extends CertProviderSslContextP node, certProviders, certInstance, - checkNotNull(rootCertInstance, "Client SSL requires rootCertInstance"), + rootCertInstance, staticCertValidationContext, upstreamTlsContext, certificateProviderStore); } @Override - protected final SslContextBuilder getSslContextBuilder( - CertificateValidationContext certificateValidationContextdationContext) - throws CertStoreException { + protected final AbstractMap.SimpleImmutableEntry + getSslContextBuilderAndTrustManager( + CertificateValidationContext certificateValidationContext) + throws CertStoreException { + UpstreamTlsContext upstreamTlsContext = (UpstreamTlsContext) tlsContext; + XdsTrustManagerFactory trustManagerFactory; + if (savedSpiffeTrustMap != null) { + trustManagerFactory = new XdsTrustManagerFactory( + savedSpiffeTrustMap, + certificateValidationContext, + upstreamTlsContext.getAutoSniSanValidation()); + } else if (savedTrustedRoots != null) { + trustManagerFactory = new XdsTrustManagerFactory( + savedTrustedRoots.toArray(new X509Certificate[0]), + certificateValidationContext, + upstreamTlsContext.getAutoSniSanValidation()); + } else { + // Should be impossible because of the check in CertProviderClientSslContextProviderFactory + throw new IllegalStateException("There must be trusted roots or a SPIFFE trust map"); + } + SslContextBuilder sslContextBuilder = - GrpcSslContexts.forClient() - .trustManager( - new XdsTrustManagerFactory( - savedTrustedRoots.toArray(new X509Certificate[0]), - certificateValidationContextdationContext)); + GrpcSslContexts.forClient().trustManager(trustManagerFactory); if (isMtls()) { sslContextBuilder.keyManager(savedKey, savedCertChain); } - return sslContextBuilder; + return new AbstractMap.SimpleImmutableEntry<>(sslContextBuilder, + io.grpc.internal.CertificateUtils.getX509ExtendedTrustManager( + Arrays.asList(trustManagerFactory.getTrustManagers()))); } - } diff --git a/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProviderFactory.java b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProviderFactory.java index 21782741c2c..6205c1c3a63 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProviderFactory.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProviderFactory.java @@ -25,6 +25,7 @@ import io.grpc.Internal; import io.grpc.xds.EnvoyServerProtoData.UpstreamTlsContext; import io.grpc.xds.client.Bootstrapper.CertificateProviderInfo; +import io.grpc.xds.internal.security.CommonTlsContextUtil; import io.grpc.xds.internal.security.SslContextProvider; import java.util.Map; import javax.annotation.Nullable; @@ -64,13 +65,17 @@ public SslContextProvider getProvider( = CertProviderSslContextProvider.getRootCertProviderInstance(commonTlsContext); CommonTlsContext.CertificateProviderInstance certInstance = CertProviderSslContextProvider.getCertProviderInstance(commonTlsContext); - return new CertProviderClientSslContextProvider( - node, - certProviders, - certInstance, - rootCertInstance, - staticCertValidationContext, - upstreamTlsContext, - certificateProviderStore); + if (CommonTlsContextUtil.hasCertProviderInstance(upstreamTlsContext.getCommonTlsContext()) + || CommonTlsContextUtil.isUsingSystemRootCerts(upstreamTlsContext.getCommonTlsContext())) { + return new CertProviderClientSslContextProvider( + node, + certProviders, + certInstance, + rootCertInstance, + staticCertValidationContext, + upstreamTlsContext, + certificateProviderStore); + } + throw new UnsupportedOperationException("Unsupported configurations in UpstreamTlsContext!"); } } diff --git a/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderServerSslContextProvider.java b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderServerSslContextProvider.java index e43452a53e1..3712b948142 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderServerSslContextProvider.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderServerSslContextProvider.java @@ -30,8 +30,10 @@ import java.security.cert.CertStoreException; import java.security.cert.CertificateException; import java.security.cert.X509Certificate; +import java.util.AbstractMap; import java.util.Map; import javax.annotation.Nullable; +import javax.net.ssl.X509TrustManager; /** A server SslContext provider using CertificateProviderInstance to fetch secrets. */ final class CertProviderServerSslContextProvider extends CertProviderSslContextProvider { @@ -55,19 +57,25 @@ final class CertProviderServerSslContextProvider extends CertProviderSslContextP } @Override - protected final SslContextBuilder getSslContextBuilder( - CertificateValidationContext certificateValidationContextdationContext) - throws CertStoreException, CertificateException, IOException { + protected final AbstractMap.SimpleImmutableEntry + getSslContextBuilderAndTrustManager( + CertificateValidationContext certificateValidationContextdationContext) + throws CertStoreException, CertificateException, IOException { SslContextBuilder sslContextBuilder = SslContextBuilder.forServer(savedKey, savedCertChain); - setClientAuthValues( - sslContextBuilder, - isMtls() - ? new XdsTrustManagerFactory( - savedTrustedRoots.toArray(new X509Certificate[0]), - certificateValidationContextdationContext) - : null); + XdsTrustManagerFactory trustManagerFactory = null; + if (isMtls() && savedSpiffeTrustMap != null) { + trustManagerFactory = new XdsTrustManagerFactory( + savedSpiffeTrustMap, + certificateValidationContextdationContext, false); + } else if (isMtls()) { + trustManagerFactory = new XdsTrustManagerFactory( + savedTrustedRoots.toArray(new X509Certificate[0]), + certificateValidationContextdationContext, false); + } + setClientAuthValues(sslContextBuilder, trustManagerFactory); sslContextBuilder = GrpcSslContexts.configure(sslContextBuilder); - return sslContextBuilder; + // TrustManager in the below return value is not used on the server side, so setting it to null + return new AbstractMap.SimpleImmutableEntry<>(sslContextBuilder, null); } } diff --git a/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderSslContextProvider.java b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderSslContextProvider.java index 6570c619913..cb99ca6ad95 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderSslContextProvider.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderSslContextProvider.java @@ -24,6 +24,7 @@ import io.grpc.xds.client.Bootstrapper.CertificateProviderInfo; import io.grpc.xds.internal.security.CommonTlsContextUtil; import io.grpc.xds.internal.security.DynamicSslContextProvider; +import java.io.Closeable; import java.security.PrivateKey; import java.security.cert.X509Certificate; import java.util.List; @@ -34,13 +35,15 @@ abstract class CertProviderSslContextProvider extends DynamicSslContextProvider implements CertificateProvider.Watcher { - @Nullable private final CertificateProviderStore.Handle certHandle; - @Nullable private final CertificateProviderStore.Handle rootCertHandle; + @Nullable private final NoExceptionCloseable certHandle; + @Nullable private final NoExceptionCloseable rootCertHandle; @Nullable private final CertificateProviderInstance certInstance; - @Nullable private final CertificateProviderInstance rootCertInstance; + @Nullable protected final CertificateProviderInstance rootCertInstance; @Nullable protected PrivateKey savedKey; @Nullable protected List savedCertChain; @Nullable protected List savedTrustedRoots; + @Nullable protected Map> savedSpiffeTrustMap; + private final boolean isUsingSystemRootCerts; protected CertProviderSslContextProvider( Node node, @@ -53,24 +56,33 @@ protected CertProviderSslContextProvider( super(tlsContext, staticCertValidationContext); this.certInstance = certInstance; this.rootCertInstance = rootCertInstance; - String certInstanceName = null; - if (certInstance != null && certInstance.isInitialized()) { - certInstanceName = certInstance.getInstanceName(); + this.isUsingSystemRootCerts = rootCertInstance == null + && CommonTlsContextUtil.isUsingSystemRootCerts(tlsContext.getCommonTlsContext()); + boolean createCertInstance = certInstance != null && certInstance.isInitialized(); + boolean createRootCertInstance = rootCertInstance != null && rootCertInstance.isInitialized(); + boolean sharedCertInstance = createCertInstance && createRootCertInstance + && rootCertInstance.getInstanceName().equals(certInstance.getInstanceName()); + if (createCertInstance) { CertificateProviderInfo certProviderInstanceConfig = - getCertProviderConfig(certProviders, certInstanceName); + getCertProviderConfig(certProviders, certInstance.getInstanceName()); + CertificateProvider.Watcher watcher = this; + if (!sharedCertInstance && !isUsingSystemRootCerts) { + watcher = new IgnoreUpdatesWatcher(watcher, /* ignoreRootCertUpdates= */ true); + } + // TODO: Previously we'd hang if certProviderInstanceConfig were null or + // certInstance.isInitialized() == false. Now we'll proceed. Those should be errors, or are + // they impossible and should be assertions? certHandle = certProviderInstanceConfig == null ? null : certificateProviderStore.createOrGetProvider( certInstance.getCertificateName(), certProviderInstanceConfig.pluginName(), certProviderInstanceConfig.config(), - this, - true); + watcher, + true)::close; } else { certHandle = null; } - if (rootCertInstance != null - && rootCertInstance.isInitialized() - && !rootCertInstance.getInstanceName().equals(certInstanceName)) { + if (createRootCertInstance && !sharedCertInstance) { CertificateProviderInfo certProviderInstanceConfig = getCertProviderConfig(certProviders, rootCertInstance.getInstanceName()); rootCertHandle = certProviderInstanceConfig == null ? null @@ -78,8 +90,13 @@ protected CertProviderSslContextProvider( rootCertInstance.getCertificateName(), certProviderInstanceConfig.pluginName(), certProviderInstanceConfig.config(), - this, - true); + new IgnoreUpdatesWatcher(this, /* ignoreRootCertUpdates= */ false), + false)::close; + } else if (rootCertInstance == null + && CommonTlsContextUtil.isUsingSystemRootCerts(tlsContext.getCommonTlsContext())) { + SystemRootCertificateProvider systemRootProvider = new SystemRootCertificateProvider(this); + systemRootProvider.start(); + rootCertHandle = systemRootProvider::close; } else { rootCertHandle = null; } @@ -95,10 +112,14 @@ protected static CertificateProviderInstance getCertProviderInstance( CommonTlsContext commonTlsContext) { if (commonTlsContext.hasTlsCertificateProviderInstance()) { return CommonTlsContextUtil.convert(commonTlsContext.getTlsCertificateProviderInstance()); - } else if (commonTlsContext.hasTlsCertificateCertificateProviderInstance()) { - return commonTlsContext.getTlsCertificateCertificateProviderInstance(); } - return null; + // Fall back to deprecated field for backward compatibility with Istio + @SuppressWarnings("deprecation") + CertificateProviderInstance deprecatedInstance = + commonTlsContext.hasTlsCertificateCertificateProviderInstance() + ? commonTlsContext.getTlsCertificateCertificateProviderInstance() + : null; + return deprecatedInstance; } @Nullable @@ -124,15 +145,6 @@ protected static CommonTlsContext.CertificateProviderInstance getRootCertProvide if (certValidationContext != null && certValidationContext.hasCaCertificateProviderInstance()) { return CommonTlsContextUtil.convert(certValidationContext.getCaCertificateProviderInstance()); } - if (commonTlsContext.hasCombinedValidationContext()) { - CommonTlsContext.CombinedCertificateValidationContext combinedValidationContext = - commonTlsContext.getCombinedValidationContext(); - if (combinedValidationContext.hasValidationContextCertificateProviderInstance()) { - return combinedValidationContext.getValidationContextCertificateProviderInstance(); - } - } else if (commonTlsContext.hasValidationContextCertificateProviderInstance()) { - return commonTlsContext.getValidationContextCertificateProviderInstance(); - } return null; } @@ -149,18 +161,24 @@ public final void updateTrustedRoots(List trustedRoots) { updateSslContextWhenReady(); } + @Override + public final void updateSpiffeTrustMap(Map> spiffeTrustMap) { + savedSpiffeTrustMap = spiffeTrustMap; + updateSslContextWhenReady(); + } + private void updateSslContextWhenReady() { if (isMtls()) { - if (savedKey != null && savedTrustedRoots != null) { + if (savedKey != null && (savedTrustedRoots != null || savedSpiffeTrustMap != null)) { updateSslContext(); clearKeysAndCerts(); } - } else if (isClientSideTls()) { - if (savedTrustedRoots != null) { + } else if (isRegularTlsAndClientSide()) { + if (savedTrustedRoots != null || savedSpiffeTrustMap != null) { updateSslContext(); clearKeysAndCerts(); } - } else if (isServerSideTls()) { + } else if (isRegularTlsAndServerSide()) { if (savedKey != null) { updateSslContext(); clearKeysAndCerts(); @@ -170,19 +188,22 @@ private void updateSslContextWhenReady() { private void clearKeysAndCerts() { savedKey = null; - savedTrustedRoots = null; + if (!isUsingSystemRootCerts) { + savedTrustedRoots = null; + savedSpiffeTrustMap = null; + } savedCertChain = null; } protected final boolean isMtls() { - return certInstance != null && rootCertInstance != null; + return certInstance != null && (rootCertInstance != null || isUsingSystemRootCerts); } - protected final boolean isClientSideTls() { - return rootCertInstance != null && certInstance == null; + protected final boolean isRegularTlsAndClientSide() { + return (rootCertInstance != null || isUsingSystemRootCerts) && certInstance == null; } - protected final boolean isServerSideTls() { + protected final boolean isRegularTlsAndServerSide() { return certInstance != null && rootCertInstance == null; } @@ -200,4 +221,9 @@ public final void close() { rootCertHandle.close(); } } + + interface NoExceptionCloseable extends Closeable { + @Override + void close(); + } } diff --git a/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertificateProvider.java b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertificateProvider.java index a0d5d0fc69f..009bb7bf566 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertificateProvider.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertificateProvider.java @@ -26,6 +26,7 @@ import java.util.Collections; import java.util.HashSet; import java.util.List; +import java.util.Map; import java.util.Set; /** @@ -45,6 +46,8 @@ public interface Watcher { void updateTrustedRoots(List trustedRoots); + void updateSpiffeTrustMap(Map> spiffeTrustMap); + void onError(Status errorStatus); } @@ -53,6 +56,7 @@ public static final class DistributorWatcher implements Watcher { private PrivateKey privateKey; private List certChain; private List trustedRoots; + private Map> spiffeTrustMap; @VisibleForTesting final Set downstreamWatchers = new HashSet<>(); @@ -65,6 +69,9 @@ synchronized void addWatcher(Watcher watcher) { if (trustedRoots != null) { sendLastTrustedRootsUpdate(watcher); } + if (spiffeTrustMap != null) { + sendLastSpiffeTrustMapUpdate(watcher); + } } synchronized void removeWatcher(Watcher watcher) { @@ -83,6 +90,10 @@ private void sendLastTrustedRootsUpdate(Watcher watcher) { watcher.updateTrustedRoots(trustedRoots); } + private void sendLastSpiffeTrustMapUpdate(Watcher watcher) { + watcher.updateSpiffeTrustMap(spiffeTrustMap); + } + @Override public synchronized void updateCertificate(PrivateKey key, List certChain) { checkNotNull(key, "key"); @@ -103,6 +114,14 @@ public synchronized void updateTrustedRoots(List trustedRoots) } } + @Override + public void updateSpiffeTrustMap(Map> spiffeTrustMap) { + this.spiffeTrustMap = spiffeTrustMap; + for (Watcher watcher : downstreamWatchers) { + sendLastSpiffeTrustMapUpdate(watcher); + } + } + @Override public synchronized void onError(Status errorStatus) { for (Watcher watcher : downstreamWatchers) { @@ -147,7 +166,7 @@ protected CertificateProvider(DistributorWatcher watcher, boolean notifyCertUpda @Override public abstract void close(); - /** Starts the cert refresh and watcher update cycle. */ + /** Starts the async cert refresh and watcher update cycle. */ public abstract void start(); private final DistributorWatcher watcher; diff --git a/xds/src/main/java/io/grpc/xds/internal/security/certprovider/FileWatcherCertificateProvider.java b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/FileWatcherCertificateProvider.java index dd945ce850e..9cb9a867118 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/certprovider/FileWatcherCertificateProvider.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/FileWatcherCertificateProvider.java @@ -16,10 +16,12 @@ package io.grpc.xds.internal.security.certprovider; +import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkNotNull; import com.google.common.annotations.VisibleForTesting; import io.grpc.Status; +import io.grpc.internal.SpiffeUtil; import io.grpc.internal.TimeProvider; import io.grpc.xds.internal.security.trust.CertificateUtils; import java.io.ByteArrayInputStream; @@ -30,6 +32,7 @@ import java.security.PrivateKey; import java.security.cert.X509Certificate; import java.util.Arrays; +import java.util.HashMap; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ScheduledFuture; import java.util.concurrent.TimeUnit; @@ -47,11 +50,13 @@ final class FileWatcherCertificateProvider extends CertificateProvider implement private final Path certFile; private final Path keyFile; private final Path trustFile; + private final Path spiffeTrustMapFile; private final long refreshIntervalInSeconds; @VisibleForTesting ScheduledFuture scheduledFuture; private FileTime lastModifiedTimeCert; private FileTime lastModifiedTimeKey; private FileTime lastModifiedTimeRoot; + private FileTime lastModifiedTimespiffeTrustMap; private boolean shutdown; FileWatcherCertificateProvider( @@ -60,6 +65,7 @@ final class FileWatcherCertificateProvider extends CertificateProvider implement String certFile, String keyFile, String trustFile, + String spiffeTrustMapFile, long refreshIntervalInSeconds, ScheduledExecutorService scheduledExecutorService, TimeProvider timeProvider) { @@ -69,7 +75,15 @@ final class FileWatcherCertificateProvider extends CertificateProvider implement this.timeProvider = checkNotNull(timeProvider, "timeProvider"); this.certFile = Paths.get(checkNotNull(certFile, "certFile")); this.keyFile = Paths.get(checkNotNull(keyFile, "keyFile")); - this.trustFile = Paths.get(checkNotNull(trustFile, "trustFile")); + checkArgument((trustFile != null || spiffeTrustMapFile != null), + "either trustFile or spiffeTrustMapFile must be present"); + if (spiffeTrustMapFile != null) { + this.spiffeTrustMapFile = Paths.get(spiffeTrustMapFile); + this.trustFile = null; + } else { + this.spiffeTrustMapFile = null; + this.trustFile = Paths.get(trustFile); + } this.refreshIntervalInSeconds = refreshIntervalInSeconds; } @@ -102,44 +116,53 @@ void checkAndReloadCertificates() { FileTime currentCertTime = Files.getLastModifiedTime(certFile); FileTime currentKeyTime = Files.getLastModifiedTime(keyFile); if (!currentCertTime.equals(lastModifiedTimeCert) - && !currentKeyTime.equals(lastModifiedTimeKey)) { + || !currentKeyTime.equals(lastModifiedTimeKey)) { byte[] certFileContents = Files.readAllBytes(certFile); byte[] keyFileContents = Files.readAllBytes(keyFile); FileTime currentCertTime2 = Files.getLastModifiedTime(certFile); FileTime currentKeyTime2 = Files.getLastModifiedTime(keyFile); - if (!currentCertTime2.equals(currentCertTime)) { - return; - } - if (!currentKeyTime2.equals(currentKeyTime)) { - return; - } - try (ByteArrayInputStream certStream = new ByteArrayInputStream(certFileContents); - ByteArrayInputStream keyStream = new ByteArrayInputStream(keyFileContents)) { - PrivateKey privateKey = CertificateUtils.getPrivateKey(keyStream); - X509Certificate[] certs = CertificateUtils.toX509Certificates(certStream); - getWatcher().updateCertificate(privateKey, Arrays.asList(certs)); + if (currentCertTime2.equals(currentCertTime) && currentKeyTime2.equals(currentKeyTime)) { + try (ByteArrayInputStream certStream = new ByteArrayInputStream(certFileContents); + ByteArrayInputStream keyStream = new ByteArrayInputStream(keyFileContents)) { + PrivateKey privateKey = CertificateUtils.getPrivateKey(keyStream); + X509Certificate[] certs = CertificateUtils.toX509Certificates(certStream); + getWatcher().updateCertificate(privateKey, Arrays.asList(certs)); + } + lastModifiedTimeCert = currentCertTime; + lastModifiedTimeKey = currentKeyTime; } - lastModifiedTimeCert = currentCertTime; - lastModifiedTimeKey = currentKeyTime; } } catch (Throwable t) { generateErrorIfCurrentCertExpired(t); } try { - FileTime currentRootTime = Files.getLastModifiedTime(trustFile); - if (currentRootTime.equals(lastModifiedTimeRoot)) { - return; - } - byte[] rootFileContents = Files.readAllBytes(trustFile); - FileTime currentRootTime2 = Files.getLastModifiedTime(trustFile); - if (!currentRootTime2.equals(currentRootTime)) { - return; + if (spiffeTrustMapFile != null) { + FileTime currentSpiffeTime = Files.getLastModifiedTime(spiffeTrustMapFile); + if (!currentSpiffeTime.equals(lastModifiedTimespiffeTrustMap)) { + SpiffeUtil.SpiffeBundle trustBundle = SpiffeUtil + .loadTrustBundleFromFile(spiffeTrustMapFile.toString()); + getWatcher().updateSpiffeTrustMap(new HashMap<>(trustBundle.getBundleMap())); + lastModifiedTimespiffeTrustMap = currentSpiffeTime; + } } - try (ByteArrayInputStream rootStream = new ByteArrayInputStream(rootFileContents)) { - X509Certificate[] caCerts = CertificateUtils.toX509Certificates(rootStream); - getWatcher().updateTrustedRoots(Arrays.asList(caCerts)); + } catch (Throwable t) { + getWatcher().onError(Status.fromThrowable(t)); + } + try { + if (trustFile != null) { + FileTime currentRootTime = Files.getLastModifiedTime(trustFile); + if (!currentRootTime.equals(lastModifiedTimeRoot)) { + byte[] rootFileContents = Files.readAllBytes(trustFile); + FileTime currentRootTime2 = Files.getLastModifiedTime(trustFile); + if (currentRootTime2.equals(currentRootTime)) { + try (ByteArrayInputStream rootStream = new ByteArrayInputStream(rootFileContents)) { + X509Certificate[] caCerts = CertificateUtils.toX509Certificates(rootStream); + getWatcher().updateTrustedRoots(Arrays.asList(caCerts)); + } + lastModifiedTimeRoot = currentRootTime; + } + } } - lastModifiedTimeRoot = currentRootTime; } catch (Throwable t) { getWatcher().onError(Status.fromThrowable(t)); } @@ -195,6 +218,7 @@ FileWatcherCertificateProvider create( String certFile, String keyFile, String trustFile, + String spiffeTrustMapFile, long refreshIntervalInSeconds, ScheduledExecutorService scheduledExecutorService, TimeProvider timeProvider) { @@ -204,6 +228,7 @@ FileWatcherCertificateProvider create( certFile, keyFile, trustFile, + spiffeTrustMapFile, refreshIntervalInSeconds, scheduledExecutorService, timeProvider); @@ -220,6 +245,7 @@ abstract FileWatcherCertificateProvider create( String certFile, String keyFile, String trustFile, + String spiffeTrustMapFile, long refreshIntervalInSeconds, ScheduledExecutorService scheduledExecutorService, TimeProvider timeProvider); diff --git a/xds/src/main/java/io/grpc/xds/internal/security/certprovider/FileWatcherCertificateProviderProvider.java b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/FileWatcherCertificateProviderProvider.java index c4b140442cb..e4871dc4c84 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/certprovider/FileWatcherCertificateProviderProvider.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/FileWatcherCertificateProviderProvider.java @@ -23,6 +23,7 @@ import com.google.common.util.concurrent.ThreadFactoryBuilder; import com.google.protobuf.Duration; import com.google.protobuf.util.Durations; +import io.grpc.internal.GrpcUtil; import io.grpc.internal.JsonUtil; import io.grpc.internal.TimeProvider; import java.text.ParseException; @@ -33,11 +34,16 @@ /** * Provider of {@link FileWatcherCertificateProvider}s. */ -final class FileWatcherCertificateProviderProvider implements CertificateProviderProvider { +public final class FileWatcherCertificateProviderProvider implements CertificateProviderProvider { + // TODO(lwge): Remove the old env var check once it's confirmed to be unused. + @VisibleForTesting + public static boolean enableSpiffe = GrpcUtil.getFlag("GRPC_EXPERIMENTAL_SPIFFE_TRUST_BUNDLE_MAP", + false) || GrpcUtil.getFlag("GRPC_EXPERIMENTAL_XDS_MTLS_SPIFFE", false); private static final String CERT_FILE_KEY = "certificate_file"; private static final String KEY_FILE_KEY = "private_key_file"; private static final String ROOT_FILE_KEY = "ca_certificate_file"; + private static final String SPIFFE_TRUST_MAP_FILE_KEY = "spiffe_trust_bundle_map_file"; private static final String REFRESH_INTERVAL_KEY = "refresh_interval"; @VisibleForTesting static final long REFRESH_INTERVAL_DEFAULT = 600L; @@ -82,6 +88,7 @@ public CertificateProvider createCertificateProvider( configObj.certFile, configObj.keyFile, configObj.rootFile, + configObj.spiffeTrustMapFile, configObj.refrehInterval, scheduledExecutorServiceFactory.create(), timeProvider); @@ -98,7 +105,20 @@ private static Config validateAndTranslateConfig(Object config) { Config configObj = new Config(); configObj.certFile = checkForNullAndGet(map, CERT_FILE_KEY); configObj.keyFile = checkForNullAndGet(map, KEY_FILE_KEY); - configObj.rootFile = checkForNullAndGet(map, ROOT_FILE_KEY); + if (enableSpiffe) { + if (!map.containsKey(ROOT_FILE_KEY) && !map.containsKey(SPIFFE_TRUST_MAP_FILE_KEY)) { + throw new NullPointerException( + String.format("either '%s' or '%s' is required in the config", + ROOT_FILE_KEY, SPIFFE_TRUST_MAP_FILE_KEY)); + } + if (map.containsKey(SPIFFE_TRUST_MAP_FILE_KEY)) { + configObj.spiffeTrustMapFile = JsonUtil.getString(map, SPIFFE_TRUST_MAP_FILE_KEY); + } else { + configObj.rootFile = JsonUtil.getString(map, ROOT_FILE_KEY); + } + } else { + configObj.rootFile = checkForNullAndGet(map, ROOT_FILE_KEY); + } String refreshIntervalString = JsonUtil.getString(map, REFRESH_INTERVAL_KEY); if (refreshIntervalString != null) { try { @@ -139,6 +159,7 @@ static class Config { String certFile; String keyFile; String rootFile; + String spiffeTrustMapFile; Long refrehInterval; } } diff --git a/xds/src/main/java/io/grpc/xds/internal/security/certprovider/IgnoreUpdatesWatcher.java b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/IgnoreUpdatesWatcher.java new file mode 100644 index 00000000000..cd9d88be41b --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/IgnoreUpdatesWatcher.java @@ -0,0 +1,68 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.security.certprovider; + +import static java.util.Objects.requireNonNull; + +import com.google.common.annotations.VisibleForTesting; +import io.grpc.Status; +import java.security.PrivateKey; +import java.security.cert.X509Certificate; +import java.util.List; +import java.util.Map; + +public final class IgnoreUpdatesWatcher implements CertificateProvider.Watcher { + private final CertificateProvider.Watcher delegate; + private final boolean ignoreRootCertUpdates; + + public IgnoreUpdatesWatcher( + CertificateProvider.Watcher delegate, boolean ignoreRootCertUpdates) { + this.delegate = requireNonNull(delegate, "delegate"); + this.ignoreRootCertUpdates = ignoreRootCertUpdates; + } + + @Override + public void updateCertificate(PrivateKey key, List certChain) { + if (ignoreRootCertUpdates) { + delegate.updateCertificate(key, certChain); + } + } + + @Override + public void updateTrustedRoots(List trustedRoots) { + if (!ignoreRootCertUpdates) { + delegate.updateTrustedRoots(trustedRoots); + } + } + + @Override + public void updateSpiffeTrustMap(Map> spiffeTrustMap) { + if (!ignoreRootCertUpdates) { + delegate.updateSpiffeTrustMap(spiffeTrustMap); + } + } + + @Override + public void onError(Status errorStatus) { + delegate.onError(errorStatus); + } + + @VisibleForTesting + public CertificateProvider.Watcher getDelegate() { + return delegate; + } +} diff --git a/xds/src/main/java/io/grpc/xds/internal/security/certprovider/SystemRootCertificateProvider.java b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/SystemRootCertificateProvider.java new file mode 100644 index 00000000000..7c60f714e71 --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/SystemRootCertificateProvider.java @@ -0,0 +1,71 @@ +/* + * Copyright 2020 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.security.certprovider; + +import io.grpc.Status; +import java.security.KeyStore; +import java.security.KeyStoreException; +import java.security.NoSuchAlgorithmException; +import java.security.cert.X509Certificate; +import java.util.Arrays; +import java.util.Collection; +import java.util.List; +import java.util.stream.Collectors; +import javax.net.ssl.TrustManager; +import javax.net.ssl.TrustManagerFactory; +import javax.net.ssl.X509TrustManager; + +/** + * An non-registered provider for CertProviderSslContextProvider to use the same code path for + * system root certs as provider-obtained certs. + */ +final class SystemRootCertificateProvider extends CertificateProvider { + public SystemRootCertificateProvider(CertificateProvider.Watcher watcher) { + super(new DistributorWatcher(), false); + getWatcher().addWatcher(watcher); + } + + @Override + public void start() { + try { + TrustManagerFactory trustManagerFactory = TrustManagerFactory.getInstance( + TrustManagerFactory.getDefaultAlgorithm()); + trustManagerFactory.init((KeyStore) null); + + List trustManagers = Arrays.asList(trustManagerFactory.getTrustManagers()); + List rootCerts = trustManagers.stream() + .filter(X509TrustManager.class::isInstance) + .map(X509TrustManager.class::cast) + .map(trustManager -> Arrays.asList(trustManager.getAcceptedIssuers())) + .flatMap(Collection::stream) + .collect(Collectors.toList()); + getWatcher().updateTrustedRoots(rootCerts); + } catch (KeyStoreException | NoSuchAlgorithmException ex) { + getWatcher().onError(Status.UNAVAILABLE + .withDescription("Could not load system root certs") + .withCause(ex)); + } + } + + @Override + public void close() { + // Unnecessary because there's no more callbacks, but do it for good measure + for (Watcher watcher : getWatcher().getDownstreamWatchers()) { + getWatcher().removeWatcher(watcher); + } + } +} diff --git a/xds/src/main/java/io/grpc/xds/internal/security/trust/CertificateUtils.java b/xds/src/main/java/io/grpc/xds/internal/security/trust/CertificateUtils.java index 86b6dd95c3e..41a3980c123 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/trust/CertificateUtils.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/trust/CertificateUtils.java @@ -16,6 +16,7 @@ package io.grpc.xds.internal.security.trust; +import io.grpc.internal.GrpcUtil; import java.io.BufferedInputStream; import java.io.File; import java.io.FileInputStream; @@ -29,6 +30,9 @@ * Contains certificate utility method(s). */ public final class CertificateUtils { + public static boolean useChannelAuthorityIfNoSniApplicable + = GrpcUtil.getFlag("GRPC_USE_CHANNEL_AUTHORITY_IF_NO_SNI_APPLICABLE", false); + /** * Generates X509Certificate array from a file on disk. * diff --git a/xds/src/main/java/io/grpc/xds/internal/security/trust/XdsTrustManagerFactory.java b/xds/src/main/java/io/grpc/xds/internal/security/trust/XdsTrustManagerFactory.java index c9d83902ec2..664c5dd9362 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/trust/XdsTrustManagerFactory.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/trust/XdsTrustManagerFactory.java @@ -17,6 +17,7 @@ package io.grpc.xds.internal.security.trust; import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.base.Preconditions.checkState; import com.google.common.annotations.VisibleForTesting; @@ -33,6 +34,9 @@ import java.security.cert.CertStoreException; import java.security.cert.CertificateException; import java.security.cert.X509Certificate; +import java.util.HashMap; +import java.util.List; +import java.util.Map; import java.util.logging.Level; import java.util.logging.Logger; import javax.net.ssl.ManagerFactoryParameters; @@ -54,26 +58,51 @@ public XdsTrustManagerFactory(CertificateValidationContext certificateValidation this( getTrustedCaFromCertContext(certificateValidationContext), certificateValidationContext, + false, false); } public XdsTrustManagerFactory( - X509Certificate[] certs, CertificateValidationContext staticCertificateValidationContext) - throws CertStoreException { - this(certs, staticCertificateValidationContext, true); + X509Certificate[] certs, CertificateValidationContext staticCertificateValidationContext, + boolean autoSniSanValidation) throws CertStoreException { + this(certs, staticCertificateValidationContext, true, autoSniSanValidation); + } + + public XdsTrustManagerFactory(Map> spiffeTrustMap, + CertificateValidationContext staticCertificateValidationContext, boolean autoSniSanValidation) + throws CertStoreException { + this(spiffeTrustMap, staticCertificateValidationContext, true, autoSniSanValidation); } private XdsTrustManagerFactory( X509Certificate[] certs, CertificateValidationContext certificateValidationContext, - boolean validationContextIsStatic) + boolean validationContextIsStatic, + boolean autoSniSanValidation) + throws CertStoreException { + if (validationContextIsStatic) { + checkArgument( + certificateValidationContext == null || !certificateValidationContext.hasTrustedCa() + || certificateValidationContext.hasSystemRootCerts(), + "only static certificateValidationContext expected"); + } + xdsX509TrustManager = createX509TrustManager( + certs, certificateValidationContext, autoSniSanValidation); + } + + private XdsTrustManagerFactory( + Map> spiffeTrustMap, + CertificateValidationContext certificateValidationContext, + boolean validationContextIsStatic, + boolean autoSniSanValidation) throws CertStoreException { if (validationContextIsStatic) { checkArgument( certificateValidationContext == null || !certificateValidationContext.hasTrustedCa(), "only static certificateValidationContext expected"); + xdsX509TrustManager = createX509TrustManager( + spiffeTrustMap, certificateValidationContext, autoSniSanValidation); } - xdsX509TrustManager = createX509TrustManager(certs, certificateValidationContext); } private static X509Certificate[] getTrustedCaFromCertContext( @@ -99,7 +128,28 @@ private static X509Certificate[] getTrustedCaFromCertContext( @VisibleForTesting static XdsX509TrustManager createX509TrustManager( - X509Certificate[] certs, CertificateValidationContext certContext) throws CertStoreException { + X509Certificate[] certs, CertificateValidationContext certContext, + boolean autoSniSanValidation) + throws CertStoreException { + return new XdsX509TrustManager(certContext, createTrustManager(certs), autoSniSanValidation); + } + + @VisibleForTesting + static XdsX509TrustManager createX509TrustManager( + Map> spiffeTrustMapFile, + CertificateValidationContext certContext, boolean autoSniSanValidation) + throws CertStoreException { + checkNotNull(spiffeTrustMapFile, "spiffeTrustMapFile"); + Map delegates = new HashMap<>(); + for (Map.Entry> entry:spiffeTrustMapFile.entrySet()) { + delegates.put(entry.getKey(), createTrustManager( + entry.getValue().toArray(new X509Certificate[0]))); + } + return new XdsX509TrustManager(certContext, delegates, autoSniSanValidation); + } + + private static X509ExtendedTrustManager createTrustManager(X509Certificate[] certs) + throws CertStoreException { TrustManagerFactory tmf = null; try { tmf = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()); @@ -131,7 +181,7 @@ static XdsX509TrustManager createX509TrustManager( if (myDelegate == null) { throw new CertStoreException("Native X509 TrustManager not found."); } - return new XdsX509TrustManager(certContext, myDelegate); + return myDelegate; } @Override diff --git a/xds/src/main/java/io/grpc/xds/internal/security/trust/XdsX509TrustManager.java b/xds/src/main/java/io/grpc/xds/internal/security/trust/XdsX509TrustManager.java index 4bb6f0520c4..01f25dda6c7 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/trust/XdsX509TrustManager.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/trust/XdsX509TrustManager.java @@ -19,18 +19,29 @@ import static com.google.common.base.Preconditions.checkNotNull; import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Optional; import com.google.common.base.Strings; +import com.google.common.collect.ImmutableMap; import com.google.re2j.Pattern; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateValidationContext; import io.envoyproxy.envoy.type.matcher.v3.RegexMatcher; import io.envoyproxy.envoy.type.matcher.v3.StringMatcher; +import io.grpc.internal.SpiffeUtil; import java.net.Socket; import java.security.cert.CertificateException; import java.security.cert.CertificateParsingException; import java.security.cert.X509Certificate; +import java.util.ArrayList; +import java.util.Arrays; import java.util.Collection; +import java.util.HashSet; import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Set; import javax.annotation.Nullable; +import javax.net.ssl.SNIHostName; +import javax.net.ssl.SNIServerName; import javax.net.ssl.SSLEngine; import javax.net.ssl.SSLParameters; import javax.net.ssl.SSLSocket; @@ -50,13 +61,28 @@ final class XdsX509TrustManager extends X509ExtendedTrustManager implements X509 private static final int ALT_IPA_NAME = 7; private final X509ExtendedTrustManager delegate; + private final Map spiffeTrustMapDelegates; private final CertificateValidationContext certContext; + private final boolean autoSniSanValidation; XdsX509TrustManager(@Nullable CertificateValidationContext certContext, - X509ExtendedTrustManager delegate) { + X509ExtendedTrustManager delegate, + boolean autoSniSanValidation) { checkNotNull(delegate, "delegate"); this.certContext = certContext; this.delegate = delegate; + this.spiffeTrustMapDelegates = null; + this.autoSniSanValidation = autoSniSanValidation; + } + + XdsX509TrustManager(@Nullable CertificateValidationContext certContext, + Map spiffeTrustMapDelegates, + boolean autoSniSanValidation) { + checkNotNull(spiffeTrustMapDelegates, "spiffeTrustMapDelegates"); + this.spiffeTrustMapDelegates = ImmutableMap.copyOf(spiffeTrustMapDelegates); + this.certContext = certContext; + this.delegate = null; + this.autoSniSanValidation = autoSniSanValidation; } private static boolean verifyDnsNameInPattern( @@ -97,7 +123,8 @@ private static boolean verifyDnsNamePrefix( return false; } return ignoreCase - ? altNameFromCert.toLowerCase().startsWith(sanToVerifyPrefix.toLowerCase()) + ? altNameFromCert.toLowerCase(Locale.ROOT).startsWith( + sanToVerifyPrefix.toLowerCase(Locale.ROOT)) : altNameFromCert.startsWith(sanToVerifyPrefix); } @@ -107,7 +134,8 @@ private static boolean verifyDnsNameSuffix( return false; } return ignoreCase - ? altNameFromCert.toLowerCase().endsWith(sanToVerifySuffix.toLowerCase()) + ? altNameFromCert.toLowerCase(Locale.ROOT).endsWith( + sanToVerifySuffix.toLowerCase(Locale.ROOT)) : altNameFromCert.endsWith(sanToVerifySuffix); } @@ -117,7 +145,8 @@ private static boolean verifyDnsNameContains( return false; } return ignoreCase - ? altNameFromCert.toLowerCase().contains(sanToVerifySubstring.toLowerCase()) + ? altNameFromCert.toLowerCase(Locale.ROOT).contains( + sanToVerifySubstring.toLowerCase(Locale.ROOT)) : altNameFromCert.contains(sanToVerifySubstring); } @@ -126,6 +155,9 @@ private static boolean verifyDnsNameExact( if (Strings.isNullOrEmpty(sanToVerifyExact)) { return false; } + if (sanToVerifyExact.contains("*")) { + return verifyDnsNameWildcard(altNameFromCert, sanToVerifyExact, ignoreCase); + } return ignoreCase ? sanToVerifyExact.equalsIgnoreCase(altNameFromCert) : sanToVerifyExact.equals(altNameFromCert); @@ -182,11 +214,11 @@ private static void verifySubjectAltNameInLeaf( * This is called from various check*Trusted methods. */ @VisibleForTesting - void verifySubjectAltNameInChain(X509Certificate[] peerCertChain) throws CertificateException { + void verifySubjectAltNameInChain(X509Certificate[] peerCertChain, + List verifyList) throws CertificateException { if (certContext == null) { return; } - List verifyList = certContext.getMatchSubjectAltNamesList(); if (verifyList.isEmpty()) { return; } @@ -198,62 +230,171 @@ void verifySubjectAltNameInChain(X509Certificate[] peerCertChain) throws Certifi } @Override + @SuppressWarnings("deprecation") // gRFC A29 predates match_typed_subject_alt_names public void checkClientTrusted(X509Certificate[] chain, String authType, Socket socket) throws CertificateException { - delegate.checkClientTrusted(chain, authType, socket); - verifySubjectAltNameInChain(chain); + chooseDelegate(chain).checkClientTrusted(chain, authType, socket); + verifySubjectAltNameInChain(chain, certContext != null + ? certContext.getMatchSubjectAltNamesList() : new ArrayList<>()); } @Override + @SuppressWarnings("deprecation") // gRFC A29 predates match_typed_subject_alt_names public void checkClientTrusted(X509Certificate[] chain, String authType, SSLEngine sslEngine) throws CertificateException { - delegate.checkClientTrusted(chain, authType, sslEngine); - verifySubjectAltNameInChain(chain); + chooseDelegate(chain).checkClientTrusted(chain, authType, sslEngine); + verifySubjectAltNameInChain(chain, certContext != null + ? certContext.getMatchSubjectAltNamesList() : new ArrayList<>()); } @Override + @SuppressWarnings("deprecation") // gRFC A29 predates match_typed_subject_alt_names public void checkClientTrusted(X509Certificate[] chain, String authType) throws CertificateException { - delegate.checkClientTrusted(chain, authType); - verifySubjectAltNameInChain(chain); + chooseDelegate(chain).checkClientTrusted(chain, authType); + verifySubjectAltNameInChain(chain, certContext != null + ? certContext.getMatchSubjectAltNamesList() : new ArrayList<>()); } @Override public void checkServerTrusted(X509Certificate[] chain, String authType, Socket socket) throws CertificateException { + List sniMatchers = null; if (socket instanceof SSLSocket) { SSLSocket sslSocket = (SSLSocket) socket; SSLParameters sslParams = sslSocket.getSSLParameters(); if (sslParams != null) { - sslParams.setEndpointIdentificationAlgorithm(null); + sslParams.setEndpointIdentificationAlgorithm(""); sslSocket.setSSLParameters(sslParams); } + sniMatchers = getAutoSniSanMatchers(sslParams); } - delegate.checkServerTrusted(chain, authType, socket); - verifySubjectAltNameInChain(chain); + if (sniMatchers.isEmpty() && certContext != null) { + @SuppressWarnings("deprecation") // gRFC A29 predates match_typed_subject_alt_names + List sniMatchersTmp = certContext.getMatchSubjectAltNamesList(); + sniMatchers = sniMatchersTmp; + } + chooseDelegate(chain).checkServerTrusted(chain, authType, socket); + verifySubjectAltNameInChain(chain, sniMatchers); } @Override public void checkServerTrusted(X509Certificate[] chain, String authType, SSLEngine sslEngine) throws CertificateException { + List sniMatchers = null; SSLParameters sslParams = sslEngine.getSSLParameters(); if (sslParams != null) { - sslParams.setEndpointIdentificationAlgorithm(null); + sslParams.setEndpointIdentificationAlgorithm(""); sslEngine.setSSLParameters(sslParams); + sniMatchers = getAutoSniSanMatchers(sslParams); + } + if (sniMatchers.isEmpty() && certContext != null) { + @SuppressWarnings("deprecation") // gRFC A29 predates match_typed_subject_alt_names + List sniMatchersTmp = certContext.getMatchSubjectAltNamesList(); + sniMatchers = sniMatchersTmp; } - delegate.checkServerTrusted(chain, authType, sslEngine); - verifySubjectAltNameInChain(chain); + chooseDelegate(chain).checkServerTrusted(chain, authType, sslEngine); + verifySubjectAltNameInChain(chain, sniMatchers); } @Override + @SuppressWarnings("deprecation") // gRFC A29 predates match_typed_subject_alt_names public void checkServerTrusted(X509Certificate[] chain, String authType) throws CertificateException { - delegate.checkServerTrusted(chain, authType); - verifySubjectAltNameInChain(chain); + chooseDelegate(chain).checkServerTrusted(chain, authType); + verifySubjectAltNameInChain(chain, certContext != null + ? certContext.getMatchSubjectAltNamesList() : new ArrayList<>()); + } + + private List getAutoSniSanMatchers(SSLParameters sslParams) { + List sniNamesToMatch = new ArrayList<>(); + if (autoSniSanValidation) { + List serverNames = sslParams.getServerNames(); + if (serverNames != null) { + for (SNIServerName serverName : serverNames) { + if (serverName instanceof SNIHostName) { + SNIHostName sniHostName = (SNIHostName) serverName; + String hostName = sniHostName.getAsciiName(); + sniNamesToMatch.add(StringMatcher.newBuilder().setExact(hostName).build()); + } + } + } + } + return sniNamesToMatch; + } + + private X509ExtendedTrustManager chooseDelegate(X509Certificate[] chain) + throws CertificateException { + if (spiffeTrustMapDelegates != null) { + Optional spiffeId = SpiffeUtil.extractSpiffeId(chain); + if (!spiffeId.isPresent()) { + throw new CertificateException("Failed to extract SPIFFE ID from peer leaf certificate"); + } + String trustDomain = spiffeId.get().getTrustDomain(); + if (!spiffeTrustMapDelegates.containsKey(trustDomain)) { + throw new CertificateException(String.format("Spiffe Trust Map doesn't contain trust" + + " domain '%s' from peer leaf certificate", trustDomain)); + } + return spiffeTrustMapDelegates.get(trustDomain); + } else { + return delegate; + } } @Override public X509Certificate[] getAcceptedIssuers() { + if (spiffeTrustMapDelegates != null) { + Set result = new HashSet<>(); + for (X509ExtendedTrustManager tm: spiffeTrustMapDelegates.values()) { + result.addAll(Arrays.asList(tm.getAcceptedIssuers())); + } + return result.toArray(new X509Certificate[0]); + } return delegate.getAcceptedIssuers(); } + + private static boolean verifyDnsNameWildcard( + String altNameFromCert, String sanToVerify, boolean ignoreCase) { + String[] splitPattern = splitAtFirstDelimiter(ignoreCase + ? sanToVerify.toLowerCase(Locale.ROOT) : sanToVerify); + String[] splitDnsName = splitAtFirstDelimiter(ignoreCase + ? altNameFromCert.toLowerCase(Locale.ROOT) : altNameFromCert); + if (splitPattern == null || splitDnsName == null) { + return false; + } + if (splitDnsName[0].startsWith("xn--")) { + return false; + } + if (splitPattern[0].contains("*") + && !splitPattern[1].contains("*") + && !splitPattern[0].startsWith("xn--")) { + return splitDnsName[1].equals(splitPattern[1]) + && labelWildcardMatch(splitDnsName[0], splitPattern[0]); + } + return false; + } + + private static boolean labelWildcardMatch(String dnsLabel, String pattern) { + final char glob = '*'; + // Check the special case of a single * pattern, as it's common. + if (pattern.equals("*")) { + return !dnsLabel.isEmpty(); + } + int globIndex = pattern.indexOf(glob); + if (pattern.indexOf(glob, globIndex + 1) == -1) { + return dnsLabel.length() >= pattern.length() - 1 + && dnsLabel.startsWith(pattern.substring(0, globIndex)) + && dnsLabel.endsWith(pattern.substring(globIndex + 1)); + } + return false; + } + + @Nullable + private static String[] splitAtFirstDelimiter(String s) { + int index = s.indexOf('.'); + if (index == -1) { + return null; + } + return new String[]{s.substring(0, index), s.substring(index + 1)}; + } } diff --git a/xds/src/main/java/io/grpc/xds/orca/OrcaOobUtil.java b/xds/src/main/java/io/grpc/xds/orca/OrcaOobUtil.java index ba03140d627..b37b9bc42e3 100644 --- a/xds/src/main/java/io/grpc/xds/orca/OrcaOobUtil.java +++ b/xds/src/main/java/io/grpc/xds/orca/OrcaOobUtil.java @@ -36,12 +36,16 @@ import io.grpc.ChannelLogger; import io.grpc.ChannelLogger.ChannelLogLevel; import io.grpc.ClientCall; +import io.grpc.ConnectivityState; import io.grpc.ConnectivityStateInfo; import io.grpc.ExperimentalApi; import io.grpc.LoadBalancer; import io.grpc.LoadBalancer.CreateSubchannelArgs; import io.grpc.LoadBalancer.Helper; +import io.grpc.LoadBalancer.PickResult; +import io.grpc.LoadBalancer.PickSubchannelArgs; import io.grpc.LoadBalancer.Subchannel; +import io.grpc.LoadBalancer.SubchannelPicker; import io.grpc.LoadBalancer.SubchannelStateListener; import io.grpc.Metadata; import io.grpc.Status; @@ -83,7 +87,7 @@ private OrcaOobUtil() {} * class WrrLoadbalancer extends LoadBalancer { * private final Helper originHelper; // the original Helper * - * public void handleResolvedAddresses(ResolvedAddresses resolvedAddresses) { + * public Status acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { * // listener implements the logic for WRR's usage of backend metrics. * OrcaReportingHelper orcaHelper = * OrcaOobUtil.newOrcaReportingHelper(originHelper); @@ -236,6 +240,30 @@ protected Helper delegate() { return delegate; } + @Override + public void updateBalancingState(ConnectivityState newState, SubchannelPicker newPicker) { + delegate.updateBalancingState(newState, new OrcaOobPicker(newPicker)); + } + + @VisibleForTesting + static final class OrcaOobPicker extends SubchannelPicker { + final SubchannelPicker delegate; + + OrcaOobPicker(SubchannelPicker delegate) { + this.delegate = delegate; + } + + @Override + public PickResult pickSubchannel(PickSubchannelArgs args) { + PickResult result = delegate.pickSubchannel(args); + Subchannel subchannel = result.getSubchannel(); + if (subchannel instanceof SubchannelImpl) { + return result.copyWithSubchannel(((SubchannelImpl) subchannel).delegate()); + } + return result; + } + } + @Override public Subchannel createSubchannel(CreateSubchannelArgs args) { syncContext.throwIfNotInThisSynchronizationContext(); diff --git a/xds/src/main/java/io/grpc/xds/package-info.java b/xds/src/main/java/io/grpc/xds/package-info.java index 74fa88cfe38..9cc15cd5449 100644 --- a/xds/src/main/java/io/grpc/xds/package-info.java +++ b/xds/src/main/java/io/grpc/xds/package-info.java @@ -15,7 +15,7 @@ */ /** - * Library for gPRC proxyless service mesh using Envoy xDS protocol. + * Library for gRPC proxyless service mesh using Envoy xDS protocol. * *

The package currently includes a name resolver plugin and a family of load balancer plugins. * A gRPC channel for a target with {@code "xds:"} scheme will load the plugins and a diff --git a/xds/src/main/resources/META-INF/services/io.grpc.LoadBalancerProvider b/xds/src/main/resources/META-INF/services/io.grpc.LoadBalancerProvider index e1c4d4aa427..04a2d9cf7a8 100644 --- a/xds/src/main/resources/META-INF/services/io.grpc.LoadBalancerProvider +++ b/xds/src/main/resources/META-INF/services/io.grpc.LoadBalancerProvider @@ -2,7 +2,6 @@ io.grpc.xds.CdsLoadBalancerProvider io.grpc.xds.PriorityLoadBalancerProvider io.grpc.xds.WeightedTargetLoadBalancerProvider io.grpc.xds.ClusterManagerLoadBalancerProvider -io.grpc.xds.ClusterResolverLoadBalancerProvider io.grpc.xds.ClusterImplLoadBalancerProvider io.grpc.xds.LeastRequestLoadBalancerProvider io.grpc.xds.RingHashLoadBalancerProvider diff --git a/xds/src/test/java/io/grpc/xds/AddressFilterTest.java b/xds/src/test/java/io/grpc/xds/AddressFilterTest.java index 7d92e2ba1ce..36709ab8ee6 100644 --- a/xds/src/test/java/io/grpc/xds/AddressFilterTest.java +++ b/xds/src/test/java/io/grpc/xds/AddressFilterTest.java @@ -58,4 +58,29 @@ public void filterAddresses() { assertThat(filteredAddress0.getAttributes().get(key1)).isEqualTo("value1"); assertThat(filteredAddress1.getAddresses()).containsExactlyElementsIn(eag3.getAddresses()); } + + @Test + public void longerPathChain() { + List addresses = Arrays.asList( + newEag(new InetSocketAddress(8000), Arrays.asList("A", "B", "C")), + newEag(new InetSocketAddress(8001), Arrays.asList("Z", "B", "C")), + newEag(new InetSocketAddress(8002), Arrays.asList("A", "Z", "C")), + newEag(new InetSocketAddress(8003), Arrays.asList("A", "B", "Z"))); + addresses = AddressFilter.filter(addresses, "A"); + assertThat(addresses).hasSize(3); + + addresses = AddressFilter.filter(addresses, "B"); + assertThat(addresses).hasSize(2); + + addresses = AddressFilter.filter(addresses, "C"); + assertThat(addresses).hasSize(1); + assertThat(addresses.get(0).getAddresses()).containsExactly(new InetSocketAddress(8000)); + + addresses = AddressFilter.filter(addresses, "D"); + assertThat(addresses).hasSize(0); + } + + private static EquivalentAddressGroup newEag(InetSocketAddress address, List names) { + return AddressFilter.setPathFilter(new EquivalentAddressGroup(address), names); + } } diff --git a/xds/src/test/java/io/grpc/xds/CdsLoadBalancer2Test.java b/xds/src/test/java/io/grpc/xds/CdsLoadBalancer2Test.java index e42aa03d73c..ff4813fe6a8 100644 --- a/xds/src/test/java/io/grpc/xds/CdsLoadBalancer2Test.java +++ b/xds/src/test/java/io/grpc/xds/CdsLoadBalancer2Test.java @@ -17,58 +17,76 @@ package io.grpc.xds; import static com.google.common.truth.Truth.assertThat; -import static io.grpc.xds.XdsLbPolicies.CLUSTER_RESOLVER_POLICY_NAME; -import static org.junit.Assert.fail; +import static io.grpc.xds.XdsLbPolicies.CLUSTER_IMPL_POLICY_NAME; +import static io.grpc.xds.XdsLbPolicies.PRIORITY_POLICY_NAME; +import static io.grpc.xds.XdsTestControlPlaneService.ADS_TYPE_URL_CDS; +import static io.grpc.xds.XdsTestControlPlaneService.ADS_TYPE_URL_EDS; +import static io.grpc.xds.XdsTestControlPlaneService.ADS_TYPE_URL_LDS; +import static io.grpc.xds.XdsTestControlPlaneService.ADS_TYPE_URL_RDS; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; -import static org.mockito.Mockito.reset; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; -import com.google.common.collect.ImmutableList; +import com.github.xds.type.v3.TypedStruct; import com.google.common.collect.ImmutableMap; import com.google.common.collect.Iterables; +import com.google.protobuf.Any; +import com.google.protobuf.Struct; +import com.google.protobuf.UInt32Value; +import com.google.protobuf.Value; +import io.envoyproxy.envoy.config.cluster.v3.CircuitBreakers; +import io.envoyproxy.envoy.config.cluster.v3.Cluster; +import io.envoyproxy.envoy.config.cluster.v3.LoadBalancingPolicy; +import io.envoyproxy.envoy.config.cluster.v3.LoadBalancingPolicy.Policy; +import io.envoyproxy.envoy.config.cluster.v3.OutlierDetection; +import io.envoyproxy.envoy.config.core.v3.Address; +import io.envoyproxy.envoy.config.core.v3.AggregatedConfigSource; +import io.envoyproxy.envoy.config.core.v3.ConfigSource; +import io.envoyproxy.envoy.config.core.v3.RoutingPriority; +import io.envoyproxy.envoy.config.core.v3.SelfConfigSource; +import io.envoyproxy.envoy.config.core.v3.SocketAddress; +import io.envoyproxy.envoy.config.core.v3.TransportSocket; +import io.envoyproxy.envoy.config.core.v3.TypedExtensionConfig; +import io.envoyproxy.envoy.config.endpoint.v3.ClusterLoadAssignment; +import io.envoyproxy.envoy.config.endpoint.v3.Endpoint; +import io.envoyproxy.envoy.config.endpoint.v3.LbEndpoint; +import io.envoyproxy.envoy.config.endpoint.v3.LocalityLbEndpoints; +import io.envoyproxy.envoy.extensions.clusters.aggregate.v3.ClusterConfig; +import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.UpstreamTlsContext; import io.grpc.Attributes; +import io.grpc.ChannelLogger; import io.grpc.ConnectivityState; -import io.grpc.EquivalentAddressGroup; -import io.grpc.InsecureChannelCredentials; import io.grpc.LoadBalancer; import io.grpc.LoadBalancer.Helper; import io.grpc.LoadBalancer.PickResult; import io.grpc.LoadBalancer.PickSubchannelArgs; import io.grpc.LoadBalancer.ResolvedAddresses; -import io.grpc.LoadBalancer.Subchannel; import io.grpc.LoadBalancer.SubchannelPicker; import io.grpc.LoadBalancerProvider; import io.grpc.LoadBalancerRegistry; import io.grpc.NameResolver; +import io.grpc.NameResolverRegistry; import io.grpc.Status; import io.grpc.Status.Code; import io.grpc.SynchronizationContext; -import io.grpc.internal.ObjectPool; +import io.grpc.inprocess.InProcessChannelBuilder; +import io.grpc.inprocess.InProcessServerBuilder; +import io.grpc.internal.FakeClock; +import io.grpc.testing.GrpcCleanupRule; +import io.grpc.util.GracefulSwitchLoadBalancerAccessor; import io.grpc.xds.CdsLoadBalancerProvider.CdsConfig; -import io.grpc.xds.ClusterResolverLoadBalancerProvider.ClusterResolverConfig; -import io.grpc.xds.ClusterResolverLoadBalancerProvider.ClusterResolverConfig.DiscoveryMechanism; -import io.grpc.xds.EnvoyServerProtoData.OutlierDetection; -import io.grpc.xds.EnvoyServerProtoData.SuccessRateEjection; -import io.grpc.xds.EnvoyServerProtoData.UpstreamTlsContext; -import io.grpc.xds.LeastRequestLoadBalancer.LeastRequestConfig; -import io.grpc.xds.RingHashLoadBalancer.RingHashConfig; -import io.grpc.xds.XdsClusterResource.CdsUpdate; -import io.grpc.xds.client.Bootstrapper.ServerInfo; +import io.grpc.xds.ClusterImplLoadBalancerProvider.ClusterImplConfig; import io.grpc.xds.client.XdsClient; -import io.grpc.xds.client.XdsResourceType; import io.grpc.xds.internal.security.CommonTlsContextTestsUtil; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; -import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.concurrent.Executor; -import javax.annotation.Nullable; +import java.util.concurrent.TimeUnit; import org.junit.After; import org.junit.Before; import org.junit.Rule; @@ -87,630 +105,544 @@ @RunWith(JUnit4.class) public class CdsLoadBalancer2Test { @Rule public final MockitoRule mocks = MockitoJUnit.rule(); + @Rule + public final GrpcCleanupRule cleanupRule = new GrpcCleanupRule(); + private static final String SERVER_NAME = "example.com"; private static final String CLUSTER = "cluster-foo.googleapis.com"; private static final String EDS_SERVICE_NAME = "backend-service-1.googleapis.com"; - private static final String DNS_HOST_NAME = "backend-service-dns.googleapis.com:443"; - private static final ServerInfo LRS_SERVER_INFO = - ServerInfo.create("lrs.googleapis.com", InsecureChannelCredentials.create()); - private final UpstreamTlsContext upstreamTlsContext = - CommonTlsContextTestsUtil.buildUpstreamTlsContext("google_cloud_private_spiffe", true); - private final OutlierDetection outlierDetection = OutlierDetection.create( - null, null, null, null, SuccessRateEjection.create(null, null, null, null), null); - - - private static final SynchronizationContext syncContext = new SynchronizationContext( - new Thread.UncaughtExceptionHandler() { - @Override - public void uncaughtException(Thread t, Throwable e) { - throw new RuntimeException(e); - //throw new AssertionError(e); - } - }); + private static final String NODE_ID = "node-id"; + private final io.grpc.xds.EnvoyServerProtoData.UpstreamTlsContext upstreamTlsContext = + CommonTlsContextTestsUtil.buildUpstreamTlsContext("cert-instance-name", true); + private static final Cluster EDS_CLUSTER = Cluster.newBuilder() + .setName(CLUSTER) + .setType(Cluster.DiscoveryType.EDS) + .setEdsClusterConfig(Cluster.EdsClusterConfig.newBuilder() + .setServiceName(EDS_SERVICE_NAME) + .setEdsConfig(ConfigSource.newBuilder() + .setAds(AggregatedConfigSource.newBuilder()))) + .build(); + + private final FakeClock fakeClock = new FakeClock(); private final LoadBalancerRegistry lbRegistry = new LoadBalancerRegistry(); private final List childBalancers = new ArrayList<>(); - private final FakeXdsClient xdsClient = new FakeXdsClient(); - private final ObjectPool xdsClientPool = new ObjectPool() { - @Override - public XdsClient getObject() { - xdsClientRefs++; - return xdsClient; - } - - @Override - public XdsClient returnObject(Object object) { - xdsClientRefs--; - return null; - } - }; + private final XdsTestControlPlaneService controlPlaneService = new XdsTestControlPlaneService(); + private final XdsClient xdsClient = XdsTestUtils.createXdsClient( + Arrays.asList("control-plane.example.com"), + serverInfo -> new GrpcXdsTransportFactory.GrpcXdsTransport( + InProcessChannelBuilder + .forName(serverInfo.target()) + .directExecutor() + .build()), + fakeClock); + private XdsDependencyManager xdsDepManager; @Mock private Helper helper; @Captor private ArgumentCaptor pickerCaptor; - private int xdsClientRefs; - private CdsLoadBalancer2 loadBalancer; + private CdsLoadBalancer2 loadBalancer; + private XdsConfig lastXdsConfig; @Before - public void setUp() { - when(helper.getSynchronizationContext()).thenReturn(syncContext); - lbRegistry.register(new FakeLoadBalancerProvider(CLUSTER_RESOLVER_POLICY_NAME)); + public void setUp() throws Exception { + lbRegistry.register(new FakeLoadBalancerProvider(PRIORITY_POLICY_NAME)); + lbRegistry.register(new FakeLoadBalancerProvider(CLUSTER_IMPL_POLICY_NAME)); lbRegistry.register(new FakeLoadBalancerProvider("round_robin")); + lbRegistry.register(new FakeLoadBalancerProvider("outlier_detection_experimental")); lbRegistry.register( new FakeLoadBalancerProvider("ring_hash_experimental", new RingHashLoadBalancerProvider())); lbRegistry.register(new FakeLoadBalancerProvider("least_request_experimental", new LeastRequestLoadBalancerProvider())); - loadBalancer = new CdsLoadBalancer2(helper, lbRegistry); - loadBalancer.acceptResolvedAddresses( - ResolvedAddresses.newBuilder() - .setAddresses(Collections.emptyList()) - .setAttributes( - // Other attributes not used by cluster_resolver LB are omitted. - Attributes.newBuilder() - .set(InternalXdsAttributes.XDS_CLIENT_POOL, xdsClientPool) - .build()) - .setLoadBalancingPolicyConfig(new CdsConfig(CLUSTER)) - .build()); - assertThat(Iterables.getOnlyElement(xdsClient.watchers.keySet())).isEqualTo(CLUSTER); + lbRegistry.register(new FakeLoadBalancerProvider("wrr_locality_experimental", + new WrrLocalityLoadBalancerProvider())); + CdsLoadBalancerProvider cdsLoadBalancerProvider = new CdsLoadBalancerProvider(lbRegistry); + lbRegistry.register(cdsLoadBalancerProvider); + loadBalancer = (CdsLoadBalancer2) cdsLoadBalancerProvider.newLoadBalancer(helper); + + cleanupRule.register(InProcessServerBuilder + .forName("control-plane.example.com") + .addService(controlPlaneService) + .directExecutor() + .build() + .start()); + + SynchronizationContext syncContext = new SynchronizationContext((t, e) -> { + throw new AssertionError(e); + }); + when(helper.getSynchronizationContext()).thenReturn(syncContext); + when(helper.getScheduledExecutorService()).thenReturn(fakeClock.getScheduledExecutorService()); + + NameResolver.Args nameResolverArgs = NameResolver.Args.newBuilder() + .setDefaultPort(8080) + .setProxyDetector((address) -> null) + .setSynchronizationContext(syncContext) + .setServiceConfigParser(mock(NameResolver.ServiceConfigParser.class)) + .setChannelLogger(mock(ChannelLogger.class)) + .setScheduledExecutorService(fakeClock.getScheduledExecutorService()) + .setNameResolverRegistry(new NameResolverRegistry()) + .build(); + + xdsDepManager = new XdsDependencyManager( + xdsClient, + syncContext, + SERVER_NAME, + SERVER_NAME, + nameResolverArgs); + + controlPlaneService.setXdsConfig(ADS_TYPE_URL_LDS, ImmutableMap.of( + SERVER_NAME, ControlPlaneRule.buildClientListener(SERVER_NAME, "my-route"))); + controlPlaneService.setXdsConfig(ADS_TYPE_URL_RDS, ImmutableMap.of( + "my-route", XdsTestUtils.buildRouteConfiguration(SERVER_NAME, "my-route", CLUSTER))); + controlPlaneService.setXdsConfig(ADS_TYPE_URL_EDS, ImmutableMap.of( + EDS_SERVICE_NAME, ControlPlaneRule.buildClusterLoadAssignment( + "127.0.0.1", "", 1234, EDS_SERVICE_NAME))); } @After public void tearDown() { - loadBalancer.shutdown(); - assertThat(xdsClient.watchers).isEmpty(); - assertThat(xdsClientRefs).isEqualTo(0); + if (loadBalancer != null) { + shutdownLoadBalancer(); + } assertThat(childBalancers).isEmpty(); + + if (xdsDepManager != null) { + xdsDepManager.shutdown(); + } + xdsClient.shutdown(); } - @Test - public void discoverTopLevelEdsCluster() { - CdsUpdate update = - CdsUpdate.forEds(CLUSTER, EDS_SERVICE_NAME, LRS_SERVER_INFO, 100L, upstreamTlsContext, - outlierDetection) - .roundRobinLbPolicy().build(); - xdsClient.deliverCdsUpdate(CLUSTER, update); - assertThat(childBalancers).hasSize(1); - FakeLoadBalancer childBalancer = Iterables.getOnlyElement(childBalancers); - assertThat(childBalancer.name).isEqualTo(CLUSTER_RESOLVER_POLICY_NAME); - ClusterResolverConfig childLbConfig = (ClusterResolverConfig) childBalancer.config; - assertThat(childLbConfig.discoveryMechanisms).hasSize(1); - DiscoveryMechanism instance = Iterables.getOnlyElement(childLbConfig.discoveryMechanisms); - assertDiscoveryMechanism(instance, CLUSTER, DiscoveryMechanism.Type.EDS, EDS_SERVICE_NAME, - null, LRS_SERVER_INFO, 100L, upstreamTlsContext, outlierDetection); - assertThat(childLbConfig.lbPolicy.getProvider().getPolicyName()).isEqualTo("round_robin"); + private void shutdownLoadBalancer() { + LoadBalancer lb = this.loadBalancer; + this.loadBalancer = null; // Must avoid calling acceptResolvedAddresses after shutdown + lb.shutdown(); } @Test - public void discoverTopLevelLogicalDnsCluster() { - CdsUpdate update = - CdsUpdate.forLogicalDns(CLUSTER, DNS_HOST_NAME, LRS_SERVER_INFO, 100L, upstreamTlsContext) - .leastRequestLbPolicy(3).build(); - xdsClient.deliverCdsUpdate(CLUSTER, update); + public void discoverTopLevelCluster() { + Cluster cluster = Cluster.newBuilder() + .setName(CLUSTER) + .setType(Cluster.DiscoveryType.EDS) + .setEdsClusterConfig(Cluster.EdsClusterConfig.newBuilder() + .setServiceName(EDS_SERVICE_NAME) + .setEdsConfig(ConfigSource.newBuilder() + .setAds(AggregatedConfigSource.newBuilder()))) + .setLbPolicy(Cluster.LbPolicy.ROUND_ROBIN) + .setLrsServer(ConfigSource.newBuilder() + .setSelf(SelfConfigSource.getDefaultInstance())) + .setCircuitBreakers(CircuitBreakers.newBuilder() + .addThresholds(CircuitBreakers.Thresholds.newBuilder() + .setPriority(RoutingPriority.DEFAULT) + .setMaxRequests(UInt32Value.newBuilder().setValue(100)))) + .setTransportSocket(TransportSocket.newBuilder() + .setName("envoy.transport_sockets.tls") + .setTypedConfig(Any.pack(UpstreamTlsContext.newBuilder() + .setCommonTlsContext(upstreamTlsContext.getCommonTlsContext()) + .build()))) + .setOutlierDetection(OutlierDetection.getDefaultInstance()) + .build(); + controlPlaneService.setXdsConfig(ADS_TYPE_URL_CDS, ImmutableMap.of(CLUSTER, cluster)); + startXdsDepManager(); + + verify(helper, never()).updateBalancingState(eq(ConnectivityState.TRANSIENT_FAILURE), any()); assertThat(childBalancers).hasSize(1); FakeLoadBalancer childBalancer = Iterables.getOnlyElement(childBalancers); - assertThat(childBalancer.name).isEqualTo(CLUSTER_RESOLVER_POLICY_NAME); - ClusterResolverConfig childLbConfig = (ClusterResolverConfig) childBalancer.config; - assertThat(childLbConfig.discoveryMechanisms).hasSize(1); - DiscoveryMechanism instance = Iterables.getOnlyElement(childLbConfig.discoveryMechanisms); - assertDiscoveryMechanism(instance, CLUSTER, DiscoveryMechanism.Type.LOGICAL_DNS, null, - DNS_HOST_NAME, LRS_SERVER_INFO, 100L, upstreamTlsContext, null); - assertThat(childLbConfig.lbPolicy.getProvider().getPolicyName()) - .isEqualTo("least_request_experimental"); - assertThat(((LeastRequestConfig) childLbConfig.lbPolicy.getConfig()).choiceCount).isEqualTo(3); + assertThat(childBalancer.name).isEqualTo(PRIORITY_POLICY_NAME); } @Test public void nonAggregateCluster_resourceNotExist_returnErrorPicker() { - xdsClient.deliverResourceNotExist(CLUSTER); + startXdsDepManager(); verify(helper).updateBalancingState( eq(ConnectivityState.TRANSIENT_FAILURE), pickerCaptor.capture()); - Status unavailable = Status.UNAVAILABLE.withDescription( - "CDS error: found 0 leaf (logical DNS or EDS) clusters for root cluster " + CLUSTER); - assertPicker(pickerCaptor.getValue(), unavailable, null); + String expectedDescription = "Error retrieving CDS resource " + CLUSTER + + " nodeID: " + NODE_ID + + ": NOT_FOUND: Timed out waiting for resource " + CLUSTER + " from xDS server"; + Status unavailable = Status.UNAVAILABLE.withDescription(expectedDescription); + assertPickerStatus(pickerCaptor.getValue(), unavailable); assertThat(childBalancers).isEmpty(); } @Test public void nonAggregateCluster_resourceUpdate() { - CdsUpdate update = - CdsUpdate.forEds(CLUSTER, null, null, 100L, upstreamTlsContext, outlierDetection) - .roundRobinLbPolicy().build(); - xdsClient.deliverCdsUpdate(CLUSTER, update); + lbRegistry.register(new PriorityLoadBalancerProvider()); + Cluster cluster = EDS_CLUSTER.toBuilder() + .setCircuitBreakers(CircuitBreakers.newBuilder() + .addThresholds(CircuitBreakers.Thresholds.newBuilder() + .setPriority(RoutingPriority.DEFAULT) + .setMaxRequests(UInt32Value.newBuilder().setValue(100)))) + .build(); + controlPlaneService.setXdsConfig(ADS_TYPE_URL_CDS, ImmutableMap.of(CLUSTER, cluster)); + startXdsDepManager(); + + verify(helper, never()).updateBalancingState(eq(ConnectivityState.TRANSIENT_FAILURE), any()); assertThat(childBalancers).hasSize(1); FakeLoadBalancer childBalancer = Iterables.getOnlyElement(childBalancers); - ClusterResolverConfig childLbConfig = (ClusterResolverConfig) childBalancer.config; - DiscoveryMechanism instance = Iterables.getOnlyElement(childLbConfig.discoveryMechanisms); - assertDiscoveryMechanism(instance, CLUSTER, DiscoveryMechanism.Type.EDS, null, null, null, - 100L, upstreamTlsContext, outlierDetection); - - update = CdsUpdate.forEds(CLUSTER, EDS_SERVICE_NAME, LRS_SERVER_INFO, 200L, null, - outlierDetection).roundRobinLbPolicy().build(); - xdsClient.deliverCdsUpdate(CLUSTER, update); - childLbConfig = (ClusterResolverConfig) childBalancer.config; - instance = Iterables.getOnlyElement(childLbConfig.discoveryMechanisms); - assertDiscoveryMechanism(instance, CLUSTER, DiscoveryMechanism.Type.EDS, EDS_SERVICE_NAME, - null, LRS_SERVER_INFO, 200L, null, outlierDetection); + ClusterImplConfig childLbConfig = (ClusterImplConfig) childBalancer.config; + assertThat(childLbConfig.cluster).isEqualTo(CLUSTER); + assertThat(childLbConfig.maxConcurrentRequests).isEqualTo(100L); + + cluster = EDS_CLUSTER.toBuilder() + .setCircuitBreakers(CircuitBreakers.newBuilder() + .addThresholds(CircuitBreakers.Thresholds.newBuilder() + .setPriority(RoutingPriority.DEFAULT) + .setMaxRequests(UInt32Value.newBuilder().setValue(200)))) + .build(); + controlPlaneService.setXdsConfig(ADS_TYPE_URL_CDS, ImmutableMap.of(CLUSTER, cluster)); + verify(helper, never()).updateBalancingState(eq(ConnectivityState.TRANSIENT_FAILURE), any()); + assertThat(childBalancers).hasSize(1); + childBalancer = Iterables.getOnlyElement(childBalancers); + childLbConfig = (ClusterImplConfig) childBalancer.config; + assertThat(childLbConfig.maxConcurrentRequests).isEqualTo(200L); } @Test public void nonAggregateCluster_resourceRevoked() { - CdsUpdate update = - CdsUpdate.forLogicalDns(CLUSTER, DNS_HOST_NAME, null, 100L, upstreamTlsContext) - .roundRobinLbPolicy().build(); - xdsClient.deliverCdsUpdate(CLUSTER, update); + lbRegistry.register(new PriorityLoadBalancerProvider()); + controlPlaneService.setXdsConfig(ADS_TYPE_URL_CDS, ImmutableMap.of(CLUSTER, EDS_CLUSTER)); + startXdsDepManager(); + + verify(helper, never()).updateBalancingState(eq(ConnectivityState.TRANSIENT_FAILURE), any()); assertThat(childBalancers).hasSize(1); FakeLoadBalancer childBalancer = Iterables.getOnlyElement(childBalancers); - ClusterResolverConfig childLbConfig = (ClusterResolverConfig) childBalancer.config; - DiscoveryMechanism instance = Iterables.getOnlyElement(childLbConfig.discoveryMechanisms); - assertDiscoveryMechanism(instance, CLUSTER, DiscoveryMechanism.Type.LOGICAL_DNS, null, - DNS_HOST_NAME, null, 100L, upstreamTlsContext, null); + ClusterImplConfig childLbConfig = (ClusterImplConfig) childBalancer.config; + assertThat(childLbConfig.cluster).isEqualTo(CLUSTER); + + controlPlaneService.setXdsConfig(ADS_TYPE_URL_CDS, ImmutableMap.of()); - xdsClient.deliverResourceNotExist(CLUSTER); assertThat(childBalancer.shutdown).isTrue(); - Status unavailable = Status.UNAVAILABLE.withDescription( - "CDS error: found 0 leaf (logical DNS or EDS) clusters for root cluster " + CLUSTER); + String expectedDescription = "Error retrieving CDS resource " + CLUSTER + + " nodeID: " + NODE_ID + + ": NOT_FOUND: Resource " + CLUSTER + " does not exist"; + Status unavailable = Status.UNAVAILABLE.withDescription(expectedDescription); verify(helper).updateBalancingState( eq(ConnectivityState.TRANSIENT_FAILURE), pickerCaptor.capture()); - assertPicker(pickerCaptor.getValue(), unavailable, null); + assertPickerStatus(pickerCaptor.getValue(), unavailable); assertThat(childBalancer.shutdown).isTrue(); assertThat(childBalancers).isEmpty(); } @Test - public void discoverAggregateCluster() { - String cluster1 = "cluster-01.googleapis.com"; - String cluster2 = "cluster-02.googleapis.com"; - // CLUSTER (aggr.) -> [cluster1 (aggr.), cluster2 (logical DNS)] - CdsUpdate update = - CdsUpdate.forAggregate(CLUSTER, Arrays.asList(cluster1, cluster2)) - .ringHashLbPolicy(100L, 1000L).build(); - xdsClient.deliverCdsUpdate(CLUSTER, update); - assertThat(xdsClient.watchers.keySet()).containsExactly(CLUSTER, cluster1, cluster2); - assertThat(childBalancers).isEmpty(); - String cluster3 = "cluster-03.googleapis.com"; - String cluster4 = "cluster-04.googleapis.com"; - // cluster1 (aggr.) -> [cluster3 (EDS), cluster4 (EDS)] - CdsUpdate update1 = - CdsUpdate.forAggregate(cluster1, Arrays.asList(cluster3, cluster4)) - .roundRobinLbPolicy().build(); - xdsClient.deliverCdsUpdate(cluster1, update1); - assertThat(xdsClient.watchers.keySet()).containsExactly( - CLUSTER, cluster1, cluster2, cluster3, cluster4); - assertThat(childBalancers).isEmpty(); - CdsUpdate update3 = CdsUpdate.forEds(cluster3, EDS_SERVICE_NAME, LRS_SERVER_INFO, 200L, - upstreamTlsContext, outlierDetection).roundRobinLbPolicy().build(); - xdsClient.deliverCdsUpdate(cluster3, update3); - assertThat(childBalancers).isEmpty(); - CdsUpdate update2 = - CdsUpdate.forLogicalDns(cluster2, DNS_HOST_NAME, null, 100L, null) - .roundRobinLbPolicy().build(); - xdsClient.deliverCdsUpdate(cluster2, update2); - assertThat(childBalancers).isEmpty(); - CdsUpdate update4 = - CdsUpdate.forEds(cluster4, null, LRS_SERVER_INFO, 300L, null, outlierDetection) - .roundRobinLbPolicy().build(); - xdsClient.deliverCdsUpdate(cluster4, update4); - assertThat(childBalancers).hasSize(1); // all non-aggregate clusters discovered + public void dynamicCluster() { + String clusterName = "cluster2"; + Cluster cluster = EDS_CLUSTER.toBuilder() + .setName(clusterName) + .build(); + controlPlaneService.setXdsConfig(ADS_TYPE_URL_CDS, ImmutableMap.of( + clusterName, cluster, + CLUSTER, Cluster.newBuilder().setName(CLUSTER).build())); + startXdsDepManager(new CdsConfig(clusterName, /*dynamic=*/ true)); + + verify(helper, never()).updateBalancingState(eq(ConnectivityState.TRANSIENT_FAILURE), any()); + assertThat(childBalancers).hasSize(1); FakeLoadBalancer childBalancer = Iterables.getOnlyElement(childBalancers); - assertThat(childBalancer.name).isEqualTo(CLUSTER_RESOLVER_POLICY_NAME); - ClusterResolverConfig childLbConfig = (ClusterResolverConfig) childBalancer.config; - assertThat(childLbConfig.discoveryMechanisms).hasSize(3); - // Clusters on higher level has higher priority: [cluster2, cluster3, cluster4] - assertDiscoveryMechanism(childLbConfig.discoveryMechanisms.get(0), cluster2, - DiscoveryMechanism.Type.LOGICAL_DNS, null, DNS_HOST_NAME, null, 100L, null, null); - assertDiscoveryMechanism(childLbConfig.discoveryMechanisms.get(1), cluster3, - DiscoveryMechanism.Type.EDS, EDS_SERVICE_NAME, null, LRS_SERVER_INFO, 200L, - upstreamTlsContext, outlierDetection); - assertDiscoveryMechanism(childLbConfig.discoveryMechanisms.get(2), cluster4, - DiscoveryMechanism.Type.EDS, null, null, LRS_SERVER_INFO, 300L, null, outlierDetection); - assertThat(childLbConfig.lbPolicy.getProvider().getPolicyName()) - .isEqualTo("ring_hash_experimental"); // dominated by top-level cluster's config - assertThat(((RingHashConfig) childLbConfig.lbPolicy.getConfig()).minRingSize).isEqualTo(100L); - assertThat(((RingHashConfig) childLbConfig.lbPolicy.getConfig()).maxRingSize).isEqualTo(1000L); - } - - @Test - public void aggregateCluster_noNonAggregateClusterExits_returnErrorPicker() { - String cluster1 = "cluster-01.googleapis.com"; - // CLUSTER (aggr.) -> [cluster1 (EDS)] - CdsUpdate update = - CdsUpdate.forAggregate(CLUSTER, Collections.singletonList(cluster1)) - .roundRobinLbPolicy().build(); - xdsClient.deliverCdsUpdate(CLUSTER, update); - assertThat(xdsClient.watchers.keySet()).containsExactly(CLUSTER, cluster1); - xdsClient.deliverResourceNotExist(cluster1); - verify(helper).updateBalancingState( - eq(ConnectivityState.TRANSIENT_FAILURE), pickerCaptor.capture()); - Status unavailable = Status.UNAVAILABLE.withDescription( - "CDS error: found 0 leaf (logical DNS or EDS) clusters for root cluster " + CLUSTER); - assertPicker(pickerCaptor.getValue(), unavailable, null); - assertThat(childBalancers).isEmpty(); - } + assertThat(childBalancer.name).isEqualTo(PRIORITY_POLICY_NAME); - @Test - public void aggregateCluster_descendantClustersRevoked() { - String cluster1 = "cluster-01.googleapis.com"; - String cluster2 = "cluster-02.googleapis.com"; - // CLUSTER (aggr.) -> [cluster1 (EDS), cluster2 (logical DNS)] - CdsUpdate update = - CdsUpdate.forAggregate(CLUSTER, Arrays.asList(cluster1, cluster2)) - .roundRobinLbPolicy().build(); - xdsClient.deliverCdsUpdate(CLUSTER, update); - assertThat(xdsClient.watchers.keySet()).containsExactly(CLUSTER, cluster1, cluster2); - CdsUpdate update1 = CdsUpdate.forEds(cluster1, EDS_SERVICE_NAME, LRS_SERVER_INFO, 200L, - upstreamTlsContext, outlierDetection).roundRobinLbPolicy().build(); - xdsClient.deliverCdsUpdate(cluster1, update1); - CdsUpdate update2 = - CdsUpdate.forLogicalDns(cluster2, DNS_HOST_NAME, LRS_SERVER_INFO, 100L, null) - .roundRobinLbPolicy().build(); - xdsClient.deliverCdsUpdate(cluster2, update2); - FakeLoadBalancer childBalancer = Iterables.getOnlyElement(childBalancers); - ClusterResolverConfig childLbConfig = (ClusterResolverConfig) childBalancer.config; - assertThat(childLbConfig.discoveryMechanisms).hasSize(2); - assertDiscoveryMechanism(childLbConfig.discoveryMechanisms.get(0), cluster1, - DiscoveryMechanism.Type.EDS, EDS_SERVICE_NAME, null, LRS_SERVER_INFO, 200L, - upstreamTlsContext, outlierDetection); - assertDiscoveryMechanism(childLbConfig.discoveryMechanisms.get(1), cluster2, - DiscoveryMechanism.Type.LOGICAL_DNS, null, DNS_HOST_NAME, LRS_SERVER_INFO, 100L, null, - null); - - // Revoke cluster1, should still be able to proceed with cluster2. - xdsClient.deliverResourceNotExist(cluster1); - assertThat(xdsClient.watchers.keySet()).containsExactly(CLUSTER, cluster1, cluster2); - childLbConfig = (ClusterResolverConfig) childBalancer.config; - assertThat(childLbConfig.discoveryMechanisms).hasSize(1); - assertDiscoveryMechanism(Iterables.getOnlyElement(childLbConfig.discoveryMechanisms), cluster2, - DiscoveryMechanism.Type.LOGICAL_DNS, null, DNS_HOST_NAME, LRS_SERVER_INFO, 100L, null, - null); - verify(helper, never()).updateBalancingState( - eq(ConnectivityState.TRANSIENT_FAILURE), any(SubchannelPicker.class)); - - // All revoked. - xdsClient.deliverResourceNotExist(cluster2); - verify(helper).updateBalancingState( - eq(ConnectivityState.TRANSIENT_FAILURE), pickerCaptor.capture()); - Status unavailable = Status.UNAVAILABLE.withDescription( - "CDS error: found 0 leaf (logical DNS or EDS) clusters for root cluster " + CLUSTER); - assertPicker(pickerCaptor.getValue(), unavailable, null); - assertThat(childBalancer.shutdown).isTrue(); - assertThat(childBalancers).isEmpty(); + assertThat(this.lastXdsConfig.getClusters()).containsKey(clusterName); + shutdownLoadBalancer(); + assertThat(this.lastXdsConfig.getClusters()).doesNotContainKey(clusterName); } @Test - public void aggregateCluster_rootClusterRevoked() { - String cluster1 = "cluster-01.googleapis.com"; - String cluster2 = "cluster-02.googleapis.com"; - // CLUSTER (aggr.) -> [cluster1 (EDS), cluster2 (logical DNS)] - CdsUpdate update = - CdsUpdate.forAggregate(CLUSTER, Arrays.asList(cluster1, cluster2)) - .roundRobinLbPolicy().build(); - xdsClient.deliverCdsUpdate(CLUSTER, update); - assertThat(xdsClient.watchers.keySet()).containsExactly(CLUSTER, cluster1, cluster2); - CdsUpdate update1 = CdsUpdate.forEds(cluster1, EDS_SERVICE_NAME, LRS_SERVER_INFO, 200L, - upstreamTlsContext, outlierDetection).roundRobinLbPolicy().build(); - xdsClient.deliverCdsUpdate(cluster1, update1); - CdsUpdate update2 = - CdsUpdate.forLogicalDns(cluster2, DNS_HOST_NAME, LRS_SERVER_INFO, 100L, null) - .roundRobinLbPolicy().build(); - xdsClient.deliverCdsUpdate(cluster2, update2); - FakeLoadBalancer childBalancer = Iterables.getOnlyElement(childBalancers); - ClusterResolverConfig childLbConfig = (ClusterResolverConfig) childBalancer.config; - assertThat(childLbConfig.discoveryMechanisms).hasSize(2); - assertDiscoveryMechanism(childLbConfig.discoveryMechanisms.get(0), cluster1, - DiscoveryMechanism.Type.EDS, EDS_SERVICE_NAME, null, LRS_SERVER_INFO, 200L, - upstreamTlsContext, outlierDetection); - assertDiscoveryMechanism(childLbConfig.discoveryMechanisms.get(1), cluster2, - DiscoveryMechanism.Type.LOGICAL_DNS, null, DNS_HOST_NAME, LRS_SERVER_INFO, 100L, null, - null); - - xdsClient.deliverResourceNotExist(CLUSTER); - assertThat(xdsClient.watchers.keySet()) - .containsExactly(CLUSTER); // subscription to all descendant clusters cancelled - verify(helper).updateBalancingState( - eq(ConnectivityState.TRANSIENT_FAILURE), pickerCaptor.capture()); - Status unavailable = Status.UNAVAILABLE.withDescription( - "CDS error: found 0 leaf (logical DNS or EDS) clusters for root cluster " + CLUSTER); - assertPicker(pickerCaptor.getValue(), unavailable, null); - assertThat(childBalancer.shutdown).isTrue(); - assertThat(childBalancers).isEmpty(); - } + public void discoverAggregateCluster_createsPriorityLbPolicy() { + CdsLoadBalancerProvider cdsLoadBalancerProvider = new CdsLoadBalancerProvider(lbRegistry); + lbRegistry.register(cdsLoadBalancerProvider); + loadBalancer = (CdsLoadBalancer2) cdsLoadBalancerProvider.newLoadBalancer(helper); - @Test - public void aggregateCluster_intermediateClusterChanges() { String cluster1 = "cluster-01.googleapis.com"; - // CLUSTER (aggr.) -> [cluster1] - CdsUpdate update = - CdsUpdate.forAggregate(CLUSTER, Collections.singletonList(cluster1)) - .roundRobinLbPolicy().build(); - xdsClient.deliverCdsUpdate(CLUSTER, update); - assertThat(xdsClient.watchers.keySet()).containsExactly(CLUSTER, cluster1); - - // CLUSTER (aggr.) -> [cluster2 (aggr.)] String cluster2 = "cluster-02.googleapis.com"; - update = - CdsUpdate.forAggregate(CLUSTER, Collections.singletonList(cluster2)) - .roundRobinLbPolicy().build(); - xdsClient.deliverCdsUpdate(CLUSTER, update); - assertThat(xdsClient.watchers.keySet()).containsExactly(CLUSTER, cluster2); - - // cluster2 (aggr.) -> [cluster3 (EDS)] String cluster3 = "cluster-03.googleapis.com"; - CdsUpdate update2 = - CdsUpdate.forAggregate(cluster2, Collections.singletonList(cluster3)) - .roundRobinLbPolicy().build(); - xdsClient.deliverCdsUpdate(cluster2, update2); - assertThat(xdsClient.watchers.keySet()).containsExactly(CLUSTER, cluster2, cluster3); - CdsUpdate update3 = CdsUpdate.forEds(cluster3, EDS_SERVICE_NAME, LRS_SERVER_INFO, 100L, - upstreamTlsContext, outlierDetection).roundRobinLbPolicy().build(); - xdsClient.deliverCdsUpdate(cluster3, update3); + String cluster4 = "cluster-04.googleapis.com"; + controlPlaneService.setXdsConfig(ADS_TYPE_URL_CDS, ImmutableMap.of( + // CLUSTER (aggr.) -> [cluster1 (aggr.), cluster2 (logical DNS), cluster3 (EDS)] + CLUSTER, Cluster.newBuilder() + .setName(CLUSTER) + .setClusterType(Cluster.CustomClusterType.newBuilder() + .setName("envoy.clusters.aggregate") + .setTypedConfig(Any.pack(ClusterConfig.newBuilder() + .addClusters(cluster1) + .addClusters(cluster2) + .addClusters(cluster3) + .build()))) + .setLbPolicy(Cluster.LbPolicy.RING_HASH) + .build(), + // cluster1 (aggr.) -> [cluster3 (EDS), cluster4 (EDS)] + cluster1, Cluster.newBuilder() + .setName(cluster1) + .setClusterType(Cluster.CustomClusterType.newBuilder() + .setName("envoy.clusters.aggregate") + .setTypedConfig(Any.pack(ClusterConfig.newBuilder() + .addClusters(cluster3) + .addClusters(cluster4) + .build()))) + .build(), + cluster2, Cluster.newBuilder() + .setName(cluster2) + .setType(Cluster.DiscoveryType.LOGICAL_DNS) + .setLoadAssignment(ClusterLoadAssignment.newBuilder() + .addEndpoints(LocalityLbEndpoints.newBuilder() + .addLbEndpoints(LbEndpoint.newBuilder() + .setEndpoint(Endpoint.newBuilder() + .setAddress(Address.newBuilder() + .setSocketAddress(SocketAddress.newBuilder() + .setAddress("dns.example.com") + .setPortValue(1111))))))) + .build(), + cluster3, EDS_CLUSTER.toBuilder() + .setName(cluster3) + .setCircuitBreakers(CircuitBreakers.newBuilder() + .addThresholds(CircuitBreakers.Thresholds.newBuilder() + .setPriority(RoutingPriority.DEFAULT) + .setMaxRequests(UInt32Value.newBuilder().setValue(100)))) + .build(), + cluster4, EDS_CLUSTER.toBuilder().setName(cluster4).build())); + startXdsDepManager(); + + verify(helper, never()).updateBalancingState(eq(ConnectivityState.TRANSIENT_FAILURE), any()); + assertThat(childBalancers).hasSize(1); FakeLoadBalancer childBalancer = Iterables.getOnlyElement(childBalancers); - ClusterResolverConfig childLbConfig = (ClusterResolverConfig) childBalancer.config; - assertThat(childLbConfig.discoveryMechanisms).hasSize(1); - DiscoveryMechanism instance = Iterables.getOnlyElement(childLbConfig.discoveryMechanisms); - assertDiscoveryMechanism(instance, cluster3, DiscoveryMechanism.Type.EDS, EDS_SERVICE_NAME, - null, LRS_SERVER_INFO, 100L, upstreamTlsContext, outlierDetection); - - // cluster2 revoked - xdsClient.deliverResourceNotExist(cluster2); - assertThat(xdsClient.watchers.keySet()) - .containsExactly(CLUSTER, cluster2); // cancelled subscription to cluster3 - verify(helper).updateBalancingState( - eq(ConnectivityState.TRANSIENT_FAILURE), pickerCaptor.capture()); - Status unavailable = Status.UNAVAILABLE.withDescription( - "CDS error: found 0 leaf (logical DNS or EDS) clusters for root cluster " + CLUSTER); - assertPicker(pickerCaptor.getValue(), unavailable, null); - assertThat(childBalancer.shutdown).isTrue(); - assertThat(childBalancers).isEmpty(); + assertThat(childBalancer.name).isEqualTo(PRIORITY_POLICY_NAME); + PriorityLoadBalancerProvider.PriorityLbConfig childLbConfig = + (PriorityLoadBalancerProvider.PriorityLbConfig) childBalancer.config; + assertThat(childLbConfig.priorities).hasSize(3); + assertThat(childLbConfig.priorities.get(0)).isEqualTo(cluster3); + assertThat(childLbConfig.priorities.get(1)).isEqualTo(cluster4); + assertThat(childLbConfig.priorities.get(2)).isEqualTo(cluster2); + assertThat(childLbConfig.childConfigs).hasSize(3); + PriorityLoadBalancerProvider.PriorityLbConfig.PriorityChildConfig childConfig3 = + childLbConfig.childConfigs.get(cluster3); + assertThat( + GracefulSwitchLoadBalancerAccessor.getChildProvider(childConfig3.childConfig) + .getPolicyName()) + .isEqualTo("cds_experimental"); + PriorityLoadBalancerProvider.PriorityLbConfig.PriorityChildConfig childConfig4 = + childLbConfig.childConfigs.get(cluster4); + assertThat( + GracefulSwitchLoadBalancerAccessor.getChildProvider(childConfig4.childConfig) + .getPolicyName()) + .isEqualTo("cds_experimental"); + PriorityLoadBalancerProvider.PriorityLbConfig.PriorityChildConfig childConfig2 = + childLbConfig.childConfigs.get(cluster2); + assertThat( + GracefulSwitchLoadBalancerAccessor.getChildProvider(childConfig2.childConfig) + .getPolicyName()) + .isEqualTo("cds_experimental"); } @Test - public void aggregateCluster_withLoops() { - String cluster1 = "cluster-01.googleapis.com"; - // CLUSTER (aggr.) -> [cluster1] - CdsUpdate update = - CdsUpdate.forAggregate(CLUSTER, Collections.singletonList(cluster1)) - .roundRobinLbPolicy().build(); - xdsClient.deliverCdsUpdate(CLUSTER, update); - assertThat(xdsClient.watchers.keySet()).containsExactly(CLUSTER, cluster1); - - // CLUSTER (aggr.) -> [cluster2 (aggr.)] - String cluster2 = "cluster-02.googleapis.com"; - update = - CdsUpdate.forAggregate(cluster1, Collections.singletonList(cluster2)) - .roundRobinLbPolicy().build(); - xdsClient.deliverCdsUpdate(cluster1, update); - assertThat(xdsClient.watchers.keySet()).containsExactly(CLUSTER, cluster1, cluster2); + // Both priorities will get tried using real priority LB policy. + public void discoverAggregateCluster_testChildCdsLbPolicyParsing() { + lbRegistry.register(new PriorityLoadBalancerProvider()); + CdsLoadBalancerProvider cdsLoadBalancerProvider = new CdsLoadBalancerProvider(lbRegistry); + lbRegistry.register(cdsLoadBalancerProvider); + loadBalancer = (CdsLoadBalancer2) cdsLoadBalancerProvider.newLoadBalancer(helper); - // cluster2 (aggr.) -> [cluster3 (EDS), cluster1 (parent), cluster2 (self), cluster3 (dup)] - String cluster3 = "cluster-03.googleapis.com"; - CdsUpdate update2 = - CdsUpdate.forAggregate(cluster2, Arrays.asList(cluster3, cluster1, cluster2, cluster3)) - .roundRobinLbPolicy().build(); - xdsClient.deliverCdsUpdate(cluster2, update2); - assertThat(xdsClient.watchers.keySet()).containsExactly(CLUSTER, cluster1, cluster2, cluster3); - - reset(helper); - CdsUpdate update3 = CdsUpdate.forEds(cluster3, EDS_SERVICE_NAME, LRS_SERVER_INFO, 100L, - upstreamTlsContext, outlierDetection).roundRobinLbPolicy().build(); - xdsClient.deliverCdsUpdate(cluster3, update3); - verify(helper).updateBalancingState( - eq(ConnectivityState.TRANSIENT_FAILURE), pickerCaptor.capture()); - Status unavailable = Status.UNAVAILABLE.withDescription( - "CDS error: circular aggregate clusters directly under cluster-02.googleapis.com for root" - + " cluster cluster-foo.googleapis.com, named [cluster-01.googleapis.com," - + " cluster-02.googleapis.com]"); - assertPicker(pickerCaptor.getValue(), unavailable, null); - } - - @Test - public void aggregateCluster_withLoops_afterEds() { String cluster1 = "cluster-01.googleapis.com"; - // CLUSTER (aggr.) -> [cluster1] - CdsUpdate update = - CdsUpdate.forAggregate(CLUSTER, Collections.singletonList(cluster1)) - .roundRobinLbPolicy().build(); - xdsClient.deliverCdsUpdate(CLUSTER, update); - assertThat(xdsClient.watchers.keySet()).containsExactly(CLUSTER, cluster1); - - // CLUSTER (aggr.) -> [cluster2 (aggr.)] String cluster2 = "cluster-02.googleapis.com"; - update = - CdsUpdate.forAggregate(cluster1, Collections.singletonList(cluster2)) - .roundRobinLbPolicy().build(); - xdsClient.deliverCdsUpdate(cluster1, update); - assertThat(xdsClient.watchers.keySet()).containsExactly(CLUSTER, cluster1, cluster2); - - String cluster3 = "cluster-03.googleapis.com"; - CdsUpdate update2 = - CdsUpdate.forAggregate(cluster2, Arrays.asList(cluster3)) - .roundRobinLbPolicy().build(); - xdsClient.deliverCdsUpdate(cluster2, update2); - CdsUpdate update3 = CdsUpdate.forEds(cluster3, EDS_SERVICE_NAME, LRS_SERVER_INFO, 100L, - upstreamTlsContext, outlierDetection).roundRobinLbPolicy().build(); - xdsClient.deliverCdsUpdate(cluster3, update3); - - // cluster2 (aggr.) -> [cluster3 (EDS)] - CdsUpdate update2a = - CdsUpdate.forAggregate(cluster2, Arrays.asList(cluster3, cluster1, cluster2, cluster3)) - .roundRobinLbPolicy().build(); - xdsClient.deliverCdsUpdate(cluster2, update2a); - assertThat(xdsClient.watchers.keySet()).containsExactly(CLUSTER, cluster1, cluster2, cluster3); - verify(helper).updateBalancingState( - eq(ConnectivityState.TRANSIENT_FAILURE), pickerCaptor.capture()); - Status unavailable = Status.UNAVAILABLE.withDescription( - "CDS error: circular aggregate clusters directly under cluster-02.googleapis.com for root" - + " cluster cluster-foo.googleapis.com, named [cluster-01.googleapis.com," - + " cluster-02.googleapis.com]"); - assertPicker(pickerCaptor.getValue(), unavailable, null); + controlPlaneService.setXdsConfig(ADS_TYPE_URL_CDS, ImmutableMap.of( + // CLUSTER (aggr.) -> [cluster1 (EDS), cluster2 (EDS)] + CLUSTER, Cluster.newBuilder() + .setName(CLUSTER) + .setClusterType(Cluster.CustomClusterType.newBuilder() + .setName("envoy.clusters.aggregate") + .setTypedConfig(Any.pack(ClusterConfig.newBuilder() + .addClusters(cluster1) + .addClusters(cluster2) + .build()))) + .build(), + cluster1, EDS_CLUSTER.toBuilder().setName(cluster1).build(), + cluster2, EDS_CLUSTER.toBuilder().setName(cluster2).build())); + startXdsDepManager(); + + verify(helper, never()).updateBalancingState(eq(ConnectivityState.TRANSIENT_FAILURE), any()); + assertThat(childBalancers).hasSize(2); + ClusterImplConfig cluster1ImplConfig = + (ClusterImplConfig) childBalancers.get(0).config; + assertThat(cluster1ImplConfig.cluster) + .isEqualTo("cluster-01.googleapis.com"); + assertThat(cluster1ImplConfig.edsServiceName) + .isEqualTo("backend-service-1.googleapis.com"); + ClusterImplConfig cluster2ImplConfig = + (ClusterImplConfig) childBalancers.get(1).config; + assertThat(cluster2ImplConfig.cluster) + .isEqualTo("cluster-02.googleapis.com"); + assertThat(cluster2ImplConfig.edsServiceName) + .isEqualTo("backend-service-1.googleapis.com"); } @Test - public void aggregateCluster_duplicateChildren() { - String cluster1 = "cluster-01.googleapis.com"; - String cluster2 = "cluster-02.googleapis.com"; - String cluster3 = "cluster-03.googleapis.com"; - String cluster4 = "cluster-04.googleapis.com"; - - // CLUSTER (aggr.) -> [cluster1] - CdsUpdate update = - CdsUpdate.forAggregate(CLUSTER, Collections.singletonList(cluster1)) - .roundRobinLbPolicy().build(); - xdsClient.deliverCdsUpdate(CLUSTER, update); - assertThat(xdsClient.watchers.keySet()).containsExactly(CLUSTER, cluster1); - - // cluster1 (aggr) -> [cluster3 (EDS), cluster2 (aggr), cluster4 (aggr)] - CdsUpdate update1 = - CdsUpdate.forAggregate(cluster1, Arrays.asList(cluster3, cluster2, cluster4, cluster3)) - .roundRobinLbPolicy().build(); - xdsClient.deliverCdsUpdate(cluster1, update1); - assertThat(xdsClient.watchers.keySet()).containsExactly( - cluster3, cluster4, cluster2, cluster1, CLUSTER); - xdsClient.watchers.values().forEach(list -> assertThat(list.size()).isEqualTo(1)); - - // cluster2 (agg) -> [cluster3 (EDS), cluster4 {agg}] with dups - CdsUpdate update2 = - CdsUpdate.forAggregate(cluster2, Arrays.asList(cluster3, cluster4, cluster3)) - .roundRobinLbPolicy().build(); - xdsClient.deliverCdsUpdate(cluster2, update2); - - // Define EDS cluster - CdsUpdate update3 = CdsUpdate.forEds(cluster3, EDS_SERVICE_NAME, LRS_SERVER_INFO, 100L, - upstreamTlsContext, outlierDetection).roundRobinLbPolicy().build(); - xdsClient.deliverCdsUpdate(cluster3, update3); - - // cluster4 (agg) -> [cluster3 (EDS)] with dups (3 copies) - CdsUpdate update4 = - CdsUpdate.forAggregate(cluster4, Arrays.asList(cluster3, cluster3, cluster3)) - .roundRobinLbPolicy().build(); - xdsClient.deliverCdsUpdate(cluster4, update4); - xdsClient.watchers.values().forEach(list -> assertThat(list.size()).isEqualTo(1)); - - FakeLoadBalancer childBalancer = Iterables.getOnlyElement(childBalancers); - ClusterResolverConfig childLbConfig = (ClusterResolverConfig) childBalancer.config; - assertThat(childLbConfig.discoveryMechanisms).hasSize(1); - DiscoveryMechanism instance = Iterables.getOnlyElement(childLbConfig.discoveryMechanisms); - assertDiscoveryMechanism(instance, cluster3, DiscoveryMechanism.Type.EDS, EDS_SERVICE_NAME, - null, LRS_SERVER_INFO, 100L, upstreamTlsContext, outlierDetection); + public void aggregateCluster_noChildren() { + controlPlaneService.setXdsConfig(ADS_TYPE_URL_CDS, ImmutableMap.of( + // CLUSTER (aggr.) -> [] + CLUSTER, Cluster.newBuilder() + .setName(CLUSTER) + .setClusterType(Cluster.CustomClusterType.newBuilder() + .setName("envoy.clusters.aggregate") + .setTypedConfig(Any.pack(ClusterConfig.newBuilder() + .build()))) + .build())); + startXdsDepManager(); + + verify(helper) + .updateBalancingState(eq(ConnectivityState.TRANSIENT_FAILURE), pickerCaptor.capture()); + PickResult result = pickerCaptor.getValue().pickSubchannel(mock(PickSubchannelArgs.class)); + Status actualStatus = result.getStatus(); + assertThat(actualStatus.getCode()).isEqualTo(Status.Code.UNAVAILABLE); + assertThat(actualStatus.getDescription()) + .contains("aggregate ClusterConfig.clusters must not be empty"); + assertThat(childBalancers).isEmpty(); } @Test - public void aggregateCluster_discoveryErrorBeforeChildLbCreated_returnErrorPicker() { + public void aggregateCluster_noNonAggregateClusterExits_returnErrorPicker() { + lbRegistry.register(new PriorityLoadBalancerProvider()); + CdsLoadBalancerProvider cdsLoadBalancerProvider = new CdsLoadBalancerProvider(lbRegistry); + lbRegistry.register(cdsLoadBalancerProvider); + loadBalancer = (CdsLoadBalancer2) cdsLoadBalancerProvider.newLoadBalancer(helper); + String cluster1 = "cluster-01.googleapis.com"; - // CLUSTER (aggr.) -> [cluster1] - CdsUpdate update = - CdsUpdate.forAggregate(CLUSTER, Collections.singletonList(cluster1)) - .roundRobinLbPolicy().build(); - xdsClient.deliverCdsUpdate(CLUSTER, update); - assertThat(xdsClient.watchers.keySet()).containsExactly(CLUSTER, cluster1); - Status error = Status.RESOURCE_EXHAUSTED.withDescription("OOM"); - xdsClient.deliverError(error); + controlPlaneService.setXdsConfig(ADS_TYPE_URL_CDS, ImmutableMap.of( + // CLUSTER (aggr.) -> [cluster1 (missing)] + CLUSTER, Cluster.newBuilder() + .setName(CLUSTER) + .setClusterType(Cluster.CustomClusterType.newBuilder() + .setName("envoy.clusters.aggregate") + .setTypedConfig(Any.pack(ClusterConfig.newBuilder() + .addClusters(cluster1) + .build()))) + .setLbPolicy(Cluster.LbPolicy.RING_HASH) + .build())); + startXdsDepManager(); + verify(helper).updateBalancingState( eq(ConnectivityState.TRANSIENT_FAILURE), pickerCaptor.capture()); - Status expectedError = Status.UNAVAILABLE.withDescription( - "Unable to load CDS cluster-foo.googleapis.com. xDS server returned: " - + "RESOURCE_EXHAUSTED: OOM"); - assertPicker(pickerCaptor.getValue(), expectedError, null); + String expectedDescription = "Error retrieving CDS resource " + cluster1 + + " nodeID: " + NODE_ID + + ": NOT_FOUND: Timed out waiting for resource " + cluster1 + " from xDS server"; + Status status = Status.UNAVAILABLE.withDescription(expectedDescription); + assertPickerStatus(pickerCaptor.getValue(), status); assertThat(childBalancers).isEmpty(); } @Test - public void aggregateCluster_discoveryErrorAfterChildLbCreated_propagateToChildLb() { - String cluster1 = "cluster-01.googleapis.com"; - // CLUSTER (aggr.) -> [cluster1 (logical DNS)] - CdsUpdate update = - CdsUpdate.forAggregate(CLUSTER, Collections.singletonList(cluster1)) - .roundRobinLbPolicy().build(); - xdsClient.deliverCdsUpdate(CLUSTER, update); - CdsUpdate update1 = - CdsUpdate.forLogicalDns(cluster1, DNS_HOST_NAME, LRS_SERVER_INFO, 200L, null) - .roundRobinLbPolicy().build(); - xdsClient.deliverCdsUpdate(cluster1, update1); - FakeLoadBalancer childLb = Iterables.getOnlyElement(childBalancers); - ClusterResolverConfig childLbConfig = (ClusterResolverConfig) childLb.config; - assertThat(childLbConfig.discoveryMechanisms).hasSize(1); - - Status error = Status.RESOURCE_EXHAUSTED.withDescription("OOM"); - xdsClient.deliverError(error); - assertThat(childLb.upstreamError.getCode()).isEqualTo(Status.Code.UNAVAILABLE); - assertThat(childLb.upstreamError.getDescription()).contains("RESOURCE_EXHAUSTED: OOM"); - assertThat(childLb.shutdown).isFalse(); // child LB may choose to keep working - } - - @Test - public void handleNameResolutionErrorFromUpstream_beforeChildLbCreated_returnErrorPicker() { - Status upstreamError = Status.UNAVAILABLE.withDescription("unreachable"); - loadBalancer.handleNameResolutionError(upstreamError); + public void handleNameResolutionErrorFromUpstream_beforeChildLbCreated_failingPicker() { + Status status = Status.UNAVAILABLE.withDescription("unreachable"); + loadBalancer.handleNameResolutionError(status); verify(helper).updateBalancingState( eq(ConnectivityState.TRANSIENT_FAILURE), pickerCaptor.capture()); - assertPicker(pickerCaptor.getValue(), upstreamError, null); + assertPickerStatus(pickerCaptor.getValue(), status); } @Test public void handleNameResolutionErrorFromUpstream_afterChildLbCreated_fallThrough() { - CdsUpdate update = CdsUpdate.forEds(CLUSTER, EDS_SERVICE_NAME, LRS_SERVER_INFO, 100L, - upstreamTlsContext, outlierDetection).roundRobinLbPolicy().build(); - xdsClient.deliverCdsUpdate(CLUSTER, update); + Cluster cluster = Cluster.newBuilder() + .setName(CLUSTER) + .setType(Cluster.DiscoveryType.EDS) + .setEdsClusterConfig(Cluster.EdsClusterConfig.newBuilder() + .setServiceName(EDS_SERVICE_NAME) + .setEdsConfig(ConfigSource.newBuilder() + .setAds(AggregatedConfigSource.newBuilder()))) + .build(); + controlPlaneService.setXdsConfig(ADS_TYPE_URL_CDS, ImmutableMap.of(CLUSTER, cluster)); + startXdsDepManager(); + verify(helper, never()).updateBalancingState(eq(ConnectivityState.TRANSIENT_FAILURE), any()); FakeLoadBalancer childBalancer = Iterables.getOnlyElement(childBalancers); assertThat(childBalancer.shutdown).isFalse(); + loadBalancer.handleNameResolutionError(Status.UNAVAILABLE.withDescription("unreachable")); assertThat(childBalancer.upstreamError.getCode()).isEqualTo(Code.UNAVAILABLE); assertThat(childBalancer.upstreamError.getDescription()).isEqualTo("unreachable"); - verify(helper, never()).updateBalancingState( - any(ConnectivityState.class), any(SubchannelPicker.class)); + verify(helper).updateBalancingState( + eq(ConnectivityState.CONNECTING), any(SubchannelPicker.class)); } @Test public void unknownLbProvider() { - try { - xdsClient.deliverCdsUpdate(CLUSTER, - CdsUpdate.forEds(CLUSTER, EDS_SERVICE_NAME, LRS_SERVER_INFO, 100L, upstreamTlsContext, - outlierDetection) - .lbPolicyConfig(ImmutableMap.of("unknown", ImmutableMap.of("foo", "bar"))).build()); - } catch (Exception e) { - assertThat(e).hasMessageThat().contains("No provider available"); - return; - } - fail("Expected the unknown LB to cause an exception"); + Cluster cluster = Cluster.newBuilder() + .setName(CLUSTER) + .setType(Cluster.DiscoveryType.EDS) + .setEdsClusterConfig(Cluster.EdsClusterConfig.newBuilder() + .setServiceName(EDS_SERVICE_NAME) + .setEdsConfig(ConfigSource.newBuilder() + .setAds(AggregatedConfigSource.newBuilder()))) + .setLoadBalancingPolicy(LoadBalancingPolicy.newBuilder() + .addPolicies(Policy.newBuilder() + .setTypedExtensionConfig(TypedExtensionConfig.newBuilder() + .setTypedConfig(Any.pack(TypedStruct.newBuilder() + .setTypeUrl("type.googleapis.com/unknownLb") + .setValue(Struct.getDefaultInstance()) + .build()))))) + .build(); + controlPlaneService.setXdsConfig(ADS_TYPE_URL_CDS, ImmutableMap.of(CLUSTER, cluster)); + startXdsDepManager(); + verify(helper).updateBalancingState( + eq(ConnectivityState.TRANSIENT_FAILURE), pickerCaptor.capture()); + PickResult result = pickerCaptor.getValue().pickSubchannel(mock(PickSubchannelArgs.class)); + Status actualStatus = result.getStatus(); + assertThat(actualStatus.getCode()).isEqualTo(Status.Code.UNAVAILABLE); + assertThat(actualStatus.getDescription()).contains("Invalid LoadBalancingPolicy"); } @Test public void invalidLbConfig() { - try { - xdsClient.deliverCdsUpdate(CLUSTER, - CdsUpdate.forEds(CLUSTER, EDS_SERVICE_NAME, LRS_SERVER_INFO, 100L, upstreamTlsContext, - outlierDetection).lbPolicyConfig( - ImmutableMap.of("ring_hash_experimental", ImmutableMap.of("minRingSize", "-1"))) + Cluster cluster = Cluster.newBuilder() + .setName(CLUSTER) + .setType(Cluster.DiscoveryType.EDS) + .setEdsClusterConfig(Cluster.EdsClusterConfig.newBuilder() + .setServiceName(EDS_SERVICE_NAME) + .setEdsConfig(ConfigSource.newBuilder() + .setAds(AggregatedConfigSource.newBuilder()))) + .setLoadBalancingPolicy(LoadBalancingPolicy.newBuilder() + .addPolicies(Policy.newBuilder() + .setTypedExtensionConfig(TypedExtensionConfig.newBuilder() + .setTypedConfig(Any.pack(TypedStruct.newBuilder() + .setTypeUrl("type.googleapis.com/ring_hash_experimental") + .setValue(Struct.newBuilder() + .putFields("minRingSize", Value.newBuilder().setNumberValue(-1).build())) + .build()))))) + .build(); + controlPlaneService.setXdsConfig(ADS_TYPE_URL_CDS, ImmutableMap.of(CLUSTER, cluster)); + startXdsDepManager(); + verify(helper).updateBalancingState( + eq(ConnectivityState.TRANSIENT_FAILURE), pickerCaptor.capture()); + PickResult result = pickerCaptor.getValue().pickSubchannel(mock(PickSubchannelArgs.class)); + Status actualStatus = result.getStatus(); + assertThat(actualStatus.getCode()).isEqualTo(Status.Code.UNAVAILABLE); + assertThat(actualStatus.getDescription()).contains("Invalid 'minRingSize'"); + } + + private void startXdsDepManager() { + startXdsDepManager(new CdsConfig(CLUSTER)); + } + + private void startXdsDepManager(final CdsConfig cdsConfig) { + xdsDepManager.start( + xdsConfig -> { + if (!xdsConfig.hasValue()) { + throw new AssertionError("" + xdsConfig.getStatus()); + } + this.lastXdsConfig = xdsConfig.getValue(); + if (loadBalancer == null) { + return; + } + loadBalancer.acceptResolvedAddresses(ResolvedAddresses.newBuilder() + .setAddresses(Collections.emptyList()) + .setAttributes(Attributes.newBuilder() + .set(XdsAttributes.XDS_CONFIG, xdsConfig.getValue()) + .set(XdsAttributes.XDS_CLUSTER_SUBSCRIPT_REGISTRY, xdsDepManager) + .build()) + .setLoadBalancingPolicyConfig(cdsConfig) .build()); - } catch (Exception e) { - assertThat(e).hasMessageThat().contains("Unable to parse"); - return; - } - fail("Expected the invalid config to cause an exception"); + }); + // trigger does not exist timer, so broken config is more obvious + fakeClock.forwardTime(10, TimeUnit.MINUTES); } - private static void assertPicker(SubchannelPicker picker, Status expectedStatus, - @Nullable Subchannel expectedSubchannel) { + private static void assertPickerStatus(SubchannelPicker picker, Status expectedStatus) { PickResult result = picker.pickSubchannel(mock(PickSubchannelArgs.class)); Status actualStatus = result.getStatus(); assertThat(actualStatus.getCode()).isEqualTo(expectedStatus.getCode()); assertThat(actualStatus.getDescription()).isEqualTo(expectedStatus.getDescription()); - if (actualStatus.isOk()) { - assertThat(result.getSubchannel()).isSameInstanceAs(expectedSubchannel); - } - } - - private static void assertDiscoveryMechanism(DiscoveryMechanism instance, String name, - DiscoveryMechanism.Type type, @Nullable String edsServiceName, @Nullable String dnsHostName, - @Nullable ServerInfo lrsServerInfo, @Nullable Long maxConcurrentRequests, - @Nullable UpstreamTlsContext tlsContext, @Nullable OutlierDetection outlierDetection) { - assertThat(instance.cluster).isEqualTo(name); - assertThat(instance.type).isEqualTo(type); - assertThat(instance.edsServiceName).isEqualTo(edsServiceName); - assertThat(instance.dnsHostName).isEqualTo(dnsHostName); - assertThat(instance.lrsServerInfo).isEqualTo(lrsServerInfo); - assertThat(instance.maxConcurrentRequests).isEqualTo(maxConcurrentRequests); - assertThat(instance.tlsContext).isEqualTo(tlsContext); - assertThat(instance.outlierDetection).isEqualTo(outlierDetection); } private final class FakeLoadBalancerProvider extends LoadBalancerProvider { @@ -769,8 +701,9 @@ private final class FakeLoadBalancer extends LoadBalancer { } @Override - public void handleResolvedAddresses(ResolvedAddresses resolvedAddresses) { + public Status acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { config = resolvedAddresses.getLoadBalancingPolicyConfig(); + return Status.OK; } @Override @@ -784,53 +717,4 @@ public void shutdown() { childBalancers.remove(this); } } - - private static final class FakeXdsClient extends XdsClient { - // watchers needs to support any non-cyclic shaped graphs - private final Map>> watchers = new HashMap<>(); - - @Override - @SuppressWarnings("unchecked") - public void watchXdsResource(XdsResourceType type, - String resourceName, - ResourceWatcher watcher, Executor syncContext) { - assertThat(type.typeName()).isEqualTo("CDS"); - watchers.computeIfAbsent(resourceName, k -> new ArrayList<>()) - .add((ResourceWatcher)watcher); - } - - @Override - public void cancelXdsResourceWatch(XdsResourceType type, - String resourceName, - ResourceWatcher watcher) { - assertThat(type.typeName()).isEqualTo("CDS"); - assertThat(watchers).containsKey(resourceName); - List> watcherList = watchers.get(resourceName); - assertThat(watcherList.remove(watcher)).isTrue(); - if (watcherList.isEmpty()) { - watchers.remove(resourceName); - } - } - - private void deliverCdsUpdate(String clusterName, CdsUpdate update) { - if (watchers.containsKey(clusterName)) { - List> resourceWatchers = - ImmutableList.copyOf(watchers.get(clusterName)); - resourceWatchers.forEach(w -> w.onChanged(update)); - } - } - - private void deliverResourceNotExist(String clusterName) { - if (watchers.containsKey(clusterName)) { - ImmutableList.copyOf(watchers.get(clusterName)) - .forEach(w -> w.onResourceDoesNotExist(clusterName)); - } - } - - private void deliverError(Status error) { - watchers.values().stream() - .flatMap(List::stream) - .forEach(w -> w.onError(error)); - } - } } diff --git a/xds/src/test/java/io/grpc/xds/ClusterImplLoadBalancerTest.java b/xds/src/test/java/io/grpc/xds/ClusterImplLoadBalancerTest.java index 662430bef52..9277675385a 100644 --- a/xds/src/test/java/io/grpc/xds/ClusterImplLoadBalancerTest.java +++ b/xds/src/test/java/io/grpc/xds/ClusterImplLoadBalancerTest.java @@ -16,37 +16,51 @@ package io.grpc.xds; +import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.truth.Truth.assertThat; +import static com.google.common.truth.Truth.assertWithMessage; +import static io.grpc.xds.ClusterImplLoadBalancer.ATTR_SUBCHANNEL_ADDRESS_NAME; +import static io.grpc.xds.XdsNameResolver.AUTO_HOST_REWRITE_KEY; import static org.mockito.ArgumentMatchers.anyInt; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import com.github.xds.data.orca.v3.OrcaLoadReport; import com.google.common.collect.ImmutableMap; import com.google.common.collect.Iterables; import io.grpc.Attributes; +import io.grpc.CallOptions; import io.grpc.ClientStreamTracer; import io.grpc.ConnectivityState; +import io.grpc.ConnectivityStateInfo; import io.grpc.EquivalentAddressGroup; import io.grpc.InsecureChannelCredentials; import io.grpc.LoadBalancer; import io.grpc.LoadBalancer.CreateSubchannelArgs; +import io.grpc.LoadBalancer.FixedResultPicker; import io.grpc.LoadBalancer.Helper; +import io.grpc.LoadBalancer.PickDetailsConsumer; import io.grpc.LoadBalancer.PickResult; import io.grpc.LoadBalancer.PickSubchannelArgs; import io.grpc.LoadBalancer.ResolvedAddresses; import io.grpc.LoadBalancer.Subchannel; import io.grpc.LoadBalancer.SubchannelPicker; +import io.grpc.LoadBalancer.SubchannelStateListener; import io.grpc.LoadBalancerProvider; +import io.grpc.LoadBalancerRegistry; import io.grpc.ManagedChannel; import io.grpc.Metadata; +import io.grpc.NameResolver; import io.grpc.Status; import io.grpc.Status.Code; import io.grpc.SynchronizationContext; import io.grpc.internal.FakeClock; -import io.grpc.internal.ObjectPool; -import io.grpc.internal.ServiceConfigUtil.PolicySelection; +import io.grpc.internal.PickFirstLoadBalancerProvider; +import io.grpc.internal.PickSubchannelArgsImpl; import io.grpc.protobuf.ProtoUtils; +import io.grpc.testing.TestMethodDescriptors; +import io.grpc.util.GracefulSwitchLoadBalancer; import io.grpc.xds.ClusterImplLoadBalancerProvider.ClusterImplConfig; import io.grpc.xds.Endpoints.DropOverload; import io.grpc.xds.EnvoyServerProtoData.DownstreamTlsContext; @@ -54,6 +68,7 @@ import io.grpc.xds.WeightedTargetLoadBalancerProvider.WeightedPolicySelection; import io.grpc.xds.WeightedTargetLoadBalancerProvider.WeightedTargetConfig; import io.grpc.xds.XdsNameResolverProvider.CallCounterProvider; +import io.grpc.xds.client.BackendMetricPropagation; import io.grpc.xds.client.Bootstrapper.ServerInfo; import io.grpc.xds.client.LoadReportClient; import io.grpc.xds.client.LoadStatsManager2; @@ -63,7 +78,9 @@ import io.grpc.xds.client.Stats.ClusterStats; import io.grpc.xds.client.Stats.UpstreamLocalityStats; import io.grpc.xds.client.XdsClient; +import io.grpc.xds.internal.XdsInternalAttributes; import io.grpc.xds.internal.security.CommonTlsContextTestsUtil; +import io.grpc.xds.internal.security.SecurityProtocolNegotiators; import io.grpc.xds.internal.security.SslContextProvider; import io.grpc.xds.internal.security.SslContextProviderSupplier; import java.net.SocketAddress; @@ -71,9 +88,11 @@ import java.util.Arrays; import java.util.Collections; import java.util.HashMap; +import java.util.LinkedList; import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.Queue; import java.util.concurrent.atomic.AtomicLong; import javax.annotation.Nonnull; import javax.annotation.Nullable; @@ -114,33 +133,23 @@ public void uncaughtException(Thread t, Throwable e) { private final FakeClock fakeClock = new FakeClock(); private final Locality locality = Locality.create("test-region", "test-zone", "test-subzone"); - private final PolicySelection roundRobin = - new PolicySelection(new FakeLoadBalancerProvider("round_robin"), null); + private final Object roundRobin = GracefulSwitchLoadBalancer.createLoadBalancingPolicyConfig( + new FakeLoadBalancerProvider("round_robin"), null); private final List downstreamBalancers = new ArrayList<>(); private final FakeTlsContextManager tlsContextManager = new FakeTlsContextManager(); private final LoadStatsManager2 loadStatsManager = new LoadStatsManager2(fakeClock.getStopwatchSupplier()); private final FakeXdsClient xdsClient = new FakeXdsClient(); - private final ObjectPool xdsClientPool = new ObjectPool() { - @Override - public XdsClient getObject() { - xdsClientRefs++; - return xdsClient; - } - - @Override - public XdsClient returnObject(Object object) { - xdsClientRefs--; - return null; - } - }; private final CallCounterProvider callCounterProvider = new CallCounterProvider() { @Override public AtomicLong getOrCreate(String cluster, @Nullable String edsServiceName) { return new AtomicLong(); } }; - private final Helper helper = new FakeLbHelper(); + private final FakeLbHelper helper = new FakeLbHelper(); + private PickSubchannelArgs pickSubchannelArgs = new PickSubchannelArgsImpl( + TestMethodDescriptors.voidMethod(), new Metadata(), CallOptions.DEFAULT, + new PickDetailsConsumer() {}); @Mock private ThreadSafeRandom mockRandom; private int xdsClientRefs; @@ -169,14 +178,17 @@ public void handleResolvedAddresses_propagateToChildPolicy() { Object weightedTargetConfig = new Object(); ClusterImplConfig config = new ClusterImplConfig(CLUSTER, EDS_SERVICE_NAME, LRS_SERVER_INFO, null, Collections.emptyList(), - new PolicySelection(weightedTargetProvider, weightedTargetConfig), null); + GracefulSwitchLoadBalancer.createLoadBalancingPolicyConfig( + weightedTargetProvider, weightedTargetConfig), + null, Collections.emptyMap(), null); EquivalentAddressGroup endpoint = makeAddress("endpoint-addr", locality); deliverAddressesAndConfig(Collections.singletonList(endpoint), config); FakeLoadBalancer childBalancer = Iterables.getOnlyElement(downstreamBalancers); assertThat(Iterables.getOnlyElement(childBalancer.addresses)).isEqualTo(endpoint); assertThat(childBalancer.config).isSameInstanceAs(weightedTargetConfig); - assertThat(childBalancer.attributes.get(InternalXdsAttributes.XDS_CLIENT_POOL)) - .isSameInstanceAs(xdsClientPool); + assertThat(childBalancer.attributes.get(io.grpc.xds.XdsAttributes.XDS_CLIENT)) + .isSameInstanceAs(xdsClient); + assertThat(childBalancer.attributes.get(NameResolver.ATTR_BACKEND_SERVICE)).isEqualTo(CLUSTER); } /** @@ -194,7 +206,9 @@ public void handleResolvedAddresses_childPolicyChanges() { ClusterImplConfig configWithWeightedTarget = new ClusterImplConfig(CLUSTER, EDS_SERVICE_NAME, LRS_SERVER_INFO, null, Collections.emptyList(), - new PolicySelection(weightedTargetProvider, weightedTargetConfig), null); + GracefulSwitchLoadBalancer.createLoadBalancingPolicyConfig( + weightedTargetProvider, weightedTargetConfig), + null, Collections.emptyMap(), null); EquivalentAddressGroup endpoint = makeAddress("endpoint-addr", locality); deliverAddressesAndConfig(Collections.singletonList(endpoint), configWithWeightedTarget); FakeLoadBalancer childBalancer = Iterables.getOnlyElement(downstreamBalancers); @@ -207,7 +221,9 @@ public void handleResolvedAddresses_childPolicyChanges() { ClusterImplConfig configWithWrrLocality = new ClusterImplConfig(CLUSTER, EDS_SERVICE_NAME, LRS_SERVER_INFO, null, Collections.emptyList(), - new PolicySelection(wrrLocalityProvider, wrrLocalityConfig), null); + GracefulSwitchLoadBalancer.createLoadBalancingPolicyConfig( + wrrLocalityProvider, wrrLocalityConfig), + null, Collections.emptyMap(), null); deliverAddressesAndConfig(Collections.singletonList(endpoint), configWithWrrLocality); childBalancer = Iterables.getOnlyElement(downstreamBalancers); assertThat(childBalancer.name).isEqualTo(XdsLbPolicies.WRR_LOCALITY_POLICY_NAME); @@ -218,7 +234,7 @@ public void handleResolvedAddresses_childPolicyChanges() { public void nameResolutionError_beforeChildPolicyInstantiated_returnErrorPickerToUpstream() { loadBalancer.handleNameResolutionError(Status.UNIMPLEMENTED.withDescription("not found")); assertThat(currentState).isEqualTo(ConnectivityState.TRANSIENT_FAILURE); - PickResult result = currentPicker.pickSubchannel(mock(PickSubchannelArgs.class)); + PickResult result = currentPicker.pickSubchannel(pickSubchannelArgs); assertThat(result.getStatus().isOk()).isFalse(); assertThat(result.getStatus().getCode()).isEqualTo(Code.UNIMPLEMENTED); assertThat(result.getStatus().getDescription()).isEqualTo("not found"); @@ -231,7 +247,9 @@ public void nameResolutionError_afterChildPolicyInstantiated_propagateToDownstre Object weightedTargetConfig = new Object(); ClusterImplConfig config = new ClusterImplConfig(CLUSTER, EDS_SERVICE_NAME, LRS_SERVER_INFO, null, Collections.emptyList(), - new PolicySelection(weightedTargetProvider, weightedTargetConfig), null); + GracefulSwitchLoadBalancer.createLoadBalancingPolicyConfig( + weightedTargetProvider, weightedTargetConfig), + null, Collections.emptyMap(), null); EquivalentAddressGroup endpoint = makeAddress("endpoint-addr", locality); deliverAddressesAndConfig(Collections.singletonList(endpoint), config); FakeLoadBalancer childBalancer = Iterables.getOnlyElement(downstreamBalancers); @@ -243,6 +261,61 @@ public void nameResolutionError_afterChildPolicyInstantiated_propagateToDownstre .isEqualTo("cannot reach server"); } + @Test + public void pick_addsOptionalLabels() { + LoadBalancerProvider weightedTargetProvider = new WeightedTargetLoadBalancerProvider(); + WeightedTargetConfig weightedTargetConfig = + buildWeightedTargetConfig(ImmutableMap.of(locality, 10)); + ClusterImplConfig config = new ClusterImplConfig(CLUSTER, EDS_SERVICE_NAME, LRS_SERVER_INFO, + null, Collections.emptyList(), + GracefulSwitchLoadBalancer.createLoadBalancingPolicyConfig( + weightedTargetProvider, weightedTargetConfig), + null, Collections.emptyMap(), null); + EquivalentAddressGroup endpoint = makeAddress("endpoint-addr", locality); + deliverAddressesAndConfig(Collections.singletonList(endpoint), config); + FakeLoadBalancer leafBalancer = Iterables.getOnlyElement(downstreamBalancers); + leafBalancer.createSubChannel(); + FakeSubchannel fakeSubchannel = helper.subchannels.poll(); + fakeSubchannel.updateState(ConnectivityStateInfo.forNonError(ConnectivityState.CONNECTING)); + fakeSubchannel.setConnectedEagIndex(0); + fakeSubchannel.updateState(ConnectivityStateInfo.forNonError(ConnectivityState.READY)); + assertThat(currentState).isEqualTo(ConnectivityState.READY); + + PickDetailsConsumer detailsConsumer = mock(PickDetailsConsumer.class); + pickSubchannelArgs = new PickSubchannelArgsImpl( + TestMethodDescriptors.voidMethod(), new Metadata(), CallOptions.DEFAULT, detailsConsumer); + PickResult result = currentPicker.pickSubchannel(pickSubchannelArgs); + assertThat(result.getStatus().isOk()).isTrue(); + // The value will be determined by the parent policy, so can be different than the value used in + // makeAddress() for the test. + verify(detailsConsumer).addOptionalLabel("grpc.lb.locality", locality.toString()); + verify(detailsConsumer).addOptionalLabel("grpc.lb.backend_service", CLUSTER); + } + + @Test + public void pick_noResult_addsClusterLabel() { + LoadBalancerProvider weightedTargetProvider = new WeightedTargetLoadBalancerProvider(); + WeightedTargetConfig weightedTargetConfig = + buildWeightedTargetConfig(ImmutableMap.of(locality, 10)); + ClusterImplConfig config = new ClusterImplConfig(CLUSTER, EDS_SERVICE_NAME, LRS_SERVER_INFO, + null, Collections.emptyList(), + GracefulSwitchLoadBalancer.createLoadBalancingPolicyConfig( + weightedTargetProvider, weightedTargetConfig), + null, Collections.emptyMap(), null); + EquivalentAddressGroup endpoint = makeAddress("endpoint-addr", locality); + deliverAddressesAndConfig(Collections.singletonList(endpoint), config); + FakeLoadBalancer leafBalancer = Iterables.getOnlyElement(downstreamBalancers); + leafBalancer.deliverSubchannelState(PickResult.withNoResult(), ConnectivityState.CONNECTING); + assertThat(currentState).isEqualTo(ConnectivityState.CONNECTING); + + PickDetailsConsumer detailsConsumer = mock(PickDetailsConsumer.class); + pickSubchannelArgs = new PickSubchannelArgsImpl( + TestMethodDescriptors.voidMethod(), new Metadata(), CallOptions.DEFAULT, detailsConsumer); + PickResult result = currentPicker.pickSubchannel(pickSubchannelArgs); + assertThat(result.getStatus().isOk()).isTrue(); + verify(detailsConsumer).addOptionalLabel("grpc.lb.backend_service", CLUSTER); + } + @Test public void recordLoadStats() { LoadBalancerProvider weightedTargetProvider = new WeightedTargetLoadBalancerProvider(); @@ -250,15 +323,19 @@ public void recordLoadStats() { buildWeightedTargetConfig(ImmutableMap.of(locality, 10)); ClusterImplConfig config = new ClusterImplConfig(CLUSTER, EDS_SERVICE_NAME, LRS_SERVER_INFO, null, Collections.emptyList(), - new PolicySelection(weightedTargetProvider, weightedTargetConfig), null); + GracefulSwitchLoadBalancer.createLoadBalancingPolicyConfig( + weightedTargetProvider, weightedTargetConfig), + null, Collections.emptyMap(), null); EquivalentAddressGroup endpoint = makeAddress("endpoint-addr", locality); deliverAddressesAndConfig(Collections.singletonList(endpoint), config); FakeLoadBalancer leafBalancer = Iterables.getOnlyElement(downstreamBalancers); - Subchannel subchannel = leafBalancer.helper.createSubchannel( - CreateSubchannelArgs.newBuilder().setAddresses(leafBalancer.addresses).build()); - leafBalancer.deliverSubchannelState(subchannel, ConnectivityState.READY); + Subchannel subchannel = leafBalancer.createSubChannel(); + FakeSubchannel fakeSubchannel = helper.subchannels.poll(); + fakeSubchannel.updateState(ConnectivityStateInfo.forNonError(ConnectivityState.CONNECTING)); + fakeSubchannel.setConnectedEagIndex(0); + fakeSubchannel.updateState(ConnectivityStateInfo.forNonError(ConnectivityState.READY)); assertThat(currentState).isEqualTo(ConnectivityState.READY); - PickResult result = currentPicker.pickSubchannel(mock(PickSubchannelArgs.class)); + PickResult result = currentPicker.pickSubchannel(pickSubchannelArgs); assertThat(result.getStatus().isOk()).isTrue(); ClientStreamTracer streamTracer1 = result.getStreamTracerFactory().newClientStreamTracer( ClientStreamTracer.StreamInfo.newBuilder().build(), new Metadata()); // first RPC call @@ -311,7 +388,7 @@ public void recordLoadStats() { TOLERANCE).of(0.009); streamTracer3.streamClosed(Status.OK); - subchannel.shutdown(); // stats recorder released + subchannel.shutdown(); // stats recorder released clusterStats = Iterables.getOnlyElement(loadStatsManager.getClusterStatsReports(CLUSTER)); // Locality load is reported for one last time in case of loads occurred since the previous // load report. @@ -327,6 +404,213 @@ public void recordLoadStats() { assertThat(clusterStats.upstreamLocalityStatsList()).isEmpty(); // no longer reported } + @Test + public void recordLoadStats_orcaLrsPropagationEnabled() { + boolean originalVal = LoadStatsManager2.isEnabledOrcaLrsPropagation; + LoadStatsManager2.isEnabledOrcaLrsPropagation = true; + BackendMetricPropagation backendMetricPropagation = BackendMetricPropagation.fromMetricSpecs( + Arrays.asList("application_utilization", "cpu_utilization", "named_metrics.named1")); + LoadBalancerProvider weightedTargetProvider = new WeightedTargetLoadBalancerProvider(); + WeightedTargetConfig weightedTargetConfig = + buildWeightedTargetConfig(ImmutableMap.of(locality, 10)); + ClusterImplConfig config = new ClusterImplConfig(CLUSTER, EDS_SERVICE_NAME, LRS_SERVER_INFO, + null, Collections.emptyList(), + GracefulSwitchLoadBalancer.createLoadBalancingPolicyConfig( + weightedTargetProvider, weightedTargetConfig), + null, Collections.emptyMap(), backendMetricPropagation); + EquivalentAddressGroup endpoint = makeAddress("endpoint-addr", locality); + deliverAddressesAndConfig(Collections.singletonList(endpoint), config); + FakeLoadBalancer leafBalancer = Iterables.getOnlyElement(downstreamBalancers); + Subchannel subchannel = leafBalancer.createSubChannel(); + FakeSubchannel fakeSubchannel = helper.subchannels.poll(); + fakeSubchannel.updateState(ConnectivityStateInfo.forNonError(ConnectivityState.CONNECTING)); + fakeSubchannel.setConnectedEagIndex(0); + fakeSubchannel.updateState(ConnectivityStateInfo.forNonError(ConnectivityState.READY)); + assertThat(currentState).isEqualTo(ConnectivityState.READY); + PickResult result = currentPicker.pickSubchannel(pickSubchannelArgs); + assertThat(result.getStatus().isOk()).isTrue(); + ClientStreamTracer streamTracer = result.getStreamTracerFactory().newClientStreamTracer( + ClientStreamTracer.StreamInfo.newBuilder().build(), new Metadata()); + Metadata trailersWithOrcaLoadReport = new Metadata(); + trailersWithOrcaLoadReport.put(ORCA_ENDPOINT_LOAD_METRICS_KEY, + OrcaLoadReport.newBuilder() + .setApplicationUtilization(1.414) + .setCpuUtilization(0.5) + .setMemUtilization(0.034) + .putNamedMetrics("named1", 3.14159) + .putNamedMetrics("named2", -1.618).build()); + streamTracer.inboundTrailers(trailersWithOrcaLoadReport); + streamTracer.streamClosed(Status.OK); + ClusterStats clusterStats = + Iterables.getOnlyElement(loadStatsManager.getClusterStatsReports(CLUSTER)); + UpstreamLocalityStats localityStats = + Iterables.getOnlyElement(clusterStats.upstreamLocalityStatsList()); + + assertThat(localityStats.loadMetricStatsMap()).containsKey("application_utilization"); + assertThat(localityStats.loadMetricStatsMap().get("application_utilization").totalMetricValue()) + .isWithin(TOLERANCE).of(1.414); + assertThat(localityStats.loadMetricStatsMap()).containsKey("cpu_utilization"); + assertThat(localityStats.loadMetricStatsMap().get("cpu_utilization").totalMetricValue()) + .isWithin(TOLERANCE).of(0.5); + assertThat(localityStats.loadMetricStatsMap()).doesNotContainKey("mem_utilization"); + assertThat(localityStats.loadMetricStatsMap()).containsKey("named_metrics.named1"); + assertThat(localityStats.loadMetricStatsMap().get("named_metrics.named1").totalMetricValue()) + .isWithin(TOLERANCE).of(3.14159); + assertThat(localityStats.loadMetricStatsMap()).doesNotContainKey("named_metrics.named2"); + subchannel.shutdown(); + LoadStatsManager2.isEnabledOrcaLrsPropagation = originalVal; + } + + @Test + public void recordLoadStats_orcaLrsPropagationDisabled() { + boolean originalVal = LoadStatsManager2.isEnabledOrcaLrsPropagation; + LoadStatsManager2.isEnabledOrcaLrsPropagation = false; + BackendMetricPropagation backendMetricPropagation = BackendMetricPropagation.fromMetricSpecs( + Arrays.asList("application_utilization", "cpu_utilization", "named_metrics.named1")); + LoadBalancerProvider weightedTargetProvider = new WeightedTargetLoadBalancerProvider(); + WeightedTargetConfig weightedTargetConfig = + buildWeightedTargetConfig(ImmutableMap.of(locality, 10)); + ClusterImplConfig config = new ClusterImplConfig(CLUSTER, EDS_SERVICE_NAME, LRS_SERVER_INFO, + null, Collections.emptyList(), + GracefulSwitchLoadBalancer.createLoadBalancingPolicyConfig( + weightedTargetProvider, weightedTargetConfig), + null, Collections.emptyMap(), backendMetricPropagation); + EquivalentAddressGroup endpoint = makeAddress("endpoint-addr", locality); + deliverAddressesAndConfig(Collections.singletonList(endpoint), config); + FakeLoadBalancer leafBalancer = Iterables.getOnlyElement(downstreamBalancers); + Subchannel subchannel = leafBalancer.createSubChannel(); + FakeSubchannel fakeSubchannel = helper.subchannels.poll(); + fakeSubchannel.updateState(ConnectivityStateInfo.forNonError(ConnectivityState.CONNECTING)); + fakeSubchannel.setConnectedEagIndex(0); + fakeSubchannel.updateState(ConnectivityStateInfo.forNonError(ConnectivityState.READY)); + assertThat(currentState).isEqualTo(ConnectivityState.READY); + PickResult result = currentPicker.pickSubchannel(pickSubchannelArgs); + assertThat(result.getStatus().isOk()).isTrue(); + ClientStreamTracer streamTracer = result.getStreamTracerFactory().newClientStreamTracer( + ClientStreamTracer.StreamInfo.newBuilder().build(), new Metadata()); + Metadata trailersWithOrcaLoadReport = new Metadata(); + trailersWithOrcaLoadReport.put(ORCA_ENDPOINT_LOAD_METRICS_KEY, + OrcaLoadReport.newBuilder() + .setApplicationUtilization(1.414) + .setCpuUtilization(0.5) + .setMemUtilization(0.034) + .putNamedMetrics("named1", 3.14159) + .putNamedMetrics("named2", -1.618).build()); + streamTracer.inboundTrailers(trailersWithOrcaLoadReport); + streamTracer.streamClosed(Status.OK); + ClusterStats clusterStats = + Iterables.getOnlyElement(loadStatsManager.getClusterStatsReports(CLUSTER)); + UpstreamLocalityStats localityStats = + Iterables.getOnlyElement(clusterStats.upstreamLocalityStatsList()); + + assertThat(localityStats.loadMetricStatsMap()).doesNotContainKey("application_utilization"); + assertThat(localityStats.loadMetricStatsMap()).doesNotContainKey("cpu_utilization"); + assertThat(localityStats.loadMetricStatsMap()).doesNotContainKey("mem_utilization"); + assertThat(localityStats.loadMetricStatsMap()).doesNotContainKey("named_metrics.named1"); + assertThat(localityStats.loadMetricStatsMap()).doesNotContainKey("named_metrics.named2"); + assertThat(localityStats.loadMetricStatsMap().containsKey("named1")).isTrue(); + assertThat(localityStats.loadMetricStatsMap().containsKey("named2")).isTrue(); + subchannel.shutdown(); + LoadStatsManager2.isEnabledOrcaLrsPropagation = originalVal; + } + + // Verifies https://github.com/grpc/grpc-java/issues/11434. + @Test + public void pickFirstLoadReport_onUpdateAddress() { + Locality locality1 = + Locality.create("test-region", "test-zone", "test-subzone"); + Locality locality2 = + Locality.create("other-region", "other-zone", "other-subzone"); + + LoadBalancerProvider pickFirstProvider = LoadBalancerRegistry + .getDefaultRegistry().getProvider("pick_first"); + Object pickFirstConfig = pickFirstProvider.parseLoadBalancingPolicyConfig(new HashMap<>()) + .getConfig(); + ClusterImplConfig config = new ClusterImplConfig(CLUSTER, EDS_SERVICE_NAME, LRS_SERVER_INFO, + null, Collections.emptyList(), + GracefulSwitchLoadBalancer.createLoadBalancingPolicyConfig(pickFirstProvider, + pickFirstConfig), + null, Collections.emptyMap(), null); + EquivalentAddressGroup endpoint1 = makeAddress("endpoint-addr1", locality1); + EquivalentAddressGroup endpoint2 = makeAddress("endpoint-addr2", locality2); + deliverAddressesAndConfig(Arrays.asList(endpoint1, endpoint2), config); + + // Leaf balancer is created by Pick First. Get FakeSubchannel created to update attributes + // A real subchannel would get these attributes from the connected address's EAG locality. + FakeSubchannel fakeSubchannel = helper.subchannels.poll(); + fakeSubchannel.updateState(ConnectivityStateInfo.forNonError(ConnectivityState.CONNECTING)); + fakeSubchannel.setConnectedEagIndex(0); + fakeSubchannel.updateState(ConnectivityStateInfo.forNonError(ConnectivityState.READY)); + assertThat(currentState).isEqualTo(ConnectivityState.READY); + PickResult result = currentPicker.pickSubchannel(pickSubchannelArgs); + assertThat(result.getStatus().isOk()).isTrue(); + + ClientStreamTracer streamTracer1 = result.getStreamTracerFactory().newClientStreamTracer( + ClientStreamTracer.StreamInfo.newBuilder().build(), new Metadata()); // first RPC call + streamTracer1.streamClosed(Status.OK); + + ClusterStats clusterStats = Iterables.getOnlyElement( + loadStatsManager.getClusterStatsReports(CLUSTER)); + UpstreamLocalityStats localityStats = Iterables.getOnlyElement( + clusterStats.upstreamLocalityStatsList()); + assertThat(localityStats.locality()).isEqualTo(locality1); + assertThat(localityStats.totalIssuedRequests()).isEqualTo(1L); + assertThat(localityStats.totalSuccessfulRequests()).isEqualTo(1L); + assertThat(localityStats.totalErrorRequests()).isEqualTo(0L); + + fakeSubchannel.updateState(ConnectivityStateInfo.forNonError(ConnectivityState.IDLE)); + loadBalancer.requestConnection(); + fakeSubchannel.updateState(ConnectivityStateInfo.forNonError(ConnectivityState.CONNECTING)); + + // Faksubchannel mimics update address and returns different locality + if (PickFirstLoadBalancerProvider.isEnabledNewPickFirst()) { + fakeSubchannel.updateState(ConnectivityStateInfo.forTransientFailure( + Status.UNAVAILABLE.withDescription("Try second address instead"))); + fakeSubchannel = helper.subchannels.poll(); + fakeSubchannel.updateState(ConnectivityStateInfo.forNonError(ConnectivityState.CONNECTING)); + fakeSubchannel.setConnectedEagIndex(0); + fakeSubchannel.updateState(ConnectivityStateInfo.forNonError(ConnectivityState.READY)); + } else { + fakeSubchannel.setConnectedEagIndex(1); + fakeSubchannel.updateState(ConnectivityStateInfo.forNonError(ConnectivityState.READY)); + } + result = currentPicker.pickSubchannel(pickSubchannelArgs); + assertThat(result.getStatus().isOk()).isTrue(); + ClientStreamTracer streamTracer2 = result.getStreamTracerFactory().newClientStreamTracer( + ClientStreamTracer.StreamInfo.newBuilder().build(), new Metadata()); // second RPC call + streamTracer2.streamClosed(Status.UNAVAILABLE); + + clusterStats = Iterables.getOnlyElement(loadStatsManager.getClusterStatsReports(CLUSTER)); + List upstreamLocalityStatsList = + clusterStats.upstreamLocalityStatsList(); + UpstreamLocalityStats localityStats1 = Iterables.find(upstreamLocalityStatsList, + upstreamLocalityStats -> upstreamLocalityStats.locality().equals(locality1)); + assertThat(localityStats1.totalIssuedRequests()).isEqualTo(0L); + assertThat(localityStats1.totalSuccessfulRequests()).isEqualTo(0L); + assertThat(localityStats1.totalErrorRequests()).isEqualTo(0L); + UpstreamLocalityStats localityStats2 = Iterables.find(upstreamLocalityStatsList, + upstreamLocalityStats -> upstreamLocalityStats.locality().equals(locality2)); + assertThat(localityStats2.totalIssuedRequests()).isEqualTo(1L); + assertThat(localityStats2.totalSuccessfulRequests()).isEqualTo(0L); + assertThat(localityStats2.totalErrorRequests()).isEqualTo(1L); + + loadBalancer.shutdown(); + loadBalancer = null; + // No more references are held for localityStats1 hence dropped. + // Locality load is reported for one last time in case of loads occurred since the previous + // load report. + clusterStats = Iterables.getOnlyElement(loadStatsManager.getClusterStatsReports(CLUSTER)); + localityStats2 = Iterables.getOnlyElement(clusterStats.upstreamLocalityStatsList()); + + assertThat(localityStats2.locality()).isEqualTo(locality2); + assertThat(localityStats2.totalIssuedRequests()).isEqualTo(0L); + assertThat(localityStats2.totalSuccessfulRequests()).isEqualTo(0L); + assertThat(localityStats2.totalErrorRequests()).isEqualTo(0L); + assertThat(localityStats2.totalRequestsInProgress()).isEqualTo(0L); + + assertThat(loadStatsManager.getClusterStatsReports(CLUSTER)).isEmpty(); + } + @Test public void dropRpcsWithRespectToLbConfigDropCategories() { LoadBalancerProvider weightedTargetProvider = new WeightedTargetLoadBalancerProvider(); @@ -334,7 +618,9 @@ public void dropRpcsWithRespectToLbConfigDropCategories() { buildWeightedTargetConfig(ImmutableMap.of(locality, 10)); ClusterImplConfig config = new ClusterImplConfig(CLUSTER, EDS_SERVICE_NAME, LRS_SERVER_INFO, null, Collections.singletonList(DropOverload.create("throttle", 500_000)), - new PolicySelection(weightedTargetProvider, weightedTargetConfig), null); + GracefulSwitchLoadBalancer.createLoadBalancingPolicyConfig( + weightedTargetProvider, weightedTargetConfig), + null, Collections.emptyMap(), null); EquivalentAddressGroup endpoint = makeAddress("endpoint-addr", locality); deliverAddressesAndConfig(Collections.singletonList(endpoint), config); when(mockRandom.nextInt(anyInt())).thenReturn(499_999, 999_999, 1_000_000); @@ -343,11 +629,14 @@ public void dropRpcsWithRespectToLbConfigDropCategories() { assertThat(leafBalancer.name).isEqualTo("round_robin"); assertThat(Iterables.getOnlyElement(leafBalancer.addresses).getAddresses()) .isEqualTo(endpoint.getAddresses()); - Subchannel subchannel = leafBalancer.helper.createSubchannel( - CreateSubchannelArgs.newBuilder().setAddresses(leafBalancer.addresses).build()); - leafBalancer.deliverSubchannelState(subchannel, ConnectivityState.READY); + leafBalancer.createSubChannel(); + FakeSubchannel fakeSubchannel = helper.subchannels.poll(); + fakeSubchannel.updateState(ConnectivityStateInfo.forNonError(ConnectivityState.CONNECTING)); + fakeSubchannel.setConnectedEagIndex(0); + fakeSubchannel.updateState(ConnectivityStateInfo.forNonError(ConnectivityState.READY)); + assertThat(currentState).isEqualTo(ConnectivityState.READY); - PickResult result = currentPicker.pickSubchannel(mock(PickSubchannelArgs.class)); + PickResult result = currentPicker.pickSubchannel(pickSubchannelArgs); assertThat(result.getStatus().isOk()).isFalse(); assertThat(result.getStatus().getCode()).isEqualTo(Code.UNAVAILABLE); assertThat(result.getStatus().getDescription()).isEqualTo("Dropped: throttle"); @@ -363,17 +652,19 @@ public void dropRpcsWithRespectToLbConfigDropCategories() { // Config update updates drop policies. config = new ClusterImplConfig(CLUSTER, EDS_SERVICE_NAME, LRS_SERVER_INFO, null, Collections.singletonList(DropOverload.create("lb", 1_000_000)), - new PolicySelection(weightedTargetProvider, weightedTargetConfig), null); + GracefulSwitchLoadBalancer.createLoadBalancingPolicyConfig( + weightedTargetProvider, weightedTargetConfig), + null, Collections.emptyMap(), null); loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(Collections.singletonList(endpoint)) .setAttributes( Attributes.newBuilder() - .set(InternalXdsAttributes.XDS_CLIENT_POOL, xdsClientPool) + .set(io.grpc.xds.XdsAttributes.XDS_CLIENT, xdsClient) .build()) .setLoadBalancingPolicyConfig(config) .build()); - result = currentPicker.pickSubchannel(mock(PickSubchannelArgs.class)); + result = currentPicker.pickSubchannel(pickSubchannelArgs); assertThat(result.getStatus().isOk()).isFalse(); assertThat(result.getStatus().getCode()).isEqualTo(Code.UNAVAILABLE); assertThat(result.getStatus().getDescription()).isEqualTo("Dropped: lb"); @@ -386,7 +677,7 @@ public void dropRpcsWithRespectToLbConfigDropCategories() { .isEqualTo(1L); assertThat(clusterStats.totalDroppedRequests()).isEqualTo(1L); - result = currentPicker.pickSubchannel(mock(PickSubchannelArgs.class)); + result = currentPicker.pickSubchannel(pickSubchannelArgs); assertThat(result.getStatus().isOk()).isTrue(); } @@ -410,7 +701,9 @@ private void subtest_maxConcurrentRequests_appliedByLbConfig(boolean enableCircu buildWeightedTargetConfig(ImmutableMap.of(locality, 10)); ClusterImplConfig config = new ClusterImplConfig(CLUSTER, EDS_SERVICE_NAME, LRS_SERVER_INFO, maxConcurrentRequests, Collections.emptyList(), - new PolicySelection(weightedTargetProvider, weightedTargetConfig), null); + GracefulSwitchLoadBalancer.createLoadBalancingPolicyConfig( + weightedTargetProvider, weightedTargetConfig), + null, Collections.emptyMap(), null); EquivalentAddressGroup endpoint = makeAddress("endpoint-addr", locality); deliverAddressesAndConfig(Collections.singletonList(endpoint), config); assertThat(downstreamBalancers).hasSize(1); // one leaf balancer @@ -418,12 +711,15 @@ private void subtest_maxConcurrentRequests_appliedByLbConfig(boolean enableCircu assertThat(leafBalancer.name).isEqualTo("round_robin"); assertThat(Iterables.getOnlyElement(leafBalancer.addresses).getAddresses()) .isEqualTo(endpoint.getAddresses()); - Subchannel subchannel = leafBalancer.helper.createSubchannel( - CreateSubchannelArgs.newBuilder().setAddresses(leafBalancer.addresses).build()); - leafBalancer.deliverSubchannelState(subchannel, ConnectivityState.READY); + leafBalancer.createSubChannel(); + FakeSubchannel fakeSubchannel = helper.subchannels.poll(); + fakeSubchannel.updateState(ConnectivityStateInfo.forNonError(ConnectivityState.CONNECTING)); + fakeSubchannel.setConnectedEagIndex(0); + fakeSubchannel.updateState(ConnectivityStateInfo.forNonError(ConnectivityState.READY)); + assertThat(currentState).isEqualTo(ConnectivityState.READY); assertThat(currentState).isEqualTo(ConnectivityState.READY); for (int i = 0; i < maxConcurrentRequests; i++) { - PickResult result = currentPicker.pickSubchannel(mock(PickSubchannelArgs.class)); + PickResult result = currentPicker.pickSubchannel(pickSubchannelArgs); assertThat(result.getStatus().isOk()).isTrue(); ClientStreamTracer.Factory streamTracerFactory = result.getStreamTracerFactory(); streamTracerFactory.newClientStreamTracer( @@ -434,14 +730,14 @@ private void subtest_maxConcurrentRequests_appliedByLbConfig(boolean enableCircu assertThat(clusterStats.clusterServiceName()).isEqualTo(EDS_SERVICE_NAME); assertThat(clusterStats.totalDroppedRequests()).isEqualTo(0L); - PickResult result = currentPicker.pickSubchannel(mock(PickSubchannelArgs.class)); + PickResult result = currentPicker.pickSubchannel(pickSubchannelArgs); clusterStats = Iterables.getOnlyElement(loadStatsManager.getClusterStatsReports(CLUSTER)); assertThat(clusterStats.clusterServiceName()).isEqualTo(EDS_SERVICE_NAME); if (enableCircuitBreaking) { assertThat(result.getStatus().isOk()).isFalse(); assertThat(result.getStatus().getCode()).isEqualTo(Code.UNAVAILABLE); assertThat(result.getStatus().getDescription()) - .isEqualTo("Cluster max concurrent requests limit exceeded"); + .isEqualTo("Cluster max concurrent requests limit of 100 exceeded"); assertThat(clusterStats.totalDroppedRequests()).isEqualTo(1L); } else { assertThat(result.getStatus().isOk()).isTrue(); @@ -452,10 +748,12 @@ private void subtest_maxConcurrentRequests_appliedByLbConfig(boolean enableCircu maxConcurrentRequests = 101L; config = new ClusterImplConfig(CLUSTER, EDS_SERVICE_NAME, LRS_SERVER_INFO, maxConcurrentRequests, Collections.emptyList(), - new PolicySelection(weightedTargetProvider, weightedTargetConfig), null); + GracefulSwitchLoadBalancer.createLoadBalancingPolicyConfig( + weightedTargetProvider, weightedTargetConfig), + null, Collections.emptyMap(), null); deliverAddressesAndConfig(Collections.singletonList(endpoint), config); - result = currentPicker.pickSubchannel(mock(PickSubchannelArgs.class)); + result = currentPicker.pickSubchannel(pickSubchannelArgs); assertThat(result.getStatus().isOk()).isTrue(); result.getStreamTracerFactory().newClientStreamTracer( ClientStreamTracer.StreamInfo.newBuilder().build(), new Metadata()); // 101th request @@ -463,14 +761,14 @@ private void subtest_maxConcurrentRequests_appliedByLbConfig(boolean enableCircu assertThat(clusterStats.clusterServiceName()).isEqualTo(EDS_SERVICE_NAME); assertThat(clusterStats.totalDroppedRequests()).isEqualTo(0L); - result = currentPicker.pickSubchannel(mock(PickSubchannelArgs.class)); // 102th request + result = currentPicker.pickSubchannel(pickSubchannelArgs); // 102th request clusterStats = Iterables.getOnlyElement(loadStatsManager.getClusterStatsReports(CLUSTER)); assertThat(clusterStats.clusterServiceName()).isEqualTo(EDS_SERVICE_NAME); if (enableCircuitBreaking) { assertThat(result.getStatus().isOk()).isFalse(); assertThat(result.getStatus().getCode()).isEqualTo(Code.UNAVAILABLE); assertThat(result.getStatus().getDescription()) - .isEqualTo("Cluster max concurrent requests limit exceeded"); + .isEqualTo("Cluster max concurrent requests limit of 101 exceeded"); assertThat(clusterStats.totalDroppedRequests()).isEqualTo(1L); } else { assertThat(result.getStatus().isOk()).isTrue(); @@ -498,7 +796,9 @@ private void subtest_maxConcurrentRequests_appliedWithDefaultValue( buildWeightedTargetConfig(ImmutableMap.of(locality, 10)); ClusterImplConfig config = new ClusterImplConfig(CLUSTER, EDS_SERVICE_NAME, LRS_SERVER_INFO, null, Collections.emptyList(), - new PolicySelection(weightedTargetProvider, weightedTargetConfig), null); + GracefulSwitchLoadBalancer.createLoadBalancingPolicyConfig( + weightedTargetProvider, weightedTargetConfig), + null, Collections.emptyMap(), null); EquivalentAddressGroup endpoint = makeAddress("endpoint-addr", locality); deliverAddressesAndConfig(Collections.singletonList(endpoint), config); assertThat(downstreamBalancers).hasSize(1); // one leaf balancer @@ -506,12 +806,15 @@ private void subtest_maxConcurrentRequests_appliedWithDefaultValue( assertThat(leafBalancer.name).isEqualTo("round_robin"); assertThat(Iterables.getOnlyElement(leafBalancer.addresses).getAddresses()) .isEqualTo(endpoint.getAddresses()); - Subchannel subchannel = leafBalancer.helper.createSubchannel( - CreateSubchannelArgs.newBuilder().setAddresses(leafBalancer.addresses).build()); - leafBalancer.deliverSubchannelState(subchannel, ConnectivityState.READY); + leafBalancer.createSubChannel(); + FakeSubchannel fakeSubchannel = helper.subchannels.poll(); + fakeSubchannel.updateState(ConnectivityStateInfo.forNonError(ConnectivityState.CONNECTING)); + fakeSubchannel.setConnectedEagIndex(0); + fakeSubchannel.updateState(ConnectivityStateInfo.forNonError(ConnectivityState.READY)); + assertThat(currentState).isEqualTo(ConnectivityState.READY); assertThat(currentState).isEqualTo(ConnectivityState.READY); for (int i = 0; i < ClusterImplLoadBalancer.DEFAULT_PER_CLUSTER_MAX_CONCURRENT_REQUESTS; i++) { - PickResult result = currentPicker.pickSubchannel(mock(PickSubchannelArgs.class)); + PickResult result = currentPicker.pickSubchannel(pickSubchannelArgs); assertThat(result.getStatus().isOk()).isTrue(); ClientStreamTracer.Factory streamTracerFactory = result.getStreamTracerFactory(); streamTracerFactory.newClientStreamTracer( @@ -522,14 +825,14 @@ private void subtest_maxConcurrentRequests_appliedWithDefaultValue( assertThat(clusterStats.clusterServiceName()).isEqualTo(EDS_SERVICE_NAME); assertThat(clusterStats.totalDroppedRequests()).isEqualTo(0L); - PickResult result = currentPicker.pickSubchannel(mock(PickSubchannelArgs.class)); + PickResult result = currentPicker.pickSubchannel(pickSubchannelArgs); clusterStats = Iterables.getOnlyElement(loadStatsManager.getClusterStatsReports(CLUSTER)); assertThat(clusterStats.clusterServiceName()).isEqualTo(EDS_SERVICE_NAME); if (enableCircuitBreaking) { assertThat(result.getStatus().isOk()).isFalse(); assertThat(result.getStatus().getCode()).isEqualTo(Code.UNAVAILABLE); assertThat(result.getStatus().getDescription()) - .isEqualTo("Cluster max concurrent requests limit exceeded"); + .isEqualTo("Cluster max concurrent requests limit of 1024 exceeded"); assertThat(clusterStats.totalDroppedRequests()).isEqualTo(1L); } else { assertThat(result.getStatus().isOk()).isTrue(); @@ -544,7 +847,9 @@ public void endpointAddressesAttachedWithClusterName() { buildWeightedTargetConfig(ImmutableMap.of(locality, 10)); ClusterImplConfig config = new ClusterImplConfig(CLUSTER, EDS_SERVICE_NAME, LRS_SERVER_INFO, null, Collections.emptyList(), - new PolicySelection(weightedTargetProvider, weightedTargetConfig), null); + GracefulSwitchLoadBalancer.createLoadBalancingPolicyConfig( + weightedTargetProvider, weightedTargetConfig), + null, Collections.emptyMap(), null); // One locality with two endpoints. EquivalentAddressGroup endpoint1 = makeAddress("endpoint-addr1", locality); EquivalentAddressGroup endpoint2 = makeAddress("endpoint-addr2", locality); @@ -560,28 +865,131 @@ public void endpointAddressesAttachedWithClusterName() { .build(); Subchannel subchannel = leafBalancer.helper.createSubchannel(args); for (EquivalentAddressGroup eag : subchannel.getAllAddresses()) { - assertThat(eag.getAttributes().get(InternalXdsAttributes.ATTR_CLUSTER_NAME)) + assertThat(eag.getAttributes().get(io.grpc.xds.XdsAttributes.ATTR_CLUSTER_NAME)) .isEqualTo(CLUSTER); } // An address update should also retain the cluster attribute. subchannel.updateAddresses(leafBalancer.addresses); for (EquivalentAddressGroup eag : subchannel.getAllAddresses()) { - assertThat(eag.getAttributes().get(InternalXdsAttributes.ATTR_CLUSTER_NAME)) + assertThat(eag.getAttributes().get(io.grpc.xds.XdsAttributes.ATTR_CLUSTER_NAME)) .isEqualTo(CLUSTER); } } + @Test + public void + endpointsWithAuthorityHostname_autoHostRewriteEnabled_pickResultHasAuthorityHostname() { + System.setProperty("GRPC_EXPERIMENTAL_XDS_AUTHORITY_REWRITE", "true"); + try { + LoadBalancerProvider weightedTargetProvider = new WeightedTargetLoadBalancerProvider(); + WeightedTargetConfig weightedTargetConfig = + buildWeightedTargetConfig(ImmutableMap.of(locality, 10)); + ClusterImplConfig config = new ClusterImplConfig(CLUSTER, EDS_SERVICE_NAME, LRS_SERVER_INFO, + null, Collections.emptyList(), + GracefulSwitchLoadBalancer.createLoadBalancingPolicyConfig( + weightedTargetProvider, weightedTargetConfig), + null, Collections.emptyMap(), null); + EquivalentAddressGroup endpoint1 = makeAddress("endpoint-addr1", locality, + "authority-host-name"); + deliverAddressesAndConfig(Arrays.asList(endpoint1), config); + assertThat(downstreamBalancers).hasSize(1); // one leaf balancer + FakeLoadBalancer leafBalancer = Iterables.getOnlyElement(downstreamBalancers); + assertThat(leafBalancer.name).isEqualTo("round_robin"); + + // Simulates leaf load balancer creating subchannels. + CreateSubchannelArgs args = + CreateSubchannelArgs.newBuilder() + .setAddresses(leafBalancer.addresses) + .build(); + Subchannel subchannel = leafBalancer.helper.createSubchannel(args); + subchannel.start(infoObject -> { + if (infoObject.getState() == ConnectivityState.READY) { + helper.updateBalancingState( + ConnectivityState.READY, + new FixedResultPicker(PickResult.withSubchannel(subchannel))); + } + }); + assertThat(subchannel.getAttributes().get(ATTR_SUBCHANNEL_ADDRESS_NAME)).isEqualTo( + "authority-host-name"); + for (EquivalentAddressGroup eag : subchannel.getAllAddresses()) { + assertThat(eag.getAttributes().get(XdsInternalAttributes.ATTR_ADDRESS_NAME)) + .isEqualTo("authority-host-name"); + } + + leafBalancer.deliverSubchannelState(subchannel, ConnectivityState.READY); + assertThat(currentState).isEqualTo(ConnectivityState.READY); + PickDetailsConsumer detailsConsumer = mock(PickDetailsConsumer.class); + pickSubchannelArgs = new PickSubchannelArgsImpl( + TestMethodDescriptors.voidMethod(), new Metadata(), + CallOptions.DEFAULT.withOption(AUTO_HOST_REWRITE_KEY, true), detailsConsumer); + PickResult result = currentPicker.pickSubchannel(pickSubchannelArgs); + assertThat(result.getAuthorityOverride()).isEqualTo("authority-host-name"); + } finally { + System.clearProperty("GRPC_EXPERIMENTAL_XDS_AUTHORITY_REWRITE"); + } + } + + @Test + public void + endpointWithAuthorityHostname_autoHostRewriteNotEnabled_pickResultNoAuthorityHostname() { + LoadBalancerProvider weightedTargetProvider = new WeightedTargetLoadBalancerProvider(); + WeightedTargetConfig weightedTargetConfig = + buildWeightedTargetConfig(ImmutableMap.of(locality, 10)); + ClusterImplConfig config = new ClusterImplConfig(CLUSTER, EDS_SERVICE_NAME, LRS_SERVER_INFO, + null, Collections.emptyList(), + GracefulSwitchLoadBalancer.createLoadBalancingPolicyConfig( + weightedTargetProvider, weightedTargetConfig), + null, Collections.emptyMap(), null); + EquivalentAddressGroup endpoint1 = makeAddress("endpoint-addr1", locality, + "authority-host-name"); + deliverAddressesAndConfig(Arrays.asList(endpoint1), config); + assertThat(downstreamBalancers).hasSize(1); // one leaf balancer + FakeLoadBalancer leafBalancer = Iterables.getOnlyElement(downstreamBalancers); + assertThat(leafBalancer.name).isEqualTo("round_robin"); + + // Simulates leaf load balancer creating subchannels. + CreateSubchannelArgs args = + CreateSubchannelArgs.newBuilder() + .setAddresses(leafBalancer.addresses) + .build(); + Subchannel subchannel = leafBalancer.helper.createSubchannel(args); + subchannel.start(infoObject -> { + if (infoObject.getState() == ConnectivityState.READY) { + helper.updateBalancingState( + ConnectivityState.READY, + new FixedResultPicker(PickResult.withSubchannel(subchannel))); + } + }); + // Sub Channel wrapper args won't have the address name although addresses will. + assertThat(subchannel.getAttributes().get(ATTR_SUBCHANNEL_ADDRESS_NAME)).isNull(); + for (EquivalentAddressGroup eag : subchannel.getAllAddresses()) { + assertThat(eag.getAttributes().get(XdsInternalAttributes.ATTR_ADDRESS_NAME)) + .isEqualTo("authority-host-name"); + } + + leafBalancer.deliverSubchannelState(subchannel, ConnectivityState.READY); + assertThat(currentState).isEqualTo(ConnectivityState.READY); + PickDetailsConsumer detailsConsumer = mock(PickDetailsConsumer.class); + pickSubchannelArgs = new PickSubchannelArgsImpl( + TestMethodDescriptors.voidMethod(), new Metadata(), CallOptions.DEFAULT, detailsConsumer); + PickResult result = currentPicker.pickSubchannel(pickSubchannelArgs); + assertThat(result.getAuthorityOverride()).isNull(); + } + @Test public void endpointAddressesAttachedWithTlsConfig_securityEnabledByDefault() { UpstreamTlsContext upstreamTlsContext = - CommonTlsContextTestsUtil.buildUpstreamTlsContext("google_cloud_private_spiffe", true); + CommonTlsContextTestsUtil.buildUpstreamTlsContext( + "google_cloud_private_spiffe", true); LoadBalancerProvider weightedTargetProvider = new WeightedTargetLoadBalancerProvider(); WeightedTargetConfig weightedTargetConfig = buildWeightedTargetConfig(ImmutableMap.of(locality, 10)); ClusterImplConfig config = new ClusterImplConfig(CLUSTER, EDS_SERVICE_NAME, LRS_SERVER_INFO, null, Collections.emptyList(), - new PolicySelection(weightedTargetProvider, weightedTargetConfig), upstreamTlsContext); + GracefulSwitchLoadBalancer.createLoadBalancingPolicyConfig( + weightedTargetProvider, weightedTargetConfig), + upstreamTlsContext, Collections.emptyMap(), null); // One locality with two endpoints. EquivalentAddressGroup endpoint1 = makeAddress("endpoint-addr1", locality); EquivalentAddressGroup endpoint2 = makeAddress("endpoint-addr2", locality); @@ -597,41 +1005,46 @@ public void endpointAddressesAttachedWithTlsConfig_securityEnabledByDefault() { Subchannel subchannel = leafBalancer.helper.createSubchannel(args); for (EquivalentAddressGroup eag : subchannel.getAllAddresses()) { SslContextProviderSupplier supplier = - eag.getAttributes().get(InternalXdsAttributes.ATTR_SSL_CONTEXT_PROVIDER_SUPPLIER); + eag.getAttributes().get(SecurityProtocolNegotiators.ATTR_SSL_CONTEXT_PROVIDER_SUPPLIER); assertThat(supplier.getTlsContext()).isEqualTo(upstreamTlsContext); } // Removes UpstreamTlsContext from the config. config = new ClusterImplConfig(CLUSTER, EDS_SERVICE_NAME, LRS_SERVER_INFO, null, Collections.emptyList(), - new PolicySelection(weightedTargetProvider, weightedTargetConfig), null); + GracefulSwitchLoadBalancer.createLoadBalancingPolicyConfig( + weightedTargetProvider, weightedTargetConfig), + null, Collections.emptyMap(), null); deliverAddressesAndConfig(Arrays.asList(endpoint1, endpoint2), config); assertThat(Iterables.getOnlyElement(downstreamBalancers)).isSameInstanceAs(leafBalancer); subchannel = leafBalancer.helper.createSubchannel(args); // creates new connections for (EquivalentAddressGroup eag : subchannel.getAllAddresses()) { - assertThat(eag.getAttributes().get(InternalXdsAttributes.ATTR_SSL_CONTEXT_PROVIDER_SUPPLIER)) + assertThat( + eag.getAttributes().get(SecurityProtocolNegotiators.ATTR_SSL_CONTEXT_PROVIDER_SUPPLIER)) .isNull(); } // Config with a new UpstreamTlsContext. - upstreamTlsContext = - CommonTlsContextTestsUtil.buildUpstreamTlsContext("google_cloud_private_spiffe1", true); + upstreamTlsContext = CommonTlsContextTestsUtil.buildUpstreamTlsContext( + "google_cloud_private_spiffe1", true); config = new ClusterImplConfig(CLUSTER, EDS_SERVICE_NAME, LRS_SERVER_INFO, null, Collections.emptyList(), - new PolicySelection(weightedTargetProvider, weightedTargetConfig), upstreamTlsContext); + GracefulSwitchLoadBalancer.createLoadBalancingPolicyConfig( + weightedTargetProvider, weightedTargetConfig), + upstreamTlsContext, Collections.emptyMap(), null); deliverAddressesAndConfig(Arrays.asList(endpoint1, endpoint2), config); assertThat(Iterables.getOnlyElement(downstreamBalancers)).isSameInstanceAs(leafBalancer); subchannel = leafBalancer.helper.createSubchannel(args); // creates new connections for (EquivalentAddressGroup eag : subchannel.getAllAddresses()) { SslContextProviderSupplier supplier = - eag.getAttributes().get(InternalXdsAttributes.ATTR_SSL_CONTEXT_PROVIDER_SUPPLIER); + eag.getAttributes().get(SecurityProtocolNegotiators.ATTR_SSL_CONTEXT_PROVIDER_SUPPLIER); assertThat(supplier.isShutdown()).isFalse(); assertThat(supplier.getTlsContext()).isEqualTo(upstreamTlsContext); } loadBalancer.shutdown(); for (EquivalentAddressGroup eag : subchannel.getAllAddresses()) { SslContextProviderSupplier supplier = - eag.getAttributes().get(InternalXdsAttributes.ATTR_SSL_CONTEXT_PROVIDER_SUPPLIER); + eag.getAttributes().get(SecurityProtocolNegotiators.ATTR_SSL_CONTEXT_PROVIDER_SUPPLIER); assertThat(supplier.isShutdown()).isTrue(); } loadBalancer = null; @@ -644,8 +1057,8 @@ private void deliverAddressesAndConfig(List addresses, .setAddresses(addresses) .setAttributes( Attributes.newBuilder() - .set(InternalXdsAttributes.XDS_CLIENT_POOL, xdsClientPool) - .set(InternalXdsAttributes.CALL_COUNTER_PROVIDER, callCounterProvider) + .set(io.grpc.xds.XdsAttributes.XDS_CLIENT, xdsClient) + .set(io.grpc.xds.XdsAttributes.CALL_COUNTER_PROVIDER, callCounterProvider) .build()) .setLoadBalancingPolicyConfig(config) .build()); @@ -666,6 +1079,11 @@ private WeightedTargetConfig buildWeightedTargetConfig(Map lo * Create a locality-labeled address. */ private static EquivalentAddressGroup makeAddress(final String name, Locality locality) { + return makeAddress(name, locality, null); + } + + private static EquivalentAddressGroup makeAddress(final String name, Locality locality, + String authorityHostname) { class FakeSocketAddress extends SocketAddress { private final String name; @@ -696,8 +1114,15 @@ public String toString() { } } + Attributes.Builder attributes = Attributes.newBuilder() + .set(io.grpc.xds.XdsAttributes.ATTR_LOCALITY, locality) + // Unique but arbitrary string + .set(EquivalentAddressGroup.ATTR_LOCALITY_NAME, locality.toString()); + if (authorityHostname != null) { + attributes.set(XdsInternalAttributes.ATTR_ADDRESS_NAME, authorityHostname); + } EquivalentAddressGroup eag = new EquivalentAddressGroup(new FakeSocketAddress(name), - Attributes.newBuilder().set(InternalXdsAttributes.ATTR_LOCALITY, locality).build()); + attributes.build()); return AddressFilter.setPathFilter(eag, Collections.singletonList(locality.toString())); } @@ -763,18 +1188,38 @@ public void shutdown() { } void deliverSubchannelState(final Subchannel subchannel, ConnectivityState state) { + deliverSubchannelState(PickResult.withSubchannel(subchannel), state); + } + + void deliverSubchannelState(final PickResult result, ConnectivityState state) { SubchannelPicker picker = new SubchannelPicker() { @Override public PickResult pickSubchannel(PickSubchannelArgs args) { - return PickResult.withSubchannel(subchannel); + return result; } }; helper.updateBalancingState(state, picker); } + + Subchannel createSubChannel() { + Subchannel subchannel = helper.createSubchannel( + CreateSubchannelArgs.newBuilder().setAddresses(addresses).build()); + subchannel.start(infoObject -> { + if (infoObject.getState() == ConnectivityState.READY) { + helper.updateBalancingState( + ConnectivityState.READY, + new FixedResultPicker(PickResult.withSubchannel(subchannel))); + } + }); + subchannel.requestConnection(); + return subchannel; + } } private final class FakeLbHelper extends LoadBalancer.Helper { + private final Queue subchannels = new LinkedList<>(); + @Override public SynchronizationContext getSynchronizationContext() { return syncContext; @@ -789,7 +1234,9 @@ public void updateBalancingState( @Override public Subchannel createSubchannel(CreateSubchannelArgs args) { - return new FakeSubchannel(args.getAddresses(), args.getAttributes()); + FakeSubchannel subchannel = new FakeSubchannel(args.getAddresses(), args.getAttributes()); + subchannels.add(subchannel); + return subchannel; } @Override @@ -801,23 +1248,38 @@ public ManagedChannel createOobChannel(EquivalentAddressGroup eag, String author public String getAuthority() { return AUTHORITY; } + + @Override + public void refreshNameResolution() {} } private static final class FakeSubchannel extends Subchannel { private final List eags; private final Attributes attrs; + private SubchannelStateListener listener; + private Attributes connectedAttributes; + private ConnectivityStateInfo state = ConnectivityStateInfo.forNonError(ConnectivityState.IDLE); + private boolean connectionRequested; private FakeSubchannel(List eags, Attributes attrs) { this.eags = eags; this.attrs = attrs; } + @Override + public void start(SubchannelStateListener listener) { + this.listener = checkNotNull(listener, "listener"); + } + @Override public void shutdown() { } @Override public void requestConnection() { + if (state.getState() == ConnectivityState.IDLE) { + this.connectionRequested = true; + } } @Override @@ -833,9 +1295,43 @@ public Attributes getAttributes() { @Override public void updateAddresses(List addrs) { } + + @Override + public Attributes getConnectedAddressAttributes() { + return connectedAttributes; + } + + public void updateState(ConnectivityStateInfo newState) { + switch (newState.getState()) { + case IDLE: + assertThat(state.getState()).isEqualTo(ConnectivityState.READY); + break; + case CONNECTING: + assertThat(state.getState()) + .isIn(Arrays.asList(ConnectivityState.IDLE, ConnectivityState.TRANSIENT_FAILURE)); + if (state.getState() == ConnectivityState.IDLE) { + assertWithMessage("Connection requested").that(this.connectionRequested).isTrue(); + this.connectionRequested = false; + } + break; + case READY: + case TRANSIENT_FAILURE: + assertThat(state.getState()).isEqualTo(ConnectivityState.CONNECTING); + break; + default: + break; + } + this.state = newState; + listener.onSubchannelState(newState); + } + + public void setConnectedEagIndex(int eagIndex) { + this.connectedAttributes = eags.get(eagIndex).getAttributes(); + } } private final class FakeXdsClient extends XdsClient { + @Override public ClusterDropStats addClusterDropStats( ServerInfo lrsServerInfo, String clusterName, @Nullable String edsServiceName) { @@ -845,8 +1341,9 @@ public ClusterDropStats addClusterDropStats( @Override public ClusterLocalityStats addClusterLocalityStats( ServerInfo lrsServerInfo, String clusterName, @Nullable String edsServiceName, - Locality locality) { - return loadStatsManager.getClusterLocalityStats(clusterName, edsServiceName, locality); + Locality locality, BackendMetricPropagation backendMetricPropagation) { + return loadStatsManager.getClusterLocalityStats( + clusterName, edsServiceName, locality, backendMetricPropagation); } @Override diff --git a/xds/src/test/java/io/grpc/xds/ClusterManagerLoadBalancerProviderTest.java b/xds/src/test/java/io/grpc/xds/ClusterManagerLoadBalancerProviderTest.java index 515f6fef3ef..40943658520 100644 --- a/xds/src/test/java/io/grpc/xds/ClusterManagerLoadBalancerProviderTest.java +++ b/xds/src/test/java/io/grpc/xds/ClusterManagerLoadBalancerProviderTest.java @@ -26,7 +26,7 @@ import io.grpc.Status; import io.grpc.Status.Code; import io.grpc.internal.JsonParser; -import io.grpc.internal.ServiceConfigUtil.PolicySelection; +import io.grpc.util.GracefulSwitchLoadBalancer; import io.grpc.xds.ClusterManagerLoadBalancerProvider.ClusterManagerConfig; import java.io.IOException; import java.util.Map; @@ -133,10 +133,9 @@ public ConfigOrError parseLoadBalancingPolicyConfig( assertThat(config.childPolicies) .containsExactly( "child1", - new PolicySelection( - lbProviderFoo, fooConfig), + GracefulSwitchLoadBalancer.createLoadBalancingPolicyConfig(lbProviderFoo, fooConfig), "child2", - new PolicySelection(lbProviderBar, barConfig)); + GracefulSwitchLoadBalancer.createLoadBalancingPolicyConfig(lbProviderBar, barConfig)); } @Test diff --git a/xds/src/test/java/io/grpc/xds/ClusterManagerLoadBalancerTest.java b/xds/src/test/java/io/grpc/xds/ClusterManagerLoadBalancerTest.java index 786962d0f1d..8856efd685f 100644 --- a/xds/src/test/java/io/grpc/xds/ClusterManagerLoadBalancerTest.java +++ b/xds/src/test/java/io/grpc/xds/ClusterManagerLoadBalancerTest.java @@ -37,6 +37,7 @@ import io.grpc.EquivalentAddressGroup; import io.grpc.LoadBalancer; import io.grpc.LoadBalancer.Helper; +import io.grpc.LoadBalancer.PickDetailsConsumer; import io.grpc.LoadBalancer.PickResult; import io.grpc.LoadBalancer.PickSubchannelArgs; import io.grpc.LoadBalancer.ResolvedAddresses; @@ -51,10 +52,11 @@ import io.grpc.SynchronizationContext; import io.grpc.internal.FakeClock; import io.grpc.internal.PickSubchannelArgsImpl; -import io.grpc.internal.ServiceConfigUtil.PolicySelection; import io.grpc.testing.TestMethodDescriptors; +import io.grpc.util.GracefulSwitchLoadBalancer; import io.grpc.xds.ClusterManagerLoadBalancerProvider.ClusterManagerConfig; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collections; import java.util.HashMap; import java.util.LinkedHashMap; @@ -116,7 +118,7 @@ public void tearDown() { } @Test - public void handleResolvedAddressesUpdatesChannelPicker() { + public void acceptResolvedAddressesUpdatesChannelPicker() { deliverResolvedAddresses(ImmutableMap.of("childA", "policy_a", "childB", "policy_b")); verify(helper, atLeastOnce()).updateBalancingState( @@ -287,16 +289,27 @@ private void deliverResolvedAddresses(final Map childPolicies, b .build()); } + // Prevent ClusterManagerLB from detecting different providers even when the configuration is the + // same. + private Map, FakeLoadBalancerProvider> fakeLoadBalancerProviderCache + = new HashMap<>(); + private ClusterManagerConfig buildConfig(Map childPolicies, boolean failing) { - Map childPolicySelections = new LinkedHashMap<>(); + Map childConfigs = new LinkedHashMap<>(); for (String name : childPolicies.keySet()) { String childPolicyName = childPolicies.get(name); Object childConfig = lbConfigInventory.get(name); - PolicySelection policy = - new PolicySelection(new FakeLoadBalancerProvider(childPolicyName, failing), childConfig); - childPolicySelections.put(name, policy); + FakeLoadBalancerProvider lbProvider = + fakeLoadBalancerProviderCache.get(Arrays.asList(childPolicyName, failing)); + if (lbProvider == null) { + lbProvider = new FakeLoadBalancerProvider(childPolicyName, failing); + fakeLoadBalancerProviderCache.put(Arrays.asList(childPolicyName, failing), lbProvider); + } + Object policy = + GracefulSwitchLoadBalancer.createLoadBalancingPolicyConfig(lbProvider, childConfig); + childConfigs.put(name, policy); } - return new ClusterManagerConfig(childPolicySelections); + return new ClusterManagerConfig(childConfigs); } private static PickResult pickSubchannel(SubchannelPicker picker, String clusterName) { @@ -310,7 +323,8 @@ private static PickResult pickSubchannel(SubchannelPicker picker, String cluster .build(), new Metadata(), CallOptions.DEFAULT.withOption( - XdsNameResolver.CLUSTER_SELECTION_KEY, clusterName)); + XdsNameResolver.CLUSTER_SELECTION_KEY, clusterName), + new PickDetailsConsumer() {}); return picker.pickSubchannel(args); } diff --git a/xds/src/test/java/io/grpc/xds/ClusterResolverLoadBalancerProviderTest.java b/xds/src/test/java/io/grpc/xds/ClusterResolverLoadBalancerProviderTest.java deleted file mode 100644 index a201ecfaa4b..00000000000 --- a/xds/src/test/java/io/grpc/xds/ClusterResolverLoadBalancerProviderTest.java +++ /dev/null @@ -1,79 +0,0 @@ -/* - * Copyright 2020 The gRPC Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.grpc.xds; - -import static com.google.common.truth.Truth.assertThat; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; - -import io.grpc.ChannelLogger; -import io.grpc.LoadBalancer; -import io.grpc.LoadBalancer.Helper; -import io.grpc.LoadBalancerProvider; -import io.grpc.LoadBalancerRegistry; -import io.grpc.NameResolver; -import io.grpc.NameResolver.ServiceConfigParser; -import io.grpc.NameResolverRegistry; -import io.grpc.SynchronizationContext; -import io.grpc.internal.FakeClock; -import io.grpc.internal.GrpcUtil; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; - -/** Tests for {@link ClusterResolverLoadBalancerProvider}. */ -@RunWith(JUnit4.class) -public class ClusterResolverLoadBalancerProviderTest { - - @Test - public void provided() { - LoadBalancerProvider provider = - LoadBalancerRegistry.getDefaultRegistry().getProvider( - XdsLbPolicies.CLUSTER_RESOLVER_POLICY_NAME); - assertThat(provider).isInstanceOf(ClusterResolverLoadBalancerProvider.class); - } - - @Test - public void providesLoadBalancer() { - Helper helper = mock(Helper.class); - - SynchronizationContext syncContext = new SynchronizationContext( - new Thread.UncaughtExceptionHandler() { - @Override - public void uncaughtException(Thread t, Throwable e) { - throw new AssertionError(e); - } - }); - FakeClock fakeClock = new FakeClock(); - NameResolverRegistry nsRegistry = new NameResolverRegistry(); - NameResolver.Args args = NameResolver.Args.newBuilder() - .setDefaultPort(8080) - .setProxyDetector(GrpcUtil.NOOP_PROXY_DETECTOR) - .setSynchronizationContext(syncContext) - .setServiceConfigParser(mock(ServiceConfigParser.class)) - .setChannelLogger(mock(ChannelLogger.class)) - .build(); - when(helper.getNameResolverRegistry()).thenReturn(nsRegistry); - when(helper.getNameResolverArgs()).thenReturn(args); - when(helper.getSynchronizationContext()).thenReturn(syncContext); - when(helper.getScheduledExecutorService()).thenReturn(fakeClock.getScheduledExecutorService()); - when(helper.getAuthority()).thenReturn("api.google.com"); - LoadBalancerProvider provider = new ClusterResolverLoadBalancerProvider(); - LoadBalancer loadBalancer = provider.newLoadBalancer(helper); - assertThat(loadBalancer).isInstanceOf(ClusterResolverLoadBalancer.class); - } -} diff --git a/xds/src/test/java/io/grpc/xds/ClusterResolverLoadBalancerTest.java b/xds/src/test/java/io/grpc/xds/ClusterResolverLoadBalancerTest.java index ddc7ef56d90..a508da34f88 100644 --- a/xds/src/test/java/io/grpc/xds/ClusterResolverLoadBalancerTest.java +++ b/xds/src/test/java/io/grpc/xds/ClusterResolverLoadBalancerTest.java @@ -21,21 +21,46 @@ import static io.grpc.xds.XdsLbPolicies.PRIORITY_POLICY_NAME; import static io.grpc.xds.XdsLbPolicies.WEIGHTED_TARGET_POLICY_NAME; import static io.grpc.xds.XdsLbPolicies.WRR_LOCALITY_POLICY_NAME; +import static io.grpc.xds.XdsTestControlPlaneService.ADS_TYPE_URL_CDS; +import static io.grpc.xds.XdsTestControlPlaneService.ADS_TYPE_URL_EDS; +import static io.grpc.xds.XdsTestControlPlaneService.ADS_TYPE_URL_LDS; +import static io.grpc.xds.XdsTestControlPlaneService.ADS_TYPE_URL_RDS; +import static java.util.stream.Collectors.toList; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; -import static org.mockito.Mockito.reset; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import com.google.common.collect.ImmutableMap; import com.google.common.collect.Iterables; +import com.google.protobuf.Any; +import com.google.protobuf.Duration; +import com.google.protobuf.UInt32Value; +import com.google.protobuf.UInt64Value; +import io.envoyproxy.envoy.config.cluster.v3.Cluster; +import io.envoyproxy.envoy.config.cluster.v3.OutlierDetection; +import io.envoyproxy.envoy.config.core.v3.Address; +import io.envoyproxy.envoy.config.core.v3.AggregatedConfigSource; +import io.envoyproxy.envoy.config.core.v3.ConfigSource; +import io.envoyproxy.envoy.config.core.v3.HealthStatus; +import io.envoyproxy.envoy.config.core.v3.Locality; +import io.envoyproxy.envoy.config.core.v3.Metadata; +import io.envoyproxy.envoy.config.core.v3.SocketAddress; +import io.envoyproxy.envoy.config.core.v3.TransportSocket; +import io.envoyproxy.envoy.config.endpoint.v3.ClusterLoadAssignment; +import io.envoyproxy.envoy.config.endpoint.v3.Endpoint; +import io.envoyproxy.envoy.config.endpoint.v3.LbEndpoint; +import io.envoyproxy.envoy.config.endpoint.v3.LocalityLbEndpoints; +import io.envoyproxy.envoy.extensions.transport_sockets.http_11_proxy.v3.Http11ProxyUpstreamTransport; import io.grpc.Attributes; import io.grpc.ChannelLogger; import io.grpc.ConnectivityState; import io.grpc.EquivalentAddressGroup; -import io.grpc.InsecureChannelCredentials; +import io.grpc.HttpConnectProxiedSocketAddress; +import io.grpc.InternalEquivalentAddressGroup; import io.grpc.LoadBalancer; import io.grpc.LoadBalancer.Helper; import io.grpc.LoadBalancer.PickResult; @@ -50,48 +75,36 @@ import io.grpc.NameResolverProvider; import io.grpc.NameResolverRegistry; import io.grpc.Status; -import io.grpc.Status.Code; +import io.grpc.StatusOr; import io.grpc.SynchronizationContext; -import io.grpc.internal.BackoffPolicy; +import io.grpc.inprocess.InProcessChannelBuilder; +import io.grpc.inprocess.InProcessServerBuilder; import io.grpc.internal.FakeClock; -import io.grpc.internal.FakeClock.ScheduledTask; import io.grpc.internal.GrpcUtil; -import io.grpc.internal.ObjectPool; -import io.grpc.internal.ServiceConfigUtil.PolicySelection; +import io.grpc.testing.GrpcCleanupRule; +import io.grpc.util.GracefulSwitchLoadBalancerAccessor; import io.grpc.util.OutlierDetectionLoadBalancer.OutlierDetectionLoadBalancerConfig; import io.grpc.util.OutlierDetectionLoadBalancerProvider; +import io.grpc.xds.CdsLoadBalancerProvider.CdsConfig; import io.grpc.xds.ClusterImplLoadBalancerProvider.ClusterImplConfig; -import io.grpc.xds.ClusterResolverLoadBalancerProvider.ClusterResolverConfig; -import io.grpc.xds.ClusterResolverLoadBalancerProvider.ClusterResolverConfig.DiscoveryMechanism; import io.grpc.xds.Endpoints.DropOverload; -import io.grpc.xds.Endpoints.LbEndpoint; -import io.grpc.xds.Endpoints.LocalityLbEndpoints; -import io.grpc.xds.EnvoyServerProtoData.FailurePercentageEjection; -import io.grpc.xds.EnvoyServerProtoData.OutlierDetection; -import io.grpc.xds.EnvoyServerProtoData.SuccessRateEjection; import io.grpc.xds.EnvoyServerProtoData.UpstreamTlsContext; -import io.grpc.xds.LeastRequestLoadBalancer.LeastRequestConfig; import io.grpc.xds.PriorityLoadBalancerProvider.PriorityLbConfig; import io.grpc.xds.PriorityLoadBalancerProvider.PriorityLbConfig.PriorityChildConfig; import io.grpc.xds.RingHashLoadBalancer.RingHashConfig; import io.grpc.xds.WrrLocalityLoadBalancer.WrrLocalityConfig; -import io.grpc.xds.XdsEndpointResource.EdsUpdate; +import io.grpc.xds.client.BackendMetricPropagation; import io.grpc.xds.client.Bootstrapper.ServerInfo; -import io.grpc.xds.client.Locality; +import io.grpc.xds.client.LoadStatsManager2; import io.grpc.xds.client.XdsClient; -import io.grpc.xds.client.XdsResourceType; -import io.grpc.xds.internal.security.CommonTlsContextTestsUtil; -import java.net.SocketAddress; +import io.grpc.xds.internal.XdsInternalAttributes; +import java.net.InetSocketAddress; import java.net.URI; -import java.net.URISyntaxException; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; -import java.util.HashMap; +import java.util.Iterator; import java.util.List; -import java.util.Map; -import java.util.Objects; -import java.util.concurrent.Executor; import java.util.concurrent.TimeUnit; import javax.annotation.Nullable; import org.junit.After; @@ -102,9 +115,7 @@ import org.junit.runners.JUnit4; import org.mockito.ArgumentCaptor; import org.mockito.Captor; -import org.mockito.InOrder; import org.mockito.Mock; -import org.mockito.Mockito; import org.mockito.junit.MockitoJUnit; import org.mockito.junit.MockitoRule; @@ -112,38 +123,43 @@ @RunWith(JUnit4.class) public class ClusterResolverLoadBalancerTest { @Rule public final MockitoRule mocks = MockitoJUnit.rule(); + @Rule + public final GrpcCleanupRule cleanupRule = new GrpcCleanupRule(); - private static final String AUTHORITY = "api.google.com"; - private static final String CLUSTER1 = "cluster-foo.googleapis.com"; - private static final String CLUSTER2 = "cluster-bar.googleapis.com"; - private static final String CLUSTER_DNS = "cluster-dns.googleapis.com"; - private static final String EDS_SERVICE_NAME1 = "backend-service-foo.googleapis.com"; - private static final String EDS_SERVICE_NAME2 = "backend-service-bar.googleapis.com"; + private static final String SERVER_NAME = "example.com"; + private static final String CLUSTER = "cluster-foo.googleapis.com"; + private static final String EDS_SERVICE_NAME = "backend-service-foo.googleapis.com"; private static final String DNS_HOST_NAME = "dns-service.googleapis.com"; - private static final ServerInfo LRS_SERVER_INFO = - ServerInfo.create("lrs.googleapis.com", InsecureChannelCredentials.create()); - private final Locality locality1 = - Locality.create("test-region-1", "test-zone-1", "test-subzone-1"); - private final Locality locality2 = - Locality.create("test-region-2", "test-zone-2", "test-subzone-2"); - private final Locality locality3 = - Locality.create("test-region-3", "test-zone-3", "test-subzone-3"); - private final UpstreamTlsContext tlsContext = - CommonTlsContextTestsUtil.buildUpstreamTlsContext("google_cloud_private_spiffe", true); - private final OutlierDetection outlierDetection = OutlierDetection.create( - 100L, 100L, 100L, 100, SuccessRateEjection.create(100, 100, 100, 100), - FailurePercentageEjection.create(100, 100, 100, 100)); - private final DiscoveryMechanism edsDiscoveryMechanism1 = - DiscoveryMechanism.forEds(CLUSTER1, EDS_SERVICE_NAME1, LRS_SERVER_INFO, 100L, tlsContext, - null); - private final DiscoveryMechanism edsDiscoveryMechanism2 = - DiscoveryMechanism.forEds(CLUSTER2, EDS_SERVICE_NAME2, LRS_SERVER_INFO, 200L, tlsContext, - null); - private final DiscoveryMechanism edsDiscoveryMechanismWithOutlierDetection = - DiscoveryMechanism.forEds(CLUSTER1, EDS_SERVICE_NAME1, LRS_SERVER_INFO, 100L, tlsContext, - outlierDetection); - private final DiscoveryMechanism logicalDnsDiscoveryMechanism = - DiscoveryMechanism.forLogicalDns(CLUSTER_DNS, DNS_HOST_NAME, LRS_SERVER_INFO, 300L, null); + private static final Cluster EDS_CLUSTER = Cluster.newBuilder() + .setName(CLUSTER) + .setType(Cluster.DiscoveryType.EDS) + .setEdsClusterConfig(Cluster.EdsClusterConfig.newBuilder() + .setServiceName(EDS_SERVICE_NAME) + .setEdsConfig(ConfigSource.newBuilder() + .setAds(AggregatedConfigSource.newBuilder()))) + .build(); + private static final Cluster LOGICAL_DNS_CLUSTER = Cluster.newBuilder() + .setName(CLUSTER) + .setType(Cluster.DiscoveryType.LOGICAL_DNS) + .setLoadAssignment(ClusterLoadAssignment.newBuilder() + .addEndpoints(LocalityLbEndpoints.newBuilder() + .addLbEndpoints(newSocketLbEndpoint(DNS_HOST_NAME, 9000)))) + .build(); + private static final Locality LOCALITY1 = Locality.newBuilder() + .setRegion("test-region-1") + .setZone("test-zone-1") + .setSubZone("test-subzone-1") + .build(); + private static final Locality LOCALITY2 = Locality.newBuilder() + .setRegion("test-region-2") + .setZone("test-zone-2") + .setSubZone("test-subzone-2") + .build(); + private static final Locality LOCALITY3 = Locality.newBuilder() + .setRegion("test-region-3") + .setZone("test-zone-3") + .setSubZone("test-subzone-3") + .build(); private final SynchronizationContext syncContext = new SynchronizationContext( new Thread.UncaughtExceptionHandler() { @@ -155,47 +171,31 @@ public void uncaughtException(Thread t, Throwable e) { private final FakeClock fakeClock = new FakeClock(); private final LoadBalancerRegistry lbRegistry = new LoadBalancerRegistry(); private final NameResolverRegistry nsRegistry = new NameResolverRegistry(); - private final PolicySelection roundRobin = new PolicySelection( - new FakeLoadBalancerProvider("wrr_locality_experimental"), new WrrLocalityConfig( - new PolicySelection(new FakeLoadBalancerProvider("round_robin"), null))); - private final PolicySelection ringHash = new PolicySelection( - new FakeLoadBalancerProvider("ring_hash_experimental"), new RingHashConfig(10L, 100L)); - private final PolicySelection leastRequest = new PolicySelection( - new FakeLoadBalancerProvider("wrr_locality_experimental"), new WrrLocalityConfig( - new PolicySelection(new FakeLoadBalancerProvider("least_request_experimental"), - new LeastRequestConfig(3)))); private final List childBalancers = new ArrayList<>(); private final List resolvers = new ArrayList<>(); - private final FakeXdsClient xdsClient = new FakeXdsClient(); - private final ObjectPool xdsClientPool = new ObjectPool() { - @Override - public XdsClient getObject() { - xdsClientRefs++; - return xdsClient; - } - - @Override - public XdsClient returnObject(Object object) { - xdsClientRefs--; - return null; - } - }; - + private final XdsTestControlPlaneService controlPlaneService = new XdsTestControlPlaneService(); + private final XdsClient xdsClient = XdsTestUtils.createXdsClient( + Arrays.asList("control-plane.example.com"), + serverInfo -> new GrpcXdsTransportFactory.GrpcXdsTransport( + InProcessChannelBuilder + .forName(serverInfo.target()) + .directExecutor() + .build()), + fakeClock); + + + private XdsDependencyManager xdsDepManager; @Mock private Helper helper; - @Mock - private BackoffPolicy.Provider backoffPolicyProvider; - @Mock - private BackoffPolicy backoffPolicy1; - @Mock - private BackoffPolicy backoffPolicy2; @Captor private ArgumentCaptor pickerCaptor; - private int xdsClientRefs; - private ClusterResolverLoadBalancer loadBalancer; + private CdsLoadBalancer2 loadBalancer; + private boolean originalIsEnabledXdsHttpConnect; @Before - public void setUp() throws URISyntaxException { + public void setUp() throws Exception { + lbRegistry.register(new RingHashLoadBalancerProvider()); + lbRegistry.register(new WrrLocalityLoadBalancerProvider()); lbRegistry.register(new FakeLoadBalancerProvider(PRIORITY_POLICY_NAME)); lbRegistry.register(new FakeLoadBalancerProvider(CLUSTER_IMPL_POLICY_NAME)); lbRegistry.register(new FakeLoadBalancerProvider(WEIGHTED_TARGET_POLICY_NAME)); @@ -208,341 +208,428 @@ public void setUp() throws URISyntaxException { .setSynchronizationContext(syncContext) .setServiceConfigParser(mock(ServiceConfigParser.class)) .setChannelLogger(mock(ChannelLogger.class)) + .setScheduledExecutorService(fakeClock.getScheduledExecutorService()) + .setNameResolverRegistry(nsRegistry) .build(); + + xdsDepManager = new XdsDependencyManager( + xdsClient, + syncContext, + SERVER_NAME, + SERVER_NAME, + args); + + cleanupRule.register(InProcessServerBuilder + .forName("control-plane.example.com") + .addService(controlPlaneService) + .directExecutor() + .build() + .start()); + + controlPlaneService.setXdsConfig(ADS_TYPE_URL_LDS, ImmutableMap.of( + SERVER_NAME, ControlPlaneRule.buildClientListener(SERVER_NAME, "my-route"))); + controlPlaneService.setXdsConfig(ADS_TYPE_URL_RDS, ImmutableMap.of( + "my-route", XdsTestUtils.buildRouteConfiguration(SERVER_NAME, "my-route", CLUSTER))); + controlPlaneService.setXdsConfig(ADS_TYPE_URL_CDS, ImmutableMap.of( + CLUSTER, EDS_CLUSTER)); + controlPlaneService.setXdsConfig(ADS_TYPE_URL_EDS, ImmutableMap.of( + EDS_SERVICE_NAME, ControlPlaneRule.buildClusterLoadAssignment( + "127.0.0.1", "", 8080, EDS_SERVICE_NAME))); + nsRegistry.register(new FakeNameResolverProvider()); - when(helper.getNameResolverRegistry()).thenReturn(nsRegistry); - when(helper.getNameResolverArgs()).thenReturn(args); - when(helper.getSynchronizationContext()).thenReturn(syncContext); - when(helper.getScheduledExecutorService()).thenReturn(fakeClock.getScheduledExecutorService()); - when(helper.getAuthority()).thenReturn(AUTHORITY); - when(backoffPolicyProvider.get()).thenReturn(backoffPolicy1, backoffPolicy2); - when(backoffPolicy1.nextBackoffNanos()) - .thenReturn(TimeUnit.SECONDS.toNanos(1L), TimeUnit.SECONDS.toNanos(10L)); - when(backoffPolicy2.nextBackoffNanos()) - .thenReturn(TimeUnit.SECONDS.toNanos(5L), TimeUnit.SECONDS.toNanos(50L)); - loadBalancer = new ClusterResolverLoadBalancer(helper, lbRegistry, backoffPolicyProvider); + when(helper.getAuthority()).thenReturn("api.google.com"); + doAnswer((inv) -> { + xdsDepManager.requestReresolution(); + return null; + }).when(helper).refreshNameResolution(); + loadBalancer = new CdsLoadBalancer2(helper, lbRegistry); + + originalIsEnabledXdsHttpConnect = XdsClusterResource.isEnabledXdsHttpConnect; } @After - public void tearDown() { + public void tearDown() throws Exception { + XdsClusterResource.isEnabledXdsHttpConnect = originalIsEnabledXdsHttpConnect; loadBalancer.shutdown(); + if (xdsDepManager != null) { + xdsDepManager.shutdown(); + } + assertThat(xdsClient.getSubscribedResourcesMetadataSnapshot().get()).isEmpty(); + xdsClient.shutdown(); + assertThat(childBalancers).isEmpty(); assertThat(resolvers).isEmpty(); - assertThat(xdsClient.watchers).isEmpty(); - assertThat(xdsClientRefs).isEqualTo(0); assertThat(fakeClock.getPendingTasks()).isEmpty(); } @Test - public void edsClustersWithRingHashEndpointLbPolicy() { - ClusterResolverConfig config = new ClusterResolverConfig( - Collections.singletonList(edsDiscoveryMechanism1), ringHash); - deliverLbConfig(config); - assertThat(xdsClient.watchers.keySet()).containsExactly(EDS_SERVICE_NAME1); - assertThat(childBalancers).isEmpty(); + public void edsClustersWithRingHashEndpointLbPolicy_oppositePickFirstWeightedShuffling() + throws Exception { + boolean original = CdsLoadBalancer2.pickFirstWeightedShuffling; + CdsLoadBalancer2.pickFirstWeightedShuffling = !CdsLoadBalancer2.pickFirstWeightedShuffling; + try { + edsClustersWithRingHashEndpointLbPolicy(); + } finally { + CdsLoadBalancer2.pickFirstWeightedShuffling = original; + } + } + + @Test + public void edsClustersWithRingHashEndpointLbPolicy() throws Exception { + boolean originalVal = LoadStatsManager2.isEnabledOrcaLrsPropagation; + LoadStatsManager2.isEnabledOrcaLrsPropagation = true; + List metricSpecs = Arrays.asList("cpu_utilization"); + BackendMetricPropagation backendMetricPropagation = + BackendMetricPropagation.fromMetricSpecs(metricSpecs); + Cluster cluster = EDS_CLUSTER.toBuilder() + .setLbPolicy(Cluster.LbPolicy.RING_HASH) + .setRingHashLbConfig(Cluster.RingHashLbConfig.newBuilder() + .setMinimumRingSize(UInt64Value.of(10)) + .setMaximumRingSize(UInt64Value.of(100)) + .build()) + .addAllLrsReportEndpointMetrics(metricSpecs) + .build(); + ClusterLoadAssignment clusterLoadAssignment = ClusterLoadAssignment.newBuilder() + .setClusterName(EDS_SERVICE_NAME) + .addEndpoints(LocalityLbEndpoints.newBuilder() + .setLoadBalancingWeight(UInt32Value.of(10)) + .setLocality(LOCALITY1) + .addLbEndpoints(newSocketLbEndpoint("127.0.0.1", 8080)) + .addLbEndpoints(newSocketLbEndpoint("127.0.0.2", 8080))) + .addEndpoints(LocalityLbEndpoints.newBuilder() + .setLoadBalancingWeight(UInt32Value.of(50)) + .setLocality(LOCALITY2) + .addLbEndpoints(newSocketLbEndpoint("127.0.1.1", 8080) + .setLoadBalancingWeight(UInt32Value.of(60)))) + .build(); + controlPlaneService.setXdsConfig(ADS_TYPE_URL_CDS, ImmutableMap.of( + CLUSTER, cluster)); + controlPlaneService.setXdsConfig(ADS_TYPE_URL_EDS, ImmutableMap.of( + EDS_SERVICE_NAME, clusterLoadAssignment)); + startXdsDepManager(); - // One priority with two localities of different weights. - EquivalentAddressGroup endpoint1 = makeAddress("endpoint-addr-1"); - EquivalentAddressGroup endpoint2 = makeAddress("endpoint-addr-2"); - EquivalentAddressGroup endpoint3 = makeAddress("endpoint-addr-3"); - LocalityLbEndpoints localityLbEndpoints1 = - LocalityLbEndpoints.create( - Arrays.asList( - LbEndpoint.create(endpoint1, 0 /* loadBalancingWeight */, true), - LbEndpoint.create(endpoint2, 0 /* loadBalancingWeight */, true)), - 10 /* localityWeight */, 1 /* priority */); - LocalityLbEndpoints localityLbEndpoints2 = - LocalityLbEndpoints.create( - Collections.singletonList( - LbEndpoint.create(endpoint3, 60 /* loadBalancingWeight */, true)), - 50 /* localityWeight */, 1 /* priority */); - xdsClient.deliverClusterLoadAssignment( - EDS_SERVICE_NAME1, - ImmutableMap.of(locality1, localityLbEndpoints1, locality2, localityLbEndpoints2)); + verify(helper, never()).updateBalancingState(eq(ConnectivityState.TRANSIENT_FAILURE), any()); assertThat(childBalancers).hasSize(1); FakeLoadBalancer childBalancer = Iterables.getOnlyElement(childBalancers); assertThat(childBalancer.addresses).hasSize(3); EquivalentAddressGroup addr1 = childBalancer.addresses.get(0); EquivalentAddressGroup addr2 = childBalancer.addresses.get(1); EquivalentAddressGroup addr3 = childBalancer.addresses.get(2); - // Endpoints in locality1 have no endpoint-level weight specified, so all endpoints within - // locality1 are equally weighted. - assertThat(addr1.getAddresses()).isEqualTo(endpoint1.getAddresses()); - assertThat(addr1.getAttributes().get(InternalXdsAttributes.ATTR_SERVER_WEIGHT)) - .isEqualTo(10); - assertThat(addr2.getAddresses()).isEqualTo(endpoint2.getAddresses()); - assertThat(addr2.getAttributes().get(InternalXdsAttributes.ATTR_SERVER_WEIGHT)) - .isEqualTo(10); - assertThat(addr3.getAddresses()).isEqualTo(endpoint3.getAddresses()); - assertThat(addr3.getAttributes().get(InternalXdsAttributes.ATTR_SERVER_WEIGHT)) - .isEqualTo(50 * 60); + // Endpoints in LOCALITY1 have no endpoint-level weight specified, so all endpoints within + // LOCALITY1 are equally weighted. + assertThat(addr1.getAddresses()) + .isEqualTo(Arrays.asList(newInetSocketAddress("127.0.0.1", 8080))); + assertThat(addr1.getAttributes().get(InternalEquivalentAddressGroup.ATTR_BACKEND_SERVICE)) + .isEqualTo(CLUSTER); + assertThat(addr1.getAttributes().get(io.grpc.xds.XdsAttributes.ATTR_SERVER_WEIGHT)) + .isEqualTo(CdsLoadBalancer2.pickFirstWeightedShuffling ? 0x0AAAAAAA /* 1/12 */ : 10); + assertThat(addr2.getAddresses()) + .isEqualTo(Arrays.asList(newInetSocketAddress("127.0.0.2", 8080))); + assertThat(addr2.getAttributes().get(InternalEquivalentAddressGroup.ATTR_BACKEND_SERVICE)) + .isEqualTo(CLUSTER); + assertThat(addr2.getAttributes().get(io.grpc.xds.XdsAttributes.ATTR_SERVER_WEIGHT)) + .isEqualTo(CdsLoadBalancer2.pickFirstWeightedShuffling ? 0x0AAAAAAA /* 1/12 */ : 10); + assertThat(addr3.getAddresses()) + .isEqualTo(Arrays.asList(newInetSocketAddress("127.0.1.1", 8080))); + assertThat(addr3.getAttributes().get(InternalEquivalentAddressGroup.ATTR_BACKEND_SERVICE)) + .isEqualTo(CLUSTER); + assertThat(addr3.getAttributes().get(io.grpc.xds.XdsAttributes.ATTR_SERVER_WEIGHT)) + .isEqualTo(CdsLoadBalancer2.pickFirstWeightedShuffling ? 0x6AAAAAAA /* 5/6 */ : 50 * 60); assertThat(childBalancer.name).isEqualTo(PRIORITY_POLICY_NAME); PriorityLbConfig priorityLbConfig = (PriorityLbConfig) childBalancer.config; - assertThat(priorityLbConfig.priorities).containsExactly(CLUSTER1 + "[child1]"); + assertThat(priorityLbConfig.priorities).containsExactly(CLUSTER + "[child1]"); PriorityChildConfig priorityChildConfig = Iterables.getOnlyElement(priorityLbConfig.childConfigs.values()); assertThat(priorityChildConfig.ignoreReresolution).isTrue(); - assertThat(priorityChildConfig.policySelection.getProvider().getPolicyName()) + assertThat(GracefulSwitchLoadBalancerAccessor.getChildProvider(priorityChildConfig.childConfig) + .getPolicyName()) .isEqualTo(CLUSTER_IMPL_POLICY_NAME); - ClusterImplConfig clusterImplConfig = - (ClusterImplConfig) priorityChildConfig.policySelection.getConfig(); - assertClusterImplConfig(clusterImplConfig, CLUSTER1, EDS_SERVICE_NAME1, LRS_SERVER_INFO, 100L, - tlsContext, Collections.emptyList(), "ring_hash_experimental"); - RingHashConfig ringHashConfig = - (RingHashConfig) clusterImplConfig.childPolicy.getConfig(); + ClusterImplConfig clusterImplConfig = (ClusterImplConfig) + GracefulSwitchLoadBalancerAccessor.getChildConfig(priorityChildConfig.childConfig); + assertClusterImplConfig(clusterImplConfig, CLUSTER, EDS_SERVICE_NAME, null, null, + null, Collections.emptyList(), "ring_hash_experimental"); + assertThat(clusterImplConfig.backendMetricPropagation).isEqualTo(backendMetricPropagation); + LoadStatsManager2.isEnabledOrcaLrsPropagation = originalVal; + RingHashConfig ringHashConfig = (RingHashConfig) + GracefulSwitchLoadBalancerAccessor.getChildConfig(clusterImplConfig.childConfig); assertThat(ringHashConfig.minRingSize).isEqualTo(10L); assertThat(ringHashConfig.maxRingSize).isEqualTo(100L); } @Test public void edsClustersWithLeastRequestEndpointLbPolicy() { - ClusterResolverConfig config = new ClusterResolverConfig( - Collections.singletonList(edsDiscoveryMechanism1), leastRequest); - deliverLbConfig(config); - assertThat(xdsClient.watchers.keySet()).containsExactly(EDS_SERVICE_NAME1); - assertThat(childBalancers).isEmpty(); - + Cluster cluster = EDS_CLUSTER.toBuilder() + .setLbPolicy(Cluster.LbPolicy.LEAST_REQUEST) + .build(); // Simple case with one priority and one locality - EquivalentAddressGroup endpoint = makeAddress("endpoint-addr-1"); - LocalityLbEndpoints localityLbEndpoints = - LocalityLbEndpoints.create( - Arrays.asList( - LbEndpoint.create(endpoint, 0 /* loadBalancingWeight */, true)), - 100 /* localityWeight */, 1 /* priority */); - xdsClient.deliverClusterLoadAssignment( - EDS_SERVICE_NAME1, - ImmutableMap.of(locality1, localityLbEndpoints)); + ClusterLoadAssignment clusterLoadAssignment = ClusterLoadAssignment.newBuilder() + .setClusterName(EDS_SERVICE_NAME) + .addEndpoints(LocalityLbEndpoints.newBuilder() + .setLoadBalancingWeight(UInt32Value.of(100)) + .setLocality(LOCALITY1) + .addLbEndpoints(newSocketLbEndpoint("127.0.0.1", 8080))) + .build(); + controlPlaneService.setXdsConfig(ADS_TYPE_URL_CDS, ImmutableMap.of( + CLUSTER, cluster)); + controlPlaneService.setXdsConfig(ADS_TYPE_URL_EDS, ImmutableMap.of( + EDS_SERVICE_NAME, clusterLoadAssignment)); + startXdsDepManager(); + + verify(helper, never()).updateBalancingState(eq(ConnectivityState.TRANSIENT_FAILURE), any()); assertThat(childBalancers).hasSize(1); FakeLoadBalancer childBalancer = Iterables.getOnlyElement(childBalancers); assertThat(childBalancer.addresses).hasSize(1); EquivalentAddressGroup addr = childBalancer.addresses.get(0); - assertThat(addr.getAddresses()).isEqualTo(endpoint.getAddresses()); + assertThat(addr.getAddresses()) + .isEqualTo(Arrays.asList(newInetSocketAddress("127.0.0.1", 8080))); assertThat(childBalancer.name).isEqualTo(PRIORITY_POLICY_NAME); PriorityLbConfig priorityLbConfig = (PriorityLbConfig) childBalancer.config; - assertThat(priorityLbConfig.priorities).containsExactly(CLUSTER1 + "[child1]"); + assertThat(priorityLbConfig.priorities).containsExactly(CLUSTER + "[child1]"); PriorityChildConfig priorityChildConfig = Iterables.getOnlyElement(priorityLbConfig.childConfigs.values()); - assertThat(priorityChildConfig.policySelection.getProvider().getPolicyName()) + assertThat(GracefulSwitchLoadBalancerAccessor.getChildProvider(priorityChildConfig.childConfig) + .getPolicyName()) .isEqualTo(CLUSTER_IMPL_POLICY_NAME); - ClusterImplConfig clusterImplConfig = - (ClusterImplConfig) priorityChildConfig.policySelection.getConfig(); - assertClusterImplConfig(clusterImplConfig, CLUSTER1, EDS_SERVICE_NAME1, LRS_SERVER_INFO, 100L, - tlsContext, Collections.emptyList(), WRR_LOCALITY_POLICY_NAME); - WrrLocalityConfig wrrLocalityConfig = - (WrrLocalityConfig) clusterImplConfig.childPolicy.getConfig(); - assertThat(wrrLocalityConfig.childPolicy.getProvider().getPolicyName()).isEqualTo( - "least_request_experimental"); + ClusterImplConfig clusterImplConfig = (ClusterImplConfig) + GracefulSwitchLoadBalancerAccessor.getChildConfig(priorityChildConfig.childConfig); + assertClusterImplConfig(clusterImplConfig, CLUSTER, EDS_SERVICE_NAME, null, null, + null, Collections.emptyList(), WRR_LOCALITY_POLICY_NAME); + WrrLocalityConfig wrrLocalityConfig = (WrrLocalityConfig) + GracefulSwitchLoadBalancerAccessor.getChildConfig(clusterImplConfig.childConfig); + LoadBalancerProvider childProvider = + GracefulSwitchLoadBalancerAccessor.getChildProvider(wrrLocalityConfig.childConfig); + assertThat(childProvider.getPolicyName()).isEqualTo("least_request_experimental"); assertThat( childBalancer.addresses.get(0).getAttributes() - .get(InternalXdsAttributes.ATTR_LOCALITY_WEIGHT)).isEqualTo(100); + .get(io.grpc.xds.XdsAttributes.ATTR_LOCALITY_WEIGHT)).isEqualTo(100); } @Test - public void edsClustersWithOutlierDetection() { - ClusterResolverConfig config = new ClusterResolverConfig( - Collections.singletonList(edsDiscoveryMechanismWithOutlierDetection), leastRequest); - deliverLbConfig(config); - assertThat(xdsClient.watchers.keySet()).containsExactly(EDS_SERVICE_NAME1); - assertThat(childBalancers).isEmpty(); - + public void edsClustersEndpointHostname_addedToAddressAttribute() { // Simple case with one priority and one locality - EquivalentAddressGroup endpoint = makeAddress("endpoint-addr-1"); - LocalityLbEndpoints localityLbEndpoints = - LocalityLbEndpoints.create( - Arrays.asList( - LbEndpoint.create(endpoint, 0 /* loadBalancingWeight */, true)), - 100 /* localityWeight */, 1 /* priority */); - xdsClient.deliverClusterLoadAssignment( - EDS_SERVICE_NAME1, - ImmutableMap.of(locality1, localityLbEndpoints)); + ClusterLoadAssignment clusterLoadAssignment = ClusterLoadAssignment.newBuilder() + .setClusterName(EDS_SERVICE_NAME) + .addEndpoints(LocalityLbEndpoints.newBuilder() + .setLoadBalancingWeight(UInt32Value.of(100)) + .setLocality(LOCALITY1) + .addLbEndpoints(LbEndpoint.newBuilder() + .setEndpoint(Endpoint.newBuilder() + .setHostname("hostname1") + .setAddress(newAddress("127.0.0.1", 8000))))) + .build(); + controlPlaneService.setXdsConfig(ADS_TYPE_URL_EDS, ImmutableMap.of( + EDS_SERVICE_NAME, clusterLoadAssignment)); + startXdsDepManager(); + + verify(helper, never()).updateBalancingState(eq(ConnectivityState.TRANSIENT_FAILURE), any()); assertThat(childBalancers).hasSize(1); FakeLoadBalancer childBalancer = Iterables.getOnlyElement(childBalancers); - assertThat(childBalancer.addresses).hasSize(1); - EquivalentAddressGroup addr = childBalancer.addresses.get(0); - assertThat(addr.getAddresses()).isEqualTo(endpoint.getAddresses()); - assertThat(childBalancer.name).isEqualTo(PRIORITY_POLICY_NAME); - PriorityLbConfig priorityLbConfig = (PriorityLbConfig) childBalancer.config; - assertThat(priorityLbConfig.priorities).containsExactly(CLUSTER1 + "[child1]"); - PriorityChildConfig priorityChildConfig = - Iterables.getOnlyElement(priorityLbConfig.childConfigs.values()); - - // The child config for priority should be outlier detection. - assertThat(priorityChildConfig.policySelection.getProvider().getPolicyName()) - .isEqualTo("outlier_detection_experimental"); - OutlierDetectionLoadBalancerConfig outlierDetectionConfig = - (OutlierDetectionLoadBalancerConfig) priorityChildConfig.policySelection.getConfig(); - - // The outlier detection config should faithfully represent what came down from xDS. - assertThat(outlierDetectionConfig.intervalNanos).isEqualTo(outlierDetection.intervalNanos()); - assertThat(outlierDetectionConfig.baseEjectionTimeNanos).isEqualTo( - outlierDetection.baseEjectionTimeNanos()); - assertThat(outlierDetectionConfig.baseEjectionTimeNanos).isEqualTo( - outlierDetection.baseEjectionTimeNanos()); - assertThat(outlierDetectionConfig.maxEjectionTimeNanos).isEqualTo( - outlierDetection.maxEjectionTimeNanos()); - assertThat(outlierDetectionConfig.maxEjectionPercent).isEqualTo( - outlierDetection.maxEjectionPercent()); - - OutlierDetectionLoadBalancerConfig.SuccessRateEjection successRateEjection - = outlierDetectionConfig.successRateEjection; - assertThat(successRateEjection.stdevFactor).isEqualTo( - outlierDetection.successRateEjection().stdevFactor()); - assertThat(successRateEjection.enforcementPercentage).isEqualTo( - outlierDetection.successRateEjection().enforcementPercentage()); - assertThat(successRateEjection.minimumHosts).isEqualTo( - outlierDetection.successRateEjection().minimumHosts()); - assertThat(successRateEjection.requestVolume).isEqualTo( - outlierDetection.successRateEjection().requestVolume()); - - OutlierDetectionLoadBalancerConfig.FailurePercentageEjection failurePercentageEjection - = outlierDetectionConfig.failurePercentageEjection; - assertThat(failurePercentageEjection.threshold).isEqualTo( - outlierDetection.failurePercentageEjection().threshold()); - assertThat(failurePercentageEjection.enforcementPercentage).isEqualTo( - outlierDetection.failurePercentageEjection().enforcementPercentage()); - assertThat(failurePercentageEjection.minimumHosts).isEqualTo( - outlierDetection.failurePercentageEjection().minimumHosts()); - assertThat(failurePercentageEjection.requestVolume).isEqualTo( - outlierDetection.failurePercentageEjection().requestVolume()); - - // The wrapped configuration should not have been tampered with. - ClusterImplConfig clusterImplConfig = - (ClusterImplConfig) outlierDetectionConfig.childPolicy.getConfig(); - assertClusterImplConfig(clusterImplConfig, CLUSTER1, EDS_SERVICE_NAME1, LRS_SERVER_INFO, 100L, - tlsContext, Collections.emptyList(), WRR_LOCALITY_POLICY_NAME); - WrrLocalityConfig wrrLocalityConfig = - (WrrLocalityConfig) clusterImplConfig.childPolicy.getConfig(); - assertThat(wrrLocalityConfig.childPolicy.getProvider().getPolicyName()).isEqualTo( - "least_request_experimental"); assertThat( childBalancer.addresses.get(0).getAttributes() - .get(InternalXdsAttributes.ATTR_LOCALITY_WEIGHT)).isEqualTo(100); + .get(XdsInternalAttributes.ATTR_ADDRESS_NAME)).isEqualTo("hostname1"); } + @Test + public void endpointAddressRewritten_whenProxyMetadataIsInEndpointMetadata() { + XdsClusterResource.isEnabledXdsHttpConnect = true; + Cluster cluster = EDS_CLUSTER.toBuilder() + .setTransportSocket(TransportSocket.newBuilder() + .setName( + "type.googleapis.com/" + Http11ProxyUpstreamTransport.getDescriptor().getFullName()) + .setTypedConfig(Any.pack(Http11ProxyUpstreamTransport.getDefaultInstance()))) + .build(); + // Proxy address in endpointMetadata, and no proxy in locality metadata + ClusterLoadAssignment clusterLoadAssignment = ClusterLoadAssignment.newBuilder() + .setClusterName(EDS_SERVICE_NAME) + .addEndpoints(LocalityLbEndpoints.newBuilder() + .setLoadBalancingWeight(UInt32Value.of(100)) + .setLocality(LOCALITY1) + .addLbEndpoints(newSocketLbEndpoint("127.0.0.1", 8080) + .setMetadata(Metadata.newBuilder() + .putTypedFilterMetadata( + "envoy.http11_proxy_transport_socket.proxy_address", + Any.pack(newAddress("127.0.0.2", 8081).build())))) + .addLbEndpoints(newSocketLbEndpoint("127.0.0.3", 8082))) + .build(); + controlPlaneService.setXdsConfig(ADS_TYPE_URL_CDS, ImmutableMap.of( + CLUSTER, cluster)); + controlPlaneService.setXdsConfig(ADS_TYPE_URL_EDS, ImmutableMap.of( + EDS_SERVICE_NAME, clusterLoadAssignment)); + startXdsDepManager(); + + verify(helper, never()).updateBalancingState(eq(ConnectivityState.TRANSIENT_FAILURE), any()); + assertThat(childBalancers).hasSize(1); + FakeLoadBalancer childBalancer = Iterables.getOnlyElement(childBalancers); + + // Get the rewritten address + java.net.SocketAddress rewrittenAddress = + childBalancer.addresses.get(0).getAddresses().get(0); + assertThat(rewrittenAddress).isInstanceOf(HttpConnectProxiedSocketAddress.class); + HttpConnectProxiedSocketAddress proxiedSocket = + (HttpConnectProxiedSocketAddress) rewrittenAddress; + + // Assert that the target address is the original address + assertThat(proxiedSocket.getTargetAddress()).isEqualTo(newInetSocketAddress("127.0.0.1", 8080)); + + // Assert that the proxy address is correctly set + assertThat(proxiedSocket.getProxyAddress()).isEqualTo(newInetSocketAddress("127.0.0.2", 8081)); + + // Check the non-rewritten address + java.net.SocketAddress normalAddress = childBalancer.addresses.get(1).getAddresses().get(0); + assertThat(normalAddress).isEqualTo(newInetSocketAddress("127.0.0.3", 8082)); + } + + @Test + public void endpointAddressRewritten_whenProxyMetadataIsInLocalityMetadata() { + XdsClusterResource.isEnabledXdsHttpConnect = true; + Cluster cluster = EDS_CLUSTER.toBuilder() + .setTransportSocket(TransportSocket.newBuilder() + .setName( + "type.googleapis.com/" + Http11ProxyUpstreamTransport.getDescriptor().getFullName()) + .setTypedConfig(Any.pack(Http11ProxyUpstreamTransport.getDefaultInstance()))) + .build(); + // No proxy address in endpointMetadata, and proxy in locality metadata + ClusterLoadAssignment clusterLoadAssignment = ClusterLoadAssignment.newBuilder() + .setClusterName(EDS_SERVICE_NAME) + .addEndpoints(LocalityLbEndpoints.newBuilder() + .setLoadBalancingWeight(UInt32Value.of(100)) + .setLocality(LOCALITY1) + .addLbEndpoints(newSocketLbEndpoint("127.0.0.1", 8080)) + .setMetadata(Metadata.newBuilder() + .putTypedFilterMetadata( + "envoy.http11_proxy_transport_socket.proxy_address", + Any.pack(newAddress("127.0.0.2", 8081).build())))) + .build(); + controlPlaneService.setXdsConfig(ADS_TYPE_URL_CDS, ImmutableMap.of( + CLUSTER, cluster)); + controlPlaneService.setXdsConfig(ADS_TYPE_URL_EDS, ImmutableMap.of( + EDS_SERVICE_NAME, clusterLoadAssignment)); + startXdsDepManager(); + + verify(helper, never()).updateBalancingState(eq(ConnectivityState.TRANSIENT_FAILURE), any()); + assertThat(childBalancers).hasSize(1); + FakeLoadBalancer childBalancer = Iterables.getOnlyElement(childBalancers); + + // Get the rewritten address + java.net.SocketAddress rewrittenAddress = childBalancer.addresses.get(0).getAddresses().get(0); + + // Assert that the address was rewritten + assertThat(rewrittenAddress).isInstanceOf(HttpConnectProxiedSocketAddress.class); + HttpConnectProxiedSocketAddress proxiedSocket = + (HttpConnectProxiedSocketAddress) rewrittenAddress; + + // Assert that the target address is the original address + assertThat(proxiedSocket.getTargetAddress()).isEqualTo(newInetSocketAddress("127.0.0.1", 8080)); + + // Assert that the proxy address is correctly set from locality metadata + assertThat(proxiedSocket.getProxyAddress()).isEqualTo(newInetSocketAddress("127.0.0.2", 8081)); + } @Test public void onlyEdsClusters_receivedEndpoints() { - ClusterResolverConfig config = new ClusterResolverConfig( - Arrays.asList(edsDiscoveryMechanism1, edsDiscoveryMechanism2), roundRobin); - deliverLbConfig(config); - assertThat(xdsClient.watchers.keySet()).containsExactly(EDS_SERVICE_NAME1, EDS_SERVICE_NAME2); - assertThat(childBalancers).isEmpty(); - // CLUSTER1 has priority 1 (priority3), which has locality 2, which has endpoint3. - // CLUSTER2 has priority 1 (priority1) and 2 (priority2); priority1 has locality1, - // which has endpoint1 and endpoint2; priority2 has locality3, which has endpoint4. - EquivalentAddressGroup endpoint1 = makeAddress("endpoint-addr-1"); - EquivalentAddressGroup endpoint2 = makeAddress("endpoint-addr-2"); - EquivalentAddressGroup endpoint3 = makeAddress("endpoint-addr-3"); - EquivalentAddressGroup endpoint4 = makeAddress("endpoint-addr-4"); - LocalityLbEndpoints localityLbEndpoints1 = - LocalityLbEndpoints.create( - Arrays.asList( - LbEndpoint.create(endpoint1, 100, true), - LbEndpoint.create(endpoint2, 100, true)), - 70 /* localityWeight */, 1 /* priority */); - LocalityLbEndpoints localityLbEndpoints2 = - LocalityLbEndpoints.create( - Collections.singletonList(LbEndpoint.create(endpoint3, 100, true)), - 10 /* localityWeight */, 1 /* priority */); - LocalityLbEndpoints localityLbEndpoints3 = - LocalityLbEndpoints.create( - Collections.singletonList(LbEndpoint.create(endpoint4, 100, true)), - 20 /* localityWeight */, 2 /* priority */); - String priority1 = CLUSTER2 + "[child1]"; - String priority2 = CLUSTER2 + "[child2]"; - String priority3 = CLUSTER1 + "[child1]"; - - // CLUSTER2: locality1 with priority 1 and locality3 with priority 2. - xdsClient.deliverClusterLoadAssignment( - EDS_SERVICE_NAME2, - ImmutableMap.of(locality1, localityLbEndpoints1, locality3, localityLbEndpoints3)); - assertThat(childBalancers).isEmpty(); // not created until all clusters resolved - - // CLUSTER1: locality2 with priority 1. - xdsClient.deliverClusterLoadAssignment( - EDS_SERVICE_NAME1, Collections.singletonMap(locality2, localityLbEndpoints2)); - - // Endpoints of all clusters have been resolved. + // Has two localities with different priorities + ClusterLoadAssignment clusterLoadAssignment = ClusterLoadAssignment.newBuilder() + .setClusterName(EDS_SERVICE_NAME) + .addEndpoints(LocalityLbEndpoints.newBuilder() + .setLoadBalancingWeight(UInt32Value.of(70)) + .setPriority(0) + .setLocality(LOCALITY1) + .addLbEndpoints(newSocketLbEndpoint("127.0.0.1", 8080)) + .addLbEndpoints(newSocketLbEndpoint("127.0.0.2", 8080))) + .addEndpoints(LocalityLbEndpoints.newBuilder() + .setLoadBalancingWeight(UInt32Value.of(30)) + .setPriority(1) + .setLocality(LOCALITY2) + .addLbEndpoints(newSocketLbEndpoint("127.0.0.3", 8080))) + .build(); + controlPlaneService.setXdsConfig(ADS_TYPE_URL_EDS, ImmutableMap.of( + EDS_SERVICE_NAME, clusterLoadAssignment)); + startXdsDepManager(); + + String priority1 = CLUSTER + "[child1]"; + String priority2 = CLUSTER + "[child2]"; + + verify(helper, never()).updateBalancingState(eq(ConnectivityState.TRANSIENT_FAILURE), any()); assertThat(childBalancers).hasSize(1); FakeLoadBalancer childBalancer = Iterables.getOnlyElement(childBalancers); assertThat(childBalancer.name).isEqualTo(PRIORITY_POLICY_NAME); PriorityLbConfig priorityLbConfig = (PriorityLbConfig) childBalancer.config; assertThat(priorityLbConfig.priorities) - .containsExactly(priority3, priority1, priority2).inOrder(); + .containsExactly(priority1, priority2).inOrder(); PriorityChildConfig priorityChildConfig1 = priorityLbConfig.childConfigs.get(priority1); assertThat(priorityChildConfig1.ignoreReresolution).isTrue(); - assertThat(priorityChildConfig1.policySelection.getProvider().getPolicyName()) + assertThat(GracefulSwitchLoadBalancerAccessor.getChildProvider(priorityChildConfig1.childConfig) + .getPolicyName()) .isEqualTo(CLUSTER_IMPL_POLICY_NAME); - ClusterImplConfig clusterImplConfig1 = - (ClusterImplConfig) priorityChildConfig1.policySelection.getConfig(); - assertClusterImplConfig(clusterImplConfig1, CLUSTER2, EDS_SERVICE_NAME2, LRS_SERVER_INFO, 200L, - tlsContext, Collections.emptyList(), WRR_LOCALITY_POLICY_NAME); - assertThat(clusterImplConfig1.childPolicy.getConfig()).isInstanceOf(WrrLocalityConfig.class); - WrrLocalityConfig wrrLocalityConfig1 = - (WrrLocalityConfig) clusterImplConfig1.childPolicy.getConfig(); - assertThat(wrrLocalityConfig1.childPolicy.getProvider().getPolicyName()).isEqualTo( - "round_robin"); + ClusterImplConfig clusterImplConfig1 = (ClusterImplConfig) + GracefulSwitchLoadBalancerAccessor.getChildConfig(priorityChildConfig1.childConfig); + assertClusterImplConfig(clusterImplConfig1, CLUSTER, EDS_SERVICE_NAME, null, null, + null, Collections.emptyList(), WRR_LOCALITY_POLICY_NAME); + WrrLocalityConfig wrrLocalityConfig1 = (WrrLocalityConfig) + GracefulSwitchLoadBalancerAccessor.getChildConfig(clusterImplConfig1.childConfig); + LoadBalancerProvider childProvider1 = + GracefulSwitchLoadBalancerAccessor.getChildProvider(wrrLocalityConfig1.childConfig); + assertThat(childProvider1.getPolicyName()).isEqualTo("round_robin"); PriorityChildConfig priorityChildConfig2 = priorityLbConfig.childConfigs.get(priority2); assertThat(priorityChildConfig2.ignoreReresolution).isTrue(); - assertThat(priorityChildConfig2.policySelection.getProvider().getPolicyName()) + assertThat(GracefulSwitchLoadBalancerAccessor.getChildProvider(priorityChildConfig2.childConfig) + .getPolicyName()) .isEqualTo(CLUSTER_IMPL_POLICY_NAME); - ClusterImplConfig clusterImplConfig2 = - (ClusterImplConfig) priorityChildConfig2.policySelection.getConfig(); - assertClusterImplConfig(clusterImplConfig2, CLUSTER2, EDS_SERVICE_NAME2, LRS_SERVER_INFO, 200L, - tlsContext, Collections.emptyList(), WRR_LOCALITY_POLICY_NAME); - assertThat(clusterImplConfig2.childPolicy.getConfig()).isInstanceOf(WrrLocalityConfig.class); - WrrLocalityConfig wrrLocalityConfig2 = - (WrrLocalityConfig) clusterImplConfig1.childPolicy.getConfig(); - assertThat(wrrLocalityConfig2.childPolicy.getProvider().getPolicyName()).isEqualTo( - "round_robin"); - - PriorityChildConfig priorityChildConfig3 = priorityLbConfig.childConfigs.get(priority3); - assertThat(priorityChildConfig3.ignoreReresolution).isTrue(); - assertThat(priorityChildConfig3.policySelection.getProvider().getPolicyName()) - .isEqualTo(CLUSTER_IMPL_POLICY_NAME); - ClusterImplConfig clusterImplConfig3 = - (ClusterImplConfig) priorityChildConfig3.policySelection.getConfig(); - assertClusterImplConfig(clusterImplConfig3, CLUSTER1, EDS_SERVICE_NAME1, LRS_SERVER_INFO, 100L, - tlsContext, Collections.emptyList(), WRR_LOCALITY_POLICY_NAME); - assertThat(clusterImplConfig3.childPolicy.getConfig()).isInstanceOf(WrrLocalityConfig.class); - WrrLocalityConfig wrrLocalityConfig3 = - (WrrLocalityConfig) clusterImplConfig1.childPolicy.getConfig(); - assertThat(wrrLocalityConfig3.childPolicy.getProvider().getPolicyName()).isEqualTo( - "round_robin"); - + ClusterImplConfig clusterImplConfig2 = (ClusterImplConfig) + GracefulSwitchLoadBalancerAccessor.getChildConfig(priorityChildConfig2.childConfig); + assertClusterImplConfig(clusterImplConfig2, CLUSTER, EDS_SERVICE_NAME, null, null, + null, Collections.emptyList(), WRR_LOCALITY_POLICY_NAME); + WrrLocalityConfig wrrLocalityConfig2 = (WrrLocalityConfig) + GracefulSwitchLoadBalancerAccessor.getChildConfig(clusterImplConfig1.childConfig); + LoadBalancerProvider childProvider2 = + GracefulSwitchLoadBalancerAccessor.getChildProvider(wrrLocalityConfig2.childConfig); + assertThat(childProvider2.getPolicyName()).isEqualTo("round_robin"); + + WrrLocalityConfig wrrLocalityConfig3 = (WrrLocalityConfig) + GracefulSwitchLoadBalancerAccessor.getChildConfig(clusterImplConfig1.childConfig); + LoadBalancerProvider childProvider3 = + GracefulSwitchLoadBalancerAccessor.getChildProvider(wrrLocalityConfig3.childConfig); + assertThat(childProvider3.getPolicyName()).isEqualTo("round_robin"); + + io.grpc.xds.client.Locality locality1 = io.grpc.xds.client.Locality.create( + LOCALITY1.getRegion(), LOCALITY1.getZone(), LOCALITY1.getSubZone()); + io.grpc.xds.client.Locality locality2 = io.grpc.xds.client.Locality.create( + LOCALITY2.getRegion(), LOCALITY2.getZone(), LOCALITY2.getSubZone()); for (EquivalentAddressGroup eag : childBalancer.addresses) { - if (eag.getAttributes().get(InternalXdsAttributes.ATTR_LOCALITY) == locality1) { - assertThat(eag.getAttributes().get(InternalXdsAttributes.ATTR_LOCALITY_WEIGHT)) + io.grpc.xds.client.Locality locality = + eag.getAttributes().get(io.grpc.xds.XdsAttributes.ATTR_LOCALITY); + if (locality.equals(locality1)) { + assertThat(eag.getAttributes().get(io.grpc.xds.XdsAttributes.ATTR_LOCALITY_WEIGHT)) .isEqualTo(70); - } - if (eag.getAttributes().get(InternalXdsAttributes.ATTR_LOCALITY) == locality2) { - assertThat(eag.getAttributes().get(InternalXdsAttributes.ATTR_LOCALITY_WEIGHT)) - .isEqualTo(10); - } - if (eag.getAttributes().get(InternalXdsAttributes.ATTR_LOCALITY) == locality3) { - assertThat(eag.getAttributes().get(InternalXdsAttributes.ATTR_LOCALITY_WEIGHT)) - .isEqualTo(20); + } else if (locality.equals(locality2)) { + assertThat(eag.getAttributes().get(io.grpc.xds.XdsAttributes.ATTR_LOCALITY_WEIGHT)) + .isEqualTo(30); + } else { + throw new AssertionError("Unexpected locality region: " + locality.region()); } } } @SuppressWarnings("unchecked") - private void verifyEdsPriorityNames(List want, - Map... updates) { - ClusterResolverConfig config = new ClusterResolverConfig( - Arrays.asList(edsDiscoveryMechanism2), roundRobin); - deliverLbConfig(config); - assertThat(xdsClient.watchers.keySet()).containsExactly(EDS_SERVICE_NAME2); - assertThat(childBalancers).isEmpty(); - - for (Map update: updates) { - xdsClient.deliverClusterLoadAssignment( - EDS_SERVICE_NAME2, - update); + private void verifyEdsPriorityNames(List want, List... updates) { + Iterator edsUpdates = Arrays.asList(updates).stream() + .map(update -> ClusterLoadAssignment.newBuilder() + .setClusterName(EDS_SERVICE_NAME) + .addAllEndpoints(update) + .build()) + .iterator(); + controlPlaneService.setXdsConfig(ADS_TYPE_URL_EDS, ImmutableMap.of( + EDS_SERVICE_NAME, edsUpdates.next())); + startXdsDepManager(); + + while (edsUpdates.hasNext()) { + controlPlaneService.setXdsConfig(ADS_TYPE_URL_EDS, ImmutableMap.of( + EDS_SERVICE_NAME, edsUpdates.next())); } + verify(helper, never()).updateBalancingState(eq(ConnectivityState.TRANSIENT_FAILURE), any()); assertThat(childBalancers).hasSize(1); FakeLoadBalancer childBalancer = Iterables.getOnlyElement(childBalancers); assertThat(childBalancer.name).isEqualTo(PRIORITY_POLICY_NAME); @@ -553,218 +640,273 @@ private void verifyEdsPriorityNames(List want, @Test @SuppressWarnings("unchecked") public void edsUpdatePriorityName_twoPriorities() { - verifyEdsPriorityNames(Arrays.asList(CLUSTER2 + "[child1]", CLUSTER2 + "[child2]"), - ImmutableMap.of(locality1, createEndpoints(1), - locality2, createEndpoints(2) - )); + verifyEdsPriorityNames(Arrays.asList(CLUSTER + "[child1]", CLUSTER + "[child2]"), + Arrays.asList(createEndpoints(LOCALITY1, 0), createEndpoints(LOCALITY2, 1))); } @Test @SuppressWarnings("unchecked") public void edsUpdatePriorityName_addOnePriority() { - verifyEdsPriorityNames(Arrays.asList(CLUSTER2 + "[child2]"), - ImmutableMap.of(locality1, createEndpoints(1)), - ImmutableMap.of(locality2, createEndpoints(1) - )); + verifyEdsPriorityNames(Arrays.asList(CLUSTER + "[child2]"), + Arrays.asList(createEndpoints(LOCALITY1, 0)), + Arrays.asList(createEndpoints(LOCALITY2, 0))); } @Test @SuppressWarnings("unchecked") public void edsUpdatePriorityName_swapTwoPriorities() { - verifyEdsPriorityNames(Arrays.asList(CLUSTER2 + "[child2]", CLUSTER2 + "[child1]", - CLUSTER2 + "[child3]"), - ImmutableMap.of(locality1, createEndpoints(1), - locality2, createEndpoints(2), - locality3, createEndpoints(3) - ), - ImmutableMap.of(locality1, createEndpoints(2), - locality2, createEndpoints(1), - locality3, createEndpoints(3)) - ); + verifyEdsPriorityNames(Arrays.asList(CLUSTER + "[child2]", CLUSTER + "[child1]", + CLUSTER + "[child3]"), + Arrays.asList( + createEndpoints(LOCALITY1, 0), + createEndpoints(LOCALITY2, 1), + createEndpoints(LOCALITY3, 2)), + Arrays.asList( + createEndpoints(LOCALITY1, 1), + createEndpoints(LOCALITY2, 0), + createEndpoints(LOCALITY3, 2))); } @Test @SuppressWarnings("unchecked") public void edsUpdatePriorityName_mergeTwoPriorities() { - verifyEdsPriorityNames(Arrays.asList(CLUSTER2 + "[child3]", CLUSTER2 + "[child1]"), - ImmutableMap.of(locality1, createEndpoints(1), - locality3, createEndpoints(3), - locality2, createEndpoints(2)), - ImmutableMap.of(locality1, createEndpoints(2), - locality3, createEndpoints(1), - locality2, createEndpoints(1) - )); + verifyEdsPriorityNames(Arrays.asList(CLUSTER + "[child3]", CLUSTER + "[child1]"), + Arrays.asList( + createEndpoints(LOCALITY1, 0), + createEndpoints(LOCALITY3, 2), + createEndpoints(LOCALITY2, 1)), + Arrays.asList( + createEndpoints(LOCALITY1, 1), + createEndpoints(LOCALITY3, 0), + createEndpoints(LOCALITY2, 0))); } - private LocalityLbEndpoints createEndpoints(int priority) { - return LocalityLbEndpoints.create( - Arrays.asList( - LbEndpoint.create(makeAddress("endpoint-addr-1"), 100, true), - LbEndpoint.create(makeAddress("endpoint-addr-2"), 100, true)), - 70 /* localityWeight */, priority /* priority */); + private LocalityLbEndpoints createEndpoints(Locality locality, int priority) { + return LocalityLbEndpoints.newBuilder() + .setLoadBalancingWeight(UInt32Value.of(70)) + .setLocality(locality) + .setPriority(priority) + .addLbEndpoints(newSocketLbEndpoint("127.0." + priority + ".1", 8080)) + .build(); } @Test public void onlyEdsClusters_resourceNeverExist_returnErrorPicker() { - ClusterResolverConfig config = new ClusterResolverConfig( - Arrays.asList(edsDiscoveryMechanism1, edsDiscoveryMechanism2), roundRobin); - deliverLbConfig(config); - assertThat(xdsClient.watchers.keySet()).containsExactly(EDS_SERVICE_NAME1, EDS_SERVICE_NAME2); - assertThat(childBalancers).isEmpty(); - reset(helper); - xdsClient.deliverResourceNotFound(EDS_SERVICE_NAME1); - verify(helper, never()).updateBalancingState( - any(ConnectivityState.class), any(SubchannelPicker.class)); // wait for CLUSTER2's results + controlPlaneService.setXdsConfig(ADS_TYPE_URL_CDS, ImmutableMap.of()); + startXdsDepManager(); - xdsClient.deliverResourceNotFound(EDS_SERVICE_NAME2); verify(helper).updateBalancingState( eq(ConnectivityState.TRANSIENT_FAILURE), pickerCaptor.capture()); - assertPicker( - pickerCaptor.getValue(), - Status.UNAVAILABLE.withDescription( - "No usable endpoint from cluster(s): " + Arrays.asList(CLUSTER1, CLUSTER2)), - null); + String expectedDescription = "Error retrieving CDS resource " + CLUSTER + " nodeID: node-id: " + + "NOT_FOUND: Timed out waiting for resource " + CLUSTER + " from xDS server"; + Status expectedError = Status.UNAVAILABLE.withDescription(expectedDescription); + assertPicker(pickerCaptor.getValue(), expectedError, null); } @Test - public void onlyEdsClusters_allResourcesRevoked_shutDownChildLbPolicy() { - ClusterResolverConfig config = new ClusterResolverConfig( - Arrays.asList(edsDiscoveryMechanism1, edsDiscoveryMechanism2), roundRobin); - deliverLbConfig(config); - assertThat(xdsClient.watchers.keySet()).containsExactly(EDS_SERVICE_NAME1, EDS_SERVICE_NAME2); - assertThat(childBalancers).isEmpty(); - reset(helper); - EquivalentAddressGroup endpoint1 = makeAddress("endpoint-addr-1"); - EquivalentAddressGroup endpoint2 = makeAddress("endpoint-addr-2"); - LocalityLbEndpoints localityLbEndpoints1 = - LocalityLbEndpoints.create( - Collections.singletonList(LbEndpoint.create(endpoint1, 100, true)), - 10 /* localityWeight */, 1 /* priority */); - LocalityLbEndpoints localityLbEndpoints2 = - LocalityLbEndpoints.create( - Collections.singletonList(LbEndpoint.create(endpoint2, 100, true)), - 20 /* localityWeight */, 2 /* priority */); - xdsClient.deliverClusterLoadAssignment( - EDS_SERVICE_NAME1, Collections.singletonMap(locality1, localityLbEndpoints1)); - xdsClient.deliverClusterLoadAssignment( - EDS_SERVICE_NAME2, Collections.singletonMap(locality2, localityLbEndpoints2)); + public void cdsMissing_handledDirectly() { + ClusterLoadAssignment clusterLoadAssignment = ClusterLoadAssignment.newBuilder() + .setClusterName(EDS_SERVICE_NAME) + .addEndpoints(LocalityLbEndpoints.newBuilder() + .setLoadBalancingWeight(UInt32Value.of(100)) + .setLocality(LOCALITY1) + .addLbEndpoints(newSocketLbEndpoint("127.0.0.1", 8000))) + .build(); + controlPlaneService.setXdsConfig(ADS_TYPE_URL_CDS, ImmutableMap.of()); + controlPlaneService.setXdsConfig(ADS_TYPE_URL_EDS, ImmutableMap.of( + EDS_SERVICE_NAME, clusterLoadAssignment)); + + startXdsDepManager(); + assertThat(childBalancers).hasSize(0); // no child LB policy created + verify(helper).updateBalancingState( + eq(ConnectivityState.TRANSIENT_FAILURE), pickerCaptor.capture()); + String expectedDescription = "Error retrieving CDS resource " + CLUSTER + " nodeID: node-id: " + + "NOT_FOUND: Timed out waiting for resource " + CLUSTER + " from xDS server"; + Status expectedError = Status.UNAVAILABLE.withDescription(expectedDescription); + assertPicker(pickerCaptor.getValue(), expectedError, null); + assertPicker(pickerCaptor.getValue(), expectedError, null); + } + + @Test + public void cdsRevoked_handledDirectly() { + ClusterLoadAssignment clusterLoadAssignment = ClusterLoadAssignment.newBuilder() + .setClusterName(EDS_SERVICE_NAME) + .addEndpoints(LocalityLbEndpoints.newBuilder() + .setLoadBalancingWeight(UInt32Value.of(100)) + .setLocality(LOCALITY1) + .addLbEndpoints(newSocketLbEndpoint("127.0.0.1", 8000))) + .build(); + controlPlaneService.setXdsConfig(ADS_TYPE_URL_EDS, ImmutableMap.of( + EDS_SERVICE_NAME, clusterLoadAssignment)); + + startXdsDepManager(); assertThat(childBalancers).hasSize(1); // child LB policy created FakeLoadBalancer childBalancer = Iterables.getOnlyElement(childBalancers); - assertThat(((PriorityLbConfig) childBalancer.config).priorities).hasSize(2); - assertAddressesEqual(Arrays.asList(endpoint1, endpoint2), childBalancer.addresses); + assertThat(((PriorityLbConfig) childBalancer.config).priorities).hasSize(1); + assertThat(childBalancer.addresses).hasSize(1); + assertAddressesEqual( + Arrays.asList(newInetSocketAddressEag("127.0.0.1", 8000)), + childBalancer.addresses); - xdsClient.deliverResourceNotFound(EDS_SERVICE_NAME2); - xdsClient.deliverResourceNotFound(EDS_SERVICE_NAME1); + controlPlaneService.setXdsConfig(ADS_TYPE_URL_CDS, ImmutableMap.of()); verify(helper).updateBalancingState( eq(ConnectivityState.TRANSIENT_FAILURE), pickerCaptor.capture()); - Status expectedError = Status.UNAVAILABLE.withDescription( - "No usable endpoint from cluster(s): " + Arrays.asList(CLUSTER1, CLUSTER2)); + String expectedDescription = "Error retrieving CDS resource " + CLUSTER + " nodeID: node-id: " + + "NOT_FOUND: Resource " + CLUSTER + " does not exist"; + Status expectedError = Status.UNAVAILABLE.withDescription(expectedDescription); assertPicker(pickerCaptor.getValue(), expectedError, null); + assertThat(childBalancer.shutdown).isTrue(); + } + + @Test + public void edsMissing_failsRpcs() { + controlPlaneService.setXdsConfig(ADS_TYPE_URL_EDS, ImmutableMap.of()); + + startXdsDepManager(); + assertThat(childBalancers).hasSize(0); // Graceful switch handles it, so no child policies yet + verify(helper).updateBalancingState( + eq(ConnectivityState.TRANSIENT_FAILURE), pickerCaptor.capture()); + String expectedDescription = "Error retrieving EDS resource " + EDS_SERVICE_NAME + + " nodeID: node-id: " + + "NOT_FOUND: Timed out waiting for resource " + EDS_SERVICE_NAME + " from xDS server"; + Status expectedError = Status.UNAVAILABLE.withDescription(expectedDescription); + assertPicker(pickerCaptor.getValue(), expectedError, null); + } + + @Test + public void logicalDnsLookupFailed_failsRpcs() { + controlPlaneService.setXdsConfig(ADS_TYPE_URL_CDS, ImmutableMap.of( + CLUSTER, LOGICAL_DNS_CLUSTER)); + startXdsDepManager(new CdsConfig(CLUSTER), /* forwardTime= */ false); + FakeNameResolver resolver = assertResolverCreated("/" + DNS_HOST_NAME + ":9000"); + assertThat(childBalancers).isEmpty(); + Status status = Status.UNAVAILABLE.withDescription("OH NO! Who would have guessed?"); + resolver.deliverError(status); + + assertThat(childBalancers).hasSize(0); // Graceful switch handles it, so no child policies yet + verify(helper).updateBalancingState( + eq(ConnectivityState.TRANSIENT_FAILURE), pickerCaptor.capture()); + assertPicker(pickerCaptor.getValue(), status, null); } @Test public void handleEdsResource_ignoreUnhealthyEndpoints() { - ClusterResolverConfig config = - new ClusterResolverConfig(Collections.singletonList(edsDiscoveryMechanism1), roundRobin); - deliverLbConfig(config); - EquivalentAddressGroup endpoint1 = makeAddress("endpoint-addr-1"); - EquivalentAddressGroup endpoint2 = makeAddress("endpoint-addr-2"); - LocalityLbEndpoints localityLbEndpoints = - LocalityLbEndpoints.create( - Arrays.asList( - LbEndpoint.create(endpoint1, 100, false /* isHealthy */), - LbEndpoint.create(endpoint2, 100, true /* isHealthy */)), - 10 /* localityWeight */, 1 /* priority */); - xdsClient.deliverClusterLoadAssignment( - EDS_SERVICE_NAME1, Collections.singletonMap(locality1, localityLbEndpoints)); + ClusterLoadAssignment clusterLoadAssignment = ClusterLoadAssignment.newBuilder() + .setClusterName(EDS_SERVICE_NAME) + .addEndpoints(LocalityLbEndpoints.newBuilder() + .setLoadBalancingWeight(UInt32Value.of(100)) + .setLocality(LOCALITY1) + .addLbEndpoints(newSocketLbEndpoint("127.0.0.1", 8000) + .setHealthStatus(HealthStatus.UNHEALTHY)) + .addLbEndpoints(newSocketLbEndpoint("127.0.0.2", 8000))) + .build(); + controlPlaneService.setXdsConfig(ADS_TYPE_URL_EDS, ImmutableMap.of( + EDS_SERVICE_NAME, clusterLoadAssignment)); + startXdsDepManager(); + + verify(helper, never()).updateBalancingState(eq(ConnectivityState.TRANSIENT_FAILURE), any()); FakeLoadBalancer childBalancer = Iterables.getOnlyElement(childBalancers); - assertThat(childBalancer.addresses).hasSize(1); - assertAddressesEqual(Collections.singletonList(endpoint2), childBalancer.addresses); + assertAddressesEqual( + Arrays.asList(new EquivalentAddressGroup(newInetSocketAddress("127.0.0.2", 8000))), + childBalancer.addresses); } @Test public void handleEdsResource_ignoreLocalitiesWithNoHealthyEndpoints() { - ClusterResolverConfig config = - new ClusterResolverConfig(Collections.singletonList(edsDiscoveryMechanism1), roundRobin); - deliverLbConfig(config); - EquivalentAddressGroup endpoint1 = makeAddress("endpoint-addr-1"); - EquivalentAddressGroup endpoint2 = makeAddress("endpoint-addr-2"); - LocalityLbEndpoints localityLbEndpoints1 = - LocalityLbEndpoints.create( - Collections.singletonList(LbEndpoint.create(endpoint1, 100, false /* isHealthy */)), - 10 /* localityWeight */, 1 /* priority */); - LocalityLbEndpoints localityLbEndpoints2 = - LocalityLbEndpoints.create( - Collections.singletonList(LbEndpoint.create(endpoint2, 100, true /* isHealthy */)), - 10 /* localityWeight */, 1 /* priority */); - xdsClient.deliverClusterLoadAssignment( - EDS_SERVICE_NAME1, - ImmutableMap.of(locality1, localityLbEndpoints1, locality2, localityLbEndpoints2)); + ClusterLoadAssignment clusterLoadAssignment = ClusterLoadAssignment.newBuilder() + .setClusterName(EDS_SERVICE_NAME) + .addEndpoints(LocalityLbEndpoints.newBuilder() + .setLoadBalancingWeight(UInt32Value.of(100)) + .setLocality(LOCALITY1) + .addLbEndpoints(newSocketLbEndpoint("127.0.0.1", 8000) + .setHealthStatus(HealthStatus.UNHEALTHY))) + .addEndpoints(LocalityLbEndpoints.newBuilder() + .setLoadBalancingWeight(UInt32Value.of(100)) + .setLocality(LOCALITY2) + .addLbEndpoints(newSocketLbEndpoint("127.0.0.2", 8000))) + .build(); + controlPlaneService.setXdsConfig(ADS_TYPE_URL_EDS, ImmutableMap.of( + EDS_SERVICE_NAME, clusterLoadAssignment)); + startXdsDepManager(); + verify(helper, never()).updateBalancingState(eq(ConnectivityState.TRANSIENT_FAILURE), any()); FakeLoadBalancer childBalancer = Iterables.getOnlyElement(childBalancers); + io.grpc.xds.client.Locality locality2 = io.grpc.xds.client.Locality.create( + LOCALITY2.getRegion(), LOCALITY2.getZone(), LOCALITY2.getSubZone()); for (EquivalentAddressGroup eag : childBalancer.addresses) { - assertThat(eag.getAttributes().get(InternalXdsAttributes.ATTR_LOCALITY)).isEqualTo(locality2); + assertThat(eag.getAttributes().get(io.grpc.xds.XdsAttributes.ATTR_LOCALITY)) + .isEqualTo(locality2); } } @Test public void handleEdsResource_ignorePrioritiesWithNoHealthyEndpoints() { - ClusterResolverConfig config = - new ClusterResolverConfig(Collections.singletonList(edsDiscoveryMechanism1), roundRobin); - deliverLbConfig(config); - EquivalentAddressGroup endpoint1 = makeAddress("endpoint-addr-1"); - EquivalentAddressGroup endpoint2 = makeAddress("endpoint-addr-2"); - LocalityLbEndpoints localityLbEndpoints1 = - LocalityLbEndpoints.create( - Collections.singletonList(LbEndpoint.create(endpoint1, 100, false /* isHealthy */)), - 10 /* localityWeight */, 1 /* priority */); - LocalityLbEndpoints localityLbEndpoints2 = - LocalityLbEndpoints.create( - Collections.singletonList(LbEndpoint.create(endpoint2, 200, true /* isHealthy */)), - 10 /* localityWeight */, 2 /* priority */); - String priority2 = CLUSTER1 + "[child2]"; - xdsClient.deliverClusterLoadAssignment( - EDS_SERVICE_NAME1, - ImmutableMap.of(locality1, localityLbEndpoints1, locality2, localityLbEndpoints2)); + ClusterLoadAssignment clusterLoadAssignment = ClusterLoadAssignment.newBuilder() + .setClusterName(EDS_SERVICE_NAME) + .addEndpoints(LocalityLbEndpoints.newBuilder() + .setLoadBalancingWeight(UInt32Value.of(100)) + .setLocality(LOCALITY1) + .setPriority(0) + .addLbEndpoints(newSocketLbEndpoint("127.0.0.1", 8000) + .setHealthStatus(HealthStatus.UNHEALTHY))) + .addEndpoints(LocalityLbEndpoints.newBuilder() + .setLoadBalancingWeight(UInt32Value.of(100)) + .setLocality(LOCALITY2) + .setPriority(1) + .addLbEndpoints(newSocketLbEndpoint("127.0.0.2", 8000))) + .build(); + controlPlaneService.setXdsConfig(ADS_TYPE_URL_EDS, ImmutableMap.of( + EDS_SERVICE_NAME, clusterLoadAssignment)); + startXdsDepManager(); + verify(helper, never()).updateBalancingState(eq(ConnectivityState.TRANSIENT_FAILURE), any()); + String priority2 = CLUSTER + "[child2]"; FakeLoadBalancer childBalancer = Iterables.getOnlyElement(childBalancers); assertThat(((PriorityLbConfig) childBalancer.config).priorities).containsExactly(priority2); } @Test public void handleEdsResource_noHealthyEndpoint() { - ClusterResolverConfig config = - new ClusterResolverConfig(Collections.singletonList(edsDiscoveryMechanism1), roundRobin); - deliverLbConfig(config); - EquivalentAddressGroup endpoint = makeAddress("endpoint-addr-1"); - LocalityLbEndpoints localityLbEndpoints = - LocalityLbEndpoints.create( - Collections.singletonList(LbEndpoint.create(endpoint, 100, false /* isHealthy */)), - 10 /* localityWeight */, 1 /* priority */); - xdsClient.deliverClusterLoadAssignment(EDS_SERVICE_NAME1, - Collections.singletonMap(locality1, localityLbEndpoints)); // single endpoint, unhealthy + ClusterLoadAssignment clusterLoadAssignment = ClusterLoadAssignment.newBuilder() + .setClusterName(EDS_SERVICE_NAME) + .addEndpoints(LocalityLbEndpoints.newBuilder() + .setLoadBalancingWeight(UInt32Value.of(100)) + .setLocality(LOCALITY1) + .addLbEndpoints(newSocketLbEndpoint("127.0.0.1", 8000) + .setHealthStatus(HealthStatus.UNHEALTHY))) + .build(); + controlPlaneService.setXdsConfig(ADS_TYPE_URL_EDS, ImmutableMap.of( + EDS_SERVICE_NAME, clusterLoadAssignment)); + startXdsDepManager(); - assertThat(childBalancers).isEmpty(); + assertThat(childBalancers).hasSize(0); // Graceful switch handles it, so no child policies yet verify(helper).updateBalancingState( eq(ConnectivityState.TRANSIENT_FAILURE), pickerCaptor.capture()); - assertPicker( - pickerCaptor.getValue(), - Status.UNAVAILABLE.withDescription( - "No usable endpoint from cluster(s): " + Collections.singleton(CLUSTER1)), - null); + Status expectedStatus = Status.UNAVAILABLE + .withDescription("No usable endpoint from cluster: " + CLUSTER); + assertPicker(pickerCaptor.getValue(), expectedStatus, null); } @Test public void onlyLogicalDnsCluster_endpointsResolved() { - ClusterResolverConfig config = new ClusterResolverConfig( - Collections.singletonList(logicalDnsDiscoveryMechanism), roundRobin); - deliverLbConfig(config); - FakeNameResolver resolver = assertResolverCreated("/" + DNS_HOST_NAME); + boolean originalVal = LoadStatsManager2.isEnabledOrcaLrsPropagation; + LoadStatsManager2.isEnabledOrcaLrsPropagation = true; + List metricSpecs = Arrays.asList("cpu_utilization"); + BackendMetricPropagation backendMetricPropagation = + BackendMetricPropagation.fromMetricSpecs(metricSpecs); + Cluster logicalDnsClusterWithMetrics = LOGICAL_DNS_CLUSTER.toBuilder() + .addAllLrsReportEndpointMetrics(metricSpecs) + .build(); + controlPlaneService.setXdsConfig(ADS_TYPE_URL_CDS, ImmutableMap.of( + CLUSTER, logicalDnsClusterWithMetrics)); + startXdsDepManager(new CdsConfig(CLUSTER), /* forwardTime= */ false); + FakeNameResolver resolver = assertResolverCreated("/" + DNS_HOST_NAME + ":9000"); assertThat(childBalancers).isEmpty(); - EquivalentAddressGroup endpoint1 = makeAddress("endpoint-addr-1"); - EquivalentAddressGroup endpoint2 = makeAddress("endpoint-addr-2"); - resolver.deliverEndpointAddresses(Arrays.asList(endpoint1, endpoint2)); + resolver.deliverEndpointAddresses(Arrays.asList( + newInetSocketAddressEag("127.0.2.1", 9000), newInetSocketAddressEag("127.0.2.2", 9000))); + fakeClock.forwardTime(10, TimeUnit.MINUTES); + verify(helper, never()).updateBalancingState(eq(ConnectivityState.TRANSIENT_FAILURE), any()); assertThat(childBalancers).hasSize(1); FakeLoadBalancer childBalancer = Iterables.getOnlyElement(childBalancers); assertThat(childBalancer.name).isEqualTo(PRIORITY_POLICY_NAME); @@ -772,340 +914,145 @@ public void onlyLogicalDnsCluster_endpointsResolved() { String priority = Iterables.getOnlyElement(priorityLbConfig.priorities); PriorityChildConfig priorityChildConfig = priorityLbConfig.childConfigs.get(priority); assertThat(priorityChildConfig.ignoreReresolution).isFalse(); - assertThat(priorityChildConfig.policySelection.getProvider().getPolicyName()) + assertThat(GracefulSwitchLoadBalancerAccessor.getChildProvider(priorityChildConfig.childConfig) + .getPolicyName()) .isEqualTo(CLUSTER_IMPL_POLICY_NAME); - ClusterImplConfig clusterImplConfig = - (ClusterImplConfig) priorityChildConfig.policySelection.getConfig(); - assertClusterImplConfig(clusterImplConfig, CLUSTER_DNS, null, LRS_SERVER_INFO, 300L, null, - Collections.emptyList(), "pick_first"); - assertAddressesEqual(Arrays.asList(endpoint1, endpoint2), childBalancer.addresses); + ClusterImplConfig clusterImplConfig = (ClusterImplConfig) + GracefulSwitchLoadBalancerAccessor.getChildConfig(priorityChildConfig.childConfig); + assertClusterImplConfig(clusterImplConfig, CLUSTER, null, null, null, null, + Collections.emptyList(), "wrr_locality_experimental"); + assertThat(clusterImplConfig.backendMetricPropagation).isEqualTo(backendMetricPropagation); + LoadStatsManager2.isEnabledOrcaLrsPropagation = originalVal; + assertAddressesEqual( + Arrays.asList(new EquivalentAddressGroup(Arrays.asList( + newInetSocketAddress("127.0.2.1", 9000), newInetSocketAddress("127.0.2.2", 9000)))), + childBalancer.addresses); + assertThat(childBalancer.addresses.get(0).getAttributes() + .get(InternalEquivalentAddressGroup.ATTR_BACKEND_SERVICE)).isEqualTo(CLUSTER); + assertThat(childBalancer.addresses.get(0).getAttributes() + .get(XdsInternalAttributes.ATTR_ADDRESS_NAME)).isEqualTo(DNS_HOST_NAME + ":9000"); } @Test public void onlyLogicalDnsCluster_handleRefreshNameResolution() { - ClusterResolverConfig config = new ClusterResolverConfig( - Collections.singletonList(logicalDnsDiscoveryMechanism), roundRobin); - deliverLbConfig(config); - FakeNameResolver resolver = assertResolverCreated("/" + DNS_HOST_NAME); + controlPlaneService.setXdsConfig(ADS_TYPE_URL_CDS, ImmutableMap.of( + CLUSTER, LOGICAL_DNS_CLUSTER)); + startXdsDepManager(new CdsConfig(CLUSTER), /* forwardTime= */ false); + FakeNameResolver resolver = assertResolverCreated("/" + DNS_HOST_NAME + ":9000"); assertThat(childBalancers).isEmpty(); - EquivalentAddressGroup endpoint1 = makeAddress("endpoint-addr-1"); - EquivalentAddressGroup endpoint2 = makeAddress("endpoint-addr-2"); - resolver.deliverEndpointAddresses(Arrays.asList(endpoint1, endpoint2)); - assertThat(resolver.refreshCount).isEqualTo(0); - FakeLoadBalancer childBalancer = Iterables.getOnlyElement(childBalancers); - childBalancer.helper.refreshNameResolution(); - assertThat(resolver.refreshCount).isEqualTo(1); - } + resolver.deliverEndpointAddresses(Arrays.asList(newInetSocketAddressEag("127.0.2.1", 9000))); + fakeClock.forwardTime(10, TimeUnit.MINUTES); - @Test - public void onlyLogicalDnsCluster_resolutionError_backoffAndRefresh() { - InOrder inOrder = Mockito.inOrder(helper, backoffPolicyProvider, - backoffPolicy1, backoffPolicy2); - ClusterResolverConfig config = new ClusterResolverConfig( - Collections.singletonList(logicalDnsDiscoveryMechanism), roundRobin); - deliverLbConfig(config); - FakeNameResolver resolver = assertResolverCreated("/" + DNS_HOST_NAME); - assertThat(childBalancers).isEmpty(); - Status error = Status.UNAVAILABLE.withDescription("cannot reach DNS server"); - resolver.deliverError(error); - inOrder.verify(helper).updateBalancingState( - eq(ConnectivityState.TRANSIENT_FAILURE), pickerCaptor.capture()); - assertPicker(pickerCaptor.getValue(), error, null); + verify(helper, never()).updateBalancingState(eq(ConnectivityState.TRANSIENT_FAILURE), any()); assertThat(resolver.refreshCount).isEqualTo(0); - inOrder.verify(backoffPolicyProvider).get(); - inOrder.verify(backoffPolicy1).nextBackoffNanos(); - assertThat(fakeClock.getPendingTasks()).hasSize(1); - assertThat(Iterables.getOnlyElement(fakeClock.getPendingTasks()).getDelay(TimeUnit.SECONDS)) - .isEqualTo(1L); - fakeClock.forwardTime(1L, TimeUnit.SECONDS); - assertThat(resolver.refreshCount).isEqualTo(1); - - error = Status.UNKNOWN.withDescription("I am lost"); - resolver.deliverError(error); - inOrder.verify(helper).updateBalancingState( - eq(ConnectivityState.TRANSIENT_FAILURE), pickerCaptor.capture()); - inOrder.verify(backoffPolicy1).nextBackoffNanos(); - assertPicker(pickerCaptor.getValue(), error, null); - assertThat(fakeClock.getPendingTasks()).hasSize(1); - assertThat(Iterables.getOnlyElement(fakeClock.getPendingTasks()).getDelay(TimeUnit.SECONDS)) - .isEqualTo(10L); - fakeClock.forwardTime(10L, TimeUnit.SECONDS); - assertThat(resolver.refreshCount).isEqualTo(2); - - // Succeed. - EquivalentAddressGroup endpoint1 = makeAddress("endpoint-addr-1"); - EquivalentAddressGroup endpoint2 = makeAddress("endpoint-addr-2"); - resolver.deliverEndpointAddresses(Arrays.asList(endpoint1, endpoint2)); - assertThat(childBalancers).hasSize(1); - assertAddressesEqual(Arrays.asList(endpoint1, endpoint2), - Iterables.getOnlyElement(childBalancers).addresses); - - assertThat(fakeClock.getPendingTasks()).isEmpty(); - inOrder.verifyNoMoreInteractions(); - } - - @Test - public void onlyLogicalDnsCluster_refreshNameResolutionRaceWithResolutionError() { - InOrder inOrder = Mockito.inOrder(backoffPolicyProvider, backoffPolicy1, backoffPolicy2); - ClusterResolverConfig config = new ClusterResolverConfig( - Collections.singletonList(logicalDnsDiscoveryMechanism), roundRobin); - deliverLbConfig(config); - FakeNameResolver resolver = assertResolverCreated("/" + DNS_HOST_NAME); - assertThat(childBalancers).isEmpty(); - EquivalentAddressGroup endpoint = makeAddress("endpoint-addr"); - resolver.deliverEndpointAddresses(Collections.singletonList(endpoint)); FakeLoadBalancer childBalancer = Iterables.getOnlyElement(childBalancers); - assertAddressesEqual(Collections.singletonList(endpoint), childBalancer.addresses); - assertThat(resolver.refreshCount).isEqualTo(0); - childBalancer.helper.refreshNameResolution(); assertThat(resolver.refreshCount).isEqualTo(1); - resolver.deliverError(Status.UNAVAILABLE.withDescription("I am lost")); - inOrder.verify(backoffPolicyProvider).get(); - inOrder.verify(backoffPolicy1).nextBackoffNanos(); - assertThat(fakeClock.getPendingTasks()).hasSize(1); - ScheduledTask task = Iterables.getOnlyElement(fakeClock.getPendingTasks()); - assertThat(task.getDelay(TimeUnit.SECONDS)).isEqualTo(1L); - - fakeClock.forwardTime( 100L, TimeUnit.MILLISECONDS); - childBalancer.helper.refreshNameResolution(); - assertThat(resolver.refreshCount).isEqualTo(2); - assertThat(task.isCancelled()).isTrue(); - assertThat(fakeClock.getPendingTasks()).isEmpty(); - resolver.deliverError(Status.UNAVAILABLE.withDescription("I am still lost")); - inOrder.verify(backoffPolicyProvider).get(); // active refresh resets backoff sequence - inOrder.verify(backoffPolicy2).nextBackoffNanos(); - task = Iterables.getOnlyElement(fakeClock.getPendingTasks()); - assertThat(task.getDelay(TimeUnit.SECONDS)).isEqualTo(5L); - - fakeClock.forwardTime(5L, TimeUnit.SECONDS); - assertThat(resolver.refreshCount).isEqualTo(3); - inOrder.verifyNoMoreInteractions(); } @Test - public void edsClustersAndLogicalDnsCluster_receivedEndpoints() { - ClusterResolverConfig config = new ClusterResolverConfig( - Arrays.asList(edsDiscoveryMechanism1, logicalDnsDiscoveryMechanism), roundRobin); - deliverLbConfig(config); - assertThat(xdsClient.watchers.keySet()).containsExactly(EDS_SERVICE_NAME1); - FakeNameResolver resolver = assertResolverCreated("/" + DNS_HOST_NAME); - assertThat(childBalancers).isEmpty(); - EquivalentAddressGroup endpoint1 = makeAddress("endpoint-addr-1"); // DNS endpoint - EquivalentAddressGroup endpoint2 = makeAddress("endpoint-addr-2"); // DNS endpoint - EquivalentAddressGroup endpoint3 = makeAddress("endpoint-addr-3"); // EDS endpoint - resolver.deliverEndpointAddresses(Arrays.asList(endpoint1, endpoint2)); - LocalityLbEndpoints localityLbEndpoints = - LocalityLbEndpoints.create( - Collections.singletonList(LbEndpoint.create(endpoint3, 100, true)), - 10 /* localityWeight */, 1 /* priority */); - xdsClient.deliverClusterLoadAssignment( - EDS_SERVICE_NAME1, Collections.singletonMap(locality1, localityLbEndpoints)); + public void outlierDetection_disabledConfig() { + Cluster cluster = EDS_CLUSTER.toBuilder() + .setOutlierDetection(OutlierDetection.newBuilder() + .setEnforcingSuccessRate(UInt32Value.of(0)) + .setEnforcingFailurePercentage(UInt32Value.of(0))) + .build(); + controlPlaneService.setXdsConfig(ADS_TYPE_URL_CDS, ImmutableMap.of( + CLUSTER, cluster)); + startXdsDepManager(); + verify(helper, never()).updateBalancingState(eq(ConnectivityState.TRANSIENT_FAILURE), any()); assertThat(childBalancers).hasSize(1); FakeLoadBalancer childBalancer = Iterables.getOnlyElement(childBalancers); - assertThat(((PriorityLbConfig) childBalancer.config).priorities) - .containsExactly(CLUSTER1 + "[child1]", CLUSTER_DNS + "[child0]").inOrder(); - assertAddressesEqual(Arrays.asList(endpoint3, endpoint1, endpoint2), - childBalancer.addresses); // ordered by cluster then addresses - assertAddressesEqual(AddressFilter.filter(AddressFilter.filter( - childBalancer.addresses, CLUSTER1 + "[child1]"), locality1.toString()), - Collections.singletonList(endpoint3)); - assertAddressesEqual(AddressFilter.filter(AddressFilter.filter( - childBalancer.addresses, CLUSTER_DNS + "[child0]"), - Locality.create("", "", "").toString()), - Arrays.asList(endpoint1, endpoint2)); + assertThat(childBalancer.name).isEqualTo(PRIORITY_POLICY_NAME); + PriorityLbConfig priorityLbConfig = (PriorityLbConfig) childBalancer.config; + PriorityChildConfig priorityChildConfig = + Iterables.getOnlyElement(priorityLbConfig.childConfigs.values()); + OutlierDetectionLoadBalancerConfig outlier = (OutlierDetectionLoadBalancerConfig) + GracefulSwitchLoadBalancerAccessor.getChildConfig(priorityChildConfig.childConfig); + assertThat(outlier.successRateEjection).isNull(); + assertThat(outlier.failurePercentageEjection).isNull(); } @Test - public void noEdsResourceExists_useDnsResolutionResults() { - ClusterResolverConfig config = new ClusterResolverConfig( - Arrays.asList(edsDiscoveryMechanism1, logicalDnsDiscoveryMechanism), roundRobin); - deliverLbConfig(config); - assertThat(xdsClient.watchers.keySet()).containsExactly(EDS_SERVICE_NAME1); - FakeNameResolver resolver = assertResolverCreated("/" + DNS_HOST_NAME); - assertThat(childBalancers).isEmpty(); - reset(helper); - xdsClient.deliverResourceNotFound(EDS_SERVICE_NAME1); - verify(helper, never()).updateBalancingState( - any(ConnectivityState.class), any(SubchannelPicker.class)); // wait for DNS results - - EquivalentAddressGroup endpoint1 = makeAddress("endpoint-addr-1"); - EquivalentAddressGroup endpoint2 = makeAddress("endpoint-addr-2"); - resolver.deliverEndpointAddresses(Arrays.asList(endpoint1, endpoint2)); - assertThat(childBalancers).hasSize(1); - FakeLoadBalancer childBalancer = Iterables.getOnlyElement(childBalancers); - String priority = Iterables.getOnlyElement( - ((PriorityLbConfig) childBalancer.config).priorities); - assertThat(priority).isEqualTo(CLUSTER_DNS + "[child0]"); - assertAddressesEqual(Arrays.asList(endpoint1, endpoint2), childBalancer.addresses); - } + public void outlierDetection_fullConfig() { + Cluster cluster = EDS_CLUSTER.toBuilder() + .setLbPolicy(Cluster.LbPolicy.ROUND_ROBIN) + .setOutlierDetection(OutlierDetection.newBuilder() + .setInterval(Duration.newBuilder().setNanos(101)) + .setBaseEjectionTime(Duration.newBuilder().setNanos(102)) + .setMaxEjectionTime(Duration.newBuilder().setNanos(103)) + .setMaxEjectionPercent(UInt32Value.of(80)) + .setSuccessRateStdevFactor(UInt32Value.of(105)) + .setEnforcingSuccessRate(UInt32Value.of(81)) + .setSuccessRateMinimumHosts(UInt32Value.of(107)) + .setSuccessRateRequestVolume(UInt32Value.of(108)) + .setFailurePercentageThreshold(UInt32Value.of(82)) + .setEnforcingFailurePercentage(UInt32Value.of(83)) + .setFailurePercentageMinimumHosts(UInt32Value.of(111)) + .setFailurePercentageRequestVolume(UInt32Value.of(112))) + .build(); + controlPlaneService.setXdsConfig(ADS_TYPE_URL_CDS, ImmutableMap.of( + CLUSTER, cluster)); + startXdsDepManager(); - @Test - public void edsResourceRevoked_dnsResolutionError_shutDownChildLbPolicyAndReturnErrorPicker() { - ClusterResolverConfig config = new ClusterResolverConfig( - Arrays.asList(edsDiscoveryMechanism1, logicalDnsDiscoveryMechanism), roundRobin); - deliverLbConfig(config); - assertThat(xdsClient.watchers.keySet()).containsExactly(EDS_SERVICE_NAME1); - FakeNameResolver resolver = assertResolverCreated("/" + DNS_HOST_NAME); - assertThat(childBalancers).isEmpty(); - reset(helper); - EquivalentAddressGroup endpoint = makeAddress("endpoint-addr-1"); - LocalityLbEndpoints localityLbEndpoints = - LocalityLbEndpoints.create( - Collections.singletonList(LbEndpoint.create(endpoint, 100, true)), - 10 /* localityWeight */, 1 /* priority */); - xdsClient.deliverClusterLoadAssignment( - EDS_SERVICE_NAME1, Collections.singletonMap(locality1, localityLbEndpoints)); - resolver.deliverError(Status.UNKNOWN.withDescription("I am lost")); + verify(helper, never()).updateBalancingState(eq(ConnectivityState.TRANSIENT_FAILURE), any()); assertThat(childBalancers).hasSize(1); FakeLoadBalancer childBalancer = Iterables.getOnlyElement(childBalancers); - assertThat(((PriorityLbConfig) childBalancer.config).priorities) - .containsExactly(CLUSTER1 + "[child1]"); - assertAddressesEqual(Collections.singletonList(endpoint), childBalancer.addresses); - assertThat(childBalancer.shutdown).isFalse(); - xdsClient.deliverResourceNotFound(EDS_SERVICE_NAME1); - assertThat(childBalancer.shutdown).isTrue(); - verify(helper).updateBalancingState( - eq(ConnectivityState.TRANSIENT_FAILURE), pickerCaptor.capture()); - assertPicker(pickerCaptor.getValue(), - Status.UNAVAILABLE.withDescription("I am lost"), null); - } - - @Test - public void resolutionErrorAfterChildLbCreated_propagateErrorIfAllClustersEncounterError() { - ClusterResolverConfig config = new ClusterResolverConfig( - Arrays.asList(edsDiscoveryMechanism1, logicalDnsDiscoveryMechanism), roundRobin); - deliverLbConfig(config); - assertThat(xdsClient.watchers.keySet()).containsExactly(EDS_SERVICE_NAME1); - FakeNameResolver resolver = assertResolverCreated("/" + DNS_HOST_NAME); - assertThat(childBalancers).isEmpty(); - reset(helper); - EquivalentAddressGroup endpoint = makeAddress("endpoint-addr-1"); - LocalityLbEndpoints localityLbEndpoints = - LocalityLbEndpoints.create( - Collections.singletonList(LbEndpoint.create(endpoint, 100, true)), - 10 /* localityWeight */, 1 /* priority */); - xdsClient.deliverClusterLoadAssignment( - EDS_SERVICE_NAME1, Collections.singletonMap(locality1, localityLbEndpoints)); - assertThat(childBalancers).isEmpty(); // not created until all clusters resolved. - - resolver.deliverError(Status.UNKNOWN.withDescription("I am lost")); - - // DNS resolution failed, but there are EDS endpoints can be used. - assertThat(childBalancers).hasSize(1); - FakeLoadBalancer childBalancer = Iterables.getOnlyElement(childBalancers); // child LB created - assertThat(childBalancer.upstreamError).isNull(); // should not propagate error to child LB - assertAddressesEqual(Collections.singletonList(endpoint), childBalancer.addresses); - - xdsClient.deliverError(Status.RESOURCE_EXHAUSTED.withDescription("out of memory")); - assertThat(childBalancer.upstreamError).isNotNull(); // last cluster's (DNS) error propagated - assertThat(childBalancer.upstreamError.getCode()).isEqualTo(Code.UNKNOWN); - assertThat(childBalancer.upstreamError.getDescription()).isEqualTo("I am lost"); - assertThat(childBalancer.shutdown).isFalse(); - verify(helper, never()).updateBalancingState( - eq(ConnectivityState.TRANSIENT_FAILURE), any(SubchannelPicker.class)); - } - - @Test - public void resolutionErrorBeforeChildLbCreated_returnErrorPickerIfAllClustersEncounterError() { - ClusterResolverConfig config = new ClusterResolverConfig( - Arrays.asList(edsDiscoveryMechanism1, logicalDnsDiscoveryMechanism), roundRobin); - deliverLbConfig(config); - assertThat(xdsClient.watchers.keySet()).containsExactly(EDS_SERVICE_NAME1); - FakeNameResolver resolver = assertResolverCreated("/" + DNS_HOST_NAME); - assertThat(childBalancers).isEmpty(); - reset(helper); - xdsClient.deliverError(Status.UNIMPLEMENTED.withDescription("not found")); - assertThat(childBalancers).isEmpty(); - verify(helper, never()).updateBalancingState( - eq(ConnectivityState.TRANSIENT_FAILURE), any(SubchannelPicker.class)); // wait for DNS - Status dnsError = Status.UNKNOWN.withDescription("I am lost"); - resolver.deliverError(dnsError); - verify(helper).updateBalancingState( - eq(ConnectivityState.TRANSIENT_FAILURE), pickerCaptor.capture()); - assertPicker( - pickerCaptor.getValue(), - Status.UNAVAILABLE.withDescription(dnsError.getDescription()), - null); - } - - @Test - public void resolutionErrorBeforeChildLbCreated_edsOnly_returnErrorPicker() { - ClusterResolverConfig config = new ClusterResolverConfig( - Arrays.asList(edsDiscoveryMechanism1), roundRobin); - deliverLbConfig(config); - assertThat(xdsClient.watchers.keySet()).containsExactly(EDS_SERVICE_NAME1); - assertThat(childBalancers).isEmpty(); - reset(helper); - xdsClient.deliverError(Status.RESOURCE_EXHAUSTED.withDescription("OOM")); - assertThat(childBalancers).isEmpty(); - verify(helper).updateBalancingState( - eq(ConnectivityState.TRANSIENT_FAILURE), pickerCaptor.capture()); - PickResult result = pickerCaptor.getValue().pickSubchannel(mock(PickSubchannelArgs.class)); - Status actualStatus = result.getStatus(); - assertThat(actualStatus.getCode()).isEqualTo(Status.Code.UNAVAILABLE); - assertThat(actualStatus.getDescription()).contains("RESOURCE_EXHAUSTED: OOM"); + assertThat(childBalancer.name).isEqualTo(PRIORITY_POLICY_NAME); + PriorityLbConfig priorityLbConfig = (PriorityLbConfig) childBalancer.config; + PriorityChildConfig priorityChildConfig = + Iterables.getOnlyElement(priorityLbConfig.childConfigs.values()); + OutlierDetectionLoadBalancerConfig outlier = (OutlierDetectionLoadBalancerConfig) + GracefulSwitchLoadBalancerAccessor.getChildConfig(priorityChildConfig.childConfig); + assertThat(outlier.intervalNanos).isEqualTo(101); + assertThat(outlier.baseEjectionTimeNanos).isEqualTo(102); + assertThat(outlier.maxEjectionTimeNanos).isEqualTo(103); + assertThat(outlier.maxEjectionPercent).isEqualTo(80); + assertThat(outlier.successRateEjection.stdevFactor).isEqualTo(105); + assertThat(outlier.successRateEjection.enforcementPercentage).isEqualTo(81); + assertThat(outlier.successRateEjection.minimumHosts).isEqualTo(107); + assertThat(outlier.successRateEjection.requestVolume).isEqualTo(108); + assertThat(outlier.failurePercentageEjection.threshold).isEqualTo(82); + assertThat(outlier.failurePercentageEjection.enforcementPercentage).isEqualTo(83); + assertThat(outlier.failurePercentageEjection.minimumHosts).isEqualTo(111); + assertThat(outlier.failurePercentageEjection.requestVolume).isEqualTo(112); + assertClusterImplConfig( + (ClusterImplConfig) GracefulSwitchLoadBalancerAccessor.getChildConfig(outlier.childConfig), + CLUSTER, EDS_SERVICE_NAME, null, null, null, Collections.emptyList(), + "wrr_locality_experimental"); } - @Test - public void handleNameResolutionErrorFromUpstream_beforeChildLbCreated_returnErrorPicker() { - ClusterResolverConfig config = new ClusterResolverConfig( - Arrays.asList(edsDiscoveryMechanism1, logicalDnsDiscoveryMechanism), roundRobin); - deliverLbConfig(config); - assertThat(xdsClient.watchers.keySet()).containsExactly(EDS_SERVICE_NAME1); - assertResolverCreated("/" + DNS_HOST_NAME); - assertThat(childBalancers).isEmpty(); - reset(helper); - Status upstreamError = Status.UNAVAILABLE.withDescription("unreachable"); - loadBalancer.handleNameResolutionError(upstreamError); - verify(helper).updateBalancingState( - eq(ConnectivityState.TRANSIENT_FAILURE), pickerCaptor.capture()); - assertPicker(pickerCaptor.getValue(), upstreamError, null); + private void startXdsDepManager() { + startXdsDepManager(new CdsConfig(CLUSTER)); } - @Test - public void handleNameResolutionErrorFromUpstream_afterChildLbCreated_fallThrough() { - ClusterResolverConfig config = new ClusterResolverConfig( - Arrays.asList(edsDiscoveryMechanism1, logicalDnsDiscoveryMechanism), roundRobin); - deliverLbConfig(config); - assertThat(xdsClient.watchers.keySet()).containsExactly(EDS_SERVICE_NAME1); - FakeNameResolver resolver = assertResolverCreated("/" + DNS_HOST_NAME); - assertThat(childBalancers).isEmpty(); - reset(helper); - EquivalentAddressGroup endpoint1 = makeAddress("endpoint-addr-1"); - EquivalentAddressGroup endpoint2 = makeAddress("endpoint-addr-2"); - LocalityLbEndpoints localityLbEndpoints = - LocalityLbEndpoints.create( - Collections.singletonList(LbEndpoint.create(endpoint1, 100, true)), - 10 /* localityWeight */, 1 /* priority */); - xdsClient.deliverClusterLoadAssignment( - EDS_SERVICE_NAME1, Collections.singletonMap(locality1, localityLbEndpoints)); - resolver.deliverEndpointAddresses(Collections.singletonList(endpoint2)); - assertThat(childBalancers).hasSize(1); - FakeLoadBalancer childBalancer = Iterables.getOnlyElement(childBalancers); - assertThat(((PriorityLbConfig) childBalancer.config).priorities) - .containsExactly(CLUSTER1 + "[child1]", CLUSTER_DNS + "[child0]"); - assertAddressesEqual(Arrays.asList(endpoint1, endpoint2), childBalancer.addresses); - - loadBalancer.handleNameResolutionError(Status.UNAVAILABLE.withDescription("unreachable")); - assertThat(childBalancer.upstreamError.getCode()).isEqualTo(Code.UNAVAILABLE); - assertThat(childBalancer.upstreamError.getDescription()).isEqualTo("unreachable"); - verify(helper, never()).updateBalancingState( - any(ConnectivityState.class), any(SubchannelPicker.class)); + private void startXdsDepManager(final CdsConfig cdsConfig) { + startXdsDepManager(cdsConfig, true); } - private void deliverLbConfig(ClusterResolverConfig config) { - loadBalancer.acceptResolvedAddresses( - ResolvedAddresses.newBuilder() - .setAddresses(Collections.emptyList()) - .setAttributes( - // Other attributes not used by cluster_resolver LB are omitted. - Attributes.newBuilder() - .set(InternalXdsAttributes.XDS_CLIENT_POOL, xdsClientPool) - .build()) - .setLoadBalancingPolicyConfig(config) - .build()); + private void startXdsDepManager(final CdsConfig cdsConfig, boolean forwardTime) { + xdsDepManager.start( + xdsConfig -> { + if (!xdsConfig.hasValue()) { + throw new AssertionError("" + xdsConfig.getStatus()); + } + if (loadBalancer == null) { + return; + } + loadBalancer.acceptResolvedAddresses(ResolvedAddresses.newBuilder() + .setAddresses(Collections.emptyList()) + .setAttributes(Attributes.newBuilder() + .set(io.grpc.xds.XdsAttributes.XDS_CONFIG, xdsConfig.getValue()) + .set(io.grpc.xds.XdsAttributes.XDS_CLUSTER_SUBSCRIPT_REGISTRY, xdsDepManager) + .build()) + .setLoadBalancingPolicyConfig(cdsConfig) + .build()); + }); + if (forwardTime) { + // trigger does not exist timer, so broken config is more obvious + fakeClock.forwardTime(10, TimeUnit.MINUTES); + } } private FakeNameResolver assertResolverCreated(String uriPath) { @@ -1136,101 +1083,42 @@ private static void assertClusterImplConfig(ClusterImplConfig config, String clu assertThat(config.maxConcurrentRequests).isEqualTo(maxConcurrentRequests); assertThat(config.tlsContext).isEqualTo(tlsContext); assertThat(config.dropCategories).isEqualTo(dropCategories); - assertThat(config.childPolicy.getProvider().getPolicyName()).isEqualTo(childPolicy); + assertThat( + GracefulSwitchLoadBalancerAccessor.getChildProvider(config.childConfig).getPolicyName()) + .isEqualTo(childPolicy); } /** Asserts two list of EAGs contains same addresses, regardless of attributes. */ private static void assertAddressesEqual( List expected, List actual) { - assertThat(actual.size()).isEqualTo(expected.size()); - for (int i = 0; i < actual.size(); i++) { - assertThat(actual.get(i).getAddresses()).isEqualTo(expected.get(i).getAddresses()); - } + List> expectedAddresses + = expected.stream().map(EquivalentAddressGroup::getAddresses).collect(toList()); + List> actualAddresses + = actual.stream().map(EquivalentAddressGroup::getAddresses).collect(toList()); + assertThat(actualAddresses).isEqualTo(expectedAddresses); } - private static EquivalentAddressGroup makeAddress(final String name) { - class FakeSocketAddress extends SocketAddress { - private final String name; - - private FakeSocketAddress(String name) { - this.name = name; - } - - @Override - public int hashCode() { - return Objects.hash(name); - } - - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (!(o instanceof FakeSocketAddress)) { - return false; - } - FakeSocketAddress that = (FakeSocketAddress) o; - return Objects.equals(name, that.name); - } - - @Override - public String toString() { - return name; - } - } - - return new EquivalentAddressGroup(new FakeSocketAddress(name)); + @SuppressWarnings("AddressSelection") + private static InetSocketAddress newInetSocketAddress(String ip, int port) { + return new InetSocketAddress(ip, port); } - private static final class FakeXdsClient extends XdsClient { - private final Map> watchers = new HashMap<>(); - - @Override - @SuppressWarnings("unchecked") - public void watchXdsResource(XdsResourceType type, - String resourceName, - ResourceWatcher watcher, - Executor syncContext) { - assertThat(type.typeName()).isEqualTo("EDS"); - assertThat(watchers).doesNotContainKey(resourceName); - watchers.put(resourceName, (ResourceWatcher) watcher); - } - - @Override - @SuppressWarnings("unchecked") - public void cancelXdsResourceWatch(XdsResourceType type, - String resourceName, - ResourceWatcher watcher) { - assertThat(type.typeName()).isEqualTo("EDS"); - assertThat(watchers).containsKey(resourceName); - watchers.remove(resourceName); - } - - void deliverClusterLoadAssignment( - String resource, Map localityLbEndpointsMap) { - deliverClusterLoadAssignment( - resource, Collections.emptyList(), localityLbEndpointsMap); - } - - void deliverClusterLoadAssignment(String resource, List dropOverloads, - Map localityLbEndpointsMap) { - if (watchers.containsKey(resource)) { - watchers.get(resource).onChanged( - new XdsEndpointResource.EdsUpdate(resource, localityLbEndpointsMap, dropOverloads)); - } - } + private static EquivalentAddressGroup newInetSocketAddressEag(String ip, int port) { + return new EquivalentAddressGroup(newInetSocketAddress(ip, port)); + } - void deliverResourceNotFound(String resource) { - if (watchers.containsKey(resource)) { - watchers.get(resource).onResourceDoesNotExist(resource); - } - } + private static LbEndpoint.Builder newSocketLbEndpoint(String ip, int port) { + return LbEndpoint.newBuilder() + .setEndpoint(Endpoint.newBuilder() + .setAddress(newAddress(ip, port))) + .setHealthStatus(HealthStatus.HEALTHY); + } - void deliverError(Status error) { - for (ResourceWatcher watcher : watchers.values()) { - watcher.onError(error); - } - } + private static Address.Builder newAddress(String ip, int port) { + return Address.newBuilder() + .setSocketAddress(SocketAddress.newBuilder() + .setAddress(ip) + .setPortValue(port)); } private class FakeNameResolverProvider extends NameResolverProvider { @@ -1258,9 +1146,10 @@ protected int priority() { } } + private class FakeNameResolver extends NameResolver { private final URI targetUri; - private Listener2 listener; + protected Listener2 listener; private int refreshCount; private FakeNameResolver(URI targetUri) { @@ -1287,12 +1176,17 @@ public void shutdown() { resolvers.remove(this); } - private void deliverEndpointAddresses(List addresses) { - listener.onResult(ResolutionResult.newBuilder().setAddresses(addresses).build()); + protected void deliverEndpointAddresses(List addresses) { + syncContext.execute(() -> { + Status ret = listener.onResult2(ResolutionResult.newBuilder() + .setAddressesOrError(StatusOr.fromValue(addresses)).build()); + assertThat(ret.getCode()).isEqualTo(Status.Code.OK); + }); } - private void deliverError(Status error) { - listener.onError(error); + protected void deliverError(Status error) { + syncContext.execute(() -> listener.onResult2(ResolutionResult.newBuilder() + .setAddressesOrError(StatusOr.fromStatus(error)).build())); } } @@ -1331,7 +1225,6 @@ private final class FakeLoadBalancer extends LoadBalancer { private final Helper helper; private List addresses; private Object config; - private Status upstreamError; private boolean shutdown; FakeLoadBalancer(String name, Helper helper) { @@ -1348,7 +1241,6 @@ public Status acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { @Override public void handleNameResolutionError(Status error) { - upstreamError = error; } @Override diff --git a/xds/src/test/java/io/grpc/xds/ControlPlaneRule.java b/xds/src/test/java/io/grpc/xds/ControlPlaneRule.java index 1ddf9620434..3665e16b6bf 100644 --- a/xds/src/test/java/io/grpc/xds/ControlPlaneRule.java +++ b/xds/src/test/java/io/grpc/xds/ControlPlaneRule.java @@ -22,7 +22,9 @@ import static io.grpc.xds.XdsTestControlPlaneService.ADS_TYPE_URL_RDS; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Lists; import com.google.protobuf.Any; +import com.google.protobuf.BoolValue; import com.google.protobuf.Message; import com.google.protobuf.UInt32Value; import io.envoyproxy.envoy.config.cluster.v3.Cluster; @@ -55,6 +57,7 @@ import io.grpc.InsecureServerCredentials; import io.grpc.NameResolverRegistry; import io.grpc.Server; +import java.io.IOException; import java.util.Collections; import java.util.Map; import java.util.UUID; @@ -86,9 +89,11 @@ public class ControlPlaneRule extends TestWatcher { private XdsTestControlPlaneService controlPlaneService; private XdsTestLoadReportingService loadReportingService; private XdsNameResolverProvider nameResolverProvider; + private int port; // Only change from 0 to actual port used in the server. public ControlPlaneRule() { serverHostName = "test-server"; + this.port = 0; } public ControlPlaneRule setServerHostName(String serverHostName) { @@ -115,11 +120,7 @@ public Server getServer() { try { controlPlaneService = new XdsTestControlPlaneService(); loadReportingService = new XdsTestLoadReportingService(); - server = Grpc.newServerBuilderForPort(0, InsecureServerCredentials.create()) - .addService(controlPlaneService) - .addService(loadReportingService) - .build() - .start(); + createAndStartXdsServer(); } catch (Exception e) { throw new AssertionError("unable to start the control plane server", e); } @@ -144,6 +145,42 @@ public Server getServer() { NameResolverRegistry.getDefaultRegistry().deregister(nameResolverProvider); } + /** + * Will shutdown existing server if needed. + * Then creates a new server in the same way as {@link #starting(Description)} and starts it. + */ + public void restartXdsServer() { + + if (getServer() != null && !getServer().isTerminated()) { + getServer().shutdownNow(); + try { + if (!getServer().awaitTermination(5, TimeUnit.SECONDS)) { + logger.log(Level.SEVERE, "Timed out waiting for server shutdown"); + } + } catch (InterruptedException e) { + throw new AssertionError("unable to shut down control plane server", e); + } + } + + try { + createAndStartXdsServer(); + } catch (Exception e) { + throw new AssertionError("unable to restart the control plane server", e); + } + } + + private void createAndStartXdsServer() throws IOException { + server = Grpc.newServerBuilderForPort(port, InsecureServerCredentials.create()) + .addService(controlPlaneService) + .addService(loadReportingService) + .build() + .start(); + + if (port == 0) { + port = server.getPort(); + } + } + /** * For test purpose, use boostrapOverride to programmatically provide bootstrap info. */ @@ -159,7 +196,7 @@ public Server getServer() { "channel_creds", Collections.singletonList( ImmutableMap.of("type", "insecure") ), - "server_features", Collections.singletonList("xds_v3") + "server_features", Lists.newArrayList("xds_v3", "trusted_xds_server") ) ), "server_listener_resource_name_template", SERVER_LISTENER_TEMPLATE_NO_REPLACEMENT @@ -173,44 +210,70 @@ void setLdsConfig(Listener serverListener, Listener clientListener) { } void setRdsConfig(RouteConfiguration routeConfiguration) { - getService().setXdsConfig(ADS_TYPE_URL_RDS, ImmutableMap.of(RDS_NAME, routeConfiguration)); + setRdsConfig(RDS_NAME, routeConfiguration); + } + + public void setRdsConfig(String rdsName, RouteConfiguration routeConfiguration) { + getService().setXdsConfig(ADS_TYPE_URL_RDS, ImmutableMap.of(rdsName, routeConfiguration)); } void setCdsConfig(Cluster cluster) { + setCdsConfig(CLUSTER_NAME, cluster); + } + + void setCdsConfig(String clusterName, Cluster cluster) { getService().setXdsConfig(ADS_TYPE_URL_CDS, - ImmutableMap.of(CLUSTER_NAME, cluster)); + ImmutableMap.of(clusterName, cluster)); } void setEdsConfig(ClusterLoadAssignment clusterLoadAssignment) { + setEdsConfig(EDS_NAME, clusterLoadAssignment); + } + + void setEdsConfig(String edsName, ClusterLoadAssignment clusterLoadAssignment) { getService().setXdsConfig(ADS_TYPE_URL_EDS, - ImmutableMap.of(EDS_NAME, clusterLoadAssignment)); + ImmutableMap.of(edsName, clusterLoadAssignment)); } /** * Builds a new default RDS configuration. */ static RouteConfiguration buildRouteConfiguration(String authority) { - io.envoyproxy.envoy.config.route.v3.VirtualHost virtualHost = VirtualHost.newBuilder() - .addDomains(authority) - .addRoutes( - Route.newBuilder() - .setMatch( - RouteMatch.newBuilder().setPrefix("/").build()) - .setRoute( - RouteAction.newBuilder().setCluster(CLUSTER_NAME).build()).build()).build(); - return RouteConfiguration.newBuilder().setName(RDS_NAME).addVirtualHosts(virtualHost).build(); + return buildRouteConfiguration(authority, RDS_NAME, CLUSTER_NAME); + } + + static RouteConfiguration buildRouteConfiguration(String authority, String rdsName, + String clusterName) { + io.envoyproxy.envoy.config.route.v3.VirtualHost.Builder vhBuilder = + io.envoyproxy.envoy.config.route.v3.VirtualHost.newBuilder() + .setName(rdsName) + .addDomains(authority) + .addRoutes( + Route.newBuilder() + .setMatch( + RouteMatch.newBuilder().setPrefix("/").build()) + .setRoute( + RouteAction.newBuilder().setCluster(clusterName) + .setAutoHostRewrite(BoolValue.newBuilder().setValue(true).build()) + .build())); + io.envoyproxy.envoy.config.route.v3.VirtualHost virtualHost = vhBuilder.build(); + return RouteConfiguration.newBuilder().setName(rdsName).addVirtualHosts(virtualHost).build(); } /** * Builds a new default CDS configuration. */ static Cluster buildCluster() { + return buildCluster(CLUSTER_NAME, EDS_NAME); + } + + static Cluster buildCluster(String clusterName, String edsName) { return Cluster.newBuilder() - .setName(CLUSTER_NAME) + .setName(clusterName) .setType(Cluster.DiscoveryType.EDS) .setEdsClusterConfig( Cluster.EdsClusterConfig.newBuilder() - .setServiceName(EDS_NAME) + .setServiceName(edsName) .setEdsConfig( ConfigSource.newBuilder() .setAds(AggregatedConfigSource.newBuilder().build()) @@ -223,21 +286,29 @@ static Cluster buildCluster() { /** * Builds a new default EDS configuration. */ - static ClusterLoadAssignment buildClusterLoadAssignment(String hostName, int port) { + static ClusterLoadAssignment buildClusterLoadAssignment( + String hostAddress, String endpointHostname, int port) { + return buildClusterLoadAssignment(hostAddress, endpointHostname, port, EDS_NAME); + } + + static ClusterLoadAssignment buildClusterLoadAssignment( + String hostAddress, String endpointHostname, int port, String edsName) { + Address address = Address.newBuilder() .setSocketAddress( - SocketAddress.newBuilder().setAddress(hostName).setPortValue(port).build()).build(); + SocketAddress.newBuilder().setAddress(hostAddress).setPortValue(port).build()).build(); LocalityLbEndpoints endpoints = LocalityLbEndpoints.newBuilder() .setLoadBalancingWeight(UInt32Value.of(10)) .setPriority(0) .addLbEndpoints( LbEndpoint.newBuilder() .setEndpoint( - Endpoint.newBuilder().setAddress(address).build()) + Endpoint.newBuilder() + .setAddress(address).setHostname(endpointHostname).build()) .setHealthStatus(HealthStatus.HEALTHY) .build()).build(); return ClusterLoadAssignment.newBuilder() - .setClusterName(EDS_NAME) + .setClusterName(edsName) .addEndpoints(endpoints) .build(); } @@ -246,6 +317,10 @@ static ClusterLoadAssignment buildClusterLoadAssignment(String hostName, int por * Builds a new client listener. */ static Listener buildClientListener(String name) { + return buildClientListener(name, RDS_NAME); + } + + static Listener buildClientListener(String name, String rdsName) { HttpFilter httpFilter = HttpFilter.newBuilder() .setName("terminal-filter") .setTypedConfig(Any.pack(Router.newBuilder().build())) @@ -256,7 +331,7 @@ static Listener buildClientListener(String name) { .HttpConnectionManager.newBuilder() .setRds( Rds.newBuilder() - .setRouteConfigName(RDS_NAME) + .setRouteConfigName(rdsName) .setConfigSource( ConfigSource.newBuilder() .setAds(AggregatedConfigSource.getDefaultInstance()))) @@ -306,10 +381,14 @@ static Listener buildServerListener() { .setFilterChainMatch(filterChainMatch) .addFilters(filter) .build(); + Address address = Address.newBuilder() + .setSocketAddress(SocketAddress.newBuilder().setAddress("0.0.0.0").setPortValue(0)) + .build(); return Listener.newBuilder() .setName(SERVER_LISTENER_TEMPLATE_NO_REPLACEMENT) .setTrafficDirection(TrafficDirection.INBOUND) .addFilterChains(filterChain) + .setAddress(address) .build(); } } diff --git a/xds/src/test/java/io/grpc/xds/CsdsServiceTest.java b/xds/src/test/java/io/grpc/xds/CsdsServiceTest.java index 0ac024d1e48..e8bd7461736 100644 --- a/xds/src/test/java/io/grpc/xds/CsdsServiceTest.java +++ b/xds/src/test/java/io/grpc/xds/CsdsServiceTest.java @@ -39,6 +39,7 @@ import io.envoyproxy.envoy.type.matcher.v3.NodeMatcher; import io.grpc.Deadline; import io.grpc.InsecureChannelCredentials; +import io.grpc.MetricRecorder; import io.grpc.Status; import io.grpc.Status.Code; import io.grpc.StatusRuntimeException; @@ -54,6 +55,7 @@ import io.grpc.xds.client.XdsClient.ResourceMetadata; import io.grpc.xds.client.XdsClient.ResourceMetadata.ResourceMetadataStatus; import io.grpc.xds.client.XdsResourceType; +import java.util.ArrayList; import java.util.Collection; import java.util.HashMap; import java.util.List; @@ -85,6 +87,7 @@ public class CsdsServiceTest { private static final XdsResourceType CDS = XdsClusterResource.getInstance(); private static final XdsResourceType RDS = XdsRouteConfigureResource.getInstance(); private static final XdsResourceType EDS = XdsEndpointResource.getInstance(); + public static final String FAKE_CLIENT_SCOPE = "fake"; @RunWith(JUnit4.class) public static class ServiceTests { @@ -105,7 +108,7 @@ public void setUp() { // because true->false return mutation prevents fetchClientStatus from completing the request. csdsStub = ClientStatusDiscoveryServiceGrpc .newBlockingStub(grpcServerRule.getChannel()) - .withDeadline(Deadline.after(3, TimeUnit.SECONDS)); + .withDeadline(Deadline.after(30, TimeUnit.SECONDS)); csdsAsyncStub = ClientStatusDiscoveryServiceGrpc.newStub(grpcServerRule.getChannel()); } @@ -169,6 +172,9 @@ public void fetchClientConfig_interruptedException() { grpcServerRule.getServiceRegistry() .addService(new CsdsService(new FakeXdsClientPoolFactory(throwingXdsClient))); + // Hack to prevent the interrupted exception from propagating through to the client stub. + grpcServerRule.getChannel().getState(true); + try { ClientStatusResponse response = csdsStub.fetchClientStatus(REQUEST); fail("Should've failed, got response: " + response); @@ -195,13 +201,13 @@ public void streamClientStatus_happyPath() { @Override @Nullable - public ObjectPool get() { + public ObjectPool get(String target) { // xDS client not ready on the first call, then becomes ready. if (!calledOnce) { calledOnce = true; return null; } else { - return super.get(); + return super.get(target); } } }); @@ -264,11 +270,51 @@ public void streamClientStatus_onClientError() { assertThat(responseObserver.getError()).isNull(); } + @Test + public void multipleXdsClients() { + FakeXdsClient xdsClient1 = new FakeXdsClient(); + FakeXdsClient xdsClient2 = new FakeXdsClient(); + Map clientMap = new HashMap<>(); + clientMap.put("target1", xdsClient1); + clientMap.put("target2", xdsClient2); + FakeXdsClientPoolFactory factory = new FakeXdsClientPoolFactory(clientMap); + CsdsService csdsService = new CsdsService(factory); + grpcServerRule.getServiceRegistry().addService(csdsService); + + StreamRecorder responseObserver = StreamRecorder.create(); + StreamObserver requestObserver = + csdsAsyncStub.streamClientStatus(responseObserver); + + requestObserver.onNext(REQUEST); + requestObserver.onCompleted(); + + List responses = responseObserver.getValues(); + assertThat(responses).hasSize(1); + Collection targets = verifyMultiResponse(responses.get(0), 2); + assertThat(targets).containsExactly("target1", "target2"); + responseObserver.onCompleted(); + } + private void verifyResponse(ClientStatusResponse response) { assertThat(response.getConfigCount()).isEqualTo(1); ClientConfig clientConfig = response.getConfig(0); verifyClientConfigNode(clientConfig); - verifyClientConfigNoResources(XDS_CLIENT_NO_RESOURCES, clientConfig); + assertThat(clientConfig.getGenericXdsConfigsList()).isEmpty(); + assertThat(clientConfig.getClientScope()).isEmpty(); + } + + private Collection verifyMultiResponse(ClientStatusResponse response, int numExpected) { + assertThat(response.getConfigCount()).isEqualTo(numExpected); + + List clientScopes = new ArrayList<>(); + for (int i = 0; i < numExpected; i++) { + ClientConfig clientConfig = response.getConfig(i); + verifyClientConfigNode(clientConfig); + assertThat(clientConfig.getGenericXdsConfigsList()).isEmpty(); + clientScopes.add(clientConfig.getClientScope()); + } + + return clientScopes; } private void verifyRequestInvalidResponseStatus(Status status) { @@ -320,6 +366,8 @@ public void metadataStatusToClientStatus() { .isEqualTo(ClientResourceStatus.ACKED); assertThat(CsdsService.metadataStatusToClientStatus(ResourceMetadataStatus.NACKED)) .isEqualTo(ClientResourceStatus.NACKED); + assertThat(CsdsService.metadataStatusToClientStatus(ResourceMetadataStatus.TIMEOUT)) + .isEqualTo(ClientResourceStatus.TIMEOUT); } @Test @@ -336,50 +384,42 @@ public void getClientConfigForXdsClient_subscribedResourcesToGenericXdsConfig() .put(EDS, ImmutableMap.of("subscribedResourceName.EDS", METADATA_ACKED_EDS)) .buildOrThrow(); } - - @Override - public Map> getSubscribedResourceTypesWithTypeUrl() { - return ImmutableMap.of( - LDS.typeUrl(), LDS, - RDS.typeUrl(), RDS, - CDS.typeUrl(), CDS, - EDS.typeUrl(), EDS - ); - } }; - ClientConfig clientConfig = CsdsService.getClientConfigForXdsClient(fakeXdsClient); + ClientConfig clientConfig = CsdsService.getClientConfigForXdsClient(fakeXdsClient, + FAKE_CLIENT_SCOPE); verifyClientConfigNode(clientConfig); + assertThat(clientConfig.getClientScope()).isEqualTo(FAKE_CLIENT_SCOPE); // Minimal verification to confirm that the data/metadata XdsClient provides, // is propagated to the correct resource types. int xdsConfigCount = clientConfig.getGenericXdsConfigsCount(); assertThat(xdsConfigCount).isEqualTo(4); - Map, GenericXdsConfig> configDumps = mapConfigDumps(fakeXdsClient, - clientConfig); - assertThat(configDumps.keySet()).containsExactly(LDS, RDS, CDS, EDS); + Map configDumps = mapConfigDumps(clientConfig); + assertThat(configDumps.keySet()) + .containsExactly(LDS.typeUrl(), RDS.typeUrl(), CDS.typeUrl(), EDS.typeUrl()); // LDS. - GenericXdsConfig genericXdsConfigLds = configDumps.get(LDS); + GenericXdsConfig genericXdsConfigLds = configDumps.get(LDS.typeUrl()); assertThat(genericXdsConfigLds.getName()).isEqualTo("subscribedResourceName.LDS"); assertThat(genericXdsConfigLds.getClientStatus()).isEqualTo(ClientResourceStatus.ACKED); assertThat(genericXdsConfigLds.getVersionInfo()).isEqualTo(VERSION_ACK_LDS); assertThat(genericXdsConfigLds.getXdsConfig()).isEqualTo(RAW_LISTENER); // RDS. - GenericXdsConfig genericXdsConfigRds = configDumps.get(RDS); + GenericXdsConfig genericXdsConfigRds = configDumps.get(RDS.typeUrl()); assertThat(genericXdsConfigRds.getClientStatus()).isEqualTo(ClientResourceStatus.ACKED); assertThat(genericXdsConfigRds.getVersionInfo()).isEqualTo(VERSION_ACK_RDS); assertThat(genericXdsConfigRds.getXdsConfig()).isEqualTo(RAW_ROUTE_CONFIGURATION); // CDS. - GenericXdsConfig genericXdsConfigCds = configDumps.get(CDS); + GenericXdsConfig genericXdsConfigCds = configDumps.get(CDS.typeUrl()); assertThat(genericXdsConfigCds.getClientStatus()).isEqualTo(ClientResourceStatus.ACKED); assertThat(genericXdsConfigCds.getVersionInfo()).isEqualTo(VERSION_ACK_CDS); assertThat(genericXdsConfigCds.getXdsConfig()).isEqualTo(RAW_CLUSTER); // RDS. - GenericXdsConfig genericXdsConfigEds = configDumps.get(EDS); + GenericXdsConfig genericXdsConfigEds = configDumps.get(EDS.typeUrl()); assertThat(genericXdsConfigEds.getClientStatus()).isEqualTo(ClientResourceStatus.ACKED); assertThat(genericXdsConfigEds.getVersionInfo()).isEqualTo(VERSION_ACK_EDS); assertThat(genericXdsConfigEds.getXdsConfig()).isEqualTo(RAW_CLUSTER_LOAD_ASSIGNMENT); @@ -387,24 +427,14 @@ public Map> getSubscribedResourceTypesWithTypeUrl() { @Test public void getClientConfigForXdsClient_noSubscribedResources() throws InterruptedException { - ClientConfig clientConfig = CsdsService.getClientConfigForXdsClient(XDS_CLIENT_NO_RESOURCES); + ClientConfig clientConfig = + CsdsService.getClientConfigForXdsClient(XDS_CLIENT_NO_RESOURCES, FAKE_CLIENT_SCOPE); verifyClientConfigNode(clientConfig); - verifyClientConfigNoResources(XDS_CLIENT_NO_RESOURCES, clientConfig); + assertThat(clientConfig.getGenericXdsConfigsList()).isEmpty(); + assertThat(clientConfig.getClientScope()).isEqualTo(FAKE_CLIENT_SCOPE); } } - /** - * Assuming {@link MetadataToProtoTests} passes, and metadata converted to corresponding - * config dumps correctly, perform a minimal verification of the general shape of ClientConfig. - */ - private static void verifyClientConfigNoResources(FakeXdsClient xdsClient, - ClientConfig clientConfig) { - int xdsConfigCount = clientConfig.getGenericXdsConfigsCount(); - assertThat(xdsConfigCount).isEqualTo(0); - Map, GenericXdsConfig> configDumps = mapConfigDumps(xdsClient, clientConfig); - assertThat(configDumps).isEmpty(); - } - /** * Assuming {@link EnvoyProtoDataTest#convertNode} passes, perform a minimal check, * just verify the node itself is the one we expect. @@ -415,21 +445,17 @@ private static void verifyClientConfigNode(ClientConfig clientConfig) { assertThat(node).isEqualTo(BOOTSTRAP_NODE.toEnvoyProtoNode()); } - private static Map, GenericXdsConfig> mapConfigDumps(FakeXdsClient client, - ClientConfig config) { - Map, GenericXdsConfig> xdsConfigMap = new HashMap<>(); + private static Map mapConfigDumps(ClientConfig config) { + Map xdsConfigMap = new HashMap<>(); List xdsConfigList = config.getGenericXdsConfigsList(); for (GenericXdsConfig genericXdsConfig : xdsConfigList) { - XdsResourceType type = client.getSubscribedResourceTypesWithTypeUrl() - .get(genericXdsConfig.getTypeUrl()); - assertThat(type).isNotNull(); - assertThat(xdsConfigMap).doesNotContainKey(type); - xdsConfigMap.put(type, genericXdsConfig); + assertThat(xdsConfigMap).doesNotContainKey(genericXdsConfig.getTypeUrl()); + xdsConfigMap.put(genericXdsConfig.getTypeUrl(), genericXdsConfig); } return xdsConfigMap; } - private static class FakeXdsClient extends XdsClient implements XdsClient.ResourceStore { + private static class FakeXdsClient extends XdsClient { protected Map, Map> getSubscribedResourcesMetadata() { return ImmutableMap.of(); @@ -445,34 +471,33 @@ private static class FakeXdsClient extends XdsClient implements XdsClient.Resour public BootstrapInfo getBootstrapInfo() { return BOOTSTRAP_INFO; } - - @Nullable - @Override - public Collection getSubscribedResources(ServerInfo serverInfo, - XdsResourceType type) { - return null; - } - - @Override - public Map> getSubscribedResourceTypesWithTypeUrl() { - return ImmutableMap.of(); - } } private static class FakeXdsClientPoolFactory implements XdsClientPoolFactory { - @Nullable private final XdsClient xdsClient; + private final Map xdsClientMap = new HashMap<>(); + private boolean isOldStyle; private FakeXdsClientPoolFactory(@Nullable XdsClient xdsClient) { - this.xdsClient = xdsClient; + if (xdsClient != null) { + xdsClientMap.put("", xdsClient); + } + isOldStyle = true; + } + + private FakeXdsClientPoolFactory(Map xdsClientMap) { + this.xdsClientMap.putAll(xdsClientMap); + isOldStyle = false; } @Override @Nullable - public ObjectPool get() { + public ObjectPool get(String target) { + String targetToUse = isOldStyle ? "" : target; + return new ObjectPool() { @Override public XdsClient getObject() { - return xdsClient; + return xdsClientMap.get(targetToUse); } @Override @@ -483,12 +508,13 @@ public XdsClient returnObject(Object object) { } @Override - public void setBootstrapOverride(Map bootstrap) { - throw new UnsupportedOperationException("Should not be called"); + public List getTargets() { + return new ArrayList<>(xdsClientMap.keySet()); } @Override - public ObjectPool getOrCreate() { + public ObjectPool getOrCreate( + String target, BootstrapInfo bootstrapInfo, MetricRecorder metricRecorder) { throw new UnsupportedOperationException("Should not be called"); } } diff --git a/xds/src/test/java/io/grpc/xds/DataPlaneRule.java b/xds/src/test/java/io/grpc/xds/DataPlaneRule.java index faa79444071..b308419d142 100644 --- a/xds/src/test/java/io/grpc/xds/DataPlaneRule.java +++ b/xds/src/test/java/io/grpc/xds/DataPlaneRule.java @@ -48,6 +48,7 @@ public class DataPlaneRule extends TestWatcher { private static final Logger logger = Logger.getLogger(DataPlaneRule.class.getName()); private static final String SERVER_HOST_NAME = "test-server"; + static final String ENDPOINT_HOST_NAME = "endpoint-host-name"; private static final String SCHEME = "test-xds"; private final ControlPlaneRule controlPlane; @@ -73,7 +74,8 @@ public Server getServer() { */ public ManagedChannel getManagedChannel() { ManagedChannel channel = Grpc.newChannelBuilder(SCHEME + ":///" + SERVER_HOST_NAME, - InsecureChannelCredentials.create()).build(); + InsecureChannelCredentials.create()) + .build(); channels.add(channel); return channel; } @@ -98,7 +100,7 @@ protected void starting(Description description) { InetSocketAddress edsInetSocketAddress = (InetSocketAddress) server.getListenSockets().get(0); controlPlane.setEdsConfig( ControlPlaneRule.buildClusterLoadAssignment(edsInetSocketAddress.getHostName(), - edsInetSocketAddress.getPort())); + ENDPOINT_HOST_NAME, edsInetSocketAddress.getPort())); } @Override @@ -124,10 +126,12 @@ protected void finished(Description description) { } private void startServer(Map bootstrapOverride) throws Exception { + final String[] authority = new String[1]; ServerInterceptor metadataInterceptor = new ServerInterceptor() { @Override public ServerCall.Listener interceptCall(ServerCall call, Metadata requestHeaders, ServerCallHandler next) { + authority[0] = call.getAuthority(); logger.fine("Received following metadata: " + requestHeaders); // Make a copy of the headers so that it can be read in a thread-safe manner when copying @@ -155,8 +159,12 @@ public void close(Status status, Metadata trailers) { @Override public void unaryRpc( SimpleRequest request, StreamObserver responseObserver) { + String responseMsg = "Hi, xDS!"; + if (authority[0] != null) { + responseMsg += " Authority= " + authority[0]; + } SimpleResponse response = - SimpleResponse.newBuilder().setResponseMessage("Hi, xDS!").build(); + SimpleResponse.newBuilder().setResponseMessage(responseMsg).build(); responseObserver.onNext(response); responseObserver.onCompleted(); } diff --git a/xds/src/test/java/io/grpc/xds/ExtAuthzConfigParserTest.java b/xds/src/test/java/io/grpc/xds/ExtAuthzConfigParserTest.java new file mode 100644 index 00000000000..fa2718cbe63 --- /dev/null +++ b/xds/src/test/java/io/grpc/xds/ExtAuthzConfigParserTest.java @@ -0,0 +1,297 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.fail; + +import com.google.protobuf.Any; +import com.google.protobuf.BoolValue; +import io.envoyproxy.envoy.config.common.mutation_rules.v3.HeaderMutationRules; +import io.envoyproxy.envoy.config.core.v3.GrpcService; +import io.envoyproxy.envoy.config.core.v3.HeaderValue; +import io.envoyproxy.envoy.config.core.v3.RuntimeFeatureFlag; +import io.envoyproxy.envoy.config.core.v3.RuntimeFractionalPercent; +import io.envoyproxy.envoy.extensions.filters.http.ext_authz.v3.ExtAuthz; +import io.envoyproxy.envoy.extensions.grpc_service.call_credentials.access_token.v3.AccessTokenCredentials; +import io.envoyproxy.envoy.extensions.grpc_service.channel_credentials.google_default.v3.GoogleDefaultCredentials; +import io.envoyproxy.envoy.type.matcher.v3.ListStringMatcher; +import io.envoyproxy.envoy.type.matcher.v3.RegexMatcher; +import io.envoyproxy.envoy.type.matcher.v3.StringMatcher; +import io.envoyproxy.envoy.type.v3.FractionalPercent; +import io.envoyproxy.envoy.type.v3.FractionalPercent.DenominatorType; +import io.grpc.Status; +import io.grpc.xds.client.Bootstrapper.BootstrapInfo; +import io.grpc.xds.client.Bootstrapper.ServerInfo; +import io.grpc.xds.client.EnvoyProtoData.Node; +import io.grpc.xds.internal.Matchers; +import io.grpc.xds.internal.extauthz.ExtAuthzConfig; +import io.grpc.xds.internal.extauthz.ExtAuthzParseException; +import io.grpc.xds.internal.headermutations.HeaderMutationRulesConfig; +import java.util.Collections; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class ExtAuthzConfigParserTest { + + private static final Any GOOGLE_DEFAULT_CHANNEL_CREDS = + Any.pack(GoogleDefaultCredentials.newBuilder().build()); + private static final Any FAKE_ACCESS_TOKEN_CALL_CREDS = + Any.pack(AccessTokenCredentials.newBuilder().setToken("fake-token").build()); + + private static BootstrapInfo dummyBootstrapInfo() { + return BootstrapInfo.builder() + .servers( + Collections.singletonList(ServerInfo.create("test_target", Collections.emptyMap()))) + .node(Node.newBuilder().build()).build(); + } + + private static ServerInfo dummyServerInfo() { + return ServerInfo.create("test_target", Collections.emptyMap(), false, true, false, false); + } + + private ExtAuthz.Builder extAuthzBuilder; + + @Before + public void setUp() { + extAuthzBuilder = ExtAuthz.newBuilder() + .setGrpcService(GrpcService.newBuilder().setGoogleGrpc(GrpcService.GoogleGrpc.newBuilder() + .setTargetUri("test-cluster") + .addChannelCredentialsPlugin(GOOGLE_DEFAULT_CHANNEL_CREDS) + .addCallCredentialsPlugin(FAKE_ACCESS_TOKEN_CALL_CREDS).build()) + .build()); + } + + @Test + public void parse_missingGrpcService_throws() { + ExtAuthz extAuthz = ExtAuthz.newBuilder().build(); + try { + ExtAuthzConfigParser.parse(extAuthz, + dummyBootstrapInfo(), + dummyServerInfo()); + fail("Expected ExtAuthzParseException"); + } catch (ExtAuthzParseException e) { + assertThat(e).hasMessageThat() + .isEqualTo("unsupported ExtAuthz service type: only grpc_service is supported"); + } + } + + @Test + public void parse_invalidGrpcService_throws() { + ExtAuthz extAuthz = ExtAuthz.newBuilder() + .setGrpcService(GrpcService.newBuilder().build()) + .build(); + try { + ExtAuthzConfigParser.parse(extAuthz, + dummyBootstrapInfo(), + dummyServerInfo()); + fail("Expected ExtAuthzParseException"); + } catch (ExtAuthzParseException e) { + assertThat(e).hasMessageThat().startsWith("Failed to parse GrpcService config:"); + } + } + + @Test + public void parse_invalidAllowExpression_throws() { + ExtAuthz extAuthz = extAuthzBuilder + .setDecoderHeaderMutationRules(HeaderMutationRules.newBuilder() + .setAllowExpression(RegexMatcher.newBuilder().setRegex("[invalid").build()).build()) + .build(); + try { + ExtAuthzConfigParser.parse(extAuthz, + dummyBootstrapInfo(), + dummyServerInfo()); + fail("Expected ExtAuthzParseException"); + } catch (ExtAuthzParseException e) { + assertThat(e).hasMessageThat().startsWith("Invalid regex pattern for allow_expression:"); + } + } + + @Test + public void parse_invalidDisallowExpression_throws() { + ExtAuthz extAuthz = extAuthzBuilder + .setDecoderHeaderMutationRules(HeaderMutationRules.newBuilder() + .setDisallowExpression(RegexMatcher.newBuilder().setRegex("[invalid").build()).build()) + .build(); + try { + ExtAuthzConfigParser.parse(extAuthz, + dummyBootstrapInfo(), + dummyServerInfo()); + fail("Expected ExtAuthzParseException"); + } catch (ExtAuthzParseException e) { + assertThat(e).hasMessageThat().startsWith("Invalid regex pattern for disallow_expression:"); + } + } + + @Test + public void parse_success() throws ExtAuthzParseException { + ExtAuthz extAuthz = + extAuthzBuilder + .setGrpcService(extAuthzBuilder.getGrpcServiceBuilder() + .setTimeout(com.google.protobuf.Duration.newBuilder().setSeconds(5).build()) + .addInitialMetadata( + HeaderValue.newBuilder().setKey("key").setValue("value").build()) + .build()) + .setFailureModeAllow(true).setFailureModeAllowHeaderAdd(true) + .setIncludePeerCertificate(true) + .setStatusOnError( + io.envoyproxy.envoy.type.v3.HttpStatus.newBuilder().setCodeValue(403).build()) + .setDenyAtDisable( + RuntimeFeatureFlag.newBuilder().setDefaultValue(BoolValue.of(true)).build()) + .setFilterEnabled(RuntimeFractionalPercent.newBuilder() + .setDefaultValue(FractionalPercent.newBuilder().setNumerator(50) + .setDenominator(DenominatorType.TEN_THOUSAND).build()) + .build()) + .setAllowedHeaders(ListStringMatcher.newBuilder() + .addPatterns(StringMatcher.newBuilder().setExact("allowed-header").build()).build()) + .setDisallowedHeaders(ListStringMatcher.newBuilder() + .addPatterns(StringMatcher.newBuilder().setPrefix("disallowed-").build()).build()) + .setDecoderHeaderMutationRules(HeaderMutationRules.newBuilder() + .setAllowExpression(RegexMatcher.newBuilder().setRegex("allow.*").build()) + .setDisallowExpression(RegexMatcher.newBuilder().setRegex("disallow.*").build()) + .setDisallowAll(BoolValue.of(true)).setDisallowIsError(BoolValue.of(true)).build()) + .build(); + + ExtAuthzConfig config = ExtAuthzConfigParser.parse(extAuthz, + dummyBootstrapInfo(), + dummyServerInfo()); + + assertThat(config.grpcService().googleGrpc().target()).isEqualTo("test-cluster"); + assertThat(config.grpcService().timeout().get().getSeconds()).isEqualTo(5); + assertThat(config.grpcService().initialMetadata()).isNotEmpty(); + assertThat(config.failureModeAllow()).isTrue(); + assertThat(config.failureModeAllowHeaderAdd()).isTrue(); + assertThat(config.includePeerCertificate()).isTrue(); + assertThat(config.statusOnError().getCode()).isEqualTo(Status.PERMISSION_DENIED.getCode()); + assertThat(config.statusOnError().getDescription()).isEqualTo("HTTP status code 403"); + assertThat(config.denyAtDisable()).isTrue(); + assertThat(config.filterEnabled()).isEqualTo(Matchers.FractionMatcher.create(50, 10_000)); + assertThat(config.allowedHeaders()).hasSize(1); + assertThat(config.allowedHeaders().get(0).matches("allowed-header")).isTrue(); + assertThat(config.disallowedHeaders()).hasSize(1); + assertThat(config.disallowedHeaders().get(0).matches("disallowed-foo")).isTrue(); + assertThat(config.decoderHeaderMutationRules().isPresent()).isTrue(); + HeaderMutationRulesConfig rules = config.decoderHeaderMutationRules().get(); + assertThat(rules.allowExpression().get().pattern()).isEqualTo("allow.*"); + assertThat(rules.disallowExpression().get().pattern()).isEqualTo("disallow.*"); + assertThat(rules.disallowAll()).isTrue(); + assertThat(rules.disallowIsError()).isTrue(); + } + + @Test + public void parse_saneDefaults() throws ExtAuthzParseException { + ExtAuthz extAuthz = extAuthzBuilder.build(); + + ExtAuthzConfig config = ExtAuthzConfigParser.parse(extAuthz, + dummyBootstrapInfo(), + dummyServerInfo()); + + assertThat(config.failureModeAllow()).isFalse(); + assertThat(config.failureModeAllowHeaderAdd()).isFalse(); + assertThat(config.includePeerCertificate()).isFalse(); + assertThat(config.statusOnError()).isEqualTo(Status.PERMISSION_DENIED); + assertThat(config.denyAtDisable()).isFalse(); + assertThat(config.filterEnabled()).isEqualTo(Matchers.FractionMatcher.create(100, 100)); + assertThat(config.allowedHeaders()).isEmpty(); + assertThat(config.disallowedHeaders()).isEmpty(); + assertThat(config.decoderHeaderMutationRules().isPresent()).isFalse(); + } + + @Test + public void parse_headerMutationRules_allowExpressionOnly() throws ExtAuthzParseException { + ExtAuthz extAuthz = extAuthzBuilder + .setDecoderHeaderMutationRules(HeaderMutationRules.newBuilder() + .setAllowExpression(RegexMatcher.newBuilder().setRegex("allow.*").build()).build()) + .build(); + + ExtAuthzConfig config = ExtAuthzConfigParser.parse(extAuthz, + dummyBootstrapInfo(), + dummyServerInfo()); + + assertThat(config.decoderHeaderMutationRules().isPresent()).isTrue(); + HeaderMutationRulesConfig rules = config.decoderHeaderMutationRules().get(); + assertThat(rules.allowExpression().get().pattern()).isEqualTo("allow.*"); + assertThat(rules.disallowExpression().isPresent()).isFalse(); + } + + @Test + public void parse_headerMutationRules_disallowExpressionOnly() throws ExtAuthzParseException { + ExtAuthz extAuthz = + extAuthzBuilder.setDecoderHeaderMutationRules(HeaderMutationRules.newBuilder() + .setDisallowExpression(RegexMatcher.newBuilder().setRegex("disallow.*").build()) + .build()).build(); + + ExtAuthzConfig config = ExtAuthzConfigParser.parse(extAuthz, + dummyBootstrapInfo(), + dummyServerInfo()); + + assertThat(config.decoderHeaderMutationRules().isPresent()).isTrue(); + HeaderMutationRulesConfig rules = config.decoderHeaderMutationRules().get(); + assertThat(rules.allowExpression().isPresent()).isFalse(); + assertThat(rules.disallowExpression().get().pattern()).isEqualTo("disallow.*"); + } + + @Test + public void parse_filterEnabled_hundred() throws ExtAuthzParseException { + ExtAuthz extAuthz = extAuthzBuilder + .setFilterEnabled(RuntimeFractionalPercent.newBuilder().setDefaultValue(FractionalPercent + .newBuilder().setNumerator(25).setDenominator(DenominatorType.HUNDRED).build()).build()) + .build(); + + ExtAuthzConfig config = ExtAuthzConfigParser.parse(extAuthz, + dummyBootstrapInfo(), + dummyServerInfo()); + + assertThat(config.filterEnabled()).isEqualTo(Matchers.FractionMatcher.create(25, 100)); + } + + @Test + public void parse_filterEnabled_million() throws ExtAuthzParseException { + ExtAuthz extAuthz = extAuthzBuilder + .setFilterEnabled( + RuntimeFractionalPercent.newBuilder().setDefaultValue(FractionalPercent.newBuilder() + .setNumerator(123456).setDenominator(DenominatorType.MILLION).build()).build()) + .build(); + + ExtAuthzConfig config = ExtAuthzConfigParser.parse(extAuthz, + dummyBootstrapInfo(), + dummyServerInfo()); + + assertThat(config.filterEnabled()) + .isEqualTo(Matchers.FractionMatcher.create(123456, 1_000_000)); + } + + @Test + public void parse_filterEnabled_unrecognizedDenominator() { + ExtAuthz extAuthz = extAuthzBuilder.setFilterEnabled(RuntimeFractionalPercent.newBuilder() + .setDefaultValue( + FractionalPercent.newBuilder().setNumerator(1).setDenominatorValue(4).build()) + .build()).build(); + + try { + ExtAuthzConfigParser.parse(extAuthz, + dummyBootstrapInfo(), + dummyServerInfo()); + fail("Expected ExtAuthzParseException"); + } catch (ExtAuthzParseException e) { + assertThat(e).hasMessageThat().isEqualTo("Unknown denominator type: UNRECOGNIZED"); + } + } +} diff --git a/xds/src/test/java/io/grpc/xds/FailingClientInterceptor.java b/xds/src/test/java/io/grpc/xds/FailingClientInterceptor.java new file mode 100644 index 00000000000..c8b32f376ee --- /dev/null +++ b/xds/src/test/java/io/grpc/xds/FailingClientInterceptor.java @@ -0,0 +1,50 @@ +/* + * Copyright 2026 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds; + +import static java.util.Objects.requireNonNull; + +import io.grpc.CallOptions; +import io.grpc.Channel; +import io.grpc.ClientCall; +import io.grpc.ClientInterceptor; +import io.grpc.Metadata; +import io.grpc.MethodDescriptor; +import io.grpc.NoopClientCall; +import io.grpc.Status; + +/** + * An interceptor that fails all RPCs with the provided status. + */ +final class FailingClientInterceptor implements ClientInterceptor { + private final Status status; + + public FailingClientInterceptor(Status status) { + this.status = requireNonNull(status, "status"); + } + + @Override + public ClientCall interceptCall( + MethodDescriptor method, CallOptions callOptions, Channel next) { + return new NoopClientCall() { + @Override + public void start(Listener responseListener, Metadata headers) { + responseListener.onClose(status, new Metadata()); + } + }; + } +} diff --git a/xds/src/test/java/io/grpc/xds/FakeControlPlaneXdsIntegrationTest.java b/xds/src/test/java/io/grpc/xds/FakeControlPlaneXdsIntegrationTest.java index 16e6d22631f..a273c6f3ebf 100644 --- a/xds/src/test/java/io/grpc/xds/FakeControlPlaneXdsIntegrationTest.java +++ b/xds/src/test/java/io/grpc/xds/FakeControlPlaneXdsIntegrationTest.java @@ -18,23 +18,42 @@ package io.grpc.xds; import static com.google.common.truth.Truth.assertThat; +import static io.grpc.xds.DataPlaneRule.ENDPOINT_HOST_NAME; +import static io.grpc.xds.XdsTestControlPlaneService.ADS_TYPE_URL_CDS; +import static io.grpc.xds.XdsTestControlPlaneService.ADS_TYPE_URL_EDS; import static org.junit.Assert.assertEquals; import com.github.xds.type.v3.TypedStruct; +import com.google.common.collect.ImmutableMap; import com.google.protobuf.Any; import com.google.protobuf.Struct; import com.google.protobuf.Value; +import io.envoyproxy.envoy.config.cluster.v3.Cluster; import io.envoyproxy.envoy.config.cluster.v3.Cluster.LbPolicy; import io.envoyproxy.envoy.config.cluster.v3.LoadBalancingPolicy; import io.envoyproxy.envoy.config.cluster.v3.LoadBalancingPolicy.Policy; +import io.envoyproxy.envoy.config.core.v3.Address; +import io.envoyproxy.envoy.config.core.v3.SocketAddress; import io.envoyproxy.envoy.config.core.v3.TypedExtensionConfig; +import io.envoyproxy.envoy.config.endpoint.v3.ClusterLoadAssignment; +import io.envoyproxy.envoy.config.endpoint.v3.Endpoint; +import io.envoyproxy.envoy.config.endpoint.v3.LbEndpoint; +import io.envoyproxy.envoy.config.endpoint.v3.LocalityLbEndpoints; +import io.envoyproxy.envoy.config.route.v3.Route; +import io.envoyproxy.envoy.config.route.v3.RouteAction; +import io.envoyproxy.envoy.config.route.v3.RouteConfiguration; +import io.envoyproxy.envoy.config.route.v3.RouteMatch; +import io.envoyproxy.envoy.config.route.v3.VirtualHost; import io.envoyproxy.envoy.extensions.load_balancing_policies.wrr_locality.v3.WrrLocality; import io.grpc.CallOptions; import io.grpc.Channel; import io.grpc.ClientCall; import io.grpc.ClientInterceptor; +import io.grpc.ClientStreamTracer; +import io.grpc.FlagResetRule; import io.grpc.ForwardingClientCall.SimpleForwardingClientCall; import io.grpc.ForwardingClientCallListener; +import io.grpc.InternalFeatureFlags; import io.grpc.LoadBalancerRegistry; import io.grpc.ManagedChannel; import io.grpc.Metadata; @@ -42,10 +61,15 @@ import io.grpc.testing.protobuf.SimpleRequest; import io.grpc.testing.protobuf.SimpleResponse; import io.grpc.testing.protobuf.SimpleServiceGrpc; +import java.net.InetSocketAddress; +import java.util.Arrays; +import org.junit.Before; import org.junit.Rule; import org.junit.Test; import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; +import org.junit.runners.Parameterized; +import org.junit.runners.Parameterized.Parameter; +import org.junit.runners.Parameterized.Parameters; /** * Xds integration tests using a local control plane, implemented in {@link @@ -67,27 +91,58 @@ * 3) Construct EDS config w/ test server address from 2). Set CDS and EDS Config at the Control * Plane. Then start the test xDS client (requires EDS to do xDS name resolution). */ -@RunWith(JUnit4.class) +@RunWith(Parameterized.class) public class FakeControlPlaneXdsIntegrationTest { @Rule(order = 0) public ControlPlaneRule controlPlane = new ControlPlaneRule(); @Rule(order = 1) public DataPlaneRule dataPlane = new DataPlaneRule(controlPlane); + @Rule(order = 2) + public final FlagResetRule flagResetRule = new FlagResetRule(); + + @Parameters(name = "enableRfc3986UrisParam={0}") + public static Iterable data() { + return Arrays.asList(new Object[][] {{true}, {false}}); + } + + @Parameter public boolean enableRfc3986UrisParam; + + @Before + public void setupRfc3986UrisFeatureFlag() throws Exception { + flagResetRule.setFlagForTest( + InternalFeatureFlags::setRfc3986UrisEnabled, enableRfc3986UrisParam); + } @Test public void pingPong() throws Exception { ManagedChannel channel = dataPlane.getManagedChannel(); SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = SimpleServiceGrpc.newBlockingStub( channel); - SimpleRequest request = SimpleRequest.newBuilder() - .build(); + SimpleRequest request = SimpleRequest.getDefaultInstance(); SimpleResponse goldenResponse = SimpleResponse.newBuilder() - .setResponseMessage("Hi, xDS!") + .setResponseMessage("Hi, xDS! Authority= test-server") .build(); assertEquals(goldenResponse, blockingStub.unaryRpc(request)); } + @Test + public void pingPong_edsEndpoint_authorityOverride() throws Exception { + System.setProperty("GRPC_EXPERIMENTAL_XDS_AUTHORITY_REWRITE", "true"); + try { + ManagedChannel channel = dataPlane.getManagedChannel(); + SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = SimpleServiceGrpc.newBlockingStub( + channel); + SimpleRequest request = SimpleRequest.getDefaultInstance(); + SimpleResponse goldenResponse = SimpleResponse.newBuilder() + .setResponseMessage("Hi, xDS! Authority= " + ENDPOINT_HOST_NAME) + .build(); + assertEquals(goldenResponse, blockingStub.unaryRpc(request)); + } finally { + System.clearProperty("GRPC_EXPERIMENTAL_XDS_AUTHORITY_REWRITE"); + } + } + @Test public void pingPong_metadataLoadBalancer() throws Exception { MetadataLoadBalancerProvider metadataLbProvider = new MetadataLoadBalancerProvider(); @@ -118,10 +173,9 @@ public void pingPong_metadataLoadBalancer() throws Exception { // We add an interceptor to catch the response headers from the server. SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = SimpleServiceGrpc.newBlockingStub( dataPlane.getManagedChannel()).withInterceptors(responseHeaderInterceptor); - SimpleRequest request = SimpleRequest.newBuilder() - .build(); + SimpleRequest request = SimpleRequest.getDefaultInstance(); SimpleResponse goldenResponse = SimpleResponse.newBuilder() - .setResponseMessage("Hi, xDS!") + .setResponseMessage("Hi, xDS! Authority= test-server") .build(); assertEquals(goldenResponse, blockingStub.unaryRpc(request)); @@ -133,6 +187,100 @@ public void pingPong_metadataLoadBalancer() throws Exception { } } + // Try to trigger "UNAVAILABLE: CDS encountered error: unable to find available subchannel for + // cluster cluster:cluster1" race, if XdsNameResolver updates its ConfigSelector before + // cluster_manager config. + @Test + public void changeClusterForRoute() throws Exception { + // Start with route to cluster0 + InetSocketAddress edsInetSocketAddress + = (InetSocketAddress) dataPlane.getServer().getListenSockets().get(0); + controlPlane.getService().setXdsConfig( + ADS_TYPE_URL_EDS, + ImmutableMap.of( + "eds-service-0", + ControlPlaneRule.buildClusterLoadAssignment( + edsInetSocketAddress.getHostName(), "", edsInetSocketAddress.getPort(), + "eds-service-0"), + "eds-service-1", + ControlPlaneRule.buildClusterLoadAssignment( + edsInetSocketAddress.getHostName(), "", edsInetSocketAddress.getPort(), + "eds-service-1"))); + controlPlane.getService().setXdsConfig( + ADS_TYPE_URL_CDS, + ImmutableMap.of( + "cluster0", + ControlPlaneRule.buildCluster("cluster0", "eds-service-0"), + "cluster1", + ControlPlaneRule.buildCluster("cluster1", "eds-service-1"))); + controlPlane.setRdsConfig(RouteConfiguration.newBuilder() + .setName("route-config.googleapis.com") + .addVirtualHosts(VirtualHost.newBuilder() + .addDomains("test-server") + .addRoutes(Route.newBuilder() + .setMatch(RouteMatch.newBuilder().setPrefix("/").build()) + .setRoute(RouteAction.newBuilder().setCluster("cluster0").build()) + .build()) + .build()) + .build()); + + class ClusterClientStreamTracer extends ClientStreamTracer { + boolean usedCluster1; + + @Override + public void addOptionalLabel(String key, String value) { + if ("grpc.lb.backend_service".equals(key)) { + usedCluster1 = "cluster1".equals(value); + } + } + } + + ClusterClientStreamTracer tracer = new ClusterClientStreamTracer(); + ClientStreamTracer.Factory tracerFactory = new ClientStreamTracer.Factory() { + @Override + public ClientStreamTracer newClientStreamTracer( + ClientStreamTracer.StreamInfo info, Metadata headers) { + return tracer; + } + }; + ClientInterceptor tracerInterceptor = new ClientInterceptor() { + @Override + public ClientCall interceptCall( + MethodDescriptor method, CallOptions callOptions, Channel next) { + return next.newCall(method, callOptions.withStreamTracerFactory(tracerFactory)); + } + }; + SimpleServiceGrpc.SimpleServiceBlockingStub stub = SimpleServiceGrpc + .newBlockingStub(dataPlane.getManagedChannel()) + .withInterceptors(tracerInterceptor); + SimpleRequest request = SimpleRequest.getDefaultInstance(); + SimpleResponse goldenResponse = SimpleResponse.newBuilder() + .setResponseMessage("Hi, xDS! Authority= test-server") + .build(); + assertThat(stub.unaryRpc(request)).isEqualTo(goldenResponse); + assertThat(tracer.usedCluster1).isFalse(); + + // Check for errors when swapping route to cluster1 + controlPlane.setRdsConfig(RouteConfiguration.newBuilder() + .setName("route-config.googleapis.com") + .addVirtualHosts(VirtualHost.newBuilder() + .addDomains("test-server") + .addRoutes(Route.newBuilder() + .setMatch(RouteMatch.newBuilder().setPrefix("/").build()) + .setRoute(RouteAction.newBuilder().setCluster("cluster1").build()) + .build()) + .build()) + .build()); + + for (int j = 0; j < 10; j++) { + stub.unaryRpc(request); + if (tracer.usedCluster1) { + break; + } + } + assertThat(tracer.usedCluster1).isTrue(); + } + // Captures response headers from the server. private static class ResponseHeaderClientInterceptor implements ClientInterceptor { Metadata reponseHeaders; @@ -172,11 +320,44 @@ public void pingPong_ringHash() { ManagedChannel channel = dataPlane.getManagedChannel(); SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = SimpleServiceGrpc.newBlockingStub( channel); - SimpleRequest request = SimpleRequest.newBuilder() - .build(); + SimpleRequest request = SimpleRequest.getDefaultInstance(); SimpleResponse goldenResponse = SimpleResponse.newBuilder() - .setResponseMessage("Hi, xDS!") + .setResponseMessage("Hi, xDS! Authority= test-server") .build(); assertEquals(goldenResponse, blockingStub.unaryRpc(request)); } + + @Test + public void pingPong_logicalDns_authorityOverride() { + System.setProperty("GRPC_EXPERIMENTAL_XDS_AUTHORITY_REWRITE", "true"); + try { + InetSocketAddress serverAddress = + (InetSocketAddress) dataPlane.getServer().getListenSockets().get(0); + controlPlane.setCdsConfig( + ControlPlaneRule.buildCluster().toBuilder() + .setType(Cluster.DiscoveryType.LOGICAL_DNS) + .setLoadAssignment( + ClusterLoadAssignment.newBuilder().addEndpoints( + LocalityLbEndpoints.newBuilder().addLbEndpoints( + LbEndpoint.newBuilder().setEndpoint( + Endpoint.newBuilder().setAddress( + Address.newBuilder().setSocketAddress( + SocketAddress.newBuilder() + .setAddress("localhost") + .setPortValue(serverAddress.getPort())))))) + .build()) + .build()); + + ManagedChannel channel = dataPlane.getManagedChannel(); + SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = SimpleServiceGrpc.newBlockingStub( + channel); + SimpleRequest request = SimpleRequest.getDefaultInstance(); + SimpleResponse goldenResponse = SimpleResponse.newBuilder() + .setResponseMessage("Hi, xDS! Authority= localhost:" + serverAddress.getPort()) + .build(); + assertEquals(goldenResponse, blockingStub.unaryRpc(request)); + } finally { + System.clearProperty("GRPC_EXPERIMENTAL_XDS_AUTHORITY_REWRITE"); + } + } } diff --git a/xds/src/test/java/io/grpc/xds/FaultFilterTest.java b/xds/src/test/java/io/grpc/xds/FaultFilterTest.java index f85f29ec0a3..9033d1e636e 100644 --- a/xds/src/test/java/io/grpc/xds/FaultFilterTest.java +++ b/xds/src/test/java/io/grpc/xds/FaultFilterTest.java @@ -26,6 +26,10 @@ import io.envoyproxy.envoy.type.v3.FractionalPercent.DenominatorType; import io.grpc.Status.Code; import io.grpc.internal.GrpcUtil; +import io.grpc.xds.client.Bootstrapper.BootstrapInfo; +import io.grpc.xds.client.Bootstrapper.ServerInfo; +import io.grpc.xds.client.EnvoyProtoData.Node; +import java.util.Collections; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -33,16 +37,28 @@ /** Tests for {@link FaultFilter}. */ @RunWith(JUnit4.class) public class FaultFilterTest { + private static final FaultFilter.Provider FILTER_PROVIDER = new FaultFilter.Provider(); + + @Test + public void filterType_clientOnly() { + assertThat(FILTER_PROVIDER.isClientFilter()).isTrue(); + assertThat(FILTER_PROVIDER.isServerFilter()).isFalse(); + } @Test public void parseFaultAbort_convertHttpStatus() { Any rawConfig = Any.pack( HTTPFault.newBuilder().setAbort(FaultAbort.newBuilder().setHttpStatus(404)).build()); - FaultConfig faultConfig = FaultFilter.INSTANCE.parseFilterConfig(rawConfig).config; + FaultConfig faultConfig = FILTER_PROVIDER.parseFilterConfig( + rawConfig, getFilterContext()).config; + assertThat(faultConfig.faultAbort()).isNotNull(); assertThat(faultConfig.faultAbort().status().getCode()) .isEqualTo(GrpcUtil.httpStatusToGrpcStatus(404).getCode()); + FaultConfig faultConfigOverride = - FaultFilter.INSTANCE.parseFilterConfigOverride(rawConfig).config; + FILTER_PROVIDER.parseFilterConfigOverride( + rawConfig, getFilterContext()).config; + assertThat(faultConfigOverride.faultAbort()).isNotNull(); assertThat(faultConfigOverride.faultAbort().status().getCode()) .isEqualTo(GrpcUtil.httpStatusToGrpcStatus(404).getCode()); } @@ -54,7 +70,7 @@ public void parseFaultAbort_withHeaderAbort() { .setPercentage(FractionalPercent.newBuilder() .setNumerator(20).setDenominator(DenominatorType.HUNDRED)) .setHeaderAbort(HeaderAbort.getDefaultInstance()).build(); - FaultConfig.FaultAbort faultAbort = FaultFilter.parseFaultAbort(proto).config; + FaultConfig.FaultAbort faultAbort = FaultFilter.Provider.parseFaultAbort(proto).config; assertThat(faultAbort.headerAbort()).isTrue(); assertThat(faultAbort.percent().numerator()).isEqualTo(20); assertThat(faultAbort.percent().denominatorType()) @@ -68,7 +84,7 @@ public void parseFaultAbort_withHttpStatus() { .setPercentage(FractionalPercent.newBuilder() .setNumerator(100).setDenominator(DenominatorType.TEN_THOUSAND)) .setHttpStatus(400).build(); - FaultConfig.FaultAbort res = FaultFilter.parseFaultAbort(proto).config; + FaultConfig.FaultAbort res = FaultFilter.Provider.parseFaultAbort(proto).config; assertThat(res.percent().numerator()).isEqualTo(100); assertThat(res.percent().denominatorType()) .isEqualTo(FaultConfig.FractionalPercent.DenominatorType.TEN_THOUSAND); @@ -82,10 +98,23 @@ public void parseFaultAbort_withGrpcStatus() { .setPercentage(FractionalPercent.newBuilder() .setNumerator(600).setDenominator(DenominatorType.MILLION)) .setGrpcStatus(Code.DEADLINE_EXCEEDED.value()).build(); - FaultConfig.FaultAbort faultAbort = FaultFilter.parseFaultAbort(proto).config; + FaultConfig.FaultAbort faultAbort = FaultFilter.Provider.parseFaultAbort(proto).config; assertThat(faultAbort.percent().numerator()).isEqualTo(600); assertThat(faultAbort.percent().denominatorType()) .isEqualTo(FaultConfig.FractionalPercent.DenominatorType.MILLION); assertThat(faultAbort.status().getCode()).isEqualTo(Code.DEADLINE_EXCEEDED); } + + private static Filter.FilterConfigParseContext getFilterContext() { + return Filter.FilterConfigParseContext.builder() + .bootstrapInfo(BootstrapInfo.builder() + .servers(Collections.singletonList( + ServerInfo.create( + "test_target", Collections.emptyMap()))) + .node(Node.newBuilder().build()) + .build()) + .serverInfo(ServerInfo.create( + "test_target", Collections.emptyMap(), false, true, false, false)) + .build(); + } } diff --git a/xds/src/test/java/io/grpc/xds/FilterChainMatchingProtocolNegotiatorsTest.java b/xds/src/test/java/io/grpc/xds/FilterChainMatchingProtocolNegotiatorsTest.java index 685102477cc..722f915dbea 100644 --- a/xds/src/test/java/io/grpc/xds/FilterChainMatchingProtocolNegotiatorsTest.java +++ b/xds/src/test/java/io/grpc/xds/FilterChainMatchingProtocolNegotiatorsTest.java @@ -25,6 +25,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import com.google.common.net.InetAddresses; import com.google.common.util.concurrent.SettableFuture; import io.grpc.ServerInterceptor; import io.grpc.internal.TestUtils.NoopChannelLogger; @@ -58,7 +59,6 @@ import io.netty.handler.codec.http2.Http2Settings; import java.net.InetSocketAddress; import java.net.SocketAddress; -import java.net.UnknownHostException; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; @@ -318,7 +318,8 @@ public void destPrefixRangeMatch() throws Exception { EnvoyServerProtoData.FilterChainMatch filterChainMatchWithMatch = EnvoyServerProtoData.FilterChainMatch.create( 0, - ImmutableList.of(EnvoyServerProtoData.CidrRange.create("10.1.2.0", 24)), + ImmutableList.of(EnvoyServerProtoData.CidrRange.create( + InetAddresses.forString("10.1.2.0"), 24)), ImmutableList.of(), ImmutableList.of(), EnvoyServerProtoData.ConnectionSourceType.ANY, @@ -360,7 +361,8 @@ public void destPrefixRangeMismatch_returnDefaultFilterChain() EnvoyServerProtoData.FilterChainMatch filterChainMatchWithMismatch = EnvoyServerProtoData.FilterChainMatch.create( 0, - ImmutableList.of(EnvoyServerProtoData.CidrRange.create("10.2.2.0", 24)), + ImmutableList.of(EnvoyServerProtoData.CidrRange.create( + InetAddresses.forString("10.2.2.0"), 24)), ImmutableList.of(), ImmutableList.of(), EnvoyServerProtoData.ConnectionSourceType.ANY, @@ -403,7 +405,8 @@ public void dest0LengthPrefixRange() EnvoyServerProtoData.FilterChainMatch filterChainMatch0Length = EnvoyServerProtoData.FilterChainMatch.create( 0, - ImmutableList.of(EnvoyServerProtoData.CidrRange.create("10.2.2.0", 0)), + ImmutableList.of(EnvoyServerProtoData.CidrRange.create( + InetAddresses.forString("10.2.2.0"), 0)), ImmutableList.of(), ImmutableList.of(), EnvoyServerProtoData.ConnectionSourceType.ANY, @@ -444,7 +447,8 @@ public void destPrefixRange_moreSpecificWins() EnvoyServerProtoData.FilterChainMatch filterChainMatchLessSpecific = EnvoyServerProtoData.FilterChainMatch.create( 0, - ImmutableList.of(EnvoyServerProtoData.CidrRange.create("10.1.2.0", 24)), + ImmutableList.of(EnvoyServerProtoData.CidrRange.create( + InetAddresses.forString("10.1.2.0"), 24)), ImmutableList.of(), ImmutableList.of(), EnvoyServerProtoData.ConnectionSourceType.ANY, @@ -461,7 +465,8 @@ public void destPrefixRange_moreSpecificWins() EnvoyServerProtoData.FilterChainMatch filterChainMatchMoreSpecific = EnvoyServerProtoData.FilterChainMatch.create( 0, - ImmutableList.of(EnvoyServerProtoData.CidrRange.create("10.1.2.2", 31)), + ImmutableList.of(EnvoyServerProtoData.CidrRange.create( + InetAddresses.forString("10.1.2.2"), 31)), ImmutableList.of(), ImmutableList.of(), EnvoyServerProtoData.ConnectionSourceType.ANY, @@ -519,7 +524,8 @@ public void destPrefixRange_emptyListLessSpecific() EnvoyServerProtoData.FilterChainMatch filterChainMatchMoreSpecific = EnvoyServerProtoData.FilterChainMatch.create( 0, - ImmutableList.of(EnvoyServerProtoData.CidrRange.create("8.0.0.0", 5)), + ImmutableList.of(EnvoyServerProtoData.CidrRange.create( + InetAddresses.forString("8.0.0.0"), 5)), ImmutableList.of(), ImmutableList.of(), EnvoyServerProtoData.ConnectionSourceType.ANY, @@ -559,7 +565,8 @@ public void destPrefixRangeIpv6_moreSpecificWins() EnvoyServerProtoData.FilterChainMatch filterChainMatchLessSpecific = EnvoyServerProtoData.FilterChainMatch.create( 0, - ImmutableList.of(EnvoyServerProtoData.CidrRange.create("FE80:0:0:0:0:0:0:0", 60)), + ImmutableList.of(EnvoyServerProtoData.CidrRange.create( + InetAddresses.forString("FE80:0:0:0:0:0:0:0"), 60)), ImmutableList.of(), ImmutableList.of(), EnvoyServerProtoData.ConnectionSourceType.ANY, @@ -577,7 +584,8 @@ public void destPrefixRangeIpv6_moreSpecificWins() EnvoyServerProtoData.FilterChainMatch.create( 0, ImmutableList.of( - EnvoyServerProtoData.CidrRange.create("FE80:0000:0000:0000:0202:0:0:0", 80)), + EnvoyServerProtoData.CidrRange.create( + InetAddresses.forString("FE80:0000:0000:0000:0202:0:0:0"), 80)), ImmutableList.of(), ImmutableList.of(), EnvoyServerProtoData.ConnectionSourceType.ANY, @@ -620,8 +628,10 @@ public void destPrefixRange_moreSpecificWith2Wins() EnvoyServerProtoData.FilterChainMatch.create( 0, ImmutableList.of( - EnvoyServerProtoData.CidrRange.create("10.1.2.0", 24), - EnvoyServerProtoData.CidrRange.create(LOCAL_IP, 32)), + EnvoyServerProtoData.CidrRange.create( + InetAddresses.forString("10.1.2.0"), 24), + EnvoyServerProtoData.CidrRange.create( + InetAddresses.forString(LOCAL_IP), 32)), ImmutableList.of(), ImmutableList.of(), EnvoyServerProtoData.ConnectionSourceType.ANY, @@ -638,7 +648,8 @@ public void destPrefixRange_moreSpecificWith2Wins() EnvoyServerProtoData.FilterChainMatch filterChainMatchLessSpecific = EnvoyServerProtoData.FilterChainMatch.create( 0, - ImmutableList.of(EnvoyServerProtoData.CidrRange.create("10.1.2.2", 31)), + ImmutableList.of(EnvoyServerProtoData.CidrRange.create( + InetAddresses.forString("10.1.2.2"), 31)), ImmutableList.of(), ImmutableList.of(), EnvoyServerProtoData.ConnectionSourceType.ANY, @@ -763,8 +774,10 @@ public void sourcePrefixRange_moreSpecificWith2Wins() ImmutableList.of(), ImmutableList.of(), ImmutableList.of( - EnvoyServerProtoData.CidrRange.create("10.4.2.0", 24), - EnvoyServerProtoData.CidrRange.create(REMOTE_IP, 32)), + EnvoyServerProtoData.CidrRange.create( + InetAddresses.forString("10.4.2.0"), 24), + EnvoyServerProtoData.CidrRange.create( + InetAddresses.forString(REMOTE_IP), 32)), EnvoyServerProtoData.ConnectionSourceType.ANY, ImmutableList.of(), ImmutableList.of(), @@ -781,7 +794,8 @@ public void sourcePrefixRange_moreSpecificWith2Wins() 0, ImmutableList.of(), ImmutableList.of(), - ImmutableList.of(EnvoyServerProtoData.CidrRange.create("10.4.2.2", 31)), + ImmutableList.of(EnvoyServerProtoData.CidrRange.create( + InetAddresses.forString("10.4.2.2"), 31)), EnvoyServerProtoData.ConnectionSourceType.ANY, ImmutableList.of(), ImmutableList.of(), @@ -811,8 +825,7 @@ filterChainLessSpecific, randomConfig("no-match")), } @Test - public void sourcePrefixRange_2Matchers_expectException() - throws UnknownHostException { + public void sourcePrefixRange_2Matchers_expectException() { ChannelHandler next = new ChannelInboundHandlerAdapter() { @Override public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { @@ -831,8 +844,10 @@ public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { ImmutableList.of(), ImmutableList.of(), ImmutableList.of( - EnvoyServerProtoData.CidrRange.create("10.4.2.0", 24), - EnvoyServerProtoData.CidrRange.create("192.168.10.2", 32)), + EnvoyServerProtoData.CidrRange.create( + InetAddresses.forString("10.4.2.0"), 24), + EnvoyServerProtoData.CidrRange.create( + InetAddresses.forString("192.168.10.2"), 32)), EnvoyServerProtoData.ConnectionSourceType.ANY, ImmutableList.of(), ImmutableList.of(), @@ -848,7 +863,8 @@ public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { 0, ImmutableList.of(), ImmutableList.of(), - ImmutableList.of(EnvoyServerProtoData.CidrRange.create("10.4.2.0", 24)), + ImmutableList.of(EnvoyServerProtoData.CidrRange.create( + InetAddresses.forString("10.4.2.0"), 24)), EnvoyServerProtoData.ConnectionSourceType.ANY, ImmutableList.of(), ImmutableList.of(), @@ -890,8 +906,10 @@ public void sourcePortMatch_exactMatchWinsOverEmptyList() throws Exception { ImmutableList.of(), ImmutableList.of(), ImmutableList.of( - EnvoyServerProtoData.CidrRange.create("10.4.2.0", 24), - EnvoyServerProtoData.CidrRange.create("10.4.2.2", 31)), + EnvoyServerProtoData.CidrRange.create( + InetAddresses.forString("10.4.2.0"), 24), + EnvoyServerProtoData.CidrRange.create( + InetAddresses.forString("10.4.2.2"), 31)), EnvoyServerProtoData.ConnectionSourceType.ANY, ImmutableList.of(), ImmutableList.of(), @@ -908,7 +926,8 @@ public void sourcePortMatch_exactMatchWinsOverEmptyList() throws Exception { 0, ImmutableList.of(), ImmutableList.of(), - ImmutableList.of(EnvoyServerProtoData.CidrRange.create("10.4.2.2", 31)), + ImmutableList.of(EnvoyServerProtoData.CidrRange.create( + InetAddresses.forString("10.4.2.2"), 31)), EnvoyServerProtoData.ConnectionSourceType.ANY, ImmutableList.of(7000, 15000), ImmutableList.of(), @@ -966,7 +985,8 @@ public void filterChain_5stepMatch() throws Exception { PORT, ImmutableList.of(), ImmutableList.of(), - ImmutableList.of(EnvoyServerProtoData.CidrRange.create(REMOTE_IP, 32)), + ImmutableList.of(EnvoyServerProtoData.CidrRange.create( + InetAddresses.forString(REMOTE_IP), 32)), EnvoyServerProtoData.ConnectionSourceType.ANY, ImmutableList.of(), ImmutableList.of(), @@ -981,9 +1001,11 @@ public void filterChain_5stepMatch() throws Exception { EnvoyServerProtoData.FilterChainMatch filterChainMatch2 = EnvoyServerProtoData.FilterChainMatch.create( 0, - ImmutableList.of(EnvoyServerProtoData.CidrRange.create("10.1.2.0", 30)), + ImmutableList.of(EnvoyServerProtoData.CidrRange.create( + InetAddresses.forString("10.1.2.0"), 30)), ImmutableList.of(), - ImmutableList.of(EnvoyServerProtoData.CidrRange.create("10.4.0.0", 16)), + ImmutableList.of(EnvoyServerProtoData.CidrRange.create( + InetAddresses.forString("10.4.0.0"), 16)), EnvoyServerProtoData.ConnectionSourceType.ANY, ImmutableList.of(), ImmutableList.of(), @@ -997,8 +1019,10 @@ public void filterChain_5stepMatch() throws Exception { EnvoyServerProtoData.FilterChainMatch.create( 0, ImmutableList.of( - EnvoyServerProtoData.CidrRange.create("192.168.2.0", 24), - EnvoyServerProtoData.CidrRange.create("10.1.2.0", 30)), + EnvoyServerProtoData.CidrRange.create( + InetAddresses.forString("192.168.2.0"), 24), + EnvoyServerProtoData.CidrRange.create( + InetAddresses.forString("10.1.2.0"), 30)), ImmutableList.of(), ImmutableList.of(), EnvoyServerProtoData.ConnectionSourceType.SAME_IP_OR_LOOPBACK, @@ -1015,10 +1039,13 @@ public void filterChain_5stepMatch() throws Exception { EnvoyServerProtoData.FilterChainMatch.create( 0, ImmutableList.of( - EnvoyServerProtoData.CidrRange.create("10.1.0.0", 16), - EnvoyServerProtoData.CidrRange.create("10.1.2.0", 30)), + EnvoyServerProtoData.CidrRange.create( + InetAddresses.forString("10.1.0.0"), 16), + EnvoyServerProtoData.CidrRange.create( + InetAddresses.forString("10.1.2.0"), 30)), ImmutableList.of(), - ImmutableList.of(EnvoyServerProtoData.CidrRange.create("10.4.2.0", 24)), + ImmutableList.of(EnvoyServerProtoData.CidrRange.create( + InetAddresses.forString("10.4.2.0"), 24)), EnvoyServerProtoData.ConnectionSourceType.EXTERNAL, ImmutableList.of(16000, 9000), ImmutableList.of(), @@ -1034,12 +1061,16 @@ public void filterChain_5stepMatch() throws Exception { EnvoyServerProtoData.FilterChainMatch.create( 0, ImmutableList.of( - EnvoyServerProtoData.CidrRange.create("10.1.0.0", 16), - EnvoyServerProtoData.CidrRange.create("10.1.2.0", 30)), + EnvoyServerProtoData.CidrRange.create( + InetAddresses.forString("10.1.0.0"), 16), + EnvoyServerProtoData.CidrRange.create( + InetAddresses.forString("10.1.2.0"), 30)), ImmutableList.of(), ImmutableList.of( - EnvoyServerProtoData.CidrRange.create("10.4.2.0", 24), - EnvoyServerProtoData.CidrRange.create("192.168.2.0", 24)), + EnvoyServerProtoData.CidrRange.create( + InetAddresses.forString("10.4.2.0"), 24), + EnvoyServerProtoData.CidrRange.create( + InetAddresses.forString("192.168.2.0"), 24)), EnvoyServerProtoData.ConnectionSourceType.ANY, ImmutableList.of(15000, 8000), ImmutableList.of(), @@ -1053,7 +1084,8 @@ public void filterChain_5stepMatch() throws Exception { EnvoyServerProtoData.FilterChainMatch filterChainMatch6 = EnvoyServerProtoData.FilterChainMatch.create( 0, - ImmutableList.of(EnvoyServerProtoData.CidrRange.create("10.1.2.0", 29)), + ImmutableList.of(EnvoyServerProtoData.CidrRange.create( + InetAddresses.forString("10.1.2.0"), 29)), ImmutableList.of(), ImmutableList.of(), EnvoyServerProtoData.ConnectionSourceType.ANY, @@ -1093,7 +1125,6 @@ public void filterChain_5stepMatch() throws Exception { } @Test - @SuppressWarnings("deprecation") public void filterChainMatch_unsupportedMatchers() throws Exception { EnvoyServerProtoData.DownstreamTlsContext tlsContext1 = CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "ROOTCA"); @@ -1105,8 +1136,8 @@ public void filterChainMatch_unsupportedMatchers() throws Exception { EnvoyServerProtoData.FilterChainMatch filterChainMatch1 = EnvoyServerProtoData.FilterChainMatch.create( 0 /* destinationPort */, - ImmutableList.of( - EnvoyServerProtoData.CidrRange.create("10.1.0.0", 16)) /* prefixRange */, + ImmutableList.of(EnvoyServerProtoData.CidrRange.create( + InetAddresses.forString("10.1.0.0"), 16)) /* prefixRange */, ImmutableList.of("managed-mtls", "h2") /* applicationProtocol */, ImmutableList.of() /* sourcePrefixRanges */, EnvoyServerProtoData.ConnectionSourceType.ANY /* sourceType */, @@ -1117,8 +1148,8 @@ public void filterChainMatch_unsupportedMatchers() throws Exception { EnvoyServerProtoData.FilterChainMatch filterChainMatch2 = EnvoyServerProtoData.FilterChainMatch.create( 0 /* destinationPort */, - ImmutableList.of( - EnvoyServerProtoData.CidrRange.create("10.0.0.0", 8)) /* prefixRange */, + ImmutableList.of(EnvoyServerProtoData.CidrRange.create( + InetAddresses.forString("10.0.0.0"), 8)) /* prefixRange */, ImmutableList.of() /* applicationProtocol */, ImmutableList.of() /* sourcePrefixRanges */, EnvoyServerProtoData.ConnectionSourceType.ANY /* sourceType */, @@ -1162,7 +1193,7 @@ public void filterChainMatch_unsupportedMatchers() throws Exception { assertThat(sslSet.get()).isEqualTo(defaultFilterChain.sslContextProviderSupplier()); assertThat(routingSettable.get()).isEqualTo(noopConfig); assertThat(sslSet.get().getTlsContext().getCommonTlsContext() - .getTlsCertificateCertificateProviderInstance() + .getTlsCertificateProviderInstance() .getCertificateName()).isEqualTo("CERT3"); } diff --git a/xds/src/test/java/io/grpc/xds/GcpAuthenticationFilterTest.java b/xds/src/test/java/io/grpc/xds/GcpAuthenticationFilterTest.java new file mode 100644 index 00000000000..788ab54726c --- /dev/null +++ b/xds/src/test/java/io/grpc/xds/GcpAuthenticationFilterTest.java @@ -0,0 +1,541 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds; + +import static com.google.common.truth.Truth.assertThat; +import static io.grpc.xds.XdsNameResolver.CLUSTER_SELECTION_KEY; +import static io.grpc.xds.XdsNameResolver.XDS_CONFIG_CALL_OPTION_KEY; +import static io.grpc.xds.XdsTestUtils.CLUSTER_NAME; +import static io.grpc.xds.XdsTestUtils.EDS_NAME; +import static io.grpc.xds.XdsTestUtils.ENDPOINT_HOSTNAME; +import static io.grpc.xds.XdsTestUtils.ENDPOINT_PORT; +import static io.grpc.xds.XdsTestUtils.RDS_NAME; +import static io.grpc.xds.XdsTestUtils.buildRouteConfiguration; +import static io.grpc.xds.XdsTestUtils.getWrrLbConfigAsMap; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNotSame; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertTrue; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.protobuf.Any; +import com.google.protobuf.Empty; +import com.google.protobuf.Message; +import com.google.protobuf.UInt64Value; +import io.envoyproxy.envoy.config.route.v3.RouteConfiguration; +import io.envoyproxy.envoy.extensions.filters.http.gcp_authn.v3.GcpAuthnFilterConfig; +import io.envoyproxy.envoy.extensions.filters.http.gcp_authn.v3.TokenCacheConfig; +import io.grpc.CallOptions; +import io.grpc.Channel; +import io.grpc.ClientCall; +import io.grpc.ClientInterceptor; +import io.grpc.MethodDescriptor; +import io.grpc.Status; +import io.grpc.StatusOr; +import io.grpc.testing.TestMethodDescriptors; +import io.grpc.xds.Endpoints.LbEndpoint; +import io.grpc.xds.Endpoints.LocalityLbEndpoints; +import io.grpc.xds.GcpAuthenticationFilter.AudienceMetadataParser.AudienceWrapper; +import io.grpc.xds.GcpAuthenticationFilter.FailingClientCall; +import io.grpc.xds.GcpAuthenticationFilter.GcpAuthenticationConfig; +import io.grpc.xds.XdsClusterResource.CdsUpdate; +import io.grpc.xds.XdsConfig.XdsClusterConfig; +import io.grpc.xds.XdsConfig.XdsClusterConfig.EndpointConfig; +import io.grpc.xds.XdsEndpointResource.EdsUpdate; +import io.grpc.xds.XdsListenerResource.LdsUpdate; +import io.grpc.xds.XdsRouteConfigureResource.RdsUpdate; +import io.grpc.xds.client.Bootstrapper.BootstrapInfo; +import io.grpc.xds.client.Bootstrapper.ServerInfo; +import io.grpc.xds.client.EnvoyProtoData.Node; +import io.grpc.xds.client.Locality; +import io.grpc.xds.client.XdsResourceType; +import io.grpc.xds.client.XdsResourceType.ResourceInvalidException; +import java.io.IOException; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.ArgumentCaptor; +import org.mockito.Mockito; + +@RunWith(JUnit4.class) +public class GcpAuthenticationFilterTest { + private static final GcpAuthenticationFilter.Provider FILTER_PROVIDER = + new GcpAuthenticationFilter.Provider(); + private static final LdsUpdate ldsUpdate = getLdsUpdate(); + private static final EdsUpdate edsUpdate = getEdsUpdate(); + private static final RdsUpdate rdsUpdate = getRdsUpdate(); + private static final CdsUpdate cdsUpdate = getCdsUpdate(); + + @Before + public void setUp() { + System.setProperty("GRPC_EXPERIMENTAL_XDS_GCP_AUTHENTICATION_FILTER", "true"); + } + + @Test + public void testNewFilterInstancesPerFilterName() { + assertThat(new GcpAuthenticationFilter("FILTER_INSTANCE_NAME1", 10)) + .isNotEqualTo(new GcpAuthenticationFilter("FILTER_INSTANCE_NAME1", 10)); + } + + @Test + public void filterType_clientOnly() { + assertThat(FILTER_PROVIDER.isClientFilter()).isTrue(); + assertThat(FILTER_PROVIDER.isServerFilter()).isFalse(); + } + + @Test + public void testParseFilterConfig_withValidConfig() { + GcpAuthnFilterConfig config = GcpAuthnFilterConfig.newBuilder() + .setCacheConfig(TokenCacheConfig.newBuilder().setCacheSize(UInt64Value.of(20))) + .build(); + Any anyMessage = Any.pack(config); + + ConfigOrError result = + FILTER_PROVIDER.parseFilterConfig(anyMessage, getFilterContext()); + assertNotNull(result.config); + assertNull(result.errorDetail); + assertEquals(20L, result.config.getCacheSize()); + } + + @Test + public void testParseFilterConfig_withZeroCacheSize() { + GcpAuthnFilterConfig config = GcpAuthnFilterConfig.newBuilder() + .setCacheConfig(TokenCacheConfig.newBuilder().setCacheSize(UInt64Value.of(0))) + .build(); + Any anyMessage = Any.pack(config); + + ConfigOrError result = + FILTER_PROVIDER.parseFilterConfig(anyMessage, getFilterContext()); + assertNull(result.config); + assertNotNull(result.errorDetail); + assertTrue(result.errorDetail.contains("cache_config.cache_size must be greater than zero")); + } + + @Test + public void testParseFilterConfig_withInvalidMessageType() { + Message invalidMessage = Empty.getDefaultInstance(); + ConfigOrError result = + FILTER_PROVIDER.parseFilterConfig(invalidMessage, getFilterContext()); + + assertNull(result.config); + assertThat(result.errorDetail).contains("Invalid config type"); + } + + @Test + public void testClientInterceptor_success() throws IOException, ResourceInvalidException { + XdsConfig.XdsClusterConfig clusterConfig = new XdsConfig.XdsClusterConfig( + CLUSTER_NAME, + cdsUpdate, + new EndpointConfig(StatusOr.fromValue(edsUpdate))); + XdsConfig defaultXdsConfig = new XdsConfig.XdsConfigBuilder() + .setListener(ldsUpdate) + .setRoute(rdsUpdate) + .setVirtualHost(rdsUpdate.virtualHosts.get(0)) + .addCluster(CLUSTER_NAME, StatusOr.fromValue(clusterConfig)).build(); + CallOptions callOptionsWithXds = CallOptions.DEFAULT + .withOption(CLUSTER_SELECTION_KEY, "cluster:cluster0") + .withOption(XDS_CONFIG_CALL_OPTION_KEY, defaultXdsConfig); + GcpAuthenticationConfig config = new GcpAuthenticationConfig(10); + GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME", 10); + ClientInterceptor interceptor = filter.buildClientInterceptor(config, null, null); + MethodDescriptor methodDescriptor = TestMethodDescriptors.voidMethod(); + Channel mockChannel = Mockito.mock(Channel.class); + ArgumentCaptor callOptionsCaptor = ArgumentCaptor.forClass(CallOptions.class); + + interceptor.interceptCall(methodDescriptor, callOptionsWithXds, mockChannel); + + verify(mockChannel).newCall(eq(methodDescriptor), callOptionsCaptor.capture()); + CallOptions capturedOptions = callOptionsCaptor.getAllValues().get(0); + assertNotNull(capturedOptions.getCredentials()); + } + + @Test + public void testClientInterceptor_createsAndReusesCachedCredentials() + throws IOException, ResourceInvalidException { + XdsConfig.XdsClusterConfig clusterConfig = new XdsConfig.XdsClusterConfig( + CLUSTER_NAME, + cdsUpdate, + new EndpointConfig(StatusOr.fromValue(edsUpdate))); + XdsConfig defaultXdsConfig = new XdsConfig.XdsConfigBuilder() + .setListener(ldsUpdate) + .setRoute(rdsUpdate) + .setVirtualHost(rdsUpdate.virtualHosts.get(0)) + .addCluster(CLUSTER_NAME, StatusOr.fromValue(clusterConfig)).build(); + CallOptions callOptionsWithXds = CallOptions.DEFAULT + .withOption(CLUSTER_SELECTION_KEY, "cluster:cluster0") + .withOption(XDS_CONFIG_CALL_OPTION_KEY, defaultXdsConfig); + GcpAuthenticationConfig config = new GcpAuthenticationConfig(10); + GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME", 10); + ClientInterceptor interceptor = filter.buildClientInterceptor(config, null, null); + MethodDescriptor methodDescriptor = TestMethodDescriptors.voidMethod(); + Channel mockChannel = Mockito.mock(Channel.class); + ArgumentCaptor callOptionsCaptor = ArgumentCaptor.forClass(CallOptions.class); + + interceptor.interceptCall(methodDescriptor, callOptionsWithXds, mockChannel); + interceptor.interceptCall(methodDescriptor, callOptionsWithXds, mockChannel); + + verify(mockChannel, times(2)) + .newCall(eq(methodDescriptor), callOptionsCaptor.capture()); + CallOptions firstCapturedOptions = callOptionsCaptor.getAllValues().get(0); + CallOptions secondCapturedOptions = callOptionsCaptor.getAllValues().get(1); + assertNotNull(firstCapturedOptions.getCredentials()); + assertNotNull(secondCapturedOptions.getCredentials()); + assertSame(firstCapturedOptions.getCredentials(), secondCapturedOptions.getCredentials()); + } + + @Test + public void testClientInterceptor_withoutClusterSelectionKey() throws Exception { + GcpAuthenticationConfig config = new GcpAuthenticationConfig(10); + GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME", 10); + ClientInterceptor interceptor = filter.buildClientInterceptor(config, null, null); + MethodDescriptor methodDescriptor = TestMethodDescriptors.voidMethod(); + Channel mockChannel = mock(Channel.class); + CallOptions callOptionsWithXds = CallOptions.DEFAULT; + + ClientCall call = interceptor.interceptCall( + methodDescriptor, callOptionsWithXds, mockChannel); + + assertTrue(call instanceof FailingClientCall); + FailingClientCall clientCall = (FailingClientCall) call; + assertThat(clientCall.error.getDescription()).contains("does not contain cluster resource"); + } + + @Test + public void testClientInterceptor_clusterSelectionKeyWithoutPrefix() throws Exception { + XdsConfig.XdsClusterConfig clusterConfig = new XdsConfig.XdsClusterConfig( + CLUSTER_NAME, + cdsUpdate, + new EndpointConfig(StatusOr.fromValue(edsUpdate))); + XdsConfig defaultXdsConfig = new XdsConfig.XdsConfigBuilder() + .setListener(ldsUpdate) + .setRoute(rdsUpdate) + .setVirtualHost(rdsUpdate.virtualHosts.get(0)) + .addCluster(CLUSTER_NAME, StatusOr.fromValue(clusterConfig)).build(); + CallOptions callOptionsWithXds = CallOptions.DEFAULT + .withOption(CLUSTER_SELECTION_KEY, "cluster0") + .withOption(XDS_CONFIG_CALL_OPTION_KEY, defaultXdsConfig); + Channel mockChannel = mock(Channel.class); + + GcpAuthenticationConfig config = new GcpAuthenticationConfig(10); + GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME", 10); + ClientInterceptor interceptor = filter.buildClientInterceptor(config, null, null); + MethodDescriptor methodDescriptor = TestMethodDescriptors.voidMethod(); + interceptor.interceptCall(methodDescriptor, callOptionsWithXds, mockChannel); + + verify(mockChannel).newCall(methodDescriptor, callOptionsWithXds); + } + + @Test + public void testClientInterceptor_xdsConfigDoesNotExist() throws Exception { + GcpAuthenticationConfig config = new GcpAuthenticationConfig(10); + GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME", 10); + ClientInterceptor interceptor = filter.buildClientInterceptor(config, null, null); + MethodDescriptor methodDescriptor = TestMethodDescriptors.voidMethod(); + Channel mockChannel = mock(Channel.class); + CallOptions callOptionsWithXds = CallOptions.DEFAULT + .withOption(CLUSTER_SELECTION_KEY, "cluster:cluster0"); + + ClientCall call = + interceptor.interceptCall(methodDescriptor, callOptionsWithXds, mockChannel); + + assertTrue(call instanceof FailingClientCall); + FailingClientCall clientCall = (FailingClientCall) call; + assertThat(clientCall.error.getDescription()).contains("does not contain xds configuration"); + } + + @Test + public void testClientInterceptor_incorrectClusterName() throws Exception { + XdsConfig.XdsClusterConfig clusterConfig = new XdsConfig.XdsClusterConfig( + CLUSTER_NAME, + cdsUpdate, + new EndpointConfig(StatusOr.fromValue(edsUpdate))); + XdsConfig defaultXdsConfig = new XdsConfig.XdsConfigBuilder() + .setListener(ldsUpdate) + .setRoute(rdsUpdate) + .setVirtualHost(rdsUpdate.virtualHosts.get(0)) + .addCluster("custer0", StatusOr.fromValue(clusterConfig)).build(); + CallOptions callOptionsWithXds = CallOptions.DEFAULT + .withOption(CLUSTER_SELECTION_KEY, "cluster:cluster") + .withOption(XDS_CONFIG_CALL_OPTION_KEY, defaultXdsConfig); + GcpAuthenticationConfig config = new GcpAuthenticationConfig(10); + GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME", 10); + ClientInterceptor interceptor = filter.buildClientInterceptor(config, null, null); + MethodDescriptor methodDescriptor = TestMethodDescriptors.voidMethod(); + Channel mockChannel = mock(Channel.class); + + ClientCall call = + interceptor.interceptCall(methodDescriptor, callOptionsWithXds, mockChannel); + + assertTrue(call instanceof FailingClientCall); + FailingClientCall clientCall = (FailingClientCall) call; + assertThat(clientCall.error.getDescription()).contains("does not contain xds cluster"); + } + + @Test + public void testClientInterceptor_statusOrError() throws Exception { + StatusOr errorCluster = + StatusOr.fromStatus(Status.NOT_FOUND.withDescription("Cluster resource not found")); + XdsConfig defaultXdsConfig = new XdsConfig.XdsConfigBuilder() + .setListener(ldsUpdate) + .setRoute(rdsUpdate) + .setVirtualHost(rdsUpdate.virtualHosts.get(0)) + .addCluster(CLUSTER_NAME, errorCluster).build(); + CallOptions callOptionsWithXds = CallOptions.DEFAULT + .withOption(CLUSTER_SELECTION_KEY, "cluster:cluster0") + .withOption(XDS_CONFIG_CALL_OPTION_KEY, defaultXdsConfig); + GcpAuthenticationConfig config = new GcpAuthenticationConfig(10); + GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME", 10); + ClientInterceptor interceptor = filter.buildClientInterceptor(config, null, null); + MethodDescriptor methodDescriptor = TestMethodDescriptors.voidMethod(); + Channel mockChannel = mock(Channel.class); + + ClientCall call = + interceptor.interceptCall(methodDescriptor, callOptionsWithXds, mockChannel); + + assertTrue(call instanceof FailingClientCall); + FailingClientCall clientCall = (FailingClientCall) call; + assertThat(clientCall.error.getDescription()).contains("Cluster resource not found"); + } + + @Test + public void testClientInterceptor_notAudienceWrapper() + throws IOException, ResourceInvalidException { + XdsConfig.XdsClusterConfig clusterConfig = new XdsConfig.XdsClusterConfig( + CLUSTER_NAME, + getCdsUpdateWithIncorrectAudienceWrapper(), + new EndpointConfig(StatusOr.fromValue(edsUpdate))); + XdsConfig defaultXdsConfig = new XdsConfig.XdsConfigBuilder() + .setListener(ldsUpdate) + .setRoute(rdsUpdate) + .setVirtualHost(rdsUpdate.virtualHosts.get(0)) + .addCluster(CLUSTER_NAME, StatusOr.fromValue(clusterConfig)).build(); + CallOptions callOptionsWithXds = CallOptions.DEFAULT + .withOption(CLUSTER_SELECTION_KEY, "cluster:cluster0") + .withOption(XDS_CONFIG_CALL_OPTION_KEY, defaultXdsConfig); + GcpAuthenticationConfig config = new GcpAuthenticationConfig(10); + GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME", 10); + ClientInterceptor interceptor = filter.buildClientInterceptor(config, null, null); + MethodDescriptor methodDescriptor = TestMethodDescriptors.voidMethod(); + Channel mockChannel = Mockito.mock(Channel.class); + + ClientCall call = + interceptor.interceptCall(methodDescriptor, callOptionsWithXds, mockChannel); + + assertTrue(call instanceof FailingClientCall); + FailingClientCall clientCall = (FailingClientCall) call; + assertThat(clientCall.error.getDescription()).contains("GCP Authn found wrong type"); + } + + @Test + public void testLruCacheAcrossInterceptors() throws IOException, ResourceInvalidException { + XdsConfig.XdsClusterConfig clusterConfig = new XdsConfig.XdsClusterConfig( + CLUSTER_NAME, cdsUpdate, new EndpointConfig(StatusOr.fromValue(edsUpdate))); + XdsConfig defaultXdsConfig = new XdsConfig.XdsConfigBuilder() + .setListener(ldsUpdate) + .setRoute(rdsUpdate) + .setVirtualHost(rdsUpdate.virtualHosts.get(0)) + .addCluster(CLUSTER_NAME, StatusOr.fromValue(clusterConfig)).build(); + CallOptions callOptionsWithXds = CallOptions.DEFAULT + .withOption(CLUSTER_SELECTION_KEY, "cluster:cluster0") + .withOption(XDS_CONFIG_CALL_OPTION_KEY, defaultXdsConfig); + GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME", 2); + ClientInterceptor interceptor1 + = filter.buildClientInterceptor(new GcpAuthenticationConfig(2), null, null); + MethodDescriptor methodDescriptor = TestMethodDescriptors.voidMethod(); + Channel mockChannel = Mockito.mock(Channel.class); + ArgumentCaptor callOptionsCaptor = ArgumentCaptor.forClass(CallOptions.class); + + interceptor1.interceptCall(methodDescriptor, callOptionsWithXds, mockChannel); + verify(mockChannel).newCall(eq(methodDescriptor), callOptionsCaptor.capture()); + CallOptions capturedOptions1 = callOptionsCaptor.getAllValues().get(0); + assertNotNull(capturedOptions1.getCredentials()); + ClientInterceptor interceptor2 + = filter.buildClientInterceptor(new GcpAuthenticationConfig(1), null, null); + interceptor2.interceptCall(methodDescriptor, callOptionsWithXds, mockChannel); + verify(mockChannel, times(2)) + .newCall(eq(methodDescriptor), callOptionsCaptor.capture()); + CallOptions capturedOptions2 = callOptionsCaptor.getAllValues().get(1); + assertNotNull(capturedOptions2.getCredentials()); + + assertSame(capturedOptions1.getCredentials(), capturedOptions2.getCredentials()); + } + + @Test + public void testLruCacheEvictionOnResize() throws IOException, ResourceInvalidException { + XdsConfig.XdsClusterConfig clusterConfig = new XdsConfig.XdsClusterConfig( + CLUSTER_NAME, cdsUpdate, new EndpointConfig(StatusOr.fromValue(edsUpdate))); + XdsConfig defaultXdsConfig = new XdsConfig.XdsConfigBuilder() + .setListener(ldsUpdate) + .setRoute(rdsUpdate) + .setVirtualHost(rdsUpdate.virtualHosts.get(0)) + .addCluster(CLUSTER_NAME, StatusOr.fromValue(clusterConfig)).build(); + CallOptions callOptionsWithXds = CallOptions.DEFAULT + .withOption(CLUSTER_SELECTION_KEY, "cluster:cluster0") + .withOption(XDS_CONFIG_CALL_OPTION_KEY, defaultXdsConfig); + GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME", 2); + MethodDescriptor methodDescriptor = TestMethodDescriptors.voidMethod(); + + ClientInterceptor interceptor1 = + filter.buildClientInterceptor(new GcpAuthenticationConfig(2), null, null); + Channel mockChannel1 = Mockito.mock(Channel.class); + ArgumentCaptor captor = ArgumentCaptor.forClass(CallOptions.class); + interceptor1.interceptCall(methodDescriptor, callOptionsWithXds, mockChannel1); + verify(mockChannel1).newCall(eq(methodDescriptor), captor.capture()); + CallOptions options1 = captor.getValue(); + // This will recreate the cache with max size of 1 and copy the credential for audience1. + ClientInterceptor interceptor2 = + filter.buildClientInterceptor(new GcpAuthenticationConfig(1), null, null); + Channel mockChannel2 = Mockito.mock(Channel.class); + interceptor2.interceptCall(methodDescriptor, callOptionsWithXds, mockChannel2); + verify(mockChannel2).newCall(eq(methodDescriptor), captor.capture()); + CallOptions options2 = captor.getValue(); + + assertSame(options1.getCredentials(), options2.getCredentials()); + + clusterConfig = new XdsConfig.XdsClusterConfig( + CLUSTER_NAME, getCdsUpdate2(), new EndpointConfig(StatusOr.fromValue(edsUpdate))); + defaultXdsConfig = new XdsConfig.XdsConfigBuilder() + .setListener(ldsUpdate) + .setRoute(rdsUpdate) + .setVirtualHost(rdsUpdate.virtualHosts.get(0)) + .addCluster(CLUSTER_NAME, StatusOr.fromValue(clusterConfig)).build(); + callOptionsWithXds = CallOptions.DEFAULT + .withOption(CLUSTER_SELECTION_KEY, "cluster:cluster0") + .withOption(XDS_CONFIG_CALL_OPTION_KEY, defaultXdsConfig); + + // This will evict the credential for audience1 and add new credential for audience2 + ClientInterceptor interceptor3 = + filter.buildClientInterceptor(new GcpAuthenticationConfig(1), null, null); + Channel mockChannel3 = Mockito.mock(Channel.class); + interceptor3.interceptCall(methodDescriptor, callOptionsWithXds, mockChannel3); + verify(mockChannel3).newCall(eq(methodDescriptor), captor.capture()); + CallOptions options3 = captor.getValue(); + + assertNotSame(options1.getCredentials(), options3.getCredentials()); + + clusterConfig = new XdsConfig.XdsClusterConfig( + CLUSTER_NAME, cdsUpdate, new EndpointConfig(StatusOr.fromValue(edsUpdate))); + defaultXdsConfig = new XdsConfig.XdsConfigBuilder() + .setListener(ldsUpdate) + .setRoute(rdsUpdate) + .setVirtualHost(rdsUpdate.virtualHosts.get(0)) + .addCluster(CLUSTER_NAME, StatusOr.fromValue(clusterConfig)).build(); + callOptionsWithXds = CallOptions.DEFAULT + .withOption(CLUSTER_SELECTION_KEY, "cluster:cluster0") + .withOption(XDS_CONFIG_CALL_OPTION_KEY, defaultXdsConfig); + + // This will create new credential for audience1 because it has been evicted + ClientInterceptor interceptor4 = + filter.buildClientInterceptor(new GcpAuthenticationConfig(1), null, null); + Channel mockChannel4 = Mockito.mock(Channel.class); + interceptor4.interceptCall(methodDescriptor, callOptionsWithXds, mockChannel4); + verify(mockChannel4).newCall(eq(methodDescriptor), captor.capture()); + CallOptions options4 = captor.getValue(); + + assertNotSame(options1.getCredentials(), options4.getCredentials()); + } + + private static LdsUpdate getLdsUpdate() { + Filter.NamedFilterConfig routerFilterConfig = new Filter.NamedFilterConfig( + "router", RouterFilter.ROUTER_CONFIG); + HttpConnectionManager httpConnectionManager = HttpConnectionManager.forRdsName( + 0L, RDS_NAME, Collections.singletonList(routerFilterConfig)); + return XdsListenerResource.LdsUpdate.forApiListener(httpConnectionManager); + } + + private static RdsUpdate getRdsUpdate() { + RouteConfiguration routeConfiguration = + buildRouteConfiguration("my-server", RDS_NAME, CLUSTER_NAME); + XdsResourceType.Args args = + new XdsResourceType.Args(XdsTestUtils.EMPTY_BOOTSTRAPPER_SERVER_INFO, "0", "0", + XdsTestUtils.EMPTY_BOOTSTRAP, null, null); + try { + return XdsRouteConfigureResource.getInstance().doParse(args, routeConfiguration); + } catch (ResourceInvalidException ex) { + return null; + } + } + + private static EdsUpdate getEdsUpdate() { + Map lbEndpointsMap = new HashMap<>(); + LbEndpoint lbEndpoint = LbEndpoint.create( + "127.0.0.5", ENDPOINT_PORT, 0, true, ENDPOINT_HOSTNAME, ImmutableMap.of()); + lbEndpointsMap.put( + Locality.create("", "", ""), + LocalityLbEndpoints.create(ImmutableList.of(lbEndpoint), 10, 0, ImmutableMap.of())); + return new XdsEndpointResource.EdsUpdate(EDS_NAME, lbEndpointsMap, Collections.emptyList()); + } + + private static CdsUpdate getCdsUpdate() { + ImmutableMap.Builder parsedMetadata = ImmutableMap.builder(); + parsedMetadata.put("FILTER_INSTANCE_NAME", new AudienceWrapper("TEST_AUDIENCE")); + try { + CdsUpdate.Builder cdsUpdate = CdsUpdate.forEds( + CLUSTER_NAME, EDS_NAME, null, null, null, null, false, null) + .lbPolicyConfig(getWrrLbConfigAsMap()); + return cdsUpdate.parsedMetadata(parsedMetadata.build()).build(); + } catch (IOException ex) { + return null; + } + } + + private static CdsUpdate getCdsUpdate2() { + ImmutableMap.Builder parsedMetadata = ImmutableMap.builder(); + parsedMetadata.put("FILTER_INSTANCE_NAME", new AudienceWrapper("NEW_TEST_AUDIENCE")); + try { + CdsUpdate.Builder cdsUpdate = CdsUpdate.forEds( + CLUSTER_NAME, EDS_NAME, null, null, null, null, false, null) + .lbPolicyConfig(getWrrLbConfigAsMap()); + return cdsUpdate.parsedMetadata(parsedMetadata.build()).build(); + } catch (IOException ex) { + return null; + } + } + + private static CdsUpdate getCdsUpdateWithIncorrectAudienceWrapper() throws IOException { + ImmutableMap.Builder parsedMetadata = ImmutableMap.builder(); + parsedMetadata.put("FILTER_INSTANCE_NAME", "TEST_AUDIENCE"); + CdsUpdate.Builder cdsUpdate = CdsUpdate.forEds( + CLUSTER_NAME, EDS_NAME, null, null, null, null, false, null) + .lbPolicyConfig(getWrrLbConfigAsMap()); + return cdsUpdate.parsedMetadata(parsedMetadata.build()).build(); + } + + private static Filter.FilterConfigParseContext getFilterContext() { + return Filter.FilterConfigParseContext.builder() + .bootstrapInfo(BootstrapInfo.builder() + .servers(Collections.singletonList( + ServerInfo.create( + "test_target", Collections.emptyMap()))) + .node(Node.newBuilder().build()) + .build()) + .serverInfo(ServerInfo.create( + "test_target", Collections.emptyMap(), false, true, false, false)) + .build(); + } +} diff --git a/xds/src/test/java/io/grpc/xds/GrpcBootstrapperImplTest.java b/xds/src/test/java/io/grpc/xds/GrpcBootstrapperImplTest.java index 30ea76b54f2..d4ee4159bc2 100644 --- a/xds/src/test/java/io/grpc/xds/GrpcBootstrapperImplTest.java +++ b/xds/src/test/java/io/grpc/xds/GrpcBootstrapperImplTest.java @@ -17,6 +17,7 @@ package io.grpc.xds; import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; import static org.junit.Assert.fail; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verifyNoInteractions; @@ -27,11 +28,14 @@ import io.grpc.TlsChannelCredentials; import io.grpc.internal.GrpcUtil; import io.grpc.internal.GrpcUtil.GrpcBuildVersion; +import io.grpc.xds.client.AllowedGrpcServices; +import io.grpc.xds.client.AllowedGrpcServices.AllowedGrpcService; import io.grpc.xds.client.Bootstrapper; import io.grpc.xds.client.Bootstrapper.AuthorityInfo; import io.grpc.xds.client.Bootstrapper.BootstrapInfo; import io.grpc.xds.client.Bootstrapper.ServerInfo; import io.grpc.xds.client.BootstrapperImpl; +import io.grpc.xds.client.CommonBootstrapperTestUtils; import io.grpc.xds.client.EnvoyProtoData.Node; import io.grpc.xds.client.Locality; import io.grpc.xds.client.XdsInitializationException; @@ -39,10 +43,9 @@ import java.util.List; import java.util.Map; import org.junit.After; +import org.junit.Assert; import org.junit.Before; -import org.junit.Rule; import org.junit.Test; -import org.junit.rules.ExpectedException; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -52,19 +55,18 @@ public class GrpcBootstrapperImplTest { private static final String BOOTSTRAP_FILE_PATH = "/fake/fs/path/bootstrap.json"; private static final String SERVER_URI = "trafficdirector.googleapis.com:443"; - @SuppressWarnings("deprecation") // https://github.com/grpc/grpc-java/issues/7467 - @Rule - public final ExpectedException thrown = ExpectedException.none(); private final GrpcBootstrapperImpl bootstrapper = new GrpcBootstrapperImpl(); private String originalBootstrapPathFromEnvVar; private String originalBootstrapPathFromSysProp; private String originalBootstrapConfigFromEnvVar; private String originalBootstrapConfigFromSysProp; + private boolean originalExperimentalXdsFallbackFlag; @Before public void setUp() { saveEnvironment(); + originalExperimentalXdsFallbackFlag = CommonBootstrapperTestUtils.setEnableXdsFallback(true); bootstrapper.bootstrapPathFromEnvVar = BOOTSTRAP_FILE_PATH; } @@ -81,6 +83,73 @@ public void restoreEnvironment() { bootstrapper.bootstrapPathFromSysProp = originalBootstrapPathFromSysProp; bootstrapper.bootstrapConfigFromEnvVar = originalBootstrapConfigFromEnvVar; bootstrapper.bootstrapConfigFromSysProp = originalBootstrapConfigFromSysProp; + CommonBootstrapperTestUtils.setEnableXdsFallback(originalExperimentalXdsFallbackFlag); + } + + @Test + public void parseBootstrap_emptyServers_throws() { + String rawData = "{\n" + + " \"xds_servers\": [\n" + + " ]\n" + + "}"; + + bootstrapper.setFileReader(createFileReader(BOOTSTRAP_FILE_PATH, rawData)); + XdsInitializationException e = Assert.assertThrows(XdsInitializationException.class, + bootstrapper::bootstrap); + assertThat(e).hasMessageThat().isEqualTo("Invalid bootstrap: 'xds_servers' is empty"); + } + + @Test + public void parseBootstrap_allowedGrpcServices() throws XdsInitializationException { + String rawData = "{\n" + + " \"xds_servers\": [\n" + + " {\n" + + " \"server_uri\": \"" + SERVER_URI + "\",\n" + + " \"channel_creds\": [{\"type\": \"insecure\"}]\n" + + " }\n" + + " ],\n" + + " \"allowed_grpc_services\": {\n" + + " \"dns:///foo.com:443\": {\n" + + " \"channel_creds\": [{\"type\": \"insecure\"}],\n" + + " \"call_creds\": [{\"type\": \"access_token\"}]\n" + + " }\n" + + " }\n" + + "}"; + + bootstrapper.setFileReader(createFileReader(BOOTSTRAP_FILE_PATH, rawData)); + BootstrapInfo info = bootstrapper.bootstrap(); + GrpcBootstrapImplConfig customConfig = + (GrpcBootstrapImplConfig) info.implSpecificObject().get(); + AllowedGrpcServices allowed = customConfig.allowedGrpcServices(); + assertThat(allowed).isNotNull(); + assertThat(allowed.services()).containsKey("dns:///foo.com:443"); + AllowedGrpcService service = allowed.services().get("dns:///foo.com:443"); + assertThat(service.configuredChannelCredentials().channelCredentials()) + .isInstanceOf(InsecureChannelCredentials.class); + assertThat(service.callCredentials().isPresent()).isFalse(); + } + + @Test + public void parseBootstrap_allowedGrpcServices_invalidChannelCreds() { + String rawData = "{\n" + + " \"xds_servers\": [\n" + + " {\n" + + " \"server_uri\": \"" + SERVER_URI + "\",\n" + + " \"channel_creds\": [{\"type\": \"insecure\"}]\n" + + " }\n" + + " ],\n" + + " \"allowed_grpc_services\": {\n" + + " \"dns:///foo.com:443\": {\n" + + " \"channel_creds\": []\n" + + " }\n" + + " }\n" + + "}"; + + bootstrapper.setFileReader(createFileReader(BOOTSTRAP_FILE_PATH, rawData)); + XdsInitializationException e = assertThrows(XdsInitializationException.class, + bootstrapper::bootstrap); + assertThat(e).hasMessageThat() + .isEqualTo("Invalid bootstrap: server dns:///foo.com:443 'channel_creds' required"); } @Test @@ -232,7 +301,7 @@ public void parseBootstrap_IgnoreIrrelevantFields() throws XdsInitializationExce } @Test - public void parseBootstrap_missingServerChannelCreds() throws XdsInitializationException { + public void parseBootstrap_missingServerChannelCreds() { String rawData = "{\n" + " \"xds_servers\": [\n" + " {\n" @@ -242,13 +311,14 @@ public void parseBootstrap_missingServerChannelCreds() throws XdsInitializationE + "}"; bootstrapper.setFileReader(createFileReader(BOOTSTRAP_FILE_PATH, rawData)); - thrown.expect(XdsInitializationException.class); - thrown.expectMessage("Invalid bootstrap: server " + SERVER_URI + " 'channel_creds' required"); - bootstrapper.bootstrap(); + XdsInitializationException e = Assert.assertThrows(XdsInitializationException.class, + bootstrapper::bootstrap); + assertThat(e).hasMessageThat() + .isEqualTo("Invalid bootstrap: server " + SERVER_URI + " 'channel_creds' required"); } @Test - public void parseBootstrap_unsupportedServerChannelCreds() throws XdsInitializationException { + public void parseBootstrap_unsupportedServerChannelCreds() { String rawData = "{\n" + " \"xds_servers\": [\n" + " {\n" @@ -261,9 +331,10 @@ public void parseBootstrap_unsupportedServerChannelCreds() throws XdsInitializat + "}"; bootstrapper.setFileReader(createFileReader(BOOTSTRAP_FILE_PATH, rawData)); - thrown.expect(XdsInitializationException.class); - thrown.expectMessage("Server " + SERVER_URI + ": no supported channel credentials found"); - bootstrapper.bootstrap(); + XdsInitializationException e = assertThrows(XdsInitializationException.class, + bootstrapper::bootstrap); + assertThat(e).hasMessageThat() + .isEqualTo("Server " + SERVER_URI + ": no supported channel credentials found"); } @Test @@ -290,7 +361,7 @@ public void parseBootstrap_useFirstSupportedChannelCredentials() } @Test - public void parseBootstrap_noXdsServers() throws XdsInitializationException { + public void parseBootstrap_noXdsServers() { String rawData = "{\n" + " \"node\": {\n" + " \"id\": \"ENVOY_NODE_ID\",\n" @@ -308,9 +379,10 @@ public void parseBootstrap_noXdsServers() throws XdsInitializationException { + "}"; bootstrapper.setFileReader(createFileReader(BOOTSTRAP_FILE_PATH, rawData)); - thrown.expect(XdsInitializationException.class); - thrown.expectMessage("Invalid bootstrap: 'xds_servers' does not exist."); - bootstrapper.bootstrap(); + XdsInitializationException e = assertThrows(XdsInitializationException.class, + bootstrapper::bootstrap); + assertThat(e).hasMessageThat() + .isEqualTo("Invalid bootstrap: 'xds_servers' does not exist."); } @Test @@ -339,15 +411,23 @@ public void parseBootstrap_serverWithoutServerUri() throws XdsInitializationExce + "}"; bootstrapper.setFileReader(createFileReader(BOOTSTRAP_FILE_PATH, rawData)); - thrown.expectMessage("Invalid bootstrap: missing 'server_uri'"); - bootstrapper.bootstrap(); + XdsInitializationException e = assertThrows(XdsInitializationException.class, + bootstrapper::bootstrap); + assertThat(e).hasMessageThat().isEqualTo("Invalid bootstrap: missing 'server_uri'"); } @Test public void parseBootstrap_certProviderInstances() throws XdsInitializationException { String rawData = "{\n" - + " \"xds_servers\": [],\n" + + " \"xds_servers\": [\n" + + " {\n" + + " \"server_uri\": \"" + SERVER_URI + "\",\n" + + " \"channel_creds\": [\n" + + " {\"type\": \"insecure\"}\n" + + " ]\n" + + " }\n" + + " ],\n" + " \"certificate_providers\": {\n" + " \"gcp_id\": {\n" + " \"plugin_name\": \"meshca\",\n" @@ -384,7 +464,6 @@ public void parseBootstrap_certProviderInstances() throws XdsInitializationExcep bootstrapper.setFileReader(createFileReader(BOOTSTRAP_FILE_PATH, rawData)); BootstrapInfo info = bootstrapper.bootstrap(); - assertThat(info.servers()).isEmpty(); assertThat(info.node()).isEqualTo(getNodeBuilder().build()); Map certProviders = info.certProviders(); assertThat(certProviders).isNotNull(); @@ -551,7 +630,14 @@ public void parseBootstrap_missingPluginName() { @Test public void parseBootstrap_grpcServerResourceId() throws XdsInitializationException { String rawData = "{\n" - + " \"xds_servers\": [],\n" + + " \"xds_servers\": [\n" + + " {\n" + + " \"server_uri\": \"" + SERVER_URI + "\",\n" + + " \"channel_creds\": [\n" + + " {\"type\": \"insecure\"}\n" + + " ]\n" + + " }\n" + + " ],\n" + " \"server_listener_resource_name_template\": \"grpc/serverx=%s\"\n" + "}"; @@ -627,6 +713,28 @@ public void serverFeatureIgnoreResourceDeletion() throws XdsInitializationExcept assertThat(serverInfo.ignoreResourceDeletion()).isTrue(); } + @Test + public void serverFeatureTrustedXdsServer() throws XdsInitializationException { + String rawData = "{\n" + + " \"xds_servers\": [\n" + + " {\n" + + " \"server_uri\": \"" + SERVER_URI + "\",\n" + + " \"channel_creds\": [\n" + + " {\"type\": \"insecure\"}\n" + + " ],\n" + + " \"server_features\": [\"trusted_xds_server\"]\n" + + " }\n" + + " ]\n" + + "}"; + + bootstrapper.setFileReader(createFileReader(BOOTSTRAP_FILE_PATH, rawData)); + BootstrapInfo info = bootstrapper.bootstrap(); + ServerInfo serverInfo = Iterables.getOnlyElement(info.servers()); + assertThat(serverInfo.target()).isEqualTo(SERVER_URI); + assertThat(serverInfo.implSpecificConfig()).isInstanceOf(InsecureChannelCredentials.class); + assertThat(serverInfo.isTrustedXdsServer()).isTrue(); + } + @Test public void serverFeatureIgnoreResourceDeletion_xdsV3() throws XdsInitializationException { String rawData = "{\n" @@ -650,6 +758,72 @@ public void serverFeatureIgnoreResourceDeletion_xdsV3() throws XdsInitialization assertThat(serverInfo.ignoreResourceDeletion()).isTrue(); } + @Test + public void serverFeatures_ignoresUnknownValues() throws XdsInitializationException { + String rawData = "{\n" + + " \"xds_servers\": [\n" + + " {\n" + + " \"server_uri\": \"" + SERVER_URI + "\",\n" + + " \"channel_creds\": [\n" + + " {\"type\": \"insecure\"}\n" + + " ],\n" + + " \"server_features\": [null, {}, 3, true, \"unexpected\", \"trusted_xds_server\"]\n" + + " }\n" + + " ]\n" + + "}"; + + bootstrapper.setFileReader(createFileReader(BOOTSTRAP_FILE_PATH, rawData)); + BootstrapInfo info = bootstrapper.bootstrap(); + ServerInfo serverInfo = Iterables.getOnlyElement(info.servers()); + assertThat(serverInfo.isTrustedXdsServer()).isTrue(); + } + + @Test + public void serverFeature_failOnDataErrors() throws XdsInitializationException { + BootstrapperImpl.xdsDataErrorHandlingEnabled = true; + String rawData = "{\n" + + " \"xds_servers\": [\n" + + " {\n" + + " \"server_uri\": \"" + SERVER_URI + "\",\n" + + " \"channel_creds\": [\n" + + " {\"type\": \"insecure\"}\n" + + " ],\n" + + " \"server_features\": [\"fail_on_data_errors\"]\n" + + " }\n" + + " ]\n" + + "}"; + + bootstrapper.setFileReader(createFileReader(BOOTSTRAP_FILE_PATH, rawData)); + BootstrapInfo info = bootstrapper.bootstrap(); + ServerInfo serverInfo = Iterables.getOnlyElement(info.servers()); + assertThat(serverInfo.target()).isEqualTo(SERVER_URI); + assertThat(serverInfo.implSpecificConfig()).isInstanceOf(InsecureChannelCredentials.class); + assertThat(serverInfo.failOnDataErrors()).isTrue(); + BootstrapperImpl.xdsDataErrorHandlingEnabled = false; + } + + @Test + public void serverFeature_failOnDataErrors_requiresEnvVar() throws XdsInitializationException { + BootstrapperImpl.xdsDataErrorHandlingEnabled = false; + String rawData = "{\n" + + " \"xds_servers\": [\n" + + " {\n" + + " \"server_uri\": \"" + SERVER_URI + "\",\n" + + " \"channel_creds\": [\n" + + " {\"type\": \"insecure\"}\n" + + " ],\n" + + " \"server_features\": [\"fail_on_data_errors\"]\n" + + " }\n" + + " ]\n" + + "}"; + + bootstrapper.setFileReader(createFileReader(BOOTSTRAP_FILE_PATH, rawData)); + BootstrapInfo info = bootstrapper.bootstrap(); + ServerInfo serverInfo = Iterables.getOnlyElement(info.servers()); + // Should be false when env var is not enabled + assertThat(serverInfo.failOnDataErrors()).isFalse(); + } + @Test public void notFound() { bootstrapper.bootstrapPathFromEnvVar = null; @@ -732,6 +906,12 @@ public void fallbackToConfigFromSysProp() throws XdsInitializationException { public void parseClientDefaultListenerResourceNameTemplate() throws Exception { String rawData = "{\n" + " \"xds_servers\": [\n" + + " {\n" + + " \"server_uri\": \"" + SERVER_URI + "\",\n" + + " \"channel_creds\": [\n" + + " {\"type\": \"insecure\"}\n" + + " ]\n" + + " }\n" + " ]\n" + "}"; bootstrapper.setFileReader(createFileReader(BOOTSTRAP_FILE_PATH, rawData)); @@ -741,6 +921,12 @@ public void parseClientDefaultListenerResourceNameTemplate() throws Exception { rawData = "{\n" + " \"client_default_listener_resource_name_template\": \"xdstp://a.com/faketype/%s\",\n" + " \"xds_servers\": [\n" + + " {\n" + + " \"server_uri\": \"" + SERVER_URI + "\",\n" + + " \"channel_creds\": [\n" + + " {\"type\": \"insecure\"}\n" + + " ]\n" + + " }\n" + " ]\n" + "}"; bootstrapper.setFileReader(createFileReader(BOOTSTRAP_FILE_PATH, rawData)); @@ -824,7 +1010,7 @@ public void parseAuthorities() throws Exception { } @Test - public void badFederationConfig() throws Exception { + public void badFederationConfig() { String rawData = "{\n" + " \"authorities\": {\n" + " \"a.com\": {\n" diff --git a/xds/src/test/java/io/grpc/xds/GrpcServiceConfigParserTest.java b/xds/src/test/java/io/grpc/xds/GrpcServiceConfigParserTest.java new file mode 100644 index 00000000000..ddfd0f19498 --- /dev/null +++ b/xds/src/test/java/io/grpc/xds/GrpcServiceConfigParserTest.java @@ -0,0 +1,757 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; + +import com.google.common.collect.ImmutableMap; +import com.google.protobuf.Any; +import com.google.protobuf.ByteString; +import com.google.protobuf.Duration; +import io.envoyproxy.envoy.config.core.v3.GrpcService; +import io.envoyproxy.envoy.config.core.v3.HeaderValue; +import io.envoyproxy.envoy.extensions.grpc_service.call_credentials.access_token.v3.AccessTokenCredentials; +import io.envoyproxy.envoy.extensions.grpc_service.channel_credentials.google_default.v3.GoogleDefaultCredentials; +import io.envoyproxy.envoy.extensions.grpc_service.channel_credentials.insecure.v3.InsecureCredentials; +import io.envoyproxy.envoy.extensions.grpc_service.channel_credentials.local.v3.LocalCredentials; +import io.envoyproxy.envoy.extensions.grpc_service.channel_credentials.xds.v3.XdsCredentials; +import io.grpc.Attributes; +import io.grpc.CallCredentials; +import io.grpc.CompositeCallCredentials; +import io.grpc.CompositeChannelCredentials; +import io.grpc.InsecureChannelCredentials; +import io.grpc.Metadata; +import io.grpc.MethodDescriptor; +import io.grpc.SecurityLevel; +import io.grpc.Status; +import io.grpc.alts.GoogleDefaultChannelCredentials; +import io.grpc.xds.client.AllowedGrpcServices; +import io.grpc.xds.client.AllowedGrpcServices.AllowedGrpcService; +import io.grpc.xds.client.Bootstrapper.BootstrapInfo; +import io.grpc.xds.client.Bootstrapper.ServerInfo; +import io.grpc.xds.client.ConfiguredChannelCredentials; +import io.grpc.xds.client.EnvoyProtoData.Node; +import io.grpc.xds.internal.grpcservice.GrpcServiceConfig; +import io.grpc.xds.internal.grpcservice.GrpcServiceParseException; +import java.io.InputStream; +import java.nio.charset.StandardCharsets; +import java.util.Collections; +import java.util.Optional; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.Mockito; + +@RunWith(JUnit4.class) +public class GrpcServiceConfigParserTest { + + private static final String CALL_CREDENTIALS_CLASS_NAME = + "io.grpc.xds.GrpcServiceConfigParser$SecurityAwareAccessTokenCredentials"; + + private static BootstrapInfo dummyBootstrapInfo() { + return dummyBootstrapInfo(Optional.empty()); + } + + private static BootstrapInfo dummyBootstrapInfo(Optional implSpecificObject) { + return BootstrapInfo.builder() + .servers(Collections + .singletonList(ServerInfo.create("test_target", Collections.emptyMap()))) + .node(Node.newBuilder().build()).implSpecificObject(implSpecificObject).build(); + } + + private static ServerInfo dummyServerInfo() { + return dummyServerInfo(true); + } + + private static ServerInfo dummyServerInfo(boolean isTrusted) { + return ServerInfo.create("test_target", Collections.emptyMap(), false, isTrusted, false, + false); + } + + private static GrpcServiceConfig parse( + GrpcService grpcServiceProto, BootstrapInfo bootstrapInfo, + ServerInfo serverInfo) + throws GrpcServiceParseException { + return GrpcServiceConfigParser.parse(grpcServiceProto, bootstrapInfo, serverInfo); + } + + private static GrpcServiceConfig.GoogleGrpcConfig parseGoogleGrpcConfig( + GrpcService.GoogleGrpc googleGrpcProto, BootstrapInfo bootstrapInfo, + ServerInfo serverInfo) + throws GrpcServiceParseException { + return GrpcServiceConfigParser.parseGoogleGrpcConfig( + googleGrpcProto, bootstrapInfo, serverInfo); + } + + @Test + public void parse_success() throws GrpcServiceParseException { + Any insecureCreds = Any.pack(InsecureCredentials.getDefaultInstance()); + Any accessTokenCreds = + Any.pack(AccessTokenCredentials.newBuilder().setToken("test_token").build()); + GrpcService.GoogleGrpc googleGrpc = GrpcService.GoogleGrpc.newBuilder().setTargetUri("test_uri") + .addChannelCredentialsPlugin(insecureCreds).addCallCredentialsPlugin(accessTokenCreds) + .build(); + HeaderValue asciiHeader = + HeaderValue.newBuilder().setKey("test_key").setValue("test_value").build(); + HeaderValue binaryHeader = + HeaderValue.newBuilder().setKey("test_key-bin").setRawValue(ByteString + .copyFrom("test_value_binary".getBytes(StandardCharsets.UTF_8))).build(); + Duration timeout = Duration.newBuilder().setSeconds(10).build(); + GrpcService grpcService = + GrpcService.newBuilder().setGoogleGrpc(googleGrpc).addInitialMetadata(asciiHeader) + .addInitialMetadata(binaryHeader).setTimeout(timeout).build(); + + GrpcServiceConfig config = parse(grpcService, + dummyBootstrapInfo(), + dummyServerInfo()); + + // Assert target URI + assertThat(config.googleGrpc().target()).isEqualTo("test_uri"); + + // Assert channel credentials + assertThat(config.googleGrpc().configuredChannelCredentials().channelCredentials()) + .isInstanceOf(InsecureChannelCredentials.class); + GrpcServiceConfigParser.ProtoChannelCredsConfig credsConfig = + (GrpcServiceConfigParser.ProtoChannelCredsConfig) + config.googleGrpc().configuredChannelCredentials().channelCredsConfig(); + assertThat(credsConfig.configProto()).isEqualTo(insecureCreds); + + // Assert call credentials + assertThat(config.googleGrpc().callCredentials().isPresent()).isTrue(); + assertThat(config.googleGrpc().callCredentials().get().getClass().getName()) + .isEqualTo(CALL_CREDENTIALS_CLASS_NAME); + + // Assert initial metadata + assertThat(config.initialMetadata()).isNotEmpty(); + assertThat(config.initialMetadata().get(0).key()).isEqualTo("test_key"); + assertThat(config.initialMetadata().get(0).value().get()).isEqualTo("test_value"); + assertThat(config.initialMetadata().get(1).key()).isEqualTo("test_key-bin"); + assertThat(config.initialMetadata().get(1).rawValue().get().toByteArray()) + .isEqualTo("test_value_binary".getBytes(StandardCharsets.UTF_8)); + + // Assert timeout + assertThat(config.timeout().isPresent()).isTrue(); + assertThat(config.timeout().get()).isEqualTo(java.time.Duration.ofSeconds(10)); + } + + @Test + public void parse_minimalSuccess_defaults() throws GrpcServiceParseException { + Any insecureCreds = Any.pack(InsecureCredentials.getDefaultInstance()); + Any accessTokenCreds = + Any.pack(AccessTokenCredentials.newBuilder().setToken("test_token").build()); + GrpcService.GoogleGrpc googleGrpc = GrpcService.GoogleGrpc.newBuilder().setTargetUri("test_uri") + .addChannelCredentialsPlugin(insecureCreds).addCallCredentialsPlugin(accessTokenCreds) + .build(); + GrpcService grpcService = GrpcService.newBuilder().setGoogleGrpc(googleGrpc).build(); + + GrpcServiceConfig config = parse(grpcService, + dummyBootstrapInfo(), + dummyServerInfo()); + + assertThat(config.googleGrpc().target()).isEqualTo("test_uri"); + assertThat(config.initialMetadata()).isEmpty(); + assertThat(config.timeout().isPresent()).isFalse(); + } + + @Test + public void parse_missingGoogleGrpc() { + GrpcService grpcService = GrpcService.newBuilder().build(); + GrpcServiceParseException exception = assertThrows(GrpcServiceParseException.class, + () -> parse(grpcService, + dummyBootstrapInfo(), + dummyServerInfo())); + assertThat(exception).hasMessageThat() + .startsWith("Unsupported: GrpcService must have GoogleGrpc, got: "); + } + + @Test + public void parse_emptyCallCredentials() throws GrpcServiceParseException { + Any insecureCreds = Any.pack(InsecureCredentials.getDefaultInstance()); + GrpcService.GoogleGrpc googleGrpc = GrpcService.GoogleGrpc.newBuilder().setTargetUri("test_uri") + .addChannelCredentialsPlugin(insecureCreds).build(); + GrpcService grpcService = GrpcService.newBuilder().setGoogleGrpc(googleGrpc).build(); + + GrpcServiceConfig config = parse(grpcService, + dummyBootstrapInfo(), + dummyServerInfo()); + + assertThat(config.googleGrpc().callCredentials().isPresent()).isFalse(); + } + + @Test + public void parse_emptyChannelCredentials() { + Any accessTokenCreds = + Any.pack(AccessTokenCredentials.newBuilder().setToken("test_token").build()); + GrpcService.GoogleGrpc googleGrpc = GrpcService.GoogleGrpc.newBuilder().setTargetUri("test_uri") + .addCallCredentialsPlugin(accessTokenCreds).build(); + GrpcService grpcService = GrpcService.newBuilder().setGoogleGrpc(googleGrpc).build(); + GrpcServiceParseException exception = assertThrows(GrpcServiceParseException.class, + () -> parse(grpcService, + dummyBootstrapInfo(), + dummyServerInfo())); + assertThat(exception).hasMessageThat() + .isEqualTo("No valid supported channel_credentials found"); + } + + @Test + public void parse_googleDefaultCredentials() throws GrpcServiceParseException { + Any googleDefaultCreds = Any.pack(GoogleDefaultCredentials.getDefaultInstance()); + Any accessTokenCreds = + Any.pack(AccessTokenCredentials.newBuilder().setToken("test_token").build()); + GrpcService.GoogleGrpc googleGrpc = GrpcService.GoogleGrpc.newBuilder().setTargetUri("test_uri") + .addChannelCredentialsPlugin(googleDefaultCreds).addCallCredentialsPlugin(accessTokenCreds) + .build(); + GrpcService grpcService = GrpcService.newBuilder().setGoogleGrpc(googleGrpc).build(); + + GrpcServiceConfig config = parse(grpcService, + dummyBootstrapInfo(), + dummyServerInfo()); + + assertThat(config.googleGrpc().configuredChannelCredentials().channelCredentials()) + .isInstanceOf(CompositeChannelCredentials.class); + GrpcServiceConfigParser.ProtoChannelCredsConfig credsConfig = + (GrpcServiceConfigParser.ProtoChannelCredsConfig) + config.googleGrpc().configuredChannelCredentials().channelCredsConfig(); + assertThat(credsConfig.configProto()).isEqualTo(googleDefaultCreds); + } + + @Test + public void parse_localCredentials() throws GrpcServiceParseException { + Any localCreds = Any.pack(LocalCredentials.getDefaultInstance()); + Any accessTokenCreds = + Any.pack(AccessTokenCredentials.newBuilder().setToken("test_token").build()); + GrpcService.GoogleGrpc googleGrpc = GrpcService.GoogleGrpc.newBuilder().setTargetUri("test_uri") + .addChannelCredentialsPlugin(localCreds).addCallCredentialsPlugin(accessTokenCreds).build(); + GrpcService grpcService = GrpcService.newBuilder().setGoogleGrpc(googleGrpc).build(); + + GrpcServiceParseException exception = assertThrows(GrpcServiceParseException.class, + () -> parse(grpcService, + dummyBootstrapInfo(), + dummyServerInfo())); + assertThat(exception).hasMessageThat() + .contains("LocalCredentials are not supported in grpc-java"); + } + + @Test + public void parse_xdsCredentials_withInsecureFallback() throws GrpcServiceParseException { + Any insecureCreds = Any.pack(InsecureCredentials.getDefaultInstance()); + XdsCredentials xdsCreds = + XdsCredentials.newBuilder().setFallbackCredentials(insecureCreds).build(); + Any xdsCredsAny = Any.pack(xdsCreds); + Any accessTokenCreds = + Any.pack(AccessTokenCredentials.newBuilder().setToken("test_token").build()); + GrpcService.GoogleGrpc googleGrpc = GrpcService.GoogleGrpc.newBuilder().setTargetUri("test_uri") + .addChannelCredentialsPlugin(xdsCredsAny).addCallCredentialsPlugin(accessTokenCreds) + .build(); + GrpcService grpcService = GrpcService.newBuilder().setGoogleGrpc(googleGrpc).build(); + + GrpcServiceConfig config = GrpcServiceConfigParser.parse(grpcService, + dummyBootstrapInfo(), + dummyServerInfo()); + + assertThat(config.googleGrpc().configuredChannelCredentials().channelCredentials()) + .isNotNull(); + GrpcServiceConfigParser.ProtoChannelCredsConfig credsConfig = + (GrpcServiceConfigParser.ProtoChannelCredsConfig) + config.googleGrpc().configuredChannelCredentials().channelCredsConfig(); + assertThat(credsConfig.configProto()).isEqualTo(xdsCredsAny); + } + + @Test + public void parse_tlsCredentials_notSupported() { + Any tlsCreds = Any + .pack(io.envoyproxy.envoy.extensions.grpc_service.channel_credentials.tls.v3.TlsCredentials + .getDefaultInstance()); + Any accessTokenCreds = + Any.pack(AccessTokenCredentials.newBuilder().setToken("test_token").build()); + GrpcService.GoogleGrpc googleGrpc = GrpcService.GoogleGrpc.newBuilder().setTargetUri("test_uri") + .addChannelCredentialsPlugin(tlsCreds).addCallCredentialsPlugin(accessTokenCreds).build(); + GrpcService grpcService = GrpcService.newBuilder().setGoogleGrpc(googleGrpc).build(); + + GrpcServiceParseException exception = assertThrows(GrpcServiceParseException.class, + () -> parse(grpcService, + dummyBootstrapInfo(), + dummyServerInfo())); + assertThat(exception).hasMessageThat() + .contains("TlsCredentials input stream construction pending"); + } + + @Test + public void parse_invalidChannelCredentialsProto() { + // Pack a Duration proto, but try to unpack it as GoogleDefaultCredentials + Any invalidCreds = Any.pack(Duration.getDefaultInstance()); + Any accessTokenCreds = + Any.pack(AccessTokenCredentials.newBuilder().setToken("test_token").build()); + GrpcService.GoogleGrpc googleGrpc = GrpcService.GoogleGrpc.newBuilder().setTargetUri("test_uri") + .addChannelCredentialsPlugin(invalidCreds).addCallCredentialsPlugin(accessTokenCreds) + .build(); + GrpcService grpcService = GrpcService.newBuilder().setGoogleGrpc(googleGrpc).build(); + + GrpcServiceParseException exception = assertThrows(GrpcServiceParseException.class, + () -> parse(grpcService, + dummyBootstrapInfo(), + dummyServerInfo())); + assertThat(exception).hasMessageThat().contains("No valid supported channel_credentials found"); + } + + @Test + public void parse_ignoredUnsupportedCallCredentialsProto() throws GrpcServiceParseException { + // Pack a Duration proto, but try to unpack it as AccessTokenCredentials + Any insecureCreds = Any.pack(InsecureCredentials.getDefaultInstance()); + Any invalidCallCredentials = Any.pack(Duration.getDefaultInstance()); + GrpcService.GoogleGrpc googleGrpc = GrpcService.GoogleGrpc.newBuilder().setTargetUri("test_uri") + .addChannelCredentialsPlugin(insecureCreds).addCallCredentialsPlugin(invalidCallCredentials) + .build(); + GrpcService grpcService = GrpcService.newBuilder().setGoogleGrpc(googleGrpc).build(); + + GrpcServiceConfig config = parse(grpcService, + dummyBootstrapInfo(), + dummyServerInfo()); + assertThat(config.googleGrpc().callCredentials().isPresent()).isFalse(); + } + + @Test + public void parse_invalidAccessTokenCallCredentialsProto() { + Any insecureCreds = Any.pack(InsecureCredentials.getDefaultInstance()); + Any invalidCallCredentials = Any.pack(AccessTokenCredentials.newBuilder().setToken("").build()); + GrpcService.GoogleGrpc googleGrpc = GrpcService.GoogleGrpc.newBuilder().setTargetUri("test_uri") + .addChannelCredentialsPlugin(insecureCreds).addCallCredentialsPlugin(invalidCallCredentials) + .build(); + GrpcService grpcService = GrpcService.newBuilder().setGoogleGrpc(googleGrpc).build(); + + GrpcServiceParseException exception = assertThrows(GrpcServiceParseException.class, + () -> parse(grpcService, + dummyBootstrapInfo(), + dummyServerInfo())); + assertThat(exception).hasMessageThat() + .contains("Missing or empty access token in call credentials"); + } + + @Test + public void parse_multipleCallCredentials() throws GrpcServiceParseException { + Any insecureCreds = Any.pack(InsecureCredentials.getDefaultInstance()); + Any accessTokenCreds1 = + Any.pack(AccessTokenCredentials.newBuilder().setToken("token1").build()); + Any accessTokenCreds2 = + Any.pack(AccessTokenCredentials.newBuilder().setToken("token2").build()); + GrpcService.GoogleGrpc googleGrpc = GrpcService.GoogleGrpc.newBuilder().setTargetUri("test_uri") + .addChannelCredentialsPlugin(insecureCreds).addCallCredentialsPlugin(accessTokenCreds1) + .addCallCredentialsPlugin(accessTokenCreds2).build(); + GrpcService grpcService = GrpcService.newBuilder().setGoogleGrpc(googleGrpc).build(); + + GrpcServiceConfig config = parse(grpcService, + dummyBootstrapInfo(), + dummyServerInfo()); + + assertThat(config.googleGrpc().callCredentials().isPresent()).isTrue(); + assertThat(config.googleGrpc().callCredentials().get()) + .isInstanceOf(CompositeCallCredentials.class); + } + + @Test + public void parse_untrustedControlPlane_withoutOverride() { + Any insecureCreds = Any.pack(InsecureCredentials.getDefaultInstance()); + GrpcService.GoogleGrpc googleGrpc = GrpcService.GoogleGrpc.newBuilder().setTargetUri("test_uri") + .addChannelCredentialsPlugin(insecureCreds).build(); + GrpcService grpcService = GrpcService.newBuilder().setGoogleGrpc(googleGrpc).build(); + + BootstrapInfo untrustedBootstrapInfo = dummyBootstrapInfo(Optional.empty()); + ServerInfo untrustedServerInfo = + dummyServerInfo(false); + + GrpcServiceParseException exception = assertThrows(GrpcServiceParseException.class, + () -> parse( + grpcService, untrustedBootstrapInfo, untrustedServerInfo)); + assertThat(exception).hasMessageThat() + .contains("Untrusted xDS server & URI not found in allowed_grpc_services"); + } + + @Test + public void parse_untrustedControlPlane_withOverride() throws GrpcServiceParseException { + // The proto credentials (insecure) should be ignored in favor of the override (google default) + Any insecureCreds = Any.pack(InsecureCredentials.getDefaultInstance()); + GrpcService.GoogleGrpc googleGrpc = GrpcService.GoogleGrpc.newBuilder().setTargetUri("test_uri") + .addChannelCredentialsPlugin(insecureCreds).build(); + GrpcService grpcService = GrpcService.newBuilder().setGoogleGrpc(googleGrpc).build(); + + ConfiguredChannelCredentials overrideChannelCreds = ConfiguredChannelCredentials.create( + GoogleDefaultChannelCredentials.create(), + new GrpcServiceConfigParser.ProtoChannelCredsConfig( + GrpcServiceConfigParser.GOOGLE_DEFAULT_CREDENTIALS_TYPE_URL, + Any.pack(GoogleDefaultCredentials.getDefaultInstance()))); + AllowedGrpcService override = AllowedGrpcService.builder() + .configuredChannelCredentials(overrideChannelCreds).build(); + AllowedGrpcServices servicesMap = + AllowedGrpcServices.create( + ImmutableMap.of("test_uri", override)); + + BootstrapInfo untrustedBootstrapInfo = + dummyBootstrapInfo(Optional.of(GrpcBootstrapImplConfig.create(servicesMap))); + ServerInfo untrustedServerInfo = + dummyServerInfo(false); + + GrpcServiceConfig config = + parse(grpcService, untrustedBootstrapInfo, untrustedServerInfo); + + // Assert channel credentials are the override, not the proto's insecure creds + assertThat(config.googleGrpc().configuredChannelCredentials().channelCredentials()) + .isInstanceOf(CompositeChannelCredentials.class); + } + + @Test + public void parse_invalidTimeout() { + Any insecureCreds = Any.pack(InsecureCredentials.getDefaultInstance()); + GrpcService.GoogleGrpc googleGrpc = GrpcService.GoogleGrpc.newBuilder().setTargetUri("test_uri") + .addChannelCredentialsPlugin(insecureCreds).build(); + + // Negative timeout + Duration timeout = Duration.newBuilder().setSeconds(-10).build(); + GrpcService grpcService = GrpcService.newBuilder() + .setGoogleGrpc(googleGrpc).setTimeout(timeout).build(); + + GrpcServiceParseException exception = assertThrows(GrpcServiceParseException.class, + () -> parse(grpcService, + dummyBootstrapInfo(), + dummyServerInfo())); + assertThat(exception).hasMessageThat() + .contains("Timeout must be strictly positive"); + + // Zero timeout + timeout = Duration.newBuilder().setSeconds(0).setNanos(0).build(); + GrpcService grpcServiceZero = GrpcService.newBuilder() + .setGoogleGrpc(googleGrpc).setTimeout(timeout).build(); + + exception = assertThrows(GrpcServiceParseException.class, + () -> parse(grpcServiceZero, + dummyBootstrapInfo(), + dummyServerInfo())); + assertThat(exception).hasMessageThat() + .contains("Timeout must be strictly positive"); + } + + @Test + public void parseGoogleGrpcConfig_unsupportedScheme() { + Any insecureCreds = Any.pack(InsecureCredentials.getDefaultInstance()); + GrpcService.GoogleGrpc googleGrpc = GrpcService.GoogleGrpc.newBuilder() + .setTargetUri("unknown://test") + .addChannelCredentialsPlugin(insecureCreds).build(); + + BootstrapInfo bootstrapInfo = dummyBootstrapInfo(); + ServerInfo serverInfo = dummyServerInfo(); + + GrpcServiceParseException exception = assertThrows(GrpcServiceParseException.class, + () -> parseGoogleGrpcConfig( + googleGrpc, bootstrapInfo, serverInfo)); + assertThat(exception).hasMessageThat() + .contains("Target URI scheme is not resolvable"); + } + + @Test + public void parse_disallowedInitialMetadata() { + Any insecureCreds = Any.pack(InsecureCredentials.getDefaultInstance()); + GrpcService.GoogleGrpc googleGrpc = GrpcService.GoogleGrpc.newBuilder().setTargetUri("test_uri") + .addChannelCredentialsPlugin(insecureCreds).build(); + HeaderValue disallowedHeader = + HeaderValue.newBuilder().setKey("host").setValue("test_value").build(); + GrpcService grpcService = GrpcService.newBuilder().setGoogleGrpc(googleGrpc) + .addInitialMetadata(disallowedHeader).build(); + + GrpcServiceParseException exception = assertThrows(GrpcServiceParseException.class, + () -> parse(grpcService, dummyBootstrapInfo(), dummyServerInfo())); + assertThat(exception).hasMessageThat().contains("Invalid initial metadata header: host"); + } + + @Test + public void parse_invalidDuration() { + Any insecureCreds = Any.pack(InsecureCredentials.getDefaultInstance()); + GrpcService.GoogleGrpc googleGrpc = GrpcService.GoogleGrpc.newBuilder().setTargetUri("test_uri") + .addChannelCredentialsPlugin(insecureCreds).build(); + + Duration timeout = Duration.newBuilder().setSeconds(10).setNanos(1_000_000_000).build(); + GrpcService grpcService = GrpcService.newBuilder() + .setGoogleGrpc(googleGrpc).setTimeout(timeout).build(); + + GrpcServiceParseException exception = assertThrows(GrpcServiceParseException.class, + () -> parse(grpcService, dummyBootstrapInfo(), dummyServerInfo())); + assertThat(exception).hasMessageThat() + .contains("Timeout must be strictly positive and valid"); + } + + @Test + public void parse_invalidChannelCredsProto() { + Any invalidCreds = Any.newBuilder() + .setTypeUrl(GrpcServiceConfigParser.XDS_CREDENTIALS_TYPE_URL) + .setValue(ByteString.copyFrom(new byte[]{1, 2, 3})).build(); + GrpcService.GoogleGrpc googleGrpc = GrpcService.GoogleGrpc.newBuilder().setTargetUri("test_uri") + .addChannelCredentialsPlugin(invalidCreds).build(); + GrpcService grpcService = GrpcService.newBuilder().setGoogleGrpc(googleGrpc).build(); + + GrpcServiceParseException exception = assertThrows(GrpcServiceParseException.class, + () -> parse(grpcService, dummyBootstrapInfo(), dummyServerInfo())); + assertThat(exception).hasMessageThat().contains("Failed to parse channel credentials"); + } + + @Test + public void parse_unsupportedXdsFallbackCreds() { + Any unsupportedFallback = Any.pack(Duration.getDefaultInstance()); + XdsCredentials xds = + XdsCredentials.newBuilder().setFallbackCredentials(unsupportedFallback).build(); + Any xdsCredsAny = Any.newBuilder() + .setTypeUrl(GrpcServiceConfigParser.XDS_CREDENTIALS_TYPE_URL) + .setValue(xds.toByteString()).build(); + GrpcService.GoogleGrpc googleGrpc = GrpcService.GoogleGrpc.newBuilder().setTargetUri("test_uri") + .addChannelCredentialsPlugin(xdsCredsAny).build(); + GrpcService grpcService = GrpcService.newBuilder().setGoogleGrpc(googleGrpc).build(); + + GrpcServiceParseException exception = assertThrows(GrpcServiceParseException.class, + () -> parse(grpcService, dummyBootstrapInfo(), dummyServerInfo())); + assertThat(exception).hasMessageThat() + .contains("Unsupported fallback credentials type for XdsCredentials"); + } + + @Test + public void parse_invalidCallCredsProto() { + Any insecureCreds = Any.pack(InsecureCredentials.getDefaultInstance()); + // We just create an Any representing AccessTokenCredentials but with invalid bytes + Any invalidCallCreds = Any.newBuilder() + .setTypeUrl(Any.pack(AccessTokenCredentials.getDefaultInstance()).getTypeUrl()) + .setValue(ByteString.copyFrom(new byte[]{1, 2, 3})).build(); + + GrpcService.GoogleGrpc googleGrpc = GrpcService.GoogleGrpc.newBuilder().setTargetUri("test_uri") + .addChannelCredentialsPlugin(insecureCreds).addCallCredentialsPlugin(invalidCallCreds) + .build(); + GrpcService grpcService = GrpcService.newBuilder().setGoogleGrpc(googleGrpc).build(); + + GrpcServiceParseException exception = assertThrows(GrpcServiceParseException.class, + () -> parse(grpcService, dummyBootstrapInfo(), dummyServerInfo())); + assertThat(exception).hasMessageThat().contains("Failed to parse access token credentials"); + } + + @Test + public void parseGoogleGrpcConfig_malformedUriThrows() { + Any insecureCreds = Any.pack(InsecureCredentials.getDefaultInstance()); + GrpcService.GoogleGrpc googleGrpc = GrpcService.GoogleGrpc.newBuilder().setTargetUri(":::::") + .addChannelCredentialsPlugin(insecureCreds).build(); + + BootstrapInfo bootstrapInfo = dummyBootstrapInfo(); + ServerInfo serverInfo = dummyServerInfo(); + + GrpcServiceParseException exception = assertThrows(GrpcServiceParseException.class, + () -> parseGoogleGrpcConfig(googleGrpc, bootstrapInfo, serverInfo)); + assertThat(exception).hasMessageThat().contains("Target URI scheme is not resolvable"); + } + + @Test + public void parseGoogleGrpcConfig_untrustedWithCallCredentialsOverride() throws Exception { + Any insecureCreds = Any.pack(InsecureCredentials.getDefaultInstance()); + GrpcService.GoogleGrpc googleGrpc = GrpcService.GoogleGrpc.newBuilder().setTargetUri("test_uri") + .addChannelCredentialsPlugin(insecureCreds).build(); + + ConfiguredChannelCredentials overrideChannelCreds = + ConfiguredChannelCredentials.create(GoogleDefaultChannelCredentials.create(), + new GrpcServiceConfigParser.ProtoChannelCredsConfig( + GrpcServiceConfigParser.GOOGLE_DEFAULT_CREDENTIALS_TYPE_URL, + Any.pack(GoogleDefaultCredentials.getDefaultInstance()))); + + CallCredentials fakeCallCreds = Mockito.mock(CallCredentials.class); + AllowedGrpcService override = AllowedGrpcService.builder() + .configuredChannelCredentials(overrideChannelCreds).callCredentials(fakeCallCreds).build(); + + AllowedGrpcServices servicesMap = + AllowedGrpcServices + .create(ImmutableMap.of("test_uri", override)); + + BootstrapInfo untrustedBootstrapInfo = + dummyBootstrapInfo(Optional.of(GrpcBootstrapImplConfig.create(servicesMap))); + ServerInfo untrustedServerInfo = dummyServerInfo(false); + + GrpcServiceConfig.GoogleGrpcConfig config = + parseGoogleGrpcConfig(googleGrpc, untrustedBootstrapInfo, untrustedServerInfo); + + assertThat(config.callCredentials().isPresent()).isTrue(); + assertThat(config.callCredentials().get()).isSameInstanceAs(fakeCallCreds); + } + + @Test + public void protoChannelCredsConfig_equalsAndHashCode() { + Any insecureCreds1 = Any.pack(InsecureCredentials.getDefaultInstance()); + Any insecureCreds2 = Any.pack(InsecureCredentials.getDefaultInstance()); + Any localCreds = Any.pack(LocalCredentials.getDefaultInstance()); + + GrpcServiceConfigParser.ProtoChannelCredsConfig config1 = + new GrpcServiceConfigParser.ProtoChannelCredsConfig("type1", insecureCreds1); + GrpcServiceConfigParser.ProtoChannelCredsConfig config1Equivalent = + new GrpcServiceConfigParser.ProtoChannelCredsConfig("type1", insecureCreds2); + GrpcServiceConfigParser.ProtoChannelCredsConfig configDifferentType = + new GrpcServiceConfigParser.ProtoChannelCredsConfig("type2", insecureCreds1); + GrpcServiceConfigParser.ProtoChannelCredsConfig configDifferentProto = + new GrpcServiceConfigParser.ProtoChannelCredsConfig("type1", localCreds); + + assertThat(config1.type()).isEqualTo("type1"); + assertThat(config1.equals(config1)).isTrue(); + assertThat(config1.equals(null)).isFalse(); + assertThat(config1.equals(new Object())).isFalse(); + assertThat(config1.equals(config1Equivalent)).isTrue(); + assertThat(config1.hashCode()).isEqualTo(config1Equivalent.hashCode()); + assertThat(config1.equals(configDifferentType)).isFalse(); + assertThat(config1.equals(configDifferentProto)).isFalse(); + } + + static class RecordingMetadataApplier extends CallCredentials.MetadataApplier { + boolean applied = false; + boolean failed = false; + Metadata appliedHeaders = null; + + @Override + public void apply(Metadata headers) { + applied = true; + appliedHeaders = headers; + } + + @Override + public void fail(Status status) { + failed = true; + } + } + + static class FakeRequestInfo extends CallCredentials.RequestInfo { + private final SecurityLevel securityLevel; + private final MethodDescriptor methodDescriptor; + + FakeRequestInfo(SecurityLevel securityLevel) { + this.securityLevel = securityLevel; + this.methodDescriptor = MethodDescriptor.newBuilder() + .setType(MethodDescriptor.MethodType.UNARY) + .setFullMethodName("test_service/test_method") + .setRequestMarshaller(new NoopMarshaller()) + .setResponseMarshaller(new NoopMarshaller()) + .build(); + } + + private static class NoopMarshaller implements MethodDescriptor.Marshaller { + @Override + public InputStream stream(T value) { + return null; + } + + @Override + public T parse(InputStream stream) { + return null; + } + } + + @Override + public MethodDescriptor getMethodDescriptor() { + return methodDescriptor; + } + + @Override + public SecurityLevel getSecurityLevel() { + return securityLevel; + } + + @Override + public String getAuthority() { + return "dummy-authority"; + } + + @Override + public Attributes getTransportAttrs() { + return Attributes.EMPTY; + } + } + + + @Test + public void securityAwareCredentials_secureConnection_appliesToken() throws Exception { + Any insecureCreds = Any.pack(InsecureCredentials.getDefaultInstance()); + Any accessTokenCreds = + Any.pack(AccessTokenCredentials.newBuilder().setToken("test_token").build()); + GrpcService.GoogleGrpc googleGrpc = GrpcService.GoogleGrpc.newBuilder() + .setTargetUri("test_uri") + .addChannelCredentialsPlugin(insecureCreds) + .addCallCredentialsPlugin(accessTokenCreds) + .build(); + GrpcService grpcService = GrpcService.newBuilder().setGoogleGrpc(googleGrpc).build(); + + GrpcServiceConfig config = parse(grpcService, + dummyBootstrapInfo(), + dummyServerInfo()); + + CallCredentials creds = config.googleGrpc().callCredentials().get(); + RecordingMetadataApplier applier = new RecordingMetadataApplier(); + CountDownLatch latch = new CountDownLatch(1); + + creds.applyRequestMetadata( + new FakeRequestInfo(SecurityLevel.PRIVACY_AND_INTEGRITY), + Runnable::run, // Use direct executor to avoid async issues in test + new CallCredentials.MetadataApplier() { + @Override + public void apply(Metadata headers) { + applier.apply(headers); + latch.countDown(); + } + + @Override + public void fail(Status status) { + applier.fail(status); + latch.countDown(); + } + }); + + latch.await(5, TimeUnit.SECONDS); + assertThat(applier.applied).isTrue(); + assertThat(applier.appliedHeaders.get( + Metadata.Key.of("Authorization", Metadata.ASCII_STRING_MARSHALLER))) + .isEqualTo("Bearer test_token"); + } + + @Test + public void securityAwareCredentials_insecureConnection_appliesEmptyMetadata() throws Exception { + Any insecureCreds = Any.pack(InsecureCredentials.getDefaultInstance()); + Any accessTokenCreds = + Any.pack(AccessTokenCredentials.newBuilder().setToken("test_token").build()); + GrpcService.GoogleGrpc googleGrpc = GrpcService.GoogleGrpc.newBuilder() + .setTargetUri("test_uri") + .addChannelCredentialsPlugin(insecureCreds) + .addCallCredentialsPlugin(accessTokenCreds) + .build(); + GrpcService grpcService = GrpcService.newBuilder().setGoogleGrpc(googleGrpc).build(); + + GrpcServiceConfig config = parse(grpcService, + dummyBootstrapInfo(), + dummyServerInfo()); + + CallCredentials creds = config.googleGrpc().callCredentials().get(); + RecordingMetadataApplier applier = new RecordingMetadataApplier(); + + creds.applyRequestMetadata( + new FakeRequestInfo(SecurityLevel.NONE), + Runnable::run, + applier); + + assertThat(applier.applied).isTrue(); + assertThat(applier.appliedHeaders.get( + Metadata.Key.of("Authorization", Metadata.ASCII_STRING_MARSHALLER))) + .isNull(); + } + + +} diff --git a/xds/src/test/java/io/grpc/xds/GrpcXdsClientImplDataTest.java b/xds/src/test/java/io/grpc/xds/GrpcXdsClientImplDataTest.java index 47dad474c3f..aa8ff68b760 100644 --- a/xds/src/test/java/io/grpc/xds/GrpcXdsClientImplDataTest.java +++ b/xds/src/test/java/io/grpc/xds/GrpcXdsClientImplDataTest.java @@ -18,9 +18,12 @@ import static com.google.common.truth.Truth.assertThat; import static io.envoyproxy.envoy.config.route.v3.RouteAction.ClusterSpecifierCase.CLUSTER_SPECIFIER_PLUGIN; +import static io.grpc.xds.XdsEndpointResource.GRPC_EXPERIMENTAL_XDS_DUALSTACK_ENDPOINTS; +import static org.junit.Assert.assertThrows; import static org.junit.Assert.fail; import com.github.udpa.udpa.type.v1.TypedStruct; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import com.google.common.collect.Iterables; @@ -49,6 +52,7 @@ import io.envoyproxy.envoy.config.core.v3.DataSource; import io.envoyproxy.envoy.config.core.v3.HttpProtocolOptions; import io.envoyproxy.envoy.config.core.v3.Locality; +import io.envoyproxy.envoy.config.core.v3.Metadata; import io.envoyproxy.envoy.config.core.v3.PathConfigSource; import io.envoyproxy.envoy.config.core.v3.RuntimeFractionalPercent; import io.envoyproxy.envoy.config.core.v3.SelfConfigSource; @@ -83,6 +87,7 @@ import io.envoyproxy.envoy.extensions.filters.common.fault.v3.FaultDelay; import io.envoyproxy.envoy.extensions.filters.http.fault.v3.FaultAbort; import io.envoyproxy.envoy.extensions.filters.http.fault.v3.HTTPFault; +import io.envoyproxy.envoy.extensions.filters.http.gcp_authn.v3.Audience; import io.envoyproxy.envoy.extensions.filters.http.rbac.v3.RBACPerRoute; import io.envoyproxy.envoy.extensions.filters.http.router.v3.Router; import io.envoyproxy.envoy.extensions.filters.network.http_connection_manager.v3.HttpConnectionManager; @@ -90,10 +95,10 @@ import io.envoyproxy.envoy.extensions.filters.network.http_connection_manager.v3.Rds; import io.envoyproxy.envoy.extensions.load_balancing_policies.client_side_weighted_round_robin.v3.ClientSideWeightedRoundRobin; import io.envoyproxy.envoy.extensions.load_balancing_policies.wrr_locality.v3.WrrLocality; +import io.envoyproxy.envoy.extensions.transport_sockets.http_11_proxy.v3.Http11ProxyUpstreamTransport; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateProviderPluginInstance; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateValidationContext; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext; -import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext.CertificateProviderInstance; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext.CombinedCertificateValidationContext; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.DownstreamTlsContext; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.SdsSecretConfig; @@ -107,9 +112,8 @@ import io.envoyproxy.envoy.type.v3.FractionalPercent; import io.envoyproxy.envoy.type.v3.FractionalPercent.DenominatorType; import io.envoyproxy.envoy.type.v3.Int64Range; -import io.grpc.ClientInterceptor; +import io.grpc.EquivalentAddressGroup; import io.grpc.InsecureChannelCredentials; -import io.grpc.LoadBalancer; import io.grpc.LoadBalancerRegistry; import io.grpc.Status.Code; import io.grpc.internal.JsonUtil; @@ -125,6 +129,8 @@ import io.grpc.xds.Endpoints.LbEndpoint; import io.grpc.xds.Endpoints.LocalityLbEndpoints; import io.grpc.xds.Filter.FilterConfig; +import io.grpc.xds.GcpAuthenticationFilter.AudienceMetadataParser.AudienceWrapper; +import io.grpc.xds.MetadataRegistry.MetadataValueParser; import io.grpc.xds.RouteLookupServiceClusterSpecifierPlugin.RlsPluginConfig; import io.grpc.xds.VirtualHost.Route; import io.grpc.xds.VirtualHost.Route.RouteAction; @@ -134,26 +140,24 @@ import io.grpc.xds.VirtualHost.Route.RouteMatch.PathMatcher; import io.grpc.xds.WeightedRoundRobinLoadBalancer.WeightedRoundRobinLoadBalancerConfig; import io.grpc.xds.XdsClusterResource.CdsUpdate; +import io.grpc.xds.client.BackendMetricPropagation; import io.grpc.xds.client.Bootstrapper.ServerInfo; +import io.grpc.xds.client.LoadStatsManager2; import io.grpc.xds.client.XdsClient; import io.grpc.xds.client.XdsResourceType; import io.grpc.xds.client.XdsResourceType.ResourceInvalidException; -import io.grpc.xds.client.XdsResourceType.StructOrError; import io.grpc.xds.internal.Matchers; import io.grpc.xds.internal.Matchers.FractionMatcher; import io.grpc.xds.internal.Matchers.HeaderMatcher; +import java.net.InetSocketAddress; import java.util.Arrays; import java.util.Collections; import java.util.List; import java.util.Map; -import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; -import javax.annotation.Nullable; import org.junit.After; import org.junit.Before; -import org.junit.Rule; import org.junit.Test; -import org.junit.rules.ExpectedException; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -161,31 +165,32 @@ @RunWith(JUnit4.class) public class GrpcXdsClientImplDataTest { + private static final FaultFilter.Provider FAULT_FILTER_PROVIDER = new FaultFilter.Provider(); + private static final RbacFilter.Provider RBAC_FILTER_PROVIDER = new RbacFilter.Provider(); + private static final RouterFilter.Provider ROUTER_FILTER_PROVIDER = new RouterFilter.Provider(); + private static final ServerInfo LRS_SERVER_INFO = ServerInfo.create("lrs.googleapis.com", InsecureChannelCredentials.create()); + private static final String GRPC_EXPERIMENTAL_XDS_AUTHORITY_REWRITE = + "GRPC_EXPERIMENTAL_XDS_AUTHORITY_REWRITE"; - @SuppressWarnings("deprecation") // https://github.com/grpc/grpc-java/issues/7467 - @Rule - public final ExpectedException thrown = ExpectedException.none(); private final FilterRegistry filterRegistry = FilterRegistry.getDefaultRegistry(); private boolean originalEnableRouteLookup; private boolean originalEnableLeastRequest; - private boolean originalEnableWrr; + private boolean originalEnableUseSystemRootCerts; @Before public void setUp() { - originalEnableRouteLookup = XdsResourceType.enableRouteLookup; - originalEnableLeastRequest = XdsResourceType.enableLeastRequest; - assertThat(originalEnableLeastRequest).isFalse(); - originalEnableWrr = XdsResourceType.enableWrr; - assertThat(originalEnableWrr).isTrue(); + originalEnableRouteLookup = XdsRouteConfigureResource.enableRouteLookup; + originalEnableLeastRequest = XdsClusterResource.enableLeastRequest; + originalEnableUseSystemRootCerts = XdsClusterResource.enableSystemRootCerts; } @After public void tearDown() { - XdsResourceType.enableRouteLookup = originalEnableRouteLookup; - XdsResourceType.enableLeastRequest = originalEnableLeastRequest; - XdsResourceType.enableWrr = originalEnableWrr; + XdsRouteConfigureResource.enableRouteLookup = originalEnableRouteLookup; + XdsClusterResource.enableLeastRequest = originalEnableLeastRequest; + XdsClusterResource.enableSystemRootCerts = originalEnableUseSystemRootCerts; } @Test @@ -201,7 +206,7 @@ public void parseRoute_withRouteAction() { .setCluster("cluster-foo")) .build(); StructOrError struct = XdsRouteConfigureResource.parseRoute( - proto, filterRegistry, ImmutableMap.of(), ImmutableSet.of()); + proto, filterRegistry, ImmutableMap.of(), ImmutableSet.of(), getXdsResourceTypeArgs(true)); assertThat(struct.getErrorDetail()).isNull(); assertThat(struct.getStruct()) .isEqualTo( @@ -209,7 +214,7 @@ public void parseRoute_withRouteAction() { RouteMatch.create(PathMatcher.fromPath("/service/method", false), Collections.emptyList(), null), RouteAction.forCluster( - "cluster-foo", Collections.emptyList(), null, null), + "cluster-foo", Collections.emptyList(), null, null, false), ImmutableMap.of())); } @@ -224,7 +229,7 @@ public void parseRoute_withNonForwardingAction() { .setNonForwardingAction(NonForwardingAction.getDefaultInstance()) .build(); StructOrError struct = XdsRouteConfigureResource.parseRoute( - proto, filterRegistry, ImmutableMap.of(), ImmutableSet.of()); + proto, filterRegistry, ImmutableMap.of(), ImmutableSet.of(), getXdsResourceTypeArgs(true)); assertThat(struct.getStruct()) .isEqualTo( Route.forNonForwardingAction( @@ -243,7 +248,8 @@ public void parseRoute_withUnsupportedActionTypes() { .setRedirect(RedirectAction.getDefaultInstance()) .build(); res = XdsRouteConfigureResource.parseRoute( - redirectRoute, filterRegistry, ImmutableMap.of(), ImmutableSet.of()); + redirectRoute, filterRegistry, ImmutableMap.of(), ImmutableSet.of(), + getXdsResourceTypeArgs(true)); assertThat(res.getStruct()).isNull(); assertThat(res.getErrorDetail()) .isEqualTo("Route [route-blade] with unknown action type: REDIRECT"); @@ -255,7 +261,8 @@ public void parseRoute_withUnsupportedActionTypes() { .setDirectResponse(DirectResponseAction.getDefaultInstance()) .build(); res = XdsRouteConfigureResource.parseRoute( - directResponseRoute, filterRegistry, ImmutableMap.of(), ImmutableSet.of()); + directResponseRoute, filterRegistry, ImmutableMap.of(), ImmutableSet.of(), + getXdsResourceTypeArgs(true)); assertThat(res.getStruct()).isNull(); assertThat(res.getErrorDetail()) .isEqualTo("Route [route-blade] with unknown action type: DIRECT_RESPONSE"); @@ -267,7 +274,8 @@ public void parseRoute_withUnsupportedActionTypes() { .setFilterAction(FilterAction.getDefaultInstance()) .build(); res = XdsRouteConfigureResource.parseRoute( - filterRoute, filterRegistry, ImmutableMap.of(), ImmutableSet.of()); + filterRoute, filterRegistry, ImmutableMap.of(), ImmutableSet.of(), + getXdsResourceTypeArgs(true)); assertThat(res.getStruct()).isNull(); assertThat(res.getErrorDetail()) .isEqualTo("Route [route-blade] with unknown action type: FILTER_ACTION"); @@ -289,7 +297,8 @@ public void parseRoute_skipRouteWithUnsupportedMatcher() { .setCluster("cluster-foo")) .build(); assertThat(XdsRouteConfigureResource.parseRoute( - proto, filterRegistry, ImmutableMap.of(), ImmutableSet.of())) + proto, filterRegistry, ImmutableMap.of(), ImmutableSet.of(), + getXdsResourceTypeArgs(true))) .isNull(); } @@ -306,7 +315,8 @@ public void parseRoute_skipRouteWithUnsupportedAction() { .setClusterHeader("cluster header")) // cluster_header action not supported .build(); assertThat(XdsRouteConfigureResource.parseRoute( - proto, filterRegistry, ImmutableMap.of(), ImmutableSet.of())) + proto, filterRegistry, ImmutableMap.of(), ImmutableSet.of(), + getXdsResourceTypeArgs(true))) .isNull(); } @@ -516,10 +526,48 @@ public void parseRouteAction_withCluster() { .build(); StructOrError struct = XdsRouteConfigureResource.parseRouteAction(proto, filterRegistry, - ImmutableMap.of(), ImmutableSet.of()); + ImmutableMap.of(), ImmutableSet.of(), getXdsResourceTypeArgs(true)); assertThat(struct.getErrorDetail()).isNull(); assertThat(struct.getStruct().cluster()).isEqualTo("cluster-foo"); assertThat(struct.getStruct().weightedClusters()).isNull(); + assertThat(struct.getStruct().autoHostRewrite()).isFalse(); + } + + @Test + public void parseRouteAction_withCluster_autoHostRewriteEnabled() { + System.setProperty(GRPC_EXPERIMENTAL_XDS_AUTHORITY_REWRITE, "true"); + try { + io.envoyproxy.envoy.config.route.v3.RouteAction proto = + io.envoyproxy.envoy.config.route.v3.RouteAction.newBuilder() + .setCluster("cluster-foo") + .setAutoHostRewrite(BoolValue.of(true)) + .build(); + StructOrError struct = + XdsRouteConfigureResource.parseRouteAction(proto, filterRegistry, + ImmutableMap.of(), ImmutableSet.of(), getXdsResourceTypeArgs(true)); + assertThat(struct.getErrorDetail()).isNull(); + assertThat(struct.getStruct().cluster()).isEqualTo("cluster-foo"); + assertThat(struct.getStruct().weightedClusters()).isNull(); + assertThat(struct.getStruct().autoHostRewrite()).isTrue(); + } finally { + System.clearProperty(GRPC_EXPERIMENTAL_XDS_AUTHORITY_REWRITE); + } + } + + @Test + public void parseRouteAction_withCluster_flagDisabled_autoHostRewriteNotEnabled() { + io.envoyproxy.envoy.config.route.v3.RouteAction proto = + io.envoyproxy.envoy.config.route.v3.RouteAction.newBuilder() + .setCluster("cluster-foo") + .setAutoHostRewrite(BoolValue.of(true)) + .build(); + StructOrError struct = + XdsRouteConfigureResource.parseRouteAction(proto, filterRegistry, + ImmutableMap.of(), ImmutableSet.of(), getXdsResourceTypeArgs(true)); + assertThat(struct.getErrorDetail()).isNull(); + assertThat(struct.getStruct().cluster()).isEqualTo("cluster-foo"); + assertThat(struct.getStruct().weightedClusters()).isNull(); + assertThat(struct.getStruct().autoHostRewrite()).isTrue(); } @Test @@ -540,12 +588,74 @@ public void parseRouteAction_withWeightedCluster() { .build(); StructOrError struct = XdsRouteConfigureResource.parseRouteAction(proto, filterRegistry, - ImmutableMap.of(), ImmutableSet.of()); + ImmutableMap.of(), ImmutableSet.of(), getXdsResourceTypeArgs(true)); assertThat(struct.getErrorDetail()).isNull(); assertThat(struct.getStruct().cluster()).isNull(); assertThat(struct.getStruct().weightedClusters()).containsExactly( ClusterWeight.create("cluster-foo", 30, ImmutableMap.of()), ClusterWeight.create("cluster-bar", 70, ImmutableMap.of())); + assertThat(struct.getStruct().autoHostRewrite()).isFalse(); + } + + @Test + public void parseRouteAction_withWeightedCluster_autoHostRewriteEnabled() { + System.setProperty(GRPC_EXPERIMENTAL_XDS_AUTHORITY_REWRITE, "true"); + try { + io.envoyproxy.envoy.config.route.v3.RouteAction proto = + io.envoyproxy.envoy.config.route.v3.RouteAction.newBuilder() + .setWeightedClusters( + WeightedCluster.newBuilder() + .addClusters( + WeightedCluster.ClusterWeight + .newBuilder() + .setName("cluster-foo") + .setWeight(UInt32Value.newBuilder().setValue(30))) + .addClusters(WeightedCluster.ClusterWeight + .newBuilder() + .setName("cluster-bar") + .setWeight(UInt32Value.newBuilder().setValue(70)))) + .setAutoHostRewrite(BoolValue.of(true)) + .build(); + StructOrError struct = + XdsRouteConfigureResource.parseRouteAction(proto, filterRegistry, + ImmutableMap.of(), ImmutableSet.of(), getXdsResourceTypeArgs(true)); + assertThat(struct.getErrorDetail()).isNull(); + assertThat(struct.getStruct().cluster()).isNull(); + assertThat(struct.getStruct().weightedClusters()).containsExactly( + ClusterWeight.create("cluster-foo", 30, ImmutableMap.of()), + ClusterWeight.create("cluster-bar", 70, ImmutableMap.of())); + assertThat(struct.getStruct().autoHostRewrite()).isTrue(); + } finally { + System.clearProperty(GRPC_EXPERIMENTAL_XDS_AUTHORITY_REWRITE); + } + } + + @Test + public void parseRouteAction_withWeightedCluster_flagDisabled_autoHostRewriteDisabled() { + io.envoyproxy.envoy.config.route.v3.RouteAction proto = + io.envoyproxy.envoy.config.route.v3.RouteAction.newBuilder() + .setWeightedClusters( + WeightedCluster.newBuilder() + .addClusters( + WeightedCluster.ClusterWeight + .newBuilder() + .setName("cluster-foo") + .setWeight(UInt32Value.newBuilder().setValue(30))) + .addClusters(WeightedCluster.ClusterWeight + .newBuilder() + .setName("cluster-bar") + .setWeight(UInt32Value.newBuilder().setValue(70)))) + .setAutoHostRewrite(BoolValue.of(true)) + .build(); + StructOrError struct = + XdsRouteConfigureResource.parseRouteAction(proto, filterRegistry, + ImmutableMap.of(), ImmutableSet.of(), getXdsResourceTypeArgs(true)); + assertThat(struct.getErrorDetail()).isNull(); + assertThat(struct.getStruct().cluster()).isNull(); + assertThat(struct.getStruct().weightedClusters()).containsExactly( + ClusterWeight.create("cluster-foo", 30, ImmutableMap.of()), + ClusterWeight.create("cluster-bar", 70, ImmutableMap.of())); + assertThat(struct.getStruct().autoHostRewrite()).isTrue(); } @Test @@ -566,7 +676,7 @@ public void parseRouteAction_weightedClusterSum() { .build(); StructOrError struct = XdsRouteConfigureResource.parseRouteAction(proto, filterRegistry, - ImmutableMap.of(), ImmutableSet.of()); + ImmutableMap.of(), ImmutableSet.of(), getXdsResourceTypeArgs(true)); assertThat(struct.getErrorDetail()).isEqualTo("Sum of cluster weights should be above 0."); } @@ -582,7 +692,7 @@ public void parseRouteAction_withTimeoutByGrpcTimeoutHeaderMax() { .build(); StructOrError struct = XdsRouteConfigureResource.parseRouteAction(proto, filterRegistry, - ImmutableMap.of(), ImmutableSet.of()); + ImmutableMap.of(), ImmutableSet.of(), getXdsResourceTypeArgs(true)); assertThat(struct.getStruct().timeoutNano()).isEqualTo(TimeUnit.SECONDS.toNanos(5L)); } @@ -597,7 +707,7 @@ public void parseRouteAction_withTimeoutByMaxStreamDuration() { .build(); StructOrError struct = XdsRouteConfigureResource.parseRouteAction(proto, filterRegistry, - ImmutableMap.of(), ImmutableSet.of()); + ImmutableMap.of(), ImmutableSet.of(), getXdsResourceTypeArgs(true)); assertThat(struct.getStruct().timeoutNano()).isEqualTo(TimeUnit.SECONDS.toNanos(5L)); } @@ -609,7 +719,7 @@ public void parseRouteAction_withTimeoutUnset() { .build(); StructOrError struct = XdsRouteConfigureResource.parseRouteAction(proto, filterRegistry, - ImmutableMap.of(), ImmutableSet.of()); + ImmutableMap.of(), ImmutableSet.of(), getXdsResourceTypeArgs(true)); assertThat(struct.getStruct().timeoutNano()).isNull(); } @@ -631,7 +741,7 @@ public void parseRouteAction_withRetryPolicy() { .build(); StructOrError struct = XdsRouteConfigureResource.parseRouteAction(proto, filterRegistry, - ImmutableMap.of(), ImmutableSet.of()); + ImmutableMap.of(), ImmutableSet.of(), getXdsResourceTypeArgs(true)); RouteAction.RetryPolicy retryPolicy = struct.getStruct().retryPolicy(); assertThat(retryPolicy.maxAttempts()).isEqualTo(4); assertThat(retryPolicy.initialBackoff()).isEqualTo(Durations.fromMillis(500)); @@ -655,7 +765,7 @@ public void parseRouteAction_withRetryPolicy() { .setRetryPolicy(builder.build()) .build(); struct = XdsRouteConfigureResource.parseRouteAction(proto, filterRegistry, - ImmutableMap.of(), ImmutableSet.of()); + ImmutableMap.of(), ImmutableSet.of(), getXdsResourceTypeArgs(true)); assertThat(struct.getStruct().retryPolicy()).isNotNull(); assertThat(struct.getStruct().retryPolicy().retryableStatusCodes()).isEmpty(); @@ -668,7 +778,7 @@ public void parseRouteAction_withRetryPolicy() { .setRetryPolicy(builder) .build(); struct = XdsRouteConfigureResource.parseRouteAction(proto, filterRegistry, - ImmutableMap.of(), ImmutableSet.of()); + ImmutableMap.of(), ImmutableSet.of(), getXdsResourceTypeArgs(true)); assertThat(struct.getErrorDetail()).isEqualTo("No base_interval specified in retry_backoff"); // max_interval unset @@ -678,7 +788,7 @@ public void parseRouteAction_withRetryPolicy() { .setRetryPolicy(builder) .build(); struct = XdsRouteConfigureResource.parseRouteAction(proto, filterRegistry, - ImmutableMap.of(), ImmutableSet.of()); + ImmutableMap.of(), ImmutableSet.of(), getXdsResourceTypeArgs(true)); retryPolicy = struct.getStruct().retryPolicy(); assertThat(retryPolicy.maxBackoff()).isEqualTo(Durations.fromMillis(500 * 10)); @@ -689,7 +799,7 @@ public void parseRouteAction_withRetryPolicy() { .setRetryPolicy(builder) .build(); struct = XdsRouteConfigureResource.parseRouteAction(proto, filterRegistry, - ImmutableMap.of(), ImmutableSet.of()); + ImmutableMap.of(), ImmutableSet.of(), getXdsResourceTypeArgs(true)); assertThat(struct.getErrorDetail()) .isEqualTo("base_interval in retry_backoff must be positive"); @@ -702,7 +812,7 @@ public void parseRouteAction_withRetryPolicy() { .setRetryPolicy(builder) .build(); struct = XdsRouteConfigureResource.parseRouteAction(proto, filterRegistry, - ImmutableMap.of(), ImmutableSet.of()); + ImmutableMap.of(), ImmutableSet.of(), getXdsResourceTypeArgs(true)); assertThat(struct.getErrorDetail()) .isEqualTo("max_interval in retry_backoff cannot be less than base_interval"); @@ -715,7 +825,7 @@ public void parseRouteAction_withRetryPolicy() { .setRetryPolicy(builder) .build(); struct = XdsRouteConfigureResource.parseRouteAction(proto, filterRegistry, - ImmutableMap.of(), ImmutableSet.of()); + ImmutableMap.of(), ImmutableSet.of(), getXdsResourceTypeArgs(true)); assertThat(struct.getErrorDetail()) .isEqualTo("max_interval in retry_backoff cannot be less than base_interval"); @@ -728,7 +838,7 @@ public void parseRouteAction_withRetryPolicy() { .setRetryPolicy(builder) .build(); struct = XdsRouteConfigureResource.parseRouteAction(proto, filterRegistry, - ImmutableMap.of(), ImmutableSet.of()); + ImmutableMap.of(), ImmutableSet.of(), getXdsResourceTypeArgs(true)); assertThat(struct.getStruct().retryPolicy().initialBackoff()) .isEqualTo(Durations.fromMillis(1)); assertThat(struct.getStruct().retryPolicy().maxBackoff()) @@ -744,7 +854,7 @@ public void parseRouteAction_withRetryPolicy() { .setRetryPolicy(builder) .build(); struct = XdsRouteConfigureResource.parseRouteAction(proto, filterRegistry, - ImmutableMap.of(), ImmutableSet.of()); + ImmutableMap.of(), ImmutableSet.of(), getXdsResourceTypeArgs(true)); retryPolicy = struct.getStruct().retryPolicy(); assertThat(retryPolicy.initialBackoff()).isEqualTo(Durations.fromMillis(25)); assertThat(retryPolicy.maxBackoff()).isEqualTo(Durations.fromMillis(250)); @@ -763,7 +873,7 @@ public void parseRouteAction_withRetryPolicy() { .setRetryPolicy(builder) .build(); struct = XdsRouteConfigureResource.parseRouteAction(proto, filterRegistry, - ImmutableMap.of(), ImmutableSet.of()); + ImmutableMap.of(), ImmutableSet.of(), getXdsResourceTypeArgs(true)); assertThat(struct.getStruct().retryPolicy().retryableStatusCodes()) .containsExactly(Code.CANCELLED); @@ -781,7 +891,7 @@ public void parseRouteAction_withRetryPolicy() { .setRetryPolicy(builder) .build(); struct = XdsRouteConfigureResource.parseRouteAction(proto, filterRegistry, - ImmutableMap.of(), ImmutableSet.of()); + ImmutableMap.of(), ImmutableSet.of(), getXdsResourceTypeArgs(true)); assertThat(struct.getStruct().retryPolicy().retryableStatusCodes()) .containsExactly(Code.CANCELLED); @@ -799,7 +909,7 @@ public void parseRouteAction_withRetryPolicy() { .setRetryPolicy(builder) .build(); struct = XdsRouteConfigureResource.parseRouteAction(proto, filterRegistry, - ImmutableMap.of(), ImmutableSet.of()); + ImmutableMap.of(), ImmutableSet.of(), getXdsResourceTypeArgs(true)); assertThat(struct.getStruct().retryPolicy().retryableStatusCodes()) .containsExactly(Code.CANCELLED); } @@ -830,7 +940,7 @@ public void parseRouteAction_withHashPolicies() { io.envoyproxy.envoy.config.route.v3.RouteAction.HashPolicy.newBuilder() .setFilterState( FilterState.newBuilder() - .setKey(XdsResourceType.HASH_POLICY_FILTER_STATE_KEY))) + .setKey(XdsRouteConfigureResource.HASH_POLICY_FILTER_STATE_KEY))) .addHashPolicy( io.envoyproxy.envoy.config.route.v3.RouteAction.HashPolicy.newBuilder() .setQueryParameter( @@ -838,7 +948,7 @@ public void parseRouteAction_withHashPolicies() { .build(); StructOrError struct = XdsRouteConfigureResource.parseRouteAction(proto, filterRegistry, - ImmutableMap.of(), ImmutableSet.of()); + ImmutableMap.of(), ImmutableSet.of(), getXdsResourceTypeArgs(true)); List policies = struct.getStruct().hashPolicies(); assertThat(policies).hasSize(2); assertThat(policies.get(0).type()).isEqualTo(HashPolicy.Type.HEADER); @@ -858,23 +968,78 @@ public void parseRouteAction_custerSpecifierNotSet() { .build(); StructOrError struct = XdsRouteConfigureResource.parseRouteAction(proto, filterRegistry, - ImmutableMap.of(), ImmutableSet.of()); + ImmutableMap.of(), ImmutableSet.of(), getXdsResourceTypeArgs(true)); assertThat(struct).isNull(); } @Test public void parseRouteAction_clusterSpecifier_routeLookupDisabled() { - XdsResourceType.enableRouteLookup = false; + XdsRouteConfigureResource.enableRouteLookup = false; io.envoyproxy.envoy.config.route.v3.RouteAction proto = io.envoyproxy.envoy.config.route.v3.RouteAction.newBuilder() .setClusterSpecifierPlugin(CLUSTER_SPECIFIER_PLUGIN.name()) .build(); StructOrError struct = XdsRouteConfigureResource.parseRouteAction(proto, filterRegistry, - ImmutableMap.of(), ImmutableSet.of()); + ImmutableMap.of(), ImmutableSet.of(), getXdsResourceTypeArgs(true)); assertThat(struct).isNull(); } + @Test + public void parseRouteAction_clusterSpecifier() { + XdsRouteConfigureResource.enableRouteLookup = true; + io.envoyproxy.envoy.config.route.v3.RouteAction proto = + io.envoyproxy.envoy.config.route.v3.RouteAction.newBuilder() + .setClusterSpecifierPlugin(CLUSTER_SPECIFIER_PLUGIN.name()) + .build(); + StructOrError struct = + XdsRouteConfigureResource.parseRouteAction(proto, filterRegistry, + ImmutableMap.of(CLUSTER_SPECIFIER_PLUGIN.name(), RlsPluginConfig.create( + ImmutableMap.of("lookupService", "rls-cbt.googleapis.com"))), ImmutableSet.of(), + getXdsResourceTypeArgs(true)); + assertThat(struct.getStruct()).isNotNull(); + assertThat(struct.getStruct().autoHostRewrite()).isFalse(); + } + + @Test + public void parseRouteAction_clusterSpecifier_autoHostRewriteEnabled() { + System.setProperty(GRPC_EXPERIMENTAL_XDS_AUTHORITY_REWRITE, "true"); + try { + XdsRouteConfigureResource.enableRouteLookup = true; + io.envoyproxy.envoy.config.route.v3.RouteAction proto = + io.envoyproxy.envoy.config.route.v3.RouteAction.newBuilder() + .setClusterSpecifierPlugin(CLUSTER_SPECIFIER_PLUGIN.name()) + .setAutoHostRewrite(BoolValue.of(true)) + .build(); + StructOrError struct = + XdsRouteConfigureResource.parseRouteAction(proto, filterRegistry, + ImmutableMap.of(CLUSTER_SPECIFIER_PLUGIN.name(), RlsPluginConfig.create( + ImmutableMap.of("lookupService", "rls-cbt.googleapis.com"))), ImmutableSet.of(), + getXdsResourceTypeArgs(true)); + assertThat(struct.getStruct()).isNotNull(); + assertThat(struct.getStruct().autoHostRewrite()).isTrue(); + } finally { + System.clearProperty(GRPC_EXPERIMENTAL_XDS_AUTHORITY_REWRITE); + } + } + + @Test + public void parseRouteAction_clusterSpecifier_flagDisabled_autoHostRewriteDisabled() { + XdsRouteConfigureResource.enableRouteLookup = true; + io.envoyproxy.envoy.config.route.v3.RouteAction proto = + io.envoyproxy.envoy.config.route.v3.RouteAction.newBuilder() + .setClusterSpecifierPlugin(CLUSTER_SPECIFIER_PLUGIN.name()) + .setAutoHostRewrite(BoolValue.of(true)) + .build(); + StructOrError struct = + XdsRouteConfigureResource.parseRouteAction(proto, filterRegistry, + ImmutableMap.of(CLUSTER_SPECIFIER_PLUGIN.name(), RlsPluginConfig.create( + ImmutableMap.of("lookupService", "rls-cbt.googleapis.com"))), ImmutableSet.of(), + getXdsResourceTypeArgs(true)); + assertThat(struct.getStruct()).isNotNull(); + assertThat(struct.getStruct().autoHostRewrite()).isTrue(); + } + @Test public void parseClusterWeight() { io.envoyproxy.envoy.config.route.v3.WeightedCluster.ClusterWeight proto = @@ -883,13 +1048,15 @@ public void parseClusterWeight() { .setWeight(UInt32Value.newBuilder().setValue(30)) .build(); ClusterWeight clusterWeight = - XdsRouteConfigureResource.parseClusterWeight(proto, filterRegistry).getStruct(); + XdsRouteConfigureResource + .parseClusterWeight(proto, filterRegistry, getXdsResourceTypeArgs(true)) + .getStruct(); assertThat(clusterWeight.name()).isEqualTo("cluster-foo"); assertThat(clusterWeight.weight()).isEqualTo(30); } @Test - public void parseLocalityLbEndpoints_withHealthyEndpoints() { + public void parseLocalityLbEndpoints_withHealthyEndpoints() throws ResourceInvalidException { io.envoyproxy.envoy.config.endpoint.v3.LocalityLbEndpoints proto = io.envoyproxy.envoy.config.endpoint.v3.LocalityLbEndpoints.newBuilder() .setLocality(Locality.newBuilder() @@ -909,11 +1076,38 @@ public void parseLocalityLbEndpoints_withHealthyEndpoints() { assertThat(struct.getErrorDetail()).isNull(); assertThat(struct.getStruct()).isEqualTo( LocalityLbEndpoints.create( - Collections.singletonList(LbEndpoint.create("172.14.14.5", 8888, 20, true)), 100, 1)); + Collections.singletonList(LbEndpoint.create("172.14.14.5", 8888, + 20, true, "", ImmutableMap.of())), + 100, 1, ImmutableMap.of())); } @Test - public void parseLocalityLbEndpoints_treatUnknownHealthAsHealthy() { + public void parseLocalityLbEndpoints_onlyPermitIp() { + io.envoyproxy.envoy.config.endpoint.v3.LocalityLbEndpoints proto = + io.envoyproxy.envoy.config.endpoint.v3.LocalityLbEndpoints.newBuilder() + .setLocality(Locality.newBuilder() + .setRegion("region-foo").setZone("zone-foo").setSubZone("subZone-foo")) + .setLoadBalancingWeight(UInt32Value.newBuilder().setValue(100)) // locality weight + .setPriority(1) + .addLbEndpoints(io.envoyproxy.envoy.config.endpoint.v3.LbEndpoint.newBuilder() + .setEndpoint(Endpoint.newBuilder() + .setAddress(Address.newBuilder() + .setSocketAddress( + SocketAddress.newBuilder() + .setAddress("example.com").setPortValue(8888)))) + .setHealthStatus(io.envoyproxy.envoy.config.core.v3.HealthStatus.HEALTHY) + .setLoadBalancingWeight(UInt32Value.newBuilder().setValue(20))) // endpoint weight + .build(); + ResourceInvalidException ex = assertThrows( + ResourceInvalidException.class, + () -> XdsEndpointResource.parseLocalityLbEndpoints(proto)); + assertThat(ex.getMessage()).contains("IP"); + assertThat(ex.getMessage()).contains("example.com"); + } + + @Test + public void parseLocalityLbEndpoints_treatUnknownHealthAsHealthy() + throws ResourceInvalidException { io.envoyproxy.envoy.config.endpoint.v3.LocalityLbEndpoints proto = io.envoyproxy.envoy.config.endpoint.v3.LocalityLbEndpoints.newBuilder() .setLocality(Locality.newBuilder() @@ -933,11 +1127,13 @@ public void parseLocalityLbEndpoints_treatUnknownHealthAsHealthy() { assertThat(struct.getErrorDetail()).isNull(); assertThat(struct.getStruct()).isEqualTo( LocalityLbEndpoints.create( - Collections.singletonList(LbEndpoint.create("172.14.14.5", 8888, 20, true)), 100, 1)); + Collections.singletonList(LbEndpoint.create("172.14.14.5", 8888, + 20, true, "", ImmutableMap.of())), + 100, 1, ImmutableMap.of())); } @Test - public void parseLocalityLbEndpoints_withUnHealthyEndpoints() { + public void parseLocalityLbEndpoints_withUnHealthyEndpoints() throws ResourceInvalidException { io.envoyproxy.envoy.config.endpoint.v3.LocalityLbEndpoints proto = io.envoyproxy.envoy.config.endpoint.v3.LocalityLbEndpoints.newBuilder() .setLocality(Locality.newBuilder() @@ -957,11 +1153,13 @@ public void parseLocalityLbEndpoints_withUnHealthyEndpoints() { assertThat(struct.getErrorDetail()).isNull(); assertThat(struct.getStruct()).isEqualTo( LocalityLbEndpoints.create( - Collections.singletonList(LbEndpoint.create("172.14.14.5", 8888, 20, false)), 100, 1)); + Collections.singletonList(LbEndpoint.create("172.14.14.5", 8888, 20, + false, "", ImmutableMap.of())), + 100, 1, ImmutableMap.of())); } @Test - public void parseLocalityLbEndpoints_ignorZeroWeightLocality() { + public void parseLocalityLbEndpoints_ignorZeroWeightLocality() throws ResourceInvalidException { io.envoyproxy.envoy.config.endpoint.v3.LocalityLbEndpoints proto = io.envoyproxy.envoy.config.endpoint.v3.LocalityLbEndpoints.newBuilder() .setLocality(Locality.newBuilder() @@ -981,7 +1179,58 @@ public void parseLocalityLbEndpoints_ignorZeroWeightLocality() { } @Test - public void parseLocalityLbEndpoints_invalidPriority() { + public void parseLocalityLbEndpoints_withDualStackEndpoints() { + String originalDualStackProp = + System.setProperty(GRPC_EXPERIMENTAL_XDS_DUALSTACK_ENDPOINTS, "true"); + String v4Address = "172.14.14.5"; + String v6Address = "2001:db8::1"; + int port = 8888; + + try { + io.envoyproxy.envoy.config.endpoint.v3.LocalityLbEndpoints proto = + io.envoyproxy.envoy.config.endpoint.v3.LocalityLbEndpoints.newBuilder() + .setLocality(Locality.newBuilder() + .setRegion("region-foo").setZone("zone-foo").setSubZone("subZone-foo")) + .setLoadBalancingWeight(UInt32Value.newBuilder().setValue(100)) // locality weight + .setPriority(1) + .addLbEndpoints(io.envoyproxy.envoy.config.endpoint.v3.LbEndpoint.newBuilder() + .setEndpoint(Endpoint.newBuilder() + .setAddress(Address.newBuilder() + .setSocketAddress( + SocketAddress.newBuilder() + .setAddress(v4Address).setPortValue(port))) + .addAdditionalAddresses(Endpoint.AdditionalAddress.newBuilder() + .setAddress(Address.newBuilder() + .setSocketAddress( + SocketAddress.newBuilder() + .setAddress(v6Address).setPortValue(port))))) + .setHealthStatus(io.envoyproxy.envoy.config.core.v3.HealthStatus.HEALTHY) + .setLoadBalancingWeight(UInt32Value.newBuilder().setValue(20))) + .build(); + + StructOrError struct = + XdsEndpointResource.parseLocalityLbEndpoints(proto); + assertThat(struct.getErrorDetail()).isNull(); + List socketAddressList = Arrays.asList( + new InetSocketAddress(v4Address, port), new InetSocketAddress(v6Address, port)); + EquivalentAddressGroup expectedEag = new EquivalentAddressGroup(socketAddressList); + assertThat(struct.getStruct()).isEqualTo( + LocalityLbEndpoints.create( + Collections.singletonList(LbEndpoint.create( + expectedEag, 20, true, "", ImmutableMap.of())), 100, 1, ImmutableMap.of())); + } catch (ResourceInvalidException e) { + throw new RuntimeException(e); + } finally { + if (originalDualStackProp != null) { + System.setProperty(GRPC_EXPERIMENTAL_XDS_DUALSTACK_ENDPOINTS, originalDualStackProp); + } else { + System.clearProperty(GRPC_EXPERIMENTAL_XDS_DUALSTACK_ENDPOINTS); + } + } + } + + @Test + public void parseLocalityLbEndpoints_invalidPriority() throws ResourceInvalidException { io.envoyproxy.envoy.config.endpoint.v3.LocalityLbEndpoints proto = io.envoyproxy.envoy.config.endpoint.v3.LocalityLbEndpoints.newBuilder() .setLocality(Locality.newBuilder() @@ -1007,7 +1256,8 @@ public void parseHttpFilter_unsupportedButOptional() { .setIsOptional(true) .setTypedConfig(Any.pack(StringValue.of("unsupported"))) .build(); - assertThat(XdsListenerResource.parseHttpFilter(httpFilter, filterRegistry, true)).isNull(); + assertThat(XdsListenerResource.parseHttpFilter(httpFilter, filterRegistry, true, + getXdsResourceTypeArgs(true))).isNull(); } private static class SimpleFilterConfig implements FilterConfig { @@ -1027,37 +1277,41 @@ public String typeUrl() { } } - private static class TestFilter implements io.grpc.xds.Filter, - io.grpc.xds.Filter.ClientInterceptorBuilder { - @Override - public String[] typeUrls() { - return new String[]{"test-url"}; - } + private static class TestFilter implements io.grpc.xds.Filter { - @Override - public ConfigOrError parseFilterConfig(Message rawProtoMessage) { - return ConfigOrError.fromConfig(new SimpleFilterConfig(rawProtoMessage)); - } + static final class Provider implements io.grpc.xds.Filter.Provider { + @Override + public String[] typeUrls() { + return new String[]{"test-url"}; + } - @Override - public ConfigOrError parseFilterConfigOverride( - Message rawProtoMessage) { - return ConfigOrError.fromConfig(new SimpleFilterConfig(rawProtoMessage)); - } + @Override + public boolean isClientFilter() { + return true; + } - @Nullable - @Override - public ClientInterceptor buildClientInterceptor(FilterConfig config, - @Nullable FilterConfig overrideConfig, - LoadBalancer.PickSubchannelArgs args, - ScheduledExecutorService scheduler) { - return null; + @Override + public TestFilter newInstance(String name) { + return new TestFilter(); + } + + @Override + public ConfigOrError parseFilterConfig(Message rawProtoMessage, + FilterConfigParseContext context) { + return ConfigOrError.fromConfig(new SimpleFilterConfig(rawProtoMessage)); + } + + @Override + public ConfigOrError parseFilterConfigOverride(Message rawProtoMessage, + FilterConfigParseContext context) { + return ConfigOrError.fromConfig(new SimpleFilterConfig(rawProtoMessage)); + } } } @Test public void parseHttpFilter_typedStructMigration() { - filterRegistry.register(new TestFilter()); + filterRegistry.register(new TestFilter.Provider()); Struct rawStruct = Struct.newBuilder() .putFields("name", Value.newBuilder().setStringValue("default").build()) .build(); @@ -1069,7 +1323,7 @@ public void parseHttpFilter_typedStructMigration() { .setValue(rawStruct) .build())).build(); FilterConfig config = XdsListenerResource.parseHttpFilter(httpFilter, filterRegistry, - true).getStruct(); + true, getXdsResourceTypeArgs(true)).getStruct(); assertThat(((SimpleFilterConfig)config).getConfig()).isEqualTo(rawStruct); HttpFilter httpFilterNewTypeStruct = HttpFilter.newBuilder() @@ -1080,13 +1334,13 @@ public void parseHttpFilter_typedStructMigration() { .setValue(rawStruct) .build())).build(); config = XdsListenerResource.parseHttpFilter(httpFilterNewTypeStruct, filterRegistry, - true).getStruct(); + true, getXdsResourceTypeArgs(true)).getStruct(); assertThat(((SimpleFilterConfig)config).getConfig()).isEqualTo(rawStruct); } @Test public void parseOverrideHttpFilter_typedStructMigration() { - filterRegistry.register(new TestFilter()); + filterRegistry.register(new TestFilter.Provider()); Struct rawStruct0 = Struct.newBuilder() .putFields("name", Value.newBuilder().setStringValue("default0").build()) .build(); @@ -1106,7 +1360,7 @@ public void parseOverrideHttpFilter_typedStructMigration() { .build()) ); Map map = XdsRouteConfigureResource.parseOverrideFilterConfigs( - rawFilterMap, filterRegistry).getStruct(); + rawFilterMap, filterRegistry, getXdsResourceTypeArgs(true)).getStruct(); assertThat(((SimpleFilterConfig)map.get("struct-0")).getConfig()).isEqualTo(rawStruct0); assertThat(((SimpleFilterConfig)map.get("struct-1")).getConfig()).isEqualTo(rawStruct1); } @@ -1118,7 +1372,8 @@ public void parseHttpFilter_unsupportedAndRequired() { .setName("unsupported.filter") .setTypedConfig(Any.pack(StringValue.of("string value"))) .build(); - assertThat(XdsListenerResource.parseHttpFilter(httpFilter, filterRegistry, true) + assertThat(XdsListenerResource + .parseHttpFilter(httpFilter, filterRegistry, true, getXdsResourceTypeArgs(true)) .getErrorDetail()).isEqualTo( "HttpFilter [unsupported.filter]" + "(type.googleapis.com/google.protobuf.StringValue) is required but unsupported " @@ -1127,7 +1382,7 @@ public void parseHttpFilter_unsupportedAndRequired() { @Test public void parseHttpFilter_routerFilterForClient() { - filterRegistry.register(RouterFilter.INSTANCE); + filterRegistry.register(ROUTER_FILTER_PROVIDER); HttpFilter httpFilter = HttpFilter.newBuilder() .setIsOptional(false) @@ -1135,13 +1390,14 @@ public void parseHttpFilter_routerFilterForClient() { .setTypedConfig(Any.pack(Router.getDefaultInstance())) .build(); FilterConfig config = XdsListenerResource.parseHttpFilter( - httpFilter, filterRegistry, true /* isForClient */).getStruct(); + httpFilter, filterRegistry, true /* isForClient */, getXdsResourceTypeArgs(true)) + .getStruct(); assertThat(config.typeUrl()).isEqualTo(RouterFilter.TYPE_URL); } @Test public void parseHttpFilter_routerFilterForServer() { - filterRegistry.register(RouterFilter.INSTANCE); + filterRegistry.register(ROUTER_FILTER_PROVIDER); HttpFilter httpFilter = HttpFilter.newBuilder() .setIsOptional(false) @@ -1149,13 +1405,14 @@ public void parseHttpFilter_routerFilterForServer() { .setTypedConfig(Any.pack(Router.getDefaultInstance())) .build(); FilterConfig config = XdsListenerResource.parseHttpFilter( - httpFilter, filterRegistry, false /* isForClient */).getStruct(); + httpFilter, filterRegistry, false /* isForClient */, getXdsResourceTypeArgs(false)) + .getStruct(); assertThat(config.typeUrl()).isEqualTo(RouterFilter.TYPE_URL); } @Test public void parseHttpFilter_faultConfigForClient() { - filterRegistry.register(FaultFilter.INSTANCE); + filterRegistry.register(FAULT_FILTER_PROVIDER); HttpFilter httpFilter = HttpFilter.newBuilder() .setIsOptional(false) @@ -1176,13 +1433,14 @@ public void parseHttpFilter_faultConfigForClient() { .build())) .build(); FilterConfig config = XdsListenerResource.parseHttpFilter( - httpFilter, filterRegistry, true /* isForClient */).getStruct(); + httpFilter, filterRegistry, true /* isForClient */, getXdsResourceTypeArgs(true)) + .getStruct(); assertThat(config).isInstanceOf(FaultConfig.class); } @Test public void parseHttpFilter_faultConfigUnsupportedForServer() { - filterRegistry.register(FaultFilter.INSTANCE); + filterRegistry.register(FAULT_FILTER_PROVIDER); HttpFilter httpFilter = HttpFilter.newBuilder() .setIsOptional(false) @@ -1203,7 +1461,8 @@ public void parseHttpFilter_faultConfigUnsupportedForServer() { .build())) .build(); StructOrError config = - XdsListenerResource.parseHttpFilter(httpFilter, filterRegistry, false /* isForClient */); + XdsListenerResource.parseHttpFilter(httpFilter, filterRegistry, false /* isForClient */, + getXdsResourceTypeArgs(false)); assertThat(config.getErrorDetail()).isEqualTo( "HttpFilter [envoy.fault](" + FaultFilter.TYPE_URL + ") is required but " + "unsupported for server"); @@ -1211,7 +1470,7 @@ public void parseHttpFilter_faultConfigUnsupportedForServer() { @Test public void parseHttpFilter_rbacConfigForServer() { - filterRegistry.register(RbacFilter.INSTANCE); + filterRegistry.register(RBAC_FILTER_PROVIDER); HttpFilter httpFilter = HttpFilter.newBuilder() .setIsOptional(false) @@ -1232,13 +1491,14 @@ public void parseHttpFilter_rbacConfigForServer() { .build())) .build(); FilterConfig config = XdsListenerResource.parseHttpFilter( - httpFilter, filterRegistry, false /* isForClient */).getStruct(); + httpFilter, filterRegistry, false /* isForClient */, getXdsResourceTypeArgs(false)) + .getStruct(); assertThat(config).isInstanceOf(RbacConfig.class); } @Test public void parseHttpFilter_rbacConfigUnsupportedForClient() { - filterRegistry.register(RbacFilter.INSTANCE); + filterRegistry.register(RBAC_FILTER_PROVIDER); HttpFilter httpFilter = HttpFilter.newBuilder() .setIsOptional(false) @@ -1259,7 +1519,8 @@ public void parseHttpFilter_rbacConfigUnsupportedForClient() { .build())) .build(); StructOrError config = - XdsListenerResource.parseHttpFilter(httpFilter, filterRegistry, true /* isForClient */); + XdsListenerResource.parseHttpFilter(httpFilter, filterRegistry, true /* isForClient */, + getXdsResourceTypeArgs(true)); assertThat(config.getErrorDetail()).isEqualTo( "HttpFilter [envoy.auth](" + RbacFilter.TYPE_URL + ") is required but " + "unsupported for client"); @@ -1267,7 +1528,7 @@ public void parseHttpFilter_rbacConfigUnsupportedForClient() { @Test public void parseOverrideRbacFilterConfig() { - filterRegistry.register(RbacFilter.INSTANCE); + filterRegistry.register(RBAC_FILTER_PROVIDER); RBACPerRoute rbacPerRoute = RBACPerRoute.newBuilder() .setRbac( @@ -1284,7 +1545,8 @@ public void parseOverrideRbacFilterConfig() { .build(); Map configOverrides = ImmutableMap.of("envoy.auth", Any.pack(rbacPerRoute)); Map parsedConfigs = - XdsRouteConfigureResource.parseOverrideFilterConfigs(configOverrides, filterRegistry) + XdsRouteConfigureResource.parseOverrideFilterConfigs(configOverrides, filterRegistry, + getXdsResourceTypeArgs(true)) .getStruct(); assertThat(parsedConfigs).hasSize(1); assertThat(parsedConfigs).containsKey("envoy.auth"); @@ -1293,7 +1555,7 @@ public void parseOverrideRbacFilterConfig() { @Test public void parseOverrideFilterConfigs_unsupportedButOptional() { - filterRegistry.register(FaultFilter.INSTANCE); + filterRegistry.register(FAULT_FILTER_PROVIDER); HTTPFault httpFault = HTTPFault.newBuilder() .setDelay(FaultDelay.newBuilder().setFixedDelay(Durations.fromNanos(3000))) .build(); @@ -1305,7 +1567,8 @@ public void parseOverrideFilterConfigs_unsupportedButOptional() { .setIsOptional(true).setConfig(Any.pack(StringValue.of("string value"))) .build())); Map parsedConfigs = - XdsRouteConfigureResource.parseOverrideFilterConfigs(configOverrides, filterRegistry) + XdsRouteConfigureResource.parseOverrideFilterConfigs(configOverrides, filterRegistry, + getXdsResourceTypeArgs(true)) .getStruct(); assertThat(parsedConfigs).hasSize(1); assertThat(parsedConfigs).containsKey("envoy.fault"); @@ -1313,7 +1576,7 @@ public void parseOverrideFilterConfigs_unsupportedButOptional() { @Test public void parseOverrideFilterConfigs_unsupportedAndRequired() { - filterRegistry.register(FaultFilter.INSTANCE); + filterRegistry.register(FAULT_FILTER_PROVIDER); HTTPFault httpFault = HTTPFault.newBuilder() .setDelay(FaultDelay.newBuilder().setFixedDelay(Durations.fromNanos(3000))) .build(); @@ -1324,7 +1587,9 @@ public void parseOverrideFilterConfigs_unsupportedAndRequired() { Any.pack(io.envoyproxy.envoy.config.route.v3.FilterConfig.newBuilder() .setIsOptional(false).setConfig(Any.pack(StringValue.of("string value"))) .build())); - assertThat(XdsRouteConfigureResource.parseOverrideFilterConfigs(configOverrides, filterRegistry) + assertThat(XdsRouteConfigureResource + .parseOverrideFilterConfigs(configOverrides, filterRegistry, + getXdsResourceTypeArgs(true)) .getErrorDetail()).isEqualTo( "HttpFilter [unsupported.filter]" + "(type.googleapis.com/google.protobuf.StringValue) is required but unsupported"); @@ -1334,7 +1599,9 @@ public void parseOverrideFilterConfigs_unsupportedAndRequired() { Any.pack(httpFault), "unsupported.filter", Any.pack(StringValue.of("string value"))); - assertThat(XdsRouteConfigureResource.parseOverrideFilterConfigs(configOverrides, filterRegistry) + assertThat(XdsRouteConfigureResource + .parseOverrideFilterConfigs(configOverrides, filterRegistry, + getXdsResourceTypeArgs(true)) .getErrorDetail()).isEqualTo( "HttpFilter [unsupported.filter]" + "(type.googleapis.com/google.protobuf.StringValue) is required but unsupported"); @@ -1345,11 +1612,12 @@ public void parseHttpConnectionManager_xffNumTrustedHopsUnsupported() throws ResourceInvalidException { @SuppressWarnings("deprecation") HttpConnectionManager hcm = HttpConnectionManager.newBuilder().setXffNumTrustedHops(2).build(); - thrown.expect(ResourceInvalidException.class); - thrown.expectMessage("HttpConnectionManager with xff_num_trusted_hops unsupported"); - XdsListenerResource.parseHttpConnectionManager( - hcm, filterRegistry, - true /* does not matter */); + ResourceInvalidException e = assertThrows(ResourceInvalidException.class, + () -> XdsListenerResource.parseHttpConnectionManager( + hcm, filterRegistry, + true /* does not matter */, getXdsResourceTypeArgs(true))); + assertThat(e).hasMessageThat() + .isEqualTo("HttpConnectionManager with xff_num_trusted_hops unsupported"); } @Test @@ -1359,12 +1627,13 @@ public void parseHttpConnectionManager_OriginalIpDetectionExtensionsMustEmpty() HttpConnectionManager hcm = HttpConnectionManager.newBuilder() .addOriginalIpDetectionExtensions(TypedExtensionConfig.newBuilder().build()) .build(); - thrown.expect(ResourceInvalidException.class); - thrown.expectMessage("HttpConnectionManager with original_ip_detection_extensions unsupported"); - XdsListenerResource.parseHttpConnectionManager( - hcm, filterRegistry, false); + ResourceInvalidException e = assertThrows(ResourceInvalidException.class, () -> + XdsListenerResource.parseHttpConnectionManager( + hcm, filterRegistry, false, getXdsResourceTypeArgs(true))); + assertThat(e).hasMessageThat() + .isEqualTo("HttpConnectionManager with original_ip_detection_extensions unsupported"); } - + @Test public void parseHttpConnectionManager_missingRdsAndInlinedRouteConfiguration() throws ResourceInvalidException { @@ -1377,11 +1646,12 @@ public void parseHttpConnectionManager_missingRdsAndInlinedRouteConfiguration() HttpFilter.newBuilder().setName("terminal").setTypedConfig( Any.pack(Router.newBuilder().build())).setIsOptional(true)) .build(); - thrown.expect(ResourceInvalidException.class); - thrown.expectMessage("HttpConnectionManager neither has inlined route_config nor RDS"); - XdsListenerResource.parseHttpConnectionManager( - hcm, filterRegistry, - true /* does not matter */); + ResourceInvalidException e = assertThrows(ResourceInvalidException.class, () -> + XdsListenerResource.parseHttpConnectionManager( + hcm, filterRegistry, + true /* does not matter */, getXdsResourceTypeArgs(true))); + assertThat(e).hasMessageThat() + .isEqualTo("HttpConnectionManager neither has inlined route_config nor RDS"); } @Test @@ -1396,16 +1666,17 @@ public void parseHttpConnectionManager_duplicateHttpFilters() throws ResourceInv HttpFilter.newBuilder().setName("terminal").setTypedConfig( Any.pack(Router.newBuilder().build())).setIsOptional(true)) .build(); - thrown.expect(ResourceInvalidException.class); - thrown.expectMessage("HttpConnectionManager contains duplicate HttpFilter: envoy.filter.foo"); - XdsListenerResource.parseHttpConnectionManager( - hcm, filterRegistry, - true /* does not matter */); + ResourceInvalidException e = assertThrows(ResourceInvalidException.class, () -> + XdsListenerResource.parseHttpConnectionManager( + hcm, filterRegistry, + true /* does not matter */, getXdsResourceTypeArgs(true))); + assertThat(e).hasMessageThat() + .isEqualTo("HttpConnectionManager contains duplicate HttpFilter: envoy.filter.foo"); } @Test public void parseHttpConnectionManager_lastNotTerminal() throws ResourceInvalidException { - filterRegistry.register(FaultFilter.INSTANCE); + filterRegistry.register(FAULT_FILTER_PROVIDER); HttpConnectionManager hcm = HttpConnectionManager.newBuilder() .addHttpFilters( @@ -1414,16 +1685,17 @@ public void parseHttpConnectionManager_lastNotTerminal() throws ResourceInvalidE HttpFilter.newBuilder().setName("envoy.filter.bar").setIsOptional(true) .setTypedConfig(Any.pack(HTTPFault.newBuilder().build()))) .build(); - thrown.expect(ResourceInvalidException.class); - thrown.expectMessage("The last HttpFilter must be a terminal filter: envoy.filter.bar"); - XdsListenerResource.parseHttpConnectionManager( + ResourceInvalidException e = assertThrows(ResourceInvalidException.class, () -> + XdsListenerResource.parseHttpConnectionManager( hcm, filterRegistry, - true /* does not matter */); + true /* does not matter */, getXdsResourceTypeArgs(true))); + assertThat(e).hasMessageThat() + .isEqualTo("The last HttpFilter must be a terminal filter: envoy.filter.bar"); } @Test public void parseHttpConnectionManager_terminalNotLast() throws ResourceInvalidException { - filterRegistry.register(RouterFilter.INSTANCE); + filterRegistry.register(ROUTER_FILTER_PROVIDER); HttpConnectionManager hcm = HttpConnectionManager.newBuilder() .addHttpFilters( @@ -1432,11 +1704,12 @@ public void parseHttpConnectionManager_terminalNotLast() throws ResourceInvalidE .addHttpFilters( HttpFilter.newBuilder().setName("envoy.filter.foo").setIsOptional(true)) .build(); - thrown.expect(ResourceInvalidException.class); - thrown.expectMessage("A terminal HttpFilter must be the last filter: terminal"); - XdsListenerResource.parseHttpConnectionManager( + ResourceInvalidException e = assertThrows(ResourceInvalidException.class, () -> + XdsListenerResource.parseHttpConnectionManager( hcm, filterRegistry, - true); + true, getXdsResourceTypeArgs(true))); + assertThat(e).hasMessageThat() + .isEqualTo("A terminal HttpFilter must be the last filter: terminal"); } @Test @@ -1448,11 +1721,12 @@ public void parseHttpConnectionManager_unknownFilters() throws ResourceInvalidEx .addHttpFilters( HttpFilter.newBuilder().setName("envoy.filter.bar").setIsOptional(true)) .build(); - thrown.expect(ResourceInvalidException.class); - thrown.expectMessage("The last HttpFilter must be a terminal filter: envoy.filter.bar"); - XdsListenerResource.parseHttpConnectionManager( + ResourceInvalidException e = assertThrows(ResourceInvalidException.class, () -> + XdsListenerResource.parseHttpConnectionManager( hcm, filterRegistry, - true /* does not matter */); + true /* does not matter */, getXdsResourceTypeArgs(true))); + assertThat(e).hasMessageThat() + .isEqualTo("The last HttpFilter must be a terminal filter: envoy.filter.bar"); } @Test @@ -1460,16 +1734,17 @@ public void parseHttpConnectionManager_emptyFilters() throws ResourceInvalidExce HttpConnectionManager hcm = HttpConnectionManager.newBuilder() .build(); - thrown.expect(ResourceInvalidException.class); - thrown.expectMessage("Missing HttpFilter in HttpConnectionManager."); - XdsListenerResource.parseHttpConnectionManager( + ResourceInvalidException e = assertThrows(ResourceInvalidException.class, () -> + XdsListenerResource.parseHttpConnectionManager( hcm, filterRegistry, - true /* does not matter */); + true /* does not matter */, getXdsResourceTypeArgs(true))); + assertThat(e).hasMessageThat() + .isEqualTo("Missing HttpFilter in HttpConnectionManager."); } @Test public void parseHttpConnectionManager_clusterSpecifierPlugin() throws Exception { - XdsResourceType.enableRouteLookup = true; + XdsRouteConfigureResource.enableRouteLookup = true; RouteLookupConfig routeLookupConfig = RouteLookupConfig.newBuilder() .addGrpcKeybuilders( GrpcKeyBuilder.newBuilder() @@ -1513,7 +1788,7 @@ public void parseHttpConnectionManager_clusterSpecifierPlugin() throws Exception io.grpc.xds.HttpConnectionManager parsedHcm = XdsListenerResource.parseHttpConnectionManager( hcm, filterRegistry, - true /* does not matter */); + true /* does not matter */, getXdsResourceTypeArgs(true)); VirtualHost virtualHost = Iterables.getOnlyElement(parsedHcm.virtualHosts()); Route parsedRoute = Iterables.getOnlyElement(virtualHost.routes()); @@ -1525,7 +1800,7 @@ public void parseHttpConnectionManager_clusterSpecifierPlugin() throws Exception @Test public void parseHttpConnectionManager_duplicatePluginName() throws Exception { - XdsResourceType.enableRouteLookup = true; + XdsRouteConfigureResource.enableRouteLookup = true; RouteLookupConfig routeLookupConfig1 = RouteLookupConfig.newBuilder() .addGrpcKeybuilders( GrpcKeyBuilder.newBuilder() @@ -1588,17 +1863,17 @@ public void parseHttpConnectionManager_duplicatePluginName() throws Exception { Any.pack(Router.newBuilder().build())).setIsOptional(true)) .build(); - thrown.expect(ResourceInvalidException.class); - thrown.expectMessage("Multiple ClusterSpecifierPlugins with the same name: rls-plugin-1"); - - XdsListenerResource.parseHttpConnectionManager( - hcm, filterRegistry, - true /* does not matter */); + ResourceInvalidException e = assertThrows(ResourceInvalidException.class, () -> + XdsListenerResource.parseHttpConnectionManager( + hcm, filterRegistry, + true /* does not matter */, getXdsResourceTypeArgs(true))); + assertThat(e).hasMessageThat() + .isEqualTo("Multiple ClusterSpecifierPlugins with the same name: rls-plugin-1"); } @Test public void parseHttpConnectionManager_pluginNameNotFound() throws Exception { - XdsResourceType.enableRouteLookup = true; + XdsRouteConfigureResource.enableRouteLookup = true; RouteLookupConfig routeLookupConfig = RouteLookupConfig.newBuilder() .addGrpcKeybuilders( GrpcKeyBuilder.newBuilder() @@ -1640,18 +1915,18 @@ public void parseHttpConnectionManager_pluginNameNotFound() throws Exception { Any.pack(Router.newBuilder().build())).setIsOptional(true)) .build(); - thrown.expect(ResourceInvalidException.class); - thrown.expectMessage("ClusterSpecifierPlugin for [invalid-plugin-name] not found"); - - XdsListenerResource.parseHttpConnectionManager( - hcm, filterRegistry, - true /* does not matter */); + ResourceInvalidException e = assertThrows(ResourceInvalidException.class, () -> + XdsListenerResource.parseHttpConnectionManager( + hcm, filterRegistry, + true /* does not matter */, getXdsResourceTypeArgs(true))); + assertThat(e).hasMessageThat() + .contains("ClusterSpecifierPlugin for [invalid-plugin-name] not found"); } @Test public void parseHttpConnectionManager_optionalPlugin() throws ResourceInvalidException { - XdsResourceType.enableRouteLookup = true; + XdsRouteConfigureResource.enableRouteLookup = true; // RLS Plugin, and a route to it. RouteLookupConfig routeLookupConfig = RouteLookupConfig.newBuilder() @@ -1719,7 +1994,7 @@ public void parseHttpConnectionManager_optionalPlugin() throws ResourceInvalidEx HttpFilter.newBuilder().setName("terminal").setTypedConfig( Any.pack(Router.newBuilder().build())).setIsOptional(true)) .build(), filterRegistry, - true /* does not matter */); + true /* does not matter */, getXdsResourceTypeArgs(true)); // Verify that the only route left is the one with the registered RLS plugin `rls-plugin-1`, // while the route with unregistered optional `optional-plugin-`1 has been skipped. @@ -1733,7 +2008,7 @@ public void parseHttpConnectionManager_optionalPlugin() throws ResourceInvalidEx @Test public void parseHttpConnectionManager_validateRdsConfigSource() throws Exception { - XdsResourceType.enableRouteLookup = true; + XdsRouteConfigureResource.enableRouteLookup = true; HttpConnectionManager hcm1 = HttpConnectionManager.newBuilder() @@ -1747,7 +2022,7 @@ public void parseHttpConnectionManager_validateRdsConfigSource() throws Exceptio .build(); XdsListenerResource.parseHttpConnectionManager( hcm1, filterRegistry, - true /* does not matter */); + true /* does not matter */, getXdsResourceTypeArgs(true)); HttpConnectionManager hcm2 = HttpConnectionManager.newBuilder() @@ -1761,7 +2036,7 @@ public void parseHttpConnectionManager_validateRdsConfigSource() throws Exceptio .build(); XdsListenerResource.parseHttpConnectionManager( hcm2, filterRegistry, - true /* does not matter */); + true /* does not matter */, getXdsResourceTypeArgs(true)); HttpConnectionManager hcm3 = HttpConnectionManager.newBuilder() @@ -1774,12 +2049,12 @@ public void parseHttpConnectionManager_validateRdsConfigSource() throws Exceptio HttpFilter.newBuilder().setName("terminal").setTypedConfig( Any.pack(Router.newBuilder().build())).setIsOptional(true)) .build(); - thrown.expect(ResourceInvalidException.class); - thrown.expectMessage( + ResourceInvalidException e = assertThrows(ResourceInvalidException.class, () -> + XdsListenerResource.parseHttpConnectionManager( + hcm3, filterRegistry, + true /* does not matter */, getXdsResourceTypeArgs(true))); + assertThat(e).hasMessageThat().isEqualTo( "HttpConnectionManager contains invalid RDS: must specify ADS or self ConfigSource"); - XdsListenerResource.parseHttpConnectionManager( - hcm3, filterRegistry, - true /* does not matter */); } @Test @@ -1869,11 +2144,10 @@ public void parseClusterSpecifierPlugin_unregisteredPlugin() throws Exception { .setTypedConfig(Any.pack(StringValue.of("unregistered")))) .build(); - thrown.expect(ResourceInvalidException.class); - thrown.expectMessage( + ResourceInvalidException e = assertThrows(ResourceInvalidException.class, () -> + XdsRouteConfigureResource.parseClusterSpecifierPlugin(pluginProto, registry)); + assertThat(e).hasMessageThat().isEqualTo( "Unsupported ClusterSpecifierPlugin type: type.googleapis.com/google.protobuf.StringValue"); - - XdsRouteConfigureResource.parseClusterSpecifierPlugin(pluginProto, registry); } @Test @@ -1971,7 +2245,7 @@ public void parseCluster_ringHashLbPolicy_defaultLbConfig() throws ResourceInval @Test public void parseCluster_leastRequestLbPolicy_defaultLbConfig() throws ResourceInvalidException { - XdsResourceType.enableLeastRequest = true; + XdsClusterResource.enableLeastRequest = true; Cluster cluster = Cluster.newBuilder() .setName("cluster-foo.googleapis.com") .setType(DiscoveryType.EDS) @@ -2070,11 +2344,11 @@ public void parseCluster_transportSocketMatches_exception() throws ResourceInval Cluster.TransportSocketMatch.newBuilder().setName("match1").build()) .build(); - thrown.expect(ResourceInvalidException.class); - thrown.expectMessage( + ResourceInvalidException e = assertThrows(ResourceInvalidException.class, () -> + XdsClusterResource.processCluster(cluster, null, LRS_SERVER_INFO, + LoadBalancerRegistry.getDefaultRegistry())); + assertThat(e).hasMessageThat().isEqualTo( "Cluster cluster-foo.googleapis.com: transport-socket-matches not supported."); - XdsClusterResource.processCluster(cluster, null, LRS_SERVER_INFO, - LoadBalancerRegistry.getDefaultRegistry()); } @Test @@ -2119,12 +2393,303 @@ public void parseCluster_validateEdsSourceConfig() throws ResourceInvalidExcepti .setLbPolicy(LbPolicy.ROUND_ROBIN) .build(); - thrown.expect(ResourceInvalidException.class); - thrown.expectMessage( + ResourceInvalidException e = assertThrows(ResourceInvalidException.class, () -> + XdsClusterResource.processCluster(cluster3, null, LRS_SERVER_INFO, + LoadBalancerRegistry.getDefaultRegistry())); + assertThat(e).hasMessageThat().isEqualTo( "Cluster cluster-foo.googleapis.com: field eds_cluster_config must be set to indicate to" + " use EDS over ADS or self ConfigSource"); - XdsClusterResource.processCluster(cluster3, null, LRS_SERVER_INFO, + } + + @Test + public void processCluster_parsesMetadata() + throws ResourceInvalidException, InvalidProtocolBufferException { + MetadataRegistry metadataRegistry = MetadataRegistry.getInstance(); + + MetadataValueParser testParser = + new MetadataValueParser() { + @Override + public String getTypeUrl() { + return "type.googleapis.com/test.Type"; + } + + @Override + public Object parse(Any value) { + assertThat(value.getValue().toStringUtf8()).isEqualTo("test"); + return value.getValue().toStringUtf8() + "_processed"; + } + }; + metadataRegistry.registerParser(testParser); + + Any typedFilterMetadata = Any.newBuilder() + .setTypeUrl("type.googleapis.com/test.Type") + .setValue(ByteString.copyFromUtf8("test")) + .build(); + + Struct filterMetadata = Struct.newBuilder() + .putFields("key1", Value.newBuilder().setStringValue("value1").build()) + .putFields("key2", Value.newBuilder().setNumberValue(42).build()) + .build(); + + Metadata metadata = Metadata.newBuilder() + .putTypedFilterMetadata("TYPED_FILTER_METADATA", typedFilterMetadata) + .putFilterMetadata("FILTER_METADATA", filterMetadata) + .build(); + + Cluster cluster = Cluster.newBuilder() + .setName("cluster-foo.googleapis.com") + .setType(DiscoveryType.EDS) + .setEdsClusterConfig( + EdsClusterConfig.newBuilder() + .setEdsConfig( + ConfigSource.newBuilder() + .setAds(AggregatedConfigSource.getDefaultInstance())) + .setServiceName("service-foo.googleapis.com")) + .setLbPolicy(LbPolicy.ROUND_ROBIN) + .setMetadata(metadata) + .build(); + + CdsUpdate update = XdsClusterResource.processCluster( + cluster, null, LRS_SERVER_INFO, LoadBalancerRegistry.getDefaultRegistry()); + + ImmutableMap expectedParsedMetadata = ImmutableMap.of( + "TYPED_FILTER_METADATA", "test_processed", + "FILTER_METADATA", ImmutableMap.of( + "key1", "value1", + "key2", 42.0)); + assertThat(update.parsedMetadata()).isEqualTo(expectedParsedMetadata); + metadataRegistry.removeParser(testParser); + } + + @Test + public void processCluster_parsesAudienceMetadata() throws Exception { + MetadataRegistry.getInstance(); + + Audience audience = Audience.newBuilder() + .setUrl("https://example.com") + .build(); + + Any audienceMetadata = Any.newBuilder() + .setTypeUrl("type.googleapis.com/envoy.extensions.filters.http.gcp_authn.v3.Audience") + .setValue(audience.toByteString()) + .build(); + + Struct filterMetadata = Struct.newBuilder() + .putFields("key1", Value.newBuilder().setStringValue("value1").build()) + .putFields("key2", Value.newBuilder().setNumberValue(42).build()) + .build(); + + Metadata metadata = Metadata.newBuilder() + .putTypedFilterMetadata("AUDIENCE_METADATA", audienceMetadata) + .putFilterMetadata("FILTER_METADATA", filterMetadata) + .build(); + + Cluster cluster = Cluster.newBuilder() + .setName("cluster-foo.googleapis.com") + .setType(DiscoveryType.EDS) + .setEdsClusterConfig( + EdsClusterConfig.newBuilder() + .setEdsConfig( + ConfigSource.newBuilder() + .setAds(AggregatedConfigSource.getDefaultInstance())) + .setServiceName("service-foo.googleapis.com")) + .setLbPolicy(LbPolicy.ROUND_ROBIN) + .setMetadata(metadata) + .build(); + + CdsUpdate update = XdsClusterResource.processCluster( + cluster, null, LRS_SERVER_INFO, + LoadBalancerRegistry.getDefaultRegistry()); + + ImmutableMap expectedParsedMetadata = ImmutableMap.of( + "AUDIENCE_METADATA", "https://example.com", + "FILTER_METADATA", ImmutableMap.of( + "key1", "value1", + "key2", 42.0)); + + assertThat(update.parsedMetadata().get("FILTER_METADATA")) + .isEqualTo(expectedParsedMetadata.get("FILTER_METADATA")); + assertThat(update.parsedMetadata().get("AUDIENCE_METADATA")) + .isInstanceOf(AudienceWrapper.class); + } + + @Test + public void processCluster_parsesAddressMetadata() throws Exception { + + // Create an Address message + Address address = Address.newBuilder() + .setSocketAddress(SocketAddress.newBuilder() + .setAddress("192.168.1.1") + .setPortValue(8080) + .build()) + .build(); + + // Wrap the Address in Any + Any addressMetadata = Any.newBuilder() + .setTypeUrl("type.googleapis.com/envoy.config.core.v3.Address") + .setValue(address.toByteString()) + .build(); + + Struct filterMetadata = Struct.newBuilder() + .putFields("key1", Value.newBuilder().setStringValue("value1").build()) + .putFields("key2", Value.newBuilder().setNumberValue(42).build()) + .build(); + + Metadata metadata = Metadata.newBuilder() + .putTypedFilterMetadata("ADDRESS_METADATA", addressMetadata) + .putFilterMetadata("FILTER_METADATA", filterMetadata) + .build(); + + Cluster cluster = Cluster.newBuilder() + .setName("cluster-foo.googleapis.com") + .setType(DiscoveryType.EDS) + .setEdsClusterConfig( + EdsClusterConfig.newBuilder() + .setEdsConfig( + ConfigSource.newBuilder() + .setAds(AggregatedConfigSource.getDefaultInstance())) + .setServiceName("service-foo.googleapis.com")) + .setLbPolicy(LbPolicy.ROUND_ROBIN) + .setMetadata(metadata) + .build(); + + CdsUpdate update = XdsClusterResource.processCluster( + cluster, null, LRS_SERVER_INFO, + LoadBalancerRegistry.getDefaultRegistry()); + + ImmutableMap expectedParsedMetadata = ImmutableMap.of( + "ADDRESS_METADATA", new InetSocketAddress("192.168.1.1", 8080), + "FILTER_METADATA", ImmutableMap.of( + "key1", "value1", + "key2", 42.0)); + + assertThat(update.parsedMetadata()).isEqualTo(expectedParsedMetadata); + } + + @Test + public void processCluster_metadataKeyCollision_resolvesToTypedMetadata() throws Exception { + MetadataRegistry metadataRegistry = MetadataRegistry.getInstance(); + + MetadataValueParser testParser = + new MetadataValueParser() { + @Override + public String getTypeUrl() { + return "type.googleapis.com/test.Type"; + } + + @Override + public Object parse(Any value) { + return "typedMetadataValue"; + } + }; + metadataRegistry.registerParser(testParser); + + Any typedFilterMetadata = Any.newBuilder() + .setTypeUrl("type.googleapis.com/test.Type") + .setValue(ByteString.copyFromUtf8("test")) + .build(); + + Struct filterMetadata = Struct.newBuilder() + .putFields("key1", Value.newBuilder().setStringValue("filterMetadataValue").build()) + .build(); + + Metadata metadata = Metadata.newBuilder() + .putTypedFilterMetadata("key1", typedFilterMetadata) + .putFilterMetadata("key1", filterMetadata) + .build(); + + Cluster cluster = Cluster.newBuilder() + .setName("cluster-foo.googleapis.com") + .setType(DiscoveryType.EDS) + .setEdsClusterConfig( + EdsClusterConfig.newBuilder() + .setEdsConfig( + ConfigSource.newBuilder() + .setAds(AggregatedConfigSource.getDefaultInstance())) + .setServiceName("service-foo.googleapis.com")) + .setLbPolicy(LbPolicy.ROUND_ROBIN) + .setMetadata(metadata) + .build(); + + CdsUpdate update = XdsClusterResource.processCluster( + cluster, null, LRS_SERVER_INFO, + LoadBalancerRegistry.getDefaultRegistry()); + + ImmutableMap expectedParsedMetadata = ImmutableMap.of( + "key1", "typedMetadataValue"); + assertThat(update.parsedMetadata()).isEqualTo(expectedParsedMetadata); + metadataRegistry.removeParser(testParser); + } + + @Test + public void parseNonAggregateCluster_withHttp11ProxyTransportSocket() throws Exception { + XdsClusterResource.isEnabledXdsHttpConnect = true; + + Http11ProxyUpstreamTransport http11ProxyUpstreamTransport = + Http11ProxyUpstreamTransport.newBuilder() + .setTransportSocket(TransportSocket.getDefaultInstance()) + .build(); + + TransportSocket transportSocket = TransportSocket.newBuilder() + .setName("envoy.transport_sockets.http_11_proxy") + .setTypedConfig(Any.pack(http11ProxyUpstreamTransport)) + .build(); + + Cluster cluster = Cluster.newBuilder() + .setName("cluster-http11-proxy.googleapis.com") + .setType(DiscoveryType.EDS) + .setEdsClusterConfig( + EdsClusterConfig.newBuilder() + .setEdsConfig( + ConfigSource.newBuilder().setAds(AggregatedConfigSource.getDefaultInstance())) + .setServiceName("service-http11-proxy.googleapis.com")) + .setLbPolicy(LbPolicy.ROUND_ROBIN) + .setTransportSocket(transportSocket) + .build(); + + CdsUpdate result = + XdsClusterResource.processCluster(cluster, null, LRS_SERVER_INFO, + LoadBalancerRegistry.getDefaultRegistry()); + + assertThat(result).isNotNull(); + assertThat(result.isHttp11ProxyAvailable()).isTrue(); + } + + @Test + public void processCluster_parsesOrcaLrsPropagationMetrics() throws ResourceInvalidException { + LoadStatsManager2.isEnabledOrcaLrsPropagation = true; + + ImmutableList metricSpecs = ImmutableList.of( + "cpu_utilization", + "named_metrics.foo", + "unknown_metric_spec" + ); + Cluster cluster = Cluster.newBuilder() + .setName("cluster-orca.googleapis.com") + .setType(DiscoveryType.EDS) + .setEdsClusterConfig( + EdsClusterConfig.newBuilder() + .setEdsConfig( + ConfigSource.newBuilder().setAds(AggregatedConfigSource.getDefaultInstance())) + .setServiceName("service-orca.googleapis.com")) + .setLbPolicy(LbPolicy.ROUND_ROBIN) + .addAllLrsReportEndpointMetrics(metricSpecs) + .build(); + + CdsUpdate update = XdsClusterResource.processCluster( + cluster, null, LRS_SERVER_INFO, LoadBalancerRegistry.getDefaultRegistry()); + + BackendMetricPropagation propagationConfig = update.backendMetricPropagation(); + assertThat(propagationConfig).isNotNull(); + assertThat(propagationConfig.propagateCpuUtilization).isTrue(); + assertThat(propagationConfig.propagateMemUtilization).isFalse(); + assertThat(propagationConfig.shouldPropagateNamedMetric("foo")).isTrue(); + assertThat(propagationConfig.shouldPropagateNamedMetric("bar")).isFalse(); + assertThat(propagationConfig.shouldPropagateNamedMetric("unknown_metric_spec")) + .isFalse(); + + LoadStatsManager2.isEnabledOrcaLrsPropagation = false; } @Test @@ -2134,10 +2699,11 @@ public void parseServerSideListener_invalidTrafficDirection() throws ResourceInv .setName("listener1") .setTrafficDirection(TrafficDirection.OUTBOUND) .build(); - thrown.expect(ResourceInvalidException.class); - thrown.expectMessage("Listener listener1 with invalid traffic direction: OUTBOUND"); - XdsListenerResource.parseServerSideListener( - listener, null, filterRegistry, null); + ResourceInvalidException e = assertThrows(ResourceInvalidException.class, () -> + XdsListenerResource.parseServerSideListener( + listener, null, filterRegistry, null, getXdsResourceTypeArgs(true))); + assertThat(e).hasMessageThat() + .isEqualTo("Listener listener1 with invalid traffic direction: OUTBOUND"); } @Test @@ -2147,7 +2713,7 @@ public void parseServerSideListener_noTrafficDirection() throws ResourceInvalidE .setName("listener1") .build(); XdsListenerResource.parseServerSideListener( - listener, null, filterRegistry, null); + listener, null, filterRegistry, null, getXdsResourceTypeArgs(true)); } @Test @@ -2158,10 +2724,11 @@ public void parseServerSideListener_listenerFiltersPresent() throws ResourceInva .setTrafficDirection(TrafficDirection.INBOUND) .addListenerFilters(ListenerFilter.newBuilder().build()) .build(); - thrown.expect(ResourceInvalidException.class); - thrown.expectMessage("Listener listener1 cannot have listener_filters"); - XdsListenerResource.parseServerSideListener( - listener, null, filterRegistry, null); + ResourceInvalidException e = assertThrows(ResourceInvalidException.class, () -> + XdsListenerResource.parseServerSideListener(listener, null, filterRegistry, null, + getXdsResourceTypeArgs(true))); + assertThat(e).hasMessageThat() + .isEqualTo("Listener listener1 cannot have listener_filters"); } @Test @@ -2172,10 +2739,44 @@ public void parseServerSideListener_useOriginalDst() throws ResourceInvalidExcep .setTrafficDirection(TrafficDirection.INBOUND) .setUseOriginalDst(BoolValue.of(true)) .build(); - thrown.expect(ResourceInvalidException.class); - thrown.expectMessage("Listener listener1 cannot have use_original_dst set to true"); - XdsListenerResource.parseServerSideListener( - listener,null, filterRegistry, null); + ResourceInvalidException e = assertThrows(ResourceInvalidException.class, () -> + XdsListenerResource.parseServerSideListener(listener, null, filterRegistry, null, + getXdsResourceTypeArgs(true))); + assertThat(e).hasMessageThat() + .isEqualTo("Listener listener1 cannot have use_original_dst set to true"); + } + + @Test + public void parseServerSideListener_emptyAddress() throws ResourceInvalidException { + Listener listener = + Listener.newBuilder() + .setName("listener1") + .setTrafficDirection(TrafficDirection.INBOUND) + .setAddress(Address.newBuilder() + .setSocketAddress( + SocketAddress.newBuilder())) + .build(); + ResourceInvalidException e = assertThrows(ResourceInvalidException.class, () -> + XdsListenerResource.parseServerSideListener( + listener, null, filterRegistry, null, getXdsResourceTypeArgs(true))); + assertThat(e).hasMessageThat().isEqualTo("Invalid address: Empty address is not allowed."); + } + + @Test + public void parseServerSideListener_namedPort() throws ResourceInvalidException { + Listener listener = + Listener.newBuilder() + .setName("listener1") + .setTrafficDirection(TrafficDirection.INBOUND) + .setAddress(Address.newBuilder() + .setSocketAddress( + SocketAddress.newBuilder() + .setAddress("172.14.14.5").setNamedPort(""))) + .build(); + ResourceInvalidException e = assertThrows(ResourceInvalidException.class, () -> + XdsListenerResource.parseServerSideListener( + listener, null, filterRegistry, null, getXdsResourceTypeArgs(true))); + assertThat(e).hasMessageThat().isEqualTo("NAMED_PORT is not supported in gRPC."); } @Test @@ -2221,10 +2822,11 @@ public void parseServerSideListener_nonUniqueFilterChainMatch() throws ResourceI .setTrafficDirection(TrafficDirection.INBOUND) .addAllFilterChains(Arrays.asList(filterChain1, filterChain2)) .build(); - thrown.expect(ResourceInvalidException.class); - thrown.expectMessage("FilterChainMatch must be unique. Found duplicate:"); - XdsListenerResource.parseServerSideListener( - listener, null, filterRegistry, null); + ResourceInvalidException e = assertThrows(ResourceInvalidException.class, () -> + XdsListenerResource.parseServerSideListener( + listener, null, filterRegistry, null, getXdsResourceTypeArgs(true))); + assertThat(e).hasMessageThat() + .startsWith("FilterChainMatch must be unique. Found duplicate:"); } @Test @@ -2270,10 +2872,11 @@ public void parseServerSideListener_nonUniqueFilterChainMatch_sameFilter() .setTrafficDirection(TrafficDirection.INBOUND) .addAllFilterChains(Arrays.asList(filterChain1, filterChain2)) .build(); - thrown.expect(ResourceInvalidException.class); - thrown.expectMessage("FilterChainMatch must be unique. Found duplicate:"); - XdsListenerResource.parseServerSideListener( - listener,null, filterRegistry, null); + ResourceInvalidException e = assertThrows(ResourceInvalidException.class, () -> + XdsListenerResource.parseServerSideListener( + listener, null, filterRegistry, null, getXdsResourceTypeArgs(true))); + assertThat(e).hasMessageThat() + .startsWith("FilterChainMatch must be unique. Found duplicate:"); } @Test @@ -2322,7 +2925,7 @@ public void parseServerSideListener_uniqueFilterChainMatch() throws ResourceInva .addAllFilterChains(Arrays.asList(filterChain1, filterChain2)) .build(); XdsListenerResource.parseServerSideListener( - listener, null, filterRegistry, null); + listener, null, filterRegistry, null, getXdsResourceTypeArgs(true)); } @Test @@ -2333,11 +2936,12 @@ public void parseFilterChain_noHcm() throws ResourceInvalidException { .setFilterChainMatch(FilterChainMatch.getDefaultInstance()) .setTransportSocket(TransportSocket.getDefaultInstance()) .build(); - thrown.expect(ResourceInvalidException.class); - thrown.expectMessage( + ResourceInvalidException e = assertThrows(ResourceInvalidException.class, () -> + XdsListenerResource.parseFilterChain( + filterChain, "filter-chain-foo", null, filterRegistry, null, null, + getXdsResourceTypeArgs(true))); + assertThat(e).hasMessageThat().isEqualTo( "FilterChain filter-chain-foo should contain exact one HttpConnectionManager filter"); - XdsListenerResource.parseFilterChain( - filterChain, null, filterRegistry, null, null); } @Test @@ -2351,11 +2955,12 @@ public void parseFilterChain_duplicateFilter() throws ResourceInvalidException { .setTransportSocket(TransportSocket.getDefaultInstance()) .addAllFilters(Arrays.asList(filter, filter)) .build(); - thrown.expect(ResourceInvalidException.class); - thrown.expectMessage( + ResourceInvalidException e = assertThrows(ResourceInvalidException.class, () -> + XdsListenerResource.parseFilterChain( + filterChain, "filter-chain-foo", null, filterRegistry, null, null, + getXdsResourceTypeArgs(true))); + assertThat(e).hasMessageThat().isEqualTo( "FilterChain filter-chain-foo should contain exact one HttpConnectionManager filter"); - XdsListenerResource.parseFilterChain( - filterChain, null, filterRegistry, null, null); } @Test @@ -2368,12 +2973,13 @@ public void parseFilterChain_filterMissingTypedConfig() throws ResourceInvalidEx .setTransportSocket(TransportSocket.getDefaultInstance()) .addFilters(filter) .build(); - thrown.expect(ResourceInvalidException.class); - thrown.expectMessage( + ResourceInvalidException e = assertThrows(ResourceInvalidException.class, () -> + XdsListenerResource.parseFilterChain( + filterChain, "filter-chain-foo", null, filterRegistry, null, null, + getXdsResourceTypeArgs(true))); + assertThat(e).hasMessageThat().isEqualTo( "FilterChain filter-chain-foo contains filter envoy.http_connection_manager " + "without typed_config"); - XdsListenerResource.parseFilterChain( - filterChain, null, filterRegistry, null, null); } @Test @@ -2390,17 +2996,18 @@ public void parseFilterChain_unsupportedFilter() throws ResourceInvalidException .setTransportSocket(TransportSocket.getDefaultInstance()) .addFilters(filter) .build(); - thrown.expect(ResourceInvalidException.class); - thrown.expectMessage( + ResourceInvalidException e = assertThrows(ResourceInvalidException.class, () -> + XdsListenerResource.parseFilterChain( + filterChain, "filter-chain-foo", null, filterRegistry, null, null, + getXdsResourceTypeArgs(true))); + assertThat(e).hasMessageThat().isEqualTo( "FilterChain filter-chain-foo contains filter unsupported with unsupported " + "typed_config type unsupported-type-url"); - XdsListenerResource.parseFilterChain( - filterChain, null, filterRegistry, null, null); } @Test public void parseFilterChain_noName() throws ResourceInvalidException { - FilterChain filterChain1 = + FilterChain filterChain0 = FilterChain.newBuilder() .setFilterChainMatch(FilterChainMatch.getDefaultInstance()) .addFilters(buildHttpConnectionManagerFilter( @@ -2410,9 +3017,53 @@ public void parseFilterChain_noName() throws ResourceInvalidException { .setTypedConfig(Any.pack(Router.newBuilder().build())) .build())) .build(); - FilterChain filterChain2 = + + FilterChain filterChain1 = + FilterChain.newBuilder() + .setFilterChainMatch( + FilterChainMatch.newBuilder().addAllSourcePorts(Arrays.asList(443, 8080))) + .addFilters(buildHttpConnectionManagerFilter( + HttpFilter.newBuilder() + .setName("http-filter-bar") + .setTypedConfig(Any.pack(Router.newBuilder().build())) + .setIsOptional(true) + .build())) + .build(); + + Listener listenerProto = + Listener.newBuilder() + .setName("listener1") + .setTrafficDirection(TrafficDirection.INBOUND) + .addAllFilterChains(Arrays.asList(filterChain0, filterChain1)) + .setDefaultFilterChain(filterChain0) + .build(); + EnvoyServerProtoData.Listener listener = XdsListenerResource.parseServerSideListener( + listenerProto, null, filterRegistry, null, getXdsResourceTypeArgs(true)); + + assertThat(listener.filterChains().get(0).name()).isEqualTo("chain_0"); + assertThat(listener.filterChains().get(1).name()).isEqualTo("chain_1"); + assertThat(listener.defaultFilterChain().name()).isEqualTo("chain_default"); + } + + @Test + public void parseFilterChain_duplicateName() throws ResourceInvalidException { + FilterChain filterChain0 = FilterChain.newBuilder() + .setName("filter_chain") .setFilterChainMatch(FilterChainMatch.getDefaultInstance()) + .addFilters(buildHttpConnectionManagerFilter( + HttpFilter.newBuilder() + .setName("http-filter-foo") + .setIsOptional(true) + .setTypedConfig(Any.pack(Router.newBuilder().build())) + .build())) + .build(); + + FilterChain filterChain1 = + FilterChain.newBuilder() + .setName("filter_chain") + .setFilterChainMatch( + FilterChainMatch.newBuilder().addAllSourcePorts(Arrays.asList(443, 8080))) .addFilters(buildHttpConnectionManagerFilter( HttpFilter.newBuilder() .setName("http-filter-bar") @@ -2421,204 +3072,273 @@ public void parseFilterChain_noName() throws ResourceInvalidException { .build())) .build(); - EnvoyServerProtoData.FilterChain parsedFilterChain1 = XdsListenerResource.parseFilterChain( - filterChain1, null, filterRegistry, null, - null); - EnvoyServerProtoData.FilterChain parsedFilterChain2 = XdsListenerResource.parseFilterChain( - filterChain2, null, filterRegistry, null, - null); - assertThat(parsedFilterChain1.name()).isEqualTo(parsedFilterChain2.name()); + Listener listenerProto = + Listener.newBuilder() + .setName("listener1") + .setTrafficDirection(TrafficDirection.INBOUND) + .addAllFilterChains(Arrays.asList(filterChain0, filterChain1)) + .build(); + ResourceInvalidException e = assertThrows(ResourceInvalidException.class, () -> + XdsListenerResource.parseServerSideListener( + listenerProto, null, filterRegistry, null, getXdsResourceTypeArgs(true))); + assertThat(e).hasMessageThat() + .isEqualTo("Filter chain names must be unique. Found duplicate: filter_chain"); } @Test - public void validateCommonTlsContext_tlsParams() throws ResourceInvalidException { + public void validateCommonTlsContext_tlsParams() { CommonTlsContext commonTlsContext = CommonTlsContext.newBuilder() .setTlsParams(TlsParameters.getDefaultInstance()) .build(); - thrown.expect(ResourceInvalidException.class); - thrown.expectMessage("common-tls-context with tls_params is not supported"); - XdsClusterResource.validateCommonTlsContext(commonTlsContext, null, false); + ResourceInvalidException e = assertThrows(ResourceInvalidException.class, () -> + XdsClusterResource.validateCommonTlsContext(commonTlsContext, null, false)); + assertThat(e).hasMessageThat().isEqualTo("common-tls-context with tls_params is not supported"); } @Test - public void validateCommonTlsContext_customHandshaker() throws ResourceInvalidException { + public void validateCommonTlsContext_customHandshaker() { CommonTlsContext commonTlsContext = CommonTlsContext.newBuilder() .setCustomHandshaker(TypedExtensionConfig.getDefaultInstance()) .build(); - thrown.expect(ResourceInvalidException.class); - thrown.expectMessage("common-tls-context with custom_handshaker is not supported"); - XdsClusterResource.validateCommonTlsContext(commonTlsContext, null, false); + ResourceInvalidException e = assertThrows(ResourceInvalidException.class, () -> + XdsClusterResource.validateCommonTlsContext(commonTlsContext, null, false)); + assertThat(e).hasMessageThat().isEqualTo( + "common-tls-context with custom_handshaker is not supported"); } @Test - public void validateCommonTlsContext_validationContext() throws ResourceInvalidException { + public void validateCommonTlsContext_validationContext() { CommonTlsContext commonTlsContext = CommonTlsContext.newBuilder() .setValidationContext(CertificateValidationContext.getDefaultInstance()) .build(); - thrown.expect(ResourceInvalidException.class); - thrown.expectMessage("ca_certificate_provider_instance is required in upstream-tls-context"); - XdsClusterResource.validateCommonTlsContext(commonTlsContext, null, false); + ResourceInvalidException e = assertThrows(ResourceInvalidException.class, () -> + XdsClusterResource.validateCommonTlsContext(commonTlsContext, null, false)); + assertThat(e).hasMessageThat().isEqualTo( + "ca_certificate_provider_instance or system_root_certs is required " + + "in upstream-tls-context"); } @Test - public void validateCommonTlsContext_validationContextSdsSecretConfig() - throws ResourceInvalidException { + public void validateCommonTlsContext_validationContextSdsSecretConfig() { CommonTlsContext commonTlsContext = CommonTlsContext.newBuilder() .setValidationContextSdsSecretConfig(SdsSecretConfig.getDefaultInstance()) .build(); - thrown.expect(ResourceInvalidException.class); - thrown.expectMessage( + ResourceInvalidException e = assertThrows(ResourceInvalidException.class, () -> + XdsClusterResource.validateCommonTlsContext(commonTlsContext, null, false)); + assertThat(e).hasMessageThat().isEqualTo( "common-tls-context with validation_context_sds_secret_config is not supported"); - XdsClusterResource.validateCommonTlsContext(commonTlsContext, null, false); - } - - @Test - @SuppressWarnings("deprecation") - public void validateCommonTlsContext_validationContextCertificateProvider() - throws ResourceInvalidException { - CommonTlsContext commonTlsContext = CommonTlsContext.newBuilder() - .setValidationContextCertificateProvider( - CommonTlsContext.CertificateProvider.getDefaultInstance()) - .build(); - thrown.expect(ResourceInvalidException.class); - thrown.expectMessage( - "common-tls-context with validation_context_certificate_provider is not supported"); - XdsClusterResource.validateCommonTlsContext(commonTlsContext, null, false); } @Test - @SuppressWarnings("deprecation") - public void validateCommonTlsContext_validationContextCertificateProviderInstance() + public void validateCommonTlsContext_tlsCertificateProviderInstance_isRequiredForServer() throws ResourceInvalidException { CommonTlsContext commonTlsContext = CommonTlsContext.newBuilder() - .setValidationContextCertificateProviderInstance( - CommonTlsContext.CertificateProviderInstance.getDefaultInstance()) .build(); - thrown.expect(ResourceInvalidException.class); - thrown.expectMessage( - "common-tls-context with validation_context_certificate_provider_instance is not " - + "supported"); - XdsClusterResource.validateCommonTlsContext(commonTlsContext, null, false); + ResourceInvalidException e = assertThrows(ResourceInvalidException.class, () -> + XdsClusterResource.validateCommonTlsContext(commonTlsContext, null, true)); + assertThat(e).hasMessageThat().isEqualTo( + "tls_certificate_provider_instance is required in downstream-tls-context"); } @Test - public void validateCommonTlsContext_tlsCertificateProviderInstance_isRequiredForServer() + public void validateCommonTlsContext_tlsNewCertificateProviderInstance() throws ResourceInvalidException { CommonTlsContext commonTlsContext = CommonTlsContext.newBuilder() + .setTlsCertificateProviderInstance( + CertificateProviderPluginInstance.newBuilder().setInstanceName("name1")) .build(); - thrown.expect(ResourceInvalidException.class); - thrown.expectMessage( - "tls_certificate_provider_instance is required in downstream-tls-context"); - XdsClusterResource.validateCommonTlsContext(commonTlsContext, null, true); + XdsClusterResource + .validateCommonTlsContext(commonTlsContext, ImmutableSet.of("name1", "name2"), true); } @Test @SuppressWarnings("deprecation") - public void validateCommonTlsContext_tlsNewCertificateProviderInstance() + public void validateCommonTlsContext_tlsDeprecatedCertificateProviderInstance() throws ResourceInvalidException { CommonTlsContext commonTlsContext = CommonTlsContext.newBuilder() - .setTlsCertificateProviderInstance( - CertificateProviderPluginInstance.newBuilder().setInstanceName("name1").build()) + .setTlsCertificateCertificateProviderInstance( + CommonTlsContext.CertificateProviderInstance.newBuilder().setInstanceName("name1")) .build(); XdsClusterResource .validateCommonTlsContext(commonTlsContext, ImmutableSet.of("name1", "name2"), true); } @Test - @SuppressWarnings("deprecation") public void validateCommonTlsContext_tlsCertificateProviderInstance() throws ResourceInvalidException { CommonTlsContext commonTlsContext = CommonTlsContext.newBuilder() - .setTlsCertificateCertificateProviderInstance( - CertificateProviderInstance.newBuilder().setInstanceName("name1").build()) + .setTlsCertificateProviderInstance( + CertificateProviderPluginInstance.newBuilder().setInstanceName("name1")) .build(); XdsClusterResource .validateCommonTlsContext(commonTlsContext, ImmutableSet.of("name1", "name2"), true); } @Test - @SuppressWarnings("deprecation") public void validateCommonTlsContext_tlsCertificateProviderInstance_absentInBootstrapFile() throws ResourceInvalidException { CommonTlsContext commonTlsContext = CommonTlsContext.newBuilder() - .setTlsCertificateCertificateProviderInstance( - CertificateProviderInstance.newBuilder().setInstanceName("bad-name").build()) + .setTlsCertificateProviderInstance( + CertificateProviderPluginInstance.newBuilder().setInstanceName("bad-name")) .build(); - thrown.expect(ResourceInvalidException.class); - thrown.expectMessage( + ResourceInvalidException e = assertThrows(ResourceInvalidException.class, () -> + XdsClusterResource.validateCommonTlsContext(commonTlsContext, + ImmutableSet.of("name1", "name2"), true)); + assertThat(e).hasMessageThat().isEqualTo( "CertificateProvider instance name 'bad-name' not defined in the bootstrap file."); - XdsClusterResource - .validateCommonTlsContext(commonTlsContext, ImmutableSet.of("name1", "name2"), true); } @Test - @SuppressWarnings("deprecation") public void validateCommonTlsContext_validationContextProviderInstance() throws ResourceInvalidException { CommonTlsContext commonTlsContext = CommonTlsContext.newBuilder() .setCombinedValidationContext( CommonTlsContext.CombinedCertificateValidationContext.newBuilder() - .setValidationContextCertificateProviderInstance( - CertificateProviderInstance.newBuilder().setInstanceName("name1").build()) - .build()) + .setDefaultValidationContext(CertificateValidationContext.newBuilder() + .setCaCertificateProviderInstance(CertificateProviderPluginInstance.newBuilder() + .setInstanceName("name1")))) .build(); XdsClusterResource .validateCommonTlsContext(commonTlsContext, ImmutableSet.of("name1", "name2"), false); } + @Test + public void + validateCommonTlsContext_combinedValidationContextSystemRootCerts_envVarNotSet_throws() { + XdsClusterResource.enableSystemRootCerts = false; + CommonTlsContext commonTlsContext = CommonTlsContext.newBuilder() + .setCombinedValidationContext( + CommonTlsContext.CombinedCertificateValidationContext.newBuilder() + .setDefaultValidationContext( + CertificateValidationContext.newBuilder() + .setSystemRootCerts( + CertificateValidationContext.SystemRootCerts.newBuilder().build()) + .build() + ) + .build()) + .build(); + try { + XdsClusterResource + .validateCommonTlsContext(commonTlsContext, ImmutableSet.of(), false); + fail("Expected exception"); + } catch (ResourceInvalidException ex) { + assertThat(ex.getMessage()).isEqualTo( + "ca_certificate_provider_instance or system_root_certs is required in" + + " upstream-tls-context"); + } + } + + @Test + public void validateCommonTlsContext_combinedValidationContextSystemRootCerts() + throws ResourceInvalidException { + XdsClusterResource.enableSystemRootCerts = true; + CommonTlsContext commonTlsContext = CommonTlsContext.newBuilder() + .setCombinedValidationContext( + CommonTlsContext.CombinedCertificateValidationContext.newBuilder() + .setDefaultValidationContext( + CertificateValidationContext.newBuilder() + .setSystemRootCerts( + CertificateValidationContext.SystemRootCerts.newBuilder().build()) + .build() + ) + .build()) + .build(); + XdsClusterResource + .validateCommonTlsContext(commonTlsContext, ImmutableSet.of(), false); + } + @Test @SuppressWarnings("deprecation") - public void validateCommonTlsContext_validationContextProviderInstance_absentInBootstrapFile() - throws ResourceInvalidException { + public void validateCommonTlsContext_combinedValidationContextDeprecatedCertProvider() + throws ResourceInvalidException { CommonTlsContext commonTlsContext = CommonTlsContext.newBuilder() + .setTlsCertificateProviderInstance( + CertificateProviderPluginInstance.newBuilder().setInstanceName("cert1")) .setCombinedValidationContext( CommonTlsContext.CombinedCertificateValidationContext.newBuilder() .setValidationContextCertificateProviderInstance( - CertificateProviderInstance.newBuilder().setInstanceName("bad-name").build()) + CommonTlsContext.CertificateProviderInstance.newBuilder() + .setInstanceName("root1")) .build()) .build(); - thrown.expect(ResourceInvalidException.class); - thrown.expectMessage( - "ca_certificate_provider_instance name 'bad-name' not defined in the bootstrap file."); XdsClusterResource - .validateCommonTlsContext(commonTlsContext, ImmutableSet.of("name1", "name2"), false); + .validateCommonTlsContext(commonTlsContext, ImmutableSet.of("cert1", "root1"), true); } - @Test - public void validateCommonTlsContext_tlsCertificatesCount() throws ResourceInvalidException { + public void validateCommonTlsContext_validationContextSystemRootCerts_envVarNotSet_throws() { + XdsClusterResource.enableSystemRootCerts = false; CommonTlsContext commonTlsContext = CommonTlsContext.newBuilder() - .addTlsCertificates(TlsCertificate.getDefaultInstance()) - .build(); - thrown.expect(ResourceInvalidException.class); - thrown.expectMessage("tls_certificate_provider_instance is unset"); - XdsClusterResource.validateCommonTlsContext(commonTlsContext, null, false); + .setValidationContext( + CertificateValidationContext.newBuilder() + .setSystemRootCerts( + CertificateValidationContext.SystemRootCerts.newBuilder().build()) + .build()) + .build(); + try { + XdsClusterResource + .validateCommonTlsContext(commonTlsContext, ImmutableSet.of(), false); + fail("Expected exception"); + } catch (ResourceInvalidException ex) { + assertThat(ex.getMessage()).isEqualTo( + "ca_certificate_provider_instance or system_root_certs is required in " + + "upstream-tls-context"); + } } @Test - public void validateCommonTlsContext_tlsCertificateSdsSecretConfigsCount() + public void validateCommonTlsContext_validationContextSystemRootCerts() throws ResourceInvalidException { + XdsClusterResource.enableSystemRootCerts = true; CommonTlsContext commonTlsContext = CommonTlsContext.newBuilder() - .addTlsCertificateSdsSecretConfigs(SdsSecretConfig.getDefaultInstance()) + .setValidationContext( + CertificateValidationContext.newBuilder() + .setSystemRootCerts( + CertificateValidationContext.SystemRootCerts.newBuilder().build()) + .build()) .build(); - thrown.expect(ResourceInvalidException.class); - thrown.expectMessage( - "tls_certificate_provider_instance is unset"); - XdsClusterResource.validateCommonTlsContext(commonTlsContext, null, false); + XdsClusterResource + .validateCommonTlsContext(commonTlsContext, ImmutableSet.of(), false); } @Test - @SuppressWarnings("deprecation") - public void validateCommonTlsContext_tlsCertificateCertificateProvider() + public void validateCommonTlsContext_validationContextProviderInstance_absentInBootstrapFile() + throws ResourceInvalidException { + CommonTlsContext commonTlsContext = CommonTlsContext.newBuilder() + .setCombinedValidationContext( + CommonTlsContext.CombinedCertificateValidationContext.newBuilder() + .setDefaultValidationContext(CertificateValidationContext.newBuilder() + .setCaCertificateProviderInstance(CertificateProviderPluginInstance.newBuilder() + .setInstanceName("bad-name")))) + .build(); + ResourceInvalidException e = assertThrows(ResourceInvalidException.class, () -> + XdsClusterResource.validateCommonTlsContext(commonTlsContext, + ImmutableSet.of("name1", "name2"), false)); + assertThat(e).hasMessageThat().isEqualTo( + "ca_certificate_provider_instance name 'bad-name' not defined in the bootstrap file."); + } + + + @Test + public void validateCommonTlsContext_tlsCertificatesCount() throws ResourceInvalidException { + CommonTlsContext commonTlsContext = CommonTlsContext.newBuilder() + .addTlsCertificates(TlsCertificate.getDefaultInstance()) + .build(); + ResourceInvalidException e = assertThrows(ResourceInvalidException.class, () -> + XdsClusterResource.validateCommonTlsContext(commonTlsContext, null, false)); + assertThat(e).hasMessageThat().isEqualTo("tls_certificate_provider_instance is unset"); + } + + @Test + public void validateCommonTlsContext_tlsCertificateSdsSecretConfigsCount() throws ResourceInvalidException { CommonTlsContext commonTlsContext = CommonTlsContext.newBuilder() - .setTlsCertificateCertificateProvider( - CommonTlsContext.CertificateProvider.getDefaultInstance()) + .addTlsCertificateSdsSecretConfigs(SdsSecretConfig.getDefaultInstance()) .build(); - thrown.expect(ResourceInvalidException.class); - thrown.expectMessage( + ResourceInvalidException e = assertThrows(ResourceInvalidException.class, () -> + XdsClusterResource.validateCommonTlsContext(commonTlsContext, null, false)); + assertThat(e).hasMessageThat().isEqualTo( "tls_certificate_provider_instance is unset"); - XdsClusterResource.validateCommonTlsContext(commonTlsContext, null, false); } @Test @@ -2626,9 +3346,11 @@ public void validateCommonTlsContext_combinedValidationContext_isRequiredForClie throws ResourceInvalidException { CommonTlsContext commonTlsContext = CommonTlsContext.newBuilder() .build(); - thrown.expect(ResourceInvalidException.class); - thrown.expectMessage("ca_certificate_provider_instance is required in upstream-tls-context"); - XdsClusterResource.validateCommonTlsContext(commonTlsContext, null, false); + ResourceInvalidException e = assertThrows(ResourceInvalidException.class, () -> + XdsClusterResource.validateCommonTlsContext(commonTlsContext, null, false)); + assertThat(e).hasMessageThat().isEqualTo( + "ca_certificate_provider_instance or system_root_certs is required " + + "in upstream-tls-context"); } @Test @@ -2638,10 +3360,11 @@ public void validateCommonTlsContext_combinedValidationContextWithoutCertProvide .setCombinedValidationContext( CommonTlsContext.CombinedCertificateValidationContext.getDefaultInstance()) .build(); - thrown.expect(ResourceInvalidException.class); - thrown.expectMessage( - "ca_certificate_provider_instance is required in upstream-tls-context"); - XdsClusterResource.validateCommonTlsContext(commonTlsContext, null, false); + ResourceInvalidException e = assertThrows(ResourceInvalidException.class, () -> + XdsClusterResource.validateCommonTlsContext(commonTlsContext, null, false)); + assertThat(e).hasMessageThat().isEqualTo( + "ca_certificate_provider_instance or system_root_certs is required in " + + "upstream-tls-context"); } @Test @@ -2651,174 +3374,169 @@ public void validateCommonTlsContext_combinedValContextWithDefaultValContextForS CommonTlsContext commonTlsContext = CommonTlsContext.newBuilder() .setCombinedValidationContext( CombinedCertificateValidationContext.newBuilder() - .setValidationContextCertificateProviderInstance( - CertificateProviderInstance.getDefaultInstance()) .setDefaultValidationContext(CertificateValidationContext.newBuilder() + .setCaCertificateProviderInstance( + CertificateProviderPluginInstance.getDefaultInstance()) .addMatchSubjectAltNames(StringMatcher.newBuilder().setExact("foo.com").build()) .build())) - .setTlsCertificateCertificateProviderInstance( - CertificateProviderInstance.getDefaultInstance()) + .setTlsCertificateProviderInstance( + CertificateProviderPluginInstance.getDefaultInstance()) .build(); - thrown.expect(ResourceInvalidException.class); - thrown.expectMessage("match_subject_alt_names only allowed in upstream_tls_context"); - XdsClusterResource.validateCommonTlsContext(commonTlsContext, ImmutableSet.of(""), true); + ResourceInvalidException e = assertThrows(ResourceInvalidException.class, () -> + XdsClusterResource.validateCommonTlsContext(commonTlsContext, ImmutableSet.of(""), true)); + assertThat(e).hasMessageThat().isEqualTo( + "match_subject_alt_names only allowed in upstream_tls_context"); } @Test - @SuppressWarnings("deprecation") public void validateCommonTlsContext_combinedValContextWithDefaultValContextVerifyCertSpki() throws ResourceInvalidException { CommonTlsContext commonTlsContext = CommonTlsContext.newBuilder() .setCombinedValidationContext( CommonTlsContext.CombinedCertificateValidationContext.newBuilder() - .setValidationContextCertificateProviderInstance( - CommonTlsContext.CertificateProviderInstance.getDefaultInstance()) - .setDefaultValidationContext( - CertificateValidationContext.newBuilder().addVerifyCertificateSpki("foo"))) - .setTlsCertificateCertificateProviderInstance( - CommonTlsContext.CertificateProviderInstance.getDefaultInstance()) + .setDefaultValidationContext(CertificateValidationContext.newBuilder() + .setCaCertificateProviderInstance( + CertificateProviderPluginInstance.getDefaultInstance()) + .addVerifyCertificateSpki("foo"))) + .setTlsCertificateProviderInstance(CertificateProviderPluginInstance.getDefaultInstance()) .build(); - thrown.expect(ResourceInvalidException.class); - thrown.expectMessage("verify_certificate_spki in default_validation_context is not " - + "supported"); - XdsClusterResource.validateCommonTlsContext(commonTlsContext, ImmutableSet.of(""), false); + ResourceInvalidException e = assertThrows(ResourceInvalidException.class, () -> + XdsClusterResource.validateCommonTlsContext(commonTlsContext, ImmutableSet.of(""), false)); + assertThat(e).hasMessageThat().isEqualTo( + "verify_certificate_spki in default_validation_context is not supported"); } @Test - @SuppressWarnings("deprecation") public void validateCommonTlsContext_combinedValContextWithDefaultValContextVerifyCertHash() throws ResourceInvalidException { CommonTlsContext commonTlsContext = CommonTlsContext.newBuilder() .setCombinedValidationContext( CommonTlsContext.CombinedCertificateValidationContext.newBuilder() - .setValidationContextCertificateProviderInstance( - CommonTlsContext.CertificateProviderInstance.getDefaultInstance()) - .setDefaultValidationContext( - CertificateValidationContext.newBuilder().addVerifyCertificateHash("foo"))) - .setTlsCertificateCertificateProviderInstance( - CommonTlsContext.CertificateProviderInstance.getDefaultInstance()) + .setDefaultValidationContext(CertificateValidationContext.newBuilder() + .setCaCertificateProviderInstance( + CertificateProviderPluginInstance.getDefaultInstance()) + .addVerifyCertificateHash("foo"))) + .setTlsCertificateProviderInstance(CertificateProviderPluginInstance.getDefaultInstance()) .build(); - thrown.expect(ResourceInvalidException.class); - thrown.expectMessage("verify_certificate_hash in default_validation_context is not " - + "supported"); - XdsClusterResource.validateCommonTlsContext(commonTlsContext, ImmutableSet.of(""), false); + ResourceInvalidException e = assertThrows(ResourceInvalidException.class, () -> + XdsClusterResource.validateCommonTlsContext(commonTlsContext, ImmutableSet.of(""), false)); + assertThat(e).hasMessageThat().isEqualTo( + "verify_certificate_hash in default_validation_context is not supported"); } @Test - @SuppressWarnings("deprecation") public void validateCommonTlsContext_combinedValContextDfltValContextRequireSignedCertTimestamp() throws ResourceInvalidException { CommonTlsContext commonTlsContext = CommonTlsContext.newBuilder() .setCombinedValidationContext( CommonTlsContext.CombinedCertificateValidationContext.newBuilder() - .setValidationContextCertificateProviderInstance( - CommonTlsContext.CertificateProviderInstance.getDefaultInstance()) .setDefaultValidationContext(CertificateValidationContext.newBuilder() + .setCaCertificateProviderInstance( + CertificateProviderPluginInstance.getDefaultInstance()) .setRequireSignedCertificateTimestamp(BoolValue.of(true)))) - .setTlsCertificateCertificateProviderInstance( - CommonTlsContext.CertificateProviderInstance.getDefaultInstance()) + .setTlsCertificateProviderInstance( + CertificateProviderPluginInstance.getDefaultInstance()) .build(); - thrown.expect(ResourceInvalidException.class); - thrown.expectMessage( + ResourceInvalidException e = assertThrows(ResourceInvalidException.class, () -> + XdsClusterResource.validateCommonTlsContext(commonTlsContext, ImmutableSet.of(""), false)); + assertThat(e).hasMessageThat().isEqualTo( "require_signed_certificate_timestamp in default_validation_context is not " + "supported"); - XdsClusterResource.validateCommonTlsContext(commonTlsContext, ImmutableSet.of(""), false); } @Test - @SuppressWarnings("deprecation") public void validateCommonTlsContext_combinedValidationContextWithDefaultValidationContextCrl() throws ResourceInvalidException { CommonTlsContext commonTlsContext = CommonTlsContext.newBuilder() .setCombinedValidationContext( CommonTlsContext.CombinedCertificateValidationContext.newBuilder() - .setValidationContextCertificateProviderInstance( - CommonTlsContext.CertificateProviderInstance.getDefaultInstance()) .setDefaultValidationContext(CertificateValidationContext.newBuilder() + .setCaCertificateProviderInstance( + CertificateProviderPluginInstance.getDefaultInstance()) .setCrl(DataSource.getDefaultInstance()))) - .setTlsCertificateCertificateProviderInstance( - CommonTlsContext.CertificateProviderInstance.getDefaultInstance()) + .setTlsCertificateProviderInstance(CertificateProviderPluginInstance.getDefaultInstance()) .build(); - thrown.expect(ResourceInvalidException.class); - thrown.expectMessage("crl in default_validation_context is not supported"); - XdsClusterResource.validateCommonTlsContext(commonTlsContext, ImmutableSet.of(""), false); + ResourceInvalidException e = assertThrows(ResourceInvalidException.class, () -> + XdsClusterResource.validateCommonTlsContext(commonTlsContext, ImmutableSet.of(""), false)); + assertThat(e).hasMessageThat().isEqualTo("crl in default_validation_context is not supported"); } @Test - @SuppressWarnings("deprecation") public void validateCommonTlsContext_combinedValContextWithDfltValContextCustomValidatorConfig() throws ResourceInvalidException { CommonTlsContext commonTlsContext = CommonTlsContext.newBuilder() .setCombinedValidationContext( CommonTlsContext.CombinedCertificateValidationContext.newBuilder() - .setValidationContextCertificateProviderInstance( - CommonTlsContext.CertificateProviderInstance.getDefaultInstance()) .setDefaultValidationContext(CertificateValidationContext.newBuilder() + .setCaCertificateProviderInstance( + CertificateProviderPluginInstance.getDefaultInstance()) .setCustomValidatorConfig(TypedExtensionConfig.getDefaultInstance()))) - .setTlsCertificateCertificateProviderInstance( - CommonTlsContext.CertificateProviderInstance.getDefaultInstance()) + .setTlsCertificateProviderInstance(CertificateProviderPluginInstance.getDefaultInstance()) .build(); - thrown.expect(ResourceInvalidException.class); - thrown.expectMessage("custom_validator_config in default_validation_context is not " - + "supported"); - XdsClusterResource.validateCommonTlsContext(commonTlsContext, ImmutableSet.of(""), false); + ResourceInvalidException e = assertThrows(ResourceInvalidException.class, () -> + XdsClusterResource.validateCommonTlsContext(commonTlsContext, ImmutableSet.of(""), false)); + assertThat(e).hasMessageThat().isEqualTo( + "custom_validator_config in default_validation_context is not supported"); } @Test public void validateDownstreamTlsContext_noCommonTlsContext() throws ResourceInvalidException { DownstreamTlsContext downstreamTlsContext = DownstreamTlsContext.getDefaultInstance(); - thrown.expect(ResourceInvalidException.class); - thrown.expectMessage("common-tls-context is required in downstream-tls-context"); - XdsListenerResource.validateDownstreamTlsContext(downstreamTlsContext, null); + ResourceInvalidException e = assertThrows(ResourceInvalidException.class, () -> + XdsListenerResource.validateDownstreamTlsContext(downstreamTlsContext, null)); + assertThat(e).hasMessageThat().isEqualTo( + "common-tls-context is required in downstream-tls-context"); } @Test - @SuppressWarnings("deprecation") public void validateDownstreamTlsContext_hasRequireSni() throws ResourceInvalidException { CommonTlsContext commonTlsContext = CommonTlsContext.newBuilder() .setCombinedValidationContext( CommonTlsContext.CombinedCertificateValidationContext.newBuilder() - .setValidationContextCertificateProviderInstance( - CommonTlsContext.CertificateProviderInstance.getDefaultInstance())) - .setTlsCertificateCertificateProviderInstance( - CommonTlsContext.CertificateProviderInstance.getDefaultInstance()) + .setDefaultValidationContext(CertificateValidationContext.newBuilder() + .setCaCertificateProviderInstance( + CertificateProviderPluginInstance.getDefaultInstance()))) + .setTlsCertificateProviderInstance(CertificateProviderPluginInstance.getDefaultInstance()) .build(); DownstreamTlsContext downstreamTlsContext = DownstreamTlsContext.newBuilder() .setCommonTlsContext(commonTlsContext) .setRequireSni(BoolValue.of(true)) .build(); - thrown.expect(ResourceInvalidException.class); - thrown.expectMessage("downstream-tls-context with require-sni is not supported"); - XdsListenerResource.validateDownstreamTlsContext(downstreamTlsContext, ImmutableSet.of("")); + ResourceInvalidException e = assertThrows(ResourceInvalidException.class, () -> + XdsListenerResource.validateDownstreamTlsContext(downstreamTlsContext, + ImmutableSet.of(""))); + assertThat(e).hasMessageThat().isEqualTo( + "downstream-tls-context with require-sni is not supported"); } @Test - @SuppressWarnings("deprecation") public void validateDownstreamTlsContext_hasOcspStaplePolicy() throws ResourceInvalidException { CommonTlsContext commonTlsContext = CommonTlsContext.newBuilder() .setCombinedValidationContext( CommonTlsContext.CombinedCertificateValidationContext.newBuilder() - .setValidationContextCertificateProviderInstance( - CommonTlsContext.CertificateProviderInstance.getDefaultInstance())) - .setTlsCertificateCertificateProviderInstance( - CommonTlsContext.CertificateProviderInstance.getDefaultInstance()) + .setDefaultValidationContext(CertificateValidationContext.newBuilder() + .setCaCertificateProviderInstance( + CertificateProviderPluginInstance.getDefaultInstance()))) + .setTlsCertificateProviderInstance(CertificateProviderPluginInstance.getDefaultInstance()) .build(); DownstreamTlsContext downstreamTlsContext = DownstreamTlsContext.newBuilder() .setCommonTlsContext(commonTlsContext) .setOcspStaplePolicy(DownstreamTlsContext.OcspStaplePolicy.STRICT_STAPLING) .build(); - thrown.expect(ResourceInvalidException.class); - thrown.expectMessage( + ResourceInvalidException e = assertThrows(ResourceInvalidException.class, () -> + XdsListenerResource.validateDownstreamTlsContext(downstreamTlsContext, + ImmutableSet.of(""))); + assertThat(e).hasMessageThat().isEqualTo( "downstream-tls-context with ocsp_staple_policy value STRICT_STAPLING is not supported"); - XdsListenerResource.validateDownstreamTlsContext(downstreamTlsContext, ImmutableSet.of("")); } @Test public void validateUpstreamTlsContext_noCommonTlsContext() throws ResourceInvalidException { UpstreamTlsContext upstreamTlsContext = UpstreamTlsContext.getDefaultInstance(); - thrown.expect(ResourceInvalidException.class); - thrown.expectMessage("common-tls-context is required in upstream-tls-context"); - XdsClusterResource.validateUpstreamTlsContext(upstreamTlsContext, null); + ResourceInvalidException e = assertThrows(ResourceInvalidException.class, () -> + XdsClusterResource.validateUpstreamTlsContext(upstreamTlsContext, null)); + assertThat(e).hasMessageThat().isEqualTo( + "common-tls-context is required in upstream-tls-context"); } @Test @@ -2870,7 +3588,7 @@ public void canonifyResourceName() { /** * Tests compliance with RFC 3986 section 3.3 - * https://datatracker.ietf.org/doc/html/rfc3986#section-3.3 + * https://datatracker.ietf.org/doc/html/rfc3986#section-3.3 . */ @Test public void percentEncodePath() { @@ -2910,4 +3628,10 @@ private static Filter buildHttpConnectionManagerFilter(HttpFilter... httpFilters "type.googleapis.com")) .build(); } + + private XdsResourceType.Args getXdsResourceTypeArgs(boolean isTrustedServer) { + return new XdsResourceType.Args( + ServerInfo.create("http://td", "", false, isTrustedServer, false, false), "1.0", null, XdsTestUtils.EMPTY_BOOTSTRAP, null, null + ); + } } diff --git a/xds/src/test/java/io/grpc/xds/GrpcXdsClientImplTestBase.java b/xds/src/test/java/io/grpc/xds/GrpcXdsClientImplTestBase.java index 97c82731cf5..af55e572811 100644 --- a/xds/src/test/java/io/grpc/xds/GrpcXdsClientImplTestBase.java +++ b/xds/src/test/java/io/grpc/xds/GrpcXdsClientImplTestBase.java @@ -18,14 +18,16 @@ import static com.google.common.truth.Truth.assertThat; import static com.google.common.truth.Truth.assertWithMessage; -import static io.grpc.xds.GrpcXdsTransportFactory.DEFAULT_XDS_TRANSPORT_FACTORY; +import static io.grpc.StatusMatcher.statusHasCode; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.argThat; import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.ArgumentMatchers.isA; import static org.mockito.Mockito.atLeastOnce; import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.inOrder; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; +import static org.mockito.Mockito.timeout; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoInteractions; @@ -41,6 +43,7 @@ import com.google.protobuf.Duration; import com.google.protobuf.InvalidProtocolBufferException; import com.google.protobuf.Message; +import com.google.protobuf.StringValue; import com.google.protobuf.UInt32Value; import com.google.protobuf.util.Durations; import io.envoyproxy.envoy.config.cluster.v3.OutlierDetection; @@ -48,7 +51,6 @@ import io.envoyproxy.envoy.config.route.v3.WeightedCluster; import io.envoyproxy.envoy.extensions.filters.http.router.v3.Router; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateProviderPluginInstance; -import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext; import io.grpc.BindableService; import io.grpc.ChannelCredentials; import io.grpc.Context; @@ -58,6 +60,8 @@ import io.grpc.Server; import io.grpc.Status; import io.grpc.Status.Code; +import io.grpc.StatusOr; +import io.grpc.StatusOrMatcher; import io.grpc.inprocess.InProcessChannelBuilder; import io.grpc.inprocess.InProcessServerBuilder; import io.grpc.internal.BackoffPolicy; @@ -87,6 +91,7 @@ import io.grpc.xds.client.Bootstrapper.BootstrapInfo; import io.grpc.xds.client.Bootstrapper.CertificateProviderInfo; import io.grpc.xds.client.Bootstrapper.ServerInfo; +import io.grpc.xds.client.BootstrapperImpl; import io.grpc.xds.client.EnvoyProtoData.Node; import io.grpc.xds.client.LoadStatsManager2.ClusterDropStats; import io.grpc.xds.client.Locality; @@ -95,7 +100,9 @@ import io.grpc.xds.client.XdsClient.ResourceMetadata.UpdateFailureState; import io.grpc.xds.client.XdsClient.ResourceUpdate; import io.grpc.xds.client.XdsClient.ResourceWatcher; +import io.grpc.xds.client.XdsClient.ServerConnectionCallback; import io.grpc.xds.client.XdsClientImpl; +import io.grpc.xds.client.XdsClientMetricReporter; import io.grpc.xds.client.XdsResourceType; import io.grpc.xds.client.XdsResourceType.ResourceInvalidException; import io.grpc.xds.client.XdsTransportFactory; @@ -107,11 +114,13 @@ import java.util.Collections; import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.Queue; import java.util.concurrent.BlockingDeque; import java.util.concurrent.CountDownLatch; import java.util.concurrent.CyclicBarrier; import java.util.concurrent.Executor; +import java.util.concurrent.Future; import java.util.concurrent.LinkedBlockingDeque; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; @@ -124,7 +133,6 @@ import org.junit.runner.RunWith; import org.junit.runners.JUnit4; import org.mockito.ArgumentCaptor; -import org.mockito.ArgumentMatchers; import org.mockito.Captor; import org.mockito.InOrder; import org.mockito.Mock; @@ -133,6 +141,7 @@ import org.mockito.junit.MockitoJUnit; import org.mockito.junit.MockitoRule; import org.mockito.stubbing.Answer; +import org.mockito.verification.VerificationMode; /** * Tests for {@link XdsClientImpl}. @@ -141,8 +150,9 @@ // The base class was used to test both xds v2 and v3. V2 is dropped now so the base class is not // necessary. Still keep it for future version usage. Remove if too much trouble to maintain. public abstract class GrpcXdsClientImplTestBase { + private static final String SERVER_URI = "trafficdirector.googleapis.com"; - private static final String SERVER_URI_CUSTOME_AUTHORITY = "trafficdirector2.googleapis.com"; + private static final String SERVER_URI_CUSTOM_AUTHORITY = "trafficdirector2.googleapis.com"; private static final String SERVER_URI_EMPTY_AUTHORITY = "trafficdirector3.googleapis.com"; private static final String LDS_RESOURCE = "listener.googleapis.com"; private static final String RDS_RESOURCE = "route-configuration.googleapis.com"; @@ -216,7 +226,7 @@ public boolean shouldAccept(Runnable command) { protected final Queue loadReportCalls = new ArrayDeque<>(); protected final AtomicBoolean adsEnded = new AtomicBoolean(true); protected final AtomicBoolean lrsEnded = new AtomicBoolean(true); - private final MessageFactory mf = createMessageFactory(); + protected MessageFactory mf; private static final long TIME_INCREMENT = TimeUnit.SECONDS.toNanos(1); /** Fake time provider increments time TIME_INCREMENT each call. */ @@ -230,47 +240,30 @@ public long currentTimeNanos() { private static final int VHOST_SIZE = 2; // LDS test resources. - private final Any testListenerVhosts = Any.pack(mf.buildListenerWithApiListener(LDS_RESOURCE, - mf.buildRouteConfiguration("do not care", mf.buildOpaqueVirtualHosts(VHOST_SIZE)))); - private final Any testListenerRds = - Any.pack(mf.buildListenerWithApiListenerForRds(LDS_RESOURCE, RDS_RESOURCE)); + private Any testListenerVhosts; + private Any testListenerRds; // RDS test resources. - private final Any testRouteConfig = - Any.pack(mf.buildRouteConfiguration(RDS_RESOURCE, mf.buildOpaqueVirtualHosts(VHOST_SIZE))); + private Any testRouteConfig; // CDS test resources. - private final Any testClusterRoundRobin = - Any.pack(mf.buildEdsCluster(CDS_RESOURCE, null, "round_robin", null, - null, false, null, "envoy.transport_sockets.tls", null, null - )); + private Any testClusterRoundRobin; // EDS test resources. - private final Message lbEndpointHealthy = - mf.buildLocalityLbEndpoints("region1", "zone1", "subzone1", - mf.buildLbEndpoint("192.168.0.1", 8080, "healthy", 2), 1, 0); + private Message lbEndpointHealthy; // Locality with 0 endpoints - private final Message lbEndpointEmpty = - mf.buildLocalityLbEndpoints("region3", "zone3", "subzone3", - ImmutableList.of(), 2, 1); + private Message lbEndpointEmpty; // Locality with 0-weight endpoint - private final Message lbEndpointZeroWeight = - mf.buildLocalityLbEndpoints("region4", "zone4", "subzone4", - mf.buildLbEndpoint("192.168.142.5", 80, "unknown", 5), 0, 2); - private final Any testClusterLoadAssignment = Any.pack(mf.buildClusterLoadAssignment(EDS_RESOURCE, - ImmutableList.of(lbEndpointHealthy, lbEndpointEmpty, lbEndpointZeroWeight), - ImmutableList.of(mf.buildDropOverload("lb", 200), mf.buildDropOverload("throttle", 1000)))); - - @Captor - private ArgumentCaptor ldsUpdateCaptor; + private Message lbEndpointZeroWeight; + private Any testClusterLoadAssignment; @Captor - private ArgumentCaptor rdsUpdateCaptor; + private ArgumentCaptor> ldsUpdateCaptor; @Captor - private ArgumentCaptor cdsUpdateCaptor; + private ArgumentCaptor> rdsUpdateCaptor; @Captor - private ArgumentCaptor edsUpdateCaptor; + private ArgumentCaptor> cdsUpdateCaptor; @Captor - private ArgumentCaptor errorCaptor; + private ArgumentCaptor> edsUpdateCaptor; @Mock private BackoffPolicy.Provider backoffPolicyProvider; @@ -281,11 +274,19 @@ public long currentTimeNanos() { @Mock private ResourceWatcher ldsResourceWatcher; @Mock + private ResourceWatcher ldsResourceWatcher2; + @Mock private ResourceWatcher rdsResourceWatcher; @Mock private ResourceWatcher cdsResourceWatcher; @Mock private ResourceWatcher edsResourceWatcher; + @Mock + private ResourceWatcher stringResourceWatcher; + @Mock + private XdsClientMetricReporter xdsClientMetricReporter; + @Mock + private ServerConnectionCallback serverConnectionCallback; private ManagedChannel channel; private ManagedChannel channelForCustomAuthority; @@ -294,18 +295,68 @@ public long currentTimeNanos() { private boolean originalEnableLeastRequest; private Server xdsServer; private final String serverName = InProcessServerBuilder.generateName(); - private final BindableService adsService = createAdsService(); - private final BindableService lrsService = createLrsService(); + private BindableService adsService; + private BindableService lrsService; + + private XdsTransportFactory xdsTransportFactory = new XdsTransportFactory() { + @Override + public XdsTransport create(ServerInfo serverInfo) { + if (serverInfo.target().equals(SERVER_URI)) { + return new GrpcXdsTransport(channel); + } + if (serverInfo.target().equals(SERVER_URI_CUSTOM_AUTHORITY)) { + if (channelForCustomAuthority == null) { + channelForCustomAuthority = cleanupRule.register( + InProcessChannelBuilder.forName(serverName).directExecutor().build()); + } + return new GrpcXdsTransport(channelForCustomAuthority); + } + if (serverInfo.target().equals(SERVER_URI_EMPTY_AUTHORITY)) { + if (channelForEmptyAuthority == null) { + channelForEmptyAuthority = cleanupRule.register( + InProcessChannelBuilder.forName(serverName).directExecutor().build()); + } + return new GrpcXdsTransport(channelForEmptyAuthority); + } + throw new IllegalArgumentException("Can not create channel for " + serverInfo); + } + }; @Before public void setUp() throws IOException { + mf = createMessageFactory(); + testListenerVhosts = Any.pack(mf.buildListenerWithApiListener(LDS_RESOURCE, + mf.buildRouteConfiguration("do not care", mf.buildOpaqueVirtualHosts(VHOST_SIZE)))); + testListenerRds = + Any.pack(mf.buildListenerWithApiListenerForRds(LDS_RESOURCE, RDS_RESOURCE)); + testRouteConfig = + Any.pack(mf.buildRouteConfiguration(RDS_RESOURCE, mf.buildOpaqueVirtualHosts(VHOST_SIZE))); + testClusterRoundRobin = + Any.pack(mf.buildEdsCluster(CDS_RESOURCE, null, "round_robin", null, + null, false, null, "envoy.transport_sockets.tls", null, null + )); + lbEndpointHealthy = + mf.buildLocalityLbEndpoints("region1", "zone1", "subzone1", + mf.buildLbEndpoint("192.168.0.1", 8080, "healthy", 2, "endpoint-host-name"), 1, 0); + lbEndpointEmpty = + mf.buildLocalityLbEndpoints("region3", "zone3", "subzone3", + ImmutableList.of(), 2, 1); + lbEndpointZeroWeight = + mf.buildLocalityLbEndpoints("region4", "zone4", "subzone4", + mf.buildLbEndpoint("192.168.142.5", 80, "unknown", 5, "endpoint-host-name"), 0, 2); + testClusterLoadAssignment = Any.pack(mf.buildClusterLoadAssignment(EDS_RESOURCE, + ImmutableList.of(lbEndpointHealthy, lbEndpointEmpty, lbEndpointZeroWeight), + ImmutableList.of(mf.buildDropOverload("lb", 200), mf.buildDropOverload("throttle", 1000)))); + adsService = createAdsService(); + lrsService = createLrsService(); + when(backoffPolicyProvider.get()).thenReturn(backoffPolicy1, backoffPolicy2); when(backoffPolicy1.nextBackoffNanos()).thenReturn(10L, 100L); when(backoffPolicy2.nextBackoffNanos()).thenReturn(20L, 200L); // Start the server and the client. - originalEnableLeastRequest = XdsResourceType.enableLeastRequest; - XdsResourceType.enableLeastRequest = true; + originalEnableLeastRequest = XdsClusterResource.enableLeastRequest; + XdsClusterResource.enableLeastRequest = true; xdsServer = cleanupRule.register(InProcessServerBuilder .forName(serverName) .addService(adsService) @@ -315,32 +366,9 @@ public void setUp() throws IOException { .start()); channel = cleanupRule.register(InProcessChannelBuilder.forName(serverName).directExecutor().build()); - XdsTransportFactory xdsTransportFactory = new XdsTransportFactory() { - @Override - public XdsTransport create(ServerInfo serverInfo) { - if (serverInfo.target().equals(SERVER_URI)) { - return new GrpcXdsTransport(channel); - } - if (serverInfo.target().equals(SERVER_URI_CUSTOME_AUTHORITY)) { - if (channelForCustomAuthority == null) { - channelForCustomAuthority = cleanupRule.register( - InProcessChannelBuilder.forName(serverName).directExecutor().build()); - } - return new GrpcXdsTransport(channelForCustomAuthority); - } - if (serverInfo.target().equals(SERVER_URI_EMPTY_AUTHORITY)) { - if (channelForEmptyAuthority == null) { - channelForEmptyAuthority = cleanupRule.register( - InProcessChannelBuilder.forName(serverName).directExecutor().build()); - } - return new GrpcXdsTransport(channelForEmptyAuthority); - } - throw new IllegalArgumentException("Can not create channel for " + serverInfo); - } - }; - xdsServerInfo = ServerInfo.create(SERVER_URI, CHANNEL_CREDENTIALS, - ignoreResourceDeletion()); + xdsServerInfo = ServerInfo.create(SERVER_URI, CHANNEL_CREDENTIALS, ignoreResourceDeletion(), + true, false, false); BootstrapInfo bootstrapInfo = Bootstrapper.BootstrapInfo.builder() .servers(Collections.singletonList(xdsServerInfo)) @@ -350,7 +378,7 @@ public XdsTransport create(ServerInfo serverInfo) { AuthorityInfo.create( "xdstp://authority.xds.com/envoy.config.listener.v3.Listener/%s", ImmutableList.of(Bootstrapper.ServerInfo.create( - SERVER_URI_CUSTOME_AUTHORITY, CHANNEL_CREDENTIALS))), + SERVER_URI_CUSTOM_AUTHORITY, CHANNEL_CREDENTIALS))), "", AuthorityInfo.create( "xdstp:///envoy.config.listener.v3.Listener/%s", @@ -368,7 +396,8 @@ public XdsTransport create(ServerInfo serverInfo) { fakeClock.getStopwatchSupplier(), timeProvider, MessagePrinter.INSTANCE, - new TlsContextManagerImpl(bootstrapInfo)); + new TlsContextManagerImpl(bootstrapInfo), + xdsClientMetricReporter); assertThat(resourceDiscoveryCalls).isEmpty(); assertThat(loadReportCalls).isEmpty(); @@ -376,7 +405,7 @@ public XdsTransport create(ServerInfo serverInfo) { @After public void tearDown() { - XdsResourceType.enableLeastRequest = originalEnableLeastRequest; + XdsClusterResource.enableLeastRequest = originalEnableLeastRequest; xdsClient.shutdown(); channel.shutdown(); // channel not owned by XdsClient assertThat(adsEnded.get()).isTrue(); @@ -475,10 +504,11 @@ private void verifyResourceMetadataAcked( private void verifyResourceMetadataNacked( XdsResourceType type, String resourceName, Any rawResource, String versionInfo, long updateTime, String failedVersion, long failedUpdateTimeNanos, - List failedDetails) { + List failedDetails, boolean cached) { ResourceMetadata resourceMetadata = verifyResourceMetadata(type, resourceName, rawResource, ResourceMetadataStatus.NACKED, versionInfo, updateTime, true); + assertThat(resourceMetadata.isCached()).isEqualTo(cached); UpdateFailureState errorState = resourceMetadata.getErrorState(); assertThat(errorState).isNotNull(); @@ -595,9 +625,104 @@ private void validateGoldenClusterLoadAssignment(EdsUpdate edsUpdate) { .containsExactly( Locality.create("region1", "zone1", "subzone1"), LocalityLbEndpoints.create( - ImmutableList.of(LbEndpoint.create("192.168.0.1", 8080, 2, true)), 1, 0), + ImmutableList.of(LbEndpoint.create("192.168.0.1", 8080, 2, true, + "endpoint-host-name", ImmutableMap.of())), 1, 0, ImmutableMap.of()), Locality.create("region3", "zone3", "subzone3"), - LocalityLbEndpoints.create(ImmutableList.of(), 2, 1)); + LocalityLbEndpoints.create(ImmutableList.of(), 2, 1, ImmutableMap.of())); + } + + /** + * Verifies that the {@link XdsClientMetricReporter#reportResourceUpdates} method has been called + * the expected number of times with the expected values for valid resource count, invalid + * resource count, and corresponding metric labels. + */ + private void verifyResourceValidInvalidCount(int times, long validResourceCount, + long invalidResourceCount, String xdsServerTargetLabel, + String resourceType) { + verify(xdsClientMetricReporter, times(times)).reportResourceUpdates( + eq(validResourceCount), + eq(invalidResourceCount), + eq(xdsServerTargetLabel), + eq(resourceType)); + } + + private void verifyServerFailureCount(int times, long serverFailureCount, String xdsServer) { + verify(xdsClientMetricReporter, times(times)).reportServerFailure( + eq(serverFailureCount), + eq(xdsServer)); + } + + /** + * Invokes the callback, which will be called by {@link XdsClientMetricReporter} to record + * whether XdsClient has a working ADS stream. + */ + private void callback_ReportServerConnection() { + try { + Future unused = xdsClient.reportServerConnections(serverConnectionCallback); + } catch (Exception e) { + if (e instanceof InterruptedException) { + Thread.currentThread().interrupt(); + } + throw new AssertionError(e); + } + } + + private void verifyServerConnection(int times, boolean isConnected, String xdsServer) { + verify(serverConnectionCallback, times(times)).reportServerConnectionGauge( + eq(isConnected), + eq(xdsServer)); + } + + @Test + public void doParse_returnsSuccessfully() { + XdsStringResource resourceType = new XdsStringResource(); + xdsClient.watchXdsResource( + resourceType, "resource1", stringResourceWatcher, MoreExecutors.directExecutor()); + DiscoveryRpcCall call = resourceDiscoveryCalls.poll(); + + Any resource = Any.pack(StringValue.newBuilder().setValue("resource1").build()); + call.sendResponse(resourceType, resource, VERSION_1, "0000"); + verify(stringResourceWatcher).onResourceChanged(argThat(StatusOrMatcher.hasValue( + (StringUpdate arg) -> new StringUpdate("resource1").equals(arg)))); + } + + @Test + public void doParse_throwsResourceInvalidException_resourceInvalid() { + XdsStringResource resourceType = new XdsStringResource() { + @Override + protected StringUpdate doParse(Args args, Message unpackedMessage) + throws ResourceInvalidException { + throw new ResourceInvalidException("some bad input"); + } + }; + xdsClient.watchXdsResource( + resourceType, "resource1", stringResourceWatcher, MoreExecutors.directExecutor()); + DiscoveryRpcCall call = resourceDiscoveryCalls.poll(); + + Any resource = Any.pack(StringValue.newBuilder().setValue("resource1").build()); + call.sendResponse(resourceType, resource, VERSION_1, "0000"); + verify(stringResourceWatcher).onResourceChanged(argThat(StatusOrMatcher.hasStatus( + statusHasCode(Status.Code.UNAVAILABLE) + .andDescriptionContains("validation error: some bad input")))); + } + + @Test + public void doParse_throwsError_resourceInvalid() throws Exception { + XdsStringResource resourceType = new XdsStringResource() { + @Override + protected StringUpdate doParse(Args args, Message unpackedMessage) { + throw new AssertionError("something bad happened"); + } + }; + xdsClient.watchXdsResource( + resourceType, "resource1", stringResourceWatcher, MoreExecutors.directExecutor()); + DiscoveryRpcCall call = resourceDiscoveryCalls.poll(); + + Any resource = Any.pack(StringValue.newBuilder().setValue("resource1").build()); + call.sendResponse(resourceType, resource, VERSION_1, "0000"); + verify(stringResourceWatcher).onResourceChanged(argThat(StatusOrMatcher.hasStatus( + statusHasCode(Status.Code.UNAVAILABLE) + .andDescriptionContains("unexpected error: AssertionError: something bad happened")))); } @Test @@ -616,9 +741,13 @@ public void ldsResourceNotFound() { verifySubscribedResourcesMetadataSizes(1, 0, 0, 0); // Server failed to return subscribed resource within expected time window. fakeClock.forwardTime(XdsClientImpl.INITIAL_RESOURCE_FETCH_TIMEOUT_SEC, TimeUnit.SECONDS); - verify(ldsResourceWatcher).onResourceDoesNotExist(LDS_RESOURCE); + verify(ldsResourceWatcher).onResourceChanged(ldsUpdateCaptor.capture()); + StatusOr statusOrUpdate = ldsUpdateCaptor.getValue(); + assertThat(statusOrUpdate.hasValue()).isFalse(); + assertThat(statusOrUpdate.getStatus().getCode()).isEqualTo(Status.Code.NOT_FOUND); assertThat(fakeClock.getPendingTasks(LDS_RESOURCE_FETCH_TIMEOUT_TASK_FILTER)).isEmpty(); verifyResourceMetadataDoesNotExist(LDS, LDS_RESOURCE); + // Check metric data. verifySubscribedResourcesMetadataSizes(1, 0, 0, 0); } @@ -626,16 +755,41 @@ public void ldsResourceNotFound() { public void ldsResourceUpdated_withXdstpResourceName_withUnknownAuthority() { String ldsResourceName = "xdstp://unknown.example.com/envoy.config.listener.v3.Listener/listener1"; - xdsClient.watchXdsResource(XdsListenerResource.getInstance(),ldsResourceName, + xdsClient.watchXdsResource(XdsListenerResource.getInstance(), ldsResourceName, + ldsResourceWatcher); + verify(ldsResourceWatcher).onResourceChanged(argThat(statusOr -> + !statusOr.hasValue() + && statusOr.getStatus().getCode() == Status.Code.INVALID_ARGUMENT + && statusOr.getStatus().getDescription().equals( + "Wrong configuration: xds server does not exist for resource " + ldsResourceName))); + assertThat(resourceDiscoveryCalls.poll()).isNull(); + xdsClient.cancelXdsResourceWatch(XdsListenerResource.getInstance(), ldsResourceName, ldsResourceWatcher); - verify(ldsResourceWatcher).onError(errorCaptor.capture()); - Status error = errorCaptor.getValue(); - assertThat(error.getCode()).isEqualTo(Code.INVALID_ARGUMENT); - assertThat(error.getDescription()).isEqualTo( - "Wrong configuration: xds server does not exist for resource " + ldsResourceName); assertThat(resourceDiscoveryCalls.poll()).isNull(); - xdsClient.cancelXdsResourceWatch(XdsListenerResource.getInstance(),ldsResourceName, + } + + @Test + public void ldsResource_onError_cachedForNewWatcher() { + xdsClient.watchXdsResource(XdsListenerResource.getInstance(), LDS_RESOURCE, ldsResourceWatcher); + DiscoveryRpcCall call = resourceDiscoveryCalls.poll(); + call.sendCompleted(); + @SuppressWarnings("unchecked") + ArgumentCaptor> errorCaptor = + ArgumentCaptor.forClass(StatusOr.class); + verify(ldsResourceWatcher, timeout(1000)).onResourceChanged(errorCaptor.capture()); + StatusOr initialError = errorCaptor.getValue(); + assertThat(initialError.hasValue()).isFalse(); + + xdsClient.watchXdsResource(XdsListenerResource.getInstance(), LDS_RESOURCE, + ldsResourceWatcher2); + @SuppressWarnings("unchecked") + ArgumentCaptor> secondErrorCaptor = + ArgumentCaptor.forClass(StatusOr.class); + verify(ldsResourceWatcher2, timeout(1000)).onResourceChanged(secondErrorCaptor.capture()); + StatusOr cachedError = secondErrorCaptor.getValue(); + + assertThat(cachedError).isEqualTo(initialError); assertThat(resourceDiscoveryCalls.poll()).isNull(); } @@ -674,21 +828,22 @@ public void ldsResponseErrorHandling_someResourcesFailedUnpack() { verifySubscribedResourcesMetadataSizes(1, 0, 0, 0); // The response is NACKed with the same error message. call.verifyRequestNack(LDS, LDS_RESOURCE, "", "0000", NODE, errors); - verify(ldsResourceWatcher).onChanged(any(LdsUpdate.class)); + verify(ldsResourceWatcher).onResourceChanged(any()); } /** * Tests a subscribed LDS resource transitioned to and from the invalid state. * - * @see - * A40-csds-support.md + * @see + * A40-csds-support.md */ @Test public void ldsResponseErrorHandling_subscribedResourceInvalid() { List subscribedResourceNames = ImmutableList.of("A", "B", "C"); - xdsClient.watchXdsResource(XdsListenerResource.getInstance(),"A", ldsResourceWatcher); - xdsClient.watchXdsResource(XdsListenerResource.getInstance(),"B", ldsResourceWatcher); - xdsClient.watchXdsResource(XdsListenerResource.getInstance(),"C", ldsResourceWatcher); + xdsClient.watchXdsResource(XdsListenerResource.getInstance(), "A", ldsResourceWatcher); + xdsClient.watchXdsResource(XdsListenerResource.getInstance(), "B", ldsResourceWatcher); + xdsClient.watchXdsResource(XdsListenerResource.getInstance(), "C", ldsResourceWatcher); DiscoveryRpcCall call = resourceDiscoveryCalls.poll(); assertThat(call).isNotNull(); verifyResourceMetadataRequested(LDS, "A"); @@ -706,6 +861,8 @@ public void ldsResponseErrorHandling_subscribedResourceInvalid() { verifyResourceMetadataAcked(LDS, "A", resourcesV1.get("A"), VERSION_1, TIME_INCREMENT); verifyResourceMetadataAcked(LDS, "B", resourcesV1.get("B"), VERSION_1, TIME_INCREMENT); verifyResourceMetadataAcked(LDS, "C", resourcesV1.get("C"), VERSION_1, TIME_INCREMENT); + // Check metric data. + verifyResourceValidInvalidCount(1, 3, 0, xdsServerInfo.target(), LDS.typeUrl()); call.verifyRequest(LDS, subscribedResourceNames, VERSION_1, "0000", NODE); // LDS -> {A, B}, version 2 @@ -720,7 +877,9 @@ public void ldsResponseErrorHandling_subscribedResourceInvalid() { List errorsV2 = ImmutableList.of("LDS response Listener 'B' validation error: "); verifyResourceMetadataAcked(LDS, "A", resourcesV2.get("A"), VERSION_2, TIME_INCREMENT * 2); verifyResourceMetadataNacked(LDS, "B", resourcesV1.get("B"), VERSION_1, TIME_INCREMENT, - VERSION_2, TIME_INCREMENT * 2, errorsV2); + VERSION_2, TIME_INCREMENT * 2, errorsV2, true); + // Check metric data. + verifyResourceValidInvalidCount(1, 1, 1, xdsServerInfo.target(), LDS.typeUrl()); if (!ignoreResourceDeletion()) { verifyResourceMetadataDoesNotExist(LDS, "C"); } else { @@ -736,6 +895,8 @@ public void ldsResponseErrorHandling_subscribedResourceInvalid() { call.sendResponse(LDS, resourcesV3.values().asList(), VERSION_3, "0002"); // {A} -> does not exist // {B, C} -> ACK, version 3 + // Check metric data. + verifyResourceValidInvalidCount(1, 2, 0, xdsServerInfo.target(), LDS.typeUrl()); if (!ignoreResourceDeletion()) { verifyResourceMetadataDoesNotExist(LDS, "A"); } else { @@ -748,15 +909,61 @@ public void ldsResponseErrorHandling_subscribedResourceInvalid() { verifySubscribedResourcesMetadataSizes(3, 0, 0, 0); } + @Test + public void ldsResponseErrorHandling_subscribedResourceInvalid_withDataErrorHandlingEnabled() { + BootstrapperImpl.xdsDataErrorHandlingEnabled = true; + + xdsClient.watchXdsResource(XdsListenerResource.getInstance(), "A", ldsResourceWatcher); + xdsClient.watchXdsResource(XdsListenerResource.getInstance(), "B", ldsResourceWatcher); + xdsClient.watchXdsResource(XdsListenerResource.getInstance(), "C", ldsResourceWatcher); + DiscoveryRpcCall call = resourceDiscoveryCalls.poll(); + assertThat(call).isNotNull(); + verifyResourceMetadataRequested(LDS, "A"); + verifyResourceMetadataRequested(LDS, "B"); + verifyResourceMetadataRequested(LDS, "C"); + ImmutableMap resourcesV1 = ImmutableMap.of( + "A", Any.pack(mf.buildListenerWithApiListenerForRds("A", "A.1")), + "B", Any.pack(mf.buildListenerWithApiListenerForRds("B", "B.1")), + "C", Any.pack(mf.buildListenerWithApiListenerForRds("C", "C.1"))); + call.sendResponse(LDS, resourcesV1.values().asList(), VERSION_1, "0000"); + verify(ldsResourceWatcher, times(3)).onResourceChanged(any()); + ImmutableMap resourcesV2 = ImmutableMap.of( + "A", Any.pack(mf.buildListenerWithApiListenerForRds("A", "A.2")), + "B", Any.pack(mf.buildListenerWithApiListenerInvalid("B"))); + call.sendResponse(LDS, resourcesV2.values().asList(), VERSION_2, "0001"); + ArgumentCaptor statusCaptor = ArgumentCaptor.forClass(Status.class); + verify(ldsResourceWatcher, times(2)).onAmbientError(statusCaptor.capture()); + List receivedStatuses = statusCaptor.getAllValues(); + assertThat(receivedStatuses).hasSize(2); + + assertThat( + receivedStatuses.stream().anyMatch( + status -> status.getCode() == Status.Code.UNAVAILABLE + && status.getDescription().contains("LDS response Listener 'B' validation error"))) + .isTrue(); + assertThat( + receivedStatuses.stream().anyMatch( + status -> status.getCode() == Status.Code.NOT_FOUND + && status.getDescription().contains("Resource C deleted from server"))) + .isTrue(); + List errorsV2 = ImmutableList.of("LDS response Listener 'B' validation error: "); + verifyResourceMetadataAcked(LDS, "A", resourcesV2.get("A"), VERSION_2, TIME_INCREMENT * 2); + verifyResourceMetadataNacked(LDS, "B", resourcesV1.get("B"), VERSION_1, TIME_INCREMENT, + VERSION_2, TIME_INCREMENT * 2, errorsV2, true); + verifyResourceMetadataAcked(LDS, "C", resourcesV1.get("C"), VERSION_1, TIME_INCREMENT); + + BootstrapperImpl.xdsDataErrorHandlingEnabled = false; + } + @Test public void ldsResponseErrorHandling_subscribedResourceInvalid_withRdsSubscription() { List subscribedResourceNames = ImmutableList.of("A", "B", "C"); - xdsClient.watchXdsResource(XdsListenerResource.getInstance(),"A", ldsResourceWatcher); - xdsClient.watchXdsResource(XdsRouteConfigureResource.getInstance(),"A.1", rdsResourceWatcher); - xdsClient.watchXdsResource(XdsListenerResource.getInstance(),"B", ldsResourceWatcher); - xdsClient.watchXdsResource(XdsRouteConfigureResource.getInstance(),"B.1", rdsResourceWatcher); - xdsClient.watchXdsResource(XdsListenerResource.getInstance(),"C", ldsResourceWatcher); - xdsClient.watchXdsResource(XdsRouteConfigureResource.getInstance(),"C.1", rdsResourceWatcher); + xdsClient.watchXdsResource(XdsListenerResource.getInstance(), "A", ldsResourceWatcher); + xdsClient.watchXdsResource(XdsRouteConfigureResource.getInstance(), "A.1", rdsResourceWatcher); + xdsClient.watchXdsResource(XdsListenerResource.getInstance(), "B", ldsResourceWatcher); + xdsClient.watchXdsResource(XdsRouteConfigureResource.getInstance(), "B.1", rdsResourceWatcher); + xdsClient.watchXdsResource(XdsListenerResource.getInstance(), "C", ldsResourceWatcher); + xdsClient.watchXdsResource(XdsRouteConfigureResource.getInstance(), "C.1", rdsResourceWatcher); DiscoveryRpcCall call = resourceDiscoveryCalls.poll(); assertThat(call).isNotNull(); verifyResourceMetadataRequested(LDS, "A"); @@ -774,6 +981,7 @@ public void ldsResponseErrorHandling_subscribedResourceInvalid_withRdsSubscripti "C", Any.pack(mf.buildListenerWithApiListenerForRds("C", "C.1"))); call.sendResponse(LDS, resourcesV1.values().asList(), VERSION_1, "0000"); // {A, B, C} -> ACK, version 1 + verifyResourceValidInvalidCount(1, 3, 0, xdsServerInfo.target(), LDS.typeUrl()); verifyResourceMetadataAcked(LDS, "A", resourcesV1.get("A"), VERSION_1, TIME_INCREMENT); verifyResourceMetadataAcked(LDS, "B", resourcesV1.get("B"), VERSION_1, TIME_INCREMENT); verifyResourceMetadataAcked(LDS, "C", resourcesV1.get("C"), VERSION_1, TIME_INCREMENT); @@ -790,6 +998,8 @@ public void ldsResponseErrorHandling_subscribedResourceInvalid_withRdsSubscripti verifyResourceMetadataAcked(RDS, "A.1", resourcesV11.get("A.1"), VERSION_1, TIME_INCREMENT * 2); verifyResourceMetadataAcked(RDS, "B.1", resourcesV11.get("B.1"), VERSION_1, TIME_INCREMENT * 2); verifyResourceMetadataAcked(RDS, "C.1", resourcesV11.get("C.1"), VERSION_1, TIME_INCREMENT * 2); + // Check metric data. + verifyResourceValidInvalidCount(1, 3, 0, xdsServerInfo.target(), RDS.typeUrl()); // LDS -> {A, B}, version 2 // Failed to parse endpoint B @@ -800,11 +1010,13 @@ public void ldsResponseErrorHandling_subscribedResourceInvalid_withRdsSubscripti // {A} -> ACK, version 2 // {B} -> NACK, version 1, rejected version 2, rejected reason: Failed to parse B // {C} -> does not exist + // Check metric data. + verifyResourceValidInvalidCount(1, 1, 1, xdsServerInfo.target(), LDS.typeUrl()); List errorsV2 = ImmutableList.of("LDS response Listener 'B' validation error: "); verifyResourceMetadataAcked(LDS, "A", resourcesV2.get("A"), VERSION_2, TIME_INCREMENT * 3); verifyResourceMetadataNacked( LDS, "B", resourcesV1.get("B"), VERSION_1, TIME_INCREMENT, VERSION_2, TIME_INCREMENT * 3, - errorsV2); + errorsV2, true); if (!ignoreResourceDeletion()) { verifyResourceMetadataDoesNotExist(LDS, "C"); } else { @@ -830,8 +1042,10 @@ public void ldsResourceFound_containsVirtualHosts() { // Client sends an ACK LDS request. call.sendResponse(LDS, testListenerVhosts, VERSION_1, "0000"); call.verifyRequest(LDS, LDS_RESOURCE, VERSION_1, "0000", NODE); - verify(ldsResourceWatcher).onChanged(ldsUpdateCaptor.capture()); - verifyGoldenListenerVhosts(ldsUpdateCaptor.getValue()); + verify(ldsResourceWatcher).onResourceChanged(ldsUpdateCaptor.capture()); + StatusOr statusOrUpdate = ldsUpdateCaptor.getValue(); + assertThat(statusOrUpdate.hasValue()).isTrue(); + verifyGoldenListenerVhosts(statusOrUpdate.getValue()); assertThat(fakeClock.getPendingTasks(LDS_RESOURCE_FETCH_TIMEOUT_TASK_FILTER)).isEmpty(); verifyResourceMetadataAcked(LDS, LDS_RESOURCE, testListenerVhosts, VERSION_1, TIME_INCREMENT); verifySubscribedResourcesMetadataSizes(1, 0, 0, 0); @@ -845,13 +1059,36 @@ public void wrappedLdsResource() { // Client sends an ACK LDS request. call.sendResponse(LDS, mf.buildWrappedResource(testListenerVhosts), VERSION_1, "0000"); call.verifyRequest(LDS, LDS_RESOURCE, VERSION_1, "0000", NODE); - verify(ldsResourceWatcher).onChanged(ldsUpdateCaptor.capture()); - verifyGoldenListenerVhosts(ldsUpdateCaptor.getValue()); + verify(ldsResourceWatcher).onResourceChanged(ldsUpdateCaptor.capture()); + StatusOr statusOrUpdate = ldsUpdateCaptor.getValue(); + assertThat(statusOrUpdate.hasValue()).isTrue(); + verifyGoldenListenerVhosts(statusOrUpdate.getValue()); assertThat(fakeClock.getPendingTasks(LDS_RESOURCE_FETCH_TIMEOUT_TASK_FILTER)).isEmpty(); verifyResourceMetadataAcked(LDS, LDS_RESOURCE, testListenerVhosts, VERSION_1, TIME_INCREMENT); verifySubscribedResourcesMetadataSizes(1, 0, 0, 0); } + @Test + public void wrappedLdsResource_preferWrappedName() { + DiscoveryRpcCall call = startResourceWatcher(XdsListenerResource.getInstance(), LDS_RESOURCE, + ldsResourceWatcher); + + Any innerResource = Any.pack(mf.buildListenerWithApiListener("random_name" /* name */, + mf.buildRouteConfiguration("do not care", mf.buildOpaqueVirtualHosts(VHOST_SIZE)))); + + // Client sends an ACK LDS request. + call.sendResponse(LDS, mf.buildWrappedResourceWithName(innerResource, LDS_RESOURCE), VERSION_1, + "0000"); + call.verifyRequest(LDS, LDS_RESOURCE, VERSION_1, "0000", NODE); + verify(ldsResourceWatcher).onResourceChanged(ldsUpdateCaptor.capture()); + StatusOr statusOrUpdate = ldsUpdateCaptor.getValue(); + assertThat(statusOrUpdate.hasValue()).isTrue(); + verifyGoldenListenerVhosts(statusOrUpdate.getValue()); + assertThat(fakeClock.getPendingTasks(LDS_RESOURCE_FETCH_TIMEOUT_TASK_FILTER)).isEmpty(); + verifyResourceMetadataAcked(LDS, LDS_RESOURCE, innerResource, VERSION_1, TIME_INCREMENT); + verifySubscribedResourcesMetadataSizes(1, 0, 0, 0); + } + @Test public void ldsResourceFound_containsRdsName() { DiscoveryRpcCall call = startResourceWatcher(XdsListenerResource.getInstance(), LDS_RESOURCE, @@ -860,8 +1097,10 @@ public void ldsResourceFound_containsRdsName() { // Client sends an ACK LDS request. call.verifyRequest(LDS, LDS_RESOURCE, VERSION_1, "0000", NODE); - verify(ldsResourceWatcher).onChanged(ldsUpdateCaptor.capture()); - verifyGoldenListenerRds(ldsUpdateCaptor.getValue()); + verify(ldsResourceWatcher).onResourceChanged(ldsUpdateCaptor.capture()); + StatusOr statusOrUpdate = ldsUpdateCaptor.getValue(); + assertThat(statusOrUpdate.hasValue()).isTrue(); + verifyGoldenListenerRds(statusOrUpdate.getValue()); assertThat(fakeClock.getPendingTasks(LDS_RESOURCE_FETCH_TIMEOUT_TASK_FILTER)).isEmpty(); verifyResourceMetadataAcked(LDS, LDS_RESOURCE, testListenerRds, VERSION_1, TIME_INCREMENT); verifySubscribedResourcesMetadataSizes(1, 0, 0, 0); @@ -878,9 +1117,11 @@ public void cachedLdsResource_data() { call.verifyRequest(LDS, LDS_RESOURCE, VERSION_1, "0000", NODE); ResourceWatcher watcher = mock(ResourceWatcher.class); - xdsClient.watchXdsResource(XdsListenerResource.getInstance(),LDS_RESOURCE, watcher); - verify(watcher).onChanged(ldsUpdateCaptor.capture()); - verifyGoldenListenerRds(ldsUpdateCaptor.getValue()); + xdsClient.watchXdsResource(XdsListenerResource.getInstance(), LDS_RESOURCE, watcher); + verify(watcher).onResourceChanged(ldsUpdateCaptor.capture()); + StatusOr statusOrUpdate = ldsUpdateCaptor.getValue(); + assertThat(statusOrUpdate.hasValue()).isTrue(); + verifyGoldenListenerRds(statusOrUpdate.getValue()); call.verifyNoMoreRequest(); verifyResourceMetadataAcked(LDS, LDS_RESOURCE, testListenerRds, VERSION_1, TIME_INCREMENT); verifySubscribedResourcesMetadataSizes(1, 0, 0, 0); @@ -892,11 +1133,17 @@ public void cachedLdsResource_absent() { DiscoveryRpcCall call = startResourceWatcher(XdsListenerResource.getInstance(), LDS_RESOURCE, ldsResourceWatcher); fakeClock.forwardTime(XdsClientImpl.INITIAL_RESOURCE_FETCH_TIMEOUT_SEC, TimeUnit.SECONDS); - verify(ldsResourceWatcher).onResourceDoesNotExist(LDS_RESOURCE); + verify(ldsResourceWatcher).onResourceChanged(ldsUpdateCaptor.capture()); + StatusOr statusOrUpdate = ldsUpdateCaptor.getValue(); + assertThat(statusOrUpdate.hasValue()).isFalse(); + assertThat(statusOrUpdate.getStatus().getCode()).isEqualTo(Status.Code.NOT_FOUND); // Add another watcher. ResourceWatcher watcher = mock(ResourceWatcher.class); - xdsClient.watchXdsResource(XdsListenerResource.getInstance(),LDS_RESOURCE, watcher); - verify(watcher).onResourceDoesNotExist(LDS_RESOURCE); + xdsClient.watchXdsResource(XdsListenerResource.getInstance(), LDS_RESOURCE, watcher); + verify(watcher).onResourceChanged(ldsUpdateCaptor.capture()); + StatusOr statusOrUpdate1 = ldsUpdateCaptor.getValue(); + assertThat(statusOrUpdate1.hasValue()).isFalse(); + assertThat(statusOrUpdate1.getStatus().getCode()).isEqualTo(Status.Code.NOT_FOUND); call.verifyNoMoreRequest(); verifyResourceMetadataDoesNotExist(LDS, LDS_RESOURCE); verifySubscribedResourcesMetadataSizes(1, 0, 0, 0); @@ -911,15 +1158,19 @@ public void ldsResourceUpdated() { // Initial LDS response. call.sendResponse(LDS, testListenerVhosts, VERSION_1, "0000"); call.verifyRequest(LDS, LDS_RESOURCE, VERSION_1, "0000", NODE); - verify(ldsResourceWatcher).onChanged(ldsUpdateCaptor.capture()); - verifyGoldenListenerVhosts(ldsUpdateCaptor.getValue()); + verify(ldsResourceWatcher).onResourceChanged(ldsUpdateCaptor.capture()); + StatusOr statusOrUpdate = ldsUpdateCaptor.getValue(); + assertThat(statusOrUpdate.hasValue()).isTrue(); + verifyGoldenListenerVhosts(statusOrUpdate.getValue()); verifyResourceMetadataAcked(LDS, LDS_RESOURCE, testListenerVhosts, VERSION_1, TIME_INCREMENT); // Updated LDS response. call.sendResponse(LDS, testListenerRds, VERSION_2, "0001"); call.verifyRequest(LDS, LDS_RESOURCE, VERSION_2, "0001", NODE); - verify(ldsResourceWatcher, times(2)).onChanged(ldsUpdateCaptor.capture()); - verifyGoldenListenerRds(ldsUpdateCaptor.getValue()); + verify(ldsResourceWatcher, times(2)).onResourceChanged(ldsUpdateCaptor.capture()); + statusOrUpdate = ldsUpdateCaptor.getValue(); + assertThat(statusOrUpdate.hasValue()).isTrue(); + verifyGoldenListenerRds(statusOrUpdate.getValue()); verifyResourceMetadataAcked(LDS, LDS_RESOURCE, testListenerRds, VERSION_2, TIME_INCREMENT * 2); verifySubscribedResourcesMetadataSizes(1, 0, 0, 0); assertThat(channelForCustomAuthority).isNull(); @@ -935,8 +1186,10 @@ public void cancelResourceWatcherNotRemoveUrlSubscribers() { // Initial LDS response. call.sendResponse(LDS, testListenerVhosts, VERSION_1, "0000"); call.verifyRequest(LDS, LDS_RESOURCE, VERSION_1, "0000", NODE); - verify(ldsResourceWatcher).onChanged(ldsUpdateCaptor.capture()); - verifyGoldenListenerVhosts(ldsUpdateCaptor.getValue()); + verify(ldsResourceWatcher).onResourceChanged(ldsUpdateCaptor.capture()); + StatusOr statusOrUpdate = ldsUpdateCaptor.getValue(); + assertThat(statusOrUpdate.hasValue()).isTrue(); + verifyGoldenListenerVhosts(statusOrUpdate.getValue()); verifyResourceMetadataAcked(LDS, LDS_RESOURCE, testListenerVhosts, VERSION_1, TIME_INCREMENT); xdsClient.watchXdsResource(XdsListenerResource.getInstance(), @@ -949,8 +1202,10 @@ public void cancelResourceWatcherNotRemoveUrlSubscribers() { mf.buildRouteConfiguration("new", mf.buildOpaqueVirtualHosts(VHOST_SIZE)))); call.sendResponse(LDS, testListenerVhosts2, VERSION_2, "0001"); call.verifyRequest(LDS, LDS_RESOURCE, VERSION_2, "0001", NODE); - verify(ldsResourceWatcher).onChanged(ldsUpdateCaptor.capture()); - verifyGoldenListenerVhosts(ldsUpdateCaptor.getValue()); + verify(ldsResourceWatcher).onResourceChanged(ldsUpdateCaptor.capture()); + statusOrUpdate = ldsUpdateCaptor.getValue(); + assertThat(statusOrUpdate.hasValue()).isTrue(); + verifyGoldenListenerVhosts(statusOrUpdate.getValue()); verifyResourceMetadataAcked(LDS, LDS_RESOURCE, testListenerVhosts2, VERSION_2, TIME_INCREMENT * 2); } @@ -968,8 +1223,10 @@ public void ldsResourceUpdated_withXdstpResourceName() { mf.buildRouteConfiguration("do not care", mf.buildOpaqueVirtualHosts(VHOST_SIZE)))); call.sendResponse(LDS, testListenerVhosts, VERSION_1, "0000"); call.verifyRequest(LDS, ldsResourceName, VERSION_1, "0000", NODE); - verify(ldsResourceWatcher).onChanged(ldsUpdateCaptor.capture()); - verifyGoldenListenerVhosts(ldsUpdateCaptor.getValue()); + verify(ldsResourceWatcher).onResourceChanged(ldsUpdateCaptor.capture()); + StatusOr statusOrUpdate = ldsUpdateCaptor.getValue(); + assertThat(statusOrUpdate.hasValue()).isTrue(); + verifyGoldenListenerVhosts(statusOrUpdate.getValue()); verifyResourceMetadataAcked( LDS, ldsResourceName, testListenerVhosts, VERSION_1, TIME_INCREMENT); } @@ -986,8 +1243,10 @@ public void ldsResourceUpdated_withXdstpResourceName_withEmptyAuthority() { mf.buildRouteConfiguration("do not care", mf.buildOpaqueVirtualHosts(VHOST_SIZE)))); call.sendResponse(LDS, testListenerVhosts, VERSION_1, "0000"); call.verifyRequest(LDS, ldsResourceName, VERSION_1, "0000", NODE); - verify(ldsResourceWatcher).onChanged(ldsUpdateCaptor.capture()); - verifyGoldenListenerVhosts(ldsUpdateCaptor.getValue()); + verify(ldsResourceWatcher).onResourceChanged(ldsUpdateCaptor.capture()); + StatusOr statusOrUpdate = ldsUpdateCaptor.getValue(); + assertThat(statusOrUpdate.hasValue()).isTrue(); + verifyGoldenListenerVhosts(statusOrUpdate.getValue()); verifyResourceMetadataAcked( LDS, ldsResourceName, testListenerVhosts, VERSION_1, TIME_INCREMENT); } @@ -1027,7 +1286,7 @@ public void ldsResourceUpdated_withXdstpResourceName_withWrongType() { call.verifyRequestNack( LDS, ldsResourceName, "", "0000", NODE, ImmutableList.of( - "Unsupported resource name: " + ldsResourceNameWithWrongType + " for type: LDS")); + "Unsupported resource name: " + ldsResourceNameWithWrongType + " for type: LDS")); } @Test @@ -1053,16 +1312,20 @@ public void rdsResourceUpdated_withXdstpResourceName_withWrongType() { public void rdsResourceUpdated_withXdstpResourceName_unknownAuthority() { String rdsResourceName = "xdstp://unknown.example.com/envoy.config.route.v3.RouteConfiguration/route1"; - xdsClient.watchXdsResource(XdsRouteConfigureResource.getInstance(),rdsResourceName, + xdsClient.watchXdsResource(XdsRouteConfigureResource.getInstance(), rdsResourceName, rdsResourceWatcher); - verify(rdsResourceWatcher).onError(errorCaptor.capture()); - Status error = errorCaptor.getValue(); + @SuppressWarnings("unchecked") + ArgumentCaptor> rdsUpdateCaptor = ArgumentCaptor.forClass(StatusOr.class); + verify(rdsResourceWatcher).onResourceChanged(rdsUpdateCaptor.capture()); + StatusOr capturedUpdate = rdsUpdateCaptor.getValue(); + assertThat(capturedUpdate.hasValue()).isFalse(); + Status error = capturedUpdate.getStatus(); assertThat(error.getCode()).isEqualTo(Code.INVALID_ARGUMENT); assertThat(error.getDescription()).isEqualTo( "Wrong configuration: xds server does not exist for resource " + rdsResourceName); assertThat(resourceDiscoveryCalls.size()).isEqualTo(0); xdsClient.cancelXdsResourceWatch( - XdsRouteConfigureResource.getInstance(),rdsResourceName, rdsResourceWatcher); + XdsRouteConfigureResource.getInstance(), rdsResourceName, rdsResourceWatcher); assertThat(resourceDiscoveryCalls.size()).isEqualTo(0); } @@ -1088,15 +1351,19 @@ public void cdsResourceUpdated_withXdstpResourceName_withWrongType() { @Test public void cdsResourceUpdated_withXdstpResourceName_unknownAuthority() { String cdsResourceName = "xdstp://unknown.example.com/envoy.config.cluster.v3.Cluster/cluster1"; - xdsClient.watchXdsResource(XdsClusterResource.getInstance(),cdsResourceName, + xdsClient.watchXdsResource(XdsClusterResource.getInstance(), cdsResourceName, cdsResourceWatcher); - verify(cdsResourceWatcher).onError(errorCaptor.capture()); - Status error = errorCaptor.getValue(); + @SuppressWarnings("unchecked") + ArgumentCaptor> cdsUpdateCaptor = ArgumentCaptor.forClass(StatusOr.class); + verify(cdsResourceWatcher).onResourceChanged(cdsUpdateCaptor.capture()); + StatusOr capturedUpdate = cdsUpdateCaptor.getValue(); + assertThat(capturedUpdate.hasValue()).isFalse(); + Status error = capturedUpdate.getStatus(); assertThat(error.getCode()).isEqualTo(Code.INVALID_ARGUMENT); assertThat(error.getDescription()).isEqualTo( "Wrong configuration: xds server does not exist for resource " + cdsResourceName); assertThat(resourceDiscoveryCalls.poll()).isNull(); - xdsClient.cancelXdsResourceWatch(XdsClusterResource.getInstance(),cdsResourceName, + xdsClient.cancelXdsResourceWatch(XdsClusterResource.getInstance(), cdsResourceName, cdsResourceWatcher); assertThat(resourceDiscoveryCalls.poll()).isNull(); } @@ -1115,7 +1382,7 @@ public void edsResourceUpdated_withXdstpResourceName_withWrongType() { edsResourceNameWithWrongType, ImmutableList.of(mf.buildLocalityLbEndpoints( "region2", "zone2", "subzone2", - mf.buildLbEndpoint("172.44.2.2", 8000, "unknown", 3), 2, 0)), + mf.buildLbEndpoint("172.44.2.2", 8000, "unknown", 3, "endpoint-host-name"), 2, 0)), ImmutableList.of())); call.sendResponse(EDS, testEdsConfig, VERSION_1, "0000"); call.verifyRequestNack( @@ -1130,8 +1397,12 @@ public void edsResourceUpdated_withXdstpResourceName_unknownAuthority() { "xdstp://unknown.example.com/envoy.config.endpoint.v3.ClusterLoadAssignment/cluster1"; xdsClient.watchXdsResource(XdsEndpointResource.getInstance(), edsResourceName, edsResourceWatcher); - verify(edsResourceWatcher).onError(errorCaptor.capture()); - Status error = errorCaptor.getValue(); + @SuppressWarnings("unchecked") + ArgumentCaptor> edsUpdateCaptor = ArgumentCaptor.forClass(StatusOr.class); + verify(edsResourceWatcher).onResourceChanged(edsUpdateCaptor.capture()); + StatusOr capturedUpdate = edsUpdateCaptor.getValue(); + assertThat(capturedUpdate.hasValue()).isFalse(); + Status error = capturedUpdate.getStatus(); assertThat(error.getCode()).isEqualTo(Code.INVALID_ARGUMENT); assertThat(error.getDescription()).isEqualTo( "Wrong configuration: xds server does not exist for resource " + edsResourceName); @@ -1180,11 +1451,13 @@ public void ldsResourceUpdate_withFaultInjection() { // Client sends an ACK LDS request. call.verifyRequest(LDS, LDS_RESOURCE, VERSION_1, "0000", NODE); - verify(ldsResourceWatcher).onChanged(ldsUpdateCaptor.capture()); + verify(ldsResourceWatcher).onResourceChanged(ldsUpdateCaptor.capture()); + StatusOr statusOrUpdate = ldsUpdateCaptor.getValue(); + assertThat(statusOrUpdate.hasValue()).isTrue(); verifyResourceMetadataAcked(LDS, LDS_RESOURCE, listener, VERSION_1, TIME_INCREMENT); verifySubscribedResourcesMetadataSizes(1, 0, 0, 0); - LdsUpdate ldsUpdate = ldsUpdateCaptor.getValue(); + LdsUpdate ldsUpdate = statusOrUpdate.getValue(); assertThat(ldsUpdate.httpConnectionManager().virtualHosts()).hasSize(2); assertThat(ldsUpdate.httpConnectionManager().httpFilterConfigs().get(0).name) .isEqualTo("envoy.fault"); @@ -1209,6 +1482,7 @@ public void ldsResourceUpdate_withFaultInjection() { @Test public void ldsResourceDeleted() { Assume.assumeFalse(ignoreResourceDeletion()); + InOrder inOrder = inOrder(ldsResourceWatcher); DiscoveryRpcCall call = startResourceWatcher(XdsListenerResource.getInstance(), LDS_RESOURCE, ldsResourceWatcher); @@ -1217,15 +1491,20 @@ public void ldsResourceDeleted() { // Initial LDS response. call.sendResponse(LDS, testListenerVhosts, VERSION_1, "0000"); call.verifyRequest(LDS, LDS_RESOURCE, VERSION_1, "0000", NODE); - verify(ldsResourceWatcher).onChanged(ldsUpdateCaptor.capture()); - verifyGoldenListenerVhosts(ldsUpdateCaptor.getValue()); + inOrder.verify(ldsResourceWatcher).onResourceChanged(ldsUpdateCaptor.capture()); + StatusOr statusOrUpdate = ldsUpdateCaptor.getValue(); + assertThat(statusOrUpdate.hasValue()).isTrue(); + verifyGoldenListenerVhosts(statusOrUpdate.getValue()); verifyResourceMetadataAcked(LDS, LDS_RESOURCE, testListenerVhosts, VERSION_1, TIME_INCREMENT); verifySubscribedResourcesMetadataSizes(1, 0, 0, 0); // Empty LDS response deletes the listener. call.sendResponse(LDS, Collections.emptyList(), VERSION_2, "0001"); call.verifyRequest(LDS, LDS_RESOURCE, VERSION_2, "0001", NODE); - verify(ldsResourceWatcher).onResourceDoesNotExist(LDS_RESOURCE); + inOrder.verify(ldsResourceWatcher).onResourceChanged(ldsUpdateCaptor.capture()); + StatusOr statusOrUpdate1 = ldsUpdateCaptor.getValue(); + assertThat(statusOrUpdate1.hasValue()).isFalse(); + assertThat(statusOrUpdate1.getStatus().getCode()).isEqualTo(Status.Code.NOT_FOUND); verifyResourceMetadataDoesNotExist(LDS, LDS_RESOURCE); verifySubscribedResourcesMetadataSizes(1, 0, 0, 0); } @@ -1233,7 +1512,7 @@ public void ldsResourceDeleted() { /** * When ignore_resource_deletion server feature is on, xDS client should keep the deleted listener * on empty response, and resume the normal work when LDS contains the listener again. - * */ + */ @Test public void ldsResourceDeleted_ignoreResourceDeletion() { Assume.assumeTrue(ignoreResourceDeletion()); @@ -1245,8 +1524,8 @@ public void ldsResourceDeleted_ignoreResourceDeletion() { // Initial LDS response. call.sendResponse(LDS, testListenerVhosts, VERSION_1, "0000"); call.verifyRequest(LDS, LDS_RESOURCE, VERSION_1, "0000", NODE); - verify(ldsResourceWatcher).onChanged(ldsUpdateCaptor.capture()); - verifyGoldenListenerVhosts(ldsUpdateCaptor.getValue()); + verify(ldsResourceWatcher).onResourceChanged(ldsUpdateCaptor.capture()); + verifyGoldenListenerVhosts(ldsUpdateCaptor.getValue().getValue()); verifyResourceMetadataAcked(LDS, LDS_RESOURCE, testListenerVhosts, VERSION_1, TIME_INCREMENT); verifySubscribedResourcesMetadataSizes(1, 0, 0, 0); @@ -1254,32 +1533,204 @@ public void ldsResourceDeleted_ignoreResourceDeletion() { call.sendResponse(LDS, Collections.emptyList(), VERSION_2, "0001"); call.verifyRequest(LDS, LDS_RESOURCE, VERSION_2, "0001", NODE); // The resource is still ACKED at VERSION_1 (no changes). + verify(ldsResourceWatcher).onAmbientError( + argThat(status -> status.getCode() == Status.Code.NOT_FOUND)); verifyResourceMetadataAcked(LDS, LDS_RESOURCE, testListenerVhosts, VERSION_1, TIME_INCREMENT); verifySubscribedResourcesMetadataSizes(1, 0, 0, 0); - // onResourceDoesNotExist not called - verify(ldsResourceWatcher, never()).onResourceDoesNotExist(LDS_RESOURCE); // Next update is correct, and contains the listener again. - call.sendResponse(LDS, testListenerVhosts, VERSION_3, "0003"); + Any updatedListener = Any.pack(mf.buildListenerWithApiListener(LDS_RESOURCE, + mf.buildRouteConfiguration("do not care", mf.buildOpaqueVirtualHosts(VHOST_SIZE + 1)))); + call.sendResponse(LDS, updatedListener, VERSION_3, "0003"); call.verifyRequest(LDS, LDS_RESOURCE, VERSION_3, "0003", NODE); - verify(ldsResourceWatcher).onChanged(ldsUpdateCaptor.capture()); - verifyGoldenListenerVhosts(ldsUpdateCaptor.getValue()); + verify(ldsResourceWatcher, times(2)).onResourceChanged(ldsUpdateCaptor.capture()); + assertThat(ldsUpdateCaptor.getValue().getValue().httpConnectionManager().virtualHosts()) + .hasSize(VHOST_SIZE + 1); // LDS is now ACKEd at VERSION_3. - verifyResourceMetadataAcked(LDS, LDS_RESOURCE, testListenerVhosts, VERSION_3, - TIME_INCREMENT * 3); + verifyResourceMetadataAcked(LDS, LDS_RESOURCE, updatedListener, VERSION_3, TIME_INCREMENT * 3); verifySubscribedResourcesMetadataSizes(1, 0, 0, 0); verifyNoMoreInteractions(ldsResourceWatcher); } + /** + * When fail_on_data_errors server feature is on, xDS client should delete the cached listener + * and fail RPCs when LDS resource is deleted. + */ + @Test + public void ldsResourceDeleted_failOnDataErrors_true() { + BootstrapperImpl.xdsDataErrorHandlingEnabled = true; + xdsServerInfo = ServerInfo.create(SERVER_URI, CHANNEL_CREDENTIALS, false, + true, false, true); + BootstrapInfo bootstrapInfo = + Bootstrapper.BootstrapInfo.builder() + .servers(Collections.singletonList(xdsServerInfo)) + .node(NODE) + .authorities(ImmutableMap.of( + "", + AuthorityInfo.create( + "xdstp:///envoy.config.listener.v3.Listener/%s", + ImmutableList.of(Bootstrapper.ServerInfo.create( + SERVER_URI_EMPTY_AUTHORITY, CHANNEL_CREDENTIALS))))) + .certProviders(ImmutableMap.of()) + .build(); + xdsClient = new XdsClientImpl( + xdsTransportFactory, + bootstrapInfo, + fakeClock.getScheduledExecutorService(), + backoffPolicyProvider, + fakeClock.getStopwatchSupplier(), + timeProvider, + MessagePrinter.INSTANCE, + new TlsContextManagerImpl(bootstrapInfo), + xdsClientMetricReporter); + + InOrder inOrder = inOrder(ldsResourceWatcher); + DiscoveryRpcCall call = startResourceWatcher(XdsListenerResource.getInstance(), LDS_RESOURCE, + ldsResourceWatcher); + verifyResourceMetadataRequested(LDS, LDS_RESOURCE); + + // Initial LDS response. + call.sendResponse(LDS, testListenerVhosts, VERSION_1, "0000"); + call.verifyRequest(LDS, LDS_RESOURCE, VERSION_1, "0000", NODE); + inOrder.verify(ldsResourceWatcher).onResourceChanged(ldsUpdateCaptor.capture()); + StatusOr statusOrUpdate = ldsUpdateCaptor.getValue(); + assertThat(statusOrUpdate.hasValue()).isTrue(); + verifyGoldenListenerVhosts(statusOrUpdate.getValue()); + verifyResourceMetadataAcked(LDS, LDS_RESOURCE, testListenerVhosts, VERSION_1, TIME_INCREMENT); + verifySubscribedResourcesMetadataSizes(1, 0, 0, 0); + + // Empty LDS response deletes the listener and fails RPCs. + call.sendResponse(LDS, Collections.emptyList(), VERSION_2, "0001"); + call.verifyRequest(LDS, LDS_RESOURCE, VERSION_2, "0001", NODE); + inOrder.verify(ldsResourceWatcher).onResourceChanged(ldsUpdateCaptor.capture()); + StatusOr statusOrUpdate1 = ldsUpdateCaptor.getValue(); + assertThat(statusOrUpdate1.hasValue()).isFalse(); + assertThat(statusOrUpdate1.getStatus().getCode()).isEqualTo(Status.Code.NOT_FOUND); + verifyResourceMetadataDoesNotExist(LDS, LDS_RESOURCE); + verifySubscribedResourcesMetadataSizes(1, 0, 0, 0); + + BootstrapperImpl.xdsDataErrorHandlingEnabled = false; + } + + /** + * When the fail_on_data_errors server feature is not present, the default behavior + * is to treat a resource deletion as an ambient error and preserve the cached resource. + */ + @Test + public void ldsResourceDeleted_failOnDataErrors_false() { + BootstrapperImpl.xdsDataErrorHandlingEnabled = true; + + xdsServerInfo = ServerInfo.create(SERVER_URI, CHANNEL_CREDENTIALS, false, + true, false, false); + BootstrapInfo bootstrapInfo = + Bootstrapper.BootstrapInfo.builder() + .servers(Collections.singletonList(xdsServerInfo)) + .node(NODE) + .authorities(ImmutableMap.of( + "", + AuthorityInfo.create( + "xdstp:///envoy.config.listener.v3.Listener/%s", + ImmutableList.of(Bootstrapper.ServerInfo.create( + SERVER_URI_EMPTY_AUTHORITY, CHANNEL_CREDENTIALS))))) + .certProviders(ImmutableMap.of()) + .build(); + xdsClient = new XdsClientImpl( + xdsTransportFactory, + bootstrapInfo, + fakeClock.getScheduledExecutorService(), + backoffPolicyProvider, + fakeClock.getStopwatchSupplier(), + timeProvider, + MessagePrinter.INSTANCE, + new TlsContextManagerImpl(bootstrapInfo), + xdsClientMetricReporter); + + InOrder inOrder = inOrder(ldsResourceWatcher); + DiscoveryRpcCall call = startResourceWatcher(XdsListenerResource.getInstance(), LDS_RESOURCE, + ldsResourceWatcher); + verifyResourceMetadataRequested(LDS, LDS_RESOURCE); + + // Initial LDS response. + call.sendResponse(LDS, testListenerVhosts, VERSION_1, "0000"); + call.verifyRequest(LDS, LDS_RESOURCE, VERSION_1, "0000", NODE); + inOrder.verify(ldsResourceWatcher).onResourceChanged(ldsUpdateCaptor.capture()); + StatusOr statusOrUpdate = ldsUpdateCaptor.getValue(); + assertThat(statusOrUpdate.hasValue()).isTrue(); + verifyGoldenListenerVhosts(statusOrUpdate.getValue()); + verifyResourceMetadataAcked(LDS, LDS_RESOURCE, testListenerVhosts, VERSION_1, TIME_INCREMENT); + verifySubscribedResourcesMetadataSizes(1, 0, 0, 0); + + // Empty LDS response deletes the listener and fails RPCs. + call.sendResponse(LDS, Collections.emptyList(), VERSION_2, "0001"); + call.verifyRequest(LDS, LDS_RESOURCE, VERSION_2, "0001", NODE); + ArgumentCaptor statusCaptor = ArgumentCaptor.forClass(Status.class); + inOrder.verify(ldsResourceWatcher).onAmbientError(statusCaptor.capture()); + Status receivedStatus = statusCaptor.getValue(); + assertThat(receivedStatus.getCode()).isEqualTo(Status.Code.NOT_FOUND); + assertThat(receivedStatus.getDescription()).contains( + "Resource " + LDS_RESOURCE + " deleted from server"); + inOrder.verify(ldsResourceWatcher, never()).onResourceChanged(any()); + verifySubscribedResourcesMetadataSizes(1, 0, 0, 0); + + BootstrapperImpl.xdsDataErrorHandlingEnabled = false; + } + + /** + * Tests that fail_on_data_errors feature is ignored if the env var is not enabled, + * and the old behavior (dropping the resource) is used. + */ + @Test + public void ldsResourceDeleted_failOnDataErrorsIgnoredWithoutEnvVar() { + BootstrapperImpl.xdsDataErrorHandlingEnabled = false; + + xdsServerInfo = ServerInfo.create(SERVER_URI, CHANNEL_CREDENTIALS, false, + true, false, true); + BootstrapInfo bootstrapInfo = + Bootstrapper.BootstrapInfo.builder() + .servers(Collections.singletonList(xdsServerInfo)) + .node(NODE) + .authorities(ImmutableMap.of( + "", + AuthorityInfo.create( + "xdstp:///envoy.config.listener.v3.Listener/%s", + ImmutableList.of(Bootstrapper.ServerInfo.create( + SERVER_URI_EMPTY_AUTHORITY, CHANNEL_CREDENTIALS))))) + .certProviders(ImmutableMap.of()) + .build(); + xdsClient = new XdsClientImpl( + xdsTransportFactory, + bootstrapInfo, + fakeClock.getScheduledExecutorService(), + backoffPolicyProvider, + fakeClock.getStopwatchSupplier(), + timeProvider, + MessagePrinter.INSTANCE, + new TlsContextManagerImpl(bootstrapInfo), + xdsClientMetricReporter); + + InOrder inOrder = inOrder(ldsResourceWatcher); + DiscoveryRpcCall call = startResourceWatcher(XdsListenerResource.getInstance(), LDS_RESOURCE, + ldsResourceWatcher); + call.sendResponse(LDS, testListenerVhosts, VERSION_1, "0000"); + inOrder.verify(ldsResourceWatcher).onResourceChanged(ldsUpdateCaptor.capture()); + assertThat(ldsUpdateCaptor.getValue().hasValue()).isTrue(); + call.sendResponse(LDS, Collections.emptyList(), VERSION_2, "0001"); + + inOrder.verify(ldsResourceWatcher).onResourceChanged(ldsUpdateCaptor.capture()); + StatusOr statusOrUpdate = ldsUpdateCaptor.getValue(); + assertThat(statusOrUpdate.hasValue()).isFalse(); + assertThat(statusOrUpdate.getStatus().getCode()).isEqualTo(Status.Code.NOT_FOUND); + } + @Test @SuppressWarnings("unchecked") public void multipleLdsWatchers() { String ldsResourceTwo = "bar.googleapis.com"; ResourceWatcher watcher1 = mock(ResourceWatcher.class); ResourceWatcher watcher2 = mock(ResourceWatcher.class); - xdsClient.watchXdsResource(XdsListenerResource.getInstance(),LDS_RESOURCE, ldsResourceWatcher); - xdsClient.watchXdsResource(XdsListenerResource.getInstance(),ldsResourceTwo, watcher1); - xdsClient.watchXdsResource(XdsListenerResource.getInstance(),ldsResourceTwo, watcher2); + xdsClient.watchXdsResource(XdsListenerResource.getInstance(), LDS_RESOURCE, ldsResourceWatcher); + xdsClient.watchXdsResource(XdsListenerResource.getInstance(), ldsResourceTwo, watcher1); + xdsClient.watchXdsResource(XdsListenerResource.getInstance(), ldsResourceTwo, watcher2); DiscoveryRpcCall call = resourceDiscoveryCalls.poll(); call.verifyRequest(LDS, ImmutableList.of(LDS_RESOURCE, ldsResourceTwo), "", "", NODE); // Both LDS resources were requested. @@ -1288,9 +1739,12 @@ public void multipleLdsWatchers() { verifySubscribedResourcesMetadataSizes(2, 0, 0, 0); fakeClock.forwardTime(XdsClientImpl.INITIAL_RESOURCE_FETCH_TIMEOUT_SEC, TimeUnit.SECONDS); - verify(ldsResourceWatcher).onResourceDoesNotExist(LDS_RESOURCE); - verify(watcher1).onResourceDoesNotExist(ldsResourceTwo); - verify(watcher2).onResourceDoesNotExist(ldsResourceTwo); + verify(ldsResourceWatcher).onResourceChanged(argThat(statusOr -> + !statusOr.hasValue() && statusOr.getStatus().getDescription().contains(LDS_RESOURCE))); + verify(watcher1).onResourceChanged(argThat(statusOr -> + !statusOr.hasValue() && statusOr.getStatus().getDescription().contains(ldsResourceTwo))); + verify(watcher2).onResourceChanged(argThat(statusOr -> + !statusOr.hasValue() && statusOr.getStatus().getDescription().contains(ldsResourceTwo))); verifyResourceMetadataDoesNotExist(LDS, LDS_RESOURCE); verifyResourceMetadataDoesNotExist(LDS, ldsResourceTwo); verifySubscribedResourcesMetadataSizes(2, 0, 0, 0); @@ -1298,16 +1752,22 @@ public void multipleLdsWatchers() { Any listenerTwo = Any.pack(mf.buildListenerWithApiListenerForRds(ldsResourceTwo, RDS_RESOURCE)); call.sendResponse(LDS, ImmutableList.of(testListenerVhosts, listenerTwo), VERSION_1, "0000"); // ResourceWatcher called with listenerVhosts. - verify(ldsResourceWatcher).onChanged(ldsUpdateCaptor.capture()); - verifyGoldenListenerVhosts(ldsUpdateCaptor.getValue()); + verify(ldsResourceWatcher, times(2)).onResourceChanged(ldsUpdateCaptor.capture()); + StatusOr statusOrUpdate = ldsUpdateCaptor.getValue(); + assertThat(statusOrUpdate.hasValue()).isTrue(); + verifyGoldenListenerVhosts(statusOrUpdate.getValue()); // watcher1 called with listenerTwo. - verify(watcher1).onChanged(ldsUpdateCaptor.capture()); - verifyGoldenListenerRds(ldsUpdateCaptor.getValue()); - assertThat(ldsUpdateCaptor.getValue().httpConnectionManager().virtualHosts()).isNull(); + verify(watcher1, times(2)).onResourceChanged(ldsUpdateCaptor.capture()); + statusOrUpdate = ldsUpdateCaptor.getValue(); + assertThat(statusOrUpdate.hasValue()).isTrue(); + verifyGoldenListenerRds(statusOrUpdate.getValue()); + assertThat(statusOrUpdate.getValue().httpConnectionManager().virtualHosts()).isNull(); // watcher2 called with listenerTwo. - verify(watcher2).onChanged(ldsUpdateCaptor.capture()); - verifyGoldenListenerRds(ldsUpdateCaptor.getValue()); - assertThat(ldsUpdateCaptor.getValue().httpConnectionManager().virtualHosts()).isNull(); + verify(watcher2, times(2)).onResourceChanged(ldsUpdateCaptor.capture()); + statusOrUpdate = ldsUpdateCaptor.getValue(); + assertThat(statusOrUpdate.hasValue()).isTrue(); + verifyGoldenListenerRds(statusOrUpdate.getValue()); + assertThat(statusOrUpdate.getValue().httpConnectionManager().virtualHosts()).isNull(); // Metadata of both listeners is stored. verifyResourceMetadataAcked(LDS, LDS_RESOURCE, testListenerVhosts, VERSION_1, TIME_INCREMENT); verifyResourceMetadataAcked(LDS, ldsResourceTwo, listenerTwo, VERSION_1, TIME_INCREMENT); @@ -1319,7 +1779,7 @@ public void rdsResourceNotFound() { DiscoveryRpcCall call = startResourceWatcher(XdsRouteConfigureResource.getInstance(), RDS_RESOURCE, rdsResourceWatcher); Any routeConfig = Any.pack(mf.buildRouteConfiguration("route-bar.googleapis.com", - mf.buildOpaqueVirtualHosts(2))); + mf.buildOpaqueVirtualHosts(2))); call.sendResponse(RDS, routeConfig, VERSION_1, "0000"); // Client sends an ACK RDS request. @@ -1329,7 +1789,8 @@ public void rdsResourceNotFound() { verifySubscribedResourcesMetadataSizes(0, 0, 1, 0); // Server failed to return subscribed resource within expected time window. fakeClock.forwardTime(XdsClientImpl.INITIAL_RESOURCE_FETCH_TIMEOUT_SEC, TimeUnit.SECONDS); - verify(rdsResourceWatcher).onResourceDoesNotExist(RDS_RESOURCE); + verify(rdsResourceWatcher).onResourceChanged(argThat( + arg -> !arg.hasValue() && arg.getStatus().getDescription().contains(RDS_RESOURCE))); assertThat(fakeClock.getPendingTasks(RDS_RESOURCE_FETCH_TIMEOUT_TASK_FILTER)).isEmpty(); verifyResourceMetadataDoesNotExist(RDS, RDS_RESOURCE); verifySubscribedResourcesMetadataSizes(0, 0, 1, 0); @@ -1370,7 +1831,7 @@ public void rdsResponseErrorHandling_someResourcesFailedUnpack() { verifySubscribedResourcesMetadataSizes(0, 0, 1, 0); // The response is NACKed with the same error message. call.verifyRequestNack(RDS, RDS_RESOURCE, "", "0000", NODE, errors); - verify(rdsResourceWatcher).onChanged(any(RdsUpdate.class)); + verify(rdsResourceWatcher).onResourceChanged(any()); } @Test @@ -1378,6 +1839,7 @@ public void rdsResponseErrorHandling_nackWeightedSumZero() { DiscoveryRpcCall call = startResourceWatcher(XdsRouteConfigureResource.getInstance(), RDS_RESOURCE, rdsResourceWatcher); verifyResourceMetadataRequested(RDS, RDS_RESOURCE); + String expectedErrorDetail = "Sum of cluster weights should be above 0"; io.envoyproxy.envoy.config.route.v3.RouteAction routeAction = io.envoyproxy.envoy.config.route.v3.RouteAction.newBuilder() @@ -1411,25 +1873,29 @@ public void rdsResponseErrorHandling_nackWeightedSumZero() { "RDS response RouteConfiguration \'route-configuration.googleapis.com\' validation error: " + "RouteConfiguration contains invalid virtual host: Virtual host [do not care] " + "contains invalid route : Route [route-blade] contains invalid RouteAction: " - + "Sum of cluster weights should be above 0."); + + expectedErrorDetail); verifySubscribedResourcesMetadataSizes(0, 0, 1, 0); // The response is NACKed with the same error message. call.verifyRequestNack(RDS, RDS_RESOURCE, "", "0000", NODE, errors); - verify(rdsResourceWatcher, never()).onChanged(any(RdsUpdate.class)); + verify(rdsResourceWatcher).onResourceChanged(argThat( + statusOr -> !statusOr.hasValue() && statusOr.getStatus().getDescription() + .contains(expectedErrorDetail))); + verify(rdsResourceWatcher, never()).onResourceChanged(argThat(StatusOr::hasValue)); } /** * Tests a subscribed RDS resource transitioned to and from the invalid state. * - * @see - * A40-csds-support.md + * @see + * A40-csds-support.md */ @Test public void rdsResponseErrorHandling_subscribedResourceInvalid() { List subscribedResourceNames = ImmutableList.of("A", "B", "C"); - xdsClient.watchXdsResource(XdsRouteConfigureResource.getInstance(),"A", rdsResourceWatcher); - xdsClient.watchXdsResource(XdsRouteConfigureResource.getInstance(),"B", rdsResourceWatcher); - xdsClient.watchXdsResource(XdsRouteConfigureResource.getInstance(),"C", rdsResourceWatcher); + xdsClient.watchXdsResource(XdsRouteConfigureResource.getInstance(), "A", rdsResourceWatcher); + xdsClient.watchXdsResource(XdsRouteConfigureResource.getInstance(), "B", rdsResourceWatcher); + xdsClient.watchXdsResource(XdsRouteConfigureResource.getInstance(), "C", rdsResourceWatcher); DiscoveryRpcCall call = resourceDiscoveryCalls.poll(); assertThat(call).isNotNull(); verifyResourceMetadataRequested(RDS, "A"); @@ -1448,6 +1914,8 @@ public void rdsResponseErrorHandling_subscribedResourceInvalid() { verifyResourceMetadataAcked(RDS, "A", resourcesV1.get("A"), VERSION_1, TIME_INCREMENT); verifyResourceMetadataAcked(RDS, "B", resourcesV1.get("B"), VERSION_1, TIME_INCREMENT); verifyResourceMetadataAcked(RDS, "C", resourcesV1.get("C"), VERSION_1, TIME_INCREMENT); + // Check metric data. + verifyResourceValidInvalidCount(1, 3, 0, xdsServerInfo.target(), RDS.typeUrl()); call.verifyRequest(RDS, subscribedResourceNames, VERSION_1, "0000", NODE); // RDS -> {A, B}, version 2 @@ -1459,11 +1927,13 @@ public void rdsResponseErrorHandling_subscribedResourceInvalid() { // {A} -> ACK, version 2 // {B} -> NACK, version 1, rejected version 2, rejected reason: Failed to parse B // {C} -> ACK, version 1 + verifyResourceValidInvalidCount(1, 1, 1, xdsServerInfo.target(), + RDS.typeUrl()); List errorsV2 = ImmutableList.of("RDS response RouteConfiguration 'B' validation error: "); verifyResourceMetadataAcked(RDS, "A", resourcesV2.get("A"), VERSION_2, TIME_INCREMENT * 2); verifyResourceMetadataNacked(RDS, "B", resourcesV1.get("B"), VERSION_1, TIME_INCREMENT, - VERSION_2, TIME_INCREMENT * 2, errorsV2); + VERSION_2, TIME_INCREMENT * 2, errorsV2, true); verifyResourceMetadataAcked(RDS, "C", resourcesV1.get("C"), VERSION_1, TIME_INCREMENT); call.verifyRequestNack(RDS, subscribedResourceNames, VERSION_1, "0001", NODE, errorsV2); @@ -1475,6 +1945,8 @@ public void rdsResponseErrorHandling_subscribedResourceInvalid() { call.sendResponse(RDS, resourcesV3.values().asList(), VERSION_3, "0002"); // {A} -> ACK, version 2 // {B, C} -> ACK, version 3 + verifyResourceValidInvalidCount(1, 2, 0, xdsServerInfo.target(), + RDS.typeUrl()); verifyResourceMetadataAcked(RDS, "A", resourcesV2.get("A"), VERSION_2, TIME_INCREMENT * 2); verifyResourceMetadataAcked(RDS, "B", resourcesV3.get("B"), VERSION_3, TIME_INCREMENT * 3); verifyResourceMetadataAcked(RDS, "C", resourcesV3.get("C"), VERSION_3, TIME_INCREMENT * 3); @@ -1490,8 +1962,10 @@ public void rdsResourceFound() { // Client sends an ACK RDS request. call.verifyRequest(RDS, RDS_RESOURCE, VERSION_1, "0000", NODE); - verify(rdsResourceWatcher).onChanged(rdsUpdateCaptor.capture()); - verifyGoldenRouteConfig(rdsUpdateCaptor.getValue()); + verify(rdsResourceWatcher).onResourceChanged(rdsUpdateCaptor.capture()); + StatusOr statusOrUpdate = rdsUpdateCaptor.getValue(); + assertThat(statusOrUpdate.hasValue()).isTrue(); + verifyGoldenRouteConfig(statusOrUpdate.getValue()); assertThat(fakeClock.getPendingTasks(RDS_RESOURCE_FETCH_TIMEOUT_TASK_FILTER)).isEmpty(); verifyResourceMetadataAcked(RDS, RDS_RESOURCE, testRouteConfig, VERSION_1, TIME_INCREMENT); verifySubscribedResourcesMetadataSizes(0, 0, 1, 0); @@ -1505,8 +1979,10 @@ public void wrappedRdsResource() { // Client sends an ACK RDS request. call.verifyRequest(RDS, RDS_RESOURCE, VERSION_1, "0000", NODE); - verify(rdsResourceWatcher).onChanged(rdsUpdateCaptor.capture()); - verifyGoldenRouteConfig(rdsUpdateCaptor.getValue()); + verify(rdsResourceWatcher).onResourceChanged(rdsUpdateCaptor.capture()); + StatusOr statusOrUpdate = rdsUpdateCaptor.getValue(); + assertThat(statusOrUpdate.hasValue()).isTrue(); + verifyGoldenRouteConfig(statusOrUpdate.getValue()); assertThat(fakeClock.getPendingTasks(RDS_RESOURCE_FETCH_TIMEOUT_TASK_FILTER)).isEmpty(); verifyResourceMetadataAcked(RDS, RDS_RESOURCE, testRouteConfig, VERSION_1, TIME_INCREMENT); verifySubscribedResourcesMetadataSizes(0, 0, 1, 0); @@ -1523,9 +1999,11 @@ public void cachedRdsResource_data() { call.verifyRequest(RDS, RDS_RESOURCE, VERSION_1, "0000", NODE); ResourceWatcher watcher = mock(ResourceWatcher.class); - xdsClient.watchXdsResource(XdsRouteConfigureResource.getInstance(),RDS_RESOURCE, watcher); - verify(watcher).onChanged(rdsUpdateCaptor.capture()); - verifyGoldenRouteConfig(rdsUpdateCaptor.getValue()); + xdsClient.watchXdsResource(XdsRouteConfigureResource.getInstance(), RDS_RESOURCE, watcher); + verify(watcher).onResourceChanged(rdsUpdateCaptor.capture()); + StatusOr statusOrUpdate = rdsUpdateCaptor.getValue(); + assertThat(statusOrUpdate.hasValue()).isTrue(); + verifyGoldenRouteConfig(statusOrUpdate.getValue()); call.verifyNoMoreRequest(); verifyResourceMetadataAcked(RDS, RDS_RESOURCE, testRouteConfig, VERSION_1, TIME_INCREMENT); verifySubscribedResourcesMetadataSizes(0, 0, 1, 0); @@ -1537,11 +2015,15 @@ public void cachedRdsResource_absent() { DiscoveryRpcCall call = startResourceWatcher(XdsRouteConfigureResource.getInstance(), RDS_RESOURCE, rdsResourceWatcher); fakeClock.forwardTime(XdsClientImpl.INITIAL_RESOURCE_FETCH_TIMEOUT_SEC, TimeUnit.SECONDS); - verify(rdsResourceWatcher).onResourceDoesNotExist(RDS_RESOURCE); + verify(rdsResourceWatcher).onResourceChanged(argThat(statusOr -> + !statusOr.hasValue() && statusOr.getStatus().getDescription().contains(RDS_RESOURCE) + && statusOr.getStatus().getDescription().contains(RDS_RESOURCE))); // Add another watcher. ResourceWatcher watcher = mock(ResourceWatcher.class); - xdsClient.watchXdsResource(XdsRouteConfigureResource.getInstance(),RDS_RESOURCE, watcher); - verify(watcher).onResourceDoesNotExist(RDS_RESOURCE); + xdsClient.watchXdsResource(XdsRouteConfigureResource.getInstance(), RDS_RESOURCE, watcher); + verify(watcher).onResourceChanged(argThat(statusOr -> + !statusOr.hasValue() && statusOr.getStatus().getDescription().contains(RDS_RESOURCE) + && statusOr.getStatus().getDescription().contains(RDS_RESOURCE))); call.verifyNoMoreRequest(); verifyResourceMetadataDoesNotExist(RDS, RDS_RESOURCE); verifySubscribedResourcesMetadataSizes(0, 0, 1, 0); @@ -1556,8 +2038,10 @@ public void rdsResourceUpdated() { // Initial RDS response. call.sendResponse(RDS, testRouteConfig, VERSION_1, "0000"); call.verifyRequest(RDS, RDS_RESOURCE, VERSION_1, "0000", NODE); - verify(rdsResourceWatcher).onChanged(rdsUpdateCaptor.capture()); - verifyGoldenRouteConfig(rdsUpdateCaptor.getValue()); + verify(rdsResourceWatcher).onResourceChanged(rdsUpdateCaptor.capture()); + StatusOr statusOrUpdate = rdsUpdateCaptor.getValue(); + assertThat(statusOrUpdate.hasValue()).isTrue(); + verifyGoldenRouteConfig(statusOrUpdate.getValue()); verifyResourceMetadataAcked(RDS, RDS_RESOURCE, testRouteConfig, VERSION_1, TIME_INCREMENT); // Updated RDS response. @@ -1567,18 +2051,49 @@ public void rdsResourceUpdated() { // Client sends an ACK RDS request. call.verifyRequest(RDS, RDS_RESOURCE, VERSION_2, "0001", NODE); - verify(rdsResourceWatcher, times(2)).onChanged(rdsUpdateCaptor.capture()); - assertThat(rdsUpdateCaptor.getValue().virtualHosts).hasSize(4); + verify(rdsResourceWatcher, times(2)).onResourceChanged(rdsUpdateCaptor.capture()); + statusOrUpdate = rdsUpdateCaptor.getValue(); + assertThat(statusOrUpdate.hasValue()).isTrue(); + assertThat(statusOrUpdate.getValue().virtualHosts).hasSize(4); verifyResourceMetadataAcked(RDS, RDS_RESOURCE, routeConfigUpdated, VERSION_2, TIME_INCREMENT * 2); - verifySubscribedResourcesMetadataSizes(0, 0, 1, 0); + } + + @Test + public void rdsResourceInvalid() { + xdsClient.watchXdsResource(XdsRouteConfigureResource.getInstance(), "A", rdsResourceWatcher); + xdsClient.watchXdsResource(XdsRouteConfigureResource.getInstance(), "B", rdsResourceWatcher); + DiscoveryRpcCall call = resourceDiscoveryCalls.poll(); + assertThat(call).isNotNull(); + verifyResourceMetadataRequested(RDS, "A"); + verifyResourceMetadataRequested(RDS, "B"); + verifySubscribedResourcesMetadataSizes(0, 0, 2, 0); + + // RDS -> {A, B}, version 1 + // Failed to parse endpoint B + List vhostsV1 = mf.buildOpaqueVirtualHosts(1); + ImmutableMap resourcesV1 = ImmutableMap.of( + "A", Any.pack(mf.buildRouteConfiguration("A", vhostsV1)), + "B", Any.pack(mf.buildRouteConfigurationInvalid("B"))); + call.sendResponse(RDS, resourcesV1.values().asList(), VERSION_1, "0000"); + + // {A} -> ACK, version 1 + // {B} -> NACK, version 1, rejected version 1, rejected reason: Failed to parse B + List errorsV1 = + ImmutableList.of("RDS response RouteConfiguration 'B' validation error: "); + verifyResourceMetadataAcked(RDS, "A", resourcesV1.get("A"), VERSION_1, TIME_INCREMENT); + verifyResourceMetadataNacked(RDS, "B", null, "", 0, + VERSION_1, TIME_INCREMENT, errorsV1, false); + // Check metric data. + verifyResourceValidInvalidCount(1, 1, 1, xdsServerInfo.target(), RDS.typeUrl()); + verifySubscribedResourcesMetadataSizes(0, 0, 2, 0); } @Test public void rdsResourceDeletedByLdsApiListener() { - xdsClient.watchXdsResource(XdsListenerResource.getInstance(),LDS_RESOURCE, + xdsClient.watchXdsResource(XdsListenerResource.getInstance(), LDS_RESOURCE, ldsResourceWatcher); - xdsClient.watchXdsResource(XdsRouteConfigureResource.getInstance(),RDS_RESOURCE, + xdsClient.watchXdsResource(XdsRouteConfigureResource.getInstance(), RDS_RESOURCE, rdsResourceWatcher); verifyResourceMetadataRequested(LDS, LDS_RESOURCE); verifyResourceMetadataRequested(RDS, RDS_RESOURCE); @@ -1586,15 +2101,19 @@ public void rdsResourceDeletedByLdsApiListener() { DiscoveryRpcCall call = resourceDiscoveryCalls.poll(); call.sendResponse(LDS, testListenerRds, VERSION_1, "0000"); - verify(ldsResourceWatcher).onChanged(ldsUpdateCaptor.capture()); - verifyGoldenListenerRds(ldsUpdateCaptor.getValue()); + verify(ldsResourceWatcher).onResourceChanged(ldsUpdateCaptor.capture()); + StatusOr statusOrUpdate = ldsUpdateCaptor.getValue(); + assertThat(statusOrUpdate.hasValue()).isTrue(); + verifyGoldenListenerRds(statusOrUpdate.getValue()); verifyResourceMetadataAcked(LDS, LDS_RESOURCE, testListenerRds, VERSION_1, TIME_INCREMENT); verifyResourceMetadataRequested(RDS, RDS_RESOURCE); verifySubscribedResourcesMetadataSizes(1, 0, 1, 0); call.sendResponse(RDS, testRouteConfig, VERSION_1, "0000"); - verify(rdsResourceWatcher).onChanged(rdsUpdateCaptor.capture()); - verifyGoldenRouteConfig(rdsUpdateCaptor.getValue()); + verify(rdsResourceWatcher).onResourceChanged(rdsUpdateCaptor.capture()); + StatusOr statusOrUpdate1 = rdsUpdateCaptor.getValue(); + assertThat(statusOrUpdate.hasValue()).isTrue(); + verifyGoldenRouteConfig(statusOrUpdate1.getValue()); verifyResourceMetadataAcked(LDS, LDS_RESOURCE, testListenerRds, VERSION_1, TIME_INCREMENT); verifyResourceMetadataAcked(RDS, RDS_RESOURCE, testRouteConfig, VERSION_1, TIME_INCREMENT * 2); verifySubscribedResourcesMetadataSizes(1, 0, 1, 0); @@ -1604,8 +2123,10 @@ public void rdsResourceDeletedByLdsApiListener() { // Note that this must work the same despite the ignore_resource_deletion feature is on. // This happens because the Listener is getting replaced, and not deleted. call.sendResponse(LDS, testListenerVhosts, VERSION_2, "0001"); - verify(ldsResourceWatcher, times(2)).onChanged(ldsUpdateCaptor.capture()); - verifyGoldenListenerVhosts(ldsUpdateCaptor.getValue()); + verify(ldsResourceWatcher, times(2)).onResourceChanged(ldsUpdateCaptor.capture()); + statusOrUpdate = ldsUpdateCaptor.getValue(); + assertThat(statusOrUpdate.hasValue()).isTrue(); + verifyGoldenListenerVhosts(statusOrUpdate.getValue()); verifyNoMoreInteractions(rdsResourceWatcher); verifyResourceMetadataAcked(RDS, RDS_RESOURCE, testRouteConfig, VERSION_1, TIME_INCREMENT * 2); verifyResourceMetadataAcked( @@ -1637,11 +2158,13 @@ public void rdsResourcesDeletedByLdsTcpListener() { // referencing RDS_RESOURCE. DiscoveryRpcCall call = resourceDiscoveryCalls.poll(); call.sendResponse(LDS, packedListener, VERSION_1, "0000"); - verify(ldsResourceWatcher).onChanged(ldsUpdateCaptor.capture()); + verify(ldsResourceWatcher).onResourceChanged(ldsUpdateCaptor.capture()); + StatusOr statusOrUpdate = ldsUpdateCaptor.getValue(); + assertThat(statusOrUpdate.hasValue()).isTrue(); - assertThat(ldsUpdateCaptor.getValue().listener().filterChains()).hasSize(1); + assertThat(statusOrUpdate.getValue().listener().filterChains()).hasSize(1); FilterChain parsedFilterChain = Iterables.getOnlyElement( - ldsUpdateCaptor.getValue().listener().filterChains()); + statusOrUpdate.getValue().listener().filterChains()); assertThat(parsedFilterChain.httpConnectionManager().rdsName()).isEqualTo(RDS_RESOURCE); verifyResourceMetadataAcked(LDS, LISTENER_RESOURCE, packedListener, VERSION_1, TIME_INCREMENT); verifyResourceMetadataRequested(RDS, RDS_RESOURCE); @@ -1649,8 +2172,10 @@ public void rdsResourcesDeletedByLdsTcpListener() { // Simulates receiving the requested RDS resource. call.sendResponse(RDS, testRouteConfig, VERSION_1, "0000"); - verify(rdsResourceWatcher).onChanged(rdsUpdateCaptor.capture()); - verifyGoldenRouteConfig(rdsUpdateCaptor.getValue()); + verify(rdsResourceWatcher).onResourceChanged(rdsUpdateCaptor.capture()); + StatusOr statusOrUpdate1 = rdsUpdateCaptor.getValue(); + assertThat(statusOrUpdate.hasValue()).isTrue(); + verifyGoldenRouteConfig(statusOrUpdate1.getValue()); verifyResourceMetadataAcked(RDS, RDS_RESOURCE, testRouteConfig, VERSION_1, TIME_INCREMENT * 2); // Simulates receiving an updated version of the requested LDS resource as a TCP listener @@ -1668,12 +2193,15 @@ public void rdsResourcesDeletedByLdsTcpListener() { packedListener = Any.pack(mf.buildListenerWithFilterChain(LISTENER_RESOURCE, 7000, "0.0.0.0", filterChain)); call.sendResponse(LDS, packedListener, VERSION_2, "0001"); - verify(ldsResourceWatcher, times(2)).onChanged(ldsUpdateCaptor.capture()); - assertThat(ldsUpdateCaptor.getValue().listener().filterChains()).hasSize(1); + verify(ldsResourceWatcher, times(2)).onResourceChanged(ldsUpdateCaptor.capture()); + statusOrUpdate = ldsUpdateCaptor.getValue(); + assertThat(statusOrUpdate.hasValue()).isTrue(); + assertThat(statusOrUpdate.getValue().listener().filterChains()).hasSize(1); parsedFilterChain = Iterables.getOnlyElement( - ldsUpdateCaptor.getValue().listener().filterChains()); + statusOrUpdate.getValue().listener().filterChains()); assertThat(parsedFilterChain.httpConnectionManager().virtualHosts()).hasSize(VHOST_SIZE); - verify(rdsResourceWatcher, never()).onResourceDoesNotExist(RDS_RESOURCE); + verify(rdsResourceWatcher, never()).onResourceChanged(argThat(statusOr -> + !statusOr.hasValue() && statusOr.getStatus().getDescription().equals(RDS_RESOURCE))); verifyResourceMetadataAcked(RDS, RDS_RESOURCE, testRouteConfig, VERSION_1, TIME_INCREMENT * 2); verifyResourceMetadataAcked( LDS, LISTENER_RESOURCE, packedListener, VERSION_2, TIME_INCREMENT * 3); @@ -1686,10 +2214,10 @@ public void multipleRdsWatchers() { String rdsResourceTwo = "route-bar.googleapis.com"; ResourceWatcher watcher1 = mock(ResourceWatcher.class); ResourceWatcher watcher2 = mock(ResourceWatcher.class); - xdsClient.watchXdsResource(XdsRouteConfigureResource.getInstance(),RDS_RESOURCE, + xdsClient.watchXdsResource(XdsRouteConfigureResource.getInstance(), RDS_RESOURCE, rdsResourceWatcher); - xdsClient.watchXdsResource(XdsRouteConfigureResource.getInstance(),rdsResourceTwo, watcher1); - xdsClient.watchXdsResource(XdsRouteConfigureResource.getInstance(),rdsResourceTwo, watcher2); + xdsClient.watchXdsResource(XdsRouteConfigureResource.getInstance(), rdsResourceTwo, watcher1); + xdsClient.watchXdsResource(XdsRouteConfigureResource.getInstance(), rdsResourceTwo, watcher2); DiscoveryRpcCall call = resourceDiscoveryCalls.poll(); call.verifyRequest(RDS, Arrays.asList(RDS_RESOURCE, rdsResourceTwo), "", "", NODE); // Both RDS resources were requested. @@ -1698,16 +2226,25 @@ public void multipleRdsWatchers() { verifySubscribedResourcesMetadataSizes(0, 0, 2, 0); fakeClock.forwardTime(XdsClientImpl.INITIAL_RESOURCE_FETCH_TIMEOUT_SEC, TimeUnit.SECONDS); - verify(rdsResourceWatcher).onResourceDoesNotExist(RDS_RESOURCE); - verify(watcher1).onResourceDoesNotExist(rdsResourceTwo); - verify(watcher2).onResourceDoesNotExist(rdsResourceTwo); + verify(rdsResourceWatcher).onResourceChanged( + argThat(statusOr -> !statusOr.hasValue() + && statusOr.getStatus().getCode() == Status.Code.NOT_FOUND)); + verify(watcher1).onResourceChanged( + argThat(statusOr -> !statusOr.hasValue() + && statusOr.getStatus().getCode() == Status.Code.NOT_FOUND)); + verify(watcher2).onResourceChanged( + argThat(statusOr -> !statusOr.hasValue() + && statusOr.getStatus().getCode() == Status.Code.NOT_FOUND)); verifyResourceMetadataDoesNotExist(RDS, RDS_RESOURCE); verifyResourceMetadataDoesNotExist(RDS, rdsResourceTwo); verifySubscribedResourcesMetadataSizes(0, 0, 2, 0); call.sendResponse(RDS, testRouteConfig, VERSION_1, "0000"); - verify(rdsResourceWatcher).onChanged(rdsUpdateCaptor.capture()); - verifyGoldenRouteConfig(rdsUpdateCaptor.getValue()); + ArgumentCaptor> rdsUpdateCaptor = ArgumentCaptor.forClass(StatusOr.class); + verify(rdsResourceWatcher, times(2)).onResourceChanged(rdsUpdateCaptor.capture()); + StatusOr capturedUpdate1 = rdsUpdateCaptor.getAllValues().get(1); + assertThat(capturedUpdate1.hasValue()).isTrue(); + verifyGoldenRouteConfig(capturedUpdate1.getValue()); verifyNoMoreInteractions(watcher1, watcher2); verifyResourceMetadataAcked(RDS, RDS_RESOURCE, testRouteConfig, VERSION_1, TIME_INCREMENT); verifyResourceMetadataDoesNotExist(RDS, rdsResourceTwo); @@ -1716,13 +2253,22 @@ public void multipleRdsWatchers() { Any routeConfigTwo = Any.pack(mf.buildRouteConfiguration(rdsResourceTwo, mf.buildOpaqueVirtualHosts(4))); call.sendResponse(RDS, routeConfigTwo, VERSION_2, "0002"); - verify(watcher1).onChanged(rdsUpdateCaptor.capture()); - assertThat(rdsUpdateCaptor.getValue().virtualHosts).hasSize(4); - verify(watcher2).onChanged(rdsUpdateCaptor.capture()); - assertThat(rdsUpdateCaptor.getValue().virtualHosts).hasSize(4); + ArgumentCaptor> watcher1Captor = + ArgumentCaptor.forClass(StatusOr.class); + verify(watcher1, times(2)).onResourceChanged(watcher1Captor.capture()); + StatusOr capturedUpdate2 = watcher1Captor.getAllValues().get(1); + assertThat(capturedUpdate2.hasValue()).isTrue(); + assertThat(capturedUpdate2.getValue().virtualHosts).hasSize(4); + ArgumentCaptor> watcher2Captor = + ArgumentCaptor.forClass(StatusOr.class); + verify(watcher2, times(2)).onResourceChanged(watcher2Captor.capture()); + StatusOr capturedUpdate3 = watcher2Captor.getAllValues().get(1); + assertThat(capturedUpdate3.hasValue()).isTrue(); + assertThat(capturedUpdate3.getValue().virtualHosts).hasSize(4); verifyNoMoreInteractions(rdsResourceWatcher); verifyResourceMetadataAcked(RDS, RDS_RESOURCE, testRouteConfig, VERSION_1, TIME_INCREMENT); - verifyResourceMetadataAcked(RDS, rdsResourceTwo, routeConfigTwo, VERSION_2, TIME_INCREMENT * 2); + verifyResourceMetadataAcked(RDS, rdsResourceTwo, routeConfigTwo, VERSION_2, + TIME_INCREMENT * 2); verifySubscribedResourcesMetadataSizes(0, 0, 2, 0); } @@ -1745,7 +2291,8 @@ public void cdsResourceNotFound() { verifySubscribedResourcesMetadataSizes(0, 1, 0, 0); // Server failed to return subscribed resource within expected time window. fakeClock.forwardTime(XdsClientImpl.INITIAL_RESOURCE_FETCH_TIMEOUT_SEC, TimeUnit.SECONDS); - verify(cdsResourceWatcher).onResourceDoesNotExist(CDS_RESOURCE); + verify(cdsResourceWatcher).onResourceChanged(argThat( + arg -> !arg.hasValue() && arg.getStatus().getDescription().contains(CDS_RESOURCE))); assertThat(fakeClock.getPendingTasks(CDS_RESOURCE_FETCH_TIMEOUT_TASK_FILTER)).isEmpty(); verifyResourceMetadataDoesNotExist(CDS, CDS_RESOURCE); verifySubscribedResourcesMetadataSizes(0, 1, 0, 0); @@ -1787,21 +2334,22 @@ public void cdsResponseErrorHandling_someResourcesFailedUnpack() { verifySubscribedResourcesMetadataSizes(0, 1, 0, 0); // The response is NACKed with the same error message. call.verifyRequestNack(CDS, CDS_RESOURCE, "", "0000", NODE, errors); - verify(cdsResourceWatcher).onChanged(any(CdsUpdate.class)); + verify(cdsResourceWatcher).onResourceChanged(any()); } /** * Tests a subscribed CDS resource transitioned to and from the invalid state. * - * @see - * A40-csds-support.md + * @see + * A40-csds-support.md */ @Test public void cdsResponseErrorHandling_subscribedResourceInvalid() { List subscribedResourceNames = ImmutableList.of("A", "B", "C"); - xdsClient.watchXdsResource(XdsClusterResource.getInstance(),"A", cdsResourceWatcher); - xdsClient.watchXdsResource(XdsClusterResource.getInstance(),"B", cdsResourceWatcher); - xdsClient.watchXdsResource(XdsClusterResource.getInstance(),"C", cdsResourceWatcher); + xdsClient.watchXdsResource(XdsClusterResource.getInstance(), "A", cdsResourceWatcher); + xdsClient.watchXdsResource(XdsClusterResource.getInstance(), "B", cdsResourceWatcher); + xdsClient.watchXdsResource(XdsClusterResource.getInstance(), "C", cdsResourceWatcher); DiscoveryRpcCall call = resourceDiscoveryCalls.poll(); assertThat(call).isNotNull(); verifyResourceMetadataRequested(CDS, "A"); @@ -1822,6 +2370,8 @@ public void cdsResponseErrorHandling_subscribedResourceInvalid() { ))); call.sendResponse(CDS, resourcesV1.values().asList(), VERSION_1, "0000"); // {A, B, C} -> ACK, version 1 + verifyResourceValidInvalidCount(1, 3, 0, xdsServerInfo.target(), + CDS.typeUrl()); verifyResourceMetadataAcked(CDS, "A", resourcesV1.get("A"), VERSION_1, TIME_INCREMENT); verifyResourceMetadataAcked(CDS, "B", resourcesV1.get("B"), VERSION_1, TIME_INCREMENT); verifyResourceMetadataAcked(CDS, "C", resourcesV1.get("C"), VERSION_1, TIME_INCREMENT); @@ -1838,10 +2388,12 @@ public void cdsResponseErrorHandling_subscribedResourceInvalid() { // {A} -> ACK, version 2 // {B} -> NACK, version 1, rejected version 2, rejected reason: Failed to parse B // {C} -> does not exist + verifyResourceValidInvalidCount(1, 1, 1, xdsServerInfo.target(), + CDS.typeUrl()); List errorsV2 = ImmutableList.of("CDS response Cluster 'B' validation error: "); verifyResourceMetadataAcked(CDS, "A", resourcesV2.get("A"), VERSION_2, TIME_INCREMENT * 2); verifyResourceMetadataNacked(CDS, "B", resourcesV1.get("B"), VERSION_1, TIME_INCREMENT, - VERSION_2, TIME_INCREMENT * 2, errorsV2); + VERSION_2, TIME_INCREMENT * 2, errorsV2, true); if (!ignoreResourceDeletion()) { verifyResourceMetadataDoesNotExist(CDS, "C"); } else { @@ -1861,6 +2413,8 @@ public void cdsResponseErrorHandling_subscribedResourceInvalid() { call.sendResponse(CDS, resourcesV3.values().asList(), VERSION_3, "0002"); // {A} -> does not exit // {B, C} -> ACK, version 3 + verifyResourceValidInvalidCount(1, 2, 0, xdsServerInfo.target(), + CDS.typeUrl()); if (!ignoreResourceDeletion()) { verifyResourceMetadataDoesNotExist(CDS, "A"); } else { @@ -1869,18 +2423,19 @@ public void cdsResponseErrorHandling_subscribedResourceInvalid() { } verifyResourceMetadataAcked(CDS, "B", resourcesV3.get("B"), VERSION_3, TIME_INCREMENT * 3); verifyResourceMetadataAcked(CDS, "C", resourcesV3.get("C"), VERSION_3, TIME_INCREMENT * 3); + call.verifyRequest(CDS, subscribedResourceNames, VERSION_3, "0002", NODE); } @Test public void cdsResponseErrorHandling_subscribedResourceInvalid_withEdsSubscription() { List subscribedResourceNames = ImmutableList.of("A", "B", "C"); - xdsClient.watchXdsResource(XdsClusterResource.getInstance(),"A", cdsResourceWatcher); - xdsClient.watchXdsResource(XdsEndpointResource.getInstance(),"A.1", edsResourceWatcher); - xdsClient.watchXdsResource(XdsClusterResource.getInstance(),"B", cdsResourceWatcher); - xdsClient.watchXdsResource(XdsEndpointResource.getInstance(),"B.1", edsResourceWatcher); - xdsClient.watchXdsResource(XdsClusterResource.getInstance(),"C", cdsResourceWatcher); - xdsClient.watchXdsResource(XdsEndpointResource.getInstance(),"C.1", edsResourceWatcher); + xdsClient.watchXdsResource(XdsClusterResource.getInstance(), "A", cdsResourceWatcher); + xdsClient.watchXdsResource(XdsEndpointResource.getInstance(), "A.1", edsResourceWatcher); + xdsClient.watchXdsResource(XdsClusterResource.getInstance(), "B", cdsResourceWatcher); + xdsClient.watchXdsResource(XdsEndpointResource.getInstance(), "B.1", edsResourceWatcher); + xdsClient.watchXdsResource(XdsClusterResource.getInstance(), "C", cdsResourceWatcher); + xdsClient.watchXdsResource(XdsEndpointResource.getInstance(), "C.1", edsResourceWatcher); DiscoveryRpcCall call = resourceDiscoveryCalls.poll(); assertThat(call).isNotNull(); verifyResourceMetadataRequested(CDS, "A"); @@ -1904,6 +2459,8 @@ public void cdsResponseErrorHandling_subscribedResourceInvalid_withEdsSubscripti ))); call.sendResponse(CDS, resourcesV1.values().asList(), VERSION_1, "0000"); // {A, B, C} -> ACK, version 1 + verifyResourceValidInvalidCount(1, 3, 0, xdsServerInfo.target(), + CDS.typeUrl()); verifyResourceMetadataAcked(CDS, "A", resourcesV1.get("A"), VERSION_1, TIME_INCREMENT); verifyResourceMetadataAcked(CDS, "B", resourcesV1.get("B"), VERSION_1, TIME_INCREMENT); verifyResourceMetadataAcked(CDS, "C", resourcesV1.get("C"), VERSION_1, TIME_INCREMENT); @@ -1918,6 +2475,8 @@ public void cdsResponseErrorHandling_subscribedResourceInvalid_withEdsSubscripti "C.1", Any.pack(mf.buildClusterLoadAssignment("C.1", endpointsV1, dropOverloads))); call.sendResponse(EDS, resourcesV11.values().asList(), VERSION_1, "0000"); // {A.1, B.1, C.1} -> ACK, version 1 + verifyResourceValidInvalidCount(1, 3, 0, xdsServerInfo.target(), + EDS.typeUrl()); verifyResourceMetadataAcked(EDS, "A.1", resourcesV11.get("A.1"), VERSION_1, TIME_INCREMENT * 2); verifyResourceMetadataAcked(EDS, "B.1", resourcesV11.get("B.1"), VERSION_1, TIME_INCREMENT * 2); verifyResourceMetadataAcked(EDS, "C.1", resourcesV11.get("C.1"), VERSION_1, TIME_INCREMENT * 2); @@ -1933,11 +2492,13 @@ public void cdsResponseErrorHandling_subscribedResourceInvalid_withEdsSubscripti // {A} -> ACK, version 2 // {B} -> NACK, version 1, rejected version 2, rejected reason: Failed to parse B // {C} -> does not exist + // Check metric data. + verifyResourceValidInvalidCount(1, 1, 1, xdsServerInfo.target(), CDS.typeUrl()); List errorsV2 = ImmutableList.of("CDS response Cluster 'B' validation error: "); verifyResourceMetadataAcked(CDS, "A", resourcesV2.get("A"), VERSION_2, TIME_INCREMENT * 3); verifyResourceMetadataNacked( CDS, "B", resourcesV1.get("B"), VERSION_1, TIME_INCREMENT, VERSION_2, TIME_INCREMENT * 3, - errorsV2); + errorsV2, true); if (!ignoreResourceDeletion()) { verifyResourceMetadataDoesNotExist(CDS, "C"); } else { @@ -1963,8 +2524,10 @@ public void cdsResourceFound() { // Client sent an ACK CDS request. call.verifyRequest(CDS, CDS_RESOURCE, VERSION_1, "0000", NODE); - verify(cdsResourceWatcher).onChanged(cdsUpdateCaptor.capture()); - verifyGoldenClusterRoundRobin(cdsUpdateCaptor.getValue()); + verify(cdsResourceWatcher).onResourceChanged(cdsUpdateCaptor.capture()); + StatusOr statusOrUpdate = cdsUpdateCaptor.getValue(); + assertThat(statusOrUpdate.hasValue()).isTrue(); + verifyGoldenClusterRoundRobin(statusOrUpdate.getValue()); assertThat(fakeClock.getPendingTasks(CDS_RESOURCE_FETCH_TIMEOUT_TASK_FILTER)).isEmpty(); verifyResourceMetadataAcked(CDS, CDS_RESOURCE, testClusterRoundRobin, VERSION_1, TIME_INCREMENT); @@ -1979,8 +2542,10 @@ public void wrappedCdsResource() { // Client sent an ACK CDS request. call.verifyRequest(CDS, CDS_RESOURCE, VERSION_1, "0000", NODE); - verify(cdsResourceWatcher).onChanged(cdsUpdateCaptor.capture()); - verifyGoldenClusterRoundRobin(cdsUpdateCaptor.getValue()); + verify(cdsResourceWatcher).onResourceChanged(cdsUpdateCaptor.capture()); + StatusOr statusOrUpdate = cdsUpdateCaptor.getValue(); + assertThat(statusOrUpdate.hasValue()).isTrue(); + verifyGoldenClusterRoundRobin(statusOrUpdate.getValue()); assertThat(fakeClock.getPendingTasks(CDS_RESOURCE_FETCH_TIMEOUT_TASK_FILTER)).isEmpty(); verifyResourceMetadataAcked(CDS, CDS_RESOURCE, testClusterRoundRobin, VERSION_1, TIME_INCREMENT); @@ -2000,8 +2565,10 @@ public void cdsResourceFound_leastRequestLbPolicy() { // Client sent an ACK CDS request. call.verifyRequest(CDS, CDS_RESOURCE, VERSION_1, "0000", NODE); - verify(cdsResourceWatcher).onChanged(cdsUpdateCaptor.capture()); - CdsUpdate cdsUpdate = cdsUpdateCaptor.getValue(); + verify(cdsResourceWatcher).onResourceChanged(cdsUpdateCaptor.capture()); + StatusOr statusOrUpdate = cdsUpdateCaptor.getValue(); + assertThat(statusOrUpdate.hasValue()).isTrue(); + CdsUpdate cdsUpdate = statusOrUpdate.getValue(); assertThat(cdsUpdate.clusterName()).isEqualTo(CDS_RESOURCE); assertThat(cdsUpdate.clusterType()).isEqualTo(ClusterType.EDS); assertThat(cdsUpdate.edsServiceName()).isNull(); @@ -2032,8 +2599,10 @@ public void cdsResourceFound_ringHashLbPolicy() { // Client sent an ACK CDS request. call.verifyRequest(CDS, CDS_RESOURCE, VERSION_1, "0000", NODE); - verify(cdsResourceWatcher).onChanged(cdsUpdateCaptor.capture()); - CdsUpdate cdsUpdate = cdsUpdateCaptor.getValue(); + verify(cdsResourceWatcher).onResourceChanged(cdsUpdateCaptor.capture()); + StatusOr statusOrUpdate = cdsUpdateCaptor.getValue(); + assertThat(statusOrUpdate.hasValue()).isTrue(); + CdsUpdate cdsUpdate = statusOrUpdate.getValue(); assertThat(cdsUpdate.clusterName()).isEqualTo(CDS_RESOURCE); assertThat(cdsUpdate.clusterType()).isEqualTo(ClusterType.EDS); assertThat(cdsUpdate.edsServiceName()).isNull(); @@ -2063,8 +2632,10 @@ public void cdsResponseWithAggregateCluster() { // Client sent an ACK CDS request. call.verifyRequest(CDS, CDS_RESOURCE, VERSION_1, "0000", NODE); - verify(cdsResourceWatcher).onChanged(cdsUpdateCaptor.capture()); - CdsUpdate cdsUpdate = cdsUpdateCaptor.getValue(); + verify(cdsResourceWatcher).onResourceChanged(cdsUpdateCaptor.capture()); + StatusOr statusOrUpdate = cdsUpdateCaptor.getValue(); + assertThat(statusOrUpdate.hasValue()).isTrue(); + CdsUpdate cdsUpdate = statusOrUpdate.getValue(); assertThat(cdsUpdate.clusterName()).isEqualTo(CDS_RESOURCE); assertThat(cdsUpdate.clusterType()).isEqualTo(ClusterType.AGGREGATE); LbConfig lbConfig = ServiceConfigUtil.unwrapLoadBalancingConfig(cdsUpdate.lbPolicyConfig()); @@ -2077,6 +2648,23 @@ public void cdsResponseWithAggregateCluster() { verifySubscribedResourcesMetadataSizes(0, 1, 0, 0); } + @Test + public void cdsResponseWithEmptyAggregateCluster() { + DiscoveryRpcCall call = startResourceWatcher(XdsClusterResource.getInstance(), CDS_RESOURCE, + cdsResourceWatcher); + List candidates = Arrays.asList(); + Any clusterAggregate = + Any.pack(mf.buildAggregateCluster(CDS_RESOURCE, "round_robin", null, null, candidates)); + call.sendResponse(CDS, clusterAggregate, VERSION_1, "0000"); + + // Client sent an ACK CDS request. + String errorMsg = "CDS response Cluster 'cluster.googleapis.com' validation error: " + + "Cluster cluster.googleapis.com: aggregate ClusterConfig.clusters must not be empty"; + call.verifyRequestNack(CDS, CDS_RESOURCE, "", "0000", NODE, ImmutableList.of(errorMsg)); + verify(cdsResourceWatcher).onResourceChanged(cdsUpdateCaptor.capture()); + verifyStatusWithNodeId(cdsUpdateCaptor.getValue().getStatus(), Code.UNAVAILABLE, errorMsg); + } + @Test public void cdsResponseWithCircuitBreakers() { DiscoveryRpcCall call = startResourceWatcher(XdsClusterResource.getInstance(), CDS_RESOURCE, @@ -2088,8 +2676,10 @@ public void cdsResponseWithCircuitBreakers() { // Client sent an ACK CDS request. call.verifyRequest(CDS, CDS_RESOURCE, VERSION_1, "0000", NODE); - verify(cdsResourceWatcher).onChanged(cdsUpdateCaptor.capture()); - CdsUpdate cdsUpdate = cdsUpdateCaptor.getValue(); + verify(cdsResourceWatcher).onResourceChanged(cdsUpdateCaptor.capture()); + StatusOr statusOrUpdate = cdsUpdateCaptor.getValue(); + assertThat(statusOrUpdate.hasValue()).isTrue(); + CdsUpdate cdsUpdate = statusOrUpdate.getValue(); assertThat(cdsUpdate.clusterName()).isEqualTo(CDS_RESOURCE); assertThat(cdsUpdate.clusterType()).isEqualTo(ClusterType.EDS); assertThat(cdsUpdate.edsServiceName()).isNull(); @@ -2110,7 +2700,6 @@ public void cdsResponseWithCircuitBreakers() { * CDS response containing UpstreamTlsContext for a cluster. */ @Test - @SuppressWarnings("deprecation") public void cdsResponseWithUpstreamTlsContext() { DiscoveryRpcCall call = startResourceWatcher(XdsClusterResource.getInstance(), CDS_RESOURCE, cdsResourceWatcher); @@ -2123,7 +2712,7 @@ public void cdsResponseWithUpstreamTlsContext() { "envoy.transport_sockets.tls", null, null)); List clusters = ImmutableList.of( Any.pack(mf.buildLogicalDnsCluster("cluster-bar.googleapis.com", - "dns-service-bar.googleapis.com", 443, "round_robin", null, null,false, null, null)), + "dns-service-bar.googleapis.com", 443, "round_robin", null, null, false, null, null)), clusterEds, Any.pack(mf.buildEdsCluster("cluster-baz.googleapis.com", null, "round_robin", null, null, false, null, "envoy.transport_sockets.tls", null, null))); @@ -2132,11 +2721,13 @@ public void cdsResponseWithUpstreamTlsContext() { // Client sent an ACK CDS request. call.verifyRequest(CDS, CDS_RESOURCE, VERSION_1, "0000", NODE); verify(cdsResourceWatcher, times(1)) - .onChanged(cdsUpdateCaptor.capture()); - CdsUpdate cdsUpdate = cdsUpdateCaptor.getValue(); - CommonTlsContext.CertificateProviderInstance certificateProviderInstance = + .onResourceChanged(cdsUpdateCaptor.capture()); + StatusOr statusOrUpdate = cdsUpdateCaptor.getValue(); + assertThat(statusOrUpdate.hasValue()).isTrue(); + CdsUpdate cdsUpdate = statusOrUpdate.getValue(); + CertificateProviderPluginInstance certificateProviderInstance = cdsUpdate.upstreamTlsContext().getCommonTlsContext().getCombinedValidationContext() - .getValidationContextCertificateProviderInstance(); + .getDefaultValidationContext().getCaCertificateProviderInstance(); assertThat(certificateProviderInstance.getInstanceName()).isEqualTo("cert-instance-name"); assertThat(certificateProviderInstance.getCertificateName()).isEqualTo("cert1"); verifyResourceMetadataAcked(CDS, CDS_RESOURCE, clusterEds, VERSION_1, TIME_INCREMENT); @@ -2147,7 +2738,6 @@ public void cdsResponseWithUpstreamTlsContext() { * CDS response containing new UpstreamTlsContext for a cluster. */ @Test - @SuppressWarnings("deprecation") public void cdsResponseWithNewUpstreamTlsContext() { DiscoveryRpcCall call = startResourceWatcher(XdsClusterResource.getInstance(), CDS_RESOURCE, cdsResourceWatcher); @@ -2155,7 +2745,7 @@ public void cdsResponseWithNewUpstreamTlsContext() { // Management server sends back CDS response with UpstreamTlsContext. Any clusterEds = Any.pack(mf.buildEdsCluster(CDS_RESOURCE, "eds-cluster-foo.googleapis.com", "round_robin", - null, null,true, + null, null, true, mf.buildNewUpstreamTlsContext("cert-instance-name", "cert1"), "envoy.transport_sockets.tls", null, null)); List clusters = ImmutableList.of( @@ -2168,8 +2758,10 @@ public void cdsResponseWithNewUpstreamTlsContext() { // Client sent an ACK CDS request. call.verifyRequest(CDS, CDS_RESOURCE, VERSION_1, "0000", NODE); - verify(cdsResourceWatcher, times(1)).onChanged(cdsUpdateCaptor.capture()); - CdsUpdate cdsUpdate = cdsUpdateCaptor.getValue(); + verify(cdsResourceWatcher, times(1)).onResourceChanged(cdsUpdateCaptor.capture()); + StatusOr statusOrUpdate = cdsUpdateCaptor.getValue(); + assertThat(statusOrUpdate.hasValue()).isTrue(); + CdsUpdate cdsUpdate = statusOrUpdate.getValue(); CertificateProviderPluginInstance certificateProviderInstance = cdsUpdate.upstreamTlsContext().getCommonTlsContext().getValidationContext() .getCaCertificateProviderInstance(); @@ -2196,19 +2788,19 @@ public void cdsResponseErrorHandling_badUpstreamTlsContext() { // The response NACKed with errors indicating indices of the failed resources. String errorMsg = "CDS response Cluster 'cluster.googleapis.com' validation error: " - + "Cluster cluster.googleapis.com: malformed UpstreamTlsContext: " - + "io.grpc.xds.client.XdsResourceType$ResourceInvalidException: " - + "ca_certificate_provider_instance is required in upstream-tls-context"; + + "Cluster cluster.googleapis.com: malformed UpstreamTlsContext: " + + "io.grpc.xds.client.XdsResourceType$ResourceInvalidException: " + + "ca_certificate_provider_instance or system_root_certs is required in " + + "upstream-tls-context"; call.verifyRequestNack(CDS, CDS_RESOURCE, "", "0000", NODE, ImmutableList.of(errorMsg)); - verify(cdsResourceWatcher).onError(errorCaptor.capture()); - verifyStatusWithNodeId(errorCaptor.getValue(), Code.UNAVAILABLE, errorMsg); + verify(cdsResourceWatcher).onResourceChanged(cdsUpdateCaptor.capture()); + verifyStatusWithNodeId(cdsUpdateCaptor.getValue().getStatus(), Code.UNAVAILABLE, errorMsg); } /** * CDS response containing OutlierDetection for a cluster. */ @Test - @SuppressWarnings("deprecation") public void cdsResponseWithOutlierDetection() { DiscoveryRpcCall call = startResourceWatcher(XdsClusterResource.getInstance(), CDS_RESOURCE, cdsResourceWatcher); @@ -2235,7 +2827,7 @@ public void cdsResponseWithOutlierDetection() { "envoy.transport_sockets.tls", null, outlierDetectionXds)); List clusters = ImmutableList.of( Any.pack(mf.buildLogicalDnsCluster("cluster-bar.googleapis.com", - "dns-service-bar.googleapis.com", 443, "round_robin", null, null,false, null, null)), + "dns-service-bar.googleapis.com", 443, "round_robin", null, null, false, null, null)), clusterEds, Any.pack(mf.buildEdsCluster("cluster-baz.googleapis.com", null, "round_robin", null, null, false, null, "envoy.transport_sockets.tls", null, outlierDetectionXds))); @@ -2243,8 +2835,10 @@ public void cdsResponseWithOutlierDetection() { // Client sent an ACK CDS request. call.verifyRequest(CDS, CDS_RESOURCE, VERSION_1, "0000", NODE); - verify(cdsResourceWatcher, times(1)).onChanged(cdsUpdateCaptor.capture()); - CdsUpdate cdsUpdate = cdsUpdateCaptor.getValue(); + verify(cdsResourceWatcher, times(1)).onResourceChanged(cdsUpdateCaptor.capture()); + StatusOr statusOrUpdate = cdsUpdateCaptor.getValue(); + assertThat(statusOrUpdate.hasValue()).isTrue(); + CdsUpdate cdsUpdate = statusOrUpdate.getValue(); // The outlier detection config in CdsUpdate should match what we get from xDS. EnvoyServerProtoData.OutlierDetection outlierDetection = cdsUpdate.outlierDetection(); @@ -2277,7 +2871,6 @@ public void cdsResponseWithOutlierDetection() { * CDS response containing OutlierDetection for a cluster. */ @Test - @SuppressWarnings("deprecation") public void cdsResponseWithInvalidOutlierDetectionNacks() { DiscoveryRpcCall call = startResourceWatcher(XdsClusterResource.getInstance(), CDS_RESOURCE, @@ -2294,7 +2887,7 @@ public void cdsResponseWithInvalidOutlierDetectionNacks() { "envoy.transport_sockets.tls", null, outlierDetectionXds)); List clusters = ImmutableList.of( Any.pack(mf.buildLogicalDnsCluster("cluster-bar.googleapis.com", - "dns-service-bar.googleapis.com", 443, "round_robin", null, null,false, null, null)), + "dns-service-bar.googleapis.com", 443, "round_robin", null, null, false, null, null)), clusterEds, Any.pack(mf.buildEdsCluster("cluster-baz.googleapis.com", null, "round_robin", null, null, false, null, "envoy.transport_sockets.tls", null, outlierDetectionXds))); @@ -2305,8 +2898,8 @@ public void cdsResponseWithInvalidOutlierDetectionNacks() { + "io.grpc.xds.client.XdsResourceType$ResourceInvalidException: outlier_detection " + "max_ejection_percent is > 100"; call.verifyRequestNack(CDS, CDS_RESOURCE, "", "0000", NODE, ImmutableList.of(errorMsg)); - verify(cdsResourceWatcher).onError(errorCaptor.capture()); - verifyStatusWithNodeId(errorCaptor.getValue(), Code.UNAVAILABLE, errorMsg); + verify(cdsResourceWatcher).onResourceChanged(cdsUpdateCaptor.capture()); + verifyStatusWithNodeId(cdsUpdateCaptor.getValue().getStatus(), Code.UNAVAILABLE, errorMsg); } @Test(expected = ResourceInvalidException.class) @@ -2400,8 +2993,8 @@ public void cdsResponseErrorHandling_badTransportSocketName() { String errorMsg = "CDS response Cluster 'cluster.googleapis.com' validation error: " + "transport-socket with name envoy.transport_sockets.bad not supported."; call.verifyRequestNack(CDS, CDS_RESOURCE, "", "0000", NODE, ImmutableList.of(errorMsg)); - verify(cdsResourceWatcher).onError(errorCaptor.capture()); - verifyStatusWithNodeId(errorCaptor.getValue(), Code.UNAVAILABLE, errorMsg); + verify(cdsResourceWatcher).onResourceChanged(cdsUpdateCaptor.capture()); + verifyStatusWithNodeId(cdsUpdateCaptor.getValue().getStatus(), Code.UNAVAILABLE, errorMsg); } @Test @@ -2414,8 +3007,7 @@ public void cdsResponseErrorHandling_xdstpWithoutEdsConfig() { )); final Any okClusterRoundRobin = Any.pack(mf.buildEdsCluster(cdsResourceName, "eds-service-bar.googleapis.com", - "round_robin", null,null, false, null, "envoy.transport_sockets.tls", null, null)); - + "round_robin", null, null, false, null, "envoy.transport_sockets.tls", null, null)); DiscoveryRpcCall call = startResourceWatcher(XdsClusterResource.getInstance(), cdsResourceName, cdsResourceWatcher); @@ -2444,8 +3036,10 @@ public void cachedCdsResource_data() { ResourceWatcher watcher = mock(ResourceWatcher.class); xdsClient.watchXdsResource(XdsClusterResource.getInstance(), CDS_RESOURCE, watcher); - verify(watcher).onChanged(cdsUpdateCaptor.capture()); - verifyGoldenClusterRoundRobin(cdsUpdateCaptor.getValue()); + verify(watcher).onResourceChanged(cdsUpdateCaptor.capture()); + StatusOr statusOrUpdate = cdsUpdateCaptor.getValue(); + assertThat(statusOrUpdate.hasValue()).isTrue(); + verifyGoldenClusterRoundRobin(statusOrUpdate.getValue()); call.verifyNoMoreRequest(); verifyResourceMetadataAcked(CDS, CDS_RESOURCE, testClusterRoundRobin, VERSION_1, TIME_INCREMENT); @@ -2459,10 +3053,12 @@ public void cachedCdsResource_absent() { DiscoveryRpcCall call = startResourceWatcher(XdsClusterResource.getInstance(), CDS_RESOURCE, cdsResourceWatcher); fakeClock.forwardTime(XdsClientImpl.INITIAL_RESOURCE_FETCH_TIMEOUT_SEC, TimeUnit.SECONDS); - verify(cdsResourceWatcher).onResourceDoesNotExist(CDS_RESOURCE); + verify(cdsResourceWatcher).onResourceChanged(argThat( + arg -> !arg.hasValue() && arg.getStatus().getDescription().contains(CDS_RESOURCE))); ResourceWatcher watcher = mock(ResourceWatcher.class); - xdsClient.watchXdsResource(XdsClusterResource.getInstance(),CDS_RESOURCE, watcher); - verify(watcher).onResourceDoesNotExist(CDS_RESOURCE); + xdsClient.watchXdsResource(XdsClusterResource.getInstance(), CDS_RESOURCE, watcher); + verify(watcher).onResourceChanged(argThat( + arg -> !arg.hasValue() && arg.getStatus().getDescription().contains(CDS_RESOURCE))); call.verifyNoMoreRequest(); verifyResourceMetadataDoesNotExist(CDS, CDS_RESOURCE); verifySubscribedResourcesMetadataSizes(0, 1, 0, 0); @@ -2482,8 +3078,10 @@ public void cdsResourceUpdated() { null, null, false, null, null)); call.sendResponse(CDS, clusterDns, VERSION_1, "0000"); call.verifyRequest(CDS, CDS_RESOURCE, VERSION_1, "0000", NODE); - verify(cdsResourceWatcher).onChanged(cdsUpdateCaptor.capture()); - CdsUpdate cdsUpdate = cdsUpdateCaptor.getValue(); + verify(cdsResourceWatcher).onResourceChanged(cdsUpdateCaptor.capture()); + StatusOr statusOrUpdate = cdsUpdateCaptor.getValue(); + assertThat(statusOrUpdate.hasValue()).isTrue(); + CdsUpdate cdsUpdate = statusOrUpdate.getValue(); assertThat(cdsUpdate.clusterName()).isEqualTo(CDS_RESOURCE); assertThat(cdsUpdate.clusterType()).isEqualTo(ClusterType.LOGICAL_DNS); assertThat(cdsUpdate.dnsHostName()).isEqualTo(dnsHostAddr + ":" + dnsHostPort); @@ -2505,8 +3103,10 @@ public void cdsResourceUpdated() { )); call.sendResponse(CDS, clusterEds, VERSION_2, "0001"); call.verifyRequest(CDS, CDS_RESOURCE, VERSION_2, "0001", NODE); - verify(cdsResourceWatcher, times(2)).onChanged(cdsUpdateCaptor.capture()); - cdsUpdate = cdsUpdateCaptor.getValue(); + verify(cdsResourceWatcher, times(2)).onResourceChanged(cdsUpdateCaptor.capture()); + statusOrUpdate = cdsUpdateCaptor.getValue(); + assertThat(statusOrUpdate.hasValue()).isTrue(); + cdsUpdate = statusOrUpdate.getValue(); assertThat(cdsUpdate.clusterName()).isEqualTo(CDS_RESOURCE); assertThat(cdsUpdate.clusterType()).isEqualTo(ClusterType.EDS); assertThat(cdsUpdate.edsServiceName()).isEqualTo(edsService); @@ -2547,27 +3147,27 @@ public void cdsResourceUpdatedWithDuplicate() { // Configure with round robin, the update should be sent to the watcher. call.sendResponse(CDS, roundRobinConfig, VERSION_2, "0001"); - verify(cdsResourceWatcher, times(1)).onChanged(isA(CdsUpdate.class)); + verify(cdsResourceWatcher, times(1)).onResourceChanged(argThat(StatusOr::hasValue)); // Second update is identical, watcher should not get an additional update. call.sendResponse(CDS, roundRobinConfig, VERSION_2, "0002"); - verify(cdsResourceWatcher, times(1)).onChanged(isA(CdsUpdate.class)); + verify(cdsResourceWatcher, times(1)).onResourceChanged(any()); // Now we switch to ring hash so the watcher should be notified. call.sendResponse(CDS, ringHashConfig, VERSION_2, "0003"); - verify(cdsResourceWatcher, times(2)).onChanged(isA(CdsUpdate.class)); + verify(cdsResourceWatcher, times(2)).onResourceChanged(argThat(StatusOr::hasValue)); // Second update to ring hash should not result in watcher being notified. call.sendResponse(CDS, ringHashConfig, VERSION_2, "0004"); - verify(cdsResourceWatcher, times(2)).onChanged(isA(CdsUpdate.class)); + verify(cdsResourceWatcher, times(2)).onResourceChanged(any()); // Now we switch to least request so the watcher should be notified. call.sendResponse(CDS, leastRequestConfig, VERSION_2, "0005"); - verify(cdsResourceWatcher, times(3)).onChanged(isA(CdsUpdate.class)); + verify(cdsResourceWatcher, times(3)).onResourceChanged(argThat(StatusOr::hasValue)); // Second update to least request should not result in watcher being notified. call.sendResponse(CDS, leastRequestConfig, VERSION_2, "0006"); - verify(cdsResourceWatcher, times(3)).onChanged(isA(CdsUpdate.class)); + verify(cdsResourceWatcher, times(3)).onResourceChanged(any()); } @Test @@ -2581,8 +3181,10 @@ public void cdsResourceDeleted() { // Initial CDS response. call.sendResponse(CDS, testClusterRoundRobin, VERSION_1, "0000"); call.verifyRequest(CDS, CDS_RESOURCE, VERSION_1, "0000", NODE); - verify(cdsResourceWatcher).onChanged(cdsUpdateCaptor.capture()); - verifyGoldenClusterRoundRobin(cdsUpdateCaptor.getValue()); + verify(cdsResourceWatcher).onResourceChanged(cdsUpdateCaptor.capture()); + StatusOr statusOrUpdate = cdsUpdateCaptor.getValue(); + assertThat(statusOrUpdate.hasValue()).isTrue(); + verifyGoldenClusterRoundRobin(statusOrUpdate.getValue()); verifyResourceMetadataAcked(CDS, CDS_RESOURCE, testClusterRoundRobin, VERSION_1, TIME_INCREMENT); verifySubscribedResourcesMetadataSizes(0, 1, 0, 0); @@ -2590,7 +3192,8 @@ public void cdsResourceDeleted() { // Empty CDS response deletes the cluster. call.sendResponse(CDS, Collections.emptyList(), VERSION_2, "0001"); call.verifyRequest(CDS, CDS_RESOURCE, VERSION_2, "0001", NODE); - verify(cdsResourceWatcher).onResourceDoesNotExist(CDS_RESOURCE); + verify(cdsResourceWatcher).onResourceChanged(argThat( + arg -> !arg.hasValue() && arg.getStatus().getDescription().contains(CDS_RESOURCE))); verifyResourceMetadataDoesNotExist(CDS, CDS_RESOURCE); verifySubscribedResourcesMetadataSizes(0, 1, 0, 0); } @@ -2598,7 +3201,7 @@ public void cdsResourceDeleted() { /** * When ignore_resource_deletion server feature is on, xDS client should keep the deleted cluster * on empty response, and resume the normal work when CDS contains the cluster again. - * */ + */ @Test public void cdsResourceDeleted_ignoreResourceDeletion() { Assume.assumeTrue(ignoreResourceDeletion()); @@ -2610,8 +3213,10 @@ public void cdsResourceDeleted_ignoreResourceDeletion() { // Initial CDS response. call.sendResponse(CDS, testClusterRoundRobin, VERSION_1, "0000"); call.verifyRequest(CDS, CDS_RESOURCE, VERSION_1, "0000", NODE); - verify(cdsResourceWatcher).onChanged(cdsUpdateCaptor.capture()); - verifyGoldenClusterRoundRobin(cdsUpdateCaptor.getValue()); + verify(cdsResourceWatcher).onResourceChanged(cdsUpdateCaptor.capture()); + StatusOr statusOrUpdate = cdsUpdateCaptor.getValue(); + assertThat(statusOrUpdate.hasValue()).isTrue(); + verifyGoldenClusterRoundRobin(statusOrUpdate.getValue()); verifyResourceMetadataAcked(CDS, CDS_RESOURCE, testClusterRoundRobin, VERSION_1, TIME_INCREMENT); verifySubscribedResourcesMetadataSizes(0, 1, 0, 0); @@ -2625,28 +3230,253 @@ public void cdsResourceDeleted_ignoreResourceDeletion() { TIME_INCREMENT); verifySubscribedResourcesMetadataSizes(0, 1, 0, 0); // onResourceDoesNotExist must not be called. - verify(ldsResourceWatcher, never()).onResourceDoesNotExist(CDS_RESOURCE); + verify(ldsResourceWatcher, never()).onResourceChanged(argThat( + arg -> !arg.hasValue() && arg.getStatus().getDescription().contains(CDS_RESOURCE))); // Next update is correct, and contains the cluster again. call.sendResponse(CDS, testClusterRoundRobin, VERSION_3, "0003"); call.verifyRequest(CDS, CDS_RESOURCE, VERSION_3, "0003", NODE); - verify(cdsResourceWatcher).onChanged(cdsUpdateCaptor.capture()); - verifyGoldenClusterRoundRobin(cdsUpdateCaptor.getValue()); + verify(cdsResourceWatcher).onResourceChanged(cdsUpdateCaptor.capture()); + statusOrUpdate = cdsUpdateCaptor.getValue(); + assertThat(statusOrUpdate.hasValue()).isTrue(); + verifyGoldenClusterRoundRobin(statusOrUpdate.getValue()); verifyResourceMetadataAcked(CDS, CDS_RESOURCE, testClusterRoundRobin, VERSION_3, TIME_INCREMENT * 3); verifySubscribedResourcesMetadataSizes(0, 1, 0, 0); verifyNoMoreInteractions(ldsResourceWatcher); } + /** + * When fail_on_data_errors server feature is on, xDS client should delete the cached cluster + * and fail RPCs when CDS resource is deleted. + */ + @Test + public void cdsResourceDeleted_failOnDataErrors_true() { + BootstrapperImpl.xdsDataErrorHandlingEnabled = true; + xdsServerInfo = ServerInfo.create(SERVER_URI, CHANNEL_CREDENTIALS, false, + true, false, true); + BootstrapInfo bootstrapInfo = + Bootstrapper.BootstrapInfo.builder() + .servers(Collections.singletonList(xdsServerInfo)) + .node(NODE) + .authorities(ImmutableMap.of( + "", + AuthorityInfo.create( + "xdstp:///envoy.config.listener.v3.Listener/%s", + ImmutableList.of(Bootstrapper.ServerInfo.create( + SERVER_URI_EMPTY_AUTHORITY, CHANNEL_CREDENTIALS))))) + .certProviders(ImmutableMap.of()) + .build(); + xdsClient = new XdsClientImpl( + xdsTransportFactory, + bootstrapInfo, + fakeClock.getScheduledExecutorService(), + backoffPolicyProvider, + fakeClock.getStopwatchSupplier(), + timeProvider, + MessagePrinter.INSTANCE, + new TlsContextManagerImpl(bootstrapInfo), + xdsClientMetricReporter); + + DiscoveryRpcCall call = startResourceWatcher(XdsClusterResource.getInstance(), CDS_RESOURCE, + cdsResourceWatcher); + verifyResourceMetadataRequested(CDS, CDS_RESOURCE); + + // Initial CDS response. + call.sendResponse(CDS, testClusterRoundRobin, VERSION_1, "0000"); + call.verifyRequest(CDS, CDS_RESOURCE, VERSION_1, "0000", NODE); + verify(cdsResourceWatcher).onResourceChanged(cdsUpdateCaptor.capture()); + StatusOr statusOrUpdate = cdsUpdateCaptor.getValue(); + assertThat(statusOrUpdate.hasValue()).isTrue(); + verifyGoldenClusterRoundRobin(statusOrUpdate.getValue()); + verifyResourceMetadataAcked(CDS, CDS_RESOURCE, testClusterRoundRobin, VERSION_1, + TIME_INCREMENT); + verifySubscribedResourcesMetadataSizes(0, 1, 0, 0); + + // Empty CDS response deletes the cluster and fails RPCs. + call.sendResponse(CDS, Collections.emptyList(), VERSION_2, "0001"); + call.verifyRequest(CDS, CDS_RESOURCE, VERSION_2, "0001", NODE); + verify(cdsResourceWatcher).onResourceChanged(argThat( + arg -> !arg.hasValue() && arg.getStatus().getDescription().contains(CDS_RESOURCE))); + verifyResourceMetadataDoesNotExist(CDS, CDS_RESOURCE); + verifySubscribedResourcesMetadataSizes(0, 1, 0, 0); + + BootstrapperImpl.xdsDataErrorHandlingEnabled = false; + } + + /** + * When fail_on_data_errors server feature is on, xDS client should delete the cached cluster + * and fail RPCs when CDS resource is deleted. + */ + @Test + public void cdsResourceDeleted_failOnDataErrors_false() { + BootstrapperImpl.xdsDataErrorHandlingEnabled = true; + // Set failOnDataErrors to false for this test case. + xdsServerInfo = ServerInfo.create(SERVER_URI, CHANNEL_CREDENTIALS, false, + true, false, false); + BootstrapInfo bootstrapInfo = + Bootstrapper.BootstrapInfo.builder() + .servers(Collections.singletonList(xdsServerInfo)) + .node(NODE) + .authorities(ImmutableMap.of( + "", + AuthorityInfo.create( + "xdstp:///envoy.config.listener.v3.Listener/%s", + ImmutableList.of(Bootstrapper.ServerInfo.create( + SERVER_URI_EMPTY_AUTHORITY, CHANNEL_CREDENTIALS))))) + .certProviders(ImmutableMap.of()) + .build(); + xdsClient = new XdsClientImpl( + xdsTransportFactory, + bootstrapInfo, + fakeClock.getScheduledExecutorService(), + backoffPolicyProvider, + fakeClock.getStopwatchSupplier(), + timeProvider, + MessagePrinter.INSTANCE, + new TlsContextManagerImpl(bootstrapInfo), + xdsClientMetricReporter); + + InOrder inOrder = inOrder(cdsResourceWatcher); + DiscoveryRpcCall call = startResourceWatcher(XdsClusterResource.getInstance(), CDS_RESOURCE, + cdsResourceWatcher); + verifyResourceMetadataRequested(CDS, CDS_RESOURCE); + + // Initial CDS response. + call.sendResponse(CDS, testClusterRoundRobin, VERSION_1, "0000"); + call.verifyRequest(CDS, CDS_RESOURCE, VERSION_1, "0000", NODE); + inOrder.verify(cdsResourceWatcher).onResourceChanged(cdsUpdateCaptor.capture()); + StatusOr statusOrUpdate = cdsUpdateCaptor.getValue(); + assertThat(statusOrUpdate.hasValue()).isTrue(); + verifyGoldenClusterRoundRobin(statusOrUpdate.getValue()); + verifyResourceMetadataAcked(CDS, CDS_RESOURCE, testClusterRoundRobin, VERSION_1, + TIME_INCREMENT); + verifySubscribedResourcesMetadataSizes(0, 1, 0, 0); + + // Empty CDS response should trigger an ambient error. + call.sendResponse(CDS, Collections.emptyList(), VERSION_2, "0001"); + call.verifyRequest(CDS, CDS_RESOURCE, VERSION_2, "0001", NODE); + + // Verify that onAmbientError() is called. + ArgumentCaptor statusCaptor = ArgumentCaptor.forClass(Status.class); + inOrder.verify(cdsResourceWatcher).onAmbientError(statusCaptor.capture()); + Status receivedStatus = statusCaptor.getValue(); + assertThat(receivedStatus.getCode()).isEqualTo(Status.Code.NOT_FOUND); + assertThat(receivedStatus.getDescription()).contains( + "Resource " + CDS_RESOURCE + " deleted from server"); + + // Verify that onResourceChanged() is NOT called again. + inOrder.verify(cdsResourceWatcher, never()).onResourceChanged(any()); + verifySubscribedResourcesMetadataSizes(0, 1, 0, 0); + + BootstrapperImpl.xdsDataErrorHandlingEnabled = false; + } + + /** + * Tests that a NACKed LDS resource update drops the cached resource when fail_on_data_errors + * is enabled. + */ + @Test + public void ldsResourceNacked_withFailOnDataErrors_dropsResource() { + BootstrapperImpl.xdsDataErrorHandlingEnabled = true; + xdsServerInfo = ServerInfo.create(SERVER_URI, CHANNEL_CREDENTIALS, false, + true, false, true); + BootstrapInfo bootstrapInfo = + Bootstrapper.BootstrapInfo.builder() + .servers(Collections.singletonList(xdsServerInfo)) + .node(NODE) + .build(); + xdsClient = new XdsClientImpl( + xdsTransportFactory, + bootstrapInfo, + fakeClock.getScheduledExecutorService(), + backoffPolicyProvider, + fakeClock.getStopwatchSupplier(), + timeProvider, + MessagePrinter.INSTANCE, + new TlsContextManagerImpl(bootstrapInfo), + xdsClientMetricReporter); + + InOrder inOrder = inOrder(ldsResourceWatcher); + DiscoveryRpcCall call = startResourceWatcher(XdsListenerResource.getInstance(), LDS_RESOURCE, + ldsResourceWatcher); + call.sendResponse(LDS, testListenerVhosts, VERSION_1, "0000"); + call.verifyRequest(LDS, LDS_RESOURCE, VERSION_1, "0000", NODE); + inOrder.verify(ldsResourceWatcher).onResourceChanged(ldsUpdateCaptor.capture()); + StatusOr initialUpdate = ldsUpdateCaptor.getValue(); + assertThat(initialUpdate.hasValue()).isTrue(); + verifyGoldenListenerVhosts(initialUpdate.getValue()); + Message invalidListener = mf.buildListenerWithApiListenerInvalid(LDS_RESOURCE); + call.sendResponse(LDS, Collections.singletonList(Any.pack(invalidListener)), VERSION_2, "0001"); + String expectedError = "LDS response Listener '" + LDS_RESOURCE + "' validation error"; + call.verifyRequestNack(LDS, LDS_RESOURCE, VERSION_1, "0001", NODE, + Collections.singletonList(expectedError)); + + inOrder.verify(ldsResourceWatcher).onResourceChanged(ldsUpdateCaptor.capture()); + StatusOr finalUpdate = ldsUpdateCaptor.getValue(); + assertThat(finalUpdate.hasValue()).isFalse(); + assertThat(finalUpdate.getStatus().getCode()).isEqualTo(Status.Code.UNAVAILABLE); + assertThat(finalUpdate.getStatus().getDescription()).contains(expectedError); + + BootstrapperImpl.xdsDataErrorHandlingEnabled = false; + } + + /** + * Tests that a NACKed LDS resource update is treated as an ambient error when + * fail_on_data_errors is disabled. + */ + @Test + public void ldsResourceNacked_withFailOnDataErrorsDisabled_isAmbientError() { + BootstrapperImpl.xdsDataErrorHandlingEnabled = true; + xdsServerInfo = ServerInfo.create(SERVER_URI, CHANNEL_CREDENTIALS, false, + true, false, false); + BootstrapInfo bootstrapInfo = + Bootstrapper.BootstrapInfo.builder() + .servers(Collections.singletonList(xdsServerInfo)) + .node(NODE) + .build(); + xdsClient = new XdsClientImpl( + xdsTransportFactory, + bootstrapInfo, + fakeClock.getScheduledExecutorService(), + backoffPolicyProvider, + fakeClock.getStopwatchSupplier(), + timeProvider, + MessagePrinter.INSTANCE, + new TlsContextManagerImpl(bootstrapInfo), + xdsClientMetricReporter); + InOrder inOrder = inOrder(ldsResourceWatcher); + DiscoveryRpcCall call = startResourceWatcher(XdsListenerResource.getInstance(), LDS_RESOURCE, + ldsResourceWatcher); + + call.sendResponse(LDS, testListenerVhosts, VERSION_1, "0000"); + call.verifyRequest(LDS, LDS_RESOURCE, VERSION_1, "0000", NODE); + inOrder.verify(ldsResourceWatcher).onResourceChanged(any()); + Message invalidListener = mf.buildListenerWithApiListenerInvalid(LDS_RESOURCE); + call.sendResponse(LDS, Collections.singletonList(Any.pack(invalidListener)), VERSION_2, "0001"); + + String expectedError = "LDS response Listener '" + LDS_RESOURCE + "' validation error"; + call.verifyRequestNack(LDS, LDS_RESOURCE, VERSION_1, "0001", NODE, + Collections.singletonList(expectedError)); + ArgumentCaptor statusCaptor = ArgumentCaptor.forClass(Status.class); + inOrder.verify(ldsResourceWatcher).onAmbientError(statusCaptor.capture()); + Status receivedStatus = statusCaptor.getValue(); + assertThat(receivedStatus.getCode()).isEqualTo(Status.Code.UNAVAILABLE); + assertThat(receivedStatus.getDescription()).contains(expectedError); + inOrder.verify(ldsResourceWatcher, never()).onResourceChanged(any()); + + BootstrapperImpl.xdsDataErrorHandlingEnabled = false; + } + @Test @SuppressWarnings("unchecked") public void multipleCdsWatchers() { String cdsResourceTwo = "cluster-bar.googleapis.com"; ResourceWatcher watcher1 = mock(ResourceWatcher.class); ResourceWatcher watcher2 = mock(ResourceWatcher.class); - xdsClient.watchXdsResource(XdsClusterResource.getInstance(),CDS_RESOURCE, cdsResourceWatcher); - xdsClient.watchXdsResource(XdsClusterResource.getInstance(),cdsResourceTwo, watcher1); - xdsClient.watchXdsResource(XdsClusterResource.getInstance(),cdsResourceTwo, watcher2); + xdsClient.watchXdsResource(XdsClusterResource.getInstance(), CDS_RESOURCE, cdsResourceWatcher); + xdsClient.watchXdsResource(XdsClusterResource.getInstance(), cdsResourceTwo, watcher1); + xdsClient.watchXdsResource(XdsClusterResource.getInstance(), cdsResourceTwo, watcher2); DiscoveryRpcCall call = resourceDiscoveryCalls.poll(); call.verifyRequest(CDS, Arrays.asList(CDS_RESOURCE, cdsResourceTwo), "", "", NODE); verifyResourceMetadataRequested(CDS, CDS_RESOURCE); @@ -2654,9 +3484,12 @@ public void multipleCdsWatchers() { verifySubscribedResourcesMetadataSizes(0, 2, 0, 0); fakeClock.forwardTime(XdsClientImpl.INITIAL_RESOURCE_FETCH_TIMEOUT_SEC, TimeUnit.SECONDS); - verify(cdsResourceWatcher).onResourceDoesNotExist(CDS_RESOURCE); - verify(watcher1).onResourceDoesNotExist(cdsResourceTwo); - verify(watcher2).onResourceDoesNotExist(cdsResourceTwo); + verify(cdsResourceWatcher).onResourceChanged(argThat(statusOr -> + statusOr.getStatus().getDescription().contains(CDS_RESOURCE))); + verify(watcher1).onResourceChanged(argThat(statusOr -> + statusOr.getStatus().getDescription().contains(cdsResourceTwo))); + verify(watcher2).onResourceChanged(argThat(statusOr -> + statusOr.getStatus().getDescription().contains(cdsResourceTwo))); verifyResourceMetadataDoesNotExist(CDS, CDS_RESOURCE); verifyResourceMetadataDoesNotExist(CDS, cdsResourceTwo); verifySubscribedResourcesMetadataSizes(0, 2, 0, 0); @@ -2670,45 +3503,54 @@ public void multipleCdsWatchers() { Any.pack(mf.buildEdsCluster(cdsResourceTwo, edsService, "round_robin", null, null, true, null, "envoy.transport_sockets.tls", null, null))); call.sendResponse(CDS, clusters, VERSION_1, "0000"); - verify(cdsResourceWatcher).onChanged(cdsUpdateCaptor.capture()); - CdsUpdate cdsUpdate = cdsUpdateCaptor.getValue(); - assertThat(cdsUpdate.clusterName()).isEqualTo(CDS_RESOURCE); - assertThat(cdsUpdate.clusterType()).isEqualTo(ClusterType.LOGICAL_DNS); - assertThat(cdsUpdate.dnsHostName()).isEqualTo(dnsHostAddr + ":" + dnsHostPort); - LbConfig lbConfig = ServiceConfigUtil.unwrapLoadBalancingConfig(cdsUpdate.lbPolicyConfig()); + ArgumentCaptor> cdsUpdateCaptor = ArgumentCaptor.forClass(StatusOr.class); + verify(cdsResourceWatcher, times(2)).onResourceChanged(cdsUpdateCaptor.capture()); + StatusOr capturedUpdate1 = cdsUpdateCaptor.getAllValues().get(1); + assertThat(capturedUpdate1.hasValue()).isTrue(); + CdsUpdate cdsUpdate1 = capturedUpdate1.getValue(); + assertThat(cdsUpdate1.clusterName()).isEqualTo(CDS_RESOURCE); + assertThat(cdsUpdate1.clusterType()).isEqualTo(ClusterType.LOGICAL_DNS); + assertThat(cdsUpdate1.dnsHostName()).isEqualTo(dnsHostAddr + ":" + dnsHostPort); + LbConfig lbConfig = ServiceConfigUtil.unwrapLoadBalancingConfig(cdsUpdate1.lbPolicyConfig()); assertThat(lbConfig.getPolicyName()).isEqualTo("wrr_locality_experimental"); List childConfigs = ServiceConfigUtil.unwrapLoadBalancingConfigList( JsonUtil.getListOfObjects(lbConfig.getRawConfigValue(), "childPolicy")); assertThat(childConfigs.get(0).getPolicyName()).isEqualTo("round_robin"); - assertThat(cdsUpdate.lrsServerInfo()).isNull(); - assertThat(cdsUpdate.maxConcurrentRequests()).isNull(); - assertThat(cdsUpdate.upstreamTlsContext()).isNull(); - verify(watcher1).onChanged(cdsUpdateCaptor.capture()); - cdsUpdate = cdsUpdateCaptor.getValue(); - assertThat(cdsUpdate.clusterName()).isEqualTo(cdsResourceTwo); - assertThat(cdsUpdate.clusterType()).isEqualTo(ClusterType.EDS); - assertThat(cdsUpdate.edsServiceName()).isEqualTo(edsService); - lbConfig = ServiceConfigUtil.unwrapLoadBalancingConfig(cdsUpdate.lbPolicyConfig()); + assertThat(cdsUpdate1.lrsServerInfo()).isNull(); + assertThat(cdsUpdate1.maxConcurrentRequests()).isNull(); + assertThat(cdsUpdate1.upstreamTlsContext()).isNull(); + ArgumentCaptor> watcher1Captor = ArgumentCaptor.forClass(StatusOr.class); + verify(watcher1, times(2)).onResourceChanged(watcher1Captor.capture()); + StatusOr capturedUpdate2 = watcher1Captor.getAllValues().get(1); + assertThat(capturedUpdate2.hasValue()).isTrue(); + CdsUpdate cdsUpdate2 = capturedUpdate2.getValue(); + assertThat(cdsUpdate2.clusterName()).isEqualTo(cdsResourceTwo); + assertThat(cdsUpdate2.clusterType()).isEqualTo(ClusterType.EDS); + assertThat(cdsUpdate2.edsServiceName()).isEqualTo(edsService); + lbConfig = ServiceConfigUtil.unwrapLoadBalancingConfig(cdsUpdate2.lbPolicyConfig()); assertThat(lbConfig.getPolicyName()).isEqualTo("wrr_locality_experimental"); childConfigs = ServiceConfigUtil.unwrapLoadBalancingConfigList( JsonUtil.getListOfObjects(lbConfig.getRawConfigValue(), "childPolicy")); assertThat(childConfigs.get(0).getPolicyName()).isEqualTo("round_robin"); - assertThat(cdsUpdate.lrsServerInfo()).isEqualTo(xdsServerInfo); - assertThat(cdsUpdate.maxConcurrentRequests()).isNull(); - assertThat(cdsUpdate.upstreamTlsContext()).isNull(); - verify(watcher2).onChanged(cdsUpdateCaptor.capture()); - cdsUpdate = cdsUpdateCaptor.getValue(); - assertThat(cdsUpdate.clusterName()).isEqualTo(cdsResourceTwo); - assertThat(cdsUpdate.clusterType()).isEqualTo(ClusterType.EDS); - assertThat(cdsUpdate.edsServiceName()).isEqualTo(edsService); - lbConfig = ServiceConfigUtil.unwrapLoadBalancingConfig(cdsUpdate.lbPolicyConfig()); + assertThat(cdsUpdate2.lrsServerInfo()).isEqualTo(xdsServerInfo); + assertThat(cdsUpdate2.maxConcurrentRequests()).isNull(); + assertThat(cdsUpdate2.upstreamTlsContext()).isNull(); + ArgumentCaptor> watcher2Captor = ArgumentCaptor.forClass(StatusOr.class); + verify(watcher2, times(2)).onResourceChanged(watcher2Captor.capture()); + StatusOr capturedUpdate3 = watcher2Captor.getAllValues().get(1); + assertThat(capturedUpdate3.hasValue()).isTrue(); + CdsUpdate cdsUpdate3 = capturedUpdate3.getValue(); + assertThat(cdsUpdate3.clusterName()).isEqualTo(cdsResourceTwo); + assertThat(cdsUpdate3.clusterType()).isEqualTo(ClusterType.EDS); + assertThat(cdsUpdate3.edsServiceName()).isEqualTo(edsService); + lbConfig = ServiceConfigUtil.unwrapLoadBalancingConfig(cdsUpdate3.lbPolicyConfig()); assertThat(lbConfig.getPolicyName()).isEqualTo("wrr_locality_experimental"); childConfigs = ServiceConfigUtil.unwrapLoadBalancingConfigList( JsonUtil.getListOfObjects(lbConfig.getRawConfigValue(), "childPolicy")); assertThat(childConfigs.get(0).getPolicyName()).isEqualTo("round_robin"); - assertThat(cdsUpdate.lrsServerInfo()).isEqualTo(xdsServerInfo); - assertThat(cdsUpdate.maxConcurrentRequests()).isNull(); - assertThat(cdsUpdate.upstreamTlsContext()).isNull(); + assertThat(cdsUpdate3.lrsServerInfo()).isEqualTo(xdsServerInfo); + assertThat(cdsUpdate3.maxConcurrentRequests()).isNull(); + assertThat(cdsUpdate3.upstreamTlsContext()).isNull(); // Metadata of both clusters is stored. verifyResourceMetadataAcked(CDS, CDS_RESOURCE, clusters.get(0), VERSION_1, TIME_INCREMENT); verifyResourceMetadataAcked(CDS, cdsResourceTwo, clusters.get(1), VERSION_1, TIME_INCREMENT); @@ -2732,12 +3574,46 @@ public void edsResourceNotFound() { verifySubscribedResourcesMetadataSizes(0, 0, 0, 1); // Server failed to return subscribed resource within expected time window. fakeClock.forwardTime(XdsClientImpl.INITIAL_RESOURCE_FETCH_TIMEOUT_SEC, TimeUnit.SECONDS); - verify(edsResourceWatcher).onResourceDoesNotExist(EDS_RESOURCE); + verify(edsResourceWatcher).onResourceChanged(argThat(statusOr -> + statusOr.getStatus().getDescription().contains(EDS_RESOURCE))); assertThat(fakeClock.getPendingTasks(EDS_RESOURCE_FETCH_TIMEOUT_TASK_FILTER)).isEmpty(); verifyResourceMetadataDoesNotExist(EDS, EDS_RESOURCE); verifySubscribedResourcesMetadataSizes(0, 0, 0, 1); } + @Test + public void edsCleanupNonceAfterUnsubscription() { + Assume.assumeFalse(ignoreResourceDeletion()); + + // Suppose we have an EDS subscription A.1 + xdsClient.watchXdsResource(XdsEndpointResource.getInstance(), "A.1", edsResourceWatcher); + DiscoveryRpcCall call = resourceDiscoveryCalls.poll(); + assertThat(call).isNotNull(); + call.verifyRequest(EDS, "A.1", "", "", NODE); + + // EDS -> {A.1}, version 1 + List dropOverloads = ImmutableList.of(); + List endpointsV1 = ImmutableList.of(lbEndpointHealthy); + ImmutableMap resourcesV1 = ImmutableMap.of( + "A.1", Any.pack(mf.buildClusterLoadAssignment("A.1", endpointsV1, dropOverloads))); + call.sendResponse(EDS, resourcesV1.values().asList(), VERSION_1, "0000"); + // {A.1} -> ACK, version 1 + call.verifyRequest(EDS, "A.1", VERSION_1, "0000", NODE); + verify(edsResourceWatcher, times(1)).onResourceChanged(any()); + + // trigger an EDS resource unsubscription. + xdsClient.cancelXdsResourceWatch(XdsEndpointResource.getInstance(), "A.1", edsResourceWatcher); + verifySubscribedResourcesMetadataSizes(0, 0, 0, 0); + call.verifyRequest(EDS, Arrays.asList(), VERSION_1, "0000", NODE); + // The control plane can send an updated response for the empty subscription list, with a new + // nonce. + call.sendResponse(EDS, Arrays.asList(), VERSION_1, "0001"); + + // When re-subscribing, the version was forgotten but not the nonce + xdsClient.watchXdsResource(XdsEndpointResource.getInstance(), "A.1", edsResourceWatcher); + call.verifyRequest(EDS, "A.1", "", "0001", NODE, Mockito.timeout(2000)); + } + @Test public void edsResponseErrorHandling_allResourcesFailedUnpack() { DiscoveryRpcCall call = startResourceWatcher(XdsEndpointResource.getInstance(), EDS_RESOURCE, @@ -2774,29 +3650,31 @@ public void edsResponseErrorHandling_someResourcesFailedUnpack() { verifySubscribedResourcesMetadataSizes(0, 0, 0, 1); // The response is NACKed with the same error message. call.verifyRequestNack(EDS, EDS_RESOURCE, "", "0000", NODE, errors); - verify(edsResourceWatcher).onChanged(edsUpdateCaptor.capture()); - EdsUpdate edsUpdate = edsUpdateCaptor.getValue(); + verify(edsResourceWatcher).onResourceChanged(edsUpdateCaptor.capture()); + StatusOr statusOrUpdate = edsUpdateCaptor.getValue(); + assertThat(statusOrUpdate.hasValue()).isTrue(); + EdsUpdate edsUpdate = statusOrUpdate.getValue(); assertThat(edsUpdate.clusterName).isEqualTo(EDS_RESOURCE); } /** * Tests a subscribed EDS resource transitioned to and from the invalid state. * - * @see - * A40-csds-support.md + * @see + * A40-csds-support.md */ @Test public void edsResponseErrorHandling_subscribedResourceInvalid() { List subscribedResourceNames = ImmutableList.of("A", "B", "C"); - xdsClient.watchXdsResource(XdsEndpointResource.getInstance(),"A", edsResourceWatcher); - xdsClient.watchXdsResource(XdsEndpointResource.getInstance(),"B", edsResourceWatcher); - xdsClient.watchXdsResource(XdsEndpointResource.getInstance(),"C", edsResourceWatcher); + xdsClient.watchXdsResource(XdsEndpointResource.getInstance(), "A", edsResourceWatcher); + xdsClient.watchXdsResource(XdsEndpointResource.getInstance(), "B", edsResourceWatcher); + xdsClient.watchXdsResource(XdsEndpointResource.getInstance(), "C", edsResourceWatcher); DiscoveryRpcCall call = resourceDiscoveryCalls.poll(); assertThat(call).isNotNull(); verifyResourceMetadataRequested(EDS, "A"); verifyResourceMetadataRequested(EDS, "B"); verifyResourceMetadataRequested(EDS, "C"); - verifySubscribedResourcesMetadataSizes(0, 0, 0, 3); // EDS -> {A, B, C}, version 1 List dropOverloads = ImmutableList.of(mf.buildDropOverload("lb", 200)); @@ -2807,6 +3685,7 @@ public void edsResponseErrorHandling_subscribedResourceInvalid() { "C", Any.pack(mf.buildClusterLoadAssignment("C", endpointsV1, dropOverloads))); call.sendResponse(EDS, resourcesV1.values().asList(), VERSION_1, "0000"); // {A, B, C} -> ACK, version 1 + verifyResourceValidInvalidCount(1, 3, 0, xdsServerInfo.target(), EDS.typeUrl()); verifyResourceMetadataAcked(EDS, "A", resourcesV1.get("A"), VERSION_1, TIME_INCREMENT); verifyResourceMetadataAcked(EDS, "B", resourcesV1.get("B"), VERSION_1, TIME_INCREMENT); verifyResourceMetadataAcked(EDS, "C", resourcesV1.get("C"), VERSION_1, TIME_INCREMENT); @@ -2822,11 +3701,13 @@ public void edsResponseErrorHandling_subscribedResourceInvalid() { // {A} -> ACK, version 2 // {B} -> NACK, version 1, rejected version 2, rejected reason: Failed to parse B // {C} -> ACK, version 1 + // Check metric data. + verifyResourceValidInvalidCount(1, 1, 1, xdsServerInfo.target(), EDS.typeUrl()); List errorsV2 = ImmutableList.of("EDS response ClusterLoadAssignment 'B' validation error: "); verifyResourceMetadataAcked(EDS, "A", resourcesV2.get("A"), VERSION_2, TIME_INCREMENT * 2); verifyResourceMetadataNacked(EDS, "B", resourcesV1.get("B"), VERSION_1, TIME_INCREMENT, - VERSION_2, TIME_INCREMENT * 2, errorsV2); + VERSION_2, TIME_INCREMENT * 2, errorsV2, true); verifyResourceMetadataAcked(EDS, "C", resourcesV1.get("C"), VERSION_1, TIME_INCREMENT); call.verifyRequestNack(EDS, subscribedResourceNames, VERSION_1, "0001", NODE, errorsV2); @@ -2839,6 +3720,8 @@ public void edsResponseErrorHandling_subscribedResourceInvalid() { call.sendResponse(EDS, resourcesV3.values().asList(), VERSION_3, "0002"); // {A} -> ACK, version 2 // {B, C} -> ACK, version 3 + // Check metric data. + verifyResourceValidInvalidCount(1, 2, 0, xdsServerInfo.target(), EDS.typeUrl()); verifyResourceMetadataAcked(EDS, "A", resourcesV2.get("A"), VERSION_2, TIME_INCREMENT * 2); verifyResourceMetadataAcked(EDS, "B", resourcesV3.get("B"), VERSION_3, TIME_INCREMENT * 3); verifyResourceMetadataAcked(EDS, "C", resourcesV3.get("C"), VERSION_3, TIME_INCREMENT * 3); @@ -2854,8 +3737,10 @@ public void edsResourceFound() { // Client sent an ACK EDS request. call.verifyRequest(EDS, EDS_RESOURCE, VERSION_1, "0000", NODE); - verify(edsResourceWatcher).onChanged(edsUpdateCaptor.capture()); - validateGoldenClusterLoadAssignment(edsUpdateCaptor.getValue()); + verify(edsResourceWatcher).onResourceChanged(edsUpdateCaptor.capture()); + StatusOr statusOrUpdate = edsUpdateCaptor.getValue(); + assertThat(statusOrUpdate.hasValue()).isTrue(); + validateGoldenClusterLoadAssignment(statusOrUpdate.getValue()); verifyResourceMetadataAcked(EDS, EDS_RESOURCE, testClusterLoadAssignment, VERSION_1, TIME_INCREMENT); verifySubscribedResourcesMetadataSizes(0, 0, 0, 1); @@ -2869,8 +3754,10 @@ public void wrappedEdsResourceFound() { // Client sent an ACK EDS request. call.verifyRequest(EDS, EDS_RESOURCE, VERSION_1, "0000", NODE); - verify(edsResourceWatcher).onChanged(edsUpdateCaptor.capture()); - validateGoldenClusterLoadAssignment(edsUpdateCaptor.getValue()); + verify(edsResourceWatcher).onResourceChanged(edsUpdateCaptor.capture()); + StatusOr statusOrUpdate = edsUpdateCaptor.getValue(); + assertThat(statusOrUpdate.hasValue()).isTrue(); + validateGoldenClusterLoadAssignment(statusOrUpdate.getValue()); verifyResourceMetadataAcked(EDS, EDS_RESOURCE, testClusterLoadAssignment, VERSION_1, TIME_INCREMENT); verifySubscribedResourcesMetadataSizes(0, 0, 0, 1); @@ -2887,9 +3774,11 @@ public void cachedEdsResource_data() { call.verifyRequest(EDS, EDS_RESOURCE, VERSION_1, "0000", NODE); // Add another watcher. ResourceWatcher watcher = mock(ResourceWatcher.class); - xdsClient.watchXdsResource(XdsEndpointResource.getInstance(),EDS_RESOURCE, watcher); - verify(watcher).onChanged(edsUpdateCaptor.capture()); - validateGoldenClusterLoadAssignment(edsUpdateCaptor.getValue()); + xdsClient.watchXdsResource(XdsEndpointResource.getInstance(), EDS_RESOURCE, watcher); + verify(watcher).onResourceChanged(edsUpdateCaptor.capture()); + StatusOr statusOrUpdate = edsUpdateCaptor.getValue(); + assertThat(statusOrUpdate.hasValue()).isTrue(); + validateGoldenClusterLoadAssignment(statusOrUpdate.getValue()); call.verifyNoMoreRequest(); verifyResourceMetadataAcked(EDS, EDS_RESOURCE, testClusterLoadAssignment, VERSION_1, TIME_INCREMENT); @@ -2902,10 +3791,12 @@ public void cachedEdsResource_absent() { DiscoveryRpcCall call = startResourceWatcher(XdsEndpointResource.getInstance(), EDS_RESOURCE, edsResourceWatcher); fakeClock.forwardTime(XdsClientImpl.INITIAL_RESOURCE_FETCH_TIMEOUT_SEC, TimeUnit.SECONDS); - verify(edsResourceWatcher).onResourceDoesNotExist(EDS_RESOURCE); + verify(edsResourceWatcher).onResourceChanged(argThat(statusOr -> + statusOr.getStatus().getDescription().contains(EDS_RESOURCE))); ResourceWatcher watcher = mock(ResourceWatcher.class); - xdsClient.watchXdsResource(XdsEndpointResource.getInstance(),EDS_RESOURCE, watcher); - verify(watcher).onResourceDoesNotExist(EDS_RESOURCE); + xdsClient.watchXdsResource(XdsEndpointResource.getInstance(), EDS_RESOURCE, watcher); + verify(watcher).onResourceChanged(argThat(statusOr -> + statusOr.getStatus().getDescription().contains(EDS_RESOURCE))); call.verifyNoMoreRequest(); verifyResourceMetadataDoesNotExist(EDS, EDS_RESOURCE); verifySubscribedResourcesMetadataSizes(0, 0, 0, 1); @@ -2923,6 +3814,7 @@ public void flowControlAbsent() throws Exception { anotherWatcher, fakeWatchClock.getScheduledExecutorService()); verifyResourceMetadataRequested(CDS, CDS_RESOURCE); verifyResourceMetadataRequested(CDS, anotherCdsResource); + DiscoveryRpcCall call = resourceDiscoveryCalls.poll(); call.verifyRequest(CDS, Arrays.asList(CDS_RESOURCE, anotherCdsResource), "", "", NODE); assertThat(fakeWatchClock.runDueTasks()).isEqualTo(2); @@ -2935,7 +3827,7 @@ public void flowControlAbsent() throws Exception { fakeClock.forwardTime(XdsClientImpl.INITIAL_RESOURCE_FETCH_TIMEOUT_SEC, TimeUnit.SECONDS); assertThat(fakeWatchClock.getPendingTasks().size()).isEqualTo(2); CyclicBarrier barrier = new CyclicBarrier(2); - doAnswer(blockUpdate(barrier)).when(cdsResourceWatcher).onChanged(any(CdsUpdate.class)); + doAnswer(blockUpdate(barrier)).when(cdsResourceWatcher).onResourceChanged(any()); CountDownLatch latch = new CountDownLatch(1); new Thread(() -> { @@ -2957,16 +3849,120 @@ public void flowControlAbsent() throws Exception { verifyResourceMetadataAcked( CDS, CDS_RESOURCE, testClusterRoundRobin, VERSION_1, TIME_INCREMENT); barrier.await(); - verify(cdsResourceWatcher, atLeastOnce()).onChanged(any()); + verify(cdsResourceWatcher, atLeastOnce()).onResourceChanged(any()); String errorMsg = "CDS response Cluster 'cluster.googleapis.com2' validation error: " + "Cluster cluster.googleapis.com2: unspecified cluster discovery type"; call.verifyRequestNack(CDS, Arrays.asList(CDS_RESOURCE, anotherCdsResource), VERSION_1, "0001", NODE, Arrays.asList(errorMsg)); barrier.await(); latch.await(10, TimeUnit.SECONDS); - verify(cdsResourceWatcher, times(2)).onChanged(any()); - verify(anotherWatcher).onResourceDoesNotExist(eq(anotherCdsResource)); - verify(anotherWatcher).onError(any()); + verify(cdsResourceWatcher, times(2)).onResourceChanged(any()); + verify(anotherWatcher, times(2)).onResourceChanged( + argThat(statusOr -> statusOr.getStatus().getDescription().contains(anotherCdsResource))); + } + + @Test + public void resourceTimerIsTransientError_schedulesExtendedTimeout() { + BootstrapperImpl.xdsDataErrorHandlingEnabled = true; + ServerInfo serverInfo = ServerInfo.create(SERVER_URI, CHANNEL_CREDENTIALS, + false, true, true, false); + BootstrapInfo bootstrapInfo = + Bootstrapper.BootstrapInfo.builder() + .servers(Collections.singletonList(serverInfo)) + .node(NODE) + .authorities(ImmutableMap.of( + "", + AuthorityInfo.create( + "xdstp:///envoy.config.listener.v3.Listener/%s", + ImmutableList.of(Bootstrapper.ServerInfo.create( + SERVER_URI_EMPTY_AUTHORITY, CHANNEL_CREDENTIALS))))) + .certProviders(ImmutableMap.of()) + .build(); + xdsClient = new XdsClientImpl( + xdsTransportFactory, + bootstrapInfo, + fakeClock.getScheduledExecutorService(), + backoffPolicyProvider, + fakeClock.getStopwatchSupplier(), + timeProvider, + MessagePrinter.INSTANCE, + new TlsContextManagerImpl(bootstrapInfo), + xdsClientMetricReporter); + @SuppressWarnings("unchecked") + ResourceWatcher watcher = mock(ResourceWatcher.class); + String resourceName = "cluster.googleapis.com"; + + xdsClient.watchXdsResource( + XdsClusterResource.getInstance(), + resourceName, + watcher, + fakeClock.getScheduledExecutorService()); + + ScheduledTask task = Iterables.getOnlyElement( + fakeClock.getPendingTasks(CDS_RESOURCE_FETCH_TIMEOUT_TASK_FILTER)); + assertThat(task.getDelay(TimeUnit.SECONDS)) + .isEqualTo(XdsClientImpl.EXTENDED_RESOURCE_FETCH_TIMEOUT_SEC); + fakeClock.runDueTasks(); + BootstrapperImpl.xdsDataErrorHandlingEnabled = false; + } + + @Test + public void resourceTimerIsTransientError_callsOnErrorUnavailable() { + BootstrapperImpl.xdsDataErrorHandlingEnabled = true; + xdsServerInfo = ServerInfo.create(SERVER_URI, CHANNEL_CREDENTIALS, ignoreResourceDeletion(), + true, true, false); + BootstrapInfo bootstrapInfo = + Bootstrapper.BootstrapInfo.builder() + .servers(Collections.singletonList(xdsServerInfo)) + .node(NODE) + .authorities(ImmutableMap.of( + "authority.xds.com", + AuthorityInfo.create( + "xdstp://authority.xds.com/envoy.config.listener.v3.Listener/%s", + ImmutableList.of(Bootstrapper.ServerInfo.create( + SERVER_URI_CUSTOM_AUTHORITY, CHANNEL_CREDENTIALS))), + "", + AuthorityInfo.create( + "xdstp:///envoy.config.listener.v3.Listener/%s", + ImmutableList.of(Bootstrapper.ServerInfo.create( + SERVER_URI_EMPTY_AUTHORITY, CHANNEL_CREDENTIALS))))) + .certProviders(ImmutableMap.of("cert-instance-name", + CertificateProviderInfo.create("file-watcher", ImmutableMap.of()))) + .build(); + xdsClient = new XdsClientImpl( + xdsTransportFactory, + bootstrapInfo, + fakeClock.getScheduledExecutorService(), + backoffPolicyProvider, + fakeClock.getStopwatchSupplier(), + timeProvider, + MessagePrinter.INSTANCE, + new TlsContextManagerImpl(bootstrapInfo), + xdsClientMetricReporter); + String timeoutResource = CDS_RESOURCE + "_timeout"; + @SuppressWarnings("unchecked") + ResourceWatcher timeoutWatcher = mock(ResourceWatcher.class); + + xdsClient.watchXdsResource( + XdsClusterResource.getInstance(), + timeoutResource, + timeoutWatcher, + fakeClock.getScheduledExecutorService()); + + assertThat(resourceDiscoveryCalls).hasSize(1); + DiscoveryRpcCall call = resourceDiscoveryCalls.poll(); + call.verifyRequest(CDS, ImmutableList.of(timeoutResource), "", "", NODE); + fakeClock.forwardTime(XdsClientImpl.EXTENDED_RESOURCE_FETCH_TIMEOUT_SEC, TimeUnit.SECONDS); + fakeClock.runDueTasks(); + @SuppressWarnings("unchecked") + ArgumentCaptor> statusOrCaptor = ArgumentCaptor.forClass(StatusOr.class); + verify(timeoutWatcher).onResourceChanged(statusOrCaptor.capture()); + StatusOr statusOr = statusOrCaptor.getValue(); + Status error = statusOr.getStatus(); + assertThat(error.getCode()).isEqualTo(Status.Code.UNAVAILABLE); + assertThat(error.getDescription()).isEqualTo( + "Timed out waiting for resource " + timeoutResource + " from xDS server"); + BootstrapperImpl.xdsDataErrorHandlingEnabled = false; } private Answer blockUpdate(CyclicBarrier barrier) { @@ -2997,7 +3993,7 @@ public void simpleFlowControl() throws Exception { // Updated EDS response. Any updatedClusterLoadAssignment = Any.pack(mf.buildClusterLoadAssignment(EDS_RESOURCE, ImmutableList.of(mf.buildLocalityLbEndpoints("region2", "zone2", "subzone2", - mf.buildLbEndpoint("172.44.2.2", 8000, "unknown", 3), 2, 0)), + mf.buildLbEndpoint("172.44.2.2", 8000, "unknown", 3, "endpoint-host-name"), 2, 0)), ImmutableList.of())); call.sendResponse(EDS, updatedClusterLoadAssignment, VERSION_2, "0001"); // message not processed due to flow control @@ -3005,7 +4001,7 @@ public void simpleFlowControl() throws Exception { assertThat(call.isReady()).isFalse(); CyclicBarrier barrier = new CyclicBarrier(2); - doAnswer(blockUpdate(barrier)).when(edsResourceWatcher).onChanged(any(EdsUpdate.class)); + doAnswer(blockUpdate(barrier)).when(edsResourceWatcher).onResourceChanged(any()); CountDownLatch latch = new CountDownLatch(1); new Thread(() -> { @@ -3020,12 +4016,14 @@ public void simpleFlowControl() throws Exception { verifyResourceMetadataAcked(EDS, EDS_RESOURCE, testClusterLoadAssignment, VERSION_1, TIME_INCREMENT); barrier.await(); - verify(edsResourceWatcher, atLeastOnce()).onChanged(edsUpdateCaptor.capture()); - EdsUpdate edsUpdate = edsUpdateCaptor.getAllValues().get(0); + verify(edsResourceWatcher, atLeastOnce()).onResourceChanged(edsUpdateCaptor.capture()); + StatusOr statusOrUpdate = edsUpdateCaptor.getAllValues().get(0); + assertThat(statusOrUpdate.hasValue()).isTrue(); + EdsUpdate edsUpdate = statusOrUpdate.getValue(); validateGoldenClusterLoadAssignment(edsUpdate); barrier.await(); latch.await(10, TimeUnit.SECONDS); - verify(edsResourceWatcher, times(2)).onChanged(any()); + verify(edsResourceWatcher, times(2)).onResourceChanged(any()); verifyResourceMetadataAcked(EDS, EDS_RESOURCE, updatedClusterLoadAssignment, VERSION_2, TIME_INCREMENT * 2); } @@ -3037,7 +4035,7 @@ public void flowControlUnknownType() { call.sendResponse(CDS, testClusterRoundRobin, VERSION_1, "0000"); call.sendResponse(EDS, testClusterLoadAssignment, VERSION_1, "0000"); call.verifyRequest(EDS, EDS_RESOURCE, VERSION_1, "0000", NODE); - verify(edsResourceWatcher).onChanged(any()); + verify(edsResourceWatcher).onResourceChanged(any()); } @Test @@ -3049,8 +4047,10 @@ public void edsResourceUpdated() { // Initial EDS response. call.sendResponse(EDS, testClusterLoadAssignment, VERSION_1, "0000"); call.verifyRequest(EDS, EDS_RESOURCE, VERSION_1, "0000", NODE); - verify(edsResourceWatcher).onChanged(edsUpdateCaptor.capture()); - EdsUpdate edsUpdate = edsUpdateCaptor.getValue(); + verify(edsResourceWatcher).onResourceChanged(edsUpdateCaptor.capture()); + StatusOr statusOrUpdate = edsUpdateCaptor.getValue(); + assertThat(statusOrUpdate.hasValue()).isTrue(); + EdsUpdate edsUpdate = statusOrUpdate.getValue(); validateGoldenClusterLoadAssignment(edsUpdate); verifyResourceMetadataAcked(EDS, EDS_RESOURCE, testClusterLoadAssignment, VERSION_1, TIME_INCREMENT); @@ -3058,12 +4058,14 @@ public void edsResourceUpdated() { // Updated EDS response. Any updatedClusterLoadAssignment = Any.pack(mf.buildClusterLoadAssignment(EDS_RESOURCE, ImmutableList.of(mf.buildLocalityLbEndpoints("region2", "zone2", "subzone2", - mf.buildLbEndpoint("172.44.2.2", 8000, "unknown", 3), 2, 0)), + mf.buildLbEndpoint("172.44.2.2", 8000, "unknown", 3, "endpoint-host-name"), 2, 0)), ImmutableList.of())); call.sendResponse(EDS, updatedClusterLoadAssignment, VERSION_2, "0001"); - verify(edsResourceWatcher, times(2)).onChanged(edsUpdateCaptor.capture()); - edsUpdate = edsUpdateCaptor.getValue(); + verify(edsResourceWatcher, times(2)).onResourceChanged(edsUpdateCaptor.capture()); + statusOrUpdate = edsUpdateCaptor.getValue(); + assertThat(statusOrUpdate.hasValue()).isTrue(); + edsUpdate = statusOrUpdate.getValue(); assertThat(edsUpdate.clusterName).isEqualTo(EDS_RESOURCE); assertThat(edsUpdate.dropPolicies).isEmpty(); assertThat(edsUpdate.localityLbEndpointsMap) @@ -3071,7 +4073,9 @@ public void edsResourceUpdated() { Locality.create("region2", "zone2", "subzone2"), LocalityLbEndpoints.create( ImmutableList.of( - LbEndpoint.create("172.44.2.2", 8000, 3, true)), 2, 0)); + LbEndpoint.create("172.44.2.2", 8000, 3, + true, "endpoint-host-name", ImmutableMap.of())), + 2, 0, ImmutableMap.of())); verifyResourceMetadataAcked(EDS, EDS_RESOURCE, updatedClusterLoadAssignment, VERSION_2, TIME_INCREMENT * 2); verifySubscribedResourcesMetadataSizes(0, 0, 0, 1); @@ -3087,9 +4091,9 @@ public void edsDuplicateLocalityInTheSamePriority() { Any updatedClusterLoadAssignment = Any.pack(mf.buildClusterLoadAssignment(EDS_RESOURCE, ImmutableList.of( mf.buildLocalityLbEndpoints("region2", "zone2", "subzone2", - mf.buildLbEndpoint("172.44.2.2", 8000, "unknown", 3), 2, 1), + mf.buildLbEndpoint("172.44.2.2", 8000, "unknown", 3, "endpoint-host-name"), 2, 1), mf.buildLocalityLbEndpoints("region2", "zone2", "subzone2", - mf.buildLbEndpoint("172.44.2.3", 8080, "healthy", 10), 2, 1) + mf.buildLbEndpoint("172.44.2.3", 8080, "healthy", 10, "endpoint-host-name"), 2, 1) ), ImmutableList.of())); call.sendResponse(EDS, updatedClusterLoadAssignment, "0", "0001"); @@ -3099,6 +4103,12 @@ public void edsDuplicateLocalityInTheSamePriority() { + "locality:Locality{region=region2, zone=zone2, subZone=subzone2} for priority:1"; call.verifyRequestNack(EDS, EDS_RESOURCE, "", "0001", NODE, ImmutableList.of( errorMsg)); + @SuppressWarnings("unchecked") + ArgumentCaptor> captor = ArgumentCaptor.forClass(StatusOr.class); + verify(edsResourceWatcher).onResourceChanged(captor.capture()); + StatusOr statusOrUpdate = captor.getValue(); + assertThat(statusOrUpdate.hasValue()).isFalse(); + assertThat(statusOrUpdate.getStatus().getDescription()).contains(errorMsg); } @Test @@ -3107,10 +4117,10 @@ public void edsResourceDeletedByCds() { String resource = "backend-service.googleapis.com"; ResourceWatcher cdsWatcher = mock(ResourceWatcher.class); ResourceWatcher edsWatcher = mock(ResourceWatcher.class); - xdsClient.watchXdsResource(XdsClusterResource.getInstance(),resource, cdsWatcher); - xdsClient.watchXdsResource(XdsEndpointResource.getInstance(),resource, edsWatcher); - xdsClient.watchXdsResource(XdsClusterResource.getInstance(),CDS_RESOURCE, cdsResourceWatcher); - xdsClient.watchXdsResource(XdsEndpointResource.getInstance(),EDS_RESOURCE, edsResourceWatcher); + xdsClient.watchXdsResource(XdsClusterResource.getInstance(), resource, cdsWatcher); + xdsClient.watchXdsResource(XdsEndpointResource.getInstance(), resource, edsWatcher); + xdsClient.watchXdsResource(XdsClusterResource.getInstance(), CDS_RESOURCE, cdsResourceWatcher); + xdsClient.watchXdsResource(XdsEndpointResource.getInstance(), EDS_RESOURCE, edsResourceWatcher); verifyResourceMetadataRequested(CDS, CDS_RESOURCE); verifyResourceMetadataRequested(CDS, resource); verifyResourceMetadataRequested(EDS, EDS_RESOURCE); @@ -3125,12 +4135,13 @@ public void edsResourceDeletedByCds() { Any.pack(mf.buildEdsCluster(CDS_RESOURCE, EDS_RESOURCE, "round_robin", null, null, false, null, "envoy.transport_sockets.tls", null, null))); call.sendResponse(CDS, clusters, VERSION_1, "0000"); - verify(cdsWatcher).onChanged(cdsUpdateCaptor.capture()); - CdsUpdate cdsUpdate = cdsUpdateCaptor.getValue(); + ArgumentCaptor> cdsUpdateCaptor = ArgumentCaptor.forClass(StatusOr.class); + verify(cdsWatcher, times(1)).onResourceChanged(cdsUpdateCaptor.capture()); + CdsUpdate cdsUpdate = cdsUpdateCaptor.getValue().getValue(); assertThat(cdsUpdate.edsServiceName()).isEqualTo(null); assertThat(cdsUpdate.lrsServerInfo()).isEqualTo(xdsServerInfo); - verify(cdsResourceWatcher).onChanged(cdsUpdateCaptor.capture()); - cdsUpdate = cdsUpdateCaptor.getValue(); + verify(cdsResourceWatcher, times(1)).onResourceChanged(cdsUpdateCaptor.capture()); + cdsUpdate = cdsUpdateCaptor.getValue().getValue(); assertThat(cdsUpdate.edsServiceName()).isEqualTo(EDS_RESOURCE); assertThat(cdsUpdate.lrsServerInfo()).isNull(); verifyResourceMetadataAcked(CDS, resource, clusters.get(0), VERSION_1, TIME_INCREMENT); @@ -3150,13 +4161,15 @@ public void edsResourceDeletedByCds() { mf.buildClusterLoadAssignment(resource, ImmutableList.of( mf.buildLocalityLbEndpoints("region2", "zone2", "subzone2", - mf.buildLbEndpoint("192.168.0.2", 9090, "healthy", 3), 1, 0)), + mf.buildLbEndpoint("192.168.0.2", 9090, "healthy", 3, + "endpoint-host-name"), 1, 0)), ImmutableList.of(mf.buildDropOverload("lb", 100))))); call.sendResponse(EDS, clusterLoadAssignments, VERSION_1, "0000"); - verify(edsWatcher).onChanged(edsUpdateCaptor.capture()); - assertThat(edsUpdateCaptor.getValue().clusterName).isEqualTo(resource); - verify(edsResourceWatcher).onChanged(edsUpdateCaptor.capture()); - assertThat(edsUpdateCaptor.getValue().clusterName).isEqualTo(EDS_RESOURCE); + ArgumentCaptor> edsUpdateCaptor = ArgumentCaptor.forClass(StatusOr.class); + verify(edsWatcher, times(1)).onResourceChanged(edsUpdateCaptor.capture()); + assertThat(edsUpdateCaptor.getValue().getValue().clusterName).isEqualTo(resource); + verify(edsResourceWatcher, times(1)).onResourceChanged(edsUpdateCaptor.capture()); + assertThat(edsUpdateCaptor.getValue().getValue().clusterName).isEqualTo(EDS_RESOURCE); verifyResourceMetadataAcked( EDS, EDS_RESOURCE, clusterLoadAssignments.get(0), VERSION_1, TIME_INCREMENT * 2); @@ -3174,12 +4187,8 @@ public void edsResourceDeletedByCds() { "envoy.transport_sockets.tls", null, null ))); call.sendResponse(CDS, clusters, VERSION_2, "0001"); - verify(cdsResourceWatcher, times(2)).onChanged(cdsUpdateCaptor.capture()); - assertThat(cdsUpdateCaptor.getValue().edsServiceName()).isNull(); - // Note that the endpoint must be deleted even if the ignore_resource_deletion feature. - // This happens because the cluster CDS_RESOURCE is getting replaced, and not deleted. - verify(edsResourceWatcher, never()).onResourceDoesNotExist(EDS_RESOURCE); - verify(edsResourceWatcher, never()).onResourceDoesNotExist(resource); + verify(cdsResourceWatcher, times(2)).onResourceChanged(cdsUpdateCaptor.capture()); + assertThat(cdsUpdateCaptor.getValue().getValue().edsServiceName()).isNull(); verifyNoMoreInteractions(cdsWatcher, edsWatcher); verifyResourceMetadataAcked( EDS, EDS_RESOURCE, clusterLoadAssignments.get(0), VERSION_1, TIME_INCREMENT * 2); @@ -3187,7 +4196,6 @@ public void edsResourceDeletedByCds() { EDS, resource, clusterLoadAssignments.get(1), VERSION_1, TIME_INCREMENT * 2); // no change verifyResourceMetadataAcked(CDS, resource, clusters.get(0), VERSION_2, TIME_INCREMENT * 3); verifyResourceMetadataAcked(CDS, CDS_RESOURCE, clusters.get(1), VERSION_2, TIME_INCREMENT * 3); - verifySubscribedResourcesMetadataSizes(0, 2, 0, 2); } @Test @@ -3196,9 +4204,9 @@ public void multipleEdsWatchers() { String edsResourceTwo = "cluster-load-assignment-bar.googleapis.com"; ResourceWatcher watcher1 = mock(ResourceWatcher.class); ResourceWatcher watcher2 = mock(ResourceWatcher.class); - xdsClient.watchXdsResource(XdsEndpointResource.getInstance(),EDS_RESOURCE, edsResourceWatcher); - xdsClient.watchXdsResource(XdsEndpointResource.getInstance(),edsResourceTwo, watcher1); - xdsClient.watchXdsResource(XdsEndpointResource.getInstance(),edsResourceTwo, watcher2); + xdsClient.watchXdsResource(XdsEndpointResource.getInstance(), EDS_RESOURCE, edsResourceWatcher); + xdsClient.watchXdsResource(XdsEndpointResource.getInstance(), edsResourceTwo, watcher1); + xdsClient.watchXdsResource(XdsEndpointResource.getInstance(), edsResourceTwo, watcher2); DiscoveryRpcCall call = resourceDiscoveryCalls.poll(); call.verifyRequest(EDS, Arrays.asList(EDS_RESOURCE, edsResourceTwo), "", "", NODE); verifyResourceMetadataRequested(EDS, EDS_RESOURCE); @@ -3206,16 +4214,24 @@ public void multipleEdsWatchers() { verifySubscribedResourcesMetadataSizes(0, 0, 0, 2); fakeClock.forwardTime(XdsClientImpl.INITIAL_RESOURCE_FETCH_TIMEOUT_SEC, TimeUnit.SECONDS); - verify(edsResourceWatcher).onResourceDoesNotExist(EDS_RESOURCE); - verify(watcher1).onResourceDoesNotExist(edsResourceTwo); - verify(watcher2).onResourceDoesNotExist(edsResourceTwo); + verify(edsResourceWatcher).onResourceChanged( + argThat(statusOr -> !statusOr.hasValue() + && statusOr.getStatus().getDescription().contains(EDS_RESOURCE))); + verify(watcher1).onResourceChanged( + argThat(statusOr -> !statusOr.hasValue() + && statusOr.getStatus().getDescription().contains(edsResourceTwo))); + verify(watcher2).onResourceChanged( + argThat(statusOr -> !statusOr.hasValue() + && statusOr.getStatus().getDescription().contains(edsResourceTwo))); verifyResourceMetadataDoesNotExist(EDS, EDS_RESOURCE); verifyResourceMetadataDoesNotExist(EDS, edsResourceTwo); verifySubscribedResourcesMetadataSizes(0, 0, 0, 2); call.sendResponse(EDS, testClusterLoadAssignment, VERSION_1, "0000"); - verify(edsResourceWatcher).onChanged(edsUpdateCaptor.capture()); - EdsUpdate edsUpdate = edsUpdateCaptor.getValue(); + verify(edsResourceWatcher, times(2)).onResourceChanged(edsUpdateCaptor.capture()); + StatusOr statusOrUpdate = edsUpdateCaptor.getValue(); + assertThat(statusOrUpdate.hasValue()).isTrue(); + EdsUpdate edsUpdate = statusOrUpdate.getValue(); validateGoldenClusterLoadAssignment(edsUpdate); verifyNoMoreInteractions(watcher1, watcher2); verifyResourceMetadataAcked( @@ -3227,12 +4243,15 @@ public void multipleEdsWatchers() { mf.buildClusterLoadAssignment(edsResourceTwo, ImmutableList.of( mf.buildLocalityLbEndpoints("region2", "zone2", "subzone2", - mf.buildLbEndpoint("172.44.2.2", 8000, "healthy", 3), 2, 0)), + mf.buildLbEndpoint("172.44.2.2", 8000, "healthy", 3, "endpoint-host-name"), + 2, 0)), ImmutableList.of())); call.sendResponse(EDS, clusterLoadAssignmentTwo, VERSION_2, "0001"); - verify(watcher1).onChanged(edsUpdateCaptor.capture()); - edsUpdate = edsUpdateCaptor.getValue(); + verify(watcher1, times(2)).onResourceChanged(edsUpdateCaptor.capture()); + statusOrUpdate = edsUpdateCaptor.getValue(); + assertThat(statusOrUpdate.hasValue()).isTrue(); + edsUpdate = statusOrUpdate.getValue(); assertThat(edsUpdate.clusterName).isEqualTo(edsResourceTwo); assertThat(edsUpdate.dropPolicies).isEmpty(); assertThat(edsUpdate.localityLbEndpointsMap) @@ -3240,9 +4259,13 @@ public void multipleEdsWatchers() { Locality.create("region2", "zone2", "subzone2"), LocalityLbEndpoints.create( ImmutableList.of( - LbEndpoint.create("172.44.2.2", 8000, 3, true)), 2, 0)); - verify(watcher2).onChanged(edsUpdateCaptor.capture()); - edsUpdate = edsUpdateCaptor.getValue(); + LbEndpoint.create("172.44.2.2", 8000, 3, + true, "endpoint-host-name", ImmutableMap.of())), + 2, 0, ImmutableMap.of())); + verify(watcher2, times(2)).onResourceChanged(edsUpdateCaptor.capture()); + statusOrUpdate = edsUpdateCaptor.getValue(); + assertThat(statusOrUpdate.hasValue()).isTrue(); + edsUpdate = statusOrUpdate.getValue(); assertThat(edsUpdate.clusterName).isEqualTo(edsResourceTwo); assertThat(edsUpdate.dropPolicies).isEmpty(); assertThat(edsUpdate.localityLbEndpointsMap) @@ -3250,7 +4273,9 @@ public void multipleEdsWatchers() { Locality.create("region2", "zone2", "subzone2"), LocalityLbEndpoints.create( ImmutableList.of( - LbEndpoint.create("172.44.2.2", 8000, 3, true)), 2, 0)); + LbEndpoint.create("172.44.2.2", 8000, 3, + true, "endpoint-host-name", ImmutableMap.of())), + 2, 0, ImmutableMap.of())); verifyNoMoreInteractions(edsResourceWatcher); verifyResourceMetadataAcked( EDS, edsResourceTwo, clusterLoadAssignmentTwo, VERSION_2, TIME_INCREMENT * 2); @@ -3271,43 +4296,113 @@ public void useIndependentRpcContext() { // The inbound RPC finishes and closes its context. The outbound RPC's control plane RPC // should not be impacted. cancellableContext.close(); - verify(ldsResourceWatcher, never()).onError(any(Status.class)); + verify(ldsResourceWatcher, never()).onAmbientError(any(Status.class)); + verify(ldsResourceWatcher, never()).onResourceChanged(argThat( + statusOr -> !statusOr.hasValue() + )); call.sendResponse(LDS, testListenerRds, VERSION_1, "0000"); - verify(ldsResourceWatcher).onChanged(any(LdsUpdate.class)); + verify(ldsResourceWatcher).onResourceChanged(any()); } finally { cancellableContext.detach(prevContext); } } + @Test + public void streamClosedWithNoResponse() { + xdsClient.watchXdsResource(XdsListenerResource.getInstance(), LDS_RESOURCE, ldsResourceWatcher); + xdsClient.watchXdsResource(XdsRouteConfigureResource.getInstance(), RDS_RESOURCE, + rdsResourceWatcher); + DiscoveryRpcCall call = resourceDiscoveryCalls.poll(); + // Check metric data. + callback_ReportServerConnection(); + verifyServerConnection(1, true, xdsServerInfo.target()); + // Management server closes the RPC stream before sending any response. + call.sendCompleted(); + // Check metric data. + callback_ReportServerConnection(); + verifyServerConnection(1, false, xdsServerInfo.target()); + verify(ldsResourceWatcher, Mockito.timeout(1000)).onResourceChanged(ldsUpdateCaptor.capture()); + StatusOr ldsStatusOr = ldsUpdateCaptor.getValue(); + assertThat(ldsStatusOr.hasValue()).isFalse(); + verifyStatusWithNodeId(ldsStatusOr.getStatus(), Code.UNAVAILABLE, + "ADS stream closed with OK before receiving a response"); + verify(rdsResourceWatcher, Mockito.timeout(1000)).onResourceChanged(rdsUpdateCaptor.capture()); + StatusOr rdsStatusOr = rdsUpdateCaptor.getValue(); + assertThat(rdsStatusOr.hasValue()).isFalse(); + verifyStatusWithNodeId(rdsStatusOr.getStatus(), Code.UNAVAILABLE, + "ADS stream closed with OK before receiving a response"); + } + + @Test + public void streamClosedAfterSendingResponses() { + xdsClient.watchXdsResource(XdsListenerResource.getInstance(), LDS_RESOURCE, ldsResourceWatcher); + xdsClient.watchXdsResource(XdsRouteConfigureResource.getInstance(), RDS_RESOURCE, + rdsResourceWatcher); + DiscoveryRpcCall call = resourceDiscoveryCalls.poll(); + // Check metric data. + callback_ReportServerConnection(); + verifyServerConnection(1, true, xdsServerInfo.target()); + ScheduledTask ldsResourceTimeout = + Iterables.getOnlyElement(fakeClock.getPendingTasks(LDS_RESOURCE_FETCH_TIMEOUT_TASK_FILTER)); + ScheduledTask rdsResourceTimeout = + Iterables.getOnlyElement(fakeClock.getPendingTasks(RDS_RESOURCE_FETCH_TIMEOUT_TASK_FILTER)); + call.sendResponse(LDS, testListenerRds, VERSION_1, "0000"); + // Check metric data. + callback_ReportServerConnection(); + verifyServerConnection(2, true, xdsServerInfo.target()); + assertThat(ldsResourceTimeout.isCancelled()).isTrue(); + call.sendResponse(RDS, testRouteConfig, VERSION_1, "0000"); + assertThat(rdsResourceTimeout.isCancelled()).isTrue(); + // Management server closes the RPC stream after sending responses. + call.sendCompleted(); + // Check metric data. + callback_ReportServerConnection(); + verifyServerConnection(3, true, xdsServerInfo.target()); + verify(ldsResourceWatcher, never()).onAmbientError(any(Status.class)); + verify(rdsResourceWatcher, never()).onAmbientError(any(Status.class)); + verify(ldsResourceWatcher, times(1)).onResourceChanged(any()); + verify(rdsResourceWatcher, times(1)).onResourceChanged(any()); + } + @Test public void streamClosedAndRetryWithBackoff() { - InOrder inOrder = Mockito.inOrder(backoffPolicyProvider, backoffPolicy1, backoffPolicy2); - xdsClient.watchXdsResource(XdsListenerResource.getInstance(),LDS_RESOURCE, ldsResourceWatcher); - xdsClient.watchXdsResource(XdsRouteConfigureResource.getInstance(),RDS_RESOURCE, + InOrder inOrder = inOrder(backoffPolicyProvider, backoffPolicy1, backoffPolicy2); + InOrder ldsWatcherInOrder = inOrder(ldsResourceWatcher); + InOrder rdsWatcherInOrder = inOrder(rdsResourceWatcher); + InOrder cdsWatcherInOrder = inOrder(cdsResourceWatcher); + InOrder edsWatcherInOrder = inOrder(edsResourceWatcher); + when(backoffPolicyProvider.get()).thenReturn(backoffPolicy1, backoffPolicy2, backoffPolicy2); + xdsClient.watchXdsResource(XdsListenerResource.getInstance(), LDS_RESOURCE, ldsResourceWatcher); + xdsClient.watchXdsResource(XdsRouteConfigureResource.getInstance(), RDS_RESOURCE, rdsResourceWatcher); - xdsClient.watchXdsResource(XdsClusterResource.getInstance(),CDS_RESOURCE, cdsResourceWatcher); - xdsClient.watchXdsResource(XdsEndpointResource.getInstance(),EDS_RESOURCE, edsResourceWatcher); + xdsClient.watchXdsResource(XdsClusterResource.getInstance(), CDS_RESOURCE, cdsResourceWatcher); + xdsClient.watchXdsResource(XdsEndpointResource.getInstance(), EDS_RESOURCE, edsResourceWatcher); DiscoveryRpcCall call = resourceDiscoveryCalls.poll(); call.verifyRequest(LDS, LDS_RESOURCE, "", "", NODE); call.verifyRequest(RDS, RDS_RESOURCE, "", "", NODE); call.verifyRequest(CDS, CDS_RESOURCE, "", "", NODE); call.verifyRequest(EDS, EDS_RESOURCE, "", "", NODE); - // Management server closes the RPC stream with an error. + // Management server closes the RPC stream with an error. No response received yet. + fakeClock.forwardNanos(1000L); // Make sure retry isn't based on stopwatch 0 call.sendError(Status.UNKNOWN.asException()); - verify(ldsResourceWatcher, Mockito.timeout(1000).times(1)) - .onError(errorCaptor.capture()); - verifyStatusWithNodeId(errorCaptor.getValue(), Code.UNKNOWN, ""); - verify(rdsResourceWatcher).onError(errorCaptor.capture()); - verifyStatusWithNodeId(errorCaptor.getValue(), Code.UNKNOWN, ""); - verify(cdsResourceWatcher).onError(errorCaptor.capture()); - verifyStatusWithNodeId(errorCaptor.getValue(), Code.UNKNOWN, ""); - verify(edsResourceWatcher).onError(errorCaptor.capture()); - verifyStatusWithNodeId(errorCaptor.getValue(), Code.UNKNOWN, ""); + ldsWatcherInOrder.verify(ldsResourceWatcher, timeout(1000)).onResourceChanged( + argThat(statusOr -> !statusOr.hasValue() + && statusOr.getStatus().getCode() == Code.UNKNOWN)); + rdsWatcherInOrder.verify(rdsResourceWatcher).onResourceChanged( + argThat(statusOr -> !statusOr.hasValue() + && statusOr.getStatus().getCode() == Code.UNKNOWN)); + cdsWatcherInOrder.verify(cdsResourceWatcher).onResourceChanged( + argThat(statusOr -> !statusOr.hasValue() + && statusOr.getStatus().getCode() == Code.UNKNOWN)); + edsWatcherInOrder.verify(edsResourceWatcher).onResourceChanged( + argThat(statusOr -> !statusOr.hasValue() + && statusOr.getStatus().getCode() == Code.UNKNOWN)); + + verifyServerFailureCount(1, 1, xdsServerInfo.target()); // Retry after backoff. - inOrder.verify(backoffPolicyProvider).get(); inOrder.verify(backoffPolicy1).nextBackoffNanos(); ScheduledTask retryTask = Iterables.getOnlyElement(fakeClock.getPendingTasks(RPC_RETRY_TASK_FILTER)); @@ -3319,17 +4414,23 @@ public void streamClosedAndRetryWithBackoff() { call.verifyRequest(CDS, CDS_RESOURCE, "", "", NODE); call.verifyRequest(EDS, EDS_RESOURCE, "", "", NODE); - // Management server becomes unreachable. + // Management server becomes unreachable. No response received on this stream either. String errorMsg = "my fault"; call.sendError(Status.UNAVAILABLE.withDescription(errorMsg).asException()); - verify(ldsResourceWatcher, times(2)).onError(errorCaptor.capture()); - verifyStatusWithNodeId(errorCaptor.getValue(), Code.UNAVAILABLE, errorMsg); - verify(rdsResourceWatcher, times(2)).onError(errorCaptor.capture()); - verifyStatusWithNodeId(errorCaptor.getValue(), Code.UNAVAILABLE, errorMsg); - verify(cdsResourceWatcher, times(2)).onError(errorCaptor.capture()); - verifyStatusWithNodeId(errorCaptor.getValue(), Code.UNAVAILABLE, errorMsg); - verify(edsResourceWatcher, times(2)).onError(errorCaptor.capture()); - verifyStatusWithNodeId(errorCaptor.getValue(), Code.UNAVAILABLE, errorMsg); + ldsWatcherInOrder.verify(ldsResourceWatcher).onResourceChanged( + argThat(statusOr -> !statusOr.hasValue() + && statusOr.getStatus().getCode() == Code.UNAVAILABLE)); + rdsWatcherInOrder.verify(rdsResourceWatcher).onResourceChanged( + argThat(statusOr -> !statusOr.hasValue() + && statusOr.getStatus().getCode() == Code.UNAVAILABLE)); + cdsWatcherInOrder.verify(cdsResourceWatcher).onResourceChanged( + argThat(statusOr -> !statusOr.hasValue() + && statusOr.getStatus().getCode() == Code.UNAVAILABLE)); + edsWatcherInOrder.verify(edsResourceWatcher).onResourceChanged( + argThat(statusOr -> !statusOr.hasValue() + && statusOr.getStatus().getCode() == Code.UNAVAILABLE)); + + verifyServerFailureCount(2, 1, xdsServerInfo.target()); // Retry after backoff. inOrder.verify(backoffPolicy1).nextBackoffNanos(); @@ -3348,41 +4449,49 @@ public void streamClosedAndRetryWithBackoff() { mf.buildRouteConfiguration("do not care", mf.buildOpaqueVirtualHosts(2))))); call.sendResponse(LDS, listeners, "63", "3242"); call.verifyRequest(LDS, LDS_RESOURCE, "63", "3242", NODE); + ldsWatcherInOrder.verify(ldsResourceWatcher).onResourceChanged( + argThat(statusOr -> statusOr.hasValue())); List routeConfigs = ImmutableList.of( Any.pack(mf.buildRouteConfiguration(RDS_RESOURCE, mf.buildOpaqueVirtualHosts(2)))); call.sendResponse(RDS, routeConfigs, "5", "6764"); call.verifyRequest(RDS, RDS_RESOURCE, "5", "6764", NODE); + rdsWatcherInOrder.verify(rdsResourceWatcher).onResourceChanged( + argThat(statusOr -> statusOr.hasValue())); + // Stream fails AFTER a response. Error is suppressed and no watcher notification occurs. call.sendError(Status.DEADLINE_EXCEEDED.asException()); - verify(ldsResourceWatcher, times(2)).onError(errorCaptor.capture()); - verify(rdsResourceWatcher, times(2)).onError(errorCaptor.capture()); - verify(cdsResourceWatcher, times(3)).onError(errorCaptor.capture()); - verifyStatusWithNodeId(errorCaptor.getValue(), Code.DEADLINE_EXCEEDED, ""); - verify(edsResourceWatcher, times(3)).onError(errorCaptor.capture()); - verifyStatusWithNodeId(errorCaptor.getValue(), Code.DEADLINE_EXCEEDED, ""); + + // Failure count does NOT increase. + verifyServerFailureCount(2, 1, xdsServerInfo.target()); // Reset backoff sequence and retry after backoff. inOrder.verify(backoffPolicyProvider).get(); inOrder.verify(backoffPolicy2).nextBackoffNanos(); retryTask = Iterables.getOnlyElement(fakeClock.getPendingTasks(RPC_RETRY_TASK_FILTER)); - assertThat(retryTask.getDelay(TimeUnit.NANOSECONDS)).isEqualTo(20L); - fakeClock.forwardNanos(20L); + fakeClock.forwardNanos(retryTask.getDelay(TimeUnit.NANOSECONDS)); call = resourceDiscoveryCalls.poll(); call.verifyRequest(LDS, LDS_RESOURCE, "63", "", NODE); call.verifyRequest(RDS, RDS_RESOURCE, "5", "", NODE); call.verifyRequest(CDS, CDS_RESOURCE, "", "", NODE); call.verifyRequest(EDS, EDS_RESOURCE, "", "", NODE); - // Management server becomes unreachable again. + // Management server becomes unreachable again. This is on a new stream, so error propagates. call.sendError(Status.UNAVAILABLE.asException()); - verify(ldsResourceWatcher, times(2)).onError(errorCaptor.capture()); - verify(rdsResourceWatcher, times(2)).onError(errorCaptor.capture()); - verify(cdsResourceWatcher, times(4)).onError(errorCaptor.capture()); - verifyStatusWithNodeId(errorCaptor.getValue(), Code.UNAVAILABLE, ""); - verify(edsResourceWatcher, times(4)).onError(errorCaptor.capture()); - verifyStatusWithNodeId(errorCaptor.getValue(), Code.UNAVAILABLE, ""); + ldsWatcherInOrder.verify(ldsResourceWatcher).onAmbientError( + argThat(status -> status.getCode() == Code.UNAVAILABLE)); + rdsWatcherInOrder.verify(rdsResourceWatcher).onAmbientError( + argThat(status -> status.getCode() == Code.UNAVAILABLE)); + cdsWatcherInOrder.verify(cdsResourceWatcher).onResourceChanged( + argThat(statusOr -> !statusOr.hasValue() + && statusOr.getStatus().getCode() == Code.UNAVAILABLE)); + edsWatcherInOrder.verify(edsResourceWatcher).onResourceChanged( + argThat(statusOr -> !statusOr.hasValue() + && statusOr.getStatus().getCode() == Code.UNAVAILABLE)); + + // Failure count is now 3. + verifyServerFailureCount(3, 1, xdsServerInfo.target()); // Retry after backoff. inOrder.verify(backoffPolicy2).nextBackoffNanos(); @@ -3396,7 +4505,41 @@ public void streamClosedAndRetryWithBackoff() { call.verifyRequest(CDS, CDS_RESOURCE, "", "", NODE); call.verifyRequest(EDS, EDS_RESOURCE, "", "", NODE); - inOrder.verifyNoMoreInteractions(); + // Send a response so CPC is considered working and close gracefully. + call.sendResponse(LDS, listeners, "63", "3242"); + call.sendCompleted(); + + // Final failure count is still 3. + verifyServerFailureCount(3, 1, xdsServerInfo.target()); + } + + @Test + public void newWatcher_receivesCachedDataAndAmbientError() throws Exception { + InOrder inOrder = inOrder(ldsResourceWatcher); + DiscoveryRpcCall call1 = startResourceWatcher(XdsListenerResource.getInstance(), LDS_RESOURCE, + ldsResourceWatcher); + call1.sendResponse(LDS, testListenerRds, VERSION_1, "0000"); + inOrder.verify(ldsResourceWatcher, timeout(5000)) + .onResourceChanged(argThat(statusOr -> statusOr.hasValue())); + + call1.sendError(Status.DEADLINE_EXCEEDED.asException()); + ScheduledTask retryTask = + Iterables.getOnlyElement(fakeClock.getPendingTasks(RPC_RETRY_TASK_FILTER)); + fakeClock.forwardNanos(retryTask.getDelay(TimeUnit.NANOSECONDS)); + DiscoveryRpcCall call2 = resourceDiscoveryCalls.poll(); + Status propagatedError = Status.UNAVAILABLE.withDescription("real failure"); + call2.sendError(propagatedError.asException()); + inOrder.verify(ldsResourceWatcher, timeout(5000)).onAmbientError( + argThat(status -> status.getCode() == Code.UNAVAILABLE)); + @SuppressWarnings("unchecked") + ResourceWatcher ldsResourceWatcher2 = mock(ResourceWatcher.class); + xdsClient.watchXdsResource( + XdsListenerResource.getInstance(), LDS_RESOURCE, ldsResourceWatcher2); + + verify(ldsResourceWatcher2, timeout(5000)).onResourceChanged( + argThat(statusOr -> statusOr.hasValue())); + verify(ldsResourceWatcher2, timeout(5000)).onAmbientError( + argThat(status -> status.getCode() == Code.UNAVAILABLE)); } @Test @@ -3406,16 +4549,23 @@ public void streamClosedAndRetryRaceWithAddRemoveWatchers() { xdsClient.watchXdsResource(XdsRouteConfigureResource.getInstance(), RDS_RESOURCE, rdsResourceWatcher); DiscoveryRpcCall call = resourceDiscoveryCalls.poll(); + // Check metric data. + callback_ReportServerConnection(); + verifyServerConnection(1, true, xdsServerInfo.target()); call.sendError(Status.UNAVAILABLE.asException()); verify(ldsResourceWatcher, Mockito.timeout(1000).times(1)) - .onError(errorCaptor.capture()); - verifyStatusWithNodeId(errorCaptor.getValue(), Code.UNAVAILABLE, ""); - verify(rdsResourceWatcher).onError(errorCaptor.capture()); - verifyStatusWithNodeId(errorCaptor.getValue(), Code.UNAVAILABLE, ""); + .onResourceChanged(ldsUpdateCaptor.capture()); + verifyStatusWithNodeId(ldsUpdateCaptor.getValue().getStatus(), Code.UNAVAILABLE, ""); + verify(rdsResourceWatcher).onResourceChanged(rdsUpdateCaptor.capture()); + verifyStatusWithNodeId(rdsUpdateCaptor.getValue().getStatus(), Code.UNAVAILABLE, ""); ScheduledTask retryTask = Iterables.getOnlyElement(fakeClock.getPendingTasks(RPC_RETRY_TASK_FILTER)); assertThat(retryTask.getDelay(TimeUnit.NANOSECONDS)).isEqualTo(10L); + // Check metric data. + callback_ReportServerConnection(); + verifyServerConnection(1, false, xdsServerInfo.target()); + xdsClient.cancelXdsResourceWatch(XdsListenerResource.getInstance(), LDS_RESOURCE, ldsResourceWatcher); xdsClient.cancelXdsResourceWatch(XdsRouteConfigureResource.getInstance(), @@ -3430,11 +4580,19 @@ public void streamClosedAndRetryRaceWithAddRemoveWatchers() { call.verifyRequest(EDS, EDS_RESOURCE, "", "", NODE); call.verifyNoMoreRequest(); + // Check metric data. + callback_ReportServerConnection(); + verifyServerConnection(2,false, xdsServerInfo.target()); + call.sendResponse(LDS, testListenerRds, VERSION_1, "0000"); List routeConfigs = ImmutableList.of( Any.pack(mf.buildRouteConfiguration(RDS_RESOURCE, mf.buildOpaqueVirtualHosts(VHOST_SIZE)))); call.sendResponse(RDS, routeConfigs, VERSION_1, "0000"); + // Check metric data. + callback_ReportServerConnection(); + verifyServerConnection(2, true, xdsServerInfo.target()); + verifyNoMoreInteractions(ldsResourceWatcher, rdsResourceWatcher); } @@ -3446,6 +4604,9 @@ public void streamClosedAndRetryRestartsResourceInitialFetchTimerForUnresolvedRe xdsClient.watchXdsResource(XdsClusterResource.getInstance(), CDS_RESOURCE, cdsResourceWatcher); xdsClient.watchXdsResource(XdsEndpointResource.getInstance(), EDS_RESOURCE, edsResourceWatcher); DiscoveryRpcCall call = resourceDiscoveryCalls.poll(); + // Check metric data. + callback_ReportServerConnection(); + verifyServerConnection(1, true, xdsServerInfo.target()); ScheduledTask ldsResourceTimeout = Iterables.getOnlyElement(fakeClock.getPendingTasks(LDS_RESOURCE_FETCH_TIMEOUT_TASK_FILTER)); ScheduledTask rdsResourceTimeout = @@ -3456,19 +4617,46 @@ public void streamClosedAndRetryRestartsResourceInitialFetchTimerForUnresolvedRe Iterables.getOnlyElement(fakeClock.getPendingTasks(EDS_RESOURCE_FETCH_TIMEOUT_TASK_FILTER)); call.sendResponse(LDS, testListenerRds, VERSION_1, "0000"); assertThat(ldsResourceTimeout.isCancelled()).isTrue(); + // Check metric data. + callback_ReportServerConnection(); + verifyServerConnection(2, true, xdsServerInfo.target()); call.sendResponse(RDS, testRouteConfig, VERSION_1, "0000"); assertThat(rdsResourceTimeout.isCancelled()).isTrue(); + // Check metric data. + callback_ReportServerConnection(); + verifyServerConnection(3, true, xdsServerInfo.target()); call.sendError(Status.UNAVAILABLE.asException()); assertThat(cdsResourceTimeout.isCancelled()).isTrue(); assertThat(edsResourceTimeout.isCancelled()).isTrue(); - verify(ldsResourceWatcher, never()).onError(errorCaptor.capture()); - verify(rdsResourceWatcher, never()).onError(errorCaptor.capture()); - verify(cdsResourceWatcher).onError(errorCaptor.capture()); - verifyStatusWithNodeId(errorCaptor.getValue(), Code.UNAVAILABLE, ""); - verify(edsResourceWatcher).onError(errorCaptor.capture()); - verifyStatusWithNodeId(errorCaptor.getValue(), Code.UNAVAILABLE, ""); + + // With the reverted logic, the first error is suppressed because a response was received. + // We verify that no error callbacks are invoked at this point. + verify(ldsResourceWatcher, never()).onAmbientError(any(Status.class)); + verify(rdsResourceWatcher, never()).onAmbientError(any(Status.class)); + + // The metric report for a failed server connection is also suppressed. + callback_ReportServerConnection(); + verifyServerConnection(4, true, xdsServerInfo.target()); + + fakeClock.forwardTime(5, TimeUnit.SECONDS); + DiscoveryRpcCall call2 = resourceDiscoveryCalls.poll(); + call2.sendError(Status.UNAVAILABLE.asException()); + + // Now, verify the watchers are notified as expected. + verify(ldsResourceWatcher).onAmbientError(any(Status.class)); + verify(rdsResourceWatcher).onAmbientError(any(Status.class)); + verify(cdsResourceWatcher).onResourceChanged( + argThat(statusOr -> !statusOr.hasValue() + && statusOr.getStatus().getCode() == Code.UNAVAILABLE)); + verify(edsResourceWatcher).onResourceChanged( + argThat(statusOr -> !statusOr.hasValue() + && statusOr.getStatus().getCode() == Code.UNAVAILABLE)); + + fakeClock.forwardTime(5, TimeUnit.SECONDS); + DiscoveryRpcCall call3 = resourceDiscoveryCalls.poll(); + assertThat(call3).isNotNull(); fakeClock.forwardNanos(10L); assertThat(fakeClock.getPendingTasks(LDS_RESOURCE_FETCH_TIMEOUT_TASK_FILTER)).hasSize(0); @@ -3487,13 +4675,13 @@ public void reportLoadStatsToServer() { lrsCall.sendResponse(Collections.singletonList(clusterName), 1000L); fakeClock.forwardNanos(1000L); - lrsCall.verifyNextReportClusters(Collections.singletonList(new String[] {clusterName, null})); + lrsCall.verifyNextReportClusters(Collections.singletonList(new String[]{clusterName, null})); dropStats.release(); fakeClock.forwardNanos(1000L); // In case of having unreported cluster stats, one last report will be sent after corresponding // stats object released. - lrsCall.verifyNextReportClusters(Collections.singletonList(new String[] {clusterName, null})); + lrsCall.verifyNextReportClusters(Collections.singletonList(new String[]{clusterName, null})); fakeClock.forwardNanos(1000L); // Currently load reporting continues (with empty stats) even if all stats objects have been @@ -3522,8 +4710,10 @@ public void serverSideListenerFound() { call.sendResponse(LDS, listeners, "0", "0000"); // Client sends an ACK LDS request. call.verifyRequest(LDS, Collections.singletonList(LISTENER_RESOURCE), "0", "0000", NODE); - verify(ldsResourceWatcher).onChanged(ldsUpdateCaptor.capture()); - EnvoyServerProtoData.Listener parsedListener = ldsUpdateCaptor.getValue().listener(); + verify(ldsResourceWatcher).onResourceChanged(ldsUpdateCaptor.capture()); + StatusOr statusOrUpdate = ldsUpdateCaptor.getValue(); + assertThat(statusOrUpdate.hasValue()).isTrue(); + EnvoyServerProtoData.Listener parsedListener = statusOrUpdate.getValue().listener(); assertThat(parsedListener.name()).isEqualTo(LISTENER_RESOURCE); assertThat(parsedListener.address()).isEqualTo("0.0.0.0:7000"); assertThat(parsedListener.defaultFilterChain()).isNull(); @@ -3560,25 +4750,26 @@ public void serverSideListenerNotFound() { verifyNoInteractions(ldsResourceWatcher); fakeClock.forwardTime(XdsClientImpl.INITIAL_RESOURCE_FETCH_TIMEOUT_SEC, TimeUnit.SECONDS); - verify(ldsResourceWatcher).onResourceDoesNotExist(LISTENER_RESOURCE); + verify(ldsResourceWatcher).onResourceChanged(argThat( + statusOr -> statusOr.getStatus().getDescription().contains(LISTENER_RESOURCE))); assertThat(fakeClock.getPendingTasks(LDS_RESOURCE_FETCH_TIMEOUT_TASK_FILTER)).isEmpty(); } @Test public void serverSideListenerResponseErrorHandling_badDownstreamTlsContext() { GrpcXdsClientImplTestBase.DiscoveryRpcCall call = - startResourceWatcher(XdsListenerResource.getInstance(), LISTENER_RESOURCE, - ldsResourceWatcher); + startResourceWatcher(XdsListenerResource.getInstance(), LISTENER_RESOURCE, + ldsResourceWatcher); Message hcmFilter = mf.buildHttpConnectionManagerFilter( - "route-foo.googleapis.com", null, + "route-foo.googleapis.com", null, Collections.singletonList(mf.buildTerminalFilter())); Message downstreamTlsContext = CommonTlsContextTestsUtil.buildTestDownstreamTlsContext( - null, null,false); + null, null, false); Message filterChain = mf.buildFilterChain( - Collections.emptyList(), downstreamTlsContext, "envoy.transport_sockets.tls", + Collections.emptyList(), downstreamTlsContext, "envoy.transport_sockets.tls", hcmFilter); Message listener = - mf.buildListenerWithFilterChain(LISTENER_RESOURCE, 7000, "0.0.0.0", filterChain); + mf.buildListenerWithFilterChain(LISTENER_RESOURCE, 7000, "0.0.0.0", filterChain); List listeners = ImmutableList.of(Any.pack(listener)); call.sendResponse(LDS, listeners, "0", "0000"); // The response NACKed with errors indicating indices of the failed resources. @@ -3586,8 +4777,10 @@ public void serverSideListenerResponseErrorHandling_badDownstreamTlsContext() { + "0.0.0.0:7000\' validation error: " + "common-tls-context is required in downstream-tls-context"; call.verifyRequestNack(LDS, LISTENER_RESOURCE, "", "0000", NODE, ImmutableList.of(errorMsg)); - verify(ldsResourceWatcher).onError(errorCaptor.capture()); - verifyStatusWithNodeId(errorCaptor.getValue(), Code.UNAVAILABLE, errorMsg); + verify(ldsResourceWatcher).onResourceChanged(ldsUpdateCaptor.capture()); + StatusOr statusOrUpdate = ldsUpdateCaptor.getValue(); + assertThat(statusOrUpdate.hasValue()).isFalse(); + verifyStatusWithNodeId(statusOrUpdate.getStatus(), Code.UNAVAILABLE, errorMsg); } @Test @@ -3599,7 +4792,7 @@ public void serverSideListenerResponseErrorHandling_badTransportSocketName() { "route-foo.googleapis.com", null, Collections.singletonList(mf.buildTerminalFilter())); Message downstreamTlsContext = CommonTlsContextTestsUtil.buildTestDownstreamTlsContext( - "cert1", "cert2",false); + "cert1", "cert2", false); Message filterChain = mf.buildFilterChain( Collections.emptyList(), downstreamTlsContext, "envoy.transport_sockets.bad1", hcmFilter); @@ -3613,8 +4806,8 @@ public void serverSideListenerResponseErrorHandling_badTransportSocketName() { + "transport-socket with name envoy.transport_sockets.bad1 not supported."; call.verifyRequestNack(LDS, LISTENER_RESOURCE, "", "0000", NODE, ImmutableList.of( errorMsg)); - verify(ldsResourceWatcher).onError(errorCaptor.capture()); - verifyStatusWithNodeId(errorCaptor.getValue(), Code.UNAVAILABLE, errorMsg); + verify(ldsResourceWatcher).onResourceChanged(ldsUpdateCaptor.capture()); + verifyStatusWithNodeId(ldsUpdateCaptor.getValue().getStatus(), Code.UNAVAILABLE, errorMsg); } @Test @@ -3630,6 +4823,9 @@ public void sendingToStoppedServer() throws Exception { xdsClient.watchXdsResource(XdsListenerResource.getInstance(), LDS_RESOURCE, ldsResourceWatcher); fakeClock.forwardTime(14, TimeUnit.SECONDS); + // Check metric data. + callback_ReportServerConnection(); + verifyServerConnection(1, false, xdsServerInfo.target()); // Restart the server xdsServer = cleanupRule.register( @@ -3641,13 +4837,21 @@ public void sendingToStoppedServer() throws Exception { .build() .start()); fakeClock.forwardTime(5, TimeUnit.SECONDS); - verify(ldsResourceWatcher, never()).onResourceDoesNotExist(LDS_RESOURCE); + verify(ldsResourceWatcher, never()).onResourceChanged(argThat( + statusOr -> statusOr.getStatus().getDescription().contains(LDS_RESOURCE))); fakeClock.forwardTime(20, TimeUnit.SECONDS); // Trigger rpcRetryTimer DiscoveryRpcCall call = resourceDiscoveryCalls.poll(3, TimeUnit.SECONDS); + // Check metric data. + callback_ReportServerConnection(); if (call == null) { // The first rpcRetry may have happened before the channel was ready fakeClock.forwardTime(50, TimeUnit.SECONDS); call = resourceDiscoveryCalls.poll(3, TimeUnit.SECONDS); } + verifyServerConnection(2, false, xdsServerInfo.target()); + + // Check metric data. + callback_ReportServerConnection(); + verifyServerConnection(3, false, xdsServerInfo.target()); // NOTE: There is a ScheduledExecutorService that may get involved due to the reconnect // so you cannot rely on the logic being single threaded. The timeout() in verifyRequest @@ -3655,11 +4859,18 @@ public void sendingToStoppedServer() throws Exception { // Send a response and do verifications call.sendResponse(LDS, mf.buildWrappedResource(testListenerVhosts), VERSION_1, "0001"); call.verifyRequest(LDS, LDS_RESOURCE, VERSION_1, "0001", NODE); - verify(ldsResourceWatcher).onChanged(ldsUpdateCaptor.capture()); - verifyGoldenListenerVhosts(ldsUpdateCaptor.getValue()); + @SuppressWarnings("unchecked") + ArgumentCaptor> captor = ArgumentCaptor.forClass(StatusOr.class); + verify(ldsResourceWatcher, timeout(1000).atLeast(2)).onResourceChanged(captor.capture()); + StatusOr lastValue = captor.getAllValues().get(captor.getAllValues().size() - 1); + assertThat(lastValue.hasValue()).isTrue(); + verifyGoldenListenerVhosts(lastValue.getValue()); assertThat(fakeClock.getPendingTasks(LDS_RESOURCE_FETCH_TIMEOUT_TASK_FILTER)).isEmpty(); verifyResourceMetadataAcked(LDS, LDS_RESOURCE, testListenerVhosts, VERSION_1, TIME_INCREMENT); verifySubscribedResourcesMetadataSizes(1, 1, 0, 0); + // Check metric data. + callback_ReportServerConnection(); + verifyServerConnection(1, true, xdsServerInfo.target()); } catch (Throwable t) { throw t; // This allows putting a breakpoint here for debugging } @@ -3668,14 +4879,38 @@ public void sendingToStoppedServer() throws Exception { @Test public void sendToBadUrl() throws Exception { // Setup xdsClient to fail on stream creation - XdsClientImpl client = createXdsClient("some. garbage"); + String garbageUri = "some. garbage"; + XdsClientImpl client = createXdsClient(garbageUri); client.watchXdsResource(XdsListenerResource.getInstance(), LDS_RESOURCE, ldsResourceWatcher); fakeClock.forwardTime(20, TimeUnit.SECONDS); - verify(ldsResourceWatcher, Mockito.timeout(5000).times(1)).onError(ArgumentMatchers.any()); + verify(ldsResourceWatcher, Mockito.timeout(5000).atLeastOnce()) + .onResourceChanged(ldsUpdateCaptor.capture()); + assertThat(ldsUpdateCaptor.getValue().getStatus().getDescription()).contains(garbageUri); client.shutdown(); } + @Test + public void circuitBreakingConversionOf32bitIntTo64bitLongForMaxRequestNegativeValue() { + DiscoveryRpcCall call = startResourceWatcher(XdsClusterResource.getInstance(), CDS_RESOURCE, + cdsResourceWatcher); + Any clusterCircuitBreakers = Any.pack( + mf.buildEdsCluster(CDS_RESOURCE, null, "round_robin", null, null, false, null, + "envoy.transport_sockets.tls", mf.buildCircuitBreakers(50, -1), null)); + call.sendResponse(CDS, clusterCircuitBreakers, VERSION_1, "0000"); + + // Client sent an ACK CDS request. + call.verifyRequest(CDS, CDS_RESOURCE, VERSION_1, "0000", NODE); + verify(cdsResourceWatcher).onResourceChanged(cdsUpdateCaptor.capture()); + StatusOr statusOrUpdate = cdsUpdateCaptor.getValue(); + assertThat(statusOrUpdate.hasValue()).isTrue(); + CdsUpdate cdsUpdate = statusOrUpdate.getValue(); + + assertThat(cdsUpdate.clusterName()).isEqualTo(CDS_RESOURCE); + assertThat(cdsUpdate.clusterType()).isEqualTo(ClusterType.EDS); + assertThat(cdsUpdate.maxConcurrentRequests()).isEqualTo(4294967295L); + } + @Test public void sendToNonexistentServer() throws Exception { // Setup xdsClient to fail on stream creation @@ -3684,30 +4919,228 @@ public void sendToNonexistentServer() throws Exception { // file. Assume localhost doesn't speak HTTP/2 on the finger port XdsClientImpl client = createXdsClient("localhost:79"); client.watchXdsResource(XdsListenerResource.getInstance(), LDS_RESOURCE, ldsResourceWatcher); - verify(ldsResourceWatcher, Mockito.timeout(5000).times(1)).onError(ArgumentMatchers.any()); + verify(ldsResourceWatcher, Mockito.timeout(5000)).onResourceChanged(ldsUpdateCaptor.capture()); + StatusOr statusOrUpdate = ldsUpdateCaptor.getValue(); + assertThat(statusOrUpdate.hasValue()).isFalse(); + assertThat(statusOrUpdate.getStatus().getCode()).isEqualTo(Status.Code.UNAVAILABLE); assertThat(fakeClock.numPendingTasks()).isEqualTo(1); //retry assertThat(fakeClock.getPendingTasks().iterator().next().toString().contains("RpcRetryTask")) .isTrue(); client.shutdown(); } + @Test + public void validAndInvalidResourceMetricReport() { + xdsClient.watchXdsResource(XdsClusterResource.getInstance(), "A", cdsResourceWatcher); + xdsClient.watchXdsResource(XdsEndpointResource.getInstance(), "A.1", edsResourceWatcher); + xdsClient.watchXdsResource(XdsClusterResource.getInstance(), "B", cdsResourceWatcher); + xdsClient.watchXdsResource(XdsEndpointResource.getInstance(), "B.1", edsResourceWatcher); + xdsClient.watchXdsResource(XdsClusterResource.getInstance(), "C", cdsResourceWatcher); + xdsClient.watchXdsResource(XdsEndpointResource.getInstance(), "C.1", edsResourceWatcher); + DiscoveryRpcCall call = resourceDiscoveryCalls.poll(); + assertThat(call).isNotNull(); + + // CDS -> {A, B, C}, version 1 + ImmutableMap resourcesV1 = ImmutableMap.of( + "A", Any.pack(mf.buildEdsCluster("A", "A.1", "round_robin", null, null, false, null, + "envoy.transport_sockets.tls", null, null + )), + "B", Any.pack(mf.buildEdsCluster("B", "B.1", "round_robin", null, null, false, null, + "envoy.transport_sockets.tls", null, null + )), + "C", Any.pack(mf.buildEdsCluster("C", "C.1", "round_robin", null, null, false, null, + "envoy.transport_sockets.tls", null, null + ))); + call.sendResponse(CDS, resourcesV1.values().asList(), VERSION_1, "0000"); + // {A, B, C} -> ACK, version 1 + verifyResourceValidInvalidCount(1, 3, 0, xdsServerInfo.target(), CDS.typeUrl()); + + // EDS -> {A.1, B.1, C.1}, version 1 + List dropOverloads = ImmutableList.of(); + List endpointsV1 = ImmutableList.of(lbEndpointHealthy); + ImmutableMap resourcesV11 = ImmutableMap.of( + "A.1", Any.pack(mf.buildClusterLoadAssignment("A.1", endpointsV1, dropOverloads)), + "B.1", Any.pack(mf.buildClusterLoadAssignment("B.1", endpointsV1, dropOverloads)), + "C.1", Any.pack(mf.buildClusterLoadAssignment("C.1", endpointsV1, dropOverloads))); + call.sendResponse(EDS, resourcesV11.values().asList(), VERSION_1, "0000"); + // {A.1, B.1, C.1} -> ACK, version 1 + verifyResourceValidInvalidCount(1, 3, 0, xdsServerInfo.target(), EDS.typeUrl()); + + // CDS -> {A, B}, version 2 + // Failed to parse endpoint B + ImmutableMap resourcesV2 = ImmutableMap.of( + "A", Any.pack(mf.buildEdsCluster("A", "A.2", "round_robin", null, null, false, null, + "envoy.transport_sockets.tls", null, null + )), + "B", Any.pack(mf.buildClusterInvalid("B"))); + call.sendResponse(CDS, resourcesV2.values().asList(), VERSION_2, "0001"); + // {A} -> ACK, version 2 + // {B} -> NACK, version 1, rejected version 2, rejected reason: Failed to parse B + // {C} -> does not exist + verifyResourceValidInvalidCount(1, 1, 1, xdsServerInfo.target(), CDS.typeUrl()); + } + + @Test + public void serverFailureMetricReport() { + xdsClient.watchXdsResource(XdsListenerResource.getInstance(), LDS_RESOURCE, ldsResourceWatcher); + xdsClient.watchXdsResource(XdsRouteConfigureResource.getInstance(), RDS_RESOURCE, + rdsResourceWatcher); + DiscoveryRpcCall call = resourceDiscoveryCalls.poll(); + // Management server closes the RPC stream before sending any response. + call.sendCompleted(); + verify(ldsResourceWatcher, Mockito.timeout(1000)).onResourceChanged(ldsUpdateCaptor.capture()); + StatusOr ldsStatusOr = ldsUpdateCaptor.getValue(); + assertThat(ldsStatusOr.hasValue()).isFalse(); + verifyStatusWithNodeId(ldsStatusOr.getStatus(), Code.UNAVAILABLE, + "ADS stream closed with OK before receiving a response"); + verify(rdsResourceWatcher).onResourceChanged(rdsUpdateCaptor.capture()); + StatusOr rdsStatusOr = rdsUpdateCaptor.getValue(); + assertThat(rdsStatusOr.hasValue()).isFalse(); + verifyStatusWithNodeId(rdsStatusOr.getStatus(), Code.UNAVAILABLE, + "ADS stream closed with OK before receiving a response"); + verifyServerFailureCount(1, 1, xdsServerInfo.target()); + } + + @Test + public void serverFailureMetricReport_forRetryAndBackoff() { + InOrder inOrder = inOrder(backoffPolicyProvider, backoffPolicy1, backoffPolicy2); + InOrder ldsWatcherInOrder = inOrder(ldsResourceWatcher); + InOrder rdsWatcherInOrder = inOrder(rdsResourceWatcher); + InOrder cdsWatcherInOrder = inOrder(cdsResourceWatcher); + InOrder edsWatcherInOrder = inOrder(edsResourceWatcher); + when(backoffPolicyProvider.get()).thenReturn(backoffPolicy1, backoffPolicy2, backoffPolicy2); + xdsClient.watchXdsResource(XdsListenerResource.getInstance(), LDS_RESOURCE, ldsResourceWatcher); + xdsClient.watchXdsResource(XdsRouteConfigureResource.getInstance(), RDS_RESOURCE, + rdsResourceWatcher); + xdsClient.watchXdsResource(XdsClusterResource.getInstance(), CDS_RESOURCE, cdsResourceWatcher); + xdsClient.watchXdsResource(XdsEndpointResource.getInstance(), EDS_RESOURCE, edsResourceWatcher); + DiscoveryRpcCall call = resourceDiscoveryCalls.poll(); + + // Management server closes the RPC stream with an error. + call.sendError(Status.UNKNOWN.asException()); + ldsWatcherInOrder.verify(ldsResourceWatcher, timeout(1000)).onResourceChanged( + argThat(statusOr -> !statusOr.hasValue() + && statusOr.getStatus().getCode() == Code.UNKNOWN)); + rdsWatcherInOrder.verify(rdsResourceWatcher).onResourceChanged( + argThat(statusOr -> !statusOr.hasValue() + && statusOr.getStatus().getCode() == Code.UNKNOWN)); + cdsWatcherInOrder.verify(cdsResourceWatcher).onResourceChanged( + argThat(statusOr -> !statusOr.hasValue() + && statusOr.getStatus().getCode() == Code.UNKNOWN)); + edsWatcherInOrder.verify(edsResourceWatcher).onResourceChanged( + argThat(statusOr -> !statusOr.hasValue() + && statusOr.getStatus().getCode() == Code.UNKNOWN)); + verifyServerFailureCount(1, 1, xdsServerInfo.target()); + + // Retry after backoff. + inOrder.verify(backoffPolicyProvider).get(); + inOrder.verify(backoffPolicy1).nextBackoffNanos(); + ScheduledTask retryTask = + Iterables.getOnlyElement(fakeClock.getPendingTasks(RPC_RETRY_TASK_FILTER)); + assertThat(retryTask.getDelay(TimeUnit.NANOSECONDS)).isEqualTo(10L); + fakeClock.forwardNanos(10L); + call = resourceDiscoveryCalls.poll(); + + // Management server becomes unreachable. + String errorMsg = "my fault"; + call.sendError(Status.UNAVAILABLE.withDescription(errorMsg).asException()); + ldsWatcherInOrder.verify(ldsResourceWatcher).onResourceChanged( + argThat(statusOr -> !statusOr.hasValue() + && statusOr.getStatus().getCode() == Code.UNAVAILABLE)); + rdsWatcherInOrder.verify(rdsResourceWatcher).onResourceChanged( + argThat(statusOr -> !statusOr.hasValue() + && statusOr.getStatus().getCode() == Code.UNAVAILABLE)); + cdsWatcherInOrder.verify(cdsResourceWatcher).onResourceChanged( + argThat(statusOr -> !statusOr.hasValue() + && statusOr.getStatus().getCode() == Code.UNAVAILABLE)); + edsWatcherInOrder.verify(edsResourceWatcher).onResourceChanged( + argThat(statusOr -> !statusOr.hasValue() + && statusOr.getStatus().getCode() == Code.UNAVAILABLE)); + verifyServerFailureCount(2, 1, xdsServerInfo.target()); + + // Retry after backoff. + inOrder.verify(backoffPolicy1).nextBackoffNanos(); + retryTask = + Iterables.getOnlyElement(fakeClock.getPendingTasks(RPC_RETRY_TASK_FILTER)); + assertThat(retryTask.getDelay(TimeUnit.NANOSECONDS)).isEqualTo(100L); + fakeClock.forwardNanos(100L); + call = resourceDiscoveryCalls.poll(); + + List resources = ImmutableList.of(FAILING_ANY, testListenerRds, FAILING_ANY); + call.sendResponse(LDS, resources, "63", "3242"); + ldsWatcherInOrder.verify(ldsResourceWatcher).onResourceChanged( + argThat(statusOr -> statusOr.hasValue())); + + List routeConfigs = ImmutableList.of(FAILING_ANY, testRouteConfig, FAILING_ANY); + call.sendResponse(RDS, routeConfigs, "5", "6764"); + rdsWatcherInOrder.verify(rdsResourceWatcher).onResourceChanged( + argThat(statusOr -> statusOr.hasValue())); + + // Stream fails AFTER a response. Error is suppressed and no watcher notification occurs. + call.sendError(Status.DEADLINE_EXCEEDED.asException()); + + // Failure count does NOT increase because the error was suppressed. It is still 2. + verifyServerFailureCount(2, 1, xdsServerInfo.target()); + + // Reset backoff sequence and retry after backoff. + inOrder.verify(backoffPolicyProvider).get(); + inOrder.verify(backoffPolicy2).nextBackoffNanos(); + retryTask = + Iterables.getOnlyElement(fakeClock.getPendingTasks(RPC_RETRY_TASK_FILTER)); + assertThat(retryTask.getDelay(TimeUnit.NANOSECONDS)).isEqualTo(20L); + fakeClock.forwardNanos(20L); + call = resourceDiscoveryCalls.poll(); + + // Management server becomes unreachable again. This is on a new stream, so error propagates. + call.sendError(Status.UNAVAILABLE.asException()); + ldsWatcherInOrder.verify(ldsResourceWatcher).onAmbientError( + argThat(status -> status.getCode() == Code.UNAVAILABLE)); + rdsWatcherInOrder.verify(rdsResourceWatcher).onAmbientError( + argThat(status -> status.getCode() == Code.UNAVAILABLE)); + cdsWatcherInOrder.verify(cdsResourceWatcher).onResourceChanged( + argThat(statusOr -> !statusOr.hasValue() + && statusOr.getStatus().getCode() == Code.UNAVAILABLE)); + edsWatcherInOrder.verify(edsResourceWatcher).onResourceChanged( + argThat(statusOr -> !statusOr.hasValue() + && statusOr.getStatus().getCode() == Code.UNAVAILABLE)); + + // Server failure count is now 3. + verifyServerFailureCount(3, 1, xdsServerInfo.target()); + + // Retry after backoff. + inOrder.verify(backoffPolicy2).nextBackoffNanos(); + retryTask = + Iterables.getOnlyElement(fakeClock.getPendingTasks(RPC_RETRY_TASK_FILTER)); + assertThat(retryTask.getDelay(TimeUnit.NANOSECONDS)).isEqualTo(200L); + fakeClock.forwardNanos(200L); + call = resourceDiscoveryCalls.poll(); + + List clusters = ImmutableList.of(FAILING_ANY, testClusterRoundRobin); + call.sendResponse(CDS, clusters, VERSION_1, "0000"); + call.sendCompleted(); + + // Final failure count is still 3 as the stream closed gracefully. + verifyServerFailureCount(3, 1, xdsServerInfo.target()); + } + private XdsClientImpl createXdsClient(String serverUri) { BootstrapInfo bootstrapInfo = buildBootStrap(serverUri); return new XdsClientImpl( - DEFAULT_XDS_TRANSPORT_FACTORY, + new GrpcXdsTransportFactory(null), bootstrapInfo, fakeClock.getScheduledExecutorService(), backoffPolicyProvider, fakeClock.getStopwatchSupplier(), timeProvider, MessagePrinter.INSTANCE, - new TlsContextManagerImpl(bootstrapInfo)); + new TlsContextManagerImpl(bootstrapInfo), + xdsClientMetricReporter); } - private BootstrapInfo buildBootStrap(String serverUri) { + private BootstrapInfo buildBootStrap(String serverUri) { ServerInfo xdsServerInfo = ServerInfo.create(serverUri, CHANNEL_CREDENTIALS, - ignoreResourceDeletion()); + ignoreResourceDeletion(), true, false, false); return Bootstrapper.BootstrapInfo.builder() .servers(Collections.singletonList(xdsServerInfo)) @@ -3717,7 +5150,7 @@ private BootstrapInfo buildBootStrap(String serverUri) { AuthorityInfo.create( "xdstp://authority.xds.com/envoy.config.listener.v3.Listener/%s", ImmutableList.of(Bootstrapper.ServerInfo.create( - SERVER_URI_CUSTOME_AUTHORITY, CHANNEL_CREDENTIALS))), + SERVER_URI_CUSTOM_AUTHORITY, CHANNEL_CREDENTIALS))), "", AuthorityInfo.create( "xdstp:///envoy.config.listener.v3.Listener/%s", @@ -3768,10 +5201,22 @@ protected abstract static class DiscoveryRpcCall { protected void verifyRequest( XdsResourceType type, List resources, String versionInfo, String nonce, - Node node) { + Node node, VerificationMode verificationMode) { throw new UnsupportedOperationException(); } + protected void verifyRequest( + XdsResourceType type, List resources, String versionInfo, String nonce, + Node node) { + verifyRequest(type, resources, versionInfo, nonce, node, Mockito.timeout(2000)); + } + + protected void verifyRequest( + XdsResourceType type, String resource, String versionInfo, String nonce, + Node node, VerificationMode verificationMode) { + verifyRequest(type, ImmutableList.of(resource), versionInfo, nonce, node, verificationMode); + } + protected void verifyRequest( XdsResourceType type, String resource, String versionInfo, String nonce, Node node) { verifyRequest(type, ImmutableList.of(resource), versionInfo, nonce, node); @@ -3799,7 +5244,7 @@ protected void sendResponse( } protected void sendResponse(XdsResourceType type, Any resource, String versionInfo, - String nonce) { + String nonce) { sendResponse(type, ImmutableList.of(resource), versionInfo, nonce); } @@ -3831,11 +5276,14 @@ protected void sendResponse(List clusters, long loadReportIntervalNano) } protected abstract static class MessageFactory { + /** Throws {@link InvalidProtocolBufferException} on {@link Any#unpack(Class)}. */ protected static final Any FAILING_ANY = Any.newBuilder().setTypeUrl("fake").build(); protected abstract Any buildWrappedResource(Any originalResource); + protected abstract Any buildWrappedResourceWithName(Any originalResource, String name); + protected Message buildListenerWithApiListener(String name, Message routeConfiguration) { return buildListenerWithApiListener( name, routeConfiguration, Collections.emptyList()); @@ -3912,7 +5360,7 @@ protected Message buildLocalityLbEndpoints(String region, String zone, String su } protected abstract Message buildLbEndpoint(String address, int port, String healthStatus, - int lbWeight); + int lbWeight, String endpointHostname); protected abstract Message buildDropOverload(String category, int dropPerMillion); @@ -3928,4 +5376,70 @@ protected abstract Message buildHttpConnectionManagerFilter( protected abstract Message buildTerminalFilter(); } + + private static class XdsStringResource extends XdsResourceType { + @Override + @SuppressWarnings("unchecked") + protected Class unpackedClassName() { + return StringValue.class; + } + + @Override + public String typeName() { + return "EMPTY"; + } + + @Override + public String typeUrl() { + return "type.googleapis.com/google.protobuf.StringValue"; + } + + @Override + public boolean shouldRetrieveResourceKeysForArgs() { + return false; + } + + @Override + protected boolean isFullStateOfTheWorld() { + return false; + } + + @Override + @Nullable + protected String extractResourceName(Message unpackedResource) { + if (!(unpackedResource instanceof StringValue)) { + return null; + } + return ((StringValue) unpackedResource).getValue(); + } + + @Override + protected StringUpdate doParse(Args args, Message unpackedMessage) + throws ResourceInvalidException { + return new StringUpdate(((StringValue) unpackedMessage).getValue()); + } + } + + private static final class StringUpdate implements ResourceUpdate { + @SuppressWarnings("UnusedVariable") + public final String value; + + public StringUpdate(String value) { + this.value = value; + } + + @Override + public boolean equals(Object o) { + if (!(o instanceof StringUpdate)) { + return false; + } + StringUpdate that = (StringUpdate) o; + return Objects.equals(this.value, that.value); + } + + @Override + public int hashCode() { + return Objects.hash(value); + } + } } diff --git a/xds/src/test/java/io/grpc/xds/GrpcXdsClientImplV3Test.java b/xds/src/test/java/io/grpc/xds/GrpcXdsClientImplV3Test.java index 91a5fefaa59..3966fae7f20 100644 --- a/xds/src/test/java/io/grpc/xds/GrpcXdsClientImplV3Test.java +++ b/xds/src/test/java/io/grpc/xds/GrpcXdsClientImplV3Test.java @@ -17,6 +17,7 @@ package io.grpc.xds; import static com.google.common.truth.Truth.assertThat; +import static org.mockito.AdditionalAnswers.delegatesTo; import static org.mockito.ArgumentMatchers.argThat; import static org.mockito.Mockito.inOrder; import static org.mockito.Mockito.mock; @@ -118,6 +119,7 @@ import org.mockito.ArgumentMatcher; import org.mockito.InOrder; import org.mockito.Mockito; +import org.mockito.verification.VerificationMode; /** * Tests for {@link XdsClientImpl} with protocol version v3. @@ -143,7 +145,8 @@ public StreamObserver streamAggregatedResources( assertThat(adsEnded.get()).isTrue(); // ensure previous call was ended adsEnded.set(false); @SuppressWarnings("unchecked") - StreamObserver requestObserver = mock(StreamObserver.class); + StreamObserver requestObserver = + mock(StreamObserver.class, delegatesTo(new MockStreamObserver())); DiscoveryRpcCall call = new DiscoveryRpcCallV3(requestObserver, responseObserver); resourceDiscoveryCalls.offer(call); Context.current().addListener( @@ -205,8 +208,8 @@ private DiscoveryRpcCallV3(StreamObserver requestObserver, @Override protected void verifyRequest( XdsResourceType type, List resources, String versionInfo, String nonce, - EnvoyProtoData.Node node) { - verify(requestObserver, Mockito.timeout(2000)).onNext(argThat(new DiscoveryRequestMatcher( + EnvoyProtoData.Node node, VerificationMode verificationMode) { + verify(requestObserver, verificationMode).onNext(argThat(new DiscoveryRequestMatcher( node.toEnvoyProtoNode(), versionInfo, resources, type.typeUrl(), nonce, null, null))); } @@ -290,6 +293,14 @@ protected Any buildWrappedResource(Any originalResource) { .build()); } + @Override + protected Any buildWrappedResourceWithName(Any originalResource, String name) { + return Any.pack(Resource.newBuilder() + .setResource(originalResource) + .setName(name) + .build()); + } + @SuppressWarnings("unchecked") @Override protected Message buildListenerWithApiListener( @@ -602,18 +613,15 @@ protected Message buildLeastRequestLbConfig(int choiceCount) { } @Override - @SuppressWarnings("deprecation") protected Message buildUpstreamTlsContext(String instanceName, String certName) { CommonTlsContext.Builder commonTlsContextBuilder = CommonTlsContext.newBuilder(); if (instanceName != null && certName != null) { - CommonTlsContext.CertificateProviderInstance providerInstance = - CommonTlsContext.CertificateProviderInstance.newBuilder() - .setInstanceName(instanceName) - .setCertificateName(certName) - .build(); CommonTlsContext.CombinedCertificateValidationContext combined = CommonTlsContext.CombinedCertificateValidationContext.newBuilder() - .setValidationContextCertificateProviderInstance(providerInstance) + .setDefaultValidationContext(CertificateValidationContext.newBuilder() + .setCaCertificateProviderInstance(CertificateProviderPluginInstance.newBuilder() + .setInstanceName(instanceName) + .setCertificateName(certName))) .build(); commonTlsContextBuilder.setCombinedValidationContext(combined); } @@ -694,7 +702,7 @@ protected Message buildLocalityLbEndpoints(String region, String zone, String su @Override protected Message buildLbEndpoint(String address, int port, String healthStatus, - int lbWeight) { + int lbWeight, String endpointHostname) { HealthStatus status; switch (healthStatus) { case "unknown": @@ -722,7 +730,8 @@ protected Message buildLbEndpoint(String address, int port, String healthStatus, .setEndpoint( Endpoint.newBuilder().setAddress( Address.newBuilder().setSocketAddress( - SocketAddress.newBuilder().setAddress(address).setPortValue(port)))) + SocketAddress.newBuilder().setAddress(address).setPortValue(port))) + .setHostname(endpointHostname)) .setHealthStatus(status) .setLoadBalancingWeight(UInt32Value.of(lbWeight)) .build(); @@ -739,7 +748,6 @@ protected Message buildDropOverload(String category, int dropPerMillion) { .build(); } - @SuppressWarnings("deprecation") @Override protected FilterChain buildFilterChain( List alpn, Message tlsContext, String transportSocketName, @@ -865,6 +873,19 @@ public boolean matches(DiscoveryRequest argument) { } return node.equals(argument.getNode()); } + + @Override + public String toString() { + return "DiscoveryRequestMatcher{" + + "node=" + node + + ", versionInfo='" + versionInfo + '\'' + + ", typeUrl='" + typeUrl + '\'' + + ", resources=" + resources + + ", responseNonce='" + responseNonce + '\'' + + ", errorCode=" + errorCode + + ", errorMessages=" + errorMessages + + '}'; + } } /** @@ -892,4 +913,23 @@ public boolean matches(LoadStatsRequest argument) { return actual.equals(expected); } } + + private static class MockStreamObserver implements StreamObserver { + private final List requests = new ArrayList<>(); + + @Override + public void onNext(DiscoveryRequest value) { + requests.add(value); + } + + @Override + public void onError(Throwable t) { + // Ignore + } + + @Override + public void onCompleted() { + // Ignore + } + } } diff --git a/xds/src/test/java/io/grpc/xds/GrpcXdsTransportFactoryTest.java b/xds/src/test/java/io/grpc/xds/GrpcXdsTransportFactoryTest.java index 703e429fa23..9c606a962f6 100644 --- a/xds/src/test/java/io/grpc/xds/GrpcXdsTransportFactoryTest.java +++ b/xds/src/test/java/io/grpc/xds/GrpcXdsTransportFactoryTest.java @@ -30,6 +30,7 @@ import io.grpc.Server; import io.grpc.Status; import io.grpc.stub.StreamObserver; +import io.grpc.testing.GrpcCleanupRule; import io.grpc.xds.client.Bootstrapper; import io.grpc.xds.client.XdsTransportFactory; import java.util.concurrent.BlockingQueue; @@ -37,6 +38,7 @@ import java.util.concurrent.TimeUnit; import org.junit.After; import org.junit.Before; +import org.junit.Rule; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -44,6 +46,8 @@ @RunWith(JUnit4.class) public class GrpcXdsTransportFactoryTest { + @Rule public final GrpcCleanupRule grpcCleanupRule = new GrpcCleanupRule(); + private Server server; @Before @@ -92,9 +96,10 @@ public void onCompleted() { @Test public void callApis() throws Exception { XdsTransportFactory.XdsTransport xdsTransport = - GrpcXdsTransportFactory.DEFAULT_XDS_TRANSPORT_FACTORY.create( - Bootstrapper.ServerInfo.create("localhost:" + server.getPort(), - InsecureChannelCredentials.create())); + new GrpcXdsTransportFactory(null) + .create( + Bootstrapper.ServerInfo.create( + "localhost:" + server.getPort(), InsecureChannelCredentials.create())); MethodDescriptor methodDescriptor = AggregatedDiscoveryServiceGrpc.getStreamAggregatedResourcesMethod(); XdsTransportFactory.StreamingCall streamingCall = @@ -117,6 +122,59 @@ public void callApis() throws Exception { xdsTransport.shutdown(); } + @Test + public void refCountedXdsTransport_sameXdsServerAddress_returnsExistingTransport() { + Bootstrapper.ServerInfo xdsServerInfo = + Bootstrapper.ServerInfo.create( + "localhost:" + server.getPort(), InsecureChannelCredentials.create()); + GrpcXdsTransportFactory xdsTransportFactory = new GrpcXdsTransportFactory(null); + // Calling create() for the first time creates a new GrpcXdsTransport instance. + // The ref count was previously 0 and now is 1. + XdsTransportFactory.XdsTransport transport1 = xdsTransportFactory.create(xdsServerInfo); + // Calling create() for the second time to the same xDS server address returns the same + // GrpcXdsTransport instance. The ref count was previously 1 and now is 2. + XdsTransportFactory.XdsTransport transport2 = xdsTransportFactory.create(xdsServerInfo); + assertThat(transport1).isSameInstanceAs(transport2); + // Calling shutdown() for the first time does not shut down the GrpcXdsTransport instance. + // The ref count was previously 2 and now is 1. + transport1.shutdown(); + // Calling shutdown() for the second time shuts down the GrpcXdsTransport instance. + // The ref count was previously 1 and now is 0. + transport2.shutdown(); + } + + @Test + public void refCountedXdsTransport_differentXdsServerAddress_returnsDifferentTransport() + throws Exception { + // Create and start a second xDS server on a different port. + Server server2 = + grpcCleanupRule.register( + Grpc.newServerBuilderForPort(0, InsecureServerCredentials.create()) + .addService(echoAdsService()) + .build() + .start()); + Bootstrapper.ServerInfo xdsServerInfo1 = + Bootstrapper.ServerInfo.create( + "localhost:" + server.getPort(), InsecureChannelCredentials.create()); + Bootstrapper.ServerInfo xdsServerInfo2 = + Bootstrapper.ServerInfo.create( + "localhost:" + server2.getPort(), InsecureChannelCredentials.create()); + GrpcXdsTransportFactory xdsTransportFactory = new GrpcXdsTransportFactory(null); + // Calling create() to the first xDS server creates a new GrpcXdsTransport instance. + // The ref count was previously 0 and now is 1. + XdsTransportFactory.XdsTransport transport1 = xdsTransportFactory.create(xdsServerInfo1); + // Calling create() to the second xDS server creates a different GrpcXdsTransport instance. + // The ref count was previously 0 and now is 1. + XdsTransportFactory.XdsTransport transport2 = xdsTransportFactory.create(xdsServerInfo2); + assertThat(transport1).isNotSameInstanceAs(transport2); + // Calling shutdown() shuts down the GrpcXdsTransport instance for the first xDS server. + // The ref count was previously 1 and now is 0. + transport1.shutdown(); + // Calling shutdown() shuts down the GrpcXdsTransport instance for the second xDS server. + // The ref count was previously 1 and now is 0. + transport2.shutdown(); + } + private static class FakeEventHandler implements XdsTransportFactory.EventHandler { private final BlockingQueue respQ = new LinkedBlockingQueue<>(); diff --git a/xds/src/test/java/io/grpc/xds/LazyLoadBalancerTest.java b/xds/src/test/java/io/grpc/xds/LazyLoadBalancerTest.java new file mode 100644 index 00000000000..c79d048c9d3 --- /dev/null +++ b/xds/src/test/java/io/grpc/xds/LazyLoadBalancerTest.java @@ -0,0 +1,94 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds; + +import static com.google.common.truth.Truth.assertThat; + +import io.grpc.CallOptions; +import io.grpc.ConnectivityState; +import io.grpc.EquivalentAddressGroup; +import io.grpc.LoadBalancer; +import io.grpc.LoadBalancer.ResolvedAddresses; +import io.grpc.LoadBalancer.SubchannelPicker; +import io.grpc.ManagedChannel; +import io.grpc.Metadata; +import io.grpc.SynchronizationContext; +import io.grpc.internal.PickSubchannelArgsImpl; +import io.grpc.testing.TestMethodDescriptors; +import java.util.Arrays; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Unit test for {@link io.grpc.xds.LazyLoadBalancer}. */ +@RunWith(JUnit4.class) +public final class LazyLoadBalancerTest { + private SynchronizationContext syncContext = + new SynchronizationContext((t, e) -> { + throw new AssertionError(e); + }); + private LoadBalancer.PickSubchannelArgs args = new PickSubchannelArgsImpl( + TestMethodDescriptors.voidMethod(), + new Metadata(), + CallOptions.DEFAULT, + new LoadBalancer.PickDetailsConsumer() {}); + private FakeHelper helper = new FakeHelper(); + + @Test + public void pickerIsNoopAfterEarlyShutdown() { + LazyLoadBalancer lb = new LazyLoadBalancer(helper, new LoadBalancer.Factory() { + @Override + public LoadBalancer newLoadBalancer(LoadBalancer.Helper helper) { + throw new AssertionError("unexpected"); + } + }); + lb.acceptResolvedAddresses(ResolvedAddresses.newBuilder() + .setAddresses(Arrays.asList()) + .build()); + SubchannelPicker picker = helper.picker; + assertThat(picker).isNotNull(); + lb.shutdown(); + + picker.pickSubchannel(args); + } + + class FakeHelper extends LoadBalancer.Helper { + ConnectivityState state; + SubchannelPicker picker; + + @Override + public ManagedChannel createOobChannel(EquivalentAddressGroup eag, String authority) { + throw new UnsupportedOperationException(); + } + + @Override + public void updateBalancingState(ConnectivityState newState, SubchannelPicker newPicker) { + this.state = newState; + this.picker = newPicker; + } + + @Override + public SynchronizationContext getSynchronizationContext() { + return syncContext; + } + + @Override + public String getAuthority() { + return "localhost"; + } + } +} diff --git a/xds/src/test/java/io/grpc/xds/LeastRequestLoadBalancerTest.java b/xds/src/test/java/io/grpc/xds/LeastRequestLoadBalancerTest.java index 9afe82d04e8..302faed95a4 100644 --- a/xds/src/test/java/io/grpc/xds/LeastRequestLoadBalancerTest.java +++ b/xds/src/test/java/io/grpc/xds/LeastRequestLoadBalancerTest.java @@ -22,6 +22,7 @@ import static io.grpc.ConnectivityState.READY; import static io.grpc.ConnectivityState.SHUTDOWN; import static io.grpc.ConnectivityState.TRANSIENT_FAILURE; +import static io.grpc.LoadBalancerMatchers.pickerReturns; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNull; @@ -31,6 +32,7 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.isA; +import static org.mockito.Mockito.atLeast; import static org.mockito.Mockito.inOrder; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; @@ -49,6 +51,7 @@ import io.grpc.EquivalentAddressGroup; import io.grpc.LoadBalancer; import io.grpc.LoadBalancer.CreateSubchannelArgs; +import io.grpc.LoadBalancer.FixedResultPicker; import io.grpc.LoadBalancer.Helper; import io.grpc.LoadBalancer.PickResult; import io.grpc.LoadBalancer.PickSubchannelArgs; @@ -58,9 +61,9 @@ import io.grpc.LoadBalancer.SubchannelStateListener; import io.grpc.Metadata; import io.grpc.Status; +import io.grpc.internal.PickFirstLoadBalancerProvider; import io.grpc.util.AbstractTestHelper; import io.grpc.util.MultiChildLoadBalancer.ChildLbState; -import io.grpc.xds.LeastRequestLoadBalancer.EmptyPicker; import io.grpc.xds.LeastRequestLoadBalancer.LeastRequestConfig; import io.grpc.xds.LeastRequestLoadBalancer.LeastRequestLbState; import io.grpc.xds.LeastRequestLoadBalancer.ReadyPicker; @@ -157,7 +160,7 @@ public void pickAfterResolved() throws Exception { assertEquals(READY, stateCaptor.getAllValues().get(1)); assertThat(getList(pickerCaptor.getValue())).containsExactly(readySubchannel); - verifyNoMoreInteractions(helper); + AbstractTestHelper.verifyNoMoreMeaningfulInteractions(helper); } @Test @@ -184,8 +187,7 @@ public void pickAfterResolvedUpdatedHosts() throws Exception { Subchannel removedSubchannel = getSubchannel(removedEag); Subchannel oldSubchannel = getSubchannel(oldEag1); SubchannelStateListener removedListener = - testHelperInstance.getSubchannelStateListeners() - .get(testHelperInstance.getRealForMockSubChannel(removedSubchannel)); + testHelperInstance.getSubchannelStateListener(removedSubchannel); inOrder.verify(helper).updateBalancingState(eq(CONNECTING), pickerCaptor.capture()); @@ -199,8 +201,6 @@ public void pickAfterResolvedUpdatedHosts() throws Exception { verify(removedSubchannel, times(1)).requestConnection(); verify(oldSubchannel, times(1)).requestConnection(); - assertThat(getChildEags(loadBalancer)).containsExactly(removedEag, oldEag1); - // This time with Attributes List latestServers = Lists.newArrayList(oldEag2, newEag); @@ -217,42 +217,18 @@ public void pickAfterResolvedUpdatedHosts() throws Exception { removedListener.onSubchannelState(ConnectivityStateInfo.forNonError(SHUTDOWN)); deliverSubchannelState(newSubchannel, ConnectivityStateInfo.forNonError(READY)); - assertThat(getChildEags(loadBalancer)).containsExactly(oldEag2, newEag); - verify(helper, times(3)).createSubchannel(any(CreateSubchannelArgs.class)); inOrder.verify(helper, times(2)).updateBalancingState(eq(READY), pickerCaptor.capture()); assertThat(getList(pickerCaptor.getValue())).containsExactly(oldSubchannel, newSubchannel); - verifyNoMoreInteractions(helper); + AbstractTestHelper.verifyNoMoreMeaningfulInteractions(helper); } private Subchannel getSubchannel(EquivalentAddressGroup removedEag) { return subchannels.get(Collections.singletonList(removedEag)); } - private Subchannel getSubchannel(ChildLbState childLbState) { - return subchannels.get(Collections.singletonList(childLbState.getEag())); - } - - private static List getChildEags(LeastRequestLoadBalancer loadBalancer) { - return loadBalancer.getChildLbStates().stream() - .map(ChildLbState::getEag) - // .map(EquivalentAddressGroup::getAddresses) - .collect(Collectors.toList()); - } - - private List getSubchannels(LeastRequestLoadBalancer lb) { - return lb.getChildLbStates().stream() - .map(this::getSubchannel) - .collect(Collectors.toList()); - } - - private LeastRequestLbState getChildLbState(PickResult pickResult) { - EquivalentAddressGroup eag = pickResult.getSubchannel().getAddresses(); - return (LeastRequestLbState) loadBalancer.getChildLbState(eag); - } - @Test public void pickAfterStateChange() throws Exception { InOrder inOrder = inOrder(helper); @@ -261,35 +237,34 @@ public void pickAfterStateChange() throws Exception { .build()); assertThat(addressesAcceptanceStatus.isOk()).isTrue(); ChildLbState childLbState = loadBalancer.getChildLbStates().iterator().next(); - Subchannel subchannel = getSubchannel(childLbState); + Subchannel subchannel = getSubchannel(servers.get(0)); - inOrder.verify(helper).updateBalancingState(eq(CONNECTING), isA(EmptyPicker.class)); + inOrder.verify(helper) + .updateBalancingState(eq(CONNECTING), pickerReturns(PickResult.withNoResult())); assertThat(childLbState.getCurrentState()).isEqualTo(CONNECTING); - deliverSubchannelState(subchannel, - ConnectivityStateInfo.forNonError(READY)); + deliverSubchannelState(subchannel, ConnectivityStateInfo.forNonError(READY)); inOrder.verify(helper).updateBalancingState(eq(READY), pickerCaptor.capture()); assertThat(pickerCaptor.getValue()).isInstanceOf(ReadyPicker.class); assertThat(childLbState.getCurrentState()).isEqualTo(READY); Status error = Status.UNKNOWN.withDescription("¯\\_(ツ)_//¯"); - deliverSubchannelState(subchannel, - ConnectivityStateInfo.forTransientFailure(error)); + deliverSubchannelState(subchannel, ConnectivityStateInfo.forTransientFailure(error)); assertThat(childLbState.getCurrentState()).isEqualTo(TRANSIENT_FAILURE); assertThat(childLbState.getCurrentPicker().toString()).contains(error.toString()); - inOrder.verify(helper).refreshNameResolution(); - inOrder.verify(helper).updateBalancingState(eq(CONNECTING), pickerCaptor.capture()); - assertThat(pickerCaptor.getValue()).isInstanceOf(EmptyPicker.class); + refreshInvokedAndUpdateBS(inOrder, CONNECTING); + assertThat(pickerCaptor.getValue().pickSubchannel(mockArgs)) + .isEqualTo(PickResult.withNoResult()); - deliverSubchannelState(subchannel, - ConnectivityStateInfo.forNonError(IDLE)); + deliverSubchannelState(subchannel, ConnectivityStateInfo.forNonError(IDLE)); inOrder.verify(helper).refreshNameResolution(); assertThat(childLbState.getCurrentState()).isEqualTo(TRANSIENT_FAILURE); assertThat(childLbState.getCurrentPicker().toString()).contains(error.toString()); - verify(subchannel, times(2)).requestConnection(); + int expectedCount = PickFirstLoadBalancerProvider.isEnabledNewPickFirst() ? 1 : 2; + verify(subchannel, times(expectedCount)).requestConnection(); verify(helper, times(3)).createSubchannel(any(CreateSubchannelArgs.class)); - verifyNoMoreInteractions(helper); + AbstractTestHelper.verifyNoMoreMeaningfulInteractions(helper); } @Test @@ -320,7 +295,7 @@ public void pickAfterConfigChange() { // At this point it should use a ReadyPicker with newConfig pickerCaptor.getValue().pickSubchannel(mockArgs); verify(mockRandom, times(oldConfig.choiceCount + newConfig.choiceCount)).nextInt(1); - verifyNoMoreInteractions(helper); + AbstractTestHelper.verifyNoMoreMeaningfulInteractions(helper); } @Test @@ -330,10 +305,12 @@ public void ignoreShutdownSubchannelStateChange() { ResolvedAddresses.newBuilder().setAddresses(servers).setAttributes(Attributes.EMPTY) .build()); assertThat(addressesAcceptanceStatus.isOk()).isTrue(); - inOrder.verify(helper).updateBalancingState(eq(CONNECTING), isA(EmptyPicker.class)); + inOrder.verify(helper) + .updateBalancingState(eq(CONNECTING), pickerReturns(PickResult.withNoResult())); + List savedSubchannels = new ArrayList<>(subchannels.values()); loadBalancer.shutdown(); - for (Subchannel sc : getSubchannels(loadBalancer)) { + for (Subchannel sc : savedSubchannels) { verify(sc).shutdown(); // When the subchannel is being shut down, a SHUTDOWN connectivity state is delivered // back to the subchannel state listener. @@ -351,31 +328,35 @@ public void stayTransientFailureUntilReady() { .build()); assertThat(addressesAcceptanceStatus.isOk()).isTrue(); - inOrder.verify(helper).updateBalancingState(eq(CONNECTING), isA(EmptyPicker.class)); + inOrder.verify(helper) + .updateBalancingState(eq(CONNECTING), pickerReturns(PickResult.withNoResult())); // Simulate state transitions for each subchannel individually. - for (ChildLbState childLbState : loadBalancer.getChildLbStates()) { - Subchannel sc = getSubchannel(childLbState); + List children = new ArrayList<>(loadBalancer.getChildLbStates()); + for (int i = 0; i < children.size(); i++) { + ChildLbState childLbState = children.get(i); + Subchannel sc = getSubchannel(servers.get(i)); Status error = Status.UNKNOWN.withDescription("connection broken"); deliverSubchannelState(sc, ConnectivityStateInfo.forTransientFailure(error)); - inOrder.verify(helper).refreshNameResolution(); deliverSubchannelState(sc, ConnectivityStateInfo.forNonError(CONNECTING)); assertThat(childLbState.getCurrentState()).isEqualTo(TRANSIENT_FAILURE); } - inOrder.verify(helper) - .updateBalancingState(eq(TRANSIENT_FAILURE), pickerCaptor.capture()); + + verify(helper, atLeast(loadBalancer.getChildLbStates().size())).refreshNameResolution(); + inOrder.verify(helper).updateBalancingState(eq(TRANSIENT_FAILURE), pickerCaptor.capture()); assertThat(getStatusString(pickerCaptor.getValue())) .contains("Status{code=UNKNOWN, description=connection broken"); + inOrder.verify(helper, atLeast(0)).refreshNameResolution(); inOrder.verifyNoMoreInteractions(); ChildLbState childLbState = loadBalancer.getChildLbStates().iterator().next(); - Subchannel subchannel = getSubchannel(childLbState); + Subchannel subchannel = getSubchannel(servers.get(0)); deliverSubchannelState(subchannel, ConnectivityStateInfo.forNonError(READY)); assertThat(childLbState.getCurrentState()).isEqualTo(READY); inOrder.verify(helper).updateBalancingState(eq(READY), isA(ReadyPicker.class)); verify(helper, times(3)).createSubchannel(any(CreateSubchannelArgs.class)); - verifyNoMoreInteractions(helper); + AbstractTestHelper.verifyNoMoreMeaningfulInteractions(helper); } private String getStatusString(SubchannelPicker picker) { @@ -408,10 +389,11 @@ public void refreshNameResolutionWhenSubchannelConnectionBroken() { assertThat(addressesAcceptanceStatus.isOk()).isTrue(); verify(helper, times(3)).createSubchannel(any(CreateSubchannelArgs.class)); - inOrder.verify(helper).updateBalancingState(eq(CONNECTING), isA(EmptyPicker.class)); + inOrder.verify(helper) + .updateBalancingState(eq(CONNECTING), pickerReturns(PickResult.withNoResult())); // Simulate state transitions for each subchannel individually. - for (Subchannel sc : getSubchannels(loadBalancer)) { + for (Subchannel sc : subchannels.values()) { verify(sc).requestConnection(); deliverSubchannelState(sc, ConnectivityStateInfo.forNonError(CONNECTING)); Status error = Status.UNKNOWN.withDescription("connection broken"); @@ -423,10 +405,11 @@ public void refreshNameResolutionWhenSubchannelConnectionBroken() { deliverSubchannelState(sc, ConnectivityStateInfo.forNonError(IDLE)); inOrder.verify(helper).refreshNameResolution(); verify(sc, times(2)).requestConnection(); - inOrder.verify(helper).updateBalancingState(eq(CONNECTING), isA(EmptyPicker.class)); + inOrder.verify(helper) + .updateBalancingState(eq(CONNECTING), pickerReturns(PickResult.withNoResult())); } - verifyNoMoreInteractions(helper); + AbstractTestHelper.verifyNoMoreMeaningfulInteractions(helper); } @Test @@ -449,8 +432,8 @@ public void pickerLeastRequest() throws Exception { ((LeastRequestLbState) childLbStates.get(i)).getActiveRequests()); } - for (ChildLbState cs : childLbStates) { - deliverSubchannelState(getSubchannel(cs), ConnectivityStateInfo.forNonError(READY)); + for (Subchannel sc : subchannels.values()) { + deliverSubchannelState(sc, ConnectivityStateInfo.forNonError(READY)); } // Capture the active ReadyPicker once all subchannels are READY @@ -460,45 +443,37 @@ public void pickerLeastRequest() throws Exception { ReadyPicker picker = (ReadyPicker) pickerCaptor.getValue(); - assertThat(picker.getChildEags()) - .containsExactlyElementsIn(childLbStates.stream().map(ChildLbState::getEag).toArray()); + assertThat(picker.getChildPickers()).containsExactlyElementsIn( + childLbStates.stream().map(ChildLbState::getCurrentPicker).toArray()); // Make random return 0, then 2 for the sample indexes. when(mockRandom.nextInt(childLbStates.size())).thenReturn(0, 2); PickResult pickResult1 = picker.pickSubchannel(mockArgs); verify(mockRandom, times(choiceCount)).nextInt(childLbStates.size()); - assertEquals(childLbStates.get(0), getChildLbState(pickResult1)); + assertThat(pickResult1.getSubchannel()).isEqualTo(getSubchannel(servers.get(0))); // This simulates sending the actual RPC on the picked channel ClientStreamTracer streamTracer1 = pickResult1.getStreamTracerFactory() .newClientStreamTracer(StreamInfo.newBuilder().build(), new Metadata()); streamTracer1.streamCreated(Attributes.EMPTY, new Metadata()); - assertEquals(1, getChildLbState(pickResult1).getActiveRequests()); + assertEquals(1, ((LeastRequestLbState) childLbStates.get(0)).getActiveRequests()); // For the second pick it should pick the one with lower inFlight. when(mockRandom.nextInt(childLbStates.size())).thenReturn(0, 2); PickResult pickResult2 = picker.pickSubchannel(mockArgs); // Since this is the second pick we expect the total random samples to be choiceCount * 2 verify(mockRandom, times(choiceCount * 2)).nextInt(childLbStates.size()); - assertEquals(childLbStates.get(2), getChildLbState(pickResult2)); + assertThat(pickResult2.getSubchannel()).isEqualTo(getSubchannel(servers.get(2))); // For the third pick we unavoidably pick subchannel with index 1. when(mockRandom.nextInt(childLbStates.size())).thenReturn(1, 1); PickResult pickResult3 = picker.pickSubchannel(mockArgs); verify(mockRandom, times(choiceCount * 3)).nextInt(childLbStates.size()); - assertEquals(childLbStates.get(1), getChildLbState(pickResult3)); + assertThat(pickResult3.getSubchannel()).isEqualTo(getSubchannel(servers.get(1))); // Finally ensure a finished RPC decreases inFlight streamTracer1.streamClosed(Status.OK); - assertEquals(0, getChildLbState(pickResult1).getActiveRequests()); - } - - @Test - public void pickerEmptyList() throws Exception { - SubchannelPicker picker = new EmptyPicker(); - - assertNull(picker.pickSubchannel(mockArgs).getSubchannel()); - assertEquals(Status.OK, picker.pickSubchannel(mockArgs).getStatus()); + assertEquals(0, ((LeastRequestLbState) childLbStates.get(0)).getActiveRequests()); } @Test @@ -548,7 +523,7 @@ public void nameResolutionErrorWithActiveChannels() throws Exception { LoadBalancer.PickResult pickResult2 = pickerCaptor.getValue().pickSubchannel(mockArgs); verify(mockRandom, times(choiceCount * 2)).nextInt(1); assertEquals(readySubchannel, pickResult2.getSubchannel()); - verifyNoMoreInteractions(helper); + AbstractTestHelper.verifyNoMoreMeaningfulInteractions(helper); } @Test @@ -578,7 +553,8 @@ public void subchannelStateIsolation() throws Exception { Iterator pickers = pickerCaptor.getAllValues().iterator(); // The picker is incrementally updated as subchannels become READY assertEquals(CONNECTING, stateIterator.next()); - assertThat(pickers.next()).isInstanceOf(EmptyPicker.class); + assertThat(pickers.next().pickSubchannel(mockArgs)) + .isEqualTo(PickResult.withNoResult()); assertEquals(READY, stateIterator.next()); assertThat(getList(pickers.next())).containsExactly(sc1); assertEquals(READY, stateIterator.next()); @@ -609,8 +585,8 @@ public void readyPicker_emptyList() { @Test public void internalPickerComparisons() { - EmptyPicker empty1 = new EmptyPicker(); - EmptyPicker empty2 = new EmptyPicker(); + FixedResultPicker empty1 = new FixedResultPicker(PickResult.withNoResult()); + FixedResultPicker empty2 = new FixedResultPicker(PickResult.withNoResult()); loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder().setAddresses(servers).setAttributes(affinity).build()); @@ -648,8 +624,8 @@ public void emptyAddresses() { private List getList(SubchannelPicker picker) { if (picker instanceof ReadyPicker) { - return ((ReadyPicker) picker).getChildEags().stream() - .map(this::getSubchannel) + return ((ReadyPicker) picker).getChildPickers().stream() + .map((p) -> p.pickSubchannel(mockArgs).getSubchannel()) .collect(Collectors.toList()); } else { return Collections.emptyList(); @@ -660,6 +636,19 @@ private void deliverSubchannelState(Subchannel subchannel, ConnectivityStateInfo testHelperInstance.deliverSubchannelState(subchannel, newState); } + // Old PF and new PF reverse calling order of updateBlaancingState and refreshNameResolution + private void refreshInvokedAndUpdateBS(InOrder inOrder, ConnectivityState state) { + if (PickFirstLoadBalancerProvider.isEnabledNewPickFirst()) { + inOrder.verify(helper).updateBalancingState(eq(state), pickerCaptor.capture()); + } + + inOrder.verify(helper).refreshNameResolution(); + + if (!PickFirstLoadBalancerProvider.isEnabledNewPickFirst()) { + inOrder.verify(helper).updateBalancingState(eq(state), pickerCaptor.capture()); + } + } + private static class FakeSocketAddress extends SocketAddress { final String name; diff --git a/xds/src/test/java/io/grpc/xds/LoadBalancerConfigFactoryTest.java b/xds/src/test/java/io/grpc/xds/LoadBalancerConfigFactoryTest.java index 09ee670ee38..b8b20248026 100644 --- a/xds/src/test/java/io/grpc/xds/LoadBalancerConfigFactoryTest.java +++ b/xds/src/test/java/io/grpc/xds/LoadBalancerConfigFactoryTest.java @@ -101,6 +101,22 @@ public class LoadBalancerConfigFactoryTest { .build())) .build()) .build(); + + private static final Policy WRR_POLICY_WITH_METRICS = Policy.newBuilder() + .setTypedExtensionConfig(TypedExtensionConfig.newBuilder() + .setName("backend") + .setTypedConfig( + Any.pack(ClientSideWeightedRoundRobin.newBuilder() + .setBlackoutPeriod(Duration.newBuilder().setSeconds(287).build()) + .setEnableOobLoadReport( + BoolValue.newBuilder().setValue(true).build()) + .setErrorUtilizationPenalty( + FloatValue.newBuilder().setValue(1.75F).build()) + .addMetricNamesForComputingUtilization("foo") + .addMetricNamesForComputingUtilization("bar") + .build())) + .build()) + .build(); private static final String CUSTOM_POLICY_NAME = "myorg.MyCustomLeastRequestPolicy"; private static final String CUSTOM_POLICY_FIELD_KEY = "choiceCount"; private static final double CUSTOM_POLICY_FIELD_VALUE = 2; @@ -130,6 +146,15 @@ public class LoadBalancerConfigFactoryTest { ImmutableMap.of("weighted_round_robin", ImmutableMap.of("blackoutPeriod","287s", "enableOobLoadReport", true, "errorUtilizationPenalty", 1.75F ))))); + + private static final LbConfig VALID_WRR_CONFIG_WITH_METRICS = + new LbConfig("wrr_locality_experimental", + ImmutableMap.of("childPolicy", + ImmutableList.of(ImmutableMap.of("weighted_round_robin", + ImmutableMap.of("blackoutPeriod", "287s", "enableOobLoadReport", true, + "errorUtilizationPenalty", 1.75F, + LoadBalancerConfigFactory.METRIC_NAMES_FOR_COMPUTING_UTILIZATION, + ImmutableList.of("foo", "bar")))))); private static final LbConfig VALID_RING_HASH_CONFIG = new LbConfig("ring_hash_experimental", ImmutableMap.of("minRingSize", (double) RING_HASH_MIN_RING_SIZE, "maxRingSize", (double) RING_HASH_MAX_RING_SIZE)); @@ -155,14 +180,21 @@ public void deregisterCustomProvider() { public void roundRobin() throws ResourceInvalidException { Cluster cluster = newCluster(buildWrrPolicy(ROUND_ROBIN_POLICY)); - assertThat(newLbConfig(cluster, true, true, true)).isEqualTo(VALID_ROUND_ROBIN_CONFIG); + assertThat(newLbConfig(cluster, true)).isEqualTo(VALID_ROUND_ROBIN_CONFIG); } @Test public void weightedRoundRobin() throws ResourceInvalidException { Cluster cluster = newCluster(buildWrrPolicy(WRR_POLICY)); - assertThat(newLbConfig(cluster, true, true, true)).isEqualTo(VALID_WRR_CONFIG); + assertThat(newLbConfig(cluster, true)).isEqualTo(VALID_WRR_CONFIG); + } + + @Test + public void weightedRoundRobin_withMetrics() throws ResourceInvalidException { + Cluster cluster = newCluster(buildWrrPolicy(WRR_POLICY_WITH_METRICS)); + + assertThat(newLbConfig(cluster, true)).isEqualTo(VALID_WRR_CONFIG_WITH_METRICS); } @Test @@ -179,22 +211,15 @@ public void weightedRoundRobin_invalid() throws ResourceInvalidException { .build()) .build())); - assertResourceInvalidExceptionThrown(cluster, true, true, true, + assertResourceInvalidExceptionThrown(cluster, true, "Invalid duration in weighted round robin config"); } - @Test - public void weightedRoundRobin_fallback_roundrobin() throws ResourceInvalidException { - Cluster cluster = newCluster(buildWrrPolicy(WRR_POLICY, ROUND_ROBIN_POLICY)); - - assertThat(newLbConfig(cluster, true, false, true)).isEqualTo(VALID_ROUND_ROBIN_CONFIG); - } - @Test public void roundRobin_legacy() throws ResourceInvalidException { Cluster cluster = Cluster.newBuilder().setLbPolicy(LbPolicy.ROUND_ROBIN).build(); - assertThat(newLbConfig(cluster, true, true, true)).isEqualTo(VALID_ROUND_ROBIN_CONFIG); + assertThat(newLbConfig(cluster, true)).isEqualTo(VALID_ROUND_ROBIN_CONFIG); } @Test @@ -203,7 +228,7 @@ public void ringHash() throws ResourceInvalidException { .setLoadBalancingPolicy(LoadBalancingPolicy.newBuilder().addPolicies(RING_HASH_POLICY)) .build(); - assertThat(newLbConfig(cluster, true, true, true)).isEqualTo(VALID_RING_HASH_CONFIG); + assertThat(newLbConfig(cluster, true)).isEqualTo(VALID_RING_HASH_CONFIG); } @Test @@ -213,7 +238,7 @@ public void ringHash_legacy() throws ResourceInvalidException { .setMaximumRingSize(UInt64Value.of(RING_HASH_MAX_RING_SIZE)) .setHashFunction(HashFunction.XX_HASH)).build(); - assertThat(newLbConfig(cluster, true, true, true)).isEqualTo(VALID_RING_HASH_CONFIG); + assertThat(newLbConfig(cluster, true)).isEqualTo(VALID_RING_HASH_CONFIG); } @Test @@ -225,7 +250,7 @@ public void ringHash_invalidHash() { .setMaximumRingSize(UInt64Value.of(RING_HASH_MAX_RING_SIZE)) .setHashFunction(RingHash.HashFunction.MURMUR_HASH_2).build()))).build()); - assertResourceInvalidExceptionThrown(cluster, true, true, true, "Invalid ring hash function"); + assertResourceInvalidExceptionThrown(cluster, true, "Invalid ring hash function"); } @Test @@ -233,7 +258,7 @@ public void ringHash_invalidHash_legacy() { Cluster cluster = Cluster.newBuilder().setLbPolicy(LbPolicy.RING_HASH).setRingHashLbConfig( RingHashLbConfig.newBuilder().setHashFunction(HashFunction.MURMUR_HASH_2)).build(); - assertResourceInvalidExceptionThrown(cluster, true, true, true, "invalid ring hash function"); + assertResourceInvalidExceptionThrown(cluster, true, "invalid ring hash function"); } @Test @@ -242,7 +267,7 @@ public void leastRequest() throws ResourceInvalidException { .setLoadBalancingPolicy(LoadBalancingPolicy.newBuilder().addPolicies(LEAST_REQUEST_POLICY)) .build(); - assertThat(newLbConfig(cluster, true, true, true)).isEqualTo(VALID_LEAST_REQUEST_CONFIG); + assertThat(newLbConfig(cluster, true)).isEqualTo(VALID_LEAST_REQUEST_CONFIG); } @Test @@ -254,7 +279,7 @@ public void leastRequest_legacy() throws ResourceInvalidException { LeastRequestLbConfig.newBuilder() .setChoiceCount(UInt32Value.of(LEAST_REQUEST_CHOICE_COUNT))).build(); - LbConfig lbConfig = newLbConfig(cluster, true, true, true); + LbConfig lbConfig = newLbConfig(cluster, true); assertThat(lbConfig.getPolicyName()).isEqualTo("wrr_locality_experimental"); List childConfigs = ServiceConfigUtil.unwrapLoadBalancingConfigList( @@ -269,7 +294,7 @@ public void leastRequest_legacy() throws ResourceInvalidException { public void leastRequest_notEnabled() { Cluster cluster = Cluster.newBuilder().setLbPolicy(LbPolicy.LEAST_REQUEST).build(); - assertResourceInvalidExceptionThrown(cluster, false, true, true, "unsupported lb policy"); + assertResourceInvalidExceptionThrown(cluster, false, "unsupported lb policy"); } @Test @@ -278,24 +303,14 @@ public void pickFirst() throws ResourceInvalidException { .setLoadBalancingPolicy(LoadBalancingPolicy.newBuilder().addPolicies(PICK_FIRST_POLICY)) .build(); - assertThat(newLbConfig(cluster, true, true, true)).isEqualTo(VALID_PICK_FIRST_CONFIG); - } - - @Test - public void pickFirst_notEnabled() throws ResourceInvalidException { - Cluster cluster = Cluster.newBuilder() - .setLoadBalancingPolicy(LoadBalancingPolicy.newBuilder().addPolicies(PICK_FIRST_POLICY)) - .build(); - - assertResourceInvalidExceptionThrown(cluster, true, true, false, "Invalid LoadBalancingPolicy"); + assertThat(newLbConfig(cluster, true)).isEqualTo(VALID_PICK_FIRST_CONFIG); } @Test public void customRootLb_providerRegistered() throws ResourceInvalidException { LoadBalancerRegistry.getDefaultRegistry().register(CUSTOM_POLICY_PROVIDER); - assertThat(newLbConfig(newCluster(CUSTOM_POLICY), false, - true, true)).isEqualTo(VALID_CUSTOM_CONFIG); + assertThat(newLbConfig(newCluster(CUSTOM_POLICY), false)).isEqualTo(VALID_CUSTOM_CONFIG); } @Test @@ -304,7 +319,7 @@ public void customRootLb_providerNotRegistered() throws ResourceInvalidException .setLoadBalancingPolicy(LoadBalancingPolicy.newBuilder().addPolicies(CUSTOM_POLICY)) .build(); - assertResourceInvalidExceptionThrown(cluster, false, true, true, "Invalid LoadBalancingPolicy"); + assertResourceInvalidExceptionThrown(cluster, false, "Invalid LoadBalancingPolicy"); } // When a provider for the endpoint picking custom policy is available, the configuration should @@ -316,7 +331,7 @@ public void customLbInWrr_providerRegistered() throws ResourceInvalidException { Cluster cluster = Cluster.newBuilder().setLoadBalancingPolicy(LoadBalancingPolicy.newBuilder() .addPolicies(buildWrrPolicy(CUSTOM_POLICY, ROUND_ROBIN_POLICY))).build(); - assertThat(newLbConfig(cluster, false, true, true)).isEqualTo(VALID_CUSTOM_CONFIG_IN_WRR); + assertThat(newLbConfig(cluster, false)).isEqualTo(VALID_CUSTOM_CONFIG_IN_WRR); } // When a provider for the endpoint picking custom policy is available, the configuration should @@ -328,7 +343,7 @@ public void customLbInWrr_providerRegistered_udpa() throws ResourceInvalidExcept Cluster cluster = Cluster.newBuilder().setLoadBalancingPolicy(LoadBalancingPolicy.newBuilder() .addPolicies(buildWrrPolicy(CUSTOM_POLICY_UDPA, ROUND_ROBIN_POLICY))).build(); - assertThat(newLbConfig(cluster, false, true, true)).isEqualTo(VALID_CUSTOM_CONFIG_IN_WRR); + assertThat(newLbConfig(cluster, false)).isEqualTo(VALID_CUSTOM_CONFIG_IN_WRR); } // When a provider for the custom wrr_locality child policy is NOT available, we should fall back @@ -338,7 +353,7 @@ public void customLbInWrr_providerNotRegistered() throws ResourceInvalidExceptio Cluster cluster = Cluster.newBuilder().setLoadBalancingPolicy(LoadBalancingPolicy.newBuilder() .addPolicies(buildWrrPolicy(CUSTOM_POLICY, ROUND_ROBIN_POLICY))).build(); - assertThat(newLbConfig(cluster, false, true, true)).isEqualTo(VALID_ROUND_ROBIN_CONFIG); + assertThat(newLbConfig(cluster, false)).isEqualTo(VALID_ROUND_ROBIN_CONFIG); } // When a provider for the custom wrr_locality child policy is NOT available and no alternative @@ -348,7 +363,7 @@ public void customLbInWrr_providerNotRegistered_noFallback() throws ResourceInva Cluster cluster = Cluster.newBuilder().setLoadBalancingPolicy( LoadBalancingPolicy.newBuilder().addPolicies(buildWrrPolicy(CUSTOM_POLICY))).build(); - assertResourceInvalidExceptionThrown(cluster, false, true, true, "Invalid LoadBalancingPolicy"); + assertResourceInvalidExceptionThrown(cluster, false, "Invalid LoadBalancingPolicy"); } @Test @@ -375,7 +390,7 @@ public void maxRecursion() { buildWrrPolicy( ROUND_ROBIN_POLICY))))))))))))))))))).build(); - assertResourceInvalidExceptionThrown(cluster, false, true, true, + assertResourceInvalidExceptionThrown(cluster, false, "Maximum LB config recursion depth reached"); } @@ -391,18 +406,16 @@ private static Policy buildWrrPolicy(Policy... childPolicy) { .build()))).build(); } - private LbConfig newLbConfig(Cluster cluster, boolean enableLeastRequest, boolean enableWrr, - boolean enablePickFirst) + private LbConfig newLbConfig(Cluster cluster, boolean enableLeastRequest) throws ResourceInvalidException { return ServiceConfigUtil.unwrapLoadBalancingConfig( - LoadBalancerConfigFactory.newConfig(cluster, enableLeastRequest, - enableWrr, enablePickFirst)); + LoadBalancerConfigFactory.newConfig(cluster, enableLeastRequest)); } private void assertResourceInvalidExceptionThrown(Cluster cluster, boolean enableLeastRequest, - boolean enableWrr, boolean enablePickFirst, String expectedMessage) { + String expectedMessage) { try { - newLbConfig(cluster, enableLeastRequest, enableWrr, enablePickFirst); + newLbConfig(cluster, enableLeastRequest); } catch (ResourceInvalidException e) { assertThat(e).hasMessageThat().contains(expectedMessage); return; diff --git a/xds/src/test/java/io/grpc/xds/LoadReportClientTest.java b/xds/src/test/java/io/grpc/xds/LoadReportClientTest.java index c11a3a6e0d2..9bdf86132b6 100644 --- a/xds/src/test/java/io/grpc/xds/LoadReportClientTest.java +++ b/xds/src/test/java/io/grpc/xds/LoadReportClientTest.java @@ -178,11 +178,15 @@ public void cancelled(Context context) { when(backoffPolicy2.nextBackoffNanos()) .thenReturn(TimeUnit.SECONDS.toNanos(2L), TimeUnit.SECONDS.toNanos(20L)); addFakeStatsData(); - lrsClient = new LoadReportClient(loadStatsManager, - GrpcXdsTransportFactory.DEFAULT_XDS_TRANSPORT_FACTORY.createForTest(channel), - NODE, - syncContext, fakeClock.getScheduledExecutorService(), backoffPolicyProvider, - fakeClock.getStopwatchSupplier()); + lrsClient = + new LoadReportClient( + loadStatsManager, + new GrpcXdsTransportFactory(null).createForTest(channel), + NODE, + syncContext, + fakeClock.getScheduledExecutorService(), + backoffPolicyProvider, + fakeClock.getStopwatchSupplier()); syncContext.execute(new Runnable() { @Override public void run() { diff --git a/xds/src/test/java/io/grpc/xds/MetadataLoadBalancerProvider.java b/xds/src/test/java/io/grpc/xds/MetadataLoadBalancerProvider.java index ecc0112a2e0..0499bafdb23 100644 --- a/xds/src/test/java/io/grpc/xds/MetadataLoadBalancerProvider.java +++ b/xds/src/test/java/io/grpc/xds/MetadataLoadBalancerProvider.java @@ -107,6 +107,7 @@ protected LoadBalancer delegate() { return delegateLb; } + @Deprecated @Override public void handleResolvedAddresses(ResolvedAddresses resolvedAddresses) { MetadataLoadBalancerConfig config @@ -114,6 +115,14 @@ public void handleResolvedAddresses(ResolvedAddresses resolvedAddresses) { helper.setMetadata(config.metadataKey, config.metadataValue); delegateLb.handleResolvedAddresses(resolvedAddresses); } + + @Override + public Status acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { + MetadataLoadBalancerConfig config + = (MetadataLoadBalancerConfig) resolvedAddresses.getLoadBalancingPolicyConfig(); + helper.setMetadata(config.metadataKey, config.metadataValue); + return delegateLb.acceptResolvedAddresses(resolvedAddresses); + } } /** diff --git a/xds/src/test/java/io/grpc/xds/PriorityLoadBalancerProviderTest.java b/xds/src/test/java/io/grpc/xds/PriorityLoadBalancerProviderTest.java index 5d96ed87949..37ea24b2aa9 100644 --- a/xds/src/test/java/io/grpc/xds/PriorityLoadBalancerProviderTest.java +++ b/xds/src/test/java/io/grpc/xds/PriorityLoadBalancerProviderTest.java @@ -16,27 +16,24 @@ package io.grpc.xds; +import static org.junit.Assert.assertThrows; import static org.mockito.Mockito.mock; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import io.grpc.LoadBalancerProvider; -import io.grpc.internal.ServiceConfigUtil.PolicySelection; +import io.grpc.util.GracefulSwitchLoadBalancer; import io.grpc.xds.PriorityLoadBalancerProvider.PriorityLbConfig; import io.grpc.xds.PriorityLoadBalancerProvider.PriorityLbConfig.PriorityChildConfig; import java.util.List; import java.util.Map; -import org.junit.Rule; import org.junit.Test; -import org.junit.rules.ExpectedException; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; /** Tests for {@link PriorityLoadBalancerProvider}. */ @RunWith(JUnit4.class) public class PriorityLoadBalancerProviderTest { - @SuppressWarnings("deprecation") // https://github.com/grpc/grpc-java/issues/7467 - @Rule public final ExpectedException thrown = ExpectedException.none(); @SuppressWarnings("ExpectedExceptionChecker") @Test @@ -45,11 +42,11 @@ public void priorityLbConfig_emptyPriorities() { ImmutableMap.of( "p0", new PriorityChildConfig( - new PolicySelection(mock(LoadBalancerProvider.class), null), true)); + newChildConfig(mock(LoadBalancerProvider.class), null), true)); List priorities = ImmutableList.of(); - thrown.expect(IllegalArgumentException.class); - new PriorityLbConfig(childConfigs, priorities); + assertThrows(IllegalArgumentException.class, + () -> new PriorityLbConfig(childConfigs, priorities)); } @SuppressWarnings("ExpectedExceptionChecker") @@ -59,10 +56,14 @@ public void priorityLbConfig_missingChildConfig() { ImmutableMap.of( "p1", new PriorityChildConfig( - new PolicySelection(mock(LoadBalancerProvider.class), null), true)); + newChildConfig(mock(LoadBalancerProvider.class), null), true)); List priorities = ImmutableList.of("p0", "p1"); - thrown.expect(IllegalArgumentException.class); - new PriorityLbConfig(childConfigs, priorities); + assertThrows(IllegalArgumentException.class, + () -> new PriorityLbConfig(childConfigs, priorities)); + } + + private Object newChildConfig(LoadBalancerProvider provider, Object config) { + return GracefulSwitchLoadBalancer.createLoadBalancingPolicyConfig(provider, config); } } diff --git a/xds/src/test/java/io/grpc/xds/PriorityLoadBalancerTest.java b/xds/src/test/java/io/grpc/xds/PriorityLoadBalancerTest.java index c037ab5034c..988bc720e45 100644 --- a/xds/src/test/java/io/grpc/xds/PriorityLoadBalancerTest.java +++ b/xds/src/test/java/io/grpc/xds/PriorityLoadBalancerTest.java @@ -28,11 +28,13 @@ import static org.mockito.Mockito.atLeastOnce; import static org.mockito.Mockito.clearInvocations; import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.inOrder; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.when; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; @@ -52,8 +54,8 @@ import io.grpc.Status; import io.grpc.SynchronizationContext; import io.grpc.internal.FakeClock; -import io.grpc.internal.ServiceConfigUtil.PolicySelection; import io.grpc.internal.TestUtils.StandardLoadBalancerProvider; +import io.grpc.util.GracefulSwitchLoadBalancer; import io.grpc.xds.PriorityLoadBalancerProvider.PriorityLbConfig; import io.grpc.xds.PriorityLoadBalancerProvider.PriorityLbConfig.PriorityChildConfig; import java.net.InetSocketAddress; @@ -69,6 +71,7 @@ import org.junit.runners.JUnit4; import org.mockito.ArgumentCaptor; import org.mockito.Captor; +import org.mockito.InOrder; import org.mockito.Mock; import org.mockito.junit.MockitoJUnit; import org.mockito.junit.MockitoRule; @@ -97,6 +100,8 @@ public void uncaughtException(Thread t, Throwable e) { public LoadBalancer newLoadBalancer(Helper helper) { fooHelpers.add(helper); LoadBalancer childBalancer = mock(LoadBalancer.class); + when(childBalancer.acceptResolvedAddresses(any(ResolvedAddresses.class))) + .thenReturn(Status.OK); fooBalancers.add(childBalancer); return childBalancer; } @@ -107,6 +112,8 @@ public LoadBalancer newLoadBalancer(Helper helper) { @Override public LoadBalancer newLoadBalancer(Helper helper) { LoadBalancer childBalancer = mock(LoadBalancer.class); + when(childBalancer.acceptResolvedAddresses(any(ResolvedAddresses.class))) + .thenReturn(Status.OK); barBalancers.add(childBalancer); return childBalancer; } @@ -141,7 +148,107 @@ public void tearDown() { } @Test - public void handleResolvedAddresses() { + public void acceptResolvedAddresses() { + boolean originalFlagVal = PriorityLoadBalancer.enablePriorityLbChildPolicyCache; + PriorityLoadBalancer.enablePriorityLbChildPolicyCache = true; + try { + SocketAddress socketAddress = new InetSocketAddress(8080); + EquivalentAddressGroup eag = new EquivalentAddressGroup(socketAddress); + eag = AddressFilter.setPathFilter(eag, ImmutableList.of("p1")); + List addresses = ImmutableList.of(eag); + Attributes attributes = + Attributes.newBuilder().set(Attributes.Key.create("fakeKey"), "fakeValue").build(); + Object fooConfig0 = new Object(); + PriorityChildConfig priorityChildConfig0 = + new PriorityChildConfig(newChildConfig(fooLbProvider, fooConfig0), true); + Object barConfig0 = new Object(); + PriorityChildConfig priorityChildConfig1 = + new PriorityChildConfig(newChildConfig(barLbProvider, barConfig0), true); + Object fooConfig1 = new Object(); + PriorityChildConfig priorityChildConfig2 = + new PriorityChildConfig(newChildConfig(fooLbProvider, fooConfig1), true); + PriorityLbConfig priorityLbConfig = + new PriorityLbConfig( + ImmutableMap.of("p0", priorityChildConfig0, "p1", priorityChildConfig1, + "p2", priorityChildConfig2), + ImmutableList.of("p0", "p1", "p2")); + Status status = priorityLb.acceptResolvedAddresses( + ResolvedAddresses.newBuilder() + .setAddresses(addresses) + .setAttributes(attributes) + .setLoadBalancingPolicyConfig(priorityLbConfig) + .build()); + assertThat(status.getCode()).isEqualTo(Status.Code.OK); + assertThat(fooBalancers).hasSize(1); + assertThat(barBalancers).isEmpty(); + LoadBalancer fooBalancer0 = Iterables.getOnlyElement(fooBalancers); + verify(fooBalancer0).acceptResolvedAddresses(resolvedAddressesCaptor.capture()); + ResolvedAddresses addressesReceived = resolvedAddressesCaptor.getValue(); + assertThat(addressesReceived.getAddresses()).isEmpty(); + assertThat(addressesReceived.getAttributes()).isEqualTo(attributes); + assertThat(addressesReceived.getLoadBalancingPolicyConfig()).isEqualTo(fooConfig0); + + // Fail over to p1. + fakeClock.forwardTime(10, TimeUnit.SECONDS); + assertThat(fooBalancers).hasSize(1); + assertThat(barBalancers).hasSize(1); + LoadBalancer barBalancer0 = Iterables.getOnlyElement(barBalancers); + verify(barBalancer0).acceptResolvedAddresses(resolvedAddressesCaptor.capture()); + addressesReceived = resolvedAddressesCaptor.getValue(); + assertThat(Iterables.getOnlyElement(addressesReceived.getAddresses()).getAddresses()) + .containsExactly(socketAddress); + assertThat(addressesReceived.getAttributes()).isEqualTo(attributes); + assertThat(addressesReceived.getLoadBalancingPolicyConfig()).isEqualTo(barConfig0); + + // Fail over to p2. + fakeClock.forwardTime(10, TimeUnit.SECONDS); + assertThat(fooBalancers).hasSize(2); + assertThat(barBalancers).hasSize(1); + LoadBalancer fooBalancer1 = Iterables.getLast(fooBalancers); + verify(fooBalancer1).acceptResolvedAddresses(resolvedAddressesCaptor.capture()); + addressesReceived = resolvedAddressesCaptor.getValue(); + assertThat(addressesReceived.getAddresses()).isEmpty(); + assertThat(addressesReceived.getAttributes()).isEqualTo(attributes); + assertThat(addressesReceived.getLoadBalancingPolicyConfig()).isEqualTo(fooConfig1); + + // New update: p0 and p2 deleted; p1 config changed. + SocketAddress newSocketAddress = new InetSocketAddress(8081); + EquivalentAddressGroup newEag = new EquivalentAddressGroup(newSocketAddress); + newEag = AddressFilter.setPathFilter(newEag, ImmutableList.of("p1")); + List newAddresses = ImmutableList.of(newEag); + Object newBarConfig = new Object(); + PriorityLbConfig newPriorityLbConfig = + new PriorityLbConfig( + ImmutableMap.of("p1", + new PriorityChildConfig(newChildConfig(barLbProvider, newBarConfig), true)), + ImmutableList.of("p1")); + status = priorityLb.acceptResolvedAddresses( + ResolvedAddresses.newBuilder() + .setAddresses(newAddresses) + .setLoadBalancingPolicyConfig(newPriorityLbConfig) + .build()); + assertThat(status.getCode()).isEqualTo(Status.Code.OK); + assertThat(fooBalancers).hasSize(2); + assertThat(barBalancers).hasSize(1); + verify(barBalancer0, times(2)).acceptResolvedAddresses(resolvedAddressesCaptor.capture()); + addressesReceived = resolvedAddressesCaptor.getValue(); + assertThat(Iterables.getOnlyElement(addressesReceived.getAddresses()).getAddresses()) + .containsExactly(newSocketAddress); + assertThat(addressesReceived.getAttributes()).isEqualTo(Attributes.EMPTY); + assertThat(addressesReceived.getLoadBalancingPolicyConfig()).isEqualTo(newBarConfig); + verify(fooBalancer0, never()).shutdown(); + verify(fooBalancer1, never()).shutdown(); + fakeClock.forwardTime(15, TimeUnit.MINUTES); + verify(fooBalancer0).shutdown(); + verify(fooBalancer1).shutdown(); + verify(barBalancer0, never()).shutdown(); + } finally { + PriorityLoadBalancer.enablePriorityLbChildPolicyCache = originalFlagVal; + } + } + + @Test + public void acceptResolvedAddresses_cacheDisabled() { SocketAddress socketAddress = new InetSocketAddress(8080); EquivalentAddressGroup eag = new EquivalentAddressGroup(socketAddress); eag = AddressFilter.setPathFilter(eag, ImmutableList.of("p1")); @@ -150,28 +257,29 @@ public void handleResolvedAddresses() { Attributes.newBuilder().set(Attributes.Key.create("fakeKey"), "fakeValue").build(); Object fooConfig0 = new Object(); PriorityChildConfig priorityChildConfig0 = - new PriorityChildConfig(new PolicySelection(fooLbProvider, fooConfig0), true); + new PriorityChildConfig(newChildConfig(fooLbProvider, fooConfig0), true); Object barConfig0 = new Object(); PriorityChildConfig priorityChildConfig1 = - new PriorityChildConfig(new PolicySelection(barLbProvider, barConfig0), true); + new PriorityChildConfig(newChildConfig(barLbProvider, barConfig0), true); Object fooConfig1 = new Object(); PriorityChildConfig priorityChildConfig2 = - new PriorityChildConfig(new PolicySelection(fooLbProvider, fooConfig1), true); + new PriorityChildConfig(newChildConfig(fooLbProvider, fooConfig1), true); PriorityLbConfig priorityLbConfig = new PriorityLbConfig( ImmutableMap.of("p0", priorityChildConfig0, "p1", priorityChildConfig1, "p2", priorityChildConfig2), ImmutableList.of("p0", "p1", "p2")); - priorityLb.handleResolvedAddresses( + Status status = priorityLb.acceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(addresses) .setAttributes(attributes) .setLoadBalancingPolicyConfig(priorityLbConfig) .build()); + assertThat(status.getCode()).isEqualTo(Status.Code.OK); assertThat(fooBalancers).hasSize(1); assertThat(barBalancers).isEmpty(); LoadBalancer fooBalancer0 = Iterables.getOnlyElement(fooBalancers); - verify(fooBalancer0).handleResolvedAddresses(resolvedAddressesCaptor.capture()); + verify(fooBalancer0).acceptResolvedAddresses(resolvedAddressesCaptor.capture()); ResolvedAddresses addressesReceived = resolvedAddressesCaptor.getValue(); assertThat(addressesReceived.getAddresses()).isEmpty(); assertThat(addressesReceived.getAttributes()).isEqualTo(attributes); @@ -182,7 +290,7 @@ public void handleResolvedAddresses() { assertThat(fooBalancers).hasSize(1); assertThat(barBalancers).hasSize(1); LoadBalancer barBalancer0 = Iterables.getOnlyElement(barBalancers); - verify(barBalancer0).handleResolvedAddresses(resolvedAddressesCaptor.capture()); + verify(barBalancer0).acceptResolvedAddresses(resolvedAddressesCaptor.capture()); addressesReceived = resolvedAddressesCaptor.getValue(); assertThat(Iterables.getOnlyElement(addressesReceived.getAddresses()).getAddresses()) .containsExactly(socketAddress); @@ -194,7 +302,7 @@ public void handleResolvedAddresses() { assertThat(fooBalancers).hasSize(2); assertThat(barBalancers).hasSize(1); LoadBalancer fooBalancer1 = Iterables.getLast(fooBalancers); - verify(fooBalancer1).handleResolvedAddresses(resolvedAddressesCaptor.capture()); + verify(fooBalancer1).acceptResolvedAddresses(resolvedAddressesCaptor.capture()); addressesReceived = resolvedAddressesCaptor.getValue(); assertThat(addressesReceived.getAddresses()).isEmpty(); assertThat(addressesReceived.getAttributes()).isEqualTo(attributes); @@ -209,84 +317,142 @@ public void handleResolvedAddresses() { PriorityLbConfig newPriorityLbConfig = new PriorityLbConfig( ImmutableMap.of("p1", - new PriorityChildConfig(new PolicySelection(barLbProvider, newBarConfig), true)), + new PriorityChildConfig(newChildConfig(barLbProvider, newBarConfig), true)), ImmutableList.of("p1")); - priorityLb.handleResolvedAddresses( + status = priorityLb.acceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(newAddresses) .setLoadBalancingPolicyConfig(newPriorityLbConfig) .build()); + assertThat(status.getCode()).isEqualTo(Status.Code.OK); assertThat(fooBalancers).hasSize(2); assertThat(barBalancers).hasSize(1); - verify(barBalancer0, times(2)).handleResolvedAddresses(resolvedAddressesCaptor.capture()); + verify(barBalancer0, times(2)).acceptResolvedAddresses(resolvedAddressesCaptor.capture()); addressesReceived = resolvedAddressesCaptor.getValue(); assertThat(Iterables.getOnlyElement(addressesReceived.getAddresses()).getAddresses()) .containsExactly(newSocketAddress); assertThat(addressesReceived.getAttributes()).isEqualTo(Attributes.EMPTY); assertThat(addressesReceived.getLoadBalancingPolicyConfig()).isEqualTo(newBarConfig); - verify(fooBalancer0, never()).shutdown(); - verify(fooBalancer1, never()).shutdown(); - fakeClock.forwardTime(15, TimeUnit.MINUTES); verify(fooBalancer0).shutdown(); verify(fooBalancer1).shutdown(); verify(barBalancer0, never()).shutdown(); } @Test - public void handleNameResolutionError() { - Object fooConfig0 = new Object(); - PriorityChildConfig priorityChildConfig0 = - new PriorityChildConfig(new PolicySelection(fooLbProvider, fooConfig0), true); - Object fooConfig1 = new Object(); - PriorityChildConfig priorityChildConfig1 = - new PriorityChildConfig(new PolicySelection(fooLbProvider, fooConfig1), true); - - PriorityLbConfig priorityLbConfig = - new PriorityLbConfig(ImmutableMap.of("p0", priorityChildConfig0), ImmutableList.of("p0")); - priorityLb.handleResolvedAddresses( - ResolvedAddresses.newBuilder() - .setAddresses(ImmutableList.of()) - .setLoadBalancingPolicyConfig(priorityLbConfig) - .build()); - LoadBalancer fooLb0 = Iterables.getOnlyElement(fooBalancers); - Status status = Status.DATA_LOSS.withDescription("fake error"); - priorityLb.handleNameResolutionError(status); - verify(fooLb0).handleNameResolutionError(status); + public void acceptResolvedAddresses_propagatesChildFailures() { + LoadBalancerProvider lbProvider = new CannedLoadBalancer.Provider(); + CannedLoadBalancer.Config internalTf = new CannedLoadBalancer.Config( + Status.INTERNAL, TRANSIENT_FAILURE); + CannedLoadBalancer.Config okTf = new CannedLoadBalancer.Config(Status.OK, TRANSIENT_FAILURE); + ResolvedAddresses resolvedAddresses = ResolvedAddresses.newBuilder() + .setAddresses(ImmutableList.of()) + .setAttributes(Attributes.EMPTY) + .build(); + + // tryNewPriority() propagates status + Status status = priorityLb.acceptResolvedAddresses( + resolvedAddresses.toBuilder() + .setLoadBalancingPolicyConfig(new PriorityLbConfig( + ImmutableMap.of( + "p0", newPriorityChildConfig(lbProvider, internalTf, true)), + ImmutableList.of("p0"))) + .build()); + assertThat(status.getCode()).isNotEqualTo(Status.Code.OK); + + // Updating a child propagates status + status = priorityLb.acceptResolvedAddresses( + resolvedAddresses.toBuilder() + .setLoadBalancingPolicyConfig(new PriorityLbConfig( + ImmutableMap.of( + "p0", newPriorityChildConfig(lbProvider, internalTf, true)), + ImmutableList.of("p0"))) + .build()); + assertThat(status.getCode()).isNotEqualTo(Status.Code.OK); + + // A single pre-existing child failure propagates + status = priorityLb.acceptResolvedAddresses( + resolvedAddresses.toBuilder() + .setLoadBalancingPolicyConfig(new PriorityLbConfig( + ImmutableMap.of( + "p0", newPriorityChildConfig(lbProvider, okTf, true), + "p1", newPriorityChildConfig(lbProvider, okTf, true), + "p2", newPriorityChildConfig(lbProvider, okTf, true)), + ImmutableList.of("p0", "p1", "p2"))) + .build()); + assertThat(status.getCode()).isEqualTo(Status.Code.OK); + status = priorityLb.acceptResolvedAddresses( + resolvedAddresses.toBuilder() + .setLoadBalancingPolicyConfig(new PriorityLbConfig( + ImmutableMap.of( + "p0", newPriorityChildConfig(lbProvider, okTf, true), + "p1", newPriorityChildConfig(lbProvider, internalTf, true), + "p2", newPriorityChildConfig(lbProvider, okTf, true)), + ImmutableList.of("p0", "p1", "p2"))) + .build()); + assertThat(status.getCode()).isNotEqualTo(Status.Code.OK); + } - priorityLbConfig = - new PriorityLbConfig(ImmutableMap.of("p1", priorityChildConfig1), ImmutableList.of("p1")); - priorityLb.handleResolvedAddresses( - ResolvedAddresses.newBuilder() - .setAddresses(ImmutableList.of()) - .setLoadBalancingPolicyConfig(priorityLbConfig) - .build()); - assertThat(fooBalancers).hasSize(2); - LoadBalancer fooLb1 = Iterables.getLast(fooBalancers); - status = Status.UNAVAILABLE.withDescription("fake error"); - priorityLb.handleNameResolutionError(status); - // fooLb0 is deactivated but not yet deleted. However, because it is delisted by the latest - // address update, name resolution error will not be propagated to it. - verify(fooLb0, never()).shutdown(); - verify(fooLb0, never()).handleNameResolutionError(status); - verify(fooLb1).handleNameResolutionError(status); + @Test + public void handleNameResolutionError() { + boolean originalFlagVal = PriorityLoadBalancer.enablePriorityLbChildPolicyCache; + PriorityLoadBalancer.enablePriorityLbChildPolicyCache = true; + try { + Object fooConfig0 = new Object(); + PriorityChildConfig priorityChildConfig0 = + new PriorityChildConfig(newChildConfig(fooLbProvider, fooConfig0), true); + Object fooConfig1 = new Object(); + PriorityChildConfig priorityChildConfig1 = + new PriorityChildConfig(newChildConfig(fooLbProvider, fooConfig1), true); + + PriorityLbConfig priorityLbConfig = + new PriorityLbConfig(ImmutableMap.of("p0", priorityChildConfig0), ImmutableList.of("p0")); + priorityLb.acceptResolvedAddresses( + ResolvedAddresses.newBuilder() + .setAddresses(ImmutableList.of()) + .setLoadBalancingPolicyConfig(priorityLbConfig) + .build()); + LoadBalancer fooLb0 = Iterables.getOnlyElement(fooBalancers); + Status status = Status.DATA_LOSS.withDescription("fake error"); + priorityLb.handleNameResolutionError(status); + verify(fooLb0).handleNameResolutionError(status); + + priorityLbConfig = + new PriorityLbConfig(ImmutableMap.of("p1", priorityChildConfig1), ImmutableList.of("p1")); + priorityLb.acceptResolvedAddresses( + ResolvedAddresses.newBuilder() + .setAddresses(ImmutableList.of()) + .setLoadBalancingPolicyConfig(priorityLbConfig) + .build()); + assertThat(fooBalancers).hasSize(2); + LoadBalancer fooLb1 = Iterables.getLast(fooBalancers); + status = Status.UNAVAILABLE.withDescription("fake error"); + priorityLb.handleNameResolutionError(status); + // fooLb0 is deactivated but not yet deleted. However, because it is delisted by the latest + // address update, name resolution error will not be propagated to it. + verify(fooLb0, never()).shutdown(); + verify(fooLb0, never()).handleNameResolutionError(status); + verify(fooLb1).handleNameResolutionError(status); + } finally { + PriorityLoadBalancer.enablePriorityLbChildPolicyCache = originalFlagVal; + } } @Test public void typicalPriorityFailOverFlow() { PriorityChildConfig priorityChildConfig0 = - new PriorityChildConfig(new PolicySelection(fooLbProvider, new Object()), true); + new PriorityChildConfig(newChildConfig(fooLbProvider, new Object()), true); PriorityChildConfig priorityChildConfig1 = - new PriorityChildConfig(new PolicySelection(fooLbProvider, new Object()), true); + new PriorityChildConfig(newChildConfig(fooLbProvider, new Object()), true); PriorityChildConfig priorityChildConfig2 = - new PriorityChildConfig(new PolicySelection(fooLbProvider, new Object()), true); + new PriorityChildConfig(newChildConfig(fooLbProvider, new Object()), true); PriorityChildConfig priorityChildConfig3 = - new PriorityChildConfig(new PolicySelection(fooLbProvider, new Object()), true); + new PriorityChildConfig(newChildConfig(fooLbProvider, new Object()), true); PriorityLbConfig priorityLbConfig = new PriorityLbConfig( ImmutableMap.of("p0", priorityChildConfig0, "p1", priorityChildConfig1, "p2", priorityChildConfig2, "p3", priorityChildConfig3), ImmutableList.of("p0", "p1", "p2", "p3")); - priorityLb.handleResolvedAddresses( + priorityLb.acceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(ImmutableList.of()) .setLoadBalancingPolicyConfig(priorityLbConfig) @@ -315,6 +481,7 @@ public PickResult pickSubchannel(PickSubchannelArgs args) { assertThat(fooBalancers).hasSize(2); assertThat(fooHelpers).hasSize(2); LoadBalancer balancer1 = Iterables.getLast(fooBalancers); + Helper helper1 = Iterables.getLast(fooHelpers); // p1 timeout, and fails over to p2 fakeClock.forwardTime(10, TimeUnit.SECONDS); @@ -362,14 +529,20 @@ public PickResult pickSubchannel(PickSubchannelArgs args) { LoadBalancer balancer3 = Iterables.getLast(fooBalancers); Helper helper3 = Iterables.getLast(fooHelpers); - // p3 timeout then the channel should go to TRANSIENT_FAILURE + // p3 timeout then the channel should stay in CONNECTING fakeClock.forwardTime(10, TimeUnit.SECONDS); - assertCurrentPickerReturnsError(Status.Code.UNAVAILABLE, "timeout"); + assertCurrentPicker(CONNECTING, PickResult.withNoResult()); - // p3 fails then the picker should have error status updated + // p3 fails then the picker should still be waiting on p1 helper3.updateBalancingState( TRANSIENT_FAILURE, new FixedResultPicker(PickResult.withError(Status.DATA_LOSS.withDescription("foo")))); + assertCurrentPicker(CONNECTING, PickResult.withNoResult()); + + // p1 fails then the picker should have error status updated to p3 + helper1.updateBalancingState( + TRANSIENT_FAILURE, + new FixedResultPicker(PickResult.withError(Status.DATA_LOSS.withDescription("bar")))); assertCurrentPickerReturnsError(Status.Code.DATA_LOSS, "foo"); // p2 gets back to READY @@ -412,14 +585,14 @@ public PickResult pickSubchannel(PickSubchannelArgs args) { @Test public void idleToConnectingDoesNotTriggerFailOver() { PriorityChildConfig priorityChildConfig0 = - new PriorityChildConfig(new PolicySelection(fooLbProvider, new Object()), true); + new PriorityChildConfig(newChildConfig(fooLbProvider, new Object()), true); PriorityChildConfig priorityChildConfig1 = - new PriorityChildConfig(new PolicySelection(fooLbProvider, new Object()), true); + new PriorityChildConfig(newChildConfig(fooLbProvider, new Object()), true); PriorityLbConfig priorityLbConfig = new PriorityLbConfig( ImmutableMap.of("p0", priorityChildConfig0, "p1", priorityChildConfig1), ImmutableList.of("p0", "p1")); - priorityLb.handleResolvedAddresses( + priorityLb.acceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(ImmutableList.of()) .setLoadBalancingPolicyConfig(priorityLbConfig) @@ -448,21 +621,20 @@ public void idleToConnectingDoesNotTriggerFailOver() { @Test public void connectingResetFailOverIfSeenReadyOrIdleSinceTransientFailure() { PriorityChildConfig priorityChildConfig0 = - new PriorityChildConfig(new PolicySelection(fooLbProvider, new Object()), true); + new PriorityChildConfig(newChildConfig(fooLbProvider, new Object()), true); PriorityChildConfig priorityChildConfig1 = - new PriorityChildConfig(new PolicySelection(fooLbProvider, new Object()), true); + new PriorityChildConfig(newChildConfig(fooLbProvider, new Object()), true); PriorityLbConfig priorityLbConfig = new PriorityLbConfig( ImmutableMap.of("p0", priorityChildConfig0, "p1", priorityChildConfig1), ImmutableList.of("p0", "p1")); - priorityLb.handleResolvedAddresses( + priorityLb.acceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(ImmutableList.of()) .setLoadBalancingPolicyConfig(priorityLbConfig) .build()); // Nothing important about this verify, other than to provide a baseline - verify(helper, times(2)) - .updateBalancingState(eq(CONNECTING), pickerReturns(PickResult.withNoResult())); + verify(helper).updateBalancingState(eq(CONNECTING), pickerReturns(PickResult.withNoResult())); assertThat(fooBalancers).hasSize(1); assertThat(fooHelpers).hasSize(1); Helper helper0 = Iterables.getOnlyElement(fooHelpers); @@ -478,7 +650,7 @@ public void connectingResetFailOverIfSeenReadyOrIdleSinceTransientFailure() { helper0.updateBalancingState( CONNECTING, EMPTY_PICKER); - verify(helper, times(3)) + verify(helper, times(2)) .updateBalancingState(eq(CONNECTING), pickerReturns(PickResult.withNoResult())); // failover happens @@ -487,17 +659,66 @@ public void connectingResetFailOverIfSeenReadyOrIdleSinceTransientFailure() { assertThat(fooHelpers).hasSize(2); } + @Test + public void failoverTimerNotRestartedOnDupConnecting() { + InOrder inOrder = inOrder(helper); + PriorityChildConfig priorityChildConfig0 = + new PriorityChildConfig(newChildConfig(fooLbProvider, new Object()), true); + PriorityChildConfig priorityChildConfig1 = + new PriorityChildConfig(newChildConfig(fooLbProvider, new Object()), true); + PriorityLbConfig priorityLbConfig = + new PriorityLbConfig( + ImmutableMap.of("p0", priorityChildConfig0, "p1", priorityChildConfig1), + ImmutableList.of("p0", "p1")); + priorityLb.acceptResolvedAddresses( + ResolvedAddresses.newBuilder() + .setAddresses(ImmutableList.of()) + .setLoadBalancingPolicyConfig(priorityLbConfig) + .build()); + // Nothing important about this verify, other than to provide a baseline + inOrder.verify(helper) + .updateBalancingState(eq(CONNECTING), pickerReturns(PickResult.withNoResult())); + assertThat(fooBalancers).hasSize(1); + assertThat(fooHelpers).hasSize(1); + Helper helper0 = Iterables.getOnlyElement(fooHelpers); + + // Cause seenReadyOrIdleSinceTransientFailure = true + helper0.updateBalancingState(IDLE, EMPTY_PICKER); + inOrder.verify(helper) + .updateBalancingState(eq(IDLE), pickerReturns(PickResult.withNoResult())); + helper0.updateBalancingState(CONNECTING, EMPTY_PICKER); + + // p0 keeps repeating CONNECTING, failover happens + fakeClock.forwardTime(5, TimeUnit.SECONDS); + helper0.updateBalancingState(CONNECTING, EMPTY_PICKER); + fakeClock.forwardTime(5, TimeUnit.SECONDS); + assertThat(fooBalancers).hasSize(2); + assertThat(fooHelpers).hasSize(2); + inOrder.verify(helper, times(2)) + .updateBalancingState(eq(CONNECTING), pickerReturns(PickResult.withNoResult())); + Helper helper1 = Iterables.getLast(fooHelpers); + + // p0 keeps repeating CONNECTING, no reset of failover timer + helper1.updateBalancingState(IDLE, EMPTY_PICKER); // Stop timer for p1 + inOrder.verify(helper) + .updateBalancingState(eq(IDLE), pickerReturns(PickResult.withNoResult())); + helper0.updateBalancingState(CONNECTING, EMPTY_PICKER); + fakeClock.forwardTime(10, TimeUnit.SECONDS); + inOrder.verify(helper, never()) + .updateBalancingState(eq(CONNECTING), any()); + } + @Test public void readyToConnectDoesNotFailOverButUpdatesPicker() { PriorityChildConfig priorityChildConfig0 = - new PriorityChildConfig(new PolicySelection(fooLbProvider, new Object()), true); + new PriorityChildConfig(newChildConfig(fooLbProvider, new Object()), true); PriorityChildConfig priorityChildConfig1 = - new PriorityChildConfig(new PolicySelection(fooLbProvider, new Object()), true); + new PriorityChildConfig(newChildConfig(fooLbProvider, new Object()), true); PriorityLbConfig priorityLbConfig = new PriorityLbConfig( ImmutableMap.of("p0", priorityChildConfig0, "p1", priorityChildConfig1), ImmutableList.of("p0", "p1")); - priorityLb.handleResolvedAddresses( + priorityLb.acceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(ImmutableList.of()) .setLoadBalancingPolicyConfig(priorityLbConfig) @@ -530,7 +751,7 @@ public PickResult pickSubchannel(PickSubchannelArgs args) { // resolution update without priority change does not trigger failover Attributes.Key fooKey = Attributes.Key.create("fooKey"); - priorityLb.handleResolvedAddresses( + priorityLb.acceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(ImmutableList.of()) .setLoadBalancingPolicyConfig(priorityLbConfig) @@ -547,19 +768,19 @@ public PickResult pickSubchannel(PickSubchannelArgs args) { @Test public void typicalPriorityFailOverFlowWithIdleUpdate() { PriorityChildConfig priorityChildConfig0 = - new PriorityChildConfig(new PolicySelection(fooLbProvider, new Object()), true); + new PriorityChildConfig(newChildConfig(fooLbProvider, new Object()), true); PriorityChildConfig priorityChildConfig1 = - new PriorityChildConfig(new PolicySelection(fooLbProvider, new Object()), true); + new PriorityChildConfig(newChildConfig(fooLbProvider, new Object()), true); PriorityChildConfig priorityChildConfig2 = - new PriorityChildConfig(new PolicySelection(fooLbProvider, new Object()), true); + new PriorityChildConfig(newChildConfig(fooLbProvider, new Object()), true); PriorityChildConfig priorityChildConfig3 = - new PriorityChildConfig(new PolicySelection(fooLbProvider, new Object()), true); + new PriorityChildConfig(newChildConfig(fooLbProvider, new Object()), true); PriorityLbConfig priorityLbConfig = new PriorityLbConfig( ImmutableMap.of("p0", priorityChildConfig0, "p1", priorityChildConfig1, "p2", priorityChildConfig2, "p3", priorityChildConfig3), ImmutableList.of("p0", "p1", "p2", "p3")); - priorityLb.handleResolvedAddresses( + priorityLb.acceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(ImmutableList.of()) .setLoadBalancingPolicyConfig(priorityLbConfig) @@ -582,6 +803,7 @@ public void typicalPriorityFailOverFlowWithIdleUpdate() { assertThat(fooBalancers).hasSize(2); assertThat(fooHelpers).hasSize(2); LoadBalancer balancer1 = Iterables.getLast(fooBalancers); + Helper helper1 = Iterables.getLast(fooHelpers); // p1 timeout, and fails over to p2 fakeClock.forwardTime(10, TimeUnit.SECONDS); @@ -617,14 +839,20 @@ public void typicalPriorityFailOverFlowWithIdleUpdate() { LoadBalancer balancer3 = Iterables.getLast(fooBalancers); Helper helper3 = Iterables.getLast(fooHelpers); - // p3 timeout then the channel should go to TRANSIENT_FAILURE + // p3 timeout then the channel should stay in CONNECTING fakeClock.forwardTime(10, TimeUnit.SECONDS); - assertCurrentPickerReturnsError(Status.Code.UNAVAILABLE, "timeout"); + assertCurrentPicker(CONNECTING, PickResult.withNoResult()); - // p3 fails then the picker should have error status updated + // p3 fails then the picker should still be waiting on p1 helper3.updateBalancingState( TRANSIENT_FAILURE, new FixedResultPicker(PickResult.withError(Status.DATA_LOSS.withDescription("foo")))); + assertCurrentPicker(CONNECTING, PickResult.withNoResult()); + + // p1 fails then the picker should have error status updated to p3 + helper1.updateBalancingState( + TRANSIENT_FAILURE, + new FixedResultPicker(PickResult.withError(Status.DATA_LOSS.withDescription("bar")))); assertCurrentPickerReturnsError(Status.Code.DATA_LOSS, "foo"); // p2 gets back to IDLE @@ -652,17 +880,66 @@ public void typicalPriorityFailOverFlowWithIdleUpdate() { verify(balancer3).shutdown(); } + @Test + public void failover_propagatesChildFailures() { + LoadBalancerProvider lbProvider = new CannedLoadBalancer.Provider(); + ResolvedAddresses resolvedAddresses = ResolvedAddresses.newBuilder() + .setAddresses(ImmutableList.of()) + .setAttributes(Attributes.EMPTY) + .build(); + + Status status = priorityLb.acceptResolvedAddresses( + resolvedAddresses.toBuilder() + .setLoadBalancingPolicyConfig(new PriorityLbConfig( + ImmutableMap.of( + "p0", newPriorityChildConfig( + lbProvider, new CannedLoadBalancer.Config(Status.OK, TRANSIENT_FAILURE), true), + "p1", newPriorityChildConfig( + lbProvider, new CannedLoadBalancer.Config(Status.INTERNAL, CONNECTING), true)), + ImmutableList.of("p0", "p1"))) + .build()); + // Since P1's activation wasn't noticed by the result status, it triggered name resolution + assertThat(status.getCode()).isEqualTo(Status.Code.OK); + verify(helper).refreshNameResolution(); + } + + @Test + public void failoverTimer_propagatesChildFailures() { + LoadBalancerProvider lbProvider = new CannedLoadBalancer.Provider(); + ResolvedAddresses resolvedAddresses = ResolvedAddresses.newBuilder() + .setAddresses(ImmutableList.of()) + .setAttributes(Attributes.EMPTY) + .build(); + + Status status = priorityLb.acceptResolvedAddresses( + resolvedAddresses.toBuilder() + .setLoadBalancingPolicyConfig(new PriorityLbConfig( + ImmutableMap.of( + "p0", newPriorityChildConfig( + lbProvider, new CannedLoadBalancer.Config(Status.OK, CONNECTING), true), + "p1", newPriorityChildConfig( + lbProvider, new CannedLoadBalancer.Config(Status.INTERNAL, CONNECTING), true)), + ImmutableList.of("p0", "p1"))) + .build()); + assertThat(status.getCode()).isEqualTo(Status.Code.OK); + + // P1's activation will refresh name resolution + verify(helper, never()).refreshNameResolution(); + fakeClock.forwardTime(10, TimeUnit.SECONDS); + verify(helper).refreshNameResolution(); + } + @Test public void bypassReresolutionRequestsIfConfiged() { PriorityChildConfig priorityChildConfig0 = - new PriorityChildConfig(new PolicySelection(fooLbProvider, new Object()), true); + new PriorityChildConfig(newChildConfig(fooLbProvider, new Object()), true); PriorityChildConfig priorityChildConfig1 = - new PriorityChildConfig(new PolicySelection(fooLbProvider, new Object()), false); + new PriorityChildConfig(newChildConfig(fooLbProvider, new Object()), false); PriorityLbConfig priorityLbConfig = new PriorityLbConfig( ImmutableMap.of("p0", priorityChildConfig0, "p1", priorityChildConfig1), ImmutableList.of("p0", "p1")); - priorityLb.handleResolvedAddresses( + priorityLb.acceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(ImmutableList.of()) .setLoadBalancingPolicyConfig(priorityLbConfig) @@ -683,19 +960,19 @@ public void bypassReresolutionRequestsIfConfiged() { @Test public void raceBetweenShutdownAndChildLbBalancingStateUpdate() { PriorityChildConfig priorityChildConfig0 = - new PriorityChildConfig(new PolicySelection(fooLbProvider, new Object()), true); + new PriorityChildConfig(newChildConfig(fooLbProvider, new Object()), true); PriorityChildConfig priorityChildConfig1 = - new PriorityChildConfig(new PolicySelection(fooLbProvider, new Object()), false); + new PriorityChildConfig(newChildConfig(fooLbProvider, new Object()), false); PriorityLbConfig priorityLbConfig = new PriorityLbConfig( ImmutableMap.of("p0", priorityChildConfig0, "p1", priorityChildConfig1), ImmutableList.of("p0", "p1")); - priorityLb.handleResolvedAddresses( + priorityLb.acceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(ImmutableList.of()) .setLoadBalancingPolicyConfig(priorityLbConfig) .build()); - verify(helper, times(2)).updateBalancingState(eq(CONNECTING), isA(SubchannelPicker.class)); + verify(helper).updateBalancingState(eq(CONNECTING), isA(SubchannelPicker.class)); // LB shutdown and subchannel state change can happen simultaneously. If shutdown runs first, // any further balancing state update should be ignored. @@ -710,14 +987,14 @@ public void noDuplicateOverallBalancingStateUpdate() { FakeLoadBalancerProvider fakeLbProvider = new FakeLoadBalancerProvider(); PriorityChildConfig priorityChildConfig0 = - new PriorityChildConfig(new PolicySelection(fakeLbProvider, new Object()), true); + new PriorityChildConfig(newChildConfig(fakeLbProvider, new Object()), true); PriorityChildConfig priorityChildConfig1 = - new PriorityChildConfig(new PolicySelection(fakeLbProvider, new Object()), false); + new PriorityChildConfig(newChildConfig(fakeLbProvider, new Object()), false); PriorityLbConfig priorityLbConfig = new PriorityLbConfig( ImmutableMap.of("p0", priorityChildConfig0), ImmutableList.of("p0")); - priorityLb.handleResolvedAddresses( + priorityLb.acceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(ImmutableList.of()) .setLoadBalancingPolicyConfig(priorityLbConfig) @@ -727,13 +1004,13 @@ public void noDuplicateOverallBalancingStateUpdate() { new PriorityLbConfig( ImmutableMap.of("p0", priorityChildConfig0, "p1", priorityChildConfig1), ImmutableList.of("p0", "p1")); - priorityLb.handleResolvedAddresses( + priorityLb.acceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(ImmutableList.of()) .setLoadBalancingPolicyConfig(priorityLbConfig) .build()); - verify(helper, times(6)).updateBalancingState(any(), any()); + verify(helper, times(4)).updateBalancingState(any(), any()); } private void assertLatestConnectivityState(ConnectivityState expectedState) { @@ -754,15 +1031,26 @@ private void assertCurrentPickerReturnsError( } private void assertCurrentPickerPicksSubchannel(Subchannel expectedSubchannelToPick) { - assertLatestConnectivityState(READY); - PickResult pickResult = pickerCaptor.getValue().pickSubchannel(mock(PickSubchannelArgs.class)); - assertThat(pickResult.getSubchannel()).isEqualTo(expectedSubchannelToPick); + assertCurrentPicker(READY, PickResult.withSubchannel(expectedSubchannelToPick)); } private void assertCurrentPickerIsBufferPicker() { - assertLatestConnectivityState(IDLE); + assertCurrentPicker(IDLE, PickResult.withNoResult()); + } + + private void assertCurrentPicker(ConnectivityState state, PickResult result) { + assertLatestConnectivityState(state); PickResult pickResult = pickerCaptor.getValue().pickSubchannel(mock(PickSubchannelArgs.class)); - assertThat(pickResult).isEqualTo(PickResult.withNoResult()); + assertThat(pickResult).isEqualTo(result); + } + + private Object newChildConfig(LoadBalancerProvider provider, Object config) { + return GracefulSwitchLoadBalancer.createLoadBalancingPolicyConfig(provider, config); + } + + private PriorityChildConfig newPriorityChildConfig( + LoadBalancerProvider provider, Object config, boolean ignoreRefresh) { + return new PriorityChildConfig(newChildConfig(provider, config), ignoreRefresh); } private static class FakeLoadBalancerProvider extends LoadBalancerProvider { @@ -797,9 +1085,10 @@ static class FakeLoadBalancer extends LoadBalancer { } @Override - public void handleResolvedAddresses(ResolvedAddresses resolvedAddresses) { + public Status acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { helper.updateBalancingState( TRANSIENT_FAILURE, new FixedResultPicker(PickResult.withError(Status.INTERNAL))); + return Status.OK; } @Override @@ -810,4 +1099,47 @@ public void handleNameResolutionError(Status error) { public void shutdown() { } } + + static final class CannedLoadBalancer extends LoadBalancer { + private final Helper helper; + + private CannedLoadBalancer(Helper helper) { + this.helper = helper; + } + + @Override + public Status acceptResolvedAddresses(ResolvedAddresses addresses) { + Config config = (Config) addresses.getLoadBalancingPolicyConfig(); + helper.updateBalancingState( + config.state, new FixedResultPicker(PickResult.withError(Status.INTERNAL))); + return config.resolvedAddressesResult; + } + + @Override + public void handleNameResolutionError(Status status) {} + + @Override + public void shutdown() {} + + static final class Provider extends StandardLoadBalancerProvider { + public Provider() { + super("echo"); + } + + @Override + public LoadBalancer newLoadBalancer(Helper helper) { + return new CannedLoadBalancer(helper); + } + } + + static final class Config { + final Status resolvedAddressesResult; + final ConnectivityState state; + + public Config(Status resolvedAddressesResult, ConnectivityState state) { + this.resolvedAddressesResult = resolvedAddressesResult; + this.state = state; + } + } + } } diff --git a/xds/src/test/java/io/grpc/xds/RbacFilterTest.java b/xds/src/test/java/io/grpc/xds/RbacFilterTest.java index 29af01b222f..0f6920b18eb 100644 --- a/xds/src/test/java/io/grpc/xds/RbacFilterTest.java +++ b/xds/src/test/java/io/grpc/xds/RbacFilterTest.java @@ -54,6 +54,9 @@ import io.grpc.Status; import io.grpc.testing.TestMethodDescriptors; import io.grpc.xds.Filter.FilterConfig; +import io.grpc.xds.client.Bootstrapper.BootstrapInfo; +import io.grpc.xds.client.Bootstrapper.ServerInfo; +import io.grpc.xds.client.EnvoyProtoData.Node; import io.grpc.xds.internal.rbac.engine.GrpcAuthorizationEngine; import io.grpc.xds.internal.rbac.engine.GrpcAuthorizationEngine.AlwaysTrueMatcher; import io.grpc.xds.internal.rbac.engine.GrpcAuthorizationEngine.AuthConfig; @@ -78,6 +81,15 @@ public class RbacFilterTest { private static final String PATH = "auth"; private static final StringMatcher STRING_MATCHER = StringMatcher.newBuilder().setExact("/" + PATH).setIgnoreCase(true).build(); + private static final RbacFilter.Provider FILTER_PROVIDER = new RbacFilter.Provider(); + + private final String name = "theFilterName"; + + @Test + public void filterType_serverOnly() { + assertThat(FILTER_PROVIDER.isClientFilter()).isFalse(); + assertThat(FILTER_PROVIDER.isServerFilter()).isTrue(); + } @Test @SuppressWarnings({"unchecked", "deprecation"}) @@ -111,7 +123,7 @@ public void ipPortParser() { } @Test - @SuppressWarnings({"unchecked", "deprecation"}) + @SuppressWarnings("unchecked") public void portRangeParser() { List permissionList = Arrays.asList( Permission.newBuilder().setDestinationPortRange( @@ -219,14 +231,15 @@ public void headerParser_headerName() { @SuppressWarnings("unchecked") public void compositeRules() { MetadataMatcher metadataMatcher = MetadataMatcher.newBuilder().build(); + @SuppressWarnings("deprecation") + Permission permissionMetadata = Permission.newBuilder().setMetadata(metadataMatcher).build(); List permissionList = Arrays.asList( Permission.newBuilder().setOrRules(Permission.Set.newBuilder().addRules( - Permission.newBuilder().setMetadata(metadataMatcher).build() - ).build()).build()); + permissionMetadata).build()).build()); + @SuppressWarnings("deprecation") + Principal principalMetadata = Principal.newBuilder().setMetadata(metadataMatcher).build(); List principalList = Arrays.asList( - Principal.newBuilder().setNotId( - Principal.newBuilder().setMetadata(metadataMatcher).build() - ).build()); + Principal.newBuilder().setNotId(principalMetadata).build()); ConfigOrError result = parse(permissionList, principalList); assertThat(result.errorDetail).isNull(); assertThat(result.config).isInstanceOf(RbacConfig.class); @@ -251,7 +264,7 @@ public void testAuthorizationInterceptor() { OrMatcher.create(AlwaysTrueMatcher.INSTANCE)); AuthConfig authconfig = AuthConfig.create(Collections.singletonList(policyMatcher), GrpcAuthorizationEngine.Action.ALLOW); - new RbacFilter().buildServerInterceptor(RbacConfig.create(authconfig), null) + FILTER_PROVIDER.newInstance(name).buildServerInterceptor(RbacConfig.create(authconfig), null) .interceptCall(mockServerCall, new Metadata(), mockHandler); verify(mockHandler, never()).startCall(eq(mockServerCall), any(Metadata.class)); ArgumentCaptor captor = ArgumentCaptor.forClass(Status.class); @@ -263,7 +276,7 @@ public void testAuthorizationInterceptor() { authconfig = AuthConfig.create(Collections.singletonList(policyMatcher), GrpcAuthorizationEngine.Action.DENY); - new RbacFilter().buildServerInterceptor(RbacConfig.create(authconfig), null) + FILTER_PROVIDER.newInstance(name).buildServerInterceptor(RbacConfig.create(authconfig), null) .interceptCall(mockServerCall, new Metadata(), mockHandler); verify(mockHandler).startCall(eq(mockServerCall), any(Metadata.class)); } @@ -289,7 +302,7 @@ public void handleException() { .putPolicies("policy-name", Policy.newBuilder().setCondition(Expr.newBuilder().build()).build()) .build()).build(); - result = new RbacFilter().parseFilterConfig(Any.pack(rawProto)); + result = FILTER_PROVIDER.parseFilterConfig(Any.pack(rawProto), getFilterContext()); assertThat(result.errorDetail).isNotNull(); } @@ -311,10 +324,11 @@ public void overrideConfig() { RbacConfig original = RbacConfig.create(authconfig); RBACPerRoute rbacPerRoute = RBACPerRoute.newBuilder().build(); - RbacConfig override = - new RbacFilter().parseFilterConfigOverride(Any.pack(rbacPerRoute)).config; + RbacConfig override = FILTER_PROVIDER.parseFilterConfigOverride(Any.pack(rbacPerRoute), + getFilterContext()).config; assertThat(override).isEqualTo(RbacConfig.create(null)); - ServerInterceptor interceptor = new RbacFilter().buildServerInterceptor(original, override); + ServerInterceptor interceptor = + FILTER_PROVIDER.newInstance(name).buildServerInterceptor(original, override); assertThat(interceptor).isNull(); policyMatcher = PolicyMatcher.create("policy-matcher-override", @@ -324,7 +338,7 @@ public void overrideConfig() { GrpcAuthorizationEngine.Action.ALLOW); override = RbacConfig.create(authconfig); - new RbacFilter().buildServerInterceptor(original, override) + FILTER_PROVIDER.newInstance(name).buildServerInterceptor(original, override) .interceptCall(mockServerCall, new Metadata(), mockHandler); verify(mockHandler).startCall(eq(mockServerCall), any(Metadata.class)); verify(mockServerCall).getAttributes(); @@ -336,22 +350,26 @@ public void ignoredConfig() { Message rawProto = io.envoyproxy.envoy.extensions.filters.http.rbac.v3.RBAC.newBuilder() .setRules(RBAC.newBuilder().setAction(Action.LOG) .putPolicies("policy-name", Policy.newBuilder().build()).build()).build(); - ConfigOrError result = new RbacFilter().parseFilterConfig(Any.pack(rawProto)); + ConfigOrError result = + FILTER_PROVIDER.parseFilterConfig(Any.pack(rawProto), getFilterContext()); assertThat(result.config).isEqualTo(RbacConfig.create(null)); } @Test public void testOrderIndependenceOfPolicies() { Message rawProto = buildComplexRbac(ImmutableList.of(1, 2, 3, 4, 5, 6), true); - ConfigOrError ascFirst = new RbacFilter().parseFilterConfig(Any.pack(rawProto)); + ConfigOrError ascFirst = + FILTER_PROVIDER.parseFilterConfig(Any.pack(rawProto), getFilterContext()); rawProto = buildComplexRbac(ImmutableList.of(1, 2, 3, 4, 5, 6), false); - ConfigOrError ascLast = new RbacFilter().parseFilterConfig(Any.pack(rawProto)); + ConfigOrError ascLast = + FILTER_PROVIDER.parseFilterConfig(Any.pack(rawProto), getFilterContext()); assertThat(ascFirst.config).isEqualTo(ascLast.config); rawProto = buildComplexRbac(ImmutableList.of(6, 5, 4, 3, 2, 1), true); - ConfigOrError decFirst = new RbacFilter().parseFilterConfig(Any.pack(rawProto)); + ConfigOrError decFirst = + FILTER_PROVIDER.parseFilterConfig(Any.pack(rawProto), getFilterContext()); assertThat(ascFirst.config).isEqualTo(decFirst.config); } @@ -373,14 +391,14 @@ private MethodDescriptor.Builder method() { private ConfigOrError parse(List permissionList, List principalList) { - return RbacFilter.parseRbacConfig(buildRbac(permissionList, principalList)); + return RbacFilter.Provider.parseRbacConfig(buildRbac(permissionList, principalList)); } private ConfigOrError parseRaw(List permissionList, List principalList) { Message rawProto = buildRbac(permissionList, principalList); Any proto = Any.pack(rawProto); - return new RbacFilter().parseFilterConfig(proto); + return FILTER_PROVIDER.parseFilterConfig(proto, getFilterContext()); } private io.envoyproxy.envoy.extensions.filters.http.rbac.v3.RBAC buildRbac( @@ -448,6 +466,19 @@ private ConfigOrError parseOverride(List permissionList, RBACPerRoute rbacPerRoute = RBACPerRoute.newBuilder().setRbac( buildRbac(permissionList, principalList)).build(); Any proto = Any.pack(rbacPerRoute); - return new RbacFilter().parseFilterConfigOverride(proto); + return FILTER_PROVIDER.parseFilterConfigOverride(proto, getFilterContext()); + } + + private Filter.FilterConfigParseContext getFilterContext() { + return Filter.FilterConfigParseContext.builder() + .bootstrapInfo(BootstrapInfo.builder() + .servers(Collections.singletonList( + ServerInfo.create( + "test_target", Collections.emptyMap()))) + .node(Node.newBuilder().build()) + .build()) + .serverInfo(ServerInfo.create( + "test_target", Collections.emptyMap(), false, true, false, false)) + .build(); } } diff --git a/xds/src/test/java/io/grpc/xds/RingHashLoadBalancerProviderTest.java b/xds/src/test/java/io/grpc/xds/RingHashLoadBalancerProviderTest.java index 87615a125c0..66c9c5c537e 100644 --- a/xds/src/test/java/io/grpc/xds/RingHashLoadBalancerProviderTest.java +++ b/xds/src/test/java/io/grpc/xds/RingHashLoadBalancerProviderTest.java @@ -42,6 +42,8 @@ @RunWith(JUnit4.class) public class RingHashLoadBalancerProviderTest { private static final String AUTHORITY = "foo.googleapis.com"; + private static final String GRPC_EXPERIMENTAL_RING_HASH_SET_REQUEST_HASH_KEY = + "GRPC_EXPERIMENTAL_RING_HASH_SET_REQUEST_HASH_KEY"; private final SynchronizationContext syncContext = new SynchronizationContext( new UncaughtExceptionHandler() { @@ -81,6 +83,7 @@ public void parseLoadBalancingConfig_valid() throws IOException { RingHashConfig config = (RingHashConfig) configOrError.getConfig(); assertThat(config.minRingSize).isEqualTo(10L); assertThat(config.maxRingSize).isEqualTo(100L); + assertThat(config.requestHashHeader).isEmpty(); } @Test @@ -92,6 +95,7 @@ public void parseLoadBalancingConfig_missingRingSize_useDefaults() throws IOExce RingHashConfig config = (RingHashConfig) configOrError.getConfig(); assertThat(config.minRingSize).isEqualTo(RingHashLoadBalancerProvider.DEFAULT_MIN_RING_SIZE); assertThat(config.maxRingSize).isEqualTo(RingHashLoadBalancerProvider.DEFAULT_MAX_RING_SIZE); + assertThat(config.requestHashHeader).isEmpty(); } @Test @@ -102,7 +106,7 @@ public void parseLoadBalancingConfig_invalid_negativeSize() throws IOException { assertThat(configOrError.getError()).isNotNull(); assertThat(configOrError.getError().getCode()).isEqualTo(Code.UNAVAILABLE); assertThat(configOrError.getError().getDescription()) - .isEqualTo("Invalid 'mingRingSize'/'maxRingSize'"); + .isEqualTo("Invalid 'minRingSize'/'maxRingSize'"); } @Test @@ -113,7 +117,7 @@ public void parseLoadBalancingConfig_invalid_minGreaterThanMax() throws IOExcept assertThat(configOrError.getError()).isNotNull(); assertThat(configOrError.getError().getCode()).isEqualTo(Code.UNAVAILABLE); assertThat(configOrError.getError().getDescription()) - .isEqualTo("Invalid 'mingRingSize'/'maxRingSize'"); + .isEqualTo("Invalid 'minRingSize'/'maxRingSize'"); } @Test @@ -127,6 +131,7 @@ public void parseLoadBalancingConfig_ringTooLargeUsesCap() throws IOException { RingHashConfig config = (RingHashConfig) configOrError.getConfig(); assertThat(config.minRingSize).isEqualTo(10); assertThat(config.maxRingSize).isEqualTo(RingHashOptions.DEFAULT_RING_SIZE_CAP); + assertThat(config.requestHashHeader).isEmpty(); } @Test @@ -142,6 +147,7 @@ public void parseLoadBalancingConfig_ringCapCanBeRaised() throws IOException { RingHashConfig config = (RingHashConfig) configOrError.getConfig(); assertThat(config.minRingSize).isEqualTo(RingHashOptions.MAX_RING_SIZE_CAP); assertThat(config.maxRingSize).isEqualTo(RingHashOptions.MAX_RING_SIZE_CAP); + assertThat(config.requestHashHeader).isEmpty(); // Reset to avoid affecting subsequent test cases RingHashOptions.setRingSizeCap(RingHashOptions.DEFAULT_RING_SIZE_CAP); } @@ -159,6 +165,7 @@ public void parseLoadBalancingConfig_ringCapIsClampedTo8M() throws IOException { RingHashConfig config = (RingHashConfig) configOrError.getConfig(); assertThat(config.minRingSize).isEqualTo(RingHashOptions.MAX_RING_SIZE_CAP); assertThat(config.maxRingSize).isEqualTo(RingHashOptions.MAX_RING_SIZE_CAP); + assertThat(config.requestHashHeader).isEmpty(); // Reset to avoid affecting subsequent test cases RingHashOptions.setRingSizeCap(RingHashOptions.DEFAULT_RING_SIZE_CAP); } @@ -176,6 +183,7 @@ public void parseLoadBalancingConfig_ringCapCanBeLowered() throws IOException { RingHashConfig config = (RingHashConfig) configOrError.getConfig(); assertThat(config.minRingSize).isEqualTo(1); assertThat(config.maxRingSize).isEqualTo(1); + assertThat(config.requestHashHeader).isEmpty(); // Reset to avoid affecting subsequent test cases RingHashOptions.setRingSizeCap(RingHashOptions.DEFAULT_RING_SIZE_CAP); } @@ -193,6 +201,7 @@ public void parseLoadBalancingConfig_ringCapLowerLimitIs1() throws IOException { RingHashConfig config = (RingHashConfig) configOrError.getConfig(); assertThat(config.minRingSize).isEqualTo(1); assertThat(config.maxRingSize).isEqualTo(1); + assertThat(config.requestHashHeader).isEmpty(); // Reset to avoid affecting subsequent test cases RingHashOptions.setRingSizeCap(RingHashOptions.DEFAULT_RING_SIZE_CAP); } @@ -205,7 +214,7 @@ public void parseLoadBalancingConfig_zeroMinRingSize() throws IOException { assertThat(configOrError.getError()).isNotNull(); assertThat(configOrError.getError().getCode()).isEqualTo(Code.UNAVAILABLE); assertThat(configOrError.getError().getDescription()) - .isEqualTo("Invalid 'mingRingSize'/'maxRingSize'"); + .isEqualTo("Invalid 'minRingSize'/'maxRingSize'"); } @Test @@ -216,7 +225,60 @@ public void parseLoadBalancingConfig_minRingSizeGreaterThanMaxRingSize() throws assertThat(configOrError.getError()).isNotNull(); assertThat(configOrError.getError().getCode()).isEqualTo(Code.UNAVAILABLE); assertThat(configOrError.getError().getDescription()) - .isEqualTo("Invalid 'mingRingSize'/'maxRingSize'"); + .isEqualTo("Invalid 'minRingSize'/'maxRingSize'"); + } + + @Test + public void parseLoadBalancingConfig_requestHashHeaderIgnoredWhenEnvVarNotSet() + throws IOException { + String lbConfig = + "{\"minRingSize\" : 10, \"maxRingSize\" : 100, \"requestHashHeader\" : \"dummy-hash\"}"; + ConfigOrError configOrError = + provider.parseLoadBalancingPolicyConfig(parseJsonObject(lbConfig)); + assertThat(configOrError.getConfig()).isNotNull(); + RingHashConfig config = (RingHashConfig) configOrError.getConfig(); + assertThat(config.minRingSize).isEqualTo(10L); + assertThat(config.maxRingSize).isEqualTo(100L); + assertThat(config.requestHashHeader).isEmpty(); + } + + @Test + public void parseLoadBalancingConfig_requestHashHeaderSetWhenEnvVarSet() throws IOException { + System.setProperty(GRPC_EXPERIMENTAL_RING_HASH_SET_REQUEST_HASH_KEY, "true"); + try { + String lbConfig = + "{\"minRingSize\" : 10, \"maxRingSize\" : 100, \"requestHashHeader\" : \"dummy-hash\"}"; + ConfigOrError configOrError = + provider.parseLoadBalancingPolicyConfig(parseJsonObject(lbConfig)); + assertThat(configOrError.getConfig()).isNotNull(); + RingHashConfig config = (RingHashConfig) configOrError.getConfig(); + assertThat(config.minRingSize).isEqualTo(10L); + assertThat(config.maxRingSize).isEqualTo(100L); + assertThat(config.requestHashHeader).isEqualTo("dummy-hash"); + assertThat(config.toString()).contains("minRingSize=10"); + assertThat(config.toString()).contains("maxRingSize=100"); + assertThat(config.toString()).contains("requestHashHeader=dummy-hash"); + } finally { + System.clearProperty(GRPC_EXPERIMENTAL_RING_HASH_SET_REQUEST_HASH_KEY); + } + } + + @Test + public void parseLoadBalancingConfig_requestHashHeaderUnsetWhenEnvVarSet_useDefaults() + throws IOException { + System.setProperty(GRPC_EXPERIMENTAL_RING_HASH_SET_REQUEST_HASH_KEY, "true"); + try { + String lbConfig = "{\"minRingSize\" : 10, \"maxRingSize\" : 100}"; + ConfigOrError configOrError = + provider.parseLoadBalancingPolicyConfig(parseJsonObject(lbConfig)); + assertThat(configOrError.getConfig()).isNotNull(); + RingHashConfig config = (RingHashConfig) configOrError.getConfig(); + assertThat(config.minRingSize).isEqualTo(10L); + assertThat(config.maxRingSize).isEqualTo(100L); + assertThat(config.requestHashHeader).isEmpty(); + } finally { + System.clearProperty(GRPC_EXPERIMENTAL_RING_HASH_SET_REQUEST_HASH_KEY); + } } @SuppressWarnings("unchecked") diff --git a/xds/src/test/java/io/grpc/xds/RingHashLoadBalancerTest.java b/xds/src/test/java/io/grpc/xds/RingHashLoadBalancerTest.java index 9d88998fe74..b515ed81158 100644 --- a/xds/src/test/java/io/grpc/xds/RingHashLoadBalancerTest.java +++ b/xds/src/test/java/io/grpc/xds/RingHashLoadBalancerTest.java @@ -17,12 +17,12 @@ package io.grpc.xds; import static com.google.common.truth.Truth.assertThat; +import static com.google.common.truth.Truth.assertWithMessage; import static io.grpc.ConnectivityState.CONNECTING; import static io.grpc.ConnectivityState.IDLE; import static io.grpc.ConnectivityState.READY; import static io.grpc.ConnectivityState.SHUTDOWN; import static io.grpc.ConnectivityState.TRANSIENT_FAILURE; -import static io.grpc.util.MultiChildLoadBalancer.IS_PETIOLE_POLICY; import static io.grpc.xds.RingHashLoadBalancerTest.InitializationFlags.DO_NOT_RESET_HELPER; import static io.grpc.xds.RingHashLoadBalancerTest.InitializationFlags.DO_NOT_VERIFY; import static io.grpc.xds.RingHashLoadBalancerTest.InitializationFlags.RESET_SUBCHANNEL_MOCKS; @@ -30,6 +30,7 @@ import static org.mockito.AdditionalAnswers.delegatesTo; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.atLeast; import static org.mockito.Mockito.atLeastOnce; import static org.mockito.Mockito.clearInvocations; import static org.mockito.Mockito.mock; @@ -41,12 +42,16 @@ import com.google.common.collect.Iterables; import com.google.common.primitives.UnsignedInteger; +import com.google.common.testing.EqualsTester; import io.grpc.Attributes; import io.grpc.CallOptions; +import io.grpc.ConnectivityState; import io.grpc.ConnectivityStateInfo; import io.grpc.EquivalentAddressGroup; +import io.grpc.LoadBalancer; import io.grpc.LoadBalancer.CreateSubchannelArgs; import io.grpc.LoadBalancer.Helper; +import io.grpc.LoadBalancer.PickDetailsConsumer; import io.grpc.LoadBalancer.PickResult; import io.grpc.LoadBalancer.PickSubchannelArgs; import io.grpc.LoadBalancer.ResolvedAddresses; @@ -56,11 +61,14 @@ import io.grpc.Status; import io.grpc.Status.Code; import io.grpc.SynchronizationContext; +import io.grpc.internal.FakeClock; +import io.grpc.internal.PickFirstLoadBalancerProvider; +import io.grpc.internal.PickFirstLoadBalancerProviderAccessor; import io.grpc.internal.PickSubchannelArgsImpl; import io.grpc.testing.TestMethodDescriptors; import io.grpc.util.AbstractTestHelper; +import io.grpc.util.ForwardingLoadBalancerHelper; import io.grpc.util.MultiChildLoadBalancer.ChildLbState; -import io.grpc.xds.RingHashLoadBalancer.RingHashChildLbState; import io.grpc.xds.RingHashLoadBalancer.RingHashConfig; import java.lang.Thread.UncaughtExceptionHandler; import java.net.SocketAddress; @@ -69,8 +77,11 @@ import java.util.Collections; import java.util.Deque; import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Random; +import java.util.Set; import org.junit.After; import org.junit.Before; import org.junit.Rule; @@ -88,7 +99,13 @@ @RunWith(JUnit4.class) public class RingHashLoadBalancerTest { private static final String AUTHORITY = "foo.googleapis.com"; + private static final String CUSTOM_REQUEST_HASH_HEADER = "custom-request-hash-header"; + private static final Metadata.Key CUSTOM_METADATA_KEY = + Metadata.Key.of(CUSTOM_REQUEST_HASH_HEADER, Metadata.ASCII_STRING_MARSHALLER); private static final Attributes.Key CUSTOM_KEY = Attributes.Key.create("custom-key"); + private static final ConnectivityStateInfo CSI_CONNECTING = + ConnectivityStateInfo.forNonError(CONNECTING); + public static final ConnectivityStateInfo CSI_READY = ConnectivityStateInfo.forNonError(READY); @Rule public final MockitoRule mocks = MockitoJUnit.rule(); @@ -107,6 +124,7 @@ public void uncaughtException(Thread t, Throwable e) { @Captor private ArgumentCaptor pickerCaptor; private RingHashLoadBalancer loadBalancer; + private boolean defaultNewPickFirst = PickFirstLoadBalancerProvider.isEnabledNewPickFirst(); @Before public void setUp() { @@ -118,6 +136,7 @@ public void setUp() { @After public void tearDown() { + PickFirstLoadBalancerProviderAccessor.setEnableNewPickFirst(defaultNewPickFirst); loadBalancer.shutdown(); for (Subchannel subchannel : subchannels.values()) { verify(subchannel).shutdown(); @@ -127,7 +146,7 @@ public void tearDown() { @Test public void subchannelLazyConnectUntilPicked() { - RingHashConfig config = new RingHashConfig(10, 100); + RingHashConfig config = new RingHashConfig(10, 100, ""); List servers = createWeightedServerAddrs(1); // one server Status addressesAcceptanceStatus = loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder() @@ -142,23 +161,26 @@ public void subchannelLazyConnectUntilPicked() { assertThat(result.getStatus().isOk()).isTrue(); assertThat(result.getSubchannel()).isNull(); Subchannel subchannel = Iterables.getOnlyElement(subchannels.values()); - verify(subchannel).requestConnection(); + int expectedTimes = PickFirstLoadBalancerProvider.isEnabledNewPickFirst() + && !PickFirstLoadBalancerProvider.isEnabledHappyEyeballs() ? 1 : 2; + verify(subchannel, times(expectedTimes)).requestConnection(); verify(helper).updateBalancingState(eq(CONNECTING), any(SubchannelPicker.class)); verify(helper).createSubchannel(any(CreateSubchannelArgs.class)); - deliverSubchannelState(subchannel, ConnectivityStateInfo.forNonError(CONNECTING)); - verify(helper, times(2)).updateBalancingState(eq(CONNECTING), any(SubchannelPicker.class)); + deliverSubchannelState(subchannel, CSI_CONNECTING); + int expectedCount = PickFirstLoadBalancerProvider.isEnabledNewPickFirst() ? 1 : 2; + verify(helper, times(expectedCount)).updateBalancingState(eq(CONNECTING), any()); // Subchannel becomes ready, triggers pick again. - deliverSubchannelState(subchannel, ConnectivityStateInfo.forNonError(READY)); + deliverSubchannelState(subchannel, CSI_READY); verify(helper).updateBalancingState(eq(READY), pickerCaptor.capture()); result = pickerCaptor.getValue().pickSubchannel(args); assertThat(result.getSubchannel()).isSameInstanceAs(subchannel); - verifyNoMoreInteractions(helper); + AbstractTestHelper.verifyNoMoreMeaningfulInteractions(helper); } @Test public void subchannelNotAutoReconnectAfterReenteringIdle() { - RingHashConfig config = new RingHashConfig(10, 100); + RingHashConfig config = new RingHashConfig(10, 100, ""); List servers = createWeightedServerAddrs(1); // one server Status addressesAcceptanceStatus = loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder() @@ -166,19 +188,17 @@ public void subchannelNotAutoReconnectAfterReenteringIdle() { assertThat(addressesAcceptanceStatus.isOk()).isTrue(); verify(helper).updateBalancingState(eq(IDLE), pickerCaptor.capture()); - RingHashChildLbState childLbState = - (RingHashChildLbState) loadBalancer.getChildLbStates().iterator().next(); - assertThat(childLbState.isDeactivated()).isTrue(); + assertThat(subchannels).isEmpty(); // Picking subchannel triggers connection. PickSubchannelArgs args = getDefaultPickSubchannelArgs(hashFunc.hashVoid()); pickerCaptor.getValue().pickSubchannel(args); - assertThat(childLbState.isDeactivated()).isFalse(); - assertThat(childLbState.getLb().delegateType()).isEqualTo("PickFirstLoadBalancer"); - Subchannel subchannel = subchannels.get(Collections.singletonList(childLbState.getEag())); + Subchannel subchannel = subchannels.get(Collections.singletonList(servers.get(0))); InOrder inOrder = Mockito.inOrder(helper, subchannel); - inOrder.verify(subchannel).requestConnection(); - deliverSubchannelState(subchannel, ConnectivityStateInfo.forNonError(READY)); + int expectedTimes = PickFirstLoadBalancerProvider.isEnabledHappyEyeballs() + || !PickFirstLoadBalancerProvider.isEnabledNewPickFirst() ? 2 : 1; + inOrder.verify(subchannel, times(expectedTimes)).requestConnection(); + deliverSubchannelState(subchannel, CSI_READY); inOrder.verify(helper).updateBalancingState(eq(READY), any(SubchannelPicker.class)); deliverSubchannelState(subchannel, ConnectivityStateInfo.forNonError(IDLE)); inOrder.verify(helper).updateBalancingState(eq(IDLE), pickerCaptor.capture()); @@ -191,57 +211,58 @@ public void subchannelNotAutoReconnectAfterReenteringIdle() { @Test public void aggregateSubchannelStates_connectingReadyIdleFailure() { - RingHashConfig config = new RingHashConfig(10, 100); + RingHashConfig config = new RingHashConfig(10, 100, ""); List servers = createWeightedServerAddrs(1, 1); InOrder inOrder = Mockito.inOrder(helper); initializeLbSubchannels(config, servers); // one in CONNECTING, one in IDLE - deliverSubchannelState( - subchannels.get(Collections.singletonList(servers.get(0))), - ConnectivityStateInfo.forNonError(CONNECTING)); + deliverSubchannelState(getSubchannel(servers, 0), CSI_CONNECTING); inOrder.verify(helper).updateBalancingState(eq(CONNECTING), any(SubchannelPicker.class)); verifyConnection(0); // two in CONNECTING - deliverSubchannelState( - subchannels.get(Collections.singletonList(servers.get(1))), - ConnectivityStateInfo.forNonError(CONNECTING)); + deliverSubchannelState(getSubchannel(servers, 1), CSI_CONNECTING); inOrder.verify(helper).updateBalancingState(eq(CONNECTING), any(SubchannelPicker.class)); verifyConnection(0); // one in CONNECTING, one in READY - deliverSubchannelState( - subchannels.get(Collections.singletonList(servers.get(1))), - ConnectivityStateInfo.forNonError(READY)); + deliverSubchannelState(getSubchannel(servers, 1), CSI_READY); inOrder.verify(helper).updateBalancingState(eq(READY), any(SubchannelPicker.class)); verifyConnection(0); // one in TRANSIENT_FAILURE, one in READY deliverSubchannelState( - subchannels.get(Collections.singletonList(servers.get(0))), + getSubchannel(servers, 0), ConnectivityStateInfo.forTransientFailure( Status.UNKNOWN.withDescription("unknown failure"))); - inOrder.verify(helper).refreshNameResolution(); - inOrder.verify(helper).updateBalancingState(eq(READY), any(SubchannelPicker.class)); + if (PickFirstLoadBalancerProvider.isEnabledNewPickFirst()) { + inOrder.verify(helper).updateBalancingState(eq(READY), any()); + } else { + inOrder.verify(helper).refreshNameResolution(); + inOrder.verify(helper).updateBalancingState(eq(READY), any()); + } verifyConnection(0); // one in TRANSIENT_FAILURE, one in IDLE deliverSubchannelState( - subchannels.get(Collections.singletonList(servers.get(1))), + getSubchannel(servers, 1), ConnectivityStateInfo.forNonError(IDLE)); - inOrder.verify(helper).refreshNameResolution(); - inOrder.verify(helper).updateBalancingState(eq(CONNECTING), any(SubchannelPicker.class)); - verifyConnection(0); - - verifyNoMoreInteractions(helper); + if (PickFirstLoadBalancerProvider.isEnabledNewPickFirst()) { + inOrder.verify(helper).updateBalancingState(eq(CONNECTING), any()); + } else { + inOrder.verify(helper).refreshNameResolution(); + inOrder.verify(helper).updateBalancingState(eq(CONNECTING), any()); + } + verifyConnection(1); } private void verifyConnection(int times) { for (int i = 0; i < times; i++) { Subchannel connectOnce = connectionRequestedQueue.poll(); - assertThat(connectOnce).isNotNull(); + assertWithMessage("Expected %s new connections, but found %s", times, i) + .that(connectOnce).isNotNull(); clearInvocations(connectOnce); } assertThat(connectionRequestedQueue.poll()).isNull(); @@ -249,7 +270,7 @@ private void verifyConnection(int times) { @Test public void aggregateSubchannelStates_allSubchannelsInTransientFailure() { - RingHashConfig config = new RingHashConfig(10, 100); + RingHashConfig config = new RingHashConfig(10, 100, ""); List servers = createWeightedServerAddrs(1, 1, 1, 1); List subChannelList = initializeLbSubchannels(config, servers, STAY_IN_CONNECTING); @@ -261,42 +282,53 @@ public void aggregateSubchannelStates_allSubchannelsInTransientFailure() { // one in TRANSIENT_FAILURE, three in CONNECTING deliverNotFound(subChannelList, 0); - inOrder.verify(helper).refreshNameResolution(); - inOrder.verify(helper).updateBalancingState(eq(CONNECTING), any(SubchannelPicker.class)); + refreshInvokedButNotUpdateBS(inOrder, TRANSIENT_FAILURE); // two in TRANSIENT_FAILURE, two in CONNECTING deliverNotFound(subChannelList, 1); - inOrder.verify(helper).refreshNameResolution(); - inOrder.verify(helper) - .updateBalancingState(eq(TRANSIENT_FAILURE), any(SubchannelPicker.class)); + refreshInvokedAndUpdateBS(inOrder, TRANSIENT_FAILURE); // All 4 in TF switch to TF deliverNotFound(subChannelList, 2); - inOrder.verify(helper).refreshNameResolution(); - inOrder.verify(helper) - .updateBalancingState(eq(TRANSIENT_FAILURE), any(SubchannelPicker.class)); + refreshInvokedAndUpdateBS(inOrder, TRANSIENT_FAILURE); deliverNotFound(subChannelList, 3); - inOrder.verify(helper).refreshNameResolution(); - inOrder.verify(helper) - .updateBalancingState(eq(TRANSIENT_FAILURE), any(SubchannelPicker.class)); + refreshInvokedAndUpdateBS(inOrder, TRANSIENT_FAILURE); // reset subchannel to CONNECTING - shouldn't change anything since PF hides the state change - deliverSubchannelState(subChannelList.get(2), ConnectivityStateInfo.forNonError(CONNECTING)); + deliverSubchannelState(subChannelList.get(2), CSI_CONNECTING); inOrder.verify(helper, never()) .updateBalancingState(eq(TRANSIENT_FAILURE), any(SubchannelPicker.class)); inOrder.verify(subChannelList.get(2), never()).requestConnection(); // three in TRANSIENT_FAILURE, one in READY - deliverSubchannelState(subChannelList.get(2), ConnectivityStateInfo.forNonError(READY)); + deliverSubchannelState(subChannelList.get(2), CSI_READY); inOrder.verify(helper).updateBalancingState(eq(READY), any(SubchannelPicker.class)); inOrder.verify(subChannelList.get(2), never()).requestConnection(); + } - verifyNoMoreInteractions(helper); + // Old PF and new PF reverse calling order of updateBlaancingState and refreshNameResolution + private void refreshInvokedButNotUpdateBS(InOrder inOrder, ConnectivityState state) { + inOrder.verify(helper, never()).updateBalancingState(eq(state), any(SubchannelPicker.class)); + inOrder.verify(helper).refreshNameResolution(); + inOrder.verify(helper, never()).updateBalancingState(eq(state), any(SubchannelPicker.class)); + } + + // Old PF and new PF reverse calling order of updateBlaancingState and refreshNameResolution + private void refreshInvokedAndUpdateBS(InOrder inOrder, ConnectivityState state) { + if (PickFirstLoadBalancerProvider.isEnabledNewPickFirst()) { + inOrder.verify(helper).updateBalancingState(eq(state), any()); + } + + inOrder.verify(helper).refreshNameResolution(); + + if (!PickFirstLoadBalancerProvider.isEnabledNewPickFirst()) { + inOrder.verify(helper).updateBalancingState(eq(state), any()); + } } @Test public void ignoreShutdownSubchannelStateChange() { - RingHashConfig config = new RingHashConfig(10, 100); + RingHashConfig config = new RingHashConfig(10, 100, ""); List servers = createWeightedServerAddrs(1, 1, 1); initializeLbSubchannels(config, servers); @@ -312,14 +344,14 @@ public void ignoreShutdownSubchannelStateChange() { @Test public void deterministicPickWithHostsPartiallyRemoved() { - RingHashConfig config = new RingHashConfig(10, 100); + RingHashConfig config = new RingHashConfig(10, 100, ""); List servers = createWeightedServerAddrs(1, 1, 1, 1, 1); initializeLbSubchannels(config, servers); InOrder inOrder = Mockito.inOrder(helper); // Bring all subchannels to READY so that next pick always succeeds. for (Subchannel subchannel : subchannels.values()) { - deliverSubchannelState(subchannel, ConnectivityStateInfo.forNonError(READY)); + deliverSubchannelState(subchannel, CSI_READY); inOrder.verify(helper).updateBalancingState(eq(READY), pickerCaptor.capture()); } @@ -336,8 +368,8 @@ public void deterministicPickWithHostsPartiallyRemoved() { Attributes attr = addr.getAttributes().toBuilder().set(CUSTOM_KEY, "custom value").build(); updatedServers.add(new EquivalentAddressGroup(addr.getAddresses(), attr)); } - Subchannel subchannel0_old = subchannels.get(Collections.singletonList(servers.get(0))); - Subchannel subchannel1_old = subchannels.get(Collections.singletonList(servers.get(1))); + Subchannel subchannel0_old = getSubchannel(servers, 0); + Subchannel subchannel1_old = getSubchannel(servers, 1); Status addressesAcceptanceStatus = loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(updatedServers).setLoadBalancingPolicyConfig(config).build()); @@ -352,7 +384,7 @@ public void deterministicPickWithHostsPartiallyRemoved() { @Test public void deterministicPickWithNewHostsAdded() { - RingHashConfig config = new RingHashConfig(10, 100); + RingHashConfig config = new RingHashConfig(10, 100, ""); List servers = createWeightedServerAddrs(1, 1); // server0 and server1 initializeLbSubchannels(config, servers, DO_NOT_VERIFY, DO_NOT_RESET_HELPER); @@ -360,7 +392,7 @@ public void deterministicPickWithNewHostsAdded() { // Bring all subchannels to READY so that next pick always succeeds. for (Subchannel subchannel : subchannels.values()) { - deliverSubchannelState(subchannel, ConnectivityStateInfo.forNonError(READY)); + deliverSubchannelState(subchannel, CSI_READY); inOrder.verify(helper).updateBalancingState(eq(READY), pickerCaptor.capture()); } @@ -384,6 +416,139 @@ public void deterministicPickWithNewHostsAdded() { inOrder.verifyNoMoreInteractions(); } + @Test + public void deterministicPickWithRequestHashHeader_oneHeaderValue() { + // Map each server address to exactly one ring entry. + RingHashConfig config = new RingHashConfig(3, 3, CUSTOM_REQUEST_HASH_HEADER); + List servers = createWeightedServerAddrs(1, 1, 1); + initializeLbSubchannels(config, servers); + InOrder inOrder = Mockito.inOrder(helper); + + // Bring all subchannels to READY. + for (Subchannel subchannel : subchannels.values()) { + deliverSubchannelState(subchannel, CSI_READY); + inOrder.verify(helper).updateBalancingState(eq(READY), pickerCaptor.capture()); + } + + // Pick subchannel with custom request hash header where the rpc hash hits server1. + Metadata headers = new Metadata(); + headers.put(CUSTOM_METADATA_KEY, "FakeSocketAddress-server1_0"); + PickSubchannelArgs args = + new PickSubchannelArgsImpl( + TestMethodDescriptors.voidMethod(), + headers, + CallOptions.DEFAULT, + new PickDetailsConsumer() {}); + SubchannelPicker picker = pickerCaptor.getValue(); + PickResult result = picker.pickSubchannel(args); + assertThat(result.getStatus().isOk()).isTrue(); + assertThat(result.getSubchannel().getAddresses()).isEqualTo(servers.get(1)); + } + + @Test + public void deterministicPickWithRequestHashHeader_multipleHeaderValues() { + // Map each server address to exactly one ring entry. + RingHashConfig config = new RingHashConfig(3, 3, CUSTOM_REQUEST_HASH_HEADER); + List servers = createWeightedServerAddrs(1, 1, 1); + initializeLbSubchannels(config, servers); + InOrder inOrder = Mockito.inOrder(helper); + + // Bring all subchannels to READY. + for (Subchannel subchannel : subchannels.values()) { + deliverSubchannelState(subchannel, CSI_READY); + inOrder.verify(helper).updateBalancingState(eq(READY), pickerCaptor.capture()); + } + + // Pick subchannel with custom request hash header with multiple values for the same key where + // the rpc hash hits server1. + Metadata headers = new Metadata(); + headers.put(CUSTOM_METADATA_KEY, "FakeSocketAddress-server0_0"); + headers.put(CUSTOM_METADATA_KEY, "FakeSocketAddress-server1_0"); + PickSubchannelArgs args = + new PickSubchannelArgsImpl( + TestMethodDescriptors.voidMethod(), + headers, + CallOptions.DEFAULT, + new PickDetailsConsumer() {}); + SubchannelPicker picker = pickerCaptor.getValue(); + PickResult result = picker.pickSubchannel(args); + assertThat(result.getStatus().isOk()).isTrue(); + assertThat(result.getSubchannel().getAddresses()).isEqualTo(servers.get(1)); + } + + @Test + public void pickWithRandomHash_allSubchannelsReady() { + loadBalancer = new RingHashLoadBalancer(helper, new FakeRandom()); + // Map each server address to exactly one ring entry. + RingHashConfig config = new RingHashConfig(2, 2, "dummy-random-hash"); + List servers = createWeightedServerAddrs(1, 1); + initializeLbSubchannels(config, servers); + InOrder inOrder = Mockito.inOrder(helper); + + // Bring all subchannels to READY. + Map pickCounts = new HashMap<>(); + for (Subchannel subchannel : subchannels.values()) { + deliverSubchannelState(subchannel, CSI_READY); + pickCounts.put(subchannel.getAddresses(), 0); + inOrder.verify(helper).updateBalancingState(eq(READY), pickerCaptor.capture()); + } + + // Pick subchannel 100 times with random hash. + SubchannelPicker picker = pickerCaptor.getValue(); + PickSubchannelArgs args = getDefaultPickSubchannelArgs(hashFunc.hashVoid()); + for (int i = 0; i < 100; ++i) { + Subchannel pickedSubchannel = picker.pickSubchannel(args).getSubchannel(); + EquivalentAddressGroup addr = pickedSubchannel.getAddresses(); + pickCounts.put(addr, pickCounts.get(addr) + 1); + } + + // Verify the distribution is uniform where server0 and server1 are exactly picked 50 times. + assertThat(pickCounts.get(servers.get(0))).isEqualTo(50); + assertThat(pickCounts.get(servers.get(1))).isEqualTo(50); + } + + @Test + public void pickWithRandomHash_atLeastOneSubchannelConnecting() { + // Map each server address to exactly one ring entry. + RingHashConfig config = new RingHashConfig(3, 3, "dummy-random-hash"); + List servers = createWeightedServerAddrs(1, 1, 1); + initializeLbSubchannels(config, servers); + + // Bring one subchannel to CONNECTING. + deliverSubchannelState(getSubChannel(servers.get(0)), CSI_CONNECTING); + verify(helper).updateBalancingState(eq(CONNECTING), pickerCaptor.capture()); + + // Pick subchannel with random hash does not trigger connection. + SubchannelPicker picker = pickerCaptor.getValue(); + PickSubchannelArgs args = getDefaultPickSubchannelArgs(hashFunc.hashVoid()); + PickResult result = picker.pickSubchannel(args); + assertThat(result.getStatus().isOk()).isTrue(); + assertThat(result.getSubchannel()).isNull(); // buffer request + verifyConnection(0); + } + + @Test + public void pickWithRandomHash_firstSubchannelInTransientFailure_remainingSubchannelsIdle() { + // Map each server address to exactly one ring entry. + RingHashConfig config = new RingHashConfig(3, 3, "dummy-random-hash"); + List servers = createWeightedServerAddrs(1, 1, 1); + initializeLbSubchannels(config, servers); + + // Bring one subchannel to TRANSIENT_FAILURE. + deliverSubchannelUnreachable(getSubChannel(servers.get(0))); + verify(helper).updateBalancingState(eq(CONNECTING), pickerCaptor.capture()); + verifyConnection(1); + + // Pick subchannel with random hash does trigger connection by walking the ring + // and choosing the first (at most one) IDLE subchannel along the way. + SubchannelPicker picker = pickerCaptor.getValue(); + PickSubchannelArgs args = getDefaultPickSubchannelArgs(hashFunc.hashVoid()); + PickResult result = picker.pickSubchannel(args); + assertThat(result.getStatus().isOk()).isTrue(); + assertThat(result.getSubchannel()).isNull(); // buffer request + verifyConnection(1); + } + private Subchannel getSubChannel(EquivalentAddressGroup eag) { return subchannels.get(Collections.singletonList(eag)); } @@ -391,7 +556,7 @@ private Subchannel getSubChannel(EquivalentAddressGroup eag) { @Test public void skipFailingHosts_pickNextNonFailingHost() { // Map each server address to exactly one ring entry. - RingHashConfig config = new RingHashConfig(3, 3); + RingHashConfig config = new RingHashConfig(3, 3, ""); List servers = createWeightedServerAddrs(1, 1, 1); Status addressesAcceptanceStatus = loadBalancer.acceptResolvedAddresses( @@ -400,7 +565,8 @@ public void skipFailingHosts_pickNextNonFailingHost() { assertThat(addressesAcceptanceStatus.isOk()).isTrue(); // Create subchannel for the first address - ((RingHashChildLbState)loadBalancer.getChildLbStateEag(servers.get(0))).activate(); + loadBalancer.getChildLbStates().iterator().next().getCurrentPicker() + .pickSubchannel(getDefaultPickSubchannelArgs(hashFunc.hashVoid())); verifyConnection(1); reset(helper); @@ -417,24 +583,27 @@ public void skipFailingHosts_pickNextNonFailingHost() { getSubChannel(servers.get(0)), ConnectivityStateInfo.forTransientFailure( Status.UNAVAILABLE.withDescription("unreachable"))); - verify(helper).updateBalancingState(eq(CONNECTING), pickerCaptor.capture()); + verify(helper, atLeastOnce()).updateBalancingState(eq(CONNECTING), pickerCaptor.capture()); PickResult result = pickerCaptor.getValue().pickSubchannel(args); assertThat(result.getStatus().isOk()).isTrue(); assertThat(result.getSubchannel()).isNull(); // buffer request - verify(getSubChannel(servers.get(1))).requestConnection(); // kicked off connection to server2 + // verify kicked off connection to server2 + int expectedTimes = PickFirstLoadBalancerProvider.isEnabledHappyEyeballs() + || !PickFirstLoadBalancerProvider.isEnabledNewPickFirst() ? 2 : 1; + + verify(getSubChannel(servers.get(1)), times(expectedTimes)).requestConnection(); assertThat(subchannels.size()).isEqualTo(2); // no excessive connection - reset(helper); - deliverSubchannelState(getSubChannel(servers.get(1)), - ConnectivityStateInfo.forNonError(CONNECTING)); - verify(helper).updateBalancingState(eq(CONNECTING), pickerCaptor.capture()); + deliverSubchannelState(getSubChannel(servers.get(1)), CSI_CONNECTING); + verify(helper, atLeast(1)) + .updateBalancingState(eq(CONNECTING), pickerCaptor.capture()); result = pickerCaptor.getValue().pickSubchannel(args); assertThat(result.getStatus().isOk()).isTrue(); assertThat(result.getSubchannel()).isNull(); // buffer request - deliverSubchannelState(getSubChannel(servers.get(1)), ConnectivityStateInfo.forNonError(READY)); + deliverSubchannelState(getSubChannel(servers.get(1)), CSI_READY); verify(helper).updateBalancingState(eq(READY), pickerCaptor.capture()); result = pickerCaptor.getValue().pickSubchannel(args); @@ -442,13 +611,14 @@ public void skipFailingHosts_pickNextNonFailingHost() { assertThat(result.getSubchannel().getAddresses()).isEqualTo(servers.get(1)); } - private PickSubchannelArgsImpl getDefaultPickSubchannelArgs(long rpcHash) { + private PickSubchannelArgs getDefaultPickSubchannelArgs(long rpcHash) { return new PickSubchannelArgsImpl( TestMethodDescriptors.voidMethod(), new Metadata(), - CallOptions.DEFAULT.withOption(XdsNameResolver.RPC_HASH_KEY, rpcHash)); + CallOptions.DEFAULT.withOption(XdsNameResolver.RPC_HASH_KEY, rpcHash), + new PickDetailsConsumer() {}); } - private PickSubchannelArgsImpl getDefaultPickSubchannelArgsForServer(int serverid) { + private PickSubchannelArgs getDefaultPickSubchannelArgsForServer(int serverid) { long rpcHash = hashFunc.hashAsciiString("FakeSocketAddress-server" + serverid + "_0"); return getDefaultPickSubchannelArgs(rpcHash); } @@ -456,7 +626,7 @@ private PickSubchannelArgsImpl getDefaultPickSubchannelArgsForServer(int serveri @Test public void skipFailingHosts_firstTwoHostsFailed_pickNextFirstReady() { // Map each server address to exactly one ring entry. - RingHashConfig config = new RingHashConfig(3, 3); + RingHashConfig config = new RingHashConfig(3, 3, ""); List servers = createWeightedServerAddrs(1, 1, 1); initializeLbSubchannels(config, servers); @@ -471,21 +641,22 @@ public void skipFailingHosts_firstTwoHostsFailed_pickNextFirstReady() { // Bring down server0 and server2 to force trying server1. deliverSubchannelState( - subchannels.get(Collections.singletonList(servers.get(1))), + getSubchannel(servers, 1), ConnectivityStateInfo.forTransientFailure( Status.UNAVAILABLE.withDescription("unreachable"))); deliverSubchannelState( - subchannels.get(Collections.singletonList(servers.get(2))), + getSubchannel(servers, 2), ConnectivityStateInfo.forTransientFailure( Status.PERMISSION_DENIED.withDescription("permission denied"))); - verify(helper).updateBalancingState(eq(CONNECTING), pickerCaptor.capture()); - verifyConnection(0); + verify(helper).updateBalancingState(eq(TRANSIENT_FAILURE), pickerCaptor.capture()); + verifyConnection(2); PickResult result = pickerCaptor.getValue().pickSubchannel(args); // activate last subchannel assertThat(result.getStatus().isOk()).isTrue(); - verifyConnection(1); + int expectedCount = PickFirstLoadBalancerProvider.isEnabledNewPickFirst() ? 0 : 1; + verifyConnection(expectedCount); deliverSubchannelState( - subchannels.get(Collections.singletonList(servers.get(0))), + getSubchannel(servers, 0), ConnectivityStateInfo.forTransientFailure( Status.PERMISSION_DENIED.withDescription("permission denied again"))); verify(helper, times(2)).updateBalancingState(eq(TRANSIENT_FAILURE), pickerCaptor.capture()); @@ -496,9 +667,7 @@ public void skipFailingHosts_firstTwoHostsFailed_pickNextFirstReady() { assertThat(result.getStatus().getDescription()).isEqualTo("unreachable"); // Now connecting to server1. - deliverSubchannelState( - subchannels.get(Collections.singletonList(servers.get(1))), - ConnectivityStateInfo.forNonError(CONNECTING)); + deliverSubchannelState(getSubchannel(servers, 1), CSI_CONNECTING); reset(helper); @@ -509,9 +678,7 @@ public void skipFailingHosts_firstTwoHostsFailed_pickNextFirstReady() { assertThat(result.getStatus().getDescription()).isEqualTo("unreachable"); // Simulate server1 becomes READY. - deliverSubchannelState( - subchannels.get(Collections.singletonList(servers.get(1))), - ConnectivityStateInfo.forNonError(READY)); + deliverSubchannelState(getSubchannel(servers, 1), CSI_READY); verify(helper).updateBalancingState(eq(READY), pickerCaptor.capture()); SubchannelPicker picker = pickerCaptor.getValue(); @@ -525,7 +692,7 @@ public void skipFailingHosts_firstTwoHostsFailed_pickNextFirstReady() { @Test public void removingAddressShutdownSubchannel() { // Map each server address to exactly one ring entry. - RingHashConfig config = new RingHashConfig(3, 3); + RingHashConfig config = new RingHashConfig(3, 3, ""); List svs1 = createWeightedServerAddrs(1, 1, 1); List subchannels1 = initializeLbSubchannels(config, svs1, STAY_IN_CONNECTING); @@ -542,7 +709,7 @@ public void removingAddressShutdownSubchannel() { @Test public void allSubchannelsInTransientFailure() { // Map each server address to exactly one ring entry. - RingHashConfig config = new RingHashConfig(3, 3); + RingHashConfig config = new RingHashConfig(3, 3, ""); List servers = createWeightedServerAddrs(1, 1, 1); initializeLbSubchannels(config, servers); @@ -554,7 +721,7 @@ public void allSubchannelsInTransientFailure() { } verify(helper, atLeastOnce()) .updateBalancingState(eq(TRANSIENT_FAILURE), pickerCaptor.capture()); - verifyConnection(0); + verifyConnection(2); // Picking subchannel triggers connection. RPC hash hits server0. PickSubchannelArgs args = getDefaultPickSubchannelArgsForServer(0); @@ -569,16 +736,17 @@ public void allSubchannelsInTransientFailure() { @Test public void firstSubchannelIdle() { // Map each server address to exactly one ring entry. - RingHashConfig config = new RingHashConfig(3, 3); + RingHashConfig config = new RingHashConfig(3, 3, ""); List servers = createWeightedServerAddrs(1, 1, 1); initializeLbSubchannels(config, servers); - // Go to TF does nothing, though PF will try to reconnect after backoff - deliverSubchannelState(subchannels.get(Collections.singletonList(servers.get(1))), + // As per gRFC A61, entering TF triggers a proactive connection attempt + // on an IDLE subchannel because no other subchannel is currently CONNECTING. + deliverSubchannelState(getSubchannel(servers, 1), ConnectivityStateInfo.forTransientFailure( Status.UNAVAILABLE.withDescription("unreachable"))); verify(helper).updateBalancingState(eq(CONNECTING), pickerCaptor.capture()); - verifyConnection(0); + verifyConnection(1); // Picking subchannel triggers connection. RPC hash hits server0. PickSubchannelArgs args = getDefaultPickSubchannelArgs(hashFunc.hashVoid()); @@ -590,32 +758,31 @@ public void firstSubchannelIdle() { @Test public void firstSubchannelConnecting() { // Map each server address to exactly one ring entry. - RingHashConfig config = new RingHashConfig(3, 3); + RingHashConfig config = new RingHashConfig(3, 3, ""); List servers = createWeightedServerAddrs(1, 1, 1); initializeLbSubchannels(config, servers); - deliverSubchannelState(subchannels.get(Collections.singletonList(servers.get(0))), - ConnectivityStateInfo.forNonError(CONNECTING)); - deliverSubchannelState(subchannels.get(Collections.singletonList(servers.get(1))), - ConnectivityStateInfo.forNonError(CONNECTING)); + deliverSubchannelState(getSubchannel(servers, 0), CSI_CONNECTING); + deliverSubchannelState(getSubchannel(servers, 1), CSI_CONNECTING); verify(helper, times(2)).updateBalancingState(eq(CONNECTING), pickerCaptor.capture()); // Picking subchannel triggers connection. PickSubchannelArgs args = getDefaultPickSubchannelArgs(hashFunc.hashVoid()); PickResult result = pickerCaptor.getValue().pickSubchannel(args); assertThat(result.getStatus().isOk()).isTrue(); - verify(subchannels.get(Collections.singletonList(servers.get(0))), never()) - .requestConnection(); - verify(subchannels.get(Collections.singletonList(servers.get(1))), never()) - .requestConnection(); - verify(subchannels.get(Collections.singletonList(servers.get(2))), never()) - .requestConnection(); + verify(getSubchannel(servers, 0), never()).requestConnection(); + verify(getSubchannel(servers, 1), never()).requestConnection(); + verify(getSubchannel(servers, 2), never()).requestConnection(); + } + + private Subchannel getSubchannel(List servers, int serverIndex) { + return subchannels.get(Collections.singletonList(servers.get(serverIndex))); } @Test public void firstSubchannelFailure() { // Map each server address to exactly one ring entry. - RingHashConfig config = new RingHashConfig(3, 3); + RingHashConfig config = new RingHashConfig(3, 3, ""); List servers = createWeightedServerAddrs(1, 1, 1); List subchannelList = @@ -630,7 +797,7 @@ public void firstSubchannelFailure() { ConnectivityStateInfo.forTransientFailure( Status.UNAVAILABLE.withDescription("unreachable"))); verify(helper).updateBalancingState(eq(CONNECTING), pickerCaptor.capture()); - verifyConnection(0); + verifyConnection(1); // Per GRFC A61 Picking subchannel should no longer request connections that were failing PickSubchannelArgs args = getDefaultPickSubchannelArgs(hashFunc.hashVoid()); @@ -646,7 +813,7 @@ public void firstSubchannelFailure() { @Test public void secondSubchannelConnecting() { // Map each server address to exactly one ring entry. - RingHashConfig config = new RingHashConfig(3, 3); + RingHashConfig config = new RingHashConfig(3, 3, ""); List servers = createWeightedServerAddrs(1, 1, 1); initializeLbSubchannels(config, servers); @@ -656,19 +823,18 @@ public void secondSubchannelConnecting() { // "FakeSocketAddress-server0_0" // "FakeSocketAddress-server2_0" - Subchannel firstSubchannel = subchannels.get(Collections.singletonList(servers.get(0))); + Subchannel firstSubchannel = getSubchannel(servers, 0); deliverSubchannelUnreachable(firstSubchannel); - verifyConnection(0); + verifyConnection(1); - deliverSubchannelState(subchannels.get(Collections.singletonList(servers.get(2))), - ConnectivityStateInfo.forNonError(CONNECTING)); + deliverSubchannelState(getSubchannel(servers, 2), CSI_CONNECTING); verify(helper, times(2)).updateBalancingState(eq(CONNECTING), pickerCaptor.capture()); verifyConnection(0); // Picking subchannel when idle triggers connection. - deliverSubchannelState(subchannels.get(Collections.singletonList(servers.get(2))), + deliverSubchannelState(getSubchannel(servers, 2), ConnectivityStateInfo.forNonError(IDLE)); - verifyConnection(0); + verifyConnection(1); PickSubchannelArgs args = getDefaultPickSubchannelArgs(hashFunc.hashVoid()); PickResult result = pickerCaptor.getValue().pickSubchannel(args); assertThat(result.getStatus().isOk()).isTrue(); @@ -678,7 +844,7 @@ public void secondSubchannelConnecting() { @Test public void secondSubchannelFailure() { // Map each server address to exactly one ring entry. - RingHashConfig config = new RingHashConfig(3, 3); + RingHashConfig config = new RingHashConfig(3, 3, ""); List servers = createWeightedServerAddrs(1, 1, 1); initializeLbSubchannels(config, servers); @@ -688,24 +854,24 @@ public void secondSubchannelFailure() { // "FakeSocketAddress-server0_0" // "FakeSocketAddress-server2_0" - Subchannel firstSubchannel = subchannels.get(Collections.singletonList(servers.get(0))); + Subchannel firstSubchannel = getSubchannel(servers, 0); deliverSubchannelUnreachable(firstSubchannel); - deliverSubchannelUnreachable(subchannels.get(Collections.singletonList(servers.get(2)))); + deliverSubchannelUnreachable(getSubchannel(servers, 2)); verify(helper).updateBalancingState(eq(TRANSIENT_FAILURE), pickerCaptor.capture()); - verifyConnection(0); + verifyConnection(2); // Picking subchannel triggers connection. PickSubchannelArgs args = getDefaultPickSubchannelArgs(hashFunc.hashVoid()); PickResult result = pickerCaptor.getValue().pickSubchannel(args); assertThat(result.getStatus().isOk()).isTrue(); - verify(subchannels.get(Collections.singletonList(servers.get(1)))).requestConnection(); + verify(getSubchannel(servers, 1)).requestConnection(); verifyConnection(1); } @Test public void thirdSubchannelConnecting() { // Map each server address to exactly one ring entry. - RingHashConfig config = new RingHashConfig(3, 3); + RingHashConfig config = new RingHashConfig(3, 3, ""); List servers = createWeightedServerAddrs(1, 1, 1); initializeLbSubchannels(config, servers); @@ -715,15 +881,14 @@ public void thirdSubchannelConnecting() { // "FakeSocketAddress-server0_0" // "FakeSocketAddress-server2_0" - Subchannel firstSubchannel = subchannels.get(Collections.singletonList(servers.get(0))); + Subchannel firstSubchannel = getSubchannel(servers, 0); deliverSubchannelUnreachable(firstSubchannel); - deliverSubchannelUnreachable(subchannels.get(Collections.singletonList(servers.get(2)))); - deliverSubchannelState(subchannels.get(Collections.singletonList(servers.get(1))), - ConnectivityStateInfo.forNonError(CONNECTING)); + deliverSubchannelUnreachable(getSubchannel(servers, 2)); + deliverSubchannelState(getSubchannel(servers, 1), CSI_CONNECTING); verify(helper, atLeastOnce()) .updateBalancingState(eq(TRANSIENT_FAILURE), pickerCaptor.capture()); - verifyConnection(0); + verifyConnection(2); // Picking subchannel should not trigger connection per gRFC A61. PickSubchannelArgs args = getDefaultPickSubchannelArgs(hashFunc.hashVoid()); @@ -735,37 +900,36 @@ public void thirdSubchannelConnecting() { @Test public void stickyTransientFailure() { // Map each server address to exactly one ring entry. - RingHashConfig config = new RingHashConfig(3, 3); + RingHashConfig config = new RingHashConfig(3, 3, ""); List servers = createWeightedServerAddrs(1, 1, 1); initializeLbSubchannels(config, servers); // Bring one subchannel to TRANSIENT_FAILURE. - Subchannel firstSubchannel = subchannels.get(Collections.singletonList(servers.get(0))); + Subchannel firstSubchannel = getSubchannel(servers, 0); deliverSubchannelUnreachable(firstSubchannel); verify(helper).updateBalancingState(eq(CONNECTING), pickerCaptor.capture()); - verifyConnection(0); + verifyConnection(1); reset(helper); deliverSubchannelState(firstSubchannel, ConnectivityStateInfo.forNonError(IDLE)); // Should not have called updateBalancingState on the helper again because PickFirst is // shielding the higher level from the state change. verify(helper, never()).updateBalancingState(any(), any()); - verifyConnection(1); + verifyConnection(PickFirstLoadBalancerProvider.isEnabledNewPickFirst() ? 0 : 1); // Picking subchannel triggers connection on second address. RPC hash hits server0. PickSubchannelArgs args = getDefaultPickSubchannelArgs(hashFunc.hashVoid()); PickResult result = pickerCaptor.getValue().pickSubchannel(args); assertThat(result.getStatus().isOk()).isTrue(); - verify(subchannels.get(Collections.singletonList(servers.get(1)))).requestConnection(); - verify(subchannels.get(Collections.singletonList(servers.get(2))), never()) - .requestConnection(); + verify(getSubchannel(servers, 1)).requestConnection(); + verify(getSubchannel(servers, 2), never()).requestConnection(); } @Test public void largeWeights() { - RingHashConfig config = new RingHashConfig(10000, 100000); // large ring + RingHashConfig config = new RingHashConfig(10000, 100000, ""); // large ring List servers = createWeightedServerAddrs(Integer.MAX_VALUE, 10, 100); // MAX:10:100 @@ -803,7 +967,7 @@ public void largeWeights() { @Test public void hostSelectionProportionalToWeights() { - RingHashConfig config = new RingHashConfig(10000, 100000); // large ring + RingHashConfig config = new RingHashConfig(10000, 100000, ""); // large ring List servers = createWeightedServerAddrs(1, 10, 100); // 1:10:100 initializeLbSubchannels(config, servers); @@ -811,7 +975,7 @@ public void hostSelectionProportionalToWeights() { // Bring all subchannels to READY. Map pickCounts = new HashMap<>(); for (Subchannel subchannel : subchannels.values()) { - deliverSubchannelState(subchannel, ConnectivityStateInfo.forNonError(READY)); + deliverSubchannelState(subchannel, CSI_READY); pickCounts.put(subchannel.getAddresses(), 0); } verify(helper, times(3)).updateBalancingState(eq(READY), pickerCaptor.capture()); @@ -846,7 +1010,7 @@ public void nameResolutionErrorWithNoActiveSubchannels() { @Test public void nameResolutionErrorWithActiveSubchannels() { - RingHashConfig config = new RingHashConfig(10, 100); + RingHashConfig config = new RingHashConfig(10, 100, ""); List servers = createWeightedServerAddrs(1); initializeLbSubchannels(config, servers, DO_NOT_VERIFY, DO_NOT_RESET_HELPER); @@ -858,7 +1022,7 @@ public void nameResolutionErrorWithActiveSubchannels() { pickerCaptor.getValue().pickSubchannel(args); verify(helper, never()).updateBalancingState(eq(READY), any(SubchannelPicker.class)); deliverSubchannelState( - Iterables.getOnlyElement(subchannels.values()), ConnectivityStateInfo.forNonError(READY)); + Iterables.getOnlyElement(subchannels.values()), CSI_READY); verify(helper).updateBalancingState(eq(READY), any(SubchannelPicker.class)); reset(helper); @@ -868,7 +1032,7 @@ public void nameResolutionErrorWithActiveSubchannels() { @Test public void duplicateAddresses() { - RingHashConfig config = new RingHashConfig(10, 100); + RingHashConfig config = new RingHashConfig(10, 100, ""); List servers = createRepeatedServerAddrs(1, 2, 3); initializeLbSubchannels(config, servers, DO_NOT_VERIFY); @@ -887,6 +1051,116 @@ public void duplicateAddresses() { assertThat(description).contains("Address: FakeSocketAddress-server2, count: 3"); } + @Test + public void subchannelHealthObserved() throws Exception { + // Only the new PF policy observes the new separate listener for health + PickFirstLoadBalancerProviderAccessor.setEnableNewPickFirst(true); + // PickFirst does most of this work. If the test fails, check IS_PETIOLE_POLICY + Map healthListeners = new HashMap<>(); + loadBalancer = new RingHashLoadBalancer(new ForwardingLoadBalancerHelper() { + @Override + public Subchannel createSubchannel(CreateSubchannelArgs args) { + Subchannel subchannel = super.createSubchannel(args.toBuilder() + .setAttributes(args.getAttributes().toBuilder() + .set(LoadBalancer.HAS_HEALTH_PRODUCER_LISTENER_KEY, true) + .build()) + .build()); + healthListeners.put( + subchannel, args.getOption(LoadBalancer.HEALTH_CONSUMER_LISTENER_ARG_KEY)); + return subchannel; + } + + @Override + protected Helper delegate() { + return helper; + } + }); + + InOrder inOrder = Mockito.inOrder(helper); + List servers = createWeightedServerAddrs(1, 1); + initializeLbSubchannels(new RingHashConfig(10, 100, ""), servers); + Subchannel subchannel0 = subchannels.get(Collections.singletonList(servers.get(0))); + Subchannel subchannel1 = subchannels.get(Collections.singletonList(servers.get(1))); + + // Subchannels go READY, but the LB waits for health + for (Subchannel subchannel : subchannels.values()) { + deliverSubchannelState(subchannel, ConnectivityStateInfo.forNonError(READY)); + } + inOrder.verify(helper, times(0)).updateBalancingState(eq(READY), any(SubchannelPicker.class)); + + // Health results lets subchannels go READY + healthListeners.get(subchannel0).onSubchannelState(ConnectivityStateInfo.forNonError(READY)); + healthListeners.get(subchannel1).onSubchannelState(ConnectivityStateInfo.forNonError(READY)); + inOrder.verify(helper, times(2)).updateBalancingState(eq(READY), pickerCaptor.capture()); + SubchannelPicker picker = pickerCaptor.getValue(); + Random random = new Random(1); + Set picks = new HashSet<>(); + for (int i = 0; i < 10; i++) { + picks.add( + picker.pickSubchannel(getDefaultPickSubchannelArgs(random.nextLong())).getSubchannel()); + } + assertThat(picks).containsExactly(subchannel0, subchannel1); + + // Unhealthy subchannel skipped + healthListeners.get(subchannel0).onSubchannelState( + ConnectivityStateInfo.forTransientFailure(Status.UNAVAILABLE.withDescription("oh no"))); + inOrder.verify(helper).updateBalancingState(eq(READY), pickerCaptor.capture()); + picker = pickerCaptor.getValue(); + random.setSeed(1); + picks.clear(); + for (int i = 0; i < 10; i++) { + picks.add( + picker.pickSubchannel(getDefaultPickSubchannelArgs(random.nextLong())).getSubchannel()); + } + assertThat(picks).containsExactly(subchannel1); + } + + @Test + public void config_equalsTester() { + new EqualsTester() + .addEqualityGroup( + new RingHashConfig(1, 2, "headerA"), + new RingHashConfig(1, 2, "headerA")) + .addEqualityGroup(new RingHashConfig(1, 1, "headerA")) + .addEqualityGroup(new RingHashConfig(2, 2, "headerA")) + .addEqualityGroup(new RingHashConfig(1, 2, "headerB")) + .addEqualityGroup(new RingHashConfig(1, 2, "")) + .testEquals(); + } + + @Test + public void tfWithoutConnectingChild_triggersIdleChildConnection() { + RingHashConfig config = new RingHashConfig(10, 100, ""); + List servers = createWeightedServerAddrs(1, 1); + + initializeLbSubchannels(config, servers); + + Subchannel tfSubchannel = getSubchannel(servers, 0); + Subchannel idleSubchannel = getSubchannel(servers, 1); + + deliverSubchannelUnreachable(tfSubchannel); + + Subchannel requested = connectionRequestedQueue.poll(); + assertThat(requested).isSameInstanceAs(idleSubchannel); + assertThat(connectionRequestedQueue.poll()).isNull(); + } + + @Test + public void tfWithReadyChild_doesNotTriggerIdleChildConnection() { + RingHashConfig config = new RingHashConfig(10, 100, ""); + List servers = createWeightedServerAddrs(1, 1, 1); + + initializeLbSubchannels(config, servers); + + Subchannel tfSubchannel = getSubchannel(servers, 0); + Subchannel readySubchannel = getSubchannel(servers, 1); + + deliverSubchannelState(readySubchannel, ConnectivityStateInfo.forNonError(READY)); + deliverSubchannelUnreachable(tfSubchannel); + + assertThat(connectionRequestedQueue.poll()).isNull(); + } + private List initializeLbSubchannels(RingHashConfig config, List servers, InitializationFlags... initFlags) { @@ -929,9 +1203,8 @@ private List initializeLbSubchannels(RingHashConfig config, // Activate them all to create the child LB and subchannel for (ChildLbState childLbState : loadBalancer.getChildLbStates()) { - ((RingHashChildLbState)childLbState).activate(); - assertThat(childLbState.getResolvedAddresses().getAttributes().get(IS_PETIOLE_POLICY)) - .isTrue(); + childLbState.getCurrentPicker() + .pickSubchannel(getDefaultPickSubchannelArgs(hashFunc.hashVoid())); } if (doVerifies) { @@ -995,7 +1268,7 @@ private static List createWeightedServerAddrs(long... we for (int i = 0; i < weights.length; i++) { SocketAddress addr = new FakeSocketAddress("server" + i); Attributes attr = Attributes.newBuilder().set( - InternalXdsAttributes.ATTR_SERVER_WEIGHT, weights[i]).build(); + XdsAttributes.ATTR_SERVER_WEIGHT, weights[i]).build(); EquivalentAddressGroup eag = new EquivalentAddressGroup(addr, attr); addrs.add(eag); } @@ -1041,6 +1314,9 @@ public String toString() { } private class TestHelper extends AbstractTestHelper { + public TestHelper() { + super(new FakeClock(), syncContext); + } @Override public Map, Subchannel> getSubchannelMap() { @@ -1052,11 +1328,6 @@ public String getAuthority() { return AUTHORITY; } - @Override - public SynchronizationContext getSynchronizationContext() { - return syncContext; - } - private Subchannel getMockSubchannel(Subchannel realSubchannel) { return realToMockSubChannelMap.get(realSubchannel); } @@ -1080,6 +1351,30 @@ public void requestConnection() { } } + private static final class FakeRandom implements ThreadSafeRandom { + int counter = 0; + + @Override + public int nextInt(int bound) { + throw new UnsupportedOperationException("Should not be called"); + } + + @Override + public long nextLong() { + ++counter; + if (counter % 2 == 0) { + return XxHash64.INSTANCE.hashAsciiString("FakeSocketAddress-server0_0"); + } else { + return XxHash64.INSTANCE.hashAsciiString("FakeSocketAddress-server1_0"); + } + } + + @Override + public long nextLong(long bound) { + throw new UnsupportedOperationException("Should not be called"); + } + } + enum InitializationFlags { DO_NOT_VERIFY, RESET_SUBCHANNEL_MOCKS, diff --git a/xds/src/test/java/io/grpc/xds/RouterFilterTest.java b/xds/src/test/java/io/grpc/xds/RouterFilterTest.java new file mode 100644 index 00000000000..30fd8a6dc38 --- /dev/null +++ b/xds/src/test/java/io/grpc/xds/RouterFilterTest.java @@ -0,0 +1,36 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds; + +import static com.google.common.truth.Truth.assertThat; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Tests for {@link RouterFilter}. */ +@RunWith(JUnit4.class) +public class RouterFilterTest { + private static final RouterFilter.Provider FILTER_PROVIDER = new RouterFilter.Provider(); + + @Test + public void filterType_clientAndServer() { + assertThat(FILTER_PROVIDER.isClientFilter()).isTrue(); + assertThat(FILTER_PROVIDER.isServerFilter()).isTrue(); + } + +} diff --git a/xds/src/test/java/io/grpc/xds/SharedXdsClientPoolProviderTest.java b/xds/src/test/java/io/grpc/xds/SharedXdsClientPoolProviderTest.java index 0687b51aea6..29b149f166f 100644 --- a/xds/src/test/java/io/grpc/xds/SharedXdsClientPoolProviderTest.java +++ b/xds/src/test/java/io/grpc/xds/SharedXdsClientPoolProviderTest.java @@ -18,22 +18,39 @@ import static com.google.common.truth.Truth.assertThat; +import static io.grpc.Metadata.ASCII_STRING_MARSHALLER; +import static org.mockito.Mockito.never; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoMoreInteractions; import static org.mockito.Mockito.when; +import com.google.auth.oauth2.AccessToken; +import com.google.auth.oauth2.OAuth2Credentials; +import com.google.common.util.concurrent.SettableFuture; +import io.grpc.CallCredentials; +import io.grpc.Grpc; import io.grpc.InsecureChannelCredentials; +import io.grpc.InsecureServerCredentials; +import io.grpc.Metadata; +import io.grpc.MetricRecorder; +import io.grpc.Server; +import io.grpc.ServerCall; +import io.grpc.ServerCallHandler; +import io.grpc.ServerInterceptor; +import io.grpc.auth.MoreCallCredentials; import io.grpc.internal.ObjectPool; import io.grpc.xds.SharedXdsClientPoolProvider.RefCountedXdsClientObjectPool; +import io.grpc.xds.XdsListenerResource.LdsUpdate; import io.grpc.xds.client.Bootstrapper.BootstrapInfo; import io.grpc.xds.client.Bootstrapper.ServerInfo; import io.grpc.xds.client.EnvoyProtoData.Node; import io.grpc.xds.client.XdsClient; +import io.grpc.xds.client.XdsClient.ResourceWatcher; import io.grpc.xds.client.XdsInitializationException; import java.util.Collections; +import java.util.concurrent.TimeUnit; import org.junit.Rule; import org.junit.Test; -import org.junit.rules.ExpectedException; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; import org.mockito.Mock; @@ -47,24 +64,34 @@ public class SharedXdsClientPoolProviderTest { private static final String SERVER_URI = "trafficdirector.googleapis.com"; @Rule public final MockitoRule mocks = MockitoJUnit.rule(); - @SuppressWarnings("deprecation") // https://github.com/grpc/grpc-java/issues/7467 - @Rule - public final ExpectedException thrown = ExpectedException.none(); private final Node node = Node.newBuilder().setId("SharedXdsClientPoolProviderTest").build(); + private final MetricRecorder metricRecorder = new MetricRecorder() {}; + private static final String DUMMY_TARGET = "dummy"; + static final Metadata.Key AUTHORIZATION_METADATA_KEY = + Metadata.Key.of("Authorization", ASCII_STRING_MARSHALLER); @Mock private GrpcBootstrapperImpl bootstrapper; + @Mock private ResourceWatcher ldsResourceWatcher; + @Deprecated @Test - public void noServer() throws XdsInitializationException { + public void sharedXdsClientObjectPool_deprecated() throws XdsInitializationException { + ServerInfo server = ServerInfo.create(SERVER_URI, InsecureChannelCredentials.create()); BootstrapInfo bootstrapInfo = - BootstrapInfo.builder().servers(Collections.emptyList()).node(node).build(); + BootstrapInfo.builder().servers(Collections.singletonList(server)).node(node).build(); when(bootstrapper.bootstrap()).thenReturn(bootstrapInfo); + SharedXdsClientPoolProvider provider = new SharedXdsClientPoolProvider(bootstrapper); - thrown.expect(XdsInitializationException.class); - thrown.expectMessage("No xDS server provided"); - provider.getOrCreate(); - assertThat(provider.get()).isNull(); + assertThat(provider.get(DUMMY_TARGET)).isNull(); + ObjectPool xdsClientPool = + provider.getOrCreate(DUMMY_TARGET, metricRecorder, null); + verify(bootstrapper).bootstrap(); + assertThat(provider.getOrCreate(DUMMY_TARGET, bootstrapInfo, metricRecorder)) + .isSameInstanceAs(xdsClientPool); + assertThat(provider.get(DUMMY_TARGET)).isNotNull(); + assertThat(provider.get(DUMMY_TARGET)).isSameInstanceAs(xdsClientPool); + verifyNoMoreInteractions(bootstrapper); } @Test @@ -72,15 +99,16 @@ public void sharedXdsClientObjectPool() throws XdsInitializationException { ServerInfo server = ServerInfo.create(SERVER_URI, InsecureChannelCredentials.create()); BootstrapInfo bootstrapInfo = BootstrapInfo.builder().servers(Collections.singletonList(server)).node(node).build(); - when(bootstrapper.bootstrap()).thenReturn(bootstrapInfo); SharedXdsClientPoolProvider provider = new SharedXdsClientPoolProvider(bootstrapper); - assertThat(provider.get()).isNull(); - ObjectPool xdsClientPool = provider.getOrCreate(); - verify(bootstrapper).bootstrap(); - assertThat(provider.getOrCreate()).isSameInstanceAs(xdsClientPool); - assertThat(provider.get()).isNotNull(); - assertThat(provider.get()).isSameInstanceAs(xdsClientPool); + assertThat(provider.get(DUMMY_TARGET)).isNull(); + ObjectPool xdsClientPool = + provider.getOrCreate(DUMMY_TARGET, bootstrapInfo, metricRecorder); + verify(bootstrapper, never()).bootstrap(); + assertThat(provider.getOrCreate(DUMMY_TARGET, bootstrapInfo, metricRecorder)) + .isSameInstanceAs(xdsClientPool); + assertThat(provider.get(DUMMY_TARGET)).isNotNull(); + assertThat(provider.get(DUMMY_TARGET)).isSameInstanceAs(xdsClientPool); verifyNoMoreInteractions(bootstrapper); } @@ -89,8 +117,9 @@ public void refCountedXdsClientObjectPool_delayedCreation() { ServerInfo server = ServerInfo.create(SERVER_URI, InsecureChannelCredentials.create()); BootstrapInfo bootstrapInfo = BootstrapInfo.builder().servers(Collections.singletonList(server)).node(node).build(); + SharedXdsClientPoolProvider provider = new SharedXdsClientPoolProvider(); RefCountedXdsClientObjectPool xdsClientPool = - new RefCountedXdsClientObjectPool(bootstrapInfo); + provider.new RefCountedXdsClientObjectPool(bootstrapInfo, DUMMY_TARGET, metricRecorder); assertThat(xdsClientPool.getXdsClientForTest()).isNull(); XdsClient xdsClient = xdsClientPool.getObject(); assertThat(xdsClientPool.getXdsClientForTest()).isNotNull(); @@ -102,8 +131,9 @@ public void refCountedXdsClientObjectPool_refCounted() { ServerInfo server = ServerInfo.create(SERVER_URI, InsecureChannelCredentials.create()); BootstrapInfo bootstrapInfo = BootstrapInfo.builder().servers(Collections.singletonList(server)).node(node).build(); + SharedXdsClientPoolProvider provider = new SharedXdsClientPoolProvider(); RefCountedXdsClientObjectPool xdsClientPool = - new RefCountedXdsClientObjectPool(bootstrapInfo); + provider.new RefCountedXdsClientObjectPool(bootstrapInfo, DUMMY_TARGET, metricRecorder); // getObject once XdsClient xdsClient = xdsClientPool.getObject(); assertThat(xdsClient).isNotNull(); @@ -122,8 +152,9 @@ public void refCountedXdsClientObjectPool_getObjectCreatesNewInstanceIfAlreadySh ServerInfo server = ServerInfo.create(SERVER_URI, InsecureChannelCredentials.create()); BootstrapInfo bootstrapInfo = BootstrapInfo.builder().servers(Collections.singletonList(server)).node(node).build(); + SharedXdsClientPoolProvider provider = new SharedXdsClientPoolProvider(); RefCountedXdsClientObjectPool xdsClientPool = - new RefCountedXdsClientObjectPool(bootstrapInfo); + provider.new RefCountedXdsClientObjectPool(bootstrapInfo, DUMMY_TARGET, metricRecorder); XdsClient xdsClient1 = xdsClientPool.getObject(); assertThat(xdsClientPool.returnObject(xdsClient1)).isNull(); assertThat(xdsClient1.isShutDown()).isTrue(); @@ -132,4 +163,61 @@ public void refCountedXdsClientObjectPool_getObjectCreatesNewInstanceIfAlreadySh assertThat(xdsClient2).isNotSameInstanceAs(xdsClient1); xdsClientPool.returnObject(xdsClient2); } + + private class CallCredsServerInterceptor implements ServerInterceptor { + private SettableFuture tokenFuture = SettableFuture.create(); + + @Override + public ServerCall.Listener interceptCall( + ServerCall serverCall, + Metadata metadata, + ServerCallHandler next) { + tokenFuture.set(metadata.get(AUTHORIZATION_METADATA_KEY)); + return next.startCall(serverCall, metadata); + } + + public String getTokenWithTimeout(long timeout, TimeUnit unit) throws Exception { + return tokenFuture.get(timeout, unit); + } + } + + @Test + public void xdsClient_usesCallCredentials() throws Exception { + // Set up fake xDS server + XdsTestControlPlaneService fakeXdsService = new XdsTestControlPlaneService(); + CallCredsServerInterceptor callCredentialsInterceptor = new CallCredsServerInterceptor(); + Server xdsServer = + Grpc.newServerBuilderForPort(0, InsecureServerCredentials.create()) + .addService(fakeXdsService) + .intercept(callCredentialsInterceptor) + .build() + .start(); + String xdsServerUri = "localhost:" + xdsServer.getPort(); + + // Set up bootstrap & xDS client pool provider + ServerInfo server = ServerInfo.create(xdsServerUri, InsecureChannelCredentials.create()); + BootstrapInfo bootstrapInfo = + BootstrapInfo.builder().servers(Collections.singletonList(server)).node(node).build(); + SharedXdsClientPoolProvider provider = new SharedXdsClientPoolProvider(); + + // Create custom xDS transport CallCredentials + CallCredentials sampleCreds = + MoreCallCredentials.from( + OAuth2Credentials.create(new AccessToken("token", /* expirationTime= */ null))); + + // Create xDS client that uses the CallCredentials on the transport + ObjectPool xdsClientPool = + provider.getOrCreate("target", bootstrapInfo, metricRecorder, sampleCreds); + XdsClient xdsClient = xdsClientPool.getObject(); + xdsClient.watchXdsResource( + XdsListenerResource.getInstance(), "someLDSresource", ldsResourceWatcher); + + // Wait for xDS server to get the request and verify that it received the CallCredentials + assertThat(callCredentialsInterceptor.getTokenWithTimeout(5, TimeUnit.SECONDS)) + .isEqualTo("Bearer token"); + + // Clean up + xdsClientPool.returnObject(xdsClient); + xdsServer.shutdownNow(); + } } diff --git a/xds/src/test/java/io/grpc/xds/StatefulFilter.java b/xds/src/test/java/io/grpc/xds/StatefulFilter.java new file mode 100644 index 00000000000..a43ef14f8d4 --- /dev/null +++ b/xds/src/test/java/io/grpc/xds/StatefulFilter.java @@ -0,0 +1,176 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds; + +import static com.google.common.base.Preconditions.checkNotNull; +import static com.google.common.collect.ImmutableList.toImmutableList; + +import com.google.common.collect.ImmutableList; +import com.google.protobuf.Message; +import io.grpc.ServerInterceptor; +import java.util.ConcurrentModificationException; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.stream.IntStream; +import javax.annotation.Nullable; + +/** + * Unlike most singleton-based filters, each StatefulFilter object has a distinct identity. + */ +class StatefulFilter implements Filter { + + static final String DEFAULT_TYPE_URL = "type.googleapis.com/grpc.test.StatefulFilter"; + private final AtomicBoolean shutdown = new AtomicBoolean(); + + final int idx; + @Nullable volatile String lastCfg = null; + + public StatefulFilter(int idx) { + this.idx = idx; + } + + public boolean isShutdown() { + return shutdown.get(); + } + + @Override + public void close() { + if (!shutdown.compareAndSet(false, true)) { + throw new ConcurrentModificationException( + "Unexpected: StatefulFilter#close called multiple times"); + } + } + + @Nullable + @Override + public ServerInterceptor buildServerInterceptor( + FilterConfig config, + @Nullable FilterConfig overrideConfig) { + Config cfg = (Config) config; + // TODO(sergiitk): to be replaced when name argument passed to the constructor. + lastCfg = cfg.getConfig(); + return null; + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder().append("StatefulFilter{") + .append("idx=").append(idx); + if (lastCfg != null) { + sb.append(", name=").append(lastCfg); + } + return sb.append("}").toString(); + } + + static final class Provider implements Filter.Provider { + + private final String typeUrl; + private final ConcurrentMap instances = new ConcurrentHashMap<>(); + + volatile int counter; + + Provider() { + this(DEFAULT_TYPE_URL); + } + + Provider(String typeUrl) { + this.typeUrl = typeUrl; + } + + @Override + public String[] typeUrls() { + return new String[]{ typeUrl }; + } + + @Override + public boolean isClientFilter() { + return true; + } + + @Override + public boolean isServerFilter() { + return true; + } + + @Override + public synchronized StatefulFilter newInstance(String name) { + StatefulFilter filter = new StatefulFilter(counter++); + instances.put(filter.idx, filter); + return filter; + } + + public synchronized StatefulFilter getInstance(int idx) { + return instances.get(idx); + } + + public synchronized ImmutableList getAllInstances() { + return IntStream.range(0, counter).mapToObj(this::getInstance).collect(toImmutableList()); + } + + @SuppressWarnings("UnusedMethod") + public synchronized int getCount() { + return counter; + } + + @Override + public ConfigOrError parseFilterConfig(Message rawProtoMessage, + FilterConfigParseContext context) { + return ConfigOrError.fromConfig(Config.fromProto(rawProtoMessage, typeUrl)); + } + + @Override + public ConfigOrError parseFilterConfigOverride( + Message rawProtoMessage, FilterConfigParseContext context) { + return ConfigOrError.fromConfig(Config.fromProto(rawProtoMessage, typeUrl)); + } + } + + + static final class Config implements FilterConfig { + + private final String typeUrl; + private final String config; + + public Config(String config, String typeUrl) { + this.config = config; + this.typeUrl = typeUrl; + } + + public Config(String config) { + this(config, DEFAULT_TYPE_URL); + } + + public Config() { + this("", DEFAULT_TYPE_URL); + } + + public static Config fromProto(Message rawProtoMessage, String typeUrl) { + checkNotNull(rawProtoMessage, "rawProtoMessage"); + return new Config(rawProtoMessage.toString(), typeUrl); + } + + public String getConfig() { + return config; + } + + @Override + public String typeUrl() { + return typeUrl; + } + } +} diff --git a/xds/src/test/java/io/grpc/xds/WeightedRandomPickerTest.java b/xds/src/test/java/io/grpc/xds/WeightedRandomPickerTest.java index d6240fb09bb..691615762bf 100644 --- a/xds/src/test/java/io/grpc/xds/WeightedRandomPickerTest.java +++ b/xds/src/test/java/io/grpc/xds/WeightedRandomPickerTest.java @@ -17,6 +17,7 @@ package io.grpc.xds; import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; import static org.mockito.Mockito.mock; import io.grpc.LoadBalancer.PickResult; @@ -30,7 +31,6 @@ import java.util.List; import org.junit.Rule; import org.junit.Test; -import org.junit.rules.ExpectedException; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; import org.mockito.Mock; @@ -42,9 +42,6 @@ */ @RunWith(JUnit4.class) public class WeightedRandomPickerTest { - @SuppressWarnings("deprecation") // https://github.com/grpc/grpc-java/issues/7467 - @Rule - public final ExpectedException thrown = ExpectedException.none(); @Rule public final MockitoRule mockitoRule = MockitoJUnit.rule(); @@ -128,20 +125,18 @@ public long nextLong(long bound) { public void emptyList() { List emptyList = new ArrayList<>(); - thrown.expect(IllegalArgumentException.class); - new WeightedRandomPicker(emptyList); + assertThrows(IllegalArgumentException.class, () -> new WeightedRandomPicker(emptyList)); } @Test public void negativeWeight() { - thrown.expect(IllegalArgumentException.class); - new WeightedChildPicker(-1, childPicker0); + assertThrows(IllegalArgumentException.class, () -> new WeightedChildPicker(-1, childPicker0)); } @Test public void overWeightSingle() { - thrown.expect(IllegalArgumentException.class); - new WeightedChildPicker(Integer.MAX_VALUE * 3L, childPicker0); + assertThrows(IllegalArgumentException.class, + () -> new WeightedChildPicker(Integer.MAX_VALUE * 3L, childPicker0)); } @Test @@ -152,8 +147,8 @@ public void overWeightAggregate() { new WeightedChildPicker(Integer.MAX_VALUE, childPicker1), new WeightedChildPicker(10, childPicker2)); - thrown.expect(IllegalArgumentException.class); - new WeightedRandomPicker(weightedChildPickers, fakeRandom); + assertThrows(IllegalArgumentException.class, + () -> new WeightedRandomPicker(weightedChildPickers, fakeRandom)); } @Test diff --git a/xds/src/test/java/io/grpc/xds/WeightedRoundRobinLoadBalancerProviderTest.java b/xds/src/test/java/io/grpc/xds/WeightedRoundRobinLoadBalancerProviderTest.java index ddde84ca842..0bd3283cb79 100644 --- a/xds/src/test/java/io/grpc/xds/WeightedRoundRobinLoadBalancerProviderTest.java +++ b/xds/src/test/java/io/grpc/xds/WeightedRoundRobinLoadBalancerProviderTest.java @@ -29,6 +29,7 @@ import io.grpc.internal.FakeClock; import io.grpc.internal.JsonParser; import io.grpc.xds.WeightedRoundRobinLoadBalancer.WeightedRoundRobinLoadBalancerConfig; +import io.grpc.xds.internal.MetricReportUtils.ParsedMetricName; import java.io.IOException; import java.util.Map; import org.junit.Test; @@ -111,6 +112,25 @@ public void parseLoadBalancingConfigDefaultValues() throws IOException { assertThat(config.errorUtilizationPenalty).isEqualTo(1.0F); } + @Test + public void parseLoadBalancingConfigCustomMetricsIgnoresInvalid() throws IOException { + System.setProperty("GRPC_EXPERIMENTAL_WRR_CUSTOM_METRICS", "true"); + try { + String lbConfig = + "{\"metricNamesForComputingUtilization\" : " + + "[\"utilization.foo\", \"invalid_name\", \"named_metrics.bar\"]}"; + ConfigOrError configOrError = provider.parseLoadBalancingPolicyConfig( + parseJsonObject(lbConfig)); + assertThat(configOrError.getConfig()).isNotNull(); + WeightedRoundRobinLoadBalancerConfig config = + (WeightedRoundRobinLoadBalancerConfig) configOrError.getConfig(); + assertThat(config.parsedMetricNamesForComputingUtilization).containsExactly( + ParsedMetricName.parse("utilization.foo"), ParsedMetricName.parse("named_metrics.bar")); + } finally { + System.clearProperty("GRPC_EXPERIMENTAL_WRR_CUSTOM_METRICS"); + } + } + @SuppressWarnings("unchecked") private static Map parseJsonObject(String json) throws IOException { diff --git a/xds/src/test/java/io/grpc/xds/WeightedRoundRobinLoadBalancerTest.java b/xds/src/test/java/io/grpc/xds/WeightedRoundRobinLoadBalancerTest.java index 41847d21d87..bac62d1a103 100644 --- a/xds/src/test/java/io/grpc/xds/WeightedRoundRobinLoadBalancerTest.java +++ b/xds/src/test/java/io/grpc/xds/WeightedRoundRobinLoadBalancerTest.java @@ -19,9 +19,13 @@ import static com.google.common.truth.Truth.assertThat; import static io.grpc.ConnectivityState.CONNECTING; import static org.mockito.AdditionalAnswers.delegatesTo; -import static org.mockito.Mockito.any; -import static org.mockito.Mockito.eq; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.argThat; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.atLeast; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.reset; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoMoreInteractions; @@ -29,16 +33,22 @@ import com.github.xds.data.orca.v3.OrcaLoadReport; import com.github.xds.service.orca.v3.OrcaLoadReportRequest; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.Lists; import com.google.common.collect.Maps; +import com.google.common.testing.EqualsTester; import com.google.protobuf.Duration; import io.grpc.Attributes; +import io.grpc.CallOptions; import io.grpc.Channel; import io.grpc.ClientCall; +import io.grpc.ClientStreamTracer; import io.grpc.ConnectivityState; import io.grpc.ConnectivityStateInfo; +import io.grpc.DoubleHistogramMetricInstrument; import io.grpc.EquivalentAddressGroup; +import io.grpc.InternalManagedChannelBuilder; import io.grpc.LoadBalancer; import io.grpc.LoadBalancer.CreateSubchannelArgs; import io.grpc.LoadBalancer.Helper; @@ -47,18 +57,36 @@ import io.grpc.LoadBalancer.Subchannel; import io.grpc.LoadBalancer.SubchannelPicker; import io.grpc.LoadBalancer.SubchannelStateListener; +import io.grpc.LongCounterMetricInstrument; +import io.grpc.Metadata; +import io.grpc.MetricRecorder; +import io.grpc.MetricSink; +import io.grpc.NameResolver; +import io.grpc.NoopMetricSink; +import io.grpc.ServerCall; +import io.grpc.ServerServiceDefinition; import io.grpc.Status; import io.grpc.SynchronizationContext; +import io.grpc.inprocess.InProcessChannelBuilder; +import io.grpc.inprocess.InProcessServerBuilder; import io.grpc.internal.FakeClock; +import io.grpc.internal.PickFirstLoadBalancerProvider; import io.grpc.internal.TestUtils; +import io.grpc.internal.testing.StreamRecorder; +import io.grpc.protobuf.ProtoUtils; import io.grpc.services.InternalCallMetricRecorder; import io.grpc.services.MetricReport; +import io.grpc.stub.ClientCalls; +import io.grpc.stub.StreamObserver; +import io.grpc.testing.GrpcCleanupRule; +import io.grpc.testing.TestMethodDescriptors; import io.grpc.util.AbstractTestHelper; import io.grpc.util.MultiChildLoadBalancer.ChildLbState; import io.grpc.xds.WeightedRoundRobinLoadBalancer.StaticStrideScheduler; import io.grpc.xds.WeightedRoundRobinLoadBalancer.WeightedChildLbState; import io.grpc.xds.WeightedRoundRobinLoadBalancer.WeightedRoundRobinLoadBalancerConfig; import io.grpc.xds.WeightedRoundRobinLoadBalancer.WeightedRoundRobinPicker; +import io.grpc.xds.orca.OrcaOobUtilAccessor; import java.net.SocketAddress; import java.util.Arrays; import java.util.Collections; @@ -71,7 +99,6 @@ import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.CyclicBarrier; -import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; import org.junit.Before; @@ -80,6 +107,7 @@ import org.junit.runner.RunWith; import org.junit.runners.JUnit4; import org.mockito.ArgumentCaptor; +import org.mockito.ArgumentMatcher; import org.mockito.Captor; import org.mockito.InOrder; import org.mockito.Mock; @@ -93,9 +121,11 @@ public class WeightedRoundRobinLoadBalancerTest { @Rule public final MockitoRule mockito = MockitoJUnit.rule(); + @Rule + public final GrpcCleanupRule grpcCleanupRule = new GrpcCleanupRule(); - private final TestHelper testHelperInstance = new TestHelper(); - private Helper helper = mock(Helper.class, delegatesTo(testHelperInstance)); + private final TestHelper testHelperInstance; + private final Helper helper; @Mock private LoadBalancer.PickSubchannelArgs mockArgs; @@ -107,9 +137,6 @@ public class WeightedRoundRobinLoadBalancerTest { private final List servers = Lists.newArrayList(); private final Map, Subchannel> subchannels = Maps.newLinkedHashMap(); - private final Map mockToRealSubChannelMap = new HashMap<>(); - private final Map subchannelStateListeners = - Maps.newLinkedHashMap(); private final Queue> oobCalls = new ConcurrentLinkedQueue<>(); @@ -118,6 +145,9 @@ public class WeightedRoundRobinLoadBalancerTest { private final FakeClock fakeClock = new FakeClock(); + @Mock + private MetricRecorder mockMetricRecorder; + private WeightedRoundRobinLoadBalancerConfig weightedConfig = WeightedRoundRobinLoadBalancerConfig.newBuilder().build(); @@ -134,6 +164,19 @@ public void uncaughtException(Thread t, Throwable e) { } }); + private String channelTarget = "channel-target"; + private String locality = "locality"; + private String backendService = "the-backend-service"; + + public WeightedRoundRobinLoadBalancerTest() { + testHelperInstance = new TestHelper(); + helper = mock(Helper.class, delegatesTo(testHelperInstance)); + } + + private static WeightedRoundRobinPicker getWrrPicker(SubchannelPicker picker) { + return (WeightedRoundRobinPicker) OrcaOobUtilAccessor.getDelegate(picker); + } + @Before public void setup() { for (int i = 0; i < 3; i++) { @@ -154,13 +197,14 @@ public ClientCall answer( return clientCall; } }); - testHelperInstance.setChannel(mockToRealSubChannelMap.get(sc), channel); + testHelperInstance.setChannel(sc, channel); subchannels.put(Arrays.asList(eag), sc); } wrr = new WeightedRoundRobinLoadBalancer(helper, fakeClock.getDeadlineTicker(), new FakeRandom(0)); verify(helper, times(3)).createSubchannel(any(CreateSubchannelArgs.class)); + reset(helper); } @Test @@ -174,9 +218,42 @@ public void pickChildLbTF() throws Exception { .forTransientFailure(Status.UNAVAILABLE)); verify(helper).updateBalancingState( eq(ConnectivityState.TRANSIENT_FAILURE), pickerCaptor.capture()); - final WeightedRoundRobinPicker weightedPicker = - (WeightedRoundRobinPicker) pickerCaptor.getValue(); - weightedPicker.pickSubchannel(mockArgs); + final SubchannelPicker picker = pickerCaptor.getValue(); + picker.pickSubchannel(mockArgs); + } + + @Test + public void config_equalsTester() { + WeightedRoundRobinLoadBalancerConfig defaults = + WeightedRoundRobinLoadBalancerConfig.newBuilder().build(); + new EqualsTester() + .addEqualityGroup( + WeightedRoundRobinLoadBalancerConfig.newBuilder().build(), + WeightedRoundRobinLoadBalancerConfig.newBuilder().build(), + WeightedRoundRobinLoadBalancerConfig.newBuilder() + .setBlackoutPeriodNanos(defaults.blackoutPeriodNanos).build()) + .addEqualityGroup( + WeightedRoundRobinLoadBalancerConfig.newBuilder() + .setBlackoutPeriodNanos(5).build()) + .addEqualityGroup( + WeightedRoundRobinLoadBalancerConfig.newBuilder() + .setWeightExpirationPeriodNanos(5).build()) + .addEqualityGroup( + WeightedRoundRobinLoadBalancerConfig.newBuilder() + .setEnableOobLoadReport(true).build()) + .addEqualityGroup( + WeightedRoundRobinLoadBalancerConfig.newBuilder() + .setOobReportingPeriodNanos(5).build()) + .addEqualityGroup( + WeightedRoundRobinLoadBalancerConfig.newBuilder() + .setWeightUpdatePeriodNanos(5).build()) + .addEqualityGroup( + WeightedRoundRobinLoadBalancerConfig.newBuilder() + .setErrorUtilizationPenalty(0.5F).build()) + .addEqualityGroup( + WeightedRoundRobinLoadBalancerConfig.newBuilder() + .setErrorUtilizationPenalty(Float.NaN).build()) + .testEquals(); } @Test @@ -184,9 +261,9 @@ public void wrrLifeCycle() { syncContext.execute(() -> wrr.acceptResolvedAddresses(ResolvedAddresses.newBuilder() .setAddresses(servers).setLoadBalancingPolicyConfig(weightedConfig) .setAttributes(affinity).build())); - verify(helper, times(6)).createSubchannel( + verify(helper, times(3)).createSubchannel( any(CreateSubchannelArgs.class)); - assertThat(fakeClock.getPendingTasks().size()).isEqualTo(1); + assertThat(getNumFilteredPendingTasks()).isEqualTo(1); Iterator it = subchannels.values().iterator(); Subchannel readySubchannel1 = it.next(); @@ -202,26 +279,29 @@ public void wrrLifeCycle() { eq(ConnectivityState.READY), pickerCaptor.capture()); assertThat(pickerCaptor.getAllValues().size()).isEqualTo(2); WeightedRoundRobinPicker weightedPicker = - (WeightedRoundRobinPicker) pickerCaptor.getAllValues().get(0); + getWrrPicker(pickerCaptor.getAllValues().get(0)); assertThat(weightedPicker.getChildren().size()).isEqualTo(1); - weightedPicker = (WeightedRoundRobinPicker) pickerCaptor.getAllValues().get(1); + weightedPicker = getWrrPicker(pickerCaptor.getAllValues().get(1)); assertThat(weightedPicker.getChildren().size()).isEqualTo(2); String weightedPickerStr = weightedPicker.toString(); assertThat(weightedPickerStr).contains("enableOobLoadReport=false"); assertThat(weightedPickerStr).contains("errorUtilizationPenalty=1.0"); - assertThat(weightedPickerStr).contains("list="); + assertThat(weightedPickerStr).contains("pickers="); WeightedChildLbState weightedChild1 = (WeightedChildLbState) getChild(weightedPicker, 0); WeightedChildLbState weightedChild2 = (WeightedChildLbState) getChild(weightedPicker, 1); - weightedChild1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( + weightedChild1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty, + weightedConfig.parsedMetricNamesForComputingUtilization).onLoadReport( InternalCallMetricRecorder.createMetricReport( 0.1, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>())); - weightedChild2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( + weightedChild2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty, + weightedConfig.parsedMetricNamesForComputingUtilization).onLoadReport( InternalCallMetricRecorder.createMetricReport( 0.2, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>())); - assertThat(fakeClock.forwardTime(11, TimeUnit.SECONDS)).isEqualTo(1); + int expectedTasks = isEnabledHappyEyeballs() ? 2 : 1; + assertThat(fakeClock.forwardTime(11, TimeUnit.SECONDS)).isEqualTo(expectedTasks); - assertThat(getAddressesFromPick(weightedPicker)).isEqualTo(weightedChild1.getEag()); + assertThat(getAddressesFromPick(weightedPicker)).isEqualTo(servers.get(0)); assertThat(fakeClock.getPendingTasks().size()).isEqualTo(1); weightedConfig = WeightedRoundRobinLoadBalancerConfig.newBuilder() .setWeightUpdatePeriodNanos(500_000_000L) //.5s @@ -229,13 +309,13 @@ weightedChild2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).on syncContext.execute(() -> wrr.acceptResolvedAddresses(ResolvedAddresses.newBuilder() .setAddresses(servers).setLoadBalancingPolicyConfig(weightedConfig) .setAttributes(affinity).build())); - assertThat(fakeClock.getPendingTasks().size()).isEqualTo(1); + assertThat(getNumFilteredPendingTasks()).isEqualTo(1); syncContext.execute(() -> wrr.shutdown()); for (Subchannel subchannel: subchannels.values()) { verify(subchannel).shutdown(); } - assertThat(fakeClock.getPendingTasks().size()).isEqualTo(0); + assertThat(getNumFilteredPendingTasks()).isEqualTo(0); verifyNoMoreInteractions(mockArgs); } @@ -252,7 +332,7 @@ public void enableOobLoadReportConfig() { syncContext.execute(() -> wrr.acceptResolvedAddresses(ResolvedAddresses.newBuilder() .setAddresses(servers).setLoadBalancingPolicyConfig(weightedConfig) .setAttributes(affinity).build())); - verify(helper, times(6)).createSubchannel( + verify(helper, times(3)).createSubchannel( any(CreateSubchannelArgs.class)); Iterator it = subchannels.values().iterator(); Subchannel readySubchannel1 = it.next(); @@ -264,19 +344,21 @@ public void enableOobLoadReportConfig() { verify(helper, times(2)).updateBalancingState( eq(ConnectivityState.READY), pickerCaptor.capture()); WeightedRoundRobinPicker weightedPicker = - (WeightedRoundRobinPicker) pickerCaptor.getAllValues().get(1); + getWrrPicker(pickerCaptor.getAllValues().get(1)); WeightedChildLbState weightedChild1 = (WeightedChildLbState) getChild(weightedPicker, 0); WeightedChildLbState weightedChild2 = (WeightedChildLbState) getChild(weightedPicker, 1); - weightedChild1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( + weightedChild1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty, + weightedConfig.parsedMetricNamesForComputingUtilization).onLoadReport( InternalCallMetricRecorder.createMetricReport( 0.1, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>())); - weightedChild2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( + weightedChild2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty, + weightedConfig.parsedMetricNamesForComputingUtilization).onLoadReport( InternalCallMetricRecorder.createMetricReport( 0.9, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>())); - assertThat(fakeClock.forwardTime(11, TimeUnit.SECONDS)).isEqualTo(1); + int expectedTasks = isEnabledHappyEyeballs() ? 2 : 1; + assertThat(fakeClock.forwardTime(11, TimeUnit.SECONDS)).isEqualTo(expectedTasks); PickResult pickResult = weightedPicker.pickSubchannel(mockArgs); - assertThat(getAddresses(pickResult)) - .isEqualTo(weightedChild1.getEag()); + assertThat(getAddresses(pickResult)).isEqualTo(servers.get(0)); assertThat(pickResult.getStreamTracerFactory()).isNotNull(); // verify per-request listener assertThat(oobCalls.isEmpty()).isTrue(); @@ -288,10 +370,9 @@ weightedChild2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).on .setAttributes(affinity).build())); verify(helper, times(3)).updateBalancingState( eq(ConnectivityState.READY), pickerCaptor2.capture()); - weightedPicker = (WeightedRoundRobinPicker) pickerCaptor2.getAllValues().get(2); - pickResult = weightedPicker.pickSubchannel(mockArgs); - assertThat(getAddresses(pickResult)) - .isEqualTo(weightedChild1.getEag()); + SubchannelPicker rawPicker = pickerCaptor2.getAllValues().get(2); + pickResult = rawPicker.pickSubchannel(mockArgs); + assertThat(getAddresses(pickResult)).isEqualTo(servers.get(0)); assertThat(pickResult.getStreamTracerFactory()).isNull(); OrcaLoadReportRequest golden = OrcaLoadReportRequest.newBuilder().setReportInterval( Duration.newBuilder().setSeconds(20).setNanos(30000000).build()).build(); @@ -306,9 +387,9 @@ private void pickByWeight(MetricReport r1, MetricReport r2, MetricReport r3, syncContext.execute(() -> wrr.acceptResolvedAddresses(ResolvedAddresses.newBuilder() .setAddresses(servers).setLoadBalancingPolicyConfig(weightedConfig) .setAttributes(affinity).build())); - verify(helper, times(6)).createSubchannel( + verify(helper, times(3)).createSubchannel( any(CreateSubchannelArgs.class)); - assertThat(fakeClock.getPendingTasks().size()).isEqualTo(1); + assertThat(getNumFilteredPendingTasks()).isEqualTo(1); Iterator it = subchannels.values().iterator(); Subchannel readySubchannel1 = it.next(); @@ -323,13 +404,16 @@ private void pickByWeight(MetricReport r1, MetricReport r2, MetricReport r3, verify(helper, times(3)).updateBalancingState( eq(ConnectivityState.READY), pickerCaptor.capture()); WeightedRoundRobinPicker weightedPicker = - (WeightedRoundRobinPicker) pickerCaptor.getAllValues().get(2); + getWrrPicker(pickerCaptor.getAllValues().get(2)); WeightedChildLbState weightedChild1 = (WeightedChildLbState) getChild(weightedPicker, 0); WeightedChildLbState weightedChild2 = (WeightedChildLbState) getChild(weightedPicker, 1); WeightedChildLbState weightedChild3 = (WeightedChildLbState) getChild(weightedPicker, 2); - weightedChild1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport(r1); - weightedChild2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport(r2); - weightedChild3.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport(r3); + weightedChild1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty, + weightedConfig.parsedMetricNamesForComputingUtilization).onLoadReport(r1); + weightedChild2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty, + weightedConfig.parsedMetricNamesForComputingUtilization).onLoadReport(r2); + weightedChild3.new OrcaReportListener(weightedConfig.errorUtilizationPenalty, + weightedConfig.parsedMetricNamesForComputingUtilization).onLoadReport(r3); assertThat(fakeClock.forwardTime(11, TimeUnit.SECONDS)).isEqualTo(1); Map pickCount = new HashMap<>(); @@ -338,16 +422,16 @@ private void pickByWeight(MetricReport r1, MetricReport r2, MetricReport r3, pickCount.put(result, pickCount.getOrDefault(result, 0) + 1); } assertThat(pickCount.size()).isEqualTo(3); - assertThat(Math.abs(pickCount.get(weightedChild1.getEag()) / 10000.0 - subchannel1PickRatio)) + assertThat(Math.abs(pickCount.get(servers.get(0)) / 10000.0 - subchannel1PickRatio)) .isAtMost(0.0002); - assertThat(Math.abs(pickCount.get(weightedChild2.getEag()) / 10000.0 - subchannel2PickRatio )) + assertThat(Math.abs(pickCount.get(servers.get(1)) / 10000.0 - subchannel2PickRatio )) .isAtMost(0.0002); - assertThat(Math.abs(pickCount.get(weightedChild3.getEag()) / 10000.0 - subchannel3PickRatio )) + assertThat(Math.abs(pickCount.get(servers.get(2)) / 10000.0 - subchannel3PickRatio )) .isAtMost(0.0002); } private SubchannelStateListener getSubchannelStateListener(Subchannel mockSubChannel) { - return subchannelStateListeners.get(mockToRealSubChannelMap.get(mockSubChannel)); + return testHelperInstance.getSubchannelStateListener(mockSubChannel); } private static ChildLbState getChild(WeightedRoundRobinPicker picker, int index) { @@ -489,19 +573,20 @@ public void emptyConfig() { assertThat(wrr.acceptResolvedAddresses(ResolvedAddresses.newBuilder() .setAddresses(servers).setLoadBalancingPolicyConfig(null) .setAttributes(affinity).build()).isOk()).isFalse(); - verify(helper, times(3)).createSubchannel(any(CreateSubchannelArgs.class)); + verify(helper, never()).createSubchannel(any(CreateSubchannelArgs.class)); verify(helper).updateBalancingState(eq(ConnectivityState.TRANSIENT_FAILURE), any()); assertThat(fakeClock.getPendingTasks()).isEmpty(); syncContext.execute(() -> wrr.acceptResolvedAddresses(ResolvedAddresses.newBuilder() .setAddresses(servers).setLoadBalancingPolicyConfig(weightedConfig) .setAttributes(affinity).build())); - verify(helper, times(6)).createSubchannel( + verify(helper, times(3)).createSubchannel( any(CreateSubchannelArgs.class)); verify(helper).updateBalancingState(eq(CONNECTING), pickerCaptor.capture()); - assertThat(pickerCaptor.getValue().getClass().getName()) - .isEqualTo("io.grpc.util.RoundRobinLoadBalancer$EmptyPicker"); - assertThat(fakeClock.forwardTime(11, TimeUnit.SECONDS)).isEqualTo(1); + assertThat(pickerCaptor.getValue().pickSubchannel(mockArgs)) + .isEqualTo(PickResult.withNoResult()); + int expectedCount = isEnabledHappyEyeballs() ? servers.size() + 1 : 1; + assertThat(fakeClock.forwardTime(11, TimeUnit.SECONDS)).isEqualTo( expectedCount); } @Test @@ -509,9 +594,8 @@ public void blackoutPeriod() { syncContext.execute(() -> wrr.acceptResolvedAddresses(ResolvedAddresses.newBuilder() .setAddresses(servers).setLoadBalancingPolicyConfig(weightedConfig) .setAttributes(affinity).build())); - verify(helper, times(6)).createSubchannel( - any(CreateSubchannelArgs.class)); - assertThat(fakeClock.getPendingTasks().size()).isEqualTo(1); + verify(helper, times(3)).createSubchannel(any(CreateSubchannelArgs.class)); + assertThat(getNumFilteredPendingTasks()).isEqualTo(1); Iterator it = subchannels.values().iterator(); Subchannel readySubchannel1 = it.next(); @@ -523,16 +607,19 @@ public void blackoutPeriod() { verify(helper, times(2)).updateBalancingState( eq(ConnectivityState.READY), pickerCaptor.capture()); WeightedRoundRobinPicker weightedPicker = - (WeightedRoundRobinPicker) pickerCaptor.getAllValues().get(1); + getWrrPicker(pickerCaptor.getAllValues().get(1)); WeightedChildLbState weightedChild1 = (WeightedChildLbState) getChild(weightedPicker, 0); WeightedChildLbState weightedChild2 = (WeightedChildLbState) getChild(weightedPicker, 1); - weightedChild1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( + weightedChild1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty, + weightedConfig.parsedMetricNamesForComputingUtilization).onLoadReport( InternalCallMetricRecorder.createMetricReport( 0.1, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>())); - weightedChild2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( + weightedChild2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty, + weightedConfig.parsedMetricNamesForComputingUtilization).onLoadReport( InternalCallMetricRecorder.createMetricReport( 0.2, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>())); - assertThat(fakeClock.forwardTime(5, TimeUnit.SECONDS)).isEqualTo(1); + int expectedCount = isEnabledHappyEyeballs() ? 2 : 1; + assertThat(fakeClock.forwardTime(5, TimeUnit.SECONDS)).isEqualTo(expectedCount); Map pickCount = new HashMap<>(); for (int i = 0; i < 10000; i++) { EquivalentAddressGroup result = getAddressesFromPick(weightedPicker); @@ -540,8 +627,8 @@ weightedChild2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).on } assertThat(pickCount.size()).isEqualTo(2); // within blackout period, fallback to simple round robin - assertThat(Math.abs(pickCount.get(weightedChild1.getEag()) / 10000.0 - 0.5)).isLessThan(0.002); - assertThat(Math.abs(pickCount.get(weightedChild2.getEag()) / 10000.0 - 0.5)).isLessThan(0.002); + assertThat(Math.abs(pickCount.get(servers.get(0)) / 10000.0 - 0.5)).isLessThan(0.002); + assertThat(Math.abs(pickCount.get(servers.get(1)) / 10000.0 - 0.5)).isLessThan(0.002); assertThat(fakeClock.forwardTime(5, TimeUnit.SECONDS)).isEqualTo(1); pickCount = new HashMap<>(); @@ -551,10 +638,12 @@ weightedChild2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).on } assertThat(pickCount.size()).isEqualTo(2); // after blackout period - assertThat(Math.abs(pickCount.get(weightedChild1.getEag()) / 10000.0 - 2.0 / 3)) - .isLessThan(0.002); - assertThat(Math.abs(pickCount.get(weightedChild2.getEag()) / 10000.0 - 1.0 / 3)) - .isLessThan(0.002); + assertThat(Math.abs(pickCount.get(servers.get(0)) / 10000.0 - 2.0 / 3)).isLessThan(0.002); + assertThat(Math.abs(pickCount.get(servers.get(1)) / 10000.0 - 1.0 / 3)).isLessThan(0.002); + } + + private boolean isEnabledHappyEyeballs() { + return PickFirstLoadBalancerProvider.isEnabledHappyEyeballs(); } @Test @@ -562,9 +651,9 @@ public void updateWeightTimer() { syncContext.execute(() -> wrr.acceptResolvedAddresses(ResolvedAddresses.newBuilder() .setAddresses(servers).setLoadBalancingPolicyConfig(weightedConfig) .setAttributes(affinity).build())); - verify(helper, times(6)).createSubchannel( + verify(helper, times(3)).createSubchannel( any(CreateSubchannelArgs.class)); - assertThat(fakeClock.getPendingTasks().size()).isEqualTo(1); + assertThat(getNumFilteredPendingTasks()).isEqualTo(1); Iterator it = subchannels.values().iterator(); Subchannel readySubchannel1 = it.next(); @@ -580,41 +669,44 @@ public void updateWeightTimer() { eq(ConnectivityState.READY), pickerCaptor.capture()); assertThat(pickerCaptor.getAllValues().size()).isEqualTo(2); WeightedRoundRobinPicker weightedPicker = - (WeightedRoundRobinPicker) pickerCaptor.getAllValues().get(0); + getWrrPicker(pickerCaptor.getAllValues().get(0)); assertThat(weightedPicker.getChildren().size()).isEqualTo(1); - weightedPicker = (WeightedRoundRobinPicker) pickerCaptor.getAllValues().get(1); + weightedPicker = getWrrPicker(pickerCaptor.getAllValues().get(1)); assertThat(weightedPicker.getChildren().size()).isEqualTo(2); WeightedChildLbState weightedChild1 = (WeightedChildLbState) getChild(weightedPicker, 0); WeightedChildLbState weightedChild2 = (WeightedChildLbState) getChild(weightedPicker, 1); - weightedChild1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( + weightedChild1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty, + weightedConfig.parsedMetricNamesForComputingUtilization).onLoadReport( InternalCallMetricRecorder.createMetricReport( 0.1, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>())); - weightedChild2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( + weightedChild2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty, + weightedConfig.parsedMetricNamesForComputingUtilization).onLoadReport( InternalCallMetricRecorder.createMetricReport( 0.2, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>())); - assertThat(fakeClock.forwardTime(11, TimeUnit.SECONDS)).isEqualTo(1); - assertThat(getAddressesFromPick(weightedPicker)) - .isEqualTo(weightedChild1.getEag()); - assertThat(fakeClock.getPendingTasks().size()).isEqualTo(1); + int expectedTasks = isEnabledHappyEyeballs() ? 2 : 1; + assertThat(fakeClock.forwardTime(11, TimeUnit.SECONDS)).isEqualTo(expectedTasks); + assertThat(getAddressesFromPick(weightedPicker)).isEqualTo(servers.get(0)); + assertThat(getNumFilteredPendingTasks()).isEqualTo(1); weightedConfig = WeightedRoundRobinLoadBalancerConfig.newBuilder() .setWeightUpdatePeriodNanos(500_000_000L) //.5s .build(); syncContext.execute(() -> wrr.acceptResolvedAddresses(ResolvedAddresses.newBuilder() .setAddresses(servers).setLoadBalancingPolicyConfig(weightedConfig) .setAttributes(affinity).build())); - assertThat(fakeClock.getPendingTasks().size()).isEqualTo(1); - weightedChild1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( + assertThat(getNumFilteredPendingTasks()).isEqualTo(1); + weightedChild1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty, + weightedConfig.parsedMetricNamesForComputingUtilization).onLoadReport( InternalCallMetricRecorder.createMetricReport( 0.2, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>())); - weightedChild2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( + weightedChild2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty, + weightedConfig.parsedMetricNamesForComputingUtilization).onLoadReport( InternalCallMetricRecorder.createMetricReport( 0.1, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>())); //timer fires, new weight updated - assertThat(fakeClock.forwardTime(500, TimeUnit.MILLISECONDS)).isEqualTo(1); - assertThat(getAddressesFromPick(weightedPicker)) - .isEqualTo(weightedChild2.getEag()); - assertThat(getAddressesFromPick(weightedPicker)) - .isEqualTo(weightedChild1.getEag()); + expectedTasks = isEnabledHappyEyeballs() ? 2 : 1; + assertThat(fakeClock.forwardTime(500, TimeUnit.MILLISECONDS)).isEqualTo(expectedTasks); + assertThat(getAddressesFromPick(weightedPicker)).isEqualTo(servers.get(1)); + assertThat(getAddressesFromPick(weightedPicker)).isEqualTo(servers.get(0)); } @Test @@ -622,9 +714,9 @@ public void weightExpired() { syncContext.execute(() -> wrr.acceptResolvedAddresses(ResolvedAddresses.newBuilder() .setAddresses(servers).setLoadBalancingPolicyConfig(weightedConfig) .setAttributes(affinity).build())); - verify(helper, times(6)).createSubchannel( + verify(helper, times(3)).createSubchannel( any(CreateSubchannelArgs.class)); - assertThat(fakeClock.getPendingTasks().size()).isEqualTo(1); + assertThat(getNumFilteredPendingTasks()).isEqualTo(1); Iterator it = subchannels.values().iterator(); Subchannel readySubchannel1 = it.next(); @@ -636,26 +728,27 @@ public void weightExpired() { verify(helper, times(2)).updateBalancingState( eq(ConnectivityState.READY), pickerCaptor.capture()); WeightedRoundRobinPicker weightedPicker = - (WeightedRoundRobinPicker) pickerCaptor.getAllValues().get(1); + getWrrPicker(pickerCaptor.getAllValues().get(1)); WeightedChildLbState weightedChild1 = (WeightedChildLbState) getChild(weightedPicker, 0); WeightedChildLbState weightedChild2 = (WeightedChildLbState) getChild(weightedPicker, 1); - weightedChild1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( + weightedChild1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty, + weightedConfig.parsedMetricNamesForComputingUtilization).onLoadReport( InternalCallMetricRecorder.createMetricReport( 0.1, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>())); - weightedChild2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( + weightedChild2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty, + weightedConfig.parsedMetricNamesForComputingUtilization).onLoadReport( InternalCallMetricRecorder.createMetricReport( 0.2, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>())); - assertThat(fakeClock.forwardTime(10, TimeUnit.SECONDS)).isEqualTo(1); + int expectedTasks = isEnabledHappyEyeballs() ? 2 : 1; + assertThat(fakeClock.forwardTime(10, TimeUnit.SECONDS)).isEqualTo(expectedTasks); Map pickCount = new HashMap<>(); for (int i = 0; i < 1000; i++) { EquivalentAddressGroup result = getAddressesFromPick(weightedPicker); pickCount.put(result, pickCount.getOrDefault(result, 0) + 1); } assertThat(pickCount.size()).isEqualTo(2); - assertThat(Math.abs(pickCount.get(weightedChild1.getEag()) / 1000.0 - 2.0 / 3)) - .isLessThan(0.002); - assertThat(Math.abs(pickCount.get(weightedChild2.getEag()) / 1000.0 - 1.0 / 3)) - .isLessThan(0.002); + assertThat(Math.abs(pickCount.get(servers.get(0)) / 1000.0 - 2.0 / 3)).isLessThan(0.002); + assertThat(Math.abs(pickCount.get(servers.get(1)) / 1000.0 - 1.0 / 3)).isLessThan(0.002); // weight expired, fallback to simple round robin assertThat(fakeClock.forwardTime(300, TimeUnit.SECONDS)).isEqualTo(1); @@ -665,10 +758,8 @@ weightedChild2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).on pickCount.put(result, pickCount.getOrDefault(result, 0) + 1); } assertThat(pickCount.size()).isEqualTo(2); - assertThat(Math.abs(pickCount.get(weightedChild1.getEag()) / 1000.0 - 0.5)) - .isLessThan(0.002); - assertThat(Math.abs(pickCount.get(weightedChild2.getEag()) / 1000.0 - 0.5)) - .isLessThan(0.002); + assertThat(Math.abs(pickCount.get(servers.get(0)) / 1000.0 - 0.5)).isLessThan(0.002); + assertThat(Math.abs(pickCount.get(servers.get(1)) / 1000.0 - 0.5)).isLessThan(0.002); } @Test @@ -676,9 +767,9 @@ public void rrFallback() { syncContext.execute(() -> wrr.acceptResolvedAddresses(ResolvedAddresses.newBuilder() .setAddresses(servers).setLoadBalancingPolicyConfig(weightedConfig) .setAttributes(affinity).build())); - verify(helper, times(6)).createSubchannel( + verify(helper, times(3)).createSubchannel( any(CreateSubchannelArgs.class)); - assertThat(fakeClock.getPendingTasks().size()).isEqualTo(1); + assertThat(getNumFilteredPendingTasks()).isEqualTo(1); Iterator it = subchannels.values().iterator(); Subchannel readySubchannel1 = it.next(); @@ -690,28 +781,20 @@ public void rrFallback() { verify(helper, times(2)).updateBalancingState( eq(ConnectivityState.READY), pickerCaptor.capture()); WeightedRoundRobinPicker weightedPicker = - (WeightedRoundRobinPicker) pickerCaptor.getAllValues().get(1); - assertThat(fakeClock.forwardTime(10, TimeUnit.SECONDS)).isEqualTo(1); - WeightedChildLbState weightedChild1 = (WeightedChildLbState) getChild(weightedPicker, 0); - WeightedChildLbState weightedChild2 = (WeightedChildLbState) getChild(weightedPicker, 1); - Map qpsByChannel = ImmutableMap.of(weightedChild1.getEag(), 2, - weightedChild2.getEag(), 1); + getWrrPicker(pickerCaptor.getAllValues().get(1)); + int expectedTasks = isEnabledHappyEyeballs() ? 2 : 1; + assertThat(fakeClock.forwardTime(10, TimeUnit.SECONDS)).isEqualTo(expectedTasks); + Map qpsByChannel = ImmutableMap.of(servers.get(0), 2, + servers.get(1), 1); Map pickCount = new HashMap<>(); for (int i = 0; i < 1000; i++) { PickResult pickResult = weightedPicker.pickSubchannel(mockArgs); EquivalentAddressGroup addresses = getAddresses(pickResult); pickCount.merge(addresses, 1, Integer::sum); - assertThat(pickResult.getStreamTracerFactory()).isNotNull(); - WeightedChildLbState childLbState = (WeightedChildLbState) wrr.getChildLbStateEag(addresses); - childLbState.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( - InternalCallMetricRecorder.createMetricReport( - 0.1, 0, 0.1, qpsByChannel.get(addresses), 0, - new HashMap<>(), new HashMap<>(), new HashMap<>())); + reportLoadOnRpc(pickResult, 0.1, 0, 0.1, qpsByChannel.get(addresses), 0); } - assertThat(Math.abs(pickCount.get(weightedChild1.getEag()) / 1000.0 - 1.0 / 2)) - .isAtMost(0.1); - assertThat(Math.abs(pickCount.get(weightedChild2.getEag()) / 1000.0 - 1.0 / 2)) - .isAtMost(0.1); + assertThat(Math.abs(pickCount.get(servers.get(0)) / 1000.0 - 1.0 / 2)).isAtMost(0.1); + assertThat(Math.abs(pickCount.get(servers.get(1)) / 1000.0 - 1.0 / 2)).isAtMost(0.1); // Identical to above except forwards time after each pick pickCount.clear(); @@ -719,19 +802,12 @@ childLbState.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLo PickResult pickResult = weightedPicker.pickSubchannel(mockArgs); EquivalentAddressGroup addresses = getAddresses(pickResult); pickCount.merge(addresses, 1, Integer::sum); - assertThat(pickResult.getStreamTracerFactory()).isNotNull(); - WeightedChildLbState childLbState = (WeightedChildLbState) wrr.getChildLbStateEag(addresses); - childLbState.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( - InternalCallMetricRecorder.createMetricReport( - 0.1, 0, 0.1, qpsByChannel.get(addresses), 0, - new HashMap<>(), new HashMap<>(), new HashMap<>())); + reportLoadOnRpc(pickResult, 0.1, 0, 0.1, qpsByChannel.get(addresses), 0); fakeClock.forwardTime(50, TimeUnit.MILLISECONDS); } assertThat(pickCount.size()).isEqualTo(2); - assertThat(Math.abs(pickCount.get(weightedChild1.getEag()) / 1000.0 - 2.0 / 3)) - .isAtMost(0.1); - assertThat(Math.abs(pickCount.get(weightedChild2.getEag()) / 1000.0 - 1.0 / 3)) - .isAtMost(0.1); + assertThat(Math.abs(pickCount.get(servers.get(0)) / 1000.0 - 2.0 / 3)).isAtMost(0.1); + assertThat(Math.abs(pickCount.get(servers.get(1)) / 1000.0 - 1.0 / 3)).isAtMost(0.1); } private static EquivalentAddressGroup getAddresses(PickResult pickResult) { @@ -743,9 +819,9 @@ public void unknownWeightIsAvgWeight() { syncContext.execute(() -> wrr.acceptResolvedAddresses(ResolvedAddresses.newBuilder() .setAddresses(servers).setLoadBalancingPolicyConfig(weightedConfig) .setAttributes(affinity).build())); - verify(helper, times(6)).createSubchannel( + verify(helper, times(3)).createSubchannel( any(CreateSubchannelArgs.class)); // 3 from setup plus 3 from the execute - assertThat(fakeClock.getPendingTasks().size()).isEqualTo(1); + assertThat(getNumFilteredPendingTasks()).isEqualTo(1); Iterator it = subchannels.values().iterator(); Subchannel readySubchannel1 = it.next(); @@ -760,14 +836,15 @@ public void unknownWeightIsAvgWeight() { verify(helper, times(3)).updateBalancingState( eq(ConnectivityState.READY), pickerCaptor.capture()); WeightedRoundRobinPicker weightedPicker = - (WeightedRoundRobinPicker) pickerCaptor.getAllValues().get(2); + getWrrPicker(pickerCaptor.getAllValues().get(2)); WeightedChildLbState weightedChild1 = (WeightedChildLbState) getChild(weightedPicker, 0); WeightedChildLbState weightedChild2 = (WeightedChildLbState) getChild(weightedPicker, 1); - WeightedChildLbState weightedChild3 = (WeightedChildLbState) getChild(weightedPicker, 2); - weightedChild1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( + weightedChild1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty, + weightedConfig.parsedMetricNamesForComputingUtilization).onLoadReport( InternalCallMetricRecorder.createMetricReport( 0.1, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>())); - weightedChild2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( + weightedChild2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty, + weightedConfig.parsedMetricNamesForComputingUtilization).onLoadReport( InternalCallMetricRecorder.createMetricReport( 0.2, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>())); assertThat(fakeClock.forwardTime(10, TimeUnit.SECONDS)).isEqualTo(1); @@ -777,13 +854,10 @@ weightedChild2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).on pickCount.merge(result.getAddresses(), 1, Integer::sum); } assertThat(pickCount.size()).isEqualTo(3); - assertThat(Math.abs(pickCount.get(weightedChild1.getEag()) / 1000.0 - 4.0 / 9)) - .isLessThan(0.002); - assertThat(Math.abs(pickCount.get(weightedChild2.getEag()) / 1000.0 - 2.0 / 9)) - .isLessThan(0.002); + assertThat(Math.abs(pickCount.get(servers.get(0)) / 1000.0 - 4.0 / 9)).isLessThan(0.002); + assertThat(Math.abs(pickCount.get(servers.get(1)) / 1000.0 - 2.0 / 9)).isLessThan(0.002); // subchannel3's weight is average of subchannel1 and subchannel2 - assertThat(Math.abs(pickCount.get(weightedChild3.getEag()) / 1000.0 - 3.0 / 9)) - .isLessThan(0.002); + assertThat(Math.abs(pickCount.get(servers.get(2)) / 1000.0 - 3.0 / 9)).isLessThan(0.002); } @Test @@ -791,9 +865,9 @@ public void pickFromOtherThread() throws Exception { syncContext.execute(() -> wrr.acceptResolvedAddresses(ResolvedAddresses.newBuilder() .setAddresses(servers).setLoadBalancingPolicyConfig(weightedConfig) .setAttributes(affinity).build())); - verify(helper, times(6)).createSubchannel( + verify(helper, times(3)).createSubchannel( any(CreateSubchannelArgs.class)); - assertThat(fakeClock.getPendingTasks().size()).isEqualTo(1); + assertThat(getNumFilteredPendingTasks()).isEqualTo(1); Iterator it = subchannels.values().iterator(); Subchannel readySubchannel1 = it.next(); @@ -805,19 +879,21 @@ public void pickFromOtherThread() throws Exception { verify(helper, times(2)).updateBalancingState( eq(ConnectivityState.READY), pickerCaptor.capture()); WeightedRoundRobinPicker weightedPicker = - (WeightedRoundRobinPicker) pickerCaptor.getAllValues().get(1); + getWrrPicker(pickerCaptor.getAllValues().get(1)); WeightedChildLbState weightedChild1 = (WeightedChildLbState) getChild(weightedPicker, 0); WeightedChildLbState weightedChild2 = (WeightedChildLbState) getChild(weightedPicker, 1); - weightedChild1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( + weightedChild1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty, + weightedConfig.parsedMetricNamesForComputingUtilization).onLoadReport( InternalCallMetricRecorder.createMetricReport( 0.1, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>())); - weightedChild2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( + weightedChild2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty, + weightedConfig.parsedMetricNamesForComputingUtilization).onLoadReport( InternalCallMetricRecorder.createMetricReport( 0.2, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>())); CyclicBarrier barrier = new CyclicBarrier(2); Map pickCount = new ConcurrentHashMap<>(); - pickCount.put(weightedChild1.getEag(), new AtomicInteger(0)); - pickCount.put(weightedChild2.getEag(), new AtomicInteger(0)); + pickCount.put(servers.get(0), new AtomicInteger(0)); + pickCount.put(servers.get(1), new AtomicInteger(0)); new Thread(new Runnable() { @Override public void run() { @@ -834,7 +910,8 @@ public void run() { } } }).start(); - assertThat(fakeClock.forwardTime(10, TimeUnit.SECONDS)).isEqualTo(1); + int expectedTasks = isEnabledHappyEyeballs() ? 2 : 1; + assertThat(fakeClock.forwardTime(10, TimeUnit.SECONDS)).isEqualTo(expectedTasks); barrier.await(); for (int i = 0; i < 1000; i++) { EquivalentAddressGroup result = getAddresses(weightedPicker.pickSubchannel(mockArgs)); @@ -843,10 +920,8 @@ public void run() { barrier.await(); assertThat(pickCount.size()).isEqualTo(2); // after blackout period - assertThat(Math.abs(pickCount.get(weightedChild1.getEag()).get() / 2000.0 - 2.0 / 3)) - .isLessThan(0.002); - assertThat(Math.abs(pickCount.get(weightedChild2.getEag()).get() / 2000.0 - 1.0 / 3)) - .isLessThan(0.002); + assertThat(Math.abs(pickCount.get(servers.get(0)).get() / 2000.0 - 2.0 / 3)).isLessThan(0.002); + assertThat(Math.abs(pickCount.get(servers.get(1)).get() / 2000.0 - 1.0 / 3)).isLessThan(0.002); } @Test(expected = NullPointerException.class) @@ -1047,7 +1122,7 @@ public void testImmediateWraparound() { .isLessThan(0.002); } } - + @Test public void testWraparound() { float[] weights = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f}; @@ -1102,6 +1177,483 @@ public void removingAddressShutsdownSubchannel() { } + @Test + public void metrics() { + // Give WRR some valid addresses to work with. + Attributes attributesWithLocality = Attributes.newBuilder() + .set(WeightedTargetLoadBalancer.CHILD_NAME, locality) + .set(NameResolver.ATTR_BACKEND_SERVICE, backendService) + .build(); + syncContext.execute(() -> wrr.acceptResolvedAddresses(ResolvedAddresses.newBuilder() + .setAddresses(servers).setLoadBalancingPolicyConfig(weightedConfig) + .setAttributes(attributesWithLocality).build())); + + // Flip the three subchannels to READY state to initiate the WRR logic + Iterator it = subchannels.values().iterator(); + Subchannel readySubchannel1 = it.next(); + getSubchannelStateListener(readySubchannel1).onSubchannelState(ConnectivityStateInfo + .forNonError(ConnectivityState.READY)); + Subchannel readySubchannel2 = it.next(); + getSubchannelStateListener(readySubchannel2).onSubchannelState(ConnectivityStateInfo + .forNonError(ConnectivityState.READY)); + Subchannel readySubchannel3 = it.next(); + getSubchannelStateListener(readySubchannel3).onSubchannelState(ConnectivityStateInfo + .forNonError(ConnectivityState.READY)); + + // WRR creates a picker that updates the weights for each of the child subchannels. This should + // give us three "rr_fallback" metric events as we don't yet have any weights to do weighted + // round-robin. + verifyLongCounterRecord("grpc.lb.wrr.rr_fallback", 3, 1); + + // We should also see six records of endpoint weights. They should all be for 0 as we don't yet + // have valid weights. + verifyDoubleHistogramRecord("grpc.lb.wrr.endpoint_weights", 6, 0); + + // We should not yet be seeing any "endpoint_weight_stale" events since we don't even have + // valid weights yet. + verifyLongCounterRecord("grpc.lb.wrr.endpoint_weight_stale", 0, 1); + + // Each time weights are updated, WRR will see if each subchannel weight is useable. As we have + // no weights yet, we should see three "endpoint_weight_not_yet_usable" metric events with the + // value increasing by one each time as all the endpoints come online. + verifyLongCounterRecord("grpc.lb.wrr.endpoint_weight_not_yet_usable", 1, 1); + verifyLongCounterRecord("grpc.lb.wrr.endpoint_weight_not_yet_usable", 1, 2); + verifyLongCounterRecord("grpc.lb.wrr.endpoint_weight_not_yet_usable", 1, 3); + + // Send one child LB state an ORCA update with some valid utilization/qps data so that weights + // can be calculated, but it's still essentially round_robin + Iterator childLbStates = wrr.getChildLbStates().iterator(); + ((WeightedChildLbState) childLbStates.next()).new OrcaReportListener( + weightedConfig.errorUtilizationPenalty, + weightedConfig.parsedMetricNamesForComputingUtilization) + .onLoadReport(InternalCallMetricRecorder.createMetricReport(0.1, 0, 0.1, 1, 0, + new HashMap<>(), new HashMap<>(), new HashMap<>())); + + fakeClock.forwardTime(1, TimeUnit.SECONDS); + + // Now send a second child LB state an ORCA update, so there's real weights + ((WeightedChildLbState) childLbStates.next()).new OrcaReportListener( + weightedConfig.errorUtilizationPenalty, + weightedConfig.parsedMetricNamesForComputingUtilization) + .onLoadReport(InternalCallMetricRecorder.createMetricReport(0.1, 0, 0.1, 1, 0, + new HashMap<>(), new HashMap<>(), new HashMap<>())); + ((WeightedChildLbState) childLbStates.next()).new OrcaReportListener( + weightedConfig.errorUtilizationPenalty, + weightedConfig.parsedMetricNamesForComputingUtilization) + .onLoadReport(InternalCallMetricRecorder.createMetricReport(0.1, 0, 0.1, 1, 0, + new HashMap<>(), new HashMap<>(), new HashMap<>())); + + // Let's reset the mock MetricsRecorder so that it's easier to verify what happened after the + // weights were updated + reset(mockMetricRecorder); + + // We go forward in time past the default 10s blackout period for the first child. The weights + // would get updated as the default update interval is 1s. + fakeClock.forwardTime(9, TimeUnit.SECONDS); + + verifyLongCounterRecord("grpc.lb.wrr.rr_fallback", 1, 1); + + // And after another second the other children have weights + reset(mockMetricRecorder); + fakeClock.forwardTime(1, TimeUnit.SECONDS); + + // Since we have weights on all the child LB states, the weight update should not result in + // further rr_fallback metric entries. + verifyLongCounterRecord("grpc.lb.wrr.rr_fallback", 0, 1); + + // We should not see an increase to the earlier count of "endpoint_weight_not_yet_usable". + verifyLongCounterRecord("grpc.lb.wrr.endpoint_weight_not_yet_usable", 0, 1); + + // No endpoints should have gotten stale yet either. + verifyLongCounterRecord("grpc.lb.wrr.endpoint_weight_stale", 0, 1); + + // Now with valid weights we should have seen the value in the endpoint weights histogram. + verifyDoubleHistogramRecord("grpc.lb.wrr.endpoint_weights", 3, 10); + + reset(mockMetricRecorder); + + // Weights become stale in three minutes. Let's move ahead in time by 3 minutes and make sure + // we get metrics events for each endpoint. + fakeClock.forwardTime(3, TimeUnit.MINUTES); + + verifyLongCounterRecord("grpc.lb.wrr.endpoint_weight_stale", 1, 3); + + // With the weights stale each three endpoints should report 0 weights. + verifyDoubleHistogramRecord("grpc.lb.wrr.endpoint_weights", 3, 0); + + // Since the weights are now stale the update should have triggered an additional rr_fallback + // event. + verifyLongCounterRecord("grpc.lb.wrr.rr_fallback", 1, 1); + + // No further weights-not-useable events should occur, since we have received weights and + // are out of the blackout. + verifyLongCounterRecord("grpc.lb.wrr.endpoint_weight_not_yet_usable", 0, 1); + + // All metric events should be accounted for. + verifyNoMoreInteractions(mockMetricRecorder); + } + + @Test + public void metricWithRealChannel() throws Exception { + String serverName = "wrr-metrics"; + grpcCleanupRule.register( + InProcessServerBuilder.forName(serverName) + .addService(ServerServiceDefinition.builder( + TestMethodDescriptors.voidMethod().getServiceName()) + .addMethod(TestMethodDescriptors.voidMethod(), (call, headers) -> { + call.sendHeaders(new Metadata()); + call.sendMessage(null); + call.close(Status.OK, new Metadata()); + return new ServerCall.Listener() {}; + }) + .build()) + .directExecutor() + .build() + .start()); + MetricSink metrics = mock(MetricSink.class, delegatesTo(new NoopMetricSink())); + Channel channel = grpcCleanupRule.register( + InternalManagedChannelBuilder.addMetricSink( + InProcessChannelBuilder.forName(serverName) + .defaultServiceConfig(Collections.singletonMap( + "loadBalancingConfig", Arrays.asList(Collections.singletonMap( + "weighted_round_robin", Collections.emptyMap())))) + .directExecutor(), + metrics) + .directExecutor() + .build()); + + // Ping-pong to wait for channel to fully start + StreamRecorder recorder = StreamRecorder.create(); + StreamObserver requestObserver = ClientCalls.asyncClientStreamingCall( + channel.newCall(TestMethodDescriptors.voidMethod(), CallOptions.DEFAULT), recorder); + requestObserver.onCompleted(); + assertThat(recorder.awaitCompletion(10, TimeUnit.SECONDS)).isTrue(); + assertThat(recorder.getError()).isNull(); + + // Make sure at least one metric works. The other tests will make sure other metrics and the + // edge cases are working. Since this is racy, we just care it happened at least once. + verify(metrics, atLeast(1)).addLongCounter( + argThat((instr) -> instr.getName().equals("grpc.lb.wrr.rr_fallback")), + eq(1L), + eq(Arrays.asList("directaddress:///wrr-metrics")), + eq(Arrays.asList("", ""))); + } + + + @Test + public void customMetric_priority_overAppUtil() { + weightedConfig = WeightedRoundRobinLoadBalancerConfig.newBuilder().setBlackoutPeriodNanos(0) + .setMetricNamesForComputingUtilization(ImmutableList.of("named_metrics.cost")).build(); + wrr = new WeightedRoundRobinLoadBalancer(helper, fakeClock.getDeadlineTicker()); + + syncContext.execute( + () -> wrr.acceptResolvedAddresses(ResolvedAddresses.newBuilder().setAddresses(servers) + .setLoadBalancingPolicyConfig(weightedConfig).setAttributes(affinity).build())); + + Iterator it = subchannels.values().iterator(); + Subchannel readySubchannel = it.next(); + getSubchannelStateListener(readySubchannel) + .onSubchannelState(ConnectivityStateInfo.forNonError(ConnectivityState.READY)); + + WeightedChildLbState weightedChild = + (WeightedChildLbState) wrr.getChildLbStates().iterator().next(); + WeightedChildLbState.OrcaReportListener listener = weightedChild.getOrCreateOrcaListener( + weightedConfig.errorUtilizationPenalty, + weightedConfig.parsedMetricNamesForComputingUtilization); + + Map namedMetrics = new HashMap<>(); + namedMetrics.put("cost", 0.5); + // App util = 0.8 + MetricReport report = InternalCallMetricRecorder.createMetricReport(0.1, 0.8, 0.1, 1, 0, + new HashMap<>(), new HashMap<>(), namedMetrics); + listener.onLoadReport(report); + // Custom metrics now take priority over app_util + // qps=1, util=0.5 -> weight=2.0 + fakeClock.forwardTime(1100, TimeUnit.MILLISECONDS); + verify(mockMetricRecorder).recordDoubleHistogram( + argThat(instr -> instr.getName().equals("grpc.lb.wrr.endpoint_weights")), eq(2.0), any(), + any()); + } + + @Test + public void customMetric_invalid_fallbackToAppUtil() { + weightedConfig = WeightedRoundRobinLoadBalancerConfig.newBuilder().setBlackoutPeriodNanos(0) + .setMetricNamesForComputingUtilization(ImmutableList.of("named_metrics.cost")).build(); + wrr = new WeightedRoundRobinLoadBalancer(helper, fakeClock.getDeadlineTicker()); + + syncContext.execute( + () -> wrr.acceptResolvedAddresses(ResolvedAddresses.newBuilder().setAddresses(servers) + .setLoadBalancingPolicyConfig(weightedConfig).setAttributes(affinity).build())); + + Iterator it = subchannels.values().iterator(); + Subchannel readySubchannel = it.next(); + getSubchannelStateListener(readySubchannel) + .onSubchannelState(ConnectivityStateInfo.forNonError(ConnectivityState.READY)); + + WeightedChildLbState weightedChild = + (WeightedChildLbState) wrr.getChildLbStates().iterator().next(); + WeightedChildLbState.OrcaReportListener listener = weightedChild.getOrCreateOrcaListener( + weightedConfig.errorUtilizationPenalty, + weightedConfig.parsedMetricNamesForComputingUtilization); + + // custom metric is NaN, App util = 0.8 + Map namedMetrics = new HashMap<>(); + namedMetrics.put("cost", Double.NaN); + MetricReport report = InternalCallMetricRecorder.createMetricReport(0.1, 0.8, 0.1, 1, 0, + new HashMap<>(), new HashMap<>(), namedMetrics); + listener.onLoadReport(report); + + // Should fallback to App Util (0.8) + // qps=1, util=0.8 -> weight=1.25 + fakeClock.forwardTime(1100, TimeUnit.MILLISECONDS); + verify(mockMetricRecorder).recordDoubleHistogram( + argThat(instr -> instr.getName().equals("grpc.lb.wrr.endpoint_weights")), eq(1.25), any(), + any()); + } + + @Test + public void customMetric_mapLookup_used() { + weightedConfig = WeightedRoundRobinLoadBalancerConfig.newBuilder().setBlackoutPeriodNanos(0) + .setMetricNamesForComputingUtilization(ImmutableList.of("named_metrics.cost")).build(); + wrr = new WeightedRoundRobinLoadBalancer(helper, fakeClock.getDeadlineTicker()); + + syncContext.execute( + () -> wrr.acceptResolvedAddresses(ResolvedAddresses.newBuilder().setAddresses(servers) + .setLoadBalancingPolicyConfig(weightedConfig).setAttributes(affinity).build())); + + Iterator it = subchannels.values().iterator(); + Subchannel readySubchannel = it.next(); + getSubchannelStateListener(readySubchannel) + .onSubchannelState(ConnectivityStateInfo.forNonError(ConnectivityState.READY)); + + WeightedChildLbState weightedChild = + (WeightedChildLbState) wrr.getChildLbStates().iterator().next(); + WeightedChildLbState.OrcaReportListener listener = weightedChild.getOrCreateOrcaListener( + weightedConfig.errorUtilizationPenalty, + weightedConfig.parsedMetricNamesForComputingUtilization); + + Map namedMetrics = new HashMap<>(); + namedMetrics.put("cost", 0.5); + MetricReport report = InternalCallMetricRecorder.createMetricReport(0.1, 0, 0.1, 1, 0, + new HashMap<>(), new HashMap<>(), namedMetrics); + listener.onLoadReport(report); + // qps=1, util=0.5 -> weight=2.0 + fakeClock.forwardTime(1100, TimeUnit.MILLISECONDS); + verify(mockMetricRecorder).recordDoubleHistogram( + argThat(instr -> instr.getName().equals("grpc.lb.wrr.endpoint_weights")), eq(2.0), any(), + any()); + } + + @Test + public void customMetric_shouldFilterOutAndFallbackToCpu() { + weightedConfig = WeightedRoundRobinLoadBalancerConfig.newBuilder().setBlackoutPeriodNanos(0) + .setMetricNamesForComputingUtilization(ImmutableList.of("named_metrics.cost")).build(); + wrr = new WeightedRoundRobinLoadBalancer(helper, fakeClock.getDeadlineTicker()); + + syncContext.execute( + () -> wrr.acceptResolvedAddresses(ResolvedAddresses.newBuilder().setAddresses(servers) + .setLoadBalancingPolicyConfig(weightedConfig).setAttributes(affinity).build())); + + Iterator it = subchannels.values().iterator(); + Subchannel readySubchannel = it.next(); + getSubchannelStateListener(readySubchannel) + .onSubchannelState(ConnectivityStateInfo.forNonError(ConnectivityState.READY)); + + WeightedChildLbState weightedChild = + (WeightedChildLbState) wrr.getChildLbStates().iterator().next(); + WeightedChildLbState.OrcaReportListener listener = weightedChild.getOrCreateOrcaListener( + weightedConfig.errorUtilizationPenalty, + weightedConfig.parsedMetricNamesForComputingUtilization); + + // custom metric is NaN, but CPU is 0.1 + Map namedMetrics = new HashMap<>(); + namedMetrics.put("cost", Double.NaN); + MetricReport report = InternalCallMetricRecorder.createMetricReport(0.1, 0, 0.1, 1, 0, + new HashMap<>(), new HashMap<>(), namedMetrics); + listener.onLoadReport(report); + + // Should fallback to CPU (0.1) + // fallback to cpu: qps=1, util=0.1 -> weight=10.0 + fakeClock.forwardTime(1100, TimeUnit.MILLISECONDS); + verify(mockMetricRecorder).recordDoubleHistogram( + argThat(instr -> instr.getName().equals("grpc.lb.wrr.endpoint_weights")), eq(10.0), any(), + any()); + } + + @Test + public void customMetric_multipleMetrics_maxUsed() { + weightedConfig = WeightedRoundRobinLoadBalancerConfig.newBuilder().setBlackoutPeriodNanos(0) + .setMetricNamesForComputingUtilization( + ImmutableList.of("named_metrics.cost", "named_metrics.score")) + .build(); + wrr = new WeightedRoundRobinLoadBalancer(helper, fakeClock.getDeadlineTicker()); + + syncContext.execute( + () -> wrr.acceptResolvedAddresses(ResolvedAddresses.newBuilder().setAddresses(servers) + .setLoadBalancingPolicyConfig(weightedConfig).setAttributes(affinity).build())); + + Iterator it = subchannels.values().iterator(); + Subchannel readySubchannel = it.next(); + getSubchannelStateListener(readySubchannel) + .onSubchannelState(ConnectivityStateInfo.forNonError(ConnectivityState.READY)); + + WeightedChildLbState weightedChild = + (WeightedChildLbState) wrr.getChildLbStates().iterator().next(); + WeightedChildLbState.OrcaReportListener listener = weightedChild.getOrCreateOrcaListener( + weightedConfig.errorUtilizationPenalty, + weightedConfig.parsedMetricNamesForComputingUtilization); + + Map namedMetrics = new HashMap<>(); + namedMetrics.put("cost", 0.5); + namedMetrics.put("score", 0.8); + MetricReport report = InternalCallMetricRecorder.createMetricReport(0.1, 0, 0.1, 1, 0, + new HashMap<>(), new HashMap<>(), namedMetrics); + listener.onLoadReport(report); + // qps=1, util=0.8 (max of 0.5 and 0.8) -> weight=1.25 + fakeClock.forwardTime(1100, TimeUnit.MILLISECONDS); + verify(mockMetricRecorder).recordDoubleHistogram( + argThat(instr -> instr.getName().equals("grpc.lb.wrr.endpoint_weights")), eq(1.25), any(), + any()); + } + + @Test + public void customMetric_allInvalid_fallbackToCpu() { + weightedConfig = WeightedRoundRobinLoadBalancerConfig.newBuilder().setBlackoutPeriodNanos(0) + .setMetricNamesForComputingUtilization( + ImmutableList.of("named_metrics.cost", "named_metrics.score", "named_metrics.other")) + .build(); + wrr = new WeightedRoundRobinLoadBalancer(helper, fakeClock.getDeadlineTicker()); + + syncContext.execute( + () -> wrr.acceptResolvedAddresses(ResolvedAddresses.newBuilder().setAddresses(servers) + .setLoadBalancingPolicyConfig(weightedConfig).setAttributes(affinity).build())); + + Iterator it = subchannels.values().iterator(); + Subchannel readySubchannel = it.next(); + getSubchannelStateListener(readySubchannel) + .onSubchannelState(ConnectivityStateInfo.forNonError(ConnectivityState.READY)); + + WeightedChildLbState weightedChild = + (WeightedChildLbState) wrr.getChildLbStates().iterator().next(); + WeightedChildLbState.OrcaReportListener listener = weightedChild.getOrCreateOrcaListener( + weightedConfig.errorUtilizationPenalty, + weightedConfig.parsedMetricNamesForComputingUtilization); + + Map namedMetrics = new HashMap<>(); + namedMetrics.put("cost", Double.NaN); + namedMetrics.put("score", 0.0); + namedMetrics.put("other", -1.0); + MetricReport report = InternalCallMetricRecorder.createMetricReport(0.1, 0, 0.1, 1, 0, + new HashMap<>(), new HashMap<>(), namedMetrics); + listener.onLoadReport(report); + // qps=1, util=0.1 (fallback to cpu) -> weight=10.0 + fakeClock.forwardTime(1100, TimeUnit.MILLISECONDS); + verify(mockMetricRecorder).recordDoubleHistogram( + argThat(instr -> instr.getName().equals("grpc.lb.wrr.endpoint_weights")), eq(10.0), any(), + any()); + } + + @Test + public void customMetric_mixInvalidAndValid_validUsed() { + weightedConfig = WeightedRoundRobinLoadBalancerConfig.newBuilder().setBlackoutPeriodNanos(0) + .setMetricNamesForComputingUtilization(ImmutableList.of("named_metrics.cost", + "named_metrics.score", "named_metrics.other1", "named_metrics.other2")) + .build(); + wrr = new WeightedRoundRobinLoadBalancer(helper, fakeClock.getDeadlineTicker()); + + syncContext.execute( + () -> wrr.acceptResolvedAddresses(ResolvedAddresses.newBuilder().setAddresses(servers) + .setLoadBalancingPolicyConfig(weightedConfig).setAttributes(affinity).build())); + + Iterator it = subchannels.values().iterator(); + Subchannel readySubchannel = it.next(); + getSubchannelStateListener(readySubchannel) + .onSubchannelState(ConnectivityStateInfo.forNonError(ConnectivityState.READY)); + + WeightedChildLbState weightedChild = + (WeightedChildLbState) wrr.getChildLbStates().iterator().next(); + WeightedChildLbState.OrcaReportListener listener = weightedChild.getOrCreateOrcaListener( + weightedConfig.errorUtilizationPenalty, + weightedConfig.parsedMetricNamesForComputingUtilization); + + Map namedMetrics = new HashMap<>(); + namedMetrics.put("cost", Double.NaN); + namedMetrics.put("score", 0.5); + namedMetrics.put("other1", 0.0); + namedMetrics.put("other2", -123.0); + MetricReport report = InternalCallMetricRecorder.createMetricReport(0.1, 0, 0.1, 1, 0, + new HashMap<>(), new HashMap<>(), namedMetrics); + listener.onLoadReport(report); + // qps=1, util=0.5 -> weight=2.0 + fakeClock.forwardTime(1100, TimeUnit.MILLISECONDS); + verify(mockMetricRecorder).recordDoubleHistogram( + argThat(instr -> instr.getName().equals("grpc.lb.wrr.endpoint_weights")), eq(2.0), any(), + any()); + } + + // Verifies that the MetricRecorder has been called to record a long counter value of 1 for the + // given metric name, the given number of times + private void verifyLongCounterRecord(String name, int times, long value) { + verify(mockMetricRecorder, times(times)).addLongCounter( + argThat(new ArgumentMatcher() { + @Override + public boolean matches(LongCounterMetricInstrument longCounterInstrument) { + return longCounterInstrument.getName().equals(name); + } + }), + eq(value), + eq(Lists.newArrayList(channelTarget)), + eq(Lists.newArrayList(locality, backendService))); + } + + // Verifies that the MetricRecorder has been called to record a given double histogram value the + // given amount of times. + private void verifyDoubleHistogramRecord(String name, int times, double value) { + verify(mockMetricRecorder, times(times)).recordDoubleHistogram( + argThat(new ArgumentMatcher() { + @Override + public boolean matches(DoubleHistogramMetricInstrument doubleHistogramInstrument) { + return doubleHistogramInstrument.getName().equals(name); + } + }), + eq(value), + eq(Lists.newArrayList(channelTarget)), + eq(Lists.newArrayList(locality, backendService))); + } + + private int getNumFilteredPendingTasks() { + return AbstractTestHelper.getNumFilteredPendingTasks(fakeClock); + } + + private static final Metadata.Key ORCA_LOAD_METRICS_KEY = + Metadata.Key.of( + "endpoint-load-metrics-bin", + ProtoUtils.metadataMarshaller(OrcaLoadReport.getDefaultInstance())); + private static final ClientStreamTracer.StreamInfo STREAM_INFO = + ClientStreamTracer.StreamInfo.newBuilder().build(); + + private static void reportLoadOnRpc( + PickResult pickResult, + double cpuUtilization, + double applicationUtilization, + double memoryUtilization, + double qps, + double eps) { + ClientStreamTracer childTracer = pickResult.getStreamTracerFactory() + .newClientStreamTracer(STREAM_INFO, new Metadata()); + Metadata trailer = new Metadata(); + trailer.put( + ORCA_LOAD_METRICS_KEY, + OrcaLoadReport.newBuilder() + .setCpuUtilization(cpuUtilization) + .setApplicationUtilization(applicationUtilization) + .setMemUtilization(memoryUtilization) + .setRpsFractional(qps) + .setEps(eps) + .build()); + childTracer.inboundTrailers(trailer); + } + private static final class VerifyingScheduler { private final StaticStrideScheduler delegate; private final int max; @@ -1148,6 +1700,9 @@ public int nextInt() { } private class TestHelper extends AbstractTestHelper { + public TestHelper() { + super(fakeClock, syncContext); + } @Override public Map, Subchannel> getSubchannelMap() { @@ -1155,25 +1710,13 @@ public Map, Subchannel> getSubchannelMap() { } @Override - public Map getMockToRealSubChannelMap() { - return mockToRealSubChannelMap; + public MetricRecorder getMetricRecorder() { + return mockMetricRecorder; } @Override - public Map getSubchannelStateListeners() { - return subchannelStateListeners; + public String getChannelTarget() { + return channelTarget; } - - @Override - public SynchronizationContext getSynchronizationContext() { - return syncContext; - } - - @Override - public ScheduledExecutorService getScheduledExecutorService() { - return fakeClock.getScheduledExecutorService(); - } - - } } diff --git a/xds/src/test/java/io/grpc/xds/WeightedTargetLoadBalancerProviderTest.java b/xds/src/test/java/io/grpc/xds/WeightedTargetLoadBalancerProviderTest.java index 7a54036b73a..c8eab309f38 100644 --- a/xds/src/test/java/io/grpc/xds/WeightedTargetLoadBalancerProviderTest.java +++ b/xds/src/test/java/io/grpc/xds/WeightedTargetLoadBalancerProviderTest.java @@ -26,7 +26,7 @@ import io.grpc.LoadBalancerRegistry; import io.grpc.NameResolver.ConfigOrError; import io.grpc.internal.JsonParser; -import io.grpc.internal.ServiceConfigUtil.PolicySelection; +import io.grpc.util.GracefulSwitchLoadBalancer; import io.grpc.xds.WeightedTargetLoadBalancerProvider.WeightedPolicySelection; import io.grpc.xds.WeightedTargetLoadBalancerProvider.WeightedTargetConfig; import java.util.Map; @@ -128,11 +128,13 @@ public ConfigOrError parseLoadBalancingPolicyConfig(Map rawConfig) { "target_1", new WeightedPolicySelection( 10, - new PolicySelection(lbProviderFoo, fooConfig)), + GracefulSwitchLoadBalancer.createLoadBalancingPolicyConfig( + lbProviderFoo, fooConfig)), "target_2", new WeightedPolicySelection( 20, - new PolicySelection(lbProviderBar, barConfig))))); + GracefulSwitchLoadBalancer.createLoadBalancingPolicyConfig( + lbProviderBar, barConfig))))); assertThat(parsedConfig).isEqualTo(expectedConfig); } } diff --git a/xds/src/test/java/io/grpc/xds/WeightedTargetLoadBalancerTest.java b/xds/src/test/java/io/grpc/xds/WeightedTargetLoadBalancerTest.java index fa80c8d6e12..55ff0cd8078 100644 --- a/xds/src/test/java/io/grpc/xds/WeightedTargetLoadBalancerTest.java +++ b/xds/src/test/java/io/grpc/xds/WeightedTargetLoadBalancerTest.java @@ -49,7 +49,7 @@ import io.grpc.LoadBalancerRegistry; import io.grpc.Status; import io.grpc.SynchronizationContext; -import io.grpc.internal.ServiceConfigUtil.PolicySelection; +import io.grpc.util.GracefulSwitchLoadBalancer; import io.grpc.xds.WeightedRandomPicker.WeightedChildPicker; import io.grpc.xds.WeightedTargetLoadBalancerProvider.WeightedPolicySelection; import io.grpc.xds.WeightedTargetLoadBalancerProvider.WeightedTargetConfig; @@ -113,6 +113,7 @@ public String getPolicyName() { public LoadBalancer newLoadBalancer(Helper helper) { childHelpers.add(helper); LoadBalancer childBalancer = mock(LoadBalancer.class); + when(childBalancer.acceptResolvedAddresses(any())).thenReturn(Status.OK); childBalancers.add(childBalancer); fooLbCreated++; return childBalancer; @@ -139,6 +140,7 @@ public String getPolicyName() { public LoadBalancer newLoadBalancer(Helper helper) { childHelpers.add(helper); LoadBalancer childBalancer = mock(LoadBalancer.class); + when(childBalancer.acceptResolvedAddresses(any())).thenReturn(Status.OK); childBalancers.add(childBalancer); barLbCreated++; return childBalancer; @@ -146,13 +148,13 @@ public LoadBalancer newLoadBalancer(Helper helper) { }; private final WeightedPolicySelection weightedLbConfig0 = new WeightedPolicySelection( - weights[0], new PolicySelection(fooLbProvider, configs[0])); + weights[0], newChildConfig(fooLbProvider, configs[0])); private final WeightedPolicySelection weightedLbConfig1 = new WeightedPolicySelection( - weights[1], new PolicySelection(barLbProvider, configs[1])); + weights[1], newChildConfig(barLbProvider, configs[1])); private final WeightedPolicySelection weightedLbConfig2 = new WeightedPolicySelection( - weights[2], new PolicySelection(barLbProvider, configs[2])); + weights[2], newChildConfig(barLbProvider, configs[2])); private final WeightedPolicySelection weightedLbConfig3 = new WeightedPolicySelection( - weights[3], new PolicySelection(fooLbProvider, configs[3])); + weights[3], newChildConfig(fooLbProvider, configs[3])); @Mock private Helper helper; @@ -180,7 +182,7 @@ public void tearDown() { } @Test - public void handleResolvedAddresses() { + public void acceptResolvedAddresses() { ArgumentCaptor resolvedAddressesCaptor = ArgumentCaptor.forClass(ResolvedAddresses.class); Attributes.Key fakeKey = Attributes.Key.create("fake_key"); @@ -203,12 +205,13 @@ public void handleResolvedAddresses() { eag2 = AddressFilter.setPathFilter(eag2, ImmutableList.of("target2")); EquivalentAddressGroup eag3 = new EquivalentAddressGroup(socketAddresses[3]); eag3 = AddressFilter.setPathFilter(eag3, ImmutableList.of("target3")); - weightedTargetLb.handleResolvedAddresses( + Status status = weightedTargetLb.acceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(ImmutableList.of(eag0, eag1, eag2, eag3)) .setAttributes(Attributes.newBuilder().set(fakeKey, fakeValue).build()) .setLoadBalancingPolicyConfig(new WeightedTargetConfig(targets)) .build()); + assertThat(status.isOk()).isTrue(); verify(helper).updateBalancingState(eq(CONNECTING), pickerReturns(PickResult.withNoResult())); assertThat(childBalancers).hasSize(4); assertThat(childHelpers).hasSize(4); @@ -216,14 +219,21 @@ public void handleResolvedAddresses() { assertThat(barLbCreated).isEqualTo(2); for (int i = 0; i < childBalancers.size(); i++) { - verify(childBalancers.get(i)).handleResolvedAddresses(resolvedAddressesCaptor.capture()); + verify(childBalancers.get(i)).acceptResolvedAddresses(resolvedAddressesCaptor.capture()); ResolvedAddresses resolvedAddresses = resolvedAddressesCaptor.getValue(); assertThat(resolvedAddresses.getLoadBalancingPolicyConfig()).isEqualTo(configs[i]); assertThat(resolvedAddresses.getAttributes().get(fakeKey)).isEqualTo(fakeValue); + assertThat(resolvedAddresses.getAttributes().get(WeightedTargetLoadBalancer.CHILD_NAME)) + .isEqualTo("target" + i); assertThat(Iterables.getOnlyElement(resolvedAddresses.getAddresses()).getAddresses()) .containsExactly(socketAddresses[i]); } + // Even when a child return an error from the update, the other children should still receive + // their updates. + Status acceptReturnStatus = Status.UNAVAILABLE.withDescription("Didn't like something"); + when(childBalancers.get(2).acceptResolvedAddresses(any())).thenReturn(acceptReturnStatus); + // Update new weighted target config for a typical workflow. // target0 removed. target1, target2, target3 changed weight and config. target4 added. int[] newWeights = new int[]{11, 22, 33, 44}; @@ -231,21 +241,22 @@ public void handleResolvedAddresses() { Map newTargets = ImmutableMap.of( "target1", new WeightedPolicySelection( - newWeights[0], new PolicySelection(barLbProvider, newConfigs[0])), + newWeights[0], newChildConfig(barLbProvider, newConfigs[0])), "target2", new WeightedPolicySelection( - newWeights[1], new PolicySelection(barLbProvider, newConfigs[1])), + newWeights[1], newChildConfig(barLbProvider, newConfigs[1])), "target3", new WeightedPolicySelection( - newWeights[2], new PolicySelection(fooLbProvider, newConfigs[2])), + newWeights[2], newChildConfig(fooLbProvider, newConfigs[2])), "target4", new WeightedPolicySelection( - newWeights[3], new PolicySelection(fooLbProvider, newConfigs[3]))); - weightedTargetLb.handleResolvedAddresses( + newWeights[3], newChildConfig(fooLbProvider, newConfigs[3]))); + status = weightedTargetLb.acceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(ImmutableList.of()) .setLoadBalancingPolicyConfig(new WeightedTargetConfig(newTargets)) .build()); + assertThat(status.getCode()).isEqualTo(acceptReturnStatus.getCode()); verify(helper, atLeast(2)) .updateBalancingState(eq(CONNECTING), pickerReturns(PickResult.withNoResult())); assertThat(childBalancers).hasSize(5); @@ -256,7 +267,7 @@ public void handleResolvedAddresses() { verify(childBalancers.get(0)).shutdown(); for (int i = 1; i < childBalancers.size(); i++) { verify(childBalancers.get(i), atLeastOnce()) - .handleResolvedAddresses(resolvedAddressesCaptor.capture()); + .acceptResolvedAddresses(resolvedAddressesCaptor.capture()); assertThat(resolvedAddressesCaptor.getValue().getLoadBalancingPolicyConfig()) .isEqualTo(newConfigs[i - 1]); } @@ -284,7 +295,7 @@ public void handleNameResolutionError() { "target2", weightedLbConfig2, // {foo, 40, config3} "target3", weightedLbConfig3); - weightedTargetLb.handleResolvedAddresses( + weightedTargetLb.acceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(ImmutableList.of()) .setLoadBalancingPolicyConfig(new WeightedTargetConfig(targets)) @@ -311,7 +322,7 @@ public void balancingStateUpdatedFromChildBalancers() { "target2", weightedLbConfig2, // {foo, 40, config3} "target3", weightedLbConfig3); - weightedTargetLb.handleResolvedAddresses( + weightedTargetLb.acceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(ImmutableList.of()) .setLoadBalancingPolicyConfig(new WeightedTargetConfig(targets)) @@ -393,7 +404,7 @@ public void raceBetweenShutdownAndChildLbBalancingStateUpdate() { Map targets = ImmutableMap.of( "target0", weightedLbConfig0, "target1", weightedLbConfig1); - weightedTargetLb.handleResolvedAddresses( + weightedTargetLb.acceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(ImmutableList.of()) .setLoadBalancingPolicyConfig(new WeightedTargetConfig(targets)) @@ -416,10 +427,10 @@ public void noDuplicateOverallBalancingStateUpdate() { Map targets = ImmutableMap.of( "target0", new WeightedPolicySelection( - weights[0], new PolicySelection(fakeLbProvider, configs[0])), + weights[0], newChildConfig(fakeLbProvider, configs[0])), "target3", new WeightedPolicySelection( - weights[3], new PolicySelection(fakeLbProvider, configs[3]))); - weightedTargetLb.handleResolvedAddresses( + weights[3], newChildConfig(fakeLbProvider, configs[3]))); + weightedTargetLb.acceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(ImmutableList.of()) .setLoadBalancingPolicyConfig(new WeightedTargetConfig(targets)) @@ -432,6 +443,10 @@ weights[0], new PolicySelection(fakeLbProvider, configs[0])), } + private Object newChildConfig(LoadBalancerProvider provider, Object config) { + return GracefulSwitchLoadBalancer.createLoadBalancingPolicyConfig(provider, config); + } + private static class FakeLoadBalancerProvider extends LoadBalancerProvider { @Override @@ -464,9 +479,10 @@ static class FakeLoadBalancer extends LoadBalancer { } @Override - public void handleResolvedAddresses(ResolvedAddresses resolvedAddresses) { + public Status acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { helper.updateBalancingState( TRANSIENT_FAILURE, new FixedResultPicker(PickResult.withError(Status.INTERNAL))); + return Status.OK; } @Override diff --git a/xds/src/test/java/io/grpc/xds/WrrLocalityLoadBalancerProviderTest.java b/xds/src/test/java/io/grpc/xds/WrrLocalityLoadBalancerProviderTest.java index d251f3677d8..c9ec2bb6af4 100644 --- a/xds/src/test/java/io/grpc/xds/WrrLocalityLoadBalancerProviderTest.java +++ b/xds/src/test/java/io/grpc/xds/WrrLocalityLoadBalancerProviderTest.java @@ -27,6 +27,7 @@ import io.grpc.LoadBalancerProvider; import io.grpc.LoadBalancerRegistry; import io.grpc.NameResolver; +import io.grpc.util.GracefulSwitchLoadBalancerAccessor; import io.grpc.xds.WrrLocalityLoadBalancer.WrrLocalityConfig; import java.util.Map; import org.junit.Test; @@ -64,6 +65,8 @@ public void parseConfig() { WrrLocalityLoadBalancerProvider provider = new WrrLocalityLoadBalancerProvider(); NameResolver.ConfigOrError configOrError = provider.parseLoadBalancingPolicyConfig(rawConfig); WrrLocalityConfig config = (WrrLocalityConfig) configOrError.getConfig(); - assertThat(config.childPolicy.getProvider().getPolicyName()).isEqualTo("round_robin"); + LoadBalancerProvider childProvider = + GracefulSwitchLoadBalancerAccessor.getChildProvider(config.childConfig); + assertThat(childProvider.getPolicyName()).isEqualTo("round_robin"); } } diff --git a/xds/src/test/java/io/grpc/xds/WrrLocalityLoadBalancerTest.java b/xds/src/test/java/io/grpc/xds/WrrLocalityLoadBalancerTest.java index 87ad876a182..584c32738c5 100644 --- a/xds/src/test/java/io/grpc/xds/WrrLocalityLoadBalancerTest.java +++ b/xds/src/test/java/io/grpc/xds/WrrLocalityLoadBalancerTest.java @@ -38,11 +38,10 @@ import io.grpc.LoadBalancerRegistry; import io.grpc.Status; import io.grpc.SynchronizationContext; -import io.grpc.internal.ServiceConfigUtil.PolicySelection; +import io.grpc.util.GracefulSwitchLoadBalancer; import io.grpc.xds.WeightedTargetLoadBalancerProvider.WeightedPolicySelection; import io.grpc.xds.WeightedTargetLoadBalancerProvider.WeightedTargetConfig; import io.grpc.xds.WrrLocalityLoadBalancer.WrrLocalityConfig; -import io.grpc.xds.client.Locality; import java.net.SocketAddress; import java.util.Collections; import java.util.List; @@ -109,11 +108,11 @@ public void setUp() { } @Test - public void handleResolvedAddresses() { + public void acceptResolvedAddresses() { // A two locality cluster with a mock child LB policy. - Locality localityOne = Locality.create("region1", "zone1", "subzone1"); - Locality localityTwo = Locality.create("region2", "zone2", "subzone2"); - PolicySelection childPolicy = new PolicySelection(mockChildProvider, null); + String localityOne = "localityOne"; + String localityTwo = "localityTwo"; + Object childPolicy = newChildConfig(mockChildProvider, null); // The child config is delivered wrapped in the wrr_locality config and the locality weights // in a ResolvedAddresses attribute. @@ -125,27 +124,26 @@ public void handleResolvedAddresses() { // Assert that the child policy and the locality weights were correctly mapped to a // WeightedTargetConfig. - verify(mockWeightedTargetLb).handleResolvedAddresses(resolvedAddressesCaptor.capture()); + verify(mockWeightedTargetLb).acceptResolvedAddresses(resolvedAddressesCaptor.capture()); Object config = resolvedAddressesCaptor.getValue().getLoadBalancingPolicyConfig(); assertThat(config).isInstanceOf(WeightedTargetConfig.class); WeightedTargetConfig wtConfig = (WeightedTargetConfig) config; assertThat(wtConfig.targets).hasSize(2); - assertThat(wtConfig.targets).containsEntry(localityOne.toString(), + assertThat(wtConfig.targets).containsEntry(localityOne, new WeightedPolicySelection(1, childPolicy)); - assertThat(wtConfig.targets).containsEntry(localityTwo.toString(), + assertThat(wtConfig.targets).containsEntry(localityTwo, new WeightedPolicySelection(2, childPolicy)); } @Test - public void handleResolvedAddresses_noLocalityWeights() { + public void acceptResolvedAddresses_noLocalityWeights() { // A two locality cluster with a mock child LB policy. - PolicySelection childPolicy = new PolicySelection(mockChildProvider, null); + Object childPolicy = newChildConfig(mockChildProvider, null); // The child config is delivered wrapped in the wrr_locality config and the locality weights // in a ResolvedAddresses attribute. WrrLocalityConfig wlConfig = new WrrLocalityConfig(childPolicy); - deliverAddresses(wlConfig, ImmutableList.of( - makeAddress("addr", Locality.create("test-region", "test-zone", "test-subzone"), null))); + deliverAddresses(wlConfig, ImmutableList.of(makeAddress("addr", "test-locality", null))); // With no locality weights, we should get a TRANSIENT_FAILURE. verify(mockHelper).getAuthority(); @@ -165,9 +163,8 @@ public void handleNameResolutionError_noChildLb() { @Test public void handleNameResolutionError_withChildLb() { - deliverAddresses(new WrrLocalityConfig(new PolicySelection(mockChildProvider, null)), - ImmutableList.of( - makeAddress("addr1", Locality.create("test-region1", "test-zone", "test-subzone"), 1))); + deliverAddresses(new WrrLocalityConfig(newChildConfig(mockChildProvider, null)), + ImmutableList.of(makeAddress("addr1", "test-locality", 1))); Status status = Status.DEADLINE_EXCEEDED.withDescription("too slow"); loadBalancer.handleNameResolutionError(status); @@ -178,25 +175,23 @@ public void handleNameResolutionError_withChildLb() { @Test public void localityWeightAttributeNotPropagated() { - PolicySelection childPolicy = new PolicySelection(mockChildProvider, null); + Object childPolicy = newChildConfig(mockChildProvider, null); WrrLocalityConfig wlConfig = new WrrLocalityConfig(childPolicy); - deliverAddresses(wlConfig, ImmutableList.of( - makeAddress("addr1", Locality.create("test-region1", "test-zone", "test-subzone"), 1))); + deliverAddresses(wlConfig, ImmutableList.of(makeAddress("addr1", "test-locality", 1))); // Assert that the child policy and the locality weights were correctly mapped to a // WeightedTargetConfig. - verify(mockWeightedTargetLb).handleResolvedAddresses(resolvedAddressesCaptor.capture()); + verify(mockWeightedTargetLb).acceptResolvedAddresses(resolvedAddressesCaptor.capture()); //assertThat(resolvedAddressesCaptor.getValue().getAttributes() - // .get(InternalXdsAttributes.ATTR_LOCALITY_WEIGHTS)).isNull(); + // .get(XdsAttributes.ATTR_LOCALITY_WEIGHTS)).isNull(); } @Test public void shutdown() { - deliverAddresses(new WrrLocalityConfig(new PolicySelection(mockChildProvider, null)), - ImmutableList.of( - makeAddress("addr", Locality.create("test-region", "test-zone", "test-subzone"), 1))); + deliverAddresses(new WrrLocalityConfig(newChildConfig(mockChildProvider, null)), + ImmutableList.of(makeAddress("addr", "test-locality", 1))); loadBalancer.shutdown(); verify(mockWeightedTargetLb).shutdown(); @@ -204,19 +199,21 @@ public void shutdown() { @Test public void configEquality() { - WrrLocalityConfig configOne = new WrrLocalityConfig( - new PolicySelection(mockChildProvider, null)); - WrrLocalityConfig configTwo = new WrrLocalityConfig( - new PolicySelection(mockChildProvider, null)); + WrrLocalityConfig configOne = new WrrLocalityConfig(newChildConfig(mockChildProvider, null)); + WrrLocalityConfig configTwo = new WrrLocalityConfig(newChildConfig(mockChildProvider, null)); WrrLocalityConfig differentConfig = new WrrLocalityConfig( - new PolicySelection(mockChildProvider, "config")); + newChildConfig(mockChildProvider, "config")); new EqualsTester().addEqualityGroup(configOne, configTwo).addEqualityGroup(differentConfig) .testEquals(); } + private Object newChildConfig(LoadBalancerProvider provider, Object config) { + return GracefulSwitchLoadBalancer.createLoadBalancingPolicyConfig(provider, config); + } + private void deliverAddresses(WrrLocalityConfig config, List addresses) { - loadBalancer.handleResolvedAddresses( + loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder().setAddresses(addresses).setLoadBalancingPolicyConfig(config) .build()); } @@ -224,7 +221,7 @@ private void deliverAddresses(WrrLocalityConfig config, List xdsClientPool; + private XdsClient xdsClient; + private boolean originalEnableXdsFallback; + private final FakeClock fakeClock = new FakeClock(); + private final MetricRecorder metricRecorder = new MetricRecorder() {}; + + @Mock + private XdsClientMetricReporter xdsClientMetricReporter; + + @Captor + private ArgumentCaptor> ldsUpdateCaptor; + @Captor + private ArgumentCaptor> rdsUpdateCaptor; + + private final XdsClient.ResourceWatcher raalLdsWatcher = + new XdsClient.ResourceWatcher() { + + @Override + public void onResourceChanged(StatusOr update) { + if (update.hasValue()) { + log.log(Level.FINE, "LDS update: " + update.getValue()); + } else { + log.log(Level.FINE, "LDS resource error: " + update.getStatus().getDescription()); + } + } + + @Override + public void onAmbientError(Status error) { + log.log(Level.FINE, "LDS ambient error: " + error.getDescription()); + } + }; + + @SuppressWarnings("unchecked") + private final XdsClient.ResourceWatcher ldsWatcher = + mock(XdsClient.ResourceWatcher.class, delegatesTo(raalLdsWatcher)); + @Mock + private XdsClient.ResourceWatcher ldsWatcher2; + + @Mock + private XdsClient.ResourceWatcher rdsWatcher; + @Mock + private XdsClient.ResourceWatcher rdsWatcher2; + @Mock + private XdsClient.ResourceWatcher rdsWatcher3; + + private final XdsClient.ResourceWatcher raalCdsWatcher = + new XdsClient.ResourceWatcher() { + + @Override + public void onResourceChanged(StatusOr update) { + if (update.hasValue()) { + log.log(Level.FINE, "CDS update: " + update.getValue()); + } else { + log.log(Level.FINE, "CDS resource error: " + update.getStatus().getDescription()); + } + } + + @Override + public void onAmbientError(Status error) { + // Logic from the old onError method for transient errors. + log.log(Level.FINE, "CDS ambient error: " + error.getDescription()); + } + }; + + @SuppressWarnings("unchecked") + private final XdsClient.ResourceWatcher cdsWatcher = + mock(XdsClient.ResourceWatcher.class, delegatesTo(raalCdsWatcher)); + @Mock + private XdsClient.ResourceWatcher cdsWatcher2; + + @Rule(order = 0) + public ControlPlaneRule mainXdsServer = + new ControlPlaneRule().setServerHostName(MAIN_SERVER); + + @Rule(order = 1) + public ControlPlaneRule fallbackServer = + new ControlPlaneRule().setServerHostName(MAIN_SERVER); + + @Rule public final MockitoRule mocks = MockitoJUnit.rule(); + + @Before + public void setUp() throws XdsInitializationException { + originalEnableXdsFallback = CommonBootstrapperTestUtils.setEnableXdsFallback(true); + if (mainXdsServer == null) { + throw new XdsInitializationException("Failed to create ControlPlaneRule for main TD server"); + } + setAdsConfig(mainXdsServer, MAIN_SERVER); + setAdsConfig(fallbackServer, FALLBACK_SERVER); + + SharedXdsClientPoolProvider clientPoolProvider = new SharedXdsClientPoolProvider(); + xdsClientPool = clientPoolProvider.getOrCreate( + DUMMY_TARGET, + new GrpcBootstrapperImpl().bootstrap(defaultBootstrapOverride()), + metricRecorder); + } + + @After + public void cleanUp() { + if (xdsClient != null) { + xdsClient = xdsClientPool.returnObject(xdsClient); + } + CommonBootstrapperTestUtils.setEnableXdsFallback(originalEnableXdsFallback); + } + + private static void setAdsConfig(ControlPlaneRule controlPlane, String serverName) { + InetSocketAddress edsInetSocketAddress = + (InetSocketAddress) controlPlane.getServer().getListenSockets().get(0); + boolean isMainServer = serverName.equals(MAIN_SERVER); + String rdsName = isMainServer + ? RDS_NAME + : FALLBACK_RDS_NAME; + String clusterName = isMainServer ? CLUSTER_NAME : FALLBACK_CLUSTER_NAME; + String edsName = isMainServer ? EDS_NAME : FALLBACK_EDS_NAME; + + controlPlane.setLdsConfig(ControlPlaneRule.buildServerListener(), + ControlPlaneRule.buildClientListener(MAIN_SERVER, rdsName)); + + controlPlane.setRdsConfig(rdsName, + XdsTestUtils.buildRouteConfiguration(MAIN_SERVER, rdsName, clusterName)); + controlPlane.setCdsConfig(clusterName, ControlPlaneRule.buildCluster(clusterName, edsName)); + + controlPlane.setEdsConfig(edsName, + ControlPlaneRule.buildClusterLoadAssignment(edsInetSocketAddress.getHostName(), + DataPlaneRule.ENDPOINT_HOST_NAME, edsInetSocketAddress.getPort(), edsName)); + log.log(Level.FINE, + String.format("Set ADS config for %s with address %s", serverName, edsInetSocketAddress)); + } + + // This is basically a control test to make sure everything is set up correctly. + @Test + public void everything_okay() { + mainXdsServer.restartXdsServer(); + fallbackServer.restartXdsServer(); + xdsClient = xdsClientPool.getObject(); + xdsClient.watchXdsResource(XdsListenerResource.getInstance(), MAIN_SERVER, ldsWatcher); + verify(ldsWatcher, timeout(5000)).onResourceChanged(ldsUpdateCaptor.capture()); + assertThat(ldsUpdateCaptor.getValue().hasValue()).isTrue(); + assertThat(ldsUpdateCaptor.getValue().getValue()).isEqualTo( + LdsUpdate.forApiListener(MAIN_HTTP_CONNECTION_MANAGER)); + + xdsClient.watchXdsResource(XdsRouteConfigureResource.getInstance(), RDS_NAME, rdsWatcher); + verify(rdsWatcher, timeout(5000)).onResourceChanged(rdsUpdateCaptor.capture()); + assertThat(rdsUpdateCaptor.getValue().hasValue()).isTrue(); + } + + @Test + public void mainServerDown_fallbackServerUp() { + mainXdsServer.getServer().shutdownNow(); + fallbackServer.restartXdsServer(); + xdsClient = xdsClientPool.getObject(); + log.log(Level.FINE, "Fallback port = " + fallbackServer.getServer().getPort()); + + xdsClient.watchXdsResource(XdsListenerResource.getInstance(), MAIN_SERVER, ldsWatcher); + + verify(ldsWatcher, timeout(5000)).onResourceChanged( + StatusOr.fromValue(XdsListenerResource.LdsUpdate.forApiListener( + FALLBACK_HTTP_CONNECTION_MANAGER))); + } + + @Test + public void useBadAuthority() { + xdsClient = xdsClientPool.getObject(); + InOrder inOrder = inOrder(ldsWatcher, rdsWatcher, rdsWatcher2, rdsWatcher3); + + String badPrefix = "xdstp://authority.xds.bad/envoy.config.listener.v3.Listener/"; + xdsClient.watchXdsResource(XdsListenerResource.getInstance(), + badPrefix + "listener.googleapis.com", ldsWatcher); + inOrder.verify(ldsWatcher, timeout(5000)).onResourceChanged( + argThat(statusOr -> !statusOr.hasValue())); + + xdsClient.watchXdsResource(XdsRouteConfigureResource.getInstance(), + badPrefix + "route-config.googleapis.bad", rdsWatcher); + xdsClient.watchXdsResource(XdsRouteConfigureResource.getInstance(), + badPrefix + "route-config2.googleapis.bad", rdsWatcher2); + xdsClient.watchXdsResource(XdsRouteConfigureResource.getInstance(), + badPrefix + "route-config3.googleapis.bad", rdsWatcher3); + inOrder.verify(rdsWatcher, timeout(5000)).onResourceChanged( + argThat(statusOr -> !statusOr.hasValue())); + inOrder.verify(rdsWatcher2, timeout(5000)).onResourceChanged( + argThat(statusOr -> !statusOr.hasValue())); + inOrder.verify(rdsWatcher3, timeout(5000)).onResourceChanged( + argThat(statusOr -> !statusOr.hasValue())); + verify(rdsWatcher, never()).onResourceChanged(argThat(StatusOr::hasValue)); + + // even after an error, a valid one will still work + xdsClient.watchXdsResource(XdsListenerResource.getInstance(), MAIN_SERVER, ldsWatcher2); + verify(ldsWatcher2, timeout(5000)).onResourceChanged(ldsUpdateCaptor.capture()); + StatusOr statusOr = ldsUpdateCaptor.getValue(); + assertThat(statusOr.hasValue()).isTrue(); + assertThat(statusOr.getValue()).isEqualTo( + XdsListenerResource.LdsUpdate.forApiListener(MAIN_HTTP_CONNECTION_MANAGER)); + } + + @Test + public void both_down_restart_main() { + mainXdsServer.getServer().shutdownNow(); + fallbackServer.getServer().shutdownNow(); + xdsClient = xdsClientPool.getObject(); + + xdsClient.watchXdsResource(XdsListenerResource.getInstance(), MAIN_SERVER, ldsWatcher); + verify(ldsWatcher, timeout(5000).atLeastOnce()) + .onResourceChanged(argThat(statusOr -> !statusOr.hasValue())); + verify(ldsWatcher, never()).onResourceChanged(argThat(StatusOr::hasValue)); + xdsClient.watchXdsResource( + XdsRouteConfigureResource.getInstance(), RDS_NAME, rdsWatcher2); + verify(rdsWatcher2, timeout(5000).atLeastOnce()) + .onResourceChanged(argThat(statusOr -> !statusOr.hasValue())); + + mainXdsServer.restartXdsServer(); + + xdsClient.watchXdsResource( + XdsRouteConfigureResource.getInstance(), RDS_NAME, rdsWatcher); + + verify(ldsWatcher, timeout(16000)).onResourceChanged( + argThat(statusOr -> statusOr.hasValue() && statusOr.getValue().equals( + XdsListenerResource.LdsUpdate.forApiListener(MAIN_HTTP_CONNECTION_MANAGER)))); + verify(rdsWatcher, timeout(5000)).onResourceChanged(argThat(StatusOr::hasValue)); + verify(rdsWatcher2, timeout(5000)).onResourceChanged(argThat(StatusOr::hasValue)); + } + + @Test + public void mainDown_fallbackUp_restart_main() { + mainXdsServer.getServer().shutdownNow(); + fallbackServer.restartXdsServer(); + xdsClient = xdsClientPool.getObject(); + InOrder inOrder = inOrder(ldsWatcher, rdsWatcher, cdsWatcher, cdsWatcher2); + + xdsClient.watchXdsResource(XdsListenerResource.getInstance(), MAIN_SERVER, ldsWatcher); + inOrder.verify(ldsWatcher, timeout(5000)).onResourceChanged( + StatusOr.fromValue(XdsListenerResource.LdsUpdate.forApiListener( + FALLBACK_HTTP_CONNECTION_MANAGER))); + + // Watch another resource, also from the fallback server. + xdsClient.watchXdsResource(XdsClusterResource.getInstance(), FALLBACK_CLUSTER_NAME, cdsWatcher); + @SuppressWarnings("unchecked") + ArgumentCaptor> cdsUpdateCaptor1 = ArgumentCaptor.forClass(StatusOr.class); + inOrder.verify(cdsWatcher, timeout(5000)).onResourceChanged(cdsUpdateCaptor1.capture()); + assertThat(cdsUpdateCaptor1.getValue().getStatus().isOk()).isTrue(); + + assertThat(fallbackServer.getService().getSubscriberCounts() + .get("type.googleapis.com/envoy.config.listener.v3.Listener")).isEqualTo(1); + verifyNoSubscribers(mainXdsServer); + + mainXdsServer.restartXdsServer(); + + // The existing ldsWatcher should receive a new update from the main server. + // Note: This is not an inOrder verification because the timing of the switchover + // can vary. We just need to verify it happens. + verify(ldsWatcher, timeout(5000)).onResourceChanged( + StatusOr.fromValue(XdsListenerResource.LdsUpdate.forApiListener( + MAIN_HTTP_CONNECTION_MANAGER))); + + // Watch a new resource; should now come from the main server. + xdsClient.watchXdsResource(XdsRouteConfigureResource.getInstance(), RDS_NAME, rdsWatcher); + @SuppressWarnings("unchecked") + ArgumentCaptor> rdsUpdateCaptor = ArgumentCaptor.forClass(StatusOr.class); + inOrder.verify(rdsWatcher, timeout(5000)).onResourceChanged(rdsUpdateCaptor.capture()); + assertThat(rdsUpdateCaptor.getValue().getStatus().isOk()).isTrue(); + verifyNoSubscribers(fallbackServer); + + xdsClient.watchXdsResource(XdsClusterResource.getInstance(), CLUSTER_NAME, cdsWatcher2); + @SuppressWarnings("unchecked") + ArgumentCaptor> cdsUpdateCaptor2 = ArgumentCaptor.forClass(StatusOr.class); + inOrder.verify(cdsWatcher2, timeout(5000)).onResourceChanged(cdsUpdateCaptor2.capture()); + assertThat(cdsUpdateCaptor2.getValue().getStatus().isOk()).isTrue(); + + verifyNoSubscribers(fallbackServer); + assertThat(mainXdsServer.getService().getSubscriberCounts() + .get("type.googleapis.com/envoy.config.listener.v3.Listener")).isEqualTo(1); + } + + private static void verifyNoSubscribers(ControlPlaneRule rule) { + for (Map.Entry me : rule.getService().getSubscriberCounts().entrySet()) { + String type = me.getKey(); + Integer count = me.getValue(); + assertWithMessage("Type with non-zero subscribers is: %s", type) + .that(count).isEqualTo(0); + } + } + + // This test takes a long time because of the 16 sec timeout for non-existent resource + @Test + public void connect_then_mainServerDown_fallbackServerUp() throws Exception { + mainXdsServer.restartXdsServer(); + fallbackServer.restartXdsServer(); + ExecutorService executor = Executors.newFixedThreadPool(1); + XdsTransportFactory xdsTransportFactory = new XdsTransportFactory() { + @Override + public XdsTransport create(Bootstrapper.ServerInfo serverInfo) { + ChannelCredentials channelCredentials = + (ChannelCredentials) serverInfo.implSpecificConfig(); + return new GrpcXdsTransportFactory.GrpcXdsTransport( + Grpc.newChannelBuilder(serverInfo.target(), channelCredentials) + .executor(executor) + .build()); + } + }; + XdsClientImpl xdsClient = CommonBootstrapperTestUtils.createXdsClient( + new GrpcBootstrapperImpl().bootstrap(defaultBootstrapOverride()), + xdsTransportFactory, fakeClock, new ExponentialBackoffPolicy.Provider(), + MessagePrinter.INSTANCE, xdsClientMetricReporter); + + xdsClient.watchXdsResource(XdsListenerResource.getInstance(), MAIN_SERVER, ldsWatcher); + + // Initial resource fetch from the main server + verify(ldsWatcher, timeout(5000)).onResourceChanged( + StatusOr.fromValue(LdsUpdate.forApiListener(MAIN_HTTP_CONNECTION_MANAGER))); + + xdsClient.watchXdsResource(XdsRouteConfigureResource.getInstance(), RDS_NAME, rdsWatcher); + verify(rdsWatcher, timeout(5000)).onResourceChanged(argThat(StatusOr::hasValue)); + + mainXdsServer.getServer().shutdownNow(); + // Sleep for the ADS stream disconnect to be processed and for the retry to fail. Between those + // two sleeps we need the fakeClock to progress by 1 second to restart the ADS stream. + for (int i = 0; i < 5; i++) { + // FakeClock is not thread-safe, and the retry scheduling is concurrent to this test thread + executor.submit(() -> fakeClock.forwardTime(1000, TimeUnit.MILLISECONDS)).get(); + TimeUnit.SECONDS.sleep(1); + } + + // Shouldn't do fallback since all watchers are loaded + verify(ldsWatcher, never()).onResourceChanged(StatusOr.fromValue( + XdsListenerResource.LdsUpdate.forApiListener(FALLBACK_HTTP_CONNECTION_MANAGER))); + + // Should just get from cache + xdsClient.watchXdsResource(XdsListenerResource.getInstance(), MAIN_SERVER, ldsWatcher2); + xdsClient.watchXdsResource(XdsRouteConfigureResource.getInstance(), RDS_NAME, rdsWatcher2); + verify(ldsWatcher2, timeout(5000)).onResourceChanged(StatusOr.fromValue( + XdsListenerResource.LdsUpdate.forApiListener(MAIN_HTTP_CONNECTION_MANAGER))); + verify(ldsWatcher, never()).onResourceChanged(StatusOr.fromValue( + XdsListenerResource.LdsUpdate.forApiListener(FALLBACK_HTTP_CONNECTION_MANAGER))); + // Make sure that rdsWatcher wasn't called again + verify(rdsWatcher, times(1)).onResourceChanged(any()); + verify(rdsWatcher2, timeout(5000)).onResourceChanged(argThat(StatusOr::hasValue)); + + // Asking for something not in cache should force a fallback + xdsClient.watchXdsResource(XdsClusterResource.getInstance(), FALLBACK_CLUSTER_NAME, cdsWatcher); + verify(ldsWatcher, timeout(5000)).onResourceChanged(StatusOr.fromValue( + XdsListenerResource.LdsUpdate.forApiListener(FALLBACK_HTTP_CONNECTION_MANAGER))); + verify(ldsWatcher2, timeout(5000)).onResourceChanged(StatusOr.fromValue( + XdsListenerResource.LdsUpdate.forApiListener(FALLBACK_HTTP_CONNECTION_MANAGER))); + verify(cdsWatcher, timeout(5000)).onResourceChanged(argThat(StatusOr::hasValue)); + + xdsClient.watchXdsResource( + XdsRouteConfigureResource.getInstance(), FALLBACK_RDS_NAME, rdsWatcher3); + verify(rdsWatcher3, timeout(5000)).onResourceChanged(argThat(StatusOr::hasValue)); + + // Test that resource defined in main but not fallback is handled correctly + xdsClient.watchXdsResource( + XdsClusterResource.getInstance(), CLUSTER_NAME, cdsWatcher2); + verify(cdsWatcher2, never()).onResourceChanged( + argThat(statusOr -> !statusOr.hasValue() + && statusOr.getStatus().getCode() == Status.Code.NOT_FOUND)); + fakeClock.forwardTime(15000, TimeUnit.MILLISECONDS); // Does not exist timer + verify(cdsWatcher2, timeout(5000)).onResourceChanged( + argThat(statusOr -> !statusOr.hasValue() + && statusOr.getStatus().getCode() == Status.Code.NOT_FOUND + && statusOr.getStatus().getDescription().contains(CLUSTER_NAME))); + xdsClient.shutdown(); + executor.shutdown(); + } + + @Test + public void connect_then_mainServerRestart_fallbackServerdown() { + mainXdsServer.restartXdsServer(); + xdsClient = xdsClientPool.getObject(); + xdsClient.watchXdsResource(XdsListenerResource.getInstance(), MAIN_SERVER, ldsWatcher); + + verify(ldsWatcher, timeout(5000)).onResourceChanged( + argThat(statusOr -> statusOr.hasValue() && statusOr.getValue().equals( + LdsUpdate.forApiListener(MAIN_HTTP_CONNECTION_MANAGER)))); + mainXdsServer.getServer().shutdownNow(); + fallbackServer.getServer().shutdownNow(); + xdsClient.watchXdsResource(XdsClusterResource.getInstance(), CLUSTER_NAME, cdsWatcher); + mainXdsServer.restartXdsServer(); + + verify(cdsWatcher, timeout(5000)).onResourceChanged( + argThat(statusOr -> statusOr.hasValue())); + verify(ldsWatcher, timeout(5000).atLeastOnce()).onResourceChanged( + argThat(statusOr -> statusOr.hasValue() && statusOr.getValue().equals( + LdsUpdate.forApiListener(MAIN_HTTP_CONNECTION_MANAGER)))); + } + + @Test + public void fallbackFromBadUrlToGoodOne() { + // Setup xdsClient to fail on stream creation + String garbageUri = "some. garbage"; + + String validUri = "localhost:" + mainXdsServer.getServer().getPort(); + XdsClientImpl client = + CommonBootstrapperTestUtils.createXdsClient( + Arrays.asList(garbageUri, validUri), + new GrpcXdsTransportFactory(null), + fakeClock, + new ExponentialBackoffPolicy.Provider(), + MessagePrinter.INSTANCE, + xdsClientMetricReporter); + + client.watchXdsResource(XdsListenerResource.getInstance(), MAIN_SERVER, ldsWatcher); + fakeClock.forwardTime(20, TimeUnit.SECONDS); + verify(ldsWatcher, timeout(5000)).onResourceChanged( + StatusOr.fromValue(XdsListenerResource.LdsUpdate.forApiListener( + MAIN_HTTP_CONNECTION_MANAGER))); + verify(ldsWatcher, never()).onAmbientError(any(Status.class)); + + client.shutdown(); + } + + @Test + public void testGoodUrlFollowedByBadUrl() { + // xdsClient should succeed in stream creation as it doesn't need to use the bad url + String garbageUri = "some. garbage"; + String validUri = "localhost:" + mainXdsServer.getServer().getPort(); + + XdsClientImpl client = + CommonBootstrapperTestUtils.createXdsClient( + Arrays.asList(validUri, garbageUri), + new GrpcXdsTransportFactory(null), + fakeClock, + new ExponentialBackoffPolicy.Provider(), + MessagePrinter.INSTANCE, + xdsClientMetricReporter); + + client.watchXdsResource(XdsListenerResource.getInstance(), MAIN_SERVER, ldsWatcher); + verify(ldsWatcher, timeout(5000)).onResourceChanged(ldsUpdateCaptor.capture()); + StatusOr statusOr = ldsUpdateCaptor.getValue(); + assertThat(statusOr.hasValue()).isTrue(); + assertThat(statusOr.getValue()).isEqualTo( + XdsListenerResource.LdsUpdate.forApiListener(MAIN_HTTP_CONNECTION_MANAGER)); + verify(ldsWatcher, never()).onAmbientError(any()); + verify(ldsWatcher, times(1)).onResourceChanged(any()); + + client.shutdown(); + } + + @Test + public void testTwoBadUrl() { + // Setup xdsClient to fail on stream creation + String garbageUri1 = "some. garbage"; + String garbageUri2 = "other garbage"; + + XdsClientImpl client = + CommonBootstrapperTestUtils.createXdsClient( + Arrays.asList(garbageUri1, garbageUri2), + new GrpcXdsTransportFactory(null), + fakeClock, + new ExponentialBackoffPolicy.Provider(), + MessagePrinter.INSTANCE, + xdsClientMetricReporter); + + client.watchXdsResource(XdsListenerResource.getInstance(), MAIN_SERVER, ldsWatcher); + fakeClock.forwardTime(20, TimeUnit.SECONDS); + verify(ldsWatcher, Mockito.timeout(5000).atLeastOnce()) + .onResourceChanged(ldsUpdateCaptor.capture()); + StatusOr statusOr = ldsUpdateCaptor.getValue(); + assertThat(statusOr.hasValue()).isFalse(); + assertThat(statusOr.getStatus().getDescription()).contains(garbageUri2); + verify(ldsWatcher, never()).onResourceChanged(argThat(StatusOr::hasValue)); + client.shutdown(); + } + + private Bootstrapper.ServerInfo getLrsServerInfo(String target) { + for (Map.Entry entry + : xdsClient.getServerLrsClientMap().entrySet()) { + if (entry.getKey().target().equals(target)) { + return entry.getKey(); + } + } + return null; + } + + @Test + public void used_then_mainServerRestart_fallbackServerUp() { + xdsClient = xdsClientPool.getObject(); + + xdsClient.watchXdsResource(XdsListenerResource.getInstance(), MAIN_SERVER, ldsWatcher); + + verify(ldsWatcher, timeout(5000)).onResourceChanged( + StatusOr.fromValue(LdsUpdate.forApiListener(MAIN_HTTP_CONNECTION_MANAGER))); + + mainXdsServer.restartXdsServer(); + + assertThat(getLrsServerInfo("localhost:" + fallbackServer.getServer().getPort())).isNull(); + assertThat(getLrsServerInfo("localhost:" + mainXdsServer.getServer().getPort())).isNotNull(); + + xdsClient.watchXdsResource(XdsClusterResource.getInstance(), CLUSTER_NAME, cdsWatcher); + + verify(cdsWatcher, timeout(5000)).onResourceChanged(any()); + assertThat(getLrsServerInfo("localhost:" + fallbackServer.getServer().getPort())).isNull(); + } + + private Map defaultBootstrapOverride() { + return ImmutableMap.of( + "node", ImmutableMap.of( + "id", UUID.randomUUID().toString(), + "cluster", CLUSTER_NAME), + "xds_servers", ImmutableList.of( + ImmutableMap.of( + "server_uri", "localhost:" + mainXdsServer.getServer().getPort(), + "channel_creds", Collections.singletonList( + ImmutableMap.of("type", "insecure") + ), + "server_features", Collections.singletonList("xds_v3") + ), + ImmutableMap.of( + "server_uri", "localhost:" + fallbackServer.getServer().getPort(), + "channel_creds", Collections.singletonList( + ImmutableMap.of("type", "insecure") + ), + "server_features", Collections.singletonList("xds_v3") + ) + ), + "fallback-policy", "fallback" + ); + } +} diff --git a/xds/src/test/java/io/grpc/xds/XdsClientFederationTest.java b/xds/src/test/java/io/grpc/xds/XdsClientFederationTest.java index 149c1d6170d..da310871c25 100644 --- a/xds/src/test/java/io/grpc/xds/XdsClientFederationTest.java +++ b/xds/src/test/java/io/grpc/xds/XdsClientFederationTest.java @@ -17,12 +17,18 @@ package io.grpc.xds; import static com.google.common.truth.Truth.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.argThat; +import static org.mockito.Mockito.never; import static org.mockito.Mockito.timeout; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import io.grpc.MetricRecorder; +import io.grpc.Status; +import io.grpc.StatusOr; import io.grpc.internal.ObjectPool; import io.grpc.xds.Filter.NamedFilterConfig; import io.grpc.xds.XdsListenerResource.LdsUpdate; @@ -45,6 +51,7 @@ import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; +import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.junit.MockitoJUnit; import org.mockito.junit.MockitoRule; @@ -72,12 +79,16 @@ public class XdsClientFederationTest { private ObjectPool xdsClientPool; private XdsClient xdsClient; + private static final String DUMMY_TARGET = "dummy"; + private final MetricRecorder metricRecorder = new MetricRecorder() {}; @Before public void setUp() throws XdsInitializationException { SharedXdsClientPoolProvider clientPoolProvider = new SharedXdsClientPoolProvider(); - clientPoolProvider.setBootstrapOverride(defaultBootstrapOverride()); - xdsClientPool = clientPoolProvider.getOrCreate(); + xdsClientPool = clientPoolProvider.getOrCreate( + DUMMY_TARGET, + new GrpcBootstrapperImpl().bootstrap(defaultBootstrapOverride()), + metricRecorder); xdsClient = xdsClientPool.getObject(); } @@ -102,14 +113,19 @@ public void isolatedResourceDeletions() { xdsClient.watchXdsResource(XdsListenerResource.getInstance(), "xdstp://server-one/envoy.config.listener.v3.Listener/test-server", mockDirectPathWatcher); - verify(mockWatcher, timeout(10000)).onChanged( - LdsUpdate.forApiListener( - HttpConnectionManager.forRdsName(0, "route-config.googleapis.com", ImmutableList.of( - new NamedFilterConfig("terminal-filter", RouterFilter.ROUTER_CONFIG))))); - verify(mockDirectPathWatcher, timeout(10000)).onChanged( - LdsUpdate.forApiListener( - HttpConnectionManager.forRdsName(0, "route-config.googleapis.com", ImmutableList.of( - new NamedFilterConfig("terminal-filter", RouterFilter.ROUTER_CONFIG))))); + @SuppressWarnings("unchecked") + ArgumentCaptor> captor = ArgumentCaptor.forClass(StatusOr.class); + LdsUpdate expectedUpdate = LdsUpdate.forApiListener( + HttpConnectionManager.forRdsName(0, "route-config.googleapis.com", ImmutableList.of( + new NamedFilterConfig("terminal-filter", RouterFilter.ROUTER_CONFIG)))); + + verify(mockWatcher, timeout(10000)).onResourceChanged(captor.capture()); + assertThat(captor.getValue().hasValue()).isTrue(); + assertThat(captor.getValue().getValue()).isEqualTo(expectedUpdate); + + verify(mockDirectPathWatcher, timeout(10000)).onResourceChanged(captor.capture()); + assertThat(captor.getValue().hasValue()).isTrue(); + assertThat(captor.getValue().getValue()).isEqualTo(expectedUpdate); // By setting the LDS config with a new server name we effectively make the old server to go // away as it is not in the configuration anymore. This change in one control plane (here the @@ -117,9 +133,13 @@ public void isolatedResourceDeletions() { // watcher of another control plane (here the DirectPath one). trafficdirector.setLdsConfig(ControlPlaneRule.buildServerListener(), ControlPlaneRule.buildClientListener("new-server")); - verify(mockWatcher, timeout(20000)).onResourceDoesNotExist("test-server"); - verify(mockDirectPathWatcher, times(0)).onResourceDoesNotExist( - "xdstp://server-one/envoy.config.listener.v3.Listener/test-server"); + verify(mockWatcher, timeout(20000)).onResourceChanged(argThat(statusOr -> { + return !statusOr.hasValue() + && statusOr.getStatus().getCode() == Status.Code.NOT_FOUND + && statusOr.getStatus().getDescription().contains("test-server"); + })); + verify(mockDirectPathWatcher, times(1)).onResourceChanged(any()); + verify(mockDirectPathWatcher, never()).onAmbientError(any()); } /** @@ -151,7 +171,6 @@ public void lrsClientsStartedForLocalityStats() throws InterruptedException, Exe } } - /** * Assures that when an {@link XdsClient} is asked to add cluster locality stats it appropriately * starts {@link LoadReportClient}s to do that. diff --git a/xds/src/test/java/io/grpc/xds/XdsClientMetricReporterImplTest.java b/xds/src/test/java/io/grpc/xds/XdsClientMetricReporterImplTest.java new file mode 100644 index 00000000000..509a0025b7b --- /dev/null +++ b/xds/src/test/java/io/grpc/xds/XdsClientMetricReporterImplTest.java @@ -0,0 +1,412 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds; + +import static com.google.common.truth.Truth.assertThat; +import static org.mockito.AdditionalAnswers.delegatesTo; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.argThat; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.inOrder; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.when; + +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Lists; +import com.google.common.util.concurrent.Futures; +import com.google.common.util.concurrent.ListenableFuture; +import com.google.common.util.concurrent.SettableFuture; +import com.google.protobuf.Any; +import io.envoyproxy.envoy.config.listener.v3.Listener; +import io.grpc.MetricInstrument; +import io.grpc.MetricRecorder; +import io.grpc.MetricRecorder.BatchCallback; +import io.grpc.MetricRecorder.BatchRecorder; +import io.grpc.MetricSink; +import io.grpc.xds.XdsClientMetricReporterImpl.MetricReporterCallback; +import io.grpc.xds.client.XdsClient; +import io.grpc.xds.client.XdsClient.ResourceMetadata; +import io.grpc.xds.client.XdsClient.ServerConnectionCallback; +import io.grpc.xds.client.XdsResourceType; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.logging.Handler; +import java.util.logging.Level; +import java.util.logging.LogRecord; +import java.util.logging.Logger; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.ArgumentCaptor; +import org.mockito.ArgumentMatcher; +import org.mockito.Captor; +import org.mockito.InOrder; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnit; +import org.mockito.junit.MockitoRule; + +/** + * Unit tests for {@link XdsClientMetricReporterImpl}. + */ +@RunWith(JUnit4.class) +public class XdsClientMetricReporterImplTest { + + private static final String target = "test-target"; + private static final String authority = "test-authority"; + private static final String server = "trafficdirector.googleapis.com"; + private static final String resourceTypeUrl = + "resourceTypeUrl.googleapis.com/envoy.config.cluster.v3.Cluster"; + + @Rule + public final MockitoRule mocks = MockitoJUnit.rule(); + + @Mock + private XdsClient mockXdsClient; + @Captor + private ArgumentCaptor gaugeBatchCallbackCaptor; + private MetricRecorder mockMetricRecorder = mock(MetricRecorder.class, + delegatesTo(new MetricRecorderImpl())); + private BatchRecorder mockBatchRecorder = mock(BatchRecorder.class, + delegatesTo(new BatchRecorderImpl())); + + private XdsClientMetricReporterImpl reporter; + + @Before + public void setUp() { + reporter = new XdsClientMetricReporterImpl(mockMetricRecorder, target); + } + + @Test + public void reportResourceUpdates() { + reporter.reportResourceUpdates(10, 5, server, resourceTypeUrl); + verify(mockMetricRecorder).addLongCounter( + eqMetricInstrumentName("grpc.xds_client.resource_updates_valid"), eq((long) 10), + eq(Lists.newArrayList(target, server, resourceTypeUrl)), + eq(Lists.newArrayList())); + verify(mockMetricRecorder).addLongCounter( + eqMetricInstrumentName("grpc.xds_client.resource_updates_invalid"), + eq((long) 5), + eq(Lists.newArrayList(target, server, resourceTypeUrl)), + eq(Lists.newArrayList())); + } + + @Test + public void reportServerFailure() { + reporter.reportServerFailure(1, server); + verify(mockMetricRecorder).addLongCounter( + eqMetricInstrumentName("grpc.xds_client.server_failure"), eq((long) 1), + eq(Lists.newArrayList(target, server)), + eq(Lists.newArrayList())); + } + + @Test + public void setXdsClient_reportMetrics() throws Exception { + SettableFuture future = SettableFuture.create(); + future.set(null); + when(mockXdsClient.getSubscribedResourcesMetadataSnapshot()).thenReturn(Futures.immediateFuture( + ImmutableMap.of())); + when(mockXdsClient.reportServerConnections(any(ServerConnectionCallback.class))) + .thenReturn(future); + reporter.setXdsClient(mockXdsClient); + verify(mockMetricRecorder).registerBatchCallback(gaugeBatchCallbackCaptor.capture(), + eqMetricInstrumentName("grpc.xds_client.connected"), + eqMetricInstrumentName("grpc.xds_client.resources")); + gaugeBatchCallbackCaptor.getValue().accept(mockBatchRecorder); + verify(mockXdsClient).reportServerConnections(any(ServerConnectionCallback.class)); + } + + @Test + public void setXdsClient_reportCallbackMetrics_resourceCountsFails() { + TestlogHandler testLogHandler = new TestlogHandler(); + Logger logger = Logger.getLogger(XdsClientMetricReporterImpl.class.getName()); + logger.addHandler(testLogHandler); + + // For reporting resource counts connections, return a normally completed future + SettableFuture future = SettableFuture.create(); + future.set(null); + when(mockXdsClient.getSubscribedResourcesMetadataSnapshot()).thenReturn(Futures.immediateFuture( + ImmutableMap.of())); + + // Create a future that will throw an exception + SettableFuture serverConnectionsFeature = SettableFuture.create(); + serverConnectionsFeature.setException(new Exception("test")); + when(mockXdsClient.reportServerConnections(any())).thenReturn(serverConnectionsFeature); + + reporter.setXdsClient(mockXdsClient); + verify(mockMetricRecorder) + .registerBatchCallback(gaugeBatchCallbackCaptor.capture(), any(), any()); + gaugeBatchCallbackCaptor.getValue().accept(mockBatchRecorder); + // Verify that the xdsClient methods were called + // verify(mockXdsClient).reportResourceCounts(any()); + verify(mockXdsClient).reportServerConnections(any()); + + assertThat(testLogHandler.getLogs().size()).isEqualTo(1); + assertThat(testLogHandler.getLogs().get(0).getLevel()).isEqualTo(Level.WARNING); + assertThat(testLogHandler.getLogs().get(0).getMessage()).isEqualTo( + "Failed to report gauge metrics"); + logger.removeHandler(testLogHandler); + } + + @Test + public void metricGauges() { + SettableFuture future = SettableFuture.create(); + future.set(null); + when(mockXdsClient.getSubscribedResourcesMetadataSnapshot()) + .thenReturn(Futures.immediateFuture(ImmutableMap.of())); + when(mockXdsClient.reportServerConnections(any(ServerConnectionCallback.class))) + .thenReturn(future); + reporter.setXdsClient(mockXdsClient); + verify(mockMetricRecorder).registerBatchCallback(gaugeBatchCallbackCaptor.capture(), + eqMetricInstrumentName("grpc.xds_client.connected"), + eqMetricInstrumentName("grpc.xds_client.resources")); + BatchCallback gaugeBatchCallback = gaugeBatchCallbackCaptor.getValue(); + InOrder inOrder = inOrder(mockBatchRecorder); + // Trigger the internal call to reportCallbackMetrics() + gaugeBatchCallback.accept(mockBatchRecorder); + + ArgumentCaptor serverConnectionCallbackCaptor = + ArgumentCaptor.forClass(ServerConnectionCallback.class); + // verify(mockXdsClient).reportResourceCounts(resourceCallbackCaptor.capture()); + verify(mockXdsClient).reportServerConnections(serverConnectionCallbackCaptor.capture()); + + // Get the captured callback + MetricReporterCallback callback = (MetricReporterCallback) + serverConnectionCallbackCaptor.getValue(); + + // Verify that reportResourceCounts and reportServerConnections were called + // with the captured callback + callback.reportResourceCountGauge(10, "MrPotatoHead", + "acked", resourceTypeUrl); + inOrder.verify(mockBatchRecorder) + .recordLongGauge(eqMetricInstrumentName("grpc.xds_client.resources"), eq(10L), any(), + any()); + callback.reportServerConnectionGauge(true, "xdsServer"); + inOrder.verify(mockBatchRecorder) + .recordLongGauge(eqMetricInstrumentName("grpc.xds_client.connected"), + eq(1L), any(), any()); + + inOrder.verifyNoMoreInteractions(); + } + + @Test + public void metricReporterCallback() { + MetricReporterCallback callback = + new MetricReporterCallback(mockBatchRecorder, target); + + callback.reportServerConnectionGauge(true, server); + verify(mockBatchRecorder, times(1)).recordLongGauge( + eqMetricInstrumentName("grpc.xds_client.connected"), eq(1L), + eq(Lists.newArrayList(target, server)), + eq(Lists.newArrayList())); + + String cacheState = "requested"; + callback.reportResourceCountGauge(10, authority, cacheState, resourceTypeUrl); + verify(mockBatchRecorder, times(1)).recordLongGauge( + eqMetricInstrumentName("grpc.xds_client.resources"), eq(10L), + eq(Arrays.asList(target, authority, cacheState, resourceTypeUrl)), + eq(Collections.emptyList())); + } + + @Test + public void reportCallbackMetrics_computeAndReportResourceCounts() { + Map, Map> metadataByType = new HashMap<>(); + XdsResourceType listenerResource = XdsListenerResource.getInstance(); + XdsResourceType routeConfigResource = XdsRouteConfigureResource.getInstance(); + XdsResourceType clusterResource = XdsClusterResource.getInstance(); + + Any rawListener = Any.pack(Listener.newBuilder().setName("listener.googleapis.com").build()); + long nanosLastUpdate = 1577923199_606042047L; + + Map ldsResourceMetadataMap = new HashMap<>(); + ldsResourceMetadataMap.put("xdstp://authority1", + ResourceMetadata.newResourceMetadataRequested()); + ResourceMetadata ackedLdsResource = + ResourceMetadata.newResourceMetadataAcked(rawListener, "42", nanosLastUpdate); + ldsResourceMetadataMap.put("resource2", ackedLdsResource); + ldsResourceMetadataMap.put("resource3", + ResourceMetadata.newResourceMetadataAcked(rawListener, "43", nanosLastUpdate)); + ldsResourceMetadataMap.put("xdstp:/no_authority", + ResourceMetadata.newResourceMetadataNacked(ackedLdsResource, "44", + nanosLastUpdate, "nacked after previous ack", true)); + + Map rdsResourceMetadataMap = new HashMap<>(); + ResourceMetadata requestedRdsResourceMetadata = ResourceMetadata.newResourceMetadataRequested(); + rdsResourceMetadataMap.put("xdstp://authority5", + ResourceMetadata.newResourceMetadataNacked(requestedRdsResourceMetadata, "24", + nanosLastUpdate, "nacked after request", false)); + rdsResourceMetadataMap.put("xdstp://authority6", + ResourceMetadata.newResourceMetadataDoesNotExist()); + + Map cdsResourceMetadataMap = new HashMap<>(); + cdsResourceMetadataMap.put("xdstp://authority7", ResourceMetadata.newResourceMetadataUnknown()); + + metadataByType.put(listenerResource, ldsResourceMetadataMap); + metadataByType.put(routeConfigResource, rdsResourceMetadataMap); + metadataByType.put(clusterResource, cdsResourceMetadataMap); + + SettableFuture reportServerConnectionsCompleted = SettableFuture.create(); + reportServerConnectionsCompleted.set(null); + when(mockXdsClient.reportServerConnections(any(MetricReporterCallback.class))) + .thenReturn(reportServerConnectionsCompleted); + + ListenableFuture, Map>> + getResourceMetadataCompleted = Futures.immediateFuture(metadataByType); + when(mockXdsClient.getSubscribedResourcesMetadataSnapshot()) + .thenReturn(getResourceMetadataCompleted); + + reporter.reportCallbackMetrics(mockBatchRecorder, mockXdsClient); + + // LDS resource requested + verify(mockBatchRecorder).recordLongGauge(eqMetricInstrumentName("grpc.xds_client.resources"), + eq(1L), + eq(Arrays.asList(target, "authority1", "requested", listenerResource.typeUrl())), any()); + // LDS resources acked + // authority = #old, for non-xdstp resource names + verify(mockBatchRecorder).recordLongGauge(eqMetricInstrumentName("grpc.xds_client.resources"), + eq(2L), + eq(Arrays.asList(target, "#old", "acked", listenerResource.typeUrl())), any()); + // LDS resource nacked but cached + // "" for missing authority in the resource name + verify(mockBatchRecorder).recordLongGauge(eqMetricInstrumentName("grpc.xds_client.resources"), + eq(1L), + eq(Arrays.asList(target, "", "nacked_but_cached", listenerResource.typeUrl())), any()); + + // RDS resource nacked + verify(mockBatchRecorder).recordLongGauge(eqMetricInstrumentName("grpc.xds_client.resources"), + eq(1L), + eq(Arrays.asList(target, "authority5", "nacked", routeConfigResource.typeUrl())), any()); + // RDS resource does not exist + verify(mockBatchRecorder).recordLongGauge(eqMetricInstrumentName("grpc.xds_client.resources"), + eq(1L), + eq(Arrays.asList(target, "authority6", "does_not_exist", routeConfigResource.typeUrl())), + any()); + + // CDS resource unknown + verify(mockBatchRecorder).recordLongGauge(eqMetricInstrumentName("grpc.xds_client.resources"), + eq(1L), + eq(Arrays.asList(target, "authority7", "unknown", clusterResource.typeUrl())), + any()); + verifyNoMoreInteractions(mockBatchRecorder); + } + + @Test + public void reportCallbackMetrics_computeAndReportResourceCounts_emptyResources() { + Map, Map> metadataByType = new HashMap<>(); + XdsResourceType listenerResource = XdsListenerResource.getInstance(); + metadataByType.put(listenerResource, Collections.emptyMap()); + + SettableFuture reportServerConnectionsCompleted = SettableFuture.create(); + reportServerConnectionsCompleted.set(null); + when(mockXdsClient.reportServerConnections(any(MetricReporterCallback.class))) + .thenReturn(reportServerConnectionsCompleted); + + ListenableFuture, Map>> + getResourceMetadataCompleted = Futures.immediateFuture(metadataByType); + when(mockXdsClient.getSubscribedResourcesMetadataSnapshot()) + .thenReturn(getResourceMetadataCompleted); + + reporter.reportCallbackMetrics(mockBatchRecorder, mockXdsClient); + + // Verify that reportResourceCountGauge is never called + verifyNoInteractions(mockBatchRecorder); + } + + @Test + public void reportCallbackMetrics_computeAndReportResourceCounts_nullMetadata() { + TestlogHandler testLogHandler = new TestlogHandler(); + Logger logger = Logger.getLogger(XdsClientMetricReporterImpl.class.getName()); + logger.addHandler(testLogHandler); + + SettableFuture reportServerConnectionsCompleted = SettableFuture.create(); + reportServerConnectionsCompleted.set(null); + when(mockXdsClient.reportServerConnections(any(MetricReporterCallback.class))) + .thenReturn(reportServerConnectionsCompleted); + + ListenableFuture, Map>> + getResourceMetadataCompleted = Futures.immediateFailedFuture( + new Exception("Error generating metadata snapshot")); + when(mockXdsClient.getSubscribedResourcesMetadataSnapshot()) + .thenReturn(getResourceMetadataCompleted); + + reporter.reportCallbackMetrics(mockBatchRecorder, mockXdsClient); + assertThat(testLogHandler.getLogs().size()).isEqualTo(1); + assertThat(testLogHandler.getLogs().get(0).getLevel()).isEqualTo(Level.WARNING); + assertThat(testLogHandler.getLogs().get(0).getMessage()).isEqualTo( + "Failed to report gauge metrics"); + logger.removeHandler(testLogHandler); + } + + @Test + public void close_closesGaugeRegistration() { + MetricSink.Registration mockRegistration = mock(MetricSink.Registration.class); + when(mockMetricRecorder.registerBatchCallback(any(MetricRecorder.BatchCallback.class), + eqMetricInstrumentName("grpc.xds_client.connected"), + eqMetricInstrumentName("grpc.xds_client.resources"))).thenReturn(mockRegistration); + + // Sets XdsClient and register the gauges + reporter.setXdsClient(mockXdsClient); + // Closes registered gauges + reporter.close(); + verify(mockRegistration, times(1)).close(); + } + + @SuppressWarnings("TypeParameterUnusedInFormals") + private T eqMetricInstrumentName(String name) { + return argThat(new ArgumentMatcher() { + @Override + public boolean matches(T instrument) { + return instrument.getName().equals(name); + } + }); + } + + static class MetricRecorderImpl implements MetricRecorder { + } + + static class BatchRecorderImpl implements BatchRecorder { + } + + static class TestlogHandler extends Handler { + List logs = new ArrayList<>(); + + @Override + public void publish(LogRecord record) { + logs.add(record); + } + + @Override + public void close() {} + + @Override + public void flush() {} + + public List getLogs() { + return logs; + } + } + +} diff --git a/xds/src/test/java/io/grpc/xds/XdsClientWrapperForServerSdsTestMisc.java b/xds/src/test/java/io/grpc/xds/XdsClientWrapperForServerSdsTestMisc.java index f3f4d74eb2f..81186d0639c 100644 --- a/xds/src/test/java/io/grpc/xds/XdsClientWrapperForServerSdsTestMisc.java +++ b/xds/src/test/java/io/grpc/xds/XdsClientWrapperForServerSdsTestMisc.java @@ -32,9 +32,11 @@ import com.google.common.collect.ImmutableList; import com.google.common.util.concurrent.SettableFuture; +import io.envoyproxy.envoy.config.core.v3.SocketAddress.Protocol; import io.grpc.Server; import io.grpc.ServerBuilder; import io.grpc.Status; +import io.grpc.StatusOr; import io.grpc.inprocess.InProcessSocketAddress; import io.grpc.internal.TestUtils.NoopChannelLogger; import io.grpc.netty.GrpcHttp2ConnectionHandler; @@ -119,7 +121,8 @@ public void setUp() { when(mockBuilder.build()).thenReturn(mockServer); when(mockServer.isShutdown()).thenReturn(false); xdsServerWrapper = new XdsServerWrapper("0.0.0.0:" + PORT, mockBuilder, listener, - selectorManager, new FakeXdsClientPoolFactory(xdsClient), FilterRegistry.newRegistry()); + selectorManager, new FakeXdsClientPoolFactory(xdsClient), + XdsServerTestHelper.RAW_BOOTSTRAP, FilterRegistry.newRegistry()); } @Test @@ -165,11 +168,12 @@ public void run() { EnvoyServerProtoData.Listener tcpListener = EnvoyServerProtoData.Listener.create( "listener1", - "10.1.2.3", + "0.0.0.0:7000", ImmutableList.of(), - null); + null, + Protocol.TCP); LdsUpdate listenerUpdate = LdsUpdate.forTcpListener(tcpListener); - xdsClient.ldsWatcher.onChanged(listenerUpdate); + xdsClient.ldsWatcher.onResourceChanged(StatusOr.fromValue(listenerUpdate)); verify(listener, timeout(5000)).onServing(); start.get(START_WAIT_AFTER_LISTENER_MILLIS, TimeUnit.MILLISECONDS); FilterChainSelector selector = selectorManager.getSelectorToUpdateSelector(); @@ -190,7 +194,8 @@ public void run() { } }); String ldsWatched = xdsClient.ldsResource.get(5, TimeUnit.SECONDS); - xdsClient.ldsWatcher.onResourceDoesNotExist(ldsWatched); + Status status = Status.NOT_FOUND.withDescription("Resource not found: " + ldsWatched); + xdsClient.ldsWatcher.onResourceChanged(StatusOr.fromStatus(status)); verify(listener, timeout(5000)).onNotServing(any()); try { start.get(START_WAIT_AFTER_LISTENER_MILLIS, TimeUnit.MILLISECONDS); @@ -275,7 +280,8 @@ public void releaseOldSupplierOnNotFound_verifyClose() throws Exception { getSslContextProviderSupplier(selectorManager.getSelectorToUpdateSelector()); assertThat(returnedSupplier.getTlsContext()).isSameInstanceAs(tlsContext1); callUpdateSslContext(returnedSupplier); - xdsClient.ldsWatcher.onResourceDoesNotExist("not-found Error"); + Status status = Status.NOT_FOUND.withDescription("not-found Error"); + xdsClient.ldsWatcher.onResourceChanged(StatusOr.fromStatus(status)); verify(tlsContextManager, times(1)).releaseServerSslContextProvider(eq(sslContextProvider1)); } @@ -292,14 +298,14 @@ public void releaseOldSupplierOnTemporaryError_noClose() throws Exception { getSslContextProviderSupplier(selectorManager.getSelectorToUpdateSelector()); assertThat(returnedSupplier.getTlsContext()).isSameInstanceAs(tlsContext1); callUpdateSslContext(returnedSupplier); - xdsClient.ldsWatcher.onError(Status.CANCELLED); + xdsClient.ldsWatcher.onAmbientError(Status.CANCELLED); verify(tlsContextManager, never()).releaseServerSslContextProvider(eq(sslContextProvider1)); } private void callUpdateSslContext(SslContextProviderSupplier sslContextProviderSupplier) { assertThat(sslContextProviderSupplier).isNotNull(); SslContextProvider.Callback callback = mock(SslContextProvider.Callback.class); - sslContextProviderSupplier.updateSslContext(callback); + sslContextProviderSupplier.updateSslContext(callback, false); } private void sendListenerUpdate( diff --git a/xds/src/test/java/io/grpc/xds/XdsDependencyManagerTest.java b/xds/src/test/java/io/grpc/xds/XdsDependencyManagerTest.java new file mode 100644 index 00000000000..522eb29c001 --- /dev/null +++ b/xds/src/test/java/io/grpc/xds/XdsDependencyManagerTest.java @@ -0,0 +1,1059 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds; + +import static com.google.common.truth.Truth.assertThat; +import static io.grpc.StatusMatcher.statusHasCode; +import static io.grpc.xds.XdsClusterResource.CdsUpdate.ClusterType.EDS; +import static io.grpc.xds.XdsTestControlPlaneService.ADS_TYPE_URL_CDS; +import static io.grpc.xds.XdsTestControlPlaneService.ADS_TYPE_URL_EDS; +import static io.grpc.xds.XdsTestControlPlaneService.ADS_TYPE_URL_LDS; +import static io.grpc.xds.XdsTestControlPlaneService.ADS_TYPE_URL_RDS; +import static io.grpc.xds.XdsTestUtils.CLUSTER_NAME; +import static io.grpc.xds.XdsTestUtils.ENDPOINT_HOSTNAME; +import static io.grpc.xds.XdsTestUtils.ENDPOINT_PORT; +import static io.grpc.xds.XdsTestUtils.RDS_NAME; +import static io.grpc.xds.XdsTestUtils.getEdsNameForCluster; +import static org.mockito.AdditionalAnswers.delegatesTo; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.argThat; +import static org.mockito.Mockito.atLeastOnce; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import com.google.common.util.concurrent.MoreExecutors; +import com.google.protobuf.Message; +import io.envoyproxy.envoy.config.cluster.v3.Cluster; +import io.envoyproxy.envoy.config.core.v3.Address; +import io.envoyproxy.envoy.config.core.v3.SocketAddress; +import io.envoyproxy.envoy.config.endpoint.v3.ClusterLoadAssignment; +import io.envoyproxy.envoy.config.endpoint.v3.Endpoint; +import io.envoyproxy.envoy.config.endpoint.v3.LbEndpoint; +import io.envoyproxy.envoy.config.endpoint.v3.LocalityLbEndpoints; +import io.envoyproxy.envoy.config.listener.v3.Listener; +import io.envoyproxy.envoy.config.route.v3.RouteConfiguration; +import io.grpc.BindableService; +import io.grpc.ChannelLogger; +import io.grpc.EquivalentAddressGroup; +import io.grpc.NameResolver; +import io.grpc.NameResolverRegistry; +import io.grpc.Status; +import io.grpc.StatusOr; +import io.grpc.StatusOrMatcher; +import io.grpc.SynchronizationContext; +import io.grpc.inprocess.InProcessChannelBuilder; +import io.grpc.inprocess.InProcessServerBuilder; +import io.grpc.internal.FakeClock; +import io.grpc.internal.GrpcUtil; +import io.grpc.internal.testing.FakeNameResolverProvider; +import io.grpc.testing.GrpcCleanupRule; +import io.grpc.xds.XdsClusterResource.CdsUpdate; +import io.grpc.xds.XdsConfig.XdsClusterConfig; +import io.grpc.xds.XdsEndpointResource.EdsUpdate; +import io.grpc.xds.client.Locality; +import io.grpc.xds.client.XdsClient; +import io.grpc.xds.client.XdsClient.ResourceMetadata; +import io.grpc.xds.client.XdsResourceType; +import java.io.Closeable; +import java.io.IOException; +import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Queue; +import java.util.Set; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.logging.Logger; +import org.junit.After; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.ArgumentCaptor; +import org.mockito.ArgumentMatcher; +import org.mockito.ArgumentMatchers; +import org.mockito.Captor; +import org.mockito.InOrder; +import org.mockito.Mockito; +import org.mockito.junit.MockitoJUnit; +import org.mockito.junit.MockitoRule; + +/** Unit tests for {@link XdsDependencyManager}. */ +@RunWith(JUnit4.class) +public class XdsDependencyManagerTest { + private static final Logger log = Logger.getLogger(XdsDependencyManagerTest.class.getName()); + public static final String CLUSTER_TYPE_NAME = XdsClusterResource.getInstance().typeName(); + public static final String ENDPOINT_TYPE_NAME = XdsEndpointResource.getInstance().typeName(); + + private final SynchronizationContext syncContext = + new SynchronizationContext((t, e) -> { + throw new AssertionError(e); + }); + private final FakeClock fakeClock = new FakeClock(); + + private XdsClient xdsClient = XdsTestUtils.createXdsClient( + Collections.singletonList("control-plane"), + serverInfo -> new GrpcXdsTransportFactory.GrpcXdsTransport( + InProcessChannelBuilder.forName(serverInfo.target()).directExecutor().build()), + fakeClock); + + private TestWatcher xdsConfigWatcher; + + private final String serverName = "the-service-name"; + private final Queue loadReportCalls = new ArrayDeque<>(); + private final AtomicBoolean adsEnded = new AtomicBoolean(true); + private final AtomicBoolean lrsEnded = new AtomicBoolean(true); + private final XdsTestControlPlaneService controlPlaneService = new XdsTestControlPlaneService(); + private final BindableService lrsService = + XdsTestUtils.createLrsService(lrsEnded, loadReportCalls); + private final NameResolverRegistry nameResolverRegistry = new NameResolverRegistry(); + + @Rule + public final GrpcCleanupRule cleanupRule = new GrpcCleanupRule(); + @Rule + public final MockitoRule mocks = MockitoJUnit.rule(); + private TestWatcher testWatcher; + private XdsConfig defaultXdsConfig; // set in setUp() + + @Captor + private ArgumentCaptor> xdsUpdateCaptor; + private final NameResolver.Args nameResolverArgs = NameResolver.Args.newBuilder() + .setDefaultPort(8080) + .setProxyDetector(GrpcUtil.DEFAULT_PROXY_DETECTOR) + .setSynchronizationContext(syncContext) + .setServiceConfigParser(mock(NameResolver.ServiceConfigParser.class)) + .setChannelLogger(mock(ChannelLogger.class)) + .setScheduledExecutorService(fakeClock.getScheduledExecutorService()) + .setNameResolverRegistry(nameResolverRegistry) + .build(); + + private XdsDependencyManager xdsDependencyManager = new XdsDependencyManager( + xdsClient, syncContext, serverName, serverName, nameResolverArgs); + private boolean savedEnableLogicalDns; + + @Before + public void setUp() throws Exception { + cleanupRule.register(InProcessServerBuilder + .forName("control-plane") + .addService(controlPlaneService) + .addService(lrsService) + .directExecutor() + .build() + .start()); + + XdsTestUtils.setAdsConfig(controlPlaneService, serverName); + + testWatcher = new TestWatcher(); + xdsConfigWatcher = mock(TestWatcher.class, delegatesTo(testWatcher)); + defaultXdsConfig = XdsTestUtils.getDefaultXdsConfig(serverName); + + savedEnableLogicalDns = XdsDependencyManager.enableLogicalDns; + } + + @After + public void tearDown() throws InterruptedException { + if (xdsDependencyManager != null) { + xdsDependencyManager.shutdown(); + } + xdsClient.shutdown(); + + assertThat(adsEnded.get()).isTrue(); + assertThat(lrsEnded.get()).isTrue(); + assertThat(fakeClock.getPendingTasks()).isEmpty(); + + XdsDependencyManager.enableLogicalDns = savedEnableLogicalDns; + } + + @Test + public void verify_basic_config() { + xdsDependencyManager.start(xdsConfigWatcher); + + verify(xdsConfigWatcher).onUpdate(StatusOr.fromValue(defaultXdsConfig)); + testWatcher.verifyStats(1, 0); + } + + @Test + public void verify_config_update() { + xdsDependencyManager.start(xdsConfigWatcher); + + InOrder inOrder = Mockito.inOrder(xdsConfigWatcher); + inOrder.verify(xdsConfigWatcher).onUpdate(StatusOr.fromValue(defaultXdsConfig)); + testWatcher.verifyStats(1, 0); + assertThat(testWatcher.lastConfig).isEqualTo(defaultXdsConfig); + + XdsTestUtils.setAdsConfig(controlPlaneService, serverName, "RDS2", "CDS2", "EDS2", + ENDPOINT_HOSTNAME + "2", ENDPOINT_PORT + 2); + inOrder.verify(xdsConfigWatcher).onUpdate(ArgumentMatchers.notNull()); + testWatcher.verifyStats(2, 0); + assertThat(testWatcher.lastConfig).isNotEqualTo(defaultXdsConfig); + } + + @Test + public void verify_simple_aggregate() { + InOrder inOrder = Mockito.inOrder(xdsConfigWatcher); + xdsDependencyManager.start(xdsConfigWatcher); + inOrder.verify(xdsConfigWatcher).onUpdate(StatusOr.fromValue(defaultXdsConfig)); + + List childNames = Arrays.asList("clusterC", "clusterB", "clusterA"); + String rootName = "root_c"; + + RouteConfiguration routeConfig = + XdsTestUtils.buildRouteConfiguration(serverName, XdsTestUtils.RDS_NAME, rootName); + controlPlaneService.setXdsConfig( + ADS_TYPE_URL_RDS, ImmutableMap.of(XdsTestUtils.RDS_NAME, routeConfig)); + + XdsTestUtils.setAggregateCdsConfig(controlPlaneService, serverName, rootName, childNames); + inOrder.verify(xdsConfigWatcher).onUpdate(any()); + + Map> lastConfigClusters = + testWatcher.lastConfig.getClusters(); + assertThat(lastConfigClusters).hasSize(childNames.size() + 1); + StatusOr rootC = lastConfigClusters.get(rootName); + assertThat(rootC.getValue().getChildren()).isInstanceOf(XdsClusterConfig.AggregateConfig.class); + XdsClusterConfig.AggregateConfig aggConfig = + (XdsClusterConfig.AggregateConfig) rootC.getValue().getChildren(); + assertThat(aggConfig.getLeafNames()).isEqualTo(childNames); + + for (String childName : childNames) { + assertThat(lastConfigClusters).containsKey(childName); + StatusOr childConfigOr = lastConfigClusters.get(childName); + CdsUpdate childResource = + childConfigOr.getValue().getClusterResource(); + assertThat(childResource.clusterType()).isEqualTo(EDS); + assertThat(childResource.edsServiceName()).isEqualTo(getEdsNameForCluster(childName)); + + StatusOr endpoint = getEndpoint(childConfigOr); + assertThat(endpoint.hasValue()).isTrue(); + assertThat(endpoint.getValue().clusterName).isEqualTo(getEdsNameForCluster(childName)); + } + } + + private static StatusOr getEndpoint(StatusOr childConfigOr) { + XdsClusterConfig.ClusterChild clusterChild = childConfigOr.getValue() + .getChildren(); + assertThat(clusterChild).isInstanceOf(XdsClusterConfig.EndpointConfig.class); + StatusOr endpoint = ((XdsClusterConfig.EndpointConfig) clusterChild).getEndpoint(); + assertThat(endpoint).isNotNull(); + return endpoint; + } + + @Test + public void testComplexRegisteredAggregate() throws IOException { + InOrder inOrder = Mockito.inOrder(xdsConfigWatcher); + + // Do initialization + String rootName1 = "root_c"; + List childNames = Arrays.asList("clusterC", "clusterB", "clusterA"); + XdsTestUtils.addAggregateToExistingConfig(controlPlaneService, rootName1, childNames); + + String rootName2 = "root_2"; + List childNames2 = Arrays.asList("clusterA", "clusterX"); + XdsTestUtils.addAggregateToExistingConfig(controlPlaneService, rootName2, childNames2); + + xdsDependencyManager.start(xdsConfigWatcher); + inOrder.verify(xdsConfigWatcher).onUpdate(any()); + + Closeable subscription1 = xdsDependencyManager.subscribeToCluster(rootName1); + inOrder.verify(xdsConfigWatcher).onUpdate(any()); + + Closeable subscription2 = xdsDependencyManager.subscribeToCluster(rootName2); + inOrder.verify(xdsConfigWatcher).onUpdate(xdsUpdateCaptor.capture()); + testWatcher.verifyStats(3, 0); + ImmutableSet.Builder builder = ImmutableSet.builder(); + Set expectedClusters = builder.add(rootName1).add(rootName2).add(CLUSTER_NAME) + .addAll(childNames).addAll(childNames2).build(); + assertThat(xdsUpdateCaptor.getValue().getValue().getClusters().keySet()) + .isEqualTo(expectedClusters); + + // Close 1 subscription shouldn't affect the other or RDS subscriptions + subscription1.close(); + inOrder.verify(xdsConfigWatcher).onUpdate(xdsUpdateCaptor.capture()); + builder = ImmutableSet.builder(); + Set expectedClusters2 = + builder.add(rootName2).add(CLUSTER_NAME).addAll(childNames2).build(); + assertThat(xdsUpdateCaptor.getValue().getValue().getClusters().keySet()) + .isEqualTo(expectedClusters2); + + subscription2.close(); + inOrder.verify(xdsConfigWatcher).onUpdate(StatusOr.fromValue(defaultXdsConfig)); + } + + @Test + public void testDelayedSubscription() { + InOrder inOrder = Mockito.inOrder(xdsConfigWatcher); + xdsDependencyManager.start(xdsConfigWatcher); + inOrder.verify(xdsConfigWatcher).onUpdate(StatusOr.fromValue(defaultXdsConfig)); + + String rootName1 = "root_c"; + + Closeable subscription1 = xdsDependencyManager.subscribeToCluster(rootName1); + assertThat(subscription1).isNotNull(); + fakeClock.forwardTime(16, TimeUnit.SECONDS); + inOrder.verify(xdsConfigWatcher).onUpdate(xdsUpdateCaptor.capture()); + Status status = xdsUpdateCaptor.getValue().getValue().getClusters().get(rootName1).getStatus(); + assertThat(status.getCode()).isEqualTo(Status.Code.UNAVAILABLE); + assertThat(status.getDescription()).contains(rootName1); + + List childNames = Arrays.asList("clusterC", "clusterB", "clusterA"); + XdsTestUtils.addAggregateToExistingConfig(controlPlaneService, rootName1, childNames); + inOrder.verify(xdsConfigWatcher).onUpdate(xdsUpdateCaptor.capture()); + assertThat(xdsUpdateCaptor.getValue().getValue().getClusters().get(rootName1).hasValue()) + .isTrue(); + } + + @Test + public void testMissingCdsAndEds() { + // update config so that agg cluster references 2 existing & 1 non-existing cluster + List childNames = Arrays.asList("clusterC", "clusterB", "clusterA"); + Cluster cluster = XdsTestUtils.buildAggCluster(CLUSTER_NAME, childNames); + Map clusterMap = new HashMap<>(); + Map edsMap = new HashMap<>(); + + clusterMap.put(CLUSTER_NAME, cluster); + for (int i = 0; i < childNames.size() - 1; i++) { + String edsName = XdsTestUtils.EDS_NAME + i; + Cluster child = ControlPlaneRule.buildCluster(childNames.get(i), edsName); + clusterMap.put(childNames.get(i), child); + } + controlPlaneService.setXdsConfig(ADS_TYPE_URL_CDS, clusterMap); + + // Update config so that one of the 2 "valid" clusters has an EDS resource, the other does not + // and there is an EDS that doesn't have matching clusters + ClusterLoadAssignment clusterLoadAssignment = ControlPlaneRule.buildClusterLoadAssignment( + "127.0.1.1", ENDPOINT_HOSTNAME, ENDPOINT_PORT, XdsTestUtils.EDS_NAME + 0); + edsMap.put(XdsTestUtils.EDS_NAME + 0, clusterLoadAssignment); + clusterLoadAssignment = ControlPlaneRule.buildClusterLoadAssignment( + "127.0.1.2", ENDPOINT_HOSTNAME, ENDPOINT_PORT, "garbageEds"); + edsMap.put("garbageEds", clusterLoadAssignment); + controlPlaneService.setXdsConfig(ADS_TYPE_URL_EDS, edsMap); + + xdsDependencyManager.start(xdsConfigWatcher); + + fakeClock.forwardTime(16, TimeUnit.SECONDS); + verify(xdsConfigWatcher).onUpdate(xdsUpdateCaptor.capture()); + + List> returnedClusters = new ArrayList<>(); + for (String childName : childNames) { + returnedClusters.add(xdsUpdateCaptor.getValue().getValue().getClusters().get(childName)); + } + + // Check that missing cluster reported Status and the other 2 are present + StatusOr missingCluster = returnedClusters.get(2); + assertThat(missingCluster.getStatus().getCode()).isEqualTo(Status.Code.UNAVAILABLE); + assertThat(missingCluster.getStatus().getDescription()).contains(childNames.get(2)); + assertThat(returnedClusters.get(0).hasValue()).isTrue(); + assertThat(returnedClusters.get(1).hasValue()).isTrue(); + + // Check that missing EDS reported Status, the other one is present and the garbage EDS is not + assertThat(getEndpoint(returnedClusters.get(0)).hasValue()).isTrue(); + assertThat(getEndpoint(returnedClusters.get(1)).getStatus().getCode()) + .isEqualTo(Status.Code.UNAVAILABLE); + assertThat(getEndpoint(returnedClusters.get(1)).getStatus().getDescription()) + .contains(XdsTestUtils.EDS_NAME + 1); + + verify(xdsConfigWatcher, never()).onUpdate( + argThat(StatusOrMatcher.hasStatus(statusHasCode(Status.Code.UNAVAILABLE)))); + testWatcher.verifyStats(1, 0); + } + + @Test + public void testMissingLds() { + String ldsName = "badLdsName"; + xdsDependencyManager = new XdsDependencyManager(xdsClient, syncContext, + serverName, ldsName, nameResolverArgs); + xdsDependencyManager.start(xdsConfigWatcher); + + fakeClock.forwardTime(16, TimeUnit.SECONDS); + verify(xdsConfigWatcher).onUpdate( + argThat(StatusOrMatcher.hasStatus(statusHasCode(Status.Code.UNAVAILABLE) + .andDescriptionContains(ldsName)))); + + testWatcher.verifyStats(0, 1); + } + + @Test + public void testTcpListenerErrors() { + Listener serverListener = + ControlPlaneRule.buildServerListener().toBuilder().setName(serverName).build(); + controlPlaneService.setXdsConfig(ADS_TYPE_URL_LDS, ImmutableMap.of(serverName, serverListener)); + xdsDependencyManager.start(xdsConfigWatcher); + + fakeClock.forwardTime(16, TimeUnit.SECONDS); + verify(xdsConfigWatcher).onUpdate( + argThat(StatusOrMatcher.hasStatus( + statusHasCode(Status.Code.UNAVAILABLE).andDescriptionContains("Not an API listener")))); + + testWatcher.verifyStats(0, 1); + } + + @Test + public void testControlPlaneError() { + Status forcedStatus = Status.NOT_FOUND + .withDescription("expected") + .withCause(new IllegalArgumentException("a random exception")); + xdsClient.shutdown(); + xdsClient = XdsTestUtils.createXdsClient( + Collections.singletonList("control-plane"), + serverInfo -> new GrpcXdsTransportFactory.GrpcXdsTransport( + InProcessChannelBuilder.forName(serverInfo.target()) + .directExecutor() + .intercept(new FailingClientInterceptor(forcedStatus)) + .build()), + fakeClock); + xdsDependencyManager = new XdsDependencyManager( + xdsClient, syncContext, serverName, serverName, nameResolverArgs); + xdsDependencyManager.start(xdsConfigWatcher); + + verify(xdsConfigWatcher).onUpdate( + argThat(StatusOrMatcher.hasStatus( + statusHasCode(Status.Code.UNAVAILABLE) + .andDescriptionContains(forcedStatus.getDescription()) + .andCause(forcedStatus.getCause())))); + testWatcher.verifyStats(0, 1); + } + + @Test + public void testMissingRds() { + String rdsName = "badRdsName"; + Listener clientListener = ControlPlaneRule.buildClientListener(serverName, rdsName); + controlPlaneService.setXdsConfig(ADS_TYPE_URL_LDS, + ImmutableMap.of(serverName, clientListener)); + + xdsDependencyManager.start(xdsConfigWatcher); + + fakeClock.forwardTime(16, TimeUnit.SECONDS); + verify(xdsConfigWatcher).onUpdate( + argThat(StatusOrMatcher.hasStatus(statusHasCode(Status.Code.UNAVAILABLE) + .andDescriptionContains(rdsName)))); + + testWatcher.verifyStats(0, 1); + } + + @Test + public void testUpdateToMissingVirtualHost() { + RouteConfiguration routeConfig = XdsTestUtils.buildRouteConfiguration( + "wrong-virtual-host", XdsTestUtils.RDS_NAME, XdsTestUtils.CLUSTER_NAME); + controlPlaneService.setXdsConfig( + ADS_TYPE_URL_RDS, ImmutableMap.of(XdsTestUtils.RDS_NAME, routeConfig)); + xdsDependencyManager.start(xdsConfigWatcher); + + // Update with a config that has a virtual host that doesn't match the server name + verify(xdsConfigWatcher).onUpdate(xdsUpdateCaptor.capture()); + assertThat(xdsUpdateCaptor.getValue().getStatus().getDescription()) + .contains("Failed to find virtual host matching hostname: " + serverName); + + testWatcher.verifyStats(0, 1); + } + + @Test + public void testCorruptLds() { + String ldsResourceName = + "xdstp://unknown.example.com/envoy.config.listener.v3.Listener/listener1"; + + xdsDependencyManager = new XdsDependencyManager(xdsClient, syncContext, + serverName, ldsResourceName, nameResolverArgs); + xdsDependencyManager.start(xdsConfigWatcher); + + verify(xdsConfigWatcher).onUpdate( + argThat(StatusOrMatcher.hasStatus( + statusHasCode(Status.Code.UNAVAILABLE).andDescriptionContains(ldsResourceName)))); + + fakeClock.forwardTime(16, TimeUnit.SECONDS); + testWatcher.verifyStats(0, 1); + } + + @Test + public void testChangeRdsName_fromLds() { + InOrder inOrder = Mockito.inOrder(xdsConfigWatcher); + xdsDependencyManager.start(xdsConfigWatcher); + inOrder.verify(xdsConfigWatcher).onUpdate(StatusOr.fromValue(defaultXdsConfig)); + + String newRdsName = "newRdsName1"; + + Listener clientListener = buildInlineClientListener(newRdsName, CLUSTER_NAME); + controlPlaneService.setXdsConfig(ADS_TYPE_URL_LDS, + ImmutableMap.of(serverName, clientListener)); + inOrder.verify(xdsConfigWatcher).onUpdate(xdsUpdateCaptor.capture()); + assertThat(xdsUpdateCaptor.getValue().getValue()).isNotEqualTo(defaultXdsConfig); + assertThat(xdsUpdateCaptor.getValue().getValue().getVirtualHost().name()).isEqualTo(newRdsName); + } + + @Test + public void testMultipleParentsInCdsTree() throws IOException { + /* + * Configure Xds server with the following cluster tree and point RDS to root: + 2 aggregates under root A & B + B has EDS Cluster B1 && shared agg AB1; A has agg A1 && shared agg AB1 + A1 has shared EDS Cluster A11 && shared agg AB1 + AB1 has shared EDS Clusters A11 && AB11 + + As an alternate visualization, parents are: + A -> root, B -> root, A1 -> A, AB1 -> A|B|A1, B1 -> B, A11 -> A1|AB1, AB11 -> AB1 + */ + Cluster rootCluster = + XdsTestUtils.buildAggCluster("root", Arrays.asList("clusterA", "clusterB")); + Cluster clusterA = + XdsTestUtils.buildAggCluster("clusterA", Arrays.asList("clusterA1", "clusterAB1")); + Cluster clusterB = + XdsTestUtils.buildAggCluster("clusterB", Arrays.asList("clusterB1", "clusterAB1")); + Cluster clusterA1 = + XdsTestUtils.buildAggCluster("clusterA1", Arrays.asList("clusterA11", "clusterAB1")); + Cluster clusterAB1 = + XdsTestUtils.buildAggCluster("clusterAB1", Arrays.asList("clusterA11", "clusterAB11")); + + Map clusterMap = new HashMap<>(); + Map edsMap = new HashMap<>(); + + clusterMap.put("root", rootCluster); + clusterMap.put("clusterA", clusterA); + clusterMap.put("clusterB", clusterB); + clusterMap.put("clusterA1", clusterA1); + clusterMap.put("clusterAB1", clusterAB1); + + XdsTestUtils.addEdsClusters(clusterMap, edsMap, "clusterA11", "clusterAB11", "clusterB1"); + RouteConfiguration routeConfig = + XdsTestUtils.buildRouteConfiguration(serverName, XdsTestUtils.RDS_NAME, "root"); + controlPlaneService.setXdsConfig( + ADS_TYPE_URL_RDS, ImmutableMap.of(XdsTestUtils.RDS_NAME, routeConfig)); + controlPlaneService.setXdsConfig(ADS_TYPE_URL_CDS, clusterMap); + controlPlaneService.setXdsConfig(ADS_TYPE_URL_EDS, edsMap); + + // Start the actual test + InOrder inOrder = Mockito.inOrder(xdsConfigWatcher); + xdsDependencyManager.start(xdsConfigWatcher); + inOrder.verify(xdsConfigWatcher).onUpdate(xdsUpdateCaptor.capture()); + XdsConfig initialConfig = xdsUpdateCaptor.getValue().getValue(); + + // Make sure that adding subscriptions that rds points at doesn't change the config + Closeable rootSub = xdsDependencyManager.subscribeToCluster("root"); + assertThat(xdsDependencyManager.buildUpdate().getValue()).isEqualTo(initialConfig); + Closeable clusterAB11Sub = xdsDependencyManager.subscribeToCluster("clusterAB11"); + assertThat(xdsDependencyManager.buildUpdate().getValue()).isEqualTo(initialConfig); + + // Make sure that closing subscriptions that rds points at doesn't change the config + rootSub.close(); + assertThat(xdsDependencyManager.buildUpdate().getValue()).isEqualTo(initialConfig); + clusterAB11Sub.close(); + assertThat(xdsDependencyManager.buildUpdate().getValue()).isEqualTo(initialConfig); + + // Make an explicit root subscription and then change RDS to point to A11 + rootSub = xdsDependencyManager.subscribeToCluster("root"); + RouteConfiguration newRouteConfig = + XdsTestUtils.buildRouteConfiguration(serverName, XdsTestUtils.RDS_NAME, "clusterA11"); + controlPlaneService.setXdsConfig( + ADS_TYPE_URL_RDS, ImmutableMap.of(XdsTestUtils.RDS_NAME, newRouteConfig)); + inOrder.verify(xdsConfigWatcher).onUpdate(xdsUpdateCaptor.capture()); + assertThat(xdsUpdateCaptor.getValue().getValue().getClusters()).hasSize(8); + + // Now that it is released, we should only have A11 + rootSub.close(); + inOrder.verify(xdsConfigWatcher).onUpdate(xdsUpdateCaptor.capture()); + assertThat(xdsUpdateCaptor.getValue().getValue().getClusters().keySet()) + .containsExactly("clusterA11"); + } + + @Test + public void testCdsDeleteUnsubscribesChild() throws Exception { + RouteConfiguration routeConfig = + XdsTestUtils.buildRouteConfiguration(serverName, XdsTestUtils.RDS_NAME, "clusterA"); + Map clusterMap = new HashMap<>(); + Map edsMap = new HashMap<>(); + XdsTestUtils.addEdsClusters(clusterMap, edsMap, "clusterA"); + controlPlaneService.setXdsConfig( + ADS_TYPE_URL_RDS, ImmutableMap.of(XdsTestUtils.RDS_NAME, routeConfig)); + controlPlaneService.setXdsConfig(ADS_TYPE_URL_CDS, clusterMap); + controlPlaneService.setXdsConfig(ADS_TYPE_URL_EDS, edsMap); + + InOrder inOrder = Mockito.inOrder(xdsConfigWatcher); + xdsDependencyManager.start(xdsConfigWatcher); + inOrder.verify(xdsConfigWatcher).onUpdate(xdsUpdateCaptor.capture()); + XdsConfig config = xdsUpdateCaptor.getValue().getValue(); + assertThat(config.getClusters().get("clusterA").hasValue()).isTrue(); + Map, Map> watches = + xdsClient.getSubscribedResourcesMetadataSnapshot().get(); + assertThat(watches.get(XdsEndpointResource.getInstance()).keySet()) + .containsExactly("eds_clusterA"); + + // Delete cluster + controlPlaneService.setXdsConfig(ADS_TYPE_URL_CDS, ImmutableMap.of()); + inOrder.verify(xdsConfigWatcher).onUpdate(xdsUpdateCaptor.capture()); + config = xdsUpdateCaptor.getValue().getValue(); + assertThat(config.getClusters().get("clusterA").hasValue()).isFalse(); + watches = xdsClient.getSubscribedResourcesMetadataSnapshot().get(); + assertThat(watches).doesNotContainKey(XdsEndpointResource.getInstance()); + } + + @Test + public void testCdsCycleReclaimed() throws Exception { + RouteConfiguration routeConfig = + XdsTestUtils.buildRouteConfiguration(serverName, XdsTestUtils.RDS_NAME, "clusterA"); + Map clusterMap = new HashMap<>(); + Map edsMap = new HashMap<>(); + clusterMap.put("clusterA", XdsTestUtils.buildAggCluster("clusterA", Arrays.asList("clusterB"))); + clusterMap.put("clusterB", XdsTestUtils.buildAggCluster("clusterB", Arrays.asList("clusterA"))); + XdsTestUtils.addEdsClusters(clusterMap, edsMap, "clusterC"); + controlPlaneService.setXdsConfig( + ADS_TYPE_URL_RDS, ImmutableMap.of(XdsTestUtils.RDS_NAME, routeConfig)); + controlPlaneService.setXdsConfig(ADS_TYPE_URL_CDS, clusterMap); + controlPlaneService.setXdsConfig(ADS_TYPE_URL_EDS, edsMap); + + // The cycle is loaded and detected + InOrder inOrder = Mockito.inOrder(xdsConfigWatcher); + xdsDependencyManager.start(xdsConfigWatcher); + inOrder.verify(xdsConfigWatcher).onUpdate(xdsUpdateCaptor.capture()); + XdsConfig config = xdsUpdateCaptor.getValue().getValue(); + assertThat(config.getClusters().get("clusterA").hasValue()).isFalse(); + assertThat(config.getClusters().get("clusterA").getStatus().getDescription()).contains("cycle"); + assertThat(config.getClusters().get("clusterB").hasValue()).isTrue(); + + // Orphan the cycle and it is discarded + routeConfig = + XdsTestUtils.buildRouteConfiguration(serverName, XdsTestUtils.RDS_NAME, "clusterC"); + controlPlaneService.setXdsConfig( + ADS_TYPE_URL_RDS, ImmutableMap.of(XdsTestUtils.RDS_NAME, routeConfig)); + inOrder.verify(xdsConfigWatcher).onUpdate(any()); + Map, Map> watches = + xdsClient.getSubscribedResourcesMetadataSnapshot().get(); + assertThat(watches.get(XdsClusterResource.getInstance()).keySet()).containsExactly("clusterC"); + } + + @Test + public void testMultipleCdsReferToSameEds() { + // Create the maps and Update the config to have 2 clusters that refer to the same EDS resource + String edsName = "sharedEds"; + + Cluster rootCluster = + XdsTestUtils.buildAggCluster("root", Arrays.asList("clusterA", "clusterB")); + Cluster clusterA = ControlPlaneRule.buildCluster("clusterA", edsName); + Cluster clusterB = ControlPlaneRule.buildCluster("clusterB", edsName); + + Map clusterMap = new HashMap<>(); + clusterMap.put("root", rootCluster); + clusterMap.put("clusterA", clusterA); + clusterMap.put("clusterB", clusterB); + + Map edsMap = new HashMap<>(); + ClusterLoadAssignment clusterLoadAssignment = ControlPlaneRule.buildClusterLoadAssignment( + "127.0.1.4", ENDPOINT_HOSTNAME, ENDPOINT_PORT, edsName); + edsMap.put(edsName, clusterLoadAssignment); + + RouteConfiguration routeConfig = + XdsTestUtils.buildRouteConfiguration(serverName, XdsTestUtils.RDS_NAME, "root"); + controlPlaneService.setXdsConfig( + ADS_TYPE_URL_RDS, ImmutableMap.of(XdsTestUtils.RDS_NAME, routeConfig)); + controlPlaneService.setXdsConfig(ADS_TYPE_URL_CDS, clusterMap); + controlPlaneService.setXdsConfig(ADS_TYPE_URL_EDS, edsMap); + + // Start the actual test + xdsDependencyManager.start(xdsConfigWatcher); + verify(xdsConfigWatcher).onUpdate(xdsUpdateCaptor.capture()); + XdsConfig initialConfig = xdsUpdateCaptor.getValue().getValue(); + assertThat(initialConfig.getClusters().keySet()) + .containsExactly("root", "clusterA", "clusterB"); + + EdsUpdate edsForA = getEndpoint(initialConfig.getClusters().get("clusterA")).getValue(); + assertThat(edsForA.clusterName).isEqualTo(edsName); + EdsUpdate edsForB = getEndpoint(initialConfig.getClusters().get("clusterB")).getValue(); + assertThat(edsForB.clusterName).isEqualTo(edsName); + assertThat(edsForA).isEqualTo(edsForB); + edsForA.localityLbEndpointsMap.values().forEach( + localityLbEndpoints -> assertThat(localityLbEndpoints.endpoints()).hasSize(1)); + } + + @Test + public void testChangeRdsName_FromLds_complexTree() { + xdsDependencyManager.start(xdsConfigWatcher); + + // Create the same tree as in testMultipleParentsInCdsTree + Cluster rootCluster = + XdsTestUtils.buildAggCluster("root", Arrays.asList("clusterA", "clusterB")); + Cluster clusterA = + XdsTestUtils.buildAggCluster("clusterA", Arrays.asList("clusterA1", "clusterAB1")); + Cluster clusterB = + XdsTestUtils.buildAggCluster("clusterB", Arrays.asList("clusterB1", "clusterAB1")); + Cluster clusterA1 = + XdsTestUtils.buildAggCluster("clusterA1", Arrays.asList("clusterA11", "clusterAB1")); + Cluster clusterAB1 = + XdsTestUtils.buildAggCluster("clusterAB1", Arrays.asList("clusterA11", "clusterAB11")); + + Map clusterMap = new HashMap<>(); + Map edsMap = new HashMap<>(); + + clusterMap.put("root", rootCluster); + clusterMap.put("clusterA", clusterA); + clusterMap.put("clusterB", clusterB); + clusterMap.put("clusterA1", clusterA1); + clusterMap.put("clusterAB1", clusterAB1); + + XdsTestUtils.addEdsClusters(clusterMap, edsMap, "clusterA11", "clusterAB11", "clusterB1"); + controlPlaneService.setXdsConfig(ADS_TYPE_URL_CDS, clusterMap); + controlPlaneService.setXdsConfig(ADS_TYPE_URL_EDS, edsMap); + + InOrder inOrder = Mockito.inOrder(xdsConfigWatcher); + inOrder.verify(xdsConfigWatcher, atLeastOnce()).onUpdate(any()); + + // Do the test + String newRdsName = "newRdsName1"; + Listener clientListener = buildInlineClientListener(newRdsName, "root"); + controlPlaneService.setXdsConfig(ADS_TYPE_URL_LDS, + ImmutableMap.of(serverName, clientListener)); + inOrder.verify(xdsConfigWatcher).onUpdate(xdsUpdateCaptor.capture()); + XdsConfig config = xdsUpdateCaptor.getValue().getValue(); + assertThat(config.getVirtualHost().name()).isEqualTo(newRdsName); + assertThat(config.getClusters()).hasSize(8); + } + + @Test + public void testChangeAggCluster() { + InOrder inOrder = Mockito.inOrder(xdsConfigWatcher); + + xdsDependencyManager.start(xdsConfigWatcher); + inOrder.verify(xdsConfigWatcher).onUpdate(any()); + + // Setup initial config A -> A1 -> (A11, A12) + Cluster rootCluster = + XdsTestUtils.buildAggCluster("root", Arrays.asList("clusterA")); + Cluster clusterA = + XdsTestUtils.buildAggCluster("clusterA", Arrays.asList("clusterA1")); + Cluster clusterA1 = + XdsTestUtils.buildAggCluster("clusterA1", Arrays.asList("clusterA11", "clusterA12")); + + Map clusterMap = new HashMap<>(); + Map edsMap = new HashMap<>(); + + clusterMap.put("root", rootCluster); + clusterMap.put("clusterA", clusterA); + clusterMap.put("clusterA1", clusterA1); + + XdsTestUtils.addEdsClusters(clusterMap, edsMap, "clusterA11", "clusterA12"); + Listener clientListener = buildInlineClientListener(RDS_NAME, "root"); + controlPlaneService.setXdsConfig(ADS_TYPE_URL_LDS, + ImmutableMap.of(serverName, clientListener)); + + controlPlaneService.setXdsConfig(ADS_TYPE_URL_CDS, clusterMap); + controlPlaneService.setXdsConfig(ADS_TYPE_URL_EDS, edsMap); + + inOrder.verify(xdsConfigWatcher).onUpdate(any()); + + // Update the cluster to A -> A2 -> (A21, A22) + Cluster clusterA2 = + XdsTestUtils.buildAggCluster("clusterA2", Arrays.asList("clusterA21", "clusterA22")); + clusterA = + XdsTestUtils.buildAggCluster("clusterA", Arrays.asList("clusterA2")); + clusterMap.clear(); + edsMap.clear(); + clusterMap.put("root", rootCluster); + clusterMap.put("clusterA", clusterA); + clusterMap.put("clusterA2", clusterA2); + XdsTestUtils.addEdsClusters(clusterMap, edsMap, "clusterA21", "clusterA22"); + controlPlaneService.setXdsConfig(ADS_TYPE_URL_CDS, clusterMap); + controlPlaneService.setXdsConfig(ADS_TYPE_URL_EDS, edsMap); + + // Verify that the config is updated as expected + ClusterNameMatcher nameMatcher = new ClusterNameMatcher(Arrays.asList( + "root", "clusterA", "clusterA2", "clusterA21", "clusterA22")); + inOrder.verify(xdsConfigWatcher).onUpdate(argThat(nameMatcher)); + } + + @Test + public void testLogicalDns_success() { + XdsDependencyManager.enableLogicalDns = true; + FakeSocketAddress fakeAddress = new FakeSocketAddress(); + nameResolverRegistry.register(new FakeNameResolverProvider( + "dns:///dns.example.com:1111", fakeAddress)); + Cluster cluster = Cluster.newBuilder() + .setName(CLUSTER_NAME) + .setType(Cluster.DiscoveryType.LOGICAL_DNS) + .setLoadAssignment(ClusterLoadAssignment.newBuilder() + .addEndpoints(LocalityLbEndpoints.newBuilder() + .addLbEndpoints(LbEndpoint.newBuilder() + .setEndpoint(Endpoint.newBuilder() + .setAddress(Address.newBuilder() + .setSocketAddress(SocketAddress.newBuilder() + .setAddress("dns.example.com") + .setPortValue(1111))))))) + .build(); + controlPlaneService.setXdsConfig(ADS_TYPE_URL_CDS, + ImmutableMap.of(CLUSTER_NAME, cluster)); + xdsDependencyManager.start(xdsConfigWatcher); + + verify(xdsConfigWatcher).onUpdate(xdsUpdateCaptor.capture()); + XdsConfig config = xdsUpdateCaptor.getValue().getValue(); + XdsClusterConfig.ClusterChild clusterChild = + config.getClusters().get(CLUSTER_NAME).getValue().getChildren(); + assertThat(clusterChild).isInstanceOf(XdsClusterConfig.EndpointConfig.class); + StatusOr endpointOr = ((XdsClusterConfig.EndpointConfig) clusterChild).getEndpoint(); + assertThat(endpointOr.getStatus()).isEqualTo(Status.OK); + assertThat(endpointOr.getValue()).isEqualTo(new EdsUpdate( + "fakeEds_logicalDns", + ImmutableMap.of( + Locality.create("", "", ""), + Endpoints.LocalityLbEndpoints.create( + Arrays.asList(Endpoints.LbEndpoint.create( + new EquivalentAddressGroup(fakeAddress), + 1, true, "dns.example.com:1111", ImmutableMap.of())), + 1, 0, ImmutableMap.of())), + Arrays.asList())); + } + + @Test + public void testLogicalDns_noDnsNr() { + XdsDependencyManager.enableLogicalDns = true; + Cluster cluster = Cluster.newBuilder() + .setName(CLUSTER_NAME) + .setType(Cluster.DiscoveryType.LOGICAL_DNS) + .setLoadAssignment(ClusterLoadAssignment.newBuilder() + .addEndpoints(LocalityLbEndpoints.newBuilder() + .addLbEndpoints(LbEndpoint.newBuilder() + .setEndpoint(Endpoint.newBuilder() + .setAddress(Address.newBuilder() + .setSocketAddress(SocketAddress.newBuilder() + .setAddress("dns.example.com") + .setPortValue(1111))))))) + .build(); + controlPlaneService.setXdsConfig(ADS_TYPE_URL_CDS, + ImmutableMap.of(CLUSTER_NAME, cluster)); + xdsDependencyManager.start(xdsConfigWatcher); + + verify(xdsConfigWatcher).onUpdate(xdsUpdateCaptor.capture()); + XdsConfig config = xdsUpdateCaptor.getValue().getValue(); + XdsClusterConfig.ClusterChild clusterChild = + config.getClusters().get(CLUSTER_NAME).getValue().getChildren(); + assertThat(clusterChild).isInstanceOf(XdsClusterConfig.EndpointConfig.class); + StatusOr endpointOr = ((XdsClusterConfig.EndpointConfig) clusterChild).getEndpoint(); + assertThat(endpointOr.getStatus().getCode()).isEqualTo(Status.Code.INTERNAL); + assertThat(endpointOr.getStatus().getDescription()) + .isEqualTo("Could not find dns name resolver"); + } + + @Test + public void testCdsError() throws IOException { + controlPlaneService.setXdsConfig( + ADS_TYPE_URL_CDS, ImmutableMap.of(XdsTestUtils.CLUSTER_NAME, + Cluster.newBuilder().setName(XdsTestUtils.CLUSTER_NAME).build())); + xdsDependencyManager.start(xdsConfigWatcher); + + verify(xdsConfigWatcher).onUpdate(xdsUpdateCaptor.capture()); + Status status = xdsUpdateCaptor.getValue().getValue() + .getClusters().get(CLUSTER_NAME).getStatus(); + assertThat(status.getDescription()).contains(XdsTestUtils.CLUSTER_NAME); + } + + @Test + public void ldsUpdateAfterShutdown() { + XdsTestUtils.setAdsConfig(controlPlaneService, serverName, "RDS", "CDS", "EDS", + ENDPOINT_HOSTNAME, ENDPOINT_PORT); + + xdsDependencyManager.start(xdsConfigWatcher); + + verify(xdsConfigWatcher).onUpdate(any()); + + @SuppressWarnings("unchecked") + XdsClient.ResourceWatcher resourceWatcher = + mock(XdsClient.ResourceWatcher.class); + xdsClient.watchXdsResource( + XdsListenerResource.getInstance(), + serverName, + resourceWatcher, + MoreExecutors.directExecutor()); + verify(resourceWatcher).onResourceChanged(argThat(StatusOr::hasValue)); + + syncContext.execute(() -> { + // Shutdown before any updates. This will unsubscribe from XdsClient, but only after this + // Runnable returns + xdsDependencyManager.shutdown(); + + XdsTestUtils.setAdsConfig(controlPlaneService, serverName, "RDS2", "CDS", "EDS", + ENDPOINT_HOSTNAME, ENDPOINT_PORT); + verify(resourceWatcher, times(2)).onResourceChanged(argThat(StatusOr::hasValue)); + xdsClient.cancelXdsResourceWatch( + XdsListenerResource.getInstance(), serverName, resourceWatcher); + }); + } + + @Test + public void rdsUpdateAfterShutdown() { + XdsTestUtils.setAdsConfig(controlPlaneService, serverName, "RDS", "CDS", "EDS", + ENDPOINT_HOSTNAME, ENDPOINT_PORT); + + xdsDependencyManager.start(xdsConfigWatcher); + + verify(xdsConfigWatcher).onUpdate(any()); + + @SuppressWarnings("unchecked") + XdsClient.ResourceWatcher resourceWatcher = + mock(XdsClient.ResourceWatcher.class); + xdsClient.watchXdsResource( + XdsRouteConfigureResource.getInstance(), + "RDS", + resourceWatcher, + MoreExecutors.directExecutor()); + verify(resourceWatcher).onResourceChanged(argThat(StatusOr::hasValue)); + + syncContext.execute(() -> { + // Shutdown before any updates. This will unsubscribe from XdsClient, but only after this + // Runnable returns + xdsDependencyManager.shutdown(); + + XdsTestUtils.setAdsConfig(controlPlaneService, serverName, "RDS", "CDS2", "EDS", + ENDPOINT_HOSTNAME, ENDPOINT_PORT); + verify(resourceWatcher, times(2)).onResourceChanged(argThat(StatusOr::hasValue)); + xdsClient.cancelXdsResourceWatch( + XdsRouteConfigureResource.getInstance(), serverName, resourceWatcher); + }); + } + + @Test + public void cdsUpdateAfterShutdown() { + XdsTestUtils.setAdsConfig(controlPlaneService, serverName, "RDS", "CDS", "EDS", + ENDPOINT_HOSTNAME, ENDPOINT_PORT); + + xdsDependencyManager.start(xdsConfigWatcher); + + verify(xdsConfigWatcher).onUpdate(any()); + + @SuppressWarnings("unchecked") + XdsClient.ResourceWatcher resourceWatcher = + mock(XdsClient.ResourceWatcher.class); + xdsClient.watchXdsResource( + XdsClusterResource.getInstance(), + "CDS", + resourceWatcher, + MoreExecutors.directExecutor()); + verify(resourceWatcher).onResourceChanged(argThat(StatusOr::hasValue)); + + syncContext.execute(() -> { + // Shutdown before any updates. This will unsubscribe from XdsClient, but only after this + // Runnable returns + xdsDependencyManager.shutdown(); + + XdsTestUtils.setAdsConfig(controlPlaneService, serverName, "RDS", "CDS", "EDS2", + ENDPOINT_HOSTNAME, ENDPOINT_PORT); + verify(resourceWatcher, times(2)).onResourceChanged(argThat(StatusOr::hasValue)); + xdsClient.cancelXdsResourceWatch( + XdsClusterResource.getInstance(), serverName, resourceWatcher); + }); + } + + @Test + public void edsUpdateAfterShutdown() { + XdsTestUtils.setAdsConfig(controlPlaneService, serverName, "RDS", "CDS", "EDS", + ENDPOINT_HOSTNAME, ENDPOINT_PORT); + + xdsDependencyManager.start(xdsConfigWatcher); + + verify(xdsConfigWatcher).onUpdate(any()); + + @SuppressWarnings("unchecked") + XdsClient.ResourceWatcher resourceWatcher = + mock(XdsClient.ResourceWatcher.class); + xdsClient.watchXdsResource( + XdsEndpointResource.getInstance(), + "EDS", + resourceWatcher, + MoreExecutors.directExecutor()); + verify(resourceWatcher).onResourceChanged(argThat(StatusOr::hasValue)); + + syncContext.execute(() -> { + // Shutdown before any updates. This will unsubscribe from XdsClient, but only after this + // Runnable returns + xdsDependencyManager.shutdown(); + + XdsTestUtils.setAdsConfig(controlPlaneService, serverName, "RDS", "CDS", "EDS", + ENDPOINT_HOSTNAME + "2", ENDPOINT_PORT); + verify(resourceWatcher, times(2)).onResourceChanged(argThat(StatusOr::hasValue)); + xdsClient.cancelXdsResourceWatch( + XdsEndpointResource.getInstance(), serverName, resourceWatcher); + }); + } + + @Test + public void subscribeToClusterAfterShutdown() throws Exception { + XdsTestUtils.setAdsConfig(controlPlaneService, serverName, "RDS", "CDS", "EDS", + ENDPOINT_HOSTNAME, ENDPOINT_PORT); + + InOrder inOrder = Mockito.inOrder(xdsConfigWatcher); + xdsDependencyManager.start(xdsConfigWatcher); + inOrder.verify(xdsConfigWatcher).onUpdate(any()); + xdsDependencyManager.shutdown(); + + Closeable subscription = xdsDependencyManager.subscribeToCluster("CDS"); + inOrder.verify(xdsConfigWatcher, never()).onUpdate(any()); + subscription.close(); + } + + private Listener buildInlineClientListener(String rdsName, String clusterName) { + return XdsTestUtils.buildInlineClientListener(rdsName, clusterName, serverName); + } + + private static class TestWatcher implements XdsDependencyManager.XdsConfigWatcher { + XdsConfig lastConfig; + int numUpdates = 0; + int numError = 0; + + @Override + public void onUpdate(StatusOr update) { + log.fine("Config update: " + update); + if (update.hasValue()) { + lastConfig = update.getValue(); + numUpdates++; + } else { + numError++; + } + } + + private List getStats() { + return Arrays.asList(numUpdates, numError); + } + + private void verifyStats(int updt, int err) { + assertThat(getStats()).isEqualTo(Arrays.asList(updt, err)); + } + } + + static class ClusterNameMatcher implements ArgumentMatcher> { + private final List expectedNames; + + ClusterNameMatcher(List expectedNames) { + this.expectedNames = expectedNames; + } + + @Override + public boolean matches(StatusOr update) { + if (!update.hasValue()) { + return false; + } + XdsConfig xdsConfig = update.getValue(); + if (xdsConfig == null || xdsConfig.getClusters() == null) { + return false; + } + return xdsConfig.getClusters().size() == expectedNames.size() + && xdsConfig.getClusters().keySet().containsAll(expectedNames); + } + } + + private static class FakeSocketAddress extends java.net.SocketAddress {} +} diff --git a/xds/src/test/java/io/grpc/xds/XdsNameResolverProviderTest.java b/xds/src/test/java/io/grpc/xds/XdsNameResolverProviderTest.java index a216c3de028..8998a2bae99 100644 --- a/xds/src/test/java/io/grpc/xds/XdsNameResolverProviderTest.java +++ b/xds/src/test/java/io/grpc/xds/XdsNameResolverProviderTest.java @@ -23,23 +23,28 @@ import com.google.common.collect.ImmutableMap; import io.grpc.ChannelLogger; import io.grpc.InternalServiceProviders; +import io.grpc.MetricRecorder; import io.grpc.NameResolver; import io.grpc.NameResolver.ServiceConfigParser; import io.grpc.NameResolverProvider; import io.grpc.NameResolverRegistry; import io.grpc.SynchronizationContext; +import io.grpc.Uri; import io.grpc.internal.FakeClock; import io.grpc.internal.GrpcUtil; import java.net.URI; +import java.util.Arrays; import java.util.Collections; import java.util.HashMap; import java.util.Map; import org.junit.Test; import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; +import org.junit.runners.Parameterized; +import org.junit.runners.Parameterized.Parameter; +import org.junit.runners.Parameterized.Parameters; /** Unit tests for {@link XdsNameResolverProvider}. */ -@RunWith(JUnit4.class) +@RunWith(Parameterized.class) public class XdsNameResolverProviderTest { private final SynchronizationContext syncContext = new SynchronizationContext( new Thread.UncaughtExceptionHandler() { @@ -57,10 +62,18 @@ public void uncaughtException(Thread t, Throwable e) { .setServiceConfigParser(mock(ServiceConfigParser.class)) .setScheduledExecutorService(fakeClock.getScheduledExecutorService()) .setChannelLogger(mock(ChannelLogger.class)) + .setMetricRecorder(mock(MetricRecorder.class)) .build(); private XdsNameResolverProvider provider = new XdsNameResolverProvider(); + @Parameters(name = "enableRfc3986UrisParam={0}") + public static Iterable data() { + return Arrays.asList(new Object[][] {{true}, {false}}); + } + + @Parameter public boolean enableRfc3986UrisParam; + @Test public void provided() { for (NameResolverProvider current @@ -79,48 +92,46 @@ public void isAvailable() { } @Test - public void newNameResolver() { - assertThat( - provider.newNameResolver(URI.create("xds://1.1.1.1/foo.googleapis.com"), args)) + public void newNameResolver_returnsExpectedType() { + assertThat(newNameResolver(provider, "xds://1.1.1.1/foo.googleapis.com", args)) .isInstanceOf(XdsNameResolver.class); - assertThat( - provider.newNameResolver(URI.create("xds:///foo.googleapis.com"), args)) + assertThat(newNameResolver(provider, "xds:///foo.googleapis.com", args)) .isInstanceOf(XdsNameResolver.class); - assertThat( - provider.newNameResolver(URI.create("notxds://1.1.1.1/foo.googleapis.com"), - args)) - .isNull(); + } + + @Test + public void newNameResolver_matchesExpectedScheme() { + assertThat(newNameResolver(provider, "notxds://1.1.1.1/foo.googleapis.com", args)).isNull(); } @Test public void validName_withAuthority() { - XdsNameResolver resolver = - provider.newNameResolver( - URI.create("xds://trafficdirector.google.com/foo.googleapis.com"), args); + NameResolver resolver = + newNameResolver(provider, "xds://trafficdirector.google.com/foo.googleapis.com", args); assertThat(resolver).isNotNull(); assertThat(resolver.getServiceAuthority()).isEqualTo("foo.googleapis.com"); } @Test public void validName_noAuthority() { - XdsNameResolver resolver = - provider.newNameResolver(URI.create("xds:///foo.googleapis.com"), args); + NameResolver resolver = newNameResolver(provider, "xds:///foo.googleapis.com", args); assertThat(resolver).isNotNull(); assertThat(resolver.getServiceAuthority()).isEqualTo("foo.googleapis.com"); } @Test public void validName_urlExtractedAuthorityInvalidWithoutEncoding() { - XdsNameResolver resolver = - provider.newNameResolver(URI.create("xds:///1234/path/foo.googleapis.com:8080"), args); + NameResolver resolver = + newNameResolver(provider, "xds:///1234/path/foo.googleapis.com:8080", args); assertThat(resolver).isNotNull(); assertThat(resolver.getServiceAuthority()).isEqualTo("1234%2Fpath%2Ffoo.googleapis.com:8080"); } @Test public void validName_urlwithTargetAuthorityAndExtractedAuthorityInvalidWithoutEncoding() { - XdsNameResolver resolver = provider.newNameResolver(URI.create( - "xds://trafficdirector.google.com/1234/path/foo.googleapis.com:8080"), args); + NameResolver resolver = + newNameResolver( + provider, "xds://trafficdirector.google.com/1234/path/foo.googleapis.com:8080", args); assertThat(resolver).isNotNull(); assertThat(resolver.getServiceAuthority()).isEqualTo("1234%2Fpath%2Ffoo.googleapis.com:8080"); } @@ -133,18 +144,14 @@ public void newProvider_multipleScheme() { XdsNameResolverProvider provider1 = XdsNameResolverProvider.createForTest("new-xds-scheme", new HashMap()); registry.register(provider1); - assertThat(registry.asFactory() - .newNameResolver(URI.create("xds:///localhost"), args)).isNotNull(); - assertThat(registry.asFactory() - .newNameResolver(URI.create("new-xds-scheme:///localhost"), args)).isNotNull(); - assertThat(registry.asFactory() - .newNameResolver(URI.create("no-scheme:///localhost"), args)).isNotNull(); + assertThat(newNameResolver(registry.asFactory(), "xds:///localhost", args)).isNotNull(); + assertThat(newNameResolver(registry.asFactory(), "new-xds-scheme:///localhost", args)) + .isNotNull(); + assertThat(newNameResolver(registry.asFactory(), "no-scheme:///localhost", args)).isNotNull(); registry.deregister(provider1); - assertThat(registry.asFactory() - .newNameResolver(URI.create("new-xds-scheme:///localhost"), args)).isNull(); + assertThat(newNameResolver(registry.asFactory(), "new-xds-scheme:///localhost", args)).isNull(); registry.deregister(provider0); - assertThat(registry.asFactory() - .newNameResolver(URI.create("xds:///localhost"), args)).isNotNull(); + assertThat(newNameResolver(registry.asFactory(), "xds:///localhost", args)).isNotNull(); } @Test @@ -174,4 +181,11 @@ public void newProvider_overrideBootstrap() { resolver.shutdown(); registry.deregister(provider); } + + private NameResolver newNameResolver( + NameResolver.Factory factory, String uriString, NameResolver.Args args) { + return enableRfc3986UrisParam + ? factory.newNameResolver(Uri.create(uriString), args) + : factory.newNameResolver(URI.create(uriString), args); + } } diff --git a/xds/src/test/java/io/grpc/xds/XdsNameResolverTest.java b/xds/src/test/java/io/grpc/xds/XdsNameResolverTest.java index cca2d84373c..e78f97635ed 100644 --- a/xds/src/test/java/io/grpc/xds/XdsNameResolverTest.java +++ b/xds/src/test/java/io/grpc/xds/XdsNameResolverTest.java @@ -17,15 +17,19 @@ package io.grpc.xds; import static com.google.common.truth.Truth.assertThat; +import static com.google.common.truth.Truth.assertWithMessage; import static io.grpc.xds.FaultFilter.HEADER_ABORT_GRPC_STATUS_KEY; import static io.grpc.xds.FaultFilter.HEADER_ABORT_HTTP_STATUS_KEY; import static io.grpc.xds.FaultFilter.HEADER_ABORT_PERCENTAGE_KEY; import static io.grpc.xds.FaultFilter.HEADER_DELAY_KEY; import static io.grpc.xds.FaultFilter.HEADER_DELAY_PERCENTAGE_KEY; +import static org.mockito.AdditionalAnswers.delegatesTo; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyInt; import static org.mockito.ArgumentMatchers.anyLong; import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.atLeast; +import static org.mockito.Mockito.lenient; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.reset; @@ -43,6 +47,7 @@ import com.google.re2j.Pattern; import io.grpc.CallOptions; import io.grpc.Channel; +import io.grpc.ChannelLogger; import io.grpc.ClientCall; import io.grpc.ClientInterceptor; import io.grpc.ClientInterceptors; @@ -50,9 +55,12 @@ import io.grpc.InsecureChannelCredentials; import io.grpc.InternalConfigSelector; import io.grpc.InternalConfigSelector.Result; +import io.grpc.LoadBalancer.PickDetailsConsumer; +import io.grpc.LoadBalancer.PickSubchannelArgs; import io.grpc.Metadata; import io.grpc.MethodDescriptor; import io.grpc.MethodDescriptor.MethodType; +import io.grpc.MetricRecorder; import io.grpc.NameResolver; import io.grpc.NameResolver.ConfigOrError; import io.grpc.NameResolver.ResolutionResult; @@ -61,9 +69,11 @@ import io.grpc.NoopClientCall.NoopClientCallListener; import io.grpc.Status; import io.grpc.Status.Code; +import io.grpc.StatusOr; import io.grpc.SynchronizationContext; import io.grpc.internal.AutoConfiguredLoadBalancerFactory; import io.grpc.internal.FakeClock; +import io.grpc.internal.GrpcUtil; import io.grpc.internal.JsonParser; import io.grpc.internal.JsonUtil; import io.grpc.internal.ObjectPool; @@ -83,6 +93,8 @@ import io.grpc.xds.VirtualHost.Route.RouteAction.RetryPolicy; import io.grpc.xds.VirtualHost.Route.RouteMatch; import io.grpc.xds.VirtualHost.Route.RouteMatch.PathMatcher; +import io.grpc.xds.XdsClusterResource.CdsUpdate; +import io.grpc.xds.XdsEndpointResource.EdsUpdate; import io.grpc.xds.XdsListenerResource.LdsUpdate; import io.grpc.xds.XdsRouteConfigureResource.RdsUpdate; import io.grpc.xds.client.Bootstrapper.AuthorityInfo; @@ -90,14 +102,16 @@ import io.grpc.xds.client.Bootstrapper.ServerInfo; import io.grpc.xds.client.EnvoyProtoData.Node; import io.grpc.xds.client.XdsClient; -import io.grpc.xds.client.XdsInitializationException; import io.grpc.xds.client.XdsResourceType; import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Set; import java.util.concurrent.Executor; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; @@ -123,6 +137,17 @@ public class XdsNameResolverTest { private static final String RDS_RESOURCE_NAME = "route-configuration.googleapis.com"; private static final String FAULT_FILTER_INSTANCE_NAME = "envoy.fault"; private static final String ROUTER_FILTER_INSTANCE_NAME = "envoy.router"; + private static final FaultFilter.Provider FAULT_FILTER_PROVIDER = new FaultFilter.Provider(); + private static final RouterFilter.Provider ROUTER_FILTER_PROVIDER = new RouterFilter.Provider(); + + // Readability: makes it simpler to distinguish resource parameters. + private static final ImmutableMap NO_FILTER_OVERRIDES = ImmutableMap.of(); + private static final ImmutableList NO_HASH_POLICIES = ImmutableList.of(); + + // Stateful instance filter names. + private static final String STATEFUL_1 = "test.stateful.filter.1"; + private static final String STATEFUL_2 = "test.stateful.filter.2"; + @Rule public final MockitoRule mocks = MockitoJUnit.rule(); private final SynchronizationContext syncContext = new SynchronizationContext( @@ -146,6 +171,16 @@ public ConfigOrError parseServiceConfig(Map rawServiceConfig) { private final CallInfo call1 = new CallInfo("HelloService", "hi"); private final CallInfo call2 = new CallInfo("GreetService", "bye"); private final TestChannel channel = new TestChannel(); + private final MetricRecorder metricRecorder = new MetricRecorder() {}; + private final Map rawBootstrap = ImmutableMap.of( + "xds_servers", ImmutableList.of( + ImmutableMap.of( + "server_uri", "td.googleapis.com", + "channel_creds", ImmutableList.of( + ImmutableMap.of( + "type", "insecure"))) + )); + private BootstrapInfo bootstrapInfo = BootstrapInfo.builder() .servers(ImmutableList.of(ServerInfo.create( "td.googleapis.com", InsecureChannelCredentials.create()))) @@ -164,64 +199,77 @@ public ConfigOrError parseServiceConfig(Map rawServiceConfig) { private XdsNameResolver resolver; private TestCall testCall; private boolean originalEnableTimeout; + private String targetUri = AUTHORITY; + private final NameResolver.Args nameResolverArgs = NameResolver.Args.newBuilder() + .setDefaultPort(8080) + .setProxyDetector(GrpcUtil.DEFAULT_PROXY_DETECTOR) + .setSynchronizationContext(syncContext) + .setServiceConfigParser(mock(NameResolver.ServiceConfigParser.class)) + .setChannelLogger(mock(ChannelLogger.class)) + .setScheduledExecutorService(fakeClock.getScheduledExecutorService()) + .build(); + @Before public void setUp() { + lenient().doReturn(Status.OK).when(mockListener).onResult2(any()); + originalEnableTimeout = XdsNameResolver.enableTimeout; XdsNameResolver.enableTimeout = true; + + // Replace FaultFilter.Provider with the one returning FaultFilter injected with mockRandom. + Filter.Provider faultFilterProvider = + mock(Filter.Provider.class, delegatesTo(FAULT_FILTER_PROVIDER)); + // Lenient: suppress [MockitoHint] Unused warning, only used in resolved_fault* tests. + lenient() + .doReturn(new FaultFilter(mockRandom, new AtomicLong())) + .when(faultFilterProvider).newInstance(any(String.class)); + FilterRegistry filterRegistry = FilterRegistry.newRegistry().register( - new FaultFilter(mockRandom, new AtomicLong()), - RouterFilter.INSTANCE); - resolver = new XdsNameResolver(null, AUTHORITY, null, + ROUTER_FILTER_PROVIDER, + faultFilterProvider); + + resolver = new XdsNameResolver(targetUri, null, AUTHORITY, null, serviceConfigParser, syncContext, scheduler, - xdsClientPoolFactory, mockRandom, filterRegistry, null); + xdsClientPoolFactory, mockRandom, filterRegistry, rawBootstrap, metricRecorder, + nameResolverArgs); } @After public void tearDown() { XdsNameResolver.enableTimeout = originalEnableTimeout; + if (resolver == null) { + // Allow tests to test shutdown. + return; + } FakeXdsClient xdsClient = (FakeXdsClient) resolver.getXdsClient(); resolver.shutdown(); if (xdsClient != null) { assertThat(xdsClient.ldsWatcher).isNull(); - assertThat(xdsClient.rdsWatcher).isNull(); + assertThat(xdsClient.rdsWatchers).isEmpty(); } } @Test public void resolving_failToCreateXdsClientPool() { - XdsClientPoolFactory xdsClientPoolFactory = new XdsClientPoolFactory() { - @Override - public void setBootstrapOverride(Map bootstrap) { - } - - @Override - @Nullable - public ObjectPool get() { - throw new UnsupportedOperationException("Should not be called"); - } - - @Override - public ObjectPool getOrCreate() throws XdsInitializationException { - throw new XdsInitializationException("Fail to read bootstrap file"); - } - }; - resolver = new XdsNameResolver(null, AUTHORITY, null, + resolver = new XdsNameResolver(targetUri, null, AUTHORITY, null, serviceConfigParser, syncContext, scheduler, - xdsClientPoolFactory, mockRandom, FilterRegistry.getDefaultRegistry(), null); + xdsClientPoolFactory, mockRandom, FilterRegistry.getDefaultRegistry(), + Collections.emptyMap(), metricRecorder, nameResolverArgs); resolver.start(mockListener); verify(mockListener).onError(errorCaptor.capture()); Status error = errorCaptor.getValue(); assertThat(error.getCode()).isEqualTo(Code.UNAVAILABLE); assertThat(error.getDescription()).isEqualTo("Failed to initialize xDS"); - assertThat(error.getCause()).hasMessageThat().isEqualTo("Fail to read bootstrap file"); + assertThat(error.getCause()).hasMessageThat().contains("Invalid bootstrap"); } @Test public void resolving_withTargetAuthorityNotFound() { - resolver = new XdsNameResolver( + resolver = new XdsNameResolver(targetUri, "notfound.google.com", AUTHORITY, null, serviceConfigParser, syncContext, scheduler, - xdsClientPoolFactory, mockRandom, FilterRegistry.getDefaultRegistry(), null); + xdsClientPoolFactory, mockRandom, FilterRegistry.getDefaultRegistry(), rawBootstrap, + metricRecorder, nameResolverArgs); resolver.start(mockListener); verify(mockListener).onError(errorCaptor.capture()); Status error = errorCaptor.getValue(); @@ -241,9 +289,44 @@ public void resolving_noTargetAuthority_templateWithoutXdstp() { String serviceAuthority = "[::FFFF:129.144.52.38]:80"; expectedLdsResourceName = "[::FFFF:129.144.52.38]:80/id=1"; resolver = new XdsNameResolver( - null, serviceAuthority, null, serviceConfigParser, syncContext, + targetUri, null, serviceAuthority, null, serviceConfigParser, syncContext, scheduler, xdsClientPoolFactory, - mockRandom, FilterRegistry.getDefaultRegistry(), null); + mockRandom, FilterRegistry.getDefaultRegistry(), rawBootstrap, metricRecorder, + nameResolverArgs); + resolver.start(mockListener); + verify(mockListener, never()).onError(any(Status.class)); + } + + @Test + public void resolving_emptyTargetAuthority_templateWithXdstp() { + bootstrapInfo = + BootstrapInfo.builder() + .servers( + ImmutableList.of( + ServerInfo.create("td.googleapis.com", InsecureChannelCredentials.create()))) + .node(Node.newBuilder().build()) + .clientDefaultListenerResourceNameTemplate( + "xdstp://xds.authority.com/envoy.config.listener.v3.Listener/%s?id=1") + .build(); + String serviceAuthority = "[::FFFF:129.144.52.38]:80"; + expectedLdsResourceName = + "xdstp://xds.authority.com/envoy.config.listener.v3.Listener/" + + "%5B::FFFF:129.144.52.38%5D:80?id=1"; + resolver = + new XdsNameResolver( + "xds:///foo.googleapis.com", + "", + serviceAuthority, + null, + serviceConfigParser, + syncContext, + scheduler, + xdsClientPoolFactory, + mockRandom, + FilterRegistry.getDefaultRegistry(), + rawBootstrap, + metricRecorder, + nameResolverArgs); resolver.start(mockListener); verify(mockListener, never()).onError(any(Status.class)); } @@ -262,8 +345,9 @@ public void resolving_noTargetAuthority_templateWithXdstp() { "xdstp://xds.authority.com/envoy.config.listener.v3.Listener/" + "%5B::FFFF:129.144.52.38%5D:80?id=1"; resolver = new XdsNameResolver( - null, serviceAuthority, null, serviceConfigParser, syncContext, scheduler, - xdsClientPoolFactory, mockRandom, FilterRegistry.getDefaultRegistry(), null); + targetUri, null, serviceAuthority, null, serviceConfigParser, syncContext, scheduler, + xdsClientPoolFactory, mockRandom, FilterRegistry.getDefaultRegistry(), rawBootstrap, + metricRecorder, nameResolverArgs); resolver.start(mockListener); verify(mockListener, never()).onError(any(Status.class)); } @@ -282,8 +366,9 @@ public void resolving_noTargetAuthority_xdstpWithMultipleSlashes() { "xdstp://xds.authority.com/envoy.config.listener.v3.Listener/" + "path/to/service?id=1"; resolver = new XdsNameResolver( - null, serviceAuthority, null, serviceConfigParser, syncContext, scheduler, - xdsClientPoolFactory, mockRandom, FilterRegistry.getDefaultRegistry(), null); + targetUri, null, serviceAuthority, null, serviceConfigParser, syncContext, scheduler, + xdsClientPoolFactory, mockRandom, FilterRegistry.getDefaultRegistry(), rawBootstrap, + metricRecorder, nameResolverArgs); // The Service Authority must be URL encoded, but unlike the LDS resource name. @@ -299,25 +384,27 @@ public void resolving_targetAuthorityInAuthoritiesMap() { String serviceAuthority = "[::FFFF:129.144.52.38]:80"; bootstrapInfo = BootstrapInfo.builder() .servers(ImmutableList.of(ServerInfo.create( - "td.googleapis.com", InsecureChannelCredentials.create(), true))) + "td.googleapis.com", InsecureChannelCredentials.create(), true, true, false, false))) .node(Node.newBuilder().build()) .authorities( ImmutableMap.of(targetAuthority, AuthorityInfo.create( "xdstp://" + targetAuthority + "/envoy.config.listener.v3.Listener/%s?foo=1&bar=2", ImmutableList.of(ServerInfo.create( - "td.googleapis.com", InsecureChannelCredentials.create(), true))))) + "td.googleapis.com", InsecureChannelCredentials.create(), + true, true, false, false))))) .build(); expectedLdsResourceName = "xdstp://xds.authority.com/envoy.config.listener.v3.Listener/" + "%5B::FFFF:129.144.52.38%5D:80?bar=2&foo=1"; // query param canonified - resolver = new XdsNameResolver( + resolver = new XdsNameResolver(targetUri, "xds.authority.com", serviceAuthority, null, serviceConfigParser, syncContext, scheduler, - xdsClientPoolFactory, mockRandom, FilterRegistry.getDefaultRegistry(), null); + xdsClientPoolFactory, mockRandom, FilterRegistry.getDefaultRegistry(), rawBootstrap, + metricRecorder, nameResolverArgs); resolver.start(mockListener); verify(mockListener, never()).onError(any(Status.class)); } @Test - public void resolving_ldsResourceNotFound() { + public void resolving_ldsResourceNotFound() { // hi resolver.start(mockListener); FakeXdsClient xdsClient = (FakeXdsClient) resolver.getXdsClient(); xdsClient.deliverLdsResourceNotFound(); @@ -329,11 +416,11 @@ public void resolving_ldsResourceNotFound() { public void resolving_ldsResourceUpdateRdsName() { Route route1 = Route.forAction(RouteMatch.withPathExactOnly(call1.getFullMethodNameForPath()), RouteAction.forCluster( - cluster1, Collections.emptyList(), TimeUnit.SECONDS.toNanos(15L), null), + cluster1, Collections.emptyList(), TimeUnit.SECONDS.toNanos(15L), null, false), ImmutableMap.of()); Route route2 = Route.forAction(RouteMatch.withPathExactOnly(call2.getFullMethodNameForPath()), RouteAction.forCluster( - cluster2, Collections.emptyList(), TimeUnit.SECONDS.toNanos(20L), null), + cluster2, Collections.emptyList(), TimeUnit.SECONDS.toNanos(20L), null, false), ImmutableMap.of()); bootstrapInfo = BootstrapInfo.builder() .servers(ImmutableList.of(ServerInfo.create( @@ -341,9 +428,10 @@ public void resolving_ldsResourceUpdateRdsName() { .clientDefaultListenerResourceNameTemplate("test-%s") .node(Node.newBuilder().build()) .build(); - resolver = new XdsNameResolver(null, AUTHORITY, null, + resolver = new XdsNameResolver(targetUri, null, AUTHORITY, null, serviceConfigParser, syncContext, scheduler, - xdsClientPoolFactory, mockRandom, FilterRegistry.getDefaultRegistry(), null); + xdsClientPoolFactory, mockRandom, FilterRegistry.getDefaultRegistry(), rawBootstrap, + metricRecorder, nameResolverArgs); // use different ldsResourceName and service authority. The virtualhost lookup should use // service authority. expectedLdsResourceName = "test-" + expectedLdsResourceName; @@ -351,32 +439,36 @@ public void resolving_ldsResourceUpdateRdsName() { resolver.start(mockListener); FakeXdsClient xdsClient = (FakeXdsClient) resolver.getXdsClient(); xdsClient.deliverLdsUpdateForRdsName(RDS_RESOURCE_NAME); - assertThat(xdsClient.rdsResource).isEqualTo(RDS_RESOURCE_NAME); + assertThat(xdsClient.rdsWatchers.keySet()).containsExactly(RDS_RESOURCE_NAME); VirtualHost virtualHost = VirtualHost.create("virtualhost", Collections.singletonList(AUTHORITY), Collections.singletonList(route1), ImmutableMap.of()); xdsClient.deliverRdsUpdate(RDS_RESOURCE_NAME, Collections.singletonList(virtualHost)); - verify(mockListener).onResult(resolutionResultCaptor.capture()); + createAndDeliverClusterUpdates(xdsClient, cluster1); + verify(mockListener).onResult2(resolutionResultCaptor.capture()); assertServiceConfigForLoadBalancingConfig( Collections.singletonList(cluster1), (Map) resolutionResultCaptor.getValue().getServiceConfig().getConfig()); reset(mockListener); + when(mockListener.onResult2(any())).thenReturn(Status.OK); ArgumentCaptor resultCaptor = ArgumentCaptor.forClass(ResolutionResult.class); String alternativeRdsResource = "route-configuration-alter.googleapis.com"; xdsClient.deliverLdsUpdateForRdsName(alternativeRdsResource); - assertThat(xdsClient.rdsResource).isEqualTo(alternativeRdsResource); + assertThat(xdsClient.rdsWatchers.keySet()).contains(alternativeRdsResource); virtualHost = VirtualHost.create("virtualhost-alter", Collections.singletonList(AUTHORITY), Collections.singletonList(route2), ImmutableMap.of()); xdsClient.deliverRdsUpdate(alternativeRdsResource, Collections.singletonList(virtualHost)); + createAndDeliverClusterUpdates(xdsClient, cluster2); + assertThat(xdsClient.rdsWatchers.keySet()).containsExactly(alternativeRdsResource); // Two new service config updates triggered: // - with load balancing config being able to select cluster1 and cluster2 // - with load balancing config being able to select cluster2 only - verify(mockListener, times(2)).onResult(resultCaptor.capture()); + verify(mockListener, times(3)).onResult2(resultCaptor.capture()); assertServiceConfigForLoadBalancingConfig( Arrays.asList(cluster1, cluster2), (Map) resultCaptor.getAllValues().get(0).getServiceConfig().getConfig()); @@ -399,35 +491,39 @@ public void resolving_rdsResourceNotFound() { public void resolving_ldsResourceRevokedAndAddedBack() { Route route = Route.forAction(RouteMatch.withPathExactOnly(call1.getFullMethodNameForPath()), RouteAction.forCluster( - cluster1, Collections.emptyList(), TimeUnit.SECONDS.toNanos(15L), null), + cluster1, Collections.emptyList(), TimeUnit.SECONDS.toNanos(15L), null, false), ImmutableMap.of()); resolver.start(mockListener); FakeXdsClient xdsClient = (FakeXdsClient) resolver.getXdsClient(); xdsClient.deliverLdsUpdateForRdsName(RDS_RESOURCE_NAME); - assertThat(xdsClient.rdsResource).isEqualTo(RDS_RESOURCE_NAME); + assertThat(xdsClient.rdsWatchers.keySet()).containsExactly(RDS_RESOURCE_NAME); VirtualHost virtualHost = VirtualHost.create("virtualhost", Collections.singletonList(AUTHORITY), Collections.singletonList(route), ImmutableMap.of()); xdsClient.deliverRdsUpdate(RDS_RESOURCE_NAME, Collections.singletonList(virtualHost)); - verify(mockListener).onResult(resolutionResultCaptor.capture()); + createAndDeliverClusterUpdates(xdsClient, cluster1); + verify(mockListener).onResult2(resolutionResultCaptor.capture()); assertServiceConfigForLoadBalancingConfig( Collections.singletonList(cluster1), (Map) resolutionResultCaptor.getValue().getServiceConfig().getConfig()); reset(mockListener); + when(mockListener.onResult2(any())).thenReturn(Status.OK); xdsClient.deliverLdsResourceNotFound(); // revoke LDS resource - assertThat(xdsClient.rdsResource).isNull(); // stop subscribing to stale RDS resource + assertThat(xdsClient.rdsWatchers.keySet()).isEmpty(); // stop subscribing to stale RDS resource assertEmptyResolutionResult(expectedLdsResourceName); reset(mockListener); + when(mockListener.onResult2(any())).thenReturn(Status.OK); xdsClient.deliverLdsUpdateForRdsName(RDS_RESOURCE_NAME); // No name resolution result until new RDS resource update is received. Do not use stale config verifyNoInteractions(mockListener); - assertThat(xdsClient.rdsResource).isEqualTo(RDS_RESOURCE_NAME); + assertThat(xdsClient.rdsWatchers.keySet()).containsExactly(RDS_RESOURCE_NAME); xdsClient.deliverRdsUpdate(RDS_RESOURCE_NAME, Collections.singletonList(virtualHost)); - verify(mockListener).onResult(resolutionResultCaptor.capture()); + createAndDeliverClusterUpdates(xdsClient, cluster1); + verify(mockListener).onResult2(resolutionResultCaptor.capture()); assertServiceConfigForLoadBalancingConfig( Collections.singletonList(cluster1), (Map) resolutionResultCaptor.getValue().getServiceConfig().getConfig()); @@ -438,31 +534,35 @@ public void resolving_ldsResourceRevokedAndAddedBack() { public void resolving_rdsResourceRevokedAndAddedBack() { Route route = Route.forAction(RouteMatch.withPathExactOnly(call1.getFullMethodNameForPath()), RouteAction.forCluster( - cluster1, Collections.emptyList(), TimeUnit.SECONDS.toNanos(15L), null), + cluster1, Collections.emptyList(), TimeUnit.SECONDS.toNanos(15L), null, false), ImmutableMap.of()); resolver.start(mockListener); FakeXdsClient xdsClient = (FakeXdsClient) resolver.getXdsClient(); xdsClient.deliverLdsUpdateForRdsName(RDS_RESOURCE_NAME); - assertThat(xdsClient.rdsResource).isEqualTo(RDS_RESOURCE_NAME); + assertThat(xdsClient.rdsWatchers.keySet()).containsExactly(RDS_RESOURCE_NAME); VirtualHost virtualHost = VirtualHost.create("virtualhost", Collections.singletonList(AUTHORITY), Collections.singletonList(route), ImmutableMap.of()); xdsClient.deliverRdsUpdate(RDS_RESOURCE_NAME, Collections.singletonList(virtualHost)); - verify(mockListener).onResult(resolutionResultCaptor.capture()); + createAndDeliverClusterUpdates(xdsClient, cluster1); + verify(mockListener).onResult2(resolutionResultCaptor.capture()); assertServiceConfigForLoadBalancingConfig( Collections.singletonList(cluster1), (Map) resolutionResultCaptor.getValue().getServiceConfig().getConfig()); reset(mockListener); + when(mockListener.onResult2(any())).thenReturn(Status.OK); xdsClient.deliverRdsResourceNotFound(RDS_RESOURCE_NAME); // revoke RDS resource assertEmptyResolutionResult(RDS_RESOURCE_NAME); // Simulate management server adds back the previously used RDS resource. reset(mockListener); + when(mockListener.onResult2(any())).thenReturn(Status.OK); xdsClient.deliverRdsUpdate(RDS_RESOURCE_NAME, Collections.singletonList(virtualHost)); - verify(mockListener).onResult(resolutionResultCaptor.capture()); + createAndDeliverClusterUpdates(xdsClient, cluster1); + verify(mockListener).onResult2(resolutionResultCaptor.capture()); assertServiceConfigForLoadBalancingConfig( Collections.singletonList(cluster1), (Map) resolutionResultCaptor.getValue().getServiceConfig().getConfig()); @@ -473,11 +573,15 @@ public void resolving_encounterErrorLdsWatcherOnly() { resolver.start(mockListener); FakeXdsClient xdsClient = (FakeXdsClient) resolver.getXdsClient(); xdsClient.deliverError(Status.UNAVAILABLE.withDescription("server unreachable")); - verify(mockListener).onError(errorCaptor.capture()); - Status error = errorCaptor.getValue(); + verify(mockListener).onResult2(resolutionResultCaptor.capture()); + InternalConfigSelector configSelector = resolutionResultCaptor.getValue() + .getAttributes().get(InternalConfigSelector.KEY); + Result selectResult = configSelector.selectConfig( + newPickSubchannelArgs(call1.methodDescriptor, new Metadata(), CallOptions.DEFAULT)); + Status error = selectResult.getStatus(); assertThat(error.getCode()).isEqualTo(Code.UNAVAILABLE); - assertThat(error.getDescription()).isEqualTo("Unable to load LDS " + AUTHORITY - + ". xDS server returned: UNAVAILABLE: server unreachable"); + assertThat(error.getDescription()).contains(AUTHORITY); + assertThat(error.getDescription()).contains("server unreachable"); } @Test @@ -485,11 +589,15 @@ public void resolving_translateErrorLds() { resolver.start(mockListener); FakeXdsClient xdsClient = (FakeXdsClient) resolver.getXdsClient(); xdsClient.deliverError(Status.NOT_FOUND.withDescription("server unreachable")); - verify(mockListener).onError(errorCaptor.capture()); - Status error = errorCaptor.getValue(); + verify(mockListener).onResult2(resolutionResultCaptor.capture()); + InternalConfigSelector configSelector = resolutionResultCaptor.getValue() + .getAttributes().get(InternalConfigSelector.KEY); + Result selectResult = configSelector.selectConfig( + newPickSubchannelArgs(call1.methodDescriptor, new Metadata(), CallOptions.DEFAULT)); + Status error = selectResult.getStatus(); assertThat(error.getCode()).isEqualTo(Code.UNAVAILABLE); - assertThat(error.getDescription()).isEqualTo("Unable to load LDS " + AUTHORITY - + ". xDS server returned: NOT_FOUND: server unreachable"); + assertThat(error.getDescription()).contains(AUTHORITY); + assertThat(error.getDescription()).contains("server unreachable"); assertThat(error.getCause()).isNull(); } @@ -499,15 +607,17 @@ public void resolving_encounterErrorLdsAndRdsWatchers() { FakeXdsClient xdsClient = (FakeXdsClient) resolver.getXdsClient(); xdsClient.deliverLdsUpdateForRdsName(RDS_RESOURCE_NAME); xdsClient.deliverError(Status.UNAVAILABLE.withDescription("server unreachable")); - verify(mockListener, times(2)).onError(errorCaptor.capture()); - Status error = errorCaptor.getAllValues().get(0); - assertThat(error.getCode()).isEqualTo(Code.UNAVAILABLE); - assertThat(error.getDescription()).isEqualTo("Unable to load LDS " + AUTHORITY - + ". xDS server returned: UNAVAILABLE: server unreachable"); - error = errorCaptor.getAllValues().get(1); + verify(mockListener).onResult2(resolutionResultCaptor.capture()); + InternalConfigSelector configSelector = resolutionResultCaptor.getValue() + .getAttributes().get(InternalConfigSelector.KEY); + Result selectResult = configSelector.selectConfig( + newPickSubchannelArgs(call1.methodDescriptor, new Metadata(), CallOptions.DEFAULT)); + Status error = selectResult.getStatus(); assertThat(error.getCode()).isEqualTo(Code.UNAVAILABLE); - assertThat(error.getDescription()).isEqualTo("Unable to load RDS " + RDS_RESOURCE_NAME - + ". xDS server returned: UNAVAILABLE: server unreachable"); + // XdsDepManager.buildUpdate doesn't allow this + // assertThat(error.getDescription()).contains(RDS_RESOURCE_NAME); + assertThat(error.getDescription()).contains(expectedLdsResourceName); + assertThat(error.getDescription()).contains("server unreachable"); } @SuppressWarnings("unchecked") @@ -515,20 +625,22 @@ public void resolving_encounterErrorLdsAndRdsWatchers() { public void resolving_matchingVirtualHostNotFound_matchingOverrideAuthority() { Route route = Route.forAction(RouteMatch.withPathExactOnly(call1.getFullMethodNameForPath()), RouteAction.forCluster( - cluster1, Collections.emptyList(), TimeUnit.SECONDS.toNanos(15L), null), + cluster1, Collections.emptyList(), TimeUnit.SECONDS.toNanos(15L), null, false), ImmutableMap.of()); VirtualHost virtualHost = VirtualHost.create("virtualhost", Collections.singletonList("random"), Collections.singletonList(route), ImmutableMap.of()); - resolver = new XdsNameResolver(null, AUTHORITY, "random", + resolver = new XdsNameResolver(targetUri, null, AUTHORITY, "random", serviceConfigParser, syncContext, scheduler, - xdsClientPoolFactory, mockRandom, FilterRegistry.getDefaultRegistry(), null); + xdsClientPoolFactory, mockRandom, FilterRegistry.getDefaultRegistry(), rawBootstrap, + metricRecorder, nameResolverArgs); resolver.start(mockListener); FakeXdsClient xdsClient = (FakeXdsClient) resolver.getXdsClient(); xdsClient.deliverLdsUpdate(0L, Arrays.asList(virtualHost)); - verify(mockListener).onResult(resolutionResultCaptor.capture()); + createAndDeliverClusterUpdates(xdsClient, cluster1); + verify(mockListener).onResult2(resolutionResultCaptor.capture()); assertServiceConfigForLoadBalancingConfig( Collections.singletonList(cluster1), (Map) resolutionResultCaptor.getValue().getServiceConfig().getConfig()); @@ -538,27 +650,30 @@ public void resolving_matchingVirtualHostNotFound_matchingOverrideAuthority() { public void resolving_matchingVirtualHostNotFound_notMatchingOverrideAuthority() { Route route = Route.forAction(RouteMatch.withPathExactOnly(call1.getFullMethodNameForPath()), RouteAction.forCluster( - cluster1, Collections.emptyList(), TimeUnit.SECONDS.toNanos(15L), null), + cluster1, Collections.emptyList(), TimeUnit.SECONDS.toNanos(15L), null, false), ImmutableMap.of()); VirtualHost virtualHost = VirtualHost.create("virtualhost", Collections.singletonList(AUTHORITY), Collections.singletonList(route), ImmutableMap.of()); - resolver = new XdsNameResolver(null, AUTHORITY, "random", + resolver = new XdsNameResolver(targetUri, null, AUTHORITY, "random", serviceConfigParser, syncContext, scheduler, - xdsClientPoolFactory, mockRandom, FilterRegistry.getDefaultRegistry(), null); + xdsClientPoolFactory, mockRandom, FilterRegistry.getDefaultRegistry(), rawBootstrap, + metricRecorder, nameResolverArgs); resolver.start(mockListener); FakeXdsClient xdsClient = (FakeXdsClient) resolver.getXdsClient(); - xdsClient.deliverLdsUpdate(0L, Arrays.asList(virtualHost)); + xdsClient.deliverLdsUpdateOnly(0L, Arrays.asList(virtualHost)); + fakeClock.forwardTime(15, TimeUnit.SECONDS); assertEmptyResolutionResult("random"); } @Test public void resolving_matchingVirtualHostNotFoundForOverrideAuthority() { - resolver = new XdsNameResolver(null, AUTHORITY, AUTHORITY, + resolver = new XdsNameResolver(targetUri, null, AUTHORITY, AUTHORITY, serviceConfigParser, syncContext, scheduler, - xdsClientPoolFactory, mockRandom, FilterRegistry.getDefaultRegistry(), null); + xdsClientPoolFactory, mockRandom, FilterRegistry.getDefaultRegistry(), rawBootstrap, + metricRecorder, nameResolverArgs); resolver.start(mockListener); FakeXdsClient xdsClient = (FakeXdsClient) resolver.getXdsClient(); xdsClient.deliverLdsUpdate(0L, buildUnmatchedVirtualHosts()); @@ -585,11 +700,11 @@ public void resolving_matchingVirtualHostNotFoundInRdsResource() { private List buildUnmatchedVirtualHosts() { Route route1 = Route.forAction(RouteMatch.withPathExactOnly(call2.getFullMethodNameForPath()), RouteAction.forCluster( - cluster2, Collections.emptyList(), TimeUnit.SECONDS.toNanos(15L), null), + cluster2, Collections.emptyList(), TimeUnit.SECONDS.toNanos(15L), null, false), ImmutableMap.of()); Route route2 = Route.forAction(RouteMatch.withPathExactOnly(call1.getFullMethodNameForPath()), RouteAction.forCluster( - cluster1, Collections.emptyList(), TimeUnit.SECONDS.toNanos(15L), null), + cluster1, Collections.emptyList(), TimeUnit.SECONDS.toNanos(15L), null, false), ImmutableMap.of()); return Arrays.asList( VirtualHost.create("virtualhost-foo", Collections.singletonList("hello.googleapis.com"), @@ -606,13 +721,13 @@ public void resolved_noTimeout() { FakeXdsClient xdsClient = (FakeXdsClient) resolver.getXdsClient(); Route route = Route.forAction(RouteMatch.withPathExactOnly(call1.getFullMethodNameForPath()), RouteAction.forCluster( - cluster1, Collections.emptyList(), null, null), // per-route timeout unset + cluster1, Collections.emptyList(), null, null, false), // per-route timeout unset ImmutableMap.of()); VirtualHost virtualHost = VirtualHost.create("does not matter", Collections.singletonList(AUTHORITY), Collections.singletonList(route), ImmutableMap.of()); xdsClient.deliverLdsUpdate(0L, Collections.singletonList(virtualHost)); - verify(mockListener).onResult(resolutionResultCaptor.capture()); + verify(mockListener).onResult2(resolutionResultCaptor.capture()); ResolutionResult result = resolutionResultCaptor.getValue(); InternalConfigSelector configSelector = result.getAttributes().get(InternalConfigSelector.KEY); assertCallSelectClusterResult(call1, configSelector, cluster1, null); @@ -624,14 +739,14 @@ public void resolved_fallbackToHttpMaxStreamDurationAsTimeout() { FakeXdsClient xdsClient = (FakeXdsClient) resolver.getXdsClient(); Route route = Route.forAction(RouteMatch.withPathExactOnly(call1.getFullMethodNameForPath()), RouteAction.forCluster( - cluster1, Collections.emptyList(), null, null), // per-route timeout unset + cluster1, Collections.emptyList(), null, null, false), // per-route timeout unset ImmutableMap.of()); VirtualHost virtualHost = VirtualHost.create("does not matter", Collections.singletonList(AUTHORITY), Collections.singletonList(route), ImmutableMap.of()); xdsClient.deliverLdsUpdate(TimeUnit.SECONDS.toNanos(5L), Collections.singletonList(virtualHost)); - verify(mockListener).onResult(resolutionResultCaptor.capture()); + verify(mockListener).onResult2(resolutionResultCaptor.capture()); ResolutionResult result = resolutionResultCaptor.getValue(); InternalConfigSelector configSelector = result.getAttributes().get(InternalConfigSelector.KEY); assertCallSelectClusterResult(call1, configSelector, cluster1, 5.0); @@ -641,8 +756,9 @@ public void resolved_fallbackToHttpMaxStreamDurationAsTimeout() { public void retryPolicyInPerMethodConfigGeneratedByResolverIsValid() { ServiceConfigParser realParser = new ScParser( true, 5, 5, new AutoConfiguredLoadBalancerFactory("pick-first")); - resolver = new XdsNameResolver(null, AUTHORITY, null, realParser, syncContext, scheduler, - xdsClientPoolFactory, mockRandom, FilterRegistry.getDefaultRegistry(), null); + resolver = new XdsNameResolver(targetUri, null, AUTHORITY, null, realParser, syncContext, + scheduler, xdsClientPoolFactory, mockRandom, FilterRegistry.getDefaultRegistry(), + rawBootstrap, metricRecorder, nameResolverArgs); resolver.start(mockListener); FakeXdsClient xdsClient = (FakeXdsClient) resolver.getXdsClient(); RetryPolicy retryPolicy = RetryPolicy.create( @@ -656,13 +772,14 @@ public void retryPolicyInPerMethodConfigGeneratedByResolverIsValid() { cluster1, Collections.emptyList(), null, - retryPolicy), + retryPolicy, + false), ImmutableMap.of()))); - verify(mockListener).onResult(resolutionResultCaptor.capture()); + verify(mockListener).onResult2(resolutionResultCaptor.capture()); ResolutionResult result = resolutionResultCaptor.getValue(); InternalConfigSelector configSelector = result.getAttributes().get(InternalConfigSelector.KEY); Result selectResult = configSelector.selectConfig( - new PickSubchannelArgsImpl(call1.methodDescriptor, new Metadata(), CallOptions.DEFAULT)); + newPickSubchannelArgs(call1.methodDescriptor, new Metadata(), CallOptions.DEFAULT)); Object config = selectResult.getConfig(); // Purely validating the data (io.grpc.internal.RetryPolicy). @@ -693,7 +810,7 @@ public void resolved_simpleCallFailedToRoute_noMatchingRoute() { InternalConfigSelector configSelector = resolveToClusters(); CallInfo call = new CallInfo("FooService", "barMethod"); Result selectResult = configSelector.selectConfig( - new PickSubchannelArgsImpl(call.methodDescriptor, new Metadata(), CallOptions.DEFAULT)); + newPickSubchannelArgs(call.methodDescriptor, new Metadata(), CallOptions.DEFAULT)); Status status = selectResult.getStatus(); assertThat(status.isOk()).isFalse(); assertThat(status.getCode()).isEqualTo(Code.UNAVAILABLE); @@ -714,20 +831,20 @@ public void resolved_simpleCallFailedToRoute_routeWithNonForwardingAction() { Route.forAction( RouteMatch.withPathExactOnly(call2.getFullMethodNameForPath()), RouteAction.forCluster(cluster2, Collections.emptyList(), - TimeUnit.SECONDS.toNanos(15L), null), + TimeUnit.SECONDS.toNanos(15L), null, false), ImmutableMap.of()))); - verify(mockListener).onResult(resolutionResultCaptor.capture()); + verify(mockListener).onResult2(resolutionResultCaptor.capture()); ResolutionResult result = resolutionResultCaptor.getValue(); - assertThat(result.getAddresses()).isEmpty(); + assertThat(result.getAddressesOrError().getValue()).isEmpty(); assertServiceConfigForLoadBalancingConfig( Collections.singletonList(cluster2), (Map) result.getServiceConfig().getConfig()); - assertThat(result.getAttributes().get(InternalXdsAttributes.XDS_CLIENT_POOL)).isNotNull(); - assertThat(result.getAttributes().get(InternalXdsAttributes.CALL_COUNTER_PROVIDER)).isNotNull(); + assertThat(result.getAttributes().get(XdsAttributes.XDS_CLIENT)).isNotNull(); + assertThat(result.getAttributes().get(XdsAttributes.CALL_COUNTER_PROVIDER)).isNotNull(); InternalConfigSelector configSelector = result.getAttributes().get(InternalConfigSelector.KEY); // Simulates making a call1 RPC. Result selectResult = configSelector.selectConfig( - new PickSubchannelArgsImpl(call1.methodDescriptor, new Metadata(), CallOptions.DEFAULT)); + newPickSubchannelArgs(call1.methodDescriptor, new Metadata(), CallOptions.DEFAULT)); Status status = selectResult.getStatus(); assertThat(status.isOk()).isFalse(); assertThat(status.getCode()).isEqualTo(Code.UNAVAILABLE); @@ -750,9 +867,10 @@ public void resolved_rpcHashingByHeader_withoutSubstitution() { Collections.singletonList( HashPolicy.forHeader(false, "custom-key", null, null)), null, - null), + null, + false), ImmutableMap.of()))); - verify(mockListener).onResult(resolutionResultCaptor.capture()); + verify(mockListener).onResult2(resolutionResultCaptor.capture()); InternalConfigSelector configSelector = resolutionResultCaptor.getValue().getAttributes().get(InternalConfigSelector.KEY); @@ -782,11 +900,13 @@ public void resolved_rpcHashingByHeader_withSubstitution() { RouteAction.forCluster( cluster1, Collections.singletonList( - HashPolicy.forHeader(false, "custom-key", Pattern.compile("value"), "val")), + HashPolicy.forHeader(false, "custom-key", Pattern.compile("value"), + "val")), + null, null, - null), + false), ImmutableMap.of()))); - verify(mockListener).onResult(resolutionResultCaptor.capture()); + verify(mockListener).onResult2(resolutionResultCaptor.capture()); InternalConfigSelector configSelector = resolutionResultCaptor.getValue().getAttributes().get(InternalConfigSelector.KEY); @@ -823,9 +943,10 @@ public void resolved_rpcHashingByChannelId() { cluster1, Collections.singletonList(HashPolicy.forChannelId(false)), null, - null), + null, + false), ImmutableMap.of()))); - verify(mockListener).onResult(resolutionResultCaptor.capture()); + verify(mockListener).onResult2(resolutionResultCaptor.capture()); InternalConfigSelector configSelector = resolutionResultCaptor.getValue().getAttributes().get(InternalConfigSelector.KEY); @@ -844,10 +965,12 @@ public void resolved_rpcHashingByChannelId() { // A different resolver/Channel. resolver.shutdown(); reset(mockListener); + when(mockListener.onResult2(any())).thenReturn(Status.OK); when(mockRandom.nextLong()).thenReturn(123L); - resolver = new XdsNameResolver(null, AUTHORITY, null, serviceConfigParser, + resolver = new XdsNameResolver(targetUri, null, AUTHORITY, null, serviceConfigParser, syncContext, scheduler, - xdsClientPoolFactory, mockRandom, FilterRegistry.getDefaultRegistry(), null); + xdsClientPoolFactory, mockRandom, FilterRegistry.getDefaultRegistry(), rawBootstrap, + metricRecorder, nameResolverArgs); resolver.start(mockListener); xdsClient = (FakeXdsClient) resolver.getXdsClient(); xdsClient.deliverLdsUpdate( @@ -859,9 +982,10 @@ public void resolved_rpcHashingByChannelId() { cluster1, Collections.singletonList(HashPolicy.forChannelId(false)), null, - null), + null, + false), ImmutableMap.of()))); - verify(mockListener).onResult(resolutionResultCaptor.capture()); + verify(mockListener).onResult2(resolutionResultCaptor.capture()); configSelector = resolutionResultCaptor.getValue().getAttributes().get( InternalConfigSelector.KEY); @@ -875,6 +999,68 @@ public void resolved_rpcHashingByChannelId() { assertThat(hash3).isNotEqualTo(hash1); } + @Test + public void resolved_routeActionHasAutoHostRewrite_emitsCallOptionForTheSame() { + resolver = new XdsNameResolver(targetUri, null, AUTHORITY, null, serviceConfigParser, + syncContext, scheduler, xdsClientPoolFactory, mockRandom, + FilterRegistry.getDefaultRegistry(), rawBootstrap, metricRecorder, nameResolverArgs); + resolver.start(mockListener); + FakeXdsClient xdsClient = (FakeXdsClient) resolver.getXdsClient(); + xdsClient.deliverLdsUpdate( + Collections.singletonList( + Route.forAction( + RouteMatch.withPathExactOnly( + "/" + TestMethodDescriptors.voidMethod().getFullMethodName()), + RouteAction.forCluster( + cluster1, + Collections.singletonList( + HashPolicy.forHeader(false, "custom-key", null, null)), + null, + null, + true), + ImmutableMap.of()))); + verify(mockListener).onResult2(resolutionResultCaptor.capture()); + InternalConfigSelector configSelector = + resolutionResultCaptor.getValue().getAttributes().get(InternalConfigSelector.KEY); + + // First call, with header "custom-key": "custom-value". + startNewCall(TestMethodDescriptors.voidMethod(), configSelector, + ImmutableMap.of("custom-key", "custom-value"), CallOptions.DEFAULT); + + assertThat(testCall.callOptions.getOption(XdsNameResolver.AUTO_HOST_REWRITE_KEY)).isTrue(); + } + + @Test + public void resolved_routeActionNoAutoHostRewrite_doesntEmitCallOptionForTheSame() { + resolver = new XdsNameResolver(targetUri, null, AUTHORITY, null, serviceConfigParser, + syncContext, scheduler, xdsClientPoolFactory, mockRandom, + FilterRegistry.getDefaultRegistry(), rawBootstrap, metricRecorder, nameResolverArgs); + resolver.start(mockListener); + FakeXdsClient xdsClient = (FakeXdsClient) resolver.getXdsClient(); + xdsClient.deliverLdsUpdate( + Collections.singletonList( + Route.forAction( + RouteMatch.withPathExactOnly( + "/" + TestMethodDescriptors.voidMethod().getFullMethodName()), + RouteAction.forCluster( + cluster1, + Collections.singletonList( + HashPolicy.forHeader(false, "custom-key", null, null)), + null, + null, + false), + ImmutableMap.of()))); + verify(mockListener).onResult2(resolutionResultCaptor.capture()); + InternalConfigSelector configSelector = + resolutionResultCaptor.getValue().getAttributes().get(InternalConfigSelector.KEY); + + // First call, with header "custom-key": "custom-value". + startNewCall(TestMethodDescriptors.voidMethod(), configSelector, + ImmutableMap.of("custom-key", "custom-value"), CallOptions.DEFAULT); + + assertThat(testCall.callOptions.getOption(XdsNameResolver.AUTO_HOST_REWRITE_KEY)).isNull(); + } + @SuppressWarnings("unchecked") @Test public void resolved_resourceUpdateAfterCallStarted() { @@ -883,6 +1069,7 @@ public void resolved_resourceUpdateAfterCallStarted() { TestCall firstCall = testCall; reset(mockListener); + when(mockListener.onResult2(any())).thenReturn(Status.OK); FakeXdsClient xdsClient = (FakeXdsClient) resolver.getXdsClient(); xdsClient.deliverLdsUpdate( Arrays.asList( @@ -890,15 +1077,15 @@ public void resolved_resourceUpdateAfterCallStarted() { RouteMatch.withPathExactOnly(call1.getFullMethodNameForPath()), RouteAction.forCluster( "another-cluster", Collections.emptyList(), - TimeUnit.SECONDS.toNanos(20L), null), + TimeUnit.SECONDS.toNanos(20L), null, false), ImmutableMap.of()), Route.forAction( RouteMatch.withPathExactOnly(call2.getFullMethodNameForPath()), RouteAction.forCluster( cluster2, Collections.emptyList(), TimeUnit.SECONDS.toNanos(15L), - null), + null, false), ImmutableMap.of()))); - verify(mockListener).onResult(resolutionResultCaptor.capture()); + verify(mockListener).onResult2(resolutionResultCaptor.capture()); ResolutionResult result = resolutionResultCaptor.getValue(); // Updated service config still contains cluster1 while it is removed resource. New calls no // longer routed to cluster1. @@ -910,7 +1097,9 @@ public void resolved_resourceUpdateAfterCallStarted() { assertCallSelectClusterResult(call1, configSelector, "another-cluster", 20.0); firstCall.deliverErrorStatus(); // completes previous call - verify(mockListener, times(2)).onResult(resolutionResultCaptor.capture()); + // Two updates: one for XdsNameResolver releasing the cluster, and another for + // XdsDependencyManager updating the XdsConfig + verify(mockListener, times(3)).onResult2(resolutionResultCaptor.capture()); result = resolutionResultCaptor.getValue(); assertServiceConfigForLoadBalancingConfig( Arrays.asList(cluster2, "another-cluster"), @@ -923,6 +1112,7 @@ public void resolved_resourceUpdateAfterCallStarted() { public void resolved_resourceUpdatedBeforeCallStarted() { InternalConfigSelector configSelector = resolveToClusters(); reset(mockListener); + when(mockListener.onResult2(any())).thenReturn(Status.OK); FakeXdsClient xdsClient = (FakeXdsClient) resolver.getXdsClient(); xdsClient.deliverLdsUpdate( Arrays.asList( @@ -930,17 +1120,17 @@ public void resolved_resourceUpdatedBeforeCallStarted() { RouteMatch.withPathExactOnly(call1.getFullMethodNameForPath()), RouteAction.forCluster( "another-cluster", Collections.emptyList(), - TimeUnit.SECONDS.toNanos(20L), null), + TimeUnit.SECONDS.toNanos(20L), null, false), ImmutableMap.of()), Route.forAction( RouteMatch.withPathExactOnly(call2.getFullMethodNameForPath()), RouteAction.forCluster( cluster2, Collections.emptyList(), TimeUnit.SECONDS.toNanos(15L), - null), + null, false), ImmutableMap.of()))); // Two consecutive service config updates: one for removing clcuster1, // one for adding "another=cluster". - verify(mockListener, times(2)).onResult(resolutionResultCaptor.capture()); + verify(mockListener, times(3)).onResult2(resolutionResultCaptor.capture()); ResolutionResult result = resolutionResultCaptor.getValue(); assertServiceConfigForLoadBalancingConfig( Arrays.asList(cluster2, "another-cluster"), @@ -959,6 +1149,7 @@ public void resolved_raceBetweenCallAndRepeatedResourceUpdate() { assertCallSelectClusterResult(call1, configSelector, cluster1, 15.0); reset(mockListener); + when(mockListener.onResult2(any())).thenReturn(Status.OK); FakeXdsClient xdsClient = (FakeXdsClient) resolver.getXdsClient(); xdsClient.deliverLdsUpdate( Arrays.asList( @@ -966,16 +1157,16 @@ public void resolved_raceBetweenCallAndRepeatedResourceUpdate() { RouteMatch.withPathExactOnly(call1.getFullMethodNameForPath()), RouteAction.forCluster( "another-cluster", Collections.emptyList(), - TimeUnit.SECONDS.toNanos(20L), null), + TimeUnit.SECONDS.toNanos(20L), null, false), ImmutableMap.of()), Route.forAction( RouteMatch.withPathExactOnly(call2.getFullMethodNameForPath()), RouteAction.forCluster( cluster2, Collections.emptyList(), - TimeUnit.SECONDS.toNanos(15L), null), + TimeUnit.SECONDS.toNanos(15L), null, false), ImmutableMap.of()))); - verify(mockListener).onResult(resolutionResultCaptor.capture()); + verify(mockListener).onResult2(resolutionResultCaptor.capture()); ResolutionResult result = resolutionResultCaptor.getValue(); assertServiceConfigForLoadBalancingConfig( Arrays.asList(cluster1, cluster2, "another-cluster"), @@ -987,15 +1178,15 @@ public void resolved_raceBetweenCallAndRepeatedResourceUpdate() { RouteMatch.withPathExactOnly(call1.getFullMethodNameForPath()), RouteAction.forCluster( "another-cluster", Collections.emptyList(), - TimeUnit.SECONDS.toNanos(15L), null), + TimeUnit.SECONDS.toNanos(15L), null, false), ImmutableMap.of()), Route.forAction( RouteMatch.withPathExactOnly(call2.getFullMethodNameForPath()), RouteAction.forCluster( cluster2, Collections.emptyList(), - TimeUnit.SECONDS.toNanos(15L), null), + TimeUnit.SECONDS.toNanos(15L), null, false), ImmutableMap.of()))); - verifyNoMoreInteractions(mockListener); // no cluster added/deleted + verify(mockListener, times(2)).onResult2(resolutionResultCaptor.capture()); assertCallSelectClusterResult(call1, configSelector, "another-cluster", 15.0); } @@ -1010,7 +1201,7 @@ public void resolved_raceBetweenClusterReleasedAndResourceUpdateAddBackAgain() { RouteMatch.withPathExactOnly(call2.getFullMethodNameForPath()), RouteAction.forCluster( cluster2, Collections.emptyList(), TimeUnit.SECONDS.toNanos(15L), - null), + null, false), ImmutableMap.of()))); xdsClient.deliverLdsUpdate( Arrays.asList( @@ -1018,16 +1209,22 @@ public void resolved_raceBetweenClusterReleasedAndResourceUpdateAddBackAgain() { RouteMatch.withPathExactOnly(call1.getFullMethodNameForPath()), RouteAction.forCluster( cluster1, Collections.emptyList(), TimeUnit.SECONDS.toNanos(15L), - null), + null, false), ImmutableMap.of()), Route.forAction( RouteMatch.withPathExactOnly(call2.getFullMethodNameForPath()), RouteAction.forCluster( cluster2, Collections.emptyList(), TimeUnit.SECONDS.toNanos(15L), - null), + null, false), ImmutableMap.of()))); testCall.deliverErrorStatus(); - verifyNoMoreInteractions(mockListener); + verify(mockListener, times(3)).onResult2(resolutionResultCaptor.capture()); + assertServiceConfigForLoadBalancingConfig( + Arrays.asList(cluster1, cluster2), resolutionResultCaptor.getAllValues().get(1)); + assertServiceConfigForLoadBalancingConfig( + Arrays.asList(cluster1, cluster2), resolutionResultCaptor.getAllValues().get(2)); + assertServiceConfigForLoadBalancingConfig( + Arrays.asList(cluster1, cluster2), resolutionResultCaptor.getAllValues().get(3)); } @SuppressWarnings("unchecked") @@ -1048,19 +1245,33 @@ public void resolved_simpleCallSucceeds_routeToWeightedCluster() { cluster2, 80, ImmutableMap.of())), Collections.emptyList(), TimeUnit.SECONDS.toNanos(20L), - null), + null, false), ImmutableMap.of()))); - verify(mockListener).onResult(resolutionResultCaptor.capture()); + verify(mockListener).onResult2(resolutionResultCaptor.capture()); ResolutionResult result = resolutionResultCaptor.getValue(); - assertThat(result.getAddresses()).isEmpty(); + assertThat(result.getAddressesOrError().getValue()).isEmpty(); assertServiceConfigForLoadBalancingConfig( Arrays.asList(cluster1, cluster2), (Map) result.getServiceConfig().getConfig()); - assertThat(result.getAttributes().get(InternalXdsAttributes.XDS_CLIENT_POOL)).isNotNull(); + assertThat(result.getAttributes().get(XdsAttributes.XDS_CLIENT)).isNotNull(); InternalConfigSelector configSelector = result.getAttributes().get(InternalConfigSelector.KEY); assertCallSelectClusterResult(call1, configSelector, cluster2, 20.0); assertCallSelectClusterResult(call1, configSelector, cluster1, 20.0); } + /** Creates and delivers both CDS and EDS updates for the given clusters. */ + private static void createAndDeliverClusterUpdates( + FakeXdsClient xdsClient, String... clusterNames) { + for (String clusterName : clusterNames) { + CdsUpdate.Builder forEds = CdsUpdate + .forEds(clusterName, clusterName, null, null, null, null, false, null) + .roundRobinLbPolicy(); + xdsClient.deliverCdsUpdate(clusterName, forEds.build()); + EdsUpdate edsUpdate = new EdsUpdate(clusterName, + XdsTestUtils.createMinimalLbEndpointsMap("127.0.0.3"), Collections.emptyList()); + xdsClient.deliverEdsUpdate(clusterName, edsUpdate); + } + } + @Test public void resolved_simpleCallSucceeds_routeToRls() { when(mockRandom.nextInt(anyInt())).thenReturn(90, 10); @@ -1077,11 +1288,11 @@ public void resolved_simpleCallSucceeds_routeToRls() { ImmutableMap.of("lookupService", "rls-cbt.googleapis.com"))), Collections.emptyList(), TimeUnit.SECONDS.toNanos(20L), - null), + null, false), ImmutableMap.of()))); - verify(mockListener).onResult(resolutionResultCaptor.capture()); + verify(mockListener).onResult2(resolutionResultCaptor.capture()); ResolutionResult result = resolutionResultCaptor.getValue(); - assertThat(result.getAddresses()).isEmpty(); + assertThat(result.getAddressesOrError().getValue()).isEmpty(); @SuppressWarnings("unchecked") Map resultServiceConfig = (Map) result.getServiceConfig().getConfig(); List> rawLbConfigs = @@ -1094,7 +1305,7 @@ public void resolved_simpleCallSucceeds_routeToRls() { "routeLookupConfig", ImmutableMap.of("lookupService", "rls-cbt.googleapis.com"), "childPolicy", - ImmutableList.of(ImmutableMap.of("cds_experimental", ImmutableMap.of())), + ImmutableList.of(ImmutableMap.of("cds_experimental", ImmutableMap.of("is_dynamic", true))), "childPolicyConfigTargetFieldName", "cluster"); Map expectedClusterManagerLbConfig = ImmutableMap.of( @@ -1106,7 +1317,7 @@ public void resolved_simpleCallSucceeds_routeToRls() { ImmutableList.of(ImmutableMap.of("rls_experimental", expectedRlsLbConfig))))); assertThat(clusterManagerLbConfig).isEqualTo(expectedClusterManagerLbConfig); - assertThat(result.getAttributes().get(InternalXdsAttributes.XDS_CLIENT_POOL)).isNotNull(); + assertThat(result.getAttributes().get(XdsAttributes.XDS_CLIENT)).isNotNull(); InternalConfigSelector configSelector = result.getAttributes().get(InternalConfigSelector.KEY); assertCallSelectRlsPluginResult( call1, configSelector, "rls-plugin-foo", 20.0); @@ -1125,9 +1336,9 @@ public void resolved_simpleCallSucceeds_routeToRls() { Collections.emptyList(), // changed TimeUnit.SECONDS.toNanos(30L), - null), + null, false), ImmutableMap.of()))); - verify(mockListener, times(2)).onResult(resolutionResultCaptor.capture()); + verify(mockListener, times(2)).onResult2(resolutionResultCaptor.capture()); ResolutionResult result2 = resolutionResultCaptor.getValue(); @SuppressWarnings("unchecked") Map resultServiceConfig2 = (Map) result2.getServiceConfig().getConfig(); @@ -1141,7 +1352,7 @@ public void resolved_simpleCallSucceeds_routeToRls() { "routeLookupConfig", ImmutableMap.of("lookupService", "rls-cbt-2.googleapis.com"), "childPolicy", - ImmutableList.of(ImmutableMap.of("cds_experimental", ImmutableMap.of())), + ImmutableList.of(ImmutableMap.of("cds_experimental", ImmutableMap.of("is_dynamic", true))), "childPolicyConfigTargetFieldName", "cluster"); Map expectedClusterManagerLbConfig2 = ImmutableMap.of( @@ -1158,24 +1369,399 @@ public void resolved_simpleCallSucceeds_routeToRls() { call1, configSelector2, "rls-plugin-foo", 30.0); } + // Begin filter state tests. + + /** + * Verifies the lifecycle of HCM filter instances across LDS updates. + * + *

Filter instances: + * 1. Must have one unique instance per HCM filter name. + * 2. Must be reused when an LDS update with HCM contains a filter with the same name. + * 3. Must be shutdown (closed) when an HCM in a LDS update doesn't a filter with the same name. + */ + @Test + public void filterState_survivesLds() { + StatefulFilter.Provider statefulFilterProvider = filterStateTestSetupResolver(); + FakeXdsClient xdsClient = (FakeXdsClient) resolver.getXdsClient(); + VirtualHost vhost = filterStateTestVhost(); + + // LDS 1. + xdsClient.deliverLdsUpdateWithFilters(vhost, filterStateTestConfigs(STATEFUL_1, STATEFUL_2)); + createAndDeliverClusterUpdates(xdsClient, cluster1); + assertClusterResolutionResult(call1, cluster1); + ImmutableList lds1Snapshot = statefulFilterProvider.getAllInstances(); + // Verify that StatefulFilter with different filter names result in different Filter instances. + assertWithMessage("LDS 1: expected to create filter instances").that(lds1Snapshot).hasSize(2); + // Naming: ldsFilter + StatefulFilter lds1Filter1 = lds1Snapshot.get(0); + StatefulFilter lds1Filter2 = lds1Snapshot.get(1); + assertThat(lds1Filter1).isNotSameInstanceAs(lds1Filter2); + // Redundant check just in case StatefulFilter synchronization is broken. + assertThat(lds1Filter1.idx).isEqualTo(0); + assertThat(lds1Filter2.idx).isEqualTo(1); + + // LDS 2: filter configs with the same names. + xdsClient.deliverLdsUpdateWithFilters(vhost, filterStateTestConfigs(STATEFUL_1, STATEFUL_2)); + assertClusterResolutionResult(call1, cluster1); + ImmutableList lds2Snapshot = statefulFilterProvider.getAllInstances(); + // Filter names hasn't changed, so expecting no new StatefulFilter instances. + assertWithMessage("LDS 2: Expected Filter instances to be reused across LDS updates") + .that(lds2Snapshot).isEqualTo(lds1Snapshot); + + // LDS 3: Filter "STATEFUL_2" removed. + xdsClient.deliverLdsUpdateWithFilters(vhost, filterStateTestConfigs(STATEFUL_1)); + assertClusterResolutionResult(call1, cluster1); + ImmutableList lds3Snapshot = statefulFilterProvider.getAllInstances(); + // Again, no new StatefulFilter instances should be created. + assertWithMessage("LDS 3: Expected Filter instances to be reused across LDS updates") + .that(lds3Snapshot).isEqualTo(lds1Snapshot); + // Verify the shutdown state. + assertThat(lds1Filter1.isShutdown()).isFalse(); + assertWithMessage("LDS 3: Expected %s to be shut down", lds1Filter2) + .that(lds1Filter2.isShutdown()).isTrue(); + + // LDS 4: Filter "STATEFUL_2" added back. + xdsClient.deliverLdsUpdateWithFilters(vhost, filterStateTestConfigs(STATEFUL_1, STATEFUL_2)); + assertClusterResolutionResult(call1, cluster1); + ImmutableList lds4Snapshot = statefulFilterProvider.getAllInstances(); + // Filter "STATEFUL_2" should be treated as any other new filter name in an LDS update: + // a new instance should be created. + assertWithMessage("LDS 4: Expected a new filter instance for %s", STATEFUL_2) + .that(lds4Snapshot).hasSize(3); + StatefulFilter lds4Filter2 = lds4Snapshot.get(2); + assertThat(lds4Filter2.idx).isEqualTo(2); + assertThat(lds4Filter2).isNotSameInstanceAs(lds1Filter2); + assertThat(lds4Snapshot).containsAtLeastElementsIn(lds1Snapshot); + // Verify the shutdown state. + assertThat(lds1Filter1.isShutdown()).isFalse(); + assertThat(lds1Filter2.isShutdown()).isTrue(); + assertThat(lds4Filter2.isShutdown()).isFalse(); + } + + /** + * Verifies the lifecycle of HCM filter instances across RDS updates. + * + *

Filter instances: + * 1. Must have instantiated by the initial LDS/RDS. + * 2. Must be reused by all subsequent RDS updates. + * 3. Must be not shutdown (closed) by valid RDS updates. + */ + @Test + public void filterState_survivesRds() { + StatefulFilter.Provider statefulFilterProvider = filterStateTestSetupResolver(); + FakeXdsClient xdsClient = (FakeXdsClient) resolver.getXdsClient(); + + // LDS 1. + xdsClient.deliverLdsUpdateForRdsNameWithFilters(RDS_RESOURCE_NAME, + filterStateTestConfigs(STATEFUL_1, STATEFUL_2)); + // RDS 1. + VirtualHost vhost1 = filterStateTestVhost(); + xdsClient.deliverRdsUpdate(RDS_RESOURCE_NAME, vhost1); + createAndDeliverClusterUpdates(xdsClient, cluster1); + assertClusterResolutionResult(call1, cluster1); + // Initial RDS update should not generate Filter instances. + ImmutableList rds1Snapshot = statefulFilterProvider.getAllInstances(); + // Verify that StatefulFilter with different filter names result in different Filter instances. + assertWithMessage("RDS 1: expected to create filter instances").that(rds1Snapshot).hasSize(2); + // Naming: ldsFilter + StatefulFilter lds1Filter1 = rds1Snapshot.get(0); + StatefulFilter lds1Filter2 = rds1Snapshot.get(1); + assertThat(lds1Filter1).isNotSameInstanceAs(lds1Filter2); + + // RDS 2: exactly the same as RDS 1. + xdsClient.deliverRdsUpdate(RDS_RESOURCE_NAME, vhost1); + assertClusterResolutionResult(call1, cluster1); + ImmutableList rds2Snapshot = statefulFilterProvider.getAllInstances(); + // Neither should any subsequent RDS updates. + assertWithMessage("RDS 2: Expected Filter instances to be reused across RDS route updates") + .that(rds2Snapshot).isEqualTo(rds1Snapshot); + + // RDS 3: Contains a per-route override for STATEFUL_1. + VirtualHost vhost3 = filterStateTestVhost(ImmutableMap.of( + STATEFUL_1, new StatefulFilter.Config("RDS3") + )); + xdsClient.deliverRdsUpdate(RDS_RESOURCE_NAME, vhost3); + assertClusterResolutionResult(call1, cluster1); + ImmutableList rds3Snapshot = statefulFilterProvider.getAllInstances(); + // As with any other Route update, typed_per_filter_config overrides should not result in + // creating new filter instances. + assertWithMessage("RDS 3: Expected Filter instances to be reused on per-route filter overrides") + .that(rds3Snapshot).isEqualTo(rds1Snapshot); + } + + /** + * Verifies a special case where an existing filter is has a different typeUrl in a subsequent + * LDS update. + * + *

Expectations: + * 1. The old filter instance must be shutdown. + * 2. A new filter instance must be created for the new filter with different typeUrl. + */ + @Test + public void filterState_specialCase_sameNameDifferentTypeUrl() { + // Prepare filter registry with StatefulFilter of different typeUrl. + StatefulFilter.Provider statefulFilterProvider = new StatefulFilter.Provider(); + String altTypeUrl = "type.googleapis.com/grpc.test.AltStatefulFilter"; + StatefulFilter.Provider altStatefulFilterProvider = new StatefulFilter.Provider(altTypeUrl); + FilterRegistry filterRegistry = FilterRegistry.newRegistry() + .register(statefulFilterProvider, altStatefulFilterProvider, ROUTER_FILTER_PROVIDER); + resolver = new XdsNameResolver(targetUri, null, AUTHORITY, null, serviceConfigParser, + syncContext, scheduler, xdsClientPoolFactory, mockRandom, filterRegistry, rawBootstrap, + metricRecorder, nameResolverArgs); + resolver.start(mockListener); + + FakeXdsClient xdsClient = (FakeXdsClient) resolver.getXdsClient(); + VirtualHost vhost = filterStateTestVhost(); + + // LDS 1. + xdsClient.deliverLdsUpdateWithFilters(vhost, filterStateTestConfigs(STATEFUL_1, STATEFUL_2)); + createAndDeliverClusterUpdates(xdsClient, cluster1); + assertClusterResolutionResult(call1, cluster1); + ImmutableList lds1Snapshot = statefulFilterProvider.getAllInstances(); + ImmutableList lds1SnapshotAlt = altStatefulFilterProvider.getAllInstances(); + // Verify that StatefulFilter with different filter names result in different Filter instances. + assertWithMessage("LDS 1: expected to create filter instances").that(lds1Snapshot).hasSize(2); + // Naming: ldsFilter + StatefulFilter lds1Filter1 = lds1Snapshot.get(0); + StatefulFilter lds1Filter2 = lds1Snapshot.get(1); + assertThat(lds1Filter1).isNotSameInstanceAs(lds1Filter2); + // Nothing in the alternative provider. + assertThat(lds1SnapshotAlt).isEmpty(); + + // LDS 2: Filter STATEFUL_2 present, but with a different typeUrl: altTypeUrl. + ImmutableList filterConfigs = ImmutableList.of( + new NamedFilterConfig(STATEFUL_1, new StatefulFilter.Config(STATEFUL_1)), + new NamedFilterConfig(STATEFUL_2, new StatefulFilter.Config(STATEFUL_2, altTypeUrl)), + new NamedFilterConfig(ROUTER_FILTER_INSTANCE_NAME, RouterFilter.ROUTER_CONFIG) + ); + xdsClient.deliverLdsUpdateWithFilters(vhost, filterConfigs); + assertClusterResolutionResult(call1, cluster1); + ImmutableList lds2Snapshot = statefulFilterProvider.getAllInstances(); + ImmutableList lds2SnapshotAlt = altStatefulFilterProvider.getAllInstances(); + // Filter "STATEFUL_2" has different typeUrl, and should be treated as a new filter. + // No changes in the snapshot of normal stateful filters. + assertWithMessage("LDS 2: expected a new filter instance of different type") + .that(lds2Snapshot).isEqualTo(lds1Snapshot); + // A new filter instance is created by altStatefulFilterProvider. + assertWithMessage("LDS 2: expected a new filter instance for type %s", altTypeUrl) + .that(lds2SnapshotAlt).hasSize(1); + StatefulFilter lds2Filter2Alt = lds2SnapshotAlt.get(0); + assertThat(lds2Filter2Alt).isNotSameInstanceAs(lds1Filter2); + // Verify the shutdown state. + assertThat(lds1Filter1.isShutdown()).isFalse(); + assertThat(lds1Filter2.isShutdown()).isTrue(); + assertThat(lds2Filter2Alt.isShutdown()).isFalse(); + } + + /** + * Verifies that all filter instances are shutdown (closed) on LDS resource not found. + */ + @Test + public void filterState_shutdown_onLdsNotFound() { + StatefulFilter.Provider statefulFilterProvider = filterStateTestSetupResolver(); + FakeXdsClient xdsClient = (FakeXdsClient) resolver.getXdsClient(); + VirtualHost vhost = filterStateTestVhost(); + + // LDS 1. + xdsClient.deliverLdsUpdateWithFilters(vhost, filterStateTestConfigs(STATEFUL_1, STATEFUL_2)); + createAndDeliverClusterUpdates(xdsClient, cluster1); + assertClusterResolutionResult(call1, cluster1); + ImmutableList lds1Snapshot = statefulFilterProvider.getAllInstances(); + assertWithMessage("LDS 1: expected to create filter instances").that(lds1Snapshot).hasSize(2); + // Naming: ldsFilter + StatefulFilter lds1Filter1 = lds1Snapshot.get(0); + StatefulFilter lds1Filter2 = lds1Snapshot.get(1); + + // LDS 2: resource not found. + reset(mockListener); + when(mockListener.onResult2(any())).thenReturn(Status.OK); + xdsClient.deliverLdsResourceNotFound(); + assertEmptyResolutionResult(expectedLdsResourceName); + // Verify shutdown. + assertThat(lds1Filter1.isShutdown()).isTrue(); + assertThat(lds1Filter2.isShutdown()).isTrue(); + } + + @Test + public void filterState_noShutdown_onLdsDeletion() { + StatefulFilter.Provider statefulFilterProvider = filterStateTestSetupResolver(); + FakeXdsClient xdsClient = (FakeXdsClient) resolver.getXdsClient(); + VirtualHost vhost = filterStateTestVhost(); + + xdsClient.deliverLdsUpdateWithFilters(vhost, filterStateTestConfigs(STATEFUL_1, STATEFUL_2)); + createAndDeliverClusterUpdates(xdsClient, cluster1); + assertClusterResolutionResult(call1, cluster1); + ImmutableList lds1Snapshot = statefulFilterProvider.getAllInstances(); + assertWithMessage("LDS 1: expected to create filter instances").that(lds1Snapshot).hasSize(2); + StatefulFilter lds1Filter1 = lds1Snapshot.get(0); + StatefulFilter lds1Filter2 = lds1Snapshot.get(1); + + // LDS 2: Deliver a resource deletion, which is now an ambient error. + reset(mockListener); + when(mockListener.onResult2(any())).thenReturn(Status.OK); + xdsClient.deliverLdsResourceDeletion(); + + // With an ambient error, no new resolution should happen. + verify(mockListener, never()).onResult2(any()); + + // Verify that the filters are NOT shut down. + assertThat(lds1Filter1.isShutdown()).isFalse(); + assertThat(lds1Filter2.isShutdown()).isFalse(); + } + + /** + * Verifies that all filter instances are shutdown (closed) on LDS ResourceWatcher shutdown. + */ + @Test + public void filterState_shutdown_onResolverShutdown() { + StatefulFilter.Provider statefulFilterProvider = filterStateTestSetupResolver(); + FakeXdsClient xdsClient = (FakeXdsClient) resolver.getXdsClient(); + VirtualHost vhost = filterStateTestVhost(); + + // LDS 1. + xdsClient.deliverLdsUpdateWithFilters(vhost, filterStateTestConfigs(STATEFUL_1, STATEFUL_2)); + createAndDeliverClusterUpdates(xdsClient, cluster1); + assertClusterResolutionResult(call1, cluster1); + ImmutableList lds1Snapshot = statefulFilterProvider.getAllInstances(); + assertWithMessage("LDS 1: expected to create filter instances").that(lds1Snapshot).hasSize(2); + // Naming: ldsFilter + StatefulFilter lds1Filter1 = lds1Snapshot.get(0); + StatefulFilter lds1Filter2 = lds1Snapshot.get(1); + + // Shutdown. + resolver.shutdown(); + resolver = null; // no need to shutdown again in the teardown. + // Verify shutdown. + assertThat(lds1Filter1.isShutdown()).isTrue(); + assertThat(lds1Filter2.isShutdown()).isTrue(); + } + + /** + * Verifies that all filter instances are shutdown (closed) on RDS resource not found. + */ + @Test + public void filterState_shutdown_onRdsNotFound() { + StatefulFilter.Provider statefulFilterProvider = filterStateTestSetupResolver(); + FakeXdsClient xdsClient = (FakeXdsClient) resolver.getXdsClient(); + xdsClient.deliverLdsUpdateForRdsNameWithFilters( + RDS_RESOURCE_NAME, + filterStateTestConfigs(STATEFUL_1, STATEFUL_2)); + xdsClient.deliverRdsUpdate( + RDS_RESOURCE_NAME, + Collections.singletonList(filterStateTestVhost())); + createAndDeliverClusterUpdates(xdsClient, cluster1); + assertClusterResolutionResult(call1, cluster1); + + ImmutableList rds1Snapshot = statefulFilterProvider.getAllInstances(); + assertWithMessage("RDS 1: Expected to create filter instances").that(rds1Snapshot).hasSize(2); + StatefulFilter rds1Filter1 = rds1Snapshot.get(0); + StatefulFilter rds1Filter2 = rds1Snapshot.get(1); + assertThat(rds1Filter1.isShutdown()).isFalse(); + assertThat(rds1Filter2.isShutdown()).isFalse(); + + reset(mockListener); + when(mockListener.onResult2(any())).thenReturn(Status.OK); + xdsClient.deliverRdsResourceNotFound(RDS_RESOURCE_NAME); + + assertEmptyResolutionResult(RDS_RESOURCE_NAME); + assertThat(rds1Filter1.isShutdown()).isTrue(); + assertThat(rds1Filter2.isShutdown()).isTrue(); + } + + @Test + public void filterState_noShutdown_onRdsAmbientError() { + StatefulFilter.Provider statefulFilterProvider = filterStateTestSetupResolver(); + FakeXdsClient xdsClient = (FakeXdsClient) resolver.getXdsClient(); + + // LDS 1. + xdsClient.deliverLdsUpdateForRdsNameWithFilters(RDS_RESOURCE_NAME, + filterStateTestConfigs(STATEFUL_1, STATEFUL_2)); + // RDS 1: Standard vhost with a route. + xdsClient.deliverRdsUpdate(RDS_RESOURCE_NAME, filterStateTestVhost()); + createAndDeliverClusterUpdates(xdsClient, cluster1); + assertClusterResolutionResult(call1, cluster1); + ImmutableList rds1Snapshot = statefulFilterProvider.getAllInstances(); + assertWithMessage("RDS 1: expected to create filter instances").that(rds1Snapshot).hasSize(2); + // Naming: ldsFilter + StatefulFilter lds1Filter1 = rds1Snapshot.get(0); + StatefulFilter lds1Filter2 = rds1Snapshot.get(1); + + // RDS 2: RDS_RESOURCE_NAME not found. + reset(mockListener); + when(mockListener.onResult2(any())).thenReturn(Status.OK); + xdsClient.deliverRdsAmbientError(RDS_RESOURCE_NAME, Status.NOT_FOUND); + verify(mockListener, never()).onResult2(any()); + assertThat(lds1Filter1.isShutdown()).isFalse(); + assertThat(lds1Filter2.isShutdown()).isFalse(); + } + + private StatefulFilter.Provider filterStateTestSetupResolver() { + StatefulFilter.Provider statefulFilterProvider = new StatefulFilter.Provider(); + FilterRegistry filterRegistry = FilterRegistry.newRegistry() + .register(statefulFilterProvider, ROUTER_FILTER_PROVIDER); + resolver = new XdsNameResolver(targetUri, null, AUTHORITY, null, serviceConfigParser, + syncContext, scheduler, xdsClientPoolFactory, mockRandom, filterRegistry, rawBootstrap, + metricRecorder, nameResolverArgs); + resolver.start(mockListener); + return statefulFilterProvider; + } + + private ImmutableList filterStateTestConfigs(String... names) { + ImmutableList.Builder result = ImmutableList.builder(); + for (String name : names) { + result.add(new NamedFilterConfig(name, new StatefulFilter.Config(name))); + } + result.add(new NamedFilterConfig(ROUTER_FILTER_INSTANCE_NAME, RouterFilter.ROUTER_CONFIG)); + return result.build(); + } + + private Route filterStateTestRoute(ImmutableMap perRouteOverrides) { + // Standard basic route for filterState tests. + return Route.forAction( + RouteMatch.withPathExactOnly(call1.getFullMethodNameForPath()), + RouteAction.forCluster(cluster1, NO_HASH_POLICIES, null, null, true), + perRouteOverrides); + } + + private VirtualHost filterStateTestVhost() { + return filterStateTestVhost(NO_FILTER_OVERRIDES); + } + + private VirtualHost filterStateTestVhost(ImmutableMap perRouteOverrides) { + return VirtualHost.create( + "stateful-vhost", + ImmutableList.of(expectedLdsResourceName), + ImmutableList.of(filterStateTestRoute(perRouteOverrides)), + NO_FILTER_OVERRIDES); + } + + // End filter state tests. + @SuppressWarnings("unchecked") private void assertEmptyResolutionResult(String resource) { - verify(mockListener).onResult(resolutionResultCaptor.capture()); + verify(mockListener).onResult2(resolutionResultCaptor.capture()); ResolutionResult result = resolutionResultCaptor.getValue(); - assertThat(result.getAddresses()).isEmpty(); + assertThat(result.getAddressesOrError().getValue()).isEmpty(); assertThat((Map) result.getServiceConfig().getConfig()).isEmpty(); InternalConfigSelector configSelector = result.getAttributes().get(InternalConfigSelector.KEY); Result configResult = configSelector.selectConfig( - new PickSubchannelArgsImpl(call1.methodDescriptor, new Metadata(), CallOptions.DEFAULT)); + newPickSubchannelArgs(call1.methodDescriptor, new Metadata(), CallOptions.DEFAULT)); assertThat(configResult.getStatus().getCode()).isEqualTo(Status.Code.UNAVAILABLE); assertThat(configResult.getStatus().getDescription()).contains(resource); } + private void assertClusterResolutionResult(CallInfo call, String expectedCluster) { + verify(mockListener, atLeast(1)).onResult2(resolutionResultCaptor.capture()); + ResolutionResult result = resolutionResultCaptor.getValue(); + InternalConfigSelector configSelector = result.getAttributes().get(InternalConfigSelector.KEY); + assertCallSelectClusterResult(call, configSelector, expectedCluster, null); + } + private void assertCallSelectClusterResult( CallInfo call, InternalConfigSelector configSelector, String expectedCluster, @Nullable Double expectedTimeoutSec) { Result result = configSelector.selectConfig( - new PickSubchannelArgsImpl(call.methodDescriptor, new Metadata(), CallOptions.DEFAULT)); + newPickSubchannelArgs(call.methodDescriptor, new Metadata(), CallOptions.DEFAULT)); assertThat(result.getStatus().isOk()).isTrue(); ClientInterceptor interceptor = result.getInterceptor(); ClientCall clientCall = interceptor.interceptCall( @@ -1183,6 +1769,10 @@ private void assertCallSelectClusterResult( clientCall.start(new NoopClientCallListener<>(), new Metadata()); assertThat(testCall.callOptions.getOption(XdsNameResolver.CLUSTER_SELECTION_KEY)) .isEqualTo("cluster:" + expectedCluster); + XdsConfig xdsConfig = + testCall.callOptions.getOption(XdsNameResolver.XDS_CONFIG_CALL_OPTION_KEY); + assertThat(xdsConfig).isNotNull(); + assertThat(xdsConfig.getClusters()).containsKey(expectedCluster); // Without "cluster:" prefix @SuppressWarnings("unchecked") Map config = (Map) result.getConfig(); if (expectedTimeoutSec != null) { @@ -1203,7 +1793,7 @@ private void assertCallSelectRlsPluginResult( CallInfo call, InternalConfigSelector configSelector, String expectedPluginName, Double expectedTimeoutSec) { Result result = configSelector.selectConfig( - new PickSubchannelArgsImpl(call.methodDescriptor, new Metadata(), CallOptions.DEFAULT)); + newPickSubchannelArgs(call.methodDescriptor, new Metadata(), CallOptions.DEFAULT)); assertThat(result.getStatus().isOk()).isTrue(); ClientInterceptor interceptor = result.getInterceptor(); ClientCall clientCall = interceptor.interceptCall( @@ -1231,24 +1821,31 @@ private InternalConfigSelector resolveToClusters() { RouteMatch.withPathExactOnly(call1.getFullMethodNameForPath()), RouteAction.forCluster( cluster1, Collections.emptyList(), TimeUnit.SECONDS.toNanos(15L), - null), + null, false), ImmutableMap.of()), Route.forAction( RouteMatch.withPathExactOnly(call2.getFullMethodNameForPath()), RouteAction.forCluster( cluster2, Collections.emptyList(), TimeUnit.SECONDS.toNanos(15L), - null), + null, false), ImmutableMap.of()))); - verify(mockListener).onResult(resolutionResultCaptor.capture()); + verify(mockListener).onResult2(resolutionResultCaptor.capture()); ResolutionResult result = resolutionResultCaptor.getValue(); - assertThat(result.getAddresses()).isEmpty(); + assertThat(result.getAddressesOrError().getValue()).isEmpty(); assertServiceConfigForLoadBalancingConfig( Arrays.asList(cluster1, cluster2), (Map) result.getServiceConfig().getConfig()); - assertThat(result.getAttributes().get(InternalXdsAttributes.XDS_CLIENT_POOL)).isNotNull(); - assertThat(result.getAttributes().get(InternalXdsAttributes.CALL_COUNTER_PROVIDER)).isNotNull(); + assertThat(result.getAttributes().get(XdsAttributes.XDS_CLIENT)).isNotNull(); + assertThat(result.getAttributes().get(XdsAttributes.CALL_COUNTER_PROVIDER)).isNotNull(); return result.getAttributes().get(InternalConfigSelector.KEY); } + private static void assertServiceConfigForLoadBalancingConfig( + List clusters, ResolutionResult result) { + @SuppressWarnings("unchecked") + Map config = (Map) result.getServiceConfig().getConfig(); + assertServiceConfigForLoadBalancingConfig(clusters, config); + } + /** * Verifies the raw service config contains an xDS load balancing config for the given clusters. */ @@ -1286,7 +1883,7 @@ public void generateServiceConfig_forClusterManagerLoadBalancingConfig() throws Route route1 = Route.forAction( RouteMatch.withPathExactOnly("HelloService/hi"), RouteAction.forCluster( - "cluster-foo", Collections.emptyList(), TimeUnit.SECONDS.toNanos(15L), null), + "cluster-foo", Collections.emptyList(), TimeUnit.SECONDS.toNanos(15L), null, false), ImmutableMap.of()); Route route2 = Route.forAction( RouteMatch.withPathExactOnly("HelloService/hello"), @@ -1296,7 +1893,7 @@ public void generateServiceConfig_forClusterManagerLoadBalancingConfig() throws ClusterWeight.create("cluster-baz", 50, ImmutableMap.of())), ImmutableList.of(), TimeUnit.SECONDS.toNanos(15L), - null), + null, false), ImmutableMap.of()); Map rlsConfig = ImmutableMap.of("lookupService", "rls.bigtable.google.com"); Route route3 = Route.forAction( @@ -1305,7 +1902,7 @@ public void generateServiceConfig_forClusterManagerLoadBalancingConfig() throws NamedPluginConfig.create("plugin-foo", RlsPluginConfig.create(rlsConfig)), Collections.emptyList(), TimeUnit.SECONDS.toNanos(20L), - null), + null, false), ImmutableMap.of()); resolver.start(mockListener); @@ -1316,8 +1913,9 @@ public void generateServiceConfig_forClusterManagerLoadBalancingConfig() throws ImmutableList.of(route1, route2, route3), ImmutableMap.of()); xdsClient.deliverRdsUpdate(RDS_RESOURCE_NAME, Collections.singletonList(virtualHost)); + createAndDeliverClusterUpdates(xdsClient, "cluster-foo", "cluster-bar", "cluster-baz"); - verify(mockListener).onResult(resolutionResultCaptor.capture()); + verify(mockListener).onResult2(resolutionResultCaptor.capture()); String expectedServiceConfigJson = "{\n" + " \"loadBalancingConfig\": [{\n" @@ -1351,7 +1949,9 @@ public void generateServiceConfig_forClusterManagerLoadBalancingConfig() throws + " \"lookupService\": \"rls.bigtable.google.com\"\n" + " },\n" + " \"childPolicy\": [\n" - + " {\"cds_experimental\": {}}\n" + + " {\"cds_experimental\": {\n" + + " \"is_dynamic\": true\n" + + " }}\n" + " ],\n" + " \"childPolicyConfigTargetFieldName\": \"cluster\"\n" + " }\n" @@ -1412,7 +2012,6 @@ public void generateServiceConfig_forPerMethodConfig() throws IOException { assertThat(XdsNameResolver.generateServiceConfigWithMethodConfig(null, retryPolicy)) .isEqualTo(expectedServiceConfig); - // timeout and retry expectedServiceConfigJson = "{\n" + " \"methodConfig\": [{\n" @@ -1501,7 +2100,7 @@ public void resolved_faultAbortInLdsUpdate() { FaultAbort.forHeader(FaultConfig.FractionalPercent.perHundred(70)), null); xdsClient.deliverLdsUpdateWithFaultInjection(cluster1, httpFilterFaultConfig, null, null, null); - verify(mockListener).onResult(resolutionResultCaptor.capture()); + verify(mockListener).onResult2(resolutionResultCaptor.capture()); ResolutionResult result = resolutionResultCaptor.getValue(); InternalConfigSelector configSelector = result.getAttributes().get(InternalConfigSelector.KEY); // no header abort key provided in metadata, rpc should succeed @@ -1540,7 +2139,7 @@ public void resolved_faultAbortInLdsUpdate() { FaultAbort.forHeader(FaultConfig.FractionalPercent.perMillion(600_000)), null); xdsClient.deliverLdsUpdateWithFaultInjection(cluster1, httpFilterFaultConfig, null, null, null); - verify(mockListener).onResult(resolutionResultCaptor.capture()); + verify(mockListener, times(2)).onResult2(resolutionResultCaptor.capture()); result = resolutionResultCaptor.getValue(); configSelector = result.getAttributes().get(InternalConfigSelector.KEY); observer = startNewCall(TestMethodDescriptors.voidMethod(), configSelector, @@ -1556,7 +2155,7 @@ public void resolved_faultAbortInLdsUpdate() { FaultAbort.forHeader(FaultConfig.FractionalPercent.perMillion(0)), null); xdsClient.deliverLdsUpdateWithFaultInjection(cluster1, httpFilterFaultConfig, null, null, null); - verify(mockListener).onResult(resolutionResultCaptor.capture()); + verify(mockListener, times(3)).onResult2(resolutionResultCaptor.capture()); result = resolutionResultCaptor.getValue(); configSelector = result.getAttributes().get(InternalConfigSelector.KEY); observer = startNewCall(TestMethodDescriptors.voidMethod(), configSelector, @@ -1571,7 +2170,7 @@ public void resolved_faultAbortInLdsUpdate() { FaultConfig.FractionalPercent.perMillion(600_000)), null); xdsClient.deliverLdsUpdateWithFaultInjection(cluster1, httpFilterFaultConfig, null, null, null); - verify(mockListener).onResult(resolutionResultCaptor.capture()); + verify(mockListener, times(4)).onResult2(resolutionResultCaptor.capture()); result = resolutionResultCaptor.getValue(); configSelector = result.getAttributes().get(InternalConfigSelector.KEY); observer = startNewCall(TestMethodDescriptors.voidMethod(), configSelector, @@ -1589,7 +2188,7 @@ public void resolved_faultAbortInLdsUpdate() { FaultConfig.FractionalPercent.perMillion(400_000)), null); xdsClient.deliverLdsUpdateWithFaultInjection(cluster1, httpFilterFaultConfig, null, null, null); - verify(mockListener).onResult(resolutionResultCaptor.capture()); + verify(mockListener, times(5)).onResult2(resolutionResultCaptor.capture()); result = resolutionResultCaptor.getValue(); configSelector = result.getAttributes().get(InternalConfigSelector.KEY); observer = startNewCall(TestMethodDescriptors.voidMethod(), configSelector, @@ -1607,7 +2206,7 @@ public void resolved_faultDelayInLdsUpdate() { FaultConfig httpFilterFaultConfig = FaultConfig.create( FaultDelay.forHeader(FaultConfig.FractionalPercent.perHundred(70)), null, null); xdsClient.deliverLdsUpdateWithFaultInjection(cluster1, httpFilterFaultConfig, null, null, null); - verify(mockListener).onResult(resolutionResultCaptor.capture()); + verify(mockListener).onResult2(resolutionResultCaptor.capture()); ResolutionResult result = resolutionResultCaptor.getValue(); InternalConfigSelector configSelector = result.getAttributes().get(InternalConfigSelector.KEY); // no header delay key provided in metadata, rpc should succeed immediately @@ -1624,7 +2223,7 @@ public void resolved_faultDelayInLdsUpdate() { httpFilterFaultConfig = FaultConfig.create( FaultDelay.forHeader(FaultConfig.FractionalPercent.perMillion(600_000)), null, null); xdsClient.deliverLdsUpdateWithFaultInjection(cluster1, httpFilterFaultConfig, null, null, null); - verify(mockListener).onResult(resolutionResultCaptor.capture()); + verify(mockListener, times(2)).onResult2(resolutionResultCaptor.capture()); result = resolutionResultCaptor.getValue(); configSelector = result.getAttributes().get(InternalConfigSelector.KEY); observer = startNewCall(TestMethodDescriptors.voidMethod(), configSelector, @@ -1635,7 +2234,7 @@ public void resolved_faultDelayInLdsUpdate() { httpFilterFaultConfig = FaultConfig.create( FaultDelay.forHeader(FaultConfig.FractionalPercent.perMillion(0)), null, null); xdsClient.deliverLdsUpdateWithFaultInjection(cluster1, httpFilterFaultConfig, null, null, null); - verify(mockListener).onResult(resolutionResultCaptor.capture()); + verify(mockListener, times(3)).onResult2(resolutionResultCaptor.capture()); result = resolutionResultCaptor.getValue(); configSelector = result.getAttributes().get(InternalConfigSelector.KEY); observer = startNewCall(TestMethodDescriptors.voidMethod(), configSelector, @@ -1648,7 +2247,7 @@ public void resolved_faultDelayInLdsUpdate() { null, null); xdsClient.deliverLdsUpdateWithFaultInjection(cluster1, httpFilterFaultConfig, null, null, null); - verify(mockListener).onResult(resolutionResultCaptor.capture()); + verify(mockListener, times(4)).onResult2(resolutionResultCaptor.capture()); result = resolutionResultCaptor.getValue(); configSelector = result.getAttributes().get(InternalConfigSelector.KEY); observer = startNewCall(TestMethodDescriptors.voidMethod(), configSelector, @@ -1661,7 +2260,7 @@ public void resolved_faultDelayInLdsUpdate() { null, null); xdsClient.deliverLdsUpdateWithFaultInjection(cluster1, httpFilterFaultConfig, null, null, null); - verify(mockListener).onResult(resolutionResultCaptor.capture()); + verify(mockListener, times(5)).onResult2(resolutionResultCaptor.capture()); result = resolutionResultCaptor.getValue(); configSelector = result.getAttributes().get(InternalConfigSelector.KEY); observer = startNewCall(TestMethodDescriptors.voidMethod(), configSelector, @@ -1680,7 +2279,7 @@ public void resolved_faultDelayWithMaxActiveStreamsInLdsUpdate() { null, /* maxActiveFaults= */ 1); xdsClient.deliverLdsUpdateWithFaultInjection(cluster1, httpFilterFaultConfig, null, null, null); - verify(mockListener).onResult(resolutionResultCaptor.capture()); + verify(mockListener).onResult2(resolutionResultCaptor.capture()); ResolutionResult result = resolutionResultCaptor.getValue(); InternalConfigSelector configSelector = result.getAttributes().get(InternalConfigSelector.KEY); @@ -1710,7 +2309,7 @@ public void resolved_faultDelayInLdsUpdate_callWithEarlyDeadline() { null, null); xdsClient.deliverLdsUpdateWithFaultInjection(cluster1, httpFilterFaultConfig, null, null, null); - verify(mockListener).onResult(resolutionResultCaptor.capture()); + verify(mockListener).onResult2(resolutionResultCaptor.capture()); ResolutionResult result = resolutionResultCaptor.getValue(); InternalConfigSelector configSelector = result.getAttributes().get(InternalConfigSelector.KEY); @@ -1726,7 +2325,7 @@ public long nanoTime() { assertThat(testCall).isNull(); verifyRpcDelayedThenAborted(observer, 4000L, Status.DEADLINE_EXCEEDED.withDescription( "Deadline exceeded after up to 5000 ns of fault-injected delay:" - + " Deadline CallOptions will be exceeded in 0.000004000s. ")); + + " Deadline CallOptions was exceeded after 0.000004000s")); } @Test @@ -1742,7 +2341,7 @@ public void resolved_faultAbortAndDelayInLdsUpdateInLdsUpdate() { FaultConfig.FractionalPercent.perMillion(1000_000)), null); xdsClient.deliverLdsUpdateWithFaultInjection(cluster1, httpFilterFaultConfig, null, null, null); - verify(mockListener).onResult(resolutionResultCaptor.capture()); + verify(mockListener).onResult2(resolutionResultCaptor.capture()); ResolutionResult result = resolutionResultCaptor.getValue(); InternalConfigSelector configSelector = result.getAttributes().get(InternalConfigSelector.KEY); ClientCall.Listener observer = startNewCall(TestMethodDescriptors.voidMethod(), @@ -1771,7 +2370,7 @@ public void resolved_faultConfigOverrideInLdsUpdate() { null); xdsClient.deliverLdsUpdateWithFaultInjection( cluster1, httpFilterFaultConfig, virtualHostFaultConfig, null, null); - verify(mockListener).onResult(resolutionResultCaptor.capture()); + verify(mockListener).onResult2(resolutionResultCaptor.capture()); ResolutionResult result = resolutionResultCaptor.getValue(); InternalConfigSelector configSelector = result.getAttributes().get(InternalConfigSelector.KEY); ClientCall.Listener observer = startNewCall(TestMethodDescriptors.voidMethod(), @@ -1786,7 +2385,7 @@ public void resolved_faultConfigOverrideInLdsUpdate() { null); xdsClient.deliverLdsUpdateWithFaultInjection( cluster1, httpFilterFaultConfig, virtualHostFaultConfig, routeFaultConfig, null); - verify(mockListener).onResult(resolutionResultCaptor.capture()); + verify(mockListener, times(2)).onResult2(resolutionResultCaptor.capture()); result = resolutionResultCaptor.getValue(); configSelector = result.getAttributes().get(InternalConfigSelector.KEY); observer = startNewCall(TestMethodDescriptors.voidMethod(), configSelector, @@ -1803,7 +2402,7 @@ public void resolved_faultConfigOverrideInLdsUpdate() { xdsClient.deliverLdsUpdateWithFaultInjection( cluster1, httpFilterFaultConfig, virtualHostFaultConfig, routeFaultConfig, weightedClusterFaultConfig); - verify(mockListener).onResult(resolutionResultCaptor.capture()); + verify(mockListener, times(3)).onResult2(resolutionResultCaptor.capture()); result = resolutionResultCaptor.getValue(); configSelector = result.getAttributes().get(InternalConfigSelector.KEY); observer = startNewCall(TestMethodDescriptors.voidMethod(), configSelector, @@ -1832,7 +2431,7 @@ public void resolved_faultConfigOverrideInLdsAndInRdsUpdate() { FaultAbort.forStatus(Status.UNKNOWN, FaultConfig.FractionalPercent.perMillion(1000_000)), null); xdsClient.deliverRdsUpdateWithFaultInjection(RDS_RESOURCE_NAME, null, routeFaultConfig, null); - verify(mockListener).onResult(resolutionResultCaptor.capture()); + verify(mockListener).onResult2(resolutionResultCaptor.capture()); ResolutionResult result = resolutionResultCaptor.getValue(); InternalConfigSelector configSelector = result.getAttributes().get(InternalConfigSelector.KEY); ClientCall.Listener observer = startNewCall(TestMethodDescriptors.voidMethod(), @@ -1850,8 +2449,7 @@ private ClientCall.Listener startNewCall( } @SuppressWarnings("unchecked") ClientCall.Listener listener = mock(ClientCall.Listener.class); - Result result = selector.selectConfig(new PickSubchannelArgsImpl( - method, metadata, callOptions)); + Result result = selector.selectConfig(newPickSubchannelArgs(method, metadata, callOptions)); ClientCall call = ClientInterceptors.intercept(channel, result.getInterceptor()).newCall(method, callOptions); call.start(listener, metadata); @@ -1889,22 +2487,29 @@ private void verifyRpcDelayedThenAborted( verifyRpcFailed(listener, expectedStatus); } + private PickSubchannelArgs newPickSubchannelArgs( + MethodDescriptor method, Metadata headers, CallOptions callOptions) { + return new PickSubchannelArgsImpl(method, headers, callOptions, new PickDetailsConsumer() {}); + } + private final class FakeXdsClientPoolFactory implements XdsClientPoolFactory { - @Override - public void setBootstrapOverride(Map bootstrap) {} + Set targets = new HashSet<>(); + XdsClient xdsClient = new FakeXdsClient(); @Override @Nullable - public ObjectPool get() { + public ObjectPool get(String target) { throw new UnsupportedOperationException("Should not be called"); } @Override - public ObjectPool getOrCreate() throws XdsInitializationException { + public ObjectPool getOrCreate( + String target, BootstrapInfo bootstrapInfo, MetricRecorder metricRecorder) { + targets.add(target); return new ObjectPool() { @Override public XdsClient getObject() { - return new FakeXdsClient(); + return xdsClient; } @Override @@ -1913,14 +2518,25 @@ public XdsClient returnObject(Object object) { } }; } + + @Override + public List getTargets() { + if (targets.isEmpty()) { + List targetList = new ArrayList<>(); + targetList.add(targetUri.toString()); + return targetList; + } + return new ArrayList<>(targets); + } } private class FakeXdsClient extends XdsClient { // Should never be subscribing to more than one LDS and RDS resource at any point of time. private String ldsResource; // should always be AUTHORITY - private String rdsResource; private ResourceWatcher ldsWatcher; - private ResourceWatcher rdsWatcher; + private final Map>> rdsWatchers = new HashMap<>(); + private final Map>> cdsWatchers = new HashMap<>(); + private final Map>> edsWatchers = new HashMap<>(); @Override public BootstrapInfo getBootstrapInfo() { @@ -1943,15 +2559,22 @@ public void watchXdsResource(XdsResourceType resou ldsWatcher = (ResourceWatcher) watcher; break; case "RDS": - assertThat(rdsResource).isNull(); - assertThat(rdsWatcher).isNull(); - rdsResource = resourceName; - rdsWatcher = (ResourceWatcher) watcher; + rdsWatchers.computeIfAbsent(resourceName, k -> new ArrayList<>()) + .add((ResourceWatcher) watcher); + break; + case "CDS": + cdsWatchers.computeIfAbsent(resourceName, k -> new ArrayList<>()) + .add((ResourceWatcher) watcher); + break; + case "EDS": + edsWatchers.computeIfAbsent(resourceName, k -> new ArrayList<>()) + .add((ResourceWatcher) watcher); break; default: } } + @SuppressWarnings("unchecked") @Override public void cancelXdsResourceWatch(XdsResourceType type, String resourceName, @@ -1965,19 +2588,57 @@ public void cancelXdsResourceWatch(XdsResourceType ldsWatcher = null; break; case "RDS": - assertThat(rdsResource).isNotNull(); - assertThat(rdsWatcher).isNotNull(); - rdsResource = null; - rdsWatcher = null; + assertThat(rdsWatchers).containsKey(resourceName); + assertThat(rdsWatchers.get(resourceName)).contains(watcher); + rdsWatchers.get(resourceName).remove((ResourceWatcher) watcher); + if (rdsWatchers.get(resourceName).isEmpty()) { + rdsWatchers.remove(resourceName); + } + break; + case "CDS": + assertThat(cdsWatchers).containsKey(resourceName); + assertThat(cdsWatchers.get(resourceName)).contains(watcher); + cdsWatchers.get(resourceName).remove((ResourceWatcher) watcher); + break; + case "EDS": + assertThat(edsWatchers).containsKey(resourceName); + assertThat(edsWatchers.get(resourceName)).contains(watcher); + edsWatchers.get(resourceName).remove((ResourceWatcher) watcher); break; default: } } + void deliverRdsAmbientError(String resourceName, Status status) { + if (!rdsWatchers.containsKey(resourceName)) { + return; + } + syncContext.execute(() -> { + List> resourceWatchers = + ImmutableList.copyOf(rdsWatchers.get(resourceName)); + resourceWatchers.forEach(w -> w.onAmbientError(status)); + }); + } + + void deliverLdsUpdateOnly(long httpMaxStreamDurationNano, List virtualHosts) { + syncContext.execute(() -> { + LdsUpdate ldsUpdate = LdsUpdate.forApiListener(HttpConnectionManager.forVirtualHosts( + httpMaxStreamDurationNano, virtualHosts, null)); + ldsWatcher.onResourceChanged(StatusOr.fromValue(ldsUpdate)); + }); + } + void deliverLdsUpdate(long httpMaxStreamDurationNano, List virtualHosts) { + List clusterNames = new ArrayList<>(); + for (VirtualHost vh : virtualHosts) { + clusterNames.addAll(getClusterNames(vh.routes())); + } + syncContext.execute(() -> { - ldsWatcher.onChanged(LdsUpdate.forApiListener(HttpConnectionManager.forVirtualHosts( - httpMaxStreamDurationNano, virtualHosts, null))); + LdsUpdate ldsUpdate = LdsUpdate.forApiListener(HttpConnectionManager.forVirtualHosts( + httpMaxStreamDurationNano, virtualHosts, null)); + ldsWatcher.onResourceChanged(StatusOr.fromValue(ldsUpdate)); + createAndDeliverClusterUpdates(this, clusterNames.toArray(new String[0])); }); } @@ -1986,9 +2647,23 @@ void deliverLdsUpdate(final List routes) { VirtualHost.create( "virtual-host", Collections.singletonList(expectedLdsResourceName), routes, ImmutableMap.of()); + List clusterNames = getClusterNames(routes); + + syncContext.execute(() -> { + LdsUpdate ldsUpdate = LdsUpdate.forApiListener(HttpConnectionManager.forVirtualHosts( + 0L, Collections.singletonList(virtualHost), null)); + ldsWatcher.onResourceChanged(StatusOr.fromValue(ldsUpdate)); + if (!clusterNames.isEmpty()) { + createAndDeliverClusterUpdates(this, clusterNames.toArray(new String[0])); + } + }); + } + + void deliverLdsUpdateWithFilters(VirtualHost vhost, List filterConfigs) { syncContext.execute(() -> { - ldsWatcher.onChanged(LdsUpdate.forApiListener(HttpConnectionManager.forVirtualHosts( - 0L, Collections.singletonList(virtualHost), null))); + LdsUpdate ldsUpdate = LdsUpdate.forApiListener(HttpConnectionManager.forVirtualHosts( + 0L, Collections.singletonList(vhost), filterConfigs)); + ldsWatcher.onResourceChanged(StatusOr.fromValue(ldsUpdate)); }); } @@ -2022,7 +2697,8 @@ void deliverLdsUpdateWithFaultInjection( Collections.singletonList(clusterWeight), Collections.emptyList(), null, - null), + null, + false), overrideConfig); overrideConfig = virtualHostFaultConfig == null ? ImmutableMap.of() @@ -2034,8 +2710,10 @@ void deliverLdsUpdateWithFaultInjection( Collections.singletonList(route), overrideConfig); syncContext.execute(() -> { - ldsWatcher.onChanged(LdsUpdate.forApiListener(HttpConnectionManager.forVirtualHosts( - 0L, Collections.singletonList(virtualHost), filterChain))); + LdsUpdate ldsUpdate = LdsUpdate.forApiListener(HttpConnectionManager.forVirtualHosts( + 0L, Collections.singletonList(virtualHost), filterChain)); + ldsWatcher.onResourceChanged(StatusOr.fromValue(ldsUpdate)); + createAndDeliverClusterUpdates(this, cluster); }); } @@ -2049,30 +2727,70 @@ void deliverLdsUpdateForRdsNameWithFaultInjection( new NamedFilterConfig(FAULT_FILTER_INSTANCE_NAME, httpFilterFaultConfig), new NamedFilterConfig(ROUTER_FILTER_INSTANCE_NAME, RouterFilter.ROUTER_CONFIG)); syncContext.execute(() -> { - ldsWatcher.onChanged(LdsUpdate.forApiListener(HttpConnectionManager.forRdsName( - 0L, rdsName, filterChain))); + LdsUpdate ldsUpdate = LdsUpdate.forApiListener(HttpConnectionManager.forRdsName( + 0L, rdsName, filterChain)); + ldsWatcher.onResourceChanged(StatusOr.fromValue(ldsUpdate)); }); } void deliverLdsUpdateForRdsName(String rdsName) { + deliverLdsUpdateForRdsNameWithFilters(rdsName, null); + } + + void deliverLdsUpdateForRdsNameWithFilters( + String rdsName, + @Nullable List filterConfigs) { + syncContext.execute(() -> { + LdsUpdate ldsUpdate = LdsUpdate.forApiListener(HttpConnectionManager.forRdsName( + 0, rdsName, filterConfigs)); + ldsWatcher.onResourceChanged(StatusOr.fromValue(ldsUpdate)); + }); + } + + void deliverLdsResourceDeletion() { + Status status = Status.NOT_FOUND.withDescription( + "Resource not found: " + expectedLdsResourceName); syncContext.execute(() -> { - ldsWatcher.onChanged(LdsUpdate.forApiListener(HttpConnectionManager.forRdsName( - 0, rdsName, null))); + ldsWatcher.onAmbientError(status); }); } void deliverLdsResourceNotFound() { + Status notFoundStatus = Status.UNAVAILABLE.withDescription( + "Resource not found: " + expectedLdsResourceName); syncContext.execute(() -> { - ldsWatcher.onResourceDoesNotExist(expectedLdsResourceName); + if (ldsWatcher != null) { + ldsWatcher.onResourceChanged(StatusOr.fromStatus(notFoundStatus)); + } }); } + private List getClusterNames(List routes) { + List clusterNames = new ArrayList<>(); + for (Route r : routes) { + if (r.routeAction() == null) { + continue; + } + String cluster = r.routeAction().cluster(); + if (cluster != null) { + clusterNames.add(cluster); + } else { + List weightedClusters = r.routeAction().weightedClusters(); + if (weightedClusters == null) { + continue; + } + for (ClusterWeight wc : weightedClusters) { + clusterNames.add(wc.name()); + } + } + } + + return clusterNames; + } + void deliverRdsUpdateWithFaultInjection( String resourceName, @Nullable FaultConfig virtualHostFaultConfig, @Nullable FaultConfig routFaultConfig, @Nullable FaultConfig weightedClusterFaultConfig) { - if (!resourceName.equals(rdsResource)) { - return; - } ImmutableMap overrideConfig = weightedClusterFaultConfig == null ? ImmutableMap.of() : ImmutableMap.of( @@ -2089,7 +2807,8 @@ void deliverRdsUpdateWithFaultInjection( Collections.singletonList(clusterWeight), Collections.emptyList(), null, - null), + null, + false), overrideConfig); overrideConfig = virtualHostFaultConfig == null ? ImmutableMap.of() @@ -2100,40 +2819,78 @@ void deliverRdsUpdateWithFaultInjection( Collections.singletonList(expectedLdsResourceName), Collections.singletonList(route), overrideConfig); - syncContext.execute(() -> { - rdsWatcher.onChanged(new RdsUpdate(Collections.singletonList(virtualHost))); - }); + deliverRdsUpdate(resourceName, virtualHost); + createAndDeliverClusterUpdates(this, cluster1); } void deliverRdsUpdate(String resourceName, List virtualHosts) { - if (!resourceName.equals(rdsResource)) { + if (!rdsWatchers.containsKey(resourceName)) { return; } syncContext.execute(() -> { - rdsWatcher.onChanged(new RdsUpdate(virtualHosts)); + RdsUpdate update = new RdsUpdate(virtualHosts); + List> resourceWatchers = + ImmutableList.copyOf(rdsWatchers.get(resourceName)); + resourceWatchers.forEach(w -> w.onResourceChanged(StatusOr.fromValue(update))); }); } + void deliverRdsUpdate(String resourceName, VirtualHost virtualHost) { + deliverRdsUpdate(resourceName, ImmutableList.of(virtualHost)); + } + void deliverRdsResourceNotFound(String resourceName) { - if (!resourceName.equals(rdsResource)) { + if (!rdsWatchers.containsKey(resourceName)) { return; } syncContext.execute(() -> { - rdsWatcher.onResourceDoesNotExist(rdsResource); + List> resourceWatchers = + ImmutableList.copyOf(rdsWatchers.get(resourceName)); + Status status = Status.UNAVAILABLE.withDescription("Resource not found: " + resourceName); + resourceWatchers.forEach(w -> w.onResourceChanged(StatusOr.fromStatus(status))); }); } + private void deliverCdsUpdate(String clusterName, CdsUpdate update) { + if (!cdsWatchers.containsKey(clusterName)) { + return; + } + syncContext.execute(() -> { + List> resourceWatchers = + ImmutableList.copyOf(cdsWatchers.get(clusterName)); + resourceWatchers.forEach(w -> w.onResourceChanged(StatusOr.fromValue(update))); + }); + } + + private void deliverEdsUpdate(String name, EdsUpdate update) { + syncContext.execute(() -> { + if (!edsWatchers.containsKey(name)) { + return; + } + List> resourceWatchers = + ImmutableList.copyOf(edsWatchers.get(name)); + resourceWatchers.forEach(w -> w.onResourceChanged(StatusOr.fromValue(update))); + }); + } + + void deliverError(final Status error) { if (ldsWatcher != null) { syncContext.execute(() -> { - ldsWatcher.onError(error); - }); - } - if (rdsWatcher != null) { - syncContext.execute(() -> { - rdsWatcher.onError(error); + ldsWatcher.onResourceChanged(StatusOr.fromStatus(error)); }); } + syncContext.execute(() -> { + List> rdsCopy = rdsWatchers.values().stream() + .flatMap(List::stream).collect(java.util.stream.Collectors.toList()); + List> cdsCopy = cdsWatchers.values().stream() + .flatMap(List::stream).collect(java.util.stream.Collectors.toList()); + List> edsCopy = edsWatchers.values().stream() + .flatMap(List::stream).collect(java.util.stream.Collectors.toList()); + rdsCopy.forEach(w -> w.onResourceChanged(StatusOr.fromStatus(error))); + cdsCopy.forEach(w -> w.onResourceChanged(StatusOr.fromStatus(error))); + edsCopy.forEach(w -> w.onResourceChanged(StatusOr.fromStatus(error))); + }); } } diff --git a/xds/src/test/java/io/grpc/xds/XdsSecurityClientServerTest.java b/xds/src/test/java/io/grpc/xds/XdsSecurityClientServerTest.java index 2c349eec4af..6b39106f18c 100644 --- a/xds/src/test/java/io/grpc/xds/XdsSecurityClientServerTest.java +++ b/xds/src/test/java/io/grpc/xds/XdsSecurityClientServerTest.java @@ -24,13 +24,20 @@ import static io.grpc.xds.internal.security.CommonTlsContextTestsUtil.CA_PEM_FILE; import static io.grpc.xds.internal.security.CommonTlsContextTestsUtil.CLIENT_KEY_FILE; import static io.grpc.xds.internal.security.CommonTlsContextTestsUtil.CLIENT_PEM_FILE; +import static io.grpc.xds.internal.security.CommonTlsContextTestsUtil.CLIENT_SPIFFE_PEM_FILE; import static io.grpc.xds.internal.security.CommonTlsContextTestsUtil.SERVER_1_KEY_FILE; import static io.grpc.xds.internal.security.CommonTlsContextTestsUtil.SERVER_1_PEM_FILE; +import static io.grpc.xds.internal.security.CommonTlsContextTestsUtil.SERVER_1_SPIFFE_PEM_FILE; +import static io.grpc.xds.internal.security.CommonTlsContextTestsUtil.SPIFFE_TRUST_MAP_1_FILE; +import static io.grpc.xds.internal.security.CommonTlsContextTestsUtil.SPIFFE_TRUST_MAP_FILE; import static org.junit.Assert.fail; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.util.concurrent.SettableFuture; +import io.envoyproxy.envoy.config.core.v3.SocketAddress.Protocol; +import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateValidationContext; +import io.envoyproxy.envoy.type.matcher.v3.StringMatcher; import io.grpc.Attributes; import io.grpc.EquivalentAddressGroup; import io.grpc.Grpc; @@ -43,6 +50,7 @@ import io.grpc.Server; import io.grpc.ServerCredentials; import io.grpc.Status; +import io.grpc.StatusOr; import io.grpc.StatusRuntimeException; import io.grpc.stub.StreamObserver; import io.grpc.testing.GrpcCleanupRule; @@ -61,35 +69,61 @@ import io.grpc.xds.XdsServerTestHelper.FakeXdsClient; import io.grpc.xds.XdsServerTestHelper.FakeXdsClientPoolFactory; import io.grpc.xds.client.Bootstrapper; +import io.grpc.xds.client.CommonBootstrapperTestUtils; import io.grpc.xds.internal.Matchers.HeaderMatcher; +import io.grpc.xds.internal.XdsInternalAttributes; import io.grpc.xds.internal.security.CommonTlsContextTestsUtil; +import io.grpc.xds.internal.security.SecurityProtocolNegotiators; import io.grpc.xds.internal.security.SslContextProviderSupplier; import io.grpc.xds.internal.security.TlsContextManagerImpl; +import io.grpc.xds.internal.security.certprovider.FileWatcherCertificateProviderProvider; +import io.grpc.xds.internal.security.trust.CertificateUtils; import io.netty.handler.ssl.NotSslRecordException; +import java.io.File; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.InputStream; import java.net.Inet4Address; import java.net.InetSocketAddress; import java.net.URI; -import java.net.URISyntaxException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.security.KeyStore; +import java.security.KeyStoreException; +import java.security.NoSuchAlgorithmException; +import java.security.cert.CertificateException; +import java.security.cert.CertificateFactory; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collection; import java.util.Collections; import java.util.List; import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; import javax.net.ssl.SSLException; import javax.net.ssl.SSLHandshakeException; +import javax.net.ssl.TrustManagerFactory; import org.junit.After; +import org.junit.Before; import org.junit.Rule; import org.junit.Test; import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; +import org.junit.runners.Parameterized; +import org.junit.runners.Parameterized.Parameter; +import org.junit.runners.Parameterized.Parameters; /** * Unit tests for {@link XdsChannelCredentials} and {@link XdsServerBuilder} for plaintext/TLS/mTLS * modes. */ -@RunWith(JUnit4.class) +@RunWith(Parameterized.class) public class XdsSecurityClientServerTest { + + private static final String SNI_IN_UTC = "waterzooi.test.google.be"; + + @Parameter + public Boolean enableSpiffe; + private Boolean originalEnableSpiffe; @Rule public final GrpcCleanupRule cleanupRule = new GrpcCleanupRule(); private int port; @@ -101,12 +135,37 @@ public class XdsSecurityClientServerTest { private FakeXdsClient xdsClient = new FakeXdsClient(); private FakeXdsClientPoolFactory fakePoolFactory = new FakeXdsClientPoolFactory(xdsClient); private static final String OVERRIDE_AUTHORITY = "foo.test.google.fr"; + private Attributes sslContextAttributes; + + @Parameters(name = "enableSpiffe={0}") + public static Collection data() { + return ImmutableList.of(true, false); + } + + @Before + public void setUp() throws IOException { + saveEnvironment(); + FileWatcherCertificateProviderProvider.enableSpiffe = enableSpiffe; + } + + private void saveEnvironment() { + originalEnableSpiffe = FileWatcherCertificateProviderProvider.enableSpiffe; + } @After - public void tearDown() { + public void tearDown() throws IOException { if (fakeNameResolverFactory != null) { NameResolverRegistry.getDefaultRegistry().deregister(fakeNameResolverFactory); } + FileWatcherCertificateProviderProvider.enableSpiffe = originalEnableSpiffe; + if (sslContextAttributes != null) { + SslContextProviderSupplier sslContextProviderSupplier = sslContextAttributes.get( + SecurityProtocolNegotiators.ATTR_SSL_CONTEXT_PROVIDER_SUPPLIER); + if (sslContextProviderSupplier != null) { + sslContextProviderSupplier.close(); + } + sslContextAttributes = null; + } } @Test @@ -133,30 +192,317 @@ public void nullFallbackCredentials_expectException() throws Exception { @Test public void tlsClientServer_noClientAuthentication() throws Exception { DownstreamTlsContext downstreamTlsContext = - setBootstrapInfoAndBuildDownstreamTlsContext(null, null, null, null, false, false); + setBootstrapInfoAndBuildDownstreamTlsContext(SERVER_1_PEM_FILE, null, null, null, null, + null, false, false); buildServerWithTlsContext(downstreamTlsContext); // for TLS, client only needs trustCa UpstreamTlsContext upstreamTlsContext = setBootstrapInfoAndBuildUpstreamTlsContext( - CLIENT_KEY_FILE, - CLIENT_PEM_FILE, false); + CLIENT_KEY_FILE, CLIENT_PEM_FILE, null, false); SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = getBlockingStub(upstreamTlsContext, /* overrideAuthority= */ OVERRIDE_AUTHORITY); assertThat(unaryRpc(/* requestMessage= */ "buddy", blockingStub)).isEqualTo("Hello buddy"); } + /** + * Use system root ca cert for TLS channel - no mTLS. + * Uses common_tls_context.combined_validation_context in upstream_tls_context. + */ + @Test + public void tlsClientServer_useSystemRootCerts_noMtls_useCombinedValidationContext() + throws Exception { + Path trustStoreFilePath = getCacertFilePathForTestCa(); + try { + setTrustStoreSystemProperties(trustStoreFilePath.toAbsolutePath().toString()); + DownstreamTlsContext downstreamTlsContext = + setBootstrapInfoAndBuildDownstreamTlsContext(SERVER_1_PEM_FILE, null, null, null, null, + null, false, false); + buildServerWithTlsContext(downstreamTlsContext); + + UpstreamTlsContext upstreamTlsContext = + setBootstrapInfoAndBuildUpstreamTlsContextForUsingSystemRootCerts(CLIENT_KEY_FILE, + CLIENT_PEM_FILE, true, SNI_IN_UTC, false, "", false, false); + + SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = + getBlockingStub(upstreamTlsContext, /* overrideAuthority= */ OVERRIDE_AUTHORITY); + assertThat(unaryRpc(/* requestMessage= */ "buddy", blockingStub)).isEqualTo("Hello buddy"); + } finally { + Files.deleteIfExists(trustStoreFilePath); + clearTrustStoreSystemProperties(); + } + } + + /** + * Use system root ca cert for TLS channel - no mTLS. + * Uses common_tls_context.validation_context in upstream_tls_context. + */ + @Test + public void tlsClientServer_useSystemRootCerts_noMtls_validationContext() throws Exception { + Path trustStoreFilePath = getCacertFilePathForTestCa().toAbsolutePath(); + try { + setTrustStoreSystemProperties(trustStoreFilePath.toAbsolutePath().toString()); + DownstreamTlsContext downstreamTlsContext = + setBootstrapInfoAndBuildDownstreamTlsContext(SERVER_1_PEM_FILE, null, null, null, null, + null, false, false); + buildServerWithTlsContext(downstreamTlsContext); + + UpstreamTlsContext upstreamTlsContext = + setBootstrapInfoAndBuildUpstreamTlsContextForUsingSystemRootCerts(CLIENT_KEY_FILE, + CLIENT_PEM_FILE, false, SNI_IN_UTC, false, null, false, false); + + SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = + getBlockingStub(upstreamTlsContext, /* overrideAuthority= */ OVERRIDE_AUTHORITY); + assertThat(unaryRpc(/* requestMessage= */ "buddy", blockingStub)).isEqualTo("Hello buddy"); + } finally { + Files.deleteIfExists(trustStoreFilePath.toAbsolutePath()); + clearTrustStoreSystemProperties(); + } + } + + @Test + public void tlsClientServer_useSystemRootCerts_mtls() throws Exception { + Path trustStoreFilePath = getCacertFilePathForTestCa(); + try { + setTrustStoreSystemProperties(trustStoreFilePath.toAbsolutePath().toString()); + DownstreamTlsContext downstreamTlsContext = + setBootstrapInfoAndBuildDownstreamTlsContext(SERVER_1_PEM_FILE, null, null, null, null, + null, false, true); + buildServerWithTlsContext(downstreamTlsContext); + + UpstreamTlsContext upstreamTlsContext = + setBootstrapInfoAndBuildUpstreamTlsContextForUsingSystemRootCerts(CLIENT_KEY_FILE, + CLIENT_PEM_FILE, true, SNI_IN_UTC, true, "", false, false); + + SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = + getBlockingStub(upstreamTlsContext, /* overrideAuthority= */ OVERRIDE_AUTHORITY); + assertThat(unaryRpc(/* requestMessage= */ "buddy", blockingStub)).isEqualTo("Hello buddy"); + } finally { + Files.deleteIfExists(trustStoreFilePath); + clearTrustStoreSystemProperties(); + } + } + + @Test + public void tlsClientServer_noAutoSniValidation_failureToMatchSubjAltNames() + throws Exception { + Path trustStoreFilePath = getCacertFilePathForTestCa(); + try { + setTrustStoreSystemProperties(trustStoreFilePath.toAbsolutePath().toString()); + DownstreamTlsContext downstreamTlsContext = + setBootstrapInfoAndBuildDownstreamTlsContext(SERVER_1_PEM_FILE, null, null, null, null, + null, false, false); + buildServerWithTlsContext(downstreamTlsContext); + + UpstreamTlsContext upstreamTlsContext = + setBootstrapInfoAndBuildUpstreamTlsContextForUsingSystemRootCerts(CLIENT_KEY_FILE, + CLIENT_PEM_FILE, true, "server1.test.google.in", false, "", false, false); + + SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = + getBlockingStub(upstreamTlsContext, /* overrideAuthority= */ OVERRIDE_AUTHORITY); + unaryRpc(/* requestMessage= */ "buddy", blockingStub); + fail("Expected handshake failure exception"); + } catch (StatusRuntimeException e) { + assertThat(e.getCause()).isInstanceOf(SSLHandshakeException.class); + assertThat(e.getCause().getCause()).isInstanceOf(CertificateException.class); + assertThat(e.getCause().getCause().getMessage()).isEqualTo( + "Peer certificate SAN check failed"); + } finally { + Files.deleteIfExists(trustStoreFilePath); + clearTrustStoreSystemProperties(); + } + } + + + @Test + public void tlsClientServer_autoSniValidation_sniInUtc() + throws Exception { + Path trustStoreFilePath = getCacertFilePathForTestCa(); + try { + setTrustStoreSystemProperties(trustStoreFilePath.toAbsolutePath().toString()); + DownstreamTlsContext downstreamTlsContext = + setBootstrapInfoAndBuildDownstreamTlsContext(SERVER_1_PEM_FILE, null, null, null, null, + null, false, false); + buildServerWithTlsContext(downstreamTlsContext); + + UpstreamTlsContext upstreamTlsContext = + setBootstrapInfoAndBuildUpstreamTlsContextForUsingSystemRootCerts(CLIENT_KEY_FILE, + CLIENT_PEM_FILE, true, + // SAN matcher in CommonValidationContext. Will be overridden by autoSniSanValidation + "server1.test.google.in", + false, + SNI_IN_UTC, + false, true); + + SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = + getBlockingStub(upstreamTlsContext, /* overrideAuthority= */ OVERRIDE_AUTHORITY); + unaryRpc(/* requestMessage= */ "buddy", blockingStub); + } finally { + Files.deleteIfExists(trustStoreFilePath); + clearTrustStoreSystemProperties(); + } + } + + @Test + public void tlsClientServer_autoSniValidation_sniFromHostname() + throws Exception { + Path trustStoreFilePath = getCacertFilePathForTestCa(); + try { + setTrustStoreSystemProperties(trustStoreFilePath.toAbsolutePath().toString()); + DownstreamTlsContext downstreamTlsContext = + setBootstrapInfoAndBuildDownstreamTlsContext(SERVER_1_PEM_FILE, null, null, null, null, + null, false, false); + buildServerWithTlsContext(downstreamTlsContext); + + UpstreamTlsContext upstreamTlsContext = + setBootstrapInfoAndBuildUpstreamTlsContextForUsingSystemRootCerts(CLIENT_KEY_FILE, + CLIENT_PEM_FILE, true, + // SAN matcher in CommonValidationContext. Will be overridden by autoSniSanValidation + "server1.test.google.in", + false, + "", + true, true); + + // TODO: Change this to foo.test.gooogle.fr that needs wildcard matching after + // https://github.com/grpc/grpc-java/pull/12345 is done + SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = + getBlockingStub(upstreamTlsContext, /* overrideAuthority= */ OVERRIDE_AUTHORITY, + "waterzooi.test.google.be"); + unaryRpc(/* requestMessage= */ "buddy", blockingStub); + } finally { + Files.deleteIfExists(trustStoreFilePath); + clearTrustStoreSystemProperties(); + } + } + + @Test + public void tlsClientServer_autoSniValidation_noSniApplicable_usesMatcherFromCmnVdnCtx() + throws Exception { + Path trustStoreFilePath = getCacertFilePathForTestCa(); + boolean originalUseChannelAuthorityIfNoSniApplicable = + CertificateUtils.useChannelAuthorityIfNoSniApplicable; + try { + CertificateUtils.useChannelAuthorityIfNoSniApplicable = + true; + setTrustStoreSystemProperties(trustStoreFilePath.toAbsolutePath().toString()); + DownstreamTlsContext downstreamTlsContext = + setBootstrapInfoAndBuildDownstreamTlsContext(SERVER_1_PEM_FILE, null, null, null, null, + null, false, false); + buildServerWithTlsContext(downstreamTlsContext); + + UpstreamTlsContext upstreamTlsContext = + setBootstrapInfoAndBuildUpstreamTlsContextForUsingSystemRootCerts(CLIENT_KEY_FILE, + CLIENT_PEM_FILE, true, + // This is what will get used for the SAN validation since no SNI was used + "waterzooi.test.google.be", + false, + "", + false, true); + + SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = + getBlockingStub(upstreamTlsContext, /* overrideAuthority= */ OVERRIDE_AUTHORITY); + unaryRpc(/* requestMessage= */ "buddy", blockingStub); + } finally { + CertificateUtils.useChannelAuthorityIfNoSniApplicable = + originalUseChannelAuthorityIfNoSniApplicable; + Files.deleteIfExists(trustStoreFilePath); + clearTrustStoreSystemProperties(); + } + } + + /** + * Use system root ca cert for TLS channel - mTLS. + */ + @Test + public void tlsClientServer_useSystemRootCerts_requireClientAuth() throws Exception { + Path trustStoreFilePath = getCacertFilePathForTestCa().toAbsolutePath(); + try { + setTrustStoreSystemProperties(trustStoreFilePath.toAbsolutePath().toString()); + DownstreamTlsContext downstreamTlsContext = + setBootstrapInfoAndBuildDownstreamTlsContext(SERVER_1_PEM_FILE, null, null, null, null, + null, false, true); + buildServerWithTlsContext(downstreamTlsContext); + + UpstreamTlsContext upstreamTlsContext = + setBootstrapInfoAndBuildUpstreamTlsContextForUsingSystemRootCerts(CLIENT_KEY_FILE, + CLIENT_PEM_FILE, true, SNI_IN_UTC, false, "", false, false); + SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = + getBlockingStub(upstreamTlsContext, /* overrideAuthority= */ OVERRIDE_AUTHORITY); + assertThat(unaryRpc(/* requestMessage= */ "buddy", blockingStub)).isEqualTo("Hello buddy"); + } finally { + Files.deleteIfExists(trustStoreFilePath.toAbsolutePath()); + clearTrustStoreSystemProperties(); + } + } + + private Path getCacertFilePathForTestCa() + throws IOException, KeyStoreException, CertificateException, NoSuchAlgorithmException { + KeyStore keystore = KeyStore.getInstance(KeyStore.getDefaultType()); + keystore.load(null, null); + InputStream caCertStream = getClass().getResource("/certs/ca.pem").openStream(); + keystore.setCertificateEntry("testca", CertificateFactory.getInstance("X.509") + .generateCertificate(caCertStream)); + caCertStream.close(); + File trustStoreFile = File.createTempFile("testca-truststore", "jks"); + FileOutputStream out = new FileOutputStream(trustStoreFile); + keystore.store(out, "changeit".toCharArray()); + out.close(); + return trustStoreFile.toPath(); + } + + @Test + public void tlsClientServer_Spiffe_noClientAuthentication() throws Exception { + DownstreamTlsContext downstreamTlsContext = + setBootstrapInfoAndBuildDownstreamTlsContext(SERVER_1_SPIFFE_PEM_FILE, null, null, null, + null, null, false, false); + buildServerWithTlsContext(downstreamTlsContext); + + // for TLS, client only needs trustCa, so BAD certs don't matter + UpstreamTlsContext upstreamTlsContext = setBootstrapInfoAndBuildUpstreamTlsContext( + BAD_CLIENT_KEY_FILE, BAD_CLIENT_PEM_FILE, SPIFFE_TRUST_MAP_FILE, false); + + SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = + getBlockingStub(upstreamTlsContext, /* overrideAuthority= */ OVERRIDE_AUTHORITY); + assertThat(unaryRpc(/* requestMessage= */ "buddy", blockingStub)).isEqualTo("Hello buddy"); + } + + @Test + public void tlsClientServer_Spiffe_noClientAuthentication_wrongServerCert() throws Exception { + if (!enableSpiffe) { + return; + } + DownstreamTlsContext downstreamTlsContext = + setBootstrapInfoAndBuildDownstreamTlsContext(SERVER_1_PEM_FILE, null, null, null, null, + null, false, false); + buildServerWithTlsContext(downstreamTlsContext); + + // for TLS, client only needs trustCa, so BAD certs don't matter + UpstreamTlsContext upstreamTlsContext = setBootstrapInfoAndBuildUpstreamTlsContext( + BAD_CLIENT_KEY_FILE, BAD_CLIENT_PEM_FILE, SPIFFE_TRUST_MAP_FILE, false); + + SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = + getBlockingStub(upstreamTlsContext, /* overrideAuthority= */ OVERRIDE_AUTHORITY); + try { + unaryRpc("buddy", blockingStub); + fail("exception expected"); + } catch (StatusRuntimeException sre) { + assertThat(sre.getStatus().getCode()).isEqualTo(Status.UNAVAILABLE.getCode()); + assertThat(sre.getCause().getCause().getMessage()) + .contains("Failed to extract SPIFFE ID from peer leaf certificate"); + } + } + @Test public void requireClientAuth_noClientCert_expectException() throws Exception { DownstreamTlsContext downstreamTlsContext = - setBootstrapInfoAndBuildDownstreamTlsContext(null, null, null, null, true, true); + setBootstrapInfoAndBuildDownstreamTlsContext(SERVER_1_PEM_FILE, null, null, null, null, + null, true, true); buildServerWithTlsContext(downstreamTlsContext); // for TLS, client only uses trustCa UpstreamTlsContext upstreamTlsContext = setBootstrapInfoAndBuildUpstreamTlsContext( - CLIENT_KEY_FILE, - CLIENT_PEM_FILE, false); + CLIENT_KEY_FILE, CLIENT_PEM_FILE, null, false); SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = getBlockingStub(upstreamTlsContext, /* overrideAuthority= */ OVERRIDE_AUTHORITY); @@ -178,12 +524,12 @@ public void requireClientAuth_noClientCert_expectException() @Test public void noClientAuth_sendBadClientCert_passes() throws Exception { DownstreamTlsContext downstreamTlsContext = - setBootstrapInfoAndBuildDownstreamTlsContext(null, null, null, null, false, false); + setBootstrapInfoAndBuildDownstreamTlsContext(SERVER_1_PEM_FILE, null, null, null, null, + null, false, false); buildServerWithTlsContext(downstreamTlsContext); UpstreamTlsContext upstreamTlsContext = setBootstrapInfoAndBuildUpstreamTlsContext( - BAD_CLIENT_KEY_FILE, - BAD_CLIENT_PEM_FILE, true); + BAD_CLIENT_KEY_FILE, BAD_CLIENT_PEM_FILE, null, true); SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = getBlockingStub(upstreamTlsContext, /* overrideAuthority= */ OVERRIDE_AUTHORITY); @@ -193,8 +539,7 @@ public void noClientAuth_sendBadClientCert_passes() throws Exception { @Test public void mtls_badClientCert_expectException() throws Exception { UpstreamTlsContext upstreamTlsContext = setBootstrapInfoAndBuildUpstreamTlsContext( - BAD_CLIENT_KEY_FILE, - BAD_CLIENT_PEM_FILE, true); + BAD_CLIENT_KEY_FILE, BAD_CLIENT_PEM_FILE, null, true); try { performMtlsTestAndGetListenerWatcher(upstreamTlsContext, null, null, null, null); fail("exception expected"); @@ -210,20 +555,58 @@ public void mtls_badClientCert_expectException() throws Exception { } } - /** mTLS - client auth enabled - using {@link XdsChannelCredentials} API. */ + /** mTLS with Spiffe Trust Bundle - client auth enabled - using {@link XdsChannelCredentials} + * API. */ + @Test + public void mtlsClientServer_Spiffe_withClientAuthentication_withXdsChannelCreds() + throws Exception { + DownstreamTlsContext downstreamTlsContext = + setBootstrapInfoAndBuildDownstreamTlsContext(SERVER_1_SPIFFE_PEM_FILE, null, null, null, + null, SPIFFE_TRUST_MAP_1_FILE, true, true); + buildServerWithTlsContext(downstreamTlsContext); + + UpstreamTlsContext upstreamTlsContext = setBootstrapInfoAndBuildUpstreamTlsContext( + CLIENT_KEY_FILE, CLIENT_SPIFFE_PEM_FILE, SPIFFE_TRUST_MAP_1_FILE, true); + + SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = + getBlockingStub(upstreamTlsContext, /* overrideAuthority= */ OVERRIDE_AUTHORITY); + assertThat(unaryRpc(/* requestMessage= */ "buddy", blockingStub)).isEqualTo("Hello buddy"); + } + + @Test + public void mtlsClientServer_Spiffe_badClientCert_expectException() + throws Exception { + DownstreamTlsContext downstreamTlsContext = + setBootstrapInfoAndBuildDownstreamTlsContext(SERVER_1_SPIFFE_PEM_FILE, null, null, null, + null, SPIFFE_TRUST_MAP_1_FILE, true, true); + buildServerWithTlsContext(downstreamTlsContext); + + UpstreamTlsContext upstreamTlsContext = setBootstrapInfoAndBuildUpstreamTlsContext( + CLIENT_KEY_FILE, BAD_CLIENT_PEM_FILE, SPIFFE_TRUST_MAP_1_FILE, true); + SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = + getBlockingStub(upstreamTlsContext, /* overrideAuthority= */ OVERRIDE_AUTHORITY); + try { + assertThat(unaryRpc(/* requestMessage= */ "buddy", blockingStub)).isEqualTo("Hello buddy"); + fail("exception expected"); + } catch (StatusRuntimeException sre) { + assertThat(sre.getStatus().getCode()).isEqualTo(Status.UNAVAILABLE.getCode()); + assertThat(sre.getMessage()).contains("ssl exception"); + } + } + @Test public void mtlsClientServer_withClientAuthentication_withXdsChannelCreds() throws Exception { UpstreamTlsContext upstreamTlsContext = setBootstrapInfoAndBuildUpstreamTlsContext( - CLIENT_KEY_FILE, - CLIENT_PEM_FILE, true); + CLIENT_KEY_FILE, CLIENT_PEM_FILE, null, true); performMtlsTestAndGetListenerWatcher(upstreamTlsContext, null, null, null, null); } @Test public void tlsServer_plaintextClient_expectException() throws Exception { DownstreamTlsContext downstreamTlsContext = - setBootstrapInfoAndBuildDownstreamTlsContext(null, null, null, null, false, false); + setBootstrapInfoAndBuildDownstreamTlsContext(SERVER_1_PEM_FILE, null, null, null, null, + null, false, false); buildServerWithTlsContext(downstreamTlsContext); SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = @@ -243,8 +626,7 @@ public void plaintextServer_tlsClient_expectException() throws Exception { // for TLS, client only needs trustCa UpstreamTlsContext upstreamTlsContext = setBootstrapInfoAndBuildUpstreamTlsContext( - CLIENT_KEY_FILE, - CLIENT_PEM_FILE, false); + CLIENT_KEY_FILE, CLIENT_PEM_FILE, null, false); SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = getBlockingStub(upstreamTlsContext, /* overrideAuthority= */ OVERRIDE_AUTHORITY); @@ -262,15 +644,14 @@ public void plaintextServer_tlsClient_expectException() throws Exception { public void mtlsClientServer_changeServerContext_expectException() throws Exception { UpstreamTlsContext upstreamTlsContext = setBootstrapInfoAndBuildUpstreamTlsContext( - CLIENT_KEY_FILE, - CLIENT_PEM_FILE, true); + CLIENT_KEY_FILE, CLIENT_PEM_FILE, null, true); performMtlsTestAndGetListenerWatcher(upstreamTlsContext, "cert-instance-name2", BAD_SERVER_KEY_FILE, BAD_SERVER_PEM_FILE, CA_PEM_FILE); DownstreamTlsContext downstreamTlsContext = CommonTlsContextTestsUtil.buildDownstreamTlsContext( "cert-instance-name2", true, true); - EnvoyServerProtoData.Listener listener = buildListener("listener1", "0.0.0.0", + EnvoyServerProtoData.Listener listener = buildListener("listener1", "0.0.0.0:0", downstreamTlsContext, tlsContextManagerForServer); xdsClient.deliverLdsUpdate(LdsUpdate.forTcpListener(listener)); @@ -290,8 +671,8 @@ private void performMtlsTestAndGetListenerWatcher( String privateKey2, String cert2, String trustCa2) throws Exception { DownstreamTlsContext downstreamTlsContext = - setBootstrapInfoAndBuildDownstreamTlsContext(certInstanceName2, privateKey2, cert2, - trustCa2, true, true); + setBootstrapInfoAndBuildDownstreamTlsContext(SERVER_1_PEM_FILE, certInstanceName2, + privateKey2, cert2, trustCa2, null, true, false); buildServerWithFallbackServerCredentials( InsecureServerCredentials.create(), downstreamTlsContext); @@ -302,26 +683,58 @@ private void performMtlsTestAndGetListenerWatcher( } private DownstreamTlsContext setBootstrapInfoAndBuildDownstreamTlsContext( - String certInstanceName2, - String privateKey2, - String cert2, String trustCa2, boolean hasRootCert, boolean requireClientCertificate) { + String cert1, String certInstanceName2, String privateKey2, + String cert2, String trustCa2, String spiffeFile, + boolean hasRootCert, boolean requireClientCertificate) { bootstrapInfoForServer = CommonBootstrapperTestUtils .buildBootstrapInfo("google_cloud_private_spiffe-server", SERVER_1_KEY_FILE, - SERVER_1_PEM_FILE, CA_PEM_FILE, certInstanceName2, privateKey2, cert2, trustCa2); + cert1, CA_PEM_FILE, certInstanceName2, privateKey2, cert2, trustCa2, spiffeFile); return CommonTlsContextTestsUtil.buildDownstreamTlsContext( "google_cloud_private_spiffe-server", hasRootCert, requireClientCertificate); } private UpstreamTlsContext setBootstrapInfoAndBuildUpstreamTlsContext(String clientKeyFile, - String clientPemFile, - boolean hasIdentityCert) { + String clientPemFile, String spiffeFile, boolean hasIdentityCert) { bootstrapInfoForClient = CommonBootstrapperTestUtils .buildBootstrapInfo("google_cloud_private_spiffe-client", clientKeyFile, clientPemFile, - CA_PEM_FILE, null, null, null, null); + CA_PEM_FILE, null, null, null, null, spiffeFile); return CommonTlsContextTestsUtil .buildUpstreamTlsContext("google_cloud_private_spiffe-client", hasIdentityCert); } + @SuppressWarnings("deprecation") // gRFC A29 predates match_typed_subject_alt_names + private UpstreamTlsContext setBootstrapInfoAndBuildUpstreamTlsContextForUsingSystemRootCerts( + String clientKeyFile, + String clientPemFile, + boolean useCombinedValidationContext, + String sanToMatch, + boolean isMtls, + String sniInUpstreamTlsContext, + boolean autoHostSni, boolean autoSniSanValidation) { + bootstrapInfoForClient = CommonBootstrapperTestUtils + .buildBootstrapInfo("google_cloud_private_spiffe-client", clientKeyFile, clientPemFile, + CA_PEM_FILE, null, null, null, null, null); + if (useCombinedValidationContext) { + return CommonTlsContextTestsUtil.buildUpstreamTlsContextForCertProviderInstance( + isMtls ? "google_cloud_private_spiffe-client" : null, + isMtls ? "ROOT" : null, null, + null, null, + CertificateValidationContext.newBuilder() + .setSystemRootCerts( + CertificateValidationContext.SystemRootCerts.newBuilder().build()) + .addMatchSubjectAltNames( + StringMatcher.newBuilder() + .setExact(sanToMatch)) + .build(), sniInUpstreamTlsContext, autoHostSni, autoSniSanValidation); + } + return CommonTlsContextTestsUtil.buildNewUpstreamTlsContextForCertProviderInstance( + "google_cloud_private_spiffe-client", "ROOT", null, + null, null, CertificateValidationContext.newBuilder() + .setSystemRootCerts( + CertificateValidationContext.SystemRootCerts.newBuilder().build()) + .build()); + } + private void buildServerWithTlsContext(DownstreamTlsContext downstreamTlsContext) throws Exception { buildServerWithTlsContext(downstreamTlsContext, InsecureServerCredentials.create()); @@ -340,6 +753,7 @@ private void buildServerWithFallbackServerCredentials( ServerCredentials xdsCredentials = XdsServerCredentials.create(fallbackCredentials); XdsServerBuilder builder = XdsServerBuilder.forPort(0, xdsCredentials) .xdsClientPoolFactory(fakePoolFactory) + .overrideBootstrapForTest(XdsServerTestHelper.RAW_BOOTSTRAP) .addService(new SimpleServiceImpl()); buildServer(builder, downstreamTlsContext); } @@ -351,7 +765,7 @@ private void buildServer( tlsContextManagerForServer = new TlsContextManagerImpl(bootstrapInfoForServer); XdsServerWrapper xdsServer = (XdsServerWrapper) builder.build(); SettableFuture startFuture = startServerAsync(xdsServer); - EnvoyServerProtoData.Listener listener = buildListener("listener1", "10.1.2.3", + EnvoyServerProtoData.Listener listener = buildListener("listener1", "0.0.0.0:0", downstreamTlsContext, tlsContextManagerForServer); LdsUpdate listenerUpdate = LdsUpdate.forTcpListener(listener); xdsClient.deliverLdsUpdate(listenerUpdate); @@ -392,13 +806,25 @@ static EnvoyServerProtoData.Listener buildListener( "filter-chain-foo", filterChainMatch, httpConnectionManager, tlsContext, tlsContextManager); EnvoyServerProtoData.Listener listener = EnvoyServerProtoData.Listener.create( - name, address, ImmutableList.of(defaultFilterChain), null); + name, address, ImmutableList.of(defaultFilterChain), null, Protocol.TCP); return listener; } private SimpleServiceGrpc.SimpleServiceBlockingStub getBlockingStub( - final UpstreamTlsContext upstreamTlsContext, String overrideAuthority) - throws URISyntaxException { + final UpstreamTlsContext upstreamTlsContext, String overrideAuthority) { + return getBlockingStub(upstreamTlsContext, overrideAuthority, overrideAuthority); + } + + // Two separate parameters for overrideAuthority and addrAttribute is for the SAN SNI validation + // test tlsClientServer_useSystemRootCerts_sni_san_validation_from_hostname that uses hostname + // passed for SNI. foo.test.google.fr is used for virtual host matching via authority but it + // can't be used for SNI in this testcase because foo.test.google.fr needs wildcard matching to + // match against *.test.google.fr in the certificate SNI, which isn't implemented yet + // (https://github.com/grpc/grpc-java/pull/12345 implements it) + // so use an exact match SAN such as waterzooi.test.google.be for SNI for this testcase. + private SimpleServiceGrpc.SimpleServiceBlockingStub getBlockingStub( + final UpstreamTlsContext upstreamTlsContext, String overrideAuthority, + String addrNameAttribute) { ManagedChannelBuilder channelBuilder = Grpc.newChannelBuilder( "sectest://localhost:" + port, @@ -410,16 +836,18 @@ private SimpleServiceGrpc.SimpleServiceBlockingStub getBlockingStub( InetSocketAddress socketAddress = new InetSocketAddress(Inet4Address.getLoopbackAddress(), port); tlsContextManagerForClient = new TlsContextManagerImpl(bootstrapInfoForClient); - Attributes attrs = - (upstreamTlsContext != null) - ? Attributes.newBuilder() - .set(InternalXdsAttributes.ATTR_SSL_CONTEXT_PROVIDER_SUPPLIER, - new SslContextProviderSupplier( - upstreamTlsContext, tlsContextManagerForClient)) - .build() - : Attributes.EMPTY; + Attributes.Builder sslContextAttributesBuilder = (upstreamTlsContext != null) + ? Attributes.newBuilder() + .set(SecurityProtocolNegotiators.ATTR_SSL_CONTEXT_PROVIDER_SUPPLIER, + new SslContextProviderSupplier( + upstreamTlsContext, tlsContextManagerForClient)) + : Attributes.newBuilder(); + if (addrNameAttribute != null) { + sslContextAttributesBuilder.set(XdsInternalAttributes.ATTR_ADDRESS_NAME, addrNameAttribute); + } + sslContextAttributes = sslContextAttributesBuilder.build(); fakeNameResolverFactory.setServers( - ImmutableList.of(new EquivalentAddressGroup(socketAddress, attrs))); + ImmutableList.of(new EquivalentAddressGroup(socketAddress, sslContextAttributes))); return SimpleServiceGrpc.newBlockingStub(cleanupRule.register(channelBuilder.build())); } @@ -445,10 +873,49 @@ public void run() { } } }); - xdsClient.ldsResource.get(8000, TimeUnit.MILLISECONDS); + try { + xdsClient.ldsResource.get(8000, TimeUnit.MILLISECONDS); + } catch (Exception ex) { + // start() probably failed, so throw its exception + if (settableFuture.isDone()) { + Throwable t = settableFuture.get(); + if (t != null) { + throw new Exception(t); + } + } + throw ex; + } return settableFuture; } + private void setTrustStoreSystemProperties(String trustStoreFilePath) throws Exception { + System.setProperty("javax.net.ssl.trustStore", trustStoreFilePath); + System.setProperty("javax.net.ssl.trustStorePassword", "changeit"); + System.setProperty("javax.net.ssl.trustStoreType", "JKS"); + createDefaultTrustManager(); + } + + private void clearTrustStoreSystemProperties() throws Exception { + System.clearProperty("javax.net.ssl.trustStore"); + System.clearProperty("javax.net.ssl.trustStorePassword"); + System.clearProperty("javax.net.ssl.trustStoreType"); + createDefaultTrustManager(); + } + + /** + * Workaround the JDK's TrustManagerStore race. TrustManagerStore has a cache for the default + * certs based on the system properties. But updating the cache is not thread-safe and can cause a + * half-updated cache to appear fully-updated. When both the client and server initialize their + * trust store simultaneously, one can see a half-updated value. Creating the trust manager here + * fixes the cache while no other threads are running and thus the client and server threads won't + * race to update it. See https://github.com/grpc/grpc-java/issues/11678. + */ + private void createDefaultTrustManager() throws Exception { + TrustManagerFactory factory = + TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()); + factory.init((KeyStore) null); + } + private static class SimpleServiceImpl extends SimpleServiceGrpc.SimpleServiceImplBase { @Override @@ -520,7 +987,8 @@ public void refresh() { } void resolved() { - ResolutionResult.Builder builder = ResolutionResult.newBuilder().setAddresses(servers); + ResolutionResult.Builder builder = ResolutionResult.newBuilder() + .setAddressesOrError(StatusOr.fromValue(servers)); listener.onResult(builder.build()); } diff --git a/xds/src/test/java/io/grpc/xds/XdsServerBuilderTest.java b/xds/src/test/java/io/grpc/xds/XdsServerBuilderTest.java index d28c7d7c607..ac990226259 100644 --- a/xds/src/test/java/io/grpc/xds/XdsServerBuilderTest.java +++ b/xds/src/test/java/io/grpc/xds/XdsServerBuilderTest.java @@ -17,6 +17,7 @@ package io.grpc.xds; import static com.google.common.truth.Truth.assertThat; +import static io.grpc.xds.XdsServerTestHelper.buildTestListener; import static org.junit.Assert.fail; import static org.mockito.Mockito.any; import static org.mockito.Mockito.mock; @@ -26,13 +27,16 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import com.google.common.collect.ImmutableList; import com.google.common.util.concurrent.SettableFuture; import io.grpc.BindableService; import io.grpc.InsecureServerCredentials; import io.grpc.ServerServiceDefinition; import io.grpc.Status; import io.grpc.StatusException; +import io.grpc.StatusOr; import io.grpc.testing.GrpcCleanupRule; +import io.grpc.xds.XdsListenerResource.LdsUpdate; import io.grpc.xds.XdsServerTestHelper.FakeXdsClient; import io.grpc.xds.XdsServerTestHelper.FakeXdsClientPoolFactory; import io.grpc.xds.internal.security.CommonTlsContextTestsUtil; @@ -40,7 +44,6 @@ import java.net.InetSocketAddress; import java.net.ServerSocket; import java.net.SocketAddress; -import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.concurrent.ExecutionException; @@ -81,6 +84,7 @@ private void buildBuilder(XdsServerBuilder.XdsServingStatusListener xdsServingSt XdsServerBuilder.forPort( port, XdsServerCredentials.create(InsecureServerCredentials.create())); builder.xdsClientPoolFactory(xdsClientPoolFactory); + builder.overrideBootstrapForTest(XdsServerTestHelper.RAW_BOOTSTRAP); if (xdsServingStatusListener != null) { builder.xdsServingStatusListener(xdsServingStatusListener); } @@ -135,7 +139,18 @@ public void run() { } } }); - xdsClient.ldsResource.get(5000, TimeUnit.MILLISECONDS); + try { + xdsClient.ldsResource.get(5000, TimeUnit.MILLISECONDS); + } catch (TimeoutException ex) { + // start() probably failed, so throw its exception + if (settableFuture.isDone()) { + Throwable t = settableFuture.get(); + if (t != null) { + throw new ExecutionException(t); + } + } + throw ex; + } return settableFuture; } @@ -195,13 +210,14 @@ public void xdsServer_discoverState() throws Exception { CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"), tlsContextManager); future.get(5000, TimeUnit.MILLISECONDS); - xdsClient.ldsWatcher.onError(Status.ABORTED); + xdsClient.ldsWatcher.onAmbientError(Status.ABORTED); verify(mockXdsServingStatusListener, never()).onNotServing(any(StatusException.class)); reset(mockXdsServingStatusListener); - xdsClient.ldsWatcher.onError(Status.CANCELLED); + xdsClient.ldsWatcher.onAmbientError(Status.CANCELLED); verify(mockXdsServingStatusListener, never()).onNotServing(any(StatusException.class)); reset(mockXdsServingStatusListener); - xdsClient.ldsWatcher.onResourceDoesNotExist("not found error"); + Status notFoundStatus = Status.NOT_FOUND.withDescription("not found error"); + xdsClient.ldsWatcher.onResourceChanged(StatusOr.fromStatus(notFoundStatus)); verify(mockXdsServingStatusListener).onNotServing(any(StatusException.class)); reset(mockXdsServingStatusListener); XdsServerTestHelper.generateListenerUpdate( @@ -221,10 +237,13 @@ public void xdsServer_startError() buildServer(mockXdsServingStatusListener); Future future = startServerAsync(); // create port conflict for start to fail - XdsServerTestHelper.generateListenerUpdate( - xdsClient, + EnvoyServerProtoData.Listener listener = buildTestListener( + "listener1", "0.0.0.0:" + port, ImmutableList.of(), CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"), - tlsContextManager); + null, tlsContextManager); + LdsUpdate listenerUpdate = LdsUpdate.forTcpListener(listener); + xdsClient.deliverLdsUpdate(listenerUpdate); + Throwable exception = future.get(5, TimeUnit.SECONDS); assertThat(exception).isInstanceOf(IOException.class); assertThat(exception).hasMessageThat().contains("Failed to bind"); @@ -249,7 +268,7 @@ public void xdsServerStartSecondUpdateAndError() tlsContextManager); verify(mockXdsServingStatusListener, never()).onNotServing(any(Throwable.class)); verifyServer(future, mockXdsServingStatusListener, null); - xdsClient.ldsWatcher.onError(Status.ABORTED); + xdsClient.ldsWatcher.onAmbientError(Status.ABORTED); verifyServer(null, mockXdsServingStatusListener, null); } @@ -298,9 +317,12 @@ public void drainGraceTime_negativeThrows() throws IOException { @Test public void testOverrideBootstrap() throws Exception { - Map b = new HashMap<>(); + Map b = XdsServerTestHelper.RAW_BOOTSTRAP; buildBuilder(null); builder.overrideBootstrapForTest(b); - assertThat(xdsClientPoolFactory.savedBootstrap).isEqualTo(b); + xdsServer = cleanupRule.register((XdsServerWrapper) builder.build()); + Future unused = startServerAsync(); + assertThat(xdsClientPoolFactory.savedBootstrapInfo.node().getId()) + .isEqualTo(XdsServerTestHelper.BOOTSTRAP_INFO.node().getId()); } } diff --git a/xds/src/test/java/io/grpc/xds/XdsServerTestHelper.java b/xds/src/test/java/io/grpc/xds/XdsServerTestHelper.java index 5d59e97335e..386793299d8 100644 --- a/xds/src/test/java/io/grpc/xds/XdsServerTestHelper.java +++ b/xds/src/test/java/io/grpc/xds/XdsServerTestHelper.java @@ -21,7 +21,11 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.util.concurrent.SettableFuture; +import io.envoyproxy.envoy.config.core.v3.SocketAddress.Protocol; import io.grpc.InsecureChannelCredentials; +import io.grpc.MetricRecorder; +import io.grpc.Status; +import io.grpc.StatusOr; import io.grpc.internal.ObjectPool; import io.grpc.xds.EnvoyServerProtoData.ConnectionSourceType; import io.grpc.xds.EnvoyServerProtoData.FilterChain; @@ -35,8 +39,8 @@ import io.grpc.xds.client.Bootstrapper.BootstrapInfo; import io.grpc.xds.client.EnvoyProtoData; import io.grpc.xds.client.XdsClient; -import io.grpc.xds.client.XdsInitializationException; import io.grpc.xds.client.XdsResourceType; +import java.time.Duration; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; @@ -44,7 +48,10 @@ import java.util.List; import java.util.Map; import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutionException; import java.util.concurrent.Executor; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; import javax.annotation.Nullable; /** @@ -57,6 +64,17 @@ public class XdsServerTestHelper { "projects/42/networks/default/nodes/5c85b298-6f5b-4722-b74a-f7d1f0ccf5ad"; private static final EnvoyProtoData.Node BOOTSTRAP_NODE = EnvoyProtoData.Node.newBuilder().setId(NODE_ID).build(); + static final Map RAW_BOOTSTRAP = ImmutableMap.of( + "node", ImmutableMap.of( + "id", NODE_ID), + "server_listener_resource_name_template", "grpc/server?udpa.resource.listening_address=%s", + "xds_servers", ImmutableList.of( + ImmutableMap.of( + "server_uri", SERVER_URI, + "channel_creds", ImmutableList.of( + ImmutableMap.of( + "type", "insecure"))) + )); static final Bootstrapper.BootstrapInfo BOOTSTRAP_INFO = Bootstrapper.BootstrapInfo.builder() .servers(Arrays.asList( @@ -69,7 +87,7 @@ public class XdsServerTestHelper { static void generateListenerUpdate(FakeXdsClient xdsClient, EnvoyServerProtoData.DownstreamTlsContext tlsContext, TlsContextManager tlsContextManager) { - EnvoyServerProtoData.Listener listener = buildTestListener("listener1", "10.1.2.3", + EnvoyServerProtoData.Listener listener = buildTestListener("listener1", "0.0.0.0:0", ImmutableList.of(), tlsContext, null, tlsContextManager); LdsUpdate listenerUpdate = LdsUpdate.forTcpListener(listener); xdsClient.deliverLdsUpdate(listenerUpdate); @@ -80,7 +98,8 @@ static void generateListenerUpdate( EnvoyServerProtoData.DownstreamTlsContext tlsContext, EnvoyServerProtoData.DownstreamTlsContext tlsContextForDefaultFilterChain, TlsContextManager tlsContextManager) { - EnvoyServerProtoData.Listener listener = buildTestListener("listener1", "10.1.2.3", sourcePorts, + EnvoyServerProtoData.Listener listener = buildTestListener( + "listener1", "0.0.0.0:7000", sourcePorts, tlsContext, tlsContextForDefaultFilterChain, tlsContextManager); LdsUpdate listenerUpdate = LdsUpdate.forTcpListener(listener); xdsClient.deliverLdsUpdate(listenerUpdate); @@ -125,7 +144,7 @@ static EnvoyServerProtoData.Listener buildTestListener( tlsContextForDefaultFilterChain, tlsContextManager); EnvoyServerProtoData.Listener listener = EnvoyServerProtoData.Listener.create( - name, address, ImmutableList.of(filterChain1), defaultFilterChain); + name, address, ImmutableList.of(filterChain1), defaultFilterChain, Protocol.TCP); return listener; } @@ -133,25 +152,22 @@ static final class FakeXdsClientPoolFactory implements XdsClientPoolFactory { private XdsClient xdsClient; - Map savedBootstrap; + BootstrapInfo savedBootstrapInfo; FakeXdsClientPoolFactory(XdsClient xdsClient) { this.xdsClient = xdsClient; } - @Override - public void setBootstrapOverride(Map bootstrap) { - this.savedBootstrap = bootstrap; - } - @Override @Nullable - public ObjectPool get() { + public ObjectPool get(String target) { throw new UnsupportedOperationException("Should not be called"); } @Override - public ObjectPool getOrCreate() throws XdsInitializationException { + public ObjectPool getOrCreate( + String target, BootstrapInfo bootstrapInfo, MetricRecorder metricRecorder) { + this.savedBootstrapInfo = bootstrapInfo; return new ObjectPool() { @Override public XdsClient getObject() { @@ -165,14 +181,25 @@ public XdsClient returnObject(Object object) { } }; } + + @Override + public List getTargets() { + return Collections.singletonList("fake-target"); + } } + // Implementation details: + // 1. Use `synchronized` in methods where XdsClientImpl uses its own `syncContext`. + // 2. Use `serverExecutor` via `execute()` in methods where XdsClientImpl uses watcher's executor. static final class FakeXdsClient extends XdsClient { - boolean shutdown; - SettableFuture ldsResource = SettableFuture.create(); - ResourceWatcher ldsWatcher; - CountDownLatch rdsCount = new CountDownLatch(1); + public static final Duration DEFAULT_TIMEOUT = Duration.ofSeconds(5); + + private boolean shutdown; + @Nullable SettableFuture ldsResource = SettableFuture.create(); + @Nullable ResourceWatcher ldsWatcher; + private CountDownLatch rdsCount = new CountDownLatch(1); final Map> rdsWatchers = new HashMap<>(); + @Nullable private volatile Executor serverExecutor; @Override public TlsContextManager getSecurityConfig() { @@ -186,14 +213,20 @@ public BootstrapInfo getBootstrapInfo() { @Override @SuppressWarnings("unchecked") - public void watchXdsResource(XdsResourceType resourceType, - String resourceName, - ResourceWatcher watcher, - Executor syncContext) { + public synchronized void watchXdsResource( + XdsResourceType resourceType, + String resourceName, + ResourceWatcher watcher, + Executor executor) { + if (serverExecutor != null) { + assertThat(executor).isEqualTo(serverExecutor); + } + switch (resourceType.typeName()) { case "LDS": assertThat(ldsWatcher).isNull(); ldsWatcher = (ResourceWatcher) watcher; + serverExecutor = executor; ldsResource.set(resourceName); break; case "RDS": @@ -206,14 +239,14 @@ public void watchXdsResource(XdsResourceType resou } @Override - public void cancelXdsResourceWatch(XdsResourceType type, - String resourceName, - ResourceWatcher watcher) { + public synchronized void cancelXdsResourceWatch( + XdsResourceType type, String resourceName, ResourceWatcher watcher) { switch (type.typeName()) { case "LDS": assertThat(ldsWatcher).isNotNull(); ldsResource = null; ldsWatcher = null; + serverExecutor = null; break; case "RDS": rdsWatchers.remove(resourceName); @@ -223,27 +256,92 @@ public void cancelXdsResourceWatch(XdsResourceType } @Override - public void shutdown() { + public synchronized void shutdown() { shutdown = true; } @Override - public boolean isShutDown() { + public synchronized boolean isShutDown() { return shutdown; } - void deliverLdsUpdate(List filterChains, - FilterChain defaultFilterChain) { - ldsWatcher.onChanged(LdsUpdate.forTcpListener(Listener.create( - "listener", "0.0.0.0:1", ImmutableList.copyOf(filterChains), defaultFilterChain))); + public void awaitRds(Duration timeout) throws InterruptedException, TimeoutException { + if (!rdsCount.await(timeout.toMillis(), TimeUnit.MILLISECONDS)) { + throw new TimeoutException("Timeout " + timeout + " waiting for RDSs"); + } + } + + public void setExpectedRdsCount(int count) { + rdsCount = new CountDownLatch(count); + } + + private void execute(Runnable action) { + // This method ensures that all watcher updates: + // - Happen after the server started watching LDS. + // - Are executed within the sync context of the server. + // + // Note that this doesn't guarantee that any of the RDS watchers are created. + // Tests should use setExpectedRdsCount(int) and awaitRds() for that. + awaitLdsResource(DEFAULT_TIMEOUT); + serverExecutor.execute(action); + } + + private String awaitLdsResource(Duration timeout) { + if (ldsResource == null) { + throw new IllegalStateException("xDS resource update after watcher cancel"); + } + try { + return ldsResource.get(timeout.toMillis(), TimeUnit.MILLISECONDS); + } catch (ExecutionException | TimeoutException e) { + throw new RuntimeException("Can't resolve LDS resource name in " + timeout, e); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RuntimeException(e); + } + } + + void deliverLdsUpdateWithApiListener(long httpMaxStreamDurationNano, + List virtualHosts) { + execute(() -> { + LdsUpdate update = LdsUpdate.forApiListener(HttpConnectionManager.forVirtualHosts( + httpMaxStreamDurationNano, virtualHosts, null)); + ldsWatcher.onResourceChanged(StatusOr.fromValue(update)); + }); } void deliverLdsUpdate(LdsUpdate ldsUpdate) { - ldsWatcher.onChanged(ldsUpdate); + execute(() -> ldsWatcher.onResourceChanged(StatusOr.fromValue(ldsUpdate))); + } + + void deliverLdsUpdate( + List filterChains, + @Nullable FilterChain defaultFilterChain) { + deliverLdsUpdate(LdsUpdate.forTcpListener(Listener.create("listener", "0.0.0.0:1", + ImmutableList.copyOf(filterChains), defaultFilterChain, Protocol.TCP))); + } + + void deliverLdsUpdate(FilterChain filterChain, @Nullable FilterChain defaultFilterChain) { + deliverLdsUpdate(ImmutableList.of(filterChain), defaultFilterChain); + } + + void deliverLdsResourceNotFound() { + String resourceName = awaitLdsResource(DEFAULT_TIMEOUT); + Status status = Status.NOT_FOUND.withDescription("Resource not found: " + resourceName); + execute(() -> ldsWatcher.onResourceChanged(StatusOr.fromStatus(status))); + } + + void deliverRdsUpdate(String resourceName, List virtualHosts) { + RdsUpdate update = new RdsUpdate(virtualHosts); + execute(() -> rdsWatchers.get(resourceName).onResourceChanged(StatusOr.fromValue(update))); + } + + void deliverRdsUpdate(String resourceName, VirtualHost virtualHost) { + deliverRdsUpdate(resourceName, ImmutableList.of(virtualHost)); } - void deliverRdsUpdate(String rdsName, List virtualHosts) { - rdsWatchers.get(rdsName).onChanged(new RdsUpdate(virtualHosts)); + void deliverRdsResourceNotFound(String resourceName) { + Status status = Status.NOT_FOUND.withDescription("Resource not found: " + resourceName); + execute(() -> rdsWatchers.get(resourceName).onResourceChanged(StatusOr.fromStatus(status))); } } } diff --git a/xds/src/test/java/io/grpc/xds/XdsServerWrapperTest.java b/xds/src/test/java/io/grpc/xds/XdsServerWrapperTest.java index 55b8812cd17..99e3911307a 100644 --- a/xds/src/test/java/io/grpc/xds/XdsServerWrapperTest.java +++ b/xds/src/test/java/io/grpc/xds/XdsServerWrapperTest.java @@ -18,6 +18,7 @@ package io.grpc.xds; import static com.google.common.truth.Truth.assertThat; +import static com.google.common.truth.Truth.assertWithMessage; import static io.grpc.xds.XdsServerWrapper.ATTR_SERVER_ROUTING_CONFIG; import static io.grpc.xds.XdsServerWrapper.RETRY_DELAY_NANOS; import static org.junit.Assert.fail; @@ -31,11 +32,12 @@ import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; -import static org.mockito.Mockito.withSettings; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import com.google.common.net.InetAddresses; import com.google.common.util.concurrent.SettableFuture; +import io.envoyproxy.envoy.config.core.v3.SocketAddress.Protocol; import io.grpc.Attributes; import io.grpc.InsecureChannelCredentials; import io.grpc.Metadata; @@ -47,17 +49,22 @@ import io.grpc.ServerInterceptor; import io.grpc.Status; import io.grpc.StatusException; +import io.grpc.StatusOr; import io.grpc.SynchronizationContext; import io.grpc.internal.FakeClock; import io.grpc.testing.TestMethodDescriptors; +import io.grpc.xds.EnvoyServerProtoData.CidrRange; import io.grpc.xds.EnvoyServerProtoData.FilterChain; +import io.grpc.xds.EnvoyServerProtoData.FilterChainMatch; +import io.grpc.xds.EnvoyServerProtoData.Listener; import io.grpc.xds.Filter.FilterConfig; import io.grpc.xds.Filter.NamedFilterConfig; -import io.grpc.xds.Filter.ServerInterceptorBuilder; import io.grpc.xds.FilterChainMatchingProtocolNegotiators.FilterChainMatchingHandler.FilterChainSelector; +import io.grpc.xds.StatefulFilter.Config; import io.grpc.xds.VirtualHost.Route; import io.grpc.xds.VirtualHost.Route.RouteMatch; import io.grpc.xds.VirtualHost.Route.RouteMatch.PathMatcher; +import io.grpc.xds.XdsListenerResource.LdsUpdate; import io.grpc.xds.XdsRouteConfigureResource.RdsUpdate; import io.grpc.xds.XdsServerBuilder.XdsServingStatusListener; import io.grpc.xds.XdsServerTestHelper.FakeXdsClient; @@ -72,11 +79,12 @@ import io.grpc.xds.internal.security.CommonTlsContextTestsUtil; import io.grpc.xds.internal.security.SslContextProviderSupplier; import java.io.IOException; +import java.net.InetAddress; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.List; -import java.util.concurrent.CountDownLatch; +import java.util.Map; import java.util.concurrent.ExecutionException; import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; @@ -96,6 +104,14 @@ @RunWith(JUnit4.class) public class XdsServerWrapperTest { private static final int START_WAIT_AFTER_LISTENER_MILLIS = 100; + private static final String ROUTER_FILTER_INSTANCE_NAME = "envoy.router"; + private static final RouterFilter.Provider ROUTER_FILTER_PROVIDER = new RouterFilter.Provider(); + + // Readability: makes it simpler to distinguish resource parameters. + private static final ImmutableMap NO_FILTER_OVERRIDES = ImmutableMap.of(); + + private static final String STATEFUL_1 = "stateful_1"; + private static final String STATEFUL_2 = "stateful_2"; @Rule public final MockitoRule mocks = MockitoJUnit.rule(); @@ -120,6 +136,7 @@ public void setup() { when(mockBuilder.build()).thenReturn(mockServer); xdsServerWrapper = new XdsServerWrapper("0.0.0.0:1", mockBuilder, listener, selectorManager, new FakeXdsClientPoolFactory(xdsClient), + XdsServerTestHelper.RAW_BOOTSTRAP, filterRegistry, executor.getScheduledExecutorService()); } @@ -142,7 +159,8 @@ public void testBootstrap() throws Exception { XdsListenerResource listenerResource = XdsListenerResource.getInstance(); when(xdsClient.getBootstrapInfo()).thenReturn(b); xdsServerWrapper = new XdsServerWrapper("[::FFFF:129.144.52.38]:80", mockBuilder, listener, - selectorManager, new FakeXdsClientPoolFactory(xdsClient), filterRegistry); + selectorManager, new FakeXdsClientPoolFactory(xdsClient), + XdsServerTestHelper.RAW_BOOTSTRAP, filterRegistry); Executors.newSingleThreadExecutor().execute(new Runnable() { @Override public void run() { @@ -175,7 +193,8 @@ private void verifyBootstrapFail(Bootstrapper.BootstrapInfo b) throws Exception XdsClient xdsClient = mock(XdsClient.class); when(xdsClient.getBootstrapInfo()).thenReturn(b); xdsServerWrapper = new XdsServerWrapper("0.0.0.0:1", mockBuilder, listener, - selectorManager, new FakeXdsClientPoolFactory(xdsClient), filterRegistry); + selectorManager, new FakeXdsClientPoolFactory(xdsClient), + XdsServerTestHelper.RAW_BOOTSTRAP, filterRegistry); final SettableFuture start = SettableFuture.create(); Executors.newSingleThreadExecutor().execute(new Runnable() { @Override @@ -214,7 +233,8 @@ public void testBootstrap_templateWithXdstp() throws Exception { XdsListenerResource listenerResource = XdsListenerResource.getInstance(); when(xdsClient.getBootstrapInfo()).thenReturn(b); xdsServerWrapper = new XdsServerWrapper("[::FFFF:129.144.52.38]:80", mockBuilder, listener, - selectorManager, new FakeXdsClientPoolFactory(xdsClient), filterRegistry); + selectorManager, new FakeXdsClientPoolFactory(xdsClient), + XdsServerTestHelper.RAW_BOOTSTRAP, filterRegistry); Executors.newSingleThreadExecutor().execute(new Runnable() { @Override public void run() { @@ -254,7 +274,7 @@ public void run() { FilterChain f0 = createFilterChain("filter-chain-0", hcm_virtual); FilterChain f1 = createFilterChain("filter-chain-1", createRds("rds")); xdsClient.deliverLdsUpdate(Collections.singletonList(f0), f1); - xdsClient.rdsCount.await(5, TimeUnit.SECONDS); + xdsClient.awaitRds(FakeXdsClient.DEFAULT_TIMEOUT); xdsClient.deliverRdsUpdate("rds", Collections.singletonList(createVirtualHost("virtual-host-1"))); verify(listener, timeout(5000)).onServing(); @@ -263,7 +283,7 @@ public void run() { xdsServerWrapper.shutdown(); assertThat(xdsServerWrapper.isShutdown()).isTrue(); assertThat(xdsClient.ldsResource).isNull(); - assertThat(xdsClient.shutdown).isTrue(); + assertThat(xdsClient.isShutDown()).isTrue(); verify(mockServer).shutdown(); assertThat(f0.sslContextProviderSupplier().isShutdown()).isTrue(); assertThat(f1.sslContextProviderSupplier().isShutdown()).isTrue(); @@ -305,7 +325,7 @@ public void run() { verify(mockServer, never()).start(); assertThat(xdsServerWrapper.isShutdown()).isTrue(); assertThat(xdsClient.ldsResource).isNull(); - assertThat(xdsClient.shutdown).isTrue(); + assertThat(xdsClient.isShutDown()).isTrue(); verify(mockServer).shutdown(); assertThat(f0.sslContextProviderSupplier().isShutdown()).isTrue(); assertThat(f1.sslContextProviderSupplier().isShutdown()).isTrue(); @@ -326,7 +346,8 @@ public void run() { } }); String ldsResource = xdsClient.ldsResource.get(5, TimeUnit.SECONDS); - xdsClient.ldsWatcher.onResourceDoesNotExist(ldsResource); + Status notFoundStatus = Status.NOT_FOUND.withDescription("Resource not found: " + ldsResource); + xdsClient.ldsWatcher.onResourceChanged(StatusOr.fromStatus(notFoundStatus)); verify(listener, timeout(5000)).onNotServing(any()); try { start.get(START_WAIT_AFTER_LISTENER_MILLIS, TimeUnit.MILLISECONDS); @@ -344,7 +365,7 @@ public void run() { xdsServerWrapper.shutdown(); assertThat(xdsServerWrapper.isShutdown()).isTrue(); assertThat(xdsClient.ldsResource).isNull(); - assertThat(xdsClient.shutdown).isTrue(); + assertThat(xdsClient.isShutDown()).isTrue(); verify(mockBuilder, times(1)).build(); verify(mockServer, times(1)).shutdown(); xdsServerWrapper.awaitTermination(1, TimeUnit.SECONDS); @@ -369,7 +390,7 @@ public void run() { FilterChain filterChain = createFilterChain("filter-chain-1", createRds("rds")); SslContextProviderSupplier sslSupplier = filterChain.sslContextProviderSupplier(); xdsClient.deliverLdsUpdate(Collections.singletonList(filterChain), null); - xdsClient.rdsCount.await(5, TimeUnit.SECONDS); + xdsClient.awaitRds(FakeXdsClient.DEFAULT_TIMEOUT); xdsClient.deliverRdsUpdate("rds", Collections.singletonList(createVirtualHost("virtual-host-1"))); try { @@ -436,7 +457,7 @@ public void run() { xdsClient.ldsResource.get(5, TimeUnit.SECONDS); FilterChain filterChain = createFilterChain("filter-chain-1", createRds("rds")); xdsClient.deliverLdsUpdate(Collections.singletonList(filterChain), null); - xdsClient.rdsCount.await(5, TimeUnit.SECONDS); + xdsClient.awaitRds(FakeXdsClient.DEFAULT_TIMEOUT); xdsClient.deliverRdsUpdate("rds", Collections.singletonList(createVirtualHost("virtual-host-1"))); try { @@ -515,7 +536,8 @@ public void run() { verify(mockServer).start(); // server shutdown after resourceDoesNotExist - xdsClient.ldsWatcher.onResourceDoesNotExist(ldsResource); + Status notFoundStatus = Status.NOT_FOUND.withDescription("Resource not found: " + ldsResource); + xdsClient.ldsWatcher.onResourceChanged(StatusOr.fromStatus(notFoundStatus)); verify(mockServer).shutdown(); // re-deliver lds resource @@ -526,6 +548,150 @@ public void run() { verify(mockServer).start(); } + @Test + public void onChanged_listenerIsNull() + throws ExecutionException, InterruptedException, TimeoutException { + xdsServerWrapper = new XdsServerWrapper("10.1.2.3:1", mockBuilder, listener, + selectorManager, new FakeXdsClientPoolFactory(xdsClient), + XdsServerTestHelper.RAW_BOOTSTRAP, + filterRegistry, executor.getScheduledExecutorService()); + final SettableFuture start = SettableFuture.create(); + Executors.newSingleThreadExecutor().execute(new Runnable() { + @Override + public void run() { + try { + start.set(xdsServerWrapper.start()); + } catch (Exception ex) { + start.setException(ex); + } + } + }); + String ldsResource = xdsClient.ldsResource.get(5, TimeUnit.SECONDS); + assertThat(ldsResource).isEqualTo("grpc/server?udpa.resource.listening_address=10.1.2.3:1"); + VirtualHost virtualHost = + VirtualHost.create( + "virtual-host", Collections.singletonList("auth"), new ArrayList(), + ImmutableMap.of()); + + xdsClient.deliverLdsUpdateWithApiListener(0L, Arrays.asList(virtualHost)); + + verify(listener, timeout(10000)).onNotServing(any()); + } + + @Test + public void onChanged_listenerAddressMissingPort() + throws ExecutionException, InterruptedException, TimeoutException { + xdsServerWrapper = new XdsServerWrapper("10.1.2.3:1", mockBuilder, listener, + selectorManager, new FakeXdsClientPoolFactory(xdsClient), + XdsServerTestHelper.RAW_BOOTSTRAP, + filterRegistry, executor.getScheduledExecutorService()); + final SettableFuture start = SettableFuture.create(); + Executors.newSingleThreadExecutor().execute(new Runnable() { + @Override + public void run() { + try { + start.set(xdsServerWrapper.start()); + } catch (Exception ex) { + start.setException(ex); + } + } + }); + String ldsResource = xdsClient.ldsResource.get(5, TimeUnit.SECONDS); + assertThat(ldsResource).isEqualTo("grpc/server?udpa.resource.listening_address=10.1.2.3:1"); + VirtualHost virtualHost = + VirtualHost.create( + "virtual-host", Collections.singletonList("auth"), new ArrayList(), + ImmutableMap.of()); + HttpConnectionManager httpConnectionManager = HttpConnectionManager.forVirtualHosts( + 0L, Collections.singletonList(virtualHost), new ArrayList()); + EnvoyServerProtoData.FilterChain filterChain = EnvoyServerProtoData.FilterChain.create( + "filter-chain-foo", createMatch(), httpConnectionManager, createTls(), + mock(TlsContextManager.class)); + LdsUpdate listenerUpdate = LdsUpdate.forTcpListener( + Listener.create("listener", "20.3.4.5:", + ImmutableList.copyOf(Collections.singletonList(filterChain)), null, Protocol.TCP)); + + xdsClient.deliverLdsUpdate(listenerUpdate); + + verify(listener, timeout(10000)).onNotServing(any()); + } + + @Test + public void onChanged_listenerAddressMismatch() + throws ExecutionException, InterruptedException, TimeoutException { + xdsServerWrapper = new XdsServerWrapper("10.1.2.3:1", mockBuilder, listener, + selectorManager, new FakeXdsClientPoolFactory(xdsClient), + XdsServerTestHelper.RAW_BOOTSTRAP, + filterRegistry, executor.getScheduledExecutorService()); + final SettableFuture start = SettableFuture.create(); + Executors.newSingleThreadExecutor().execute(new Runnable() { + @Override + public void run() { + try { + start.set(xdsServerWrapper.start()); + } catch (Exception ex) { + start.setException(ex); + } + } + }); + String ldsResource = xdsClient.ldsResource.get(5, TimeUnit.SECONDS); + assertThat(ldsResource).isEqualTo("grpc/server?udpa.resource.listening_address=10.1.2.3:1"); + VirtualHost virtualHost = + VirtualHost.create( + "virtual-host", Collections.singletonList("auth"), new ArrayList(), + ImmutableMap.of()); + HttpConnectionManager httpConnectionManager = HttpConnectionManager.forVirtualHosts( + 0L, Collections.singletonList(virtualHost), new ArrayList()); + EnvoyServerProtoData.FilterChain filterChain = EnvoyServerProtoData.FilterChain.create( + "filter-chain-foo", createMatch(), httpConnectionManager, createTls(), + mock(TlsContextManager.class)); + LdsUpdate listenerUpdate = LdsUpdate.forTcpListener( + Listener.create("listener", "20.3.4.5:1", + ImmutableList.copyOf(Collections.singletonList(filterChain)), null, Protocol.TCP)); + + xdsClient.deliverLdsUpdate(listenerUpdate); + + verify(listener, timeout(10000)).onNotServing(any()); + } + + @Test + public void onChanged_listenerAddressPortMismatch() + throws ExecutionException, InterruptedException, TimeoutException { + xdsServerWrapper = new XdsServerWrapper("10.1.2.3:1", mockBuilder, listener, + selectorManager, new FakeXdsClientPoolFactory(xdsClient), + XdsServerTestHelper.RAW_BOOTSTRAP, + filterRegistry, executor.getScheduledExecutorService()); + final SettableFuture start = SettableFuture.create(); + Executors.newSingleThreadExecutor().execute(new Runnable() { + @Override + public void run() { + try { + start.set(xdsServerWrapper.start()); + } catch (Exception ex) { + start.setException(ex); + } + } + }); + String ldsResource = xdsClient.ldsResource.get(5, TimeUnit.SECONDS); + assertThat(ldsResource).isEqualTo("grpc/server?udpa.resource.listening_address=10.1.2.3:1"); + VirtualHost virtualHost = + VirtualHost.create( + "virtual-host", Collections.singletonList("auth"), new ArrayList(), + ImmutableMap.of()); + HttpConnectionManager httpConnectionManager = HttpConnectionManager.forVirtualHosts( + 0L, Collections.singletonList(virtualHost), new ArrayList()); + EnvoyServerProtoData.FilterChain filterChain = EnvoyServerProtoData.FilterChain.create( + "filter-chain-foo", createMatch(), httpConnectionManager, createTls(), + mock(TlsContextManager.class)); + LdsUpdate listenerUpdate = LdsUpdate.forTcpListener( + Listener.create("listener", "10.1.2.3:2", + ImmutableList.copyOf(Collections.singletonList(filterChain)), null, Protocol.TCP)); + + xdsClient.deliverLdsUpdate(listenerUpdate); + + verify(listener, timeout(10000)).onNotServing(any()); + } + @Test public void discoverState_rds() throws Exception { final SettableFuture start = SettableFuture.create(); @@ -546,7 +712,7 @@ public void run() { 0L, Collections.singletonList(virtualHost), new ArrayList()); EnvoyServerProtoData.FilterChain f0 = createFilterChain("filter-chain-0", hcmVirtual); EnvoyServerProtoData.FilterChain f1 = createFilterChain("filter-chain-1", createRds("r0")); - xdsClient.rdsCount = new CountDownLatch(3); + xdsClient.setExpectedRdsCount(3); xdsClient.deliverLdsUpdate(Arrays.asList(f0, f1), null); assertThat(start.isDone()).isFalse(); assertThat(selectorManager.getSelectorToUpdateSelector()).isNull(); @@ -558,7 +724,7 @@ public void run() { xdsClient.deliverLdsUpdate(Arrays.asList(f0, f2), f3); verify(mockServer, never()).start(); verify(listener, never()).onServing(); - xdsClient.rdsCount.await(5, TimeUnit.SECONDS); + xdsClient.awaitRds(FakeXdsClient.DEFAULT_TIMEOUT); xdsClient.deliverRdsUpdate("r1", Collections.singletonList(createVirtualHost("virtual-host-1"))); @@ -604,12 +770,11 @@ public void run() { EnvoyServerProtoData.FilterChain f1 = createFilterChain("filter-chain-1", createRds("r0")); EnvoyServerProtoData.FilterChain f2 = createFilterChain("filter-chain-2", createRds("r0")); - xdsClient.rdsCount = new CountDownLatch(1); xdsClient.deliverLdsUpdate(Arrays.asList(f0, f1), f2); assertThat(start.isDone()).isFalse(); assertThat(selectorManager.getSelectorToUpdateSelector()).isNull(); - xdsClient.rdsCount.await(5, TimeUnit.SECONDS); + xdsClient.awaitRds(FakeXdsClient.DEFAULT_TIMEOUT); xdsClient.deliverRdsUpdate("r0", Collections.singletonList(createVirtualHost("virtual-host-0"))); start.get(5000, TimeUnit.MILLISECONDS); @@ -635,9 +800,9 @@ public void run() { EnvoyServerProtoData.FilterChain f3 = createFilterChain("filter-chain-3", createRds("r0")); EnvoyServerProtoData.FilterChain f4 = createFilterChain("filter-chain-4", createRds("r1")); EnvoyServerProtoData.FilterChain f5 = createFilterChain("filter-chain-4", createRds("r1")); - xdsClient.rdsCount = new CountDownLatch(1); + xdsClient.setExpectedRdsCount(1); xdsClient.deliverLdsUpdate(Arrays.asList(f5, f3), f4); - xdsClient.rdsCount.await(5, TimeUnit.SECONDS); + xdsClient.awaitRds(FakeXdsClient.DEFAULT_TIMEOUT); xdsClient.deliverRdsUpdate("r1", Collections.singletonList(createVirtualHost("virtual-host-1"))); xdsClient.deliverRdsUpdate("r0", @@ -690,8 +855,8 @@ public void run() { EnvoyServerProtoData.FilterChain f0 = createFilterChain("filter-chain-0", hcmVirtual); EnvoyServerProtoData.FilterChain f1 = createFilterChain("filter-chain-1", createRds("r0")); xdsClient.deliverLdsUpdate(Arrays.asList(f0, f1), null); - xdsClient.rdsCount.await(); - xdsClient.rdsWatchers.get("r0").onError(Status.CANCELLED); + xdsClient.awaitRds(FakeXdsClient.DEFAULT_TIMEOUT); + xdsClient.rdsWatchers.get("r0").onResourceChanged(StatusOr.fromStatus(Status.CANCELLED)); start.get(5000, TimeUnit.MILLISECONDS); assertThat(selectorManager.getSelectorToUpdateSelector().getRoutingConfigs().size()) .isEqualTo(2); @@ -711,13 +876,14 @@ public void run() { Collections.singletonList(createVirtualHost("virtual-host-1"))); assertThat(realConfig.interceptors()).isEqualTo(ImmutableMap.of()); - xdsClient.rdsWatchers.get("r0").onError(Status.CANCELLED); + xdsClient.rdsWatchers.get("r0").onAmbientError(Status.CANCELLED); realConfig = selectorManager.getSelectorToUpdateSelector().getRoutingConfigs().get(f1).get(); assertThat(realConfig.virtualHosts()).isEqualTo( Collections.singletonList(createVirtualHost("virtual-host-1"))); assertThat(realConfig.interceptors()).isEqualTo(ImmutableMap.of()); - xdsClient.rdsWatchers.get("r0").onResourceDoesNotExist("r0"); + Status notFoundStatus = Status.NOT_FOUND.withDescription("Resource r0 does not exist"); + xdsClient.rdsWatchers.get("r0").onResourceChanged(StatusOr.fromStatus(notFoundStatus)); realConfig = selectorManager.getSelectorToUpdateSelector().getRoutingConfigs().get(f1).get(); assertThat(realConfig.virtualHosts()).isEmpty(); assertThat(realConfig.interceptors()).isEmpty(); @@ -737,7 +903,9 @@ public void run() { } }); String ldsResource = xdsClient.ldsResource.get(5, TimeUnit.SECONDS); - xdsClient.ldsWatcher.onResourceDoesNotExist(ldsResource); + Status notFoundStatus = Status.NOT_FOUND.withDescription( + "FakeXdsClient: Resource not found: " + ldsResource); + xdsClient.ldsWatcher.onResourceChanged(StatusOr.fromStatus(notFoundStatus)); verify(listener, timeout(5000)).onNotServing(any()); try { start.get(START_WAIT_AFTER_LISTENER_MILLIS, TimeUnit.MILLISECONDS); @@ -751,10 +919,10 @@ public void run() { FilterChain filterChain0 = createFilterChain("filter-chain-0", createRds("rds")); SslContextProviderSupplier sslSupplier0 = filterChain0.sslContextProviderSupplier(); xdsClient.deliverLdsUpdate(Collections.singletonList(filterChain0), null); - xdsClient.ldsWatcher.onError(Status.INTERNAL); + ResourceWatcher saveRdsWatcher = xdsClient.rdsWatchers.get("rds"); + xdsClient.ldsWatcher.onResourceChanged(StatusOr.fromStatus(Status.INTERNAL)); assertThat(selectorManager.getSelectorToUpdateSelector()) .isSameInstanceAs(FilterChainSelector.NO_FILTER_CHAIN); - ResourceWatcher saveRdsWatcher = xdsClient.rdsWatchers.get("rds"); verify(mockBuilder, times(1)).build(); verify(listener, times(2)).onNotServing(any(StatusException.class)); assertThat(sslSupplier0.isShutdown()).isFalse(); @@ -790,7 +958,7 @@ public void run() { xdsClient.deliverRdsUpdate("rds", Collections.singletonList(createVirtualHost("virtual-host-2"))); assertThat(sslSupplier1.isShutdown()).isFalse(); - xdsClient.ldsWatcher.onError(Status.DEADLINE_EXCEEDED); + xdsClient.ldsWatcher.onAmbientError(Status.DEADLINE_EXCEEDED); verify(mockBuilder, times(1)).build(); verify(mockServer, times(2)).start(); verify(listener, times(2)).onNotServing(any(StatusException.class)); @@ -805,17 +973,18 @@ public void run() { assertThat(sslSupplier1.isShutdown()).isFalse(); // not serving after serving - xdsClient.ldsWatcher.onResourceDoesNotExist(ldsResource); + xdsClient.ldsWatcher.onResourceChanged(StatusOr.fromStatus(notFoundStatus)); assertThat(xdsClient.rdsWatchers).isEmpty(); - verify(mockServer, times(2)).shutdown(); + verify(mockServer, times(3)).shutdown(); // This is the 3rd shutdown in the test. when(mockServer.isShutdown()).thenReturn(true); assertThat(selectorManager.getSelectorToUpdateSelector()) .isSameInstanceAs(FilterChainSelector.NO_FILTER_CHAIN); verify(listener, times(3)).onNotServing(any(StatusException.class)); assertThat(sslSupplier1.isShutdown()).isTrue(); + assertThat(xdsClient.rdsWatchers.get("rds")).isNull(); // no op - saveRdsWatcher.onChanged( - new RdsUpdate(Collections.singletonList(createVirtualHost("virtual-host-1")))); + saveRdsWatcher.onResourceChanged(StatusOr.fromValue( + new RdsUpdate(Collections.singletonList(createVirtualHost("virtual-host-1"))))); verify(mockBuilder, times(1)).build(); verify(mockServer, times(2)).start(); verify(listener, times(1)).onServing(); @@ -844,8 +1013,8 @@ public void run() { assertThat(realConfig.interceptors()).isEqualTo(ImmutableMap.of()); assertThat(executor.numPendingTasks()).isEqualTo(1); - xdsClient.ldsWatcher.onResourceDoesNotExist(ldsResource); - verify(mockServer, times(3)).shutdown(); + xdsClient.ldsWatcher.onResourceChanged(StatusOr.fromStatus(notFoundStatus)); + verify(mockServer, times(4)).shutdown(); verify(listener, times(4)).onNotServing(any(StatusException.class)); verify(listener, times(1)).onNotServing(any(IOException.class)); when(mockServer.isShutdown()).thenReturn(true); @@ -873,7 +1042,7 @@ public void run() { assertThat(realConfig.interceptors()).isEqualTo(ImmutableMap.of()); xdsServerWrapper.shutdown(); - verify(mockServer, times(4)).shutdown(); + verify(mockServer, times(5)).shutdown(); assertThat(sslSupplier3.isShutdown()).isTrue(); when(mockServer.awaitTermination(anyLong(), any(TimeUnit.class))).thenReturn(true); assertThat(xdsServerWrapper.awaitTermination(5, TimeUnit.SECONDS)).isTrue(); @@ -957,9 +1126,11 @@ public void run() { new AtomicReference<>(routingConfig)).build()); when(serverCall.getAuthority()).thenReturn("not-match.google.com"); - Filter filter = mock(Filter.class); - when(filter.typeUrls()).thenReturn(new String[]{"filter-type-url"}); - filterRegistry.register(filter); + Filter.Provider filterProvider = mock(Filter.Provider.class); + when(filterProvider.typeUrls()).thenReturn(new String[]{"filter-type-url"}); + when(filterProvider.isServerFilter()).thenReturn(true); + filterRegistry.register(filterProvider); + ServerCallHandler next = mock(ServerCallHandler.class); interceptor.interceptCall(serverCall, new Metadata(), next); verify(next, never()).startCall(any(ServerCall.class), any(Metadata.class)); @@ -998,9 +1169,11 @@ public void run() { when(serverCall.getMethodDescriptor()).thenReturn(createMethod("NotMatchMethod")); when(serverCall.getAuthority()).thenReturn("foo.google.com"); - Filter filter = mock(Filter.class); - when(filter.typeUrls()).thenReturn(new String[]{"filter-type-url"}); - filterRegistry.register(filter); + Filter.Provider filterProvider = mock(Filter.Provider.class); + when(filterProvider.typeUrls()).thenReturn(new String[]{"filter-type-url"}); + when(filterProvider.isServerFilter()).thenReturn(true); + filterRegistry.register(filterProvider); + ServerCallHandler next = mock(ServerCallHandler.class); interceptor.interceptCall(serverCall, new Metadata(), next); verify(next, never()).startCall(any(ServerCall.class), any(Metadata.class)); @@ -1035,7 +1208,8 @@ public void run() { "/FooService/barMethod", "foo.google.com", Route.RouteAction.forCluster( - "cluster", Collections.emptyList(), null, null)); + "cluster", Collections.emptyList(), null, null, + false)); ServerCall serverCall = mock(ServerCall.class); when(serverCall.getAttributes()).thenReturn( Attributes.newBuilder() @@ -1043,9 +1217,11 @@ public void run() { when(serverCall.getMethodDescriptor()).thenReturn(createMethod("FooService/barMethod")); when(serverCall.getAuthority()).thenReturn("foo.google.com"); - Filter filter = mock(Filter.class); - when(filter.typeUrls()).thenReturn(new String[]{"filter-type-url"}); - filterRegistry.register(filter); + Filter.Provider filterProvider = mock(Filter.Provider.class); + when(filterProvider.typeUrls()).thenReturn(new String[]{"filter-type-url"}); + when(filterProvider.isServerFilter()).thenReturn(true); + filterRegistry.register(filterProvider); + ServerCallHandler next = mock(ServerCallHandler.class); interceptor.interceptCall(serverCall, new Metadata(), next); verify(next, never()).startCall(any(ServerCall.class), any(Metadata.class)); @@ -1112,10 +1288,14 @@ public void run() { RouteMatch.create( PathMatcher.fromPath("/FooService/barMethod", true), Collections.emptyList(), null); - Filter filter = mock(Filter.class, withSettings() - .extraInterfaces(ServerInterceptorBuilder.class)); - when(filter.typeUrls()).thenReturn(new String[]{"filter-type-url"}); - filterRegistry.register(filter); + + Filter filter = mock(Filter.class); + Filter.Provider filterProvider = mock(Filter.Provider.class); + when(filterProvider.typeUrls()).thenReturn(new String[]{"filter-type-url"}); + when(filterProvider.isServerFilter()).thenReturn(true); + when(filterProvider.newInstance(any(String.class))).thenReturn(filter); + filterRegistry.register(filterProvider); + FilterConfig f0 = mock(FilterConfig.class); FilterConfig f0Override = mock(FilterConfig.class); when(f0.typeUrl()).thenReturn("filter-type-url"); @@ -1136,10 +1316,8 @@ public ServerCall.Listener interceptCall(ServerCallof()); VirtualHost virtualHost = VirtualHost.create( @@ -1184,10 +1362,13 @@ public void run() { }); xdsClient.ldsResource.get(5, TimeUnit.SECONDS); - Filter filter = mock(Filter.class, withSettings() - .extraInterfaces(ServerInterceptorBuilder.class)); - when(filter.typeUrls()).thenReturn(new String[]{"filter-type-url"}); - filterRegistry.register(filter); + Filter filter = mock(Filter.class); + Filter.Provider filterProvider = mock(Filter.Provider.class); + when(filterProvider.typeUrls()).thenReturn(new String[]{"filter-type-url"}); + when(filterProvider.isServerFilter()).thenReturn(true); + when(filterProvider.newInstance(any(String.class))).thenReturn(filter); + filterRegistry.register(filterProvider); + FilterConfig f0 = mock(FilterConfig.class); FilterConfig f0Override = mock(FilterConfig.class); when(f0.typeUrl()).thenReturn("filter-type-url"); @@ -1208,10 +1389,8 @@ public ServerCall.Listener interceptCall(ServerCall ServerCall.Listener interceptCall(ServerCall ServerCall.Listener interceptCall(ServerCall serverStart = filterStateTestStartServer(filterRegistry); + + VirtualHost vhost = filterStateTestVhost(); + + // LDS 1. + FilterChain lds1FilterChain = createFilterChain("chain_0", + createHcm(vhost, filterStateTestConfigs(STATEFUL_1, STATEFUL_2))); + xdsClient.deliverLdsUpdate(lds1FilterChain, null); + verifyServerStarted(serverStart); + ImmutableList lds1Snapshot = statefulFilterProvider.getAllInstances(); + // Verify that StatefulFilter with different filter names result in different Filter instances. + assertWithMessage("LDS 1: expected to create filter instances").that(lds1Snapshot).hasSize(2); + // Naming: ldsFilter + StatefulFilter lds1Filter1 = lds1Snapshot.get(0); + StatefulFilter lds1Filter2 = lds1Snapshot.get(1); + assertThat(lds1Filter1).isNotSameInstanceAs(lds1Filter2); + // Redundant check just in case StatefulFilter synchronization is broken. + assertThat(lds1Filter1.idx).isEqualTo(0); + assertThat(lds1Filter2.idx).isEqualTo(1); + + // LDS 2: filter configs with the same names. + FilterChain lds2FilterChain = createFilterChain("chain_0", + createHcm(vhost, filterStateTestConfigs(STATEFUL_1, STATEFUL_2))); + xdsClient.deliverLdsUpdate(lds2FilterChain, null); + ImmutableList lds2Snapshot = statefulFilterProvider.getAllInstances(); + // Filter names hasn't changed, so expecting no new StatefulFilter instances. + assertWithMessage("LDS 2: Expected Filter instances to be reused across LDS updates") + .that(lds2Snapshot).isEqualTo(lds1Snapshot); + + // LDS 3: Filter "STATEFUL_2" removed. + FilterChain lds3FilterChain = createFilterChain("chain_0", + createHcm(vhost, filterStateTestConfigs(STATEFUL_1))); + xdsClient.deliverLdsUpdate(lds3FilterChain, null); + ImmutableList lds3Snapshot = statefulFilterProvider.getAllInstances(); + // Again, no new StatefulFilter instances should be created. + assertWithMessage("LDS 3: Expected Filter instances to be reused across LDS updates") + .that(lds3Snapshot).isEqualTo(lds1Snapshot); + // Verify the shutdown state. + assertThat(lds1Filter1.isShutdown()).isFalse(); + assertWithMessage("LDS 3: Expected %s to be shut down", lds1Filter2) + .that(lds1Filter2.isShutdown()).isTrue(); + + // LDS 4: Filter "STATEFUL_2" added back. + FilterChain lds4FilterChain = createFilterChain("chain_0", + createHcm(vhost, filterStateTestConfigs(STATEFUL_1, STATEFUL_2))); + xdsClient.deliverLdsUpdate(lds4FilterChain, null); + ImmutableList lds4Snapshot = statefulFilterProvider.getAllInstances(); + // Filter "STATEFUL_2" should be treated as any other new filter name in an LDS update: + // a new instance should be created. + assertWithMessage("LDS 4: Expected a new filter instance for %s", STATEFUL_2) + .that(lds4Snapshot).hasSize(3); + StatefulFilter lds4Filter2 = lds4Snapshot.get(2); + assertThat(lds4Filter2.idx).isEqualTo(2); + assertThat(lds4Filter2).isNotSameInstanceAs(lds1Filter2); + assertThat(lds4Snapshot).containsAtLeastElementsIn(lds1Snapshot); + // Verify the shutdown state. + assertThat(lds1Filter1.isShutdown()).isFalse(); + assertThat(lds1Filter2.isShutdown()).isTrue(); + assertThat(lds4Filter2.isShutdown()).isFalse(); + } + + @Test + public void filterState_survivesRds() throws Exception { + StatefulFilter.Provider statefulFilterProvider = new StatefulFilter.Provider(); + FilterRegistry filterRegistry = filterStateTestFilterRegistry(statefulFilterProvider); + SettableFuture serverStart = filterStateTestStartServer(filterRegistry); + + String rdsName = "rds.example.com"; + + // LDS 1. + FilterChain fc1 = createFilterChain("fc1", + createHcmForRds(rdsName, filterStateTestConfigs(STATEFUL_1, STATEFUL_2))); + xdsClient.deliverLdsUpdate(fc1, null); + xdsClient.awaitRds(FakeXdsClient.DEFAULT_TIMEOUT); + verify(listener, never()).onServing(); + // Server didn't start, but filter instances should have already been created. + ImmutableList lds1Snapshot = statefulFilterProvider.getAllInstances(); + assertWithMessage("LDS 1: expected to create filter instances").that(lds1Snapshot).hasSize(2); + // Naming: ldsFilter + StatefulFilter lds1Filter1 = lds1Snapshot.get(0); + StatefulFilter lds1Filter2 = lds1Snapshot.get(1); + assertThat(lds1Filter1).isNotSameInstanceAs(lds1Filter2); + + // RDS 1. + VirtualHost vhost1 = filterStateTestVhost(); + xdsClient.deliverRdsUpdate(rdsName, vhost1); + verifyServerStarted(serverStart); + assertThat(getSelectorRoutingConfigs()).hasSize(1); + assertThat(getSelectorVhosts(fc1)).containsExactly(vhost1); + // Initial RDS update should not generate Filter instances. + ImmutableList rds1Snapshot = statefulFilterProvider.getAllInstances(); + assertWithMessage("RDS 1: Expected Filter instances to be reused across RDS route updates") + .that(rds1Snapshot).isEqualTo(lds1Snapshot); + + // RDS 2: exactly the same as RDS 1. + xdsClient.deliverRdsUpdate(rdsName, vhost1); + assertThat(getSelectorRoutingConfigs()).hasSize(1); + assertThat(getSelectorVhosts(fc1)).containsExactly(vhost1); + ImmutableList rds2Snapshot = statefulFilterProvider.getAllInstances(); + // Neither should any subsequent RDS updates. + assertWithMessage("RDS 2: Expected Filter instances to be reused across RDS route updates") + .that(rds2Snapshot).isEqualTo(lds1Snapshot); + + // RDS 3: Contains a per-route override for STATEFUL_1. + VirtualHost vhost3 = filterStateTestVhost(vhost1.name(), ImmutableMap.of( + STATEFUL_1, new Config("RDS3") + )); + xdsClient.deliverRdsUpdate(rdsName, vhost3); + assertThat(getSelectorRoutingConfigs()).hasSize(1); + assertThat(getSelectorVhosts(fc1)).containsExactly(vhost3); + ImmutableList rds3Snapshot = statefulFilterProvider.getAllInstances(); + // As with any other Route update, typed_per_filter_config overrides should not result in + // creating new filter instances. + assertWithMessage("RDS 3: Expected Filter instances to be reused on per-route filter overrides") + .that(rds3Snapshot).isEqualTo(lds1Snapshot); + } + + @Test + public void filterState_uniquePerFilterChain() { + StatefulFilter.Provider statefulFilterProvider = new StatefulFilter.Provider(); + FilterRegistry filterRegistry = filterStateTestFilterRegistry(statefulFilterProvider); + SettableFuture serverStart = filterStateTestStartServer(filterRegistry); + + // Prepare multiple filter chains matchers for testing. + FilterChainMatch matcherA = createMatchSrcIp("3fff:a::/32"); + FilterChainMatch matcherB = createMatchSrcIp("3fff:b::/32"); + + // Vhosts won't change too. + VirtualHost vhostA = filterStateTestVhost("stateful_vhost_a"); + VirtualHost vhostB = filterStateTestVhost("stateful_vhost_b"); + + // LDS 1. + FilterChain lds1ChainA = createFilterChain("chain_a", + createHcm(vhostA, filterStateTestConfigs(STATEFUL_1, STATEFUL_2)), + matcherA); + FilterChain lds1ChainB = createFilterChain("chain_b", + createHcm(vhostB, filterStateTestConfigs(STATEFUL_2)), + matcherB); + + xdsClient.deliverLdsUpdate(ImmutableList.of(lds1ChainA, lds1ChainB), null); + verifyServerStarted(serverStart); + ImmutableList lds1Snapshot = statefulFilterProvider.getAllInstances(); + // Verify that filter with name STATEFUL_2 produced separate instances unique per filter chain. + assertWithMessage("LDS 1: expected to create filter instances").that(lds1Snapshot).hasSize(3); + // Naming: ldsChainFilter + StatefulFilter lds1ChainAFilter1 = lds1Snapshot.get(0); + StatefulFilter lds1ChainAFilter2 = lds1Snapshot.get(1); + StatefulFilter lds1ChainBFilter2 = lds1Snapshot.get(2); + assertThat(lds1ChainAFilter2).isNotSameInstanceAs(lds1ChainBFilter2); + + // LDS 2: In chain B filter with name STATEFUL_1 is replaced STATEFUL_2. + FilterChain lds2ChainA = createFilterChain("chain_a", + createHcm(vhostA, filterStateTestConfigs(STATEFUL_1, STATEFUL_2)), + matcherA); + FilterChain lds2ChainB = createFilterChain("chain_b", + createHcm(vhostB, filterStateTestConfigs(STATEFUL_1)), + matcherB); + + xdsClient.deliverLdsUpdate(ImmutableList.of(lds2ChainA, lds2ChainB), null); + ImmutableList lds2Snapshot = statefulFilterProvider.getAllInstances(); + assertWithMessage("LDS 2: expected a distinct instance of filter %s for Chain B", STATEFUL_1) + .that(lds2Snapshot).hasSize(4); + StatefulFilter lds2ChainBFilter1 = lds2Snapshot.get(3); + assertThat(lds2ChainBFilter1).isNotSameInstanceAs(lds1ChainAFilter1); + // Confirm correct STATEFUL_2 has been shut down. + assertThat(lds1ChainBFilter2.isShutdown()).isTrue(); + assertThat(lds1ChainAFilter2.isShutdown()).isFalse(); + + // LDS 3: Add default chain + // Default filter chain is an exception from the uniqueness rule, and we need to make sure + // that this is accounted for when we're tracking active filters per unique FilterChain. + FilterChain lds3ChainDefault = createFilterChain("chain_default", + createHcm(vhostA, filterStateTestConfigs(STATEFUL_1, STATEFUL_2)), + matcherA); + xdsClient.deliverLdsUpdate(ImmutableList.of(lds2ChainA, lds2ChainB), lds3ChainDefault); + ImmutableList lds3Snapshot = statefulFilterProvider.getAllInstances(); + assertWithMessage("LDS 3: Expected two new distinct filter instances for default chain") + .that(lds3Snapshot).hasSize(6); + StatefulFilter lds3ChainDefaultFilter1 = lds3Snapshot.get(4); + StatefulFilter lds3ChainDefaultFilter2 = lds3Snapshot.get(5); + // STATEFUL_1 in default chain not the same STATEFUL_1 in chain A or B + assertThat(lds3ChainDefaultFilter1).isNotSameInstanceAs(lds1ChainAFilter1); + assertThat(lds3ChainDefaultFilter1).isNotSameInstanceAs(lds2ChainBFilter1); + // STATEFUL_2 in default chain not the same STATEFUL_1 in chain A + assertThat(lds3ChainDefaultFilter2).isNotSameInstanceAs(lds1ChainAFilter2); + } + + /** + * Verifies a special case where an existing filter is has a different typeUrl in a subsequent + * LDS update. + * + *

Expectations: + * 1. The old filter instance must be shutdown. + * 2. A new filter instance must be created for the new filter with different typeUrl. + */ + @Test + public void filterState_specialCase_sameNameDifferentTypeUrl() { + // Setup the server with filter containing StatefulFilter.Provider for two distict type URLs. + StatefulFilter.Provider statefulFilterProvider = new StatefulFilter.Provider(); + String altTypeUrl = "type.googleapis.com/grpc.test.AltStatefulFilter"; + StatefulFilter.Provider altStatefulFilterProvider = new StatefulFilter.Provider(altTypeUrl); + FilterRegistry filterRegistry = FilterRegistry.newRegistry() + .register(statefulFilterProvider, altStatefulFilterProvider, ROUTER_FILTER_PROVIDER); + SettableFuture serverStart = filterStateTestStartServer(filterRegistry); + + // Test a normal chain and the default chain, as it's handled separately. + VirtualHost vhost = filterStateTestVhost(); + + // LDS 1. + ImmutableList lds1Confgs = filterStateTestConfigs(STATEFUL_1, STATEFUL_2); + FilterChain lds1ChainA = createFilterChain("chain_a", createHcm(vhost, lds1Confgs)); + FilterChain lds1ChainDefault = createFilterChain("chain_default", createHcm(vhost, lds1Confgs)); + xdsClient.deliverLdsUpdate(lds1ChainA, lds1ChainDefault); + verifyServerStarted(serverStart); + ImmutableList lds1Snapshot = statefulFilterProvider.getAllInstances(); + assertWithMessage("LDS 1: expected to create filter instances").that(lds1Snapshot).hasSize(4); + // Naming: ldsChainFilter + StatefulFilter lds1ChainAFilter1 = lds1Snapshot.get(0); + StatefulFilter lds1ChainAFilter2 = lds1Snapshot.get(1); + StatefulFilter lds1ChainDefaultFilter1 = lds1Snapshot.get(2); + StatefulFilter lds1ChainDefaultFilter2 = lds1Snapshot.get(3); + + // LDS 2: Filter STATEFUL_2 present, but with a different typeUrl: altTypeUrl. + ImmutableList lds2Confgs = ImmutableList.of( + new NamedFilterConfig(STATEFUL_1, new StatefulFilter.Config(STATEFUL_1)), + new NamedFilterConfig(STATEFUL_2, new StatefulFilter.Config(STATEFUL_2, altTypeUrl)), + new NamedFilterConfig(ROUTER_FILTER_INSTANCE_NAME, RouterFilter.ROUTER_CONFIG) + ); + FilterChain lds2ChainA = createFilterChain("chain_a", createHcm(vhost, lds2Confgs)); + FilterChain lds2ChainDefault = createFilterChain("chain_default", createHcm(vhost, lds2Confgs)); + xdsClient.deliverLdsUpdate(lds2ChainA, lds2ChainDefault); + ImmutableList lds2Snapshot = statefulFilterProvider.getAllInstances(); + ImmutableList lds2SnapshotAlt = altStatefulFilterProvider.getAllInstances(); + // Filter "STATEFUL_2" has different typeUrl, and should be treated as a new filter. + // No changes in the snapshot of normal stateful filters. + assertThat(lds2Snapshot).isEqualTo(lds1Snapshot); + // Two new filter instances is created by altStatefulFilterProvider for chainA and chainDefault. + assertWithMessage("LDS 2: expected new filter instances for type %s", altTypeUrl) + .that(lds2SnapshotAlt).hasSize(2); + StatefulFilter lds2ChainAFilter2Alt = lds2SnapshotAlt.get(0); + StatefulFilter lds2ChainADefault2Alt = lds2SnapshotAlt.get(1); + // Confirm two new distict instances of STATEFUL_2 were created. + assertThat(lds2ChainAFilter2Alt).isNotSameInstanceAs(lds1ChainAFilter2); + assertThat(lds2ChainADefault2Alt).isNotSameInstanceAs(lds1ChainDefaultFilter2); + assertThat(lds2ChainAFilter2Alt).isNotSameInstanceAs(lds2ChainADefault2Alt); + // Verify the instance of STATEFUL_2 of the old type are shutdown. + assertThat(lds1ChainAFilter2.isShutdown()).isTrue(); + assertThat(lds1ChainDefaultFilter2.isShutdown()).isTrue(); + // Verify the new instances of STATEFUL_2 and the old instances of STATEFUL_1 are running. + assertThat(lds2ChainAFilter2Alt.isShutdown()).isFalse(); + assertThat(lds2ChainADefault2Alt.isShutdown()).isFalse(); + assertThat(lds1ChainAFilter1.isShutdown()).isFalse(); + assertThat(lds1ChainDefaultFilter1.isShutdown()).isFalse(); + } + + /** + * Verifies that all filter instances are shutdown (closed) on LDS resource not found. + */ + @Test + public void filterState_shutdown_onLdsNotFound() { + StatefulFilter.Provider statefulFilterProvider = new StatefulFilter.Provider(); + FilterRegistry filterRegistry = filterStateTestFilterRegistry(statefulFilterProvider); + SettableFuture serverStart = filterStateTestStartServer(filterRegistry); + + // Test a normal chain and the default chain, as it's handled separately. + VirtualHost vhost = filterStateTestVhost(); + FilterChain chainA = createFilterChain("chain_a", + createHcm(vhost, filterStateTestConfigs(STATEFUL_1))); + FilterChain chainDefault = createFilterChain("chain_default", + createHcm(vhost, filterStateTestConfigs(STATEFUL_2))); + + // LDS 1. + xdsClient.deliverLdsUpdate(chainA, chainDefault); + verifyServerStarted(serverStart); + ImmutableList lds1Snapshot = statefulFilterProvider.getAllInstances(); + assertWithMessage("LDS 1: expected to create filter instances").that(lds1Snapshot).hasSize(2); + // Naming: ldsChainFilter + StatefulFilter lds1ChainAFilter1 = lds1Snapshot.get(0); + StatefulFilter lds1ChainDefaultFilter2 = lds1Snapshot.get(1); + + // LDS 2: resource not found. + xdsClient.deliverLdsResourceNotFound(); + // Verify shutdown. + assertThat(lds1ChainAFilter1.isShutdown()).isTrue(); + assertThat(lds1ChainDefaultFilter2.isShutdown()).isTrue(); + } + + /** + * Verifies that all filter instances of a filter chain are shutdown when said chain is removed. + */ + @Test + public void filterState_shutdown_onChainRemoved() { + StatefulFilter.Provider statefulFilterProvider = new StatefulFilter.Provider(); + FilterRegistry filterRegistry = filterStateTestFilterRegistry(statefulFilterProvider); + SettableFuture serverStart = filterStateTestStartServer(filterRegistry); + + ImmutableList configs = filterStateTestConfigs(STATEFUL_1, STATEFUL_2); + FilterChain chainA = createFilterChain("chain_a", + createHcm(filterStateTestVhost("stateful_vhost_a"), configs), + createMatchSrcIp("3fff:a::/32")); + FilterChain chainB = createFilterChain("chain_b", + createHcm(filterStateTestVhost("stateful_vhost_b"), configs), + createMatchSrcIp("3fff:b::/32")); + FilterChain chainDefault = createFilterChain("chain_default", + createHcm(filterStateTestVhost("stateful_vhost_default"), configs), + createMatchSrcIp("3fff:defa::/32")); + + // LDS 1. + xdsClient.deliverLdsUpdate(ImmutableList.of(chainA, chainB), chainDefault); + verifyServerStarted(serverStart); + ImmutableList lds1Snapshot = statefulFilterProvider.getAllInstances(); + assertWithMessage("LDS 1: expected to create filter instances").that(lds1Snapshot).hasSize(6); + StatefulFilter chainAFilter1 = lds1Snapshot.get(0); + StatefulFilter chainAFilter2 = lds1Snapshot.get(1); + StatefulFilter chainBFilter1 = lds1Snapshot.get(2); + StatefulFilter chainBFilter2 = lds1Snapshot.get(3); + StatefulFilter chainDefaultFilter1 = lds1Snapshot.get(4); + StatefulFilter chainDefaultFilter2 = lds1Snapshot.get(5); + + // LDS 2: ChainB and ChainDefault are gone. + xdsClient.deliverLdsUpdate(chainA, null); + assertThat(statefulFilterProvider.getAllInstances()).isEqualTo(lds1Snapshot); + // ChainA filters not shutdown (just in case). + assertThat(chainAFilter1.isShutdown()).isFalse(); + assertThat(chainAFilter2.isShutdown()).isFalse(); + // ChainB and ChainDefault filters shutdown. + assertWithMessage("chainBFilter1").that(chainBFilter1.isShutdown()).isTrue(); + assertWithMessage("chainBFilter2").that(chainBFilter2.isShutdown()).isTrue(); + assertWithMessage("chainDefaultFilter1").that(chainDefaultFilter1.isShutdown()).isTrue(); + assertWithMessage("chainDefaultFilter2").that(chainDefaultFilter2.isShutdown()).isTrue(); + } + + /** + * Verifies that all filter instances are shutdown (closed) on LDS ResourceWatcher shutdown. + */ + @Test + public void filterState_shutdown_onServerShutdown() { + StatefulFilter.Provider statefulFilterProvider = new StatefulFilter.Provider(); + FilterRegistry filterRegistry = filterStateTestFilterRegistry(statefulFilterProvider); + SettableFuture serverStart = filterStateTestStartServer(filterRegistry); + + // Test a normal chain and the default chain, as it's handled separately. + VirtualHost vhost = filterStateTestVhost(); + FilterChain chainA = createFilterChain("chain_a", + createHcm(vhost, filterStateTestConfigs(STATEFUL_1))); + FilterChain chainDefault = createFilterChain("chain_default", + createHcm(vhost, filterStateTestConfigs(STATEFUL_2))); + + // LDS 1. + xdsClient.deliverLdsUpdate(chainA, chainDefault); + verifyServerStarted(serverStart); + ImmutableList lds1Snapshot = statefulFilterProvider.getAllInstances(); + assertWithMessage("LDS 1: expected to create filter instances").that(lds1Snapshot).hasSize(2); + // Naming: ldsChainFilter + StatefulFilter lds1ChainAFilter1 = lds1Snapshot.get(0); + StatefulFilter lds1ChainDefaultFilter2 = lds1Snapshot.get(1); + + // Shutdown. + xdsServerWrapper.shutdown(); + assertThat(xdsServerWrapper.isShutdown()).isTrue(); + assertThat(xdsClient.isShutDown()).isTrue(); + // Verify shutdown. + assertThat(lds1ChainAFilter1.isShutdown()).isTrue(); + assertThat(lds1ChainDefaultFilter2.isShutdown()).isTrue(); + } + + /** + * Verifies that filter instances are NOT shutdown on RDS_RESOURCE_NAME not found. + */ + @Test + public void filterState_shutdown_noShutdownOnRdsNotFound() throws Exception { + StatefulFilter.Provider statefulFilterProvider = new StatefulFilter.Provider(); + FilterRegistry filterRegistry = filterStateTestFilterRegistry(statefulFilterProvider); + SettableFuture serverStart = filterStateTestStartServer(filterRegistry); + + String rdsName = "rds.example.com"; + // Test a normal chain and the default chain, as it's handled separately. + FilterChain chainA = createFilterChain("chain_a", + createHcmForRds(rdsName, filterStateTestConfigs(STATEFUL_1))); + FilterChain chainDefault = createFilterChain("chain_default", + createHcmForRds(rdsName, filterStateTestConfigs(STATEFUL_2))); + + xdsClient.deliverLdsUpdate(chainA, chainDefault); + xdsClient.awaitRds(FakeXdsClient.DEFAULT_TIMEOUT); + verify(listener, never()).onServing(); + // Server didn't start, but filter instances should have already been created. + ImmutableList lds1Snapshot = statefulFilterProvider.getAllInstances(); + assertWithMessage("LDS 1: expected to create filter instances").that(lds1Snapshot).hasSize(2); + // Naming: ldsChainFilter + StatefulFilter lds1ChainAFilter1 = lds1Snapshot.get(0); + StatefulFilter lds1ChainDefaultFilter2 = lds1Snapshot.get(1); + + // RDS 1: Standard vhost with a route. + xdsClient.deliverRdsUpdate(rdsName, filterStateTestVhost()); + verifyServerStarted(serverStart); + assertThat(statefulFilterProvider.getAllInstances()).isEqualTo(lds1Snapshot); + + // RDS 2: RDS_RESOURCE_NAME not found. + xdsClient.deliverRdsResourceNotFound(rdsName); + assertThat(lds1ChainAFilter1.isShutdown()).isFalse(); + assertThat(lds1ChainDefaultFilter2.isShutdown()).isFalse(); + } + + private FilterRegistry filterStateTestFilterRegistry( + StatefulFilter.Provider statefulFilterProvider) { + return FilterRegistry.newRegistry().register(statefulFilterProvider, ROUTER_FILTER_PROVIDER); + } + + private SettableFuture filterStateTestStartServer(FilterRegistry filterRegistry) { + xdsServerWrapper = new XdsServerWrapper("0.0.0.0:1", mockBuilder, listener, + selectorManager, new FakeXdsClientPoolFactory(xdsClient), + XdsServerTestHelper.RAW_BOOTSTRAP, filterRegistry); + SettableFuture serverStart = SettableFuture.create(); + scheduleServerStart(xdsServerWrapper, serverStart); + return serverStart; + } + + private static ImmutableList filterStateTestConfigs(String... names) { + ImmutableList.Builder result = ImmutableList.builder(); + for (String name : names) { + result.add(new NamedFilterConfig(name, new StatefulFilter.Config(name))); + } + result.add(new NamedFilterConfig(ROUTER_FILTER_INSTANCE_NAME, RouterFilter.ROUTER_CONFIG)); + return result.build(); + } + + private static Route filterStateTestRoute(ImmutableMap perRouteOverrides) { + // Standard basic route for filterState tests. + return Route.forAction( + RouteMatch.withPathExactOnly("/grpc.test.HelloService/SayHello"), null, perRouteOverrides); + } + + private static VirtualHost filterStateTestVhost() { + return filterStateTestVhost("stateful-vhost", NO_FILTER_OVERRIDES); + } + + private static VirtualHost filterStateTestVhost(String name) { + return filterStateTestVhost(name, NO_FILTER_OVERRIDES); + } + + private static VirtualHost filterStateTestVhost( + String name, ImmutableMap perRouteOverrides) { + return VirtualHost.create( + name, + ImmutableList.of("stateful.test.example.com"), + ImmutableList.of(filterStateTestRoute(perRouteOverrides)), + NO_FILTER_OVERRIDES); + } + + // End filter state tests. + + private void verifyServerStarted(SettableFuture serverStart) { + try { + serverStart.get(5, TimeUnit.SECONDS); + } catch (InterruptedException | ExecutionException | TimeoutException e) { + throw new AssertionError("serverStart future failed to resolve within the timeout", e); + } + verify(listener).onServing(); + try { + verify(mockServer).start(); + } catch (IOException e) { + throw new AssertionError("mockServer.start() shouldn't throw", e); + } + } + + private Map> getSelectorRoutingConfigs() { + return selectorManager.getSelectorToUpdateSelector().getRoutingConfigs(); + } + + private ServerRoutingConfig getSelectorRoutingConfig(FilterChain fc) { + return getSelectorRoutingConfigs().get(fc).get(); + } + + private ImmutableList getSelectorVhosts(FilterChain fc) { + return getSelectorRoutingConfig(fc).virtualHosts(); + } + + public static void scheduleServerStart( + XdsServerWrapper xdsServerWrapper, SettableFuture serverStart) { + Executors.newSingleThreadExecutor().execute(() -> { + try { + serverStart.set(xdsServerWrapper.start()); + } catch (Exception e) { + serverStart.setException(e); + } + }); + } + private static FilterChain createFilterChain(String name, HttpConnectionManager hcm) { - return EnvoyServerProtoData.FilterChain.create(name, createMatch(), - hcm, createTls(), mock(TlsContextManager.class)); + return createFilterChain(name, hcm, createMatch()); + } + + private static FilterChain createFilterChain( + String name, HttpConnectionManager hcm, FilterChainMatch filterChainMatch) { + TlsContextManager tlsContextManager = mock(TlsContextManager.class); + return FilterChain.create(name, filterChainMatch, hcm, createTls(), tlsContextManager); } private static VirtualHost createVirtualHost(String name) { @@ -1273,17 +1952,27 @@ private static VirtualHost createVirtualHost(String name) { ImmutableMap.of()); } - private static HttpConnectionManager createRds(String name) { - return createRds(name, null); + private static HttpConnectionManager createHcm( + VirtualHost vhost, List filterConfigs) { + return HttpConnectionManager.forVirtualHosts(0L, ImmutableList.of(vhost), filterConfigs); + } + + private static HttpConnectionManager createHcmForRds( + String name, List filterConfigs) { + return HttpConnectionManager.forRdsName(0L, name, filterConfigs); } - private static HttpConnectionManager createRds(String name, FilterConfig filterConfig) { - return HttpConnectionManager.forRdsName(0L, name, - Arrays.asList(new NamedFilterConfig("named-config-" + name, filterConfig))); + private static HttpConnectionManager createRds(String name) { + NamedFilterConfig config = + new NamedFilterConfig(ROUTER_FILTER_INSTANCE_NAME, RouterFilter.ROUTER_CONFIG); + return createHcmForRds(name, ImmutableList.of(config)); } - private static EnvoyServerProtoData.FilterChainMatch createMatch() { - return EnvoyServerProtoData.FilterChainMatch.create( + /** + * Returns the least-specific match-all Filter Chain Match. + */ + static FilterChainMatch createMatch() { + return FilterChainMatch.create( 0, ImmutableList.of(), ImmutableList.of(), @@ -1294,6 +1983,21 @@ private static EnvoyServerProtoData.FilterChainMatch createMatch() { ""); } + private static FilterChainMatch createMatchSrcIp(String srcCidr) { + String[] srcParts = srcCidr.split("/", 2); + InetAddress ip = InetAddresses.forString(srcParts[0]); + Integer subnetMask = Integer.valueOf(srcParts[1], 10); + return FilterChainMatch.create( + 0, + ImmutableList.of(), + ImmutableList.of(), + ImmutableList.of(CidrRange.create(ip, subnetMask)), + EnvoyServerProtoData.ConnectionSourceType.ANY, + ImmutableList.of(), + ImmutableList.of(), + ""); + } + private static ServerRoutingConfig createRoutingConfig(String path, String domain) { return createRoutingConfig(path, domain, null); } @@ -1323,7 +2027,7 @@ private static MethodDescriptor createMethod(String path) { .build(); } - private static EnvoyServerProtoData.DownstreamTlsContext createTls() { + static EnvoyServerProtoData.DownstreamTlsContext createTls() { return CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"); } } diff --git a/xds/src/test/java/io/grpc/xds/XdsTestControlPlaneService.java b/xds/src/test/java/io/grpc/xds/XdsTestControlPlaneService.java index c51327dc84d..a54893c9075 100644 --- a/xds/src/test/java/io/grpc/xds/XdsTestControlPlaneService.java +++ b/xds/src/test/java/io/grpc/xds/XdsTestControlPlaneService.java @@ -106,7 +106,7 @@ public void setXdsConfig(final String type, final Map copyResources = new HashMap<>(resources); xdsResources.put(type, copyResources); - String newVersionInfo = String.valueOf(xdsVersions.get(type).getAndDecrement()); + String newVersionInfo = String.valueOf(xdsVersions.get(type).getAndIncrement()); for (Map.Entry, Set> entry : subscribers.get(type).entrySet()) { @@ -119,6 +119,11 @@ public void run() { }); } + ImmutableMap getCurrentConfig(String type) { + HashMap hashMap = xdsResources.get(type); + return (hashMap != null) ? ImmutableMap.copyOf(hashMap) : ImmutableMap.of(); + } + @Override public StreamObserver streamAggregatedResources( final StreamObserver responseObserver) { @@ -135,12 +140,15 @@ public void run() { new Object[]{value.getResourceNamesList(), value.getErrorDetail()}); return; } + String resourceType = value.getTypeUrl(); - if (!value.getResponseNonce().isEmpty() - && !String.valueOf(xdsNonces.get(resourceType)).equals(value.getResponseNonce())) { + if (!value.getResponseNonce().isEmpty() && xdsNonces.containsKey(resourceType) + && !String.valueOf(xdsNonces.get(resourceType).get(responseObserver)) + .equals(value.getResponseNonce())) { logger.log(Level.FINE, "Resource nonce does not match, ignore."); return; } + Set requestedResourceNames = new HashSet<>(value.getResourceNamesList()); if (subscribers.get(resourceType).containsKey(responseObserver) && subscribers.get(resourceType).get(responseObserver) @@ -149,12 +157,14 @@ public void run() { value.getResourceNamesList()); return; } + if (!xdsNonces.get(resourceType).containsKey(responseObserver)) { xdsNonces.get(resourceType).put(responseObserver, new AtomicInteger(0)); } + DiscoveryResponse response = generateResponse(resourceType, String.valueOf(xdsVersions.get(resourceType)), - String.valueOf(xdsNonces.get(resourceType).get(responseObserver)), + String.valueOf(xdsNonces.get(resourceType).get(responseObserver).addAndGet(1)), requestedResourceNames); responseObserver.onNext(response); subscribers.get(resourceType).put(responseObserver, requestedResourceNames); @@ -197,4 +207,12 @@ private DiscoveryResponse generateResponse(String resourceType, String version, } return responseBuilder.build(); } + + public Map getSubscriberCounts() { + Map subscriberCounts = new HashMap<>(); + for (String type : subscribers.keySet()) { + subscriberCounts.put(type, subscribers.get(type).size()); + } + return subscriberCounts; + } } diff --git a/xds/src/test/java/io/grpc/xds/XdsTestUtils.java b/xds/src/test/java/io/grpc/xds/XdsTestUtils.java new file mode 100644 index 00000000000..93113411b5e --- /dev/null +++ b/xds/src/test/java/io/grpc/xds/XdsTestUtils.java @@ -0,0 +1,442 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds; + +import static com.google.common.truth.Truth.assertThat; +import static io.grpc.xds.XdsTestControlPlaneService.ADS_TYPE_URL_CDS; +import static io.grpc.xds.XdsTestControlPlaneService.ADS_TYPE_URL_EDS; +import static io.grpc.xds.XdsTestControlPlaneService.ADS_TYPE_URL_LDS; +import static io.grpc.xds.XdsTestControlPlaneService.ADS_TYPE_URL_RDS; +import static org.mockito.ArgumentMatchers.argThat; +import static org.mockito.Mockito.inOrder; +import static org.mockito.Mockito.mock; + +import com.google.common.base.Splitter; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.util.concurrent.MoreExecutors; +import com.google.protobuf.Any; +import com.google.protobuf.Message; +import com.google.protobuf.util.Durations; +import io.envoyproxy.envoy.config.cluster.v3.Cluster; +import io.envoyproxy.envoy.config.endpoint.v3.ClusterLoadAssignment; +import io.envoyproxy.envoy.config.endpoint.v3.ClusterStats; +import io.envoyproxy.envoy.config.listener.v3.ApiListener; +import io.envoyproxy.envoy.config.listener.v3.Listener; +import io.envoyproxy.envoy.config.route.v3.RouteConfiguration; +import io.envoyproxy.envoy.extensions.clusters.aggregate.v3.ClusterConfig; +import io.envoyproxy.envoy.extensions.filters.http.router.v3.Router; +import io.envoyproxy.envoy.extensions.filters.network.http_connection_manager.v3.HttpFilter; +import io.envoyproxy.envoy.service.load_stats.v3.LoadReportingServiceGrpc; +import io.envoyproxy.envoy.service.load_stats.v3.LoadStatsRequest; +import io.envoyproxy.envoy.service.load_stats.v3.LoadStatsResponse; +import io.grpc.BindableService; +import io.grpc.Context; +import io.grpc.Context.CancellationListener; +import io.grpc.InsecureChannelCredentials; +import io.grpc.StatusOr; +import io.grpc.internal.ExponentialBackoffPolicy; +import io.grpc.internal.FakeClock; +import io.grpc.internal.JsonParser; +import io.grpc.stub.StreamObserver; +import io.grpc.xds.Endpoints.LbEndpoint; +import io.grpc.xds.Endpoints.LocalityLbEndpoints; +import io.grpc.xds.XdsConfig.XdsClusterConfig.EndpointConfig; +import io.grpc.xds.client.Bootstrapper; +import io.grpc.xds.client.CommonBootstrapperTestUtils; +import io.grpc.xds.client.Locality; +import io.grpc.xds.client.XdsClient; +import io.grpc.xds.client.XdsClientMetricReporter; +import io.grpc.xds.client.XdsResourceType; +import io.grpc.xds.client.XdsTransportFactory; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Queue; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.logging.Level; +import java.util.logging.Logger; +import org.mockito.ArgumentMatcher; +import org.mockito.InOrder; + +public class XdsTestUtils { + private static final Logger log = Logger.getLogger(XdsTestUtils.class.getName()); + static final String RDS_NAME = "route-config.googleapis.com"; + static final String CLUSTER_NAME = "cluster0"; + static final String EDS_NAME = "eds-service-0"; + static final String SERVER_LISTENER = "grpc/server?udpa.resource.listening_address="; + static final String HTTP_CONNECTION_MANAGER_TYPE_URL = + "type.googleapis.com/envoy.extensions.filters.network.http_connection_manager.v3" + + ".HttpConnectionManager"; + static final Bootstrapper.ServerInfo EMPTY_BOOTSTRAPPER_SERVER_INFO = + Bootstrapper.ServerInfo.create( + "td.googleapis.com", InsecureChannelCredentials.create(), false, true, false, false); + static final Bootstrapper.BootstrapInfo EMPTY_BOOTSTRAP = + Bootstrapper.BootstrapInfo.builder() + .servers(com.google.common.collect.ImmutableList.of(EMPTY_BOOTSTRAPPER_SERVER_INFO)) + .node(io.grpc.xds.client.EnvoyProtoData.Node.newBuilder().setId("node-id").build()) + .build(); + public static final String ENDPOINT_HOSTNAME = "data-host"; + public static final int ENDPOINT_PORT = 1234; + + static BindableService createLrsService(AtomicBoolean lrsEnded, + Queue loadReportCalls) { + return new LoadReportingServiceGrpc.LoadReportingServiceImplBase() { + @Override + public StreamObserver streamLoadStats( + StreamObserver responseObserver) { + assertThat(lrsEnded.get()).isTrue(); + lrsEnded.set(false); + @SuppressWarnings("unchecked") + StreamObserver requestObserver = mock(StreamObserver.class); + LrsRpcCall call = new LrsRpcCall(requestObserver, responseObserver); + Context.current().addListener( + new CancellationListener() { + @Override + public void cancelled(Context context) { + lrsEnded.set(true); + } + }, MoreExecutors.directExecutor()); + loadReportCalls.offer(call); + return requestObserver; + } + }; + } + + static boolean matchErrorDetail( + com.google.rpc.Status errorDetail, int expectedCode, List expectedMessages) { + if (expectedCode != errorDetail.getCode()) { + return false; + } + List errors = Splitter.on('\n').splitToList(errorDetail.getMessage()); + if (errors.size() != expectedMessages.size()) { + return false; + } + for (int i = 0; i < errors.size(); i++) { + if (!errors.get(i).startsWith(expectedMessages.get(i))) { + return false; + } + } + return true; + } + + static void setAdsConfig(XdsTestControlPlaneService service, String serverName) { + setAdsConfig(service, serverName, RDS_NAME, CLUSTER_NAME, EDS_NAME, ENDPOINT_HOSTNAME, + ENDPOINT_PORT); + } + + static void setAdsConfig(XdsTestControlPlaneService service, String serverName, String rdsName, + String clusterName, String edsName, String endpointHostname, + int endpointPort) { + + Listener serverListener = ControlPlaneRule.buildServerListener(); + Listener clientListener = ControlPlaneRule.buildClientListener(serverName, rdsName); + service.setXdsConfig(ADS_TYPE_URL_LDS, + ImmutableMap.of(SERVER_LISTENER, serverListener, serverName, clientListener)); + + RouteConfiguration routeConfig = + buildRouteConfiguration(serverName, rdsName, clusterName); + service.setXdsConfig(ADS_TYPE_URL_RDS, ImmutableMap.of(rdsName, routeConfig));; + + Cluster cluster = ControlPlaneRule.buildCluster(clusterName, edsName); + service.setXdsConfig(ADS_TYPE_URL_CDS, ImmutableMap.of(clusterName, cluster)); + + ClusterLoadAssignment clusterLoadAssignment = ControlPlaneRule.buildClusterLoadAssignment( + "127.0.0.11", endpointHostname, endpointPort, edsName); + service.setXdsConfig(ADS_TYPE_URL_EDS, + ImmutableMap.of(edsName, clusterLoadAssignment)); + + log.log(Level.FINE, String.format("Set ADS config for %s with address %s:%d", + serverName, endpointHostname, endpointPort)); + + } + + static String getEdsNameForCluster(String clusterName) { + return "eds_" + clusterName; + } + + static void setAggregateCdsConfig(XdsTestControlPlaneService service, String serverName, + String clusterName, List children) { + Map clusterMap = new HashMap<>(); + + ClusterConfig rootConfig = ClusterConfig.newBuilder().addAllClusters(children).build(); + Cluster.CustomClusterType type = + Cluster.CustomClusterType.newBuilder() + .setName(XdsClusterResource.AGGREGATE_CLUSTER_TYPE_NAME) + .setTypedConfig(Any.pack(rootConfig)) + .build(); + Cluster.Builder builder = Cluster.newBuilder().setName(clusterName).setClusterType(type); + builder.setLbPolicy(Cluster.LbPolicy.ROUND_ROBIN); + Cluster cluster = builder.build(); + clusterMap.put(clusterName, cluster); + + for (String child : children) { + Cluster childCluster = ControlPlaneRule.buildCluster(child, getEdsNameForCluster(child)); + clusterMap.put(child, childCluster); + } + + service.setXdsConfig(ADS_TYPE_URL_CDS, clusterMap); + + Map edsMap = new HashMap<>(); + for (String child : children) { + ClusterLoadAssignment clusterLoadAssignment = ControlPlaneRule.buildClusterLoadAssignment( + "127.0.0.16", ENDPOINT_HOSTNAME, ENDPOINT_PORT, getEdsNameForCluster(child)); + edsMap.put(getEdsNameForCluster(child), clusterLoadAssignment); + } + service.setXdsConfig(ADS_TYPE_URL_EDS, edsMap); + } + + static void addAggregateToExistingConfig(XdsTestControlPlaneService service, String rootName, + List children) { + Map clusterMap = new HashMap<>(service.getCurrentConfig(ADS_TYPE_URL_CDS)); + if (clusterMap.containsKey(rootName)) { + throw new IllegalArgumentException("Root cluster " + rootName + " already exists"); + } + ClusterConfig rootConfig = ClusterConfig.newBuilder().addAllClusters(children).build(); + Cluster.CustomClusterType type = + Cluster.CustomClusterType.newBuilder() + .setName(XdsClusterResource.AGGREGATE_CLUSTER_TYPE_NAME) + .setTypedConfig(Any.pack(rootConfig)) + .build(); + Cluster.Builder builder = Cluster.newBuilder().setName(rootName).setClusterType(type); + builder.setLbPolicy(Cluster.LbPolicy.ROUND_ROBIN); + Cluster cluster = builder.build(); + clusterMap.put(rootName, cluster); + + for (String child : children) { + if (clusterMap.containsKey(child)) { + continue; + } + Cluster childCluster = ControlPlaneRule.buildCluster(child, getEdsNameForCluster(child)); + clusterMap.put(child, childCluster); + } + + service.setXdsConfig(ADS_TYPE_URL_CDS, clusterMap); + + Map edsMap = new HashMap<>(service.getCurrentConfig(ADS_TYPE_URL_EDS)); + for (String child : children) { + if (edsMap.containsKey(getEdsNameForCluster(child))) { + continue; + } + ClusterLoadAssignment clusterLoadAssignment = ControlPlaneRule.buildClusterLoadAssignment( + "127.0.0.15", ENDPOINT_HOSTNAME, ENDPOINT_PORT, getEdsNameForCluster(child)); + edsMap.put(getEdsNameForCluster(child), clusterLoadAssignment); + } + service.setXdsConfig(ADS_TYPE_URL_EDS, edsMap); + } + + static XdsConfig getDefaultXdsConfig(String serverHostName) + throws XdsResourceType.ResourceInvalidException, IOException { + XdsConfig.XdsConfigBuilder builder = new XdsConfig.XdsConfigBuilder(); + + Filter.NamedFilterConfig routerFilterConfig = new Filter.NamedFilterConfig( + "terminal-filter", RouterFilter.ROUTER_CONFIG); + + HttpConnectionManager httpConnectionManager = HttpConnectionManager.forRdsName( + 0L, RDS_NAME, Collections.singletonList(routerFilterConfig)); + XdsListenerResource.LdsUpdate ldsUpdate = + XdsListenerResource.LdsUpdate.forApiListener(httpConnectionManager); + + RouteConfiguration routeConfiguration = + buildRouteConfiguration(serverHostName, RDS_NAME, CLUSTER_NAME); + XdsResourceType.Args args = new XdsResourceType.Args( + EMPTY_BOOTSTRAPPER_SERVER_INFO, "0", "0", EMPTY_BOOTSTRAP, null, null); + XdsRouteConfigureResource.RdsUpdate rdsUpdate = + XdsRouteConfigureResource.getInstance().doParse(args, routeConfiguration); + + // Take advantage of knowing that there is only 1 virtual host in the route configuration + assertThat(rdsUpdate.virtualHosts).hasSize(1); + VirtualHost virtualHost = rdsUpdate.virtualHosts.get(0); + + // Need to create endpoints to create locality endpoints map to create edsUpdate + Map lbEndpointsMap = new HashMap<>(); + LbEndpoint lbEndpoint = LbEndpoint.create( + "127.0.0.11", ENDPOINT_PORT, 0, true, ENDPOINT_HOSTNAME, ImmutableMap.of()); + lbEndpointsMap.put( + Locality.create("", "", ""), + LocalityLbEndpoints.create(ImmutableList.of(lbEndpoint), 10, 0, ImmutableMap.of())); + + // Need to create EdsUpdate to create CdsUpdate to create XdsClusterConfig for builder + XdsEndpointResource.EdsUpdate edsUpdate = new XdsEndpointResource.EdsUpdate( + EDS_NAME, lbEndpointsMap, Collections.emptyList()); + XdsClusterResource.CdsUpdate cdsUpdate = XdsClusterResource.CdsUpdate.forEds( + CLUSTER_NAME, EDS_NAME, null, null, null, null, false, null) + .lbPolicyConfig(getWrrLbConfigAsMap()).build(); + XdsConfig.XdsClusterConfig clusterConfig = new XdsConfig.XdsClusterConfig( + CLUSTER_NAME, cdsUpdate, new EndpointConfig(StatusOr.fromValue(edsUpdate))); + + builder + .setListener(ldsUpdate) + .setRoute(rdsUpdate) + .setVirtualHost(virtualHost) + .addCluster(CLUSTER_NAME, StatusOr.fromValue(clusterConfig)); + + return builder.build(); + } + + static Map createMinimalLbEndpointsMap(String serverAddress) { + Map lbEndpointsMap = new HashMap<>(); + LbEndpoint lbEndpoint = LbEndpoint.create( + serverAddress, ENDPOINT_PORT, 0, true, ENDPOINT_HOSTNAME, ImmutableMap.of()); + lbEndpointsMap.put( + Locality.create("", "", ""), + LocalityLbEndpoints.create(ImmutableList.of(lbEndpoint), 10, 0, ImmutableMap.of())); + return lbEndpointsMap; + } + + @SuppressWarnings("unchecked") + static ImmutableMap getWrrLbConfigAsMap() throws IOException { + String lbConfigStr = "{\"wrr_locality_experimental\" : " + + "{ \"childPolicy\" : [{\"round_robin\" : {}}]}}"; + + return ImmutableMap.copyOf((Map) JsonParser.parse(lbConfigStr)); + } + + static RouteConfiguration buildRouteConfiguration(String authority, String rdsName, + String clusterName) { + return ControlPlaneRule.buildRouteConfiguration(authority, rdsName, clusterName); + } + + static Cluster buildAggCluster(String name, List childNames) { + ClusterConfig rootConfig = ClusterConfig.newBuilder().addAllClusters(childNames).build(); + Cluster.CustomClusterType type = + Cluster.CustomClusterType.newBuilder() + .setName(XdsClusterResource.AGGREGATE_CLUSTER_TYPE_NAME) + .setTypedConfig(Any.pack(rootConfig)) + .build(); + Cluster.Builder builder = + Cluster.newBuilder().setName(name).setClusterType(type); + builder.setLbPolicy(Cluster.LbPolicy.ROUND_ROBIN); + Cluster cluster = builder.build(); + return cluster; + } + + static void addEdsClusters(Map clusterMap, Map edsMap, + String... clusterNames) { + for (String clusterName : clusterNames) { + String edsName = getEdsNameForCluster(clusterName); + Cluster cluster = ControlPlaneRule.buildCluster(clusterName, edsName); + clusterMap.put(clusterName, cluster); + + ClusterLoadAssignment clusterLoadAssignment = ControlPlaneRule.buildClusterLoadAssignment( + "127.0.0.13", ENDPOINT_HOSTNAME, ENDPOINT_PORT, edsName); + edsMap.put(edsName, clusterLoadAssignment); + } + } + + static Listener buildInlineClientListener(String rdsName, String clusterName, String serverName) { + HttpFilter + httpFilter = HttpFilter.newBuilder() + .setName("terminal-filter") + .setTypedConfig(Any.pack(Router.newBuilder().build())) + .setIsOptional(true) + .build(); + ApiListener.Builder clientListenerBuilder = + ApiListener.newBuilder().setApiListener(Any.pack( + io.envoyproxy.envoy.extensions.filters.network.http_connection_manager.v3 + .HttpConnectionManager.newBuilder() + .setRouteConfig( + buildRouteConfiguration(serverName, rdsName, clusterName)) + .addAllHttpFilters(Collections.singletonList(httpFilter)) + .build(), + HTTP_CONNECTION_MANAGER_TYPE_URL)); + return Listener.newBuilder() + .setName(serverName) + .setApiListener(clientListenerBuilder.build()).build(); + } + + public static XdsClient createXdsClient( + List serverUris, + XdsTransportFactory xdsTransportFactory, + FakeClock fakeClock) { + return createXdsClient( + CommonBootstrapperTestUtils.buildBootStrap(serverUris), + xdsTransportFactory, + fakeClock, + new XdsClientMetricReporter() {}); + } + + /** Calls {@link CommonBootstrapperTestUtils#createXdsClient} with gRPC-specific values. */ + public static XdsClient createXdsClient( + Bootstrapper.BootstrapInfo bootstrapInfo, + XdsTransportFactory xdsTransportFactory, + FakeClock fakeClock, + XdsClientMetricReporter xdsClientMetricReporter) { + return CommonBootstrapperTestUtils.createXdsClient( + bootstrapInfo, + xdsTransportFactory, + fakeClock, + new ExponentialBackoffPolicy.Provider(), + MessagePrinter.INSTANCE, + xdsClientMetricReporter); + } + + /** + * Matches a {@link LoadStatsRequest} containing a collection of {@link ClusterStats} with + * the same list of clusterName:clusterServiceName pair. + */ + static class LrsRequestMatcher implements ArgumentMatcher { + private final List expected; + + private LrsRequestMatcher(List clusterNames) { + expected = new ArrayList<>(); + for (String[] pair : clusterNames) { + expected.add(pair[0] + ":" + (pair[1] == null ? "" : pair[1])); + } + Collections.sort(expected); + } + + @Override + public boolean matches(LoadStatsRequest argument) { + List actual = new ArrayList<>(); + for (ClusterStats clusterStats : argument.getClusterStatsList()) { + actual.add(clusterStats.getClusterName() + ":" + clusterStats.getClusterServiceName()); + } + Collections.sort(actual); + return actual.equals(expected); + } + } + + static class LrsRpcCall { + private final StreamObserver requestObserver; + private final StreamObserver responseObserver; + private final InOrder inOrder; + + private LrsRpcCall(StreamObserver requestObserver, + StreamObserver responseObserver) { + this.requestObserver = requestObserver; + this.responseObserver = responseObserver; + inOrder = inOrder(requestObserver); + } + + protected void verifyNextReportClusters(List clusters) { + inOrder.verify(requestObserver).onNext(argThat(new LrsRequestMatcher(clusters))); + } + + protected void sendResponse(List clusters, long loadReportIntervalNano) { + LoadStatsResponse response = + LoadStatsResponse.newBuilder() + .addAllClusters(clusters) + .setLoadReportingInterval(Durations.fromNanos(loadReportIntervalNano)) + .build(); + responseObserver.onNext(response); + } + } +} diff --git a/xds/src/test/java/io/grpc/xds/client/BackendMetricPropagationTest.java b/xds/src/test/java/io/grpc/xds/client/BackendMetricPropagationTest.java new file mode 100644 index 00000000000..31ad6f9c47f --- /dev/null +++ b/xds/src/test/java/io/grpc/xds/client/BackendMetricPropagationTest.java @@ -0,0 +1,151 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.client; + +import static com.google.common.truth.Truth.assertThat; +import static java.util.Arrays.asList; + +import com.google.common.collect.ImmutableList; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** + * Unit tests for {@link BackendMetricPropagation}. + */ +@RunWith(JUnit4.class) +public class BackendMetricPropagationTest { + + @Test + public void fromMetricSpecs_nullInput() { + BackendMetricPropagation config = BackendMetricPropagation.fromMetricSpecs(null); + + assertThat(config.propagateCpuUtilization).isFalse(); + assertThat(config.propagateMemUtilization).isFalse(); + assertThat(config.propagateApplicationUtilization).isFalse(); + assertThat(config.shouldPropagateNamedMetric("any")).isFalse(); + } + + @Test + public void fromMetricSpecs_emptyInput() { + BackendMetricPropagation config = BackendMetricPropagation.fromMetricSpecs(ImmutableList.of()); + + assertThat(config.propagateCpuUtilization).isFalse(); + assertThat(config.propagateMemUtilization).isFalse(); + assertThat(config.propagateApplicationUtilization).isFalse(); + assertThat(config.shouldPropagateNamedMetric("any")).isFalse(); + } + + @Test + public void fromMetricSpecs_partialStandardMetrics() { + BackendMetricPropagation config = BackendMetricPropagation.fromMetricSpecs( + ImmutableList.of("cpu_utilization", "mem_utilization")); + + assertThat(config.propagateCpuUtilization).isTrue(); + assertThat(config.propagateMemUtilization).isTrue(); + assertThat(config.propagateApplicationUtilization).isFalse(); + assertThat(config.shouldPropagateNamedMetric("any")).isFalse(); + } + + @Test + public void fromMetricSpecs_allStandardMetrics() { + BackendMetricPropagation config = BackendMetricPropagation.fromMetricSpecs( + ImmutableList.of("cpu_utilization", "mem_utilization", "application_utilization")); + + assertThat(config.propagateCpuUtilization).isTrue(); + assertThat(config.propagateMemUtilization).isTrue(); + assertThat(config.propagateApplicationUtilization).isTrue(); + assertThat(config.shouldPropagateNamedMetric("any")).isFalse(); + } + + @Test + public void fromMetricSpecs_wildcardNamedMetrics() { + BackendMetricPropagation config = BackendMetricPropagation.fromMetricSpecs( + ImmutableList.of("named_metrics.*")); + + assertThat(config.propagateCpuUtilization).isFalse(); + assertThat(config.propagateMemUtilization).isFalse(); + assertThat(config.propagateApplicationUtilization).isFalse(); + assertThat(config.shouldPropagateNamedMetric("any_key")).isTrue(); + assertThat(config.shouldPropagateNamedMetric("another_key")).isTrue(); + } + + @Test + public void fromMetricSpecs_specificNamedMetrics() { + BackendMetricPropagation config = BackendMetricPropagation.fromMetricSpecs( + ImmutableList.of("named_metrics.foo", "named_metrics.bar")); + + assertThat(config.shouldPropagateNamedMetric("foo")).isTrue(); + assertThat(config.shouldPropagateNamedMetric("bar")).isTrue(); + assertThat(config.shouldPropagateNamedMetric("baz")).isFalse(); + assertThat(config.shouldPropagateNamedMetric("any")).isFalse(); + } + + @Test + public void fromMetricSpecs_mixedStandardAndNamed() { + BackendMetricPropagation config = BackendMetricPropagation.fromMetricSpecs( + ImmutableList.of("cpu_utilization", "named_metrics.foo", "named_metrics.bar")); + + assertThat(config.propagateCpuUtilization).isTrue(); + assertThat(config.propagateMemUtilization).isFalse(); + assertThat(config.shouldPropagateNamedMetric("foo")).isTrue(); + assertThat(config.shouldPropagateNamedMetric("bar")).isTrue(); + assertThat(config.shouldPropagateNamedMetric("baz")).isFalse(); + } + + @Test + public void fromMetricSpecs_wildcardAndSpecificNamedMetrics() { + BackendMetricPropagation config = BackendMetricPropagation.fromMetricSpecs( + ImmutableList.of("named_metrics.foo", "named_metrics.*")); + + assertThat(config.shouldPropagateNamedMetric("foo")).isTrue(); + assertThat(config.shouldPropagateNamedMetric("bar")).isTrue(); + assertThat(config.shouldPropagateNamedMetric("any_other_key")).isTrue(); + } + + @Test + public void fromMetricSpecs_malformedAndUnknownSpecs_areIgnored() { + BackendMetricPropagation config = BackendMetricPropagation.fromMetricSpecs( + asList( + "cpu_utilization", + null, // ignored + "disk_utilization", + "named_metrics.", // empty key + "named_metrics.valid" + )); + + assertThat(config.propagateCpuUtilization).isTrue(); + assertThat(config.propagateMemUtilization).isFalse(); + assertThat(config.shouldPropagateNamedMetric("disk_utilization")).isFalse(); + assertThat(config.shouldPropagateNamedMetric("valid")).isTrue(); + assertThat(config.shouldPropagateNamedMetric("")).isFalse(); // from the empty key + } + + @Test + public void fromMetricSpecs_duplicateSpecs_areHandledGracefully() { + BackendMetricPropagation config = BackendMetricPropagation.fromMetricSpecs( + ImmutableList.of( + "cpu_utilization", + "named_metrics.foo", + "cpu_utilization", + "named_metrics.foo")); + + assertThat(config.propagateCpuUtilization).isTrue(); + assertThat(config.shouldPropagateNamedMetric("foo")).isTrue(); + assertThat(config.shouldPropagateNamedMetric("bar")).isFalse(); + } +} diff --git a/xds/src/test/java/io/grpc/xds/CommonBootstrapperTestUtils.java b/xds/src/test/java/io/grpc/xds/client/CommonBootstrapperTestUtils.java similarity index 57% rename from xds/src/test/java/io/grpc/xds/CommonBootstrapperTestUtils.java rename to xds/src/test/java/io/grpc/xds/client/CommonBootstrapperTestUtils.java index 0b2f3c7136b..e3760bd983f 100644 --- a/xds/src/test/java/io/grpc/xds/CommonBootstrapperTestUtils.java +++ b/xds/src/test/java/io/grpc/xds/client/CommonBootstrapperTestUtils.java @@ -14,21 +14,35 @@ * limitations under the License. */ -package io.grpc.xds; +package io.grpc.xds.client; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import io.grpc.ChannelCredentials; +import io.grpc.InsecureChannelCredentials; +import io.grpc.internal.BackoffPolicy; +import io.grpc.internal.FakeClock; import io.grpc.internal.JsonParser; -import io.grpc.xds.client.Bootstrapper; import io.grpc.xds.client.Bootstrapper.ServerInfo; -import io.grpc.xds.client.EnvoyProtoData; import io.grpc.xds.internal.security.CommonTlsContextTestsUtil; +import io.grpc.xds.internal.security.TlsContextManagerImpl; import java.io.IOException; +import java.util.ArrayList; import java.util.HashMap; +import java.util.List; import java.util.Map; import javax.annotation.Nullable; public class CommonBootstrapperTestUtils { + public static final String SERVER_URI = "trafficdirector.googleapis.com"; + private static final ChannelCredentials CHANNEL_CREDENTIALS = InsecureChannelCredentials.create(); + private static final String SERVER_URI_CUSTOM_AUTHORITY = "trafficdirector2.googleapis.com"; + private static final String SERVER_URI_EMPTY_AUTHORITY = "trafficdirector3.googleapis.com"; + public static final String LDS_RESOURCE = "listener.googleapis.com"; + public static final String RDS_RESOURCE = "route-configuration.googleapis.com"; + public static final String CDS_RESOURCE = "cluster.googleapis.com"; + public static final String EDS_RESOURCE = "cluster-load-assignment.googleapis.com"; + private static final String FILE_WATCHER_CONFIG = "{\"path\": \"/etc/secret/certs\"}"; private static final String MESHCA_CONFIG = "{\n" @@ -88,7 +102,7 @@ public static Bootstrapper.BootstrapInfo buildBootstrapInfo( String certInstanceName1, @Nullable String privateKey1, @Nullable String cert1, @Nullable String trustCa1, String certInstanceName2, String privateKey2, String cert2, - String trustCa2) { + String trustCa2, @Nullable String spiffeTrustMap) { // get temp file for each file try { if (privateKey1 != null) { @@ -109,6 +123,9 @@ public static Bootstrapper.BootstrapInfo buildBootstrapInfo( if (trustCa2 != null) { trustCa2 = CommonTlsContextTestsUtil.getTempFileNameForResourcesFile(trustCa2); } + if (spiffeTrustMap != null) { + spiffeTrustMap = CommonTlsContextTestsUtil.getTempFileNameForResourcesFile(spiffeTrustMap); + } } catch (IOException ioe) { throw new RuntimeException(ioe); } @@ -116,6 +133,9 @@ public static Bootstrapper.BootstrapInfo buildBootstrapInfo( config.put("certificate_file", cert1); config.put("private_key_file", privateKey1); config.put("ca_certificate_file", trustCa1); + if (spiffeTrustMap != null) { + config.put("spiffe_trust_bundle_map_file", spiffeTrustMap); + } Bootstrapper.CertificateProviderInfo certificateProviderInfo = Bootstrapper.CertificateProviderInfo.create("file_watcher", config); HashMap certProviders = @@ -126,6 +146,9 @@ public static Bootstrapper.BootstrapInfo buildBootstrapInfo( config.put("certificate_file", cert2); config.put("private_key_file", privateKey2); config.put("ca_certificate_file", trustCa2); + if (spiffeTrustMap != null) { + config.put("spiffe_trust_bundle_map_file", spiffeTrustMap); + } certificateProviderInfo = Bootstrapper.CertificateProviderInfo.create("file_watcher", config); certProviders.put(certInstanceName2, certificateProviderInfo); @@ -136,4 +159,71 @@ public static Bootstrapper.BootstrapInfo buildBootstrapInfo( .certProviders(certProviders) .build(); } + + public static boolean setEnableXdsFallback(boolean target) { + boolean oldValue = BootstrapperImpl.enableXdsFallback; + BootstrapperImpl.enableXdsFallback = target; + return oldValue; + } + + public static XdsClientImpl createXdsClient(List serverUris, + XdsTransportFactory xdsTransportFactory, + FakeClock fakeClock, + BackoffPolicy.Provider backoffPolicyProvider, + MessagePrettyPrinter messagePrinter, + XdsClientMetricReporter xdsClientMetricReporter) { + return createXdsClient( + buildBootStrap(serverUris), + xdsTransportFactory, + fakeClock, + backoffPolicyProvider, + messagePrinter, + xdsClientMetricReporter); + } + + public static XdsClientImpl createXdsClient(Bootstrapper.BootstrapInfo bootstrapInfo, + XdsTransportFactory xdsTransportFactory, + FakeClock fakeClock, + BackoffPolicy.Provider backoffPolicyProvider, + MessagePrettyPrinter messagePrinter, + XdsClientMetricReporter xdsClientMetricReporter) { + return new XdsClientImpl( + xdsTransportFactory, + bootstrapInfo, + fakeClock.getScheduledExecutorService(), + backoffPolicyProvider, + fakeClock.getStopwatchSupplier(), + fakeClock.getTimeProvider(), + messagePrinter, + new TlsContextManagerImpl(bootstrapInfo), + xdsClientMetricReporter); + } + + public static Bootstrapper.BootstrapInfo buildBootStrap(List serverUris) { + + List serverInfos = new ArrayList<>(); + for (String uri : serverUris) { + serverInfos.add(ServerInfo.create(uri, CHANNEL_CREDENTIALS, false, true, false, false)); + } + EnvoyProtoData.Node node = EnvoyProtoData.Node.newBuilder().setId("node-id").build(); + + return Bootstrapper.BootstrapInfo.builder() + .servers(serverInfos) + .node(node) + .authorities(ImmutableMap.of( + "authority.xds.com", + Bootstrapper.AuthorityInfo.create( + "xdstp://authority.xds.com/envoy.config.listener.v3.Listener/%s", + ImmutableList.of(Bootstrapper.ServerInfo.create( + SERVER_URI_CUSTOM_AUTHORITY, CHANNEL_CREDENTIALS))), + "", + Bootstrapper.AuthorityInfo.create( + "xdstp:///envoy.config.listener.v3.Listener/%s", + ImmutableList.of(Bootstrapper.ServerInfo.create( + SERVER_URI_EMPTY_AUTHORITY, CHANNEL_CREDENTIALS))))) + .certProviders(ImmutableMap.of("cert-instance-name", + Bootstrapper.CertificateProviderInfo.create("file-watcher", ImmutableMap.of()))) + .build(); + } + } diff --git a/xds/src/test/java/io/grpc/xds/client/ControlPlaneClientTest.java b/xds/src/test/java/io/grpc/xds/client/ControlPlaneClientTest.java new file mode 100644 index 00000000000..64786c4fb3b --- /dev/null +++ b/xds/src/test/java/io/grpc/xds/client/ControlPlaneClientTest.java @@ -0,0 +1,279 @@ +/* + * Copyright 2026 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.client; + +import static com.google.common.truth.Truth.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.argThat; +import static org.mockito.Mockito.atLeastOnce; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import com.google.common.base.Stopwatch; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import io.envoyproxy.envoy.service.discovery.v3.DiscoveryRequest; +import io.envoyproxy.envoy.service.discovery.v3.DiscoveryResponse; +import io.grpc.InsecureChannelCredentials; +import io.grpc.MethodDescriptor; +import io.grpc.SynchronizationContext; +import io.grpc.internal.BackoffPolicy; +import io.grpc.internal.FakeClock; +import io.grpc.xds.client.Bootstrapper.ServerInfo; +import io.grpc.xds.client.EnvoyProtoData.Node; +import io.grpc.xds.client.XdsClient.ResourceStore; +import io.grpc.xds.client.XdsClient.XdsResponseHandler; +import io.grpc.xds.client.XdsTransportFactory.EventHandler; +import io.grpc.xds.client.XdsTransportFactory.StreamingCall; +import io.grpc.xds.client.XdsTransportFactory.XdsTransport; +import java.util.Collections; +import java.util.Map; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnit; +import org.mockito.junit.MockitoRule; + +/** Unit tests for {@link ControlPlaneClient}. */ +@RunWith(JUnit4.class) +public class ControlPlaneClientTest { + + private static final String CDS_TYPE_URL = "type.googleapis.com/envoy.config.cluster.v3.Cluster"; + private static final String EDS_TYPE_URL = + "type.googleapis.com/envoy.config.endpoint.v3.ClusterLoadAssignment"; + + private final SynchronizationContext syncContext = + new SynchronizationContext((t, e) -> { + throw new AssertionError("Uncaught exception in sync context", e); + }); + private final FakeClock fakeClock = new FakeClock(); + private final ServerInfo serverInfo = + ServerInfo.create("eds-control-plane:8443", InsecureChannelCredentials.create()); + private final Node bootstrapNode = Node.newBuilder().setId("test-node").build(); + + @Mock private XdsTransport xdsTransport; + @Mock private StreamingCall streamingCall; + @Mock private XdsResponseHandler responseHandler; + @Mock private ResourceStore resourceStore; + @Mock private BackoffPolicy.Provider backoffPolicyProvider; + @Mock private MessagePrettyPrinter messagePrinter; + @Mock private XdsResourceType cdsType; + @Mock private XdsResourceType edsType; + + @Rule public final MockitoRule mocks = MockitoJUnit.rule(); + + private ControlPlaneClient cpc; + private ArgumentCaptor> handlerCaptor; + + @Before + @SuppressWarnings("unchecked") + public void setUp() { + when(cdsType.typeUrl()).thenReturn(CDS_TYPE_URL); + when(cdsType.typeName()).thenReturn("CDS"); + when(edsType.typeUrl()).thenReturn(EDS_TYPE_URL); + when(edsType.typeName()).thenReturn("EDS"); + + when(xdsTransport.createStreamingCall( + anyString(), + any(MethodDescriptor.Marshaller.class), + any(MethodDescriptor.Marshaller.class))) + .thenReturn(streamingCall); + when(streamingCall.isReady()).thenReturn(true); + + handlerCaptor = ArgumentCaptor.forClass(EventHandler.class); + + cpc = new ControlPlaneClient( + xdsTransport, + serverInfo, + bootstrapNode, + responseHandler, + resourceStore, + fakeClock.getScheduledExecutorService(), + syncContext, + backoffPolicyProvider, + () -> Stopwatch.createUnstarted(fakeClock.getTicker()), + messagePrinter); + } + + /** + * Reproduces the bug where, when an ADS stream is opened to an authority-specific server (e.g. + * an EDS-only control plane), {@code sendDiscoveryRequests} previously emitted an empty + * DiscoveryRequest for every globally-subscribed resource type — including types this server + * does not handle. Authority-specific servers may reject those requests with UNIMPLEMENTED and + * tear down the stream, blocking the legitimate request that follows. + * + *

Asserts that the empty CDS request is suppressed and only the EDS request (which has + * resources for this server) goes on the wire. + */ + @Test + public void streamReady_skipsEmptyDiscoveryRequestForUnsubscribedType() { + // CDS is globally subscribed (e.g. against a different authority) but has no resources on + // this server. EDS has one resource on this server. + Map> subscribedTypes = + ImmutableMap.of(CDS_TYPE_URL, cdsType, EDS_TYPE_URL, edsType); + when(resourceStore.getSubscribedResourceTypesWithTypeUrl()).thenReturn(subscribedTypes); + when(resourceStore.getSubscribedResources(serverInfo, cdsType)).thenReturn(null); + when(resourceStore.getSubscribedResources(serverInfo, edsType)) + .thenReturn(ImmutableList.of("foo-endpoint")); + + // Triggers stream creation and registers the EventHandler. + syncContext.execute(cpc::sendDiscoveryRequests); + verify(streamingCall).start(handlerCaptor.capture()); + + // Drive the stream into the connected state. onReady() flips sentInitialRequest=true and + // re-invokes sendDiscoveryRequests, which iterates the globally-subscribed types. + handlerCaptor.getValue().onReady(); + + // EDS request was sent with the one resource for this server. + ArgumentCaptor sent = ArgumentCaptor.forClass(DiscoveryRequest.class); + verify(streamingCall, atLeastOnce()).sendMessage(sent.capture()); + ImmutableSet sentTypes = sent.getAllValues().stream() + .map(DiscoveryRequest::getTypeUrl) + .collect(ImmutableSet.toImmutableSet()); + assertThat(sentTypes).contains(EDS_TYPE_URL); + assertThat(sentTypes).doesNotContain(CDS_TYPE_URL); + + // Confirm the EDS request actually carried the resource name. + DiscoveryRequest edsReq = sent.getAllValues().stream() + .filter(r -> r.getTypeUrl().equals(EDS_TYPE_URL)) + .findFirst() + .orElseThrow(() -> new AssertionError("EDS request not sent")); + assertThat(edsReq.getResourceNamesList()).containsExactly("foo-endpoint"); + } + + /** + * If a server has resources for every globally-subscribed type, the empty-skip guard is a + * no-op: a DiscoveryRequest is sent for every type. This guards against the skip becoming + * over-eager and dropping legitimate subscriptions. + */ + @Test + public void streamReady_sendsRequestForAllTypesWhenAllHaveResources() { + Map> subscribedTypes = + ImmutableMap.of(CDS_TYPE_URL, cdsType, EDS_TYPE_URL, edsType); + when(resourceStore.getSubscribedResourceTypesWithTypeUrl()).thenReturn(subscribedTypes); + when(resourceStore.getSubscribedResources(serverInfo, cdsType)) + .thenReturn(ImmutableList.of("foo-cluster")); + when(resourceStore.getSubscribedResources(serverInfo, edsType)) + .thenReturn(ImmutableList.of("foo-endpoint")); + + syncContext.execute(cpc::sendDiscoveryRequests); + verify(streamingCall).start(handlerCaptor.capture()); + handlerCaptor.getValue().onReady(); + + ArgumentCaptor sent = ArgumentCaptor.forClass(DiscoveryRequest.class); + verify(streamingCall, times(2)).sendMessage(sent.capture()); + ImmutableSet sentTypes = sent.getAllValues().stream() + .map(DiscoveryRequest::getTypeUrl) + .collect(ImmutableSet.toImmutableSet()); + assertThat(sentTypes).containsExactly(CDS_TYPE_URL, EDS_TYPE_URL); + } + + /** + * If only one type has a subscription on this server, no request is sent for the unsubscribed + * type. This is the canonical multi-authority federation case (e.g. fabric authority owns CDS, + * eds-control-plane owns EDS — the eds-control-plane stream should only see EDS). + */ + @Test + public void streamReady_skipsTypeWithNoSubscription() { + Map> subscribedTypes = + ImmutableMap.of(CDS_TYPE_URL, cdsType, EDS_TYPE_URL, edsType); + when(resourceStore.getSubscribedResourceTypesWithTypeUrl()).thenReturn(subscribedTypes); + when(resourceStore.getSubscribedResources(serverInfo, cdsType)).thenReturn(null); + when(resourceStore.getSubscribedResources(serverInfo, edsType)) + .thenReturn(ImmutableList.of("foo-endpoint")); + + syncContext.execute(cpc::sendDiscoveryRequests); + verify(streamingCall).start(handlerCaptor.capture()); + handlerCaptor.getValue().onReady(); + + verify(streamingCall, never()).sendMessage( + argThatTypeUrlIs(CDS_TYPE_URL)); + verify(streamingCall).sendMessage(argThatTypeUrlIs(EDS_TYPE_URL)); + } + + /** + * Per the ResourceStore contract in XdsClient.java, an empty collection from + * getSubscribedResources indicates a wildcard subscription. The skip-on-empty guard must not + * suppress wildcard requests on initial stream ready — the server needs the empty-resource-list + * DiscoveryRequest to start streaming, and the watcher's missing-resource timers must start. + */ + @Test + public void streamReady_sendsWildcardRequestAndStartsTimers() { + Map> subscribedTypes = ImmutableMap.of(CDS_TYPE_URL, cdsType); + when(resourceStore.getSubscribedResourceTypesWithTypeUrl()).thenReturn(subscribedTypes); + // Empty collection == wildcard subscription per the ResourceStore contract. + when(resourceStore.getSubscribedResources(serverInfo, cdsType)) + .thenReturn(Collections.emptyList()); + + syncContext.execute(cpc::sendDiscoveryRequests); + verify(streamingCall).start(handlerCaptor.capture()); + handlerCaptor.getValue().onReady(); + + ArgumentCaptor sent = ArgumentCaptor.forClass(DiscoveryRequest.class); + verify(streamingCall, atLeastOnce()).sendMessage(sent.capture()); + DiscoveryRequest cdsReq = sent.getAllValues().stream() + .filter(r -> r.getTypeUrl().equals(CDS_TYPE_URL)) + .findFirst() + .orElseThrow(() -> new AssertionError("CDS wildcard request not sent")); + assertThat(cdsReq.getResourceNamesList()).isEmpty(); + + verify(resourceStore).startMissingResourceTimers(Collections.emptyList(), cdsType); + } + + /** + * If a watch is canceled after the initial DiscoveryRequest goes out but before any response + * is ACKed, the empty unsubscribe must still be sent — otherwise the server keeps the stale + * subscription until the stream resets. The skip guard must gate on per-stream send history, + * not on the {@code versions} map (which is only populated on ACK). + */ + @Test + public void cancelBeforeAck_sendsEmptyUnsubscribe() { + Map> subscribedTypes = ImmutableMap.of(CDS_TYPE_URL, cdsType); + when(resourceStore.getSubscribedResourceTypesWithTypeUrl()).thenReturn(subscribedTypes); + when(resourceStore.getSubscribedResources(serverInfo, cdsType)) + .thenReturn(ImmutableList.of("foo-cluster")); + + syncContext.execute(cpc::sendDiscoveryRequests); + verify(streamingCall).start(handlerCaptor.capture()); + handlerCaptor.getValue().onReady(); + + // Initial DiscoveryRequest with the resource went out. No DiscoveryResponse has been ACKed. + verify(streamingCall).sendMessage(argThatTypeUrlIs(CDS_TYPE_URL)); + + // Cancel the watch before any response arrives: store now reports no subscription. + when(resourceStore.getSubscribedResources(serverInfo, cdsType)).thenReturn(null); + syncContext.execute(() -> cpc.adjustResourceSubscription(cdsType)); + + ArgumentCaptor sent = ArgumentCaptor.forClass(DiscoveryRequest.class); + verify(streamingCall, times(2)).sendMessage(sent.capture()); + DiscoveryRequest unsub = sent.getAllValues().get(1); + assertThat(unsub.getTypeUrl()).isEqualTo(CDS_TYPE_URL); + assertThat(unsub.getResourceNamesList()).isEmpty(); + } + + private static DiscoveryRequest argThatTypeUrlIs(String typeUrl) { + return argThat(req -> req != null && typeUrl.equals(req.getTypeUrl())); + } +} \ No newline at end of file diff --git a/xds/src/test/java/io/grpc/xds/client/LoadStatsManager2Test.java b/xds/src/test/java/io/grpc/xds/client/LoadStatsManager2Test.java index 9a90a92dcbd..a0642f7e4bb 100644 --- a/xds/src/test/java/io/grpc/xds/client/LoadStatsManager2Test.java +++ b/xds/src/test/java/io/grpc/xds/client/LoadStatsManager2Test.java @@ -27,6 +27,7 @@ import io.grpc.xds.client.Stats.ClusterStats; import io.grpc.xds.client.Stats.DroppedRequests; import io.grpc.xds.client.Stats.UpstreamLocalityStats; +import java.util.Arrays; import java.util.List; import java.util.Objects; import java.util.concurrent.TimeUnit; @@ -254,6 +255,59 @@ public void sharedLoadCounterStatsAggregation() { 2.718); } + @Test + public void recordMetrics_orcaLrsPropagationEnabled_specificMetrics() { + boolean originalVal = LoadStatsManager2.isEnabledOrcaLrsPropagation; + LoadStatsManager2.isEnabledOrcaLrsPropagation = true; + BackendMetricPropagation backendMetricPropagation = BackendMetricPropagation.fromMetricSpecs( + Arrays.asList("cpu_utilization", "named_metrics.named1")); + ClusterLocalityStats stats = loadStatsManager.getClusterLocalityStats( + CLUSTER_NAME1, EDS_SERVICE_NAME1, LOCALITY1, backendMetricPropagation); + + stats.recordTopLevelMetrics(0.8, 0.5, 0.0); // cpu, mem, app + stats.recordBackendLoadMetricStats(ImmutableMap.of("named1", 123.4, "named2", 567.8)); + stats.recordCallFinished(Status.OK); + ClusterStats report = Iterables.getOnlyElement( + loadStatsManager.getClusterStatsReports(CLUSTER_NAME1)); + UpstreamLocalityStats localityStats = + Iterables.getOnlyElement(report.upstreamLocalityStatsList()); + + assertThat(localityStats.loadMetricStatsMap()).containsKey("cpu_utilization"); + assertThat(localityStats.loadMetricStatsMap().get("cpu_utilization").totalMetricValue()) + .isWithin(TOLERANCE).of(0.8); + assertThat(localityStats.loadMetricStatsMap()).doesNotContainKey("mem_utilization"); + assertThat(localityStats.loadMetricStatsMap()).containsKey("named_metrics.named1"); + assertThat(localityStats.loadMetricStatsMap().get("named_metrics.named1").totalMetricValue()) + .isWithin(TOLERANCE).of(123.4); + assertThat(localityStats.loadMetricStatsMap()).doesNotContainKey("named_metrics.named2"); + LoadStatsManager2.isEnabledOrcaLrsPropagation = originalVal; + } + + @Test + public void recordMetrics_orcaLrsPropagationEnabled_wildcardNamedMetrics() { + boolean originalVal = LoadStatsManager2.isEnabledOrcaLrsPropagation; + LoadStatsManager2.isEnabledOrcaLrsPropagation = true; + BackendMetricPropagation backendMetricPropagation = BackendMetricPropagation.fromMetricSpecs( + Arrays.asList("named_metrics.*")); + ClusterLocalityStats stats = loadStatsManager.getClusterLocalityStats( + CLUSTER_NAME1, EDS_SERVICE_NAME1, LOCALITY1, backendMetricPropagation); + + stats.recordBackendLoadMetricStats(ImmutableMap.of("named1", 123.4, "named2", 567.8)); + stats.recordCallFinished(Status.OK); + ClusterStats report = Iterables.getOnlyElement( + loadStatsManager.getClusterStatsReports(CLUSTER_NAME1)); + UpstreamLocalityStats localityStats = + Iterables.getOnlyElement(report.upstreamLocalityStatsList()); + + assertThat(localityStats.loadMetricStatsMap()).containsKey("named_metrics.named1"); + assertThat(localityStats.loadMetricStatsMap().get("named_metrics.named1").totalMetricValue()) + .isWithin(TOLERANCE).of(123.4); + assertThat(localityStats.loadMetricStatsMap()).containsKey("named_metrics.named2"); + assertThat(localityStats.loadMetricStatsMap().get("named_metrics.named2").totalMetricValue()) + .isWithin(TOLERANCE).of(567.8); + LoadStatsManager2.isEnabledOrcaLrsPropagation = originalVal; + } + @Test public void loadCounterDelayedDeletionAfterAllInProgressRequestsReported() { ClusterLocalityStats counter = loadStatsManager.getClusterLocalityStats( diff --git a/xds/src/test/java/io/grpc/xds/internal/MatcherParserTest.java b/xds/src/test/java/io/grpc/xds/internal/MatcherParserTest.java new file mode 100644 index 00000000000..86a6a95fd4b --- /dev/null +++ b/xds/src/test/java/io/grpc/xds/internal/MatcherParserTest.java @@ -0,0 +1,85 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; + +import io.envoyproxy.envoy.type.matcher.v3.RegexMatcher; +import io.envoyproxy.envoy.type.matcher.v3.StringMatcher; +import io.envoyproxy.envoy.type.v3.FractionalPercent; +import io.envoyproxy.envoy.type.v3.FractionalPercent.DenominatorType; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class MatcherParserTest { + + @Test + public void parseStringMatcher_exact() { + StringMatcher proto = + StringMatcher.newBuilder().setExact("exact-match").setIgnoreCase(true).build(); + Matchers.StringMatcher matcher = MatcherParser.parseStringMatcher(proto); + assertThat(matcher).isNotNull(); + } + + @Test + public void parseStringMatcher_allTypes() { + MatcherParser.parseStringMatcher(StringMatcher.newBuilder().setExact("test").build()); + MatcherParser.parseStringMatcher(StringMatcher.newBuilder().setPrefix("test").build()); + MatcherParser.parseStringMatcher(StringMatcher.newBuilder().setSuffix("test").build()); + MatcherParser.parseStringMatcher(StringMatcher.newBuilder().setContains("test").build()); + MatcherParser.parseStringMatcher(StringMatcher.newBuilder() + .setSafeRegex(RegexMatcher.newBuilder().setRegex(".*").build()).build()); + } + + @Test + public void parseStringMatcher_unknownTypeThrows() { + StringMatcher unknownProto = StringMatcher.getDefaultInstance(); + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, + () -> MatcherParser.parseStringMatcher(unknownProto)); + assertThat(exception).hasMessageThat().contains("Unknown StringMatcher match pattern"); + } + + @Test + public void parseFractionMatcher_denominators() { + Matchers.FractionMatcher hundred = MatcherParser.parseFractionMatcher(FractionalPercent + .newBuilder().setNumerator(1).setDenominator(DenominatorType.HUNDRED).build()); + assertThat(hundred.numerator()).isEqualTo(1); + assertThat(hundred.denominator()).isEqualTo(100); + + Matchers.FractionMatcher tenThousand = MatcherParser.parseFractionMatcher(FractionalPercent + .newBuilder().setNumerator(2).setDenominator(DenominatorType.TEN_THOUSAND).build()); + assertThat(tenThousand.numerator()).isEqualTo(2); + assertThat(tenThousand.denominator()).isEqualTo(10_000); + + Matchers.FractionMatcher million = MatcherParser.parseFractionMatcher(FractionalPercent + .newBuilder().setNumerator(3).setDenominator(DenominatorType.MILLION).build()); + assertThat(million.numerator()).isEqualTo(3); + assertThat(million.denominator()).isEqualTo(1_000_000); + } + + @Test + public void parseFractionMatcher_unknownDenominatorThrows() { + FractionalPercent unknownProto = + FractionalPercent.newBuilder().setDenominatorValue(999).build(); + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, + () -> MatcherParser.parseFractionMatcher(unknownProto)); + assertThat(exception).hasMessageThat().contains("Unknown denominator type"); + } +} diff --git a/xds/src/test/java/io/grpc/xds/internal/MetricReportUtilsTest.java b/xds/src/test/java/io/grpc/xds/internal/MetricReportUtilsTest.java new file mode 100644 index 00000000000..9d7a3910216 --- /dev/null +++ b/xds/src/test/java/io/grpc/xds/internal/MetricReportUtilsTest.java @@ -0,0 +1,115 @@ +/* + * Copyright 2026 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +import io.grpc.services.InternalCallMetricRecorder; +import io.grpc.services.MetricReport; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.OptionalDouble; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Tests for {@link MetricReportUtils}. */ +@RunWith(JUnit4.class) +public class MetricReportUtilsTest { + + @Test + public void getMetricValue_cpuUtilization() { + MetricReport report = createMetricReport(0.5, 0.1, 0.2, 10.0, 5.0, Collections.emptyMap()); + MetricReportUtils.ParsedMetricName parsed = + MetricReportUtils.ParsedMetricName.parse("cpu_utilization"); + OptionalDouble result = MetricReportUtils.getMetricValue(report, parsed); + assertTrue(result.isPresent()); + assertEquals(0.5, result.getAsDouble(), 0.0001); + } + + @Test + public void getMetricValue_applicationUtilization() { + MetricReport report = createMetricReport(0.5, 0.1, 0.2, 10.0, 5.0, Collections.emptyMap()); + MetricReportUtils.ParsedMetricName parsed = + MetricReportUtils.ParsedMetricName.parse("application_utilization"); + OptionalDouble result = MetricReportUtils.getMetricValue(report, parsed); + assertTrue(result.isPresent()); + assertEquals(0.1, result.getAsDouble(), 0.0001); + } + + @Test + public void getMetricValue_memUtilization() { + MetricReport report = createMetricReport(0.5, 0.1, 0.2, 10.0, 5.0, Collections.emptyMap()); + MetricReportUtils.ParsedMetricName parsed = + MetricReportUtils.ParsedMetricName.parse("mem_utilization"); + OptionalDouble result = MetricReportUtils.getMetricValue(report, parsed); + assertTrue(result.isPresent()); + assertEquals(0.2, result.getAsDouble(), 0.0001); + } + + @Test + public void getMetricValue_utilizationMetric() { + Map utilizationMetrics = new HashMap<>(); + utilizationMetrics.put("foo", 1.23); + MetricReport report = InternalCallMetricRecorder.createMetricReport( + 0, 0, 0, 0, 0, Collections.emptyMap(), utilizationMetrics, Collections.emptyMap()); + + MetricReportUtils.ParsedMetricName parsed = + MetricReportUtils.ParsedMetricName.parse("utilization.foo"); + OptionalDouble result = MetricReportUtils.getMetricValue(report, parsed); + assertTrue(result.isPresent()); + assertEquals(1.23, result.getAsDouble(), 0.0001); + + MetricReportUtils.ParsedMetricName bad = + MetricReportUtils.ParsedMetricName.parse("utilization.bar"); + assertFalse(MetricReportUtils.getMetricValue(report, bad).isPresent()); + } + + @Test + public void getMetricValue_namedMetric() { + Map namedMetrics = new HashMap<>(); + namedMetrics.put("foo", 7.89); + MetricReport report = createMetricReport(0, 0, 0, 0, 0, namedMetrics); + + MetricReportUtils.ParsedMetricName parsed = + MetricReportUtils.ParsedMetricName.parse("named_metrics.foo"); + OptionalDouble result = MetricReportUtils.getMetricValue(report, parsed); + assertTrue(result.isPresent()); + assertEquals(7.89, result.getAsDouble(), 0.0001); + + MetricReportUtils.ParsedMetricName bad = + MetricReportUtils.ParsedMetricName.parse("named_metrics.bar"); + assertFalse(MetricReportUtils.getMetricValue(report, bad).isPresent()); + } + + @Test + public void getMetricValue_invalidMetric() { + MetricReport report = createMetricReport(0.5, 0.1, 0.2, 10.0, 5.0, Collections.emptyMap()); + MetricReportUtils.ParsedMetricName invalid = + MetricReportUtils.ParsedMetricName.parse("invalid_metric"); + assertFalse(MetricReportUtils.getMetricValue(report, invalid).isPresent()); + } + + private MetricReport createMetricReport(double cpu, double app, double mem, double qps, + double eps, Map namedMetrics) { + return InternalCallMetricRecorder.createMetricReport( + cpu, app, mem, qps, eps, Collections.emptyMap(), Collections.emptyMap(), namedMetrics); + } +} diff --git a/xds/src/test/java/io/grpc/xds/internal/ProtobufJsonConverterTest.java b/xds/src/test/java/io/grpc/xds/internal/ProtobufJsonConverterTest.java new file mode 100644 index 00000000000..86f9be4dda8 --- /dev/null +++ b/xds/src/test/java/io/grpc/xds/internal/ProtobufJsonConverterTest.java @@ -0,0 +1,83 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertEquals; + +import com.google.common.collect.ImmutableMap; +import com.google.protobuf.ListValue; +import com.google.protobuf.Struct; +import com.google.protobuf.Value; +import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class ProtobufJsonConverterTest { + + @Test + public void testEmptyStruct() { + Struct emptyStruct = Struct.newBuilder().build(); + Map result = ProtobufJsonConverter.convertToJson(emptyStruct); + assertThat(result).isEmpty(); + } + + @Test + public void testStructWithValues() { + Struct struct = Struct.newBuilder() + .putFields("stringKey", Value.newBuilder().setStringValue("stringValue").build()) + .putFields("numberKey", Value.newBuilder().setNumberValue(123.45).build()) + .putFields("boolKey", Value.newBuilder().setBoolValue(true).build()) + .putFields("nullKey", Value.newBuilder().setNullValueValue(0).build()) + .putFields("structKey", Value.newBuilder() + .setStructValue(Struct.newBuilder() + .putFields("nestedKey", Value.newBuilder().setStringValue("nestedValue").build()) + .build()) + .build()) + .putFields("listKey", Value.newBuilder() + .setListValue(ListValue.newBuilder() + .addValues(Value.newBuilder().setNumberValue(1).build()) + .addValues(Value.newBuilder().setStringValue("two").build()) + .addValues(Value.newBuilder().setBoolValue(false).build()) + .build()) + .build()) + .build(); + + Map result = ProtobufJsonConverter.convertToJson(struct); + + Map goldenResult = new HashMap<>(); + goldenResult.put("stringKey", "stringValue"); + goldenResult.put("numberKey", 123.45); + goldenResult.put("boolKey", true); + goldenResult.put("nullKey", null); + goldenResult.put("structKey", ImmutableMap.of("nestedKey", "nestedValue")); + goldenResult.put("listKey", Arrays.asList(1.0, "two", false)); + + assertEquals(goldenResult, result); + } + + @Test(expected = IllegalArgumentException.class) + public void testUnknownValueType() { + Value unknownValue = Value.newBuilder().build(); // Default instance with no kind case set. + ProtobufJsonConverter.convertToJson( + Struct.newBuilder().putFields("unknownKey", unknownValue).build()); + } +} diff --git a/xds/src/test/java/io/grpc/xds/internal/grpcservice/HeaderValueTest.java b/xds/src/test/java/io/grpc/xds/internal/grpcservice/HeaderValueTest.java new file mode 100644 index 00000000000..b55e6ae76f7 --- /dev/null +++ b/xds/src/test/java/io/grpc/xds/internal/grpcservice/HeaderValueTest.java @@ -0,0 +1,49 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.grpcservice; + +import static com.google.common.truth.Truth.assertThat; + +import com.google.protobuf.ByteString; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class HeaderValueTest { + + @Test + public void create_withStringValue_success() { + HeaderValue headerValue = HeaderValue.create("key1", "value1"); + assertThat(headerValue.key()).isEqualTo("key1"); + assertThat(headerValue.value().isPresent()).isTrue(); + assertThat(headerValue.value().get()).isEqualTo("value1"); + assertThat(headerValue.rawValue().isPresent()).isFalse(); + } + + @Test + public void create_withByteStringValue_success() { + ByteString rawValue = ByteString.copyFromUtf8("raw_value"); + HeaderValue headerValue = HeaderValue.create("key2", rawValue); + assertThat(headerValue.key()).isEqualTo("key2"); + assertThat(headerValue.rawValue().isPresent()).isTrue(); + assertThat(headerValue.rawValue().get()).isEqualTo(rawValue); + assertThat(headerValue.value().isPresent()).isFalse(); + } + + +} diff --git a/xds/src/test/java/io/grpc/xds/internal/grpcservice/HeaderValueValidationUtilsTest.java b/xds/src/test/java/io/grpc/xds/internal/grpcservice/HeaderValueValidationUtilsTest.java new file mode 100644 index 00000000000..c4658f3f305 --- /dev/null +++ b/xds/src/test/java/io/grpc/xds/internal/grpcservice/HeaderValueValidationUtilsTest.java @@ -0,0 +1,87 @@ +/* + * Copyright 2026 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.grpcservice; + +import static com.google.common.truth.Truth.assertThat; + +import com.google.protobuf.ByteString; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** + * Unit tests for {@link HeaderValueValidationUtils}. + */ +@RunWith(JUnit4.class) +public class HeaderValueValidationUtilsTest { + + @Test + public void isDisallowed_string_emptyKey() { + assertThat(HeaderValueValidationUtils.isDisallowed("")).isTrue(); + } + + @Test + public void isDisallowed_string_tooLongKey() { + String longKey = new String(new char[16385]).replace('\0', 'a'); + assertThat(HeaderValueValidationUtils.isDisallowed(longKey)).isTrue(); + } + + @Test + public void isDisallowed_string_notLowercase() { + assertThat(HeaderValueValidationUtils.isDisallowed("Content-Type")).isTrue(); + } + + @Test + public void isDisallowed_string_grpcPrefix() { + assertThat(HeaderValueValidationUtils.isDisallowed("grpc-timeout")).isTrue(); + } + + @Test + public void isDisallowed_string_systemHeader_colon() { + assertThat(HeaderValueValidationUtils.isDisallowed(":authority")).isTrue(); + } + + @Test + public void isDisallowed_string_systemHeader_host() { + assertThat(HeaderValueValidationUtils.isDisallowed("host")).isTrue(); + } + + @Test + public void isDisallowed_string_valid() { + assertThat(HeaderValueValidationUtils.isDisallowed("content-type")).isFalse(); + } + + @Test + public void isDisallowed_headerValue_tooLongValue() { + String longValue = new String(new char[16385]).replace('\0', 'v'); + HeaderValue header = HeaderValue.create("content-type", longValue); + assertThat(HeaderValueValidationUtils.isDisallowed(header)).isTrue(); + } + + @Test + public void isDisallowed_headerValue_tooLongRawValue() { + ByteString longRawValue = ByteString.copyFrom(new byte[16385]); + HeaderValue header = HeaderValue.create("content-type", longRawValue); + assertThat(HeaderValueValidationUtils.isDisallowed(header)).isTrue(); + } + + @Test + public void isDisallowed_headerValue_valid() { + HeaderValue header = HeaderValue.create("content-type", "application/grpc"); + assertThat(HeaderValueValidationUtils.isDisallowed(header)).isFalse(); + } +} diff --git a/xds/src/test/java/io/grpc/xds/internal/headermutations/HeaderMutationFilterTest.java b/xds/src/test/java/io/grpc/xds/internal/headermutations/HeaderMutationFilterTest.java new file mode 100644 index 00000000000..f07997a6244 --- /dev/null +++ b/xds/src/test/java/io/grpc/xds/internal/headermutations/HeaderMutationFilterTest.java @@ -0,0 +1,244 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.headermutations; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; + +import com.google.common.base.Strings; +import com.google.common.collect.ImmutableList; +import com.google.protobuf.ByteString; +import com.google.re2j.Pattern; +import io.grpc.xds.internal.headermutations.HeaderValueOption.HeaderAppendAction; +import java.util.Optional; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class HeaderMutationFilterTest { + + private static final int MAX_HEADER_LENGTH = 16384; + + private static HeaderValueOption header(String key, ByteString value) { + return HeaderValueOption.create(io.grpc.xds.internal.grpcservice.HeaderValue.create(key, value), + HeaderAppendAction.APPEND_IF_EXISTS_OR_ADD, false); + } + + private static HeaderValueOption header(String key, String value) { + return HeaderValueOption.create(io.grpc.xds.internal.grpcservice.HeaderValue.create(key, value), + HeaderAppendAction.APPEND_IF_EXISTS_OR_ADD, false); + } + + @Test + public void filter_validationRules_dropsInvalidHeaders() throws Exception { + HeaderMutationFilter filter = new HeaderMutationFilter(Optional.empty()); + @SuppressWarnings("InlineMeInliner") + String longString = Strings.repeat("a", MAX_HEADER_LENGTH + 1); + ByteString longBytes = ByteString.copyFrom(new byte[MAX_HEADER_LENGTH + 1]); + + HeaderMutations mutations = HeaderMutations.create( + ImmutableList.of( + header("add-key", "add-value"), header(":authority", "new-authority"), + header("host", "new-host"), header(":scheme", "https"), header(":method", "PUT"), + header("resp-add-key", "resp-add-value"), header(":scheme", "https"), + header(":path", "/new-path"), header(":grpc-trace-bin", "binary-value"), + header(":alt-svc", "h3=:443"), header("user-agent", "new-agent"), + header("Valid-Key", "value"), header("", "value"), header(longString, "value"), + header("long-value-key", longString), header("long-bin-key-bin", longBytes), + header("grpc-timeout", "10S"), header("valid-key-lower", "value")), + ImmutableList.of("remove-key", "host", ":authority", ":scheme", ":method", ":foo", ":bar", + "Valid-Key", "", longString, "grpc-timeout", "UPPER-REMOVE", "lower-remove")); + + HeaderMutations filtered = filter.filter(mutations); + + assertThat(filtered.headersToRemove()).containsExactly("remove-key", "lower-remove"); + assertThat(filtered.headers()).containsExactly( + header("add-key", "add-value"), header("resp-add-key", "resp-add-value"), + header("user-agent", "new-agent"), header("valid-key-lower", "value")); + } + + @Test + public void filter_validationRules_throwsOnInvalidHeaders() throws Exception { + HeaderMutationRulesConfig rules = + HeaderMutationRulesConfig.builder().disallowIsError(true).build(); + HeaderMutationFilter filter = new HeaderMutationFilter(Optional.of(rules)); + @SuppressWarnings("InlineMeInliner") + String longString = Strings.repeat("a", MAX_HEADER_LENGTH + 1); + + // Test system headers modification + assertThrows(HeaderMutationDisallowedException.class, () -> filter.filter(HeaderMutations + .create( + ImmutableList.of(header(":path", "/new-path")), ImmutableList.of()))); + + // Test system headers removal + assertThrows(HeaderMutationDisallowedException.class, + () -> filter.filter(HeaderMutations.create( + ImmutableList.of(), ImmutableList.of(":path")))); + + // Test uppercase header modification + assertThrows(HeaderMutationDisallowedException.class, () -> filter.filter(HeaderMutations + .create( + ImmutableList.of(header("Valid-Key", "value")), ImmutableList.of()))); + + // Test uppercase header removal + assertThrows(HeaderMutationDisallowedException.class, () -> filter + .filter(HeaderMutations.create( + ImmutableList.of(), ImmutableList.of("UPPER-REMOVE")))); + + // Test empty header + assertThrows(HeaderMutationDisallowedException.class, () -> filter + .filter(HeaderMutations.create( + ImmutableList.of(header("", "value")), ImmutableList.of()))); + + // Test long header key + assertThrows(HeaderMutationDisallowedException.class, () -> filter + .filter(HeaderMutations.create( + ImmutableList.of(), ImmutableList.of(longString)))); + } + + + @Test + public void filter_mutationRules_disallowAll_dropsAll() throws Exception { + HeaderMutationRulesConfig rules = HeaderMutationRulesConfig.builder().disallowAll(true).build(); + HeaderMutationFilter filter = new HeaderMutationFilter(Optional.of(rules)); + HeaderMutations mutations = HeaderMutations.create( + ImmutableList.of(header("add-key", "add-value"), header("resp-add-key", "resp-add-value")), + ImmutableList.of("remove-key")); + + HeaderMutations filtered = filter.filter(mutations); + + assertThat(filtered.headers()).isEmpty(); + assertThat(filtered.headersToRemove()).isEmpty(); + } + + @Test + public void filter_mutationRules_disallowAll_throws() throws Exception { + HeaderMutationRulesConfig rules = + HeaderMutationRulesConfig.builder().disallowAll(true).disallowIsError(true).build(); + HeaderMutationFilter filter = new HeaderMutationFilter(Optional.of(rules)); + + // Test add header + assertThrows(HeaderMutationDisallowedException.class, () -> filter.filter(HeaderMutations + .create( + ImmutableList.of(header("add-key", "add-value")), ImmutableList.of()))); + + // Test remove header + assertThrows(HeaderMutationDisallowedException.class, () -> filter + .filter(HeaderMutations.create( + ImmutableList.of(), ImmutableList.of("remove-key")))); + + // Test response header + assertThrows(HeaderMutationDisallowedException.class, () -> filter.filter(HeaderMutations + .create( + ImmutableList.of(header("resp-add-key", "resp-add-value")), ImmutableList.of()))); + } + + + @Test + public void filter_mutationRules_disallowExpression_dropsMatching() throws Exception { + HeaderMutationRulesConfig rules = HeaderMutationRulesConfig.builder() + .disallowExpression(Pattern.compile("^x-private-.*")).build(); + HeaderMutationFilter filter = new HeaderMutationFilter(Optional.of(rules)); + HeaderMutations mutations = HeaderMutations.create( + ImmutableList.of(header("x-public", "value"), header("x-private-key", "value"), + header("x-public-resp", "value"), header("x-private-resp", "value")), + ImmutableList.of("x-public-remove", "x-private-remove")); + + HeaderMutations filtered = filter.filter(mutations); + + assertThat(filtered.headersToRemove()).containsExactly("x-public-remove"); + assertThat(filtered.headers()).containsExactly(header("x-public", "value"), + header("x-public-resp", "value")); + } + + @Test + public void filter_mutationRules_disallowExpression_throws() throws Exception { + HeaderMutationRulesConfig rules = HeaderMutationRulesConfig.builder() + .disallowExpression(Pattern.compile("^x-private-.*")).disallowIsError(true).build(); + HeaderMutationFilter filter = new HeaderMutationFilter(Optional.of(rules)); + + // Test disallowed key modification + assertThrows(HeaderMutationDisallowedException.class, () -> filter.filter(HeaderMutations + .create( + ImmutableList.of(header("x-private-key", "value")), ImmutableList.of()))); + + // Test disallowed key removal + assertThrows(HeaderMutationDisallowedException.class, () -> filter + .filter(HeaderMutations.create( + ImmutableList.of(), ImmutableList.of("x-private-remove")))); + } + + + @Test + public void filter_mutationRules_precedence() throws Exception { + HeaderMutationRulesConfig rules = HeaderMutationRulesConfig.builder() + .disallowAll(true) + .allowExpression(Pattern.compile("^x-allowed-.*")) + .disallowExpression(Pattern.compile("^x-allowed-but-disallowed-.*")) + .build(); + HeaderMutationFilter filter = new HeaderMutationFilter(Optional.of(rules)); + + // Case 1: allowExpression overrides disallowAll + HeaderMutations mutations1 = HeaderMutations.create( + ImmutableList.of(header("x-allowed-key", "value"), header("not-allowed", "value")), + ImmutableList.of("x-allowed-remove", "not-allowed-remove")); + HeaderMutations filtered1 = filter.filter(mutations1); + assertThat(filtered1.headersToRemove()).containsExactly("x-allowed-remove"); + assertThat(filtered1.headers()).containsExactly(header("x-allowed-key", "value")); + + // Case 2: disallowExpression overrides allowExpression + HeaderMutations mutations2 = HeaderMutations.create( + ImmutableList.of(header("x-allowed-but-disallowed-key", "value")), + ImmutableList.of("x-allowed-but-disallowed-remove")); + HeaderMutations filtered2 = filter.filter(mutations2); + assertThat(filtered2.headers()).isEmpty(); + assertThat(filtered2.headersToRemove()).isEmpty(); + } + + @Test + public void filter_mutationRules_precedence_throws() throws Exception { + // Case 1: allowExpression overrides disallowAll (does not throw) + HeaderMutationRulesConfig rules1 = HeaderMutationRulesConfig.builder() + .disallowAll(true) + .allowExpression(Pattern.compile("^x-allowed-.*")) + .disallowIsError(true) + .build(); + HeaderMutationFilter filter1 = new HeaderMutationFilter(Optional.of(rules1)); + HeaderMutations mutations1 = HeaderMutations.create( + ImmutableList.of(header("x-allowed-key", "value")), ImmutableList.of("x-allowed-remove")); + HeaderMutations filtered1 = filter1.filter(mutations1); + assertThat(filtered1.headersToRemove()).containsExactly("x-allowed-remove"); + assertThat(filtered1.headers()).containsExactly(header("x-allowed-key", "value")); + + // Case 2: disallowExpression overrides allowExpression (throws) + HeaderMutationRulesConfig rules2 = HeaderMutationRulesConfig.builder() + .allowExpression(Pattern.compile("^x-allowed-.*")) + .disallowExpression(Pattern.compile("^x-allowed-but-disallowed-.*")) + .disallowIsError(true) + .build(); + HeaderMutationFilter filter2 = new HeaderMutationFilter(Optional.of(rules2)); + assertThrows(HeaderMutationDisallowedException.class, + () -> filter2.filter(HeaderMutations.create( + ImmutableList.of(header("x-allowed-but-disallowed-key", "value")), + ImmutableList.of()))); + + assertThrows(HeaderMutationDisallowedException.class, () -> filter2.filter(HeaderMutations + .create( + ImmutableList.of(), ImmutableList.of("x-allowed-but-disallowed-remove")))); + } +} diff --git a/xds/src/test/java/io/grpc/xds/internal/headermutations/HeaderMutationRulesConfigTest.java b/xds/src/test/java/io/grpc/xds/internal/headermutations/HeaderMutationRulesConfigTest.java new file mode 100644 index 00000000000..9f5cb75460f --- /dev/null +++ b/xds/src/test/java/io/grpc/xds/internal/headermutations/HeaderMutationRulesConfigTest.java @@ -0,0 +1,84 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.headermutations; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +import com.google.re2j.Pattern; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class HeaderMutationRulesConfigTest { + @Test + public void testBuilderDefaultValues() { + HeaderMutationRulesConfig config = HeaderMutationRulesConfig.builder().build(); + assertFalse(config.disallowAll()); + assertFalse(config.disallowIsError()); + assertThat(config.allowExpression()).isEmpty(); + assertThat(config.disallowExpression()).isEmpty(); + } + + @Test + public void testBuilder_setDisallowAll() { + HeaderMutationRulesConfig config = + HeaderMutationRulesConfig.builder().disallowAll(true).build(); + assertTrue(config.disallowAll()); + } + + @Test + public void testBuilder_setDisallowIsError() { + HeaderMutationRulesConfig config = + HeaderMutationRulesConfig.builder().disallowIsError(true).build(); + assertTrue(config.disallowIsError()); + } + + @Test + public void testBuilder_setAllowExpression() { + Pattern pattern = Pattern.compile("allow.*"); + HeaderMutationRulesConfig config = + HeaderMutationRulesConfig.builder().allowExpression(pattern).build(); + assertThat(config.allowExpression()).hasValue(pattern); + } + + @Test + public void testBuilder_setDisallowExpression() { + Pattern pattern = Pattern.compile("disallow.*"); + HeaderMutationRulesConfig config = + HeaderMutationRulesConfig.builder().disallowExpression(pattern).build(); + assertThat(config.disallowExpression()).hasValue(pattern); + } + + @Test + public void testBuilder_setAll() { + Pattern allowPattern = Pattern.compile("allow.*"); + Pattern disallowPattern = Pattern.compile("disallow.*"); + HeaderMutationRulesConfig config = HeaderMutationRulesConfig.builder() + .disallowAll(true) + .disallowIsError(true) + .allowExpression(allowPattern) + .disallowExpression(disallowPattern) + .build(); + assertTrue(config.disallowAll()); + assertTrue(config.disallowIsError()); + assertThat(config.allowExpression()).hasValue(allowPattern); + assertThat(config.disallowExpression()).hasValue(disallowPattern); + } +} diff --git a/xds/src/test/java/io/grpc/xds/internal/headermutations/HeaderMutationRulesParserTest.java b/xds/src/test/java/io/grpc/xds/internal/headermutations/HeaderMutationRulesParserTest.java new file mode 100644 index 00000000000..e880c197450 --- /dev/null +++ b/xds/src/test/java/io/grpc/xds/internal/headermutations/HeaderMutationRulesParserTest.java @@ -0,0 +1,90 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.headermutations; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; + +import com.google.protobuf.BoolValue; +import io.envoyproxy.envoy.config.common.mutation_rules.v3.HeaderMutationRules; +import io.envoyproxy.envoy.type.matcher.v3.RegexMatcher; +import io.grpc.xds.internal.headermutations.HeaderMutationRulesParseException; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class HeaderMutationRulesParserTest { + + @Test + public void parse_protoWithAllFields_success() throws Exception { + HeaderMutationRules proto = HeaderMutationRules.newBuilder() + .setAllowExpression(RegexMatcher.newBuilder().setRegex("allow-.*")) + .setDisallowExpression(RegexMatcher.newBuilder().setRegex("disallow-.*")) + .setDisallowAll(BoolValue.newBuilder().setValue(true).build()) + .setDisallowIsError(BoolValue.newBuilder().setValue(true).build()) + .build(); + + HeaderMutationRulesConfig config = HeaderMutationRulesParser.parse(proto); + + assertThat(config.allowExpression().isPresent()).isTrue(); + assertThat(config.allowExpression().get().pattern()).isEqualTo("allow-.*"); + + assertThat(config.disallowExpression().isPresent()).isTrue(); + assertThat(config.disallowExpression().get().pattern()).isEqualTo("disallow-.*"); + + assertThat(config.disallowAll()).isTrue(); + assertThat(config.disallowIsError()).isTrue(); + } + + @Test + public void parse_protoWithNoExpressions_success() throws Exception { + HeaderMutationRules proto = HeaderMutationRules.newBuilder().build(); + + HeaderMutationRulesConfig config = HeaderMutationRulesParser.parse(proto); + + assertThat(config.allowExpression().isPresent()).isFalse(); + assertThat(config.disallowExpression().isPresent()).isFalse(); + assertThat(config.disallowAll()).isFalse(); + assertThat(config.disallowIsError()).isFalse(); + } + + @Test + public void parse_invalidRegexAllowExpression_throwsHeaderMutationRulesParseException() { + HeaderMutationRules proto = HeaderMutationRules.newBuilder() + .setAllowExpression(RegexMatcher.newBuilder().setRegex("allow-[")) + .build(); + + HeaderMutationRulesParseException exception = assertThrows( + HeaderMutationRulesParseException.class, () -> HeaderMutationRulesParser.parse(proto)); + + assertThat(exception).hasMessageThat().contains("Invalid regex pattern for allow_expression"); + } + + @Test + public void parse_invalidRegexDisallowExpression_throwsHeaderMutationRulesParseException() { + HeaderMutationRules proto = HeaderMutationRules.newBuilder() + .setDisallowExpression(RegexMatcher.newBuilder().setRegex("disallow-[")) + .build(); + + HeaderMutationRulesParseException exception = assertThrows( + HeaderMutationRulesParseException.class, () -> HeaderMutationRulesParser.parse(proto)); + + assertThat(exception).hasMessageThat() + .contains("Invalid regex pattern for disallow_expression"); + } +} diff --git a/xds/src/test/java/io/grpc/xds/internal/headermutations/HeaderMutationsTest.java b/xds/src/test/java/io/grpc/xds/internal/headermutations/HeaderMutationsTest.java new file mode 100644 index 00000000000..ef7f22b7ac8 --- /dev/null +++ b/xds/src/test/java/io/grpc/xds/internal/headermutations/HeaderMutationsTest.java @@ -0,0 +1,40 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.headermutations; + +import static com.google.common.truth.Truth.assertThat; + +import com.google.common.collect.ImmutableList; +import io.grpc.xds.internal.grpcservice.HeaderValue; +import io.grpc.xds.internal.headermutations.HeaderValueOption.HeaderAppendAction; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class HeaderMutationsTest { + @Test + public void testCreate() { + HeaderValueOption header = HeaderValueOption.create( + HeaderValue.create("key", "value"), + HeaderAppendAction.APPEND_IF_EXISTS_OR_ADD, false); + HeaderMutations mutations = HeaderMutations.create( + ImmutableList.of(header), ImmutableList.of("remove-key")); + assertThat(mutations.headers()).containsExactly(header); + assertThat(mutations.headersToRemove()).containsExactly("remove-key"); + } +} diff --git a/xds/src/test/java/io/grpc/xds/internal/headermutations/HeaderMutatorTest.java b/xds/src/test/java/io/grpc/xds/internal/headermutations/HeaderMutatorTest.java new file mode 100644 index 00000000000..b6806760f9b --- /dev/null +++ b/xds/src/test/java/io/grpc/xds/internal/headermutations/HeaderMutatorTest.java @@ -0,0 +1,315 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.headermutations; + +import static com.google.common.truth.Truth.assertThat; + +import com.google.common.collect.ImmutableList; +import com.google.common.testing.TestLogHandler; +import com.google.protobuf.ByteString; +import io.grpc.Metadata; +import io.grpc.xds.internal.grpcservice.HeaderValue; +import io.grpc.xds.internal.headermutations.HeaderMutations; +import io.grpc.xds.internal.headermutations.HeaderValueOption.HeaderAppendAction; +import java.util.logging.Level; +import java.util.logging.Logger; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class HeaderMutatorTest { + + private static final Metadata.Key BINARY_KEY = + Metadata.Key.of("some-key-bin", Metadata.BINARY_BYTE_MARSHALLER); + private static final Metadata.Key APPEND_KEY = + Metadata.Key.of("append-key", Metadata.ASCII_STRING_MARSHALLER); + private static final Metadata.Key ADD_KEY = + Metadata.Key.of("add-key", Metadata.ASCII_STRING_MARSHALLER); + private static final Metadata.Key OVERWRITE_KEY = + Metadata.Key.of("overwrite-key", Metadata.ASCII_STRING_MARSHALLER); + private static final Metadata.Key REMOVE_KEY = + Metadata.Key.of("remove-key", Metadata.ASCII_STRING_MARSHALLER); + private static final Metadata.Key NEW_ADD_KEY = + Metadata.Key.of("new-add-key", Metadata.ASCII_STRING_MARSHALLER); + private static final Metadata.Key NEW_OVERWRITE_KEY = + Metadata.Key.of("new-overwrite-key", Metadata.ASCII_STRING_MARSHALLER); + private static final Metadata.Key OVERWRITE_IF_EXISTS_KEY = + Metadata.Key.of("overwrite-if-exists-key", Metadata.ASCII_STRING_MARSHALLER); + private static final Metadata.Key OVERWRITE_IF_EXISTS_ABSENT_KEY = + Metadata.Key.of("overwrite-if-exists-absent-key", Metadata.ASCII_STRING_MARSHALLER); + + private final HeaderMutator headerMutator = HeaderMutator.create(); + + private static final TestLogHandler logHandler = new TestLogHandler(); + private static final Logger logger = Logger.getLogger(HeaderMutator.class.getName()); + + @Before + public void setUp() { + logHandler.clear(); + logger.addHandler(logHandler); + logger.setLevel(Level.WARNING); + } + + @After + public void tearDown() { + logger.removeHandler(logHandler); + } + + private static HeaderValueOption header(String key, String value, HeaderAppendAction action) { + return HeaderValueOption.create(HeaderValue.create(key, value), action, false); + } + + @Test + public void applyMutations_asciiHeaders() { + Metadata headers = new Metadata(); + headers.put(APPEND_KEY, "append-value-1"); + headers.put(ADD_KEY, "add-value-original"); + headers.put(OVERWRITE_KEY, "overwrite-value-original"); + headers.put(REMOVE_KEY, "remove-value-original"); + headers.put(OVERWRITE_IF_EXISTS_KEY, "original-value"); + + HeaderMutations mutations = + HeaderMutations.create( + ImmutableList.of( + header( + APPEND_KEY.name(), + "append-value-2", + HeaderAppendAction.APPEND_IF_EXISTS_OR_ADD), + header(ADD_KEY.name(), "add-value-new", HeaderAppendAction.ADD_IF_ABSENT), + header(NEW_ADD_KEY.name(), "new-add-value", HeaderAppendAction.ADD_IF_ABSENT), + header( + OVERWRITE_KEY.name(), + "overwrite-value-new", + HeaderAppendAction.OVERWRITE_IF_EXISTS_OR_ADD), + header( + NEW_OVERWRITE_KEY.name(), + "new-overwrite-value", + HeaderAppendAction.OVERWRITE_IF_EXISTS_OR_ADD), + header( + OVERWRITE_IF_EXISTS_KEY.name(), + "new-value", + HeaderAppendAction.OVERWRITE_IF_EXISTS), + header( + OVERWRITE_IF_EXISTS_ABSENT_KEY.name(), + "new-value", + HeaderAppendAction.OVERWRITE_IF_EXISTS)), + ImmutableList.of(REMOVE_KEY.name())); + + headerMutator.applyMutations(mutations, headers); + + assertThat(headers.getAll(APPEND_KEY)).containsExactly("append-value-1", "append-value-2"); + assertThat(headers.get(ADD_KEY)).isEqualTo("add-value-original"); + assertThat(headers.get(NEW_ADD_KEY)).isEqualTo("new-add-value"); + assertThat(headers.get(OVERWRITE_KEY)).isEqualTo("overwrite-value-new"); + assertThat(headers.get(NEW_OVERWRITE_KEY)).isEqualTo("new-overwrite-value"); + assertThat(headers.containsKey(REMOVE_KEY)).isFalse(); + assertThat(headers.get(OVERWRITE_IF_EXISTS_KEY)).isEqualTo("new-value"); + assertThat(headers.containsKey(OVERWRITE_IF_EXISTS_ABSENT_KEY)).isFalse(); + } + + @Test + public void applyMutations_removalHasPriority() { + Metadata headers = new Metadata(); + headers.put(REMOVE_KEY, "value"); + HeaderMutations mutations = + HeaderMutations.create( + ImmutableList.of( + header( + REMOVE_KEY.name(), "new-value", HeaderAppendAction.OVERWRITE_IF_EXISTS_OR_ADD)), + ImmutableList.of(REMOVE_KEY.name())); + + headerMutator.applyMutations(mutations, headers); + + assertThat(headers.containsKey(REMOVE_KEY)).isFalse(); + } + + @Test + public void applyMutations_binary() { + Metadata headers = new Metadata(); + byte[] value = new byte[] {1, 2, 3}; + HeaderValueOption option = + HeaderValueOption.create( + HeaderValue.create(BINARY_KEY.name(), ByteString.copyFrom(value)), + HeaderAppendAction.APPEND_IF_EXISTS_OR_ADD, + false); + headerMutator.applyMutations( + HeaderMutations.create(ImmutableList.of(option), ImmutableList.of()), headers); + assertThat(headers.get(BINARY_KEY)).isEqualTo(value); + } + + @Test + public void applyResponseMutations_asciiHeaders() { + Metadata headers = new Metadata(); + headers.put(APPEND_KEY, "append-value-1"); + headers.put(ADD_KEY, "add-value-original"); + headers.put(OVERWRITE_KEY, "overwrite-value-original"); + + HeaderMutations mutations = + HeaderMutations.create( + ImmutableList.of( + header( + APPEND_KEY.name(), + "append-value-2", + HeaderAppendAction.APPEND_IF_EXISTS_OR_ADD), + header(ADD_KEY.name(), "add-value-new", HeaderAppendAction.ADD_IF_ABSENT), + header(NEW_ADD_KEY.name(), "new-add-value", HeaderAppendAction.ADD_IF_ABSENT), + header( + OVERWRITE_KEY.name(), + "overwrite-value-new", + HeaderAppendAction.OVERWRITE_IF_EXISTS_OR_ADD), + header( + NEW_OVERWRITE_KEY.name(), + "new-overwrite-value", + HeaderAppendAction.OVERWRITE_IF_EXISTS_OR_ADD)), ImmutableList.of()); + + headerMutator.applyMutations(mutations, headers); + + assertThat(headers.getAll(APPEND_KEY)).containsExactly("append-value-1", "append-value-2"); + assertThat(headers.get(ADD_KEY)).isEqualTo("add-value-original"); + assertThat(headers.get(NEW_ADD_KEY)).isEqualTo("new-add-value"); + assertThat(headers.get(OVERWRITE_KEY)).isEqualTo("overwrite-value-new"); + assertThat(headers.get(NEW_OVERWRITE_KEY)).isEqualTo("new-overwrite-value"); + } + + @Test + public void applyResponseMutations_binary() { + Metadata headers = new Metadata(); + byte[] value = new byte[] {1, 2, 3}; + HeaderValueOption option = + HeaderValueOption.create( + HeaderValue.create(BINARY_KEY.name(), ByteString.copyFrom(value)), + HeaderAppendAction.APPEND_IF_EXISTS_OR_ADD, + false); + headerMutator.applyMutations( + HeaderMutations.create(ImmutableList.of(option), ImmutableList.of()), headers); + assertThat(headers.get(BINARY_KEY)).isEqualTo(value); + } + + @Test + public void applyMutations_keepEmptyValue() { + Metadata headers = new Metadata(); + headers.put(APPEND_KEY, "existing-value"); + headers.put(OVERWRITE_KEY, "existing-value"); + headers.put(OVERWRITE_IF_EXISTS_KEY, "existing-value"); + + HeaderMutations mutations = + HeaderMutations.create( + ImmutableList.of( + header(NEW_ADD_KEY.name(), "", HeaderAppendAction.APPEND_IF_EXISTS_OR_ADD), + header(APPEND_KEY.name(), "", HeaderAppendAction.APPEND_IF_EXISTS_OR_ADD), + header(OVERWRITE_KEY.name(), "", HeaderAppendAction.OVERWRITE_IF_EXISTS_OR_ADD), + header(ADD_KEY.name(), "", HeaderAppendAction.ADD_IF_ABSENT), + header(OVERWRITE_IF_EXISTS_KEY.name(), "", HeaderAppendAction.OVERWRITE_IF_EXISTS), + HeaderValueOption.create( + HeaderValue.create("keep-empty-key", ""), + HeaderAppendAction.APPEND_IF_EXISTS_OR_ADD, + true), + HeaderValueOption.create( + HeaderValue.create("keep-empty-overwrite-key", ""), + HeaderAppendAction.OVERWRITE_IF_EXISTS_OR_ADD, + true), + HeaderValueOption.create( + HeaderValue.create("keep-empty-bin-key-bin", ByteString.EMPTY), + HeaderAppendAction.APPEND_IF_EXISTS_OR_ADD, true), + HeaderValueOption.create( + HeaderValue.create("ignore-empty-bin-key-bin", ByteString.EMPTY), + HeaderAppendAction.APPEND_IF_EXISTS_OR_ADD, false), + HeaderValueOption.create( + HeaderValue.create("overwrite-empty-bin-key-bin", ByteString.EMPTY), + HeaderAppendAction.OVERWRITE_IF_EXISTS_OR_ADD, false)), + ImmutableList.of()); + + headers.put( + Metadata.Key.of("keep-empty-overwrite-key", Metadata.ASCII_STRING_MARSHALLER), "old"); + + Metadata.Key overwriteEmptyBinKey = + Metadata.Key.of("overwrite-empty-bin-key-bin", Metadata.BINARY_BYTE_MARSHALLER); + byte[] originalBinValue = new byte[] {1, 2, 3}; + headers.put(overwriteEmptyBinKey, originalBinValue); + + headerMutator.applyMutations(mutations, headers); + + assertThat(headers.containsKey(NEW_ADD_KEY)).isFalse(); + assertThat(headers.getAll(APPEND_KEY)).containsExactly("existing-value"); + assertThat(headers.get(OVERWRITE_KEY)).isEqualTo("existing-value"); + assertThat(headers.containsKey(ADD_KEY)).isFalse(); + assertThat(headers.get(OVERWRITE_IF_EXISTS_KEY)).isEqualTo("existing-value"); + + Metadata.Key keepEmptyKey = + Metadata.Key.of("keep-empty-key", Metadata.ASCII_STRING_MARSHALLER); + Metadata.Key keepEmptyOverwriteKey = + Metadata.Key.of("keep-empty-overwrite-key", Metadata.ASCII_STRING_MARSHALLER); + + assertThat(headers.containsKey(keepEmptyKey)).isTrue(); + assertThat(headers.get(keepEmptyKey)).isEqualTo(""); + assertThat(headers.containsKey(keepEmptyOverwriteKey)).isTrue(); + assertThat(headers.get(keepEmptyOverwriteKey)).isEqualTo(""); + + Metadata.Key keepEmptyBinKey = + Metadata.Key.of("keep-empty-bin-key-bin", Metadata.BINARY_BYTE_MARSHALLER); + Metadata.Key ignoreEmptyBinKey = + Metadata.Key.of("ignore-empty-bin-key-bin", Metadata.BINARY_BYTE_MARSHALLER); + + assertThat(headers.containsKey(keepEmptyBinKey)).isTrue(); + assertThat(headers.get(keepEmptyBinKey)).isEqualTo(new byte[0]); + assertThat(headers.containsKey(ignoreEmptyBinKey)).isFalse(); + assertThat(headers.get(overwriteEmptyBinKey)).isEqualTo(originalBinValue); + } + + @Test + public void applyMutations_binaryRemoval() { + Metadata headers = new Metadata(); + byte[] value = new byte[] {1, 2, 3}; + headers.put(BINARY_KEY, value); + HeaderMutations mutations = + HeaderMutations.create(ImmutableList.of(), ImmutableList.of(BINARY_KEY.name())); + + headerMutator.applyMutations(mutations, headers); + + assertThat(headers.containsKey(BINARY_KEY)).isFalse(); + } + + @Test + public void applyMutations_stringValueWithBinaryKey_ignored() { + Metadata headers = new Metadata(); + HeaderValueOption option = HeaderValueOption.create(HeaderValue.create("some-key-bin", "value"), + HeaderAppendAction.APPEND_IF_EXISTS_OR_ADD, false); + + headerMutator.applyMutations( + HeaderMutations.create(ImmutableList.of(option), ImmutableList.of()), headers); + + Metadata.Key key = Metadata.Key.of("some-key-bin", Metadata.BINARY_BYTE_MARSHALLER); + assertThat(headers.containsKey(key)).isFalse(); + } + + @Test + public void applyMutations_binaryValueWithAsciiKey_ignored() { + Metadata headers = new Metadata(); + HeaderValueOption option = HeaderValueOption.create( + HeaderValue.create("some-key", ByteString.copyFrom(new byte[] {1})), + HeaderAppendAction.APPEND_IF_EXISTS_OR_ADD, false); + + headerMutator.applyMutations( + HeaderMutations.create(ImmutableList.of(option), ImmutableList.of()), headers); + + Metadata.Key key = Metadata.Key.of("some-key", Metadata.ASCII_STRING_MARSHALLER); + assertThat(headers.containsKey(key)).isFalse(); + } +} diff --git a/xds/src/test/java/io/grpc/xds/internal/headermutations/HeaderValueOptionTest.java b/xds/src/test/java/io/grpc/xds/internal/headermutations/HeaderValueOptionTest.java new file mode 100644 index 00000000000..49c43749135 --- /dev/null +++ b/xds/src/test/java/io/grpc/xds/internal/headermutations/HeaderValueOptionTest.java @@ -0,0 +1,40 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.headermutations; + +import static com.google.common.truth.Truth.assertThat; + +import io.grpc.xds.internal.grpcservice.HeaderValue; +import io.grpc.xds.internal.headermutations.HeaderValueOption.HeaderAppendAction; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class HeaderValueOptionTest { + + @Test + public void create_withAllFields_success() { + HeaderValue header = HeaderValue.create("key1", "value1"); + HeaderValueOption option = HeaderValueOption.create( + header, HeaderAppendAction.APPEND_IF_EXISTS_OR_ADD, true); + + assertThat(option.header()).isEqualTo(header); + assertThat(option.appendAction()).isEqualTo(HeaderAppendAction.APPEND_IF_EXISTS_OR_ADD); + assertThat(option.keepEmptyValue()).isTrue(); + } +} diff --git a/xds/src/test/java/io/grpc/xds/internal/rbac/engine/GrpcAuthorizationEngineTest.java b/xds/src/test/java/io/grpc/xds/internal/rbac/engine/GrpcAuthorizationEngineTest.java index 4fb38f661e1..10287c11262 100644 --- a/xds/src/test/java/io/grpc/xds/internal/rbac/engine/GrpcAuthorizationEngineTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/rbac/engine/GrpcAuthorizationEngineTest.java @@ -16,8 +16,8 @@ package io.grpc.xds.internal.rbac.engine; -import static com.google.common.base.Charsets.US_ASCII; import static com.google.common.truth.Truth.assertThat; +import static java.nio.charset.StandardCharsets.US_ASCII; import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; diff --git a/xds/src/test/java/io/grpc/xds/internal/security/ClientSslContextProviderFactoryTest.java b/xds/src/test/java/io/grpc/xds/internal/security/ClientSslContextProviderFactoryTest.java index 4de881c710e..a0eac581d5c 100644 --- a/xds/src/test/java/io/grpc/xds/internal/security/ClientSslContextProviderFactoryTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/security/ClientSslContextProviderFactoryTest.java @@ -28,18 +28,17 @@ import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.TlsCertificate; import io.envoyproxy.envoy.type.matcher.v3.StringMatcher; -import io.grpc.xds.CommonBootstrapperTestUtils; import io.grpc.xds.EnvoyServerProtoData.UpstreamTlsContext; import io.grpc.xds.client.Bootstrapper; +import io.grpc.xds.client.CommonBootstrapperTestUtils; import io.grpc.xds.client.XdsInitializationException; import io.grpc.xds.internal.security.certprovider.CertProviderClientSslContextProviderFactory; import io.grpc.xds.internal.security.certprovider.CertificateProvider; import io.grpc.xds.internal.security.certprovider.CertificateProviderProvider; import io.grpc.xds.internal.security.certprovider.CertificateProviderRegistry; import io.grpc.xds.internal.security.certprovider.CertificateProviderStore; +import io.grpc.xds.internal.security.certprovider.IgnoreUpdatesWatcher; import io.grpc.xds.internal.security.certprovider.TestCertificateProvider; -import java.io.IOException; -import org.junit.Assert; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @@ -86,7 +85,7 @@ public void createCertProviderClientSslContextProvider() throws XdsInitializatio clientSslContextProviderFactory.create(upstreamTlsContext); assertThat(sslContextProvider.getClass().getSimpleName()).isEqualTo( "CertProviderClientSslContextProvider"); - verifyWatcher(sslContextProvider, watcherCaptor[0]); + verifyWatcher(sslContextProvider, watcherCaptor[0], false); // verify that bootstrapInfo is cached... sslContextProvider = clientSslContextProviderFactory.create(upstreamTlsContext); @@ -121,7 +120,7 @@ public void bothPresent_expectCertProviderClientSslContextProvider() clientSslContextProviderFactory.create(upstreamTlsContext); assertThat(sslContextProvider.getClass().getSimpleName()).isEqualTo( "CertProviderClientSslContextProvider"); - verifyWatcher(sslContextProvider, watcherCaptor[0]); + verifyWatcher(sslContextProvider, watcherCaptor[0], true); } @Test @@ -147,7 +146,7 @@ public void createCertProviderClientSslContextProvider_onlyRootCert() clientSslContextProviderFactory.create(upstreamTlsContext); assertThat(sslContextProvider.getClass().getSimpleName()).isEqualTo( "CertProviderClientSslContextProvider"); - verifyWatcher(sslContextProvider, watcherCaptor[0]); + verifyWatcher(sslContextProvider, watcherCaptor[0], true); } @Test @@ -181,7 +180,7 @@ public void createCertProviderClientSslContextProvider_withStaticContext() clientSslContextProviderFactory.create(upstreamTlsContext); assertThat(sslContextProvider.getClass().getSimpleName()).isEqualTo( "CertProviderClientSslContextProvider"); - verifyWatcher(sslContextProvider, watcherCaptor[0]); + verifyWatcher(sslContextProvider, watcherCaptor[0], true); } @Test @@ -211,8 +210,8 @@ public void createCertProviderClientSslContextProvider_2providers() clientSslContextProviderFactory.create(upstreamTlsContext); assertThat(sslContextProvider.getClass().getSimpleName()).isEqualTo( "CertProviderClientSslContextProvider"); - verifyWatcher(sslContextProvider, watcherCaptor[0]); - verifyWatcher(sslContextProvider, watcherCaptor[1]); + verifyWatcher(sslContextProvider, watcherCaptor[0], true); + verifyWatcher(sslContextProvider, watcherCaptor[1], true); } @Test @@ -248,8 +247,8 @@ public void createNewCertProviderClientSslContextProvider_withSans() { clientSslContextProviderFactory.create(upstreamTlsContext); assertThat(sslContextProvider.getClass().getSimpleName()).isEqualTo( "CertProviderClientSslContextProvider"); - verifyWatcher(sslContextProvider, watcherCaptor[0]); - verifyWatcher(sslContextProvider, watcherCaptor[1]); + verifyWatcher(sslContextProvider, watcherCaptor[0], true); + verifyWatcher(sslContextProvider, watcherCaptor[1], true); } @Test @@ -282,23 +281,7 @@ public void createNewCertProviderClientSslContextProvider_onlyRootCert() { clientSslContextProviderFactory.create(upstreamTlsContext); assertThat(sslContextProvider.getClass().getSimpleName()).isEqualTo( "CertProviderClientSslContextProvider"); - verifyWatcher(sslContextProvider, watcherCaptor[0]); - } - - @Test - public void createNullCommonTlsContext_exception() throws IOException { - clientSslContextProviderFactory = - new ClientSslContextProviderFactory( - null, certProviderClientSslContextProviderFactory); - UpstreamTlsContext upstreamTlsContext = new UpstreamTlsContext(null); - try { - clientSslContextProviderFactory.create(upstreamTlsContext); - Assert.fail("no exception thrown"); - } catch (NullPointerException expected) { - assertThat(expected) - .hasMessageThat() - .isEqualTo("upstreamTlsContext should have CommonTlsContext"); - } + verifyWatcher(sslContextProvider, watcherCaptor[0], true); } static void createAndRegisterProviderProvider( @@ -328,14 +311,20 @@ public CertificateProvider answer(InvocationOnMock invocation) throws Throwable } static void verifyWatcher( - SslContextProvider sslContextProvider, CertificateProvider.DistributorWatcher watcherCaptor) { + SslContextProvider sslContextProvider, CertificateProvider.DistributorWatcher watcherCaptor, + boolean usesDelegateWatcher) { assertThat(watcherCaptor).isNotNull(); assertThat(watcherCaptor.getDownstreamWatchers()).hasSize(1); - assertThat(watcherCaptor.getDownstreamWatchers().iterator().next()) - .isSameInstanceAs(sslContextProvider); + if (usesDelegateWatcher) { + assertThat(((IgnoreUpdatesWatcher) watcherCaptor.getDownstreamWatchers().iterator().next()) + .getDelegate()) + .isSameInstanceAs(sslContextProvider); + } else { + assertThat(watcherCaptor.getDownstreamWatchers().iterator().next()) + .isSameInstanceAs(sslContextProvider); + } } - @SuppressWarnings("deprecation") static CommonTlsContext.Builder addFilenames( CommonTlsContext.Builder builder, String certChain, String privateKey, String trustCa) { TlsCertificate tlsCert = @@ -347,13 +336,10 @@ static CommonTlsContext.Builder addFilenames( CertificateValidationContext.newBuilder() .setTrustedCa(DataSource.newBuilder().setFilename(trustCa)) .build(); - CommonTlsContext.CertificateProviderInstance certificateProviderInstance = - builder.getValidationContextCertificateProviderInstance(); CommonTlsContext.CombinedCertificateValidationContext.Builder combinedBuilder = CommonTlsContext.CombinedCertificateValidationContext.newBuilder(); combinedBuilder - .setDefaultValidationContext(certContext) - .setValidationContextCertificateProviderInstance(certificateProviderInstance); + .setDefaultValidationContext(certContext); return builder .addTlsCertificates(tlsCert) .setCombinedValidationContext(combinedBuilder.build()); diff --git a/xds/src/test/java/io/grpc/xds/internal/security/CommonTlsContextTestsUtil.java b/xds/src/test/java/io/grpc/xds/internal/security/CommonTlsContextTestsUtil.java index 8a04a3d02a7..abacd2038f8 100644 --- a/xds/src/test/java/io/grpc/xds/internal/security/CommonTlsContextTestsUtil.java +++ b/xds/src/test/java/io/grpc/xds/internal/security/CommonTlsContextTestsUtil.java @@ -23,7 +23,6 @@ import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateProviderPluginInstance; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateValidationContext; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext; -import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext.CertificateProviderInstance; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext.CombinedCertificateValidationContext; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.DownstreamTlsContext; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.UpstreamTlsContext; @@ -37,10 +36,12 @@ import java.io.InputStream; import java.security.cert.CertificateException; import java.security.cert.X509Certificate; +import java.util.AbstractMap; import java.util.Arrays; import java.util.List; import java.util.concurrent.Executor; import javax.annotation.Nullable; +import javax.net.ssl.X509TrustManager; /** Utility class for client and server ssl provider tests. */ public class CommonTlsContextTestsUtil { @@ -48,59 +49,43 @@ public class CommonTlsContextTestsUtil { public static final String SERVER_0_PEM_FILE = "server0.pem"; public static final String SERVER_0_KEY_FILE = "server0.key"; public static final String SERVER_1_PEM_FILE = "server1.pem"; + public static final String SERVER_1_SPIFFE_PEM_FILE = "server1_spiffe.pem"; public static final String SERVER_1_KEY_FILE = "server1.key"; public static final String CLIENT_PEM_FILE = "client.pem"; + public static final String CLIENT_SPIFFE_PEM_FILE = "client_spiffe.pem"; public static final String CLIENT_KEY_FILE = "client.key"; public static final String CA_PEM_FILE = "ca.pem"; + public static final String SPIFFE_TRUST_MAP_FILE = "spiffebundle.json"; + public static final String SPIFFE_TRUST_MAP_1_FILE = "spiffebundle1.json"; /** Bad/untrusted server certs. */ public static final String BAD_SERVER_PEM_FILE = "badserver.pem"; public static final String BAD_SERVER_KEY_FILE = "badserver.key"; public static final String BAD_CLIENT_PEM_FILE = "badclient.pem"; public static final String BAD_CLIENT_KEY_FILE = "badclient.key"; + public static final String BAD_WILDCARD_DNS_PEM_FILE = + "sni-test-certs/bad_wildcard_dns_certificate.pem"; /** takes additional values and creates CombinedCertificateValidationContext as needed. */ - @SuppressWarnings("deprecation") - static CommonTlsContext buildCommonTlsContextWithAdditionalValues( + private static CommonTlsContext buildCommonTlsContextWithAdditionalValues( String certInstanceName, String certName, String validationContextCertInstanceName, String validationContextCertName, Iterable matchSubjectAltNames, Iterable alpnNames) { - - CommonTlsContext.Builder builder = CommonTlsContext.newBuilder(); - - CertificateProviderInstance certificateProviderInstance = CertificateProviderInstance - .newBuilder().setInstanceName(certInstanceName).setCertificateName(certName).build(); - if (certificateProviderInstance != null) { - builder.setTlsCertificateCertificateProviderInstance(certificateProviderInstance); - } - CertificateProviderInstance validationCertificateProviderInstance = - CertificateProviderInstance.newBuilder().setInstanceName(validationContextCertInstanceName) - .setCertificateName(validationContextCertName).build(); - CertificateValidationContext certValidationContext = - matchSubjectAltNames == null - ? null - : CertificateValidationContext.newBuilder() - .addAllMatchSubjectAltNames(matchSubjectAltNames) - .build(); - if (validationCertificateProviderInstance != null) { - CombinedCertificateValidationContext.Builder combinedBuilder = - CombinedCertificateValidationContext.newBuilder() - .setValidationContextCertificateProviderInstance( - validationCertificateProviderInstance); - if (certValidationContext != null) { - combinedBuilder = combinedBuilder.setDefaultValidationContext(certValidationContext); - } - builder.setCombinedValidationContext(combinedBuilder); - } else if (validationCertificateProviderInstance != null) { - builder - .setValidationContextCertificateProviderInstance(validationCertificateProviderInstance); - } else if (certValidationContext != null) { - builder.setValidationContext(certValidationContext); - } - if (alpnNames != null) { - builder.addAllAlpnProtocols(alpnNames); - } - return builder.build(); + @SuppressWarnings("deprecation") // gRFC A29 predates match_typed_subject_alt_names + CertificateValidationContext.Builder certificateValidationContextBuilder + = CertificateValidationContext.newBuilder() + .addAllMatchSubjectAltNames(matchSubjectAltNames); + return CommonTlsContext.newBuilder() + .setTlsCertificateProviderInstance(CertificateProviderPluginInstance.newBuilder() + .setInstanceName(certInstanceName) + .setCertificateName(certName)) + .setCombinedValidationContext(CombinedCertificateValidationContext.newBuilder() + .setDefaultValidationContext(certificateValidationContextBuilder + .setCaCertificateProviderInstance(CertificateProviderPluginInstance.newBuilder() + .setInstanceName(validationContextCertInstanceName) + .setCertificateName(validationContextCertName)))) + .addAllAlpnProtocols(alpnNames) + .build(); } /** Helper method to build DownstreamTlsContext for multiple test classes. */ @@ -148,7 +133,7 @@ public static DownstreamTlsContext buildTestDownstreamTlsContext( useSans ? Arrays.asList( StringMatcher.newBuilder() .setExact("spiffe://grpc-sds-testing.svc.id.goog/ns/default/sa/bob") - .build()) : null, + .build()) : Arrays.asList(), Arrays.asList("managed-tls")); } return buildDownstreamTlsContext(commonTlsContext, /* requireClientCert= */ false); @@ -168,11 +153,24 @@ public static String getTempFileNameForResourcesFile(String resFile) throws IOEx * Helper method to build UpstreamTlsContext for above tests. Called from other classes as well. */ static EnvoyServerProtoData.UpstreamTlsContext buildUpstreamTlsContext( - CommonTlsContext commonTlsContext) { - UpstreamTlsContext upstreamTlsContext = - UpstreamTlsContext.newBuilder().setCommonTlsContext(commonTlsContext).build(); + CommonTlsContext commonTlsContext) { + return buildUpstreamTlsContext(commonTlsContext, "", false, false); + } + + /** + * Helper method to build UpstreamTlsContext with SNI info. + */ + static EnvoyServerProtoData.UpstreamTlsContext buildUpstreamTlsContext( + CommonTlsContext commonTlsContext, String sni, boolean autoHostSni, + boolean autoSniSanValidation) { + UpstreamTlsContext.Builder upstreamTlsContext = + UpstreamTlsContext.newBuilder() + .setCommonTlsContext(commonTlsContext) + .setAutoHostSni(autoHostSni) + .setAutoSniSanValidation(autoSniSanValidation) + .setSni(sni); return EnvoyServerProtoData.UpstreamTlsContext.fromEnvoyProtoUpstreamTlsContext( - upstreamTlsContext); + upstreamTlsContext.build()); } /** Helper method to build UpstreamTlsContext for multiple test classes. */ @@ -187,6 +185,21 @@ public static EnvoyServerProtoData.UpstreamTlsContext buildUpstreamTlsContext( null); } + /** Helper method to build UpstreamTlsContext with SNI info. */ + public static EnvoyServerProtoData.UpstreamTlsContext buildUpstreamTlsContext( + String commonInstanceName, boolean hasIdentityCert, String sni, boolean autoHostSni) { + return buildUpstreamTlsContextForCertProviderInstance( + hasIdentityCert ? commonInstanceName : null, + hasIdentityCert ? "default" : null, + commonInstanceName, + "ROOT", + null, + null, + sni, + autoHostSni, + false); + } + /** Gets a cert from contents of a resource. */ public static X509Certificate getCertFromResourceName(String resourceName) throws IOException, CertificateException { @@ -195,7 +208,6 @@ public static X509Certificate getCertFromResourceName(String resourceName) } } - @SuppressWarnings("deprecation") private static CommonTlsContext buildCommonTlsContextForCertProviderInstance( String certInstanceName, String certName, @@ -206,10 +218,37 @@ private static CommonTlsContext buildCommonTlsContextForCertProviderInstance( CommonTlsContext.Builder builder = CommonTlsContext.newBuilder(); if (certInstanceName != null) { builder = - builder.setTlsCertificateCertificateProviderInstance( - CommonTlsContext.CertificateProviderInstance.newBuilder() - .setInstanceName(certInstanceName) - .setCertificateName(certName)); + builder.setTlsCertificateProviderInstance( + CertificateProviderPluginInstance.newBuilder() + .setInstanceName(certInstanceName) + .setCertificateName(certName)); + } + builder = + addCertificateValidationContext( + builder, rootInstanceName, rootCertName, staticCertValidationContext); + if (alpnProtocols != null) { + builder.addAllAlpnProtocols(alpnProtocols); + } + return builder.build(); + } + + /** Helper method to build CommonTlsContext using deprecated certificate provider field. */ + @SuppressWarnings("deprecation") + public static CommonTlsContext buildCommonTlsContextWithDeprecatedCertProviderInstance( + String certInstanceName, + String certName, + String rootInstanceName, + String rootCertName, + Iterable alpnProtocols, + CertificateValidationContext staticCertValidationContext) { + CommonTlsContext.Builder builder = CommonTlsContext.newBuilder(); + if (certInstanceName != null) { + // Use deprecated field (field 11) instead of current field (field 14) + builder = + builder.setTlsCertificateCertificateProviderInstance( + CommonTlsContext.CertificateProviderInstance.newBuilder() + .setInstanceName(certInstanceName) + .setCertificateName(certName)); } builder = addCertificateValidationContext( @@ -244,29 +283,28 @@ private static CommonTlsContext buildNewCommonTlsContextForCertProviderInstance( return builder.build(); } - @SuppressWarnings("deprecation") private static CommonTlsContext.Builder addCertificateValidationContext( CommonTlsContext.Builder builder, String rootInstanceName, String rootCertName, CertificateValidationContext staticCertValidationContext) { + if (staticCertValidationContext == null && rootInstanceName == null) { + return builder; + } + CertificateValidationContext.Builder contextBuilder; + if (staticCertValidationContext == null) { + contextBuilder = CertificateValidationContext.newBuilder(); + } else { + contextBuilder = staticCertValidationContext.toBuilder(); + } if (rootInstanceName != null) { - CertificateProviderInstance providerInstance = - CertificateProviderInstance.newBuilder() - .setInstanceName(rootInstanceName) - .setCertificateName(rootCertName) - .build(); - if (staticCertValidationContext != null) { - CombinedCertificateValidationContext combined = - CombinedCertificateValidationContext.newBuilder() - .setDefaultValidationContext(staticCertValidationContext) - .setValidationContextCertificateProviderInstance(providerInstance) - .build(); - return builder.setCombinedValidationContext(combined); - } - builder = builder.setValidationContextCertificateProviderInstance(providerInstance); + contextBuilder.setCaCertificateProviderInstance(CertificateProviderPluginInstance.newBuilder() + .setInstanceName(rootInstanceName) + .setCertificateName(rootCertName)); + builder.setValidationContext(contextBuilder.build()); } - return builder; + return builder.setCombinedValidationContext(CombinedCertificateValidationContext.newBuilder() + .setDefaultValidationContext(contextBuilder)); } private static CommonTlsContext.Builder addNewCertificateValidationContext( @@ -274,19 +312,19 @@ private static CommonTlsContext.Builder addNewCertificateValidationContext( String rootInstanceName, String rootCertName, CertificateValidationContext staticCertValidationContext) { + CertificateValidationContext.Builder validationContextBuilder = + staticCertValidationContext != null ? staticCertValidationContext.toBuilder() + : CertificateValidationContext.newBuilder(); if (rootInstanceName != null) { CertificateProviderPluginInstance providerInstance = CertificateProviderPluginInstance.newBuilder() .setInstanceName(rootInstanceName) .setCertificateName(rootCertName) .build(); - CertificateValidationContext.Builder validationContextBuilder = - staticCertValidationContext != null ? staticCertValidationContext.toBuilder() - : CertificateValidationContext.newBuilder(); - return builder.setValidationContext( - validationContextBuilder.setCaCertificateProviderInstance(providerInstance)); + validationContextBuilder = validationContextBuilder.setCaCertificateProviderInstance( + providerInstance); } - return builder; + return builder.setValidationContext(validationContextBuilder); } /** Helper method to build UpstreamTlsContext for CertProvider tests. */ @@ -305,7 +343,31 @@ private static CommonTlsContext.Builder addNewCertificateValidationContext( rootInstanceName, rootCertName, alpnProtocols, - staticCertValidationContext)); + staticCertValidationContext), + "", false, false); + } + + /** Helper method to build UpstreamTlsContext with SNI info for CertProvider tests. */ + public static EnvoyServerProtoData.UpstreamTlsContext + buildUpstreamTlsContextForCertProviderInstance( + @Nullable String certInstanceName, + @Nullable String certName, + @Nullable String rootInstanceName, + @Nullable String rootCertName, + Iterable alpnProtocols, + CertificateValidationContext staticCertValidationContext, + String sni, + boolean autoHostSni, + boolean autoSniSanValidation) { + return buildUpstreamTlsContext( + buildCommonTlsContextForCertProviderInstance( + certInstanceName, + certName, + rootInstanceName, + rootCertName, + alpnProtocols, + staticCertValidationContext), + sni, autoHostSni, autoSniSanValidation); } /** Helper method to build UpstreamTlsContext for CertProvider tests. */ @@ -324,7 +386,8 @@ private static CommonTlsContext.Builder addNewCertificateValidationContext( rootInstanceName, rootCertName, alpnProtocols, - staticCertValidationContext)); + staticCertValidationContext), + "", false, false); } /** Helper method to build DownstreamTlsContext for CertProvider tests. */ @@ -368,14 +431,15 @@ private static CommonTlsContext.Builder addNewCertificateValidationContext( } /** Perform some simple checks on sslContext. */ - public static void doChecksOnSslContext(boolean server, SslContext sslContext, + public static void doChecksOnSslContext(boolean server, + AbstractMap.SimpleImmutableEntry sslContextAndTm, List expectedApnProtos) { if (server) { - assertThat(sslContext.isServer()).isTrue(); + assertThat(sslContextAndTm.getKey().isServer()).isTrue(); } else { - assertThat(sslContext.isClient()).isTrue(); + assertThat(sslContextAndTm.getKey().isClient()).isTrue(); } - List apnProtos = sslContext.applicationProtocolNegotiator().protocols(); + List apnProtos = sslContextAndTm.getKey().applicationProtocolNegotiator().protocols(); assertThat(apnProtos).isNotNull(); if (expectedApnProtos != null) { assertThat(apnProtos).isEqualTo(expectedApnProtos); @@ -401,7 +465,7 @@ public static TestCallback getValueThruCallback(SslContextProvider provider, Exe public static class TestCallback extends SslContextProvider.Callback { - public SslContext updatedSslContext; + public AbstractMap.SimpleImmutableEntry updatedSslContext; public Throwable updatedThrowable; public TestCallback(Executor executor) { @@ -409,7 +473,8 @@ public TestCallback(Executor executor) { } @Override - public void updateSslContext(SslContext sslContext) { + public void updateSslContextAndExtendedX509TrustManager( + AbstractMap.SimpleImmutableEntry sslContext) { updatedSslContext = sslContext; } diff --git a/xds/src/test/java/io/grpc/xds/internal/security/SecurityProtocolNegotiatorsTest.java b/xds/src/test/java/io/grpc/xds/internal/security/SecurityProtocolNegotiatorsTest.java index da7f8113dfa..125b7e65aa6 100644 --- a/xds/src/test/java/io/grpc/xds/internal/security/SecurityProtocolNegotiatorsTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/security/SecurityProtocolNegotiatorsTest.java @@ -45,12 +45,12 @@ import io.grpc.netty.InternalProtocolNegotiator.ProtocolNegotiator; import io.grpc.netty.InternalProtocolNegotiators; import io.grpc.netty.ProtocolNegotiationEvent; -import io.grpc.xds.CommonBootstrapperTestUtils; import io.grpc.xds.EnvoyServerProtoData.DownstreamTlsContext; import io.grpc.xds.EnvoyServerProtoData.UpstreamTlsContext; -import io.grpc.xds.InternalXdsAttributes; import io.grpc.xds.TlsContextManager; import io.grpc.xds.client.Bootstrapper; +import io.grpc.xds.client.CommonBootstrapperTestUtils; +import io.grpc.xds.internal.XdsInternalAttributes; import io.grpc.xds.internal.security.SecurityProtocolNegotiators.ClientSecurityHandler; import io.grpc.xds.internal.security.SecurityProtocolNegotiators.ClientSecurityProtocolNegotiator; import io.grpc.xds.internal.security.certprovider.CommonCertProviderTestUtils; @@ -74,11 +74,13 @@ import java.net.InetSocketAddress; import java.net.SocketAddress; import java.security.cert.CertStoreException; +import java.util.AbstractMap; import java.util.Iterator; import java.util.Map; import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; +import javax.net.ssl.X509TrustManager; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -87,6 +89,10 @@ @RunWith(JUnit4.class) public class SecurityProtocolNegotiatorsTest { + private static final String HOSTNAME = "hostname"; + private static final String SNI_IN_UTC = "sni-in-upstream-tls-context"; + private static final String FAKE_AUTHORITY = "authority"; + private final GrpcHttp2ConnectionHandler grpcHandler = FakeGrpcHttp2ConnectionHandler.newHandler(); @@ -122,8 +128,31 @@ public void clientSecurityProtocolNegotiatorNewHandler_noFallback_expectExceptio @Test public void clientSecurityProtocolNegotiatorNewHandler_withTlsContextAttribute() { + UpstreamTlsContext upstreamTlsContext = CommonTlsContextTestsUtil.buildUpstreamTlsContext( + CommonTlsContext.newBuilder().build()); + ClientSecurityProtocolNegotiator pn = + new ClientSecurityProtocolNegotiator(InternalProtocolNegotiators.plaintext()); + GrpcHttp2ConnectionHandler mockHandler = mock(GrpcHttp2ConnectionHandler.class); + ChannelLogger logger = mock(ChannelLogger.class); + doNothing().when(logger).log(any(ChannelLogLevel.class), anyString()); + when(mockHandler.getNegotiationLogger()).thenReturn(logger); + TlsContextManager mockTlsContextManager = mock(TlsContextManager.class); + when(mockHandler.getEagAttributes()) + .thenReturn( + Attributes.newBuilder() + .set(SecurityProtocolNegotiators.ATTR_SSL_CONTEXT_PROVIDER_SUPPLIER, + new SslContextProviderSupplier(upstreamTlsContext, mockTlsContextManager)) + .build()); + ChannelHandler newHandler = pn.newHandler(mockHandler); + assertThat(newHandler).isNotNull(); + assertThat(newHandler).isInstanceOf(ClientSecurityHandler.class); + } + + @Test + public void clientSecurityProtocolNegotiator_autoHostSni_hostnamePassedToClientSecurityHandlr() { UpstreamTlsContext upstreamTlsContext = - CommonTlsContextTestsUtil.buildUpstreamTlsContext(CommonTlsContext.newBuilder().build()); + CommonTlsContextTestsUtil.buildUpstreamTlsContext( + CommonTlsContext.newBuilder().build(), "", true, false); ClientSecurityProtocolNegotiator pn = new ClientSecurityProtocolNegotiator(InternalProtocolNegotiators.plaintext()); GrpcHttp2ConnectionHandler mockHandler = mock(GrpcHttp2ConnectionHandler.class); @@ -134,12 +163,14 @@ public void clientSecurityProtocolNegotiatorNewHandler_withTlsContextAttribute() when(mockHandler.getEagAttributes()) .thenReturn( Attributes.newBuilder() - .set(InternalXdsAttributes.ATTR_SSL_CONTEXT_PROVIDER_SUPPLIER, + .set(SecurityProtocolNegotiators.ATTR_SSL_CONTEXT_PROVIDER_SUPPLIER, new SslContextProviderSupplier(upstreamTlsContext, mockTlsContextManager)) + .set(XdsInternalAttributes.ATTR_ADDRESS_NAME, FAKE_AUTHORITY) .build()); ChannelHandler newHandler = pn.newHandler(mockHandler); assertThat(newHandler).isNotNull(); assertThat(newHandler).isInstanceOf(ClientSecurityHandler.class); + assertThat(((ClientSecurityHandler) newHandler).getSni()).isEqualTo(FAKE_AUTHORITY); } @Test @@ -149,7 +180,7 @@ public void clientSecurityHandler_addLast() CommonCertProviderTestUtils.register(executor); Bootstrapper.BootstrapInfo bootstrapInfoForClient = CommonBootstrapperTestUtils .buildBootstrapInfo("google_cloud_private_spiffe-client", CLIENT_KEY_FILE, CLIENT_PEM_FILE, - CA_PEM_FILE, null, null, null, null); + CA_PEM_FILE, null, null, null, null, null); UpstreamTlsContext upstreamTlsContext = CommonTlsContextTestsUtil .buildUpstreamTlsContext("google_cloud_private_spiffe-client", true); @@ -158,7 +189,7 @@ public void clientSecurityHandler_addLast() new SslContextProviderSupplier(upstreamTlsContext, new TlsContextManagerImpl(bootstrapInfoForClient)); ClientSecurityHandler clientSecurityHandler = - new ClientSecurityHandler(grpcHandler, sslContextProviderSupplier); + new ClientSecurityHandler(grpcHandler, sslContextProviderSupplier, HOSTNAME); pipeline.addLast(clientSecurityHandler); channelHandlerCtx = pipeline.context(clientSecurityHandler); assertNotNull(channelHandlerCtx); @@ -169,19 +200,20 @@ public void clientSecurityHandler_addLast() sslContextProviderSupplier .updateSslContext(new SslContextProvider.Callback(MoreExecutors.directExecutor()) { @Override - public void updateSslContext(SslContext sslContext) { - future.set(sslContext); + public void updateSslContextAndExtendedX509TrustManager( + AbstractMap.SimpleImmutableEntry sslContextAndTm) { + future.set(sslContextAndTm); } @Override protected void onException(Throwable throwable) { future.set(throwable); } - }); + }, true); assertThat(executor.runDueTasks()).isEqualTo(1); channel.runPendingTasks(); Object fromFuture = future.get(2, TimeUnit.SECONDS); - assertThat(fromFuture).isInstanceOf(SslContext.class); + assertThat(fromFuture).isInstanceOf(AbstractMap.SimpleImmutableEntry.class); channel.runPendingTasks(); channelHandlerCtx = pipeline.context(clientSecurityHandler); assertThat(channelHandlerCtx).isNull(); @@ -195,6 +227,75 @@ protected void onException(Throwable throwable) { CommonCertProviderTestUtils.register0(); } + @Test + public void sniInClientSecurityHandler_autoHostSniIsTrue_usesEndpointHostname() { + Bootstrapper.BootstrapInfo bootstrapInfoForClient = CommonBootstrapperTestUtils + .buildBootstrapInfo("google_cloud_private_spiffe-client", CLIENT_KEY_FILE, + CLIENT_PEM_FILE, CA_PEM_FILE, null, null, null, null, null); + UpstreamTlsContext upstreamTlsContext = + CommonTlsContextTestsUtil + .buildUpstreamTlsContext("google_cloud_private_spiffe-client", true, "", true); + SslContextProviderSupplier sslContextProviderSupplier = + new SslContextProviderSupplier(upstreamTlsContext, + new TlsContextManagerImpl(bootstrapInfoForClient)); + + ClientSecurityHandler clientSecurityHandler = + new ClientSecurityHandler(grpcHandler, sslContextProviderSupplier, HOSTNAME); + + assertThat(clientSecurityHandler.getSni()).isEqualTo(HOSTNAME); + } + + @Test + public void sniInClientSecurityHandler_autoHostSni_endpointHostnameIsEmpty_usesSniFromUtc() { + Bootstrapper.BootstrapInfo bootstrapInfoForClient = CommonBootstrapperTestUtils + .buildBootstrapInfo("google_cloud_private_spiffe-client", CLIENT_KEY_FILE, + CLIENT_PEM_FILE, CA_PEM_FILE, null, null, null, null, null); + UpstreamTlsContext upstreamTlsContext = CommonTlsContextTestsUtil.buildUpstreamTlsContext( + "google_cloud_private_spiffe-client", true, SNI_IN_UTC, true); + SslContextProviderSupplier sslContextProviderSupplier = + new SslContextProviderSupplier(upstreamTlsContext, + new TlsContextManagerImpl(bootstrapInfoForClient)); + + ClientSecurityHandler clientSecurityHandler = + new ClientSecurityHandler(grpcHandler, sslContextProviderSupplier, ""); + + assertThat(clientSecurityHandler.getSni()).isEqualTo(SNI_IN_UTC); + } + + @Test + public void sniInClientSecurityHandler_autoHostSni_endpointHostnameIsNull_usesSniFromUtc() { + Bootstrapper.BootstrapInfo bootstrapInfoForClient = CommonBootstrapperTestUtils + .buildBootstrapInfo("google_cloud_private_spiffe-client", CLIENT_KEY_FILE, + CLIENT_PEM_FILE, CA_PEM_FILE, null, null, null, null, null); + UpstreamTlsContext upstreamTlsContext = CommonTlsContextTestsUtil.buildUpstreamTlsContext( + "google_cloud_private_spiffe-client", true, SNI_IN_UTC, true); + SslContextProviderSupplier sslContextProviderSupplier = + new SslContextProviderSupplier(upstreamTlsContext, + new TlsContextManagerImpl(bootstrapInfoForClient)); + + ClientSecurityHandler clientSecurityHandler = + new ClientSecurityHandler(grpcHandler, sslContextProviderSupplier, null); + + assertThat(clientSecurityHandler.getSni()).isEqualTo(SNI_IN_UTC); + } + + @Test + public void sniInClientSecurityHandler_autoHostSniIsFalse_usesSniFromUpstreamTlsContext() { + Bootstrapper.BootstrapInfo bootstrapInfoForClient = CommonBootstrapperTestUtils + .buildBootstrapInfo("google_cloud_private_spiffe-client", CLIENT_KEY_FILE, + CLIENT_PEM_FILE, CA_PEM_FILE, null, null, null, null, null); + UpstreamTlsContext upstreamTlsContext = CommonTlsContextTestsUtil.buildUpstreamTlsContext( + "google_cloud_private_spiffe-client", true, SNI_IN_UTC, false); + SslContextProviderSupplier sslContextProviderSupplier = + new SslContextProviderSupplier(upstreamTlsContext, + new TlsContextManagerImpl(bootstrapInfoForClient)); + + ClientSecurityHandler clientSecurityHandler = + new ClientSecurityHandler(grpcHandler, sslContextProviderSupplier, HOSTNAME); + + assertThat(clientSecurityHandler.getSni()).isEqualTo(SNI_IN_UTC); + } + @Test public void serverSecurityHandler_addLast() throws InterruptedException, TimeoutException, ExecutionException { @@ -216,7 +317,7 @@ public SocketAddress remoteAddress() { pipeline = channel.pipeline(); Bootstrapper.BootstrapInfo bootstrapInfoForServer = CommonBootstrapperTestUtils .buildBootstrapInfo("google_cloud_private_spiffe-server", SERVER_1_KEY_FILE, - SERVER_1_PEM_FILE, CA_PEM_FILE, null, null, null, null); + SERVER_1_PEM_FILE, CA_PEM_FILE, null, null, null, null, null); DownstreamTlsContext downstreamTlsContext = CommonTlsContextTestsUtil.buildDownstreamTlsContext( "google_cloud_private_spiffe-server", true, true); @@ -246,19 +347,20 @@ public SocketAddress remoteAddress() { sslContextProviderSupplier .updateSslContext(new SslContextProvider.Callback(MoreExecutors.directExecutor()) { @Override - public void updateSslContext(SslContext sslContext) { - future.set(sslContext); + public void updateSslContextAndExtendedX509TrustManager( + AbstractMap.SimpleImmutableEntry sslContextAndTm) { + future.set(sslContextAndTm); } @Override protected void onException(Throwable throwable) { future.set(throwable); } - }); + }, true); channel.runPendingTasks(); // need this for tasks to execute on eventLoop assertThat(executor.runDueTasks()).isEqualTo(1); Object fromFuture = future.get(2, TimeUnit.SECONDS); - assertThat(fromFuture).isInstanceOf(SslContext.class); + assertThat(fromFuture).isInstanceOf(AbstractMap.SimpleImmutableEntry.class); channel.runPendingTasks(); channelHandlerCtx = pipeline.context(SecurityProtocolNegotiators.ServerSecurityHandler.class); assertThat(channelHandlerCtx).isNull(); @@ -356,12 +458,12 @@ public void nullTlsContext_nullFallbackProtocolNegotiator_expectException() { @Test public void clientSecurityProtocolNegotiatorNewHandler_fireProtocolNegotiationEvent() - throws InterruptedException, TimeoutException, ExecutionException { + throws InterruptedException, TimeoutException, ExecutionException { FakeClock executor = new FakeClock(); CommonCertProviderTestUtils.register(executor); Bootstrapper.BootstrapInfo bootstrapInfoForClient = CommonBootstrapperTestUtils - .buildBootstrapInfo("google_cloud_private_spiffe-client", CLIENT_KEY_FILE, CLIENT_PEM_FILE, - CA_PEM_FILE, null, null, null, null); + .buildBootstrapInfo("google_cloud_private_spiffe-client", CLIENT_KEY_FILE, + CLIENT_PEM_FILE, CA_PEM_FILE, null, null, null, null, null); UpstreamTlsContext upstreamTlsContext = CommonTlsContextTestsUtil .buildUpstreamTlsContext("google_cloud_private_spiffe-client", true); @@ -370,7 +472,7 @@ public void clientSecurityProtocolNegotiatorNewHandler_fireProtocolNegotiationEv new SslContextProviderSupplier(upstreamTlsContext, new TlsContextManagerImpl(bootstrapInfoForClient)); ClientSecurityHandler clientSecurityHandler = - new ClientSecurityHandler(grpcHandler, sslContextProviderSupplier); + new ClientSecurityHandler(grpcHandler, sslContextProviderSupplier, HOSTNAME); pipeline.addLast(clientSecurityHandler); channelHandlerCtx = pipeline.context(clientSecurityHandler); @@ -382,19 +484,20 @@ public void clientSecurityProtocolNegotiatorNewHandler_fireProtocolNegotiationEv sslContextProviderSupplier .updateSslContext(new SslContextProvider.Callback(MoreExecutors.directExecutor()) { @Override - public void updateSslContext(SslContext sslContext) { - future.set(sslContext); + public void updateSslContextAndExtendedX509TrustManager( + AbstractMap.SimpleImmutableEntry sslContextAndTm) { + future.set(sslContextAndTm); } @Override protected void onException(Throwable throwable) { future.set(throwable); } - }); + }, true); executor.runDueTasks(); channel.runPendingTasks(); // need this for tasks to execute on eventLoop Object fromFuture = future.get(5, TimeUnit.SECONDS); - assertThat(fromFuture).isInstanceOf(SslContext.class); + assertThat(fromFuture).isInstanceOf(AbstractMap.SimpleImmutableEntry.class); channel.runPendingTasks(); channelHandlerCtx = pipeline.context(clientSecurityHandler); assertThat(channelHandlerCtx).isNull(); @@ -412,7 +515,7 @@ public void clientSecurityProtocolNegotiatorNewHandler_handleHandlerRemoved() { CommonCertProviderTestUtils.register(executor); Bootstrapper.BootstrapInfo bootstrapInfoForClient = CommonBootstrapperTestUtils .buildBootstrapInfo("google_cloud_private_spiffe-client", CLIENT_KEY_FILE, CLIENT_PEM_FILE, - CA_PEM_FILE, null, null, null, null); + CA_PEM_FILE, null, null, null, null, null); UpstreamTlsContext upstreamTlsContext = CommonTlsContextTestsUtil .buildUpstreamTlsContext("google_cloud_private_spiffe-client", true); @@ -421,7 +524,7 @@ public void clientSecurityProtocolNegotiatorNewHandler_handleHandlerRemoved() { new SslContextProviderSupplier(upstreamTlsContext, new TlsContextManagerImpl(bootstrapInfoForClient)); ClientSecurityHandler clientSecurityHandler = - new ClientSecurityHandler(grpcHandler, sslContextProviderSupplier); + new ClientSecurityHandler(grpcHandler, sslContextProviderSupplier, HOSTNAME); pipeline.addLast(clientSecurityHandler); channelHandlerCtx = pipeline.context(clientSecurityHandler); @@ -459,7 +562,7 @@ static FakeGrpcHttp2ConnectionHandler newHandler() { @Override public String getAuthority() { - return "authority"; + return FAKE_AUTHORITY; } } } diff --git a/xds/src/test/java/io/grpc/xds/internal/security/ServerSslContextProviderFactoryTest.java b/xds/src/test/java/io/grpc/xds/internal/security/ServerSslContextProviderFactoryTest.java index c455385dae9..7a5a6c00639 100644 --- a/xds/src/test/java/io/grpc/xds/internal/security/ServerSslContextProviderFactoryTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/security/ServerSslContextProviderFactoryTest.java @@ -24,10 +24,10 @@ import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateValidationContext; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext; import io.envoyproxy.envoy.type.matcher.v3.StringMatcher; -import io.grpc.xds.CommonBootstrapperTestUtils; import io.grpc.xds.EnvoyServerProtoData; import io.grpc.xds.EnvoyServerProtoData.DownstreamTlsContext; import io.grpc.xds.client.Bootstrapper; +import io.grpc.xds.client.CommonBootstrapperTestUtils; import io.grpc.xds.client.XdsInitializationException; import io.grpc.xds.internal.security.certprovider.CertProviderServerSslContextProviderFactory; import io.grpc.xds.internal.security.certprovider.CertificateProvider; @@ -78,7 +78,7 @@ public void createCertProviderServerSslContextProvider() throws XdsInitializatio serverSslContextProviderFactory.create(downstreamTlsContext); assertThat(sslContextProvider.getClass().getSimpleName()).isEqualTo( "CertProviderServerSslContextProvider"); - verifyWatcher(sslContextProvider, watcherCaptor[0]); + verifyWatcher(sslContextProvider, watcherCaptor[0], false); // verify that bootstrapInfo is cached... sslContextProvider = serverSslContextProviderFactory.create(downstreamTlsContext); @@ -117,7 +117,7 @@ public void bothPresent_expectCertProviderServerSslContextProvider() serverSslContextProviderFactory.create(downstreamTlsContext); assertThat(sslContextProvider.getClass().getSimpleName()).isEqualTo( "CertProviderServerSslContextProvider"); - verifyWatcher(sslContextProvider, watcherCaptor[0]); + verifyWatcher(sslContextProvider, watcherCaptor[0], true); } @Test @@ -144,7 +144,7 @@ public void createCertProviderServerSslContextProvider_onlyCertInstance() serverSslContextProviderFactory.create(downstreamTlsContext); assertThat(sslContextProvider.getClass().getSimpleName()).isEqualTo( "CertProviderServerSslContextProvider"); - verifyWatcher(sslContextProvider, watcherCaptor[0]); + verifyWatcher(sslContextProvider, watcherCaptor[0], true); } @Test @@ -179,7 +179,7 @@ public void createCertProviderServerSslContextProvider_withStaticContext() serverSslContextProviderFactory.create(downstreamTlsContext); assertThat(sslContextProvider.getClass().getSimpleName()).isEqualTo( "CertProviderServerSslContextProvider"); - verifyWatcher(sslContextProvider, watcherCaptor[0]); + verifyWatcher(sslContextProvider, watcherCaptor[0], false); } @Test @@ -210,8 +210,8 @@ public void createCertProviderServerSslContextProvider_2providers() serverSslContextProviderFactory.create(downstreamTlsContext); assertThat(sslContextProvider.getClass().getSimpleName()).isEqualTo( "CertProviderServerSslContextProvider"); - verifyWatcher(sslContextProvider, watcherCaptor[0]); - verifyWatcher(sslContextProvider, watcherCaptor[1]); + verifyWatcher(sslContextProvider, watcherCaptor[0], true); + verifyWatcher(sslContextProvider, watcherCaptor[1], true); } @Test @@ -249,7 +249,7 @@ public void createNewCertProviderServerSslContextProvider_withSans() serverSslContextProviderFactory.create(downstreamTlsContext); assertThat(sslContextProvider.getClass().getSimpleName()).isEqualTo( "CertProviderServerSslContextProvider"); - verifyWatcher(sslContextProvider, watcherCaptor[0]); - verifyWatcher(sslContextProvider, watcherCaptor[1]); + verifyWatcher(sslContextProvider, watcherCaptor[0], true); + verifyWatcher(sslContextProvider, watcherCaptor[1], true); } } diff --git a/xds/src/test/java/io/grpc/xds/internal/security/SslContextProviderSupplierTest.java b/xds/src/test/java/io/grpc/xds/internal/security/SslContextProviderSupplierTest.java index f476818297d..70a53c53205 100644 --- a/xds/src/test/java/io/grpc/xds/internal/security/SslContextProviderSupplierTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/security/SslContextProviderSupplierTest.java @@ -17,8 +17,9 @@ package io.grpc.xds.internal.security; import static com.google.common.truth.Truth.assertThat; +import static io.grpc.xds.internal.security.CommonTlsContextTestsUtil.buildUpstreamTlsContext; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.Mockito.any; import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; @@ -26,10 +27,13 @@ import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; +import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateValidationContext; import io.grpc.xds.EnvoyServerProtoData; import io.grpc.xds.TlsContextManager; import io.netty.handler.ssl.SslContext; +import java.util.AbstractMap; import java.util.concurrent.Executor; +import javax.net.ssl.X509TrustManager; import org.junit.Rule; import org.junit.Test; import org.junit.runner.RunWith; @@ -47,14 +51,14 @@ public class SslContextProviderSupplierTest { @Rule public final MockitoRule mocks = MockitoJUnit.rule(); @Mock private TlsContextManager mockTlsContextManager; + @Mock private Executor mockExecutor; private SslContextProviderSupplier supplier; private SslContextProvider mockSslContextProvider; - private EnvoyServerProtoData.UpstreamTlsContext upstreamTlsContext; + private EnvoyServerProtoData.UpstreamTlsContext upstreamTlsContext = + buildUpstreamTlsContext("google_cloud_private_spiffe", true); private SslContextProvider.Callback mockCallback; private void prepareSupplier() { - upstreamTlsContext = - CommonTlsContextTestsUtil.buildUpstreamTlsContext("google_cloud_private_spiffe", true); mockSslContextProvider = mock(SslContextProvider.class); doReturn(mockSslContextProvider) .when(mockTlsContextManager) @@ -64,9 +68,8 @@ private void prepareSupplier() { private void callUpdateSslContext() { mockCallback = mock(SslContextProvider.Callback.class); - Executor mockExecutor = mock(Executor.class); doReturn(mockExecutor).when(mockCallback).getExecutor(); - supplier.updateSslContext(mockCallback); + supplier.updateSslContext(mockCallback, false); } @Test @@ -82,26 +85,57 @@ public void get_updateSecret() { verify(mockSslContextProvider, times(1)).addCallback(callbackCaptor.capture()); SslContextProvider.Callback capturedCallback = callbackCaptor.getValue(); assertThat(capturedCallback).isNotNull(); - SslContext mockSslContext = mock(SslContext.class); - capturedCallback.updateSslContext(mockSslContext); - verify(mockCallback, times(1)).updateSslContext(eq(mockSslContext)); + @SuppressWarnings("unchecked") + AbstractMap.SimpleImmutableEntry mockSslContextAndTm = + mock(AbstractMap.SimpleImmutableEntry.class); + capturedCallback.updateSslContextAndExtendedX509TrustManager(mockSslContextAndTm); + verify(mockCallback, times(1)) + .updateSslContextAndExtendedX509TrustManager(eq(mockSslContextAndTm)); verify(mockTlsContextManager, times(1)) .releaseClientSslContextProvider(eq(mockSslContextProvider)); SslContextProvider.Callback mockCallback = mock(SslContextProvider.Callback.class); - supplier.updateSslContext(mockCallback); + supplier.updateSslContext(mockCallback, false); verify(mockTlsContextManager, times(3)) .findOrCreateClientSslContextProvider(eq(upstreamTlsContext)); } @Test - public void get_onException() { + public void autoHostSniFalse_usesSniFromUpstreamTlsContext() { prepareSupplier(); callUpdateSslContext(); + verify(mockTlsContextManager, times(2)) + .findOrCreateClientSslContextProvider(eq(upstreamTlsContext)); + verify(mockTlsContextManager, times(0)) + .releaseClientSslContextProvider(any(SslContextProvider.class)); ArgumentCaptor callbackCaptor = ArgumentCaptor.forClass(SslContextProvider.Callback.class); verify(mockSslContextProvider, times(1)).addCallback(callbackCaptor.capture()); SslContextProvider.Callback capturedCallback = callbackCaptor.getValue(); assertThat(capturedCallback).isNotNull(); + @SuppressWarnings("unchecked") + AbstractMap.SimpleImmutableEntry mockSslContextAndTm = + mock(AbstractMap.SimpleImmutableEntry.class); + capturedCallback.updateSslContextAndExtendedX509TrustManager(mockSslContextAndTm); + verify(mockCallback, times(1)) + .updateSslContextAndExtendedX509TrustManager(eq(mockSslContextAndTm)); + verify(mockTlsContextManager, times(1)) + .releaseClientSslContextProvider(eq(mockSslContextProvider)); + SslContextProvider.Callback mockCallback = mock(SslContextProvider.Callback.class); + supplier.updateSslContext(mockCallback, false); + verify(mockTlsContextManager, times(3)) + .findOrCreateClientSslContextProvider(eq(upstreamTlsContext)); + } + + @Test + public void get_onException() { + prepareSupplier(); + callUpdateSslContext(); + ArgumentCaptor callbackCaptor = + ArgumentCaptor.forClass(SslContextProvider.Callback.class); + verify(mockSslContextProvider, times(1)) + .addCallback(callbackCaptor.capture()); + SslContextProvider.Callback capturedCallback = callbackCaptor.getValue(); + assertThat(capturedCallback).isNotNull(); Exception exception = new Exception("test"); capturedCallback.onException(exception); verify(mockCallback, times(1)).onException(eq(exception)); @@ -109,6 +143,46 @@ public void get_onException() { .releaseClientSslContextProvider(eq(mockSslContextProvider)); } + @Test + public void systemRootCertsWithMtls_callbackExecutedFromProvider() { + upstreamTlsContext = + CommonTlsContextTestsUtil.buildNewUpstreamTlsContextForCertProviderInstance( + "gcp_id", + "cert-default", + null, + "root-default", + null, + CertificateValidationContext.newBuilder() + .setSystemRootCerts( + CertificateValidationContext.SystemRootCerts.getDefaultInstance()) + .build()); + prepareSupplier(); + + callUpdateSslContext(); + + verify(mockTlsContextManager, times(2)) + .findOrCreateClientSslContextProvider(eq(upstreamTlsContext)); + verify(mockTlsContextManager, times(0)) + .releaseClientSslContextProvider(any(SslContextProvider.class)); + ArgumentCaptor callbackCaptor = + ArgumentCaptor.forClass(SslContextProvider.Callback.class); + verify(mockSslContextProvider, times(1)).addCallback(callbackCaptor.capture()); + SslContextProvider.Callback capturedCallback = callbackCaptor.getValue(); + assertThat(capturedCallback).isNotNull(); + @SuppressWarnings("unchecked") + AbstractMap.SimpleImmutableEntry mockSslContextAndTm = + mock(AbstractMap.SimpleImmutableEntry.class); + capturedCallback.updateSslContextAndExtendedX509TrustManager(mockSslContextAndTm); + verify(mockCallback, times(1)) + .updateSslContextAndExtendedX509TrustManager(eq(mockSslContextAndTm)); + verify(mockTlsContextManager, times(1)) + .releaseClientSslContextProvider(eq(mockSslContextProvider)); + SslContextProvider.Callback mockCallback = mock(SslContextProvider.Callback.class); + supplier.updateSslContext(mockCallback, false); + verify(mockTlsContextManager, times(3)) + .findOrCreateClientSslContextProvider(eq(upstreamTlsContext)); + } + @Test public void testClose() { prepareSupplier(); @@ -116,7 +190,7 @@ public void testClose() { supplier.close(); verify(mockTlsContextManager, times(1)) .releaseClientSslContextProvider(eq(mockSslContextProvider)); - supplier.updateSslContext(mockCallback); + supplier.updateSslContext(mockCallback, false); verify(mockTlsContextManager, times(3)) .findOrCreateClientSslContextProvider(eq(upstreamTlsContext)); verify(mockTlsContextManager, times(1)) diff --git a/xds/src/test/java/io/grpc/xds/internal/security/TlsContextManagerTest.java b/xds/src/test/java/io/grpc/xds/internal/security/TlsContextManagerTest.java index 4d04eeb41e0..035096a3528 100644 --- a/xds/src/test/java/io/grpc/xds/internal/security/TlsContextManagerTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/security/TlsContextManagerTest.java @@ -30,10 +30,10 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; -import io.grpc.xds.CommonBootstrapperTestUtils; import io.grpc.xds.EnvoyServerProtoData.DownstreamTlsContext; import io.grpc.xds.EnvoyServerProtoData.UpstreamTlsContext; import io.grpc.xds.client.Bootstrapper; +import io.grpc.xds.client.CommonBootstrapperTestUtils; import io.grpc.xds.internal.security.ReferenceCountingMap.ValueFactory; import org.junit.Rule; import org.junit.Test; @@ -57,7 +57,7 @@ public class TlsContextManagerTest { public void createServerSslContextProvider() { Bootstrapper.BootstrapInfo bootstrapInfoForServer = CommonBootstrapperTestUtils .buildBootstrapInfo("google_cloud_private_spiffe-server", SERVER_1_KEY_FILE, - SERVER_1_PEM_FILE, CA_PEM_FILE, null, null, null, null); + SERVER_1_PEM_FILE, CA_PEM_FILE, null, null, null, null, null); DownstreamTlsContext downstreamTlsContext = CommonTlsContextTestsUtil.buildDownstreamTlsContext( "google_cloud_private_spiffe-server", false, false); @@ -76,7 +76,7 @@ public void createServerSslContextProvider() { public void createClientSslContextProvider() { Bootstrapper.BootstrapInfo bootstrapInfoForClient = CommonBootstrapperTestUtils .buildBootstrapInfo("google_cloud_private_spiffe-client", CLIENT_KEY_FILE, CLIENT_PEM_FILE, - CA_PEM_FILE, null, null, null, null); + CA_PEM_FILE, null, null, null, null, null); UpstreamTlsContext upstreamTlsContext = CommonTlsContextTestsUtil .buildUpstreamTlsContext("google_cloud_private_spiffe-client", false); @@ -96,7 +96,7 @@ public void createServerSslContextProvider_differentInstance() { Bootstrapper.BootstrapInfo bootstrapInfoForServer = CommonBootstrapperTestUtils .buildBootstrapInfo("google_cloud_private_spiffe-server", SERVER_1_KEY_FILE, SERVER_1_PEM_FILE, CA_PEM_FILE, "cert-instance2", SERVER_0_KEY_FILE, SERVER_0_PEM_FILE, - CA_PEM_FILE); + CA_PEM_FILE, null); DownstreamTlsContext downstreamTlsContext = CommonTlsContextTestsUtil.buildDownstreamTlsContext( "google_cloud_private_spiffe-server", false, false); @@ -120,7 +120,7 @@ public void createServerSslContextProvider_differentInstance() { public void createClientSslContextProvider_differentInstance() { Bootstrapper.BootstrapInfo bootstrapInfoForClient = CommonBootstrapperTestUtils .buildBootstrapInfo("google_cloud_private_spiffe-client", CLIENT_KEY_FILE, CLIENT_PEM_FILE, - CA_PEM_FILE, "cert-instance-2", CLIENT_KEY_FILE, CLIENT_PEM_FILE, CA_PEM_FILE); + CA_PEM_FILE, "cert-instance-2", CLIENT_KEY_FILE, CLIENT_PEM_FILE, CA_PEM_FILE, null); UpstreamTlsContext upstreamTlsContext = CommonTlsContextTestsUtil .buildUpstreamTlsContext("google_cloud_private_spiffe-client", false); diff --git a/xds/src/test/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProviderTest.java b/xds/src/test/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProviderTest.java index 5925c5f03b1..91f02863ca4 100644 --- a/xds/src/test/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProviderTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProviderTest.java @@ -33,9 +33,9 @@ import com.google.common.util.concurrent.MoreExecutors; import io.envoyproxy.envoy.config.core.v3.DataSource; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateValidationContext; -import io.grpc.xds.CommonBootstrapperTestUtils; import io.grpc.xds.EnvoyServerProtoData; import io.grpc.xds.client.Bootstrapper; +import io.grpc.xds.client.CommonBootstrapperTestUtils; import io.grpc.xds.internal.security.CommonTlsContextTestsUtil; import io.grpc.xds.internal.security.CommonTlsContextTestsUtil.TestCallback; import java.util.Queue; @@ -72,15 +72,28 @@ private CertProviderClientSslContextProvider getSslContextProvider( String rootInstanceName, Bootstrapper.BootstrapInfo bootstrapInfo, Iterable alpnProtocols, - CertificateValidationContext staticCertValidationContext) { - EnvoyServerProtoData.UpstreamTlsContext upstreamTlsContext = - CommonTlsContextTestsUtil.buildUpstreamTlsContextForCertProviderInstance( - certInstanceName, - "cert-default", - rootInstanceName, - "root-default", - alpnProtocols, - staticCertValidationContext); + CertificateValidationContext staticCertValidationContext, + boolean useSystemRootCerts) { + EnvoyServerProtoData.UpstreamTlsContext upstreamTlsContext; + if (useSystemRootCerts) { + upstreamTlsContext = + CommonTlsContextTestsUtil.buildNewUpstreamTlsContextForCertProviderInstance( + certInstanceName, + "cert-default", + rootInstanceName, + "root-default", + alpnProtocols, + staticCertValidationContext); + } else { + upstreamTlsContext = + CommonTlsContextTestsUtil.buildUpstreamTlsContextForCertProviderInstance( + certInstanceName, + "cert-default", + rootInstanceName, + "root-default", + alpnProtocols, + staticCertValidationContext); + } return (CertProviderClientSslContextProvider) certProviderClientSslContextProviderFactory.getProvider( upstreamTlsContext, @@ -122,12 +135,12 @@ public void testProviderForClient_mtls() throws Exception { "gcp_id", CommonBootstrapperTestUtils.getTestBootstrapInfo(), /* alpnProtocols= */ null, - /* staticCertValidationContext= */ null); + /* staticCertValidationContext= */ null, false); assertThat(provider.savedKey).isNull(); assertThat(provider.savedCertChain).isNull(); assertThat(provider.savedTrustedRoots).isNull(); - assertThat(provider.getSslContext()).isNull(); + assertThat(provider.getSslContextAndTrustManager()).isNull(); // now generate cert update watcherCaptor[0].updateCertificate( @@ -135,11 +148,11 @@ public void testProviderForClient_mtls() throws Exception { ImmutableList.of(getCertFromResourceName(CLIENT_PEM_FILE))); assertThat(provider.savedKey).isNotNull(); assertThat(provider.savedCertChain).isNotNull(); - assertThat(provider.getSslContext()).isNull(); + assertThat(provider.getSslContextAndTrustManager()).isNull(); // now generate root cert update watcherCaptor[0].updateTrustedRoots(ImmutableList.of(getCertFromResourceName(CA_PEM_FILE))); - assertThat(provider.getSslContext()).isNotNull(); + assertThat(provider.getSslContextAndTrustManager()).isNotNull(); assertThat(provider.savedKey).isNull(); assertThat(provider.savedCertChain).isNull(); assertThat(provider.savedTrustedRoots).isNull(); @@ -168,11 +181,92 @@ public void testProviderForClient_mtls() throws Exception { assertThat(provider.savedKey).isNull(); assertThat(provider.savedCertChain).isNull(); assertThat(provider.savedTrustedRoots).isNull(); - assertThat(provider.getSslContext()).isNotNull(); + assertThat(provider.getSslContextAndTrustManager()).isNotNull(); + testCallback1 = CommonTlsContextTestsUtil.getValueThruCallback(provider); + assertThat(testCallback1.updatedSslContext).isNotSameInstanceAs(testCallback.updatedSslContext); + } + + @Test + public void testProviderForClient_systemRootCerts_mtls() throws Exception { + final CertificateProvider.DistributorWatcher[] watcherCaptor = + new CertificateProvider.DistributorWatcher[1]; + TestCertificateProvider.createAndRegisterProviderProvider( + certificateProviderRegistry, watcherCaptor, "testca", 0); + CertProviderClientSslContextProvider provider = + getSslContextProvider( + "gcp_id", + null, + CommonBootstrapperTestUtils.getTestBootstrapInfo(), + /* alpnProtocols= */ null, + CertificateValidationContext.newBuilder() + .setSystemRootCerts( + CertificateValidationContext.SystemRootCerts.getDefaultInstance()) + .build(), + true); + + assertThat(provider.savedKey).isNull(); + assertThat(provider.savedCertChain).isNull(); + assertThat(provider.savedTrustedRoots).isNotNull(); + assertThat(provider.getSslContextAndTrustManager()).isNull(); + + // now generate cert update + watcherCaptor[0].updateCertificate( + CommonCertProviderTestUtils.getPrivateKey(CLIENT_KEY_FILE), + ImmutableList.of(getCertFromResourceName(CLIENT_PEM_FILE))); + assertThat(provider.savedKey).isNull(); + assertThat(provider.savedCertChain).isNull(); + assertThat(provider.savedTrustedRoots).isNotNull(); + assertThat(provider.getSslContextAndTrustManager()).isNotNull(); + + TestCallback testCallback = + CommonTlsContextTestsUtil.getValueThruCallback(provider); + + doChecksOnSslContext(false, testCallback.updatedSslContext, /* expectedApnProtos= */ null); + TestCallback testCallback1 = + CommonTlsContextTestsUtil.getValueThruCallback(provider); + assertThat(testCallback1.updatedSslContext).isSameInstanceAs(testCallback.updatedSslContext); + + // now update id cert: sslContext should be updated i.e. different from the previous one + watcherCaptor[0].updateCertificate( + CommonCertProviderTestUtils.getPrivateKey(SERVER_1_KEY_FILE), + ImmutableList.of(getCertFromResourceName(SERVER_1_PEM_FILE))); + assertThat(provider.savedKey).isNull(); + assertThat(provider.savedCertChain).isNull(); + assertThat(provider.savedTrustedRoots).isNotNull(); + assertThat(provider.getSslContextAndTrustManager()).isNotNull(); testCallback1 = CommonTlsContextTestsUtil.getValueThruCallback(provider); assertThat(testCallback1.updatedSslContext).isNotSameInstanceAs(testCallback.updatedSslContext); } + @Test + public void testProviderForClient_systemRootCerts_regularTls() { + final CertificateProvider.DistributorWatcher[] watcherCaptor = + new CertificateProvider.DistributorWatcher[1]; + TestCertificateProvider.createAndRegisterProviderProvider( + certificateProviderRegistry, watcherCaptor, "testca", 0); + CertProviderClientSslContextProvider provider = + getSslContextProvider( + null, + null, + CommonBootstrapperTestUtils.getTestBootstrapInfo(), + /* alpnProtocols= */ null, + CertificateValidationContext.newBuilder() + .setSystemRootCerts( + CertificateValidationContext.SystemRootCerts.getDefaultInstance()) + .build(), + true); + + assertThat(provider.savedKey).isNull(); + assertThat(provider.savedCertChain).isNull(); + assertThat(provider.savedTrustedRoots).isNotNull(); + assertThat(provider.getSslContextAndTrustManager()).isNotNull(); + TestCallback testCallback = + CommonTlsContextTestsUtil.getValueThruCallback(provider); + assertThat(testCallback.updatedSslContext).isEqualTo(provider.getSslContextAndTrustManager()); + + assertThat(watcherCaptor[0]).isNull(); + } + @Test public void testProviderForClient_mtls_newXds() throws Exception { final CertificateProvider.DistributorWatcher[] watcherCaptor = @@ -190,7 +284,7 @@ public void testProviderForClient_mtls_newXds() throws Exception { assertThat(provider.savedKey).isNull(); assertThat(provider.savedCertChain).isNull(); assertThat(provider.savedTrustedRoots).isNull(); - assertThat(provider.getSslContext()).isNull(); + assertThat(provider.getSslContextAndTrustManager()).isNull(); // now generate cert update watcherCaptor[0].updateCertificate( @@ -198,11 +292,11 @@ public void testProviderForClient_mtls_newXds() throws Exception { ImmutableList.of(getCertFromResourceName(CLIENT_PEM_FILE))); assertThat(provider.savedKey).isNotNull(); assertThat(provider.savedCertChain).isNotNull(); - assertThat(provider.getSslContext()).isNull(); + assertThat(provider.getSslContextAndTrustManager()).isNull(); // now generate root cert update watcherCaptor[0].updateTrustedRoots(ImmutableList.of(getCertFromResourceName(CA_PEM_FILE))); - assertThat(provider.getSslContext()).isNotNull(); + assertThat(provider.getSslContextAndTrustManager()).isNotNull(); assertThat(provider.savedKey).isNull(); assertThat(provider.savedCertChain).isNull(); assertThat(provider.savedTrustedRoots).isNull(); @@ -231,7 +325,7 @@ public void testProviderForClient_mtls_newXds() throws Exception { assertThat(provider.savedKey).isNull(); assertThat(provider.savedCertChain).isNull(); assertThat(provider.savedTrustedRoots).isNull(); - assertThat(provider.getSslContext()).isNotNull(); + assertThat(provider.getSslContextAndTrustManager()).isNotNull(); testCallback1 = CommonTlsContextTestsUtil.getValueThruCallback(provider); assertThat(testCallback1.updatedSslContext).isNotSameInstanceAs(testCallback.updatedSslContext); } @@ -248,7 +342,7 @@ public void testProviderForClient_queueExecutor() throws Exception { "gcp_id", CommonBootstrapperTestUtils.getTestBootstrapInfo(), /* alpnProtocols= */ null, - /* staticCertValidationContext= */ null); + /* staticCertValidationContext= */ null, false); QueuedExecutor queuedExecutor = new QueuedExecutor(); TestCallback testCallback = @@ -281,16 +375,16 @@ public void testProviderForClient_tls() throws Exception { "gcp_id", CommonBootstrapperTestUtils.getTestBootstrapInfo(), /* alpnProtocols= */ null, - /* staticCertValidationContext= */ null); + /* staticCertValidationContext= */ null, false); assertThat(provider.savedKey).isNull(); assertThat(provider.savedCertChain).isNull(); assertThat(provider.savedTrustedRoots).isNull(); - assertThat(provider.getSslContext()).isNull(); + assertThat(provider.getSslContextAndTrustManager()).isNull(); // now generate root cert update watcherCaptor[0].updateTrustedRoots(ImmutableList.of(getCertFromResourceName(CA_PEM_FILE))); - assertThat(provider.getSslContext()).isNotNull(); + assertThat(provider.getSslContextAndTrustManager()).isNotNull(); assertThat(provider.savedKey).isNull(); assertThat(provider.savedCertChain).isNull(); assertThat(provider.savedTrustedRoots).isNull(); @@ -318,7 +412,7 @@ public void testProviderForClient_sslContextException_onError() throws Exception "gcp_id", CommonBootstrapperTestUtils.getTestBootstrapInfo(), /* alpnProtocols= */null, - staticCertValidationContext); + staticCertValidationContext, false); TestCallback testCallback = new TestCallback(MoreExecutors.directExecutor()); provider.addCallback(testCallback); @@ -338,7 +432,8 @@ public void testProviderForClient_sslContextException_onError() throws Exception } @Test - public void testProviderForClient_rootInstanceNull_expectError() throws Exception { + public void testProviderForClient_rootInstanceNull_and_notUsingSystemRootCerts_expectError() + throws Exception { final CertificateProvider.DistributorWatcher[] watcherCaptor = new CertificateProvider.DistributorWatcher[1]; TestCertificateProvider.createAndRegisterProviderProvider( @@ -349,13 +444,84 @@ public void testProviderForClient_rootInstanceNull_expectError() throws Exceptio /* rootInstanceName= */ null, CommonBootstrapperTestUtils.getTestBootstrapInfo(), /* alpnProtocols= */ null, - /* staticCertValidationContext= */ null); + /* staticCertValidationContext= */ null, false); fail("exception expected"); - } catch (NullPointerException expected) { - assertThat(expected).hasMessageThat().contains("Client SSL requires rootCertInstance"); + } catch (UnsupportedOperationException expected) { + assertThat(expected).hasMessageThat().contains("Unsupported configurations in " + + "UpstreamTlsContext!"); } } + @Test + public void testProviderForClient_rootInstanceNull_but_isUsingSystemRootCerts_valid() + throws Exception { + final CertificateProvider.DistributorWatcher[] watcherCaptor = + new CertificateProvider.DistributorWatcher[1]; + TestCertificateProvider.createAndRegisterProviderProvider( + certificateProviderRegistry, watcherCaptor, "testca", 0); + getSslContextProvider( + /* certInstanceName= */ null, + /* rootInstanceName= */ null, + CommonBootstrapperTestUtils.getTestBootstrapInfo(), + /* alpnProtocols= */ null, + CertificateValidationContext.newBuilder() + .setSystemRootCerts( + CertificateValidationContext.SystemRootCerts.newBuilder().build()) + .build(), false); + } + + @Test + public void testProviderForClient_deprecatedCertProviderField() throws Exception { + final CertificateProvider.DistributorWatcher[] watcherCaptor = + new CertificateProvider.DistributorWatcher[1]; + TestCertificateProvider.createAndRegisterProviderProvider( + certificateProviderRegistry, watcherCaptor, "testca", 0); + + // Build UpstreamTlsContext using deprecated field + EnvoyServerProtoData.UpstreamTlsContext upstreamTlsContext = + new EnvoyServerProtoData.UpstreamTlsContext( + CommonTlsContextTestsUtil.buildCommonTlsContextWithDeprecatedCertProviderInstance( + "gcp_id", + "cert-default", + "gcp_id", + "root-default", + /* alpnProtocols= */ null, + /* staticCertValidationContext= */ null)); + + Bootstrapper.BootstrapInfo bootstrapInfo = CommonBootstrapperTestUtils.getTestBootstrapInfo(); + CertProviderClientSslContextProvider provider = + (CertProviderClientSslContextProvider) + certProviderClientSslContextProviderFactory.getProvider( + upstreamTlsContext, + bootstrapInfo.node().toEnvoyProtoNode(), + bootstrapInfo.certProviders()); + + assertThat(provider.savedKey).isNull(); + assertThat(provider.savedCertChain).isNull(); + assertThat(provider.savedTrustedRoots).isNull(); + assertThat(provider.getSslContextAndTrustManager()).isNull(); + + // Generate cert update + watcherCaptor[0].updateCertificate( + CommonCertProviderTestUtils.getPrivateKey(CLIENT_KEY_FILE), + ImmutableList.of(getCertFromResourceName(CLIENT_PEM_FILE))); + assertThat(provider.savedKey).isNotNull(); + assertThat(provider.savedCertChain).isNotNull(); + assertThat(provider.getSslContextAndTrustManager()).isNull(); + + // Generate root cert update + watcherCaptor[0].updateTrustedRoots(ImmutableList.of(getCertFromResourceName(CA_PEM_FILE))); + assertThat(provider.getSslContextAndTrustManager()).isNotNull(); + assertThat(provider.savedKey).isNull(); + assertThat(provider.savedCertChain).isNull(); + assertThat(provider.savedTrustedRoots).isNull(); + + TestCallback testCallback = + CommonTlsContextTestsUtil.getValueThruCallback(provider); + + doChecksOnSslContext(false, testCallback.updatedSslContext, /* expectedApnProtos= */ null); + } + static class QueuedExecutor implements Executor { /** A list of Runnables to be run in order. */ @VisibleForTesting final Queue runQueue = new ConcurrentLinkedQueue<>(); diff --git a/xds/src/test/java/io/grpc/xds/internal/security/certprovider/CertProviderServerSslContextProviderTest.java b/xds/src/test/java/io/grpc/xds/internal/security/certprovider/CertProviderServerSslContextProviderTest.java index 82af7d1dc27..93559f47245 100644 --- a/xds/src/test/java/io/grpc/xds/internal/security/certprovider/CertProviderServerSslContextProviderTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/security/certprovider/CertProviderServerSslContextProviderTest.java @@ -32,9 +32,9 @@ import io.envoyproxy.envoy.config.core.v3.DataSource; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateValidationContext; import io.envoyproxy.envoy.type.matcher.v3.StringMatcher; -import io.grpc.xds.CommonBootstrapperTestUtils; import io.grpc.xds.EnvoyServerProtoData; import io.grpc.xds.client.Bootstrapper; +import io.grpc.xds.client.CommonBootstrapperTestUtils; import io.grpc.xds.internal.security.CommonTlsContextTestsUtil; import io.grpc.xds.internal.security.CommonTlsContextTestsUtil.TestCallback; import io.grpc.xds.internal.security.certprovider.CertProviderClientSslContextProviderTest.QueuedExecutor; @@ -127,7 +127,7 @@ public void testProviderForServer_mtls() throws Exception { assertThat(provider.savedKey).isNull(); assertThat(provider.savedCertChain).isNull(); assertThat(provider.savedTrustedRoots).isNull(); - assertThat(provider.getSslContext()).isNull(); + assertThat(provider.getSslContextAndTrustManager()).isNull(); // now generate cert update watcherCaptor[0].updateCertificate( @@ -135,11 +135,11 @@ public void testProviderForServer_mtls() throws Exception { ImmutableList.of(getCertFromResourceName(SERVER_0_PEM_FILE))); assertThat(provider.savedKey).isNotNull(); assertThat(provider.savedCertChain).isNotNull(); - assertThat(provider.getSslContext()).isNull(); + assertThat(provider.getSslContextAndTrustManager()).isNull(); // now generate root cert update watcherCaptor[0].updateTrustedRoots(ImmutableList.of(getCertFromResourceName(CA_PEM_FILE))); - assertThat(provider.getSslContext()).isNotNull(); + assertThat(provider.getSslContextAndTrustManager()).isNotNull(); assertThat(provider.savedKey).isNull(); assertThat(provider.savedCertChain).isNull(); assertThat(provider.savedTrustedRoots).isNull(); @@ -168,7 +168,7 @@ public void testProviderForServer_mtls() throws Exception { assertThat(provider.savedKey).isNull(); assertThat(provider.savedCertChain).isNull(); assertThat(provider.savedTrustedRoots).isNull(); - assertThat(provider.getSslContext()).isNotNull(); + assertThat(provider.getSslContextAndTrustManager()).isNotNull(); testCallback1 = CommonTlsContextTestsUtil.getValueThruCallback(provider); assertThat(testCallback1.updatedSslContext).isNotSameInstanceAs(testCallback.updatedSslContext); } @@ -196,7 +196,7 @@ public void testProviderForServer_mtls_newXds() throws Exception { assertThat(provider.savedKey).isNull(); assertThat(provider.savedCertChain).isNull(); assertThat(provider.savedTrustedRoots).isNull(); - assertThat(provider.getSslContext()).isNull(); + assertThat(provider.getSslContextAndTrustManager()).isNull(); // now generate cert update watcherCaptor[0].updateCertificate( @@ -204,11 +204,11 @@ public void testProviderForServer_mtls_newXds() throws Exception { ImmutableList.of(getCertFromResourceName(SERVER_0_PEM_FILE))); assertThat(provider.savedKey).isNotNull(); assertThat(provider.savedCertChain).isNotNull(); - assertThat(provider.getSslContext()).isNull(); + assertThat(provider.getSslContextAndTrustManager()).isNull(); // now generate root cert update watcherCaptor[0].updateTrustedRoots(ImmutableList.of(getCertFromResourceName(CA_PEM_FILE))); - assertThat(provider.getSslContext()).isNotNull(); + assertThat(provider.getSslContextAndTrustManager()).isNotNull(); assertThat(provider.savedKey).isNull(); assertThat(provider.savedCertChain).isNull(); assertThat(provider.savedTrustedRoots).isNull(); @@ -237,7 +237,7 @@ public void testProviderForServer_mtls_newXds() throws Exception { assertThat(provider.savedKey).isNull(); assertThat(provider.savedCertChain).isNull(); assertThat(provider.savedTrustedRoots).isNull(); - assertThat(provider.getSslContext()).isNotNull(); + assertThat(provider.getSslContextAndTrustManager()).isNotNull(); testCallback1 = CommonTlsContextTestsUtil.getValueThruCallback(provider); assertThat(testCallback1.updatedSslContext).isNotSameInstanceAs(testCallback.updatedSslContext); } @@ -294,14 +294,14 @@ public void testProviderForServer_tls() throws Exception { assertThat(provider.savedKey).isNull(); assertThat(provider.savedCertChain).isNull(); assertThat(provider.savedTrustedRoots).isNull(); - assertThat(provider.getSslContext()).isNull(); + assertThat(provider.getSslContextAndTrustManager()).isNull(); // now generate cert update watcherCaptor[0].updateCertificate( CommonCertProviderTestUtils.getPrivateKey(SERVER_0_KEY_FILE), ImmutableList.of(getCertFromResourceName(SERVER_0_PEM_FILE))); - assertThat(provider.getSslContext()).isNotNull(); + assertThat(provider.getSslContextAndTrustManager()).isNotNull(); assertThat(provider.savedKey).isNull(); assertThat(provider.savedCertChain).isNull(); assertThat(provider.savedTrustedRoots).isNull(); diff --git a/xds/src/test/java/io/grpc/xds/internal/security/certprovider/CertificateProviderStoreTest.java b/xds/src/test/java/io/grpc/xds/internal/security/certprovider/CertificateProviderStoreTest.java index 8f77de7b5e2..c0bc095eab6 100644 --- a/xds/src/test/java/io/grpc/xds/internal/security/certprovider/CertificateProviderStoreTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/security/certprovider/CertificateProviderStoreTest.java @@ -123,7 +123,6 @@ public void notifyCertUpdatesNotSupported_expectExceptionOnSecondCall() { } @Test - @SuppressWarnings("deprecation") public void onePluginSameConfig_sameInstance() { registerPlugin("plugin1"); CertificateProvider.Watcher mockWatcher1 = mock(CertificateProvider.Watcher.class); @@ -167,7 +166,6 @@ public void onePluginSameConfig_sameInstance() { } @Test - @SuppressWarnings("deprecation") public void onePluginSameConfig_secondWatcherAfterFirstNotify() { registerPlugin("plugin1"); CertificateProvider.Watcher mockWatcher1 = mock(CertificateProvider.Watcher.class); @@ -275,7 +273,6 @@ public void twoPlugins_differentInstance() { mockWatcher1, handle1, certProviderProvider1, mockWatcher2, handle2, certProviderProvider2); } - @SuppressWarnings("deprecation") private static void checkDifferentInstances( CertificateProvider.Watcher mockWatcher1, CertificateProviderStore.Handle handle1, diff --git a/xds/src/test/java/io/grpc/xds/internal/security/certprovider/FileWatcherCertificateProviderProviderTest.java b/xds/src/test/java/io/grpc/xds/internal/security/certprovider/FileWatcherCertificateProviderProviderTest.java index a0bdd618004..304a2dd5441 100644 --- a/xds/src/test/java/io/grpc/xds/internal/security/certprovider/FileWatcherCertificateProviderProviderTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/security/certprovider/FileWatcherCertificateProviderProviderTest.java @@ -24,22 +24,28 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import com.google.common.collect.ImmutableList; import io.grpc.internal.JsonParser; import io.grpc.internal.TimeProvider; import java.io.IOException; +import java.util.Collection; import java.util.Map; import java.util.concurrent.ScheduledExecutorService; +import org.junit.After; +import org.junit.Assume; import org.junit.Before; import org.junit.Rule; import org.junit.Test; import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; +import org.junit.runners.Parameterized; +import org.junit.runners.Parameterized.Parameter; +import org.junit.runners.Parameterized.Parameters; import org.mockito.Mock; import org.mockito.junit.MockitoJUnit; import org.mockito.junit.MockitoRule; /** Unit tests for {@link FileWatcherCertificateProviderProvider}. */ -@RunWith(JUnit4.class) +@RunWith(Parameterized.class) public class FileWatcherCertificateProviderProviderTest { @Rule public final MockitoRule mocks = MockitoJUnit.rule(); @@ -48,13 +54,28 @@ public class FileWatcherCertificateProviderProviderTest { scheduledExecutorServiceFactory; @Mock private TimeProvider timeProvider; + @Parameter + public boolean enableSpiffe; + private boolean originalEnableSpiffe; private FileWatcherCertificateProviderProvider provider; + @Parameters(name = "enableSpiffe={0}") + public static Collection data() { + return ImmutableList.of(true, false); + } + @Before public void setUp() throws IOException { provider = new FileWatcherCertificateProviderProvider( fileWatcherCertificateProviderFactory, scheduledExecutorServiceFactory, timeProvider); + originalEnableSpiffe = FileWatcherCertificateProviderProvider.enableSpiffe; + FileWatcherCertificateProviderProvider.enableSpiffe = enableSpiffe; + } + + @After + public void restoreEnvironment() { + FileWatcherCertificateProviderProvider.enableSpiffe = originalEnableSpiffe; } @Test @@ -85,6 +106,30 @@ public void createProvider_minimalConfig() throws IOException { eq("/var/run/gke-spiffe/certs/certificates.pem"), eq("/var/run/gke-spiffe/certs/private_key.pem"), eq("/var/run/gke-spiffe/certs/ca_certificates.pem"), + eq(null), + eq(600L), + eq(mockService), + eq(timeProvider)); + } + + @Test + public void createProvider_minimalSpiffeConfig() throws IOException { + Assume.assumeTrue(enableSpiffe); + CertificateProvider.DistributorWatcher distWatcher = + new CertificateProvider.DistributorWatcher(); + @SuppressWarnings("unchecked") + Map map = (Map) JsonParser.parse(MINIMAL_FILE_WATCHER_WITH_SPIFFE_CONFIG); + ScheduledExecutorService mockService = mock(ScheduledExecutorService.class); + when(scheduledExecutorServiceFactory.create()).thenReturn(mockService); + provider.createCertificateProvider(map, distWatcher, true); + verify(fileWatcherCertificateProviderFactory, times(1)) + .create( + eq(distWatcher), + eq(true), + eq("/var/run/gke-spiffe/certs/certificates.pem"), + eq("/var/run/gke-spiffe/certs/private_key.pem"), + eq(null), + eq("/var/run/gke-spiffe/certs/spiffe_bundle.json"), eq(600L), eq(mockService), eq(timeProvider)); @@ -106,6 +151,30 @@ public void createProvider_fullConfig() throws IOException { eq("/var/run/gke-spiffe/certs/certificates2.pem"), eq("/var/run/gke-spiffe/certs/private_key3.pem"), eq("/var/run/gke-spiffe/certs/ca_certificates4.pem"), + eq(null), + eq(7890L), + eq(mockService), + eq(timeProvider)); + } + + @Test + public void createProvider_spiffeConfig() throws IOException { + Assume.assumeTrue(enableSpiffe); + CertificateProvider.DistributorWatcher distWatcher = + new CertificateProvider.DistributorWatcher(); + @SuppressWarnings("unchecked") + Map map = (Map) JsonParser.parse(FULL_FILE_WATCHER_WITH_SPIFFE_CONFIG); + ScheduledExecutorService mockService = mock(ScheduledExecutorService.class); + when(scheduledExecutorServiceFactory.create()).thenReturn(mockService); + provider.createCertificateProvider(map, distWatcher, true); + verify(fileWatcherCertificateProviderFactory, times(1)) + .create( + eq(distWatcher), + eq(true), + eq("/var/run/gke-spiffe/certs/certificates2.pem"), + eq("/var/run/gke-spiffe/certs/private_key3.pem"), + eq(null), + eq("/var/run/gke-spiffe/certs/spiffe_bundle.json"), eq(7890L), eq(mockService), eq(timeProvider)); @@ -157,15 +226,18 @@ public void createProvider_missingKey_expectException() throws IOException { @Test public void createProvider_missingRoot_expectException() throws IOException { + String expectedMessage = enableSpiffe ? "either 'ca_certificate_file' or " + + "'spiffe_trust_bundle_map_file' is required in the config" + : "'ca_certificate_file' is required in the config"; CertificateProvider.DistributorWatcher distWatcher = new CertificateProvider.DistributorWatcher(); @SuppressWarnings("unchecked") - Map map = (Map) JsonParser.parse(MISSING_ROOT_CONFIG); + Map map = (Map) JsonParser.parse(MISSING_ROOT_AND_SPIFFE_CONFIG); try { provider.createCertificateProvider(map, distWatcher, true); fail("exception expected"); } catch (NullPointerException npe) { - assertThat(npe).hasMessageThat().isEqualTo("'ca_certificate_file' is required in the config"); + assertThat(npe).hasMessageThat().isEqualTo(expectedMessage); } } @@ -176,6 +248,14 @@ public void createProvider_missingRoot_expectException() throws IOException { + " \"ca_certificate_file\": \"/var/run/gke-spiffe/certs/ca_certificates.pem\"" + " }"; + private static final String MINIMAL_FILE_WATCHER_WITH_SPIFFE_CONFIG = + "{\n" + + " \"certificate_file\": \"/var/run/gke-spiffe/certs/certificates.pem\"," + + " \"private_key_file\": \"/var/run/gke-spiffe/certs/private_key.pem\"," + + " \"spiffe_trust_bundle_map_file\":" + + " \"/var/run/gke-spiffe/certs/spiffe_bundle.json\"" + + " }"; + private static final String FULL_FILE_WATCHER_CONFIG = "{\n" + " \"certificate_file\": \"/var/run/gke-spiffe/certs/certificates2.pem\"," @@ -184,6 +264,16 @@ public void createProvider_missingRoot_expectException() throws IOException { + " \"refresh_interval\": \"7890s\"" + " }"; + private static final String FULL_FILE_WATCHER_WITH_SPIFFE_CONFIG = + "{\n" + + " \"certificate_file\": \"/var/run/gke-spiffe/certs/certificates2.pem\"," + + " \"private_key_file\": \"/var/run/gke-spiffe/certs/private_key3.pem\"," + + " \"ca_certificate_file\": \"/var/run/gke-spiffe/certs/ca_certificates4.pem\"," + + " \"spiffe_trust_bundle_map_file\":" + + " \"/var/run/gke-spiffe/certs/spiffe_bundle.json\"," + + " \"refresh_interval\": \"7890s\"" + + " }"; + private static final String MISSING_CERT_CONFIG = "{\n" + " \"private_key_file\": \"/var/run/gke-spiffe/certs/private_key.pem\"," @@ -196,7 +286,7 @@ public void createProvider_missingRoot_expectException() throws IOException { + " \"ca_certificate_file\": \"/var/run/gke-spiffe/certs/ca_certificates.pem\"" + " }"; - private static final String MISSING_ROOT_CONFIG = + private static final String MISSING_ROOT_AND_SPIFFE_CONFIG = "{\n" + " \"certificate_file\": \"/var/run/gke-spiffe/certs/certificates.pem\"," + " \"private_key_file\": \"/var/run/gke-spiffe/certs/private_key.pem\"" diff --git a/xds/src/test/java/io/grpc/xds/internal/security/certprovider/FileWatcherCertificateProviderTest.java b/xds/src/test/java/io/grpc/xds/internal/security/certprovider/FileWatcherCertificateProviderTest.java index 210ec056732..f6fdc51dece 100644 --- a/xds/src/test/java/io/grpc/xds/internal/security/certprovider/FileWatcherCertificateProviderTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/security/certprovider/FileWatcherCertificateProviderTest.java @@ -23,6 +23,7 @@ import static io.grpc.xds.internal.security.CommonTlsContextTestsUtil.SERVER_0_KEY_FILE; import static io.grpc.xds.internal.security.CommonTlsContextTestsUtil.SERVER_0_PEM_FILE; import static io.grpc.xds.internal.security.CommonTlsContextTestsUtil.SERVER_1_PEM_FILE; +import static io.grpc.xds.internal.security.CommonTlsContextTestsUtil.SPIFFE_TRUST_MAP_1_FILE; import static java.nio.file.StandardCopyOption.REPLACE_EXISTING; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; @@ -47,6 +48,7 @@ import java.security.cert.X509Certificate; import java.util.ArrayList; import java.util.List; +import java.util.Map; import java.util.concurrent.Delayed; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ScheduledFuture; @@ -73,6 +75,7 @@ public class FileWatcherCertificateProviderTest { private static final String CERT_FILE = "cert.pem"; private static final String KEY_FILE = "key.pem"; private static final String ROOT_FILE = "root.pem"; + private static final String SPIFFE_TRUST_MAP_FILE = "spiffebundle.json"; @Mock private CertificateProvider.Watcher mockWatcher; @Mock private ScheduledExecutorService timeService; @@ -84,28 +87,33 @@ public class FileWatcherCertificateProviderTest { private String certFile; private String keyFile; private String rootFile; + private String spiffeTrustMapFile; private FileWatcherCertificateProvider provider; + private DistributorWatcher watcher; @Before public void setUp() throws IOException { - DistributorWatcher watcher = new DistributorWatcher(); + watcher = new DistributorWatcher(); watcher.addWatcher(mockWatcher); certFile = new File(tempFolder.getRoot(), CERT_FILE).getAbsolutePath(); keyFile = new File(tempFolder.getRoot(), KEY_FILE).getAbsolutePath(); rootFile = new File(tempFolder.getRoot(), ROOT_FILE).getAbsolutePath(); + spiffeTrustMapFile = new File(tempFolder.getRoot(), SPIFFE_TRUST_MAP_FILE).getAbsolutePath(); provider = - new FileWatcherCertificateProvider( - watcher, true, certFile, keyFile, rootFile, 600L, timeService, timeProvider); + new FileWatcherCertificateProvider(watcher, true, certFile, keyFile, rootFile, null, 600L, + timeService, timeProvider); } private void populateTarget( String certFileSource, String keyFileSource, String rootFileSource, + String spiffeTrustMapFileSource, boolean deleteCurCert, boolean deleteCurKey, + boolean deleteCurSpiffeTrustMap, boolean deleteCurRoot) throws IOException { if (deleteCurCert) { @@ -135,6 +143,17 @@ private void populateTarget( Files.setLastModifiedTime( Paths.get(rootFile), FileTime.fromMillis(timeProvider.currentTimeMillis())); } + if (deleteCurSpiffeTrustMap) { + Files.delete(Paths.get(spiffeTrustMapFile)); + } + if (spiffeTrustMapFileSource != null) { + spiffeTrustMapFileSource = CommonTlsContextTestsUtil + .getTempFileNameForResourcesFile(spiffeTrustMapFileSource); + Files.copy(Paths.get(spiffeTrustMapFileSource), + Paths.get(spiffeTrustMapFile), REPLACE_EXISTING); + Files.setLastModifiedTime( + Paths.get(spiffeTrustMapFile), FileTime.fromMillis(timeProvider.currentTimeMillis())); + } } @Test @@ -144,9 +163,9 @@ public void getCertificateAndCheckUpdates() throws IOException, CertificateExcep doReturn(scheduledFuture) .when(timeService) .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS)); - populateTarget(CLIENT_PEM_FILE, CLIENT_KEY_FILE, CA_PEM_FILE, false, false, false); + populateTarget(CLIENT_PEM_FILE, CLIENT_KEY_FILE, CA_PEM_FILE, null, false, false, false, false); provider.checkAndReloadCertificates(); - verifyWatcherUpdates(CLIENT_PEM_FILE, CA_PEM_FILE); + verifyWatcherUpdates(CLIENT_PEM_FILE, CA_PEM_FILE, null); verifyTimeServiceAndScheduledFuture(); reset(mockWatcher, timeService); @@ -165,7 +184,7 @@ public void allUpdateSecondTime() throws IOException, CertificateException, Inte doReturn(scheduledFuture) .when(timeService) .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS)); - populateTarget(CLIENT_PEM_FILE, CLIENT_KEY_FILE, CA_PEM_FILE, false, false, false); + populateTarget(CLIENT_PEM_FILE, CLIENT_KEY_FILE, CA_PEM_FILE, null, false, false, false, false); provider.checkAndReloadCertificates(); reset(mockWatcher, timeService); @@ -173,9 +192,10 @@ public void allUpdateSecondTime() throws IOException, CertificateException, Inte .when(timeService) .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS)); timeProvider.forwardTime(1, TimeUnit.SECONDS); - populateTarget(SERVER_0_PEM_FILE, SERVER_0_KEY_FILE, SERVER_1_PEM_FILE, false, false, false); + populateTarget(SERVER_0_PEM_FILE, SERVER_0_KEY_FILE, SERVER_1_PEM_FILE, null, false, false, + false, false); provider.checkAndReloadCertificates(); - verifyWatcherUpdates(SERVER_0_PEM_FILE, SERVER_1_PEM_FILE); + verifyWatcherUpdates(SERVER_0_PEM_FILE, SERVER_1_PEM_FILE, null); verifyTimeServiceAndScheduledFuture(); } @@ -186,12 +206,13 @@ public void closeDoesNotScheduleNext() throws IOException, CertificateException doReturn(scheduledFuture) .when(timeService) .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS)); - populateTarget(CLIENT_PEM_FILE, CLIENT_KEY_FILE, CA_PEM_FILE, false, false, false); + populateTarget(CLIENT_PEM_FILE, CLIENT_KEY_FILE, CA_PEM_FILE, null, false, false, false, false); provider.close(); provider.checkAndReloadCertificates(); verify(mockWatcher, never()) .updateCertificate(any(PrivateKey.class), ArgumentMatchers.anyList()); verify(mockWatcher, never()).updateTrustedRoots(ArgumentMatchers.anyList()); + verify(mockWatcher, never()).updateSpiffeTrustMap(ArgumentMatchers.anyMap()); verify(timeService, never()).schedule(any(Runnable.class), any(Long.TYPE), any(TimeUnit.class)); verify(timeService, times(1)).shutdownNow(); } @@ -204,7 +225,7 @@ public void rootFileUpdateOnly() throws IOException, CertificateException, Inter doReturn(scheduledFuture) .when(timeService) .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS)); - populateTarget(CLIENT_PEM_FILE, CLIENT_KEY_FILE, CA_PEM_FILE, false, false, false); + populateTarget(CLIENT_PEM_FILE, CLIENT_KEY_FILE, CA_PEM_FILE, null, false, false, false, false); provider.checkAndReloadCertificates(); reset(mockWatcher, timeService); @@ -212,9 +233,9 @@ public void rootFileUpdateOnly() throws IOException, CertificateException, Inter .when(timeService) .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS)); timeProvider.forwardTime(1, TimeUnit.SECONDS); - populateTarget(null, null, SERVER_1_PEM_FILE, false, false, false); + populateTarget(null, null, SERVER_1_PEM_FILE, null, false, false, false, false); provider.checkAndReloadCertificates(); - verifyWatcherUpdates(null, SERVER_1_PEM_FILE); + verifyWatcherUpdates(null, SERVER_1_PEM_FILE, null); verifyTimeServiceAndScheduledFuture(); } @@ -226,7 +247,32 @@ public void certAndKeyFileUpdateOnly() doReturn(scheduledFuture) .when(timeService) .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS)); - populateTarget(CLIENT_PEM_FILE, CLIENT_KEY_FILE, CA_PEM_FILE, false, false, false); + populateTarget(CLIENT_PEM_FILE, CLIENT_KEY_FILE, CA_PEM_FILE, null, false, false, false, false); + provider.checkAndReloadCertificates(); + + reset(mockWatcher, timeService); + doReturn(scheduledFuture) + .when(timeService) + .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS)); + timeProvider.forwardTime(1, TimeUnit.SECONDS); + populateTarget(SERVER_0_PEM_FILE, SERVER_0_KEY_FILE, null, null, false, false, false, false); + provider.checkAndReloadCertificates(); + verifyWatcherUpdates(SERVER_0_PEM_FILE, null, null); + verifyTimeServiceAndScheduledFuture(); + } + + @Test + public void certFileUpdateOnly() + throws IOException, CertificateException, InterruptedException { + TestScheduledFuture scheduledFuture = + new TestScheduledFuture<>(); + doReturn(scheduledFuture) + .when(timeService) + .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS)); + // Ideally we'd use a matching cert/key pair here, but we don't actually have any ready-made. + // The test doesn't notice they don't match though. + populateTarget( + CLIENT_PEM_FILE, SERVER_0_KEY_FILE, CA_PEM_FILE, null, false, false, false, false); provider.checkAndReloadCertificates(); reset(mockWatcher, timeService); @@ -234,9 +280,72 @@ public void certAndKeyFileUpdateOnly() .when(timeService) .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS)); timeProvider.forwardTime(1, TimeUnit.SECONDS); - populateTarget(SERVER_0_PEM_FILE, SERVER_0_KEY_FILE, null, false, false, false); + // It's normal to get a newer cert while continuing to use the same private key + populateTarget(SERVER_0_PEM_FILE, null, null, null, false, false, false, false); provider.checkAndReloadCertificates(); - verifyWatcherUpdates(SERVER_0_PEM_FILE, null); + verifyWatcherUpdates(SERVER_0_PEM_FILE, null, null); + verifyTimeServiceAndScheduledFuture(); + } + + @Test + public void keyFileUpdateOnly() + throws IOException, CertificateException, InterruptedException { + TestScheduledFuture scheduledFuture = + new TestScheduledFuture<>(); + doReturn(scheduledFuture) + .when(timeService) + .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS)); + // Assume the key/cert is not updated atomically and we see a tear between them. Or maybe this + // was just a bug. + populateTarget( + SERVER_0_PEM_FILE, CLIENT_KEY_FILE, CA_PEM_FILE, null, false, false, false, false); + provider.checkAndReloadCertificates(); + + reset(mockWatcher, timeService); + doReturn(scheduledFuture) + .when(timeService) + .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS)); + timeProvider.forwardTime(1, TimeUnit.SECONDS); + // Even though it is strange the key updated without a cert update, we do still want to use the + // new files, as this recovers from the earlier tear. + populateTarget(null, SERVER_0_KEY_FILE, null, null, false, false, false, false); + provider.checkAndReloadCertificates(); + verifyWatcherUpdates(SERVER_0_PEM_FILE, null, null); + verifyTimeServiceAndScheduledFuture(); + } + + @Test + public void spiffeTrustMapFileUpdateOnly() throws Exception { + provider = new FileWatcherCertificateProvider(watcher, true, certFile, keyFile, null, + spiffeTrustMapFile, 600L, timeService, timeProvider); + TestScheduledFuture scheduledFuture = + new TestScheduledFuture<>(); + doReturn(scheduledFuture) + .when(timeService) + .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS)); + populateTarget(CLIENT_PEM_FILE, CLIENT_KEY_FILE, null, null, false, false, false, false); + provider.checkAndReloadCertificates(); + verify(mockWatcher, never()).updateSpiffeTrustMap(ArgumentMatchers.anyMap()); + + reset(timeService); + doReturn(scheduledFuture) + .when(timeService) + .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS)); + timeProvider.forwardTime(1, TimeUnit.SECONDS); + populateTarget(CLIENT_PEM_FILE, CLIENT_KEY_FILE, null, SPIFFE_TRUST_MAP_FILE, false, + false, false, false); + provider.checkAndReloadCertificates(); + verify(mockWatcher, times(1)).updateSpiffeTrustMap(ArgumentMatchers.anyMap()); + + reset(timeService); + doReturn(scheduledFuture) + .when(timeService) + .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS)); + timeProvider.forwardTime(1, TimeUnit.SECONDS); + populateTarget(CLIENT_PEM_FILE, CLIENT_KEY_FILE, null, SPIFFE_TRUST_MAP_1_FILE, false, + false, false, false); + provider.checkAndReloadCertificates(); + verify(mockWatcher, times(2)).updateSpiffeTrustMap(ArgumentMatchers.anyMap()); verifyTimeServiceAndScheduledFuture(); } @@ -247,7 +356,7 @@ public void getCertificate_initialMissingCertFile() throws IOException { doReturn(scheduledFuture) .when(timeService) .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS)); - populateTarget(null, CLIENT_KEY_FILE, CA_PEM_FILE, false, false, false); + populateTarget(null, CLIENT_KEY_FILE, CA_PEM_FILE, null, false, false, false, false); provider.checkAndReloadCertificates(); verifyWatcherErrorUpdates(Status.Code.UNKNOWN, NoSuchFileException.class, 0, 1, "cert.pem"); } @@ -255,13 +364,14 @@ public void getCertificate_initialMissingCertFile() throws IOException { @Test public void getCertificate_missingCertFile() throws IOException, InterruptedException { commonErrorTest( - null, CLIENT_KEY_FILE, CA_PEM_FILE, NoSuchFileException.class, 0, 1, 0, 0, "cert.pem"); + null, CLIENT_KEY_FILE, CA_PEM_FILE, null, NoSuchFileException.class, 0, 1, 0, 0, + "cert.pem"); } @Test public void getCertificate_missingKeyFile() throws IOException, InterruptedException { commonErrorTest( - CLIENT_PEM_FILE, null, CA_PEM_FILE, NoSuchFileException.class, 0, 1, 0, 0, "key.pem"); + CLIENT_PEM_FILE, null, CA_PEM_FILE, null, NoSuchFileException.class, 0, 1, 0, 0, "key.pem"); } @Test @@ -270,6 +380,7 @@ public void getCertificate_badKeyFile() throws IOException, InterruptedException CLIENT_PEM_FILE, SERVER_0_PEM_FILE, CA_PEM_FILE, + null, java.security.spec.InvalidKeySpecException.class, 0, 1, @@ -285,12 +396,13 @@ public void getCertificate_missingRootFile() throws IOException, InterruptedExce doReturn(scheduledFuture) .when(timeService) .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS)); - populateTarget(SERVER_0_PEM_FILE, SERVER_0_KEY_FILE, SERVER_1_PEM_FILE, false, false, false); + populateTarget(SERVER_0_PEM_FILE, SERVER_0_KEY_FILE, SERVER_1_PEM_FILE, null, false, false, + false, false); provider.checkAndReloadCertificates(); reset(mockWatcher); timeProvider.forwardTime(1, TimeUnit.SECONDS); - populateTarget(CLIENT_PEM_FILE, CLIENT_KEY_FILE, null, false, false, true); + populateTarget(CLIENT_PEM_FILE, CLIENT_KEY_FILE, null, null, false, false, false, true); timeProvider.forwardTime( CERT0_EXPIRY_TIME_MILLIS - 610_000L - timeProvider.currentTimeMillis(), TimeUnit.MILLISECONDS); @@ -302,6 +414,7 @@ private void commonErrorTest( String certFile, String keyFile, String rootFile, + String spiffeFile, Class throwableType, int firstUpdateCertCount, int firstUpdateRootCount, @@ -314,13 +427,15 @@ private void commonErrorTest( doReturn(scheduledFuture) .when(timeService) .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS)); - populateTarget(SERVER_0_PEM_FILE, SERVER_0_KEY_FILE, SERVER_1_PEM_FILE, false, false, false); + populateTarget(SERVER_0_PEM_FILE, SERVER_0_KEY_FILE, SERVER_1_PEM_FILE, + SPIFFE_TRUST_MAP_1_FILE, false, false, false, false); provider.checkAndReloadCertificates(); reset(mockWatcher); timeProvider.forwardTime(1, TimeUnit.SECONDS); populateTarget( - certFile, keyFile, rootFile, certFile == null, keyFile == null, rootFile == null); + certFile, keyFile, rootFile, spiffeFile, certFile == null, keyFile == null, + rootFile == null, spiffeFile == null); timeProvider.forwardTime( CERT0_EXPIRY_TIME_MILLIS - 610_000L - timeProvider.currentTimeMillis(), TimeUnit.MILLISECONDS); @@ -372,7 +487,7 @@ private void verifyTimeServiceAndScheduledFuture() { assertThat(provider.scheduledFuture.isCancelled()).isFalse(); } - private void verifyWatcherUpdates(String certPemFile, String rootPemFile) + private void verifyWatcherUpdates(String certPemFile, String rootPemFile, String spiffeFile) throws IOException, CertificateException { if (certPemFile != null) { @SuppressWarnings("unchecked") @@ -399,6 +514,17 @@ private void verifyWatcherUpdates(String certPemFile, String rootPemFile) } else { verify(mockWatcher, never()).updateTrustedRoots(ArgumentMatchers.anyList()); } + if (spiffeFile != null) { + @SuppressWarnings("unchecked") + ArgumentCaptor>> spiffeCaptor = + ArgumentCaptor.forClass(Map.class); + verify(mockWatcher, times(1)).updateSpiffeTrustMap(spiffeCaptor.capture()); + Map> trustMap = spiffeCaptor.getValue(); + assertThat(trustMap).hasSize(2); + verify(mockWatcher, never()).onError(any(Status.class)); + } else { + verify(mockWatcher, never()).updateSpiffeTrustMap(ArgumentMatchers.anyMap()); + } } static class TestScheduledFuture implements ScheduledFuture { diff --git a/xds/src/test/java/io/grpc/xds/internal/security/trust/XdsTrustManagerFactoryTest.java b/xds/src/test/java/io/grpc/xds/internal/security/trust/XdsTrustManagerFactoryTest.java index 77749814cf2..3077482b10b 100644 --- a/xds/src/test/java/io/grpc/xds/internal/security/trust/XdsTrustManagerFactoryTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/security/trust/XdsTrustManagerFactoryTest.java @@ -23,6 +23,8 @@ import static io.grpc.xds.internal.security.CommonTlsContextTestsUtil.CLIENT_PEM_FILE; import static io.grpc.xds.internal.security.CommonTlsContextTestsUtil.SERVER_1_PEM_FILE; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import com.google.protobuf.ByteString; import io.envoyproxy.envoy.config.core.v3.DataSource; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateValidationContext; @@ -89,7 +91,7 @@ public void constructor_fromRootCert() CertificateValidationContext staticValidationContext = buildStaticValidationContext("san1", "san2"); XdsTrustManagerFactory factory = - new XdsTrustManagerFactory(new X509Certificate[]{x509Cert}, staticValidationContext); + new XdsTrustManagerFactory(new X509Certificate[]{x509Cert}, staticValidationContext, false); assertThat(factory).isNotNull(); TrustManager[] tms = factory.getTrustManagers(); assertThat(tms).isNotNull(); @@ -105,6 +107,46 @@ public void constructor_fromRootCert() .isEqualTo(CertificateUtils.toX509Certificates(TlsTesting.loadCert(CA_PEM_FILE))[0]); } + @Test + public void constructor_fromSpiffeTrustMap() + throws CertificateException, IOException, CertStoreException { + X509Certificate x509Cert = TestUtils.loadX509Cert(CA_PEM_FILE); + CertificateValidationContext staticValidationContext = buildStaticValidationContext("san1", + "san2"); + // Single domain and single cert + XdsTrustManagerFactory factory = new XdsTrustManagerFactory(ImmutableMap + .of("example.com", ImmutableList.of(x509Cert)), staticValidationContext, false); + assertThat(factory).isNotNull(); + TrustManager[] tms = factory.getTrustManagers(); + assertThat(tms).isNotNull(); + assertThat(tms).hasLength(1); + TrustManager myTm = tms[0]; + assertThat(myTm).isInstanceOf(XdsX509TrustManager.class); + XdsX509TrustManager xdsX509TrustManager = (XdsX509TrustManager) myTm; + assertThat(xdsX509TrustManager.getAcceptedIssuers()).isNotNull(); + assertThat(xdsX509TrustManager.getAcceptedIssuers()).hasLength(1); + assertThat(xdsX509TrustManager.getAcceptedIssuers()[0].getIssuerX500Principal().getName()) + .isEqualTo("CN=testca,O=Internet Widgits Pty Ltd,ST=Some-State,C=AU"); + // Multiple domains and multiple certs for one of it + X509Certificate anotherCert = TestUtils.loadX509Cert(CLIENT_PEM_FILE); + factory = new XdsTrustManagerFactory(ImmutableMap + .of("example.com", ImmutableList.of(x509Cert), + "google.com", ImmutableList.of(x509Cert, anotherCert)), staticValidationContext, false); + assertThat(factory).isNotNull(); + tms = factory.getTrustManagers(); + assertThat(tms).isNotNull(); + assertThat(tms).hasLength(1); + myTm = tms[0]; + assertThat(myTm).isInstanceOf(XdsX509TrustManager.class); + xdsX509TrustManager = (XdsX509TrustManager) myTm; + assertThat(xdsX509TrustManager.getAcceptedIssuers()).isNotNull(); + assertThat(xdsX509TrustManager.getAcceptedIssuers()).hasLength(2); + assertThat(xdsX509TrustManager.getAcceptedIssuers()[0].getIssuerX500Principal().getName()) + .isEqualTo("CN=testca,O=Internet Widgits Pty Ltd,ST=Some-State,C=AU"); + assertThat(xdsX509TrustManager.getAcceptedIssuers()[1].getIssuerX500Principal().getName()) + .isEqualTo("CN=testca,O=Internet Widgits Pty Ltd,ST=Some-State,C=AU"); + } + @Test public void constructorRootCert_checkServerTrusted() throws CertificateException, IOException, CertStoreException { @@ -112,7 +154,7 @@ public void constructorRootCert_checkServerTrusted() CertificateValidationContext staticValidationContext = buildStaticValidationContext("san1", "waterzooi.test.google.be"); XdsTrustManagerFactory factory = - new XdsTrustManagerFactory(new X509Certificate[]{x509Cert}, staticValidationContext); + new XdsTrustManagerFactory(new X509Certificate[]{x509Cert}, staticValidationContext, false); XdsX509TrustManager xdsX509TrustManager = (XdsX509TrustManager) factory.getTrustManagers()[0]; X509Certificate[] serverChain = CertificateUtils.toX509Certificates(TlsTesting.loadCert(SERVER_1_PEM_FILE)); @@ -125,7 +167,7 @@ public void constructorRootCert_nonStaticContext_throwsException() X509Certificate x509Cert = TestUtils.loadX509Cert(CA_PEM_FILE); try { new XdsTrustManagerFactory( - new X509Certificate[] {x509Cert}, getCertContextFromPath(CA_PEM_FILE)); + new X509Certificate[] {x509Cert}, getCertContextFromPath(CA_PEM_FILE), false); Assert.fail("no exception thrown"); } catch (IllegalArgumentException expected) { assertThat(expected) @@ -134,6 +176,19 @@ public void constructorRootCert_nonStaticContext_throwsException() } } + @Test + public void constructorRootCert_nonStaticContext_systemRootCerts_valid() + throws CertificateException, IOException, CertStoreException { + X509Certificate x509Cert = TestUtils.loadX509Cert(CA_PEM_FILE); + CertificateValidationContext certValidationContext = CertificateValidationContext.newBuilder() + .setTrustedCa( + DataSource.newBuilder().setFilename(TestUtils.loadCert(CA_PEM_FILE).getAbsolutePath())) + .setSystemRootCerts(CertificateValidationContext.SystemRootCerts.getDefaultInstance()) + .build(); + XdsTrustManagerFactory unused = + new XdsTrustManagerFactory(new X509Certificate[] {x509Cert}, certValidationContext, false); + } + @Test public void constructorRootCert_checkServerTrusted_throwsException() throws CertificateException, IOException, CertStoreException { @@ -141,7 +196,7 @@ public void constructorRootCert_checkServerTrusted_throwsException() CertificateValidationContext staticValidationContext = buildStaticValidationContext("san1", "san2"); XdsTrustManagerFactory factory = - new XdsTrustManagerFactory(new X509Certificate[]{x509Cert}, staticValidationContext); + new XdsTrustManagerFactory(new X509Certificate[]{x509Cert}, staticValidationContext, false); XdsX509TrustManager xdsX509TrustManager = (XdsX509TrustManager) factory.getTrustManagers()[0]; X509Certificate[] serverChain = CertificateUtils.toX509Certificates(TlsTesting.loadCert(SERVER_1_PEM_FILE)); @@ -162,7 +217,7 @@ public void constructorRootCert_checkClientTrusted_throwsException() CertificateValidationContext staticValidationContext = buildStaticValidationContext("san1", "san2"); XdsTrustManagerFactory factory = - new XdsTrustManagerFactory(new X509Certificate[]{x509Cert}, staticValidationContext); + new XdsTrustManagerFactory(new X509Certificate[]{x509Cert}, staticValidationContext, false); XdsX509TrustManager xdsX509TrustManager = (XdsX509TrustManager) factory.getTrustManagers()[0]; X509Certificate[] clientChain = CertificateUtils.toX509Certificates(TlsTesting.loadCert(SERVER_1_PEM_FILE)); diff --git a/xds/src/test/java/io/grpc/xds/internal/security/trust/XdsX509TrustManagerTest.java b/xds/src/test/java/io/grpc/xds/internal/security/trust/XdsX509TrustManagerTest.java index 08512396a4f..ffe0536f25b 100644 --- a/xds/src/test/java/io/grpc/xds/internal/security/trust/XdsX509TrustManagerTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/security/trust/XdsX509TrustManagerTest.java @@ -18,9 +18,13 @@ import static com.google.common.truth.Truth.assertThat; import static io.grpc.xds.internal.security.CommonTlsContextTestsUtil.BAD_SERVER_PEM_FILE; +import static io.grpc.xds.internal.security.CommonTlsContextTestsUtil.BAD_WILDCARD_DNS_PEM_FILE; import static io.grpc.xds.internal.security.CommonTlsContextTestsUtil.CA_PEM_FILE; import static io.grpc.xds.internal.security.CommonTlsContextTestsUtil.CLIENT_PEM_FILE; +import static io.grpc.xds.internal.security.CommonTlsContextTestsUtil.CLIENT_SPIFFE_PEM_FILE; +import static io.grpc.xds.internal.security.CommonTlsContextTestsUtil.SERVER_0_PEM_FILE; import static io.grpc.xds.internal.security.CommonTlsContextTestsUtil.SERVER_1_PEM_FILE; +import static io.grpc.xds.internal.security.CommonTlsContextTestsUtil.SERVER_1_SPIFFE_PEM_FILE; import static org.junit.Assert.fail; import static org.mockito.Mockito.CALLS_REAL_METHODS; import static org.mockito.Mockito.doReturn; @@ -30,6 +34,7 @@ import static org.mockito.Mockito.when; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateValidationContext; import io.envoyproxy.envoy.type.matcher.v3.RegexMatcher; import io.envoyproxy.envoy.type.matcher.v3.StringMatcher; @@ -38,6 +43,9 @@ import java.security.cert.CertStoreException; import java.security.cert.CertificateException; import java.security.cert.X509Certificate; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; import java.util.Collections; import java.util.List; import javax.net.ssl.SSLEngine; @@ -48,7 +56,8 @@ import org.junit.Rule; import org.junit.Test; import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; +import org.junit.runners.Parameterized; +import org.junit.runners.Parameterized.Parameters; import org.mockito.Mock; import org.mockito.junit.MockitoJUnit; import org.mockito.junit.MockitoRule; @@ -56,7 +65,7 @@ /** * Unit tests for {@link XdsX509TrustManager}. */ -@RunWith(JUnit4.class) +@RunWith(Parameterized.class) public class XdsX509TrustManagerTest { @Rule @@ -70,32 +79,40 @@ public class XdsX509TrustManagerTest { private XdsX509TrustManager trustManager; + private final TestParam testParam; + + public XdsX509TrustManagerTest(TestParam testParam) { + this.testParam = testParam; + } + @Test public void nullCertContextTest() throws CertificateException, IOException { - trustManager = new XdsX509TrustManager(null, mockDelegate); + trustManager = new XdsX509TrustManager(null, mockDelegate, false); X509Certificate[] certs = CertificateUtils.toX509Certificates(TlsTesting.loadCert(SERVER_1_PEM_FILE)); - trustManager.verifySubjectAltNameInChain(certs); + trustManager.verifySubjectAltNameInChain(certs, new ArrayList<>()); } @Test + @SuppressWarnings("deprecation") public void emptySanListContextTest() throws CertificateException, IOException { CertificateValidationContext certContext = CertificateValidationContext.getDefaultInstance(); - trustManager = new XdsX509TrustManager(certContext, mockDelegate); + trustManager = new XdsX509TrustManager(certContext, mockDelegate, false); X509Certificate[] certs = CertificateUtils.toX509Certificates(TlsTesting.loadCert(SERVER_1_PEM_FILE)); - trustManager.verifySubjectAltNameInChain(certs); + trustManager.verifySubjectAltNameInChain(certs, certContext.getMatchSubjectAltNamesList()); } @Test + @SuppressWarnings("deprecation") public void missingPeerCerts() { StringMatcher stringMatcher = StringMatcher.newBuilder().setExact("foo.com").build(); @SuppressWarnings("deprecation") CertificateValidationContext certContext = CertificateValidationContext.newBuilder().addMatchSubjectAltNames(stringMatcher).build(); - trustManager = new XdsX509TrustManager(certContext, mockDelegate); + trustManager = new XdsX509TrustManager(certContext, mockDelegate, false); try { - trustManager.verifySubjectAltNameInChain(null); + trustManager.verifySubjectAltNameInChain(null, certContext.getMatchSubjectAltNamesList()); fail("no exception thrown"); } catch (CertificateException expected) { assertThat(expected).hasMessageThat().isEqualTo("Peer certificate(s) missing"); @@ -103,14 +120,15 @@ public void missingPeerCerts() { } @Test + @SuppressWarnings("deprecation") public void emptyArrayPeerCerts() { StringMatcher stringMatcher = StringMatcher.newBuilder().setExact("foo.com").build(); - @SuppressWarnings("deprecation") CertificateValidationContext certContext = CertificateValidationContext.newBuilder().addMatchSubjectAltNames(stringMatcher).build(); - trustManager = new XdsX509TrustManager(certContext, mockDelegate); + trustManager = new XdsX509TrustManager(certContext, mockDelegate, false); try { - trustManager.verifySubjectAltNameInChain(new X509Certificate[0]); + trustManager.verifySubjectAltNameInChain( + new X509Certificate[0], certContext.getMatchSubjectAltNamesList()); fail("no exception thrown"); } catch (CertificateException expected) { assertThat(expected).hasMessageThat().isEqualTo("Peer certificate(s) missing"); @@ -118,16 +136,16 @@ public void emptyArrayPeerCerts() { } @Test + @SuppressWarnings("deprecation") public void noSansInPeerCerts() throws CertificateException, IOException { StringMatcher stringMatcher = StringMatcher.newBuilder().setExact("foo.com").build(); - @SuppressWarnings("deprecation") CertificateValidationContext certContext = CertificateValidationContext.newBuilder().addMatchSubjectAltNames(stringMatcher).build(); - trustManager = new XdsX509TrustManager(certContext, mockDelegate); + trustManager = new XdsX509TrustManager(certContext, mockDelegate, false); X509Certificate[] certs = CertificateUtils.toX509Certificates(TlsTesting.loadCert(CLIENT_PEM_FILE)); try { - trustManager.verifySubjectAltNameInChain(certs); + trustManager.verifySubjectAltNameInChain(certs, certContext.getMatchSubjectAltNamesList()); fail("no exception thrown"); } catch (CertificateException expected) { assertThat(expected).hasMessageThat().isEqualTo("Peer certificate SAN check failed"); @@ -135,22 +153,23 @@ public void noSansInPeerCerts() throws CertificateException, IOException { } @Test + @SuppressWarnings("deprecation") public void oneSanInPeerCertsVerifies() throws CertificateException, IOException { StringMatcher stringMatcher = StringMatcher.newBuilder() .setExact("waterzooi.test.google.be") .setIgnoreCase(false) .build(); - @SuppressWarnings("deprecation") CertificateValidationContext certContext = CertificateValidationContext.newBuilder().addMatchSubjectAltNames(stringMatcher).build(); - trustManager = new XdsX509TrustManager(certContext, mockDelegate); + trustManager = new XdsX509TrustManager(certContext, mockDelegate, false); X509Certificate[] certs = CertificateUtils.toX509Certificates(TlsTesting.loadCert(SERVER_1_PEM_FILE)); - trustManager.verifySubjectAltNameInChain(certs); + trustManager.verifySubjectAltNameInChain(certs, certContext.getMatchSubjectAltNamesList()); } @Test + @SuppressWarnings("deprecation") public void oneSanInPeerCertsVerifies_differentCase_expectException() throws CertificateException, IOException { StringMatcher stringMatcher = @@ -158,14 +177,13 @@ public void oneSanInPeerCertsVerifies_differentCase_expectException() .setExact("waterZooi.test.Google.be") .setIgnoreCase(false) .build(); - @SuppressWarnings("deprecation") CertificateValidationContext certContext = CertificateValidationContext.newBuilder().addMatchSubjectAltNames(stringMatcher).build(); - trustManager = new XdsX509TrustManager(certContext, mockDelegate); + trustManager = new XdsX509TrustManager(certContext, mockDelegate, false); X509Certificate[] certs = CertificateUtils.toX509Certificates(TlsTesting.loadCert(SERVER_1_PEM_FILE)); try { - trustManager.verifySubjectAltNameInChain(certs); + trustManager.verifySubjectAltNameInChain(certs, certContext.getMatchSubjectAltNamesList()); fail("no exception thrown"); } catch (CertificateException expected) { assertThat(expected).hasMessageThat().isEqualTo("Peer certificate SAN check failed"); @@ -173,47 +191,48 @@ public void oneSanInPeerCertsVerifies_differentCase_expectException() } @Test + @SuppressWarnings("deprecation") public void oneSanInPeerCertsVerifies_ignoreCase() throws CertificateException, IOException { StringMatcher stringMatcher = StringMatcher.newBuilder().setExact("Waterzooi.Test.google.be").setIgnoreCase(true).build(); @SuppressWarnings("deprecation") CertificateValidationContext certContext = CertificateValidationContext.newBuilder().addMatchSubjectAltNames(stringMatcher).build(); - trustManager = new XdsX509TrustManager(certContext, mockDelegate); + trustManager = new XdsX509TrustManager(certContext, mockDelegate, false); X509Certificate[] certs = CertificateUtils.toX509Certificates(TlsTesting.loadCert(SERVER_1_PEM_FILE)); - trustManager.verifySubjectAltNameInChain(certs); + trustManager.verifySubjectAltNameInChain(certs, certContext.getMatchSubjectAltNamesList()); } @Test + @SuppressWarnings("deprecation") public void oneSanInPeerCerts_prefix() throws CertificateException, IOException { StringMatcher stringMatcher = StringMatcher.newBuilder() .setPrefix("waterzooi.") // test.google.be .setIgnoreCase(false) .build(); - @SuppressWarnings("deprecation") CertificateValidationContext certContext = CertificateValidationContext.newBuilder().addMatchSubjectAltNames(stringMatcher).build(); - trustManager = new XdsX509TrustManager(certContext, mockDelegate); + trustManager = new XdsX509TrustManager(certContext, mockDelegate, false); X509Certificate[] certs = CertificateUtils.toX509Certificates(TlsTesting.loadCert(SERVER_1_PEM_FILE)); - trustManager.verifySubjectAltNameInChain(certs); + trustManager.verifySubjectAltNameInChain(certs, certContext.getMatchSubjectAltNamesList()); } @Test + @SuppressWarnings("deprecation") public void oneSanInPeerCertsPrefix_differentCase_expectException() throws CertificateException, IOException { StringMatcher stringMatcher = StringMatcher.newBuilder().setPrefix("waterZooi.").setIgnoreCase(false).build(); - @SuppressWarnings("deprecation") CertificateValidationContext certContext = CertificateValidationContext.newBuilder().addMatchSubjectAltNames(stringMatcher).build(); - trustManager = new XdsX509TrustManager(certContext, mockDelegate); + trustManager = new XdsX509TrustManager(certContext, mockDelegate, false); X509Certificate[] certs = CertificateUtils.toX509Certificates(TlsTesting.loadCert(SERVER_1_PEM_FILE)); try { - trustManager.verifySubjectAltNameInChain(certs); + trustManager.verifySubjectAltNameInChain(certs, certContext.getMatchSubjectAltNamesList()); fail("no exception thrown"); } catch (CertificateException expected) { assertThat(expected).hasMessageThat().isEqualTo("Peer certificate SAN check failed"); @@ -221,47 +240,47 @@ public void oneSanInPeerCertsPrefix_differentCase_expectException() } @Test + @SuppressWarnings("deprecation") public void oneSanInPeerCerts_prefixIgnoreCase() throws CertificateException, IOException { StringMatcher stringMatcher = StringMatcher.newBuilder() .setPrefix("WaterZooi.") // test.google.be .setIgnoreCase(true) .build(); - @SuppressWarnings("deprecation") CertificateValidationContext certContext = CertificateValidationContext.newBuilder().addMatchSubjectAltNames(stringMatcher).build(); - trustManager = new XdsX509TrustManager(certContext, mockDelegate); + trustManager = new XdsX509TrustManager(certContext, mockDelegate, false); X509Certificate[] certs = CertificateUtils.toX509Certificates(TlsTesting.loadCert(SERVER_1_PEM_FILE)); - trustManager.verifySubjectAltNameInChain(certs); + trustManager.verifySubjectAltNameInChain(certs, certContext.getMatchSubjectAltNamesList()); } @Test + @SuppressWarnings("deprecation") public void oneSanInPeerCerts_suffix() throws CertificateException, IOException { StringMatcher stringMatcher = StringMatcher.newBuilder().setSuffix(".google.be").setIgnoreCase(false).build(); - @SuppressWarnings("deprecation") CertificateValidationContext certContext = CertificateValidationContext.newBuilder().addMatchSubjectAltNames(stringMatcher).build(); - trustManager = new XdsX509TrustManager(certContext, mockDelegate); + trustManager = new XdsX509TrustManager(certContext, mockDelegate, false); X509Certificate[] certs = CertificateUtils.toX509Certificates(TlsTesting.loadCert(SERVER_1_PEM_FILE)); - trustManager.verifySubjectAltNameInChain(certs); + trustManager.verifySubjectAltNameInChain(certs, certContext.getMatchSubjectAltNamesList()); } @Test + @SuppressWarnings("deprecation") public void oneSanInPeerCertsSuffix_differentCase_expectException() throws CertificateException, IOException { StringMatcher stringMatcher = StringMatcher.newBuilder().setSuffix(".gooGle.bE").setIgnoreCase(false).build(); - @SuppressWarnings("deprecation") CertificateValidationContext certContext = CertificateValidationContext.newBuilder().addMatchSubjectAltNames(stringMatcher).build(); - trustManager = new XdsX509TrustManager(certContext, mockDelegate); + trustManager = new XdsX509TrustManager(certContext, mockDelegate, false); X509Certificate[] certs = CertificateUtils.toX509Certificates(TlsTesting.loadCert(SERVER_1_PEM_FILE)); try { - trustManager.verifySubjectAltNameInChain(certs); + trustManager.verifySubjectAltNameInChain(certs, certContext.getMatchSubjectAltNamesList()); fail("no exception thrown"); } catch (CertificateException expected) { assertThat(expected).hasMessageThat().isEqualTo("Peer certificate SAN check failed"); @@ -269,44 +288,45 @@ public void oneSanInPeerCertsSuffix_differentCase_expectException() } @Test + @SuppressWarnings("deprecation") public void oneSanInPeerCerts_suffixIgnoreCase() throws CertificateException, IOException { StringMatcher stringMatcher = StringMatcher.newBuilder().setSuffix(".GooGle.BE").setIgnoreCase(true).build(); @SuppressWarnings("deprecation") CertificateValidationContext certContext = CertificateValidationContext.newBuilder().addMatchSubjectAltNames(stringMatcher).build(); - trustManager = new XdsX509TrustManager(certContext, mockDelegate); + trustManager = new XdsX509TrustManager(certContext, mockDelegate, false); X509Certificate[] certs = CertificateUtils.toX509Certificates(TlsTesting.loadCert(SERVER_1_PEM_FILE)); - trustManager.verifySubjectAltNameInChain(certs); + trustManager.verifySubjectAltNameInChain(certs, certContext.getMatchSubjectAltNamesList()); } @Test + @SuppressWarnings("deprecation") public void oneSanInPeerCerts_substring() throws CertificateException, IOException { StringMatcher stringMatcher = StringMatcher.newBuilder().setContains("zooi.test.google").setIgnoreCase(false).build(); - @SuppressWarnings("deprecation") CertificateValidationContext certContext = CertificateValidationContext.newBuilder().addMatchSubjectAltNames(stringMatcher).build(); - trustManager = new XdsX509TrustManager(certContext, mockDelegate); + trustManager = new XdsX509TrustManager(certContext, mockDelegate, false); X509Certificate[] certs = CertificateUtils.toX509Certificates(TlsTesting.loadCert(SERVER_1_PEM_FILE)); - trustManager.verifySubjectAltNameInChain(certs); + trustManager.verifySubjectAltNameInChain(certs, certContext.getMatchSubjectAltNamesList()); } @Test + @SuppressWarnings("deprecation") public void oneSanInPeerCertsSubstring_differentCase_expectException() throws CertificateException, IOException { StringMatcher stringMatcher = StringMatcher.newBuilder().setContains("zooi.Test.gooGle").setIgnoreCase(false).build(); - @SuppressWarnings("deprecation") CertificateValidationContext certContext = CertificateValidationContext.newBuilder().addMatchSubjectAltNames(stringMatcher).build(); - trustManager = new XdsX509TrustManager(certContext, mockDelegate); + trustManager = new XdsX509TrustManager(certContext, mockDelegate, false); X509Certificate[] certs = CertificateUtils.toX509Certificates(TlsTesting.loadCert(SERVER_1_PEM_FILE)); try { - trustManager.verifySubjectAltNameInChain(certs); + trustManager.verifySubjectAltNameInChain(certs, certContext.getMatchSubjectAltNamesList()); fail("no exception thrown"); } catch (CertificateException expected) { assertThat(expected).hasMessageThat().isEqualTo("Peer certificate SAN check failed"); @@ -314,81 +334,81 @@ public void oneSanInPeerCertsSubstring_differentCase_expectException() } @Test + @SuppressWarnings("deprecation") public void oneSanInPeerCerts_substringIgnoreCase() throws CertificateException, IOException { StringMatcher stringMatcher = StringMatcher.newBuilder().setContains("zooI.Test.Google").setIgnoreCase(true).build(); - @SuppressWarnings("deprecation") CertificateValidationContext certContext = CertificateValidationContext.newBuilder().addMatchSubjectAltNames(stringMatcher).build(); - trustManager = new XdsX509TrustManager(certContext, mockDelegate); + trustManager = new XdsX509TrustManager(certContext, mockDelegate, false); X509Certificate[] certs = CertificateUtils.toX509Certificates(TlsTesting.loadCert(SERVER_1_PEM_FILE)); - trustManager.verifySubjectAltNameInChain(certs); + trustManager.verifySubjectAltNameInChain(certs, certContext.getMatchSubjectAltNamesList()); } @Test + @SuppressWarnings("deprecation") public void oneSanInPeerCerts_safeRegex() throws CertificateException, IOException { StringMatcher stringMatcher = StringMatcher.newBuilder() .setSafeRegex( RegexMatcher.newBuilder().setRegex("water[[:alpha:]]{1}ooi\\.test\\.google\\.be")) .build(); - @SuppressWarnings("deprecation") CertificateValidationContext certContext = CertificateValidationContext.newBuilder().addMatchSubjectAltNames(stringMatcher).build(); - trustManager = new XdsX509TrustManager(certContext, mockDelegate); + trustManager = new XdsX509TrustManager(certContext, mockDelegate, false); X509Certificate[] certs = CertificateUtils.toX509Certificates(TlsTesting.loadCert(SERVER_1_PEM_FILE)); - trustManager.verifySubjectAltNameInChain(certs); + trustManager.verifySubjectAltNameInChain(certs, certContext.getMatchSubjectAltNamesList()); } @Test + @SuppressWarnings("deprecation") public void oneSanInPeerCerts_safeRegex1() throws CertificateException, IOException { StringMatcher stringMatcher = StringMatcher.newBuilder() .setSafeRegex( RegexMatcher.newBuilder().setRegex("no-match-string|\\*\\.test\\.youtube\\.com")) .build(); - @SuppressWarnings("deprecation") CertificateValidationContext certContext = CertificateValidationContext.newBuilder().addMatchSubjectAltNames(stringMatcher).build(); - trustManager = new XdsX509TrustManager(certContext, mockDelegate); + trustManager = new XdsX509TrustManager(certContext, mockDelegate, false); X509Certificate[] certs = CertificateUtils.toX509Certificates(TlsTesting.loadCert(SERVER_1_PEM_FILE)); - trustManager.verifySubjectAltNameInChain(certs); + trustManager.verifySubjectAltNameInChain(certs, certContext.getMatchSubjectAltNamesList()); } @Test + @SuppressWarnings("deprecation") public void oneSanInPeerCerts_safeRegex_ipAddress() throws CertificateException, IOException { StringMatcher stringMatcher = StringMatcher.newBuilder() .setSafeRegex( RegexMatcher.newBuilder().setRegex("([[:digit:]]{1,3}\\.){3}[[:digit:]]{1,3}")) .build(); - @SuppressWarnings("deprecation") CertificateValidationContext certContext = CertificateValidationContext.newBuilder().addMatchSubjectAltNames(stringMatcher).build(); - trustManager = new XdsX509TrustManager(certContext, mockDelegate); + trustManager = new XdsX509TrustManager(certContext, mockDelegate, false); X509Certificate[] certs = CertificateUtils.toX509Certificates(TlsTesting.loadCert(SERVER_1_PEM_FILE)); - trustManager.verifySubjectAltNameInChain(certs); + trustManager.verifySubjectAltNameInChain(certs, certContext.getMatchSubjectAltNamesList()); } @Test + @SuppressWarnings("deprecation") public void oneSanInPeerCerts_safeRegex_noMatch() throws CertificateException, IOException { StringMatcher stringMatcher = StringMatcher.newBuilder() .setSafeRegex( RegexMatcher.newBuilder().setRegex("water[[:alpha:]]{2}ooi\\.test\\.google\\.be")) .build(); - @SuppressWarnings("deprecation") CertificateValidationContext certContext = CertificateValidationContext.newBuilder().addMatchSubjectAltNames(stringMatcher).build(); - trustManager = new XdsX509TrustManager(certContext, mockDelegate); + trustManager = new XdsX509TrustManager(certContext, mockDelegate, false); X509Certificate[] certs = CertificateUtils.toX509Certificates(TlsTesting.loadCert(SERVER_1_PEM_FILE)); try { - trustManager.verifySubjectAltNameInChain(certs); + trustManager.verifySubjectAltNameInChain(certs, certContext.getMatchSubjectAltNamesList()); fail("no exception thrown"); } catch (CertificateException expected) { assertThat(expected).hasMessageThat().isEqualTo("Peer certificate SAN check failed"); @@ -396,35 +416,35 @@ public void oneSanInPeerCerts_safeRegex_noMatch() throws CertificateException, I } @Test + @SuppressWarnings("deprecation") public void oneSanInPeerCertsVerifiesMultipleVerifySans() throws CertificateException, IOException { StringMatcher stringMatcher = StringMatcher.newBuilder().setExact("x.foo.com").build(); StringMatcher stringMatcher1 = StringMatcher.newBuilder().setExact("waterzooi.test.google.be").build(); - @SuppressWarnings("deprecation") CertificateValidationContext certContext = CertificateValidationContext.newBuilder() .addMatchSubjectAltNames(stringMatcher) .addMatchSubjectAltNames(stringMatcher1) .build(); - trustManager = new XdsX509TrustManager(certContext, mockDelegate); + trustManager = new XdsX509TrustManager(certContext, mockDelegate, false); X509Certificate[] certs = CertificateUtils.toX509Certificates(TlsTesting.loadCert(SERVER_1_PEM_FILE)); - trustManager.verifySubjectAltNameInChain(certs); + trustManager.verifySubjectAltNameInChain(certs, certContext.getMatchSubjectAltNamesList()); } @Test + @SuppressWarnings("deprecation") public void oneSanInPeerCertsNotFoundException() throws CertificateException, IOException { StringMatcher stringMatcher = StringMatcher.newBuilder().setExact("x.foo.com").build(); - @SuppressWarnings("deprecation") CertificateValidationContext certContext = CertificateValidationContext.newBuilder().addMatchSubjectAltNames(stringMatcher).build(); - trustManager = new XdsX509TrustManager(certContext, mockDelegate); + trustManager = new XdsX509TrustManager(certContext, mockDelegate, false); X509Certificate[] certs = CertificateUtils.toX509Certificates(TlsTesting.loadCert(SERVER_1_PEM_FILE)); try { - trustManager.verifySubjectAltNameInChain(certs); + trustManager.verifySubjectAltNameInChain(certs, certContext.getMatchSubjectAltNamesList()); fail("no exception thrown"); } catch (CertificateException expected) { assertThat(expected).hasMessageThat().isEqualTo("Peer certificate SAN check failed"); @@ -432,42 +452,43 @@ public void oneSanInPeerCertsNotFoundException() } @Test + @SuppressWarnings("deprecation") public void wildcardSanInPeerCertsVerifiesMultipleVerifySans() throws CertificateException, IOException { StringMatcher stringMatcher = StringMatcher.newBuilder().setExact("x.foo.com").build(); StringMatcher stringMatcher1 = StringMatcher.newBuilder().setSuffix("test.youTube.Com").setIgnoreCase(true).build(); - @SuppressWarnings("deprecation") CertificateValidationContext certContext = CertificateValidationContext.newBuilder() .addMatchSubjectAltNames(stringMatcher) .addMatchSubjectAltNames(stringMatcher1) // should match suffix test.youTube.Com .build(); - trustManager = new XdsX509TrustManager(certContext, mockDelegate); + trustManager = new XdsX509TrustManager(certContext, mockDelegate, false); X509Certificate[] certs = CertificateUtils.toX509Certificates(TlsTesting.loadCert(SERVER_1_PEM_FILE)); - trustManager.verifySubjectAltNameInChain(certs); + trustManager.verifySubjectAltNameInChain(certs, certContext.getMatchSubjectAltNamesList()); } @Test + @SuppressWarnings("deprecation") public void wildcardSanInPeerCertsVerifiesMultipleVerifySans1() throws CertificateException, IOException { StringMatcher stringMatcher = StringMatcher.newBuilder().setExact("x.foo.com").build(); StringMatcher stringMatcher1 = StringMatcher.newBuilder().setContains("est.Google.f").setIgnoreCase(true).build(); - @SuppressWarnings("deprecation") CertificateValidationContext certContext = CertificateValidationContext.newBuilder() .addMatchSubjectAltNames(stringMatcher) .addMatchSubjectAltNames(stringMatcher1) // should contain est.Google.f .build(); - trustManager = new XdsX509TrustManager(certContext, mockDelegate); + trustManager = new XdsX509TrustManager(certContext, mockDelegate, false); X509Certificate[] certs = CertificateUtils.toX509Certificates(TlsTesting.loadCert(SERVER_1_PEM_FILE)); - trustManager.verifySubjectAltNameInChain(certs); + trustManager.verifySubjectAltNameInChain(certs, certContext.getMatchSubjectAltNamesList()); } @Test + @SuppressWarnings("deprecation") public void wildcardSanInPeerCertsSubdomainMismatch() throws CertificateException, IOException { // 2. Asterisk (*) cannot match across domain name labels. @@ -475,14 +496,13 @@ public void wildcardSanInPeerCertsSubdomainMismatch() // sub.test.example.com. StringMatcher stringMatcher = StringMatcher.newBuilder().setExact("sub.abc.test.youtube.com").build(); - @SuppressWarnings("deprecation") CertificateValidationContext certContext = CertificateValidationContext.newBuilder().addMatchSubjectAltNames(stringMatcher).build(); - trustManager = new XdsX509TrustManager(certContext, mockDelegate); + trustManager = new XdsX509TrustManager(certContext, mockDelegate, false); X509Certificate[] certs = CertificateUtils.toX509Certificates(TlsTesting.loadCert(SERVER_1_PEM_FILE)); try { - trustManager.verifySubjectAltNameInChain(certs); + trustManager.verifySubjectAltNameInChain(certs, certContext.getMatchSubjectAltNamesList()); fail("no exception thrown"); } catch (CertificateException expected) { assertThat(expected).hasMessageThat().isEqualTo("Peer certificate SAN check failed"); @@ -490,36 +510,36 @@ public void wildcardSanInPeerCertsSubdomainMismatch() } @Test + @SuppressWarnings("deprecation") public void oneIpAddressInPeerCertsVerifies() throws CertificateException, IOException { StringMatcher stringMatcher = StringMatcher.newBuilder().setExact("x.foo.com").build(); StringMatcher stringMatcher1 = StringMatcher.newBuilder().setExact("192.168.1.3").build(); - @SuppressWarnings("deprecation") CertificateValidationContext certContext = CertificateValidationContext.newBuilder() .addMatchSubjectAltNames(stringMatcher) .addMatchSubjectAltNames(stringMatcher1) .build(); - trustManager = new XdsX509TrustManager(certContext, mockDelegate); + trustManager = new XdsX509TrustManager(certContext, mockDelegate, false); X509Certificate[] certs = CertificateUtils.toX509Certificates(TlsTesting.loadCert(SERVER_1_PEM_FILE)); - trustManager.verifySubjectAltNameInChain(certs); + trustManager.verifySubjectAltNameInChain(certs, certContext.getMatchSubjectAltNamesList()); } @Test + @SuppressWarnings("deprecation") public void oneIpAddressInPeerCertsMismatch() throws CertificateException, IOException { StringMatcher stringMatcher = StringMatcher.newBuilder().setExact("x.foo.com").build(); StringMatcher stringMatcher1 = StringMatcher.newBuilder().setExact("192.168.2.3").build(); - @SuppressWarnings("deprecation") CertificateValidationContext certContext = CertificateValidationContext.newBuilder() .addMatchSubjectAltNames(stringMatcher) .addMatchSubjectAltNames(stringMatcher1) .build(); - trustManager = new XdsX509TrustManager(certContext, mockDelegate); + trustManager = new XdsX509TrustManager(certContext, mockDelegate, false); X509Certificate[] certs = CertificateUtils.toX509Certificates(TlsTesting.loadCert(SERVER_1_PEM_FILE)); try { - trustManager.verifySubjectAltNameInChain(certs); + trustManager.verifySubjectAltNameInChain(certs, certContext.getMatchSubjectAltNamesList()); fail("no exception thrown"); } catch (CertificateException expected) { assertThat(expected).hasMessageThat().isEqualTo("Peer certificate SAN check failed"); @@ -534,6 +554,72 @@ public void checkServerTrustedSslEngine() CertificateUtils.toX509Certificates(TlsTesting.loadCert(SERVER_1_PEM_FILE)); trustManager.checkServerTrusted(serverCerts, "ECDHE_ECDSA", sslEngine); verify(sslEngine, times(1)).getHandshakeSession(); + assertThat(sslEngine.getSSLParameters().getEndpointIdentificationAlgorithm()).isEmpty(); + } + + @Test + public void checkServerTrustedSslEngineSpiffeTrustMap() + throws CertificateException, IOException, CertStoreException { + TestSslEngine sslEngine = buildTrustManagerAndGetSslEngine(); + X509Certificate[] serverCerts = + CertificateUtils.toX509Certificates(TlsTesting.loadCert(SERVER_1_SPIFFE_PEM_FILE)); + List caCerts = Arrays.asList(CertificateUtils + .toX509Certificates(TlsTesting.loadCert(CA_PEM_FILE))); + trustManager = XdsTrustManagerFactory.createX509TrustManager( + ImmutableMap.of("example.com", caCerts), null, false); + trustManager.checkServerTrusted(serverCerts, "ECDHE_ECDSA", sslEngine); + verify(sslEngine, times(1)).getHandshakeSession(); + assertThat(sslEngine.getSSLParameters().getEndpointIdentificationAlgorithm()).isEmpty(); + } + + @Test + public void checkServerTrustedSslEngineSpiffeTrustMap_missing_spiffe_id() + throws CertificateException, IOException, CertStoreException { + TestSslEngine sslEngine = buildTrustManagerAndGetSslEngine(); + X509Certificate[] serverCerts = + CertificateUtils.toX509Certificates(TlsTesting.loadCert(SERVER_1_PEM_FILE)); + List caCerts = Arrays.asList(CertificateUtils + .toX509Certificates(TlsTesting.loadCert(CA_PEM_FILE))); + trustManager = XdsTrustManagerFactory.createX509TrustManager( + ImmutableMap.of("example.com", caCerts), null, false); + try { + trustManager.checkServerTrusted(serverCerts, "ECDHE_ECDSA", sslEngine); + fail("exception expected"); + } catch (CertificateException expected) { + assertThat(expected).hasMessageThat() + .isEqualTo("Failed to extract SPIFFE ID from peer leaf certificate"); + } + } + + @Test + public void checkServerTrustedSpiffeSslEngineTrustMap_missing_trust_domain() + throws CertificateException, IOException, CertStoreException { + TestSslEngine sslEngine = buildTrustManagerAndGetSslEngine(); + X509Certificate[] serverCerts = + CertificateUtils.toX509Certificates(TlsTesting.loadCert(SERVER_1_SPIFFE_PEM_FILE)); + List caCerts = Arrays.asList(CertificateUtils + .toX509Certificates(TlsTesting.loadCert(CA_PEM_FILE))); + trustManager = XdsTrustManagerFactory.createX509TrustManager( + ImmutableMap.of("unknown.com", caCerts), null, false); + try { + trustManager.checkServerTrusted(serverCerts, "ECDHE_ECDSA", sslEngine); + fail("exception expected"); + } catch (CertificateException expected) { + assertThat(expected).hasMessageThat().isEqualTo("Spiffe Trust Map doesn't contain trust" + + " domain 'example.com' from peer leaf certificate"); + } + } + + @Test + public void checkClientTrustedSpiffeTrustMap() + throws CertificateException, IOException, CertStoreException { + X509Certificate[] clientCerts = + CertificateUtils.toX509Certificates(TlsTesting.loadCert(CLIENT_SPIFFE_PEM_FILE)); + List caCerts = Arrays.asList(CertificateUtils + .toX509Certificates(TlsTesting.loadCert(CA_PEM_FILE))); + trustManager = XdsTrustManagerFactory.createX509TrustManager( + ImmutableMap.of("foo.bar.com", caCerts), null, false); + trustManager.checkClientTrusted(clientCerts, "RSA"); } @Test @@ -561,6 +647,23 @@ public void checkServerTrustedSslSocket() trustManager.checkServerTrusted(serverCerts, "ECDHE_ECDSA", sslSocket); verify(sslSocket, times(1)).isConnected(); verify(sslSocket, times(1)).getHandshakeSession(); + assertThat(sslSocket.getSSLParameters().getEndpointIdentificationAlgorithm()).isEmpty(); + } + + @Test + public void checkServerTrustedSslSocketSpiffeTrustMap() + throws CertificateException, IOException, CertStoreException { + TestSslSocket sslSocket = buildTrustManagerAndGetSslSocket(); + X509Certificate[] serverCerts = + CertificateUtils.toX509Certificates(TlsTesting.loadCert(SERVER_1_SPIFFE_PEM_FILE)); + List caCerts = Arrays.asList(CertificateUtils + .toX509Certificates(TlsTesting.loadCert(CA_PEM_FILE))); + trustManager = XdsTrustManagerFactory.createX509TrustManager( + ImmutableMap.of("example.com", caCerts), null, false); + trustManager.checkServerTrusted(serverCerts, "ECDHE_ECDSA", sslSocket); + verify(sslSocket, times(1)).isConnected(); + verify(sslSocket, times(1)).getHandshakeSession(); + assertThat(sslSocket.getSSLParameters().getEndpointIdentificationAlgorithm()).isEmpty(); } @Test @@ -581,29 +684,76 @@ public void checkServerTrustedSslSocket_untrustedServer_expectException() } @Test - public void unsupportedAltNameType() throws CertificateException, IOException { + @SuppressWarnings("deprecation") + public void unsupportedAltNameType() throws CertificateException { StringMatcher stringMatcher = StringMatcher.newBuilder() .setExact("waterzooi.test.google.be") .setIgnoreCase(false) .build(); - @SuppressWarnings("deprecation") CertificateValidationContext certContext = CertificateValidationContext.newBuilder().addMatchSubjectAltNames(stringMatcher).build(); - trustManager = new XdsX509TrustManager(certContext, mockDelegate); + trustManager = new XdsX509TrustManager(certContext, mockDelegate, false); X509Certificate mockCert = mock(X509Certificate.class); when(mockCert.getSubjectAlternativeNames()) .thenReturn(Collections.>singleton(ImmutableList.of(Integer.valueOf(1), "foo"))); X509Certificate[] certs = new X509Certificate[] {mockCert}; try { - trustManager.verifySubjectAltNameInChain(certs); + trustManager.verifySubjectAltNameInChain(certs, certContext.getMatchSubjectAltNamesList()); fail("no exception thrown"); } catch (CertificateException expected) { assertThat(expected).hasMessageThat().isEqualTo("Peer certificate SAN check failed"); } } + @Test + @SuppressWarnings("deprecation") + public void testDnsWildcardPatterns() + throws CertificateException, IOException { + StringMatcher stringMatcher = + StringMatcher.newBuilder() + .setExact(testParam.sanPattern) + .setIgnoreCase(testParam.ignoreCase) + .build(); + @SuppressWarnings("deprecation") + CertificateValidationContext certContext = + CertificateValidationContext.newBuilder() + .addMatchSubjectAltNames(stringMatcher) + .build(); + trustManager = new XdsX509TrustManager(certContext, mockDelegate, false); + X509Certificate[] certs = + CertificateUtils.toX509Certificates(TlsTesting.loadCert(testParam.certFile)); + try { + trustManager.verifySubjectAltNameInChain(certs, certContext.getMatchSubjectAltNamesList()); + assertThat(testParam.expected).isTrue(); + } catch (CertificateException certException) { + assertThat(testParam.expected).isFalse(); + assertThat(certException).hasMessageThat().isEqualTo("Peer certificate SAN check failed"); + } + } + + @Parameters(name = "{index}: {0}") + public static Collection getParameters() { + return Arrays.asList(new Object[][] { + {new TestParam("*.test.google.fr", SERVER_1_PEM_FILE, false, true)}, + {new TestParam("*.test.youtube.com", SERVER_1_PEM_FILE, false, true)}, + {new TestParam("waterzooi.test.google.be", SERVER_1_PEM_FILE, false, true)}, + {new TestParam("192.168.1.3", SERVER_1_PEM_FILE, false, true)}, + {new TestParam("*.TEST.YOUTUBE.com", SERVER_1_PEM_FILE, true, true)}, + {new TestParam("w*i.test.google.be", SERVER_1_PEM_FILE, false, true)}, + {new TestParam("w*a.test.google.be", SERVER_1_PEM_FILE, false, false)}, + {new TestParam("*.test.google.com.au", SERVER_0_PEM_FILE, false, false)}, + {new TestParam("*.TEST.YOUTUBE.com", SERVER_1_PEM_FILE, false, false)}, + {new TestParam("*waterzooi", SERVER_1_PEM_FILE, false, false)}, + {new TestParam("*.lyft.com", BAD_WILDCARD_DNS_PEM_FILE, false, false)}, + {new TestParam("ly**ft.com", BAD_WILDCARD_DNS_PEM_FILE, false, false)}, + {new TestParam("*yft.c*m", BAD_WILDCARD_DNS_PEM_FILE, false, false)}, + {new TestParam("xn--*.lyft.com", BAD_WILDCARD_DNS_PEM_FILE, false, false)}, + {new TestParam("", BAD_WILDCARD_DNS_PEM_FILE, false, false)}, + }); + } + private TestSslEngine buildTrustManagerAndGetSslEngine() throws CertificateException, IOException, CertStoreException { SSLParameters sslParams = buildTrustManagerAndGetSslParameters(); @@ -630,7 +780,7 @@ private SSLParameters buildTrustManagerAndGetSslParameters() X509Certificate[] caCerts = CertificateUtils.toX509Certificates(TlsTesting.loadCert(CA_PEM_FILE)); trustManager = XdsTrustManagerFactory.createX509TrustManager(caCerts, - null); + null, false); when(mockSession.getProtocol()).thenReturn("TLSv1.2"); when(mockSession.getPeerHost()).thenReturn("peer-host-from-mock"); SSLParameters sslParams = new SSLParameters(); @@ -667,4 +817,18 @@ public void setSSLParameters(SSLParameters sslParameters) { private SSLParameters sslParameters; } + + private static class TestParam { + final String sanPattern; + final String certFile; + final boolean ignoreCase; + final boolean expected; + + TestParam(String sanPattern, String certFile, boolean ignoreCase, boolean expected) { + this.sanPattern = sanPattern; + this.certFile = certFile; + this.ignoreCase = ignoreCase; + this.expected = expected; + } + } } diff --git a/xds/src/test/java/io/grpc/xds/orca/OrcaOobUtilAccessor.java b/xds/src/test/java/io/grpc/xds/orca/OrcaOobUtilAccessor.java new file mode 100644 index 00000000000..db9168dd08e --- /dev/null +++ b/xds/src/test/java/io/grpc/xds/orca/OrcaOobUtilAccessor.java @@ -0,0 +1,35 @@ +/* + * Copyright 2026 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.orca; + +import io.grpc.LoadBalancer; + +/** + * Accessor for white-box testing involving OrcaOobUtil. + */ +public final class OrcaOobUtilAccessor { + private OrcaOobUtilAccessor() { + // Do not instantiate + } + + public static LoadBalancer.SubchannelPicker getDelegate(LoadBalancer.SubchannelPicker picker) { + if (picker instanceof OrcaOobUtil.OrcaReportingHelper.OrcaOobPicker) { + return ((OrcaOobUtil.OrcaReportingHelper.OrcaOobPicker) picker).delegate; + } + return picker; + } +} diff --git a/xds/src/test/resources/certs/sni-test-certs/README b/xds/src/test/resources/certs/sni-test-certs/README new file mode 100644 index 00000000000..25e66021192 --- /dev/null +++ b/xds/src/test/resources/certs/sni-test-certs/README @@ -0,0 +1,55 @@ +Bad Wildcard DNS Certificate (bad_wildcard_dns_certificate.pem) +This certificate is used for testing SNI with invalid wildcard DNS SANs. It is issued by a custom, self-signed Certificate Authority (CA). + +1. Create the Certificate Authority (CA) +Create the CA's private key: +$ openssl genpkey -algorithm RSA -out ca.key -pkeyopt rsa_keygen_bits:2048 +Create the CA's self-signed certificate: +$ openssl req -x509 -new -nodes -key ca.key -sha256 -days 365 -out ca.pem -subj "/CN=My Internal CA" + +2. Generate the Server Certificate +Next, generate the server's private key and a Certificate Signing Request (CSR). +Create the server's private key: +$ openssl genpkey -algorithm RSA -out bad_wildcard_dns.key -pkeyopt rsa_keygen_bits:2048 +Create a configuration file named san.cnf with the following content. This file specifies the Subject Alternative Names (SANs) for the certificate. +[req] +distinguished_name = req_distinguished_name +req_extensions = v3_req +prompt = no + +[req_distinguished_name] +C = US +ST = Illinois +L = Chicago +O = "Example, Co." +CN = *.test.google.com + +[v3_req] +keyUsage = nonRepudiation, digitalSignature, keyEncipherment +extendedKeyUsage = serverAuth +subjectAltName = @alt_names + +[alt_names] +DNS.1 = *.test.google.fr +DNS.2 = *.test.youtube.com +DNS.3 = waterzooi.test.google.be +DNS.4 = 192.168.1.3 +DNS.5 = *.TEST.YOUTUBE.com +DNS.6 = w*i.test.google.be +DNS.7 = w*a.test.google.be +DNS.8 = *.test.google.com.au +DNS.9 = *waterzooi +DNS.10 = *.lyft.com +DNS.11 = ly**ft.com +DNS.12 = *yft.c*m +DNS.13 = xn--*.lyft.com + +Create the Certificate Signing Request (CSR): +$ openssl req -new -key bad_wildcard_dns.key -out bad_wildcard_dns.csr -config san.cnf + +3. Sign the Server Certificate +Finally, use the CA to sign the CSR, which will create the server certificate. +$ openssl x509 -req -in bad_wildcard_dns.csr -CA ca.pem -CAkey ca.key -CAcreateserial -out bad_wildcard_dns_certificate.pem -days 365 -sha256 -extensions v3_req -extfile san.cnf + +4. Clean Up +$ rm bad_wildcard_dns.key san.cnf bad_wildcard_dns.csr ca.key ca.pem ca.srl diff --git a/xds/src/test/resources/certs/sni-test-certs/bad_wildcard_dns_certificate.pem b/xds/src/test/resources/certs/sni-test-certs/bad_wildcard_dns_certificate.pem new file mode 100644 index 00000000000..b015f62e51c --- /dev/null +++ b/xds/src/test/resources/certs/sni-test-certs/bad_wildcard_dns_certificate.pem @@ -0,0 +1,22 @@ +-----BEGIN CERTIFICATE----- +MIIDsjCCApqgAwIBAgIUCs5j4C2KXgCRVFa48kc5TYRS1JwwDQYJKoZIhvcNAQEL +BQAwGTEXMBUGA1UEAwwOTXkgSW50ZXJuYWwgQ0EwIBcNMjUwOTIzMDc1NDUzWhgP +MjEyNTA4MzAwNzU0NTNaMGUxCzAJBgNVBAYTAlVTMREwDwYDVQQIDAhJbGxpbm9p +czEQMA4GA1UEBwwHQ2hpY2FnbzEVMBMGA1UECgwMRXhhbXBsZSwgQ28uMRowGAYD +VQQDDBEqLnRlc3QuZ29vZ2xlLmNvbTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCC +AQoCggEBAKoqcnNh9MV39GH6JjC5KVMN6MO1IoTw6wHJN0JJ/nGNx6ycIsBK8SgJ +eYRR2BEpT6WZba+f04KChcB4Z9tiPISNvUBpmEv76rAsdtcAZwSpF06q4wxHVE5F +rX6mNT8hk448mDBDGHUXNAT6g/e/Vlt6U0XRyuu713gbZq1X6JH29FG7EJ3LUx35 +h6sEkvTlZZ3m6NJr7zYoqrYh/gRkPigtPxaNcoXo0gVm4IEde0sYz27SWyNH4v/o +23NynSulOwx4DwEhBOXekLb5QJHBqwMTPynaMncBQIXF+PXeuxN9a3zR6DSn+jGw +g008tS0tn2FuAvJDBl0paEykdOr2rNMCAwEAAaOBozCBoDALBgNVHQ8EBAMCBeAw +EwYDVR0lBAwwCgYIKwYBBQUHAwEwPAYDVR0RBDUwM4IKbHkqKmZ0LmNvbYIIKnlm +dC5jKm2CCS5seWZ0LmNvbYIOeG4tLSoubHlmdC5jb22CADAdBgNVHQ4EFgQUZoL2 +OzBtK/BUzSYfgXDx3iDjcIQwHwYDVR0jBBgwFoAUHlstFN5WSLSqyJgUDy6BB0z0 +BrgwDQYJKoZIhvcNAQELBQADggEBAMYwVOT7XZJMQ6n32pQqtZhJ/Z0wlVfCAbm0 +7xospeBt6KtOz2zIsvPpq0aqPjowMAeL1EZaBvmfm/XgWUU5e/3hLUIHOHyKfswB +czDbY0RE8nfVDoF4Ck1ljPjvrFr4tSAxTzVA4JU5o3UXkblBg0LG6tTuLlZ3x5aF +KtkZnszxjE+vOg6J9MDbFP/xtA1oVHyCvk+cUgnBxAoPShI+87DINGVTmztBSetK +nJN9dOh7Q88NhTLHOe67Ora9Y0ZP+uFKHaqFv8qj8B/Q6ptb0CAksdL5EunkIHrq +glKdVdYgIP2JpRwtvVHK5FzWBlGXCi3DxTyYi6FWqsSJ+heCS2w= +-----END CERTIFICATE----- diff --git a/xds/third_party/cel-spec/LICENSE b/xds/third_party/cel-spec/LICENSE new file mode 100644 index 00000000000..d6456956733 --- /dev/null +++ b/xds/third_party/cel-spec/LICENSE @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/xds/third_party/cel-spec/import.sh b/xds/third_party/cel-spec/import.sh new file mode 100755 index 00000000000..bba8214fdfb --- /dev/null +++ b/xds/third_party/cel-spec/import.sh @@ -0,0 +1,59 @@ +#!/bin/bash +# Copyright 2024 The gRPC Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Update VERSION then execute this script + +set -e +VERSION="v0.15.0" +DOWNLOAD_URL="https://github.com/google/cel-spec/archive/refs/tags/${VERSION}.tar.gz" +DOWNLOAD_BASE_DIR="cel-spec-${VERSION#v}" +SOURCE_PROTO_BASE_DIR="${DOWNLOAD_BASE_DIR}/proto" +TARGET_PROTO_BASE_DIR="src/main/proto" +# Sorted alphabetically. +FILES=( +cel/expr/checked.proto +cel/expr/syntax.proto +) + +pushd `git rev-parse --show-toplevel`/xds/third_party/cel-spec > /dev/null + +# put the repo in a tmp directory +tmpdir="$(mktemp -d)" +trap "rm -rf ${tmpdir}" EXIT +curl -Ls "${DOWNLOAD_URL}" | tar xz -C "${tmpdir}" + +cp -p "${tmpdir}/${DOWNLOAD_BASE_DIR}/LICENSE" LICENSE + +rm -rf "${TARGET_PROTO_BASE_DIR}" +mkdir -p "${TARGET_PROTO_BASE_DIR}" +pushd "${TARGET_PROTO_BASE_DIR}" > /dev/null + +# copy proto files to project directory +TOTAL=${#FILES[@]} +COPIED=0 +for file in "${FILES[@]}" +do + mkdir -p "$(dirname "${file}")" + cp -p "${tmpdir}/${SOURCE_PROTO_BASE_DIR}/${file}" "${file}" && (( ++COPIED )) +done +popd > /dev/null + +popd > /dev/null + +echo "Imported ${COPIED} files." +if (( COPIED != TOTAL )); then + echo "Failed importing $(( TOTAL - COPIED )) files." 1>&2 + exit 1 +fi diff --git a/xds/third_party/cel-spec/src/main/proto/cel/expr/checked.proto b/xds/third_party/cel-spec/src/main/proto/cel/expr/checked.proto new file mode 100644 index 00000000000..e327db9b225 --- /dev/null +++ b/xds/third_party/cel-spec/src/main/proto/cel/expr/checked.proto @@ -0,0 +1,344 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto3"; + +package cel.expr; + +import "cel/expr/syntax.proto"; +import "google/protobuf/empty.proto"; +import "google/protobuf/struct.proto"; + +option cc_enable_arenas = true; +option go_package = "cel.dev/expr"; +option java_multiple_files = true; +option java_outer_classname = "DeclProto"; +option java_package = "dev.cel.expr"; + +// Protos for representing CEL declarations and typed checked expressions. + +// A CEL expression which has been successfully type checked. +message CheckedExpr { + // A map from expression ids to resolved references. + // + // The following entries are in this table: + // + // - An Ident or Select expression is represented here if it resolves to a + // declaration. For instance, if `a.b.c` is represented by + // `select(select(id(a), b), c)`, and `a.b` resolves to a declaration, + // while `c` is a field selection, then the reference is attached to the + // nested select expression (but not to the id or or the outer select). + // In turn, if `a` resolves to a declaration and `b.c` are field selections, + // the reference is attached to the ident expression. + // - Every Call expression has an entry here, identifying the function being + // called. + // - Every CreateStruct expression for a message has an entry, identifying + // the message. + map reference_map = 2; + + // A map from expression ids to types. + // + // Every expression node which has a type different than DYN has a mapping + // here. If an expression has type DYN, it is omitted from this map to save + // space. + map type_map = 3; + + // The source info derived from input that generated the parsed `expr` and + // any optimizations made during the type-checking pass. + SourceInfo source_info = 5; + + // The expr version indicates the major / minor version number of the `expr` + // representation. + // + // The most common reason for a version change will be to indicate to the CEL + // runtimes that transformations have been performed on the expr during static + // analysis. In some cases, this will save the runtime the work of applying + // the same or similar transformations prior to evaluation. + string expr_version = 6; + + // The checked expression. Semantically equivalent to the parsed `expr`, but + // may have structural differences. + Expr expr = 4; +} + +// Represents a CEL type. +message Type { + // List type with typed elements, e.g. `list`. + message ListType { + // The element type. + Type elem_type = 1; + } + + // Map type with parameterized key and value types, e.g. `map`. + message MapType { + // The type of the key. + Type key_type = 1; + + // The type of the value. + Type value_type = 2; + } + + // Function type with result and arg types. + message FunctionType { + // Result type of the function. + Type result_type = 1; + + // Argument types of the function. + repeated Type arg_types = 2; + } + + // Application defined abstract type. + message AbstractType { + // The fully qualified name of this abstract type. + string name = 1; + + // Parameter types for this abstract type. + repeated Type parameter_types = 2; + } + + // CEL primitive types. + enum PrimitiveType { + // Unspecified type. + PRIMITIVE_TYPE_UNSPECIFIED = 0; + + // Boolean type. + BOOL = 1; + + // Int64 type. + // + // 32-bit integer values are widened to int64. + INT64 = 2; + + // Uint64 type. + // + // 32-bit unsigned integer values are widened to uint64. + UINT64 = 3; + + // Double type. + // + // 32-bit float values are widened to double values. + DOUBLE = 4; + + // String type. + STRING = 5; + + // Bytes type. + BYTES = 6; + } + + // Well-known protobuf types treated with first-class support in CEL. + enum WellKnownType { + // Unspecified type. + WELL_KNOWN_TYPE_UNSPECIFIED = 0; + + // Well-known protobuf.Any type. + // + // Any types are a polymorphic message type. During type-checking they are + // treated like `DYN` types, but at runtime they are resolved to a specific + // message type specified at evaluation time. + ANY = 1; + + // Well-known protobuf.Timestamp type, internally referenced as `timestamp`. + TIMESTAMP = 2; + + // Well-known protobuf.Duration type, internally referenced as `duration`. + DURATION = 3; + } + + // The kind of type. + oneof type_kind { + // Dynamic type. + google.protobuf.Empty dyn = 1; + + // Null value. + google.protobuf.NullValue null = 2; + + // Primitive types: `true`, `1u`, `-2.0`, `'string'`, `b'bytes'`. + PrimitiveType primitive = 3; + + // Wrapper of a primitive type, e.g. `google.protobuf.Int64Value`. + PrimitiveType wrapper = 4; + + // Well-known protobuf type such as `google.protobuf.Timestamp`. + WellKnownType well_known = 5; + + // Parameterized list with elements of `list_type`, e.g. `list`. + ListType list_type = 6; + + // Parameterized map with typed keys and values. + MapType map_type = 7; + + // Function type. + FunctionType function = 8; + + // Protocol buffer message type. + // + // The `message_type` string specifies the qualified message type name. For + // example, `google.type.PhoneNumber`. + string message_type = 9; + + // Type param type. + // + // The `type_param` string specifies the type parameter name, e.g. `list` + // would be a `list_type` whose element type was a `type_param` type + // named `E`. + string type_param = 10; + + // Type type. + // + // The `type` value specifies the target type. e.g. int is type with a + // target type of `Primitive.INT64`. + Type type = 11; + + // Error type. + // + // During type-checking if an expression is an error, its type is propagated + // as the `ERROR` type. This permits the type-checker to discover other + // errors present in the expression. + google.protobuf.Empty error = 12; + + // Abstract, application defined type. + // + // An abstract type has no accessible field names, and it can only be + // inspected via helper / member functions. + AbstractType abstract_type = 14; + } +} + +// Represents a declaration of a named value or function. +// +// A declaration is part of the contract between the expression, the agent +// evaluating that expression, and the caller requesting evaluation. +message Decl { + // Identifier declaration which specifies its type and optional `Expr` value. + // + // An identifier without a value is a declaration that must be provided at + // evaluation time. An identifier with a value should resolve to a constant, + // but may be used in conjunction with other identifiers bound at evaluation + // time. + message IdentDecl { + // Required. The type of the identifier. + Type type = 1; + + // The constant value of the identifier. If not specified, the identifier + // must be supplied at evaluation time. + Constant value = 2; + + // Documentation string for the identifier. + string doc = 3; + } + + // Function declaration specifies one or more overloads which indicate the + // function's parameter types and return type. + // + // Functions have no observable side-effects (there may be side-effects like + // logging which are not observable from CEL). + message FunctionDecl { + // An overload indicates a function's parameter types and return type, and + // may optionally include a function body described in terms of + // [Expr][cel.expr.Expr] values. + // + // Functions overloads are declared in either a function or method + // call-style. For methods, the `params[0]` is the expected type of the + // target receiver. + // + // Overloads must have non-overlapping argument types after erasure of all + // parameterized type variables (similar as type erasure in Java). + message Overload { + // Required. Globally unique overload name of the function which reflects + // the function name and argument types. + // + // This will be used by a [Reference][cel.expr.Reference] to + // indicate the `overload_id` that was resolved for the function `name`. + string overload_id = 1; + + // List of function parameter [Type][cel.expr.Type] values. + // + // Param types are disjoint after generic type parameters have been + // replaced with the type `DYN`. Since the `DYN` type is compatible with + // any other type, this means that if `A` is a type parameter, the + // function types `int` and `int` are not disjoint. Likewise, + // `map` is not disjoint from `map`. + // + // When the `result_type` of a function is a generic type param, the + // type param name also appears as the `type` of on at least one params. + repeated Type params = 2; + + // The type param names associated with the function declaration. + // + // For example, `function ex(K key, map map) : V` would yield + // the type params of `K, V`. + repeated string type_params = 3; + + // Required. The result type of the function. For example, the operator + // `string.isEmpty()` would have `result_type` of `kind: BOOL`. + Type result_type = 4; + + // Whether the function is to be used in a method call-style `x.f(...)` + // of a function call-style `f(x, ...)`. + // + // For methods, the first parameter declaration, `params[0]` is the + // expected type of the target receiver. + bool is_instance_function = 5; + + // Documentation string for the overload. + string doc = 6; + } + + // Required. List of function overloads, must contain at least one overload. + repeated Overload overloads = 1; + } + + // The fully qualified name of the declaration. + // + // Declarations are organized in containers and this represents the full path + // to the declaration in its container, as in `cel.expr.Decl`. + // + // Declarations used as + // [FunctionDecl.Overload][cel.expr.Decl.FunctionDecl.Overload] + // parameters may or may not have a name depending on whether the overload is + // function declaration or a function definition containing a result + // [Expr][cel.expr.Expr]. + string name = 1; + + // Required. The declaration kind. + oneof decl_kind { + // Identifier declaration. + IdentDecl ident = 2; + + // Function declaration. + FunctionDecl function = 3; + } +} + +// Describes a resolved reference to a declaration. +message Reference { + // The fully qualified name of the declaration. + string name = 1; + + // For references to functions, this is a list of `Overload.overload_id` + // values which match according to typing rules. + // + // If the list has more than one element, overload resolution among the + // presented candidates must happen at runtime because of dynamic types. The + // type checker attempts to narrow down this list as much as possible. + // + // Empty if this is not a reference to a + // [Decl.FunctionDecl][cel.expr.Decl.FunctionDecl]. + repeated string overload_id = 3; + + // For references to constants, this may contain the value of the + // constant if known at compile time. + Constant value = 4; +} diff --git a/xds/third_party/cel-spec/src/main/proto/cel/expr/syntax.proto b/xds/third_party/cel-spec/src/main/proto/cel/expr/syntax.proto new file mode 100644 index 00000000000..ed124a74384 --- /dev/null +++ b/xds/third_party/cel-spec/src/main/proto/cel/expr/syntax.proto @@ -0,0 +1,393 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto3"; + +package cel.expr; + +import "google/protobuf/duration.proto"; +import "google/protobuf/struct.proto"; +import "google/protobuf/timestamp.proto"; + +option cc_enable_arenas = true; +option go_package = "cel.dev/expr"; +option java_multiple_files = true; +option java_outer_classname = "SyntaxProto"; +option java_package = "dev.cel.expr"; + +// A representation of the abstract syntax of the Common Expression Language. + +// An expression together with source information as returned by the parser. +message ParsedExpr { + // The parsed expression. + Expr expr = 2; + + // The source info derived from input that generated the parsed `expr`. + SourceInfo source_info = 3; +} + +// An abstract representation of a common expression. +// +// Expressions are abstractly represented as a collection of identifiers, +// select statements, function calls, literals, and comprehensions. All +// operators with the exception of the '.' operator are modelled as function +// calls. This makes it easy to represent new operators into the existing AST. +// +// All references within expressions must resolve to a +// [Decl][cel.expr.Decl] provided at type-check for an expression to be +// valid. A reference may either be a bare identifier `name` or a qualified +// identifier `google.api.name`. References may either refer to a value or a +// function declaration. +// +// For example, the expression `google.api.name.startsWith('expr')` references +// the declaration `google.api.name` within a +// [Expr.Select][cel.expr.Expr.Select] expression, and the function +// declaration `startsWith`. +message Expr { + // An identifier expression. e.g. `request`. + message Ident { + // Required. Holds a single, unqualified identifier, possibly preceded by a + // '.'. + // + // Qualified names are represented by the + // [Expr.Select][cel.expr.Expr.Select] expression. + string name = 1; + } + + // A field selection expression. e.g. `request.auth`. + message Select { + // Required. The target of the selection expression. + // + // For example, in the select expression `request.auth`, the `request` + // portion of the expression is the `operand`. + Expr operand = 1; + + // Required. The name of the field to select. + // + // For example, in the select expression `request.auth`, the `auth` portion + // of the expression would be the `field`. + string field = 2; + + // Whether the select is to be interpreted as a field presence test. + // + // This results from the macro `has(request.auth)`. + bool test_only = 3; + } + + // A call expression, including calls to predefined functions and operators. + // + // For example, `value == 10`, `size(map_value)`. + message Call { + // The target of an method call-style expression. For example, `x` in + // `x.f()`. + Expr target = 1; + + // Required. The name of the function or method being called. + string function = 2; + + // The arguments. + repeated Expr args = 3; + } + + // A list creation expression. + // + // Lists may either be homogenous, e.g. `[1, 2, 3]`, or heterogeneous, e.g. + // `dyn([1, 'hello', 2.0])` + message CreateList { + // The elements part of the list. + repeated Expr elements = 1; + + // The indices within the elements list which are marked as optional + // elements. + // + // When an optional-typed value is present, the value it contains + // is included in the list. If the optional-typed value is absent, the list + // element is omitted from the CreateList result. + repeated int32 optional_indices = 2; + } + + // A map or message creation expression. + // + // Maps are constructed as `{'key_name': 'value'}`. Message construction is + // similar, but prefixed with a type name and composed of field ids: + // `types.MyType{field_id: 'value'}`. + message CreateStruct { + // Represents an entry. + message Entry { + // Required. An id assigned to this node by the parser which is unique + // in a given expression tree. This is used to associate type + // information and other attributes to the node. + int64 id = 1; + + // The `Entry` key kinds. + oneof key_kind { + // The field key for a message creator statement. + string field_key = 2; + + // The key expression for a map creation statement. + Expr map_key = 3; + } + + // Required. The value assigned to the key. + // + // If the optional_entry field is true, the expression must resolve to an + // optional-typed value. If the optional value is present, the key will be + // set; however, if the optional value is absent, the key will be unset. + Expr value = 4; + + // Whether the key-value pair is optional. + bool optional_entry = 5; + } + + // The type name of the message to be created, empty when creating map + // literals. + string message_name = 1; + + // The entries in the creation expression. + repeated Entry entries = 2; + } + + // A comprehension expression applied to a list or map. + // + // Comprehensions are not part of the core syntax, but enabled with macros. + // A macro matches a specific call signature within a parsed AST and replaces + // the call with an alternate AST block. Macro expansion happens at parse + // time. + // + // The following macros are supported within CEL: + // + // Aggregate type macros may be applied to all elements in a list or all keys + // in a map: + // + // * `all`, `exists`, `exists_one` - test a predicate expression against + // the inputs and return `true` if the predicate is satisfied for all, + // any, or only one value `list.all(x, x < 10)`. + // * `filter` - test a predicate expression against the inputs and return + // the subset of elements which satisfy the predicate: + // `payments.filter(p, p > 1000)`. + // * `map` - apply an expression to all elements in the input and return the + // output aggregate type: `[1, 2, 3].map(i, i * i)`. + // + // The `has(m.x)` macro tests whether the property `x` is present in struct + // `m`. The semantics of this macro depend on the type of `m`. For proto2 + // messages `has(m.x)` is defined as 'defined, but not set`. For proto3, the + // macro tests whether the property is set to its default. For map and struct + // types, the macro tests whether the property `x` is defined on `m`. + // + // Comprehension evaluation can be best visualized as the following + // pseudocode: + // + // ``` + // let `accu_var` = `accu_init` + // for (let `iter_var` in `iter_range`) { + // if (!`loop_condition`) { + // break + // } + // `accu_var` = `loop_step` + // } + // return `result` + // ``` + message Comprehension { + // The name of the iteration variable. + string iter_var = 1; + + // The range over which var iterates. + Expr iter_range = 2; + + // The name of the variable used for accumulation of the result. + string accu_var = 3; + + // The initial value of the accumulator. + Expr accu_init = 4; + + // An expression which can contain iter_var and accu_var. + // + // Returns false when the result has been computed and may be used as + // a hint to short-circuit the remainder of the comprehension. + Expr loop_condition = 5; + + // An expression which can contain iter_var and accu_var. + // + // Computes the next value of accu_var. + Expr loop_step = 6; + + // An expression which can contain accu_var. + // + // Computes the result. + Expr result = 7; + } + + // Required. An id assigned to this node by the parser which is unique in a + // given expression tree. This is used to associate type information and other + // attributes to a node in the parse tree. + int64 id = 2; + + // Required. Variants of expressions. + oneof expr_kind { + // A constant expression. + Constant const_expr = 3; + + // An identifier expression. + Ident ident_expr = 4; + + // A field selection expression, e.g. `request.auth`. + Select select_expr = 5; + + // A call expression, including calls to predefined functions and operators. + Call call_expr = 6; + + // A list creation expression. + CreateList list_expr = 7; + + // A map or message creation expression. + CreateStruct struct_expr = 8; + + // A comprehension expression. + Comprehension comprehension_expr = 9; + } +} + +// Represents a primitive literal. +// +// Named 'Constant' here for backwards compatibility. +// +// This is similar as the primitives supported in the well-known type +// `google.protobuf.Value`, but richer so it can represent CEL's full range of +// primitives. +// +// Lists and structs are not included as constants as these aggregate types may +// contain [Expr][cel.expr.Expr] elements which require evaluation and +// are thus not constant. +// +// Examples of constants include: `"hello"`, `b'bytes'`, `1u`, `4.2`, `-2`, +// `true`, `null`. +message Constant { + // Required. The valid constant kinds. + oneof constant_kind { + // null value. + google.protobuf.NullValue null_value = 1; + + // boolean value. + bool bool_value = 2; + + // int64 value. + int64 int64_value = 3; + + // uint64 value. + uint64 uint64_value = 4; + + // double value. + double double_value = 5; + + // string value. + string string_value = 6; + + // bytes value. + bytes bytes_value = 7; + + // protobuf.Duration value. + // + // Deprecated: duration is no longer considered a builtin cel type. + google.protobuf.Duration duration_value = 8 [deprecated = true]; + + // protobuf.Timestamp value. + // + // Deprecated: timestamp is no longer considered a builtin cel type. + google.protobuf.Timestamp timestamp_value = 9 [deprecated = true]; + } +} + +// Source information collected at parse time. +message SourceInfo { + // The syntax version of the source, e.g. `cel1`. + string syntax_version = 1; + + // The location name. All position information attached to an expression is + // relative to this location. + // + // The location could be a file, UI element, or similar. For example, + // `acme/app/AnvilPolicy.cel`. + string location = 2; + + // Monotonically increasing list of code point offsets where newlines + // `\n` appear. + // + // The line number of a given position is the index `i` where for a given + // `id` the `line_offsets[i] < id_positions[id] < line_offsets[i+1]`. The + // column may be derived from `id_positions[id] - line_offsets[i]`. + repeated int32 line_offsets = 3; + + // A map from the parse node id (e.g. `Expr.id`) to the code point offset + // within the source. + map positions = 4; + + // A map from the parse node id where a macro replacement was made to the + // call `Expr` that resulted in a macro expansion. + // + // For example, `has(value.field)` is a function call that is replaced by a + // `test_only` field selection in the AST. Likewise, the call + // `list.exists(e, e > 10)` translates to a comprehension expression. The key + // in the map corresponds to the expression id of the expanded macro, and the + // value is the call `Expr` that was replaced. + map macro_calls = 5; + + // A list of tags for extensions that were used while parsing or type checking + // the source expression. For example, optimizations that require special + // runtime support may be specified. + // + // These are used to check feature support between components in separate + // implementations. This can be used to either skip redundant work or + // report an error if the extension is unsupported. + repeated Extension extensions = 6; + + // An extension that was requested for the source expression. + message Extension { + // Version + message Version { + // Major version changes indicate different required support level from + // the required components. + int64 major = 1; + // Minor version changes must not change the observed behavior from + // existing implementations, but may be provided informationally. + int64 minor = 2; + } + + // CEL component specifier. + enum Component { + // Unspecified, default. + COMPONENT_UNSPECIFIED = 0; + // Parser. Converts a CEL string to an AST. + COMPONENT_PARSER = 1; + // Type checker. Checks that references in an AST are defined and types + // agree. + COMPONENT_TYPE_CHECKER = 2; + // Runtime. Evaluates a parsed and optionally checked CEL AST against a + // context. + COMPONENT_RUNTIME = 3; + } + + // Identifier for the extension. Example: constant_folding + string id = 1; + + // If set, the listed components must understand the extension for the + // expression to evaluate correctly. + // + // This field has set semantics, repeated values should be deduplicated. + repeated Component affected_components = 2; + + // Version info. May be skipped if it isn't meaningful for the extension. + // (for example constant_folding might always be v0.0). + Version version = 3; + } +} diff --git a/xds/third_party/envoy/import.sh b/xds/third_party/envoy/import.sh index a7df33789df..74b8af750ab 100755 --- a/xds/third_party/envoy/import.sh +++ b/xds/third_party/envoy/import.sh @@ -16,8 +16,8 @@ # Update VERSION then execute this script set -e -# import VERSION from the google internal copybara_version.txt for Envoy -VERSION=147e6b9523d8d2ae0d9d2205254d6e633644c6fe +# import VERSION from the google internal go/envoy-import-status +VERSION=a0b3df32ba54c92a08d3636a9a36013cb920e471 DOWNLOAD_URL="https://github.com/envoyproxy/envoy/archive/${VERSION}.tar.gz" DOWNLOAD_BASE_DIR="envoy-${VERSION}" SOURCE_PROTO_BASE_DIR="${DOWNLOAD_BASE_DIR}/api" @@ -33,9 +33,11 @@ envoy/config/cluster/v3/circuit_breaker.proto envoy/config/cluster/v3/cluster.proto envoy/config/cluster/v3/filter.proto envoy/config/cluster/v3/outlier_detection.proto +envoy/config/common/mutation_rules/v3/mutation_rules.proto envoy/config/core/v3/address.proto envoy/config/core/v3/backoff.proto envoy/config/core/v3/base.proto +envoy/config/core/v3/cel.proto envoy/config/core/v3/config_source.proto envoy/config/core/v3/event_service_config.proto envoy/config/core/v3/extension.proto @@ -46,6 +48,7 @@ envoy/config/core/v3/http_uri.proto envoy/config/core/v3/protocol.proto envoy/config/core/v3/proxy_protocol.proto envoy/config/core/v3/resolver.proto +envoy/config/core/v3/socket_cmsg_headers.proto envoy/config/core/v3/socket_option.proto envoy/config/core/v3/substitution_format_string.proto envoy/config/core/v3/udp_socket_config.proto @@ -67,18 +70,27 @@ envoy/config/trace/v3/datadog.proto envoy/config/trace/v3/dynamic_ot.proto envoy/config/trace/v3/http_tracer.proto envoy/config/trace/v3/lightstep.proto -envoy/config/trace/v3/opencensus.proto envoy/config/trace/v3/opentelemetry.proto envoy/config/trace/v3/service.proto -envoy/config/trace/v3/trace.proto envoy/config/trace/v3/zipkin.proto envoy/data/accesslog/v3/accesslog.proto envoy/extensions/clusters/aggregate/v3/cluster.proto envoy/extensions/filters/common/fault/v3/fault.proto +envoy/extensions/filters/http/ext_authz/v3/ext_authz.proto +envoy/extensions/common/matching/v3/extension_matcher.proto envoy/extensions/filters/http/fault/v3/fault.proto +envoy/extensions/filters/http/composite/v3/composite.proto +envoy/extensions/filters/http/rate_limit_quota/v3/rate_limit_quota.proto +envoy/extensions/filters/http/gcp_authn/v3/gcp_authn.proto envoy/extensions/filters/http/rbac/v3/rbac.proto envoy/extensions/filters/http/router/v3/router.proto envoy/extensions/filters/network/http_connection_manager/v3/http_connection_manager.proto +envoy/extensions/grpc_service/call_credentials/access_token/v3/access_token_credentials.proto +envoy/extensions/grpc_service/channel_credentials/google_default/v3/google_default_credentials.proto +envoy/extensions/grpc_service/channel_credentials/insecure/v3/insecure_credentials.proto +envoy/extensions/grpc_service/channel_credentials/local/v3/local_credentials.proto +envoy/extensions/grpc_service/channel_credentials/tls/v3/tls_credentials.proto +envoy/extensions/grpc_service/channel_credentials/xds/v3/xds_credentials.proto envoy/extensions/load_balancing_policies/client_side_weighted_round_robin/v3/client_side_weighted_round_robin.proto envoy/extensions/load_balancing_policies/common/v3/common.proto envoy/extensions/load_balancing_policies/least_request/v3/least_request.proto @@ -86,18 +98,25 @@ envoy/extensions/load_balancing_policies/pick_first/v3/pick_first.proto envoy/extensions/load_balancing_policies/ring_hash/v3/ring_hash.proto envoy/extensions/load_balancing_policies/round_robin/v3/round_robin.proto envoy/extensions/load_balancing_policies/wrr_locality/v3/wrr_locality.proto +envoy/extensions/transport_sockets/http_11_proxy/v3/upstream_http_11_connect.proto envoy/extensions/transport_sockets/tls/v3/cert.proto envoy/extensions/transport_sockets/tls/v3/common.proto envoy/extensions/transport_sockets/tls/v3/secret.proto envoy/extensions/transport_sockets/tls/v3/tls.proto +envoy/service/auth/v3/attribute_context.proto +envoy/service/auth/v3/external_auth.proto envoy/service/discovery/v3/ads.proto envoy/service/discovery/v3/discovery.proto envoy/service/load_stats/v3/lrs.proto +envoy/service/rate_limit_quota/v3/rlqs.proto envoy/service/status/v3/csds.proto envoy/type/http/v3/path_transformation.proto +envoy/type/matcher/v3/address.proto envoy/type/matcher/v3/filter_state.proto +envoy/type/matcher/v3/http_inputs.proto envoy/type/matcher/v3/metadata.proto envoy/type/matcher/v3/node.proto +envoy/config/common/matcher/v3/matcher.proto envoy/type/matcher/v3/number.proto envoy/type/matcher/v3/path.proto envoy/type/matcher/v3/regex.proto @@ -107,9 +126,13 @@ envoy/type/matcher/v3/value.proto envoy/type/metadata/v3/metadata.proto envoy/type/tracing/v3/custom_tag.proto envoy/type/v3/http.proto +envoy/type/v3/http_status.proto envoy/type/v3/percent.proto envoy/type/v3/range.proto +envoy/type/v3/ratelimit_strategy.proto +envoy/type/v3/ratelimit_unit.proto envoy/type/v3/semantic_version.proto +envoy/type/v3/token_bucket.proto ) pushd "$(git rev-parse --show-toplevel)/xds/third_party/envoy" > /dev/null @@ -137,7 +160,7 @@ COPIED=0 for file in "${FILES[@]}" do mkdir -p "$(dirname "${file}")" - cp -p "${tmpdir}/${SOURCE_PROTO_BASE_DIR}/${file}" "${file}" && (( COPIED++ )) + cp -p "${tmpdir}/${SOURCE_PROTO_BASE_DIR}/${file}" "${file}" && (( ++COPIED )) done popd > /dev/null diff --git a/xds/third_party/envoy/src/main/proto/envoy/admin/v3/config_dump_shared.proto b/xds/third_party/envoy/src/main/proto/envoy/admin/v3/config_dump_shared.proto index 8de77e18e1f..b34e004d986 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/admin/v3/config_dump_shared.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/admin/v3/config_dump_shared.proto @@ -39,6 +39,14 @@ enum ClientResourceStatus { // Client received this resource and replied with NACK. NACKED = 4; + + // Client received an error from the control plane. The attached config + // dump is the most recent accepted one. If no config is accepted yet, + // the attached config dump will be empty. + RECEIVED_ERROR = 5; + + // Client timed out waiting for the resource from the control plane. + TIMEOUT = 6; } message UpdateFailureState { diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/accesslog/v3/accesslog.proto b/xds/third_party/envoy/src/main/proto/envoy/config/accesslog/v3/accesslog.proto index fe3ba2bc97c..f273f2e695f 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/config/accesslog/v3/accesslog.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/config/accesslog/v3/accesslog.proto @@ -108,6 +108,9 @@ message ComparisonFilter { // <= LE = 2; + + // != + NE = 3; } // Comparison operator. @@ -152,35 +155,38 @@ message TraceableFilter { "envoy.config.filter.accesslog.v2.TraceableFilter"; } -// Filters for random sampling of requests. +// Filters requests based on runtime-configurable sampling rates. message RuntimeFilter { option (udpa.annotations.versioning).previous_message_type = "envoy.config.filter.accesslog.v2.RuntimeFilter"; - // Runtime key to get an optional overridden numerator for use in the - // ``percent_sampled`` field. If found in runtime, this value will replace the - // default numerator. + // Specifies a key used to look up a custom sampling rate from the runtime configuration. If a value is found for this + // key, it will override the default sampling rate specified in ``percent_sampled``. string runtime_key = 1 [(validate.rules).string = {min_len: 1}]; - // The default sampling percentage. If not specified, defaults to 0% with - // denominator of 100. + // Defines the default sampling percentage when no runtime override is present. If not specified, the default is + // **0%** (with a denominator of 100). type.v3.FractionalPercent percent_sampled = 2; - // By default, sampling pivots on the header - // :ref:`x-request-id` being - // present. If :ref:`x-request-id` - // is present, the filter will consistently sample across multiple hosts based - // on the runtime key value and the value extracted from - // :ref:`x-request-id`. If it is - // missing, or ``use_independent_randomness`` is set to true, the filter will - // randomly sample based on the runtime key value alone. - // ``use_independent_randomness`` can be used for logging kill switches within - // complex nested :ref:`AndFilter - // ` and :ref:`OrFilter - // ` blocks that are easier to - // reason about from a probability perspective (i.e., setting to true will - // cause the filter to behave like an independent random variable when - // composed within logical operator filters). + // Controls how sampling decisions are made. + // + // - Default behavior (``false``): + // + // * Uses the :ref:`x-request-id` as a consistent sampling pivot. + // * When :ref:`x-request-id` is present, sampling will be consistent + // across multiple hosts based on both the ``runtime_key`` and + // :ref:`x-request-id`. + // * Useful for tracking related requests across a distributed system. + // + // - When set to ``true`` or :ref:`x-request-id` is missing: + // + // * Sampling decisions are made randomly based only on the ``runtime_key``. + // * Useful in complex filter configurations (like nested + // :ref:`AndFilter`/ + // :ref:`OrFilter` blocks) where independent probability + // calculations are desired. + // * Can be used to implement logging kill switches with predictable probability distributions. + // bool use_independent_randomness = 3; } @@ -256,6 +262,8 @@ message ResponseFlagFilter { in: "OM" in: "DF" in: "DO" + in: "DR" + in: "UDO" } } }]; diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/bootstrap/v3/bootstrap.proto b/xds/third_party/envoy/src/main/proto/envoy/config/bootstrap/v3/bootstrap.proto index b5f36f273bc..7b862c1021a 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/config/bootstrap/v3/bootstrap.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/config/bootstrap/v3/bootstrap.proto @@ -16,6 +16,7 @@ import "envoy/config/metrics/v3/stats.proto"; import "envoy/config/overload/v3/overload.proto"; import "envoy/config/trace/v3/http_tracer.proto"; import "envoy/extensions/transport_sockets/tls/v3/secret.proto"; +import "envoy/type/matcher/v3/string.proto"; import "envoy/type/v3/percent.proto"; import "google/protobuf/duration.proto"; @@ -41,7 +42,7 @@ option (udpa.annotations.file_status).package_version_status = ACTIVE; // ` for more detail. // Bootstrap :ref:`configuration overview `. -// [#next-free-field: 41] +// [#next-free-field: 43] message Bootstrap { option (udpa.annotations.versioning).previous_message_type = "envoy.config.bootstrap.v2.Bootstrap"; @@ -57,9 +58,7 @@ message Bootstrap { // If a network based configuration source is specified for :ref:`cds_config // `, it's necessary // to have some initial cluster definitions available to allow Envoy to know - // how to speak to the management server. These cluster definitions may not - // use :ref:`EDS ` (i.e. they should be static - // IP or DNS-based). + // how to speak to the management server. repeated cluster.v3.Cluster clusters = 2; // These static secrets can be used by :ref:`SdsSecretConfig @@ -78,7 +77,7 @@ message Bootstrap { // :ref:`LDS ` configuration source. core.v3.ConfigSource lds_config = 1; - // xdstp:// resource locator for listener collection. + // ``xdstp://`` resource locator for listener collection. // [#not-implemented-hide:] string lds_resources_locator = 5; @@ -87,7 +86,7 @@ message Bootstrap { // configuration source. core.v3.ConfigSource cds_config = 2; - // xdstp:// resource locator for cluster collection. + // ``xdstp://`` resource locator for cluster collection. // [#not-implemented-hide:] string cds_resources_locator = 6; @@ -128,17 +127,19 @@ message Bootstrap { // When the flag is enabled, Envoy will lazily initialize a subset of the stats (see below). // This will save memory and CPU cycles when creating the objects that own these stats, if those // stats are never referenced throughout the lifetime of the process. However, it will incur additional - // memory overhead for these objects, and a small increase of CPU usage when a at least one of the stats + // memory overhead for these objects, and a small increase of CPU usage when at least one of the stats // is updated for the first time. + // // Groups of stats that will be lazily initialized: + // // - Cluster traffic stats: a subgroup of the :ref:`cluster statistics ` - // that are used when requests are routed to the cluster. + // that are used when requests are routed to the cluster. bool enable_deferred_creation_stats = 1; } message GrpcAsyncClientManagerConfig { // Optional field to set the expiration time for the cached gRPC client object. - // The minimal value is 5s and the default is 50s. + // The minimal value is ``5s`` and the default is ``50s``. google.protobuf.Duration max_cached_entry_idle_duration = 1 [(validate.rules).duration = {gte {seconds: 5}}]; } @@ -153,25 +154,25 @@ message Bootstrap { // A list of :ref:`Node ` field names // that will be included in the context parameters of the effective - // xdstp:// URL that is sent in a discovery request when resource + // ``xdstp://`` URL that is sent in a discovery request when resource // locators are used for LDS/CDS. Any non-string field will have its JSON // encoding set as the context parameter value, with the exception of // metadata, which will be flattened (see example below). The supported field // names are: - // - "cluster" - // - "id" - // - "locality.region" - // - "locality.sub_zone" - // - "locality.zone" - // - "metadata" - // - "user_agent_build_version.metadata" - // - "user_agent_build_version.version" - // - "user_agent_name" - // - "user_agent_version" + // - ``cluster`` + // - ``id`` + // - ``locality.region`` + // - ``locality.sub_zone`` + // - ``locality.zone`` + // - ``metadata`` + // - ``user_agent_build_version.metadata`` + // - ``user_agent_build_version.version`` + // - ``user_agent_name`` + // - ``user_agent_version`` // // The node context parameters act as a base layer dictionary for the context // parameters (i.e. more specific resource specific context parameters will - // override). Field names will be prefixed with “udpa.node.” when included in + // override). Field names will be prefixed with ````"udpa.node."```` when included in // context parameters. // // For example, if node_context_params is ``["user_agent_name", "metadata"]``, @@ -213,10 +214,10 @@ message Bootstrap { // Optional duration between flushes to configured stats sinks. For // performance reasons Envoy latches counters and only flushes counters and - // gauges at a periodic interval. If not specified the default is 5000ms (5 - // seconds). Only one of ``stats_flush_interval`` or ``stats_flush_on_admin`` + // gauges at a periodic interval. If not specified the default is ``5000ms`` (``5`` seconds). + // Only one of ``stats_flush_interval`` or ``stats_flush_on_admin`` // can be set. - // Duration must be at least 1ms and at most 5 min. + // Duration must be at least ``1ms`` and at most ``5 min``. google.protobuf.Duration stats_flush_interval = 7 [ (validate.rules).duration = { lt {seconds: 300} @@ -232,6 +233,14 @@ message Bootstrap { bool stats_flush_on_admin = 29 [(validate.rules).bool = {const: true}]; } + oneof stats_eviction { + // Optional duration to perform metric eviction. At every interval, during the stats flush + // the unused metrics are removed from the worker caches and the used metrics + // are marked as unused. Must be a multiple of the ``stats_flush_interval``. + google.protobuf.Duration stats_eviction_interval = 42 + [(validate.rules).duration = {gte {nanos: 1000000}}]; + } + // Optional watchdog configuration. // This is for a single watchdog configuration for the entire system. // Deprecated in favor of ``watchdogs`` which has finer granularity. @@ -265,23 +274,28 @@ message Bootstrap { (udpa.annotations.security).configure_for_untrusted_upstream = true ]; - // Enable :ref:`stats for event dispatcher `, defaults to false. - // Note that this records a value for each iteration of the event loop on every thread. This - // should normally be minimal overhead, but when using - // :ref:`statsd `, it will send each observed value - // over the wire individually because the statsd protocol doesn't have any way to represent a - // histogram summary. Be aware that this can be a very large volume of data. + // Enable :ref:`stats for event dispatcher `. Defaults to ``false``. + // + // .. note:: + // + // This records a value for each iteration of the event loop on every thread. This + // should normally be minimal overhead, but when using + // :ref:`statsd `, it will send each observed value + // over the wire individually because the statsd protocol doesn't have any way to represent a + // histogram summary. Be aware that this can be a very large volume of data. bool enable_dispatcher_stats = 16; - // Optional string which will be used in lieu of x-envoy in prefixing headers. + // Optional string which will be used in lieu of ``x-envoy`` in prefixing headers. + // + // For example, if this string is present and set to ``X-Foo``, then ``x-envoy-retry-on`` will be + // transformed into ``x-foo-retry-on`` etc. // - // For example, if this string is present and set to X-Foo, then x-envoy-retry-on will be - // transformed into x-foo-retry-on etc. + // .. note:: // - // Note this applies to the headers Envoy will generate, the headers Envoy will sanitize, and the - // headers Envoy will trust for core code and core extensions only. Be VERY careful making - // changes to this string, especially in multi-layer Envoy deployments or deployments using - // extensions which are not upstream. + // This applies to the headers Envoy will generate, the headers Envoy will sanitize, and the + // headers Envoy will trust for core code and core extensions only. Be VERY careful making + // changes to this string, especially in multi-layer Envoy deployments or deployments using + // extensions which are not upstream. string header_prefix = 18; // Optional proxy version which will be used to set the value of :ref:`server.version statistic @@ -289,8 +303,8 @@ message Bootstrap { // :ref:`stats sinks `. google.protobuf.UInt64Value stats_server_version_override = 19; - // Always use TCP queries instead of UDP queries for DNS lookups. - // This may be overridden on a per-cluster basis in cds_config, + // Always use ``TCP`` queries instead of ``UDP`` queries for DNS lookups. + // This may be overridden on a per-cluster basis in ``cds_config``, // when :ref:`dns_resolvers ` and // :ref:`use_tcp_for_dns_lookups ` are // specified. @@ -299,8 +313,8 @@ message Bootstrap { bool use_tcp_for_dns_lookups = 20 [deprecated = true, (envoy.annotations.deprecated_at_minor_version) = "3.0"]; - // DNS resolution configuration which includes the underlying dns resolver addresses and options. - // This may be overridden on a per-cluster basis in cds_config, when + // DNS resolution configuration which includes the underlying DNS resolver addresses and options. + // This may be overridden on a per-cluster basis in ``cds_config``, when // :ref:`dns_resolution_config ` // is specified. // This field is deprecated in favor of @@ -308,14 +322,15 @@ message Bootstrap { core.v3.DnsResolutionConfig dns_resolution_config = 30 [deprecated = true, (envoy.annotations.deprecated_at_minor_version) = "3.0"]; - // DNS resolver type configuration extension. This extension can be used to configure c-ares, apple, + // DNS resolver type configuration extension. This extension can be used to configure ``c-ares``, ``apple``, // or any other DNS resolver types and the related parameters. // For example, an object of // :ref:`CaresDnsResolverConfig ` // can be packed into this ``typed_dns_resolver_config``. This configuration replaces the // :ref:`dns_resolution_config ` // configuration. - // During the transition period when both ``dns_resolution_config`` and ``typed_dns_resolver_config`` exists, + // + // During the transition period when both ``dns_resolution_config`` and ``typed_dns_resolver_config`` exist, // when ``typed_dns_resolver_config`` is in place, Envoy will use it and ignore ``dns_resolution_config``. // When ``typed_dns_resolver_config`` is missing, the default behavior is in place. // [#extension-category: envoy.network.dns_resolver] @@ -331,9 +346,10 @@ message Bootstrap { repeated FatalAction fatal_actions = 28; // Configuration sources that will participate in - // xdstp:// URL authority resolution. The algorithm is as + // ``xdstp://`` URL authority resolution. The algorithm is as // follows: - // 1. The authority field is taken from the xdstp:// URL, call + // + // 1. The authority field is taken from the ``xdstp://`` URL, call // this ``resource_authority``. // 2. ``resource_authority`` is compared against the authorities in any peer // ``ConfigSource``. The peer ``ConfigSource`` is the configuration source @@ -349,7 +365,7 @@ message Bootstrap { // [#not-implemented-hide:] repeated core.v3.ConfigSource config_sources = 22; - // Default configuration source for xdstp:// URLs if all + // Default configuration source for ``xdstp://`` URLs if all // other resolution fails. // [#not-implemented-hide:] core.v3.ConfigSource default_config_source = 23; @@ -369,28 +385,30 @@ message Bootstrap { // allows users to customize the inline headers on-demand at Envoy startup without modifying // Envoy's source code. // - // Note that the 'set-cookie' header cannot be registered as inline header. + // .. note:: + // + // The ``set-cookie`` header cannot be registered as inline header. repeated CustomInlineHeader inline_headers = 32; - // Optional path to a file with performance tracing data created by "Perfetto" SDK in binary - // ProtoBuf format. The default value is "envoy.pftrace". + // Optional path to a file with performance tracing data created by ``Perfetto`` SDK in binary + // ProtoBuf format. The default value is ``envoy.pftrace``. string perf_tracing_file_path = 33; // Optional overriding of default regex engine. - // If the value is not specified, Google RE2 will be used by default. + // If the value is not specified, ``Google RE2`` will be used by default. // [#extension-category: envoy.regex_engines] core.v3.TypedExtensionConfig default_regex_engine = 34; // Optional XdsResourcesDelegate configuration, which allows plugging custom logic into both // fetch and load events during xDS processing. - // If a value is not specified, no XdsResourcesDelegate will be used. + // If a value is not specified, no ``XdsResourcesDelegate`` will be used. // TODO(abeyad): Add public-facing documentation. // [#not-implemented-hide:] core.v3.TypedExtensionConfig xds_delegate_extension = 35; // Optional XdsConfigTracker configuration, which allows tracking xDS responses in external components, // e.g., external tracer or monitor. It provides the process point when receive, ingest, or fail to - // process xDS resources and messages. If a value is not specified, no XdsConfigTracker will be used. + // process xDS resources and messages. If a value is not specified, no ``XdsConfigTracker`` will be used. // // .. note:: // @@ -402,20 +420,24 @@ message Bootstrap { // [#not-implemented-hide:] // This controls the type of listener manager configured for Envoy. Currently - // Envoy only supports ListenerManager for this field and Envoy Mobile - // supports ApiListenerManager. + // Envoy only supports ``ListenerManager`` for this field and Envoy Mobile + // supports ``ApiListenerManager``. core.v3.TypedExtensionConfig listener_manager = 37; // Optional application log configuration. ApplicationLogConfig application_log_config = 38; - // Optional gRPC async manager config. + // Optional gRPC async client manager config. GrpcAsyncClientManagerConfig grpc_async_client_manager_config = 40; + + // Optional configuration for memory allocation manager. + // Memory releasing is only supported for `tcmalloc allocator `_. + MemoryAllocatorManager memory_allocator_manager = 41; } // Administration interface :ref:`operations documentation // `. -// [#next-free-field: 7] +// [#next-free-field: 8] message Admin { option (udpa.annotations.versioning).previous_message_type = "envoy.config.bootstrap.v2.Admin"; @@ -424,14 +446,14 @@ message Admin { repeated accesslog.v3.AccessLog access_log = 5; // The path to write the access log for the administration server. If no - // access log is desired specify ‘/dev/null’. This is only required if + // access log is desired specify ``/dev/null``. This is only required if // :ref:`address ` is set. // Deprecated in favor of ``access_log`` which offers more options. string access_log_path = 1 [deprecated = true, (envoy.annotations.deprecated_at_minor_version) = "3.0"]; - // The cpu profiler output path for the administration server. If no profile - // path is specified, the default is ‘/var/log/envoy/envoy.prof’. + // The CPU profiler output path for the administration server. If no profile + // path is specified, the default is ``/var/log/envoy/envoy.prof``. string profile_path = 2; // The TCP address that the administration server will listen on. @@ -445,6 +467,21 @@ message Admin { // Indicates whether :ref:`global_downstream_max_connections ` // should apply to the admin interface or not. bool ignore_global_conn_limit = 6; + + // List of admin paths that are accessible. If not specified, all admin endpoints are accessible. + // + // When specified, only paths in this list will be accessible, all others will return ``HTTP 403 Forbidden``. + // + // Example: + // + // .. code-block:: yaml + // + // allow_paths: + // - exact: /stats + // - exact: /ready + // - prefix: /healthcheck + // + repeated type.matcher.v3.StringMatcher allow_paths = 7; } // Cluster manager :ref:`architecture overview `. @@ -481,7 +518,7 @@ message ClusterManager { OutlierDetection outlier_detection = 2; // Optional configuration used to bind newly established upstream connections. - // This may be overridden on a per-cluster basis by upstream_bind_config in the cds_config. + // This may be overridden on a per-cluster basis by ``upstream_bind_config`` in the ``cds_config``. core.v3.BindConfig upstream_bind_config = 3; // A management server endpoint to stream load stats to via @@ -492,7 +529,7 @@ message ClusterManager { // Whether the ClusterManager will create clusters on the worker threads // inline during requests. This will save memory and CPU cycles in cases where - // there are lots of inactive clusters and > 1 worker thread. + // there are lots of inactive clusters and ``> 1`` worker thread. bool enable_deferred_cluster_creation = 5; } @@ -515,12 +552,12 @@ message Watchdog { option (udpa.annotations.versioning).previous_message_type = "envoy.config.bootstrap.v2.Watchdog"; message WatchdogAction { - // The events are fired in this order: KILL, MULTIKILL, MEGAMISS, MISS. + // The events are fired in this order: ``KILL``, ``MULTIKILL``, ``MEGAMISS``, ``MISS``. // Within an event type, actions execute in the order they are configured. - // For KILL/MULTIKILL there is a default PANIC that will run after the + // For ``KILL``/``MULTIKILL`` there is a default ``PANIC`` that will run after the // registered actions and kills the process if it wasn't already killed. // It might be useful to specify several debug actions, and possibly an - // alternate FATAL action. + // alternate ``FATAL`` action. enum WatchdogEvent { UNKNOWN = 0; KILL = 1; @@ -535,46 +572,48 @@ message Watchdog { WatchdogEvent event = 2 [(validate.rules).enum = {defined_only: true}]; } - // Register actions that will fire on given WatchDog events. - // See ``WatchDogAction`` for priority of events. + // Register actions that will fire on given Watchdog events. + // See ``WatchdogAction`` for priority of events. repeated WatchdogAction actions = 7; // The duration after which Envoy counts a nonresponsive thread in the - // ``watchdog_miss`` statistic. If not specified the default is 200ms. + // ``watchdog_miss`` statistic. If not specified the default is ``200ms``. google.protobuf.Duration miss_timeout = 1; // The duration after which Envoy counts a nonresponsive thread in the - // ``watchdog_mega_miss`` statistic. If not specified the default is - // 1000ms. + // ``watchdog_mega_miss`` statistic. If not specified the default is ``1000ms``. google.protobuf.Duration megamiss_timeout = 2; // If a watched thread has been nonresponsive for this duration, assume a - // programming error and kill the entire Envoy process. Set to 0 to disable - // kill behavior. If not specified the default is 0 (disabled). + // programming error and kill the entire Envoy process. Set to ``0`` to disable + // kill behavior. If not specified the default is ``0`` (disabled). google.protobuf.Duration kill_timeout = 3; // Defines the maximum jitter used to adjust the ``kill_timeout`` if ``kill_timeout`` is // enabled. Enabling this feature would help to reduce risk of synchronized - // watchdog kill events across proxies due to external triggers. Set to 0 to - // disable. If not specified the default is 0 (disabled). + // watchdog kill events across proxies due to external triggers. Set to ``0`` to + // disable. If not specified the default is ``0`` (disabled). google.protobuf.Duration max_kill_timeout_jitter = 6 [(validate.rules).duration = {gte {}}]; - // If ``max(2, ceil(registered_threads * Fraction(*multikill_threshold*)))`` + // If ``max(2, ceil(registered_threads * Fraction(multikill_threshold)))`` // threads have been nonresponsive for at least this duration kill the entire - // Envoy process. Set to 0 to disable this behavior. If not specified the - // default is 0 (disabled). + // Envoy process. Set to ``0`` to disable this behavior. If not specified the + // default is ``0`` (disabled). google.protobuf.Duration multikill_timeout = 4; // Sets the threshold for ``multikill_timeout`` in terms of the percentage of // nonresponsive threads required for the ``multikill_timeout``. - // If not specified the default is 0. + // If not specified the default is ``0``. type.v3.Percent multikill_threshold = 5; } // Fatal actions to run while crashing. Actions can be safe (meaning they are // async-signal safe) or unsafe. We run all safe actions before we run unsafe actions. -// If using an unsafe action that could get stuck or deadlock, it important to -// have an out of band system to terminate the process. +// +// .. note:: +// +// If using an unsafe action that could get stuck or deadlock, it is important to +// have an out of band system to terminate the process. // // The interface for the extension is ``Envoy::Server::Configuration::FatalAction``. // ``FatalAction`` extensions live in the ``envoy.extensions.fatal_actions`` API @@ -657,7 +696,7 @@ message RuntimeLayer { option (udpa.annotations.versioning).previous_message_type = "envoy.config.bootstrap.v2.RuntimeLayer.RtdsLayer"; - // Resource to subscribe to at ``rtds_config`` for the RTDS layer. + // Resource to subscribe to at the ``rtds_config`` for the RTDS layer. string name = 1; // RTDS configuration source. @@ -698,11 +737,11 @@ message LayeredRuntime { // Used to specify the header that needs to be registered as an inline header. // // If request or response contain multiple headers with the same name and the header -// name is registered as an inline header. Then multiple headers will be folded +// name is registered as an inline header, then multiple headers will be folded // into one, and multiple header values will be concatenated by a suitable delimiter. // The delimiter is generally a comma. // -// For example, if 'foo' is registered as an inline header, and the headers contains +// For example, if ``foo`` is registered as an inline header, and the headers contain // the following two headers: // // .. code-block:: text @@ -734,3 +773,14 @@ message CustomInlineHeader { // The type of the header that is expected to be set as the inline header. InlineHeaderType inline_header_type = 2 [(validate.rules).enum = {defined_only: true}]; } + +message MemoryAllocatorManager { + // Configures tcmalloc to perform background release of free memory in amount of bytes per ``memory_release_interval`` interval. + // If equals to ``0``, no memory release will occur. Defaults to ``0``. + uint64 bytes_to_release = 1; + + // Interval in milliseconds for memory releasing. If specified, during every + // interval Envoy will try to release ``bytes_to_release`` of free memory back to operating system for reuse. + // Defaults to ``1000`` milliseconds. + google.protobuf.Duration memory_release_interval = 2; +} diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/cluster/v3/cluster.proto b/xds/third_party/envoy/src/main/proto/envoy/config/cluster/v3/cluster.proto index 9b847a33126..192409096af 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/config/cluster/v3/cluster.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/config/cluster/v3/cluster.proto @@ -22,6 +22,7 @@ import "google/protobuf/struct.proto"; import "google/protobuf/wrappers.proto"; import "xds/core/v3/collection_entry.proto"; +import "xds/type/matcher/v3/matcher.proto"; import "envoy/annotations/deprecation.proto"; import "udpa/annotations/migrate.proto"; @@ -45,7 +46,7 @@ message ClusterCollection { } // Configuration for a single upstream cluster. -// [#next-free-field: 57] +// [#next-free-field: 60] message Cluster { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.Cluster"; @@ -168,7 +169,7 @@ message Cluster { // The name of the match, used in stats generation. string name = 1 [(validate.rules).string = {min_len: 1}]; - // Optional endpoint metadata match criteria. + // Optional metadata match criteria. // The connection to the endpoint with metadata matching what is set in this field // will use the transport socket configuration specified here. // The endpoint's metadata entry in ``envoy.transport_socket_match`` is used to match @@ -652,9 +653,10 @@ message Cluster { // If this is not set, we default to a merge window of 1000ms. To disable it, set the merge // window to 0. // - // Note: merging does not apply to cluster membership changes (e.g.: adds/removes); this is - // because merging those updates isn't currently safe. See - // https://github.com/envoyproxy/envoy/pull/3941. + // .. note:: + // Merging does not apply to cluster membership changes (e.g.: adds/removes); this is + // because merging those updates isn't currently safe. See + // https://github.com/envoyproxy/envoy/pull/3941. google.protobuf.Duration update_merge_window = 4; // If set to true, Envoy will :ref:`exclude ` new hosts @@ -746,6 +748,9 @@ message Cluster { // If both this and preconnect_ratio are set, Envoy will make sure both predicted needs are met, // basically preconnecting max(predictive-preconnect, per-upstream-preconnect), for each // upstream. + // + // This is limited somewhat arbitrarily to 3 because preconnecting too aggressively can + // harm latency more than the preconnecting helps. google.protobuf.DoubleValue predictive_preconnect_ratio = 2 [(validate.rules).double = {lte: 3.0 gte: 1.0}]; } @@ -754,12 +759,14 @@ message Cluster { reserved "hosts", "tls_context", "extension_protocol_options"; - // Configuration to use different transport sockets for different endpoints. - // The entry of ``envoy.transport_socket_match`` in the - // :ref:`LbEndpoint.Metadata ` - // is used to match against the transport sockets as they appear in the list. The first - // :ref:`match ` is used. - // For example, with the following match + // Configuration to use different transport sockets for different endpoints. The entry of + // ``envoy.transport_socket_match`` in the :ref:`LbEndpoint.Metadata + // ` is used to match against the + // transport sockets as they appear in the list. If a match is not found, the search continues in + // :ref:`LocalityLbEndpoints.Metadata + // `. The first :ref:`match + // ` is used. For example, with + // the following match // // .. code-block:: yaml // @@ -783,8 +790,9 @@ message Cluster { // socket match in case above. // // If an endpoint metadata's value under ``envoy.transport_socket_match`` does not match any - // ``TransportSocketMatch``, socket configuration fallbacks to use the ``tls_context`` or - // ``transport_socket`` specified in this cluster. + // ``TransportSocketMatch``, the locality metadata is then checked for a match. Barring any + // matches in the endpoint or locality metadata, the socket configuration fallbacks to use the + // ``tls_context`` or ``transport_socket`` specified in this cluster. // // This field allows gradual and flexible transport socket configuration changes. // @@ -805,6 +813,41 @@ message Cluster { // [#comment:TODO(incfly): add a detailed architecture doc on intended usage.] repeated TransportSocketMatch transport_socket_matches = 43; + // Optional matcher that selects a transport socket from + // :ref:`transport_socket_matches `. + // + // This matcher uses the generic xDS matcher framework to select a named transport socket + // based on various inputs available at transport socket selection time. + // + // Supported matching inputs: + // + // * ``endpoint_metadata``: Extract values from the selected endpoint's metadata. + // * ``locality_metadata``: Extract values from the endpoint's locality metadata. + // * ``transport_socket_filter_state``: Extract values from filter state that was explicitly shared from + // downstream to upstream via ``TransportSocketOptions``. This enables flexible + // downstream-connection-based matching, such as: + // + // - Network namespace matching. + // - Custom connection attributes. + // - Any data explicitly passed via filter state. + // + // .. note:: + // Filter state sharing follows the same pattern as tunneling in Envoy. Filters must explicitly + // share data by setting filter state with the appropriate sharing mode. The filter state is + // then accessible via the ``transport_socket_filter_state`` input during transport socket selection. + // + // If this field is set, it takes precedence over legacy metadata-based selection + // performed by :ref:`transport_socket_matches + // ` alone. + // If the matcher does not yield a match, Envoy uses the default transport socket + // configured for the cluster. + // + // When using this field, each entry in + // :ref:`transport_socket_matches ` + // must have a unique ``name``. The matcher outcome is expected to reference one of + // these names. + xds.type.matcher.v3.Matcher transport_socket_matcher = 59; + // Supplies the name of the cluster which must be unique across all clusters. // The cluster name is used when emitting // :ref:`statistics ` if :ref:`alt_stat_name @@ -813,12 +856,14 @@ message Cluster { string name = 1 [(validate.rules).string = {min_len: 1}]; // An optional alternative to the cluster name to be used for observability. This name is used - // emitting stats for the cluster and access logging the cluster name. This will appear as + // for emitting stats for the cluster and access logging the cluster name. This will appear as // additional information in configuration dumps of a cluster's current status as // :ref:`observability_name ` - // and as an additional tag "upstream_cluster.name" while tracing. Note: Any ``:`` in the name - // will be converted to ``_`` when emitting statistics. This should not be confused with - // :ref:`Router Filter Header `. + // and as an additional tag "upstream_cluster.name" while tracing. + // + // .. note:: + // Any ``:`` in the name will be converted to ``_`` when emitting statistics. This should not be confused with + // :ref:`Router Filter Header `. string alt_stat_name = 28 [(udpa.annotations.field_migrate).rename = "observability_name"]; oneof cluster_discovery_type { @@ -939,6 +984,7 @@ message Cluster { // "envoy.filters.network.thrift_proxy". See the extension's documentation for details on // specific options. // [#next-major-version: make this a list of typed extensions.] + // [#extension-category: envoy.upstream_options] map typed_extension_protocol_options = 36; // If the DNS refresh rate is specified and the cluster type is either @@ -950,8 +996,34 @@ message Cluster { // :ref:`STRICT_DNS` // and :ref:`LOGICAL_DNS` // this setting is ignored. - google.protobuf.Duration dns_refresh_rate = 16 - [(validate.rules).duration = {gt {nanos: 1000000}}]; + // This field is deprecated in favor of using the :ref:`cluster_type` + // extension point and configuring it with :ref:`DnsCluster`. + // If :ref:`cluster_type` is configured with + // :ref:`DnsCluster`, this field will be ignored. + google.protobuf.Duration dns_refresh_rate = 16 [ + deprecated = true, + (validate.rules).duration = {gt {nanos: 1000000}}, + (envoy.annotations.deprecated_at_minor_version) = "3.0" + ]; + + // DNS jitter can be optionally specified if the cluster type is either + // :ref:`STRICT_DNS`, + // or :ref:`LOGICAL_DNS`. + // DNS jitter causes the cluster to refresh DNS entries later by a random amount of time to avoid a + // stampede of DNS requests. This value sets the upper bound (exclusive) for the random amount. + // There will be no jitter if this value is omitted. For cluster types other than + // :ref:`STRICT_DNS` + // and :ref:`LOGICAL_DNS` + // this setting is ignored. + // This field is deprecated in favor of using the :ref:`cluster_type` + // extension point and configuring it with :ref:`DnsCluster`. + // If :ref:`cluster_type` is configured with + // :ref:`DnsCluster`, this field will be ignored. + google.protobuf.Duration dns_jitter = 58 [ + deprecated = true, + (validate.rules).duration = {gte {}}, + (envoy.annotations.deprecated_at_minor_version) = "3.0" + ]; // If the DNS failure refresh rate is specified and the cluster type is either // :ref:`STRICT_DNS`, @@ -961,16 +1033,31 @@ message Cluster { // other than :ref:`STRICT_DNS` and // :ref:`LOGICAL_DNS` this setting is // ignored. - RefreshRate dns_failure_refresh_rate = 44; + // This field is deprecated in favor of using the :ref:`cluster_type` + // extension point and configuring it with :ref:`DnsCluster`. + // If :ref:`cluster_type` is configured with + // :ref:`DnsCluster`, this field will be ignored. + RefreshRate dns_failure_refresh_rate = 44 + [deprecated = true, (envoy.annotations.deprecated_at_minor_version) = "3.0"]; // Optional configuration for setting cluster's DNS refresh rate. If the value is set to true, // cluster's DNS refresh rate will be set to resource record's TTL which comes from DNS // resolution. - bool respect_dns_ttl = 39; + // This field is deprecated in favor of using the :ref:`cluster_type` + // extension point and configuring it with :ref:`DnsCluster`. + // If :ref:`cluster_type` is configured with + // :ref:`DnsCluster`, this field will be ignored. + bool respect_dns_ttl = 39 + [deprecated = true, (envoy.annotations.deprecated_at_minor_version) = "3.0"]; // The DNS IP address resolution policy. If this setting is not specified, the // value defaults to // :ref:`AUTO`. + // For logical and strict dns cluster, this field is deprecated in favor of using the + // :ref:`cluster_type` + // extension point and configuring it with :ref:`DnsCluster`. + // If :ref:`cluster_type` is configured with + // :ref:`DnsCluster`, this field will be ignored. DnsLookupFamily dns_lookup_family = 17 [(validate.rules).enum = {defined_only: true}]; // If DNS resolvers are specified and the cluster type is either @@ -1010,6 +1097,9 @@ message Cluster { // During the transition period when both ``dns_resolution_config`` and ``typed_dns_resolver_config`` exists, // when ``typed_dns_resolver_config`` is in place, Envoy will use it and ignore ``dns_resolution_config``. // When ``typed_dns_resolver_config`` is missing, the default behavior is in place. + // Also note that this field is deprecated for logical dns and strict dns clusters and will be ignored when + // :ref:`cluster_type` is configured with + // :ref:`DnsCluster`. // [#extension-category: envoy.network.dns_resolver] core.v3.TypedExtensionConfig typed_dns_resolver_config = 55; @@ -1148,6 +1238,23 @@ message Cluster { // from the LRS stream here.] core.v3.ConfigSource lrs_server = 42; + // A list of metric names from :ref:`ORCA load reports ` to propagate to LRS. + // + // If not specified, then ORCA load reports will not be propagated to LRS. + // + // For map fields in the ORCA proto, the string will be of the form ``.``. + // For example, the string ``named_metrics.foo`` will mean to look for the key ``foo`` in the ORCA + // :ref:`named_metrics ` field. + // + // The special map key ``*`` means to report all entries in the map (e.g., ``named_metrics.*`` means to + // report all entries in the ORCA named_metrics field). Note that this should be used only with trusted + // backends. + // + // The metric names in LRS will follow the same semantics as this field. In other words, if this field + // contains ``named_metrics.foo``, then the LRS load report will include the data with that same string + // as the key. + repeated string lrs_report_endpoint_metrics = 57; + // If track_timeout_budgets is true, the :ref:`timeout budget histograms // ` will be published for each // request. These show what percentage of a request's per try and global timeout was used. A value @@ -1236,6 +1343,26 @@ message UpstreamConnectionOptions { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.UpstreamConnectionOptions"; + enum FirstAddressFamilyVersion { + // respect the native ranking of destination ip addresses returned from dns + // resolution + DEFAULT = 0; + + V4 = 1; + + V6 = 2; + } + + message HappyEyeballsConfig { + // Specify the IP address family to attempt connection first in happy + // eyeballs algorithm according to RFC8305#section-4. + FirstAddressFamilyVersion first_address_family_version = 1; + + // Specify the number of addresses of the first_address_family_version being + // attempted for connection before the other address family. + google.protobuf.UInt32Value first_address_family_count = 2 [(validate.rules).uint32 = {gte: 1}]; + } + // If set then set SO_KEEPALIVE on the socket to enable TCP Keepalives. core.v3.TcpKeepalive tcp_keepalive = 1; @@ -1243,6 +1370,11 @@ message UpstreamConnectionOptions { // This can be used by extensions during processing of requests. The association mechanism is // implementation specific. Defaults to false due to performance concerns. bool set_local_interface_name_on_upstream_connections = 2; + + // Configurations for happy eyeballs algorithm. + // Add configs for first_address_family_version and first_address_family_count + // when sorting destination ip addresses. + HappyEyeballsConfig happy_eyeballs_config = 3; } message TrackClusterStats { @@ -1255,7 +1387,7 @@ message TrackClusterStats { // If request_response_sizes is true, then the :ref:`histograms // ` tracking header and body sizes - // of requests and responses will be published. + // of requests and responses will be published. Additionally, number of headers in the requests and responses will be tracked. bool request_response_sizes = 2; // If true, some stats will be emitted per-endpoint, similar to the stats in admin ``/clusters`` diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/cluster/v3/outlier_detection.proto b/xds/third_party/envoy/src/main/proto/envoy/config/cluster/v3/outlier_detection.proto index 11289e26b4f..822d81da850 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/config/cluster/v3/outlier_detection.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/config/cluster/v3/outlier_detection.proto @@ -2,6 +2,8 @@ syntax = "proto3"; package envoy.config.cluster.v3; +import "envoy/config/core/v3/extension.proto"; + import "google/protobuf/duration.proto"; import "google/protobuf/wrappers.proto"; @@ -19,7 +21,7 @@ option (udpa.annotations.file_status).package_version_status = ACTIVE; // See the :ref:`architecture overview ` for // more information on outlier detection. -// [#next-free-field: 24] +// [#next-free-field: 26] message OutlierDetection { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.cluster.OutlierDetection"; @@ -40,8 +42,8 @@ message OutlierDetection { // Defaults to 30000ms or 30s. google.protobuf.Duration base_ejection_time = 3 [(validate.rules).duration = {gt {}}]; - // The maximum % of an upstream cluster that can be ejected due to outlier - // detection. Defaults to 10% but will eject at least one host regardless of the value. + // The maximum % of an upstream cluster that can be ejected due to outlier detection. Defaults to 10% . + // Will eject at least one host regardless of the value if :ref:`always_eject_one_host` is enabled. google.protobuf.UInt32Value max_ejection_percent = 4 [(validate.rules).uint32 = {lte: 100}]; // The % chance that a host will be actually ejected when an outlier status @@ -167,4 +169,12 @@ message OutlierDetection { // To change this default behavior set this config to ``false`` where active health checking will not uneject the host. // Defaults to true. google.protobuf.BoolValue successful_active_health_check_uneject_host = 23; + + // Set of host's passive monitors. + // [#not-implemented-hide:] + repeated core.v3.TypedExtensionConfig monitors = 24; + + // If enabled, at least one host is ejected regardless of the value of :ref:`max_ejection_percent`. + // Defaults to false. + google.protobuf.BoolValue always_eject_one_host = 25; } diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/common/matcher/v3/matcher.proto b/xds/third_party/envoy/src/main/proto/envoy/config/common/matcher/v3/matcher.proto new file mode 100644 index 00000000000..9b189d1aa77 --- /dev/null +++ b/xds/third_party/envoy/src/main/proto/envoy/config/common/matcher/v3/matcher.proto @@ -0,0 +1,239 @@ +syntax = "proto3"; + +package envoy.config.common.matcher.v3; + +import "envoy/config/core/v3/extension.proto"; +import "envoy/config/route/v3/route_components.proto"; +import "envoy/type/matcher/v3/string.proto"; + +import "udpa/annotations/status.proto"; +import "validate/validate.proto"; + +option java_package = "io.envoyproxy.envoy.config.common.matcher.v3"; +option java_outer_classname = "MatcherProto"; +option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/config/common/matcher/v3;matcherv3"; +option (udpa.annotations.file_status).package_version_status = ACTIVE; + +// [#protodoc-title: Unified Matcher API] + +// A matcher, which may traverse a matching tree in order to result in a match action. +// During matching, the tree will be traversed until a match is found, or if no match +// is found the action specified by the most specific on_no_match will be evaluated. +// As an on_no_match might result in another matching tree being evaluated, this process +// might repeat several times until the final OnMatch (or no match) is decided. +// +// .. note:: +// Please use the syntactically equivalent :ref:`matching API ` +message Matcher { + // What to do if a match is successful. + message OnMatch { + oneof on_match { + option (validate.required) = true; + + // Nested matcher to evaluate. + // If the nested matcher does not match and does not specify + // on_no_match, then this matcher is considered not to have + // matched, even if a predicate at this level or above returned + // true. + Matcher matcher = 1; + + // Protocol-specific action to take. + core.v3.TypedExtensionConfig action = 2; + } + + // If true, the action will be taken but the caller will behave as if no + // match was found. This applies both to actions directly encoded in the + // action field and to actions returned from a nested matcher tree in the + // matcher field. A subsequent matcher on_no_match action will be used + // instead. + // + // This field is not supported in all contexts in which the matcher API is + // used. If this field is set in a context in which it's not supported, + // the resource will be rejected. + bool keep_matching = 3; + } + + // A linear list of field matchers. + // The field matchers are evaluated in order, and the first match + // wins. + message MatcherList { + // Predicate to determine if a match is successful. + message Predicate { + // Predicate for a single input field. + message SinglePredicate { + // Protocol-specific specification of input field to match on. + // [#extension-category: envoy.matching.common_inputs] + core.v3.TypedExtensionConfig input = 1 [(validate.rules).message = {required: true}]; + + oneof matcher { + option (validate.required) = true; + + // Built-in string matcher. + type.matcher.v3.StringMatcher value_match = 2; + + // Extension for custom matching logic. + // [#extension-category: envoy.matching.input_matchers] + core.v3.TypedExtensionConfig custom_match = 3; + } + } + + // A list of two or more matchers. Used to allow using a list within a oneof. + message PredicateList { + repeated Predicate predicate = 1 [(validate.rules).repeated = {min_items: 2}]; + } + + oneof match_type { + option (validate.required) = true; + + // A single predicate to evaluate. + SinglePredicate single_predicate = 1; + + // A list of predicates to be OR-ed together. + PredicateList or_matcher = 2; + + // A list of predicates to be AND-ed together. + PredicateList and_matcher = 3; + + // The inverse of a predicate + Predicate not_matcher = 4; + } + } + + // An individual matcher. + message FieldMatcher { + // Determines if the match succeeds. + Predicate predicate = 1 [(validate.rules).message = {required: true}]; + + // What to do if the match succeeds. + OnMatch on_match = 2 [(validate.rules).message = {required: true}]; + } + + // A list of matchers. First match wins. + repeated FieldMatcher matchers = 1 [(validate.rules).repeated = {min_items: 1}]; + } + + message MatcherTree { + // A map of configured matchers. Used to allow using a map within a oneof. + message MatchMap { + map map = 1 [(validate.rules).map = {min_pairs: 1}]; + } + + // Protocol-specific specification of input field to match on. + core.v3.TypedExtensionConfig input = 1 [(validate.rules).message = {required: true}]; + + // Exact or prefix match maps in which to look up the input value. + // If the lookup succeeds, the match is considered successful, and + // the corresponding OnMatch is used. + oneof tree_type { + option (validate.required) = true; + + MatchMap exact_match_map = 2; + + // Longest matching prefix wins. + MatchMap prefix_match_map = 3; + + // Extension for custom matching logic. + core.v3.TypedExtensionConfig custom_match = 4; + } + } + + oneof matcher_type { + option (validate.required) = true; + + // A linear list of matchers to evaluate. + MatcherList matcher_list = 1; + + // A match tree to evaluate. + MatcherTree matcher_tree = 2; + } + + // Optional ``OnMatch`` to use if the matcher failed. + // If specified, the ``OnMatch`` is used, and the matcher is considered + // to have matched. + // If not specified, the matcher is considered not to have matched. + OnMatch on_no_match = 3; +} + +// Match configuration. This is a recursive structure which allows complex nested match +// configurations to be built using various logical operators. +// [#next-free-field: 11] +message MatchPredicate { + // A set of match configurations used for logical operations. + message MatchSet { + // The list of rules that make up the set. + repeated MatchPredicate rules = 1 [(validate.rules).repeated = {min_items: 2}]; + } + + oneof rule { + option (validate.required) = true; + + // A set that describes a logical OR. If any member of the set matches, the match configuration + // matches. + MatchSet or_match = 1; + + // A set that describes a logical AND. If all members of the set match, the match configuration + // matches. + MatchSet and_match = 2; + + // A negation match. The match configuration will match if the negated match condition matches. + MatchPredicate not_match = 3; + + // The match configuration will always match. + bool any_match = 4 [(validate.rules).bool = {const: true}]; + + // HTTP request headers match configuration. + HttpHeadersMatch http_request_headers_match = 5; + + // HTTP request trailers match configuration. + HttpHeadersMatch http_request_trailers_match = 6; + + // HTTP response headers match configuration. + HttpHeadersMatch http_response_headers_match = 7; + + // HTTP response trailers match configuration. + HttpHeadersMatch http_response_trailers_match = 8; + + // HTTP request generic body match configuration. + HttpGenericBodyMatch http_request_generic_body_match = 9; + + // HTTP response generic body match configuration. + HttpGenericBodyMatch http_response_generic_body_match = 10; + } +} + +// HTTP headers match configuration. +message HttpHeadersMatch { + // HTTP headers to match. + repeated route.v3.HeaderMatcher headers = 1; +} + +// HTTP generic body match configuration. +// List of text strings and hex strings to be located in HTTP body. +// All specified strings must be found in the HTTP body for positive match. +// The search may be limited to specified number of bytes from the body start. +// +// .. attention:: +// +// Searching for patterns in HTTP body is potentially CPU-intensive. For each specified pattern, HTTP body is scanned byte by byte to find a match. +// If multiple patterns are specified, the process is repeated for each pattern. If location of a pattern is known, ``bytes_limit`` should be specified +// to scan only part of the HTTP body. +message HttpGenericBodyMatch { + message GenericTextMatch { + oneof rule { + option (validate.required) = true; + + // Text string to be located in HTTP body. + string string_match = 1 [(validate.rules).string = {min_len: 1}]; + + // Sequence of bytes to be located in HTTP body. + bytes binary_match = 2 [(validate.rules).bytes = {min_len: 1}]; + } + } + + // Limits search to specified number of bytes - default zero (no limit - match entire captured buffer). + uint32 bytes_limit = 1; + + // List of patterns to match. + repeated GenericTextMatch patterns = 2 [(validate.rules).repeated = {min_items: 1}]; +} diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/common/mutation_rules/v3/mutation_rules.proto b/xds/third_party/envoy/src/main/proto/envoy/config/common/mutation_rules/v3/mutation_rules.proto new file mode 100644 index 00000000000..c015db21431 --- /dev/null +++ b/xds/third_party/envoy/src/main/proto/envoy/config/common/mutation_rules/v3/mutation_rules.proto @@ -0,0 +1,113 @@ +syntax = "proto3"; + +package envoy.config.common.mutation_rules.v3; + +import "envoy/config/core/v3/base.proto"; +import "envoy/type/matcher/v3/regex.proto"; +import "envoy/type/matcher/v3/string.proto"; + +import "google/protobuf/wrappers.proto"; + +import "udpa/annotations/status.proto"; +import "validate/validate.proto"; + +option java_package = "io.envoyproxy.envoy.config.common.mutation_rules.v3"; +option java_outer_classname = "MutationRulesProto"; +option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/config/common/mutation_rules/v3;mutation_rulesv3"; +option (udpa.annotations.file_status).package_version_status = ACTIVE; + +// [#protodoc-title: Header mutation rules] + +// The HeaderMutationRules structure specifies what headers may be +// manipulated by a processing filter. This set of rules makes it +// possible to control which modifications a filter may make. +// +// By default, an external processing server may add, modify, or remove +// any header except for an "Envoy internal" header (which is typically +// denoted by an x-envoy prefix) or specific headers that may affect +// further filter processing: +// +// * ``host`` +// * ``:authority`` +// * ``:scheme`` +// * ``:method`` +// +// Every attempt to add, change, append, or remove a header will be +// tested against the rules here. Disallowed header mutations will be +// ignored unless ``disallow_is_error`` is set to true. +// +// Attempts to remove headers are further constrained -- regardless of the +// settings, system-defined headers (that start with ``:``) and the ``host`` +// header may never be removed. +// +// In addition, a counter will be incremented whenever a mutation is +// rejected. In the ext_proc filter, that counter is named +// ``rejected_header_mutations``. +// [#next-free-field: 8] +message HeaderMutationRules { + // By default, certain headers that could affect processing of subsequent + // filters or request routing cannot be modified. These headers are + // ``host``, ``:authority``, ``:scheme``, and ``:method``. Setting this parameter + // to true allows these headers to be modified as well. + google.protobuf.BoolValue allow_all_routing = 1; + + // If true, allow modification of envoy internal headers. By default, these + // start with ``x-envoy`` but this may be overridden in the ``Bootstrap`` + // configuration using the + // :ref:`header_prefix ` + // field. Default is false. + google.protobuf.BoolValue allow_envoy = 2; + + // If true, prevent modification of any system header, defined as a header + // that starts with a ``:`` character, regardless of any other settings. + // A processing server may still override the ``:status`` of an HTTP response + // using an ``ImmediateResponse`` message. Default is false. + google.protobuf.BoolValue disallow_system = 3; + + // If true, prevent modifications of all header values, regardless of any + // other settings. A processing server may still override the ``:status`` + // of an HTTP response using an ``ImmediateResponse`` message. Default is false. + google.protobuf.BoolValue disallow_all = 4; + + // If set, specifically allow any header that matches this regular + // expression. This overrides all other settings except for + // ``disallow_expression``. + type.matcher.v3.RegexMatcher allow_expression = 5; + + // If set, specifically disallow any header that matches this regular + // expression regardless of any other settings. + type.matcher.v3.RegexMatcher disallow_expression = 6; + + // If true, and if the rules in this list cause a header mutation to be + // disallowed, then the filter using this configuration will terminate the + // request with a 500 error. In addition, regardless of the setting of this + // parameter, any attempt to set, add, or modify a disallowed header will + // cause the ``rejected_header_mutations`` counter to be incremented. + // Default is false. + google.protobuf.BoolValue disallow_is_error = 7; +} + +// The HeaderMutation structure specifies an action that may be taken on HTTP +// headers. +message HeaderMutation { + message RemoveOnMatch { + // A string matcher that will be applied to the header key. If the header key + // matches, the header will be removed. + type.matcher.v3.StringMatcher key_matcher = 1 [(validate.rules).message = {required: true}]; + } + + oneof action { + option (validate.required) = true; + + // Remove the specified header if it exists. + string remove = 1 + [(validate.rules).string = {well_known_regex: HTTP_HEADER_VALUE strict: false}]; + + // Append new header by the specified HeaderValueOption. + core.v3.HeaderValueOption append = 2; + + // Remove the header if the key matches the specified string matcher. + RemoveOnMatch remove_on_match = 3; + } +} diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/address.proto b/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/address.proto index d8d47882655..17a68269e34 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/address.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/address.proto @@ -50,7 +50,7 @@ message EnvoyInternalAddress { string endpoint_id = 2; } -// [#next-free-field: 7] +// [#next-free-field: 8] message SocketAddress { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.core.SocketAddress"; @@ -97,6 +97,17 @@ message SocketAddress { // allow both IPv4 and IPv6 connections, with peer IPv4 addresses mapped into // IPv6 space as ``::FFFF:``. bool ipv4_compat = 6; + + // Filepath that specifies the Linux network namespace this socket will be created in (see ``man 7 + // network_namespaces``). If this field is set, Envoy will create the socket in the specified + // network namespace. + // + // .. note:: + // Setting this parameter requires Envoy to run with the ``CAP_NET_ADMIN`` capability. + // + // .. attention:: + // Network namespaces are only configurable on Linux. Otherwise, this field has no effect. + string network_namespace_filepath = 7; } message TcpKeepalive { @@ -104,16 +115,18 @@ message TcpKeepalive { // Maximum number of keepalive probes to send without response before deciding // the connection is dead. Default is to use the OS level configuration (unless - // overridden, Linux defaults to 9.) + // overridden, Linux defaults to 9.) Setting this to ``0`` disables TCP keepalive. google.protobuf.UInt32Value keepalive_probes = 1; // The number of seconds a connection needs to be idle before keep-alive probes // start being sent. Default is to use the OS level configuration (unless - // overridden, Linux defaults to 7200s (i.e., 2 hours.) + // overridden, Linux defaults to 7200s (i.e., 2 hours.) Setting this to ``0`` disables + // TCP keepalive. google.protobuf.UInt32Value keepalive_time = 2; // The number of seconds between keep-alive probes. Default is to use the OS - // level configuration (unless overridden, Linux defaults to 75s.) + // level configuration (unless overridden, Linux defaults to 75s.) Setting this to + // ``0`` disables TCP keepalive. google.protobuf.UInt32Value keepalive_interval = 3; } diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/base.proto b/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/base.proto index 97131e4b8c6..978f365d5f9 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/base.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/base.proto @@ -245,7 +245,8 @@ message Metadata { // :ref:`typed_filter_metadata ` // fields are present in the metadata with same keys, // only ``typed_filter_metadata`` field will be parsed. - map filter_metadata = 1; + map filter_metadata = 1 + [(validate.rules).map = {keys {string {min_len: 1}}}]; // Key is the reverse DNS filter name, e.g. com.acme.widget. The ``envoy.*`` // namespace is reserved for Envoy's built-in filters. @@ -253,7 +254,8 @@ message Metadata { // If both :ref:`filter_metadata ` // and ``typed_filter_metadata`` fields are present in the metadata with same keys, // only ``typed_filter_metadata`` field will be parsed. - map typed_filter_metadata = 2; + map typed_filter_metadata = 2 + [(validate.rules).map = {keys {string {min_len: 1}}}]; } // Runtime derived uint32 with a default when not specified. @@ -264,7 +266,7 @@ message RuntimeUInt32 { uint32 default_value = 2; // Runtime key to get value for comparison. This value is used if defined. - string runtime_key = 3 [(validate.rules).string = {min_len: 1}]; + string runtime_key = 3; } // Runtime derived percentage with a default when not specified. @@ -273,7 +275,7 @@ message RuntimePercent { type.v3.Percent default_value = 1; // Runtime key to get value for comparison. This value is used if defined. - string runtime_key = 2 [(validate.rules).string = {min_len: 1}]; + string runtime_key = 2; } // Runtime derived double with a default when not specified. @@ -284,7 +286,7 @@ message RuntimeDouble { double default_value = 1; // Runtime key to get value for comparison. This value is used if defined. - string runtime_key = 2 [(validate.rules).string = {min_len: 1}]; + string runtime_key = 2; } // Runtime derived bool with a default when not specified. @@ -298,7 +300,91 @@ message RuntimeFeatureFlag { // Runtime key to get value for comparison. This value is used if defined. The boolean value must // be represented via its // `canonical JSON encoding `_. - string runtime_key = 2 [(validate.rules).string = {min_len: 1}]; + string runtime_key = 2; +} + +// Please use :ref:`KeyValuePair ` instead. +// [#not-implemented-hide:] +message KeyValue { + // The key of the key/value pair. + string key = 1 [ + deprecated = true, + (validate.rules).string = {min_len: 1 max_bytes: 16384}, + (envoy.annotations.deprecated_at_minor_version) = "3.0" + ]; + + // The value of the key/value pair. + // + // The ``bytes`` type is used. This means if JSON or YAML is used to to represent the + // configuration, the value must be base64 encoded. This is unfriendly for users in most + // use scenarios of this message. + // + bytes value = 2 [deprecated = true, (envoy.annotations.deprecated_at_minor_version) = "3.0"]; +} + +message KeyValuePair { + // The key of the key/value pair. + string key = 1 [(validate.rules).string = {min_len: 1 max_bytes: 16384}]; + + // The value of the key/value pair. + google.protobuf.Value value = 2; +} + +// Key/value pair plus option to control append behavior. This is used to specify +// key/value pairs that should be appended to a set of existing key/value pairs. +message KeyValueAppend { + // Describes the supported actions types for key/value pair append action. + enum KeyValueAppendAction { + // If the key already exists, this action will result in the following behavior: + // + // - Comma-concatenated value if multiple values are not allowed. + // - New value added to the list of values if multiple values are allowed. + // + // If the key doesn't exist then this will add pair with specified key and value. + APPEND_IF_EXISTS_OR_ADD = 0; + + // This action will add the key/value pair if it doesn't already exist. If the + // key already exists then this will be a no-op. + ADD_IF_ABSENT = 1; + + // This action will overwrite the specified value by discarding any existing + // values if the key already exists. If the key doesn't exist then this will add + // the pair with specified key and value. + OVERWRITE_IF_EXISTS_OR_ADD = 2; + + // This action will overwrite the specified value by discarding any existing + // values if the key already exists. If the key doesn't exist then this will + // be no-op. + OVERWRITE_IF_EXISTS = 3; + } + + // The single key/value pair record to be appended or overridden. This field must be set. + KeyValuePair record = 3; + + // Key/value pair entry that this option to append or overwrite. This field is deprecated + // and please use :ref:`record ` + // as replacement. + // [#not-implemented-hide:] + KeyValue entry = 1 [ + deprecated = true, + (validate.rules).message = {skip: true}, + (envoy.annotations.deprecated_at_minor_version) = "3.0" + ]; + + // Describes the action taken to append/overwrite the given value for an existing + // key or to only add this key if it's absent. + KeyValueAppendAction action = 2 [(validate.rules).enum = {defined_only: true}]; +} + +// Key/value pair to append or remove. +message KeyValueMutation { + // Key/value pair to append or overwrite. Only one of ``append`` or ``remove`` can be set or + // the configuration will be rejected. + KeyValueAppend append = 1; + + // Key to remove. Only one of ``append`` or ``remove`` can be set or the configuration will be + // rejected. + string remove = 2 [(validate.rules).string = {max_bytes: 16384}]; } // Query parameter name/value pair. @@ -398,6 +484,7 @@ message HeaderValueOption { message HeaderMap { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.core.HeaderMap"; + // A list of header names and their values. repeated HeaderValue headers = 1; } @@ -409,6 +496,7 @@ message WatchedDirectory { } // Data source consisting of a file, an inline value, or an environment variable. +// [#next-free-field: 6] message DataSource { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.core.DataSource"; @@ -427,12 +515,47 @@ message DataSource { // Environment variable data source. string environment_variable = 4 [(validate.rules).string = {min_len: 1}]; } + + // Watched directory that is watched for file changes. If this is set explicitly, the file + // specified in the ``filename`` field will be reloaded when relevant file move events occur. + // + // .. note:: + // This field only makes sense when the ``filename`` field is set. + // + // .. note:: + // Envoy only updates when the file is replaced by a file move, and not when the file is + // edited in place. + // + // .. note:: + // Not all use cases of ``DataSource`` support watching directories. It depends on the + // specific usage of the ``DataSource``. See the documentation of the parent message for + // details. + WatchedDirectory watched_directory = 5; } // The message specifies the retry policy of remote data source when fetching fails. +// [#next-free-field: 7] message RetryPolicy { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.core.RetryPolicy"; + // See :ref:`RetryPriority `. + message RetryPriority { + string name = 1 [(validate.rules).string = {min_len: 1}]; + + oneof config_type { + google.protobuf.Any typed_config = 2; + } + } + + // See :ref:`RetryHostPredicate `. + message RetryHostPredicate { + string name = 1 [(validate.rules).string = {min_len: 1}]; + + oneof config_type { + google.protobuf.Any typed_config = 2; + } + } + // Specifies parameters that control :ref:`retry backoff strategy `. // This parameter is optional, in which case the default base interval is 1000 milliseconds. The // default maximum interval is 10 times the base interval. @@ -442,6 +565,18 @@ message RetryPolicy { // defaults to 1. google.protobuf.UInt32Value num_retries = 2 [(udpa.annotations.field_migrate).rename = "max_retries"]; + + // For details, see :ref:`retry_on `. + string retry_on = 3; + + // For details, see :ref:`retry_priority `. + RetryPriority retry_priority = 4; + + // For details, see :ref:`RetryHostPredicate `. + repeated RetryHostPredicate retry_host_predicate = 5; + + // For details, see :ref:`host_selection_retry_max_attempts `. + int64 host_selection_retry_max_attempts = 6; } // The message specifies how to fetch data from remote and how to verify it. diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/cel.proto b/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/cel.proto new file mode 100644 index 00000000000..940a66d0b10 --- /dev/null +++ b/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/cel.proto @@ -0,0 +1,63 @@ +syntax = "proto3"; + +package envoy.config.core.v3; + +import "udpa/annotations/status.proto"; + +option java_package = "io.envoyproxy.envoy.config.core.v3"; +option java_outer_classname = "CelProto"; +option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/config/core/v3;corev3"; +option (udpa.annotations.file_status).package_version_status = ACTIVE; + +// [#protodoc-title: CEL Expression Configuration] + +// CEL expression evaluation configuration. +// These options control the behavior of the Common Expression Language runtime for +// individual CEL expressions. +message CelExpressionConfig { + // Enable string conversion functions for CEL expressions. When enabled, CEL expressions + // can convert values to strings using the ``string()`` function. + // + // .. attention:: + // + // This option is disabled by default to avoid unbounded memory allocation. + // CEL evaluation cost is typically bounded by the expression size, but converting + // arbitrary values (e.g., large messages, lists, or maps) to strings may allocate + // memory proportional to input data size, which can be unbounded and lead to + // memory exhaustion. + bool enable_string_conversion = 1; + + // Enable string concatenation for CEL expressions. When enabled, CEL expressions + // can concatenate strings using the ``+`` operator. + // + // .. attention:: + // + // This option is disabled by default to avoid unbounded memory allocation. + // While CEL normally bounds evaluation by expression size, enabling string + // concatenation allows building outputs whose size depends on input data, + // potentially causing large intermediate allocations and memory exhaustion. + bool enable_string_concat = 2; + + // Enable string manipulation functions for CEL expressions. When enabled, CEL + // expressions can use additional string functions: + // + // * ``replace(old, new)`` - Replaces all occurrences of ``old`` with ``new``. + // * ``split(separator)`` - Splits a string into a list of substrings. + // * ``lowerAscii()`` - Converts ASCII characters to lowercase. + // * ``upperAscii()`` - Converts ASCII characters to uppercase. + // + // .. note:: + // + // Standard CEL string functions like ``contains()``, ``startsWith()``, and + // ``endsWith()`` are always available regardless of this setting. + // + // .. attention:: + // + // This option is disabled by default to avoid unbounded memory allocation. + // Although CEL generally bounds evaluation by expression size, functions such as + // ``replace``, ``split``, ``lowerAscii()``, and ``upperAscii()`` can allocate memory + // proportional to input data size. Under adversarial inputs this can lead to + // unbounded allocations and memory exhaustion. + bool enable_string_functions = 3; +} diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/config_source.proto b/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/config_source.proto index 70204bad9eb..430562aa5bd 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/config_source.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/config_source.proto @@ -28,12 +28,10 @@ option (udpa.annotations.file_status).package_version_status = ACTIVE; // xDS API and non-xDS services version. This is used to describe both resource and transport // protocol versions (in distinct configuration fields). enum ApiVersion { - // When not specified, we assume v2, to ease migration to Envoy's stable API - // versioning. If a client does not support v2 (e.g. due to deprecation), this - // is an invalid value. - AUTO = 0 [deprecated = true, (envoy.annotations.deprecated_at_minor_version_enum) = "3.0"]; + // When not specified, we assume v3; it is the only supported version. + AUTO = 0; - // Use xDS v2 API. + // Use xDS v2 API. This is no longer supported. V2 = 1 [deprecated = true, (envoy.annotations.deprecated_at_minor_version_enum) = "3.0"]; // Use xDS v3 API. @@ -278,7 +276,8 @@ message ExtensionConfigSource { // to be supplied. bool apply_default_config_without_warming = 3; - // A set of permitted extension type URLs. Extension configuration updates are rejected - // if they do not match any type URL in the set. + // A set of permitted extension type URLs for the type encoded inside of the + // :ref:`TypedExtensionConfig `. Extension + // configuration updates are rejected if they do not match any type URL in the set. repeated string type_urls = 4 [(validate.rules).repeated = {min_items: 1}]; } diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/grpc_service.proto b/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/grpc_service.proto index f266c7bce5b..9c44006b2a9 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/grpc_service.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/grpc_service.proto @@ -25,10 +25,11 @@ option (udpa.annotations.file_status).package_version_status = ACTIVE; // gRPC service configuration. This is used by :ref:`ApiConfigSource // ` and filter configurations. -// [#next-free-field: 6] +// [#next-free-field: 7] message GrpcService { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.core.GrpcService"; + // [#next-free-field: 6] message EnvoyGrpc { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.core.GrpcService.EnvoyGrpc"; @@ -44,14 +45,36 @@ message GrpcService { [(validate.rules).string = {min_len: 0 max_bytes: 16384 well_known_regex: HTTP_HEADER_VALUE strict: false}]; - // Indicates the retry policy for re-establishing the gRPC stream - // This field is optional. If max interval is not provided, it will be set to ten times the provided base interval. - // Currently only supported for xDS gRPC streams. - // If not set, xDS gRPC streams default base interval:500ms, maximum interval:30s will be applied. + // Specifies the retry backoff policy for re-establishing long‑lived xDS gRPC streams. + // + // This field is optional. If ``retry_back_off.max_interval`` is not provided, it will be set to + // ten times the configured ``retry_back_off.base_interval``. + // + // .. note:: + // + // This field is only honored for management‑plane xDS gRPC streams created from + // :ref:`ApiConfigSource ` that use + // ``envoy_grpc``. Data‑plane gRPC clients (for example external authorization or external + // processing filters) must use :ref:`GrpcService.retry_policy + // ` instead. + // + // If not set, xDS gRPC streams default to a base interval of 500ms and a maximum interval of 30s. RetryPolicy retry_policy = 3; + + // Maximum gRPC message size that is allowed to be received. + // If a message over this limit is received, the gRPC stream is terminated with the RESOURCE_EXHAUSTED error. + // This limit is applied to individual messages in the streaming response and not the total size of streaming response. + // Defaults to 0, which means unlimited. + google.protobuf.UInt32Value max_receive_message_length = 4; + + // This provides gRPC client level control over envoy generated headers. + // If false, the header will be sent but it can be overridden by per stream option. + // If true, the header will be removed and can not be overridden by per stream option. + // Default to false. + bool skip_envoy_headers = 5; } - // [#next-free-field: 9] + // [#next-free-field: 11] message GoogleGrpc { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.core.GrpcService.GoogleGrpc"; @@ -236,16 +259,31 @@ message GrpcService { } // The target URI when using the `Google C++ gRPC client - // `_. SSL credentials will be supplied in - // :ref:`channel_credentials `. + // `_. string target_uri = 1 [(validate.rules).string = {min_len: 1}]; + // The channel credentials to use. See `channel credentials + // `_. + // Ignored if ``channel_credentials_plugin`` is set. ChannelCredentials channel_credentials = 2; - // A set of call credentials that can be composed with `channel credentials + // A list of channel credentials plugins. + // The data plane will iterate over the list in order and stop at the first credential type + // that it supports. This provides a mechanism for starting to use new credential types that + // are not yet supported by all data planes. + // [#not-implemented-hide:] + repeated google.protobuf.Any channel_credentials_plugin = 9; + + // The call credentials to use. See `channel credentials // `_. + // Ignored if ``call_credentials_plugin`` is set. repeated CallCredentials call_credentials = 3; + // A list of call credentials plugins. All supported plugins will be used. + // Unsupported plugin types will be ignored. + // [#not-implemented-hide:] + repeated google.protobuf.Any call_credentials_plugin = 10; + // The human readable prefix to use when emitting statistics for the gRPC // service. // @@ -300,4 +338,18 @@ message GrpcService { // documentation on :ref:`custom request headers // `. repeated HeaderValue initial_metadata = 5; + + // Optional default retry policy for RPCs or streams initiated toward this gRPC service. + // + // If an async stream does not have a retry policy configured in its per‑stream options, this + // policy is used as the default. + // + // .. note:: + // + // This field is only applied by Envoy gRPC (``envoy_grpc``) clients. Google gRPC + // (``google_grpc``) clients currently ignore this field. + // + // If not specified, no default retry policy is applied at the client level and retries only occur + // when explicitly configured in per‑stream options. + RetryPolicy retry_policy = 6; } diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/health_check.proto b/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/health_check.proto index 2ec258d8ac0..a4ed6e91818 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/health_check.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/health_check.proto @@ -5,6 +5,7 @@ package envoy.config.core.v3; import "envoy/config/core/v3/base.proto"; import "envoy/config/core/v3/event_service_config.proto"; import "envoy/config/core/v3/extension.proto"; +import "envoy/config/core/v3/proxy_protocol.proto"; import "envoy/type/matcher/v3/string.proto"; import "envoy/type/v3/http.proto"; import "envoy/type/v3/range.proto"; @@ -62,7 +63,7 @@ message HealthStatusSet { [(validate.rules).repeated = {items {enum {defined_only: true}}}]; } -// [#next-free-field: 26] +// [#next-free-field: 27] message HealthCheck { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.core.HealthCheck"; @@ -95,14 +96,14 @@ message HealthCheck { // left empty (default value), the name of the cluster this health check is associated // with will be used. The host header can be customized for a specific endpoint by setting the // :ref:`hostname ` field. - string host = 1 [(validate.rules).string = {well_known_regex: HTTP_HEADER_VALUE strict: false}]; + string host = 1 [(validate.rules).string = {well_known_regex: HTTP_HEADER_VALUE}]; // Specifies the HTTP path that will be requested during health checking. For example // ``/healthcheck``. - string path = 2 - [(validate.rules).string = {min_len: 1 well_known_regex: HTTP_HEADER_VALUE strict: false}]; + string path = 2 [(validate.rules).string = {min_len: 1 well_known_regex: HTTP_HEADER_VALUE}]; - // [#not-implemented-hide:] HTTP specific payload. + // HTTP specific payload to be sent as the request body during health checking. + // If specified, the method should support a request body (POST, PUT, PATCH, etc.). Payload send = 3; // Specifies a list of HTTP expected responses to match in the first ``response_buffer_size`` bytes of the response body. @@ -161,7 +162,8 @@ message HealthCheck { type.matcher.v3.StringMatcher service_name_matcher = 11; // HTTP Method that will be used for health checking, default is "GET". - // GET, HEAD, POST, PUT, DELETE, OPTIONS, TRACE, PATCH methods are supported, but making request body is not supported. + // GET, HEAD, POST, PUT, DELETE, OPTIONS, TRACE, PATCH methods are supported. + // Request body payloads are supported for POST, PUT, PATCH, and OPTIONS methods only. // CONNECT method is disallowed because it is not appropriate for health check request. // If a non-200 response is expected by the method, it needs to be set in :ref:`expected_statuses `. RequestMethod method = 13 [(validate.rules).enum = {defined_only: true not_in: 6}]; @@ -178,6 +180,13 @@ message HealthCheck { // payload block must be found, and in the order specified, but not // necessarily contiguous. repeated Payload receive = 2; + + // When setting this value, it tries to attempt health check request with ProxyProtocol. + // When ``send`` is presented, they are sent after preceding ProxyProtocol header. + // Only ProxyProtocol header is sent when ``send`` is not presented. + // It allows to use both ProxyProtocol V1 and V2. In V1, it presents L3/L4. In V2, it includes + // LOCAL command and doesn't include L3/L4. + ProxyProtocolConfig proxy_protocol_config = 3; } message RedisHealthCheck { @@ -368,13 +377,13 @@ message HealthCheck { // The default value for "healthy edge interval" is the same as the default interval. google.protobuf.Duration healthy_edge_interval = 16 [(validate.rules).duration = {gt {}}]; - // .. attention:: - // This field is deprecated in favor of the extension - // :ref:`event_logger ` and - // :ref:`event_log_path ` - // in the file sink extension. - // // Specifies the path to the :ref:`health check event log `. + // + // .. attention:: + // This field is deprecated in favor of the extension + // :ref:`event_logger ` and + // :ref:`event_log_path ` + // in the file sink extension. string event_log_path = 17 [deprecated = true, (envoy.annotations.deprecated_at_minor_version) = "3.0"]; @@ -392,6 +401,11 @@ message HealthCheck { // The default value is false. bool always_log_health_check_failures = 19; + // If set to true, health check success events will always be logged. If set to false, only host addition event will be logged + // if it is the first successful health check, or if the healthy threshold is reached. + // The default value is false. + bool always_log_health_check_success = 26; + // This allows overriding the cluster TLS settings, just for health check connections. TlsOptions tls_options = 21; diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/protocol.proto b/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/protocol.proto index d128dc6d93d..63e189e689e 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/protocol.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/protocol.proto @@ -3,6 +3,7 @@ syntax = "proto3"; package envoy.config.core.v3; import "envoy/config/core/v3/extension.proto"; +import "envoy/type/matcher/v3/string.proto"; import "envoy/type/v3/percent.proto"; import "google/protobuf/duration.proto"; @@ -30,44 +31,80 @@ message TcpProtocolOptions { } // Config for keepalive probes in a QUIC connection. -// Note that QUIC keep-alive probing packets work differently from HTTP/2 keep-alive PINGs in a sense that the probing packet -// itself doesn't timeout waiting for a probing response. Quic has a shorter idle timeout than TCP, so it doesn't rely on such probing to discover dead connections. If the peer fails to respond, the connection will idle timeout eventually. Thus, they are configured differently from :ref:`connection_keepalive `. +// +// .. note:: +// +// QUIC keep-alive probing packets work differently from HTTP/2 keep-alive PINGs in a sense that the probing packet +// itself doesn't timeout waiting for a probing response. QUIC has a shorter idle timeout than TCP, so it doesn't rely on such probing to discover dead connections. If the peer fails to respond, the connection will idle timeout eventually. Thus, they are configured differently from :ref:`connection_keepalive `. message QuicKeepAliveSettings { - // The max interval for a connection to send keep-alive probing packets (with PING or PATH_RESPONSE). The value should be smaller than :ref:`connection idle_timeout ` to prevent idle timeout while not less than 1s to avoid throttling the connection or flooding the peer with probes. + // The max interval for a connection to send keep-alive probing packets (with ``PING`` or ``PATH_RESPONSE``). The value should be smaller than :ref:`connection idle_timeout ` to prevent idle timeout while not less than ``1s`` to avoid throttling the connection or flooding the peer with probes. // // If :ref:`initial_interval ` is absent or zero, a client connection will use this value to start probing. // // If zero, disable keepalive probing. // If absent, use the QUICHE default interval to probe. - google.protobuf.Duration max_interval = 1 [(validate.rules).duration = { - lte {} - gte {seconds: 1} - }]; + google.protobuf.Duration max_interval = 1; // The interval to send the first few keep-alive probing packets to prevent connection from hitting the idle timeout. Subsequent probes will be sent, each one with an interval exponentially longer than previous one, till it reaches :ref:`max_interval `. And the probes afterwards will always use :ref:`max_interval `. // // The value should be smaller than :ref:`connection idle_timeout ` to prevent idle timeout and smaller than max_interval to take effect. // - // If absent or zero, disable keepalive probing for a server connection. For a client connection, if :ref:`max_interval ` is also zero, do not keepalive, otherwise use max_interval or QUICHE default to probe all the time. + // If absent, disable keepalive probing for a server connection. For a client connection, if :ref:`max_interval ` is zero, do not keepalive, otherwise use max_interval or QUICHE default to probe all the time. google.protobuf.Duration initial_interval = 2 [(validate.rules).duration = { lte {} - gte {seconds: 1} + gte {nanos: 1000000} }]; } // QUIC protocol options which apply to both downstream and upstream connections. -// [#next-free-field: 8] +// [#next-free-field: 12] message QuicProtocolOptions { - // Maximum number of streams that the client can negotiate per connection. 100 + // Config for QUIC connection migration across network interfaces, i.e. cellular to WIFI, upon + // network change events from the platform, i.e. the current network gets + // disconnected, or upon the QUIC detecting a bad connection. After migration, the + // connection may be on a different network other than the default network + // picked by the platform. Both iOS and Android will use a default network to interact with the internet, usually prefer unmetered network (WIFI) + // over metered ones (cellular). And users can specify which network to be used as the default. A connection on non-default network is only allowed to + // serve new requests for a certain period of time before being drained, and + // meanwhile, QUIC will try to migrate to the default network if possible. + message ConnectionMigrationSettings { + // Config for options to migrate idle connections which aren't serving any requests. + message MigrateIdleConnectionSettings { + // If idle connections are allowed to be migrated, only migrate the connection + // if it hasn't been idle for longer than this idle period. Otherwise, the + // connection will be closed instead. + // Default to 30s. + google.protobuf.Duration max_idle_time_before_migration = 1 + [(validate.rules).duration = {gte {seconds: 1}}]; + } + + // Config whether and how to migrate idle connections. + // If absent, idle connections will not be migrated but be closed upon + // migration signals. + MigrateIdleConnectionSettings migrate_idle_connections = 1; + + // After migrating to a non-default network interface, the connection will + // only be allowed to stay on that network for up to this period of time before + // being drained unless it migrates to the default network or that network + // gets picked as the default by the device by then. + // Default to 128s. + google.protobuf.Duration max_time_on_non_default_network = 2 + [(validate.rules).duration = {gte {seconds: 1}}]; + } + + // Maximum number of streams that the client can negotiate per connection. ``100`` // if not specified. google.protobuf.UInt32Value max_concurrent_streams = 1 [(validate.rules).uint32 = {gte: 1}]; // `Initial stream-level flow-control receive window // `_ size. Valid values range from - // 1 to 16777216 (2^24, maximum supported by QUICHE) and defaults to 65536 (2^16). + // ``1`` to ``16777216`` (``2^24``, maximum supported by QUICHE) and defaults to ``16777216`` (``16 * 1024 * 1024``). + // + // .. note:: // - // NOTE: 16384 (2^14) is the minimum window size supported in Google QUIC. If configured smaller than it, we will use 16384 instead. - // QUICHE IETF Quic implementation supports 1 bytes window. We only support increasing the default window size now, so it's also the minimum. + // ``16384`` (``2^14``) is the minimum window size supported in Google QUIC. If configured smaller than it, we will use + // ``16384`` instead. QUICHE IETF QUIC implementation supports ``1`` byte window. We only support increasing the default + // window size now, so it's also the minimum. // // This field also acts as a soft limit on the number of bytes Envoy will buffer per-stream in the // QUIC stream send and receive buffers. Once the buffer reaches this pointer, watermark callbacks will fire to @@ -76,23 +113,26 @@ message QuicProtocolOptions { [(validate.rules).uint32 = {lte: 16777216 gte: 1}]; // Similar to ``initial_stream_window_size``, but for connection-level - // flow-control. Valid values rage from 1 to 25165824 (24MB, maximum supported by QUICHE) and defaults to 65536 (2^16). - // window. Currently, this has the same minimum/default as ``initial_stream_window_size``. + // flow-control. Valid values range from ``1`` to ``25165824`` (``24MB``, maximum supported by QUICHE) and defaults + // to ``25165824`` (``24 * 1024 * 1024``). + // + // .. note:: + // + // ``16384`` (``2^14``) is the minimum window size supported in Google QUIC. We only support increasing the default + // window size now, so it's also the minimum. // - // NOTE: 16384 (2^14) is the minimum window size supported in Google QUIC. We only support increasing the default - // window size now, so it's also the minimum. google.protobuf.UInt32Value initial_connection_window_size = 3 [(validate.rules).uint32 = {lte: 25165824 gte: 1}]; // The number of timeouts that can occur before port migration is triggered for QUIC clients. - // This defaults to 4. If set to 0, port migration will not occur on path degrading. - // Timeout here refers to QUIC internal path degrading timeout mechanism, such as PTO. + // This defaults to ``4``. If set to ``0``, port migration will not occur on path degrading. + // Timeout here refers to QUIC internal path degrading timeout mechanism, such as ``PTO``. // This has no effect on server sessions. google.protobuf.UInt32Value num_timeouts_to_trigger_port_migration = 4 [(validate.rules).uint32 = {lte: 5 gte: 0}]; - // Probes the peer at the configured interval to solicit traffic, i.e. ACK or PATH_RESPONSE, from the peer to push back connection idle timeout. - // If absent, use the default keepalive behavior of which a client connection sends PINGs every 15s, and a server connection doesn't do anything. + // Probes the peer at the configured interval to solicit traffic, i.e. ``ACK`` or ``PATH_RESPONSE``, from the peer to push back connection idle timeout. + // If absent, use the default keepalive behavior of which a client connection sends ``PING``s every ``15s``, and a server connection doesn't do anything. QuicKeepAliveSettings connection_keepalive = 5; // A comma-separated list of strings representing QUIC connection options defined in @@ -102,6 +142,37 @@ message QuicProtocolOptions { // A comma-separated list of strings representing QUIC client connection options defined in // `QUICHE `_ and to be sent by upstream connections. string client_connection_options = 7; + + // The duration that a QUIC connection stays idle before it closes itself. If this field is not present, QUICHE + // default ``600s`` will be applied. + // For internal corporate network, a long timeout is often fine. + // But for client facing network, ``30s`` is usually a good choice. + // Do not add an upper bound here. A long idle timeout is useful for maintaining warm connections at non-front-line proxy for low QPS services. + google.protobuf.Duration idle_network_timeout = 8 + [(validate.rules).duration = {gte {seconds: 1}}]; + + // Maximum packet length for QUIC connections. It refers to the largest size of a QUIC packet that can be transmitted over the connection. + // If not specified, one of the `default values in QUICHE `_ is used. + google.protobuf.UInt64Value max_packet_length = 9; + + // A customized UDP socket and a QUIC packet writer using the socket for + // client connections. i.e. Mobile uses its own implementation to interact + // with platform socket APIs. + // If not present, the default platform-independent socket and writer will be used. + // [#extension-category: envoy.quic.client_packet_writer] + TypedExtensionConfig client_packet_writer = 10; + + // Enable QUIC `connection migration + // ` + // to a different network interface when the current network is degrading or + // has become bad. + // In order to use a different network interface other than the platform's default one, + // a customized :ref:`client_packet_writer ` needs to be configured to + // create UDP sockets on non-default networks. + // Only takes effect when runtime key ``envoy.reloadable_features.use_migration_in_quiche`` is true. + // If absent, the feature will be disabled. + // [#not-implemented-hide:] + ConnectionMigrationSettings connection_migration = 11; } message UpstreamHttpProtocolOptions { @@ -113,6 +184,9 @@ message UpstreamHttpProtocolOptions { // header when :ref:`override_auto_sni_header ` // is set, as seen by the :ref:`router filter `. // Does nothing if a filter before the http router filter sets the corresponding metadata. + // + // See :ref:`SNI configuration ` for details on how this + // interacts with other validation options. bool auto_sni = 1; // Automatic validate upstream presented certificate for new upstream connections based on the @@ -120,6 +194,9 @@ message UpstreamHttpProtocolOptions { // is set, as seen by the :ref:`router filter `. // This field is intended to be set with ``auto_sni`` field. // Does nothing if a filter before the http router filter sets the corresponding metadata. + // + // See :ref:`validation configuration ` for how this interacts with + // other validation options. bool auto_san_validation = 2; // An optional alternative to the host/authority header to be used for setting the SNI value. @@ -165,9 +242,9 @@ message AlternateProtocolsCacheOptions { // not the case. string name = 1 [(validate.rules).string = {min_len: 1}]; - // The maximum number of entries that the cache will hold. If not specified defaults to 1024. + // The maximum number of entries that the cache will hold. If not specified defaults to ``1024``. // - // .. note: + // .. note:: // // The implementation is approximate and enforced independently on each worker thread, thus // it is possible for the maximum entries in the cache to go slightly above the configured @@ -196,7 +273,7 @@ message AlternateProtocolsCacheOptions { repeated string canonical_suffixes = 5; } -// [#next-free-field: 7] +// [#next-free-field: 8] message HttpProtocolOptions { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.core.HttpProtocolOptions"; @@ -210,14 +287,14 @@ message HttpProtocolOptions { // Allow headers with underscores. This is the default behavior. ALLOW = 0; - // Reject client request. HTTP/1 requests are rejected with the 400 status. HTTP/2 requests - // end with the stream reset. The "httpN.requests_rejected_with_underscores_in_headers" counter + // Reject client request. HTTP/1 requests are rejected with ``HTTP 400`` status. HTTP/2 requests + // end with the stream reset. The ``httpN.requests_rejected_with_underscores_in_headers`` counter // is incremented for each rejected request. REJECT_REQUEST = 1; // Drop the client header with name containing underscores. The header is dropped before the filter chain is // invoked and as such filters will not see dropped headers. The - // "httpN.dropped_headers_with_underscores" is incremented for each dropped header. + // ``httpN.dropped_headers_with_underscores`` is incremented for each dropped header. DROP_HEADER = 2; } @@ -227,8 +304,12 @@ message HttpProtocolOptions { // downstream connection a drain sequence will occur prior to closing the connection, see // :ref:`drain_timeout // `. - // Note that request based timeouts mean that HTTP/2 PINGs will not keep the connection alive. - // If not specified, this defaults to 1 hour. To disable idle timeouts explicitly set this to 0. + // + // .. note:: + // + // Request based timeouts mean that HTTP/2 PINGs will not keep the connection alive. + // + // If not specified, this defaults to ``1 hour``. To disable idle timeouts explicitly set this to ``0``. // // .. warning:: // Disabling this timeout has a highly likelihood of yielding connection leaks due to lost TCP @@ -240,37 +321,66 @@ message HttpProtocolOptions { google.protobuf.Duration idle_timeout = 1; // The maximum duration of a connection. The duration is defined as a period since a connection - // was established. If not set, there is no max duration. When max_connection_duration is reached - // and if there are no active streams, the connection will be closed. If the connection is a - // downstream connection and there are any active streams, the drain sequence will kick-in, - // and the connection will be force-closed after the drain period. See :ref:`drain_timeout + // was established. If not set, there is no max duration. When max_connection_duration is reached, + // the drain sequence will kick-in. The connection will be closed after the drain timeout period + // if there are no active streams. See :ref:`drain_timeout // `. google.protobuf.Duration max_connection_duration = 3; - // The maximum number of headers. If unconfigured, the default - // maximum number of request headers allowed is 100. Requests that exceed this limit will receive - // a 431 response for HTTP/1.x and cause a stream reset for HTTP/2. + // The maximum number of headers (request headers if configured on HttpConnectionManager, + // response headers when configured on a cluster). + // If unconfigured, the default maximum number of headers allowed is ``100``. + // The default value for requests can be overridden by setting runtime key ``envoy.reloadable_features.max_request_headers_count``. + // The default value for responses can be overridden by setting runtime key ``envoy.reloadable_features.max_response_headers_count``. + // Downstream requests that exceed this limit will receive a ``HTTP 431`` response for HTTP/1.x and cause a stream + // reset for HTTP/2. + // Upstream responses that exceed this limit will result in a ``HTTP 502`` response. google.protobuf.UInt32Value max_headers_count = 2 [(validate.rules).uint32 = {gte: 1}]; + // The maximum size of response headers. + // If unconfigured, the default is ``60 KiB``, except for HTTP/1 response headers which have a default + // of ``80 KiB``. + // The default value can be overridden by setting runtime key ``envoy.reloadable_features.max_response_headers_size_kb``. + // Responses that exceed this limit will result in a ``HTTP 503`` response. + // In Envoy, this setting is only valid when configured on an upstream cluster, not on the + // :ref:`HTTP Connection Manager + // `. + // + // .. note:: + // + // Currently some protocol codecs impose limits on the maximum size of a single header. + // + // * HTTP/2 (when using ``nghttp2``) limits a single header to around ``100kb``. + // * HTTP/3 limits a single header to around ``1024kb``. + // + google.protobuf.UInt32Value max_response_headers_kb = 7 + [(validate.rules).uint32 = {lte: 8192 gt: 0}]; + // Total duration to keep alive an HTTP request/response stream. If the time limit is reached the stream will be // reset independent of any other timeouts. If not specified, this value is not set. google.protobuf.Duration max_stream_duration = 4; // Action to take when a client request with a header name containing underscore characters is received. - // If this setting is not specified, the value defaults to ALLOW. - // Note: upstream responses are not affected by this setting. - // Note: this only affects client headers. It does not affect headers added - // by Envoy filters and does not have any impact if added to cluster config. + // If this setting is not specified, the value defaults to ``ALLOW``. + // + // .. note:: + // + // Upstream responses are not affected by this setting. + // + // .. note:: + // + // This only affects client headers. It does not affect headers added by Envoy filters and does not have any + // impact if added to cluster config. HeadersWithUnderscoresAction headers_with_underscores_action = 5; // Optional maximum requests for both upstream and downstream connections. // If not specified, there is no limit. - // Setting this parameter to 1 will effectively disable keep alive. + // Setting this parameter to ``1`` will effectively disable keep alive. // For HTTP/2 and HTTP/3, due to concurrent stream processing, the limit is approximate. google.protobuf.UInt32Value max_requests_per_connection = 6; } -// [#next-free-field: 11] +// [#next-free-field: 12] message Http1ProtocolOptions { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.core.Http1ProtocolOptions"; @@ -290,9 +400,12 @@ message Http1ProtocolOptions { // Formats the header by proper casing words: the first character and any character following // a special character will be capitalized if it's an alpha character. For example, - // "content-type" becomes "Content-Type", and "foo$b#$are" becomes "Foo$B#$Are". - // Note that while this results in most headers following conventional casing, certain headers - // are not covered. For example, the "TE" header will be formatted as "Te". + // ``"content-type"`` becomes ``"Content-Type"``, and ``"foo$b#$are"`` becomes ``"Foo$B#$Are"``. + // + // .. note:: + // + // While this results in most headers following conventional casing, certain headers + // are not covered. For example, the ``"TE"`` header will be formatted as ``"Te"``. ProperCaseWords proper_case_words = 1; // Configuration for stateful formatter extensions that allow using received headers to @@ -308,7 +421,7 @@ message Http1ProtocolOptions { // ``http_proxy`` environment variable. google.protobuf.BoolValue allow_absolute_url = 1; - // Handle incoming HTTP/1.0 and HTTP 0.9 requests. + // Handle incoming HTTP/1.0 and HTTP/0.9 requests. // This is off by default, and not fully standards compliant. There is support for pre-HTTP/1.1 // style connect logic, dechunking, and handling lack of client host iff // ``default_host_for_http_10`` is configured. @@ -327,19 +440,20 @@ message Http1ProtocolOptions { // // .. attention:: // - // Note that this only happens when Envoy is chunk encoding which occurs when: + // This only happens when Envoy is chunk encoding which occurs when: // - The request is HTTP/1.1. - // - Is neither a HEAD only request nor a HTTP Upgrade. - // - Not a response to a HEAD request. - // - The content length header is not present. + // - Is neither a ``HEAD`` only request nor a HTTP Upgrade. + // - Not a response to a ``HEAD`` request. + // - The ``Content-Length`` header is not present. bool enable_trailers = 5; // Allows Envoy to process requests/responses with both ``Content-Length`` and ``Transfer-Encoding`` // headers set. By default such messages are rejected, but if option is enabled - Envoy will - // remove Content-Length header and process message. + // remove ``Content-Length`` header and process message. // See `RFC7230, sec. 3.3.3 `_ for details. // // .. attention:: + // // Enabling this option might lead to request smuggling vulnerability, especially if traffic // is proxied via multiple layers of proxies. // [#comment:TODO: This field is ignored when the @@ -368,7 +482,7 @@ message Http1ProtocolOptions { // envoy.reloadable_features.http1_use_balsa_parser. // See issue #21245. google.protobuf.BoolValue use_balsa_parser = 9 - [(xds.annotations.v3.field_status).work_in_progress = true]; + [deprecated = true, (envoy.annotations.deprecated_at_minor_version) = "3.0"]; // [#not-implemented-hide:] Hiding so that field can be removed. // If true, and BalsaParser is used (either `use_balsa_parser` above is true, @@ -382,6 +496,14 @@ message Http1ProtocolOptions { // ` // to reject custom methods. bool allow_custom_methods = 10 [(xds.annotations.v3.field_status).work_in_progress = true]; + + // Ignore HTTP/1.1 upgrade values matching any of the supplied matchers. + // + // .. note:: + // + // ``h2c`` upgrades are always removed for backwards compatibility, regardless of the + // value in this setting. + repeated type.matcher.v3.StringMatcher ignore_http_11_upgrade = 11; } message KeepaliveSettings { @@ -390,9 +512,12 @@ message KeepaliveSettings { google.protobuf.Duration interval = 1 [(validate.rules).duration = {gte {nanos: 1000000}}]; // How long to wait for a response to a keepalive PING. If a response is not received within this - // time period, the connection will be aborted. Note that in order to prevent the influence of - // Head-of-line (HOL) blocking the timeout period is extended when *any* frame is received on - // the connection, under the assumption that if a frame is received the connection is healthy. + // time period, the connection will be aborted. + // + // .. note:: + // + // In order to prevent the influence of Head-of-line (HOL) blocking the timeout period is extended when *any* frame is received on + // the connection, under the assumption that if a frame is received the connection is healthy. google.protobuf.Duration timeout = 2 [(validate.rules).duration = { required: true gte {nanos: 1000000} @@ -400,7 +525,7 @@ message KeepaliveSettings { // A random jitter amount as a percentage of interval that will be added to each interval. // A value of zero means there will be no jitter. - // The default value is 15%. + // The default value is ``15%``. type.v3.Percent interval_jitter = 3; // If the connection has been idle for this duration, send a HTTP/2 ping ahead @@ -414,7 +539,7 @@ message KeepaliveSettings { [(validate.rules).duration = {gte {nanos: 1000000}}]; } -// [#next-free-field: 17] +// [#next-free-field: 19] message Http2ProtocolOptions { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.core.Http2ProtocolOptions"; @@ -437,13 +562,13 @@ message Http2ProtocolOptions { // `Maximum table size `_ // (in octets) that the encoder is permitted to use for the dynamic HPACK table. Valid values - // range from 0 to 4294967295 (2^32 - 1) and defaults to 4096. 0 effectively disables header + // range from ``0`` to ``4294967295`` (``2^32 - 1``) and defaults to ``4096``. ``0`` effectively disables header // compression. google.protobuf.UInt32Value hpack_table_size = 1; // `Maximum concurrent streams `_ - // allowed for peer on one HTTP/2 connection. Valid values range from 1 to 2147483647 (2^31 - 1) - // and defaults to 2147483647. + // allowed for peer on one HTTP/2 connection. Valid values range from ``1`` to ``2147483647`` (``2^31 - 1``) + // and defaults to ``1024`` for safety and should be sufficient for most use cases. // // For upstream connections, this also limits how many streams Envoy will initiate concurrently // on a single connection. If the limit is reached, Envoy may queue requests or establish @@ -456,12 +581,14 @@ message Http2ProtocolOptions { [(validate.rules).uint32 = {lte: 2147483647 gte: 1}]; // `Initial stream-level flow-control window - // `_ size. Valid values range from 65535 - // (2^16 - 1, HTTP/2 default) to 2147483647 (2^31 - 1, HTTP/2 maximum) and defaults to 268435456 - // (256 * 1024 * 1024). + // `_ size. Valid values range from ``65535`` + // (``2^16 - 1``, HTTP/2 default) to ``2147483647`` (``2^31 - 1``, HTTP/2 maximum) and defaults to + // ``16MiB`` (``16 * 1024 * 1024``). // - // NOTE: 65535 is the initial window size from HTTP/2 spec. We only support increasing the default - // window size now, so it's also the minimum. + // .. note:: + // + // ``65535`` is the initial window size from HTTP/2 spec. We only support increasing the default window size now, + // so it's also the minimum. // // This field also acts as a soft limit on the number of bytes Envoy will buffer per-stream in the // HTTP/2 codec buffers. Once the buffer reaches this pointer, watermark callbacks will fire to @@ -470,17 +597,17 @@ message Http2ProtocolOptions { [(validate.rules).uint32 = {lte: 2147483647 gte: 65535}]; // Similar to ``initial_stream_window_size``, but for connection-level flow-control - // window. Currently, this has the same minimum/maximum/default as ``initial_stream_window_size``. + // window. The default is ``24MiB`` (``24 * 1024 * 1024``). google.protobuf.UInt32Value initial_connection_window_size = 4 [(validate.rules).uint32 = {lte: 2147483647 gte: 65535}]; // Allows proxying Websocket and other upgrades over H2 connect. bool allow_connect = 5; - // [#not-implemented-hide:] Hiding until envoy has full metadata support. + // [#not-implemented-hide:] Hiding until Envoy has full metadata support. // Still under implementation. DO NOT USE. // - // Allows metadata. See [metadata + // Allows sending and receiving HTTP/2 METADATA frames. See [metadata // docs](https://github.com/envoyproxy/envoy/blob/main/source/docs/h2_metadata.md) for more // information. bool allow_metadata = 6; @@ -488,51 +615,51 @@ message Http2ProtocolOptions { // Limit the number of pending outbound downstream frames of all types (frames that are waiting to // be written into the socket). Exceeding this limit triggers flood mitigation and connection is // terminated. The ``http2.outbound_flood`` stat tracks the number of terminated connections due - // to flood mitigation. The default limit is 10000. + // to flood mitigation. The default limit is ``10000``. google.protobuf.UInt32Value max_outbound_frames = 7 [(validate.rules).uint32 = {gte: 1}]; - // Limit the number of pending outbound downstream frames of types PING, SETTINGS and RST_STREAM, + // Limit the number of pending outbound downstream frames of types ``PING``, ``SETTINGS`` and ``RST_STREAM``, // preventing high memory utilization when receiving continuous stream of these frames. Exceeding // this limit triggers flood mitigation and connection is terminated. The // ``http2.outbound_control_flood`` stat tracks the number of terminated connections due to flood - // mitigation. The default limit is 1000. + // mitigation. The default limit is ``1000``. google.protobuf.UInt32Value max_outbound_control_frames = 8 [(validate.rules).uint32 = {gte: 1}]; - // Limit the number of consecutive inbound frames of types HEADERS, CONTINUATION and DATA with an + // Limit the number of consecutive inbound frames of types ``HEADERS``, ``CONTINUATION`` and ``DATA`` with an // empty payload and no end stream flag. Those frames have no legitimate use and are abusive, but - // might be a result of a broken HTTP/2 implementation. The `http2.inbound_empty_frames_flood`` + // might be a result of a broken HTTP/2 implementation. The ``http2.inbound_empty_frames_flood`` // stat tracks the number of connections terminated due to flood mitigation. - // Setting this to 0 will terminate connection upon receiving first frame with an empty payload - // and no end stream flag. The default limit is 1. + // Setting this to ``0`` will terminate connection upon receiving first frame with an empty payload + // and no end stream flag. The default limit is ``1``. google.protobuf.UInt32Value max_consecutive_inbound_frames_with_empty_payload = 9; - // Limit the number of inbound PRIORITY frames allowed per each opened stream. If the number - // of PRIORITY frames received over the lifetime of connection exceeds the value calculated + // Limit the number of inbound ``PRIORITY`` frames allowed per each opened stream. If the number + // of ``PRIORITY`` frames received over the lifetime of connection exceeds the value calculated // using this formula:: // // ``max_inbound_priority_frames_per_stream`` * (1 + ``opened_streams``) // // the connection is terminated. For downstream connections the ``opened_streams`` is incremented when // Envoy receives complete response headers from the upstream server. For upstream connection the - // ``opened_streams`` is incremented when Envoy send the HEADERS frame for a new stream. The + // ``opened_streams`` is incremented when Envoy sends the ``HEADERS`` frame for a new stream. The // ``http2.inbound_priority_frames_flood`` stat tracks - // the number of connections terminated due to flood mitigation. The default limit is 100. + // the number of connections terminated due to flood mitigation. The default limit is ``100``. google.protobuf.UInt32Value max_inbound_priority_frames_per_stream = 10; - // Limit the number of inbound WINDOW_UPDATE frames allowed per DATA frame sent. If the number - // of WINDOW_UPDATE frames received over the lifetime of connection exceeds the value calculated + // Limit the number of inbound ``WINDOW_UPDATE`` frames allowed per ``DATA`` frame sent. If the number + // of ``WINDOW_UPDATE`` frames received over the lifetime of connection exceeds the value calculated // using this formula:: // - // 5 + 2 * (``opened_streams`` + - // ``max_inbound_window_update_frames_per_data_frame_sent`` * ``outbound_data_frames``) + // ``5 + 2 * (opened_streams + + // max_inbound_window_update_frames_per_data_frame_sent * outbound_data_frames)`` // // the connection is terminated. For downstream connections the ``opened_streams`` is incremented when // Envoy receives complete response headers from the upstream server. For upstream connections the - // ``opened_streams`` is incremented when Envoy sends the HEADERS frame for a new stream. The + // ``opened_streams`` is incremented when Envoy sends the ``HEADERS`` frame for a new stream. The // ``http2.inbound_priority_frames_flood`` stat tracks the number of connections terminated due to - // flood mitigation. The default max_inbound_window_update_frames_per_data_frame_sent value is 10. - // Setting this to 1 should be enough to support HTTP/2 implementations with basic flow control, - // but more complex implementations that try to estimate available bandwidth require at least 2. + // flood mitigation. The default ``max_inbound_window_update_frames_per_data_frame_sent`` value is ``10``. + // Setting this to ``1`` should be enough to support HTTP/2 implementations with basic flow control, + // but more complex implementations that try to estimate available bandwidth require at least ``2``. google.protobuf.UInt32Value max_inbound_window_update_frames_per_data_frame_sent = 11 [(validate.rules).uint32 = {gte: 1}]; @@ -570,8 +697,10 @@ message Http2ProtocolOptions { // 2. SETTINGS_ENABLE_CONNECT_PROTOCOL (0x8) is only configurable through the named field // 'allow_connect'. // - // Note that custom parameters specified through this field can not also be set in the - // corresponding named parameters: + // .. note:: + // + // Custom parameters specified through this field can not also be set in the + // corresponding named parameters: // // .. code-block:: text // @@ -598,6 +727,15 @@ message Http2ProtocolOptions { // If unset, HTTP/2 codec is selected based on envoy.reloadable_features.http2_use_oghttp2. google.protobuf.BoolValue use_oghttp2_codec = 16 [(xds.annotations.v3.field_status).work_in_progress = true]; + + // Configure the maximum amount of metadata than can be handled per stream. Defaults to ``1 MB``. + google.protobuf.UInt64Value max_metadata_size = 17; + + // Controls whether to encode headers using huffman encoding. + // This can be useful in cases where the cpu spent encoding the headers isn't + // worth the network bandwidth saved e.g. for localhost. + // If unset, uses the data plane's default value. + google.protobuf.BoolValue enable_huffman_encoding = 18; } // [#not-implemented-hide:] @@ -609,7 +747,7 @@ message GrpcProtocolOptions { } // A message which allows using HTTP/3. -// [#next-free-field: 6] +// [#next-free-field: 9] message Http3ProtocolOptions { QuicProtocolOptions quic_protocol_options = 1; @@ -626,14 +764,44 @@ message Http3ProtocolOptions { // `_ // and settings `proposed for HTTP/3 // `_ - // Note that HTTP/3 CONNECT is not yet an RFC. + // + // .. note:: + // + // HTTP/3 CONNECT is not yet an RFC. bool allow_extended_connect = 5 [(xds.annotations.v3.field_status).work_in_progress = true]; + + // [#not-implemented-hide:] Hiding until Envoy has full metadata support. + // Still under implementation. DO NOT USE. + // + // Allows sending and receiving HTTP/3 METADATA frames. See [metadata + // docs](https://github.com/envoyproxy/envoy/blob/main/source/docs/h2_metadata.md) for more + // information. + bool allow_metadata = 6; + + // [#not-implemented-hide:] Hiding until Envoy has full HTTP/3 upstream support. + // Still under implementation. DO NOT USE. + // + // Disables QPACK compression related features for HTTP/3 including: + // No huffman encoding, zero dynamic table capacity and no cookie crumbling. + // This can be useful for trading off CPU vs bandwidth when an upstream HTTP/3 connection multiplexes multiple downstream connections. + bool disable_qpack = 7; + + // Disables connection level flow control for HTTP/3 streams. This is useful in situations where the streams share the same connection + // but originate from different end-clients, so that each stream can make progress independently at non-front-line proxies. + bool disable_connection_flow_control_for_streams = 8; } // A message to control transformations to the :scheme header message SchemeHeaderTransformation { oneof transformation { // Overwrite any Scheme header with the contents of this string. + // If set, takes precedence over ``match_upstream``. string scheme_to_overwrite = 1 [(validate.rules).string = {in: "http" in: "https"}]; } + + // Set the Scheme header to match the upstream transport protocol. For example, should a + // request be sent to the upstream over TLS, the scheme header will be set to ``"https"``. Should the + // request be sent over plaintext, the scheme header will be set to ``"http"``. + // If ``scheme_to_overwrite`` is set, this field is not used. + bool match_upstream = 2; } diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/proxy_protocol.proto b/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/proxy_protocol.proto index 32747dd2288..2da5fe5fd4d 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/proxy_protocol.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/proxy_protocol.proto @@ -2,6 +2,8 @@ syntax = "proto3"; package envoy.config.core.v3; +import "envoy/config/core/v3/substitution_format_string.proto"; + import "udpa/annotations/status.proto"; import "validate/validate.proto"; @@ -32,6 +34,34 @@ message ProxyProtocolPassThroughTLVs { repeated uint32 tlv_type = 2 [(validate.rules).repeated = {items {uint32 {lt: 256}}}]; } +// Represents a single Type-Length-Value (TLV) entry. +message TlvEntry { + // The type of the TLV. Must be a uint8 (0-255) as per the Proxy Protocol v2 specification. + uint32 type = 1 [(validate.rules).uint32 = {lt: 256}]; + + // The static value of the TLV. + // Only one of ``value`` or ``format_string`` may be set. + bytes value = 2; + + // Uses the :ref:`format string ` to dynamically + // populate the TLV value from stream information. This allows dynamic values + // such as metadata, filter state, or other stream properties to be included in + // the TLV. + // + // For example: + // + // .. code-block:: yaml + // + // type: 0xF0 + // format_string: + // text_format_source: + // inline_string: "%DYNAMIC_METADATA(envoy.filters.network:key)%" + // + // The formatted string will be used directly as the TLV value. + // Only one of ``value`` or ``format_string`` may be set. + SubstitutionFormatString format_string = 3; +} + message ProxyProtocolConfig { enum Version { // PROXY protocol version 1. Human readable format. @@ -47,4 +77,38 @@ message ProxyProtocolConfig { // This config controls which TLVs can be passed to upstream if it is Proxy Protocol // V2 header. If there is no setting for this field, no TLVs will be passed through. ProxyProtocolPassThroughTLVs pass_through_tlvs = 2; + + // This config allows additional TLVs to be included in the upstream PROXY protocol + // V2 header. Unlike ``pass_through_tlvs``, which passes TLVs from the downstream request, + // ``added_tlvs`` provides an extension mechanism for defining new TLVs that are included + // with the upstream request. These TLVs may not be present in the downstream request and + // can be defined at either the transport socket level or the host level to provide more + // granular control over the TLVs that are included in the upstream request. + // + // Host-level TLVs are specified in the ``metadata.typed_filter_metadata`` field under the + // ``envoy.transport_sockets.proxy_protocol`` namespace. + // + // .. literalinclude:: /_configs/repo/proxy_protocol.yaml + // :language: yaml + // :lines: 49-57 + // :linenos: + // :lineno-start: 49 + // :caption: :download:`proxy_protocol.yaml ` + // + // **Precedence behavior**: + // + // - When a TLV is defined at both the host level and the transport socket level, the value + // from the host level configuration takes precedence. This allows users to define default TLVs + // at the transport socket level and override them at the host level. + // - Any TLV defined in the ``pass_through_tlvs`` field will be overridden by either the host-level + // or transport socket-level TLV. + // + // If there are multiple TLVs with the same type, only the TLVs from the highest precedence level + // will be used. + repeated TlvEntry added_tlvs = 3; +} + +message PerHostConfig { + // Enables per-host configuration for Proxy Protocol. + repeated TlvEntry added_tlvs = 1; } diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/socket_cmsg_headers.proto b/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/socket_cmsg_headers.proto new file mode 100644 index 00000000000..cc3e58e0996 --- /dev/null +++ b/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/socket_cmsg_headers.proto @@ -0,0 +1,28 @@ +syntax = "proto3"; + +package envoy.config.core.v3; + +import "google/protobuf/wrappers.proto"; + +import "udpa/annotations/status.proto"; + +option java_package = "io.envoyproxy.envoy.config.core.v3"; +option java_outer_classname = "SocketCmsgHeadersProto"; +option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/config/core/v3;corev3"; +option (udpa.annotations.file_status).package_version_status = ACTIVE; + +// [#protodoc-title: Socket CMSG headers] + +// Configuration for socket cmsg headers. +// See `:ref:CMSG `_ for further information. +message SocketCmsgHeaders { + // cmsg level. Default is unset. + google.protobuf.UInt32Value level = 1; + + // cmsg type. Default is unset. + google.protobuf.UInt32Value type = 2; + + // Expected size of cmsg value. Default is zero. + uint32 expected_size = 3; +} diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/socket_option.proto b/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/socket_option.proto index 44f1ce3890a..ad73d72e490 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/socket_option.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/socket_option.proto @@ -36,7 +36,7 @@ option (udpa.annotations.file_status).package_version_status = ACTIVE; // :ref:`admin's ` socket_options etc. // // It should be noted that the name or level may have different values on different platforms. -// [#next-free-field: 7] +// [#next-free-field: 8] message SocketOption { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.core.SocketOption"; @@ -51,6 +51,29 @@ message SocketOption { STATE_LISTENING = 2; } + // The `socket type `_ to apply the socket option to. + // Only one field should be set. If multiple fields are set, the precedence order will determine + // the selected one. If none of the fields is set, the socket option will be applied to all socket types. + // + // For example: + // If :ref:`stream ` is set, + // it takes precedence over :ref:`datagram `. + message SocketType { + // The stream socket type. + message Stream { + } + + // The datagram socket type. + message Datagram { + } + + // Apply the socket option to the stream socket type. + Stream stream = 1; + + // Apply the socket option to the datagram socket type. + Datagram datagram = 2; + } + // An optional name to give this socket option for debugging, etc. // Uniqueness is not required and no special meaning is assumed. string description = 1; @@ -74,6 +97,10 @@ message SocketOption { // The state in which the option will be applied. When used in BindConfig // STATE_PREBIND is currently the only valid value. SocketState state = 6 [(validate.rules).enum = {defined_only: true}]; + + // Apply the socket option to the specified `socket type `_. + // If not specified, the socket option will be applied to all socket types. + SocketType type = 7; } message SocketOptionsOverride { diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/substitution_format_string.proto b/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/substitution_format_string.proto index abe8afa68ae..3edbf5f5f00 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/substitution_format_string.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/substitution_format_string.proto @@ -22,7 +22,12 @@ option (udpa.annotations.file_status).package_version_status = ACTIVE; // Optional configuration options to be used with json_format. message JsonFormatOptions { // The output JSON string properties will be sorted. - bool sort_properties = 1; + // + // .. note:: + // As the properties are always sorted, this option has no effect and is deprecated. + // + bool sort_properties = 1 + [deprecated = true, (envoy.annotations.deprecated_at_minor_version) = "3.0"]; } // Configuration to use multiple :ref:`command operators ` @@ -101,6 +106,12 @@ message SubstitutionFormatString { // * for ``text_format``, the output of the empty operator is changed from ``-`` to an // empty string, so that empty values are omitted entirely. // * for ``json_format`` the keys with null values are omitted in the output structure. + // + // .. note:: + // This option does not work perfectly with ``json_format`` as keys with ``null`` values + // will still be included in the output. See https://github.com/envoyproxy/envoy/issues/37941 + // for more details. + // bool omit_empty_values = 3; // Specify a ``content_type`` field. diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/endpoint/v3/endpoint.proto b/xds/third_party/envoy/src/main/proto/envoy/config/endpoint/v3/endpoint.proto index 20939526eb5..a149f6095c1 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/config/endpoint/v3/endpoint.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/config/endpoint/v3/endpoint.proto @@ -77,6 +77,12 @@ message ClusterLoadAssignment { // // Envoy supports only one element and will NACK if more than one element is present. // Other xDS-capable data planes will not necessarily have this limitation. + // + // In Envoy, this ``drop_overloads`` config can be overridden by a runtime key + // "load_balancing_policy.drop_overload_limit" setting. This runtime key can be set to + // any integer number between 0 and 100. 0 means drop 0%. 100 means drop 100%. + // When both ``drop_overloads`` config and "load_balancing_policy.drop_overload_limit" + // setting are in place, the min of these two wins. repeated DropOverload drop_overloads = 2; // Priority levels and localities are considered overprovisioned with this @@ -107,8 +113,9 @@ message ClusterLoadAssignment { // to determine the health of the priority level, or in other words assume each host has a weight of 1 for // this calculation. // - // Note: this is not currently implemented for - // :ref:`locality weighted load balancing `. + // .. note:: + // This is not currently implemented for + // :ref:`locality weighted load balancing `. bool weighted_priority_health = 6; } diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/endpoint/v3/endpoint_components.proto b/xds/third_party/envoy/src/main/proto/envoy/config/endpoint/v3/endpoint_components.proto index ebd2bb4c332..eacc555df73 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/config/endpoint/v3/endpoint_components.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/config/endpoint/v3/endpoint_components.proto @@ -9,6 +9,9 @@ import "envoy/config/core/v3/health_check.proto"; import "google/protobuf/wrappers.proto"; +import "xds/core/v3/collection_entry.proto"; + +import "envoy/annotations/deprecation.proto"; import "udpa/annotations/status.proto"; import "udpa/annotations/versioning.proto"; import "validate/validate.proto"; @@ -133,21 +136,31 @@ message LbEndpoint { google.protobuf.UInt32Value load_balancing_weight = 4 [(validate.rules).uint32 = {gte: 1}]; } +// LbEndpoint list collection. Entries are `LbEndpoint` resources or references. // [#not-implemented-hide:] -// A configuration for a LEDS collection. +message LbEndpointCollection { + xds.core.v3.CollectionEntry entries = 1; +} + +// A configuration for an LEDS collection. message LedsClusterLocalityConfig { // Configuration for the source of LEDS updates for a Locality. core.v3.ConfigSource leds_config = 1; - // The xDS transport protocol glob collection resource name. - // The service is only supported in delta xDS (incremental) mode. + // The name of the LbEndpoint collection resource. + // + // If the name ends in ``/*``, it indicates an LbEndpoint glob collection, + // which is supported only in the xDS incremental protocol variants. + // Otherwise, it indicates an LbEndpointCollection list collection. + // + // Envoy currently supports only glob collections. string leds_collection_name = 2; } // A group of endpoints belonging to a Locality. // One can have multiple LocalityLbEndpoints for a locality, but only if // they have different priorities. -// [#next-free-field: 9] +// [#next-free-field: 10] message LocalityLbEndpoints { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.endpoint.LocalityLbEndpoints"; @@ -161,19 +174,24 @@ message LocalityLbEndpoints { // Identifies location of where the upstream hosts run. core.v3.Locality locality = 1; + // Metadata to provide additional information about the locality endpoints in aggregate. + core.v3.Metadata metadata = 9; + // The group of endpoints belonging to the locality specified. - // [#comment:TODO(adisuissa): Once LEDS is implemented this field needs to be - // deprecated and replaced by ``load_balancer_endpoints``.] + // This is ignored if :ref:`leds_cluster_locality_config + // ` is set. repeated LbEndpoint lb_endpoints = 2; - // [#not-implemented-hide:] oneof lb_config { - // The group of endpoints belonging to the locality. - // [#comment:TODO(adisuissa): Once LEDS is implemented the ``lb_endpoints`` field - // needs to be deprecated.] - LbEndpointList load_balancer_endpoints = 7; + // [#not-implemented-hide:] + // Not implemented and deprecated. + LbEndpointList load_balancer_endpoints = 7 + [deprecated = true, (envoy.annotations.deprecated_at_minor_version) = "3.0"]; // LEDS Configuration for the current locality. + // If this is set, the :ref:`lb_endpoints + // ` + // field is ignored. LedsClusterLocalityConfig leds_cluster_locality_config = 8; } diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/endpoint/v3/load_report.proto b/xds/third_party/envoy/src/main/proto/envoy/config/endpoint/v3/load_report.proto index 832fe83dbb0..6d12765cef5 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/config/endpoint/v3/load_report.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/config/endpoint/v3/load_report.proto @@ -8,6 +8,8 @@ import "envoy/config/core/v3/base.proto"; import "google/protobuf/duration.proto"; import "google/protobuf/struct.proto"; +import "xds/annotations/v3/status.proto"; + import "udpa/annotations/status.proto"; import "udpa/annotations/versioning.proto"; import "validate/validate.proto"; @@ -23,7 +25,7 @@ option (udpa.annotations.file_status).package_version_status = ACTIVE; // These are stats Envoy reports to the management server at a frequency defined by // :ref:`LoadStatsResponse.load_reporting_interval`. // Stats per upstream region/zone and optionally per subzone. -// [#next-free-field: 9] +// [#next-free-field: 15] message UpstreamLocalityStats { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.endpoint.UpstreamLocalityStats"; @@ -36,7 +38,8 @@ message UpstreamLocalityStats { // locality. uint64 total_successful_requests = 2; - // The total number of unfinished requests + // The total number of unfinished requests. A request can be an HTTP request + // or a TCP connection for a TCP connection pool. uint64 total_requests_in_progress = 3; // The total number of requests that failed due to errors at the endpoint, @@ -45,10 +48,49 @@ message UpstreamLocalityStats { // The total number of requests that were issued by this Envoy since // the last report. This information is aggregated over all the - // upstream endpoints in the locality. + // upstream endpoints in the locality. A request can be an HTTP request + // or a TCP connection for a TCP connection pool. uint64 total_issued_requests = 8; - // Stats for multi-dimensional load balancing. + // The total number of connections in an established state at the time of the + // report. This field is aggregated over all the upstream endpoints in the + // locality. + // In Envoy, this information may be based on ``upstream_cx_active metric``. + // [#not-implemented-hide:] + uint64 total_active_connections = 9 [(xds.annotations.v3.field_status).work_in_progress = true]; + + // The total number of connections opened since the last report. + // This field is aggregated over all the upstream endpoints in the locality. + // In Envoy, this information may be based on ``upstream_cx_total`` metric + // compared to itself between start and end of an interval, i.e. + // ``upstream_cx_total``(now) - ``upstream_cx_total``(now - + // load_report_interval). + // [#not-implemented-hide:] + uint64 total_new_connections = 10 [(xds.annotations.v3.field_status).work_in_progress = true]; + + // The total number of connection failures since the last report. + // This field is aggregated over all the upstream endpoints in the locality. + // In Envoy, this information may be based on ``upstream_cx_connect_fail`` + // metric compared to itself between start and end of an interval, i.e. + // ``upstream_cx_connect_fail``(now) - ``upstream_cx_connect_fail``(now - + // load_report_interval). + // [#not-implemented-hide:] + uint64 total_fail_connections = 11 [(xds.annotations.v3.field_status).work_in_progress = true]; + + // CPU utilization stats for multi-dimensional load balancing. + // This typically comes from endpoint metrics reported via ORCA. + UnnamedEndpointLoadMetricStats cpu_utilization = 12; + + // Memory utilization for multi-dimensional load balancing. + // This typically comes from endpoint metrics reported via ORCA. + UnnamedEndpointLoadMetricStats mem_utilization = 13; + + // Blended application-defined utilization for multi-dimensional load balancing. + // This typically comes from endpoint metrics reported via ORCA. + UnnamedEndpointLoadMetricStats application_utilization = 14; + + // Named stats for multi-dimensional load balancing. + // These typically come from endpoint metrics reported via ORCA. repeated EndpointLoadMetricStats load_metric_stats = 5; // Endpoint granularity stats information for this locality. This information @@ -118,6 +160,16 @@ message EndpointLoadMetricStats { double total_metric_value = 3; } +// Same as EndpointLoadMetricStats, except without the metric_name field. +message UnnamedEndpointLoadMetricStats { + // Number of calls that finished and included this metric. + uint64 num_requests_finished_with_metric = 1; + + // Sum of metric values across all calls that finished with this metric for + // load_reporting_interval. + double total_metric_value = 2; +} + // Per cluster load stats. Envoy reports these stats a management server in a // :ref:`LoadStatsRequest` // Next ID: 7 diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/listener/v3/listener.proto b/xds/third_party/envoy/src/main/proto/envoy/config/listener/v3/listener.proto index a1a3d82c1c8..54ef2cfed38 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/config/listener/v3/listener.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/config/listener/v3/listener.proto @@ -5,6 +5,7 @@ package envoy.config.listener.v3; import "envoy/config/accesslog/v3/accesslog.proto"; import "envoy/config/core/v3/address.proto"; import "envoy/config/core/v3/base.proto"; +import "envoy/config/core/v3/config_source.proto"; import "envoy/config/core/v3/extension.proto"; import "envoy/config/core/v3/socket_option.proto"; import "envoy/config/listener/v3/api_listener.proto"; @@ -14,7 +15,6 @@ import "envoy/config/listener/v3/udp_listener_config.proto"; import "google/protobuf/duration.proto"; import "google/protobuf/wrappers.proto"; -import "xds/annotations/v3/status.proto"; import "xds/core/v3/collection_entry.proto"; import "xds/type/matcher/v3/matcher.proto"; @@ -45,6 +45,14 @@ message AdditionalAddress { // or an empty list of :ref:`socket_options `, // it means no socket option will apply. core.v3.SocketOptionsOverride socket_options = 2; + + // Configures TCP keepalive settings for the additional address. + // If not set, the listener :ref:`tcp_keepalive ` + // configuration is inherited. You can explicitly disable TCP keepalive for the additional address by setting any keepalive field + // (:ref:`keepalive_probes `, + // :ref:`keepalive_time `, or + // :ref:`keepalive_interval `) to ``0``. + core.v3.TcpKeepalive tcp_keepalive = 3; } // Listener list collections. Entries are ``Listener`` resources or references. @@ -53,7 +61,7 @@ message ListenerCollection { repeated xds.core.v3.CollectionEntry entries = 1; } -// [#next-free-field: 35] +// [#next-free-field: 38] message Listener { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.Listener"; @@ -115,6 +123,20 @@ message Listener { message InternalListenerConfig { } + // Configuration for filter chains discovery. + // [#not-implemented-hide:] + message FcdsConfig { + // Optional name to present to the filter chain discovery service. This may be an arbitrary name with arbitrary + // length. If a name is not provided, the listener's name is used. Refer to :ref:`filter_chains `. + // for details on how listener name is determined if unspecified. In addition, this may be a xdstp:// URL. + string name = 1; + + // Configuration for the source of FCDS updates for this listener. + // .. note:: + // This discovery service only supports ``AGGREGATED_GRPC`` API type. + core.v3.ConfigSource config_source = 2; + } + reserved 14, 23; // The unique name by which this listener is known. If no name is provided, @@ -126,6 +148,12 @@ message Listener { // that is governed by the bind rules of the OS. E.g., multiple listeners can listen on port 0 on // Linux as the actual port will be allocated by the OS. // Required unless ``api_listener`` or ``listener_specifier`` is populated. + // + // When the address contains a network namespace filepath (via + // :ref:`network_namespace_filepath `), + // Envoy automatically populates the filter state with key ``envoy.network.network_namespace`` + // when a connection is accepted. This provides read-only access to the network namespace for + // filters, access logs, and other components. core.v3.Address address = 2; // The additional addresses the listener should listen on. The addresses must be unique across all @@ -147,6 +175,12 @@ message Listener { // :ref:`FAQ entry `. repeated FilterChain filter_chains = 3; + // Discover filter chains configurations by external service. Dynamic discovery of filter chains is allowed + // while having statically configured filter chains, however, a filter chain name must be unique within a + // listener. If a discovered filter chain matches a name of an existing filter chain, it is discarded. + // [#not-implemented-hide:] + FcdsConfig fcds_config = 36; + // :ref:`Matcher API ` resolving the filter chain name from the // network properties. This matcher is used as a replacement for the filter chain match condition // :ref:`filter_chain_match @@ -163,8 +197,7 @@ message Listener { // connections bound to the filter chain are not drained. If, however, the // filter chain is removed or structurally modified, then the drain for its // connections is initiated. - xds.type.matcher.v3.Matcher filter_chain_matcher = 32 - [(xds.annotations.v3.field_status).work_in_progress = true]; + xds.type.matcher.v3.Matcher filter_chain_matcher = 32; // If a connection is redirected using ``iptables``, the port on which the proxy // receives it might be different from the original destination address. When this flag is set to @@ -247,10 +280,10 @@ message Listener { google.protobuf.BoolValue freebind = 11; // Additional socket options that may not be present in Envoy source code or - // precompiled binaries. The socket options can be updated for a listener when + // precompiled binaries. + // It is not allowed to update the socket options for any existing address if // :ref:`enable_reuse_port ` - // is ``true``. Otherwise, if socket options change during a listener update the update will be rejected - // to make it clear that the options were not updated. + // is ``false`` to avoid the conflict when creating new sockets for the listener. repeated core.v3.SocketOption socket_options = 13; // Whether the listener should accept TCP Fast Open (TFO) connections. @@ -352,6 +385,11 @@ message Listener { // accepted in later event loop iterations. // If no value is provided Envoy will accept all connections pending accept // from the kernel. + // + // .. note:: + // + // It is recommended to lower this value for better overload management and reduced per-event cost. + // Setting it to 1 is a viable option with no noticeable impact on performance. google.protobuf.UInt32Value max_connections_to_accept_per_socket_event = 34 [(validate.rules).uint32 = {gt: 0}]; @@ -387,6 +425,15 @@ message Listener { // Whether the listener should limit connections based upon the value of // :ref:`global_downstream_max_connections `. bool ignore_global_conn_limit = 31; + + // Whether the listener bypasses configured overload manager actions. + bool bypass_overload_manager = 35; + + // If set, TCP keepalive settings are configured for the listener address and inherited by + // additional addresses. If not set, TCP keepalive settings are not configured for the + // listener address and additional addresses by default. See :ref:`tcp_keepalive ` + // to explicitly configure TCP keepalive settings for individual additional addresses. + core.v3.TcpKeepalive tcp_keepalive = 37; } // A placeholder proto so that users can explicitly configure the standard diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/listener/v3/listener_components.proto b/xds/third_party/envoy/src/main/proto/envoy/config/listener/v3/listener_components.proto index 2adb8bc2c80..16b43568f39 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/config/listener/v3/listener_components.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/config/listener/v3/listener_components.proto @@ -201,24 +201,9 @@ message FilterChainMatch { message FilterChain { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.listener.FilterChain"; - // The configuration for on-demand filter chain. If this field is not empty in FilterChain message, - // a filter chain will be built on-demand. - // On-demand filter chains help speedup the warming up of listeners since the building and initialization of - // an on-demand filter chain will be postponed to the arrival of new connection requests that require this filter chain. - // Filter chains that are not often used can be set as on-demand. - message OnDemandConfiguration { - // The timeout to wait for filter chain placeholders to complete rebuilding. - // 1. If this field is set to 0, timeout is disabled. - // 2. If not specified, a default timeout of 15s is used. - // Rebuilding will wait until dependencies are ready, have failed, or this timeout is reached. - // Upon failure or timeout, all connections related to this filter chain will be closed. - // Rebuilding will start again on the next new connection. - google.protobuf.Duration rebuild_timeout = 1; - } - - reserved 2; + reserved 2, 8; - reserved "tls_context"; + reserved "tls_context", "on_demand_configuration"; // The criteria to use when matching a connection to this filter chain. FilterChainMatch filter_chain_match = 1; @@ -248,7 +233,7 @@ message FilterChain { google.protobuf.BoolValue use_proxy_proto = 4 [deprecated = true, (envoy.annotations.deprecated_at_minor_version) = "3.0"]; - // [#not-implemented-hide:] filter chain metadata. + // Filter chain metadata. core.v3.Metadata metadata = 5; // Optional custom transport socket implementation to use for downstream connections. @@ -265,15 +250,12 @@ message FilterChain { google.protobuf.Duration transport_socket_connect_timeout = 9; // The unique name (or empty) by which this filter chain is known. - // Note: :ref:`filter_chain_matcher - // ` - // requires that filter chains are uniquely named within a listener. + // + // .. note:: + // :ref:`filter_chain_matcher + // ` + // requires that filter chains are uniquely named within a listener. string name = 7; - - // [#not-implemented-hide:] The configuration to specify whether the filter chain will be built on-demand. - // If this field is not empty, the filter chain will be built on-demand. - // Otherwise, the filter chain will be built normally and block listener warming. - OnDemandConfiguration on_demand_configuration = 8; } // Listener filter chain match configuration. This is a recursive structure which allows complex diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/listener/v3/quic_config.proto b/xds/third_party/envoy/src/main/proto/envoy/config/listener/v3/quic_config.proto index 3a8ce2cd0a6..c208a58f4a4 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/config/listener/v3/quic_config.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/config/listener/v3/quic_config.proto @@ -5,6 +5,7 @@ package envoy.config.listener.v3; import "envoy/config/core/v3/base.proto"; import "envoy/config/core/v3/extension.proto"; import "envoy/config/core/v3/protocol.proto"; +import "envoy/config/core/v3/socket_cmsg_headers.proto"; import "google/protobuf/duration.proto"; import "google/protobuf/wrappers.proto"; @@ -24,7 +25,7 @@ option (udpa.annotations.file_status).package_version_status = ACTIVE; // [#protodoc-title: QUIC listener config] // Configuration specific to the UDP QUIC listener. -// [#next-free-field: 10] +// [#next-free-field: 15] message QuicProtocolOptions { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.listener.QuicProtocolOptions"; @@ -72,9 +73,36 @@ message QuicProtocolOptions { core.v3.TypedExtensionConfig connection_id_generator_config = 8; // Configure the server's preferred address to advertise so that client can migrate to it. See :ref:`example ` which configures a pair of v4 and v6 preferred addresses. - // The current QUICHE implementation will advertise only one of the preferred IPv4 and IPv6 addresses based on the address family the client initially connects with, and only if the client is also QUICHE-based. + // The current QUICHE implementation will advertise only one of the preferred IPv4 and IPv6 addresses based on the address family the client initially connects with. // If not specified, Envoy will not advertise any server's preferred address. // [#extension-category: envoy.quic.server_preferred_address] core.v3.TypedExtensionConfig server_preferred_address_config = 9 [(xds.annotations.v3.field_status).work_in_progress = true]; + + // Configure the server to send transport parameter `disable_active_migration `_. + // Defaults to false (do not send this transport parameter). + google.protobuf.BoolValue send_disable_active_migration = 10; + + // Configure which implementation of ``quic::QuicConnectionDebugVisitor`` to be used for this listener. + // If not specified, no debug visitor will be attached to connections. + // [#extension-category: envoy.quic.connection_debug_visitor] + core.v3.TypedExtensionConfig connection_debug_visitor_config = 11; + + // Configure a type of UDP cmsg to pass to listener filters via QuicReceivedPacket. + // Both level and type must be specified for cmsg to be saved. + // Cmsg may be truncated or omitted if expected size is not set. + // If not specified, no cmsg will be saved to QuicReceivedPacket. + repeated core.v3.SocketCmsgHeaders save_cmsg_config = 12 + [(validate.rules).repeated = {max_items: 1}]; + + // If true, the listener will reject connection-establishing packets at the + // QUIC layer by replying with an empty version negotiation packet to the + // client. + bool reject_new_connections = 13; + + // Maximum number of QUIC sessions to create per event loop. + // If not specified, the default value is 16. + // This is an equivalent of the TCP listener option + // max_connections_to_accept_per_socket_event. + google.protobuf.UInt32Value max_sessions_per_event_loop = 14 [(validate.rules).uint32 = {gt: 0}]; } diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/metrics/v3/stats.proto b/xds/third_party/envoy/src/main/proto/envoy/config/metrics/v3/stats.proto index e7d7f80d648..0fcf36c1c71 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/config/metrics/v3/stats.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/config/metrics/v3/stats.proto @@ -60,11 +60,6 @@ message StatsConfig { // `. They will be processed before // the custom tags. // - // .. note:: - // - // If any default tags are specified twice, the config will be considered - // invalid. - // // See :repo:`well_known_names.h ` for a list of the // default tags in Envoy. // @@ -298,10 +293,12 @@ message HistogramBucketSettings { // Each value is the upper bound of a bucket. Each bucket must be greater than 0 and unique. // The order of the buckets does not matter. repeated double buckets = 2 [(validate.rules).repeated = { - min_items: 1 unique: true items {double {gt: 0.0}} }]; + + // Initial number of bins for the ``circllhist`` thread local histogram per time series. Default value is 100. + google.protobuf.UInt32Value bins = 3 [(validate.rules).uint32 = {lte: 46082 gt: 0}]; } // Stats configuration proto schema for built-in ``envoy.stat_sinks.statsd`` sink. This sink does not support diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/overload/v3/overload.proto b/xds/third_party/envoy/src/main/proto/envoy/config/overload/v3/overload.proto index d3b8b01a173..b5bc2c4d830 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/config/overload/v3/overload.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/config/overload/v3/overload.proto @@ -103,6 +103,19 @@ message ScaleTimersOverloadActionConfig { // This affects the value of // :ref:`FilterChain.transport_socket_connect_timeout `. TRANSPORT_SOCKET_CONNECT = 3; + + // Adjusts the max connection duration timer for downstream HTTP connections. + // This affects the value of + // :ref:`HttpConnectionManager.common_http_protocol_options.max_connection_duration + // `. + HTTP_DOWNSTREAM_CONNECTION_MAX = 4; + + // Adjusts the timeout for the downstream codec to flush an ended stream. + // This affects the value of :ref:`RouteAction.flush_timeout + // ` and + // :ref:`HttpConnectionManager.stream_flush_timeout + // ` + HTTP_DOWNSTREAM_STREAM_FLUSH = 5; } message ScaleTimer { @@ -128,9 +141,16 @@ message OverloadAction { option (udpa.annotations.versioning).previous_message_type = "envoy.config.overload.v2alpha.OverloadAction"; - // The name of the overload action. This is just a well-known string that listeners can - // use for registering callbacks. Custom overload actions should be named using reverse - // DNS to ensure uniqueness. + // The name of the overload action. This is just a well-known string that + // listeners can use for registering callbacks. + // Valid known overload actions include: + // - envoy.overload_actions.stop_accepting_requests + // - envoy.overload_actions.disable_http_keepalive + // - envoy.overload_actions.stop_accepting_connections + // - envoy.overload_actions.reject_incoming_connections + // - envoy.overload_actions.shrink_heap + // - envoy.overload_actions.reduce_timeouts + // - envoy.overload_actions.reset_high_memory_stream string name = 1 [(validate.rules).string = {min_len: 1}]; // A set of triggers for this action. The state of the action is the maximum @@ -142,7 +162,7 @@ message OverloadAction { // in this list. repeated Trigger triggers = 2 [(validate.rules).repeated = {min_items: 1}]; - // Configuration for the action being instantiated. + // Configuration for the action being instantiated if applicable. google.protobuf.Any typed_config = 3; } diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/rbac/v3/rbac.proto b/xds/third_party/envoy/src/main/proto/envoy/config/rbac/v3/rbac.proto index 3a9271c0015..ef153ad177b 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/config/rbac/v3/rbac.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/config/rbac/v3/rbac.proto @@ -3,6 +3,7 @@ syntax = "proto3"; package envoy.config.rbac.v3; import "envoy/config/core/v3/address.proto"; +import "envoy/config/core/v3/cel.proto"; import "envoy/config/core/v3/extension.proto"; import "envoy/config/route/v3/route_components.proto"; import "envoy/type/matcher/v3/filter_state.proto"; @@ -28,6 +29,14 @@ option (udpa.annotations.file_status).package_version_status = ACTIVE; // [#protodoc-title: Role Based Access Control (RBAC)] +enum MetadataSource { + // Query :ref:`dynamic metadata ` + DYNAMIC = 0; + + // Query :ref:`route metadata ` + ROUTE = 1; +} + // Role Based Access Control (RBAC) provides service-level and method-level access control for a // service. Requests are allowed or denied based on the ``action`` and whether a matching policy is // found. For instance, if the action is ALLOW and a matching policy is found the request should be @@ -165,6 +174,7 @@ message RBAC { // A policy matches if and only if at least one of its permissions match the // action taking place AND at least one of its principals match the downstream // AND the condition is true if specified. +// [#next-free-field: 6] message Policy { option (udpa.annotations.versioning).previous_message_type = "envoy.config.rbac.v2.Policy"; @@ -191,10 +201,37 @@ message Policy { // Only be used when condition is not used. google.api.expr.v1alpha1.CheckedExpr checked_condition = 4 [(udpa.annotations.field_migrate).oneof_promotion = "expression_specifier"]; + + // CEL expression configuration that modifies the evaluation behavior of the ``condition`` field. + // If specified, string conversion, concatenation, and manipulation functions may be enabled + // for the CEL expression. See :ref:`CelExpressionConfig ` + // for more details. + core.v3.CelExpressionConfig cel_config = 5; +} + +// SourcedMetadata enables matching against metadata from different sources in the request processing +// pipeline. It extends the base MetadataMatcher functionality by allowing specification of where the +// metadata should be sourced from, rather than only matching against dynamic metadata. +// +// The matcher can be configured to look up metadata from: +// +// * Dynamic metadata: Runtime metadata added by filters during request processing +// * Route metadata: Static metadata configured on the route entry +// +message SourcedMetadata { + // Metadata matcher configuration that defines what metadata to match against. This includes the filter name, + // metadata key path, and expected value. + type.matcher.v3.MetadataMatcher metadata_matcher = 1 + [(validate.rules).message = {required: true}]; + + // Specifies which metadata source should be used for matching. If not set, + // defaults to DYNAMIC (dynamic metadata). Set to ROUTE to match against + // static metadata configured on the route entry. + MetadataSource metadata_source = 2 [(validate.rules).enum = {defined_only: true}]; } // Permission defines an action (or actions) that a principal can take. -// [#next-free-field: 13] +// [#next-free-field: 15] message Permission { option (udpa.annotations.versioning).previous_message_type = "envoy.config.rbac.v2.Permission"; @@ -219,10 +256,14 @@ message Permission { // When any is set, it matches any action. bool any = 3 [(validate.rules).bool = {const: true}]; - // A header (or pseudo-header such as :path or :method) on the incoming HTTP request. Only - // available for HTTP request. - // Note: the pseudo-header :path includes the query and fragment string. Use the ``url_path`` - // field if you want to match the URL path without the query and fragment string. + // A header (or pseudo-header such as ``:path`` or ``:method``) on the incoming HTTP request. Only available + // for HTTP request. + // + // .. note:: + // + // The pseudo-header ``:path`` includes the query and fragment string. Use the ``url_path`` field if you + // want to match the URL path without the query and fragment string. + // route.v3.HeaderMatcher header = 4; // A URL path on the incoming HTTP request. Only available for HTTP. @@ -237,16 +278,17 @@ message Permission { // A port number range that describes a range of destination ports connecting to. type.v3.Int32Range destination_port_range = 11; - // Metadata that describes additional information about the action. - type.matcher.v3.MetadataMatcher metadata = 7; + // Metadata that describes additional information about the action. This field is deprecated; please use + // :ref:`sourced_metadata` instead. + type.matcher.v3.MetadataMatcher metadata = 7 + [deprecated = true, (envoy.annotations.deprecated_at_minor_version) = "3.0"]; // Negates matching the provided permission. For instance, if the value of // ``not_rule`` would match, this permission would not match. Conversely, if // the value of ``not_rule`` would not match, this permission would match. Permission not_rule = 8; - // The request server from the client's connection request. This is - // typically TLS SNI. + // The request server from the client's connection request. This is typically TLS SNI. // // .. attention:: // @@ -263,19 +305,26 @@ message Permission { // * A :ref:`listener filter ` may // overwrite a connection's requested server name within Envoy. // - // Please refer to :ref:`this FAQ entry ` to learn to - // setup SNI. + // Please refer to :ref:`this FAQ entry ` to learn how to setup SNI. type.matcher.v3.StringMatcher requested_server_name = 9; // Extension for configuring custom matchers for RBAC. // [#extension-category: envoy.rbac.matchers] core.v3.TypedExtensionConfig matcher = 12; + + // URI template path matching. + // [#extension-category: envoy.path.match] + core.v3.TypedExtensionConfig uri_template = 13; + + // Matches against metadata from either dynamic state or route configuration. Preferred over the + // ``metadata`` field as it provides more flexibility in metadata source selection. + SourcedMetadata sourced_metadata = 14; } } // Principal defines an identity or a group of identities for a downstream // subject. -// [#next-free-field: 13] +// [#next-free-field: 15] message Principal { option (udpa.annotations.versioning).previous_message_type = "envoy.config.rbac.v2.Principal"; @@ -289,6 +338,10 @@ message Principal { } // Authentication attributes for a downstream. + // It is recommended to NOT use this type, but instead use + // :ref:`MTlsAuthenticated `, + // configured via :ref:`custom `, + // which should be used for most use cases due to its improved security. message Authenticated { option (udpa.annotations.versioning).previous_message_type = "envoy.config.rbac.v2.Principal.Authenticated"; @@ -297,25 +350,31 @@ message Principal { // The name of the principal. If set, The URI SAN or DNS SAN in that order // is used from the certificate, otherwise the subject field is used. If - // unset, it applies to any user that is authenticated. + // unset, it applies to any user that is allowed by the downstream TLS configuration. + // If :ref:`require_client_certificate ` + // is false or :ref:`trust_chain_verification ` + // is set to :ref:`ACCEPT_UNTRUSTED `, + // then no authentication is required. type.matcher.v3.StringMatcher principal_name = 2; } oneof identifier { option (validate.required) = true; - // A set of identifiers that all must match in order to define the - // downstream. + // A set of identifiers that all must match in order to define the downstream. Set and_ids = 1; - // A set of identifiers at least one must match in order to define the - // downstream. + // A set of identifiers at least one must match in order to define the downstream. Set or_ids = 2; // When any is set, it matches any downstream. bool any = 3 [(validate.rules).bool = {const: true}]; // Authenticated attributes that identify the downstream. + // It is recommended to NOT use this field, but instead use + // :ref:`MTlsAuthenticated `, + // configured via :ref:`custom `, + // which should be used for most use cases due to its improved security. Authenticated authenticated = 4; // A CIDR block that describes the downstream IP. @@ -329,31 +388,42 @@ message Principal { [deprecated = true, (envoy.annotations.deprecated_at_minor_version) = "3.0"]; // A CIDR block that describes the downstream remote/origin address. - // Note: This is always the physical peer even if the - // :ref:`remote_ip ` is - // inferred from for example the x-forwarder-for header, proxy protocol, - // etc. + // + // .. note:: + // + // This is always the physical peer even if the + // :ref:`remote_ip ` is inferred from the + // x-forwarder-for header, the proxy protocol, etc. + // core.v3.CidrRange direct_remote_ip = 10; // A CIDR block that describes the downstream remote/origin address. - // Note: This may not be the physical peer and could be different from the - // :ref:`direct_remote_ip - // `. E.g, if the - // remote ip is inferred from for example the x-forwarder-for header, proxy - // protocol, etc. + // + // .. note:: + // + // This may not be the physical peer and could be different from the :ref:`direct_remote_ip + // `. E.g, if the remote ip is inferred from + // the x-forwarder-for header, the proxy protocol, etc. + // core.v3.CidrRange remote_ip = 11; - // A header (or pseudo-header such as :path or :method) on the incoming HTTP - // request. Only available for HTTP request. Note: the pseudo-header :path - // includes the query and fragment string. Use the ``url_path`` field if you - // want to match the URL path without the query and fragment string. + // A header (or pseudo-header such as ``:path`` or ``:method``) on the incoming HTTP request. Only available + // for HTTP request. + // + // .. note:: + // + // The pseudo-header ``:path`` includes the query and fragment string. Use the ``url_path`` field if you + // want to match the URL path without the query and fragment string. + // route.v3.HeaderMatcher header = 6; // A URL path on the incoming HTTP request. Only available for HTTP. type.matcher.v3.PathMatcher url_path = 9; - // Metadata that describes additional information about the principal. - type.matcher.v3.MetadataMatcher metadata = 7; + // Metadata that describes additional information about the principal. This field is deprecated; please use + // :ref:`sourced_metadata` instead. + type.matcher.v3.MetadataMatcher metadata = 7 + [deprecated = true, (envoy.annotations.deprecated_at_minor_version) = "3.0"]; // Identifies the principal using a filter state object. type.matcher.v3.FilterStateMatcher filter_state = 12; @@ -362,6 +432,14 @@ message Principal { // ``not_id`` would match, this principal would not match. Conversely, if the // value of ``not_id`` would not match, this principal would match. Principal not_id = 8; + + // Matches against metadata from either dynamic state or route configuration. Preferred over the + // ``metadata`` field as it provides more flexibility in metadata source selection. + SourcedMetadata sourced_metadata = 13; + + // Extension for configuring custom principals for RBAC. + // [#extension-category: envoy.rbac.principals] + core.v3.TypedExtensionConfig custom = 14; } } @@ -373,7 +451,7 @@ message Action { // The action to take if the matcher matches. Every action either allows or denies a request, // and can also carry out action-specific operations. // - // Actions: + // **Actions:** // // * ``ALLOW``: If the request gets matched on ALLOW, it is permitted. // * ``DENY``: If the request gets matched on DENY, it is not permitted. @@ -382,7 +460,7 @@ message Action { // ``envoy.common`` will be set to the value ``true``. // * If the request cannot get matched, it will fallback to ``DENY``. // - // Log behavior: + // **Log behavior:** // // If the RBAC matcher contains at least one LOG action, the dynamic // metadata key ``access_log_hint`` will be set based on if the request diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/route/v3/route.proto b/xds/third_party/envoy/src/main/proto/envoy/config/route/v3/route.proto index c4d507d22b0..5bd909f34c3 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/config/route/v3/route.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/config/route/v3/route.proto @@ -23,7 +23,7 @@ option (udpa.annotations.file_status).package_version_status = ACTIVE; // * Routing :ref:`architecture overview ` // * HTTP :ref:`router filter ` -// [#next-free-field: 18] +// [#next-free-field: 19] message RouteConfiguration { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.RouteConfiguration"; @@ -129,10 +129,17 @@ message RouteConfiguration { // By default, port in :authority header (if any) is used in host matching. // With this option enabled, Envoy will ignore the port number in the :authority header (if any) when picking VirtualHost. - // NOTE: this option will not strip the port number (if any) contained in route config - // :ref:`envoy_v3_api_msg_config.route.v3.VirtualHost`.domains field. + // + // .. note:: + // This option will not strip the port number (if any) contained in route config + // :ref:`envoy_v3_api_msg_config.route.v3.VirtualHost`.domains field. bool ignore_port_in_host_matching = 14; + // Normally, virtual host matching is done using the :authority (or + // Host: in HTTP < 2) HTTP header. Setting this will instead, use a + // different HTTP header for this purpose. + string vhost_header = 18; + // Ignore path-parameters in path-matching. // Before RFC3986, URI were like(RFC1808): :///;?# // Envoy by default takes ":path" as ";". diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/route/v3/route_components.proto b/xds/third_party/envoy/src/main/proto/envoy/config/route/v3/route_components.proto index 1e2b486d288..4587ef10487 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/config/route/v3/route_components.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/config/route/v3/route_components.proto @@ -2,9 +2,12 @@ syntax = "proto3"; package envoy.config.route.v3; +import "envoy/config/common/mutation_rules/v3/mutation_rules.proto"; import "envoy/config/core/v3/base.proto"; import "envoy/config/core/v3/extension.proto"; import "envoy/config/core/v3/proxy_protocol.proto"; +import "envoy/config/core/v3/substitution_format_string.proto"; +import "envoy/type/matcher/v3/filter_state.proto"; import "envoy/type/matcher/v3/metadata.proto"; import "envoy/type/matcher/v3/regex.proto"; import "envoy/type/matcher/v3/string.proto"; @@ -17,7 +20,6 @@ import "google/protobuf/any.proto"; import "google/protobuf/duration.proto"; import "google/protobuf/wrappers.proto"; -import "xds/annotations/v3/status.proto"; import "xds/type/matcher/v3/matcher.proto"; import "envoy/annotations/deprecation.proto"; @@ -41,7 +43,7 @@ option (udpa.annotations.file_status).package_version_status = ACTIVE; // host header. This allows a single listener to service multiple top level domain path trees. Once // a virtual host is selected based on the domain, the routes are processed in order to see which // upstream cluster to route to or whether to perform a redirect. -// [#next-free-field: 25] +// [#next-free-field: 26] message VirtualHost { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.route.VirtualHost"; @@ -78,7 +80,7 @@ message VirtualHost { // .. note:: // // The wildcard will not match the empty string. - // e.g. ``*-bar.foo.com`` will match ``baz-bar.foo.com`` but not ``-bar.foo.com``. + // For example, ``*-bar.foo.com`` will match ``baz-bar.foo.com`` but not ``-bar.foo.com``. // The longest wildcards match first. // Only a single virtual host in the entire route configuration can match on ``*``. A domain // must be unique across all virtual hosts or the config will fail to load. @@ -92,13 +94,12 @@ message VirtualHost { // The list of routes that will be matched, in order, for incoming requests. // The first route that matches will be used. // Only one of this and ``matcher`` can be specified. - repeated Route routes = 3; + repeated Route routes = 3 [(udpa.annotations.field_migrate).oneof_promotion = "route_selection"]; - // [#next-major-version: This should be included in a oneof with routes wrapped in a message.] // The match tree to use when resolving route actions for incoming requests. Only one of this and ``routes`` // can be specified. xds.type.matcher.v3.Matcher matcher = 21 - [(xds.annotations.v3.field_status).work_in_progress = true]; + [(udpa.annotations.field_migrate).oneof_promotion = "route_selection"]; // Specifies the type of TLS enforcement the virtual host expects. If this option is not // specified, there is no TLS requirement for the virtual host. @@ -156,7 +157,7 @@ message VirtualHost { // This field can be used to provide virtual host level per filter config. The key should match the // :ref:`filter config name // `. - // See :ref:`Http filter route specific config ` + // See :ref:`HTTP filter route-specific config ` // for details. // [#comment: An entry's value may be wrapped in a // :ref:`FilterConfig` @@ -167,7 +168,10 @@ message VirtualHost { // ` header should be included // in the upstream request. Setting this option will cause it to override any existing header // value, so in the case of two Envoys on the request path with this option enabled, the upstream - // will see the attempt count as perceived by the second Envoy. Defaults to false. + // will see the attempt count as perceived by the second Envoy. + // + // Defaults to ``false``. + // // This header is unaffected by the // :ref:`suppress_envoy_headers // ` flag. @@ -179,7 +183,10 @@ message VirtualHost { // ` header should be included // in the downstream response. Setting this option will cause the router to override any existing header // value, so in the case of two Envoys on the request path with this option enabled, the downstream - // will see the attempt count as perceived by the Envoy closest upstream from itself. Defaults to false. + // will see the attempt count as perceived by the Envoy closest upstream from itself. + // + // Defaults to ``false``. + // // This header is unaffected by the // :ref:`suppress_envoy_headers // ` flag. @@ -187,29 +194,56 @@ message VirtualHost { // Indicates the retry policy for all routes in this virtual host. Note that setting a // route level entry will take precedence over this config and it'll be treated - // independently (e.g.: values are not inherited). + // independently (e.g., values are not inherited). RetryPolicy retry_policy = 16; // [#not-implemented-hide:] // Specifies the configuration for retry policy extension. Note that setting a route level entry - // will take precedence over this config and it'll be treated independently (e.g.: values are not + // will take precedence over this config and it'll be treated independently (e.g., values are not // inherited). :ref:`Retry policy ` should not be // set if this field is used. google.protobuf.Any retry_policy_typed_config = 20; // Indicates the hedge policy for all routes in this virtual host. Note that setting a // route level entry will take precedence over this config and it'll be treated - // independently (e.g.: values are not inherited). + // independently (e.g., values are not inherited). HedgePolicy hedge_policy = 17; // Decides whether to include the :ref:`x-envoy-is-timeout-retry ` - // request header in retries initiated by per try timeouts. + // request header in retries initiated by per-try timeouts. bool include_is_timeout_retry_header = 23; - // The maximum bytes which will be buffered for retries and shadowing. - // If set and a route-specific limit is not set, the bytes actually buffered will be the minimum - // value of this and the listener per_connection_buffer_limit_bytes. - google.protobuf.UInt32Value per_request_buffer_limit_bytes = 18; + // The maximum bytes which will be buffered for retries and shadowing. If set, the bytes actually buffered will be + // the minimum value of this and the listener ``per_connection_buffer_limit_bytes``. + // + // .. attention:: + // + // This field has been deprecated. Please use :ref:`request_body_buffer_limit + // ` instead. + // Only one of ``per_request_buffer_limit_bytes`` and ``request_body_buffer_limit`` could be set. + google.protobuf.UInt32Value per_request_buffer_limit_bytes = 18 + [deprecated = true, (envoy.annotations.deprecated_at_minor_version) = "3.0"]; + + // The maximum bytes which will be buffered for request bodies to support large request body + // buffering beyond the ``per_connection_buffer_limit_bytes``. + // + // This limit is specifically for the request body buffering and allows buffering larger payloads while maintaining + // flow control. + // + // Buffer limit precedence (from highest to lowest priority): + // + // 1. If ``request_body_buffer_limit`` is set, then ``request_body_buffer_limit`` will be used. + // 2. If :ref:`per_request_buffer_limit_bytes ` + // is set but ``request_body_buffer_limit`` is not, then ``min(per_request_buffer_limit_bytes, per_connection_buffer_limit_bytes)`` + // will be used. + // 3. If neither is set, then ``per_connection_buffer_limit_bytes`` will be used. + // + // For flow control chunk sizes, ``min(per_connection_buffer_limit_bytes, 16KB)`` will be used. + // + // Only one of :ref:`per_request_buffer_limit_bytes ` + // and ``request_body_buffer_limit`` could be set. + google.protobuf.UInt64Value request_body_buffer_limit = 25 + [(validate.rules).message = {required: false}]; // Specify a set of default request mirroring policies for every route under this virtual host. // It takes precedence over the route config mirror policy entirely. @@ -245,7 +279,7 @@ message RouteList { // // Envoy supports routing on HTTP method via :ref:`header matching // `. -// [#next-free-field: 20] +// [#next-free-field: 21] message Route { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.route.Route"; @@ -298,7 +332,7 @@ message Route { // This field can be used to provide route specific per filter config. The key should match the // :ref:`filter config name // `. - // See :ref:`Http filter route specific config ` + // See :ref:`HTTP filter route-specific config ` // for details. // [#comment: An entry's value may be wrapped in a // :ref:`FilterConfig` @@ -342,7 +376,14 @@ message Route { // The maximum bytes which will be buffered for retries and shadowing. // If set, the bytes actually buffered will be the minimum value of this and the // listener per_connection_buffer_limit_bytes. - google.protobuf.UInt32Value per_request_buffer_limit_bytes = 16; + // + // .. attention:: + // + // This field has been deprecated. Please use :ref:`request_body_buffer_limit + // ` instead. + // Only one of ``per_request_buffer_limit_bytes`` and ``request_body_buffer_limit`` may be set. + google.protobuf.UInt32Value per_request_buffer_limit_bytes = 16 + [deprecated = true, (envoy.annotations.deprecated_at_minor_version) = "3.0"]; // The human readable prefix to use when emitting statistics for this endpoint. // The statistics are rooted at vhost..route.. @@ -356,8 +397,27 @@ message Route { // // We do not recommend setting up a stat prefix for // every application endpoint. This is both not easily maintainable and - // statistics use a non-trivial amount of memory(approximately 1KiB per route). + // statistics use a non-trivial amount of memory (approximately 1KiB per route). string stat_prefix = 19; + + // The maximum bytes which will be buffered for request bodies to support large request body + // buffering beyond the ``per_connection_buffer_limit_bytes``. + // + // This limit is specifically for the request body buffering and allows buffering larger payloads while maintaining + // flow control. + // + // Buffer limit precedence (from highest to lowest priority): + // + // 1. If ``request_body_buffer_limit`` is set: use ``request_body_buffer_limit`` + // 2. If :ref:`per_request_buffer_limit_bytes ` + // is set but ``request_body_buffer_limit`` is not: use ``min(per_request_buffer_limit_bytes, per_connection_buffer_limit_bytes)`` + // 3. If neither is set: use ``per_connection_buffer_limit_bytes`` + // + // For flow control chunk sizes, use ``min(per_connection_buffer_limit_bytes, 16KB)``. + // + // Only one of :ref:`per_request_buffer_limit_bytes ` + // and ``request_body_buffer_limit`` may be set. + google.protobuf.UInt64Value request_body_buffer_limit = 20; } // Compared to the :ref:`cluster ` field that specifies a @@ -366,6 +426,7 @@ message Route { // multiple upstream clusters along with weights that indicate the percentage of // traffic to be forwarded to each cluster. The router selects an upstream cluster based on the // weights. +// [#next-free-field: 6] message WeightedCluster { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.route.WeightedCluster"; @@ -453,7 +514,7 @@ message WeightedCluster { // This field can be used to provide weighted cluster specific per filter config. The key should match the // :ref:`filter config name // `. - // See :ref:`Http filter route specific config ` + // See :ref:`HTTP filter route-specific config ` // for details. // [#comment: An entry's value may be wrapped in a // :ref:`FilterConfig` @@ -496,12 +557,18 @@ message WeightedCluster { // the process for the consistency. And the value is a unsigned number between 0 and UINT64_MAX. string header_name = 4 [(validate.rules).string = {well_known_regex: HTTP_HEADER_NAME strict: false}]; + + // When set to true, the hash policies will be used to generate the random value for weighted cluster selection. + // This could ensure consistent cluster picking across multiple proxy levels for weighted traffic. + google.protobuf.BoolValue use_hash_policy = 5; } } // Configuration for a cluster specifier plugin. message ClusterSpecifierPlugin { // The name of the plugin and its opaque configuration. + // + // [#extension-category: envoy.router.cluster_specifier_plugin] core.v3.TypedExtensionConfig extension = 1 [(validate.rules).message = {required: true}]; // If is_optional is not set or is set to false and the plugin defined by this message is not a @@ -512,7 +579,7 @@ message ClusterSpecifierPlugin { bool is_optional = 2; } -// [#next-free-field: 16] +// [#next-free-field: 18] message RouteMatch { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.route.RouteMatch"; @@ -570,7 +637,7 @@ message RouteMatch { // // [#next-major-version: In the v3 API we should redo how path specification works such // that we utilize StringMatcher, and additionally have consistent options around whether we - // strip query strings, do a case sensitive match, etc. In the interim it will be too disruptive + // strip query strings, do a case-sensitive match, etc. In the interim it will be too disruptive // to deprecate the existing options. We should even consider whether we want to do away with // path_specifier entirely and just rely on a set of header matchers which can already match // on :path, etc. The issue with that is it is unclear how to generically deal with query string @@ -602,7 +669,7 @@ message RouteMatch { core.v3.TypedExtensionConfig path_match_policy = 15; } - // Indicates that prefix/path matching should be case sensitive. The default + // Indicates that prefix/path matching should be case-sensitive. The default // is true. Ignored for safe_regex matching. google.protobuf.BoolValue case_sensitive = 4; @@ -642,14 +709,19 @@ message RouteMatch { // // If query parameters are used to pass request message fields when // `grpc_json_transcoder `_ - // is used, the transcoded message fields maybe different. The query parameters are - // url encoded, but the message fields are not. For example, if a query + // is used, the transcoded message fields may be different. The query parameters are + // URL-encoded, but the message fields are not. For example, if a query // parameter is "foo%20bar", the message field will be "foo bar". repeated QueryParameterMatcher query_parameters = 7; + // Specifies a set of cookies on which the route should match. The router parses the ``Cookie`` + // header and evaluates the named cookie against each matcher. If the number of specified cookie + // matchers is nonzero, they all must match for the route to be selected. + repeated CookieMatcher cookies = 17; + // If specified, only gRPC requests will be matched. The router will check - // that the content-type header has a application/grpc or one of the various - // application/grpc+ values. + // that the ``Content-Type`` header has ``application/grpc`` or one of the various + // ``application/grpc+`` values. GrpcRouteMatchOptions grpc = 8; // If specified, the client tls context will be matched against the defined @@ -663,6 +735,12 @@ message RouteMatch { // If the number of specified dynamic metadata matchers is nonzero, they all must match the // dynamic metadata for a match to occur. repeated type.matcher.v3.MetadataMatcher dynamic_metadata = 13; + + // Specifies a set of filter state matchers on which the route should match. + // The router will check the filter state against all the specified filter state matchers. + // If the number of specified filter state matchers is nonzero, they all must match the + // filter state for a match to occur. + repeated type.matcher.v3.FilterStateMatcher filter_state = 16; } // Cors policy configuration. @@ -673,7 +751,7 @@ message RouteMatch { // :ref:`CorsPolicy in filter extension ` // as as alternative. // -// [#next-free-field: 13] +// [#next-free-field: 14] message CorsPolicy { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.route.CorsPolicy"; @@ -727,9 +805,13 @@ message CorsPolicy { // // More details refer to https://developer.chrome.com/blog/private-network-access-preflight. google.protobuf.BoolValue allow_private_network_access = 12; + + // Specifies if preflight requests not matching the configured allowed origin should be forwarded + // to the upstream. Default is ``true``. + google.protobuf.BoolValue forward_not_matching_preflights = 13; } -// [#next-free-field: 42] +// [#next-free-field: 46] message RouteAction { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.route.RouteAction"; @@ -759,7 +841,8 @@ message RouteAction { // collected for the shadow cluster making this feature useful for testing. // // During shadowing, the host/authority header is altered such that ``-shadow`` is appended. This is - // useful for logging. For example, ``cluster1`` becomes ``cluster1-shadow``. + // useful for logging. For example, ``cluster1`` becomes ``cluster1-shadow``. This behavior can be + // disabled by setting ``disable_shadow_host_suffix_append`` to ``true``. // // .. note:: // @@ -767,8 +850,8 @@ message RouteAction { // // .. note:: // - // Shadowing doesn't support Http CONNECT and upgrades. - // [#next-free-field: 6] + // Shadowing doesn't support HTTP CONNECT and upgrades. + // [#next-free-field: 9] message RequestMirrorPolicy { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.route.RouteAction.RequestMirrorPolicy"; @@ -812,8 +895,30 @@ message RouteAction { // value, the request will be mirrored. core.v3.RuntimeFractionalPercent runtime_fraction = 3; - // Determines if the trace span should be sampled. Defaults to true. + // Specifies whether the trace span for the shadow request should be sampled. If this field is not explicitly set, + // the shadow request will inherit the sampling decision of its parent span. This ensures consistency with the trace + // sampling policy of the original request and prevents oversampling, especially in scenarios where runtime sampling + // is disabled. google.protobuf.BoolValue trace_sampled = 4; + + // Disables appending the ``-shadow`` suffix to the shadowed ``Host`` header. + // + // Defaults to ``false``. + bool disable_shadow_host_suffix_append = 6; + + // Specifies a list of header mutations that should be applied to each mirrored request. + // Header mutations are applied in the order they are specified. For more information, including + // details on header value syntax, see the documentation on :ref:`custom request headers + // `. + repeated common.mutation_rules.v3.HeaderMutation request_headers_mutations = 7 + [(validate.rules).repeated = {max_items: 1000}]; + + // Indicates that during mirroring, the host header will be swapped with this value. + // :ref:`disable_shadow_host_suffix_append + // ` + // is implicitly enabled if this field is set. + string host_rewrite_literal = 8 + [(validate.rules).string = {well_known_regex: HTTP_HEADER_VALUE strict: false}]; } // Specifies the route's hashing policy if the upstream cluster uses a hashing :ref:`load balancer @@ -975,13 +1080,15 @@ message RouteAction { bool allow_post = 2; } - // The case-insensitive name of this upgrade, e.g. "websocket". + // The case-insensitive name of this upgrade, for example, "websocket". // For each upgrade type present in upgrade_configs, requests with // Upgrade: [upgrade_type] will be proxied upstream. string upgrade_type = 1 [(validate.rules).string = {min_len: 1 well_known_regex: HTTP_HEADER_VALUE strict: false}]; - // Determines if upgrades are available on this route. Defaults to true. + // Determines if upgrades are available on this route. + // + // Defaults to ``true``. google.protobuf.BoolValue enabled = 2; // Configuration for sending data upstream as a raw data payload. This is used for @@ -1080,9 +1187,11 @@ message RouteAction { // place the original path before rewrite into the :ref:`x-envoy-original-path // ` header. // - // Only one of :ref:`regex_rewrite ` + // Only one of :ref:`regex_rewrite `, // :ref:`path_rewrite_policy `, - // or :ref:`prefix_rewrite ` may be specified. + // :ref:`path_rewrite `, + // or :ref:`prefix_rewrite ` + // may be specified. // // .. attention:: // @@ -1118,8 +1227,9 @@ message RouteAction { // ` header. // // Only one of :ref:`regex_rewrite `, - // :ref:`prefix_rewrite `, or - // :ref:`path_rewrite_policy `] + // :ref:`path_rewrite_policy `, + // :ref:`path_rewrite `, + // or :ref:`prefix_rewrite ` // may be specified. // // Examples using Google's `RE2 `_ engine: @@ -1143,12 +1253,48 @@ message RouteAction { // [#extension-category: envoy.path.rewrite] core.v3.TypedExtensionConfig path_rewrite_policy = 41; + // Rewrites the whole path (without query parameters) with the given path value. + // The router filter will + // place the original path before rewrite into the :ref:`x-envoy-original-path + // ` header. + // + // Only one of :ref:`regex_rewrite `, + // :ref:`path_rewrite_policy `, + // :ref:`path_rewrite `, + // or :ref:`prefix_rewrite ` + // may be specified. + // + // The :ref:`substitution format specifier ` could be applied here. + // For example, with the following config: + // + // .. code-block:: yaml + // + // path_rewrite: "/new_path_prefix%REQ(custom-path-header-name)%" + // + // Would rewrite the path to ``/new_path_prefix/some_value`` given the header + // ``custom-path-header-name: some_value``. If the header is not present, the path will be + // rewritten to ``/new_path_prefix``. + // + // + // If the final output of the path rewrite is empty, then the update will be ignored and the + // original path will be preserved. + string path_rewrite = 45; + + // If one of the host rewrite specifiers is set and the + // :ref:`suppress_envoy_headers + // ` flag is not + // set to true, the router filter will place the original host header value before + // rewriting into the :ref:`x-envoy-original-host + // ` header. + // + // And if the + // :ref:`append_x_forwarded_host ` + // is set to true, the original host value will also be appended to the + // :ref:`config_http_conn_man_headers_x-forwarded-host` header. + // oneof host_rewrite_specifier { // Indicates that during forwarding, the host header will be swapped with - // this value. Using this option will append the - // :ref:`config_http_conn_man_headers_x-forwarded-host` header if - // :ref:`append_x_forwarded_host ` - // is set. + // this value. string host_rewrite_literal = 6 [(validate.rules).string = {well_known_regex: HTTP_HEADER_VALUE strict: false}]; @@ -1158,18 +1304,12 @@ message RouteAction { // type ``strict_dns`` or ``logical_dns``, // or when :ref:`hostname ` // field is not empty. Setting this to true with other cluster types - // has no effect. Using this option will append the - // :ref:`config_http_conn_man_headers_x-forwarded-host` header if - // :ref:`append_x_forwarded_host ` - // is set. + // has no effect. google.protobuf.BoolValue auto_host_rewrite = 7; // Indicates that during forwarding, the host header will be swapped with the content of given // downstream or :ref:`custom ` header. - // If header value is empty, host header is left intact. Using this option will append the - // :ref:`config_http_conn_man_headers_x-forwarded-host` header if - // :ref:`append_x_forwarded_host ` - // is set. + // If header value is empty, host header is left intact. // // .. attention:: // @@ -1185,10 +1325,6 @@ message RouteAction { // Indicates that during forwarding, the host header will be swapped with // the result of the regex substitution executed on path value with query and fragment removed. // This is useful for transitioning variable content between path segment and subdomain. - // Using this option will append the - // :ref:`config_http_conn_man_headers_x-forwarded-host` header if - // :ref:`append_x_forwarded_host ` - // is set. // // For example with the following config: // @@ -1202,6 +1338,25 @@ message RouteAction { // // Would rewrite the host header to ``envoyproxy.io`` given the path ``/envoyproxy.io/some/path``. type.matcher.v3.RegexMatchAndSubstitute host_rewrite_path_regex = 35; + + // Rewrites the host header with the value of this field. The router filter will + // place the original host header value before rewriting into the :ref:`x-envoy-original-host + // ` header. + // + // The :ref:`substitution format specifier ` could be applied here. + // For example, with the following config: + // + // .. code-block:: yaml + // + // host_rewrite: "prefix-%REQ(custom-host-header-name)%" + // + // Would rewrite the host header to ``prefix-some_value`` given the header + // ``custom-host-header-name: some_value``. If the header is not present, the host header will + // be rewritten to an value of ``prefix-``. + // + // If the final output of the host rewrite is empty, then the update will be ignored and the + // original host header will be preserved. + string host_rewrite = 44; } // If set, then a host rewrite action (one of @@ -1211,7 +1366,6 @@ message RouteAction { // :ref:`host_rewrite_path_regex `) // causes the original value of the host header, if any, to be appended to the // :ref:`config_http_conn_man_headers_x-forwarded-host` HTTP header if it is different to the last value appended. - // This can be disabled by setting the runtime guard ``envoy_reloadable_features_append_xfh_idempotent`` to false. bool append_x_forwarded_host = 38; // Specifies the upstream timeout for the route. If not specified, the default is 15s. This @@ -1249,8 +1403,28 @@ message RouteAction { // If the :ref:`overload action ` "envoy.overload_actions.reduce_timeouts" // is configured, this timeout is scaled according to the value for // :ref:`HTTP_DOWNSTREAM_STREAM_IDLE `. + // + // This timeout may also be used in place of ``flush_timeout`` in very specific cases. See the + // documentation for ``flush_timeout`` for more details. google.protobuf.Duration idle_timeout = 24; + // Specifies the codec stream flush timeout for the route. + // + // If not specified, the first preference is the global :ref:`stream_flush_timeout + // `, + // but only if explicitly configured. + // + // If neither the explicit HCM-wide flush timeout nor this route-specific flush timeout is configured, + // the route's stream idle timeout is reused for this timeout. This is for + // backwards compatibility since both behaviors were historically controlled by the one timeout. + // + // If the route also does not have an idle timeout configured, the global :ref:`stream_idle_timeout + // `. used, again + // for backwards compatibility. That timeout defaults to 5 minutes. + // + // A value of 0 via any of the above paths will completely disable the timeout for a given route. + google.protobuf.Duration flush_timeout = 42; + // Specifies how to send request over TLS early data. // If absent, allows `safe HTTP requests `_ to be sent on early data. // [#extension-category: envoy.route.early_data_policy] @@ -1258,13 +1432,13 @@ message RouteAction { // Indicates that the route has a retry policy. Note that if this is set, // it'll take precedence over the virtual host level retry policy entirely - // (e.g.: policies are not merged, most internal one becomes the enforced policy). + // (e.g., policies are not merged, the most internal one becomes the enforced policy). RetryPolicy retry_policy = 9; // [#not-implemented-hide:] // Specifies the configuration for retry policy extension. Note that if this is set, it'll take - // precedence over the virtual host level retry policy entirely (e.g.: policies are not merged, - // most internal one becomes the enforced policy). :ref:`Retry policy ` + // precedence over the virtual host level retry policy entirely (e.g., policies are not merged, + // the most internal one becomes the enforced policy). :ref:`Retry policy ` // should not be set if this field is used. google.protobuf.Any retry_policy_typed_config = 33; @@ -1285,7 +1459,9 @@ message RouteAction { // :ref:`rate_limits ` are not applied to the // request. // - // This field is deprecated. Please use :ref:`vh_rate_limits ` + // .. attention:: + // + // This field is deprecated. Please use :ref:`vh_rate_limits ` google.protobuf.BoolValue include_vh_rate_limits = 14 [deprecated = true, (envoy.annotations.deprecated_at_minor_version) = "3.0"]; @@ -1379,7 +1555,7 @@ message RouteAction { // Indicates that the route has a hedge policy. Note that if this is set, // it'll take precedence over the virtual host level hedge policy entirely - // (e.g.: policies are not merged, most internal one becomes the enforced policy). + // (e.g., policies are not merged, the most internal one becomes the enforced policy). HedgePolicy hedge_policy = 27; // Specifies the maximum stream duration for this route. @@ -1513,7 +1689,9 @@ message RetryPolicy { // Specifies the maximum back off interval that Envoy will allow. If a reset // header contains an interval longer than this then it will be discarded and - // the next header will be tried. Defaults to 300 seconds. + // the next header will be tried. + // + // Defaults to 300 seconds. google.protobuf.Duration max_interval = 2 [(validate.rules).duration = {gt {}}]; } @@ -1542,7 +1720,7 @@ message RetryPolicy { google.protobuf.Duration per_try_timeout = 3; // Specifies an upstream idle timeout per retry attempt (including the initial attempt). This - // parameter is optional and if absent there is no per try idle timeout. The semantics of the per + // parameter is optional and if absent there is no per-try idle timeout. The semantics of the per- // try idle timeout are similar to the // :ref:`route idle timeout ` and // :ref:`stream idle timeout @@ -1617,12 +1795,14 @@ message HedgePolicy { // Specifies the number of initial requests that should be sent upstream. // Must be at least 1. + // // Defaults to 1. // [#not-implemented-hide:] google.protobuf.UInt32Value initial_requests = 1 [(validate.rules).uint32 = {gte: 1}]; // Specifies a probability that an additional upstream request should be sent // on top of what is specified by initial_requests. + // // Defaults to 0. // [#not-implemented-hide:] type.v3.FractionalPercent additional_request_chance = 2; @@ -1632,14 +1812,16 @@ message HedgePolicy { // The first request to complete successfully will be the one returned to the caller. // // * At any time, a successful response (i.e. not triggering any of the retry-on conditions) would be returned to the client. - // * Before per-try timeout, an error response (per retry-on conditions) would be retried immediately or returned ot the client + // * Before per-try timeout, an error response (per retry-on conditions) would be retried immediately or returned to the client // if there are no more retries left. // * After per-try timeout, an error response would be discarded, as a retry in the form of a hedged request is already in progress. // - // Note: For this to have effect, you must have a :ref:`RetryPolicy ` that retries at least - // one error code and specifies a maximum number of retries. + // .. note:: + // + // For this to have effect, you must have a :ref:`RetryPolicy ` that retries at least + // one error code and specifies a maximum number of retries. // - // Defaults to false. + // Defaults to ``false``. bool hedge_on_per_try_timeout = 3; } @@ -1766,6 +1948,12 @@ message DirectResponseAction { // :ref:`envoy_v3_api_msg_config.route.v3.Route`, :ref:`envoy_v3_api_msg_config.route.v3.RouteConfiguration` or // :ref:`envoy_v3_api_msg_config.route.v3.VirtualHost`. core.v3.DataSource body = 2; + + // Specifies a format string for the response body. If present, the contents of + // ``body_format`` will be formatted and used as the response body, where the + // contents of ``body`` (may be empty) will be passed as the variable ``%LOCAL_REPLY_BODY%``. + // If neither are provided, no body is included in the generated response. + core.v3.SubstitutionFormatString body_format = 3; } // [#not-implemented-hide:] @@ -1785,10 +1973,11 @@ message Decorator { // ` header. string operation = 1 [(validate.rules).string = {min_len: 1}]; - // Whether the decorated details should be propagated to the other party. The default is true. + // Whether the decorated details should be propagated to the other party. The default is ``true``. google.protobuf.BoolValue propagate = 2; } +// [#next-free-field: 7] message Tracing { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.route.Tracing"; @@ -1824,6 +2013,34 @@ message Tracing { // each in the HTTP connection manager and the route level, the one configured here takes // priority. repeated type.tracing.v3.CustomTag custom_tags = 4; + + // The operation name of the span which will be used for tracing. + // + // The same :ref:`format specifier ` as used for + // :ref:`HTTP access logging ` applies here, however + // unknown specifier values are replaced with the empty string instead of ``-``. + // + // This field will take precedence over and make following settings ineffective: + // + // * :ref:`route decorator `. + // * :ref:`x-envoy-decorator-operation `. + // * :ref:`HCM tracing operation + // `. + string operation = 5; + + // The operation name of the upstream span which will be used for tracing. + // This only takes effect when ``spawn_upstream_span`` is set to true and the upstream + // span is created. + // + // The same :ref:`format specifier ` as used for + // :ref:`HTTP access logging ` applies here, however + // unknown specifier values are replaced with the empty string instead of ``-``. + // + // This field will take precedence over and make following settings ineffective: + // + // * :ref:`HCM tracing upstream operation + // ` + string upstream_operation = 6; } // A virtual cluster is a way of specifying a regex matching rule against @@ -1863,10 +2080,11 @@ message VirtualCluster { // Global rate limiting :ref:`architecture overview `. // Also applies to Local rate limiting :ref:`using descriptors `. +// [#next-free-field: 7] message RateLimit { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.route.RateLimit"; - // [#next-free-field: 12] + // [#next-free-field: 13] message Action { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.route.RateLimit.Action"; @@ -1923,9 +2141,48 @@ message RateLimit { // The key to use in the descriptor entry. string descriptor_key = 2 [(validate.rules).string = {min_len: 1}]; - // If set to true, Envoy skips the descriptor while calling rate limiting service - // when header is not present in the request. By default it skips calling the - // rate limiting service if this header is not present in the request. + // Controls the behavior when the specified header is not present in the request. + // + // If set to ``false`` (default): + // + // * Envoy does **NOT** call the rate limiting service for this descriptor. + // * Useful if the header is optional and you prefer to skip rate limiting when it's absent. + // + // If set to ``true``: + // + // * Envoy calls the rate limiting service but omits this descriptor if the header is missing. + // * Useful if you want Envoy to enforce rate limiting even when the header is not present. + // + bool skip_if_absent = 3; + } + + // The following descriptor entry is appended when a query parameter contains a key that matches the + // ``query_parameter_name``: + // + // .. code-block:: cpp + // + // ("", "") + message QueryParameters { + // The name of the query parameter to use for rate limiting. Value of this query parameter is used to populate + // the value of the descriptor entry for the descriptor_key. + string query_parameter_name = 1 [(validate.rules).string = {min_len: 1}]; + + // The key to use when creating the rate limit descriptor entry. This descriptor key will be used to identify the + // rate limit rule in the rate limiting service. + string descriptor_key = 2 [(validate.rules).string = {min_len: 1}]; + + // Controls the behavior when the specified query parameter is not present in the request. + // + // If set to ``false`` (default): + // + // * Envoy does **NOT** call the rate limiting service for this descriptor. + // * Useful if the query parameter is optional and you prefer to skip rate limiting when it's absent. + // + // If set to ``true``: + // + // * Envoy calls the rate limiting service but omits this descriptor if the query parameter is missing. + // * Useful if you want Envoy to enforce rate limiting even when the query parameter is not present. + // bool skip_if_absent = 3; } @@ -1948,14 +2205,18 @@ message RateLimit { // ("masked_remote_address", "") message MaskedRemoteAddress { // Length of prefix mask len for IPv4 (e.g. 0, 32). + // // Defaults to 32 when unset. + // // For example, trusted address from x-forwarded-for is ``192.168.1.1``, // the descriptor entry is ("masked_remote_address", "192.168.1.1/32"); // if mask len is 24, the descriptor entry is ("masked_remote_address", "192.168.1.0/24"). google.protobuf.UInt32Value v4_prefix_mask_len = 1 [(validate.rules).uint32 = {lte: 32}]; // Length of prefix mask len for IPv6 (e.g. 0, 128). + // // Defaults to 128 when unset. + // // For example, trusted address from x-forwarded-for is ``2001:abcd:ef01:2345:6789:abcd:ef01:234``, // the descriptor entry is ("masked_remote_address", "2001:abcd:ef01:2345:6789:abcd:ef01:234/128"); // if mask len is 64, the descriptor entry is ("masked_remote_address", "2001:abcd:ef01:2345::/64"). @@ -1971,9 +2232,40 @@ message RateLimit { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.route.RateLimit.Action.GenericKey"; - // The value to use in the descriptor entry. + // Descriptor value of entry. + // + // The same :ref:`format specifier ` as used for + // :ref:`HTTP access logging ` applies here, however + // unknown specifier values are replaced with the empty string instead of ``-``. + // + // .. note:: + // + // Formatter parsing is controlled by the runtime feature flag + // ``envoy.reloadable_features.enable_formatter_for_ratelimit_action_descriptor_value`` + // (disabled by default). + // + // When enabled: The format string can contain multiple valid substitution + // fields. If multiple substitution fields are present, their results will be concatenated + // to form the final descriptor value. If it contains no substitution fields, the value + // will be used as is. If the final concatenated result is empty and ``default_value`` is set, + // the ``default_value`` will be used. If ``default_value`` is not set and the result is + // empty, this descriptor will be skipped and not included in the rate limit call. + // + // When disabled (default): The descriptor_value is used as a literal string without any formatter + // parsing or substitution. + // + // For example, ``static_value`` will be used as is since there are no substitution fields. + // ``%REQ(:method)%`` will be replaced with the HTTP method, and + // ``%REQ(:method)%%REQ(:path)%`` will be replaced with the concatenation of the HTTP method and path. + // ``%CEL(request.headers['user-id'])%`` will use CEL to extract the user ID from request headers. + // string descriptor_value = 1 [(validate.rules).string = {min_len: 1}]; + // An optional value to use if the final concatenated ``descriptor_value`` result is empty. + // Only applicable when formatter parsing is enabled by the runtime feature flag + // ``envoy.reloadable_features.enable_formatter_for_ratelimit_action_descriptor_value`` (disabled by default). + string default_value = 3; + // An optional key to use in the descriptor entry. If not set it defaults // to 'generic_key' as the descriptor key. string descriptor_key = 2; @@ -1984,16 +2276,51 @@ message RateLimit { // .. code-block:: cpp // // ("header_match", "") + // [#next-free-field: 6] message HeaderValueMatch { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.route.RateLimit.Action.HeaderValueMatch"; - // The key to use in the descriptor entry. Defaults to ``header_match``. - string descriptor_key = 4; - - // The value to use in the descriptor entry. + // Descriptor value of entry. + // + // The same :ref:`format specifier ` as used for + // :ref:`HTTP access logging ` applies here, however + // unknown specifier values are replaced with the empty string instead of ``-``. + // + // .. note:: + // + // Formatter parsing is controlled by the runtime feature flag + // ``envoy.reloadable_features.enable_formatter_for_ratelimit_action_descriptor_value`` + // (disabled by default). + // + // When enabled: The format string can contain multiple valid substitution + // fields. If multiple substitution fields are present, their results will be concatenated + // to form the final descriptor value. If it contains no substitution fields, the value + // will be used as is. All substitution fields will be evaluated and their results + // concatenated. If the final concatenated result is empty and ``default_value`` is set, + // the ``default_value`` will be used. If ``default_value`` is not set and the result is + // empty, this descriptor will be skipped and not included in the rate limit call. + // + // When disabled (default): The descriptor_value is used as a literal string without any formatter + // parsing or substitution. + // + // For example, ``static_value`` will be used as is since there are no substitution fields. + // ``%REQ(:method)%`` will be replaced with the HTTP method, and + // ``%REQ(:method)%%REQ(:path)%`` will be replaced with the concatenation of the HTTP method and path. + // ``%CEL(request.headers['user-id'])%`` will use CEL to extract the user ID from request headers. + // string descriptor_value = 1 [(validate.rules).string = {min_len: 1}]; + // An optional value to use if the final concatenated ``descriptor_value`` result is empty. + // Only applicable when formatter parsing is enabled by the runtime feature flag + // ``envoy.reloadable_features.enable_formatter_for_ratelimit_action_descriptor_value`` (disabled by default). + string default_value = 5; + + // The key to use in the descriptor entry. + // + // Defaults to ``header_match``. + string descriptor_key = 4; + // If set to true, the action will append a descriptor entry when the // request matches the headers. If set to false, the action will append a // descriptor entry when the request does not match the headers. The @@ -2001,7 +2328,7 @@ message RateLimit { google.protobuf.BoolValue expect_match = 2; // Specifies a set of headers that the rate limit action should match - // on. The action will check the request’s headers against all the + // on. The action will check the request's headers against all the // specified headers in the config. A match will happen if all the // headers in the config are present in the request with the same values // (or based on presence if the value field is not in the config). @@ -2060,9 +2387,19 @@ message RateLimit { // Source of metadata Source source = 4 [(validate.rules).enum = {defined_only: true}]; - // If set to true, Envoy skips the descriptor while calling rate limiting service - // when ``metadata_key`` is empty and ``default_value`` is not set. By default it skips calling the - // rate limiting service in that case. + // Controls the behavior when the specified ``metadata_key`` is empty and ``default_value`` is not set. + // + // If set to ``false`` (default): + // + // * Envoy does **NOT** call the rate limiting service for this descriptor. + // * Useful if the metadata is optional and you prefer to skip rate limiting when it's absent. + // + // If set to ``true``: + // + // * Envoy calls the rate limiting service but omits this descriptor if the ``metadata_key`` is empty and + // ``default_value`` is missing. + // * Useful if you want Envoy to enforce rate limiting even when the metadata is not present. + // bool skip_if_absent = 5; } @@ -2071,13 +2408,48 @@ message RateLimit { // .. code-block:: cpp // // ("query_match", "") + // [#next-free-field: 6] message QueryParameterValueMatch { - // The key to use in the descriptor entry. Defaults to ``query_match``. - string descriptor_key = 4; - - // The value to use in the descriptor entry. + // Descriptor value of entry. + // + // The same :ref:`format specifier ` as used for + // :ref:`HTTP access logging ` applies here, however + // unknown specifier values are replaced with the empty string instead of ``-``. + // + // .. note:: + // + // Formatter parsing is controlled by the runtime feature flag + // ``envoy.reloadable_features.enable_formatter_for_ratelimit_action_descriptor_value`` + // (disabled by default). + // + // When enabled: The format string can contain multiple valid substitution + // fields. If multiple substitution fields are present, their results will be concatenated + // to form the final descriptor value. If it contains no substitution fields, the value + // will be used as is. All substitution fields will be evaluated and their results + // concatenated. If the final concatenated result is empty and ``default_value`` is set, + // the ``default_value`` will be used. If ``default_value`` is not set and the result is + // empty, this descriptor will be skipped and not included in the rate limit call. + // + // When disabled (default): The descriptor_value is used as a literal string without any formatter + // parsing or substitution. + // + // For example, ``static_value`` will be used as is since there are no substitution fields. + // ``%REQ(:method)%`` will be replaced with the HTTP method, and + // ``%REQ(:method)%%REQ(:path)%`` will be replaced with the concatenation of the HTTP method and path. + // ``%CEL(request.headers['user-id'])%`` will use CEL to extract the user ID from request headers. + // string descriptor_value = 1 [(validate.rules).string = {min_len: 1}]; + // An optional value to use if the final concatenated ``descriptor_value`` result is empty. + // Only applicable when formatter parsing is enabled by the runtime feature flag + // ``envoy.reloadable_features.enable_formatter_for_ratelimit_action_descriptor_value`` (disabled by default). + string default_value = 5; + + // The key to use in the descriptor entry. + // + // Defaults to ``query_match``. + string descriptor_key = 4; + // If set to true, the action will append a descriptor entry when the // request matches the headers. If set to false, the action will append a // descriptor entry when the request does not match the headers. The @@ -2085,7 +2457,7 @@ message RateLimit { google.protobuf.BoolValue expect_match = 2; // Specifies a set of query parameters that the rate limit action should match - // on. The action will check the request’s query parameters against all the + // on. The action will check the request's query parameters against all the // specified query parameters in the config. A match will happen if all the // query parameters in the config are present in the request with the same values // (or based on presence if the value field is not in the config). @@ -2105,6 +2477,9 @@ message RateLimit { // Rate limit on request headers. RequestHeaders request_headers = 3; + // Rate limit on query parameters. + QueryParameters query_parameters = 12; + // Rate limit on remote address. RemoteAddress remote_address = 4; @@ -2163,6 +2538,33 @@ message RateLimit { } } + message HitsAddend { + // Fixed number of hits to add to the rate limit descriptor. + // + // One of the ``number`` or ``format`` fields should be set but not both. + google.protobuf.UInt64Value number = 1 [(validate.rules).uint64 = {lte: 1000000000}]; + + // Substitution format string to extract the number of hits to add to the rate limit descriptor. + // The same :ref:`format specifier ` as used for + // :ref:`HTTP access logging ` applies here. + // + // .. note:: + // + // The format string must contains only single valid substitution field. If the format string + // not meets the requirement, the configuration will be rejected. + // + // The substitution field should generates a non-negative number or string representation of + // a non-negative number. The value of the non-negative number should be less than or equal + // to 1000000000 like the ``number`` field. If the output of the substitution field not meet + // the requirement, this will be treated as an error and the current descriptor will be ignored. + // + // For example, the ``%BYTES_RECEIVED%`` format string will be replaced with the number of bytes + // received in the request. + // + // One of the ``number`` or ``format`` fields should be set but not both. + string format = 2 [(validate.rules).string = {prefix: "%" suffix: "%" ignore_empty: true}]; + } + // Refers to the stage set in the filter. The rate limit configuration only // applies to filters with the same stage number. The default stage number is // 0. @@ -2170,9 +2572,19 @@ message RateLimit { // .. note:: // // The filter supports a range of 0 - 10 inclusively for stage numbers. + // + // .. note:: + // This is not supported if the rate limit action is configured in the ``typed_per_filter_config`` like + // :ref:`VirtualHost.typed_per_filter_config` or + // :ref:`Route.typed_per_filter_config`, etc. google.protobuf.UInt32Value stage = 1 [(validate.rules).uint32 = {lte: 10}]; // The key to be set in runtime to disable this rate limit configuration. + // + // .. note:: + // This is not supported if the rate limit action is configured in the ``typed_per_filter_config`` like + // :ref:`VirtualHost.typed_per_filter_config` or + // :ref:`Route.typed_per_filter_config`, etc. string disable_key = 2; // A list of actions that are to be applied for this rate limit configuration. @@ -2187,7 +2599,38 @@ message RateLimit { // rate limit configuration. If the override value is invalid or cannot be resolved // from metadata, no override is provided. See :ref:`rate limit override // ` for more information. + // + // .. note:: + // This is not supported if the rate limit action is configured in the ``typed_per_filter_config`` like + // :ref:`VirtualHost.typed_per_filter_config` or + // :ref:`Route.typed_per_filter_config`, etc. Override limit = 4; + + // An optional hits addend to be appended to the descriptor produced by this rate limit + // configuration. + // + // .. note:: + // This is only supported if the rate limit action is configured in the ``typed_per_filter_config`` like + // :ref:`VirtualHost.typed_per_filter_config` or + // :ref:`Route.typed_per_filter_config`, etc. + HitsAddend hits_addend = 5; + + // If true, the rate limit request will be applied when the stream completes. The default value is false. + // This is useful when the rate limit budget needs to reflect the response context that is not available + // on the request path. + // + // For example, let's say the upstream service calculates the usage statistics and returns them in the response body + // and we want to utilize these numbers to apply the rate limit action for the subsequent requests. + // Combined with another filter that can set the desired addend based on the response (e.g. Lua filter), + // this can be used to subtract the usage statistics from the rate limit budget. + // + // A rate limit applied on the stream completion is "fire-and-forget" by nature, and rate limit is not enforced by this config. + // In other words, the current request won't be blocked when this is true, but the budget will be updated for the subsequent + // requests based on the action with this field set to true. Users should ensure that the rate limit is enforced by the actions + // applied on the request path, i.e. the ones with this field set to false. + // + // Currently, this is only supported by the HTTP global rate filter. + bool apply_on_stream_done = 6; } // .. attention:: @@ -2231,14 +2674,20 @@ message HeaderMatcher { // Specifies how the header match will be performed to route the request. oneof header_match_specifier { // If specified, header match will be performed based on the value of the header. - // This field is deprecated. Please use :ref:`string_match `. + // + // .. attention:: + // + // This field is deprecated. Please use :ref:`string_match `. string exact_match = 4 [deprecated = true, (envoy.annotations.deprecated_at_minor_version) = "3.0"]; // If specified, this regex string is a regular expression rule which implies the entire request // header value must match the regex. The rule will not match if only a subsequence of the // request header value matches the regex. - // This field is deprecated. Please use :ref:`string_match `. + // + // .. attention:: + // + // This field is deprecated. Please use :ref:`string_match `. type.matcher.v3.RegexMatcher safe_regex_match = 11 [deprecated = true, (envoy.annotations.deprecated_at_minor_version) = "3.0"]; @@ -2260,8 +2709,14 @@ message HeaderMatcher { bool present_match = 7; // If specified, header match will be performed based on the prefix of the header value. - // Note: empty prefix is not allowed, please use present_match instead. - // This field is deprecated. Please use :ref:`string_match `. + // + // .. note:: + // + // Empty prefix is not allowed. Please use ``present_match`` instead. + // + // .. attention:: + // + // This field is deprecated. Please use :ref:`string_match `. // // Examples: // @@ -2273,8 +2728,14 @@ message HeaderMatcher { ]; // If specified, header match will be performed based on the suffix of the header value. - // Note: empty suffix is not allowed, please use present_match instead. - // This field is deprecated. Please use :ref:`string_match `. + // + // .. note:: + // + // Empty suffix is not allowed. Please use ``present_match`` instead. + // + // .. attention:: + // + // This field is deprecated. Please use :ref:`string_match `. // // Examples: // @@ -2287,8 +2748,14 @@ message HeaderMatcher { // If specified, header match will be performed based on whether the header value contains // the given value or not. - // Note: empty contains match is not allowed, please use present_match instead. - // This field is deprecated. Please use :ref:`string_match `. + // + // .. note:: + // + // Empty contains match is not allowed. Please use ``present_match`` instead. + // + // .. attention:: + // + // This field is deprecated. Please use :ref:`string_match `. // // Examples: // @@ -2303,7 +2770,9 @@ message HeaderMatcher { type.matcher.v3.StringMatcher string_match = 13; } - // If specified, the match result will be inverted before checking. Defaults to false. + // If specified, the match result will be inverted before checking. + // + // Defaults to ``false``. // // Examples: // @@ -2312,7 +2781,9 @@ message HeaderMatcher { bool invert_match = 8; // If specified, for any header match rule, if the header match rule specified header - // does not exist, this header value will be treated as empty. Defaults to false. + // does not exist, this header value will be treated as empty. + // + // Defaults to ``false``. // // Examples: // @@ -2364,6 +2835,20 @@ message QueryParameterMatcher { } } +// Cookie matching inspects individual name/value pairs parsed from the ``Cookie`` header. +message CookieMatcher { + // Specifies the cookie name to evaluate. + string name = 1 [(validate.rules).string = {min_len: 1 max_bytes: 1024}]; + + // Match the cookie value using :ref:`StringMatcher + // ` semantics. + type.matcher.v3.StringMatcher string_match = 2 [(validate.rules).message = {required: true}]; + + // Invert the match result. If the cookie is not present, the match result is false, so + // ``invert_match`` will cause the matcher to succeed when the cookie is absent. + bool invert_match = 3; +} + // HTTP Internal Redirect :ref:`architecture overview `. // [#next-free-field: 6] message InternalRedirectPolicy { @@ -2389,7 +2874,7 @@ message InternalRedirectPolicy { repeated core.v3.TypedExtensionConfig predicates = 3; // Allow internal redirect to follow a target URI with a different scheme than the value of - // x-forwarded-proto. The default is false. + // x-forwarded-proto. The default is ``false``. bool allow_cross_scheme_redirect = 4; // Specifies a list of headers, by name, to copy from the internal redirect into the subsequent @@ -2429,6 +2914,5 @@ message FilterConfig { // initial route will not be added back to the filter chain because the filter chain is already // created and it is too late to change the chain. // - // This field only make sense for the downstream HTTP filters for now. bool disabled = 3; } diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/trace/v3/datadog.proto b/xds/third_party/envoy/src/main/proto/envoy/config/trace/v3/datadog.proto index bed6c8eec36..5359ec74267 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/config/trace/v3/datadog.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/config/trace/v3/datadog.proto @@ -2,6 +2,8 @@ syntax = "proto3"; package envoy.config.trace.v3; +import "google/protobuf/duration.proto"; + import "udpa/annotations/migrate.proto"; import "udpa/annotations/status.proto"; import "udpa/annotations/versioning.proto"; @@ -16,6 +18,13 @@ option (udpa.annotations.file_status).package_version_status = ACTIVE; // [#protodoc-title: Datadog tracer] +// Configuration for the Remote Configuration feature. +message DatadogRemoteConfig { + // Frequency at which new configuration updates are queried. + // If no value is provided, the default value is delegated to the Datadog tracing library. + google.protobuf.Duration polling_interval = 1; +} + // Configuration for the Datadog tracer. // [#extension: envoy.tracers.datadog] message DatadogConfig { @@ -31,4 +40,11 @@ message DatadogConfig { // Optional hostname to use when sending spans to the collector_cluster. Useful for collectors // that require a specific hostname. Defaults to :ref:`collector_cluster ` above. string collector_hostname = 3; + + // Enables and configures remote configuration. + // Remote Configuration allows to configure the tracer from Datadog's user interface. + // This feature can drastically increase the number of connections to the Datadog Agent. + // Each tracer regularly polls for configuration updates, and the number of tracers is the product + // of the number of listeners and worker threads. + DatadogRemoteConfig remote_config = 4; } diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/trace/v3/dynamic_ot.proto b/xds/third_party/envoy/src/main/proto/envoy/config/trace/v3/dynamic_ot.proto index 35971f30dfb..40fe8526a5f 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/config/trace/v3/dynamic_ot.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/config/trace/v3/dynamic_ot.proto @@ -20,10 +20,10 @@ option (udpa.annotations.file_status).package_version_status = ACTIVE; // [#protodoc-title: Dynamically loadable OpenTracing tracer] -// DynamicOtConfig is used to dynamically load a tracer from a shared library +// DynamicOtConfig was used to dynamically load a tracer from a shared library // that implements the `OpenTracing dynamic loading API // `_. -// [#extension: envoy.tracers.dynamic_ot] +// [#not-implemented-hide:] message DynamicOtConfig { option (udpa.annotations.versioning).previous_message_type = "envoy.config.trace.v2.DynamicOtConfig"; @@ -33,11 +33,15 @@ message DynamicOtConfig { string library = 1 [ deprecated = true, (validate.rules).string = {min_len: 1}, - (envoy.annotations.deprecated_at_minor_version) = "3.0" + (envoy.annotations.deprecated_at_minor_version) = "3.0", + (envoy.annotations.disallowed_by_default) = true ]; // The configuration to use when creating a tracer from the given dynamic // library. - google.protobuf.Struct config = 2 - [deprecated = true, (envoy.annotations.deprecated_at_minor_version) = "3.0"]; + google.protobuf.Struct config = 2 [ + deprecated = true, + (envoy.annotations.deprecated_at_minor_version) = "3.0", + (envoy.annotations.disallowed_by_default) = true + ]; } diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/trace/v3/opencensus.proto b/xds/third_party/envoy/src/main/proto/envoy/config/trace/v3/opencensus.proto deleted file mode 100644 index 86a986a24e4..00000000000 --- a/xds/third_party/envoy/src/main/proto/envoy/config/trace/v3/opencensus.proto +++ /dev/null @@ -1,117 +0,0 @@ -syntax = "proto3"; - -package envoy.config.trace.v3; - -import "envoy/config/core/v3/grpc_service.proto"; - -import "opencensus/proto/trace/v1/trace_config.proto"; - -import "envoy/annotations/deprecation.proto"; -import "udpa/annotations/migrate.proto"; -import "udpa/annotations/status.proto"; -import "udpa/annotations/versioning.proto"; - -option java_package = "io.envoyproxy.envoy.config.trace.v3"; -option java_outer_classname = "OpencensusProto"; -option java_multiple_files = true; -option go_package = "github.com/envoyproxy/go-control-plane/envoy/config/trace/v3;tracev3"; -option (udpa.annotations.file_migrate).move_to_package = - "envoy.extensions.tracers.opencensus.v4alpha"; -option (udpa.annotations.file_status).package_version_status = ACTIVE; - -// [#protodoc-title: OpenCensus tracer] - -// Configuration for the OpenCensus tracer. -// [#next-free-field: 15] -// [#extension: envoy.tracers.opencensus] -message OpenCensusConfig { - option (udpa.annotations.versioning).previous_message_type = - "envoy.config.trace.v2.OpenCensusConfig"; - - enum TraceContext { - // No-op default, no trace context is utilized. - NONE = 0; - - // W3C Trace-Context format "traceparent:" header. - TRACE_CONTEXT = 1; - - // Binary "grpc-trace-bin:" header. - GRPC_TRACE_BIN = 2; - - // "X-Cloud-Trace-Context:" header. - CLOUD_TRACE_CONTEXT = 3; - - // X-B3-* headers. - B3 = 4; - } - - reserved 7; - - // Configures tracing, e.g. the sampler, max number of annotations, etc. - opencensus.proto.trace.v1.TraceConfig trace_config = 1 - [deprecated = true, (envoy.annotations.deprecated_at_minor_version) = "3.0"]; - - // Enables the stdout exporter if set to true. This is intended for debugging - // purposes. - bool stdout_exporter_enabled = 2 - [deprecated = true, (envoy.annotations.deprecated_at_minor_version) = "3.0"]; - - // Enables the Stackdriver exporter if set to true. The project_id must also - // be set. - bool stackdriver_exporter_enabled = 3 - [deprecated = true, (envoy.annotations.deprecated_at_minor_version) = "3.0"]; - - // The Cloud project_id to use for Stackdriver tracing. - string stackdriver_project_id = 4 - [deprecated = true, (envoy.annotations.deprecated_at_minor_version) = "3.0"]; - - // (optional) By default, the Stackdriver exporter will connect to production - // Stackdriver. If stackdriver_address is non-empty, it will instead connect - // to this address, which is in the gRPC format: - // https://github.com/grpc/grpc/blob/master/doc/naming.md - string stackdriver_address = 10 - [deprecated = true, (envoy.annotations.deprecated_at_minor_version) = "3.0"]; - - // (optional) The gRPC server that hosts Stackdriver tracing service. Only - // Google gRPC is supported. If :ref:`target_uri ` - // is not provided, the default production Stackdriver address will be used. - core.v3.GrpcService stackdriver_grpc_service = 13 - [deprecated = true, (envoy.annotations.deprecated_at_minor_version) = "3.0"]; - - // Enables the Zipkin exporter if set to true. The url and service name must - // also be set. This is deprecated, prefer to use Envoy's :ref:`native Zipkin - // tracer `. - bool zipkin_exporter_enabled = 5 - [deprecated = true, (envoy.annotations.deprecated_at_minor_version) = "3.0"]; - - // The URL to Zipkin, e.g. "http://127.0.0.1:9411/api/v2/spans". This is - // deprecated, prefer to use Envoy's :ref:`native Zipkin tracer - // `. - string zipkin_url = 6 - [deprecated = true, (envoy.annotations.deprecated_at_minor_version) = "3.0"]; - - // Enables the OpenCensus Agent exporter if set to true. The ocagent_address or - // ocagent_grpc_service must also be set. - bool ocagent_exporter_enabled = 11 - [deprecated = true, (envoy.annotations.deprecated_at_minor_version) = "3.0"]; - - // The address of the OpenCensus Agent, if its exporter is enabled, in gRPC - // format: https://github.com/grpc/grpc/blob/master/doc/naming.md - // [#comment:TODO: deprecate this field] - string ocagent_address = 12 - [deprecated = true, (envoy.annotations.deprecated_at_minor_version) = "3.0"]; - - // (optional) The gRPC server hosted by the OpenCensus Agent. Only Google gRPC is supported. - // This is only used if the ocagent_address is left empty. - core.v3.GrpcService ocagent_grpc_service = 14 - [deprecated = true, (envoy.annotations.deprecated_at_minor_version) = "3.0"]; - - // List of incoming trace context headers we will accept. First one found - // wins. - repeated TraceContext incoming_trace_context = 8 - [deprecated = true, (envoy.annotations.deprecated_at_minor_version) = "3.0"]; - - // List of outgoing trace context headers we will produce. - repeated TraceContext outgoing_trace_context = 9 - [deprecated = true, (envoy.annotations.deprecated_at_minor_version) = "3.0"]; -} diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/trace/v3/opentelemetry.proto b/xds/third_party/envoy/src/main/proto/envoy/config/trace/v3/opentelemetry.proto index 59028326f22..5260d9bd6af 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/config/trace/v3/opentelemetry.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/config/trace/v3/opentelemetry.proto @@ -6,6 +6,8 @@ import "envoy/config/core/v3/extension.proto"; import "envoy/config/core/v3/grpc_service.proto"; import "envoy/config/core/v3/http_service.proto"; +import "google/protobuf/wrappers.proto"; + import "udpa/annotations/migrate.proto"; import "udpa/annotations/status.proto"; @@ -19,7 +21,7 @@ option (udpa.annotations.file_status).package_version_status = ACTIVE; // Configuration for the OpenTelemetry tracer. // [#extension: envoy.tracers.opentelemetry] -// [#next-free-field: 6] +// [#next-free-field: 7] message OpenTelemetryConfig { // The upstream gRPC cluster that will receive OTLP traces. // Note that the tracer drops traces if the server does not read data fast enough. @@ -57,4 +59,9 @@ message OpenTelemetryConfig { // See: `OpenTelemetry sampler specification `_ // [#extension-category: envoy.tracers.opentelemetry.samplers] core.v3.TypedExtensionConfig sampler = 5; + + // Envoy caches the span in memory when the OpenTelemetry backend service is temporarily unavailable. + // This field specifies the maximum number of spans that can be cached. If not specified, the + // default is 1024. + google.protobuf.UInt32Value max_cache_size = 6; } diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/trace/v3/trace.proto b/xds/third_party/envoy/src/main/proto/envoy/config/trace/v3/trace.proto deleted file mode 100644 index 8ca43718ca3..00000000000 --- a/xds/third_party/envoy/src/main/proto/envoy/config/trace/v3/trace.proto +++ /dev/null @@ -1,17 +0,0 @@ -syntax = "proto3"; - -package envoy.config.trace.v3; - -import public "envoy/config/trace/v3/datadog.proto"; -import public "envoy/config/trace/v3/dynamic_ot.proto"; -import public "envoy/config/trace/v3/http_tracer.proto"; -import public "envoy/config/trace/v3/lightstep.proto"; -import public "envoy/config/trace/v3/opencensus.proto"; -import public "envoy/config/trace/v3/opentelemetry.proto"; -import public "envoy/config/trace/v3/service.proto"; -import public "envoy/config/trace/v3/zipkin.proto"; - -option java_package = "io.envoyproxy.envoy.config.trace.v3"; -option java_outer_classname = "TraceProto"; -option java_multiple_files = true; -option go_package = "github.com/envoyproxy/go-control-plane/envoy/config/trace/v3;tracev3"; diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/trace/v3/zipkin.proto b/xds/third_party/envoy/src/main/proto/envoy/config/trace/v3/zipkin.proto index a9aefef0c6d..2364983efc5 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/config/trace/v3/zipkin.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/config/trace/v3/zipkin.proto @@ -2,13 +2,14 @@ syntax = "proto3"; package envoy.config.trace.v3; +import "envoy/config/core/v3/http_service.proto"; + import "google/protobuf/wrappers.proto"; import "envoy/annotations/deprecation.proto"; import "udpa/annotations/migrate.proto"; import "udpa/annotations/status.proto"; import "udpa/annotations/versioning.proto"; -import "validate/validate.proto"; option java_package = "io.envoyproxy.envoy.config.trace.v3"; option java_outer_classname = "ZipkinProto"; @@ -21,10 +22,22 @@ option (udpa.annotations.file_status).package_version_status = ACTIVE; // Configuration for the Zipkin tracer. // [#extension: envoy.tracers.zipkin] -// [#next-free-field: 8] +// [#next-free-field: 10] message ZipkinConfig { option (udpa.annotations.versioning).previous_message_type = "envoy.config.trace.v2.ZipkinConfig"; + // Available trace context options for handling different trace header formats. + enum TraceContextOption { + // Use B3 headers only (default behavior). + USE_B3 = 0; + + // Enable B3 and W3C dual header support: + // - For downstream: Extract from B3 headers first, fallback to W3C traceparent if B3 is unavailable. + // - For upstream: Inject both B3 and W3C traceparent headers. + // When this option is NOT set, only B3 headers are used for both extraction and injection. + USE_B3_WITH_W3C_PROPAGATION = 1; + } + // Available Zipkin collector endpoint versions. enum CollectorEndpointVersion { // Zipkin API v1, JSON over HTTP. @@ -48,11 +61,23 @@ message ZipkinConfig { } // The cluster manager cluster that hosts the Zipkin collectors. - string collector_cluster = 1 [(validate.rules).string = {min_len: 1}]; + // + // .. note:: + // This field will be deprecated in future releases in favor of + // :ref:`collector_service `. + // + // Either this field or ``collector_service`` must be specified. + string collector_cluster = 1; // The API endpoint of the Zipkin service where the spans will be sent. When // using a standard Zipkin installation. - string collector_endpoint = 2 [(validate.rules).string = {min_len: 1}]; + // + // .. note:: + // This field will be deprecated in future releases in favor of + // :ref:`collector_service `. + // + // Required when using ``collector_cluster``. + string collector_endpoint = 2; // Determines whether a 128bit trace id will be used when creating a new // trace instance. The default value is false, which will result in a 64 bit trace id being used. @@ -67,6 +92,10 @@ message ZipkinConfig { // Optional hostname to use when sending spans to the collector_cluster. Useful for collectors // that require a specific hostname. Defaults to :ref:`collector_cluster ` above. + // + // .. note:: + // This field will be deprecated in future releases in favor of + // :ref:`collector_service `. string collector_hostname = 6; // If this is set to true, then Envoy will be treated as an independent hop in trace chain. A complete span pair will be created for a single @@ -82,5 +111,66 @@ message ZipkinConfig { // If this is set to true, then the // :ref:`start_child_span of router ` // SHOULD be set to true also to ensure the correctness of trace chain. - bool split_spans_for_request = 7; + // + // Both this field and ``start_child_span`` are deprecated by the + // :ref:`spawn_upstream_span `. + // Please use that ``spawn_upstream_span`` field to control the span creation. + bool split_spans_for_request = 7 + [deprecated = true, (envoy.annotations.deprecated_at_minor_version) = "3.0"]; + + // Determines which trace context format to use for trace header extraction and propagation. + // This controls both downstream request header extraction and upstream request header injection. + // Here is the spec for W3C trace headers: https://www.w3.org/TR/trace-context/ + // The default value is USE_B3 to maintain backward compatibility. + TraceContextOption trace_context_option = 8; + + // HTTP service configuration for the Zipkin collector. + // When specified, this configuration takes precedence over the legacy fields: + // collector_cluster, collector_endpoint, and collector_hostname. + // This provides a complete HTTP service configuration including cluster, URI, timeout, and headers. + // If not specified, the legacy fields above will be used for backward compatibility. + // + // Required fields when using collector_service: + // + // * ``http_uri.cluster`` - Must be specified and non-empty + // * ``http_uri.uri`` - Must be specified and non-empty + // * ``http_uri.timeout`` - Optional + // + // Full URI Support with Automatic Parsing: + // + // The ``uri`` field supports both path-only and full URI formats: + // + // .. code-block:: yaml + // + // tracing: + // provider: + // name: envoy.tracers.zipkin + // typed_config: + // "@type": type.googleapis.com/envoy.config.trace.v3.ZipkinConfig + // collector_service: + // http_uri: + // # Full URI format - hostname and path are extracted automatically + // uri: "https://zipkin-collector.example.com/api/v2/spans" + // cluster: zipkin + // timeout: 5s + // request_headers_to_add: + // - header: + // key: "X-Custom-Token" + // value: "your-custom-token" + // - header: + // key: "X-Service-ID" + // value: "your-service-id" + // + // URI Parsing Behavior: + // + // * Full URI: ``"https://zipkin-collector.example.com/api/v2/spans"`` + // + // * Hostname: ``zipkin-collector.example.com`` (sets HTTP ``Host`` header) + // * Path: ``/api/v2/spans`` (sets HTTP request path) + // + // * Path only: ``"/api/v2/spans"`` + // + // * Hostname: Uses cluster name as fallback + // * Path: ``/api/v2/spans`` + core.v3.HttpService collector_service = 9; } diff --git a/xds/third_party/envoy/src/main/proto/envoy/data/accesslog/v3/accesslog.proto b/xds/third_party/envoy/src/main/proto/envoy/data/accesslog/v3/accesslog.proto index a247c08df30..da029b7da2e 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/data/accesslog/v3/accesslog.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/data/accesslog/v3/accesslog.proto @@ -109,14 +109,16 @@ message AccessLogCommon { double sample_rate = 1 [(validate.rules).double = {lte: 1.0 gt: 0.0}]; // This field is the remote/origin address on which the request from the user was received. - // Note: This may not be the physical peer. E.g, if the remote address is inferred from for - // example the x-forwarder-for header, proxy protocol, etc. + // + // .. note:: + // This may not be the actual peer address. For example, it might be derived from headers like ``x-forwarded-for``, + // the proxy protocol, or similar sources. config.core.v3.Address downstream_remote_address = 2; // This field is the local/destination address on which the request from the user was received. config.core.v3.Address downstream_local_address = 3; - // If the connection is secure,S this field will contain TLS properties. + // If the connection is secure, this field will contain TLS properties. TLSProperties tls_properties = 4; // The time that Envoy started servicing this request. This is effectively the time that the first @@ -128,7 +130,7 @@ message AccessLogCommon { google.protobuf.Duration time_to_last_rx_byte = 6; // Interval between the first downstream byte received and the first upstream byte sent. There may - // by considerable delta between ``time_to_last_rx_byte`` and this value due to filters. + // be considerable delta between ``time_to_last_rx_byte`` and this value due to filters. // Additionally, the same caveats apply as documented in ``time_to_last_downstream_tx_byte`` about // not accounting for kernel socket buffer time, etc. google.protobuf.Duration time_to_first_upstream_tx_byte = 7; @@ -187,7 +189,7 @@ message AccessLogCommon { // If upstream connection failed due to transport socket (e.g. TLS handshake), provides the // failure reason from the transport socket. The format of this field depends on the configured // upstream transport socket. Common TLS failures are in - // :ref:`TLS trouble shooting `. + // :ref:`TLS troubleshooting `. string upstream_transport_failure_reason = 18; // The name of the route @@ -204,7 +206,7 @@ message AccessLogCommon { map filter_state_objects = 21; // A list of custom tags, which annotate logs with additional information. - // To configure this value, users should configure + // To configure this value, see the documentation for // :ref:`custom_tags `. map custom_tags = 22; @@ -225,40 +227,41 @@ message AccessLogCommon { // This could be any format string that could be used to identify one stream. string stream_id = 26; - // If this log entry is final log entry that flushed after the stream completed or - // intermediate log entry that flushed periodically during the stream. - // There may be multiple intermediate log entries and only one final log entry for each - // long-live stream (TCP connection, long-live HTTP2 stream). - // And if it is necessary, unique ID or identifier can be added to the log entry - // :ref:`stream_id ` to - // correlate all these intermediate log entries and final log entry. + // Indicates whether this log entry is the final entry (flushed after the stream completed) or an intermediate entry + // (flushed periodically during the stream). + // + // For long-lived streams (e.g., TCP connections or long-lived HTTP/2 streams), there may be multiple intermediate + // entries and only one final entry. + // + // If needed, a unique identifier (see :ref:`stream_id `) + // can be used to correlate all intermediate and final log entries for the same stream. // // .. attention:: // - // This field is deprecated in favor of ``access_log_type`` for better indication of the - // type of the access log record. + // This field is deprecated in favor of ``access_log_type``, which provides a clearer indication of the log entry + // type. bool intermediate_log_entry = 27 [deprecated = true, (envoy.annotations.deprecated_at_minor_version) = "3.0"]; // If downstream connection in listener failed due to transport socket (e.g. TLS handshake), provides the // failure reason from the transport socket. The format of this field depends on the configured downstream - // transport socket. Common TLS failures are in :ref:`TLS trouble shooting `. + // transport socket. Common TLS failures are in :ref:`TLS troubleshooting `. string downstream_transport_failure_reason = 28; // For HTTP: Total number of bytes sent to the downstream by the http stream. - // For TCP: Total number of bytes sent to the downstream by the tcp proxy. + // For TCP: Total number of bytes sent to the downstream by the :ref:`TCP Proxy `. uint64 downstream_wire_bytes_sent = 29; // For HTTP: Total number of bytes received from the downstream by the http stream. Envoy over counts sizes of received HTTP/1.1 pipelined requests by adding up bytes of requests in the pipeline to the one currently being processed. - // For TCP: Total number of bytes received from the downstream by the tcp proxy. + // For TCP: Total number of bytes received from the downstream by the :ref:`TCP Proxy `. uint64 downstream_wire_bytes_received = 30; // For HTTP: Total number of bytes sent to the upstream by the http stream. This value accumulates during upstream retries. - // For TCP: Total number of bytes sent to the upstream by the tcp proxy. + // For TCP: Total number of bytes sent to the upstream by the :ref:`TCP Proxy `. uint64 upstream_wire_bytes_sent = 31; // For HTTP: Total number of bytes received from the upstream by the http stream. - // For TCP: Total number of bytes sent to the upstream by the tcp proxy. + // For TCP: Total number of bytes sent to the upstream by the :ref:`TCP Proxy `. uint64 upstream_wire_bytes_received = 32; // The type of the access log, which indicates when the log was recorded. @@ -271,7 +274,7 @@ message AccessLogCommon { } // Flags indicating occurrences during request/response processing. -// [#next-free-field: 28] +// [#next-free-field: 29] message ResponseFlags { option (udpa.annotations.versioning).previous_message_type = "envoy.data.accesslog.v2.ResponseFlags"; @@ -297,7 +300,7 @@ message ResponseFlags { // Indicates there was no healthy upstream. bool no_healthy_upstream = 2; - // Indicates an there was an upstream request timeout. + // Indicates there was an upstream request timeout. bool upstream_request_timeout = 3; // Indicates local codec level reset was sent on the stream. @@ -358,7 +361,7 @@ message ResponseFlags { // Indicates that a filter configuration is not available. bool no_filter_config_found = 22; - // Indicates that request or connection exceeded the downstream connection duration. + // Indicates that the request or connection exceeded the downstream connection duration. bool duration_timeout = 23; // Indicates there was an HTTP protocol error in the upstream response. @@ -372,6 +375,9 @@ message ResponseFlags { // Indicates a DNS resolution failed. bool dns_resolution_failure = 27; + + // Indicates a downstream remote codec level reset was received on the stream + bool downstream_remote_reset = 28; } // Properties of a negotiated TLS connection. @@ -477,7 +483,7 @@ message HTTPRequestProperties { // do not already have a request ID. string request_id = 9; - // Value of the ``X-Envoy-Original-Path`` request header. + // Value of the ``x-envoy-original-path`` request header. string original_path = 10; // Size of the HTTP request headers in bytes. diff --git a/xds/third_party/envoy/src/main/proto/envoy/extensions/clusters/aggregate/v3/cluster.proto b/xds/third_party/envoy/src/main/proto/envoy/extensions/clusters/aggregate/v3/cluster.proto index 4f44ac9cd5c..d23d767f73b 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/extensions/clusters/aggregate/v3/cluster.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/extensions/clusters/aggregate/v3/cluster.proto @@ -2,6 +2,8 @@ syntax = "proto3"; package envoy.extensions.clusters.aggregate.v3; +import "envoy/config/core/v3/config_source.proto"; + import "udpa/annotations/status.proto"; import "udpa/annotations/versioning.proto"; import "validate/validate.proto"; @@ -25,3 +27,18 @@ message ClusterConfig { // appear in this list. repeated string clusters = 1 [(validate.rules).repeated = {min_items: 1}]; } + +// Configures an aggregate cluster whose +// :ref:`ClusterConfig ` +// is to be fetched from a separate xDS resource. +// [#extension: envoy.clusters.aggregate_resource] +// [#not-implemented-hide:] +message AggregateClusterResource { + // Configuration source specifier for the ClusterConfig resource. + // Only the aggregated protocol variants are supported; if configured + // otherwise, the cluster resource will be NACKed. + config.core.v3.ConfigSource config_source = 1 [(validate.rules).message = {required: true}]; + + // The name of the ClusterConfig resource to subscribe to. + string resource_name = 2 [(validate.rules).string = {min_len: 1}]; +} diff --git a/xds/third_party/envoy/src/main/proto/envoy/extensions/common/matching/v3/extension_matcher.proto b/xds/third_party/envoy/src/main/proto/envoy/extensions/common/matching/v3/extension_matcher.proto new file mode 100644 index 00000000000..817cd27a37a --- /dev/null +++ b/xds/third_party/envoy/src/main/proto/envoy/extensions/common/matching/v3/extension_matcher.proto @@ -0,0 +1,42 @@ +syntax = "proto3"; + +package envoy.extensions.common.matching.v3; + +import "envoy/config/common/matcher/v3/matcher.proto"; +import "envoy/config/core/v3/extension.proto"; + +import "xds/type/matcher/v3/matcher.proto"; + +import "envoy/annotations/deprecation.proto"; +import "udpa/annotations/status.proto"; +import "validate/validate.proto"; + +option java_package = "io.envoyproxy.envoy.extensions.common.matching.v3"; +option java_outer_classname = "ExtensionMatcherProto"; +option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/extensions/common/matching/v3;matchingv3"; +option (udpa.annotations.file_status).package_version_status = ACTIVE; + +// [#protodoc-title: Extension matcher] + +// Wrapper around an existing extension that provides an associated matcher. This allows +// decorating an existing extension with a matcher, which can be used to match against +// relevant protocol data. +message ExtensionWithMatcher { + // The associated matcher. This is deprecated in favor of xds_matcher. + config.common.matcher.v3.Matcher matcher = 1 + [deprecated = true, (envoy.annotations.deprecated_at_minor_version) = "3.0"]; + + // The associated matcher. + xds.type.matcher.v3.Matcher xds_matcher = 3; + + // The underlying extension config. + config.core.v3.TypedExtensionConfig extension_config = 2 + [(validate.rules).message = {required: true}]; +} + +// Extra settings on a per virtualhost/route/weighted-cluster level. +message ExtensionWithMatcherPerRoute { + // Matcher override. + xds.type.matcher.v3.Matcher xds_matcher = 1; +} diff --git a/xds/third_party/envoy/src/main/proto/envoy/extensions/filters/http/composite/v3/composite.proto b/xds/third_party/envoy/src/main/proto/envoy/extensions/filters/http/composite/v3/composite.proto new file mode 100644 index 00000000000..1ab6c5eb1ef --- /dev/null +++ b/xds/third_party/envoy/src/main/proto/envoy/extensions/filters/http/composite/v3/composite.proto @@ -0,0 +1,106 @@ +syntax = "proto3"; + +package envoy.extensions.filters.http.composite.v3; + +import "envoy/config/core/v3/base.proto"; +import "envoy/config/core/v3/config_source.proto"; +import "envoy/config/core/v3/extension.proto"; + +import "udpa/annotations/migrate.proto"; +import "udpa/annotations/status.proto"; +import "validate/validate.proto"; + +option java_package = "io.envoyproxy.envoy.extensions.filters.http.composite.v3"; +option java_outer_classname = "CompositeProto"; +option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/extensions/filters/http/composite/v3;compositev3"; +option (udpa.annotations.file_status).package_version_status = ACTIVE; + +// [#protodoc-title: Composite] +// Composite Filter :ref:`configuration overview `. +// [#extension: envoy.filters.http.composite] + +// :ref:`Composite filter ` config. The composite filter config +// allows delegating filter handling to another filter as determined by matching on the request +// headers. This makes it possible to use different filters or filter configurations based on the +// incoming request. +// +// This is intended to be used with +// :ref:`ExtensionWithMatcher ` +// where a match tree is specified that indicates (via +// :ref:`ExecuteFilterAction `) +// which filter configuration to create and delegate to. +message Composite { + // Named filter chain definitions that can be referenced from + // :ref:`ExecuteFilterAction.filter_chain_name + // `. + // The filter chains are compiled at configuration time and can be referenced by name. + // This is useful when the same filter chain needs to be applied across many routes, + // as it avoids duplicating the filter chain configuration. + map named_filter_chains = 1; +} + +// A list of filter configurations to be called in order. Note that this can be used as the type +// inside of an ECDS :ref:`TypedExtensionConfig +// ` extension, which allows a chain of +// filters to be configured dynamically. In that case, the types of all filters in the chain must +// be present in the :ref:`ExtensionConfigSource.type_urls +// ` field. +message FilterChainConfiguration { + repeated config.core.v3.TypedExtensionConfig typed_config = 1; +} + +// Configuration for an extension configuration discovery service with name. +message DynamicConfig { + // The name of the extension configuration. It also serves as a resource name in ExtensionConfigDS. + // The resource type in the ``DiscoveryRequest`` will be :ref:`TypedExtensionConfig + // `. + string name = 1 [(validate.rules).string = {min_len: 1}]; + + // Configuration source specifier for an extension configuration discovery + // service. In case of a failure and without the default configuration, + // 500(Internal Server Error) will be returned. + config.core.v3.ExtensionConfigSource config_discovery = 2; +} + +// Composite match action (see :ref:`matching docs ` for more info on match actions). +// This specifies the filter configuration of the filter that the composite filter should delegate filter interactions to. +// [#next-free-field: 6] +message ExecuteFilterAction { + // Filter specific configuration which depends on the filter being + // instantiated. See the supported filters for further documentation. + // Only one of ``typed_config``, ``dynamic_config``, ``filter_chain``, or ``filter_chain_name`` + // can be set. + // [#extension-category: envoy.filters.http] + config.core.v3.TypedExtensionConfig typed_config = 1 + [(udpa.annotations.field_migrate).oneof_promotion = "config_type"]; + + // Dynamic configuration of filter obtained via extension configuration discovery service. + // Only one of ``typed_config``, ``dynamic_config``, ``filter_chain``, or ``filter_chain_name`` + // can be set. + DynamicConfig dynamic_config = 2 + [(udpa.annotations.field_migrate).oneof_promotion = "config_type"]; + + // An inlined list of filter configurations. The specified filters will be executed in order. + // Only one of ``typed_config``, ``dynamic_config``, ``filter_chain``, or ``filter_chain_name`` + // can be set. + FilterChainConfiguration filter_chain = 4; + + // The name of a filter chain defined in + // :ref:`Composite.named_filter_chains + // `. + // At runtime, if the named filter chain is not found in the Composite filter's configuration, + // no filter will be applied for this match (the action is silently skipped). + // Only one of ``typed_config``, ``dynamic_config``, ``filter_chain``, or ``filter_chain_name`` + // can be set. + string filter_chain_name = 5; + + // Probability of the action execution. If not specified, this is 100%. + // This allows sampling behavior for the configured actions. + // For example, if + // :ref:`default_value ` + // under the ``sample_percent`` is configured with 30%, a dice roll with that + // probability is done. The underline action will only be executed if the + // dice roll returns positive. Otherwise, the action is skipped. + config.core.v3.RuntimeFractionalPercent sample_percent = 3; +} diff --git a/xds/third_party/envoy/src/main/proto/envoy/extensions/filters/http/ext_authz/v3/ext_authz.proto b/xds/third_party/envoy/src/main/proto/envoy/extensions/filters/http/ext_authz/v3/ext_authz.proto new file mode 100644 index 00000000000..7f70b70013b --- /dev/null +++ b/xds/third_party/envoy/src/main/proto/envoy/extensions/filters/http/ext_authz/v3/ext_authz.proto @@ -0,0 +1,602 @@ +syntax = "proto3"; + +package envoy.extensions.filters.http.ext_authz.v3; + +import "envoy/config/common/mutation_rules/v3/mutation_rules.proto"; +import "envoy/config/core/v3/base.proto"; +import "envoy/config/core/v3/config_source.proto"; +import "envoy/config/core/v3/grpc_service.proto"; +import "envoy/config/core/v3/http_uri.proto"; +import "envoy/type/matcher/v3/metadata.proto"; +import "envoy/type/matcher/v3/string.proto"; +import "envoy/type/v3/http_status.proto"; + +import "google/protobuf/struct.proto"; +import "google/protobuf/wrappers.proto"; + +import "envoy/annotations/deprecation.proto"; +import "udpa/annotations/sensitive.proto"; +import "udpa/annotations/status.proto"; +import "udpa/annotations/versioning.proto"; +import "validate/validate.proto"; + +option java_package = "io.envoyproxy.envoy.extensions.filters.http.ext_authz.v3"; +option java_outer_classname = "ExtAuthzProto"; +option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/extensions/filters/http/ext_authz/v3;ext_authzv3"; +option (udpa.annotations.file_status).package_version_status = ACTIVE; + +// [#protodoc-title: External Authorization] +// External Authorization :ref:`configuration overview `. +// [#extension: envoy.filters.http.ext_authz] + +// [#next-free-field: 32] +message ExtAuthz { + option (udpa.annotations.versioning).previous_message_type = + "envoy.config.filter.http.ext_authz.v3.ExtAuthz"; + + reserved 4; + + reserved "use_alpha"; + + // External authorization service configuration. + oneof services { + // gRPC service configuration (default timeout: 200ms). + config.core.v3.GrpcService grpc_service = 1; + + // HTTP service configuration (default timeout: 200ms). + HttpService http_service = 3; + } + + // API version for ext_authz transport protocol. This describes the ext_authz gRPC endpoint and + // version of messages used on the wire. + config.core.v3.ApiVersion transport_api_version = 12 + [(validate.rules).enum = {defined_only: true}]; + + // Changes the filter's behavior on errors: + // + // * When set to ``true``, the filter will ``accept`` the client request even if communication with + // the authorization service has failed, or if the authorization service has returned an HTTP 5xx + // error. + // + // * When set to ``false``, the filter will ``reject`` client requests and return ``Forbidden`` + // if communication with the authorization service has failed, or if the authorization service + // has returned an HTTP 5xx error. + // + // Errors can always be tracked in the :ref:`stats `. + // + // Defaults to ``false``. + bool failure_mode_allow = 2; + + // When ``failure_mode_allow`` and ``failure_mode_allow_header_add`` are both set to ``true``, + // ``x-envoy-auth-failure-mode-allowed: true`` will be added to request headers if the communication + // with the authorization service has failed, or if the authorization service has returned a + // HTTP 5xx error. + bool failure_mode_allow_header_add = 19; + + // Enables the filter to buffer the client request body and send it within the authorization request. + // The ``x-envoy-auth-partial-body: false|true`` metadata header will be added to the authorization + // request indicating whether the body data is partial. + BufferSettings with_request_body = 5; + + // Clears the route cache in order to allow the external authorization service to correctly affect + // routing decisions. The filter clears all cached routes when all of the following holds: + // + // * This field is set to ``true``. + // * The status returned from the authorization service is an HTTP 200 or gRPC 0. + // * At least one ``authorization response header`` is added to the client request, or is used to + // alter another client request header. + // + // Defaults to ``false``. + bool clear_route_cache = 6; + + // Sets the HTTP status that is returned to the client when the authorization server returns an error + // or cannot be reached. + // + // The default status is ``HTTP 403 Forbidden``. + type.v3.HttpStatus status_on_error = 7; + + // When set to ``true``, the filter will check the :ref:`ext_authz response + // ` for invalid header and + // query parameter mutations. If the response is invalid, the filter will send a local reply + // to the downstream request with status ``HTTP 500 Internal Server Error``. + // + // .. note:: + // Both ``headers_to_remove`` and ``query_parameters_to_remove`` are validated, but invalid elements in + // those fields should not affect any headers and thus will not cause the filter to send a local reply. + // + // When set to ``false``, any invalid mutations will be visible to the rest of Envoy and may cause + // unexpected behavior. + // + // If you are using ext_authz with an untrusted ext_authz server, you should set this to ``true``. + // + // Defaults to ``false``. + bool validate_mutations = 24; + + // Specifies a list of metadata namespaces whose values, if present, will be passed to the + // ext_authz service. The :ref:`filter_metadata ` + // is passed as an opaque ``protobuf::Struct``. + // + // .. note:: + // This field applies exclusively to the gRPC ext_authz service and has no effect on the HTTP service. + // + // For example, if the ``jwt_authn`` filter is used and :ref:`payload_in_metadata + // ` is set, + // then the following will pass the jwt payload to the authorization server. + // + // .. code-block:: yaml + // + // metadata_context_namespaces: + // - envoy.filters.http.jwt_authn + // + repeated string metadata_context_namespaces = 8; + + // Specifies a list of metadata namespaces whose values, if present, will be passed to the + // ext_authz service. :ref:`typed_filter_metadata ` + // is passed as a ``protobuf::Any``. + // + // .. note:: + // This field applies exclusively to the gRPC ext_authz service and has no effect on the HTTP service. + // + // This works similarly to ``metadata_context_namespaces`` but allows Envoy and the ext_authz server to share + // the protobuf message definition in order to perform safe parsing. + // + repeated string typed_metadata_context_namespaces = 16; + + // Specifies a list of route metadata namespaces whose values, if present, will be passed to the + // ext_authz service at :ref:`route_metadata_context ` in + // :ref:`CheckRequest `. + // :ref:`filter_metadata ` is passed as an opaque ``protobuf::Struct``. + repeated string route_metadata_context_namespaces = 21; + + // Specifies a list of route metadata namespaces whose values, if present, will be passed to the + // ext_authz service at :ref:`route_metadata_context ` in + // :ref:`CheckRequest `. + // :ref:`typed_filter_metadata ` is passed as a ``protobuf::Any``. + repeated string route_typed_metadata_context_namespaces = 22; + + // Specifies if the filter is enabled. + // + // If :ref:`runtime_key ` is specified, + // Envoy will lookup the runtime key to get the percentage of requests to filter. + // + // If this field is not specified, the filter will be enabled for all requests. + config.core.v3.RuntimeFractionalPercent filter_enabled = 9; + + // Specifies if the filter is enabled with metadata matcher. + // If this field is not specified, the filter will be enabled for all requests. + // + // .. note:: + // + // This field is only evaluated if the filter is instantiated. If the filter is marked with + // ``disabled: true`` in the :ref:`HttpFilter + // ` + // configuration or in per-route configuration via :ref:`ExtAuthzPerRoute + // `, + // the filter will not be instantiated and this field will have no effect. + // + // .. tip:: + // + // For dynamic filter activation based on metadata (such as metadata set by a preceding + // filter), consider using :ref:`ExtensionWithMatcher + // ` instead. This + // provides a more flexible matching framework that can evaluate conditions before filter + // instantiation. See the :ref:`ext_authz filter documentation + // ` for examples. + type.matcher.v3.MetadataMatcher filter_enabled_metadata = 14; + + // Specifies whether to deny the requests when the filter is disabled. + // If :ref:`runtime_key ` is specified, + // Envoy will lookup the runtime key to determine whether to deny requests for filter-protected paths + // when the filter is disabled. If the filter is disabled in ``typed_per_filter_config`` for the path, + // requests will not be denied. + // + // If this field is not specified, all requests will be allowed when disabled. + // + // If a request is denied due to this setting, the response code in :ref:`status_on_error + // ` will + // be returned. + config.core.v3.RuntimeFeatureFlag deny_at_disable = 11; + + // Specifies if the peer certificate is sent to the external service. + // + // When this field is ``true``, Envoy will include the peer X.509 certificate, if available, in the + // :ref:`certificate`. + bool include_peer_certificate = 10; + + // Optional additional prefix to use when emitting statistics. This allows distinguishing + // emitted statistics between configured ``ext_authz`` filters in an HTTP filter chain. For example: + // + // .. code-block:: yaml + // + // http_filters: + // - name: envoy.filters.http.ext_authz + // typed_config: + // "@type": type.googleapis.com/envoy.extensions.filters.http.ext_authz.v3.ExtAuthz + // stat_prefix: waf # This emits ext_authz.waf.ok, ext_authz.waf.denied, etc. + // - name: envoy.filters.http.ext_authz + // typed_config: + // "@type": type.googleapis.com/envoy.extensions.filters.http.ext_authz.v3.ExtAuthz + // stat_prefix: blocker # This emits ext_authz.blocker.ok, ext_authz.blocker.denied, etc. + // + string stat_prefix = 13; + + // Optional labels that will be passed to :ref:`labels` in + // :ref:`destination`. + // The labels will be read from :ref:`metadata` with the specified key. + string bootstrap_metadata_labels_key = 15; + + // Check request to authorization server will include the client request headers that have a correspondent match + // in the list. If this option isn't specified, then + // all client request headers are included in the check request to a gRPC authorization server, whereas no client request headers + // (besides the ones allowed by default - see note below) are included in the check request to an HTTP authorization server. + // This inconsistency between gRPC and HTTP servers is to maintain backwards compatibility with legacy behavior. + // + // .. note:: + // + // For requests to an HTTP authorization server: in addition to the user's supplied matchers, ``Host``, ``Method``, ``Path``, + // ``Content-Length``, and ``Authorization`` are **additionally included** in the list. + // + // .. note:: + // + // For requests to an HTTP authorization server: the value of ``Content-Length`` will be set to ``0`` and the request to the + // authorization server will not have a message body. However, the check request can include the buffered + // client request body (controlled by :ref:`with_request_body + // ` setting); + // consequently, the value of ``Content-Length`` in the authorization request reflects the size of its payload. + // + // .. note:: + // + // This can be overridden by the field ``disallowed_headers`` below. That is, if a header + // matches for both ``allowed_headers`` and ``disallowed_headers``, the header will NOT be sent. + type.matcher.v3.ListStringMatcher allowed_headers = 17; + + // If set, specifically disallow any header in this list to be forwarded to the external + // authentication server. This overrides the above ``allowed_headers`` if a header matches both. + type.matcher.v3.ListStringMatcher disallowed_headers = 25; + + // Specifies if the TLS session level details like SNI are sent to the external service. + // + // When this field is ``true``, Envoy will include the SNI name used for TLSClientHello, if available, in the + // :ref:`tls_session`. + bool include_tls_session = 18; + + // Whether to increment cluster statistics (e.g. cluster..upstream_rq_*) on authorization failure. + // Defaults to ``true``. + google.protobuf.BoolValue charge_cluster_response_stats = 20; + + // Whether to encode the raw headers (i.e., unsanitized values and unconcatenated multi-line headers) + // in the authorization request. Works with both HTTP and gRPC clients. + // + // When this is set to ``true``, header values are not sanitized. Headers with the same key will also + // not be combined into a single, comma-separated header. + // Requests to gRPC services will populate the field + // :ref:`header_map`. + // Requests to HTTP services will be constructed with the unsanitized header values and preserved + // multi-line headers with the same key. + // + // If this field is set to ``false``, header values will be sanitized, with any non-UTF-8-compliant + // bytes replaced with ``'!'``. Headers with the same key will have their values concatenated into a + // single comma-separated header value. + // Requests to gRPC services will populate the field + // :ref:`headers`. + // Requests to HTTP services will have their header values sanitized and will not preserve + // multi-line headers with the same key. + // + // It is recommended to set this to ``true`` unless you rely on the previous behavior. + // + // It is set to ``false`` by default for backwards compatibility. + bool encode_raw_headers = 23; + + // Rules for what modifications an ext_authz server may make to the request headers before + // continuing decoding or forwarding upstream. + // + // If set, enables header mutation checking against the configured rules. Note that + // :ref:`HeaderMutationRules ` + // has defaults that change ext_authz behavior. Also note that if this field is set, + // ext_authz can no longer append to ``:``-prefixed headers. + // + // If unset, header mutation rule checking is completely disabled. + // + // Regardless of what is configured here, ext_authz cannot remove ``:``-prefixed headers. + // + // This field and ``validate_mutations`` have different use cases. ``validate_mutations`` enables + // correctness checks for all header and query parameter mutations (for example, invalid characters). + // This field allows the filter to reject mutations to specific headers. + config.common.mutation_rules.v3.HeaderMutationRules decoder_header_mutation_rules = 26; + + // Enable or disable ingestion of dynamic metadata from the ext_authz service. + // + // If ``false``, the filter will ignore dynamic metadata injected by the ext_authz service. If the + // ext_authz service tries injecting dynamic metadata, the filter will log, increment the + // ``ignored_dynamic_metadata`` stat, then continue handling the response. + // + // If ``true``, the filter will ingest dynamic metadata entries as normal. + // + // If unset, defaults to ``true``. + google.protobuf.BoolValue enable_dynamic_metadata_ingestion = 27; + + // Additional metadata to be added to the filter state for logging purposes. The metadata will be + // added to StreamInfo's filter state under the namespace corresponding to the ext_authz filter + // name. + google.protobuf.Struct filter_metadata = 28; + + // When set to ``true``, the filter will emit per-stream stats for access logging. The filter state + // key will be the same as the filter name. + // + // If using Envoy gRPC, emits latency, bytes sent / received, upstream info, and upstream cluster + // info. If not using Envoy gRPC, emits only latency. + // + // .. note:: + // Stats are ONLY added to filter state if a check request is actually made to an ext_authz service. + // + // If this is ``false`` the filter will not emit stats, but filter_metadata will still be respected if + // it has a value. + // + // Field ``latency_us`` is exposed for CEL and logging when using gRPC or HTTP service. + // Fields ``bytesSent`` and ``bytesReceived`` are exposed for CEL and logging only when using gRPC service. + bool emit_filter_state_stats = 29; + + // Sets the maximum size (in bytes) of the response body that the filter will send downstream + // when a request is denied by the external authorization service. + // + // If the authorization server returns a response body larger than this configured limit, + // the body will be truncated to ``max_denied_response_body_bytes`` before being sent to the + // downstream client. + // + // If this field is not set or is set to 0, no truncation will occur, and the entire + // denied response body will be forwarded. + uint32 max_denied_response_body_bytes = 30; + + // When set to ``true``, the filter will enforce the response header map's count and size limits + // by sending a local reply when those limits are violated. + // + // When set to ``false``, the filter will ignore the response header map's limits and add / set + // all response headers as specified by the external authorization service. + // + // Recommendation: enable if the external authorization service is not trusted. Otherwise, leave + // it ``false``. + // + // Defaults to ``false``. + bool enforce_response_header_limits = 31; +} + +// Configuration for buffering the request data. +message BufferSettings { + option (udpa.annotations.versioning).previous_message_type = + "envoy.config.filter.http.ext_authz.v2.BufferSettings"; + + // Sets the maximum size of a message body that the filter will hold in memory. Envoy will return + // ``HTTP 413`` and will *not* initiate the authorization process when the buffer reaches the size + // set in this field. + // + // .. note:: + // This setting will have precedence over :ref:`failure_mode_allow + // `. + uint32 max_request_bytes = 1 [(validate.rules).uint32 = {gt: 0}]; + + // When this field is ``true``, Envoy will buffer the message until ``max_request_bytes`` is reached. + // The authorization request will be dispatched and no 413 HTTP error will be returned by the + // filter. + // + // Defaults to ``false``. + bool allow_partial_message = 2; + + // If ``true``, the body sent to the external authorization service is set as raw bytes and populates + // :ref:`raw_body` + // in the HTTP request attribute context. Otherwise, :ref:`body + // ` will be populated + // with a UTF-8 string request body. + // + // This field only affects configurations using a :ref:`grpc_service + // `. In configurations that use + // an :ref:`http_service `, this + // has no effect. + // + // Defaults to ``false``. + bool pack_as_bytes = 3; +} + +// HttpService is used for raw HTTP communication between the filter and the authorization service. +// When configured, the filter will parse the client request and use these attributes to call the +// authorization server. Depending on the response, the filter may reject or accept the client +// request. +// +// .. note:: +// In any of these events, metadata can be added, removed or overridden by the filter: +// +// On authorization request, a list of allowed request headers may be supplied. See +// :ref:`allowed_headers +// ` +// for details. Additional headers metadata may be added to the authorization request. See +// :ref:`headers_to_add +// ` for +// details. +// +// On authorization response status ``HTTP 200 OK``, the filter will allow traffic to the upstream and +// additional headers metadata may be added to the original client request. See +// :ref:`allowed_upstream_headers +// ` +// for details. Additionally, the filter may add additional headers to the client's response. See +// :ref:`allowed_client_headers_on_success +// ` +// for details. +// +// On other authorization response statuses, the filter will not allow traffic. Additional headers +// metadata as well as body may be added to the client's response. See :ref:`allowed_client_headers +// ` +// for details. +// [#next-free-field: 10] +message HttpService { + option (udpa.annotations.versioning).previous_message_type = + "envoy.config.filter.http.ext_authz.v2.HttpService"; + + reserved 3, 4, 5, 6; + + // Sets the HTTP server URI which the authorization requests must be sent to. + config.core.v3.HttpUri server_uri = 1; + + // Sets a prefix to the value of authorization request header ``Path``. + string path_prefix = 2; + + // Settings used for controlling authorization request metadata. + AuthorizationRequest authorization_request = 7; + + // Settings used for controlling authorization response metadata. + AuthorizationResponse authorization_response = 8; + + // Optional retry policy for requests to the authorization server. + // If not set, no retries will be performed. + // + // .. note:: + // When this field is set, the ``ext_authz`` filter will buffer the request body for retry purposes. + config.core.v3.RetryPolicy retry_policy = 9; +} + +message AuthorizationRequest { + option (udpa.annotations.versioning).previous_message_type = + "envoy.config.filter.http.ext_authz.v2.AuthorizationRequest"; + + // Authorization request includes the client request headers that have a corresponding match + // in the list. + // This field has been deprecated in favor of :ref:`allowed_headers + // `. + // + // .. note:: + // + // In addition to the user's supplied matchers, ``Host``, ``Method``, ``Path``, + // ``Content-Length``, and ``Authorization`` are **automatically included** in the list. + // + // .. note:: + // + // By default, the ``Content-Length`` header is set to ``0`` and the request to the authorization + // service has no message body. However, the authorization request *may* include the buffered + // client request body (controlled by :ref:`with_request_body + // ` + // setting); hence the value of its ``Content-Length`` reflects the size of its payload. + // + type.matcher.v3.ListStringMatcher allowed_headers = 1 + [deprecated = true, (envoy.annotations.deprecated_at_minor_version) = "3.0"]; + + // Sets a list of headers that will be included in the request to the authorization service. + // + // .. note:: + // Client request headers with the same key will be overridden. + repeated config.core.v3.HeaderValue headers_to_add = 2; +} + +// [#next-free-field: 6] +message AuthorizationResponse { + option (udpa.annotations.versioning).previous_message_type = + "envoy.config.filter.http.ext_authz.v2.AuthorizationResponse"; + + // When this list is set, authorization + // response headers that have a correspondent match will be added to the original client request. + // + // .. note:: + // Existing headers will be overridden. + type.matcher.v3.ListStringMatcher allowed_upstream_headers = 1; + + // When this list is set, authorization + // response headers that have a correspondent match will be added to the original client request. + // + // .. note:: + // Existing headers will be appended. + type.matcher.v3.ListStringMatcher allowed_upstream_headers_to_append = 3; + + // When this list is set, authorization + // response headers that have a correspondent match will be added to the client's response. + // When a header is included in this list, ``Path``, ``Status``, ``Content-Length``, ``WWW-Authenticate`` and + // ``Location`` are automatically added. + // + // .. note:: + // When this list is *not* set, all the authorization response headers, except + // ``Authority (Host)``, will be in the response to the client. + type.matcher.v3.ListStringMatcher allowed_client_headers = 2; + + // When this list is set, authorization + // response headers that have a correspondent match will be added to the client's response when + // the authorization response itself is successful, i.e. not failed or denied. When this list is + // *not* set, no additional headers will be added to the client's response on success. + type.matcher.v3.ListStringMatcher allowed_client_headers_on_success = 4; + + // When this list is set, authorization + // response headers that have a correspondent match will be emitted as dynamic metadata to be consumed + // by the next filter. This metadata lives in a namespace specified by the canonical name of extension filter + // that requires it: + // + // - :ref:`envoy.filters.http.ext_authz ` for HTTP filter. + // - :ref:`envoy.filters.network.ext_authz ` for network filter. + type.matcher.v3.ListStringMatcher dynamic_metadata_from_headers = 5; +} + +// Extra settings on a per virtualhost/route/weighted-cluster level. +message ExtAuthzPerRoute { + option (udpa.annotations.versioning).previous_message_type = + "envoy.config.filter.http.ext_authz.v2.ExtAuthzPerRoute"; + + oneof override { + option (validate.required) = true; + + // Disable the ext auth filter for this particular vhost or route. + // If disabled is specified in multiple per-filter-configs, the most specific one will be used. + // If the filter is disabled by default and this is set to ``false``, the filter will be enabled + // for this vhost or route. + bool disabled = 1; + + // Check request settings for this route. + CheckSettings check_settings = 2 [(validate.rules).message = {required: true}]; + } +} + +// Extra settings for the check request. +// [#next-free-field: 6] +message CheckSettings { + option (udpa.annotations.versioning).previous_message_type = + "envoy.config.filter.http.ext_authz.v2.CheckSettings"; + + // Context extensions to set on the CheckRequest's + // :ref:`AttributeContext.context_extensions` + // + // You can use this to provide extra context for the external authorization server on specific + // virtual hosts/routes. For example, adding a context extension on the virtual host level can + // give the ext-authz server information on what virtual host is used without needing to parse the + // host header. If CheckSettings is specified in multiple per-filter-configs, they will be merged + // in order, and the result will be used. + // + // Merge semantics for this field are such that keys from more specific configs override. + // + // .. note:: + // These settings are only applied to a filter configured with a + // :ref:`grpc_service`. + map context_extensions = 1 [(udpa.annotations.sensitive) = true]; + + // When set to ``true``, disable the configured :ref:`with_request_body + // ` for a specific route. + // + // Only one of ``disable_request_body_buffering`` and + // :ref:`with_request_body ` + // may be specified. + bool disable_request_body_buffering = 2; + + // Enable or override request body buffering, which is configured using the + // :ref:`with_request_body ` + // option for a specific route. + // + // Only one of ``with_request_body`` and + // :ref:`disable_request_body_buffering ` + // may be specified. + BufferSettings with_request_body = 3; + + // Override the external authorization service for this route. + // This allows different routes to use different external authorization service backends + // and service types (gRPC or HTTP). If specified, this overrides the filter-level service + // configuration regardless of the original service type. + oneof service_override { + // Override with a gRPC service configuration. + config.core.v3.GrpcService grpc_service = 4; + + // Override with an HTTP service configuration. + HttpService http_service = 5; + } +} diff --git a/xds/third_party/envoy/src/main/proto/envoy/extensions/filters/http/gcp_authn/v3/gcp_authn.proto b/xds/third_party/envoy/src/main/proto/envoy/extensions/filters/http/gcp_authn/v3/gcp_authn.proto new file mode 100644 index 00000000000..f4646389f7e --- /dev/null +++ b/xds/third_party/envoy/src/main/proto/envoy/extensions/filters/http/gcp_authn/v3/gcp_authn.proto @@ -0,0 +1,87 @@ +syntax = "proto3"; + +package envoy.extensions.filters.http.gcp_authn.v3; + +import "envoy/config/core/v3/base.proto"; +import "envoy/config/core/v3/http_uri.proto"; + +import "google/protobuf/duration.proto"; +import "google/protobuf/wrappers.proto"; + +import "envoy/annotations/deprecation.proto"; +import "udpa/annotations/status.proto"; +import "validate/validate.proto"; + +option java_package = "io.envoyproxy.envoy.extensions.filters.http.gcp_authn.v3"; +option java_outer_classname = "GcpAuthnProto"; +option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/extensions/filters/http/gcp_authn/v3;gcp_authnv3"; +option (udpa.annotations.file_status).package_version_status = ACTIVE; + +// [#protodoc-title: GCP authentication] +// GCP authentication :ref:`configuration overview `. +// [#extension: envoy.filters.http.gcp_authn] + +// Filter configuration. +// [#next-free-field: 7] +message GcpAuthnFilterConfig { + // The HTTP URI to fetch tokens from GCE Metadata Server(https://cloud.google.com/compute/docs/metadata/overview). + // The URL format is "http://metadata.google.internal/computeMetadata/v1/instance/service-accounts/default/identity?audience=[AUDIENCE]" + // + // This field is deprecated because it does not match the API surface provided by the google auth libraries. + // Control planes should not attempt to override the metadata server URI. + // The cluster and timeout can be configured using the ``cluster`` and ``timeout`` fields instead. + // For backward compatibility, the cluster and timeout configured in this field will be used + // if the new ``cluster`` and ``timeout`` fields are not set. + config.core.v3.HttpUri http_uri = 1 + [deprecated = true, (envoy.annotations.deprecated_at_minor_version) = "3.0"]; + + // Retry policy for fetching tokens. + // Not supported by all data planes. + config.core.v3.RetryPolicy retry_policy = 2; + + // Token cache configuration. This field is optional. + TokenCacheConfig cache_config = 3; + + // Request header location to extract the token. By default (i.e. if this field is not specified), the token + // is extracted to the Authorization HTTP header, in the format "Authorization: Bearer ". + // Not supported by all data planes. + TokenHeader token_header = 4; + + // Cluster to send traffic to the GCE metadata server. Not supported + // by all data planes; a data plane may instead have its own mechanism + // for contacting the metadata server. + string cluster = 5; + + // Timeout for fetching the tokens from the GCE metadata server. + // Not supported by all data planes. + google.protobuf.Duration timeout = 6 [(validate.rules).duration = { + lt {seconds: 4294967296} + gte {} + }]; +} + +// Audience is the URL of the receiving service that performs token authentication. +// It will be provided to the filter through cluster's typed_filter_metadata. +message Audience { + string url = 1 [(validate.rules).string = {min_len: 1}]; +} + +// Token Cache configuration. +message TokenCacheConfig { + // The number of cache entries. The maximum number of entries is INT64_MAX as it is constrained by underlying cache implementation. + // Default value 0 (i.e., proto3 defaults) disables the cache by default. Other default values will enable the cache. + google.protobuf.UInt64Value cache_size = 1 [(validate.rules).uint64 = {lte: 9223372036854775807}]; +} + +message TokenHeader { + // The HTTP header's name. + string name = 1 + [(validate.rules).string = {min_len: 1 well_known_regex: HTTP_HEADER_NAME strict: false}]; + + // The header's prefix. The format is "value_prefix" + // For example, for "Authorization: Bearer ", value_prefix="Bearer " with a space at the + // end. + string value_prefix = 2 + [(validate.rules).string = {well_known_regex: HTTP_HEADER_VALUE strict: false}]; +} diff --git a/xds/third_party/envoy/src/main/proto/envoy/extensions/filters/http/rate_limit_quota/v3/rate_limit_quota.proto b/xds/third_party/envoy/src/main/proto/envoy/extensions/filters/http/rate_limit_quota/v3/rate_limit_quota.proto new file mode 100644 index 00000000000..57b8bdecd78 --- /dev/null +++ b/xds/third_party/envoy/src/main/proto/envoy/extensions/filters/http/rate_limit_quota/v3/rate_limit_quota.proto @@ -0,0 +1,423 @@ +syntax = "proto3"; + +package envoy.extensions.filters.http.rate_limit_quota.v3; + +import "envoy/config/core/v3/base.proto"; +import "envoy/config/core/v3/extension.proto"; +import "envoy/config/core/v3/grpc_service.proto"; +import "envoy/type/v3/http_status.proto"; +import "envoy/type/v3/ratelimit_strategy.proto"; + +import "google/protobuf/duration.proto"; +import "google/protobuf/wrappers.proto"; +import "google/rpc/status.proto"; + +import "xds/annotations/v3/status.proto"; +import "xds/type/matcher/v3/matcher.proto"; + +import "udpa/annotations/status.proto"; +import "validate/validate.proto"; + +option java_package = "io.envoyproxy.envoy.extensions.filters.http.rate_limit_quota.v3"; +option java_outer_classname = "RateLimitQuotaProto"; +option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/extensions/filters/http/rate_limit_quota/v3;rate_limit_quotav3"; +option (udpa.annotations.file_status).package_version_status = ACTIVE; +option (xds.annotations.v3.file_status).work_in_progress = true; + +// [#protodoc-title: Rate Limit Quota] +// Rate Limit Quota :ref:`configuration overview `. +// [#extension: envoy.filters.http.rate_limit_quota] + +// Configures the Rate Limit Quota filter. +// +// Can be overridden in the per-route and per-host configurations. +// The more specific definition completely overrides the less specific definition. +// [#next-free-field: 7] +message RateLimitQuotaFilterConfig { + // Configures the gRPC Rate Limit Quota Service (RLQS) RateLimitQuotaService. + config.core.v3.GrpcService rlqs_server = 1 [(validate.rules).message = {required: true}]; + + // The application domain to use when calling the service. This enables sharing the quota + // server between different applications without fear of overlap. + // E.g., "envoy". + string domain = 2 [(validate.rules).string = {min_len: 1}]; + + // The match tree to use for grouping incoming requests into buckets. + // + // Example: + // + // .. validated-code-block:: yaml + // :type-name: xds.type.matcher.v3.Matcher + // + // matcher_list: + // matchers: + // # Assign requests with header['env'] set to 'staging' to the bucket { name: 'staging' } + // - predicate: + // single_predicate: + // input: + // typed_config: + // '@type': type.googleapis.com/envoy.type.matcher.v3.HttpRequestHeaderMatchInput + // header_name: env + // value_match: + // exact: staging + // on_match: + // action: + // typed_config: + // '@type': type.googleapis.com/envoy.extensions.filters.http.rate_limit_quota.v3.RateLimitQuotaBucketSettings + // bucket_id_builder: + // bucket_id_builder: + // name: + // string_value: staging + // + // # Assign requests with header['user_group'] set to 'admin' to the bucket { acl: 'admin_users' } + // - predicate: + // single_predicate: + // input: + // typed_config: + // '@type': type.googleapis.com/xds.type.matcher.v3.HttpAttributesCelMatchInput + // custom_match: + // typed_config: + // '@type': type.googleapis.com/xds.type.matcher.v3.CelMatcher + // expr_match: + // # Shortened for illustration purposes. Here should be parsed CEL expression: + // # request.headers['user_group'] == 'admin' + // parsed_expr: {} + // on_match: + // action: + // typed_config: + // '@type': type.googleapis.com/envoy.extensions.filters.http.rate_limit_quota.v3.RateLimitQuotaBucketSettings + // bucket_id_builder: + // bucket_id_builder: + // acl: + // string_value: admin_users + // + // # Catch-all clause for the requests not matched by any of the matchers. + // # In this example, deny all requests. + // on_no_match: + // action: + // typed_config: + // '@type': type.googleapis.com/envoy.extensions.filters.http.rate_limit_quota.v3.RateLimitQuotaBucketSettings + // no_assignment_behavior: + // fallback_rate_limit: + // blanket_rule: DENY_ALL + // + // .. attention:: + // The first matched group wins. Once the request is matched into a bucket, matcher + // evaluation ends. + // + // Use ``on_no_match`` field to assign the catch-all bucket. If a request is not matched + // into any bucket, and there's no ``on_no_match`` field configured, the request will be + // ALLOWED by default. It will NOT be reported to the RLQS server. + // + // Refer to :ref:`Unified Matcher API ` + // documentation for more information on the matcher trees. + xds.type.matcher.v3.Matcher bucket_matchers = 3 [(validate.rules).message = {required: true}]; + + // If set, this will enable -- but not necessarily enforce -- the rate limit for the given + // fraction of requests. + // + // Defaults to 100% of requests. + config.core.v3.RuntimeFractionalPercent filter_enabled = 4; + + // If set, this will enforce the rate limit decisions for the given fraction of requests. + // For requests that are not enforced the filter will still obtain the quota and include it + // in the load computation, however the request will always be allowed regardless of the outcome + // of quota application. This allows validation or testing of the rate limiting service + // infrastructure without disrupting existing traffic. + // + // Note: this only applies to the fraction of enabled requests. + // + // Defaults to 100% of requests. + config.core.v3.RuntimeFractionalPercent filter_enforced = 5; + + // Specifies a list of HTTP headers that should be added to each request that + // has been rate limited and is also forwarded upstream. This can only occur when the + // filter is enabled but not enforced. + repeated config.core.v3.HeaderValueOption request_headers_to_add_when_not_enforced = 6 + [(validate.rules).repeated = {max_items: 10}]; +} + +// Per-route and per-host configuration overrides. The more specific definition completely +// overrides the less specific definition. +message RateLimitQuotaOverride { + // The application domain to use when calling the service. This enables sharing the quota + // server between different applications without fear of overlap. + // E.g., "envoy". + // + // If empty, inherits the value from the less specific definition. + string domain = 1; + + // The match tree to use for grouping incoming requests into buckets. + // + // If set, fully overrides the bucket matchers provided on the less specific definition. + // If not set, inherits the value from the less specific definition. + // + // See usage example: :ref:`RateLimitQuotaFilterConfig.bucket_matchers + // `. + xds.type.matcher.v3.Matcher bucket_matchers = 2; +} + +// Rate Limit Quota Bucket Settings to apply on the successful ``bucket_matchers`` match. +// +// Specify this message in the :ref:`Matcher.OnMatch.action +// ` field of the +// ``bucket_matchers`` matcher tree to assign the matched requests to the Quota Bucket. +// Usage example: :ref:`RateLimitQuotaFilterConfig.bucket_matchers +// `. +// [#next-free-field: 6] +message RateLimitQuotaBucketSettings { + // Configures the behavior after the first request has been matched to the bucket, and before the + // the RLQS server returns the first quota assignment. + message NoAssignmentBehavior { + oneof no_assignment_behavior { + option (validate.required) = true; + + // Apply pre-configured rate limiting strategy until the server sends the first assignment. + type.v3.RateLimitStrategy fallback_rate_limit = 1; + } + } + + // Specifies the behavior when the bucket's assignment has expired, and cannot be refreshed for + // any reason. + message ExpiredAssignmentBehavior { + // Reuse the last known quota assignment, effectively extending it for the duration + // specified in the :ref:`expired_assignment_behavior_timeout + // ` + // field. + message ReuseLastAssignment { + } + + // Limit the time :ref:`ExpiredAssignmentBehavior + // ` + // is applied. If the server doesn't respond within this duration: + // + // 1. Selected ``ExpiredAssignmentBehavior`` is no longer applied. + // 2. The bucket is abandoned. The process of abandoning the bucket is described in the + // :ref:`AbandonAction ` + // message. + // 3. If a new request is matched into the bucket that has become abandoned, + // the data plane restarts the subscription to the bucket. The process of restarting the + // subscription is described in the :ref:`AbandonAction + // ` + // message. + // + // If not set, defaults to zero, and the bucket is abandoned immediately. + google.protobuf.Duration expired_assignment_behavior_timeout = 1 + [(validate.rules).duration = {gt {}}]; + + oneof expired_assignment_behavior { + option (validate.required) = true; + + // Apply the rate limiting strategy to all requests matched into the bucket until the RLQS + // server sends a new assignment, or the :ref:`expired_assignment_behavior_timeout + // ` + // runs out. + type.v3.RateLimitStrategy fallback_rate_limit = 2; + + // Reuse the last ``active`` assignment until the RLQS server sends a new assignment, or the + // :ref:`expired_assignment_behavior_timeout + // ` + // runs out. + ReuseLastAssignment reuse_last_assignment = 3; + } + } + + // Customize the deny response to the requests over the rate limit. + message DenyResponseSettings { + // HTTP response code to deny for HTTP requests (gRPC excluded). + // Defaults to 429 (:ref:`StatusCode.TooManyRequests`). + type.v3.HttpStatus http_status = 1; + + // HTTP response body used to deny for HTTP requests (gRPC excluded). + // If not set, an empty body is returned. + google.protobuf.BytesValue http_body = 2; + + // Configure the deny response for gRPC requests over the rate limit. + // Allows to specify the `RPC status code + // `_, + // and the error message. + // Defaults to the Status with the RPC Code ``UNAVAILABLE`` and empty message. + // + // To identify gRPC requests, Envoy checks that the ``Content-Type`` header is + // ``application/grpc``, or one of the various ``application/grpc+`` values. + // + // .. note:: + // The HTTP code for a gRPC response is always 200. + google.rpc.Status grpc_status = 3; + + // Specifies a list of HTTP headers that should be added to each response for requests that + // have been rate limited. Applies both to plain HTTP, and gRPC requests. + // The headers are added even when the rate limit quota was not enforced. + repeated config.core.v3.HeaderValueOption response_headers_to_add = 4 + [(validate.rules).repeated = {max_items: 10}]; + } + + // ``BucketIdBuilder`` makes it possible to build :ref:`BucketId + // ` with values substituted + // from the dynamic properties associated with each individual request. See usage examples in + // the docs to :ref:`bucket_id_builder + // ` + // field. + message BucketIdBuilder { + // Produces the value of the :ref:`BucketId + // ` map. + message ValueBuilder { + oneof value_specifier { + option (validate.required) = true; + + // Static string value — becomes the value in the :ref:`BucketId + // ` map as is. + string string_value = 1; + + // Dynamic value — evaluated for each request. Must produce a string output, which becomes + // the value in the :ref:`BucketId ` + // map. For example, extensions with the ``envoy.matching.http.input`` category can be used. + config.core.v3.TypedExtensionConfig custom_value = 2; + } + } + + // The map translated into the ``BucketId`` map. + // + // The ``string key`` of this map and becomes the key of ``BucketId`` map as is. + // + // The ``ValueBuilder value`` for the key can be: + // + // * static ``StringValue string_value`` — becomes the value in the ``BucketId`` map as is. + // * dynamic ``TypedExtensionConfig custom_value`` — evaluated for each request. Must produce + // a string output, which becomes the value in the the ``BucketId`` map. + // + // See usage examples in the docs to :ref:`bucket_id_builder + // ` + // field. + map bucket_id_builder = 1 [(validate.rules).map = {min_pairs: 1}]; + } + + // ``BucketId`` builder. + // + // :ref:`BucketId ` is a map from + // the string key to the string value which serves as bucket identifier common for on + // the control plane and the data plane. + // + // While ``BucketId`` is always static, ``BucketIdBuilder`` allows to populate map values + // with the dynamic properties associated with the each individual request. + // + // Example 1: static fields only + // + // ``BucketIdBuilder``: + // + // .. validated-code-block:: yaml + // :type-name: envoy.extensions.filters.http.rate_limit_quota.v3.RateLimitQuotaBucketSettings.BucketIdBuilder + // + // bucket_id_builder: + // name: + // string_value: my_bucket + // hello: + // string_value: world + // + // Produces the following ``BucketId`` for all requests: + // + // .. validated-code-block:: yaml + // :type-name: envoy.service.rate_limit_quota.v3.BucketId + // + // bucket: + // name: my_bucket + // hello: world + // + // Example 2: static and dynamic fields + // + // .. validated-code-block:: yaml + // :type-name: envoy.extensions.filters.http.rate_limit_quota.v3.RateLimitQuotaBucketSettings.BucketIdBuilder + // + // bucket_id_builder: + // name: + // string_value: my_bucket + // env: + // custom_value: + // typed_config: + // '@type': type.googleapis.com/envoy.type.matcher.v3.HttpRequestHeaderMatchInput + // header_name: environment + // + // In this example, the value of ``BucketId`` key ``env`` is substituted from the ``environment`` + // request header. + // + // This is equivalent to the following ``pseudo-code``: + // + // .. code-block:: yaml + // + // name: 'my_bucket' + // env: $header['environment'] + // + // For example, the request with the HTTP header ``env`` set to ``staging`` will produce + // the following ``BucketId``: + // + // .. validated-code-block:: yaml + // :type-name: envoy.service.rate_limit_quota.v3.BucketId + // + // bucket: + // name: my_bucket + // env: staging + // + // For the request with the HTTP header ``environment`` set to ``prod``, will produce: + // + // .. validated-code-block:: yaml + // :type-name: envoy.service.rate_limit_quota.v3.BucketId + // + // bucket: + // name: my_bucket + // env: prod + // + // .. note:: + // The order of ``BucketId`` keys do not matter. Buckets ``{ a: 'A', b: 'B' }`` and + // ``{ b: 'B', a: 'A' }`` are identical. + // + // If not set, requests will NOT be reported to the server, and will always limited + // according to :ref:`no_assignment_behavior + // ` + // configuration. + BucketIdBuilder bucket_id_builder = 1; + + // The interval at which the data plane (RLQS client) is to report quota usage for this bucket. + // + // When the first request is matched to a bucket with no assignment, the data plane is to report + // the request immediately in the :ref:`RateLimitQuotaUsageReports + // ` message. + // For the RLQS server, this signals that the data plane is now subscribed to + // the quota assignments in this bucket, and will start sending the assignment as described in + // the :ref:`RLQS documentation `. + // + // After sending the initial report, the data plane is to continue reporting the bucket usage with + // the internal specified in this field. + // + // If for any reason RLQS client doesn't receive the initial assignment for the reported bucket, + // the data plane will eventually consider the bucket abandoned and stop sending the usage + // reports. This is explained in more details at :ref:`Rate Limit Quota Service (RLQS) + // `. + // + // [#comment: 100000000 nanoseconds = 0.1 seconds] + google.protobuf.Duration reporting_interval = 2 [(validate.rules).duration = { + required: true + gt {nanos: 100000000} + }]; + + // Customize the deny response to the requests over the rate limit. + // If not set, the filter will be configured as if an empty message is set, + // and will behave according to the defaults specified in :ref:`DenyResponseSettings + // `. + DenyResponseSettings deny_response_settings = 3; + + // Configures the behavior in the "no assignment" state: after the first request has been + // matched to the bucket, and before the the RLQS server returns the first quota assignment. + // + // If not set, the default behavior is to allow all requests. + NoAssignmentBehavior no_assignment_behavior = 4; + + // Configures the behavior in the "expired assignment" state: the bucket's assignment has expired, + // and cannot be refreshed. + // + // If not set, the bucket is abandoned when its ``active`` assignment expires. + // The process of abandoning the bucket, and restarting the subscription is described in the + // :ref:`AbandonAction ` + // message. + ExpiredAssignmentBehavior expired_assignment_behavior = 5; +} diff --git a/xds/third_party/envoy/src/main/proto/envoy/extensions/filters/http/rbac/v3/rbac.proto b/xds/third_party/envoy/src/main/proto/envoy/extensions/filters/http/rbac/v3/rbac.proto index eeb505a17fb..a37efe157db 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/extensions/filters/http/rbac/v3/rbac.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/extensions/filters/http/rbac/v3/rbac.proto @@ -4,7 +4,6 @@ package envoy.extensions.filters.http.rbac.v3; import "envoy/config/rbac/v3/rbac.proto"; -import "xds/annotations/v3/status.proto"; import "xds/type/matcher/v3/matcher.proto"; import "udpa/annotations/migrate.proto"; @@ -22,46 +21,57 @@ option (udpa.annotations.file_status).package_version_status = ACTIVE; // [#extension: envoy.filters.http.rbac] // RBAC filter config. -// [#next-free-field: 6] +// [#next-free-field: 8] message RBAC { option (udpa.annotations.versioning).previous_message_type = "envoy.config.filter.http.rbac.v2.RBAC"; - // Specify the RBAC rules to be applied globally. - // If absent, no enforcing RBAC policy will be applied. - // If present and empty, DENY. - // If both rules and matcher are configured, rules will be ignored. + // The primary RBAC policy which will be applied globally, to all the incoming requests. + // + // * If absent, no RBAC enforcement occurs. + // * If set but empty, all requests are denied. + // + // .. note:: + // + // When both ``rules`` and ``matcher`` are configured, ``rules`` will be ignored. + // config.rbac.v3.RBAC rules = 1 [(udpa.annotations.field_migrate).oneof_promotion = "rules_specifier"]; - // The match tree to use when resolving RBAC action for incoming requests. Requests do not - // match any matcher will be denied. - // If absent, no enforcing RBAC matcher will be applied. - // If present and empty, deny all requests. - xds.type.matcher.v3.Matcher matcher = 4 [ - (udpa.annotations.field_migrate).oneof_promotion = "rules_specifier", - (xds.annotations.v3.field_status).work_in_progress = true - ]; - - // Shadow rules are not enforced by the filter (i.e., returning a 403) - // but will emit stats and logs and can be used for rule testing. - // If absent, no shadow RBAC policy will be applied. - // If both shadow rules and shadow matcher are configured, shadow rules will be ignored. + // If specified, rules will emit stats with the given prefix. + // This is useful for distinguishing metrics when multiple RBAC filters are configured. + string rules_stat_prefix = 6; + + // Match tree for evaluating RBAC actions on incoming requests. Requests not matching any matcher will be denied. + // + // * If absent, no RBAC enforcement occurs. + // * If set but empty, all requests are denied. + // + xds.type.matcher.v3.Matcher matcher = 4 + [(udpa.annotations.field_migrate).oneof_promotion = "rules_specifier"]; + + // Shadow policy for testing RBAC rules without enforcing them. These rules generate stats and logs but do not deny + // requests. If absent, no shadow RBAC policy will be applied. + // + // .. note:: + // + // When both ``shadow_rules`` and ``shadow_matcher`` are configured, ``shadow_rules`` will be ignored. + // config.rbac.v3.RBAC shadow_rules = 2 [(udpa.annotations.field_migrate).oneof_promotion = "shadow_rules_specifier"]; - // The match tree to use for emitting stats and logs which can be used for rule testing for - // incoming requests. // If absent, no shadow matcher will be applied. - xds.type.matcher.v3.Matcher shadow_matcher = 5 [ - (udpa.annotations.field_migrate).oneof_promotion = "shadow_rules_specifier", - (xds.annotations.v3.field_status).work_in_progress = true - ]; + // Match tree for testing RBAC rules through stats and logs without enforcing them. + // If absent, no shadow matching occurs. + xds.type.matcher.v3.Matcher shadow_matcher = 5 + [(udpa.annotations.field_migrate).oneof_promotion = "shadow_rules_specifier"]; // If specified, shadow rules will emit stats with the given prefix. - // This is useful to distinguish the stat when there are more than 1 RBAC filter configured with - // shadow rules. + // This is useful for distinguishing metrics when multiple RBAC filters use shadow rules. string shadow_rules_stat_prefix = 3; + + // If ``track_per_rule_stats`` is ``true``, counters will be published for each rule and shadow rule. + bool track_per_rule_stats = 7; } message RBACPerRoute { @@ -70,7 +80,7 @@ message RBACPerRoute { reserved 1; - // Override the global configuration of the filter with this new config. - // If absent, the global RBAC policy will be disabled for this route. + // Per-route specific RBAC configuration that overrides the global RBAC configuration. + // If absent, RBAC policy will be disabled for this route. RBAC rbac = 2; } diff --git a/xds/third_party/envoy/src/main/proto/envoy/extensions/filters/http/router/v3/router.proto b/xds/third_party/envoy/src/main/proto/envoy/extensions/filters/http/router/v3/router.proto index 75bca960da1..7da658bcb33 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/extensions/filters/http/router/v3/router.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/extensions/filters/http/router/v3/router.proto @@ -23,7 +23,7 @@ option (udpa.annotations.file_status).package_version_status = ACTIVE; // Router :ref:`configuration overview `. // [#extension: envoy.filters.http.router] -// [#next-free-field: 10] +// [#next-free-field: 11] message Router { option (udpa.annotations.versioning).previous_message_type = "envoy.config.filter.http.router.v2.Router"; @@ -119,11 +119,11 @@ message Router { // for more details. bool suppress_grpc_request_failure_code_stats = 7; + // Optional HTTP filters for the upstream HTTP filter chain. + // // .. note:: // Upstream HTTP filters are currently in alpha. // - // Optional HTTP filters for the upstream HTTP filter chain. - // // These filters will be applied for all requests that pass through the router. // They will also be applied to shadowed requests. // Upstream HTTP filters cannot change route or cluster. @@ -134,4 +134,10 @@ message Router { // upstream HTTP filters will count as a final response if hedging is configured. // [#extension-category: envoy.filters.http.upstream] repeated network.http_connection_manager.v3.HttpFilter upstream_http_filters = 8; + + // If set to true, Envoy will reject ``CONNECT`` requests that send data before + // receiving a ``200`` response from the upstream. This early data behavior + // is common for latency reduction but can cause issues with some upstreams. + // Defaults to false to allow early data and be compatible with common behavior. + google.protobuf.BoolValue reject_connect_request_early_data = 10; } diff --git a/xds/third_party/envoy/src/main/proto/envoy/extensions/filters/network/http_connection_manager/v3/http_connection_manager.proto b/xds/third_party/envoy/src/main/proto/envoy/extensions/filters/network/http_connection_manager/v3/http_connection_manager.proto index 7a92259eb43..9d8cf8bf4fd 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/extensions/filters/network/http_connection_manager/v3/http_connection_manager.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/extensions/filters/network/http_connection_manager/v3/http_connection_manager.proto @@ -20,6 +20,8 @@ import "google/protobuf/any.proto"; import "google/protobuf/duration.proto"; import "google/protobuf/wrappers.proto"; +import "xds/type/matcher/v3/matcher.proto"; + import "envoy/annotations/deprecation.proto"; import "udpa/annotations/migrate.proto"; import "udpa/annotations/security.proto"; @@ -37,7 +39,7 @@ option (udpa.annotations.file_status).package_version_status = ACTIVE; // HTTP connection manager :ref:`configuration overview `. // [#extension: envoy.filters.network.http_connection_manager] -// [#next-free-field: 57] +// [#next-free-field: 61] message HttpConnectionManager { option (udpa.annotations.versioning).previous_message_type = "envoy.config.filter.network.http_connection_manager.v2.HttpConnectionManager"; @@ -58,9 +60,8 @@ message HttpConnectionManager { // Prior knowledge is allowed). HTTP2 = 2; - // [#not-implemented-hide:] QUIC implementation is not production ready yet. Use this enum with - // caution to prevent accidental execution of QUIC code. I.e. `!= HTTP2` is no longer sufficient - // to distinguish HTTP1 and HTTP2 traffic. + // The connection manager will assume that the client is speaking HTTP/3. + // This needs to be consistent with listener and transport socket config. HTTP3 = 3; } @@ -100,41 +101,53 @@ message HttpConnectionManager { ALWAYS_FORWARD_ONLY = 4; } - // Determines the action for request that contain %2F, %2f, %5C or %5c sequences in the URI path. + // Determines the action for request that contain ``%2F``, ``%2f``, ``%5C`` or ``%5c`` sequences in the URI path. // This operation occurs before URL normalization and the merge slashes transformations if they were enabled. enum PathWithEscapedSlashesAction { // Default behavior specific to implementation (i.e. Envoy) of this configuration option. // Envoy, by default, takes the KEEP_UNCHANGED action. - // NOTE: the implementation may change the default behavior at-will. + // + // .. note:: + // + // The implementation may change the default behavior at-will. IMPLEMENTATION_SPECIFIC_DEFAULT = 0; // Keep escaped slashes. KEEP_UNCHANGED = 1; // Reject client request with the 400 status. gRPC requests will be rejected with the INTERNAL (13) error code. - // The "httpN.downstream_rq_failed_path_normalization" counter is incremented for each rejected request. + // The ``httpN.downstream_rq_failed_path_normalization`` counter is incremented for each rejected request. REJECT_REQUEST = 2; - // Unescape %2F and %5C sequences and redirect request to the new path if these sequences were present. + // Unescape ``%2F`` and ``%5C`` sequences and redirect request to the new path if these sequences were present. // Redirect occurs after path normalization and merge slashes transformations if they were configured. - // NOTE: gRPC requests will be rejected with the INTERNAL (13) error code. - // This option minimizes possibility of path confusion exploits by forcing request with unescaped slashes to - // traverse all parties: downstream client, intermediate proxies, Envoy and upstream server. - // The "httpN.downstream_rq_redirected_with_normalized_path" counter is incremented for each - // redirected request. + // + // .. note:: + // + // gRPC requests will be rejected with the INTERNAL (13) error code. This option minimizes possibility of path + // confusion exploits by forcing request with unescaped slashes to traverse all parties: downstream client, + // intermediate proxies, Envoy and upstream server. The ``httpN.downstream_rq_redirected_with_normalized_path`` + // counter is incremented for each redirected request. + // UNESCAPE_AND_REDIRECT = 3; - // Unescape %2F and %5C sequences. - // Note: this option should not be enabled if intermediaries perform path based access control as - // it may lead to path confusion vulnerabilities. + // Unescape ``%2F`` and ``%5C`` sequences. + // + // .. note:: + // + // This option should not be enabled if intermediaries perform path based access control as it may lead to path + // confusion vulnerabilities. + // UNESCAPE_AND_FORWARD = 4; } - // [#next-free-field: 11] + // [#next-free-field: 13] message Tracing { option (udpa.annotations.versioning).previous_message_type = "envoy.config.filter.network.http_connection_manager.v2.HttpConnectionManager.Tracing"; + // This OperationName makes no sense and is unnecessary in the current tracing API. + // [#not-implemented-hide:] enum OperationName { // The HTTP listener is used for ingress/incoming requests. INGRESS = 0; @@ -186,14 +199,6 @@ message HttpConnectionManager { // Configuration for an external tracing provider. // If not specified, no tracing will be performed. - // - // .. attention:: - // Please be aware that ``envoy.tracers.opencensus`` provider can only be configured once - // in Envoy lifetime. - // Any attempts to reconfigure it or to use different configurations for different HCM filters - // will be rejected. - // Such a constraint is inherent to OpenCensus itself. It cannot be overcome without changes - // on OpenCensus side. config.trace.v3.Tracing.Http provider = 9; // Create separate tracing span for each upstream request if true. And if this flag is set to true, @@ -216,6 +221,28 @@ message HttpConnectionManager { // // The default value is false for now for backward compatibility. google.protobuf.BoolValue spawn_upstream_span = 10; + + // The operation name of the span which will be used for tracing. + // + // The same :ref:`format specifier ` as used for + // :ref:`HTTP access logging ` applies here, however + // unknown specifier values are replaced with the empty string instead of ``-``. + // + // This field will take precedence over and make following settings ineffective: + // + // * :ref:`route decorator ` and + // * :ref:`x-envoy-decorator-operation ` + // header will be ignored. + string operation = 11; + + // The operation name of the upstream span which will be used for tracing. + // This only takes effect when ``spawn_upstream_span`` is set to true and the upstream + // span is created. + // + // The same :ref:`format specifier ` as used for + // :ref:`HTTP access logging ` applies here, however + // unknown specifier values are replaced with the empty string instead of ``-``. + string upstream_operation = 12; } message InternalAddressConfig { @@ -262,18 +289,26 @@ message HttpConnectionManager { bool uri = 5; } + // The configuration for forwarding client cert details. + message ForwardClientCertConfig { + // How to handle the XFCC header. + ForwardClientCertDetails forward_client_cert_details = 1; + + // How to set the current client cert details. + SetCurrentClientCertDetails set_current_client_cert_details = 2; + } + // The configuration for HTTP upgrades. // For each upgrade type desired, an UpgradeConfig must be added. // // .. warning:: // - // The current implementation of upgrade headers does not handle - // multi-valued upgrade headers. Support for multi-valued headers may be - // added in the future if needed. + // The current implementation of upgrade headers does not handle multi-valued upgrade headers. Support for + // multi-valued headers may be added in the future if needed. // // .. warning:: - // The current implementation of upgrade headers does not work with HTTP/2 - // upstreams. + // The current implementation of upgrade headers does not work with HTTP/2 upstreams. + // message UpgradeConfig { option (udpa.annotations.versioning).previous_message_type = "envoy.config.filter.network.http_connection_manager.v2.HttpConnectionManager." @@ -305,7 +340,10 @@ message HttpConnectionManager { // `) will apply to the ``:path`` header // destined for the upstream. // - // Note: access logging and tracing will show the original ``:path`` header. + // .. note:: + // + // Access logging and tracing will show the original ``:path`` header. + // message PathNormalizationOptions { // [#not-implemented-hide:] Normalization applies internally before any processing of requests by // HTTP filters, routing, and matching *and* will affect the forwarded ``:path`` header. Defaults @@ -443,10 +481,25 @@ message HttpConnectionManager { Tracing tracing = 7; // Additional settings for HTTP requests handled by the connection manager. These will be - // applicable to both HTTP1 and HTTP2 requests. + // applicable to both HTTP/1.1 and HTTP/2 requests. config.core.v3.HttpProtocolOptions common_http_protocol_options = 35 [(udpa.annotations.security).configure_for_untrusted_downstream = true]; + // If set to ``true``, Envoy will not initiate an immediate drain timer for downstream HTTP/1 connections + // once :ref:`common_http_protocol_options.max_connection_duration + // ` is exceeded. + // Instead, Envoy will wait until the next downstream request arrives, add a ``connection: close`` header + // to the response, and then gracefully close the connection once the stream has completed. + // + // This behavior adheres to `RFC 9112, Section 9.6 `_. + // + // If set to ``false``, exceeding ``max_connection_duration`` triggers Envoy's default drain behavior for HTTP/1, + // where the connection is eventually closed after all active streams finish. + // + // This option has no effect if ``max_connection_duration`` is not configured. + // Defaults to ``false``. + bool http1_safe_max_connection_duration = 58; + // Additional HTTP/1 settings that are passed to the HTTP/1 codec. // [#comment:TODO: The following fields are ignored when the // :ref:`header validation configuration ` @@ -459,7 +512,6 @@ message HttpConnectionManager { [(udpa.annotations.security).configure_for_untrusted_downstream = true]; // Additional HTTP/3 settings that are passed directly to the HTTP/3 codec. - // [#not-implemented-hide:] config.core.v3.Http3ProtocolOptions http3_protocol_options = 44; // An optional override that the connection manager will write to the server @@ -480,7 +532,16 @@ message HttpConnectionManager { // The maximum request headers size for incoming connections. // If unconfigured, the default max request headers allowed is 60 KiB. + // The default value can be overridden by setting runtime key ``envoy.reloadable_features.max_request_headers_size_kb``. // Requests that exceed this limit will receive a 431 response. + // + // .. note:: + // + // Currently some protocol codecs impose limits on the maximum size of a single header. + // + // * HTTP/2 (when using nghttp2) limits a single header to around 100kb. + // * HTTP/3 limits a single header to around 1024kb. + // google.protobuf.UInt32Value max_request_headers_kb = 29 [(validate.rules).uint32 = {lte: 8192 gt: 0}]; @@ -501,16 +562,6 @@ message HttpConnectionManager { // is terminated with a 408 Request Timeout error code if no upstream response // header has been received, otherwise a stream reset occurs. // - // This timeout also specifies the amount of time that Envoy will wait for the peer to open enough - // window to write any remaining stream data once the entirety of stream data (local end stream is - // true) has been buffered pending available window. In other words, this timeout defends against - // a peer that does not release enough window to completely write the stream, even though all - // data has been proxied within available flow control windows. If the timeout is hit in this - // case, the :ref:`tx_flush_timeout ` counter will be - // incremented. Note that :ref:`max_stream_duration - // ` does not apply to - // this corner case. - // // If the :ref:`overload action ` "envoy.overload_actions.reduce_timeouts" // is configured, this timeout is scaled according to the value for // :ref:`HTTP_DOWNSTREAM_STREAM_IDLE `. @@ -523,9 +574,29 @@ message HttpConnectionManager { // // A value of 0 will completely disable the connection manager stream idle // timeout, although per-route idle timeout overrides will continue to apply. + // + // This timeout is also used as the default value for :ref:`stream_flush_timeout + // `. google.protobuf.Duration stream_idle_timeout = 24 [(udpa.annotations.security).configure_for_untrusted_downstream = true]; + // The stream flush timeout for connections managed by the connection manager. + // + // If not specified, the value of stream_idle_timeout is used. This is for backwards compatibility + // since this was the original behavior. In essence this timeout is an override for the + // stream_idle_timeout that applies specifically to the end of stream flush case. + // + // This timeout specifies the amount of time that Envoy will wait for the peer to open enough + // window to write any remaining stream data once the entirety of stream data (local end stream is + // true) has been buffered pending available window. In other words, this timeout defends against + // a peer that does not release enough window to completely write the stream, even though all + // data has been proxied within available flow control windows. If the timeout is hit in this + // case, the :ref:`tx_flush_timeout ` counter will be + // incremented. Note that :ref:`max_stream_duration + // ` does not apply to + // this corner case. + google.protobuf.Duration stream_flush_timeout = 59; + // The amount of time that Envoy will wait for the entire request to be received. // The timer is activated when the request is initiated, and is disarmed when the last byte of the // request is sent upstream (i.e. all decoding filters have processed the request), OR when the @@ -547,9 +618,10 @@ message HttpConnectionManager { // race with the final GOAWAY frame. During this grace period, Envoy will // continue to accept new streams. After the grace period, a final GOAWAY // frame is sent and Envoy will start refusing new streams. Draining occurs - // both when a connection hits the idle timeout or during general server - // draining. The default grace period is 5000 milliseconds (5 seconds) if this - // option is not specified. + // either when a connection hits the idle timeout, when :ref:`max_connection_duration + // ` + // is reached, or during general server draining. The default grace period is + // 5000 milliseconds (5 seconds) if this option is not specified. google.protobuf.Duration drain_timeout = 12; // The delayed close timeout is for downstream connections managed by the HTTP connection manager. @@ -557,57 +629,67 @@ message HttpConnectionManager { // during which Envoy will wait for the peer to close (i.e., a TCP FIN/RST is received by Envoy // from the downstream connection) prior to Envoy closing the socket associated with that // connection. - // NOTE: This timeout is enforced even when the socket associated with the downstream connection - // is pending a flush of the write buffer. However, any progress made writing data to the socket - // will restart the timer associated with this timeout. This means that the total grace period for - // a socket in this state will be - // +. + // + // .. note:: + // + // This timeout is enforced even when the socket associated with the downstream connection is pending a flush of + // the write buffer. However, any progress made writing data to the socket will restart the timer associated with + // this timeout. This means that the total grace period for a socket in this state will be + // +. // // Delaying Envoy's connection close and giving the peer the opportunity to initiate the close // sequence mitigates a race condition that exists when downstream clients do not drain/process // data in a connection's receive buffer after a remote close has been detected via a socket - // write(). This race leads to such clients failing to process the response code sent by Envoy, + // ``write()``. This race leads to such clients failing to process the response code sent by Envoy, // which could result in erroneous downstream processing. // // If the timeout triggers, Envoy will close the connection's socket. // // The default timeout is 1000 ms if this option is not specified. // - // .. NOTE:: + // .. note:: // To be useful in avoiding the race condition described above, this timeout must be set // to *at least* +<100ms to account for // a reasonable "worst" case processing time for a full iteration of Envoy's event loop>. // - // .. WARNING:: - // A value of 0 will completely disable delayed close processing. When disabled, the downstream + // .. warning:: + // A value of ``0`` will completely disable delayed close processing. When disabled, the downstream // connection's socket will be closed immediately after the write flush is completed or will // never close if the write flush does not complete. + // google.protobuf.Duration delayed_close_timeout = 26; // Configuration for :ref:`HTTP access logs ` // emitted by the connection manager. repeated config.accesslog.v3.AccessLog access_log = 13; + // The interval to flush the above access logs. + // // .. attention:: - // This field is deprecated in favor of - // :ref:`access_log_flush_interval - // `. - // Note that if both this field and :ref:`access_log_flush_interval - // ` - // are specified, the former (deprecated field) is ignored. + // + // This field is deprecated in favor of + // :ref:`access_log_flush_interval + // `. + // Note that if both this field and :ref:`access_log_flush_interval + // ` + // are specified, the former (deprecated field) is ignored. google.protobuf.Duration access_log_flush_interval = 54 [ deprecated = true, (validate.rules).duration = {gte {nanos: 1000000}}, (envoy.annotations.deprecated_at_minor_version) = "3.0" ]; + // If set to true, HCM will flush an access log once when a new HTTP request is received, after the request + // headers have been evaluated, and before iterating through the HTTP filter chain. + // // .. attention:: - // This field is deprecated in favor of - // :ref:`flush_access_log_on_new_request - // `. - // Note that if both this field and :ref:`flush_access_log_on_new_request - // ` - // are specified, the former (deprecated field) is ignored. + // + // This field is deprecated in favor of + // :ref:`flush_access_log_on_new_request + // `. + // Note that if both this field and :ref:`flush_access_log_on_new_request + // ` + // are specified, the former (deprecated field) is ignored. bool flush_access_log_on_new_request = 55 [deprecated = true, (envoy.annotations.deprecated_at_minor_version) = "3.0"]; @@ -631,20 +713,19 @@ message HttpConnectionManager { // :ref:`config_http_conn_man_headers_x-forwarded-for` for more information. uint32 xff_num_trusted_hops = 19; - // The configuration for the original IP detection extensions. + // Configuration for original IP detection extensions. // - // When configured the extensions will be called along with the request headers - // and information about the downstream connection, such as the directly connected address. - // Each extension will then use these parameters to decide the request's effective remote address. - // If an extension fails to detect the original IP address and isn't configured to reject - // the request, the HCM will try the remaining extensions until one succeeds or rejects - // the request. If the request isn't rejected nor any extension succeeds, the HCM will - // fallback to using the remote address. + // When these extensions are configured, Envoy will invoke them with the incoming request headers and + // details about the downstream connection, including the directly connected address. Each extension uses + // this information to determine the effective remote IP address for the request. If an extension cannot + // identify the original IP address and isn't set to reject the request, Envoy will sequentially attempt + // the remaining extensions until one successfully determines the IP or explicitly rejects the request. + // If all extensions fail without rejection, Envoy defaults to using the directly connected remote address. // - // .. WARNING:: - // Extensions cannot be used in conjunction with :ref:`use_remote_address + // .. warning:: + // These extensions cannot be configured simultaneously with :ref:`use_remote_address // ` - // nor :ref:`xff_num_trusted_hops + // or :ref:`xff_num_trusted_hops // `. // // [#extension-category: envoy.http.original_ip_detection] @@ -663,6 +744,34 @@ message HttpConnectionManager { // purposes. If unspecified, only RFC1918 IP addresses will be considered internal. // See the documentation for :ref:`config_http_conn_man_headers_x-envoy-internal` for more // information about internal/external addresses. + // + // .. warning:: + // As of Envoy 1.33.0 no IP addresses will be considered trusted. If you have tooling such as probes + // on your private network which need to be treated as trusted (e.g. changing arbitrary x-envoy headers) + // you will have to manually include those addresses or CIDR ranges like: + // + // .. validated-code-block:: yaml + // :type-name: envoy.extensions.filters.network.http_connection_manager.v3.InternalAddressConfig + // + // cidr_ranges: + // address_prefix: 10.0.0.0 + // prefix_len: 8 + // cidr_ranges: + // address_prefix: 192.168.0.0 + // prefix_len: 16 + // cidr_ranges: + // address_prefix: 172.16.0.0 + // prefix_len: 12 + // cidr_ranges: + // address_prefix: 127.0.0.1 + // prefix_len: 32 + // cidr_ranges: + // address_prefix: fd00:: + // prefix_len: 8 + // cidr_ranges: + // address_prefix: ::1 + // prefix_len: 128 + // InternalAddressConfig internal_address_config = 25; // If set, Envoy will not append the remote address to the @@ -710,6 +819,53 @@ message HttpConnectionManager { // value. SetCurrentClientCertDetails set_current_client_cert_details = 17; + // The matcher for forwarding client cert details. This allows per-request configuration + // of forward client cert behavior based on request properties. If a matcher is configured + // and matches a request, the matched action's forward client cert config will be used. + // If the matcher is not configured or doesn't match, the static + // :ref:`forward_client_cert_details + // ` + // and + // :ref:`set_current_client_cert_details + // ` + // config will be used as fallback. + // + // Example: If the x-forwarded-client-cert header contains "trusted-client", use APPEND_FORWARD, + // otherwise use SANITIZE_SET: + // + // .. code-block:: yaml + // + // forward_client_cert_matcher: + // matcher_list: + // matchers: + // - predicate: + // single_predicate: + // input: + // name: envoy.matching.inputs.request_headers + // typed_config: + // "@type": type.googleapis.com/envoy.type.matcher.v3.HttpRequestHeaderMatchInput + // header_name: "x-forwarded-client-cert" + // value_match: + // string_match: + // contains: "trusted-client" + // on_match: + // action: + // name: forward_client_cert + // typed_config: + // "@type": type.googleapis.com/envoy.extensions.filters.network.http_connection_manager.v3.HttpConnectionManager.ForwardClientCertConfig + // forward_client_cert_details: APPEND_FORWARD + // set_current_client_cert_details: + // uri: true + // on_no_match: + // action: + // name: forward_client_cert + // typed_config: + // "@type": type.googleapis.com/envoy.extensions.filters.network.http_connection_manager.v3.HttpConnectionManager.ForwardClientCertConfig + // forward_client_cert_details: SANITIZE_SET + // set_current_client_cert_details: + // uri: true + xds.type.matcher.v3.Matcher forward_client_cert_matcher = 60; + // If proxy_100_continue is true, Envoy will proxy incoming "Expect: // 100-continue" headers upstream, and forward "100 Continue" responses // downstream. If this is false or not set, Envoy will instead strip the @@ -887,6 +1043,10 @@ message HttpConnectionManager { // will be ignored if the ``x-forwarded-port`` header has been set by any trusted proxy in front of Envoy. bool append_x_forwarded_port = 51; + // Append the :ref:`config_http_conn_man_headers_x-envoy-local-overloaded` HTTP header in the scenario where + // the Overload Manager has been triggered. + bool append_local_overload = 57; + // Whether the HCM will add ProxyProtocolFilterState to the Connection lifetime filter state. Defaults to ``true``. // This should be set to ``false`` in cases where Envoy's view of the downstream address may not correspond to the // actual client address, for example, if there's another proxy in front of the Envoy. @@ -968,7 +1128,7 @@ message Rds { "envoy.config.filter.network.http_connection_manager.v2.Rds"; // Configuration source specifier for RDS. - config.core.v3.ConfigSource config_source = 1 [(validate.rules).message = {required: true}]; + config.core.v3.ConfigSource config_source = 1; // The name of the route configuration. This name will be passed to the RDS // API. This allows an Envoy configuration with multiple HTTP listeners (and diff --git a/xds/third_party/envoy/src/main/proto/envoy/extensions/grpc_service/call_credentials/access_token/v3/access_token_credentials.proto b/xds/third_party/envoy/src/main/proto/envoy/extensions/grpc_service/call_credentials/access_token/v3/access_token_credentials.proto new file mode 100644 index 00000000000..45ee3839e6f --- /dev/null +++ b/xds/third_party/envoy/src/main/proto/envoy/extensions/grpc_service/call_credentials/access_token/v3/access_token_credentials.proto @@ -0,0 +1,19 @@ +syntax = "proto3"; + +package envoy.extensions.grpc_service.call_credentials.access_token.v3; + +import "udpa/annotations/status.proto"; + +option java_package = "io.envoyproxy.envoy.extensions.grpc_service.call_credentials.access_token.v3"; +option java_outer_classname = "AccessTokenCredentialsProto"; +option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/extensions/grpc_service/call_credentials/access_token/v3;access_tokenv3"; +option (udpa.annotations.file_status).package_version_status = ACTIVE; + +// [#protodoc-title: gRPC Access Token Credentials] + +// [#not-implemented-hide:] +message AccessTokenCredentials { + // The access token. + string token = 1; +} diff --git a/xds/third_party/envoy/src/main/proto/envoy/extensions/grpc_service/channel_credentials/google_default/v3/google_default_credentials.proto b/xds/third_party/envoy/src/main/proto/envoy/extensions/grpc_service/channel_credentials/google_default/v3/google_default_credentials.proto new file mode 100644 index 00000000000..77c3af41fdd --- /dev/null +++ b/xds/third_party/envoy/src/main/proto/envoy/extensions/grpc_service/channel_credentials/google_default/v3/google_default_credentials.proto @@ -0,0 +1,17 @@ +syntax = "proto3"; + +package envoy.extensions.grpc_service.channel_credentials.google_default.v3; + +import "udpa/annotations/status.proto"; + +option java_package = "io.envoyproxy.envoy.extensions.grpc_service.channel_credentials.google_default.v3"; +option java_outer_classname = "GoogleDefaultCredentialsProto"; +option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/extensions/grpc_service/channel_credentials/google_default/v3;google_defaultv3"; +option (udpa.annotations.file_status).package_version_status = ACTIVE; + +// [#protodoc-title: gRPC Google Default Credentials] + +// [#not-implemented-hide:] +message GoogleDefaultCredentials { +} diff --git a/xds/third_party/envoy/src/main/proto/envoy/extensions/grpc_service/channel_credentials/insecure/v3/insecure_credentials.proto b/xds/third_party/envoy/src/main/proto/envoy/extensions/grpc_service/channel_credentials/insecure/v3/insecure_credentials.proto new file mode 100644 index 00000000000..70d58451e2d --- /dev/null +++ b/xds/third_party/envoy/src/main/proto/envoy/extensions/grpc_service/channel_credentials/insecure/v3/insecure_credentials.proto @@ -0,0 +1,17 @@ +syntax = "proto3"; + +package envoy.extensions.grpc_service.channel_credentials.insecure.v3; + +import "udpa/annotations/status.proto"; + +option java_package = "io.envoyproxy.envoy.extensions.grpc_service.channel_credentials.insecure.v3"; +option java_outer_classname = "InsecureCredentialsProto"; +option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/extensions/grpc_service/channel_credentials/insecure/v3;insecurev3"; +option (udpa.annotations.file_status).package_version_status = ACTIVE; + +// [#protodoc-title: gRPC Insecure Credentials] + +// [#not-implemented-hide:] +message InsecureCredentials { +} diff --git a/xds/third_party/envoy/src/main/proto/envoy/extensions/grpc_service/channel_credentials/local/v3/local_credentials.proto b/xds/third_party/envoy/src/main/proto/envoy/extensions/grpc_service/channel_credentials/local/v3/local_credentials.proto new file mode 100644 index 00000000000..00514a0e847 --- /dev/null +++ b/xds/third_party/envoy/src/main/proto/envoy/extensions/grpc_service/channel_credentials/local/v3/local_credentials.proto @@ -0,0 +1,17 @@ +syntax = "proto3"; + +package envoy.extensions.grpc_service.channel_credentials.local.v3; + +import "udpa/annotations/status.proto"; + +option java_package = "io.envoyproxy.envoy.extensions.grpc_service.channel_credentials.local.v3"; +option java_outer_classname = "LocalCredentialsProto"; +option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/extensions/grpc_service/channel_credentials/local/v3;localv3"; +option (udpa.annotations.file_status).package_version_status = ACTIVE; + +// [#protodoc-title: gRPC Local Credentials] + +// [#not-implemented-hide:] +message LocalCredentials { +} diff --git a/xds/third_party/envoy/src/main/proto/envoy/extensions/grpc_service/channel_credentials/tls/v3/tls_credentials.proto b/xds/third_party/envoy/src/main/proto/envoy/extensions/grpc_service/channel_credentials/tls/v3/tls_credentials.proto new file mode 100644 index 00000000000..f64c16bb684 --- /dev/null +++ b/xds/third_party/envoy/src/main/proto/envoy/extensions/grpc_service/channel_credentials/tls/v3/tls_credentials.proto @@ -0,0 +1,27 @@ +syntax = "proto3"; + +package envoy.extensions.grpc_service.channel_credentials.tls.v3; + +import "envoy/extensions/transport_sockets/tls/v3/tls.proto"; + +import "udpa/annotations/status.proto"; + +option java_package = "io.envoyproxy.envoy.extensions.grpc_service.channel_credentials.tls.v3"; +option java_outer_classname = "TlsCredentialsProto"; +option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/extensions/grpc_service/channel_credentials/tls/v3;tlsv3"; +option (udpa.annotations.file_status).package_version_status = ACTIVE; + +// [#protodoc-title: gRPC TLS Credentials] + +// [#not-implemented-hide:] +message TlsCredentials { + // The certificate provider instance for the root cert. Must be set. + transport_sockets.tls.v3.CommonTlsContext.CertificateProviderInstance root_certificate_provider = + 1; + + // The certificate provider instance for the identity cert. Optional; + // if unset, no identity certificate will be sent to the server. + transport_sockets.tls.v3.CommonTlsContext.CertificateProviderInstance + identity_certificate_provider = 2; +} diff --git a/xds/third_party/envoy/src/main/proto/envoy/extensions/grpc_service/channel_credentials/xds/v3/xds_credentials.proto b/xds/third_party/envoy/src/main/proto/envoy/extensions/grpc_service/channel_credentials/xds/v3/xds_credentials.proto new file mode 100644 index 00000000000..ba8d471dd49 --- /dev/null +++ b/xds/third_party/envoy/src/main/proto/envoy/extensions/grpc_service/channel_credentials/xds/v3/xds_credentials.proto @@ -0,0 +1,21 @@ +syntax = "proto3"; + +package envoy.extensions.grpc_service.channel_credentials.xds.v3; + +import "google/protobuf/any.proto"; + +import "udpa/annotations/status.proto"; + +option java_package = "io.envoyproxy.envoy.extensions.grpc_service.channel_credentials.xds.v3"; +option java_outer_classname = "XdsCredentialsProto"; +option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/extensions/grpc_service/channel_credentials/xds/v3;xdsv3"; +option (udpa.annotations.file_status).package_version_status = ACTIVE; + +// [#protodoc-title: gRPC xDS Credentials] + +// [#not-implemented-hide:] +message XdsCredentials { + // Fallback credentials. Required. + google.protobuf.Any fallback_credentials = 1; +} diff --git a/xds/third_party/envoy/src/main/proto/envoy/extensions/load_balancing_policies/client_side_weighted_round_robin/v3/client_side_weighted_round_robin.proto b/xds/third_party/envoy/src/main/proto/envoy/extensions/load_balancing_policies/client_side_weighted_round_robin/v3/client_side_weighted_round_robin.proto index c70360a0946..c55d30b89e0 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/extensions/load_balancing_policies/client_side_weighted_round_robin/v3/client_side_weighted_round_robin.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/extensions/load_balancing_policies/client_side_weighted_round_robin/v3/client_side_weighted_round_robin.proto @@ -2,6 +2,8 @@ syntax = "proto3"; package envoy.extensions.load_balancing_policies.client_side_weighted_round_robin.v3; +import "envoy/extensions/load_balancing_policies/common/v3/common.proto"; + import "google/protobuf/duration.proto"; import "google/protobuf/wrappers.proto"; @@ -15,7 +17,7 @@ option go_package = "github.com/envoyproxy/go-control-plane/envoy/extensions/loa option (udpa.annotations.file_status).package_version_status = ACTIVE; // [#protodoc-title: Client-Side Weighted Round Robin Load Balancing Policy] -// [#not-implemented-hide:] +// [#extension: envoy.load_balancing_policies.client_side_weighted_round_robin] // Configuration for the client_side_weighted_round_robin LB policy. // @@ -30,11 +32,19 @@ option (udpa.annotations.file_status).package_version_status = ACTIVE; // regardless of result. Only failed queries count toward eps. A config // parameter error_utilization_penalty controls the penalty to adjust endpoint // weights using eps and qps. The weight of a given endpoint is computed as: -// qps / (utilization + eps/qps * error_utilization_penalty) +// ``qps / (utilization + eps/qps * error_utilization_penalty)``. +// +// Note that Envoy will forward the ORCA response headers/trailers from the upstream +// cluster to the downstream client. This means that if the downstream client is also +// configured to use ``client_side_weighted_round_robin`` it will load balance against +// Envoy based on upstream weights. This can happen when Envoy is used as a reverse proxy. +// To avoid this issue you can configure the :ref:`header_mutation filter ` to remove +// the ORCA payload from the response headers/trailers. // -// See the :ref:`load balancing architecture overview` for more information. +// See the :ref:`load balancing architecture +// overview` for more information. // -// [#next-free-field: 7] +// [#next-free-field: 9] message ClientSideWeightedRoundRobin { // Whether to enable out-of-band utilization reporting collection from // the endpoints. By default, per-request utilization reporting is used. @@ -68,4 +78,14 @@ message ClientSideWeightedRoundRobin { // calculated as eps/qps. Configuration is rejected if this value is negative. // Default is 1.0. google.protobuf.FloatValue error_utilization_penalty = 6 [(validate.rules).float = {gte: 0.0}]; + + // By default, endpoint weight is computed based on the :ref:`application_utilization ` field reported by the endpoint. + // If that field is not set, then utilization will instead be computed by taking the max of the values of the metrics specified here. + // For map fields in the ORCA proto, the string will be of the form ``.``. For example, the string ``named_metrics.foo`` will mean to look for the key ``foo`` in the ORCA :ref:`named_metrics ` field. + // If none of the specified metrics are present in the load report, then :ref:`cpu_utilization ` is used instead. + repeated string metric_names_for_computing_utilization = 7; + + // Configuration for slow start mode. + // If this configuration is not set, slow start will not be not enabled. + common.v3.SlowStartConfig slow_start_config = 8; } diff --git a/xds/third_party/envoy/src/main/proto/envoy/extensions/load_balancing_policies/common/v3/common.proto b/xds/third_party/envoy/src/main/proto/envoy/extensions/load_balancing_policies/common/v3/common.proto index 51520690a29..22faf11b9c5 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/extensions/load_balancing_policies/common/v3/common.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/extensions/load_balancing_policies/common/v3/common.proto @@ -3,11 +3,13 @@ syntax = "proto3"; package envoy.extensions.load_balancing_policies.common.v3; import "envoy/config/core/v3/base.proto"; +import "envoy/config/route/v3/route_components.proto"; import "envoy/type/v3/percent.proto"; import "google/protobuf/duration.proto"; import "google/protobuf/wrappers.proto"; +import "envoy/annotations/deprecation.proto"; import "udpa/annotations/status.proto"; import "validate/validate.proto"; @@ -22,7 +24,34 @@ option (udpa.annotations.file_status).package_version_status = ACTIVE; message LocalityLbConfig { // Configuration for :ref:`zone aware routing // `. + // [#next-free-field: 7] message ZoneAwareLbConfig { + // Basis for computing per-locality percentages in zone-aware routing. + enum LocalityBasis { + // Use the number of healthy hosts in each locality. + HEALTHY_HOSTS_NUM = 0; + + // Use the weights of healthy hosts in each locality. + HEALTHY_HOSTS_WEIGHT = 1; + } + + // Configures Envoy to always route requests to the local zone regardless of the + // upstream zone structure. In Envoy's default configuration, traffic is distributed proportionally + // across all upstream hosts while trying to maximize local routing when possible. The approach + // with force_local_zone aims to be more predictable and if there are upstream hosts in the local + // zone, they will receive all traffic. + // * :ref:`runtime values `. + // * :ref:`Zone aware routing support `. + message ForceLocalZone { + // Configures the minimum number of upstream hosts in the local zone required when force_local_zone + // is enabled. If the number of upstream hosts in the local zone is less than the specified value, + // Envoy will fall back to the default proportional-based distribution across localities. + // If not specified, the default is 1. + // * :ref:`runtime values `. + // * :ref:`Zone aware routing support `. + google.protobuf.UInt32Value min_size = 1; + } + // Configures percentage of requests that will be considered for zone aware routing // if zone aware routing is configured. If not specified, the default is 100%. // * :ref:`runtime values `. @@ -41,6 +70,18 @@ message LocalityLbConfig { // requests as if all hosts are unhealthy. This can help avoid potentially overwhelming a // failing service. bool fail_traffic_on_panic = 3; + + // If set to true, Envoy will force LocalityDirect routing if a local locality exists. + bool force_locality_direct_routing = 4 + [deprecated = true, (envoy.annotations.deprecated_at_minor_version) = "3.0"]; + + ForceLocalZone force_local_zone = 5; + + // Determines how locality percentages are computed: + // - HEALTHY_HOSTS_NUM: proportional to the count of healthy hosts. + // - HEALTHY_HOSTS_WEIGHT: proportional to the weights of healthy hosts. + // Default value is HEALTHY_HOSTS_NUM if unset. + LocalityBasis locality_basis = 6; } // Configuration for :ref:`locality weighted load balancing @@ -111,4 +152,10 @@ message ConsistentHashingLbConfig { // This is an O(N) algorithm, unlike other load balancers. Using a lower ``hash_balance_factor`` results in more hosts // being probed, so use a higher value if you require better performance. google.protobuf.UInt32Value hash_balance_factor = 2 [(validate.rules).uint32 = {gte: 100}]; + + // Specifies a list of hash policies to use for ring hash load balancing. If ``hash_policy`` is + // set, then + // :ref:`route level hash policy ` + // will be ignored. + repeated config.route.v3.RouteAction.HashPolicy hash_policy = 3; } diff --git a/xds/third_party/envoy/src/main/proto/envoy/extensions/load_balancing_policies/least_request/v3/least_request.proto b/xds/third_party/envoy/src/main/proto/envoy/extensions/load_balancing_policies/least_request/v3/least_request.proto index ebef61852e2..095f6075286 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/extensions/load_balancing_policies/least_request/v3/least_request.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/extensions/load_balancing_policies/least_request/v3/least_request.proto @@ -7,6 +7,7 @@ import "envoy/extensions/load_balancing_policies/common/v3/common.proto"; import "google/protobuf/wrappers.proto"; +import "envoy/annotations/deprecation.proto"; import "udpa/annotations/status.proto"; import "validate/validate.proto"; @@ -22,10 +23,34 @@ option (udpa.annotations.file_status).package_version_status = ACTIVE; // This configuration allows the built-in LEAST_REQUEST LB policy to be configured via the LB policy // extension point. See the :ref:`load balancing architecture overview // ` for more information. -// [#next-free-field: 6] +// [#next-free-field: 7] message LeastRequest { + // Available methods for selecting the host set from which to return the host with the + // fewest active requests. + enum SelectionMethod { + // Return host with fewest requests from a set of ``choice_count`` randomly selected hosts. + // Best selection method for most scenarios. + N_CHOICES = 0; + + // Return host with fewest requests from all hosts. + // Useful in some niche use cases involving low request rates and one of: + // (example 1) low request limits on workloads, or (example 2) few hosts. + // + // Example 1: Consider a workload type that can only accept one connection at a time. + // If such workloads are deployed across many hosts, only a small percentage of those + // workloads have zero connections at any given time, and the rate of new connections is low, + // the ``FULL_SCAN`` method is more likely to select a suitable host than ``N_CHOICES``. + // + // Example 2: Consider a workload type that is only deployed on 2 hosts. With default settings, + // the ``N_CHOICES`` method will return the host with more active requests 25% of the time. + // If the request rate is sufficiently low, the behavior of always selecting the host with least + // requests as of the last metrics refresh may be preferable. + FULL_SCAN = 1; + } + // The number of random healthy hosts from which the host with the fewest active requests will // be chosen. Defaults to 2 so that we perform two-choice selection if the field is not set. + // Only applies to the ``N_CHOICES`` selection method. google.protobuf.UInt32Value choice_count = 1 [(validate.rules).uint32 = {gte: 2}]; // The following formula is used to calculate the dynamic weights when hosts have different load @@ -61,8 +86,12 @@ message LeastRequest { common.v3.LocalityLbConfig locality_lb_config = 4; // [#not-implemented-hide:] - // Configuration for performing full scan on the list of hosts. - // If this configuration is set, when selecting the host a full scan on the list hosts will be - // used to select the one with least requests instead of using random choices. - google.protobuf.BoolValue enable_full_scan = 5; + // Unused. Replaced by the `selection_method` enum for better extensibility. + google.protobuf.BoolValue enable_full_scan = 5 + [deprecated = true, (envoy.annotations.deprecated_at_minor_version) = "3.0"]; + + // Method for selecting the host set from which to return the host with the fewest active requests. + // + // Defaults to ``N_CHOICES``. + SelectionMethod selection_method = 6 [(validate.rules).enum = {defined_only: true}]; } diff --git a/xds/third_party/envoy/src/main/proto/envoy/extensions/load_balancing_policies/wrr_locality/v3/wrr_locality.proto b/xds/third_party/envoy/src/main/proto/envoy/extensions/load_balancing_policies/wrr_locality/v3/wrr_locality.proto index ab8367a401a..e2e4ade8236 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/extensions/load_balancing_policies/wrr_locality/v3/wrr_locality.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/extensions/load_balancing_policies/wrr_locality/v3/wrr_locality.proto @@ -14,7 +14,7 @@ option go_package = "github.com/envoyproxy/go-control-plane/envoy/extensions/loa option (udpa.annotations.file_status).package_version_status = ACTIVE; // [#protodoc-title: Weighted Round Robin Locality-Picking Load Balancing Policy] -// [#not-implemented-hide:] +// [#extension: envoy.load_balancing_policies.wrr_locality] // Configuration for the wrr_locality LB policy. See the :ref:`load balancing architecture overview // ` for more information. diff --git a/xds/third_party/envoy/src/main/proto/envoy/extensions/transport_sockets/http_11_proxy/v3/upstream_http_11_connect.proto b/xds/third_party/envoy/src/main/proto/envoy/extensions/transport_sockets/http_11_proxy/v3/upstream_http_11_connect.proto new file mode 100644 index 00000000000..2c9b5333f41 --- /dev/null +++ b/xds/third_party/envoy/src/main/proto/envoy/extensions/transport_sockets/http_11_proxy/v3/upstream_http_11_connect.proto @@ -0,0 +1,38 @@ +syntax = "proto3"; + +package envoy.extensions.transport_sockets.http_11_proxy.v3; + +import "envoy/config/core/v3/base.proto"; + +import "udpa/annotations/status.proto"; + +option java_package = "io.envoyproxy.envoy.extensions.transport_sockets.http_11_proxy.v3"; +option java_outer_classname = "UpstreamHttp11ConnectProto"; +option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/extensions/transport_sockets/http_11_proxy/v3;http_11_proxyv3"; +option (udpa.annotations.file_status).package_version_status = ACTIVE; + +// [#protodoc-title: Upstream HTTP/1.1 Proxy] +// [#extension: envoy.transport_sockets.http_11_proxy] + +// HTTP/1.1 proxy transport socket establishes an upstream connection to a proxy address +// instead of the target host's address. This behavior is triggered when the transport +// socket is configured and proxy information is provided. +// +// Behavior when proxying: +// ======================= +// When an upstream connection is established, instead of connecting directly to the endpoint +// address, the client will connect to the specified proxy address, send an HTTP/1.1 ``CONNECT`` request +// indicating the endpoint address, and process the response. If the response has HTTP status 200, +// the connection will be passed down to the underlying transport socket. +// +// Configuring proxy information: +// ============================== +// Set ``typed_filter_metadata`` in :ref:`LbEndpoint.Metadata ` or :ref:`LocalityLbEndpoints.Metadata `. +// using the key ``envoy.http11_proxy_transport_socket.proxy_address`` and the +// proxy address in ``config::core::v3::Address`` format. +// +message Http11ProxyUpstreamTransport { + // The underlying transport socket being wrapped. Defaults to plaintext (raw_buffer) if unset. + config.core.v3.TransportSocket transport_socket = 1; +} diff --git a/xds/third_party/envoy/src/main/proto/envoy/extensions/transport_sockets/tls/v3/common.proto b/xds/third_party/envoy/src/main/proto/envoy/extensions/transport_sockets/tls/v3/common.proto index d244adcdf54..9bc5fb5d029 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/extensions/transport_sockets/tls/v3/common.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/extensions/transport_sockets/tls/v3/common.proto @@ -24,7 +24,7 @@ option (udpa.annotations.file_status).package_version_status = ACTIVE; // [#protodoc-title: Common TLS configuration] -// [#next-free-field: 6] +// [#next-free-field: 7] message TlsParameters { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.auth.TlsParameters"; @@ -45,6 +45,23 @@ message TlsParameters { TLSv1_3 = 4; } + enum CompliancePolicy { + // FIPS_202205 configures a TLS connection to use: + // + // * TLS 1.2 or 1.3 + // * For TLS 1.2, only ECDHE_[RSA|ECDSA]_WITH_AES_*_GCM_SHA*. + // * For TLS 1.3, only AES-GCM + // * P-256 or P-384 for key agreement. + // * For server signatures, only ``PKCS#1/PSS`` with ``SHA256/384/512``, or ECDSA + // with P-256 or P-384. + // + // .. attention:: + // + // Please refer to `BoringSSL policies `_ + // for details. + FIPS_202205 = 0; + } + // Minimum TLS protocol version. By default, it's ``TLSv1_2`` for both clients and servers. // // TLS protocol versions below TLSv1_2 require setting compatible ciphers with the @@ -157,6 +174,11 @@ message TlsParameters { // rsa_pkcs1_sha1 // ecdsa_sha1 repeated string signature_algorithms = 5; + + // Compliance policies configure various aspects of the TLS based on the given policy. + // The policies are applied last during configuration and may override the other TLS + // parameters, or any previous policy. + repeated CompliancePolicy compliance_policies = 6 [(validate.rules).repeated = {max_items: 1}]; } // BoringSSL private key method configuration. The private key methods are used for external @@ -232,12 +254,13 @@ message TlsCertificate { config.core.v3.WatchedDirectory watched_directory = 7; // BoringSSL private key method provider. This is an alternative to :ref:`private_key - // ` field. This can't be - // marked as ``oneof`` due to API compatibility reasons. Setting both :ref:`private_key - // ` and - // :ref:`private_key_provider - // ` fields will result in an - // error. + // ` field. + // When both :ref:`private_key ` and + // :ref:`private_key_provider ` fields are set, + // ``private_key_provider`` takes precedence. + // If ``private_key_provider`` is unavailable and :ref:`fallback + // ` + // is enabled, ``private_key`` will be used. PrivateKeyProvider private_key_provider = 6; // The password to decrypt the TLS private key. If this field is not set, it is assumed that the @@ -290,12 +313,12 @@ message TlsSessionTicketKeys { // respect to the TLS handshake. // [#not-implemented-hide:] message CertificateProviderPluginInstance { - // Provider instance name. If not present, defaults to "default". + // Provider instance name. // // Instance names should generally be defined not in terms of the underlying provider // implementation (e.g., "file_watcher") but rather in terms of the function of the // certificates (e.g., "foo_deployment_identity"). - string instance_name = 1; + string instance_name = 1 [(validate.rules).string = {min_len: 1}]; // Opaque name used to specify certificate instances or types. For example, "ROOTCA" to specify // a root-certificate (validation context) or "example.com" to specify a certificate for a @@ -314,16 +337,39 @@ message SubjectAltNameMatcher { DNS = 2; URI = 3; IP_ADDRESS = 4; + OTHER_NAME = 5; } // Specification of type of SAN. Note that the default enum value is an invalid choice. SanType san_type = 1 [(validate.rules).enum = {defined_only: true not_in: 0}]; // Matcher for SAN value. + // + // If the :ref:`san_type ` + // is :ref:`DNS ` + // and the matcher type is :ref:`exact `, DNS wildcards are evaluated + // according to the rules in https://www.rfc-editor.org/rfc/rfc6125#section-6.4.3. + // For example, ``*.example.com`` would match ``test.example.com`` but not ``example.com`` and not + // ``a.b.example.com``. + // + // The string matching for OTHER_NAME SAN values depends on their ASN.1 type: + // + // * OBJECT: Validated against its dotted numeric notation (e.g., "1.2.3.4") + // * BOOLEAN: Validated against strings "true" or "false" + // * INTEGER/ENUMERATED: Validated against a string containing the integer value + // * NULL: Validated against an empty string + // * Other types: Validated directly against the string value type.matcher.v3.StringMatcher matcher = 2 [(validate.rules).message = {required: true}]; + + // OID Value which is required if OTHER_NAME SAN type is used. + // For example, UPN OID is 1.3.6.1.4.1.311.20.2.3 + // (Reference: http://oid-info.com/get/1.3.6.1.4.1.311.20.2.3). + // + // If set for SAN types other than OTHER_NAME, it will be ignored. + string oid = 3; } -// [#next-free-field: 17] +// [#next-free-field: 18] message CertificateValidationContext { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.auth.CertificateValidationContext"; @@ -339,6 +385,9 @@ message CertificateValidationContext { ACCEPT_UNTRUSTED = 1; } + message SystemRootCerts { + } + reserved 4, 5; reserved "verify_subject_alt_name"; @@ -378,20 +427,23 @@ message CertificateValidationContext { // can be treated as trust anchor as well. It allows verification with building valid partial chain instead // of a full chain. // - // Only one of ``trusted_ca`` and ``ca_certificate_provider_instance`` may be specified. - // - // [#next-major-version: This field and watched_directory below should ideally be moved into a - // separate sub-message, since there's no point in specifying the latter field without this one.] + // If ``ca_certificate_provider_instance`` is set, it takes precedence over ``trusted_ca``. config.core.v3.DataSource trusted_ca = 1 [(udpa.annotations.field_migrate).oneof_promotion = "ca_cert_source"]; // Certificate provider instance for fetching TLS certificates. // - // Only one of ``trusted_ca`` and ``ca_certificate_provider_instance`` may be specified. + // If set, takes precedence over ``trusted_ca``. // [#not-implemented-hide:] CertificateProviderPluginInstance ca_certificate_provider_instance = 13 [(udpa.annotations.field_migrate).oneof_promotion = "ca_cert_source"]; + // Use system root certs for validation. + // If present, system root certs are used only if neither of the ``trusted_ca`` + // or ``ca_certificate_provider_instance`` fields are set. + // [#not-implemented-hide:] + SystemRootCerts system_root_certs = 17; + // If specified, updates of a file-based ``trusted_ca`` source will be triggered // by this watch. This allows explicit control over the path watched, by // default the parent directory of the filesystem path in ``trusted_ca`` is diff --git a/xds/third_party/envoy/src/main/proto/envoy/extensions/transport_sockets/tls/v3/secret.proto b/xds/third_party/envoy/src/main/proto/envoy/extensions/transport_sockets/tls/v3/secret.proto index 83ad364c4bf..94660e2da9f 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/extensions/transport_sockets/tls/v3/secret.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/extensions/transport_sockets/tls/v3/secret.proto @@ -22,8 +22,13 @@ option (udpa.annotations.file_status).package_version_status = ACTIVE; message GenericSecret { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.auth.GenericSecret"; - // Secret of generic type and is available to filters. + // Secret of generic type and is available to filters. It is expected + // that only only one of secret and secrets is set. config.core.v3.DataSource secret = 1 [(udpa.annotations.sensitive) = true]; + + // For cases where multiple associated secrets need to be distributed together. It is expected + // that only only one of secret and secrets is set. + map secrets = 2 [(udpa.annotations.sensitive) = true]; } message SdsSecretConfig { diff --git a/xds/third_party/envoy/src/main/proto/envoy/extensions/transport_sockets/tls/v3/tls.proto b/xds/third_party/envoy/src/main/proto/envoy/extensions/transport_sockets/tls/v3/tls.proto index f94889cfad0..d656c66b5d0 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/extensions/transport_sockets/tls/v3/tls.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/extensions/transport_sockets/tls/v3/tls.proto @@ -25,7 +25,7 @@ option (udpa.annotations.file_status).package_version_status = ACTIVE; // [#extension: envoy.transport_sockets.tls] // The TLS contexts below provide the transport socket configuration for upstream/downstream TLS. -// [#next-free-field: 6] +// [#next-free-field: 8] message UpstreamTlsContext { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.auth.UpstreamTlsContext"; @@ -34,14 +34,32 @@ message UpstreamTlsContext { // // .. attention:: // - // Server certificate verification is not enabled by default. Configure - // :ref:`trusted_ca` to enable - // verification. + // Server certificate verification is not enabled by default. To enable verification, configure + // :ref:`trusted_ca`. CommonTlsContext common_tls_context = 1; // SNI string to use when creating TLS backend connections. string sni = 2 [(validate.rules).string = {max_bytes: 255}]; + // If true, replaces the SNI for the connection with the hostname of the upstream host, if + // the hostname is known due to either a DNS cluster type or the + // :ref:`hostname ` is set on + // the host. + // + // See :ref:`SNI configuration ` for details on how this + // interacts with other validation options. + bool auto_host_sni = 6; + + // If true, replaces any Subject Alternative Name (SAN) validations with a validation for a DNS SAN matching + // the SNI value sent. The validation uses the actual requested SNI, regardless of how the SNI is configured. + // + // For common cases where an SNI value is present and the server certificate should include a corresponding SAN, + // this option ensures the SAN is properly validated. + // + // See the :ref:`validation configuration ` for how this interacts with + // other validation options. + bool auto_sni_san_validation = 7; + // If true, server-initiated TLS renegotiation will be allowed. // // .. attention:: @@ -50,43 +68,38 @@ message UpstreamTlsContext { bool allow_renegotiation = 3; // Maximum number of session keys (Pre-Shared Keys for TLSv1.3+, Session IDs and Session Tickets - // for TLSv1.2 and older) to store for the purpose of session resumption. + // for TLSv1.2 and older) to be stored for session resumption. // // Defaults to 1, setting this to 0 disables session resumption. google.protobuf.UInt32Value max_session_keys = 4; - // This field is used to control the enforcement, whereby the handshake will fail if the keyUsage extension - // is present and incompatible with the TLS usage. Currently, the default value is false (i.e., enforcement off) - // but it is expected to be changed to true by default in a future release. - // ``ssl.was_key_usage_invalid`` in :ref:`listener metrics ` will be set for certificate - // configurations that would fail if this option were set to true. + // Controls enforcement of the ``keyUsage`` extension in peer certificates. If set to ``true``, the handshake will fail if + // the ``keyUsage`` is incompatible with TLS usage. + // + // .. note:: + // The default value is ``false`` (i.e., enforcement off). It is expected to change to ``true`` in a future release. + // + // The ``ssl.was_key_usage_invalid`` in :ref:`listener metrics ` metric will be incremented + // for configurations that would fail if this option were enabled. google.protobuf.BoolValue enforce_rsa_key_usage = 5; } -// [#next-free-field: 11] +// [#next-free-field: 12] message DownstreamTlsContext { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.auth.DownstreamTlsContext"; enum OcspStaplePolicy { - // OCSP responses are optional. If an OCSP response is absent - // or expired, the associated certificate will be used for - // connections without an OCSP staple. + // OCSP responses are optional. If absent or expired, the certificate is used without stapling. LENIENT_STAPLING = 0; - // OCSP responses are optional. If an OCSP response is absent, - // the associated certificate will be used without an - // OCSP staple. If a response is provided but is expired, - // the associated certificate will not be used for - // subsequent connections. If no suitable certificate is found, - // the connection is rejected. + // OCSP responses are optional. If absent, the certificate is used without stapling. If present but expired, + // the certificate is not used for subsequent connections. Connections are rejected if no suitable certificate + // is found. STRICT_STAPLING = 1; - // OCSP responses are required. Configuration will fail if - // a certificate is provided without an OCSP response. If a - // response expires, the associated certificate will not be - // used connections. If no suitable certificate is found, the - // connection is rejected. + // OCSP responses are required. Connections fail if a certificate lacks a valid OCSP response. Expired responses + // prevent certificate use in new connections, and connections are rejected if no suitable certificate is available. MUST_STAPLE = 2; } @@ -119,51 +132,64 @@ message DownstreamTlsContext { bool disable_stateless_session_resumption = 7; } - // If set to true, the TLS server will not maintain a session cache of TLS sessions. (This is - // relevant only for TLSv1.2 and earlier.) + // If ``true``, the TLS server will not maintain a session cache of TLS sessions. + // + // .. note:: + // This applies only to TLSv1.2 and earlier. + // bool disable_stateful_session_resumption = 10; - // If specified, ``session_timeout`` will change the maximum lifetime (in seconds) of the TLS session. - // Currently this value is used as a hint for the `TLS session ticket lifetime (for TLSv1.2) `_. - // Only seconds can be specified (fractional seconds are ignored). + // Maximum lifetime of TLS sessions. If specified, ``session_timeout`` will change the maximum lifetime + // of the TLS session. + // + // This serves as a hint for the `TLS session ticket lifetime (for TLSv1.2) `_. + // Only whole seconds are considered; fractional seconds are ignored. google.protobuf.Duration session_timeout = 6 [(validate.rules).duration = { lt {seconds: 4294967296} gte {} }]; - // Config for whether to use certificates if they do not have - // an accompanying OCSP response or if the response expires at runtime. - // Defaults to LENIENT_STAPLING + // Configuration for handling certificates without an OCSP response or with expired responses. + // + // Defaults to ``LENIENT_STAPLING`` OcspStaplePolicy ocsp_staple_policy = 8 [(validate.rules).enum = {defined_only: true}]; // Multiple certificates are allowed in Downstream transport socket to serve different SNI. - // If the client provides SNI but no such cert matched, it will decide to full scan certificates or not based on this config. - // Defaults to false. See more details in :ref:`Multiple TLS certificates `. + // This option controls the behavior when no matching certificate is found for the received SNI value, + // or no SNI value was sent. If enabled, all certificates will be evaluated for a match for non-SNI criteria + // such as key type and OCSP settings. If disabled, the first provided certificate will be used. + // Defaults to ``false``. See more details in :ref:`Multiple TLS certificates `. google.protobuf.BoolValue full_scan_certs_on_sni_mismatch = 9; + + // If ``true``, the downstream client's preferred cipher is used during the handshake. If ``false``, Envoy + // uses its preferred cipher. + // + // .. note:: + // This has no effect when using TLSv1_3. + // + bool prefer_client_ciphers = 11; } // TLS key log configuration. // The key log file format is "format used by NSS for its SSLKEYLOGFILE debugging output" (text taken from openssl man page) message TlsKeyLog { - // The path to save the TLS key log. + // Path to save the TLS key log. string path = 1 [(validate.rules).string = {min_len: 1}]; - // The local IP address that will be used to filter the connection which should save the TLS key log - // If it is not set, any local IP address will be matched. + // Local IP address ranges to filter connections for TLS key logging. If not set, matches any local IP address. repeated config.core.v3.CidrRange local_address_range = 2; - // The remote IP address that will be used to filter the connection which should save the TLS key log - // If it is not set, any remote IP address will be matched. + // Remote IP address ranges to filter connections for TLS key logging. If not set, matches any remote IP address. repeated config.core.v3.CidrRange remote_address_range = 3; } // TLS context shared by both client and server TLS contexts. -// [#next-free-field: 16] +// [#next-free-field: 17] message CommonTlsContext { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.auth.CommonTlsContext"; - // Config for Certificate provider to get certificates. This provider should allow certificates to be - // fetched/refreshed over the network asynchronously with respect to the TLS handshake. + // Config for the Certificate Provider to fetch certificates. Certificates are fetched/refreshed asynchronously over + // the network relative to the TLS handshake. // // DEPRECATED: This message is not currently used, but if we ever do need it, we will want to // move it out of CommonTlsContext and into common.proto, similar to the existing @@ -248,33 +274,35 @@ message CommonTlsContext { // :ref:`Multiple TLS certificates ` can be associated with the // same context to allow both RSA and ECDSA certificates and support SNI-based selection. // - // Only one of ``tls_certificates``, ``tls_certificate_sds_secret_configs``, - // and ``tls_certificate_provider_instance`` may be used. - // [#next-major-version: These mutually exclusive fields should ideally be in a oneof, but it's - // not legal to put a repeated field in a oneof. In the next major version, we should rework - // this to avoid this problem.] + // If ``tls_certificate_provider_instance`` is set, this field is ignored. + // If this field is set, ``tls_certificate_sds_secret_configs`` is ignored. repeated TlsCertificate tls_certificates = 2; // Configs for fetching TLS certificates via SDS API. Note SDS API allows certificates to be // fetched/refreshed over the network asynchronously with respect to the TLS handshake. // // The same number and types of certificates as :ref:`tls_certificates ` - // are valid in the the certificates fetched through this setting. + // are valid in the certificates fetched through this setting. // - // Only one of ``tls_certificates``, ``tls_certificate_sds_secret_configs``, - // and ``tls_certificate_provider_instance`` may be used. - // [#next-major-version: These mutually exclusive fields should ideally be in a oneof, but it's - // not legal to put a repeated field in a oneof. In the next major version, we should rework - // this to avoid this problem.] + // If ``tls_certificates`` or ``tls_certificate_provider_instance`` are set, this field + // is ignored. repeated SdsSecretConfig tls_certificate_sds_secret_configs = 6; // Certificate provider instance for fetching TLS certs. // - // Only one of ``tls_certificates``, ``tls_certificate_sds_secret_configs``, - // and ``tls_certificate_provider_instance`` may be used. + // If this field is set, ``tls_certificates`` and ``tls_certificate_provider_instance`` + // are ignored. // [#not-implemented-hide:] CertificateProviderPluginInstance tls_certificate_provider_instance = 14; + // Custom TLS certificate selector. + // + // Select TLS certificate based on TLS client hello. + // If empty, defaults to native TLS certificate selection behavior: + // DNS SANs or Subject Common Name in TLS certificates is extracted as server name pattern to match SNI. + // [#extension-category: envoy.tls.certificate_selectors] + config.core.v3.TypedExtensionConfig custom_tls_certificate_selector = 16; + // Certificate provider for fetching TLS certificates. // [#not-implemented-hide:] CertificateProvider tls_certificate_certificate_provider = 9 @@ -293,13 +321,17 @@ message CommonTlsContext { // fetched/refreshed over the network asynchronously with respect to the TLS handshake. SdsSecretConfig validation_context_sds_secret_config = 7; - // Combined certificate validation context holds a default CertificateValidationContext - // and SDS config. When SDS server returns dynamic CertificateValidationContext, both dynamic - // and default CertificateValidationContext are merged into a new CertificateValidationContext - // for validation. This merge is done by Message::MergeFrom(), so dynamic - // CertificateValidationContext overwrites singular fields in default - // CertificateValidationContext, and concatenates repeated fields to default - // CertificateValidationContext, and logical OR is applied to boolean fields. + // Combines the default ``CertificateValidationContext`` with the SDS-provided dynamic context for certificate + // validation. + // + // When the SDS server returns a dynamic ``CertificateValidationContext``, it is merged + // with the default context using ``Message::MergeFrom()``. The merging rules are as follows: + // + // * **Singular Fields:** Dynamic fields override the default singular fields. + // * **Repeated Fields:** Dynamic repeated fields are concatenated with the default repeated fields. + // * **Boolean Fields:** Boolean fields are combined using a logical OR operation. + // + // The resulting ``CertificateValidationContext`` is used to perform certificate validation. CombinedCertificateValidationContext combined_validation_context = 8; // Certificate provider for fetching validation context. diff --git a/xds/third_party/envoy/src/main/proto/envoy/service/auth/v3/attribute_context.proto b/xds/third_party/envoy/src/main/proto/envoy/service/auth/v3/attribute_context.proto new file mode 100644 index 00000000000..2c4fbb4b73e --- /dev/null +++ b/xds/third_party/envoy/src/main/proto/envoy/service/auth/v3/attribute_context.proto @@ -0,0 +1,222 @@ +syntax = "proto3"; + +package envoy.service.auth.v3; + +import "envoy/config/core/v3/address.proto"; +import "envoy/config/core/v3/base.proto"; + +import "google/protobuf/timestamp.proto"; + +import "udpa/annotations/migrate.proto"; +import "udpa/annotations/status.proto"; +import "udpa/annotations/versioning.proto"; + +option java_package = "io.envoyproxy.envoy.service.auth.v3"; +option java_outer_classname = "AttributeContextProto"; +option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/service/auth/v3;authv3"; +option (udpa.annotations.file_status).package_version_status = ACTIVE; + +// [#protodoc-title: Attribute context] + +// See :ref:`network filter configuration overview ` +// and :ref:`HTTP filter configuration overview `. + +// An attribute is a piece of metadata that describes an activity on a network. +// For example, the size of an HTTP request, or the status code of an HTTP response. +// +// Each attribute has a type and a name, which is logically defined as a proto message field +// of the ``AttributeContext``. The ``AttributeContext`` is a collection of individual attributes +// supported by Envoy authorization system. +// [#comment: The following items are left out of this proto +// Request.Auth field for JWTs +// Request.Api for api management +// Origin peer that originated the request +// Caching Protocol +// request_context return values to inject back into the filter chain +// peer.claims -- from X.509 extensions +// Configuration +// - field mask to send +// - which return values from request_context are copied back +// - which return values are copied into request_headers] +// [#next-free-field: 14] +message AttributeContext { + option (udpa.annotations.versioning).previous_message_type = + "envoy.service.auth.v2.AttributeContext"; + + // This message defines attributes for a node that handles a network request. + // The node can be either a service or an application that sends, forwards, + // or receives the request. Service peers should fill in the ``service``, + // ``principal``, and ``labels`` as appropriate. + // [#next-free-field: 6] + message Peer { + option (udpa.annotations.versioning).previous_message_type = + "envoy.service.auth.v2.AttributeContext.Peer"; + + // The address of the peer, this is typically the IP address. + // It can also be UDS path, or others. + config.core.v3.Address address = 1; + + // The canonical service name of the peer. + // It should be set to :ref:`the HTTP x-envoy-downstream-service-cluster + // ` + // If a more trusted source of the service name is available through mTLS/secure naming, it + // should be used. + string service = 2; + + // The labels associated with the peer. + // These could be pod labels for Kubernetes or tags for VMs. + // The source of the labels could be an X.509 certificate or other configuration. + map labels = 3; + + // The authenticated identity of this peer. + // For example, the identity associated with the workload such as a service account. + // If an X.509 certificate is used to assert the identity this field should be sourced from + // ``URI Subject Alternative Names``, ``DNS Subject Alternate Names`` or ``Subject`` in that order. + // The primary identity should be the principal. The principal format is issuer specific. + // + // Examples: + // + // - SPIFFE format is ``spiffe://trust-domain/path``. + // - Google account format is ``https://accounts.google.com/{userid}``. + string principal = 4; + + // The X.509 certificate used to authenticate the identify of this peer. + // When present, the certificate contents are encoded in URL and PEM format. + string certificate = 5; + } + + // Represents a network request, such as an HTTP request. + message Request { + option (udpa.annotations.versioning).previous_message_type = + "envoy.service.auth.v2.AttributeContext.Request"; + + // The timestamp when the proxy receives the first byte of the request. + google.protobuf.Timestamp time = 1; + + // Represents an HTTP request or an HTTP-like request. + HttpRequest http = 2; + } + + // This message defines attributes for an HTTP request. + // HTTP/1.x, HTTP/2, gRPC are all considered as HTTP requests. + // [#next-free-field: 14] + message HttpRequest { + option (udpa.annotations.versioning).previous_message_type = + "envoy.service.auth.v2.AttributeContext.HttpRequest"; + + // The unique ID for a request, which can be propagated to downstream + // systems. The ID should have low probability of collision + // within a single day for a specific service. + // For HTTP requests, it should be X-Request-ID or equivalent. + string id = 1; + + // The HTTP request method, such as ``GET``, ``POST``. + string method = 2; + + // The HTTP request headers. If multiple headers share the same key, they + // must be merged according to the HTTP spec. All header keys must be + // lower-cased, because HTTP header keys are case-insensitive. + // Header value is encoded as UTF-8 string. Non-UTF-8 characters will be replaced by "!". + // This field will not be set if + // :ref:`encode_raw_headers ` + // is set to true. + map headers = 3 + [(udpa.annotations.field_migrate).oneof_promotion = "headers_type"]; + + // A list of the raw HTTP request headers. This is used instead of + // :ref:`headers ` when + // :ref:`encode_raw_headers ` + // is set to true. + // + // Note that this is not actually a map type. ``header_map`` contains a single repeated field + // ``headers``. + // + // Here, only the ``key`` and ``raw_value`` fields will be populated for each HeaderValue, and + // that is only when + // :ref:`encode_raw_headers ` + // is set to true. + // + // Also, unlike the + // :ref:`headers ` + // field, headers with the same key are not combined into a single comma separated header. + config.core.v3.HeaderMap header_map = 13 + [(udpa.annotations.field_migrate).oneof_promotion = "headers_type"]; + + // The request target, as it appears in the first line of the HTTP request. This includes + // the URL path and query-string. No decoding is performed. + string path = 4; + + // The HTTP request ``Host`` or ``:authority`` header value. + string host = 5; + + // The HTTP URL scheme, such as ``http`` and ``https``. + string scheme = 6; + + // This field is always empty, and exists for compatibility reasons. The HTTP URL query is + // included in ``path`` field. + string query = 7; + + // This field is always empty, and exists for compatibility reasons. The URL fragment is + // not submitted as part of HTTP requests; it is unknowable. + string fragment = 8; + + // The HTTP request size in bytes. If unknown, it must be -1. + int64 size = 9; + + // The network protocol used with the request, such as "HTTP/1.0", "HTTP/1.1", or "HTTP/2". + // + // See :repo:`headers.h:ProtocolStrings ` for a list of all + // possible values. + string protocol = 10; + + // The HTTP request body. + string body = 11; + + // The HTTP request body in bytes. This is used instead of + // :ref:`body ` when + // :ref:`pack_as_bytes ` + // is set to true. + bytes raw_body = 12; + } + + // This message defines attributes for the underlying TLS session. + message TLSSession { + // SNI used for TLS session. + string sni = 1; + } + + // The source of a network activity, such as starting a TCP connection. + // In a multi hop network activity, the source represents the sender of the + // last hop. + Peer source = 1; + + // The destination of a network activity, such as accepting a TCP connection. + // In a multi hop network activity, the destination represents the receiver of + // the last hop. + Peer destination = 2; + + // Represents a network request, such as an HTTP request. + Request request = 4; + + // This is analogous to http_request.headers, however these contents will not be sent to the + // upstream server. Context_extensions provide an extension mechanism for sending additional + // information to the auth server without modifying the proto definition. It maps to the + // internal opaque context in the filter chain. + map context_extensions = 10; + + // Dynamic metadata associated with the request. + config.core.v3.Metadata metadata_context = 11; + + // Metadata associated with the selected route. + config.core.v3.Metadata route_metadata_context = 13; + + // TLS session details of the underlying connection. + // This is not populated by default and will be populated only if the ext_authz filter has + // been specifically configured to include this information. + // For HTTP ext_authz, that requires :ref:`include_tls_session ` + // to be set to true. + // For network ext_authz, that requires :ref:`include_tls_session ` + // to be set to true. + TLSSession tls_session = 12; +} diff --git a/xds/third_party/envoy/src/main/proto/envoy/service/auth/v3/external_auth.proto b/xds/third_party/envoy/src/main/proto/envoy/service/auth/v3/external_auth.proto new file mode 100644 index 00000000000..520a4ff4f31 --- /dev/null +++ b/xds/third_party/envoy/src/main/proto/envoy/service/auth/v3/external_auth.proto @@ -0,0 +1,157 @@ +syntax = "proto3"; + +package envoy.service.auth.v3; + +import "envoy/config/core/v3/base.proto"; +import "envoy/service/auth/v3/attribute_context.proto"; +import "envoy/type/v3/http_status.proto"; + +import "google/protobuf/struct.proto"; +import "google/rpc/status.proto"; + +import "envoy/annotations/deprecation.proto"; +import "udpa/annotations/status.proto"; +import "udpa/annotations/versioning.proto"; + +option java_package = "io.envoyproxy.envoy.service.auth.v3"; +option java_outer_classname = "ExternalAuthProto"; +option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/service/auth/v3;authv3"; +option (udpa.annotations.file_status).package_version_status = ACTIVE; + +// [#protodoc-title: Authorization service] + +// The authorization service request messages used by external authorization :ref:`network filter +// ` and :ref:`HTTP filter `. + +// A generic interface for performing authorization check on incoming +// requests to a networked service. +service Authorization { + // Performs authorization check based on the attributes associated with the + // incoming request, and returns status `OK` or not `OK`. + rpc Check(CheckRequest) returns (CheckResponse) { + } +} + +message CheckRequest { + option (udpa.annotations.versioning).previous_message_type = "envoy.service.auth.v2.CheckRequest"; + + // The request attributes. + AttributeContext attributes = 1; +} + +// HTTP attributes for a denied response. +message DeniedHttpResponse { + option (udpa.annotations.versioning).previous_message_type = + "envoy.service.auth.v2.DeniedHttpResponse"; + + // This field allows the authorization service to send an HTTP response status code to the + // downstream client. If not set, Envoy sends ``403 Forbidden`` HTTP status code by default. + type.v3.HttpStatus status = 1; + + // This field allows the authorization service to send HTTP response headers + // to the downstream client. Note that the :ref:`append field in HeaderValueOption ` defaults to + // false when used in this message. + repeated config.core.v3.HeaderValueOption headers = 2; + + // This field allows the authorization service to send a response body data + // to the downstream client. + string body = 3; +} + +// HTTP attributes for an OK response. +// [#next-free-field: 9] +message OkHttpResponse { + option (udpa.annotations.versioning).previous_message_type = + "envoy.service.auth.v2.OkHttpResponse"; + + // HTTP entity headers in addition to the original request headers. This allows the authorization + // service to append, to add or to override headers from the original request before + // dispatching it to the upstream. Note that the :ref:`append field in HeaderValueOption ` defaults to + // false when used in this message. By setting the ``append`` field to ``true``, + // the filter will append the correspondent header value to the matched request header. + // By leaving ``append`` as false, the filter will either add a new header, or override an existing + // one if there is a match. + repeated config.core.v3.HeaderValueOption headers = 2; + + // HTTP entity headers to remove from the original request before dispatching + // it to the upstream. This allows the authorization service to act on auth + // related headers (like ``Authorization``), process them, and consume them. + // Under this model, the upstream will either receive the request (if it's + // authorized) or not receive it (if it's not), but will not see headers + // containing authorization credentials. + // + // Pseudo headers (such as ``:authority``, ``:method``, ``:path`` etc), as well as + // the header ``Host``, may not be removed as that would make the request + // malformed. If mentioned in ``headers_to_remove`` these special headers will + // be ignored. + // + // When using the HTTP service this must instead be set by the HTTP + // authorization service as a comma separated list like so: + // ``x-envoy-auth-headers-to-remove: one-auth-header, another-auth-header``. + repeated string headers_to_remove = 5; + + // This field has been deprecated in favor of :ref:`CheckResponse.dynamic_metadata + // `. Until it is removed, + // setting this field overrides :ref:`CheckResponse.dynamic_metadata + // `. + google.protobuf.Struct dynamic_metadata = 3 + [deprecated = true, (envoy.annotations.deprecated_at_minor_version) = "3.0"]; + + // This field allows the authorization service to send HTTP response headers + // to the downstream client on success. Note that the :ref:`append field in HeaderValueOption ` + // defaults to false when used in this message. + repeated config.core.v3.HeaderValueOption response_headers_to_add = 6; + + // This field allows the authorization service to set (and overwrite) query + // string parameters on the original request before it is sent upstream. + repeated config.core.v3.QueryParameter query_parameters_to_set = 7; + + // This field allows the authorization service to specify which query parameters + // should be removed from the original request before it is sent upstream. Each + // element in this list is a case-sensitive query parameter name to be removed. + repeated string query_parameters_to_remove = 8; +} + +// Intended for gRPC and Network Authorization servers ``only``. +// [#next-free-field: 6] +message CheckResponse { + option (udpa.annotations.versioning).previous_message_type = + "envoy.service.auth.v2.CheckResponse"; + + // Status ``OK`` allows the request. Any other status indicates the request should be denied, and + // for HTTP filter, if not overridden by :ref:`denied HTTP response status ` + // Envoy sends ``403 Forbidden`` HTTP status code by default. + google.rpc.Status status = 1; + + // An message that contains HTTP response attributes. This message is + // used when the authorization service needs to send custom responses to the + // downstream client or, to modify/add request headers being dispatched to the upstream. + oneof http_response { + // Supplies http attributes for a denied response. + DeniedHttpResponse denied_response = 2; + + // Supplies http attributes for an ok response. + OkHttpResponse ok_response = 3; + + // Supplies http attributes for an error response. This is used when the authorization + // service encounters an internal error and wants to return custom headers and body to the + // downstream client. When ``error_response`` is set, the ext_authz filter increments the + // ``ext_authz_error`` stat and respects the :ref:`failure_mode_allow + // ` + // configuration. The HTTP status code, headers, and body are taken from the + // :ref:`DeniedHttpResponse ` message. + // If the status field is not set, Envoy sends the status code configured via + // :ref:`status_on_error `, + // which defaults to ``403 Forbidden``. + DeniedHttpResponse error_response = 5; + } + + // Optional response metadata that will be emitted as dynamic metadata to be consumed by the next + // filter. This metadata lives in a namespace specified by the canonical name of extension filter + // that requires it: + // + // - :ref:`envoy.filters.http.ext_authz ` for HTTP filter. + // - :ref:`envoy.filters.network.ext_authz ` for network filter. + google.protobuf.Struct dynamic_metadata = 4; +} diff --git a/xds/third_party/envoy/src/main/proto/envoy/service/discovery/v3/discovery.proto b/xds/third_party/envoy/src/main/proto/envoy/service/discovery/v3/discovery.proto index b7270f246de..e1ce827a48f 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/service/discovery/v3/discovery.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/service/discovery/v3/discovery.proto @@ -41,18 +41,29 @@ message ResourceName { DynamicParameterConstraints dynamic_parameter_constraints = 2; } +// [#not-implemented-hide:] +// An error associated with a specific resource name, returned to the +// client by the server. +message ResourceError { + // The name of the resource. + ResourceName resource_name = 1; + + // The error reported for the resource. + google.rpc.Status error_detail = 2; +} + // A DiscoveryRequest requests a set of versioned resources of the same type for // a given Envoy node on some API. // [#next-free-field: 8] message DiscoveryRequest { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.DiscoveryRequest"; - // The version_info provided in the request messages will be the version_info + // The ``version_info`` provided in the request messages will be the ``version_info`` // received with the most recent successfully processed response or empty on // the first request. It is expected that no new request is sent after a // response is received until the Envoy instance is ready to ACK/NACK the new // configuration. ACK/NACK takes place by returning the new API config version - // as applied or the previous API config version respectively. Each type_url + // as applied or the previous API config version respectively. Each ``type_url`` // (see below) has an independent version associated with it. string version_info = 1; @@ -61,10 +72,10 @@ message DiscoveryRequest { // List of resources to subscribe to, e.g. list of cluster names or a route // configuration name. If this is empty, all resources for the API are - // returned. LDS/CDS may have empty resource_names, which will cause all + // returned. LDS/CDS may have empty ``resource_names``, which will cause all // resources for the Envoy instance to be returned. The LDS and CDS responses // will then imply a number of resources that need to be fetched via EDS/RDS, - // which will be explicitly enumerated in resource_names. + // which will be explicitly enumerated in ``resource_names``. repeated string resource_names = 3; // [#not-implemented-hide:] @@ -72,21 +83,27 @@ message DiscoveryRequest { // parameters along with each resource name. Clients that populate this // field must be able to handle responses from the server where resources // are wrapped in a Resource message. - // Note that it is legal for a request to have some resources listed - // in ``resource_names`` and others in ``resource_locators``. + // + // .. note:: + // It is legal for a request to have some resources listed + // in ``resource_names`` and others in ``resource_locators``. + // repeated ResourceLocator resource_locators = 7; // Type of the resource that is being requested, e.g. - // "type.googleapis.com/envoy.api.v2.ClusterLoadAssignment". This is implicit + // ``type.googleapis.com/envoy.api.v2.ClusterLoadAssignment``. This is implicit // in requests made via singleton xDS APIs such as CDS, LDS, etc. but is // required for ADS. string type_url = 4; - // nonce corresponding to DiscoveryResponse being ACK/NACKed. See above - // discussion on version_info and the DiscoveryResponse nonce comment. This - // may be empty only if 1) this is a non-persistent-stream xDS such as HTTP, - // or 2) the client has not yet accepted an update in this xDS stream (unlike - // delta, where it is populated only for new explicit ACKs). + // nonce corresponding to ``DiscoveryResponse`` being ACK/NACKed. See above + // discussion on ``version_info`` and the ``DiscoveryResponse`` nonce comment. This + // may be empty only if: + // + // * This is a non-persistent-stream xDS such as HTTP, or + // * The client has not yet accepted an update in this xDS stream (unlike + // delta, where it is populated only for new explicit ACKs). + // string response_nonce = 5; // This is populated when the previous :ref:`DiscoveryResponse ` @@ -96,7 +113,7 @@ message DiscoveryRequest { google.rpc.Status error_detail = 6; } -// [#next-free-field: 7] +// [#next-free-field: 8] message DiscoveryResponse { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.DiscoveryResponse"; @@ -109,35 +126,46 @@ message DiscoveryResponse { // [#not-implemented-hide:] // Canary is used to support two Envoy command line flags: // - // * --terminate-on-canary-transition-failure. When set, Envoy is able to + // * ``--terminate-on-canary-transition-failure``. When set, Envoy is able to // terminate if it detects that configuration is stuck at canary. Consider // this example sequence of updates: - // - Management server applies a canary config successfully. - // - Management server rolls back to a production config. - // - Envoy rejects the new production config. + // + // * Management server applies a canary config successfully. + // * Management server rolls back to a production config. + // * Envoy rejects the new production config. + // // Since there is no sensible way to continue receiving configuration // updates, Envoy will then terminate and apply production config from a // clean slate. - // * --dry-run-canary. When set, a canary response will never be applied, only + // + // * ``--dry-run-canary``. When set, a canary response will never be applied, only // validated via a dry run. + // bool canary = 3; // Type URL for resources. Identifies the xDS API when muxing over ADS. - // Must be consistent with the type_url in the 'resources' repeated Any (if non-empty). + // Must be consistent with the ``type_url`` in the 'resources' repeated Any (if non-empty). string type_url = 4; // For gRPC based subscriptions, the nonce provides a way to explicitly ack a - // specific DiscoveryResponse in a following DiscoveryRequest. Additional + // specific ``DiscoveryResponse`` in a following ``DiscoveryRequest``. Additional // messages may have been sent by Envoy to the management server for the - // previous version on the stream prior to this DiscoveryResponse, that were + // previous version on the stream prior to this ``DiscoveryResponse``, that were // unprocessed at response send time. The nonce allows the management server - // to ignore any further DiscoveryRequests for the previous version until a - // DiscoveryRequest bearing the nonce. The nonce is optional and is not + // to ignore any further ``DiscoveryRequests`` for the previous version until a + // ``DiscoveryRequest`` bearing the nonce. The nonce is optional and is not // required for non-stream based xDS implementations. string nonce = 5; // The control plane instance that sent the response. config.core.v3.ControlPlane control_plane = 6; + + // [#not-implemented-hide:] + // Errors associated with specific resources. Clients are expected to + // remember the most recent error for a given resource across responses; + // the error condition is not considered to be cleared until a response is + // received that contains the resource in the 'resources' field. + repeated ResourceError resource_errors = 7; } // DeltaDiscoveryRequest and DeltaDiscoveryResponse are used in a new gRPC @@ -153,25 +181,28 @@ message DiscoveryResponse { // connected to it. // // In Delta xDS the nonce field is required and used to pair -// DeltaDiscoveryResponse to a DeltaDiscoveryRequest ACK or NACK. -// Optionally, a response message level system_version_info is present for +// ``DeltaDiscoveryResponse`` to a ``DeltaDiscoveryRequest`` ACK or NACK. +// Optionally, a response message level ``system_version_info`` is present for // debugging purposes only. // -// DeltaDiscoveryRequest plays two independent roles. Any DeltaDiscoveryRequest -// can be either or both of: [1] informing the server of what resources the -// client has gained/lost interest in (using resource_names_subscribe and -// resource_names_unsubscribe), or [2] (N)ACKing an earlier resource update from -// the server (using response_nonce, with presence of error_detail making it a NACK). -// Additionally, the first message (for a given type_url) of a reconnected gRPC stream +// ``DeltaDiscoveryRequest`` plays two independent roles. Any ``DeltaDiscoveryRequest`` +// can be either or both of: +// +// * Informing the server of what resources the client has gained/lost interest in +// (using ``resource_names_subscribe`` and ``resource_names_unsubscribe``), or +// * (N)ACKing an earlier resource update from the server (using ``response_nonce``, +// with presence of ``error_detail`` making it a NACK). +// +// Additionally, the first message (for a given ``type_url``) of a reconnected gRPC stream // has a third role: informing the server of the resources (and their versions) -// that the client already possesses, using the initial_resource_versions field. +// that the client already possesses, using the ``initial_resource_versions`` field. // // As with state-of-the-world, when multiple resource types are multiplexed (ADS), -// all requests/acknowledgments/updates are logically walled off by type_url: +// all requests/acknowledgments/updates are logically walled off by ``type_url``: // a Cluster ACK exists in a completely separate world from a prior Route NACK. -// In particular, initial_resource_versions being sent at the "start" of every -// gRPC stream actually entails a message for each type_url, each with its own -// initial_resource_versions. +// In particular, ``initial_resource_versions`` being sent at the "start" of every +// gRPC stream actually entails a message for each ``type_url``, each with its own +// ``initial_resource_versions``. // [#next-free-field: 10] message DeltaDiscoveryRequest { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.DeltaDiscoveryRequest"; @@ -187,23 +218,24 @@ message DeltaDiscoveryRequest { // DeltaDiscoveryRequests allow the client to add or remove individual // resources to the set of tracked resources in the context of a stream. - // All resource names in the resource_names_subscribe list are added to the - // set of tracked resources and all resource names in the resource_names_unsubscribe + // All resource names in the ``resource_names_subscribe`` list are added to the + // set of tracked resources and all resource names in the ``resource_names_unsubscribe`` // list are removed from the set of tracked resources. // - // *Unlike* state-of-the-world xDS, an empty resource_names_subscribe or - // resource_names_unsubscribe list simply means that no resources are to be + // *Unlike* state-of-the-world xDS, an empty ``resource_names_subscribe`` or + // ``resource_names_unsubscribe`` list simply means that no resources are to be // added or removed to the resource list. // *Like* state-of-the-world xDS, the server must send updates for all tracked // resources, but can also send updates for resources the client has not subscribed to. // - // NOTE: the server must respond with all resources listed in resource_names_subscribe, - // even if it believes the client has the most recent version of them. The reason: - // the client may have dropped them, but then regained interest before it had a chance - // to send the unsubscribe message. See DeltaSubscriptionStateTest.RemoveThenAdd. + // .. note:: + // The server must respond with all resources listed in ``resource_names_subscribe``, + // even if it believes the client has the most recent version of them. The reason: + // the client may have dropped them, but then regained interest before it had a chance + // to send the unsubscribe message. See DeltaSubscriptionStateTest.RemoveThenAdd. // - // These two fields can be set in any DeltaDiscoveryRequest, including ACKs - // and initial_resource_versions. + // These two fields can be set in any ``DeltaDiscoveryRequest``, including ACKs + // and ``initial_resource_versions``. // // A list of Resource names to add to the list of tracked resources. repeated string resource_names_subscribe = 3; @@ -214,31 +246,40 @@ message DeltaDiscoveryRequest { // [#not-implemented-hide:] // Alternative to ``resource_names_subscribe`` field that allows specifying dynamic parameters // along with each resource name. - // Note that it is legal for a request to have some resources listed - // in ``resource_names_subscribe`` and others in ``resource_locators_subscribe``. + // + // .. note:: + // It is legal for a request to have some resources listed + // in ``resource_names_subscribe`` and others in ``resource_locators_subscribe``. + // repeated ResourceLocator resource_locators_subscribe = 8; // [#not-implemented-hide:] // Alternative to ``resource_names_unsubscribe`` field that allows specifying dynamic parameters // along with each resource name. - // Note that it is legal for a request to have some resources listed - // in ``resource_names_unsubscribe`` and others in ``resource_locators_unsubscribe``. + // + // .. note:: + // It is legal for a request to have some resources listed + // in ``resource_names_unsubscribe`` and others in ``resource_locators_unsubscribe``. + // repeated ResourceLocator resource_locators_unsubscribe = 9; // Informs the server of the versions of the resources the xDS client knows of, to enable the // client to continue the same logical xDS session even in the face of gRPC stream reconnection. - // It will not be populated: [1] in the very first stream of a session, since the client will - // not yet have any resources, [2] in any message after the first in a stream (for a given - // type_url), since the server will already be correctly tracking the client's state. - // (In ADS, the first message *of each type_url* of a reconnected stream populates this map.) + // It will not be populated: + // + // * In the very first stream of a session, since the client will not yet have any resources. + // * In any message after the first in a stream (for a given ``type_url``), since the server will + // already be correctly tracking the client's state. + // + // (In ADS, the first message ``of each type_url`` of a reconnected stream populates this map.) // The map's keys are names of xDS resources known to the xDS client. // The map's values are opaque resource versions. map initial_resource_versions = 5; - // When the DeltaDiscoveryRequest is a ACK or NACK message in response - // to a previous DeltaDiscoveryResponse, the response_nonce must be the - // nonce in the DeltaDiscoveryResponse. - // Otherwise (unlike in DiscoveryRequest) response_nonce must be omitted. + // When the ``DeltaDiscoveryRequest`` is a ACK or NACK message in response + // to a previous ``DeltaDiscoveryResponse``, the ``response_nonce`` must be the + // nonce in the ``DeltaDiscoveryResponse``. + // Otherwise (unlike in ``DiscoveryRequest``) ``response_nonce`` must be omitted. string response_nonce = 6; // This is populated when the previous :ref:`DiscoveryResponse ` @@ -247,7 +288,7 @@ message DeltaDiscoveryRequest { google.rpc.Status error_detail = 7; } -// [#next-free-field: 9] +// [#next-free-field: 10] message DeltaDiscoveryResponse { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.DeltaDiscoveryResponse"; @@ -256,37 +297,46 @@ message DeltaDiscoveryResponse { string system_version_info = 1; // The response resources. These are typed resources, whose types must match - // the type_url field. + // the ``type_url`` field. repeated Resource resources = 2; // field id 3 IS available! // Type URL for resources. Identifies the xDS API when muxing over ADS. - // Must be consistent with the type_url in the Any within 'resources' if 'resources' is non-empty. + // Must be consistent with the ``type_url`` in the Any within 'resources' if 'resources' is non-empty. string type_url = 4; - // Resources names of resources that have be deleted and to be removed from the xDS Client. + // Resource names of resources that have been deleted and to be removed from the xDS Client. // Removed resources for missing resources can be ignored. repeated string removed_resources = 6; - // Alternative to removed_resources that allows specifying which variant of + // Alternative to ``removed_resources`` that allows specifying which variant of // a resource is being removed. This variant must be used for any resource // for which dynamic parameter constraints were sent to the client. repeated ResourceName removed_resource_names = 8; - // The nonce provides a way for DeltaDiscoveryRequests to uniquely - // reference a DeltaDiscoveryResponse when (N)ACKing. The nonce is required. + // The nonce provides a way for ``DeltaDiscoveryRequests`` to uniquely + // reference a ``DeltaDiscoveryResponse`` when (N)ACKing. The nonce is required. string nonce = 5; // [#not-implemented-hide:] // The control plane instance that sent the response. config.core.v3.ControlPlane control_plane = 7; + + // [#not-implemented-hide:] + // Errors associated with specific resources. + // + // .. note:: + // A resource in this field with a status of NOT_FOUND should be treated the same as + // a resource listed in the ``removed_resources`` or ``removed_resource_names`` fields. + // + repeated ResourceError resource_errors = 9; } // A set of dynamic parameter constraints associated with a variant of an individual xDS resource. // These constraints determine whether the resource matches a subscription based on the set of // dynamic parameters in the subscription, as specified in the -// :ref:`ResourceLocator.dynamic_parameters` +// :ref:`ResourceLocator.dynamic_parameters ` // field. This allows xDS implementations (clients, servers, and caching proxies) to determine // which variant of a resource is appropriate for a given client. message DynamicParameterConstraints { @@ -340,8 +390,11 @@ message Resource { // [#not-implemented-hide:] message CacheControl { // If true, xDS proxies may not cache this resource. - // Note that this does not apply to clients other than xDS proxies, which must cache resources - // for their own use, regardless of the value of this field. + // + // .. note:: + // This does not apply to clients other than xDS proxies, which must cache resources + // for their own use, regardless of the value of this field. + // bool do_not_cache = 1; } @@ -371,7 +424,7 @@ message Resource { // configuration for the resource will be removed. // // The TTL can be refreshed or changed by sending a response that doesn't change the resource - // version. In this case the resource field does not need to be populated, which allows for + // version. In this case the ``resource`` field does not need to be populated, which allows for // light-weight "heartbeat" updates to keep a resource with a TTL alive. // // The TTL feature is meant to support configurations that should be removed in the event of diff --git a/xds/third_party/envoy/src/main/proto/envoy/service/rate_limit_quota/v3/rlqs.proto b/xds/third_party/envoy/src/main/proto/envoy/service/rate_limit_quota/v3/rlqs.proto new file mode 100644 index 00000000000..b8fa2cd8982 --- /dev/null +++ b/xds/third_party/envoy/src/main/proto/envoy/service/rate_limit_quota/v3/rlqs.proto @@ -0,0 +1,258 @@ +syntax = "proto3"; + +package envoy.service.rate_limit_quota.v3; + +import "envoy/type/v3/ratelimit_strategy.proto"; + +import "google/protobuf/duration.proto"; + +import "xds/annotations/v3/status.proto"; + +import "udpa/annotations/status.proto"; +import "validate/validate.proto"; + +option java_package = "io.envoyproxy.envoy.service.rate_limit_quota.v3"; +option java_outer_classname = "RlqsProto"; +option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/service/rate_limit_quota/v3;rate_limit_quotav3"; +option (udpa.annotations.file_status).package_version_status = ACTIVE; +option (xds.annotations.v3.file_status).work_in_progress = true; + +// [#protodoc-title: Rate Limit Quota Service (RLQS)] + +// The Rate Limit Quota Service (RLQS) is a Envoy global rate limiting service that allows to +// delegate rate limit decisions to a remote service. The service will aggregate the usage reports +// from multiple data plane instances, and distribute Rate Limit Assignments to each instance +// based on its business logic. The logic is outside of the scope of the protocol API. +// +// The protocol is designed as a streaming-first API. It utilizes watch-like subscription model. +// The data plane groups requests into Quota Buckets as directed by the filter config, +// and periodically reports them to the RLQS server along with the Bucket identifier, :ref:`BucketId +// `. Once RLQS server has collected enough +// reports to make a decision, it'll send back the assignment with the rate limiting instructions. +// +// The first report sent by the data plane is interpreted by the RLQS server as a "watch" request, +// indicating that the data plane instance is interested in receiving further updates for the +// ``BucketId``. From then on, RLQS server may push assignments to this instance at will, even if +// the instance is not sending usage reports. It's the responsibility of the RLQS server +// to determine when the data plane instance didn't send ``BucketId`` reports for too long, +// and to respond with the :ref:`AbandonAction +// `, +// indicating that the server has now stopped sending quota assignments for the ``BucketId`` bucket, +// and the data plane instance should :ref:`abandon +// ` +// it. +// +// If for any reason the RLQS client doesn't receive the initial assignment for the reported bucket, +// in order to prevent memory exhaustion, the data plane will limit the time such bucket +// is retained. The exact time to wait for the initial assignment is chosen by the filter, +// and may vary based on the implementation. +// Once the duration ends, the data plane will stop reporting bucket usage, reject any enqueued +// requests, and purge the bucket from the memory. Subsequent requests matched into the bucket +// will re-initialize the bucket in the "no assignment" state, restarting the reports. +// +// Refer to Rate Limit Quota :ref:`configuration overview ` +// for further details. + +// Defines the Rate Limit Quota Service (RLQS). +service RateLimitQuotaService { + // Main communication channel: the data plane sends usage reports to the RLQS server, + // and the server asynchronously responding with the assignments. + rpc StreamRateLimitQuotas(stream RateLimitQuotaUsageReports) + returns (stream RateLimitQuotaResponse) { + } +} + +message RateLimitQuotaUsageReports { + // The usage report for a bucket. + // + // .. note:: + // Note that the first report sent for a ``BucketId`` indicates to the RLQS server that + // the RLQS client is subscribing for the future assignments for this ``BucketId``. + message BucketQuotaUsage { + // ``BucketId`` for which request quota usage is reported. + BucketId bucket_id = 1 [(validate.rules).message = {required: true}]; + + // Time elapsed since the last report. + google.protobuf.Duration time_elapsed = 2 [(validate.rules).duration = { + required: true + gt {} + }]; + + // Requests the data plane has allowed through. + uint64 num_requests_allowed = 3; + + // Requests throttled. + uint64 num_requests_denied = 4; + } + + // All quota requests must specify the domain. This enables sharing the quota + // server between different applications without fear of overlap. + // E.g., "envoy". + // + // Should only be provided in the first report, all subsequent messages on the same + // stream are considered to be in the same domain. In case the domain needs to be + // changes, close the stream, and reopen a new one with the different domain. + string domain = 1 [(validate.rules).string = {min_len: 1}]; + + // A list of quota usage reports. The list is processed by the RLQS server in the same order + // it's provided by the client. + repeated BucketQuotaUsage bucket_quota_usages = 2 [(validate.rules).repeated = {min_items: 1}]; +} + +message RateLimitQuotaResponse { + // Commands the data plane to apply one of the actions to the bucket with the + // :ref:`bucket_id `. + message BucketAction { + // Quota assignment for the bucket. Configures the rate limiting strategy and the duration + // for the given :ref:`bucket_id + // `. + // + // **Applying the first assignment to the bucket** + // + // Once the data plane receives the ``QuotaAssignmentAction``, it must send the current usage + // report for the bucket, and start rate limiting requests matched into the bucket + // using the strategy configured in the :ref:`rate_limit_strategy + // ` + // field. The assignment becomes bucket's ``active`` assignment. + // + // **Expiring the assignment** + // + // The duration of the assignment defined in the :ref:`assignment_time_to_live + // ` + // field. When the duration runs off, the assignment is ``expired``, and no longer ``active``. + // The data plane should stop applying the rate limiting strategy to the bucket, and transition + // the bucket to the "expired assignment" state. This activates the behavior configured in the + // :ref:`expired_assignment_behavior ` + // field. + // + // **Replacing the assignment** + // + // * If the rate limiting strategy is different from bucket's ``active`` assignment, or + // the current bucket assignment is ``expired``, the data plane must immediately + // end the current assignment, report the bucket usage, and apply the new assignment. + // The new assignment becomes bucket's ``active`` assignment. + // * If the rate limiting strategy is the same as the bucket's ``active`` (not ``expired``) + // assignment, the data plane should extend the duration of the ``active`` assignment + // for the duration of the new assignment provided in the :ref:`assignment_time_to_live + // ` + // field. The ``active`` assignment is considered unchanged. + message QuotaAssignmentAction { + // A duration after which the assignment is be considered ``expired``. The process of the + // expiration is described :ref:`above + // `. + // + // * If unset, the assignment has no expiration date. + // * If set to ``0``, the assignment expires immediately, forcing the client into the + // :ref:`"expired assignment" + // ` + // state. This may be used by the RLQS server in cases when it needs clients to proactively + // fall back to the pre-configured :ref:`ExpiredAssignmentBehavior + // `, + // f.e. before the server going into restart. + // + // .. attention:: + // Note that :ref:`expiring + // ` + // the assignment is not the same as :ref:`abandoning + // ` + // the assignment. While expiring the assignment just transitions the bucket to + // the "expired assignment" state; abandoning the assignment completely erases + // the bucket from the data plane memory, and stops the usage reports. + google.protobuf.Duration assignment_time_to_live = 2 [(validate.rules).duration = {gte {}}]; + + // Configures the local rate limiter for the request matched to the bucket. + // If not set, allow all requests. + type.v3.RateLimitStrategy rate_limit_strategy = 3; + } + + // Abandon action for the bucket. Indicates that the RLQS server will no longer be + // sending updates for the given :ref:`bucket_id + // `. + // + // If no requests are reported for a bucket, after some time the server considers the bucket + // inactive. The server stops tracking the bucket, and instructs the the data plane to abandon + // the bucket via this message. + // + // **Abandoning the assignment** + // + // The data plane is to erase the bucket (including its usage data) from the memory. + // It should stop tracking the bucket, and stop reporting its usage. This effectively resets + // the data plane to the state prior to matching the first request into the bucket. + // + // **Restarting the subscription** + // + // If a new request is matched into a bucket previously abandoned, the data plane must behave + // as if it has never tracked the bucket, and it's the first request matched into it: + // + // 1. The process of :ref:`subscription and reporting + // ` + // starts from the beginning. + // + // 2. The bucket transitions to the :ref:`"no assignment" + // ` + // state. + // + // 3. Once the new assignment is received, it's applied per + // "Applying the first assignment to the bucket" section of the :ref:`QuotaAssignmentAction + // `. + message AbandonAction { + } + + // ``BucketId`` for which request the action is applied. + BucketId bucket_id = 1 [(validate.rules).message = {required: true}]; + + oneof bucket_action { + option (validate.required) = true; + + // Apply the quota assignment to the bucket. + // + // Commands the data plane to apply a rate limiting strategy to the bucket. + // The process of applying and expiring the rate limiting strategy is detailed in the + // :ref:`QuotaAssignmentAction + // ` + // message. + QuotaAssignmentAction quota_assignment_action = 2; + + // Abandon the bucket. + // + // Commands the data plane to abandon the bucket. + // The process of abandoning the bucket is described in the :ref:`AbandonAction + // ` + // message. + AbandonAction abandon_action = 3; + } + } + + // An ordered list of actions to be applied to the buckets. The actions are applied in the + // given order, from top to bottom. + repeated BucketAction bucket_action = 1 [(validate.rules).repeated = {min_items: 1}]; +} + +// The identifier for the bucket. Used to match the bucket between the control plane (RLQS server), +// and the data plane (RLQS client), f.e.: +// +// * the data plane sends a usage report for requests matched into the bucket with ``BucketId`` +// to the control plane +// * the control plane sends an assignment for the bucket with ``BucketId`` to the data plane +// Bucket ID. +// +// Example: +// +// .. validated-code-block:: yaml +// :type-name: envoy.service.rate_limit_quota.v3.BucketId +// +// bucket: +// name: my_bucket +// env: staging +// +// .. note:: +// The order of ``BucketId`` keys do not matter. Buckets ``{ a: 'A', b: 'B' }`` and +// ``{ b: 'B', a: 'A' }`` are identical. +message BucketId { + map bucket = 1 [(validate.rules).map = { + min_pairs: 1 + keys {string {min_len: 1}} + values {string {min_len: 1}} + }]; +} diff --git a/xds/third_party/envoy/src/main/proto/envoy/service/status/v3/csds.proto b/xds/third_party/envoy/src/main/proto/envoy/service/status/v3/csds.proto index 1c51f2bac37..de62fbf9b0f 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/service/status/v3/csds.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/service/status/v3/csds.proto @@ -72,6 +72,11 @@ enum ClientConfigStatus { // config dump is not the NACKed version, but the most recent accepted one. If // no config is accepted yet, the attached config dump will be empty. CLIENT_NACKED = 3; + + // Client received an error from the control plane. The attached config + // dump is the most recent accepted one. If no config is accepted yet, + // the attached config dump will be empty. + CLIENT_RECEIVED_ERROR = 4; } // Request for client status of clients identified by a list of NodeMatchers. diff --git a/xds/third_party/envoy/src/main/proto/envoy/type/matcher/v3/address.proto b/xds/third_party/envoy/src/main/proto/envoy/type/matcher/v3/address.proto new file mode 100644 index 00000000000..8a03a5320af --- /dev/null +++ b/xds/third_party/envoy/src/main/proto/envoy/type/matcher/v3/address.proto @@ -0,0 +1,22 @@ +syntax = "proto3"; + +package envoy.type.matcher.v3; + +import "xds/core/v3/cidr.proto"; + +import "udpa/annotations/status.proto"; + +option java_package = "io.envoyproxy.envoy.type.matcher.v3"; +option java_outer_classname = "AddressProto"; +option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/type/matcher/v3;matcherv3"; +option (udpa.annotations.file_status).package_version_status = ACTIVE; + +// [#protodoc-title: Address Matcher] + +// Match an IP against a repeated CIDR range. This matcher is intended to be +// used in other matchers, for example in the filter state matcher to match a +// filter state object as an IP. +message AddressMatcher { + repeated xds.core.v3.CidrRange ranges = 1; +} diff --git a/xds/third_party/envoy/src/main/proto/envoy/type/matcher/v3/filter_state.proto b/xds/third_party/envoy/src/main/proto/envoy/type/matcher/v3/filter_state.proto index f813178ae05..8c38a515ae9 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/type/matcher/v3/filter_state.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/type/matcher/v3/filter_state.proto @@ -2,6 +2,7 @@ syntax = "proto3"; package envoy.type.matcher.v3; +import "envoy/type/matcher/v3/address.proto"; import "envoy/type/matcher/v3/string.proto"; import "udpa/annotations/status.proto"; @@ -25,5 +26,8 @@ message FilterStateMatcher { // Matches the filter state object as a string value. StringMatcher string_match = 2; + + // Matches the filter state object as a ip Instance. + AddressMatcher address_match = 3; } } diff --git a/xds/third_party/envoy/src/main/proto/envoy/type/matcher/v3/http_inputs.proto b/xds/third_party/envoy/src/main/proto/envoy/type/matcher/v3/http_inputs.proto new file mode 100644 index 00000000000..c90199eb618 --- /dev/null +++ b/xds/third_party/envoy/src/main/proto/envoy/type/matcher/v3/http_inputs.proto @@ -0,0 +1,71 @@ +syntax = "proto3"; + +package envoy.type.matcher.v3; + +import "udpa/annotations/status.proto"; +import "validate/validate.proto"; + +option java_package = "io.envoyproxy.envoy.type.matcher.v3"; +option java_outer_classname = "HttpInputsProto"; +option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/type/matcher/v3;matcherv3"; +option (udpa.annotations.file_status).package_version_status = ACTIVE; + +// [#protodoc-title: Common HTTP inputs] + +// Match input indicates that matching should be done on a specific request header. +// The resulting input string will be all headers for the given key joined by a comma, +// e.g. if the request contains two 'foo' headers with value 'bar' and 'baz', the input +// string will be 'bar,baz'. +// [#comment:TODO(snowp): Link to unified matching docs.] +// [#extension: envoy.matching.inputs.request_headers] +message HttpRequestHeaderMatchInput { + // The request header to match on. + string header_name = 1 + [(validate.rules).string = {well_known_regex: HTTP_HEADER_NAME strict: false}]; +} + +// Match input indicates that matching should be done on a specific request trailer. +// The resulting input string will be all headers for the given key joined by a comma, +// e.g. if the request contains two 'foo' headers with value 'bar' and 'baz', the input +// string will be 'bar,baz'. +// [#comment:TODO(snowp): Link to unified matching docs.] +// [#extension: envoy.matching.inputs.request_trailers] +message HttpRequestTrailerMatchInput { + // The request trailer to match on. + string header_name = 1 + [(validate.rules).string = {well_known_regex: HTTP_HEADER_NAME strict: false}]; +} + +// Match input indicating that matching should be done on a specific response header. +// The resulting input string will be all headers for the given key joined by a comma, +// e.g. if the response contains two 'foo' headers with value 'bar' and 'baz', the input +// string will be 'bar,baz'. +// [#comment:TODO(snowp): Link to unified matching docs.] +// [#extension: envoy.matching.inputs.response_headers] +message HttpResponseHeaderMatchInput { + // The response header to match on. + string header_name = 1 + [(validate.rules).string = {well_known_regex: HTTP_HEADER_NAME strict: false}]; +} + +// Match input indicates that matching should be done on a specific response trailer. +// The resulting input string will be all headers for the given key joined by a comma, +// e.g. if the request contains two 'foo' headers with value 'bar' and 'baz', the input +// string will be 'bar,baz'. +// [#comment:TODO(snowp): Link to unified matching docs.] +// [#extension: envoy.matching.inputs.response_trailers] +message HttpResponseTrailerMatchInput { + // The response trailer to match on. + string header_name = 1 + [(validate.rules).string = {well_known_regex: HTTP_HEADER_NAME strict: false}]; +} + +// Match input indicates that matching should be done on a specific query parameter. +// The resulting input string will be the first query parameter for the value +// 'query_param'. +// [#extension: envoy.matching.inputs.query_params] +message HttpRequestQueryParamMatchInput { + // The query parameter to match on. + string query_param = 1 [(validate.rules).string = {min_len: 1}]; +} diff --git a/xds/third_party/envoy/src/main/proto/envoy/type/matcher/v3/metadata.proto b/xds/third_party/envoy/src/main/proto/envoy/type/matcher/v3/metadata.proto index d3316e88a88..30abde97c09 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/type/matcher/v3/metadata.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/type/matcher/v3/metadata.proto @@ -16,11 +16,11 @@ option (udpa.annotations.file_status).package_version_status = ACTIVE; // [#protodoc-title: Metadata matcher] -// MetadataMatcher provides a general interface to check if a given value is matched in -// :ref:`Metadata `. It uses `filter` and `path` to retrieve the value -// from the Metadata and then check if it's matched to the specified value. +// ``MetadataMatcher`` provides a general interface to check if a given value is matched in +// :ref:`Metadata `. It uses ``filter`` and ``path`` to retrieve the value +// from the ``Metadata`` and then check if it's matched to the specified value. // -// For example, for the following Metadata: +// For example, for the following ``Metadata``: // // .. code-block:: yaml // @@ -41,8 +41,8 @@ option (udpa.annotations.file_status).package_version_status = ACTIVE; // - string_value: m // - string_value: n // -// The following MetadataMatcher is matched as the path [a, b, c] will retrieve a string value "pro" -// from the Metadata which is matched to the specified prefix match. +// The following ``MetadataMatcher`` is matched as the path ``[a, b, c]`` will retrieve a string value ``pro`` +// from the ``Metadata`` which is matched to the specified prefix match. // // .. code-block:: yaml // @@ -55,7 +55,7 @@ option (udpa.annotations.file_status).package_version_status = ACTIVE; // string_match: // prefix: pr // -// The following MetadataMatcher is matched as the code will match one of the string values in the +// The following ``MetadataMatcher`` is matched as the code will match one of the string values in the // list at the path [a, t]. // // .. code-block:: yaml @@ -70,7 +70,7 @@ option (udpa.annotations.file_status).package_version_status = ACTIVE; // string_match: // exact: m // -// An example use of MetadataMatcher is specifying additional metadata in envoy.filters.http.rbac to +// An example use of ``MetadataMatcher`` is specifying additional metadata in ``envoy.filters.http.rbac`` to // enforce access control based on dynamic metadata in a request. See :ref:`Permission // ` and :ref:`Principal // `. @@ -79,9 +79,11 @@ option (udpa.annotations.file_status).package_version_status = ACTIVE; message MetadataMatcher { option (udpa.annotations.versioning).previous_message_type = "envoy.type.matcher.MetadataMatcher"; - // Specifies the segment in a path to retrieve value from Metadata. - // Note: Currently it's not supported to retrieve a value from a list in Metadata. This means that - // if the segment key refers to a list, it has to be the last segment in a path. + // Specifies the segment in a path to retrieve value from ``Metadata``. + // + // .. note:: + // Currently it's not supported to retrieve a value from a list in ``Metadata``. This means that + // if the segment key refers to a list, it has to be the last segment in a path. message PathSegment { option (udpa.annotations.versioning).previous_message_type = "envoy.type.matcher.MetadataMatcher.PathSegment"; @@ -89,18 +91,18 @@ message MetadataMatcher { oneof segment { option (validate.required) = true; - // If specified, use the key to retrieve the value in a Struct. + // If specified, use the key to retrieve the value in a ``Struct``. string key = 1 [(validate.rules).string = {min_len: 1}]; } } - // The filter name to retrieve the Struct from the Metadata. + // The filter name to retrieve the ``Struct`` from the ``Metadata``. string filter = 1 [(validate.rules).string = {min_len: 1}]; - // The path to retrieve the Value from the Struct. + // The path to retrieve the ``Value`` from the ``Struct``. repeated PathSegment path = 2 [(validate.rules).repeated = {min_items: 1}]; - // The MetadataMatcher is matched if the value retrieved by path is matched to this value. + // The ``MetadataMatcher`` is matched if the value retrieved by path is matched to this value. ValueMatcher value = 3 [(validate.rules).message = {required: true}]; // If true, the match result will be inverted. diff --git a/xds/third_party/envoy/src/main/proto/envoy/type/matcher/v3/string.proto b/xds/third_party/envoy/src/main/proto/envoy/type/matcher/v3/string.proto index 2df1bd37a6a..56d39565ca5 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/type/matcher/v3/string.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/type/matcher/v3/string.proto @@ -4,6 +4,8 @@ package envoy.type.matcher.v3; import "envoy/type/matcher/v3/regex.proto"; +import "xds/core/v3/extension.proto"; + import "udpa/annotations/status.proto"; import "udpa/annotations/versioning.proto"; import "validate/validate.proto"; @@ -17,7 +19,7 @@ option (udpa.annotations.file_status).package_version_status = ACTIVE; // [#protodoc-title: String matcher] // Specifies the way to match a string. -// [#next-free-field: 8] +// [#next-free-field: 9] message StringMatcher { option (udpa.annotations.versioning).previous_message_type = "envoy.type.matcher.StringMatcher"; @@ -36,7 +38,10 @@ message StringMatcher { string exact = 1; // The input string must have the prefix specified here. - // Note: empty prefix is not allowed, please use regex instead. + // + // .. note:: + // + // Empty prefix match is not allowed, please use ``safe_regex`` instead. // // Examples: // @@ -44,7 +49,10 @@ message StringMatcher { string prefix = 2 [(validate.rules).string = {min_len: 1}]; // The input string must have the suffix specified here. - // Note: empty prefix is not allowed, please use regex instead. + // + // .. note:: + // + // Empty suffix match is not allowed, please use ``safe_regex`` instead. // // Examples: // @@ -55,17 +63,25 @@ message StringMatcher { RegexMatcher safe_regex = 5 [(validate.rules).message = {required: true}]; // The input string must have the substring specified here. - // Note: empty contains match is not allowed, please use regex instead. + // + // .. note:: + // + // Empty contains match is not allowed, please use ``safe_regex`` instead. // // Examples: // // * ``abc`` matches the value ``xyz.abc.def`` string contains = 7 [(validate.rules).string = {min_len: 1}]; + + // Use an extension as the matcher type. + // [#extension-category: envoy.string_matcher] + xds.core.v3.TypedExtensionConfig custom = 8; } - // If true, indicates the exact/prefix/suffix/contains matching should be case insensitive. This - // has no effect for the safe_regex match. - // For example, the matcher ``data`` will match both input string ``Data`` and ``data`` if set to true. + // If ``true``, indicates the exact/prefix/suffix/contains matching should be case insensitive. This + // has no effect for the ``safe_regex`` match. + // For example, the matcher ``data`` will match both input string ``Data`` and ``data`` if this option + // is set to ``true``. bool ignore_case = 6; } diff --git a/xds/third_party/envoy/src/main/proto/envoy/type/matcher/v3/value.proto b/xds/third_party/envoy/src/main/proto/envoy/type/matcher/v3/value.proto index d773c6057fc..8d65c457ccc 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/type/matcher/v3/value.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/type/matcher/v3/value.proto @@ -17,7 +17,7 @@ option (udpa.annotations.file_status).package_version_status = ACTIVE; // [#protodoc-title: Value matcher] -// Specifies the way to match a ProtobufWkt::Value. Primitive values and ListValue are supported. +// Specifies the way to match a Protobuf::Value. Primitive values and ListValue are supported. // StructValue is not supported and is always not matched. // [#next-free-field: 8] message ValueMatcher { diff --git a/xds/third_party/envoy/src/main/proto/envoy/type/metadata/v3/metadata.proto b/xds/third_party/envoy/src/main/proto/envoy/type/metadata/v3/metadata.proto index 20758577503..d131635bf9f 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/type/metadata/v3/metadata.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/type/metadata/v3/metadata.proto @@ -14,10 +14,10 @@ option (udpa.annotations.file_status).package_version_status = ACTIVE; // [#protodoc-title: Metadata] -// MetadataKey provides a general interface using ``key`` and ``path`` to retrieve value from -// :ref:`Metadata `. +// MetadataKey provides a way to retrieve values from +// :ref:`Metadata ` using a ``key`` and a ``path``. // -// For example, for the following Metadata: +// For example, consider the following Metadata: // // .. code-block:: yaml // @@ -28,7 +28,7 @@ option (udpa.annotations.file_status).package_version_status = ACTIVE; // xyz: // hello: envoy // -// The following MetadataKey will retrieve a string value "bar" from the Metadata. +// The following MetadataKey would retrieve the string value "bar" from the Metadata: // // .. code-block:: yaml // @@ -40,8 +40,8 @@ option (udpa.annotations.file_status).package_version_status = ACTIVE; message MetadataKey { option (udpa.annotations.versioning).previous_message_type = "envoy.type.metadata.v2.MetadataKey"; - // Specifies the segment in a path to retrieve value from Metadata. - // Currently it is only supported to specify the key, i.e. field name, as one segment of a path. + // Specifies a segment in a path for retrieving values from Metadata. + // Currently, only key-based segments (field names) are supported. message PathSegment { option (udpa.annotations.versioning).previous_message_type = "envoy.type.metadata.v2.MetadataKey.PathSegment"; @@ -49,25 +49,27 @@ message MetadataKey { oneof segment { option (validate.required) = true; - // If specified, use the key to retrieve the value in a Struct. + // If specified, use this key to retrieve the value in a Struct. string key = 1 [(validate.rules).string = {min_len: 1}]; } } - // The key name of Metadata to retrieve the Struct from the metadata. - // Typically, it represents a builtin subsystem or custom extension. + // The key name of the Metadata from which to retrieve the Struct. + // This typically represents a builtin subsystem or custom extension. string key = 1 [(validate.rules).string = {min_len: 1}]; - // The path to retrieve the Value from the Struct. It can be a prefix or a full path, - // e.g. ``[prop, xyz]`` for a struct or ``[prop, foo]`` for a string in the example, - // which depends on the particular scenario. + // The path used to retrieve a specific Value from the Struct. + // This can be either a prefix or a full path, depending on the use case. + // For example, ``[prop, xyz]`` would retrieve a struct or ``[prop, foo]`` would retrieve a string + // in the example above. // - // Note: Due to that only the key type segment is supported, the path can not specify a list - // unless the list is the last segment. + // .. note:: + // Since only key-type segments are supported, a path cannot specify a list + // unless the list is the last segment. repeated PathSegment path = 2 [(validate.rules).repeated = {min_items: 1}]; } -// Describes what kind of metadata. +// Describes different types of metadata sources. message MetadataKind { option (udpa.annotations.versioning).previous_message_type = "envoy.type.metadata.v2.MetadataKind"; diff --git a/xds/third_party/envoy/src/main/proto/envoy/type/tracing/v3/custom_tag.proto b/xds/third_party/envoy/src/main/proto/envoy/type/tracing/v3/custom_tag.proto index feb57e8eb66..cdb42a43507 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/type/tracing/v3/custom_tag.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/type/tracing/v3/custom_tag.proto @@ -17,7 +17,7 @@ option (udpa.annotations.file_status).package_version_status = ACTIVE; // [#protodoc-title: Custom Tag] // Describes custom tags for the active span. -// [#next-free-field: 6] +// [#next-free-field: 7] message CustomTag { option (udpa.annotations.versioning).previous_message_type = "envoy.type.tracing.v2.CustomTag"; @@ -98,5 +98,12 @@ message CustomTag { // A custom tag to obtain tag value from the metadata. Metadata metadata = 5; + + // Custom tag value. + // + // The same :ref:`format specifier ` as used for + // :ref:`HTTP access logging ` applies here, however + // unknown specifier values are replaced with the empty string instead of ``-``. + string value = 6; } } diff --git a/xds/third_party/envoy/src/main/proto/envoy/type/v3/http_status.proto b/xds/third_party/envoy/src/main/proto/envoy/type/v3/http_status.proto new file mode 100644 index 00000000000..40d697beefc --- /dev/null +++ b/xds/third_party/envoy/src/main/proto/envoy/type/v3/http_status.proto @@ -0,0 +1,199 @@ +syntax = "proto3"; + +package envoy.type.v3; + +import "udpa/annotations/status.proto"; +import "udpa/annotations/versioning.proto"; +import "validate/validate.proto"; + +option java_package = "io.envoyproxy.envoy.type.v3"; +option java_outer_classname = "HttpStatusProto"; +option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/type/v3;typev3"; +option (udpa.annotations.file_status).package_version_status = ACTIVE; + +// [#protodoc-title: HTTP status codes] + +// HTTP response codes supported in Envoy. +// For more details: https://www.iana.org/assignments/http-status-codes/http-status-codes.xhtml +enum StatusCode { + // Empty - This code not part of the HTTP status code specification, but it is needed for proto + // `enum` type. + Empty = 0; + + // Continue - ``100`` status code. + Continue = 100; + + // OK - ``200`` status code. + OK = 200; + + // Created - ``201`` status code. + Created = 201; + + // Accepted - ``202`` status code. + Accepted = 202; + + // NonAuthoritativeInformation - ``203`` status code. + NonAuthoritativeInformation = 203; + + // NoContent - ``204`` status code. + NoContent = 204; + + // ResetContent - ``205`` status code. + ResetContent = 205; + + // PartialContent - ``206`` status code. + PartialContent = 206; + + // MultiStatus - ``207`` status code. + MultiStatus = 207; + + // AlreadyReported - ``208`` status code. + AlreadyReported = 208; + + // IMUsed - ``226`` status code. + IMUsed = 226; + + // MultipleChoices - ``300`` status code. + MultipleChoices = 300; + + // MovedPermanently - ``301`` status code. + MovedPermanently = 301; + + // Found - ``302`` status code. + Found = 302; + + // SeeOther - ``303`` status code. + SeeOther = 303; + + // NotModified - ``304`` status code. + NotModified = 304; + + // UseProxy - ``305`` status code. + UseProxy = 305; + + // TemporaryRedirect - ``307`` status code. + TemporaryRedirect = 307; + + // PermanentRedirect - ``308`` status code. + PermanentRedirect = 308; + + // BadRequest - ``400`` status code. + BadRequest = 400; + + // Unauthorized - ``401`` status code. + Unauthorized = 401; + + // PaymentRequired - ``402`` status code. + PaymentRequired = 402; + + // Forbidden - ``403`` status code. + Forbidden = 403; + + // NotFound - ``404`` status code. + NotFound = 404; + + // MethodNotAllowed - ``405`` status code. + MethodNotAllowed = 405; + + // NotAcceptable - ``406`` status code. + NotAcceptable = 406; + + // ProxyAuthenticationRequired - ``407`` status code. + ProxyAuthenticationRequired = 407; + + // RequestTimeout - ``408`` status code. + RequestTimeout = 408; + + // Conflict - ``409`` status code. + Conflict = 409; + + // Gone - ``410`` status code. + Gone = 410; + + // LengthRequired - ``411`` status code. + LengthRequired = 411; + + // PreconditionFailed - ``412`` status code. + PreconditionFailed = 412; + + // PayloadTooLarge - ``413`` status code. + PayloadTooLarge = 413; + + // URITooLong - ``414`` status code. + URITooLong = 414; + + // UnsupportedMediaType - ``415`` status code. + UnsupportedMediaType = 415; + + // RangeNotSatisfiable - ``416`` status code. + RangeNotSatisfiable = 416; + + // ExpectationFailed - ``417`` status code. + ExpectationFailed = 417; + + // MisdirectedRequest - ``421`` status code. + MisdirectedRequest = 421; + + // UnprocessableEntity - ``422`` status code. + UnprocessableEntity = 422; + + // Locked - ``423`` status code. + Locked = 423; + + // FailedDependency - ``424`` status code. + FailedDependency = 424; + + // UpgradeRequired - ``426`` status code. + UpgradeRequired = 426; + + // PreconditionRequired - ``428`` status code. + PreconditionRequired = 428; + + // TooManyRequests - ``429`` status code. + TooManyRequests = 429; + + // RequestHeaderFieldsTooLarge - ``431`` status code. + RequestHeaderFieldsTooLarge = 431; + + // InternalServerError - ``500`` status code. + InternalServerError = 500; + + // NotImplemented - ``501`` status code. + NotImplemented = 501; + + // BadGateway - ``502`` status code. + BadGateway = 502; + + // ServiceUnavailable - ``503`` status code. + ServiceUnavailable = 503; + + // GatewayTimeout - ``504`` status code. + GatewayTimeout = 504; + + // HTTPVersionNotSupported - ``505`` status code. + HTTPVersionNotSupported = 505; + + // VariantAlsoNegotiates - ``506`` status code. + VariantAlsoNegotiates = 506; + + // InsufficientStorage - ``507`` status code. + InsufficientStorage = 507; + + // LoopDetected - ``508`` status code. + LoopDetected = 508; + + // NotExtended - ``510`` status code. + NotExtended = 510; + + // NetworkAuthenticationRequired - ``511`` status code. + NetworkAuthenticationRequired = 511; +} + +// HTTP status. +message HttpStatus { + option (udpa.annotations.versioning).previous_message_type = "envoy.type.HttpStatus"; + + // Supplies HTTP response code. + StatusCode code = 1 [(validate.rules).enum = {defined_only: true not_in: 0}]; +} diff --git a/xds/third_party/envoy/src/main/proto/envoy/type/v3/ratelimit_strategy.proto b/xds/third_party/envoy/src/main/proto/envoy/type/v3/ratelimit_strategy.proto new file mode 100644 index 00000000000..a86da55b854 --- /dev/null +++ b/xds/third_party/envoy/src/main/proto/envoy/type/v3/ratelimit_strategy.proto @@ -0,0 +1,79 @@ +syntax = "proto3"; + +package envoy.type.v3; + +import "envoy/type/v3/ratelimit_unit.proto"; +import "envoy/type/v3/token_bucket.proto"; + +import "xds/annotations/v3/status.proto"; + +import "udpa/annotations/status.proto"; +import "validate/validate.proto"; + +option java_package = "io.envoyproxy.envoy.type.v3"; +option java_outer_classname = "RatelimitStrategyProto"; +option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/type/v3;typev3"; +option (udpa.annotations.file_status).package_version_status = ACTIVE; +option (xds.annotations.v3.file_status).work_in_progress = true; + +// [#protodoc-title: Rate Limit Strategies] + +message RateLimitStrategy { + // Choose between allow all and deny all. + enum BlanketRule { + ALLOW_ALL = 0; + DENY_ALL = 1; + } + + // Best-effort limit of the number of requests per time unit. + // + // Allows to specify the desired requests per second (RPS, QPS), requests per minute (QPM, RPM), + // etc., without specifying a rate limiting algorithm implementation. + // + // ``RequestsPerTimeUnit`` strategy does not demand any specific rate limiting algorithm to be + // used (in contrast to the :ref:`TokenBucket `, + // for example). It implies that the implementation details of rate limiting algorithm are + // irrelevant as long as the configured number of "requests per time unit" is achieved. + // + // Note that the ``TokenBucket`` is still a valid implementation of the ``RequestsPerTimeUnit`` + // strategy, and may be chosen to enforce the rate limit. However, there's no guarantee it will be + // the ``TokenBucket`` in particular, and not the Leaky Bucket, the Sliding Window, or any other + // rate limiting algorithm that fulfills the requirements. + message RequestsPerTimeUnit { + // The desired number of requests per :ref:`time_unit + // ` to allow. + // If set to ``0``, deny all (equivalent to ``BlanketRule.DENY_ALL``). + // + // .. note:: + // Note that the algorithm implementation determines the course of action for the requests + // over the limit. As long as the ``requests_per_time_unit`` converges on the desired value, + // it's allowed to treat this field as a soft-limit: allow bursts, redistribute the allowance + // over time, etc. + // + uint64 requests_per_time_unit = 1; + + // The unit of time. Ignored when :ref:`requests_per_time_unit + // ` + // is ``0`` (deny all). + RateLimitUnit time_unit = 2 [(validate.rules).enum = {defined_only: true}]; + } + + oneof strategy { + option (validate.required) = true; + + // Allow or Deny the requests. + // If unset, allow all. + BlanketRule blanket_rule = 1 [(validate.rules).enum = {defined_only: true}]; + + // Best-effort limit of the number of requests per time unit, f.e. requests per second. + // Does not prescribe any specific rate limiting algorithm, see :ref:`RequestsPerTimeUnit + // ` for details. + RequestsPerTimeUnit requests_per_time_unit = 2; + + // Limit the requests by consuming tokens from the Token Bucket. + // Allow the same number of requests as the number of tokens available in + // the token bucket. + TokenBucket token_bucket = 3; + } +} diff --git a/xds/third_party/envoy/src/main/proto/envoy/type/v3/ratelimit_unit.proto b/xds/third_party/envoy/src/main/proto/envoy/type/v3/ratelimit_unit.proto new file mode 100644 index 00000000000..1a96497926d --- /dev/null +++ b/xds/third_party/envoy/src/main/proto/envoy/type/v3/ratelimit_unit.proto @@ -0,0 +1,37 @@ +syntax = "proto3"; + +package envoy.type.v3; + +import "udpa/annotations/status.proto"; + +option java_package = "io.envoyproxy.envoy.type.v3"; +option java_outer_classname = "RatelimitUnitProto"; +option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/type/v3;typev3"; +option (udpa.annotations.file_status).package_version_status = ACTIVE; + +// [#protodoc-title: Ratelimit Time Unit] + +// Identifies the unit of of time for rate limit. +enum RateLimitUnit { + // The time unit is not known. + UNKNOWN = 0; + + // The time unit representing a second. + SECOND = 1; + + // The time unit representing a minute. + MINUTE = 2; + + // The time unit representing an hour. + HOUR = 3; + + // The time unit representing a day. + DAY = 4; + + // The time unit representing a month. + MONTH = 5; + + // The time unit representing a year. + YEAR = 6; +} diff --git a/xds/third_party/envoy/src/main/proto/envoy/type/v3/token_bucket.proto b/xds/third_party/envoy/src/main/proto/envoy/type/v3/token_bucket.proto new file mode 100644 index 00000000000..157a271efc9 --- /dev/null +++ b/xds/third_party/envoy/src/main/proto/envoy/type/v3/token_bucket.proto @@ -0,0 +1,39 @@ +syntax = "proto3"; + +package envoy.type.v3; + +import "google/protobuf/duration.proto"; +import "google/protobuf/wrappers.proto"; + +import "udpa/annotations/status.proto"; +import "udpa/annotations/versioning.proto"; +import "validate/validate.proto"; + +option java_package = "io.envoyproxy.envoy.type.v3"; +option java_outer_classname = "TokenBucketProto"; +option java_multiple_files = true; +option go_package = "github.com/envoyproxy/go-control-plane/envoy/type/v3;typev3"; +option (udpa.annotations.file_status).package_version_status = ACTIVE; + +// [#protodoc-title: Token bucket] + +// Configures a token bucket, typically used for rate limiting. +message TokenBucket { + option (udpa.annotations.versioning).previous_message_type = "envoy.type.TokenBucket"; + + // The maximum tokens that the bucket can hold. This is also the number of tokens that the bucket + // initially contains. + uint32 max_tokens = 1 [(validate.rules).uint32 = {gt: 0}]; + + // The number of tokens added to the bucket during each fill interval. If not specified, defaults + // to a single token. + google.protobuf.UInt32Value tokens_per_fill = 2 [(validate.rules).uint32 = {gt: 0}]; + + // The fill interval that tokens are added to the bucket. During each fill interval + // ``tokens_per_fill`` are added to the bucket. The bucket will never contain more than + // ``max_tokens`` tokens. + google.protobuf.Duration fill_interval = 3 [(validate.rules).duration = { + required: true + gt {} + }]; +} diff --git a/xds/third_party/xds/import.sh b/xds/third_party/xds/import.sh index cda86d0368f..7af5c8489d1 100755 --- a/xds/third_party/xds/import.sh +++ b/xds/third_party/xds/import.sh @@ -17,7 +17,7 @@ set -e # import VERSION from one of the google internal CLs -VERSION=e9ce68804cb4e64cab5a52e3c8baf840d4ff87b7 +VERSION=2ac532fd44436293585084f8d94c6bdb17835af0 DOWNLOAD_URL="https://github.com/cncf/xds/archive/${VERSION}.tar.gz" DOWNLOAD_BASE_DIR="xds-${VERSION}" SOURCE_PROTO_BASE_DIR="${DOWNLOAD_BASE_DIR}" @@ -40,14 +40,18 @@ xds/annotations/v3/versioning.proto xds/core/v3/authority.proto xds/core/v3/collection_entry.proto xds/core/v3/context_params.proto +xds/core/v3/cidr.proto xds/core/v3/extension.proto xds/core/v3/resource_locator.proto xds/core/v3/resource_name.proto xds/data/orca/v3/orca_load_report.proto xds/service/orca/v3/orca.proto +xds/type/matcher/v3/cel.proto xds/type/matcher/v3/matcher.proto xds/type/matcher/v3/regex.proto xds/type/matcher/v3/string.proto +xds/type/v3/cel.proto +xds/type/matcher/v3/http_inputs.proto xds/type/v3/typed_struct.proto ) diff --git a/xds/third_party/xds/src/main/proto/udpa/annotations/migrate.proto b/xds/third_party/xds/src/main/proto/udpa/annotations/migrate.proto index 5289cb8a742..5f5f389b7d2 100644 --- a/xds/third_party/xds/src/main/proto/udpa/annotations/migrate.proto +++ b/xds/third_party/xds/src/main/proto/udpa/annotations/migrate.proto @@ -8,7 +8,7 @@ package udpa.annotations; import "google/protobuf/descriptor.proto"; -option go_package = "github.com/cncf/xds/go/annotations"; +option go_package = "github.com/cncf/xds/go/udpa/annotations"; // Magic number in this file derived from top 28bit of SHA256 digest of // "udpa.annotation.migrate". diff --git a/xds/third_party/xds/src/main/proto/udpa/annotations/security.proto b/xds/third_party/xds/src/main/proto/udpa/annotations/security.proto index 52801d30d1e..0ef919716da 100644 --- a/xds/third_party/xds/src/main/proto/udpa/annotations/security.proto +++ b/xds/third_party/xds/src/main/proto/udpa/annotations/security.proto @@ -10,7 +10,7 @@ import "udpa/annotations/status.proto"; import "google/protobuf/descriptor.proto"; -option go_package = "github.com/cncf/xds/go/annotations"; +option go_package = "github.com/cncf/xds/go/udpa/annotations"; // All annotations in this file are experimental and subject to change. Their // only consumer today is the Envoy APIs and SecuritAnnotationValidator protoc diff --git a/xds/third_party/xds/src/main/proto/udpa/annotations/sensitive.proto b/xds/third_party/xds/src/main/proto/udpa/annotations/sensitive.proto index ab822fb4884..c7d8af608be 100644 --- a/xds/third_party/xds/src/main/proto/udpa/annotations/sensitive.proto +++ b/xds/third_party/xds/src/main/proto/udpa/annotations/sensitive.proto @@ -8,7 +8,7 @@ package udpa.annotations; import "google/protobuf/descriptor.proto"; -option go_package = "github.com/cncf/xds/go/annotations"; +option go_package = "github.com/cncf/xds/go/udpa/annotations"; extend google.protobuf.FieldOptions { // Magic number is the 28 most significant bits in the sha256sum of "udpa.annotations.sensitive". diff --git a/xds/third_party/xds/src/main/proto/udpa/annotations/status.proto b/xds/third_party/xds/src/main/proto/udpa/annotations/status.proto index 76cfd4dcfef..5a90bde29c7 100644 --- a/xds/third_party/xds/src/main/proto/udpa/annotations/status.proto +++ b/xds/third_party/xds/src/main/proto/udpa/annotations/status.proto @@ -8,7 +8,7 @@ package udpa.annotations; import "google/protobuf/descriptor.proto"; -option go_package = "github.com/cncf/xds/go/annotations"; +option go_package = "github.com/cncf/xds/go/udpa/annotations"; // Magic number in this file derived from top 28bit of SHA256 digest of // "udpa.annotation.status". diff --git a/xds/third_party/xds/src/main/proto/udpa/annotations/versioning.proto b/xds/third_party/xds/src/main/proto/udpa/annotations/versioning.proto index dcb7c85fd4f..06df78d818f 100644 --- a/xds/third_party/xds/src/main/proto/udpa/annotations/versioning.proto +++ b/xds/third_party/xds/src/main/proto/udpa/annotations/versioning.proto @@ -8,7 +8,7 @@ package udpa.annotations; import "google/protobuf/descriptor.proto"; -option go_package = "github.com/cncf/xds/go/annotations"; +option go_package = "github.com/cncf/xds/go/udpa/annotations"; extend google.protobuf.MessageOptions { // Magic number derived from 0x78 ('x') 0x44 ('D') 0x53 ('S') diff --git a/xds/third_party/xds/src/main/proto/xds/core/v3/cidr.proto b/xds/third_party/xds/src/main/proto/xds/core/v3/cidr.proto new file mode 100644 index 00000000000..b8471bc8078 --- /dev/null +++ b/xds/third_party/xds/src/main/proto/xds/core/v3/cidr.proto @@ -0,0 +1,25 @@ +syntax = "proto3"; + +package xds.core.v3; + +import "xds/annotations/v3/status.proto"; +import "google/protobuf/wrappers.proto"; + +import "validate/validate.proto"; + +option java_outer_classname = "CidrRangeProto"; +option java_multiple_files = true; +option java_package = "com.github.xds.core.v3"; +option go_package = "github.com/cncf/xds/go/xds/core/v3"; + +option (xds.annotations.v3.file_status).work_in_progress = true; + +// CidrRange specifies an IP Address and a prefix length to construct +// the subnet mask for a `CIDR `_ range. +message CidrRange { + // IPv4 or IPv6 address, e.g. ``192.0.0.0`` or ``2001:db8::``. + string address_prefix = 1 [(validate.rules).string = {min_len: 1}]; + + // Length of prefix, e.g. 0, 32. Defaults to 0 when unset. + google.protobuf.UInt32Value prefix_len = 2 [(validate.rules).uint32 = {lte: 128}]; +} \ No newline at end of file diff --git a/xds/third_party/xds/src/main/proto/xds/data/orca/v3/orca_load_report.proto b/xds/third_party/xds/src/main/proto/xds/data/orca/v3/orca_load_report.proto index 53da75f78ac..1b0847585a4 100644 --- a/xds/third_party/xds/src/main/proto/xds/data/orca/v3/orca_load_report.proto +++ b/xds/third_party/xds/src/main/proto/xds/data/orca/v3/orca_load_report.proto @@ -10,7 +10,7 @@ option go_package = "github.com/cncf/xds/go/xds/data/orca/v3"; import "validate/validate.proto"; // See section `ORCA load report format` of the design document in -// :ref:`https://github.com/envoyproxy/envoy/issues/6614`. +// https://github.com/envoyproxy/envoy/issues/6614. message OrcaLoadReport { // CPU utilization expressed as a fraction of available CPU resources. This diff --git a/xds/third_party/xds/src/main/proto/xds/type/matcher/v3/cel.proto b/xds/third_party/xds/src/main/proto/xds/type/matcher/v3/cel.proto new file mode 100644 index 00000000000..a45af9534a0 --- /dev/null +++ b/xds/third_party/xds/src/main/proto/xds/type/matcher/v3/cel.proto @@ -0,0 +1,37 @@ +syntax = "proto3"; + +package xds.type.matcher.v3; + +import "xds/type/v3/cel.proto"; +import "validate/validate.proto"; + +option java_package = "com.github.xds.type.matcher.v3"; +option java_outer_classname = "CelProto"; +option java_multiple_files = true; +option go_package = "github.com/cncf/xds/go/xds/type/matcher/v3"; + +// [#protodoc-title: Common Expression Language (CEL) matchers] + +// Performs a match by evaluating a `Common Expression Language +// `_ (CEL) expression against the standardized set of +// :ref:`HTTP attributes ` specified via ``HttpAttributesCelMatchInput``. +// +// .. attention:: +// +// The match is ``true``, iff the result of the evaluation is a bool AND true. +// In all other cases, the match is ``false``, including but not limited to: non-bool types, +// ``false``, ``null``, ``int(1)``, etc. +// In case CEL expression raises an error, the result of the evaluation is interpreted "no match". +// +// Refer to :ref:`Unified Matcher API ` documentation +// for usage details. +// +// [#comment: envoy.matching.matchers.cel_matcher] +message CelMatcher { + // Either parsed or checked representation of the CEL program. + type.v3.CelExpression expr_match = 1 [(validate.rules).message = {required: true}]; + + // Free-form description of the CEL AST, e.g. the original expression text, to be + // used for debugging assistance. + string description = 2; +} diff --git a/xds/third_party/xds/src/main/proto/xds/type/matcher/v3/http_inputs.proto b/xds/third_party/xds/src/main/proto/xds/type/matcher/v3/http_inputs.proto new file mode 100644 index 00000000000..5709d64501b --- /dev/null +++ b/xds/third_party/xds/src/main/proto/xds/type/matcher/v3/http_inputs.proto @@ -0,0 +1,23 @@ +syntax = "proto3"; + +package xds.type.matcher.v3; + +option java_package = "com.github.xds.type.matcher.v3"; +option java_outer_classname = "HttpInputsProto"; +option java_multiple_files = true; +option go_package = "github.com/cncf/xds/go/xds/type/matcher/v3"; + +// [#protodoc-title: Common HTTP Inputs] + +// Specifies that matching should be performed on the set of :ref:`HTTP attributes +// `. +// +// The attributes will be exposed via `Common Expression Language +// `_ runtime to associated CEL matcher. +// +// Refer to :ref:`Unified Matcher API ` documentation +// for usage details. +// +// [#comment: envoy.matching.inputs.cel_data_input] +message HttpAttributesCelMatchInput { +} diff --git a/xds/third_party/xds/src/main/proto/xds/type/matcher/v3/matcher.proto b/xds/third_party/xds/src/main/proto/xds/type/matcher/v3/matcher.proto index 4966b456bee..cc03ff6e98f 100644 --- a/xds/third_party/xds/src/main/proto/xds/type/matcher/v3/matcher.proto +++ b/xds/third_party/xds/src/main/proto/xds/type/matcher/v3/matcher.proto @@ -2,7 +2,6 @@ syntax = "proto3"; package xds.type.matcher.v3; -import "xds/annotations/v3/status.proto"; import "xds/core/v3/extension.proto"; import "xds/type/matcher/v3/string.proto"; @@ -21,8 +20,6 @@ option go_package = "github.com/cncf/xds/go/xds/type/matcher/v3"; // As an on_no_match might result in another matching tree being evaluated, this process // might repeat several times until the final OnMatch (or no match) is decided. message Matcher { - option (xds.annotations.v3.message_status).work_in_progress = true; - // What to do if a match is successful. message OnMatch { oneof on_match { @@ -38,6 +35,14 @@ message Matcher { // Protocol-specific action to take. core.v3.TypedExtensionConfig action = 2; } + + // If true and the Matcher matches, the action will be taken but the caller + // will behave as if the Matcher did not match. A subsequent matcher or + // on_no_match action will be used instead. + // This field is not supported in all contexts in which the matcher API is + // used. If this field is set in a context in which it's not supported, + // the resource will be rejected. + bool keep_matching = 3; } // A linear list of field matchers. diff --git a/xds/third_party/xds/src/main/proto/xds/type/matcher/v3/string.proto b/xds/third_party/xds/src/main/proto/xds/type/matcher/v3/string.proto index fdc598e174a..e58cb413e96 100644 --- a/xds/third_party/xds/src/main/proto/xds/type/matcher/v3/string.proto +++ b/xds/third_party/xds/src/main/proto/xds/type/matcher/v3/string.proto @@ -2,6 +2,7 @@ syntax = "proto3"; package xds.type.matcher.v3; +import "xds/core/v3/extension.proto"; import "xds/type/matcher/v3/regex.proto"; import "validate/validate.proto"; @@ -14,7 +15,7 @@ option go_package = "github.com/cncf/xds/go/xds/type/matcher/v3"; // [#protodoc-title: String matcher] // Specifies the way to match a string. -// [#next-free-field: 8] +// [#next-free-field: 9] message StringMatcher { oneof match_pattern { option (validate.required) = true; @@ -52,6 +53,10 @@ message StringMatcher { // // * *abc* matches the value *xyz.abc.def* string contains = 7 [(validate.rules).string = {min_len: 1}]; + + // Use an extension as the matcher type. + // [#extension-category: envoy.string_matcher] + xds.core.v3.TypedExtensionConfig custom = 8; } // If true, indicates the exact/prefix/suffix matching should be case insensitive. This has no diff --git a/xds/third_party/xds/src/main/proto/xds/type/v3/cel.proto b/xds/third_party/xds/src/main/proto/xds/type/v3/cel.proto new file mode 100644 index 00000000000..043990401c6 --- /dev/null +++ b/xds/third_party/xds/src/main/proto/xds/type/v3/cel.proto @@ -0,0 +1,77 @@ +syntax = "proto3"; + +package xds.type.v3; + +import "google/api/expr/v1alpha1/checked.proto"; +import "google/api/expr/v1alpha1/syntax.proto"; +import "cel/expr/checked.proto"; +import "cel/expr/syntax.proto"; +import "google/protobuf/wrappers.proto"; + +import "xds/annotations/v3/status.proto"; + +import "validate/validate.proto"; + +option java_package = "com.github.xds.type.v3"; +option java_outer_classname = "CelProto"; +option java_multiple_files = true; +option go_package = "github.com/cncf/xds/go/xds/type/v3"; + +option (xds.annotations.v3.file_status).work_in_progress = true; + +// [#protodoc-title: Common Expression Language (CEL)] + +// Either parsed or checked representation of the `Common Expression Language +// `_ (CEL) program. +message CelExpression { + oneof expr_specifier { + // Parsed expression in abstract syntax tree (AST) form. + // + // Deprecated -- use ``cel_expr_parsed`` field instead. + // If ``cel_expr_parsed`` or ``cel_expr_checked`` is set, this field is not used. + google.api.expr.v1alpha1.ParsedExpr parsed_expr = 1 [deprecated = true]; + + // Parsed expression in abstract syntax tree (AST) form that has been successfully type checked. + // + // Deprecated -- use ``cel_expr_checked`` field instead. + // If ``cel_expr_parsed`` or ``cel_expr_checked`` is set, this field is not used. + google.api.expr.v1alpha1.CheckedExpr checked_expr = 2 [deprecated = true]; + } + + // Parsed expression in abstract syntax tree (AST) form. + // + // If ``cel_expr_checked`` is set, this field is not used. + cel.expr.ParsedExpr cel_expr_parsed = 3; + + // Parsed expression in abstract syntax tree (AST) form that has been successfully type checked. + // + // If set, takes precedence over ``cel_expr_parsed``. + cel.expr.CheckedExpr cel_expr_checked = 4; + + // Unparsed expression in string form. For example, ``request.headers['x-env'] == 'prod'`` will + // get ``x-env`` header value and compare it with ``prod``. + // Check the `Common Expression Language `_ for more details. + // + // If set, takes precedence over ``cel_expr_parsed`` and ``cel_expr_checked``. + string cel_expr_string = 5; +} + +// Extracts a string by evaluating a `Common Expression Language +// `_ (CEL) expression against the standardized set of +// :ref:`HTTP attributes `. +// +// .. attention:: +// +// Besides CEL evaluation raising an error explicitly, CEL program returning a type other than +// the ``string``, or not returning anything, are considered an error as well. +// +// [#comment:TODO(sergiitk): When implemented, add the extension tag.] +message CelExtractString { + // The CEL expression used to extract a string from the CEL environment. + // the "subject string") that should be replaced. + CelExpression expr_extract = 1 [(validate.rules).message = {required: true}]; + + // If CEL expression evaluates to an error, this value is be returned to the caller. + // If not set, the error is propagated to the caller. + google.protobuf.StringValue default_value = 2; +}