From b5bf37e7ca699ff142afc4439279a2583e2e5998 Mon Sep 17 00:00:00 2001 From: Hardik Shah Date: Thu, 14 Feb 2019 01:04:44 +0530 Subject: [PATCH 001/226] Firebase auth Email Action Links API (#258) * added email action link generation to auth * added change log entry * Integration tests added * fixed review comments * fixes for review comments * minor tidy up - review comments * fixed cosmetic comments in IT --- CHANGELOG.md | 3 + firebase_admin/_auth_utils.py | 7 ++ firebase_admin/_user_mgt.py | 118 +++++++++++++++++++++++++ firebase_admin/auth.py | 77 +++++++++++++++++ integration/test_auth.py | 94 ++++++++++++++++++++ tests/test_user_mgt.py | 156 ++++++++++++++++++++++++++++++++++ 6 files changed, 455 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index b7b12de99..894ef86b1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,8 @@ # Unreleased +- [added] Added `generate_password_reset_link()`, + `generate_email_verification_link()` and `generate_sign_in_with_email_link()` + methods to the `auth` API. - [added] Migrated the `auth` user management API to the new Identity Toolkit endpoint. - [fixed] Extending HTTP retries to more HTTP methods like POST and PATCH. diff --git a/firebase_admin/_auth_utils.py b/firebase_admin/_auth_utils.py index 852438725..b6788355c 100644 --- a/firebase_admin/_auth_utils.py +++ b/firebase_admin/_auth_utils.py @@ -26,6 +26,7 @@ 'acr', 'amr', 'at_hash', 'aud', 'auth_time', 'azp', 'cnf', 'c_hash', 'exp', 'iat', 'iss', 'jti', 'nbf', 'nonce', 'sub', 'firebase', ]) +VALID_EMAIL_ACTION_TYPES = set(['VERIFY_EMAIL', 'EMAIL_SIGNIN', 'PASSWORD_RESET']) def validate_uid(uid, required=False): @@ -181,3 +182,9 @@ def validate_custom_claims(custom_claims, required=False): raise ValueError( 'Claim "{0}" is reserved, and must not be set.'.format(invalid_claims.pop())) return claims_str + +def validate_action_type(action_type): + if action_type not in VALID_EMAIL_ACTION_TYPES: + raise ValueError('Invalid action type provided action_type: {0}. \ + Valid values are {1}'.format(action_type, ', '.join(VALID_EMAIL_ACTION_TYPES))) + return action_type diff --git a/firebase_admin/_user_mgt.py b/firebase_admin/_user_mgt.py index 227e13151..71e2055ad 100644 --- a/firebase_admin/_user_mgt.py +++ b/firebase_admin/_user_mgt.py @@ -18,6 +18,7 @@ import requests import six +from six.moves import urllib from firebase_admin import _auth_utils from firebase_admin import _user_import @@ -30,6 +31,7 @@ USER_DELETE_ERROR = 'USER_DELETE_ERROR' USER_IMPORT_ERROR = 'USER_IMPORT_ERROR' USER_DOWNLOAD_ERROR = 'LIST_USERS_ERROR' +GENERATE_EMAIL_ACTION_LINK_ERROR = 'GENERATE_EMAIL_ACTION_LINK_ERROR' MAX_LIST_USERS_RESULTS = 1000 MAX_IMPORT_USERS_SIZE = 1000 @@ -372,6 +374,87 @@ def photo_url(self): def provider_id(self): return self._data.get('providerId') +class ActionCodeSettings(object): + """Contains required continue/state URL with optional Android and iOS settings. + Used when invoking the email action link generation APIs. + """ + + def __init__(self, url, handle_code_in_app=None, dynamic_link_domain=None, ios_bundle_id=None, + android_package_name=None, android_install_app=None, android_minimum_version=None): + self.url = url + self.handle_code_in_app = handle_code_in_app + self.dynamic_link_domain = dynamic_link_domain + self.ios_bundle_id = ios_bundle_id + self.android_package_name = android_package_name + self.android_install_app = android_install_app + self.android_minimum_version = android_minimum_version + +def encode_action_code_settings(settings): + """ Validates the provided action code settings for email link generation and + populates the REST api parameters. + + settings - ``ActionCodeSettings`` object provided to be encoded + returns - dict of parameters to be passed for link gereration. + """ + + parameters = {} + # url + if not settings.url: + raise ValueError("Dynamic action links url is mandatory") + + try: + parsed = urllib.parse.urlparse(settings.url) + if not parsed.netloc: + raise ValueError('Malformed dynamic action links url: "{0}".'.format(settings.url)) + parameters['continueUrl'] = settings.url + except Exception: + raise ValueError('Malformed dynamic action links url: "{0}".'.format(settings.url)) + + # handle_code_in_app + if settings.handle_code_in_app is not None: + if not isinstance(settings.handle_code_in_app, bool): + raise ValueError('Invalid value provided for handle_code_in_app: {0}' + .format(settings.handle_code_in_app)) + parameters['canHandleCodeInApp'] = settings.handle_code_in_app + + # dynamic_link_domain + if settings.dynamic_link_domain is not None: + if not isinstance(settings.dynamic_link_domain, six.string_types): + raise ValueError('Invalid value provided for dynamic_link_domain: {0}' + .format(settings.dynamic_link_domain)) + parameters['dynamicLinkDomain'] = settings.dynamic_link_domain + + # ios_bundle_id + if settings.ios_bundle_id is not None: + if not isinstance(settings.ios_bundle_id, six.string_types): + raise ValueError('Invalid value provided for ios_bundle_id: {0}' + .format(settings.ios_bundle_id)) + parameters['iosBundleId'] = settings.ios_bundle_id + + # android_* attributes + if (settings.android_minimum_version or settings.android_install_app) \ + and not settings.android_package_name: + raise ValueError("Android package name is required when specifying other Android settings") + + if settings.android_package_name is not None: + if not isinstance(settings.android_package_name, six.string_types): + raise ValueError('Invalid value provided for android_package_name: {0}' + .format(settings.android_package_name)) + parameters['androidPackageName'] = settings.android_package_name + + if settings.android_minimum_version is not None: + if not isinstance(settings.android_minimum_version, six.string_types): + raise ValueError('Invalid value provided for android_minimum_version: {0}' + .format(settings.android_minimum_version)) + parameters['androidMinimumVersion'] = settings.android_minimum_version + + if settings.android_install_app is not None: + if not isinstance(settings.android_install_app, bool): + raise ValueError('Invalid value provided for android_install_app: {0}' + .format(settings.android_install_app)) + parameters['androidInstallApp'] = settings.android_install_app + + return parameters class UserManager(object): """Provides methods for interacting with the Google Identity Toolkit.""" @@ -537,6 +620,41 @@ def import_users(self, users, hash_alg=None): raise ApiCallError(USER_IMPORT_ERROR, 'Failed to import users.') return response + def generate_email_action_link(self, action_type, email, action_code_settings=None): + """Fetches the email action links for types + + Args: + action_type: String. Valid values ['VERIFY_EMAIL', 'EMAIL_SIGNIN', 'PASSWORD_RESET'] + email: Email of the user for which the action is performed + action_code_settings: ``ActionCodeSettings`` object or dict (optional). Defines whether + the link is to be handled by a mobile app and the additional state information to be + passed in the deep link, etc. + Returns: + link_url: action url to be emailed to the user + + Raises: + ApiCallError: If an error occurs while generating the link + ValueError: If the provided arguments are invalid + """ + payload = { + 'requestType': _auth_utils.validate_action_type(action_type), + 'email': _auth_utils.validate_email(email), + 'returnOobLink': True + } + + if action_code_settings: + payload.update(encode_action_code_settings(action_code_settings)) + + try: + response = self._client.body('post', '/accounts:sendOobCode', json=payload) + except requests.exceptions.RequestException as error: + self._handle_http_error(GENERATE_EMAIL_ACTION_LINK_ERROR, 'Failed to generate link.', + error) + else: + if not response or not response.get('oobLink'): + raise ApiCallError(GENERATE_EMAIL_ACTION_LINK_ERROR, 'Failed to generate link.') + return response.get('oobLink') + def _handle_http_error(self, code, msg, error): if error.response is not None: msg += '\nServer response: {0}'.format(error.response.content.decode()) diff --git a/firebase_admin/auth.py b/firebase_admin/auth.py index 4c793d34b..6a65c646f 100644 --- a/firebase_admin/auth.py +++ b/firebase_admin/auth.py @@ -35,6 +35,7 @@ __all__ = [ + 'ActionCodeSettings', 'AuthError', 'ErrorInfo', 'ExportedUserRecord', @@ -51,6 +52,9 @@ 'create_session_cookie', 'create_user', 'delete_user', + 'generate_password_reset_link', + 'generate_email_verification_link', + 'generate_sign_in_with_email_link', 'get_user', 'get_user_by_email', 'get_user_by_phone_number', @@ -63,6 +67,7 @@ 'verify_session_cookie', ] +ActionCodeSettings = _user_mgt.ActionCodeSettings ErrorInfo = _user_import.ErrorInfo ExportedUserRecord = _user_mgt.ExportedUserRecord ListUsersPage = _user_mgt.ListUsersPage @@ -448,6 +453,78 @@ def import_users(users, hash_alg=None, app=None): except _user_mgt.ApiCallError as error: raise AuthError(error.code, str(error), error.detail) +def generate_password_reset_link(email, action_code_settings=None, app=None): + """Generates the out-of-band email action link for password reset flows for the specified email + address. + + Args: + email: The email of the user whose password is to be reset. + action_code_settings: ``ActionCodeSettings`` instance (optional). Defines whether + the link is to be handled by a mobile app and the additional state information to be + passed in the deep link. + app: An App instance (optional). + Returns: + link: The password reset link created by API + + Raises: + ValueError: If the provided arguments are invalid + AuthError: If an error occurs while generating the link + """ + user_manager = _get_auth_service(app).user_manager + try: + return user_manager.generate_email_action_link('PASSWORD_RESET', email, + action_code_settings=action_code_settings) + except _user_mgt.ApiCallError as error: + raise AuthError(error.code, str(error), error.detail) + +def generate_email_verification_link(email, action_code_settings=None, app=None): + """Generates the out-of-band email action link for email verification flows for the specified + email address. + + Args: + email: The email of the user to be verified. + action_code_settings: ``ActionCodeSettings`` instance (optional). Defines whether + the link is to be handled by a mobile app and the additional state information to be + passed in the deep link. + app: An App instance (optional). + Returns: + link: The email verification link created by API + + Raises: + ValueError: If the provided arguments are invalid + AuthError: If an error occurs while generating the link + """ + user_manager = _get_auth_service(app).user_manager + try: + return user_manager.generate_email_action_link('VERIFY_EMAIL', email, + action_code_settings=action_code_settings) + except _user_mgt.ApiCallError as error: + raise AuthError(error.code, str(error), error.detail) + +def generate_sign_in_with_email_link(email, action_code_settings, app=None): + """Generates the out-of-band email action link for email link sign-in flows, using the action + code settings provided. + + Args: + email: The email of the user signing in. + action_code_settings: ``ActionCodeSettings`` instance. Defines whether + the link is to be handled by a mobile app and the additional state information to be + passed in the deep link. + app: An App instance (optional). + Returns: + link: The email sign in link created by API + + Raises: + ValueError: If the provided arguments are invalid + AuthError: If an error occurs while generating the link + """ + user_manager = _get_auth_service(app).user_manager + try: + return user_manager.generate_email_action_link('EMAIL_SIGNIN', email, + action_code_settings=action_code_settings) + except _user_mgt.ApiCallError as error: + raise AuthError(error.code, str(error), error.detail) + def _check_jwt_revoked(verified_claims, error_code, label, app): user = get_user(verified_claims.get('uid'), app=app) if verified_claims.get('iat') * 1000 < user.tokens_valid_after_timestamp: diff --git a/integration/test_auth.py b/integration/test_auth.py index 8604761c3..53577b827 100644 --- a/integration/test_auth.py +++ b/integration/test_auth.py @@ -18,6 +18,7 @@ import random import time import uuid +import six import pytest import requests @@ -30,7 +31,11 @@ _verify_token_url = 'https://www.googleapis.com/identitytoolkit/v3/relyingparty/verifyCustomToken' _verify_password_url = 'https://www.googleapis.com/identitytoolkit/v3/relyingparty/verifyPassword' +_password_reset_url = 'https://www.googleapis.com/identitytoolkit/v3/relyingparty/resetPassword' +_verify_email_url = 'https://www.googleapis.com/identitytoolkit/v3/relyingparty/setAccountInfo' +_email_sign_in_url = 'https://www.googleapis.com/identitytoolkit/v3/relyingparty/emailLinkSignin' +ACTION_LINK_CONTINUE_URL = 'http://localhost?a=1&b=5#f=1' def _sign_in(custom_token, api_key): body = {'token' : custom_token.decode(), 'returnSecureToken' : True} @@ -54,6 +59,32 @@ def _random_id(): def _random_phone(): return '+1' + ''.join([str(random.randint(0, 9)) for _ in range(0, 10)]) +def _reset_password(oob_code, new_password, api_key): + body = {'oobCode': oob_code, 'newPassword': new_password} + params = {'key' : api_key} + resp = requests.request('post', _password_reset_url, params=params, json=body) + resp.raise_for_status() + return resp.json().get('email') + +def _verify_email(oob_code, api_key): + body = {'oobCode': oob_code} + params = {'key' : api_key} + resp = requests.request('post', _verify_email_url, params=params, json=body) + resp.raise_for_status() + return resp.json().get('email') + +def _sign_in_with_email_link(email, oob_code, api_key): + body = {'oobCode': oob_code, 'email': email} + params = {'key' : api_key} + resp = requests.request('post', _email_sign_in_url, params=params, json=body) + resp.raise_for_status() + return resp.json().get('idToken') + +def _extract_link_params(link): + query = six.moves.urllib.parse.urlparse(link).query + query_dict = dict(six.moves.urllib.parse.parse_qsl(query)) + return query_dict + def test_custom_token(api_key): custom_token = auth.create_custom_token('user1') id_token = _sign_in(custom_token, api_key) @@ -151,6 +182,18 @@ def new_user_list(): for uid in users: auth.delete_user(uid) +@pytest.fixture +def new_user_email_unverified(): + random_id, email = _random_id() + user = auth.create_user( + uid=random_id, + email=email, + email_verified=False, + password='password' + ) + yield user + auth.delete_user(user.uid) + def test_get_user(new_user_with_params): user = auth.get_user(new_user_with_params.uid) assert user.uid == new_user_with_params.uid @@ -372,6 +415,57 @@ def test_import_users_with_password(api_key): finally: auth.delete_user(uid) +def test_password_reset(new_user_email_unverified, api_key): + link = auth.generate_password_reset_link(new_user_email_unverified.email) + assert isinstance(link, six.string_types) + query_dict = _extract_link_params(link) + user_email = _reset_password(query_dict['oobCode'], 'newPassword', api_key) + assert new_user_email_unverified.email == user_email + # password reset also set email_verified to True + assert auth.get_user(new_user_email_unverified.uid).email_verified + +def test_email_verification(new_user_email_unverified, api_key): + link = auth.generate_email_verification_link(new_user_email_unverified.email) + assert isinstance(link, six.string_types) + query_dict = _extract_link_params(link) + user_email = _verify_email(query_dict['oobCode'], api_key) + assert new_user_email_unverified.email == user_email + assert auth.get_user(new_user_email_unverified.uid).email_verified + +def test_password_reset_with_settings(new_user_email_unverified, api_key): + action_code_settings = auth.ActionCodeSettings(ACTION_LINK_CONTINUE_URL) + link = auth.generate_password_reset_link(new_user_email_unverified.email, + action_code_settings=action_code_settings) + assert isinstance(link, six.string_types) + query_dict = _extract_link_params(link) + assert query_dict['continueUrl'] == ACTION_LINK_CONTINUE_URL + user_email = _reset_password(query_dict['oobCode'], 'newPassword', api_key) + assert new_user_email_unverified.email == user_email + # password reset also set email_verified to True + assert auth.get_user(new_user_email_unverified.uid).email_verified + +def test_email_verification_with_settings(new_user_email_unverified, api_key): + action_code_settings = auth.ActionCodeSettings(ACTION_LINK_CONTINUE_URL) + link = auth.generate_email_verification_link(new_user_email_unverified.email, + action_code_settings=action_code_settings) + assert isinstance(link, six.string_types) + query_dict = _extract_link_params(link) + assert query_dict['continueUrl'] == ACTION_LINK_CONTINUE_URL + user_email = _verify_email(query_dict['oobCode'], api_key) + assert new_user_email_unverified.email == user_email + assert auth.get_user(new_user_email_unverified.uid).email_verified + +def test_email_sign_in_with_settings(new_user_email_unverified, api_key): + action_code_settings = auth.ActionCodeSettings(ACTION_LINK_CONTINUE_URL) + link = auth.generate_sign_in_with_email_link(new_user_email_unverified.email, + action_code_settings=action_code_settings) + assert isinstance(link, six.string_types) + query_dict = _extract_link_params(link) + assert query_dict['continueUrl'] == ACTION_LINK_CONTINUE_URL + oob_code = query_dict['oobCode'] + id_token = _sign_in_with_email_link(new_user_email_unverified.email, oob_code, api_key) + assert id_token is not None and len(id_token) > 0 + assert auth.get_user(new_user_email_unverified.uid).email_verified class CredentialWrapper(credentials.Base): """A custom Firebase credential that wraps an OAuth2 token.""" diff --git a/tests/test_user_mgt.py b/tests/test_user_mgt.py index f20a4e714..6e033fae4 100644 --- a/tests/test_user_mgt.py +++ b/tests/test_user_mgt.py @@ -37,6 +37,16 @@ MOCK_GET_USER_RESPONSE = testutils.resource('get_user.json') MOCK_LIST_USERS_RESPONSE = testutils.resource('list_users.json') +MOCK_ACTION_CODE_DATA = { + 'url': 'http://localhost', + 'handle_code_in_app': True, + 'dynamic_link_domain': 'http://testly', + 'ios_bundle_id': 'test.bundle', + 'android_package_name': 'test.bundle', + 'android_minimum_version': '7', + 'android_install_app': True, +} +MOCK_ACTION_CODE_SETTINGS = auth.ActionCodeSettings(**MOCK_ACTION_CODE_DATA) @pytest.fixture(scope='module') def user_mgt_app(): @@ -972,3 +982,149 @@ def test_revoke_refresh_tokens(self, user_mgt_app): assert request['localId'] == 'testuser' assert int(request['validSince']) >= int(before_time) assert int(request['validSince']) <= int(after_time) + +class TestActionCodeSetting(object): + + def test_valid_data(self): + data = { + 'url': 'http://localhost', + 'handle_code_in_app': True, + 'dynamic_link_domain': 'http://testly', + 'ios_bundle_id': 'test.bundle', + 'android_package_name': 'test.bundle', + 'android_minimum_version': '7', + 'android_install_app': True, + } + settings = auth.ActionCodeSettings(**data) + parameters = _user_mgt.encode_action_code_settings(settings) + assert parameters['continueUrl'] == data['url'] + assert parameters['canHandleCodeInApp'] == data['handle_code_in_app'] + assert parameters['dynamicLinkDomain'] == data['dynamic_link_domain'] + assert parameters['iosBundleId'] == data['ios_bundle_id'] + assert parameters['androidPackageName'] == data['android_package_name'] + assert parameters['androidMinimumVersion'] == data['android_minimum_version'] + assert parameters['androidInstallApp'] == data['android_install_app'] + + @pytest.mark.parametrize('data', [{'handle_code_in_app':'nonboolean'}, + {'android_install_app':'nonboolean'}, + {'dynamic_link_domain': False}, + {'ios_bundle_id':11}, + {'android_package_name':dict()}, + {'android_minimum_version':tuple()}, + {'android_minimum_version':'7'}, + {'android_install_app': True}]) + def test_bad_data(self, data): + settings = auth.ActionCodeSettings('http://localhost', **data) + with pytest.raises(ValueError): + _user_mgt.encode_action_code_settings(settings) + + def test_bad_url(self): + settings = auth.ActionCodeSettings('http:') + with pytest.raises(ValueError): + _user_mgt.encode_action_code_settings(settings) + + def test_encode_action_code_bad_data(self): + with pytest.raises(AttributeError): + _user_mgt.encode_action_code_settings({"foo":"bar"}) + +class TestGenerateEmailActionLink(object): + + def test_email_verification_no_settings(self, user_mgt_app): + _, recorder = _instrument_user_manager(user_mgt_app, 200, '{"oobLink":"https://testlink"}') + link = auth.generate_email_verification_link('test@test.com', app=user_mgt_app) + request = json.loads(recorder[0].body.decode()) + + assert link == 'https://testlink' + assert request['requestType'] == 'VERIFY_EMAIL' + self._validate_request(request) + + def test_password_reset_no_settings(self, user_mgt_app): + _, recorder = _instrument_user_manager(user_mgt_app, 200, '{"oobLink":"https://testlink"}') + link = auth.generate_password_reset_link('test@test.com', app=user_mgt_app) + request = json.loads(recorder[0].body.decode()) + + assert link == 'https://testlink' + assert request['requestType'] == 'PASSWORD_RESET' + self._validate_request(request) + + def test_email_signin_with_settings(self, user_mgt_app): + _, recorder = _instrument_user_manager(user_mgt_app, 200, '{"oobLink":"https://testlink"}') + link = auth.generate_sign_in_with_email_link('test@test.com', + action_code_settings=MOCK_ACTION_CODE_SETTINGS, + app=user_mgt_app) + request = json.loads(recorder[0].body.decode()) + + assert link == 'https://testlink' + assert request['requestType'] == 'EMAIL_SIGNIN' + self._validate_request(request, MOCK_ACTION_CODE_SETTINGS) + + def test_email_verification_with_settings(self, user_mgt_app): + _, recorder = _instrument_user_manager(user_mgt_app, 200, '{"oobLink":"https://testlink"}') + link = auth.generate_email_verification_link('test@test.com', + action_code_settings=MOCK_ACTION_CODE_SETTINGS, + app=user_mgt_app) + request = json.loads(recorder[0].body.decode()) + + assert link == 'https://testlink' + assert request['requestType'] == 'VERIFY_EMAIL' + self._validate_request(request, MOCK_ACTION_CODE_SETTINGS) + + def test_password_reset_with_settings(self, user_mgt_app): + _, recorder = _instrument_user_manager(user_mgt_app, 200, '{"oobLink":"https://testlink"}') + link = auth.generate_password_reset_link('test@test.com', + action_code_settings=MOCK_ACTION_CODE_SETTINGS, + app=user_mgt_app) + request = json.loads(recorder[0].body.decode()) + + assert link == 'https://testlink' + assert request['requestType'] == 'PASSWORD_RESET' + self._validate_request(request, MOCK_ACTION_CODE_SETTINGS) + + @pytest.mark.parametrize('func', [ + auth.generate_sign_in_with_email_link, + auth.generate_email_verification_link, + auth.generate_password_reset_link, + ]) + def test_api_call_failure(self, user_mgt_app, func): + _instrument_user_manager(user_mgt_app, 500, '{"error":"dummy error"}') + with pytest.raises(auth.AuthError): + func('test@test.com', MOCK_ACTION_CODE_SETTINGS, app=user_mgt_app) + + @pytest.mark.parametrize('func', [ + auth.generate_sign_in_with_email_link, + auth.generate_email_verification_link, + auth.generate_password_reset_link, + ]) + def test_api_call_no_link(self, user_mgt_app, func): + _instrument_user_manager(user_mgt_app, 200, '{}') + with pytest.raises(auth.AuthError): + func('test@test.com', MOCK_ACTION_CODE_SETTINGS, app=user_mgt_app) + + @pytest.mark.parametrize('func', [ + auth.generate_sign_in_with_email_link, + auth.generate_email_verification_link, + auth.generate_password_reset_link, + ]) + def test_bad_settings_data(self, user_mgt_app, func): + _instrument_user_manager(user_mgt_app, 200, '{"oobLink":"https://testlink"}') + with pytest.raises(AttributeError): + func('test@test.com', app=user_mgt_app, action_code_settings=1234) + + def test_bad_action_type(self, user_mgt_app): + with pytest.raises(ValueError): + auth._get_auth_service(user_mgt_app) \ + .user_manager \ + .generate_email_action_link('BAD_TYPE', 'test@test.com', + action_code_settings=MOCK_ACTION_CODE_SETTINGS) + + def _validate_request(self, request, settings=None): + assert request['email'] == 'test@test.com' + assert request['returnOobLink'] + if settings: + assert request['continueUrl'] == settings.url + assert request['canHandleCodeInApp'] == settings.handle_code_in_app + assert request['dynamicLinkDomain'] == settings.dynamic_link_domain + assert request['iosBundleId'] == settings.ios_bundle_id + assert request['androidPackageName'] == settings.android_package_name + assert request['androidMinimumVersion'] == settings.android_minimum_version + assert request['androidInstallApp'] == settings.android_install_app From e48e2b8d3146ab474d4fdd10c81a3e55e6d7dea0 Mon Sep 17 00:00:00 2001 From: Fernando Macedo Date: Wed, 13 Feb 2019 19:39:26 -0200 Subject: [PATCH 002/226] Distribution package only including firebase_admin (#263) --- setup.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/setup.py b/setup.py index 38a21152d..9aa36f89f 100644 --- a/setup.py +++ b/setup.py @@ -18,7 +18,6 @@ from os import path import sys -from setuptools import find_packages from setuptools import setup @@ -55,7 +54,7 @@ license=about['__license__'], keywords='firebase cloud development', install_requires=install_requires, - packages=find_packages(exclude=['tests']), + packages=['firebase_admin'], python_requires='>=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*', classifiers=[ 'Development Status :: 5 - Production/Stable', From 5a760583c36f6cfe1704bbde1fe142aa110efebb Mon Sep 17 00:00:00 2001 From: Hiranya Jayathilaka Date: Wed, 20 Feb 2019 12:13:58 -0800 Subject: [PATCH 003/226] Added snippets for email action links API (#264) * Added snippets for email action links API * Fixing a lint error; Fixing a regression caused by https://github.com/googleapis/google-auth-library-python/pull/324 * Fixing order of exported methods * Applying feedback from docs team; Updating snippets based on code review comments --- firebase_admin/auth.py | 8 +++---- snippets/auth/index.py | 48 +++++++++++++++++++++++++++++++++++++++++ tests/test_token_gen.py | 4 ++-- 3 files changed, 54 insertions(+), 6 deletions(-) diff --git a/firebase_admin/auth.py b/firebase_admin/auth.py index 6a65c646f..984c2babd 100644 --- a/firebase_admin/auth.py +++ b/firebase_admin/auth.py @@ -52,8 +52,8 @@ 'create_session_cookie', 'create_user', 'delete_user', - 'generate_password_reset_link', 'generate_email_verification_link', + 'generate_password_reset_link', 'generate_sign_in_with_email_link', 'get_user', 'get_user_by_email', @@ -464,7 +464,7 @@ def generate_password_reset_link(email, action_code_settings=None, app=None): passed in the deep link. app: An App instance (optional). Returns: - link: The password reset link created by API + link: The password reset link created by the API Raises: ValueError: If the provided arguments are invalid @@ -488,7 +488,7 @@ def generate_email_verification_link(email, action_code_settings=None, app=None) passed in the deep link. app: An App instance (optional). Returns: - link: The email verification link created by API + link: The email verification link created by the API Raises: ValueError: If the provided arguments are invalid @@ -512,7 +512,7 @@ def generate_sign_in_with_email_link(email, action_code_settings, app=None): passed in the deep link. app: An App instance (optional). Returns: - link: The email sign in link created by API + link: The email sign-in link created by the API Raises: ValueError: If the provided arguments are invalid diff --git a/snippets/auth/index.py b/snippets/auth/index.py index 36ea949d2..5bfe21f8e 100644 --- a/snippets/auth/index.py +++ b/snippets/auth/index.py @@ -587,6 +587,54 @@ def import_without_password(): print('Error importing users:', error) # [END import_without_password] +def init_action_code_settings(): + # [START init_action_code_settings] + action_code_settings = auth.ActionCodeSettings( + url='https://www.example.com/checkout?cartId=1234', + handle_code_in_app=True, + ios_bundle_id='com.example.ios', + android_package_name='com.example.android', + android_install_app=True, + android_minimum_version='12', + dynamic_link_domain='coolapp.page.link', + ) + # [END init_action_code_settings] + return action_code_settings + +def password_reset_link(): + action_code_settings = init_action_code_settings() + # [START password_reset_link] + email = 'user@example.com' + link = auth.generate_password_reset_link(email, action_code_settings) + # Construct password reset email from a template embedding the link, and send + # using a custom SMTP server. + send_custom_email(email, link) + # [END password_reset_link] + +def email_verification_link(): + action_code_settings = init_action_code_settings() + # [START email_verification_link] + email = 'user@example.com' + link = auth.generate_email_verification_link(email, action_code_settings) + # Construct email from a template embedding the link, and send + # using a custom SMTP server. + send_custom_email(email, link) + # [END email_verification_link] + +def sign_in_with_email_link(): + action_code_settings = init_action_code_settings() + # [START sign_in_with_email_link] + email = 'user@example.com' + link = auth.generate_sign_in_with_email_link(email, action_code_settings) + # Construct email from a template embedding the link, and send + # using a custom SMTP server. + send_custom_email(email, link) + # [END sign_in_with_email_link] + +def send_custom_email(email, link): + del email + del link + initialize_sdk_with_service_account() initialize_sdk_with_application_default() #initialize_sdk_with_refresh_token() diff --git a/tests/test_token_gen.py b/tests/test_token_gen.py index 52ccd172b..412ba3d0e 100644 --- a/tests/test_token_gen.py +++ b/tests/test_token_gen.py @@ -207,7 +207,7 @@ def test_sign_with_iam(self): iam_resp = '{{"signature": "{0}"}}'.format(signature) _overwrite_iam_request(app, testutils.MockRequest(200, iam_resp)) custom_token = auth.create_custom_token(MOCK_UID, app=app).decode() - assert custom_token.endswith('.' + signature) + assert custom_token.endswith('.' + signature.rstrip('=')) self._verify_signer(custom_token, 'test-service-account') finally: firebase_admin.delete_app(app) @@ -241,7 +241,7 @@ def test_sign_with_discovered_service_account(self): request.response = testutils.MockResponse( 200, '{{"signature": "{0}"}}'.format(signature)) custom_token = auth.create_custom_token(MOCK_UID, app=app).decode() - assert custom_token.endswith('.' + signature) + assert custom_token.endswith('.' + signature.rstrip('=')) self._verify_signer(custom_token, 'discovered-service-account') assert len(request.log) == 2 assert request.log[0][1]['headers'] == {'Metadata-Flavor': 'Google'} From f4b9c8ebbb4c9774d393db0846fe569a86a19b40 Mon Sep 17 00:00:00 2001 From: Hiranya Jayathilaka Date: Thu, 21 Feb 2019 10:34:22 -0800 Subject: [PATCH 004/226] Bumped version to 2.16.0 (#265) --- CHANGELOG.md | 6 +++++- firebase_admin/__about__.py | 2 +- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 894ef86b1..5ceae9ffd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,10 @@ # Unreleased -- [added] Added `generate_password_reset_link()`, +- + +# v2.16.0 + +- [added] Added `generate_password_reset_link()`, `generate_email_verification_link()` and `generate_sign_in_with_email_link()` methods to the `auth` API. - [added] Migrated the `auth` user management API to the diff --git a/firebase_admin/__about__.py b/firebase_admin/__about__.py index fce8b8388..cba8bc848 100644 --- a/firebase_admin/__about__.py +++ b/firebase_admin/__about__.py @@ -14,7 +14,7 @@ """About information (version, etc) for Firebase Admin SDK.""" -__version__ = '2.15.1' +__version__ = '2.16.0' __title__ = 'firebase_admin' __author__ = 'Firebase' __license__ = 'Apache License 2.0' From 32c34c6c5343e09b30991a8a5b318cc751d3df72 Mon Sep 17 00:00:00 2001 From: Hiranya Jayathilaka Date: Mon, 25 Feb 2019 14:46:21 -0800 Subject: [PATCH 005/226] Updated set up instructions (#266) --- CONTRIBUTING.md | 38 ++++++++++++++++++++++++++++---------- 1 file changed, 28 insertions(+), 10 deletions(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index fd6d15a64..487e3754a 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -85,6 +85,8 @@ information on using pull requests. ### Initial Setup +You need Python 2.7 or Python 3.4+ to build and test the code in this repo. + We recommend using [pip](https://pypi.python.org/pypi/pip) for installing the necessary tools and project dependencies. Most recent versions of Python ship with pip. If your development environment does not already have pip, use the software package manager of your platform (e.g. apt-get, brew) @@ -120,7 +122,7 @@ pass `all` as an argument. ``` ./lint.sh # Lint locally modified source files -./lint.sh all # Lint all source files +./lint.sh all # Lint all source files ``` Ideally you should not see any pylint errors or warnings when you run the @@ -167,36 +169,52 @@ do not already have one suitable for running the tests aginst. Then obtain the following credentials from the project: 1. *Service account certificate*: This can be downloaded as a JSON file from - the "Settings > Service Accounts" tab of the Firebase console. + the "Settings > Service Accounts" tab of the Firebase console. Copy the + file into the repo so it's available at `scripts/cert.json`. 2. *Web API key*: This is displayed in the "Settings > General" tab of the - console. Copy it and save to a new text file. + console. Copy it and save to a new text file at `scripts/apikey.txt`. + +Then set up your Firebase/GCP project as follows: + +1. Enable Firestore: Go to the Firebase Console, and select "Database" from + the "Develop" menu. Click on the "Create database" button. You may choose + to set up Firestore either in the locked mode or in the test mode. +2. Enable password auth: Select "Authentication" from the "Develop" menu in + Firebase Console. Select the "Sign-in method" tab, and enable the + "Email/Password" sign-in method. +3. Enable the IAM API: Go to the + [Google Cloud Platform Console](https://console.cloud.google.com) and make + sure your Firebase/GCP project is selected. Select "APIs & Services > + Dashboard" from the main menu, and click the "ENABLE APIS AND SERVICES" + button. Search for and enable the "Identity and Access Management (IAM) + API". Now you can invoke the integration test suite as follows: ``` -pytest integration/ --cert path/to/service_acct.json --apikey path/to/apikey.txt +pytest integration/ --cert scripts/cert.json --apikey scripts/apikey.txt ``` ### Test Coverage -To review the test coverage, run `pytest` with the `--cov` flag. To view a detailed line by line -coverage, use +To review the test coverage, run `pytest` with the `--cov` flag. To view a detailed line by line +coverage, use ```bash pytest --cov --cov-report html ``` -and point your browser to +and point your browser to `file:////htmlcov/index.html` (where `dir` is the location from which the report was created). ### Testing in Different Environments Sometimes we want to run unit tests in multiple environments (e.g. different Python versions), and -ensure that the SDK works as expected in each of them. We use -[tox](https://tox.readthedocs.io/en/latest/) for this purpose. +ensure that the SDK works as expected in each of them. We use +[tox](https://tox.readthedocs.io/en/latest/) for this purpose. But before you can invoke tox, you must set up all the necessary target environments on your workstation. The easiest and cleanest way to achieve this is by using a tool like -[pyenv](https://github.com/pyenv/pyenv). Refer to the +[pyenv](https://github.com/pyenv/pyenv). Refer to the [pyenv documentation](https://github.com/pyenv/pyenv#installation) for instructions on how to install it. This generally involves installing some binaries as well as modifying a system level configuration file such as `.bash_profile`. Once pyenv is installed, you can install multiple From 4aa34222fc243b317cecf5c078015f4d38566e59 Mon Sep 17 00:00:00 2001 From: Hiranya Jayathilaka Date: Wed, 6 Mar 2019 13:48:22 -0800 Subject: [PATCH 006/226] Updated db.Reference.listen() documentation (#268) --- firebase_admin/db.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/firebase_admin/db.py b/firebase_admin/db.py index d37a7fe88..3973f7654 100644 --- a/firebase_admin/db.py +++ b/firebase_admin/db.py @@ -350,7 +350,11 @@ def listen(self, callback): """Registers the ``callback`` function to receive realtime updates. The specified callback function will get invoked with ``db.Event`` objects for each - realtime update received from the database. + realtime update received from the database. It will also get called whenever the SDK + reconnects to the server due to network issues and credential expiration. In general, + the OAuth2 credentials used to authorize connections to the server expire every hour. + Therefore clients should expect the ``callback`` to fire at least once every hour, even if + there are no updates in the database. This API is based on the event streaming support available in the Firebase REST API. Each call to ``listen()`` starts a new HTTP connection and a background thread. This is an From 18c2395df35977a6444fd95b8f66a29fda3b04a9 Mon Sep 17 00:00:00 2001 From: Tyler Rockwood Date: Tue, 26 Mar 2019 20:57:50 -0500 Subject: [PATCH 007/226] Revert "Add header for opting into correct URL decoding (#206)" (#274) This reverts commit 77737c060def12a9b7720474bfe3f84cbc091499. --- firebase_admin/db.py | 5 ++--- tests/test_db.py | 19 ------------------- 2 files changed, 2 insertions(+), 22 deletions(-) diff --git a/firebase_admin/db.py b/firebase_admin/db.py index 3973f7654..f1bbeba8e 100644 --- a/firebase_admin/db.py +++ b/firebase_admin/db.py @@ -36,7 +36,7 @@ _DB_ATTRIBUTE = '_database' -_INVALID_PATH_CHARACTERS = '[].#$' +_INVALID_PATH_CHARACTERS = '[].?#$' _RESERVED_FILTERS = ('$key', '$value', '$priority') _USER_AGENT = 'Firebase/HTTP/{0}/{1}.{2}/AdminPython'.format( firebase_admin.__version__, sys.version_info.major, sys.version_info.minor) @@ -849,8 +849,7 @@ def __init__(self, credential, base_url, auth_override, timeout): timeout, which is the default behavior of the underlying requests library. """ _http_client.JsonHttpClient.__init__( - self, credential=credential, base_url=base_url, - headers={'User-Agent': _USER_AGENT, 'X-Firebase-Decoding': '1'}) + self, credential=credential, base_url=base_url, headers={'User-Agent': _USER_AGENT}) self.credential = credential self.auth_override = auth_override self.timeout = timeout diff --git a/tests/test_db.py b/tests/test_db.py index b42123d34..6168b72d4 100644 --- a/tests/test_db.py +++ b/tests/test_db.py @@ -157,7 +157,6 @@ def test_get_value(self, data): assert recorder[0].url == 'https://test.firebaseio.com/test.json' assert recorder[0].headers['Authorization'] == 'Bearer mock-token' assert recorder[0].headers['User-Agent'] == db._USER_AGENT - assert recorder[0].headers['X-Firebase-Decoding'] == '1' assert 'X-Firebase-ETag' not in recorder[0].headers @pytest.mark.parametrize('data', valid_values) @@ -170,7 +169,6 @@ def test_get_with_etag(self, data): assert recorder[0].url == 'https://test.firebaseio.com/test.json' assert recorder[0].headers['Authorization'] == 'Bearer mock-token' assert recorder[0].headers['User-Agent'] == db._USER_AGENT - assert recorder[0].headers['X-Firebase-Decoding'] == '1' assert recorder[0].headers['X-Firebase-ETag'] == 'true' @pytest.mark.parametrize('data', valid_values) @@ -182,7 +180,6 @@ def test_get_shallow(self, data): assert recorder[0].method == 'GET' assert recorder[0].url == 'https://test.firebaseio.com/test.json?shallow=true' assert recorder[0].headers['Authorization'] == 'Bearer mock-token' - assert recorder[0].headers['X-Firebase-Decoding'] == '1' assert recorder[0].headers['User-Agent'] == db._USER_AGENT def test_get_with_etag_and_shallow(self): @@ -200,14 +197,12 @@ def test_get_if_changed(self, data): assert recorder[0].method == 'GET' assert recorder[0].url == 'https://test.firebaseio.com/test.json' assert recorder[0].headers['if-none-match'] == 'invalid-etag' - assert recorder[0].headers['X-Firebase-Decoding'] == '1' assert ref.get_if_changed(MockAdapter.ETAG) == (False, None, None) assert len(recorder) == 2 assert recorder[1].method == 'GET' assert recorder[1].url == 'https://test.firebaseio.com/test.json' assert recorder[1].headers['if-none-match'] == MockAdapter.ETAG - assert recorder[1].headers['X-Firebase-Decoding'] == '1' @pytest.mark.parametrize('etag', [0, 1, True, False, dict(), list(), tuple()]) def test_get_if_changed_invalid_etag(self, etag): @@ -226,7 +221,6 @@ def test_order_by_query(self, data): assert recorder[0].method == 'GET' assert recorder[0].url == 'https://test.firebaseio.com/test.json?' + query_str assert recorder[0].headers['Authorization'] == 'Bearer mock-token' - assert recorder[0].headers['X-Firebase-Decoding'] == '1' @pytest.mark.parametrize('data', valid_values) def test_limit_query(self, data): @@ -240,7 +234,6 @@ def test_limit_query(self, data): assert recorder[0].method == 'GET' assert recorder[0].url == 'https://test.firebaseio.com/test.json?' + query_str assert recorder[0].headers['Authorization'] == 'Bearer mock-token' - assert recorder[0].headers['X-Firebase-Decoding'] == '1' @pytest.mark.parametrize('data', valid_values) def test_range_query(self, data): @@ -255,7 +248,6 @@ def test_range_query(self, data): assert recorder[0].method == 'GET' assert recorder[0].url == 'https://test.firebaseio.com/test.json?' + query_str assert recorder[0].headers['Authorization'] == 'Bearer mock-token' - assert recorder[0].headers['X-Firebase-Decoding'] == '1' @pytest.mark.parametrize('data', valid_values) def test_set_value(self, data): @@ -267,7 +259,6 @@ def test_set_value(self, data): assert recorder[0].url == 'https://test.firebaseio.com/test.json?print=silent' assert json.loads(recorder[0].body.decode()) == data assert recorder[0].headers['Authorization'] == 'Bearer mock-token' - assert recorder[0].headers['X-Firebase-Decoding'] == '1' def test_set_none_value(self): ref = db.reference('/test') @@ -294,7 +285,6 @@ def test_update_children(self, data): assert recorder[0].url == 'https://test.firebaseio.com/test.json?print=silent' assert json.loads(recorder[0].body.decode()) == data assert recorder[0].headers['Authorization'] == 'Bearer mock-token' - assert recorder[0].headers['X-Firebase-Decoding'] == '1' @pytest.mark.parametrize('data', valid_values) def test_set_if_unchanged_success(self, data): @@ -308,7 +298,6 @@ def test_set_if_unchanged_success(self, data): assert json.loads(recorder[0].body.decode()) == data assert recorder[0].headers['Authorization'] == 'Bearer mock-token' assert recorder[0].headers['if-match'] == MockAdapter.ETAG - assert recorder[0].headers['X-Firebase-Decoding'] == '1' @pytest.mark.parametrize('data', valid_values) def test_set_if_unchanged_failure(self, data): @@ -322,7 +311,6 @@ def test_set_if_unchanged_failure(self, data): assert json.loads(recorder[0].body.decode()) == data assert recorder[0].headers['Authorization'] == 'Bearer mock-token' assert recorder[0].headers['if-match'] == 'invalid-etag' - assert recorder[0].headers['X-Firebase-Decoding'] == '1' @pytest.mark.parametrize('etag', [0, 1, True, False, dict(), list(), tuple()]) def test_set_if_unchanged_invalid_etag(self, etag): @@ -368,7 +356,6 @@ def test_push(self, data): assert json.loads(recorder[0].body.decode()) == data assert recorder[0].headers['Authorization'] == 'Bearer mock-token' assert recorder[0].headers['User-Agent'] == db._USER_AGENT - assert recorder[0].headers['X-Firebase-Decoding'] == '1' def test_push_default(self): ref = db.reference('/test') @@ -380,7 +367,6 @@ def test_push_default(self): assert json.loads(recorder[0].body.decode()) == '' assert recorder[0].headers['Authorization'] == 'Bearer mock-token' assert recorder[0].headers['User-Agent'] == db._USER_AGENT - assert recorder[0].headers['X-Firebase-Decoding'] == '1' def test_push_none_value(self): ref = db.reference('/test') @@ -397,7 +383,6 @@ def test_delete(self): assert recorder[0].url == 'https://test.firebaseio.com/test.json' assert recorder[0].headers['Authorization'] == 'Bearer mock-token' assert recorder[0].headers['User-Agent'] == db._USER_AGENT - assert recorder[0].headers['X-Firebase-Decoding'] == '1' def test_transaction(self): ref = db.reference('/test') @@ -583,7 +568,6 @@ def test_get_value(self): assert recorder[0].url == 'https://test.firebaseio.com/test.json?' + query_str assert recorder[0].headers['Authorization'] == 'Bearer mock-token' assert recorder[0].headers['User-Agent'] == db._USER_AGENT - assert recorder[0].headers['X-Firebase-Decoding'] == '1' def test_set_value(self): ref = db.reference('/test') @@ -597,7 +581,6 @@ def test_set_value(self): assert json.loads(recorder[0].body.decode()) == data assert recorder[0].headers['Authorization'] == 'Bearer mock-token' assert recorder[0].headers['User-Agent'] == db._USER_AGENT - assert recorder[0].headers['X-Firebase-Decoding'] == '1' def test_order_by_query(self): ref = db.reference('/test') @@ -610,7 +593,6 @@ def test_order_by_query(self): assert recorder[0].url == 'https://test.firebaseio.com/test.json?' + query_str assert recorder[0].headers['Authorization'] == 'Bearer mock-token' assert recorder[0].headers['User-Agent'] == db._USER_AGENT - assert recorder[0].headers['X-Firebase-Decoding'] == '1' def test_range_query(self): ref = db.reference('/test') @@ -624,7 +606,6 @@ def test_range_query(self): assert recorder[0].url == 'https://test.firebaseio.com/test.json?' + query_str assert recorder[0].headers['Authorization'] == 'Bearer mock-token' assert recorder[0].headers['User-Agent'] == db._USER_AGENT - assert recorder[0].headers['X-Firebase-Decoding'] == '1' class TestDatabaseInitialization(object): From 3da3b5a2df94f8c1efc91c1fc2ad92131e6db44a Mon Sep 17 00:00:00 2001 From: Hiranya Jayathilaka Date: Tue, 2 Apr 2019 13:10:00 -0700 Subject: [PATCH 008/226] Added X-Firebase-Client header to FCM API calls (#278) * Added X-Firebase-Client header to FCM API calls * Using the pre-calculated version string --- firebase_admin/messaging.py | 7 ++++++- tests/test_messaging.py | 4 ++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/firebase_admin/messaging.py b/firebase_admin/messaging.py index 4d88feef8..f7988320d 100644 --- a/firebase_admin/messaging.py +++ b/firebase_admin/messaging.py @@ -17,6 +17,7 @@ import requests import six +import firebase_admin from firebase_admin import _http_client from firebase_admin import _messaging_utils from firebase_admin import _utils @@ -235,6 +236,7 @@ def __init__(self, app): self._fcm_url = _MessagingService.FCM_URL.format(project_id) self._client = _http_client.JsonHttpClient(credential=app.credential.get_credential()) self._timeout = app.options.get('httpTimeout') + self._client_version = 'fire-admin-python/{0}'.format(firebase_admin.__version__) @classmethod def encode_message(cls, message): @@ -247,7 +249,10 @@ def send(self, message, dry_run=False): if dry_run: data['validate_only'] = True try: - headers = {'X-GOOG-API-FORMAT-VERSION': '2'} + headers = { + 'X-GOOG-API-FORMAT-VERSION': '2', + 'X-FIREBASE-CLIENT': self._client_version, + } resp = self._client.body( 'post', url=self._fcm_url, headers=headers, json=data, timeout=self._timeout) except requests.exceptions.RequestException as error: diff --git a/tests/test_messaging.py b/tests/test_messaging.py index ca4dcb24e..8be2b8d8f 100644 --- a/tests/test_messaging.py +++ b/tests/test_messaging.py @@ -1165,6 +1165,7 @@ def test_topic_management_timeout(self): class TestSend(object): _DEFAULT_RESPONSE = json.dumps({'name': 'message-id'}) + _CLIENT_VERSION = 'fire-admin-python/{0}'.format(firebase_admin.__version__) @classmethod def setup_class(cls): @@ -1210,6 +1211,7 @@ def test_send_dry_run(self): assert recorder[0].method == 'POST' assert recorder[0].url == self._get_url('explicit-project-id') assert recorder[0].headers['X-GOOG-API-FORMAT-VERSION'] == '2' + assert recorder[0].headers['X-FIREBASE-CLIENT'] == self._CLIENT_VERSION body = { 'message': messaging._MessagingService.encode_message(msg), 'validate_only': True, @@ -1225,6 +1227,7 @@ def test_send(self): assert recorder[0].method == 'POST' assert recorder[0].url == self._get_url('explicit-project-id') assert recorder[0].headers['X-GOOG-API-FORMAT-VERSION'] == '2' + assert recorder[0].headers['X-FIREBASE-CLIENT'] == self._CLIENT_VERSION assert recorder[0]._extra_kwargs['timeout'] is None body = {'message': messaging._MessagingService.encode_message(msg)} assert json.loads(recorder[0].body.decode()) == body @@ -1242,6 +1245,7 @@ def test_send_error(self, status): assert recorder[0].method == 'POST' assert recorder[0].url == self._get_url('explicit-project-id') assert recorder[0].headers['X-GOOG-API-FORMAT-VERSION'] == '2' + assert recorder[0].headers['X-FIREBASE-CLIENT'] == self._CLIENT_VERSION body = {'message': messaging._MessagingService.JSON_ENCODER.default(msg)} assert json.loads(recorder[0].body.decode()) == body From 8f7148557933ecb51888fef2e29c488a40b64d66 Mon Sep 17 00:00:00 2001 From: Zachary Orr <516458+ZachOrr@users.noreply.github.com> Date: Mon, 13 May 2019 18:53:32 -0400 Subject: [PATCH 009/226] Add messaging send_all and send_multicast functions (#283) * Add messaging send_all and send_multicast functions * Fix CI * Small changes * Updating tests * Add non-200 non-error response code tests * Fix CI * Update postproc, update tests * Fix linter errors --- firebase_admin/_messaging_utils.py | 29 +- firebase_admin/messaging.py | 228 +++++++++++++-- requirements.txt | 1 + setup.py | 1 + tests/test_messaging.py | 435 +++++++++++++++++++++++++++++ 5 files changed, 670 insertions(+), 24 deletions(-) diff --git a/firebase_admin/_messaging_utils.py b/firebase_admin/_messaging_utils.py index aba809f22..373adf68c 100644 --- a/firebase_admin/_messaging_utils.py +++ b/firebase_admin/_messaging_utils.py @@ -54,6 +54,33 @@ def __init__(self, data=None, notification=None, android=None, webpush=None, apn self.condition = condition +class MulticastMessage(object): + """A message that can be sent to multiple tokens via Firebase Cloud Messaging. + + Contains payload information as well as recipient information. In particular, the message must + contain exactly one of token, topic or condition fields. + + Args: + tokens: A list of registration token of the device to which the message should be sent. + data: A dictionary of data fields (optional). All keys and values in the dictionary must be + strings. + notification: An instance of ``messaging.Notification`` (optional). + android: An instance of ``messaging.AndroidConfig`` (optional). + webpush: An instance of ``messaging.WebpushConfig`` (optional). + apns: An instance of ``messaging.ApnsConfig`` (optional). + """ + def __init__(self, tokens, data=None, notification=None, android=None, webpush=None, apns=None): + _Validators.check_string_list('MulticastMessage.tokens', tokens) + if len(tokens) > 100: + raise ValueError('MulticastMessage.tokens must not contain more than 100 tokens.') + self.tokens = tokens + self.data = data + self.notification = notification + self.android = android + self.webpush = webpush + self.apns = apns + + class Notification(object): """A notification that can be included in a message. @@ -150,7 +177,7 @@ class WebpushConfig(object): data: A dictionary of data fields (optional). All keys and values in the dictionary must be strings. When specified, overrides any data fields set via ``Message.data``. notification: A ``messaging.WebpushNotification`` to be included in the message (optional). - fcm_options: A ``messaging.WebpushFcmOptions`` instance to be included in the messsage + fcm_options: A ``messaging.WebpushFcmOptions`` instance to be included in the message (optional). .. _Webpush Specification: https://tools.ietf.org/html/rfc8030#section-5 diff --git a/firebase_admin/messaging.py b/firebase_admin/messaging.py index f7988320d..8129f8de1 100644 --- a/firebase_admin/messaging.py +++ b/firebase_admin/messaging.py @@ -14,9 +14,14 @@ """Firebase Cloud Messaging module.""" +import json import requests import six +import googleapiclient +from googleapiclient import http +from googleapiclient import _auth + import firebase_admin from firebase_admin import _http_client from firebase_admin import _messaging_utils @@ -34,10 +39,13 @@ 'ApiCallError', 'Aps', 'ApsAlert', + 'BatchResponse', 'CriticalSound', 'ErrorInfo', 'Message', + 'MulticastMessage', 'Notification', + 'SendResponse', 'TopicManagementResponse', 'WebpushConfig', 'WebpushFcmOptions', @@ -45,6 +53,8 @@ 'WebpushNotificationAction', 'send', + 'send_all', + 'send_multicast', 'subscribe_to_topic', 'unsubscribe_from_topic', ] @@ -58,6 +68,7 @@ ApsAlert = _messaging_utils.ApsAlert CriticalSound = _messaging_utils.CriticalSound Message = _messaging_utils.Message +MulticastMessage = _messaging_utils.MulticastMessage Notification = _messaging_utils.Notification WebpushConfig = _messaging_utils.WebpushConfig WebpushFcmOptions = _messaging_utils.WebpushFcmOptions @@ -88,6 +99,56 @@ def send(message, dry_run=False, app=None): """ return _get_messaging_service(app).send(message, dry_run) +def send_all(messages, dry_run=False, app=None): + """Sends the given list of messages via Firebase Cloud Messaging as a single batch. + + If the ``dry_run`` mode is enabled, the message will not be actually delivered to the + recipients. Instead FCM performs all the usual validations, and emulates the send operation. + + Args: + messages: A list of ``messaging.Message`` instances. + dry_run: A boolean indicating whether to run the operation in dry run mode (optional). + app: An App instance (optional). + + Returns: + BatchResponse: A ``messaging.BatchResponse`` instance. + + Raises: + ApiCallError: If an error occurs while sending the message to FCM service. + ValueError: If the input arguments are invalid. + """ + return _get_messaging_service(app).send_all(messages, dry_run) + +def send_multicast(multicast_message, dry_run=False, app=None): + """Sends the given mutlicast message to all tokens via Firebase Cloud Messaging (FCM). + + If the ``dry_run`` mode is enabled, the message will not be actually delivered to the + recipients. Instead FCM performs all the usual validations, and emulates the send operation. + + Args: + multicast_message: An instance of ``messaging.MulticastMessage``. + dry_run: A boolean indicating whether to run the operation in dry run mode (optional). + app: An App instance (optional). + + Returns: + BatchResponse: A ``messaging.BatchResponse`` instance. + + Raises: + ApiCallError: If an error occurs while sending the message to FCM service. + ValueError: If the input arguments are invalid. + """ + if not isinstance(multicast_message, MulticastMessage): + raise ValueError('Message must be an instance of messaging.MulticastMessage class.') + messages = [Message( + data=multicast_message.data, + notification=multicast_message.notification, + android=multicast_message.android, + webpush=multicast_message.webpush, + apns=multicast_message.apns, + token=token + ) for token in multicast_message.tokens] + return _get_messaging_service(app).send_all(messages, dry_run) + def subscribe_to_topic(tokens, topic, app=None): """Subscribes a list of registration tokens to an FCM topic. @@ -192,10 +253,57 @@ def __init__(self, code, message, detail=None): self.detail = detail +class BatchResponse(object): + """The response received from a batch request to the FCM API.""" + + def __init__(self, responses): + self._responses = responses + self._success_count = len([resp for resp in responses if resp.success]) + + @property + def responses(self): + """A list of ``messaging.SendResponse`` objects (possibly empty).""" + return self._responses + + @property + def success_count(self): + return self._success_count + + @property + def failure_count(self): + return len(self.responses) - self.success_count + + +class SendResponse(object): + """The response received from an individual batched request to the FCM API.""" + + def __init__(self, resp, exception): + self._exception = exception + self._message_id = None + if resp: + self._message_id = resp.get('name', None) + + @property + def message_id(self): + """A message ID string that uniquely identifies the sent the message.""" + return self._message_id + + @property + def success(self): + """A boolean indicating if the request was successful.""" + return self._message_id is not None and not self._exception + + @property + def exception(self): + """A ApiCallError if an error occurs while sending the message to FCM service.""" + return self._exception + + class _MessagingService(object): """Service class that implements Firebase Cloud Messaging (FCM) functionality.""" FCM_URL = 'https://fcm.googleapis.com/v1/projects/{0}/messages:send' + FCM_BATCH_URL = 'https://fcm.googleapis.com/batch' IID_URL = 'https://iid.googleapis.com' IID_HEADERS = {'access_token_auth': 'true'} JSON_ENCODER = _messaging_utils.MessageEncoder() @@ -234,9 +342,13 @@ def __init__(self, app): 'projectId option, or use service account credentials. Alternatively, set the ' 'GOOGLE_CLOUD_PROJECT environment variable.') self._fcm_url = _MessagingService.FCM_URL.format(project_id) + self._fcm_headers = { + 'X-GOOG-API-FORMAT-VERSION': '2', + 'X-FIREBASE-CLIENT': 'fire-admin-python/{0}'.format(firebase_admin.__version__), + } self._client = _http_client.JsonHttpClient(credential=app.credential.get_credential()) self._timeout = app.options.get('httpTimeout') - self._client_version = 'fire-admin-python/{0}'.format(firebase_admin.__version__) + self._transport = _auth.authorized_http(app.credential.get_credential()) @classmethod def encode_message(cls, message): @@ -245,16 +357,15 @@ def encode_message(cls, message): return cls.JSON_ENCODER.default(message) def send(self, message, dry_run=False): - data = {'message': _MessagingService.encode_message(message)} - if dry_run: - data['validate_only'] = True + data = self._message_data(message, dry_run) try: - headers = { - 'X-GOOG-API-FORMAT-VERSION': '2', - 'X-FIREBASE-CLIENT': self._client_version, - } resp = self._client.body( - 'post', url=self._fcm_url, headers=headers, json=data, timeout=self._timeout) + 'post', + url=self._fcm_url, + headers=self._fcm_headers, + json=data, + timeout=self._timeout + ) except requests.exceptions.RequestException as error: if error.response is not None: self._handle_fcm_error(error) @@ -264,6 +375,42 @@ def send(self, message, dry_run=False): else: return resp['name'] + def send_all(self, messages, dry_run=False): + """Sends the given messages to FCM via the batch API.""" + if not isinstance(messages, list): + raise ValueError('Messages must be an list of messaging.Message instances.') + if len(messages) > 100: + raise ValueError('send_all messages must not contain more than 100 messages.') + + responses = [] + + def batch_callback(_, response, error): + exception = None + if error: + exception = self._parse_batch_error(error) + send_response = SendResponse(response, exception) + responses.append(send_response) + + batch = http.BatchHttpRequest(batch_callback, _MessagingService.FCM_BATCH_URL) + for message in messages: + body = json.dumps(self._message_data(message, dry_run)) + req = http.HttpRequest( + http=self._transport, + postproc=self._postproc, + uri=self._fcm_url, + method='POST', + body=body, + headers=self._fcm_headers + ) + batch.add(req) + + try: + batch.execute() + except googleapiclient.http.HttpError as error: + raise self._parse_batch_error(error) + else: + return BatchResponse(responses) + def make_topic_management_request(self, tokens, topic, operation): """Invokes the IID service for topic management functionality.""" if isinstance(tokens, six.string_types): @@ -299,6 +446,17 @@ def make_topic_management_request(self, tokens, topic, operation): else: return TopicManagementResponse(resp) + def _message_data(self, message, dry_run): + data = {'message': _MessagingService.encode_message(message)} + if dry_run: + data['validate_only'] = True + return data + + def _postproc(self, _, body): + """Handle response from batch API request.""" + # This only gets called for 2xx responses. + return json.loads(body.decode()) + def _handle_fcm_error(self, error): """Handles errors received from the FCM API.""" data = {} @@ -309,20 +467,8 @@ def _handle_fcm_error(self, error): except ValueError: pass - error_dict = data.get('error', {}) - server_code = None - for detail in error_dict.get('details', []): - if detail.get('@type') == 'type.googleapis.com/google.firebase.fcm.v1.FcmError': - server_code = detail.get('errorCode') - break - if not server_code: - server_code = error_dict.get('status') - code = _MessagingService.FCM_ERROR_CODES.get(server_code, _MessagingService.UNKNOWN_ERROR) - - msg = error_dict.get('message') - if not msg: - msg = 'Unexpected HTTP response with status: {0}; body: {1}'.format( - error.response.status_code, error.response.content.decode()) + code, msg = _MessagingService._parse_fcm_error( + data, error.response.content, error.response.status_code) raise ApiCallError(code, msg, error) def _handle_iid_error(self, error): @@ -342,3 +488,39 @@ def _handle_iid_error(self, error): msg = 'Unexpected HTTP response with status: {0}; body: {1}'.format( error.response.status_code, error.response.content.decode()) raise ApiCallError(code, msg, error) + + def _parse_batch_error(self, error): + """Parses a googleapiclient.http.HttpError content in to an ApiCallError.""" + if error.content is None: + msg = 'Failed to call messaging API: {0}'.format(error) + return ApiCallError(self.INTERNAL_ERROR, msg, error) + + data = {} + try: + parsed_body = json.loads(error.content.decode()) + if isinstance(parsed_body, dict): + data = parsed_body + except ValueError: + pass + + code, msg = _MessagingService._parse_fcm_error(data, error.content, error.resp.status) + return ApiCallError(code, msg, error) + + @classmethod + def _parse_fcm_error(cls, data, content, status_code): + """Parses an error response from the FCM API to a ApiCallError.""" + error_dict = data.get('error', {}) + server_code = None + for detail in error_dict.get('details', []): + if detail.get('@type') == 'type.googleapis.com/google.firebase.fcm.v1.FcmError': + server_code = detail.get('errorCode') + break + if not server_code: + server_code = error_dict.get('status') + code = _MessagingService.FCM_ERROR_CODES.get(server_code, _MessagingService.UNKNOWN_ERROR) + + msg = error_dict.get('message') + if not msg: + msg = 'Unexpected HTTP response with status: {0}; body: {1}'.format( + status_code, content.decode()) + return code, msg diff --git a/requirements.txt b/requirements.txt index 03bbe7271..7a8d855bd 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,6 +6,7 @@ tox >= 3.6.0 cachecontrol >= 0.12.4 google-api-core[grpc] >= 1.7.0, < 2.0.0dev; platform.python_implementation != 'PyPy' +google-api-python-client >= 1.7.8 google-cloud-firestore >= 0.31.0; platform.python_implementation != 'PyPy' google-cloud-storage >= 1.13.0 six >= 1.6.1 diff --git a/setup.py b/setup.py index 9aa36f89f..15ae97f93 100644 --- a/setup.py +++ b/setup.py @@ -39,6 +39,7 @@ install_requires = [ 'cachecontrol>=0.12.4', 'google-api-core[grpc] >= 1.7.0, < 2.0.0dev; platform.python_implementation != "PyPy"', + 'google-api-python-client >= 1.7.8', 'google-cloud-firestore>=0.31.0; platform.python_implementation != "PyPy"', 'google-cloud-storage>=1.13.0', 'six>=1.6.1' diff --git a/tests/test_messaging.py b/tests/test_messaging.py index 8be2b8d8f..de940b591 100644 --- a/tests/test_messaging.py +++ b/tests/test_messaging.py @@ -20,6 +20,8 @@ import pytest import six +from googleapiclient.http import HttpMockSequence + import firebase_admin from firebase_admin import messaging from tests import testutils @@ -38,6 +40,30 @@ def check_encoding(msg, expected=None): assert encoded == expected +class TestMulticastMessage(object): + + @pytest.mark.parametrize('tokens', NON_LIST_ARGS) + def test_invalid_tokens_type(self, tokens): + with pytest.raises(ValueError) as excinfo: + messaging.MulticastMessage(tokens=tokens) + if isinstance(tokens, list): + expected = 'MulticastMessage.tokens must not contain non-string values.' + assert str(excinfo.value) == expected + else: + expected = 'MulticastMessage.tokens must be a list of strings.' + assert str(excinfo.value) == expected + + def test_tokens_over_one_hundred(self): + with pytest.raises(ValueError) as excinfo: + messaging.MulticastMessage(tokens=['token' for _ in range(0, 101)]) + expected = 'MulticastMessage.tokens must not contain more than 100 tokens.' + assert str(excinfo.value) == expected + + def test_tokens_type(self): + messaging.MulticastMessage(tokens=['token']) + messaging.MulticastMessage(tokens=['token' for _ in range(0, 100)]) + + class TestMessageEncoder(object): @pytest.mark.parametrize('msg', [ @@ -1316,6 +1342,415 @@ def test_send_fcm_error_code(self, status): assert json.loads(recorder[0].body.decode()) == body +class TestBatch(object): + + @classmethod + def setup_class(cls): + cred = testutils.MockCredential() + firebase_admin.initialize_app(cred, {'projectId': 'explicit-project-id'}) + + @classmethod + def teardown_class(cls): + testutils.cleanup_apps() + + def _instrument_batch_messaging_service(self, app=None, status=200, payload=''): + if not app: + app = firebase_admin.get_app() + fcm_service = messaging._get_messaging_service(app) + if status == 200: + content_type = 'multipart/mixed; boundary=boundary' + else: + content_type = 'application/json' + fcm_service._transport = HttpMockSequence([ + ({'status': str(status), 'content-type': content_type}, payload), + ]) + return fcm_service + + def _batch_payload(self, payloads): + # payloads should be a list of (status_code, content) tuples + payload = '' + _playload_format = """--boundary\r\nContent-Type: application/http\r\n\ +Content-ID: \r\n\r\nHTTP/1.1 {} Success\r\n\ +Content-Type: application/json; charset=UTF-8\r\n\r\n{}\r\n\r\n""" + for (index, (status_code, content)) in enumerate(payloads): + payload += _playload_format.format(str(index + 1), str(status_code), content) + payload += '--boundary--' + return payload + + +class TestSendAll(TestBatch): + + def test_no_project_id(self): + def evaluate(): + app = firebase_admin.initialize_app(testutils.MockCredential(), name='no_project_id') + with pytest.raises(ValueError): + messaging.send_all([messaging.Message(topic='foo')], app=app) + testutils.run_without_project_id(evaluate) + + @pytest.mark.parametrize('msg', NON_LIST_ARGS) + def test_invalid_send_all(self, msg): + with pytest.raises(ValueError) as excinfo: + messaging.send_all(msg) + if isinstance(msg, list): + expected = 'Message must be an instance of messaging.Message class.' + assert str(excinfo.value) == expected + else: + expected = 'Messages must be an list of messaging.Message instances.' + assert str(excinfo.value) == expected + + def test_invalid_over_one_hundred(self): + msg = messaging.Message(topic='foo') + with pytest.raises(ValueError) as excinfo: + messaging.send_all([msg for _ in range(0, 101)]) + expected = 'send_all messages must not contain more than 100 messages.' + assert str(excinfo.value) == expected + + def test_send_all(self): + payload = json.dumps({'name': 'message-id'}) + _ = self._instrument_batch_messaging_service( + payload=self._batch_payload([(200, payload), (200, payload)])) + msg = messaging.Message(topic='foo') + batch_response = messaging.send_all([msg, msg], dry_run=True) + assert batch_response.success_count is 2 + assert batch_response.failure_count is 0 + assert len(batch_response.responses) == 2 + assert [r.message_id for r in batch_response.responses] == ['message-id', 'message-id'] + assert all([r.success for r in batch_response.responses]) + assert not any([r.exception for r in batch_response.responses]) + + @pytest.mark.parametrize('status', HTTP_ERRORS) + def test_send_all_detailed_error(self, status): + success_payload = json.dumps({'name': 'message-id'}) + error_payload = json.dumps({ + 'error': { + 'status': 'INVALID_ARGUMENT', + 'message': 'test error' + } + }) + _ = self._instrument_batch_messaging_service( + payload=self._batch_payload([(200, success_payload), (status, error_payload)])) + msg = messaging.Message(topic='foo') + batch_response = messaging.send_all([msg, msg]) + assert batch_response.success_count is 1 + assert batch_response.failure_count is 1 + assert len(batch_response.responses) == 2 + success_response = batch_response.responses[0] + assert success_response.message_id == 'message-id' + assert success_response.success is True + assert success_response.exception is None + error_response = batch_response.responses[1] + assert error_response.message_id is None + assert error_response.success is False + assert error_response.exception is not None + exception = error_response.exception + assert str(exception) == 'test error' + assert str(exception.code) == 'invalid-argument' + + @pytest.mark.parametrize('status', HTTP_ERRORS) + def test_send_all_canonical_error_code(self, status): + success_payload = json.dumps({'name': 'message-id'}) + error_payload = json.dumps({ + 'error': { + 'status': 'NOT_FOUND', + 'message': 'test error' + } + }) + _ = self._instrument_batch_messaging_service( + payload=self._batch_payload([(200, success_payload), (status, error_payload)])) + msg = messaging.Message(topic='foo') + batch_response = messaging.send_all([msg, msg]) + assert batch_response.success_count is 1 + assert batch_response.failure_count is 1 + assert len(batch_response.responses) == 2 + success_response = batch_response.responses[0] + assert success_response.message_id == 'message-id' + assert success_response.success is True + assert success_response.exception is None + error_response = batch_response.responses[1] + assert error_response.message_id is None + assert error_response.success is False + assert error_response.exception is not None + exception = error_response.exception + assert str(exception) == 'test error' + assert str(exception.code) == 'registration-token-not-registered' + + @pytest.mark.parametrize('status', HTTP_ERRORS) + def test_send_all_fcm_error_code(self, status): + success_payload = json.dumps({'name': 'message-id'}) + error_payload = json.dumps({ + 'error': { + 'status': 'INVALID_ARGUMENT', + 'message': 'test error', + 'details': [ + { + '@type': 'type.googleapis.com/google.firebase.fcm.v1.FcmError', + 'errorCode': 'UNREGISTERED', + }, + ], + } + }) + _ = self._instrument_batch_messaging_service( + payload=self._batch_payload([(200, success_payload), (status, error_payload)])) + msg = messaging.Message(topic='foo') + batch_response = messaging.send_all([msg, msg]) + assert batch_response.success_count is 1 + assert batch_response.failure_count is 1 + assert len(batch_response.responses) == 2 + success_response = batch_response.responses[0] + assert success_response.message_id == 'message-id' + assert success_response.success is True + assert success_response.exception is None + error_response = batch_response.responses[1] + assert error_response.message_id is None + assert error_response.success is False + assert error_response.exception is not None + exception = error_response.exception + assert str(exception) == 'test error' + assert str(exception.code) == 'registration-token-not-registered' + + @pytest.mark.parametrize('status', HTTP_ERRORS) + def test_send_all_batch_error(self, status): + _ = self._instrument_batch_messaging_service(status=status, payload='{}') + msg = messaging.Message(topic='foo') + with pytest.raises(messaging.ApiCallError) as excinfo: + messaging.send_all([msg]) + expected = 'Unexpected HTTP response with status: {0}; body: {{}}'.format(status) + assert str(excinfo.value) == expected + assert str(excinfo.value.code) == messaging._MessagingService.UNKNOWN_ERROR + + @pytest.mark.parametrize('status', HTTP_ERRORS) + def test_send_all_batch_detailed_error(self, status): + payload = json.dumps({ + 'error': { + 'status': 'INVALID_ARGUMENT', + 'message': 'test error' + } + }) + _ = self._instrument_batch_messaging_service(status=status, payload=payload) + msg = messaging.Message(topic='foo') + with pytest.raises(messaging.ApiCallError) as excinfo: + messaging.send_all([msg]) + assert str(excinfo.value) == 'test error' + assert str(excinfo.value.code) == 'invalid-argument' + + @pytest.mark.parametrize('status', HTTP_ERRORS) + def test_send_all_batch_canonical_error_code(self, status): + payload = json.dumps({ + 'error': { + 'status': 'NOT_FOUND', + 'message': 'test error' + } + }) + _ = self._instrument_batch_messaging_service(status=status, payload=payload) + msg = messaging.Message(topic='foo') + with pytest.raises(messaging.ApiCallError) as excinfo: + messaging.send_all([msg]) + assert str(excinfo.value) == 'test error' + assert str(excinfo.value.code) == 'registration-token-not-registered' + + @pytest.mark.parametrize('status', HTTP_ERRORS) + def test_send_all_batch_fcm_error_code(self, status): + payload = json.dumps({ + 'error': { + 'status': 'INVALID_ARGUMENT', + 'message': 'test error', + 'details': [ + { + '@type': 'type.googleapis.com/google.firebase.fcm.v1.FcmError', + 'errorCode': 'UNREGISTERED', + }, + ], + } + }) + _ = self._instrument_batch_messaging_service(status=status, payload=payload) + msg = messaging.Message(topic='foo') + with pytest.raises(messaging.ApiCallError) as excinfo: + messaging.send_all([msg]) + assert str(excinfo.value) == 'test error' + assert str(excinfo.value.code) == 'registration-token-not-registered' + + +class TestSendMulticast(TestBatch): + + def test_no_project_id(self): + def evaluate(): + app = firebase_admin.initialize_app(testutils.MockCredential(), name='no_project_id') + with pytest.raises(ValueError): + messaging.send_all([messaging.Message(topic='foo')], app=app) + testutils.run_without_project_id(evaluate) + + @pytest.mark.parametrize('msg', NON_LIST_ARGS) + def test_invalid_send_multicast(self, msg): + with pytest.raises(ValueError) as excinfo: + messaging.send_multicast(msg) + expected = 'Message must be an instance of messaging.MulticastMessage class.' + assert str(excinfo.value) == expected + + def test_send_multicast(self): + payload = json.dumps({'name': 'message-id'}) + _ = self._instrument_batch_messaging_service( + payload=self._batch_payload([(200, payload), (200, payload)])) + msg = messaging.MulticastMessage(tokens=['foo', 'foo']) + batch_response = messaging.send_multicast(msg, dry_run=True) + assert batch_response.success_count is 2 + assert batch_response.failure_count is 0 + assert len(batch_response.responses) == 2 + assert [r.message_id for r in batch_response.responses] == ['message-id', 'message-id'] + assert all([r.success for r in batch_response.responses]) + assert not any([r.exception for r in batch_response.responses]) + + @pytest.mark.parametrize('status', HTTP_ERRORS) + def test_send_multicast_detailed_error(self, status): + success_payload = json.dumps({'name': 'message-id'}) + error_payload = json.dumps({ + 'error': { + 'status': 'INVALID_ARGUMENT', + 'message': 'test error' + } + }) + _ = self._instrument_batch_messaging_service( + payload=self._batch_payload([(200, success_payload), (status, error_payload)])) + msg = messaging.MulticastMessage(tokens=['foo', 'foo']) + batch_response = messaging.send_multicast(msg) + assert batch_response.success_count is 1 + assert batch_response.failure_count is 1 + assert len(batch_response.responses) == 2 + success_response = batch_response.responses[0] + assert success_response.message_id == 'message-id' + assert success_response.success is True + assert success_response.exception is None + error_response = batch_response.responses[1] + assert error_response.message_id is None + assert error_response.success is False + assert error_response.exception is not None + exception = error_response.exception + assert str(exception) == 'test error' + assert str(exception.code) == 'invalid-argument' + + @pytest.mark.parametrize('status', HTTP_ERRORS) + def test_send_multicast_canonical_error_code(self, status): + success_payload = json.dumps({'name': 'message-id'}) + error_payload = json.dumps({ + 'error': { + 'status': 'NOT_FOUND', + 'message': 'test error' + } + }) + _ = self._instrument_batch_messaging_service( + payload=self._batch_payload([(200, success_payload), (status, error_payload)])) + msg = messaging.MulticastMessage(tokens=['foo', 'foo']) + batch_response = messaging.send_multicast(msg) + assert batch_response.success_count is 1 + assert batch_response.failure_count is 1 + assert len(batch_response.responses) == 2 + success_response = batch_response.responses[0] + assert success_response.message_id == 'message-id' + assert success_response.success is True + assert success_response.exception is None + error_response = batch_response.responses[1] + assert error_response.message_id is None + assert error_response.success is False + assert error_response.exception is not None + exception = error_response.exception + assert str(exception) == 'test error' + assert str(exception.code) == 'registration-token-not-registered' + + @pytest.mark.parametrize('status', HTTP_ERRORS) + def test_send_multicast_fcm_error_code(self, status): + success_payload = json.dumps({'name': 'message-id'}) + error_payload = json.dumps({ + 'error': { + 'status': 'INVALID_ARGUMENT', + 'message': 'test error', + 'details': [ + { + '@type': 'type.googleapis.com/google.firebase.fcm.v1.FcmError', + 'errorCode': 'UNREGISTERED', + }, + ], + } + }) + _ = self._instrument_batch_messaging_service( + payload=self._batch_payload([(200, success_payload), (status, error_payload)])) + msg = messaging.MulticastMessage(tokens=['foo', 'foo']) + batch_response = messaging.send_multicast(msg) + assert batch_response.success_count is 1 + assert batch_response.failure_count is 1 + assert len(batch_response.responses) == 2 + success_response = batch_response.responses[0] + assert success_response.message_id == 'message-id' + assert success_response.success is True + assert success_response.exception is None + error_response = batch_response.responses[1] + assert error_response.message_id is None + assert error_response.success is False + assert error_response.exception is not None + exception = error_response.exception + assert str(exception) == 'test error' + assert str(exception.code) == 'registration-token-not-registered' + + @pytest.mark.parametrize('status', HTTP_ERRORS) + def test_send_multicast_batch_error(self, status): + _ = self._instrument_batch_messaging_service(status=status, payload='{}') + msg = messaging.MulticastMessage(tokens=['foo']) + with pytest.raises(messaging.ApiCallError) as excinfo: + messaging.send_multicast(msg) + expected = 'Unexpected HTTP response with status: {0}; body: {{}}'.format(status) + assert str(excinfo.value) == expected + assert str(excinfo.value.code) == messaging._MessagingService.UNKNOWN_ERROR + + @pytest.mark.parametrize('status', HTTP_ERRORS) + def test_send_multicast_batch_detailed_error(self, status): + payload = json.dumps({ + 'error': { + 'status': 'INVALID_ARGUMENT', + 'message': 'test error' + } + }) + _ = self._instrument_batch_messaging_service(status=status, payload=payload) + msg = messaging.MulticastMessage(tokens=['foo']) + with pytest.raises(messaging.ApiCallError) as excinfo: + messaging.send_multicast(msg) + assert str(excinfo.value) == 'test error' + assert str(excinfo.value.code) == 'invalid-argument' + + @pytest.mark.parametrize('status', HTTP_ERRORS) + def test_send_multicast_batch_canonical_error_code(self, status): + payload = json.dumps({ + 'error': { + 'status': 'NOT_FOUND', + 'message': 'test error' + } + }) + _ = self._instrument_batch_messaging_service(status=status, payload=payload) + msg = messaging.MulticastMessage(tokens=['foo']) + with pytest.raises(messaging.ApiCallError) as excinfo: + messaging.send_multicast(msg) + assert str(excinfo.value) == 'test error' + assert str(excinfo.value.code) == 'registration-token-not-registered' + + @pytest.mark.parametrize('status', HTTP_ERRORS) + def test_send_multicast_batch_fcm_error_code(self, status): + payload = json.dumps({ + 'error': { + 'status': 'INVALID_ARGUMENT', + 'message': 'test error', + 'details': [ + { + '@type': 'type.googleapis.com/google.firebase.fcm.v1.FcmError', + 'errorCode': 'UNREGISTERED', + }, + ], + } + }) + _ = self._instrument_batch_messaging_service(status=status, payload=payload) + msg = messaging.MulticastMessage(tokens=['foo']) + with pytest.raises(messaging.ApiCallError) as excinfo: + messaging.send_multicast(msg) + assert str(excinfo.value) == 'test error' + assert str(excinfo.value.code) == 'registration-token-not-registered' + + class TestTopicManagement(object): _DEFAULT_RESPONSE = json.dumps({'results': [{}, {'error': 'error_reason'}]}) From 7c413f6db8fbfd95d7d75fc8c97f76cc462fc50a Mon Sep 17 00:00:00 2001 From: Arash Fatahzade Date: Mon, 20 May 2019 22:36:22 +0430 Subject: [PATCH 010/226] Coding style and broken links (#288) * Fix 2 blank line between function definitions * Fix broken links --- CONTRIBUTING.md | 6 +++--- firebase_admin/auth.py | 14 ++++++++++++++ 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 487e3754a..39f865915 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -106,14 +106,14 @@ $ pip install -r requirements.txt # Install additional tools and dependencies We use [pylint](https://pylint.org/) for verifying source code format, and enforcing other Python programming best practices. -There is a pylint configuration file ([`.pylintrc`](../.pylintrc)) at the root of this Git +There is a pylint configuration file ([`.pylintrc`](.pylintrc)) at the root of this Git repository. This enables you to invoke pylint directly from the command line: ``` pylint firebase_admin ``` -However, it is recommended that you use the [`lint.sh`](../lint.sh) bash script to invoke +However, it is recommended that you use the [`lint.sh`](lint.sh) bash script to invoke pylint. This script will run the linter on both `firebase_admin` and the corresponding `tests` module. It suprresses some of the noisy warnings that get generated when running pylint on test code. Note that by default `lint.sh` will only @@ -226,7 +226,7 @@ pyenv install 3.3.0 # install Python 3.3.0 pyenv install pypy2-5.6.0 # install pypy2 ``` -Refer to the [`tox.ini`](../tox.ini) file for a list of target environments that we usually test. +Refer to the [`tox.ini`](tox.ini) file for a list of target environments that we usually test. Use pyenv to install all the required Python versions on your workstation. Verify that they are installed by running the following command: diff --git a/firebase_admin/auth.py b/firebase_admin/auth.py index 984c2babd..4f3d34b0b 100644 --- a/firebase_admin/auth.py +++ b/firebase_admin/auth.py @@ -121,6 +121,7 @@ def create_custom_token(uid, developer_claims=None, app=None): except _token_gen.ApiCallError as error: raise AuthError(error.code, str(error), error.detail) + def verify_id_token(id_token, app=None, check_revoked=False): """Verifies the signature and data for the provided JWT. @@ -150,6 +151,7 @@ def verify_id_token(id_token, app=None, check_revoked=False): _check_jwt_revoked(verified_claims, _ID_TOKEN_REVOKED, 'ID token', app) return verified_claims + def create_session_cookie(id_token, expires_in, app=None): """Creates a new Firebase session cookie from the given ID token and options. @@ -174,6 +176,7 @@ def create_session_cookie(id_token, expires_in, app=None): except _token_gen.ApiCallError as error: raise AuthError(error.code, str(error), error.detail) + def verify_session_cookie(session_cookie, check_revoked=False, app=None): """Verifies a Firebase session cookie. @@ -199,6 +202,7 @@ def verify_session_cookie(session_cookie, check_revoked=False, app=None): _check_jwt_revoked(verified_claims, _SESSION_COOKIE_REVOKED, 'session cookie', app) return verified_claims + def revoke_refresh_tokens(uid, app=None): """Revokes all refresh tokens for an existing user. @@ -214,6 +218,7 @@ def revoke_refresh_tokens(uid, app=None): user_manager = _get_auth_service(app).user_manager user_manager.update_user(uid, valid_since=int(time.time())) + def get_user(uid, app=None): """Gets the user data corresponding to the specified user ID. @@ -236,6 +241,7 @@ def get_user(uid, app=None): except _user_mgt.ApiCallError as error: raise AuthError(error.code, str(error), error.detail) + def get_user_by_email(email, app=None): """Gets the user data corresponding to the specified user email. @@ -281,6 +287,7 @@ def get_user_by_phone_number(phone_number, app=None): except _user_mgt.ApiCallError as error: raise AuthError(error.code, str(error), error.detail) + def list_users(page_token=None, max_results=_user_mgt.MAX_LIST_USERS_RESULTS, app=None): """Retrieves a page of user accounts from a Firebase project. @@ -381,6 +388,7 @@ def update_user(uid, **kwargs): except _user_mgt.ApiCallError as error: raise AuthError(error.code, str(error), error.detail) + def set_custom_user_claims(uid, custom_claims, app=None): """Sets additional claims on an existing user account. @@ -407,6 +415,7 @@ def set_custom_user_claims(uid, custom_claims, app=None): except _user_mgt.ApiCallError as error: raise AuthError(error.code, str(error), error.detail) + def delete_user(uid, app=None): """Deletes the user identified by the specified user ID. @@ -424,6 +433,7 @@ def delete_user(uid, app=None): except _user_mgt.ApiCallError as error: raise AuthError(error.code, str(error), error.detail) + def import_users(users, hash_alg=None, app=None): """Imports the specified list of users into Firebase Auth. @@ -453,6 +463,7 @@ def import_users(users, hash_alg=None, app=None): except _user_mgt.ApiCallError as error: raise AuthError(error.code, str(error), error.detail) + def generate_password_reset_link(email, action_code_settings=None, app=None): """Generates the out-of-band email action link for password reset flows for the specified email address. @@ -477,6 +488,7 @@ def generate_password_reset_link(email, action_code_settings=None, app=None): except _user_mgt.ApiCallError as error: raise AuthError(error.code, str(error), error.detail) + def generate_email_verification_link(email, action_code_settings=None, app=None): """Generates the out-of-band email action link for email verification flows for the specified email address. @@ -501,6 +513,7 @@ def generate_email_verification_link(email, action_code_settings=None, app=None) except _user_mgt.ApiCallError as error: raise AuthError(error.code, str(error), error.detail) + def generate_sign_in_with_email_link(email, action_code_settings, app=None): """Generates the out-of-band email action link for email link sign-in flows, using the action code settings provided. @@ -525,6 +538,7 @@ def generate_sign_in_with_email_link(email, action_code_settings, app=None): except _user_mgt.ApiCallError as error: raise AuthError(error.code, str(error), error.detail) + def _check_jwt_revoked(verified_claims, error_code, label, app): user = get_user(verified_claims.get('uid'), app=app) if verified_claims.get('iat') * 1000 < user.tokens_valid_after_timestamp: From 4488f5367c78bba66dd6d7a4147a5e9cf0ad84d0 Mon Sep 17 00:00:00 2001 From: Hiranya Jayathilaka Date: Mon, 20 May 2019 11:14:46 -0700 Subject: [PATCH 011/226] Introducing auth.DELETE_ATTRIBUTE sentinel value (#285) * Introducing auth.DELETE_ATTRIBUTE sentinel value for updating users * Fixing a lint error --- CHANGELOG.md | 5 ++++- firebase_admin/_user_mgt.py | 21 ++++++++++++++------- firebase_admin/auth.py | 11 +++++++---- tests/test_user_mgt.py | 22 +++++++++++++++++++++- 4 files changed, 46 insertions(+), 13 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5ceae9ffd..7dadefbda 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,9 @@ # Unreleased -- +- [added] Added a new `auth.DELETE_ATTRIBUTE` sentinel value, which can be + used to delete `phone_number`, `display_name`, `photo_url` and `custom_claims` + attributes from a user account. It is now recommended to use this sentinel + value over passing `None` for deleting attributes. # v2.16.0 diff --git a/firebase_admin/_user_mgt.py b/firebase_admin/_user_mgt.py index 71e2055ad..24bb2bdb6 100644 --- a/firebase_admin/_user_mgt.py +++ b/firebase_admin/_user_mgt.py @@ -36,11 +36,18 @@ MAX_LIST_USERS_RESULTS = 1000 MAX_IMPORT_USERS_SIZE = 1000 -class _Unspecified(object): - pass + +class Sentinel(object): + + def __init__(self, description): + self.description = description + # Use this internally, until sentinels are available in the public API. -_UNSPECIFIED = _Unspecified() +_UNSPECIFIED = Sentinel('No value specified') + + +DELETE_ATTRIBUTE = Sentinel('Value used to delete an attribute from a user profile') class ApiCallError(Exception): @@ -546,12 +553,12 @@ def update_user(self, uid, display_name=_UNSPECIFIED, email=None, phone_number=_ remove = [] if display_name is not _UNSPECIFIED: - if display_name is None: + if display_name is None or display_name is DELETE_ATTRIBUTE: remove.append('DISPLAY_NAME') else: payload['displayName'] = _auth_utils.validate_display_name(display_name) if photo_url is not _UNSPECIFIED: - if photo_url is None: + if photo_url is None or photo_url is DELETE_ATTRIBUTE: remove.append('PHOTO_URL') else: payload['photoUrl'] = _auth_utils.validate_photo_url(photo_url) @@ -559,13 +566,13 @@ def update_user(self, uid, display_name=_UNSPECIFIED, email=None, phone_number=_ payload['deleteAttribute'] = remove if phone_number is not _UNSPECIFIED: - if phone_number is None: + if phone_number is None or phone_number is DELETE_ATTRIBUTE: payload['deleteProvider'] = ['phone'] else: payload['phoneNumber'] = _auth_utils.validate_phone(phone_number) if custom_claims is not _UNSPECIFIED: - if custom_claims is None: + if custom_claims is None or custom_claims is DELETE_ATTRIBUTE: custom_claims = {} json_claims = json.dumps(custom_claims) if isinstance( custom_claims, dict) else custom_claims diff --git a/firebase_admin/auth.py b/firebase_admin/auth.py index 4f3d34b0b..0800d7c1e 100644 --- a/firebase_admin/auth.py +++ b/firebase_admin/auth.py @@ -37,6 +37,7 @@ __all__ = [ 'ActionCodeSettings', 'AuthError', + 'DELETE_ATTRIBUTE', 'ErrorInfo', 'ExportedUserRecord', 'ImportUserRecord', @@ -68,6 +69,7 @@ ] ActionCodeSettings = _user_mgt.ActionCodeSettings +DELETE_ATTRIBUTE = _user_mgt.DELETE_ATTRIBUTE ErrorInfo = _user_import.ErrorInfo ExportedUserRecord = _user_mgt.ExportedUserRecord ListUsersPage = _user_mgt.ListUsersPage @@ -359,17 +361,18 @@ def update_user(uid, **kwargs): Keyword Args: display_name: The user's display name (optional). Can be removed by explicitly passing - None. + ``auth.DELETE_ATTRIBUTE``. email: The user's primary email (optional). email_verified: A boolean indicating whether or not the user's primary email is verified (optional). phone_number: The user's primary phone number (optional). Can be removed by explicitly - passing None. - photo_url: The user's photo URL (optional). Can be removed by explicitly passing None. + passing ``auth.DELETE_ATTRIBUTE``. + photo_url: The user's photo URL (optional). Can be removed by explicitly passing + ``auth.DELETE_ATTRIBUTE``. password: The user's raw, unhashed password. (optional). disabled: A boolean indicating whether or not the user account is disabled (optional). custom_claims: A dictionary or a JSON string contining the custom claims to be set on the - user account (optional). + user account (optional). To remove all custom claims, pass ``auth.DELETE_ATTRIBUTE``. valid_since: An integer signifying the seconds since the epoch. This field is set by ``revoke_refresh_tokens`` and it is discouraged to set this field directly. diff --git a/tests/test_user_mgt.py b/tests/test_user_mgt.py index 6e033fae4..797e0ce59 100644 --- a/tests/test_user_mgt.py +++ b/tests/test_user_mgt.py @@ -381,7 +381,13 @@ def test_update_user_custom_claims(self, user_mgt_app): request = json.loads(recorder[0].body.decode()) assert request == {'localId' : 'testuser', 'customAttributes' : json.dumps(claims)} - def test_update_user_delete_fields(self, user_mgt_app): + def test_delete_user_custom_claims(self, user_mgt_app): + user_mgt, recorder = _instrument_user_manager(user_mgt_app, 200, '{"localId":"testuser"}') + user_mgt.update_user('testuser', custom_claims=auth.DELETE_ATTRIBUTE) + request = json.loads(recorder[0].body.decode()) + assert request == {'localId' : 'testuser', 'customAttributes' : json.dumps({})} + + def test_update_user_delete_fields_with_none(self, user_mgt_app): user_mgt, recorder = _instrument_user_manager(user_mgt_app, 200, '{"localId":"testuser"}') user_mgt.update_user('testuser', display_name=None, photo_url=None, phone_number=None) request = json.loads(recorder[0].body.decode()) @@ -391,6 +397,20 @@ def test_update_user_delete_fields(self, user_mgt_app): 'deleteProvider' : ['phone'], } + def test_update_user_delete_fields(self, user_mgt_app): + user_mgt, recorder = _instrument_user_manager(user_mgt_app, 200, '{"localId":"testuser"}') + user_mgt.update_user( + 'testuser', + display_name=auth.DELETE_ATTRIBUTE, + photo_url=auth.DELETE_ATTRIBUTE, + phone_number=auth.DELETE_ATTRIBUTE) + request = json.loads(recorder[0].body.decode()) + assert request == { + 'localId' : 'testuser', + 'deleteAttribute' : ['DISPLAY_NAME', 'PHOTO_URL'], + 'deleteProvider' : ['phone'], + } + def test_update_user_error(self, user_mgt_app): _instrument_user_manager(user_mgt_app, 500, '{"error":"test"}') with pytest.raises(auth.AuthError) as excinfo: From 5ef6f6dddb23a4c2dc54ea352f2999db2f2cfe17 Mon Sep 17 00:00:00 2001 From: Hiranya Jayathilaka Date: Wed, 22 May 2019 11:08:24 -0700 Subject: [PATCH 012/226] Added integration tests for the new multicast APIs (#289) * Added integration tests for the new multicast APIs * Fixing a lint error --- CHANGELOG.md | 2 ++ integration/test_messaging.py | 62 +++++++++++++++++++++++++++++++++++ 2 files changed, 64 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7dadefbda..c9a3bc42e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,7 @@ # Unreleased +- [added] Added new `send_all()` and `send_multicast()` APIs to the + `messasing` module. - [added] Added a new `auth.DELETE_ATTRIBUTE` sentinel value, which can be used to delete `phone_number`, `display_name`, `photo_url` and `custom_claims` attributes from a user account. It is now recommended to use this sentinel diff --git a/integration/test_messaging.py b/integration/test_messaging.py index e58737c70..7ebd5866a 100644 --- a/integration/test_messaging.py +++ b/integration/test_messaging.py @@ -47,6 +47,68 @@ def test_send(): msg_id = messaging.send(msg, dry_run=True) assert re.match('^projects/.*/messages/.*$', msg_id) +def test_send_all(): + messages = [ + messaging.Message( + topic='foo-bar', notification=messaging.Notification('Title', 'Body')), + messaging.Message( + topic='foo-bar', notification=messaging.Notification('Title', 'Body')), + messaging.Message( + token='not-a-token', notification=messaging.Notification('Title', 'Body')), + ] + + batch_response = messaging.send_all(messages, dry_run=True) + + assert batch_response.success_count == 2 + assert batch_response.failure_count == 1 + assert len(batch_response.responses) == 3 + + response = batch_response.responses[0] + assert response.success is True + assert response.exception is None + assert re.match('^projects/.*/messages/.*$', response.message_id) + + response = batch_response.responses[1] + assert response.success is True + assert response.exception is None + assert re.match('^projects/.*/messages/.*$', response.message_id) + + response = batch_response.responses[2] + assert response.success is False + assert response.exception is not None + assert response.message_id is None + +def test_send_one_hundred(): + messages = [] + for msg_number in range(100): + topic = 'foo-bar-{0}'.format(msg_number % 10) + messages.append(messaging.Message(topic=topic)) + + batch_response = messaging.send_all(messages, dry_run=True) + + assert batch_response.success_count == 100 + assert batch_response.failure_count == 0 + assert len(batch_response.responses) == 100 + for response in batch_response.responses: + assert response.success is True + assert response.exception is None + assert re.match('^projects/.*/messages/.*$', response.message_id) + +def test_send_multicast(): + multicast = messaging.MulticastMessage( + notification=messaging.Notification('Title', 'Body'), + tokens=['not-a-token', 'also-not-a-token']) + + batch_response = messaging.send_multicast(multicast) + + assert batch_response.success_count is 0 + assert batch_response.failure_count == 2 + assert len(batch_response.responses) == 2 + for response in batch_response.responses: + assert response.success is False + assert response.exception is not None + assert response.message_id is None + def test_subscribe(): resp = messaging.subscribe_to_topic(_REGISTRATION_TOKEN, 'mock-topic') assert resp.success_count + resp.failure_count == 1 From d6a1671843ecde6d70c22b6e1be0753549d8a027 Mon Sep 17 00:00:00 2001 From: Hiranya Jayathilaka Date: Thu, 23 May 2019 10:21:06 -0700 Subject: [PATCH 013/226] Bumped version to 2.17.0 (#292) * Bumped version to 2.17.0 * Removed additional whitespace --- CHANGELOG.md | 4 ++++ firebase_admin/__about__.py | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c9a3bc42e..5fc1a759e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,9 @@ # Unreleased +- + +# v2.17.0 + - [added] Added new `send_all()` and `send_multicast()` APIs to the `messasing` module. - [added] Added a new `auth.DELETE_ATTRIBUTE` sentinel value, which can be diff --git a/firebase_admin/__about__.py b/firebase_admin/__about__.py index cba8bc848..a5141c5e3 100644 --- a/firebase_admin/__about__.py +++ b/firebase_admin/__about__.py @@ -14,7 +14,7 @@ """About information (version, etc) for Firebase Admin SDK.""" -__version__ = '2.16.0' +__version__ = '2.17.0' __title__ = 'firebase_admin' __author__ = 'Firebase' __license__ = 'Apache License 2.0' From f2dd24ea6b3c471e4b3146b6d6bb7dc993fcb11d Mon Sep 17 00:00:00 2001 From: Hiranya Jayathilaka Date: Thu, 23 May 2019 13:03:23 -0700 Subject: [PATCH 014/226] API doc updates (#293) --- firebase_admin/_messaging_utils.py | 5 +---- firebase_admin/db.py | 2 +- firebase_admin/messaging.py | 10 +++++----- 3 files changed, 7 insertions(+), 10 deletions(-) diff --git a/firebase_admin/_messaging_utils.py b/firebase_admin/_messaging_utils.py index 373adf68c..17067f175 100644 --- a/firebase_admin/_messaging_utils.py +++ b/firebase_admin/_messaging_utils.py @@ -57,11 +57,8 @@ def __init__(self, data=None, notification=None, android=None, webpush=None, apn class MulticastMessage(object): """A message that can be sent to multiple tokens via Firebase Cloud Messaging. - Contains payload information as well as recipient information. In particular, the message must - contain exactly one of token, topic or condition fields. - Args: - tokens: A list of registration token of the device to which the message should be sent. + tokens: A list of registration tokens of targeted devices. data: A dictionary of data fields (optional). All keys and values in the dictionary must be strings. notification: An instance of ``messaging.Notification`` (optional). diff --git a/firebase_admin/db.py b/firebase_admin/db.py index f1bbeba8e..778699d47 100644 --- a/firebase_admin/db.py +++ b/firebase_admin/db.py @@ -351,7 +351,7 @@ def listen(self, callback): The specified callback function will get invoked with ``db.Event`` objects for each realtime update received from the database. It will also get called whenever the SDK - reconnects to the server due to network issues and credential expiration. In general, + reconnects to the server due to network issues or credential expiration. In general, the OAuth2 credentials used to authorize connections to the server expire every hour. Therefore clients should expect the ``callback`` to fire at least once every hour, even if there are no updates in the database. diff --git a/firebase_admin/messaging.py b/firebase_admin/messaging.py index 8129f8de1..35d9e4ccd 100644 --- a/firebase_admin/messaging.py +++ b/firebase_admin/messaging.py @@ -94,7 +94,7 @@ def send(message, dry_run=False, app=None): string: A message ID string that uniquely identifies the sent the message. Raises: - ApiCallError: If an error occurs while sending the message to FCM service. + ApiCallError: If an error occurs while sending the message to the FCM service. ValueError: If the input arguments are invalid. """ return _get_messaging_service(app).send(message, dry_run) @@ -114,7 +114,7 @@ def send_all(messages, dry_run=False, app=None): BatchResponse: A ``messaging.BatchResponse`` instance. Raises: - ApiCallError: If an error occurs while sending the message to FCM service. + ApiCallError: If an error occurs while sending the message to the FCM service. ValueError: If the input arguments are invalid. """ return _get_messaging_service(app).send_all(messages, dry_run) @@ -134,7 +134,7 @@ def send_multicast(multicast_message, dry_run=False, app=None): BatchResponse: A ``messaging.BatchResponse`` instance. Raises: - ApiCallError: If an error occurs while sending the message to FCM service. + ApiCallError: If an error occurs while sending the message to the FCM service. ValueError: If the input arguments are invalid. """ if not isinstance(multicast_message, MulticastMessage): @@ -285,7 +285,7 @@ def __init__(self, resp, exception): @property def message_id(self): - """A message ID string that uniquely identifies the sent the message.""" + """A message ID string that uniquely identifies the message.""" return self._message_id @property @@ -295,7 +295,7 @@ def success(self): @property def exception(self): - """A ApiCallError if an error occurs while sending the message to FCM service.""" + """An ApiCallError if an error occurs while sending the message to the FCM service.""" return self._exception From 8cf7291957ea58684d241508d8e52c3b28eb1afd Mon Sep 17 00:00:00 2001 From: wangwei <40977340+willawang8908@users.noreply.github.com> Date: Sat, 20 Jul 2019 05:19:55 +0800 Subject: [PATCH 015/226] Analytics label based on @chemidy (#310) * add analytics_label in FcmOptions * fix errors * fix errors * add analytics_label encoders * fix line-too-long * fix lint errors --- firebase_admin/_messaging_utils.py | 109 +++++++++++++++++++++++++++-- firebase_admin/messaging.py | 7 ++ tests/test_messaging.py | 66 ++++++++++++++++- 3 files changed, 174 insertions(+), 8 deletions(-) diff --git a/firebase_admin/_messaging_utils.py b/firebase_admin/_messaging_utils.py index 17067f175..09f7daf87 100644 --- a/firebase_admin/_messaging_utils.py +++ b/firebase_admin/_messaging_utils.py @@ -36,6 +36,7 @@ class Message(object): android: An instance of ``messaging.AndroidConfig`` (optional). webpush: An instance of ``messaging.WebpushConfig`` (optional). apns: An instance of ``messaging.ApnsConfig`` (optional). + fcm_options: An instance of ``messaging.FcmOptions`` (optional). token: The registration token of the device to which the message should be sent (optional). topic: Name of the FCM topic to which the message should be sent (optional). Topic name may contain the ``/topics/`` prefix. @@ -43,12 +44,13 @@ class Message(object): """ def __init__(self, data=None, notification=None, android=None, webpush=None, apns=None, - token=None, topic=None, condition=None): + fcm_options=None, token=None, topic=None, condition=None): self.data = data self.notification = notification self.android = android self.webpush = webpush self.apns = apns + self.fcm_options = fcm_options self.token = token self.topic = topic self.condition = condition @@ -65,8 +67,10 @@ class MulticastMessage(object): android: An instance of ``messaging.AndroidConfig`` (optional). webpush: An instance of ``messaging.WebpushConfig`` (optional). apns: An instance of ``messaging.ApnsConfig`` (optional). + fcm_options: An instance of ``messaging.FcmOptions`` (optional). """ - def __init__(self, tokens, data=None, notification=None, android=None, webpush=None, apns=None): + def __init__(self, tokens, data=None, notification=None, android=None, webpush=None, apns=None, + fcm_options=None): _Validators.check_string_list('MulticastMessage.tokens', tokens) if len(tokens) > 100: raise ValueError('MulticastMessage.tokens must not contain more than 100 tokens.') @@ -76,6 +80,7 @@ def __init__(self, tokens, data=None, notification=None, android=None, webpush=N self.android = android self.webpush = webpush self.apns = apns + self.fcm_options = fcm_options class Notification(object): @@ -107,16 +112,18 @@ class AndroidConfig(object): data: A dictionary of data fields (optional). All keys and values in the dictionary must be strings. When specified, overrides any data fields set via ``Message.data``. notification: A ``messaging.AndroidNotification`` to be included in the message (optional). + fcm_options: A ``messaging.AndroidFcmOptions`` to be included in the message (optional). """ def __init__(self, collapse_key=None, priority=None, ttl=None, restricted_package_name=None, - data=None, notification=None): + data=None, notification=None, fcm_options=None): self.collapse_key = collapse_key self.priority = priority self.ttl = ttl self.restricted_package_name = restricted_package_name self.data = data self.notification = notification + self.fcm_options = fcm_options class AndroidNotification(object): @@ -165,6 +172,18 @@ def __init__(self, title=None, body=None, icon=None, color=None, sound=None, tag self.channel_id = channel_id +class AndroidFcmOptions(object): + """Options for features provided by the FCM SDK for Android. + + Args: + analytics_label: contains additional options for features provided by the FCM Android SDK + (optional). + """ + + def __init__(self, analytics_label=None): + self.analytics_label = analytics_label + + class WebpushConfig(object): """Webpush-specific options that can be included in a message. @@ -279,14 +298,17 @@ class APNSConfig(object): Args: headers: A dictionary of headers (optional). payload: A ``messaging.APNSPayload`` to be included in the message (optional). + fcm_options: A ``messaging.APNSFcmOptions`` instance to be included in the message + (optional). .. _APNS Documentation: https://developer.apple.com/library/content/documentation\ /NetworkingInternet/Conceptual/RemoteNotificationsPG/CommunicatingwithAPNs.html """ - def __init__(self, headers=None, payload=None): + def __init__(self, headers=None, payload=None, fcm_options=None): self.headers = headers self.payload = payload + self.fcm_options = fcm_options class APNSPayload(object): @@ -387,6 +409,29 @@ def __init__(self, title=None, subtitle=None, body=None, loc_key=None, loc_args= self.launch_image = launch_image +class APNSFcmOptions(object): + """Options for features provided by the FCM SDK for iOS. + + Args: + analytics_label: contains additional options for features provided by the FCM iOS SDK + (optional). + """ + + def __init__(self, analytics_label=None): + self.analytics_label = analytics_label + + +class FcmOptions(object): + """Options for features provided by SDK. + + Args: + analytics_label: contains additional options to use across all platforms (optional). + """ + + def __init__(self, analytics_label=None): + self.analytics_label = analytics_label + + class _Validators(object): """A collection of data validation utilities. @@ -442,6 +487,14 @@ def check_string_list(cls, label, value): raise ValueError('{0} must not contain non-string values.'.format(label)) return value + @classmethod + def check_analytics_label(cls, label, value): + """Checks if the given value is a valid analytics label.""" + value = _Validators.check_string(label, value) + if value is not None and not re.match(r'^[a-zA-Z0-9-_.~%]{1,50}$', value): + raise ValueError('Malformed {}.'.format(label)) + return value + class MessageEncoder(json.JSONEncoder): """A custom JSONEncoder implementation for serializing Message instances into JSON.""" @@ -468,6 +521,7 @@ def encode_android(cls, android): 'restricted_package_name': _Validators.check_string( 'AndroidConfig.restricted_package_name', android.restricted_package_name), 'ttl': cls.encode_ttl(android.ttl), + 'fcm_options': cls.encode_android_fcm_options(android.fcm_options), } result = cls.remove_null_values(result) priority = result.get('priority') @@ -475,6 +529,21 @@ def encode_android(cls, android): raise ValueError('AndroidConfig.priority must be "high" or "normal".') return result + @classmethod + def encode_android_fcm_options(cls, fcm_options): + """Encodes a AndroidFcmOptions instance into a json.""" + if fcm_options is None: + return None + if not isinstance(fcm_options, AndroidFcmOptions): + raise ValueError('AndroidConfig.fcm_options must be an instance of ' + 'AndroidFcmOptions class.') + result = { + 'analytics_label': _Validators.check_analytics_label( + 'AndroidFcmOptions.analytics_label', fcm_options.analytics_label), + } + result = cls.remove_null_values(result) + return result + @classmethod def encode_ttl(cls, ttl): """Encodes a AndroidConfig TTL duration into a string.""" @@ -553,7 +622,7 @@ def encode_webpush(cls, webpush): 'headers': _Validators.check_string_dict( 'WebpushConfig.headers', webpush.headers), 'notification': cls.encode_webpush_notification(webpush.notification), - 'fcmOptions': cls.encode_webpush_fcm_options(webpush.fcm_options), + 'fcm_options': cls.encode_webpush_fcm_options(webpush.fcm_options), } return cls.remove_null_values(result) @@ -653,6 +722,7 @@ def encode_apns(cls, apns): 'headers': _Validators.check_string_dict( 'APNSConfig.headers', apns.headers), 'payload': cls.encode_apns_payload(apns.payload), + 'fcm_options': cls.encode_apns_fcm_options(apns.fcm_options), } return cls.remove_null_values(result) @@ -670,6 +740,20 @@ def encode_apns_payload(cls, payload): result[key] = value return cls.remove_null_values(result) + @classmethod + def encode_apns_fcm_options(cls, fcm_options): + """Encodes an APNSFcmOptions instance into JSON.""" + if fcm_options is None: + return None + if not isinstance(fcm_options, APNSFcmOptions): + raise ValueError('APNSConfig.fcm_options must be an instance of APNSFcmOptions class.') + result = { + 'analytics_label': _Validators.check_analytics_label( + 'APNSFcmOptions.analytics_label', fcm_options.analytics_label), + } + result = cls.remove_null_values(result) + return result + @classmethod def encode_aps(cls, aps): """Encodes an Aps instance into JSON.""" @@ -790,6 +874,7 @@ def default(self, obj): # pylint: disable=method-hidden 'token': _Validators.check_string('Message.token', obj.token, non_empty=True), 'topic': _Validators.check_string('Message.topic', obj.topic, non_empty=True), 'webpush': MessageEncoder.encode_webpush(obj.webpush), + 'fcm_options': MessageEncoder.encode_fcm_options(obj.fcm_options), } result['topic'] = MessageEncoder.sanitize_topic_name(result.get('topic')) result = MessageEncoder.remove_null_values(result) @@ -797,3 +882,17 @@ def default(self, obj): # pylint: disable=method-hidden if target_count != 1: raise ValueError('Exactly one of token, topic or condition must be specified.') return result + + @classmethod + def encode_fcm_options(cls, fcm_options): + """Encodes an FcmOptions instance into JSON.""" + if fcm_options is None: + return None + if not isinstance(fcm_options, FcmOptions): + raise ValueError('Message.fcm_options must be an instance of FcmOptions class.') + result = { + 'analytics_label': _Validators.check_analytics_label( + 'FcmOptions.analytics_label', fcm_options.analytics_label), + } + result = cls.remove_null_values(result) + return result diff --git a/firebase_admin/messaging.py b/firebase_admin/messaging.py index 35d9e4ccd..ddaef19f0 100644 --- a/firebase_admin/messaging.py +++ b/firebase_admin/messaging.py @@ -33,8 +33,10 @@ __all__ = [ 'AndroidConfig', + 'AndroidFcmOptions', 'AndroidNotification', 'APNSConfig', + 'APNSFcmOptions', 'APNSPayload', 'ApiCallError', 'Aps', @@ -42,6 +44,7 @@ 'BatchResponse', 'CriticalSound', 'ErrorInfo', + 'FcmOptions', 'Message', 'MulticastMessage', 'Notification', @@ -61,12 +64,15 @@ AndroidConfig = _messaging_utils.AndroidConfig +AndroidFcmOptions = _messaging_utils.AndroidFcmOptions AndroidNotification = _messaging_utils.AndroidNotification APNSConfig = _messaging_utils.APNSConfig +APNSFcmOptions = _messaging_utils.APNSFcmOptions APNSPayload = _messaging_utils.APNSPayload Aps = _messaging_utils.Aps ApsAlert = _messaging_utils.ApsAlert CriticalSound = _messaging_utils.CriticalSound +FcmOptions = _messaging_utils.FcmOptions Message = _messaging_utils.Message MulticastMessage = _messaging_utils.MulticastMessage Notification = _messaging_utils.Notification @@ -145,6 +151,7 @@ def send_multicast(multicast_message, dry_run=False, app=None): android=multicast_message.android, webpush=multicast_message.webpush, apns=multicast_message.apns, + fcm_options=multicast_message.fcm_options, token=token ) for token in multicast_message.tokens] return _get_messaging_service(app).send_all(messages, dry_run) diff --git a/tests/test_messaging.py b/tests/test_messaging.py index de940b591..878e1365b 100644 --- a/tests/test_messaging.py +++ b/tests/test_messaging.py @@ -120,6 +120,15 @@ def test_data_message(self): def test_prefixed_topic(self): check_encoding(messaging.Message(topic='/topics/topic'), {'topic': 'topic'}) + def test_fcm_options(self): + check_encoding( + messaging.Message( + topic='topic', fcm_options=messaging.FcmOptions('analytics_label_v1')), + {'topic': 'topic', 'fcm_options': {'analytics_label': 'analytics_label_v1'}}) + check_encoding( + messaging.Message(topic='topic', fcm_options=messaging.FcmOptions()), + {'topic': 'topic'}) + class TestNotificationEncoder(object): @@ -157,6 +166,47 @@ def test_notification_message(self): {'topic': 'topic', 'notification': {'title': 't'}}) +class TestFcmOptionEncoder(object): + + @pytest.mark.parametrize('label', [ + '!', + 'THIS_IS_LONGER_THAN_50_CHARACTERS_WHICH_IS_NOT_ALLOWED', + '', + ]) + def test_invalid_fcm_options(self, label): + with pytest.raises(ValueError) as excinfo: + check_encoding(messaging.Message( + topic='topic', + fcm_options=messaging.FcmOptions(label) + )) + expected = 'Malformed FcmOptions.analytics_label.' + assert str(excinfo.value) == expected + + def test_fcm_options(self): + check_encoding( + messaging.Message( + topic='topic', + fcm_options=messaging.FcmOptions(), + android=messaging.AndroidConfig(fcm_options=messaging.AndroidFcmOptions()), + apns=messaging.APNSConfig(fcm_options=messaging.APNSFcmOptions()) + ), + {'topic': 'topic'}) + check_encoding( + messaging.Message( + topic='topic', + fcm_options=messaging.FcmOptions('message-label'), + android=messaging.AndroidConfig( + fcm_options=messaging.AndroidFcmOptions('android-label')), + apns=messaging.APNSConfig(fcm_options=messaging.APNSFcmOptions('apns-label')) + ), + { + 'topic': 'topic', + 'fcm_options': {'analytics_label': 'message-label'}, + 'android': {'fcm_options': {'analytics_label': 'android-label'}}, + 'apns': {'fcm_options': {'analytics_label': 'apns-label'}}, + }) + + class TestAndroidConfigEncoder(object): @pytest.mark.parametrize('data', NON_OBJECT_ARGS) @@ -216,7 +266,8 @@ def test_android_config(self): restricted_package_name='package', priority='high', ttl=123, - data={'k1': 'v1', 'k2': 'v2'} + data={'k1': 'v1', 'k2': 'v2'}, + fcm_options=messaging.AndroidFcmOptions('analytics_label_v1') ) ) expected = { @@ -230,6 +281,9 @@ def test_android_config(self): 'k1': 'v1', 'k2': 'v2', }, + 'fcm_options': { + 'analytics_label': 'analytics_label_v1', + }, }, } check_encoding(msg, expected) @@ -484,7 +538,7 @@ def test_webpush_notification(self): expected = { 'topic': 'topic', 'webpush': { - 'fcmOptions': { + 'fcm_options': { 'link': 'https://example', }, }, @@ -714,7 +768,10 @@ def test_invalid_headers(self, data): def test_apns_config(self): msg = messaging.Message( topic='topic', - apns=messaging.APNSConfig(headers={'h1': 'v1', 'h2': 'v2'}) + apns=messaging.APNSConfig( + headers={'h1': 'v1', 'h2': 'v2'}, + fcm_options=messaging.APNSFcmOptions('analytics_label_v1') + ), ) expected = { 'topic': 'topic', @@ -723,6 +780,9 @@ def test_apns_config(self): 'h1': 'v1', 'h2': 'v2', }, + 'fcm_options': { + 'analytics_label': 'analytics_label_v1', + }, }, } check_encoding(msg, expected) From aecd35e1c91af966c5623dead46fac5b40082150 Mon Sep 17 00:00:00 2001 From: Yuchen Shi Date: Thu, 1 Aug 2019 14:22:39 -0700 Subject: [PATCH 016/226] Support RTDB Emulator via FIREBASE_DATABASE_EMULATOR_HOST. (#313) * Support RTDB Emulator via FIREBASE_DATABASE_EMULATOR_HOST. * Fix linter issues. * Defer ApplicationDefault init and remove FakeCredential class. * Fix lazy initialization and tests. * Address PR feedback and clean up URL parsing logic. * Use non-global app for db tests. * Simplify app project_id initialization logic. * Docstring. * Simplify parsing logic again! * Docstring and indentation fix. * Return! --- CONTRIBUTING.md | 12 +++ firebase_admin/__init__.py | 58 +++++++------- firebase_admin/credentials.py | 24 ++++-- firebase_admin/db.py | 139 ++++++++++++++++++++++++++-------- integration/conftest.py | 1 - integration/test_db.py | 57 ++++++++++---- tests/test_credentials.py | 5 +- tests/test_db.py | 49 ++++++++++-- tox.ini | 4 +- 9 files changed, 260 insertions(+), 89 deletions(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 39f865915..8c58e63e9 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -195,6 +195,18 @@ Now you can invoke the integration test suite as follows: pytest integration/ --cert scripts/cert.json --apikey scripts/apikey.txt ``` +### Emulator-based Integration Testing + +Some integration tests can run against emulators. This allows local testing +without using real projects or credentials. For now, only the RTDB Emulator +is supported. + +First, install the Firebase CLI, then run: + +``` +firebase emulators:exec --only database --project fake-project-id 'pytest integration/test_db.py' +``` + ### Test Coverage To review the test coverage, run `pytest` with the `--cov` flag. To view a detailed line by line diff --git a/firebase_admin/__init__.py b/firebase_admin/__init__.py index d802e15a0..bc9526378 100644 --- a/firebase_admin/__init__.py +++ b/firebase_admin/__init__.py @@ -215,39 +215,15 @@ def __init__(self, name, credential, options): self._options = _AppOptions(options) self._lock = threading.RLock() self._services = {} - self._project_id = App._lookup_project_id(self._credential, self._options) - @classmethod - def _lookup_project_id(cls, credential, options): - """Looks up the Firebase project ID associated with an App. - - This method first inspects the app options for a ``projectId`` entry. Then it attempts to - get the project ID from the credential used to initialize the app. If that also fails, - attempts to look up the ``GOOGLE_CLOUD_PROJECT`` and ``GCLOUD_PROJECT`` environment - variables. - - Args: - credential: A Firebase credential instance. - options: A Firebase AppOptions instance. - - Returns: - str: A project ID string or None. + App._validate_project_id(self._options.get('projectId')) + self._project_id_initialized = False - Raises: - ValueError: If a non-string project ID value is specified. - """ - project_id = options.get('projectId') - if not project_id: - try: - project_id = credential.project_id - except AttributeError: - pass - if not project_id: - project_id = os.environ.get('GOOGLE_CLOUD_PROJECT', os.environ.get('GCLOUD_PROJECT')) + @classmethod + def _validate_project_id(cls, project_id): if project_id is not None and not isinstance(project_id, six.string_types): raise ValueError( 'Invalid project ID: "{0}". project ID must be a string.'.format(project_id)) - return project_id @property def name(self): @@ -263,8 +239,34 @@ def options(self): @property def project_id(self): + if not self._project_id_initialized: + self._project_id = self._lookup_project_id() + self._project_id_initialized = True return self._project_id + def _lookup_project_id(self): + """Looks up the Firebase project ID associated with an App. + + If a ``projectId`` is specified in app options, it is returned. Then tries to + get the project ID from the credential used to initialize the app. If that also fails, + attempts to look up the ``GOOGLE_CLOUD_PROJECT`` and ``GCLOUD_PROJECT`` environment + variables. + + Returns: + str: A project ID string or None. + """ + project_id = self._options.get('projectId') + if not project_id: + try: + project_id = self._credential.project_id + except AttributeError: + pass + if not project_id: + project_id = os.environ.get('GOOGLE_CLOUD_PROJECT', + os.environ.get('GCLOUD_PROJECT')) + App._validate_project_id(self._options.get('projectId')) + return project_id + def _get_service(self, name, initializer): """Returns the service instance identified by the given name. diff --git a/firebase_admin/credentials.py b/firebase_admin/credentials.py index b5864beb8..2e400d9e4 100644 --- a/firebase_admin/credentials.py +++ b/firebase_admin/credentials.py @@ -123,26 +123,40 @@ class ApplicationDefault(Base): """A Google Application Default credential.""" def __init__(self): - """Initializes the Application Default credentials for the current environment. + """Creates an instance that will use Application Default credentials. - Raises: - google.auth.exceptions.DefaultCredentialsError: If Application Default - credentials cannot be initialized in the current environment. + The credentials will be lazily initialized when get_credential() or + project_id() is called. See those methods for possible errors raised. """ super(ApplicationDefault, self).__init__() - self._g_credential, self._project_id = google.auth.default(scopes=_scopes) + self._g_credential = None # Will be lazily-loaded via _load_credential(). def get_credential(self): """Returns the underlying Google credential. + Raises: + google.auth.exceptions.DefaultCredentialsError: If Application Default + credentials cannot be initialized in the current environment. Returns: google.auth.credentials.Credentials: A Google Auth credential instance.""" + self._load_credential() return self._g_credential @property def project_id(self): + """Returns the project_id from the underlying Google credential. + + Raises: + google.auth.exceptions.DefaultCredentialsError: If Application Default + credentials cannot be initialized in the current environment. + Returns: + str: The project id.""" + self._load_credential() return self._project_id + def _load_credential(self): + if not self._g_credential: + self._g_credential, self._project_id = google.auth.default(scopes=_scopes) class RefreshToken(Base): """A credential initialized from an existing refresh token.""" diff --git a/firebase_admin/db.py b/firebase_admin/db.py index 778699d47..53efd9b15 100644 --- a/firebase_admin/db.py +++ b/firebase_admin/db.py @@ -22,9 +22,11 @@ import collections import json +import os import sys import threading +import google.auth import requests import six from six.moves import urllib @@ -41,6 +43,7 @@ _USER_AGENT = 'Firebase/HTTP/{0}/{1}.{2}/AdminPython'.format( firebase_admin.__version__, sys.version_info.major, sys.version_info.minor) _TRANSACTION_MAX_RETRIES = 25 +_EMULATOR_HOST_ENV_VAR = 'FIREBASE_DATABASE_EMULATOR_HOST' def reference(path='/', app=None, url=None): @@ -768,46 +771,108 @@ class _DatabaseService(object): _DEFAULT_AUTH_OVERRIDE = '_admin_' def __init__(self, app): - self._credential = app.credential.get_credential() + self._credential = app.credential db_url = app.options.get('databaseURL') if db_url: - self._db_url = _DatabaseService._validate_url(db_url) + _DatabaseService._parse_db_url(db_url) # Just for validation. + self._db_url = db_url else: self._db_url = None auth_override = _DatabaseService._get_auth_override(app) if auth_override != self._DEFAULT_AUTH_OVERRIDE and auth_override != {}: - encoded = json.dumps(auth_override, separators=(',', ':')) - self._auth_override = 'auth_variable_override={0}'.format(encoded) + self._auth_override = json.dumps(auth_override, separators=(',', ':')) else: self._auth_override = None self._timeout = app.options.get('httpTimeout') self._clients = {} - def get_client(self, base_url=None): - if base_url is None: - base_url = self._db_url - base_url = _DatabaseService._validate_url(base_url) - if base_url not in self._clients: - client = _Client(self._credential, base_url, self._auth_override, self._timeout) - self._clients[base_url] = client - return self._clients[base_url] + emulator_host = os.environ.get(_EMULATOR_HOST_ENV_VAR) + if emulator_host: + if '//' in emulator_host: + raise ValueError( + 'Invalid {0}: "{1}". It must follow format "host:port".'.format( + _EMULATOR_HOST_ENV_VAR, emulator_host)) + self._emulator_host = emulator_host + else: + self._emulator_host = None + + def get_client(self, db_url=None): + """Creates a client based on the db_url. Clients may be cached.""" + if db_url is None: + db_url = self._db_url + + base_url, namespace = _DatabaseService._parse_db_url(db_url, self._emulator_host) + if base_url == 'https://{0}.firebaseio.com'.format(namespace): + # Production base_url. No need to specify namespace in query params. + params = {} + credential = self._credential.get_credential() + else: + # Emulator base_url. Use fake credentials and specify ?ns=foo in query params. + credential = _EmulatorAdminCredentials() + params = {'ns': namespace} + if self._auth_override: + params['auth_variable_override'] = self._auth_override + + client_cache_key = (base_url, json.dumps(params, sort_keys=True)) + if client_cache_key not in self._clients: + client = _Client(credential, base_url, self._timeout, params) + self._clients[client_cache_key] = client + return self._clients[client_cache_key] @classmethod - def _validate_url(cls, url): - """Parses and validates a given database URL.""" + def _parse_db_url(cls, url, emulator_host=None): + """Parses (base_url, namespace) from a database URL. + + The input can be either a production URL (https://foo-bar.firebaseio.com/) + or an Emulator URL (http://localhost:8080/?ns=foo-bar). In case of Emulator + URL, the namespace is extracted from the query param ns. The resulting + base_url never includes query params. + + If url is a production URL and emulator_host is specified, the result + base URL will use emulator_host instead. emulator_host is ignored + if url is already an emulator URL. + """ if not url or not isinstance(url, six.string_types): raise ValueError( 'Invalid database URL: "{0}". Database URL must be a non-empty ' 'URL string.'.format(url)) - parsed = urllib.parse.urlparse(url) - if parsed.scheme != 'https': + parsed_url = urllib.parse.urlparse(url) + if parsed_url.netloc.endswith('.firebaseio.com'): + return cls._parse_production_url(parsed_url, emulator_host) + else: + return cls._parse_emulator_url(parsed_url) + + @classmethod + def _parse_production_url(cls, parsed_url, emulator_host): + """Parses production URL like https://foo-bar.firebaseio.com/""" + if parsed_url.scheme != 'https': raise ValueError( - 'Invalid database URL: "{0}". Database URL must be an HTTPS URL.'.format(url)) - elif not parsed.netloc.endswith('.firebaseio.com'): + 'Invalid database URL scheme: "{0}". Database URL must be an HTTPS URL.'.format( + parsed_url.scheme)) + namespace = parsed_url.netloc.split('.')[0] + if not namespace: raise ValueError( 'Invalid database URL: "{0}". Database URL must be a valid URL to a ' - 'Firebase Realtime Database instance.'.format(url)) - return 'https://{0}'.format(parsed.netloc) + 'Firebase Realtime Database instance.'.format(parsed_url.geturl())) + + if emulator_host: + base_url = 'http://{0}'.format(emulator_host) + else: + base_url = 'https://{0}'.format(parsed_url.netloc) + return base_url, namespace + + @classmethod + def _parse_emulator_url(cls, parsed_url): + """Parses emulator URL like http://localhost:8080/?ns=foo-bar""" + query_ns = urllib.parse.parse_qs(parsed_url.query).get('ns') + if parsed_url.scheme != 'http' or (not query_ns or len(query_ns) != 1 or not query_ns[0]): + raise ValueError( + 'Invalid database URL: "{0}". Database URL must be a valid URL to a ' + 'Firebase Realtime Database instance.'.format(parsed_url.geturl())) + + namespace = query_ns[0] + base_url = '{0}://{1}'.format(parsed_url.scheme, parsed_url.netloc) + return base_url, namespace @classmethod def _get_auth_override(cls, app): @@ -833,7 +898,7 @@ class _Client(_http_client.JsonHttpClient): marshalling and unmarshalling of JSON data. """ - def __init__(self, credential, base_url, auth_override, timeout): + def __init__(self, credential, base_url, timeout, params=None): """Creates a new _Client from the given parameters. This exists primarily to enable testing. For regular use, obtain _Client instances by @@ -843,22 +908,21 @@ def __init__(self, credential, base_url, auth_override, timeout): credential: A Google credential that can be used to authenticate requests. base_url: A URL prefix to be added to all outgoing requests. This is typically the Firebase Realtime Database URL. - auth_override: The encoded auth_variable_override query parameter to be included in - outgoing requests. timeout: HTTP request timeout in seconds. If not set connections will never timeout, which is the default behavior of the underlying requests library. + params: Dict of query parameters to add to all outgoing requests. """ _http_client.JsonHttpClient.__init__( self, credential=credential, base_url=base_url, headers={'User-Agent': _USER_AGENT}) self.credential = credential - self.auth_override = auth_override self.timeout = timeout + self.params = params if params else {} def request(self, method, url, **kwargs): """Makes an HTTP call using the Python requests library. - Extends the request() method of the parent JsonHttpClient class. Handles auth overrides, - and low-level exceptions. + Extends the request() method of the parent JsonHttpClient class. Handles default + params like auth overrides, and low-level exceptions. Args: method: HTTP method name as a string (e.g. get, post). @@ -872,13 +936,15 @@ def request(self, method, url, **kwargs): Raises: ApiCallError: If an error occurs while making the HTTP call. """ - if self.auth_override: - params = kwargs.get('params') - if params: - params += '&{0}'.format(self.auth_override) + query = '&'.join('{0}={1}'.format(key, self.params[key]) for key in self.params) + extra_params = kwargs.get('params') + if extra_params: + if query: + query = extra_params + '&' + query else: - params = self.auth_override - kwargs['params'] = params + query = extra_params + kwargs['params'] = query + if self.timeout: kwargs['timeout'] = self.timeout try: @@ -911,3 +977,12 @@ def extract_error_message(cls, error): except ValueError: pass return '{0}\nReason: {1}'.format(error, error.response.content.decode()) + + +class _EmulatorAdminCredentials(google.auth.credentials.Credentials): + def __init__(self): + google.auth.credentials.Credentials.__init__(self) + self.token = 'owner' + + def refresh(self, request): + pass diff --git a/integration/conftest.py b/integration/conftest.py index 912247a01..169e02d5b 100644 --- a/integration/conftest.py +++ b/integration/conftest.py @@ -70,4 +70,3 @@ def api_key(request): 'command-line option.') with open(path) as keyfile: return keyfile.read().strip() - \ No newline at end of file diff --git a/integration/test_db.py b/integration/test_db.py index cd666f576..d88d145ba 100644 --- a/integration/test_db.py +++ b/integration/test_db.py @@ -15,6 +15,7 @@ """Integration tests for firebase_admin.db module.""" import collections import json +import os import pytest import six @@ -25,11 +26,35 @@ from tests import testutils +def integration_conf(request): + host_override = os.environ.get('FIREBASE_DATABASE_EMULATOR_HOST') + if host_override: + return None, 'fake-project-id' + else: + return conftest.integration_conf(request) + + +@pytest.fixture(scope='module') +def app(request): + cred, project_id = integration_conf(request) + ops = { + 'databaseURL' : 'https://{0}.firebaseio.com'.format(project_id), + } + return firebase_admin.initialize_app(cred, ops, name='integration-db') + + +@pytest.fixture(scope='module', autouse=True) +def default_app(): + # Overwrites the default_app fixture in conftest.py. + # This test suite should not use the default app. Use the app fixture instead. + pass + + @pytest.fixture(scope='module') -def update_rules(): +def update_rules(app): with open(testutils.resource_filename('dinosaurs_index.json')) as rules_file: new_rules = json.load(rules_file) - client = db.reference()._client + client = db.reference('', app)._client rules = client.body('get', '/.settings/rules.json') existing = rules.get('rules') if existing != new_rules: @@ -42,7 +67,7 @@ def testdata(): return json.load(dino_file) @pytest.fixture(scope='module') -def testref(update_rules, testdata): +def testref(update_rules, testdata, app): """Adds the necessary DB indices, and sets the initial values. This fixture is attached to the module scope, and therefore is guaranteed to run only once @@ -52,7 +77,7 @@ def testref(update_rules, testdata): Reference: A reference to the test dinosaur database. """ del update_rules - ref = db.reference('_adminsdk/python/dinodb') + ref = db.reference('_adminsdk/python/dinodb', app) ref.set(testdata) return ref @@ -304,7 +329,7 @@ def test_filter_by_value(self, testref): @pytest.fixture(scope='module') def override_app(request, update_rules): del update_rules - cred, project_id = conftest.integration_conf(request) + cred, project_id = integration_conf(request) ops = { 'databaseURL' : 'https://{0}.firebaseio.com'.format(project_id), 'databaseAuthVariableOverride' : {'uid' : 'user1'} @@ -316,7 +341,7 @@ def override_app(request, update_rules): @pytest.fixture(scope='module') def none_override_app(request, update_rules): del update_rules - cred, project_id = conftest.integration_conf(request) + cred, project_id = integration_conf(request) ops = { 'databaseURL' : 'https://{0}.firebaseio.com'.format(project_id), 'databaseAuthVariableOverride' : None @@ -329,8 +354,8 @@ def none_override_app(request, update_rules): class TestAuthVariableOverride(object): """Test cases for database auth variable overrides.""" - def init_ref(self, path): - admin_ref = db.reference(path) + def init_ref(self, path, app): + admin_ref = db.reference(path, app) admin_ref.set('test') assert admin_ref.get() == 'test' @@ -338,9 +363,9 @@ def check_permission_error(self, excinfo): assert isinstance(excinfo.value, db.ApiCallError) assert 'Reason: Permission denied' in str(excinfo.value) - def test_no_access(self, override_app): + def test_no_access(self, app, override_app): path = '_adminsdk/python/admin' - self.init_ref(path) + self.init_ref(path, app) user_ref = db.reference(path, override_app) with pytest.raises(db.ApiCallError) as excinfo: assert user_ref.get() @@ -350,18 +375,18 @@ def test_no_access(self, override_app): user_ref.set('test2') self.check_permission_error(excinfo) - def test_read(self, override_app): + def test_read(self, app, override_app): path = '_adminsdk/python/protected/user2' - self.init_ref(path) + self.init_ref(path, app) user_ref = db.reference(path, override_app) assert user_ref.get() == 'test' with pytest.raises(db.ApiCallError) as excinfo: user_ref.set('test2') self.check_permission_error(excinfo) - def test_read_write(self, override_app): + def test_read_write(self, app, override_app): path = '_adminsdk/python/protected/user1' - self.init_ref(path) + self.init_ref(path, app) user_ref = db.reference(path, override_app) assert user_ref.get() == 'test' user_ref.set('test2') @@ -373,9 +398,9 @@ def test_query(self, override_app): user_ref.order_by_key().limit_to_first(2).get() self.check_permission_error(excinfo) - def test_none_auth_override(self, none_override_app): + def test_none_auth_override(self, app, none_override_app): path = '_adminsdk/python/public' - self.init_ref(path) + self.init_ref(path, app) public_ref = db.reference(path, none_override_app) assert public_ref.get() == 'test' diff --git a/tests/test_credentials.py b/tests/test_credentials.py index 685d89d0f..6f081d796 100644 --- a/tests/test_credentials.py +++ b/tests/test_credentials.py @@ -115,8 +115,11 @@ def test_init(self, app_default): indirect=True) def test_nonexisting_path(self, app_default): del app_default + # This does not yet throw because the credentials are lazily loaded. + creds = credentials.ApplicationDefault() + with pytest.raises(exceptions.DefaultCredentialsError): - credentials.ApplicationDefault() + creds.get_credential() # This now throws. class TestRefreshToken(object): diff --git a/tests/test_db.py b/tests/test_db.py index 6168b72d4..211eabb4b 100644 --- a/tests/test_db.py +++ b/tests/test_db.py @@ -623,6 +623,45 @@ def test_no_db_url(self): with pytest.raises(ValueError): db.reference() + @pytest.mark.parametrize( + 'url,emulator_host,expected_base_url,expected_namespace', + [ + # Production URLs with no override: + ('https://test.firebaseio.com', None, 'https://test.firebaseio.com', 'test'), + ('https://test.firebaseio.com/', None, 'https://test.firebaseio.com', 'test'), + + # Production URLs with emulator_host override: + ('https://test.firebaseio.com', 'localhost:9000', 'http://localhost:9000', 'test'), + ('https://test.firebaseio.com/', 'localhost:9000', 'http://localhost:9000', 'test'), + + # Emulator URLs with no override. + ('http://localhost:8000/?ns=test', None, 'http://localhost:8000', 'test'), + # emulator_host is ignored when the original URL is already emulator. + ('http://localhost:8000/?ns=test', 'localhost:9999', 'http://localhost:8000', 'test'), + ] + ) + def test_parse_db_url(self, url, emulator_host, expected_base_url, expected_namespace): + base_url, namespace = db._DatabaseService._parse_db_url(url, emulator_host) + assert base_url == expected_base_url + assert namespace == expected_namespace + + @pytest.mark.parametrize('url,emulator_host', [ + ('', None), + (None, None), + (42, None), + ('test.firebaseio.com', None), # Not a URL. + ('http://test.firebaseio.com', None), # Use of non-HTTPs in production URLs. + ('ftp://test.firebaseio.com', None), # Use of non-HTTPs in production URLs. + ('https://example.com', None), # Invalid RTDB URL. + ('http://localhost:9000/', None), # No ns specified. + ('http://localhost:9000/?ns=', None), # No ns specified. + ('http://localhost:9000/?ns=test1&ns=test2', None), # Two ns parameters specified. + ('ftp://localhost:9000/?ns=test', None), # Neither HTTP or HTTPS. + ]) + def test_parse_db_url_errors(self, url, emulator_host): + with pytest.raises(ValueError): + db._DatabaseService._parse_db_url(url, emulator_host) + @pytest.mark.parametrize('url', [ 'https://test.firebaseio.com', 'https://test.firebaseio.com/' ]) @@ -633,7 +672,7 @@ def test_valid_db_url(self, url): adapter = MockAdapter('{}', 200, recorder) ref._client.session.mount(url, adapter) assert ref._client.base_url == 'https://test.firebaseio.com' - assert ref._client.auth_override is None + assert 'auth_variable_override' not in ref._client.params assert ref._client.timeout is None assert ref.get() == {} assert len(recorder) == 1 @@ -658,7 +697,7 @@ def test_multi_db_support(self): }) ref = db.reference() assert ref._client.base_url == default_url - assert ref._client.auth_override is None + assert 'auth_variable_override' not in ref._client.params assert ref._client.timeout is None assert ref._client is db.reference()._client assert ref._client is db.reference(url=default_url)._client @@ -666,7 +705,7 @@ def test_multi_db_support(self): other_url = 'https://other.firebaseio.com' other_ref = db.reference(url=other_url) assert other_ref._client.base_url == other_url - assert other_ref._client.auth_override is None + assert 'auth_variable_override' not in ref._client.params assert other_ref._client.timeout is None assert other_ref._client is db.reference(url=other_url)._client assert other_ref._client is db.reference(url=other_url + '/')._client @@ -682,10 +721,10 @@ def test_valid_auth_override(self, override): for ref in [default_ref, other_ref]: assert ref._client.timeout is None if override == {}: - assert ref._client.auth_override is None + assert 'auth_variable_override' not in ref._client.params else: encoded = json.dumps(override, separators=(',', ':')) - assert ref._client.auth_override == 'auth_variable_override={0}'.format(encoded) + assert ref._client.params['auth_variable_override'] == encoded @pytest.mark.parametrize('override', [ '', 'foo', 0, 1, True, False, list(), tuple(), _Object()]) diff --git a/tox.ini b/tox.ini index 64239322c..dec7b618f 100644 --- a/tox.ini +++ b/tox.ini @@ -7,7 +7,9 @@ envlist = py2,py3,pypy,cover [testenv] -commands = pytest +passenv = + FIREBASE_DATABASE_EMULATOR_HOST +commands = pytest {posargs} deps = pytest pytest-localserver From bbfa5e861cbb07b8e5a89195bb0dde5e72cd7e56 Mon Sep 17 00:00:00 2001 From: Yuchen Shi Date: Thu, 1 Aug 2019 15:33:51 -0700 Subject: [PATCH 017/226] Add RTDB-emulator-based integration testing. (#316) * Add RTDB-emulator-based integration testing. * Combine unit tests and integration tests into a single job. --- .travis.yml | 24 +++++++++++++++++------- 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/.travis.yml b/.travis.yml index a9114f0f4..4db3c3708 100644 --- a/.travis.yml +++ b/.travis.yml @@ -5,10 +5,20 @@ python: - "3.5" - "3.6" - "pypy3.5" -# command to install dependencies -install: "pip install -r requirements.txt" -before_script: - - export PY_VERSION=`python -c 'import sys; print sys.version_info.major'` - - if [[ "$PY_VERSION" == '2' ]]; then ./lint.sh all; fi -# command to run tests -script: pytest + +jobs: + include: + - name: "Lint" + python: "2.7" + script: ./lint.sh all + +before_install: + - nvm install 8 && npm install -g firebase-tools +script: + - pytest + - firebase emulators:exec --only database --project fake-project-id 'pytest integration/test_db.py' +cache: + pip: true + npm: true + directories: + - $HOME/.cache/firebase/emulators From 5408fbc1d7a9f8310a9fd85f22580ae38d795ea2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Viktoras=20Laukevi=C4=8Dius?= Date: Wed, 14 Aug 2019 21:33:49 +0300 Subject: [PATCH 018/226] Add support for arbitrary key-value pairs in messaging.ApsAlert (#322) --- firebase_admin/_messaging_utils.py | 14 ++++++- tests/test_messaging.py | 66 ++++++++++++++++++++++++++++++ 2 files changed, 79 insertions(+), 1 deletion(-) diff --git a/firebase_admin/_messaging_utils.py b/firebase_admin/_messaging_utils.py index 09f7daf87..d6e263dcf 100644 --- a/firebase_admin/_messaging_utils.py +++ b/firebase_admin/_messaging_utils.py @@ -394,10 +394,13 @@ class ApsAlert(object): action_loc_key: Key of the text in the app's string resources to use to localize the action button text (optional). launch_image: Image for the notification action (optional). + custom_data: A dict of custom key-value pairs to be included in the ApsAlert dictionary + (optional) """ def __init__(self, title=None, subtitle=None, body=None, loc_key=None, loc_args=None, - title_loc_key=None, title_loc_args=None, action_loc_key=None, launch_image=None): + title_loc_key=None, title_loc_args=None, action_loc_key=None, launch_image=None, + custom_data=None): self.title = title self.subtitle = subtitle self.body = body @@ -407,6 +410,7 @@ def __init__(self, title=None, subtitle=None, body=None, loc_key=None, loc_args= self.title_loc_args = title_loc_args self.action_loc_key = action_loc_key self.launch_image = launch_image + self.custom_data = custom_data class APNSFcmOptions(object): @@ -835,6 +839,14 @@ def encode_aps_alert(cls, alert): if result.get('title-loc-args') and not result.get('title-loc-key'): raise ValueError( 'ApsAlert.title_loc_key is required when specifying title_loc_args.') + if alert.custom_data is not None: + if not isinstance(alert.custom_data, dict): + raise ValueError('ApsAlert.custom_data must be a dict.') + for key, val in alert.custom_data.items(): + _Validators.check_string('ApsAlert.custom_data key', key) + # allow specifying key override because Apple could update API so that key + # could have unexpected value type + result[key] = val return cls.remove_null_values(result) @classmethod diff --git a/tests/test_messaging.py b/tests/test_messaging.py index 878e1365b..67ee0721d 100644 --- a/tests/test_messaging.py +++ b/tests/test_messaging.py @@ -1209,6 +1209,72 @@ def test_aps_alert(self): } check_encoding(msg, expected) + def test_aps_alert_custom_data_merge(self): + msg = messaging.Message( + topic='topic', + apns=messaging.APNSConfig( + payload=messaging.APNSPayload( + aps=messaging.Aps( + alert=messaging.ApsAlert( + title='t', + subtitle='st', + custom_data={'k1': 'v1', 'k2': 'v2'} + ) + ), + ) + ) + ) + expected = { + 'topic': 'topic', + 'apns': { + 'payload': { + 'aps': { + 'alert': { + 'title': 't', + 'subtitle': 'st', + 'k1': 'v1', + 'k2': 'v2' + }, + }, + } + }, + } + check_encoding(msg, expected) + + def test_aps_alert_custom_data_override(self): + msg = messaging.Message( + topic='topic', + apns=messaging.APNSConfig( + payload=messaging.APNSPayload( + aps=messaging.Aps( + alert=messaging.ApsAlert( + title='t', + subtitle='st', + launch_image='li', + custom_data={'launch-image': ['li1', 'li2']} + ) + ), + ) + ) + ) + expected = { + 'topic': 'topic', + 'apns': { + 'payload': { + 'aps': { + 'alert': { + 'title': 't', + 'subtitle': 'st', + 'launch-image': [ + 'li1', + 'li2' + ] + }, + }, + } + }, + } + check_encoding(msg, expected) class TestTimeout(object): From a79ee979b4758ee4915592499af590c1c50e52f2 Mon Sep 17 00:00:00 2001 From: Hiranya Jayathilaka Date: Fri, 16 Aug 2019 14:58:46 -0700 Subject: [PATCH 019/226] Updated the metadata server URL (#324) --- firebase_admin/_token_gen.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/firebase_admin/_token_gen.py b/firebase_admin/_token_gen.py index e2eaa5715..7af7b73b7 100644 --- a/firebase_admin/_token_gen.py +++ b/firebase_admin/_token_gen.py @@ -48,8 +48,8 @@ 'acr', 'amr', 'at_hash', 'aud', 'auth_time', 'azp', 'cnf', 'c_hash', 'exp', 'firebase', 'iat', 'iss', 'jti', 'nbf', 'nonce', 'sub' ]) -METADATA_SERVICE_URL = ('http://metadata/computeMetadata/v1/instance/service-accounts/' - 'default/email') +METADATA_SERVICE_URL = ('http://metadata.google.internal/computeMetadata/v1/instance/' + 'service-accounts/default/email') # Error codes COOKIE_CREATE_ERROR = 'COOKIE_CREATE_ERROR' From 69c03e39a7b1cceeb7cee847828143b194c1a377 Mon Sep 17 00:00:00 2001 From: Hiranya Jayathilaka Date: Fri, 16 Aug 2019 15:25:00 -0700 Subject: [PATCH 020/226] FCM multicast snippets (#321) --- snippets/messaging/cloud_messaging.py | 69 +++++++++++++++++++++++++++ 1 file changed, 69 insertions(+) diff --git a/snippets/messaging/cloud_messaging.py b/snippets/messaging/cloud_messaging.py index a22e7ebcf..6dc1aad10 100644 --- a/snippets/messaging/cloud_messaging.py +++ b/snippets/messaging/cloud_messaging.py @@ -220,3 +220,72 @@ def unsubscribe_from_topic(): # for the contents of response. print(response.success_count, 'tokens were unsubscribed successfully') # [END unsubscribe] + + +def send_all(): + registration_token = 'YOUR_REGISTRATION_TOKEN' + # [START send_all] + # Create a list containing up to 100 messages. + messages = [ + messaging.Message( + notification=messaging.Notification('Price drop', '5% off all electronics'), + token=registration_token, + ), + # ... + messaging.Message( + notification=messaging.Notification('Price drop', '2% off all books'), + topic='readers-club', + ), + ] + + response = messaging.send_all(messages) + # See the BatchResponse reference documentation + # for the contents of response. + print('{0} messages were sent successfully'.format(response.success_count)) + # [END send_all] + + +def send_multicast(): + # [START send_multicast] + # Create a list containing up to 100 registration tokens. + # These registration tokens come from the client FCM SDKs. + registration_tokens = [ + 'YOUR_REGISTRATION_TOKEN_1', + # ... + 'YOUR_REGISTRATION_TOKEN_N', + ] + + message = messaging.MulticastMessage( + data={'score': '850', 'time': '2:45'}, + tokens=registration_tokens, + ) + response = messaging.send_multicast(message) + # See the BatchResponse reference documentation + # for the contents of response. + print('{0} messages were sent successfully'.format(response.success_count)) + # [END send_multicast] + + +def send_multicast_and_handle_errors(): + # [START send_multicast_error] + # These registration tokens come from the client FCM SDKs. + registration_tokens = [ + 'YOUR_REGISTRATION_TOKEN_1', + # ... + 'YOUR_REGISTRATION_TOKEN_N', + ] + + message = messaging.MulticastMessage( + data={'score': '850', 'time': '2:45'}, + tokens=registration_tokens, + ) + response = messaging.send_multicast(message) + if response.failure_count > 0: + responses = response.responses + failed_tokens = [] + for idx, resp in enumerate(responses): + if not resp.success: + # The order of responses corresponds to the order of the registration tokens. + failed_tokens.append(registration_tokens[idx]) + print('List of tokens that caused failures: {0}'.format(failed_tokens)) + # [END send_multicast_error] From de4ee2df40c1bde7f76a6629c3a9b58dc4d24262 Mon Sep 17 00:00:00 2001 From: Hiranya Jayathilaka Date: Tue, 20 Aug 2019 10:28:11 -0700 Subject: [PATCH 021/226] Renamed FcmOptions types to FCMOptions (#328) --- firebase_admin/_messaging_utils.py | 46 +++++++++++++------------- firebase_admin/messaging.py | 17 +++++----- tests/test_messaging.py | 53 ++++++++++++++++++++---------- 3 files changed, 68 insertions(+), 48 deletions(-) diff --git a/firebase_admin/_messaging_utils.py b/firebase_admin/_messaging_utils.py index d6e263dcf..72e2acab3 100644 --- a/firebase_admin/_messaging_utils.py +++ b/firebase_admin/_messaging_utils.py @@ -36,7 +36,7 @@ class Message(object): android: An instance of ``messaging.AndroidConfig`` (optional). webpush: An instance of ``messaging.WebpushConfig`` (optional). apns: An instance of ``messaging.ApnsConfig`` (optional). - fcm_options: An instance of ``messaging.FcmOptions`` (optional). + fcm_options: An instance of ``messaging.FCMOptions`` (optional). token: The registration token of the device to which the message should be sent (optional). topic: Name of the FCM topic to which the message should be sent (optional). Topic name may contain the ``/topics/`` prefix. @@ -67,7 +67,7 @@ class MulticastMessage(object): android: An instance of ``messaging.AndroidConfig`` (optional). webpush: An instance of ``messaging.WebpushConfig`` (optional). apns: An instance of ``messaging.ApnsConfig`` (optional). - fcm_options: An instance of ``messaging.FcmOptions`` (optional). + fcm_options: An instance of ``messaging.FCMOptions`` (optional). """ def __init__(self, tokens, data=None, notification=None, android=None, webpush=None, apns=None, fcm_options=None): @@ -112,7 +112,7 @@ class AndroidConfig(object): data: A dictionary of data fields (optional). All keys and values in the dictionary must be strings. When specified, overrides any data fields set via ``Message.data``. notification: A ``messaging.AndroidNotification`` to be included in the message (optional). - fcm_options: A ``messaging.AndroidFcmOptions`` to be included in the message (optional). + fcm_options: A ``messaging.AndroidFCMOptions`` to be included in the message (optional). """ def __init__(self, collapse_key=None, priority=None, ttl=None, restricted_package_name=None, @@ -172,7 +172,7 @@ def __init__(self, title=None, body=None, icon=None, color=None, sound=None, tag self.channel_id = channel_id -class AndroidFcmOptions(object): +class AndroidFCMOptions(object): """Options for features provided by the FCM SDK for Android. Args: @@ -193,7 +193,7 @@ class WebpushConfig(object): data: A dictionary of data fields (optional). All keys and values in the dictionary must be strings. When specified, overrides any data fields set via ``Message.data``. notification: A ``messaging.WebpushNotification`` to be included in the message (optional). - fcm_options: A ``messaging.WebpushFcmOptions`` instance to be included in the message + fcm_options: A ``messaging.WebpushFCMOptions`` instance to be included in the message (optional). .. _Webpush Specification: https://tools.ietf.org/html/rfc8030#section-5 @@ -278,7 +278,7 @@ def __init__(self, title=None, body=None, icon=None, actions=None, badge=None, d self.custom_data = custom_data -class WebpushFcmOptions(object): +class WebpushFCMOptions(object): """Options for features provided by the FCM SDK for Web. Args: @@ -298,7 +298,7 @@ class APNSConfig(object): Args: headers: A dictionary of headers (optional). payload: A ``messaging.APNSPayload`` to be included in the message (optional). - fcm_options: A ``messaging.APNSFcmOptions`` instance to be included in the message + fcm_options: A ``messaging.APNSFCMOptions`` instance to be included in the message (optional). .. _APNS Documentation: https://developer.apple.com/library/content/documentation\ @@ -413,7 +413,7 @@ def __init__(self, title=None, subtitle=None, body=None, loc_key=None, loc_args= self.custom_data = custom_data -class APNSFcmOptions(object): +class APNSFCMOptions(object): """Options for features provided by the FCM SDK for iOS. Args: @@ -425,7 +425,7 @@ def __init__(self, analytics_label=None): self.analytics_label = analytics_label -class FcmOptions(object): +class FCMOptions(object): """Options for features provided by SDK. Args: @@ -535,15 +535,15 @@ def encode_android(cls, android): @classmethod def encode_android_fcm_options(cls, fcm_options): - """Encodes a AndroidFcmOptions instance into a json.""" + """Encodes an AndroidFCMOptions instance into a json.""" if fcm_options is None: return None - if not isinstance(fcm_options, AndroidFcmOptions): + if not isinstance(fcm_options, AndroidFCMOptions): raise ValueError('AndroidConfig.fcm_options must be an instance of ' - 'AndroidFcmOptions class.') + 'AndroidFCMOptions class.') result = { 'analytics_label': _Validators.check_analytics_label( - 'AndroidFcmOptions.analytics_label', fcm_options.analytics_label), + 'AndroidFCMOptions.analytics_label', fcm_options.analytics_label), } result = cls.remove_null_values(result) return result @@ -703,7 +703,7 @@ def encode_webpush_notification_actions(cls, actions): @classmethod def encode_webpush_fcm_options(cls, options): - """Encodes a WebpushFcmOptions instance into JSON.""" + """Encodes a WebpushFCMOptions instance into JSON.""" if options is None: return None result = { @@ -712,7 +712,7 @@ def encode_webpush_fcm_options(cls, options): result = cls.remove_null_values(result) link = result.get('link') if link is not None and not link.startswith('https://'): - raise ValueError('WebpushFcmOptions.link must be a HTTPS URL.') + raise ValueError('WebpushFCMOptions.link must be a HTTPS URL.') return result @classmethod @@ -746,14 +746,14 @@ def encode_apns_payload(cls, payload): @classmethod def encode_apns_fcm_options(cls, fcm_options): - """Encodes an APNSFcmOptions instance into JSON.""" + """Encodes an APNSFCMOptions instance into JSON.""" if fcm_options is None: return None - if not isinstance(fcm_options, APNSFcmOptions): - raise ValueError('APNSConfig.fcm_options must be an instance of APNSFcmOptions class.') + if not isinstance(fcm_options, APNSFCMOptions): + raise ValueError('APNSConfig.fcm_options must be an instance of APNSFCMOptions class.') result = { 'analytics_label': _Validators.check_analytics_label( - 'APNSFcmOptions.analytics_label', fcm_options.analytics_label), + 'APNSFCMOptions.analytics_label', fcm_options.analytics_label), } result = cls.remove_null_values(result) return result @@ -897,14 +897,14 @@ def default(self, obj): # pylint: disable=method-hidden @classmethod def encode_fcm_options(cls, fcm_options): - """Encodes an FcmOptions instance into JSON.""" + """Encodes an FCMOptions instance into JSON.""" if fcm_options is None: return None - if not isinstance(fcm_options, FcmOptions): - raise ValueError('Message.fcm_options must be an instance of FcmOptions class.') + if not isinstance(fcm_options, FCMOptions): + raise ValueError('Message.fcm_options must be an instance of FCMOptions class.') result = { 'analytics_label': _Validators.check_analytics_label( - 'FcmOptions.analytics_label', fcm_options.analytics_label), + 'FCMOptions.analytics_label', fcm_options.analytics_label), } result = cls.remove_null_values(result) return result diff --git a/firebase_admin/messaging.py b/firebase_admin/messaging.py index ddaef19f0..c0f023169 100644 --- a/firebase_admin/messaging.py +++ b/firebase_admin/messaging.py @@ -33,10 +33,10 @@ __all__ = [ 'AndroidConfig', - 'AndroidFcmOptions', + 'AndroidFCMOptions', 'AndroidNotification', 'APNSConfig', - 'APNSFcmOptions', + 'APNSFCMOptions', 'APNSPayload', 'ApiCallError', 'Aps', @@ -44,14 +44,14 @@ 'BatchResponse', 'CriticalSound', 'ErrorInfo', - 'FcmOptions', + 'FCMOptions', 'Message', 'MulticastMessage', 'Notification', 'SendResponse', 'TopicManagementResponse', 'WebpushConfig', - 'WebpushFcmOptions', + 'WebpushFCMOptions', 'WebpushNotification', 'WebpushNotificationAction', @@ -64,20 +64,21 @@ AndroidConfig = _messaging_utils.AndroidConfig -AndroidFcmOptions = _messaging_utils.AndroidFcmOptions +AndroidFCMOptions = _messaging_utils.AndroidFCMOptions AndroidNotification = _messaging_utils.AndroidNotification APNSConfig = _messaging_utils.APNSConfig -APNSFcmOptions = _messaging_utils.APNSFcmOptions +APNSFCMOptions = _messaging_utils.APNSFCMOptions APNSPayload = _messaging_utils.APNSPayload Aps = _messaging_utils.Aps ApsAlert = _messaging_utils.ApsAlert CriticalSound = _messaging_utils.CriticalSound -FcmOptions = _messaging_utils.FcmOptions +FCMOptions = _messaging_utils.FCMOptions Message = _messaging_utils.Message MulticastMessage = _messaging_utils.MulticastMessage Notification = _messaging_utils.Notification WebpushConfig = _messaging_utils.WebpushConfig -WebpushFcmOptions = _messaging_utils.WebpushFcmOptions +WebpushFCMOptions = _messaging_utils.WebpushFCMOptions +WebpushFcmOptions = _messaging_utils.WebpushFCMOptions WebpushNotification = _messaging_utils.WebpushNotification WebpushNotificationAction = _messaging_utils.WebpushNotificationAction diff --git a/tests/test_messaging.py b/tests/test_messaging.py index 67ee0721d..0b8739195 100644 --- a/tests/test_messaging.py +++ b/tests/test_messaging.py @@ -123,10 +123,10 @@ def test_prefixed_topic(self): def test_fcm_options(self): check_encoding( messaging.Message( - topic='topic', fcm_options=messaging.FcmOptions('analytics_label_v1')), + topic='topic', fcm_options=messaging.FCMOptions('analytics_label_v1')), {'topic': 'topic', 'fcm_options': {'analytics_label': 'analytics_label_v1'}}) check_encoding( - messaging.Message(topic='topic', fcm_options=messaging.FcmOptions()), + messaging.Message(topic='topic', fcm_options=messaging.FCMOptions()), {'topic': 'topic'}) @@ -177,27 +177,27 @@ def test_invalid_fcm_options(self, label): with pytest.raises(ValueError) as excinfo: check_encoding(messaging.Message( topic='topic', - fcm_options=messaging.FcmOptions(label) + fcm_options=messaging.FCMOptions(label) )) - expected = 'Malformed FcmOptions.analytics_label.' + expected = 'Malformed FCMOptions.analytics_label.' assert str(excinfo.value) == expected def test_fcm_options(self): check_encoding( messaging.Message( topic='topic', - fcm_options=messaging.FcmOptions(), - android=messaging.AndroidConfig(fcm_options=messaging.AndroidFcmOptions()), - apns=messaging.APNSConfig(fcm_options=messaging.APNSFcmOptions()) + fcm_options=messaging.FCMOptions(), + android=messaging.AndroidConfig(fcm_options=messaging.AndroidFCMOptions()), + apns=messaging.APNSConfig(fcm_options=messaging.APNSFCMOptions()) ), {'topic': 'topic'}) check_encoding( messaging.Message( topic='topic', - fcm_options=messaging.FcmOptions('message-label'), + fcm_options=messaging.FCMOptions('message-label'), android=messaging.AndroidConfig( - fcm_options=messaging.AndroidFcmOptions('android-label')), - apns=messaging.APNSConfig(fcm_options=messaging.APNSFcmOptions('apns-label')) + fcm_options=messaging.AndroidFCMOptions('android-label')), + apns=messaging.APNSConfig(fcm_options=messaging.APNSFCMOptions('apns-label')) ), { 'topic': 'topic', @@ -267,7 +267,7 @@ def test_android_config(self): priority='high', ttl=123, data={'k1': 'v1', 'k2': 'v2'}, - fcm_options=messaging.AndroidFcmOptions('analytics_label_v1') + fcm_options=messaging.AndroidFCMOptions('analytics_label_v1') ) ) expected = { @@ -500,7 +500,7 @@ def test_webpush_config(self): check_encoding(msg, expected) -class TestWebpushFcmOptionsEncoder(object): +class TestWebpushFCMOptionsEncoder(object): @pytest.mark.parametrize('data', NON_OBJECT_ARGS) def test_invalid_webpush_fcm_options(self, data): @@ -510,7 +510,7 @@ def test_invalid_webpush_fcm_options(self, data): @pytest.mark.parametrize('data', NON_STRING_ARGS) def test_invalid_link_type(self, data): - options = messaging.WebpushFcmOptions(link=data) + options = messaging.WebpushFCMOptions(link=data) with pytest.raises(ValueError) as excinfo: check_encoding(messaging.Message( topic='topic', webpush=messaging.WebpushConfig(fcm_options=options))) @@ -519,14 +519,33 @@ def test_invalid_link_type(self, data): @pytest.mark.parametrize('data', ['', 'foo', 'http://example']) def test_invalid_link_format(self, data): - options = messaging.WebpushFcmOptions(link=data) + options = messaging.WebpushFCMOptions(link=data) with pytest.raises(ValueError) as excinfo: check_encoding(messaging.Message( topic='topic', webpush=messaging.WebpushConfig(fcm_options=options))) - expected = 'WebpushFcmOptions.link must be a HTTPS URL.' + expected = 'WebpushFCMOptions.link must be a HTTPS URL.' assert str(excinfo.value) == expected - def test_webpush_notification(self): + def test_webpush_options(self): + msg = messaging.Message( + topic='topic', + webpush=messaging.WebpushConfig( + fcm_options=messaging.WebpushFCMOptions( + link='https://example', + ), + ) + ) + expected = { + 'topic': 'topic', + 'webpush': { + 'fcm_options': { + 'link': 'https://example', + }, + }, + } + check_encoding(msg, expected) + + def test_deprecated_fcm_options(self): msg = messaging.Message( topic='topic', webpush=messaging.WebpushConfig( @@ -770,7 +789,7 @@ def test_apns_config(self): topic='topic', apns=messaging.APNSConfig( headers={'h1': 'v1', 'h2': 'v2'}, - fcm_options=messaging.APNSFcmOptions('analytics_label_v1') + fcm_options=messaging.APNSFCMOptions('analytics_label_v1') ), ) expected = { From 150a8551359f818ce5c2edf6acbc02d48ba98747 Mon Sep 17 00:00:00 2001 From: Hiranya Jayathilaka Date: Wed, 21 Aug 2019 10:38:49 -0700 Subject: [PATCH 022/226] Bumped version to 2.18.0 (#333) * Bumped version to 2.18.0 * Removing extra whitespace --- CHANGELOG.md | 10 ++++++++++ firebase_admin/__about__.py | 2 +- 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5fc1a759e..0c6378f49 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,16 @@ - +# v2.18.0 + +- [added] Added support for specifying the analytics label for notifications. +- [added] Added support for arbitrary key-value pairs in `messaging.ApsAlert`. +- [changed] The `WebpushFcmOptions` type is now deprecated. Developers should use + the PEP8 compliant type name `WebpushFCMOptions` instead. +- [added] Developers can now test their Database API calls by directing the + SDK traffic to the RTDB emulator. Set the `FIREBASE_DATABASE_EMULATOR_HOST` + environment variable to specify the emulator endpoint in `host:port` format. + # v2.17.0 - [added] Added new `send_all()` and `send_multicast()` APIs to the diff --git a/firebase_admin/__about__.py b/firebase_admin/__about__.py index a5141c5e3..5a2f77e32 100644 --- a/firebase_admin/__about__.py +++ b/firebase_admin/__about__.py @@ -14,7 +14,7 @@ """About information (version, etc) for Firebase Admin SDK.""" -__version__ = '2.17.0' +__version__ = '2.18.0' __title__ = 'firebase_admin' __author__ = 'Firebase' __license__ = 'Apache License 2.0' From 5d4d6cb83455f993eff19d93d45b75d794c27ae2 Mon Sep 17 00:00:00 2001 From: Herbert Verdida <38489033+bertdida@users.noreply.github.com> Date: Fri, 6 Sep 2019 02:11:36 +0800 Subject: [PATCH 023/226] Fix: corrected ValueError message (#338) --- firebase_admin/auth.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/firebase_admin/auth.py b/firebase_admin/auth.py index 0800d7c1e..fba5f3540 100644 --- a/firebase_admin/auth.py +++ b/firebase_admin/auth.py @@ -146,7 +146,7 @@ def verify_id_token(id_token, app=None, check_revoked=False): if not isinstance(check_revoked, bool): # guard against accidental wrong assignment. raise ValueError('Illegal check_revoked argument. Argument must be of type ' - ' bool, but given "{0}".'.format(type(app))) + ' bool, but given "{0}".'.format(type(check_revoked))) token_verifier = _get_auth_service(app).token_verifier verified_claims = token_verifier.verify_id_token(id_token) if check_revoked: From 0ed372e6a81c71fad07c9e80b02b2c7a3e27d1c1 Mon Sep 17 00:00:00 2001 From: cchamm Date: Wed, 11 Sep 2019 02:55:48 +0800 Subject: [PATCH 024/226] feat(fcm): Added support for sending an image URL in notifications (#332) * Add "image" field to AndroidNotification and Notification * Add Doc String to _messaging_utils.encode_notification * Shorten Line Width * Add "image" field to APNSFCMOptions. * APNSFCMOptions: Shorten line --- firebase_admin/_messaging_utils.py | 19 ++++++++++++++++--- integration/test_messaging.py | 8 ++++++-- tests/test_messaging.py | 10 ++++++++-- 3 files changed, 30 insertions(+), 7 deletions(-) diff --git a/firebase_admin/_messaging_utils.py b/firebase_admin/_messaging_utils.py index 72e2acab3..34738b168 100644 --- a/firebase_admin/_messaging_utils.py +++ b/firebase_admin/_messaging_utils.py @@ -89,11 +89,13 @@ class Notification(object): Args: title: Title of the notification (optional). body: Body of the notification (optional). + image: Image url of the notification (optional) """ - def __init__(self, title=None, body=None): + def __init__(self, title=None, body=None, image=None): self.title = title self.body = body + self.image = image class AndroidConfig(object): @@ -153,11 +155,12 @@ class AndroidNotification(object): title_loc_args: A list of resource keys that will be used in place of the format specifiers in ``title_loc_key`` (optional). channel_id: channel_id of the notification (optional). + image: Image url of the notification (optional). """ def __init__(self, title=None, body=None, icon=None, color=None, sound=None, tag=None, click_action=None, body_loc_key=None, body_loc_args=None, title_loc_key=None, - title_loc_args=None, channel_id=None): + title_loc_args=None, channel_id=None, image=None): self.title = title self.body = body self.icon = icon @@ -170,6 +173,7 @@ def __init__(self, title=None, body=None, icon=None, color=None, sound=None, tag self.title_loc_key = title_loc_key self.title_loc_args = title_loc_args self.channel_id = channel_id + self.image = image class AndroidFCMOptions(object): @@ -419,10 +423,13 @@ class APNSFCMOptions(object): Args: analytics_label: contains additional options for features provided by the FCM iOS SDK (optional). + image: contains the URL of an image that is going to be displayed in a notification + (optional). """ - def __init__(self, analytics_label=None): + def __init__(self, analytics_label=None, image=None): self.analytics_label = analytics_label + self.image = image class FCMOptions(object): @@ -600,6 +607,9 @@ def encode_android_notification(cls, notification): 'AndroidNotification.title_loc_key', notification.title_loc_key), 'channel_id': _Validators.check_string( 'AndroidNotification.channel_id', notification.channel_id), + 'image': _Validators.check_string( + 'image', notification.image + ) } result = cls.remove_null_values(result) color = result.get('color') @@ -754,6 +764,7 @@ def encode_apns_fcm_options(cls, fcm_options): result = { 'analytics_label': _Validators.check_analytics_label( 'APNSFCMOptions.analytics_label', fcm_options.analytics_label), + 'image': _Validators.check_string('APNSFCMOptions.image', fcm_options.image) } result = cls.remove_null_values(result) return result @@ -851,6 +862,7 @@ def encode_aps_alert(cls, alert): @classmethod def encode_notification(cls, notification): + """Encodes an Notification instance into JSON.""" if notification is None: return None if not isinstance(notification, Notification): @@ -858,6 +870,7 @@ def encode_notification(cls, notification): result = { 'body': _Validators.check_string('Notification.body', notification.body), 'title': _Validators.check_string('Notification.title', notification.title), + 'image': _Validators.check_string('Notification.image', notification.image) } return cls.remove_null_values(result) diff --git a/integration/test_messaging.py b/integration/test_messaging.py index 7ebd5866a..b1caa09f9 100644 --- a/integration/test_messaging.py +++ b/integration/test_messaging.py @@ -27,12 +27,16 @@ def test_send(): msg = messaging.Message( topic='foo-bar', - notification=messaging.Notification('test-title', 'test-body'), + notification=messaging.Notification('test-title', 'test-body', + 'https://images.unsplash.com/photo-1494438639946' + '-1ebd1d20bf85?fit=crop&w=900&q=60'), android=messaging.AndroidConfig( restricted_package_name='com.google.firebase.demos', notification=messaging.AndroidNotification( title='android-title', - body='android-body' + body='android-body', + image='https://images.unsplash.com/' + 'photo-1494438639946-1ebd1d20bf85?fit=crop&w=900&q=60' ) ), apns=messaging.APNSConfig(payload=messaging.APNSPayload( diff --git a/tests/test_messaging.py b/tests/test_messaging.py index 0b8739195..4f7520045 100644 --- a/tests/test_messaging.py +++ b/tests/test_messaging.py @@ -197,13 +197,19 @@ def test_fcm_options(self): fcm_options=messaging.FCMOptions('message-label'), android=messaging.AndroidConfig( fcm_options=messaging.AndroidFCMOptions('android-label')), - apns=messaging.APNSConfig(fcm_options=messaging.APNSFCMOptions('apns-label')) + apns=messaging.APNSConfig(fcm_options= + messaging.APNSFCMOptions( + analytics_label='apns-label', + image='https://images.unsplash.com/photo-14944386399' + '46-1ebd1d20bf85?fit=crop&w=900&q=60')) ), { 'topic': 'topic', 'fcm_options': {'analytics_label': 'message-label'}, 'android': {'fcm_options': {'analytics_label': 'android-label'}}, - 'apns': {'fcm_options': {'analytics_label': 'apns-label'}}, + 'apns': {'fcm_options': {'analytics_label': 'apns-label', + 'image': 'https://images.unsplash.com/photo-14944386399' + '46-1ebd1d20bf85?fit=crop&w=900&q=60'}}, }) From 9406afe157c2aaab0d0b1da081c5c5dee8200ffe Mon Sep 17 00:00:00 2001 From: Hiranya Jayathilaka Date: Tue, 10 Sep 2019 14:26:59 -0700 Subject: [PATCH 025/226] Error handling revamp (v3 release) (#334) * Introduced the exceptions module (#296) * Added the exceptions module * Cleaned up the error handling logic; Added tests * Updated docs; Fixed some typos * Migrating FCM Send APIs to the New Exceptions (#297) * Migrated FCM send APIs to the new error handling regime * Moved error parsing logic to _utils * Refactored OP error handling code * Fixing a broken test * Added utils for handling googleapiclient errors * Added tests for new error handling logic * Updated public API docs * Fixing test for python3 * Cleaning up the error code lookup code * Cleaning up the error parsing APIs * Cleaned up error parsing logic; Updated docs * Migrated remaining messaging APIs to new error types (#298) * Migrated FCM send APIs to the new error handling regime * Moved error parsing logic to _utils * Refactored OP error handling code * Fixing a broken test * Added utils for handling googleapiclient errors * Added tests for new error handling logic * Updated public API docs * Fixing test for python3 * Cleaning up the error code lookup code * Cleaning up the error parsing APIs * Cleaned up error parsing logic; Updated docs * Migrated the FCM IID APIs to the new error types * Introducing TokenSignError to represent custom token creation errors (#302) * Migrated FCM send APIs to the new error handling regime * Moved error parsing logic to _utils * Refactored OP error handling code * Fixing a broken test * Added utils for handling googleapiclient errors * Added tests for new error handling logic * Updated public API docs * Fixing test for python3 * Cleaning up the error code lookup code * Cleaning up the error parsing APIs * Cleaned up error parsing logic; Updated docs * Migrated the FCM IID APIs to the new error types * Migrated custom token API to new error types * Raising FirebaseError from create_session_cookie() API (#306) * Migrated FCM send APIs to the new error handling regime * Moved error parsing logic to _utils * Refactored OP error handling code * Fixing a broken test * Added utils for handling googleapiclient errors * Added tests for new error handling logic * Updated public API docs * Fixing test for python3 * Cleaning up the error code lookup code * Cleaning up the error parsing APIs * Cleaned up error parsing logic; Updated docs * Migrated the FCM IID APIs to the new error types * Migrated custom token API to new error types * Migrated create cookie API to new error types * Improved error message computation * Refactored the shared error handling code * Fixing lint errors * Renamed variable for clarity * Introducing UserNotFoundError type (#309) * Added UserNotFoundError type * Fixed some lint errors * Some formatting updates * Updated docs and tests * New error handling support in create/update/delete user APIs (#311) * New error handling support in create/update/delete user APIs * Fixing some lint errors * Error handling improvements in email action link APIs (#312) * New error handling support in create/update/delete user APIs * Fixing some lint errors * Error handling update in email action link APIs * Project management API migrated to new error types (#314) * Error handling updated for remaining user_mgt APIs (#315) * Error handling updated for remaining user_mgt APIs * Removed unused constants * Migrated token verification APIs to new exception types (#317) * Migrated token verification APIs to new error types * Removed old AuthError type * Added new exception types for revoked tokens * Migrated the db module to the new exception types (#318) * Migrating db module to new exception types * Error handling for transactions * Updated integration tests * Restoring the old txn abort behavior * Updated error type in snippet * Added comment * Adding a few overlooked error types (#319) * Adding some missing error types * Updated documentation * Removing the ability to delete user properties by passing None (#320) * Some types renamed to be PEP8 compliant (#330) * Upgraded Cloud Firestore and Cloud Storage dependencies (#325) * Added documentation for error codes (#339) * A few API doc updates (#340) * Added documentation for error codes * Updated API docs --- CHANGELOG.md | 5 +- README.md | 8 +- firebase_admin/_auth_utils.py | 121 +++++++++ firebase_admin/_http_client.py | 4 + firebase_admin/_messaging_utils.py | 32 +++ firebase_admin/_token_gen.py | 150 +++++++---- firebase_admin/_user_mgt.py | 136 +++++----- firebase_admin/_utils.py | 259 +++++++++++++++++++ firebase_admin/auth.py | 191 +++++++------- firebase_admin/db.py | 95 ++++--- firebase_admin/exceptions.py | 237 +++++++++++++++++ firebase_admin/instance_id.py | 15 +- firebase_admin/messaging.py | 155 ++++-------- firebase_admin/project_management.py | 146 ++++------- integration/test_auth.py | 34 +-- integration/test_db.py | 33 ++- integration/test_instance_id.py | 3 +- integration/test_messaging.py | 21 +- integration/test_project_management.py | 43 ++-- lint.sh | 2 +- requirements.txt | 6 +- setup.py | 6 +- snippets/auth/index.py | 55 ++-- snippets/database/index.py | 2 +- tests/test_db.py | 97 +++++-- tests/test_exceptions.py | 335 +++++++++++++++++++++++++ tests/test_instance_id.py | 51 +++- tests/test_messaging.py | 242 +++++++++--------- tests/test_project_management.py | 242 ++++++++++-------- tests/test_token_gen.py | 133 +++++++--- tests/test_user_mgt.py | 230 +++++++++++++---- 31 files changed, 2175 insertions(+), 914 deletions(-) create mode 100644 firebase_admin/exceptions.py create mode 100644 tests/test_exceptions.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 0c6378f49..72c40dbe0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,9 @@ # Unreleased -- +- [added] Added the new `firebase_admin.exceptions` module containing the + base exception types and global error codes. +- [changed] Updated the `firebase_admin.instance_id` module to use the new + shared exception types. The type `instance_id.ApiCallError` was removed. # v2.18.0 diff --git a/README.md b/README.md index 80adc0583..757a3f8cd 100644 --- a/README.md +++ b/README.md @@ -36,13 +36,15 @@ pip install firebase-admin Please refer to the [CONTRIBUTING page](./CONTRIBUTING.md) for more information about how you can contribute to this project. We welcome bug reports, feature -requests, code review feedback, and also pull requests. +requests, code review feedback, and also pull requests. ## Supported Python Versions -We support Python 2.7 and Python 3.3+. Firebase Admin Python SDK is also tested -on PyPy and [Google App Engine](https://cloud.google.com/appengine/) environments. +We currently support Python 2.7 and Python 3.4+. However, Python 2.7 support is +being phased out, and the developers are advised to use latest Python 3. +Firebase Admin Python SDK is also tested on PyPy and +[Google App Engine](https://cloud.google.com/appengine/) environments. ## Documentation diff --git a/firebase_admin/_auth_utils.py b/firebase_admin/_auth_utils.py index b6788355c..d90b494f5 100644 --- a/firebase_admin/_auth_utils.py +++ b/firebase_admin/_auth_utils.py @@ -20,6 +20,9 @@ import six from six.moves import urllib +from firebase_admin import exceptions +from firebase_admin import _utils + MAX_CLAIMS_PAYLOAD_SIZE = 1000 RESERVED_CLAIMS = set([ @@ -188,3 +191,121 @@ def validate_action_type(action_type): raise ValueError('Invalid action type provided action_type: {0}. \ Valid values are {1}'.format(action_type, ', '.join(VALID_EMAIL_ACTION_TYPES))) return action_type + + +class UidAlreadyExistsError(exceptions.AlreadyExistsError): + """The user with the provided uid already exists.""" + + default_message = 'The user with the provided uid already exists' + + def __init__(self, message, cause, http_response): + exceptions.AlreadyExistsError.__init__(self, message, cause, http_response) + + +class EmailAlreadyExistsError(exceptions.AlreadyExistsError): + """The user with the provided email already exists.""" + + default_message = 'The user with the provided email already exists' + + def __init__(self, message, cause, http_response): + exceptions.AlreadyExistsError.__init__(self, message, cause, http_response) + + +class InvalidDynamicLinkDomainError(exceptions.InvalidArgumentError): + """Dynamic link domain in ActionCodeSettings is not authorized.""" + + default_message = 'Dynamic link domain specified in ActionCodeSettings is not authorized' + + def __init__(self, message, cause, http_response): + exceptions.InvalidArgumentError.__init__(self, message, cause, http_response) + + +class InvalidIdTokenError(exceptions.InvalidArgumentError): + """The provided ID token is not a valid Firebase ID token.""" + + default_message = 'The provided ID token is invalid' + + def __init__(self, message, cause=None, http_response=None): + exceptions.InvalidArgumentError.__init__(self, message, cause, http_response) + + +class PhoneNumberAlreadyExistsError(exceptions.AlreadyExistsError): + """The user with the provided phone number already exists.""" + + default_message = 'The user with the provided phone number already exists' + + def __init__(self, message, cause, http_response): + exceptions.AlreadyExistsError.__init__(self, message, cause, http_response) + + +class UnexpectedResponseError(exceptions.UnknownError): + """Backend service responded with an unexpected or malformed response.""" + + def __init__(self, message, cause=None, http_response=None): + exceptions.UnknownError.__init__(self, message, cause, http_response) + + +class UserNotFoundError(exceptions.NotFoundError): + """No user record found for the specified identifier.""" + + default_message = 'No user record found for the given identifier' + + def __init__(self, message, cause=None, http_response=None): + exceptions.NotFoundError.__init__(self, message, cause, http_response) + + +_CODE_TO_EXC_TYPE = { + 'DUPLICATE_EMAIL': EmailAlreadyExistsError, + 'DUPLICATE_LOCAL_ID': UidAlreadyExistsError, + 'INVALID_DYNAMIC_LINK_DOMAIN': InvalidDynamicLinkDomainError, + 'INVALID_ID_TOKEN': InvalidIdTokenError, + 'PHONE_NUMBER_EXISTS': PhoneNumberAlreadyExistsError, + 'USER_NOT_FOUND': UserNotFoundError, +} + + +def handle_auth_backend_error(error): + """Converts a requests error received from the Firebase Auth service into a FirebaseError.""" + if error.response is None: + raise _utils.handle_requests_error(error) + + code, custom_message = _parse_error_body(error.response) + if not code: + msg = 'Unexpected error response: {0}'.format(error.response.content.decode()) + raise _utils.handle_requests_error(error, message=msg) + + exc_type = _CODE_TO_EXC_TYPE.get(code) + msg = _build_error_message(code, exc_type, custom_message) + if not exc_type: + return _utils.handle_requests_error(error, message=msg) + + return exc_type(msg, cause=error, http_response=error.response) + + +def _parse_error_body(response): + """Parses the given error response to extract Auth error code and message.""" + error_dict = {} + try: + parsed_body = response.json() + if isinstance(parsed_body, dict): + error_dict = parsed_body.get('error', {}) + except ValueError: + pass + + # Auth error response format: {"error": {"message": "AUTH_ERROR_CODE: Optional text"}} + code = error_dict.get('message') if isinstance(error_dict, dict) else None + custom_message = None + if code: + separator = code.find(':') + if separator != -1: + custom_message = code[separator + 1:].strip() + code = code[:separator] + + return code, custom_message + + +def _build_error_message(code, exc_type, custom_message): + default_message = exc_type.default_message if ( + exc_type and hasattr(exc_type, 'default_message')) else 'Error while calling Auth service' + ext = ' {0}'.format(custom_message) if custom_message else '' + return '{0} ({1}).{2}'.format(default_message, code, ext) diff --git a/firebase_admin/_http_client.py b/firebase_admin/_http_client.py index 73028f833..eb8c4027a 100644 --- a/firebase_admin/_http_client.py +++ b/firebase_admin/_http_client.py @@ -109,6 +109,10 @@ def headers(self, method, url, **kwargs): resp = self.request(method, url, **kwargs) return resp.headers + def body_and_response(self, method, url, **kwargs): + resp = self.request(method, url, **kwargs) + return self.parse_body(resp), resp + def body(self, method, url, **kwargs): resp = self.request(method, url, **kwargs) return self.parse_body(resp) diff --git a/firebase_admin/_messaging_utils.py b/firebase_admin/_messaging_utils.py index 34738b168..5c99cb8ef 100644 --- a/firebase_admin/_messaging_utils.py +++ b/firebase_admin/_messaging_utils.py @@ -22,6 +22,8 @@ import six +from firebase_admin import exceptions + class Message(object): """A message that can be sent via Firebase Cloud Messaging. @@ -921,3 +923,33 @@ def encode_fcm_options(cls, fcm_options): } result = cls.remove_null_values(result) return result + + +class ThirdPartyAuthError(exceptions.UnauthenticatedError): + """APNs certificate or web push auth key was invalid or missing.""" + + def __init__(self, message, cause=None, http_response=None): + exceptions.UnauthenticatedError.__init__(self, message, cause, http_response) + + +class QuotaExceededError(exceptions.ResourceExhaustedError): + """Sending limit exceeded for the message target.""" + + def __init__(self, message, cause=None, http_response=None): + exceptions.ResourceExhaustedError.__init__(self, message, cause, http_response) + + +class SenderIdMismatchError(exceptions.PermissionDeniedError): + """The authenticated sender ID is different from the sender ID for the registration token.""" + + def __init__(self, message, cause=None, http_response=None): + exceptions.PermissionDeniedError.__init__(self, message, cause, http_response) + + +class UnregisteredError(exceptions.NotFoundError): + """App instance was unregistered from FCM. + + This usually means that the token used is no longer valid and a new one must be used.""" + + def __init__(self, message, cause=None, http_response=None): + exceptions.NotFoundError.__init__(self, message, cause, http_response) diff --git a/firebase_admin/_token_gen.py b/firebase_admin/_token_gen.py index 7af7b73b7..339714dcd 100644 --- a/firebase_admin/_token_gen.py +++ b/firebase_admin/_token_gen.py @@ -21,13 +21,16 @@ import requests import six from google.auth import credentials -from google.auth import exceptions from google.auth import iam from google.auth import jwt from google.auth import transport +import google.auth.exceptions import google.oauth2.id_token import google.oauth2.service_account +from firebase_admin import exceptions +from firebase_admin import _auth_utils + # ID token constants ID_TOKEN_ISSUER_PREFIX = 'https://securetoken.google.com/' @@ -51,19 +54,6 @@ METADATA_SERVICE_URL = ('http://metadata.google.internal/computeMetadata/v1/instance/' 'service-accounts/default/email') -# Error codes -COOKIE_CREATE_ERROR = 'COOKIE_CREATE_ERROR' -TOKEN_SIGN_ERROR = 'TOKEN_SIGN_ERROR' - - -class ApiCallError(Exception): - """Represents an Exception encountered while invoking the ID toolkit API.""" - - def __init__(self, code, message, error=None): - Exception.__init__(self, message) - self.code = code - self.detail = error - class _SigningProvider(object): """Stores a reference to a google.auth.crypto.Signer.""" @@ -177,9 +167,9 @@ def create_custom_token(self, uid, developer_claims=None): payload['claims'] = developer_claims try: return jwt.encode(signing_provider.signer, payload) - except exceptions.TransportError as error: + except google.auth.exceptions.TransportError as error: msg = 'Failed to sign custom token. {0}'.format(error) - raise ApiCallError(TOKEN_SIGN_ERROR, msg, error) + raise TokenSignError(msg, error) def create_session_cookie(self, id_token, expires_in): @@ -206,20 +196,15 @@ def create_session_cookie(self, id_token, expires_in): 'validDuration': expires_in, } try: - response = self.client.body('post', ':createSessionCookie', json=payload) + body, http_resp = self.client.body_and_response( + 'post', ':createSessionCookie', json=payload) except requests.exceptions.RequestException as error: - self._handle_http_error(COOKIE_CREATE_ERROR, 'Failed to create session cookie', error) + raise _auth_utils.handle_auth_backend_error(error) else: - if not response or not response.get('sessionCookie'): - raise ApiCallError(COOKIE_CREATE_ERROR, 'Failed to create session cookie.') - return response.get('sessionCookie') - - def _handle_http_error(self, code, msg, error): - if error.response is not None: - msg += '\nServer response: {0}'.format(error.response.content.decode()) - else: - msg += '\nReason: {0}'.format(error) - raise ApiCallError(code, msg, error) + if not body or not body.get('sessionCookie'): + raise _auth_utils.UnexpectedResponseError( + 'Failed to create session cookie.', http_response=http_resp) + return body.get('sessionCookie') class TokenVerifier(object): @@ -232,12 +217,18 @@ def __init__(self, app): project_id=app.project_id, short_name='ID token', operation='verify_id_token()', doc_url='https://firebase.google.com/docs/auth/admin/verify-id-tokens', - cert_url=ID_TOKEN_CERT_URI, issuer=ID_TOKEN_ISSUER_PREFIX) + cert_url=ID_TOKEN_CERT_URI, + issuer=ID_TOKEN_ISSUER_PREFIX, + invalid_token_error=_auth_utils.InvalidIdTokenError, + expired_token_error=ExpiredIdTokenError) self.cookie_verifier = _JWTVerifier( project_id=app.project_id, short_name='session cookie', operation='verify_session_cookie()', doc_url='https://firebase.google.com/docs/auth/admin/verify-id-tokens', - cert_url=COOKIE_CERT_URI, issuer=COOKIE_ISSUER_PREFIX) + cert_url=COOKIE_CERT_URI, + issuer=COOKIE_ISSUER_PREFIX, + invalid_token_error=InvalidSessionCookieError, + expired_token_error=ExpiredSessionCookieError) def verify_id_token(self, id_token): return self.id_token_verifier.verify(id_token, self.request) @@ -260,6 +251,8 @@ def __init__(self, **kwargs): self.articled_short_name = 'an {0}'.format(self.short_name) else: self.articled_short_name = 'a {0}'.format(self.short_name) + self._invalid_token_error = kwargs.pop('invalid_token_error') + self._expired_token_error = kwargs.pop('expired_token_error') def verify(self, token, request): """Verifies the signature and data for the provided JWT.""" @@ -276,8 +269,7 @@ def verify(self, token, request): 'or set your Firebase project ID as an app option. Alternatively set the ' 'GOOGLE_CLOUD_PROJECT environment variable.'.format(self.operation)) - header = jwt.decode_header(token) - payload = jwt.decode(token, verify=False) + header, payload = self._decode_unverified(token) issuer = payload.get('iss') audience = payload.get('aud') subject = payload.get('sub') @@ -290,12 +282,12 @@ def verify(self, token, request): 'See {0} for details on how to retrieve {1}.'.format(self.url, self.short_name)) error_message = None - if not header.get('kid'): - if audience == FIREBASE_AUDIENCE: - error_message = ( - '{0} expects {1}, but was given a custom ' - 'token.'.format(self.operation, self.articled_short_name)) - elif header.get('alg') == 'HS256' and payload.get( + if audience == FIREBASE_AUDIENCE: + error_message = ( + '{0} expects {1}, but was given a custom ' + 'token.'.format(self.operation, self.articled_short_name)) + elif not header.get('kid'): + if header.get('alg') == 'HS256' and payload.get( 'v') is 0 and 'uid' in payload.get('d', {}): error_message = ( '{0} expects {1}, but was given a legacy custom ' @@ -330,12 +322,76 @@ def verify(self, token, request): '{1}'.format(self.short_name, verify_id_token_msg)) if error_message: - raise ValueError(error_message) - - verified_claims = google.oauth2.id_token.verify_token( - token, - request=request, - audience=self.project_id, - certs_url=self.cert_url) - verified_claims['uid'] = verified_claims['sub'] - return verified_claims + raise self._invalid_token_error(error_message) + + try: + verified_claims = google.oauth2.id_token.verify_token( + token, + request=request, + audience=self.project_id, + certs_url=self.cert_url) + verified_claims['uid'] = verified_claims['sub'] + return verified_claims + except google.auth.exceptions.TransportError as error: + raise CertificateFetchError(str(error), cause=error) + except ValueError as error: + if 'Token expired' in str(error): + raise self._expired_token_error(str(error), cause=error) + raise self._invalid_token_error(str(error), cause=error) + + def _decode_unverified(self, token): + try: + header = jwt.decode_header(token) + payload = jwt.decode(token, verify=False) + return header, payload + except ValueError as error: + raise self._invalid_token_error(str(error), cause=error) + + +class TokenSignError(exceptions.UnknownError): + """Unexpected error while signing a Firebase custom token.""" + + def __init__(self, message, cause): + exceptions.UnknownError.__init__(self, message, cause) + + +class CertificateFetchError(exceptions.UnknownError): + """Failed to fetch some public key certificates required to verify a token.""" + + def __init__(self, message, cause): + exceptions.UnknownError.__init__(self, message, cause) + + +class ExpiredIdTokenError(_auth_utils.InvalidIdTokenError): + """The provided ID token is expired.""" + + def __init__(self, message, cause): + _auth_utils.InvalidIdTokenError.__init__(self, message, cause) + + +class RevokedIdTokenError(_auth_utils.InvalidIdTokenError): + """The provided ID token has been revoked.""" + + def __init__(self, message): + _auth_utils.InvalidIdTokenError.__init__(self, message) + + +class InvalidSessionCookieError(exceptions.InvalidArgumentError): + """The provided string is not a valid Firebase session cookie.""" + + def __init__(self, message, cause=None): + exceptions.InvalidArgumentError.__init__(self, message, cause) + + +class ExpiredSessionCookieError(InvalidSessionCookieError): + """The provided session cookie is expired.""" + + def __init__(self, message, cause): + InvalidSessionCookieError.__init__(self, message, cause) + + +class RevokedSessionCookieError(InvalidSessionCookieError): + """The provided session cookie has been revoked.""" + + def __init__(self, message): + InvalidSessionCookieError.__init__(self, message) diff --git a/firebase_admin/_user_mgt.py b/firebase_admin/_user_mgt.py index 24bb2bdb6..867b6dd89 100644 --- a/firebase_admin/_user_mgt.py +++ b/firebase_admin/_user_mgt.py @@ -24,15 +24,6 @@ from firebase_admin import _user_import -INTERNAL_ERROR = 'INTERNAL_ERROR' -USER_NOT_FOUND_ERROR = 'USER_NOT_FOUND_ERROR' -USER_CREATE_ERROR = 'USER_CREATE_ERROR' -USER_UPDATE_ERROR = 'USER_UPDATE_ERROR' -USER_DELETE_ERROR = 'USER_DELETE_ERROR' -USER_IMPORT_ERROR = 'USER_IMPORT_ERROR' -USER_DOWNLOAD_ERROR = 'LIST_USERS_ERROR' -GENERATE_EMAIL_ACTION_LINK_ERROR = 'GENERATE_EMAIL_ACTION_LINK_ERROR' - MAX_LIST_USERS_RESULTS = 1000 MAX_IMPORT_USERS_SIZE = 1000 @@ -43,22 +34,9 @@ def __init__(self, description): self.description = description -# Use this internally, until sentinels are available in the public API. -_UNSPECIFIED = Sentinel('No value specified') - - DELETE_ATTRIBUTE = Sentinel('Value used to delete an attribute from a user profile') -class ApiCallError(Exception): - """Represents an Exception encountered while invoking the Firebase user management API.""" - - def __init__(self, code, message, error=None): - Exception.__init__(self, message) - self.code = code - self.detail = error - - class UserMetadata(object): """Contains additional metadata associated with a user account.""" @@ -381,6 +359,7 @@ def photo_url(self): def provider_id(self): return self._data.get('providerId') + class ActionCodeSettings(object): """Contains required continue/state URL with optional Android and iOS settings. Used when invoking the email action link generation APIs. @@ -396,6 +375,7 @@ def __init__(self, url, handle_code_in_app=None, dynamic_link_domain=None, ios_b self.android_install_app = android_install_app self.android_minimum_version = android_minimum_version + def encode_action_code_settings(settings): """ Validates the provided action code settings for email link generation and populates the REST api parameters. @@ -463,6 +443,7 @@ def encode_action_code_settings(settings): return parameters + class UserManager(object): """Provides methods for interacting with the Google Identity Toolkit.""" @@ -484,16 +465,16 @@ def get_user(self, **kwargs): raise TypeError('Unsupported keyword arguments: {0}.'.format(kwargs)) try: - response = self._client.body('post', '/accounts:lookup', json=payload) + body, http_resp = self._client.body_and_response( + 'post', '/accounts:lookup', json=payload) except requests.exceptions.RequestException as error: - msg = 'Failed to get user by {0}: {1}.'.format(key_type, key) - self._handle_http_error(INTERNAL_ERROR, msg, error) + raise _auth_utils.handle_auth_backend_error(error) else: - if not response or not response.get('users'): - raise ApiCallError( - USER_NOT_FOUND_ERROR, - 'No user record found for the provided {0}: {1}.'.format(key_type, key)) - return response['users'][0] + if not body or not body.get('users'): + raise _auth_utils.UserNotFoundError( + 'No user record found for the provided {0}: {1}.'.format(key_type, key), + http_response=http_resp) + return body['users'][0] def list_users(self, page_token=None, max_results=MAX_LIST_USERS_RESULTS): """Retrieves a batch of users.""" @@ -513,7 +494,7 @@ def list_users(self, page_token=None, max_results=MAX_LIST_USERS_RESULTS): try: return self._client.body('get', '/accounts:batchGet', params=payload) except requests.exceptions.RequestException as error: - self._handle_http_error(USER_DOWNLOAD_ERROR, 'Failed to download user accounts.', error) + raise _auth_utils.handle_auth_backend_error(error) def create_user(self, uid=None, display_name=None, email=None, phone_number=None, photo_url=None, password=None, disabled=None, email_verified=None): @@ -530,17 +511,18 @@ def create_user(self, uid=None, display_name=None, email=None, phone_number=None } payload = {k: v for k, v in payload.items() if v is not None} try: - response = self._client.body('post', '/accounts', json=payload) + body, http_resp = self._client.body_and_response('post', '/accounts', json=payload) except requests.exceptions.RequestException as error: - self._handle_http_error(USER_CREATE_ERROR, 'Failed to create new user.', error) + raise _auth_utils.handle_auth_backend_error(error) else: - if not response or not response.get('localId'): - raise ApiCallError(USER_CREATE_ERROR, 'Failed to create new user.') - return response.get('localId') - - def update_user(self, uid, display_name=_UNSPECIFIED, email=None, phone_number=_UNSPECIFIED, - photo_url=_UNSPECIFIED, password=None, disabled=None, email_verified=None, - valid_since=None, custom_claims=_UNSPECIFIED): + if not body or not body.get('localId'): + raise _auth_utils.UnexpectedResponseError( + 'Failed to create new user.', http_response=http_resp) + return body.get('localId') + + def update_user(self, uid, display_name=None, email=None, phone_number=None, + photo_url=None, password=None, disabled=None, email_verified=None, + valid_since=None, custom_claims=None): """Updates an existing user account with the specified properties""" payload = { 'localId': _auth_utils.validate_uid(uid, required=True), @@ -552,27 +534,27 @@ def update_user(self, uid, display_name=_UNSPECIFIED, email=None, phone_number=_ } remove = [] - if display_name is not _UNSPECIFIED: - if display_name is None or display_name is DELETE_ATTRIBUTE: + if display_name is not None: + if display_name is DELETE_ATTRIBUTE: remove.append('DISPLAY_NAME') else: payload['displayName'] = _auth_utils.validate_display_name(display_name) - if photo_url is not _UNSPECIFIED: - if photo_url is None or photo_url is DELETE_ATTRIBUTE: + if photo_url is not None: + if photo_url is DELETE_ATTRIBUTE: remove.append('PHOTO_URL') else: payload['photoUrl'] = _auth_utils.validate_photo_url(photo_url) if remove: payload['deleteAttribute'] = remove - if phone_number is not _UNSPECIFIED: - if phone_number is None or phone_number is DELETE_ATTRIBUTE: + if phone_number is not None: + if phone_number is DELETE_ATTRIBUTE: payload['deleteProvider'] = ['phone'] else: payload['phoneNumber'] = _auth_utils.validate_phone(phone_number) - if custom_claims is not _UNSPECIFIED: - if custom_claims is None or custom_claims is DELETE_ATTRIBUTE: + if custom_claims is not None: + if custom_claims is DELETE_ATTRIBUTE: custom_claims = {} json_claims = json.dumps(custom_claims) if isinstance( custom_claims, dict) else custom_claims @@ -580,26 +562,28 @@ def update_user(self, uid, display_name=_UNSPECIFIED, email=None, phone_number=_ payload = {k: v for k, v in payload.items() if v is not None} try: - response = self._client.body('post', '/accounts:update', json=payload) + body, http_resp = self._client.body_and_response( + 'post', '/accounts:update', json=payload) except requests.exceptions.RequestException as error: - self._handle_http_error( - USER_UPDATE_ERROR, 'Failed to update user: {0}.'.format(uid), error) + raise _auth_utils.handle_auth_backend_error(error) else: - if not response or not response.get('localId'): - raise ApiCallError(USER_UPDATE_ERROR, 'Failed to update user: {0}.'.format(uid)) - return response.get('localId') + if not body or not body.get('localId'): + raise _auth_utils.UnexpectedResponseError( + 'Failed to update user: {0}.'.format(uid), http_response=http_resp) + return body.get('localId') def delete_user(self, uid): """Deletes the user identified by the specified user ID.""" _auth_utils.validate_uid(uid, required=True) try: - response = self._client.body('post', '/accounts:delete', json={'localId' : uid}) + body, http_resp = self._client.body_and_response( + 'post', '/accounts:delete', json={'localId' : uid}) except requests.exceptions.RequestException as error: - self._handle_http_error( - USER_DELETE_ERROR, 'Failed to delete user: {0}.'.format(uid), error) + raise _auth_utils.handle_auth_backend_error(error) else: - if not response or not response.get('kind'): - raise ApiCallError(USER_DELETE_ERROR, 'Failed to delete user: {0}.'.format(uid)) + if not body or not body.get('kind'): + raise _auth_utils.UnexpectedResponseError( + 'Failed to delete user: {0}.'.format(uid), http_response=http_resp) def import_users(self, users, hash_alg=None): """Imports the given list of users to Firebase Auth.""" @@ -619,13 +603,15 @@ def import_users(self, users, hash_alg=None): raise ValueError('A UserImportHash is required to import users with passwords.') payload.update(hash_alg.to_dict()) try: - response = self._client.body('post', '/accounts:batchCreate', json=payload) + body, http_resp = self._client.body_and_response( + 'post', '/accounts:batchCreate', json=payload) except requests.exceptions.RequestException as error: - self._handle_http_error(USER_IMPORT_ERROR, 'Failed to import users.', error) + raise _auth_utils.handle_auth_backend_error(error) else: - if not isinstance(response, dict): - raise ApiCallError(USER_IMPORT_ERROR, 'Failed to import users.') - return response + if not isinstance(body, dict): + raise _auth_utils.UnexpectedResponseError( + 'Failed to import users.', http_response=http_resp) + return body def generate_email_action_link(self, action_type, email, action_code_settings=None): """Fetches the email action links for types @@ -640,7 +626,7 @@ def generate_email_action_link(self, action_type, email, action_code_settings=No link_url: action url to be emailed to the user Raises: - ApiCallError: If an error occurs while generating the link + FirebaseError: If an error occurs while generating the link ValueError: If the provided arguments are invalid """ payload = { @@ -653,21 +639,15 @@ def generate_email_action_link(self, action_type, email, action_code_settings=No payload.update(encode_action_code_settings(action_code_settings)) try: - response = self._client.body('post', '/accounts:sendOobCode', json=payload) + body, http_resp = self._client.body_and_response( + 'post', '/accounts:sendOobCode', json=payload) except requests.exceptions.RequestException as error: - self._handle_http_error(GENERATE_EMAIL_ACTION_LINK_ERROR, 'Failed to generate link.', - error) - else: - if not response or not response.get('oobLink'): - raise ApiCallError(GENERATE_EMAIL_ACTION_LINK_ERROR, 'Failed to generate link.') - return response.get('oobLink') - - def _handle_http_error(self, code, msg, error): - if error.response is not None: - msg += '\nServer response: {0}'.format(error.response.content.decode()) + raise _auth_utils.handle_auth_backend_error(error) else: - msg += '\nReason: {0}'.format(error) - raise ApiCallError(code, msg, error) + if not body or not body.get('oobLink'): + raise _auth_utils.UnexpectedResponseError( + 'Failed to generate email action link.', http_response=http_resp) + return body.get('oobLink') class _UserIterator(object): diff --git a/firebase_admin/_utils.py b/firebase_admin/_utils.py index b28853868..95ed2c414 100644 --- a/firebase_admin/_utils.py +++ b/firebase_admin/_utils.py @@ -14,7 +14,49 @@ """Internal utilities common to all modules.""" +import json +import socket + +import googleapiclient +import httplib2 +import requests +import six + import firebase_admin +from firebase_admin import exceptions + + +_ERROR_CODE_TO_EXCEPTION_TYPE = { + exceptions.INVALID_ARGUMENT: exceptions.InvalidArgumentError, + exceptions.FAILED_PRECONDITION: exceptions.FailedPreconditionError, + exceptions.OUT_OF_RANGE: exceptions.OutOfRangeError, + exceptions.UNAUTHENTICATED: exceptions.UnauthenticatedError, + exceptions.PERMISSION_DENIED: exceptions.PermissionDeniedError, + exceptions.NOT_FOUND: exceptions.NotFoundError, + exceptions.ABORTED: exceptions.AbortedError, + exceptions.ALREADY_EXISTS: exceptions.AlreadyExistsError, + exceptions.CONFLICT: exceptions.ConflictError, + exceptions.RESOURCE_EXHAUSTED: exceptions.ResourceExhaustedError, + exceptions.CANCELLED: exceptions.CancelledError, + exceptions.DATA_LOSS: exceptions.DataLossError, + exceptions.UNKNOWN: exceptions.UnknownError, + exceptions.INTERNAL: exceptions.InternalError, + exceptions.UNAVAILABLE: exceptions.UnavailableError, + exceptions.DEADLINE_EXCEEDED: exceptions.DeadlineExceededError, +} + + +_HTTP_STATUS_TO_ERROR_CODE = { + 400: exceptions.INVALID_ARGUMENT, + 401: exceptions.UNAUTHENTICATED, + 403: exceptions.PERMISSION_DENIED, + 404: exceptions.NOT_FOUND, + 409: exceptions.CONFLICT, + 412: exceptions.FAILED_PRECONDITION, + 429: exceptions.RESOURCE_EXHAUSTED, + 500: exceptions.INTERNAL, + 503: exceptions.UNAVAILABLE, +} def _get_initialized_app(app): @@ -30,6 +72,223 @@ def _get_initialized_app(app): raise ValueError('Illegal app argument. Argument must be of type ' ' firebase_admin.App, but given "{0}".'.format(type(app))) + def get_app_service(app, name, initializer): app = _get_initialized_app(app) return app._get_service(name, initializer) # pylint: disable=protected-access + + +def handle_platform_error_from_requests(error, handle_func=None): + """Constructs a ``FirebaseError`` from the given requests error. + + This can be used to handle errors returned by Google Cloud Platform (GCP) APIs. + + Args: + error: An error raised by the requests module while making an HTTP call to a GCP API. + handle_func: A function that can be used to handle platform errors in a custom way. When + specified, this function will be called with three arguments. It has the same + signature as ```_handle_func_requests``, but may return ``None``. + + Returns: + FirebaseError: A ``FirebaseError`` that can be raised to the user code. + """ + if error.response is None: + return handle_requests_error(error) + + response = error.response + content = response.content.decode() + status_code = response.status_code + error_dict, message = _parse_platform_error(content, status_code) + exc = None + if handle_func: + exc = handle_func(error, message, error_dict) + + return exc if exc else _handle_func_requests(error, message, error_dict) + + +def _handle_func_requests(error, message, error_dict): + """Constructs a ``FirebaseError`` from the given GCP error. + + Args: + error: An error raised by the requests module while making an HTTP call. + message: A message to be included in the resulting ``FirebaseError``. + error_dict: Parsed GCP error response. + + Returns: + FirebaseError: A ``FirebaseError`` that can be raised to the user code or None. + """ + code = error_dict.get('status') + return handle_requests_error(error, message, code) + + +def handle_requests_error(error, message=None, code=None): + """Constructs a ``FirebaseError`` from the given requests error. + + This method is agnostic of the remote service that produced the error, whether it is a GCP + service or otherwise. Therefore, this method does not attempt to parse the error response in + any way. + + Args: + error: An error raised by the requests module while making an HTTP call. + message: A message to be included in the resulting ``FirebaseError`` (optional). If not + specified the string representation of the ``error`` argument is used as the message. + code: A GCP error code that will be used to determine the resulting error type (optional). + If not specified the HTTP status code on the error response is used to determine a + suitable error code. + + Returns: + FirebaseError: A ``FirebaseError`` that can be raised to the user code. + """ + if isinstance(error, requests.exceptions.Timeout): + return exceptions.DeadlineExceededError( + message='Timed out while making an API call: {0}'.format(error), + cause=error) + elif isinstance(error, requests.exceptions.ConnectionError): + return exceptions.UnavailableError( + message='Failed to establish a connection: {0}'.format(error), + cause=error) + elif error.response is None: + return exceptions.UnknownError( + message='Unknown error while making a remote service call: {0}'.format(error), + cause=error) + + if not code: + code = _http_status_to_error_code(error.response.status_code) + if not message: + message = str(error) + + err_type = _error_code_to_exception_type(code) + return err_type(message=message, cause=error, http_response=error.response) + + +def handle_platform_error_from_googleapiclient(error, handle_func=None): + """Constructs a ``FirebaseError`` from the given googleapiclient error. + + This can be used to handle errors returned by Google Cloud Platform (GCP) APIs. + + Args: + error: An error raised by the googleapiclient while making an HTTP call to a GCP API. + handle_func: A function that can be used to handle platform errors in a custom way. When + specified, this function will be called with three arguments. It has the same + signature as ```_handle_func_googleapiclient``, but may return ``None``. + + Returns: + FirebaseError: A ``FirebaseError`` that can be raised to the user code. + """ + if not isinstance(error, googleapiclient.errors.HttpError): + return handle_googleapiclient_error(error) + + content = error.content.decode() + status_code = error.resp.status + error_dict, message = _parse_platform_error(content, status_code) + http_response = _http_response_from_googleapiclient_error(error) + exc = None + if handle_func: + exc = handle_func(error, message, error_dict, http_response) + + return exc if exc else _handle_func_googleapiclient(error, message, error_dict, http_response) + + +def _handle_func_googleapiclient(error, message, error_dict, http_response): + """Constructs a ``FirebaseError`` from the given GCP error. + + Args: + error: An error raised by the googleapiclient module while making an HTTP call. + message: A message to be included in the resulting ``FirebaseError``. + error_dict: Parsed GCP error response. + http_response: A requests HTTP response object to associate with the exception. + + Returns: + FirebaseError: A ``FirebaseError`` that can be raised to the user code or None. + """ + code = error_dict.get('status') + return handle_googleapiclient_error(error, message, code, http_response) + + +def handle_googleapiclient_error(error, message=None, code=None, http_response=None): + """Constructs a ``FirebaseError`` from the given googleapiclient error. + + This method is agnostic of the remote service that produced the error, whether it is a GCP + service or otherwise. Therefore, this method does not attempt to parse the error response in + any way. + + Args: + error: An error raised by the googleapiclient module while making an HTTP call. + message: A message to be included in the resulting ``FirebaseError`` (optional). If not + specified the string representation of the ``error`` argument is used as the message. + code: A GCP error code that will be used to determine the resulting error type (optional). + If not specified the HTTP status code on the error response is used to determine a + suitable error code. + http_response: A requests HTTP response object to associate with the exception (optional). + If not specified, one will be created from the ``error``. + + Returns: + FirebaseError: A ``FirebaseError`` that can be raised to the user code. + """ + if isinstance(error, socket.timeout) or ( + isinstance(error, socket.error) and 'timed out' in str(error)): + return exceptions.DeadlineExceededError( + message='Timed out while making an API call: {0}'.format(error), + cause=error) + elif isinstance(error, httplib2.ServerNotFoundError): + return exceptions.UnavailableError( + message='Failed to establish a connection: {0}'.format(error), + cause=error) + elif not isinstance(error, googleapiclient.errors.HttpError): + return exceptions.UnknownError( + message='Unknown error while making a remote service call: {0}'.format(error), + cause=error) + + if not code: + code = _http_status_to_error_code(error.resp.status) + if not message: + message = str(error) + if not http_response: + http_response = _http_response_from_googleapiclient_error(error) + + err_type = _error_code_to_exception_type(code) + return err_type(message=message, cause=error, http_response=http_response) + + +def _http_response_from_googleapiclient_error(error): + """Creates a requests HTTP Response object from the given googleapiclient error.""" + resp = requests.models.Response() + resp.raw = six.BytesIO(error.content) + resp.status_code = error.resp.status + return resp + + +def _http_status_to_error_code(status): + """Maps an HTTP status to a platform error code.""" + return _HTTP_STATUS_TO_ERROR_CODE.get(status, exceptions.UNKNOWN) + + +def _error_code_to_exception_type(code): + """Maps a platform error code to an exception type.""" + return _ERROR_CODE_TO_EXCEPTION_TYPE.get(code, exceptions.UnknownError) + + +def _parse_platform_error(content, status_code): + """Parses an HTTP error response from a Google Cloud Platform API and extracts the error code + and message fields. + + Args: + content: Decoded content of the response body. + status_code: HTTP status code. + + Returns: + tuple: A tuple containing error code and message. + """ + data = {} + try: + parsed_body = json.loads(content) + if isinstance(parsed_body, dict): + data = parsed_body + except ValueError: + pass + + error_dict = data.get('error', {}) + msg = error_dict.get('message') + if not msg: + msg = 'Unexpected HTTP response with status: {0}; body: {1}'.format(status_code, content) + return error_dict, msg diff --git a/firebase_admin/auth.py b/firebase_admin/auth.py index fba5f3540..47a9a23f7 100644 --- a/firebase_admin/auth.py +++ b/firebase_admin/auth.py @@ -22,6 +22,7 @@ import time import firebase_admin +from firebase_admin import _auth_utils from firebase_admin import _http_client from firebase_admin import _token_gen from firebase_admin import _user_import @@ -30,22 +31,33 @@ _AUTH_ATTRIBUTE = '_auth' -_ID_TOKEN_REVOKED = 'ID_TOKEN_REVOKED' -_SESSION_COOKIE_REVOKED = 'SESSION_COOKIE_REVOKED' __all__ = [ 'ActionCodeSettings', - 'AuthError', + 'CertificateFetchError', 'DELETE_ATTRIBUTE', + 'EmailAlreadyExistsError', 'ErrorInfo', + 'ExpiredIdTokenError', + 'ExpiredSessionCookieError', 'ExportedUserRecord', 'ImportUserRecord', + 'InvalidDynamicLinkDomainError', + 'InvalidIdTokenError', + 'InvalidSessionCookieError', 'ListUsersPage', + 'PhoneNumberAlreadyExistsError', + 'RevokedIdTokenError', + 'RevokedSessionCookieError', + 'TokenSignError', + 'UidAlreadyExistsError', + 'UnexpectedResponseError', 'UserImportHash', 'UserImportResult', 'UserInfo', 'UserMetadata', + 'UserNotFoundError', 'UserProvider', 'UserRecord', @@ -69,15 +81,29 @@ ] ActionCodeSettings = _user_mgt.ActionCodeSettings +CertificateFetchError = _token_gen.CertificateFetchError DELETE_ATTRIBUTE = _user_mgt.DELETE_ATTRIBUTE +EmailAlreadyExistsError = _auth_utils.EmailAlreadyExistsError ErrorInfo = _user_import.ErrorInfo +ExpiredIdTokenError = _token_gen.ExpiredIdTokenError +ExpiredSessionCookieError = _token_gen.ExpiredSessionCookieError ExportedUserRecord = _user_mgt.ExportedUserRecord +ImportUserRecord = _user_import.ImportUserRecord +InvalidDynamicLinkDomainError = _auth_utils.InvalidDynamicLinkDomainError +InvalidIdTokenError = _auth_utils.InvalidIdTokenError +InvalidSessionCookieError = _token_gen.InvalidSessionCookieError ListUsersPage = _user_mgt.ListUsersPage +PhoneNumberAlreadyExistsError = _auth_utils.PhoneNumberAlreadyExistsError +RevokedIdTokenError = _token_gen.RevokedIdTokenError +RevokedSessionCookieError = _token_gen.RevokedSessionCookieError +TokenSignError = _token_gen.TokenSignError +UidAlreadyExistsError = _auth_utils.UidAlreadyExistsError +UnexpectedResponseError = _auth_utils.UnexpectedResponseError UserImportHash = _user_import.UserImportHash -ImportUserRecord = _user_import.ImportUserRecord UserImportResult = _user_import.UserImportResult UserInfo = _user_mgt.UserInfo UserMetadata = _user_mgt.UserMetadata +UserNotFoundError = _auth_utils.UserNotFoundError UserProvider = _user_import.UserProvider UserRecord = _user_mgt.UserRecord @@ -115,13 +141,10 @@ def create_custom_token(uid, developer_claims=None, app=None): Raises: ValueError: If input parameters are invalid. - AuthError: If an error occurs while creating the token using the remote IAM service. + TokenSignError: If an error occurs while signing the token using the remote IAM service. """ token_generator = _get_auth_service(app).token_generator - try: - return token_generator.create_custom_token(uid, developer_claims) - except _token_gen.ApiCallError as error: - raise AuthError(error.code, str(error), error.detail) + return token_generator.create_custom_token(uid, developer_claims) def verify_id_token(id_token, app=None, check_revoked=False): @@ -139,9 +162,12 @@ def verify_id_token(id_token, app=None, check_revoked=False): dict: A dictionary of key-value pairs parsed from the decoded JWT. Raises: - ValueError: If the JWT was found to be invalid, or if the App's project ID cannot - be determined. - AuthError: If ``check_revoked`` is requested and the token was revoked. + ValueError: If ``id_token`` is a not a string or is empty. + InvalidIdTokenError: If ``id_token`` is not a valid Firebase ID token. + ExpiredIdTokenError: If the specified ID token has expired. + RevokedIdTokenError: If ``check_revoked`` is ``True`` and the ID token has been revoked. + CertificateFetchError: If an error occurs while fetching the public key certificates + required to verify the ID token. """ if not isinstance(check_revoked, bool): # guard against accidental wrong assignment. @@ -150,7 +176,7 @@ def verify_id_token(id_token, app=None, check_revoked=False): token_verifier = _get_auth_service(app).token_verifier verified_claims = token_verifier.verify_id_token(id_token) if check_revoked: - _check_jwt_revoked(verified_claims, _ID_TOKEN_REVOKED, 'ID token', app) + _check_jwt_revoked(verified_claims, RevokedIdTokenError, 'ID token', app) return verified_claims @@ -170,13 +196,10 @@ def create_session_cookie(id_token, expires_in, app=None): Raises: ValueError: If input parameters are invalid. - AuthError: If an error occurs while creating the cookie. + FirebaseError: If an error occurs while creating the cookie. """ token_generator = _get_auth_service(app).token_generator - try: - return token_generator.create_session_cookie(id_token, expires_in) - except _token_gen.ApiCallError as error: - raise AuthError(error.code, str(error), error.detail) + return token_generator.create_session_cookie(id_token, expires_in) def verify_session_cookie(session_cookie, check_revoked=False, app=None): @@ -194,14 +217,17 @@ def verify_session_cookie(session_cookie, check_revoked=False, app=None): dict: A dictionary of key-value pairs parsed from the decoded JWT. Raises: - ValueError: If the cookie was found to be invalid, or if the App's project ID cannot - be determined. - AuthError: If ``check_revoked`` is requested and the cookie was revoked. + ValueError: If ``session_cookie`` is a not a string or is empty. + InvalidSessionCookieError: If ``session_cookie`` is not a valid Firebase session cookie. + ExpiredSessionCookieError: If the specified session cookie has expired. + RevokedSessionCookieError: If ``check_revoked`` is ``True`` and the cookie has been revoked. + CertificateFetchError: If an error occurs while fetching the public key certificates + required to verify the session cookie. """ token_verifier = _get_auth_service(app).token_verifier verified_claims = token_verifier.verify_session_cookie(session_cookie) if check_revoked: - _check_jwt_revoked(verified_claims, _SESSION_COOKIE_REVOKED, 'session cookie', app) + _check_jwt_revoked(verified_claims, RevokedSessionCookieError, 'session cookie', app) return verified_claims @@ -233,15 +259,12 @@ def get_user(uid, app=None): Raises: ValueError: If the user ID is None, empty or malformed. - AuthError: If an error occurs while retrieving the user or if the specified user ID - does not exist. + UserNotFoundError: If the specified user ID does not exist. + FirebaseError: If an error occurs while retrieving the user. """ user_manager = _get_auth_service(app).user_manager - try: - response = user_manager.get_user(uid=uid) - return UserRecord(response) - except _user_mgt.ApiCallError as error: - raise AuthError(error.code, str(error), error.detail) + response = user_manager.get_user(uid=uid) + return UserRecord(response) def get_user_by_email(email, app=None): @@ -256,15 +279,12 @@ def get_user_by_email(email, app=None): Raises: ValueError: If the email is None, empty or malformed. - AuthError: If an error occurs while retrieving the user or no user exists by the specified - email address. + UserNotFoundError: If no user exists by the specified email address. + FirebaseError: If an error occurs while retrieving the user. """ user_manager = _get_auth_service(app).user_manager - try: - response = user_manager.get_user(email=email) - return UserRecord(response) - except _user_mgt.ApiCallError as error: - raise AuthError(error.code, str(error), error.detail) + response = user_manager.get_user(email=email) + return UserRecord(response) def get_user_by_phone_number(phone_number, app=None): @@ -279,15 +299,12 @@ def get_user_by_phone_number(phone_number, app=None): Raises: ValueError: If the phone number is None, empty or malformed. - AuthError: If an error occurs while retrieving the user or no user exists by the specified - phone number. + UserNotFoundError: If no user exists by the specified phone number. + FirebaseError: If an error occurs while retrieving the user. """ user_manager = _get_auth_service(app).user_manager - try: - response = user_manager.get_user(phone_number=phone_number) - return UserRecord(response) - except _user_mgt.ApiCallError as error: - raise AuthError(error.code, str(error), error.detail) + response = user_manager.get_user(phone_number=phone_number) + return UserRecord(response) def list_users(page_token=None, max_results=_user_mgt.MAX_LIST_USERS_RESULTS, app=None): @@ -310,14 +327,11 @@ def list_users(page_token=None, max_results=_user_mgt.MAX_LIST_USERS_RESULTS, ap Raises: ValueError: If max_results or page_token are invalid. - AuthError: If an error occurs while retrieving the user accounts. + FirebaseError: If an error occurs while retrieving the user accounts. """ user_manager = _get_auth_service(app).user_manager def download(page_token, max_results): - try: - return user_manager.list_users(page_token, max_results) - except _user_mgt.ApiCallError as error: - raise AuthError(error.code, str(error), error.detail) + return user_manager.list_users(page_token, max_results) return ListUsersPage(download, page_token, max_results) @@ -341,15 +355,12 @@ def create_user(**kwargs): Raises: ValueError: If the specified user properties are invalid. - AuthError: If an error occurs while creating the user account. + FirebaseError: If an error occurs while creating the user account. """ app = kwargs.pop('app', None) user_manager = _get_auth_service(app).user_manager - try: - uid = user_manager.create_user(**kwargs) - return UserRecord(user_manager.get_user(uid=uid)) - except _user_mgt.ApiCallError as error: - raise AuthError(error.code, str(error), error.detail) + uid = user_manager.create_user(**kwargs) + return UserRecord(user_manager.get_user(uid=uid)) def update_user(uid, **kwargs): @@ -381,15 +392,12 @@ def update_user(uid, **kwargs): Raises: ValueError: If the specified user ID or properties are invalid. - AuthError: If an error occurs while updating the user account. + FirebaseError: If an error occurs while updating the user account. """ app = kwargs.pop('app', None) user_manager = _get_auth_service(app).user_manager - try: - user_manager.update_user(uid, **kwargs) - return UserRecord(user_manager.get_user(uid=uid)) - except _user_mgt.ApiCallError as error: - raise AuthError(error.code, str(error), error.detail) + user_manager.update_user(uid, **kwargs) + return UserRecord(user_manager.get_user(uid=uid)) def set_custom_user_claims(uid, custom_claims, app=None): @@ -410,13 +418,10 @@ def set_custom_user_claims(uid, custom_claims, app=None): Raises: ValueError: If the specified user ID or the custom claims are invalid. - AuthError: If an error occurs while updating the user account. + FirebaseError: If an error occurs while updating the user account. """ user_manager = _get_auth_service(app).user_manager - try: - user_manager.update_user(uid, custom_claims=custom_claims) - except _user_mgt.ApiCallError as error: - raise AuthError(error.code, str(error), error.detail) + user_manager.update_user(uid, custom_claims=custom_claims) def delete_user(uid, app=None): @@ -428,13 +433,10 @@ def delete_user(uid, app=None): Raises: ValueError: If the user ID is None, empty or malformed. - AuthError: If an error occurs while deleting the user account. + FirebaseError: If an error occurs while deleting the user account. """ user_manager = _get_auth_service(app).user_manager - try: - user_manager.delete_user(uid) - except _user_mgt.ApiCallError as error: - raise AuthError(error.code, str(error), error.detail) + user_manager.delete_user(uid) def import_users(users, hash_alg=None, app=None): @@ -457,14 +459,11 @@ def import_users(users, hash_alg=None, app=None): Raises: ValueError: If the provided arguments are invalid. - AuthError: If an error occurs while importing users. + FirebaseError: If an error occurs while importing users. """ user_manager = _get_auth_service(app).user_manager - try: - result = user_manager.import_users(users, hash_alg) - return UserImportResult(result, len(users)) - except _user_mgt.ApiCallError as error: - raise AuthError(error.code, str(error), error.detail) + result = user_manager.import_users(users, hash_alg) + return UserImportResult(result, len(users)) def generate_password_reset_link(email, action_code_settings=None, app=None): @@ -482,14 +481,11 @@ def generate_password_reset_link(email, action_code_settings=None, app=None): Raises: ValueError: If the provided arguments are invalid - AuthError: If an error occurs while generating the link + FirebaseError: If an error occurs while generating the link """ user_manager = _get_auth_service(app).user_manager - try: - return user_manager.generate_email_action_link('PASSWORD_RESET', email, - action_code_settings=action_code_settings) - except _user_mgt.ApiCallError as error: - raise AuthError(error.code, str(error), error.detail) + return user_manager.generate_email_action_link( + 'PASSWORD_RESET', email, action_code_settings=action_code_settings) def generate_email_verification_link(email, action_code_settings=None, app=None): @@ -507,14 +503,11 @@ def generate_email_verification_link(email, action_code_settings=None, app=None) Raises: ValueError: If the provided arguments are invalid - AuthError: If an error occurs while generating the link + FirebaseError: If an error occurs while generating the link """ user_manager = _get_auth_service(app).user_manager - try: - return user_manager.generate_email_action_link('VERIFY_EMAIL', email, - action_code_settings=action_code_settings) - except _user_mgt.ApiCallError as error: - raise AuthError(error.code, str(error), error.detail) + return user_manager.generate_email_action_link( + 'VERIFY_EMAIL', email, action_code_settings=action_code_settings) def generate_sign_in_with_email_link(email, action_code_settings, app=None): @@ -532,29 +525,17 @@ def generate_sign_in_with_email_link(email, action_code_settings, app=None): Raises: ValueError: If the provided arguments are invalid - AuthError: If an error occurs while generating the link + FirebaseError: If an error occurs while generating the link """ user_manager = _get_auth_service(app).user_manager - try: - return user_manager.generate_email_action_link('EMAIL_SIGNIN', email, - action_code_settings=action_code_settings) - except _user_mgt.ApiCallError as error: - raise AuthError(error.code, str(error), error.detail) + return user_manager.generate_email_action_link( + 'EMAIL_SIGNIN', email, action_code_settings=action_code_settings) -def _check_jwt_revoked(verified_claims, error_code, label, app): +def _check_jwt_revoked(verified_claims, exc_type, label, app): user = get_user(verified_claims.get('uid'), app=app) if verified_claims.get('iat') * 1000 < user.tokens_valid_after_timestamp: - raise AuthError(error_code, 'The Firebase {0} has been revoked.'.format(label)) - - -class AuthError(Exception): - """Represents an Exception encountered while invoking the Firebase auth API.""" - - def __init__(self, code, message, error=None): - Exception.__init__(self, message) - self.code = code - self.detail = error + raise exc_type('The Firebase {0} has been revoked.'.format(label)) class _AuthService(object): diff --git a/firebase_admin/db.py b/firebase_admin/db.py index 53efd9b15..ef7c96721 100644 --- a/firebase_admin/db.py +++ b/firebase_admin/db.py @@ -32,6 +32,7 @@ from six.moves import urllib import firebase_admin +from firebase_admin import exceptions from firebase_admin import _http_client from firebase_admin import _sseclient from firebase_admin import _utils @@ -209,7 +210,7 @@ def get(self, etag=False, shallow=False): Raises: ValueError: If both ``etag`` and ``shallow`` are set to True. - ApiCallError: If an error occurs while communicating with the remote database server. + FirebaseError: If an error occurs while communicating with the remote database server. """ if etag: if shallow: @@ -236,7 +237,7 @@ def get_if_changed(self, etag): Raises: ValueError: If the ETag is not a string. - ApiCallError: If an error occurs while communicating with the remote database server. + FirebaseError: If an error occurs while communicating with the remote database server. """ if not isinstance(etag, six.string_types): raise ValueError('ETag must be a string.') @@ -258,7 +259,7 @@ def set(self, value): Raises: ValueError: If the provided value is None. TypeError: If the value is not JSON-serializable. - ApiCallError: If an error occurs while communicating with the remote database server. + FirebaseError: If an error occurs while communicating with the remote database server. """ if value is None: raise ValueError('Value must not be None.') @@ -281,7 +282,7 @@ def set_if_unchanged(self, expected_etag, value): Raises: ValueError: If the value is None, or if expected_etag is not a string. - ApiCallError: If an error occurs while communicating with the remote database server. + FirebaseError: If an error occurs while communicating with the remote database server. """ # pylint: disable=missing-raises-doc if not isinstance(expected_etag, six.string_types): @@ -293,11 +294,11 @@ def set_if_unchanged(self, expected_etag, value): headers = self._client.headers( 'put', self._add_suffix(), json=value, headers={'if-match': expected_etag}) return True, value, headers.get('ETag') - except ApiCallError as error: - detail = error.detail - if detail.response is not None and 'ETag' in detail.response.headers: - etag = detail.response.headers['ETag'] - snapshot = detail.response.json() + except exceptions.FailedPreconditionError as error: + http_response = error.http_response + if http_response is not None and 'ETag' in http_response.headers: + etag = http_response.headers['ETag'] + snapshot = http_response.json() return False, snapshot, etag else: raise error @@ -317,7 +318,7 @@ def push(self, value=''): Raises: ValueError: If the value is None. TypeError: If the value is not JSON-serializable. - ApiCallError: If an error occurs while communicating with the remote database server. + FirebaseError: If an error occurs while communicating with the remote database server. """ if value is None: raise ValueError('Value must not be None.') @@ -333,7 +334,7 @@ def update(self, value): Raises: ValueError: If value is empty or not a dictionary. - ApiCallError: If an error occurs while communicating with the remote database server. + FirebaseError: If an error occurs while communicating with the remote database server. """ if not value or not isinstance(value, dict): raise ValueError('Value argument must be a non-empty dictionary.') @@ -345,7 +346,7 @@ def delete(self): """Deletes this node from the database. Raises: - ApiCallError: If an error occurs while communicating with the remote database server. + FirebaseError: If an error occurs while communicating with the remote database server. """ self._client.request('delete', self._add_suffix()) @@ -371,7 +372,7 @@ def listen(self, callback): ListenerRegistration: An object that can be used to stop the event listener. Raises: - ApiCallError: If an error occurs while starting the initial HTTP connection. + FirebaseError: If an error occurs while starting the initial HTTP connection. """ session = _sseclient.KeepAuthSession(self._client.credential) return self._listen_with_session(callback, session) @@ -387,9 +388,9 @@ def transaction(self, transaction_update): value of this reference into a new value. If another client writes to this location before the new value is successfully saved, the update function is called again with the new current value, and the write will be retried. In case of repeated failures, this method - will retry the transaction up to 25 times before giving up and raising a TransactionError. - The update function may also force an early abort by raising an exception instead of - returning a value. + will retry the transaction up to 25 times before giving up and raising a + TransactionAbortedError. The update function may also force an early abort by raising an + exception instead of returning a value. Args: transaction_update: A function which will be passed the current data stored at this @@ -402,7 +403,7 @@ def transaction(self, transaction_update): object: New value of the current database Reference (only if the transaction commits). Raises: - TransactionError: If the transaction aborts after exhausting all retry attempts. + TransactionAbortedError: If the transaction aborts after exhausting all retry attempts. ValueError: If transaction_update is not a function. """ if not callable(transaction_update): @@ -416,7 +417,8 @@ def transaction(self, transaction_update): if success: return new_data tries += 1 - raise TransactionError('Transaction aborted after failed retries.') + + raise TransactionAbortedError('Transaction aborted after failed retries.') def order_by_child(self, path): """Returns a Query that orders data by child values. @@ -468,7 +470,7 @@ def _listen_with_session(self, callback, session): sse = _sseclient.SSEClient(url, session) return ListenerRegistration(callback, sse) except requests.exceptions.RequestException as error: - raise ApiCallError(_Client.extract_error_message(error), error) + raise _Client.handle_rtdb_error(error) class Query(object): @@ -614,7 +616,7 @@ def get(self): object: Decoded JSON result of the Query. Raises: - ApiCallError: If an error occurs while communicating with the remote database server. + FirebaseError: If an error occurs while communicating with the remote database server. """ result = self._client.body('get', self._pathurl, params=self._querystr) if isinstance(result, (dict, list)) and self._order_by != '$priority': @@ -622,20 +624,11 @@ def get(self): return result -class ApiCallError(Exception): - """Represents an Exception encountered while invoking the Firebase database server API.""" - - def __init__(self, message, error): - Exception.__init__(self, message) - self.detail = error - - -class TransactionError(Exception): - """Represents an Exception encountered while performing a transaction.""" +class TransactionAbortedError(exceptions.AbortedError): + """A transaction was aborted aftr exceeding the maximum number of retries.""" def __init__(self, message): - Exception.__init__(self, message) - + exceptions.AbortedError.__init__(self, message) class _Sorter(object): @@ -934,7 +927,7 @@ def request(self, method, url, **kwargs): Response: An HTTP response object. Raises: - ApiCallError: If an error occurs while making the HTTP call. + FirebaseError: If an error occurs while making the HTTP call. """ query = '&'.join('{0}={1}'.format(key, self.params[key]) for key in self.params) extra_params = kwargs.get('params') @@ -950,33 +943,39 @@ def request(self, method, url, **kwargs): try: return super(_Client, self).request(method, url, **kwargs) except requests.exceptions.RequestException as error: - raise ApiCallError(_Client.extract_error_message(error), error) + raise _Client.handle_rtdb_error(error) + + @classmethod + def handle_rtdb_error(cls, error): + """Converts an error encountered while calling RTDB into a FirebaseError.""" + if error.response is None: + return _utils.handle_requests_error(error) + + message = cls._extract_error_message(error.response) + return _utils.handle_requests_error(error, message=message) @classmethod - def extract_error_message(cls, error): - """Extracts an error message from an exception. + def _extract_error_message(cls, response): + """Extracts an error message from an error response. - If the server has not sent any response, simply converts the exception into a string. If the server has sent a JSON response with an 'error' field, which is the typical behavior of the Realtime Database REST API, parses the response to retrieve the error message. If the server has sent a non-JSON response, returns the full response as the error message. - - Args: - error: An exception raised by the requests library. - - Returns: - str: A string error message extracted from the exception. """ - if error.response is None: - return str(error) + message = None try: - data = error.response.json() + # RTDB error format: {"error": "text message"} + data = response.json() if isinstance(data, dict): - return '{0}\nReason: {1}'.format(error, data.get('error', 'unknown')) + message = data.get('error') except ValueError: pass - return '{0}\nReason: {1}'.format(error, error.response.content.decode()) + + if not message: + message = 'Unexpected response from database: {0}'.format(response.content.decode()) + + return message class _EmulatorAdminCredentials(google.auth.credentials.Credentials): diff --git a/firebase_admin/exceptions.py b/firebase_admin/exceptions.py new file mode 100644 index 000000000..06504225f --- /dev/null +++ b/firebase_admin/exceptions.py @@ -0,0 +1,237 @@ +# Copyright 2019 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Firebase Exceptions module. + +This module defines the base types for exceptions and the platform-wide error codes as outlined in +https://cloud.google.com/apis/design/errors. + +:class:`FirebaseError` is the parent class of all exceptions raised by the Admin SDK. It contains +the ``code``, ``http_response`` and ``cause`` properties common to all Firebase exception types. +Each exception also carries a message that outlines what went wrong. This can be logged for +audit or debugging purposes. + +When calling an Admin SDK API, developers can catch the parent ``FirebaseError`` and +inspect its ``code`` to implement fine-grained error handling. Alternatively, developers can +catch one or more subtypes of ``FirebaseError``. Under normal conditions, any given API can raise +only a small subset of the available exception subtypes. However, the SDK also exposes rare error +conditions like connection timeouts and other I/O errors as instances of ``FirebaseError``. +Therefore it is always a good idea to have a handler specified for ``FirebaseError``, after all the +subtype error handlers. +""" + + +#: Error code for ``InvalidArgumentError`` type. +INVALID_ARGUMENT = 'INVALID_ARGUMENT' + +#: Error code for ``FailedPreconditionError`` type. +FAILED_PRECONDITION = 'FAILED_PRECONDITION' + +#: Error code for ``OutOfRangeError`` type. +OUT_OF_RANGE = 'OUT_OF_RANGE' + +#: Error code for ``UnauthenticatedError`` type. +UNAUTHENTICATED = 'UNAUTHENTICATED' + +#: Error code for ``PermissionDeniedError`` type. +PERMISSION_DENIED = 'PERMISSION_DENIED' + +#: Error code for ``NotFoundError`` type. +NOT_FOUND = 'NOT_FOUND' + +#: Error code for ``ConflictError`` type. +CONFLICT = 'CONFLICT' + +#: Error code for ``AbortedError`` type. +ABORTED = 'ABORTED' + +#: Error code for ``AlreadyExistsError`` type. +ALREADY_EXISTS = 'ALREADY_EXISTS' + +#: Error code for ``ResourceExhaustedError`` type. +RESOURCE_EXHAUSTED = 'RESOURCE_EXHAUSTED' + +#: Error code for ``CancelledError`` type. +CANCELLED = 'CANCELLED' + +#: Error code for ``DataLossError`` type. +DATA_LOSS = 'DATA_LOSS' + +#: Error code for ``UnknownError`` type. +UNKNOWN = 'UNKNOWN' + +#: Error code for ``InternalError`` type. +INTERNAL = 'INTERNAL' + +#: Error code for ``UnavailableError`` type. +UNAVAILABLE = 'UNAVAILABLE' + +#: Error code for ``DeadlineExceededError`` type. +DEADLINE_EXCEEDED = 'DEADLINE_EXCEEDED' + + +class FirebaseError(Exception): + """Base class for all errors raised by the Admin SDK. + + Args: + code: A string error code that represents the type of the exception. Possible error + codes are defined in https://cloud.google.com/apis/design/errors#handling_errors. + message: A human-readable error message string. + cause: The exception that caused this error (optional). + http_response: If this error was caused by an HTTP error response, this property is + set to the ``requests.Response`` object that represents the HTTP response (optional). + See https://2.python-requests.org/en/master/api/#requests.Response for details of + this object. + """ + + def __init__(self, code, message, cause=None, http_response=None): + Exception.__init__(self, message) + self._code = code + self._cause = cause + self._http_response = http_response + + @property + def code(self): + return self._code + + @property + def cause(self): + return self._cause + + @property + def http_response(self): + return self._http_response + + +class InvalidArgumentError(FirebaseError): + """Client specified an invalid argument.""" + + def __init__(self, message, cause=None, http_response=None): + FirebaseError.__init__(self, INVALID_ARGUMENT, message, cause, http_response) + + +class FailedPreconditionError(FirebaseError): + """Request can not be executed in the current system state, such as deleting a non-empty + directory.""" + + def __init__(self, message, cause=None, http_response=None): + FirebaseError.__init__(self, FAILED_PRECONDITION, message, cause, http_response) + + +class OutOfRangeError(FirebaseError): + """Client specified an invalid range.""" + + def __init__(self, message, cause=None, http_response=None): + FirebaseError.__init__(self, OUT_OF_RANGE, message, cause, http_response) + + +class UnauthenticatedError(FirebaseError): + """Request not authenticated due to missing, invalid, or expired OAuth token.""" + + def __init__(self, message, cause=None, http_response=None): + FirebaseError.__init__(self, UNAUTHENTICATED, message, cause, http_response) + + +class PermissionDeniedError(FirebaseError): + """Client does not have sufficient permission. + + This can happen because the OAuth token does not have the right scopes, the client doesn't + have permission, or the API has not been enabled for the client project. + """ + + def __init__(self, message, cause=None, http_response=None): + FirebaseError.__init__(self, PERMISSION_DENIED, message, cause, http_response) + + +class NotFoundError(FirebaseError): + """A specified resource is not found, or the request is rejected by undisclosed reasons, such + as whitelisting.""" + + def __init__(self, message, cause=None, http_response=None): + FirebaseError.__init__(self, NOT_FOUND, message, cause, http_response) + + +class ConflictError(FirebaseError): + """Concurrency conflict, such as read-modify-write conflict.""" + + def __init__(self, message, cause=None, http_response=None): + FirebaseError.__init__(self, CONFLICT, message, cause, http_response) + + +class AbortedError(FirebaseError): + """Concurrency conflict, such as read-modify-write conflict.""" + + def __init__(self, message, cause=None, http_response=None): + FirebaseError.__init__(self, ABORTED, message, cause, http_response) + + +class AlreadyExistsError(FirebaseError): + """The resource that a client tried to create already exists.""" + + def __init__(self, message, cause=None, http_response=None): + FirebaseError.__init__(self, ALREADY_EXISTS, message, cause, http_response) + + +class ResourceExhaustedError(FirebaseError): + """Either out of resource quota or reaching rate limiting.""" + + def __init__(self, message, cause=None, http_response=None): + FirebaseError.__init__(self, RESOURCE_EXHAUSTED, message, cause, http_response) + + +class CancelledError(FirebaseError): + """Request cancelled by the client.""" + + def __init__(self, message, cause=None, http_response=None): + FirebaseError.__init__(self, CANCELLED, message, cause, http_response) + + +class DataLossError(FirebaseError): + """Unrecoverable data loss or data corruption.""" + + def __init__(self, message, cause=None, http_response=None): + FirebaseError.__init__(self, DATA_LOSS, message, cause, http_response) + + +class UnknownError(FirebaseError): + """Unknown server error.""" + + def __init__(self, message, cause=None, http_response=None): + FirebaseError.__init__(self, UNKNOWN, message, cause, http_response) + + +class InternalError(FirebaseError): + """Internal server error.""" + + def __init__(self, message, cause=None, http_response=None): + FirebaseError.__init__(self, INTERNAL, message, cause, http_response) + + +class UnavailableError(FirebaseError): + """Service unavailable. Typically the server is down.""" + + def __init__(self, message, cause=None, http_response=None): + FirebaseError.__init__(self, UNAVAILABLE, message, cause, http_response) + + +class DeadlineExceededError(FirebaseError): + """Request deadline exceeded. + + This will happen only if the caller sets a deadline that is shorter than the method's + default deadline (i.e. requested deadline is not enough for the server to process the + request) and the request did not finish within the deadline. + """ + + def __init__(self, message, cause=None, http_response=None): + FirebaseError.__init__(self, DEADLINE_EXCEEDED, message, cause, http_response) diff --git a/firebase_admin/instance_id.py b/firebase_admin/instance_id.py index b290e9e7f..e9134fc28 100644 --- a/firebase_admin/instance_id.py +++ b/firebase_admin/instance_id.py @@ -53,14 +53,6 @@ def delete_instance_id(instance_id, app=None): _get_iid_service(app).delete_instance_id(instance_id) -class ApiCallError(Exception): - """Represents an Exception encountered while invoking the Firebase instance ID service.""" - - def __init__(self, message, error): - Exception.__init__(self, message) - self.detail = error - - class _InstanceIdService(object): """Provides methods for interacting with the remote instance ID service.""" @@ -94,14 +86,15 @@ def delete_instance_id(self, instance_id): try: self._client.request('delete', path) except requests.exceptions.RequestException as error: - raise ApiCallError(self._extract_message(instance_id, error), error) + msg = self._extract_message(instance_id, error) + raise _utils.handle_requests_error(error, msg) def _extract_message(self, instance_id, error): if error.response is None: - return str(error) + return None status = error.response.status_code msg = self.error_codes.get(status) if msg: return 'Instance ID "{0}": {1}'.format(instance_id, msg) else: - return str(error) + return 'Instance ID "{0}": {1}'.format(instance_id, error) diff --git a/firebase_admin/messaging.py b/firebase_admin/messaging.py index c0f023169..cbd3522fa 100644 --- a/firebase_admin/messaging.py +++ b/firebase_admin/messaging.py @@ -38,7 +38,6 @@ 'APNSConfig', 'APNSFCMOptions', 'APNSPayload', - 'ApiCallError', 'Aps', 'ApsAlert', 'BatchResponse', @@ -48,8 +47,12 @@ 'Message', 'MulticastMessage', 'Notification', + 'QuotaExceededError', + 'SenderIdMismatchError', 'SendResponse', + 'ThirdPartyAuthError', 'TopicManagementResponse', + 'UnregisteredError', 'WebpushConfig', 'WebpushFCMOptions', 'WebpushNotification', @@ -78,10 +81,14 @@ Notification = _messaging_utils.Notification WebpushConfig = _messaging_utils.WebpushConfig WebpushFCMOptions = _messaging_utils.WebpushFCMOptions -WebpushFcmOptions = _messaging_utils.WebpushFCMOptions WebpushNotification = _messaging_utils.WebpushNotification WebpushNotificationAction = _messaging_utils.WebpushNotificationAction +QuotaExceededError = _messaging_utils.QuotaExceededError +SenderIdMismatchError = _messaging_utils.SenderIdMismatchError +ThirdPartyAuthError = _messaging_utils.ThirdPartyAuthError +UnregisteredError = _messaging_utils.UnregisteredError + def _get_messaging_service(app): return _utils.get_app_service(app, _MESSAGING_ATTRIBUTE, _MessagingService) @@ -101,7 +108,7 @@ def send(message, dry_run=False, app=None): string: A message ID string that uniquely identifies the sent the message. Raises: - ApiCallError: If an error occurs while sending the message to the FCM service. + FirebaseError: If an error occurs while sending the message to the FCM service. ValueError: If the input arguments are invalid. """ return _get_messaging_service(app).send(message, dry_run) @@ -121,7 +128,7 @@ def send_all(messages, dry_run=False, app=None): BatchResponse: A ``messaging.BatchResponse`` instance. Raises: - ApiCallError: If an error occurs while sending the message to the FCM service. + FirebaseError: If an error occurs while sending the message to the FCM service. ValueError: If the input arguments are invalid. """ return _get_messaging_service(app).send_all(messages, dry_run) @@ -141,7 +148,7 @@ def send_multicast(multicast_message, dry_run=False, app=None): BatchResponse: A ``messaging.BatchResponse`` instance. Raises: - ApiCallError: If an error occurs while sending the message to the FCM service. + FirebaseError: If an error occurs while sending the message to the FCM service. ValueError: If the input arguments are invalid. """ if not isinstance(multicast_message, MulticastMessage): @@ -170,7 +177,7 @@ def subscribe_to_topic(tokens, topic, app=None): TopicManagementResponse: A ``TopicManagementResponse`` instance. Raises: - ApiCallError: If an error occurs while communicating with instance ID service. + FirebaseError: If an error occurs while communicating with instance ID service. ValueError: If the input arguments are invalid. """ return _get_messaging_service(app).make_topic_management_request( @@ -189,7 +196,7 @@ def unsubscribe_from_topic(tokens, topic, app=None): TopicManagementResponse: A ``TopicManagementResponse`` instance. Raises: - ApiCallError: If an error occurs while communicating with instance ID service. + FirebaseError: If an error occurs while communicating with instance ID service. ValueError: If the input arguments are invalid. """ return _get_messaging_service(app).make_topic_management_request( @@ -246,21 +253,6 @@ def errors(self): return self._errors -class ApiCallError(Exception): - """Represents an Exception encountered while invoking the FCM API. - - Attributes: - code: A string error code. - message: A error message string. - detail: Original low-level exception. - """ - - def __init__(self, code, message, detail=None): - Exception.__init__(self, message) - self.code = code - self.detail = detail - - class BatchResponse(object): """The response received from a batch request to the FCM API.""" @@ -303,7 +295,7 @@ def success(self): @property def exception(self): - """An ApiCallError if an error occurs while sending the message to the FCM service.""" + """A ``FirebaseError`` if an error occurs while sending the message to the FCM service.""" return self._exception @@ -316,30 +308,12 @@ class _MessagingService(object): IID_HEADERS = {'access_token_auth': 'true'} JSON_ENCODER = _messaging_utils.MessageEncoder() - INTERNAL_ERROR = 'internal-error' - UNKNOWN_ERROR = 'unknown-error' - FCM_ERROR_CODES = { - # FCM v1 canonical error codes - 'NOT_FOUND': 'registration-token-not-registered', - 'PERMISSION_DENIED': 'mismatched-credential', - 'RESOURCE_EXHAUSTED': 'message-rate-exceeded', - 'UNAUTHENTICATED': 'invalid-apns-credentials', - - # FCM v1 new error codes - 'APNS_AUTH_ERROR': 'invalid-apns-credentials', - 'INTERNAL': INTERNAL_ERROR, - 'INVALID_ARGUMENT': 'invalid-argument', - 'QUOTA_EXCEEDED': 'message-rate-exceeded', - 'SENDER_ID_MISMATCH': 'mismatched-credential', - 'UNAVAILABLE': 'server-unavailable', - 'UNREGISTERED': 'registration-token-not-registered', - } - IID_ERROR_CODES = { - 400: 'invalid-argument', - 401: 'authentication-error', - 403: 'authentication-error', - 500: INTERNAL_ERROR, - 503: 'server-unavailable', + FCM_ERROR_TYPES = { + 'APNS_AUTH_ERROR': ThirdPartyAuthError, + 'QUOTA_EXCEEDED': QuotaExceededError, + 'SENDER_ID_MISMATCH': SenderIdMismatchError, + 'THIRD_PARTY_AUTH_ERROR': ThirdPartyAuthError, + 'UNREGISTERED': UnregisteredError, } def __init__(self, app): @@ -375,11 +349,7 @@ def send(self, message, dry_run=False): timeout=self._timeout ) except requests.exceptions.RequestException as error: - if error.response is not None: - self._handle_fcm_error(error) - else: - msg = 'Failed to call messaging API: {0}'.format(error) - raise ApiCallError(self.INTERNAL_ERROR, msg, error) + raise self._handle_fcm_error(error) else: return resp['name'] @@ -395,7 +365,7 @@ def send_all(self, messages, dry_run=False): def batch_callback(_, response, error): exception = None if error: - exception = self._parse_batch_error(error) + exception = self._handle_batch_error(error) send_response = SendResponse(response, exception) responses.append(send_response) @@ -415,7 +385,7 @@ def batch_callback(_, response, error): try: batch.execute() except googleapiclient.http.HttpError as error: - raise self._parse_batch_error(error) + raise self._handle_batch_error(error) else: return BatchResponse(responses) @@ -447,10 +417,7 @@ def make_topic_management_request(self, tokens, topic, operation): timeout=self._timeout ) except requests.exceptions.RequestException as error: - if error.response is not None: - self._handle_iid_error(error) - else: - raise ApiCallError(self.INTERNAL_ERROR, 'Failed to call instance ID API.', error) + raise self._handle_iid_error(error) else: return TopicManagementResponse(resp) @@ -467,20 +434,14 @@ def _postproc(self, _, body): def _handle_fcm_error(self, error): """Handles errors received from the FCM API.""" - data = {} - try: - parsed_body = error.response.json() - if isinstance(parsed_body, dict): - data = parsed_body - except ValueError: - pass - - code, msg = _MessagingService._parse_fcm_error( - data, error.response.content, error.response.status_code) - raise ApiCallError(code, msg, error) + return _utils.handle_platform_error_from_requests( + error, _MessagingService._build_fcm_error_requests) def _handle_iid_error(self, error): """Handles errors received from the Instance ID API.""" + if error.response is None: + raise _utils.handle_requests_error(error) + data = {} try: parsed_body = error.response.json() @@ -489,46 +450,40 @@ def _handle_iid_error(self, error): except ValueError: pass - code = _MessagingService.IID_ERROR_CODES.get( - error.response.status_code, _MessagingService.UNKNOWN_ERROR) + # IID error response format: {"error": "some error message"} msg = data.get('error') if not msg: msg = 'Unexpected HTTP response with status: {0}; body: {1}'.format( error.response.status_code, error.response.content.decode()) - raise ApiCallError(code, msg, error) - def _parse_batch_error(self, error): - """Parses a googleapiclient.http.HttpError content in to an ApiCallError.""" - if error.content is None: - msg = 'Failed to call messaging API: {0}'.format(error) - return ApiCallError(self.INTERNAL_ERROR, msg, error) + return _utils.handle_requests_error(error, msg) - data = {} - try: - parsed_body = json.loads(error.content.decode()) - if isinstance(parsed_body, dict): - data = parsed_body - except ValueError: - pass + def _handle_batch_error(self, error): + """Handles errors received from the googleapiclient while making batch requests.""" + return _utils.handle_platform_error_from_googleapiclient( + error, _MessagingService._build_fcm_error_googleapiclient) + + @classmethod + def _build_fcm_error_requests(cls, error, message, error_dict): + """Parses an error response from the FCM API and creates a FCM-specific exception if + appropriate.""" + exc_type = cls._build_fcm_error(error_dict) + return exc_type(message, cause=error, http_response=error.response) if exc_type else None - code, msg = _MessagingService._parse_fcm_error(data, error.content, error.resp.status) - return ApiCallError(code, msg, error) + @classmethod + def _build_fcm_error_googleapiclient(cls, error, message, error_dict, http_response): + """Parses an error response from the FCM API and creates a FCM-specific exception if + appropriate.""" + exc_type = cls._build_fcm_error(error_dict) + return exc_type(message, cause=error, http_response=http_response) if exc_type else None @classmethod - def _parse_fcm_error(cls, data, content, status_code): - """Parses an error response from the FCM API to a ApiCallError.""" - error_dict = data.get('error', {}) - server_code = None + def _build_fcm_error(cls, error_dict): + if not error_dict: + return None + fcm_code = None for detail in error_dict.get('details', []): if detail.get('@type') == 'type.googleapis.com/google.firebase.fcm.v1.FcmError': - server_code = detail.get('errorCode') + fcm_code = detail.get('errorCode') break - if not server_code: - server_code = error_dict.get('status') - code = _MessagingService.FCM_ERROR_CODES.get(server_code, _MessagingService.UNKNOWN_ERROR) - - msg = error_dict.get('message') - if not msg: - msg = 'Unexpected HTTP response with status: {0}; body: {1}'.format( - status_code, content.decode()) - return code, msg + return _MessagingService.FCM_ERROR_TYPES.get(fcm_code) diff --git a/firebase_admin/project_management.py b/firebase_admin/project_management.py index cc57471c5..68e10797c 100644 --- a/firebase_admin/project_management.py +++ b/firebase_admin/project_management.py @@ -25,6 +25,7 @@ import six import firebase_admin +from firebase_admin import exceptions from firebase_admin import _http_client from firebase_admin import _utils @@ -57,9 +58,9 @@ def ios_app(app_id, app=None): app: An App instance (optional). Returns: - IosApp: An ``IosApp`` instance. + IOSApp: An ``IOSApp`` instance. """ - return IosApp(app_id=app_id, service=_get_project_management_service(app)) + return IOSApp(app_id=app_id, service=_get_project_management_service(app)) def list_android_apps(app=None): @@ -82,7 +83,7 @@ def list_ios_apps(app=None): app: An App instance (optional). Returns: - list: a list of ``IosApp`` instances referring to each iOS app in the Firebase project. + list: a list of ``IOSApp`` instances referring to each iOS app in the Firebase project. """ return _get_project_management_service(app).list_ios_apps() @@ -110,7 +111,7 @@ def create_ios_app(bundle_id, display_name=None, app=None): app: An App instance (optional). Returns: - IosApp: An ``IosApp`` instance that is a reference to the newly created app. + IOSApp: An ``IOSApp`` instance that is a reference to the newly created app. """ return _get_project_management_service(app).create_ios_app(bundle_id, display_name) @@ -139,21 +140,6 @@ def _check_not_none(obj, field_name): return obj -class ApiCallError(Exception): - """An error encountered while interacting with the Firebase Project Management Service.""" - - def __init__(self, message, error): - Exception.__init__(self, message) - self.detail = error - - -class _PollingError(Exception): - """An error encountered during the polling of an app's creation status.""" - - def __init__(self, message): - Exception.__init__(self, message) - - class AndroidApp(object): """A reference to an Android app within a Firebase project. @@ -185,7 +171,7 @@ def get_metadata(self): AndroidAppMetadata: An ``AndroidAppMetadata`` instance. Raises: - ApiCallError: If an error occurs while communicating with the Firebase Project + FirebaseError: If an error occurs while communicating with the Firebase Project Management Service. """ return self._service.get_android_app_metadata(self._app_id) @@ -200,7 +186,7 @@ def set_display_name(self, new_display_name): NoneType: None. Raises: - ApiCallError: If an error occurs while communicating with the Firebase Project + FirebaseError: If an error occurs while communicating with the Firebase Project Management Service. """ return self._service.set_android_app_display_name(self._app_id, new_display_name) @@ -213,10 +199,10 @@ def get_sha_certificates(self): """Retrieves the entire list of SHA certificates associated with this Android app. Returns: - list: A list of ``ShaCertificate`` instances. + list: A list of ``SHACertificate`` instances. Raises: - ApiCallError: If an error occurs while communicating with the Firebase Project + FirebaseError: If an error occurs while communicating with the Firebase Project Management Service. """ return self._service.get_sha_certificates(self._app_id) @@ -231,7 +217,7 @@ def add_sha_certificate(self, certificate_to_add): NoneType: None. Raises: - ApiCallError: If an error occurs while communicating with the Firebase Project + FirebaseError: If an error occurs while communicating with the Firebase Project Management Service. (For example, if the certificate_to_add already exists.) """ return self._service.add_sha_certificate(self._app_id, certificate_to_add) @@ -246,13 +232,13 @@ def delete_sha_certificate(self, certificate_to_delete): NoneType: None. Raises: - ApiCallError: If an error occurs while communicating with the Firebase Project + FirebaseError: If an error occurs while communicating with the Firebase Project Management Service. (For example, if the certificate_to_delete is not found.) """ return self._service.delete_sha_certificate(certificate_to_delete) -class IosApp(object): +class IOSApp(object): """A reference to an iOS app within a Firebase project. Note: Unless otherwise specified, all methods defined in this class make an RPC. @@ -280,10 +266,10 @@ def get_metadata(self): """Retrieves detailed information about this iOS app. Returns: - IosAppMetadata: An ``IosAppMetadata`` instance. + IOSAppMetadata: An ``IOSAppMetadata`` instance. Raises: - ApiCallError: If an error occurs while communicating with the Firebase Project + FirebaseError: If an error occurs while communicating with the Firebase Project Management Service. """ return self._service.get_ios_app_metadata(self._app_id) @@ -298,7 +284,7 @@ def set_display_name(self, new_display_name): NoneType: None. Raises: - ApiCallError: If an error occurs while communicating with the Firebase Project + FirebaseError: If an error occurs while communicating with the Firebase Project Management Service. """ return self._service.set_ios_app_display_name(self._app_id, new_display_name) @@ -373,12 +359,12 @@ def __hash__(self): (self._name, self.app_id, self.display_name, self.project_id, self.package_name)) -class IosAppMetadata(_AppMetadata): +class IOSAppMetadata(_AppMetadata): """iOS-specific information about an iOS Firebase app.""" def __init__(self, bundle_id, name, app_id, display_name, project_id): """Clients should not instantiate this class directly.""" - super(IosAppMetadata, self).__init__(name, app_id, display_name, project_id) + super(IOSAppMetadata, self).__init__(name, app_id, display_name, project_id) self._bundle_id = _check_is_nonempty_string(bundle_id, 'bundle_id') @property @@ -387,7 +373,7 @@ def bundle_id(self): return self._bundle_id def __eq__(self, other): - return super(IosAppMetadata, self).__eq__(other) and self.bundle_id == other.bundle_id + return super(IOSAppMetadata, self).__eq__(other) and self.bundle_id == other.bundle_id def __ne__(self, other): return not self.__eq__(other) @@ -396,7 +382,7 @@ def __hash__(self): return hash((self._name, self.app_id, self.display_name, self.project_id, self.bundle_id)) -class ShaCertificate(object): +class SHACertificate(object): """Represents a SHA-1 or SHA-256 certificate associated with an Android app.""" SHA_1 = 'SHA_1' @@ -406,7 +392,7 @@ class ShaCertificate(object): _SHA_256_RE = re.compile('^[0-9A-Fa-f]{64}$') def __init__(self, sha_hash, name=None): - """Creates a new ShaCertificate instance. + """Creates a new SHACertificate instance. Args: sha_hash: A string; the certificate hash for the Android app. @@ -421,10 +407,10 @@ def __init__(self, sha_hash, name=None): _check_is_nonempty_string_or_none(name, 'name') self._name = name self._sha_hash = sha_hash.lower() - if ShaCertificate._SHA_1_RE.match(sha_hash): - self._cert_type = ShaCertificate.SHA_1 - elif ShaCertificate._SHA_256_RE.match(sha_hash): - self._cert_type = ShaCertificate.SHA_256 + if SHACertificate._SHA_1_RE.match(sha_hash): + self._cert_type = SHACertificate.SHA_1 + elif SHACertificate._SHA_256_RE.match(sha_hash): + self._cert_type = SHACertificate.SHA_256 else: raise ValueError( 'The supplied certificate hash is neither a valid SHA-1 nor SHA_256 hash.') @@ -458,7 +444,7 @@ def cert_type(self): return self._cert_type def __eq__(self, other): - if not isinstance(other, ShaCertificate): + if not isinstance(other, SHACertificate): return False return (self.name == other.name and self.sha_hash == other.sha_hash and self.cert_type == other.cert_type) @@ -478,22 +464,11 @@ class _ProjectManagementService(object): MAXIMUM_POLLING_ATTEMPTS = 8 POLL_BASE_WAIT_TIME_SECONDS = 0.5 POLL_EXPONENTIAL_BACKOFF_FACTOR = 1.5 - ERROR_CODES = { - 401: 'Request not authorized.', - 403: 'Client does not have sufficient privileges.', - 404: 'Failed to find the resource.', - 409: 'The resource already exists.', - 429: 'Request throttled out by the backend server.', - 500: 'Internal server error.', - 503: 'Backend servers are over capacity. Try again later.' - } ANDROID_APPS_RESOURCE_NAME = 'androidApps' ANDROID_APP_IDENTIFIER_NAME = 'packageName' - ANDROID_APP_IDENTIFIER_LABEL = 'Package name' IOS_APPS_RESOURCE_NAME = 'iosApps' IOS_APP_IDENTIFIER_NAME = 'bundleId' - IOS_APP_IDENTIFIER_LABEL = 'Bundle ID' def __init__(self, app): project_id = app.project_id @@ -521,14 +496,14 @@ def get_ios_app_metadata(self, app_id): return self._get_app_metadata( platform_resource_name=_ProjectManagementService.IOS_APPS_RESOURCE_NAME, identifier_name=_ProjectManagementService.IOS_APP_IDENTIFIER_NAME, - metadata_class=IosAppMetadata, + metadata_class=IOSAppMetadata, app_id=app_id) def _get_app_metadata(self, platform_resource_name, identifier_name, metadata_class, app_id): """Retrieves detailed information about an Android or iOS app.""" _check_is_nonempty_string(app_id, 'app_id') path = '/v1beta1/projects/-/{0}/{1}'.format(platform_resource_name, app_id) - response = self._make_request('get', path, app_id, 'App ID') + response = self._make_request('get', path) return metadata_class( response[identifier_name], name=response['name'], @@ -553,7 +528,7 @@ def _set_display_name(self, app_id, new_display_name, platform_resource_name): path = '/v1beta1/projects/-/{0}/{1}?updateMask=displayName'.format( platform_resource_name, app_id) request_body = {'displayName': new_display_name} - self._make_request('patch', path, app_id, 'App ID', json=request_body) + self._make_request('patch', path, json=request_body) def list_android_apps(self): return self._list_apps( @@ -563,7 +538,7 @@ def list_android_apps(self): def list_ios_apps(self): return self._list_apps( platform_resource_name=_ProjectManagementService.IOS_APPS_RESOURCE_NAME, - app_class=IosApp) + app_class=IOSApp) def _list_apps(self, platform_resource_name, app_class): """Lists all the Android or iOS apps within the Firebase project.""" @@ -571,7 +546,7 @@ def _list_apps(self, platform_resource_name, app_class): self._project_id, platform_resource_name, _ProjectManagementService.MAXIMUM_LIST_APPS_PAGE_SIZE) - response = self._make_request('get', path, self._project_id, 'Project ID') + response = self._make_request('get', path) apps_list = [] while True: apps = response.get('apps') @@ -587,14 +562,13 @@ def _list_apps(self, platform_resource_name, app_class): platform_resource_name, next_page_token, _ProjectManagementService.MAXIMUM_LIST_APPS_PAGE_SIZE) - response = self._make_request('get', path, self._project_id, 'Project ID') + response = self._make_request('get', path) return apps_list def create_android_app(self, package_name, display_name=None): return self._create_app( platform_resource_name=_ProjectManagementService.ANDROID_APPS_RESOURCE_NAME, identifier_name=_ProjectManagementService.ANDROID_APP_IDENTIFIER_NAME, - identifier_label=_ProjectManagementService.ANDROID_APP_IDENTIFIER_LABEL, identifier=package_name, display_name=display_name, app_class=AndroidApp) @@ -603,16 +577,14 @@ def create_ios_app(self, bundle_id, display_name=None): return self._create_app( platform_resource_name=_ProjectManagementService.IOS_APPS_RESOURCE_NAME, identifier_name=_ProjectManagementService.IOS_APP_IDENTIFIER_NAME, - identifier_label=_ProjectManagementService.IOS_APP_IDENTIFIER_LABEL, identifier=bundle_id, display_name=display_name, - app_class=IosApp) + app_class=IOSApp) def _create_app( self, platform_resource_name, identifier_name, - identifier_label, identifier, display_name, app_class): @@ -622,15 +594,10 @@ def _create_app( request_body = {identifier_name: identifier} if display_name: request_body['displayName'] = display_name - response = self._make_request('post', path, identifier, identifier_label, json=request_body) + response = self._make_request('post', path, json=request_body) operation_name = response['name'] - try: - poll_response = self._poll_app_creation(operation_name) - return app_class(app_id=poll_response['appId'], service=self) - except _PollingError as error: - raise ApiCallError( - _ProjectManagementService._extract_message(operation_name, 'Operation name', error), - error) + poll_response = self._poll_app_creation(operation_name) + return app_class(app_id=poll_response['appId'], service=self) def _poll_app_creation(self, operation_name): """Polls the Long-Running Operation repeatedly until it is done with exponential backoff.""" @@ -640,16 +607,17 @@ def _poll_app_creation(self, operation_name): wait_time_seconds = delay_factor * _ProjectManagementService.POLL_BASE_WAIT_TIME_SECONDS time.sleep(wait_time_seconds) path = '/v1/{0}'.format(operation_name) - poll_response = self._make_request('get', path, operation_name, 'Operation name') + poll_response, http_response = self._body_and_response('get', path) done = poll_response.get('done') if done: response = poll_response.get('response') if response: return response else: - raise _PollingError( - 'Polling finished, but the operation terminated in an error.') - raise _PollingError('Polling deadline exceeded.') + raise exceptions.UnknownError( + 'Polling finished, but the operation terminated in an error.', + http_response=http_response) + raise exceptions.DeadlineExceededError('Polling deadline exceeded.') def get_android_app_config(self, app_id): return self._get_app_config( @@ -662,44 +630,36 @@ def get_ios_app_config(self, app_id): def _get_app_config(self, platform_resource_name, app_id): path = '/v1beta1/projects/-/{0}/{1}/config'.format(platform_resource_name, app_id) - response = self._make_request('get', path, app_id, 'App ID') + response = self._make_request('get', path) # In Python 2.7, the base64 module works with strings, while in Python 3, it works with # bytes objects. This line works in both versions. return base64.standard_b64decode(response['configFileContents']).decode(encoding='utf-8') def get_sha_certificates(self, app_id): path = '/v1beta1/projects/-/androidApps/{0}/sha'.format(app_id) - response = self._make_request('get', path, app_id, 'App ID') + response = self._make_request('get', path) cert_list = response.get('certificates') or [] - return [ShaCertificate(sha_hash=cert['shaHash'], name=cert['name']) for cert in cert_list] + return [SHACertificate(sha_hash=cert['shaHash'], name=cert['name']) for cert in cert_list] def add_sha_certificate(self, app_id, certificate_to_add): path = '/v1beta1/projects/-/androidApps/{0}/sha'.format(app_id) sha_hash = _check_not_none(certificate_to_add, 'certificate_to_add').sha_hash cert_type = certificate_to_add.cert_type request_body = {'shaHash': sha_hash, 'certType': cert_type} - self._make_request('post', path, app_id, 'App ID', json=request_body) + self._make_request('post', path, json=request_body) def delete_sha_certificate(self, certificate_to_delete): name = _check_not_none(certificate_to_delete, 'certificate_to_delete').name path = '/v1beta1/{0}'.format(name) - self._make_request('delete', path, name, 'SHA ID') + self._make_request('delete', path) + + def _make_request(self, method, url, json=None): + body, _ = self._body_and_response(method, url, json) + return body - def _make_request(self, method, url, resource_identifier, resource_identifier_label, json=None): + def _body_and_response(self, method, url, json=None): try: - return self._client.body(method=method, url=url, json=json, timeout=self._timeout) + return self._client.body_and_response( + method=method, url=url, json=json, timeout=self._timeout) except requests.exceptions.RequestException as error: - raise ApiCallError( - _ProjectManagementService._extract_message( - resource_identifier, resource_identifier_label, error), - error) - - @staticmethod - def _extract_message(identifier, identifier_label, error): - if not isinstance(error, requests.exceptions.RequestException) or error.response is None: - return '{0} "{1}": {2}'.format(identifier_label, identifier, str(error)) - status = error.response.status_code - message = _ProjectManagementService.ERROR_CODES.get(status) - if message: - return '{0} "{1}": {2}'.format(identifier_label, identifier, message) - return '{0} "{1}": Error {2}.'.format(identifier_label, identifier, status) + raise _utils.handle_platform_error_from_requests(error) diff --git a/integration/test_auth.py b/integration/test_auth.py index 53577b827..eb1464476 100644 --- a/integration/test_auth.py +++ b/integration/test_auth.py @@ -29,6 +29,7 @@ import google.oauth2.credentials from google.auth import transport + _verify_token_url = 'https://www.googleapis.com/identitytoolkit/v3/relyingparty/verifyCustomToken' _verify_password_url = 'https://www.googleapis.com/identitytoolkit/v3/relyingparty/verifyPassword' _password_reset_url = 'https://www.googleapis.com/identitytoolkit/v3/relyingparty/resetPassword' @@ -129,25 +130,30 @@ def test_session_cookies(api_key): estimated_exp = int(time.time() + expires_in.total_seconds()) assert abs(claims['exp'] - estimated_exp) < 5 +def test_session_cookie_error(): + expires_in = datetime.timedelta(days=1) + with pytest.raises(auth.InvalidIdTokenError): + auth.create_session_cookie('not.a.token', expires_in=expires_in) + def test_get_non_existing_user(): - with pytest.raises(auth.AuthError) as excinfo: + with pytest.raises(auth.UserNotFoundError) as excinfo: auth.get_user('non.existing') - assert 'USER_NOT_FOUND_ERROR' in str(excinfo.value.code) + assert str(excinfo.value) == 'No user record found for the provided user ID: non.existing.' def test_get_non_existing_user_by_email(): - with pytest.raises(auth.AuthError) as excinfo: + with pytest.raises(auth.UserNotFoundError) as excinfo: auth.get_user_by_email('non.existing@definitely.non.existing') - assert 'USER_NOT_FOUND_ERROR' in str(excinfo.value.code) + error_msg = ('No user record found for the provided email: ' + 'non.existing@definitely.non.existing.') + assert str(excinfo.value) == error_msg def test_update_non_existing_user(): - with pytest.raises(auth.AuthError) as excinfo: + with pytest.raises(auth.UserNotFoundError): auth.update_user('non.existing') - assert 'USER_UPDATE_ERROR' in str(excinfo.value.code) def test_delete_non_existing_user(): - with pytest.raises(auth.AuthError) as excinfo: + with pytest.raises(auth.UserNotFoundError): auth.delete_user('non.existing') - assert 'USER_DELETE_ERROR' in str(excinfo.value.code) @pytest.fixture def new_user(): @@ -250,9 +256,8 @@ def test_create_user(new_user): assert user.user_metadata.creation_timestamp > 0 assert user.user_metadata.last_sign_in_timestamp is None assert len(user.provider_data) is 0 - with pytest.raises(auth.AuthError) as excinfo: + with pytest.raises(auth.UidAlreadyExistsError): auth.create_user(uid=new_user.uid) - assert excinfo.value.code == 'USER_CREATE_ERROR' def test_update_user(new_user): _, email = _random_id() @@ -321,9 +326,8 @@ def test_disable_user(new_user_with_params): def test_delete_user(): user = auth.create_user() auth.delete_user(user.uid) - with pytest.raises(auth.AuthError) as excinfo: + with pytest.raises(auth.UserNotFoundError): auth.get_user(user.uid) - assert excinfo.value.code == 'USER_NOT_FOUND_ERROR' def test_revoke_refresh_tokens(new_user): user = auth.get_user(new_user.uid) @@ -347,9 +351,8 @@ def test_verify_id_token_revoked(new_user, api_key): # verify_id_token succeeded because it didn't check revoked. assert claims['iat'] * 1000 < user.tokens_valid_after_timestamp - with pytest.raises(auth.AuthError) as excinfo: + with pytest.raises(auth.RevokedIdTokenError) as excinfo: claims = auth.verify_id_token(id_token, check_revoked=True) - assert excinfo.value.code == auth._ID_TOKEN_REVOKED assert str(excinfo.value) == 'The Firebase ID token has been revoked.' # Sign in again, verify works. @@ -369,9 +372,8 @@ def test_verify_session_cookie_revoked(new_user, api_key): # verify_session_cookie succeeded because it didn't check revoked. assert claims['iat'] * 1000 < user.tokens_valid_after_timestamp - with pytest.raises(auth.AuthError) as excinfo: + with pytest.raises(auth.RevokedSessionCookieError) as excinfo: claims = auth.verify_session_cookie(session_cookie, check_revoked=True) - assert excinfo.value.code == auth._SESSION_COOKIE_REVOKED assert str(excinfo.value) == 'The Firebase session cookie has been revoked.' # Sign in again, verify works. diff --git a/integration/test_db.py b/integration/test_db.py index d88d145ba..4c2f6bde2 100644 --- a/integration/test_db.py +++ b/integration/test_db.py @@ -22,6 +22,7 @@ import firebase_admin from firebase_admin import db +from firebase_admin import exceptions from integration import conftest from tests import testutils @@ -359,30 +360,26 @@ def init_ref(self, path, app): admin_ref.set('test') assert admin_ref.get() == 'test' - def check_permission_error(self, excinfo): - assert isinstance(excinfo.value, db.ApiCallError) - assert 'Reason: Permission denied' in str(excinfo.value) - def test_no_access(self, app, override_app): path = '_adminsdk/python/admin' self.init_ref(path, app) user_ref = db.reference(path, override_app) - with pytest.raises(db.ApiCallError) as excinfo: + with pytest.raises(exceptions.UnauthenticatedError) as excinfo: assert user_ref.get() - self.check_permission_error(excinfo) + assert str(excinfo.value) == 'Permission denied' - with pytest.raises(db.ApiCallError) as excinfo: + with pytest.raises(exceptions.UnauthenticatedError) as excinfo: user_ref.set('test2') - self.check_permission_error(excinfo) + assert str(excinfo.value) == 'Permission denied' def test_read(self, app, override_app): path = '_adminsdk/python/protected/user2' self.init_ref(path, app) user_ref = db.reference(path, override_app) assert user_ref.get() == 'test' - with pytest.raises(db.ApiCallError) as excinfo: + with pytest.raises(exceptions.UnauthenticatedError) as excinfo: user_ref.set('test2') - self.check_permission_error(excinfo) + assert str(excinfo.value) == 'Permission denied' def test_read_write(self, app, override_app): path = '_adminsdk/python/protected/user1' @@ -394,9 +391,9 @@ def test_read_write(self, app, override_app): def test_query(self, override_app): user_ref = db.reference('_adminsdk/python/protected', override_app) - with pytest.raises(db.ApiCallError) as excinfo: + with pytest.raises(exceptions.UnauthenticatedError) as excinfo: user_ref.order_by_key().limit_to_first(2).get() - self.check_permission_error(excinfo) + assert str(excinfo.value) == 'Permission denied' def test_none_auth_override(self, app, none_override_app): path = '_adminsdk/python/public' @@ -405,14 +402,14 @@ def test_none_auth_override(self, app, none_override_app): assert public_ref.get() == 'test' ref = db.reference('_adminsdk/python', none_override_app) - with pytest.raises(db.ApiCallError) as excinfo: + with pytest.raises(exceptions.UnauthenticatedError) as excinfo: assert ref.child('protected/user1').get() - self.check_permission_error(excinfo) + assert str(excinfo.value) == 'Permission denied' - with pytest.raises(db.ApiCallError) as excinfo: + with pytest.raises(exceptions.UnauthenticatedError) as excinfo: assert ref.child('protected/user2').get() - self.check_permission_error(excinfo) + assert str(excinfo.value) == 'Permission denied' - with pytest.raises(db.ApiCallError) as excinfo: + with pytest.raises(exceptions.UnauthenticatedError) as excinfo: assert ref.child('admin').get() - self.check_permission_error(excinfo) + assert str(excinfo.value) == 'Permission denied' diff --git a/integration/test_instance_id.py b/integration/test_instance_id.py index 1a176a9a0..99b6787d3 100644 --- a/integration/test_instance_id.py +++ b/integration/test_instance_id.py @@ -16,10 +16,11 @@ import pytest +from firebase_admin import exceptions from firebase_admin import instance_id def test_delete_non_existing(): - with pytest.raises(instance_id.ApiCallError) as excinfo: + with pytest.raises(exceptions.NotFoundError) as excinfo: # legal instance IDs are /[cdef][A-Za-z0-9_-]{9}[AEIMQUYcgkosw048]/ instance_id.delete_instance_id('fictive-ID0') assert str(excinfo.value) == 'Instance ID "fictive-ID0": Failed to find the instance ID.' diff --git a/integration/test_messaging.py b/integration/test_messaging.py index b1caa09f9..21f9d9669 100644 --- a/integration/test_messaging.py +++ b/integration/test_messaging.py @@ -16,6 +16,9 @@ import re +import pytest + +from firebase_admin import exceptions from firebase_admin import messaging @@ -51,6 +54,22 @@ def test_send(): msg_id = messaging.send(msg, dry_run=True) assert re.match('^projects/.*/messages/.*$', msg_id) +def test_send_invalid_token(): + msg = messaging.Message( + token=_REGISTRATION_TOKEN, + notification=messaging.Notification('test-title', 'test-body') + ) + with pytest.raises(messaging.SenderIdMismatchError): + messaging.send(msg, dry_run=True) + +def test_send_malformed_token(): + msg = messaging.Message( + token='not-a-token', + notification=messaging.Notification('test-title', 'test-body') + ) + with pytest.raises(exceptions.InvalidArgumentError): + messaging.send(msg, dry_run=True) + def test_send_all(): messages = [ messaging.Message( @@ -79,7 +98,7 @@ def test_send_all(): response = batch_response.responses[2] assert response.success is False - assert response.exception is not None + assert isinstance(response.exception, exceptions.InvalidArgumentError) assert response.message_id is None def test_send_one_hundred(): diff --git a/integration/test_project_management.py b/integration/test_project_management.py index 7386a4837..ca648f12d 100644 --- a/integration/test_project_management.py +++ b/integration/test_project_management.py @@ -20,6 +20,7 @@ import pytest +from firebase_admin import exceptions from firebase_admin import project_management @@ -31,8 +32,8 @@ SHA_1_HASH_2 = 'aaaaaaaaaaaaaaaaaaaabbbbbbbbbbbbbbbbbbbb' SHA_256_HASH_1 = '123456789a123456789a123456789a123456789a123456789a123456789a1234' SHA_256_HASH_2 = 'cafef00dba5eba11b01dfaceacc01adeda7aba5eca55e77e0b57ac1e5ca1ab1e' -SHA_1 = project_management.ShaCertificate.SHA_1 -SHA_256 = project_management.ShaCertificate.SHA_256 +SHA_1 = project_management.SHACertificate.SHA_1 +SHA_256 = project_management.SHACertificate.SHA_256 def _starts_with(display_name, prefix): @@ -64,11 +65,12 @@ def ios_app(default_app): def test_create_android_app_already_exists(android_app): del android_app - with pytest.raises(project_management.ApiCallError) as excinfo: + with pytest.raises(exceptions.AlreadyExistsError) as excinfo: project_management.create_android_app( package_name=TEST_APP_PACKAGE_NAME, display_name=TEST_APP_DISPLAY_NAME_PREFIX) - assert 'The resource already exists' in str(excinfo.value) - assert excinfo.value.detail is not None + assert 'Requested entity already exists' in str(excinfo.value) + assert excinfo.value.cause is not None + assert excinfo.value.http_response is not None def test_android_set_display_name_and_get_metadata(android_app, project_id): @@ -118,10 +120,10 @@ def test_android_sha_certificates(android_app): android_app.delete_sha_certificate(cert) # Add four different certs and assert that they have all been added successfully. - android_app.add_sha_certificate(project_management.ShaCertificate(SHA_1_HASH_1)) - android_app.add_sha_certificate(project_management.ShaCertificate(SHA_1_HASH_2)) - android_app.add_sha_certificate(project_management.ShaCertificate(SHA_256_HASH_1)) - android_app.add_sha_certificate(project_management.ShaCertificate(SHA_256_HASH_2)) + android_app.add_sha_certificate(project_management.SHACertificate(SHA_1_HASH_1)) + android_app.add_sha_certificate(project_management.SHACertificate(SHA_1_HASH_2)) + android_app.add_sha_certificate(project_management.SHACertificate(SHA_256_HASH_1)) + android_app.add_sha_certificate(project_management.SHACertificate(SHA_256_HASH_2)) cert_list = android_app.get_sha_certificates() @@ -133,10 +135,11 @@ def test_android_sha_certificates(android_app): assert cert.name # Adding the same cert twice should cause an already-exists error. - with pytest.raises(project_management.ApiCallError) as excinfo: - android_app.add_sha_certificate(project_management.ShaCertificate(SHA_256_HASH_2)) - assert 'The resource already exists' in str(excinfo.value) - assert excinfo.value.detail is not None + with pytest.raises(exceptions.AlreadyExistsError) as excinfo: + android_app.add_sha_certificate(project_management.SHACertificate(SHA_256_HASH_2)) + assert 'Requested entity already exists' in str(excinfo.value) + assert excinfo.value.cause is not None + assert excinfo.value.http_response is not None # Delete all certs and assert that they have all been deleted successfully. for cert in cert_list: @@ -145,20 +148,22 @@ def test_android_sha_certificates(android_app): assert android_app.get_sha_certificates() == [] # Deleting a nonexistent cert should cause a not-found error. - with pytest.raises(project_management.ApiCallError) as excinfo: + with pytest.raises(exceptions.NotFoundError) as excinfo: android_app.delete_sha_certificate(cert_list[0]) - assert 'Failed to find the resource' in str(excinfo.value) - assert excinfo.value.detail is not None + assert 'Requested entity was not found' in str(excinfo.value) + assert excinfo.value.cause is not None + assert excinfo.value.http_response is not None def test_create_ios_app_already_exists(ios_app): del ios_app - with pytest.raises(project_management.ApiCallError) as excinfo: + with pytest.raises(exceptions.AlreadyExistsError) as excinfo: project_management.create_ios_app( bundle_id=TEST_APP_BUNDLE_ID, display_name=TEST_APP_DISPLAY_NAME_PREFIX) - assert 'The resource already exists' in str(excinfo.value) - assert excinfo.value.detail is not None + assert 'Requested entity already exists' in str(excinfo.value) + assert excinfo.value.cause is not None + assert excinfo.value.http_response is not None def test_ios_set_display_name_and_get_metadata(ios_app, project_id): diff --git a/lint.sh b/lint.sh index 603b78f92..aeb37f741 100755 --- a/lint.sh +++ b/lint.sh @@ -20,7 +20,7 @@ function lintAllFiles () { } function lintChangedFiles () { - files=`git status -s $1 | grep -v "^D" | awk '{print $NF}' | grep .py$` + files=`git status -s $1 | (grep -v "^D") | awk '{print $NF}' | (grep .py$ || true)` for f in $files do echo "Running linter on $f" diff --git a/requirements.txt b/requirements.txt index 7a8d855bd..fd73d36bd 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,8 +5,8 @@ pytest-localserver >= 0.4.1 tox >= 3.6.0 cachecontrol >= 0.12.4 -google-api-core[grpc] >= 1.7.0, < 2.0.0dev; platform.python_implementation != 'PyPy' +google-api-core[grpc] >= 1.14.0, < 2.0.0dev; platform.python_implementation != 'PyPy' google-api-python-client >= 1.7.8 -google-cloud-firestore >= 0.31.0; platform.python_implementation != 'PyPy' -google-cloud-storage >= 1.13.0 +google-cloud-firestore >= 1.4.0; platform.python_implementation != 'PyPy' +google-cloud-storage >= 1.18.0 six >= 1.6.1 diff --git a/setup.py b/setup.py index 15ae97f93..a3cce8be5 100644 --- a/setup.py +++ b/setup.py @@ -38,10 +38,10 @@ 'to integrate Firebase into their services and applications.') install_requires = [ 'cachecontrol>=0.12.4', - 'google-api-core[grpc] >= 1.7.0, < 2.0.0dev; platform.python_implementation != "PyPy"', + 'google-api-core[grpc] >= 1.14.0, < 2.0.0dev; platform.python_implementation != "PyPy"', 'google-api-python-client >= 1.7.8', - 'google-cloud-firestore>=0.31.0; platform.python_implementation != "PyPy"', - 'google-cloud-storage>=1.13.0', + 'google-cloud-firestore>=1.4.0; platform.python_implementation != "PyPy"', + 'google-cloud-storage>=1.18.0', 'six>=1.6.1' ] diff --git a/snippets/auth/index.py b/snippets/auth/index.py index 5bfe21f8e..552875696 100644 --- a/snippets/auth/index.py +++ b/snippets/auth/index.py @@ -24,6 +24,7 @@ # [END import_sdk] from firebase_admin import credentials from firebase_admin import auth +from firebase_admin import exceptions sys.path.append("lib") @@ -31,6 +32,7 @@ def initialize_sdk_with_service_account(): # [START initialize_sdk_with_service_account] import firebase_admin from firebase_admin import credentials + from firebase_admin import exceptions cred = credentials.Certificate('path/to/serviceAccountKey.json') default_app = firebase_admin.initialize_app(cred) @@ -144,13 +146,12 @@ def verify_token_uid_check_revoke(id_token): decoded_token = auth.verify_id_token(id_token, check_revoked=True) # Token is valid and not revoked. uid = decoded_token['uid'] - except auth.AuthError as exc: - if exc.code == 'ID_TOKEN_REVOKED': - # Token revoked, inform the user to reauthenticate or signOut(). - pass - else: - # Token is invalid - pass + except auth.RevokedIdTokenError: + # Token revoked, inform the user to reauthenticate or signOut(). + pass + except auth.InvalidIdTokenError: + # Token is invalid + pass # [END verify_token_id_check_revoked] firebase_admin.delete_app(default_app) return uid @@ -322,7 +323,7 @@ def session_login(): response.set_cookie( 'session', session_cookie, expires=expires, httponly=True, secure=True) return response - except auth.AuthError: + except exceptions.FirebaseError: return flask.abort(401, 'Failed to create a session cookie') # [END session_login] @@ -344,9 +345,9 @@ def check_auth_time(id_token, flask): # User did not sign in recently. To guard against ID token theft, require # re-authentication. return flask.abort(401, 'Recent sign in required') - except ValueError: + except auth.InvalidIdTokenError: return flask.abort(401, 'Invalid ID token') - except auth.AuthError: + except exceptions.FirebaseError: return flask.abort(401, 'Failed to create a session cookie') # [END check_auth_time] @@ -359,16 +360,17 @@ def serve_content_for_user(decoded_claims): @app.route('/profile', methods=['POST']) def access_restricted_content(): session_cookie = flask.request.cookies.get('session') + if not session_cookie: + # Session cookie is unavailable. Force user to login. + return flask.redirect('/login') + # Verify the session cookie. In this case an additional check is added to detect # if the user's Firebase session was revoked, user deleted/disabled, etc. try: decoded_claims = auth.verify_session_cookie(session_cookie, check_revoked=True) return serve_content_for_user(decoded_claims) - except ValueError: - # Session cookie is unavailable or invalid. Force user to login. - return flask.redirect('/login') - except auth.AuthError: - # Session revoked. Force user to login. + except auth.InvalidSessionCookieError: + # Session cookie is invalid, expired or revoked. Force user to login. return flask.redirect('/login') # [END session_verify] @@ -385,11 +387,8 @@ def serve_content_for_admin(decoded_claims): return serve_content_for_admin(decoded_claims) else: return flask.abort(401, 'Insufficient permissions') - except ValueError: - # Session cookie is unavailable or invalid. Force user to login. - return flask.redirect('/login') - except auth.AuthError: - # Session revoked. Force user to login. + except auth.InvalidSessionCookieError: + # Session cookie is invalid, expired or revoked. Force user to login. return flask.redirect('/login') # [END session_verify_with_permission_check] @@ -413,7 +412,7 @@ def session_logout(): response = flask.make_response(flask.redirect('/login')) response.set_cookie('session', expires=0) return response - except ValueError: + except auth.InvalidSessionCookieError: return flask.redirect('/login') # [END session_clear_and_revoke] @@ -444,7 +443,7 @@ def import_users(): result.success_count, result.failure_count)) for err in result.errors: print('Failed to import {0} due to {1}'.format(users[err.index].uid, err.reason)) - except auth.AuthError: + except exceptions.FirebaseError: # Some unrecoverable error occurred that prevented the operation from running. pass # [END import_users] @@ -465,7 +464,7 @@ def import_with_hmac(): result = auth.import_users(users, hash_alg=hash_alg) for err in result.errors: print('Failed to import user:', err.reason) - except auth.AuthError as error: + except exceptions.FirebaseError as error: print('Error importing users:', error) # [END import_with_hmac] @@ -485,7 +484,7 @@ def import_with_pbkdf(): result = auth.import_users(users, hash_alg=hash_alg) for err in result.errors: print('Failed to import user:', err.reason) - except auth.AuthError as error: + except exceptions.FirebaseError as error: print('Error importing users:', error) # [END import_with_pbkdf] @@ -506,7 +505,7 @@ def import_with_standard_scrypt(): result = auth.import_users(users, hash_alg=hash_alg) for err in result.errors: print('Failed to import user:', err.reason) - except auth.AuthError as error: + except exceptions.FirebaseError as error: print('Error importing users:', error) # [END import_with_standard_scrypt] @@ -526,7 +525,7 @@ def import_with_bcrypt(): result = auth.import_users(users, hash_alg=hash_alg) for err in result.errors: print('Failed to import user:', err.reason) - except auth.AuthError as error: + except exceptions.FirebaseError as error: print('Error importing users:', error) # [END import_with_bcrypt] @@ -553,7 +552,7 @@ def import_with_scrypt(): result = auth.import_users(users, hash_alg=hash_alg) for err in result.errors: print('Failed to import user:', err.reason) - except auth.AuthError as error: + except exceptions.FirebaseError as error: print('Error importing users:', error) # [END import_with_scrypt] @@ -583,7 +582,7 @@ def import_without_password(): result = auth.import_users(users) for err in result.errors: print('Failed to import user:', err.reason) - except auth.AuthError as error: + except exceptions.FirebaseError as error: print('Error importing users:', error) # [END import_without_password] diff --git a/snippets/database/index.py b/snippets/database/index.py index fee23f626..adfa13476 100644 --- a/snippets/database/index.py +++ b/snippets/database/index.py @@ -214,7 +214,7 @@ def increment_votes(current_value): try: new_vote_count = upvotes_ref.transaction(increment_votes) print('Transaction completed') - except db.TransactionError: + except db.TransactionAbortedError: print('Transaction failed to commit') # [END transaction] diff --git a/tests/test_db.py b/tests/test_db.py index 211eabb4b..081c31e3d 100644 --- a/tests/test_db.py +++ b/tests/test_db.py @@ -22,6 +22,7 @@ import firebase_admin from firebase_admin import db +from firebase_admin import exceptions from firebase_admin import _sseclient from tests import testutils @@ -31,14 +32,15 @@ class MockAdapter(testutils.MockAdapter): ETAG = '0' - def __init__(self, data, status, recorder): + def __init__(self, data, status, recorder, etag=ETAG): testutils.MockAdapter.__init__(self, data, status, recorder) + self._etag = etag def send(self, request, **kwargs): if_match = request.headers.get('if-match') if_none_match = request.headers.get('if-none-match') resp = super(MockAdapter, self).send(request, **kwargs) - resp.headers = {'ETag': MockAdapter.ETAG} + resp.headers = {'ETag': self._etag} if if_match and if_match != MockAdapter.ETAG: resp.status_code = 412 elif if_none_match == MockAdapter.ETAG: @@ -125,6 +127,38 @@ def test_invalid_child(self, child): parent.child(child) +class _RefOperations(object): + """A collection of operations that can be performed using a ``db.Reference``. + + This can be used to test any functionality that is common across multiple API calls. + """ + + @classmethod + def get(cls, ref): + ref.get() + + @classmethod + def push(cls, ref): + ref.push() + + @classmethod + def set(cls, ref): + ref.set({'foo': 'bar'}) + + @classmethod + def delete(cls, ref): + ref.delete() + + @classmethod + def query(cls, ref): + query = ref.order_by_key() + query.get() + + @classmethod + def get_ops(cls): + return [cls.get, cls.push, cls.set, cls.delete, cls.query] + + class TestReference(object): """Test cases for database queries via References.""" @@ -132,6 +166,12 @@ class TestReference(object): valid_values = [ '', 'foo', 0, 1, 100, 1.2, True, False, [], [1, 2], {}, {'foo' : 'bar'} ] + error_codes = { + 400: exceptions.InvalidArgumentError, + 401: exceptions.UnauthenticatedError, + 404: exceptions.NotFoundError, + 500: exceptions.InternalError, + } @classmethod def setup_class(cls): @@ -141,9 +181,9 @@ def setup_class(cls): def teardown_class(cls): testutils.cleanup_apps() - def instrument(self, ref, payload, status=200): + def instrument(self, ref, payload, status=200, etag=MockAdapter.ETAG): recorder = [] - adapter = MockAdapter(payload, status, recorder) + adapter = MockAdapter(payload, status, recorder, etag) ref._client.session.mount(self.test_url, adapter) return recorder @@ -427,6 +467,19 @@ def transaction_update(data): assert len(recorder) == 1 assert recorder[0].method == 'GET' + def test_transaction_abort(self): + ref = db.reference('/test/count') + data = 42 + recorder = self.instrument(ref, json.dumps(data), etag='1') + + with pytest.raises(db.TransactionAbortedError) as excinfo: + ref.transaction(lambda x: x + 1 if x else 1) + assert isinstance(excinfo.value, exceptions.AbortedError) + assert str(excinfo.value) == 'Transaction aborted after failed retries.' + assert excinfo.value.cause is None + assert excinfo.value.http_response is None + assert len(recorder) == 1 + 25 + @pytest.mark.parametrize('func', [None, 0, 1, True, False, 'foo', dict(), list(), tuple()]) def test_transaction_invalid_function(self, func): ref = db.reference('/test') @@ -449,21 +502,29 @@ def test_get_reference(self, path, expected): else: assert ref.parent.path == parent - @pytest.mark.parametrize('error_code', [400, 401, 500]) - def test_server_error(self, error_code): + @pytest.mark.parametrize('error_code', error_codes.keys()) + @pytest.mark.parametrize('func', _RefOperations.get_ops()) + def test_server_error(self, error_code, func): ref = db.reference('/test') self.instrument(ref, json.dumps({'error' : 'json error message'}), error_code) - with pytest.raises(db.ApiCallError) as excinfo: - ref.get() - assert 'Reason: json error message' in str(excinfo.value) - - @pytest.mark.parametrize('error_code', [400, 401, 500]) - def test_other_error(self, error_code): + exc_type = self.error_codes[error_code] + with pytest.raises(exc_type) as excinfo: + func(ref) + assert str(excinfo.value) == 'json error message' + assert excinfo.value.cause is not None + assert excinfo.value.http_response is not None + + @pytest.mark.parametrize('error_code', error_codes.keys()) + @pytest.mark.parametrize('func', _RefOperations.get_ops()) + def test_other_error(self, error_code, func): ref = db.reference('/test') self.instrument(ref, 'custom error message', error_code) - with pytest.raises(db.ApiCallError) as excinfo: - ref.get() - assert 'Reason: custom error message' in str(excinfo.value) + exc_type = self.error_codes[error_code] + with pytest.raises(exc_type) as excinfo: + func(ref) + assert str(excinfo.value) == 'Unexpected response from database: custom error message' + assert excinfo.value.cause is not None + assert excinfo.value.http_response is not None class TestListenerRegistration(object): @@ -481,9 +542,11 @@ def test_listen_error(self): session.mount(test_url, adapter) def callback(_): pass - with pytest.raises(db.ApiCallError) as excinfo: + with pytest.raises(exceptions.InternalError) as excinfo: ref._listen_with_session(callback, session) - assert 'Reason: json error message' in str(excinfo.value) + assert str(excinfo.value) == 'json error message' + assert excinfo.value.cause is not None + assert excinfo.value.http_response is not None finally: testutils.cleanup_apps() diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py new file mode 100644 index 000000000..98d9ce5e9 --- /dev/null +++ b/tests/test_exceptions.py @@ -0,0 +1,335 @@ +# Copyright 2019 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import socket + +import httplib2 +import pytest +import requests +from requests import models +import six + +from googleapiclient import errors +from firebase_admin import exceptions +from firebase_admin import _utils + + +_NOT_FOUND_ERROR_DICT = { + 'status': 'NOT_FOUND', + 'message': 'test error' +} + + +_NOT_FOUND_PAYLOAD = json.dumps({ + 'error': _NOT_FOUND_ERROR_DICT, +}) + + +class TestRequests(object): + + def test_timeout_error(self): + error = requests.exceptions.Timeout('Test error') + firebase_error = _utils.handle_requests_error(error) + assert isinstance(firebase_error, exceptions.DeadlineExceededError) + assert str(firebase_error) == 'Timed out while making an API call: Test error' + assert firebase_error.cause is error + assert firebase_error.http_response is None + + def test_requests_connection_error(self): + error = requests.exceptions.ConnectionError('Test error') + firebase_error = _utils.handle_requests_error(error) + assert isinstance(firebase_error, exceptions.UnavailableError) + assert str(firebase_error) == 'Failed to establish a connection: Test error' + assert firebase_error.cause is error + assert firebase_error.http_response is None + + def test_unknown_transport_error(self): + error = requests.exceptions.RequestException('Test error') + firebase_error = _utils.handle_requests_error(error) + assert isinstance(firebase_error, exceptions.UnknownError) + assert str(firebase_error) == 'Unknown error while making a remote service call: Test error' + assert firebase_error.cause is error + assert firebase_error.http_response is None + + def test_http_response(self): + resp, error = self._create_response() + firebase_error = _utils.handle_requests_error(error) + assert isinstance(firebase_error, exceptions.InternalError) + assert str(firebase_error) == 'Test error' + assert firebase_error.cause is error + assert firebase_error.http_response is resp + + def test_http_response_with_unknown_status(self): + resp, error = self._create_response(status=501) + firebase_error = _utils.handle_requests_error(error) + assert isinstance(firebase_error, exceptions.UnknownError) + assert str(firebase_error) == 'Test error' + assert firebase_error.cause is error + assert firebase_error.http_response is resp + + def test_http_response_with_message(self): + resp, error = self._create_response() + firebase_error = _utils.handle_requests_error(error, message='Explicit error message') + assert isinstance(firebase_error, exceptions.InternalError) + assert str(firebase_error) == 'Explicit error message' + assert firebase_error.cause is error + assert firebase_error.http_response is resp + + def test_http_response_with_code(self): + resp, error = self._create_response() + firebase_error = _utils.handle_requests_error(error, code=exceptions.UNAVAILABLE) + assert isinstance(firebase_error, exceptions.UnavailableError) + assert str(firebase_error) == 'Test error' + assert firebase_error.cause is error + assert firebase_error.http_response is resp + + def test_http_response_with_message_and_code(self): + resp, error = self._create_response() + firebase_error = _utils.handle_requests_error( + error, message='Explicit error message', code=exceptions.UNAVAILABLE) + assert isinstance(firebase_error, exceptions.UnavailableError) + assert str(firebase_error) == 'Explicit error message' + assert firebase_error.cause is error + assert firebase_error.http_response is resp + + def test_handle_platform_error(self): + resp, error = self._create_response(payload=_NOT_FOUND_PAYLOAD) + firebase_error = _utils.handle_platform_error_from_requests(error) + assert isinstance(firebase_error, exceptions.NotFoundError) + assert str(firebase_error) == 'test error' + assert firebase_error.cause is error + assert firebase_error.http_response is resp + + def test_handle_platform_error_with_no_response(self): + error = requests.exceptions.RequestException('Test error') + firebase_error = _utils.handle_platform_error_from_requests(error) + assert isinstance(firebase_error, exceptions.UnknownError) + assert str(firebase_error) == 'Unknown error while making a remote service call: Test error' + assert firebase_error.cause is error + assert firebase_error.http_response is None + + def test_handle_platform_error_with_no_error_code(self): + resp, error = self._create_response(payload='no error code') + firebase_error = _utils.handle_platform_error_from_requests(error) + assert isinstance(firebase_error, exceptions.InternalError) + message = 'Unexpected HTTP response with status: 500; body: no error code' + assert str(firebase_error) == message + assert firebase_error.cause is error + assert firebase_error.http_response is resp + + def test_handle_platform_error_with_custom_handler(self): + resp, error = self._create_response(payload=_NOT_FOUND_PAYLOAD) + invocations = [] + + def _custom_handler(cause, message, error_dict): + invocations.append((cause, message, error_dict)) + return exceptions.InvalidArgumentError('Custom message', cause, cause.response) + + firebase_error = _utils.handle_platform_error_from_requests(error, _custom_handler) + + assert isinstance(firebase_error, exceptions.InvalidArgumentError) + assert str(firebase_error) == 'Custom message' + assert firebase_error.cause is error + assert firebase_error.http_response is resp + assert len(invocations) == 1 + args = invocations[0] + assert len(args) == 3 + assert args[0] is error + assert args[1] == 'test error' + assert args[2] == _NOT_FOUND_ERROR_DICT + + def test_handle_platform_error_with_custom_handler_ignore(self): + resp, error = self._create_response(payload=_NOT_FOUND_PAYLOAD) + invocations = [] + + def _custom_handler(cause, message, error_dict): + invocations.append((cause, message, error_dict)) + return None + + firebase_error = _utils.handle_platform_error_from_requests(error, _custom_handler) + + assert isinstance(firebase_error, exceptions.NotFoundError) + assert str(firebase_error) == 'test error' + assert firebase_error.cause is error + assert firebase_error.http_response is resp + assert len(invocations) == 1 + args = invocations[0] + assert len(args) == 3 + assert args[0] is error + assert args[1] == 'test error' + assert args[2] == _NOT_FOUND_ERROR_DICT + + def _create_response(self, status=500, payload=None): + resp = models.Response() + resp.status_code = status + if payload: + resp.raw = six.BytesIO(payload.encode()) + exc = requests.exceptions.RequestException('Test error', response=resp) + return resp, exc + + +class TestGoogleApiClient(object): + + @pytest.mark.parametrize('error', [ + socket.timeout('Test error'), + socket.error('Read timed out') + ]) + def test_googleapicleint_timeout_error(self, error): + firebase_error = _utils.handle_googleapiclient_error(error) + assert isinstance(firebase_error, exceptions.DeadlineExceededError) + assert str(firebase_error) == 'Timed out while making an API call: {0}'.format(error) + assert firebase_error.cause is error + assert firebase_error.http_response is None + + def test_googleapiclient_connection_error(self): + error = httplib2.ServerNotFoundError('Test error') + firebase_error = _utils.handle_googleapiclient_error(error) + assert isinstance(firebase_error, exceptions.UnavailableError) + assert str(firebase_error) == 'Failed to establish a connection: Test error' + assert firebase_error.cause is error + assert firebase_error.http_response is None + + def test_unknown_transport_error(self): + error = socket.error('Test error') + firebase_error = _utils.handle_googleapiclient_error(error) + assert isinstance(firebase_error, exceptions.UnknownError) + assert str(firebase_error) == 'Unknown error while making a remote service call: Test error' + assert firebase_error.cause is error + assert firebase_error.http_response is None + + def test_http_response(self): + error = self._create_http_error() + firebase_error = _utils.handle_googleapiclient_error(error) + assert isinstance(firebase_error, exceptions.InternalError) + assert str(firebase_error) == str(error) + assert firebase_error.cause is error + assert firebase_error.http_response.status_code == 500 + assert firebase_error.http_response.content.decode() == 'Body' + + def test_http_response_with_unknown_status(self): + error = self._create_http_error(status=501) + firebase_error = _utils.handle_googleapiclient_error(error) + assert isinstance(firebase_error, exceptions.UnknownError) + assert str(firebase_error) == str(error) + assert firebase_error.cause is error + assert firebase_error.http_response.status_code == 501 + assert firebase_error.http_response.content.decode() == 'Body' + + def test_http_response_with_message(self): + error = self._create_http_error() + firebase_error = _utils.handle_googleapiclient_error( + error, message='Explicit error message') + assert isinstance(firebase_error, exceptions.InternalError) + assert str(firebase_error) == 'Explicit error message' + assert firebase_error.cause is error + assert firebase_error.http_response.status_code == 500 + assert firebase_error.http_response.content.decode() == 'Body' + + def test_http_response_with_code(self): + error = self._create_http_error() + firebase_error = _utils.handle_googleapiclient_error( + error, code=exceptions.UNAVAILABLE) + assert isinstance(firebase_error, exceptions.UnavailableError) + assert str(firebase_error) == str(error) + assert firebase_error.cause is error + assert firebase_error.http_response.status_code == 500 + assert firebase_error.http_response.content.decode() == 'Body' + + def test_http_response_with_message_and_code(self): + error = self._create_http_error() + firebase_error = _utils.handle_googleapiclient_error( + error, message='Explicit error message', code=exceptions.UNAVAILABLE) + assert isinstance(firebase_error, exceptions.UnavailableError) + assert str(firebase_error) == 'Explicit error message' + assert firebase_error.cause is error + assert firebase_error.http_response.status_code == 500 + assert firebase_error.http_response.content.decode() == 'Body' + + def test_handle_platform_error(self): + error = self._create_http_error(payload=_NOT_FOUND_PAYLOAD) + firebase_error = _utils.handle_platform_error_from_googleapiclient(error) + assert isinstance(firebase_error, exceptions.NotFoundError) + assert str(firebase_error) == 'test error' + assert firebase_error.cause is error + assert firebase_error.http_response.status_code == 500 + assert firebase_error.http_response.content.decode() == _NOT_FOUND_PAYLOAD + + def test_handle_platform_error_with_no_response(self): + error = socket.error('Test error') + firebase_error = _utils.handle_platform_error_from_googleapiclient(error) + assert isinstance(firebase_error, exceptions.UnknownError) + assert str(firebase_error) == 'Unknown error while making a remote service call: Test error' + assert firebase_error.cause is error + assert firebase_error.http_response is None + + def test_handle_platform_error_with_no_error_code(self): + error = self._create_http_error(payload='no error code') + firebase_error = _utils.handle_platform_error_from_googleapiclient(error) + assert isinstance(firebase_error, exceptions.InternalError) + message = 'Unexpected HTTP response with status: 500; body: no error code' + assert str(firebase_error) == message + assert firebase_error.cause is error + assert firebase_error.http_response.status_code == 500 + assert firebase_error.http_response.content.decode() == 'no error code' + + def test_handle_platform_error_with_custom_handler(self): + error = self._create_http_error(payload=_NOT_FOUND_PAYLOAD) + invocations = [] + + def _custom_handler(cause, message, error_dict, http_response): + invocations.append((cause, message, error_dict, http_response)) + return exceptions.InvalidArgumentError('Custom message', cause, http_response) + + firebase_error = _utils.handle_platform_error_from_googleapiclient(error, _custom_handler) + + assert isinstance(firebase_error, exceptions.InvalidArgumentError) + assert str(firebase_error) == 'Custom message' + assert firebase_error.cause is error + assert firebase_error.http_response.status_code == 500 + assert firebase_error.http_response.content.decode() == _NOT_FOUND_PAYLOAD + assert len(invocations) == 1 + args = invocations[0] + assert len(args) == 4 + assert args[0] is error + assert args[1] == 'test error' + assert args[2] == _NOT_FOUND_ERROR_DICT + assert args[3] is not None + + def test_handle_platform_error_with_custom_handler_ignore(self): + error = self._create_http_error(payload=_NOT_FOUND_PAYLOAD) + invocations = [] + + def _custom_handler(cause, message, error_dict, http_response): + invocations.append((cause, message, error_dict, http_response)) + return None + + firebase_error = _utils.handle_platform_error_from_googleapiclient(error, _custom_handler) + + assert isinstance(firebase_error, exceptions.NotFoundError) + assert str(firebase_error) == 'test error' + assert firebase_error.cause is error + assert firebase_error.http_response.status_code == 500 + assert firebase_error.http_response.content.decode() == _NOT_FOUND_PAYLOAD + assert len(invocations) == 1 + args = invocations[0] + assert len(args) == 4 + assert args[0] is error + assert args[1] == 'test error' + assert args[2] == _NOT_FOUND_ERROR_DICT + assert args[3] is not None + + def _create_http_error(self, status=500, payload='Body'): + resp = httplib2.Response({'status': status}) + return errors.HttpError(resp, payload.encode()) diff --git a/tests/test_instance_id.py b/tests/test_instance_id.py index e8e8edd27..83e66491a 100644 --- a/tests/test_instance_id.py +++ b/tests/test_instance_id.py @@ -17,15 +17,37 @@ import pytest import firebase_admin +from firebase_admin import exceptions from firebase_admin import instance_id from tests import testutils http_errors = { - 404: 'Instance ID "test_iid": Failed to find the instance ID.', - 409: 'Instance ID "test_iid": Already deleted.', - 429: 'Instance ID "test_iid": Request throttled out by the backend server.', - 500: 'Instance ID "test_iid": Internal server error.', + 400: ( + 'Instance ID "test_iid": Malformed instance ID argument.', + exceptions.InvalidArgumentError), + 401: ( + 'Instance ID "test_iid": Request not authorized.', + exceptions.UnauthenticatedError), + 403: ( + ('Instance ID "test_iid": Project does not match instance ID or the client does not have ' + 'sufficient privileges.'), + exceptions.PermissionDeniedError), + 404: ( + 'Instance ID "test_iid": Failed to find the instance ID.', + exceptions.NotFoundError), + 409: ( + 'Instance ID "test_iid": Already deleted.', + exceptions.ConflictError), + 429: ( + 'Instance ID "test_iid": Request throttled out by the backend server.', + exceptions.ResourceExhaustedError), + 500: ( + 'Instance ID "test_iid": Internal server error.', + exceptions.InternalError), + 503: ( + 'Instance ID "test_iid": Backend servers are over capacity. Try again later.', + exceptions.UnavailableError), } class TestDeleteInstanceId(object): @@ -74,11 +96,17 @@ def test_delete_instance_id_error(self, status): cred = testutils.MockCredential() app = firebase_admin.initialize_app(cred, {'projectId': 'explicit-project-id'}) _, recorder = self._instrument_iid_service(app, status, 'some error') - with pytest.raises(instance_id.ApiCallError) as excinfo: + msg, exc = http_errors.get(status) + with pytest.raises(exc) as excinfo: instance_id.delete_instance_id('test_iid') - assert str(excinfo.value) == http_errors.get(status) - assert excinfo.value.detail is not None - assert len(recorder) == 1 + assert str(excinfo.value) == msg + assert excinfo.value.cause is not None + assert excinfo.value.http_response is not None + if status != 401: + assert len(recorder) == 1 + else: + # 401 responses are automatically retried by google-auth + assert len(recorder) == 3 assert recorder[0].method == 'DELETE' assert recorder[0].url == self._get_url('explicit-project-id', 'test_iid') @@ -86,12 +114,13 @@ def test_delete_instance_id_unexpected_error(self): cred = testutils.MockCredential() app = firebase_admin.initialize_app(cred, {'projectId': 'explicit-project-id'}) _, recorder = self._instrument_iid_service(app, 501, 'some error') - with pytest.raises(instance_id.ApiCallError) as excinfo: + with pytest.raises(exceptions.UnknownError) as excinfo: instance_id.delete_instance_id('test_iid') url = self._get_url('explicit-project-id', 'test_iid') - message = '501 Server Error: None for url: {0}'.format(url) + message = 'Instance ID "test_iid": 501 Server Error: None for url: {0}'.format(url) assert str(excinfo.value) == message - assert excinfo.value.detail is not None + assert excinfo.value.cause is not None + assert excinfo.value.http_response is not None assert len(recorder) == 1 assert recorder[0].method == 'DELETE' assert recorder[0].url == url diff --git a/tests/test_messaging.py b/tests/test_messaging.py index 4f7520045..1f6fa102c 100644 --- a/tests/test_messaging.py +++ b/tests/test_messaging.py @@ -23,6 +23,7 @@ from googleapiclient.http import HttpMockSequence import firebase_admin +from firebase_admin import exceptions from firebase_admin import messaging from tests import testutils @@ -31,7 +32,20 @@ NON_DICT_ARGS = ['', list(), tuple(), True, False, 1, 0, {1: 'foo'}, {'foo': 1}] NON_OBJECT_ARGS = [list(), tuple(), dict(), 'foo', 0, 1, True, False] NON_LIST_ARGS = ['', tuple(), dict(), True, False, 1, 0, [1], ['foo', 1]] -HTTP_ERRORS = [400, 404, 500] +HTTP_ERROR_CODES = { + 400: exceptions.InvalidArgumentError, + 403: exceptions.PermissionDeniedError, + 404: exceptions.NotFoundError, + 500: exceptions.InternalError, + 503: exceptions.UnavailableError, +} +FCM_ERROR_CODES = { + 'APNS_AUTH_ERROR': messaging.ThirdPartyAuthError, + 'QUOTA_EXCEEDED': messaging.QuotaExceededError, + 'SENDER_ID_MISMATCH': messaging.SenderIdMismatchError, + 'THIRD_PARTY_AUTH_ERROR': messaging.ThirdPartyAuthError, + 'UNREGISTERED': messaging.UnregisteredError, +} def check_encoding(msg, expected=None): @@ -39,6 +53,13 @@ def check_encoding(msg, expected=None): if expected: assert encoded == expected +def check_exception(exception, message, status): + assert isinstance(exception, exceptions.FirebaseError) + assert str(exception) == message + assert exception.cause is not None + assert exception.http_response is not None + assert exception.http_response.status_code == status + class TestMulticastMessage(object): @@ -551,25 +572,6 @@ def test_webpush_options(self): } check_encoding(msg, expected) - def test_deprecated_fcm_options(self): - msg = messaging.Message( - topic='topic', - webpush=messaging.WebpushConfig( - fcm_options=messaging.WebpushFcmOptions( - link='https://example', - ), - ) - ) - expected = { - 'topic': 'topic', - 'webpush': { - 'fcm_options': { - 'link': 'https://example', - }, - }, - } - check_encoding(msg, expected) - class TestWebpushNotificationEncoder(object): @@ -1409,15 +1411,14 @@ def test_send(self): body = {'message': messaging._MessagingService.encode_message(msg)} assert json.loads(recorder[0].body.decode()) == body - @pytest.mark.parametrize('status', HTTP_ERRORS) - def test_send_error(self, status): + @pytest.mark.parametrize('status,exc_type', HTTP_ERROR_CODES.items()) + def test_send_error(self, status, exc_type): _, recorder = self._instrument_messaging_service(status=status, payload='{}') msg = messaging.Message(topic='foo') - with pytest.raises(messaging.ApiCallError) as excinfo: + with pytest.raises(exc_type) as excinfo: messaging.send(msg) expected = 'Unexpected HTTP response with status: {0}; body: {{}}'.format(status) - assert str(excinfo.value) == expected - assert str(excinfo.value.code) == messaging._MessagingService.UNKNOWN_ERROR + check_exception(excinfo.value, expected, status) assert len(recorder) == 1 assert recorder[0].method == 'POST' assert recorder[0].url == self._get_url('explicit-project-id') @@ -1426,7 +1427,7 @@ def test_send_error(self, status): body = {'message': messaging._MessagingService.JSON_ENCODER.default(msg)} assert json.loads(recorder[0].body.decode()) == body - @pytest.mark.parametrize('status', HTTP_ERRORS) + @pytest.mark.parametrize('status', HTTP_ERROR_CODES) def test_send_detailed_error(self, status): payload = json.dumps({ 'error': { @@ -1436,17 +1437,16 @@ def test_send_detailed_error(self, status): }) _, recorder = self._instrument_messaging_service(status=status, payload=payload) msg = messaging.Message(topic='foo') - with pytest.raises(messaging.ApiCallError) as excinfo: + with pytest.raises(exceptions.InvalidArgumentError) as excinfo: messaging.send(msg) - assert str(excinfo.value) == 'test error' - assert str(excinfo.value.code) == 'invalid-argument' + check_exception(excinfo.value, 'test error', status) assert len(recorder) == 1 assert recorder[0].method == 'POST' assert recorder[0].url == self._get_url('explicit-project-id') body = {'message': messaging._MessagingService.JSON_ENCODER.default(msg)} assert json.loads(recorder[0].body.decode()) == body - @pytest.mark.parametrize('status', HTTP_ERRORS) + @pytest.mark.parametrize('status', HTTP_ERROR_CODES) def test_send_canonical_error_code(self, status): payload = json.dumps({ 'error': { @@ -1456,18 +1456,18 @@ def test_send_canonical_error_code(self, status): }) _, recorder = self._instrument_messaging_service(status=status, payload=payload) msg = messaging.Message(topic='foo') - with pytest.raises(messaging.ApiCallError) as excinfo: + with pytest.raises(exceptions.NotFoundError) as excinfo: messaging.send(msg) - assert str(excinfo.value) == 'test error' - assert str(excinfo.value.code) == 'registration-token-not-registered' + check_exception(excinfo.value, 'test error', status) assert len(recorder) == 1 assert recorder[0].method == 'POST' assert recorder[0].url == self._get_url('explicit-project-id') body = {'message': messaging._MessagingService.JSON_ENCODER.default(msg)} assert json.loads(recorder[0].body.decode()) == body - @pytest.mark.parametrize('status', HTTP_ERRORS) - def test_send_fcm_error_code(self, status): + @pytest.mark.parametrize('status', HTTP_ERROR_CODES) + @pytest.mark.parametrize('fcm_error_code, exc_type', FCM_ERROR_CODES.items()) + def test_send_fcm_error_code(self, status, fcm_error_code, exc_type): payload = json.dumps({ 'error': { 'status': 'INVALID_ARGUMENT', @@ -1475,17 +1475,41 @@ def test_send_fcm_error_code(self, status): 'details': [ { '@type': 'type.googleapis.com/google.firebase.fcm.v1.FcmError', - 'errorCode': 'UNREGISTERED', + 'errorCode': fcm_error_code, + }, + ], + } + }) + _, recorder = self._instrument_messaging_service(status=status, payload=payload) + msg = messaging.Message(topic='foo') + with pytest.raises(exc_type) as excinfo: + messaging.send(msg) + check_exception(excinfo.value, 'test error', status) + assert len(recorder) == 1 + assert recorder[0].method == 'POST' + assert recorder[0].url == self._get_url('explicit-project-id') + body = {'message': messaging._MessagingService.JSON_ENCODER.default(msg)} + assert json.loads(recorder[0].body.decode()) == body + + @pytest.mark.parametrize('status', HTTP_ERROR_CODES) + def test_send_unknown_fcm_error_code(self, status): + payload = json.dumps({ + 'error': { + 'status': 'INVALID_ARGUMENT', + 'message': 'test error', + 'details': [ + { + '@type': 'type.googleapis.com/google.firebase.fcm.v1.FcmError', + 'errorCode': 'SOME_UNKNOWN_CODE', }, ], } }) _, recorder = self._instrument_messaging_service(status=status, payload=payload) msg = messaging.Message(topic='foo') - with pytest.raises(messaging.ApiCallError) as excinfo: + with pytest.raises(exceptions.InvalidArgumentError) as excinfo: messaging.send(msg) - assert str(excinfo.value) == 'test error' - assert str(excinfo.value.code) == 'registration-token-not-registered' + check_exception(excinfo.value, 'test error', status) assert len(recorder) == 1 assert recorder[0].method == 'POST' assert recorder[0].url == self._get_url('explicit-project-id') @@ -1569,7 +1593,7 @@ def test_send_all(self): assert all([r.success for r in batch_response.responses]) assert not any([r.exception for r in batch_response.responses]) - @pytest.mark.parametrize('status', HTTP_ERRORS) + @pytest.mark.parametrize('status', HTTP_ERROR_CODES) def test_send_all_detailed_error(self, status): success_payload = json.dumps({'name': 'message-id'}) error_payload = json.dumps({ @@ -1592,12 +1616,11 @@ def test_send_all_detailed_error(self, status): error_response = batch_response.responses[1] assert error_response.message_id is None assert error_response.success is False - assert error_response.exception is not None exception = error_response.exception - assert str(exception) == 'test error' - assert str(exception.code) == 'invalid-argument' + assert isinstance(exception, exceptions.InvalidArgumentError) + check_exception(exception, 'test error', status) - @pytest.mark.parametrize('status', HTTP_ERRORS) + @pytest.mark.parametrize('status', HTTP_ERROR_CODES) def test_send_all_canonical_error_code(self, status): success_payload = json.dumps({'name': 'message-id'}) error_payload = json.dumps({ @@ -1620,13 +1643,13 @@ def test_send_all_canonical_error_code(self, status): error_response = batch_response.responses[1] assert error_response.message_id is None assert error_response.success is False - assert error_response.exception is not None exception = error_response.exception - assert str(exception) == 'test error' - assert str(exception.code) == 'registration-token-not-registered' + assert isinstance(exception, exceptions.NotFoundError) + check_exception(exception, 'test error', status) - @pytest.mark.parametrize('status', HTTP_ERRORS) - def test_send_all_fcm_error_code(self, status): + @pytest.mark.parametrize('status', HTTP_ERROR_CODES) + @pytest.mark.parametrize('fcm_error_code, exc_type', FCM_ERROR_CODES.items()) + def test_send_all_fcm_error_code(self, status, fcm_error_code, exc_type): success_payload = json.dumps({'name': 'message-id'}) error_payload = json.dumps({ 'error': { @@ -1635,7 +1658,7 @@ def test_send_all_fcm_error_code(self, status): 'details': [ { '@type': 'type.googleapis.com/google.firebase.fcm.v1.FcmError', - 'errorCode': 'UNREGISTERED', + 'errorCode': fcm_error_code, }, ], } @@ -1654,22 +1677,20 @@ def test_send_all_fcm_error_code(self, status): error_response = batch_response.responses[1] assert error_response.message_id is None assert error_response.success is False - assert error_response.exception is not None exception = error_response.exception - assert str(exception) == 'test error' - assert str(exception.code) == 'registration-token-not-registered' + assert isinstance(exception, exc_type) + check_exception(exception, 'test error', status) - @pytest.mark.parametrize('status', HTTP_ERRORS) - def test_send_all_batch_error(self, status): + @pytest.mark.parametrize('status, exc_type', HTTP_ERROR_CODES.items()) + def test_send_all_batch_error(self, status, exc_type): _ = self._instrument_batch_messaging_service(status=status, payload='{}') msg = messaging.Message(topic='foo') - with pytest.raises(messaging.ApiCallError) as excinfo: + with pytest.raises(exc_type) as excinfo: messaging.send_all([msg]) expected = 'Unexpected HTTP response with status: {0}; body: {{}}'.format(status) - assert str(excinfo.value) == expected - assert str(excinfo.value.code) == messaging._MessagingService.UNKNOWN_ERROR + check_exception(excinfo.value, expected, status) - @pytest.mark.parametrize('status', HTTP_ERRORS) + @pytest.mark.parametrize('status', HTTP_ERROR_CODES) def test_send_all_batch_detailed_error(self, status): payload = json.dumps({ 'error': { @@ -1679,12 +1700,11 @@ def test_send_all_batch_detailed_error(self, status): }) _ = self._instrument_batch_messaging_service(status=status, payload=payload) msg = messaging.Message(topic='foo') - with pytest.raises(messaging.ApiCallError) as excinfo: + with pytest.raises(exceptions.InvalidArgumentError) as excinfo: messaging.send_all([msg]) - assert str(excinfo.value) == 'test error' - assert str(excinfo.value.code) == 'invalid-argument' + check_exception(excinfo.value, 'test error', status) - @pytest.mark.parametrize('status', HTTP_ERRORS) + @pytest.mark.parametrize('status', HTTP_ERROR_CODES) def test_send_all_batch_canonical_error_code(self, status): payload = json.dumps({ 'error': { @@ -1694,12 +1714,11 @@ def test_send_all_batch_canonical_error_code(self, status): }) _ = self._instrument_batch_messaging_service(status=status, payload=payload) msg = messaging.Message(topic='foo') - with pytest.raises(messaging.ApiCallError) as excinfo: + with pytest.raises(exceptions.NotFoundError) as excinfo: messaging.send_all([msg]) - assert str(excinfo.value) == 'test error' - assert str(excinfo.value.code) == 'registration-token-not-registered' + check_exception(excinfo.value, 'test error', status) - @pytest.mark.parametrize('status', HTTP_ERRORS) + @pytest.mark.parametrize('status', HTTP_ERROR_CODES) def test_send_all_batch_fcm_error_code(self, status): payload = json.dumps({ 'error': { @@ -1715,10 +1734,9 @@ def test_send_all_batch_fcm_error_code(self, status): }) _ = self._instrument_batch_messaging_service(status=status, payload=payload) msg = messaging.Message(topic='foo') - with pytest.raises(messaging.ApiCallError) as excinfo: + with pytest.raises(messaging.UnregisteredError) as excinfo: messaging.send_all([msg]) - assert str(excinfo.value) == 'test error' - assert str(excinfo.value.code) == 'registration-token-not-registered' + check_exception(excinfo.value, 'test error', status) class TestSendMulticast(TestBatch): @@ -1750,7 +1768,7 @@ def test_send_multicast(self): assert all([r.success for r in batch_response.responses]) assert not any([r.exception for r in batch_response.responses]) - @pytest.mark.parametrize('status', HTTP_ERRORS) + @pytest.mark.parametrize('status', HTTP_ERROR_CODES) def test_send_multicast_detailed_error(self, status): success_payload = json.dumps({'name': 'message-id'}) error_payload = json.dumps({ @@ -1775,10 +1793,10 @@ def test_send_multicast_detailed_error(self, status): assert error_response.success is False assert error_response.exception is not None exception = error_response.exception - assert str(exception) == 'test error' - assert str(exception.code) == 'invalid-argument' + assert isinstance(exception, exceptions.InvalidArgumentError) + check_exception(exception, 'test error', status) - @pytest.mark.parametrize('status', HTTP_ERRORS) + @pytest.mark.parametrize('status', HTTP_ERROR_CODES) def test_send_multicast_canonical_error_code(self, status): success_payload = json.dumps({'name': 'message-id'}) error_payload = json.dumps({ @@ -1803,10 +1821,10 @@ def test_send_multicast_canonical_error_code(self, status): assert error_response.success is False assert error_response.exception is not None exception = error_response.exception - assert str(exception) == 'test error' - assert str(exception.code) == 'registration-token-not-registered' + assert isinstance(exception, exceptions.NotFoundError) + check_exception(exception, 'test error', status) - @pytest.mark.parametrize('status', HTTP_ERRORS) + @pytest.mark.parametrize('status', HTTP_ERROR_CODES) def test_send_multicast_fcm_error_code(self, status): success_payload = json.dumps({'name': 'message-id'}) error_payload = json.dumps({ @@ -1837,20 +1855,19 @@ def test_send_multicast_fcm_error_code(self, status): assert error_response.success is False assert error_response.exception is not None exception = error_response.exception - assert str(exception) == 'test error' - assert str(exception.code) == 'registration-token-not-registered' + assert isinstance(exception, messaging.UnregisteredError) + check_exception(exception, 'test error', status) - @pytest.mark.parametrize('status', HTTP_ERRORS) - def test_send_multicast_batch_error(self, status): + @pytest.mark.parametrize('status, exc_type', HTTP_ERROR_CODES.items()) + def test_send_multicast_batch_error(self, status, exc_type): _ = self._instrument_batch_messaging_service(status=status, payload='{}') msg = messaging.MulticastMessage(tokens=['foo']) - with pytest.raises(messaging.ApiCallError) as excinfo: + with pytest.raises(exc_type) as excinfo: messaging.send_multicast(msg) expected = 'Unexpected HTTP response with status: {0}; body: {{}}'.format(status) - assert str(excinfo.value) == expected - assert str(excinfo.value.code) == messaging._MessagingService.UNKNOWN_ERROR + check_exception(excinfo.value, expected, status) - @pytest.mark.parametrize('status', HTTP_ERRORS) + @pytest.mark.parametrize('status', HTTP_ERROR_CODES) def test_send_multicast_batch_detailed_error(self, status): payload = json.dumps({ 'error': { @@ -1860,12 +1877,11 @@ def test_send_multicast_batch_detailed_error(self, status): }) _ = self._instrument_batch_messaging_service(status=status, payload=payload) msg = messaging.MulticastMessage(tokens=['foo']) - with pytest.raises(messaging.ApiCallError) as excinfo: + with pytest.raises(exceptions.InvalidArgumentError) as excinfo: messaging.send_multicast(msg) - assert str(excinfo.value) == 'test error' - assert str(excinfo.value.code) == 'invalid-argument' + check_exception(excinfo.value, 'test error', status) - @pytest.mark.parametrize('status', HTTP_ERRORS) + @pytest.mark.parametrize('status', HTTP_ERROR_CODES) def test_send_multicast_batch_canonical_error_code(self, status): payload = json.dumps({ 'error': { @@ -1875,12 +1891,11 @@ def test_send_multicast_batch_canonical_error_code(self, status): }) _ = self._instrument_batch_messaging_service(status=status, payload=payload) msg = messaging.MulticastMessage(tokens=['foo']) - with pytest.raises(messaging.ApiCallError) as excinfo: + with pytest.raises(exceptions.NotFoundError) as excinfo: messaging.send_multicast(msg) - assert str(excinfo.value) == 'test error' - assert str(excinfo.value.code) == 'registration-token-not-registered' + check_exception(excinfo.value, 'test error', status) - @pytest.mark.parametrize('status', HTTP_ERRORS) + @pytest.mark.parametrize('status', HTTP_ERROR_CODES) def test_send_multicast_batch_fcm_error_code(self, status): payload = json.dumps({ 'error': { @@ -1896,10 +1911,9 @@ def test_send_multicast_batch_fcm_error_code(self, status): }) _ = self._instrument_batch_messaging_service(status=status, payload=payload) msg = messaging.MulticastMessage(tokens=['foo']) - with pytest.raises(messaging.ApiCallError) as excinfo: + with pytest.raises(messaging.UnregisteredError) as excinfo: messaging.send_multicast(msg) - assert str(excinfo.value) == 'test error' - assert str(excinfo.value.code) == 'registration-token-not-registered' + check_exception(excinfo.value, 'test error', status) class TestTopicManagement(object): @@ -1977,30 +1991,24 @@ def test_subscribe_to_topic(self, args): assert recorder[0].url == self._get_url('iid/v1:batchAdd') assert json.loads(recorder[0].body.decode()) == args[2] - @pytest.mark.parametrize('status', HTTP_ERRORS) - def test_subscribe_to_topic_error(self, status): + @pytest.mark.parametrize('status, exc_type', HTTP_ERROR_CODES.items()) + def test_subscribe_to_topic_error(self, status, exc_type): _, recorder = self._instrument_iid_service( status=status, payload=self._DEFAULT_ERROR_RESPONSE) - with pytest.raises(messaging.ApiCallError) as excinfo: + with pytest.raises(exc_type) as excinfo: messaging.subscribe_to_topic('foo', 'test-topic') assert str(excinfo.value) == 'error_reason' - code = messaging._MessagingService.IID_ERROR_CODES.get( - status, messaging._MessagingService.UNKNOWN_ERROR) - assert excinfo.value.code == code assert len(recorder) == 1 assert recorder[0].method == 'POST' assert recorder[0].url == self._get_url('iid/v1:batchAdd') - @pytest.mark.parametrize('status', HTTP_ERRORS) - def test_subscribe_to_topic_non_json_error(self, status): + @pytest.mark.parametrize('status, exc_type', HTTP_ERROR_CODES.items()) + def test_subscribe_to_topic_non_json_error(self, status, exc_type): _, recorder = self._instrument_iid_service(status=status, payload='not json') - with pytest.raises(messaging.ApiCallError) as excinfo: + with pytest.raises(exc_type) as excinfo: messaging.subscribe_to_topic('foo', 'test-topic') reason = 'Unexpected HTTP response with status: {0}; body: not json'.format(status) - code = messaging._MessagingService.IID_ERROR_CODES.get( - status, messaging._MessagingService.UNKNOWN_ERROR) assert str(excinfo.value) == reason - assert excinfo.value.code == code assert len(recorder) == 1 assert recorder[0].method == 'POST' assert recorder[0].url == self._get_url('iid/v1:batchAdd') @@ -2015,30 +2023,24 @@ def test_unsubscribe_from_topic(self, args): assert recorder[0].url == self._get_url('iid/v1:batchRemove') assert json.loads(recorder[0].body.decode()) == args[2] - @pytest.mark.parametrize('status', HTTP_ERRORS) - def test_unsubscribe_from_topic_error(self, status): + @pytest.mark.parametrize('status, exc_type', HTTP_ERROR_CODES.items()) + def test_unsubscribe_from_topic_error(self, status, exc_type): _, recorder = self._instrument_iid_service( status=status, payload=self._DEFAULT_ERROR_RESPONSE) - with pytest.raises(messaging.ApiCallError) as excinfo: + with pytest.raises(exc_type) as excinfo: messaging.unsubscribe_from_topic('foo', 'test-topic') assert str(excinfo.value) == 'error_reason' - code = messaging._MessagingService.IID_ERROR_CODES.get( - status, messaging._MessagingService.UNKNOWN_ERROR) - assert excinfo.value.code == code assert len(recorder) == 1 assert recorder[0].method == 'POST' assert recorder[0].url == self._get_url('iid/v1:batchRemove') - @pytest.mark.parametrize('status', HTTP_ERRORS) - def test_unsubscribe_from_topic_non_json_error(self, status): + @pytest.mark.parametrize('status, exc_type', HTTP_ERROR_CODES.items()) + def test_unsubscribe_from_topic_non_json_error(self, status, exc_type): _, recorder = self._instrument_iid_service(status=status, payload='not json') - with pytest.raises(messaging.ApiCallError) as excinfo: + with pytest.raises(exc_type) as excinfo: messaging.unsubscribe_from_topic('foo', 'test-topic') reason = 'Unexpected HTTP response with status: {0}; body: not json'.format(status) - code = messaging._MessagingService.IID_ERROR_CODES.get( - status, messaging._MessagingService.UNKNOWN_ERROR) assert str(excinfo.value) == reason - assert excinfo.value.code == code assert len(recorder) == 1 assert recorder[0].method == 'POST' assert recorder[0].url == self._get_url('iid/v1:batchRemove') diff --git a/tests/test_project_management.py b/tests/test_project_management.py index 9de95f7fd..e8353e212 100644 --- a/tests/test_project_management.py +++ b/tests/test_project_management.py @@ -20,6 +20,7 @@ import pytest import firebase_admin +from firebase_admin import exceptions from firebase_admin import project_management from tests import testutils @@ -172,10 +173,10 @@ 'configFileContents': TEST_APP_ENCODED_CONFIG, }) -SHA_1_CERTIFICATE = project_management.ShaCertificate( +SHA_1_CERTIFICATE = project_management.SHACertificate( '123456789a123456789a123456789a123456789a', 'projects/-/androidApps/1:12345678:android:deadbeef/sha/name1') -SHA_256_CERTIFICATE = project_management.ShaCertificate( +SHA_256_CERTIFICATE = project_management.SHACertificate( '123456789a123456789a123456789a123456789a123456789a123456789a1234', 'projects/-/androidApps/1:12345678:android:deadbeef/sha/name256') GET_SHA_CERTIFICATES_RESPONSE = json.dumps({'certificates': [ @@ -189,13 +190,17 @@ app_id='1:12345678:android:deadbeef', display_name='My Android App', project_id='test-project-id') -IOS_APP_METADATA = project_management.IosAppMetadata( +IOS_APP_METADATA = project_management.IOSAppMetadata( bundle_id='com.hello.world.ios', name='projects/test-project-id/iosApps/1:12345678:ios:ca5cade5', app_id='1:12345678:android:deadbeef', display_name='My iOS App', project_id='test-project-id') +ALREADY_EXISTS_RESPONSE = ('{"error": {"status": "ALREADY_EXISTS", ' + '"message": "The resource already exists"}}') +NOT_FOUND_RESPONSE = '{"error": {"message": "Failed to find the resource"}}' +UNAVAILABLE_RESPONSE = '{"error": {"message": "Backend servers are over capacity"}}' class TestAndroidAppMetadata(object): @@ -310,12 +315,12 @@ def test_android_app_metadata_project_id(self): assert ANDROID_APP_METADATA.project_id == 'test-project-id' -class TestIosAppMetadata(object): +class TestIOSAppMetadata(object): def test_create_ios_app_metadata_errors(self): # bundle_id must be a non-empty string. with pytest.raises(ValueError): - project_management.IosAppMetadata( + project_management.IOSAppMetadata( bundle_id='', name='projects/test-project-id/iosApps/1:12345678:ios:ca5cade5', app_id='1:12345678:android:deadbeef', @@ -323,7 +328,7 @@ def test_create_ios_app_metadata_errors(self): project_id='test-project-id') # name must be a non-empty string. with pytest.raises(ValueError): - project_management.IosAppMetadata( + project_management.IOSAppMetadata( bundle_id='com.hello.world.ios', name='', app_id='1:12345678:android:deadbeef', @@ -331,7 +336,7 @@ def test_create_ios_app_metadata_errors(self): project_id='test-project-id') # app_id must be a non-empty string. with pytest.raises(ValueError): - project_management.IosAppMetadata( + project_management.IOSAppMetadata( bundle_id='com.hello.world.ios', name='projects/test-project-id/iosApps/1:12345678:ios:ca5cade5', app_id='', @@ -339,7 +344,7 @@ def test_create_ios_app_metadata_errors(self): project_id='test-project-id') # display_name must be a string or None. with pytest.raises(ValueError): - project_management.IosAppMetadata( + project_management.IOSAppMetadata( bundle_id='com.hello.world.ios', name='projects/test-project-id/iosApps/1:12345678:ios:ca5cade5', app_id='1:12345678:android:deadbeef', @@ -347,7 +352,7 @@ def test_create_ios_app_metadata_errors(self): project_id='test-project-id') # project_id must be a nonempty string. with pytest.raises(ValueError): - project_management.IosAppMetadata( + project_management.IOSAppMetadata( bundle_id='com.hello.world.ios', name='projects/test-project-id/iosApps/1:12345678:ios:ca5cade5', app_id='1:12345678:android:deadbeef', @@ -356,37 +361,37 @@ def test_create_ios_app_metadata_errors(self): def test_ios_app_metadata_eq_and_hash(self): metadata_1 = IOS_APP_METADATA - metadata_2 = project_management.IosAppMetadata( + metadata_2 = project_management.IOSAppMetadata( bundle_id='different', name='projects/test-project-id/iosApps/1:12345678:ios:ca5cade5', app_id='1:12345678:android:deadbeef', display_name='My iOS App', project_id='test-project-id') - metadata_3 = project_management.IosAppMetadata( + metadata_3 = project_management.IOSAppMetadata( bundle_id='com.hello.world.ios', name='different', app_id='1:12345678:android:deadbeef', display_name='My iOS App', project_id='test-project-id') - metadata_4 = project_management.IosAppMetadata( + metadata_4 = project_management.IOSAppMetadata( bundle_id='com.hello.world.ios', name='projects/test-project-id/iosApps/1:12345678:ios:ca5cade5', app_id='different', display_name='My iOS App', project_id='test-project-id') - metadata_5 = project_management.IosAppMetadata( + metadata_5 = project_management.IOSAppMetadata( bundle_id='com.hello.world.ios', name='projects/test-project-id/iosApps/1:12345678:ios:ca5cade5', app_id='1:12345678:android:deadbeef', display_name='different', project_id='test-project-id') - metadata_6 = project_management.IosAppMetadata( + metadata_6 = project_management.IOSAppMetadata( bundle_id='com.hello.world.ios', name='projects/test-project-id/iosApps/1:12345678:ios:ca5cade5', app_id='1:12345678:android:deadbeef', display_name='My iOS App', project_id='different') - metadata_7 = project_management.IosAppMetadata( + metadata_7 = project_management.IOSAppMetadata( bundle_id='com.hello.world.ios', name='projects/test-project-id/iosApps/1:12345678:ios:ca5cade5', app_id='1:12345678:android:deadbeef', @@ -422,40 +427,40 @@ def test_ios_app_metadata_project_id(self): assert IOS_APP_METADATA.project_id == 'test-project-id' -class TestShaCertificate(object): +class TestSHACertificate(object): def test_create_sha_certificate_errors(self): # sha_hash cannot be None. with pytest.raises(ValueError): - project_management.ShaCertificate(sha_hash=None) + project_management.SHACertificate(sha_hash=None) # sha_hash must be a string. with pytest.raises(ValueError): - project_management.ShaCertificate(sha_hash=0x123456789a123456789a123456789a123456789a) + project_management.SHACertificate(sha_hash=0x123456789a123456789a123456789a123456789a) # sha_hash must be a valid SHA-1 or SHA-256 hash. with pytest.raises(ValueError): - project_management.ShaCertificate(sha_hash='123456789a123456789') + project_management.SHACertificate(sha_hash='123456789a123456789') with pytest.raises(ValueError): - project_management.ShaCertificate(sha_hash='123456789a123456789a123456789a123456oops') + project_management.SHACertificate(sha_hash='123456789a123456789a123456789a123456oops') def test_sha_certificate_eq(self): - sha_cert_1 = project_management.ShaCertificate( + sha_cert_1 = project_management.SHACertificate( '123456789a123456789a123456789a123456789a', 'projects/-/androidApps/1:12345678:android:deadbeef/sha/name1') # sha_hash is different from sha_cert_1, but name is the same. - sha_cert_2 = project_management.ShaCertificate( + sha_cert_2 = project_management.SHACertificate( '0000000000000000000000000000000000000000', 'projects/-/androidApps/1:12345678:android:deadbeef/sha/name1') # name is different from sha_cert_1, but sha_hash is the same. - sha_cert_3 = project_management.ShaCertificate( + sha_cert_3 = project_management.SHACertificate( '123456789a123456789a123456789a123456789a', None) # name is different from sha_cert_1, but sha_hash is the same. - sha_cert_4 = project_management.ShaCertificate( + sha_cert_4 = project_management.SHACertificate( '123456789a123456789a123456789a123456789a', 'projects/-/androidApps/{0}/sha/notname1') # sha_hash and cert_type are different from sha_cert_1, but name is the same. - sha_cert_5 = project_management.ShaCertificate( + sha_cert_5 = project_management.SHACertificate( '123456789a123456789a123456789a123456789a123456789a123456789a1234', 'projects/-/androidApps/{0}/sha/name1') # Exactly the same as sha_cert_1. - sha_cert_6 = project_management.ShaCertificate( + sha_cert_6 = project_management.SHACertificate( '123456789a123456789a123456789a123456789a', 'projects/-/androidApps/1:12345678:android:deadbeef/sha/name1') not_a_sha_cert = { @@ -578,15 +583,16 @@ def test_create_android_app(self): recorder[2], 'GET', 'https://firebase.googleapis.com/v1/operations/abcdefg') def test_create_android_app_already_exists(self): - recorder = self._instrument_service(statuses=[409], responses=['some error response']) + recorder = self._instrument_service(statuses=[409], responses=[ALREADY_EXISTS_RESPONSE]) - with pytest.raises(project_management.ApiCallError) as excinfo: + with pytest.raises(exceptions.AlreadyExistsError) as excinfo: project_management.create_android_app( package_name='com.hello.world.android', display_name='My Android App') assert 'The resource already exists' in str(excinfo.value) - assert excinfo.value.detail is not None + assert excinfo.value.cause is not None + assert excinfo.value.http_response is not None assert len(recorder) == 1 def test_create_android_app_polling_rpc_error(self): @@ -595,16 +601,17 @@ def test_create_android_app_polling_rpc_error(self): responses=[ OPERATION_IN_PROGRESS_RESPONSE, # Request to create Android app asynchronously. OPERATION_IN_PROGRESS_RESPONSE, # Creation operation is still not done. - 'some error response', # Error 503. + UNAVAILABLE_RESPONSE, # Error 503. ]) - with pytest.raises(project_management.ApiCallError) as excinfo: + with pytest.raises(exceptions.UnavailableError) as excinfo: project_management.create_android_app( package_name='com.hello.world.android', display_name='My Android App') assert 'Backend servers are over capacity' in str(excinfo.value) - assert excinfo.value.detail is not None + assert excinfo.value.cause is not None + assert excinfo.value.http_response is not None assert len(recorder) == 3 def test_create_android_app_polling_failure(self): @@ -616,13 +623,14 @@ def test_create_android_app_polling_failure(self): OPERATION_FAILED_RESPONSE, # Operation is finished, but terminated with an error. ]) - with pytest.raises(project_management.ApiCallError) as excinfo: + with pytest.raises(exceptions.UnknownError) as excinfo: project_management.create_android_app( package_name='com.hello.world.android', display_name='My Android App') assert 'Polling finished, but the operation terminated in an error' in str(excinfo.value) - assert excinfo.value.detail is not None + assert excinfo.value.cause is None + assert excinfo.value.http_response is not None assert len(recorder) == 3 def test_create_android_app_polling_limit_exceeded(self): @@ -635,17 +643,17 @@ def test_create_android_app_polling_limit_exceeded(self): OPERATION_IN_PROGRESS_RESPONSE, # Creation Operation is still not done. ]) - with pytest.raises(project_management.ApiCallError) as excinfo: + with pytest.raises(exceptions.DeadlineExceededError) as excinfo: project_management.create_android_app( package_name='com.hello.world.android', display_name='My Android App') assert 'Polling deadline exceeded' in str(excinfo.value) - assert excinfo.value.detail is not None + assert excinfo.value.cause is None assert len(recorder) == 3 -class TestCreateIosApp(BaseProjectManagementTest): +class TestCreateIOSApp(BaseProjectManagementTest): _CREATION_URL = 'https://firebase.googleapis.com/v1beta1/projects/test-project-id/iosApps' def test_create_ios_app_without_display_name(self): @@ -663,7 +671,7 @@ def test_create_ios_app_without_display_name(self): assert ios_app.app_id == '1:12345678:ios:ca5cade5' assert len(recorder) == 3 body = {'bundleId': 'com.hello.world.ios'} - self._assert_request_is_correct(recorder[0], 'POST', TestCreateIosApp._CREATION_URL, body) + self._assert_request_is_correct(recorder[0], 'POST', TestCreateIOSApp._CREATION_URL, body) self._assert_request_is_correct( recorder[1], 'GET', 'https://firebase.googleapis.com/v1/operations/abcdefg') self._assert_request_is_correct( @@ -688,22 +696,23 @@ def test_create_ios_app(self): 'bundleId': 'com.hello.world.ios', 'displayName': 'My iOS App', } - self._assert_request_is_correct(recorder[0], 'POST', TestCreateIosApp._CREATION_URL, body) + self._assert_request_is_correct(recorder[0], 'POST', TestCreateIOSApp._CREATION_URL, body) self._assert_request_is_correct( recorder[1], 'GET', 'https://firebase.googleapis.com/v1/operations/abcdefg') self._assert_request_is_correct( recorder[2], 'GET', 'https://firebase.googleapis.com/v1/operations/abcdefg') def test_create_ios_app_already_exists(self): - recorder = self._instrument_service(statuses=[409], responses=['some error response']) + recorder = self._instrument_service(statuses=[409], responses=[ALREADY_EXISTS_RESPONSE]) - with pytest.raises(project_management.ApiCallError) as excinfo: + with pytest.raises(exceptions.AlreadyExistsError) as excinfo: project_management.create_ios_app( bundle_id='com.hello.world.ios', display_name='My iOS App') assert 'The resource already exists' in str(excinfo.value) - assert excinfo.value.detail is not None + assert excinfo.value.cause is not None + assert excinfo.value.http_response is not None assert len(recorder) == 1 def test_create_ios_app_polling_rpc_error(self): @@ -712,16 +721,17 @@ def test_create_ios_app_polling_rpc_error(self): responses=[ OPERATION_IN_PROGRESS_RESPONSE, # Request to create iOS app asynchronously. OPERATION_IN_PROGRESS_RESPONSE, # Creation operation is still not done. - 'some error response', # Error 503. + UNAVAILABLE_RESPONSE, # Error 503. ]) - with pytest.raises(project_management.ApiCallError) as excinfo: + with pytest.raises(exceptions.UnavailableError) as excinfo: project_management.create_ios_app( bundle_id='com.hello.world.ios', display_name='My iOS App') assert 'Backend servers are over capacity' in str(excinfo.value) - assert excinfo.value.detail is not None + assert excinfo.value.cause is not None + assert excinfo.value.http_response is not None assert len(recorder) == 3 def test_create_ios_app_polling_failure(self): @@ -733,13 +743,14 @@ def test_create_ios_app_polling_failure(self): OPERATION_FAILED_RESPONSE, # Operation is finished, but terminated with an error. ]) - with pytest.raises(project_management.ApiCallError) as excinfo: + with pytest.raises(exceptions.UnknownError) as excinfo: project_management.create_ios_app( bundle_id='com.hello.world.ios', display_name='My iOS App') assert 'Polling finished, but the operation terminated in an error' in str(excinfo.value) - assert excinfo.value.detail is not None + assert excinfo.value.cause is None + assert excinfo.value.http_response is not None assert len(recorder) == 3 def test_create_ios_app_polling_limit_exceeded(self): @@ -752,13 +763,13 @@ def test_create_ios_app_polling_limit_exceeded(self): OPERATION_IN_PROGRESS_RESPONSE, # Creation Operation is still not done. ]) - with pytest.raises(project_management.ApiCallError) as excinfo: + with pytest.raises(exceptions.DeadlineExceededError) as excinfo: project_management.create_ios_app( bundle_id='com.hello.world.ios', display_name='My iOS App') assert 'Polling deadline exceeded' in str(excinfo.value) - assert excinfo.value.detail is not None + assert excinfo.value.cause is None assert len(recorder) == 3 @@ -779,13 +790,14 @@ def test_list_android_apps(self): self._assert_request_is_correct(recorder[0], 'GET', TestListAndroidApps._LISTING_URL) def test_list_android_apps_rpc_error(self): - recorder = self._instrument_service(statuses=[503], responses=['some error response']) + recorder = self._instrument_service(statuses=[503], responses=[UNAVAILABLE_RESPONSE]) - with pytest.raises(project_management.ApiCallError) as excinfo: + with pytest.raises(exceptions.UnavailableError) as excinfo: project_management.list_android_apps() assert 'Backend servers are over capacity' in str(excinfo.value) - assert excinfo.value.detail is not None + assert excinfo.value.cause is not None + assert excinfo.value.http_response is not None assert len(recorder) == 1 def test_list_android_apps_empty_list(self): @@ -813,17 +825,18 @@ def test_list_android_apps_multiple_pages(self): def test_list_android_apps_multiple_pages_rpc_error(self): recorder = self._instrument_service( statuses=[200, 503], - responses=[LIST_ANDROID_APPS_PAGE_1_RESPONSE, 'some error response']) + responses=[LIST_ANDROID_APPS_PAGE_1_RESPONSE, UNAVAILABLE_RESPONSE]) - with pytest.raises(project_management.ApiCallError) as excinfo: + with pytest.raises(exceptions.UnavailableError) as excinfo: project_management.list_android_apps() assert 'Backend servers are over capacity' in str(excinfo.value) - assert excinfo.value.detail is not None + assert excinfo.value.cause is not None + assert excinfo.value.http_response is not None assert len(recorder) == 2 -class TestListIosApps(BaseProjectManagementTest): +class TestListIOSApps(BaseProjectManagementTest): _LISTING_URL = ('https://firebase.googleapis.com/v1beta1/projects/test-project-id/' 'iosApps?pageSize=100') _LISTING_PAGE_2_URL = ('https://firebase.googleapis.com/v1beta1/projects/test-project-id/' @@ -837,16 +850,17 @@ def test_list_ios_apps(self): expected_app_ids = set(['1:12345678:ios:ca5cade5', '1:12345678:ios:ca5cade5cafe']) assert set(app.app_id for app in ios_apps) == expected_app_ids assert len(recorder) == 1 - self._assert_request_is_correct(recorder[0], 'GET', TestListIosApps._LISTING_URL) + self._assert_request_is_correct(recorder[0], 'GET', TestListIOSApps._LISTING_URL) def test_list_ios_apps_rpc_error(self): - recorder = self._instrument_service(statuses=[503], responses=['some error response']) + recorder = self._instrument_service(statuses=[503], responses=[UNAVAILABLE_RESPONSE]) - with pytest.raises(project_management.ApiCallError) as excinfo: + with pytest.raises(exceptions.UnavailableError) as excinfo: project_management.list_ios_apps() assert 'Backend servers are over capacity' in str(excinfo.value) - assert excinfo.value.detail is not None + assert excinfo.value.cause is not None + assert excinfo.value.http_response is not None assert len(recorder) == 1 def test_list_ios_apps_empty_list(self): @@ -856,7 +870,7 @@ def test_list_ios_apps_empty_list(self): assert ios_apps == [] assert len(recorder) == 1 - self._assert_request_is_correct(recorder[0], 'GET', TestListIosApps._LISTING_URL) + self._assert_request_is_correct(recorder[0], 'GET', TestListIOSApps._LISTING_URL) def test_list_ios_apps_multiple_pages(self): recorder = self._instrument_service( @@ -868,19 +882,20 @@ def test_list_ios_apps_multiple_pages(self): expected_app_ids = set(['1:12345678:ios:ca5cade5', '1:12345678:ios:ca5cade5cafe']) assert set(app.app_id for app in ios_apps) == expected_app_ids assert len(recorder) == 2 - self._assert_request_is_correct(recorder[0], 'GET', TestListIosApps._LISTING_URL) - self._assert_request_is_correct(recorder[1], 'GET', TestListIosApps._LISTING_PAGE_2_URL) + self._assert_request_is_correct(recorder[0], 'GET', TestListIOSApps._LISTING_URL) + self._assert_request_is_correct(recorder[1], 'GET', TestListIOSApps._LISTING_PAGE_2_URL) def test_list_ios_apps_multiple_pages_rpc_error(self): recorder = self._instrument_service( statuses=[200, 503], - responses=[LIST_IOS_APPS_PAGE_1_RESPONSE, 'some error response']) + responses=[LIST_IOS_APPS_PAGE_1_RESPONSE, UNAVAILABLE_RESPONSE]) - with pytest.raises(project_management.ApiCallError) as excinfo: + with pytest.raises(exceptions.UnavailableError) as excinfo: project_management.list_ios_apps() assert 'Backend servers are over capacity' in str(excinfo.value) - assert excinfo.value.detail is not None + assert excinfo.value.cause is not None + assert excinfo.value.http_response is not None assert len(recorder) == 2 @@ -936,21 +951,24 @@ def test_get_metadata_unknown_error(self, android_app): recorder = self._instrument_service( statuses=[428], responses=['precondition required error']) - with pytest.raises(project_management.ApiCallError) as excinfo: + with pytest.raises(exceptions.UnknownError) as excinfo: android_app.get_metadata() - assert 'Error 428' in str(excinfo.value) - assert excinfo.value.detail is not None + message = 'Unexpected HTTP response with status: 428; body: precondition required error' + assert str(excinfo.value) == message + assert excinfo.value.cause is not None + assert excinfo.value.http_response is not None assert len(recorder) == 1 def test_get_metadata_not_found(self, android_app): - recorder = self._instrument_service(statuses=[404], responses=['some error response']) + recorder = self._instrument_service(statuses=[404], responses=[NOT_FOUND_RESPONSE]) - with pytest.raises(project_management.ApiCallError) as excinfo: + with pytest.raises(exceptions.NotFoundError) as excinfo: android_app.get_metadata() assert 'Failed to find the resource' in str(excinfo.value) - assert excinfo.value.detail is not None + assert excinfo.value.cause is not None + assert excinfo.value.http_response is not None assert len(recorder) == 1 def test_set_display_name(self, android_app): @@ -965,14 +983,15 @@ def test_set_display_name(self, android_app): recorder[0], 'PATCH', TestAndroidApp._SET_DISPLAY_NAME_URL, body) def test_set_display_name_not_found(self, android_app): - recorder = self._instrument_service(statuses=[404], responses=['some error response']) + recorder = self._instrument_service(statuses=[404], responses=[NOT_FOUND_RESPONSE]) new_display_name = 'A new display name!' - with pytest.raises(project_management.ApiCallError) as excinfo: + with pytest.raises(exceptions.NotFoundError) as excinfo: android_app.set_display_name(new_display_name) assert 'Failed to find the resource' in str(excinfo.value) - assert excinfo.value.detail is not None + assert excinfo.value.cause is not None + assert excinfo.value.http_response is not None assert len(recorder) == 1 def test_get_config(self, android_app): @@ -985,13 +1004,14 @@ def test_get_config(self, android_app): self._assert_request_is_correct(recorder[0], 'GET', TestAndroidApp._GET_CONFIG_URL) def test_get_config_not_found(self, android_app): - recorder = self._instrument_service(statuses=[404], responses=['some error response']) + recorder = self._instrument_service(statuses=[404], responses=[NOT_FOUND_RESPONSE]) - with pytest.raises(project_management.ApiCallError) as excinfo: + with pytest.raises(exceptions.NotFoundError) as excinfo: android_app.get_config() assert 'Failed to find the resource' in str(excinfo.value) - assert excinfo.value.detail is not None + assert excinfo.value.cause is not None + assert excinfo.value.http_response is not None assert len(recorder) == 1 def test_get_sha_certificates(self, android_app): @@ -1005,13 +1025,14 @@ def test_get_sha_certificates(self, android_app): self._assert_request_is_correct(recorder[0], 'GET', TestAndroidApp._LIST_CERTS_URL) def test_get_sha_certificates_not_found(self, android_app): - recorder = self._instrument_service(statuses=[404], responses=['some error response']) + recorder = self._instrument_service(statuses=[404], responses=[NOT_FOUND_RESPONSE]) - with pytest.raises(project_management.ApiCallError) as excinfo: + with pytest.raises(exceptions.NotFoundError) as excinfo: android_app.get_sha_certificates() assert 'Failed to find the resource' in str(excinfo.value) - assert excinfo.value.detail is not None + assert excinfo.value.cause is not None + assert excinfo.value.http_response is not None assert len(recorder) == 1 def test_add_certificate_none_error(self, android_app): @@ -1022,7 +1043,7 @@ def test_add_sha_1_certificate(self, android_app): recorder = self._instrument_service(statuses=[200], responses=[json.dumps({})]) android_app.add_sha_certificate( - project_management.ShaCertificate('123456789a123456789a123456789a123456789a')) + project_management.SHACertificate('123456789a123456789a123456789a123456789a')) assert len(recorder) == 1 body = {'shaHash': '123456789a123456789a123456789a123456789a', 'certType': 'SHA_1'} @@ -1031,7 +1052,7 @@ def test_add_sha_1_certificate(self, android_app): def test_add_sha_256_certificate(self, android_app): recorder = self._instrument_service(statuses=[200], responses=[json.dumps({})]) - android_app.add_sha_certificate(project_management.ShaCertificate( + android_app.add_sha_certificate(project_management.SHACertificate( '123456789a123456789a123456789a123456789a123456789a123456789a1234')) assert len(recorder) == 1 @@ -1042,14 +1063,15 @@ def test_add_sha_256_certificate(self, android_app): self._assert_request_is_correct(recorder[0], 'POST', TestAndroidApp._ADD_CERT_URL, body) def test_add_sha_certificates_already_exists(self, android_app): - recorder = self._instrument_service(statuses=[409], responses=['some error response']) + recorder = self._instrument_service(statuses=[409], responses=[ALREADY_EXISTS_RESPONSE]) - with pytest.raises(project_management.ApiCallError) as excinfo: + with pytest.raises(exceptions.AlreadyExistsError) as excinfo: android_app.add_sha_certificate( - project_management.ShaCertificate('123456789a123456789a123456789a123456789a')) + project_management.SHACertificate('123456789a123456789a123456789a123456789a')) assert 'The resource already exists' in str(excinfo.value) - assert excinfo.value.detail is not None + assert excinfo.value.cause is not None + assert excinfo.value.http_response is not None assert len(recorder) == 1 def test_delete_certificate_none_error(self, android_app): @@ -1075,13 +1097,14 @@ def test_delete_sha_256_certificate(self, android_app): recorder[0], 'DELETE', TestAndroidApp._DELETE_SHA_256_CERT_URL) def test_delete_sha_certificates_not_found(self, android_app): - recorder = self._instrument_service(statuses=[404], responses=['some error response']) + recorder = self._instrument_service(statuses=[404], responses=[NOT_FOUND_RESPONSE]) - with pytest.raises(project_management.ApiCallError) as excinfo: + with pytest.raises(exceptions.NotFoundError) as excinfo: android_app.delete_sha_certificate(SHA_1_CERTIFICATE) assert 'Failed to find the resource' in str(excinfo.value) - assert excinfo.value.detail is not None + assert excinfo.value.cause is not None + assert excinfo.value.http_response is not None assert len(recorder) == 1 def test_raises_if_app_has_no_project_id(self): @@ -1094,7 +1117,7 @@ def evaluate(): testutils.run_without_project_id(evaluate) -class TestIosApp(BaseProjectManagementTest): +class TestIOSApp(BaseProjectManagementTest): _GET_METADATA_URL = ('https://firebase.googleapis.com/v1beta1/projects/-/iosApps/' '1:12345678:ios:ca5cade5') _SET_DISPLAY_NAME_URL = ('https://firebase.googleapis.com/v1beta1/projects/-/iosApps/' @@ -1118,7 +1141,7 @@ def test_get_metadata_no_display_name(self, ios_app): assert metadata.project_id == 'test-project-id' assert metadata.bundle_id == 'com.hello.world.ios' assert len(recorder) == 1 - self._assert_request_is_correct(recorder[0], 'GET', TestIosApp._GET_METADATA_URL) + self._assert_request_is_correct(recorder[0], 'GET', TestIOSApp._GET_METADATA_URL) def test_get_metadata(self, ios_app): recorder = self._instrument_service(statuses=[200], responses=[IOS_APP_METADATA_RESPONSE]) @@ -1131,27 +1154,30 @@ def test_get_metadata(self, ios_app): assert metadata.project_id == 'test-project-id' assert metadata.bundle_id == 'com.hello.world.ios' assert len(recorder) == 1 - self._assert_request_is_correct(recorder[0], 'GET', TestIosApp._GET_METADATA_URL) + self._assert_request_is_correct(recorder[0], 'GET', TestIOSApp._GET_METADATA_URL) def test_get_metadata_unknown_error(self, ios_app): recorder = self._instrument_service( statuses=[428], responses=['precondition required error']) - with pytest.raises(project_management.ApiCallError) as excinfo: + with pytest.raises(exceptions.UnknownError) as excinfo: ios_app.get_metadata() - assert 'Error 428' in str(excinfo.value) - assert excinfo.value.detail is not None + message = 'Unexpected HTTP response with status: 428; body: precondition required error' + assert str(excinfo.value) == message + assert excinfo.value.cause is not None + assert excinfo.value.http_response is not None assert len(recorder) == 1 def test_get_metadata_not_found(self, ios_app): - recorder = self._instrument_service(statuses=[404], responses=['some error response']) + recorder = self._instrument_service(statuses=[404], responses=[NOT_FOUND_RESPONSE]) - with pytest.raises(project_management.ApiCallError) as excinfo: + with pytest.raises(exceptions.NotFoundError) as excinfo: ios_app.get_metadata() assert 'Failed to find the resource' in str(excinfo.value) - assert excinfo.value.detail is not None + assert excinfo.value.cause is not None + assert excinfo.value.http_response is not None assert len(recorder) == 1 def test_set_display_name(self, ios_app): @@ -1163,17 +1189,18 @@ def test_set_display_name(self, ios_app): assert len(recorder) == 1 body = {'displayName': new_display_name} self._assert_request_is_correct( - recorder[0], 'PATCH', TestIosApp._SET_DISPLAY_NAME_URL, body) + recorder[0], 'PATCH', TestIOSApp._SET_DISPLAY_NAME_URL, body) def test_set_display_name_not_found(self, ios_app): - recorder = self._instrument_service(statuses=[404], responses=['some error response']) + recorder = self._instrument_service(statuses=[404], responses=[NOT_FOUND_RESPONSE]) new_display_name = 'A new display name!' - with pytest.raises(project_management.ApiCallError) as excinfo: + with pytest.raises(exceptions.NotFoundError) as excinfo: ios_app.set_display_name(new_display_name) assert 'Failed to find the resource' in str(excinfo.value) - assert excinfo.value.detail is not None + assert excinfo.value.cause is not None + assert excinfo.value.http_response is not None assert len(recorder) == 1 def test_get_config(self, ios_app): @@ -1183,16 +1210,17 @@ def test_get_config(self, ios_app): assert config == 'hello world' assert len(recorder) == 1 - self._assert_request_is_correct(recorder[0], 'GET', TestIosApp._GET_CONFIG_URL) + self._assert_request_is_correct(recorder[0], 'GET', TestIOSApp._GET_CONFIG_URL) def test_get_config_not_found(self, ios_app): - recorder = self._instrument_service(statuses=[404], responses=['some error response']) + recorder = self._instrument_service(statuses=[404], responses=[NOT_FOUND_RESPONSE]) - with pytest.raises(project_management.ApiCallError) as excinfo: + with pytest.raises(exceptions.NotFoundError) as excinfo: ios_app.get_config() assert 'Failed to find the resource' in str(excinfo.value) - assert excinfo.value.detail is not None + assert excinfo.value.cause is not None + assert excinfo.value.http_response is not None assert len(recorder) == 1 def test_raises_if_app_has_no_project_id(self): diff --git a/tests/test_token_gen.py b/tests/test_token_gen.py index 412ba3d0e..e016b8fb1 100644 --- a/tests/test_token_gen.py +++ b/tests/test_token_gen.py @@ -21,8 +21,8 @@ import time from google.auth import crypt -from google.auth import exceptions from google.auth import jwt +import google.auth.exceptions import google.oauth2.id_token import pytest from pytest_localserver import plugin @@ -31,6 +31,7 @@ import firebase_admin from firebase_admin import auth from firebase_admin import credentials +from firebase_admin import exceptions from firebase_admin import _token_gen from tests import testutils @@ -45,6 +46,15 @@ INVALID_STRINGS = [None, '', 0, 1, True, False, list(), tuple(), dict()] INVALID_BOOLS = [None, '', 'foo', 0, 1, list(), tuple(), dict()] +INVALID_JWT_ARGS = { + 'NoneToken': None, + 'EmptyToken': '', + 'BoolToken': True, + 'IntToken': 1, + 'ListToken': [], + 'EmptyDictToken': {}, + 'NonEmptyDictToken': {'a': 1}, +} # Fixture for mocking a HTTP server httpserver = plugin.httpserver @@ -219,10 +229,12 @@ def test_sign_with_iam_error(self): try: iam_resp = '{"error": {"code": 403, "message": "test error"}}' _overwrite_iam_request(app, testutils.MockRequest(403, iam_resp)) - with pytest.raises(auth.AuthError) as excinfo: + with pytest.raises(auth.TokenSignError) as excinfo: auth.create_custom_token(MOCK_UID, app=app) - assert excinfo.value.code == _token_gen.TOKEN_SIGN_ERROR - assert iam_resp in str(excinfo.value) + error = excinfo.value + assert error.code == exceptions.UNKNOWN + assert iam_resp in str(error) + assert isinstance(error.cause, google.auth.exceptions.TransportError) finally: firebase_admin.delete_app(app) @@ -298,17 +310,38 @@ def test_valid_args(self, user_mgt_app, expires_in): assert request == {'idToken' : 'id_token', 'validDuration': 3600} def test_error(self, user_mgt_app): - _instrument_user_manager(user_mgt_app, 500, '{"error":"test"}') - with pytest.raises(auth.AuthError) as excinfo: + _instrument_user_manager(user_mgt_app, 500, '{"error":{"message": "INVALID_ID_TOKEN"}}') + with pytest.raises(auth.InvalidIdTokenError) as excinfo: auth.create_session_cookie('id_token', expires_in=3600, app=user_mgt_app) - assert excinfo.value.code == _token_gen.COOKIE_CREATE_ERROR - assert '{"error":"test"}' in str(excinfo.value) + assert excinfo.value.code == exceptions.INVALID_ARGUMENT + assert str(excinfo.value) == 'The provided ID token is invalid (INVALID_ID_TOKEN).' + + def test_error_with_details(self, user_mgt_app): + _instrument_user_manager( + user_mgt_app, 500, '{"error":{"message": "INVALID_ID_TOKEN: More details."}}') + with pytest.raises(auth.InvalidIdTokenError) as excinfo: + auth.create_session_cookie('id_token', expires_in=3600, app=user_mgt_app) + assert excinfo.value.code == exceptions.INVALID_ARGUMENT + expected = 'The provided ID token is invalid (INVALID_ID_TOKEN). More details.' + assert str(excinfo.value) == expected + + def test_unexpected_error_code(self, user_mgt_app): + _instrument_user_manager(user_mgt_app, 500, '{"error":{"message": "SOMETHING_UNUSUAL"}}') + with pytest.raises(exceptions.InternalError) as excinfo: + auth.create_session_cookie('id_token', expires_in=3600, app=user_mgt_app) + assert str(excinfo.value) == 'Error while calling Auth service (SOMETHING_UNUSUAL).' + + def test_unexpected_error_response(self, user_mgt_app): + _instrument_user_manager(user_mgt_app, 500, '{}') + with pytest.raises(exceptions.InternalError) as excinfo: + auth.create_session_cookie('id_token', expires_in=3600, app=user_mgt_app) + assert str(excinfo.value) == 'Unexpected error response: {}' def test_unexpected_response(self, user_mgt_app): _instrument_user_manager(user_mgt_app, 200, '{}') - with pytest.raises(auth.AuthError) as excinfo: + with pytest.raises(auth.UnexpectedResponseError) as excinfo: auth.create_session_cookie('id_token', expires_in=3600, app=user_mgt_app) - assert excinfo.value.code == _token_gen.COOKIE_CREATE_ERROR + assert excinfo.value.code == exceptions.UNKNOWN assert 'Failed to create session cookie' in str(excinfo.value) @@ -339,13 +372,6 @@ class TestVerifyIdToken(object): 'iat': int(time.time()) - 10000, 'exp': int(time.time()) - 3600 }), - 'NoneToken': None, - 'EmptyToken': '', - 'BoolToken': True, - 'IntToken': 1, - 'ListToken': [], - 'EmptyDictToken': {}, - 'NonEmptyDictToken': {'a': 1}, 'BadFormatToken': 'foobar' } @@ -368,9 +394,8 @@ def test_valid_token_check_revoked(self, user_mgt_app, id_token): def test_revoked_token_check_revoked(self, user_mgt_app, revoked_tokens, id_token): _overwrite_cert_request(user_mgt_app, MOCK_REQUEST) _instrument_user_manager(user_mgt_app, 200, revoked_tokens) - with pytest.raises(auth.AuthError) as excinfo: + with pytest.raises(auth.RevokedIdTokenError) as excinfo: auth.verify_id_token(id_token, app=user_mgt_app, check_revoked=True) - assert excinfo.value.code == 'ID_TOKEN_REVOKED' assert str(excinfo.value) == 'The Firebase ID token has been revoked.' @pytest.mark.parametrize('arg', INVALID_BOOLS) @@ -387,11 +412,30 @@ def test_revoked_token_do_not_check_revoked(self, user_mgt_app, revoked_tokens, assert claims['admin'] is True assert claims['uid'] == claims['sub'] + @pytest.mark.parametrize('id_token', INVALID_JWT_ARGS.values(), ids=list(INVALID_JWT_ARGS)) + def test_invalid_arg(self, user_mgt_app, id_token): + _overwrite_cert_request(user_mgt_app, MOCK_REQUEST) + with pytest.raises(ValueError) as excinfo: + auth.verify_id_token(id_token, app=user_mgt_app) + assert 'Illegal ID token provided' in str(excinfo.value) + @pytest.mark.parametrize('id_token', invalid_tokens.values(), ids=list(invalid_tokens)) def test_invalid_token(self, user_mgt_app, id_token): _overwrite_cert_request(user_mgt_app, MOCK_REQUEST) - with pytest.raises(ValueError): + with pytest.raises(auth.InvalidIdTokenError) as excinfo: + auth.verify_id_token(id_token, app=user_mgt_app) + assert isinstance(excinfo.value, exceptions.InvalidArgumentError) + assert excinfo.value.http_response is None + + def test_expired_token(self, user_mgt_app): + _overwrite_cert_request(user_mgt_app, MOCK_REQUEST) + id_token = self.invalid_tokens['ExpiredToken'] + with pytest.raises(auth.ExpiredIdTokenError) as excinfo: auth.verify_id_token(id_token, app=user_mgt_app) + assert isinstance(excinfo.value, auth.InvalidIdTokenError) + assert 'Token expired' in str(excinfo.value) + assert excinfo.value.cause is not None + assert excinfo.value.http_response is None def test_project_id_option(self): app = firebase_admin.initialize_app( @@ -416,13 +460,19 @@ def test_project_id_env_var(self, env_var_app): def test_custom_token(self, auth_app): id_token = auth.create_custom_token(MOCK_UID, app=auth_app) _overwrite_cert_request(auth_app, MOCK_REQUEST) - with pytest.raises(ValueError): + with pytest.raises(auth.InvalidIdTokenError) as excinfo: auth.verify_id_token(id_token, app=auth_app) + message = 'verify_id_token() expects an ID token, but was given a custom token.' + assert str(excinfo.value) == message def test_certificate_request_failure(self, user_mgt_app): _overwrite_cert_request(user_mgt_app, testutils.MockRequest(404, 'not found')) - with pytest.raises(exceptions.TransportError): + with pytest.raises(auth.CertificateFetchError) as excinfo: auth.verify_id_token(TEST_ID_TOKEN, app=user_mgt_app) + assert 'Could not fetch certificates' in str(excinfo.value) + assert isinstance(excinfo.value, exceptions.UnknownError) + assert excinfo.value.cause is not None + assert excinfo.value.http_response is None class TestVerifySessionCookie(object): @@ -447,13 +497,6 @@ class TestVerifySessionCookie(object): 'iat': int(time.time()) - 10000, 'exp': int(time.time()) - 3600 }), - 'NoneCookie': None, - 'EmptyCookie': '', - 'BoolCookie': True, - 'IntCookie': 1, - 'ListCookie': [], - 'EmptyDictCookie': {}, - 'NonEmptyDictCookie': {'a': 1}, 'BadFormatCookie': 'foobar', 'IDToken': TEST_ID_TOKEN, } @@ -477,9 +520,8 @@ def test_valid_cookie_check_revoked(self, user_mgt_app, cookie): def test_revoked_cookie_check_revoked(self, user_mgt_app, revoked_tokens, cookie): _overwrite_cert_request(user_mgt_app, MOCK_REQUEST) _instrument_user_manager(user_mgt_app, 200, revoked_tokens) - with pytest.raises(auth.AuthError) as excinfo: + with pytest.raises(auth.RevokedSessionCookieError) as excinfo: auth.verify_session_cookie(cookie, app=user_mgt_app, check_revoked=True) - assert excinfo.value.code == 'SESSION_COOKIE_REVOKED' assert str(excinfo.value) == 'The Firebase session cookie has been revoked.' @pytest.mark.parametrize('cookie', valid_cookies.values(), ids=list(valid_cookies)) @@ -490,11 +532,30 @@ def test_revoked_cookie_does_not_check_revoked(self, user_mgt_app, revoked_token assert claims['admin'] is True assert claims['uid'] == claims['sub'] + @pytest.mark.parametrize('cookie', INVALID_JWT_ARGS.values(), ids=list(INVALID_JWT_ARGS)) + def test_invalid_args(self, user_mgt_app, cookie): + _overwrite_cert_request(user_mgt_app, MOCK_REQUEST) + with pytest.raises(ValueError) as excinfo: + auth.verify_session_cookie(cookie, app=user_mgt_app) + assert 'Illegal session cookie provided' in str(excinfo.value) + @pytest.mark.parametrize('cookie', invalid_cookies.values(), ids=list(invalid_cookies)) def test_invalid_cookie(self, user_mgt_app, cookie): _overwrite_cert_request(user_mgt_app, MOCK_REQUEST) - with pytest.raises(ValueError): + with pytest.raises(auth.InvalidSessionCookieError) as excinfo: auth.verify_session_cookie(cookie, app=user_mgt_app) + assert isinstance(excinfo.value, exceptions.InvalidArgumentError) + assert excinfo.value.http_response is None + + def test_expired_cookie(self, user_mgt_app): + _overwrite_cert_request(user_mgt_app, MOCK_REQUEST) + cookie = self.invalid_cookies['ExpiredCookie'] + with pytest.raises(auth.ExpiredSessionCookieError) as excinfo: + auth.verify_session_cookie(cookie, app=user_mgt_app) + assert isinstance(excinfo.value, auth.InvalidSessionCookieError) + assert 'Token expired' in str(excinfo.value) + assert excinfo.value.cause is not None + assert excinfo.value.http_response is None def test_project_id_option(self): app = firebase_admin.initialize_app( @@ -516,13 +577,17 @@ def test_project_id_env_var(self, env_var_app): def test_custom_token(self, auth_app): custom_token = auth.create_custom_token(MOCK_UID, app=auth_app) _overwrite_cert_request(auth_app, MOCK_REQUEST) - with pytest.raises(ValueError): + with pytest.raises(auth.InvalidSessionCookieError): auth.verify_session_cookie(custom_token, app=auth_app) def test_certificate_request_failure(self, user_mgt_app): _overwrite_cert_request(user_mgt_app, testutils.MockRequest(404, 'not found')) - with pytest.raises(exceptions.TransportError): + with pytest.raises(auth.CertificateFetchError) as excinfo: auth.verify_session_cookie(TEST_SESSION_COOKIE, app=user_mgt_app) + assert 'Could not fetch certificates' in str(excinfo.value) + assert isinstance(excinfo.value, exceptions.UnknownError) + assert excinfo.value.cause is not None + assert excinfo.value.http_response is None class TestCertificateCaching(object): diff --git a/tests/test_user_mgt.py b/tests/test_user_mgt.py index 797e0ce59..dc71b6b6d 100644 --- a/tests/test_user_mgt.py +++ b/tests/test_user_mgt.py @@ -21,6 +21,7 @@ import firebase_admin from firebase_admin import auth +from firebase_admin import exceptions from firebase_admin import _auth_utils from firebase_admin import _user_import from firebase_admin import _user_mgt @@ -211,34 +212,89 @@ def test_get_user_by_phone(self, user_mgt_app): def test_get_user_non_existing(self, user_mgt_app): _instrument_user_manager(user_mgt_app, 200, '{"users":[]}') - with pytest.raises(auth.AuthError) as excinfo: + with pytest.raises(auth.UserNotFoundError) as excinfo: auth.get_user('nonexistentuser', user_mgt_app) - assert excinfo.value.code == _user_mgt.USER_NOT_FOUND_ERROR + error_msg = 'No user record found for the provided user ID: nonexistentuser.' + assert excinfo.value.code == exceptions.NOT_FOUND + assert str(excinfo.value) == error_msg + assert excinfo.value.http_response is not None + assert excinfo.value.cause is None + + def test_get_user_by_email_non_existing(self, user_mgt_app): + _instrument_user_manager(user_mgt_app, 200, '{"users":[]}') + with pytest.raises(auth.UserNotFoundError) as excinfo: + auth.get_user_by_email('nonexistent@user', user_mgt_app) + error_msg = 'No user record found for the provided email: nonexistent@user.' + assert excinfo.value.code == exceptions.NOT_FOUND + assert str(excinfo.value) == error_msg + assert excinfo.value.http_response is not None + assert excinfo.value.cause is None + + def test_get_user_by_phone_non_existing(self, user_mgt_app): + _instrument_user_manager(user_mgt_app, 200, '{"users":[]}') + with pytest.raises(auth.UserNotFoundError) as excinfo: + auth.get_user_by_phone_number('+1234567890', user_mgt_app) + error_msg = 'No user record found for the provided phone number: +1234567890.' + assert excinfo.value.code == exceptions.NOT_FOUND + assert str(excinfo.value) == error_msg + assert excinfo.value.http_response is not None + assert excinfo.value.cause is None def test_get_user_http_error(self, user_mgt_app): - _instrument_user_manager(user_mgt_app, 500, '{"error":"test"}') - with pytest.raises(auth.AuthError) as excinfo: + _instrument_user_manager(user_mgt_app, 500, '{"error":{"message": "USER_NOT_FOUND"}}') + with pytest.raises(auth.UserNotFoundError) as excinfo: + auth.get_user('testuser', user_mgt_app) + error_msg = 'No user record found for the given identifier (USER_NOT_FOUND).' + assert excinfo.value.code == exceptions.NOT_FOUND + assert str(excinfo.value) == error_msg + assert excinfo.value.http_response is not None + assert excinfo.value.cause is not None + + def test_get_user_http_error_unexpected_code(self, user_mgt_app): + _instrument_user_manager(user_mgt_app, 500, '{"error":{"message": "UNEXPECTED_CODE"}}') + with pytest.raises(exceptions.InternalError) as excinfo: auth.get_user('testuser', user_mgt_app) - assert excinfo.value.code == _user_mgt.INTERNAL_ERROR - assert '{"error":"test"}' in str(excinfo.value) + assert str(excinfo.value) == 'Error while calling Auth service (UNEXPECTED_CODE).' + assert excinfo.value.http_response is not None + assert excinfo.value.cause is not None + + def test_get_user_http_error_malformed_response(self, user_mgt_app): + _instrument_user_manager(user_mgt_app, 500, '{"error": "UNEXPECTED_CODE"}') + with pytest.raises(exceptions.InternalError) as excinfo: + auth.get_user('testuser', user_mgt_app) + assert str(excinfo.value) == 'Unexpected error response: {"error": "UNEXPECTED_CODE"}' + assert excinfo.value.http_response is not None + assert excinfo.value.cause is not None def test_get_user_by_email_http_error(self, user_mgt_app): - _instrument_user_manager(user_mgt_app, 500, '{"error":"test"}') - with pytest.raises(auth.AuthError) as excinfo: + _instrument_user_manager(user_mgt_app, 500, '{"error":{"message": "USER_NOT_FOUND"}}') + with pytest.raises(auth.UserNotFoundError) as excinfo: auth.get_user_by_email('non.existent.user@example.com', user_mgt_app) - assert excinfo.value.code == _user_mgt.INTERNAL_ERROR - assert '{"error":"test"}' in str(excinfo.value) + error_msg = 'No user record found for the given identifier (USER_NOT_FOUND).' + assert excinfo.value.code == exceptions.NOT_FOUND + assert str(excinfo.value) == error_msg + assert excinfo.value.http_response is not None + assert excinfo.value.cause is not None def test_get_user_by_phone_http_error(self, user_mgt_app): - _instrument_user_manager(user_mgt_app, 500, '{"error":"test"}') - with pytest.raises(auth.AuthError) as excinfo: + _instrument_user_manager(user_mgt_app, 500, '{"error":{"message": "USER_NOT_FOUND"}}') + with pytest.raises(auth.UserNotFoundError) as excinfo: auth.get_user_by_phone_number('+1234567890', user_mgt_app) - assert excinfo.value.code == _user_mgt.INTERNAL_ERROR - assert '{"error":"test"}' in str(excinfo.value) + error_msg = 'No user record found for the given identifier (USER_NOT_FOUND).' + assert excinfo.value.code == exceptions.NOT_FOUND + assert str(excinfo.value) == error_msg + assert excinfo.value.http_response is not None + assert excinfo.value.cause is not None class TestCreateUser(object): + already_exists_errors = { + 'DUPLICATE_EMAIL': auth.EmailAlreadyExistsError, + 'DUPLICATE_LOCAL_ID': auth.UidAlreadyExistsError, + 'PHONE_NUMBER_EXISTS': auth.PhoneNumberAlreadyExistsError, + } + @pytest.mark.parametrize('arg', INVALID_STRINGS[1:] + ['a'*129]) def test_invalid_uid(self, user_mgt_app, arg): with pytest.raises(ValueError): @@ -301,11 +357,33 @@ def test_create_user_with_id(self, user_mgt_app): assert request == {'localId' : 'testuser'} def test_create_user_error(self, user_mgt_app): - _instrument_user_manager(user_mgt_app, 500, '{"error":"test"}') - with pytest.raises(auth.AuthError) as excinfo: + _instrument_user_manager(user_mgt_app, 500, '{"error": {"message": "UNEXPECTED_CODE"}}') + with pytest.raises(exceptions.InternalError) as excinfo: + auth.create_user(app=user_mgt_app) + assert str(excinfo.value) == 'Error while calling Auth service (UNEXPECTED_CODE).' + assert excinfo.value.http_response is not None + assert excinfo.value.cause is not None + + @pytest.mark.parametrize('error_code', already_exists_errors.keys()) + def test_user_already_exists(self, user_mgt_app, error_code): + resp = {'error': {'message': error_code}} + _instrument_user_manager(user_mgt_app, 500, json.dumps(resp)) + exc_type = self.already_exists_errors[error_code] + with pytest.raises(exc_type) as excinfo: auth.create_user(app=user_mgt_app) - assert excinfo.value.code == _user_mgt.USER_CREATE_ERROR - assert '{"error":"test"}' in str(excinfo.value) + assert isinstance(excinfo.value, exceptions.AlreadyExistsError) + assert str(excinfo.value) == '{0} ({1}).'.format(exc_type.default_message, error_code) + assert excinfo.value.http_response is not None + assert excinfo.value.cause is not None + + def test_create_user_unexpected_response(self, user_mgt_app): + _instrument_user_manager(user_mgt_app, 200, '{"error": "test"}') + with pytest.raises(auth.UnexpectedResponseError) as excinfo: + auth.create_user(app=user_mgt_app) + assert str(excinfo.value) == 'Failed to create new user.' + assert excinfo.value.http_response is not None + assert excinfo.value.cause is None + assert isinstance(excinfo.value, exceptions.UnknownError) class TestUpdateUser(object): @@ -387,16 +465,6 @@ def test_delete_user_custom_claims(self, user_mgt_app): request = json.loads(recorder[0].body.decode()) assert request == {'localId' : 'testuser', 'customAttributes' : json.dumps({})} - def test_update_user_delete_fields_with_none(self, user_mgt_app): - user_mgt, recorder = _instrument_user_manager(user_mgt_app, 200, '{"localId":"testuser"}') - user_mgt.update_user('testuser', display_name=None, photo_url=None, phone_number=None) - request = json.loads(recorder[0].body.decode()) - assert request == { - 'localId' : 'testuser', - 'deleteAttribute' : ['DISPLAY_NAME', 'PHOTO_URL'], - 'deleteProvider' : ['phone'], - } - def test_update_user_delete_fields(self, user_mgt_app): user_mgt, recorder = _instrument_user_manager(user_mgt_app, 200, '{"localId":"testuser"}') user_mgt.update_user( @@ -412,11 +480,21 @@ def test_update_user_delete_fields(self, user_mgt_app): } def test_update_user_error(self, user_mgt_app): - _instrument_user_manager(user_mgt_app, 500, '{"error":"test"}') - with pytest.raises(auth.AuthError) as excinfo: + _instrument_user_manager(user_mgt_app, 500, '{"error": {"message": "UNEXPECTED_CODE"}}') + with pytest.raises(exceptions.InternalError) as excinfo: auth.update_user('user', app=user_mgt_app) - assert excinfo.value.code == _user_mgt.USER_UPDATE_ERROR - assert '{"error":"test"}' in str(excinfo.value) + assert str(excinfo.value) == 'Error while calling Auth service (UNEXPECTED_CODE).' + assert excinfo.value.http_response is not None + assert excinfo.value.cause is not None + + def test_update_user_unexpected_response(self, user_mgt_app): + _instrument_user_manager(user_mgt_app, 200, '{"error": "test"}') + with pytest.raises(auth.UnexpectedResponseError) as excinfo: + auth.update_user('user', app=user_mgt_app) + assert str(excinfo.value) == 'Failed to update user: user.' + assert excinfo.value.http_response is not None + assert excinfo.value.cause is None + assert isinstance(excinfo.value, exceptions.UnknownError) @pytest.mark.parametrize('arg', [1, 1.0]) def test_update_user_valid_since(self, user_mgt_app, arg): @@ -473,18 +551,19 @@ def test_set_custom_user_claims_str(self, user_mgt_app): request = json.loads(recorder[0].body.decode()) assert request == {'localId' : 'testuser', 'customAttributes' : claims} - def test_set_custom_user_claims_none(self, user_mgt_app): + def test_set_custom_user_claims_remove(self, user_mgt_app): _, recorder = _instrument_user_manager(user_mgt_app, 200, '{"localId":"testuser"}') - auth.set_custom_user_claims('testuser', None, app=user_mgt_app) + auth.set_custom_user_claims('testuser', auth.DELETE_ATTRIBUTE, app=user_mgt_app) request = json.loads(recorder[0].body.decode()) assert request == {'localId' : 'testuser', 'customAttributes' : json.dumps({})} def test_set_custom_user_claims_error(self, user_mgt_app): - _instrument_user_manager(user_mgt_app, 500, '{"error":"test"}') - with pytest.raises(auth.AuthError) as excinfo: + _instrument_user_manager(user_mgt_app, 500, '{"error": {"message": "UNEXPECTED_CODE"}}') + with pytest.raises(exceptions.InternalError) as excinfo: auth.set_custom_user_claims('user', {}, app=user_mgt_app) - assert excinfo.value.code == _user_mgt.USER_UPDATE_ERROR - assert '{"error":"test"}' in str(excinfo.value) + assert str(excinfo.value) == 'Error while calling Auth service (UNEXPECTED_CODE).' + assert excinfo.value.http_response is not None + assert excinfo.value.cause is not None class TestDeleteUser(object): @@ -500,11 +579,21 @@ def test_delete_user(self, user_mgt_app): auth.delete_user('testuser', user_mgt_app) def test_delete_user_error(self, user_mgt_app): - _instrument_user_manager(user_mgt_app, 500, '{"error":"test"}') - with pytest.raises(auth.AuthError) as excinfo: + _instrument_user_manager(user_mgt_app, 500, '{"error": {"message": "UNEXPECTED_CODE"}}') + with pytest.raises(exceptions.InternalError) as excinfo: + auth.delete_user('user', app=user_mgt_app) + assert str(excinfo.value) == 'Error while calling Auth service (UNEXPECTED_CODE).' + assert excinfo.value.http_response is not None + assert excinfo.value.cause is not None + + def test_delete_user_unexpected_response(self, user_mgt_app): + _instrument_user_manager(user_mgt_app, 200, '{"error": "test"}') + with pytest.raises(auth.UnexpectedResponseError) as excinfo: auth.delete_user('user', app=user_mgt_app) - assert excinfo.value.code == _user_mgt.USER_DELETE_ERROR - assert '{"error":"test"}' in str(excinfo.value) + assert str(excinfo.value) == 'Failed to delete user: user.' + assert excinfo.value.http_response is not None + assert excinfo.value.cause is None + assert isinstance(excinfo.value, exceptions.UnknownError) class TestListUsers(object): @@ -640,10 +729,9 @@ def test_list_users_with_all_args(self, user_mgt_app): def test_list_users_error(self, user_mgt_app): _instrument_user_manager(user_mgt_app, 500, '{"error":"test"}') - with pytest.raises(auth.AuthError) as excinfo: + with pytest.raises(exceptions.InternalError) as excinfo: auth.list_users(app=user_mgt_app) - assert excinfo.value.code == _user_mgt.USER_DOWNLOAD_ERROR - assert '{"error":"test"}' in str(excinfo.value) + assert str(excinfo.value) == 'Unexpected error response: {"error":"test"}' def _check_page(self, page): assert isinstance(page, auth.ListUsersPage) @@ -718,6 +806,7 @@ def test_invalid_args(self, arg): with pytest.raises(ValueError): auth.UserMetadata(**arg) + class TestImportUserRecord(object): _INVALID_USERS = ( @@ -984,6 +1073,25 @@ def test_import_users_with_hash(self, user_mgt_app): } self._check_rpc_calls(recorder, expected) + def test_import_users_http_error(self, user_mgt_app): + _instrument_user_manager(user_mgt_app, 401, '{"error": {"message": "ERROR_CODE"}}') + users = [ + auth.ImportUserRecord(uid='user1'), + auth.ImportUserRecord(uid='user2'), + ] + with pytest.raises(exceptions.UnauthenticatedError) as excinfo: + auth.import_users(users, app=user_mgt_app) + assert str(excinfo.value) == 'Error while calling Auth service (ERROR_CODE).' + + def test_import_users_unexpected_response(self, user_mgt_app): + _instrument_user_manager(user_mgt_app, 200, '"not dict"') + users = [ + auth.ImportUserRecord(uid='user1'), + auth.ImportUserRecord(uid='user2'), + ] + with pytest.raises(auth.UnexpectedResponseError): + auth.import_users(users, app=user_mgt_app) + def _check_rpc_calls(self, recorder, expected): assert len(recorder) == 1 request = json.loads(recorder[0].body.decode()) @@ -1003,6 +1111,7 @@ def test_revoke_refresh_tokens(self, user_mgt_app): assert int(request['validSince']) >= int(before_time) assert int(request['validSince']) <= int(after_time) + class TestActionCodeSetting(object): def test_valid_data(self): @@ -1047,6 +1156,7 @@ def test_encode_action_code_bad_data(self): with pytest.raises(AttributeError): _user_mgt.encode_action_code_settings({"foo":"bar"}) + class TestGenerateEmailActionLink(object): def test_email_verification_no_settings(self, user_mgt_app): @@ -1106,9 +1216,29 @@ def test_password_reset_with_settings(self, user_mgt_app): auth.generate_password_reset_link, ]) def test_api_call_failure(self, user_mgt_app, func): - _instrument_user_manager(user_mgt_app, 500, '{"error":"dummy error"}') - with pytest.raises(auth.AuthError): + _instrument_user_manager(user_mgt_app, 500, '{"error":{"message": "UNEXPECTED_CODE"}}') + with pytest.raises(exceptions.InternalError) as excinfo: + func('test@test.com', MOCK_ACTION_CODE_SETTINGS, app=user_mgt_app) + assert str(excinfo.value) == 'Error while calling Auth service (UNEXPECTED_CODE).' + assert excinfo.value.http_response is not None + assert excinfo.value.cause is not None + + @pytest.mark.parametrize('func', [ + auth.generate_sign_in_with_email_link, + auth.generate_email_verification_link, + auth.generate_password_reset_link, + ]) + def test_invalid_dynamic_link(self, user_mgt_app, func): + resp = '{"error":{"message": "INVALID_DYNAMIC_LINK_DOMAIN: Because of this reason."}}' + _instrument_user_manager(user_mgt_app, 500, resp) + with pytest.raises(auth.InvalidDynamicLinkDomainError) as excinfo: func('test@test.com', MOCK_ACTION_CODE_SETTINGS, app=user_mgt_app) + assert isinstance(excinfo.value, exceptions.InvalidArgumentError) + assert str(excinfo.value) == ('Dynamic link domain specified in ActionCodeSettings is ' + 'not authorized (INVALID_DYNAMIC_LINK_DOMAIN). Because ' + 'of this reason.') + assert excinfo.value.http_response is not None + assert excinfo.value.cause is not None @pytest.mark.parametrize('func', [ auth.generate_sign_in_with_email_link, @@ -1117,8 +1247,12 @@ def test_api_call_failure(self, user_mgt_app, func): ]) def test_api_call_no_link(self, user_mgt_app, func): _instrument_user_manager(user_mgt_app, 200, '{}') - with pytest.raises(auth.AuthError): + with pytest.raises(auth.UnexpectedResponseError) as excinfo: func('test@test.com', MOCK_ACTION_CODE_SETTINGS, app=user_mgt_app) + assert str(excinfo.value) == 'Failed to generate email action link.' + assert excinfo.value.http_response is not None + assert excinfo.value.cause is None + assert isinstance(excinfo.value, exceptions.UnknownError) @pytest.mark.parametrize('func', [ auth.generate_sign_in_with_email_link, From 88cd33fc1c277a79f2cac4eea1fdb8e722725c6d Mon Sep 17 00:00:00 2001 From: Hiranya Jayathilaka Date: Wed, 11 Sep 2019 10:12:31 -0700 Subject: [PATCH 026/226] Support deleting custom claims by passing None (#341) --- firebase_admin/auth.py | 2 ++ integration/test_auth.py | 6 +++--- tests/test_user_mgt.py | 5 +++-- 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/firebase_admin/auth.py b/firebase_admin/auth.py index 47a9a23f7..ebc133d4c 100644 --- a/firebase_admin/auth.py +++ b/firebase_admin/auth.py @@ -421,6 +421,8 @@ def set_custom_user_claims(uid, custom_claims, app=None): FirebaseError: If an error occurs while updating the user account. """ user_manager = _get_auth_service(app).user_manager + if custom_claims is None: + custom_claims = DELETE_ATTRIBUTE user_manager.update_user(uid, custom_claims=custom_claims) diff --git a/integration/test_auth.py b/integration/test_auth.py index eb1464476..1a4bacceb 100644 --- a/integration/test_auth.py +++ b/integration/test_auth.py @@ -310,9 +310,9 @@ def test_update_custom_user_claims(new_user): def test_disable_user(new_user_with_params): user = auth.update_user( new_user_with_params.uid, - display_name=None, - photo_url=None, - phone_number=None, + display_name=auth.DELETE_ATTRIBUTE, + photo_url=auth.DELETE_ATTRIBUTE, + phone_number=auth.DELETE_ATTRIBUTE, disabled=True) assert user.uid == new_user_with_params.uid assert user.email == new_user_with_params.email diff --git a/tests/test_user_mgt.py b/tests/test_user_mgt.py index dc71b6b6d..a971c40a0 100644 --- a/tests/test_user_mgt.py +++ b/tests/test_user_mgt.py @@ -551,9 +551,10 @@ def test_set_custom_user_claims_str(self, user_mgt_app): request = json.loads(recorder[0].body.decode()) assert request == {'localId' : 'testuser', 'customAttributes' : claims} - def test_set_custom_user_claims_remove(self, user_mgt_app): + @pytest.mark.parametrize('claims', [None, auth.DELETE_ATTRIBUTE]) + def test_set_custom_user_claims_remove(self, user_mgt_app, claims): _, recorder = _instrument_user_manager(user_mgt_app, 200, '{"localId":"testuser"}') - auth.set_custom_user_claims('testuser', auth.DELETE_ATTRIBUTE, app=user_mgt_app) + auth.set_custom_user_claims('testuser', claims, app=user_mgt_app) request = json.loads(recorder[0].body.decode()) assert request == {'localId' : 'testuser', 'customAttributes' : json.dumps({})} From 733481da6a18eaca0f84f3a33019da9ac5e27a76 Mon Sep 17 00:00:00 2001 From: Hiranya Jayathilaka Date: Wed, 11 Sep 2019 11:05:19 -0700 Subject: [PATCH 027/226] Bumped version to 3.0.0 (#344) --- CHANGELOG.md | 5 +++++ firebase_admin/__about__.py | 2 +- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 72c40dbe0..751639ff6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,10 @@ # Unreleased +- + +# v3.0.0 + + - [added] Added the new `firebase_admin.exceptions` module containing the base exception types and global error codes. - [changed] Updated the `firebase_admin.instance_id` module to use the new diff --git a/firebase_admin/__about__.py b/firebase_admin/__about__.py index 5a2f77e32..546b6cb7c 100644 --- a/firebase_admin/__about__.py +++ b/firebase_admin/__about__.py @@ -14,7 +14,7 @@ """About information (version, etc) for Firebase Admin SDK.""" -__version__ = '2.18.0' +__version__ = '3.0.0' __title__ = 'firebase_admin' __author__ = 'Firebase' __license__ = 'Apache License 2.0' From f889dff96c47a8eeb8299f0b75148512d1e117d5 Mon Sep 17 00:00:00 2001 From: Hiranya Jayathilaka Date: Fri, 13 Sep 2019 13:47:26 -0700 Subject: [PATCH 028/226] Handling the EMAIL_EXISTS error code (#348) --- firebase_admin/_auth_utils.py | 1 + tests/test_user_mgt.py | 1 + 2 files changed, 2 insertions(+) diff --git a/firebase_admin/_auth_utils.py b/firebase_admin/_auth_utils.py index d90b494f5..bdba9f81d 100644 --- a/firebase_admin/_auth_utils.py +++ b/firebase_admin/_auth_utils.py @@ -257,6 +257,7 @@ def __init__(self, message, cause=None, http_response=None): _CODE_TO_EXC_TYPE = { 'DUPLICATE_EMAIL': EmailAlreadyExistsError, 'DUPLICATE_LOCAL_ID': UidAlreadyExistsError, + 'EMAIL_EXISTS': EmailAlreadyExistsError, 'INVALID_DYNAMIC_LINK_DOMAIN': InvalidDynamicLinkDomainError, 'INVALID_ID_TOKEN': InvalidIdTokenError, 'PHONE_NUMBER_EXISTS': PhoneNumberAlreadyExistsError, diff --git a/tests/test_user_mgt.py b/tests/test_user_mgt.py index a971c40a0..3c19a98d8 100644 --- a/tests/test_user_mgt.py +++ b/tests/test_user_mgt.py @@ -292,6 +292,7 @@ class TestCreateUser(object): already_exists_errors = { 'DUPLICATE_EMAIL': auth.EmailAlreadyExistsError, 'DUPLICATE_LOCAL_ID': auth.UidAlreadyExistsError, + 'EMAIL_EXISTS': auth.EmailAlreadyExistsError, 'PHONE_NUMBER_EXISTS': auth.PhoneNumberAlreadyExistsError, } From 3cab0c1ffdd32a7d5d4a368e0bcee98d5556d1e1 Mon Sep 17 00:00:00 2001 From: Lahiru Maramba Date: Wed, 25 Sep 2019 12:45:06 -0400 Subject: [PATCH 029/226] fix(fcm): String representation in Message class (#350) * String representation in Message class * PR fixes and add test cases * Fix pylint errors --- firebase_admin/_messaging_utils.py | 3 +++ tests/test_messaging.py | 28 ++++++++++++++++++++++++++++ 2 files changed, 31 insertions(+) diff --git a/firebase_admin/_messaging_utils.py b/firebase_admin/_messaging_utils.py index 5c99cb8ef..f0bc969eb 100644 --- a/firebase_admin/_messaging_utils.py +++ b/firebase_admin/_messaging_utils.py @@ -57,6 +57,9 @@ def __init__(self, data=None, notification=None, android=None, webpush=None, apn self.topic = topic self.condition = condition + def __str__(self): + return json.dumps(self, cls=MessageEncoder, sort_keys=True) + class MulticastMessage(object): """A message that can be sent to multiple tokens via Firebase Cloud Messaging. diff --git a/tests/test_messaging.py b/tests/test_messaging.py index 1f6fa102c..dbfe5d2c0 100644 --- a/tests/test_messaging.py +++ b/tests/test_messaging.py @@ -61,6 +61,34 @@ def check_exception(exception, message, status): assert exception.http_response.status_code == status +class TestMessageStr(object): + + @pytest.mark.parametrize('msg', [ + messaging.Message(), + messaging.Message(topic='topic', token='token'), + messaging.Message(topic='topic', condition='condition'), + messaging.Message(condition='condition', token='token'), + messaging.Message(topic='topic', token='token', condition='condition'), + ]) + def test_invalid_target_message(self, msg): + with pytest.raises(ValueError) as excinfo: + str(msg) + assert str( + excinfo.value) == 'Exactly one of token, topic or condition must be specified.' + + def test_empty_message(self): + assert str(messaging.Message(token='value')) == '{"token": "value"}' + assert str(messaging.Message(topic='value')) == '{"topic": "value"}' + assert str(messaging.Message(condition='value') + ) == '{"condition": "value"}' + + def test_data_message(self): + assert str(messaging.Message(topic='topic', data={}) + ) == '{"topic": "topic"}' + assert str(messaging.Message(topic='topic', data={ + 'k1': 'v1', 'k2': 'v2'})) == '{"data": {"k1": "v1", "k2": "v2"}, "topic": "topic"}' + + class TestMulticastMessage(object): @pytest.mark.parametrize('tokens', NON_LIST_ARGS) From 3c504e6c83d9d19b6a2059fa38faca646a7e4d1e Mon Sep 17 00:00:00 2001 From: rsgowman Date: Mon, 30 Sep 2019 16:01:57 -0400 Subject: [PATCH 030/226] Remove (base64) 'REDACTED' passwords from user records. (#352) These values *look* like passwords hashes, but aren't, leading to potential confusion. Additionally, added docs to CONTRIBUTING.md detailing how to add the permission that causes password hashes to be properly returned as well as adjusting the test failure message should the developer not add that permission. b/141189502 --- CONTRIBUTING.md | 14 ++++++++++++-- firebase_admin/_user_mgt.py | 18 ++++++++++++++---- integration/test_auth.py | 16 ++++++++++++---- tests/test_user_mgt.py | 8 ++++++++ 4 files changed, 46 insertions(+), 10 deletions(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 8c58e63e9..7b4a0ea84 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -115,7 +115,7 @@ pylint firebase_admin However, it is recommended that you use the [`lint.sh`](lint.sh) bash script to invoke pylint. This script will run the linter on both `firebase_admin` and the corresponding -`tests` module. It suprresses some of the noisy warnings that get generated +`tests` module. It suppresses some of the noisy warnings that get generated when running pylint on test code. Note that by default `lint.sh` will only validate the locally modified source files. To validate all source files, pass `all` as an argument. @@ -181,13 +181,23 @@ Then set up your Firebase/GCP project as follows: to set up Firestore either in the locked mode or in the test mode. 2. Enable password auth: Select "Authentication" from the "Develop" menu in Firebase Console. Select the "Sign-in method" tab, and enable the - "Email/Password" sign-in method. + "Email/Password" sign-in method, including the Email link (passwordless + sign-in) option. + 3. Enable the IAM API: Go to the [Google Cloud Platform Console](https://console.cloud.google.com) and make sure your Firebase/GCP project is selected. Select "APIs & Services > Dashboard" from the main menu, and click the "ENABLE APIS AND SERVICES" button. Search for and enable the "Identity and Access Management (IAM) API". +4. Grant your service account the 'Firebase Authentication Admin' role. This is + required to ensure that exported user records contain the password hashes of + the user accounts: + 1. Go to [Google Cloud Platform Console / IAM & admin](https://console.cloud.google.com/iam-admin). + 2. Find your service account in the list, and click the 'pencil' icon to edit it's permissions. + 3. Click 'ADD ANOTHER ROLE' and choose 'Firebase Authentication Admin'. + 4. Click 'SAVE'. + Now you can invoke the integration test suite as follows: diff --git a/firebase_admin/_user_mgt.py b/firebase_admin/_user_mgt.py index 867b6dd89..2e10fac1b 100644 --- a/firebase_admin/_user_mgt.py +++ b/firebase_admin/_user_mgt.py @@ -14,8 +14,8 @@ """Firebase user management sub module.""" +import base64 import json - import requests import six from six.moves import urllib @@ -26,6 +26,7 @@ MAX_LIST_USERS_RESULTS = 1000 MAX_IMPORT_USERS_SIZE = 1000 +B64_REDACTED = base64.b64encode(b'REDACTED') class Sentinel(object): @@ -257,9 +258,17 @@ def password_hash(self): If the Firebase Auth hashing algorithm (SCRYPT) was used to create the user account, this is the base64-encoded password hash of the user. If a different hashing algorithm was used to create this user, as is typical when migrating from another Auth system, this - is an empty string. If no password is set, this is ``None``. + is an empty string. If no password is set, or if the service account doesn't have permission + to read the password, then this is ``None``. """ - return self._data.get('passwordHash') + password_hash = self._data.get('passwordHash') + + # If the password hash is redacted (probably due to missing permissions) then clear it out, + # similar to how the salt is returned. (Otherwise, it *looks* like a b64-encoded hash is + # present, which is confusing.) + if password_hash == B64_REDACTED: + return None + return password_hash @property def password_salt(self): @@ -268,7 +277,8 @@ def password_salt(self): If the Firebase Auth hashing algorithm (SCRYPT) was used to create the user account, this is the base64-encoded password salt of the user. If a different hashing algorithm was used to create this user, as is typical when migrating from another Auth system, this is - an empty string. If no password is set, this is ``None``. + an empty string. If no password is set, or if the service account doesn't have permission to + read the password, then this is ``None``. """ return self._data.get('salt') diff --git a/integration/test_auth.py b/integration/test_auth.py index 1a4bacceb..9d5d0dfe3 100644 --- a/integration/test_auth.py +++ b/integration/test_auth.py @@ -220,6 +220,10 @@ def test_get_user(new_user_with_params): assert provider_ids == ['password', 'phone'] def test_list_users(new_user_list): + err_msg_template = ( + 'Missing {field} field. A common cause would be forgetting to add the "Firebase ' + + 'Authentication Admin" permission. See instructions in CONTRIBUTING.md') + fetched = [] # Test exporting all user accounts. page = auth.list_users() @@ -228,8 +232,10 @@ def test_list_users(new_user_list): assert isinstance(user, auth.ExportedUserRecord) if user.uid in new_user_list: fetched.append(user.uid) - assert user.password_hash is not None - assert user.password_salt is not None + assert user.password_hash is not None, ( + err_msg_template.format(field='password_hash')) + assert user.password_salt is not None, ( + err_msg_template.format(field='password_salt')) page = page.get_next_page() assert len(fetched) == len(new_user_list) @@ -239,8 +245,10 @@ def test_list_users(new_user_list): assert isinstance(user, auth.ExportedUserRecord) if user.uid in new_user_list: fetched.append(user.uid) - assert user.password_hash is not None - assert user.password_salt is not None + assert user.password_hash is not None, ( + err_msg_template.format(field='password_hash')) + assert user.password_salt is not None, ( + err_msg_template.format(field='password_salt')) assert len(fetched) == len(new_user_list) def test_create_user(new_user): diff --git a/tests/test_user_mgt.py b/tests/test_user_mgt.py index 3c19a98d8..8b1bab133 100644 --- a/tests/test_user_mgt.py +++ b/tests/test_user_mgt.py @@ -14,6 +14,7 @@ """Test cases for the firebase_admin._user_mgt module.""" +import base64 import json import time @@ -152,6 +153,13 @@ def test_exported_record_empty_password(self): assert user.password_hash == '' assert user.password_salt == '' + def test_redacted_passwords_cleared(self): + user = auth.ExportedUserRecord({ + 'localId': 'user', + 'passwordHash': base64.b64encode(b'REDACTED'), + }) + assert user.password_hash is None + def test_custom_claims(self): user = auth.UserRecord({ 'localId' : 'user', From 69264dc7871a58146f89b48a2c3c352c3881bf03 Mon Sep 17 00:00:00 2001 From: Hiranya Jayathilaka Date: Thu, 10 Oct 2019 10:57:57 -0700 Subject: [PATCH 031/226] feat(auth): Added InsufficientPermissionError type (#354) * Added InsufficientPermissionError type * Fixing lint error --- firebase_admin/_auth_utils.py | 13 +++++++++++++ firebase_admin/auth.py | 2 ++ tests/test_user_mgt.py | 15 +++++++++++++++ 3 files changed, 30 insertions(+) diff --git a/firebase_admin/_auth_utils.py b/firebase_admin/_auth_utils.py index bdba9f81d..df3e0acfc 100644 --- a/firebase_admin/_auth_utils.py +++ b/firebase_admin/_auth_utils.py @@ -211,6 +211,18 @@ def __init__(self, message, cause, http_response): exceptions.AlreadyExistsError.__init__(self, message, cause, http_response) +class InsufficientPermissionError(exceptions.PermissionDeniedError): + """The credential used to initialize the SDK lacks required permissions.""" + + default_message = ('The credential used to initialize the SDK has insufficient ' + 'permissions to perform the requested operation. See ' + 'https://firebase.google.com/docs/admin/setup for details ' + 'on how to initialize the Admin SDK with appropriate permissions') + + def __init__(self, message, cause, http_response): + exceptions.PermissionDeniedError.__init__(self, message, cause, http_response) + + class InvalidDynamicLinkDomainError(exceptions.InvalidArgumentError): """Dynamic link domain in ActionCodeSettings is not authorized.""" @@ -258,6 +270,7 @@ def __init__(self, message, cause=None, http_response=None): 'DUPLICATE_EMAIL': EmailAlreadyExistsError, 'DUPLICATE_LOCAL_ID': UidAlreadyExistsError, 'EMAIL_EXISTS': EmailAlreadyExistsError, + 'INSUFFICIENT_PERMISSION': InsufficientPermissionError, 'INVALID_DYNAMIC_LINK_DOMAIN': InvalidDynamicLinkDomainError, 'INVALID_ID_TOKEN': InvalidIdTokenError, 'PHONE_NUMBER_EXISTS': PhoneNumberAlreadyExistsError, diff --git a/firebase_admin/auth.py b/firebase_admin/auth.py index ebc133d4c..a5110c211 100644 --- a/firebase_admin/auth.py +++ b/firebase_admin/auth.py @@ -43,6 +43,7 @@ 'ExpiredSessionCookieError', 'ExportedUserRecord', 'ImportUserRecord', + 'InsufficientPermissionError', 'InvalidDynamicLinkDomainError', 'InvalidIdTokenError', 'InvalidSessionCookieError', @@ -89,6 +90,7 @@ ExpiredSessionCookieError = _token_gen.ExpiredSessionCookieError ExportedUserRecord = _user_mgt.ExportedUserRecord ImportUserRecord = _user_import.ImportUserRecord +InsufficientPermissionError = _auth_utils.InsufficientPermissionError InvalidDynamicLinkDomainError = _auth_utils.InvalidDynamicLinkDomainError InvalidIdTokenError = _auth_utils.InvalidIdTokenError InvalidSessionCookieError = _token_gen.InvalidSessionCookieError diff --git a/tests/test_user_mgt.py b/tests/test_user_mgt.py index 8b1bab133..b213fce1b 100644 --- a/tests/test_user_mgt.py +++ b/tests/test_user_mgt.py @@ -743,6 +743,21 @@ def test_list_users_error(self, user_mgt_app): auth.list_users(app=user_mgt_app) assert str(excinfo.value) == 'Unexpected error response: {"error":"test"}' + def test_permission_error(self, user_mgt_app): + _instrument_user_manager( + user_mgt_app, 400, '{"error": {"message": "INSUFFICIENT_PERMISSION"}}') + with pytest.raises(auth.InsufficientPermissionError) as excinfo: + auth.list_users(app=user_mgt_app) + assert isinstance(excinfo.value, exceptions.PermissionDeniedError) + msg = ('The credential used to initialize the SDK has insufficient ' + 'permissions to perform the requested operation. See ' + 'https://firebase.google.com/docs/admin/setup for details ' + 'on how to initialize the Admin SDK with appropriate permissions ' + '(INSUFFICIENT_PERMISSION).') + assert str(excinfo.value) == msg + assert excinfo.value.http_response is not None + assert excinfo.value.cause is not None + def _check_page(self, page): assert isinstance(page, auth.ListUsersPage) index = 0 From 972cda0eebb8b4285e735a635f56e579e0dfc93e Mon Sep 17 00:00:00 2001 From: Hiranya Jayathilaka Date: Wed, 16 Oct 2019 10:48:08 -0700 Subject: [PATCH 032/226] Removing the CHANGELOG file (#356) --- CHANGELOG.md | 285 ------------------------------------- scripts/prepare_release.sh | 13 +- 2 files changed, 4 insertions(+), 294 deletions(-) delete mode 100644 CHANGELOG.md diff --git a/CHANGELOG.md b/CHANGELOG.md deleted file mode 100644 index 751639ff6..000000000 --- a/CHANGELOG.md +++ /dev/null @@ -1,285 +0,0 @@ -# Unreleased - -- - -# v3.0.0 - - -- [added] Added the new `firebase_admin.exceptions` module containing the - base exception types and global error codes. -- [changed] Updated the `firebase_admin.instance_id` module to use the new - shared exception types. The type `instance_id.ApiCallError` was removed. - -# v2.18.0 - -- [added] Added support for specifying the analytics label for notifications. -- [added] Added support for arbitrary key-value pairs in `messaging.ApsAlert`. -- [changed] The `WebpushFcmOptions` type is now deprecated. Developers should use - the PEP8 compliant type name `WebpushFCMOptions` instead. -- [added] Developers can now test their Database API calls by directing the - SDK traffic to the RTDB emulator. Set the `FIREBASE_DATABASE_EMULATOR_HOST` - environment variable to specify the emulator endpoint in `host:port` format. - -# v2.17.0 - -- [added] Added new `send_all()` and `send_multicast()` APIs to the - `messasing` module. -- [added] Added a new `auth.DELETE_ATTRIBUTE` sentinel value, which can be - used to delete `phone_number`, `display_name`, `photo_url` and `custom_claims` - attributes from a user account. It is now recommended to use this sentinel - value over passing `None` for deleting attributes. - -# v2.16.0 - -- [added] Added `generate_password_reset_link()`, - `generate_email_verification_link()` and `generate_sign_in_with_email_link()` - methods to the `auth` API. -- [added] Migrated the `auth` user management API to the - new Identity Toolkit endpoint. -- [fixed] Extending HTTP retries to more HTTP methods like POST and PATCH. - -# v2.15.1 - -- [added] Implemented HTTP retries. The SDK now retries HTTP calls on - low-level connection and socket read errors, as well as HTTP 500 and - 503 errors. - -# v2.15.0 - -- [changed] Taking a direct dependency on `google-api-core[grpc]` in order to - resolve some long standing Firestore installation problems. -- `messaging.WebpushConfig` class now supports configuring additional - [added] FCM options for the features supported by the web SDK. A new - `messaging.WebpushFcmOptions` type has been introduced for this - purpose. -- [added] `messaging.Aps` class now supports configuring a critical sound. A - new `messaging.CriticalSound` class has been introduced for this purpose. -- [changed] Dropped support for Python 3.3. - -# v2.14.0 - -- [added] A new `project_management` API for managing apps in a - project. -- [added] `messaging.AndroidNotification` type now supports `channel_id`. -- [fixed] FCM errors sent by the back-end now include more details - that are helpful when debugging problems. -- [fixed] Fixing error handling in FCM. The SDK now checks the key - type.googleapis.com/google.firebase.fcm.v1.FcmError to set error code. -- [fixed] Ensuring that `UserRecord.tokens_valid_after_time` always - returns an integer, and never returns `None`. -- [fixed] Fixing a performance issue in the `db.listen()` API - where it was taking a long time to process large RTDB nodes. - -# v2.13.0 - -- [added] The `db.Reference` type now provides a `listen()` API for - receiving realtime update events from the Firebase Database. -- [added] The `db.reference()` method now optionally takes a `url` - parameter. This can be used to access multiple Firebase Databases - in the same project more easily. -- [added] The `messaging.WebpushNotification` type now supports - additional parameters. - -# v2.12.0 - -- [added] Implemented the ability to create custom tokens without - service account credentials. -- [added] Admin SDK can now read the project ID from both `GCLOUD_PROJECT` and - `GOOGLE_CLOUD_PROJECT` environment variables. - -# v2.11.0 - -- [added] A new `auth.import_users()` API for importing users into Firebase - Auth in bulk. -- [fixed] The `db.Reference.update()` function now accepts dictionaries with - `None` values. This can be used to delete child keys from a reference. - -# v2.10.0 - -- [added] A new `create_session_cookie()` method for creating a long-lived - session cookie given a valid ID token. -- [added] A new `verify_session_cookie()` method for verifying a given - cookie string is valid. -- [added] `auth` module now caches the public key certificates used to - verify ID tokens and sessions cookies. This enables the SDK to avoid - making a network call everytime a credential needs to be verified. -- [added] Added the `mutable_content` optional field to the `messaging.Aps` - type. -- [added] Added support for specifying arbitrary custom key-value - fields in the `messaging.Aps` type. - -# v2.9.1 - -### Cloud Messaging - -- [changed] Improved error handling in FCM by mapping more server-side - errors to client-side error codes. See [documentation](https://firebase.google.com/docs/cloud-messaging/admin/errors). -- [changed] The `messaging` module now supports specifying an HTTP timeout - for all egress requests. Pass the `httpTimeout` option - to `firebase_admin.initialize_app()` before invoking any functions in - `messaging`. - -# v2.9.0 - -### Cloud Messaging - -- [feature] Added the `firebase_admin.messaging` module for sending - Firebase notifications and managing topic subscriptions. - -### Authentication - -- [added] The ['verify_id_token()'](https://firebase.google.com/docs/reference/admin/python/firebase_admin.auth#verify_id_token) - function now accepts an optional `check_revoked` parameter. When `True`, an - additional check is performed to see whether the token has been revoked. -- [added] A new - ['auth.revoke_refresh_tokens(uid)'](https://firebase.google.com/docs/reference/admin/python/firebase_admin.auth#revoke_refresh_tokens) - function has been added to invalidate all tokens issued to a user. -- [added] A new `tokens_valid_after_timestamp` property has been added to the - ['UserRecord'](https://firebase.google.com/docs/reference/admin/python/firebase_admin.auth#userrecord), - class indicating the time before which tokens are not valid. - -# v2.8.0 - -### Initialization - -- [added] The [`initialize_app()`](https://firebase.google.com/docs/reference/admin/python/firebase_admin#initialize_app) - method can now be invoked without any arguments. This initializes an app - using Google Application Default Credentials, and other - options loaded from the `FIREBASE_CONFIG` environment variable. - -### Realtime Database - -- [added] The [`db.Reference.get()`](https://firebase.google.com/docs/reference/admin/python/firebase_admin.db#reference) - method now accepts an optional `shallow` - argument. If set to `True` this causes the SDK to execute a shallow read, - which does not retrieve the child node values of the current reference. - -# v2.7.0 - -- [added] A new [`instance_id`](https://firebase.google.com/docs/reference/admin/python/firebase_admin.instance_id) - API that facilitates deleting instance IDs and associated user data from - Firebase projects. - -# v2.6.0 - -### Authentication - -- [added] Added the - [`list_users()`](https://firebase.google.com/docs/reference/admin/python/firebase_admin.auth#list_users) - function to the `firebase_admin.auth` module. This function enables listing - or iterating over all user accounts in a Firebase project. -- [added] Added the - [`set_custom_user_claims()`](https://firebase.google.com/docs/reference/admin/python/firebase_admin.auth#set_custom_user_claims) - function to the `firebase_admin.auth` module. This function enables setting - custom claims on a Firebase user. The custom claims can be accessed via that - user's ID token. - -### Realtime Database - -- [changed] Updated the `start_at()`, `end_at()` and `equal_to()` methods of - the [`db.Query`](https://firebase.google.com/docs/reference/admin/python/firebase_admin.db#query) class - so they can accept empty string arguments. - -# v2.5.0 - -- [added] A new [`Firestore` API](https://firebase.google.com/docs/reference/admin/python/firebase_admin.firestore) - that enables access to [Cloud Firestore](https://firebase.google.com/docs/firestore) databases. - -# v2.4.0 - -### Realtime Database - -- [added] The [`db.Reference`](https://firebase.google.com/docs/reference/admin/python/firebase_admin.db#reference) - class now has a `get_if_changed()` method, which retrieves a - database value only if the value has changed since last read. -- [added] The options dictionary passed to - [`initialize_app()`](https://firebase.google.com/docs/reference/admin/python/firebase_admin#initialize_app) - function can now contain an `httpTimeout` option, which sets - the timeout (in seconds) for outbound HTTP connections started by the SDK. - -# v2.3.0 - -### Realtime Database - -- [added] You can now get the ETag value of a database reference by passing - `etag=True` to the `get()` method of a - [`db.Reference`](https://firebase.google.com/docs/reference/admin/python/firebase_admin.db#reference) - object. -- [added] The [`db.Reference`](https://firebase.google.com/docs/reference/admin/python/firebase_admin.db#reference) - class now has a `set_if_unchanged()` method, which you can use to write to a - database location only when the location has the ETag value you specify. -- [changed] Fixed an issue with the `transaction()` method that prevented you - from updating scalar values in a transaction. - -# v2.2.0 - -- [added] A new [Cloud Storage API](https://firebase.google.com/docs/reference/admin/python/firebase_admin.storage) - that facilitates accessing Google Cloud Storage buckets using the - [`google-cloud-storage`](https://googlecloudplatform.github.io/google-cloud-python/stable/storage/client.html) - library. - -### Authentication -- [added] A new user management API that allows provisioning and managing - Firebase users from Python applications. This API adds `get_user()`, - `get_user_by_email()`, `get_user_by_phone_number()`, `create_user()`, - `update_user()` and `delete_dser()` methods - to the [`firebase_admin.auth`](https://firebase.google.com/docs/reference/admin/python/firebase_admin.auth) - module. - -### Realtime Database -- [added] The [`db.Reference`](https://firebase.google.com/docs/reference/admin/python/firebase_admin.db#reference) - class now exposes a `transaction()` method, which can be used to execute atomic updates - on database references. - -# v2.1.1 - -- [changed] Constructors of - [`Certificate`](https://firebase.google.com/docs/reference/admin/python/firebase_admin.credentials#certificate) and - [`RefreshToken`](https://firebase.google.com/docs/reference/admin/python/firebase_admin.credentials#refreshtoken) - credential types can now be invoked with either a file path or a parsed JSON object. - This facilitates the consumption of service account credentials and refresh token - credentials from sources other than the local file system. -- [changed] Better integration with the `google-auth` library for making authenticated - HTTP requests from the SDK. - -# v2.1.0 - -- [added] A new [database API](https://firebase.google.com/docs/reference/admin/python/firebase_admin.db) - that facilitates basic data manipulation - operations (create, read, update and delete), and advanced queries. Currently, - this API does not support realtime event listeners. See - [Add the Firebase Admin SDK to your Server](/docs/admin/setup/) - to get started. - -# v2.0.0 - -- [changed] This SDK has been migrated from `oauth2client` to the new - `google-auth` library. - -### Authentication -- [changed] This SDK now supports verifying ID tokens when initialized with - application default credentials. - - -# v1.0.0 - -- [added] Initial release of the Admin Python SDK. See - [Add the Firebase Admin SDK to your Server](https://firebase.google.com/docs/admin/setup/) - to get started. - -### Initialization -- [added] Implemented the - [`firebase_admin`](https://firebase.google.com/docs/reference/admin/python/firebase_admin) - module, which provides the `initialize_app()` function for initializing the - SDK with a credential. -- [added] Implemented the - [`firebase_admin.credentials`](https://firebase.google.com/docs/reference/admin/python/firebase_admin.credentials) - module, which contains constructors for `Certificate`, `ApplicationDefault` - and `RefreshToken` credential types. - -### Authentication -- [added] Implemented the - [`firebase_admin.auth`](https://firebase.google.com/docs/reference/admin/python/firebase_admin.auth) - module, which provides `create_custom_token()` and `verify_id_token()` - functions for minting custom authentication tokens and verifying Firebase ID - tokens. diff --git a/scripts/prepare_release.sh b/scripts/prepare_release.sh index ae9747c45..ca30d9043 100755 --- a/scripts/prepare_release.sh +++ b/scripts/prepare_release.sh @@ -114,23 +114,18 @@ if [[ $(git status --porcelain) ]]; then fi -################################## -# UPDATE VERSION AND CHANGELOG # -################################## +#################### +# UPDATE VERSION # +#################### HOST=$(uname) -echo "[INFO] Updating __about__.py and CHANGELOG.md" +echo "[INFO] Updating __about__.py" if [ $HOST == "Darwin" ]; then sed -i "" -e "s/__version__ = '$CUR_VERSION'/__version__ = '$VERSION'/" "../firebase_admin/__about__.py" - sed -i "" -e "1 s/# Unreleased//" "../CHANGELOG.md" else sed -i -e "s/__version__ = '$CUR_VERSION'/__version__ = '$VERSION'/" "../firebase_admin/__about__.py" - sed -i -e "1 s/# Unreleased//" "../CHANGELOG.md" fi -echo -e "# Unreleased\n\n-\n\n# v${VERSION}" | cat - ../CHANGELOG.md > TEMP_CHANGELOG.md -mv TEMP_CHANGELOG.md ../CHANGELOG.md - ################## # LAUNCH TESTS # From d625dddfcde190c0ab54f16d0553a7f141529501 Mon Sep 17 00:00:00 2001 From: Hiranya Jayathilaka Date: Wed, 23 Oct 2019 12:44:23 -0700 Subject: [PATCH 033/226] Bumped version to 3.1.0 (#358) --- firebase_admin/__about__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/firebase_admin/__about__.py b/firebase_admin/__about__.py index 546b6cb7c..04a662b25 100644 --- a/firebase_admin/__about__.py +++ b/firebase_admin/__about__.py @@ -14,7 +14,7 @@ """About information (version, etc) for Firebase Admin SDK.""" -__version__ = '3.0.0' +__version__ = '3.1.0' __title__ = 'firebase_admin' __author__ = 'Firebase' __license__ = 'Apache License 2.0' From 1199712c6ec92c9d9e49b64c30722fac82116e56 Mon Sep 17 00:00:00 2001 From: rsgowman Date: Tue, 5 Nov 2019 10:43:47 -0500 Subject: [PATCH 034/226] Reject rounds=0 for SHA1 hashes (#361) Port of https://github.com/firebase/firebase-admin-node/pull/677 --- firebase_admin/_user_import.py | 37 +++++++++++++++----------- tests/test_user_mgt.py | 48 ++++++++++++++++++---------------- 2 files changed, 48 insertions(+), 37 deletions(-) diff --git a/firebase_admin/_user_import.py b/firebase_admin/_user_import.py index 1794548f7..86252ffb8 100644 --- a/firebase_admin/_user_import.py +++ b/firebase_admin/_user_import.py @@ -282,11 +282,6 @@ def _hmac(cls, name, key): } return UserImportHash(name, data) - @classmethod - def _basic_hash(cls, name, rounds): - data = {'rounds': _auth_utils.validate_int(rounds, 'rounds', 0, 120000)} - return UserImportHash(name, data) - @classmethod def hmac_sha512(cls, key): """Creates a new HMAC SHA512 algorithm instance. @@ -340,48 +335,56 @@ def md5(cls, rounds): """Creates a new MD5 algorithm instance. Args: - rounds: Number of rounds. Must be an integer between 0 and 120000. + rounds: Number of rounds. Must be an integer between 0 and 8192. Returns: UserImportHash: A new ``UserImportHash``. """ - return cls._basic_hash('MD5', rounds) + return UserImportHash( + 'MD5', + {'rounds': _auth_utils.validate_int(rounds, 'rounds', 0, 8192)}) @classmethod def sha1(cls, rounds): """Creates a new SHA1 algorithm instance. Args: - rounds: Number of rounds. Must be an integer between 0 and 120000. + rounds: Number of rounds. Must be an integer between 1 and 8192. Returns: UserImportHash: A new ``UserImportHash``. """ - return cls._basic_hash('SHA1', rounds) + return UserImportHash( + 'SHA1', + {'rounds': _auth_utils.validate_int(rounds, 'rounds', 1, 8192)}) @classmethod def sha256(cls, rounds): """Creates a new SHA256 algorithm instance. Args: - rounds: Number of rounds. Must be an integer between 0 and 120000. + rounds: Number of rounds. Must be an integer between 1 and 8192. Returns: UserImportHash: A new ``UserImportHash``. """ - return cls._basic_hash('SHA256', rounds) + return UserImportHash( + 'SHA256', + {'rounds': _auth_utils.validate_int(rounds, 'rounds', 1, 8192)}) @classmethod def sha512(cls, rounds): """Creates a new SHA512 algorithm instance. Args: - rounds: Number of rounds. Must be an integer between 0 and 120000. + rounds: Number of rounds. Must be an integer between 1 and 8192. Returns: UserImportHash: A new ``UserImportHash``. """ - return cls._basic_hash('SHA512', rounds) + return UserImportHash( + 'SHA512', + {'rounds': _auth_utils.validate_int(rounds, 'rounds', 1, 8192)}) @classmethod def pbkdf_sha1(cls, rounds): @@ -393,7 +396,9 @@ def pbkdf_sha1(cls, rounds): Returns: UserImportHash: A new ``UserImportHash``. """ - return cls._basic_hash('PBKDF_SHA1', rounds) + return UserImportHash( + 'PBKDF_SHA1', + {'rounds': _auth_utils.validate_int(rounds, 'rounds', 0, 120000)}) @classmethod def pbkdf2_sha256(cls, rounds): @@ -405,7 +410,9 @@ def pbkdf2_sha256(cls, rounds): Returns: UserImportHash: A new ``UserImportHash``. """ - return cls._basic_hash('PBKDF2_SHA256', rounds) + return UserImportHash( + 'PBKDF2_SHA256', + {'rounds': _auth_utils.validate_int(rounds, 'rounds', 0, 120000)}) @classmethod def scrypt(cls, key, rounds, memory_cost, salt_separator=None): diff --git a/tests/test_user_mgt.py b/tests/test_user_mgt.py index b213fce1b..f4e03cc3f 100644 --- a/tests/test_user_mgt.py +++ b/tests/test_user_mgt.py @@ -933,31 +933,35 @@ def test_invalid_hmac(self, func, key): with pytest.raises(ValueError): func(key=key) - @pytest.mark.parametrize('func,name', [ - (auth.UserImportHash.sha512, 'SHA512'), - (auth.UserImportHash.sha256, 'SHA256'), - (auth.UserImportHash.sha1, 'SHA1'), - (auth.UserImportHash.md5, 'MD5'), - (auth.UserImportHash.pbkdf_sha1, 'PBKDF_SHA1'), - (auth.UserImportHash.pbkdf2_sha256, 'PBKDF2_SHA256'), + @pytest.mark.parametrize('func,name,rounds', [ + (auth.UserImportHash.md5, 'MD5', [0, 8192]), + (auth.UserImportHash.sha1, 'SHA1', [1, 8192]), + (auth.UserImportHash.sha256, 'SHA256', [1, 8192]), + (auth.UserImportHash.sha512, 'SHA512', [1, 8192]), + (auth.UserImportHash.pbkdf_sha1, 'PBKDF_SHA1', [0, 120000]), + (auth.UserImportHash.pbkdf2_sha256, 'PBKDF2_SHA256', [0, 120000]), ]) - def test_basic(self, func, name): - basic = func(rounds=10) - expected = { - 'hashAlgorithm': name, - 'rounds': 10, - } - assert basic.to_dict() == expected - - @pytest.mark.parametrize('func', [ - auth.UserImportHash.sha512, auth.UserImportHash.sha256, - auth.UserImportHash.sha1, auth.UserImportHash.md5, - auth.UserImportHash.pbkdf_sha1, auth.UserImportHash.pbkdf2_sha256, + def test_basic(self, func, name, rounds): + for rnds in rounds: + basic = func(rounds=rnds) + expected = { + 'hashAlgorithm': name, + 'rounds': rnds, + } + assert basic.to_dict() == expected + + @pytest.mark.parametrize('func,rounds', [ + (auth.UserImportHash.md5, INVALID_INTS + [-1, 8193]), + (auth.UserImportHash.sha1, INVALID_INTS + [0, 8193]), + (auth.UserImportHash.sha256, INVALID_INTS + [0, 8193]), + (auth.UserImportHash.sha512, INVALID_INTS + [0, 8193]), + (auth.UserImportHash.pbkdf_sha1, INVALID_INTS + [-1, 120001]), + (auth.UserImportHash.pbkdf2_sha256, INVALID_INTS + [-1, 120001]), ]) - @pytest.mark.parametrize('rounds', INVALID_INTS + [120001]) def test_invalid_basic(self, func, rounds): - with pytest.raises(ValueError): - func(rounds=rounds) + for rnds in rounds: + with pytest.raises(ValueError): + func(rounds=rnds) def test_scrypt(self): scrypt = auth.UserImportHash.scrypt( From c6080e41a2817175704d0216d8173fb4ee983913 Mon Sep 17 00:00:00 2001 From: Hiranya Jayathilaka Date: Thu, 7 Nov 2019 10:21:41 -0800 Subject: [PATCH 035/226] Increased FCM batch request limit to 500 (#362) --- firebase_admin/_messaging_utils.py | 4 ++-- firebase_admin/messaging.py | 6 +++--- integration/test_messaging.py | 8 ++++---- tests/test_messaging.py | 21 ++++++++++++--------- 4 files changed, 21 insertions(+), 18 deletions(-) diff --git a/firebase_admin/_messaging_utils.py b/firebase_admin/_messaging_utils.py index f0bc969eb..28d283d73 100644 --- a/firebase_admin/_messaging_utils.py +++ b/firebase_admin/_messaging_utils.py @@ -77,8 +77,8 @@ class MulticastMessage(object): def __init__(self, tokens, data=None, notification=None, android=None, webpush=None, apns=None, fcm_options=None): _Validators.check_string_list('MulticastMessage.tokens', tokens) - if len(tokens) > 100: - raise ValueError('MulticastMessage.tokens must not contain more than 100 tokens.') + if len(tokens) > 500: + raise ValueError('MulticastMessage.tokens must not contain more than 500 tokens.') self.tokens = tokens self.data = data self.notification = notification diff --git a/firebase_admin/messaging.py b/firebase_admin/messaging.py index cbd3522fa..e7062ba04 100644 --- a/firebase_admin/messaging.py +++ b/firebase_admin/messaging.py @@ -356,9 +356,9 @@ def send(self, message, dry_run=False): def send_all(self, messages, dry_run=False): """Sends the given messages to FCM via the batch API.""" if not isinstance(messages, list): - raise ValueError('Messages must be an list of messaging.Message instances.') - if len(messages) > 100: - raise ValueError('send_all messages must not contain more than 100 messages.') + raise ValueError('messages must be a list of messaging.Message instances.') + if len(messages) > 500: + raise ValueError('messages must not contain more than 500 elements.') responses = [] diff --git a/integration/test_messaging.py b/integration/test_messaging.py index 21f9d9669..01e1d212a 100644 --- a/integration/test_messaging.py +++ b/integration/test_messaging.py @@ -101,17 +101,17 @@ def test_send_all(): assert isinstance(response.exception, exceptions.InvalidArgumentError) assert response.message_id is None -def test_send_one_hundred(): +def test_send_all_500(): messages = [] - for msg_number in range(100): + for msg_number in range(500): topic = 'foo-bar-{0}'.format(msg_number % 10) messages.append(messaging.Message(topic=topic)) batch_response = messaging.send_all(messages, dry_run=True) - assert batch_response.success_count == 100 + assert batch_response.success_count == 500 assert batch_response.failure_count == 0 - assert len(batch_response.responses) == 100 + assert len(batch_response.responses) == 500 for response in batch_response.responses: assert response.success is True assert response.exception is None diff --git a/tests/test_messaging.py b/tests/test_messaging.py index dbfe5d2c0..04ef36d8c 100644 --- a/tests/test_messaging.py +++ b/tests/test_messaging.py @@ -102,15 +102,18 @@ def test_invalid_tokens_type(self, tokens): expected = 'MulticastMessage.tokens must be a list of strings.' assert str(excinfo.value) == expected - def test_tokens_over_one_hundred(self): + def test_tokens_over_500(self): with pytest.raises(ValueError) as excinfo: - messaging.MulticastMessage(tokens=['token' for _ in range(0, 101)]) - expected = 'MulticastMessage.tokens must not contain more than 100 tokens.' + messaging.MulticastMessage(tokens=['token' for _ in range(0, 501)]) + expected = 'MulticastMessage.tokens must not contain more than 500 tokens.' assert str(excinfo.value) == expected def test_tokens_type(self): - messaging.MulticastMessage(tokens=['token']) - messaging.MulticastMessage(tokens=['token' for _ in range(0, 100)]) + message = messaging.MulticastMessage(tokens=['token']) + assert len(message.tokens) == 1 + + message = messaging.MulticastMessage(tokens=['token' for _ in range(0, 500)]) + assert len(message.tokens) == 500 class TestMessageEncoder(object): @@ -1598,14 +1601,14 @@ def test_invalid_send_all(self, msg): expected = 'Message must be an instance of messaging.Message class.' assert str(excinfo.value) == expected else: - expected = 'Messages must be an list of messaging.Message instances.' + expected = 'messages must be a list of messaging.Message instances.' assert str(excinfo.value) == expected - def test_invalid_over_one_hundred(self): + def test_invalid_over_500(self): msg = messaging.Message(topic='foo') with pytest.raises(ValueError) as excinfo: - messaging.send_all([msg for _ in range(0, 101)]) - expected = 'send_all messages must not contain more than 100 messages.' + messaging.send_all([msg for _ in range(0, 501)]) + expected = 'messages must not contain more than 500 elements.' assert str(excinfo.value) == expected def test_send_all(self): From 85803619d10198c1c4fe58b0cb86e7181403d9f9 Mon Sep 17 00:00:00 2001 From: Lahiru Maramba Date: Wed, 13 Nov 2019 12:05:51 -0500 Subject: [PATCH 036/226] feat(fcm): Add 12 new Android Notification Parameters Support (#363) * feat(fcm): Add 12 new Android Notification Parameters Support * Move message encoders and validators to a separate module * Fix cyclic import * PR Fixes * PR Fixes * Code-font class names and other PR fixes --- firebase_admin/_messaging_encoder.py | 701 +++++++++++++++++++++++++++ firebase_admin/_messaging_utils.py | 639 ++++-------------------- firebase_admin/messaging.py | 9 +- integration/test_messaging.py | 13 +- tests/test_messaging.py | 213 +++++++- 5 files changed, 1015 insertions(+), 560 deletions(-) create mode 100644 firebase_admin/_messaging_encoder.py diff --git a/firebase_admin/_messaging_encoder.py b/firebase_admin/_messaging_encoder.py new file mode 100644 index 000000000..1177ffb65 --- /dev/null +++ b/firebase_admin/_messaging_encoder.py @@ -0,0 +1,701 @@ +# Copyright 2019 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Encoding and validation utils for the messaging (FCM) module.""" + +import datetime +import json +import math +import numbers +import re + +import six + +import firebase_admin._messaging_utils as _messaging_utils + + +class Message(object): + """A message that can be sent via Firebase Cloud Messaging. + + Contains payload information as well as recipient information. In particular, the message must + contain exactly one of token, topic or condition fields. + + Args: + data: A dictionary of data fields (optional). All keys and values in the dictionary must be + strings. + notification: An instance of ``messaging.Notification`` (optional). + android: An instance of ``messaging.AndroidConfig`` (optional). + webpush: An instance of ``messaging.WebpushConfig`` (optional). + apns: An instance of ``messaging.ApnsConfig`` (optional). + fcm_options: An instance of ``messaging.FCMOptions`` (optional). + token: The registration token of the device to which the message should be sent (optional). + topic: Name of the FCM topic to which the message should be sent (optional). Topic name + may contain the ``/topics/`` prefix. + condition: The FCM condition to which the message should be sent (optional). + """ + + def __init__(self, data=None, notification=None, android=None, webpush=None, apns=None, + fcm_options=None, token=None, topic=None, condition=None): + self.data = data + self.notification = notification + self.android = android + self.webpush = webpush + self.apns = apns + self.fcm_options = fcm_options + self.token = token + self.topic = topic + self.condition = condition + + def __str__(self): + return json.dumps(self, cls=MessageEncoder, sort_keys=True) + + +class MulticastMessage(object): + """A message that can be sent to multiple tokens via Firebase Cloud Messaging. + + Args: + tokens: A list of registration tokens of targeted devices. + data: A dictionary of data fields (optional). All keys and values in the dictionary must be + strings. + notification: An instance of ``messaging.Notification`` (optional). + android: An instance of ``messaging.AndroidConfig`` (optional). + webpush: An instance of ``messaging.WebpushConfig`` (optional). + apns: An instance of ``messaging.ApnsConfig`` (optional). + fcm_options: An instance of ``messaging.FCMOptions`` (optional). + """ + def __init__(self, tokens, data=None, notification=None, android=None, webpush=None, apns=None, + fcm_options=None): + _Validators.check_string_list('MulticastMessage.tokens', tokens) + if len(tokens) > 500: + raise ValueError('MulticastMessage.tokens must not contain more than 500 tokens.') + self.tokens = tokens + self.data = data + self.notification = notification + self.android = android + self.webpush = webpush + self.apns = apns + self.fcm_options = fcm_options + + +class _Validators(object): + """A collection of data validation utilities. + + Methods provided in this class raise ``ValueErrors`` if any validations fail. + """ + + @classmethod + def check_string(cls, label, value, non_empty=False): + """Checks if the given value is a string.""" + if value is None: + return None + if not isinstance(value, six.string_types): + if non_empty: + raise ValueError('{0} must be a non-empty string.'.format(label)) + else: + raise ValueError('{0} must be a string.'.format(label)) + if non_empty and not value: + raise ValueError('{0} must be a non-empty string.'.format(label)) + return value + + @classmethod + def check_number(cls, label, value): + if value is None: + return None + if not isinstance(value, numbers.Number): + raise ValueError('{0} must be a number.'.format(label)) + return value + + @classmethod + def check_string_dict(cls, label, value): + """Checks if the given value is a dictionary comprised only of string keys and values.""" + if value is None or value == {}: + return None + if not isinstance(value, dict): + raise ValueError('{0} must be a dictionary.'.format(label)) + non_str = [k for k in value if not isinstance(k, six.string_types)] + if non_str: + raise ValueError('{0} must not contain non-string keys.'.format(label)) + non_str = [v for v in value.values() if not isinstance(v, six.string_types)] + if non_str: + raise ValueError('{0} must not contain non-string values.'.format(label)) + return value + + @classmethod + def check_string_list(cls, label, value): + """Checks if the given value is a list comprised only of strings.""" + if value is None or value == []: + return None + if not isinstance(value, list): + raise ValueError('{0} must be a list of strings.'.format(label)) + non_str = [k for k in value if not isinstance(k, six.string_types)] + if non_str: + raise ValueError('{0} must not contain non-string values.'.format(label)) + return value + + @classmethod + def check_number_list(cls, label, value): + """Checks if the given value is a list comprised only of numbers.""" + if value is None or value == []: + return None + if not isinstance(value, list): + raise ValueError('{0} must be a list of numbers.'.format(label)) + non_number = [k for k in value if not isinstance(k, numbers.Number)] + if non_number: + raise ValueError('{0} must not contain non-number values.'.format(label)) + return value + + @classmethod + def check_analytics_label(cls, label, value): + """Checks if the given value is a valid analytics label.""" + value = _Validators.check_string(label, value) + if value is not None and not re.match(r'^[a-zA-Z0-9-_.~%]{1,50}$', value): + raise ValueError('Malformed {}.'.format(label)) + return value + + @classmethod + def check_datetime(cls, label, value): + """Checks if the given value is a datetime.""" + if value is None: + return None + if not isinstance(value, datetime.datetime): + raise ValueError('{0} must be a datetime.'.format(label)) + return value + + +class MessageEncoder(json.JSONEncoder): + """A custom ``JSONEncoder`` implementation for serializing Message instances into JSON.""" + + @classmethod + def remove_null_values(cls, dict_value): + return {k: v for k, v in dict_value.items() if v not in [None, [], {}]} + + @classmethod + def encode_android(cls, android): + """Encodes an ``AndroidConfig`` instance into JSON.""" + if android is None: + return None + if not isinstance(android, _messaging_utils.AndroidConfig): + raise ValueError('Message.android must be an instance of AndroidConfig class.') + result = { + 'collapse_key': _Validators.check_string( + 'AndroidConfig.collapse_key', android.collapse_key), + 'data': _Validators.check_string_dict( + 'AndroidConfig.data', android.data), + 'notification': cls.encode_android_notification(android.notification), + 'priority': _Validators.check_string( + 'AndroidConfig.priority', android.priority, non_empty=True), + 'restricted_package_name': _Validators.check_string( + 'AndroidConfig.restricted_package_name', android.restricted_package_name), + 'ttl': cls.encode_ttl(android.ttl), + 'fcm_options': cls.encode_android_fcm_options(android.fcm_options), + } + result = cls.remove_null_values(result) + priority = result.get('priority') + if priority and priority not in ('high', 'normal'): + raise ValueError('AndroidConfig.priority must be "high" or "normal".') + return result + + @classmethod + def encode_android_fcm_options(cls, fcm_options): + """Encodes an ``AndroidFCMOptions`` instance into JSON.""" + if fcm_options is None: + return None + if not isinstance(fcm_options, _messaging_utils.AndroidFCMOptions): + raise ValueError('AndroidConfig.fcm_options must be an instance of ' + 'AndroidFCMOptions class.') + result = { + 'analytics_label': _Validators.check_analytics_label( + 'AndroidFCMOptions.analytics_label', fcm_options.analytics_label), + } + result = cls.remove_null_values(result) + return result + + @classmethod + def encode_ttl(cls, ttl): + """Encodes an ``AndroidConfig`` ``TTL`` duration into a string.""" + if ttl is None: + return None + if isinstance(ttl, numbers.Number): + ttl = datetime.timedelta(seconds=ttl) + if not isinstance(ttl, datetime.timedelta): + raise ValueError('AndroidConfig.ttl must be a duration in seconds or an instance of ' + 'datetime.timedelta.') + total_seconds = ttl.total_seconds() + if total_seconds < 0: + raise ValueError('AndroidConfig.ttl must not be negative.') + seconds = int(math.floor(total_seconds)) + nanos = int((total_seconds - seconds) * 1e9) + if nanos: + return '{0}.{1}s'.format(seconds, str(nanos).zfill(9)) + return '{0}s'.format(seconds) + + @classmethod + def encode_milliseconds(cls, label, msec): + """Encodes a duration in milliseconds into a string.""" + if msec is None: + return None + if isinstance(msec, numbers.Number): + msec = datetime.timedelta(milliseconds=msec) + if not isinstance(msec, datetime.timedelta): + raise ValueError('{0} must be a duration in milliseconds or an instance of ' + 'datetime.timedelta.'.format(label)) + total_seconds = msec.total_seconds() + if total_seconds < 0: + raise ValueError('{0} must not be negative.'.format(label)) + seconds = int(math.floor(total_seconds)) + nanos = int((total_seconds - seconds) * 1e9) + if nanos: + return '{0}.{1}s'.format(seconds, str(nanos).zfill(9)) + return '{0}s'.format(seconds) + + @classmethod + def encode_boolean(cls, value): + """Encodes a boolean into JSON.""" + if value is None: + return None + return 1 if value else 0 + + @classmethod + def encode_android_notification(cls, notification): + """Encodes an ``AndroidNotification`` instance into JSON.""" + if notification is None: + return None + if not isinstance(notification, _messaging_utils.AndroidNotification): + raise ValueError('AndroidConfig.notification must be an instance of ' + 'AndroidNotification class.') + result = { + 'body': _Validators.check_string( + 'AndroidNotification.body', notification.body), + 'body_loc_args': _Validators.check_string_list( + 'AndroidNotification.body_loc_args', notification.body_loc_args), + 'body_loc_key': _Validators.check_string( + 'AndroidNotification.body_loc_key', notification.body_loc_key), + 'click_action': _Validators.check_string( + 'AndroidNotification.click_action', notification.click_action), + 'color': _Validators.check_string( + 'AndroidNotification.color', notification.color, non_empty=True), + 'icon': _Validators.check_string( + 'AndroidNotification.icon', notification.icon), + 'sound': _Validators.check_string( + 'AndroidNotification.sound', notification.sound), + 'tag': _Validators.check_string( + 'AndroidNotification.tag', notification.tag), + 'title': _Validators.check_string( + 'AndroidNotification.title', notification.title), + 'title_loc_args': _Validators.check_string_list( + 'AndroidNotification.title_loc_args', notification.title_loc_args), + 'title_loc_key': _Validators.check_string( + 'AndroidNotification.title_loc_key', notification.title_loc_key), + 'channel_id': _Validators.check_string( + 'AndroidNotification.channel_id', notification.channel_id), + 'image': _Validators.check_string( + 'image', notification.image), + 'ticker': _Validators.check_string( + 'AndroidNotification.ticker', notification.ticker), + 'sticky': cls.encode_boolean(notification.sticky), + 'event_time': _Validators.check_datetime( + 'AndroidNotification.event_timestamp', notification.event_timestamp), + 'local_only': cls.encode_boolean(notification.local_only), + 'notification_priority': _Validators.check_string( + 'AndroidNotification.priority', notification.priority, non_empty=True), + 'vibrate_timings': _Validators.check_number_list( + 'AndroidNotification.vibrate_timings_millis', notification.vibrate_timings_millis), + 'default_vibrate_timings': cls.encode_boolean(notification.default_vibrate_timings), + 'default_sound': cls.encode_boolean(notification.default_sound), + 'default_light_settings': cls.encode_boolean(notification.default_light_settings), + 'light_settings': cls.encode_light_settings(notification.light_settings), + 'visibility': _Validators.check_string( + 'AndroidNotification.visibility', notification.visibility, non_empty=True), + 'notification_count': _Validators.check_number( + 'AndroidNotification.notification_count', notification.notification_count) + } + result = cls.remove_null_values(result) + color = result.get('color') + if color and not re.match(r'^#[0-9a-fA-F]{6}$', color): + raise ValueError( + 'AndroidNotification.color must be in the form #RRGGBB.') + if result.get('body_loc_args') and not result.get('body_loc_key'): + raise ValueError( + 'AndroidNotification.body_loc_key is required when specifying body_loc_args.') + if result.get('title_loc_args') and not result.get('title_loc_key'): + raise ValueError( + 'AndroidNotification.title_loc_key is required when specifying title_loc_args.') + + event_time = result.get('event_time') + if event_time: + result['event_time'] = str(event_time.isoformat()) + 'Z' + + priority = result.get('notification_priority') + if priority: + if priority not in ('min', 'low', 'default', 'high', 'max'): + raise ValueError('AndroidNotification.priority must be "default", "min", "low", ' + '"high" or "max".') + result['notification_priority'] = 'PRIORITY_' + priority.upper() + + visibility = result.get('visibility') + if visibility: + if visibility not in ('private', 'public', 'secret'): + raise ValueError( + 'AndroidNotification.visibility must be "private", "public" or "secret".') + result['visibility'] = visibility.upper() + + vibrate_timings_millis = result.get('vibrate_timings') + if vibrate_timings_millis: + vibrate_timing_strings = [] + for msec in vibrate_timings_millis: + formated_string = cls.encode_milliseconds( + 'AndroidNotification.vibrate_timings_millis', msec) + vibrate_timing_strings.append(formated_string) + result['vibrate_timings'] = vibrate_timing_strings + return result + + @classmethod + def encode_light_settings(cls, light_settings): + """Encodes a ``LightSettings`` instance into JSON.""" + if light_settings is None: + return None + if not isinstance(light_settings, _messaging_utils.LightSettings): + raise ValueError( + 'AndroidNotification.light_settings must be an instance of LightSettings class.') + result = { + 'color': _Validators.check_string( + 'LightSettings.color', light_settings.color, non_empty=True), + 'light_on_duration': cls.encode_milliseconds( + 'LightSettings.light_on_duration_millis', light_settings.light_on_duration_millis), + 'light_off_duration': cls.encode_milliseconds( + 'LightSettings.light_off_duration_millis', + light_settings.light_off_duration_millis), + } + result = cls.remove_null_values(result) + light_on_duration = result.get('light_on_duration') + if not light_on_duration: + raise ValueError( + 'LightSettings.light_on_duration_millis is required.') + + light_off_duration = result.get('light_off_duration') + if not light_off_duration: + raise ValueError( + 'LightSettings.light_off_duration_millis is required.') + + color = result.get('color') + if not color: + raise ValueError('LightSettings.color is required.') + if not re.match(r'^#[0-9a-fA-F]{6}$', color) and not re.match(r'^#[0-9a-fA-F]{8}$', color): + raise ValueError( + 'LightSettings.color must be in the form #RRGGBB or #RRGGBBAA.') + if len(color) == 7: + color = (color+'FF') + rgba = [int(color[i:i + 2], 16) / 255.0 for i in (1, 3, 5, 7)] + result['color'] = {'red': rgba[0], 'green': rgba[1], + 'blue': rgba[2], 'alpha': rgba[3]} + return result + + @classmethod + def encode_webpush(cls, webpush): + """Encodes a ``WebpushConfig`` instance into JSON.""" + if webpush is None: + return None + if not isinstance(webpush, _messaging_utils.WebpushConfig): + raise ValueError('Message.webpush must be an instance of WebpushConfig class.') + result = { + 'data': _Validators.check_string_dict( + 'WebpushConfig.data', webpush.data), + 'headers': _Validators.check_string_dict( + 'WebpushConfig.headers', webpush.headers), + 'notification': cls.encode_webpush_notification(webpush.notification), + 'fcm_options': cls.encode_webpush_fcm_options(webpush.fcm_options), + } + return cls.remove_null_values(result) + + @classmethod + def encode_webpush_notification(cls, notification): + """Encodes a ``WebpushNotification`` instance into JSON.""" + if notification is None: + return None + if not isinstance(notification, _messaging_utils.WebpushNotification): + raise ValueError('WebpushConfig.notification must be an instance of ' + 'WebpushNotification class.') + result = { + 'actions': cls.encode_webpush_notification_actions(notification.actions), + 'badge': _Validators.check_string( + 'WebpushNotification.badge', notification.badge), + 'body': _Validators.check_string( + 'WebpushNotification.body', notification.body), + 'data': notification.data, + 'dir': _Validators.check_string( + 'WebpushNotification.direction', notification.direction), + 'icon': _Validators.check_string( + 'WebpushNotification.icon', notification.icon), + 'image': _Validators.check_string( + 'WebpushNotification.image', notification.image), + 'lang': _Validators.check_string( + 'WebpushNotification.language', notification.language), + 'renotify': notification.renotify, + 'requireInteraction': notification.require_interaction, + 'silent': notification.silent, + 'tag': _Validators.check_string( + 'WebpushNotification.tag', notification.tag), + 'timestamp': _Validators.check_number( + 'WebpushNotification.timestamp_millis', notification.timestamp_millis), + 'title': _Validators.check_string( + 'WebpushNotification.title', notification.title), + 'vibrate': notification.vibrate, + } + direction = result.get('dir') + if direction and direction not in ('auto', 'ltr', 'rtl'): + raise ValueError('WebpushNotification.direction must be "auto", "ltr" or "rtl".') + if notification.custom_data is not None: + if not isinstance(notification.custom_data, dict): + raise ValueError('WebpushNotification.custom_data must be a dict.') + for key, value in notification.custom_data.items(): + if key in result: + raise ValueError( + 'Multiple specifications for {0} in WebpushNotification.'.format(key)) + result[key] = value + return cls.remove_null_values(result) + + @classmethod + def encode_webpush_notification_actions(cls, actions): + """Encodes a list of ``WebpushNotificationActions`` into JSON.""" + if actions is None: + return None + if not isinstance(actions, list): + raise ValueError('WebpushConfig.notification.actions must be a list of ' + 'WebpushNotificationAction instances.') + results = [] + for action in actions: + if not isinstance(action, _messaging_utils.WebpushNotificationAction): + raise ValueError('WebpushConfig.notification.actions must be a list of ' + 'WebpushNotificationAction instances.') + result = { + 'action': _Validators.check_string( + 'WebpushNotificationAction.action', action.action), + 'title': _Validators.check_string( + 'WebpushNotificationAction.title', action.title), + 'icon': _Validators.check_string( + 'WebpushNotificationAction.icon', action.icon), + } + results.append(cls.remove_null_values(result)) + return results + + @classmethod + def encode_webpush_fcm_options(cls, options): + """Encodes a ``WebpushFCMOptions`` instance into JSON.""" + if options is None: + return None + result = { + 'link': _Validators.check_string('WebpushConfig.fcm_options.link', options.link), + } + result = cls.remove_null_values(result) + link = result.get('link') + if link is not None and not link.startswith('https://'): + raise ValueError('WebpushFCMOptions.link must be a HTTPS URL.') + return result + + @classmethod + def encode_apns(cls, apns): + """Encodes an ``APNSConfig`` instance into JSON.""" + if apns is None: + return None + if not isinstance(apns, _messaging_utils.APNSConfig): + raise ValueError('Message.apns must be an instance of APNSConfig class.') + result = { + 'headers': _Validators.check_string_dict( + 'APNSConfig.headers', apns.headers), + 'payload': cls.encode_apns_payload(apns.payload), + 'fcm_options': cls.encode_apns_fcm_options(apns.fcm_options), + } + return cls.remove_null_values(result) + + @classmethod + def encode_apns_payload(cls, payload): + """Encodes an ``APNSPayload`` instance into JSON.""" + if payload is None: + return None + if not isinstance(payload, _messaging_utils.APNSPayload): + raise ValueError('APNSConfig.payload must be an instance of APNSPayload class.') + result = { + 'aps': cls.encode_aps(payload.aps) + } + for key, value in payload.custom_data.items(): + result[key] = value + return cls.remove_null_values(result) + + @classmethod + def encode_apns_fcm_options(cls, fcm_options): + """Encodes an ``APNSFCMOptions`` instance into JSON.""" + if fcm_options is None: + return None + if not isinstance(fcm_options, _messaging_utils.APNSFCMOptions): + raise ValueError('APNSConfig.fcm_options must be an instance of APNSFCMOptions class.') + result = { + 'analytics_label': _Validators.check_analytics_label( + 'APNSFCMOptions.analytics_label', fcm_options.analytics_label), + 'image': _Validators.check_string('APNSFCMOptions.image', fcm_options.image) + } + result = cls.remove_null_values(result) + return result + + @classmethod + def encode_aps(cls, aps): + """Encodes an ``Aps`` instance into JSON.""" + if not isinstance(aps, _messaging_utils.Aps): + raise ValueError('APNSPayload.aps must be an instance of Aps class.') + result = { + 'alert': cls.encode_aps_alert(aps.alert), + 'badge': _Validators.check_number('Aps.badge', aps.badge), + 'sound': cls.encode_aps_sound(aps.sound), + 'category': _Validators.check_string('Aps.category', aps.category), + 'thread-id': _Validators.check_string('Aps.thread_id', aps.thread_id), + } + if aps.content_available is True: + result['content-available'] = 1 + if aps.mutable_content is True: + result['mutable-content'] = 1 + if aps.custom_data is not None: + if not isinstance(aps.custom_data, dict): + raise ValueError('Aps.custom_data must be a dict.') + for key, val in aps.custom_data.items(): + _Validators.check_string('Aps.custom_data key', key) + if key in result: + raise ValueError('Multiple specifications for {0} in Aps.'.format(key)) + result[key] = val + return cls.remove_null_values(result) + + @classmethod + def encode_aps_sound(cls, sound): + """Encodes an APNs sound configuration into JSON.""" + if sound is None: + return None + if sound and isinstance(sound, six.string_types): + return sound + if not isinstance(sound, _messaging_utils.CriticalSound): + raise ValueError( + 'Aps.sound must be a non-empty string or an instance of CriticalSound class.') + result = { + 'name': _Validators.check_string('CriticalSound.name', sound.name, non_empty=True), + 'volume': _Validators.check_number('CriticalSound.volume', sound.volume), + } + if sound.critical: + result['critical'] = 1 + if not result['name']: + raise ValueError('CriticalSond.name must be a non-empty string.') + volume = result['volume'] + if volume is not None and (volume < 0 or volume > 1): + raise ValueError('CriticalSound.volume must be in the interval [0,1].') + return cls.remove_null_values(result) + + @classmethod + def encode_aps_alert(cls, alert): + """Encodes an ``ApsAlert`` instance into JSON.""" + if alert is None: + return None + if isinstance(alert, six.string_types): + return alert + if not isinstance(alert, _messaging_utils.ApsAlert): + raise ValueError('Aps.alert must be a string or an instance of ApsAlert class.') + result = { + 'title': _Validators.check_string('ApsAlert.title', alert.title), + 'subtitle': _Validators.check_string('ApsAlert.subtitle', alert.subtitle), + 'body': _Validators.check_string('ApsAlert.body', alert.body), + 'title-loc-key': _Validators.check_string( + 'ApsAlert.title_loc_key', alert.title_loc_key), + 'title-loc-args': _Validators.check_string_list( + 'ApsAlert.title_loc_args', alert.title_loc_args), + 'loc-key': _Validators.check_string( + 'ApsAlert.loc_key', alert.loc_key), + 'loc-args': _Validators.check_string_list( + 'ApsAlert.loc_args', alert.loc_args), + 'action-loc-key': _Validators.check_string( + 'ApsAlert.action_loc_key', alert.action_loc_key), + 'launch-image': _Validators.check_string( + 'ApsAlert.launch_image', alert.launch_image), + } + if result.get('loc-args') and not result.get('loc-key'): + raise ValueError( + 'ApsAlert.loc_key is required when specifying loc_args.') + if result.get('title-loc-args') and not result.get('title-loc-key'): + raise ValueError( + 'ApsAlert.title_loc_key is required when specifying title_loc_args.') + if alert.custom_data is not None: + if not isinstance(alert.custom_data, dict): + raise ValueError('ApsAlert.custom_data must be a dict.') + for key, val in alert.custom_data.items(): + _Validators.check_string('ApsAlert.custom_data key', key) + # allow specifying key override because Apple could update API so that key + # could have unexpected value type + result[key] = val + return cls.remove_null_values(result) + + @classmethod + def encode_notification(cls, notification): + """Encodes a ``Notification`` instance into JSON.""" + if notification is None: + return None + if not isinstance(notification, _messaging_utils.Notification): + raise ValueError('Message.notification must be an instance of Notification class.') + result = { + 'body': _Validators.check_string('Notification.body', notification.body), + 'title': _Validators.check_string('Notification.title', notification.title), + 'image': _Validators.check_string('Notification.image', notification.image) + } + return cls.remove_null_values(result) + + @classmethod + def sanitize_topic_name(cls, topic): + if not topic: + return None + prefix = '/topics/' + if topic.startswith(prefix): + topic = topic[len(prefix):] + # Checks for illegal characters and empty string. + if not re.match(r'^[a-zA-Z0-9-_\.~%]+$', topic): + raise ValueError('Malformed topic name.') + return topic + + def default(self, obj): # pylint: disable=method-hidden + if not isinstance(obj, Message): + return json.JSONEncoder.default(self, obj) + result = { + 'android': MessageEncoder.encode_android(obj.android), + 'apns': MessageEncoder.encode_apns(obj.apns), + 'condition': _Validators.check_string( + 'Message.condition', obj.condition, non_empty=True), + 'data': _Validators.check_string_dict('Message.data', obj.data), + 'notification': MessageEncoder.encode_notification(obj.notification), + 'token': _Validators.check_string('Message.token', obj.token, non_empty=True), + 'topic': _Validators.check_string('Message.topic', obj.topic, non_empty=True), + 'webpush': MessageEncoder.encode_webpush(obj.webpush), + 'fcm_options': MessageEncoder.encode_fcm_options(obj.fcm_options), + } + result['topic'] = MessageEncoder.sanitize_topic_name(result.get('topic')) + result = MessageEncoder.remove_null_values(result) + target_count = sum([t in result for t in ['token', 'topic', 'condition']]) + if target_count != 1: + raise ValueError('Exactly one of token, topic or condition must be specified.') + return result + + @classmethod + def encode_fcm_options(cls, fcm_options): + """Encodes an ``FCMOptions`` instance into JSON.""" + if fcm_options is None: + return None + if not isinstance(fcm_options, _messaging_utils.FCMOptions): + raise ValueError('Message.fcm_options must be an instance of FCMOptions class.') + result = { + 'analytics_label': _Validators.check_analytics_label( + 'FCMOptions.analytics_label', fcm_options.analytics_label), + } + result = cls.remove_null_values(result) + return result diff --git a/firebase_admin/_messaging_utils.py b/firebase_admin/_messaging_utils.py index 28d283d73..7287e57d9 100644 --- a/firebase_admin/_messaging_utils.py +++ b/firebase_admin/_messaging_utils.py @@ -14,80 +14,9 @@ """Types and utilities used by the messaging (FCM) module.""" -import datetime -import json -import math -import numbers -import re - -import six - from firebase_admin import exceptions -class Message(object): - """A message that can be sent via Firebase Cloud Messaging. - - Contains payload information as well as recipient information. In particular, the message must - contain exactly one of token, topic or condition fields. - - Args: - data: A dictionary of data fields (optional). All keys and values in the dictionary must be - strings. - notification: An instance of ``messaging.Notification`` (optional). - android: An instance of ``messaging.AndroidConfig`` (optional). - webpush: An instance of ``messaging.WebpushConfig`` (optional). - apns: An instance of ``messaging.ApnsConfig`` (optional). - fcm_options: An instance of ``messaging.FCMOptions`` (optional). - token: The registration token of the device to which the message should be sent (optional). - topic: Name of the FCM topic to which the message should be sent (optional). Topic name - may contain the ``/topics/`` prefix. - condition: The FCM condition to which the message should be sent (optional). - """ - - def __init__(self, data=None, notification=None, android=None, webpush=None, apns=None, - fcm_options=None, token=None, topic=None, condition=None): - self.data = data - self.notification = notification - self.android = android - self.webpush = webpush - self.apns = apns - self.fcm_options = fcm_options - self.token = token - self.topic = topic - self.condition = condition - - def __str__(self): - return json.dumps(self, cls=MessageEncoder, sort_keys=True) - - -class MulticastMessage(object): - """A message that can be sent to multiple tokens via Firebase Cloud Messaging. - - Args: - tokens: A list of registration tokens of targeted devices. - data: A dictionary of data fields (optional). All keys and values in the dictionary must be - strings. - notification: An instance of ``messaging.Notification`` (optional). - android: An instance of ``messaging.AndroidConfig`` (optional). - webpush: An instance of ``messaging.WebpushConfig`` (optional). - apns: An instance of ``messaging.ApnsConfig`` (optional). - fcm_options: An instance of ``messaging.FCMOptions`` (optional). - """ - def __init__(self, tokens, data=None, notification=None, android=None, webpush=None, apns=None, - fcm_options=None): - _Validators.check_string_list('MulticastMessage.tokens', tokens) - if len(tokens) > 500: - raise ValueError('MulticastMessage.tokens must not contain more than 500 tokens.') - self.tokens = tokens - self.data = data - self.notification = notification - self.android = android - self.webpush = webpush - self.apns = apns - self.fcm_options = fcm_options - - class Notification(object): """A notification that can be included in a message. @@ -161,11 +90,67 @@ class AndroidNotification(object): in ``title_loc_key`` (optional). channel_id: channel_id of the notification (optional). image: Image url of the notification (optional). + ticker: Sets the ``ticker`` text, which is sent to accessibility services. Prior to API + level 21 (Lollipop), sets the text that is displayed in the status bar when the + notification first arrives (optional). + sticky: When set to ``false`` or unset, the notification is automatically dismissed when the + user clicks it in the panel. When set to ``True``, the notification persists even when + the user clicks it (optional). + event_timestamp: For notifications that inform users about events with an absolute time + reference, sets the time that the event in the notification occurred as a + ``datetime.datetime`` instance. Notifications in the panel are sorted by this time + (optional). + local_only: Sets whether or not this notification is relevant only to the current device. + Some notifications can be bridged to other devices for remote display, such as a Wear OS + watch. This hint can be set to recommend this notification not be bridged (optional). + See Wear OS guides: + https://developer.android.com/training/wearables/notifications/bridger#existing-method-of-preventing-bridging + priority: Sets the relative priority for this notification. Low-priority notifications may + be hidden from the user in certain situations. Note this priority differs from + ``AndroidMessagePriority``. This priority is processed by the client after the message + has been delivered. Whereas ``AndroidMessagePriority`` is an FCM concept that controls + when the message is delivered (optional). Must be one of ``default``, ``min``, ``low``, + ``high``, ``max`` or ``normal``. + vibrate_timings_millis: Sets the vibration pattern to use. Pass in an array of milliseconds + to turn the vibrator on or off. The first value indicates the duration to wait before + turning the vibrator on. The next value indicates the duration to keep the vibrator on. + Subsequent values alternate between duration to turn the vibrator off and to turn the + vibrator on. If ``vibrate_timings`` is set and ``default_vibrate_timings`` is set to + ``True``, the default value is used instead of the user-specified ``vibrate_timings``. + default_vibrate_timings: If set to ``True``, use the Android framework's default vibrate + pattern for the notification (optional). Default values are specified in ``config.xml`` + https://android.googlesource.com/platform/frameworks/base/+/master/core/res/res/values/config.xml. + If ``default_vibrate_timings`` is set to ``True`` and ``vibrate_timings`` is also set, + the default value is used instead of the user-specified ``vibrate_timings``. + default_sound: If set to ``True``, use the Android framework's default sound for the + notification (optional). Default values are specified in ``config.xml`` + https://android.googlesource.com/platform/frameworks/base/+/master/core/res/res/values/config.xml + light_settings: Settings to control the notification's LED blinking rate and color if LED is + available on the device. The total blinking time is controlled by the OS (optional). + default_light_settings: If set to ``True``, use the Android framework's default LED light + settings for the notification. Default values are specified in ``config.xml`` + https://android.googlesource.com/platform/frameworks/base/+/master/core/res/res/values/config.xml. + If ``default_light_settings`` is set to ``True`` and ``light_settings`` is also set, the + user-specified ``light_settings`` is used instead of the default value. + visibility: Sets the visibility of the notification. Must be either ``private``, ``public``, + or ``secret``. If unspecified, default to ``private``. + notification_count: Sets the number of items this notification represents. May be displayed + as a badge count for Launchers that support badging. See ``NotificationBadge`` + https://developer.android.com/training/notify-user/badges. For example, this might be + useful if you're using just one notification to represent multiple new messages but you + want the count here to represent the number of total new messages. If zero or + unspecified, systems that support badging use the default, which is to increment a + number displayed on the long-press menu each time a new notification arrives. + + """ def __init__(self, title=None, body=None, icon=None, color=None, sound=None, tag=None, click_action=None, body_loc_key=None, body_loc_args=None, title_loc_key=None, - title_loc_args=None, channel_id=None, image=None): + title_loc_args=None, channel_id=None, image=None, ticker=None, sticky=None, + event_timestamp=None, local_only=None, priority=None, vibrate_timings_millis=None, + default_vibrate_timings=None, default_sound=None, light_settings=None, + default_light_settings=None, visibility=None, notification_count=None): self.title = title self.body = body self.icon = icon @@ -179,6 +164,36 @@ def __init__(self, title=None, body=None, icon=None, color=None, sound=None, tag self.title_loc_args = title_loc_args self.channel_id = channel_id self.image = image + self.ticker = ticker + self.sticky = sticky + self.event_timestamp = event_timestamp + self.local_only = local_only + self.priority = priority + self.vibrate_timings_millis = vibrate_timings_millis + self.default_vibrate_timings = default_vibrate_timings + self.default_sound = default_sound + self.light_settings = light_settings + self.default_light_settings = default_light_settings + self.visibility = visibility + self.notification_count = notification_count + + +class LightSettings(object): + """Represents settings to control notification LED that can be included in a + ``messaging.AndroidNotification``. + + Args: + color: Sets the color of the LED in ``#rrggbb`` or ``#rrggbbaa`` format. + light_on_duration_millis: Along with ``light_off_duration``, defines the blink rate of LED + flashes. + light_off_duration_millis: Along with ``light_on_duration``, defines the blink rate of LED + flashes. + """ + def __init__(self, color, light_on_duration_millis, + light_off_duration_millis): + self.color = color + self.light_on_duration_millis = light_on_duration_millis + self.light_off_duration_millis = light_off_duration_millis class AndroidFCMOptions(object): @@ -448,486 +463,6 @@ def __init__(self, analytics_label=None): self.analytics_label = analytics_label -class _Validators(object): - """A collection of data validation utilities. - - Methods provided in this class raise ValueErrors if any validations fail. - """ - - @classmethod - def check_string(cls, label, value, non_empty=False): - """Checks if the given value is a string.""" - if value is None: - return None - if not isinstance(value, six.string_types): - if non_empty: - raise ValueError('{0} must be a non-empty string.'.format(label)) - else: - raise ValueError('{0} must be a string.'.format(label)) - if non_empty and not value: - raise ValueError('{0} must be a non-empty string.'.format(label)) - return value - - @classmethod - def check_number(cls, label, value): - if value is None: - return None - if not isinstance(value, numbers.Number): - raise ValueError('{0} must be a number.'.format(label)) - return value - - @classmethod - def check_string_dict(cls, label, value): - """Checks if the given value is a dictionary comprised only of string keys and values.""" - if value is None or value == {}: - return None - if not isinstance(value, dict): - raise ValueError('{0} must be a dictionary.'.format(label)) - non_str = [k for k in value if not isinstance(k, six.string_types)] - if non_str: - raise ValueError('{0} must not contain non-string keys.'.format(label)) - non_str = [v for v in value.values() if not isinstance(v, six.string_types)] - if non_str: - raise ValueError('{0} must not contain non-string values.'.format(label)) - return value - - @classmethod - def check_string_list(cls, label, value): - """Checks if the given value is a list comprised only of strings.""" - if value is None or value == []: - return None - if not isinstance(value, list): - raise ValueError('{0} must be a list of strings.'.format(label)) - non_str = [k for k in value if not isinstance(k, six.string_types)] - if non_str: - raise ValueError('{0} must not contain non-string values.'.format(label)) - return value - - @classmethod - def check_analytics_label(cls, label, value): - """Checks if the given value is a valid analytics label.""" - value = _Validators.check_string(label, value) - if value is not None and not re.match(r'^[a-zA-Z0-9-_.~%]{1,50}$', value): - raise ValueError('Malformed {}.'.format(label)) - return value - - -class MessageEncoder(json.JSONEncoder): - """A custom JSONEncoder implementation for serializing Message instances into JSON.""" - - @classmethod - def remove_null_values(cls, dict_value): - return {k: v for k, v in dict_value.items() if v not in [None, [], {}]} - - @classmethod - def encode_android(cls, android): - """Encodes an AndroidConfig instance into JSON.""" - if android is None: - return None - if not isinstance(android, AndroidConfig): - raise ValueError('Message.android must be an instance of AndroidConfig class.') - result = { - 'collapse_key': _Validators.check_string( - 'AndroidConfig.collapse_key', android.collapse_key), - 'data': _Validators.check_string_dict( - 'AndroidConfig.data', android.data), - 'notification': cls.encode_android_notification(android.notification), - 'priority': _Validators.check_string( - 'AndroidConfig.priority', android.priority, non_empty=True), - 'restricted_package_name': _Validators.check_string( - 'AndroidConfig.restricted_package_name', android.restricted_package_name), - 'ttl': cls.encode_ttl(android.ttl), - 'fcm_options': cls.encode_android_fcm_options(android.fcm_options), - } - result = cls.remove_null_values(result) - priority = result.get('priority') - if priority and priority not in ('high', 'normal'): - raise ValueError('AndroidConfig.priority must be "high" or "normal".') - return result - - @classmethod - def encode_android_fcm_options(cls, fcm_options): - """Encodes an AndroidFCMOptions instance into a json.""" - if fcm_options is None: - return None - if not isinstance(fcm_options, AndroidFCMOptions): - raise ValueError('AndroidConfig.fcm_options must be an instance of ' - 'AndroidFCMOptions class.') - result = { - 'analytics_label': _Validators.check_analytics_label( - 'AndroidFCMOptions.analytics_label', fcm_options.analytics_label), - } - result = cls.remove_null_values(result) - return result - - @classmethod - def encode_ttl(cls, ttl): - """Encodes a AndroidConfig TTL duration into a string.""" - if ttl is None: - return None - if isinstance(ttl, numbers.Number): - ttl = datetime.timedelta(seconds=ttl) - if not isinstance(ttl, datetime.timedelta): - raise ValueError('AndroidConfig.ttl must be a duration in seconds or an instance of ' - 'datetime.timedelta.') - total_seconds = ttl.total_seconds() - if total_seconds < 0: - raise ValueError('AndroidConfig.ttl must not be negative.') - seconds = int(math.floor(total_seconds)) - nanos = int((total_seconds - seconds) * 1e9) - if nanos: - return '{0}.{1}s'.format(seconds, str(nanos).zfill(9)) - return '{0}s'.format(seconds) - - @classmethod - def encode_android_notification(cls, notification): - """Encodes an AndroidNotification instance into JSON.""" - if notification is None: - return None - if not isinstance(notification, AndroidNotification): - raise ValueError('AndroidConfig.notification must be an instance of ' - 'AndroidNotification class.') - result = { - 'body': _Validators.check_string( - 'AndroidNotification.body', notification.body), - 'body_loc_args': _Validators.check_string_list( - 'AndroidNotification.body_loc_args', notification.body_loc_args), - 'body_loc_key': _Validators.check_string( - 'AndroidNotification.body_loc_key', notification.body_loc_key), - 'click_action': _Validators.check_string( - 'AndroidNotification.click_action', notification.click_action), - 'color': _Validators.check_string( - 'AndroidNotification.color', notification.color, non_empty=True), - 'icon': _Validators.check_string( - 'AndroidNotification.icon', notification.icon), - 'sound': _Validators.check_string( - 'AndroidNotification.sound', notification.sound), - 'tag': _Validators.check_string( - 'AndroidNotification.tag', notification.tag), - 'title': _Validators.check_string( - 'AndroidNotification.title', notification.title), - 'title_loc_args': _Validators.check_string_list( - 'AndroidNotification.title_loc_args', notification.title_loc_args), - 'title_loc_key': _Validators.check_string( - 'AndroidNotification.title_loc_key', notification.title_loc_key), - 'channel_id': _Validators.check_string( - 'AndroidNotification.channel_id', notification.channel_id), - 'image': _Validators.check_string( - 'image', notification.image - ) - } - result = cls.remove_null_values(result) - color = result.get('color') - if color and not re.match(r'^#[0-9a-fA-F]{6}$', color): - raise ValueError('AndroidNotification.color must be in the form #RRGGBB.') - if result.get('body_loc_args') and not result.get('body_loc_key'): - raise ValueError( - 'AndroidNotification.body_loc_key is required when specifying body_loc_args.') - if result.get('title_loc_args') and not result.get('title_loc_key'): - raise ValueError( - 'AndroidNotification.title_loc_key is required when specifying title_loc_args.') - return result - - @classmethod - def encode_webpush(cls, webpush): - """Encodes a WebpushConfig instance into JSON.""" - if webpush is None: - return None - if not isinstance(webpush, WebpushConfig): - raise ValueError('Message.webpush must be an instance of WebpushConfig class.') - result = { - 'data': _Validators.check_string_dict( - 'WebpushConfig.data', webpush.data), - 'headers': _Validators.check_string_dict( - 'WebpushConfig.headers', webpush.headers), - 'notification': cls.encode_webpush_notification(webpush.notification), - 'fcm_options': cls.encode_webpush_fcm_options(webpush.fcm_options), - } - return cls.remove_null_values(result) - - @classmethod - def encode_webpush_notification(cls, notification): - """Encodes a WebpushNotification instance into JSON.""" - if notification is None: - return None - if not isinstance(notification, WebpushNotification): - raise ValueError('WebpushConfig.notification must be an instance of ' - 'WebpushNotification class.') - result = { - 'actions': cls.encode_webpush_notification_actions(notification.actions), - 'badge': _Validators.check_string( - 'WebpushNotification.badge', notification.badge), - 'body': _Validators.check_string( - 'WebpushNotification.body', notification.body), - 'data': notification.data, - 'dir': _Validators.check_string( - 'WebpushNotification.direction', notification.direction), - 'icon': _Validators.check_string( - 'WebpushNotification.icon', notification.icon), - 'image': _Validators.check_string( - 'WebpushNotification.image', notification.image), - 'lang': _Validators.check_string( - 'WebpushNotification.language', notification.language), - 'renotify': notification.renotify, - 'requireInteraction': notification.require_interaction, - 'silent': notification.silent, - 'tag': _Validators.check_string( - 'WebpushNotification.tag', notification.tag), - 'timestamp': _Validators.check_number( - 'WebpushNotification.timestamp_millis', notification.timestamp_millis), - 'title': _Validators.check_string( - 'WebpushNotification.title', notification.title), - 'vibrate': notification.vibrate, - } - direction = result.get('dir') - if direction and direction not in ('auto', 'ltr', 'rtl'): - raise ValueError('WebpushNotification.direction must be "auto", "ltr" or "rtl".') - if notification.custom_data is not None: - if not isinstance(notification.custom_data, dict): - raise ValueError('WebpushNotification.custom_data must be a dict.') - for key, value in notification.custom_data.items(): - if key in result: - raise ValueError( - 'Multiple specifications for {0} in WebpushNotification.'.format(key)) - result[key] = value - return cls.remove_null_values(result) - - @classmethod - def encode_webpush_notification_actions(cls, actions): - """Encodes a list of WebpushNotificationActions into JSON.""" - if actions is None: - return None - if not isinstance(actions, list): - raise ValueError('WebpushConfig.notification.actions must be a list of ' - 'WebpushNotificationAction instances.') - results = [] - for action in actions: - if not isinstance(action, WebpushNotificationAction): - raise ValueError('WebpushConfig.notification.actions must be a list of ' - 'WebpushNotificationAction instances.') - result = { - 'action': _Validators.check_string( - 'WebpushNotificationAction.action', action.action), - 'title': _Validators.check_string( - 'WebpushNotificationAction.title', action.title), - 'icon': _Validators.check_string( - 'WebpushNotificationAction.icon', action.icon), - } - results.append(cls.remove_null_values(result)) - return results - - @classmethod - def encode_webpush_fcm_options(cls, options): - """Encodes a WebpushFCMOptions instance into JSON.""" - if options is None: - return None - result = { - 'link': _Validators.check_string('WebpushConfig.fcm_options.link', options.link), - } - result = cls.remove_null_values(result) - link = result.get('link') - if link is not None and not link.startswith('https://'): - raise ValueError('WebpushFCMOptions.link must be a HTTPS URL.') - return result - - @classmethod - def encode_apns(cls, apns): - """Encodes an APNSConfig instance into JSON.""" - if apns is None: - return None - if not isinstance(apns, APNSConfig): - raise ValueError('Message.apns must be an instance of APNSConfig class.') - result = { - 'headers': _Validators.check_string_dict( - 'APNSConfig.headers', apns.headers), - 'payload': cls.encode_apns_payload(apns.payload), - 'fcm_options': cls.encode_apns_fcm_options(apns.fcm_options), - } - return cls.remove_null_values(result) - - @classmethod - def encode_apns_payload(cls, payload): - """Encodes an APNSPayload instance into JSON.""" - if payload is None: - return None - if not isinstance(payload, APNSPayload): - raise ValueError('APNSConfig.payload must be an instance of APNSPayload class.') - result = { - 'aps': cls.encode_aps(payload.aps) - } - for key, value in payload.custom_data.items(): - result[key] = value - return cls.remove_null_values(result) - - @classmethod - def encode_apns_fcm_options(cls, fcm_options): - """Encodes an APNSFCMOptions instance into JSON.""" - if fcm_options is None: - return None - if not isinstance(fcm_options, APNSFCMOptions): - raise ValueError('APNSConfig.fcm_options must be an instance of APNSFCMOptions class.') - result = { - 'analytics_label': _Validators.check_analytics_label( - 'APNSFCMOptions.analytics_label', fcm_options.analytics_label), - 'image': _Validators.check_string('APNSFCMOptions.image', fcm_options.image) - } - result = cls.remove_null_values(result) - return result - - @classmethod - def encode_aps(cls, aps): - """Encodes an Aps instance into JSON.""" - if not isinstance(aps, Aps): - raise ValueError('APNSPayload.aps must be an instance of Aps class.') - result = { - 'alert': cls.encode_aps_alert(aps.alert), - 'badge': _Validators.check_number('Aps.badge', aps.badge), - 'sound': cls.encode_aps_sound(aps.sound), - 'category': _Validators.check_string('Aps.category', aps.category), - 'thread-id': _Validators.check_string('Aps.thread_id', aps.thread_id), - } - if aps.content_available is True: - result['content-available'] = 1 - if aps.mutable_content is True: - result['mutable-content'] = 1 - if aps.custom_data is not None: - if not isinstance(aps.custom_data, dict): - raise ValueError('Aps.custom_data must be a dict.') - for key, val in aps.custom_data.items(): - _Validators.check_string('Aps.custom_data key', key) - if key in result: - raise ValueError('Multiple specifications for {0} in Aps.'.format(key)) - result[key] = val - return cls.remove_null_values(result) - - @classmethod - def encode_aps_sound(cls, sound): - """Encodes an APNs sound configuration into JSON.""" - if sound is None: - return None - if sound and isinstance(sound, six.string_types): - return sound - if not isinstance(sound, CriticalSound): - raise ValueError( - 'Aps.sound must be a non-empty string or an instance of CriticalSound class.') - result = { - 'name': _Validators.check_string('CriticalSound.name', sound.name, non_empty=True), - 'volume': _Validators.check_number('CriticalSound.volume', sound.volume), - } - if sound.critical: - result['critical'] = 1 - if not result['name']: - raise ValueError('CriticalSond.name must be a non-empty string.') - volume = result['volume'] - if volume is not None and (volume < 0 or volume > 1): - raise ValueError('CriticalSound.volume must be in the interval [0,1].') - return cls.remove_null_values(result) - - @classmethod - def encode_aps_alert(cls, alert): - """Encodes an ApsAlert instance into JSON.""" - if alert is None: - return None - if isinstance(alert, six.string_types): - return alert - if not isinstance(alert, ApsAlert): - raise ValueError('Aps.alert must be a string or an instance of ApsAlert class.') - result = { - 'title': _Validators.check_string('ApsAlert.title', alert.title), - 'subtitle': _Validators.check_string('ApsAlert.subtitle', alert.subtitle), - 'body': _Validators.check_string('ApsAlert.body', alert.body), - 'title-loc-key': _Validators.check_string( - 'ApsAlert.title_loc_key', alert.title_loc_key), - 'title-loc-args': _Validators.check_string_list( - 'ApsAlert.title_loc_args', alert.title_loc_args), - 'loc-key': _Validators.check_string( - 'ApsAlert.loc_key', alert.loc_key), - 'loc-args': _Validators.check_string_list( - 'ApsAlert.loc_args', alert.loc_args), - 'action-loc-key': _Validators.check_string( - 'ApsAlert.action_loc_key', alert.action_loc_key), - 'launch-image': _Validators.check_string( - 'ApsAlert.launch_image', alert.launch_image), - } - if result.get('loc-args') and not result.get('loc-key'): - raise ValueError( - 'ApsAlert.loc_key is required when specifying loc_args.') - if result.get('title-loc-args') and not result.get('title-loc-key'): - raise ValueError( - 'ApsAlert.title_loc_key is required when specifying title_loc_args.') - if alert.custom_data is not None: - if not isinstance(alert.custom_data, dict): - raise ValueError('ApsAlert.custom_data must be a dict.') - for key, val in alert.custom_data.items(): - _Validators.check_string('ApsAlert.custom_data key', key) - # allow specifying key override because Apple could update API so that key - # could have unexpected value type - result[key] = val - return cls.remove_null_values(result) - - @classmethod - def encode_notification(cls, notification): - """Encodes an Notification instance into JSON.""" - if notification is None: - return None - if not isinstance(notification, Notification): - raise ValueError('Message.notification must be an instance of Notification class.') - result = { - 'body': _Validators.check_string('Notification.body', notification.body), - 'title': _Validators.check_string('Notification.title', notification.title), - 'image': _Validators.check_string('Notification.image', notification.image) - } - return cls.remove_null_values(result) - - @classmethod - def sanitize_topic_name(cls, topic): - if not topic: - return None - prefix = '/topics/' - if topic.startswith(prefix): - topic = topic[len(prefix):] - # Checks for illegal characters and empty string. - if not re.match(r'^[a-zA-Z0-9-_\.~%]+$', topic): - raise ValueError('Malformed topic name.') - return topic - - def default(self, obj): # pylint: disable=method-hidden - if not isinstance(obj, Message): - return json.JSONEncoder.default(self, obj) - result = { - 'android': MessageEncoder.encode_android(obj.android), - 'apns': MessageEncoder.encode_apns(obj.apns), - 'condition': _Validators.check_string( - 'Message.condition', obj.condition, non_empty=True), - 'data': _Validators.check_string_dict('Message.data', obj.data), - 'notification': MessageEncoder.encode_notification(obj.notification), - 'token': _Validators.check_string('Message.token', obj.token, non_empty=True), - 'topic': _Validators.check_string('Message.topic', obj.topic, non_empty=True), - 'webpush': MessageEncoder.encode_webpush(obj.webpush), - 'fcm_options': MessageEncoder.encode_fcm_options(obj.fcm_options), - } - result['topic'] = MessageEncoder.sanitize_topic_name(result.get('topic')) - result = MessageEncoder.remove_null_values(result) - target_count = sum([t in result for t in ['token', 'topic', 'condition']]) - if target_count != 1: - raise ValueError('Exactly one of token, topic or condition must be specified.') - return result - - @classmethod - def encode_fcm_options(cls, fcm_options): - """Encodes an FCMOptions instance into JSON.""" - if fcm_options is None: - return None - if not isinstance(fcm_options, FCMOptions): - raise ValueError('Message.fcm_options must be an instance of FCMOptions class.') - result = { - 'analytics_label': _Validators.check_analytics_label( - 'FCMOptions.analytics_label', fcm_options.analytics_label), - } - result = cls.remove_null_values(result) - return result - - class ThirdPartyAuthError(exceptions.UnauthenticatedError): """APNs certificate or web push auth key was invalid or missing.""" diff --git a/firebase_admin/messaging.py b/firebase_admin/messaging.py index e7062ba04..a35afc87a 100644 --- a/firebase_admin/messaging.py +++ b/firebase_admin/messaging.py @@ -24,6 +24,7 @@ import firebase_admin from firebase_admin import _http_client +from firebase_admin import _messaging_encoder from firebase_admin import _messaging_utils from firebase_admin import _utils @@ -44,6 +45,7 @@ 'CriticalSound', 'ErrorInfo', 'FCMOptions', + 'LightSettings', 'Message', 'MulticastMessage', 'Notification', @@ -76,8 +78,9 @@ ApsAlert = _messaging_utils.ApsAlert CriticalSound = _messaging_utils.CriticalSound FCMOptions = _messaging_utils.FCMOptions -Message = _messaging_utils.Message -MulticastMessage = _messaging_utils.MulticastMessage +LightSettings = _messaging_utils.LightSettings +Message = _messaging_encoder.Message +MulticastMessage = _messaging_encoder.MulticastMessage Notification = _messaging_utils.Notification WebpushConfig = _messaging_utils.WebpushConfig WebpushFCMOptions = _messaging_utils.WebpushFCMOptions @@ -306,7 +309,7 @@ class _MessagingService(object): FCM_BATCH_URL = 'https://fcm.googleapis.com/batch' IID_URL = 'https://iid.googleapis.com' IID_HEADERS = {'access_token_auth': 'true'} - JSON_ENCODER = _messaging_utils.MessageEncoder() + JSON_ENCODER = _messaging_encoder.MessageEncoder() FCM_ERROR_TYPES = { 'APNS_AUTH_ERROR': ThirdPartyAuthError, diff --git a/integration/test_messaging.py b/integration/test_messaging.py index 01e1d212a..45b53ce97 100644 --- a/integration/test_messaging.py +++ b/integration/test_messaging.py @@ -15,6 +15,7 @@ """Integration tests for firebase_admin.messaging module.""" import re +from datetime import datetime import pytest @@ -39,7 +40,17 @@ def test_send(): title='android-title', body='android-body', image='https://images.unsplash.com/' - 'photo-1494438639946-1ebd1d20bf85?fit=crop&w=900&q=60' + 'photo-1494438639946-1ebd1d20bf85?fit=crop&w=900&q=60', + event_timestamp=datetime.now(), + priority='high', + vibrate_timings_millis=[100, 200, 300, 400], + visibility='public', + light_settings=messaging.LightSettings( + color='#aabbcc', + light_off_duration_millis=200, + light_on_duration_millis=300 + ), + notification_count=1 ) ), apns=messaging.APNSConfig(payload=messaging.APNSPayload( diff --git a/tests/test_messaging.py b/tests/test_messaging.py index 04ef36d8c..5d75f246c 100644 --- a/tests/test_messaging.py +++ b/tests/test_messaging.py @@ -32,6 +32,7 @@ NON_DICT_ARGS = ['', list(), tuple(), True, False, 1, 0, {1: 'foo'}, {'foo': 1}] NON_OBJECT_ARGS = [list(), tuple(), dict(), 'foo', 0, 1, True, False] NON_LIST_ARGS = ['', tuple(), dict(), True, False, 1, 0, [1], ['foo', 1]] +NON_UINT_ARGS = ['1.23s', list(), tuple(), dict(), -1.23] HTTP_ERROR_CODES = { 400: exceptions.InvalidArgumentError, 403: exceptions.PermissionDeniedError, @@ -292,7 +293,7 @@ def test_invalid_priority(self, data): else: assert str(excinfo.value) == 'AndroidConfig.priority must be a non-empty string.' - @pytest.mark.parametrize('data', ['1.23s', list(), tuple(), dict(), -1.23]) + @pytest.mark.parametrize('data', NON_UINT_ARGS) def test_invalid_ttl(self, data): with pytest.raises(ValueError) as excinfo: check_encoding(messaging.Message( @@ -474,11 +475,70 @@ def test_no_body_loc_key(self): assert str(excinfo.value) == expected @pytest.mark.parametrize('data', NON_STRING_ARGS) - def test_invalid_channek_id(self, data): + def test_invalid_channel_id(self, data): notification = messaging.AndroidNotification(channel_id=data) excinfo = self._check_notification(notification) assert str(excinfo.value) == 'AndroidNotification.channel_id must be a string.' + @pytest.mark.parametrize('timestamp', [100, '', 'foo', {}, [], list(), dict()]) + def test_invalid_event_timestamp(self, timestamp): + notification = messaging.AndroidNotification(event_timestamp=timestamp) + excinfo = self._check_notification(notification) + expected = 'AndroidNotification.event_timestamp must be a datetime.' + assert str(excinfo.value) == expected + + @pytest.mark.parametrize('priority', NON_STRING_ARGS + ['foo']) + def test_invalid_priority(self, priority): + notification = messaging.AndroidNotification(priority=priority) + excinfo = self._check_notification(notification) + if isinstance(priority, six.string_types): + if not priority: + expected = 'AndroidNotification.priority must be a non-empty string.' + else: + expected = ('AndroidNotification.priority must be "default", "min", "low", "high" ' + 'or "max".') + else: + expected = 'AndroidNotification.priority must be a non-empty string.' + assert str(excinfo.value) == expected + + @pytest.mark.parametrize('visibility', NON_STRING_ARGS + ['foo']) + def test_invalid_visibility(self, visibility): + notification = messaging.AndroidNotification(visibility=visibility) + excinfo = self._check_notification(notification) + if isinstance(visibility, six.string_types): + if not visibility: + expected = 'AndroidNotification.visibility must be a non-empty string.' + else: + expected = ('AndroidNotification.visibility must be "private", "public" or' + ' "secret".') + else: + expected = 'AndroidNotification.visibility must be a non-empty string.' + assert str(excinfo.value) == expected + + @pytest.mark.parametrize('vibrate_timings', ['', 1, True, 'msec', ['500', 500], [0, 'abc']]) + def test_invalid_vibrate_timings_millis(self, vibrate_timings): + notification = messaging.AndroidNotification(vibrate_timings_millis=vibrate_timings) + excinfo = self._check_notification(notification) + if isinstance(vibrate_timings, list): + expected = ('AndroidNotification.vibrate_timings_millis must not contain non-number ' + 'values.') + else: + expected = 'AndroidNotification.vibrate_timings_millis must be a list of numbers.' + assert str(excinfo.value) == expected + + def test_negative_vibrate_timings_millis(self): + notification = messaging.AndroidNotification( + vibrate_timings_millis=[100, -20, 15]) + excinfo = self._check_notification(notification) + expected = 'AndroidNotification.vibrate_timings_millis must not be negative.' + assert str(excinfo.value) == expected + + @pytest.mark.parametrize('notification_count', ['', 'foo', list(), tuple(), dict()]) + def test_invalid_notification_count(self, notification_count): + notification = messaging.AndroidNotification(notification_count=notification_count) + excinfo = self._check_notification(notification) + assert str(excinfo.value) == 'AndroidNotification.notification_count must be a number.' + def test_android_notification(self): msg = messaging.Message( topic='topic', @@ -486,7 +546,17 @@ def test_android_notification(self): notification=messaging.AndroidNotification( title='t', body='b', icon='i', color='#112233', sound='s', tag='t', click_action='ca', title_loc_key='tlk', body_loc_key='blk', - title_loc_args=['t1', 't2'], body_loc_args=['b1', 'b2'], channel_id='c' + title_loc_args=['t1', 't2'], body_loc_args=['b1', 'b2'], channel_id='c', + ticker='ticker', sticky=True, + event_timestamp=datetime.datetime(2019, 10, 20, 15, 12, 23, 123), + local_only=False, + priority='high', vibrate_timings_millis=[100, 50, 250], + default_vibrate_timings=False, default_sound=True, + light_settings=messaging.LightSettings( + color='#AABBCCDD', light_on_duration_millis=200, + light_off_duration_millis=300, + ), + default_light_settings=False, visibility='public', notification_count=1, ) ) ) @@ -505,7 +575,142 @@ def test_android_notification(self): 'body_loc_key': 'blk', 'title_loc_args': ['t1', 't2'], 'body_loc_args': ['b1', 'b2'], - 'channel_id' : 'c', + 'channel_id': 'c', + 'ticker': 'ticker', + 'sticky': 1, + 'event_time': '2019-10-20T15:12:23.000123Z', + 'local_only': 0, + 'notification_priority': 'PRIORITY_HIGH', + 'vibrate_timings': ['0.100000000s', '0.050000000s', '0.250000000s'], + 'default_vibrate_timings': 0, + 'default_sound': 1, + 'light_settings': { + 'color': { + 'red': 0.6666666666666666, + 'green': 0.7333333333333333, + 'blue': 0.8, + 'alpha': 0.8666666666666667, + }, + 'light_on_duration': '0.200000000s', + 'light_off_duration': '0.300000000s', + }, + 'default_light_settings': 0, + 'visibility': 'PUBLIC', + 'notification_count': 1, + }, + }, + } + check_encoding(msg, expected) + + +class TestLightSettingsEncoder(object): + + def _check_light_settings(self, light_settings): + with pytest.raises(ValueError) as excinfo: + check_encoding(messaging.Message( + topic='topic', android=messaging.AndroidConfig( + notification=messaging.AndroidNotification( + light_settings=light_settings + )))) + return excinfo + + @pytest.mark.parametrize('data', NON_OBJECT_ARGS) + def test_invalid_light_settings(self, data): + with pytest.raises(ValueError) as excinfo: + check_encoding(messaging.Message( + topic='topic', android=messaging.AndroidConfig( + notification=messaging.AndroidNotification( + light_settings=data + )))) + expected = 'AndroidNotification.light_settings must be an instance of LightSettings class.' + assert str(excinfo.value) == expected + + def test_no_color(self): + light_settings = messaging.LightSettings(color=None, light_on_duration_millis=200, + light_off_duration_millis=200) + excinfo = self._check_light_settings(light_settings) + expected = 'LightSettings.color is required.' + assert str(excinfo.value) == expected + + def test_no_light_on_duration_millis(self): + light_settings = messaging.LightSettings(color='#aabbcc', light_on_duration_millis=None, + light_off_duration_millis=200) + excinfo = self._check_light_settings(light_settings) + expected = 'LightSettings.light_on_duration_millis is required.' + assert str(excinfo.value) == expected + + def test_no_light_off_duration_millis(self): + light_settings = messaging.LightSettings(color='#aabbcc', light_on_duration_millis=200, + light_off_duration_millis=None) + excinfo = self._check_light_settings(light_settings) + expected = 'LightSettings.light_off_duration_millis is required.' + assert str(excinfo.value) == expected + + @pytest.mark.parametrize('data', NON_UINT_ARGS) + def test_invalid_light_off_duration_millis(self, data): + light_settings = messaging.LightSettings(color='#aabbcc', + light_on_duration_millis=200, + light_off_duration_millis=data) + excinfo = self._check_light_settings(light_settings) + if isinstance(data, numbers.Number): + assert str(excinfo.value) == ('LightSettings.light_off_duration_millis must not be ' + 'negative.') + else: + assert str(excinfo.value) == ('LightSettings.light_off_duration_millis must be a ' + 'duration in milliseconds or ' + 'an instance of datetime.timedelta.') + + @pytest.mark.parametrize('data', NON_UINT_ARGS) + def test_invalid_light_on_duration_millis(self, data): + light_settings = messaging.LightSettings(color='#aabbcc', + light_on_duration_millis=data, + light_off_duration_millis=200) + excinfo = self._check_light_settings(light_settings) + if isinstance(data, numbers.Number): + assert str(excinfo.value) == ('LightSettings.light_on_duration_millis must not be ' + 'negative.') + else: + assert str(excinfo.value) == ('LightSettings.light_on_duration_millis must be a ' + 'duration in milliseconds or ' + 'an instance of datetime.timedelta.') + + @pytest.mark.parametrize('data', NON_STRING_ARGS + ['foo', '#xxyyzz', '112233', '#11223']) + def test_invalid_color(self, data): + notification = messaging.LightSettings(color=data, light_on_duration_millis=300, + light_off_duration_millis=200) + excinfo = self._check_light_settings(notification) + if isinstance(data, six.string_types): + assert str(excinfo.value) == ('LightSettings.color must be in the form #RRGGBB or ' + '#RRGGBBAA.') + else: + assert str( + excinfo.value) == 'LightSettings.color must be a non-empty string.' + + def test_light_settings(self): + msg = messaging.Message( + topic='topic', android=messaging.AndroidConfig( + notification=messaging.AndroidNotification( + light_settings=messaging.LightSettings( + color="#aabbcc", + light_on_duration_millis=200, + light_off_duration_millis=300, + ) + )) + ) + expected = { + 'topic': 'topic', + 'android': { + 'notification': { + 'light_settings': { + 'color': { + 'red': 0.6666666666666666, + 'green': 0.7333333333333333, + 'blue': 0.8, + 'alpha': 1, + }, + 'light_on_duration': '0.200000000s', + 'light_off_duration': '0.300000000s', + } }, }, } From 24e5ad42e048db122c2357f110aee358f33a3480 Mon Sep 17 00:00:00 2001 From: Lahiru Maramba Date: Wed, 13 Nov 2019 17:16:15 -0500 Subject: [PATCH 037/226] Bumped version to 3.2.0 (#366) --- firebase_admin/__about__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/firebase_admin/__about__.py b/firebase_admin/__about__.py index 04a662b25..9377a4c5b 100644 --- a/firebase_admin/__about__.py +++ b/firebase_admin/__about__.py @@ -14,7 +14,7 @@ """About information (version, etc) for Firebase Admin SDK.""" -__version__ = '3.1.0' +__version__ = '3.2.0' __title__ = 'firebase_admin' __author__ = 'Firebase' __license__ = 'Apache License 2.0' From 2e67b96f6d670cd6af6129b16409b02a0a17e29f Mon Sep 17 00:00:00 2001 From: Lahiru Maramba Date: Mon, 2 Dec 2019 08:46:08 -0500 Subject: [PATCH 038/226] Fix FCM Android Notification boolean parameters (#370) --- firebase_admin/_messaging_encoder.py | 17 +++++------------ firebase_admin/_messaging_utils.py | 2 +- integration/test_messaging.py | 5 +++++ tests/test_messaging.py | 8 ++++---- 4 files changed, 15 insertions(+), 17 deletions(-) diff --git a/firebase_admin/_messaging_encoder.py b/firebase_admin/_messaging_encoder.py index 1177ffb65..a65b2f4ee 100644 --- a/firebase_admin/_messaging_encoder.py +++ b/firebase_admin/_messaging_encoder.py @@ -259,13 +259,6 @@ def encode_milliseconds(cls, label, msec): return '{0}.{1}s'.format(seconds, str(nanos).zfill(9)) return '{0}s'.format(seconds) - @classmethod - def encode_boolean(cls, value): - """Encodes a boolean into JSON.""" - if value is None: - return None - return 1 if value else 0 - @classmethod def encode_android_notification(cls, notification): """Encodes an ``AndroidNotification`` instance into JSON.""" @@ -303,17 +296,17 @@ def encode_android_notification(cls, notification): 'image', notification.image), 'ticker': _Validators.check_string( 'AndroidNotification.ticker', notification.ticker), - 'sticky': cls.encode_boolean(notification.sticky), + 'sticky': notification.sticky, 'event_time': _Validators.check_datetime( 'AndroidNotification.event_timestamp', notification.event_timestamp), - 'local_only': cls.encode_boolean(notification.local_only), + 'local_only': notification.local_only, 'notification_priority': _Validators.check_string( 'AndroidNotification.priority', notification.priority, non_empty=True), 'vibrate_timings': _Validators.check_number_list( 'AndroidNotification.vibrate_timings_millis', notification.vibrate_timings_millis), - 'default_vibrate_timings': cls.encode_boolean(notification.default_vibrate_timings), - 'default_sound': cls.encode_boolean(notification.default_sound), - 'default_light_settings': cls.encode_boolean(notification.default_light_settings), + 'default_vibrate_timings': notification.default_vibrate_timings, + 'default_sound': notification.default_sound, + 'default_light_settings': notification.default_light_settings, 'light_settings': cls.encode_light_settings(notification.light_settings), 'visibility': _Validators.check_string( 'AndroidNotification.visibility', notification.visibility, non_empty=True), diff --git a/firebase_admin/_messaging_utils.py b/firebase_admin/_messaging_utils.py index 7287e57d9..10ede8a5b 100644 --- a/firebase_admin/_messaging_utils.py +++ b/firebase_admin/_messaging_utils.py @@ -93,7 +93,7 @@ class AndroidNotification(object): ticker: Sets the ``ticker`` text, which is sent to accessibility services. Prior to API level 21 (Lollipop), sets the text that is displayed in the status bar when the notification first arrives (optional). - sticky: When set to ``false`` or unset, the notification is automatically dismissed when the + sticky: When set to ``False`` or unset, the notification is automatically dismissed when the user clicks it in the panel. When set to ``True``, the notification persists even when the user clicks it (optional). event_timestamp: For notifications that inform users about events with an absolute time diff --git a/integration/test_messaging.py b/integration/test_messaging.py index 45b53ce97..bc4f1d1ca 100644 --- a/integration/test_messaging.py +++ b/integration/test_messaging.py @@ -45,6 +45,11 @@ def test_send(): priority='high', vibrate_timings_millis=[100, 200, 300, 400], visibility='public', + sticky=True, + local_only=False, + default_vibrate_timings=False, + default_sound=True, + default_light_settings=False, light_settings=messaging.LightSettings( color='#aabbcc', light_off_duration_millis=200, diff --git a/tests/test_messaging.py b/tests/test_messaging.py index 5d75f246c..040f84023 100644 --- a/tests/test_messaging.py +++ b/tests/test_messaging.py @@ -577,12 +577,12 @@ def test_android_notification(self): 'body_loc_args': ['b1', 'b2'], 'channel_id': 'c', 'ticker': 'ticker', - 'sticky': 1, + 'sticky': True, 'event_time': '2019-10-20T15:12:23.000123Z', - 'local_only': 0, + 'local_only': False, 'notification_priority': 'PRIORITY_HIGH', 'vibrate_timings': ['0.100000000s', '0.050000000s', '0.250000000s'], - 'default_vibrate_timings': 0, + 'default_vibrate_timings': False, 'default_sound': 1, 'light_settings': { 'color': { @@ -594,7 +594,7 @@ def test_android_notification(self): 'light_on_duration': '0.200000000s', 'light_off_duration': '0.300000000s', }, - 'default_light_settings': 0, + 'default_light_settings': False, 'visibility': 'PUBLIC', 'notification_count': 1, }, From 22f6761fbeccbd4633a68f9d1acfccf8e54a1e01 Mon Sep 17 00:00:00 2001 From: Lahiru Maramba Date: Wed, 11 Dec 2019 16:16:05 -0500 Subject: [PATCH 039/226] Bumped version to 3.2.1 (#374) --- firebase_admin/__about__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/firebase_admin/__about__.py b/firebase_admin/__about__.py index 9377a4c5b..d44e3ccb5 100644 --- a/firebase_admin/__about__.py +++ b/firebase_admin/__about__.py @@ -14,7 +14,7 @@ """About information (version, etc) for Firebase Admin SDK.""" -__version__ = '3.2.0' +__version__ = '3.2.1' __title__ = 'firebase_admin' __author__ = 'Firebase' __license__ = 'Apache License 2.0' From 81463e26b4e832776ea7faa8f3d31f2771688f77 Mon Sep 17 00:00:00 2001 From: Hiranya Jayathilaka Date: Fri, 3 Jan 2020 12:10:03 -0800 Subject: [PATCH 040/226] Upgraded cachecontrol to latest (#378) * Upgraded cachecontrol to latest * Fixing some failing requests tests --- requirements.txt | 2 +- setup.py | 2 +- tests/test_db.py | 2 +- tests/test_messaging.py | 4 ++-- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/requirements.txt b/requirements.txt index fd73d36bd..cc4534b03 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,7 +4,7 @@ pytest-cov >= 2.4.0 pytest-localserver >= 0.4.1 tox >= 3.6.0 -cachecontrol >= 0.12.4 +cachecontrol >= 0.12.6 google-api-core[grpc] >= 1.14.0, < 2.0.0dev; platform.python_implementation != 'PyPy' google-api-python-client >= 1.7.8 google-cloud-firestore >= 1.4.0; platform.python_implementation != 'PyPy' diff --git a/setup.py b/setup.py index a3cce8be5..cb698f774 100644 --- a/setup.py +++ b/setup.py @@ -37,7 +37,7 @@ long_description = ('The Firebase Admin Python SDK enables server-side (backend) Python developers ' 'to integrate Firebase into their services and applications.') install_requires = [ - 'cachecontrol>=0.12.4', + 'cachecontrol>=0.12.6', 'google-api-core[grpc] >= 1.14.0, < 2.0.0dev; platform.python_implementation != "PyPy"', 'google-api-python-client >= 1.7.8', 'google-cloud-firestore>=1.4.0; platform.python_implementation != "PyPy"', diff --git a/tests/test_db.py b/tests/test_db.py index 081c31e3d..e9f8f7dda 100644 --- a/tests/test_db.py +++ b/tests/test_db.py @@ -819,7 +819,7 @@ def test_http_timeout(self): assert ref._client.timeout == 60 assert ref.get() == {} assert len(recorder) == 1 - assert recorder[0]._extra_kwargs['timeout'] == 60 + assert recorder[0]._extra_kwargs['timeout'] == pytest.approx(60, 0.001) def test_app_delete(self): app = firebase_admin.initialize_app( diff --git a/tests/test_messaging.py b/tests/test_messaging.py index 040f84023..96b512577 100644 --- a/tests/test_messaging.py +++ b/tests/test_messaging.py @@ -1562,7 +1562,7 @@ def test_send(self): msg = messaging.Message(topic='foo') messaging.send(msg) assert len(self.recorder) == 1 - assert self.recorder[0]._extra_kwargs['timeout'] == 4 + assert self.recorder[0]._extra_kwargs['timeout'] == pytest.approx(4, 0.001) def test_topic_management_timeout(self): self.fcm_service._client.session.mount( @@ -1574,7 +1574,7 @@ def test_topic_management_timeout(self): ) messaging.subscribe_to_topic(['1'], 'a') assert len(self.recorder) == 1 - assert self.recorder[0]._extra_kwargs['timeout'] == 4 + assert self.recorder[0]._extra_kwargs['timeout'] == pytest.approx(4, 0.001) class TestSend(object): From 31d91d6d767a3d9852daa40e3d8013e2a98f52dc Mon Sep 17 00:00:00 2001 From: Hiranya Jayathilaka Date: Fri, 10 Jan 2020 14:24:00 -0800 Subject: [PATCH 041/226] Removing Python 2 support (#381) --- .travis.yml | 2 +- CONTRIBUTING.md | 49 +------------------------------------- README.md | 6 ++--- requirements.txt | 1 - scripts/prepare_release.sh | 2 +- setup.py | 8 +++---- tox.ini | 33 ------------------------- 7 files changed, 8 insertions(+), 93 deletions(-) delete mode 100644 tox.ini diff --git a/.travis.yml b/.travis.yml index 4db3c3708..0c00ccc23 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,9 +1,9 @@ language: python python: - - "2.7" - "3.4" - "3.5" - "3.6" + - "3.7" - "pypy3.5" jobs: diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 7b4a0ea84..7b521ec99 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -85,7 +85,7 @@ information on using pull requests. ### Initial Setup -You need Python 2.7 or Python 3.4+ to build and test the code in this repo. +You need Python 3.4+ to build and test the code in this repo. We recommend using [pip](https://pypi.python.org/pypi/pip) for installing the necessary tools and project dependencies. Most recent versions of Python ship with pip. If your development environment @@ -227,53 +227,6 @@ pytest --cov --cov-report html and point your browser to `file:////htmlcov/index.html` (where `dir` is the location from which the report was created). - -### Testing in Different Environments - -Sometimes we want to run unit tests in multiple environments (e.g. different Python versions), and -ensure that the SDK works as expected in each of them. We use -[tox](https://tox.readthedocs.io/en/latest/) for this purpose. - -But before you can invoke tox, you must set up all the necessary target environments on your -workstation. The easiest and cleanest way to achieve this is by using a tool like -[pyenv](https://github.com/pyenv/pyenv). Refer to the -[pyenv documentation](https://github.com/pyenv/pyenv#installation) for instructions on how to -install it. This generally involves installing some binaries as well as modifying a system level -configuration file such as `.bash_profile`. Once pyenv is installed, you can install multiple -versions of Python as follows: - -``` -pyenv install 2.7.6 # install Python 2.7.6 -pyenv install 3.3.0 # install Python 3.3.0 -pyenv install pypy2-5.6.0 # install pypy2 -``` - -Refer to the [`tox.ini`](tox.ini) file for a list of target environments that we usually test. -Use pyenv to install all the required Python versions on your workstation. Verify that they are -installed by running the following command: - -``` -pyenv versions -``` - -To make all the required Python versions available to tox for testing, run the `pyenv local` command -with all the Python versions as arguments. The following example shows how to make Python versions -2.7.6, 3.3.0 and pypy2 available to tox. - -``` -pyenv local 2.7.6 3.3.0 pypy2-5.6.0 -``` - -Once your system is fully set up, you can execute the following command from the root of the -repository to launch tox: - -``` -tox -``` - -This command will read the list of target environments from `tox.ini`, and execute tests in each of -those environments. It will also generate a code coverage report at the end of the execution. - ### Repo Organization Here are some highlights of the directory structure and notable source files diff --git a/README.md b/README.md index 757a3f8cd..8e9efd0ee 100644 --- a/README.md +++ b/README.md @@ -41,10 +41,8 @@ requests, code review feedback, and also pull requests. ## Supported Python Versions -We currently support Python 2.7 and Python 3.4+. However, Python 2.7 support is -being phased out, and the developers are advised to use latest Python 3. -Firebase Admin Python SDK is also tested on PyPy and -[Google App Engine](https://cloud.google.com/appengine/) environments. +We currently support Python 3.4+. Firebase Admin Python SDK is also tested on +PyPy and [Google App Engine](https://cloud.google.com/appengine/) environments. ## Documentation diff --git a/requirements.txt b/requirements.txt index cc4534b03..2f1a09a5b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,7 +2,6 @@ pylint == 1.6.4 pytest >= 3.6.0 pytest-cov >= 2.4.0 pytest-localserver >= 0.4.1 -tox >= 3.6.0 cachecontrol >= 0.12.6 google-api-core[grpc] >= 1.14.0, < 2.0.0dev; platform.python_implementation != 'PyPy' diff --git a/scripts/prepare_release.sh b/scripts/prepare_release.sh index ca30d9043..aa55dae92 100755 --- a/scripts/prepare_release.sh +++ b/scripts/prepare_release.sh @@ -132,7 +132,7 @@ fi ################## echo "[INFO] Running unit tests" -tox +pytest ../tests echo "[INFO] Running integration tests" pytest ../integration --cert cert.json --apikey apikey.txt diff --git a/setup.py b/setup.py index cb698f774..b492ec922 100644 --- a/setup.py +++ b/setup.py @@ -22,8 +22,8 @@ (major, minor) = (sys.version_info.major, sys.version_info.minor) -if (major == 2 and minor < 7) or (major == 3 and minor < 4): - print('firebase_admin requires python2 >= 2.7 or python3 >= 3.4', file=sys.stderr) +if major != 3 or minor < 4: + print('firebase_admin requires python >= 3.4', file=sys.stderr) sys.exit(1) # Read in the package metadata per recommendations from: @@ -56,13 +56,11 @@ keywords='firebase cloud development', install_requires=install_requires, packages=['firebase_admin'], - python_requires='>=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*', + python_requires='>=3.4', classifiers=[ 'Development Status :: 5 - Production/Stable', 'Intended Audience :: Developers', 'Topic :: Software Development :: Build Tools', - 'Programming Language :: Python :: 2', - 'Programming Language :: Python :: 2.7', 'Programming Language :: Python :: 3', 'Programming Language :: Python :: 3.4', 'Programming Language :: Python :: 3.5', diff --git a/tox.ini b/tox.ini deleted file mode 100644 index dec7b618f..000000000 --- a/tox.ini +++ /dev/null @@ -1,33 +0,0 @@ -# Tox (https://tox.readthedocs.io/) is a tool for running tests -# in multiple virtualenvs. This configuration file will run the -# test suite on all supported python versions. To use it, "pip install tox" -# and then run "tox" from this directory. - -[tox] -envlist = py2,py3,pypy,cover - -[testenv] -passenv = - FIREBASE_DATABASE_EMULATOR_HOST -commands = pytest {posargs} -deps = - pytest - pytest-localserver - -[coverbase] -basepython = python2.7 -commands = - pytest \ - --cov=firebase_admin \ - --cov=tests -deps = {[testenv]deps} - coverage - pytest-cov - -[testenv:cover] -basepython = {[coverbase]basepython} -commands = - {[coverbase]commands} - coverage report --show-missing -deps = - {[coverbase]deps} From cf3203bb75181257749bc19c624eb16a7b58737d Mon Sep 17 00:00:00 2001 From: Hiranya Jayathilaka Date: Tue, 14 Jan 2020 11:49:18 -0800 Subject: [PATCH 042/226] Upgraded to pylint 2.x (#384) * Removing Python 2 support * Upgraded to Pylint 2.x and fixed all linter errors for Python 3 * Downgrading to pylint 2.3 since 2.4 won't install on Python 3.4 --- .travis.yml | 2 +- firebase_admin/__init__.py | 36 +++++++------- firebase_admin/_auth_utils.py | 4 +- firebase_admin/_http_client.py | 2 +- firebase_admin/_messaging_encoder.py | 34 ++++++------- firebase_admin/_messaging_utils.py | 32 ++++++------- firebase_admin/_sseclient.py | 10 ++-- firebase_admin/_token_gen.py | 10 ++-- firebase_admin/_user_import.py | 10 ++-- firebase_admin/_user_mgt.py | 20 ++++---- firebase_admin/_utils.py | 18 +++---- firebase_admin/auth.py | 9 ++-- firebase_admin/credentials.py | 2 +- firebase_admin/db.py | 64 ++++++++++++------------- firebase_admin/firestore.py | 2 +- firebase_admin/instance_id.py | 6 +-- firebase_admin/messaging.py | 11 +++-- firebase_admin/project_management.py | 18 +++---- firebase_admin/storage.py | 4 +- integration/test_auth.py | 8 ++-- integration/test_db.py | 14 +++--- integration/test_messaging.py | 2 +- lint.sh | 4 +- requirements.txt | 2 +- snippets/auth/index.py | 4 +- tests/test_app.py | 10 ++-- tests/test_credentials.py | 8 ++-- tests/test_db.py | 28 +++++------ tests/test_exceptions.py | 6 +-- tests/test_firestore.py | 2 +- tests/test_http_client.py | 2 +- tests/test_instance_id.py | 4 +- tests/test_messaging.py | 72 ++++++++++++++-------------- tests/test_project_management.py | 11 ++--- tests/test_sseclient.py | 6 +-- tests/test_token_gen.py | 10 ++-- tests/test_user_mgt.py | 51 ++++++++++---------- tests/testutils.py | 6 +-- 38 files changed, 272 insertions(+), 272 deletions(-) diff --git a/.travis.yml b/.travis.yml index 0c00ccc23..8d6b9246a 100644 --- a/.travis.yml +++ b/.travis.yml @@ -9,7 +9,7 @@ python: jobs: include: - name: "Lint" - python: "2.7" + python: "3.7" script: ./lint.sh all before_install: diff --git a/firebase_admin/__init__.py b/firebase_admin/__init__.py index bc9526378..eae68bd06 100644 --- a/firebase_admin/__init__.py +++ b/firebase_admin/__init__.py @@ -77,12 +77,12 @@ def initialize_app(credential=None, options=None, name=_DEFAULT_APP_NAME): 'initialize_app() once. But if you do want to initialize multiple ' 'apps, pass a second argument to initialize_app() to give each app ' 'a unique name.')) - else: - raise ValueError(( - 'Firebase app named "{0}" already exists. This means you called ' - 'initialize_app() more than once with the same app name as the ' - 'second argument. Make sure you provide a unique name every time ' - 'you call initialize_app().').format(name)) + + raise ValueError(( + 'Firebase app named "{0}" already exists. This means you called ' + 'initialize_app() more than once with the same app name as the ' + 'second argument. Make sure you provide a unique name every time ' + 'you call initialize_app().').format(name)) def delete_app(app): @@ -106,11 +106,11 @@ def delete_app(app): raise ValueError( 'The default Firebase app is not initialized. Make sure to initialize ' 'the default app by calling initialize_app().') - else: - raise ValueError( - ('Firebase app named "{0}" is not initialized. Make sure to initialize ' - 'the app by calling initialize_app() with your app name as the ' - 'second argument.').format(app.name)) + + raise ValueError( + ('Firebase app named "{0}" is not initialized. Make sure to initialize ' + 'the app by calling initialize_app() with your app name as the ' + 'second argument.').format(app.name)) def get_app(name=_DEFAULT_APP_NAME): @@ -137,14 +137,14 @@ def get_app(name=_DEFAULT_APP_NAME): raise ValueError( 'The default Firebase app does not exist. Make sure to initialize ' 'the SDK by calling initialize_app().') - else: - raise ValueError( - ('Firebase app named "{0}" does not exist. Make sure to initialize ' - 'the SDK by calling initialize_app() with your app name as the ' - 'second argument.').format(name)) + + raise ValueError( + ('Firebase app named "{0}" does not exist. Make sure to initialize ' + 'the SDK by calling initialize_app() with your app name as the ' + 'second argument.').format(name)) -class _AppOptions(object): +class _AppOptions: """A collection of configuration options for an App.""" def __init__(self, options): @@ -185,7 +185,7 @@ def _load_from_environment(self): return {k: v for k, v in json_data.items() if k in _CONFIG_VALID_KEYS} -class App(object): +class App: """The entry point for Firebase Python SDK. Represents a Firebase app, while holding the configuration and state diff --git a/firebase_admin/_auth_utils.py b/firebase_admin/_auth_utils.py index df3e0acfc..b54e7d480 100644 --- a/firebase_admin/_auth_utils.py +++ b/firebase_admin/_auth_utils.py @@ -103,6 +103,7 @@ def validate_provider_id(provider_id, required=True): return provider_id def validate_photo_url(photo_url, required=False): + """Parses and validates the given URL string.""" if photo_url is None and not required: return None if not isinstance(photo_url, six.string_types) or not photo_url: @@ -118,6 +119,7 @@ def validate_photo_url(photo_url, required=False): raise ValueError('Malformed photo URL: "{0}".'.format(photo_url)) def validate_timestamp(timestamp, label, required=False): + """Validates the given timestamp value. Timestamps must be positive integers.""" if timestamp is None and not required: return None if isinstance(timestamp, bool): @@ -181,7 +183,7 @@ def validate_custom_claims(custom_claims, required=False): if len(invalid_claims) > 1: joined = ', '.join(sorted(invalid_claims)) raise ValueError('Claims "{0}" are reserved, and must not be set.'.format(joined)) - elif len(invalid_claims) == 1: + if len(invalid_claims) == 1: raise ValueError( 'Claim "{0}" is reserved, and must not be set.'.format(invalid_claims.pop())) return claims_str diff --git a/firebase_admin/_http_client.py b/firebase_admin/_http_client.py index eb8c4027a..1daaf371b 100644 --- a/firebase_admin/_http_client.py +++ b/firebase_admin/_http_client.py @@ -32,7 +32,7 @@ raise_on_status=False, backoff_factor=0.5) -class HttpClient(object): +class HttpClient: """Base HTTP client used to make HTTP calls. HttpClient maintains an HTTP session, and handles request authentication and retries if diff --git a/firebase_admin/_messaging_encoder.py b/firebase_admin/_messaging_encoder.py index a65b2f4ee..aefaf3e2f 100644 --- a/firebase_admin/_messaging_encoder.py +++ b/firebase_admin/_messaging_encoder.py @@ -25,7 +25,7 @@ import firebase_admin._messaging_utils as _messaging_utils -class Message(object): +class Message: """A message that can be sent via Firebase Cloud Messaging. Contains payload information as well as recipient information. In particular, the message must @@ -61,7 +61,7 @@ def __str__(self): return json.dumps(self, cls=MessageEncoder, sort_keys=True) -class MulticastMessage(object): +class MulticastMessage: """A message that can be sent to multiple tokens via Firebase Cloud Messaging. Args: @@ -88,7 +88,7 @@ def __init__(self, tokens, data=None, notification=None, android=None, webpush=N self.fcm_options = fcm_options -class _Validators(object): +class _Validators: """A collection of data validation utilities. Methods provided in this class raise ``ValueErrors`` if any validations fail. @@ -102,8 +102,7 @@ def check_string(cls, label, value, non_empty=False): if not isinstance(value, six.string_types): if non_empty: raise ValueError('{0} must be a non-empty string.'.format(label)) - else: - raise ValueError('{0} must be a string.'.format(label)) + raise ValueError('{0} must be a string.'.format(label)) if non_empty and not value: raise ValueError('{0} must be a non-empty string.'.format(label)) return value @@ -647,6 +646,7 @@ def encode_notification(cls, notification): @classmethod def sanitize_topic_name(cls, topic): + """Removes the /topics/ prefix from the topic name, if present.""" if not topic: return None prefix = '/topics/' @@ -657,20 +657,20 @@ def sanitize_topic_name(cls, topic): raise ValueError('Malformed topic name.') return topic - def default(self, obj): # pylint: disable=method-hidden - if not isinstance(obj, Message): - return json.JSONEncoder.default(self, obj) + def default(self, o): # pylint: disable=method-hidden + if not isinstance(o, Message): + return json.JSONEncoder.default(self, o) result = { - 'android': MessageEncoder.encode_android(obj.android), - 'apns': MessageEncoder.encode_apns(obj.apns), + 'android': MessageEncoder.encode_android(o.android), + 'apns': MessageEncoder.encode_apns(o.apns), 'condition': _Validators.check_string( - 'Message.condition', obj.condition, non_empty=True), - 'data': _Validators.check_string_dict('Message.data', obj.data), - 'notification': MessageEncoder.encode_notification(obj.notification), - 'token': _Validators.check_string('Message.token', obj.token, non_empty=True), - 'topic': _Validators.check_string('Message.topic', obj.topic, non_empty=True), - 'webpush': MessageEncoder.encode_webpush(obj.webpush), - 'fcm_options': MessageEncoder.encode_fcm_options(obj.fcm_options), + 'Message.condition', o.condition, non_empty=True), + 'data': _Validators.check_string_dict('Message.data', o.data), + 'notification': MessageEncoder.encode_notification(o.notification), + 'token': _Validators.check_string('Message.token', o.token, non_empty=True), + 'topic': _Validators.check_string('Message.topic', o.topic, non_empty=True), + 'webpush': MessageEncoder.encode_webpush(o.webpush), + 'fcm_options': MessageEncoder.encode_fcm_options(o.fcm_options), } result['topic'] = MessageEncoder.sanitize_topic_name(result.get('topic')) result = MessageEncoder.remove_null_values(result) diff --git a/firebase_admin/_messaging_utils.py b/firebase_admin/_messaging_utils.py index 10ede8a5b..3a1943c04 100644 --- a/firebase_admin/_messaging_utils.py +++ b/firebase_admin/_messaging_utils.py @@ -17,7 +17,7 @@ from firebase_admin import exceptions -class Notification(object): +class Notification: """A notification that can be included in a message. Args: @@ -32,7 +32,7 @@ def __init__(self, title=None, body=None, image=None): self.image = image -class AndroidConfig(object): +class AndroidConfig: """Android-specific options that can be included in a message. Args: @@ -62,7 +62,7 @@ def __init__(self, collapse_key=None, priority=None, ttl=None, restricted_packag self.fcm_options = fcm_options -class AndroidNotification(object): +class AndroidNotification: """Android-specific notification parameters. Args: @@ -178,7 +178,7 @@ def __init__(self, title=None, body=None, icon=None, color=None, sound=None, tag self.notification_count = notification_count -class LightSettings(object): +class LightSettings: """Represents settings to control notification LED that can be included in a ``messaging.AndroidNotification``. @@ -196,7 +196,7 @@ def __init__(self, color, light_on_duration_millis, self.light_off_duration_millis = light_off_duration_millis -class AndroidFCMOptions(object): +class AndroidFCMOptions: """Options for features provided by the FCM SDK for Android. Args: @@ -208,7 +208,7 @@ def __init__(self, analytics_label=None): self.analytics_label = analytics_label -class WebpushConfig(object): +class WebpushConfig: """Webpush-specific options that can be included in a message. Args: @@ -230,7 +230,7 @@ def __init__(self, headers=None, data=None, notification=None, fcm_options=None) self.fcm_options = fcm_options -class WebpushNotificationAction(object): +class WebpushNotificationAction: """An action available to the users when the notification is presented. Args: @@ -245,7 +245,7 @@ def __init__(self, action, title, icon=None): self.icon = icon -class WebpushNotification(object): +class WebpushNotification: """Webpush-specific notification parameters. Refer to the `Notification Reference`_ for more information. @@ -302,7 +302,7 @@ def __init__(self, title=None, body=None, icon=None, actions=None, badge=None, d self.custom_data = custom_data -class WebpushFCMOptions(object): +class WebpushFCMOptions: """Options for features provided by the FCM SDK for Web. Args: @@ -314,7 +314,7 @@ def __init__(self, link=None): self.link = link -class APNSConfig(object): +class APNSConfig: """APNS-specific options that can be included in a message. Refer to `APNS Documentation`_ for more information. @@ -335,7 +335,7 @@ def __init__(self, headers=None, payload=None, fcm_options=None): self.fcm_options = fcm_options -class APNSPayload(object): +class APNSPayload: """Payload of an APNS message. Args: @@ -349,7 +349,7 @@ def __init__(self, aps, **kwargs): self.custom_data = kwargs -class Aps(object): +class Aps: """Aps dictionary to be included in an APNS payload. Args: @@ -379,7 +379,7 @@ def __init__(self, alert=None, badge=None, sound=None, content_available=None, c self.custom_data = custom_data -class CriticalSound(object): +class CriticalSound: """Critical alert sound configuration that can be included in ``messaging.Aps``. Args: @@ -398,7 +398,7 @@ def __init__(self, name, critical=None, volume=None): self.volume = volume -class ApsAlert(object): +class ApsAlert: """An alert that can be included in ``messaging.Aps``. Args: @@ -437,7 +437,7 @@ def __init__(self, title=None, subtitle=None, body=None, loc_key=None, loc_args= self.custom_data = custom_data -class APNSFCMOptions(object): +class APNSFCMOptions: """Options for features provided by the FCM SDK for iOS. Args: @@ -452,7 +452,7 @@ def __init__(self, analytics_label=None, image=None): self.image = image -class FCMOptions(object): +class FCMOptions: """Options for features provided by SDK. Args: diff --git a/firebase_admin/_sseclient.py b/firebase_admin/_sseclient.py index eab79f9e3..6585dfc80 100644 --- a/firebase_admin/_sseclient.py +++ b/firebase_admin/_sseclient.py @@ -40,7 +40,7 @@ def rebuild_auth(self, prepared_request, response): pass -class _EventBuffer(object): +class _EventBuffer: """A helper class for buffering and parsing raw SSE data.""" def __init__(self): @@ -68,7 +68,7 @@ def buffer_string(self): return ''.join(self._buffer) -class SSEClient(object): +class SSEClient: """SSE client implementation.""" def __init__(self, url, session, retry=3000, **kwargs): @@ -140,7 +140,7 @@ def __next__(self): if event.data == 'credential is no longer valid': self._connect() return None - elif event.data == 'null': + if event.data == 'null': return None # If the server requests a specific retry delay, we need to honor it. @@ -157,7 +157,7 @@ def next(self): return self.__next__() -class Event(object): +class Event: """Event represents the events fired by SSE.""" sse_line_pattern = re.compile('(?P[^:]*):?( ?(?P.*))?') @@ -192,7 +192,7 @@ def parse(cls, raw): if name == '': # line began with a ":", so is a comment. Ignore continue - elif name == 'data': + if name == 'data': # If we already have some data, then join to it with a newline. # Else this is it. if event.data: diff --git a/firebase_admin/_token_gen.py b/firebase_admin/_token_gen.py index 339714dcd..471630cca 100644 --- a/firebase_admin/_token_gen.py +++ b/firebase_admin/_token_gen.py @@ -55,7 +55,7 @@ 'service-accounts/default/email') -class _SigningProvider(object): +class _SigningProvider: """Stores a reference to a google.auth.crypto.Signer.""" def __init__(self, signer, signer_email): @@ -80,7 +80,7 @@ def from_iam(cls, request, google_cred, service_account): return _SigningProvider(signer, service_account) -class TokenGenerator(object): +class TokenGenerator: """Generates custom tokens and session cookies.""" def __init__(self, app, client): @@ -207,7 +207,7 @@ def create_session_cookie(self, id_token, expires_in): return body.get('sessionCookie') -class TokenVerifier(object): +class TokenVerifier: """Verifies ID tokens and session cookies.""" def __init__(self, app): @@ -237,7 +237,7 @@ def verify_session_cookie(self, cookie): return self.cookie_verifier.verify(cookie, self.request) -class _JWTVerifier(object): +class _JWTVerifier: """Verifies Firebase JWTs (ID tokens or session cookies).""" def __init__(self, **kwargs): @@ -288,7 +288,7 @@ def verify(self, token, request): 'token.'.format(self.operation, self.articled_short_name)) elif not header.get('kid'): if header.get('alg') == 'HS256' and payload.get( - 'v') is 0 and 'uid' in payload.get('d', {}): + 'v') == 0 and 'uid' in payload.get('d', {}): error_message = ( '{0} expects {1}, but was given a legacy custom ' 'token.'.format(self.operation, self.articled_short_name)) diff --git a/firebase_admin/_user_import.py b/firebase_admin/_user_import.py index 86252ffb8..21cc8082d 100644 --- a/firebase_admin/_user_import.py +++ b/firebase_admin/_user_import.py @@ -24,7 +24,7 @@ def b64_encode(bytes_value): return base64.urlsafe_b64encode(bytes_value).decode() -class UserProvider(object): +class UserProvider: """Represents a user identity provider that can be associated with a Firebase user. One or more providers can be specified in an ``ImportUserRecord`` when importing users via @@ -97,7 +97,7 @@ def to_dict(self): return {k: v for k, v in payload.items() if v is not None} -class ImportUserRecord(object): +class ImportUserRecord: """Represents a user account to be imported to Firebase Auth. Must specify the ``uid`` field at a minimum. A sequence of ``ImportUserRecord`` objects can be @@ -255,7 +255,7 @@ def to_dict(self): return {k: v for k, v in payload.items() if v is not None} -class UserImportHash(object): +class UserImportHash: """Represents a hash algorithm used to hash user passwords. An instance of this class must be specified when importing users with passwords via the @@ -471,7 +471,7 @@ def standard_scrypt(cls, memory_cost, parallelization, block_size, derived_key_l return UserImportHash('STANDARD_SCRYPT', data) -class ErrorInfo(object): +class ErrorInfo: """Represents an error encountered while importing an ``ImportUserRecord``.""" def __init__(self, error): @@ -487,7 +487,7 @@ def reason(self): return self._reason -class UserImportResult(object): +class UserImportResult: """Represents the result of a bulk user import operation. See ``auth.import_users()`` API for more details. diff --git a/firebase_admin/_user_mgt.py b/firebase_admin/_user_mgt.py index 2e10fac1b..5b33abb39 100644 --- a/firebase_admin/_user_mgt.py +++ b/firebase_admin/_user_mgt.py @@ -29,7 +29,7 @@ B64_REDACTED = base64.b64encode(b'REDACTED') -class Sentinel(object): +class Sentinel: def __init__(self, description): self.description = description @@ -38,7 +38,7 @@ def __init__(self, description): DELETE_ATTRIBUTE = Sentinel('Value used to delete an attribute from a user profile') -class UserMetadata(object): +class UserMetadata: """Contains additional metadata associated with a user account.""" def __init__(self, creation_timestamp=None, last_sign_in_timestamp=None): @@ -66,7 +66,7 @@ def last_sign_in_timestamp(self): return self._last_sign_in_timestamp -class UserInfo(object): +class UserInfo: """A collection of standard profile information for a user. Used to expose profile information returned by an identity provider. @@ -248,9 +248,6 @@ def custom_claims(self): class ExportedUserRecord(UserRecord): """Contains metadata associated with a user including password hash and salt.""" - def __init__(self, data): - super(ExportedUserRecord, self).__init__(data) - @property def password_hash(self): """The user's password hash as a base64-encoded string. @@ -283,7 +280,7 @@ def password_salt(self): return self._data.get('salt') -class ListUsersPage(object): +class ListUsersPage: """Represents a page of user records exported from a Firebase project. Provides methods for traversing the user accounts included in this page, as well as retrieving @@ -370,7 +367,7 @@ def provider_id(self): return self._data.get('providerId') -class ActionCodeSettings(object): +class ActionCodeSettings: """Contains required continue/state URL with optional Android and iOS settings. Used when invoking the email action link generation APIs. """ @@ -454,7 +451,7 @@ def encode_action_code_settings(settings): return parameters -class UserManager(object): +class UserManager: """Provides methods for interacting with the Google Identity Toolkit.""" def __init__(self, client): @@ -493,7 +490,7 @@ def list_users(self, page_token=None, max_results=MAX_LIST_USERS_RESULTS): raise ValueError('Page token must be a non-empty string.') if not isinstance(max_results, int): raise ValueError('Max results must be an integer.') - elif max_results < 1 or max_results > MAX_LIST_USERS_RESULTS: + if max_results < 1 or max_results > MAX_LIST_USERS_RESULTS: raise ValueError( 'Max results must be a positive integer less than ' '{0}.'.format(MAX_LIST_USERS_RESULTS)) @@ -636,6 +633,7 @@ def generate_email_action_link(self, action_type, email, action_code_settings=No link_url: action url to be emailed to the user Raises: + UnexpectedResponseError: If the backend server responds with an unexpected message FirebaseError: If an error occurs while generating the link ValueError: If the provided arguments are invalid """ @@ -660,7 +658,7 @@ def generate_email_action_link(self, action_type, email, action_code_settings=No return body.get('oobLink') -class _UserIterator(object): +class _UserIterator: """An iterator that allows iterating over user accounts, one at a time. This implementation loads a page of users into memory, and iterates on them. When the whole diff --git a/firebase_admin/_utils.py b/firebase_admin/_utils.py index 95ed2c414..7ec1b8fb8 100644 --- a/firebase_admin/_utils.py +++ b/firebase_admin/_utils.py @@ -60,17 +60,19 @@ def _get_initialized_app(app): + """Returns a reference to an initialized App instance.""" if app is None: return firebase_admin.get_app() - elif isinstance(app, firebase_admin.App): + + if isinstance(app, firebase_admin.App): initialized_app = firebase_admin.get_app(app.name) if app is not initialized_app: raise ValueError('Illegal app argument. App instance not ' 'initialized via the firebase module.') return app - else: - raise ValueError('Illegal app argument. Argument must be of type ' - ' firebase_admin.App, but given "{0}".'.format(type(app))) + + raise ValueError('Illegal app argument. Argument must be of type ' + ' firebase_admin.App, but given "{0}".'.format(type(app))) def get_app_service(app, name, initializer): @@ -143,11 +145,11 @@ def handle_requests_error(error, message=None, code=None): return exceptions.DeadlineExceededError( message='Timed out while making an API call: {0}'.format(error), cause=error) - elif isinstance(error, requests.exceptions.ConnectionError): + if isinstance(error, requests.exceptions.ConnectionError): return exceptions.UnavailableError( message='Failed to establish a connection: {0}'.format(error), cause=error) - elif error.response is None: + if error.response is None: return exceptions.UnknownError( message='Unknown error while making a remote service call: {0}'.format(error), cause=error) @@ -230,11 +232,11 @@ def handle_googleapiclient_error(error, message=None, code=None, http_response=N return exceptions.DeadlineExceededError( message='Timed out while making an API call: {0}'.format(error), cause=error) - elif isinstance(error, httplib2.ServerNotFoundError): + if isinstance(error, httplib2.ServerNotFoundError): return exceptions.UnavailableError( message='Failed to establish a connection: {0}'.format(error), cause=error) - elif not isinstance(error, googleapiclient.errors.HttpError): + if not isinstance(error, googleapiclient.errors.HttpError): return exceptions.UnknownError( message='Unknown error while making a remote service call: {0}'.format(error), cause=error) diff --git a/firebase_admin/auth.py b/firebase_admin/auth.py index a5110c211..6f85e622c 100644 --- a/firebase_admin/auth.py +++ b/firebase_admin/auth.py @@ -337,9 +337,12 @@ def download(page_token, max_results): return ListUsersPage(download, page_token, max_results) -def create_user(**kwargs): +def create_user(**kwargs): # pylint: disable=differing-param-doc """Creates a new user account with the specified properties. + Args: + kwargs: A series of keyword arguments (optional). + Keyword Args: uid: User ID to assign to the newly created user (optional). display_name: The user's display name (optional). @@ -365,7 +368,7 @@ def create_user(**kwargs): return UserRecord(user_manager.get_user(uid=uid)) -def update_user(uid, **kwargs): +def update_user(uid, **kwargs): # pylint: disable=differing-param-doc """Updates an existing user account with the specified properties. Args: @@ -542,7 +545,7 @@ def _check_jwt_revoked(verified_claims, exc_type, label, app): raise exc_type('The Firebase {0} has been revoked.'.format(label)) -class _AuthService(object): +class _AuthService: """Firebase Authentication service.""" ID_TOOLKIT_URL = 'https://identitytoolkit.googleapis.com/v1/projects/' diff --git a/firebase_admin/credentials.py b/firebase_admin/credentials.py index 2e400d9e4..e930675bd 100644 --- a/firebase_admin/credentials.py +++ b/firebase_admin/credentials.py @@ -41,7 +41,7 @@ """ -class Base(object): +class Base: """Provides OAuth2 access tokens for accessing Firebase services.""" def get_access_token(self): diff --git a/firebase_admin/db.py b/firebase_admin/db.py index ef7c96721..2fb8b3a74 100644 --- a/firebase_admin/db.py +++ b/firebase_admin/db.py @@ -81,7 +81,7 @@ def _parse_path(path): return [seg for seg in path.split('/') if seg] -class Event(object): +class Event: """Represents a realtime update event received from the database.""" def __init__(self, sse_event): @@ -104,7 +104,7 @@ def event_type(self): return self._sse_event.event_type -class ListenerRegistration(object): +class ListenerRegistration: """Represents the addition of an event listener to a database reference.""" def __init__(self, callback, sse): @@ -138,7 +138,7 @@ def close(self): self._thread.join() -class Reference(object): +class Reference: """Reference represents a node in the Firebase realtime database.""" def __init__(self, **kwargs): @@ -218,9 +218,9 @@ def get(self, etag=False, shallow=False): headers, data = self._client.headers_and_body( 'get', self._add_suffix(), headers={'X-Firebase-ETag' : 'true'}) return data, headers.get('ETag') - else: - params = 'shallow=true' if shallow else None - return self._client.body('get', self._add_suffix(), params=params) + + params = 'shallow=true' if shallow else None + return self._client.body('get', self._add_suffix(), params=params) def get_if_changed(self, etag): """Gets data in this location only if the specified ETag does not match. @@ -245,8 +245,8 @@ def get_if_changed(self, etag): resp = self._client.request('get', self._add_suffix(), headers={'if-none-match': etag}) if resp.status_code == 304: return False, None, None - else: - return True, resp.json(), resp.headers.get('ETag') + + return True, resp.json(), resp.headers.get('ETag') def set(self, value): """Sets the data at this location to the given value. @@ -300,8 +300,8 @@ def set_if_unchanged(self, expected_etag, value): etag = http_response.headers['ETag'] snapshot = http_response.json() return False, snapshot, etag - else: - raise error + + raise error def push(self, value=''): """Creates a new child node. @@ -473,7 +473,7 @@ def _listen_with_session(self, callback, session): raise _Client.handle_rtdb_error(error) -class Query(object): +class Query: """Represents a complex query that can be executed on a Reference. Complex queries can consist of up to 2 components: a required ordering constraint, and an @@ -631,7 +631,7 @@ def __init__(self, message): exceptions.AbortedError.__init__(self, message) -class _Sorter(object): +class _Sorter: """Helper class for sorting query results.""" def __init__(self, results, order_by): @@ -648,11 +648,11 @@ def __init__(self, results, order_by): def get(self): if self.dict_input: return collections.OrderedDict([(e.key, e.value) for e in self.sort_entries]) - else: - return [e.value for e in self.sort_entries] + return [e.value for e in self.sort_entries] -class _SortEntry(object): + +class _SortEntry: """A wrapper that is capable of sorting items in a dictionary.""" _type_none = 0 @@ -665,7 +665,7 @@ class _SortEntry(object): def __init__(self, key, value, order_by): self._key = key self._value = value - if order_by == '$key' or order_by == '$priority': + if order_by in ('$key', '$priority'): self._index = key elif order_by == '$value': self._index = value @@ -698,16 +698,16 @@ def _get_index_type(cls, index): """ if index is None: return cls._type_none - elif isinstance(index, bool) and not index: + if isinstance(index, bool) and not index: return cls._type_bool_false - elif isinstance(index, bool) and index: + if isinstance(index, bool) and index: return cls._type_bool_true - elif isinstance(index, (int, float)): + if isinstance(index, (int, float)): return cls._type_numeric - elif isinstance(index, six.string_types): + if isinstance(index, six.string_types): return cls._type_string - else: - return cls._type_object + + return cls._type_object @classmethod def _extract_child(cls, value, path): @@ -737,10 +737,10 @@ def _compare(self, other): if self_key < other_key: return -1 - elif self_key > other_key: + if self_key > other_key: return 1 - else: - return 0 + + return 0 def __lt__(self, other): return self._compare(other) < 0 @@ -755,10 +755,10 @@ def __ge__(self, other): return self._compare(other) >= 0 def __eq__(self, other): - return self._compare(other) is 0 + return self._compare(other) == 0 -class _DatabaseService(object): +class _DatabaseService: """Service that maintains a collection of database clients.""" _DEFAULT_AUTH_OVERRIDE = '_admin_' @@ -772,7 +772,7 @@ def __init__(self, app): else: self._db_url = None auth_override = _DatabaseService._get_auth_override(app) - if auth_override != self._DEFAULT_AUTH_OVERRIDE and auth_override != {}: + if auth_override not in (self._DEFAULT_AUTH_OVERRIDE, {}): self._auth_override = json.dumps(auth_override, separators=(',', ':')) else: self._auth_override = None @@ -832,8 +832,8 @@ def _parse_db_url(cls, url, emulator_host=None): parsed_url = urllib.parse.urlparse(url) if parsed_url.netloc.endswith('.firebaseio.com'): return cls._parse_production_url(parsed_url, emulator_host) - else: - return cls._parse_emulator_url(parsed_url) + + return cls._parse_emulator_url(parsed_url) @classmethod def _parse_production_url(cls, parsed_url, emulator_host): @@ -875,8 +875,8 @@ def _get_auth_override(cls, app): if not isinstance(auth_override, dict): raise ValueError('Invalid databaseAuthVariableOverride option: "{0}". Override ' 'value must be a dict or None.'.format(auth_override)) - else: - return auth_override + + return auth_override def close(self): for value in self._clients.values(): diff --git a/firebase_admin/firestore.py b/firebase_admin/firestore.py index a9887b195..32c9897d5 100644 --- a/firebase_admin/firestore.py +++ b/firebase_admin/firestore.py @@ -54,7 +54,7 @@ def client(app=None): return fs_client.get() -class _FirestoreClient(object): +class _FirestoreClient: """Holds a Google Cloud Firestore client instance.""" def __init__(self, credentials, project): diff --git a/firebase_admin/instance_id.py b/firebase_admin/instance_id.py index e9134fc28..f90d058cc 100644 --- a/firebase_admin/instance_id.py +++ b/firebase_admin/instance_id.py @@ -53,7 +53,7 @@ def delete_instance_id(instance_id, app=None): _get_iid_service(app).delete_instance_id(instance_id) -class _InstanceIdService(object): +class _InstanceIdService: """Provides methods for interacting with the remote instance ID service.""" error_codes = { @@ -96,5 +96,5 @@ def _extract_message(self, instance_id, error): msg = self.error_codes.get(status) if msg: return 'Instance ID "{0}": {1}'.format(instance_id, msg) - else: - return 'Instance ID "{0}": {1}'.format(instance_id, error) + + return 'Instance ID "{0}": {1}'.format(instance_id, error) diff --git a/firebase_admin/messaging.py b/firebase_admin/messaging.py index a35afc87a..71366e5c4 100644 --- a/firebase_admin/messaging.py +++ b/firebase_admin/messaging.py @@ -206,7 +206,7 @@ def unsubscribe_from_topic(tokens, topic, app=None): tokens, topic, 'iid/v1:batchRemove') -class ErrorInfo(object): +class ErrorInfo: """An error encountered when performing a topic management operation.""" def __init__(self, index, reason): @@ -224,7 +224,7 @@ def reason(self): return self._reason -class TopicManagementResponse(object): +class TopicManagementResponse: """The response received from a topic management operation.""" def __init__(self, resp): @@ -256,7 +256,7 @@ def errors(self): return self._errors -class BatchResponse(object): +class BatchResponse: """The response received from a batch request to the FCM API.""" def __init__(self, responses): @@ -277,7 +277,7 @@ def failure_count(self): return len(self.responses) - self.success_count -class SendResponse(object): +class SendResponse: """The response received from an individual batched request to the FCM API.""" def __init__(self, resp, exception): @@ -302,7 +302,7 @@ def exception(self): return self._exception -class _MessagingService(object): +class _MessagingService: """Service class that implements Firebase Cloud Messaging (FCM) functionality.""" FCM_URL = 'https://fcm.googleapis.com/v1/projects/{0}/messages:send' @@ -342,6 +342,7 @@ def encode_message(cls, message): return cls.JSON_ENCODER.default(message) def send(self, message, dry_run=False): + """Sends the given message to FCM via the FCM v1 API.""" data = self._message_data(message, dry_run) try: resp = self._client.body( diff --git a/firebase_admin/project_management.py b/firebase_admin/project_management.py index 68e10797c..076542bda 100644 --- a/firebase_admin/project_management.py +++ b/firebase_admin/project_management.py @@ -140,7 +140,7 @@ def _check_not_none(obj, field_name): return obj -class AndroidApp(object): +class AndroidApp: """A reference to an Android app within a Firebase project. Note: Unless otherwise specified, all methods defined in this class make an RPC. @@ -238,7 +238,7 @@ def delete_sha_certificate(self, certificate_to_delete): return self._service.delete_sha_certificate(certificate_to_delete) -class IOSApp(object): +class IOSApp: """A reference to an iOS app within a Firebase project. Note: Unless otherwise specified, all methods defined in this class make an RPC. @@ -294,7 +294,7 @@ def get_config(self): return self._service.get_ios_app_config(self._app_id) -class _AppMetadata(object): +class _AppMetadata: """Detailed information about a Firebase Android or iOS app.""" def __init__(self, name, app_id, display_name, project_id): @@ -382,7 +382,7 @@ def __hash__(self): return hash((self._name, self.app_id, self.display_name, self.project_id, self.bundle_id)) -class SHACertificate(object): +class SHACertificate: """Represents a SHA-1 or SHA-256 certificate associated with an Android app.""" SHA_1 = 'SHA_1' @@ -456,7 +456,7 @@ def __hash__(self): return hash((self.name, self.sha_hash, self.cert_type)) -class _ProjectManagementService(object): +class _ProjectManagementService: """Provides methods for interacting with the Firebase Project Management Service.""" BASE_URL = 'https://firebase.googleapis.com' @@ -613,10 +613,10 @@ def _poll_app_creation(self, operation_name): response = poll_response.get('response') if response: return response - else: - raise exceptions.UnknownError( - 'Polling finished, but the operation terminated in an error.', - http_response=http_response) + + raise exceptions.UnknownError( + 'Polling finished, but the operation terminated in an error.', + http_response=http_response) raise exceptions.DeadlineExceededError('Polling deadline exceeded.') def get_android_app_config(self, app_id): diff --git a/firebase_admin/storage.py b/firebase_admin/storage.py index 6aab1ebc1..a080b31ef 100644 --- a/firebase_admin/storage.py +++ b/firebase_admin/storage.py @@ -54,7 +54,7 @@ def bucket(name=None, app=None): return client.bucket(name) -class _StorageClient(object): +class _StorageClient: """Holds a Google Cloud Storage client instance.""" def __init__(self, credentials, project, default_bucket): @@ -77,7 +77,7 @@ def bucket(self, name=None): 'Storage bucket name not specified. Specify the bucket name via the ' '"storageBucket" option when initializing the App, or specify the bucket ' 'name explicitly when calling the storage.bucket() function.') - elif not bucket_name or not isinstance(bucket_name, six.string_types): + if not bucket_name or not isinstance(bucket_name, six.string_types): raise ValueError( 'Invalid storage bucket name: "{0}". Bucket name must be a non-empty ' 'string.'.format(bucket_name)) diff --git a/integration/test_auth.py b/integration/test_auth.py index 9d5d0dfe3..c3759ce12 100644 --- a/integration/test_auth.py +++ b/integration/test_auth.py @@ -18,16 +18,16 @@ import random import time import uuid -import six +import google.oauth2.credentials +from google.auth import transport import pytest import requests +import six import firebase_admin from firebase_admin import auth from firebase_admin import credentials -import google.oauth2.credentials -from google.auth import transport _verify_token_url = 'https://www.googleapis.com/identitytoolkit/v3/relyingparty/verifyCustomToken' @@ -263,7 +263,7 @@ def test_create_user(new_user): assert user.custom_claims is None assert user.user_metadata.creation_timestamp > 0 assert user.user_metadata.last_sign_in_timestamp is None - assert len(user.provider_data) is 0 + assert len(user.provider_data) == 0 with pytest.raises(auth.UidAlreadyExistsError): auth.create_user(uid=new_user.uid) diff --git a/integration/test_db.py b/integration/test_db.py index 4c2f6bde2..abd02660f 100644 --- a/integration/test_db.py +++ b/integration/test_db.py @@ -31,8 +31,8 @@ def integration_conf(request): host_override = os.environ.get('FIREBASE_DATABASE_EMULATOR_HOST') if host_override: return None, 'fake-project-id' - else: - return conftest.integration_conf(request) + + return conftest.integration_conf(request) @pytest.fixture(scope='module') @@ -83,7 +83,7 @@ def testref(update_rules, testdata, app): return ref -class TestReferenceAttributes(object): +class TestReferenceAttributes: """Test cases for attributes exposed by db.Reference class.""" def test_ref_attributes(self, testref): @@ -101,7 +101,7 @@ def test_parent(self, testref): assert parent.path == '/_adminsdk/python' -class TestReadOperations(object): +class TestReadOperations: """Test cases for reading node values.""" def test_get_value(self, testref, testdata): @@ -143,7 +143,7 @@ def test_get_nonexisting_child_value(self, testref): assert testref.child('none_existing').get() is None -class TestWriteOperations(object): +class TestWriteOperations: """Test cases for creating and updating node values.""" def test_push(self, testref): @@ -247,7 +247,7 @@ def test_delete(self, testref): assert ref.get() is None -class TestAdvancedQueries(object): +class TestAdvancedQueries: """Test cases for advanced interactions via the db.Query interface.""" height_sorted = [ @@ -352,7 +352,7 @@ def none_override_app(request, update_rules): firebase_admin.delete_app(app) -class TestAuthVariableOverride(object): +class TestAuthVariableOverride: """Test cases for database auth variable overrides.""" def init_ref(self, path, app): diff --git a/integration/test_messaging.py b/integration/test_messaging.py index bc4f1d1ca..001b96a0a 100644 --- a/integration/test_messaging.py +++ b/integration/test_messaging.py @@ -140,7 +140,7 @@ def test_send_multicast(): batch_response = messaging.send_multicast(multicast) - assert batch_response.success_count is 0 + assert batch_response.success_count == 0 assert batch_response.failure_count == 2 assert len(batch_response.responses) == 2 for response in batch_response.responses: diff --git a/lint.sh b/lint.sh index aeb37f741..0fd5058a3 100755 --- a/lint.sh +++ b/lint.sh @@ -31,8 +31,8 @@ function lintChangedFiles () { set -o errexit set -o nounset -SKIP_FOR_TESTS="redefined-outer-name,protected-access,missing-docstring,too-many-lines" -SKIP_FOR_SNIPPETS="${SKIP_FOR_TESTS},reimported,unused-variable" +SKIP_FOR_TESTS="redefined-outer-name,protected-access,missing-docstring,too-many-lines,len-as-condition" +SKIP_FOR_SNIPPETS="${SKIP_FOR_TESTS},reimported,unused-variable,unused-import,import-outside-toplevel" if [[ "$#" -eq 1 && "$1" = "all" ]] then diff --git a/requirements.txt b/requirements.txt index 2f1a09a5b..6d28b38ac 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -pylint == 1.6.4 +pylint == 2.3.1 pytest >= 3.6.0 pytest-cov >= 2.4.0 pytest-localserver >= 0.4.1 diff --git a/snippets/auth/index.py b/snippets/auth/index.py index 552875696..b1c091064 100644 --- a/snippets/auth/index.py +++ b/snippets/auth/index.py @@ -385,8 +385,8 @@ def serve_content_for_admin(decoded_claims): # Check custom claims to confirm user is an admin. if decoded_claims.get('admin') is True: return serve_content_for_admin(decoded_claims) - else: - return flask.abort(401, 'Insufficient permissions') + + return flask.abort(401, 'Insufficient permissions') except auth.InvalidSessionCookieError: # Session cookie is invalid, expired or revoked. Force user to login. return flask.redirect('/login') diff --git a/tests/test_app.py b/tests/test_app.py index 9d3446692..fe3a43a5c 100644 --- a/tests/test_app.py +++ b/tests/test_app.py @@ -30,7 +30,7 @@ # This fixture will ignore the environment variable pointing to the default # configuration for the duration of the tests. -class CredentialProvider(object): +class CredentialProvider: def init(self): pass @@ -73,7 +73,7 @@ def get(self): return None -class AppService(object): +class AppService: def __init__(self, app): self._app = app @@ -89,8 +89,8 @@ def app_credential(request): def init_app(request): if request.param: return firebase_admin.initialize_app(CREDENTIAL, name=request.param) - else: - return firebase_admin.initialize_app(CREDENTIAL) + + return firebase_admin.initialize_app(CREDENTIAL) @pytest.fixture(scope="function") def env_test_case(request): @@ -211,7 +211,7 @@ def revert_config_env(config_old): elif os.environ.get(CONFIG_JSON) is not None: del os.environ[CONFIG_JSON] -class TestFirebaseApp(object): +class TestFirebaseApp: """Test cases for App initialization and life cycle.""" invalid_credentials = ['', 'foo', 0, 1, dict(), list(), tuple(), True, False] diff --git a/tests/test_credentials.py b/tests/test_credentials.py index 6f081d796..d78ef5192 100644 --- a/tests/test_credentials.py +++ b/tests/test_credentials.py @@ -22,9 +22,9 @@ from google.auth import exceptions from google.oauth2 import credentials as gcredentials from google.oauth2 import service_account -from firebase_admin import credentials import pytest +from firebase_admin import credentials from tests import testutils @@ -33,7 +33,7 @@ def check_scopes(g_credential): assert sorted(credentials._scopes) == sorted(g_credential.scopes) -class TestCertificate(object): +class TestCertificate: invalid_certs = { 'NonExistingFile': ('non_existing.json', IOError), @@ -91,7 +91,7 @@ def app_default(request): del os.environ[var_name] -class TestApplicationDefault(object): +class TestApplicationDefault: @pytest.mark.parametrize('app_default', [testutils.resource_filename('service_account.json')], indirect=True) @@ -122,7 +122,7 @@ def test_nonexisting_path(self, app_default): creds.get_credential() # This now throws. -class TestRefreshToken(object): +class TestRefreshToken: def test_init_from_file(self): credential = credentials.RefreshToken( diff --git a/tests/test_db.py b/tests/test_db.py index e9f8f7dda..b20f99cb9 100644 --- a/tests/test_db.py +++ b/tests/test_db.py @@ -48,7 +48,7 @@ def send(self, request, **kwargs): return resp -class MockSSEClient(object): +class MockSSEClient: """A mock SSE client that mimics long-lived HTTP connections.""" def __init__(self, events): @@ -62,11 +62,11 @@ def close(self): self.closed = True -class _Object(object): +class _Object: pass -class TestReferencePath(object): +class TestReferencePath: """Test cases for Reference paths.""" # path => (fullstr, key, parent) @@ -127,7 +127,7 @@ def test_invalid_child(self, child): parent.child(child) -class _RefOperations(object): +class _RefOperations: """A collection of operations that can be performed using a ``db.Reference``. This can be used to test any functionality that is common across multiple API calls. @@ -159,7 +159,7 @@ def get_ops(cls): return [cls.get, cls.push, cls.set, cls.delete, cls.query] -class TestReference(object): +class TestReference: """Test cases for database queries via References.""" test_url = 'https://test.firebaseio.com' @@ -381,7 +381,7 @@ def test_set_invalid_update(self, update): recorder = self.instrument(ref, '') with pytest.raises(ValueError): ref.update(update) - assert len(recorder) is 0 + assert len(recorder) == 0 @pytest.mark.parametrize('data', valid_values) def test_push(self, data): @@ -527,7 +527,7 @@ def test_other_error(self, error_code, func): assert excinfo.value.http_response is not None -class TestListenerRegistration(object): +class TestListenerRegistration: """Test cases for receiving events via ListenerRegistrations.""" def test_listen_error(self): @@ -598,7 +598,7 @@ def wait_for(cls, events, count=1, timeout_seconds=5): raise pytest.fail('Timed out while waiting for events') -class TestReferenceWithAuthOverride(object): +class TestReferenceWithAuthOverride: """Test cases for database queries via References.""" test_url = 'https://test.firebaseio.com' @@ -671,7 +671,7 @@ def test_range_query(self): assert recorder[0].headers['User-Agent'] == db._USER_AGENT -class TestDatabaseInitialization(object): +class TestDatabaseInitialization: """Test cases for database initialization.""" def teardown_method(self): @@ -847,13 +847,13 @@ def initquery(request): ref = db.Reference(path='foo') if request.param == '$key': return ref.order_by_key(), request.param - elif request.param == '$value': + if request.param == '$value': return ref.order_by_value(), request.param - else: - return ref.order_by_child(request.param), request.param + return ref.order_by_child(request.param), request.param -class TestQuery(object): + +class TestQuery: """Test cases for db.Query class.""" valid_paths = { @@ -982,7 +982,7 @@ def test_invalid_query_args(self): db.Query(order_by='$key', client=ref._client, pathurl=ref._add_suffix(), foo='bar') -class TestSorter(object): +class TestSorter: """Test cases for db._Sorter class.""" value_test_cases = [ diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index 98d9ce5e9..3df7ec2e3 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -37,7 +37,7 @@ }) -class TestRequests(object): +class TestRequests: def test_timeout_error(self): error = requests.exceptions.Timeout('Test error') @@ -156,7 +156,6 @@ def test_handle_platform_error_with_custom_handler_ignore(self): def _custom_handler(cause, message, error_dict): invocations.append((cause, message, error_dict)) - return None firebase_error = _utils.handle_platform_error_from_requests(error, _custom_handler) @@ -180,7 +179,7 @@ def _create_response(self, status=500, payload=None): return resp, exc -class TestGoogleApiClient(object): +class TestGoogleApiClient: @pytest.mark.parametrize('error', [ socket.timeout('Test error'), @@ -313,7 +312,6 @@ def test_handle_platform_error_with_custom_handler_ignore(self): def _custom_handler(cause, message, error_dict, http_response): invocations.append((cause, message, error_dict, http_response)) - return None firebase_error = _utils.handle_platform_error_from_googleapiclient(error, _custom_handler) diff --git a/tests/test_firestore.py b/tests/test_firestore.py index 01b019333..768eb637e 100644 --- a/tests/test_firestore.py +++ b/tests/test_firestore.py @@ -30,7 +30,7 @@ @pytest.mark.skipif( platform.python_implementation() == 'PyPy', reason='Firestore is not supported on PyPy') -class TestFirestore(object): +class TestFirestore: """Test class Firestore APIs.""" def teardown_method(self, method): diff --git a/tests/test_http_client.py b/tests/test_http_client.py index a0c6dc654..ce35e5ce4 100644 --- a/tests/test_http_client.py +++ b/tests/test_http_client.py @@ -84,7 +84,7 @@ def _instrument(client, payload, status=200): return recorder -class TestHttpRetry(object): +class TestHttpRetry: """Unit tests for the default HTTP retry configuration.""" ENTITY_ENCLOSING_METHODS = ['post', 'put', 'patch'] diff --git a/tests/test_instance_id.py b/tests/test_instance_id.py index 83e66491a..a13506a07 100644 --- a/tests/test_instance_id.py +++ b/tests/test_instance_id.py @@ -50,7 +50,7 @@ exceptions.UnavailableError), } -class TestDeleteInstanceId(object): +class TestDeleteInstanceId: def teardown_method(self): testutils.cleanup_apps() @@ -132,4 +132,4 @@ def test_invalid_instance_id(self, iid): _, recorder = self._instrument_iid_service(app) with pytest.raises(ValueError): instance_id.delete_instance_id(iid) - assert len(recorder) is 0 + assert len(recorder) == 0 diff --git a/tests/test_messaging.py b/tests/test_messaging.py index 96b512577..36f5943be 100644 --- a/tests/test_messaging.py +++ b/tests/test_messaging.py @@ -62,7 +62,7 @@ def check_exception(exception, message, status): assert exception.http_response.status_code == status -class TestMessageStr(object): +class TestMessageStr: @pytest.mark.parametrize('msg', [ messaging.Message(), @@ -90,7 +90,7 @@ def test_data_message(self): 'k1': 'v1', 'k2': 'v2'})) == '{"data": {"k1": "v1", "k2": "v2"}, "topic": "topic"}' -class TestMulticastMessage(object): +class TestMulticastMessage: @pytest.mark.parametrize('tokens', NON_LIST_ARGS) def test_invalid_tokens_type(self, tokens): @@ -117,7 +117,7 @@ def test_tokens_type(self): assert len(message.tokens) == 500 -class TestMessageEncoder(object): +class TestMessageEncoder: @pytest.mark.parametrize('msg', [ messaging.Message(), @@ -183,7 +183,7 @@ def test_fcm_options(self): {'topic': 'topic'}) -class TestNotificationEncoder(object): +class TestNotificationEncoder: @pytest.mark.parametrize('data', NON_OBJECT_ARGS) def test_invalid_notification(self, data): @@ -219,7 +219,7 @@ def test_notification_message(self): {'topic': 'topic', 'notification': {'title': 't'}}) -class TestFcmOptionEncoder(object): +class TestFcmOptionEncoder: @pytest.mark.parametrize('label', [ '!', @@ -266,7 +266,7 @@ def test_fcm_options(self): }) -class TestAndroidConfigEncoder(object): +class TestAndroidConfigEncoder: @pytest.mark.parametrize('data', NON_OBJECT_ARGS) def test_invalid_android(self, data): @@ -367,7 +367,7 @@ def test_android_ttl(self, ttl): check_encoding(msg, expected) -class TestAndroidNotificationEncoder(object): +class TestAndroidNotificationEncoder: def _check_notification(self, notification): with pytest.raises(ValueError) as excinfo: @@ -603,7 +603,7 @@ def test_android_notification(self): check_encoding(msg, expected) -class TestLightSettingsEncoder(object): +class TestLightSettingsEncoder: def _check_light_settings(self, light_settings): with pytest.raises(ValueError) as excinfo: @@ -717,7 +717,7 @@ def test_light_settings(self): check_encoding(msg, expected) -class TestWebpushConfigEncoder(object): +class TestWebpushConfigEncoder: @pytest.mark.parametrize('data', NON_OBJECT_ARGS) def test_invalid_webpush(self, data): @@ -763,7 +763,7 @@ def test_webpush_config(self): check_encoding(msg, expected) -class TestWebpushFCMOptionsEncoder(object): +class TestWebpushFCMOptionsEncoder: @pytest.mark.parametrize('data', NON_OBJECT_ARGS) def test_invalid_webpush_fcm_options(self, data): @@ -809,7 +809,7 @@ def test_webpush_options(self): check_encoding(msg, expected) -class TestWebpushNotificationEncoder(object): +class TestWebpushNotificationEncoder: def _check_notification(self, notification): with pytest.raises(ValueError) as excinfo: @@ -1012,7 +1012,7 @@ def test_invalid_action_icon(self, data): assert str(excinfo.value) == 'WebpushNotificationAction.icon must be a string.' -class TestAPNSConfigEncoder(object): +class TestAPNSConfigEncoder: @pytest.mark.parametrize('data', NON_OBJECT_ARGS) def test_invalid_apns(self, data): @@ -1051,7 +1051,7 @@ def test_apns_config(self): check_encoding(msg, expected) -class TestAPNSPayloadEncoder(object): +class TestAPNSPayloadEncoder: @pytest.mark.parametrize('data', NON_OBJECT_ARGS) def test_invalid_payload(self, data): @@ -1085,7 +1085,7 @@ def test_apns_payload(self): check_encoding(msg, expected) -class TestApsEncoder(object): +class TestApsEncoder: def _encode_aps(self, aps): return check_encoding(messaging.Message( @@ -1227,7 +1227,7 @@ def test_aps_custom_data(self): check_encoding(msg, expected) -class TestApsSoundEncoder(object): +class TestApsSoundEncoder: def _check_sound(self, sound): with pytest.raises(ValueError) as excinfo: @@ -1335,7 +1335,7 @@ def test_critical_sound_name_only(self): check_encoding(msg, expected) -class TestApsAlertEncoder(object): +class TestApsAlertEncoder: def _check_alert(self, alert): with pytest.raises(ValueError) as excinfo: @@ -1539,7 +1539,7 @@ def test_aps_alert_custom_data_override(self): } check_encoding(msg, expected) -class TestTimeout(object): +class TestTimeout: @classmethod def setup_class(cls): @@ -1577,7 +1577,7 @@ def test_topic_management_timeout(self): assert self.recorder[0]._extra_kwargs['timeout'] == pytest.approx(4, 0.001) -class TestSend(object): +class TestSend: _DEFAULT_RESPONSE = json.dumps({'name': 'message-id'}) _CLIENT_VERSION = 'fire-admin-python/{0}'.format(firebase_admin.__version__) @@ -1753,7 +1753,7 @@ def test_send_unknown_fcm_error_code(self, status): assert json.loads(recorder[0].body.decode()) == body -class TestBatch(object): +class TestBatch: @classmethod def setup_class(cls): @@ -1822,8 +1822,8 @@ def test_send_all(self): payload=self._batch_payload([(200, payload), (200, payload)])) msg = messaging.Message(topic='foo') batch_response = messaging.send_all([msg, msg], dry_run=True) - assert batch_response.success_count is 2 - assert batch_response.failure_count is 0 + assert batch_response.success_count == 2 + assert batch_response.failure_count == 0 assert len(batch_response.responses) == 2 assert [r.message_id for r in batch_response.responses] == ['message-id', 'message-id'] assert all([r.success for r in batch_response.responses]) @@ -1842,8 +1842,8 @@ def test_send_all_detailed_error(self, status): payload=self._batch_payload([(200, success_payload), (status, error_payload)])) msg = messaging.Message(topic='foo') batch_response = messaging.send_all([msg, msg]) - assert batch_response.success_count is 1 - assert batch_response.failure_count is 1 + assert batch_response.success_count == 1 + assert batch_response.failure_count == 1 assert len(batch_response.responses) == 2 success_response = batch_response.responses[0] assert success_response.message_id == 'message-id' @@ -1869,8 +1869,8 @@ def test_send_all_canonical_error_code(self, status): payload=self._batch_payload([(200, success_payload), (status, error_payload)])) msg = messaging.Message(topic='foo') batch_response = messaging.send_all([msg, msg]) - assert batch_response.success_count is 1 - assert batch_response.failure_count is 1 + assert batch_response.success_count == 1 + assert batch_response.failure_count == 1 assert len(batch_response.responses) == 2 success_response = batch_response.responses[0] assert success_response.message_id == 'message-id' @@ -1903,8 +1903,8 @@ def test_send_all_fcm_error_code(self, status, fcm_error_code, exc_type): payload=self._batch_payload([(200, success_payload), (status, error_payload)])) msg = messaging.Message(topic='foo') batch_response = messaging.send_all([msg, msg]) - assert batch_response.success_count is 1 - assert batch_response.failure_count is 1 + assert batch_response.success_count == 1 + assert batch_response.failure_count == 1 assert len(batch_response.responses) == 2 success_response = batch_response.responses[0] assert success_response.message_id == 'message-id' @@ -1997,8 +1997,8 @@ def test_send_multicast(self): payload=self._batch_payload([(200, payload), (200, payload)])) msg = messaging.MulticastMessage(tokens=['foo', 'foo']) batch_response = messaging.send_multicast(msg, dry_run=True) - assert batch_response.success_count is 2 - assert batch_response.failure_count is 0 + assert batch_response.success_count == 2 + assert batch_response.failure_count == 0 assert len(batch_response.responses) == 2 assert [r.message_id for r in batch_response.responses] == ['message-id', 'message-id'] assert all([r.success for r in batch_response.responses]) @@ -2017,8 +2017,8 @@ def test_send_multicast_detailed_error(self, status): payload=self._batch_payload([(200, success_payload), (status, error_payload)])) msg = messaging.MulticastMessage(tokens=['foo', 'foo']) batch_response = messaging.send_multicast(msg) - assert batch_response.success_count is 1 - assert batch_response.failure_count is 1 + assert batch_response.success_count == 1 + assert batch_response.failure_count == 1 assert len(batch_response.responses) == 2 success_response = batch_response.responses[0] assert success_response.message_id == 'message-id' @@ -2045,8 +2045,8 @@ def test_send_multicast_canonical_error_code(self, status): payload=self._batch_payload([(200, success_payload), (status, error_payload)])) msg = messaging.MulticastMessage(tokens=['foo', 'foo']) batch_response = messaging.send_multicast(msg) - assert batch_response.success_count is 1 - assert batch_response.failure_count is 1 + assert batch_response.success_count == 1 + assert batch_response.failure_count == 1 assert len(batch_response.responses) == 2 success_response = batch_response.responses[0] assert success_response.message_id == 'message-id' @@ -2079,8 +2079,8 @@ def test_send_multicast_fcm_error_code(self, status): payload=self._batch_payload([(200, success_payload), (status, error_payload)])) msg = messaging.MulticastMessage(tokens=['foo', 'foo']) batch_response = messaging.send_multicast(msg) - assert batch_response.success_count is 1 - assert batch_response.failure_count is 1 + assert batch_response.success_count == 1 + assert batch_response.failure_count == 1 assert len(batch_response.responses) == 2 success_response = batch_response.responses[0] assert success_response.message_id == 'message-id' @@ -2152,7 +2152,7 @@ def test_send_multicast_batch_fcm_error_code(self, status): check_exception(excinfo.value, 'test error', status) -class TestTopicManagement(object): +class TestTopicManagement: _DEFAULT_RESPONSE = json.dumps({'results': [{}, {'error': 'error_reason'}]}) _DEFAULT_ERROR_RESPONSE = json.dumps({'error': 'error_reason'}) diff --git a/tests/test_project_management.py b/tests/test_project_management.py index e8353e212..aa717bbf7 100644 --- a/tests/test_project_management.py +++ b/tests/test_project_management.py @@ -202,7 +202,7 @@ NOT_FOUND_RESPONSE = '{"error": {"message": "Failed to find the resource"}}' UNAVAILABLE_RESPONSE = '{"error": {"message": "Backend servers are over capacity"}}' -class TestAndroidAppMetadata(object): +class TestAndroidAppMetadata: def test_create_android_app_metadata_errors(self): # package_name must be a non-empty string. @@ -289,7 +289,6 @@ def test_android_app_metadata_eq_and_hash(self): # Don't trigger __ne__. assert not metadata_1 == ios_metadata # pylint: disable=unneeded-not assert metadata_1 != ios_metadata - assert metadata_1 == metadata_1 assert metadata_1 != metadata_2 assert metadata_1 != metadata_3 assert metadata_1 != metadata_4 @@ -315,7 +314,7 @@ def test_android_app_metadata_project_id(self): assert ANDROID_APP_METADATA.project_id == 'test-project-id' -class TestIOSAppMetadata(object): +class TestIOSAppMetadata: def test_create_ios_app_metadata_errors(self): # bundle_id must be a non-empty string. @@ -402,7 +401,6 @@ def test_ios_app_metadata_eq_and_hash(self): # Don't trigger __ne__. assert not metadata_1 == android_metadata # pylint: disable=unneeded-not assert metadata_1 != android_metadata - assert metadata_1 == metadata_1 assert metadata_1 != metadata_2 assert metadata_1 != metadata_3 assert metadata_1 != metadata_4 @@ -427,7 +425,7 @@ def test_ios_app_metadata_project_id(self): assert IOS_APP_METADATA.project_id == 'test-project-id' -class TestSHACertificate(object): +class TestSHACertificate: def test_create_sha_certificate_errors(self): # sha_hash cannot be None. with pytest.raises(ValueError): @@ -469,7 +467,6 @@ def test_sha_certificate_eq(self): 'cert_type': 'SHA_1', } - assert sha_cert_1 == sha_cert_1 assert sha_cert_1 != sha_cert_2 assert sha_cert_1 != sha_cert_3 assert sha_cert_1 != sha_cert_4 @@ -496,7 +493,7 @@ def test_sha_certificate_cert_type(self): assert SHA_256_CERTIFICATE.cert_type == 'SHA_256' -class BaseProjectManagementTest(object): +class BaseProjectManagementTest: @classmethod def setup_class(cls): project_management._ProjectManagementService.POLL_BASE_WAIT_TIME_SECONDS = 0.01 diff --git a/tests/test_sseclient.py b/tests/test_sseclient.py index a9ec2edf7..881ecc6b9 100644 --- a/tests/test_sseclient.py +++ b/tests/test_sseclient.py @@ -36,7 +36,7 @@ def send(self, request, **kwargs): return resp -class TestSSEClient(object): +class TestSSEClient: """Test cases for the SSEClient""" test_url = "https://test.firebaseio.com" @@ -54,7 +54,7 @@ def test_init_sseclient(self): payload = 'event: put\ndata: {"path":"/","data":"testevent"}\n\n' sseclient = self.init_sse(payload) assert sseclient.url == self.test_url - assert sseclient.session != None + assert sseclient.session is not None def test_single_event(self): payload = 'event: put\ndata: {"path":"/","data":"testevent"}\n\n' @@ -120,7 +120,7 @@ def test_event_separators(self): assert len(recorder) == 1 -class TestEvent(object): +class TestEvent: """Test cases for server-side events""" def test_normal(self): diff --git a/tests/test_token_gen.py b/tests/test_token_gen.py index e016b8fb1..e92fd0059 100644 --- a/tests/test_token_gen.py +++ b/tests/test_token_gen.py @@ -166,7 +166,7 @@ def revoked_tokens(): return json.dumps(mock_user) -class TestCreateCustomToken(object): +class TestCreateCustomToken: valid_args = { 'Basic': (MOCK_UID, {'one': 2, 'three': 'four'}), @@ -283,7 +283,7 @@ def _verify_signer(self, token, signer): assert body['sub'] == signer -class TestCreateSessionCookie(object): +class TestCreateSessionCookie: @pytest.mark.parametrize('id_token', [None, '', 0, 1, True, False, list(), dict(), tuple()]) def test_invalid_id_token(self, user_mgt_app, id_token): @@ -350,7 +350,7 @@ def test_unexpected_response(self, user_mgt_app): TEST_SESSION_COOKIE = _get_session_cookie() -class TestVerifyIdToken(object): +class TestVerifyIdToken: valid_tokens = { 'BinaryToken': TEST_ID_TOKEN, @@ -475,7 +475,7 @@ def test_certificate_request_failure(self, user_mgt_app): assert excinfo.value.http_response is None -class TestVerifySessionCookie(object): +class TestVerifySessionCookie: valid_cookies = { 'BinaryCookie': TEST_SESSION_COOKIE, @@ -590,7 +590,7 @@ def test_certificate_request_failure(self, user_mgt_app): assert excinfo.value.http_response is None -class TestCertificateCaching(object): +class TestCertificateCaching: def test_certificate_caching(self, user_mgt_app, httpserver): httpserver.serve_content(MOCK_PUBLIC_CERTS, 200, headers={'Cache-Control': 'max-age=3600'}) diff --git a/tests/test_user_mgt.py b/tests/test_user_mgt.py index f4e03cc3f..f1572baf2 100644 --- a/tests/test_user_mgt.py +++ b/tests/test_user_mgt.py @@ -19,6 +19,7 @@ import time import pytest +from six.moves import urllib import firebase_admin from firebase_admin import auth @@ -28,8 +29,6 @@ from firebase_admin import _user_mgt from tests import testutils -from six.moves import urllib - INVALID_STRINGS = [None, '', 0, 1, True, False, list(), tuple(), dict()] INVALID_DICTS = [None, 'foo', 0, 1, True, False, list(), tuple()] @@ -101,7 +100,7 @@ def _check_user_record(user, expected_uid='testuser'): assert provider.provider_id == 'phone' -class TestAuthServiceInitialization(object): +class TestAuthServiceInitialization: def test_fail_on_no_project_id(self): app = firebase_admin.initialize_app(testutils.MockCredential(), name='userMgt2') @@ -109,7 +108,7 @@ def test_fail_on_no_project_id(self): auth._get_auth_service(app) firebase_admin.delete_app(app) -class TestUserRecord(object): +class TestUserRecord: # Input dict must be non-empty, and must not contain unsupported keys. @pytest.mark.parametrize('data', INVALID_DICTS + [{}, {'foo':'bar'}]) @@ -186,10 +185,10 @@ def test_tokens_valid_after_time(self): def test_no_tokens_valid_after_time(self): user = auth.UserRecord({'localId' : 'user'}) - assert user.tokens_valid_after_timestamp is 0 + assert user.tokens_valid_after_timestamp == 0 -class TestGetUser(object): +class TestGetUser: @pytest.mark.parametrize('arg', INVALID_STRINGS + ['a'*129]) def test_invalid_get_user(self, arg, user_mgt_app): @@ -295,7 +294,7 @@ def test_get_user_by_phone_http_error(self, user_mgt_app): assert excinfo.value.cause is not None -class TestCreateUser(object): +class TestCreateUser: already_exists_errors = { 'DUPLICATE_EMAIL': auth.EmailAlreadyExistsError, @@ -395,7 +394,7 @@ def test_create_user_unexpected_response(self, user_mgt_app): assert isinstance(excinfo.value, exceptions.UnknownError) -class TestUpdateUser(object): +class TestUpdateUser: @pytest.mark.parametrize('arg', INVALID_STRINGS + ['a'*129]) def test_invalid_uid(self, user_mgt_app, arg): @@ -513,7 +512,7 @@ def test_update_user_valid_since(self, user_mgt_app, arg): assert request == {'localId': 'testuser', 'validSince': int(arg)} -class TestSetCustomUserClaims(object): +class TestSetCustomUserClaims: @pytest.mark.parametrize('arg', INVALID_STRINGS + ['a'*129]) def test_invalid_uid(self, user_mgt_app, arg): @@ -576,7 +575,7 @@ def test_set_custom_user_claims_error(self, user_mgt_app): assert excinfo.value.cause is not None -class TestDeleteUser(object): +class TestDeleteUser: @pytest.mark.parametrize('arg', INVALID_STRINGS + ['a'*129]) def test_invalid_delete_user(self, user_mgt_app, arg): @@ -606,7 +605,7 @@ def test_delete_user_unexpected_response(self, user_mgt_app): assert isinstance(excinfo.value, exceptions.UnknownError) -class TestListUsers(object): +class TestListUsers: @pytest.mark.parametrize('arg', [None, 'foo', list(), dict(), 0, -1, 1001, False]) def test_invalid_max_results(self, user_mgt_app, arg): @@ -625,7 +624,7 @@ def test_list_single_page(self, user_mgt_app): assert page.next_page_token == '' assert page.has_next_page is False assert page.get_next_page() is None - users = [user for user in page.iterate_all()] + users = list(user for user in page.iterate_all()) assert len(users) == 2 self._check_rpc_calls(recorder) @@ -710,7 +709,7 @@ def test_list_users_stop_iteration(self, user_mgt_app): assert len(page.users) == 3 iterator = page.iterate_all() - users = [user for user in iterator] + users = list(user for user in iterator) assert len(page.users) == 3 with pytest.raises(StopIteration): next(iterator) @@ -721,9 +720,9 @@ def test_list_users_no_users_response(self, user_mgt_app): response = {'users': []} _instrument_user_manager(user_mgt_app, 200, json.dumps(response)) page = auth.list_users(app=user_mgt_app) - assert len(page.users) is 0 - users = [user for user in page.iterate_all()] - assert len(users) is 0 + assert len(page.users) == 0 + users = list(user for user in page.iterate_all()) + assert len(users) == 0 def test_list_users_with_max_results(self, user_mgt_app): _, recorder = _instrument_user_manager(user_mgt_app, 200, MOCK_LIST_USERS_RESPONSE) @@ -777,7 +776,7 @@ def _check_rpc_calls(self, recorder, expected=None): assert request == expected -class TestUserProvider(object): +class TestUserProvider: _INVALID_PROVIDERS = ( [{'display_name': arg} for arg in INVALID_STRINGS[1:]] + @@ -819,7 +818,7 @@ def test_invalid_arg(self, arg): auth.UserProvider(uid='test', provider_id='google.com', **arg) -class TestUserMetadata(object): +class TestUserMetadata: _INVALID_ARGS = ( [{'creation_timestamp': arg} for arg in INVALID_TIMESTAMPS] + @@ -832,7 +831,7 @@ def test_invalid_args(self, arg): auth.UserMetadata(**arg) -class TestImportUserRecord(object): +class TestImportUserRecord: _INVALID_USERS = ( [{'display_name': arg} for arg in INVALID_STRINGS[1:]] + @@ -908,7 +907,7 @@ def test_disabled(self, disabled): assert user.to_dict() == {'localId': 'test', 'disabled': disabled} -class TestUserImportHash(object): +class TestUserImportHash: @pytest.mark.parametrize('func,name', [ (auth.UserImportHash.hmac_sha512, 'HMAC_SHA512'), @@ -1021,7 +1020,7 @@ def test_invalid_standard_scrypt(self, arg): auth.UserImportHash.standard_scrypt(**params) -class TestImportUsers(object): +class TestImportUsers: @pytest.mark.parametrize('arg', [None, list(), tuple(), dict(), 0, 1, 'foo']) def test_invalid_users(self, user_mgt_app, arg): @@ -1041,7 +1040,7 @@ def test_import_users(self, user_mgt_app): ] result = auth.import_users(users, app=user_mgt_app) assert result.success_count == 2 - assert result.failure_count is 0 + assert result.failure_count == 0 assert result.errors == [] expected = {'users': [{'localId': 'user1'}, {'localId': 'user2'}]} self._check_rpc_calls(recorder, expected) @@ -1087,7 +1086,7 @@ def test_import_users_with_hash(self, user_mgt_app): b'key', rounds=8, memory_cost=14, salt_separator=b'sep') result = auth.import_users(users, hash_alg=hash_alg, app=user_mgt_app) assert result.success_count == 2 - assert result.failure_count is 0 + assert result.failure_count == 0 assert result.errors == [] expected = { 'users': [ @@ -1127,7 +1126,7 @@ def _check_rpc_calls(self, recorder, expected): assert request == expected -class TestRevokeRefreshTokkens(object): +class TestRevokeRefreshTokkens: def test_revoke_refresh_tokens(self, user_mgt_app): _, recorder = _instrument_user_manager(user_mgt_app, 200, '{"localId":"testuser"}') @@ -1141,7 +1140,7 @@ def test_revoke_refresh_tokens(self, user_mgt_app): assert int(request['validSince']) <= int(after_time) -class TestActionCodeSetting(object): +class TestActionCodeSetting: def test_valid_data(self): data = { @@ -1186,7 +1185,7 @@ def test_encode_action_code_bad_data(self): _user_mgt.encode_action_code_settings({"foo":"bar"}) -class TestGenerateEmailActionLink(object): +class TestGenerateEmailActionLink: def test_email_verification_no_settings(self, user_mgt_app): _, recorder = _instrument_user_manager(user_mgt_app, 200, '{"oobLink":"https://testlink"}') diff --git a/tests/testutils.py b/tests/testutils.py index cdbf75aef..9c69663a0 100644 --- a/tests/testutils.py +++ b/tests/testutils.py @@ -88,7 +88,7 @@ def __init__(self, status, response): self.response = MockResponse(status, response) self.log = [] - def __call__(self, *args, **kwargs): + def __call__(self, *args, **kwargs): # pylint: disable=arguments-differ self.log.append((args, kwargs)) return self.response @@ -100,7 +100,7 @@ def __init__(self, error): self.error = error self.log = [] - def __call__(self, *args, **kwargs): + def __call__(self, *args, **kwargs): # pylint: disable=arguments-differ self.log.append((args, kwargs)) raise self.error @@ -139,7 +139,7 @@ def __init__(self, responses, statuses, recorder): self._statuses = list(statuses) self._recorder = recorder - def send(self, request, **kwargs): + def send(self, request, **kwargs): # pylint: disable=arguments-differ request._extra_kwargs = kwargs self._recorder.append(request) resp = models.Response() From d3dda24dac3e34ab89daf82d8fb8b03cdb0c1a76 Mon Sep 17 00:00:00 2001 From: Hiranya Jayathilaka Date: Tue, 14 Jan 2020 21:44:58 -0800 Subject: [PATCH 043/226] Create a GitHub Actions based CI Pipeline (#386) * Create an GitHub Actions based CI Pipeline * Temporarily removing Py3.4 * Added emulator-based integration tests --- .github/workflows/ci.yml | 36 ++++++++++++++++++++++++++++++++++++ .travis.yml | 10 ---------- 2 files changed, 36 insertions(+), 10 deletions(-) create mode 100644 .github/workflows/ci.yml diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 000000000..ca223bd5b --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,36 @@ +name: Continuous Integration + +on: [push, pull_request] + +jobs: + build: + + runs-on: ubuntu-latest + strategy: + matrix: + python: [3.5, 3.6, 3.7, pypy3] + + steps: + - uses: actions/checkout@v1 + - name: Set up Python ${{ matrix.python }} + uses: actions/setup-python@v1 + with: + python-version: ${{ matrix.python }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -r requirements.txt + - name: Lint with pylint + if: matrix.python == '3.7' + run: ./lint.sh all + - name: Test with pytest + if: success() || failure() + run: pytest + - name: Set up Node.js 10 + uses: actions/setup-node@v1 + with: + node-version: 10.x + - name: Run integration tests against emulator + run: | + npm install -g firebase-tools + firebase emulators:exec --only database --project fake-project-id 'pytest integration/test_db.py' diff --git a/.travis.yml b/.travis.yml index 8d6b9246a..8cec7e1d9 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,16 +1,6 @@ language: python python: - "3.4" - - "3.5" - - "3.6" - - "3.7" - - "pypy3.5" - -jobs: - include: - - name: "Lint" - python: "3.7" - script: ./lint.sh all before_install: - nvm install 8 && npm install -g firebase-tools From 0e4f3bfbcc0217dad02b049e11ce0d65f6c9eeec Mon Sep 17 00:00:00 2001 From: Hiranya Jayathilaka Date: Wed, 15 Jan 2020 10:21:00 -0800 Subject: [PATCH 044/226] Speeding up the HTTPClient tests by reusing the test server instance (#387) --- tests/test_http_client.py | 44 ++++++++++++++++++++++++--------------- 1 file changed, 27 insertions(+), 17 deletions(-) diff --git a/tests/test_http_client.py b/tests/test_http_client.py index ce35e5ce4..d4d2885f3 100644 --- a/tests/test_http_client.py +++ b/tests/test_http_client.py @@ -14,16 +14,13 @@ """Tests for firebase_admin._http_client.""" import pytest -from pytest_localserver import plugin +from pytest_localserver import http import requests from firebase_admin import _http_client from tests import testutils -# Fixture for mocking a HTTP server -httpserver = plugin.httpserver - _TEST_URL = 'http://firebase.test.url/' @@ -92,40 +89,53 @@ class TestHttpRetry: @classmethod def setup_class(cls): - # Turn off exponential backoff for faster execution + # Turn off exponential backoff for faster execution. _http_client.DEFAULT_RETRY_CONFIG.backoff_factor = 0 + # Start a test server instance scoped to the class. + server = http.ContentServer() + server.start() + cls.httpserver = server + + @classmethod + def teardown_class(cls): + cls.httpserver.stop() + + def setup_method(self): + # Clean up any state in the server before starting a new test case. + self.httpserver.requests = [] + @pytest.mark.parametrize('method', ALL_METHODS) - def test_retry_on_503(self, httpserver, method): - httpserver.serve_content({}, 503) + def test_retry_on_503(self, method): + self.httpserver.serve_content({}, 503) client = _http_client.JsonHttpClient( - credential=testutils.MockGoogleCredential(), base_url=httpserver.url) + credential=testutils.MockGoogleCredential(), base_url=self.httpserver.url) body = None if method in self.ENTITY_ENCLOSING_METHODS: body = {'key': 'value'} with pytest.raises(requests.exceptions.HTTPError) as excinfo: client.request(method, '/', json=body) assert excinfo.value.response.status_code == 503 - assert len(httpserver.requests) == 5 + assert len(self.httpserver.requests) == 5 @pytest.mark.parametrize('method', ALL_METHODS) - def test_retry_on_500(self, httpserver, method): - httpserver.serve_content({}, 500) + def test_retry_on_500(self, method): + self.httpserver.serve_content({}, 500) client = _http_client.JsonHttpClient( - credential=testutils.MockGoogleCredential(), base_url=httpserver.url) + credential=testutils.MockGoogleCredential(), base_url=self.httpserver.url) body = None if method in self.ENTITY_ENCLOSING_METHODS: body = {'key': 'value'} with pytest.raises(requests.exceptions.HTTPError) as excinfo: client.request(method, '/', json=body) assert excinfo.value.response.status_code == 500 - assert len(httpserver.requests) == 5 + assert len(self.httpserver.requests) == 5 - def test_no_retry_on_404(self, httpserver): - httpserver.serve_content({}, 404) + def test_no_retry_on_404(self): + self.httpserver.serve_content({}, 404) client = _http_client.JsonHttpClient( - credential=testutils.MockGoogleCredential(), base_url=httpserver.url) + credential=testutils.MockGoogleCredential(), base_url=self.httpserver.url) with pytest.raises(requests.exceptions.HTTPError) as excinfo: client.request('get', '/') assert excinfo.value.response.status_code == 404 - assert len(httpserver.requests) == 1 + assert len(self.httpserver.requests) == 1 From 7078e9661b81ee7b44c48c566d6e1209a2960e97 Mon Sep 17 00:00:00 2001 From: Hiranya Jayathilaka Date: Wed, 15 Jan 2020 14:20:45 -0800 Subject: [PATCH 045/226] chore: Dropped the dependency on six (#385) * Removing Python 2 support * Upgraded to Pylint 2.x and fixed all linter errors for Python 3 * Removed the dependency on the six library --- firebase_admin/__init__.py | 10 ++++------ firebase_admin/_auth_utils.py | 22 ++++++++++------------ firebase_admin/_messaging_encoder.py | 14 ++++++-------- firebase_admin/_token_gen.py | 13 ++++++------- firebase_admin/_user_mgt.py | 16 ++++++++-------- firebase_admin/_utils.py | 4 ++-- firebase_admin/credentials.py | 5 ++--- firebase_admin/db.py | 21 ++++++++++----------- firebase_admin/instance_id.py | 3 +-- firebase_admin/messaging.py | 9 ++++----- firebase_admin/project_management.py | 5 ++--- firebase_admin/storage.py | 4 +--- integration/test_auth.py | 16 ++++++++-------- integration/test_db.py | 7 +++---- requirements.txt | 1 - setup.py | 1 - tests/test_exceptions.py | 4 ++-- tests/test_messaging.py | 18 ++++++++---------- tests/test_sseclient.py | 4 ++-- tests/test_token_gen.py | 3 +-- tests/test_user_mgt.py | 4 ++-- tests/testutils.py | 4 ++-- 22 files changed, 84 insertions(+), 104 deletions(-) diff --git a/firebase_admin/__init__.py b/firebase_admin/__init__.py index eae68bd06..400396266 100644 --- a/firebase_admin/__init__.py +++ b/firebase_admin/__init__.py @@ -18,8 +18,6 @@ import os import threading -import six - from firebase_admin import credentials from firebase_admin.__about__ import __version__ @@ -126,7 +124,7 @@ def get_app(name=_DEFAULT_APP_NAME): ValueError: If the specified name is not a string, or if the specified app does not exist. """ - if not isinstance(name, six.string_types): + if not isinstance(name, str): raise ValueError('Illegal app name argument type: "{}". App name ' 'must be a string.'.format(type(name))) with _apps_lock: @@ -203,7 +201,7 @@ def __init__(self, name, credential, options): Raises: ValueError: If an argument is None or invalid. """ - if not name or not isinstance(name, six.string_types): + if not name or not isinstance(name, str): raise ValueError('Illegal Firebase app name "{0}" provided. App name must be a ' 'non-empty string.'.format(name)) self._name = name @@ -221,7 +219,7 @@ def __init__(self, name, credential, options): @classmethod def _validate_project_id(cls, project_id): - if project_id is not None and not isinstance(project_id, six.string_types): + if project_id is not None and not isinstance(project_id, str): raise ValueError( 'Invalid project ID: "{0}". project ID must be a string.'.format(project_id)) @@ -286,7 +284,7 @@ def _get_service(self, name, initializer): Raises: ValueError: If the provided name is invalid, or if the App is already deleted. """ - if not name or not isinstance(name, six.string_types): + if not name or not isinstance(name, str): raise ValueError( 'Illegal name argument: "{0}". Name must be a non-empty string.'.format(name)) with self._lock: diff --git a/firebase_admin/_auth_utils.py b/firebase_admin/_auth_utils.py index b54e7d480..2f7383c0b 100644 --- a/firebase_admin/_auth_utils.py +++ b/firebase_admin/_auth_utils.py @@ -16,9 +16,7 @@ import json import re - -import six -from six.moves import urllib +from urllib import parse from firebase_admin import exceptions from firebase_admin import _utils @@ -35,7 +33,7 @@ def validate_uid(uid, required=False): if uid is None and not required: return None - if not isinstance(uid, six.string_types) or not uid or len(uid) > 128: + if not isinstance(uid, str) or not uid or len(uid) > 128: raise ValueError( 'Invalid uid: "{0}". The uid must be a non-empty string with no more than 128 ' 'characters.'.format(uid)) @@ -44,7 +42,7 @@ def validate_uid(uid, required=False): def validate_email(email, required=False): if email is None and not required: return None - if not isinstance(email, six.string_types) or not email: + if not isinstance(email, str) or not email: raise ValueError( 'Invalid email: "{0}". Email must be a non-empty string.'.format(email)) parts = email.split('@') @@ -61,7 +59,7 @@ def validate_phone(phone, required=False): """ if phone is None and not required: return None - if not isinstance(phone, six.string_types) or not phone: + if not isinstance(phone, str) or not phone: raise ValueError('Invalid phone number: "{0}". Phone number must be a non-empty ' 'string.'.format(phone)) if not phone.startswith('+') or not re.search('[a-zA-Z0-9]', phone): @@ -72,7 +70,7 @@ def validate_phone(phone, required=False): def validate_password(password, required=False): if password is None and not required: return None - if not isinstance(password, six.string_types) or len(password) < 6: + if not isinstance(password, str) or len(password) < 6: raise ValueError( 'Invalid password string. Password must be a string at least 6 characters long.') return password @@ -80,14 +78,14 @@ def validate_password(password, required=False): def validate_bytes(value, label, required=False): if value is None and not required: return None - if not isinstance(value, six.binary_type) or not value: + if not isinstance(value, bytes) or not value: raise ValueError('{0} must be a non-empty byte sequence.'.format(label)) return value def validate_display_name(display_name, required=False): if display_name is None and not required: return None - if not isinstance(display_name, six.string_types) or not display_name: + if not isinstance(display_name, str) or not display_name: raise ValueError( 'Invalid display name: "{0}". Display name must be a non-empty ' 'string.'.format(display_name)) @@ -96,7 +94,7 @@ def validate_display_name(display_name, required=False): def validate_provider_id(provider_id, required=True): if provider_id is None and not required: return None - if not isinstance(provider_id, six.string_types) or not provider_id: + if not isinstance(provider_id, str) or not provider_id: raise ValueError( 'Invalid provider ID: "{0}". Provider ID must be a non-empty ' 'string.'.format(provider_id)) @@ -106,12 +104,12 @@ def validate_photo_url(photo_url, required=False): """Parses and validates the given URL string.""" if photo_url is None and not required: return None - if not isinstance(photo_url, six.string_types) or not photo_url: + if not isinstance(photo_url, str) or not photo_url: raise ValueError( 'Invalid photo URL: "{0}". Photo URL must be a non-empty ' 'string.'.format(photo_url)) try: - parsed = urllib.parse.urlparse(photo_url) + parsed = parse.urlparse(photo_url) if not parsed.netloc: raise ValueError('Malformed photo URL: "{0}".'.format(photo_url)) return photo_url diff --git a/firebase_admin/_messaging_encoder.py b/firebase_admin/_messaging_encoder.py index aefaf3e2f..c4da53f0d 100644 --- a/firebase_admin/_messaging_encoder.py +++ b/firebase_admin/_messaging_encoder.py @@ -20,8 +20,6 @@ import numbers import re -import six - import firebase_admin._messaging_utils as _messaging_utils @@ -99,7 +97,7 @@ def check_string(cls, label, value, non_empty=False): """Checks if the given value is a string.""" if value is None: return None - if not isinstance(value, six.string_types): + if not isinstance(value, str): if non_empty: raise ValueError('{0} must be a non-empty string.'.format(label)) raise ValueError('{0} must be a string.'.format(label)) @@ -122,10 +120,10 @@ def check_string_dict(cls, label, value): return None if not isinstance(value, dict): raise ValueError('{0} must be a dictionary.'.format(label)) - non_str = [k for k in value if not isinstance(k, six.string_types)] + non_str = [k for k in value if not isinstance(k, str)] if non_str: raise ValueError('{0} must not contain non-string keys.'.format(label)) - non_str = [v for v in value.values() if not isinstance(v, six.string_types)] + non_str = [v for v in value.values() if not isinstance(v, str)] if non_str: raise ValueError('{0} must not contain non-string values.'.format(label)) return value @@ -137,7 +135,7 @@ def check_string_list(cls, label, value): return None if not isinstance(value, list): raise ValueError('{0} must be a list of strings.'.format(label)) - non_str = [k for k in value if not isinstance(k, six.string_types)] + non_str = [k for k in value if not isinstance(k, str)] if non_str: raise ValueError('{0} must not contain non-string values.'.format(label)) return value @@ -570,7 +568,7 @@ def encode_aps_sound(cls, sound): """Encodes an APNs sound configuration into JSON.""" if sound is None: return None - if sound and isinstance(sound, six.string_types): + if sound and isinstance(sound, str): return sound if not isinstance(sound, _messaging_utils.CriticalSound): raise ValueError( @@ -593,7 +591,7 @@ def encode_aps_alert(cls, alert): """Encodes an ``ApsAlert`` instance into JSON.""" if alert is None: return None - if isinstance(alert, six.string_types): + if isinstance(alert, str): return alert if not isinstance(alert, _messaging_utils.ApsAlert): raise ValueError('Aps.alert must be a string or an instance of ApsAlert class.') diff --git a/firebase_admin/_token_gen.py b/firebase_admin/_token_gen.py index 471630cca..4234bfa7b 100644 --- a/firebase_admin/_token_gen.py +++ b/firebase_admin/_token_gen.py @@ -19,7 +19,6 @@ import cachecontrol import requests -import six from google.auth import credentials from google.auth import iam from google.auth import jwt @@ -149,7 +148,7 @@ def create_custom_token(self, uid, developer_claims=None): ', '.join(disallowed_keys))) raise ValueError(error_message) - if not uid or not isinstance(uid, six.string_types) or len(uid) > 128: + if not uid or not isinstance(uid, str) or len(uid) > 128: raise ValueError('uid must be a string between 1 and 128 characters.') signing_provider = self.signing_provider @@ -174,8 +173,8 @@ def create_custom_token(self, uid, developer_claims=None): def create_session_cookie(self, id_token, expires_in): """Creates a session cookie from the provided ID token.""" - id_token = id_token.decode('utf-8') if isinstance(id_token, six.binary_type) else id_token - if not isinstance(id_token, six.text_type) or not id_token: + id_token = id_token.decode('utf-8') if isinstance(id_token, bytes) else id_token + if not isinstance(id_token, str) or not id_token: raise ValueError( 'Illegal ID token provided: {0}. ID token must be a non-empty ' 'string.'.format(id_token)) @@ -256,8 +255,8 @@ def __init__(self, **kwargs): def verify(self, token, request): """Verifies the signature and data for the provided JWT.""" - token = token.encode('utf-8') if isinstance(token, six.text_type) else token - if not isinstance(token, six.binary_type) or not token: + token = token.encode('utf-8') if isinstance(token, str) else token + if not isinstance(token, bytes) or not token: raise ValueError( 'Illegal {0} provided: {1}. {0} must be a non-empty ' 'string.'.format(self.short_name, token)) @@ -308,7 +307,7 @@ def verify(self, token, request): 'Firebase {0} has incorrect "iss" (issuer) claim. Expected "{1}" but ' 'got "{2}". {3} {4}'.format(self.short_name, expected_issuer, issuer, project_id_match_msg, verify_id_token_msg)) - elif subject is None or not isinstance(subject, six.string_types): + elif subject is None or not isinstance(subject, str): error_message = ( 'Firebase {0} has no "sub" (subject) claim. ' '{1}'.format(self.short_name, verify_id_token_msg)) diff --git a/firebase_admin/_user_mgt.py b/firebase_admin/_user_mgt.py index 5b33abb39..533259e70 100644 --- a/firebase_admin/_user_mgt.py +++ b/firebase_admin/_user_mgt.py @@ -16,9 +16,9 @@ import base64 import json +from urllib import parse + import requests -import six -from six.moves import urllib from firebase_admin import _auth_utils from firebase_admin import _user_import @@ -397,7 +397,7 @@ def encode_action_code_settings(settings): raise ValueError("Dynamic action links url is mandatory") try: - parsed = urllib.parse.urlparse(settings.url) + parsed = parse.urlparse(settings.url) if not parsed.netloc: raise ValueError('Malformed dynamic action links url: "{0}".'.format(settings.url)) parameters['continueUrl'] = settings.url @@ -413,14 +413,14 @@ def encode_action_code_settings(settings): # dynamic_link_domain if settings.dynamic_link_domain is not None: - if not isinstance(settings.dynamic_link_domain, six.string_types): + if not isinstance(settings.dynamic_link_domain, str): raise ValueError('Invalid value provided for dynamic_link_domain: {0}' .format(settings.dynamic_link_domain)) parameters['dynamicLinkDomain'] = settings.dynamic_link_domain # ios_bundle_id if settings.ios_bundle_id is not None: - if not isinstance(settings.ios_bundle_id, six.string_types): + if not isinstance(settings.ios_bundle_id, str): raise ValueError('Invalid value provided for ios_bundle_id: {0}' .format(settings.ios_bundle_id)) parameters['iosBundleId'] = settings.ios_bundle_id @@ -431,13 +431,13 @@ def encode_action_code_settings(settings): raise ValueError("Android package name is required when specifying other Android settings") if settings.android_package_name is not None: - if not isinstance(settings.android_package_name, six.string_types): + if not isinstance(settings.android_package_name, str): raise ValueError('Invalid value provided for android_package_name: {0}' .format(settings.android_package_name)) parameters['androidPackageName'] = settings.android_package_name if settings.android_minimum_version is not None: - if not isinstance(settings.android_minimum_version, six.string_types): + if not isinstance(settings.android_minimum_version, str): raise ValueError('Invalid value provided for android_minimum_version: {0}' .format(settings.android_minimum_version)) parameters['androidMinimumVersion'] = settings.android_minimum_version @@ -486,7 +486,7 @@ def get_user(self, **kwargs): def list_users(self, page_token=None, max_results=MAX_LIST_USERS_RESULTS): """Retrieves a batch of users.""" if page_token is not None: - if not isinstance(page_token, six.string_types) or not page_token: + if not isinstance(page_token, str) or not page_token: raise ValueError('Page token must be a non-empty string.') if not isinstance(max_results, int): raise ValueError('Max results must be an integer.') diff --git a/firebase_admin/_utils.py b/firebase_admin/_utils.py index 7ec1b8fb8..2c4cec868 100644 --- a/firebase_admin/_utils.py +++ b/firebase_admin/_utils.py @@ -14,13 +14,13 @@ """Internal utilities common to all modules.""" +import io import json import socket import googleapiclient import httplib2 import requests -import six import firebase_admin from firebase_admin import exceptions @@ -255,7 +255,7 @@ def handle_googleapiclient_error(error, message=None, code=None, http_response=N def _http_response_from_googleapiclient_error(error): """Creates a requests HTTP Response object from the given googleapiclient error.""" resp = requests.models.Response() - resp.raw = six.BytesIO(error.content) + resp.raw = io.BytesIO(error.content) resp.status_code = error.resp.status return resp diff --git a/firebase_admin/credentials.py b/firebase_admin/credentials.py index e930675bd..8f9c504f0 100644 --- a/firebase_admin/credentials.py +++ b/firebase_admin/credentials.py @@ -15,7 +15,6 @@ """Firebase credentials module.""" import collections import json -import six import google.auth from google.auth.transport import requests @@ -79,7 +78,7 @@ def __init__(self, cert): ValueError: If the specified certificate is invalid. """ super(Certificate, self).__init__() - if isinstance(cert, six.string_types): + if isinstance(cert, str): with open(cert) as json_file: json_data = json.load(json_file) elif isinstance(cert, dict): @@ -180,7 +179,7 @@ def __init__(self, refresh_token): ValueError: If the refresh token configuration is invalid. """ super(RefreshToken, self).__init__() - if isinstance(refresh_token, six.string_types): + if isinstance(refresh_token, str): with open(refresh_token) as json_file: json_data = json.load(json_file) elif isinstance(refresh_token, dict): diff --git a/firebase_admin/db.py b/firebase_admin/db.py index 2fb8b3a74..9092a955c 100644 --- a/firebase_admin/db.py +++ b/firebase_admin/db.py @@ -25,11 +25,10 @@ import os import sys import threading +from urllib import parse import google.auth import requests -import six -from six.moves import urllib import firebase_admin from firebase_admin import exceptions @@ -73,7 +72,7 @@ def reference(path='/', app=None, url=None): def _parse_path(path): """Parses a path string into a set of segments.""" - if not isinstance(path, six.string_types): + if not isinstance(path, str): raise ValueError('Invalid path: "{0}". Path must be a string.'.format(path)) if any(ch in path for ch in _INVALID_PATH_CHARACTERS): raise ValueError( @@ -185,7 +184,7 @@ def child(self, path): Raises: ValueError: If the child path is not a string, not well-formed or begins with '/'. """ - if not path or not isinstance(path, six.string_types): + if not path or not isinstance(path, str): raise ValueError( 'Invalid path argument: "{0}". Path must be a non-empty string.'.format(path)) if path.startswith('/'): @@ -239,7 +238,7 @@ def get_if_changed(self, etag): ValueError: If the ETag is not a string. FirebaseError: If an error occurs while communicating with the remote database server. """ - if not isinstance(etag, six.string_types): + if not isinstance(etag, str): raise ValueError('ETag must be a string.') resp = self._client.request('get', self._add_suffix(), headers={'if-none-match': etag}) @@ -285,7 +284,7 @@ def set_if_unchanged(self, expected_etag, value): FirebaseError: If an error occurs while communicating with the remote database server. """ # pylint: disable=missing-raises-doc - if not isinstance(expected_etag, six.string_types): + if not isinstance(expected_etag, str): raise ValueError('Expected ETag must be a string.') if value is None: raise ValueError('Value must not be none.') @@ -488,7 +487,7 @@ class Query: def __init__(self, **kwargs): order_by = kwargs.pop('order_by') - if not order_by or not isinstance(order_by, six.string_types): + if not order_by or not isinstance(order_by, str): raise ValueError('order_by field must be a non-empty string') if order_by not in _RESERVED_FILTERS: if order_by.startswith('/'): @@ -704,7 +703,7 @@ def _get_index_type(cls, index): return cls._type_bool_true if isinstance(index, (int, float)): return cls._type_numeric - if isinstance(index, six.string_types): + if isinstance(index, str): return cls._type_string return cls._type_object @@ -825,11 +824,11 @@ def _parse_db_url(cls, url, emulator_host=None): base URL will use emulator_host instead. emulator_host is ignored if url is already an emulator URL. """ - if not url or not isinstance(url, six.string_types): + if not url or not isinstance(url, str): raise ValueError( 'Invalid database URL: "{0}". Database URL must be a non-empty ' 'URL string.'.format(url)) - parsed_url = urllib.parse.urlparse(url) + parsed_url = parse.urlparse(url) if parsed_url.netloc.endswith('.firebaseio.com'): return cls._parse_production_url(parsed_url, emulator_host) @@ -857,7 +856,7 @@ def _parse_production_url(cls, parsed_url, emulator_host): @classmethod def _parse_emulator_url(cls, parsed_url): """Parses emulator URL like http://localhost:8080/?ns=foo-bar""" - query_ns = urllib.parse.parse_qs(parsed_url.query).get('ns') + query_ns = parse.parse_qs(parsed_url.query).get('ns') if parsed_url.scheme != 'http' or (not query_ns or len(query_ns) != 1 or not query_ns[0]): raise ValueError( 'Invalid database URL: "{0}". Database URL must be a valid URL to a ' diff --git a/firebase_admin/instance_id.py b/firebase_admin/instance_id.py index f90d058cc..604158d9c 100644 --- a/firebase_admin/instance_id.py +++ b/firebase_admin/instance_id.py @@ -18,7 +18,6 @@ """ import requests -import six from firebase_admin import _http_client from firebase_admin import _utils @@ -80,7 +79,7 @@ def __init__(self, app): credential=app.credential.get_credential(), base_url=_IID_SERVICE_URL) def delete_instance_id(self, instance_id): - if not isinstance(instance_id, six.string_types) or not instance_id: + if not isinstance(instance_id, str) or not instance_id: raise ValueError('Instance ID must be a non-empty string.') path = 'project/{0}/instanceId/{1}'.format(self._project_id, instance_id) try: diff --git a/firebase_admin/messaging.py b/firebase_admin/messaging.py index 71366e5c4..9262751a1 100644 --- a/firebase_admin/messaging.py +++ b/firebase_admin/messaging.py @@ -15,12 +15,11 @@ """Firebase Cloud Messaging module.""" import json -import requests -import six import googleapiclient from googleapiclient import http from googleapiclient import _auth +import requests import firebase_admin from firebase_admin import _http_client @@ -395,15 +394,15 @@ def batch_callback(_, response, error): def make_topic_management_request(self, tokens, topic, operation): """Invokes the IID service for topic management functionality.""" - if isinstance(tokens, six.string_types): + if isinstance(tokens, str): tokens = [tokens] if not isinstance(tokens, list) or not tokens: raise ValueError('Tokens must be a string or a non-empty list of strings.') - invalid_str = [t for t in tokens if not isinstance(t, six.string_types) or not t] + invalid_str = [t for t in tokens if not isinstance(t, str) or not t] if invalid_str: raise ValueError('Tokens must be non-empty strings.') - if not isinstance(topic, six.string_types) or not topic: + if not isinstance(topic, str) or not topic: raise ValueError('Topic must be a non-empty string.') if not topic.startswith('/topics/'): topic = '/topics/{0}'.format(topic) diff --git a/firebase_admin/project_management.py b/firebase_admin/project_management.py index 076542bda..91aa1eebb 100644 --- a/firebase_admin/project_management.py +++ b/firebase_admin/project_management.py @@ -22,7 +22,6 @@ import time import requests -import six import firebase_admin from firebase_admin import exceptions @@ -117,13 +116,13 @@ def create_ios_app(bundle_id, display_name=None, app=None): def _check_is_string_or_none(obj, field_name): - if obj is None or isinstance(obj, six.string_types): + if obj is None or isinstance(obj, str): return obj raise ValueError('{0} must be a string.'.format(field_name)) def _check_is_nonempty_string(obj, field_name): - if isinstance(obj, six.string_types) and obj: + if isinstance(obj, str) and obj: return obj raise ValueError('{0} must be a non-empty string.'.format(field_name)) diff --git a/firebase_admin/storage.py b/firebase_admin/storage.py index a080b31ef..16f48e273 100644 --- a/firebase_admin/storage.py +++ b/firebase_admin/storage.py @@ -25,8 +25,6 @@ raise ImportError('Failed to import the Cloud Storage library for Python. Make sure ' 'to install the "google-cloud-storage" module.') -import six - from firebase_admin import _utils @@ -77,7 +75,7 @@ def bucket(self, name=None): 'Storage bucket name not specified. Specify the bucket name via the ' '"storageBucket" option when initializing the App, or specify the bucket ' 'name explicitly when calling the storage.bucket() function.') - if not bucket_name or not isinstance(bucket_name, six.string_types): + if not bucket_name or not isinstance(bucket_name, str): raise ValueError( 'Invalid storage bucket name: "{0}". Bucket name must be a non-empty ' 'string.'.format(bucket_name)) diff --git a/integration/test_auth.py b/integration/test_auth.py index c3759ce12..5d26dd9f1 100644 --- a/integration/test_auth.py +++ b/integration/test_auth.py @@ -17,13 +17,13 @@ import datetime import random import time +from urllib import parse import uuid import google.oauth2.credentials from google.auth import transport import pytest import requests -import six import firebase_admin from firebase_admin import auth @@ -82,8 +82,8 @@ def _sign_in_with_email_link(email, oob_code, api_key): return resp.json().get('idToken') def _extract_link_params(link): - query = six.moves.urllib.parse.urlparse(link).query - query_dict = dict(six.moves.urllib.parse.parse_qsl(query)) + query = parse.urlparse(link).query + query_dict = dict(parse.parse_qsl(query)) return query_dict def test_custom_token(api_key): @@ -427,7 +427,7 @@ def test_import_users_with_password(api_key): def test_password_reset(new_user_email_unverified, api_key): link = auth.generate_password_reset_link(new_user_email_unverified.email) - assert isinstance(link, six.string_types) + assert isinstance(link, str) query_dict = _extract_link_params(link) user_email = _reset_password(query_dict['oobCode'], 'newPassword', api_key) assert new_user_email_unverified.email == user_email @@ -436,7 +436,7 @@ def test_password_reset(new_user_email_unverified, api_key): def test_email_verification(new_user_email_unverified, api_key): link = auth.generate_email_verification_link(new_user_email_unverified.email) - assert isinstance(link, six.string_types) + assert isinstance(link, str) query_dict = _extract_link_params(link) user_email = _verify_email(query_dict['oobCode'], api_key) assert new_user_email_unverified.email == user_email @@ -446,7 +446,7 @@ def test_password_reset_with_settings(new_user_email_unverified, api_key): action_code_settings = auth.ActionCodeSettings(ACTION_LINK_CONTINUE_URL) link = auth.generate_password_reset_link(new_user_email_unverified.email, action_code_settings=action_code_settings) - assert isinstance(link, six.string_types) + assert isinstance(link, str) query_dict = _extract_link_params(link) assert query_dict['continueUrl'] == ACTION_LINK_CONTINUE_URL user_email = _reset_password(query_dict['oobCode'], 'newPassword', api_key) @@ -458,7 +458,7 @@ def test_email_verification_with_settings(new_user_email_unverified, api_key): action_code_settings = auth.ActionCodeSettings(ACTION_LINK_CONTINUE_URL) link = auth.generate_email_verification_link(new_user_email_unverified.email, action_code_settings=action_code_settings) - assert isinstance(link, six.string_types) + assert isinstance(link, str) query_dict = _extract_link_params(link) assert query_dict['continueUrl'] == ACTION_LINK_CONTINUE_URL user_email = _verify_email(query_dict['oobCode'], api_key) @@ -469,7 +469,7 @@ def test_email_sign_in_with_settings(new_user_email_unverified, api_key): action_code_settings = auth.ActionCodeSettings(ACTION_LINK_CONTINUE_URL) link = auth.generate_sign_in_with_email_link(new_user_email_unverified.email, action_code_settings=action_code_settings) - assert isinstance(link, six.string_types) + assert isinstance(link, str) query_dict = _extract_link_params(link) assert query_dict['continueUrl'] == ACTION_LINK_CONTINUE_URL oob_code = query_dict['oobCode'] diff --git a/integration/test_db.py b/integration/test_db.py index abd02660f..7a73ea3ad 100644 --- a/integration/test_db.py +++ b/integration/test_db.py @@ -18,7 +18,6 @@ import os import pytest -import six import firebase_admin from firebase_admin import db @@ -113,7 +112,7 @@ def test_get_value_and_etag(self, testref, testdata): value, etag = testref.get(etag=True) assert isinstance(value, dict) assert testdata == value - assert isinstance(etag, six.string_types) + assert isinstance(etag, str) def test_get_shallow(self, testref): value = testref.get(shallow=True) @@ -124,7 +123,7 @@ def test_get_if_changed(self, testref, testdata): success, data, etag = testref.get_if_changed('wrong_etag') assert success is True assert data == testdata - assert isinstance(etag, six.string_types) + assert isinstance(etag, str) assert testref.get_if_changed(etag) == (False, None, None) def test_get_child_value(self, testref, testdata): @@ -211,7 +210,7 @@ def test_set_if_unchanged(self, testref): success, data, etag = edward.set_if_unchanged('invalid-etag', update_data) assert success is False assert data == push_data - assert isinstance(etag, six.string_types) + assert isinstance(etag, str) success, data, new_etag = edward.set_if_unchanged(etag, update_data) assert success is True diff --git a/requirements.txt b/requirements.txt index 6d28b38ac..d7fb6d736 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,4 +8,3 @@ google-api-core[grpc] >= 1.14.0, < 2.0.0dev; platform.python_implementation != ' google-api-python-client >= 1.7.8 google-cloud-firestore >= 1.4.0; platform.python_implementation != 'PyPy' google-cloud-storage >= 1.18.0 -six >= 1.6.1 diff --git a/setup.py b/setup.py index b492ec922..43da5eb85 100644 --- a/setup.py +++ b/setup.py @@ -42,7 +42,6 @@ 'google-api-python-client >= 1.7.8', 'google-cloud-firestore>=1.4.0; platform.python_implementation != "PyPy"', 'google-cloud-storage>=1.18.0', - 'six>=1.6.1' ] setup( diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index 3df7ec2e3..96072d91b 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import io import json import socket @@ -19,7 +20,6 @@ import pytest import requests from requests import models -import six from googleapiclient import errors from firebase_admin import exceptions @@ -174,7 +174,7 @@ def _create_response(self, status=500, payload=None): resp = models.Response() resp.status_code = status if payload: - resp.raw = six.BytesIO(payload.encode()) + resp.raw = io.BytesIO(payload.encode()) exc = requests.exceptions.RequestException('Test error', response=resp) return resp, exc diff --git a/tests/test_messaging.py b/tests/test_messaging.py index 36f5943be..33c99445b 100644 --- a/tests/test_messaging.py +++ b/tests/test_messaging.py @@ -17,10 +17,8 @@ import json import numbers -import pytest -import six - from googleapiclient.http import HttpMockSequence +import pytest import firebase_admin from firebase_admin import exceptions @@ -288,7 +286,7 @@ def test_invalid_priority(self, data): with pytest.raises(ValueError) as excinfo: check_encoding(messaging.Message( topic='topic', android=messaging.AndroidConfig(priority=data))) - if isinstance(data, six.string_types): + if isinstance(data, str): assert str(excinfo.value) == 'AndroidConfig.priority must be "high" or "normal".' else: assert str(excinfo.value) == 'AndroidConfig.priority must be a non-empty string.' @@ -405,7 +403,7 @@ def test_invalid_icon(self, data): def test_invalid_color(self, data): notification = messaging.AndroidNotification(color=data) excinfo = self._check_notification(notification) - if isinstance(data, six.string_types): + if isinstance(data, str): assert str(excinfo.value) == 'AndroidNotification.color must be in the form #RRGGBB.' else: assert str(excinfo.value) == 'AndroidNotification.color must be a non-empty string.' @@ -491,7 +489,7 @@ def test_invalid_event_timestamp(self, timestamp): def test_invalid_priority(self, priority): notification = messaging.AndroidNotification(priority=priority) excinfo = self._check_notification(notification) - if isinstance(priority, six.string_types): + if isinstance(priority, str): if not priority: expected = 'AndroidNotification.priority must be a non-empty string.' else: @@ -505,7 +503,7 @@ def test_invalid_priority(self, priority): def test_invalid_visibility(self, visibility): notification = messaging.AndroidNotification(visibility=visibility) excinfo = self._check_notification(notification) - if isinstance(visibility, six.string_types): + if isinstance(visibility, str): if not visibility: expected = 'AndroidNotification.visibility must be a non-empty string.' else: @@ -679,7 +677,7 @@ def test_invalid_color(self, data): notification = messaging.LightSettings(color=data, light_on_duration_millis=300, light_off_duration_millis=200) excinfo = self._check_light_settings(notification) - if isinstance(data, six.string_types): + if isinstance(data, str): assert str(excinfo.value) == ('LightSettings.color must be in the form #RRGGBB or ' '#RRGGBBAA.') else: @@ -853,7 +851,7 @@ def test_invalid_badge(self, data): def test_invalid_direction(self, data): notification = messaging.WebpushNotification(direction=data) excinfo = self._check_notification(notification) - if isinstance(data, six.string_types): + if isinstance(data, str): assert str(excinfo.value) == ('WebpushNotification.direction must be "auto", ' '"ltr" or "rtl".') else: @@ -2195,7 +2193,7 @@ def _get_url(self, path): @pytest.mark.parametrize('tokens', [None, '', list(), dict(), tuple()]) def test_invalid_tokens(self, tokens): expected = 'Tokens must be a string or a non-empty list of strings.' - if isinstance(tokens, six.string_types): + if isinstance(tokens, str): expected = 'Tokens must be non-empty strings.' with pytest.raises(ValueError) as excinfo: diff --git a/tests/test_sseclient.py b/tests/test_sseclient.py index 881ecc6b9..70edcf0d0 100644 --- a/tests/test_sseclient.py +++ b/tests/test_sseclient.py @@ -13,10 +13,10 @@ # limitations under the License. """Tests for firebase_admin._sseclient.""" +import io import json import requests -import six from firebase_admin import _sseclient from tests import testutils @@ -31,7 +31,7 @@ def send(self, request, **kwargs): resp = super(MockSSEClientAdapter, self).send(request, **kwargs) resp.url = request.url resp.status_code = self.status - resp.raw = six.BytesIO(self.data.encode()) + resp.raw = io.BytesIO(self.data.encode()) resp.encoding = "utf-8" return resp diff --git a/tests/test_token_gen.py b/tests/test_token_gen.py index e92fd0059..439c1ba6e 100644 --- a/tests/test_token_gen.py +++ b/tests/test_token_gen.py @@ -26,7 +26,6 @@ import google.oauth2.id_token import pytest from pytest_localserver import plugin -import six import firebase_admin from firebase_admin import auth @@ -68,7 +67,7 @@ def _merge_jwt_claims(defaults, overrides): return defaults def _verify_custom_token(custom_token, expected_claims): - assert isinstance(custom_token, six.binary_type) + assert isinstance(custom_token, bytes) token = google.oauth2.id_token.verify_token( custom_token, testutils.MockRequest(200, MOCK_PUBLIC_CERTS), diff --git a/tests/test_user_mgt.py b/tests/test_user_mgt.py index f1572baf2..9b0b4ce11 100644 --- a/tests/test_user_mgt.py +++ b/tests/test_user_mgt.py @@ -17,9 +17,9 @@ import base64 import json import time +from urllib import parse import pytest -from six.moves import urllib import firebase_admin from firebase_admin import auth @@ -772,7 +772,7 @@ def _check_rpc_calls(self, recorder, expected=None): if expected is None: expected = {'maxResults' : '1000'} assert len(recorder) == 1 - request = dict(urllib.parse.parse_qsl(urllib.parse.urlsplit(recorder[0].url).query)) + request = dict(parse.parse_qsl(parse.urlsplit(recorder[0].url).query)) assert request == expected diff --git a/tests/testutils.py b/tests/testutils.py index 9c69663a0..d0663ead1 100644 --- a/tests/testutils.py +++ b/tests/testutils.py @@ -13,13 +13,13 @@ # limitations under the License. """Common utility classes and functions for testing.""" +import io import os from google.auth import credentials from google.auth import transport from requests import adapters from requests import models -import six import firebase_admin @@ -145,7 +145,7 @@ def send(self, request, **kwargs): # pylint: disable=arguments-differ resp = models.Response() resp.url = request.url resp.status_code = self._statuses[self._current_response] - resp.raw = six.BytesIO(self._responses[self._current_response].encode()) + resp.raw = io.BytesIO(self._responses[self._current_response].encode()) self._current_response = min(self._current_response + 1, len(self._responses) - 1) return resp From 00cf1d34fdd8ca3531ff3556688c2e785db17920 Mon Sep 17 00:00:00 2001 From: Hiranya Jayathilaka Date: Fri, 17 Jan 2020 14:44:04 -0800 Subject: [PATCH 046/226] chore: Removing Python 3.4 support (#389) --- .travis.yml | 14 -------------- README.md | 4 +++- setup.py | 7 +++---- 3 files changed, 6 insertions(+), 19 deletions(-) delete mode 100644 .travis.yml diff --git a/.travis.yml b/.travis.yml deleted file mode 100644 index 8cec7e1d9..000000000 --- a/.travis.yml +++ /dev/null @@ -1,14 +0,0 @@ -language: python -python: - - "3.4" - -before_install: - - nvm install 8 && npm install -g firebase-tools -script: - - pytest - - firebase emulators:exec --only database --project fake-project-id 'pytest integration/test_db.py' -cache: - pip: true - npm: true - directories: - - $HOME/.cache/firebase/emulators diff --git a/README.md b/README.md index 8e9efd0ee..7f33af68b 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,6 @@ [![Build Status](https://travis-ci.org/firebase/firebase-admin-python.svg?branch=master)](https://travis-ci.org/firebase/firebase-admin-python) +[![Python](https://img.shields.io/pypi/pyversions/firebase-admin.svg)](https://pypi.org/project/firebase-admin/) +[![Version](https://img.shields.io/pypi/v/firebase-admin.svg)](https://pypi.org/project/firebase-admin/) # Firebase Admin Python SDK @@ -41,7 +43,7 @@ requests, code review feedback, and also pull requests. ## Supported Python Versions -We currently support Python 3.4+. Firebase Admin Python SDK is also tested on +We currently support Python 3.5+. Firebase Admin Python SDK is also tested on PyPy and [Google App Engine](https://cloud.google.com/appengine/) environments. diff --git a/setup.py b/setup.py index 43da5eb85..0ebcc3455 100644 --- a/setup.py +++ b/setup.py @@ -22,8 +22,8 @@ (major, minor) = (sys.version_info.major, sys.version_info.minor) -if major != 3 or minor < 4: - print('firebase_admin requires python >= 3.4', file=sys.stderr) +if major != 3 or minor < 5: + print('firebase_admin requires python >= 3.5', file=sys.stderr) sys.exit(1) # Read in the package metadata per recommendations from: @@ -55,13 +55,12 @@ keywords='firebase cloud development', install_requires=install_requires, packages=['firebase_admin'], - python_requires='>=3.4', + python_requires='>=3.5', classifiers=[ 'Development Status :: 5 - Production/Stable', 'Intended Audience :: Developers', 'Topic :: Software Development :: Build Tools', 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.4', 'Programming Language :: Python :: 3.5', 'Programming Language :: Python :: 3.6', 'Programming Language :: Python :: 3.7', From b5f228f68d95a4c24efa39e47e30045047db1a6e Mon Sep 17 00:00:00 2001 From: Hiranya Jayathilaka Date: Thu, 30 Jan 2020 11:10:40 -0800 Subject: [PATCH 047/226] fix: Setting a default timeout on all HTTP connections (#397) * Setting a default timeout on all HTTP connections * Refactored and cleaned up tests * Further cleaning up the tests * Cleaning up tests based on feedback --- firebase_admin/__init__.py | 4 +- firebase_admin/_http_client.py | 18 ++++++-- firebase_admin/db.py | 13 ++---- firebase_admin/messaging.py | 11 ++--- firebase_admin/project_management.py | 8 ++-- tests/test_db.py | 45 +++++++++--------- tests/test_http_client.py | 18 ++++++++ tests/test_instance_id.py | 7 +++ tests/test_messaging.py | 69 +++++++++++++++++----------- tests/test_project_management.py | 20 ++++++++ tests/test_user_mgt.py | 7 +++ 11 files changed, 149 insertions(+), 71 deletions(-) diff --git a/firebase_admin/__init__.py b/firebase_admin/__init__.py index 400396266..7e3b2eab0 100644 --- a/firebase_admin/__init__.py +++ b/firebase_admin/__init__.py @@ -49,8 +49,8 @@ def initialize_app(credential=None, options=None, name=_DEFAULT_APP_NAME): Google Application Default Credentials are used. options: A dictionary of configuration options (optional). Supported options include ``databaseURL``, ``storageBucket``, ``projectId``, ``databaseAuthVariableOverride``, - ``serviceAccountId`` and ``httpTimeout``. If ``httpTimeout`` is not set, HTTP - connections initiated by client modules such as ``db`` will not time out. + ``serviceAccountId`` and ``httpTimeout``. If ``httpTimeout`` is not set, the SDK + uses a default timeout of 120 seconds. name: Name of the app (optional). Returns: App: A newly initialized instance of App. diff --git a/firebase_admin/_http_client.py b/firebase_admin/_http_client.py index 1daaf371b..f6f0d89fa 100644 --- a/firebase_admin/_http_client.py +++ b/firebase_admin/_http_client.py @@ -32,6 +32,9 @@ raise_on_status=False, backoff_factor=0.5) +DEFAULT_TIMEOUT_SECONDS = 120 + + class HttpClient: """Base HTTP client used to make HTTP calls. @@ -41,7 +44,7 @@ class HttpClient: def __init__( self, credential=None, session=None, base_url='', headers=None, - retries=DEFAULT_RETRY_CONFIG): + retries=DEFAULT_RETRY_CONFIG, timeout=DEFAULT_TIMEOUT_SECONDS): """Creates a new HttpClient instance from the provided arguments. If a credential is provided, initializes a new HTTP session authorized with it. If neither @@ -55,6 +58,8 @@ def __init__( retries: A urllib retry configuration. Default settings would retry once for low-level connection and socket read errors, and up to 4 times for HTTP 500 and 503 errors. Pass a False value to disable retries (optional). + timeout: HTTP timeout in seconds. Defaults to 120 seconds when not specified. Set to + None to disable timeouts (optional). """ if credential: self._session = transport.requests.AuthorizedSession(credential) @@ -69,6 +74,7 @@ def __init__( self._session.mount('http://', requests.adapters.HTTPAdapter(max_retries=retries)) self._session.mount('https://', requests.adapters.HTTPAdapter(max_retries=retries)) self._base_url = base_url + self._timeout = timeout @property def session(self): @@ -78,6 +84,10 @@ def session(self): def base_url(self): return self._base_url + @property + def timeout(self): + return self._timeout + def parse_body(self, resp): raise NotImplementedError @@ -93,7 +103,7 @@ class call this method to send HTTP requests out. Refer to method: HTTP method name as a string (e.g. get, post). url: URL of the remote endpoint. kwargs: An additional set of keyword arguments to be passed into the requests API - (e.g. json, params). + (e.g. json, params, timeout). Returns: Response: An HTTP response object. @@ -101,7 +111,9 @@ class call this method to send HTTP requests out. Refer to Raises: RequestException: Any requests exceptions encountered while making the HTTP call. """ - resp = self._session.request(method, self._base_url + url, **kwargs) + if 'timeout' not in kwargs: + kwargs['timeout'] = self.timeout + resp = self._session.request(method, self.base_url + url, **kwargs) resp.raise_for_status() return resp diff --git a/firebase_admin/db.py b/firebase_admin/db.py index 9092a955c..b82a327ed 100644 --- a/firebase_admin/db.py +++ b/firebase_admin/db.py @@ -775,7 +775,7 @@ def __init__(self, app): self._auth_override = json.dumps(auth_override, separators=(',', ':')) else: self._auth_override = None - self._timeout = app.options.get('httpTimeout') + self._timeout = app.options.get('httpTimeout', _http_client.DEFAULT_TIMEOUT_SECONDS) self._clients = {} emulator_host = os.environ.get(_EMULATOR_HOST_ENV_VAR) @@ -900,14 +900,13 @@ def __init__(self, credential, base_url, timeout, params=None): credential: A Google credential that can be used to authenticate requests. base_url: A URL prefix to be added to all outgoing requests. This is typically the Firebase Realtime Database URL. - timeout: HTTP request timeout in seconds. If not set connections will never + timeout: HTTP request timeout in seconds. If set to None connections will never timeout, which is the default behavior of the underlying requests library. params: Dict of query parameters to add to all outgoing requests. """ - _http_client.JsonHttpClient.__init__( - self, credential=credential, base_url=base_url, headers={'User-Agent': _USER_AGENT}) - self.credential = credential - self.timeout = timeout + super().__init__( + credential=credential, base_url=base_url, + timeout=timeout, headers={'User-Agent': _USER_AGENT}) self.params = params if params else {} def request(self, method, url, **kwargs): @@ -937,8 +936,6 @@ def request(self, method, url, **kwargs): query = extra_params kwargs['params'] = query - if self.timeout: - kwargs['timeout'] = self.timeout try: return super(_Client, self).request(method, url, **kwargs) except requests.exceptions.RequestException as error: diff --git a/firebase_admin/messaging.py b/firebase_admin/messaging.py index 9262751a1..788875048 100644 --- a/firebase_admin/messaging.py +++ b/firebase_admin/messaging.py @@ -330,8 +330,9 @@ def __init__(self, app): 'X-GOOG-API-FORMAT-VERSION': '2', 'X-FIREBASE-CLIENT': 'fire-admin-python/{0}'.format(firebase_admin.__version__), } - self._client = _http_client.JsonHttpClient(credential=app.credential.get_credential()) - self._timeout = app.options.get('httpTimeout') + timeout = app.options.get('httpTimeout', _http_client.DEFAULT_TIMEOUT_SECONDS) + self._client = _http_client.JsonHttpClient( + credential=app.credential.get_credential(), timeout=timeout) self._transport = _auth.authorized_http(app.credential.get_credential()) @classmethod @@ -348,8 +349,7 @@ def send(self, message, dry_run=False): 'post', url=self._fcm_url, headers=self._fcm_headers, - json=data, - timeout=self._timeout + json=data ) except requests.exceptions.RequestException as error: raise self._handle_fcm_error(error) @@ -416,8 +416,7 @@ def make_topic_management_request(self, tokens, topic, operation): 'post', url=url, json=data, - headers=_MessagingService.IID_HEADERS, - timeout=self._timeout + headers=_MessagingService.IID_HEADERS ) except requests.exceptions.RequestException as error: raise self._handle_iid_error(error) diff --git a/firebase_admin/project_management.py b/firebase_admin/project_management.py index 91aa1eebb..ed292b80f 100644 --- a/firebase_admin/project_management.py +++ b/firebase_admin/project_management.py @@ -478,11 +478,12 @@ def __init__(self, app): 'the GOOGLE_CLOUD_PROJECT environment variable.') self._project_id = project_id version_header = 'Python/Admin/{0}'.format(firebase_admin.__version__) + timeout = app.options.get('httpTimeout', _http_client.DEFAULT_TIMEOUT_SECONDS) self._client = _http_client.JsonHttpClient( credential=app.credential.get_credential(), base_url=_ProjectManagementService.BASE_URL, - headers={'X-Client-Version': version_header}) - self._timeout = app.options.get('httpTimeout') + headers={'X-Client-Version': version_header}, + timeout=timeout) def get_android_app_metadata(self, app_id): return self._get_app_metadata( @@ -658,7 +659,6 @@ def _make_request(self, method, url, json=None): def _body_and_response(self, method, url, json=None): try: - return self._client.body_and_response( - method=method, url=url, json=json, timeout=self._timeout) + return self._client.body_and_response(method=method, url=url, json=json) except requests.exceptions.RequestException as error: raise _utils.handle_platform_error_from_requests(error) diff --git a/tests/test_db.py b/tests/test_db.py index b20f99cb9..1743347c5 100644 --- a/tests/test_db.py +++ b/tests/test_db.py @@ -23,6 +23,7 @@ import firebase_admin from firebase_admin import db from firebase_admin import exceptions +from firebase_admin import _http_client from firebase_admin import _sseclient from tests import testutils @@ -731,15 +732,8 @@ def test_parse_db_url_errors(self, url, emulator_host): def test_valid_db_url(self, url): firebase_admin.initialize_app(testutils.MockCredential(), {'databaseURL' : url}) ref = db.reference() - recorder = [] - adapter = MockAdapter('{}', 200, recorder) - ref._client.session.mount(url, adapter) assert ref._client.base_url == 'https://test.firebaseio.com' assert 'auth_variable_override' not in ref._client.params - assert ref._client.timeout is None - assert ref.get() == {} - assert len(recorder) == 1 - assert recorder[0]._extra_kwargs.get('timeout') is None @pytest.mark.parametrize('url', [ None, '', 'foo', 'http://test.firebaseio.com', 'https://google.com', @@ -761,7 +755,6 @@ def test_multi_db_support(self): ref = db.reference() assert ref._client.base_url == default_url assert 'auth_variable_override' not in ref._client.params - assert ref._client.timeout is None assert ref._client is db.reference()._client assert ref._client is db.reference(url=default_url)._client @@ -769,7 +762,6 @@ def test_multi_db_support(self): other_ref = db.reference(url=other_url) assert other_ref._client.base_url == other_url assert 'auth_variable_override' not in ref._client.params - assert other_ref._client.timeout is None assert other_ref._client is db.reference(url=other_url)._client assert other_ref._client is db.reference(url=other_url + '/')._client @@ -782,7 +774,6 @@ def test_valid_auth_override(self, override): default_ref = db.reference() other_ref = db.reference(url='https://other.firebaseio.com') for ref in [default_ref, other_ref]: - assert ref._client.timeout is None if override == {}: assert 'auth_variable_override' not in ref._client.params else: @@ -804,22 +795,22 @@ def test_invalid_auth_override(self, override): with pytest.raises(ValueError): db.reference(app=other_app, url='https://other.firebaseio.com') - def test_http_timeout(self): + @pytest.mark.parametrize('options, timeout', [ + ({'httpTimeout': 4}, 4), + ({'httpTimeout': None}, None), + ({}, _http_client.DEFAULT_TIMEOUT_SECONDS), + ]) + def test_http_timeout(self, options, timeout): test_url = 'https://test.firebaseio.com' - firebase_admin.initialize_app(testutils.MockCredential(), { + all_options = { 'databaseURL' : test_url, - 'httpTimeout': 60 - }) + } + all_options.update(options) + firebase_admin.initialize_app(testutils.MockCredential(), all_options) default_ref = db.reference() other_ref = db.reference(url='https://other.firebaseio.com') for ref in [default_ref, other_ref]: - recorder = [] - adapter = MockAdapter('{}', 200, recorder) - ref._client.session.mount(ref._client.base_url, adapter) - assert ref._client.timeout == 60 - assert ref.get() == {} - assert len(recorder) == 1 - assert recorder[0]._extra_kwargs['timeout'] == pytest.approx(60, 0.001) + self._check_timeout(ref, timeout) def test_app_delete(self): app = firebase_admin.initialize_app( @@ -841,6 +832,18 @@ def test_user_agent_format(self): firebase_admin.__version__, sys.version_info.major, sys.version_info.minor) assert db._USER_AGENT == expected + def _check_timeout(self, ref, timeout): + assert ref._client.timeout == timeout + recorder = [] + adapter = MockAdapter('{}', 200, recorder) + ref._client.session.mount(ref._client.base_url, adapter) + assert ref.get() == {} + assert len(recorder) == 1 + if timeout is None: + assert recorder[0]._extra_kwargs['timeout'] is None + else: + assert recorder[0]._extra_kwargs['timeout'] == pytest.approx(timeout, 0.001) + @pytest.fixture(params=['foo', '$key', '$value']) def initquery(request): diff --git a/tests/test_http_client.py b/tests/test_http_client.py index d4d2885f3..12ba03b48 100644 --- a/tests/test_http_client.py +++ b/tests/test_http_client.py @@ -74,6 +74,24 @@ def test_credential(): assert recorder[0].url == _TEST_URL assert recorder[0].headers['Authorization'] == 'Bearer mock-token' +@pytest.mark.parametrize('options, timeout', [ + ({}, _http_client.DEFAULT_TIMEOUT_SECONDS), + ({'timeout': 7}, 7), + ({'timeout': 0}, 0), + ({'timeout': None}, None), +]) +def test_timeout(options, timeout): + client = _http_client.HttpClient(**options) + assert client.timeout == timeout + recorder = _instrument(client, 'body') + client.request('get', _TEST_URL) + assert len(recorder) == 1 + if timeout is None: + assert recorder[0]._extra_kwargs['timeout'] is None + else: + assert recorder[0]._extra_kwargs['timeout'] == pytest.approx(timeout, 0.001) + + def _instrument(client, payload, status=200): recorder = [] adapter = testutils.MockAdapter(payload, status, recorder) diff --git a/tests/test_instance_id.py b/tests/test_instance_id.py index a13506a07..08b0fe6db 100644 --- a/tests/test_instance_id.py +++ b/tests/test_instance_id.py @@ -19,6 +19,7 @@ import firebase_admin from firebase_admin import exceptions from firebase_admin import instance_id +from firebase_admin import _http_client from tests import testutils @@ -73,6 +74,12 @@ def evaluate(): instance_id.delete_instance_id('test') testutils.run_without_project_id(evaluate) + def test_default_timeout(self): + cred = testutils.MockCredential() + app = firebase_admin.initialize_app(cred, {'projectId': 'explicit-project-id'}) + iid_service = instance_id._get_iid_service(app) + assert iid_service._client.timeout == _http_client.DEFAULT_TIMEOUT_SECONDS + def test_delete_instance_id(self): cred = testutils.MockCredential() app = firebase_admin.initialize_app(cred, {'projectId': 'explicit-project-id'}) diff --git a/tests/test_messaging.py b/tests/test_messaging.py index 33c99445b..f8be4cd67 100644 --- a/tests/test_messaging.py +++ b/tests/test_messaging.py @@ -23,6 +23,7 @@ import firebase_admin from firebase_admin import exceptions from firebase_admin import messaging +from firebase_admin import _http_client from tests import testutils @@ -1537,42 +1538,57 @@ def test_aps_alert_custom_data_override(self): } check_encoding(msg, expected) -class TestTimeout: - @classmethod - def setup_class(cls): - cred = testutils.MockCredential() - firebase_admin.initialize_app(cred, {'httpTimeout': 4, 'projectId': 'explicit-project-id'}) +class TestTimeout: - @classmethod - def teardown_class(cls): + def teardown(self): testutils.cleanup_apps() - def setup(self): + def _instrument_service(self, url, response): app = firebase_admin.get_app() - self.fcm_service = messaging._get_messaging_service(app) - self.recorder = [] + fcm_service = messaging._get_messaging_service(app) + recorder = [] + fcm_service._client.session.mount( + url, testutils.MockAdapter(json.dumps(response), 200, recorder)) + return recorder - def test_send(self): - self.fcm_service._client.session.mount( - 'https://fcm.googleapis.com', - testutils.MockAdapter(json.dumps({'name': 'message-id'}), 200, self.recorder)) + def _check_timeout(self, recorder, timeout): + assert len(recorder) == 1 + if timeout is None: + assert recorder[0]._extra_kwargs['timeout'] is None + else: + assert recorder[0]._extra_kwargs['timeout'] == pytest.approx(timeout, 0.001) + + @pytest.mark.parametrize('options, timeout', [ + ({'httpTimeout': 4}, 4), + ({'httpTimeout': None}, None), + ({}, _http_client.DEFAULT_TIMEOUT_SECONDS), + ]) + def test_send(self, options, timeout): + cred = testutils.MockCredential() + all_options = {'projectId': 'explicit-project-id'} + all_options.update(options) + firebase_admin.initialize_app(cred, all_options) + recorder = self._instrument_service( + 'https://fcm.googleapis.com', {'name': 'message-id'}) msg = messaging.Message(topic='foo') messaging.send(msg) - assert len(self.recorder) == 1 - assert self.recorder[0]._extra_kwargs['timeout'] == pytest.approx(4, 0.001) + self._check_timeout(recorder, timeout) - def test_topic_management_timeout(self): - self.fcm_service._client.session.mount( - 'https://iid.googleapis.com', - testutils.MockAdapter( - json.dumps({'results': [{}, {'error': 'error_reason'}]}), - 200, - self.recorder) - ) + @pytest.mark.parametrize('options, timeout', [ + ({'httpTimeout': 4}, 4), + ({'httpTimeout': None}, None), + ({}, _http_client.DEFAULT_TIMEOUT_SECONDS), + ]) + def test_topic_management_custom_timeout(self, options, timeout): + cred = testutils.MockCredential() + all_options = {'projectId': 'explicit-project-id'} + all_options.update(options) + firebase_admin.initialize_app(cred, all_options) + recorder = self._instrument_service( + 'https://iid.googleapis.com', {'results': [{}, {'error': 'error_reason'}]}) messaging.subscribe_to_topic(['1'], 'a') - assert len(self.recorder) == 1 - assert self.recorder[0]._extra_kwargs['timeout'] == pytest.approx(4, 0.001) + self._check_timeout(recorder, timeout) class TestSend: @@ -1641,7 +1657,6 @@ def test_send(self): assert recorder[0].url == self._get_url('explicit-project-id') assert recorder[0].headers['X-GOOG-API-FORMAT-VERSION'] == '2' assert recorder[0].headers['X-FIREBASE-CLIENT'] == self._CLIENT_VERSION - assert recorder[0]._extra_kwargs['timeout'] is None body = {'message': messaging._MessagingService.encode_message(msg)} assert json.loads(recorder[0].body.decode()) == body diff --git a/tests/test_project_management.py b/tests/test_project_management.py index aa717bbf7..183195510 100644 --- a/tests/test_project_management.py +++ b/tests/test_project_management.py @@ -22,6 +22,7 @@ import firebase_admin from firebase_admin import exceptions from firebase_admin import project_management +from firebase_admin import _http_client from tests import testutils OPERATION_IN_PROGRESS_RESPONSE = json.dumps({ @@ -528,6 +529,25 @@ def _assert_request_is_correct( assert json.loads(request.body.decode()) == expected_body +class TestTimeout(BaseProjectManagementTest): + + def test_default_timeout(self): + app = firebase_admin.get_app() + project_management_service = project_management._get_project_management_service(app) + assert project_management_service._client.timeout == _http_client.DEFAULT_TIMEOUT_SECONDS + + @pytest.mark.parametrize('timeout', [4, None]) + def test_custom_timeout(self, timeout): + options = { + 'httpTimeout': timeout, + 'projectId': 'test-project-id' + } + app = firebase_admin.initialize_app( + testutils.MockCredential(), options, 'timeout-{0}'.format(timeout)) + project_management_service = project_management._get_project_management_service(app) + assert project_management_service._client.timeout == timeout + + class TestCreateAndroidApp(BaseProjectManagementTest): _CREATION_URL = 'https://firebase.googleapis.com/v1beta1/projects/test-project-id/androidApps' diff --git a/tests/test_user_mgt.py b/tests/test_user_mgt.py index 9b0b4ce11..958bbf9c4 100644 --- a/tests/test_user_mgt.py +++ b/tests/test_user_mgt.py @@ -25,6 +25,7 @@ from firebase_admin import auth from firebase_admin import exceptions from firebase_admin import _auth_utils +from firebase_admin import _http_client from firebase_admin import _user_import from firebase_admin import _user_mgt from tests import testutils @@ -102,12 +103,18 @@ def _check_user_record(user, expected_uid='testuser'): class TestAuthServiceInitialization: + def test_default_timeout(self, user_mgt_app): + auth_service = auth._get_auth_service(user_mgt_app) + user_manager = auth_service.user_manager + assert user_manager._client.timeout == _http_client.DEFAULT_TIMEOUT_SECONDS + def test_fail_on_no_project_id(self): app = firebase_admin.initialize_app(testutils.MockCredential(), name='userMgt2') with pytest.raises(ValueError): auth._get_auth_service(app) firebase_admin.delete_app(app) + class TestUserRecord: # Input dict must be non-empty, and must not contain unsupported keys. From 04e2b1b67bcc1b6d3633c93716484282d951e5b6 Mon Sep 17 00:00:00 2001 From: Alastair Hendricks Date: Thu, 30 Jan 2020 21:13:42 +0200 Subject: [PATCH 048/226] Fix send_all & send_multicast snippet comment to match implementation (#376) * Fix send_all comment * Update send_multicast comment --- snippets/messaging/cloud_messaging.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/snippets/messaging/cloud_messaging.py b/snippets/messaging/cloud_messaging.py index 6dc1aad10..bb63db065 100644 --- a/snippets/messaging/cloud_messaging.py +++ b/snippets/messaging/cloud_messaging.py @@ -225,7 +225,7 @@ def unsubscribe_from_topic(): def send_all(): registration_token = 'YOUR_REGISTRATION_TOKEN' # [START send_all] - # Create a list containing up to 100 messages. + # Create a list containing up to 500 messages. messages = [ messaging.Message( notification=messaging.Notification('Price drop', '5% off all electronics'), @@ -247,7 +247,7 @@ def send_all(): def send_multicast(): # [START send_multicast] - # Create a list containing up to 100 registration tokens. + # Create a list containing up to 500 registration tokens. # These registration tokens come from the client FCM SDKs. registration_tokens = [ 'YOUR_REGISTRATION_TOKEN_1', From ffebd3cd42aa4b05fc3cf6f5a4a01442d04f8fc5 Mon Sep 17 00:00:00 2001 From: Hiranya Jayathilaka Date: Thu, 30 Jan 2020 13:34:30 -0800 Subject: [PATCH 049/226] Defined linter as a separate job (#398) * Defined linter as a separate job * Added missing env config * Fixed pylint install directive * Installing all dependencies * Disabling fail fast for build matrix * Merged with master --- .github/workflows/ci.yml | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index ca223bd5b..976767d64 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -4,9 +4,9 @@ on: [push, pull_request] jobs: build: - runs-on: ubuntu-latest strategy: + fail-fast: false matrix: python: [3.5, 3.6, 3.7, pypy3] @@ -20,11 +20,7 @@ jobs: run: | python -m pip install --upgrade pip pip install -r requirements.txt - - name: Lint with pylint - if: matrix.python == '3.7' - run: ./lint.sh all - name: Test with pytest - if: success() || failure() run: pytest - name: Set up Node.js 10 uses: actions/setup-node@v1 @@ -34,3 +30,18 @@ jobs: run: | npm install -g firebase-tools firebase emulators:exec --only database --project fake-project-id 'pytest integration/test_db.py' + + lint: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v1 + - name: Set up Python 3.7 + uses: actions/setup-python@v1 + with: + python-version: 3.7 + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -r requirements.txt + - name: Lint with pylint + run: ./lint.sh all From ccefa63b3e83587ea3266b5498908eb2b5853778 Mon Sep 17 00:00:00 2001 From: Hiranya Jayathilaka Date: Wed, 5 Feb 2020 13:58:16 -0800 Subject: [PATCH 050/226] Removing universal flag from binary dist configuration (#404) --- setup.cfg | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/setup.cfg b/setup.cfg index a038cfa0d..25c649748 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,5 +1,2 @@ -[bdist_wheel] -universal = 1 - [tool:pytest] -testpaths = tests \ No newline at end of file +testpaths = tests From 57f1603cfc9a9d5c05ddd9fb68233d6fe04b2ff9 Mon Sep 17 00:00:00 2001 From: Hiranya Jayathilaka Date: Thu, 6 Feb 2020 11:39:38 -0800 Subject: [PATCH 051/226] chore: Experimental GitHub Actions based workflow for publishing releases (#402) * Experimental release workflow * Release note generation and more improvements * Simplified release process --- .github/scripts/generate_changelog.sh | 65 ++++++++++ .github/scripts/publish_preflight_check.sh | 134 +++++++++++++++++++++ .github/workflows/release.yml | 108 +++++++++++++++++ 3 files changed, 307 insertions(+) create mode 100755 .github/scripts/generate_changelog.sh create mode 100755 .github/scripts/publish_preflight_check.sh create mode 100644 .github/workflows/release.yml diff --git a/.github/scripts/generate_changelog.sh b/.github/scripts/generate_changelog.sh new file mode 100755 index 000000000..3c97dca0c --- /dev/null +++ b/.github/scripts/generate_changelog.sh @@ -0,0 +1,65 @@ +#!/bin/bash + +set -e +set -u + +function printChangelog() { + local TITLE=$1 + shift + # Skip the sentinel value. + local ENTRIES=("${@:2}") + if [ ${#ENTRIES[@]} -ne 0 ]; then + echo "### ${TITLE}" + echo "" + for ((i = 0; i < ${#ENTRIES[@]}; i++)) + do + echo "* ${ENTRIES[$i]}" + done + echo "" + fi +} + +if [[ -z "${GITHUB_SHA}" ]]; then + GITHUB_SHA="HEAD" +fi + +LAST_TAG=`git describe --tags $(git rev-list --tags --max-count=1) 2> /dev/null` || true +if [[ -z "${LAST_TAG}" ]]; then + echo "[INFO] No tags found. Including all commits up to ${GITHUB_SHA}." + VERSION_RANGE="${GITHUB_SHA}" +else + echo "[INFO] Last release tag: ${LAST_TAG}." + COMMIT_SHA=`git show-ref -s ${LAST_TAG}` + echo "[INFO] Last release commit: ${COMMIT_SHA}." + VERSION_RANGE="${COMMIT_SHA}..${GITHUB_SHA}" + echo "[INFO] Including all commits in the range ${VERSION_RANGE}." +fi + +echo "" + +# Older versions of Bash (< 4.4) treat empty arrays as unbound variables, which triggers +# errors when referencing them. Therefore we initialize each of these arrays with an empty +# sentinel value, and later skip them. +CHANGES=("") +FIXES=("") +FEATS=("") +MISC=("") + +while read -r line +do + COMMIT_MSG=`echo ${line} | cut -d ' ' -f 2-` + if [[ $COMMIT_MSG =~ ^change(\(.*\))?: ]]; then + CHANGES+=("$COMMIT_MSG") + elif [[ $COMMIT_MSG =~ ^fix(\(.*\))?: ]]; then + FIXES+=("$COMMIT_MSG") + elif [[ $COMMIT_MSG =~ ^feat(\(.*\))?: ]]; then + FEATS+=("$COMMIT_MSG") + else + MISC+=("${COMMIT_MSG}") + fi +done < <(git log ${VERSION_RANGE} --oneline) + +printChangelog "Breaking Changes" "${CHANGES[@]}" +printChangelog "New Features" "${FEATS[@]}" +printChangelog "Bug Fixes" "${FIXES[@]}" +printChangelog "Miscellaneous" "${MISC[@]}" diff --git a/.github/scripts/publish_preflight_check.sh b/.github/scripts/publish_preflight_check.sh new file mode 100755 index 000000000..38b0be20c --- /dev/null +++ b/.github/scripts/publish_preflight_check.sh @@ -0,0 +1,134 @@ +#!/bin/bash + +###################################### Outputs ##################################### + +# 1. version: The version of this release including the 'v' prefix (e.g. v1.2.3). +# 2. changelog: Formatted changelog text for this release. + +#################################################################################### + +set -e +set -u + +function echo_info() { + local MESSAGE=$1 + echo "[INFO] ${MESSAGE}" +} + +function echo_warn() { + local MESSAGE=$1 + echo "[WARN] ${MESSAGE}" +} + +function terminate() { + echo "" + echo_warn "--------------------------------------------" + echo_warn "PREFLIGHT FAILED" + echo_warn "--------------------------------------------" + exit 1 +} + + +echo_info "Starting publish preflight check..." +echo_info "Git revision : ${GITHUB_SHA}" +echo_info "Workflow triggered by : ${GITHUB_ACTOR}" +echo_info "GitHub event : ${GITHUB_EVENT_NAME}" + + +echo_info "" +echo_info "--------------------------------------------" +echo_info "Extracting release version" +echo_info "--------------------------------------------" +echo_info "" + +readonly ABOUT_FILE="firebase_admin/__about__.py" +echo_info "Loading version from: ${ABOUT_FILE}" + +readonly VERSION_SCRIPT="exec(open('${ABOUT_FILE}').read()); print(__version__)" +readonly RELEASE_VERSION=`python -c "${VERSION_SCRIPT}"` || true +if [[ -z "${RELEASE_VERSION}" ]]; then + echo_warn "Failed to extract release version from: ${ABOUT_FILE}" + terminate +fi + +if [[ ! "${RELEASE_VERSION}" =~ ^([0-9]*)\.([0-9]*)\.([0-9]*)$ ]]; then + echo_warn "Malformed release version string: ${RELEASE_VERSION}. Exiting." + terminate +fi + +echo_info "Extracted release version: ${RELEASE_VERSION}" +echo "::set-output name=version::v${RELEASE_VERSION}" + + +echo_info "" +echo_info "--------------------------------------------" +echo_info "Checking previous releases" +echo_info "--------------------------------------------" +echo_info "" + +readonly PYPI_URL="https://pypi.org/pypi/firebase-admin/${RELEASE_VERSION}/json" +readonly PYPI_STATUS=`curl -s -o /dev/null -L -w "%{http_code}" ${PYPI_URL}` +if [[ $PYPI_STATUS -eq 404 ]]; then + echo_info "Release version ${RELEASE_VERSION} not found in Pypi." +elif [[ $PYPI_STATUS -eq 200 ]]; then + echo_warn "Release version ${RELEASE_VERSION} already present in Pypi." + terminate +else + echo_warn "Unexpected ${PYPI_STATUS} response from Pypi. Exiting." + terminate +fi + + +echo_info "" +echo_info "--------------------------------------------" +echo_info "Checking release tag" +echo_info "--------------------------------------------" +echo_info "" + +echo_info "---< git fetch --depth=1 origin +refs/tags/*:refs/tags/* >---" +git fetch --depth=1 origin +refs/tags/*:refs/tags/* +echo "" + +readonly EXISTING_TAG=`git rev-parse -q --verify "refs/tags/v${RELEASE_VERSION}"` || true +if [[ -n "${EXISTING_TAG}" ]]; then + echo_warn "Tag v${RELEASE_VERSION} already exists. Exiting." + echo_warn "If the tag was created in a previous unsuccessful attempt, delete it and try again." + echo_warn " $ git tag -d v${RELEASE_VERSION}" + echo_warn " $ git push --delete origin v${RELEASE_VERSION}" + + readonly RELEASE_URL="https://github.com/firebase/firebase-admin-python/releases/tag/v${RELEASE_VERSION}" + echo_warn "Delete any corresponding releases at ${RELEASE_URL}." + terminate +fi + +echo_info "Tag v${RELEASE_VERSION} does not exist." + + +echo_info "" +echo_info "--------------------------------------------" +echo_info "Generating changelog" +echo_info "--------------------------------------------" +echo_info "" + +echo_info "---< git fetch origin master --prune --unshallow >---" +git fetch origin master --prune --unshallow +echo "" + +echo_info "Generating changelog from history..." +readonly CURRENT_DIR=$(dirname "$0") +readonly CHANGELOG=`${CURRENT_DIR}/generate_changelog.sh` +echo "$CHANGELOG" + +# Parse and preformat the text to handle multi-line output. +# See https://github.community/t5/GitHub-Actions/set-output-Truncates-Multiline-Strings/td-p/37870 +FILTERED_CHANGELOG=`echo "$CHANGELOG" | grep -v "\\[INFO\\]"` +FILTERED_CHANGELOG="${FILTERED_CHANGELOG//'%'/'%25'}" +FILTERED_CHANGELOG="${FILTERED_CHANGELOG//$'\n'/'%0A'}" +FILTERED_CHANGELOG="${FILTERED_CHANGELOG//$'\r'/'%0D'}" +echo "::set-output name=changelog::${FILTERED_CHANGELOG}" + + +echo "" +echo_info "--------------------------------------------" +echo_info "PREFLIGHT SUCCESSFUL" +echo_info "--------------------------------------------" diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml new file mode 100644 index 000000000..670a5cab4 --- /dev/null +++ b/.github/workflows/release.yml @@ -0,0 +1,108 @@ +# Copyright 2020 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: Release + +on: + # Only run the workflow when a PR is closed, or when a developer explicitly requests + # a build by sending a 'firebase_build' event. + pull_request: + types: + - closed + + repository_dispatch: + types: + - firebase_build + +jobs: + stage_release: + # If triggered by a PR it must be merged and contain the label 'release:build'. + if: github.event.action == 'firebase_build' || + (github.event.pull_request.merged && + contains(github.event.pull_request.labels.*.name, 'release:build')) + + runs-on: ubuntu-latest + + # When manually triggering the build, the requester can specify a target branch or a tag + # via the 'ref' client parameter. + steps: + - name: Checkout source for staging + uses: actions/checkout@v2 + with: + ref: ${{ github.event.client_payload.ref || github.ref }} + + - name: Set up Python + uses: actions/setup-python@v1 + with: + python-version: 3.6 + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -r requirements.txt + + - name: Run tests + run: | + pytest + echo "Running integration tests" + + - name: Package release artifacts + run: python setup.py bdist_wheel bdist_egg + + # Attach the packaged artifacts to the workflow output. These can be manually + # downloaded for later inspection if necessary. + - name: Archive artifacts + uses: actions/upload-artifact@v1 + with: + name: dist + path: dist + + # Check whether the release should be published. We publish only when the trigger PR is + # 1. merged + # 2. to the master branch + # 3. with the title prefix '[chore] Release '. + - name: Publish preflight check + if: success() && github.event.pull_request.merged && + github.ref == 'master' && + startsWith(github.event.pull_request.title, '[chore] Release ') + id: preflight + run: | + ./.github/scripts/publish_preflight_check.sh + echo ::set-env name=FIREBASE_PUBLISH::true + + # Tag the release if not executing in the dryrun mode. We pull this action froma + # custom fork of a contributor until https://github.com/actions/create-release/pull/32 + # is merged. Also note that v1 of this action does not support the "body" parameter. + - name: Create release tag + if: success() && env.FIREBASE_PUBLISH + uses: fleskesvor/create-release@1a72e235c178bf2ae6c51a8ae36febc24568c5fe + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + with: + tag_name: ${{ steps.preflight.outputs.version }} + release_name: Firebase Admin Python SDK ${{ steps.preflight.outputs.version }} + body: ${{ steps.preflight.outputs.changelog }} + draft: false + prerelease: false + + - name: Publish to Pypi + if: success() && env.FIREBASE_PUBLISH + run: echo Publishing to Pypi + + # Post to Twitter if explicitly opted-in by adding the label 'release:tweet'. + - name: Post to Twitter + if: success() && env.FIREBASE_PUBLISH && + contains(github.event.pull_request.labels.*.name, 'release:tweet') + run: echo Posting Tweet + continue-on-error: true From 02c2ac25e7840325a5b64f795c1909ff81bb5cbc Mon Sep 17 00:00:00 2001 From: Hiranya Jayathilaka Date: Thu, 6 Feb 2020 13:56:28 -0800 Subject: [PATCH 052/226] chore: Installing wheel package during build staging (#405) --- .github/workflows/release.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 670a5cab4..6111dd7f1 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -51,6 +51,7 @@ jobs: run: | python -m pip install --upgrade pip pip install -r requirements.txt + pip install wheel - name: Run tests run: | From 814306863fe90703aa62987d4e4bf3ac4d6b1484 Mon Sep 17 00:00:00 2001 From: Hiranya Jayathilaka Date: Fri, 7 Feb 2020 13:26:27 -0800 Subject: [PATCH 053/226] chore: Running integration tests in release workflow (#406) --- .../resources/integ-service-account.json.gpg | Bin 0 -> 1733 bytes .github/scripts/run_integration_tests.sh | 11 +++++++++ .github/workflows/ci.yml | 2 +- .github/workflows/release.yml | 22 ++++++++++-------- integration/test_db.py | 2 +- 5 files changed, 25 insertions(+), 12 deletions(-) create mode 100644 .github/resources/integ-service-account.json.gpg create mode 100755 .github/scripts/run_integration_tests.sh diff --git a/.github/resources/integ-service-account.json.gpg b/.github/resources/integ-service-account.json.gpg new file mode 100644 index 0000000000000000000000000000000000000000..e8cc3e2a2a970b6760faceb0ff7dfc3e4b02c60a GIT binary patch literal 1733 zcmV;$20HnS4Fm}T0&*P1GU?%eegD$x0rvxBAXZGDSEYyg&@r!tuJO!^@&2Mg0RD-w z6$6e;;`CFWj*fn+wYp+93`zD7be^H9^(f2dArxBMIsVqnH^uk`NBpM}j)R?=LOWsA z{+*+Is!_mH0&L^@gva$WU4tSy{!^+>46sf3l~2(yGHLCQJ|AD}OSS4|w)W)bIhUCB ztLZxR_T>Ox?nQ%n2SuXok1L6Q?jrQ##qw?ASK7)@vt_ihr{>UJqi+)ws0`vx6lVaz zYoMV^widvCLEMCOzE9r4f`M#Cz#4ukZW(RbYSG4zyaS7o#kYgzHU2SEPW}TDgOeO$ zl8#OV2K4iNZ(_%sDt3-8G?e{}9)Xa63cE<5Tw~PUBhHb%f%Wp+7@Dii4KH~>ThXnc zJU+i2l19mPJ*o^DiA1mUbL4eB> zbfqZwFA9iW9)@X5^Jg;FmKH1LyrWqA*+tn$T-+S{Q!Ba4sb3{w26;5ycQyh=djxuG zE6Ef&YBF(Bdd_s{DEg>abVPH z6sfGhfLXxrTgj3Towx+c#p!H^NDd+(me+3C^t?B9kvUgNYxtM|xwc~5A3rV?MQ&Eb zRJ)mm($0#cd@i+T7wA;@d~q4WBv2zt(~SwCC&t<(Qzk>wvIas=r{SY@FlV!7%Y4X|e=JDOGfN84Xw7mlf9b{Z_{P(a;fJd&Ex zPE;unRrc%Lt6Yl#y!I` zW#bH&1CfNm3=Yxg&2XrnR$`$>mqEN6koziF5@98L^u}ja`*?#L9DqbkUsSVDO%mCg zyP&~Gf!nBhwHZTdZJZ=?(OfE%qT3Og6Pj0p(`{TM5&%c%Sp%Ly^p!vvKI|2`DN^t$ z+dont3_<4G_ds{)3CY0N32M6o8Nw=bodIN2N}2j6Kx>s0I%&=ci z8<<`m=?L9(VxxCHhes%UrPSTNQG&?qWdm{1;3a=FlA-^@>_1LqL4PsLKUIJ|lW}}ll^Rmr+`u^00-Zu3=0MK{X7=wW1xG9v zb>+uzSLWHC`%$|q#}_OwBP?rUGXCC>US#9h1-dE_KdPv%`8PBHTXX2${J{nG-7ZR7 z)j;WYkbTKBD-I^zOeBUw_AQ~!0`Ql&jp&}z(t&1V;13U#S4w>IC{WpQqf+-0w%_krosfryd7nMa+kp*`jl7_Q}9BY)T)sl`80~A!l zY=Kfl2&-TmHbwDIP%GrpRKMWoYsK z6~5b^5HMXvSl{rRG8_Innc{l2^8=++T>ul8K#~j(O_2?|&4mMfRkS+zMy2vVS}Cb;C8Qbj zh_XDiovWpqGvdj(vC6e_XKoj<`4>%Sm#lPtm?g(B1X|F~Ecq?+-dLoI<(r(K|0;(F z-GjL5-xIv&u}KCiyG&3X*caBWCC>sfo`m7udmy%Fq=do+s!GWP52a-~I()x-5i@;$ zKiD+X_ujNViv(}lsFTk0pW{G7J$y)Z-62(VZ^ literal 0 HcmV?d00001 diff --git a/.github/scripts/run_integration_tests.sh b/.github/scripts/run_integration_tests.sh new file mode 100755 index 000000000..060c5ba9f --- /dev/null +++ b/.github/scripts/run_integration_tests.sh @@ -0,0 +1,11 @@ +#!/bin/bash + +set -e +set -u + +gpg --quiet --batch --yes --decrypt --passphrase="${FIREBASE_SERVICE_ACCT_KEY}" \ + --output integ-service-account.json .github/resources/integ-service-account.json.gpg + +echo "${FIREBASE_API_KEY}" > integ-api-key.txt + +pytest integration/ --cert integ-service-account.json --apikey integ-api-key.txt diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 976767d64..5a952418c 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -1,6 +1,6 @@ name: Continuous Integration -on: [push, pull_request] +on: push jobs: build: diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 6111dd7f1..df6183952 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -15,11 +15,10 @@ name: Release on: - # Only run the workflow when a PR is closed, or when a developer explicitly requests + # Only run the workflow when a PR is updated or when a developer explicitly requests # a build by sending a 'firebase_build' event. pull_request: - types: - - closed + types: [opened, synchronize, closed] repository_dispatch: types: @@ -27,10 +26,9 @@ on: jobs: stage_release: - # If triggered by a PR it must be merged and contain the label 'release:build'. + # If triggered by a PR it must contain the label 'release:build'. if: github.event.action == 'firebase_build' || - (github.event.pull_request.merged && - contains(github.event.pull_request.labels.*.name, 'release:build')) + contains(github.event.pull_request.labels.*.name, 'release:build') runs-on: ubuntu-latest @@ -53,10 +51,14 @@ jobs: pip install -r requirements.txt pip install wheel - - name: Run tests - run: | - pytest - echo "Running integration tests" + - name: Run unit tests + run: pytest + + - name: Run integration tests + run: ./.github/scripts/run_integration_tests.sh + env: + FIREBASE_SERVICE_ACCT_KEY: ${{ secrets.FIREBASE_SERVICE_ACCT_KEY }} + FIREBASE_API_KEY: ${{ secrets.FIREBASE_API_KEY }} - name: Package release artifacts run: python setup.py bdist_wheel bdist_egg diff --git a/integration/test_db.py b/integration/test_db.py index 7a73ea3ad..c448436d6 100644 --- a/integration/test_db.py +++ b/integration/test_db.py @@ -55,7 +55,7 @@ def update_rules(app): with open(testutils.resource_filename('dinosaurs_index.json')) as rules_file: new_rules = json.load(rules_file) client = db.reference('', app)._client - rules = client.body('get', '/.settings/rules.json') + rules = client.body('get', '/.settings/rules.json', params='format=strict') existing = rules.get('rules') if existing != new_rules: rules['rules'] = new_rules From a4f2bda0c92deac408a4a5dd4d9571092708b49e Mon Sep 17 00:00:00 2001 From: Hiranya Jayathilaka Date: Mon, 10 Feb 2020 11:02:07 -0800 Subject: [PATCH 054/226] chore: Making the separation between staging and publishing explicit (#407) * chore: Running integration tests in release workflow * chore: Making the separation between staging and publishing explicit * Fixing a syntax error * Minor clean up of yml --- .github/scripts/publish_preflight_check.sh | 35 ++++++++++++++ .github/workflows/release.yml | 54 ++++++++++++++++------ 2 files changed, 74 insertions(+), 15 deletions(-) diff --git a/.github/scripts/publish_preflight_check.sh b/.github/scripts/publish_preflight_check.sh index 38b0be20c..eaf0270f8 100755 --- a/.github/scripts/publish_preflight_check.sh +++ b/.github/scripts/publish_preflight_check.sh @@ -60,6 +60,41 @@ echo_info "Extracted release version: ${RELEASE_VERSION}" echo "::set-output name=version::v${RELEASE_VERSION}" +echo_info "" +echo_info "--------------------------------------------" +echo_info "Check release artifacts" +echo_info "--------------------------------------------" +echo_info "" + +if [[ ! -d dist ]]; then + echo_warn "dist directory does not exist." + terminate +fi + +readonly BIN_DIST="dist/firebase_admin-${RELEASE_VERSION}-py3-none-any.whl" +if [[ -f "${BIN_DIST}" ]]; then + echo_info "Found binary distribution (bdist_wheel): ${BIN_DIST}" +else + echo_warn "Binary distribution ${BIN_DIST} not found." + terminate +fi + +readonly SRC_DIST="dist/firebase_admin-${RELEASE_VERSION}.tar.gz" +if [[ -f "${SRC_DIST}" ]]; then + echo_info "Found source distribution (sdist): ${SRC_DIST}" +else + echo_warn "Source distribution ${SRC_DIST} not found." + terminate +fi + +readonly ARTIFACT_COUNT=`ls dist/ | wc -l` +if [[ $ARTIFACT_COUNT -ne 2 ]]; then + echo_warn "Unexpected artifacts in the distribution directory." + ls -l dist + terminate +fi + + echo_info "" echo_info "--------------------------------------------" echo_info "Checking previous releases" diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index df6183952..b51ace956 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -name: Release +name: Release Candidate on: # Only run the workflow when a PR is updated or when a developer explicitly requests @@ -49,7 +49,7 @@ jobs: run: | python -m pip install --upgrade pip pip install -r requirements.txt - pip install wheel + pip install setuptools wheel - name: Run unit tests run: pytest @@ -60,8 +60,9 @@ jobs: FIREBASE_SERVICE_ACCT_KEY: ${{ secrets.FIREBASE_SERVICE_ACCT_KEY }} FIREBASE_API_KEY: ${{ secrets.FIREBASE_API_KEY }} + # Build the Python Wheel and the source distribution. - name: Package release artifacts - run: python setup.py bdist_wheel bdist_egg + run: python setup.py bdist_wheel sdist # Attach the packaged artifacts to the workflow output. These can be manually # downloaded for later inspection if necessary. @@ -71,24 +72,48 @@ jobs: name: dist path: dist + publish_release: + needs: stage_release + # Check whether the release should be published. We publish only when the trigger PR is # 1. merged # 2. to the master branch # 3. with the title prefix '[chore] Release '. + if: github.event.pull_request.merged && + github.ref == 'master' && + startsWith(github.event.pull_request.title, '[chore] Release ') + + runs-on: ubuntu-latest + + steps: + - name: Checkout source for publish + uses: actions/checkout@v2 + + # Download the artifacts created by the stage_release job. + - name: Download release candidates + uses: actions/download-artifact@v1 + with: + name: dist + + # Python is needed to run Twine and some of the preflight checks. + - name: Set up Python + uses: actions/setup-python@v1 + with: + python-version: 3.6 + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install twine + - name: Publish preflight check - if: success() && github.event.pull_request.merged && - github.ref == 'master' && - startsWith(github.event.pull_request.title, '[chore] Release ') id: preflight - run: | - ./.github/scripts/publish_preflight_check.sh - echo ::set-env name=FIREBASE_PUBLISH::true + run: ./.github/scripts/publish_preflight_check.sh - # Tag the release if not executing in the dryrun mode. We pull this action froma - # custom fork of a contributor until https://github.com/actions/create-release/pull/32 - # is merged. Also note that v1 of this action does not support the "body" parameter. + # We pull this action from a custom fork of a contributor until + # https://github.com/actions/create-release/pull/32 is merged. Also note that v1 of + # this action does not support the "body" parameter. - name: Create release tag - if: success() && env.FIREBASE_PUBLISH uses: fleskesvor/create-release@1a72e235c178bf2ae6c51a8ae36febc24568c5fe env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} @@ -100,12 +125,11 @@ jobs: prerelease: false - name: Publish to Pypi - if: success() && env.FIREBASE_PUBLISH run: echo Publishing to Pypi # Post to Twitter if explicitly opted-in by adding the label 'release:tweet'. - name: Post to Twitter - if: success() && env.FIREBASE_PUBLISH && + if: success() && contains(github.event.pull_request.labels.*.name, 'release:tweet') run: echo Posting Tweet continue-on-error: true From 08ed809d95b8aec311b2d17f257eb28b2c1d62a8 Mon Sep 17 00:00:00 2001 From: Hiranya Jayathilaka Date: Tue, 11 Feb 2020 14:15:15 -0800 Subject: [PATCH 055/226] Updated release trigger mechanisms (#409) * Updated release trigger mechanisms * Added license information to scripts --- .github/scripts/generate_changelog.sh | 14 ++++++++++++++ .github/scripts/publish_preflight_check.sh | 15 +++++++++++++++ .github/scripts/run_integration_tests.sh | 14 ++++++++++++++ .github/workflows/release.yml | 12 +++++++++--- 4 files changed, 52 insertions(+), 3 deletions(-) diff --git a/.github/scripts/generate_changelog.sh b/.github/scripts/generate_changelog.sh index 3c97dca0c..e393f40e4 100755 --- a/.github/scripts/generate_changelog.sh +++ b/.github/scripts/generate_changelog.sh @@ -1,5 +1,19 @@ #!/bin/bash +# Copyright 2020 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + set -e set -u diff --git a/.github/scripts/publish_preflight_check.sh b/.github/scripts/publish_preflight_check.sh index eaf0270f8..6b7b36180 100755 --- a/.github/scripts/publish_preflight_check.sh +++ b/.github/scripts/publish_preflight_check.sh @@ -1,5 +1,20 @@ #!/bin/bash +# Copyright 2020 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + ###################################### Outputs ##################################### # 1. version: The version of this release including the 'v' prefix (e.g. v1.2.3). diff --git a/.github/scripts/run_integration_tests.sh b/.github/scripts/run_integration_tests.sh index 060c5ba9f..96b0ad75d 100755 --- a/.github/scripts/run_integration_tests.sh +++ b/.github/scripts/run_integration_tests.sh @@ -1,5 +1,19 @@ #!/bin/bash +# Copyright 2020 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + set -e set -u diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index b51ace956..bc085a63f 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -26,9 +26,13 @@ on: jobs: stage_release: - # If triggered by a PR it must contain the label 'release:build'. + # To publish a release, merge the release PR with the label 'release:publish'. + # To stage a release without publishing it, send a 'firebase_build' event or apply + # the 'release:stage' label to a PR. if: github.event.action == 'firebase_build' || - contains(github.event.pull_request.labels.*.name, 'release:build') + contains(github.event.pull_request.labels.*.name, 'release:stage') || + (github.event.pull_request.merged && + contains(github.event.pull_request.labels.*.name, 'release:publish')) runs-on: ubuntu-latest @@ -78,9 +82,11 @@ jobs: # Check whether the release should be published. We publish only when the trigger PR is # 1. merged # 2. to the master branch - # 3. with the title prefix '[chore] Release '. + # 3. with the label 'release:publish', and + # 4. the title prefix '[chore] Release '. if: github.event.pull_request.merged && github.ref == 'master' && + contains(github.event.pull_request.labels.*.name, 'release:publish') && startsWith(github.event.pull_request.title, '[chore] Release ') runs-on: ubuntu-latest From 5b244a24b805bc2e4b577fcce7417367520f6ea1 Mon Sep 17 00:00:00 2001 From: Lahiru Maramba Date: Wed, 12 Feb 2020 11:38:10 -0500 Subject: [PATCH 056/226] fix(fcm): Convert event_time to UTC (#403) * fix(fcm): Convert event_time to UTC - Check if the datetime object is naive or timezone aware - If a naive datetime object is provided then set the timezone to local timezone - Convert the event_time to UTC Zulu format * Remove the third party library * Add new test case for naive event_timestamp * Consider naive datetimes are in UTC * Add a comment explaning the logic * Update docs --- firebase_admin/_messaging_encoder.py | 6 +++++- firebase_admin/_messaging_utils.py | 3 ++- tests/test_messaging.py | 29 ++++++++++++++++++++++++++-- 3 files changed, 34 insertions(+), 4 deletions(-) diff --git a/firebase_admin/_messaging_encoder.py b/firebase_admin/_messaging_encoder.py index c4da53f0d..48a3dd3cd 100644 --- a/firebase_admin/_messaging_encoder.py +++ b/firebase_admin/_messaging_encoder.py @@ -324,7 +324,11 @@ def encode_android_notification(cls, notification): event_time = result.get('event_time') if event_time: - result['event_time'] = str(event_time.isoformat()) + 'Z' + # if the datetime instance is not naive (tzinfo is present), convert to UTC + # otherwise (tzinfo is None) assume the datetime instance is already in UTC + if event_time.tzinfo is not None: + event_time = event_time.astimezone(datetime.timezone.utc) + result['event_time'] = event_time.strftime('%Y-%m-%dT%H:%M:%S.%fZ') priority = result.get('notification_priority') if priority: diff --git a/firebase_admin/_messaging_utils.py b/firebase_admin/_messaging_utils.py index 3a1943c04..d25ba5520 100644 --- a/firebase_admin/_messaging_utils.py +++ b/firebase_admin/_messaging_utils.py @@ -98,7 +98,8 @@ class AndroidNotification: the user clicks it (optional). event_timestamp: For notifications that inform users about events with an absolute time reference, sets the time that the event in the notification occurred as a - ``datetime.datetime`` instance. Notifications in the panel are sorted by this time + ``datetime.datetime`` instance. If the ``datetime.datetime`` instance is naive, it + defaults to be in the UTC timezone. Notifications in the panel are sorted by this time (optional). local_only: Sets whether or not this notification is relevant only to the current device. Some notifications can be bridged to other devices for remote display, such as a Wear OS diff --git a/tests/test_messaging.py b/tests/test_messaging.py index f8be4cd67..f2ef47cf8 100644 --- a/tests/test_messaging.py +++ b/tests/test_messaging.py @@ -547,7 +547,10 @@ def test_android_notification(self): click_action='ca', title_loc_key='tlk', body_loc_key='blk', title_loc_args=['t1', 't2'], body_loc_args=['b1', 'b2'], channel_id='c', ticker='ticker', sticky=True, - event_timestamp=datetime.datetime(2019, 10, 20, 15, 12, 23, 123), + event_timestamp=datetime.datetime( + 2019, 10, 20, 15, 12, 23, 123, + tzinfo=datetime.timezone(datetime.timedelta(hours=-5)) + ), local_only=False, priority='high', vibrate_timings_millis=[100, 50, 250], default_vibrate_timings=False, default_sound=True, @@ -577,7 +580,7 @@ def test_android_notification(self): 'channel_id': 'c', 'ticker': 'ticker', 'sticky': True, - 'event_time': '2019-10-20T15:12:23.000123Z', + 'event_time': '2019-10-20T20:12:23.000123Z', 'local_only': False, 'notification_priority': 'PRIORITY_HIGH', 'vibrate_timings': ['0.100000000s', '0.050000000s', '0.250000000s'], @@ -601,6 +604,28 @@ def test_android_notification(self): } check_encoding(msg, expected) + def test_android_notification_naive_event_timestamp(self): + event_time = datetime.datetime.now() + msg = messaging.Message( + topic='topic', + android=messaging.AndroidConfig( + notification=messaging.AndroidNotification( + title='t', + event_timestamp=event_time, + ) + ) + ) + expected = { + 'topic': 'topic', + 'android': { + 'notification': { + 'title': 't', + 'event_time': event_time.strftime('%Y-%m-%dT%H:%M:%S.%fZ') + }, + }, + } + check_encoding(msg, expected) + class TestLightSettingsEncoder: From e35a45a68885d1edfe7a28a2e75a9f1cc444f272 Mon Sep 17 00:00:00 2001 From: Hiranya Jayathilaka Date: Wed, 12 Feb 2020 11:42:16 -0800 Subject: [PATCH 057/226] chore: Implementing Pypi publish and Tweet steps (#410) * Updated release trigger mechanisms * Added license information to scripts * Added actions for publishing to Pypi and Twitter * Using shorter secret names; Pinned twitter action to a specific commit which seems to be safer --- .github/scripts/publish_preflight_check.sh | 3 +-- .github/workflows/release.yml | 26 +++++++++++----------- 2 files changed, 14 insertions(+), 15 deletions(-) diff --git a/.github/scripts/publish_preflight_check.sh b/.github/scripts/publish_preflight_check.sh index 6b7b36180..c962d8807 100755 --- a/.github/scripts/publish_preflight_check.sh +++ b/.github/scripts/publish_preflight_check.sh @@ -59,8 +59,7 @@ echo_info "" readonly ABOUT_FILE="firebase_admin/__about__.py" echo_info "Loading version from: ${ABOUT_FILE}" -readonly VERSION_SCRIPT="exec(open('${ABOUT_FILE}').read()); print(__version__)" -readonly RELEASE_VERSION=`python -c "${VERSION_SCRIPT}"` || true +readonly RELEASE_VERSION=`grep "__version__" ${ABOUT_FILE} | awk '{print $3}' | tr -d \'` || true if [[ -z "${RELEASE_VERSION}" ]]; then echo_warn "Failed to extract release version from: ${ABOUT_FILE}" terminate diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index bc085a63f..f8a9b5de8 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -101,17 +101,6 @@ jobs: with: name: dist - # Python is needed to run Twine and some of the preflight checks. - - name: Set up Python - uses: actions/setup-python@v1 - with: - python-version: 3.6 - - - name: Install dependencies - run: | - python -m pip install --upgrade pip - pip install twine - - name: Publish preflight check id: preflight run: ./.github/scripts/publish_preflight_check.sh @@ -131,11 +120,22 @@ jobs: prerelease: false - name: Publish to Pypi - run: echo Publishing to Pypi + uses: pypa/gh-action-pypi-publish@v1.0.0a0 + with: + user: firebase + password: ${{ secrets.PYPI_PASSWORD }} # Post to Twitter if explicitly opted-in by adding the label 'release:tweet'. - name: Post to Twitter if: success() && contains(github.event.pull_request.labels.*.name, 'release:tweet') - run: echo Posting Tweet + uses: ethomson/send-tweet-action@288f9339e0412e3038dce350e0da5ecdf12133a6 + with: + status: > + ${{ steps.preflight.outputs.version }} of @Firebase Admin Python SDK is avaialble. + https://github.com/firebase/firebase-admin-python/releases/tag/${{ steps.preflight.outputs.version }} + consumer-key: ${{ secrets.TWITTER_CONSUMER_KEY }} + consumer-secret: ${{ secrets.TWITTER_CONSUMER_SECRET }} + access-token: ${{ secrets.TWITTER_ACCESS_TOKEN }} + access-token-secret: ${{ secrets.TWITTER_ACCESS_TOKEN_SECRET }} continue-on-error: true From 076591f40328018fa7e36d36454b7268886472c3 Mon Sep 17 00:00:00 2001 From: Hiranya Jayathilaka Date: Wed, 19 Feb 2020 11:48:08 -0800 Subject: [PATCH 058/226] fix(fcm): Passing params as keyword arguments to googleapiclient (#414) * fix(fcm): Passing batch URI as a keyword arg to googleapiclient * Adding test case --- firebase_admin/messaging.py | 3 ++- tests/test_messaging.py | 19 +++++++++++++++++-- 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/firebase_admin/messaging.py b/firebase_admin/messaging.py index 788875048..e4e223091 100644 --- a/firebase_admin/messaging.py +++ b/firebase_admin/messaging.py @@ -372,7 +372,8 @@ def batch_callback(_, response, error): send_response = SendResponse(response, exception) responses.append(send_response) - batch = http.BatchHttpRequest(batch_callback, _MessagingService.FCM_BATCH_URL) + batch = http.BatchHttpRequest( + callback=batch_callback, batch_uri=_MessagingService.FCM_BATCH_URL) for message in messages: body = json.dumps(self._message_data(message, dry_run)) req = http.HttpRequest( diff --git a/tests/test_messaging.py b/tests/test_messaging.py index f2ef47cf8..6e776cc5f 100644 --- a/tests/test_messaging.py +++ b/tests/test_messaging.py @@ -17,7 +17,8 @@ import json import numbers -from googleapiclient.http import HttpMockSequence +from googleapiclient import http +from googleapiclient import _helpers import pytest import firebase_admin @@ -1810,7 +1811,7 @@ def _instrument_batch_messaging_service(self, app=None, status=200, payload=''): content_type = 'multipart/mixed; boundary=boundary' else: content_type = 'application/json' - fcm_service._transport = HttpMockSequence([ + fcm_service._transport = http.HttpMockSequence([ ({'status': str(status), 'content-type': content_type}, payload), ]) return fcm_service @@ -1867,6 +1868,20 @@ def test_send_all(self): assert all([r.success for r in batch_response.responses]) assert not any([r.exception for r in batch_response.responses]) + def test_send_all_with_positional_param_enforcement(self): + payload = json.dumps({'name': 'message-id'}) + _ = self._instrument_batch_messaging_service( + payload=self._batch_payload([(200, payload), (200, payload)])) + msg = messaging.Message(topic='foo') + + enforcement = _helpers.positional_parameters_enforcement + _helpers.positional_parameters_enforcement = _helpers.POSITIONAL_EXCEPTION + try: + batch_response = messaging.send_all([msg, msg], dry_run=True) + assert batch_response.success_count == 2 + finally: + _helpers.positional_parameters_enforcement = enforcement + @pytest.mark.parametrize('status', HTTP_ERROR_CODES) def test_send_all_detailed_error(self, status): success_payload = json.dumps({'name': 'message-id'}) From 2482f33500ac900492a4c1f087b650f63d09045b Mon Sep 17 00:00:00 2001 From: Hiranya Jayathilaka Date: Mon, 24 Feb 2020 11:25:22 -0800 Subject: [PATCH 059/226] [chore] Release 4.0.0 (#415) --- .github/workflows/release.yml | 2 +- firebase_admin/__about__.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index f8a9b5de8..6d626eef2 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -129,7 +129,7 @@ jobs: - name: Post to Twitter if: success() && contains(github.event.pull_request.labels.*.name, 'release:tweet') - uses: ethomson/send-tweet-action@288f9339e0412e3038dce350e0da5ecdf12133a6 + uses: firebase/firebase-admin-node/.github/actions/send-tweet@master with: status: > ${{ steps.preflight.outputs.version }} of @Firebase Admin Python SDK is avaialble. diff --git a/firebase_admin/__about__.py b/firebase_admin/__about__.py index d44e3ccb5..c1bc469bf 100644 --- a/firebase_admin/__about__.py +++ b/firebase_admin/__about__.py @@ -14,7 +14,7 @@ """About information (version, etc) for Firebase Admin SDK.""" -__version__ = '3.2.1' +__version__ = '4.0.0' __title__ = 'firebase_admin' __author__ = 'Firebase' __license__ = 'Apache License 2.0' From 6a26c152c066719d88d28ab2d1404d731fec0923 Mon Sep 17 00:00:00 2001 From: Hiranya Jayathilaka Date: Tue, 25 Feb 2020 14:40:58 -0800 Subject: [PATCH 060/226] chore: Cleaning up scripts used in the old release process (#416) --- .github/workflows/ci.yml | 2 +- .gitignore | 4 +- CONTRIBUTING.md | 6 +- scripts/bash_utils.sh | 25 ------- scripts/prepare_release.sh | 140 ------------------------------------- scripts/verify_release.sh | 45 ------------ 6 files changed, 6 insertions(+), 216 deletions(-) delete mode 100644 scripts/bash_utils.sh delete mode 100755 scripts/prepare_release.sh delete mode 100755 scripts/verify_release.sh diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 5a952418c..61d3861bd 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -1,6 +1,6 @@ name: Continuous Integration -on: push +on: pull_request jobs: build: diff --git a/.gitignore b/.gitignore index 4880bc525..79d2d5ff3 100644 --- a/.gitignore +++ b/.gitignore @@ -7,8 +7,8 @@ build/ dist/ *~ -scripts/cert.json -scripts/apikey.txt +cert.json +apikey.txt htmlcov/ .pytest_cache/ .vscode/ diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 7b521ec99..80a607a8d 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -170,9 +170,9 @@ following credentials from the project: 1. *Service account certificate*: This can be downloaded as a JSON file from the "Settings > Service Accounts" tab of the Firebase console. Copy the - file into the repo so it's available at `scripts/cert.json`. + file into the repo so it's available at `cert.json`. 2. *Web API key*: This is displayed in the "Settings > General" tab of the - console. Copy it and save to a new text file at `scripts/apikey.txt`. + console. Copy it and save to a new text file at `apikey.txt`. Then set up your Firebase/GCP project as follows: @@ -202,7 +202,7 @@ Then set up your Firebase/GCP project as follows: Now you can invoke the integration test suite as follows: ``` -pytest integration/ --cert scripts/cert.json --apikey scripts/apikey.txt +pytest integration/ --cert cert.json --apikey apikey.txt ``` ### Emulator-based Integration Testing diff --git a/scripts/bash_utils.sh b/scripts/bash_utils.sh deleted file mode 100644 index 628068fb7..000000000 --- a/scripts/bash_utils.sh +++ /dev/null @@ -1,25 +0,0 @@ -# Copyright 2017 Google Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -#!/bin/bash - -function parseVersion { - if [[ ! "$1" =~ ^([0-9]*)\.([0-9]*)\.([0-9]*)$ ]]; then - return 1 - fi - MAJOR_VERSION=$(echo "$1" | sed -e 's/^\([0-9]*\)\.\([0-9]*\)\.\([0-9]*\)$/\1/') - MINOR_VERSION=$(echo "$1" | sed -e 's/^\([0-9]*\)\.\([0-9]*\)\.\([0-9]*\)$/\2/') - PATCH_VERSION=$(echo "$1" | sed -e 's/^\([0-9]*\)\.\([0-9]*\)\.\([0-9]*\)$/\3/') - return 0 -} diff --git a/scripts/prepare_release.sh b/scripts/prepare_release.sh deleted file mode 100755 index aa55dae92..000000000 --- a/scripts/prepare_release.sh +++ /dev/null @@ -1,140 +0,0 @@ -# Copyright 2017 Google Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -#!/bin/bash - -source bash_utils.sh - -function isNewerVersion { - parseVersion "$1" - ARG_MAJOR=$MAJOR_VERSION - ARG_MINOR=$MINOR_VERSION - ARG_PATCH=$PATCH_VERSION - - parseVersion "$2" - if [ "$ARG_MAJOR" -ne "$MAJOR_VERSION" ]; then - if [ "$ARG_MAJOR" -lt "$MAJOR_VERSION" ]; then return 1; else return 0; fi; - fi - if [ "$ARG_MINOR" -ne "$MINOR_VERSION" ]; then - if [ "$ARG_MINOR" -lt "$MINOR_VERSION" ]; then return 1; else return 0; fi; - fi - if [ "$ARG_PATCH" -ne "$PATCH_VERSION" ]; then - if [ "$ARG_PATCH" -lt "$PATCH_VERSION" ]; then return 1; else return 0; fi; - fi - # The build numbers are equal - return 1 -} - -set -e - -if [ -z "$1" ]; then - echo "[ERROR] No version number provided." - echo "[INFO] Usage: ./prepare_release.sh " - exit 1 -fi - - -############################# -# VALIDATE VERSION NUMBER # -############################# - -VERSION="$1" -if ! parseVersion "$VERSION"; then - echo "[ERROR] Illegal version number provided. Version number must match semver." - exit 1 -fi - -CUR_VERSION=$(grep "^__version__ =" ../firebase_admin/__about__.py | awk '{print $3}' | sed "s/'//g") -if [ -z "$CUR_VERSION" ]; then - echo "[ERROR] Failed to find the current version. Check firebase_admin/__about__.py for version declaration." - exit 1 -fi -if ! parseVersion "$CUR_VERSION"; then - echo "[ERROR] Illegal current version number. Version number must match semver." - exit 1 -fi - -if ! isNewerVersion "$VERSION" "$CUR_VERSION"; then - echo "[ERROR] Illegal version number provided. Version $VERSION <= $CUR_VERSION" - exit 1 -fi - - -############################# -# VALIDATE TEST RESOURCES # -############################# - -if [[ ! -e "cert.json" ]]; then - echo "[ERROR] cert.json file is required to run integration tests." - exit 1 -fi - -if [[ ! -e "apikey.txt" ]]; then - echo "[ERROR] apikey.txt file is required to run integration tests." - exit 1 -fi - - -################### -# VALIDATE REPO # -################### - -# Ensure the checked out branch is master -CHECKED_OUT_BRANCH="$(git branch | grep "*" | awk -F ' ' '{print $2}')" -if [[ $CHECKED_OUT_BRANCH != "master" ]]; then - read -p "[WARN] You are on the '${CHECKED_OUT_BRANCH}' branch, not 'master'. Continue? (y/N) " CONTINUE - echo - - if ! [[ $CONTINUE == "y" || $CONTINUE == "Y" ]]; then - echo "[INFO] You chose not to continue." - exit 1 - fi -fi - -# Ensure the branch does not have local changes -if [[ $(git status --porcelain) ]]; then - read -p "[WARN] Local changes exist in the repo. Continue? (y/N) " CONTINUE - echo - - if ! [[ $CONTINUE == "y" || $CONTINUE == "Y" ]]; then - echo "[INFO] You chose not to continue." - exit 1 - fi -fi - - -#################### -# UPDATE VERSION # -#################### - -HOST=$(uname) -echo "[INFO] Updating __about__.py" -if [ $HOST == "Darwin" ]; then - sed -i "" -e "s/__version__ = '$CUR_VERSION'/__version__ = '$VERSION'/" "../firebase_admin/__about__.py" -else - sed -i -e "s/__version__ = '$CUR_VERSION'/__version__ = '$VERSION'/" "../firebase_admin/__about__.py" -fi - - -################## -# LAUNCH TESTS # -################## - -echo "[INFO] Running unit tests" -pytest ../tests - -echo "[INFO] Running integration tests" -pytest ../integration --cert cert.json --apikey apikey.txt - -echo "[INFO] This repo has been prepared for a release. Create a branch and commit the changes." diff --git a/scripts/verify_release.sh b/scripts/verify_release.sh deleted file mode 100755 index f4edd25de..000000000 --- a/scripts/verify_release.sh +++ /dev/null @@ -1,45 +0,0 @@ -# Copyright 2017 Google Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -#!/bin/bash - -source bash_utils.sh - -if [ -z "$1" ]; then - echo "[ERROR] No version number provided." - echo "[INFO] Usage: ./verify_release.sh " - exit 1 -fi - -VERSION="$1" -if ! parseVersion "$VERSION"; then - echo "[ERROR] Illegal version number provided. Version number must match semver." - exit 1 -fi - -mkdir sandbox -virtualenv sandbox -source sandbox/bin/activate -pip install firebase_admin -INSTALLED_VERSION=`python -c 'import firebase_admin; print firebase_admin.__version__'` -echo "[INFO] Installed firebase_admin version $INSTALLED_VERSION" -deactivate -rm -rf sandbox - -if [[ "$VERSION" == "$INSTALLED_VERSION" ]]; then - echo "[INFO] Release verified successfully" -else - echo "[ERROR] Installed version did not match the release version." - exit 1 -fi From 69c940cadfabfeafba384e8954aa6c8c1e17082a Mon Sep 17 00:00:00 2001 From: Hiranya Jayathilaka Date: Thu, 27 Feb 2020 13:48:35 -0800 Subject: [PATCH 061/226] fix(fcm): Updated topic management error format (#417) * fix(fcm): Updated topic management error format * Better default error messages * Removed redundant variable initializer --- firebase_admin/messaging.py | 9 ++++++--- tests/test_messaging.py | 4 ++-- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/firebase_admin/messaging.py b/firebase_admin/messaging.py index e4e223091..217cf0a56 100644 --- a/firebase_admin/messaging.py +++ b/firebase_admin/messaging.py @@ -453,9 +453,12 @@ def _handle_iid_error(self, error): except ValueError: pass - # IID error response format: {"error": "some error message"} - msg = data.get('error') - if not msg: + # IID error response format: {"error": "ErrorCode"} + code = data.get('error') + msg = None + if code: + msg = 'Error while calling the IID service: {0}'.format(code) + else: msg = 'Unexpected HTTP response with status: {0}; body: {1}'.format( error.response.status_code, error.response.content.decode()) diff --git a/tests/test_messaging.py b/tests/test_messaging.py index 6e776cc5f..6333aad46 100644 --- a/tests/test_messaging.py +++ b/tests/test_messaging.py @@ -2286,7 +2286,7 @@ def test_subscribe_to_topic_error(self, status, exc_type): status=status, payload=self._DEFAULT_ERROR_RESPONSE) with pytest.raises(exc_type) as excinfo: messaging.subscribe_to_topic('foo', 'test-topic') - assert str(excinfo.value) == 'error_reason' + assert str(excinfo.value) == 'Error while calling the IID service: error_reason' assert len(recorder) == 1 assert recorder[0].method == 'POST' assert recorder[0].url == self._get_url('iid/v1:batchAdd') @@ -2318,7 +2318,7 @@ def test_unsubscribe_from_topic_error(self, status, exc_type): status=status, payload=self._DEFAULT_ERROR_RESPONSE) with pytest.raises(exc_type) as excinfo: messaging.unsubscribe_from_topic('foo', 'test-topic') - assert str(excinfo.value) == 'error_reason' + assert str(excinfo.value) == 'Error while calling the IID service: error_reason' assert len(recorder) == 1 assert recorder[0].method == 'POST' assert recorder[0].url == self._get_url('iid/v1:batchRemove') From 9805758ef936b882f12f08cd78cc2589985235b7 Mon Sep 17 00:00:00 2001 From: Hiranya Jayathilaka Date: Fri, 28 Feb 2020 10:24:37 -0800 Subject: [PATCH 062/226] fix(rtdb): Fixed a bug in the Reference.listen() API (#418) --- firebase_admin/db.py | 12 +++++++++--- tests/test_db.py | 11 +++++++++++ 2 files changed, 20 insertions(+), 3 deletions(-) diff --git a/firebase_admin/db.py b/firebase_admin/db.py index b82a327ed..d42370317 100644 --- a/firebase_admin/db.py +++ b/firebase_admin/db.py @@ -373,8 +373,7 @@ def listen(self, callback): Raises: FirebaseError: If an error occurs while starting the initial HTTP connection. """ - session = _sseclient.KeepAuthSession(self._client.credential) - return self._listen_with_session(callback, session) + return self._listen_with_session(callback) def transaction(self, transaction_update): """Atomically modifies the data at this location. @@ -463,8 +462,11 @@ def order_by_value(self): def _add_suffix(self, suffix='.json'): return self._pathurl + suffix - def _listen_with_session(self, callback, session): + def _listen_with_session(self, callback, session=None): url = self._client.base_url + self._add_suffix() + if not session: + session = self._client.create_listener_session() + try: sse = _sseclient.SSEClient(url, session) return ListenerRegistration(callback, sse) @@ -907,6 +909,7 @@ def __init__(self, credential, base_url, timeout, params=None): super().__init__( credential=credential, base_url=base_url, timeout=timeout, headers={'User-Agent': _USER_AGENT}) + self.credential = credential self.params = params if params else {} def request(self, method, url, **kwargs): @@ -941,6 +944,9 @@ def request(self, method, url, **kwargs): except requests.exceptions.RequestException as error: raise _Client.handle_rtdb_error(error) + def create_listener_session(self): + return _sseclient.KeepAuthSession(self.credential) + @classmethod def handle_rtdb_error(cls, error): """Converts an error encountered while calling RTDB into a FirebaseError.""" diff --git a/tests/test_db.py b/tests/test_db.py index 1743347c5..2989fc030 100644 --- a/tests/test_db.py +++ b/tests/test_db.py @@ -551,6 +551,17 @@ def callback(_): finally: testutils.cleanup_apps() + def test_listener_session(self): + firebase_admin.initialize_app(testutils.MockCredential(), { + 'databaseURL' : 'https://test.firebaseio.com', + }) + try: + ref = db.reference() + session = ref._client.create_listener_session() + assert isinstance(session, _sseclient.KeepAuthSession) + finally: + testutils.cleanup_apps() + def test_single_event(self): self.events = [] def callback(event): From af1b4565bd9f9a33b6e236dd4a915f0bf45c8087 Mon Sep 17 00:00:00 2001 From: Hiranya Jayathilaka Date: Mon, 30 Mar 2020 11:49:08 -0700 Subject: [PATCH 063/226] [chore] Release 4.0.1 (#434) --- firebase_admin/__about__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/firebase_admin/__about__.py b/firebase_admin/__about__.py index c1bc469bf..d9a27bd92 100644 --- a/firebase_admin/__about__.py +++ b/firebase_admin/__about__.py @@ -14,7 +14,7 @@ """About information (version, etc) for Firebase Admin SDK.""" -__version__ = '4.0.0' +__version__ = '4.0.1' __title__ = 'firebase_admin' __author__ = 'Firebase' __license__ = 'Apache License 2.0' From 43e246de940cb7ff17ecd14ae9edbc2352a35427 Mon Sep 17 00:00:00 2001 From: ifielker Date: Mon, 20 Apr 2020 14:00:25 -0400 Subject: [PATCH 064/226] feat(ml): Adding Firebase ML management APIs (#447) * Introduced the exceptions module (#296) * Added the exceptions module * Cleaned up the error handling logic; Added tests * Updated docs; Fixed some typos * Migrating FCM Send APIs to the New Exceptions (#297) * Migrated FCM send APIs to the new error handling regime * Moved error parsing logic to _utils * Refactored OP error handling code * Fixing a broken test * Added utils for handling googleapiclient errors * Added tests for new error handling logic * Updated public API docs * Fixing test for python3 * Cleaning up the error code lookup code * Cleaning up the error parsing APIs * Cleaned up error parsing logic; Updated docs * Migrated remaining messaging APIs to new error types (#298) * Migrated FCM send APIs to the new error handling regime * Moved error parsing logic to _utils * Refactored OP error handling code * Fixing a broken test * Added utils for handling googleapiclient errors * Added tests for new error handling logic * Updated public API docs * Fixing test for python3 * Cleaning up the error code lookup code * Cleaning up the error parsing APIs * Cleaned up error parsing logic; Updated docs * Migrated the FCM IID APIs to the new error types * Introducing TokenSignError to represent custom token creation errors (#302) * Migrated FCM send APIs to the new error handling regime * Moved error parsing logic to _utils * Refactored OP error handling code * Fixing a broken test * Added utils for handling googleapiclient errors * Added tests for new error handling logic * Updated public API docs * Fixing test for python3 * Cleaning up the error code lookup code * Cleaning up the error parsing APIs * Cleaned up error parsing logic; Updated docs * Migrated the FCM IID APIs to the new error types * Migrated custom token API to new error types * Raising FirebaseError from create_session_cookie() API (#306) * Migrated FCM send APIs to the new error handling regime * Moved error parsing logic to _utils * Refactored OP error handling code * Fixing a broken test * Added utils for handling googleapiclient errors * Added tests for new error handling logic * Updated public API docs * Fixing test for python3 * Cleaning up the error code lookup code * Cleaning up the error parsing APIs * Cleaned up error parsing logic; Updated docs * Migrated the FCM IID APIs to the new error types * Migrated custom token API to new error types * Migrated create cookie API to new error types * Improved error message computation * Refactored the shared error handling code * Fixing lint errors * Renamed variable for clarity * Introducing UserNotFoundError type (#309) * Added UserNotFoundError type * Fixed some lint errors * Some formatting updates * Updated docs and tests * New error handling support in create/update/delete user APIs (#311) * New error handling support in create/update/delete user APIs * Fixing some lint errors * Error handling improvements in email action link APIs (#312) * New error handling support in create/update/delete user APIs * Fixing some lint errors * Error handling update in email action link APIs * Project management API migrated to new error types (#314) * Error handling updated for remaining user_mgt APIs (#315) * Error handling updated for remaining user_mgt APIs * Removed unused constants * Migrated token verification APIs to new exception types (#317) * Migrated token verification APIs to new error types * Removed old AuthError type * Added new exception types for revoked tokens * Migrated the db module to the new exception types (#318) * Migrating db module to new exception types * Error handling for transactions * Updated integration tests * Restoring the old txn abort behavior * Updated error type in snippet * Added comment * Adding a few overlooked error types (#319) * Adding some missing error types * Updated documentation * Removing the ability to delete user properties by passing None (#320) * Adding beginning of _MLKitService (#323) * Adding beginning of _MLKitService * Added License and Docstring * Firebase ML Kit Get Model API implementation (#326) * added GetModel * Added tests for get_model * Firebase ML Kit Delete Model API implementation (#327) * implement delete model * Firebase ML Kit List Models API implementation (#331) * implemented list models plus tests * Implementation of Model, ModelFormat, TFLiteModelSource and subclasses (#335) * Implementation of Model, ModelFormat, ModelSource and subclasses * Firebase ML Kit Create Model API implementation (#337) * create model plus long running operation handling * Model.wait_for_unlocked * Firebase ML Kit Update Model API implementation (#343) * Firebase ML Kit Create Model API implementation * Firebase ML Kit Publish and Unpublish Implementation (#345) * Firebase ML Kit Publish and Unpublish Implementation * Firebase ML Kit TFLiteGCSModelSource.from_tflite_model implementation and conversion helpers (#346) * Firebase ML Kit TFLiteGCSModelSource.from_tflite_model implementation * support for tensorflow lite conversion helpers (version 1.x) * Quick pass at filling in missing docstrings (#367) * Quick pass at filling in missing docstrings * More punctuation * Modify Operation Handling to not require a name for Done Operations (#371) * Firebase ML Kit Modify Operation Handling to not require a name for Done Operations * Adding support for TensorFlow 2.x * rename from mlkit to ml (#373) * Adding File naming capability to from_saved_model and from_keras_model. (#375) adding File naming capability for ModelSource * Firebase ML Modify Operation Handling Code to match rpc codes not html codes (#390) * Firebase ML Modify Operation Handling Code to match actual codes * apply database fix too * Mlkit fix date handling2 (#391) * Fix create/update date handling * Skip unrelated failing tests (until sync) * Firebase Ml Fix upload file naming (#392) * Fix File Naming * Integration tests for Firebase ML (#394) * Integration tests for Firebase ML * Fixing lint errors for Py3 (#401) * Fixing lint errors for Py3 * Removed dependency on six * Fixing a couple of merge errors * Modifying operation handling to support backend changes (#423) * modifying operation handling to support backend changes * Firebase ML Changing service endpoint (#421) * Mlkit add headers (#445) * add Headers * fixed test (#448) * Adding tensorflow and keras so we don't skip tests (#449) * Adding tensorflow and keras so we don't skip tests * Add additional instructions for integration tests for ml Co-authored-by: Hiranya Jayathilaka Co-authored-by: Kevin Cheung --- .github/workflows/release.yml | 2 + CONTRIBUTING.md | 9 +- firebase_admin/_utils.py | 45 ++ firebase_admin/ml.py | 938 ++++++++++++++++++++++++++ integration/test_ml.py | 373 +++++++++++ tests/data/invalid_model.tflite | 1 + tests/data/model1.tflite | Bin 0 -> 736 bytes tests/test_ml.py | 1113 +++++++++++++++++++++++++++++++ 8 files changed, 2478 insertions(+), 3 deletions(-) create mode 100644 firebase_admin/ml.py create mode 100644 integration/test_ml.py create mode 100644 tests/data/invalid_model.tflite create mode 100644 tests/data/model1.tflite create mode 100644 tests/test_ml.py diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 6d626eef2..64ee304ce 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -54,6 +54,8 @@ jobs: python -m pip install --upgrade pip pip install -r requirements.txt pip install setuptools wheel + pip install tensorflow + pip install keras - name: Run unit tests run: pytest diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 80a607a8d..f6d09b093 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -183,14 +183,17 @@ Then set up your Firebase/GCP project as follows: Firebase Console. Select the "Sign-in method" tab, and enable the "Email/Password" sign-in method, including the Email link (passwordless sign-in) option. - -3. Enable the IAM API: Go to the +3. Enable the Firebase ML API: Go to the + [Google Developers Console]( + https://console.developers.google.com/apis/api/firebaseml.googleapis.com/overview) + and make sure your project is selected. If the API is not already enabled, click Enable. +4. Enable the IAM API: Go to the [Google Cloud Platform Console](https://console.cloud.google.com) and make sure your Firebase/GCP project is selected. Select "APIs & Services > Dashboard" from the main menu, and click the "ENABLE APIS AND SERVICES" button. Search for and enable the "Identity and Access Management (IAM) API". -4. Grant your service account the 'Firebase Authentication Admin' role. This is +5. Grant your service account the 'Firebase Authentication Admin' role. This is required to ensure that exported user records contain the password hashes of the user accounts: 1. Go to [Google Cloud Platform Console / IAM & admin](https://console.cloud.google.com/iam-admin). diff --git a/firebase_admin/_utils.py b/firebase_admin/_utils.py index 2c4cec868..a5fc8d022 100644 --- a/firebase_admin/_utils.py +++ b/firebase_admin/_utils.py @@ -59,6 +59,26 @@ } +# See https://github.com/googleapis/googleapis/blob/master/google/rpc/code.proto +_RPC_CODE_TO_ERROR_CODE = { + 1: exceptions.CANCELLED, + 2: exceptions.UNKNOWN, + 3: exceptions.INVALID_ARGUMENT, + 4: exceptions.DEADLINE_EXCEEDED, + 5: exceptions.NOT_FOUND, + 6: exceptions.ALREADY_EXISTS, + 7: exceptions.PERMISSION_DENIED, + 8: exceptions.RESOURCE_EXHAUSTED, + 9: exceptions.FAILED_PRECONDITION, + 10: exceptions.ABORTED, + 11: exceptions.OUT_OF_RANGE, + 13: exceptions.INTERNAL, + 14: exceptions.UNAVAILABLE, + 15: exceptions.DATA_LOSS, + 16: exceptions.UNAUTHENTICATED, +} + + def _get_initialized_app(app): """Returns a reference to an initialized App instance.""" if app is None: @@ -75,6 +95,7 @@ def _get_initialized_app(app): ' firebase_admin.App, but given "{0}".'.format(type(app))) + def get_app_service(app, name, initializer): app = _get_initialized_app(app) return app._get_service(name, initializer) # pylint: disable=protected-access @@ -108,6 +129,27 @@ def handle_platform_error_from_requests(error, handle_func=None): return exc if exc else _handle_func_requests(error, message, error_dict) +def handle_operation_error(error): + """Constructs a ``FirebaseError`` from the given operation error. + + Args: + error: An error returned by a long running operation. + + Returns: + FirebaseError: A ``FirebaseError`` that can be raised to the user code. + """ + if not isinstance(error, dict): + return exceptions.UnknownError( + message='Unknown error while making a remote service call: {0}'.format(error), + cause=error) + + rpc_code = error.get('code') + message = error.get('message') + error_code = _rpc_code_to_error_code(rpc_code) + err_type = _error_code_to_exception_type(error_code) + return err_type(message=message) + + def _handle_func_requests(error, message, error_dict): """Constructs a ``FirebaseError`` from the given GCP error. @@ -264,6 +306,9 @@ def _http_status_to_error_code(status): """Maps an HTTP status to a platform error code.""" return _HTTP_STATUS_TO_ERROR_CODE.get(status, exceptions.UNKNOWN) +def _rpc_code_to_error_code(rpc_code): + """Maps an RPC code to a platform error code.""" + return _RPC_CODE_TO_ERROR_CODE.get(rpc_code, exceptions.UNKNOWN) def _error_code_to_exception_type(code): """Maps a platform error code to an exception type.""" diff --git a/firebase_admin/ml.py b/firebase_admin/ml.py new file mode 100644 index 000000000..db1657839 --- /dev/null +++ b/firebase_admin/ml.py @@ -0,0 +1,938 @@ +# Copyright 2019 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Firebase ML module. + +This module contains functions for creating, updating, getting, listing, +deleting, publishing and unpublishing Firebase ML models. +""" + + +import datetime +import re +import time +import os +from urllib import parse + +import requests + +import firebase_admin +from firebase_admin import _http_client +from firebase_admin import _utils +from firebase_admin import exceptions + +# pylint: disable=import-error,no-name-in-module +try: + from firebase_admin import storage + _GCS_ENABLED = True +except ImportError: + _GCS_ENABLED = False + +# pylint: disable=import-error,no-name-in-module +try: + import tensorflow as tf + _TF_ENABLED = True +except ImportError: + _TF_ENABLED = False + +_ML_ATTRIBUTE = '_ml' +_MAX_PAGE_SIZE = 100 +_MODEL_ID_PATTERN = re.compile(r'^[A-Za-z0-9_-]{1,60}$') +_DISPLAY_NAME_PATTERN = re.compile(r'^[A-Za-z0-9_-]{1,60}$') +_TAG_PATTERN = re.compile(r'^[A-Za-z0-9_-]{1,60}$') +_GCS_TFLITE_URI_PATTERN = re.compile( + r'^gs://(?P[a-z0-9_.-]{3,63})/(?P.+)$') +_RESOURCE_NAME_PATTERN = re.compile( + r'^projects/(?P[a-z0-9-]{6,30})/models/(?P[A-Za-z0-9_-]{1,60})$') +_OPERATION_NAME_PATTERN = re.compile( + r'^projects/(?P[a-z0-9-]{6,30})/operations/[^/]+$') + + +def _get_ml_service(app): + """ Returns an _MLService instance for an App. + + Args: + app: A Firebase App instance (or None to use the default App). + + Returns: + _MLService: An _MLService for the specified App instance. + + Raises: + ValueError: If the app argument is invalid. + """ + return _utils.get_app_service(app, _ML_ATTRIBUTE, _MLService) + + +def create_model(model, app=None): + """Creates a model in Firebase ML. + + Args: + model: An ml.Model to create. + app: A Firebase app instance (or None to use the default app). + + Returns: + Model: The model that was created in Firebase ML. + """ + ml_service = _get_ml_service(app) + return Model.from_dict(ml_service.create_model(model), app=app) + + +def update_model(model, app=None): + """Updates a model in Firebase ML. + + Args: + model: The ml.Model to update. + app: A Firebase app instance (or None to use the default app). + + Returns: + Model: The updated model. + """ + ml_service = _get_ml_service(app) + return Model.from_dict(ml_service.update_model(model), app=app) + + +def publish_model(model_id, app=None): + """Publishes a model in Firebase ML. + + Args: + model_id: The id of the model to publish. + app: A Firebase app instance (or None to use the default app). + + Returns: + Model: The published model. + """ + ml_service = _get_ml_service(app) + return Model.from_dict(ml_service.set_published(model_id, publish=True), app=app) + + +def unpublish_model(model_id, app=None): + """Unpublishes a model in Firebase ML. + + Args: + model_id: The id of the model to unpublish. + app: A Firebase app instance (or None to use the default app). + + Returns: + Model: The unpublished model. + """ + ml_service = _get_ml_service(app) + return Model.from_dict(ml_service.set_published(model_id, publish=False), app=app) + + +def get_model(model_id, app=None): + """Gets a model from Firebase ML. + + Args: + model_id: The id of the model to get. + app: A Firebase app instance (or None to use the default app). + + Returns: + Model: The requested model. + """ + ml_service = _get_ml_service(app) + return Model.from_dict(ml_service.get_model(model_id), app=app) + + +def list_models(list_filter=None, page_size=None, page_token=None, app=None): + """Lists models from Firebase ML. + + Args: + list_filter: a list filter string such as "tags:'tag_1'". None will return all models. + page_size: A number between 1 and 100 inclusive that specifies the maximum + number of models to return per page. None for default. + page_token: A next page token returned from a previous page of results. None + for first page of results. + app: A Firebase app instance (or None to use the default app). + + Returns: + ListModelsPage: A (filtered) list of models. + """ + ml_service = _get_ml_service(app) + return ListModelsPage( + ml_service.list_models, list_filter, page_size, page_token, app=app) + + +def delete_model(model_id, app=None): + """Deletes a model from Firebase ML. + + Args: + model_id: The id of the model you wish to delete. + app: A Firebase app instance (or None to use the default app). + """ + ml_service = _get_ml_service(app) + ml_service.delete_model(model_id) + + +class Model: + """A Firebase ML Model object. + + Args: + display_name: The display name of your model - used to identify your model in code. + tags: Optional list of strings associated with your model. Can be used in list queries. + model_format: A subclass of ModelFormat. (e.g. TFLiteFormat) Specifies the model details. + """ + def __init__(self, display_name=None, tags=None, model_format=None): + self._app = None # Only needed for wait_for_unlo + self._data = {} + self._model_format = None + + if display_name is not None: + self.display_name = display_name + if tags is not None: + self.tags = tags + if model_format is not None: + self.model_format = model_format + + @classmethod + def from_dict(cls, data, app=None): + """Create an instance of the object from a dict.""" + data_copy = dict(data) + tflite_format = None + tflite_format_data = data_copy.pop('tfliteModel', None) + data_copy.pop('@type', None) # Returned by Operations. (Not needed) + if tflite_format_data: + tflite_format = TFLiteFormat.from_dict(tflite_format_data) + model = Model(model_format=tflite_format) + model._data = data_copy # pylint: disable=protected-access + model._app = app # pylint: disable=protected-access + return model + + def _update_from_dict(self, data): + copy = Model.from_dict(data) + self.model_format = copy.model_format + self._data = copy._data # pylint: disable=protected-access + + def __eq__(self, other): + if isinstance(other, self.__class__): + # pylint: disable=protected-access + return self._data == other._data and self._model_format == other._model_format + return False + + def __ne__(self, other): + return not self.__eq__(other) + + @property + def model_id(self): + """The model's ID, unique to the project.""" + if not self._data.get('name'): + return None + _, model_id = _validate_and_parse_name(self._data.get('name')) + return model_id + + @property + def display_name(self): + """The model's display name, used to refer to the model in code and in + the Firebase console.""" + return self._data.get('displayName') + + @display_name.setter + def display_name(self, display_name): + self._data['displayName'] = _validate_display_name(display_name) + return self + + @staticmethod + def _convert_to_millis(date_string): + if not date_string: + return None + format_str = '%Y-%m-%dT%H:%M:%S.%fZ' + epoch = datetime.datetime.utcfromtimestamp(0) + datetime_object = datetime.datetime.strptime(date_string, format_str) + millis = int((datetime_object - epoch).total_seconds() * 1000) + return millis + + @property + def create_time(self): + """The time the model was created.""" + return Model._convert_to_millis(self._data.get('createTime', None)) + + @property + def update_time(self): + """The time the model was last updated.""" + return Model._convert_to_millis(self._data.get('updateTime', None)) + + @property + def validation_error(self): + """Validation error message.""" + return self._data.get('state', {}).get('validationError', {}).get('message') + + @property + def published(self): + """True if the model is published and available for clients to + download.""" + return bool(self._data.get('state', {}).get('published')) + + @property + def etag(self): + """The entity tag (ETag) of the model resource.""" + return self._data.get('etag') + + @property + def model_hash(self): + """SHA256 hash of the model binary.""" + return self._data.get('modelHash') + + @property + def tags(self): + """Tag strings, used for filtering query results.""" + return self._data.get('tags') + + @tags.setter + def tags(self, tags): + self._data['tags'] = _validate_tags(tags) + return self + + @property + def locked(self): + """True if the Model object is locked by an active operation.""" + return bool(self._data.get('activeOperations') and + len(self._data.get('activeOperations')) > 0) + + def wait_for_unlocked(self, max_time_seconds=None): + """Waits for the model to be unlocked. (All active operations complete) + + Args: + max_time_seconds: The maximum number of seconds to wait for the model to unlock. + (None for no limit) + + Raises: + exceptions.DeadlineExceeded: If max_time_seconds passed and the model is still locked. + """ + if not self.locked: + return + ml_service = _get_ml_service(self._app) + op_name = self._data.get('activeOperations')[0].get('name') + model_dict = ml_service.handle_operation( + ml_service.get_operation(op_name), + wait_for_operation=True, + max_time_seconds=max_time_seconds) + self._update_from_dict(model_dict) + + @property + def model_format(self): + """The model's ``ModelFormat`` object, which represents the model's + format and storage location.""" + return self._model_format + + @model_format.setter + def model_format(self, model_format): + if model_format is not None: + _validate_model_format(model_format) + self._model_format = model_format #Can be None + return self + + def as_dict(self, for_upload=False): + """Returns a serializable representation of the object.""" + copy = dict(self._data) + if self._model_format: + copy.update(self._model_format.as_dict(for_upload=for_upload)) + return copy + + +class ModelFormat: + """Abstract base class representing a Model Format such as TFLite.""" + def as_dict(self, for_upload=False): + """Returns a serializable representation of the object.""" + raise NotImplementedError + + +class TFLiteFormat(ModelFormat): + """Model format representing a TFLite model. + + Args: + model_source: A TFLiteModelSource sub class. Specifies the details of the model source. + """ + def __init__(self, model_source=None): + self._data = {} + self._model_source = None + + if model_source is not None: + self.model_source = model_source + + @classmethod + def from_dict(cls, data): + """Create an instance of the object from a dict.""" + data_copy = dict(data) + model_source = None + gcs_tflite_uri = data_copy.pop('gcsTfliteUri', None) + if gcs_tflite_uri: + model_source = TFLiteGCSModelSource(gcs_tflite_uri=gcs_tflite_uri) + tflite_format = TFLiteFormat(model_source=model_source) + tflite_format._data = data_copy # pylint: disable=protected-access + return tflite_format + + + def __eq__(self, other): + if isinstance(other, self.__class__): + # pylint: disable=protected-access + return self._data == other._data and self._model_source == other._model_source + return False + + def __ne__(self, other): + return not self.__eq__(other) + + @property + def model_source(self): + """The TF Lite model's location.""" + return self._model_source + + @model_source.setter + def model_source(self, model_source): + if model_source is not None: + if not isinstance(model_source, TFLiteModelSource): + raise TypeError('Model source must be a TFLiteModelSource object.') + self._model_source = model_source # Can be None + + @property + def size_bytes(self): + """The size in bytes of the TF Lite model.""" + return self._data.get('sizeBytes') + + def as_dict(self, for_upload=False): + """Returns a serializable representation of the object.""" + copy = dict(self._data) + if self._model_source: + copy.update(self._model_source.as_dict(for_upload=for_upload)) + return {'tfliteModel': copy} + + +class TFLiteModelSource: + """Abstract base class representing a model source for TFLite format models.""" + def as_dict(self, for_upload=False): + """Returns a serializable representation of the object.""" + raise NotImplementedError + + +class _CloudStorageClient: + """Cloud Storage helper class""" + + GCS_URI = 'gs://{0}/{1}' + BLOB_NAME = 'Firebase/ML/Models/{0}' + + @staticmethod + def _assert_gcs_enabled(): + if not _GCS_ENABLED: + raise ImportError('Failed to import the Cloud Storage library for Python. Make sure ' + 'to install the "google-cloud-storage" module.') + + @staticmethod + def _parse_gcs_tflite_uri(uri): + # GCS Bucket naming rules are complex. The regex is not comprehensive. + # See https://cloud.google.com/storage/docs/naming for full details. + matcher = _GCS_TFLITE_URI_PATTERN.match(uri) + if not matcher: + raise ValueError('GCS TFLite URI format is invalid.') + return matcher.group('bucket_name'), matcher.group('blob_name') + + @staticmethod + def upload(bucket_name, model_file_name, app): + """Upload a model file to the specified Storage bucket.""" + _CloudStorageClient._assert_gcs_enabled() + + file_name = os.path.basename(model_file_name) + bucket = storage.bucket(bucket_name, app=app) + blob_name = _CloudStorageClient.BLOB_NAME.format(file_name) + blob = bucket.blob(blob_name) + blob.upload_from_filename(model_file_name) + return _CloudStorageClient.GCS_URI.format(bucket.name, blob_name) + + @staticmethod + def sign_uri(gcs_tflite_uri, app): + """Makes the gcs_tflite_uri readable for GET for 10 minutes via signed_uri.""" + _CloudStorageClient._assert_gcs_enabled() + bucket_name, blob_name = _CloudStorageClient._parse_gcs_tflite_uri(gcs_tflite_uri) + bucket = storage.bucket(bucket_name, app=app) + blob = bucket.blob(blob_name) + return blob.generate_signed_url( + version='v4', + expiration=datetime.timedelta(minutes=10), + method='GET' + ) + + +class TFLiteGCSModelSource(TFLiteModelSource): + """TFLite model source representing a tflite model file stored in GCS.""" + + _STORAGE_CLIENT = _CloudStorageClient() + + def __init__(self, gcs_tflite_uri, app=None): + self._app = app + self._gcs_tflite_uri = _validate_gcs_tflite_uri(gcs_tflite_uri) + + def __eq__(self, other): + if isinstance(other, self.__class__): + return self._gcs_tflite_uri == other._gcs_tflite_uri # pylint: disable=protected-access + return False + + def __ne__(self, other): + return not self.__eq__(other) + + @classmethod + def from_tflite_model_file(cls, model_file_name, bucket_name=None, app=None): + """Uploads the model file to an existing Google Cloud Storage bucket. + + Args: + model_file_name: The name of the model file. + bucket_name: The name of an existing bucket. None to use the default bucket configured + in the app. + app: A Firebase app instance (or None to use the default app). + + Returns: + TFLiteGCSModelSource: The source created from the model_file + + Raises: + ImportError: If the Cloud Storage Library has not been installed. + """ + gcs_uri = TFLiteGCSModelSource._STORAGE_CLIENT.upload(bucket_name, model_file_name, app) + return TFLiteGCSModelSource(gcs_tflite_uri=gcs_uri, app=app) + + @staticmethod + def _assert_tf_enabled(): + if not _TF_ENABLED: + raise ImportError('Failed to import the tensorflow library for Python. Make sure ' + 'to install the tensorflow module.') + if not tf.version.VERSION.startswith('1.') and not tf.version.VERSION.startswith('2.'): + raise ImportError('Expected tensorflow version 1.x or 2.x, but found {0}' + .format(tf.version.VERSION)) + + @staticmethod + def _tf_convert_from_saved_model(saved_model_dir): + # Same for both v1.x and v2.x + converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir) + return converter.convert() + + @staticmethod + def _tf_convert_from_keras_model(keras_model): + """Converts the given Keras model into a TF Lite model.""" + # Version 1.x conversion function takes a model file. Version 2.x takes the model itself. + if tf.version.VERSION.startswith('1.'): + keras_file = 'firebase_keras_model.h5' + tf.keras.models.save_model(keras_model, keras_file) + converter = tf.lite.TFLiteConverter.from_keras_model_file(keras_file) + else: + converter = tf.lite.TFLiteConverter.from_keras_model(keras_model) + + return converter.convert() + + @classmethod + def from_saved_model(cls, saved_model_dir, model_file_name='firebase_ml_model.tflite', + bucket_name=None, app=None): + """Creates a Tensor Flow Lite model from the saved model, and uploads the model to GCS. + + Args: + saved_model_dir: The saved model directory. + model_file_name: The name that the tflite model will be saved as in Cloud Storage. + bucket_name: The name of an existing bucket. None to use the default bucket configured + in the app. + app: Optional. A Firebase app instance (or None to use the default app) + + Returns: + TFLiteGCSModelSource: The source created from the saved_model_dir + + Raises: + ImportError: If the Tensor Flow or Cloud Storage Libraries have not been installed. + """ + TFLiteGCSModelSource._assert_tf_enabled() + tflite_model = TFLiteGCSModelSource._tf_convert_from_saved_model(saved_model_dir) + with open(model_file_name, 'wb') as model_file: + model_file.write(tflite_model) + return TFLiteGCSModelSource.from_tflite_model_file(model_file_name, bucket_name, app) + + @classmethod + def from_keras_model(cls, keras_model, model_file_name='firebase_ml_model.tflite', + bucket_name=None, app=None): + """Creates a Tensor Flow Lite model from the keras model, and uploads the model to GCS. + + Args: + keras_model: A tf.keras model. + model_file_name: The name that the tflite model will be saved as in Cloud Storage. + bucket_name: The name of an existing bucket. None to use the default bucket configured + in the app. + app: Optional. A Firebase app instance (or None to use the default app) + + Returns: + TFLiteGCSModelSource: The source created from the keras_model + + Raises: + ImportError: If the Tensor Flow or Cloud Storage Libraries have not been installed. + """ + TFLiteGCSModelSource._assert_tf_enabled() + tflite_model = TFLiteGCSModelSource._tf_convert_from_keras_model(keras_model) + with open(model_file_name, 'wb') as model_file: + model_file.write(tflite_model) + return TFLiteGCSModelSource.from_tflite_model_file(model_file_name, bucket_name, app) + + @property + def gcs_tflite_uri(self): + """URI of the model file in Cloud Storage.""" + return self._gcs_tflite_uri + + @gcs_tflite_uri.setter + def gcs_tflite_uri(self, gcs_tflite_uri): + self._gcs_tflite_uri = _validate_gcs_tflite_uri(gcs_tflite_uri) + + def _get_signed_gcs_tflite_uri(self): + """Signs the GCS uri, so the model file can be uploaded to Firebase ML and verified.""" + return TFLiteGCSModelSource._STORAGE_CLIENT.sign_uri(self._gcs_tflite_uri, self._app) + + def as_dict(self, for_upload=False): + """Returns a serializable representation of the object.""" + if for_upload: + return {'gcsTfliteUri': self._get_signed_gcs_tflite_uri()} + + return {'gcsTfliteUri': self._gcs_tflite_uri} + + +class ListModelsPage: + """Represents a page of models in a firebase project. + + Provides methods for traversing the models included in this page, as well as + retrieving subsequent pages of models. The iterator returned by + ``iterate_all()`` can be used to iterate through all the models in the + Firebase project starting from this page. + """ + def __init__(self, list_models_func, list_filter, page_size, page_token, app): + self._list_models_func = list_models_func + self._list_filter = list_filter + self._page_size = page_size + self._page_token = page_token + self._app = app + self._list_response = list_models_func(list_filter, page_size, page_token) + + @property + def models(self): + """A list of Models from this page.""" + return [ + Model.from_dict(model, app=self._app) for model in self._list_response.get('models', []) + ] + + @property + def list_filter(self): + """The filter string used to filter the models.""" + return self._list_filter + + @property + def next_page_token(self): + """Token identifying the next page of results.""" + return self._list_response.get('nextPageToken', '') + + @property + def has_next_page(self): + """True if more pages are available.""" + return bool(self.next_page_token) + + def get_next_page(self): + """Retrieves the next page of models if available. + + Returns: + ListModelsPage: Next page of models, or None if this is the last page. + """ + if self.has_next_page: + return ListModelsPage( + self._list_models_func, + self._list_filter, + self._page_size, + self.next_page_token, + self._app) + return None + + def iterate_all(self): + """Retrieves an iterator for Models. + + Returned iterator will iterate through all the models in the Firebase + project starting from this page. The iterator will never buffer more than + one page of models in memory at a time. + + Returns: + iterator: An iterator of Model instances. + """ + return _ModelIterator(self) + + +class _ModelIterator: + """An iterator that allows iterating over models, one at a time. + + This implementation loads a page of models into memory, and iterates on them. + When the whole page has been traversed, it loads another page. This class + never keeps more than one page of entries in memory. + """ + def __init__(self, current_page): + if not isinstance(current_page, ListModelsPage): + raise TypeError('Current page must be a ListModelsPage') + self._current_page = current_page + self._index = 0 + + def next(self): + if self._index == len(self._current_page.models): + if self._current_page.has_next_page: + self._current_page = self._current_page.get_next_page() + self._index = 0 + if self._index < len(self._current_page.models): + result = self._current_page.models[self._index] + self._index += 1 + return result + raise StopIteration + + def __next__(self): + return self.next() + + def __iter__(self): + return self + + +def _validate_and_parse_name(name): + # The resource name is added automatically from API call responses. + # The only way it could be invalid is if someone tries to + # create a model from a dictionary manually and does it incorrectly. + matcher = _RESOURCE_NAME_PATTERN.match(name) + if not matcher: + raise ValueError('Model resource name format is invalid.') + return matcher.group('project_id'), matcher.group('model_id') + + +def _validate_model(model, update_mask=None): + if not isinstance(model, Model): + raise TypeError('Model must be an ml.Model.') + if update_mask is None and not model.display_name: + raise ValueError('Model must have a display name.') + + +def _validate_model_id(model_id): + if not _MODEL_ID_PATTERN.match(model_id): + raise ValueError('Model ID format is invalid.') + + +def _validate_operation_name(op_name): + if not _OPERATION_NAME_PATTERN.match(op_name): + raise ValueError('Operation name format is invalid.') + return op_name + + +def _validate_display_name(display_name): + if not _DISPLAY_NAME_PATTERN.match(display_name): + raise ValueError('Display name format is invalid.') + return display_name + + +def _validate_tags(tags): + if not isinstance(tags, list) or not \ + all(isinstance(tag, str) for tag in tags): + raise TypeError('Tags must be a list of strings.') + if not all(_TAG_PATTERN.match(tag) for tag in tags): + raise ValueError('Tag format is invalid.') + return tags + + +def _validate_gcs_tflite_uri(uri): + # GCS Bucket naming rules are complex. The regex is not comprehensive. + # See https://cloud.google.com/storage/docs/naming for full details. + if not _GCS_TFLITE_URI_PATTERN.match(uri): + raise ValueError('GCS TFLite URI format is invalid.') + return uri + + +def _validate_model_format(model_format): + if not isinstance(model_format, ModelFormat): + raise TypeError('Model format must be a ModelFormat object.') + return model_format + + +def _validate_list_filter(list_filter): + if list_filter is not None: + if not isinstance(list_filter, str): + raise TypeError('List filter must be a string or None.') + + +def _validate_page_size(page_size): + if page_size is not None: + if type(page_size) is not int: # pylint: disable=unidiomatic-typecheck + # Specifically type() to disallow boolean which is a subtype of int + raise TypeError('Page size must be a number or None.') + if page_size < 1 or page_size > _MAX_PAGE_SIZE: + raise ValueError('Page size must be a positive integer between ' + '1 and {0}'.format(_MAX_PAGE_SIZE)) + + +def _validate_page_token(page_token): + if page_token is not None: + if not isinstance(page_token, str): + raise TypeError('Page token must be a string or None.') + + +class _MLService: + """Firebase ML service.""" + + PROJECT_URL = 'https://firebaseml.googleapis.com/v1beta2/projects/{0}/' + OPERATION_URL = 'https://firebaseml.googleapis.com/v1beta2/' + POLL_EXPONENTIAL_BACKOFF_FACTOR = 1.5 + POLL_BASE_WAIT_TIME_SECONDS = 3 + + def __init__(self, app): + self._project_id = app.project_id + if not self._project_id: + raise ValueError( + 'Project ID is required to access ML service. Either set the ' + 'projectId option, or use service account credentials.') + self._project_url = _MLService.PROJECT_URL.format(self._project_id) + ml_headers = { + 'X-FIREBASE-CLIENT': 'fire-admin-python/{0}'.format(firebase_admin.__version__), + } + self._client = _http_client.JsonHttpClient( + credential=app.credential.get_credential(), + headers=ml_headers, + base_url=self._project_url) + self._operation_client = _http_client.JsonHttpClient( + credential=app.credential.get_credential(), + headers=ml_headers, + base_url=_MLService.OPERATION_URL) + + def get_operation(self, op_name): + _validate_operation_name(op_name) + try: + return self._operation_client.body('get', url=op_name) + except requests.exceptions.RequestException as error: + raise _utils.handle_platform_error_from_requests(error) + + def _exponential_backoff(self, current_attempt, stop_time): + """Sleeps for the appropriate amount of time. Or throws deadline exceeded.""" + delay_factor = pow(_MLService.POLL_EXPONENTIAL_BACKOFF_FACTOR, current_attempt) + wait_time_seconds = delay_factor * _MLService.POLL_BASE_WAIT_TIME_SECONDS + + if stop_time is not None: + max_seconds_left = (stop_time - datetime.datetime.now()).total_seconds() + if max_seconds_left < 1: # allow a bit of time for rpc + raise exceptions.DeadlineExceededError('Polling max time exceeded.') + wait_time_seconds = min(wait_time_seconds, max_seconds_left - 1) + time.sleep(wait_time_seconds) + + def handle_operation(self, operation, wait_for_operation=False, max_time_seconds=None): + """Handles long running operations. + + Args: + operation: The operation to handle. + wait_for_operation: Should we allow polling for the operation to complete. + If no polling is requested, a locked model will be returned instead. + max_time_seconds: The maximum seconds to try polling for operation complete. + (None for no limit) + + Returns: + dict: A dictionary of the returned model properties. + + Raises: + TypeError: if the operation is not a dictionary. + ValueError: If the operation is malformed. + UnknownError: If the server responds with an unexpected response. + err: If the operation exceeds polling attempts or stop_time + """ + if not isinstance(operation, dict): + raise TypeError('Operation must be a dictionary.') + + if operation.get('done'): + # Operations which are immediately done don't have an operation name + if operation.get('response'): + return operation.get('response') + if operation.get('error'): + raise _utils.handle_operation_error(operation.get('error')) + raise exceptions.UnknownError(message='Internal Error: Malformed Operation.') + + op_name = _validate_operation_name(operation.get('name')) + metadata = operation.get('metadata', {}) + metadata_type = metadata.get('@type', '') + if not metadata_type.endswith('ModelOperationMetadata'): + raise TypeError('Unknown type of operation metadata.') + _, model_id = _validate_and_parse_name(metadata.get('name')) + current_attempt = 0 + start_time = datetime.datetime.now() + stop_time = (None if max_time_seconds is None else + start_time + datetime.timedelta(seconds=max_time_seconds)) + while wait_for_operation and not operation.get('done'): + # We just got this operation. Wait before getting another + # so we don't exceed the GetOperation maximum request rate. + self._exponential_backoff(current_attempt, stop_time) + operation = self.get_operation(op_name) + current_attempt += 1 + + if operation.get('done'): + if operation.get('response'): + return operation.get('response') + if operation.get('error'): + raise _utils.handle_operation_error(operation.get('error')) + + # If the operation is not complete or timed out, return a (locked) model instead + return get_model(model_id).as_dict() + + + def create_model(self, model): + _validate_model(model) + try: + return self.handle_operation( + self._client.body('post', url='models', json=model.as_dict(for_upload=True))) + except requests.exceptions.RequestException as error: + raise _utils.handle_platform_error_from_requests(error) + + def update_model(self, model, update_mask=None): + _validate_model(model, update_mask) + path = 'models/{0}'.format(model.model_id) + if update_mask is not None: + path = path + '?updateMask={0}'.format(update_mask) + try: + return self.handle_operation( + self._client.body('patch', url=path, json=model.as_dict(for_upload=True))) + except requests.exceptions.RequestException as error: + raise _utils.handle_platform_error_from_requests(error) + + def set_published(self, model_id, publish): + _validate_model_id(model_id) + model_name = 'projects/{0}/models/{1}'.format(self._project_id, model_id) + model = Model.from_dict({ + 'name': model_name, + 'state': { + 'published': publish + } + }) + return self.update_model(model, update_mask='state.published') + + def get_model(self, model_id): + _validate_model_id(model_id) + try: + return self._client.body('get', url='models/{0}'.format(model_id)) + except requests.exceptions.RequestException as error: + raise _utils.handle_platform_error_from_requests(error) + + def list_models(self, list_filter, page_size, page_token): + """ lists Firebase ML models.""" + _validate_list_filter(list_filter) + _validate_page_size(page_size) + _validate_page_token(page_token) + params = {} + if list_filter: + params['filter'] = list_filter + if page_size: + params['page_size'] = page_size + if page_token: + params['page_token'] = page_token + path = 'models' + if params: + param_str = parse.urlencode(sorted(params.items()), True) + path = path + '?' + param_str + try: + return self._client.body('get', url=path) + except requests.exceptions.RequestException as error: + raise _utils.handle_platform_error_from_requests(error) + + def delete_model(self, model_id): + _validate_model_id(model_id) + try: + self._client.body('delete', url='models/{0}'.format(model_id)) + except requests.exceptions.RequestException as error: + raise _utils.handle_platform_error_from_requests(error) diff --git a/integration/test_ml.py b/integration/test_ml.py new file mode 100644 index 000000000..be791d8fa --- /dev/null +++ b/integration/test_ml.py @@ -0,0 +1,373 @@ +# Copyright 2020 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Integration tests for firebase_admin.ml module.""" +import os +import random +import re +import shutil +import string +import tempfile + +import pytest + +from firebase_admin import exceptions +from firebase_admin import ml +from tests import testutils + + +# pylint: disable=import-error,no-name-in-module +try: + import tensorflow as tf + _TF_ENABLED = True +except ImportError: + _TF_ENABLED = False + + +def _random_identifier(prefix): + #pylint: disable=unused-variable + suffix = ''.join([random.choice(string.ascii_letters + string.digits) for n in range(8)]) + return '{0}_{1}'.format(prefix, suffix) + + +NAME_ONLY_ARGS = { + 'display_name': _random_identifier('TestModel123_') +} +NAME_ONLY_ARGS_UPDATED = { + 'display_name': _random_identifier('TestModel123_updated_') +} +NAME_AND_TAGS_ARGS = { + 'display_name': _random_identifier('TestModel123_tags_'), + 'tags': ['test_tag123'] +} +FULL_MODEL_ARGS = { + 'display_name': _random_identifier('TestModel123_full_'), + 'tags': ['test_tag567'], + 'file_name': 'model1.tflite' +} +INVALID_FULL_MODEL_ARGS = { + 'display_name': _random_identifier('TestModel123_invalid_full_'), + 'tags': ['test_tag890'], + 'file_name': 'invalid_model.tflite' +} + + +@pytest.fixture +def firebase_model(request): + args = request.param + tflite_format = None + file_name = args.get('file_name') + if file_name: + file_path = testutils.resource_filename(file_name) + source = ml.TFLiteGCSModelSource.from_tflite_model_file(file_path) + tflite_format = ml.TFLiteFormat(model_source=source) + + ml_model = ml.Model( + display_name=args.get('display_name'), + tags=args.get('tags'), + model_format=tflite_format) + model = ml.create_model(model=ml_model) + yield model + _clean_up_model(model) + + +@pytest.fixture +def model_list(): + ml_model_1 = ml.Model(display_name=_random_identifier('TestModel123_list1_')) + model_1 = ml.create_model(model=ml_model_1) + + ml_model_2 = ml.Model(display_name=_random_identifier('TestModel123_list2_'), + tags=['test_tag123']) + model_2 = ml.create_model(model=ml_model_2) + + yield [model_1, model_2] + + _clean_up_model(model_1) + _clean_up_model(model_2) + + +def _clean_up_model(model): + try: + # Try to delete the model. + # Some tests delete the model as part of the test. + ml.delete_model(model.model_id) + except exceptions.NotFoundError: + pass + + +# For rpc errors +def check_firebase_error(excinfo, status, msg): + err = excinfo.value + assert isinstance(err, exceptions.FirebaseError) + assert err.cause is not None + assert err.http_response is not None + assert err.http_response.status_code == status + assert str(err) == msg + + +# For operation errors +def check_operation_error(excinfo, msg): + err = excinfo.value + assert isinstance(err, exceptions.FirebaseError) + assert str(err) == msg + + +def check_model(model, args): + assert model.display_name == args.get('display_name') + assert model.tags == args.get('tags') + assert model.model_id is not None + assert model.create_time is not None + assert model.update_time is not None + assert model.locked is False + assert model.etag is not None + + +def check_model_format(model, has_model_format=False, validation_error=None): + if has_model_format: + assert model.validation_error == validation_error + assert model.published is False + assert model.model_format.model_source.gcs_tflite_uri.startswith('gs://') + if validation_error: + assert model.model_format.size_bytes is None + assert model.model_hash is None + else: + assert model.model_format.size_bytes is not None + assert model.model_hash is not None + else: + assert model.model_format is None + assert model.validation_error == 'No model file has been uploaded.' + assert model.published is False + assert model.model_hash is None + + +@pytest.mark.parametrize('firebase_model', [NAME_AND_TAGS_ARGS], indirect=True) +def test_create_simple_model(firebase_model): + check_model(firebase_model, NAME_AND_TAGS_ARGS) + check_model_format(firebase_model) + + +@pytest.mark.parametrize('firebase_model', [FULL_MODEL_ARGS], indirect=True) +def test_create_full_model(firebase_model): + check_model(firebase_model, FULL_MODEL_ARGS) + check_model_format(firebase_model, True) + + +@pytest.mark.parametrize('firebase_model', [FULL_MODEL_ARGS], indirect=True) +def test_create_already_existing_fails(firebase_model): + with pytest.raises(exceptions.AlreadyExistsError) as excinfo: + ml.create_model(model=firebase_model) + check_operation_error( + excinfo, + 'Model \'{0}\' already exists'.format(firebase_model.display_name)) + + +@pytest.mark.parametrize('firebase_model', [INVALID_FULL_MODEL_ARGS], indirect=True) +def test_create_invalid_model(firebase_model): + check_model(firebase_model, INVALID_FULL_MODEL_ARGS) + check_model_format(firebase_model, True, 'Invalid flatbuffer format') + + +@pytest.mark.parametrize('firebase_model', [NAME_AND_TAGS_ARGS], indirect=True) +def test_get_model(firebase_model): + get_model = ml.get_model(firebase_model.model_id) + check_model(get_model, NAME_AND_TAGS_ARGS) + check_model_format(get_model) + + +@pytest.mark.parametrize('firebase_model', [NAME_ONLY_ARGS], indirect=True) +def test_get_non_existing_model(firebase_model): + # Get a valid model_id that no longer exists + ml.delete_model(firebase_model.model_id) + + with pytest.raises(exceptions.NotFoundError) as excinfo: + ml.get_model(firebase_model.model_id) + check_firebase_error(excinfo, 404, 'Requested entity was not found.') + + +@pytest.mark.parametrize('firebase_model', [NAME_ONLY_ARGS], indirect=True) +def test_update_model(firebase_model): + new_model_name = NAME_ONLY_ARGS_UPDATED.get('display_name') + firebase_model.display_name = new_model_name + updated_model = ml.update_model(firebase_model) + check_model(updated_model, NAME_ONLY_ARGS_UPDATED) + check_model_format(updated_model) + + # Second call with same model does not cause error + updated_model2 = ml.update_model(updated_model) + check_model(updated_model2, NAME_ONLY_ARGS_UPDATED) + check_model_format(updated_model2) + + +@pytest.mark.parametrize('firebase_model', [NAME_ONLY_ARGS], indirect=True) +def test_update_non_existing_model(firebase_model): + ml.delete_model(firebase_model.model_id) + + firebase_model.tags = ['tag987'] + with pytest.raises(exceptions.NotFoundError) as excinfo: + ml.update_model(firebase_model) + check_operation_error( + excinfo, + 'Model \'{0}\' was not found'.format(firebase_model.as_dict().get('name'))) + + +@pytest.mark.parametrize('firebase_model', [FULL_MODEL_ARGS], indirect=True) +def test_publish_unpublish_model(firebase_model): + assert firebase_model.published is False + + published_model = ml.publish_model(firebase_model.model_id) + assert published_model.published is True + + unpublished_model = ml.unpublish_model(published_model.model_id) + assert unpublished_model.published is False + + +@pytest.mark.parametrize('firebase_model', [NAME_ONLY_ARGS], indirect=True) +def test_publish_invalid_fails(firebase_model): + assert firebase_model.validation_error is not None + + with pytest.raises(exceptions.FailedPreconditionError) as excinfo: + ml.publish_model(firebase_model.model_id) + check_operation_error( + excinfo, + 'Cannot publish a model that is not verified.') + + +@pytest.mark.parametrize('firebase_model', [FULL_MODEL_ARGS], indirect=True) +def test_publish_unpublish_non_existing_model(firebase_model): + ml.delete_model(firebase_model.model_id) + + with pytest.raises(exceptions.NotFoundError) as excinfo: + ml.publish_model(firebase_model.model_id) + check_operation_error( + excinfo, + 'Model \'{0}\' was not found'.format(firebase_model.as_dict().get('name'))) + + with pytest.raises(exceptions.NotFoundError) as excinfo: + ml.unpublish_model(firebase_model.model_id) + check_operation_error( + excinfo, + 'Model \'{0}\' was not found'.format(firebase_model.as_dict().get('name'))) + + +def test_list_models(model_list): + filter_str = 'displayName={0} OR tags:{1}'.format( + model_list[0].display_name, model_list[1].tags[0]) + + all_models = ml.list_models(list_filter=filter_str) + all_model_ids = [mdl.model_id for mdl in all_models.iterate_all()] + for mdl in model_list: + assert mdl.model_id in all_model_ids + + +def test_list_models_invalid_filter(): + invalid_filter = 'InvalidFilterParam=123' + + with pytest.raises(exceptions.InvalidArgumentError) as excinfo: + ml.list_models(list_filter=invalid_filter) + check_firebase_error(excinfo, 400, 'Request contains an invalid argument.') + + +@pytest.mark.parametrize('firebase_model', [NAME_ONLY_ARGS], indirect=True) +def test_delete_model(firebase_model): + ml.delete_model(firebase_model.model_id) + + # Second delete of same model will fail + with pytest.raises(exceptions.NotFoundError) as excinfo: + ml.delete_model(firebase_model.model_id) + check_firebase_error(excinfo, 404, 'Requested entity was not found.') + + +# Test tensor flow conversion functions if tensor flow is enabled. +#'pip install tensorflow' in the environment if you want _TF_ENABLED = True +#'pip install tensorflow==2.0.0b' for version 2 etc. + + +def _clean_up_directory(save_dir): + if save_dir.startswith(tempfile.gettempdir()) and os.path.exists(save_dir): + shutil.rmtree(save_dir) + + +@pytest.fixture +def keras_model(): + assert _TF_ENABLED + x_array = [-1, 0, 1, 2, 3, 4] + y_array = [-3, -1, 1, 3, 5, 7] + model = tf.keras.models.Sequential( + [tf.keras.layers.Dense(units=1, input_shape=[1])]) + model.compile(optimizer='sgd', loss='mean_squared_error') + model.fit(x_array, y_array, epochs=3) + return model + + +@pytest.fixture +def saved_model_dir(keras_model): + assert _TF_ENABLED + # Make a new parent directory. The child directory must not exist yet. + # The child directory gets created by tf. If it exists, the tf call fails. + parent = tempfile.mkdtemp() + save_dir = os.path.join(parent, 'child') + + # different versions have different model conversion capability + # pick something that works for each version + if tf.version.VERSION.startswith('1.'): + tf.reset_default_graph() + x_var = tf.placeholder(tf.float32, (None, 3), name="x") + y_var = tf.multiply(x_var, x_var, name="y") + with tf.Session() as sess: + tf.saved_model.simple_save(sess, save_dir, {"x": x_var}, {"y": y_var}) + else: + # If it's not version 1.x or version 2.x we need to update the test. + assert tf.version.VERSION.startswith('2.') + tf.saved_model.save(keras_model, save_dir) + yield save_dir + _clean_up_directory(parent) + + +@pytest.mark.skipif(not _TF_ENABLED, reason='Tensor flow is required for this test.') +def test_from_keras_model(keras_model): + source = ml.TFLiteGCSModelSource.from_keras_model(keras_model, 'model2.tflite') + assert re.search( + '^gs://.*/Firebase/ML/Models/model2.tflite$', + source.gcs_tflite_uri) is not None + + # Validate the conversion by creating a model + model_format = ml.TFLiteFormat(model_source=source) + model = ml.Model(display_name=_random_identifier('KerasModel_'), model_format=model_format) + created_model = ml.create_model(model) + + try: + check_model(created_model, {'display_name': model.display_name}) + check_model_format(created_model, True) + finally: + _clean_up_model(created_model) + + +@pytest.mark.skipif(not _TF_ENABLED, reason='Tensor flow is required for this test.') +def test_from_saved_model(saved_model_dir): + # Test the conversion helper + source = ml.TFLiteGCSModelSource.from_saved_model(saved_model_dir, 'model3.tflite') + assert re.search( + '^gs://.*/Firebase/ML/Models/model3.tflite$', + source.gcs_tflite_uri) is not None + + # Validate the conversion by creating a model + model_format = ml.TFLiteFormat(model_source=source) + model = ml.Model(display_name=_random_identifier('SavedModel_'), model_format=model_format) + created_model = ml.create_model(model) + + try: + assert created_model.model_id is not None + assert created_model.validation_error is None + finally: + _clean_up_model(created_model) diff --git a/tests/data/invalid_model.tflite b/tests/data/invalid_model.tflite new file mode 100644 index 000000000..d8482f436 --- /dev/null +++ b/tests/data/invalid_model.tflite @@ -0,0 +1 @@ +This is not a tflite file. diff --git a/tests/data/model1.tflite b/tests/data/model1.tflite new file mode 100644 index 0000000000000000000000000000000000000000..c4b71b7a222ebc59ee9fa1239fe2b8efb382cf8b GIT binary patch literal 736 zcmaJY5FPb2r$h~!B1MWTQ%GV^!A6@CK}ZNlustqhi-Y8H-iIg%{+S@QcK(Dk zR)Rl3tSqc7Y;=9^?h;aY@NQ;j_RY-O-KvOmPg{E;TT&H6Oeso9%7|7F5m^Gpi-MRS zFR}v^fd$|pw*l-X(CyeA%O3exDvVXXE-Q$g0Ea*gAfI&%;7x1IS|70Au#+FH4KSD! zSQ8%k6Eu29ZW(^Feo)_K>{n|Oma%PM==n~V_^~%s4thu4$j6N3nHtV(0aQiaEoyRp zezXMpUIWy`Jtl%{m~A>Qxh()kG2`sRkJM$N(Aph1%|>7Ok%DczaXT3_&XwE0a6`}S z4OAy+#G&g)!6;Io$rCiN=X3S*IL!O-tl7r`=KFA-Gt`c~_y(?gfqS2GxQ`s3wFx?^MfSDns>_^X1%E{Y8Sb)GfRIXJ-KXm39D=`^W;!7eZm6%(eLy;H^LSfV^(T? zeSA40uL6|QKPM`rbs3=!d?x$wt<-YMb0LrVras*CGoXmIndd!cZ>NyH9V}M=02G>W Ai~s-t literal 0 HcmV?d00001 diff --git a/tests/test_ml.py b/tests/test_ml.py new file mode 100644 index 000000000..10b0441db --- /dev/null +++ b/tests/test_ml.py @@ -0,0 +1,1113 @@ +# Copyright 2019 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test cases for the firebase_admin.ml module.""" + +import json + +import pytest + +import firebase_admin +from firebase_admin import exceptions +from firebase_admin import ml +from tests import testutils + + +BASE_URL = 'https://firebaseml.googleapis.com/v1beta2/' +HEADER_CLIENT_KEY = 'X-FIREBASE-CLIENT' +HEADER_CLIENT_VALUE = 'fire-admin-python/{0}'.format(firebase_admin.__version__) +PROJECT_ID = 'my-project-1' + +PAGE_TOKEN = 'pageToken' +NEXT_PAGE_TOKEN = 'nextPageToken' + +CREATE_TIME = '2020-01-21T20:44:27.392932Z' +CREATE_TIME_MILLIS = 1579639467392 + +UPDATE_TIME = '2020-01-21T22:45:29.392932Z' +UPDATE_TIME_MILLIS = 1579646729392 + +CREATE_TIME_2 = '2020-01-21T21:44:27.392932Z' +UPDATE_TIME_2 = '2020-01-21T23:45:29.392932Z' + +ETAG = '33a64df551425fcc55e4d42a148795d9f25f89d4' +MODEL_HASH = '987987a98b98798d098098e09809fc0893897' +TAG_1 = 'Tag1' +TAG_2 = 'Tag2' +TAG_3 = 'Tag3' +TAGS = [TAG_1, TAG_2] +TAGS_2 = [TAG_1, TAG_3] + +MODEL_ID_1 = 'modelId1' +MODEL_NAME_1 = 'projects/{0}/models/{1}'.format(PROJECT_ID, MODEL_ID_1) +DISPLAY_NAME_1 = 'displayName1' +MODEL_JSON_1 = { + 'name': MODEL_NAME_1, + 'displayName': DISPLAY_NAME_1 +} +MODEL_1 = ml.Model.from_dict(MODEL_JSON_1) + +MODEL_ID_2 = 'modelId2' +MODEL_NAME_2 = 'projects/{0}/models/{1}'.format(PROJECT_ID, MODEL_ID_2) +DISPLAY_NAME_2 = 'displayName2' +MODEL_JSON_2 = { + 'name': MODEL_NAME_2, + 'displayName': DISPLAY_NAME_2 +} +MODEL_2 = ml.Model.from_dict(MODEL_JSON_2) + +MODEL_ID_3 = 'modelId3' +MODEL_NAME_3 = 'projects/{0}/models/{1}'.format(PROJECT_ID, MODEL_ID_3) +DISPLAY_NAME_3 = 'displayName3' +MODEL_JSON_3 = { + 'name': MODEL_NAME_3, + 'displayName': DISPLAY_NAME_3 +} +MODEL_3 = ml.Model.from_dict(MODEL_JSON_3) + +MODEL_STATE_PUBLISHED_JSON = { + 'published': True +} +VALIDATION_ERROR_CODE = 400 +VALIDATION_ERROR_MSG = 'No model format found for {0}.'.format(MODEL_ID_1) +MODEL_STATE_ERROR_JSON = { + 'validationError': { + 'code': VALIDATION_ERROR_CODE, + 'message': VALIDATION_ERROR_MSG, + } +} + +OPERATION_NAME_1 = 'projects/{0}/operations/123'.format(PROJECT_ID) +OPERATION_NOT_DONE_JSON_1 = { + 'name': OPERATION_NAME_1, + 'metadata': { + '@type': 'type.googleapis.com/google.firebase.ml.v1beta2.ModelOperationMetadata', + 'name': 'projects/{0}/models/{1}'.format(PROJECT_ID, MODEL_ID_1), + 'basic_operation_status': 'BASIC_OPERATION_STATUS_UPLOADING' + } +} + +GCS_BUCKET_NAME = 'my_bucket' +GCS_BLOB_NAME = 'mymodel.tflite' +GCS_TFLITE_URI = 'gs://{0}/{1}'.format(GCS_BUCKET_NAME, GCS_BLOB_NAME) +GCS_TFLITE_URI_JSON = {'gcsTfliteUri': GCS_TFLITE_URI} +GCS_TFLITE_MODEL_SOURCE = ml.TFLiteGCSModelSource(GCS_TFLITE_URI) +TFLITE_FORMAT_JSON = { + 'gcsTfliteUri': GCS_TFLITE_URI, + 'sizeBytes': '1234567' +} +TFLITE_FORMAT = ml.TFLiteFormat.from_dict(TFLITE_FORMAT_JSON) + +GCS_TFLITE_SIGNED_URI_PATTERN = ( + 'https://storage.googleapis.com/{0}/{1}?X-Goog-Algorithm=GOOG4-RSA-SHA256&foo') +GCS_TFLITE_SIGNED_URI = GCS_TFLITE_SIGNED_URI_PATTERN.format(GCS_BUCKET_NAME, GCS_BLOB_NAME) + +GCS_TFLITE_URI_2 = 'gs://my_bucket/mymodel2.tflite' +GCS_TFLITE_URI_JSON_2 = {'gcsTfliteUri': GCS_TFLITE_URI_2} +GCS_TFLITE_MODEL_SOURCE_2 = ml.TFLiteGCSModelSource(GCS_TFLITE_URI_2) +TFLITE_FORMAT_JSON_2 = { + 'gcsTfliteUri': GCS_TFLITE_URI_2, + 'sizeBytes': '2345678' +} +TFLITE_FORMAT_2 = ml.TFLiteFormat.from_dict(TFLITE_FORMAT_JSON_2) + +CREATED_UPDATED_MODEL_JSON_1 = { + 'name': MODEL_NAME_1, + 'displayName': DISPLAY_NAME_1, + 'createTime': CREATE_TIME, + 'updateTime': UPDATE_TIME, + 'state': MODEL_STATE_ERROR_JSON, + 'etag': ETAG, + 'modelHash': MODEL_HASH, + 'tags': TAGS, +} +CREATED_UPDATED_MODEL_1 = ml.Model.from_dict(CREATED_UPDATED_MODEL_JSON_1) + +LOCKED_MODEL_JSON_1 = { + 'name': MODEL_NAME_1, + 'displayName': DISPLAY_NAME_1, + 'createTime': CREATE_TIME, + 'updateTime': UPDATE_TIME, + 'tags': TAGS, + 'activeOperations': [OPERATION_NOT_DONE_JSON_1] +} + +LOCKED_MODEL_JSON_2 = { + 'name': MODEL_NAME_1, + 'displayName': DISPLAY_NAME_2, + 'createTime': CREATE_TIME_2, + 'updateTime': UPDATE_TIME_2, + 'tags': TAGS_2, + 'activeOperations': [OPERATION_NOT_DONE_JSON_1] +} + +OPERATION_DONE_MODEL_JSON_1 = { + 'done': True, + 'response': CREATED_UPDATED_MODEL_JSON_1 +} +OPERATION_MALFORMED_JSON_1 = { + 'done': True, + # if done is true then either response or error should be populated +} +OPERATION_MISSING_NAME = { + # Name is required if the operation is not done. + 'done': False +} +OPERATION_ERROR_CODE = 3 +OPERATION_ERROR_MSG = "Invalid argument" +OPERATION_ERROR_EXPECTED_STATUS = 'INVALID_ARGUMENT' +OPERATION_ERROR_JSON_1 = { + 'done': True, + 'error': { + 'code': OPERATION_ERROR_CODE, + 'message': OPERATION_ERROR_MSG, + } +} + +FULL_MODEL_ERR_STATE_LRO_JSON = { + 'name': MODEL_NAME_1, + 'displayName': DISPLAY_NAME_1, + 'createTime': CREATE_TIME, + 'updateTime': UPDATE_TIME, + 'state': MODEL_STATE_ERROR_JSON, + 'etag': ETAG, + 'modelHash': MODEL_HASH, + 'tags': TAGS, + 'activeOperations': [OPERATION_NOT_DONE_JSON_1], +} +FULL_MODEL_PUBLISHED_JSON = { + 'name': MODEL_NAME_1, + 'displayName': DISPLAY_NAME_1, + 'createTime': CREATE_TIME, + 'updateTime': UPDATE_TIME, + 'state': MODEL_STATE_PUBLISHED_JSON, + 'etag': ETAG, + 'modelHash': MODEL_HASH, + 'tags': TAGS, + 'tfliteModel': TFLITE_FORMAT_JSON +} +FULL_MODEL_PUBLISHED = ml.Model.from_dict(FULL_MODEL_PUBLISHED_JSON) +OPERATION_DONE_FULL_MODEL_PUBLISHED_JSON = { + 'name': OPERATION_NAME_1, + 'done': True, + 'response': FULL_MODEL_PUBLISHED_JSON +} + +EMPTY_RESPONSE = json.dumps({}) +OPERATION_NOT_DONE_RESPONSE = json.dumps(OPERATION_NOT_DONE_JSON_1) +OPERATION_DONE_RESPONSE = json.dumps(OPERATION_DONE_MODEL_JSON_1) +OPERATION_DONE_PUBLISHED_RESPONSE = json.dumps(OPERATION_DONE_FULL_MODEL_PUBLISHED_JSON) +OPERATION_ERROR_RESPONSE = json.dumps(OPERATION_ERROR_JSON_1) +OPERATION_MALFORMED_RESPONSE = json.dumps(OPERATION_MALFORMED_JSON_1) +OPERATION_MISSING_NAME_RESPONSE = json.dumps(OPERATION_MISSING_NAME) +DEFAULT_GET_RESPONSE = json.dumps(MODEL_JSON_1) +LOCKED_MODEL_2_RESPONSE = json.dumps(LOCKED_MODEL_JSON_2) +NO_MODELS_LIST_RESPONSE = json.dumps({}) +DEFAULT_LIST_RESPONSE = json.dumps({ + 'models': [MODEL_JSON_1, MODEL_JSON_2], + 'nextPageToken': NEXT_PAGE_TOKEN +}) +LAST_PAGE_LIST_RESPONSE = json.dumps({ + 'models': [MODEL_JSON_3] +}) +ONE_PAGE_LIST_RESPONSE = json.dumps({ + 'models': [MODEL_JSON_1, MODEL_JSON_2, MODEL_JSON_3], +}) + +ERROR_CODE_NOT_FOUND = 404 +ERROR_MSG_NOT_FOUND = 'The resource was not found' +ERROR_STATUS_NOT_FOUND = 'NOT_FOUND' +ERROR_JSON_NOT_FOUND = { + 'error': { + 'code': ERROR_CODE_NOT_FOUND, + 'message': ERROR_MSG_NOT_FOUND, + 'status': ERROR_STATUS_NOT_FOUND + } +} +ERROR_RESPONSE_NOT_FOUND = json.dumps(ERROR_JSON_NOT_FOUND) + +ERROR_CODE_BAD_REQUEST = 400 +ERROR_MSG_BAD_REQUEST = 'Invalid Argument' +ERROR_STATUS_BAD_REQUEST = 'INVALID_ARGUMENT' +ERROR_JSON_BAD_REQUEST = { + 'error': { + 'code': ERROR_CODE_BAD_REQUEST, + 'message': ERROR_MSG_BAD_REQUEST, + 'status': ERROR_STATUS_BAD_REQUEST + } +} +ERROR_RESPONSE_BAD_REQUEST = json.dumps(ERROR_JSON_BAD_REQUEST) + +INVALID_MODEL_ID_ARGS = [ + ('', ValueError), + ('&_*#@:/?', ValueError), + (None, TypeError), + (12345, TypeError), +] +INVALID_MODEL_ARGS = [ + 'abc', + 4.2, + list(), + dict(), + True, + -1, + 0, + None +] +INVALID_OP_NAME_ARGS = [ + 'abc', + '123', + 'operations/project/1234/model/abc/operation/123', + 'projects/operations/123', + 'projects/$#@/operations/123', + 'projects/1234/operations/123/extrathing', +] +PAGE_SIZE_VALUE_ERROR_MSG = 'Page size must be a positive integer between ' \ + '1 and {0}'.format(ml._MAX_PAGE_SIZE) +INVALID_STRING_OR_NONE_ARGS = [0, -1, 4.2, 0x10, False, list(), dict()] + + +# For validation type errors +def check_error(excinfo, err_type, msg=None): + err = excinfo.value + assert isinstance(err, err_type) + if msg: + assert str(err) == msg + + +# For errors that are returned in an operation +def check_operation_error(excinfo, code, msg): + err = excinfo.value + assert isinstance(err, exceptions.FirebaseError) + assert err.code == code + assert str(err) == msg + + +# For rpc errors +def check_firebase_error(excinfo, code, status, msg): + err = excinfo.value + assert isinstance(err, exceptions.FirebaseError) + assert err.code == code + assert err.http_response is not None + assert err.http_response.status_code == status + assert str(err) == msg + + +def instrument_ml_service(status=200, payload=None, operations=False, app=None): + if not app: + app = firebase_admin.get_app() + ml_service = ml._get_ml_service(app) + recorder = [] + session_url = 'https://firebaseml.googleapis.com/v1beta2/' + + if isinstance(status, list): + adapter = testutils.MockMultiRequestAdapter + else: + adapter = testutils.MockAdapter + + if operations: + ml_service._operation_client.session.mount( + session_url, adapter(payload, status, recorder)) + else: + ml_service._client.session.mount( + session_url, adapter(payload, status, recorder)) + return recorder + +class _TestStorageClient: + @staticmethod + def upload(bucket_name, model_file_name, app): + del app # unused variable + blob_name = ml._CloudStorageClient.BLOB_NAME.format(model_file_name) + return ml._CloudStorageClient.GCS_URI.format(bucket_name, blob_name) + + @staticmethod + def sign_uri(gcs_tflite_uri, app): + del app # unused variable + bucket_name, blob_name = ml._CloudStorageClient._parse_gcs_tflite_uri(gcs_tflite_uri) + return GCS_TFLITE_SIGNED_URI_PATTERN.format(bucket_name, blob_name) + +class TestModel: + """Tests ml.Model class.""" + @classmethod + def setup_class(cls): + cred = testutils.MockCredential() + firebase_admin.initialize_app(cred, {'projectId': PROJECT_ID}) + ml._MLService.POLL_BASE_WAIT_TIME_SECONDS = 0.1 # shorter for test + ml.TFLiteGCSModelSource._STORAGE_CLIENT = _TestStorageClient() + + @classmethod + def teardown_class(cls): + testutils.cleanup_apps() + + @staticmethod + def _op_url(project_id): + return BASE_URL + \ + 'projects/{0}/operations/123'.format(project_id) + + def test_model_success_err_state_lro(self): + model = ml.Model.from_dict(FULL_MODEL_ERR_STATE_LRO_JSON) + assert model.model_id == MODEL_ID_1 + assert model.display_name == DISPLAY_NAME_1 + assert model.create_time == CREATE_TIME_MILLIS + assert model.update_time == UPDATE_TIME_MILLIS + assert model.validation_error == VALIDATION_ERROR_MSG + assert model.published is False + assert model.etag == ETAG + assert model.model_hash == MODEL_HASH + assert model.tags == TAGS + assert model.locked is True + assert model.model_format is None + assert model.as_dict() == FULL_MODEL_ERR_STATE_LRO_JSON + + def test_model_success_published(self): + model = ml.Model.from_dict(FULL_MODEL_PUBLISHED_JSON) + assert model.model_id == MODEL_ID_1 + assert model.display_name == DISPLAY_NAME_1 + assert model.create_time == CREATE_TIME_MILLIS + assert model.update_time == UPDATE_TIME_MILLIS + assert model.validation_error is None + assert model.published is True + assert model.etag == ETAG + assert model.model_hash == MODEL_HASH + assert model.tags == TAGS + assert model.locked is False + assert model.model_format == TFLITE_FORMAT + assert model.as_dict() == FULL_MODEL_PUBLISHED_JSON + + def test_model_keyword_based_creation_and_setters(self): + model = ml.Model(display_name=DISPLAY_NAME_1, tags=TAGS, model_format=TFLITE_FORMAT) + assert model.display_name == DISPLAY_NAME_1 + assert model.tags == TAGS + assert model.model_format == TFLITE_FORMAT + assert model.as_dict() == { + 'displayName': DISPLAY_NAME_1, + 'tags': TAGS, + 'tfliteModel': TFLITE_FORMAT_JSON + } + + model.display_name = DISPLAY_NAME_2 + model.tags = TAGS_2 + model.model_format = TFLITE_FORMAT_2 + assert model.as_dict() == { + 'displayName': DISPLAY_NAME_2, + 'tags': TAGS_2, + 'tfliteModel': TFLITE_FORMAT_JSON_2 + } + + def test_model_format_source_creation(self): + model_source = ml.TFLiteGCSModelSource(gcs_tflite_uri=GCS_TFLITE_URI) + model_format = ml.TFLiteFormat(model_source=model_source) + model = ml.Model(display_name=DISPLAY_NAME_1, model_format=model_format) + assert model.as_dict() == { + 'displayName': DISPLAY_NAME_1, + 'tfliteModel': { + 'gcsTfliteUri': GCS_TFLITE_URI + } + } + + def test_source_creation_from_tflite_file(self): + model_source = ml.TFLiteGCSModelSource.from_tflite_model_file( + "my_model.tflite", "my_bucket") + assert model_source.as_dict() == { + 'gcsTfliteUri': 'gs://my_bucket/Firebase/ML/Models/my_model.tflite' + } + + def test_model_source_setters(self): + model_source = ml.TFLiteGCSModelSource(GCS_TFLITE_URI) + model_source.gcs_tflite_uri = GCS_TFLITE_URI_2 + assert model_source.gcs_tflite_uri == GCS_TFLITE_URI_2 + assert model_source.as_dict() == GCS_TFLITE_URI_JSON_2 + + def test_model_format_setters(self): + model_format = ml.TFLiteFormat(model_source=GCS_TFLITE_MODEL_SOURCE) + model_format.model_source = GCS_TFLITE_MODEL_SOURCE_2 + assert model_format.model_source == GCS_TFLITE_MODEL_SOURCE_2 + assert model_format.as_dict() == { + 'tfliteModel': { + 'gcsTfliteUri': GCS_TFLITE_URI_2 + } + } + + def test_model_as_dict_for_upload(self): + model_source = ml.TFLiteGCSModelSource(gcs_tflite_uri=GCS_TFLITE_URI) + model_format = ml.TFLiteFormat(model_source=model_source) + model = ml.Model(display_name=DISPLAY_NAME_1, model_format=model_format) + assert model.as_dict(for_upload=True) == { + 'displayName': DISPLAY_NAME_1, + 'tfliteModel': { + 'gcsTfliteUri': GCS_TFLITE_SIGNED_URI + } + } + + @pytest.mark.parametrize('helper_func', [ + ml.TFLiteGCSModelSource.from_keras_model, + ml.TFLiteGCSModelSource.from_saved_model + ]) + def test_tf_not_enabled(self, helper_func): + ml._TF_ENABLED = False # for reliability + with pytest.raises(ImportError) as excinfo: + helper_func(None) + check_error(excinfo, ImportError) + + @pytest.mark.parametrize('display_name, exc_type', [ + ('', ValueError), + ('&_*#@:/?', ValueError), + (12345, TypeError) + ]) + def test_model_display_name_validation_errors(self, display_name, exc_type): + with pytest.raises(exc_type) as excinfo: + ml.Model(display_name=display_name) + check_error(excinfo, exc_type) + + @pytest.mark.parametrize('tags, exc_type, error_message', [ + ('tag1', TypeError, 'Tags must be a list of strings.'), + (123, TypeError, 'Tags must be a list of strings.'), + (['tag1', 123, 'tag2'], TypeError, 'Tags must be a list of strings.'), + (['tag1', '@#$%^&'], ValueError, 'Tag format is invalid.'), + (['', 'tag2'], ValueError, 'Tag format is invalid.'), + (['sixty-one_characters_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx', + 'tag2'], ValueError, 'Tag format is invalid.') + ]) + def test_model_tags_validation_errors(self, tags, exc_type, error_message): + with pytest.raises(exc_type) as excinfo: + ml.Model(tags=tags) + check_error(excinfo, exc_type, error_message) + + @pytest.mark.parametrize('model_format', [ + 123, + "abc", + {}, + [], + True + ]) + def test_model_format_validation_errors(self, model_format): + with pytest.raises(TypeError) as excinfo: + ml.Model(model_format=model_format) + check_error(excinfo, TypeError, 'Model format must be a ModelFormat object.') + + @pytest.mark.parametrize('model_source', [ + 123, + "abc", + {}, + [], + True + ]) + def test_model_source_validation_errors(self, model_source): + with pytest.raises(TypeError) as excinfo: + ml.TFLiteFormat(model_source=model_source) + check_error(excinfo, TypeError, 'Model source must be a TFLiteModelSource object.') + + @pytest.mark.parametrize('uri, exc_type', [ + (123, TypeError), + ('abc', ValueError), + ('gs://NO_CAPITALS', ValueError), + ('gs://abc/', ValueError), + ('gs://aa/model.tflite', ValueError), + ('gs://@#$%/model.tflite', ValueError), + ('gs://invalid space/model.tflite', ValueError), + ('gs://sixty-four-characters_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx/model.tflite', + ValueError) + ]) + def test_gcs_tflite_source_validation_errors(self, uri, exc_type): + with pytest.raises(exc_type) as excinfo: + ml.TFLiteGCSModelSource(gcs_tflite_uri=uri) + check_error(excinfo, exc_type) + + def test_wait_for_unlocked_not_locked(self): + model = ml.Model(display_name="not_locked") + model.wait_for_unlocked() + + def test_wait_for_unlocked(self): + recorder = instrument_ml_service(status=200, + operations=True, + payload=OPERATION_DONE_PUBLISHED_RESPONSE) + model = ml.Model.from_dict(LOCKED_MODEL_JSON_1) + model.wait_for_unlocked() + assert model == FULL_MODEL_PUBLISHED + assert len(recorder) == 1 + assert recorder[0].method == 'GET' + assert recorder[0].url == TestModel._op_url(PROJECT_ID) + assert recorder[0].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE + + def test_wait_for_unlocked_timeout(self): + recorder = instrument_ml_service( + status=200, operations=True, payload=OPERATION_NOT_DONE_RESPONSE) + ml._MLService.POLL_BASE_WAIT_TIME_SECONDS = 3 # longer so timeout applies immediately + model = ml.Model.from_dict(LOCKED_MODEL_JSON_1) + with pytest.raises(Exception) as excinfo: + model.wait_for_unlocked(max_time_seconds=0.1) + check_error(excinfo, exceptions.DeadlineExceededError, 'Polling max time exceeded.') + assert len(recorder) == 1 + + +class TestCreateModel: + """Tests ml.create_model.""" + @classmethod + def setup_class(cls): + cred = testutils.MockCredential() + firebase_admin.initialize_app(cred, {'projectId': PROJECT_ID}) + ml._MLService.POLL_BASE_WAIT_TIME_SECONDS = 0.1 # shorter for test + + @classmethod + def teardown_class(cls): + testutils.cleanup_apps() + + @staticmethod + def _url(project_id): + return BASE_URL + 'projects/{0}/models'.format(project_id) + + @staticmethod + def _op_url(project_id): + return BASE_URL + \ + 'projects/{0}/operations/123'.format(project_id) + + @staticmethod + def _get_url(project_id, model_id): + return BASE_URL + 'projects/{0}/models/{1}'.format(project_id, model_id) + + def test_immediate_done(self): + instrument_ml_service(status=200, payload=OPERATION_DONE_RESPONSE) + model = ml.create_model(MODEL_1) + assert model == CREATED_UPDATED_MODEL_1 + + def test_returns_locked(self): + recorder = instrument_ml_service( + status=[200, 200], + payload=[OPERATION_NOT_DONE_RESPONSE, LOCKED_MODEL_2_RESPONSE]) + expected_model = ml.Model.from_dict(LOCKED_MODEL_JSON_2) + model = ml.create_model(MODEL_1) + + assert model == expected_model + assert len(recorder) == 2 + assert recorder[0].method == 'POST' + assert recorder[0].url == TestCreateModel._url(PROJECT_ID) + assert recorder[0].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE + assert recorder[1].method == 'GET' + assert recorder[1].url == TestCreateModel._get_url(PROJECT_ID, MODEL_ID_1) + assert recorder[1].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE + + def test_operation_error(self): + instrument_ml_service(status=200, payload=OPERATION_ERROR_RESPONSE) + with pytest.raises(Exception) as excinfo: + ml.create_model(MODEL_1) + # The http request succeeded, the operation returned contains a create failure + check_operation_error(excinfo, OPERATION_ERROR_EXPECTED_STATUS, OPERATION_ERROR_MSG) + + def test_malformed_operation(self): + instrument_ml_service(status=200, payload=OPERATION_MALFORMED_RESPONSE) + with pytest.raises(Exception) as excinfo: + ml.create_model(MODEL_1) + check_error(excinfo, exceptions.UnknownError, 'Internal Error: Malformed Operation.') + + def test_rpc_error_create(self): + create_recorder = instrument_ml_service( + status=400, payload=ERROR_RESPONSE_BAD_REQUEST) + with pytest.raises(Exception) as excinfo: + ml.create_model(MODEL_1) + check_firebase_error( + excinfo, + ERROR_STATUS_BAD_REQUEST, + ERROR_CODE_BAD_REQUEST, + ERROR_MSG_BAD_REQUEST + ) + assert len(create_recorder) == 1 + + @pytest.mark.parametrize('model', INVALID_MODEL_ARGS) + def test_not_model(self, model): + with pytest.raises(Exception) as excinfo: + ml.create_model(model) + check_error(excinfo, TypeError, 'Model must be an ml.Model.') + + def test_missing_display_name(self): + with pytest.raises(Exception) as excinfo: + ml.create_model(ml.Model.from_dict({})) + check_error(excinfo, ValueError, 'Model must have a display name.') + + def test_missing_op_name(self): + instrument_ml_service(status=200, payload=OPERATION_MISSING_NAME_RESPONSE) + with pytest.raises(Exception) as excinfo: + ml.create_model(MODEL_1) + check_error(excinfo, TypeError) + + @pytest.mark.parametrize('op_name', INVALID_OP_NAME_ARGS) + def test_invalid_op_name(self, op_name): + payload = json.dumps({'name': op_name}) + instrument_ml_service(status=200, payload=payload) + with pytest.raises(Exception) as excinfo: + ml.create_model(MODEL_1) + check_error(excinfo, ValueError, 'Operation name format is invalid.') + + +class TestUpdateModel: + """Tests ml.update_model.""" + @classmethod + def setup_class(cls): + cred = testutils.MockCredential() + firebase_admin.initialize_app(cred, {'projectId': PROJECT_ID}) + ml._MLService.POLL_BASE_WAIT_TIME_SECONDS = 0.1 # shorter for test + + @classmethod + def teardown_class(cls): + testutils.cleanup_apps() + + @staticmethod + def _url(project_id, model_id): + return BASE_URL + 'projects/{0}/models/{1}'.format(project_id, model_id) + + @staticmethod + def _op_url(project_id): + return BASE_URL + \ + 'projects/{0}/operations/123'.format(project_id) + + def test_immediate_done(self): + instrument_ml_service(status=200, payload=OPERATION_DONE_RESPONSE) + model = ml.update_model(MODEL_1) + assert model == CREATED_UPDATED_MODEL_1 + + def test_returns_locked(self): + recorder = instrument_ml_service( + status=[200, 200], + payload=[OPERATION_NOT_DONE_RESPONSE, LOCKED_MODEL_2_RESPONSE]) + expected_model = ml.Model.from_dict(LOCKED_MODEL_JSON_2) + model = ml.update_model(MODEL_1) + + assert model == expected_model + assert len(recorder) == 2 + assert recorder[0].method == 'PATCH' + assert recorder[0].url == TestUpdateModel._url(PROJECT_ID, MODEL_ID_1) + assert recorder[0].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE + assert recorder[1].method == 'GET' + assert recorder[1].url == TestUpdateModel._url(PROJECT_ID, MODEL_ID_1) + assert recorder[1].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE + + def test_operation_error(self): + instrument_ml_service(status=200, payload=OPERATION_ERROR_RESPONSE) + with pytest.raises(Exception) as excinfo: + ml.update_model(MODEL_1) + # The http request succeeded, the operation returned contains an update failure + check_operation_error(excinfo, OPERATION_ERROR_EXPECTED_STATUS, OPERATION_ERROR_MSG) + + def test_malformed_operation(self): + instrument_ml_service(status=200, payload=OPERATION_MALFORMED_RESPONSE) + with pytest.raises(Exception) as excinfo: + ml.update_model(MODEL_1) + check_error(excinfo, exceptions.UnknownError, 'Internal Error: Malformed Operation.') + + def test_rpc_error(self): + create_recorder = instrument_ml_service( + status=400, payload=ERROR_RESPONSE_BAD_REQUEST) + with pytest.raises(Exception) as excinfo: + ml.update_model(MODEL_1) + check_firebase_error( + excinfo, + ERROR_STATUS_BAD_REQUEST, + ERROR_CODE_BAD_REQUEST, + ERROR_MSG_BAD_REQUEST + ) + assert len(create_recorder) == 1 + + @pytest.mark.parametrize('model', INVALID_MODEL_ARGS) + def test_not_model(self, model): + with pytest.raises(Exception) as excinfo: + ml.update_model(model) + check_error(excinfo, TypeError, 'Model must be an ml.Model.') + + def test_missing_display_name(self): + with pytest.raises(Exception) as excinfo: + ml.update_model(ml.Model.from_dict({})) + check_error(excinfo, ValueError, 'Model must have a display name.') + + def test_missing_op_name(self): + instrument_ml_service(status=200, payload=OPERATION_MISSING_NAME_RESPONSE) + with pytest.raises(Exception) as excinfo: + ml.update_model(MODEL_1) + check_error(excinfo, TypeError) + + @pytest.mark.parametrize('op_name', INVALID_OP_NAME_ARGS) + def test_invalid_op_name(self, op_name): + payload = json.dumps({'name': op_name}) + instrument_ml_service(status=200, payload=payload) + with pytest.raises(Exception) as excinfo: + ml.update_model(MODEL_1) + check_error(excinfo, ValueError, 'Operation name format is invalid.') + + +class TestPublishUnpublish: + """Tests ml.publish_model and ml.unpublish_model.""" + + PUBLISH_UNPUBLISH_WITH_ARGS = [ + (ml.publish_model, True), + (ml.unpublish_model, False) + ] + PUBLISH_UNPUBLISH_FUNCS = [item[0] for item in PUBLISH_UNPUBLISH_WITH_ARGS] + + @classmethod + def setup_class(cls): + cred = testutils.MockCredential() + firebase_admin.initialize_app(cred, {'projectId': PROJECT_ID}) + ml._MLService.POLL_BASE_WAIT_TIME_SECONDS = 0.1 # shorter for test + + @classmethod + def teardown_class(cls): + testutils.cleanup_apps() + + @staticmethod + def _update_url(project_id, model_id): + update_url = 'projects/{0}/models/{1}?updateMask=state.published'.format( + project_id, model_id) + return BASE_URL + update_url + + @staticmethod + def _get_url(project_id, model_id): + return BASE_URL + 'projects/{0}/models/{1}'.format(project_id, model_id) + + @staticmethod + def _op_url(project_id): + return BASE_URL + \ + 'projects/{0}/operations/123'.format(project_id) + + @pytest.mark.parametrize('publish_function, published', PUBLISH_UNPUBLISH_WITH_ARGS) + def test_immediate_done(self, publish_function, published): + recorder = instrument_ml_service(status=200, payload=OPERATION_DONE_RESPONSE) + model = publish_function(MODEL_ID_1) + assert model == CREATED_UPDATED_MODEL_1 + assert len(recorder) == 1 + assert recorder[0].method == 'PATCH' + assert recorder[0].url == TestPublishUnpublish._update_url(PROJECT_ID, MODEL_ID_1) + assert recorder[0].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE + body = json.loads(recorder[0].body.decode()) + assert body.get('state', {}).get('published', None) is published + + @pytest.mark.parametrize('publish_function', PUBLISH_UNPUBLISH_FUNCS) + def test_returns_locked(self, publish_function): + recorder = instrument_ml_service( + status=[200, 200], + payload=[OPERATION_NOT_DONE_RESPONSE, LOCKED_MODEL_2_RESPONSE]) + expected_model = ml.Model.from_dict(LOCKED_MODEL_JSON_2) + model = publish_function(MODEL_ID_1) + + assert model == expected_model + assert len(recorder) == 2 + assert recorder[0].method == 'PATCH' + assert recorder[0].url == TestPublishUnpublish._update_url(PROJECT_ID, MODEL_ID_1) + assert recorder[0].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE + assert recorder[1].method == 'GET' + assert recorder[1].url == TestPublishUnpublish._get_url(PROJECT_ID, MODEL_ID_1) + assert recorder[1].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE + + @pytest.mark.parametrize('publish_function', PUBLISH_UNPUBLISH_FUNCS) + def test_operation_error(self, publish_function): + instrument_ml_service(status=200, payload=OPERATION_ERROR_RESPONSE) + with pytest.raises(Exception) as excinfo: + publish_function(MODEL_ID_1) + # The http request succeeded, the operation returned contains an update failure + check_operation_error(excinfo, OPERATION_ERROR_EXPECTED_STATUS, OPERATION_ERROR_MSG) + + @pytest.mark.parametrize('publish_function', PUBLISH_UNPUBLISH_FUNCS) + def test_malformed_operation(self, publish_function): + instrument_ml_service(status=200, payload=OPERATION_MALFORMED_RESPONSE) + with pytest.raises(Exception) as excinfo: + publish_function(MODEL_ID_1) + check_error(excinfo, exceptions.UnknownError, 'Internal Error: Malformed Operation.') + + @pytest.mark.parametrize('publish_function', PUBLISH_UNPUBLISH_FUNCS) + def test_rpc_error(self, publish_function): + create_recorder = instrument_ml_service( + status=400, payload=ERROR_RESPONSE_BAD_REQUEST) + with pytest.raises(Exception) as excinfo: + publish_function(MODEL_ID_1) + check_firebase_error( + excinfo, + ERROR_STATUS_BAD_REQUEST, + ERROR_CODE_BAD_REQUEST, + ERROR_MSG_BAD_REQUEST + ) + assert len(create_recorder) == 1 + + +class TestGetModel: + """Tests ml.get_model.""" + @classmethod + def setup_class(cls): + cred = testutils.MockCredential() + firebase_admin.initialize_app(cred, {'projectId': PROJECT_ID}) + + @classmethod + def teardown_class(cls): + testutils.cleanup_apps() + + @staticmethod + def _url(project_id, model_id): + return BASE_URL + 'projects/{0}/models/{1}'.format(project_id, model_id) + + def test_get_model(self): + recorder = instrument_ml_service(status=200, payload=DEFAULT_GET_RESPONSE) + model = ml.get_model(MODEL_ID_1) + assert len(recorder) == 1 + assert recorder[0].method == 'GET' + assert recorder[0].url == TestGetModel._url(PROJECT_ID, MODEL_ID_1) + assert recorder[0].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE + assert model == MODEL_1 + assert model.model_id == MODEL_ID_1 + assert model.display_name == DISPLAY_NAME_1 + + @pytest.mark.parametrize('model_id, exc_type', INVALID_MODEL_ID_ARGS) + def test_get_model_validation_errors(self, model_id, exc_type): + with pytest.raises(exc_type) as excinfo: + ml.get_model(model_id) + check_error(excinfo, exc_type) + + def test_get_model_error(self): + recorder = instrument_ml_service(status=404, payload=ERROR_RESPONSE_NOT_FOUND) + with pytest.raises(exceptions.NotFoundError) as excinfo: + ml.get_model(MODEL_ID_1) + check_firebase_error( + excinfo, + ERROR_STATUS_NOT_FOUND, + ERROR_CODE_NOT_FOUND, + ERROR_MSG_NOT_FOUND + ) + assert len(recorder) == 1 + assert recorder[0].method == 'GET' + assert recorder[0].url == TestGetModel._url(PROJECT_ID, MODEL_ID_1) + assert recorder[0].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE + + def test_no_project_id(self): + def evaluate(): + app = firebase_admin.initialize_app(testutils.MockCredential(), name='no_project_id') + with pytest.raises(ValueError): + ml.get_model(MODEL_ID_1, app) + testutils.run_without_project_id(evaluate) + + +class TestDeleteModel: + """Tests ml.delete_model.""" + @classmethod + def setup_class(cls): + cred = testutils.MockCredential() + firebase_admin.initialize_app(cred, {'projectId': PROJECT_ID}) + + @classmethod + def teardown_class(cls): + testutils.cleanup_apps() + + @staticmethod + def _url(project_id, model_id): + return BASE_URL + 'projects/{0}/models/{1}'.format(project_id, model_id) + + def test_delete_model(self): + recorder = instrument_ml_service(status=200, payload=EMPTY_RESPONSE) + ml.delete_model(MODEL_ID_1) # no response for delete + assert len(recorder) == 1 + assert recorder[0].method == 'DELETE' + assert recorder[0].url == TestDeleteModel._url(PROJECT_ID, MODEL_ID_1) + assert recorder[0].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE + + @pytest.mark.parametrize('model_id, exc_type', INVALID_MODEL_ID_ARGS) + def test_delete_model_validation_errors(self, model_id, exc_type): + with pytest.raises(exc_type) as excinfo: + ml.delete_model(model_id) + check_error(excinfo, exc_type) + + def test_delete_model_error(self): + recorder = instrument_ml_service(status=404, payload=ERROR_RESPONSE_NOT_FOUND) + with pytest.raises(exceptions.NotFoundError) as excinfo: + ml.delete_model(MODEL_ID_1) + check_firebase_error( + excinfo, + ERROR_STATUS_NOT_FOUND, + ERROR_CODE_NOT_FOUND, + ERROR_MSG_NOT_FOUND + ) + assert len(recorder) == 1 + assert recorder[0].method == 'DELETE' + assert recorder[0].url == self._url(PROJECT_ID, MODEL_ID_1) + assert recorder[0].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE + + def test_no_project_id(self): + def evaluate(): + app = firebase_admin.initialize_app(testutils.MockCredential(), name='no_project_id') + with pytest.raises(ValueError): + ml.delete_model(MODEL_ID_1, app) + testutils.run_without_project_id(evaluate) + + +class TestListModels: + """Tests ml.list_models.""" + @classmethod + def setup_class(cls): + cred = testutils.MockCredential() + firebase_admin.initialize_app(cred, {'projectId': PROJECT_ID}) + + @classmethod + def teardown_class(cls): + testutils.cleanup_apps() + + @staticmethod + def _url(project_id): + return BASE_URL + 'projects/{0}/models'.format(project_id) + + @staticmethod + def _check_page(page, model_count): + assert isinstance(page, ml.ListModelsPage) + assert len(page.models) == model_count + for model in page.models: + assert isinstance(model, ml.Model) + + def test_list_models_no_args(self): + recorder = instrument_ml_service(status=200, payload=DEFAULT_LIST_RESPONSE) + models_page = ml.list_models() + assert len(recorder) == 1 + assert recorder[0].method == 'GET' + assert recorder[0].url == TestListModels._url(PROJECT_ID) + assert recorder[0].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE + TestListModels._check_page(models_page, 2) + assert models_page.has_next_page + assert models_page.next_page_token == NEXT_PAGE_TOKEN + assert models_page.models[0] == MODEL_1 + assert models_page.models[1] == MODEL_2 + + def test_list_models_with_all_args(self): + recorder = instrument_ml_service(status=200, payload=LAST_PAGE_LIST_RESPONSE) + models_page = ml.list_models( + 'display_name=displayName3', + page_size=10, + page_token=PAGE_TOKEN) + assert len(recorder) == 1 + assert recorder[0].method == 'GET' + assert recorder[0].url == ( + TestListModels._url(PROJECT_ID) + + '?filter=display_name%3DdisplayName3&page_size=10&page_token={0}' + .format(PAGE_TOKEN)) + assert recorder[0].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE + assert isinstance(models_page, ml.ListModelsPage) + assert len(models_page.models) == 1 + assert models_page.models[0] == MODEL_3 + assert not models_page.has_next_page + + @pytest.mark.parametrize('list_filter', INVALID_STRING_OR_NONE_ARGS) + def test_list_models_list_filter_validation(self, list_filter): + with pytest.raises(TypeError) as excinfo: + ml.list_models(list_filter=list_filter) + check_error(excinfo, TypeError, 'List filter must be a string or None.') + + @pytest.mark.parametrize('page_size, exc_type, error_message', [ + ('abc', TypeError, 'Page size must be a number or None.'), + (4.2, TypeError, 'Page size must be a number or None.'), + (list(), TypeError, 'Page size must be a number or None.'), + (dict(), TypeError, 'Page size must be a number or None.'), + (True, TypeError, 'Page size must be a number or None.'), + (-1, ValueError, PAGE_SIZE_VALUE_ERROR_MSG), + (0, ValueError, PAGE_SIZE_VALUE_ERROR_MSG), + (ml._MAX_PAGE_SIZE + 1, ValueError, PAGE_SIZE_VALUE_ERROR_MSG) + ]) + def test_list_models_page_size_validation(self, page_size, exc_type, error_message): + with pytest.raises(exc_type) as excinfo: + ml.list_models(page_size=page_size) + check_error(excinfo, exc_type, error_message) + + @pytest.mark.parametrize('page_token', INVALID_STRING_OR_NONE_ARGS) + def test_list_models_page_token_validation(self, page_token): + with pytest.raises(TypeError) as excinfo: + ml.list_models(page_token=page_token) + check_error(excinfo, TypeError, 'Page token must be a string or None.') + + def test_list_models_error(self): + recorder = instrument_ml_service(status=400, payload=ERROR_RESPONSE_BAD_REQUEST) + with pytest.raises(exceptions.InvalidArgumentError) as excinfo: + ml.list_models() + check_firebase_error( + excinfo, + ERROR_STATUS_BAD_REQUEST, + ERROR_CODE_BAD_REQUEST, + ERROR_MSG_BAD_REQUEST + ) + assert len(recorder) == 1 + assert recorder[0].method == 'GET' + assert recorder[0].url == TestListModels._url(PROJECT_ID) + assert recorder[0].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE + + def test_no_project_id(self): + def evaluate(): + app = firebase_admin.initialize_app(testutils.MockCredential(), name='no_project_id') + with pytest.raises(ValueError): + ml.list_models(app=app) + testutils.run_without_project_id(evaluate) + + def test_list_single_page(self): + recorder = instrument_ml_service(status=200, payload=LAST_PAGE_LIST_RESPONSE) + models_page = ml.list_models() + assert len(recorder) == 1 + assert models_page.next_page_token == '' + assert models_page.has_next_page is False + assert models_page.get_next_page() is None + models = [model for model in models_page.iterate_all()] + assert len(models) == 1 + + def test_list_multiple_pages(self): + # Page 1 + recorder = instrument_ml_service(status=200, payload=DEFAULT_LIST_RESPONSE) + page = ml.list_models() + assert len(recorder) == 1 + assert len(page.models) == 2 + assert page.next_page_token == NEXT_PAGE_TOKEN + assert page.has_next_page is True + + # Page 2 + recorder = instrument_ml_service(status=200, payload=LAST_PAGE_LIST_RESPONSE) + page_2 = page.get_next_page() + assert len(recorder) == 1 + assert len(page_2.models) == 1 + assert page_2.next_page_token == '' + assert page_2.has_next_page is False + assert page_2.get_next_page() is None + + def test_list_models_paged_iteration(self): + # Page 1 + recorder = instrument_ml_service(status=200, payload=DEFAULT_LIST_RESPONSE) + page = ml.list_models() + assert page.next_page_token == NEXT_PAGE_TOKEN + assert page.has_next_page is True + iterator = page.iterate_all() + for index in range(2): + model = next(iterator) + assert model.display_name == 'displayName{0}'.format(index+1) + assert len(recorder) == 1 + + # Page 2 + recorder = instrument_ml_service(status=200, payload=LAST_PAGE_LIST_RESPONSE) + model = next(iterator) + assert model.display_name == DISPLAY_NAME_3 + with pytest.raises(StopIteration): + next(iterator) + + def test_list_models_stop_iteration(self): + recorder = instrument_ml_service(status=200, payload=ONE_PAGE_LIST_RESPONSE) + page = ml.list_models() + assert len(recorder) == 1 + assert len(page.models) == 3 + iterator = page.iterate_all() + models = [model for model in iterator] + assert len(page.models) == 3 + with pytest.raises(StopIteration): + next(iterator) + assert len(models) == 3 + + def test_list_models_no_models(self): + recorder = instrument_ml_service(status=200, payload=NO_MODELS_LIST_RESPONSE) + page = ml.list_models() + assert len(recorder) == 1 + assert len(page.models) == 0 + models = [model for model in page.iterate_all()] + assert len(models) == 0 From 1f534473e98d1a22f56ed48cf06843b289c02196 Mon Sep 17 00:00:00 2001 From: Lahiru Maramba Date: Tue, 21 Apr 2020 15:57:10 -0400 Subject: [PATCH 065/226] [chore] Release 4.1.0 (#451) - Release 4.1.0 --- firebase_admin/__about__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/firebase_admin/__about__.py b/firebase_admin/__about__.py index d9a27bd92..bd19af68f 100644 --- a/firebase_admin/__about__.py +++ b/firebase_admin/__about__.py @@ -14,7 +14,7 @@ """About information (version, etc) for Firebase Admin SDK.""" -__version__ = '4.0.1' +__version__ = '4.1.0' __title__ = 'firebase_admin' __author__ = 'Firebase' __license__ = 'Apache License 2.0' From ec78e6c80a5af313aedbf5f031c12d091b80dc2c Mon Sep 17 00:00:00 2001 From: Lahiru Maramba Date: Tue, 21 Apr 2020 16:37:18 -0400 Subject: [PATCH 066/226] Revert "[chore] Release 4.1.0 (#451)" (#452) This reverts commit 1f534473e98d1a22f56ed48cf06843b289c02196. --- firebase_admin/__about__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/firebase_admin/__about__.py b/firebase_admin/__about__.py index bd19af68f..d9a27bd92 100644 --- a/firebase_admin/__about__.py +++ b/firebase_admin/__about__.py @@ -14,7 +14,7 @@ """About information (version, etc) for Firebase Admin SDK.""" -__version__ = '4.1.0' +__version__ = '4.0.1' __title__ = 'firebase_admin' __author__ = 'Firebase' __license__ = 'Apache License 2.0' From 9a0b5aecdaad0eb12c02a630a939377e5b5f55c3 Mon Sep 17 00:00:00 2001 From: Lahiru Maramba Date: Tue, 21 Apr 2020 16:43:00 -0400 Subject: [PATCH 067/226] [chore] Release 4.1.0 --- firebase_admin/__about__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/firebase_admin/__about__.py b/firebase_admin/__about__.py index d9a27bd92..bd19af68f 100644 --- a/firebase_admin/__about__.py +++ b/firebase_admin/__about__.py @@ -14,7 +14,7 @@ """About information (version, etc) for Firebase Admin SDK.""" -__version__ = '4.0.1' +__version__ = '4.1.0' __title__ = 'firebase_admin' __author__ = 'Firebase' __license__ = 'Apache License 2.0' From f2b4f19e233d199b6c0200b9696fbf15a73dd015 Mon Sep 17 00:00:00 2001 From: Hiranya Jayathilaka Date: Mon, 27 Apr 2020 11:28:59 -0700 Subject: [PATCH 068/226] feat(auth): Adding multi-tenancy and IdP management APIs (#450) * feat(auth): APIs for retrieving and deleting tenants (#422) * feat(auth): Added Tenant class and get_tenant() API * Added delete_tenant() API * Added delete_tenant to _all_ list * Fixing a lint error * Fixing a lint error * Added create_tenant() and update_tenant() APIs (#424) * Create tenant API * Added update tenant API * Added docstring to fix lint error * Added list_tenants() API (#429) * Added list_tenants() API * Update firebase_admin/tenant_mgt.py Co-Authored-By: Lahiru Maramba * Updated error message Co-authored-by: Lahiru Maramba * Moved all public auth APIs to _AuthClient (#430) * Tenant-scoped user management operations (#431) * Adding tenant_mgt.auth_for_tenant() API * Added more tenant-aware user mgt tests * Full test coverage for tenant-aware user mgt APIs * Updated docstring to fix lint error * Removed unused var; Fixing lint error * Tenant-aware ID token verification support (#432) * Tenant-aware ID token verification support * Extended InvalidArgumentError in TenantIdMismatchError * Fixing lint errors * Added tenant-scoped custom token support (#433) * Added tenant-scoped custom token support * Fixed a lint error; Improved test assertion * Renamed _AuthService to Client (#436) * Renamed _AuthService to Client * Renamed some local variables for consistency * Added documentation to Client APIs * Fixed doc lint error * feat(auth): Adding SAMLProviderConfig type and the getter method (#437) * feat(auth): Adding SAMLProviderConfig type and the getter method * Added ConfigurationNotFoundError type * Fixing a lint error related to super delegation * feat(auth): Added create and update APIs for SAMLProviderConfig (#440) * feat(auth): Added create_saml_provider_config() API * Added update_saml_provider_config() API * Moved auth.Client to a separate submodule * Moved auth.Client; Updated docs * feat(auth): Added delete and list APIs for SAMLProviderConfig (#441) * feat(auth): Added delete_saml_provider_config() API * Preliminary list provider config impl * Refactored the common paging logic into base classes * Added more tests for list API * feat(auth): Added OIDCProviderConfig type and get/delete APIs (#442) * feat(auth): Added OIDCProviderConfig type and get/delete APIs * Added newline to eof * OIDCProviderConfig create/update APIs (#443) * feat(auth): Added list_oidc_provider_configs() API (#444) * fix(auth): Integration tests for multi-tenancy and IdP management APIs (#446) * fix(auth): Integration tests for IdP management APIs * More integration tests for tenant_mgt module; Made display_name required for tenants * Integration tests for tenant-aware IdP management * Fixing lint error; Added unit test for UserRecord.tenant_id * Trigger staging * Added unit tests for tenant names longer than 20 chars * Updated API reference docs * fix(auth): Snippets for multi-tenancy and IdP management APIs (#455) Co-authored-by: Lahiru Maramba --- firebase_admin/_auth_client.py | 625 ++++++++++++ firebase_admin/_auth_providers.py | 390 ++++++++ firebase_admin/_auth_utils.py | 92 +- firebase_admin/_token_gen.py | 15 +- firebase_admin/_user_mgt.py | 147 ++- firebase_admin/auth.py | 435 ++++++--- firebase_admin/tenant_mgt.py | 445 +++++++++ integration/test_auth.py | 186 ++++ integration/test_tenant_mgt.py | 417 ++++++++ snippets/auth/index.py | 413 ++++++++ tests/data/list_oidc_provider_configs.json | 18 + tests/data/list_saml_provider_configs.json | 40 + tests/data/oidc_provider_config.json | 7 + tests/data/saml_provider_config.json | 18 + tests/test_auth_providers.py | 732 ++++++++++++++ tests/test_tenant_mgt.py | 1004 ++++++++++++++++++++ tests/test_token_gen.py | 46 +- tests/test_user_mgt.py | 58 +- 18 files changed, 4851 insertions(+), 237 deletions(-) create mode 100644 firebase_admin/_auth_client.py create mode 100644 firebase_admin/_auth_providers.py create mode 100644 firebase_admin/tenant_mgt.py create mode 100644 integration/test_tenant_mgt.py create mode 100644 tests/data/list_oidc_provider_configs.json create mode 100644 tests/data/list_saml_provider_configs.json create mode 100644 tests/data/oidc_provider_config.json create mode 100644 tests/data/saml_provider_config.json create mode 100644 tests/test_auth_providers.py create mode 100644 tests/test_tenant_mgt.py diff --git a/firebase_admin/_auth_client.py b/firebase_admin/_auth_client.py new file mode 100644 index 000000000..b7af6ddb6 --- /dev/null +++ b/firebase_admin/_auth_client.py @@ -0,0 +1,625 @@ +# Copyright 2020 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Firebase auth client sub module.""" + +import time + +import firebase_admin +from firebase_admin import _auth_providers +from firebase_admin import _auth_utils +from firebase_admin import _http_client +from firebase_admin import _token_gen +from firebase_admin import _user_import +from firebase_admin import _user_mgt + + +class Client: + """Firebase Authentication client scoped to a specific tenant.""" + + def __init__(self, app, tenant_id=None): + if not app.project_id: + raise ValueError("""A project ID is required to access the auth service. + 1. Use a service account credential, or + 2. set the project ID explicitly via Firebase App options, or + 3. set the project ID via the GOOGLE_CLOUD_PROJECT environment variable.""") + + credential = app.credential.get_credential() + version_header = 'Python/Admin/{0}'.format(firebase_admin.__version__) + http_client = _http_client.JsonHttpClient( + credential=credential, headers={'X-Client-Version': version_header}) + + self._tenant_id = tenant_id + self._token_generator = _token_gen.TokenGenerator(app, http_client) + self._token_verifier = _token_gen.TokenVerifier(app) + self._user_manager = _user_mgt.UserManager(http_client, app.project_id, tenant_id) + self._provider_manager = _auth_providers.ProviderConfigClient( + http_client, app.project_id, tenant_id) + + @property + def tenant_id(self): + """Tenant ID associated with this client.""" + return self._tenant_id + + def create_custom_token(self, uid, developer_claims=None): + """Builds and signs a Firebase custom auth token. + + Args: + uid: ID of the user for whom the token is created. + developer_claims: A dictionary of claims to be included in the token + (optional). + + Returns: + bytes: A token minted from the input parameters. + + Raises: + ValueError: If input parameters are invalid. + TokenSignError: If an error occurs while signing the token using the remote IAM service. + """ + return self._token_generator.create_custom_token( + uid, developer_claims, tenant_id=self.tenant_id) + + def verify_id_token(self, id_token, check_revoked=False): + """Verifies the signature and data for the provided JWT. + + Accepts a signed token string, verifies that it is current, was issued + to this project, and that it was correctly signed by Google. + + Args: + id_token: A string of the encoded JWT. + check_revoked: Boolean, If true, checks whether the token has been revoked (optional). + + Returns: + dict: A dictionary of key-value pairs parsed from the decoded JWT. + + Raises: + ValueError: If ``id_token`` is a not a string or is empty. + InvalidIdTokenError: If ``id_token`` is not a valid Firebase ID token. + ExpiredIdTokenError: If the specified ID token has expired. + RevokedIdTokenError: If ``check_revoked`` is ``True`` and the ID token has been + revoked. + TenantIdMismatchError: If ``id_token`` belongs to a tenant that is different than + this ``Client`` instance. + CertificateFetchError: If an error occurs while fetching the public key certificates + required to verify the ID token. + """ + if not isinstance(check_revoked, bool): + # guard against accidental wrong assignment. + raise ValueError('Illegal check_revoked argument. Argument must be of type ' + ' bool, but given "{0}".'.format(type(check_revoked))) + + verified_claims = self._token_verifier.verify_id_token(id_token) + if self.tenant_id: + token_tenant_id = verified_claims.get('firebase', {}).get('tenant') + if self.tenant_id != token_tenant_id: + raise _auth_utils.TenantIdMismatchError( + 'Invalid tenant ID: {0}'.format(token_tenant_id)) + + if check_revoked: + self._check_jwt_revoked(verified_claims, _token_gen.RevokedIdTokenError, 'ID token') + return verified_claims + + def revoke_refresh_tokens(self, uid): + """Revokes all refresh tokens for an existing user. + + This method updates the user's ``tokens_valid_after_timestamp`` to the current UTC + in seconds since the epoch. It is important that the server on which this is called has its + clock set correctly and synchronized. + + While this revokes all sessions for a specified user and disables any new ID tokens for + existing sessions from getting minted, existing ID tokens may remain active until their + natural expiration (one hour). To verify that ID tokens are revoked, use + ``verify_id_token(idToken, check_revoked=True)``. + + Args: + uid: A user ID string. + + Raises: + ValueError: If the user ID is None, empty or malformed. + FirebaseError: If an error occurs while revoking the refresh token. + """ + self._user_manager.update_user(uid, valid_since=int(time.time())) + + def get_user(self, uid): + """Gets the user data corresponding to the specified user ID. + + Args: + uid: A user ID string. + + Returns: + UserRecord: A user record instance. + + Raises: + ValueError: If the user ID is None, empty or malformed. + UserNotFoundError: If the specified user ID does not exist. + FirebaseError: If an error occurs while retrieving the user. + """ + response = self._user_manager.get_user(uid=uid) + return _user_mgt.UserRecord(response) + + def get_user_by_email(self, email): + """Gets the user data corresponding to the specified user email. + + Args: + email: A user email address string. + + Returns: + UserRecord: A user record instance. + + Raises: + ValueError: If the email is None, empty or malformed. + UserNotFoundError: If no user exists by the specified email address. + FirebaseError: If an error occurs while retrieving the user. + """ + response = self._user_manager.get_user(email=email) + return _user_mgt.UserRecord(response) + + def get_user_by_phone_number(self, phone_number): + """Gets the user data corresponding to the specified phone number. + + Args: + phone_number: A phone number string. + + Returns: + UserRecord: A user record instance. + + Raises: + ValueError: If the phone number is ``None``, empty or malformed. + UserNotFoundError: If no user exists by the specified phone number. + FirebaseError: If an error occurs while retrieving the user. + """ + response = self._user_manager.get_user(phone_number=phone_number) + return _user_mgt.UserRecord(response) + + def list_users(self, page_token=None, max_results=_user_mgt.MAX_LIST_USERS_RESULTS): + """Retrieves a page of user accounts from a Firebase project. + + The ``page_token`` argument governs the starting point of the page. The ``max_results`` + argument governs the maximum number of user accounts that may be included in the returned + page. This function never returns ``None``. If there are no user accounts in the Firebase + project, this returns an empty page. + + Args: + page_token: A non-empty page token string, which indicates the starting point of the + page (optional). Defaults to ``None``, which will retrieve the first page of users. + max_results: A positive integer indicating the maximum number of users to include in + the returned page (optional). Defaults to 1000, which is also the maximum number + allowed. + + Returns: + ListUsersPage: A page of user accounts. + + Raises: + ValueError: If max_results or page_token are invalid. + FirebaseError: If an error occurs while retrieving the user accounts. + """ + def download(page_token, max_results): + return self._user_manager.list_users(page_token, max_results) + return _user_mgt.ListUsersPage(download, page_token, max_results) + + def create_user(self, **kwargs): # pylint: disable=differing-param-doc + """Creates a new user account with the specified properties. + + Args: + kwargs: A series of keyword arguments (optional). + + Keyword Args: + uid: User ID to assign to the newly created user (optional). + display_name: The user's display name (optional). + email: The user's primary email (optional). + email_verified: A boolean indicating whether or not the user's primary email is + verified (optional). + phone_number: The user's primary phone number (optional). + photo_url: The user's photo URL (optional). + password: The user's raw, unhashed password. (optional). + disabled: A boolean indicating whether or not the user account is disabled (optional). + + Returns: + UserRecord: A UserRecord instance for the newly created user. + + Raises: + ValueError: If the specified user properties are invalid. + FirebaseError: If an error occurs while creating the user account. + """ + uid = self._user_manager.create_user(**kwargs) + return self.get_user(uid=uid) + + def update_user(self, uid, **kwargs): # pylint: disable=differing-param-doc + """Updates an existing user account with the specified properties. + + Args: + uid: A user ID string. + kwargs: A series of keyword arguments (optional). + + Keyword Args: + display_name: The user's display name (optional). Can be removed by explicitly passing + ``auth.DELETE_ATTRIBUTE``. + email: The user's primary email (optional). + email_verified: A boolean indicating whether or not the user's primary email is + verified (optional). + phone_number: The user's primary phone number (optional). Can be removed by explicitly + passing ``auth.DELETE_ATTRIBUTE``. + photo_url: The user's photo URL (optional). Can be removed by explicitly passing + ``auth.DELETE_ATTRIBUTE``. + password: The user's raw, unhashed password. (optional). + disabled: A boolean indicating whether or not the user account is disabled (optional). + custom_claims: A dictionary or a JSON string contining the custom claims to be set on + the user account (optional). To remove all custom claims, pass + ``auth.DELETE_ATTRIBUTE``. + valid_since: An integer signifying the seconds since the epoch (optional). This field + is set by ``revoke_refresh_tokens`` and it is discouraged to set this field + directly. + + Returns: + UserRecord: An updated UserRecord instance for the user. + + Raises: + ValueError: If the specified user ID or properties are invalid. + FirebaseError: If an error occurs while updating the user account. + """ + self._user_manager.update_user(uid, **kwargs) + return self.get_user(uid=uid) + + def set_custom_user_claims(self, uid, custom_claims): + """Sets additional claims on an existing user account. + + Custom claims set via this function can be used to define user roles and privilege levels. + These claims propagate to all the devices where the user is already signed in (after token + expiration or when token refresh is forced), and next time the user signs in. The claims + can be accessed via the user's ID token JWT. If a reserved OIDC claim is specified (sub, + iat, iss, etc), an error is thrown. Claims payload must also not be larger then 1000 + characters when serialized into a JSON string. + + Args: + uid: A user ID string. + custom_claims: A dictionary or a JSON string of custom claims. Pass None to unset any + claims set previously. + + Raises: + ValueError: If the specified user ID or the custom claims are invalid. + FirebaseError: If an error occurs while updating the user account. + """ + if custom_claims is None: + custom_claims = _user_mgt.DELETE_ATTRIBUTE + self._user_manager.update_user(uid, custom_claims=custom_claims) + + def delete_user(self, uid): + """Deletes the user identified by the specified user ID. + + Args: + uid: A user ID string. + + Raises: + ValueError: If the user ID is None, empty or malformed. + FirebaseError: If an error occurs while deleting the user account. + """ + self._user_manager.delete_user(uid) + + def import_users(self, users, hash_alg=None): + """Imports the specified list of users into Firebase Auth. + + At most 1000 users can be imported at a time. This operation is optimized for bulk imports + and ignores checks on identifier uniqueness, which could result in duplications. The + ``hash_alg`` parameter must be specified when importing users with passwords. Refer to the + ``UserImportHash`` class for supported hash algorithms. + + Args: + users: A list of ``ImportUserRecord`` instances to import. Length of the list must not + exceed 1000. + hash_alg: A ``UserImportHash`` object (optional). Required when importing users with + passwords. + + Returns: + UserImportResult: An object summarizing the result of the import operation. + + Raises: + ValueError: If the provided arguments are invalid. + FirebaseError: If an error occurs while importing users. + """ + result = self._user_manager.import_users(users, hash_alg) + return _user_import.UserImportResult(result, len(users)) + + def generate_password_reset_link(self, email, action_code_settings=None): + """Generates the out-of-band email action link for password reset flows for the specified + email address. + + Args: + email: The email of the user whose password is to be reset. + action_code_settings: ``ActionCodeSettings`` instance (optional). Defines whether + the link is to be handled by a mobile app and the additional state information to + be passed in the deep link. + + Returns: + link: The password reset link created by the API + + Raises: + ValueError: If the provided arguments are invalid + FirebaseError: If an error occurs while generating the link + """ + return self._user_manager.generate_email_action_link( + 'PASSWORD_RESET', email, action_code_settings=action_code_settings) + + def generate_email_verification_link(self, email, action_code_settings=None): + """Generates the out-of-band email action link for email verification flows for the + specified email address. + + Args: + email: The email of the user to be verified. + action_code_settings: ``ActionCodeSettings`` instance (optional). Defines whether + the link is to be handled by a mobile app and the additional state information to + be passed in the deep link. + + Returns: + link: The email verification link created by the API + + Raises: + ValueError: If the provided arguments are invalid + FirebaseError: If an error occurs while generating the link + """ + return self._user_manager.generate_email_action_link( + 'VERIFY_EMAIL', email, action_code_settings=action_code_settings) + + def generate_sign_in_with_email_link(self, email, action_code_settings): + """Generates the out-of-band email action link for email link sign-in flows, using the + action code settings provided. + + Args: + email: The email of the user signing in. + action_code_settings: ``ActionCodeSettings`` instance. Defines whether + the link is to be handled by a mobile app and the additional state information to be + passed in the deep link. + + Returns: + link: The email sign-in link created by the API + + Raises: + ValueError: If the provided arguments are invalid + FirebaseError: If an error occurs while generating the link + """ + return self._user_manager.generate_email_action_link( + 'EMAIL_SIGNIN', email, action_code_settings=action_code_settings) + + def get_oidc_provider_config(self, provider_id): + """Returns the ``OIDCProviderConfig`` with the given ID. + + Args: + provider_id: Provider ID string. + + Returns: + SAMLProviderConfig: An OIDC provider config instance. + + Raises: + ValueError: If the provider ID is invalid, empty or does not have ``oidc.`` prefix. + ConfigurationNotFoundError: If no OIDC provider is available with the given identifier. + FirebaseError: If an error occurs while retrieving the OIDC provider. + """ + return self._provider_manager.get_oidc_provider_config(provider_id) + + def create_oidc_provider_config( + self, provider_id, client_id, issuer, display_name=None, enabled=None): + """Creates a new OIDC provider config from the given parameters. + + OIDC provider support requires Google Cloud's Identity Platform (GCIP). To learn more about + GCIP, including pricing and features, see https://cloud.google.com/identity-platform. + + Args: + provider_id: Provider ID string. Must have the prefix ``oidc.``. + client_id: Client ID of the new config. + issuer: Issuer of the new config. Must be a valid URL. + display_name: The user-friendly display name to the current configuration (optional). + This name is also used as the provider label in the Cloud Console. + enabled: A boolean indicating whether the provider configuration is enabled or disabled + (optional). A user cannot sign in using a disabled provider. + + Returns: + OIDCProviderConfig: The newly created OIDC provider config instance. + + Raises: + ValueError: If any of the specified input parameters are invalid. + FirebaseError: If an error occurs while creating the new OIDC provider config. + """ + return self._provider_manager.create_oidc_provider_config( + provider_id, client_id=client_id, issuer=issuer, display_name=display_name, + enabled=enabled) + + def update_oidc_provider_config( + self, provider_id, client_id=None, issuer=None, display_name=None, enabled=None): + """Updates an existing OIDC provider config with the given parameters. + + Args: + provider_id: Provider ID string. Must have the prefix ``oidc.``. + client_id: Client ID of the new config (optional). + issuer: Issuer of the new config (optional). Must be a valid URL. + display_name: The user-friendly display name to the current configuration (optional). + Pass ``auth.DELETE_ATTRIBUTE`` to delete the current display name. + enabled: A boolean indicating whether the provider configuration is enabled or disabled + (optional). + + Returns: + OIDCProviderConfig: The updated OIDC provider config instance. + + Raises: + ValueError: If any of the specified input parameters are invalid. + FirebaseError: If an error occurs while updating the OIDC provider config. + """ + return self._provider_manager.update_oidc_provider_config( + provider_id, client_id=client_id, issuer=issuer, display_name=display_name, + enabled=enabled) + + def delete_oidc_provider_config(self, provider_id): + """Deletes the ``OIDCProviderConfig`` with the given ID. + + Args: + provider_id: Provider ID string. + + Raises: + ValueError: If the provider ID is invalid, empty or does not have ``oidc.`` prefix. + ConfigurationNotFoundError: If no OIDC provider is available with the given identifier. + FirebaseError: If an error occurs while deleting the OIDC provider. + """ + self._provider_manager.delete_oidc_provider_config(provider_id) + + def list_oidc_provider_configs( + self, page_token=None, max_results=_auth_providers.MAX_LIST_CONFIGS_RESULTS): + """Retrieves a page of OIDC provider configs from a Firebase project. + + The ``page_token`` argument governs the starting point of the page. The ``max_results`` + argument governs the maximum number of configs that may be included in the returned + page. This function never returns ``None``. If there are no OIDC configs in the Firebase + project, this returns an empty page. + + Args: + page_token: A non-empty page token string, which indicates the starting point of the + page (optional). Defaults to ``None``, which will retrieve the first page of users. + max_results: A positive integer indicating the maximum number of users to include in + the returned page (optional). Defaults to 100, which is also the maximum number + allowed. + + Returns: + ListProviderConfigsPage: A page of OIDC provider config instances. + + Raises: + ValueError: If ``max_results`` or ``page_token`` are invalid. + FirebaseError: If an error occurs while retrieving the OIDC provider configs. + """ + return self._provider_manager.list_oidc_provider_configs(page_token, max_results) + + def get_saml_provider_config(self, provider_id): + """Returns the ``SAMLProviderConfig`` with the given ID. + + Args: + provider_id: Provider ID string. + + Returns: + SAMLProviderConfig: A SAML provider config instance. + + Raises: + ValueError: If the provider ID is invalid, empty or does not have ``saml.`` prefix. + ConfigurationNotFoundError: If no SAML provider is available with the given identifier. + FirebaseError: If an error occurs while retrieving the SAML provider. + """ + return self._provider_manager.get_saml_provider_config(provider_id) + + def create_saml_provider_config( + self, provider_id, idp_entity_id, sso_url, x509_certificates, rp_entity_id, + callback_url, display_name=None, enabled=None): + """Creates a new SAML provider config from the given parameters. + + SAML provider support requires Google Cloud's Identity Platform (GCIP). To learn more about + GCIP, including pricing and features, see https://cloud.google.com/identity-platform. + + Args: + provider_id: Provider ID string. Must have the prefix ``saml.``. + idp_entity_id: The SAML IdP entity identifier. + sso_url: The SAML IdP SSO URL. Must be a valid URL. + x509_certificates: The list of SAML IdP X.509 certificates issued by CA for this + provider. Multiple certificates are accepted to prevent outages during IdP key + rotation (for example ADFS rotates every 10 days). When the Auth server receives a + SAML response, it will match the SAML response with the certificate on record. + Otherwise the response is rejected. Developers are expected to manage the + certificate updates as keys are rotated. + rp_entity_id: The SAML relying party (service provider) entity ID. This is defined by + the developer but needs to be provided to the SAML IdP. + callback_url: Callback URL string. This is fixed and must always be the same as the + OAuth redirect URL provisioned by Firebase Auth, unless a custom authDomain is + used. + display_name: The user-friendly display name to the current configuration (optional). + This name is also used as the provider label in the Cloud Console. + enabled: A boolean indicating whether the provider configuration is enabled or disabled + (optional). A user cannot sign in using a disabled provider. + + Returns: + SAMLProviderConfig: The newly created SAML provider config instance. + + Raises: + ValueError: If any of the specified input parameters are invalid. + FirebaseError: If an error occurs while creating the new SAML provider config. + """ + return self._provider_manager.create_saml_provider_config( + provider_id, idp_entity_id=idp_entity_id, sso_url=sso_url, + x509_certificates=x509_certificates, rp_entity_id=rp_entity_id, + callback_url=callback_url, display_name=display_name, enabled=enabled) + + def update_saml_provider_config( + self, provider_id, idp_entity_id=None, sso_url=None, x509_certificates=None, + rp_entity_id=None, callback_url=None, display_name=None, enabled=None): + """Updates an existing SAML provider config with the given parameters. + + Args: + provider_id: Provider ID string. Must have the prefix ``saml.``. + idp_entity_id: The SAML IdP entity identifier (optional). + sso_url: The SAML IdP SSO URL. Must be a valid URL (optional). + x509_certificates: The list of SAML IdP X.509 certificates issued by CA for this + provider (optional). + rp_entity_id: The SAML relying party entity ID (optional). + callback_url: Callback URL string (optional). + display_name: The user-friendly display name of the current configuration (optional). + Pass ``auth.DELETE_ATTRIBUTE`` to delete the current display name. + enabled: A boolean indicating whether the provider configuration is enabled or disabled + (optional). + + Returns: + SAMLProviderConfig: The updated SAML provider config instance. + + Raises: + ValueError: If any of the specified input parameters are invalid. + FirebaseError: If an error occurs while updating the SAML provider config. + """ + return self._provider_manager.update_saml_provider_config( + provider_id, idp_entity_id=idp_entity_id, sso_url=sso_url, + x509_certificates=x509_certificates, rp_entity_id=rp_entity_id, + callback_url=callback_url, display_name=display_name, enabled=enabled) + + def delete_saml_provider_config(self, provider_id): + """Deletes the ``SAMLProviderConfig`` with the given ID. + + Args: + provider_id: Provider ID string. + + Raises: + ValueError: If the provider ID is invalid, empty or does not have ``saml.`` prefix. + ConfigurationNotFoundError: If no SAML provider is available with the given identifier. + FirebaseError: If an error occurs while deleting the SAML provider. + """ + self._provider_manager.delete_saml_provider_config(provider_id) + + def list_saml_provider_configs( + self, page_token=None, max_results=_auth_providers.MAX_LIST_CONFIGS_RESULTS): + """Retrieves a page of SAML provider configs from a Firebase project. + + The ``page_token`` argument governs the starting point of the page. The ``max_results`` + argument governs the maximum number of configs that may be included in the returned + page. This function never returns ``None``. If there are no SAML configs in the Firebase + project, this returns an empty page. + + Args: + page_token: A non-empty page token string, which indicates the starting point of the + page (optional). Defaults to ``None``, which will retrieve the first page of users. + max_results: A positive integer indicating the maximum number of users to include in + the returned page (optional). Defaults to 100, which is also the maximum number + allowed. + + Returns: + ListProviderConfigsPage: A page of SAML provider config instances. + + Raises: + ValueError: If ``max_results`` or ``page_token`` are invalid. + FirebaseError: If an error occurs while retrieving the SAML provider configs. + """ + return self._provider_manager.list_saml_provider_configs(page_token, max_results) + + def _check_jwt_revoked(self, verified_claims, exc_type, label): + user = self.get_user(verified_claims.get('uid')) + if verified_claims.get('iat') * 1000 < user.tokens_valid_after_timestamp: + raise exc_type('The Firebase {0} has been revoked.'.format(label)) diff --git a/firebase_admin/_auth_providers.py b/firebase_admin/_auth_providers.py new file mode 100644 index 000000000..96f1b5348 --- /dev/null +++ b/firebase_admin/_auth_providers.py @@ -0,0 +1,390 @@ +# Copyright 2020 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Firebase auth providers management sub module.""" + +from urllib import parse + +import requests + +from firebase_admin import _auth_utils +from firebase_admin import _user_mgt + + +MAX_LIST_CONFIGS_RESULTS = 100 + + +class ProviderConfig: + """Parent type for all authentication provider config types.""" + + def __init__(self, data): + self._data = data + + @property + def provider_id(self): + name = self._data['name'] + return name.split('/')[-1] + + @property + def display_name(self): + return self._data.get('displayName') + + @property + def enabled(self): + return self._data.get('enabled', False) + + +class OIDCProviderConfig(ProviderConfig): + """Represents the OIDC auth provider configuration. + + See https://openid.net/specs/openid-connect-core-1_0-final.html. + """ + + @property + def issuer(self): + return self._data['issuer'] + + @property + def client_id(self): + return self._data['clientId'] + + +class SAMLProviderConfig(ProviderConfig): + """Represents he SAML auth provider configuration. + + See http://docs.oasis-open.org/security/saml/Post2.0/sstc-saml-tech-overview-2.0.html. + """ + + @property + def idp_entity_id(self): + return self._data.get('idpConfig', {})['idpEntityId'] + + @property + def sso_url(self): + return self._data.get('idpConfig', {})['ssoUrl'] + + @property + def x509_certificates(self): + certs = self._data.get('idpConfig', {})['idpCertificates'] + return [c['x509Certificate'] for c in certs] + + @property + def callback_url(self): + return self._data.get('spConfig', {})['callbackUri'] + + @property + def rp_entity_id(self): + return self._data.get('spConfig', {})['spEntityId'] + + +class ListProviderConfigsPage: + """Represents a page of AuthProviderConfig instances retrieved from a Firebase project. + + Provides methods for traversing the provider configs included in this page, as well as + retrieving subsequent pages. The iterator returned by ``iterate_all()`` can be used to iterate + through all provider configs in the Firebase project starting from this page. + """ + + def __init__(self, download, page_token, max_results): + self._download = download + self._max_results = max_results + self._current = download(page_token, max_results) + + @property + def provider_configs(self): + """A list of ``AuthProviderConfig`` instances available in this page.""" + raise NotImplementedError + + @property + def next_page_token(self): + """Page token string for the next page (empty string indicates no more pages).""" + return self._current.get('nextPageToken', '') + + @property + def has_next_page(self): + """A boolean indicating whether more pages are available.""" + return bool(self.next_page_token) + + def get_next_page(self): + """Retrieves the next page of provider configs, if available. + + Returns: + ListProviderConfigsPage: Next page of provider configs, or None if this is the last + page. + """ + if self.has_next_page: + return self.__class__(self._download, self.next_page_token, self._max_results) + return None + + def iterate_all(self): + """Retrieves an iterator for provider configs. + + Returned iterator will iterate through all the provider configs in the Firebase project + starting from this page. The iterator will never buffer more than one page of configs + in memory at a time. + + Returns: + iterator: An iterator of AuthProviderConfig instances. + """ + return _ProviderConfigIterator(self) + + +class _ListOIDCProviderConfigsPage(ListProviderConfigsPage): + + @property + def provider_configs(self): + return [OIDCProviderConfig(data) for data in self._current.get('oauthIdpConfigs', [])] + + +class _ListSAMLProviderConfigsPage(ListProviderConfigsPage): + + @property + def provider_configs(self): + return [SAMLProviderConfig(data) for data in self._current.get('inboundSamlConfigs', [])] + + +class _ProviderConfigIterator(_auth_utils.PageIterator): + + @property + def items(self): + return self._current_page.provider_configs + + +class ProviderConfigClient: + """Client for managing Auth provider configurations.""" + + PROVIDER_CONFIG_URL = 'https://identitytoolkit.googleapis.com/v2beta1' + + def __init__(self, http_client, project_id, tenant_id=None): + self.http_client = http_client + self.base_url = '{0}/projects/{1}'.format(self.PROVIDER_CONFIG_URL, project_id) + if tenant_id: + self.base_url += '/tenants/{0}'.format(tenant_id) + + def get_oidc_provider_config(self, provider_id): + _validate_oidc_provider_id(provider_id) + body = self._make_request('get', '/oauthIdpConfigs/{0}'.format(provider_id)) + return OIDCProviderConfig(body) + + def create_oidc_provider_config( + self, provider_id, client_id, issuer, display_name=None, enabled=None): + """Creates a new OIDC provider config from the given parameters.""" + _validate_oidc_provider_id(provider_id) + req = { + 'clientId': _validate_non_empty_string(client_id, 'client_id'), + 'issuer': _validate_url(issuer, 'issuer'), + } + if display_name is not None: + req['displayName'] = _auth_utils.validate_string(display_name, 'display_name') + if enabled is not None: + req['enabled'] = _auth_utils.validate_boolean(enabled, 'enabled') + + params = 'oauthIdpConfigId={0}'.format(provider_id) + body = self._make_request('post', '/oauthIdpConfigs', json=req, params=params) + return OIDCProviderConfig(body) + + def update_oidc_provider_config( + self, provider_id, client_id=None, issuer=None, display_name=None, enabled=None): + """Updates an existing OIDC provider config with the given parameters.""" + _validate_oidc_provider_id(provider_id) + req = {} + if display_name is not None: + if display_name == _user_mgt.DELETE_ATTRIBUTE: + req['displayName'] = None + else: + req['displayName'] = _auth_utils.validate_string(display_name, 'display_name') + if enabled is not None: + req['enabled'] = _auth_utils.validate_boolean(enabled, 'enabled') + if client_id: + req['clientId'] = _validate_non_empty_string(client_id, 'client_id') + if issuer: + req['issuer'] = _validate_url(issuer, 'issuer') + + if not req: + raise ValueError('At least one parameter must be specified for update.') + + update_mask = _auth_utils.build_update_mask(req) + params = 'updateMask={0}'.format(','.join(update_mask)) + url = '/oauthIdpConfigs/{0}'.format(provider_id) + body = self._make_request('patch', url, json=req, params=params) + return OIDCProviderConfig(body) + + def delete_oidc_provider_config(self, provider_id): + _validate_oidc_provider_id(provider_id) + self._make_request('delete', '/oauthIdpConfigs/{0}'.format(provider_id)) + + def list_oidc_provider_configs(self, page_token=None, max_results=MAX_LIST_CONFIGS_RESULTS): + return _ListOIDCProviderConfigsPage( + self._fetch_oidc_provider_configs, page_token, max_results) + + def _fetch_oidc_provider_configs(self, page_token=None, max_results=MAX_LIST_CONFIGS_RESULTS): + return self._fetch_provider_configs('/oauthIdpConfigs', page_token, max_results) + + def get_saml_provider_config(self, provider_id): + _validate_saml_provider_id(provider_id) + body = self._make_request('get', '/inboundSamlConfigs/{0}'.format(provider_id)) + return SAMLProviderConfig(body) + + def create_saml_provider_config( + self, provider_id, idp_entity_id, sso_url, x509_certificates, + rp_entity_id, callback_url, display_name=None, enabled=None): + """Creates a new SAML provider config from the given parameters.""" + _validate_saml_provider_id(provider_id) + req = { + 'idpConfig': { + 'idpEntityId': _validate_non_empty_string(idp_entity_id, 'idp_entity_id'), + 'ssoUrl': _validate_url(sso_url, 'sso_url'), + 'idpCertificates': _validate_x509_certificates(x509_certificates), + }, + 'spConfig': { + 'spEntityId': _validate_non_empty_string(rp_entity_id, 'rp_entity_id'), + 'callbackUri': _validate_url(callback_url, 'callback_url'), + }, + } + if display_name is not None: + req['displayName'] = _auth_utils.validate_string(display_name, 'display_name') + if enabled is not None: + req['enabled'] = _auth_utils.validate_boolean(enabled, 'enabled') + + params = 'inboundSamlConfigId={0}'.format(provider_id) + body = self._make_request('post', '/inboundSamlConfigs', json=req, params=params) + return SAMLProviderConfig(body) + + def update_saml_provider_config( + self, provider_id, idp_entity_id=None, sso_url=None, x509_certificates=None, + rp_entity_id=None, callback_url=None, display_name=None, enabled=None): + """Updates an existing SAML provider config with the given parameters.""" + _validate_saml_provider_id(provider_id) + idp_config = {} + if idp_entity_id is not None: + idp_config['idpEntityId'] = _validate_non_empty_string(idp_entity_id, 'idp_entity_id') + if sso_url is not None: + idp_config['ssoUrl'] = _validate_url(sso_url, 'sso_url') + if x509_certificates is not None: + idp_config['idpCertificates'] = _validate_x509_certificates(x509_certificates) + + sp_config = {} + if rp_entity_id is not None: + sp_config['spEntityId'] = _validate_non_empty_string(rp_entity_id, 'rp_entity_id') + if callback_url is not None: + sp_config['callbackUri'] = _validate_url(callback_url, 'callback_url') + + req = {} + if display_name is not None: + if display_name == _user_mgt.DELETE_ATTRIBUTE: + req['displayName'] = None + else: + req['displayName'] = _auth_utils.validate_string(display_name, 'display_name') + if enabled is not None: + req['enabled'] = _auth_utils.validate_boolean(enabled, 'enabled') + if idp_config: + req['idpConfig'] = idp_config + if sp_config: + req['spConfig'] = sp_config + + if not req: + raise ValueError('At least one parameter must be specified for update.') + + update_mask = _auth_utils.build_update_mask(req) + params = 'updateMask={0}'.format(','.join(update_mask)) + url = '/inboundSamlConfigs/{0}'.format(provider_id) + body = self._make_request('patch', url, json=req, params=params) + return SAMLProviderConfig(body) + + def delete_saml_provider_config(self, provider_id): + _validate_saml_provider_id(provider_id) + self._make_request('delete', '/inboundSamlConfigs/{0}'.format(provider_id)) + + def list_saml_provider_configs(self, page_token=None, max_results=MAX_LIST_CONFIGS_RESULTS): + return _ListSAMLProviderConfigsPage( + self._fetch_saml_provider_configs, page_token, max_results) + + def _fetch_saml_provider_configs(self, page_token=None, max_results=MAX_LIST_CONFIGS_RESULTS): + return self._fetch_provider_configs('/inboundSamlConfigs', page_token, max_results) + + def _fetch_provider_configs(self, path, page_token=None, max_results=MAX_LIST_CONFIGS_RESULTS): + """Fetches a page of auth provider configs""" + if page_token is not None: + if not isinstance(page_token, str) or not page_token: + raise ValueError('Page token must be a non-empty string.') + if not isinstance(max_results, int): + raise ValueError('Max results must be an integer.') + if max_results < 1 or max_results > MAX_LIST_CONFIGS_RESULTS: + raise ValueError( + 'Max results must be a positive integer less than or equal to ' + '{0}.'.format(MAX_LIST_CONFIGS_RESULTS)) + + params = 'pageSize={0}'.format(max_results) + if page_token: + params += '&pageToken={0}'.format(page_token) + return self._make_request('get', path, params=params) + + def _make_request(self, method, path, **kwargs): + url = '{0}{1}'.format(self.base_url, path) + try: + return self.http_client.body(method, url, **kwargs) + except requests.exceptions.RequestException as error: + raise _auth_utils.handle_auth_backend_error(error) + + +def _validate_oidc_provider_id(provider_id): + if not isinstance(provider_id, str): + raise ValueError( + 'Invalid OIDC provider ID: {0}. Provider ID must be a non-empty string.'.format( + provider_id)) + if not provider_id.startswith('oidc.'): + raise ValueError('Invalid OIDC provider ID: {0}.'.format(provider_id)) + return provider_id + + +def _validate_saml_provider_id(provider_id): + if not isinstance(provider_id, str): + raise ValueError( + 'Invalid SAML provider ID: {0}. Provider ID must be a non-empty string.'.format( + provider_id)) + if not provider_id.startswith('saml.'): + raise ValueError('Invalid SAML provider ID: {0}.'.format(provider_id)) + return provider_id + + +def _validate_non_empty_string(value, label): + """Validates that the given value is a non-empty string.""" + if not isinstance(value, str): + raise ValueError('Invalid type for {0}: {1}.'.format(label, value)) + if not value: + raise ValueError('{0} must not be empty.'.format(label)) + return value + + +def _validate_url(url, label): + """Validates that the given value is a well-formed URL string.""" + if not isinstance(url, str) or not url: + raise ValueError( + 'Invalid photo URL: "{0}". {1} must be a non-empty ' + 'string.'.format(url, label)) + try: + parsed = parse.urlparse(url) + if not parsed.netloc: + raise ValueError('Malformed {0}: "{1}".'.format(label, url)) + return url + except Exception: + raise ValueError('Malformed {0}: "{1}".'.format(label, url)) + + +def _validate_x509_certificates(x509_certificates): + if not isinstance(x509_certificates, list) or not x509_certificates: + raise ValueError('x509_certificates must be a non-empty list.') + if not all([isinstance(cert, str) and cert for cert in x509_certificates]): + raise ValueError('x509_certificates must only contain non-empty strings.') + return [{'x509Certificate': cert} for cert in x509_certificates] diff --git a/firebase_admin/_auth_utils.py b/firebase_admin/_auth_utils.py index 2f7383c0b..f1ce97dee 100644 --- a/firebase_admin/_auth_utils.py +++ b/firebase_admin/_auth_utils.py @@ -30,6 +30,42 @@ VALID_EMAIL_ACTION_TYPES = set(['VERIFY_EMAIL', 'EMAIL_SIGNIN', 'PASSWORD_RESET']) +class PageIterator: + """An iterator that allows iterating over a sequence of items, one at a time. + + This implementation loads a page of items into memory, and iterates on them. When the whole + page has been traversed, it loads another page. This class never keeps more than one page + of entries in memory. + """ + + def __init__(self, current_page): + if not current_page: + raise ValueError('Current page must not be None.') + self._current_page = current_page + self._index = 0 + + def next(self): + if self._index == len(self.items): + if self._current_page.has_next_page: + self._current_page = self._current_page.get_next_page() + self._index = 0 + if self._index < len(self.items): + result = self.items[self._index] + self._index += 1 + return result + raise StopIteration + + @property + def items(self): + raise NotImplementedError + + def __next__(self): + return self.next() + + def __iter__(self): + return self + + def validate_uid(uid, required=False): if uid is None and not required: return None @@ -157,6 +193,18 @@ def validate_int(value, label, low=None, high=None): raise ValueError('{0} must not be larger than {1}.'.format(label, high)) return val_int +def validate_string(value, label): + """Validates that the given value is a string.""" + if not isinstance(value, str): + raise ValueError('Invalid type for {0}: {1}.'.format(label, value)) + return value + +def validate_boolean(value, label): + """Validates that the given value is a boolean.""" + if not isinstance(value, bool): + raise ValueError('Invalid type for {0}: {1}.'.format(label, value)) + return value + def validate_custom_claims(custom_claims, required=False): """Validates the specified custom claims. @@ -192,6 +240,19 @@ def validate_action_type(action_type): Valid values are {1}'.format(action_type, ', '.join(VALID_EMAIL_ACTION_TYPES))) return action_type +def build_update_mask(params): + """Creates an update mask list from the given dictionary.""" + mask = [] + for key, value in params.items(): + if isinstance(value, dict): + child_mask = build_update_mask(value) + for child in child_mask: + mask.append('{0}.{1}'.format(key, child)) + else: + mask.append(key) + + return sorted(mask) + class UidAlreadyExistsError(exceptions.AlreadyExistsError): """The user with the provided uid already exists.""" @@ -266,7 +327,33 @@ def __init__(self, message, cause=None, http_response=None): exceptions.NotFoundError.__init__(self, message, cause, http_response) +class TenantNotFoundError(exceptions.NotFoundError): + """No tenant found for the specified identifier.""" + + default_message = 'No tenant found for the given identifier' + + def __init__(self, message, cause=None, http_response=None): + exceptions.NotFoundError.__init__(self, message, cause, http_response) + + +class TenantIdMismatchError(exceptions.InvalidArgumentError): + """Missing or invalid tenant ID field in the given JWT.""" + + def __init__(self, message): + exceptions.InvalidArgumentError.__init__(self, message) + + +class ConfigurationNotFoundError(exceptions.NotFoundError): + """No auth provider found for the specified identifier.""" + + default_message = 'No auth provider found for the given identifier' + + def __init__(self, message, cause=None, http_response=None): + exceptions.NotFoundError.__init__(self, message, cause, http_response) + + _CODE_TO_EXC_TYPE = { + 'CONFIGURATION_NOT_FOUND': ConfigurationNotFoundError, 'DUPLICATE_EMAIL': EmailAlreadyExistsError, 'DUPLICATE_LOCAL_ID': UidAlreadyExistsError, 'EMAIL_EXISTS': EmailAlreadyExistsError, @@ -274,6 +361,7 @@ def __init__(self, message, cause=None, http_response=None): 'INVALID_DYNAMIC_LINK_DOMAIN': InvalidDynamicLinkDomainError, 'INVALID_ID_TOKEN': InvalidIdTokenError, 'PHONE_NUMBER_EXISTS': PhoneNumberAlreadyExistsError, + 'TENANT_NOT_FOUND': TenantNotFoundError, 'USER_NOT_FOUND': UserNotFoundError, } @@ -281,12 +369,12 @@ def __init__(self, message, cause=None, http_response=None): def handle_auth_backend_error(error): """Converts a requests error received from the Firebase Auth service into a FirebaseError.""" if error.response is None: - raise _utils.handle_requests_error(error) + return _utils.handle_requests_error(error) code, custom_message = _parse_error_body(error.response) if not code: msg = 'Unexpected error response: {0}'.format(error.response.content.decode()) - raise _utils.handle_requests_error(error, message=msg) + return _utils.handle_requests_error(error, message=msg) exc_type = _CODE_TO_EXC_TYPE.get(code) msg = _build_error_message(code, exc_type, custom_message) diff --git a/firebase_admin/_token_gen.py b/firebase_admin/_token_gen.py index 4234bfa7b..18a8008c7 100644 --- a/firebase_admin/_token_gen.py +++ b/firebase_admin/_token_gen.py @@ -82,10 +82,13 @@ def from_iam(cls, request, google_cred, service_account): class TokenGenerator: """Generates custom tokens and session cookies.""" - def __init__(self, app, client): + ID_TOOLKIT_URL = 'https://identitytoolkit.googleapis.com/v1' + + def __init__(self, app, http_client): self.app = app - self.client = client + self.http_client = http_client self.request = transport.requests.Request() + self.base_url = '{0}/projects/{1}'.format(self.ID_TOOLKIT_URL, app.project_id) self._signing_provider = None def _init_signing_provider(self): @@ -130,7 +133,7 @@ def signing_provider(self): 'details on creating custom tokens.'.format(error, url)) return self._signing_provider - def create_custom_token(self, uid, developer_claims=None): + def create_custom_token(self, uid, developer_claims=None, tenant_id=None): """Builds and signs a Firebase custom auth token.""" if developer_claims is not None: if not isinstance(developer_claims, dict): @@ -161,6 +164,8 @@ def create_custom_token(self, uid, developer_claims=None): 'iat': now, 'exp': now + MAX_TOKEN_LIFETIME_SECONDS, } + if tenant_id: + payload['tenant_id'] = tenant_id if developer_claims is not None: payload['claims'] = developer_claims @@ -190,13 +195,13 @@ def create_session_cookie(self, id_token, expires_in): raise ValueError('Illegal expiry duration: {0}. Duration must be at most {1} ' 'seconds.'.format(expires_in, MAX_SESSION_COOKIE_DURATION_SECONDS)) + url = '{0}:createSessionCookie'.format(self.base_url) payload = { 'idToken': id_token, 'validDuration': expires_in, } try: - body, http_resp = self.client.body_and_response( - 'post', ':createSessionCookie', json=payload) + body, http_resp = self.http_client.body_and_response('post', url, json=payload) except requests.exceptions.RequestException as error: raise _auth_utils.handle_auth_backend_error(error) else: diff --git a/firebase_admin/_user_mgt.py b/firebase_admin/_user_mgt.py index 533259e70..0b0c5ddb6 100644 --- a/firebase_admin/_user_mgt.py +++ b/firebase_admin/_user_mgt.py @@ -244,6 +244,15 @@ def custom_claims(self): return parsed return None + @property + def tenant_id(self): + """Returns the tenant ID of this user. + + Returns: + string: A tenant ID string or None. + """ + return self._data.get('tenantId') + class ExportedUserRecord(UserRecord): """Contains metadata associated with a user including password hash and salt.""" @@ -454,8 +463,13 @@ def encode_action_code_settings(settings): class UserManager: """Provides methods for interacting with the Google Identity Toolkit.""" - def __init__(self, client): - self._client = client + ID_TOOLKIT_URL = 'https://identitytoolkit.googleapis.com/v1' + + def __init__(self, http_client, project_id, tenant_id=None): + self.http_client = http_client + self.base_url = '{0}/projects/{1}'.format(self.ID_TOOLKIT_URL, project_id) + if tenant_id: + self.base_url += '/tenants/{0}'.format(tenant_id) def get_user(self, **kwargs): """Gets the user data corresponding to the provided key.""" @@ -471,17 +485,12 @@ def get_user(self, **kwargs): else: raise TypeError('Unsupported keyword arguments: {0}.'.format(kwargs)) - try: - body, http_resp = self._client.body_and_response( - 'post', '/accounts:lookup', json=payload) - except requests.exceptions.RequestException as error: - raise _auth_utils.handle_auth_backend_error(error) - else: - if not body or not body.get('users'): - raise _auth_utils.UserNotFoundError( - 'No user record found for the provided {0}: {1}.'.format(key_type, key), - http_response=http_resp) - return body['users'][0] + body, http_resp = self._make_request('post', '/accounts:lookup', json=payload) + if not body or not body.get('users'): + raise _auth_utils.UserNotFoundError( + 'No user record found for the provided {0}: {1}.'.format(key_type, key), + http_response=http_resp) + return body['users'][0] def list_users(self, page_token=None, max_results=MAX_LIST_USERS_RESULTS): """Retrieves a batch of users.""" @@ -498,10 +507,8 @@ def list_users(self, page_token=None, max_results=MAX_LIST_USERS_RESULTS): payload = {'maxResults': max_results} if page_token: payload['nextPageToken'] = page_token - try: - return self._client.body('get', '/accounts:batchGet', params=payload) - except requests.exceptions.RequestException as error: - raise _auth_utils.handle_auth_backend_error(error) + body, _ = self._make_request('get', '/accounts:batchGet', params=payload) + return body def create_user(self, uid=None, display_name=None, email=None, phone_number=None, photo_url=None, password=None, disabled=None, email_verified=None): @@ -517,15 +524,11 @@ def create_user(self, uid=None, display_name=None, email=None, phone_number=None 'disabled': bool(disabled) if disabled is not None else None, } payload = {k: v for k, v in payload.items() if v is not None} - try: - body, http_resp = self._client.body_and_response('post', '/accounts', json=payload) - except requests.exceptions.RequestException as error: - raise _auth_utils.handle_auth_backend_error(error) - else: - if not body or not body.get('localId'): - raise _auth_utils.UnexpectedResponseError( - 'Failed to create new user.', http_response=http_resp) - return body.get('localId') + body, http_resp = self._make_request('post', '/accounts', json=payload) + if not body or not body.get('localId'): + raise _auth_utils.UnexpectedResponseError( + 'Failed to create new user.', http_response=http_resp) + return body.get('localId') def update_user(self, uid, display_name=None, email=None, phone_number=None, photo_url=None, password=None, disabled=None, email_verified=None, @@ -568,29 +571,19 @@ def update_user(self, uid, display_name=None, email=None, phone_number=None, payload['customAttributes'] = _auth_utils.validate_custom_claims(json_claims) payload = {k: v for k, v in payload.items() if v is not None} - try: - body, http_resp = self._client.body_and_response( - 'post', '/accounts:update', json=payload) - except requests.exceptions.RequestException as error: - raise _auth_utils.handle_auth_backend_error(error) - else: - if not body or not body.get('localId'): - raise _auth_utils.UnexpectedResponseError( - 'Failed to update user: {0}.'.format(uid), http_response=http_resp) - return body.get('localId') + body, http_resp = self._make_request('post', '/accounts:update', json=payload) + if not body or not body.get('localId'): + raise _auth_utils.UnexpectedResponseError( + 'Failed to update user: {0}.'.format(uid), http_response=http_resp) + return body.get('localId') def delete_user(self, uid): """Deletes the user identified by the specified user ID.""" _auth_utils.validate_uid(uid, required=True) - try: - body, http_resp = self._client.body_and_response( - 'post', '/accounts:delete', json={'localId' : uid}) - except requests.exceptions.RequestException as error: - raise _auth_utils.handle_auth_backend_error(error) - else: - if not body or not body.get('kind'): - raise _auth_utils.UnexpectedResponseError( - 'Failed to delete user: {0}.'.format(uid), http_response=http_resp) + body, http_resp = self._make_request('post', '/accounts:delete', json={'localId' : uid}) + if not body or not body.get('kind'): + raise _auth_utils.UnexpectedResponseError( + 'Failed to delete user: {0}.'.format(uid), http_response=http_resp) def import_users(self, users, hash_alg=None): """Imports the given list of users to Firebase Auth.""" @@ -609,16 +602,11 @@ def import_users(self, users, hash_alg=None): if not isinstance(hash_alg, _user_import.UserImportHash): raise ValueError('A UserImportHash is required to import users with passwords.') payload.update(hash_alg.to_dict()) - try: - body, http_resp = self._client.body_and_response( - 'post', '/accounts:batchCreate', json=payload) - except requests.exceptions.RequestException as error: - raise _auth_utils.handle_auth_backend_error(error) - else: - if not isinstance(body, dict): - raise _auth_utils.UnexpectedResponseError( - 'Failed to import users.', http_response=http_resp) - return body + body, http_resp = self._make_request('post', '/accounts:batchCreate', json=payload) + if not isinstance(body, dict): + raise _auth_utils.UnexpectedResponseError( + 'Failed to import users.', http_response=http_resp) + return body def generate_email_action_link(self, action_type, email, action_code_settings=None): """Fetches the email action links for types @@ -646,45 +634,22 @@ def generate_email_action_link(self, action_type, email, action_code_settings=No if action_code_settings: payload.update(encode_action_code_settings(action_code_settings)) + body, http_resp = self._make_request('post', '/accounts:sendOobCode', json=payload) + if not body or not body.get('oobLink'): + raise _auth_utils.UnexpectedResponseError( + 'Failed to generate email action link.', http_response=http_resp) + return body.get('oobLink') + + def _make_request(self, method, path, **kwargs): + url = '{0}{1}'.format(self.base_url, path) try: - body, http_resp = self._client.body_and_response( - 'post', '/accounts:sendOobCode', json=payload) + return self.http_client.body_and_response(method, url, **kwargs) except requests.exceptions.RequestException as error: raise _auth_utils.handle_auth_backend_error(error) - else: - if not body or not body.get('oobLink'): - raise _auth_utils.UnexpectedResponseError( - 'Failed to generate email action link.', http_response=http_resp) - return body.get('oobLink') - -class _UserIterator: - """An iterator that allows iterating over user accounts, one at a time. - This implementation loads a page of users into memory, and iterates on them. When the whole - page has been traversed, it loads another page. This class never keeps more than one page - of entries in memory. - """ +class _UserIterator(_auth_utils.PageIterator): - def __init__(self, current_page): - if not current_page: - raise ValueError('Current page must not be None.') - self._current_page = current_page - self._index = 0 - - def next(self): - if self._index == len(self._current_page.users): - if self._current_page.has_next_page: - self._current_page = self._current_page.get_next_page() - self._index = 0 - if self._index < len(self._current_page.users): - result = self._current_page.users[self._index] - self._index += 1 - return result - raise StopIteration - - def __next__(self): - return self.next() - - def __iter__(self): - return self + @property + def items(self): + return self._current_page.users diff --git a/firebase_admin/auth.py b/firebase_admin/auth.py index 6f85e622c..cb8782ea7 100644 --- a/firebase_admin/auth.py +++ b/firebase_admin/auth.py @@ -19,11 +19,9 @@ creating and managing user accounts in Firebase projects. """ -import time - -import firebase_admin +from firebase_admin import _auth_client +from firebase_admin import _auth_providers from firebase_admin import _auth_utils -from firebase_admin import _http_client from firebase_admin import _token_gen from firebase_admin import _user_import from firebase_admin import _user_mgt @@ -36,6 +34,7 @@ __all__ = [ 'ActionCodeSettings', 'CertificateFetchError', + 'Client', 'DELETE_ATTRIBUTE', 'EmailAlreadyExistsError', 'ErrorInfo', @@ -47,10 +46,13 @@ 'InvalidDynamicLinkDomainError', 'InvalidIdTokenError', 'InvalidSessionCookieError', + 'ListProviderConfigsPage', 'ListUsersPage', 'PhoneNumberAlreadyExistsError', + 'ProviderConfig', 'RevokedIdTokenError', 'RevokedSessionCookieError', + 'SAMLProviderConfig', 'TokenSignError', 'UidAlreadyExistsError', 'UnexpectedResponseError', @@ -63,19 +65,24 @@ 'UserRecord', 'create_custom_token', + 'create_saml_provider_config', 'create_session_cookie', 'create_user', + 'delete_saml_provider_config', 'delete_user', 'generate_email_verification_link', 'generate_password_reset_link', 'generate_sign_in_with_email_link', + 'get_saml_provider_config', 'get_user', 'get_user_by_email', 'get_user_by_phone_number', 'import_users', + 'list_saml_provider_configs', 'list_users', 'revoke_refresh_tokens', 'set_custom_user_claims', + 'update_saml_provider_config', 'update_user', 'verify_id_token', 'verify_session_cookie', @@ -83,6 +90,8 @@ ActionCodeSettings = _user_mgt.ActionCodeSettings CertificateFetchError = _token_gen.CertificateFetchError +Client = _auth_client.Client +ConfigurationNotFoundError = _auth_utils.ConfigurationNotFoundError DELETE_ATTRIBUTE = _user_mgt.DELETE_ATTRIBUTE EmailAlreadyExistsError = _auth_utils.EmailAlreadyExistsError ErrorInfo = _user_import.ErrorInfo @@ -94,10 +103,14 @@ InvalidDynamicLinkDomainError = _auth_utils.InvalidDynamicLinkDomainError InvalidIdTokenError = _auth_utils.InvalidIdTokenError InvalidSessionCookieError = _token_gen.InvalidSessionCookieError +ListProviderConfigsPage = _auth_providers.ListProviderConfigsPage ListUsersPage = _user_mgt.ListUsersPage +OIDCProviderConfig = _auth_providers.OIDCProviderConfig PhoneNumberAlreadyExistsError = _auth_utils.PhoneNumberAlreadyExistsError +ProviderConfig = _auth_providers.ProviderConfigClient RevokedIdTokenError = _token_gen.RevokedIdTokenError RevokedSessionCookieError = _token_gen.RevokedSessionCookieError +SAMLProviderConfig = _auth_providers.SAMLProviderConfig TokenSignError = _token_gen.TokenSignError UidAlreadyExistsError = _auth_utils.UidAlreadyExistsError UnexpectedResponseError = _auth_utils.UnexpectedResponseError @@ -110,23 +123,23 @@ UserRecord = _user_mgt.UserRecord -def _get_auth_service(app): - """Returns an _AuthService instance for an App. +def _get_client(app): + """Returns a client instance for an App. - If the App already has an _AuthService associated with it, simply returns - it. Otherwise creates a new _AuthService, and adds it to the App before + If the App already has a client associated with it, simply returns + it. Otherwise creates a new client, and adds it to the App before returning it. Args: - app: A Firebase App instance (or None to use the default App). + app: A Firebase App instance (or ``None`` to use the default App). Returns: - _AuthService: An _AuthService for the specified App instance. + Client: A client for the specified App instance. Raises: ValueError: If the app argument is invalid. """ - return _utils.get_app_service(app, _AUTH_ATTRIBUTE, _AuthService) + return _utils.get_app_service(app, _AUTH_ATTRIBUTE, Client) def create_custom_token(uid, developer_claims=None, app=None): @@ -145,8 +158,8 @@ def create_custom_token(uid, developer_claims=None, app=None): ValueError: If input parameters are invalid. TokenSignError: If an error occurs while signing the token using the remote IAM service. """ - token_generator = _get_auth_service(app).token_generator - return token_generator.create_custom_token(uid, developer_claims) + client = _get_client(app) + return client.create_custom_token(uid, developer_claims) def verify_id_token(id_token, app=None, check_revoked=False): @@ -171,15 +184,8 @@ def verify_id_token(id_token, app=None, check_revoked=False): CertificateFetchError: If an error occurs while fetching the public key certificates required to verify the ID token. """ - if not isinstance(check_revoked, bool): - # guard against accidental wrong assignment. - raise ValueError('Illegal check_revoked argument. Argument must be of type ' - ' bool, but given "{0}".'.format(type(check_revoked))) - token_verifier = _get_auth_service(app).token_verifier - verified_claims = token_verifier.verify_id_token(id_token) - if check_revoked: - _check_jwt_revoked(verified_claims, RevokedIdTokenError, 'ID token', app) - return verified_claims + client = _get_client(app) + return client.verify_id_token(id_token, check_revoked=check_revoked) def create_session_cookie(id_token, expires_in, app=None): @@ -200,8 +206,9 @@ def create_session_cookie(id_token, expires_in, app=None): ValueError: If input parameters are invalid. FirebaseError: If an error occurs while creating the cookie. """ - token_generator = _get_auth_service(app).token_generator - return token_generator.create_session_cookie(id_token, expires_in) + client = _get_client(app) + # pylint: disable=protected-access + return client._token_generator.create_session_cookie(id_token, expires_in) def verify_session_cookie(session_cookie, check_revoked=False, app=None): @@ -226,17 +233,18 @@ def verify_session_cookie(session_cookie, check_revoked=False, app=None): CertificateFetchError: If an error occurs while fetching the public key certificates required to verify the session cookie. """ - token_verifier = _get_auth_service(app).token_verifier - verified_claims = token_verifier.verify_session_cookie(session_cookie) + client = _get_client(app) + # pylint: disable=protected-access + verified_claims = client._token_verifier.verify_session_cookie(session_cookie) if check_revoked: - _check_jwt_revoked(verified_claims, RevokedSessionCookieError, 'session cookie', app) + client._check_jwt_revoked(verified_claims, RevokedSessionCookieError, 'session cookie') return verified_claims def revoke_refresh_tokens(uid, app=None): """Revokes all refresh tokens for an existing user. - revoke_refresh_tokens updates the user's tokens_valid_after_timestamp to the current UTC + This function updates the user's ``tokens_valid_after_timestamp`` to the current UTC in seconds since the epoch. It is important that the server on which this is called has its clock set correctly and synchronized. @@ -244,9 +252,17 @@ def revoke_refresh_tokens(uid, app=None): existing sessions from getting minted, existing ID tokens may remain active until their natural expiration (one hour). To verify that ID tokens are revoked, use ``verify_id_token(idToken, check_revoked=True)``. + + Args: + uid: A user ID string. + app: An App instance (optional). + + Raises: + ValueError: If the user ID is None, empty or malformed. + FirebaseError: If an error occurs while revoking the refresh token. """ - user_manager = _get_auth_service(app).user_manager - user_manager.update_user(uid, valid_since=int(time.time())) + client = _get_client(app) + client.revoke_refresh_tokens(uid) def get_user(uid, app=None): @@ -257,16 +273,15 @@ def get_user(uid, app=None): app: An App instance (optional). Returns: - UserRecord: A UserRecord instance. + UserRecord: A user record instance. Raises: ValueError: If the user ID is None, empty or malformed. UserNotFoundError: If the specified user ID does not exist. FirebaseError: If an error occurs while retrieving the user. """ - user_manager = _get_auth_service(app).user_manager - response = user_manager.get_user(uid=uid) - return UserRecord(response) + client = _get_client(app) + return client.get_user(uid=uid) def get_user_by_email(email, app=None): @@ -277,16 +292,15 @@ def get_user_by_email(email, app=None): app: An App instance (optional). Returns: - UserRecord: A UserRecord instance. + UserRecord: A user record instance. Raises: ValueError: If the email is None, empty or malformed. UserNotFoundError: If no user exists by the specified email address. FirebaseError: If an error occurs while retrieving the user. """ - user_manager = _get_auth_service(app).user_manager - response = user_manager.get_user(email=email) - return UserRecord(response) + client = _get_client(app) + return client.get_user_by_email(email=email) def get_user_by_phone_number(phone_number, app=None): @@ -297,16 +311,15 @@ def get_user_by_phone_number(phone_number, app=None): app: An App instance (optional). Returns: - UserRecord: A UserRecord instance. + UserRecord: A user record instance. Raises: ValueError: If the phone number is None, empty or malformed. UserNotFoundError: If no user exists by the specified phone number. FirebaseError: If an error occurs while retrieving the user. """ - user_manager = _get_auth_service(app).user_manager - response = user_manager.get_user(phone_number=phone_number) - return UserRecord(response) + client = _get_client(app) + return client.get_user_by_phone_number(phone_number=phone_number) def list_users(page_token=None, max_results=_user_mgt.MAX_LIST_USERS_RESULTS, app=None): @@ -325,16 +338,14 @@ def list_users(page_token=None, max_results=_user_mgt.MAX_LIST_USERS_RESULTS, ap app: An App instance (optional). Returns: - ListUsersPage: A ListUsersPage instance. + ListUsersPage: A page of user accounts. Raises: - ValueError: If max_results or page_token are invalid. + ValueError: If ``max_results`` or ``page_token`` are invalid. FirebaseError: If an error occurs while retrieving the user accounts. """ - user_manager = _get_auth_service(app).user_manager - def download(page_token, max_results): - return user_manager.list_users(page_token, max_results) - return ListUsersPage(download, page_token, max_results) + client = _get_client(app) + return client.list_users(page_token=page_token, max_results=max_results) def create_user(**kwargs): # pylint: disable=differing-param-doc @@ -356,16 +367,15 @@ def create_user(**kwargs): # pylint: disable=differing-param-doc app: An App instance (optional). Returns: - UserRecord: A UserRecord instance for the newly created user. + UserRecord: A user record instance for the newly created user. Raises: ValueError: If the specified user properties are invalid. FirebaseError: If an error occurs while creating the user account. """ app = kwargs.pop('app', None) - user_manager = _get_auth_service(app).user_manager - uid = user_manager.create_user(**kwargs) - return UserRecord(user_manager.get_user(uid=uid)) + client = _get_client(app) + return client.create_user(**kwargs) def update_user(uid, **kwargs): # pylint: disable=differing-param-doc @@ -389,20 +399,20 @@ def update_user(uid, **kwargs): # pylint: disable=differing-param-doc disabled: A boolean indicating whether or not the user account is disabled (optional). custom_claims: A dictionary or a JSON string contining the custom claims to be set on the user account (optional). To remove all custom claims, pass ``auth.DELETE_ATTRIBUTE``. - valid_since: An integer signifying the seconds since the epoch. This field is set by - ``revoke_refresh_tokens`` and it is discouraged to set this field directly. + valid_since: An integer signifying the seconds since the epoch (optional). This field is + set by ``revoke_refresh_tokens`` and it is discouraged to set this field directly. + app: An App instance (optional). Returns: - UserRecord: An updated UserRecord instance for the user. + UserRecord: An updated user record instance for the user. Raises: ValueError: If the specified user ID or properties are invalid. FirebaseError: If an error occurs while updating the user account. """ app = kwargs.pop('app', None) - user_manager = _get_auth_service(app).user_manager - user_manager.update_user(uid, **kwargs) - return UserRecord(user_manager.get_user(uid=uid)) + client = _get_client(app) + return client.update_user(uid, **kwargs) def set_custom_user_claims(uid, custom_claims, app=None): @@ -425,10 +435,8 @@ def set_custom_user_claims(uid, custom_claims, app=None): ValueError: If the specified user ID or the custom claims are invalid. FirebaseError: If an error occurs while updating the user account. """ - user_manager = _get_auth_service(app).user_manager - if custom_claims is None: - custom_claims = DELETE_ATTRIBUTE - user_manager.update_user(uid, custom_claims=custom_claims) + client = _get_client(app) + client.set_custom_user_claims(uid, custom_claims=custom_claims) def delete_user(uid, app=None): @@ -442,8 +450,8 @@ def delete_user(uid, app=None): ValueError: If the user ID is None, empty or malformed. FirebaseError: If an error occurs while deleting the user account. """ - user_manager = _get_auth_service(app).user_manager - user_manager.delete_user(uid) + client = _get_client(app) + client.delete_user(uid) def import_users(users, hash_alg=None, app=None): @@ -468,9 +476,8 @@ def import_users(users, hash_alg=None, app=None): ValueError: If the provided arguments are invalid. FirebaseError: If an error occurs while importing users. """ - user_manager = _get_auth_service(app).user_manager - result = user_manager.import_users(users, hash_alg) - return UserImportResult(result, len(users)) + client = _get_client(app) + return client.import_users(users, hash_alg) def generate_password_reset_link(email, action_code_settings=None, app=None): @@ -490,9 +497,8 @@ def generate_password_reset_link(email, action_code_settings=None, app=None): ValueError: If the provided arguments are invalid FirebaseError: If an error occurs while generating the link """ - user_manager = _get_auth_service(app).user_manager - return user_manager.generate_email_action_link( - 'PASSWORD_RESET', email, action_code_settings=action_code_settings) + client = _get_client(app) + return client.generate_password_reset_link(email, action_code_settings=action_code_settings) def generate_email_verification_link(email, action_code_settings=None, app=None): @@ -512,9 +518,9 @@ def generate_email_verification_link(email, action_code_settings=None, app=None) ValueError: If the provided arguments are invalid FirebaseError: If an error occurs while generating the link """ - user_manager = _get_auth_service(app).user_manager - return user_manager.generate_email_action_link( - 'VERIFY_EMAIL', email, action_code_settings=action_code_settings) + client = _get_client(app) + return client.generate_email_verification_link( + email, action_code_settings=action_code_settings) def generate_sign_in_with_email_link(email, action_code_settings, app=None): @@ -527,6 +533,7 @@ def generate_sign_in_with_email_link(email, action_code_settings, app=None): the link is to be handled by a mobile app and the additional state information to be passed in the deep link. app: An App instance (optional). + Returns: link: The email sign-in link created by the API @@ -534,47 +541,263 @@ def generate_sign_in_with_email_link(email, action_code_settings, app=None): ValueError: If the provided arguments are invalid FirebaseError: If an error occurs while generating the link """ - user_manager = _get_auth_service(app).user_manager - return user_manager.generate_email_action_link( - 'EMAIL_SIGNIN', email, action_code_settings=action_code_settings) + client = _get_client(app) + return client.generate_sign_in_with_email_link( + email, action_code_settings=action_code_settings) + + +def get_oidc_provider_config(provider_id, app=None): + """Returns the ``OIDCProviderConfig`` with the given ID. + + Args: + provider_id: Provider ID string. + app: An App instance (optional). + + Returns: + OIDCProviderConfig: An OIDC provider config instance. + + Raises: + ValueError: If the provider ID is invalid, empty or does not have ``oidc.`` prefix. + ConfigurationNotFoundError: If no OIDC provider is available with the given identifier. + FirebaseError: If an error occurs while retrieving the OIDC provider. + """ + client = _get_client(app) + return client.get_oidc_provider_config(provider_id) + +def create_oidc_provider_config( + provider_id, client_id, issuer, display_name=None, enabled=None, app=None): + """Creates a new OIDC provider config from the given parameters. + + OIDC provider support requires Google Cloud's Identity Platform (GCIP). To learn more about + GCIP, including pricing and features, see https://cloud.google.com/identity-platform. + + Args: + provider_id: Provider ID string. Must have the prefix ``oidc.``. + client_id: Client ID of the new config. + issuer: Issuer of the new config. Must be a valid URL. + display_name: The user-friendly display name to the current configuration (optional). + This name is also used as the provider label in the Cloud Console. + enabled: A boolean indicating whether the provider configuration is enabled or disabled + (optional). A user cannot sign in using a disabled provider. + app: An App instance (optional). + + Returns: + OIDCProviderConfig: The newly created OIDC provider config instance. + + Raises: + ValueError: If any of the specified input parameters are invalid. + FirebaseError: If an error occurs while creating the new OIDC provider config. + """ + client = _get_client(app) + return client.create_oidc_provider_config( + provider_id, client_id=client_id, issuer=issuer, display_name=display_name, + enabled=enabled) + + +def update_oidc_provider_config( + provider_id, client_id=None, issuer=None, display_name=None, enabled=None, app=None): + """Updates an existing OIDC provider config with the given parameters. + + Args: + provider_id: Provider ID string. Must have the prefix ``oidc.``. + client_id: Client ID of the new config (optional). + issuer: Issuer of the new config (optional). Must be a valid URL. + display_name: The user-friendly display name of the current configuration (optional). + Pass ``auth.DELETE_ATTRIBUTE`` to delete the current display name. + enabled: A boolean indicating whether the provider configuration is enabled or disabled + (optional). + app: An App instance (optional). + + Returns: + OIDCProviderConfig: The updated OIDC provider config instance. + + Raises: + ValueError: If any of the specified input parameters are invalid. + FirebaseError: If an error occurs while updating the OIDC provider config. + """ + client = _get_client(app) + return client.update_oidc_provider_config( + provider_id, client_id=client_id, issuer=issuer, display_name=display_name, + enabled=enabled) + + +def delete_oidc_provider_config(provider_id, app=None): + """Deletes the ``OIDCProviderConfig`` with the given ID. + + Args: + provider_id: Provider ID string. + app: An App instance (optional). + + Raises: + ValueError: If the provider ID is invalid, empty or does not have ``oidc.`` prefix. + ConfigurationNotFoundError: If no OIDC provider is available with the given identifier. + FirebaseError: If an error occurs while deleting the OIDC provider. + """ + client = _get_client(app) + client.delete_oidc_provider_config(provider_id) + + +def list_oidc_provider_configs( + page_token=None, max_results=_auth_providers.MAX_LIST_CONFIGS_RESULTS, app=None): + """Retrieves a page of OIDC provider configs from a Firebase project. + + The ``page_token`` argument governs the starting point of the page. The ``max_results`` + argument governs the maximum number of configs that may be included in the returned + page. This function never returns ``None``. If there are no OIDC configs in the Firebase + project, this returns an empty page. + + Args: + page_token: A non-empty page token string, which indicates the starting point of the + page (optional). Defaults to ``None``, which will retrieve the first page of users. + max_results: A positive integer indicating the maximum number of users to include in + the returned page (optional). Defaults to 100, which is also the maximum number + allowed. + app: An App instance (optional). + + Returns: + ListProviderConfigsPage: A page of OIDC provider config instances. + + Raises: + ValueError: If ``max_results`` or ``page_token`` are invalid. + FirebaseError: If an error occurs while retrieving the OIDC provider configs. + """ + client = _get_client(app) + return client.list_oidc_provider_configs(page_token, max_results) + + +def get_saml_provider_config(provider_id, app=None): + """Returns the ``SAMLProviderConfig`` with the given ID. + + Args: + provider_id: Provider ID string. + app: An App instance (optional). + + Returns: + SAMLProviderConfig: A SAML provider config instance. + + Raises: + ValueError: If the provider ID is invalid, empty or does not have ``saml.`` prefix. + ConfigurationNotFoundError: If no SAML provider is available with the given identifier. + FirebaseError: If an error occurs while retrieving the SAML provider. + """ + client = _get_client(app) + return client.get_saml_provider_config(provider_id) + + +def create_saml_provider_config( + provider_id, idp_entity_id, sso_url, x509_certificates, rp_entity_id, callback_url, + display_name=None, enabled=None, app=None): + """Creates a new SAML provider config from the given parameters. + + SAML provider support requires Google Cloud's Identity Platform (GCIP). To learn more about + GCIP, including pricing and features, see https://cloud.google.com/identity-platform. + + Args: + provider_id: Provider ID string. Must have the prefix ``saml.``. + idp_entity_id: The SAML IdP entity identifier. + sso_url: The SAML IdP SSO URL. Must be a valid URL. + x509_certificates: The list of SAML IdP X.509 certificates issued by CA for this provider. + Multiple certificates are accepted to prevent outages during IdP key rotation (for + example ADFS rotates every 10 days). When the Auth server receives a SAML response, it + will match the SAML response with the certificate on record. Otherwise the response is + rejected. Developers are expected to manage the certificate updates as keys are + rotated. + rp_entity_id: The SAML relying party (service provider) entity ID. This is defined by the + developer but needs to be provided to the SAML IdP. + callback_url: Callback URL string. This is fixed and must always be the same as the OAuth + redirect URL provisioned by Firebase Auth, unless a custom authDomain is used. + display_name: The user-friendly display name to the current configuration (optional). This + name is also used as the provider label in the Cloud Console. + enabled: A boolean indicating whether the provider configuration is enabled or disabled + (optional). A user cannot sign in using a disabled provider. + app: An App instance (optional). + + Returns: + SAMLProviderConfig: The newly created SAML provider config instance. + + Raises: + ValueError: If any of the specified input parameters are invalid. + FirebaseError: If an error occurs while creating the new SAML provider config. + """ + client = _get_client(app) + return client.create_saml_provider_config( + provider_id, idp_entity_id=idp_entity_id, sso_url=sso_url, + x509_certificates=x509_certificates, rp_entity_id=rp_entity_id, callback_url=callback_url, + display_name=display_name, enabled=enabled) + + +def update_saml_provider_config( + provider_id, idp_entity_id=None, sso_url=None, x509_certificates=None, + rp_entity_id=None, callback_url=None, display_name=None, enabled=None, app=None): + """Updates an existing SAML provider config with the given parameters. + + Args: + provider_id: Provider ID string. Must have the prefix ``saml.``. + idp_entity_id: The SAML IdP entity identifier (optional). + sso_url: The SAML IdP SSO URL. Must be a valid URL (optional). + x509_certificates: The list of SAML IdP X.509 certificates issued by CA for this + provider (optional). + rp_entity_id: The SAML relying party entity ID (optional). + callback_url: Callback URL string (optional). + display_name: The user-friendly display name of the current configuration (optional). + Pass ``auth.DELETE_ATTRIBUTE`` to delete the current display name. + enabled: A boolean indicating whether the provider configuration is enabled or disabled + (optional). + app: An App instance (optional). + + Returns: + SAMLProviderConfig: The updated SAML provider config instance. + Raises: + ValueError: If any of the specified input parameters are invalid. + FirebaseError: If an error occurs while updating the SAML provider config. + """ + client = _get_client(app) + return client.update_saml_provider_config( + provider_id, idp_entity_id=idp_entity_id, sso_url=sso_url, + x509_certificates=x509_certificates, rp_entity_id=rp_entity_id, + callback_url=callback_url, display_name=display_name, enabled=enabled) -def _check_jwt_revoked(verified_claims, exc_type, label, app): - user = get_user(verified_claims.get('uid'), app=app) - if verified_claims.get('iat') * 1000 < user.tokens_valid_after_timestamp: - raise exc_type('The Firebase {0} has been revoked.'.format(label)) +def delete_saml_provider_config(provider_id, app=None): + """Deletes the ``SAMLProviderConfig`` with the given ID. -class _AuthService: - """Firebase Authentication service.""" + Args: + provider_id: Provider ID string. + app: An App instance (optional). - ID_TOOLKIT_URL = 'https://identitytoolkit.googleapis.com/v1/projects/' + Raises: + ValueError: If the provider ID is invalid, empty or does not have ``saml.`` prefix. + ConfigurationNotFoundError: If no SAML provider is available with the given identifier. + FirebaseError: If an error occurs while deleting the SAML provider. + """ + client = _get_client(app) + client.delete_saml_provider_config(provider_id) - def __init__(self, app): - credential = app.credential.get_credential() - version_header = 'Python/Admin/{0}'.format(firebase_admin.__version__) - if not app.project_id: - raise ValueError("""Project ID is required to access the auth service. - 1. Use a service account credential, or - 2. set the project ID explicitly via Firebase App options, or - 3. set the project ID via the GOOGLE_CLOUD_PROJECT environment variable.""") +def list_saml_provider_configs( + page_token=None, max_results=_auth_providers.MAX_LIST_CONFIGS_RESULTS, app=None): + """Retrieves a page of SAML provider configs from a Firebase project. - client = _http_client.JsonHttpClient( - credential=credential, base_url=self.ID_TOOLKIT_URL + app.project_id, - headers={'X-Client-Version': version_header}) - self._token_generator = _token_gen.TokenGenerator(app, client) - self._token_verifier = _token_gen.TokenVerifier(app) - self._user_manager = _user_mgt.UserManager(client) + The ``page_token`` argument governs the starting point of the page. The ``max_results`` + argument governs the maximum number of configs that may be included in the returned + page. This function never returns ``None``. If there are no SAML configs in the Firebase + project, this returns an empty page. - @property - def token_generator(self): - return self._token_generator + Args: + page_token: A non-empty page token string, which indicates the starting point of the + page (optional). Defaults to ``None``, which will retrieve the first page of users. + max_results: A positive integer indicating the maximum number of users to include in + the returned page (optional). Defaults to 100, which is also the maximum number + allowed. + app: An App instance (optional). - @property - def token_verifier(self): - return self._token_verifier + Returns: + ListProviderConfigsPage: A page of SAML provider config instances. - @property - def user_manager(self): - return self._user_manager + Raises: + ValueError: If ``max_results`` or ``page_token`` are invalid. + FirebaseError: If an error occurs while retrieving the SAML provider configs. + """ + client = _get_client(app) + return client.list_saml_provider_configs(page_token, max_results) diff --git a/firebase_admin/tenant_mgt.py b/firebase_admin/tenant_mgt.py new file mode 100644 index 000000000..396a819fb --- /dev/null +++ b/firebase_admin/tenant_mgt.py @@ -0,0 +1,445 @@ +# Copyright 2020 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Firebase tenant management module. + +This module contains functions for creating and configuring authentication tenants within a +Google Cloud Identity Platform (GCIP) instance. +""" + +import re +import threading + +import requests + +import firebase_admin +from firebase_admin import auth +from firebase_admin import _auth_utils +from firebase_admin import _http_client +from firebase_admin import _utils + + +_TENANT_MGT_ATTRIBUTE = '_tenant_mgt' +_MAX_LIST_TENANTS_RESULTS = 100 +_DISPLAY_NAME_PATTERN = re.compile('^[a-zA-Z][a-zA-Z0-9-]{3,19}$') + + +__all__ = [ + 'ListTenantsPage', + 'Tenant', + 'TenantIdMismatchError', + 'TenantNotFoundError', + + 'auth_for_tenant', + 'create_tenant', + 'delete_tenant', + 'get_tenant', + 'list_tenants', + 'update_tenant', +] + + +TenantIdMismatchError = _auth_utils.TenantIdMismatchError +TenantNotFoundError = _auth_utils.TenantNotFoundError + + +def auth_for_tenant(tenant_id, app=None): + """Gets an Auth Client instance scoped to the given tenant ID. + + Args: + tenant_id: A tenant ID string. + app: An App instance (optional). + + Returns: + auth.Client: An ``auth.Client`` object. + + Raises: + ValueError: If the tenant ID is None, empty or not a string. + """ + tenant_mgt_service = _get_tenant_mgt_service(app) + return tenant_mgt_service.auth_for_tenant(tenant_id) + + +def get_tenant(tenant_id, app=None): + """Gets the tenant corresponding to the given ``tenant_id``. + + Args: + tenant_id: A tenant ID string. + app: An App instance (optional). + + Returns: + Tenant: A tenant object. + + Raises: + ValueError: If the tenant ID is None, empty or not a string. + TenantNotFoundError: If no tenant exists by the given ID. + FirebaseError: If an error occurs while retrieving the tenant. + """ + tenant_mgt_service = _get_tenant_mgt_service(app) + return tenant_mgt_service.get_tenant(tenant_id) + + +def create_tenant( + display_name, allow_password_sign_up=None, enable_email_link_sign_in=None, app=None): + """Creates a new tenant from the given options. + + Args: + display_name: Display name string for the new tenant. Must begin with a letter and contain + only letters, digits and hyphens. Length must be between 4 and 20. + allow_password_sign_up: A boolean indicating whether to enable or disable the email sign-in + provider (optional). + enable_email_link_sign_in: A boolean indicating whether to enable or disable email link + sign-in (optional). Disabling this makes the password required for email sign-in. + app: An App instance (optional). + + Returns: + Tenant: A tenant object. + + Raises: + ValueError: If any of the given arguments are invalid. + FirebaseError: If an error occurs while creating the tenant. + """ + tenant_mgt_service = _get_tenant_mgt_service(app) + return tenant_mgt_service.create_tenant( + display_name=display_name, allow_password_sign_up=allow_password_sign_up, + enable_email_link_sign_in=enable_email_link_sign_in) + + +def update_tenant( + tenant_id, display_name=None, allow_password_sign_up=None, enable_email_link_sign_in=None, + app=None): + """Updates an existing tenant with the given options. + + Args: + tenant_id: ID of the tenant to update. + display_name: Updated display name string for the tenant (optional). + allow_password_sign_up: A boolean indicating whether to enable or disable the email sign-in + provider. + enable_email_link_sign_in: A boolean indicating whether to enable or disable email link + sign-in. Disabling this makes the password required for email sign-in. + app: An App instance (optional). + + Returns: + Tenant: The updated tenant object. + + Raises: + ValueError: If any of the given arguments are invalid. + TenantNotFoundError: If no tenant exists by the given ID. + FirebaseError: If an error occurs while creating the tenant. + """ + tenant_mgt_service = _get_tenant_mgt_service(app) + return tenant_mgt_service.update_tenant( + tenant_id, display_name=display_name, allow_password_sign_up=allow_password_sign_up, + enable_email_link_sign_in=enable_email_link_sign_in) + + +def delete_tenant(tenant_id, app=None): + """Deletes the tenant corresponding to the given ``tenant_id``. + + Args: + tenant_id: A tenant ID string. + app: An App instance (optional). + + Raises: + ValueError: If the tenant ID is None, empty or not a string. + TenantNotFoundError: If no tenant exists by the given ID. + FirebaseError: If an error occurs while retrieving the tenant. + """ + tenant_mgt_service = _get_tenant_mgt_service(app) + tenant_mgt_service.delete_tenant(tenant_id) + + +def list_tenants(page_token=None, max_results=_MAX_LIST_TENANTS_RESULTS, app=None): + """Retrieves a page of tenants from a Firebase project. + + The ``page_token`` argument governs the starting point of the page. The ``max_results`` + argument governs the maximum number of tenants that may be included in the returned page. + This function never returns None. If there are no user accounts in the Firebase project, this + returns an empty page. + + Args: + page_token: A non-empty page token string, which indicates the starting point of the page + (optional). Defaults to ``None``, which will retrieve the first page of users. + max_results: A positive integer indicating the maximum number of users to include in the + returned page (optional). Defaults to 100, which is also the maximum number allowed. + app: An App instance (optional). + + Returns: + ListTenantsPage: A page of tenants. + + Raises: + ValueError: If ``max_results`` or ``page_token`` are invalid. + FirebaseError: If an error occurs while retrieving the user accounts. + """ + tenant_mgt_service = _get_tenant_mgt_service(app) + def download(page_token, max_results): + return tenant_mgt_service.list_tenants(page_token, max_results) + return ListTenantsPage(download, page_token, max_results) + + +def _get_tenant_mgt_service(app): + return _utils.get_app_service(app, _TENANT_MGT_ATTRIBUTE, _TenantManagementService) + + +class Tenant: + """Represents a tenant in a multi-tenant application. + + Multi-tenancy support requires Google Cloud Identity Platform (GCIP). To learn more about + GCIP including pricing and features, see https://cloud.google.com/identity-platform. + + Before multi-tenancy can be used in a Google Cloud Identity Platform project, tenants must be + enabled in that project via the Cloud Console UI. A Tenant instance provides information + such as the display name, tenant identifier and email authentication configuration. + """ + + def __init__(self, data): + if not isinstance(data, dict): + raise ValueError('Invalid data argument in Tenant constructor: {0}'.format(data)) + if not 'name' in data: + raise ValueError('Tenant response missing required keys.') + + self._data = data + + @property + def tenant_id(self): + name = self._data['name'] + return name.split('/')[-1] + + @property + def display_name(self): + return self._data.get('displayName') + + @property + def allow_password_sign_up(self): + return self._data.get('allowPasswordSignup', False) + + @property + def enable_email_link_sign_in(self): + return self._data.get('enableEmailLinkSignin', False) + + +class _TenantManagementService: + """Firebase tenant management service.""" + + TENANT_MGT_URL = 'https://identitytoolkit.googleapis.com/v2beta1' + + def __init__(self, app): + credential = app.credential.get_credential() + version_header = 'Python/Admin/{0}'.format(firebase_admin.__version__) + base_url = '{0}/projects/{1}'.format(self.TENANT_MGT_URL, app.project_id) + self.app = app + self.client = _http_client.JsonHttpClient( + credential=credential, base_url=base_url, headers={'X-Client-Version': version_header}) + self.tenant_clients = {} + self.lock = threading.RLock() + + def auth_for_tenant(self, tenant_id): + """Gets an Auth Client instance scoped to the given tenant ID.""" + if not isinstance(tenant_id, str) or not tenant_id: + raise ValueError( + 'Invalid tenant ID: {0}. Tenant ID must be a non-empty string.'.format(tenant_id)) + + with self.lock: + if tenant_id in self.tenant_clients: + return self.tenant_clients[tenant_id] + + client = auth.Client(self.app, tenant_id=tenant_id) + self.tenant_clients[tenant_id] = client + return client + + def get_tenant(self, tenant_id): + """Gets the tenant corresponding to the given ``tenant_id``.""" + if not isinstance(tenant_id, str) or not tenant_id: + raise ValueError( + 'Invalid tenant ID: {0}. Tenant ID must be a non-empty string.'.format(tenant_id)) + + try: + body = self.client.body('get', '/tenants/{0}'.format(tenant_id)) + except requests.exceptions.RequestException as error: + raise _auth_utils.handle_auth_backend_error(error) + else: + return Tenant(body) + + def create_tenant( + self, display_name, allow_password_sign_up=None, enable_email_link_sign_in=None): + """Creates a new tenant from the given parameters.""" + + payload = {'displayName': _validate_display_name(display_name)} + if allow_password_sign_up is not None: + payload['allowPasswordSignup'] = _auth_utils.validate_boolean( + allow_password_sign_up, 'allowPasswordSignup') + if enable_email_link_sign_in is not None: + payload['enableEmailLinkSignin'] = _auth_utils.validate_boolean( + enable_email_link_sign_in, 'enableEmailLinkSignin') + + try: + body = self.client.body('post', '/tenants', json=payload) + except requests.exceptions.RequestException as error: + raise _auth_utils.handle_auth_backend_error(error) + else: + return Tenant(body) + + def update_tenant( + self, tenant_id, display_name=None, allow_password_sign_up=None, + enable_email_link_sign_in=None): + """Updates the specified tenant with the given parameters.""" + if not isinstance(tenant_id, str) or not tenant_id: + raise ValueError('Tenant ID must be a non-empty string.') + + payload = {} + if display_name is not None: + payload['displayName'] = _validate_display_name(display_name) + if allow_password_sign_up is not None: + payload['allowPasswordSignup'] = _auth_utils.validate_boolean( + allow_password_sign_up, 'allowPasswordSignup') + if enable_email_link_sign_in is not None: + payload['enableEmailLinkSignin'] = _auth_utils.validate_boolean( + enable_email_link_sign_in, 'enableEmailLinkSignin') + + if not payload: + raise ValueError('At least one parameter must be specified for update.') + + url = '/tenants/{0}'.format(tenant_id) + update_mask = ','.join(_auth_utils.build_update_mask(payload)) + params = 'updateMask={0}'.format(update_mask) + try: + body = self.client.body('patch', url, json=payload, params=params) + except requests.exceptions.RequestException as error: + raise _auth_utils.handle_auth_backend_error(error) + else: + return Tenant(body) + + def delete_tenant(self, tenant_id): + """Deletes the tenant corresponding to the given ``tenant_id``.""" + if not isinstance(tenant_id, str) or not tenant_id: + raise ValueError( + 'Invalid tenant ID: {0}. Tenant ID must be a non-empty string.'.format(tenant_id)) + + try: + self.client.request('delete', '/tenants/{0}'.format(tenant_id)) + except requests.exceptions.RequestException as error: + raise _auth_utils.handle_auth_backend_error(error) + + def list_tenants(self, page_token=None, max_results=_MAX_LIST_TENANTS_RESULTS): + """Retrieves a batch of tenants.""" + if page_token is not None: + if not isinstance(page_token, str) or not page_token: + raise ValueError('Page token must be a non-empty string.') + if not isinstance(max_results, int): + raise ValueError('Max results must be an integer.') + if max_results < 1 or max_results > _MAX_LIST_TENANTS_RESULTS: + raise ValueError( + 'Max results must be a positive integer less than or equal to ' + '{0}.'.format(_MAX_LIST_TENANTS_RESULTS)) + + payload = {'pageSize': max_results} + if page_token: + payload['pageToken'] = page_token + try: + return self.client.body('get', '/tenants', params=payload) + except requests.exceptions.RequestException as error: + raise _auth_utils.handle_auth_backend_error(error) + + +class ListTenantsPage: + """Represents a page of tenants fetched from a Firebase project. + + Provides methods for traversing tenants included in this page, as well as retrieving + subsequent pages of tenants. The iterator returned by ``iterate_all()`` can be used to iterate + through all tenants in the Firebase project starting from this page. + """ + + def __init__(self, download, page_token, max_results): + self._download = download + self._max_results = max_results + self._current = download(page_token, max_results) + + @property + def tenants(self): + """A list of ``ExportedUserRecord`` instances available in this page.""" + return [Tenant(data) for data in self._current.get('tenants', [])] + + @property + def next_page_token(self): + """Page token string for the next page (empty string indicates no more pages).""" + return self._current.get('nextPageToken', '') + + @property + def has_next_page(self): + """A boolean indicating whether more pages are available.""" + return bool(self.next_page_token) + + def get_next_page(self): + """Retrieves the next page of tenants, if available. + + Returns: + ListTenantsPage: Next page of tenants, or None if this is the last page. + """ + if self.has_next_page: + return ListTenantsPage(self._download, self.next_page_token, self._max_results) + return None + + def iterate_all(self): + """Retrieves an iterator for tenants. + + Returned iterator will iterate through all the tenants in the Firebase project + starting from this page. The iterator will never buffer more than one page of tenants + in memory at a time. + + Returns: + iterator: An iterator of Tenant instances. + """ + return _TenantIterator(self) + + +class _TenantIterator: + """An iterator that allows iterating over tenants. + + This implementation loads a page of tenants into memory, and iterates on them. When the whole + page has been traversed, it loads another page. This class never keeps more than one page + of entries in memory. + """ + + def __init__(self, current_page): + if not current_page: + raise ValueError('Current page must not be None.') + self._current_page = current_page + self._index = 0 + + def next(self): + if self._index == len(self._current_page.tenants): + if self._current_page.has_next_page: + self._current_page = self._current_page.get_next_page() + self._index = 0 + if self._index < len(self._current_page.tenants): + result = self._current_page.tenants[self._index] + self._index += 1 + return result + raise StopIteration + + def __next__(self): + return self.next() + + def __iter__(self): + return self + + +def _validate_display_name(display_name): + if not isinstance(display_name, str): + raise ValueError('Invalid type for displayName') + if not _DISPLAY_NAME_PATTERN.search(display_name): + raise ValueError( + 'displayName must start with a letter and only consist of letters, digits and ' + 'hyphens with 4-20 characters.') + return display_name diff --git a/integration/test_auth.py b/integration/test_auth.py index 5d26dd9f1..cfd775016 100644 --- a/integration/test_auth.py +++ b/integration/test_auth.py @@ -16,6 +16,7 @@ import base64 import datetime import random +import string import time from urllib import parse import uuid @@ -38,6 +39,30 @@ ACTION_LINK_CONTINUE_URL = 'http://localhost?a=1&b=5#f=1' +X509_CERTIFICATES = [ + ('-----BEGIN CERTIFICATE-----\nMIICZjCCAc+gAwIBAgIBADANBgkqhkiG9w0BAQ0FADBQMQswCQYDVQQGEwJ1czE' + 'L\nMAkGA1UECAwCQ0ExDTALBgNVBAoMBEFjbWUxETAPBgNVBAMMCGFjbWUuY29tMRIw\nEAYDVQQHDAlTdW5ueXZhbGU' + 'wHhcNMTgxMjA2MDc1MTUxWhcNMjgxMjAzMDc1MTUx\nWjBQMQswCQYDVQQGEwJ1czELMAkGA1UECAwCQ0ExDTALBgNVB' + 'AoMBEFjbWUxETAP\nBgNVBAMMCGFjbWUuY29tMRIwEAYDVQQHDAlTdW5ueXZhbGUwgZ8wDQYJKoZIhvcN\nAQEBBQADg' + 'Y0AMIGJAoGBAKphmggjiVgqMLXyzvI7cKphscIIQ+wcv7Dld6MD4aKv\n7Jqr8ltujMxBUeY4LFEKw8Terb01snYpDot' + 'filaG6NxpF/GfVVmMalzwWp0mT8+H\nyzyPj89mRcozu17RwuooR6n1ofXjGcBE86lqC21UhA3WVgjPOLqB42rlE9gPn' + 'ZLB\nAgMBAAGjUDBOMB0GA1UdDgQWBBS0iM7WnbCNOnieOP1HIA+Oz/ML+zAfBgNVHSME\nGDAWgBS0iM7WnbCNOnieO' + 'P1HIA+Oz/ML+zAMBgNVHRMEBTADAQH/MA0GCSqGSIb3\nDQEBDQUAA4GBAF3jBgS+wP+K/jTupEQur6iaqS4UvXd//d4' + 'vo1MV06oTLQMTz+rP\nOSMDNwxzfaOn6vgYLKP/Dcy9dSTnSzgxLAxfKvDQZA0vE3udsw0Bd245MmX4+GOp\nlbrN99X' + 'P1u+lFxCSdMUzvQ/jW4ysw/Nq4JdJ0gPAyPvL6Qi/3mQdIQwx\n-----END CERTIFICATE-----\n'), + ('-----BEGIN CERTIFICATE-----\nMIICZjCCAc+gAwIBAgIBADANBgkqhkiG9w0BAQ0FADBQMQswCQYDVQQGEwJ1czE' + 'L\nMAkGA1UECAwCQ0ExDTALBgNVBAoMBEFjbWUxETAPBgNVBAMMCGFjbWUuY29tMRIw\nEAYDVQQHDAlTdW5ueXZhbGU' + 'wHhcNMTgxMjA2MDc1ODE4WhcNMjgxMjAzMDc1ODE4\nWjBQMQswCQYDVQQGEwJ1czELMAkGA1UECAwCQ0ExDTALBgNVB' + 'AoMBEFjbWUxETAP\nBgNVBAMMCGFjbWUuY29tMRIwEAYDVQQHDAlTdW5ueXZhbGUwgZ8wDQYJKoZIhvcN\nAQEBBQADg' + 'Y0AMIGJAoGBAKuzYKfDZGA6DJgQru3wNUqv+S0hMZfP/jbp8ou/8UKu\nrNeX7cfCgt3yxoGCJYKmF6t5mvo76JY0MWw' + 'A53BxeP/oyXmJ93uHG5mFRAsVAUKs\ncVVb0Xi6ujxZGVdDWFV696L0BNOoHTfXmac6IBoZQzNNK4n1AATqwo+z7a0pf' + 'RrJ\nAgMBAAGjUDBOMB0GA1UdDgQWBBSKmi/ZKMuLN0ES7/jPa7q7jAjPiDAfBgNVHSME\nGDAWgBSKmi/ZKMuLN0ES7' + '/jPa7q7jAjPiDAMBgNVHRMEBTADAQH/MA0GCSqGSIb3\nDQEBDQUAA4GBAAg2a2kSn05NiUOuWOHwPUjW3wQRsGxPXtb' + 'hWMhmNdCfKKteM2+/\nLd/jz5F3qkOgGQ3UDgr3SHEoWhnLaJMF4a2tm6vL2rEIfPEK81KhTTRxSsAgMVbU\nJXBz1md' + '6Ur0HlgQC7d1CHC8/xi2DDwHopLyxhogaZUxy9IaRxUEa2vJW\n-----END CERTIFICATE-----\n'), +] + + def _sign_in(custom_token, api_key): body = {'token' : custom_token.decode(), 'returnSecureToken' : True} params = {'key' : api_key} @@ -52,6 +77,10 @@ def _sign_in_with_password(email, password, api_key): resp.raise_for_status() return resp.json().get('idToken') +def _random_string(length=10): + letters = string.ascii_lowercase + return ''.join(random.choice(letters) for i in range(length)) + def _random_id(): random_id = str(uuid.uuid4()).lower().replace('-', '') email = 'test{0}@example.{1}.com'.format(random_id[:12], random_id[12:]) @@ -477,6 +506,163 @@ def test_email_sign_in_with_settings(new_user_email_unverified, api_key): assert id_token is not None and len(id_token) > 0 assert auth.get_user(new_user_email_unverified.uid).email_verified + +@pytest.fixture(scope='module') +def oidc_provider(): + provider_config = _create_oidc_provider_config() + yield provider_config + auth.delete_oidc_provider_config(provider_config.provider_id) + + +def test_create_oidc_provider_config(oidc_provider): + assert isinstance(oidc_provider, auth.OIDCProviderConfig) + assert oidc_provider.client_id == 'OIDC_CLIENT_ID' + assert oidc_provider.issuer == 'https://oidc.com/issuer' + assert oidc_provider.display_name == 'OIDC_DISPLAY_NAME' + assert oidc_provider.enabled is True + + +def test_get_oidc_provider_config(oidc_provider): + provider_config = auth.get_oidc_provider_config(oidc_provider.provider_id) + assert isinstance(provider_config, auth.OIDCProviderConfig) + assert provider_config.provider_id == oidc_provider.provider_id + assert provider_config.client_id == 'OIDC_CLIENT_ID' + assert provider_config.issuer == 'https://oidc.com/issuer' + assert provider_config.display_name == 'OIDC_DISPLAY_NAME' + assert provider_config.enabled is True + + +def test_list_oidc_provider_configs(oidc_provider): + page = auth.list_oidc_provider_configs() + result = None + for provider_config in page.iterate_all(): + if provider_config.provider_id == oidc_provider.provider_id: + result = provider_config + break + + assert result is not None + + +def test_update_oidc_provider_config(): + provider_config = _create_oidc_provider_config() + try: + provider_config = auth.update_oidc_provider_config( + provider_config.provider_id, + client_id='UPDATED_OIDC_CLIENT_ID', + issuer='https://oidc.com/updated_issuer', + display_name='UPDATED_OIDC_DISPLAY_NAME', + enabled=False) + assert provider_config.client_id == 'UPDATED_OIDC_CLIENT_ID' + assert provider_config.issuer == 'https://oidc.com/updated_issuer' + assert provider_config.display_name == 'UPDATED_OIDC_DISPLAY_NAME' + assert provider_config.enabled is False + finally: + auth.delete_oidc_provider_config(provider_config.provider_id) + + +def test_delete_oidc_provider_config(): + provider_config = _create_oidc_provider_config() + auth.delete_oidc_provider_config(provider_config.provider_id) + with pytest.raises(auth.ConfigurationNotFoundError): + auth.get_oidc_provider_config(provider_config.provider_id) + + +@pytest.fixture(scope='module') +def saml_provider(): + provider_config = _create_saml_provider_config() + yield provider_config + auth.delete_saml_provider_config(provider_config.provider_id) + + +def test_create_saml_provider_config(saml_provider): + assert isinstance(saml_provider, auth.SAMLProviderConfig) + assert saml_provider.idp_entity_id == 'IDP_ENTITY_ID' + assert saml_provider.sso_url == 'https://example.com/login' + assert saml_provider.x509_certificates == [X509_CERTIFICATES[0]] + assert saml_provider.rp_entity_id == 'RP_ENTITY_ID' + assert saml_provider.callback_url == 'https://projectId.firebaseapp.com/__/auth/handler' + assert saml_provider.display_name == 'SAML_DISPLAY_NAME' + assert saml_provider.enabled is True + + +def test_get_saml_provider_config(saml_provider): + provider_config = auth.get_saml_provider_config(saml_provider.provider_id) + assert isinstance(provider_config, auth.SAMLProviderConfig) + assert provider_config.provider_id == saml_provider.provider_id + assert provider_config.idp_entity_id == 'IDP_ENTITY_ID' + assert provider_config.sso_url == 'https://example.com/login' + assert provider_config.x509_certificates == [X509_CERTIFICATES[0]] + assert provider_config.rp_entity_id == 'RP_ENTITY_ID' + assert provider_config.callback_url == 'https://projectId.firebaseapp.com/__/auth/handler' + assert provider_config.display_name == 'SAML_DISPLAY_NAME' + assert provider_config.enabled is True + + +def test_list_saml_provider_configs(saml_provider): + page = auth.list_saml_provider_configs() + result = None + for provider_config in page.iterate_all(): + if provider_config.provider_id == saml_provider.provider_id: + result = provider_config + break + + assert result is not None + + +def test_update_saml_provider_config(): + provider_config = _create_saml_provider_config() + try: + provider_config = auth.update_saml_provider_config( + provider_config.provider_id, + idp_entity_id='UPDATED_IDP_ENTITY_ID', + sso_url='https://example.com/updated_login', + x509_certificates=[X509_CERTIFICATES[1]], + rp_entity_id='UPDATED_RP_ENTITY_ID', + callback_url='https://updatedProjectId.firebaseapp.com/__/auth/handler', + display_name='UPDATED_SAML_DISPLAY_NAME', + enabled=False) + assert provider_config.idp_entity_id == 'UPDATED_IDP_ENTITY_ID' + assert provider_config.sso_url == 'https://example.com/updated_login' + assert provider_config.x509_certificates == [X509_CERTIFICATES[1]] + assert provider_config.rp_entity_id == 'UPDATED_RP_ENTITY_ID' + assert provider_config.callback_url == ('https://updatedProjectId.firebaseapp.com/' + '__/auth/handler') + assert provider_config.display_name == 'UPDATED_SAML_DISPLAY_NAME' + assert provider_config.enabled is False + finally: + auth.delete_saml_provider_config(provider_config.provider_id) + + +def test_delete_saml_provider_config(): + provider_config = _create_saml_provider_config() + auth.delete_saml_provider_config(provider_config.provider_id) + with pytest.raises(auth.ConfigurationNotFoundError): + auth.get_saml_provider_config(provider_config.provider_id) + + +def _create_oidc_provider_config(): + provider_id = 'oidc.{0}'.format(_random_string()) + return auth.create_oidc_provider_config( + provider_id=provider_id, + client_id='OIDC_CLIENT_ID', + issuer='https://oidc.com/issuer', + display_name='OIDC_DISPLAY_NAME', + enabled=True) + + +def _create_saml_provider_config(): + provider_id = 'saml.{0}'.format(_random_string()) + return auth.create_saml_provider_config( + provider_id=provider_id, + idp_entity_id='IDP_ENTITY_ID', + sso_url='https://example.com/login', + x509_certificates=[X509_CERTIFICATES[0]], + rp_entity_id='RP_ENTITY_ID', + callback_url='https://projectId.firebaseapp.com/__/auth/handler', + display_name='SAML_DISPLAY_NAME', + enabled=True) + + class CredentialWrapper(credentials.Base): """A custom Firebase credential that wraps an OAuth2 token.""" diff --git a/integration/test_tenant_mgt.py b/integration/test_tenant_mgt.py new file mode 100644 index 000000000..c9eefd96e --- /dev/null +++ b/integration/test_tenant_mgt.py @@ -0,0 +1,417 @@ +# Copyright 2020 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Integration tests for firebase_admin.tenant_mgt module.""" + +import random +import string +import time +from urllib import parse +import uuid + +import requests +import pytest + +from firebase_admin import auth +from firebase_admin import tenant_mgt +from integration import test_auth + + +ACTION_LINK_CONTINUE_URL = 'http://localhost?a=1&b=5#f=1' +ACTION_CODE_SETTINGS = auth.ActionCodeSettings(ACTION_LINK_CONTINUE_URL) +VERIFY_TOKEN_URL = 'https://www.googleapis.com/identitytoolkit/v3/relyingparty/verifyCustomToken' + + +@pytest.fixture(scope='module') +def sample_tenant(): + tenant = tenant_mgt.create_tenant( + display_name='admin-python-tenant', + allow_password_sign_up=True, + enable_email_link_sign_in=True) + yield tenant + tenant_mgt.delete_tenant(tenant.tenant_id) + + +@pytest.fixture(scope='module') +def tenant_user(sample_tenant): + client = tenant_mgt.auth_for_tenant(sample_tenant.tenant_id) + email = _random_email() + user = client.create_user(email=email) + yield user + client.delete_user(user.uid) + + +def test_get_tenant(sample_tenant): + tenant = tenant_mgt.get_tenant(sample_tenant.tenant_id) + assert isinstance(tenant, tenant_mgt.Tenant) + assert tenant.tenant_id == sample_tenant.tenant_id + assert tenant.display_name == 'admin-python-tenant' + assert tenant.allow_password_sign_up is True + assert tenant.enable_email_link_sign_in is True + + +def test_list_tenants(sample_tenant): + page = tenant_mgt.list_tenants() + result = None + for tenant in page.iterate_all(): + if tenant.tenant_id == sample_tenant.tenant_id: + result = tenant + break + assert isinstance(result, tenant_mgt.Tenant) + assert result.tenant_id == sample_tenant.tenant_id + assert result.display_name == 'admin-python-tenant' + assert result.allow_password_sign_up is True + assert result.enable_email_link_sign_in is True + + +def test_update_tenant(): + tenant = tenant_mgt.create_tenant( + display_name='py-update-test', allow_password_sign_up=True, enable_email_link_sign_in=True) + try: + tenant = tenant_mgt.update_tenant( + tenant.tenant_id, display_name='updated-py-tenant', allow_password_sign_up=False, + enable_email_link_sign_in=False) + assert isinstance(tenant, tenant_mgt.Tenant) + assert tenant.tenant_id == tenant.tenant_id + assert tenant.display_name == 'updated-py-tenant' + assert tenant.allow_password_sign_up is False + assert tenant.enable_email_link_sign_in is False + finally: + tenant_mgt.delete_tenant(tenant.tenant_id) + + +def test_delete_tenant(): + tenant = tenant_mgt.create_tenant(display_name='py-delete-test') + tenant_mgt.delete_tenant(tenant.tenant_id) + with pytest.raises(tenant_mgt.TenantNotFoundError): + tenant_mgt.get_tenant(tenant.tenant_id) + + +def test_auth_for_client(sample_tenant): + client = tenant_mgt.auth_for_tenant(sample_tenant.tenant_id) + assert isinstance(client, auth.Client) + assert client.tenant_id == sample_tenant.tenant_id + + +def test_custom_token(sample_tenant, api_key): + client = tenant_mgt.auth_for_tenant(sample_tenant.tenant_id) + custom_token = client.create_custom_token('user1') + id_token = _sign_in(custom_token, sample_tenant.tenant_id, api_key) + claims = client.verify_id_token(id_token) + assert claims['uid'] == 'user1' + assert claims['firebase']['tenant'] == sample_tenant.tenant_id + + +def test_custom_token_with_claims(sample_tenant, api_key): + client = tenant_mgt.auth_for_tenant(sample_tenant.tenant_id) + custom_token = client.create_custom_token('user1', {'premium': True}) + id_token = _sign_in(custom_token, sample_tenant.tenant_id, api_key) + claims = client.verify_id_token(id_token) + assert claims['uid'] == 'user1' + assert claims['premium'] is True + assert claims['firebase']['tenant'] == sample_tenant.tenant_id + + +def test_create_user(sample_tenant, tenant_user): + assert tenant_user.tenant_id == sample_tenant.tenant_id + + +def test_update_user(sample_tenant): + client = tenant_mgt.auth_for_tenant(sample_tenant.tenant_id) + user = client.create_user() + try: + email = _random_email() + phone = _random_phone() + user = client.update_user(user.uid, email=email, phone_number=phone) + assert user.tenant_id == sample_tenant.tenant_id + assert user.email == email + assert user.phone_number == phone + finally: + client.delete_user(user.uid) + + +def test_get_user(sample_tenant, tenant_user): + client = tenant_mgt.auth_for_tenant(sample_tenant.tenant_id) + user = client.get_user(tenant_user.uid) + assert user.uid == tenant_user.uid + assert user.tenant_id == sample_tenant.tenant_id + + +def test_list_users(sample_tenant, tenant_user): + client = tenant_mgt.auth_for_tenant(sample_tenant.tenant_id) + page = client.list_users() + result = None + for user in page.iterate_all(): + if user.uid == tenant_user.uid: + result = user + break + assert result.tenant_id == sample_tenant.tenant_id + + +def test_set_custom_user_claims(sample_tenant, tenant_user): + client = tenant_mgt.auth_for_tenant(sample_tenant.tenant_id) + client.set_custom_user_claims(tenant_user.uid, {'premium': True}) + user = client.get_user(tenant_user.uid) + assert user.custom_claims == {'premium': True} + + +def test_delete_user(sample_tenant): + client = tenant_mgt.auth_for_tenant(sample_tenant.tenant_id) + user = client.create_user() + client.delete_user(user.uid) + with pytest.raises(auth.UserNotFoundError): + client.get_user(user.uid) + + +def test_revoke_refresh_tokens(sample_tenant, tenant_user): + valid_since = int(time.time()) + time.sleep(1) + client = tenant_mgt.auth_for_tenant(sample_tenant.tenant_id) + client.revoke_refresh_tokens(tenant_user.uid) + user = client.get_user(tenant_user.uid) + assert user.tokens_valid_after_timestamp > valid_since + + +def test_password_reset_link(sample_tenant, tenant_user): + client = tenant_mgt.auth_for_tenant(sample_tenant.tenant_id) + link = client.generate_password_reset_link(tenant_user.email, ACTION_CODE_SETTINGS) + assert _tenant_id_from_link(link) == sample_tenant.tenant_id + + +def test_email_verification_link(sample_tenant, tenant_user): + client = tenant_mgt.auth_for_tenant(sample_tenant.tenant_id) + link = client.generate_email_verification_link(tenant_user.email, ACTION_CODE_SETTINGS) + assert _tenant_id_from_link(link) == sample_tenant.tenant_id + + +def test_sign_in_with_email_link(sample_tenant, tenant_user): + client = tenant_mgt.auth_for_tenant(sample_tenant.tenant_id) + link = client.generate_sign_in_with_email_link(tenant_user.email, ACTION_CODE_SETTINGS) + assert _tenant_id_from_link(link) == sample_tenant.tenant_id + + +def test_import_users(sample_tenant): + client = tenant_mgt.auth_for_tenant(sample_tenant.tenant_id) + user = auth.ImportUserRecord( + uid=_random_uid(), email=_random_email()) + result = client.import_users([user]) + try: + assert result.success_count == 1 + assert result.failure_count == 0 + saved_user = client.get_user(user.uid) + assert saved_user.email == user.email + finally: + client.delete_user(user.uid) + + +@pytest.fixture(scope='module') +def oidc_provider(sample_tenant): + client = tenant_mgt.auth_for_tenant(sample_tenant.tenant_id) + provider_config = _create_oidc_provider_config(client) + yield provider_config + client.delete_oidc_provider_config(provider_config.provider_id) + + +def test_create_oidc_provider_config(oidc_provider): + assert isinstance(oidc_provider, auth.OIDCProviderConfig) + assert oidc_provider.client_id == 'OIDC_CLIENT_ID' + assert oidc_provider.issuer == 'https://oidc.com/issuer' + assert oidc_provider.display_name == 'OIDC_DISPLAY_NAME' + assert oidc_provider.enabled is True + + +def test_get_oidc_provider_config(sample_tenant, oidc_provider): + client = tenant_mgt.auth_for_tenant(sample_tenant.tenant_id) + provider_config = client.get_oidc_provider_config(oidc_provider.provider_id) + assert isinstance(provider_config, auth.OIDCProviderConfig) + assert provider_config.provider_id == oidc_provider.provider_id + assert provider_config.client_id == 'OIDC_CLIENT_ID' + assert provider_config.issuer == 'https://oidc.com/issuer' + assert provider_config.display_name == 'OIDC_DISPLAY_NAME' + assert provider_config.enabled is True + + +def test_list_oidc_provider_configs(sample_tenant, oidc_provider): + client = tenant_mgt.auth_for_tenant(sample_tenant.tenant_id) + page = client.list_oidc_provider_configs() + result = None + for provider_config in page.iterate_all(): + if provider_config.provider_id == oidc_provider.provider_id: + result = provider_config + break + + assert result is not None + + +def test_update_oidc_provider_config(sample_tenant): + client = tenant_mgt.auth_for_tenant(sample_tenant.tenant_id) + provider_config = _create_oidc_provider_config(client) + try: + provider_config = client.update_oidc_provider_config( + provider_config.provider_id, + client_id='UPDATED_OIDC_CLIENT_ID', + issuer='https://oidc.com/updated_issuer', + display_name='UPDATED_OIDC_DISPLAY_NAME', + enabled=False) + assert provider_config.client_id == 'UPDATED_OIDC_CLIENT_ID' + assert provider_config.issuer == 'https://oidc.com/updated_issuer' + assert provider_config.display_name == 'UPDATED_OIDC_DISPLAY_NAME' + assert provider_config.enabled is False + finally: + client.delete_oidc_provider_config(provider_config.provider_id) + + +def test_delete_oidc_provider_config(sample_tenant): + client = tenant_mgt.auth_for_tenant(sample_tenant.tenant_id) + provider_config = _create_oidc_provider_config(client) + client.delete_oidc_provider_config(provider_config.provider_id) + with pytest.raises(auth.ConfigurationNotFoundError): + client.get_oidc_provider_config(provider_config.provider_id) + + +@pytest.fixture(scope='module') +def saml_provider(sample_tenant): + client = tenant_mgt.auth_for_tenant(sample_tenant.tenant_id) + provider_config = _create_saml_provider_config(client) + yield provider_config + client.delete_saml_provider_config(provider_config.provider_id) + + +def test_create_saml_provider_config(saml_provider): + assert isinstance(saml_provider, auth.SAMLProviderConfig) + assert saml_provider.idp_entity_id == 'IDP_ENTITY_ID' + assert saml_provider.sso_url == 'https://example.com/login' + assert saml_provider.x509_certificates == [test_auth.X509_CERTIFICATES[0]] + assert saml_provider.rp_entity_id == 'RP_ENTITY_ID' + assert saml_provider.callback_url == 'https://projectId.firebaseapp.com/__/auth/handler' + assert saml_provider.display_name == 'SAML_DISPLAY_NAME' + assert saml_provider.enabled is True + + +def test_get_saml_provider_config(sample_tenant, saml_provider): + client = tenant_mgt.auth_for_tenant(sample_tenant.tenant_id) + provider_config = client.get_saml_provider_config(saml_provider.provider_id) + assert isinstance(provider_config, auth.SAMLProviderConfig) + assert provider_config.provider_id == saml_provider.provider_id + assert provider_config.idp_entity_id == 'IDP_ENTITY_ID' + assert provider_config.sso_url == 'https://example.com/login' + assert provider_config.x509_certificates == [test_auth.X509_CERTIFICATES[0]] + assert provider_config.rp_entity_id == 'RP_ENTITY_ID' + assert provider_config.callback_url == 'https://projectId.firebaseapp.com/__/auth/handler' + assert provider_config.display_name == 'SAML_DISPLAY_NAME' + assert provider_config.enabled is True + + +def test_list_saml_provider_configs(sample_tenant, saml_provider): + client = tenant_mgt.auth_for_tenant(sample_tenant.tenant_id) + page = client.list_saml_provider_configs() + result = None + for provider_config in page.iterate_all(): + if provider_config.provider_id == saml_provider.provider_id: + result = provider_config + break + + assert result is not None + + +def test_update_saml_provider_config(sample_tenant): + client = tenant_mgt.auth_for_tenant(sample_tenant.tenant_id) + provider_config = _create_saml_provider_config(client) + try: + provider_config = client.update_saml_provider_config( + provider_config.provider_id, + idp_entity_id='UPDATED_IDP_ENTITY_ID', + sso_url='https://example.com/updated_login', + x509_certificates=[test_auth.X509_CERTIFICATES[1]], + rp_entity_id='UPDATED_RP_ENTITY_ID', + callback_url='https://updatedProjectId.firebaseapp.com/__/auth/handler', + display_name='UPDATED_SAML_DISPLAY_NAME', + enabled=False) + assert provider_config.idp_entity_id == 'UPDATED_IDP_ENTITY_ID' + assert provider_config.sso_url == 'https://example.com/updated_login' + assert provider_config.x509_certificates == [test_auth.X509_CERTIFICATES[1]] + assert provider_config.rp_entity_id == 'UPDATED_RP_ENTITY_ID' + assert provider_config.callback_url == ('https://updatedProjectId.firebaseapp.com/' + '__/auth/handler') + assert provider_config.display_name == 'UPDATED_SAML_DISPLAY_NAME' + assert provider_config.enabled is False + finally: + client.delete_saml_provider_config(provider_config.provider_id) + + +def test_delete_saml_provider_config(sample_tenant): + client = tenant_mgt.auth_for_tenant(sample_tenant.tenant_id) + provider_config = _create_saml_provider_config(client) + client.delete_saml_provider_config(provider_config.provider_id) + with pytest.raises(auth.ConfigurationNotFoundError): + client.get_saml_provider_config(provider_config.provider_id) + + +def _create_oidc_provider_config(client): + provider_id = 'oidc.{0}'.format(_random_string()) + return client.create_oidc_provider_config( + provider_id=provider_id, + client_id='OIDC_CLIENT_ID', + issuer='https://oidc.com/issuer', + display_name='OIDC_DISPLAY_NAME', + enabled=True) + + +def _create_saml_provider_config(client): + provider_id = 'saml.{0}'.format(_random_string()) + return client.create_saml_provider_config( + provider_id=provider_id, + idp_entity_id='IDP_ENTITY_ID', + sso_url='https://example.com/login', + x509_certificates=[test_auth.X509_CERTIFICATES[0]], + rp_entity_id='RP_ENTITY_ID', + callback_url='https://projectId.firebaseapp.com/__/auth/handler', + display_name='SAML_DISPLAY_NAME', + enabled=True) + + +def _random_uid(): + return str(uuid.uuid4()).lower().replace('-', '') + + +def _random_email(): + random_id = str(uuid.uuid4()).lower().replace('-', '') + return 'test{0}@example.{1}.com'.format(random_id[:12], random_id[12:]) + + +def _random_phone(): + return '+1' + ''.join([str(random.randint(0, 9)) for _ in range(0, 10)]) + + +def _random_string(length=10): + letters = string.ascii_lowercase + return ''.join(random.choice(letters) for i in range(length)) + + +def _tenant_id_from_link(link): + query = parse.urlparse(link).query + parsed_query = parse.parse_qs(query) + return parsed_query['tenantId'][0] + + +def _sign_in(custom_token, tenant_id, api_key): + body = { + 'token' : custom_token.decode(), + 'returnSecureToken' : True, + 'tenantId': tenant_id, + } + params = {'key' : api_key} + resp = requests.request('post', VERIFY_TOKEN_URL, params=params, json=body) + resp.raise_for_status() + return resp.json().get('idToken') diff --git a/snippets/auth/index.py b/snippets/auth/index.py index b1c091064..428c54e09 100644 --- a/snippets/auth/index.py +++ b/snippets/auth/index.py @@ -25,6 +25,7 @@ from firebase_admin import credentials from firebase_admin import auth from firebase_admin import exceptions +from firebase_admin import tenant_mgt sys.path.append("lib") @@ -634,6 +635,418 @@ def send_custom_email(email, link): del email del link +def create_saml_provider_config(): + # [START create_saml_provider] + saml = auth.create_saml_provider_config( + display_name='SAML provider name', + enabled=True, + provider_id='saml.myProvider', + idp_entity_id='IDP_ENTITY_ID', + sso_url='https://example.com/saml/sso/1234/', + x509_certificates=[ + '-----BEGIN CERTIFICATE-----\nCERT1...\n-----END CERTIFICATE-----', + '-----BEGIN CERTIFICATE-----\nCERT2...\n-----END CERTIFICATE-----', + ], + rp_entity_id='P_ENTITY_ID', + callback_url='https://project-id.firebaseapp.com/__/auth/handler') + + print('Created new SAML provider:', saml.provider_id) + # [END create_saml_provider] + +def update_saml_provider_config(): + # [START update_saml_provider] + saml = auth.update_saml_provider_config( + 'saml.myProvider', + x509_certificates=[ + '-----BEGIN CERTIFICATE-----\nCERT2...\n-----END CERTIFICATE-----', + '-----BEGIN CERTIFICATE-----\nCERT3...\n-----END CERTIFICATE-----', + ]) + + print('Updated SAML provider:', saml.provider_id) + # [END update_saml_provider] + +def get_saml_provider_config(): + # [START get_saml_provider] + saml = auth.get_saml_provider_config('saml.myProvider') + print(saml.display_name, saml.enabled) + # [END get_saml_provider] + +def delete_saml_provider_config(): + # [START delete_saml_provider] + auth.delete_saml_provider_config('saml.myProvider') + # [END delete_saml_provider] + +def list_saml_provider_configs(): + # [START list_saml_providers] + for saml in auth.list_saml_provider_configs('nextPageToken').iterate_all(): + print(saml.provider_id) + # [END list_saml_providers] + +def create_oidc_provider_config(): + # [START create_oidc_provider] + oidc = auth.create_oidc_provider_config( + display_name='OIDC provider name', + enabled=True, + provider_id='oidc.myProvider', + client_id='CLIENT_ID2', + issuer='https://oidc.com/CLIENT_ID2') + + print('Created new OIDC provider:', oidc.provider_id) + # [END create_oidc_provider] + +def update_oidc_provider_config(): + # [START update_oidc_provider] + oidc = auth.update_oidc_provider_config( + 'oidc.myProvider', + client_id='CLIENT_ID', + issuer='https://oidc.com') + + print('Updated OIDC provider:', oidc.provider_id) + # [END update_oidc_provider] + +def get_oidc_provider_config(): + # [START get_oidc_provider] + oidc = auth.get_oidc_provider_config('oidc.myProvider') + + print(oidc.display_name, oidc.enabled) + # [END get_oidc_provider] + +def delete_oidc_provider_config(): + # [START delete_oidc_provider] + auth.delete_oidc_provider_config('oidc.myProvider') + # [END delete_oidc_provider] + +def list_oidc_provider_configs(): + # [START list_oidc_providers] + for oidc in auth.list_oidc_provider_configs('nextPageToken').iterate_all(): + print(oidc.provider_id) + # [END list_oidc_providers] + +def get_tenant_client(tenant_id): + # [START get_tenant_client] + from firebase_admin import tenant_mgt + + tenant_client = tenant_mgt.auth_for_tenant(tenant_id) + # [END get_tenant_client] + return tenant_client + +def get_tenant(tenant_id): + # [START get_tenant] + tenant = tenant_mgt.get_tenant(tenant_id) + + print('Retreieved tenant:', tenant.tenant_id) + # [END get_tenant] + +def create_tenant(): + # [START create_tenant] + tenant = tenant_mgt.create_tenant( + display_name='myTenant1', + enable_email_link_sign_in=True, + allow_password_sign_up=True) + + print('Created tenant:', tenant.tenant_id) + # [END create_tenant] + +def update_tenant(tenant_id): + # [START update_tenant] + tenant = tenant_mgt.update_tenant( + tenant_id, + display_name='updatedName', + allow_password_sign_up=False) # Disable email provider + + print('Updated tenant:', tenant.tenant_id) + # [END update_tenant] + +def delete_tenant(tenant_id): + # [START delete_tenant] + tenant_mgt.delete_tenant(tenant_id) + # [END delete_tenant] + +def list_tenants(): + # [START list_tenants] + for tenant in tenant_mgt.list_tenants().iterate_all(): + print('Retrieved tenant:', tenant.tenant_id) + # [END list_tenants] + +def create_provider_tenant(): + # [START get_tenant_client_short] + tenant_client = tenant_mgt.auth_for_tenant('TENANT-ID') + # [END get_tenant_client_short] + + # [START create_saml_provider_tenant] + saml = tenant_client.create_saml_provider_config( + display_name='SAML provider name', + enabled=True, + provider_id='saml.myProvider', + idp_entity_id='IDP_ENTITY_ID', + sso_url='https://example.com/saml/sso/1234/', + x509_certificates=[ + '-----BEGIN CERTIFICATE-----\nCERT1...\n-----END CERTIFICATE-----', + '-----BEGIN CERTIFICATE-----\nCERT2...\n-----END CERTIFICATE-----', + ], + rp_entity_id='P_ENTITY_ID', + callback_url='https://project-id.firebaseapp.com/__/auth/handler') + + print('Created new SAML provider:', saml.provider_id) + # [END create_saml_provider_tenant] + +def update_provider_tenant(tenant_client): + # [START update_saml_provider_tenant] + saml = tenant_client.update_saml_provider_config( + 'saml.myProvider', + x509_certificates=[ + '-----BEGIN CERTIFICATE-----\nCERT2...\n-----END CERTIFICATE-----', + '-----BEGIN CERTIFICATE-----\nCERT3...\n-----END CERTIFICATE-----', + ]) + + print('Updated SAML provider:', saml.provider_id) + # [END update_saml_provider_tenant] + +def get_provider_tenant(tennat_client): + # [START get_saml_provider_tenant] + saml = tennat_client.get_saml_provider_config('saml.myProvider') + print(saml.display_name, saml.enabled) + # [END get_saml_provider_tenant] + +def list_provider_configs_tenant(tenant_client): + # [START list_saml_providers_tenant] + for saml in tenant_client.list_saml_provider_configs('nextPageToken').iterate_all(): + print(saml.provider_id) + # [END list_saml_providers_tenant] + +def delete_provider_config_tenant(tenant_client): + # [START delete_saml_provider_tenant] + tenant_client.delete_saml_provider_config('saml.myProvider') + # [END delete_saml_provider_tenant] + +def get_user_tenant(tenant_client): + uid = 'some_string_uid' + + # [START get_user_tenant] + # Get an auth.Client from tenant_mgt.auth_for_tenant() + user = tenant_client.get_user(uid) + print('Successfully fetched user data:', user.uid) + # [END get_user_tenant] + +def get_user_by_email_tenant(tenant_client): + email = 'some@email.com' + # [START get_user_by_email_tenant] + user = tenant_client.get_user_by_email(email) + print('Successfully fetched user data:', user.uid) + # [END get_user_by_email_tenant] + +def create_user_tenant(tenant_client): + # [START create_user_tenant] + user = tenant_client.create_user( + email='user@example.com', + email_verified=False, + phone_number='+15555550100', + password='secretPassword', + display_name='John Doe', + photo_url='http://www.example.com/12345678/photo.png', + disabled=False) + print('Sucessfully created new user:', user.uid) + # [END create_user_tenant] + +def update_user_tenant(tenant_client, uid): + # [START update_user_tenant] + user = tenant_client.update_user( + uid, + email='user@example.com', + phone_number='+15555550100', + email_verified=True, + password='newPassword', + display_name='John Doe', + photo_url='http://www.example.com/12345678/photo.png', + disabled=True) + print('Sucessfully updated user:', user.uid) + # [END update_user_tenant] + +def delete_user_tenant(tenant_client, uid): + # [START delete_user_tenant] + tenant_client.delete_user(uid) + print('Successfully deleted user') + # [END delete_user_tenant] + +def list_users_tenant(tenant_client): + # [START list_all_users_tenant] + # Note, behind the scenes, the iterator will retrive 1000 users at a time through the API + for user in tenant_client.list_users().iterate_all(): + print('User: ' + user.uid) + + # Iterating by pages of 1000 users at a time. + page = tenant_client.list_users() + while page: + for user in page.users: + print('User: ' + user.uid) + # Get next batch of users. + page = page.get_next_page() + # [END list_all_users_tenant] + +def import_with_hmac_tenant(tenant_client): + # [START import_with_hmac_tenant] + users = [ + auth.ImportUserRecord( + uid='uid1', + email='user1@example.com', + password_hash=b'password_hash_1', + password_salt=b'salt1' + ), + auth.ImportUserRecord( + uid='uid2', + email='user2@example.com', + password_hash=b'password_hash_2', + password_salt=b'salt2' + ), + ] + + hash_alg = auth.UserImportHash.hmac_sha256(key=b'secret') + try: + result = tenant_client.import_users(users, hash_alg=hash_alg) + for err in result.errors: + print('Failed to import user:', err.reason) + except exceptions.FirebaseError as error: + print('Error importing users:', error) + # [END import_with_hmac_tenant] + +def import_without_password_tenant(tenant_client): + # [START import_without_password_tenant] + users = [ + auth.ImportUserRecord( + uid='some-uid', + display_name='John Doe', + email='johndoe@gmail.com', + photo_url='http://www.example.com/12345678/photo.png', + email_verified=True, + phone_number='+11234567890', + custom_claims={'admin': True}, # set this user as admin + provider_data=[ # user with SAML provider + auth.UserProvider( + uid='saml-uid', + email='johndoe@gmail.com', + display_name='John Doe', + photo_url='http://www.example.com/12345678/photo.png', + provider_id='saml.acme' + ) + ], + ), + ] + try: + result = tenant_client.import_users(users) + for err in result.errors: + print('Failed to import user:', err.reason) + except exceptions.FirebaseError as error: + print('Error importing users:', error) + # [END import_without_password_tenant] + +def verify_id_token_tenant(tenant_client, id_token): + # [START verify_id_token_tenant] + # id_token comes from the client app + try: + decoded_token = tenant_client.verify_id_token(id_token) + + # This should be set to TENANT-ID. Otherwise TenantIdMismatchError error raised. + print('Verified ID token from tenant:', decoded_token['firebase']['tenant']) + except tenant_mgt.TenantIdMismatchError: + # Token revoked, inform the user to reauthenticate or signOut(). + pass + # [END verify_id_token_tenant] + +def verify_id_token_access_control_tenant(id_token): + # [START id_token_access_control_tenant] + decoded_token = auth.verify_id_token(id_token) + + tenant = decoded_token['firebase']['tenant'] + if tenant == 'TENANT-ID1': + # Allow appropriate level of access for TENANT-ID1. + pass + elif tenant == 'TENANT-ID2': + # Allow appropriate level of access for TENANT-ID2. + pass + else: + # Access not allowed -- Handle error + pass + # [END id_token_access_control_tenant] + +def revoke_refresh_tokens_tenant(tenant_client, uid): + # [START revoke_tokens_tenant] + # Revoke all refresh tokens for a specified user in a specified tenant for whatever reason. + # Retrieve the timestamp of the revocation, in seconds since the epoch. + tenant_client.revoke_refresh_tokens(uid) + + user = tenant_client.get_user(uid) + # Convert to seconds as the auth_time in the token claims is in seconds. + revocation_second = user.tokens_valid_after_timestamp / 1000 + print('Tokens revoked at: {0}'.format(revocation_second)) + # [END revoke_tokens_tenant] + +def verify_id_token_and_check_revoked_tenant(tenant_client, id_token): + # [START verify_id_token_and_check_revoked_tenant] + # Verify the ID token for a specific tenant while checking if the token is revoked. + try: + # Verify the ID token while checking if the token is revoked by + # passing check_revoked=True. + decoded_token = tenant_client.verify_id_token(id_token, check_revoked=True) + # Token is valid and not revoked. + uid = decoded_token['uid'] + except tenant_mgt.TenantIdMismatchError: + # Token belongs to a different tenant. + pass + except auth.RevokedIdTokenError: + # Token revoked, inform the user to reauthenticate or signOut(). + pass + except auth.InvalidIdTokenError: + # Token is invalid + pass + # [END verify_id_token_and_check_revoked_tenant] + +def custom_claims_set_tenant(tenant_client, uid): + # [START set_custom_user_claims_tenant] + # Set admin privilege on the user corresponding to uid. + tenant_client.set_custom_user_claims(uid, {'admin': True}) + # The new custom claims will propagate to the user's ID token the + # next time a new one is issued. + # [END set_custom_user_claims_tenant] + +def custom_claims_verify_tenant(tenant_client, id_token): + # [START verify_custom_claims_tenant] + # Verify the ID token first. + claims = tenant_client.verify_id_token(id_token) + if claims['admin'] is True: + # Allow access to requested admin resource. + pass + # [END verify_custom_claims_tenant] + +def custom_claims_read_tenant(tenant_client, uid): + # [START read_custom_user_claims_tenant] + # Lookup the user associated with the specified uid. + user = tenant_client.get_user(uid) + + # The claims can be accessed on the user record. + print(user.custom_claims.get('admin')) + # [END read_custom_user_claims_tenant] + +def generate_email_verification_link_tenant(tenant_client): + # [START email_verification_link_tenant] + action_code_settings = auth.ActionCodeSettings( + url='https://www.example.com/checkout?cartId=1234', + handle_code_in_app=True, + ios_bundle_id='com.example.ios', + android_package_name='com.example.android', + android_install_app=True, + android_minimum_version='12', + # FDL custom domain. + dynamic_link_domain='coolapp.page.link', + ) + + email = 'user@example.com' + link = tenant_client.generate_email_verification_link(email, action_code_settings) + # Construct email from a template embedding the link, and send + # using a custom SMTP server. + send_custom_email(email, link) + # [END email_verification_link_tenant] + + initialize_sdk_with_service_account() initialize_sdk_with_application_default() #initialize_sdk_with_refresh_token() diff --git a/tests/data/list_oidc_provider_configs.json b/tests/data/list_oidc_provider_configs.json new file mode 100644 index 000000000..b2b381304 --- /dev/null +++ b/tests/data/list_oidc_provider_configs.json @@ -0,0 +1,18 @@ +{ + "oauthIdpConfigs": [ + { + "name":"projects/mock-project-id/oauthIdpConfigs/oidc.provider0", + "clientId": "CLIENT_ID", + "issuer": "https://oidc.com/issuer", + "displayName": "oidcProviderName", + "enabled": true + }, + { + "name":"projects/mock-project-id/oauthIdpConfigs/oidc.provider1", + "clientId": "CLIENT_ID", + "issuer": "https://oidc.com/issuer", + "displayName": "oidcProviderName", + "enabled": true + } + ] +} diff --git a/tests/data/list_saml_provider_configs.json b/tests/data/list_saml_provider_configs.json new file mode 100644 index 000000000..b568e1e09 --- /dev/null +++ b/tests/data/list_saml_provider_configs.json @@ -0,0 +1,40 @@ +{ + "inboundSamlConfigs": [ + { + "name": "projects/mock-project-id/inboundSamlConfigs/saml.provider0", + "idpConfig": { + "idpEntityId": "IDP_ENTITY_ID", + "ssoUrl": "https://example.com/login", + "signRequest": true, + "idpCertificates": [ + {"x509Certificate": "CERT1"}, + {"x509Certificate": "CERT2"} + ] + }, + "spConfig": { + "spEntityId": "RP_ENTITY_ID", + "callbackUri": "https://projectId.firebaseapp.com/__/auth/handler" + }, + "displayName": "samlProviderName", + "enabled": true + }, + { + "name": "projects/mock-project-id/inboundSamlConfigs/saml.provider1", + "idpConfig": { + "idpEntityId": "IDP_ENTITY_ID", + "ssoUrl": "https://example.com/login", + "signRequest": true, + "idpCertificates": [ + {"x509Certificate": "CERT1"}, + {"x509Certificate": "CERT2"} + ] + }, + "spConfig": { + "spEntityId": "RP_ENTITY_ID", + "callbackUri": "https://projectId.firebaseapp.com/__/auth/handler" + }, + "displayName": "samlProviderName", + "enabled": true + } + ] +} diff --git a/tests/data/oidc_provider_config.json b/tests/data/oidc_provider_config.json new file mode 100644 index 000000000..89cf3eacf --- /dev/null +++ b/tests/data/oidc_provider_config.json @@ -0,0 +1,7 @@ +{ + "name":"projects/mock-project-id/oauthIdpConfigs/oidc.provider", + "clientId": "CLIENT_ID", + "issuer": "https://oidc.com/issuer", + "displayName": "oidcProviderName", + "enabled": true +} diff --git a/tests/data/saml_provider_config.json b/tests/data/saml_provider_config.json new file mode 100644 index 000000000..577340f2a --- /dev/null +++ b/tests/data/saml_provider_config.json @@ -0,0 +1,18 @@ +{ + "name": "projects/mock-project-id/inboundSamlConfigs/saml.provider", + "idpConfig": { + "idpEntityId": "IDP_ENTITY_ID", + "ssoUrl": "https://example.com/login", + "signRequest": true, + "idpCertificates": [ + {"x509Certificate": "CERT1"}, + {"x509Certificate": "CERT2"} + ] + }, + "spConfig": { + "spEntityId": "RP_ENTITY_ID", + "callbackUri": "https://projectId.firebaseapp.com/__/auth/handler" + }, + "displayName": "samlProviderName", + "enabled": true +} \ No newline at end of file diff --git a/tests/test_auth_providers.py b/tests/test_auth_providers.py new file mode 100644 index 000000000..124aea3cc --- /dev/null +++ b/tests/test_auth_providers.py @@ -0,0 +1,732 @@ +# Copyright 2020 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test cases for the firebase_admin._auth_providers module.""" + +import json + +import pytest + +import firebase_admin +from firebase_admin import auth +from firebase_admin import exceptions +from firebase_admin import _auth_providers +from tests import testutils + +USER_MGT_URL_PREFIX = 'https://identitytoolkit.googleapis.com/v2beta1/projects/mock-project-id' +OIDC_PROVIDER_CONFIG_RESPONSE = testutils.resource('oidc_provider_config.json') +SAML_PROVIDER_CONFIG_RESPONSE = testutils.resource('saml_provider_config.json') +LIST_OIDC_PROVIDER_CONFIGS_RESPONSE = testutils.resource('list_oidc_provider_configs.json') +LIST_SAML_PROVIDER_CONFIGS_RESPONSE = testutils.resource('list_saml_provider_configs.json') + +CONFIG_NOT_FOUND_RESPONSE = """{ + "error": { + "message": "CONFIGURATION_NOT_FOUND" + } +}""" + +INVALID_PROVIDER_IDS = [None, True, False, 1, 0, list(), tuple(), dict(), ''] + + +@pytest.fixture(scope='module') +def user_mgt_app(): + app = firebase_admin.initialize_app(testutils.MockCredential(), name='providerConfig', + options={'projectId': 'mock-project-id'}) + yield app + firebase_admin.delete_app(app) + + +def _instrument_provider_mgt(app, status, payload): + client = auth._get_client(app) + provider_manager = client._provider_manager + recorder = [] + provider_manager.http_client.session.mount( + _auth_providers.ProviderConfigClient.PROVIDER_CONFIG_URL, + testutils.MockAdapter(payload, status, recorder)) + return recorder + + +class TestOIDCProviderConfig: + + VALID_CREATE_OPTIONS = { + 'provider_id': 'oidc.provider', + 'client_id': 'CLIENT_ID', + 'issuer': 'https://oidc.com/issuer', + 'display_name': 'oidcProviderName', + 'enabled': True, + } + + OIDC_CONFIG_REQUEST = { + 'displayName': 'oidcProviderName', + 'enabled': True, + 'clientId': 'CLIENT_ID', + 'issuer': 'https://oidc.com/issuer', + } + + @pytest.mark.parametrize('provider_id', INVALID_PROVIDER_IDS + ['saml.provider']) + def test_get_invalid_provider_id(self, user_mgt_app, provider_id): + with pytest.raises(ValueError) as excinfo: + auth.get_oidc_provider_config(provider_id, app=user_mgt_app) + + assert str(excinfo.value).startswith('Invalid OIDC provider ID') + + def test_get(self, user_mgt_app): + recorder = _instrument_provider_mgt(user_mgt_app, 200, OIDC_PROVIDER_CONFIG_RESPONSE) + + provider_config = auth.get_oidc_provider_config('oidc.provider', app=user_mgt_app) + + self._assert_provider_config(provider_config) + assert len(recorder) == 1 + req = recorder[0] + assert req.method == 'GET' + assert req.url == '{0}{1}'.format(USER_MGT_URL_PREFIX, '/oauthIdpConfigs/oidc.provider') + + @pytest.mark.parametrize('invalid_opts', [ + {'provider_id': None}, {'provider_id': ''}, {'provider_id': 'saml.provider'}, + {'client_id': None}, {'client_id': ''}, + {'issuer': None}, {'issuer': ''}, {'issuer': 'not a url'}, + {'display_name': True}, + {'enabled': 'true'}, + ]) + def test_create_invalid_args(self, user_mgt_app, invalid_opts): + options = dict(self.VALID_CREATE_OPTIONS) + options.update(invalid_opts) + with pytest.raises(ValueError): + auth.create_oidc_provider_config(**options, app=user_mgt_app) + + def test_create(self, user_mgt_app): + recorder = _instrument_provider_mgt(user_mgt_app, 200, OIDC_PROVIDER_CONFIG_RESPONSE) + + provider_config = auth.create_oidc_provider_config( + **self.VALID_CREATE_OPTIONS, app=user_mgt_app) + + self._assert_provider_config(provider_config) + assert len(recorder) == 1 + req = recorder[0] + assert req.method == 'POST' + assert req.url == '{0}/oauthIdpConfigs?oauthIdpConfigId=oidc.provider'.format( + USER_MGT_URL_PREFIX) + got = json.loads(req.body.decode()) + assert got == self.OIDC_CONFIG_REQUEST + + def test_create_minimal(self, user_mgt_app): + recorder = _instrument_provider_mgt(user_mgt_app, 200, OIDC_PROVIDER_CONFIG_RESPONSE) + options = dict(self.VALID_CREATE_OPTIONS) + del options['display_name'] + del options['enabled'] + want = dict(self.OIDC_CONFIG_REQUEST) + del want['displayName'] + del want['enabled'] + + provider_config = auth.create_oidc_provider_config(**options, app=user_mgt_app) + + self._assert_provider_config(provider_config) + assert len(recorder) == 1 + req = recorder[0] + assert req.method == 'POST' + assert req.url == '{0}/oauthIdpConfigs?oauthIdpConfigId=oidc.provider'.format( + USER_MGT_URL_PREFIX) + got = json.loads(req.body.decode()) + assert got == want + + def test_create_empty_values(self, user_mgt_app): + recorder = _instrument_provider_mgt(user_mgt_app, 200, OIDC_PROVIDER_CONFIG_RESPONSE) + options = dict(self.VALID_CREATE_OPTIONS) + options['display_name'] = '' + options['enabled'] = False + want = dict(self.OIDC_CONFIG_REQUEST) + want['displayName'] = '' + want['enabled'] = False + + provider_config = auth.create_oidc_provider_config(**options, app=user_mgt_app) + + self._assert_provider_config(provider_config) + assert len(recorder) == 1 + req = recorder[0] + assert req.method == 'POST' + assert req.url == '{0}/oauthIdpConfigs?oauthIdpConfigId=oidc.provider'.format( + USER_MGT_URL_PREFIX) + got = json.loads(req.body.decode()) + assert got == want + + @pytest.mark.parametrize('invalid_opts', [ + {}, + {'provider_id': None}, {'provider_id': ''}, {'provider_id': 'saml.provider'}, + {'client_id': ''}, + {'issuer': ''}, {'issuer': 'not a url'}, + {'display_name': True}, + {'enabled': 'true'}, + ]) + def test_update_invalid_args(self, user_mgt_app, invalid_opts): + options = {'provider_id': 'oidc.provider'} + options.update(invalid_opts) + with pytest.raises(ValueError): + auth.update_oidc_provider_config(**options, app=user_mgt_app) + + def test_update(self, user_mgt_app): + recorder = _instrument_provider_mgt(user_mgt_app, 200, OIDC_PROVIDER_CONFIG_RESPONSE) + + provider_config = auth.update_oidc_provider_config( + **self.VALID_CREATE_OPTIONS, app=user_mgt_app) + + self._assert_provider_config(provider_config) + assert len(recorder) == 1 + req = recorder[0] + assert req.method == 'PATCH' + mask = ['clientId', 'displayName', 'enabled', 'issuer'] + assert req.url == '{0}/oauthIdpConfigs/oidc.provider?updateMask={1}'.format( + USER_MGT_URL_PREFIX, ','.join(mask)) + got = json.loads(req.body.decode()) + assert got == self.OIDC_CONFIG_REQUEST + + def test_update_minimal(self, user_mgt_app): + recorder = _instrument_provider_mgt(user_mgt_app, 200, OIDC_PROVIDER_CONFIG_RESPONSE) + + provider_config = auth.update_oidc_provider_config( + 'oidc.provider', display_name='oidcProviderName', app=user_mgt_app) + + self._assert_provider_config(provider_config) + assert len(recorder) == 1 + req = recorder[0] + assert req.method == 'PATCH' + assert req.url == '{0}/oauthIdpConfigs/oidc.provider?updateMask=displayName'.format( + USER_MGT_URL_PREFIX) + got = json.loads(req.body.decode()) + assert got == {'displayName': 'oidcProviderName'} + + def test_update_empty_values(self, user_mgt_app): + recorder = _instrument_provider_mgt(user_mgt_app, 200, OIDC_PROVIDER_CONFIG_RESPONSE) + + provider_config = auth.update_oidc_provider_config( + 'oidc.provider', display_name=auth.DELETE_ATTRIBUTE, enabled=False, app=user_mgt_app) + + self._assert_provider_config(provider_config) + assert len(recorder) == 1 + req = recorder[0] + assert req.method == 'PATCH' + mask = ['displayName', 'enabled'] + assert req.url == '{0}/oauthIdpConfigs/oidc.provider?updateMask={1}'.format( + USER_MGT_URL_PREFIX, ','.join(mask)) + got = json.loads(req.body.decode()) + assert got == {'displayName': None, 'enabled': False} + + @pytest.mark.parametrize('provider_id', INVALID_PROVIDER_IDS + ['saml.provider']) + def test_delete_invalid_provider_id(self, user_mgt_app, provider_id): + with pytest.raises(ValueError) as excinfo: + auth.delete_oidc_provider_config(provider_id, app=user_mgt_app) + + assert str(excinfo.value).startswith('Invalid OIDC provider ID') + + def test_delete(self, user_mgt_app): + recorder = _instrument_provider_mgt(user_mgt_app, 200, '{}') + + auth.delete_oidc_provider_config('oidc.provider', app=user_mgt_app) + + assert len(recorder) == 1 + req = recorder[0] + assert req.method == 'DELETE' + assert req.url == '{0}{1}'.format(USER_MGT_URL_PREFIX, '/oauthIdpConfigs/oidc.provider') + + @pytest.mark.parametrize('arg', [None, 'foo', list(), dict(), 0, -1, 101, False]) + def test_invalid_max_results(self, user_mgt_app, arg): + with pytest.raises(ValueError): + auth.list_oidc_provider_configs(max_results=arg, app=user_mgt_app) + + @pytest.mark.parametrize('arg', ['', list(), dict(), 0, -1, 101, False]) + def test_invalid_page_token(self, user_mgt_app, arg): + with pytest.raises(ValueError): + auth.list_oidc_provider_configs(page_token=arg, app=user_mgt_app) + + def test_list_single_page(self, user_mgt_app): + recorder = _instrument_provider_mgt(user_mgt_app, 200, LIST_OIDC_PROVIDER_CONFIGS_RESPONSE) + page = auth.list_oidc_provider_configs(app=user_mgt_app) + + self._assert_page(page) + provider_configs = list(config for config in page.iterate_all()) + assert len(provider_configs) == 2 + + assert len(recorder) == 1 + req = recorder[0] + assert req.method == 'GET' + assert req.url == '{0}{1}'.format(USER_MGT_URL_PREFIX, '/oauthIdpConfigs?pageSize=100') + + def test_list_multiple_pages(self, user_mgt_app): + sample_response = json.loads(OIDC_PROVIDER_CONFIG_RESPONSE) + configs = _create_list_response(sample_response) + + # Page 1 + response = { + 'oauthIdpConfigs': configs[:2], + 'nextPageToken': 'token' + } + recorder = _instrument_provider_mgt(user_mgt_app, 200, json.dumps(response)) + page = auth.list_oidc_provider_configs(max_results=10, app=user_mgt_app) + + self._assert_page(page, next_page_token='token') + assert len(recorder) == 1 + req = recorder[0] + assert req.method == 'GET' + assert req.url == '{0}/oauthIdpConfigs?pageSize=10'.format(USER_MGT_URL_PREFIX) + + # Page 2 (also the last page) + response = {'oauthIdpConfigs': configs[2:]} + recorder = _instrument_provider_mgt(user_mgt_app, 200, json.dumps(response)) + page = page.get_next_page() + + self._assert_page(page, count=1, start=2) + assert len(recorder) == 1 + req = recorder[0] + assert req.method == 'GET' + assert req.url == '{0}/oauthIdpConfigs?pageSize=10&pageToken=token'.format( + USER_MGT_URL_PREFIX) + + def test_paged_iteration(self, user_mgt_app): + sample_response = json.loads(OIDC_PROVIDER_CONFIG_RESPONSE) + configs = _create_list_response(sample_response) + + # Page 1 + response = { + 'oauthIdpConfigs': configs[:2], + 'nextPageToken': 'token' + } + recorder = _instrument_provider_mgt(user_mgt_app, 200, json.dumps(response)) + page = auth.list_oidc_provider_configs(app=user_mgt_app) + iterator = page.iterate_all() + + for index in range(2): + provider_config = next(iterator) + assert provider_config.provider_id == 'oidc.provider{0}'.format(index) + assert len(recorder) == 1 + req = recorder[0] + assert req.method == 'GET' + assert req.url == '{0}/oauthIdpConfigs?pageSize=100'.format(USER_MGT_URL_PREFIX) + + # Page 2 (also the last page) + response = {'oauthIdpConfigs': configs[2:]} + recorder = _instrument_provider_mgt(user_mgt_app, 200, json.dumps(response)) + + provider_config = next(iterator) + assert provider_config.provider_id == 'oidc.provider2' + assert len(recorder) == 1 + req = recorder[0] + assert req.method == 'GET' + assert req.url == '{0}/oauthIdpConfigs?pageSize=100&pageToken=token'.format( + USER_MGT_URL_PREFIX) + + with pytest.raises(StopIteration): + next(iterator) + + def test_list_empty_response(self, user_mgt_app): + response = {'oauthIdpConfigs': []} + _instrument_provider_mgt(user_mgt_app, 200, json.dumps(response)) + page = auth.list_oidc_provider_configs(app=user_mgt_app) + assert len(page.provider_configs) == 0 + provider_configs = list(config for config in page.iterate_all()) + assert len(provider_configs) == 0 + + def test_list_error(self, user_mgt_app): + _instrument_provider_mgt(user_mgt_app, 500, '{"error":"test"}') + with pytest.raises(exceptions.InternalError) as excinfo: + auth.list_oidc_provider_configs(app=user_mgt_app) + assert str(excinfo.value) == 'Unexpected error response: {"error":"test"}' + + def test_config_not_found(self, user_mgt_app): + _instrument_provider_mgt(user_mgt_app, 500, CONFIG_NOT_FOUND_RESPONSE) + + with pytest.raises(auth.ConfigurationNotFoundError) as excinfo: + auth.get_oidc_provider_config('oidc.provider', app=user_mgt_app) + + error_msg = 'No auth provider found for the given identifier (CONFIGURATION_NOT_FOUND).' + assert excinfo.value.code == exceptions.NOT_FOUND + assert str(excinfo.value) == error_msg + assert excinfo.value.http_response is not None + assert excinfo.value.cause is not None + + def _assert_provider_config(self, provider_config, want_id='oidc.provider'): + assert isinstance(provider_config, auth.OIDCProviderConfig) + assert provider_config.provider_id == want_id + assert provider_config.display_name == 'oidcProviderName' + assert provider_config.enabled is True + assert provider_config.issuer == 'https://oidc.com/issuer' + assert provider_config.client_id == 'CLIENT_ID' + + def _assert_page(self, page, count=2, start=0, next_page_token=''): + assert isinstance(page, auth.ListProviderConfigsPage) + index = start + assert len(page.provider_configs) == count + for provider_config in page.provider_configs: + self._assert_provider_config(provider_config, want_id='oidc.provider{0}'.format(index)) + index += 1 + + if next_page_token: + assert page.next_page_token == next_page_token + assert page.has_next_page is True + else: + assert page.next_page_token == '' + assert page.has_next_page is False + assert page.get_next_page() is None + + +class TestSAMLProviderConfig: + + VALID_CREATE_OPTIONS = { + 'provider_id': 'saml.provider', + 'idp_entity_id': 'IDP_ENTITY_ID', + 'sso_url': 'https://example.com/login', + 'x509_certificates': ['CERT1', 'CERT2'], + 'rp_entity_id': 'RP_ENTITY_ID', + 'callback_url': 'https://projectId.firebaseapp.com/__/auth/handler', + 'display_name': 'samlProviderName', + 'enabled': True, + } + + SAML_CONFIG_REQUEST = { + 'displayName': 'samlProviderName', + 'enabled': True, + 'idpConfig': { + 'idpEntityId': 'IDP_ENTITY_ID', + 'ssoUrl': 'https://example.com/login', + 'idpCertificates': [{'x509Certificate': 'CERT1'}, {'x509Certificate': 'CERT2'}] + }, + 'spConfig': { + 'spEntityId': 'RP_ENTITY_ID', + 'callbackUri': 'https://projectId.firebaseapp.com/__/auth/handler', + } + } + + @pytest.mark.parametrize('provider_id', INVALID_PROVIDER_IDS + ['oidc.provider']) + def test_get_invalid_provider_id(self, user_mgt_app, provider_id): + with pytest.raises(ValueError) as excinfo: + auth.get_saml_provider_config(provider_id, app=user_mgt_app) + + assert str(excinfo.value).startswith('Invalid SAML provider ID') + + def test_get(self, user_mgt_app): + recorder = _instrument_provider_mgt(user_mgt_app, 200, SAML_PROVIDER_CONFIG_RESPONSE) + + provider_config = auth.get_saml_provider_config('saml.provider', app=user_mgt_app) + + self._assert_provider_config(provider_config) + assert len(recorder) == 1 + req = recorder[0] + assert req.method == 'GET' + assert req.url == '{0}{1}'.format(USER_MGT_URL_PREFIX, '/inboundSamlConfigs/saml.provider') + + @pytest.mark.parametrize('invalid_opts', [ + {'provider_id': None}, {'provider_id': ''}, {'provider_id': 'oidc.provider'}, + {'idp_entity_id': None}, {'idp_entity_id': ''}, + {'sso_url': None}, {'sso_url': ''}, {'sso_url': 'not a url'}, + {'x509_certificates': None}, {'x509_certificates': []}, {'x509_certificates': 'cert'}, + {'x509_certificates': [None]}, {'x509_certificates': ['foo', {}]}, + {'rp_entity_id': None}, {'rp_entity_id': ''}, + {'callback_url': None}, {'callback_url': ''}, {'callback_url': 'not a url'}, + {'display_name': True}, + {'enabled': 'true'}, + ]) + def test_create_invalid_args(self, user_mgt_app, invalid_opts): + options = dict(self.VALID_CREATE_OPTIONS) + options.update(invalid_opts) + with pytest.raises(ValueError): + auth.create_saml_provider_config(**options, app=user_mgt_app) + + def test_create(self, user_mgt_app): + recorder = _instrument_provider_mgt(user_mgt_app, 200, SAML_PROVIDER_CONFIG_RESPONSE) + + provider_config = auth.create_saml_provider_config( + **self.VALID_CREATE_OPTIONS, app=user_mgt_app) + + self._assert_provider_config(provider_config) + assert len(recorder) == 1 + req = recorder[0] + assert req.method == 'POST' + assert req.url == '{0}/inboundSamlConfigs?inboundSamlConfigId=saml.provider'.format( + USER_MGT_URL_PREFIX) + got = json.loads(req.body.decode()) + assert got == self.SAML_CONFIG_REQUEST + + def test_create_minimal(self, user_mgt_app): + recorder = _instrument_provider_mgt(user_mgt_app, 200, SAML_PROVIDER_CONFIG_RESPONSE) + options = dict(self.VALID_CREATE_OPTIONS) + del options['display_name'] + del options['enabled'] + want = dict(self.SAML_CONFIG_REQUEST) + del want['displayName'] + del want['enabled'] + + provider_config = auth.create_saml_provider_config(**options, app=user_mgt_app) + + self._assert_provider_config(provider_config) + assert len(recorder) == 1 + req = recorder[0] + assert req.method == 'POST' + assert req.url == '{0}/inboundSamlConfigs?inboundSamlConfigId=saml.provider'.format( + USER_MGT_URL_PREFIX) + got = json.loads(req.body.decode()) + assert got == want + + def test_create_empty_values(self, user_mgt_app): + recorder = _instrument_provider_mgt(user_mgt_app, 200, SAML_PROVIDER_CONFIG_RESPONSE) + options = dict(self.VALID_CREATE_OPTIONS) + options['display_name'] = '' + options['enabled'] = False + want = dict(self.SAML_CONFIG_REQUEST) + want['displayName'] = '' + want['enabled'] = False + + provider_config = auth.create_saml_provider_config(**options, app=user_mgt_app) + + self._assert_provider_config(provider_config) + assert len(recorder) == 1 + req = recorder[0] + assert req.method == 'POST' + assert req.url == '{0}/inboundSamlConfigs?inboundSamlConfigId=saml.provider'.format( + USER_MGT_URL_PREFIX) + got = json.loads(req.body.decode()) + assert got == want + + @pytest.mark.parametrize('invalid_opts', [ + {}, + {'provider_id': None}, {'provider_id': ''}, {'provider_id': 'oidc.provider'}, + {'idp_entity_id': ''}, + {'sso_url': ''}, {'sso_url': 'not a url'}, + {'x509_certificates': []}, {'x509_certificates': 'cert'}, + {'x509_certificates': [None]}, {'x509_certificates': ['foo', {}]}, + {'rp_entity_id': ''}, + {'callback_url': ''}, {'callback_url': 'not a url'}, + {'display_name': True}, + {'enabled': 'true'}, + ]) + def test_update_invalid_args(self, user_mgt_app, invalid_opts): + options = {'provider_id': 'saml.provider'} + options.update(invalid_opts) + with pytest.raises(ValueError): + auth.update_saml_provider_config(**options, app=user_mgt_app) + + def test_update(self, user_mgt_app): + recorder = _instrument_provider_mgt(user_mgt_app, 200, SAML_PROVIDER_CONFIG_RESPONSE) + + provider_config = auth.update_saml_provider_config( + **self.VALID_CREATE_OPTIONS, app=user_mgt_app) + + self._assert_provider_config(provider_config) + assert len(recorder) == 1 + req = recorder[0] + assert req.method == 'PATCH' + mask = [ + 'displayName', 'enabled', 'idpConfig.idpCertificates', 'idpConfig.idpEntityId', + 'idpConfig.ssoUrl', 'spConfig.callbackUri', 'spConfig.spEntityId', + ] + assert req.url == '{0}/inboundSamlConfigs/saml.provider?updateMask={1}'.format( + USER_MGT_URL_PREFIX, ','.join(mask)) + got = json.loads(req.body.decode()) + assert got == self.SAML_CONFIG_REQUEST + + def test_update_minimal(self, user_mgt_app): + recorder = _instrument_provider_mgt(user_mgt_app, 200, SAML_PROVIDER_CONFIG_RESPONSE) + + provider_config = auth.update_saml_provider_config( + 'saml.provider', display_name='samlProviderName', app=user_mgt_app) + + self._assert_provider_config(provider_config) + assert len(recorder) == 1 + req = recorder[0] + assert req.method == 'PATCH' + assert req.url == '{0}/inboundSamlConfigs/saml.provider?updateMask=displayName'.format( + USER_MGT_URL_PREFIX) + got = json.loads(req.body.decode()) + assert got == {'displayName': 'samlProviderName'} + + def test_update_empty_values(self, user_mgt_app): + recorder = _instrument_provider_mgt(user_mgt_app, 200, SAML_PROVIDER_CONFIG_RESPONSE) + + provider_config = auth.update_saml_provider_config( + 'saml.provider', display_name=auth.DELETE_ATTRIBUTE, enabled=False, app=user_mgt_app) + + self._assert_provider_config(provider_config) + assert len(recorder) == 1 + req = recorder[0] + assert req.method == 'PATCH' + mask = ['displayName', 'enabled'] + assert req.url == '{0}/inboundSamlConfigs/saml.provider?updateMask={1}'.format( + USER_MGT_URL_PREFIX, ','.join(mask)) + got = json.loads(req.body.decode()) + assert got == {'displayName': None, 'enabled': False} + + @pytest.mark.parametrize('provider_id', INVALID_PROVIDER_IDS + ['oidc.provider']) + def test_delete_invalid_provider_id(self, user_mgt_app, provider_id): + with pytest.raises(ValueError) as excinfo: + auth.delete_saml_provider_config(provider_id, app=user_mgt_app) + + assert str(excinfo.value).startswith('Invalid SAML provider ID') + + def test_delete(self, user_mgt_app): + recorder = _instrument_provider_mgt(user_mgt_app, 200, '{}') + + auth.delete_saml_provider_config('saml.provider', app=user_mgt_app) + + assert len(recorder) == 1 + req = recorder[0] + assert req.method == 'DELETE' + assert req.url == '{0}{1}'.format(USER_MGT_URL_PREFIX, '/inboundSamlConfigs/saml.provider') + + def test_config_not_found(self, user_mgt_app): + _instrument_provider_mgt(user_mgt_app, 500, CONFIG_NOT_FOUND_RESPONSE) + + with pytest.raises(auth.ConfigurationNotFoundError) as excinfo: + auth.get_saml_provider_config('saml.provider', app=user_mgt_app) + + error_msg = 'No auth provider found for the given identifier (CONFIGURATION_NOT_FOUND).' + assert excinfo.value.code == exceptions.NOT_FOUND + assert str(excinfo.value) == error_msg + assert excinfo.value.http_response is not None + assert excinfo.value.cause is not None + + @pytest.mark.parametrize('arg', [None, 'foo', list(), dict(), 0, -1, 101, False]) + def test_invalid_max_results(self, user_mgt_app, arg): + with pytest.raises(ValueError): + auth.list_saml_provider_configs(max_results=arg, app=user_mgt_app) + + @pytest.mark.parametrize('arg', ['', list(), dict(), 0, -1, 101, False]) + def test_invalid_page_token(self, user_mgt_app, arg): + with pytest.raises(ValueError): + auth.list_saml_provider_configs(page_token=arg, app=user_mgt_app) + + def test_list_single_page(self, user_mgt_app): + recorder = _instrument_provider_mgt(user_mgt_app, 200, LIST_SAML_PROVIDER_CONFIGS_RESPONSE) + page = auth.list_saml_provider_configs(app=user_mgt_app) + + self._assert_page(page) + provider_configs = list(config for config in page.iterate_all()) + assert len(provider_configs) == 2 + + assert len(recorder) == 1 + req = recorder[0] + assert req.method == 'GET' + assert req.url == '{0}{1}'.format(USER_MGT_URL_PREFIX, '/inboundSamlConfigs?pageSize=100') + + def test_list_multiple_pages(self, user_mgt_app): + sample_response = json.loads(SAML_PROVIDER_CONFIG_RESPONSE) + configs = _create_list_response(sample_response) + + # Page 1 + response = { + 'inboundSamlConfigs': configs[:2], + 'nextPageToken': 'token' + } + recorder = _instrument_provider_mgt(user_mgt_app, 200, json.dumps(response)) + page = auth.list_saml_provider_configs(max_results=10, app=user_mgt_app) + + self._assert_page(page, next_page_token='token') + assert len(recorder) == 1 + req = recorder[0] + assert req.method == 'GET' + assert req.url == '{0}/inboundSamlConfigs?pageSize=10'.format(USER_MGT_URL_PREFIX) + + # Page 2 (also the last page) + response = {'inboundSamlConfigs': configs[2:]} + recorder = _instrument_provider_mgt(user_mgt_app, 200, json.dumps(response)) + page = page.get_next_page() + + self._assert_page(page, count=1, start=2) + assert len(recorder) == 1 + req = recorder[0] + assert req.method == 'GET' + assert req.url == '{0}/inboundSamlConfigs?pageSize=10&pageToken=token'.format( + USER_MGT_URL_PREFIX) + + def test_paged_iteration(self, user_mgt_app): + sample_response = json.loads(SAML_PROVIDER_CONFIG_RESPONSE) + configs = _create_list_response(sample_response) + + # Page 1 + response = { + 'inboundSamlConfigs': configs[:2], + 'nextPageToken': 'token' + } + recorder = _instrument_provider_mgt(user_mgt_app, 200, json.dumps(response)) + page = auth.list_saml_provider_configs(app=user_mgt_app) + iterator = page.iterate_all() + + for index in range(2): + provider_config = next(iterator) + assert provider_config.provider_id == 'saml.provider{0}'.format(index) + assert len(recorder) == 1 + req = recorder[0] + assert req.method == 'GET' + assert req.url == '{0}/inboundSamlConfigs?pageSize=100'.format(USER_MGT_URL_PREFIX) + + # Page 2 (also the last page) + response = {'inboundSamlConfigs': configs[2:]} + recorder = _instrument_provider_mgt(user_mgt_app, 200, json.dumps(response)) + + provider_config = next(iterator) + assert provider_config.provider_id == 'saml.provider2' + assert len(recorder) == 1 + req = recorder[0] + assert req.method == 'GET' + assert req.url == '{0}/inboundSamlConfigs?pageSize=100&pageToken=token'.format( + USER_MGT_URL_PREFIX) + + with pytest.raises(StopIteration): + next(iterator) + + def test_list_empty_response(self, user_mgt_app): + response = {'inboundSamlConfigs': []} + _instrument_provider_mgt(user_mgt_app, 200, json.dumps(response)) + page = auth.list_saml_provider_configs(app=user_mgt_app) + assert len(page.provider_configs) == 0 + provider_configs = list(config for config in page.iterate_all()) + assert len(provider_configs) == 0 + + def test_list_error(self, user_mgt_app): + _instrument_provider_mgt(user_mgt_app, 500, '{"error":"test"}') + with pytest.raises(exceptions.InternalError) as excinfo: + auth.list_saml_provider_configs(app=user_mgt_app) + assert str(excinfo.value) == 'Unexpected error response: {"error":"test"}' + + def _assert_provider_config(self, provider_config, want_id='saml.provider'): + assert isinstance(provider_config, auth.SAMLProviderConfig) + assert provider_config.provider_id == want_id + assert provider_config.display_name == 'samlProviderName' + assert provider_config.enabled is True + assert provider_config.idp_entity_id == 'IDP_ENTITY_ID' + assert provider_config.sso_url == 'https://example.com/login' + assert provider_config.x509_certificates == ['CERT1', 'CERT2'] + assert provider_config.rp_entity_id == 'RP_ENTITY_ID' + assert provider_config.callback_url == 'https://projectId.firebaseapp.com/__/auth/handler' + + def _assert_page(self, page, count=2, start=0, next_page_token=''): + assert isinstance(page, auth.ListProviderConfigsPage) + index = start + assert len(page.provider_configs) == count + for provider_config in page.provider_configs: + self._assert_provider_config(provider_config, want_id='saml.provider{0}'.format(index)) + index += 1 + + if next_page_token: + assert page.next_page_token == next_page_token + assert page.has_next_page is True + else: + assert page.next_page_token == '' + assert page.has_next_page is False + assert page.get_next_page() is None + + +def _create_list_response(sample_response, count=3): + configs = [] + for idx in range(count): + config = dict(sample_response) + config['name'] += str(idx) + configs.append(config) + return configs diff --git a/tests/test_tenant_mgt.py b/tests/test_tenant_mgt.py new file mode 100644 index 000000000..f92dd2a83 --- /dev/null +++ b/tests/test_tenant_mgt.py @@ -0,0 +1,1004 @@ +# Copyright 2020 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test cases for the firebase_admin.tenant_mgt module.""" + +import json +from urllib import parse + +import pytest + +import firebase_admin +from firebase_admin import auth +from firebase_admin import credentials +from firebase_admin import exceptions +from firebase_admin import tenant_mgt +from firebase_admin import _auth_providers +from firebase_admin import _user_mgt +from tests import testutils +from tests import test_token_gen + + +GET_TENANT_RESPONSE = """{ + "name": "projects/mock-project-id/tenants/tenant-id", + "displayName": "Test Tenant", + "allowPasswordSignup": true, + "enableEmailLinkSignin": true +}""" + +TENANT_NOT_FOUND_RESPONSE = """{ + "error": { + "message": "TENANT_NOT_FOUND" + } +}""" + +LIST_TENANTS_RESPONSE = """{ + "tenants": [ + { + "name": "projects/mock-project-id/tenants/tenant0", + "displayName": "Test Tenant", + "allowPasswordSignup": true, + "enableEmailLinkSignin": true + }, + { + "name": "projects/mock-project-id/tenants/tenant1", + "displayName": "Test Tenant", + "allowPasswordSignup": true, + "enableEmailLinkSignin": true + } + ] +}""" + +LIST_TENANTS_RESPONSE_WITH_TOKEN = """{ + "tenants": [ + { + "name": "projects/mock-project-id/tenants/tenant0" + }, + { + "name": "projects/mock-project-id/tenants/tenant1" + }, + { + "name": "projects/mock-project-id/tenants/tenant2" + } + ], + "nextPageToken": "token" +}""" + +MOCK_GET_USER_RESPONSE = testutils.resource('get_user.json') +MOCK_LIST_USERS_RESPONSE = testutils.resource('list_users.json') + +OIDC_PROVIDER_CONFIG_RESPONSE = testutils.resource('oidc_provider_config.json') +OIDC_PROVIDER_CONFIG_REQUEST = { + 'displayName': 'oidcProviderName', + 'enabled': True, + 'clientId': 'CLIENT_ID', + 'issuer': 'https://oidc.com/issuer', +} + +SAML_PROVIDER_CONFIG_RESPONSE = testutils.resource('saml_provider_config.json') +SAML_PROVIDER_CONFIG_REQUEST = body = { + 'displayName': 'samlProviderName', + 'enabled': True, + 'idpConfig': { + 'idpEntityId': 'IDP_ENTITY_ID', + 'ssoUrl': 'https://example.com/login', + 'idpCertificates': [{'x509Certificate': 'CERT1'}, {'x509Certificate': 'CERT2'}] + }, + 'spConfig': { + 'spEntityId': 'RP_ENTITY_ID', + 'callbackUri': 'https://projectId.firebaseapp.com/__/auth/handler', + } +} + +LIST_OIDC_PROVIDER_CONFIGS_RESPONSE = testutils.resource('list_oidc_provider_configs.json') +LIST_SAML_PROVIDER_CONFIGS_RESPONSE = testutils.resource('list_saml_provider_configs.json') + +INVALID_TENANT_IDS = [None, '', 0, 1, True, False, list(), tuple(), dict()] +INVALID_BOOLEANS = ['', 1, 0, list(), tuple(), dict()] + +USER_MGT_URL_PREFIX = 'https://identitytoolkit.googleapis.com/v1/projects/mock-project-id' +PROVIDER_MGT_URL_PREFIX = 'https://identitytoolkit.googleapis.com/v2beta1/projects/mock-project-id' +TENANT_MGT_URL_PREFIX = 'https://identitytoolkit.googleapis.com/v2beta1/projects/mock-project-id' + + +@pytest.fixture(scope='module') +def tenant_mgt_app(): + app = firebase_admin.initialize_app( + testutils.MockCredential(), name='tenantMgt', options={'projectId': 'mock-project-id'}) + yield app + firebase_admin.delete_app(app) + + +def _instrument_tenant_mgt(app, status, payload): + service = tenant_mgt._get_tenant_mgt_service(app) + recorder = [] + service.client.session.mount( + tenant_mgt._TenantManagementService.TENANT_MGT_URL, + testutils.MockAdapter(payload, status, recorder)) + return service, recorder + + +def _instrument_user_mgt(client, status, payload): + recorder = [] + user_manager = client._user_manager + user_manager.http_client.session.mount( + _user_mgt.UserManager.ID_TOOLKIT_URL, + testutils.MockAdapter(payload, status, recorder)) + return recorder + + +def _instrument_provider_mgt(client, status, payload): + recorder = [] + provider_manager = client._provider_manager + provider_manager.http_client.session.mount( + _auth_providers.ProviderConfigClient.PROVIDER_CONFIG_URL, + testutils.MockAdapter(payload, status, recorder)) + return recorder + + +class TestTenant: + + @pytest.mark.parametrize('data', [None, 'foo', 0, 1, True, False, list(), tuple(), dict()]) + def test_invalid_data(self, data): + with pytest.raises(ValueError): + tenant_mgt.Tenant(data) + + def test_tenant(self): + data = { + 'name': 'projects/test-project/tenants/tenant-id', + 'displayName': 'Test Tenant', + 'allowPasswordSignup': True, + 'enableEmailLinkSignin': True, + } + tenant = tenant_mgt.Tenant(data) + assert tenant.tenant_id == 'tenant-id' + assert tenant.display_name == 'Test Tenant' + assert tenant.allow_password_sign_up is True + assert tenant.enable_email_link_sign_in is True + + def test_tenant_optional_params(self): + data = { + 'name': 'projects/test-project/tenants/tenant-id', + } + tenant = tenant_mgt.Tenant(data) + assert tenant.tenant_id == 'tenant-id' + assert tenant.display_name is None + assert tenant.allow_password_sign_up is False + assert tenant.enable_email_link_sign_in is False + + +class TestGetTenant: + + @pytest.mark.parametrize('tenant_id', INVALID_TENANT_IDS) + def test_invalid_tenant_id(self, tenant_id, tenant_mgt_app): + with pytest.raises(ValueError) as excinfo: + tenant_mgt.get_tenant(tenant_id, app=tenant_mgt_app) + assert str(excinfo.value).startswith('Invalid tenant ID') + + def test_get_tenant(self, tenant_mgt_app): + _, recorder = _instrument_tenant_mgt(tenant_mgt_app, 200, GET_TENANT_RESPONSE) + tenant = tenant_mgt.get_tenant('tenant-id', app=tenant_mgt_app) + + _assert_tenant(tenant) + assert len(recorder) == 1 + req = recorder[0] + assert req.method == 'GET' + assert req.url == '{0}/tenants/tenant-id'.format(TENANT_MGT_URL_PREFIX) + + def test_tenant_not_found(self, tenant_mgt_app): + _instrument_tenant_mgt(tenant_mgt_app, 500, TENANT_NOT_FOUND_RESPONSE) + with pytest.raises(tenant_mgt.TenantNotFoundError) as excinfo: + tenant_mgt.get_tenant('tenant-id', app=tenant_mgt_app) + + error_msg = 'No tenant found for the given identifier (TENANT_NOT_FOUND).' + assert excinfo.value.code == exceptions.NOT_FOUND + assert str(excinfo.value) == error_msg + assert excinfo.value.http_response is not None + assert excinfo.value.cause is not None + + +class TestCreateTenant: + + @pytest.mark.parametrize('display_name', [True, False, 1, 0, list(), tuple(), dict()]) + def test_invalid_display_name_type(self, display_name, tenant_mgt_app): + with pytest.raises(ValueError) as excinfo: + tenant_mgt.create_tenant(display_name=display_name, app=tenant_mgt_app) + assert str(excinfo.value).startswith('Invalid type for displayName') + + @pytest.mark.parametrize('display_name', ['', 'foo', '1test', 'foo bar', 'a'*21]) + def test_invalid_display_name_value(self, display_name, tenant_mgt_app): + with pytest.raises(ValueError) as excinfo: + tenant_mgt.create_tenant(display_name=display_name, app=tenant_mgt_app) + assert str(excinfo.value).startswith('displayName must start') + + @pytest.mark.parametrize('allow', INVALID_BOOLEANS) + def test_invalid_allow_password_sign_up(self, allow, tenant_mgt_app): + with pytest.raises(ValueError) as excinfo: + tenant_mgt.create_tenant( + display_name='test', allow_password_sign_up=allow, app=tenant_mgt_app) + assert str(excinfo.value).startswith('Invalid type for allowPasswordSignup') + + @pytest.mark.parametrize('enable', INVALID_BOOLEANS) + def test_invalid_enable_email_link_sign_in(self, enable, tenant_mgt_app): + with pytest.raises(ValueError) as excinfo: + tenant_mgt.create_tenant( + display_name='test', enable_email_link_sign_in=enable, app=tenant_mgt_app) + assert str(excinfo.value).startswith('Invalid type for enableEmailLinkSignin') + + def test_create_tenant(self, tenant_mgt_app): + _, recorder = _instrument_tenant_mgt(tenant_mgt_app, 200, GET_TENANT_RESPONSE) + tenant = tenant_mgt.create_tenant( + display_name='My-Tenant', allow_password_sign_up=True, enable_email_link_sign_in=True, + app=tenant_mgt_app) + + _assert_tenant(tenant) + self._assert_request(recorder, { + 'displayName': 'My-Tenant', + 'allowPasswordSignup': True, + 'enableEmailLinkSignin': True, + }) + + def test_create_tenant_false_values(self, tenant_mgt_app): + _, recorder = _instrument_tenant_mgt(tenant_mgt_app, 200, GET_TENANT_RESPONSE) + tenant = tenant_mgt.create_tenant( + display_name='test', allow_password_sign_up=False, enable_email_link_sign_in=False, + app=tenant_mgt_app) + + _assert_tenant(tenant) + self._assert_request(recorder, { + 'displayName': 'test', + 'allowPasswordSignup': False, + 'enableEmailLinkSignin': False, + }) + + def test_create_tenant_minimal(self, tenant_mgt_app): + _, recorder = _instrument_tenant_mgt(tenant_mgt_app, 200, GET_TENANT_RESPONSE) + tenant = tenant_mgt.create_tenant(display_name='test', app=tenant_mgt_app) + + _assert_tenant(tenant) + self._assert_request(recorder, {'displayName': 'test'}) + + def test_error(self, tenant_mgt_app): + _instrument_tenant_mgt(tenant_mgt_app, 500, '{}') + with pytest.raises(exceptions.InternalError) as excinfo: + tenant_mgt.create_tenant(display_name='test', app=tenant_mgt_app) + + error_msg = 'Unexpected error response: {}' + assert excinfo.value.code == exceptions.INTERNAL + assert str(excinfo.value) == error_msg + assert excinfo.value.http_response is not None + assert excinfo.value.cause is not None + + def _assert_request(self, recorder, body): + assert len(recorder) == 1 + req = recorder[0] + assert req.method == 'POST' + assert req.url == '{0}/tenants'.format(TENANT_MGT_URL_PREFIX) + got = json.loads(req.body.decode()) + assert got == body + + +class TestUpdateTenant: + + @pytest.mark.parametrize('tenant_id', INVALID_TENANT_IDS) + def test_invalid_tenant_id(self, tenant_id, tenant_mgt_app): + with pytest.raises(ValueError) as excinfo: + tenant_mgt.update_tenant(tenant_id, display_name='My Tenant', app=tenant_mgt_app) + assert str(excinfo.value).startswith('Tenant ID must be a non-empty string') + + @pytest.mark.parametrize('display_name', [True, False, 1, 0, list(), tuple(), dict()]) + def test_invalid_display_name_type(self, display_name, tenant_mgt_app): + with pytest.raises(ValueError) as excinfo: + tenant_mgt.update_tenant('tenant-id', display_name=display_name, app=tenant_mgt_app) + assert str(excinfo.value).startswith('Invalid type for displayName') + + @pytest.mark.parametrize('display_name', ['', 'foo', '1test', 'foo bar', 'a'*21]) + def test_invalid_display_name_value(self, display_name, tenant_mgt_app): + with pytest.raises(ValueError) as excinfo: + tenant_mgt.update_tenant('tenant-id', display_name=display_name, app=tenant_mgt_app) + assert str(excinfo.value).startswith('displayName must start') + + @pytest.mark.parametrize('allow', INVALID_BOOLEANS) + def test_invalid_allow_password_sign_up(self, allow, tenant_mgt_app): + with pytest.raises(ValueError) as excinfo: + tenant_mgt.update_tenant('tenant-id', allow_password_sign_up=allow, app=tenant_mgt_app) + assert str(excinfo.value).startswith('Invalid type for allowPasswordSignup') + + @pytest.mark.parametrize('enable', INVALID_BOOLEANS) + def test_invalid_enable_email_link_sign_in(self, enable, tenant_mgt_app): + with pytest.raises(ValueError) as excinfo: + tenant_mgt.update_tenant( + 'tenant-id', enable_email_link_sign_in=enable, app=tenant_mgt_app) + assert str(excinfo.value).startswith('Invalid type for enableEmailLinkSignin') + + def test_update_tenant_no_args(self, tenant_mgt_app): + with pytest.raises(ValueError) as excinfo: + tenant_mgt.update_tenant('tenant-id', app=tenant_mgt_app) + assert str(excinfo.value).startswith('At least one parameter must be specified for update') + + def test_update_tenant(self, tenant_mgt_app): + _, recorder = _instrument_tenant_mgt(tenant_mgt_app, 200, GET_TENANT_RESPONSE) + tenant = tenant_mgt.update_tenant( + 'tenant-id', display_name='My-Tenant', allow_password_sign_up=True, + enable_email_link_sign_in=True, app=tenant_mgt_app) + + _assert_tenant(tenant) + body = { + 'displayName': 'My-Tenant', + 'allowPasswordSignup': True, + 'enableEmailLinkSignin': True, + } + mask = ['allowPasswordSignup', 'displayName', 'enableEmailLinkSignin'] + self._assert_request(recorder, body, mask) + + def test_update_tenant_false_values(self, tenant_mgt_app): + _, recorder = _instrument_tenant_mgt(tenant_mgt_app, 200, GET_TENANT_RESPONSE) + tenant = tenant_mgt.update_tenant( + 'tenant-id', allow_password_sign_up=False, + enable_email_link_sign_in=False, app=tenant_mgt_app) + + _assert_tenant(tenant) + body = { + 'allowPasswordSignup': False, + 'enableEmailLinkSignin': False, + } + mask = ['allowPasswordSignup', 'enableEmailLinkSignin'] + self._assert_request(recorder, body, mask) + + def test_update_tenant_minimal(self, tenant_mgt_app): + _, recorder = _instrument_tenant_mgt(tenant_mgt_app, 200, GET_TENANT_RESPONSE) + tenant = tenant_mgt.update_tenant( + 'tenant-id', display_name='My-Tenant', app=tenant_mgt_app) + + _assert_tenant(tenant) + body = {'displayName': 'My-Tenant'} + mask = ['displayName'] + self._assert_request(recorder, body, mask) + + def test_tenant_not_found_error(self, tenant_mgt_app): + _instrument_tenant_mgt(tenant_mgt_app, 500, TENANT_NOT_FOUND_RESPONSE) + with pytest.raises(tenant_mgt.TenantNotFoundError) as excinfo: + tenant_mgt.update_tenant('tenant', display_name='My-Tenant', app=tenant_mgt_app) + + error_msg = 'No tenant found for the given identifier (TENANT_NOT_FOUND).' + assert excinfo.value.code == exceptions.NOT_FOUND + assert str(excinfo.value) == error_msg + assert excinfo.value.http_response is not None + assert excinfo.value.cause is not None + + def _assert_request(self, recorder, body, mask): + assert len(recorder) == 1 + req = recorder[0] + assert req.method == 'PATCH' + assert req.url == '{0}/tenants/tenant-id?updateMask={1}'.format( + TENANT_MGT_URL_PREFIX, ','.join(mask)) + got = json.loads(req.body.decode()) + assert got == body + + +class TestDeleteTenant: + + @pytest.mark.parametrize('tenant_id', INVALID_TENANT_IDS) + def test_invalid_tenant_id(self, tenant_id, tenant_mgt_app): + with pytest.raises(ValueError) as excinfo: + tenant_mgt.delete_tenant(tenant_id, app=tenant_mgt_app) + assert str(excinfo.value).startswith('Invalid tenant ID') + + def test_delete_tenant(self, tenant_mgt_app): + _, recorder = _instrument_tenant_mgt(tenant_mgt_app, 200, '{}') + tenant_mgt.delete_tenant('tenant-id', app=tenant_mgt_app) + + assert len(recorder) == 1 + req = recorder[0] + assert req.method == 'DELETE' + assert req.url == '{0}/tenants/tenant-id'.format(TENANT_MGT_URL_PREFIX) + + def test_tenant_not_found(self, tenant_mgt_app): + _instrument_tenant_mgt(tenant_mgt_app, 500, TENANT_NOT_FOUND_RESPONSE) + with pytest.raises(tenant_mgt.TenantNotFoundError) as excinfo: + tenant_mgt.delete_tenant('tenant-id', app=tenant_mgt_app) + + error_msg = 'No tenant found for the given identifier (TENANT_NOT_FOUND).' + assert excinfo.value.code == exceptions.NOT_FOUND + assert str(excinfo.value) == error_msg + assert excinfo.value.http_response is not None + assert excinfo.value.cause is not None + + +class TestListTenants: + + @pytest.mark.parametrize('arg', [None, 'foo', list(), dict(), 0, -1, 101, False]) + def test_invalid_max_results(self, tenant_mgt_app, arg): + with pytest.raises(ValueError): + tenant_mgt.list_tenants(max_results=arg, app=tenant_mgt_app) + + @pytest.mark.parametrize('arg', ['', list(), dict(), 0, -1, True, False]) + def test_invalid_page_token(self, tenant_mgt_app, arg): + with pytest.raises(ValueError): + tenant_mgt.list_tenants(page_token=arg, app=tenant_mgt_app) + + def test_list_single_page(self, tenant_mgt_app): + _, recorder = _instrument_tenant_mgt(tenant_mgt_app, 200, LIST_TENANTS_RESPONSE) + page = tenant_mgt.list_tenants(app=tenant_mgt_app) + self._assert_tenants_page(page) + assert page.next_page_token == '' + assert page.has_next_page is False + assert page.get_next_page() is None + tenants = [tenant for tenant in page.iterate_all()] + assert len(tenants) == 2 + self._assert_request(recorder) + + def test_list_multiple_pages(self, tenant_mgt_app): + # Page 1 + _, recorder = _instrument_tenant_mgt(tenant_mgt_app, 200, LIST_TENANTS_RESPONSE_WITH_TOKEN) + page = tenant_mgt.list_tenants(app=tenant_mgt_app) + assert len(page.tenants) == 3 + assert page.next_page_token == 'token' + assert page.has_next_page is True + self._assert_request(recorder) + + # Page 2 (also the last page) + response = {'tenants': [{'name': 'projects/mock-project-id/tenants/tenant3'}]} + _, recorder = _instrument_tenant_mgt(tenant_mgt_app, 200, json.dumps(response)) + page = page.get_next_page() + assert len(page.tenants) == 1 + assert page.next_page_token == '' + assert page.has_next_page is False + assert page.get_next_page() is None + self._assert_request(recorder, {'pageSize': '100', 'pageToken': 'token'}) + + def test_list_tenants_paged_iteration(self, tenant_mgt_app): + # Page 1 + _, recorder = _instrument_tenant_mgt(tenant_mgt_app, 200, LIST_TENANTS_RESPONSE_WITH_TOKEN) + page = tenant_mgt.list_tenants(app=tenant_mgt_app) + iterator = page.iterate_all() + for index in range(3): + tenant = next(iterator) + assert tenant.tenant_id == 'tenant{0}'.format(index) + self._assert_request(recorder) + + # Page 2 (also the last page) + response = {'tenants': [{'name': 'projects/mock-project-id/tenants/tenant3'}]} + _, recorder = _instrument_tenant_mgt(tenant_mgt_app, 200, json.dumps(response)) + tenant = next(iterator) + assert tenant.tenant_id == 'tenant3' + + with pytest.raises(StopIteration): + next(iterator) + self._assert_request(recorder, {'pageSize': '100', 'pageToken': 'token'}) + + def test_list_tenants_iterator_state(self, tenant_mgt_app): + _, recorder = _instrument_tenant_mgt(tenant_mgt_app, 200, LIST_TENANTS_RESPONSE) + page = tenant_mgt.list_tenants(app=tenant_mgt_app) + + # Advance iterator. + iterator = page.iterate_all() + tenant = next(iterator) + assert tenant.tenant_id == 'tenant0' + + # Iterator should resume from where left off. + tenant = next(iterator) + assert tenant.tenant_id == 'tenant1' + + with pytest.raises(StopIteration): + next(iterator) + self._assert_request(recorder) + + def test_list_tenants_stop_iteration(self, tenant_mgt_app): + _, recorder = _instrument_tenant_mgt(tenant_mgt_app, 200, LIST_TENANTS_RESPONSE) + page = tenant_mgt.list_tenants(app=tenant_mgt_app) + iterator = page.iterate_all() + tenants = [tenant for tenant in iterator] + assert len(tenants) == 2 + + with pytest.raises(StopIteration): + next(iterator) + self._assert_request(recorder) + + def test_list_tenants_no_tenants_response(self, tenant_mgt_app): + response = {'tenants': []} + _instrument_tenant_mgt(tenant_mgt_app, 200, json.dumps(response)) + page = tenant_mgt.list_tenants(app=tenant_mgt_app) + assert len(page.tenants) == 0 + tenants = [tenant for tenant in page.iterate_all()] + assert len(tenants) == 0 + + def test_list_tenants_with_max_results(self, tenant_mgt_app): + _, recorder = _instrument_tenant_mgt(tenant_mgt_app, 200, LIST_TENANTS_RESPONSE) + page = tenant_mgt.list_tenants(max_results=50, app=tenant_mgt_app) + self._assert_tenants_page(page) + self._assert_request(recorder, {'pageSize' : '50'}) + + def test_list_tenants_with_all_args(self, tenant_mgt_app): + _, recorder = _instrument_tenant_mgt(tenant_mgt_app, 200, LIST_TENANTS_RESPONSE) + page = tenant_mgt.list_tenants(page_token='foo', max_results=50, app=tenant_mgt_app) + self._assert_tenants_page(page) + self._assert_request(recorder, {'pageToken' : 'foo', 'pageSize' : '50'}) + + def test_list_tenants_error(self, tenant_mgt_app): + _instrument_tenant_mgt(tenant_mgt_app, 500, '{"error":"test"}') + with pytest.raises(exceptions.InternalError) as excinfo: + tenant_mgt.list_tenants(app=tenant_mgt_app) + assert str(excinfo.value) == 'Unexpected error response: {"error":"test"}' + + def _assert_tenants_page(self, page): + assert isinstance(page, tenant_mgt.ListTenantsPage) + assert len(page.tenants) == 2 + for idx, tenant in enumerate(page.tenants): + _assert_tenant(tenant, 'tenant{0}'.format(idx)) + + def _assert_request(self, recorder, expected=None): + if expected is None: + expected = {'pageSize' : '100'} + + assert len(recorder) == 1 + req = recorder[0] + assert req.method == 'GET' + request = dict(parse.parse_qsl(parse.urlsplit(req.url).query)) + assert request == expected + + +class TestAuthForTenant: + + @pytest.mark.parametrize('tenant_id', INVALID_TENANT_IDS) + def test_invalid_tenant_id(self, tenant_id, tenant_mgt_app): + with pytest.raises(ValueError): + tenant_mgt.auth_for_tenant(tenant_id, app=tenant_mgt_app) + + def test_client(self, tenant_mgt_app): + client = tenant_mgt.auth_for_tenant('tenant1', app=tenant_mgt_app) + assert client.tenant_id == 'tenant1' + + def test_client_reuse(self, tenant_mgt_app): + client1 = tenant_mgt.auth_for_tenant('tenant1', app=tenant_mgt_app) + client2 = tenant_mgt.auth_for_tenant('tenant1', app=tenant_mgt_app) + client3 = tenant_mgt.auth_for_tenant('tenant2', app=tenant_mgt_app) + assert client1 is client2 + assert client1 is not client3 + + +class TestTenantAwareUserManagement: + + def test_get_user(self, tenant_mgt_app): + client = tenant_mgt.auth_for_tenant('tenant-id', app=tenant_mgt_app) + recorder = _instrument_user_mgt(client, 200, MOCK_GET_USER_RESPONSE) + + user = client.get_user('testuser') + + assert isinstance(user, auth.UserRecord) + assert user.uid == 'testuser' + assert user.email == 'testuser@example.com' + self._assert_request(recorder, '/accounts:lookup', {'localId': ['testuser']}) + + def test_get_user_by_email(self, tenant_mgt_app): + client = tenant_mgt.auth_for_tenant('tenant-id', app=tenant_mgt_app) + recorder = _instrument_user_mgt(client, 200, MOCK_GET_USER_RESPONSE) + + user = client.get_user_by_email('testuser@example.com') + + assert isinstance(user, auth.UserRecord) + assert user.uid == 'testuser' + assert user.email == 'testuser@example.com' + self._assert_request(recorder, '/accounts:lookup', {'email': ['testuser@example.com']}) + + def test_get_user_by_phone_number(self, tenant_mgt_app): + client = tenant_mgt.auth_for_tenant('tenant-id', app=tenant_mgt_app) + recorder = _instrument_user_mgt(client, 200, MOCK_GET_USER_RESPONSE) + + user = client.get_user_by_phone_number('+1234567890') + + assert isinstance(user, auth.UserRecord) + assert user.uid == 'testuser' + assert user.email == 'testuser@example.com' + self._assert_request(recorder, '/accounts:lookup', {'phoneNumber': ['+1234567890']}) + + def test_create_user(self, tenant_mgt_app): + client = tenant_mgt.auth_for_tenant('tenant-id', app=tenant_mgt_app) + recorder = _instrument_user_mgt(client, 200, '{"localId":"testuser"}') + + uid = client._user_manager.create_user() + + assert uid == 'testuser' + self._assert_request(recorder, '/accounts', {}) + + def test_update_user(self, tenant_mgt_app): + client = tenant_mgt.auth_for_tenant('tenant-id', app=tenant_mgt_app) + recorder = _instrument_user_mgt(client, 200, '{"localId":"testuser"}') + + uid = client._user_manager.update_user('testuser', email='testuser@example.com') + + assert uid == 'testuser' + self._assert_request(recorder, '/accounts:update', { + 'localId': 'testuser', + 'email': 'testuser@example.com', + }) + + def test_delete_user(self, tenant_mgt_app): + client = tenant_mgt.auth_for_tenant('tenant-id', app=tenant_mgt_app) + recorder = _instrument_user_mgt(client, 200, '{"kind":"deleteresponse"}') + + client.delete_user('testuser') + + self._assert_request(recorder, '/accounts:delete', {'localId': 'testuser'}) + + def test_set_custom_user_claims(self, tenant_mgt_app): + client = tenant_mgt.auth_for_tenant('tenant-id', app=tenant_mgt_app) + recorder = _instrument_user_mgt(client, 200, '{"localId":"testuser"}') + claims = {'admin': True} + + client.set_custom_user_claims('testuser', claims) + + self._assert_request(recorder, '/accounts:update', { + 'localId': 'testuser', + 'customAttributes': json.dumps(claims), + }) + + def test_revoke_refresh_tokens(self, tenant_mgt_app): + client = tenant_mgt.auth_for_tenant('tenant-id', app=tenant_mgt_app) + recorder = _instrument_user_mgt(client, 200, '{"localId":"testuser"}') + + client.revoke_refresh_tokens('testuser') + + assert len(recorder) == 1 + req = recorder[0] + assert req.method == 'POST' + assert req.url == '{0}/tenants/tenant-id/accounts:update'.format( + USER_MGT_URL_PREFIX) + body = json.loads(req.body.decode()) + assert body['localId'] == 'testuser' + assert 'validSince' in body + + def test_list_users(self, tenant_mgt_app): + client = tenant_mgt.auth_for_tenant('tenant-id', app=tenant_mgt_app) + recorder = _instrument_user_mgt(client, 200, MOCK_LIST_USERS_RESPONSE) + + page = client.list_users() + + assert isinstance(page, auth.ListUsersPage) + assert page.next_page_token == '' + assert page.has_next_page is False + assert page.get_next_page() is None + users = list(user for user in page.iterate_all()) + assert len(users) == 2 + + assert len(recorder) == 1 + req = recorder[0] + assert req.method == 'GET' + assert req.url == '{0}/tenants/tenant-id/accounts:batchGet?maxResults=1000'.format( + USER_MGT_URL_PREFIX) + + def test_import_users(self, tenant_mgt_app): + client = tenant_mgt.auth_for_tenant('tenant-id', app=tenant_mgt_app) + recorder = _instrument_user_mgt(client, 200, '{}') + users = [ + auth.ImportUserRecord(uid='user1'), + auth.ImportUserRecord(uid='user2'), + ] + + result = client.import_users(users) + + assert isinstance(result, auth.UserImportResult) + assert result.success_count == 2 + assert result.failure_count == 0 + assert result.errors == [] + self._assert_request(recorder, '/accounts:batchCreate', { + 'users': [{'localId': 'user1'}, {'localId': 'user2'}], + }) + + def test_generate_password_reset_link(self, tenant_mgt_app): + client = tenant_mgt.auth_for_tenant('tenant-id', app=tenant_mgt_app) + recorder = _instrument_user_mgt(client, 200, '{"oobLink":"https://testlink"}') + + link = client.generate_password_reset_link('test@test.com') + + assert link == 'https://testlink' + self._assert_request(recorder, '/accounts:sendOobCode', { + 'email': 'test@test.com', + 'requestType': 'PASSWORD_RESET', + 'returnOobLink': True, + }) + + def test_generate_email_verification_link(self, tenant_mgt_app): + client = tenant_mgt.auth_for_tenant('tenant-id', app=tenant_mgt_app) + recorder = _instrument_user_mgt(client, 200, '{"oobLink":"https://testlink"}') + + link = client.generate_email_verification_link('test@test.com') + + assert link == 'https://testlink' + self._assert_request(recorder, '/accounts:sendOobCode', { + 'email': 'test@test.com', + 'requestType': 'VERIFY_EMAIL', + 'returnOobLink': True, + }) + + def test_generate_sign_in_with_email_link(self, tenant_mgt_app): + client = tenant_mgt.auth_for_tenant('tenant-id', app=tenant_mgt_app) + recorder = _instrument_user_mgt(client, 200, '{"oobLink":"https://testlink"}') + settings = auth.ActionCodeSettings(url='http://localhost') + + link = client.generate_sign_in_with_email_link('test@test.com', settings) + + assert link == 'https://testlink' + self._assert_request(recorder, '/accounts:sendOobCode', { + 'email': 'test@test.com', + 'requestType': 'EMAIL_SIGNIN', + 'returnOobLink': True, + 'continueUrl': 'http://localhost', + }) + + def test_get_oidc_provider_config(self, tenant_mgt_app): + client = tenant_mgt.auth_for_tenant('tenant-id', app=tenant_mgt_app) + recorder = _instrument_provider_mgt(client, 200, OIDC_PROVIDER_CONFIG_RESPONSE) + + provider_config = client.get_oidc_provider_config('oidc.provider') + + self._assert_oidc_provider_config(provider_config) + assert len(recorder) == 1 + req = recorder[0] + assert req.method == 'GET' + assert req.url == '{0}/tenants/tenant-id/oauthIdpConfigs/oidc.provider'.format( + PROVIDER_MGT_URL_PREFIX) + + def test_create_oidc_provider_config(self, tenant_mgt_app): + client = tenant_mgt.auth_for_tenant('tenant-id', app=tenant_mgt_app) + recorder = _instrument_provider_mgt(client, 200, OIDC_PROVIDER_CONFIG_RESPONSE) + + provider_config = client.create_oidc_provider_config( + 'oidc.provider', client_id='CLIENT_ID', issuer='https://oidc.com/issuer', + display_name='oidcProviderName', enabled=True) + + self._assert_oidc_provider_config(provider_config) + self._assert_request( + recorder, '/oauthIdpConfigs?oauthIdpConfigId=oidc.provider', + OIDC_PROVIDER_CONFIG_REQUEST, prefix=PROVIDER_MGT_URL_PREFIX) + + def test_update_oidc_provider_config(self, tenant_mgt_app): + client = tenant_mgt.auth_for_tenant('tenant-id', app=tenant_mgt_app) + recorder = _instrument_provider_mgt(client, 200, OIDC_PROVIDER_CONFIG_RESPONSE) + + provider_config = client.update_oidc_provider_config( + 'oidc.provider', client_id='CLIENT_ID', issuer='https://oidc.com/issuer', + display_name='oidcProviderName', enabled=True) + + self._assert_oidc_provider_config(provider_config) + mask = ['clientId', 'displayName', 'enabled', 'issuer'] + url = '/oauthIdpConfigs/oidc.provider?updateMask={0}'.format(','.join(mask)) + self._assert_request( + recorder, url, OIDC_PROVIDER_CONFIG_REQUEST, method='PATCH', + prefix=PROVIDER_MGT_URL_PREFIX) + + def test_delete_oidc_provider_config(self, tenant_mgt_app): + client = tenant_mgt.auth_for_tenant('tenant-id', app=tenant_mgt_app) + recorder = _instrument_provider_mgt(client, 200, '{}') + + client.delete_oidc_provider_config('oidc.provider') + + assert len(recorder) == 1 + req = recorder[0] + assert req.method == 'DELETE' + assert req.url == '{0}/tenants/tenant-id/oauthIdpConfigs/oidc.provider'.format( + PROVIDER_MGT_URL_PREFIX) + + def test_list_oidc_provider_configs(self, tenant_mgt_app): + client = tenant_mgt.auth_for_tenant('tenant-id', app=tenant_mgt_app) + recorder = _instrument_provider_mgt(client, 200, LIST_OIDC_PROVIDER_CONFIGS_RESPONSE) + + page = client.list_oidc_provider_configs() + + assert isinstance(page, auth.ListProviderConfigsPage) + index = 0 + assert len(page.provider_configs) == 2 + for provider_config in page.provider_configs: + self._assert_oidc_provider_config( + provider_config, want_id='oidc.provider{0}'.format(index)) + index += 1 + + assert page.next_page_token == '' + assert page.has_next_page is False + assert page.get_next_page() is None + provider_configs = list(config for config in page.iterate_all()) + assert len(provider_configs) == 2 + + assert len(recorder) == 1 + req = recorder[0] + assert req.method == 'GET' + assert req.url == '{0}{1}'.format( + PROVIDER_MGT_URL_PREFIX, '/tenants/tenant-id/oauthIdpConfigs?pageSize=100') + + def test_get_saml_provider_config(self, tenant_mgt_app): + client = tenant_mgt.auth_for_tenant('tenant-id', app=tenant_mgt_app) + recorder = _instrument_provider_mgt(client, 200, SAML_PROVIDER_CONFIG_RESPONSE) + + provider_config = client.get_saml_provider_config('saml.provider') + + self._assert_saml_provider_config(provider_config) + assert len(recorder) == 1 + req = recorder[0] + assert req.method == 'GET' + assert req.url == '{0}/tenants/tenant-id/inboundSamlConfigs/saml.provider'.format( + PROVIDER_MGT_URL_PREFIX) + + def test_create_saml_provider_config(self, tenant_mgt_app): + client = tenant_mgt.auth_for_tenant('tenant-id', app=tenant_mgt_app) + recorder = _instrument_provider_mgt(client, 200, SAML_PROVIDER_CONFIG_RESPONSE) + + provider_config = client.create_saml_provider_config( + 'saml.provider', idp_entity_id='IDP_ENTITY_ID', sso_url='https://example.com/login', + x509_certificates=['CERT1', 'CERT2'], rp_entity_id='RP_ENTITY_ID', + callback_url='https://projectId.firebaseapp.com/__/auth/handler', + display_name='samlProviderName', enabled=True) + + self._assert_saml_provider_config(provider_config) + self._assert_request( + recorder, '/inboundSamlConfigs?inboundSamlConfigId=saml.provider', + SAML_PROVIDER_CONFIG_REQUEST, prefix=PROVIDER_MGT_URL_PREFIX) + + def test_update_saml_provider_config(self, tenant_mgt_app): + client = tenant_mgt.auth_for_tenant('tenant-id', app=tenant_mgt_app) + recorder = _instrument_provider_mgt(client, 200, SAML_PROVIDER_CONFIG_RESPONSE) + + provider_config = client.update_saml_provider_config( + 'saml.provider', idp_entity_id='IDP_ENTITY_ID', sso_url='https://example.com/login', + x509_certificates=['CERT1', 'CERT2'], rp_entity_id='RP_ENTITY_ID', + callback_url='https://projectId.firebaseapp.com/__/auth/handler', + display_name='samlProviderName', enabled=True) + + self._assert_saml_provider_config(provider_config) + mask = [ + 'displayName', 'enabled', 'idpConfig.idpCertificates', 'idpConfig.idpEntityId', + 'idpConfig.ssoUrl', 'spConfig.callbackUri', 'spConfig.spEntityId', + ] + url = '/inboundSamlConfigs/saml.provider?updateMask={0}'.format(','.join(mask)) + self._assert_request( + recorder, url, SAML_PROVIDER_CONFIG_REQUEST, method='PATCH', + prefix=PROVIDER_MGT_URL_PREFIX) + + def test_delete_saml_provider_config(self, tenant_mgt_app): + client = tenant_mgt.auth_for_tenant('tenant-id', app=tenant_mgt_app) + recorder = _instrument_provider_mgt(client, 200, '{}') + + client.delete_saml_provider_config('saml.provider') + + assert len(recorder) == 1 + req = recorder[0] + assert req.method == 'DELETE' + assert req.url == '{0}/tenants/tenant-id/inboundSamlConfigs/saml.provider'.format( + PROVIDER_MGT_URL_PREFIX) + + def test_list_saml_provider_configs(self, tenant_mgt_app): + client = tenant_mgt.auth_for_tenant('tenant-id', app=tenant_mgt_app) + recorder = _instrument_provider_mgt(client, 200, LIST_SAML_PROVIDER_CONFIGS_RESPONSE) + + page = client.list_saml_provider_configs() + + assert isinstance(page, auth.ListProviderConfigsPage) + index = 0 + assert len(page.provider_configs) == 2 + for provider_config in page.provider_configs: + self._assert_saml_provider_config( + provider_config, want_id='saml.provider{0}'.format(index)) + index += 1 + + assert page.next_page_token == '' + assert page.has_next_page is False + assert page.get_next_page() is None + provider_configs = list(config for config in page.iterate_all()) + assert len(provider_configs) == 2 + + assert len(recorder) == 1 + req = recorder[0] + assert req.method == 'GET' + assert req.url == '{0}{1}'.format( + PROVIDER_MGT_URL_PREFIX, '/tenants/tenant-id/inboundSamlConfigs?pageSize=100') + + def test_tenant_not_found(self, tenant_mgt_app): + client = tenant_mgt.auth_for_tenant('tenant-id', app=tenant_mgt_app) + _instrument_user_mgt(client, 500, TENANT_NOT_FOUND_RESPONSE) + with pytest.raises(tenant_mgt.TenantNotFoundError) as excinfo: + client.get_user('testuser') + + error_msg = 'No tenant found for the given identifier (TENANT_NOT_FOUND).' + assert excinfo.value.code == exceptions.NOT_FOUND + assert str(excinfo.value) == error_msg + assert excinfo.value.http_response is not None + assert excinfo.value.cause is not None + + def _assert_request( + self, recorder, want_url, want_body, method='POST', prefix=USER_MGT_URL_PREFIX): + assert len(recorder) == 1 + req = recorder[0] + assert req.method == method + assert req.url == '{0}/tenants/tenant-id{1}'.format(prefix, want_url) + body = json.loads(req.body.decode()) + assert body == want_body + + def _assert_oidc_provider_config(self, provider_config, want_id='oidc.provider'): + assert isinstance(provider_config, auth.OIDCProviderConfig) + assert provider_config.provider_id == want_id + assert provider_config.display_name == 'oidcProviderName' + assert provider_config.enabled is True + assert provider_config.client_id == 'CLIENT_ID' + assert provider_config.issuer == 'https://oidc.com/issuer' + + def _assert_saml_provider_config(self, provider_config, want_id='saml.provider'): + assert isinstance(provider_config, auth.SAMLProviderConfig) + assert provider_config.provider_id == want_id + assert provider_config.display_name == 'samlProviderName' + assert provider_config.enabled is True + assert provider_config.idp_entity_id == 'IDP_ENTITY_ID' + assert provider_config.sso_url == 'https://example.com/login' + assert provider_config.x509_certificates == ['CERT1', 'CERT2'] + assert provider_config.rp_entity_id == 'RP_ENTITY_ID' + assert provider_config.callback_url == 'https://projectId.firebaseapp.com/__/auth/handler' + + +class TestVerifyIdToken: + + def test_valid_token(self, tenant_mgt_app): + client = tenant_mgt.auth_for_tenant('test-tenant', app=tenant_mgt_app) + client._token_verifier.request = test_token_gen.MOCK_REQUEST + + claims = client.verify_id_token(test_token_gen.TEST_ID_TOKEN_WITH_TENANT) + + assert claims['admin'] is True + assert claims['uid'] == claims['sub'] + assert claims['firebase']['tenant'] == 'test-tenant' + + def test_invalid_tenant_id(self, tenant_mgt_app): + client = tenant_mgt.auth_for_tenant('other-tenant', app=tenant_mgt_app) + client._token_verifier.request = test_token_gen.MOCK_REQUEST + + with pytest.raises(tenant_mgt.TenantIdMismatchError) as excinfo: + client.verify_id_token(test_token_gen.TEST_ID_TOKEN_WITH_TENANT) + + assert 'Invalid tenant ID: test-tenant' in str(excinfo.value) + assert isinstance(excinfo.value, exceptions.InvalidArgumentError) + assert excinfo.value.cause is None + assert excinfo.value.http_response is None + + +@pytest.fixture(scope='module') +def tenant_aware_custom_token_app(): + cred = credentials.Certificate(testutils.resource_filename('service_account.json')) + app = firebase_admin.initialize_app(cred, name='tenantAwareCustomToken') + yield app + firebase_admin.delete_app(app) + + +class TestCreateCustomToken: + + def test_custom_token(self, tenant_aware_custom_token_app): + client = tenant_mgt.auth_for_tenant('test-tenant', app=tenant_aware_custom_token_app) + + custom_token = client.create_custom_token('user1') + + test_token_gen.verify_custom_token( + custom_token, expected_claims=None, tenant_id='test-tenant') + + def test_custom_token_with_claims(self, tenant_aware_custom_token_app): + client = tenant_mgt.auth_for_tenant('test-tenant', app=tenant_aware_custom_token_app) + claims = {'admin': True} + + custom_token = client.create_custom_token('user1', claims) + + test_token_gen.verify_custom_token( + custom_token, expected_claims=claims, tenant_id='test-tenant') + + +def _assert_tenant(tenant, tenant_id='tenant-id'): + assert isinstance(tenant, tenant_mgt.Tenant) + assert tenant.tenant_id == tenant_id + assert tenant.display_name == 'Test Tenant' + assert tenant.allow_password_sign_up is True + assert tenant.enable_email_link_sign_in is True diff --git a/tests/test_token_gen.py b/tests/test_token_gen.py index 439c1ba6e..f88c87ff4 100644 --- a/tests/test_token_gen.py +++ b/tests/test_token_gen.py @@ -66,7 +66,7 @@ def _merge_jwt_claims(defaults, overrides): del defaults[key] return defaults -def _verify_custom_token(custom_token, expected_claims): +def verify_custom_token(custom_token, expected_claims, tenant_id=None): assert isinstance(custom_token, bytes) token = google.oauth2.id_token.verify_token( custom_token, @@ -75,6 +75,11 @@ def _verify_custom_token(custom_token, expected_claims): assert token['uid'] == MOCK_UID assert token['iss'] == MOCK_SERVICE_ACCOUNT_EMAIL assert token['sub'] == MOCK_SERVICE_ACCOUNT_EMAIL + if tenant_id is None: + assert 'tenant_id' not in token + else: + assert token['tenant_id'] == tenant_id + header = jwt.decode_header(custom_token) assert header.get('typ') == 'JWT' assert header.get('alg') == 'RS256' @@ -94,6 +99,9 @@ def _get_id_token(payload_overrides=None, header_overrides=None): 'exp': int(time.time()) + 3600, 'sub': '1234567890', 'admin': True, + 'firebase': { + 'sign_in_provider': 'provider', + }, } if header_overrides: headers = _merge_jwt_claims(headers, header_overrides) @@ -109,21 +117,21 @@ def _get_session_cookie(payload_overrides=None, header_overrides=None): return _get_id_token(payload_overrides, header_overrides) def _instrument_user_manager(app, status, payload): - auth_service = auth._get_auth_service(app) - user_manager = auth_service.user_manager + client = auth._get_client(app) + user_manager = client._user_manager recorder = [] - user_manager._client.session.mount( - auth._AuthService.ID_TOOLKIT_URL, + user_manager.http_client.session.mount( + _token_gen.TokenGenerator.ID_TOOLKIT_URL, testutils.MockAdapter(payload, status, recorder)) return user_manager, recorder def _overwrite_cert_request(app, request): - auth_service = auth._get_auth_service(app) - auth_service.token_verifier.request = request + client = auth._get_client(app) + client._token_verifier.request = request def _overwrite_iam_request(app, request): - auth_service = auth._get_auth_service(app) - auth_service.token_generator.request = request + client = auth._get_client(app) + client._token_generator.request = request @pytest.fixture(scope='module') def auth_app(): @@ -195,7 +203,7 @@ class TestCreateCustomToken: def test_valid_params(self, auth_app, values): user, claims = values custom_token = auth.create_custom_token(user, claims, app=auth_app) - _verify_custom_token(custom_token, claims) + verify_custom_token(custom_token, claims) @pytest.mark.parametrize('values', invalid_args.values(), ids=list(invalid_args)) def test_invalid_params(self, auth_app, values): @@ -245,8 +253,9 @@ def test_sign_with_discovered_service_account(self): try: _overwrite_iam_request(app, request) # Force initialization of the signing provider. This will invoke the Metadata service. - auth_service = auth._get_auth_service(app) - assert auth_service.token_generator.signing_provider is not None + client = auth._get_client(app) + assert client._token_generator.signing_provider is not None + # Now invoke the IAM signer. signature = base64.b64encode(b'test').decode() request.response = testutils.MockResponse( @@ -346,6 +355,11 @@ def test_unexpected_response(self, user_mgt_app): MOCK_GET_USER_RESPONSE = testutils.resource('get_user.json') TEST_ID_TOKEN = _get_id_token() +TEST_ID_TOKEN_WITH_TENANT = _get_id_token({ + 'firebase': { + 'tenant': 'test-tenant', + } +}) TEST_SESSION_COOKIE = _get_session_cookie() @@ -380,6 +394,14 @@ def test_valid_token(self, user_mgt_app, id_token): claims = auth.verify_id_token(id_token, app=user_mgt_app) assert claims['admin'] is True assert claims['uid'] == claims['sub'] + assert claims['firebase']['sign_in_provider'] == 'provider' + + def test_valid_token_with_tenant(self, user_mgt_app): + _overwrite_cert_request(user_mgt_app, MOCK_REQUEST) + claims = auth.verify_id_token(TEST_ID_TOKEN_WITH_TENANT, app=user_mgt_app) + assert claims['admin'] is True + assert claims['uid'] == claims['sub'] + assert claims['firebase']['tenant'] == 'test-tenant' @pytest.mark.parametrize('id_token', valid_tokens.values(), ids=list(valid_tokens)) def test_valid_token_check_revoked(self, user_mgt_app, id_token): diff --git a/tests/test_user_mgt.py b/tests/test_user_mgt.py index 958bbf9c4..c7b2de496 100644 --- a/tests/test_user_mgt.py +++ b/tests/test_user_mgt.py @@ -50,6 +50,9 @@ } MOCK_ACTION_CODE_SETTINGS = auth.ActionCodeSettings(**MOCK_ACTION_CODE_DATA) +USER_MGT_URL_PREFIX = 'https://identitytoolkit.googleapis.com/v1/projects/mock-project-id' + + @pytest.fixture(scope='module') def user_mgt_app(): app = firebase_admin.initialize_app(testutils.MockCredential(), name='userMgt', @@ -58,11 +61,11 @@ def user_mgt_app(): firebase_admin.delete_app(app) def _instrument_user_manager(app, status, payload): - auth_service = auth._get_auth_service(app) - user_manager = auth_service.user_manager + client = auth._get_client(app) + user_manager = client._user_manager recorder = [] - user_manager._client.session.mount( - auth._AuthService.ID_TOOLKIT_URL, + user_manager.http_client.session.mount( + _user_mgt.UserManager.ID_TOOLKIT_URL, testutils.MockAdapter(payload, status, recorder)) return user_manager, recorder @@ -78,6 +81,7 @@ def _check_user_record(user, expected_uid='testuser'): assert user.user_metadata.creation_timestamp == 1234567890000 assert user.user_metadata.last_sign_in_timestamp is None assert user.provider_id == 'firebase' + assert user.tenant_id is None claims = user.custom_claims assert claims['admin'] is True @@ -101,17 +105,27 @@ def _check_user_record(user, expected_uid='testuser'): assert provider.provider_id == 'phone' +def _check_request(recorder, want_url, want_body=None): + assert len(recorder) == 1 + req = recorder[0] + assert req.method == 'POST' + assert req.url == '{0}{1}'.format(USER_MGT_URL_PREFIX, want_url) + if want_body: + body = json.loads(req.body.decode()) + assert body == want_body + + class TestAuthServiceInitialization: def test_default_timeout(self, user_mgt_app): - auth_service = auth._get_auth_service(user_mgt_app) - user_manager = auth_service.user_manager - assert user_manager._client.timeout == _http_client.DEFAULT_TIMEOUT_SECONDS + client = auth._get_client(user_mgt_app) + user_manager = client._user_manager + assert user_manager.http_client.timeout == _http_client.DEFAULT_TIMEOUT_SECONDS def test_fail_on_no_project_id(self): app = firebase_admin.initialize_app(testutils.MockCredential(), name='userMgt2') with pytest.raises(ValueError): - auth._get_auth_service(app) + auth._get_client(app) firebase_admin.delete_app(app) @@ -194,6 +208,10 @@ def test_no_tokens_valid_after_time(self): user = auth.UserRecord({'localId' : 'user'}) assert user.tokens_valid_after_timestamp == 0 + def test_tenant_id(self): + user = auth.UserRecord({'localId' : 'user', 'tenantId': 'test-tenant'}) + assert user.tenant_id == 'test-tenant' + class TestGetUser: @@ -203,8 +221,9 @@ def test_invalid_get_user(self, arg, user_mgt_app): auth.get_user(arg, app=user_mgt_app) def test_get_user(self, user_mgt_app): - _instrument_user_manager(user_mgt_app, 200, MOCK_GET_USER_RESPONSE) + _, recorder = _instrument_user_manager(user_mgt_app, 200, MOCK_GET_USER_RESPONSE) _check_user_record(auth.get_user('testuser', user_mgt_app)) + _check_request(recorder, '/accounts:lookup', {'localId': ['testuser']}) @pytest.mark.parametrize('arg', INVALID_STRINGS + ['not-an-email']) def test_invalid_get_user_by_email(self, arg, user_mgt_app): @@ -212,8 +231,9 @@ def test_invalid_get_user_by_email(self, arg, user_mgt_app): auth.get_user_by_email(arg, app=user_mgt_app) def test_get_user_by_email(self, user_mgt_app): - _instrument_user_manager(user_mgt_app, 200, MOCK_GET_USER_RESPONSE) + _, recorder = _instrument_user_manager(user_mgt_app, 200, MOCK_GET_USER_RESPONSE) _check_user_record(auth.get_user_by_email('testuser@example.com', user_mgt_app)) + _check_request(recorder, '/accounts:lookup', {'email': ['testuser@example.com']}) @pytest.mark.parametrize('arg', INVALID_STRINGS + ['not-a-phone']) def test_invalid_get_user_by_phone(self, arg, user_mgt_app): @@ -221,8 +241,9 @@ def test_invalid_get_user_by_phone(self, arg, user_mgt_app): auth.get_user_by_phone_number(arg, app=user_mgt_app) def test_get_user_by_phone(self, user_mgt_app): - _instrument_user_manager(user_mgt_app, 200, MOCK_GET_USER_RESPONSE) + _, recorder = _instrument_user_manager(user_mgt_app, 200, MOCK_GET_USER_RESPONSE) _check_user_record(auth.get_user_by_phone_number('+1234567890', user_mgt_app)) + _check_request(recorder, '/accounts:lookup', {'phoneNumber': ['+1234567890']}) def test_get_user_non_existing(self, user_mgt_app): _instrument_user_manager(user_mgt_app, 200, '{"users":[]}') @@ -1050,7 +1071,7 @@ def test_import_users(self, user_mgt_app): assert result.failure_count == 0 assert result.errors == [] expected = {'users': [{'localId': 'user1'}, {'localId': 'user2'}]} - self._check_rpc_calls(recorder, expected) + _check_request(recorder, '/accounts:batchCreate', expected) def test_import_users_error(self, user_mgt_app): _, recorder = _instrument_user_manager(user_mgt_app, 200, """{"error": [ @@ -1073,7 +1094,7 @@ def test_import_users_error(self, user_mgt_app): assert err.index == 2 assert err.reason == 'Another error occured in user3' expected = {'users': [{'localId': 'user1'}, {'localId': 'user2'}, {'localId': 'user3'}]} - self._check_rpc_calls(recorder, expected) + _check_request(recorder, '/accounts:batchCreate', expected) def test_import_users_missing_required_hash(self, user_mgt_app): users = [ @@ -1106,7 +1127,7 @@ def test_import_users_with_hash(self, user_mgt_app): 'memoryCost': 14, 'saltSeparator': _user_import.b64_encode(b'sep'), } - self._check_rpc_calls(recorder, expected) + _check_request(recorder, '/accounts:batchCreate', expected) def test_import_users_http_error(self, user_mgt_app): _instrument_user_manager(user_mgt_app, 401, '{"error": {"message": "ERROR_CODE"}}') @@ -1127,11 +1148,6 @@ def test_import_users_unexpected_response(self, user_mgt_app): with pytest.raises(auth.UnexpectedResponseError): auth.import_users(users, app=user_mgt_app) - def _check_rpc_calls(self, recorder, expected): - assert len(recorder) == 1 - request = json.loads(recorder[0].body.decode()) - assert request == expected - class TestRevokeRefreshTokkens: @@ -1301,8 +1317,8 @@ def test_bad_settings_data(self, user_mgt_app, func): def test_bad_action_type(self, user_mgt_app): with pytest.raises(ValueError): - auth._get_auth_service(user_mgt_app) \ - .user_manager \ + auth._get_client(user_mgt_app) \ + ._user_manager \ .generate_email_action_link('BAD_TYPE', 'test@test.com', action_code_settings=MOCK_ACTION_CODE_SETTINGS) From 088c33ef7a4afe7a329f7d564b1e5fdf31063d41 Mon Sep 17 00:00:00 2001 From: Hiranya Jayathilaka Date: Tue, 28 Apr 2020 11:23:56 -0700 Subject: [PATCH 069/226] fix(auth): Fixed some API reference mistakes (#456) * fix(auth): Fixed some API reference mistakes * Fixed astroid version to avoid lint error --- firebase_admin/auth.py | 8 +++++++- firebase_admin/ml.py | 2 +- requirements.txt | 1 + 3 files changed, 9 insertions(+), 2 deletions(-) diff --git a/firebase_admin/auth.py b/firebase_admin/auth.py index cb8782ea7..1cce9ea00 100644 --- a/firebase_admin/auth.py +++ b/firebase_admin/auth.py @@ -35,6 +35,7 @@ 'ActionCodeSettings', 'CertificateFetchError', 'Client', + 'ConfigurationNotFoundError', 'DELETE_ATTRIBUTE', 'EmailAlreadyExistsError', 'ErrorInfo', @@ -48,6 +49,7 @@ 'InvalidSessionCookieError', 'ListProviderConfigsPage', 'ListUsersPage', + 'OIDCProviderConfig', 'PhoneNumberAlreadyExistsError', 'ProviderConfig', 'RevokedIdTokenError', @@ -65,14 +67,17 @@ 'UserRecord', 'create_custom_token', + 'create_oidc_provider_config', 'create_saml_provider_config', 'create_session_cookie', 'create_user', + 'delete_oidc_provider_config', 'delete_saml_provider_config', 'delete_user', 'generate_email_verification_link', 'generate_password_reset_link', 'generate_sign_in_with_email_link', + 'get_oidc_provider_config', 'get_saml_provider_config', 'get_user', 'get_user_by_email', @@ -82,6 +87,7 @@ 'list_users', 'revoke_refresh_tokens', 'set_custom_user_claims', + 'update_oidc_provider_config', 'update_saml_provider_config', 'update_user', 'verify_id_token', @@ -107,7 +113,7 @@ ListUsersPage = _user_mgt.ListUsersPage OIDCProviderConfig = _auth_providers.OIDCProviderConfig PhoneNumberAlreadyExistsError = _auth_utils.PhoneNumberAlreadyExistsError -ProviderConfig = _auth_providers.ProviderConfigClient +ProviderConfig = _auth_providers.ProviderConfig RevokedIdTokenError = _token_gen.RevokedIdTokenError RevokedSessionCookieError = _token_gen.RevokedSessionCookieError SAMLProviderConfig = _auth_providers.SAMLProviderConfig diff --git a/firebase_admin/ml.py b/firebase_admin/ml.py index db1657839..900fdc24b 100644 --- a/firebase_admin/ml.py +++ b/firebase_admin/ml.py @@ -148,7 +148,7 @@ def list_models(list_filter=None, page_size=None, page_token=None, app=None): """Lists models from Firebase ML. Args: - list_filter: a list filter string such as "tags:'tag_1'". None will return all models. + list_filter: a list filter string such as ``tags:'tag_1'``. None will return all models. page_size: A number between 1 and 100 inclusive that specifies the maximum number of models to return per page. None for default. page_token: A next page token returned from a previous page of results. None diff --git a/requirements.txt b/requirements.txt index d7fb6d736..dbeaee3b6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ +astroid == 2.3.3 pylint == 2.3.1 pytest >= 3.6.0 pytest-cov >= 2.4.0 From 0afcad871de1f3e32d89530941c00fb6149dfa3f Mon Sep 17 00:00:00 2001 From: Hiranya Jayathilaka Date: Thu, 30 Apr 2020 11:42:22 -0700 Subject: [PATCH 070/226] [chore] Release 4.2.0 (#457) --- firebase_admin/__about__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/firebase_admin/__about__.py b/firebase_admin/__about__.py index bd19af68f..ff12296f2 100644 --- a/firebase_admin/__about__.py +++ b/firebase_admin/__about__.py @@ -14,7 +14,7 @@ """About information (version, etc) for Firebase Admin SDK.""" -__version__ = '4.1.0' +__version__ = '4.2.0' __title__ = 'firebase_admin' __author__ = 'Firebase' __license__ = 'Apache License 2.0' From 05d4cc213ebcc2f5f85c35bdb1d51b64143f07aa Mon Sep 17 00:00:00 2001 From: rsgowman Date: Thu, 7 May 2020 14:57:16 -0400 Subject: [PATCH 071/226] Move shebang to top of file. (#458) It otherwise has no effect, causing the default shell to be tried (which is /bin/sh on my home system). --- lint.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lint.sh b/lint.sh index 0fd5058a3..5e65862f3 100755 --- a/lint.sh +++ b/lint.sh @@ -1,3 +1,5 @@ +#!/bin/bash + # Copyright 2017 Google Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -12,8 +14,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -#!/bin/bash - function lintAllFiles () { echo "Running linter on module $1" pylint --disable=$2 $1 From 96b82c0870e55c515791ae0415c209175f2a9d74 Mon Sep 17 00:00:00 2001 From: ifielker Date: Fri, 8 May 2020 14:25:55 -0400 Subject: [PATCH 072/226] Firebase ML: fixed displayName and tags regexes to match changed backend requirements. (#459) --- firebase_admin/ml.py | 4 ++-- integration/test_ml.py | 10 +++++----- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/firebase_admin/ml.py b/firebase_admin/ml.py index 900fdc24b..2613a3de3 100644 --- a/firebase_admin/ml.py +++ b/firebase_admin/ml.py @@ -49,8 +49,8 @@ _ML_ATTRIBUTE = '_ml' _MAX_PAGE_SIZE = 100 _MODEL_ID_PATTERN = re.compile(r'^[A-Za-z0-9_-]{1,60}$') -_DISPLAY_NAME_PATTERN = re.compile(r'^[A-Za-z0-9_-]{1,60}$') -_TAG_PATTERN = re.compile(r'^[A-Za-z0-9_-]{1,60}$') +_DISPLAY_NAME_PATTERN = re.compile(r'^[A-Za-z0-9_-]{1,32}$') +_TAG_PATTERN = re.compile(r'^[A-Za-z0-9_-]{1,32}$') _GCS_TFLITE_URI_PATTERN = re.compile( r'^gs://(?P[a-z0-9_.-]{3,63})/(?P.+)$') _RESOURCE_NAME_PATTERN = re.compile( diff --git a/integration/test_ml.py b/integration/test_ml.py index be791d8fa..1d32ebed1 100644 --- a/integration/test_ml.py +++ b/integration/test_ml.py @@ -42,22 +42,22 @@ def _random_identifier(prefix): NAME_ONLY_ARGS = { - 'display_name': _random_identifier('TestModel123_') + 'display_name': _random_identifier('TestModel_') } NAME_ONLY_ARGS_UPDATED = { - 'display_name': _random_identifier('TestModel123_updated_') + 'display_name': _random_identifier('TestModel_updated_') } NAME_AND_TAGS_ARGS = { - 'display_name': _random_identifier('TestModel123_tags_'), + 'display_name': _random_identifier('TestModel_tags_'), 'tags': ['test_tag123'] } FULL_MODEL_ARGS = { - 'display_name': _random_identifier('TestModel123_full_'), + 'display_name': _random_identifier('TestModel_full_'), 'tags': ['test_tag567'], 'file_name': 'model1.tflite' } INVALID_FULL_MODEL_ARGS = { - 'display_name': _random_identifier('TestModel123_invalid_full_'), + 'display_name': _random_identifier('TestModel_invalid_full_'), 'tags': ['test_tag890'], 'file_name': 'invalid_model.tflite' } From d1aae525fea24a92203673bec901ffeb89c21679 Mon Sep 17 00:00:00 2001 From: rsgowman Date: Tue, 12 May 2020 16:57:30 -0400 Subject: [PATCH 073/226] feat(auth): Add bulk get/delete methods (#400) This PR allows callers to retrieve a list of users by unique identifier (uid, email, phone, federated provider uid) as well as to delete a list of users. RELEASE NOTE: Added get_users() and delete_users() APIs for retrieving and deleting user accounts in bulk. --- firebase_admin/_auth_client.py | 78 ++++++++++++ firebase_admin/_auth_utils.py | 9 ++ firebase_admin/_rfc3339.py | 87 +++++++++++++ firebase_admin/_user_identifier.py | 103 +++++++++++++++ firebase_admin/_user_import.py | 7 +- firebase_admin/_user_mgt.py | 195 ++++++++++++++++++++++++++++- firebase_admin/auth.py | 73 +++++++++++ integration/test_auth.py | 159 ++++++++++++++++++++++- tests/test_rfc3339.py | 67 ++++++++++ tests/test_user_mgt.py | 134 ++++++++++++++++++++ 10 files changed, 907 insertions(+), 5 deletions(-) create mode 100644 firebase_admin/_rfc3339.py create mode 100644 firebase_admin/_user_identifier.py create mode 100644 tests/test_rfc3339.py diff --git a/firebase_admin/_auth_client.py b/firebase_admin/_auth_client.py index b7af6ddb6..12d60592e 100644 --- a/firebase_admin/_auth_client.py +++ b/firebase_admin/_auth_client.py @@ -21,6 +21,7 @@ from firebase_admin import _auth_utils from firebase_admin import _http_client from firebase_admin import _token_gen +from firebase_admin import _user_identifier from firebase_admin import _user_import from firebase_admin import _user_mgt @@ -182,6 +183,56 @@ def get_user_by_phone_number(self, phone_number): response = self._user_manager.get_user(phone_number=phone_number) return _user_mgt.UserRecord(response) + def get_users(self, identifiers): + """Gets the user data corresponding to the specified identifiers. + + There are no ordering guarantees; in particular, the nth entry in the + result list is not guaranteed to correspond to the nth entry in the input + parameters list. + + A maximum of 100 identifiers may be supplied. If more than 100 + identifiers are supplied, this method raises a `ValueError`. + + Args: + identifiers (list[Identifier]): A list of ``Identifier`` instances used + to indicate which user records should be returned. Must have <= 100 + entries. + + Returns: + GetUsersResult: A ``GetUsersResult`` instance corresponding to the + specified identifiers. + + Raises: + ValueError: If any of the identifiers are invalid or if more than 100 + identifiers are specified. + """ + response = self._user_manager.get_users(identifiers=identifiers) + + def _matches(identifier, user_record): + if isinstance(identifier, _user_identifier.UidIdentifier): + return identifier.uid == user_record.uid + if isinstance(identifier, _user_identifier.EmailIdentifier): + return identifier.email == user_record.email + if isinstance(identifier, _user_identifier.PhoneIdentifier): + return identifier.phone_number == user_record.phone_number + if isinstance(identifier, _user_identifier.ProviderIdentifier): + return next(( + True + for user_info in user_record.provider_data + if identifier.provider_id == user_info.provider_id + and identifier.provider_uid == user_info.uid + ), False) + raise TypeError("Unexpected type: {}".format(type(identifier))) + + def _is_user_found(identifier, user_records): + return any(_matches(identifier, user_record) for user_record in user_records) + + users = [_user_mgt.UserRecord(user) for user in response] + not_found = [ + identifier for identifier in identifiers if not _is_user_found(identifier, users)] + + return _user_mgt.GetUsersResult(users=users, not_found=not_found) + def list_users(self, page_token=None, max_results=_user_mgt.MAX_LIST_USERS_RESULTS): """Retrieves a page of user accounts from a Firebase project. @@ -306,6 +357,33 @@ def delete_user(self, uid): """ self._user_manager.delete_user(uid) + def delete_users(self, uids): + """Deletes the users specified by the given identifiers. + + Deleting a non-existing user does not generate an error (the method is + idempotent.) Non-existing users are considered to be successfully + deleted and are therefore included in the + `DeleteUserResult.success_count` value. + + A maximum of 1000 identifiers may be supplied. If more than 1000 + identifiers are supplied, this method raises a `ValueError`. + + Args: + uids: A list of strings indicating the uids of the users to be deleted. + Must have <= 1000 entries. + + Returns: + DeleteUsersResult: The total number of successful/failed deletions, as + well as the array of errors that correspond to the failed + deletions. + + Raises: + ValueError: If any of the identifiers are invalid or if more than 1000 + identifiers are specified. + """ + result = self._user_manager.delete_users(uids, force_delete=True) + return _user_mgt.DeleteUsersResult(result, len(uids)) + def import_users(self, users, hash_alg=None): """Imports the specified list of users into Firebase Auth. diff --git a/firebase_admin/_auth_utils.py b/firebase_admin/_auth_utils.py index f1ce97dee..2226675f9 100644 --- a/firebase_admin/_auth_utils.py +++ b/firebase_admin/_auth_utils.py @@ -136,6 +136,15 @@ def validate_provider_id(provider_id, required=True): 'string.'.format(provider_id)) return provider_id +def validate_provider_uid(provider_uid, required=True): + if provider_uid is None and not required: + return None + if not isinstance(provider_uid, str) or not provider_uid: + raise ValueError( + 'Invalid provider UID: "{0}". Provider UID must be a non-empty ' + 'string.'.format(provider_uid)) + return provider_uid + def validate_photo_url(photo_url, required=False): """Parses and validates the given URL string.""" if photo_url is None and not required: diff --git a/firebase_admin/_rfc3339.py b/firebase_admin/_rfc3339.py new file mode 100644 index 000000000..2c720bdd1 --- /dev/null +++ b/firebase_admin/_rfc3339.py @@ -0,0 +1,87 @@ +# Copyright 2020 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Parse RFC3339 date strings""" + +from datetime import datetime, timezone +import re + +def parse_to_epoch(datestr): + """Parse an RFC3339 date string and return the number of seconds since the + epoch (as a float). + + In particular, this method is meant to parse the strings returned by the + JSON mapping of protobuf google.protobuf.timestamp.Timestamp instances: + https://github.com/protocolbuffers/protobuf/blob/4cf5bfee9546101d98754d23ff378ff718ba8438/src/google/protobuf/timestamp.proto#L99 + + This method has microsecond precision; nanoseconds will be truncated. + + Args: + datestr: A string in RFC3339 format. + Returns: + Float: The number of seconds since the Unix epoch. + Raises: + ValueError: Raised if the `datestr` is not a valid RFC3339 date string. + """ + return _parse_to_datetime(datestr).timestamp() + + +def _parse_to_datetime(datestr): + """Parse an RFC3339 date string and return a python datetime instance. + + Args: + datestr: A string in RFC3339 format. + Returns: + datetime: The corresponding `datetime` (with timezone information). + Raises: + ValueError: Raised if the `datestr` is not a valid RFC3339 date string. + """ + # If more than 6 digits appear in the fractional seconds position, truncate + # to just the most significant 6. (i.e. we only have microsecond precision; + # nanos are truncated.) + datestr_modified = re.sub(r'(\.\d{6})\d*', r'\1', datestr) + + # This format is the one we actually expect to occur from our backend. The + # others are only present because the spec says we *should* accept them. + try: + return datetime.strptime( + datestr_modified, '%Y-%m-%dT%H:%M:%S.%fZ' + ).replace(tzinfo=timezone.utc) + except ValueError: + pass + + try: + return datetime.strptime( + datestr_modified, '%Y-%m-%dT%H:%M:%SZ' + ).replace(tzinfo=timezone.utc) + except ValueError: + pass + + # Note: %z parses timezone offsets, but requires the timezone offset *not* + # include a separating ':'. As of python 3.7, this was relaxed. + # TODO(rsgowman): Once python3.7 becomes our floor, we can drop the regex + # replacement. + datestr_modified = re.sub(r'(\d\d):(\d\d)$', r'\1\2', datestr_modified) + + try: + return datetime.strptime(datestr_modified, '%Y-%m-%dT%H:%M:%S.%f%z') + except ValueError: + pass + + try: + return datetime.strptime(datestr_modified, '%Y-%m-%dT%H:%M:%S%z') + except ValueError: + pass + + raise ValueError('time data {0} does not match RFC3339 format'.format(datestr)) diff --git a/firebase_admin/_user_identifier.py b/firebase_admin/_user_identifier.py new file mode 100644 index 000000000..85a224e0b --- /dev/null +++ b/firebase_admin/_user_identifier.py @@ -0,0 +1,103 @@ +# Copyright 2020 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Classes to uniquely identify a user.""" + +from firebase_admin import _auth_utils + +class UserIdentifier: + """Identifies a user to be looked up.""" + + +class UidIdentifier(UserIdentifier): + """Used for looking up an account by uid. + + See ``auth.get_user()``. + """ + + def __init__(self, uid): + """Constructs a new `UidIdentifier` object. + + Args: + uid: A user ID string. + """ + self._uid = _auth_utils.validate_uid(uid, required=True) + + @property + def uid(self): + return self._uid + + +class EmailIdentifier(UserIdentifier): + """Used for looking up an account by email. + + See ``auth.get_user()``. + """ + + def __init__(self, email): + """Constructs a new `EmailIdentifier` object. + + Args: + email: A user email address string. + """ + self._email = _auth_utils.validate_email(email, required=True) + + @property + def email(self): + return self._email + + +class PhoneIdentifier(UserIdentifier): + """Used for looking up an account by phone number. + + See ``auth.get_user()``. + """ + + def __init__(self, phone_number): + """Constructs a new `PhoneIdentifier` object. + + Args: + phone_number: A phone number string. + """ + self._phone_number = _auth_utils.validate_phone(phone_number, required=True) + + @property + def phone_number(self): + return self._phone_number + + +class ProviderIdentifier(UserIdentifier): + """Used for looking up an account by provider. + + See ``auth.get_user()``. + """ + + def __init__(self, provider_id, provider_uid): + """Constructs a new `ProviderIdentifier` object. + +   Args: +     provider_id: A provider ID string. +     provider_uid: A provider UID string. + """ + self._provider_id = _auth_utils.validate_provider_id(provider_id, required=True) + self._provider_uid = _auth_utils.validate_provider_uid( + provider_uid, required=True) + + @property + def provider_id(self): + return self._provider_id + + @property + def provider_uid(self): + return self._provider_uid diff --git a/firebase_admin/_user_import.py b/firebase_admin/_user_import.py index 21cc8082d..7834b232a 100644 --- a/firebase_admin/_user_import.py +++ b/firebase_admin/_user_import.py @@ -472,7 +472,12 @@ def standard_scrypt(cls, memory_cost, parallelization, block_size, derived_key_l class ErrorInfo: - """Represents an error encountered while importing an ``ImportUserRecord``.""" + """Represents an error encountered while performing a batch operation such + as importing users or deleting multiple user accounts. + """ + # TODO(rsgowman): This class used to be specific to importing users (hence + # it's home in _user_import.py). It's now also used by bulk deletion of + # users. Move this to a more common location. def __init__(self, error): self._index = error['index'] diff --git a/firebase_admin/_user_mgt.py b/firebase_admin/_user_mgt.py index 0b0c5ddb6..0307959f3 100644 --- a/firebase_admin/_user_mgt.py +++ b/firebase_admin/_user_mgt.py @@ -15,13 +15,17 @@ """Firebase user management sub module.""" import base64 +from collections import defaultdict import json from urllib import parse import requests from firebase_admin import _auth_utils +from firebase_admin import _rfc3339 +from firebase_admin import _user_identifier from firebase_admin import _user_import +from firebase_admin._user_import import ErrorInfo MAX_LIST_USERS_RESULTS = 1000 @@ -41,11 +45,14 @@ def __init__(self, description): class UserMetadata: """Contains additional metadata associated with a user account.""" - def __init__(self, creation_timestamp=None, last_sign_in_timestamp=None): + def __init__(self, creation_timestamp=None, last_sign_in_timestamp=None, + last_refresh_timestamp=None): self._creation_timestamp = _auth_utils.validate_timestamp( creation_timestamp, 'creation_timestamp') self._last_sign_in_timestamp = _auth_utils.validate_timestamp( last_sign_in_timestamp, 'last_sign_in_timestamp') + self._last_refresh_timestamp = _auth_utils.validate_timestamp( + last_refresh_timestamp, 'last_refresh_timestamp') @property def creation_timestamp(self): @@ -65,6 +72,16 @@ def last_sign_in_timestamp(self): """ return self._last_sign_in_timestamp + @property + def last_refresh_timestamp(self): + """The time at which the user was last active (ID token refreshed). + + Returns: + integer: Milliseconds since epoch timestamp, or `None` if the user was + never active. + """ + return self._last_refresh_timestamp + class UserInfo: """A collection of standard profile information for a user. @@ -216,7 +233,12 @@ def _int_or_none(key): if key in self._data: return int(self._data[key]) return None - return UserMetadata(_int_or_none('createdAt'), _int_or_none('lastLoginAt')) + last_refresh_at_millis = None + last_refresh_at_rfc3339 = self._data.get('lastRefreshAt', None) + if last_refresh_at_rfc3339: + last_refresh_at_millis = int(_rfc3339.parse_to_epoch(last_refresh_at_rfc3339) * 1000) + return UserMetadata( + _int_or_none('createdAt'), _int_or_none('lastLoginAt'), last_refresh_at_millis) @property def provider_data(self): @@ -289,6 +311,35 @@ def password_salt(self): return self._data.get('salt') +class GetUsersResult: + """Represents the result of the ``auth.get_users()`` API.""" + + def __init__(self, users, not_found): + """Constructs a `GetUsersResult` object. + + Args: + users: List of `UserRecord` instances. + not_found: List of `UserIdentifier` instances. + """ + self._users = users + self._not_found = not_found + + @property + def users(self): + """Set of `UserRecord` instances, corresponding to the set of users + that were requested. Only users that were found are listed here. The + result set is unordered. + """ + return self._users + + @property + def not_found(self): + """Set of `UserIdentifier` instances that were requested, but not + found. + """ + return self._not_found + + class ListUsersPage: """Represents a page of user records exported from a Firebase project. @@ -340,6 +391,63 @@ def iterate_all(self): return _UserIterator(self) +class DeleteUsersResult: + """Represents the result of the ``auth.delete_users()`` API.""" + + def __init__(self, result, total): + """Constructs a `DeleteUsersResult` object. + + Args: + result: The proto response, wrapped in a + `BatchDeleteAccountsResponse` instance. + total: Total integer number of deletion attempts. + """ + errors = result.errors + self._success_count = total - len(errors) + self._failure_count = len(errors) + self._errors = errors + + @property + def success_count(self): + """Returns the number of users that were deleted successfully (possibly + zero). + + Users that did not exist prior to calling `delete_users()` are + considered to be successfully deleted. + """ + return self._success_count + + @property + def failure_count(self): + """Returns the number of users that failed to be deleted (possibly + zero). + """ + return self._failure_count + + @property + def errors(self): + """A list of `auth.ErrorInfo` instances describing the errors that + were encountered during the deletion. Length of this list is equal to + `failure_count`. + """ + return self._errors + + +class BatchDeleteAccountsResponse: + """Represents the results of a `delete_users()` call.""" + + def __init__(self, errors=None): + """Constructs a `BatchDeleteAccountsResponse` instance, corresponding to + the JSON representing the `BatchDeleteAccountsResponse` proto. + + Args: + errors: List of dictionaries, with each dictionary representing an + `ErrorInfo` instance as returned by the server. `None` implies + an empty list. + """ + self.errors = [ErrorInfo(err) for err in errors] if errors else [] + + class ProviderUserInfo(UserInfo): """Contains metadata regarding how a user is known by a particular identity provider.""" @@ -492,6 +600,53 @@ def get_user(self, **kwargs): http_response=http_resp) return body['users'][0] + def get_users(self, identifiers): + """Looks up multiple users by their identifiers (uid, email, etc.) + + Args: + identifiers: UserIdentifier[]: The identifiers indicating the user + to be looked up. Must have <= 100 entries. + + Returns: + list[dict[string, string]]: List of dicts representing the JSON + `UserInfo` responses from the server. + + Raises: + ValueError: If any of the identifiers are invalid or if more than + 100 identifiers are specified. + UnexpectedResponseError: If the backend server responds with an + unexpected message. + """ + if not identifiers: + return [] + if len(identifiers) > 100: + raise ValueError('`identifiers` parameter must have <= 100 entries.') + + payload = defaultdict(list) + for identifier in identifiers: + if isinstance(identifier, _user_identifier.UidIdentifier): + payload['localId'].append(identifier.uid) + elif isinstance(identifier, _user_identifier.EmailIdentifier): + payload['email'].append(identifier.email) + elif isinstance(identifier, _user_identifier.PhoneIdentifier): + payload['phoneNumber'].append(identifier.phone_number) + elif isinstance(identifier, _user_identifier.ProviderIdentifier): + payload['federatedUserId'].append({ + 'providerId': identifier.provider_id, + 'rawId': identifier.provider_uid + }) + else: + raise ValueError( + 'Invalid entry in "identifiers" list. Unsupported type: {}' + .format(type(identifier))) + + body, http_resp = self._make_request( + 'post', '/accounts:lookup', json=payload) + if not http_resp.ok: + raise _auth_utils.UnexpectedResponseError( + 'Failed to get users.', http_response=http_resp) + return body.get('users', []) + def list_users(self, page_token=None, max_results=MAX_LIST_USERS_RESULTS): """Retrieves a batch of users.""" if page_token is not None: @@ -585,6 +740,42 @@ def delete_user(self, uid): raise _auth_utils.UnexpectedResponseError( 'Failed to delete user: {0}.'.format(uid), http_response=http_resp) + def delete_users(self, uids, force_delete=False): + """Deletes the users identified by the specified user ids. + + Args: + uids: A list of strings indicating the uids of the users to be deleted. + Must have <= 1000 entries. + force_delete: Optional parameter that indicates if users should be + deleted, even if they're not disabled. Defaults to False. + + + Returns: + BatchDeleteAccountsResponse: Server's proto response, wrapped in a + python object. + + Raises: + ValueError: If any of the identifiers are invalid or if more than 1000 + identifiers are specified. + UnexpectedResponseError: If the backend server responds with an + unexpected message. + """ + if not uids: + return BatchDeleteAccountsResponse() + + if len(uids) > 1000: + raise ValueError("`uids` paramter must have <= 1000 entries.") + for uid in uids: + _auth_utils.validate_uid(uid, required=True) + + body, http_resp = self._make_request('post', '/accounts:batchDelete', + json={'localIds': uids, 'force': force_delete}) + if not isinstance(body, dict): + raise _auth_utils.UnexpectedResponseError( + 'Unexpected response from server while attempting to delete users.', + http_response=http_resp) + return BatchDeleteAccountsResponse(body.get('errors', [])) + def import_users(self, users, hash_alg=None): """Imports the given list of users to Firebase Auth.""" try: diff --git a/firebase_admin/auth.py b/firebase_admin/auth.py index 1cce9ea00..5d2fe0f68 100644 --- a/firebase_admin/auth.py +++ b/firebase_admin/auth.py @@ -22,6 +22,7 @@ from firebase_admin import _auth_client from firebase_admin import _auth_providers from firebase_admin import _auth_utils +from firebase_admin import _user_identifier from firebase_admin import _token_gen from firebase_admin import _user_import from firebase_admin import _user_mgt @@ -66,6 +67,12 @@ 'UserProvider', 'UserRecord', + 'UserIdentifier', + 'UidIdentifier', + 'EmailIdentifier', + 'PhoneIdentifier', + 'ProviderIdentifier', + 'create_custom_token', 'create_oidc_provider_config', 'create_saml_provider_config', @@ -74,6 +81,7 @@ 'delete_oidc_provider_config', 'delete_saml_provider_config', 'delete_user', + 'delete_users', 'generate_email_verification_link', 'generate_password_reset_link', 'generate_sign_in_with_email_link', @@ -82,6 +90,7 @@ 'get_user', 'get_user_by_email', 'get_user_by_phone_number', + 'get_users', 'import_users', 'list_saml_provider_configs', 'list_users', @@ -99,11 +108,13 @@ Client = _auth_client.Client ConfigurationNotFoundError = _auth_utils.ConfigurationNotFoundError DELETE_ATTRIBUTE = _user_mgt.DELETE_ATTRIBUTE +DeleteUsersResult = _user_mgt.DeleteUsersResult EmailAlreadyExistsError = _auth_utils.EmailAlreadyExistsError ErrorInfo = _user_import.ErrorInfo ExpiredIdTokenError = _token_gen.ExpiredIdTokenError ExpiredSessionCookieError = _token_gen.ExpiredSessionCookieError ExportedUserRecord = _user_mgt.ExportedUserRecord +GetUsersResult = _user_mgt.GetUsersResult ImportUserRecord = _user_import.ImportUserRecord InsufficientPermissionError = _auth_utils.InsufficientPermissionError InvalidDynamicLinkDomainError = _auth_utils.InvalidDynamicLinkDomainError @@ -128,6 +139,12 @@ UserProvider = _user_import.UserProvider UserRecord = _user_mgt.UserRecord +UserIdentifier = _user_identifier.UserIdentifier +UidIdentifier = _user_identifier.UidIdentifier +EmailIdentifier = _user_identifier.EmailIdentifier +PhoneIdentifier = _user_identifier.PhoneIdentifier +ProviderIdentifier = _user_identifier.ProviderIdentifier + def _get_client(app): """Returns a client instance for an App. @@ -328,6 +345,34 @@ def get_user_by_phone_number(phone_number, app=None): return client.get_user_by_phone_number(phone_number=phone_number) +def get_users(identifiers, app=None): + """Gets the user data corresponding to the specified identifiers. + + There are no ordering guarantees; in particular, the nth entry in the + result list is not guaranteed to correspond to the nth entry in the input + parameters list. + + A maximum of 100 identifiers may be supplied. If more than 100 + identifiers are supplied, this method raises a `ValueError`. + + Args: + identifiers (list[Identifier]): A list of ``Identifier`` instances used + to indicate which user records should be returned. Must have <= 100 + entries. + app: An App instance (optional). + + Returns: + GetUsersResult: A ``GetUsersResult`` instance corresponding to the + specified identifiers. + + Raises: + ValueError: If any of the identifiers are invalid or if more than 100 + identifiers are specified. + """ + client = _get_client(app) + return client.get_users(identifiers) + + def list_users(page_token=None, max_results=_user_mgt.MAX_LIST_USERS_RESULTS, app=None): """Retrieves a page of user accounts from a Firebase project. @@ -460,6 +505,34 @@ def delete_user(uid, app=None): client.delete_user(uid) +def delete_users(uids, app=None): + """Deletes the users specified by the given identifiers. + + Deleting a non-existing user does not generate an error (the method is + idempotent.) Non-existing users are considered to be successfully deleted + and are therefore included in the `DeleteUserResult.success_count` value. + + A maximum of 1000 identifiers may be supplied. If more than 1000 + identifiers are supplied, this method raises a `ValueError`. + + Args: + uids: A list of strings indicating the uids of the users to be deleted. + Must have <= 1000 entries. + app: An App instance (optional). + + Returns: + DeleteUsersResult: The total number of successful/failed deletions, as + well as the array of errors that correspond to the failed + deletions. + + Raises: + ValueError: If any of the identifiers are invalid or if more than 1000 + identifiers are specified. + """ + client = _get_client(app) + return client.delete_users(uids) + + def import_users(users, hash_alg=None, app=None): """Imports the specified list of users into Firebase Auth. diff --git a/integration/test_auth.py b/integration/test_auth.py index cfd775016..26cf53d20 100644 --- a/integration/test_auth.py +++ b/integration/test_auth.py @@ -18,6 +18,7 @@ import random import string import time +from typing import List from urllib import parse import uuid @@ -71,7 +72,7 @@ def _sign_in(custom_token, api_key): return resp.json().get('idToken') def _sign_in_with_password(email, password, api_key): - body = {'email': email, 'password': password} + body = {'email': email, 'password': password, 'returnSecureToken': True} params = {'key' : api_key} resp = requests.request('post', _verify_password_url, params=params, json=body) resp.raise_for_status() @@ -191,7 +192,7 @@ def new_user(): auth.delete_user(user.uid) @pytest.fixture -def new_user_with_params(): +def new_user_with_params() -> auth.UserRecord: random_id, email = _random_id() phone = _random_phone() user = auth.create_user( @@ -214,9 +215,52 @@ def new_user_list(): auth.create_user(password='password').uid, ] yield users + # TODO(rsgowman): Using auth.delete_users() would make more sense here, but + # that's currently rate limited to 1qps, so using it in this context would + # almost certainly trigger errors. When/if that limit is relaxed, switch to + # batch delete. for uid in users: auth.delete_user(uid) +@pytest.fixture +def new_user_record_list() -> List[auth.UserRecord]: + uid1, email1 = _random_id() + uid2, email2 = _random_id() + uid3, email3 = _random_id() + users = [ + auth.create_user( + uid=uid1, email=email1, password='password', phone_number=_random_phone()), + auth.create_user( + uid=uid2, email=email2, password='password', phone_number=_random_phone()), + auth.create_user( + uid=uid3, email=email3, password='password', phone_number=_random_phone()), + ] + yield users + for user in users: + auth.delete_user(user.uid) + +@pytest.fixture +def new_user_with_provider() -> auth.UserRecord: + uid4, email4 = _random_id() + google_uid, google_email = _random_id() + import_user1 = auth.ImportUserRecord( + uid=uid4, + email=email4, + provider_data=[ + auth.UserProvider( + uid=google_uid, + provider_id='google.com', + email=google_email, + ) + ]) + user_import_result = auth.import_users([import_user1]) + assert user_import_result.success_count == 1 + assert user_import_result.failure_count == 0 + + user = auth.get_user(uid4) + yield user + auth.delete_user(user.uid) + @pytest.fixture def new_user_email_unverified(): random_id, email = _random_id() @@ -248,6 +292,87 @@ def test_get_user(new_user_with_params): provider_ids = sorted([provider.provider_id for provider in user.provider_data]) assert provider_ids == ['password', 'phone'] +class TestGetUsers: + @staticmethod + def _map_user_record_to_uid_email_phones(user_record): + return { + 'uid': user_record.uid, + 'email': user_record.email, + 'phone_number': user_record.phone_number + } + + def test_multiple_uid_types(self, new_user_record_list, new_user_with_provider): + get_users_results = auth.get_users([ + auth.UidIdentifier(new_user_record_list[0].uid), + auth.EmailIdentifier(new_user_record_list[1].email), + auth.PhoneIdentifier(new_user_record_list[2].phone_number), + auth.ProviderIdentifier( + new_user_with_provider.provider_data[0].provider_id, + new_user_with_provider.provider_data[0].uid, + )]) + actual = sorted([ + self._map_user_record_to_uid_email_phones(user) + for user in get_users_results.users + ], key=lambda user: user['uid']) + expected = sorted([ + self._map_user_record_to_uid_email_phones(user) + for user in new_user_record_list + [new_user_with_provider] + ], key=lambda user: user['uid']) + + assert actual == expected + + def test_existing_and_non_existing_users(self, new_user_record_list): + get_users_results = auth.get_users([ + auth.UidIdentifier(new_user_record_list[0].uid), + auth.UidIdentifier('uid_that_doesnt_exist'), + auth.UidIdentifier(new_user_record_list[2].uid)]) + actual = sorted([ + self._map_user_record_to_uid_email_phones(user) + for user in get_users_results.users + ], key=lambda user: user['uid']) + expected = sorted([ + self._map_user_record_to_uid_email_phones(user) + for user in [new_user_record_list[0], new_user_record_list[2]] + ], key=lambda user: user['uid']) + + assert actual == expected + + def test_non_existing_users(self): + not_found_ids = [auth.UidIdentifier('non-existing user')] + get_users_results = auth.get_users(not_found_ids) + + assert get_users_results.users == [] + assert get_users_results.not_found == not_found_ids + + def test_de_dups_duplicate_users(self, new_user): + get_users_results = auth.get_users([ + auth.UidIdentifier(new_user.uid), + auth.UidIdentifier(new_user.uid)]) + actual = [ + self._map_user_record_to_uid_email_phones(user) + for user in get_users_results.users] + expected = [self._map_user_record_to_uid_email_phones(new_user)] + assert actual == expected + +def test_last_refresh_timestamp(new_user_with_params: auth.UserRecord, api_key): + # new users should not have a last_refresh_timestamp set + assert new_user_with_params.user_metadata.last_refresh_timestamp is None + + # login to cause the last_refresh_timestamp to be set + _sign_in_with_password(new_user_with_params.email, 'secret', api_key) + new_user_with_params = auth.get_user(new_user_with_params.uid) + + # Ensure the last refresh time occurred at approximately 'now'. (With a + # tolerance of up to 1 minute; we ideally want to ensure that any timezone + # considerations are handled properly, so as long as we're within an hour, + # we're in good shape.) + millis_per_second = 1000 + millis_per_minute = millis_per_second * 60 + + last_refresh_timestamp = new_user_with_params.user_metadata.last_refresh_timestamp + assert last_refresh_timestamp == pytest.approx( + time.time()*millis_per_second, 1*millis_per_minute) + def test_list_users(new_user_list): err_msg_template = ( 'Missing {field} field. A common cause would be forgetting to add the "Firebase ' + @@ -366,6 +491,36 @@ def test_delete_user(): with pytest.raises(auth.UserNotFoundError): auth.get_user(user.uid) + +class TestDeleteUsers: + def test_delete_multiple_users(self): + uid1 = auth.create_user(disabled=True).uid + uid2 = auth.create_user(disabled=False).uid + uid3 = auth.create_user(disabled=True).uid + + delete_users_result = auth.delete_users([uid1, uid2, uid3]) + assert delete_users_result.success_count == 3 + assert delete_users_result.failure_count == 0 + assert len(delete_users_result.errors) == 0 + + get_users_results = auth.get_users( + [auth.UidIdentifier(uid1), auth.UidIdentifier(uid2), auth.UidIdentifier(uid3)]) + assert len(get_users_results.users) == 0 + + def test_is_idempotent(self): + uid = auth.create_user().uid + + delete_users_result = auth.delete_users([uid]) + assert delete_users_result.success_count == 1 + assert delete_users_result.failure_count == 0 + + # Delete the user again, ensuring that everything still counts as a + # success. + delete_users_result = auth.delete_users([uid]) + assert delete_users_result.success_count == 1 + assert delete_users_result.failure_count == 0 + + def test_revoke_refresh_tokens(new_user): user = auth.get_user(new_user.uid) old_valid_after = user.tokens_valid_after_timestamp diff --git a/tests/test_rfc3339.py b/tests/test_rfc3339.py new file mode 100644 index 000000000..5a844b07e --- /dev/null +++ b/tests/test_rfc3339.py @@ -0,0 +1,67 @@ +# Copyright 2020 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test cases for the firebase_admin._rfc3339 module.""" + +import pytest + +from firebase_admin import _rfc3339 + +def test_epoch(): + expected = pytest.approx(0) + assert _rfc3339.parse_to_epoch("1970-01-01T00:00:00Z") == expected + assert _rfc3339.parse_to_epoch("1970-01-01T00:00:00z") == expected + assert _rfc3339.parse_to_epoch("1970-01-01T00:00:00+00:00") == expected + assert _rfc3339.parse_to_epoch("1970-01-01T00:00:00-00:00") == expected + assert _rfc3339.parse_to_epoch("1970-01-01T01:00:00+01:00") == expected + assert _rfc3339.parse_to_epoch("1969-12-31T23:00:00-01:00") == expected + +def test_pre_epoch(): + expected = -5617641600 + assert _rfc3339.parse_to_epoch("1791-12-26T00:00:00Z") == expected + assert _rfc3339.parse_to_epoch("1791-12-26T00:00:00+00:00") == expected + assert _rfc3339.parse_to_epoch("1791-12-26T00:00:00-00:00") == expected + assert _rfc3339.parse_to_epoch("1791-12-26T01:00:00+01:00") == expected + assert _rfc3339.parse_to_epoch("1791-12-25T23:00:00-01:00") == expected + +def test_post_epoch(): + expected = 904892400 + assert _rfc3339.parse_to_epoch("1998-09-04T07:00:00Z") == expected + assert _rfc3339.parse_to_epoch("1998-09-04T07:00:00+00:00") == expected + assert _rfc3339.parse_to_epoch("1998-09-04T08:00:00+01:00") == expected + assert _rfc3339.parse_to_epoch("1998-09-04T06:00:00-01:00") == expected + +def test_micros_millis(): + assert _rfc3339.parse_to_epoch("1970-01-01T00:00:00Z") == pytest.approx(0) + assert _rfc3339.parse_to_epoch("1970-01-01T00:00:00.1Z") == pytest.approx(0.1) + assert _rfc3339.parse_to_epoch("1970-01-01T00:00:00.001Z") == pytest.approx(0.001) + assert _rfc3339.parse_to_epoch("1970-01-01T00:00:00.000001Z") == pytest.approx(0.000001) + + assert _rfc3339.parse_to_epoch("1970-01-01T00:00:00+00:00") == pytest.approx(0) + assert _rfc3339.parse_to_epoch("1970-01-01T00:00:00.1+00:00") == pytest.approx(0.1) + assert _rfc3339.parse_to_epoch("1970-01-01T00:00:00.001+00:00") == pytest.approx(0.001) + assert _rfc3339.parse_to_epoch("1970-01-01T00:00:00.000001+00:00") == pytest.approx(0.000001) + +def test_nanos(): + assert _rfc3339.parse_to_epoch("1970-01-01T00:00:00.0000001Z") == pytest.approx(0) + +@pytest.mark.parametrize('datestr', [ + 'not a date string', + '1970-01-01 00:00:00Z', + '1970-01-01 00:00:00+00:00', + '1970-01-01T00:00:00', + ]) +def test_bad_datestrs(datestr): + with pytest.raises(ValueError): + _rfc3339.parse_to_epoch(datestr) diff --git a/tests/test_user_mgt.py b/tests/test_user_mgt.py index c7b2de496..79e23373f 100644 --- a/tests/test_user_mgt.py +++ b/tests/test_user_mgt.py @@ -322,6 +322,92 @@ def test_get_user_by_phone_http_error(self, user_mgt_app): assert excinfo.value.cause is not None +class TestGetUsers: + + @staticmethod + def _map_user_record_to_uid_email_phones(user_record): + return { + 'uid': user_record.uid, + 'email': user_record.email, + 'phone_number': user_record.phone_number + } + + def test_more_than_100_identifiers(self, user_mgt_app): + identifiers = [auth.UidIdentifier('id' + str(i)) for i in range(101)] + with pytest.raises(ValueError): + auth.get_users(identifiers, app=user_mgt_app) + + def test_no_identifiers(self, user_mgt_app): + get_users_results = auth.get_users([], app=user_mgt_app) + assert get_users_results.users == [] + assert get_users_results.not_found == [] + + def test_identifiers_that_do_not_exist(self, user_mgt_app): + _instrument_user_manager(user_mgt_app, 200, '{}') + not_found_ids = [auth.UidIdentifier('id that doesnt exist')] + get_users_results = auth.get_users(not_found_ids, app=user_mgt_app) + assert get_users_results.users == [] + assert get_users_results.not_found == not_found_ids + + def test_invalid_uid(self): + with pytest.raises(ValueError): + auth.UidIdentifier('too long ' + '.'*128) + + def test_invalid_email(self): + with pytest.raises(ValueError): + auth.EmailIdentifier('invalid email addr') + + def test_invalid_phone_number(self): + with pytest.raises(ValueError): + auth.PhoneIdentifier('invalid phone number') + + def test_invalid_provider(self): + with pytest.raises(ValueError): + auth.ProviderIdentifier(provider_id='', provider_uid='') + + def test_success(self, user_mgt_app): + mock_users = [{ + "localId": "uid1", + "email": "user1@example.com", + "phoneNumber": "+15555550001" + }, { + "localId": "uid2", + "email": "user2@example.com", + "phoneNumber": "+15555550002" + }, { + "localId": "uid3", + "email": "user3@example.com", + "phoneNumber": "+15555550003" + }, { + "localId": "uid4", + "email": "user4@example.com", + "phoneNumber": "+15555550004", + "providerUserInfo": [{ + "providerId": "google.com", + "rawId": "google_uid4" + }] + }] + _instrument_user_manager(user_mgt_app, 200, '{ "users": ' + json.dumps(mock_users) + '}') + + get_users_results = auth.get_users([ + auth.UidIdentifier('uid1'), + auth.EmailIdentifier('user2@example.com'), + auth.PhoneIdentifier('+15555550003'), + auth.ProviderIdentifier(provider_id='google.com', provider_uid='google_uid4'), + auth.UidIdentifier('this-user-doesnt-exist'), + ], app=user_mgt_app) + + actual = sorted( + [self._map_user_record_to_uid_email_phones(user) for user in get_users_results.users], + key=lambda user: user['uid']) + expected = sorted([ + self._map_user_record_to_uid_email_phones(auth.UserRecord(user)) + for user in mock_users + ], key=lambda user: user['uid']) + assert actual == expected + assert [u.uid for u in get_users_results.not_found] == ['this-user-doesnt-exist'] + + class TestCreateUser: already_exists_errors = { @@ -633,6 +719,54 @@ def test_delete_user_unexpected_response(self, user_mgt_app): assert isinstance(excinfo.value, exceptions.UnknownError) +class TestDeleteUsers: + + def test_empty_list(self, user_mgt_app): + delete_users_result = auth.delete_users([], app=user_mgt_app) + assert delete_users_result.success_count == 0 + assert delete_users_result.failure_count == 0 + assert len(delete_users_result.errors) == 0 + + def test_too_many_identifiers_should_fail(self, user_mgt_app): + ids = ['id' + str(i) for i in range(1001)] + with pytest.raises(ValueError): + auth.delete_users(ids, app=user_mgt_app) + + def test_invalid_id_should_fail(self, user_mgt_app): + ids = ['too long ' + '.'*128] + with pytest.raises(ValueError): + auth.delete_users(ids, app=user_mgt_app) + + def test_should_index_errors_correctly_in_results(self, user_mgt_app): + _instrument_user_manager(user_mgt_app, 200, """{ + "errors": [{ + "index": 0, + "localId": "uid1", + "message": "NOT_DISABLED : Disable the account before batch deletion." + }, { + "index": 2, + "localId": "uid3", + "message": "something awful" + }] + }""") + + delete_users_result = auth.delete_users(['uid1', 'uid2', 'uid3', 'uid4'], app=user_mgt_app) + assert delete_users_result.success_count == 2 + assert delete_users_result.failure_count == 2 + assert len(delete_users_result.errors) == 2 + assert delete_users_result.errors[0].index == 0 + assert delete_users_result.errors[0].reason.startswith('NOT_DISABLED') + assert delete_users_result.errors[1].index == 2 + assert delete_users_result.errors[1].reason == 'something awful' + + def test_success(self, user_mgt_app): + _instrument_user_manager(user_mgt_app, 200, '{}') + delete_users_result = auth.delete_users(['uid1', 'uid2', 'uid3'], app=user_mgt_app) + assert delete_users_result.success_count == 3 + assert delete_users_result.failure_count == 0 + assert len(delete_users_result.errors) == 0 + + class TestListUsers: @pytest.mark.parametrize('arg', [None, 'foo', list(), dict(), 0, -1, 1001, False]) From f8b1ffffe2f186cadb66db731c1c01930c3af494 Mon Sep 17 00:00:00 2001 From: rsgowman Date: Wed, 13 May 2020 08:41:50 -0400 Subject: [PATCH 074/226] Followup to #400 to ensure all new types are exported (#461) Also fix multi-line Returns: doc statements. (See https://github.com/sphinx-contrib/napoleon/issues/4) --- firebase_admin/_auth_client.py | 6 +++--- firebase_admin/_auth_providers.py | 2 +- firebase_admin/_user_mgt.py | 6 +++--- firebase_admin/auth.py | 13 +++++++------ 4 files changed, 14 insertions(+), 13 deletions(-) diff --git a/firebase_admin/_auth_client.py b/firebase_admin/_auth_client.py index 12d60592e..1c9b37082 100644 --- a/firebase_admin/_auth_client.py +++ b/firebase_admin/_auth_client.py @@ -200,7 +200,7 @@ def get_users(self, identifiers): Returns: GetUsersResult: A ``GetUsersResult`` instance corresponding to the - specified identifiers. + specified identifiers. Raises: ValueError: If any of the identifiers are invalid or if more than 100 @@ -374,8 +374,8 @@ def delete_users(self, uids): Returns: DeleteUsersResult: The total number of successful/failed deletions, as - well as the array of errors that correspond to the failed - deletions. + well as the array of errors that correspond to the failed + deletions. Raises: ValueError: If any of the identifiers are invalid or if more than 1000 diff --git a/firebase_admin/_auth_providers.py b/firebase_admin/_auth_providers.py index 96f1b5348..46de6fe5f 100644 --- a/firebase_admin/_auth_providers.py +++ b/firebase_admin/_auth_providers.py @@ -121,7 +121,7 @@ def get_next_page(self): Returns: ListProviderConfigsPage: Next page of provider configs, or None if this is the last - page. + page. """ if self.has_next_page: return self.__class__(self._download, self.next_page_token, self._max_results) diff --git a/firebase_admin/_user_mgt.py b/firebase_admin/_user_mgt.py index 0307959f3..1d97dd504 100644 --- a/firebase_admin/_user_mgt.py +++ b/firebase_admin/_user_mgt.py @@ -78,7 +78,7 @@ def last_refresh_timestamp(self): Returns: integer: Milliseconds since epoch timestamp, or `None` if the user was - never active. + never active. """ return self._last_refresh_timestamp @@ -215,7 +215,7 @@ def tokens_valid_after_timestamp(self): Returns: int: Timestamp in milliseconds since the epoch, truncated to the second. - All tokens issued before that time are considered revoked. + All tokens issued before that time are considered revoked. """ valid_since = self._data.get('validSince') if valid_since is not None: @@ -752,7 +752,7 @@ def delete_users(self, uids, force_delete=False): Returns: BatchDeleteAccountsResponse: Server's proto response, wrapped in a - python object. + python object. Raises: ValueError: If any of the identifiers are invalid or if more than 1000 diff --git a/firebase_admin/auth.py b/firebase_admin/auth.py index 5d2fe0f68..c5361bc38 100644 --- a/firebase_admin/auth.py +++ b/firebase_admin/auth.py @@ -43,6 +43,8 @@ 'ExpiredIdTokenError', 'ExpiredSessionCookieError', 'ExportedUserRecord', + 'DeleteUsersResult', + 'GetUsersResult', 'ImportUserRecord', 'InsufficientPermissionError', 'InvalidDynamicLinkDomainError', @@ -356,14 +358,14 @@ def get_users(identifiers, app=None): identifiers are supplied, this method raises a `ValueError`. Args: - identifiers (list[Identifier]): A list of ``Identifier`` instances used - to indicate which user records should be returned. Must have <= 100 - entries. + identifiers (list[UserIdentifier]): A list of ``UserIdentifier`` + instances used to indicate which user records should be returned. + Must have <= 100 entries. app: An App instance (optional). Returns: GetUsersResult: A ``GetUsersResult`` instance corresponding to the - specified identifiers. + specified identifiers. Raises: ValueError: If any of the identifiers are invalid or if more than 100 @@ -522,8 +524,7 @@ def delete_users(uids, app=None): Returns: DeleteUsersResult: The total number of successful/failed deletions, as - well as the array of errors that correspond to the failed - deletions. + well as the array of errors that correspond to the failed deletions. Raises: ValueError: If any of the identifiers are invalid or if more than 1000 From 74b517934ceedf231e8e2855560f700455a328ba Mon Sep 17 00:00:00 2001 From: Hiranya Jayathilaka Date: Thu, 14 May 2020 12:29:36 -0700 Subject: [PATCH 075/226] [chore] Release 4.3.0 (#462) --- .github/workflows/release.yml | 2 +- firebase_admin/__about__.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 64ee304ce..fbde8ed59 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -134,7 +134,7 @@ jobs: uses: firebase/firebase-admin-node/.github/actions/send-tweet@master with: status: > - ${{ steps.preflight.outputs.version }} of @Firebase Admin Python SDK is avaialble. + ${{ steps.preflight.outputs.version }} of @Firebase Admin Python SDK is available. https://github.com/firebase/firebase-admin-python/releases/tag/${{ steps.preflight.outputs.version }} consumer-key: ${{ secrets.TWITTER_CONSUMER_KEY }} consumer-secret: ${{ secrets.TWITTER_CONSUMER_SECRET }} diff --git a/firebase_admin/__about__.py b/firebase_admin/__about__.py index ff12296f2..298f3703e 100644 --- a/firebase_admin/__about__.py +++ b/firebase_admin/__about__.py @@ -14,7 +14,7 @@ """About information (version, etc) for Firebase Admin SDK.""" -__version__ = '4.2.0' +__version__ = '4.3.0' __title__ = 'firebase_admin' __author__ = 'Firebase' __license__ = 'Apache License 2.0' From 28c4d46d0e7f0378b8d88ab821f6f746ae11564f Mon Sep 17 00:00:00 2001 From: Sam Stern Date: Thu, 28 May 2020 09:07:09 -0400 Subject: [PATCH 076/226] Add auth bulk get/delete snippets (#464) --- snippets/auth/index.py | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/snippets/auth/index.py b/snippets/auth/index.py index 428c54e09..9de9cfa03 100644 --- a/snippets/auth/index.py +++ b/snippets/auth/index.py @@ -192,6 +192,26 @@ def get_user_by_email(): print('Successfully fetched user data: {0}'.format(user.uid)) # [END get_user_by_email] +def bulk_get_users(): + # [START bulk_get_users] + from firebase_admin import auth + + result = auth.get_users([ + auth.UidIdentifier('uid1'), + auth.EmailIdentifier('user2@example.com'), + auth.PhoneIdentifier(+15555550003), + auth.ProviderIdentifier('google.com', 'google_uid4') + ]) + + print('Successfully fetched user data:') + for user in result.users: + print(user.uid) + + print('Unable to find users corresponding to these identifiers:') + for uid in result.not_found: + print(uid) + # [END bulk_get_users] + def get_user_by_phone_number(): phone = '+1 555 555 0100' # [START get_user_by_phone] @@ -242,6 +262,18 @@ def delete_user(uid): print('Successfully deleted user') # [END delete_user] +def bulk_delete_users(): + # [START bulk_delete_users] + from firebase_admin import auth + + result = auth.delete_users(["uid1", "uid2", "uid3"]) + + print('Successfully deleted {0} users'.format(result.success_count)) + print('Failed to delete {0} users'.format(result.failure_count)) + for err in result.errors: + print('error #{0}, reason: {1}'.format(result.index, result.reason)) + # [END bulk_delete_users] + def set_custom_user_claims(uid): # [START set_custom_user_claims] # Set admin privilege on the user corresponding to uid. From f43e6876684d2c7e9acf5b0b013642b44883c63a Mon Sep 17 00:00:00 2001 From: rsgowman Date: Tue, 16 Jun 2020 16:31:24 -0400 Subject: [PATCH 077/226] Fixed a flaky auth integration test by retrying the GetUser() API call (#469) Includes bonus fix to ensure bulk deleting users doesn't hit the quota. --- integration/test_auth.py | 27 ++++++++++++++++++++++----- 1 file changed, 22 insertions(+), 5 deletions(-) diff --git a/integration/test_auth.py b/integration/test_auth.py index 26cf53d20..16ae52a86 100644 --- a/integration/test_auth.py +++ b/integration/test_auth.py @@ -360,7 +360,18 @@ def test_last_refresh_timestamp(new_user_with_params: auth.UserRecord, api_key): # login to cause the last_refresh_timestamp to be set _sign_in_with_password(new_user_with_params.email, 'secret', api_key) - new_user_with_params = auth.get_user(new_user_with_params.uid) + + # Attempt to retrieve the user 3 times (with a small delay between each + # attempt). Occassionally, this call retrieves the user data without the + # lastLoginTime/lastRefreshTime set; possibly because it's hitting a + # different server than the login request uses. + user_record = None + for iteration in range(0, 3): + user_record = auth.get_user(new_user_with_params.uid) + if user_record.user_metadata.last_refresh_timestamp is not None: + break + + time.sleep(2 ** iteration) # Ensure the last refresh time occurred at approximately 'now'. (With a # tolerance of up to 1 minute; we ideally want to ensure that any timezone @@ -369,7 +380,7 @@ def test_last_refresh_timestamp(new_user_with_params: auth.UserRecord, api_key): millis_per_second = 1000 millis_per_minute = millis_per_second * 60 - last_refresh_timestamp = new_user_with_params.user_metadata.last_refresh_timestamp + last_refresh_timestamp = user_record.user_metadata.last_refresh_timestamp assert last_refresh_timestamp == pytest.approx( time.time()*millis_per_second, 1*millis_per_minute) @@ -498,7 +509,7 @@ def test_delete_multiple_users(self): uid2 = auth.create_user(disabled=False).uid uid3 = auth.create_user(disabled=True).uid - delete_users_result = auth.delete_users([uid1, uid2, uid3]) + delete_users_result = self._slow_delete_users(auth, [uid1, uid2, uid3]) assert delete_users_result.success_count == 3 assert delete_users_result.failure_count == 0 assert len(delete_users_result.errors) == 0 @@ -510,16 +521,22 @@ def test_delete_multiple_users(self): def test_is_idempotent(self): uid = auth.create_user().uid - delete_users_result = auth.delete_users([uid]) + delete_users_result = self._slow_delete_users(auth, [uid]) assert delete_users_result.success_count == 1 assert delete_users_result.failure_count == 0 # Delete the user again, ensuring that everything still counts as a # success. - delete_users_result = auth.delete_users([uid]) + delete_users_result = self._slow_delete_users(auth, [uid]) assert delete_users_result.success_count == 1 assert delete_users_result.failure_count == 0 + def _slow_delete_users(self, auth, uids): + """The batchDelete endpoint has a rate limit of 1 QPS. Use this test + helper to ensure you don't exceed the quota.""" + time.sleep(1) + return auth.delete_users(uids) + def test_revoke_refresh_tokens(new_user): user = auth.get_user(new_user.uid) From 7a3dcb70df80de9f5c140aff8df00d2b87f414c9 Mon Sep 17 00:00:00 2001 From: John Carter Date: Fri, 28 Aug 2020 05:26:45 +1200 Subject: [PATCH 078/226] Fix doc spelling (#486) --- firebase_admin/auth.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/firebase_admin/auth.py b/firebase_admin/auth.py index c5361bc38..5154bb495 100644 --- a/firebase_admin/auth.py +++ b/firebase_admin/auth.py @@ -450,7 +450,7 @@ def update_user(uid, **kwargs): # pylint: disable=differing-param-doc ``auth.DELETE_ATTRIBUTE``. password: The user's raw, unhashed password. (optional). disabled: A boolean indicating whether or not the user account is disabled (optional). - custom_claims: A dictionary or a JSON string contining the custom claims to be set on the + custom_claims: A dictionary or a JSON string containing the custom claims to be set on the user account (optional). To remove all custom claims, pass ``auth.DELETE_ATTRIBUTE``. valid_since: An integer signifying the seconds since the epoch (optional). This field is set by ``revoke_refresh_tokens`` and it is discouraged to set this field directly. From 9acaff94e0be601ca9961d67ed1e931a1166f6b9 Mon Sep 17 00:00:00 2001 From: Hiranya Jayathilaka Date: Thu, 27 Aug 2020 15:17:58 -0700 Subject: [PATCH 079/226] chore: Temporarily disabling a lint rule (#485) * chore: Temporarily disabling a lint rule * chore: Fixing another similar lint error --- firebase_admin/db.py | 4 +++- tests/testutils.py | 3 +++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/firebase_admin/db.py b/firebase_admin/db.py index d42370317..be2b9c917 100644 --- a/firebase_admin/db.py +++ b/firebase_admin/db.py @@ -979,7 +979,9 @@ def _extract_error_message(cls, response): return message - +# Temporarily disable the lint rule. For more information see: +# https://github.com/googleapis/google-auth-library-python/pull/561 +# pylint: disable=abstract-method class _EmulatorAdminCredentials(google.auth.credentials.Credentials): def __init__(self): google.auth.credentials.Credentials.__init__(self) diff --git a/tests/testutils.py b/tests/testutils.py index d0663ead1..556155253 100644 --- a/tests/testutils.py +++ b/tests/testutils.py @@ -105,6 +105,9 @@ def __call__(self, *args, **kwargs): # pylint: disable=arguments-differ raise self.error +# Temporarily disable the lint rule. For more information see: +# https://github.com/googleapis/google-auth-library-python/pull/561 +# pylint: disable=abstract-method class MockGoogleCredential(credentials.Credentials): """A mock Google authentication credential.""" def refresh(self, request): From 8868d8d1903642f161cbc994a50b0364a7b14ab1 Mon Sep 17 00:00:00 2001 From: ifielker Date: Fri, 11 Sep 2020 14:42:02 -0400 Subject: [PATCH 080/226] feat(ml): Adding Firebase ML support for AutoML models (#489) Added support for AutoML models RELEASE NOTES: Added support for creating, updating, getting, listing, publishing, unpublishing, and deleting Firebase-hosted custom ML models created with AutoML. --- firebase_admin/ml.py | 73 +++++++++++++++++++++------ integration/test_ml.py | 112 ++++++++++++++++++++++++++++++++--------- requirements.txt | 1 + tests/test_ml.py | 67 +++++++++++++++++++++++- 4 files changed, 213 insertions(+), 40 deletions(-) diff --git a/firebase_admin/ml.py b/firebase_admin/ml.py index 2613a3de3..bcc4b9390 100644 --- a/firebase_admin/ml.py +++ b/firebase_admin/ml.py @@ -53,6 +53,9 @@ _TAG_PATTERN = re.compile(r'^[A-Za-z0-9_-]{1,32}$') _GCS_TFLITE_URI_PATTERN = re.compile( r'^gs://(?P[a-z0-9_.-]{3,63})/(?P.+)$') +_AUTO_ML_MODEL_PATTERN = re.compile( + r'^projects/(?P[a-z0-9-]{6,30})/locations/(?P[^/]+)/' + + r'models/(?P[A-Za-z0-9]+)$') _RESOURCE_NAME_PATTERN = re.compile( r'^projects/(?P[a-z0-9-]{6,30})/models/(?P[A-Za-z0-9_-]{1,60})$') _OPERATION_NAME_PATTERN = re.compile( @@ -75,7 +78,7 @@ def _get_ml_service(app): def create_model(model, app=None): - """Creates a model in Firebase ML. + """Creates a model in the current Firebase project. Args: model: An ml.Model to create. @@ -89,7 +92,7 @@ def create_model(model, app=None): def update_model(model, app=None): - """Updates a model in Firebase ML. + """Updates a model's metadata or model file. Args: model: The ml.Model to update. @@ -103,7 +106,9 @@ def update_model(model, app=None): def publish_model(model_id, app=None): - """Publishes a model in Firebase ML. + """Publishes a Firebase ML model. + + A published model can be downloaded to client apps. Args: model_id: The id of the model to publish. @@ -117,7 +122,7 @@ def publish_model(model_id, app=None): def unpublish_model(model_id, app=None): - """Unpublishes a model in Firebase ML. + """Unpublishes a Firebase ML model. Args: model_id: The id of the model to unpublish. @@ -131,7 +136,7 @@ def unpublish_model(model_id, app=None): def get_model(model_id, app=None): - """Gets a model from Firebase ML. + """Gets the model specified by the given ID. Args: model_id: The id of the model to get. @@ -145,7 +150,7 @@ def get_model(model_id, app=None): def list_models(list_filter=None, page_size=None, page_token=None, app=None): - """Lists models from Firebase ML. + """Lists the current project's models. Args: list_filter: a list filter string such as ``tags:'tag_1'``. None will return all models. @@ -164,7 +169,7 @@ def list_models(list_filter=None, page_size=None, page_token=None, app=None): def delete_model(model_id, app=None): - """Deletes a model from Firebase ML. + """Deletes a model from the current project. Args: model_id: The id of the model you wish to delete. @@ -363,15 +368,10 @@ def __init__(self, model_source=None): def from_dict(cls, data): """Create an instance of the object from a dict.""" data_copy = dict(data) - model_source = None - gcs_tflite_uri = data_copy.pop('gcsTfliteUri', None) - if gcs_tflite_uri: - model_source = TFLiteGCSModelSource(gcs_tflite_uri=gcs_tflite_uri) - tflite_format = TFLiteFormat(model_source=model_source) + tflite_format = TFLiteFormat(model_source=cls._init_model_source(data_copy)) tflite_format._data = data_copy # pylint: disable=protected-access return tflite_format - def __eq__(self, other): if isinstance(other, self.__class__): # pylint: disable=protected-access @@ -381,6 +381,16 @@ def __eq__(self, other): def __ne__(self, other): return not self.__eq__(other) + @staticmethod + def _init_model_source(data): + gcs_tflite_uri = data.pop('gcsTfliteUri', None) + if gcs_tflite_uri: + return TFLiteGCSModelSource(gcs_tflite_uri=gcs_tflite_uri) + auto_ml_model = data.pop('automlModel', None) + if auto_ml_model: + return TFLiteAutoMlSource(auto_ml_model=auto_ml_model) + return None + @property def model_source(self): """The TF Lite model's location.""" @@ -593,8 +603,38 @@ def as_dict(self, for_upload=False): return {'gcsTfliteUri': self._gcs_tflite_uri} +class TFLiteAutoMlSource(TFLiteModelSource): + """TFLite model source representing a tflite model created with AutoML.""" + + def __init__(self, auto_ml_model, app=None): + self._app = app + self.auto_ml_model = auto_ml_model + + def __eq__(self, other): + if isinstance(other, self.__class__): + return self.auto_ml_model == other.auto_ml_model + return False + + def __ne__(self, other): + return not self.__eq__(other) + + @property + def auto_ml_model(self): + """Resource name of the model, created by the AutoML API or Cloud console.""" + return self._auto_ml_model + + @auto_ml_model.setter + def auto_ml_model(self, auto_ml_model): + self._auto_ml_model = _validate_auto_ml_model(auto_ml_model) + + def as_dict(self, for_upload=False): + """Returns a serializable representation of the object.""" + # Upload is irrelevant for auto_ml models + return {'automlModel': self._auto_ml_model} + + class ListModelsPage: - """Represents a page of models in a firebase project. + """Represents a page of models in a Firebase project. Provides methods for traversing the models included in this page, as well as retrieving subsequent pages of models. The iterator returned by @@ -740,6 +780,11 @@ def _validate_gcs_tflite_uri(uri): raise ValueError('GCS TFLite URI format is invalid.') return uri +def _validate_auto_ml_model(model): + if not _AUTO_ML_MODEL_PATTERN.match(model): + raise ValueError('Model resource name format is invalid.') + return model + def _validate_model_format(model_format): if not isinstance(model_format, ModelFormat): diff --git a/integration/test_ml.py b/integration/test_ml.py index 1d32ebed1..52cb1bb7e 100644 --- a/integration/test_ml.py +++ b/integration/test_ml.py @@ -22,6 +22,7 @@ import pytest +import firebase_admin from firebase_admin import exceptions from firebase_admin import ml from tests import testutils @@ -34,6 +35,11 @@ except ImportError: _TF_ENABLED = False +try: + from google.cloud import automl_v1 + _AUTOML_ENABLED = True +except ImportError: + _AUTOML_ENABLED = False def _random_identifier(prefix): #pylint: disable=unused-variable @@ -62,7 +68,6 @@ def _random_identifier(prefix): 'file_name': 'invalid_model.tflite' } - @pytest.fixture def firebase_model(request): args = request.param @@ -101,6 +106,7 @@ def _clean_up_model(model): try: # Try to delete the model. # Some tests delete the model as part of the test. + model.wait_for_unlocked() ml.delete_model(model.model_id) except exceptions.NotFoundError: pass @@ -132,35 +138,45 @@ def check_model(model, args): assert model.locked is False assert model.etag is not None +# Model Format Checks -def check_model_format(model, has_model_format=False, validation_error=None): - if has_model_format: - assert model.validation_error == validation_error - assert model.published is False - assert model.model_format.model_source.gcs_tflite_uri.startswith('gs://') - if validation_error: - assert model.model_format.size_bytes is None - assert model.model_hash is None - else: - assert model.model_format.size_bytes is not None - assert model.model_hash is not None - else: - assert model.model_format is None - assert model.validation_error == 'No model file has been uploaded.' - assert model.published is False +def check_no_model_format(model): + assert model.model_format is None + assert model.validation_error == 'No model file has been uploaded.' + assert model.published is False + assert model.model_hash is None + + +def check_tflite_gcs_format(model, validation_error=None): + assert model.validation_error == validation_error + assert model.published is False + assert model.model_format.model_source.gcs_tflite_uri.startswith('gs://') + if validation_error: + assert model.model_format.size_bytes is None assert model.model_hash is None + else: + assert model.model_format.size_bytes is not None + assert model.model_hash is not None + + +def check_tflite_automl_format(model): + assert model.validation_error is None + assert model.published is False + assert model.model_format.model_source.auto_ml_model.startswith('projects/') + # Automl models don't have validation errors since they are references + # to valid automl models. @pytest.mark.parametrize('firebase_model', [NAME_AND_TAGS_ARGS], indirect=True) def test_create_simple_model(firebase_model): check_model(firebase_model, NAME_AND_TAGS_ARGS) - check_model_format(firebase_model) + check_no_model_format(firebase_model) @pytest.mark.parametrize('firebase_model', [FULL_MODEL_ARGS], indirect=True) def test_create_full_model(firebase_model): check_model(firebase_model, FULL_MODEL_ARGS) - check_model_format(firebase_model, True) + check_tflite_gcs_format(firebase_model) @pytest.mark.parametrize('firebase_model', [FULL_MODEL_ARGS], indirect=True) @@ -175,14 +191,14 @@ def test_create_already_existing_fails(firebase_model): @pytest.mark.parametrize('firebase_model', [INVALID_FULL_MODEL_ARGS], indirect=True) def test_create_invalid_model(firebase_model): check_model(firebase_model, INVALID_FULL_MODEL_ARGS) - check_model_format(firebase_model, True, 'Invalid flatbuffer format') + check_tflite_gcs_format(firebase_model, 'Invalid flatbuffer format') @pytest.mark.parametrize('firebase_model', [NAME_AND_TAGS_ARGS], indirect=True) def test_get_model(firebase_model): get_model = ml.get_model(firebase_model.model_id) check_model(get_model, NAME_AND_TAGS_ARGS) - check_model_format(get_model) + check_no_model_format(get_model) @pytest.mark.parametrize('firebase_model', [NAME_ONLY_ARGS], indirect=True) @@ -201,12 +217,12 @@ def test_update_model(firebase_model): firebase_model.display_name = new_model_name updated_model = ml.update_model(firebase_model) check_model(updated_model, NAME_ONLY_ARGS_UPDATED) - check_model_format(updated_model) + check_no_model_format(updated_model) # Second call with same model does not cause error updated_model2 = ml.update_model(updated_model) check_model(updated_model2, NAME_ONLY_ARGS_UPDATED) - check_model_format(updated_model2) + check_no_model_format(updated_model2) @pytest.mark.parametrize('firebase_model', [NAME_ONLY_ARGS], indirect=True) @@ -290,7 +306,7 @@ def test_delete_model(firebase_model): # Test tensor flow conversion functions if tensor flow is enabled. #'pip install tensorflow' in the environment if you want _TF_ENABLED = True -#'pip install tensorflow==2.0.0b' for version 2 etc. +#'pip install tensorflow==2.2.0' for version 2.2.0 etc. def _clean_up_directory(save_dir): @@ -334,6 +350,7 @@ def saved_model_dir(keras_model): _clean_up_directory(parent) + @pytest.mark.skipif(not _TF_ENABLED, reason='Tensor flow is required for this test.') def test_from_keras_model(keras_model): source = ml.TFLiteGCSModelSource.from_keras_model(keras_model, 'model2.tflite') @@ -348,7 +365,7 @@ def test_from_keras_model(keras_model): try: check_model(created_model, {'display_name': model.display_name}) - check_model_format(created_model, True) + check_tflite_gcs_format(created_model) finally: _clean_up_model(created_model) @@ -371,3 +388,50 @@ def test_from_saved_model(saved_model_dir): assert created_model.validation_error is None finally: _clean_up_model(created_model) + + +# Test AutoML functionality if AutoML is enabled. +#'pip install google-cloud-automl' in the environment if you want _AUTOML_ENABLED = True +# You will also need a predefined AutoML model named 'admin_sdk_integ_test1' to run the +# successful test. (Test is skipped otherwise) + +@pytest.fixture +def automl_model(): + assert _AUTOML_ENABLED + + # It takes > 20 minutes to train a model, so we expect a predefined AutoMl + # model named 'admin_sdk_integ_test1' to exist in the project, or we skip + # the test. + automl_client = automl_v1.AutoMlClient() + project_id = firebase_admin.get_app().project_id + parent = automl_client.location_path(project_id, 'us-central1') + models = automl_client.list_models(parent, filter_="display_name=admin_sdk_integ_test1") + # Expecting exactly one. (Ok to use last one if somehow more than 1) + automl_ref = None + for model in models: + automl_ref = model.name + + # Skip if no pre-defined model. (It takes min > 20 minutes to train a model) + if automl_ref is None: + pytest.skip("No pre-existing AutoML model found. Skipping test") + + source = ml.TFLiteAutoMlSource(automl_ref) + tflite_format = ml.TFLiteFormat(model_source=source) + ml_model = ml.Model( + display_name=_random_identifier('TestModel_automl_'), + tags=['test_automl'], + model_format=tflite_format) + model = ml.create_model(model=ml_model) + yield model + _clean_up_model(model) + +@pytest.mark.skipif(not _AUTOML_ENABLED, reason='AutoML is required for this test.') +def test_automl_model(automl_model): + # This test looks for a predefined automl model with display_name = 'admin_sdk_integ_test1' + automl_model.wait_for_unlocked() + + check_model(automl_model, { + 'display_name': automl_model.display_name, + 'tags': ['test_automl'], + }) + check_tflite_automl_format(automl_model) diff --git a/requirements.txt b/requirements.txt index dbeaee3b6..1a55482da 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,5 +7,6 @@ pytest-localserver >= 0.4.1 cachecontrol >= 0.12.6 google-api-core[grpc] >= 1.14.0, < 2.0.0dev; platform.python_implementation != 'PyPy' google-api-python-client >= 1.7.8 +google-auth == 1.18.0 # temporary workaround google-cloud-firestore >= 1.4.0; platform.python_implementation != 'PyPy' google-cloud-storage >= 1.18.0 diff --git a/tests/test_ml.py b/tests/test_ml.py index 10b0441db..abd6d06f9 100644 --- a/tests/test_ml.py +++ b/tests/test_ml.py @@ -122,6 +122,18 @@ } TFLITE_FORMAT_2 = ml.TFLiteFormat.from_dict(TFLITE_FORMAT_JSON_2) +AUTOML_MODEL_NAME = 'projects/111111111111/locations/us-central1/models/ICN7683346839371803263' +AUTOML_MODEL_SOURCE = ml.TFLiteAutoMlSource(AUTOML_MODEL_NAME) +TFLITE_FORMAT_JSON_3 = { + 'automlModel': AUTOML_MODEL_NAME, + 'sizeBytes': '3456789' +} +TFLITE_FORMAT_3 = ml.TFLiteFormat.from_dict(TFLITE_FORMAT_JSON_3) + +AUTOML_MODEL_NAME_2 = 'projects/2222222222/locations/us-central1/models/ICN2222222222222222222' +AUTOML_MODEL_NAME_JSON_2 = {'automlModel': AUTOML_MODEL_NAME_2} +AUTOML_MODEL_SOURCE_2 = ml.TFLiteAutoMlSource(AUTOML_MODEL_NAME_2) + CREATED_UPDATED_MODEL_JSON_1 = { 'name': MODEL_NAME_1, 'displayName': DISPLAY_NAME_1, @@ -405,7 +417,15 @@ def test_model_keyword_based_creation_and_setters(self): 'tfliteModel': TFLITE_FORMAT_JSON_2 } - def test_model_format_source_creation(self): + model.model_format = TFLITE_FORMAT_3 + assert model.as_dict() == { + 'displayName': DISPLAY_NAME_2, + 'tags': TAGS_2, + 'tfliteModel': TFLITE_FORMAT_JSON_3 + } + + + def test_gcs_tflite_model_format_source_creation(self): model_source = ml.TFLiteGCSModelSource(gcs_tflite_uri=GCS_TFLITE_URI) model_format = ml.TFLiteFormat(model_source=model_source) model = ml.Model(display_name=DISPLAY_NAME_1, model_format=model_format) @@ -416,6 +436,17 @@ def test_model_format_source_creation(self): } } + def test_auto_ml_tflite_model_format_source_creation(self): + model_source = ml.TFLiteAutoMlSource(auto_ml_model=AUTOML_MODEL_NAME) + model_format = ml.TFLiteFormat(model_source=model_source) + model = ml.Model(display_name=DISPLAY_NAME_1, model_format=model_format) + assert model.as_dict() == { + 'displayName': DISPLAY_NAME_1, + 'tfliteModel': { + 'automlModel': AUTOML_MODEL_NAME + } + } + def test_source_creation_from_tflite_file(self): model_source = ml.TFLiteGCSModelSource.from_tflite_model_file( "my_model.tflite", "my_bucket") @@ -423,12 +454,19 @@ def test_source_creation_from_tflite_file(self): 'gcsTfliteUri': 'gs://my_bucket/Firebase/ML/Models/my_model.tflite' } - def test_model_source_setters(self): + def test_gcs_tflite_model_source_setters(self): model_source = ml.TFLiteGCSModelSource(GCS_TFLITE_URI) model_source.gcs_tflite_uri = GCS_TFLITE_URI_2 assert model_source.gcs_tflite_uri == GCS_TFLITE_URI_2 assert model_source.as_dict() == GCS_TFLITE_URI_JSON_2 + def test_auto_ml_tflite_model_source_setters(self): + model_source = ml.TFLiteAutoMlSource(AUTOML_MODEL_NAME) + model_source.auto_ml_model = AUTOML_MODEL_NAME_2 + assert model_source.auto_ml_model == AUTOML_MODEL_NAME_2 + assert model_source.as_dict() == AUTOML_MODEL_NAME_JSON_2 + + def test_model_format_setters(self): model_format = ml.TFLiteFormat(model_source=GCS_TFLITE_MODEL_SOURCE) model_format.model_source = GCS_TFLITE_MODEL_SOURCE_2 @@ -439,6 +477,14 @@ def test_model_format_setters(self): } } + model_format.model_source = AUTOML_MODEL_SOURCE + assert model_format.model_source == AUTOML_MODEL_SOURCE + assert model_format.as_dict() == { + 'tfliteModel': { + 'automlModel': AUTOML_MODEL_NAME + } + } + def test_model_as_dict_for_upload(self): model_source = ml.TFLiteGCSModelSource(gcs_tflite_uri=GCS_TFLITE_URI) model_format = ml.TFLiteFormat(model_source=model_source) @@ -524,6 +570,23 @@ def test_gcs_tflite_source_validation_errors(self, uri, exc_type): ml.TFLiteGCSModelSource(gcs_tflite_uri=uri) check_error(excinfo, exc_type) + @pytest.mark.parametrize('auto_ml_model, exc_type', [ + (123, TypeError), + ('abc', ValueError), + ('/projects/123456/locations/us-central1/models/noLeadingSlash', ValueError), + ('projects/123546/models/ICN123456', ValueError), + ('projects//locations/us-central1/models/ICN123456', ValueError), + ('projects/123456/locations//models/ICN123456', ValueError), + ('projects/123456/locations/us-central1/models/', ValueError), + ('projects/ABC/locations/us-central1/models/ICN123456', ValueError), + ('projects/123456/locations/us-central1/models/@#$%^&', ValueError), + ('projects/123456/locations/us-cent/ral1/models/ICN123456', ValueError), + ]) + def test_auto_ml_tflite_source_validation_errors(self, auto_ml_model, exc_type): + with pytest.raises(exc_type) as excinfo: + ml.TFLiteAutoMlSource(auto_ml_model=auto_ml_model) + check_error(excinfo, exc_type) + def test_wait_for_unlocked_not_locked(self): model = ml.Model(display_name="not_locked") model.wait_for_unlocked() From 378ec06878ee6e5cd45502b51139e15e96f8b7bb Mon Sep 17 00:00:00 2001 From: Lahiru Maramba Date: Tue, 15 Sep 2020 14:05:33 -0400 Subject: [PATCH 081/226] [chore] Release 4.4.0 (#490) --- firebase_admin/__about__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/firebase_admin/__about__.py b/firebase_admin/__about__.py index 298f3703e..de6a75223 100644 --- a/firebase_admin/__about__.py +++ b/firebase_admin/__about__.py @@ -14,7 +14,7 @@ """About information (version, etc) for Firebase Admin SDK.""" -__version__ = '4.3.0' +__version__ = '4.4.0' __title__ = 'firebase_admin' __author__ = 'Firebase' __license__ = 'Apache License 2.0' From 873aa7c10f32d1f07ce28dde6dc5414e7583afd3 Mon Sep 17 00:00:00 2001 From: Lahiru Maramba Date: Tue, 15 Sep 2020 16:37:03 -0400 Subject: [PATCH 082/226] [chore] Release 4.4.0 Take 2 --- integration/test_messaging.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integration/test_messaging.py b/integration/test_messaging.py index 001b96a0a..b5612b63d 100644 --- a/integration/test_messaging.py +++ b/integration/test_messaging.py @@ -75,7 +75,7 @@ def test_send_invalid_token(): token=_REGISTRATION_TOKEN, notification=messaging.Notification('test-title', 'test-body') ) - with pytest.raises(messaging.SenderIdMismatchError): + with pytest.raises(messaging.UnregisteredError): messaging.send(msg, dry_run=True) def test_send_malformed_token(): From 2b8cb45cec4df47c45b122fee80721e287f45a8a Mon Sep 17 00:00:00 2001 From: Mike Moore Date: Fri, 23 Oct 2020 15:20:30 -0600 Subject: [PATCH 083/226] Fixed typo in code comments. (#497) --- firebase_admin/credentials.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/firebase_admin/credentials.py b/firebase_admin/credentials.py index 8f9c504f0..1f207e483 100644 --- a/firebase_admin/credentials.py +++ b/firebase_admin/credentials.py @@ -165,7 +165,7 @@ class RefreshToken(Base): def __init__(self, refresh_token): """Initializes a credential from a refresh token JSON file. - The JSON must consist of client_id, client_secert and refresh_token fields. Refresh + The JSON must consist of client_id, client_secret and refresh_token fields. Refresh token files are typically created and managed by the gcloud SDK. To instantiate a credential from a refresh token file, either specify the file path or a dict representing the parsed contents of the file. From 273b05801a30102fee4ad0249b17915b197201de Mon Sep 17 00:00:00 2001 From: Hiranya Jayathilaka Date: Thu, 3 Dec 2020 17:43:57 -0800 Subject: [PATCH 084/226] fix(fcm): Converting unexpected gapic runtime errors to FirebaseError (#509) --- firebase_admin/messaging.py | 3 +-- tests/test_messaging.py | 53 ++++++++++++++++++++++++++++++++----- 2 files changed, 47 insertions(+), 9 deletions(-) diff --git a/firebase_admin/messaging.py b/firebase_admin/messaging.py index 217cf0a56..7c92a3d8d 100644 --- a/firebase_admin/messaging.py +++ b/firebase_admin/messaging.py @@ -16,7 +16,6 @@ import json -import googleapiclient from googleapiclient import http from googleapiclient import _auth import requests @@ -388,7 +387,7 @@ def batch_callback(_, response, error): try: batch.execute() - except googleapiclient.http.HttpError as error: + except Exception as error: raise self._handle_batch_error(error) else: return BatchResponse(responses) diff --git a/tests/test_messaging.py b/tests/test_messaging.py index 6333aad46..8eb24c0a9 100644 --- a/tests/test_messaging.py +++ b/tests/test_messaging.py @@ -1792,6 +1792,15 @@ def test_send_unknown_fcm_error_code(self, status): assert json.loads(recorder[0].body.decode()) == body +class _HttpMockException: + + def __init__(self, exc): + self._exc = exc + + def request(self, url, **kwargs): + raise self._exc + + class TestBatch: @classmethod @@ -1803,17 +1812,21 @@ def setup_class(cls): def teardown_class(cls): testutils.cleanup_apps() - def _instrument_batch_messaging_service(self, app=None, status=200, payload=''): + def _instrument_batch_messaging_service(self, app=None, status=200, payload='', exc=None): if not app: app = firebase_admin.get_app() + fcm_service = messaging._get_messaging_service(app) - if status == 200: - content_type = 'multipart/mixed; boundary=boundary' + if exc: + fcm_service._transport = _HttpMockException(exc) else: - content_type = 'application/json' - fcm_service._transport = http.HttpMockSequence([ - ({'status': str(status), 'content-type': content_type}, payload), - ]) + if status == 200: + content_type = 'multipart/mixed; boundary=boundary' + else: + content_type = 'application/json' + fcm_service._transport = http.HttpMockSequence([ + ({'status': str(status), 'content-type': content_type}, payload), + ]) return fcm_service def _batch_payload(self, payloads): @@ -2027,6 +2040,19 @@ def test_send_all_batch_fcm_error_code(self, status): messaging.send_all([msg]) check_exception(excinfo.value, 'test error', status) + def test_send_all_runtime_exception(self): + exc = BrokenPipeError('Test error') + _ = self._instrument_batch_messaging_service(exc=exc) + msg = messaging.Message(topic='foo') + + with pytest.raises(exceptions.UnknownError) as excinfo: + messaging.send_all([msg]) + + expected = 'Unknown error while making a remote service call: Test error' + assert str(excinfo.value) == expected + assert excinfo.value.cause is exc + assert excinfo.value.http_response is None + class TestSendMulticast(TestBatch): @@ -2204,6 +2230,19 @@ def test_send_multicast_batch_fcm_error_code(self, status): messaging.send_multicast(msg) check_exception(excinfo.value, 'test error', status) + def test_send_multicast_runtime_exception(self): + exc = BrokenPipeError('Test error') + _ = self._instrument_batch_messaging_service(exc=exc) + msg = messaging.MulticastMessage(tokens=['foo']) + + with pytest.raises(exceptions.UnknownError) as excinfo: + messaging.send_multicast(msg) + + expected = 'Unknown error while making a remote service call: Test error' + assert str(excinfo.value) == expected + assert excinfo.value.cause is exc + assert excinfo.value.http_response is None + class TestTopicManagement: From ba12bda0d550c33a2831b02ab31a47b32d06aad3 Mon Sep 17 00:00:00 2001 From: Hiranya Jayathilaka Date: Fri, 4 Dec 2020 10:49:18 -0800 Subject: [PATCH 085/226] change: Deprecated support for Python 3.5 (#511) --- README.md | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 7f33af68b..180bc2bff 100644 --- a/README.md +++ b/README.md @@ -43,8 +43,10 @@ requests, code review feedback, and also pull requests. ## Supported Python Versions -We currently support Python 3.5+. Firebase Admin Python SDK is also tested on -PyPy and [Google App Engine](https://cloud.google.com/appengine/) environments. +We currently support Python 3.5+. However, Python 3.5 support is deprecated, +and the developers are strongly advised to use Python 3.6 or higher. Firebase +Admin Python SDK is also tested on PyPy and +[Google App Engine](https://cloud.google.com/appengine/) environments. ## Documentation From eefc31b67bc8ad50a734a7bb0a52f56716e0e4d7 Mon Sep 17 00:00:00 2001 From: Hiranya Jayathilaka Date: Wed, 9 Dec 2020 11:20:29 -0800 Subject: [PATCH 086/226] [chore] Release 4.5.0 (#516) --- firebase_admin/__about__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/firebase_admin/__about__.py b/firebase_admin/__about__.py index de6a75223..428d85d06 100644 --- a/firebase_admin/__about__.py +++ b/firebase_admin/__about__.py @@ -14,7 +14,7 @@ """About information (version, etc) for Firebase Admin SDK.""" -__version__ = '4.4.0' +__version__ = '4.5.0' __title__ = 'firebase_admin' __author__ = 'Firebase' __license__ = 'Apache License 2.0' From 6d0023104413bfb02b8496e2469d5c363d6ab6db Mon Sep 17 00:00:00 2001 From: Hiranya Jayathilaka Date: Tue, 15 Dec 2020 12:01:20 -0800 Subject: [PATCH 087/226] fix(rtdb): Support parsing non-US RTDB instance URLs (#517) * fix(rtdb): Support parsing non-US RTDB instance URLs * fix: Deferred credential loading until emulator URL is determined --- firebase_admin/db.py | 83 +++++++++++++++++++------------------------- tests/test_db.py | 70 ++++++++++++++++++++++++------------- 2 files changed, 81 insertions(+), 72 deletions(-) diff --git a/firebase_admin/db.py b/firebase_admin/db.py index be2b9c917..3384bd440 100644 --- a/firebase_admin/db.py +++ b/firebase_admin/db.py @@ -768,10 +768,10 @@ def __init__(self, app): self._credential = app.credential db_url = app.options.get('databaseURL') if db_url: - _DatabaseService._parse_db_url(db_url) # Just for validation. self._db_url = db_url else: self._db_url = None + auth_override = _DatabaseService._get_auth_override(app) if auth_override not in (self._DEFAULT_AUTH_OVERRIDE, {}): self._auth_override = json.dumps(auth_override, separators=(',', ':')) @@ -795,15 +795,29 @@ def get_client(self, db_url=None): if db_url is None: db_url = self._db_url - base_url, namespace = _DatabaseService._parse_db_url(db_url, self._emulator_host) - if base_url == 'https://{0}.firebaseio.com'.format(namespace): - # Production base_url. No need to specify namespace in query params. - params = {} - credential = self._credential.get_credential() - else: - # Emulator base_url. Use fake credentials and specify ?ns=foo in query params. + if not db_url or not isinstance(db_url, str): + raise ValueError( + 'Invalid database URL: "{0}". Database URL must be a non-empty ' + 'URL string.'.format(db_url)) + + parsed_url = parse.urlparse(db_url) + if not parsed_url.netloc: + raise ValueError( + 'Invalid database URL: "{0}". Database URL must be a wellformed ' + 'URL string.'.format(db_url)) + + emulator_config = self._get_emulator_config(parsed_url) + if emulator_config: credential = _EmulatorAdminCredentials() - params = {'ns': namespace} + base_url = emulator_config.base_url + params = {'ns': emulator_config.namespace} + else: + # Defer credential lookup until we are certain it's going to be prod connection. + credential = self._credential.get_credential() + base_url = 'https://{0}'.format(parsed_url.netloc) + params = {} + + if self._auth_override: params['auth_variable_override'] = self._auth_override @@ -813,47 +827,20 @@ def get_client(self, db_url=None): self._clients[client_cache_key] = client return self._clients[client_cache_key] - @classmethod - def _parse_db_url(cls, url, emulator_host=None): - """Parses (base_url, namespace) from a database URL. - - The input can be either a production URL (https://foo-bar.firebaseio.com/) - or an Emulator URL (http://localhost:8080/?ns=foo-bar). In case of Emulator - URL, the namespace is extracted from the query param ns. The resulting - base_url never includes query params. - - If url is a production URL and emulator_host is specified, the result - base URL will use emulator_host instead. emulator_host is ignored - if url is already an emulator URL. - """ - if not url or not isinstance(url, str): - raise ValueError( - 'Invalid database URL: "{0}". Database URL must be a non-empty ' - 'URL string.'.format(url)) - parsed_url = parse.urlparse(url) - if parsed_url.netloc.endswith('.firebaseio.com'): - return cls._parse_production_url(parsed_url, emulator_host) - - return cls._parse_emulator_url(parsed_url) - - @classmethod - def _parse_production_url(cls, parsed_url, emulator_host): - """Parses production URL like https://foo-bar.firebaseio.com/""" + def _get_emulator_config(self, parsed_url): + """Checks whether the SDK should connect to the RTDB emulator.""" + EmulatorConfig = collections.namedtuple('EmulatorConfig', ['base_url', 'namespace']) if parsed_url.scheme != 'https': - raise ValueError( - 'Invalid database URL scheme: "{0}". Database URL must be an HTTPS URL.'.format( - parsed_url.scheme)) - namespace = parsed_url.netloc.split('.')[0] - if not namespace: - raise ValueError( - 'Invalid database URL: "{0}". Database URL must be a valid URL to a ' - 'Firebase Realtime Database instance.'.format(parsed_url.geturl())) + # Emulator mode enabled by passing http URL via AppOptions + base_url, namespace = _DatabaseService._parse_emulator_url(parsed_url) + return EmulatorConfig(base_url, namespace) + if self._emulator_host: + # Emulator mode enabled via environment variable + base_url = 'http://{0}'.format(self._emulator_host) + namespace = parsed_url.netloc.split('.')[0] + return EmulatorConfig(base_url, namespace) - if emulator_host: - base_url = 'http://{0}'.format(emulator_host) - else: - base_url = 'https://{0}'.format(parsed_url.netloc) - return base_url, namespace + return None @classmethod def _parse_emulator_url(cls, parsed_url): diff --git a/tests/test_db.py b/tests/test_db.py index 2989fc030..5f8ba4b51 100644 --- a/tests/test_db.py +++ b/tests/test_db.py @@ -15,6 +15,7 @@ """Tests for firebase_admin.db.""" import collections import json +import os import sys import time @@ -28,6 +29,9 @@ from tests import testutils +_EMULATOR_HOST_ENV_VAR = 'FIREBASE_DATABASE_EMULATOR_HOST' + + class MockAdapter(testutils.MockAdapter): """A mock HTTP adapter that mimics RTDB server behavior.""" @@ -702,52 +706,70 @@ def test_no_db_url(self): 'url,emulator_host,expected_base_url,expected_namespace', [ # Production URLs with no override: - ('https://test.firebaseio.com', None, 'https://test.firebaseio.com', 'test'), - ('https://test.firebaseio.com/', None, 'https://test.firebaseio.com', 'test'), + ('https://test.firebaseio.com', None, 'https://test.firebaseio.com', None), + ('https://test.firebaseio.com/', None, 'https://test.firebaseio.com', None), # Production URLs with emulator_host override: ('https://test.firebaseio.com', 'localhost:9000', 'http://localhost:9000', 'test'), ('https://test.firebaseio.com/', 'localhost:9000', 'http://localhost:9000', 'test'), - # Emulator URLs with no override. + # Emulator URL with no override. ('http://localhost:8000/?ns=test', None, 'http://localhost:8000', 'test'), + # emulator_host is ignored when the original URL is already emulator. ('http://localhost:8000/?ns=test', 'localhost:9999', 'http://localhost:8000', 'test'), ] ) def test_parse_db_url(self, url, emulator_host, expected_base_url, expected_namespace): - base_url, namespace = db._DatabaseService._parse_db_url(url, emulator_host) - assert base_url == expected_base_url - assert namespace == expected_namespace - - @pytest.mark.parametrize('url,emulator_host', [ - ('', None), - (None, None), - (42, None), - ('test.firebaseio.com', None), # Not a URL. - ('http://test.firebaseio.com', None), # Use of non-HTTPs in production URLs. - ('ftp://test.firebaseio.com', None), # Use of non-HTTPs in production URLs. - ('https://example.com', None), # Invalid RTDB URL. - ('http://localhost:9000/', None), # No ns specified. - ('http://localhost:9000/?ns=', None), # No ns specified. - ('http://localhost:9000/?ns=test1&ns=test2', None), # Two ns parameters specified. - ('ftp://localhost:9000/?ns=test', None), # Neither HTTP or HTTPS. + if emulator_host: + os.environ[_EMULATOR_HOST_ENV_VAR] = emulator_host + + try: + firebase_admin.initialize_app(testutils.MockCredential(), {'databaseURL' : url}) + ref = db.reference() + assert ref._client._base_url == expected_base_url + assert ref._client.params.get('ns') == expected_namespace + if expected_base_url.startswith('http://localhost'): + assert isinstance(ref._client.credential, db._EmulatorAdminCredentials) + else: + assert isinstance(ref._client.credential, testutils.MockGoogleCredential) + finally: + if _EMULATOR_HOST_ENV_VAR in os.environ: + del os.environ[_EMULATOR_HOST_ENV_VAR] + + @pytest.mark.parametrize('url', [ + '', + None, + 42, + 'test.firebaseio.com', # Not a URL. + 'http://test.firebaseio.com', # Use of non-HTTPs in production URLs. + 'ftp://test.firebaseio.com', # Use of non-HTTPs in production URLs. + 'http://localhost:9000/', # No ns specified. + 'http://localhost:9000/?ns=', # No ns specified. + 'http://localhost:9000/?ns=test1&ns=test2', # Two ns parameters specified. + 'ftp://localhost:9000/?ns=test', # Neither HTTP or HTTPS. ]) - def test_parse_db_url_errors(self, url, emulator_host): + def test_parse_db_url_errors(self, url): + firebase_admin.initialize_app(testutils.MockCredential(), {'databaseURL' : url}) with pytest.raises(ValueError): - db._DatabaseService._parse_db_url(url, emulator_host) + db.reference() @pytest.mark.parametrize('url', [ - 'https://test.firebaseio.com', 'https://test.firebaseio.com/' + 'https://test.firebaseio.com', 'https://test.firebaseio.com/', + 'https://test.eu-west1.firebasdatabase.app', 'https://test.eu-west1.firebasdatabase.app/' ]) def test_valid_db_url(self, url): firebase_admin.initialize_app(testutils.MockCredential(), {'databaseURL' : url}) ref = db.reference() - assert ref._client.base_url == 'https://test.firebaseio.com' + expected_url = url + if url.endswith('/'): + expected_url = url[:-1] + assert ref._client.base_url == expected_url assert 'auth_variable_override' not in ref._client.params + assert 'ns' not in ref._client.params @pytest.mark.parametrize('url', [ - None, '', 'foo', 'http://test.firebaseio.com', 'https://google.com', + None, '', 'foo', 'http://test.firebaseio.com', 'http://test.firebasedatabase.app', True, False, 1, 0, dict(), list(), tuple(), _Object() ]) def test_invalid_db_url(self, url): From d18f42b578be8aa24025dce471c6be9ca457f939 Mon Sep 17 00:00:00 2001 From: Hiranya Jayathilaka Date: Thu, 17 Dec 2020 15:38:07 -0800 Subject: [PATCH 088/226] Adding delayed response message for holidays (#520) --- ISSUE_TEMPLATE.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ISSUE_TEMPLATE.md b/ISSUE_TEMPLATE.md index 5de83b2cc..b7dc143fa 100644 --- a/ISSUE_TEMPLATE.md +++ b/ISSUE_TEMPLATE.md @@ -1,3 +1,5 @@ +**Thank you for submitting your issue. We are operating at reduced capacity from Dec 18 2020 to Jan 4 2021. Please expect delayed responses. For more urgent requests please reach us via our support channels https://firebase.google.com/support** + ### [READ] Step 1: Are you in the right place? * For issues or feature requests related to __the code in this repository__ From 4849ca8b346f236ed814e9e6ecf8ddcd5f659f9d Mon Sep 17 00:00:00 2001 From: Lahiru Maramba Date: Tue, 5 Jan 2021 13:26:53 -0500 Subject: [PATCH 089/226] Remove delayed response message for holidays (#526) - Remove delayed response message for holidays --- ISSUE_TEMPLATE.md | 2 -- 1 file changed, 2 deletions(-) diff --git a/ISSUE_TEMPLATE.md b/ISSUE_TEMPLATE.md index b7dc143fa..5de83b2cc 100644 --- a/ISSUE_TEMPLATE.md +++ b/ISSUE_TEMPLATE.md @@ -1,5 +1,3 @@ -**Thank you for submitting your issue. We are operating at reduced capacity from Dec 18 2020 to Jan 4 2021. Please expect delayed responses. For more urgent requests please reach us via our support channels https://firebase.google.com/support** - ### [READ] Step 1: Are you in the right place? * For issues or feature requests related to __the code in this repository__ From f6400ced6b8a1f92faec93ea3922efaec05378be Mon Sep 17 00:00:00 2001 From: Allen Thomas Varghese Date: Thu, 7 Jan 2021 20:55:07 +0000 Subject: [PATCH 090/226] Add Py3.9 support (#525) --- .github/workflows/ci.yml | 2 +- setup.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 61d3861bd..067cd8f18 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -8,7 +8,7 @@ jobs: strategy: fail-fast: false matrix: - python: [3.5, 3.6, 3.7, pypy3] + python: [3.5, 3.6, 3.7, 3.9, pypy3] steps: - uses: actions/checkout@v1 diff --git a/setup.py b/setup.py index 0ebcc3455..4f0d42365 100644 --- a/setup.py +++ b/setup.py @@ -64,6 +64,7 @@ 'Programming Language :: Python :: 3.5', 'Programming Language :: Python :: 3.6', 'Programming Language :: Python :: 3.7', + 'Programming Language :: Python :: 3.9', 'License :: OSI Approved :: Apache Software License', ], ) From b35abb9bb74ab7629ae1ff5c6f73772c6dd450c5 Mon Sep 17 00:00:00 2001 From: Hiranya Jayathilaka Date: Tue, 12 Jan 2021 11:56:22 -0800 Subject: [PATCH 091/226] [chore] Release 4.5.1 (#527) --- firebase_admin/__about__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/firebase_admin/__about__.py b/firebase_admin/__about__.py index 428d85d06..28979015d 100644 --- a/firebase_admin/__about__.py +++ b/firebase_admin/__about__.py @@ -14,7 +14,7 @@ """About information (version, etc) for Firebase Admin SDK.""" -__version__ = '4.5.0' +__version__ = '4.5.1' __title__ = 'firebase_admin' __author__ = 'Firebase' __license__ = 'Apache License 2.0' From 8c05981a8bbf7de3a185439f578d9a5e4e1f503c Mon Sep 17 00:00:00 2001 From: xmo-odoo Date: Wed, 10 Feb 2021 21:12:49 +0100 Subject: [PATCH 092/226] Remove use of method_whitelist when possible (#532) Deprecated in favor of allowed_methods. Fall back to the old argument for older versions of urllib3 which do not support the new one. Uses conditional attribute check as recommended in https://github.com/urllib3/urllib3/issues/2057 --- firebase_admin/_http_client.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/firebase_admin/_http_client.py b/firebase_admin/_http_client.py index f6f0d89fa..ae312095b 100644 --- a/firebase_admin/_http_client.py +++ b/firebase_admin/_http_client.py @@ -22,14 +22,16 @@ from requests.packages.urllib3.util import retry # pylint: disable=import-error -_ANY_METHOD = None - +if hasattr(retry.Retry.DEFAULT, 'allowed_methods'): + _ANY_METHOD = {'allowed_methods': None} +else: + _ANY_METHOD = {'method_whitelist': None} # Default retry configuration: Retries once on low-level connection and socket read errors. # Retries up to 4 times on HTTP 500 and 503 errors, with exponential backoff. Returns the # last response upon exhausting all retries. DEFAULT_RETRY_CONFIG = retry.Retry( - connect=1, read=1, status=4, status_forcelist=[500, 503], method_whitelist=_ANY_METHOD, - raise_on_status=False, backoff_factor=0.5) + connect=1, read=1, status=4, status_forcelist=[500, 503], + raise_on_status=False, backoff_factor=0.5, **_ANY_METHOD) DEFAULT_TIMEOUT_SECONDS = 120 From c10c7471eda060ef16a11e568e3e503316e67793 Mon Sep 17 00:00:00 2001 From: Lahiru Maramba Date: Thu, 18 Feb 2021 16:25:00 -0500 Subject: [PATCH 093/226] [chore] Release 4.5.2 (#533) - Release 4.5.2 --- firebase_admin/__about__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/firebase_admin/__about__.py b/firebase_admin/__about__.py index 28979015d..58ec6ddb8 100644 --- a/firebase_admin/__about__.py +++ b/firebase_admin/__about__.py @@ -14,7 +14,7 @@ """About information (version, etc) for Firebase Admin SDK.""" -__version__ = '4.5.1' +__version__ = '4.5.2' __title__ = 'firebase_admin' __author__ = 'Firebase' __license__ = 'Apache License 2.0' From 32e45f1535e39a0769e298d2ec3b3e87366ee322 Mon Sep 17 00:00:00 2001 From: Danielle Hanks <41087581+daniellehanks@users.noreply.github.com> Date: Mon, 15 Mar 2021 15:15:20 -0600 Subject: [PATCH 094/226] fix(auth): Make auth client respect app options httpTimeout (#536) --- firebase_admin/_auth_client.py | 3 ++- tests/test_user_mgt.py | 27 ++++++++++++++++++++++++++- 2 files changed, 28 insertions(+), 2 deletions(-) diff --git a/firebase_admin/_auth_client.py b/firebase_admin/_auth_client.py index 1c9b37082..60be96811 100644 --- a/firebase_admin/_auth_client.py +++ b/firebase_admin/_auth_client.py @@ -38,8 +38,9 @@ def __init__(self, app, tenant_id=None): credential = app.credential.get_credential() version_header = 'Python/Admin/{0}'.format(firebase_admin.__version__) + timeout = app.options.get('httpTimeout', _http_client.DEFAULT_TIMEOUT_SECONDS) http_client = _http_client.JsonHttpClient( - credential=credential, headers={'X-Client-Version': version_header}) + credential=credential, headers={'X-Client-Version': version_header}, timeout=timeout) self._tenant_id = tenant_id self._token_generator = _token_gen.TokenGenerator(app, http_client) diff --git a/tests/test_user_mgt.py b/tests/test_user_mgt.py index 79e23373f..240f19bdc 100644 --- a/tests/test_user_mgt.py +++ b/tests/test_user_mgt.py @@ -52,6 +52,8 @@ USER_MGT_URL_PREFIX = 'https://identitytoolkit.googleapis.com/v1/projects/mock-project-id' +TEST_TIMEOUT = 42 + @pytest.fixture(scope='module') def user_mgt_app(): @@ -60,6 +62,16 @@ def user_mgt_app(): yield app firebase_admin.delete_app(app) +@pytest.fixture(scope='module') +def user_mgt_app_with_timeout(): + app = firebase_admin.initialize_app( + testutils.MockCredential(), + name='userMgtTimeout', + options={'projectId': 'mock-project-id', 'httpTimeout': TEST_TIMEOUT} + ) + yield app + firebase_admin.delete_app(app) + def _instrument_user_manager(app, status, payload): client = auth._get_client(app) user_manager = client._user_manager @@ -105,7 +117,7 @@ def _check_user_record(user, expected_uid='testuser'): assert provider.provider_id == 'phone' -def _check_request(recorder, want_url, want_body=None): +def _check_request(recorder, want_url, want_body=None, want_timeout=None): assert len(recorder) == 1 req = recorder[0] assert req.method == 'POST' @@ -113,6 +125,8 @@ def _check_request(recorder, want_url, want_body=None): if want_body: body = json.loads(req.body.decode()) assert body == want_body + if want_timeout: + assert recorder[0]._extra_kwargs['timeout'] == pytest.approx(want_timeout, 0.001) class TestAuthServiceInitialization: @@ -122,6 +136,11 @@ def test_default_timeout(self, user_mgt_app): user_manager = client._user_manager assert user_manager.http_client.timeout == _http_client.DEFAULT_TIMEOUT_SECONDS + def test_app_options_timeout(self, user_mgt_app_with_timeout): + client = auth._get_client(user_mgt_app_with_timeout) + user_manager = client._user_manager + assert user_manager.http_client.timeout == TEST_TIMEOUT + def test_fail_on_no_project_id(self): app = firebase_admin.initialize_app(testutils.MockCredential(), name='userMgt2') with pytest.raises(ValueError): @@ -225,6 +244,12 @@ def test_get_user(self, user_mgt_app): _check_user_record(auth.get_user('testuser', user_mgt_app)) _check_request(recorder, '/accounts:lookup', {'localId': ['testuser']}) + def test_get_user_with_timeout(self, user_mgt_app_with_timeout): + _, recorder = _instrument_user_manager( + user_mgt_app_with_timeout, 200, MOCK_GET_USER_RESPONSE) + _check_user_record(auth.get_user('testuser', user_mgt_app_with_timeout)) + _check_request(recorder, '/accounts:lookup', {'localId': ['testuser']}, TEST_TIMEOUT) + @pytest.mark.parametrize('arg', INVALID_STRINGS + ['not-an-email']) def test_invalid_get_user_by_email(self, arg, user_mgt_app): with pytest.raises(ValueError): From 3bdb182554ad188db9d7d66cfc136e999bf593b0 Mon Sep 17 00:00:00 2001 From: Lahiru Maramba Date: Thu, 18 Mar 2021 18:37:17 -0400 Subject: [PATCH 095/226] [chore] Release 4.5.3 (#537) - Release 4.5.3 --- firebase_admin/__about__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/firebase_admin/__about__.py b/firebase_admin/__about__.py index 58ec6ddb8..c4665e933 100644 --- a/firebase_admin/__about__.py +++ b/firebase_admin/__about__.py @@ -14,7 +14,7 @@ """About information (version, etc) for Firebase Admin SDK.""" -__version__ = '4.5.2' +__version__ = '4.5.3' __title__ = 'firebase_admin' __author__ = 'Firebase' __license__ = 'Apache License 2.0' From 04c406978948a076db6198b5f9a1fa96a9addc2e Mon Sep 17 00:00:00 2001 From: Murukesh Mohanan Date: Wed, 24 Mar 2021 03:10:41 +0900 Subject: [PATCH 096/226] feat(auth): Add auth emulator support via the FIREBASE_AUTH_EMULATOR_HOST environment variable. (#531) * Support auth emulator via FIREBASE_AUTH_EMULATOR_HOST Modeled on https://github.com/firebase/firebase-admin-go/pull/414 * Tests for emulator support in auth, user mgmt and token gen To minimize modification of tests, the app fixture and instrumentation have been modified to use a global dict of URLs, which are then monkey-patched based on fixture parameters. Essentially, all tests using the app fixture are run twice, once with the emulated endpoint and once without. * fallback for monkeypatch in python 3.5 * Token verification for the auth emulator * Accommodate auth emulator behaviour in tests. Where possible, tests are modified to account for the current behaviour in emulator mode (e.g., invalid or expired tokens or cookies still work). Fixtures were changed to function scope to avoid problems caused by overlap when some fixtures being in emulator mode and some in normal mode concurrently. --- firebase_admin/_auth_client.py | 28 +++++- firebase_admin/_auth_providers.py | 5 +- firebase_admin/_auth_utils.py | 15 ++++ firebase_admin/_token_gen.py | 43 +++++++-- firebase_admin/_user_mgt.py | 5 +- firebase_admin/_utils.py | 18 ++++ firebase_admin/db.py | 14 +-- tests/test_auth_providers.py | 79 ++++++++++------- tests/test_db.py | 3 +- tests/test_token_gen.py | 143 +++++++++++++++++++++++------- tests/test_user_mgt.py | 27 ++++-- tests/testutils.py | 11 +++ 12 files changed, 289 insertions(+), 102 deletions(-) diff --git a/firebase_admin/_auth_client.py b/firebase_admin/_auth_client.py index 60be96811..2f6713d41 100644 --- a/firebase_admin/_auth_client.py +++ b/firebase_admin/_auth_client.py @@ -24,6 +24,7 @@ from firebase_admin import _user_identifier from firebase_admin import _user_import from firebase_admin import _user_mgt +from firebase_admin import _utils class Client: @@ -36,18 +37,37 @@ def __init__(self, app, tenant_id=None): 2. set the project ID explicitly via Firebase App options, or 3. set the project ID via the GOOGLE_CLOUD_PROJECT environment variable.""") - credential = app.credential.get_credential() + credential = None version_header = 'Python/Admin/{0}'.format(firebase_admin.__version__) timeout = app.options.get('httpTimeout', _http_client.DEFAULT_TIMEOUT_SECONDS) + # Non-default endpoint URLs for emulator support are set in this dict later. + endpoint_urls = {} + self.emulated = False + + # If an emulator is present, check that the given value matches the expected format and set + # endpoint URLs to use the emulator. Additionally, use a fake credential. + emulator_host = _auth_utils.get_emulator_host() + if emulator_host: + base_url = 'http://{0}/identitytoolkit.googleapis.com'.format(emulator_host) + endpoint_urls['v1'] = base_url + '/v1' + endpoint_urls['v2beta1'] = base_url + '/v2beta1' + credential = _utils.EmulatorAdminCredentials() + self.emulated = True + else: + # Use credentials if provided + credential = app.credential.get_credential() + http_client = _http_client.JsonHttpClient( credential=credential, headers={'X-Client-Version': version_header}, timeout=timeout) self._tenant_id = tenant_id - self._token_generator = _token_gen.TokenGenerator(app, http_client) + self._token_generator = _token_gen.TokenGenerator( + app, http_client, url_override=endpoint_urls.get('v1')) self._token_verifier = _token_gen.TokenVerifier(app) - self._user_manager = _user_mgt.UserManager(http_client, app.project_id, tenant_id) + self._user_manager = _user_mgt.UserManager( + http_client, app.project_id, tenant_id, url_override=endpoint_urls.get('v1')) self._provider_manager = _auth_providers.ProviderConfigClient( - http_client, app.project_id, tenant_id) + http_client, app.project_id, tenant_id, url_override=endpoint_urls.get('v2beta1')) @property def tenant_id(self): diff --git a/firebase_admin/_auth_providers.py b/firebase_admin/_auth_providers.py index 46de6fe5f..5126c862c 100644 --- a/firebase_admin/_auth_providers.py +++ b/firebase_admin/_auth_providers.py @@ -166,9 +166,10 @@ class ProviderConfigClient: PROVIDER_CONFIG_URL = 'https://identitytoolkit.googleapis.com/v2beta1' - def __init__(self, http_client, project_id, tenant_id=None): + def __init__(self, http_client, project_id, tenant_id=None, url_override=None): self.http_client = http_client - self.base_url = '{0}/projects/{1}'.format(self.PROVIDER_CONFIG_URL, project_id) + url_prefix = url_override or self.PROVIDER_CONFIG_URL + self.base_url = '{0}/projects/{1}'.format(url_prefix, project_id) if tenant_id: self.base_url += '/tenants/{0}'.format(tenant_id) diff --git a/firebase_admin/_auth_utils.py b/firebase_admin/_auth_utils.py index 2226675f9..d8e49b1a1 100644 --- a/firebase_admin/_auth_utils.py +++ b/firebase_admin/_auth_utils.py @@ -15,6 +15,7 @@ """Firebase auth utils.""" import json +import os import re from urllib import parse @@ -22,6 +23,7 @@ from firebase_admin import _utils +EMULATOR_HOST_ENV_VAR = 'FIREBASE_AUTH_EMULATOR_HOST' MAX_CLAIMS_PAYLOAD_SIZE = 1000 RESERVED_CLAIMS = set([ 'acr', 'amr', 'at_hash', 'aud', 'auth_time', 'azp', 'cnf', 'c_hash', 'exp', 'iat', @@ -66,6 +68,19 @@ def __iter__(self): return self +def get_emulator_host(): + emulator_host = os.getenv(EMULATOR_HOST_ENV_VAR, '') + if emulator_host and '//' in emulator_host: + raise ValueError( + 'Invalid {0}: "{1}". It must follow format "host:port".'.format( + EMULATOR_HOST_ENV_VAR, emulator_host)) + return emulator_host + + +def is_emulated(): + return get_emulator_host() != '' + + def validate_uid(uid, required=False): if uid is None and not required: return None diff --git a/firebase_admin/_token_gen.py b/firebase_admin/_token_gen.py index 18a8008c7..562e77fa5 100644 --- a/firebase_admin/_token_gen.py +++ b/firebase_admin/_token_gen.py @@ -53,6 +53,19 @@ METADATA_SERVICE_URL = ('http://metadata.google.internal/computeMetadata/v1/instance/' 'service-accounts/default/email') +# Emulator fake account +AUTH_EMULATOR_EMAIL = 'firebase-auth-emulator@example.com' + + +class _EmulatedSigner(google.auth.crypt.Signer): + key_id = None + + def __init__(self): + pass + + def sign(self, message): + return b'' + class _SigningProvider: """Stores a reference to a google.auth.crypto.Signer.""" @@ -78,21 +91,28 @@ def from_iam(cls, request, google_cred, service_account): signer = iam.Signer(request, google_cred, service_account) return _SigningProvider(signer, service_account) + @classmethod + def for_emulator(cls): + return _SigningProvider(_EmulatedSigner(), AUTH_EMULATOR_EMAIL) + class TokenGenerator: """Generates custom tokens and session cookies.""" ID_TOOLKIT_URL = 'https://identitytoolkit.googleapis.com/v1' - def __init__(self, app, http_client): + def __init__(self, app, http_client, url_override=None): self.app = app self.http_client = http_client self.request = transport.requests.Request() - self.base_url = '{0}/projects/{1}'.format(self.ID_TOOLKIT_URL, app.project_id) + url_prefix = url_override or self.ID_TOOLKIT_URL + self.base_url = '{0}/projects/{1}'.format(url_prefix, app.project_id) self._signing_provider = None def _init_signing_provider(self): """Initializes a signing provider by following the go/firebase-admin-sign protocol.""" + if _auth_utils.is_emulated(): + return _SigningProvider.for_emulator() # If the SDK was initialized with a service account, use it to sign bytes. google_cred = self.app.credential.get_credential() if isinstance(google_cred, google.oauth2.service_account.Credentials): @@ -285,12 +305,14 @@ def verify(self, token, request): verify_id_token_msg = ( 'See {0} for details on how to retrieve {1}.'.format(self.url, self.short_name)) + emulated = _auth_utils.is_emulated() + error_message = None if audience == FIREBASE_AUDIENCE: error_message = ( '{0} expects {1}, but was given a custom ' 'token.'.format(self.operation, self.articled_short_name)) - elif not header.get('kid'): + elif not emulated and not header.get('kid'): if header.get('alg') == 'HS256' and payload.get( 'v') == 0 and 'uid' in payload.get('d', {}): error_message = ( @@ -298,7 +320,7 @@ def verify(self, token, request): 'token.'.format(self.operation, self.articled_short_name)) else: error_message = 'Firebase {0} has no "kid" claim.'.format(self.short_name) - elif header.get('alg') != 'RS256': + elif not emulated and header.get('alg') != 'RS256': error_message = ( 'Firebase {0} has incorrect algorithm. Expected "RS256" but got ' '"{1}". {2}'.format(self.short_name, header.get('alg'), verify_id_token_msg)) @@ -329,11 +351,14 @@ def verify(self, token, request): raise self._invalid_token_error(error_message) try: - verified_claims = google.oauth2.id_token.verify_token( - token, - request=request, - audience=self.project_id, - certs_url=self.cert_url) + if emulated: + verified_claims = payload + else: + verified_claims = google.oauth2.id_token.verify_token( + token, + request=request, + audience=self.project_id, + certs_url=self.cert_url) verified_claims['uid'] = verified_claims['sub'] return verified_claims except google.auth.exceptions.TransportError as error: diff --git a/firebase_admin/_user_mgt.py b/firebase_admin/_user_mgt.py index 1d97dd504..b60c4d100 100644 --- a/firebase_admin/_user_mgt.py +++ b/firebase_admin/_user_mgt.py @@ -573,9 +573,10 @@ class UserManager: ID_TOOLKIT_URL = 'https://identitytoolkit.googleapis.com/v1' - def __init__(self, http_client, project_id, tenant_id=None): + def __init__(self, http_client, project_id, tenant_id=None, url_override=None): self.http_client = http_client - self.base_url = '{0}/projects/{1}'.format(self.ID_TOOLKIT_URL, project_id) + url_prefix = url_override or self.ID_TOOLKIT_URL + self.base_url = '{0}/projects/{1}'.format(url_prefix, project_id) if tenant_id: self.base_url += '/tenants/{0}'.format(tenant_id) diff --git a/firebase_admin/_utils.py b/firebase_admin/_utils.py index a5fc8d022..8c640276c 100644 --- a/firebase_admin/_utils.py +++ b/firebase_admin/_utils.py @@ -18,6 +18,7 @@ import json import socket +import google.auth import googleapiclient import httplib2 import requests @@ -339,3 +340,20 @@ def _parse_platform_error(content, status_code): if not msg: msg = 'Unexpected HTTP response with status: {0}; body: {1}'.format(status_code, content) return error_dict, msg + + +# Temporarily disable the lint rule. For more information see: +# https://github.com/googleapis/google-auth-library-python/pull/561 +# pylint: disable=abstract-method +class EmulatorAdminCredentials(google.auth.credentials.Credentials): + """ Credentials for use with the firebase local emulator. + + This is used instead of user-supplied credentials or ADC. It will silently do nothing when + asked to refresh credentials. + """ + def __init__(self): + google.auth.credentials.Credentials.__init__(self) + self.token = 'owner' + + def refresh(self, request): + pass diff --git a/firebase_admin/db.py b/firebase_admin/db.py index 3384bd440..1d293bb89 100644 --- a/firebase_admin/db.py +++ b/firebase_admin/db.py @@ -27,7 +27,6 @@ import threading from urllib import parse -import google.auth import requests import firebase_admin @@ -808,7 +807,7 @@ def get_client(self, db_url=None): emulator_config = self._get_emulator_config(parsed_url) if emulator_config: - credential = _EmulatorAdminCredentials() + credential = _utils.EmulatorAdminCredentials() base_url = emulator_config.base_url params = {'ns': emulator_config.namespace} else: @@ -965,14 +964,3 @@ def _extract_error_message(cls, response): message = 'Unexpected response from database: {0}'.format(response.content.decode()) return message - -# Temporarily disable the lint rule. For more information see: -# https://github.com/googleapis/google-auth-library-python/pull/561 -# pylint: disable=abstract-method -class _EmulatorAdminCredentials(google.auth.credentials.Credentials): - def __init__(self): - google.auth.credentials.Credentials.__init__(self) - self.token = 'owner' - - def refresh(self, request): - pass diff --git a/tests/test_auth_providers.py b/tests/test_auth_providers.py index 124aea3cc..0947c77ae 100644 --- a/tests/test_auth_providers.py +++ b/tests/test_auth_providers.py @@ -21,10 +21,18 @@ import firebase_admin from firebase_admin import auth from firebase_admin import exceptions -from firebase_admin import _auth_providers from tests import testutils -USER_MGT_URL_PREFIX = 'https://identitytoolkit.googleapis.com/v2beta1/projects/mock-project-id' +ID_TOOLKIT_URL = 'https://identitytoolkit.googleapis.com/v2beta1' +EMULATOR_HOST_ENV_VAR = 'FIREBASE_AUTH_EMULATOR_HOST' +AUTH_EMULATOR_HOST = 'localhost:9099' +EMULATED_ID_TOOLKIT_URL = 'http://{}/identitytoolkit.googleapis.com/v2beta1'.format( + AUTH_EMULATOR_HOST) +URL_PROJECT_SUFFIX = '/projects/mock-project-id' +USER_MGT_URLS = { + 'ID_TOOLKIT': ID_TOOLKIT_URL, + 'PREFIX': ID_TOOLKIT_URL + URL_PROJECT_SUFFIX, +} OIDC_PROVIDER_CONFIG_RESPONSE = testutils.resource('oidc_provider_config.json') SAML_PROVIDER_CONFIG_RESPONSE = testutils.resource('saml_provider_config.json') LIST_OIDC_PROVIDER_CONFIGS_RESPONSE = testutils.resource('list_oidc_provider_configs.json') @@ -39,12 +47,18 @@ INVALID_PROVIDER_IDS = [None, True, False, 1, 0, list(), tuple(), dict(), ''] -@pytest.fixture(scope='module') -def user_mgt_app(): +@pytest.fixture(scope='module', params=[{'emulated': False}, {'emulated': True}]) +def user_mgt_app(request): + monkeypatch = testutils.new_monkeypatch() + if request.param['emulated']: + monkeypatch.setenv(EMULATOR_HOST_ENV_VAR, AUTH_EMULATOR_HOST) + monkeypatch.setitem(USER_MGT_URLS, 'ID_TOOLKIT', EMULATED_ID_TOOLKIT_URL) + monkeypatch.setitem(USER_MGT_URLS, 'PREFIX', EMULATED_ID_TOOLKIT_URL + URL_PROJECT_SUFFIX) app = firebase_admin.initialize_app(testutils.MockCredential(), name='providerConfig', options={'projectId': 'mock-project-id'}) yield app firebase_admin.delete_app(app) + monkeypatch.undo() def _instrument_provider_mgt(app, status, payload): @@ -52,7 +66,7 @@ def _instrument_provider_mgt(app, status, payload): provider_manager = client._provider_manager recorder = [] provider_manager.http_client.session.mount( - _auth_providers.ProviderConfigClient.PROVIDER_CONFIG_URL, + USER_MGT_URLS['ID_TOOLKIT'], testutils.MockAdapter(payload, status, recorder)) return recorder @@ -90,7 +104,7 @@ def test_get(self, user_mgt_app): assert len(recorder) == 1 req = recorder[0] assert req.method == 'GET' - assert req.url == '{0}{1}'.format(USER_MGT_URL_PREFIX, '/oauthIdpConfigs/oidc.provider') + assert req.url == '{0}{1}'.format(USER_MGT_URLS['PREFIX'], '/oauthIdpConfigs/oidc.provider') @pytest.mark.parametrize('invalid_opts', [ {'provider_id': None}, {'provider_id': ''}, {'provider_id': 'saml.provider'}, @@ -116,7 +130,7 @@ def test_create(self, user_mgt_app): req = recorder[0] assert req.method == 'POST' assert req.url == '{0}/oauthIdpConfigs?oauthIdpConfigId=oidc.provider'.format( - USER_MGT_URL_PREFIX) + USER_MGT_URLS['PREFIX']) got = json.loads(req.body.decode()) assert got == self.OIDC_CONFIG_REQUEST @@ -136,7 +150,7 @@ def test_create_minimal(self, user_mgt_app): req = recorder[0] assert req.method == 'POST' assert req.url == '{0}/oauthIdpConfigs?oauthIdpConfigId=oidc.provider'.format( - USER_MGT_URL_PREFIX) + USER_MGT_URLS['PREFIX']) got = json.loads(req.body.decode()) assert got == want @@ -156,7 +170,7 @@ def test_create_empty_values(self, user_mgt_app): req = recorder[0] assert req.method == 'POST' assert req.url == '{0}/oauthIdpConfigs?oauthIdpConfigId=oidc.provider'.format( - USER_MGT_URL_PREFIX) + USER_MGT_URLS['PREFIX']) got = json.loads(req.body.decode()) assert got == want @@ -186,7 +200,7 @@ def test_update(self, user_mgt_app): assert req.method == 'PATCH' mask = ['clientId', 'displayName', 'enabled', 'issuer'] assert req.url == '{0}/oauthIdpConfigs/oidc.provider?updateMask={1}'.format( - USER_MGT_URL_PREFIX, ','.join(mask)) + USER_MGT_URLS['PREFIX'], ','.join(mask)) got = json.loads(req.body.decode()) assert got == self.OIDC_CONFIG_REQUEST @@ -201,7 +215,7 @@ def test_update_minimal(self, user_mgt_app): req = recorder[0] assert req.method == 'PATCH' assert req.url == '{0}/oauthIdpConfigs/oidc.provider?updateMask=displayName'.format( - USER_MGT_URL_PREFIX) + USER_MGT_URLS['PREFIX']) got = json.loads(req.body.decode()) assert got == {'displayName': 'oidcProviderName'} @@ -217,7 +231,7 @@ def test_update_empty_values(self, user_mgt_app): assert req.method == 'PATCH' mask = ['displayName', 'enabled'] assert req.url == '{0}/oauthIdpConfigs/oidc.provider?updateMask={1}'.format( - USER_MGT_URL_PREFIX, ','.join(mask)) + USER_MGT_URLS['PREFIX'], ','.join(mask)) got = json.loads(req.body.decode()) assert got == {'displayName': None, 'enabled': False} @@ -236,7 +250,7 @@ def test_delete(self, user_mgt_app): assert len(recorder) == 1 req = recorder[0] assert req.method == 'DELETE' - assert req.url == '{0}{1}'.format(USER_MGT_URL_PREFIX, '/oauthIdpConfigs/oidc.provider') + assert req.url == '{0}{1}'.format(USER_MGT_URLS['PREFIX'], '/oauthIdpConfigs/oidc.provider') @pytest.mark.parametrize('arg', [None, 'foo', list(), dict(), 0, -1, 101, False]) def test_invalid_max_results(self, user_mgt_app, arg): @@ -259,7 +273,7 @@ def test_list_single_page(self, user_mgt_app): assert len(recorder) == 1 req = recorder[0] assert req.method == 'GET' - assert req.url == '{0}{1}'.format(USER_MGT_URL_PREFIX, '/oauthIdpConfigs?pageSize=100') + assert req.url == '{0}{1}'.format(USER_MGT_URLS['PREFIX'], '/oauthIdpConfigs?pageSize=100') def test_list_multiple_pages(self, user_mgt_app): sample_response = json.loads(OIDC_PROVIDER_CONFIG_RESPONSE) @@ -277,7 +291,7 @@ def test_list_multiple_pages(self, user_mgt_app): assert len(recorder) == 1 req = recorder[0] assert req.method == 'GET' - assert req.url == '{0}/oauthIdpConfigs?pageSize=10'.format(USER_MGT_URL_PREFIX) + assert req.url == '{0}/oauthIdpConfigs?pageSize=10'.format(USER_MGT_URLS['PREFIX']) # Page 2 (also the last page) response = {'oauthIdpConfigs': configs[2:]} @@ -289,7 +303,7 @@ def test_list_multiple_pages(self, user_mgt_app): req = recorder[0] assert req.method == 'GET' assert req.url == '{0}/oauthIdpConfigs?pageSize=10&pageToken=token'.format( - USER_MGT_URL_PREFIX) + USER_MGT_URLS['PREFIX']) def test_paged_iteration(self, user_mgt_app): sample_response = json.loads(OIDC_PROVIDER_CONFIG_RESPONSE) @@ -310,7 +324,7 @@ def test_paged_iteration(self, user_mgt_app): assert len(recorder) == 1 req = recorder[0] assert req.method == 'GET' - assert req.url == '{0}/oauthIdpConfigs?pageSize=100'.format(USER_MGT_URL_PREFIX) + assert req.url == '{0}/oauthIdpConfigs?pageSize=100'.format(USER_MGT_URLS['PREFIX']) # Page 2 (also the last page) response = {'oauthIdpConfigs': configs[2:]} @@ -322,7 +336,7 @@ def test_paged_iteration(self, user_mgt_app): req = recorder[0] assert req.method == 'GET' assert req.url == '{0}/oauthIdpConfigs?pageSize=100&pageToken=token'.format( - USER_MGT_URL_PREFIX) + USER_MGT_URLS['PREFIX']) with pytest.raises(StopIteration): next(iterator) @@ -421,7 +435,8 @@ def test_get(self, user_mgt_app): assert len(recorder) == 1 req = recorder[0] assert req.method == 'GET' - assert req.url == '{0}{1}'.format(USER_MGT_URL_PREFIX, '/inboundSamlConfigs/saml.provider') + assert req.url == '{0}{1}'.format(USER_MGT_URLS['PREFIX'], + '/inboundSamlConfigs/saml.provider') @pytest.mark.parametrize('invalid_opts', [ {'provider_id': None}, {'provider_id': ''}, {'provider_id': 'oidc.provider'}, @@ -451,7 +466,7 @@ def test_create(self, user_mgt_app): req = recorder[0] assert req.method == 'POST' assert req.url == '{0}/inboundSamlConfigs?inboundSamlConfigId=saml.provider'.format( - USER_MGT_URL_PREFIX) + USER_MGT_URLS['PREFIX']) got = json.loads(req.body.decode()) assert got == self.SAML_CONFIG_REQUEST @@ -471,7 +486,7 @@ def test_create_minimal(self, user_mgt_app): req = recorder[0] assert req.method == 'POST' assert req.url == '{0}/inboundSamlConfigs?inboundSamlConfigId=saml.provider'.format( - USER_MGT_URL_PREFIX) + USER_MGT_URLS['PREFIX']) got = json.loads(req.body.decode()) assert got == want @@ -491,7 +506,7 @@ def test_create_empty_values(self, user_mgt_app): req = recorder[0] assert req.method == 'POST' assert req.url == '{0}/inboundSamlConfigs?inboundSamlConfigId=saml.provider'.format( - USER_MGT_URL_PREFIX) + USER_MGT_URLS['PREFIX']) got = json.loads(req.body.decode()) assert got == want @@ -528,7 +543,7 @@ def test_update(self, user_mgt_app): 'idpConfig.ssoUrl', 'spConfig.callbackUri', 'spConfig.spEntityId', ] assert req.url == '{0}/inboundSamlConfigs/saml.provider?updateMask={1}'.format( - USER_MGT_URL_PREFIX, ','.join(mask)) + USER_MGT_URLS['PREFIX'], ','.join(mask)) got = json.loads(req.body.decode()) assert got == self.SAML_CONFIG_REQUEST @@ -543,7 +558,7 @@ def test_update_minimal(self, user_mgt_app): req = recorder[0] assert req.method == 'PATCH' assert req.url == '{0}/inboundSamlConfigs/saml.provider?updateMask=displayName'.format( - USER_MGT_URL_PREFIX) + USER_MGT_URLS['PREFIX']) got = json.loads(req.body.decode()) assert got == {'displayName': 'samlProviderName'} @@ -559,7 +574,7 @@ def test_update_empty_values(self, user_mgt_app): assert req.method == 'PATCH' mask = ['displayName', 'enabled'] assert req.url == '{0}/inboundSamlConfigs/saml.provider?updateMask={1}'.format( - USER_MGT_URL_PREFIX, ','.join(mask)) + USER_MGT_URLS['PREFIX'], ','.join(mask)) got = json.loads(req.body.decode()) assert got == {'displayName': None, 'enabled': False} @@ -578,7 +593,8 @@ def test_delete(self, user_mgt_app): assert len(recorder) == 1 req = recorder[0] assert req.method == 'DELETE' - assert req.url == '{0}{1}'.format(USER_MGT_URL_PREFIX, '/inboundSamlConfigs/saml.provider') + assert req.url == '{0}{1}'.format(USER_MGT_URLS['PREFIX'], + '/inboundSamlConfigs/saml.provider') def test_config_not_found(self, user_mgt_app): _instrument_provider_mgt(user_mgt_app, 500, CONFIG_NOT_FOUND_RESPONSE) @@ -613,7 +629,8 @@ def test_list_single_page(self, user_mgt_app): assert len(recorder) == 1 req = recorder[0] assert req.method == 'GET' - assert req.url == '{0}{1}'.format(USER_MGT_URL_PREFIX, '/inboundSamlConfigs?pageSize=100') + assert req.url == '{0}{1}'.format(USER_MGT_URLS['PREFIX'], + '/inboundSamlConfigs?pageSize=100') def test_list_multiple_pages(self, user_mgt_app): sample_response = json.loads(SAML_PROVIDER_CONFIG_RESPONSE) @@ -631,7 +648,7 @@ def test_list_multiple_pages(self, user_mgt_app): assert len(recorder) == 1 req = recorder[0] assert req.method == 'GET' - assert req.url == '{0}/inboundSamlConfigs?pageSize=10'.format(USER_MGT_URL_PREFIX) + assert req.url == '{0}/inboundSamlConfigs?pageSize=10'.format(USER_MGT_URLS['PREFIX']) # Page 2 (also the last page) response = {'inboundSamlConfigs': configs[2:]} @@ -643,7 +660,7 @@ def test_list_multiple_pages(self, user_mgt_app): req = recorder[0] assert req.method == 'GET' assert req.url == '{0}/inboundSamlConfigs?pageSize=10&pageToken=token'.format( - USER_MGT_URL_PREFIX) + USER_MGT_URLS['PREFIX']) def test_paged_iteration(self, user_mgt_app): sample_response = json.loads(SAML_PROVIDER_CONFIG_RESPONSE) @@ -664,7 +681,7 @@ def test_paged_iteration(self, user_mgt_app): assert len(recorder) == 1 req = recorder[0] assert req.method == 'GET' - assert req.url == '{0}/inboundSamlConfigs?pageSize=100'.format(USER_MGT_URL_PREFIX) + assert req.url == '{0}/inboundSamlConfigs?pageSize=100'.format(USER_MGT_URLS['PREFIX']) # Page 2 (also the last page) response = {'inboundSamlConfigs': configs[2:]} @@ -676,7 +693,7 @@ def test_paged_iteration(self, user_mgt_app): req = recorder[0] assert req.method == 'GET' assert req.url == '{0}/inboundSamlConfigs?pageSize=100&pageToken=token'.format( - USER_MGT_URL_PREFIX) + USER_MGT_URLS['PREFIX']) with pytest.raises(StopIteration): next(iterator) diff --git a/tests/test_db.py b/tests/test_db.py index 5f8ba4b51..aa2c83bd9 100644 --- a/tests/test_db.py +++ b/tests/test_db.py @@ -26,6 +26,7 @@ from firebase_admin import exceptions from firebase_admin import _http_client from firebase_admin import _sseclient +from firebase_admin import _utils from tests import testutils @@ -730,7 +731,7 @@ def test_parse_db_url(self, url, emulator_host, expected_base_url, expected_name assert ref._client._base_url == expected_base_url assert ref._client.params.get('ns') == expected_namespace if expected_base_url.startswith('http://localhost'): - assert isinstance(ref._client.credential, db._EmulatorAdminCredentials) + assert isinstance(ref._client.credential, _utils.EmulatorAdminCredentials) else: assert isinstance(ref._client.credential, testutils.MockGoogleCredential) finally: diff --git a/tests/test_token_gen.py b/tests/test_token_gen.py index f88c87ff4..29c70da80 100644 --- a/tests/test_token_gen.py +++ b/tests/test_token_gen.py @@ -55,6 +55,14 @@ 'NonEmptyDictToken': {'a': 1}, } +ID_TOOLKIT_URL = 'https://identitytoolkit.googleapis.com/v1' +EMULATOR_HOST_ENV_VAR = 'FIREBASE_AUTH_EMULATOR_HOST' +AUTH_EMULATOR_HOST = 'localhost:9099' +EMULATED_ID_TOOLKIT_URL = 'http://{}/identitytoolkit.googleapis.com/v1'.format(AUTH_EMULATOR_HOST) +TOKEN_MGT_URLS = { + 'ID_TOOLKIT': ID_TOOLKIT_URL, +} + # Fixture for mocking a HTTP server httpserver = plugin.httpserver @@ -68,13 +76,18 @@ def _merge_jwt_claims(defaults, overrides): def verify_custom_token(custom_token, expected_claims, tenant_id=None): assert isinstance(custom_token, bytes) - token = google.oauth2.id_token.verify_token( - custom_token, - testutils.MockRequest(200, MOCK_PUBLIC_CERTS), - _token_gen.FIREBASE_AUDIENCE) + expected_email = MOCK_SERVICE_ACCOUNT_EMAIL + if _is_emulated(): + expected_email = _token_gen.AUTH_EMULATOR_EMAIL + token = jwt.decode(custom_token, verify=False) + else: + token = google.oauth2.id_token.verify_token( + custom_token, + testutils.MockRequest(200, MOCK_PUBLIC_CERTS), + _token_gen.FIREBASE_AUDIENCE) assert token['uid'] == MOCK_UID - assert token['iss'] == MOCK_SERVICE_ACCOUNT_EMAIL - assert token['sub'] == MOCK_SERVICE_ACCOUNT_EMAIL + assert token['iss'] == expected_email + assert token['sub'] == expected_email if tenant_id is None: assert 'tenant_id' not in token else: @@ -121,7 +134,7 @@ def _instrument_user_manager(app, status, payload): user_manager = client._user_manager recorder = [] user_manager.http_client.session.mount( - _token_gen.TokenGenerator.ID_TOOLKIT_URL, + TOKEN_MGT_URLS['ID_TOOLKIT'], testutils.MockAdapter(payload, status, recorder)) return user_manager, recorder @@ -133,23 +146,41 @@ def _overwrite_iam_request(app, request): client = auth._get_client(app) client._token_generator.request = request -@pytest.fixture(scope='module') -def auth_app(): + +def _is_emulated(): + emulator_host = os.getenv(EMULATOR_HOST_ENV_VAR, '') + return emulator_host and '//' not in emulator_host + + +# These fixtures are set to the default function scope as the emulator environment variable bleeds +# over when in module scope. +@pytest.fixture(params=[{'emulated': False}, {'emulated': True}]) +def auth_app(request): """Returns an App initialized with a mock service account credential. This can be used in any scenario where the private key is required. Use user_mgt_app for everything else. """ + monkeypatch = testutils.new_monkeypatch() + if request.param['emulated']: + monkeypatch.setenv(EMULATOR_HOST_ENV_VAR, AUTH_EMULATOR_HOST) + monkeypatch.setitem(TOKEN_MGT_URLS, 'ID_TOOLKIT', EMULATED_ID_TOOLKIT_URL) app = firebase_admin.initialize_app(MOCK_CREDENTIAL, name='tokenGen') yield app firebase_admin.delete_app(app) - -@pytest.fixture(scope='module') -def user_mgt_app(): + monkeypatch.undo() + +@pytest.fixture(params=[{'emulated': False}, {'emulated': True}]) +def user_mgt_app(request): + monkeypatch = testutils.new_monkeypatch() + if request.param['emulated']: + monkeypatch.setenv(EMULATOR_HOST_ENV_VAR, AUTH_EMULATOR_HOST) + monkeypatch.setitem(TOKEN_MGT_URLS, 'ID_TOOLKIT', EMULATED_ID_TOOLKIT_URL) app = firebase_admin.initialize_app(testutils.MockCredential(), name='userMgt', options={'projectId': 'mock-project-id'}) yield app firebase_admin.delete_app(app) + monkeypatch.undo() @pytest.fixture def env_var_app(request): @@ -212,6 +243,12 @@ def test_invalid_params(self, auth_app, values): auth.create_custom_token(user, claims, app=auth_app) def test_noncert_credential(self, user_mgt_app): + if _is_emulated(): + # Should work fine with the emulator, so do a condensed version of + # test_sign_with_iam below. + custom_token = auth.create_custom_token(MOCK_UID, app=user_mgt_app).decode() + self._verify_signer(custom_token, _token_gen.AUTH_EMULATOR_EMAIL) + return with pytest.raises(ValueError): auth.create_custom_token(MOCK_UID, app=user_mgt_app) @@ -286,7 +323,7 @@ def test_sign_with_discovery_failure(self): def _verify_signer(self, token, signer): segments = token.split('.') assert len(segments) == 3 - body = json.loads(base64.b64decode(segments[1]).decode()) + body = jwt.decode(token, verify=False) assert body['iss'] == signer assert body['sub'] == signer @@ -388,14 +425,24 @@ class TestVerifyIdToken: 'BadFormatToken': 'foobar' } - @pytest.mark.parametrize('id_token', valid_tokens.values(), ids=list(valid_tokens)) - def test_valid_token(self, user_mgt_app, id_token): - _overwrite_cert_request(user_mgt_app, MOCK_REQUEST) - claims = auth.verify_id_token(id_token, app=user_mgt_app) + tokens_accepted_in_emulator = [ + 'NoKid', + 'WrongKid', + 'FutureToken', + 'ExpiredToken' + ] + + def _assert_valid_token(self, id_token, app): + claims = auth.verify_id_token(id_token, app=app) assert claims['admin'] is True assert claims['uid'] == claims['sub'] assert claims['firebase']['sign_in_provider'] == 'provider' + @pytest.mark.parametrize('id_token', valid_tokens.values(), ids=list(valid_tokens)) + def test_valid_token(self, user_mgt_app, id_token): + _overwrite_cert_request(user_mgt_app, MOCK_REQUEST) + self._assert_valid_token(id_token, app=user_mgt_app) + def test_valid_token_with_tenant(self, user_mgt_app): _overwrite_cert_request(user_mgt_app, MOCK_REQUEST) claims = auth.verify_id_token(TEST_ID_TOKEN_WITH_TENANT, app=user_mgt_app) @@ -440,8 +487,12 @@ def test_invalid_arg(self, user_mgt_app, id_token): auth.verify_id_token(id_token, app=user_mgt_app) assert 'Illegal ID token provided' in str(excinfo.value) - @pytest.mark.parametrize('id_token', invalid_tokens.values(), ids=list(invalid_tokens)) - def test_invalid_token(self, user_mgt_app, id_token): + @pytest.mark.parametrize('id_token_key', list(invalid_tokens)) + def test_invalid_token(self, user_mgt_app, id_token_key): + id_token = self.invalid_tokens[id_token_key] + if _is_emulated() and id_token_key in self.tokens_accepted_in_emulator: + self._assert_valid_token(id_token, user_mgt_app) + return _overwrite_cert_request(user_mgt_app, MOCK_REQUEST) with pytest.raises(auth.InvalidIdTokenError) as excinfo: auth.verify_id_token(id_token, app=user_mgt_app) @@ -451,6 +502,9 @@ def test_invalid_token(self, user_mgt_app, id_token): def test_expired_token(self, user_mgt_app): _overwrite_cert_request(user_mgt_app, MOCK_REQUEST) id_token = self.invalid_tokens['ExpiredToken'] + if _is_emulated(): + self._assert_valid_token(id_token, user_mgt_app) + return with pytest.raises(auth.ExpiredIdTokenError) as excinfo: auth.verify_id_token(id_token, app=user_mgt_app) assert isinstance(excinfo.value, auth.InvalidIdTokenError) @@ -488,6 +542,10 @@ def test_custom_token(self, auth_app): def test_certificate_request_failure(self, user_mgt_app): _overwrite_cert_request(user_mgt_app, testutils.MockRequest(404, 'not found')) + if _is_emulated(): + # Shouldn't fetch certificates in emulator mode. + self._assert_valid_token(TEST_ID_TOKEN, app=user_mgt_app) + return with pytest.raises(auth.CertificateFetchError) as excinfo: auth.verify_id_token(TEST_ID_TOKEN, app=user_mgt_app) assert 'Could not fetch certificates' in str(excinfo.value) @@ -522,20 +580,28 @@ class TestVerifySessionCookie: 'IDToken': TEST_ID_TOKEN, } + cookies_accepted_in_emulator = [ + 'NoKid', + 'WrongKid', + 'FutureCookie', + 'ExpiredCookie' + ] + + def _assert_valid_cookie(self, cookie, app, check_revoked=False): + claims = auth.verify_session_cookie(cookie, app=app, check_revoked=check_revoked) + assert claims['admin'] is True + assert claims['uid'] == claims['sub'] + @pytest.mark.parametrize('cookie', valid_cookies.values(), ids=list(valid_cookies)) def test_valid_cookie(self, user_mgt_app, cookie): _overwrite_cert_request(user_mgt_app, MOCK_REQUEST) - claims = auth.verify_session_cookie(cookie, app=user_mgt_app) - assert claims['admin'] is True - assert claims['uid'] == claims['sub'] + self._assert_valid_cookie(cookie, user_mgt_app) @pytest.mark.parametrize('cookie', valid_cookies.values(), ids=list(valid_cookies)) def test_valid_cookie_check_revoked(self, user_mgt_app, cookie): _overwrite_cert_request(user_mgt_app, MOCK_REQUEST) _instrument_user_manager(user_mgt_app, 200, MOCK_GET_USER_RESPONSE) - claims = auth.verify_session_cookie(cookie, app=user_mgt_app, check_revoked=True) - assert claims['admin'] is True - assert claims['uid'] == claims['sub'] + self._assert_valid_cookie(cookie, app=user_mgt_app, check_revoked=True) @pytest.mark.parametrize('cookie', valid_cookies.values(), ids=list(valid_cookies)) def test_revoked_cookie_check_revoked(self, user_mgt_app, revoked_tokens, cookie): @@ -549,9 +615,7 @@ def test_revoked_cookie_check_revoked(self, user_mgt_app, revoked_tokens, cookie def test_revoked_cookie_does_not_check_revoked(self, user_mgt_app, revoked_tokens, cookie): _overwrite_cert_request(user_mgt_app, MOCK_REQUEST) _instrument_user_manager(user_mgt_app, 200, revoked_tokens) - claims = auth.verify_session_cookie(cookie, app=user_mgt_app, check_revoked=False) - assert claims['admin'] is True - assert claims['uid'] == claims['sub'] + self._assert_valid_cookie(cookie, app=user_mgt_app, check_revoked=False) @pytest.mark.parametrize('cookie', INVALID_JWT_ARGS.values(), ids=list(INVALID_JWT_ARGS)) def test_invalid_args(self, user_mgt_app, cookie): @@ -560,8 +624,12 @@ def test_invalid_args(self, user_mgt_app, cookie): auth.verify_session_cookie(cookie, app=user_mgt_app) assert 'Illegal session cookie provided' in str(excinfo.value) - @pytest.mark.parametrize('cookie', invalid_cookies.values(), ids=list(invalid_cookies)) - def test_invalid_cookie(self, user_mgt_app, cookie): + @pytest.mark.parametrize('cookie_key', list(invalid_cookies)) + def test_invalid_cookie(self, user_mgt_app, cookie_key): + cookie = self.invalid_cookies[cookie_key] + if _is_emulated() and cookie_key in self.cookies_accepted_in_emulator: + self._assert_valid_cookie(cookie, user_mgt_app) + return _overwrite_cert_request(user_mgt_app, MOCK_REQUEST) with pytest.raises(auth.InvalidSessionCookieError) as excinfo: auth.verify_session_cookie(cookie, app=user_mgt_app) @@ -571,6 +639,9 @@ def test_invalid_cookie(self, user_mgt_app, cookie): def test_expired_cookie(self, user_mgt_app): _overwrite_cert_request(user_mgt_app, MOCK_REQUEST) cookie = self.invalid_cookies['ExpiredCookie'] + if _is_emulated(): + self._assert_valid_cookie(cookie, user_mgt_app) + return with pytest.raises(auth.ExpiredSessionCookieError) as excinfo: auth.verify_session_cookie(cookie, app=user_mgt_app) assert isinstance(excinfo.value, auth.InvalidSessionCookieError) @@ -603,6 +674,10 @@ def test_custom_token(self, auth_app): def test_certificate_request_failure(self, user_mgt_app): _overwrite_cert_request(user_mgt_app, testutils.MockRequest(404, 'not found')) + if _is_emulated(): + # Shouldn't fetch certificates in emulator mode. + auth.verify_session_cookie(TEST_SESSION_COOKIE, app=user_mgt_app) + return with pytest.raises(auth.CertificateFetchError) as excinfo: auth.verify_session_cookie(TEST_SESSION_COOKIE, app=user_mgt_app) assert 'Could not fetch certificates' in str(excinfo.value) @@ -619,9 +694,11 @@ def test_certificate_caching(self, user_mgt_app, httpserver): verifier.cookie_verifier.cert_url = httpserver.url verifier.id_token_verifier.cert_url = httpserver.url verifier.verify_session_cookie(TEST_SESSION_COOKIE) - assert len(httpserver.requests) == 1 + # No requests should be made in emulated mode + request_count = 0 if _is_emulated() else 1 + assert len(httpserver.requests) == request_count # Subsequent requests should not fetch certs from the server verifier.verify_session_cookie(TEST_SESSION_COOKIE) - assert len(httpserver.requests) == 1 + assert len(httpserver.requests) == request_count verifier.verify_id_token(TEST_ID_TOKEN) - assert len(httpserver.requests) == 1 + assert len(httpserver.requests) == request_count diff --git a/tests/test_user_mgt.py b/tests/test_user_mgt.py index 240f19bdc..ac80a92a6 100644 --- a/tests/test_user_mgt.py +++ b/tests/test_user_mgt.py @@ -50,19 +50,32 @@ } MOCK_ACTION_CODE_SETTINGS = auth.ActionCodeSettings(**MOCK_ACTION_CODE_DATA) -USER_MGT_URL_PREFIX = 'https://identitytoolkit.googleapis.com/v1/projects/mock-project-id' - TEST_TIMEOUT = 42 +ID_TOOLKIT_URL = 'https://identitytoolkit.googleapis.com/v1' +EMULATOR_HOST_ENV_VAR = 'FIREBASE_AUTH_EMULATOR_HOST' +AUTH_EMULATOR_HOST = 'localhost:9099' +EMULATED_ID_TOOLKIT_URL = 'http://{}/identitytoolkit.googleapis.com/v1'.format(AUTH_EMULATOR_HOST) +URL_PROJECT_SUFFIX = '/projects/mock-project-id' +USER_MGT_URLS = { + 'ID_TOOLKIT': ID_TOOLKIT_URL, + 'PREFIX': ID_TOOLKIT_URL + URL_PROJECT_SUFFIX, +} -@pytest.fixture(scope='module') -def user_mgt_app(): +@pytest.fixture(params=[{'emulated': False}, {'emulated': True}]) +def user_mgt_app(request): + monkeypatch = testutils.new_monkeypatch() + if request.param['emulated']: + monkeypatch.setenv(EMULATOR_HOST_ENV_VAR, AUTH_EMULATOR_HOST) + monkeypatch.setitem(USER_MGT_URLS, 'ID_TOOLKIT', EMULATED_ID_TOOLKIT_URL) + monkeypatch.setitem(USER_MGT_URLS, 'PREFIX', EMULATED_ID_TOOLKIT_URL + URL_PROJECT_SUFFIX) app = firebase_admin.initialize_app(testutils.MockCredential(), name='userMgt', options={'projectId': 'mock-project-id'}) yield app firebase_admin.delete_app(app) + monkeypatch.undo() -@pytest.fixture(scope='module') +@pytest.fixture def user_mgt_app_with_timeout(): app = firebase_admin.initialize_app( testutils.MockCredential(), @@ -77,7 +90,7 @@ def _instrument_user_manager(app, status, payload): user_manager = client._user_manager recorder = [] user_manager.http_client.session.mount( - _user_mgt.UserManager.ID_TOOLKIT_URL, + USER_MGT_URLS['ID_TOOLKIT'], testutils.MockAdapter(payload, status, recorder)) return user_manager, recorder @@ -121,7 +134,7 @@ def _check_request(recorder, want_url, want_body=None, want_timeout=None): assert len(recorder) == 1 req = recorder[0] assert req.method == 'POST' - assert req.url == '{0}{1}'.format(USER_MGT_URL_PREFIX, want_url) + assert req.url == '{0}{1}'.format(USER_MGT_URLS['PREFIX'], want_url) if want_body: body = json.loads(req.body.decode()) assert body == want_body diff --git a/tests/testutils.py b/tests/testutils.py index 556155253..4a77c9d80 100644 --- a/tests/testutils.py +++ b/tests/testutils.py @@ -16,6 +16,8 @@ import io import os +import pytest + from google.auth import credentials from google.auth import transport from requests import adapters @@ -58,6 +60,15 @@ def run_without_project_id(func): os.environ[env_var] = gcloud_project +def new_monkeypatch(): + try: + return pytest.MonkeyPatch() + except AttributeError: + # Fallback for Python 3.5 + from _pytest.monkeypatch import MonkeyPatch + return MonkeyPatch() + + class MockResponse(transport.Response): def __init__(self, status, response): self._status = status From 5e88d927757e61023e274cd8b00270d899fdf513 Mon Sep 17 00:00:00 2001 From: Hiranya Jayathilaka Date: Thu, 25 Mar 2021 12:20:53 -0700 Subject: [PATCH 097/226] fix(auth): Setting httpTimeout on certificate fetch requests (#538) * fix(auth): Setting httpTimeout on certificate fetch requests * fix: Removed unused import --- firebase_admin/_token_gen.py | 30 ++++++++++++++++++++-- tests/test_token_gen.py | 50 ++++++++++++++++++++++++++++++++++++ 2 files changed, 78 insertions(+), 2 deletions(-) diff --git a/firebase_admin/_token_gen.py b/firebase_admin/_token_gen.py index 562e77fa5..135573c01 100644 --- a/firebase_admin/_token_gen.py +++ b/firebase_admin/_token_gen.py @@ -29,6 +29,7 @@ from firebase_admin import exceptions from firebase_admin import _auth_utils +from firebase_admin import _http_client # ID token constants @@ -231,12 +232,37 @@ def create_session_cookie(self, id_token, expires_in): return body.get('sessionCookie') +class CertificateFetchRequest(transport.Request): + """A google-auth transport that supports HTTP cache-control. + + Also injects a timeout to each outgoing HTTP request. + """ + + def __init__(self, timeout_seconds=None): + self._session = cachecontrol.CacheControl(requests.Session()) + self._delegate = transport.requests.Request(self.session) + self._timeout_seconds = timeout_seconds + + @property + def session(self): + return self._session + + @property + def timeout_seconds(self): + return self._timeout_seconds + + def __call__(self, url, method='GET', body=None, headers=None, timeout=None, **kwargs): + timeout = timeout or self.timeout_seconds + return self._delegate( + url, method=method, body=body, headers=headers, timeout=timeout, **kwargs) + + class TokenVerifier: """Verifies ID tokens and session cookies.""" def __init__(self, app): - session = cachecontrol.CacheControl(requests.Session()) - self.request = transport.requests.Request(session=session) + timeout = app.options.get('httpTimeout', _http_client.DEFAULT_TIMEOUT_SECONDS) + self.request = CertificateFetchRequest(timeout) self.id_token_verifier = _JWTVerifier( project_id=app.project_id, short_name='ID token', operation='verify_id_token()', diff --git a/tests/test_token_gen.py b/tests/test_token_gen.py index 29c70da80..d8450c59c 100644 --- a/tests/test_token_gen.py +++ b/tests/test_token_gen.py @@ -31,6 +31,7 @@ from firebase_admin import auth from firebase_admin import credentials from firebase_admin import exceptions +from firebase_admin import _http_client from firebase_admin import _token_gen from tests import testutils @@ -702,3 +703,52 @@ def test_certificate_caching(self, user_mgt_app, httpserver): assert len(httpserver.requests) == request_count verifier.verify_id_token(TEST_ID_TOKEN) assert len(httpserver.requests) == request_count + + +class TestCertificateFetchTimeout: + + timeout_configs = [ + ({'httpTimeout': 4}, 4), + ({'httpTimeout': None}, None), + ({}, _http_client.DEFAULT_TIMEOUT_SECONDS), + ] + + @pytest.mark.parametrize('options, timeout', timeout_configs) + def test_init_request(self, options, timeout): + app = firebase_admin.initialize_app(MOCK_CREDENTIAL, options=options) + + client = auth._get_client(app) + request = client._token_verifier.request + + assert isinstance(request, _token_gen.CertificateFetchRequest) + assert request.timeout_seconds == timeout + + @pytest.mark.parametrize('options, timeout', timeout_configs) + def test_verify_id_token_timeout(self, options, timeout): + app = firebase_admin.initialize_app(MOCK_CREDENTIAL, options=options) + recorder = self._instrument_session(app) + + auth.verify_id_token(TEST_ID_TOKEN) + + assert len(recorder) == 1 + assert recorder[0]._extra_kwargs['timeout'] == timeout + + @pytest.mark.parametrize('options, timeout', timeout_configs) + def test_verify_session_cookie_timeout(self, options, timeout): + app = firebase_admin.initialize_app(MOCK_CREDENTIAL, options=options) + recorder = self._instrument_session(app) + + auth.verify_session_cookie(TEST_SESSION_COOKIE) + + assert len(recorder) == 1 + assert recorder[0]._extra_kwargs['timeout'] == timeout + + def _instrument_session(self, app): + client = auth._get_client(app) + request = client._token_verifier.request + recorder = [] + request.session.mount('https://', testutils.MockAdapter(MOCK_PUBLIC_CERTS, 200, recorder)) + return recorder + + def teardown(self): + testutils.cleanup_apps() From a6714a12705c8531ca6d0e6e0e9982d435b4c4c4 Mon Sep 17 00:00:00 2001 From: Lahiru Maramba Date: Tue, 20 Apr 2021 14:39:02 -0400 Subject: [PATCH 098/226] change: Drop Python 3.5 support (#542) * chore: Drop Python 3.5 support * Remove SHA1 hash tests * Clean up previous python version hacks * update to use pytest.MonkeyPatch() * upgrade pytest to 6.2.0 and up --- .github/workflows/ci.yml | 2 +- CONTRIBUTING.md | 2 +- README.md | 3 +-- integration/test_project_management.py | 16 +++------------- requirements.txt | 2 +- setup.py | 8 ++++---- tests/testutils.py | 7 +------ 7 files changed, 12 insertions(+), 28 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 067cd8f18..d81f932a8 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -8,7 +8,7 @@ jobs: strategy: fail-fast: false matrix: - python: [3.5, 3.6, 3.7, 3.9, pypy3] + python: [3.6, 3.7, 3.8, 3.9, pypy3] steps: - uses: actions/checkout@v1 diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index f6d09b093..30685394e 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -85,7 +85,7 @@ information on using pull requests. ### Initial Setup -You need Python 3.4+ to build and test the code in this repo. +You need Python 3.6+ to build and test the code in this repo. We recommend using [pip](https://pypi.python.org/pypi/pip) for installing the necessary tools and project dependencies. Most recent versions of Python ship with pip. If your development environment diff --git a/README.md b/README.md index 180bc2bff..646d3d0e3 100644 --- a/README.md +++ b/README.md @@ -43,8 +43,7 @@ requests, code review feedback, and also pull requests. ## Supported Python Versions -We currently support Python 3.5+. However, Python 3.5 support is deprecated, -and the developers are strongly advised to use Python 3.6 or higher. Firebase +We currently support Python 3.6+. Firebase Admin Python SDK is also tested on PyPy and [Google App Engine](https://cloud.google.com/appengine/) environments. diff --git a/integration/test_project_management.py b/integration/test_project_management.py index ca648f12d..362515535 100644 --- a/integration/test_project_management.py +++ b/integration/test_project_management.py @@ -28,8 +28,6 @@ TEST_APP_PACKAGE_NAME = 'com.firebase.adminsdk_python_integration_test' TEST_APP_DISPLAY_NAME_PREFIX = 'Created By Firebase AdminSDK Python Integration Testing' -SHA_1_HASH_1 = '123456789a123456789a123456789a123456789a' -SHA_1_HASH_2 = 'aaaaaaaaaaaaaaaaaaaabbbbbbbbbbbbbbbbbbbb' SHA_256_HASH_1 = '123456789a123456789a123456789a123456789a123456789a123456789a1234' SHA_256_HASH_2 = 'cafef00dba5eba11b01dfaceacc01adeda7aba5eca55e77e0b57ac1e5ca1ab1e' SHA_1 = project_management.SHACertificate.SHA_1 @@ -119,17 +117,13 @@ def test_android_sha_certificates(android_app): for cert in android_app.get_sha_certificates(): android_app.delete_sha_certificate(cert) - # Add four different certs and assert that they have all been added successfully. - android_app.add_sha_certificate(project_management.SHACertificate(SHA_1_HASH_1)) - android_app.add_sha_certificate(project_management.SHACertificate(SHA_1_HASH_2)) + # Add two different certs and assert that they have all been added successfully. android_app.add_sha_certificate(project_management.SHACertificate(SHA_256_HASH_1)) android_app.add_sha_certificate(project_management.SHACertificate(SHA_256_HASH_2)) cert_list = android_app.get_sha_certificates() - sha_1_hashes = set(cert.sha_hash for cert in cert_list if cert.cert_type == SHA_1) sha_256_hashes = set(cert.sha_hash for cert in cert_list if cert.cert_type == SHA_256) - assert sha_1_hashes == set([SHA_1_HASH_1, SHA_1_HASH_2]) assert sha_256_hashes == set([SHA_256_HASH_1, SHA_256_HASH_2]) for cert in cert_list: assert cert.name @@ -195,12 +189,8 @@ def test_list_ios_apps(ios_app): def test_get_ios_app_config(ios_app, project_id): config = ios_app.get_config() - # In Python 2.7, the plistlib module works with strings, while in Python 3, it is significantly - # redesigned and works with bytes objects instead. - try: - plist = plistlib.loads(config.encode('utf-8')) - except AttributeError: # Python 2.7 plistlib does not have the loads attribute. - plist = plistlib.readPlistFromString(config) # pylint: disable=no-member + plist = plistlib.loads(config.encode('utf-8')) + assert plist['BUNDLE_ID'] == TEST_APP_BUNDLE_ID assert plist['PROJECT_ID'] == project_id assert plist['GOOGLE_APP_ID'] == ios_app.app_id diff --git a/requirements.txt b/requirements.txt index 1a55482da..08dfd1ab5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ astroid == 2.3.3 pylint == 2.3.1 -pytest >= 3.6.0 +pytest >= 6.2.0 pytest-cov >= 2.4.0 pytest-localserver >= 0.4.1 diff --git a/setup.py b/setup.py index 4f0d42365..ea8286dce 100644 --- a/setup.py +++ b/setup.py @@ -22,8 +22,8 @@ (major, minor) = (sys.version_info.major, sys.version_info.minor) -if major != 3 or minor < 5: - print('firebase_admin requires python >= 3.5', file=sys.stderr) +if major != 3 or minor < 6: + print('firebase_admin requires python >= 3.6', file=sys.stderr) sys.exit(1) # Read in the package metadata per recommendations from: @@ -55,15 +55,15 @@ keywords='firebase cloud development', install_requires=install_requires, packages=['firebase_admin'], - python_requires='>=3.5', + python_requires='>=3.6', classifiers=[ 'Development Status :: 5 - Production/Stable', 'Intended Audience :: Developers', 'Topic :: Software Development :: Build Tools', 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.5', 'Programming Language :: Python :: 3.6', 'Programming Language :: Python :: 3.7', + 'Programming Language :: Python :: 3.8', 'Programming Language :: Python :: 3.9', 'License :: OSI Approved :: Apache Software License', ], diff --git a/tests/testutils.py b/tests/testutils.py index 4a77c9d80..92755107c 100644 --- a/tests/testutils.py +++ b/tests/testutils.py @@ -61,12 +61,7 @@ def run_without_project_id(func): def new_monkeypatch(): - try: - return pytest.MonkeyPatch() - except AttributeError: - # Fallback for Python 3.5 - from _pytest.monkeypatch import MonkeyPatch - return MonkeyPatch() + return pytest.MonkeyPatch() class MockResponse(transport.Response): From d9bc50c50d8f7b0a426e31c9fed9824109b02c5d Mon Sep 17 00:00:00 2001 From: Lahiru Maramba Date: Tue, 20 Apr 2021 16:24:21 -0400 Subject: [PATCH 099/226] chore: Add nightly build workflow (#540) * chore: Add nightly build workflow * Add missing version info for send-email action --- .github/workflows/nightly.yml | 98 +++++++++++++++++++++++++++++++++++ integration/test_messaging.py | 3 ++ 2 files changed, 101 insertions(+) create mode 100644 .github/workflows/nightly.yml diff --git a/.github/workflows/nightly.yml b/.github/workflows/nightly.yml new file mode 100644 index 000000000..f22eb99c8 --- /dev/null +++ b/.github/workflows/nightly.yml @@ -0,0 +1,98 @@ +# Copyright 2021 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: Nightly Builds + +on: + # Runs every day at 06:20 AM (PT) and 08:20 PM (PT) / 04:20 AM (UTC) and 02:20 PM (UTC) + # or on 'firebase_nightly_build' repository dispatch event. + schedule: + - cron: "20 4,14 * * *" + repository_dispatch: + types: [firebase_nightly_build] + +jobs: + nightly: + + runs-on: ubuntu-latest + + steps: + - name: Checkout source for staging + uses: actions/checkout@v2 + with: + ref: ${{ github.event.client_payload.ref || github.ref }} + + - name: Set up Python + uses: actions/setup-python@v1 + with: + python-version: 3.6 + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -r requirements.txt + pip install setuptools wheel + pip install tensorflow + pip install keras + + - name: Run unit tests + run: pytest + + - name: Run integration tests + run: ./.github/scripts/run_integration_tests.sh + env: + FIREBASE_SERVICE_ACCT_KEY: ${{ secrets.FIREBASE_SERVICE_ACCT_KEY }} + FIREBASE_API_KEY: ${{ secrets.FIREBASE_API_KEY }} + + # Build the Python Wheel and the source distribution. + - name: Package release artifacts + run: python setup.py bdist_wheel sdist + + # Attach the packaged artifacts to the workflow output. These can be manually + # downloaded for later inspection if necessary. + - name: Archive artifacts + uses: actions/upload-artifact@v1 + with: + name: dist + path: dist + + - name: Send email on failure + if: failure() + uses: firebase/firebase-admin-node/.github/actions/send-email@master + with: + api-key: ${{ secrets.OSS_BOT_MAILGUN_KEY }} + domain: ${{ secrets.OSS_BOT_MAILGUN_DOMAIN }} + from: 'GitHub ' + to: ${{ secrets.FIREBASE_ADMIN_GITHUB_EMAIL }} + subject: 'Nightly build ${{github.run_id}} of ${{github.repository}} failed!' + html: > + Nightly workflow ${{github.run_id}} failed on: ${{github.repository}} +

Navigate to the + failed workflow. + continue-on-error: true + + - name: Send email on cancelled + if: cancelled() + uses: firebase/firebase-admin-node/.github/actions/send-email@master + with: + api-key: ${{ secrets.OSS_BOT_MAILGUN_KEY }} + domain: ${{ secrets.OSS_BOT_MAILGUN_DOMAIN }} + from: 'GitHub ' + to: ${{ secrets.FIREBASE_ADMIN_GITHUB_EMAIL }} + subject: 'Nightly build ${{github.run_id}} of ${{github.repository}} cancelled!' + html: > + Nightly workflow ${{github.run_id}} cancelled on: ${{github.repository}} +

Navigate to the + cancelled workflow. + continue-on-error: true diff --git a/integration/test_messaging.py b/integration/test_messaging.py index b5612b63d..48f8208f3 100644 --- a/integration/test_messaging.py +++ b/integration/test_messaging.py @@ -28,6 +28,9 @@ '1SsRoE') +def test_to_trigger_nightly_email_notification(): + assert 'a' == 'b' + def test_send(): msg = messaging.Message( topic='foo-bar', From df47e570e30fa46db8c669d7bdcd5392821a6f50 Mon Sep 17 00:00:00 2001 From: Lahiru Maramba Date: Wed, 21 Apr 2021 14:24:42 -0400 Subject: [PATCH 100/226] Remove failing integration test added for nightly (#545) - Remove the failing integration test added in #540 to test the nightly email notifications (the test runs are completed). Note: staging to trigger integration tests. All tests should pass. --- integration/test_messaging.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/integration/test_messaging.py b/integration/test_messaging.py index 48f8208f3..b5612b63d 100644 --- a/integration/test_messaging.py +++ b/integration/test_messaging.py @@ -28,9 +28,6 @@ '1SsRoE') -def test_to_trigger_nightly_email_notification(): - assert 'a' == 'b' - def test_send(): msg = messaging.Message( topic='foo-bar', From 44ae03822f2ecb468d89348cc0f94e03bc1e9e5d Mon Sep 17 00:00:00 2001 From: Lahiru Maramba Date: Wed, 21 Apr 2021 14:44:58 -0400 Subject: [PATCH 101/226] chore: Upgraded Google Auth, Cloud Firestore, and Cloud Storage dependencies (#544) * chore: Upgrade auth, firestore, and storage * remove google-auth dependency --- requirements.txt | 7 +++---- setup.py | 6 +++--- tests/test_token_gen.py | 4 ++-- 3 files changed, 8 insertions(+), 9 deletions(-) diff --git a/requirements.txt b/requirements.txt index 08dfd1ab5..131b65f8b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,8 +5,7 @@ pytest-cov >= 2.4.0 pytest-localserver >= 0.4.1 cachecontrol >= 0.12.6 -google-api-core[grpc] >= 1.14.0, < 2.0.0dev; platform.python_implementation != 'PyPy' +google-api-core[grpc] >= 1.22.1, < 2.0.0dev; platform.python_implementation != 'PyPy' google-api-python-client >= 1.7.8 -google-auth == 1.18.0 # temporary workaround -google-cloud-firestore >= 1.4.0; platform.python_implementation != 'PyPy' -google-cloud-storage >= 1.18.0 +google-cloud-firestore >= 2.1.0; platform.python_implementation != 'PyPy' +google-cloud-storage >= 1.37.1 diff --git a/setup.py b/setup.py index ea8286dce..83b7291df 100644 --- a/setup.py +++ b/setup.py @@ -38,10 +38,10 @@ 'to integrate Firebase into their services and applications.') install_requires = [ 'cachecontrol>=0.12.6', - 'google-api-core[grpc] >= 1.14.0, < 2.0.0dev; platform.python_implementation != "PyPy"', + 'google-api-core[grpc] >= 1.22.1, < 2.0.0dev; platform.python_implementation != "PyPy"', 'google-api-python-client >= 1.7.8', - 'google-cloud-firestore>=1.4.0; platform.python_implementation != "PyPy"', - 'google-cloud-storage>=1.18.0', + 'google-cloud-firestore>=2.1.0; platform.python_implementation != "PyPy"', + 'google-cloud-storage>=1.37.1', ] setup( diff --git a/tests/test_token_gen.py b/tests/test_token_gen.py index d8450c59c..b0a744f1d 100644 --- a/tests/test_token_gen.py +++ b/tests/test_token_gen.py @@ -259,7 +259,7 @@ def test_sign_with_iam(self): testutils.MockCredential(), name='iam-signer-app', options=options) try: signature = base64.b64encode(b'test').decode() - iam_resp = '{{"signature": "{0}"}}'.format(signature) + iam_resp = '{{"signedBlob": "{0}"}}'.format(signature) _overwrite_iam_request(app, testutils.MockRequest(200, iam_resp)) custom_token = auth.create_custom_token(MOCK_UID, app=app).decode() assert custom_token.endswith('.' + signature.rstrip('=')) @@ -297,7 +297,7 @@ def test_sign_with_discovered_service_account(self): # Now invoke the IAM signer. signature = base64.b64encode(b'test').decode() request.response = testutils.MockResponse( - 200, '{{"signature": "{0}"}}'.format(signature)) + 200, '{{"signedBlob": "{0}"}}'.format(signature)) custom_token = auth.create_custom_token(MOCK_UID, app=app).decode() assert custom_token.endswith('.' + signature.rstrip('=')) self._verify_signer(custom_token, 'discovered-service-account') From 939375c021f4c4425beb4bd77b76f0d1be9a7ddd Mon Sep 17 00:00:00 2001 From: Hiranya Jayathilaka Date: Wed, 21 Apr 2021 13:39:28 -0700 Subject: [PATCH 102/226] fix: Using alg=none header for custom tokens in emulator mode (#541) * fix(auth): Using alg=none header for custom tokens in emulator mode * fix: Dropping google-auth explicit dependency; Temp test skip for Py 3.5 * chore: Removed py35 hack --- firebase_admin/_token_gen.py | 15 ++++++++++++--- tests/test_token_gen.py | 10 +++++++--- 2 files changed, 19 insertions(+), 6 deletions(-) diff --git a/firebase_admin/_token_gen.py b/firebase_admin/_token_gen.py index 135573c01..32c109d5d 100644 --- a/firebase_admin/_token_gen.py +++ b/firebase_admin/_token_gen.py @@ -53,6 +53,8 @@ ]) METADATA_SERVICE_URL = ('http://metadata.google.internal/computeMetadata/v1/instance/' 'service-accounts/default/email') +ALGORITHM_RS256 = 'RS256' +ALGORITHM_NONE = 'none' # Emulator fake account AUTH_EMULATOR_EMAIL = 'firebase-auth-emulator@example.com' @@ -71,9 +73,10 @@ def sign(self, message): class _SigningProvider: """Stores a reference to a google.auth.crypto.Signer.""" - def __init__(self, signer, signer_email): + def __init__(self, signer, signer_email, alg=ALGORITHM_RS256): self._signer = signer self._signer_email = signer_email + self._alg = alg @property def signer(self): @@ -83,6 +86,10 @@ def signer(self): def signer_email(self): return self._signer_email + @property + def alg(self): + return self._alg + @classmethod def from_credential(cls, google_cred): return _SigningProvider(google_cred.signer, google_cred.signer_email) @@ -94,7 +101,7 @@ def from_iam(cls, request, google_cred, service_account): @classmethod def for_emulator(cls): - return _SigningProvider(_EmulatedSigner(), AUTH_EMULATOR_EMAIL) + return _SigningProvider(_EmulatedSigner(), AUTH_EMULATOR_EMAIL, ALGORITHM_NONE) class TokenGenerator: @@ -190,8 +197,10 @@ def create_custom_token(self, uid, developer_claims=None, tenant_id=None): if developer_claims is not None: payload['claims'] = developer_claims + + header = {'alg': signing_provider.alg} try: - return jwt.encode(signing_provider.signer, payload) + return jwt.encode(signing_provider.signer, payload, header=header) except google.auth.exceptions.TransportError as error: msg = 'Failed to sign custom token. {0}'.format(error) raise TokenSignError(msg, error) diff --git a/tests/test_token_gen.py b/tests/test_token_gen.py index b0a744f1d..0a09862ab 100644 --- a/tests/test_token_gen.py +++ b/tests/test_token_gen.py @@ -75,17 +75,24 @@ def _merge_jwt_claims(defaults, overrides): del defaults[key] return defaults + def verify_custom_token(custom_token, expected_claims, tenant_id=None): assert isinstance(custom_token, bytes) expected_email = MOCK_SERVICE_ACCOUNT_EMAIL + header = jwt.decode_header(custom_token) + assert header.get('typ') == 'JWT' if _is_emulated(): + assert header.get('alg') == 'none' + assert custom_token.split(b'.')[2] == b'' expected_email = _token_gen.AUTH_EMULATOR_EMAIL token = jwt.decode(custom_token, verify=False) else: + assert header.get('alg') == 'RS256' token = google.oauth2.id_token.verify_token( custom_token, testutils.MockRequest(200, MOCK_PUBLIC_CERTS), _token_gen.FIREBASE_AUDIENCE) + assert token['uid'] == MOCK_UID assert token['iss'] == expected_email assert token['sub'] == expected_email @@ -94,9 +101,6 @@ def verify_custom_token(custom_token, expected_claims, tenant_id=None): else: assert token['tenant_id'] == tenant_id - header = jwt.decode_header(custom_token) - assert header.get('typ') == 'JWT' - assert header.get('alg') == 'RS256' if expected_claims: for key, value in expected_claims.items(): assert value == token['claims'][key] From dcd3a86345b53d3cab7c9a747be717f892137c6a Mon Sep 17 00:00:00 2001 From: Hiranya Jayathilaka Date: Fri, 23 Apr 2021 14:47:30 -0700 Subject: [PATCH 103/226] fix: Accept Path-like objects in credential factory functions (#510) * fix: Accept Path-like objects in credential factory functions * chore: Trigger CI --- firebase_admin/credentials.py | 13 +++++++++++-- tests/test_credentials.py | 12 ++++++++++++ 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/firebase_admin/credentials.py b/firebase_admin/credentials.py index 1f207e483..5477e1cf7 100644 --- a/firebase_admin/credentials.py +++ b/firebase_admin/credentials.py @@ -15,6 +15,7 @@ """Firebase credentials module.""" import collections import json +import pathlib import google.auth from google.auth.transport import requests @@ -78,7 +79,7 @@ def __init__(self, cert): ValueError: If the specified certificate is invalid. """ super(Certificate, self).__init__() - if isinstance(cert, str): + if _is_file_path(cert): with open(cert) as json_file: json_data = json.load(json_file) elif isinstance(cert, dict): @@ -179,7 +180,7 @@ def __init__(self, refresh_token): ValueError: If the refresh token configuration is invalid. """ super(RefreshToken, self).__init__() - if isinstance(refresh_token, str): + if _is_file_path(refresh_token): with open(refresh_token) as json_file: json_data = json.load(json_file) elif isinstance(refresh_token, dict): @@ -212,3 +213,11 @@ def get_credential(self): Returns: google.auth.credentials.Credentials: A Google Auth credential instance.""" return self._g_credential + + +def _is_file_path(path): + try: + pathlib.Path(path) + return True + except TypeError: + return False diff --git a/tests/test_credentials.py b/tests/test_credentials.py index d78ef5192..cceb6b6f9 100644 --- a/tests/test_credentials.py +++ b/tests/test_credentials.py @@ -16,6 +16,7 @@ import datetime import json import os +import pathlib import google.auth from google.auth import crypt @@ -47,6 +48,12 @@ def test_init_from_file(self): testutils.resource_filename('service_account.json')) self._verify_credential(credential) + def test_init_from_path_like(self): + path = pathlib.Path(testutils.resource_filename('service_account.json')) + credential = credentials.Certificate(path) + self._verify_credential(credential) + + def test_init_from_dict(self): parsed_json = json.loads(testutils.resource('service_account.json')) credential = credentials.Certificate(parsed_json) @@ -129,6 +136,11 @@ def test_init_from_file(self): testutils.resource_filename('refresh_token.json')) self._verify_credential(credential) + def test_init_from_path_like(self): + path = pathlib.Path(testutils.resource_filename('refresh_token.json')) + credential = credentials.RefreshToken(path) + self._verify_credential(credential) + def test_init_from_dict(self): parsed_json = json.loads(testutils.resource('refresh_token.json')) credential = credentials.RefreshToken(parsed_json) From 4d27a2be9c09c471debc6222fcbc0d4f63c51328 Mon Sep 17 00:00:00 2001 From: Omid-eD <51051081+Omid-eD@users.noreply.github.com> Date: Tue, 27 Apr 2021 02:05:15 +0430 Subject: [PATCH 104/226] Fix Typo in messaging.py (#546) --- firebase_admin/messaging.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/firebase_admin/messaging.py b/firebase_admin/messaging.py index 7c92a3d8d..548bcfc37 100644 --- a/firebase_admin/messaging.py +++ b/firebase_admin/messaging.py @@ -106,7 +106,7 @@ def send(message, dry_run=False, app=None): app: An App instance (optional). Returns: - string: A message ID string that uniquely identifies the sent the message. + string: A message ID string that uniquely identifies the sent message. Raises: FirebaseError: If an error occurs while sending the message to the FCM service. From ebf09619ad24cdcf04a8a81362844233fa2a2130 Mon Sep 17 00:00:00 2001 From: Lahiru Maramba Date: Wed, 28 Apr 2021 14:29:12 -0400 Subject: [PATCH 105/226] [chore] Release 5.0.0 (#547) --- firebase_admin/__about__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/firebase_admin/__about__.py b/firebase_admin/__about__.py index c4665e933..353643533 100644 --- a/firebase_admin/__about__.py +++ b/firebase_admin/__about__.py @@ -14,7 +14,7 @@ """About information (version, etc) for Firebase Admin SDK.""" -__version__ = '4.5.3' +__version__ = '5.0.0' __title__ = 'firebase_admin' __author__ = 'Firebase' __license__ = 'Apache License 2.0' From 527a8245e1ba461817be3f38ac142a0563daeb4d Mon Sep 17 00:00:00 2001 From: bojeil-google Date: Mon, 3 May 2021 17:27:39 -0700 Subject: [PATCH 106/226] fix(auth): adds missing EMAIL_NOT_FOUND error code (#550) * fix(auth): adds missing EMAIL_NOT_FOUND error code Catch EMAIL_NOT_FOUND and translate to EmailNotFoundError when /accounts:sendOobCode is called for password reset on a user that does not exist. --- firebase_admin/_auth_client.py | 6 ++++-- firebase_admin/_auth_utils.py | 10 ++++++++++ firebase_admin/auth.py | 2 ++ tests/test_user_mgt.py | 11 +++++++++++ 4 files changed, 27 insertions(+), 2 deletions(-) diff --git a/firebase_admin/_auth_client.py b/firebase_admin/_auth_client.py index 2f6713d41..a58dbef74 100644 --- a/firebase_admin/_auth_client.py +++ b/firebase_admin/_auth_client.py @@ -181,7 +181,7 @@ def get_user_by_email(self, email): Raises: ValueError: If the email is None, empty or malformed. - UserNotFoundError: If no user exists by the specified email address. + UserNotFoundError: If no user exists for the specified email address. FirebaseError: If an error occurs while retrieving the user. """ response = self._user_manager.get_user(email=email) @@ -198,7 +198,7 @@ def get_user_by_phone_number(self, phone_number): Raises: ValueError: If the phone number is ``None``, empty or malformed. - UserNotFoundError: If no user exists by the specified phone number. + UserNotFoundError: If no user exists for the specified phone number. FirebaseError: If an error occurs while retrieving the user. """ response = self._user_manager.get_user(phone_number=phone_number) @@ -444,6 +444,7 @@ def generate_password_reset_link(self, email, action_code_settings=None): Raises: ValueError: If the provided arguments are invalid + EmailNotFoundError: If no user exists for the specified email address. FirebaseError: If an error occurs while generating the link """ return self._user_manager.generate_email_action_link( @@ -464,6 +465,7 @@ def generate_email_verification_link(self, email, action_code_settings=None): Raises: ValueError: If the provided arguments are invalid + UserNotFoundError: If no user exists for the specified email address. FirebaseError: If an error occurs while generating the link """ return self._user_manager.generate_email_action_link( diff --git a/firebase_admin/_auth_utils.py b/firebase_admin/_auth_utils.py index d8e49b1a1..50c52812e 100644 --- a/firebase_admin/_auth_utils.py +++ b/firebase_admin/_auth_utils.py @@ -351,6 +351,15 @@ def __init__(self, message, cause=None, http_response=None): exceptions.NotFoundError.__init__(self, message, cause, http_response) +class EmailNotFoundError(exceptions.NotFoundError): + """No user record found for the specified email.""" + + default_message = 'No user record found for the given email' + + def __init__(self, message, cause=None, http_response=None): + exceptions.NotFoundError.__init__(self, message, cause, http_response) + + class TenantNotFoundError(exceptions.NotFoundError): """No tenant found for the specified identifier.""" @@ -381,6 +390,7 @@ def __init__(self, message, cause=None, http_response=None): 'DUPLICATE_EMAIL': EmailAlreadyExistsError, 'DUPLICATE_LOCAL_ID': UidAlreadyExistsError, 'EMAIL_EXISTS': EmailAlreadyExistsError, + 'EMAIL_NOT_FOUND': EmailNotFoundError, 'INSUFFICIENT_PERMISSION': InsufficientPermissionError, 'INVALID_DYNAMIC_LINK_DOMAIN': InvalidDynamicLinkDomainError, 'INVALID_ID_TOKEN': InvalidIdTokenError, diff --git a/firebase_admin/auth.py b/firebase_admin/auth.py index 5154bb495..ed9829aca 100644 --- a/firebase_admin/auth.py +++ b/firebase_admin/auth.py @@ -39,6 +39,7 @@ 'ConfigurationNotFoundError', 'DELETE_ATTRIBUTE', 'EmailAlreadyExistsError', + 'EmailNotFoundError', 'ErrorInfo', 'ExpiredIdTokenError', 'ExpiredSessionCookieError', @@ -112,6 +113,7 @@ DELETE_ATTRIBUTE = _user_mgt.DELETE_ATTRIBUTE DeleteUsersResult = _user_mgt.DeleteUsersResult EmailAlreadyExistsError = _auth_utils.EmailAlreadyExistsError +EmailNotFoundError = _auth_utils.EmailNotFoundError ErrorInfo = _user_import.ErrorInfo ExpiredIdTokenError = _token_gen.ExpiredIdTokenError ExpiredSessionCookieError = _token_gen.ExpiredSessionCookieError diff --git a/tests/test_user_mgt.py b/tests/test_user_mgt.py index ac80a92a6..10dfe698f 100644 --- a/tests/test_user_mgt.py +++ b/tests/test_user_mgt.py @@ -1446,6 +1446,17 @@ def test_api_call_failure(self, user_mgt_app, func): assert excinfo.value.http_response is not None assert excinfo.value.cause is not None + def test_password_reset_non_existing(self, user_mgt_app): + _instrument_user_manager(user_mgt_app, 400, '{"error":{"message": "EMAIL_NOT_FOUND"}}') + with pytest.raises(auth.EmailNotFoundError) as excinfo: + auth.generate_password_reset_link( + 'nonexistent@user', MOCK_ACTION_CODE_SETTINGS, app=user_mgt_app) + error_msg = 'No user record found for the given email (EMAIL_NOT_FOUND).' + assert excinfo.value.code == exceptions.NOT_FOUND + assert str(excinfo.value) == error_msg + assert excinfo.value.http_response is not None + assert excinfo.value.cause is not None + @pytest.mark.parametrize('func', [ auth.generate_sign_in_with_email_link, auth.generate_email_verification_link, From 172f200722fd4ac7a1122aecbcd45ad53127eb2b Mon Sep 17 00:00:00 2001 From: Lahiru Maramba Date: Thu, 17 Jun 2021 11:44:38 -0400 Subject: [PATCH 107/226] [chore] Release 5.0.1 (#557) --- firebase_admin/__about__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/firebase_admin/__about__.py b/firebase_admin/__about__.py index 353643533..4648863f3 100644 --- a/firebase_admin/__about__.py +++ b/firebase_admin/__about__.py @@ -14,7 +14,7 @@ """About information (version, etc) for Firebase Admin SDK.""" -__version__ = '5.0.0' +__version__ = '5.0.1' __title__ = 'firebase_admin' __author__ = 'Firebase' __license__ = 'Apache License 2.0' From 9ff16bd7d74c554facceaa1c1aaa7f40857cfd48 Mon Sep 17 00:00:00 2001 From: Hiranya Jayathilaka Date: Fri, 25 Jun 2021 12:05:28 -0700 Subject: [PATCH 108/226] fix(fcm): A workaround for the concurrency issues in googleapiclient (#558) * fix(fcm): FA workaround for the concurrency issues in googleapiclient * fix: Added test case --- firebase_admin/messaging.py | 9 +++++---- tests/test_messaging.py | 40 ++++++++++++++++++++++++++++++------- 2 files changed, 38 insertions(+), 11 deletions(-) diff --git a/firebase_admin/messaging.py b/firebase_admin/messaging.py index 548bcfc37..95fc03e04 100644 --- a/firebase_admin/messaging.py +++ b/firebase_admin/messaging.py @@ -330,9 +330,9 @@ def __init__(self, app): 'X-FIREBASE-CLIENT': 'fire-admin-python/{0}'.format(firebase_admin.__version__), } timeout = app.options.get('httpTimeout', _http_client.DEFAULT_TIMEOUT_SECONDS) - self._client = _http_client.JsonHttpClient( - credential=app.credential.get_credential(), timeout=timeout) - self._transport = _auth.authorized_http(app.credential.get_credential()) + self._credential = app.credential.get_credential() + self._client = _http_client.JsonHttpClient(credential=self._credential, timeout=timeout) + self._build_transport = _auth.authorized_http @classmethod def encode_message(cls, message): @@ -373,10 +373,11 @@ def batch_callback(_, response, error): batch = http.BatchHttpRequest( callback=batch_callback, batch_uri=_MessagingService.FCM_BATCH_URL) + transport = self._build_transport(self._credential) for message in messages: body = json.dumps(self._message_data(message, dry_run)) req = http.HttpRequest( - http=self._transport, + http=transport, postproc=self._postproc, uri=self._fcm_url, method='POST', diff --git a/tests/test_messaging.py b/tests/test_messaging.py index 8eb24c0a9..3d8740cc1 100644 --- a/tests/test_messaging.py +++ b/tests/test_messaging.py @@ -1813,20 +1813,23 @@ def teardown_class(cls): testutils.cleanup_apps() def _instrument_batch_messaging_service(self, app=None, status=200, payload='', exc=None): - if not app: - app = firebase_admin.get_app() + def build_mock_transport(_): + if exc: + return _HttpMockException(exc) - fcm_service = messaging._get_messaging_service(app) - if exc: - fcm_service._transport = _HttpMockException(exc) - else: if status == 200: content_type = 'multipart/mixed; boundary=boundary' else: content_type = 'application/json' - fcm_service._transport = http.HttpMockSequence([ + return http.HttpMockSequence([ ({'status': str(status), 'content-type': content_type}, payload), ]) + + if not app: + app = firebase_admin.get_app() + + fcm_service = messaging._get_messaging_service(app) + fcm_service._build_transport = build_mock_transport return fcm_service def _batch_payload(self, payloads): @@ -2053,6 +2056,29 @@ def test_send_all_runtime_exception(self): assert excinfo.value.cause is exc assert excinfo.value.http_response is None + def test_send_transport_init(self): + def track_call_count(build_transport): + def wrapper(credential): + wrapper.calls += 1 + return build_transport(credential) + wrapper.calls = 0 + return wrapper + + payload = json.dumps({'name': 'message-id'}) + fcm_service = self._instrument_batch_messaging_service( + payload=self._batch_payload([(200, payload), (200, payload)])) + build_mock_transport = fcm_service._build_transport + fcm_service._build_transport = track_call_count(build_mock_transport) + msg = messaging.Message(topic='foo') + + batch_response = messaging.send_all([msg, msg], dry_run=True) + assert batch_response.success_count == 2 + assert fcm_service._build_transport.calls == 1 + + batch_response = messaging.send_all([msg, msg], dry_run=True) + assert batch_response.success_count == 2 + assert fcm_service._build_transport.calls == 2 + class TestSendMulticast(TestBatch): From fb64981f08228aae2c3742c4e507d9793c16c17e Mon Sep 17 00:00:00 2001 From: Hiranya Jayathilaka Date: Mon, 12 Jul 2021 10:29:32 -0700 Subject: [PATCH 109/226] chore: Configuring GitHub issue templates (#562) --- .../ISSUE_TEMPLATE/bug_report.md | 22 ++++++++++++++----- .github/ISSUE_TEMPLATE/feature_request.md | 20 +++++++++++++++++ 2 files changed, 37 insertions(+), 5 deletions(-) rename ISSUE_TEMPLATE.md => .github/ISSUE_TEMPLATE/bug_report.md (62%) create mode 100644 .github/ISSUE_TEMPLATE/feature_request.md diff --git a/ISSUE_TEMPLATE.md b/.github/ISSUE_TEMPLATE/bug_report.md similarity index 62% rename from ISSUE_TEMPLATE.md rename to .github/ISSUE_TEMPLATE/bug_report.md index 5de83b2cc..2970d494f 100644 --- a/ISSUE_TEMPLATE.md +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -1,10 +1,21 @@ +--- +name: Bug report +about: Bug reports related to any component in this repo +title: '' +labels: '' +assignees: '' + +--- + ### [READ] Step 1: Are you in the right place? - * For issues or feature requests related to __the code in this repository__ - file a Github issue. - * If this is a __feature request__ make sure the issue title starts with "FR:". + * For issues related to __the code in this repository__ file a GitHub issue. + * If the issue pertains to __Cloud Firestore__, report directly in the + [Python Firestore](https://github.com/googleapis/python-firestore) GitHub repo. Firestore + bugs reported in this repo will be closed with a reference to the Python Firestore + project. * For general technical questions, post a question on [StackOverflow](http://stackoverflow.com/) - with the firebase tag. + with the `firebase` tag. * For general Firebase discussion, use the [firebase-talk](https://groups.google.com/forum/#!forum/firebase-talk) google group. * For help troubleshooting your application that does not fall under one @@ -15,8 +26,9 @@ * Operating System version: _____ * Firebase SDK version: _____ - * Library version: _____ * Firebase Product: _____ (auth, database, storage, etc) + * Python version: _____ + * Pip version: _____ ### [REQUIRED] Step 3: Describe the problem diff --git a/.github/ISSUE_TEMPLATE/feature_request.md b/.github/ISSUE_TEMPLATE/feature_request.md new file mode 100644 index 000000000..7729d13a4 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/feature_request.md @@ -0,0 +1,20 @@ +--- +name: Feature request +about: Suggest an idea for this project +title: "[FR]" +labels: 'type: feature request' +assignees: '' + +--- + +**Is your feature request related to a problem? Please describe.** +A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] + +**Describe the solution you'd like** +A clear and concise description of what you want to happen. + +**Describe alternatives you've considered** +A clear and concise description of any alternative solutions or features you've considered. + +**Additional context** +Add any other context, code samples or screenshots about the feature request here. From 0e35c9a6ce7bc4bd8034c5369e601cb440cff061 Mon Sep 17 00:00:00 2001 From: bojeil-google Date: Wed, 4 Aug 2021 13:35:35 -0700 Subject: [PATCH 110/226] fix(auth): check if user disabled on check_revoked (#565) * fix(auth): check if user disabled on check_revoked When `verify_session_cookie` or `verify_id_token` is called with `check_revoked` set to `True` we should also check if the user is disabled. If disabled the `UserDisabledError` is raised. --- firebase_admin/_auth_client.py | 12 +++++-- firebase_admin/_auth_utils.py | 9 +++++ firebase_admin/auth.py | 15 +++++++-- integration/test_auth.py | 36 ++++++++++++++++++++ snippets/auth/index.py | 6 ++++ tests/test_token_gen.py | 61 ++++++++++++++++++++++++++++++++++ 6 files changed, 133 insertions(+), 6 deletions(-) diff --git a/firebase_admin/_auth_client.py b/firebase_admin/_auth_client.py index a58dbef74..4418a034d 100644 --- a/firebase_admin/_auth_client.py +++ b/firebase_admin/_auth_client.py @@ -100,7 +100,8 @@ def verify_id_token(self, id_token, check_revoked=False): Args: id_token: A string of the encoded JWT. - check_revoked: Boolean, If true, checks whether the token has been revoked (optional). + check_revoked: Boolean, If true, checks whether the token has been revoked or + the user disabled (optional). Returns: dict: A dictionary of key-value pairs parsed from the decoded JWT. @@ -115,6 +116,8 @@ def verify_id_token(self, id_token, check_revoked=False): this ``Client`` instance. CertificateFetchError: If an error occurs while fetching the public key certificates required to verify the ID token. + UserDisabledError: If ``check_revoked`` is ``True`` and the corresponding user + record is disabled. """ if not isinstance(check_revoked, bool): # guard against accidental wrong assignment. @@ -129,7 +132,8 @@ def verify_id_token(self, id_token, check_revoked=False): 'Invalid tenant ID: {0}'.format(token_tenant_id)) if check_revoked: - self._check_jwt_revoked(verified_claims, _token_gen.RevokedIdTokenError, 'ID token') + self._check_jwt_revoked_or_disabled( + verified_claims, _token_gen.RevokedIdTokenError, 'ID token') return verified_claims def revoke_refresh_tokens(self, uid): @@ -720,7 +724,9 @@ def list_saml_provider_configs( """ return self._provider_manager.list_saml_provider_configs(page_token, max_results) - def _check_jwt_revoked(self, verified_claims, exc_type, label): + def _check_jwt_revoked_or_disabled(self, verified_claims, exc_type, label): user = self.get_user(verified_claims.get('uid')) + if user.disabled: + raise _auth_utils.UserDisabledError('The user record is disabled.') if verified_claims.get('iat') * 1000 < user.tokens_valid_after_timestamp: raise exc_type('The Firebase {0} has been revoked.'.format(label)) diff --git a/firebase_admin/_auth_utils.py b/firebase_admin/_auth_utils.py index 50c52812e..e368342e8 100644 --- a/firebase_admin/_auth_utils.py +++ b/firebase_admin/_auth_utils.py @@ -385,6 +385,15 @@ def __init__(self, message, cause=None, http_response=None): exceptions.NotFoundError.__init__(self, message, cause, http_response) +class UserDisabledError(exceptions.InvalidArgumentError): + """An operation failed due to a user record being disabled.""" + + default_message = 'The user record is disabled' + + def __init__(self, message, cause=None, http_response=None): + exceptions.InvalidArgumentError.__init__(self, message, cause, http_response) + + _CODE_TO_EXC_TYPE = { 'CONFIGURATION_NOT_FOUND': ConfigurationNotFoundError, 'DUPLICATE_EMAIL': EmailAlreadyExistsError, diff --git a/firebase_admin/auth.py b/firebase_admin/auth.py index ed9829aca..40a5b611f 100644 --- a/firebase_admin/auth.py +++ b/firebase_admin/auth.py @@ -62,6 +62,7 @@ 'TokenSignError', 'UidAlreadyExistsError', 'UnexpectedResponseError', + 'UserDisabledError', 'UserImportHash', 'UserImportResult', 'UserInfo', @@ -135,6 +136,7 @@ TokenSignError = _token_gen.TokenSignError UidAlreadyExistsError = _auth_utils.UidAlreadyExistsError UnexpectedResponseError = _auth_utils.UnexpectedResponseError +UserDisabledError = _auth_utils.UserDisabledError UserImportHash = _user_import.UserImportHash UserImportResult = _user_import.UserImportResult UserInfo = _user_mgt.UserInfo @@ -198,7 +200,8 @@ def verify_id_token(id_token, app=None, check_revoked=False): Args: id_token: A string of the encoded JWT. app: An App instance (optional). - check_revoked: Boolean, If true, checks whether the token has been revoked (optional). + check_revoked: Boolean, If true, checks whether the token has been revoked or + the user disabled (optional). Returns: dict: A dictionary of key-value pairs parsed from the decoded JWT. @@ -210,6 +213,8 @@ def verify_id_token(id_token, app=None, check_revoked=False): RevokedIdTokenError: If ``check_revoked`` is ``True`` and the ID token has been revoked. CertificateFetchError: If an error occurs while fetching the public key certificates required to verify the ID token. + UserDisabledError: If ``check_revoked`` is ``True`` and the corresponding user + record is disabled. """ client = _get_client(app) return client.verify_id_token(id_token, check_revoked=check_revoked) @@ -246,7 +251,8 @@ def verify_session_cookie(session_cookie, check_revoked=False, app=None): Args: session_cookie: A session cookie string to verify. - check_revoked: Boolean, if true, checks whether the cookie has been revoked (optional). + check_revoked: Boolean, if true, checks whether the cookie has been revoked or the + user disabled (optional). app: An App instance (optional). Returns: @@ -259,12 +265,15 @@ def verify_session_cookie(session_cookie, check_revoked=False, app=None): RevokedSessionCookieError: If ``check_revoked`` is ``True`` and the cookie has been revoked. CertificateFetchError: If an error occurs while fetching the public key certificates required to verify the session cookie. + UserDisabledError: If ``check_revoked`` is ``True`` and the corresponding user + record is disabled. """ client = _get_client(app) # pylint: disable=protected-access verified_claims = client._token_verifier.verify_session_cookie(session_cookie) if check_revoked: - client._check_jwt_revoked(verified_claims, RevokedSessionCookieError, 'session cookie') + client._check_jwt_revoked_or_disabled( + verified_claims, RevokedSessionCookieError, 'session cookie') return verified_claims diff --git a/integration/test_auth.py b/integration/test_auth.py index 16ae52a86..55ddbb0a0 100644 --- a/integration/test_auth.py +++ b/integration/test_auth.py @@ -569,6 +569,24 @@ def test_verify_id_token_revoked(new_user, api_key): claims = auth.verify_id_token(id_token, check_revoked=True) assert claims['iat'] * 1000 >= user.tokens_valid_after_timestamp +def test_verify_id_token_disabled(new_user, api_key): + custom_token = auth.create_custom_token(new_user.uid) + id_token = _sign_in(custom_token, api_key) + claims = auth.verify_id_token(id_token, check_revoked=True) + + # Disable the user record. + auth.update_user(new_user.uid, disabled=True) + # Verify the ID token without checking revocation. This should + # not raise. + claims = auth.verify_id_token(id_token, check_revoked=False) + assert claims['sub'] == new_user.uid + + # Verify the ID token while checking revocation. This should + # raise an exception. + with pytest.raises(auth.UserDisabledError) as excinfo: + auth.verify_id_token(id_token, check_revoked=True) + assert str(excinfo.value) == 'The user record is disabled.' + def test_verify_session_cookie_revoked(new_user, api_key): custom_token = auth.create_custom_token(new_user.uid) id_token = _sign_in(custom_token, api_key) @@ -591,6 +609,24 @@ def test_verify_session_cookie_revoked(new_user, api_key): claims = auth.verify_session_cookie(session_cookie, check_revoked=True) assert claims['iat'] * 1000 >= user.tokens_valid_after_timestamp +def test_verify_session_cookie_disabled(new_user, api_key): + custom_token = auth.create_custom_token(new_user.uid) + id_token = _sign_in(custom_token, api_key) + session_cookie = auth.create_session_cookie(id_token, expires_in=datetime.timedelta(days=1)) + + # Disable the user record. + auth.update_user(new_user.uid, disabled=True) + # Verify the session cookie without checking revocation. This should + # not raise. + claims = auth.verify_session_cookie(session_cookie, check_revoked=False) + assert claims['sub'] == new_user.uid + + # Verify the session cookie while checking revocation. This should + # raise an exception. + with pytest.raises(auth.UserDisabledError) as excinfo: + auth.verify_session_cookie(session_cookie, check_revoked=True) + assert str(excinfo.value) == 'The user record is disabled.' + def test_import_users(): uid, email = _random_id() user = auth.ImportUserRecord(uid=uid, email=email) diff --git a/snippets/auth/index.py b/snippets/auth/index.py index 9de9cfa03..9d6f29ebd 100644 --- a/snippets/auth/index.py +++ b/snippets/auth/index.py @@ -150,6 +150,9 @@ def verify_token_uid_check_revoke(id_token): except auth.RevokedIdTokenError: # Token revoked, inform the user to reauthenticate or signOut(). pass + except auth.UserDisabledError: + # Token belongs to a disabled user record. + pass except auth.InvalidIdTokenError: # Token is invalid pass @@ -1027,6 +1030,9 @@ def verify_id_token_and_check_revoked_tenant(tenant_client, id_token): except auth.RevokedIdTokenError: # Token revoked, inform the user to reauthenticate or signOut(). pass + except auth.UserDisabledError: + # Token belongs to a disabled user record. + pass except auth.InvalidIdTokenError: # Token is invalid pass diff --git a/tests/test_token_gen.py b/tests/test_token_gen.py index 0a09862ab..00b7956fa 100644 --- a/tests/test_token_gen.py +++ b/tests/test_token_gen.py @@ -208,6 +208,19 @@ def revoked_tokens(): mock_user['users'][0]['validSince'] = str(int(time.time())+100) return json.dumps(mock_user) +@pytest.fixture(scope='module') +def user_disabled(): + mock_user = json.loads(testutils.resource('get_user.json')) + mock_user['users'][0]['disabled'] = True + return json.dumps(mock_user) + +@pytest.fixture(scope='module') +def user_disabled_and_revoked(): + mock_user = json.loads(testutils.resource('get_user.json')) + mock_user['users'][0]['disabled'] = True + mock_user['users'][0]['validSince'] = str(int(time.time())+100) + return json.dumps(mock_user) + class TestCreateCustomToken: @@ -471,6 +484,23 @@ def test_revoked_token_check_revoked(self, user_mgt_app, revoked_tokens, id_toke auth.verify_id_token(id_token, app=user_mgt_app, check_revoked=True) assert str(excinfo.value) == 'The Firebase ID token has been revoked.' + @pytest.mark.parametrize('id_token', valid_tokens.values(), ids=list(valid_tokens)) + def test_disabled_user_check_revoked(self, user_mgt_app, user_disabled, id_token): + _overwrite_cert_request(user_mgt_app, MOCK_REQUEST) + _instrument_user_manager(user_mgt_app, 200, user_disabled) + with pytest.raises(auth.UserDisabledError) as excinfo: + auth.verify_id_token(id_token, app=user_mgt_app, check_revoked=True) + assert str(excinfo.value) == 'The user record is disabled.' + + @pytest.mark.parametrize('id_token', valid_tokens.values(), ids=list(valid_tokens)) + def test_check_disabled_before_revoked( + self, user_mgt_app, user_disabled_and_revoked, id_token): + _overwrite_cert_request(user_mgt_app, MOCK_REQUEST) + _instrument_user_manager(user_mgt_app, 200, user_disabled_and_revoked) + with pytest.raises(auth.UserDisabledError) as excinfo: + auth.verify_id_token(id_token, app=user_mgt_app, check_revoked=True) + assert str(excinfo.value) == 'The user record is disabled.' + @pytest.mark.parametrize('arg', INVALID_BOOLS) def test_invalid_check_revoked(self, user_mgt_app, arg): _overwrite_cert_request(user_mgt_app, MOCK_REQUEST) @@ -485,6 +515,14 @@ def test_revoked_token_do_not_check_revoked(self, user_mgt_app, revoked_tokens, assert claims['admin'] is True assert claims['uid'] == claims['sub'] + @pytest.mark.parametrize('id_token', valid_tokens.values(), ids=list(valid_tokens)) + def test_disabled_user_do_not_check_revoked(self, user_mgt_app, user_disabled, id_token): + _overwrite_cert_request(user_mgt_app, MOCK_REQUEST) + _instrument_user_manager(user_mgt_app, 200, user_disabled) + claims = auth.verify_id_token(id_token, app=user_mgt_app, check_revoked=False) + assert claims['admin'] is True + assert claims['uid'] == claims['sub'] + @pytest.mark.parametrize('id_token', INVALID_JWT_ARGS.values(), ids=list(INVALID_JWT_ARGS)) def test_invalid_arg(self, user_mgt_app, id_token): _overwrite_cert_request(user_mgt_app, MOCK_REQUEST) @@ -622,6 +660,29 @@ def test_revoked_cookie_does_not_check_revoked(self, user_mgt_app, revoked_token _instrument_user_manager(user_mgt_app, 200, revoked_tokens) self._assert_valid_cookie(cookie, app=user_mgt_app, check_revoked=False) + @pytest.mark.parametrize('cookie', valid_cookies.values(), ids=list(valid_cookies)) + def test_disabled_user_check_revoked(self, user_mgt_app, user_disabled, cookie): + _overwrite_cert_request(user_mgt_app, MOCK_REQUEST) + _instrument_user_manager(user_mgt_app, 200, user_disabled) + with pytest.raises(auth.UserDisabledError) as excinfo: + auth.verify_session_cookie(cookie, app=user_mgt_app, check_revoked=True) + assert str(excinfo.value) == 'The user record is disabled.' + + @pytest.mark.parametrize('cookie', valid_cookies.values(), ids=list(valid_cookies)) + def test_check_disabled_before_revoked( + self, user_mgt_app, user_disabled_and_revoked, cookie): + _overwrite_cert_request(user_mgt_app, MOCK_REQUEST) + _instrument_user_manager(user_mgt_app, 200, user_disabled_and_revoked) + with pytest.raises(auth.UserDisabledError) as excinfo: + auth.verify_session_cookie(cookie, app=user_mgt_app, check_revoked=True) + assert str(excinfo.value) == 'The user record is disabled.' + + @pytest.mark.parametrize('cookie', valid_cookies.values(), ids=list(valid_cookies)) + def test_disabled_user_does_not_check_revoked(self, user_mgt_app, user_disabled, cookie): + _overwrite_cert_request(user_mgt_app, MOCK_REQUEST) + _instrument_user_manager(user_mgt_app, 200, user_disabled) + self._assert_valid_cookie(cookie, app=user_mgt_app, check_revoked=False) + @pytest.mark.parametrize('cookie', INVALID_JWT_ARGS.values(), ids=list(INVALID_JWT_ARGS)) def test_invalid_args(self, user_mgt_app, cookie): _overwrite_cert_request(user_mgt_app, MOCK_REQUEST) From 01db7eb8da6094e09fc0311930718deec5ccd4ad Mon Sep 17 00:00:00 2001 From: Lahiru Maramba Date: Fri, 13 Aug 2021 14:11:12 -0400 Subject: [PATCH 111/226] [chore] Release 5.0.2 (#567) --- firebase_admin/__about__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/firebase_admin/__about__.py b/firebase_admin/__about__.py index 4648863f3..0fbd30de4 100644 --- a/firebase_admin/__about__.py +++ b/firebase_admin/__about__.py @@ -14,7 +14,7 @@ """About information (version, etc) for Firebase Admin SDK.""" -__version__ = '5.0.1' +__version__ = '5.0.2' __title__ = 'firebase_admin' __author__ = 'Firebase' __license__ = 'Apache License 2.0' From 52eb94a961e5d72eae74fb60d23335561bd76815 Mon Sep 17 00:00:00 2001 From: David Buxton Date: Wed, 8 Sep 2021 22:01:08 +0100 Subject: [PATCH 112/226] Speed up the PageIterator by evaluating items once per page (#572) The `firebase_admin.auth.list_users().iterate_all()` method uses a sub-class of PageIterator, which happens to access the .items computed property more than once for every page of results. This has been changed so we take care not to access the `self.items` property more than once per page. This is a lot faster. --- firebase_admin/_auth_utils.py | 30 ++++++++++++++++-------------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/firebase_admin/_auth_utils.py b/firebase_admin/_auth_utils.py index e368342e8..02d32b659 100644 --- a/firebase_admin/_auth_utils.py +++ b/firebase_admin/_auth_utils.py @@ -43,30 +43,32 @@ class PageIterator: def __init__(self, current_page): if not current_page: raise ValueError('Current page must not be None.') + self._current_page = current_page - self._index = 0 + self._iter = None + + def __next__(self): + if self._iter is None: + self._iter = iter(self.items) - def next(self): - if self._index == len(self.items): + try: + return next(self._iter) + except StopIteration: if self._current_page.has_next_page: self._current_page = self._current_page.get_next_page() - self._index = 0 - if self._index < len(self.items): - result = self.items[self._index] - self._index += 1 - return result - raise StopIteration + self._iter = iter(self.items) - @property - def items(self): - raise NotImplementedError + return next(self._iter) - def __next__(self): - return self.next() + raise def __iter__(self): return self + @property + def items(self): + raise NotImplementedError + def get_emulator_host(): emulator_host = os.getenv(EMULATOR_HOST_ENV_VAR, '') From fccf814d3098b69ecfdac85c750426a804b00e98 Mon Sep 17 00:00:00 2001 From: jimcasteleiro Date: Thu, 16 Sep 2021 13:13:30 -0400 Subject: [PATCH 113/226] Allows google-api-core[grpc] versions 2.X.X (#576) --- requirements.txt | 2 +- setup.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/requirements.txt b/requirements.txt index 131b65f8b..0dd529c04 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,7 +5,7 @@ pytest-cov >= 2.4.0 pytest-localserver >= 0.4.1 cachecontrol >= 0.12.6 -google-api-core[grpc] >= 1.22.1, < 2.0.0dev; platform.python_implementation != 'PyPy' +google-api-core[grpc] >= 1.22.1, < 3.0.0dev; platform.python_implementation != 'PyPy' google-api-python-client >= 1.7.8 google-cloud-firestore >= 2.1.0; platform.python_implementation != 'PyPy' google-cloud-storage >= 1.37.1 diff --git a/setup.py b/setup.py index 83b7291df..6b47b2214 100644 --- a/setup.py +++ b/setup.py @@ -38,7 +38,7 @@ 'to integrate Firebase into their services and applications.') install_requires = [ 'cachecontrol>=0.12.6', - 'google-api-core[grpc] >= 1.22.1, < 2.0.0dev; platform.python_implementation != "PyPy"', + 'google-api-core[grpc] >= 1.22.1, < 3.0.0dev; platform.python_implementation != "PyPy"', 'google-api-python-client >= 1.7.8', 'google-cloud-firestore>=2.1.0; platform.python_implementation != "PyPy"', 'google-cloud-storage>=1.37.1', From 1a53b04f1ca489ea588941fd348ca2d0082d631c Mon Sep 17 00:00:00 2001 From: Lahiru Maramba Date: Thu, 23 Sep 2021 16:45:01 -0400 Subject: [PATCH 114/226] [chore] Release 5.0.3 (#580) --- firebase_admin/__about__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/firebase_admin/__about__.py b/firebase_admin/__about__.py index 0fbd30de4..c83f05ae6 100644 --- a/firebase_admin/__about__.py +++ b/firebase_admin/__about__.py @@ -14,7 +14,7 @@ """About information (version, etc) for Firebase Admin SDK.""" -__version__ = '5.0.2' +__version__ = '5.0.3' __title__ = 'firebase_admin' __author__ = 'Firebase' __license__ = 'Apache License 2.0' From f38c5f7f4401f8235c7f1c289b92d57fe2cb6b54 Mon Sep 17 00:00:00 2001 From: Hiranya Jayathilaka Date: Mon, 4 Oct 2021 10:18:05 -0700 Subject: [PATCH 115/226] fix: Extracting GAPIC API calls into a new module (#581) --- firebase_admin/_gapic_utils.py | 122 +++++++++++++++++++++++++++++++++ firebase_admin/_utils.py | 101 --------------------------- firebase_admin/messaging.py | 3 +- tests/test_exceptions.py | 29 ++++---- 4 files changed, 140 insertions(+), 115 deletions(-) create mode 100644 firebase_admin/_gapic_utils.py diff --git a/firebase_admin/_gapic_utils.py b/firebase_admin/_gapic_utils.py new file mode 100644 index 000000000..3c975808c --- /dev/null +++ b/firebase_admin/_gapic_utils.py @@ -0,0 +1,122 @@ +# Copyright 2021 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Internal utilities for interacting with Google API client.""" + +import io +import socket + +import googleapiclient +import httplib2 +import requests + +from firebase_admin import exceptions +from firebase_admin import _utils + + +def handle_platform_error_from_googleapiclient(error, handle_func=None): + """Constructs a ``FirebaseError`` from the given googleapiclient error. + + This can be used to handle errors returned by Google Cloud Platform (GCP) APIs. + + Args: + error: An error raised by the googleapiclient while making an HTTP call to a GCP API. + handle_func: A function that can be used to handle platform errors in a custom way. When + specified, this function will be called with three arguments. It has the same + signature as ```_handle_func_googleapiclient``, but may return ``None``. + + Returns: + FirebaseError: A ``FirebaseError`` that can be raised to the user code. + """ + if not isinstance(error, googleapiclient.errors.HttpError): + return handle_googleapiclient_error(error) + + content = error.content.decode() + status_code = error.resp.status + error_dict, message = _utils._parse_platform_error(content, status_code) # pylint: disable=protected-access + http_response = _http_response_from_googleapiclient_error(error) + exc = None + if handle_func: + exc = handle_func(error, message, error_dict, http_response) + + return exc if exc else _handle_func_googleapiclient(error, message, error_dict, http_response) + + +def _handle_func_googleapiclient(error, message, error_dict, http_response): + """Constructs a ``FirebaseError`` from the given GCP error. + + Args: + error: An error raised by the googleapiclient module while making an HTTP call. + message: A message to be included in the resulting ``FirebaseError``. + error_dict: Parsed GCP error response. + http_response: A requests HTTP response object to associate with the exception. + + Returns: + FirebaseError: A ``FirebaseError`` that can be raised to the user code or None. + """ + code = error_dict.get('status') + return handle_googleapiclient_error(error, message, code, http_response) + + +def handle_googleapiclient_error(error, message=None, code=None, http_response=None): + """Constructs a ``FirebaseError`` from the given googleapiclient error. + + This method is agnostic of the remote service that produced the error, whether it is a GCP + service or otherwise. Therefore, this method does not attempt to parse the error response in + any way. + + Args: + error: An error raised by the googleapiclient module while making an HTTP call. + message: A message to be included in the resulting ``FirebaseError`` (optional). If not + specified the string representation of the ``error`` argument is used as the message. + code: A GCP error code that will be used to determine the resulting error type (optional). + If not specified the HTTP status code on the error response is used to determine a + suitable error code. + http_response: A requests HTTP response object to associate with the exception (optional). + If not specified, one will be created from the ``error``. + + Returns: + FirebaseError: A ``FirebaseError`` that can be raised to the user code. + """ + if isinstance(error, socket.timeout) or ( + isinstance(error, socket.error) and 'timed out' in str(error)): + return exceptions.DeadlineExceededError( + message='Timed out while making an API call: {0}'.format(error), + cause=error) + if isinstance(error, httplib2.ServerNotFoundError): + return exceptions.UnavailableError( + message='Failed to establish a connection: {0}'.format(error), + cause=error) + if not isinstance(error, googleapiclient.errors.HttpError): + return exceptions.UnknownError( + message='Unknown error while making a remote service call: {0}'.format(error), + cause=error) + + if not code: + code = _utils._http_status_to_error_code(error.resp.status) # pylint: disable=protected-access + if not message: + message = str(error) + if not http_response: + http_response = _http_response_from_googleapiclient_error(error) + + err_type = _utils._error_code_to_exception_type(code) # pylint: disable=protected-access + return err_type(message=message, cause=error, http_response=http_response) + + +def _http_response_from_googleapiclient_error(error): + """Creates a requests HTTP Response object from the given googleapiclient error.""" + resp = requests.models.Response() + resp.raw = io.BytesIO(error.content) + resp.status_code = error.resp.status + return resp diff --git a/firebase_admin/_utils.py b/firebase_admin/_utils.py index 8c640276c..dcfb520d2 100644 --- a/firebase_admin/_utils.py +++ b/firebase_admin/_utils.py @@ -14,13 +14,9 @@ """Internal utilities common to all modules.""" -import io import json -import socket import google.auth -import googleapiclient -import httplib2 import requests import firebase_admin @@ -206,103 +202,6 @@ def handle_requests_error(error, message=None, code=None): return err_type(message=message, cause=error, http_response=error.response) -def handle_platform_error_from_googleapiclient(error, handle_func=None): - """Constructs a ``FirebaseError`` from the given googleapiclient error. - - This can be used to handle errors returned by Google Cloud Platform (GCP) APIs. - - Args: - error: An error raised by the googleapiclient while making an HTTP call to a GCP API. - handle_func: A function that can be used to handle platform errors in a custom way. When - specified, this function will be called with three arguments. It has the same - signature as ```_handle_func_googleapiclient``, but may return ``None``. - - Returns: - FirebaseError: A ``FirebaseError`` that can be raised to the user code. - """ - if not isinstance(error, googleapiclient.errors.HttpError): - return handle_googleapiclient_error(error) - - content = error.content.decode() - status_code = error.resp.status - error_dict, message = _parse_platform_error(content, status_code) - http_response = _http_response_from_googleapiclient_error(error) - exc = None - if handle_func: - exc = handle_func(error, message, error_dict, http_response) - - return exc if exc else _handle_func_googleapiclient(error, message, error_dict, http_response) - - -def _handle_func_googleapiclient(error, message, error_dict, http_response): - """Constructs a ``FirebaseError`` from the given GCP error. - - Args: - error: An error raised by the googleapiclient module while making an HTTP call. - message: A message to be included in the resulting ``FirebaseError``. - error_dict: Parsed GCP error response. - http_response: A requests HTTP response object to associate with the exception. - - Returns: - FirebaseError: A ``FirebaseError`` that can be raised to the user code or None. - """ - code = error_dict.get('status') - return handle_googleapiclient_error(error, message, code, http_response) - - -def handle_googleapiclient_error(error, message=None, code=None, http_response=None): - """Constructs a ``FirebaseError`` from the given googleapiclient error. - - This method is agnostic of the remote service that produced the error, whether it is a GCP - service or otherwise. Therefore, this method does not attempt to parse the error response in - any way. - - Args: - error: An error raised by the googleapiclient module while making an HTTP call. - message: A message to be included in the resulting ``FirebaseError`` (optional). If not - specified the string representation of the ``error`` argument is used as the message. - code: A GCP error code that will be used to determine the resulting error type (optional). - If not specified the HTTP status code on the error response is used to determine a - suitable error code. - http_response: A requests HTTP response object to associate with the exception (optional). - If not specified, one will be created from the ``error``. - - Returns: - FirebaseError: A ``FirebaseError`` that can be raised to the user code. - """ - if isinstance(error, socket.timeout) or ( - isinstance(error, socket.error) and 'timed out' in str(error)): - return exceptions.DeadlineExceededError( - message='Timed out while making an API call: {0}'.format(error), - cause=error) - if isinstance(error, httplib2.ServerNotFoundError): - return exceptions.UnavailableError( - message='Failed to establish a connection: {0}'.format(error), - cause=error) - if not isinstance(error, googleapiclient.errors.HttpError): - return exceptions.UnknownError( - message='Unknown error while making a remote service call: {0}'.format(error), - cause=error) - - if not code: - code = _http_status_to_error_code(error.resp.status) - if not message: - message = str(error) - if not http_response: - http_response = _http_response_from_googleapiclient_error(error) - - err_type = _error_code_to_exception_type(code) - return err_type(message=message, cause=error, http_response=http_response) - - -def _http_response_from_googleapiclient_error(error): - """Creates a requests HTTP Response object from the given googleapiclient error.""" - resp = requests.models.Response() - resp.raw = io.BytesIO(error.content) - resp.status_code = error.resp.status - return resp - - def _http_status_to_error_code(status): """Maps an HTTP status to a platform error code.""" return _HTTP_STATUS_TO_ERROR_CODE.get(status, exceptions.UNKNOWN) diff --git a/firebase_admin/messaging.py b/firebase_admin/messaging.py index 95fc03e04..46dd7d410 100644 --- a/firebase_admin/messaging.py +++ b/firebase_admin/messaging.py @@ -24,6 +24,7 @@ from firebase_admin import _http_client from firebase_admin import _messaging_encoder from firebase_admin import _messaging_utils +from firebase_admin import _gapic_utils from firebase_admin import _utils @@ -466,7 +467,7 @@ def _handle_iid_error(self, error): def _handle_batch_error(self, error): """Handles errors received from the googleapiclient while making batch requests.""" - return _utils.handle_platform_error_from_googleapiclient( + return _gapic_utils.handle_platform_error_from_googleapiclient( error, _MessagingService._build_fcm_error_googleapiclient) @classmethod diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index 96072d91b..4347c838a 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -24,6 +24,7 @@ from googleapiclient import errors from firebase_admin import exceptions from firebase_admin import _utils +from firebase_admin import _gapic_utils _NOT_FOUND_ERROR_DICT = { @@ -186,7 +187,7 @@ class TestGoogleApiClient: socket.error('Read timed out') ]) def test_googleapicleint_timeout_error(self, error): - firebase_error = _utils.handle_googleapiclient_error(error) + firebase_error = _gapic_utils.handle_googleapiclient_error(error) assert isinstance(firebase_error, exceptions.DeadlineExceededError) assert str(firebase_error) == 'Timed out while making an API call: {0}'.format(error) assert firebase_error.cause is error @@ -194,7 +195,7 @@ def test_googleapicleint_timeout_error(self, error): def test_googleapiclient_connection_error(self): error = httplib2.ServerNotFoundError('Test error') - firebase_error = _utils.handle_googleapiclient_error(error) + firebase_error = _gapic_utils.handle_googleapiclient_error(error) assert isinstance(firebase_error, exceptions.UnavailableError) assert str(firebase_error) == 'Failed to establish a connection: Test error' assert firebase_error.cause is error @@ -202,7 +203,7 @@ def test_googleapiclient_connection_error(self): def test_unknown_transport_error(self): error = socket.error('Test error') - firebase_error = _utils.handle_googleapiclient_error(error) + firebase_error = _gapic_utils.handle_googleapiclient_error(error) assert isinstance(firebase_error, exceptions.UnknownError) assert str(firebase_error) == 'Unknown error while making a remote service call: Test error' assert firebase_error.cause is error @@ -210,7 +211,7 @@ def test_unknown_transport_error(self): def test_http_response(self): error = self._create_http_error() - firebase_error = _utils.handle_googleapiclient_error(error) + firebase_error = _gapic_utils.handle_googleapiclient_error(error) assert isinstance(firebase_error, exceptions.InternalError) assert str(firebase_error) == str(error) assert firebase_error.cause is error @@ -219,7 +220,7 @@ def test_http_response(self): def test_http_response_with_unknown_status(self): error = self._create_http_error(status=501) - firebase_error = _utils.handle_googleapiclient_error(error) + firebase_error = _gapic_utils.handle_googleapiclient_error(error) assert isinstance(firebase_error, exceptions.UnknownError) assert str(firebase_error) == str(error) assert firebase_error.cause is error @@ -228,7 +229,7 @@ def test_http_response_with_unknown_status(self): def test_http_response_with_message(self): error = self._create_http_error() - firebase_error = _utils.handle_googleapiclient_error( + firebase_error = _gapic_utils.handle_googleapiclient_error( error, message='Explicit error message') assert isinstance(firebase_error, exceptions.InternalError) assert str(firebase_error) == 'Explicit error message' @@ -238,7 +239,7 @@ def test_http_response_with_message(self): def test_http_response_with_code(self): error = self._create_http_error() - firebase_error = _utils.handle_googleapiclient_error( + firebase_error = _gapic_utils.handle_googleapiclient_error( error, code=exceptions.UNAVAILABLE) assert isinstance(firebase_error, exceptions.UnavailableError) assert str(firebase_error) == str(error) @@ -248,7 +249,7 @@ def test_http_response_with_code(self): def test_http_response_with_message_and_code(self): error = self._create_http_error() - firebase_error = _utils.handle_googleapiclient_error( + firebase_error = _gapic_utils.handle_googleapiclient_error( error, message='Explicit error message', code=exceptions.UNAVAILABLE) assert isinstance(firebase_error, exceptions.UnavailableError) assert str(firebase_error) == 'Explicit error message' @@ -258,7 +259,7 @@ def test_http_response_with_message_and_code(self): def test_handle_platform_error(self): error = self._create_http_error(payload=_NOT_FOUND_PAYLOAD) - firebase_error = _utils.handle_platform_error_from_googleapiclient(error) + firebase_error = _gapic_utils.handle_platform_error_from_googleapiclient(error) assert isinstance(firebase_error, exceptions.NotFoundError) assert str(firebase_error) == 'test error' assert firebase_error.cause is error @@ -267,7 +268,7 @@ def test_handle_platform_error(self): def test_handle_platform_error_with_no_response(self): error = socket.error('Test error') - firebase_error = _utils.handle_platform_error_from_googleapiclient(error) + firebase_error = _gapic_utils.handle_platform_error_from_googleapiclient(error) assert isinstance(firebase_error, exceptions.UnknownError) assert str(firebase_error) == 'Unknown error while making a remote service call: Test error' assert firebase_error.cause is error @@ -275,7 +276,7 @@ def test_handle_platform_error_with_no_response(self): def test_handle_platform_error_with_no_error_code(self): error = self._create_http_error(payload='no error code') - firebase_error = _utils.handle_platform_error_from_googleapiclient(error) + firebase_error = _gapic_utils.handle_platform_error_from_googleapiclient(error) assert isinstance(firebase_error, exceptions.InternalError) message = 'Unexpected HTTP response with status: 500; body: no error code' assert str(firebase_error) == message @@ -291,7 +292,8 @@ def _custom_handler(cause, message, error_dict, http_response): invocations.append((cause, message, error_dict, http_response)) return exceptions.InvalidArgumentError('Custom message', cause, http_response) - firebase_error = _utils.handle_platform_error_from_googleapiclient(error, _custom_handler) + firebase_error = _gapic_utils.handle_platform_error_from_googleapiclient( + error, _custom_handler) assert isinstance(firebase_error, exceptions.InvalidArgumentError) assert str(firebase_error) == 'Custom message' @@ -313,7 +315,8 @@ def test_handle_platform_error_with_custom_handler_ignore(self): def _custom_handler(cause, message, error_dict, http_response): invocations.append((cause, message, error_dict, http_response)) - firebase_error = _utils.handle_platform_error_from_googleapiclient(error, _custom_handler) + firebase_error = _gapic_utils.handle_platform_error_from_googleapiclient( + error, _custom_handler) assert isinstance(firebase_error, exceptions.NotFoundError) assert str(firebase_error) == 'test error' From 0a11d07190314cc57edb48395be44df11775f480 Mon Sep 17 00:00:00 2001 From: Hardik Shah Date: Tue, 5 Oct 2021 11:22:26 +0530 Subject: [PATCH 116/226] feat(auth): ability to delete provider in auth (#579) * [add] ability to delete provider in auth * [fix] pylint * [fix] tests * [fix] tests * fix comments * fix tests * fix lint * [fix] address comments --- firebase_admin/_auth_client.py | 2 ++ firebase_admin/_auth_utils.py | 9 +++++++++ firebase_admin/_user_mgt.py | 8 ++++++-- integration/test_auth.py | 8 ++++++++ tests/test_user_mgt.py | 17 +++++++++++++++++ 5 files changed, 42 insertions(+), 2 deletions(-) diff --git a/firebase_admin/_auth_client.py b/firebase_admin/_auth_client.py index 4418a034d..27dd5c7ce 100644 --- a/firebase_admin/_auth_client.py +++ b/firebase_admin/_auth_client.py @@ -336,6 +336,8 @@ def update_user(self, uid, **kwargs): # pylint: disable=differing-param-doc valid_since: An integer signifying the seconds since the epoch (optional). This field is set by ``revoke_refresh_tokens`` and it is discouraged to set this field directly. + providers_to_delete: The list of provider IDs to unlink, + eg: 'google.com', 'password', etc. Returns: UserRecord: An updated UserRecord instance for the user. diff --git a/firebase_admin/_auth_utils.py b/firebase_admin/_auth_utils.py index 02d32b659..7aece495e 100644 --- a/firebase_admin/_auth_utils.py +++ b/firebase_admin/_auth_utils.py @@ -266,6 +266,15 @@ def validate_action_type(action_type): Valid values are {1}'.format(action_type, ', '.join(VALID_EMAIL_ACTION_TYPES))) return action_type +def validate_provider_ids(provider_ids, required=False): + if not provider_ids: + if required: + raise ValueError('Invalid provider IDs. Provider ids should be provided') + return [] + for provider_id in provider_ids: + validate_provider_id(provider_id, True) + return provider_ids + def build_update_mask(params): """Creates an update mask list from the given dictionary.""" mask = [] diff --git a/firebase_admin/_user_mgt.py b/firebase_admin/_user_mgt.py index b60c4d100..c77c4d40d 100644 --- a/firebase_admin/_user_mgt.py +++ b/firebase_admin/_user_mgt.py @@ -688,7 +688,7 @@ def create_user(self, uid=None, display_name=None, email=None, phone_number=None def update_user(self, uid, display_name=None, email=None, phone_number=None, photo_url=None, password=None, disabled=None, email_verified=None, - valid_since=None, custom_claims=None): + valid_since=None, custom_claims=None, providers_to_delete=None): """Updates an existing user account with the specified properties""" payload = { 'localId': _auth_utils.validate_uid(uid, required=True), @@ -700,6 +700,7 @@ def update_user(self, uid, display_name=None, email=None, phone_number=None, } remove = [] + remove_provider = _auth_utils.validate_provider_ids(providers_to_delete) if display_name is not None: if display_name is DELETE_ATTRIBUTE: remove.append('DISPLAY_NAME') @@ -715,7 +716,7 @@ def update_user(self, uid, display_name=None, email=None, phone_number=None, if phone_number is not None: if phone_number is DELETE_ATTRIBUTE: - payload['deleteProvider'] = ['phone'] + remove_provider.append('phone') else: payload['phoneNumber'] = _auth_utils.validate_phone(phone_number) @@ -726,6 +727,9 @@ def update_user(self, uid, display_name=None, email=None, phone_number=None, custom_claims, dict) else custom_claims payload['customAttributes'] = _auth_utils.validate_custom_claims(json_claims) + if remove_provider: + payload['deleteProvider'] = list(set(remove_provider)) + payload = {k: v for k, v in payload.items() if v is not None} body, http_resp = self._make_request('post', '/accounts:update', json=payload) if not body or not body.get('localId'): diff --git a/integration/test_auth.py b/integration/test_auth.py index 55ddbb0a0..2dd2cb639 100644 --- a/integration/test_auth.py +++ b/integration/test_auth.py @@ -496,6 +496,14 @@ def test_disable_user(new_user_with_params): assert user.disabled is True assert len(user.provider_data) == 1 +def test_remove_provider(new_user_with_provider): + provider_ids = [provider.provider_id for provider in new_user_with_provider.provider_data] + assert 'google.com' in provider_ids + user = auth.update_user(new_user_with_provider, providers_to_delete=['google.com']) + assert user.uid == new_user_with_params.uid + new_provider_ids = [provider.provider_id for provider in user.provider_data] + assert 'google.com' not in new_provider_ids + def test_delete_user(): user = auth.create_user() auth.delete_user(user.uid) diff --git a/tests/test_user_mgt.py b/tests/test_user_mgt.py index 10dfe698f..67447c6ba 100644 --- a/tests/test_user_mgt.py +++ b/tests/test_user_mgt.py @@ -663,6 +663,23 @@ def test_update_user_valid_since(self, user_mgt_app, arg): request = json.loads(recorder[0].body.decode()) assert request == {'localId': 'testuser', 'validSince': int(arg)} + @pytest.mark.parametrize('arg', [['phone'], ['google.com', 'phone']]) + def test_update_user_delete_provider(self, user_mgt_app, arg): + user_mgt, recorder = _instrument_user_manager(user_mgt_app, 200, '{"localId":"testuser"}') + user_mgt.update_user('testuser', providers_to_delete=arg) + request = json.loads(recorder[0].body.decode()) + assert set(request['deleteProvider']) == set(arg) + + @pytest.mark.parametrize('arg', [[], ['phone'], ['google.com'], ['google.com', 'phone']]) + def test_update_user_delete_provider_and_phone(self, user_mgt_app, arg): + user_mgt, recorder = _instrument_user_manager(user_mgt_app, 200, '{"localId":"testuser"}') + user_mgt.update_user('testuser', + providers_to_delete=arg, + phone_number=auth.DELETE_ATTRIBUTE) + request = json.loads(recorder[0].body.decode()) + assert 'phone' in request['deleteProvider'] + assert len(set(request['deleteProvider'])) == len(request['deleteProvider']) + assert set(arg) - set(request['deleteProvider']) == set() class TestSetCustomUserClaims: From ebf1bcd946fcc2cc97801f418bfebfbc45783f44 Mon Sep 17 00:00:00 2001 From: Hiranya Jayathilaka Date: Wed, 6 Oct 2021 11:29:39 -0700 Subject: [PATCH 117/226] fix: Fixing a broken integration test (#582) * fix: Fixing a broken integration test * fix: Fixing another typo --- integration/test_auth.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/integration/test_auth.py b/integration/test_auth.py index 2dd2cb639..d2d3e8577 100644 --- a/integration/test_auth.py +++ b/integration/test_auth.py @@ -499,8 +499,8 @@ def test_disable_user(new_user_with_params): def test_remove_provider(new_user_with_provider): provider_ids = [provider.provider_id for provider in new_user_with_provider.provider_data] assert 'google.com' in provider_ids - user = auth.update_user(new_user_with_provider, providers_to_delete=['google.com']) - assert user.uid == new_user_with_params.uid + user = auth.update_user(new_user_with_provider.uid, providers_to_delete=['google.com']) + assert user.uid == new_user_with_provider.uid new_provider_ids = [provider.provider_id for provider in user.provider_data] assert 'google.com' not in new_provider_ids From 9f3143c593be72092eb0c2823c5f213b81fc02ec Mon Sep 17 00:00:00 2001 From: Lahiru Maramba Date: Thu, 4 Nov 2021 17:42:30 -0400 Subject: [PATCH 118/226] [chore] Release 5.1.0 (#587) Bumped version to 5.1.0 --- firebase_admin/__about__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/firebase_admin/__about__.py b/firebase_admin/__about__.py index c83f05ae6..df58224f8 100644 --- a/firebase_admin/__about__.py +++ b/firebase_admin/__about__.py @@ -14,7 +14,7 @@ """About information (version, etc) for Firebase Admin SDK.""" -__version__ = '5.0.3' +__version__ = '5.1.0' __title__ = 'firebase_admin' __author__ = 'Firebase' __license__ = 'Apache License 2.0' From 348a90dd5cd3851817bac809ed2b593b91d1cece Mon Sep 17 00:00:00 2001 From: Sarmad Gulzar Date: Tue, 30 Nov 2021 20:28:33 +0500 Subject: [PATCH 119/226] Added return type for `bucket` function (#591) --- firebase_admin/storage.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/firebase_admin/storage.py b/firebase_admin/storage.py index 16f48e273..f3948371c 100644 --- a/firebase_admin/storage.py +++ b/firebase_admin/storage.py @@ -30,7 +30,7 @@ _STORAGE_ATTRIBUTE = '_storage' -def bucket(name=None, app=None): +def bucket(name=None, app=None) -> storage.Bucket: """Returns a handle to a Google Cloud Storage bucket. If the name argument is not provided, uses the 'storageBucket' option specified when From 02596dc149d05fa69c8288c0f3c73cef60ebff41 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C5=81ukasz=20Skar=C5=BCy=C5=84ski?= Date: Tue, 30 Nov 2021 19:07:26 +0100 Subject: [PATCH 120/226] correct kwargs documentation in docstrings (#559) --- firebase_admin/_auth_client.py | 4 ++-- firebase_admin/_http_client.py | 2 +- firebase_admin/_messaging_utils.py | 2 +- firebase_admin/auth.py | 4 ++-- firebase_admin/db.py | 2 +- 5 files changed, 7 insertions(+), 7 deletions(-) diff --git a/firebase_admin/_auth_client.py b/firebase_admin/_auth_client.py index 27dd5c7ce..0265197d9 100644 --- a/firebase_admin/_auth_client.py +++ b/firebase_admin/_auth_client.py @@ -288,7 +288,7 @@ def create_user(self, **kwargs): # pylint: disable=differing-param-doc """Creates a new user account with the specified properties. Args: - kwargs: A series of keyword arguments (optional). + **kwargs: A series of keyword arguments (optional). Keyword Args: uid: User ID to assign to the newly created user (optional). @@ -316,7 +316,7 @@ def update_user(self, uid, **kwargs): # pylint: disable=differing-param-doc Args: uid: A user ID string. - kwargs: A series of keyword arguments (optional). + **kwargs: A series of keyword arguments (optional). Keyword Args: display_name: The user's display name (optional). Can be removed by explicitly passing diff --git a/firebase_admin/_http_client.py b/firebase_admin/_http_client.py index ae312095b..d259faddf 100644 --- a/firebase_admin/_http_client.py +++ b/firebase_admin/_http_client.py @@ -104,7 +104,7 @@ class call this method to send HTTP requests out. Refer to Args: method: HTTP method name as a string (e.g. get, post). url: URL of the remote endpoint. - kwargs: An additional set of keyword arguments to be passed into the requests API + **kwargs: An additional set of keyword arguments to be passed into the requests API (e.g. json, params, timeout). Returns: diff --git a/firebase_admin/_messaging_utils.py b/firebase_admin/_messaging_utils.py index d25ba5520..64930f1b8 100644 --- a/firebase_admin/_messaging_utils.py +++ b/firebase_admin/_messaging_utils.py @@ -341,7 +341,7 @@ class APNSPayload: Args: aps: A ``messaging.Aps`` instance to be included in the payload. - kwargs: Arbitrary keyword arguments to be included as custom fields in the payload + **kwargs: Arbitrary keyword arguments to be included as custom fields in the payload (optional). """ diff --git a/firebase_admin/auth.py b/firebase_admin/auth.py index 40a5b611f..cbaaf6c01 100644 --- a/firebase_admin/auth.py +++ b/firebase_admin/auth.py @@ -416,7 +416,7 @@ def create_user(**kwargs): # pylint: disable=differing-param-doc """Creates a new user account with the specified properties. Args: - kwargs: A series of keyword arguments (optional). + **kwargs: A series of keyword arguments (optional). Keyword Args: uid: User ID to assign to the newly created user (optional). @@ -447,7 +447,7 @@ def update_user(uid, **kwargs): # pylint: disable=differing-param-doc Args: uid: A user ID string. - kwargs: A series of keyword arguments (optional). + **kwargs: A series of keyword arguments (optional). Keyword Args: display_name: The user's display name (optional). Can be removed by explicitly passing diff --git a/firebase_admin/db.py b/firebase_admin/db.py index 1d293bb89..890968796 100644 --- a/firebase_admin/db.py +++ b/firebase_admin/db.py @@ -907,7 +907,7 @@ def request(self, method, url, **kwargs): Args: method: HTTP method name as a string (e.g. get, post). url: URL path of the remote endpoint. This will be appended to the server's base URL. - kwargs: An additional set of keyword arguments to be passed into requests API + **kwargs: An additional set of keyword arguments to be passed into requests API (e.g. json, params). Returns: From 008b1d83717c728a102a582883e1b53e48cafe9f Mon Sep 17 00:00:00 2001 From: Ryan Kohler Date: Tue, 14 Dec 2021 04:46:44 -0800 Subject: [PATCH 121/226] feat(auth): enables OIDC auth code flow (#549) Provides an option for developers to specify the OAuth response type for their OIDC provider (either one of these can be set:): - id_token - code (if set, must also set the client secret) RELEASE NOTES: Added support for configuring the authorization code flow for OIDC providers. --- firebase_admin/_auth_client.py | 32 ++++++++++++++++++--- firebase_admin/_auth_providers.py | 47 +++++++++++++++++++++++++++++-- firebase_admin/auth.py | 30 +++++++++++++++++--- integration/test_auth.py | 18 ++++++++++-- tests/test_auth_providers.py | 39 ++++++++++++++++++++++--- 5 files changed, 150 insertions(+), 16 deletions(-) diff --git a/firebase_admin/_auth_client.py b/firebase_admin/_auth_client.py index 0265197d9..eaf491f32 100644 --- a/firebase_admin/_auth_client.py +++ b/firebase_admin/_auth_client.py @@ -514,7 +514,8 @@ def get_oidc_provider_config(self, provider_id): return self._provider_manager.get_oidc_provider_config(provider_id) def create_oidc_provider_config( - self, provider_id, client_id, issuer, display_name=None, enabled=None): + self, provider_id, client_id, issuer, display_name=None, enabled=None, + client_secret=None, id_token_response_type=None, code_response_type=None): """Creates a new OIDC provider config from the given parameters. OIDC provider support requires Google Cloud's Identity Platform (GCIP). To learn more about @@ -528,6 +529,16 @@ def create_oidc_provider_config( This name is also used as the provider label in the Cloud Console. enabled: A boolean indicating whether the provider configuration is enabled or disabled (optional). A user cannot sign in using a disabled provider. + client_secret: A string which sets the client secret for the new provider. + This is required for the code flow. + code_response_type: A boolean which sets whether to enable the code response flow for + the new provider. By default, this is not enabled if no response type is + specified. A client secret must be set for this response type. + Having both the code and ID token response flows is currently not supported. + id_token_response_type: A boolean which sets whether to enable the ID token response + flow for the new provider. By default, this is enabled if no response type is + specified. + Having both the code and ID token response flows is currently not supported. Returns: OIDCProviderConfig: The newly created OIDC provider config instance. @@ -538,10 +549,12 @@ def create_oidc_provider_config( """ return self._provider_manager.create_oidc_provider_config( provider_id, client_id=client_id, issuer=issuer, display_name=display_name, - enabled=enabled) + enabled=enabled, client_secret=client_secret, + id_token_response_type=id_token_response_type, code_response_type=code_response_type) def update_oidc_provider_config( - self, provider_id, client_id=None, issuer=None, display_name=None, enabled=None): + self, provider_id, client_id=None, issuer=None, display_name=None, enabled=None, + client_secret=None, id_token_response_type=None, code_response_type=None): """Updates an existing OIDC provider config with the given parameters. Args: @@ -552,6 +565,16 @@ def update_oidc_provider_config( Pass ``auth.DELETE_ATTRIBUTE`` to delete the current display name. enabled: A boolean indicating whether the provider configuration is enabled or disabled (optional). + client_secret: A string which sets the client secret for the new provider. + This is required for the code flow. + code_response_type: A boolean which sets whether to enable the code response flow for + the new provider. By default, this is not enabled if no response type is specified. + A client secret must be set for this response type. + Having both the code and ID token response flows is currently not supported. + id_token_response_type: A boolean which sets whether to enable the ID token response + flow for the new provider. By default, this is enabled if no response type is + specified. + Having both the code and ID token response flows is currently not supported. Returns: OIDCProviderConfig: The updated OIDC provider config instance. @@ -562,7 +585,8 @@ def update_oidc_provider_config( """ return self._provider_manager.update_oidc_provider_config( provider_id, client_id=client_id, issuer=issuer, display_name=display_name, - enabled=enabled) + enabled=enabled, client_secret=client_secret, + id_token_response_type=id_token_response_type, code_response_type=code_response_type) def delete_oidc_provider_config(self, provider_id): """Deletes the ``OIDCProviderConfig`` with the given ID. diff --git a/firebase_admin/_auth_providers.py b/firebase_admin/_auth_providers.py index 5126c862c..31511f3c5 100644 --- a/firebase_admin/_auth_providers.py +++ b/firebase_admin/_auth_providers.py @@ -59,6 +59,18 @@ def issuer(self): def client_id(self): return self._data['clientId'] + @property + def client_secret(self): + return self._data.get('clientSecret') + + @property + def id_token_response_type(self): + return self._data.get('responseType', {}).get('idToken', False) + + @property + def code_response_type(self): + return self._data.get('responseType', {}).get('code', False) + class SAMLProviderConfig(ProviderConfig): """Represents he SAML auth provider configuration. @@ -179,7 +191,8 @@ def get_oidc_provider_config(self, provider_id): return OIDCProviderConfig(body) def create_oidc_provider_config( - self, provider_id, client_id, issuer, display_name=None, enabled=None): + self, provider_id, client_id, issuer, display_name=None, enabled=None, + client_secret=None, id_token_response_type=None, code_response_type=None): """Creates a new OIDC provider config from the given parameters.""" _validate_oidc_provider_id(provider_id) req = { @@ -191,12 +204,28 @@ def create_oidc_provider_config( if enabled is not None: req['enabled'] = _auth_utils.validate_boolean(enabled, 'enabled') + response_type = {} + if id_token_response_type is False and code_response_type is False: + raise ValueError('At least one response type must be returned.') + if id_token_response_type is not None: + response_type['idToken'] = _auth_utils.validate_boolean( + id_token_response_type, 'id_token_response_type') + if code_response_type is not None: + response_type['code'] = _auth_utils.validate_boolean( + code_response_type, 'code_response_type') + if code_response_type: + req['clientSecret'] = _validate_non_empty_string(client_secret, 'client_secret') + if response_type: + req['responseType'] = response_type + params = 'oauthIdpConfigId={0}'.format(provider_id) body = self._make_request('post', '/oauthIdpConfigs', json=req, params=params) return OIDCProviderConfig(body) def update_oidc_provider_config( - self, provider_id, client_id=None, issuer=None, display_name=None, enabled=None): + self, provider_id, client_id=None, issuer=None, display_name=None, + enabled=None, client_secret=None, id_token_response_type=None, + code_response_type=None): """Updates an existing OIDC provider config with the given parameters.""" _validate_oidc_provider_id(provider_id) req = {} @@ -212,6 +241,20 @@ def update_oidc_provider_config( if issuer: req['issuer'] = _validate_url(issuer, 'issuer') + response_type = {} + if id_token_response_type is False and code_response_type is False: + raise ValueError('At least one response type must be returned.') + if id_token_response_type is not None: + response_type['idToken'] = _auth_utils.validate_boolean( + id_token_response_type, 'id_token_response_type') + if code_response_type is not None: + response_type['code'] = _auth_utils.validate_boolean( + code_response_type, 'code_response_type') + if code_response_type: + req['clientSecret'] = _validate_non_empty_string(client_secret, 'client_secret') + if response_type: + req['responseType'] = response_type + if not req: raise ValueError('At least one parameter must be specified for update.') diff --git a/firebase_admin/auth.py b/firebase_admin/auth.py index cbaaf6c01..6902a322f 100644 --- a/firebase_admin/auth.py +++ b/firebase_admin/auth.py @@ -656,7 +656,8 @@ def get_oidc_provider_config(provider_id, app=None): return client.get_oidc_provider_config(provider_id) def create_oidc_provider_config( - provider_id, client_id, issuer, display_name=None, enabled=None, app=None): + provider_id, client_id, issuer, display_name=None, enabled=None, client_secret=None, + id_token_response_type=None, code_response_type=None, app=None): """Creates a new OIDC provider config from the given parameters. OIDC provider support requires Google Cloud's Identity Platform (GCIP). To learn more about @@ -671,6 +672,15 @@ def create_oidc_provider_config( enabled: A boolean indicating whether the provider configuration is enabled or disabled (optional). A user cannot sign in using a disabled provider. app: An App instance (optional). + client_secret: A string which sets the client secret for the new provider. + This is required for the code flow. + code_response_type: A boolean which sets whether to enable the code response flow for the + new provider. By default, this is not enabled if no response type is specified. + A client secret must be set for this response type. + Having both the code and ID token response flows is currently not supported. + id_token_response_type: A boolean which sets whether to enable the ID token response flow + for the new provider. By default, this is enabled if no response type is specified. + Having both the code and ID token response flows is currently not supported. Returns: OIDCProviderConfig: The newly created OIDC provider config instance. @@ -682,11 +692,13 @@ def create_oidc_provider_config( client = _get_client(app) return client.create_oidc_provider_config( provider_id, client_id=client_id, issuer=issuer, display_name=display_name, - enabled=enabled) + enabled=enabled, client_secret=client_secret, id_token_response_type=id_token_response_type, + code_response_type=code_response_type) def update_oidc_provider_config( - provider_id, client_id=None, issuer=None, display_name=None, enabled=None, app=None): + provider_id, client_id=None, issuer=None, display_name=None, enabled=None, + client_secret=None, id_token_response_type=None, code_response_type=None, app=None): """Updates an existing OIDC provider config with the given parameters. Args: @@ -698,6 +710,15 @@ def update_oidc_provider_config( enabled: A boolean indicating whether the provider configuration is enabled or disabled (optional). app: An App instance (optional). + client_secret: A string which sets the client secret for the new provider. + This is required for the code flow. + code_response_type: A boolean which sets whether to enable the code response flow for the + new provider. By default, this is not enabled if no response type is specified. + A client secret must be set for this response type. + Having both the code and ID token response flows is currently not supported. + id_token_response_type: A boolean which sets whether to enable the ID token response flow + for the new provider. By default, this is enabled if no response type is specified. + Having both the code and ID token response flows is currently not supported. Returns: OIDCProviderConfig: The updated OIDC provider config instance. @@ -709,7 +730,8 @@ def update_oidc_provider_config( client = _get_client(app) return client.update_oidc_provider_config( provider_id, client_id=client_id, issuer=issuer, display_name=display_name, - enabled=enabled) + enabled=enabled, client_secret=client_secret, id_token_response_type=id_token_response_type, + code_response_type=code_response_type) def delete_oidc_provider_config(provider_id, app=None): diff --git a/integration/test_auth.py b/integration/test_auth.py index d2d3e8577..1009816eb 100644 --- a/integration/test_auth.py +++ b/integration/test_auth.py @@ -736,6 +736,9 @@ def test_create_oidc_provider_config(oidc_provider): assert oidc_provider.issuer == 'https://oidc.com/issuer' assert oidc_provider.display_name == 'OIDC_DISPLAY_NAME' assert oidc_provider.enabled is True + assert oidc_provider.response_type.id_token is True + assert oidc_provider.response_type.code is False + assert oidc_provider.client_secret is None def test_get_oidc_provider_config(oidc_provider): @@ -746,6 +749,9 @@ def test_get_oidc_provider_config(oidc_provider): assert provider_config.issuer == 'https://oidc.com/issuer' assert provider_config.display_name == 'OIDC_DISPLAY_NAME' assert provider_config.enabled is True + assert provider_config.response_type.id_token is True + assert provider_config.response_type.code is False + assert provider_config.client_secret is None def test_list_oidc_provider_configs(oidc_provider): @@ -767,11 +773,17 @@ def test_update_oidc_provider_config(): client_id='UPDATED_OIDC_CLIENT_ID', issuer='https://oidc.com/updated_issuer', display_name='UPDATED_OIDC_DISPLAY_NAME', - enabled=False) + enabled=False, + client_secret='CLIENT_SECRET', + id_token_response_type=False, + code_response_type=True) assert provider_config.client_id == 'UPDATED_OIDC_CLIENT_ID' assert provider_config.issuer == 'https://oidc.com/updated_issuer' assert provider_config.display_name == 'UPDATED_OIDC_DISPLAY_NAME' assert provider_config.enabled is False + assert provider_config.response_type.id_token is False + assert provider_config.response_type.code is True + assert provider_config.client_secret == 'CLIENT_SECRET' finally: auth.delete_oidc_provider_config(provider_config.provider_id) @@ -863,7 +875,9 @@ def _create_oidc_provider_config(): client_id='OIDC_CLIENT_ID', issuer='https://oidc.com/issuer', display_name='OIDC_DISPLAY_NAME', - enabled=True) + enabled=True, + id_token_response_type=True, + code_response_type=False) def _create_saml_provider_config(): diff --git a/tests/test_auth_providers.py b/tests/test_auth_providers.py index 0947c77ae..b67a8eb96 100644 --- a/tests/test_auth_providers.py +++ b/tests/test_auth_providers.py @@ -79,13 +79,21 @@ class TestOIDCProviderConfig: 'issuer': 'https://oidc.com/issuer', 'display_name': 'oidcProviderName', 'enabled': True, + 'id_token_response_type': True, + 'code_response_type': True, + 'client_secret': 'CLIENT_SECRET', } OIDC_CONFIG_REQUEST = { 'displayName': 'oidcProviderName', 'enabled': True, 'clientId': 'CLIENT_ID', + 'clientSecret': 'CLIENT_SECRET', 'issuer': 'https://oidc.com/issuer', + 'responseType': { + 'code': True, + 'idToken': True, + }, } @pytest.mark.parametrize('provider_id', INVALID_PROVIDER_IDS + ['saml.provider']) @@ -112,6 +120,11 @@ def test_get(self, user_mgt_app): {'issuer': None}, {'issuer': ''}, {'issuer': 'not a url'}, {'display_name': True}, {'enabled': 'true'}, + {'id_token_response_type': 'true'}, {'code_response_type': 'true'}, + {'code_response_type': True, 'client_secret': ''}, + {'code_response_type': True, 'client_secret': True}, + {'code_response_type': True, 'client_secret': None}, + {'code_response_type': False, 'id_token_response_type': False}, ]) def test_create_invalid_args(self, user_mgt_app, invalid_opts): options = dict(self.VALID_CREATE_OPTIONS) @@ -139,9 +152,14 @@ def test_create_minimal(self, user_mgt_app): options = dict(self.VALID_CREATE_OPTIONS) del options['display_name'] del options['enabled'] + del options['client_secret'] + del options['id_token_response_type'] + del options['code_response_type'] want = dict(self.OIDC_CONFIG_REQUEST) del want['displayName'] del want['enabled'] + del want['clientSecret'] + del want['responseType'] provider_config = auth.create_oidc_provider_config(**options, app=user_mgt_app) @@ -159,9 +177,15 @@ def test_create_empty_values(self, user_mgt_app): options = dict(self.VALID_CREATE_OPTIONS) options['display_name'] = '' options['enabled'] = False + options['code_response_type'] = False want = dict(self.OIDC_CONFIG_REQUEST) want['displayName'] = '' want['enabled'] = False + want['responseType'] = { + 'code': False, + 'idToken': True, + } + del want['clientSecret'] provider_config = auth.create_oidc_provider_config(**options, app=user_mgt_app) @@ -181,6 +205,11 @@ def test_create_empty_values(self, user_mgt_app): {'issuer': ''}, {'issuer': 'not a url'}, {'display_name': True}, {'enabled': 'true'}, + {'id_token_response_type': 'true'}, {'code_response_type': 'true'}, + {'code_response_type': True, 'client_secret': ''}, + {'code_response_type': True, 'client_secret': True}, + {'code_response_type': True, 'client_secret': None}, + {'code_response_type': False, 'id_token_response_type': False}, ]) def test_update_invalid_args(self, user_mgt_app, invalid_opts): options = {'provider_id': 'oidc.provider'} @@ -198,7 +227,8 @@ def test_update(self, user_mgt_app): assert len(recorder) == 1 req = recorder[0] assert req.method == 'PATCH' - mask = ['clientId', 'displayName', 'enabled', 'issuer'] + mask = ['clientId', 'clientSecret', 'displayName', 'enabled', 'issuer', + 'responseType.code', 'responseType.idToken'] assert req.url == '{0}/oauthIdpConfigs/oidc.provider?updateMask={1}'.format( USER_MGT_URLS['PREFIX'], ','.join(mask)) got = json.loads(req.body.decode()) @@ -223,17 +253,18 @@ def test_update_empty_values(self, user_mgt_app): recorder = _instrument_provider_mgt(user_mgt_app, 200, OIDC_PROVIDER_CONFIG_RESPONSE) provider_config = auth.update_oidc_provider_config( - 'oidc.provider', display_name=auth.DELETE_ATTRIBUTE, enabled=False, app=user_mgt_app) + 'oidc.provider', display_name=auth.DELETE_ATTRIBUTE, enabled=False, + id_token_response_type=False, app=user_mgt_app) self._assert_provider_config(provider_config) assert len(recorder) == 1 req = recorder[0] assert req.method == 'PATCH' - mask = ['displayName', 'enabled'] + mask = ['displayName', 'enabled', 'responseType.idToken'] assert req.url == '{0}/oauthIdpConfigs/oidc.provider?updateMask={1}'.format( USER_MGT_URLS['PREFIX'], ','.join(mask)) got = json.loads(req.body.decode()) - assert got == {'displayName': None, 'enabled': False} + assert got == {'displayName': None, 'enabled': False, 'responseType': {'idToken': False}} @pytest.mark.parametrize('provider_id', INVALID_PROVIDER_IDS + ['saml.provider']) def test_delete_invalid_provider_id(self, user_mgt_app, provider_id): From f695072e138c1089cc59cce6ba0caf83fda27710 Mon Sep 17 00:00:00 2001 From: Lahiru Maramba Date: Wed, 15 Dec 2021 15:34:55 -0500 Subject: [PATCH 122/226] Fixing integration tests (#595) Co-authored-by: Ryan Kohler --- integration/test_auth.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/integration/test_auth.py b/integration/test_auth.py index 1009816eb..82974732d 100644 --- a/integration/test_auth.py +++ b/integration/test_auth.py @@ -736,8 +736,8 @@ def test_create_oidc_provider_config(oidc_provider): assert oidc_provider.issuer == 'https://oidc.com/issuer' assert oidc_provider.display_name == 'OIDC_DISPLAY_NAME' assert oidc_provider.enabled is True - assert oidc_provider.response_type.id_token is True - assert oidc_provider.response_type.code is False + assert oidc_provider.id_token_response_type is True + assert oidc_provider.code_response_type is False assert oidc_provider.client_secret is None @@ -749,8 +749,8 @@ def test_get_oidc_provider_config(oidc_provider): assert provider_config.issuer == 'https://oidc.com/issuer' assert provider_config.display_name == 'OIDC_DISPLAY_NAME' assert provider_config.enabled is True - assert provider_config.response_type.id_token is True - assert provider_config.response_type.code is False + assert provider_config.id_token_response_type is True + assert provider_config.code_response_type is False assert provider_config.client_secret is None @@ -781,8 +781,8 @@ def test_update_oidc_provider_config(): assert provider_config.issuer == 'https://oidc.com/updated_issuer' assert provider_config.display_name == 'UPDATED_OIDC_DISPLAY_NAME' assert provider_config.enabled is False - assert provider_config.response_type.id_token is False - assert provider_config.response_type.code is True + assert provider_config.id_token_response_type is False + assert provider_config.code_response_type is True assert provider_config.client_secret == 'CLIENT_SECRET' finally: auth.delete_oidc_provider_config(provider_config.provider_id) From 6dd41aad788d2f1acf88d4ede0ab29e50a81d791 Mon Sep 17 00:00:00 2001 From: Lahiru Maramba Date: Wed, 15 Dec 2021 16:31:04 -0500 Subject: [PATCH 123/226] [chore] Release 5.2.0 (#596) --- firebase_admin/__about__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/firebase_admin/__about__.py b/firebase_admin/__about__.py index df58224f8..d4c8d76c8 100644 --- a/firebase_admin/__about__.py +++ b/firebase_admin/__about__.py @@ -14,7 +14,7 @@ """About information (version, etc) for Firebase Admin SDK.""" -__version__ = '5.1.0' +__version__ = '5.2.0' __title__ = 'firebase_admin' __author__ = 'Firebase' __license__ = 'Apache License 2.0' From 684bd248b5aa6d4ae07eae7c1cd6b0e098a1e325 Mon Sep 17 00:00:00 2001 From: Lahiru Maramba Date: Thu, 16 Dec 2021 19:01:55 -0500 Subject: [PATCH 124/226] Add delayed response message for holidays (#597) * Add delayed response message for holidays * Remove the table from the markdown --- .github/ISSUE_TEMPLATE/bug_report.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md index 2970d494f..f0b0fd7a1 100644 --- a/.github/ISSUE_TEMPLATE/bug_report.md +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -7,6 +7,8 @@ assignees: '' --- +**Thank you for submitting your issue. We are operating at reduced capacity from Dec 17 2021 to Jan 4 2022. Please expect delayed responses. For more urgent requests please reach us via our support channels https://firebase.google.com/support** + ### [READ] Step 1: Are you in the right place? * For issues related to __the code in this repository__ file a GitHub issue. From 6ee8528b0fd02da316203be3d56b430ff83fec6e Mon Sep 17 00:00:00 2001 From: Lahiru Maramba Date: Fri, 7 Jan 2022 14:37:02 -0500 Subject: [PATCH 125/226] chore: Update Node.js version in CI workflow (#602) * chore: Update Node.js version in CI workflow - Firebase CLI dropped support for Node.js 10. Updating the Node.js version to 12 for emulator test setup. * Update ci.yml Update the version to current LTS 16 --- .github/workflows/ci.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index d81f932a8..0ae8f6585 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -22,10 +22,10 @@ jobs: pip install -r requirements.txt - name: Test with pytest run: pytest - - name: Set up Node.js 10 + - name: Set up Node.js 16 uses: actions/setup-node@v1 with: - node-version: 10.x + node-version: 16.x - name: Run integration tests against emulator run: | npm install -g firebase-tools From 7cfb798ca78f706d08a46c3c48028895787d0ea4 Mon Sep 17 00:00:00 2001 From: Lahiru Maramba Date: Fri, 7 Jan 2022 14:40:26 -0500 Subject: [PATCH 126/226] Revert "Add delayed response message for holidays (#597)" (#600) This reverts commit 684bd248b5aa6d4ae07eae7c1cd6b0e098a1e325. --- .github/ISSUE_TEMPLATE/bug_report.md | 2 -- 1 file changed, 2 deletions(-) diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md index f0b0fd7a1..2970d494f 100644 --- a/.github/ISSUE_TEMPLATE/bug_report.md +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -7,8 +7,6 @@ assignees: '' --- -**Thank you for submitting your issue. We are operating at reduced capacity from Dec 17 2021 to Jan 4 2022. Please expect delayed responses. For more urgent requests please reach us via our support channels https://firebase.google.com/support** - ### [READ] Step 1: Are you in the right place? * For issues related to __the code in this repository__ file a GitHub issue. From 6d826fd15c87db62f7501f569b6e0a762d17a05e Mon Sep 17 00:00:00 2001 From: Andrii Oriekhov Date: Fri, 8 Apr 2022 22:01:03 +0300 Subject: [PATCH 127/226] add GitHub URL for PyPi (#613) --- setup.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/setup.py b/setup.py index 6b47b2214..282c3aa59 100644 --- a/setup.py +++ b/setup.py @@ -50,6 +50,9 @@ description='Firebase Admin Python SDK', long_description=long_description, url=about['__url__'], + project_urls={ + 'Source': 'https://github.com/firebase/firebase-admin-python', + }, author=about['__author__'], license=about['__license__'], keywords='firebase cloud development', From 44a8fde5672828232ffa68267a71eedb270dbb16 Mon Sep 17 00:00:00 2001 From: Lahiru Maramba Date: Fri, 19 Aug 2022 10:34:07 -0400 Subject: [PATCH 128/226] Remove failing nightly tests in project_management (#636) --- integration/test_project_management.py | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/integration/test_project_management.py b/integration/test_project_management.py index 362515535..b0b7fa52a 100644 --- a/integration/test_project_management.py +++ b/integration/test_project_management.py @@ -128,26 +128,12 @@ def test_android_sha_certificates(android_app): for cert in cert_list: assert cert.name - # Adding the same cert twice should cause an already-exists error. - with pytest.raises(exceptions.AlreadyExistsError) as excinfo: - android_app.add_sha_certificate(project_management.SHACertificate(SHA_256_HASH_2)) - assert 'Requested entity already exists' in str(excinfo.value) - assert excinfo.value.cause is not None - assert excinfo.value.http_response is not None - # Delete all certs and assert that they have all been deleted successfully. for cert in cert_list: android_app.delete_sha_certificate(cert) assert android_app.get_sha_certificates() == [] - # Deleting a nonexistent cert should cause a not-found error. - with pytest.raises(exceptions.NotFoundError) as excinfo: - android_app.delete_sha_certificate(cert_list[0]) - assert 'Requested entity was not found' in str(excinfo.value) - assert excinfo.value.cause is not None - assert excinfo.value.http_response is not None - def test_create_ios_app_already_exists(ios_app): del ios_app From e1c6c6f436daafa5d2b03f7cd93941d3622f8a7e Mon Sep 17 00:00:00 2001 From: Jonathan Edey <55252373+jkyle109@users.noreply.github.com> Date: Wed, 24 Aug 2022 10:40:47 -0400 Subject: [PATCH 129/226] feat(firestore): Async Firestore (#635) * feat(firestore): Expose Async Firestore Client. (#621) * feat(firestore): Expose Async Firestore Client. * fix: Added type hints and defintion wording changes * fix: removed future annotations until Python 3.6 is depreciated. * fix: added missed type and clarifying comment for Python 3.6 type hinting. * fix: lint * Adds integration tests for the Async Firstore module (#623) * Add integration tests for async firstore module * fix: made pytest Python 3.6 compatible * Trigger Integration Tests * fix: correct copyright year * Add code snippets for firestore modules. (#628) * Add code snippets for firestore modules. * fix: clarified snippet names and fixed newline. * fix: Removed var tags. These won't work as I intended it to since html is escaped when using includecode. Co-authored-by: Lahiru Maramba --- firebase_admin/firestore_async.py | 82 ++++++++++++++++ integration/conftest.py | 10 ++ integration/test_firestore_async.py | 53 +++++++++++ requirements.txt | 1 + snippets/firestore/__init__.py | 0 snippets/firestore/firestore.py | 84 ++++++++++++++++ snippets/firestore/firestore_async.py | 132 ++++++++++++++++++++++++++ tests/test_firestore_async.py | 81 ++++++++++++++++ 8 files changed, 443 insertions(+) create mode 100644 firebase_admin/firestore_async.py create mode 100644 integration/test_firestore_async.py create mode 100644 snippets/firestore/__init__.py create mode 100644 snippets/firestore/firestore.py create mode 100644 snippets/firestore/firestore_async.py create mode 100644 tests/test_firestore_async.py diff --git a/firebase_admin/firestore_async.py b/firebase_admin/firestore_async.py new file mode 100644 index 000000000..a63d5a761 --- /dev/null +++ b/firebase_admin/firestore_async.py @@ -0,0 +1,82 @@ +# Copyright 2022 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Cloud Firestore Async module. + +This module contains utilities for asynchronusly accessing the Google Cloud Firestore databases +associated with Firebase apps. This requires the ``google-cloud-firestore`` Python module. +""" + +from typing import Type + +from firebase_admin import ( + App, + _utils, +) +from firebase_admin.credentials import Base + +try: + from google.cloud import firestore # type: ignore # pylint: disable=import-error,no-name-in-module + existing = globals().keys() + for key, value in firestore.__dict__.items(): + if not key.startswith('_') and key not in existing: + globals()[key] = value +except ImportError: + raise ImportError('Failed to import the Cloud Firestore library for Python. Make sure ' + 'to install the "google-cloud-firestore" module.') + +_FIRESTORE_ASYNC_ATTRIBUTE: str = '_firestore_async' + + +def client(app: App = None) -> firestore.AsyncClient: + """Returns an async client that can be used to interact with Google Cloud Firestore. + + Args: + app: An App instance (optional). + + Returns: + google.cloud.firestore.Firestore_Async: A `Firestore Async Client`_. + + Raises: + ValueError: If a project ID is not specified either via options, credentials or + environment variables, or if the specified project ID is not a valid string. + + .. _Firestore Async Client: https://googleapis.dev/python/firestore/latest/client.html + """ + fs_client = _utils.get_app_service( + app, _FIRESTORE_ASYNC_ATTRIBUTE, _FirestoreAsyncClient.from_app) + return fs_client.get() + + +class _FirestoreAsyncClient: + """Holds a Google Cloud Firestore Async Client instance.""" + + def __init__(self, credentials: Type[Base], project: str) -> None: + self._client = firestore.AsyncClient(credentials=credentials, project=project) + + def get(self) -> firestore.AsyncClient: + return self._client + + @classmethod + def from_app(cls, app: App) -> "_FirestoreAsyncClient": + # Replace remove future reference quotes by importing annotations in Python 3.7+ b/238779406 + """Creates a new _FirestoreAsyncClient for the specified app.""" + credentials = app.credential.get_credential() + project = app.project_id + if not project: + raise ValueError( + 'Project ID is required to access Firestore. Either set the projectId option, ' + 'or use service account credentials. Alternatively, set the GOOGLE_CLOUD_PROJECT ' + 'environment variable.') + return _FirestoreAsyncClient(credentials, project) diff --git a/integration/conftest.py b/integration/conftest.py index 169e02d5b..71f53f612 100644 --- a/integration/conftest.py +++ b/integration/conftest.py @@ -15,6 +15,7 @@ """pytest configuration and global fixtures for integration tests.""" import json +import asyncio import pytest import firebase_admin @@ -70,3 +71,12 @@ def api_key(request): 'command-line option.') with open(path) as keyfile: return keyfile.read().strip() + +@pytest.fixture(scope="session") +def event_loop(): + """Create an instance of the default event loop for test session. + This avoids early eventloop closure. + """ + loop = asyncio.get_event_loop_policy().new_event_loop() + yield loop + loop.close() diff --git a/integration/test_firestore_async.py b/integration/test_firestore_async.py new file mode 100644 index 000000000..2a5b93217 --- /dev/null +++ b/integration/test_firestore_async.py @@ -0,0 +1,53 @@ +# Copyright 2022 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Integration tests for firebase_admin.firestore_async module.""" +import datetime +import pytest + +from firebase_admin import firestore_async + +@pytest.mark.asyncio +async def test_firestore_async(): + client = firestore_async.client() + expected = { + 'name': u'Mountain View', + 'country': u'USA', + 'population': 77846, + 'capital': False + } + doc = client.collection('cities').document() + await doc.set(expected) + + data = await doc.get() + assert data.to_dict() == expected + + await doc.delete() + data = await doc.get() + assert data.exists is False + +@pytest.mark.asyncio +async def test_server_timestamp(): + client = firestore_async.client() + expected = { + 'name': u'Mountain View', + 'timestamp': firestore_async.SERVER_TIMESTAMP # pylint: disable=no-member + } + doc = client.collection('cities').document() + await doc.set(expected) + + data = await doc.get() + data = data.to_dict() + assert isinstance(data['timestamp'], datetime.datetime) + await doc.delete() diff --git a/requirements.txt b/requirements.txt index 0dd529c04..87142fe93 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,6 +3,7 @@ pylint == 2.3.1 pytest >= 6.2.0 pytest-cov >= 2.4.0 pytest-localserver >= 0.4.1 +pytest-asyncio >= 0.16.0 cachecontrol >= 0.12.6 google-api-core[grpc] >= 1.22.1, < 3.0.0dev; platform.python_implementation != 'PyPy' diff --git a/snippets/firestore/__init__.py b/snippets/firestore/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/snippets/firestore/firestore.py b/snippets/firestore/firestore.py new file mode 100644 index 000000000..18040b742 --- /dev/null +++ b/snippets/firestore/firestore.py @@ -0,0 +1,84 @@ +# Copyright 2022 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from firebase_admin import firestore + +# pylint: disable=invalid-name +def init_firestore_client(): + # [START init_firestore_client] + import firebase_admin + from firebase_admin import firestore + + # Application Default credentials are automatically created. + app = firebase_admin.initialize_app() + db = firestore.client() + # [END init_firestore_client] + +def init_firestore_client_application_default(): + # [START init_firestore_client_application_default] + import firebase_admin + from firebase_admin import credentials + from firebase_admin import firestore + + # Use the application default credentials. + cred = credentials.ApplicationDefault() + + firebase_admin.initialize_app(cred) + db = firestore.client() + # [END init_firestore_client_application_default] + +def init_firestore_client_service_account(): + # [START init_firestore_client_service_account] + import firebase_admin + from firebase_admin import credentials + from firebase_admin import firestore + + # Use a service account. + cred = credentials.Certificate('path/to/serviceAccount.json') + + app = firebase_admin.initialize_app(cred) + + db = firestore.client() + # [END init_firestore_client_service_account] + +def read_data(): + import firebase_admin + from firebase_admin import firestore + + app = firebase_admin.initialize_app() + db = firestore.client() + + # [START read_data] + doc_ref = db.collection('users').document('alovelace') + doc = doc_ref.get() + if doc.exists: + return f'data: {doc.to_dict()}' + return "Document does not exist." + # [END read_data] + +def add_data(): + import firebase_admin + from firebase_admin import firestore + + app = firebase_admin.initialize_app() + db = firestore.client() + + # [START add_data] + doc_ref = db.collection("users").document("alovelace") + doc_ref.set({ + "first": "Ada", + "last": "Lovelace", + "born": 1815 + }) + # [END add_data] diff --git a/snippets/firestore/firestore_async.py b/snippets/firestore/firestore_async.py new file mode 100644 index 000000000..cf815504e --- /dev/null +++ b/snippets/firestore/firestore_async.py @@ -0,0 +1,132 @@ +# Copyright 2022 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio + +from firebase_admin import firestore_async + +# pylint: disable=invalid-name +def init_firestore_async_client(): + # [START init_firestore_async_client] + import firebase_admin + from firebase_admin import firestore_async + + # Application Default credentials are automatically created. + app = firebase_admin.initialize_app() + db = firestore_async.client() + # [END init_firestore_async_client] + +def init_firestore_async_client_application_default(): + # [START init_firestore_async_client_application_default] + import firebase_admin + from firebase_admin import credentials + from firebase_admin import firestore_async + + # Use the application default credentials. + cred = credentials.ApplicationDefault() + + firebase_admin.initialize_app(cred) + db = firestore_async.client() + # [END init_firestore_async_client_application_default] + +def init_firestore_async_client_service_account(): + # [START init_firestore_async_client_service_account] + import firebase_admin + from firebase_admin import credentials + from firebase_admin import firestore_async + + # Use a service account. + cred = credentials.Certificate('path/to/serviceAccount.json') + + app = firebase_admin.initialize_app(cred) + + db = firestore_async.client() + # [END init_firestore_async_client_service_account] + +def close_async_sessions(): + import firebase_admin + from firebase_admin import firestore_async + + # [START close_async_sessions] + app = firebase_admin.initialize_app() + db = firestore_async.client() + + # Perform firestore tasks... + + # Delete app to ensure that all the async sessions are closed gracefully. + firebase_admin.delete_app(app) + # [END close_async_sessions] + +async def read_data(): + import firebase_admin + from firebase_admin import firestore_async + + app = firebase_admin.initialize_app() + db = firestore_async.client() + + # [START read_data] + doc_ref = db.collection('users').document('alovelace') + doc = await doc_ref.get() + if doc.exists: + return f'data: {doc.to_dict()}' + # [END read_data] + +async def add_data(): + import firebase_admin + from firebase_admin import firestore_async + + app = firebase_admin.initialize_app() + db = firestore_async.client() + + # [START add_data] + doc_ref = db.collection("users").document("alovelace") + await doc_ref.set({ + "first": "Ada", + "last": "Lovelace", + "born": 1815 + }) + # [END add_data] + +def firestore_async_client_with_asyncio_eventloop(): + # [START firestore_async_client_with_asyncio_eventloop] + import asyncio + import firebase_admin + from firebase_admin import firestore_async + + app = firebase_admin.initialize_app() + db = firestore_async.client() + + # Create coroutine to add user data. + async def add_data(): + doc_ref = db.collection("users").document("alovelace") + print("Start adding user...") + await doc_ref.set({ + "first": "Ada", + "last": "Lovelace", + "born": 1815 + }) + print("Done adding user!") + + # Another corutine with secondary tasks we want to complete. + async def while_waiting(): + print("Start other tasks...") + await asyncio.sleep(2) + print("Finished with other tasks!") + + # Initialize an eventloop to execute tasks until completion. + loop = asyncio.get_event_loop() + tasks = [add_data(), while_waiting()] + loop.run_until_complete(asyncio.gather(*tasks)) + firebase_admin.delete_app(app) + # [END firestore_async_client_with_asyncio_eventloop] diff --git a/tests/test_firestore_async.py b/tests/test_firestore_async.py new file mode 100644 index 000000000..0fb17c813 --- /dev/null +++ b/tests/test_firestore_async.py @@ -0,0 +1,81 @@ +# Copyright 2022 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for firebase_admin.firestore_async.""" + +import platform + +import pytest + +import firebase_admin +from firebase_admin import credentials +try: + from firebase_admin import firestore_async +except ImportError: + pass +from tests import testutils + + +@pytest.mark.skipif( + platform.python_implementation() == 'PyPy', + reason='Firestore is not supported on PyPy') +class TestFirestoreAsync: + """Test class Firestore Async APIs.""" + + def teardown_method(self, method): + del method + testutils.cleanup_apps() + + def test_no_project_id(self): + def evaluate(): + firebase_admin.initialize_app(testutils.MockCredential()) + with pytest.raises(ValueError): + firestore_async.client() + testutils.run_without_project_id(evaluate) + + def test_project_id(self): + cred = credentials.Certificate(testutils.resource_filename('service_account.json')) + firebase_admin.initialize_app(cred, {'projectId': 'explicit-project-id'}) + client = firestore_async.client() + assert client is not None + assert client.project == 'explicit-project-id' + + def test_project_id_with_explicit_app(self): + cred = credentials.Certificate(testutils.resource_filename('service_account.json')) + app = firebase_admin.initialize_app(cred, {'projectId': 'explicit-project-id'}) + client = firestore_async.client(app=app) + assert client is not None + assert client.project == 'explicit-project-id' + + def test_service_account(self): + cred = credentials.Certificate(testutils.resource_filename('service_account.json')) + firebase_admin.initialize_app(cred) + client = firestore_async.client() + assert client is not None + assert client.project == 'mock-project-id' + + def test_service_account_with_explicit_app(self): + cred = credentials.Certificate(testutils.resource_filename('service_account.json')) + app = firebase_admin.initialize_app(cred) + client = firestore_async.client(app=app) + assert client is not None + assert client.project == 'mock-project-id' + + def test_geo_point(self): + geo_point = firestore_async.GeoPoint(10, 20) # pylint: disable=no-member + assert geo_point.latitude == 10 + assert geo_point.longitude == 20 + + def test_server_timestamp(self): + assert firestore_async.SERVER_TIMESTAMP is not None # pylint: disable=no-member From 336dbef17d8f835b54ce609dc0352e10114c7645 Mon Sep 17 00:00:00 2001 From: Lahiru Maramba Date: Thu, 25 Aug 2022 14:56:58 -0400 Subject: [PATCH 130/226] [chore] Release 5.3.0 (#637) - Release v 5.3.0 - Adding async firestore API --- firebase_admin/__about__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/firebase_admin/__about__.py b/firebase_admin/__about__.py index d4c8d76c8..b63347bcd 100644 --- a/firebase_admin/__about__.py +++ b/firebase_admin/__about__.py @@ -14,7 +14,7 @@ """About information (version, etc) for Firebase Admin SDK.""" -__version__ = '5.2.0' +__version__ = '5.3.0' __title__ = 'firebase_admin' __author__ = 'Firebase' __license__ = 'Apache License 2.0' From 9e5b8e383e066c319a483285df903150d5029a34 Mon Sep 17 00:00:00 2001 From: pragatimodi <110490169+pragatimodi@users.noreply.github.com> Date: Thu, 15 Sep 2022 01:49:21 +0530 Subject: [PATCH 131/226] Bug fix - Changing variable from "MemoryCost" to "cpuMemCost" in standard_scrypt method (#643) * Changing 'memoryCost' to CpuMemCost RELEASE NOTE: Fixed an incorrect key used to set MemoryCost config in standard_scrypt() API * Update Test for standard_scrypt * Update _user_import.py --- firebase_admin/_user_import.py | 4 ++-- tests/test_user_mgt.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/firebase_admin/_user_import.py b/firebase_admin/_user_import.py index 7834b232a..659a68701 100644 --- a/firebase_admin/_user_import.py +++ b/firebase_admin/_user_import.py @@ -454,7 +454,7 @@ def standard_scrypt(cls, memory_cost, parallelization, block_size, derived_key_l """Creates a new standard Scrypt algorithm instance. Args: - memory_cost: Memory cost as a non-negaive integer. + memory_cost: CPU Memory cost as a non-negative integer. parallelization: Parallelization as a non-negative integer. block_size: Block size as a non-negative integer. derived_key_length: Derived key length as a non-negative integer. @@ -463,7 +463,7 @@ def standard_scrypt(cls, memory_cost, parallelization, block_size, derived_key_l UserImportHash: A new ``UserImportHash``. """ data = { - 'memoryCost': _auth_utils.validate_int(memory_cost, 'memory_cost', low=0), + 'cpuMemCost': _auth_utils.validate_int(memory_cost, 'memory_cost', low=0), 'parallelization': _auth_utils.validate_int(parallelization, 'parallelization', low=0), 'blockSize': _auth_utils.validate_int(block_size, 'block_size', low=0), 'dkLen': _auth_utils.validate_int(derived_key_length, 'derived_key_length', low=0), diff --git a/tests/test_user_mgt.py b/tests/test_user_mgt.py index 67447c6ba..b590cca05 100644 --- a/tests/test_user_mgt.py +++ b/tests/test_user_mgt.py @@ -1212,7 +1212,7 @@ def test_standard_scrypt(self): memory_cost=14, parallelization=2, block_size=10, derived_key_length=128) expected = { 'hashAlgorithm': 'STANDARD_SCRYPT', - 'memoryCost': 14, + 'cpuMemCost': 14, 'parallelization': 2, 'blockSize': 10, 'dkLen': 128, From 32d40ca592d7dceda952547ce8e1d959351bf617 Mon Sep 17 00:00:00 2001 From: Lahiru Maramba Date: Thu, 22 Sep 2022 17:04:30 -0400 Subject: [PATCH 132/226] change: Deprecated support for Python 3.6 (#646) * Deprecated support for Python 3.6 --- README.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 646d3d0e3..9c5101e3f 100644 --- a/README.md +++ b/README.md @@ -43,7 +43,8 @@ requests, code review feedback, and also pull requests. ## Supported Python Versions -We currently support Python 3.6+. Firebase +We currently support Python 3.6+. However, Python 3.6 support is deprecated, +and developers are strongly advised to use Python 3.7 or higher. Firebase Admin Python SDK is also tested on PyPy and [Google App Engine](https://cloud.google.com/appengine/) environments. From 78d2e44e70b3209911f413ff8f5012eadef88590 Mon Sep 17 00:00:00 2001 From: Lahiru Maramba Date: Thu, 22 Sep 2022 18:54:34 -0400 Subject: [PATCH 133/226] [chore] Release 5.4.0 (#647) * [chore] Release 5.4.0 Release v5.4.0 Deprecated Python 3.6 (reached EoL) --- firebase_admin/__about__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/firebase_admin/__about__.py b/firebase_admin/__about__.py index b63347bcd..b24ab002e 100644 --- a/firebase_admin/__about__.py +++ b/firebase_admin/__about__.py @@ -14,7 +14,7 @@ """About information (version, etc) for Firebase Admin SDK.""" -__version__ = '5.3.0' +__version__ = '5.4.0' __title__ = 'firebase_admin' __author__ = 'Firebase' __license__ = 'Apache License 2.0' From 0dd630389947ebe81791491170ce6a134fb053ba Mon Sep 17 00:00:00 2001 From: Lahiru Maramba Date: Wed, 28 Sep 2022 16:03:45 -0400 Subject: [PATCH 134/226] change: Drop Python 3.6 support (#645) * Drop support for Python 3.6 * Fix 3.10 * update readme --- .github/workflows/ci.yml | 10 +++++----- .github/workflows/nightly.yml | 6 +++--- .github/workflows/release.yml | 4 ++-- CONTRIBUTING.md | 2 +- README.md | 3 +-- setup.py | 8 ++++---- 6 files changed, 16 insertions(+), 17 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 0ae8f6585..d2129720b 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -8,12 +8,12 @@ jobs: strategy: fail-fast: false matrix: - python: [3.6, 3.7, 3.8, 3.9, pypy3] + python: ['3.7', '3.8', '3.9', '3.10', 'pypy3.7'] steps: - - uses: actions/checkout@v1 + - uses: actions/checkout@v3 - name: Set up Python ${{ matrix.python }} - uses: actions/setup-python@v1 + uses: actions/setup-python@v4 with: python-version: ${{ matrix.python }} - name: Install dependencies @@ -34,9 +34,9 @@ jobs: lint: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v1 + - uses: actions/checkout@v3 - name: Set up Python 3.7 - uses: actions/setup-python@v1 + uses: actions/setup-python@v4 with: python-version: 3.7 - name: Install dependencies diff --git a/.github/workflows/nightly.yml b/.github/workflows/nightly.yml index f22eb99c8..ac6c62abe 100644 --- a/.github/workflows/nightly.yml +++ b/.github/workflows/nightly.yml @@ -29,14 +29,14 @@ jobs: steps: - name: Checkout source for staging - uses: actions/checkout@v2 + uses: actions/checkout@v3 with: ref: ${{ github.event.client_payload.ref || github.ref }} - name: Set up Python - uses: actions/setup-python@v1 + uses: actions/setup-python@v4 with: - python-version: 3.6 + python-version: 3.7 - name: Install dependencies run: | diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index fbde8ed59..5eb4bfaea 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -45,9 +45,9 @@ jobs: ref: ${{ github.event.client_payload.ref || github.ref }} - name: Set up Python - uses: actions/setup-python@v1 + uses: actions/setup-python@v4 with: - python-version: 3.6 + python-version: 3.7 - name: Install dependencies run: | diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 30685394e..1d500cba8 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -85,7 +85,7 @@ information on using pull requests. ### Initial Setup -You need Python 3.6+ to build and test the code in this repo. +You need Python 3.7+ to build and test the code in this repo. We recommend using [pip](https://pypi.python.org/pypi/pip) for installing the necessary tools and project dependencies. Most recent versions of Python ship with pip. If your development environment diff --git a/README.md b/README.md index 9c5101e3f..041c41673 100644 --- a/README.md +++ b/README.md @@ -43,8 +43,7 @@ requests, code review feedback, and also pull requests. ## Supported Python Versions -We currently support Python 3.6+. However, Python 3.6 support is deprecated, -and developers are strongly advised to use Python 3.7 or higher. Firebase +We currently support Python 3.7+. Firebase Admin Python SDK is also tested on PyPy and [Google App Engine](https://cloud.google.com/appengine/) environments. diff --git a/setup.py b/setup.py index 282c3aa59..a54949891 100644 --- a/setup.py +++ b/setup.py @@ -22,8 +22,8 @@ (major, minor) = (sys.version_info.major, sys.version_info.minor) -if major != 3 or minor < 6: - print('firebase_admin requires python >= 3.6', file=sys.stderr) +if major != 3 or minor < 7: + print('firebase_admin requires python >= 3.7', file=sys.stderr) sys.exit(1) # Read in the package metadata per recommendations from: @@ -58,16 +58,16 @@ keywords='firebase cloud development', install_requires=install_requires, packages=['firebase_admin'], - python_requires='>=3.6', + python_requires='>=3.7', classifiers=[ 'Development Status :: 5 - Production/Stable', 'Intended Audience :: Developers', 'Topic :: Software Development :: Build Tools', 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.6', 'Programming Language :: Python :: 3.7', 'Programming Language :: Python :: 3.8', 'Programming Language :: Python :: 3.9', + 'Programming Language :: Python :: 3.10', 'License :: OSI Approved :: Apache Software License', ], ) From 5b7ac0558ef89cdd386daddaf55a3d5ee3122a72 Mon Sep 17 00:00:00 2001 From: dwyfrequency Date: Thu, 29 Sep 2022 13:40:51 -0400 Subject: [PATCH 135/226] feat: Add function to verify an App Check token (#642) * Sketch out initial private methods and service * Remove unnecessary notes * Fix some lint issues * Fix style guide issues * Update code structure * Add pyjwt version to requirments & update code based on comments * Add app_id key for verified claims dict * Add initial test * Add tests for token headers * Add decode token test and notes * Updating requirements for mocks and note in test * Add verify token test and decode test * Update pytest-mock requirements * Add tests for error messages * Update requirements for lifespan cache * update error message and test * Explicitly pass audience to jwt.decode and update key retrieval * Mock signing key * Update aud check logic and tests * Remove print statement * Update method doc string * Add test for decode_token error * Catch additional errors and add custom error messages for them * Mock out all the common errors * Updating error messages and tests per comments * Make jwks_client a class property * Add validation for the subject in the JWT payload * Update docs and error message strings --- .gitignore | 1 + firebase_admin/app_check.py | 150 ++++++++++++++++++++ requirements.txt | 2 + tests/test_app_check.py | 275 ++++++++++++++++++++++++++++++++++++ 4 files changed, 428 insertions(+) create mode 100644 firebase_admin/app_check.py create mode 100644 tests/test_app_check.py diff --git a/.gitignore b/.gitignore index 79d2d5ff3..e5c1902d5 100644 --- a/.gitignore +++ b/.gitignore @@ -12,3 +12,4 @@ apikey.txt htmlcov/ .pytest_cache/ .vscode/ +.venv/ diff --git a/firebase_admin/app_check.py b/firebase_admin/app_check.py new file mode 100644 index 000000000..91b0c4c31 --- /dev/null +++ b/firebase_admin/app_check.py @@ -0,0 +1,150 @@ +# Copyright 2022 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Firebase App Check module.""" + +from typing import Any, Dict +import jwt +from jwt import PyJWKClient, ExpiredSignatureError, InvalidTokenError +from jwt import InvalidAudienceError, InvalidIssuerError, InvalidSignatureError +from firebase_admin import _utils + +_APP_CHECK_ATTRIBUTE = '_app_check' + +def _get_app_check_service(app) -> Any: + return _utils.get_app_service(app, _APP_CHECK_ATTRIBUTE, _AppCheckService) + +def verify_token(token: str, app=None) -> Dict[str, Any]: + """Verifies a Firebase App Check token. + + Args: + token: A token from App Check. + app: An App instance (optional). + + Returns: + Dict[str, Any]: The token's decoded claims. + + Raises: + ValueError: If the app's ``project_id`` is invalid or unspecified, + or if the token's headers or payload are invalid. + """ + return _get_app_check_service(app).verify_token(token) + +class _AppCheckService: + """Service class that implements Firebase App Check functionality.""" + + _APP_CHECK_ISSUER = 'https://firebaseappcheck.googleapis.com/' + _JWKS_URL = 'https://firebaseappcheck.googleapis.com/v1/jwks' + _project_id = None + _scoped_project_id = None + _jwks_client = None + + def __init__(self, app): + # Validate and store the project_id to validate the JWT claims + self._project_id = app.project_id + if not self._project_id: + raise ValueError( + 'A project ID must be specified to access the App Check ' + 'service. Either set the projectId option, use service ' + 'account credentials, or set the ' + 'GOOGLE_CLOUD_PROJECT environment variable.') + self._scoped_project_id = 'projects/' + app.project_id + # Default lifespan is 300 seconds (5 minutes) so we change it to 21600 seconds (6 hours). + self._jwks_client = PyJWKClient(self._JWKS_URL, lifespan=21600) + + + def verify_token(self, token: str) -> Dict[str, Any]: + """Verifies a Firebase App Check token.""" + _Validators.check_string("app check token", token) + + # Obtain the Firebase App Check Public Keys + # Note: It is not recommended to hard code these keys as they rotate, + # but you should cache them for up to 6 hours. + signing_key = self._jwks_client.get_signing_key_from_jwt(token) + self._has_valid_token_headers(jwt.get_unverified_header(token)) + verified_claims = self._decode_and_verify(token, signing_key.key) + + verified_claims['app_id'] = verified_claims.get('sub') + return verified_claims + + def _has_valid_token_headers(self, headers: Any) -> None: + """Checks whether the token has valid headers for App Check.""" + # Ensure the token's header has type JWT + if headers.get('typ') != 'JWT': + raise ValueError("The provided App Check token has an incorrect type header") + # Ensure the token's header uses the algorithm RS256 + algorithm = headers.get('alg') + if algorithm != 'RS256': + raise ValueError( + 'The provided App Check token has an incorrect alg header. ' + f'Expected RS256 but got {algorithm}.' + ) + + def _decode_and_verify(self, token: str, signing_key: str): + """Decodes and verifies the token from App Check.""" + payload = {} + try: + payload = jwt.decode( + token, + signing_key, + algorithms=["RS256"], + audience=self._scoped_project_id + ) + except InvalidSignatureError: + raise ValueError( + 'The provided App Check token has an invalid signature.' + ) + except InvalidAudienceError: + raise ValueError( + 'The provided App Check token has an incorrect "aud" (audience) claim. ' + f'Expected payload to include {self._scoped_project_id}.' + ) + except InvalidIssuerError: + raise ValueError( + 'The provided App Check token has an incorrect "iss" (issuer) claim. ' + f'Expected claim to include {self._APP_CHECK_ISSUER}' + ) + except ExpiredSignatureError: + raise ValueError( + 'The provided App Check token has expired.' + ) + except InvalidTokenError as exception: + raise ValueError( + f'Decoding App Check token failed. Error: {exception}' + ) + + audience = payload.get('aud') + if not isinstance(audience, list) or self._scoped_project_id not in audience: + raise ValueError('Firebase App Check token has incorrect "aud" (audience) claim.') + if not payload.get('iss').startswith(self._APP_CHECK_ISSUER): + raise ValueError('Token does not contain the correct "iss" (issuer).') + _Validators.check_string( + 'The provided App Check token "sub" (subject) claim', + payload.get('sub')) + + return payload + +class _Validators: + """A collection of data validation utilities. + + Methods provided in this class raise ``ValueErrors`` if any validations fail. + """ + + @classmethod + def check_string(cls, label: str, value: Any): + """Checks if the given value is a string.""" + if value is None: + raise ValueError('{0} "{1}" must be a non-empty string.'.format(label, value)) + if not isinstance(value, str): + raise ValueError('{0} "{1}" must be a string.'.format(label, value)) diff --git a/requirements.txt b/requirements.txt index 87142fe93..c66212673 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,9 +4,11 @@ pytest >= 6.2.0 pytest-cov >= 2.4.0 pytest-localserver >= 0.4.1 pytest-asyncio >= 0.16.0 +pytest-mock >= 3.6.1 cachecontrol >= 0.12.6 google-api-core[grpc] >= 1.22.1, < 3.0.0dev; platform.python_implementation != 'PyPy' google-api-python-client >= 1.7.8 google-cloud-firestore >= 2.1.0; platform.python_implementation != 'PyPy' google-cloud-storage >= 1.37.1 +pyjwt[crypto] >= 2.5.0 \ No newline at end of file diff --git a/tests/test_app_check.py b/tests/test_app_check.py new file mode 100644 index 000000000..168d0a972 --- /dev/null +++ b/tests/test_app_check.py @@ -0,0 +1,275 @@ +# Copyright 2022 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test cases for the firebase_admin.app_check module.""" +import base64 +import pytest + +from jwt import PyJWK, InvalidAudienceError, InvalidIssuerError +from jwt import ExpiredSignatureError, InvalidSignatureError +import firebase_admin +from firebase_admin import app_check +from tests import testutils + +NON_STRING_ARGS = [list(), tuple(), dict(), True, False, 1, 0] + +APP_ID = "1234567890" +PROJECT_ID = "1334" +SCOPED_PROJECT_ID = f"projects/{PROJECT_ID}" +ISSUER = "https://firebaseappcheck.googleapis.com/" +JWT_PAYLOAD_SAMPLE = { + "headers": { + "alg": "RS256", + "typ": "JWT" + }, + "sub": APP_ID, + "name": "John Doe", + "iss": ISSUER, + "aud": [SCOPED_PROJECT_ID] +} + +secret_key = "secret" +signing_key = { + "kty": "oct", + # Using HS256 for simplicity, production key will use RS256 + "alg": "HS256", + "k": base64.urlsafe_b64encode(secret_key.encode()) +} + +class TestBatch: + + @classmethod + def setup_class(cls): + cred = testutils.MockCredential() + firebase_admin.initialize_app(cred, {'projectId': PROJECT_ID}) + + @classmethod + def teardown_class(cls): + testutils.cleanup_apps() + +class TestVerifyToken(TestBatch): + + def test_no_project_id(self): + def evaluate(): + app = firebase_admin.initialize_app(testutils.MockCredential(), name='no_project_id') + with pytest.raises(ValueError): + app_check.verify_token(token="app_check_token", app=app) + testutils.run_without_project_id(evaluate) + + @pytest.mark.parametrize('token', NON_STRING_ARGS) + def test_verify_token_with_non_string_raises_error(self, token): + with pytest.raises(ValueError) as excinfo: + app_check.verify_token(token) + expected = 'app check token "{0}" must be a string.'.format(token) + assert str(excinfo.value) == expected + + def test_has_valid_token_headers(self): + app = firebase_admin.get_app() + app_check_service = app_check._get_app_check_service(app) + + headers = {"alg": "RS256", 'typ': "JWT"} + assert app_check_service._has_valid_token_headers(headers=headers) is None + + def test_has_valid_token_headers_with_incorrect_type_raises_error(self): + app = firebase_admin.get_app() + app_check_service = app_check._get_app_check_service(app) + headers = {"alg": "RS256", 'typ': "WRONG"} + with pytest.raises(ValueError) as excinfo: + app_check_service._has_valid_token_headers(headers=headers) + + expected = 'The provided App Check token has an incorrect type header' + assert str(excinfo.value) == expected + + def test_has_valid_token_headers_with_incorrect_algorithm_raises_error(self): + app = firebase_admin.get_app() + app_check_service = app_check._get_app_check_service(app) + headers = {"alg": "HS256", 'typ': "JWT"} + with pytest.raises(ValueError) as excinfo: + app_check_service._has_valid_token_headers(headers=headers) + + expected = ('The provided App Check token has an incorrect alg header. ' + 'Expected RS256 but got HS256.') + assert str(excinfo.value) == expected + + def test_decode_and_verify(self, mocker): + jwt_decode_mock = mocker.patch("jwt.decode", return_value=JWT_PAYLOAD_SAMPLE) + app = firebase_admin.get_app() + app_check_service = app_check._get_app_check_service(app) + payload = app_check_service._decode_and_verify( + token=None, + signing_key="1234", + ) + + jwt_decode_mock.assert_called_once_with( + None, "1234", algorithms=["RS256"], audience=SCOPED_PROJECT_ID) + assert payload == JWT_PAYLOAD_SAMPLE.copy() + + def test_decode_and_verify_with_incorrect_token_and_key(self): + app = firebase_admin.get_app() + app_check_service = app_check._get_app_check_service(app) + with pytest.raises(ValueError) as excinfo: + app_check_service._decode_and_verify( + token="1232132", + signing_key=signing_key, + ) + + expected = ( + 'Decoding App Check token failed. Error: Not enough segments') + assert str(excinfo.value) == expected + + def test_decode_and_verify_with_expired_token_raises_error(self, mocker): + mocker.patch("jwt.decode", side_effect=ExpiredSignatureError) + app = firebase_admin.get_app() + app_check_service = app_check._get_app_check_service(app) + with pytest.raises(ValueError) as excinfo: + app_check_service._decode_and_verify( + token="1232132", + signing_key=signing_key, + ) + + expected = ( + 'The provided App Check token has expired.') + assert str(excinfo.value) == expected + + def test_decode_and_verify_with_invalid_signature_raises_error(self, mocker): + mocker.patch("jwt.decode", side_effect=InvalidSignatureError) + app = firebase_admin.get_app() + app_check_service = app_check._get_app_check_service(app) + with pytest.raises(ValueError) as excinfo: + app_check_service._decode_and_verify( + token="1232132", + signing_key=signing_key, + ) + + expected = ( + 'The provided App Check token has an invalid signature.') + assert str(excinfo.value) == expected + + def test_decode_and_verify_with_invalid_aud_raises_error(self, mocker): + mocker.patch("jwt.decode", side_effect=InvalidAudienceError) + app = firebase_admin.get_app() + app_check_service = app_check._get_app_check_service(app) + with pytest.raises(ValueError) as excinfo: + app_check_service._decode_and_verify( + token="1232132", + signing_key=signing_key, + ) + + expected = ( + 'The provided App Check token has an incorrect "aud" (audience) claim. ' + f'Expected payload to include {SCOPED_PROJECT_ID}.') + assert str(excinfo.value) == expected + + def test_decode_and_verify_with_invalid_iss_raises_error(self, mocker): + mocker.patch("jwt.decode", side_effect=InvalidIssuerError) + app = firebase_admin.get_app() + app_check_service = app_check._get_app_check_service(app) + with pytest.raises(ValueError) as excinfo: + app_check_service._decode_and_verify( + token="1232132", + signing_key=signing_key, + ) + + expected = ( + 'The provided App Check token has an incorrect "iss" (issuer) claim. ' + f'Expected claim to include {ISSUER}') + assert str(excinfo.value) == expected + + def test_decode_and_verify_with_none_sub_raises_error(self, mocker): + jwt_with_none_sub = JWT_PAYLOAD_SAMPLE.copy() + jwt_with_none_sub['sub'] = None + mocker.patch("jwt.decode", return_value=jwt_with_none_sub) + app = firebase_admin.get_app() + app_check_service = app_check._get_app_check_service(app) + with pytest.raises(ValueError) as excinfo: + app_check_service._decode_and_verify( + token="1232132", + signing_key=signing_key, + ) + + expected = ( + 'The provided App Check token "sub" (subject) claim ' + f'"{None}" must be a non-empty string.') + assert str(excinfo.value) == expected + + def test_decode_and_verify_with_non_string_sub_raises_error(self, mocker): + sub_number = 1234 + jwt_with_none_sub = JWT_PAYLOAD_SAMPLE.copy() + jwt_with_none_sub['sub'] = sub_number + mocker.patch("jwt.decode", return_value=jwt_with_none_sub) + app = firebase_admin.get_app() + app_check_service = app_check._get_app_check_service(app) + with pytest.raises(ValueError) as excinfo: + app_check_service._decode_and_verify( + token="1232132", + signing_key=signing_key, + ) + + expected = ( + 'The provided App Check token "sub" (subject) claim ' + f'"{sub_number}" must be a string.') + assert str(excinfo.value) == expected + + def test_verify_token(self, mocker): + mocker.patch("jwt.decode", return_value=JWT_PAYLOAD_SAMPLE) + mocker.patch("jwt.PyJWKClient.get_signing_key_from_jwt", return_value=PyJWK(signing_key)) + mocker.patch("jwt.get_unverified_header", return_value=JWT_PAYLOAD_SAMPLE.get("headers")) + app = firebase_admin.get_app() + + payload = app_check.verify_token("encoded", app) + expected = JWT_PAYLOAD_SAMPLE.copy() + expected['app_id'] = APP_ID + assert payload == expected + + def test_verify_token_with_non_list_audience_raises_error(self, mocker): + jwt_with_non_list_audience = JWT_PAYLOAD_SAMPLE.copy() + jwt_with_non_list_audience["aud"] = '1234' + mocker.patch("jwt.decode", return_value=jwt_with_non_list_audience) + mocker.patch("jwt.PyJWKClient.get_signing_key_from_jwt", return_value=PyJWK(signing_key)) + mocker.patch("jwt.get_unverified_header", return_value=JWT_PAYLOAD_SAMPLE.get("headers")) + app = firebase_admin.get_app() + + with pytest.raises(ValueError) as excinfo: + app_check.verify_token("encoded", app) + + expected = 'Firebase App Check token has incorrect "aud" (audience) claim.' + assert str(excinfo.value) == expected + + def test_verify_token_with_empty_list_audience_raises_error(self, mocker): + jwt_with_empty_list_audience = JWT_PAYLOAD_SAMPLE.copy() + jwt_with_empty_list_audience["aud"] = [] + mocker.patch("jwt.decode", return_value=jwt_with_empty_list_audience) + mocker.patch("jwt.PyJWKClient.get_signing_key_from_jwt", return_value=PyJWK(signing_key)) + mocker.patch("jwt.get_unverified_header", return_value=JWT_PAYLOAD_SAMPLE.get("headers")) + app = firebase_admin.get_app() + + with pytest.raises(ValueError) as excinfo: + app_check.verify_token("encoded", app) + + expected = 'Firebase App Check token has incorrect "aud" (audience) claim.' + assert str(excinfo.value) == expected + + def test_verify_token_with_incorrect_issuer_raises_error(self, mocker): + jwt_with_non_incorrect_issuer = JWT_PAYLOAD_SAMPLE.copy() + jwt_with_non_incorrect_issuer["iss"] = "https://dwyfrequency.googleapis.com/" + mocker.patch("jwt.decode", return_value=jwt_with_non_incorrect_issuer) + mocker.patch("jwt.PyJWKClient.get_signing_key_from_jwt", return_value=PyJWK(signing_key)) + mocker.patch("jwt.get_unverified_header", return_value=JWT_PAYLOAD_SAMPLE.get("headers")) + app = firebase_admin.get_app() + + with pytest.raises(ValueError) as excinfo: + app_check.verify_token("encoded", app) + + expected = 'Token does not contain the correct "iss" (issuer).' + assert str(excinfo.value) == expected From 6c565f22c2f025b0ee4edd38496add11b8851f3c Mon Sep 17 00:00:00 2001 From: Lahiru Maramba Date: Thu, 6 Oct 2022 14:40:28 -0400 Subject: [PATCH 136/226] [chore] Release 6.0.0 (#649) * [chore] Release 6.0.0 - Major version bump to drop support for Python 3.6 - Added `app_check.verify_token()` API --- firebase_admin/__about__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/firebase_admin/__about__.py b/firebase_admin/__about__.py index b24ab002e..326da5650 100644 --- a/firebase_admin/__about__.py +++ b/firebase_admin/__about__.py @@ -14,7 +14,7 @@ """About information (version, etc) for Firebase Admin SDK.""" -__version__ = '5.4.0' +__version__ = '6.0.0' __title__ = 'firebase_admin' __author__ = 'Firebase' __license__ = 'Apache License 2.0' From d7772f97a7974ee7973343e782d38a15de5dcfa3 Mon Sep 17 00:00:00 2001 From: Thomas Burke <40719837+thomasmburke@users.noreply.github.com> Date: Fri, 14 Oct 2022 13:48:53 -0700 Subject: [PATCH 137/226] password_hash obtained from Firebase Auth backend needs to be base64URL decoded before import to avoid double encoding (#652) Co-authored-by: Thomas Burke --- snippets/auth/index.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/snippets/auth/index.py b/snippets/auth/index.py index 9d6f29ebd..ed324e486 100644 --- a/snippets/auth/index.py +++ b/snippets/auth/index.py @@ -571,8 +571,8 @@ def import_with_scrypt(): auth.ImportUserRecord( uid='some-uid', email='user@example.com', - password_hash=b'password_hash', - password_salt=b'salt' + password_hash=base64.urlsafe_b64decode('password_hash'), + password_salt=base64.urlsafe_b64decode('salt') ), ] From 37ecf18d0ad09d7402143f22443e957a57ec7a2a Mon Sep 17 00:00:00 2001 From: Lahiru Maramba Date: Fri, 14 Oct 2022 17:06:12 -0400 Subject: [PATCH 138/226] fix(fac): Include pyjwt in distribution artifacts (#654) Adding pyjwt to setup.py to include the dependency in the distribution artifacts. --- setup.py | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.py b/setup.py index a54949891..5d917c661 100644 --- a/setup.py +++ b/setup.py @@ -42,6 +42,7 @@ 'google-api-python-client >= 1.7.8', 'google-cloud-firestore>=2.1.0; platform.python_implementation != "PyPy"', 'google-cloud-storage>=1.37.1', + 'pyjwt[crypto] >= 2.5.0', ] setup( From b9e95e8248eb1473ca5a13bf64e8a33b79dc9db3 Mon Sep 17 00:00:00 2001 From: Lahiru Maramba Date: Mon, 17 Oct 2022 15:14:29 -0400 Subject: [PATCH 139/226] [chore] Release 6.0.1 (#655) - Release 6.0.1 --- firebase_admin/__about__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/firebase_admin/__about__.py b/firebase_admin/__about__.py index 326da5650..29568f759 100644 --- a/firebase_admin/__about__.py +++ b/firebase_admin/__about__.py @@ -14,7 +14,7 @@ """About information (version, etc) for Firebase Admin SDK.""" -__version__ = '6.0.0' +__version__ = '6.0.1' __title__ = 'firebase_admin' __author__ = 'Firebase' __license__ = 'Apache License 2.0' From 18714fbb3e765d2db023dfb29d07ad9d0cedcb09 Mon Sep 17 00:00:00 2001 From: Lahiru Maramba Date: Wed, 1 Feb 2023 14:19:10 -0500 Subject: [PATCH 140/226] chore(firestore): Upgrade google-cloud-firestore to include COUNT queries (#671) --- requirements.txt | 2 +- setup.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/requirements.txt b/requirements.txt index c66212673..acf09438b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,6 +9,6 @@ pytest-mock >= 3.6.1 cachecontrol >= 0.12.6 google-api-core[grpc] >= 1.22.1, < 3.0.0dev; platform.python_implementation != 'PyPy' google-api-python-client >= 1.7.8 -google-cloud-firestore >= 2.1.0; platform.python_implementation != 'PyPy' +google-cloud-firestore >= 2.9.1; platform.python_implementation != 'PyPy' google-cloud-storage >= 1.37.1 pyjwt[crypto] >= 2.5.0 \ No newline at end of file diff --git a/setup.py b/setup.py index 5d917c661..1ba2ffa92 100644 --- a/setup.py +++ b/setup.py @@ -40,7 +40,7 @@ 'cachecontrol>=0.12.6', 'google-api-core[grpc] >= 1.22.1, < 3.0.0dev; platform.python_implementation != "PyPy"', 'google-api-python-client >= 1.7.8', - 'google-cloud-firestore>=2.1.0; platform.python_implementation != "PyPy"', + 'google-cloud-firestore>=2.9.1; platform.python_implementation != "PyPy"', 'google-cloud-storage>=1.37.1', 'pyjwt[crypto] >= 2.5.0', ] From 68001d9b4bfb25d142130f106669ef865ff15fe5 Mon Sep 17 00:00:00 2001 From: Lahiru Maramba Date: Thu, 2 Feb 2023 11:24:39 -0500 Subject: [PATCH 141/226] change(ml): Deprecate AutoML model support (#670) * chore(ml): Deprecate AutoML model support * fix lint --- firebase_admin/ml.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/firebase_admin/ml.py b/firebase_admin/ml.py index bcc4b9390..98bdbb56a 100644 --- a/firebase_admin/ml.py +++ b/firebase_admin/ml.py @@ -24,6 +24,7 @@ import time import os from urllib import parse +import warnings import requests @@ -383,11 +384,14 @@ def __ne__(self, other): @staticmethod def _init_model_source(data): + """Initialize the ML model source.""" gcs_tflite_uri = data.pop('gcsTfliteUri', None) if gcs_tflite_uri: return TFLiteGCSModelSource(gcs_tflite_uri=gcs_tflite_uri) auto_ml_model = data.pop('automlModel', None) if auto_ml_model: + warnings.warn('AutoML model support is deprecated and will be removed in the next ' + 'major version.', DeprecationWarning) return TFLiteAutoMlSource(auto_ml_model=auto_ml_model) return None @@ -604,9 +608,14 @@ def as_dict(self, for_upload=False): class TFLiteAutoMlSource(TFLiteModelSource): - """TFLite model source representing a tflite model created with AutoML.""" + """TFLite model source representing a tflite model created with AutoML. + + AutoML model support is deprecated and will be removed in the next major version. + """ def __init__(self, auto_ml_model, app=None): + warnings.warn('AutoML model support is deprecated and will be removed in the next ' + 'major version.', DeprecationWarning) self._app = app self.auto_ml_model = auto_ml_model From 5c21b81e35443f749a5df16e7e02dd817dca8c1c Mon Sep 17 00:00:00 2001 From: Lahiru Maramba Date: Thu, 2 Feb 2023 13:00:36 -0500 Subject: [PATCH 142/226] [chore] Release 6.1.0 (#672) - Release 6.1.0 --- firebase_admin/__about__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/firebase_admin/__about__.py b/firebase_admin/__about__.py index 29568f759..42ac2bd04 100644 --- a/firebase_admin/__about__.py +++ b/firebase_admin/__about__.py @@ -14,7 +14,7 @@ """About information (version, etc) for Firebase Admin SDK.""" -__version__ = '6.0.1' +__version__ = '6.1.0' __title__ = 'firebase_admin' __author__ = 'Firebase' __license__ = 'Apache License 2.0' From 6ae9408139664e9cf9ac660db00a64b378a6a084 Mon Sep 17 00:00:00 2001 From: Lahiru Maramba Date: Wed, 5 Apr 2023 19:55:51 -0400 Subject: [PATCH 143/226] chore: Fix pypy tests (#694) --- .github/workflows/ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index d2129720b..6612efe55 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -8,7 +8,7 @@ jobs: strategy: fail-fast: false matrix: - python: ['3.7', '3.8', '3.9', '3.10', 'pypy3.7'] + python: ['3.7', '3.8', '3.9', '3.10', 'pypy3.8'] steps: - uses: actions/checkout@v3 From 77848a602a3808ffe8be002840dc112ae9d6cc97 Mon Sep 17 00:00:00 2001 From: pragatimodi <110490169+pragatimodi@users.noreply.github.com> Date: Wed, 5 Apr 2023 17:00:16 -0700 Subject: [PATCH 144/226] chore(auth): Update Auth API to `v2` (#691) * `v2beta1` -> `v2` * Reverting auto formatting changes * undo auto formatting --- firebase_admin/_auth_client.py | 4 ++-- firebase_admin/_auth_providers.py | 2 +- firebase_admin/tenant_mgt.py | 2 +- tests/test_auth_providers.py | 4 ++-- tests/test_tenant_mgt.py | 4 ++-- 5 files changed, 8 insertions(+), 8 deletions(-) diff --git a/firebase_admin/_auth_client.py b/firebase_admin/_auth_client.py index eaf491f32..0fc9d2bee 100644 --- a/firebase_admin/_auth_client.py +++ b/firebase_admin/_auth_client.py @@ -50,7 +50,7 @@ def __init__(self, app, tenant_id=None): if emulator_host: base_url = 'http://{0}/identitytoolkit.googleapis.com'.format(emulator_host) endpoint_urls['v1'] = base_url + '/v1' - endpoint_urls['v2beta1'] = base_url + '/v2beta1' + endpoint_urls['v2'] = base_url + '/v2' credential = _utils.EmulatorAdminCredentials() self.emulated = True else: @@ -67,7 +67,7 @@ def __init__(self, app, tenant_id=None): self._user_manager = _user_mgt.UserManager( http_client, app.project_id, tenant_id, url_override=endpoint_urls.get('v1')) self._provider_manager = _auth_providers.ProviderConfigClient( - http_client, app.project_id, tenant_id, url_override=endpoint_urls.get('v2beta1')) + http_client, app.project_id, tenant_id, url_override=endpoint_urls.get('v2')) @property def tenant_id(self): diff --git a/firebase_admin/_auth_providers.py b/firebase_admin/_auth_providers.py index 31511f3c5..31894a4dc 100644 --- a/firebase_admin/_auth_providers.py +++ b/firebase_admin/_auth_providers.py @@ -176,7 +176,7 @@ def items(self): class ProviderConfigClient: """Client for managing Auth provider configurations.""" - PROVIDER_CONFIG_URL = 'https://identitytoolkit.googleapis.com/v2beta1' + PROVIDER_CONFIG_URL = 'https://identitytoolkit.googleapis.com/v2' def __init__(self, http_client, project_id, tenant_id=None, url_override=None): self.http_client = http_client diff --git a/firebase_admin/tenant_mgt.py b/firebase_admin/tenant_mgt.py index 396a819fb..8c53e30a1 100644 --- a/firebase_admin/tenant_mgt.py +++ b/firebase_admin/tenant_mgt.py @@ -232,7 +232,7 @@ def enable_email_link_sign_in(self): class _TenantManagementService: """Firebase tenant management service.""" - TENANT_MGT_URL = 'https://identitytoolkit.googleapis.com/v2beta1' + TENANT_MGT_URL = 'https://identitytoolkit.googleapis.com/v2' def __init__(self, app): credential = app.credential.get_credential() diff --git a/tests/test_auth_providers.py b/tests/test_auth_providers.py index b67a8eb96..a5716266c 100644 --- a/tests/test_auth_providers.py +++ b/tests/test_auth_providers.py @@ -23,10 +23,10 @@ from firebase_admin import exceptions from tests import testutils -ID_TOOLKIT_URL = 'https://identitytoolkit.googleapis.com/v2beta1' +ID_TOOLKIT_URL = 'https://identitytoolkit.googleapis.com/v2' EMULATOR_HOST_ENV_VAR = 'FIREBASE_AUTH_EMULATOR_HOST' AUTH_EMULATOR_HOST = 'localhost:9099' -EMULATED_ID_TOOLKIT_URL = 'http://{}/identitytoolkit.googleapis.com/v2beta1'.format( +EMULATED_ID_TOOLKIT_URL = 'http://{}/identitytoolkit.googleapis.com/v2'.format( AUTH_EMULATOR_HOST) URL_PROJECT_SUFFIX = '/projects/mock-project-id' USER_MGT_URLS = { diff --git a/tests/test_tenant_mgt.py b/tests/test_tenant_mgt.py index f92dd2a83..53b766239 100644 --- a/tests/test_tenant_mgt.py +++ b/tests/test_tenant_mgt.py @@ -108,8 +108,8 @@ INVALID_BOOLEANS = ['', 1, 0, list(), tuple(), dict()] USER_MGT_URL_PREFIX = 'https://identitytoolkit.googleapis.com/v1/projects/mock-project-id' -PROVIDER_MGT_URL_PREFIX = 'https://identitytoolkit.googleapis.com/v2beta1/projects/mock-project-id' -TENANT_MGT_URL_PREFIX = 'https://identitytoolkit.googleapis.com/v2beta1/projects/mock-project-id' +PROVIDER_MGT_URL_PREFIX = 'https://identitytoolkit.googleapis.com/v2/projects/mock-project-id' +TENANT_MGT_URL_PREFIX = 'https://identitytoolkit.googleapis.com/v2/projects/mock-project-id' @pytest.fixture(scope='module') From f0865f7493a2c642d2e551efdf19da1e53c1a8c3 Mon Sep 17 00:00:00 2001 From: Samuel Dion-Girardeau Date: Wed, 5 Apr 2023 20:08:45 -0400 Subject: [PATCH 145/226] Add release notes to project URLs in PyPI (#679) It's useful to be able to navigate to the release notes easily from the package index when upgrading. "Release Notes" is a special keyword that will have the scroll icon in the project page. A random example: * https://pypi.org/project/streamlit/ * https://github.com/streamlit/streamlit/blob/815a3ea6fa3e7f9099b479e8365bd3a5874ddc35/lib/setup.py#L111 Co-authored-by: Lahiru Maramba --- setup.py | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.py b/setup.py index 1ba2ffa92..a82bc47f3 100644 --- a/setup.py +++ b/setup.py @@ -52,6 +52,7 @@ long_description=long_description, url=about['__url__'], project_urls={ + 'Release Notes': 'https://firebase.google.com/support/release-notes/admin/python', 'Source': 'https://github.com/firebase/firebase-admin-python', }, author=about['__author__'], From 4323ed88cef168d66844add3aeaf3b7d19b46be6 Mon Sep 17 00:00:00 2001 From: Doris-Ge Date: Fri, 9 Jun 2023 10:27:21 -0700 Subject: [PATCH 146/226] feat(fcm): Add `send_each` and `send_each_for_multicast` for FCM batch send (#706) * Implement `send_each` and `send_each_for_multicast` (#692) `send_each` vs `send_all` 1. `send_each` sends one HTTP request to V1 Send endpoint for each message in the list. `send_all` sends only one HTTP request to V1 Batch Send endpoint to send all messages in the array. 2. `send_each` uses concurrent.futures.ThreadPoolExecutor to run and wait for all `request` calls to complete and construct a `BatchResponse`. An `request` call to V1 Send endpoint either completes with a success or throws an exception. So if an exception is thrown out, the exception will be caught in `send_each` and turned into a `SendResponse` with an exception. Therefore, unlike `send_all`, `send_each` does not always throw an exception for a total failure. It can also return a `BatchResponse` with only exceptions in it. `send_each_for_multicast` calls `send_each` under the hood. * Add integration tests for send_each and send_each_for_multicast (#700) * Add integration tests for send_each and send_each_for_multicast Add test_send_each, test_send_each_500 and test_send_each_for_multicast * chore: Fix pypy tests (#694) * chore(auth): Update Auth API to `v2` (#691) * `v2beta1` -> `v2` * Reverting auto formatting changes * undo auto formatting * Add release notes to project URLs in PyPI (#679) It's useful to be able to navigate to the release notes easily from the package index when upgrading. "Release Notes" is a special keyword that will have the scroll icon in the project page. A random example: * https://pypi.org/project/streamlit/ * https://github.com/streamlit/streamlit/blob/815a3ea6fa3e7f9099b479e8365bd3a5874ddc35/lib/setup.py#L111 Co-authored-by: Lahiru Maramba --------- Co-authored-by: Lahiru Maramba Co-authored-by: pragatimodi <110490169+pragatimodi@users.noreply.github.com> Co-authored-by: Samuel Dion-Girardeau --------- Co-authored-by: Lahiru Maramba Co-authored-by: pragatimodi <110490169+pragatimodi@users.noreply.github.com> Co-authored-by: Samuel Dion-Girardeau --- firebase_admin/messaging.py | 92 +++++++++++- integration/test_messaging.py | 62 ++++++++ tests/test_messaging.py | 265 ++++++++++++++++++++++++++++++++++ tests/testutils.py | 30 ++++ 4 files changed, 448 insertions(+), 1 deletion(-) diff --git a/firebase_admin/messaging.py b/firebase_admin/messaging.py index 46dd7d410..7e63933e1 100644 --- a/firebase_admin/messaging.py +++ b/firebase_admin/messaging.py @@ -14,11 +14,13 @@ """Firebase Cloud Messaging module.""" +import concurrent.futures import json +import warnings +import requests from googleapiclient import http from googleapiclient import _auth -import requests import firebase_admin from firebase_admin import _http_client @@ -26,6 +28,7 @@ from firebase_admin import _messaging_utils from firebase_admin import _gapic_utils from firebase_admin import _utils +from firebase_admin import exceptions _MESSAGING_ATTRIBUTE = '_messaging' @@ -115,6 +118,57 @@ def send(message, dry_run=False, app=None): """ return _get_messaging_service(app).send(message, dry_run) +def send_each(messages, dry_run=False, app=None): + """Sends each message in the given list via Firebase Cloud Messaging. + + If the ``dry_run`` mode is enabled, the message will not be actually delivered to the + recipients. Instead FCM performs all the usual validations, and emulates the send operation. + + Args: + messages: A list of ``messaging.Message`` instances. + dry_run: A boolean indicating whether to run the operation in dry run mode (optional). + app: An App instance (optional). + + Returns: + BatchResponse: A ``messaging.BatchResponse`` instance. + + Raises: + FirebaseError: If an error occurs while sending the message to the FCM service. + ValueError: If the input arguments are invalid. + """ + return _get_messaging_service(app).send_each(messages, dry_run) + +def send_each_for_multicast(multicast_message, dry_run=False, app=None): + """Sends the given mutlicast message to each token via Firebase Cloud Messaging (FCM). + + If the ``dry_run`` mode is enabled, the message will not be actually delivered to the + recipients. Instead FCM performs all the usual validations, and emulates the send operation. + + Args: + multicast_message: An instance of ``messaging.MulticastMessage``. + dry_run: A boolean indicating whether to run the operation in dry run mode (optional). + app: An App instance (optional). + + Returns: + BatchResponse: A ``messaging.BatchResponse`` instance. + + Raises: + FirebaseError: If an error occurs while sending the message to the FCM service. + ValueError: If the input arguments are invalid. + """ + if not isinstance(multicast_message, MulticastMessage): + raise ValueError('Message must be an instance of messaging.MulticastMessage class.') + messages = [Message( + data=multicast_message.data, + notification=multicast_message.notification, + android=multicast_message.android, + webpush=multicast_message.webpush, + apns=multicast_message.apns, + fcm_options=multicast_message.fcm_options, + token=token + ) for token in multicast_message.tokens] + return _get_messaging_service(app).send_each(messages, dry_run) + def send_all(messages, dry_run=False, app=None): """Sends the given list of messages via Firebase Cloud Messaging as a single batch. @@ -132,7 +186,10 @@ def send_all(messages, dry_run=False, app=None): Raises: FirebaseError: If an error occurs while sending the message to the FCM service. ValueError: If the input arguments are invalid. + + send_all() is deprecated. Use send_each() instead. """ + warnings.warn('send_all() is deprecated. Use send_each() instead.', DeprecationWarning) return _get_messaging_service(app).send_all(messages, dry_run) def send_multicast(multicast_message, dry_run=False, app=None): @@ -152,7 +209,11 @@ def send_multicast(multicast_message, dry_run=False, app=None): Raises: FirebaseError: If an error occurs while sending the message to the FCM service. ValueError: If the input arguments are invalid. + + send_multicast() is deprecated. Use send_each_for_multicast() instead. """ + warnings.warn('send_multicast() is deprecated. Use send_each_for_multicast() instead.', + DeprecationWarning) if not isinstance(multicast_message, MulticastMessage): raise ValueError('Message must be an instance of messaging.MulticastMessage class.') messages = [Message( @@ -356,6 +417,35 @@ def send(self, message, dry_run=False): else: return resp['name'] + def send_each(self, messages, dry_run=False): + """Sends the given messages to FCM via the FCM v1 API.""" + if not isinstance(messages, list): + raise ValueError('messages must be a list of messaging.Message instances.') + if len(messages) > 500: + raise ValueError('messages must not contain more than 500 elements.') + + def send_data(data): + try: + resp = self._client.body( + 'post', + url=self._fcm_url, + headers=self._fcm_headers, + json=data) + except requests.exceptions.RequestException as exception: + return SendResponse(resp=None, exception=self._handle_fcm_error(exception)) + else: + return SendResponse(resp, exception=None) + + message_data = [self._message_data(message, dry_run) for message in messages] + try: + with concurrent.futures.ThreadPoolExecutor(max_workers=len(message_data)) as executor: + responses = [resp for resp in executor.map(send_data, message_data)] + return BatchResponse(responses) + except Exception as error: + raise exceptions.UnknownError( + message='Unknown error while making remote service calls: {0}'.format(error), + cause=error) + def send_all(self, messages, dry_run=False): """Sends the given messages to FCM via the batch API.""" if not isinstance(messages, list): diff --git a/integration/test_messaging.py b/integration/test_messaging.py index b5612b63d..ab5d09b9e 100644 --- a/integration/test_messaging.py +++ b/integration/test_messaging.py @@ -86,6 +86,68 @@ def test_send_malformed_token(): with pytest.raises(exceptions.InvalidArgumentError): messaging.send(msg, dry_run=True) +def test_send_each(): + messages = [ + messaging.Message( + topic='foo-bar', notification=messaging.Notification('Title', 'Body')), + messaging.Message( + topic='foo-bar', notification=messaging.Notification('Title', 'Body')), + messaging.Message( + token='not-a-token', notification=messaging.Notification('Title', 'Body')), + ] + + batch_response = messaging.send_each(messages, dry_run=True) + + assert batch_response.success_count == 2 + assert batch_response.failure_count == 1 + assert len(batch_response.responses) == 3 + + response = batch_response.responses[0] + assert response.success is True + assert response.exception is None + assert re.match('^projects/.*/messages/.*$', response.message_id) + + response = batch_response.responses[1] + assert response.success is True + assert response.exception is None + assert re.match('^projects/.*/messages/.*$', response.message_id) + + response = batch_response.responses[2] + assert response.success is False + assert isinstance(response.exception, exceptions.InvalidArgumentError) + assert response.message_id is None + +def test_send_each_500(): + messages = [] + for msg_number in range(500): + topic = 'foo-bar-{0}'.format(msg_number % 10) + messages.append(messaging.Message(topic=topic)) + + batch_response = messaging.send_each(messages, dry_run=True) + + assert batch_response.success_count == 500 + assert batch_response.failure_count == 0 + assert len(batch_response.responses) == 500 + for response in batch_response.responses: + assert response.success is True + assert response.exception is None + assert re.match('^projects/.*/messages/.*$', response.message_id) + +def test_send_each_for_multicast(): + multicast = messaging.MulticastMessage( + notification=messaging.Notification('Title', 'Body'), + tokens=['not-a-token', 'also-not-a-token']) + + batch_response = messaging.send_each_for_multicast(multicast) + + assert batch_response.success_count == 0 + assert batch_response.failure_count == 2 + assert len(batch_response.responses) == 2 + for response in batch_response.responses: + assert response.success is False + assert response.exception is not None + assert response.message_id is None + def test_send_all(): messages = [ messaging.Message( diff --git a/tests/test_messaging.py b/tests/test_messaging.py index 3d8740cc1..71bb13eed 100644 --- a/tests/test_messaging.py +++ b/tests/test_messaging.py @@ -1812,6 +1812,16 @@ def setup_class(cls): def teardown_class(cls): testutils.cleanup_apps() + def _instrument_messaging_service(self, response_dict, app=None): + if not app: + app = firebase_admin.get_app() + fcm_service = messaging._get_messaging_service(app) + recorder = [] + fcm_service._client.session.mount( + 'https://fcm.googleapis.com', + testutils.MockRequestBasedMultiRequestAdapter(response_dict, recorder)) + return fcm_service, recorder + def _instrument_batch_messaging_service(self, app=None, status=200, payload='', exc=None): def build_mock_transport(_): if exc: @@ -1844,6 +1854,261 @@ def _batch_payload(self, payloads): return payload +class TestSendEach(TestBatch): + + def test_no_project_id(self): + def evaluate(): + app = firebase_admin.initialize_app(testutils.MockCredential(), name='no_project_id') + with pytest.raises(ValueError): + messaging.send_each([messaging.Message(topic='foo')], app=app) + testutils.run_without_project_id(evaluate) + + @pytest.mark.parametrize('msg', NON_LIST_ARGS) + def test_invalid_send_each(self, msg): + with pytest.raises(ValueError) as excinfo: + messaging.send_each(msg) + if isinstance(msg, list): + expected = 'Message must be an instance of messaging.Message class.' + assert str(excinfo.value) == expected + else: + expected = 'messages must be a list of messaging.Message instances.' + assert str(excinfo.value) == expected + + def test_invalid_over_500(self): + msg = messaging.Message(topic='foo') + with pytest.raises(ValueError) as excinfo: + messaging.send_each([msg for _ in range(0, 501)]) + expected = 'messages must not contain more than 500 elements.' + assert str(excinfo.value) == expected + + def test_send_each(self): + payload1 = json.dumps({'name': 'message-id1'}) + payload2 = json.dumps({'name': 'message-id2'}) + _ = self._instrument_messaging_service( + response_dict={'foo1': [200, payload1], 'foo2': [200, payload2]}) + msg1 = messaging.Message(topic='foo1') + msg2 = messaging.Message(topic='foo2') + batch_response = messaging.send_each([msg1, msg2], dry_run=True) + assert batch_response.success_count == 2 + assert batch_response.failure_count == 0 + assert len(batch_response.responses) == 2 + assert [r.message_id for r in batch_response.responses] == ['message-id1', 'message-id2'] + assert all([r.success for r in batch_response.responses]) + assert not any([r.exception for r in batch_response.responses]) + + @pytest.mark.parametrize('status', HTTP_ERROR_CODES) + def test_send_each_detailed_error(self, status): + success_payload = json.dumps({'name': 'message-id'}) + error_payload = json.dumps({ + 'error': { + 'status': 'INVALID_ARGUMENT', + 'message': 'test error' + } + }) + _ = self._instrument_messaging_service( + response_dict={'foo1': [200, success_payload], 'foo2': [status, error_payload]}) + msg1 = messaging.Message(topic='foo1') + msg2 = messaging.Message(topic='foo2') + batch_response = messaging.send_each([msg1, msg2]) + assert batch_response.success_count == 1 + assert batch_response.failure_count == 1 + assert len(batch_response.responses) == 2 + success_response = batch_response.responses[0] + assert success_response.message_id == 'message-id' + assert success_response.success is True + assert success_response.exception is None + error_response = batch_response.responses[1] + assert error_response.message_id is None + assert error_response.success is False + exception = error_response.exception + assert isinstance(exception, exceptions.InvalidArgumentError) + check_exception(exception, 'test error', status) + + @pytest.mark.parametrize('status', HTTP_ERROR_CODES) + def test_send_each_canonical_error_code(self, status): + success_payload = json.dumps({'name': 'message-id'}) + error_payload = json.dumps({ + 'error': { + 'status': 'NOT_FOUND', + 'message': 'test error' + } + }) + _ = self._instrument_messaging_service( + response_dict={'foo1': [200, success_payload], 'foo2': [status, error_payload]}) + msg1 = messaging.Message(topic='foo1') + msg2 = messaging.Message(topic='foo2') + batch_response = messaging.send_each([msg1, msg2]) + assert batch_response.success_count == 1 + assert batch_response.failure_count == 1 + assert len(batch_response.responses) == 2 + success_response = batch_response.responses[0] + assert success_response.message_id == 'message-id' + assert success_response.success is True + assert success_response.exception is None + error_response = batch_response.responses[1] + assert error_response.message_id is None + assert error_response.success is False + exception = error_response.exception + assert isinstance(exception, exceptions.NotFoundError) + check_exception(exception, 'test error', status) + + @pytest.mark.parametrize('status', HTTP_ERROR_CODES) + @pytest.mark.parametrize('fcm_error_code, exc_type', FCM_ERROR_CODES.items()) + def test_send_each_fcm_error_code(self, status, fcm_error_code, exc_type): + success_payload = json.dumps({'name': 'message-id'}) + error_payload = json.dumps({ + 'error': { + 'status': 'INVALID_ARGUMENT', + 'message': 'test error', + 'details': [ + { + '@type': 'type.googleapis.com/google.firebase.fcm.v1.FcmError', + 'errorCode': fcm_error_code, + }, + ], + } + }) + _ = self._instrument_messaging_service( + response_dict={'foo1': [200, success_payload], 'foo2': [status, error_payload]}) + msg1 = messaging.Message(topic='foo1') + msg2 = messaging.Message(topic='foo2') + batch_response = messaging.send_each([msg1, msg2]) + assert batch_response.success_count == 1 + assert batch_response.failure_count == 1 + assert len(batch_response.responses) == 2 + success_response = batch_response.responses[0] + assert success_response.message_id == 'message-id' + assert success_response.success is True + assert success_response.exception is None + error_response = batch_response.responses[1] + assert error_response.message_id is None + assert error_response.success is False + exception = error_response.exception + assert isinstance(exception, exc_type) + check_exception(exception, 'test error', status) + + +class TestSendEachForMulticast(TestBatch): + + def test_no_project_id(self): + def evaluate(): + app = firebase_admin.initialize_app(testutils.MockCredential(), name='no_project_id') + with pytest.raises(ValueError): + messaging.send_all([messaging.Message(topic='foo')], app=app) + testutils.run_without_project_id(evaluate) + + @pytest.mark.parametrize('msg', NON_LIST_ARGS) + def test_invalid_send_each_for_multicast(self, msg): + with pytest.raises(ValueError) as excinfo: + messaging.send_multicast(msg) + expected = 'Message must be an instance of messaging.MulticastMessage class.' + assert str(excinfo.value) == expected + + def test_send_each_for_multicast(self): + payload1 = json.dumps({'name': 'message-id1'}) + payload2 = json.dumps({'name': 'message-id2'}) + _ = self._instrument_messaging_service( + response_dict={'foo1': [200, payload1], 'foo2': [200, payload2]}) + msg = messaging.MulticastMessage(tokens=['foo1', 'foo2']) + batch_response = messaging.send_each_for_multicast(msg, dry_run=True) + assert batch_response.success_count == 2 + assert batch_response.failure_count == 0 + assert len(batch_response.responses) == 2 + assert [r.message_id for r in batch_response.responses] == ['message-id1', 'message-id2'] + assert all([r.success for r in batch_response.responses]) + assert not any([r.exception for r in batch_response.responses]) + + @pytest.mark.parametrize('status', HTTP_ERROR_CODES) + def test_send_each_for_multicast_detailed_error(self, status): + success_payload = json.dumps({'name': 'message-id'}) + error_payload = json.dumps({ + 'error': { + 'status': 'INVALID_ARGUMENT', + 'message': 'test error' + } + }) + _ = self._instrument_messaging_service( + response_dict={'foo1': [200, success_payload], 'foo2': [status, error_payload]}) + msg = messaging.MulticastMessage(tokens=['foo1', 'foo2']) + batch_response = messaging.send_each_for_multicast(msg) + assert batch_response.success_count == 1 + assert batch_response.failure_count == 1 + assert len(batch_response.responses) == 2 + success_response = batch_response.responses[0] + assert success_response.message_id == 'message-id' + assert success_response.success is True + assert success_response.exception is None + error_response = batch_response.responses[1] + assert error_response.message_id is None + assert error_response.success is False + assert error_response.exception is not None + exception = error_response.exception + assert isinstance(exception, exceptions.InvalidArgumentError) + check_exception(exception, 'test error', status) + + @pytest.mark.parametrize('status', HTTP_ERROR_CODES) + def test_send_each_for_multicast_canonical_error_code(self, status): + success_payload = json.dumps({'name': 'message-id'}) + error_payload = json.dumps({ + 'error': { + 'status': 'NOT_FOUND', + 'message': 'test error' + } + }) + _ = self._instrument_messaging_service( + response_dict={'foo1': [200, success_payload], 'foo2': [status, error_payload]}) + msg = messaging.MulticastMessage(tokens=['foo1', 'foo2']) + batch_response = messaging.send_each_for_multicast(msg) + assert batch_response.success_count == 1 + assert batch_response.failure_count == 1 + assert len(batch_response.responses) == 2 + success_response = batch_response.responses[0] + assert success_response.message_id == 'message-id' + assert success_response.success is True + assert success_response.exception is None + error_response = batch_response.responses[1] + assert error_response.message_id is None + assert error_response.success is False + assert error_response.exception is not None + exception = error_response.exception + assert isinstance(exception, exceptions.NotFoundError) + check_exception(exception, 'test error', status) + + @pytest.mark.parametrize('status', HTTP_ERROR_CODES) + def test_send_each_for_multicast_fcm_error_code(self, status): + success_payload = json.dumps({'name': 'message-id'}) + error_payload = json.dumps({ + 'error': { + 'status': 'INVALID_ARGUMENT', + 'message': 'test error', + 'details': [ + { + '@type': 'type.googleapis.com/google.firebase.fcm.v1.FcmError', + 'errorCode': 'UNREGISTERED', + }, + ], + } + }) + _ = self._instrument_messaging_service( + response_dict={'foo1': [200, success_payload], 'foo2': [status, error_payload]}) + msg = messaging.MulticastMessage(tokens=['foo1', 'foo2']) + batch_response = messaging.send_each_for_multicast(msg) + assert batch_response.success_count == 1 + assert batch_response.failure_count == 1 + assert len(batch_response.responses) == 2 + success_response = batch_response.responses[0] + assert success_response.message_id == 'message-id' + assert success_response.success is True + assert success_response.exception is None + error_response = batch_response.responses[1] + assert error_response.message_id is None + assert error_response.success is False + assert error_response.exception is not None + exception = error_response.exception + assert isinstance(exception, messaging.UnregisteredError) + check_exception(exception, 'test error', status) + + class TestSendAll(TestBatch): def test_no_project_id(self): diff --git a/tests/testutils.py b/tests/testutils.py index 92755107c..e52b90d1a 100644 --- a/tests/testutils.py +++ b/tests/testutils.py @@ -171,3 +171,33 @@ def status(self): @property def data(self): return self._responses[0] + +class MockRequestBasedMultiRequestAdapter(adapters.HTTPAdapter): + """A mock HTTP adapter that supports multiple responses for the Python requests module. + The response for each incoming request should be specified in response_dict during + initialization. Each incoming request should contain an identifier in the its body.""" + def __init__(self, response_dict, recorder): + """Constructs a MockRequestBasedMultiRequestAdapter. + + Each incoming request consumes the response and status mapped to it. If no response + is specified for the request, the response will be 404 with an empty body. + """ + adapters.HTTPAdapter.__init__(self) + self._current_response = 0 + self._response_dict = dict(response_dict) + self._recorder = recorder + + def send(self, request, **kwargs): # pylint: disable=arguments-differ + request._extra_kwargs = kwargs + self._recorder.append(request) + resp = models.Response() + resp.url = request.url + resp.status_code = 404 # Not found. + resp.raw = None + for req_id, pair in self._response_dict.items(): + if req_id in str(request.body): + status, response = pair + resp.status_code = status + resp.raw = io.BytesIO(response.encode()) + break + return resp From 59a22b3ef3263530b1f1b61a3416ef311c24477b Mon Sep 17 00:00:00 2001 From: Lahiru Maramba Date: Thu, 22 Jun 2023 14:41:15 -0400 Subject: [PATCH 147/226] [chore] Release 6.2.0 (#708) - Release 6.2.0 --- firebase_admin/__about__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/firebase_admin/__about__.py b/firebase_admin/__about__.py index 42ac2bd04..ff6ad252d 100644 --- a/firebase_admin/__about__.py +++ b/firebase_admin/__about__.py @@ -14,7 +14,7 @@ """About information (version, etc) for Firebase Admin SDK.""" -__version__ = '6.1.0' +__version__ = '6.2.0' __title__ = 'firebase_admin' __author__ = 'Firebase' __license__ = 'Apache License 2.0' From aef52be90951204bc2ce966656fca796f4f87228 Mon Sep 17 00:00:00 2001 From: Jonathan Edey <145066863+jonathanedey@users.noreply.github.com> Date: Thu, 21 Sep 2023 12:53:11 -0400 Subject: [PATCH 148/226] fix: Correctly catch DefaultCredentialsError when looking up project_id (#720) * Added additional exception catching with unit tests * lint: fixed spacing * fix: explicitly force exception raise in unit test --- firebase_admin/__init__.py | 3 ++- tests/test_app.py | 22 ++++++++++++++++++++++ 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/firebase_admin/__init__.py b/firebase_admin/__init__.py index 7e3b2eab0..e2c8f1ec5 100644 --- a/firebase_admin/__init__.py +++ b/firebase_admin/__init__.py @@ -18,6 +18,7 @@ import os import threading +from google.auth.exceptions import DefaultCredentialsError from firebase_admin import credentials from firebase_admin.__about__ import __version__ @@ -257,7 +258,7 @@ def _lookup_project_id(self): if not project_id: try: project_id = self._credential.project_id - except AttributeError: + except (AttributeError, DefaultCredentialsError): pass if not project_id: project_id = os.environ.get('GOOGLE_CLOUD_PROJECT', diff --git a/tests/test_app.py b/tests/test_app.py index fe3a43a5c..4233d5849 100644 --- a/tests/test_app.py +++ b/tests/test_app.py @@ -17,6 +17,7 @@ import os import pytest +from google.auth.exceptions import DefaultCredentialsError import firebase_admin from firebase_admin import credentials @@ -315,6 +316,27 @@ def evaluate(): assert app.project_id is None testutils.run_without_project_id(evaluate) + def test_no_project_id_from_environment(self, app_credential): + default_env = 'GOOGLE_APPLICATION_CREDENTIALS' + gcloud_env = 'CLOUDSDK_CONFIG' + def evaluate(): + app = firebase_admin.initialize_app(app_credential, name='myApp') + app._credential._g_credential = None + old_gcloud_var = os.environ.get(gcloud_env) + os.environ[gcloud_env] = '' + old_default_var = os.environ.get(default_env) + if old_default_var: + del os.environ[default_env] + with pytest.raises((AttributeError, DefaultCredentialsError)): + project_id = app._credential.project_id + project_id = app.project_id + if old_default_var: + os.environ[default_env] = old_default_var + if old_gcloud_var: + os.environ[gcloud_env] = old_gcloud_var + assert project_id is None + testutils.run_without_project_id(evaluate) + def test_non_string_project_id(self): options = {'projectId': {'key': 'not a string'}} with pytest.raises(ValueError): From 4052a3cdc7d7b9955a3c4611409554200225297a Mon Sep 17 00:00:00 2001 From: Lahiru Maramba Date: Thu, 12 Oct 2023 12:48:33 -0400 Subject: [PATCH 149/226] Update `github.ref` value in `release.yml` (#730) - Fixes the release workflow to match the updates to `github.ref` - `github.ref` now returns a fully-formed value `refs/heads/...` - See https://github.blog/changelog/2023-09-13-github-actions-updates-to-github_ref-and-github-ref/ --- .github/workflows/release.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 5eb4bfaea..e4c9b9fd0 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -87,7 +87,7 @@ jobs: # 3. with the label 'release:publish', and # 4. the title prefix '[chore] Release '. if: github.event.pull_request.merged && - github.ref == 'master' && + github.ref == 'refs/heads/master' && contains(github.event.pull_request.labels.*.name, 'release:publish') && startsWith(github.event.pull_request.title, '[chore] Release ') From 44b756875fe0183e290ddac5b059b8da073e88dc Mon Sep 17 00:00:00 2001 From: Chris Hua Date: Thu, 26 Oct 2023 10:11:17 -0400 Subject: [PATCH 150/226] feat: add clockSkewSeconds (#714) * feat: add clockSkewSeconds per feedback in https://github.com/firebase/firebase-admin-python/pull/625#issuecomment-1331197410 adds unit and integration tests as well. unit tests and lint pass. * fix: test * chore: version bump for testing * chore: address CR * fix:lint * chore: address CR * chore: remove test * fix: remove more tests * chore: address CR --- firebase_admin/__init__.py | 5 ++-- firebase_admin/_auth_client.py | 6 +++-- firebase_admin/_token_gen.py | 18 ++++++++++----- firebase_admin/auth.py | 14 ++++++++---- integration/test_auth.py | 1 + tests/test_token_gen.py | 42 ++++++++++++++++++++++++++++++++-- 6 files changed, 69 insertions(+), 17 deletions(-) diff --git a/firebase_admin/__init__.py b/firebase_admin/__init__.py index e2c8f1ec5..0ca82ec5e 100644 --- a/firebase_admin/__init__.py +++ b/firebase_admin/__init__.py @@ -50,8 +50,9 @@ def initialize_app(credential=None, options=None, name=_DEFAULT_APP_NAME): Google Application Default Credentials are used. options: A dictionary of configuration options (optional). Supported options include ``databaseURL``, ``storageBucket``, ``projectId``, ``databaseAuthVariableOverride``, - ``serviceAccountId`` and ``httpTimeout``. If ``httpTimeout`` is not set, the SDK - uses a default timeout of 120 seconds. + ``serviceAccountId`` and ``httpTimeout``. If ``httpTimeout`` is not set, the SDK uses + a default timeout of 120 seconds. + name: Name of the app (optional). Returns: App: A newly initialized instance of App. diff --git a/firebase_admin/_auth_client.py b/firebase_admin/_auth_client.py index 0fc9d2bee..38b42993a 100644 --- a/firebase_admin/_auth_client.py +++ b/firebase_admin/_auth_client.py @@ -92,7 +92,7 @@ def create_custom_token(self, uid, developer_claims=None): return self._token_generator.create_custom_token( uid, developer_claims, tenant_id=self.tenant_id) - def verify_id_token(self, id_token, check_revoked=False): + def verify_id_token(self, id_token, check_revoked=False, clock_skew_seconds=0): """Verifies the signature and data for the provided JWT. Accepts a signed token string, verifies that it is current, was issued @@ -102,6 +102,8 @@ def verify_id_token(self, id_token, check_revoked=False): id_token: A string of the encoded JWT. check_revoked: Boolean, If true, checks whether the token has been revoked or the user disabled (optional). + clock_skew_seconds: The number of seconds to tolerate when checking the token. + Must be between 0-60. Defaults to 0. Returns: dict: A dictionary of key-value pairs parsed from the decoded JWT. @@ -124,7 +126,7 @@ def verify_id_token(self, id_token, check_revoked=False): raise ValueError('Illegal check_revoked argument. Argument must be of type ' ' bool, but given "{0}".'.format(type(check_revoked))) - verified_claims = self._token_verifier.verify_id_token(id_token) + verified_claims = self._token_verifier.verify_id_token(id_token, clock_skew_seconds) if self.tenant_id: token_tenant_id = verified_claims.get('firebase', {}).get('tenant') if self.tenant_id != token_tenant_id: diff --git a/firebase_admin/_token_gen.py b/firebase_admin/_token_gen.py index 32c109d5d..a2fc725e8 100644 --- a/firebase_admin/_token_gen.py +++ b/firebase_admin/_token_gen.py @@ -289,11 +289,11 @@ def __init__(self, app): invalid_token_error=InvalidSessionCookieError, expired_token_error=ExpiredSessionCookieError) - def verify_id_token(self, id_token): - return self.id_token_verifier.verify(id_token, self.request) + def verify_id_token(self, id_token, clock_skew_seconds=0): + return self.id_token_verifier.verify(id_token, self.request, clock_skew_seconds) - def verify_session_cookie(self, cookie): - return self.cookie_verifier.verify(cookie, self.request) + def verify_session_cookie(self, cookie, clock_skew_seconds=0): + return self.cookie_verifier.verify(cookie, self.request, clock_skew_seconds) class _JWTVerifier: @@ -313,7 +313,7 @@ def __init__(self, **kwargs): self._invalid_token_error = kwargs.pop('invalid_token_error') self._expired_token_error = kwargs.pop('expired_token_error') - def verify(self, token, request): + def verify(self, token, request, clock_skew_seconds=0): """Verifies the signature and data for the provided JWT.""" token = token.encode('utf-8') if isinstance(token, str) else token if not isinstance(token, bytes) or not token: @@ -328,6 +328,11 @@ def verify(self, token, request): 'or set your Firebase project ID as an app option. Alternatively set the ' 'GOOGLE_CLOUD_PROJECT environment variable.'.format(self.operation)) + if clock_skew_seconds < 0 or clock_skew_seconds > 60: + raise ValueError( + 'Illegal clock_skew_seconds value: {0}. Must be between 0 and 60, inclusive.' + .format(clock_skew_seconds)) + header, payload = self._decode_unverified(token) issuer = payload.get('iss') audience = payload.get('aud') @@ -393,7 +398,8 @@ def verify(self, token, request): token, request=request, audience=self.project_id, - certs_url=self.cert_url) + certs_url=self.cert_url, + clock_skew_in_seconds=clock_skew_seconds) verified_claims['uid'] = verified_claims['sub'] return verified_claims except google.auth.exceptions.TransportError as error: diff --git a/firebase_admin/auth.py b/firebase_admin/auth.py index 6902a322f..84873c3da 100644 --- a/firebase_admin/auth.py +++ b/firebase_admin/auth.py @@ -191,7 +191,7 @@ def create_custom_token(uid, developer_claims=None, app=None): return client.create_custom_token(uid, developer_claims) -def verify_id_token(id_token, app=None, check_revoked=False): +def verify_id_token(id_token, app=None, check_revoked=False, clock_skew_seconds=0): """Verifies the signature and data for the provided JWT. Accepts a signed token string, verifies that it is current, and issued @@ -202,7 +202,8 @@ def verify_id_token(id_token, app=None, check_revoked=False): app: An App instance (optional). check_revoked: Boolean, If true, checks whether the token has been revoked or the user disabled (optional). - + clock_skew_seconds: The number of seconds to tolerate when checking the token. + Must be between 0-60. Defaults to 0. Returns: dict: A dictionary of key-value pairs parsed from the decoded JWT. @@ -217,7 +218,8 @@ def verify_id_token(id_token, app=None, check_revoked=False): record is disabled. """ client = _get_client(app) - return client.verify_id_token(id_token, check_revoked=check_revoked) + return client.verify_id_token( + id_token, check_revoked=check_revoked, clock_skew_seconds=clock_skew_seconds) def create_session_cookie(id_token, expires_in, app=None): @@ -243,7 +245,7 @@ def create_session_cookie(id_token, expires_in, app=None): return client._token_generator.create_session_cookie(id_token, expires_in) -def verify_session_cookie(session_cookie, check_revoked=False, app=None): +def verify_session_cookie(session_cookie, check_revoked=False, app=None, clock_skew_seconds=0): """Verifies a Firebase session cookie. Accepts a session cookie string, verifies that it is current, and issued @@ -254,6 +256,7 @@ def verify_session_cookie(session_cookie, check_revoked=False, app=None): check_revoked: Boolean, if true, checks whether the cookie has been revoked or the user disabled (optional). app: An App instance (optional). + clock_skew_seconds: The number of seconds to tolerate when checking the cookie. Returns: dict: A dictionary of key-value pairs parsed from the decoded JWT. @@ -270,7 +273,8 @@ def verify_session_cookie(session_cookie, check_revoked=False, app=None): """ client = _get_client(app) # pylint: disable=protected-access - verified_claims = client._token_verifier.verify_session_cookie(session_cookie) + verified_claims = client._token_verifier.verify_session_cookie( + session_cookie, clock_skew_seconds) if check_revoked: client._check_jwt_revoked_or_disabled( verified_claims, RevokedSessionCookieError, 'session cookie') diff --git a/integration/test_auth.py b/integration/test_auth.py index 82974732d..e1d01a254 100644 --- a/integration/test_auth.py +++ b/integration/test_auth.py @@ -617,6 +617,7 @@ def test_verify_session_cookie_revoked(new_user, api_key): claims = auth.verify_session_cookie(session_cookie, check_revoked=True) assert claims['iat'] * 1000 >= user.tokens_valid_after_timestamp + def test_verify_session_cookie_disabled(new_user, api_key): custom_token = auth.create_custom_token(new_user.uid) id_token = _sign_in(custom_token, api_key) diff --git a/tests/test_token_gen.py b/tests/test_token_gen.py index 00b7956fa..64540f26f 100644 --- a/tests/test_token_gen.py +++ b/tests/test_token_gen.py @@ -440,6 +440,10 @@ class TestVerifyIdToken: 'iat': int(time.time()) - 10000, 'exp': int(time.time()) - 3600 }), + 'ExpiredTokenShort': _get_id_token({ + 'iat': int(time.time()) - 10000, + 'exp': int(time.time()) - 30 + }), 'BadFormatToken': 'foobar' } @@ -447,7 +451,8 @@ class TestVerifyIdToken: 'NoKid', 'WrongKid', 'FutureToken', - 'ExpiredToken' + 'ExpiredToken', + 'ExpiredTokenShort', ] def _assert_valid_token(self, id_token, app): @@ -555,6 +560,20 @@ def test_expired_token(self, user_mgt_app): assert excinfo.value.cause is not None assert excinfo.value.http_response is None + def test_expired_token_with_tolerance(self, user_mgt_app): + _overwrite_cert_request(user_mgt_app, MOCK_REQUEST) + id_token = self.invalid_tokens['ExpiredTokenShort'] + if _is_emulated(): + self._assert_valid_token(id_token, user_mgt_app) + return + claims = auth.verify_id_token(id_token, app=user_mgt_app, + clock_skew_seconds=60) + assert claims['admin'] is True + assert claims['uid'] == claims['sub'] + with pytest.raises(auth.ExpiredIdTokenError): + auth.verify_id_token(id_token, app=user_mgt_app, + clock_skew_seconds=20) + def test_project_id_option(self): app = firebase_admin.initialize_app( testutils.MockCredential(), options={'projectId': 'mock-project-id'}, name='myApp') @@ -619,6 +638,10 @@ class TestVerifySessionCookie: 'iat': int(time.time()) - 10000, 'exp': int(time.time()) - 3600 }), + 'ExpiredCookieShort': _get_session_cookie({ + 'iat': int(time.time()) - 10000, + 'exp': int(time.time()) - 30 + }), 'BadFormatCookie': 'foobar', 'IDToken': TEST_ID_TOKEN, } @@ -627,7 +650,8 @@ class TestVerifySessionCookie: 'NoKid', 'WrongKid', 'FutureCookie', - 'ExpiredCookie' + 'ExpiredCookie', + 'ExpiredCookieShort', ] def _assert_valid_cookie(self, cookie, app, check_revoked=False): @@ -715,6 +739,20 @@ def test_expired_cookie(self, user_mgt_app): assert excinfo.value.cause is not None assert excinfo.value.http_response is None + def test_expired_cookie_with_tolerance(self, user_mgt_app): + _overwrite_cert_request(user_mgt_app, MOCK_REQUEST) + cookie = self.invalid_cookies['ExpiredCookieShort'] + if _is_emulated(): + self._assert_valid_cookie(cookie, user_mgt_app) + return + claims = auth.verify_session_cookie(cookie, app=user_mgt_app, check_revoked=False, + clock_skew_seconds=59) + assert claims['admin'] is True + assert claims['uid'] == claims['sub'] + with pytest.raises(auth.ExpiredSessionCookieError): + auth.verify_session_cookie(cookie, app=user_mgt_app, check_revoked=False, + clock_skew_seconds=29) + def test_project_id_option(self): app = firebase_admin.initialize_app( testutils.MockCredential(), options={'projectId': 'mock-project-id'}, name='myApp') From c77608e173685d2c882a9515bae6e5699b91d54f Mon Sep 17 00:00:00 2001 From: Jonathan Edey <145066863+jonathanedey@users.noreply.github.com> Date: Wed, 15 Nov 2023 15:19:58 -0500 Subject: [PATCH 151/226] fix: Add `PyJWKClientError` to raised errors documentation and handle possible uncaught errors. (#733) * fix: Add PyJWKClientError to raised error documentaion and handle possible uncaught errors * fix: grammar --- firebase_admin/app_check.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/firebase_admin/app_check.py b/firebase_admin/app_check.py index 91b0c4c31..6bc10b2f4 100644 --- a/firebase_admin/app_check.py +++ b/firebase_admin/app_check.py @@ -16,7 +16,7 @@ from typing import Any, Dict import jwt -from jwt import PyJWKClient, ExpiredSignatureError, InvalidTokenError +from jwt import PyJWKClient, ExpiredSignatureError, InvalidTokenError, DecodeError from jwt import InvalidAudienceError, InvalidIssuerError, InvalidSignatureError from firebase_admin import _utils @@ -38,6 +38,7 @@ def verify_token(token: str, app=None) -> Dict[str, Any]: Raises: ValueError: If the app's ``project_id`` is invalid or unspecified, or if the token's headers or payload are invalid. + PyJWKClientError: If PyJWKClient fails to fetch a valid signing key. """ return _get_app_check_service(app).verify_token(token) @@ -71,9 +72,14 @@ def verify_token(self, token: str) -> Dict[str, Any]: # Obtain the Firebase App Check Public Keys # Note: It is not recommended to hard code these keys as they rotate, # but you should cache them for up to 6 hours. - signing_key = self._jwks_client.get_signing_key_from_jwt(token) - self._has_valid_token_headers(jwt.get_unverified_header(token)) - verified_claims = self._decode_and_verify(token, signing_key.key) + try: + signing_key = self._jwks_client.get_signing_key_from_jwt(token) + self._has_valid_token_headers(jwt.get_unverified_header(token)) + verified_claims = self._decode_and_verify(token, signing_key.key) + except (InvalidTokenError, DecodeError) as exception: + raise ValueError( + f'Verifying App Check token failed. Error: {exception}' + ) verified_claims['app_id'] = verified_claims.get('sub') return verified_claims From ea885c17f045fe2e9f2733d92933bea56f037d04 Mon Sep 17 00:00:00 2001 From: Jonathan Edey <145066863+jonathanedey@users.noreply.github.com> Date: Mon, 4 Dec 2023 15:42:54 -0500 Subject: [PATCH 152/226] Deprecated support for Python 3.7 (#741) --- README.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 041c41673..f7cae21ff 100644 --- a/README.md +++ b/README.md @@ -43,7 +43,8 @@ requests, code review feedback, and also pull requests. ## Supported Python Versions -We currently support Python 3.7+. Firebase +We currently support Python 3.7+. However, Python 3.7 support is deprecated, +and developers are strongly advised to use Python 3.8 or higher. Firebase Admin Python SDK is also tested on PyPy and [Google App Engine](https://cloud.google.com/appengine/) environments. From 3773b6407f5d4dfabb17fcee5b6f8a7b1f3f8069 Mon Sep 17 00:00:00 2001 From: Jonathan Edey <145066863+jonathanedey@users.noreply.github.com> Date: Tue, 5 Dec 2023 12:15:08 -0500 Subject: [PATCH 153/226] [chore] Release 6.3.0 (#742) --- firebase_admin/__about__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/firebase_admin/__about__.py b/firebase_admin/__about__.py index ff6ad252d..091fbd205 100644 --- a/firebase_admin/__about__.py +++ b/firebase_admin/__about__.py @@ -14,7 +14,7 @@ """About information (version, etc) for Firebase Admin SDK.""" -__version__ = '6.2.0' +__version__ = '6.3.0' __title__ = 'firebase_admin' __author__ = 'Firebase' __license__ = 'Apache License 2.0' From e073f8cb6e067626d4cbf1d45d45649fe64e5a07 Mon Sep 17 00:00:00 2001 From: Jonathan Edey <145066863+jonathanedey@users.noreply.github.com> Date: Tue, 19 Dec 2023 12:15:34 -0500 Subject: [PATCH 154/226] chore: Bump github actions CI to use Node 20 (#748) * Lock firebase-tools to v13 and bump actions to use Node 20. * set firebase-tools to latest --- .github/workflows/ci.yml | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 6612efe55..4829256eb 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -11,7 +11,7 @@ jobs: python: ['3.7', '3.8', '3.9', '3.10', 'pypy3.8'] steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python }} uses: actions/setup-python@v4 with: @@ -22,10 +22,10 @@ jobs: pip install -r requirements.txt - name: Test with pytest run: pytest - - name: Set up Node.js 16 - uses: actions/setup-node@v1 + - name: Set up Node.js 20 + uses: actions/setup-node@v4 with: - node-version: 16.x + node-version: 20 - name: Run integration tests against emulator run: | npm install -g firebase-tools @@ -34,7 +34,7 @@ jobs: lint: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up Python 3.7 uses: actions/setup-python@v4 with: From e2ddedbbc5078e82fee4bbb149e292889556d087 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jes=C3=BAs=20Losada?= Date: Tue, 19 Dec 2023 18:20:23 +0100 Subject: [PATCH 155/226] fix(auth): Fix iOSBundleId parameter name (#727) Co-authored-by: Lahiru Maramba --- firebase_admin/_user_mgt.py | 2 +- tests/test_user_mgt.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/firebase_admin/_user_mgt.py b/firebase_admin/_user_mgt.py index c77c4d40d..aa0dfb0a4 100644 --- a/firebase_admin/_user_mgt.py +++ b/firebase_admin/_user_mgt.py @@ -540,7 +540,7 @@ def encode_action_code_settings(settings): if not isinstance(settings.ios_bundle_id, str): raise ValueError('Invalid value provided for ios_bundle_id: {0}' .format(settings.ios_bundle_id)) - parameters['iosBundleId'] = settings.ios_bundle_id + parameters['iOSBundleId'] = settings.ios_bundle_id # android_* attributes if (settings.android_minimum_version or settings.android_install_app) \ diff --git a/tests/test_user_mgt.py b/tests/test_user_mgt.py index b590cca05..ea9c87e6f 100644 --- a/tests/test_user_mgt.py +++ b/tests/test_user_mgt.py @@ -1369,7 +1369,7 @@ def test_valid_data(self): assert parameters['continueUrl'] == data['url'] assert parameters['canHandleCodeInApp'] == data['handle_code_in_app'] assert parameters['dynamicLinkDomain'] == data['dynamic_link_domain'] - assert parameters['iosBundleId'] == data['ios_bundle_id'] + assert parameters['iOSBundleId'] == data['ios_bundle_id'] assert parameters['androidPackageName'] == data['android_package_name'] assert parameters['androidMinimumVersion'] == data['android_minimum_version'] assert parameters['androidInstallApp'] == data['android_install_app'] @@ -1529,7 +1529,7 @@ def _validate_request(self, request, settings=None): assert request['continueUrl'] == settings.url assert request['canHandleCodeInApp'] == settings.handle_code_in_app assert request['dynamicLinkDomain'] == settings.dynamic_link_domain - assert request['iosBundleId'] == settings.ios_bundle_id + assert request['iOSBundleId'] == settings.ios_bundle_id assert request['androidPackageName'] == settings.android_package_name assert request['androidMinimumVersion'] == settings.android_minimum_version assert request['androidInstallApp'] == settings.android_install_app From 7a9dfa0e2de06c86a558b8289eee1c3407dae89b Mon Sep 17 00:00:00 2001 From: Jonathan Edey <145066863+jonathanedey@users.noreply.github.com> Date: Tue, 19 Dec 2023 12:30:02 -0500 Subject: [PATCH 156/226] chore: Update Firebase test project setup instructions. (#736) * Update Firebase test project setup instructions. * fix: numbering * fix: add missing step. * mirror Tech Writer review changes * fix: pencil * Added service account management note. --- CONTRIBUTING.md | 127 +++++++++++++++++++++++++++++++++--------------- 1 file changed, 89 insertions(+), 38 deletions(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 1d500cba8..c06d7de2c 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -162,44 +162,95 @@ pytest --cov=firebase_admin --cov=tests ### Integration Testing -A suite of integration tests are available under the `integration/` directory. -These tests are designed to run against an actual Firebase project. Create a new -project in the [Firebase Console](https://console.firebase.google.com), if you -do not already have one suitable for running the tests aginst. Then obtain the -following credentials from the project: - -1. *Service account certificate*: This can be downloaded as a JSON file from - the "Settings > Service Accounts" tab of the Firebase console. Copy the - file into the repo so it's available at `cert.json`. -2. *Web API key*: This is displayed in the "Settings > General" tab of the - console. Copy it and save to a new text file at `apikey.txt`. - -Then set up your Firebase/GCP project as follows: - -1. Enable Firestore: Go to the Firebase Console, and select "Database" from - the "Develop" menu. Click on the "Create database" button. You may choose - to set up Firestore either in the locked mode or in the test mode. -2. Enable password auth: Select "Authentication" from the "Develop" menu in - Firebase Console. Select the "Sign-in method" tab, and enable the - "Email/Password" sign-in method, including the Email link (passwordless - sign-in) option. -3. Enable the Firebase ML API: Go to the - [Google Developers Console]( - https://console.developers.google.com/apis/api/firebaseml.googleapis.com/overview) - and make sure your project is selected. If the API is not already enabled, click Enable. -4. Enable the IAM API: Go to the - [Google Cloud Platform Console](https://console.cloud.google.com) and make - sure your Firebase/GCP project is selected. Select "APIs & Services > - Dashboard" from the main menu, and click the "ENABLE APIS AND SERVICES" - button. Search for and enable the "Identity and Access Management (IAM) - API". -5. Grant your service account the 'Firebase Authentication Admin' role. This is - required to ensure that exported user records contain the password hashes of - the user accounts: - 1. Go to [Google Cloud Platform Console / IAM & admin](https://console.cloud.google.com/iam-admin). - 2. Find your service account in the list, and click the 'pencil' icon to edit it's permissions. - 3. Click 'ADD ANOTHER ROLE' and choose 'Firebase Authentication Admin'. - 4. Click 'SAVE'. + +Integration tests are executed against a real life Firebase project. If you do not already +have one suitable for running the tests against, you can create a new project in the +[Firebase Console](https://console.firebase.google.com) following the setup guide below. +If you already have a Firebase project, you'll need to obtain credentials to communicate and +authorize access to your Firebase project: + + +1. Service account certificate: This allows access to your Firebase project through a service account +which is required for all integration tests. This can be downloaded as a JSON file from the +**Settings > Service Accounts** tab of the Firebase console when you click the +**Generate new private key** button. Copy the file into the repo so it's available at `cert.json`. + > **Note:** Service accounts should be carefully managed and their keys should never be stored in publicly accessible source code or repositories. + + +2. Web API key: This allows for Auth sign-in needed for some Authentication and Tenant Management +integration tests. This is displayed in the **Settings > General** tab of the Firebase console +after enabling Authentication as described in the steps below. Copy it and save to a new text +file at `apikey.txt`. + + +Set up your Firebase project as follows: + + +1. Enable Authentication: + 1. Go to the Firebase Console, and select **Authentication** from the **Build** menu. + 2. Click on **Get Started**. + 3. Select **Sign-in method > Add new provider > Email/Password** then enable both the + **Email/Password** and **Email link (passwordless sign-in)** options. + + +2. Enable Firestore: + 1. Go to the Firebase Console, and select **Firestore Database** from the **Build** menu. + 2. Click on the **Create database** button. You can choose to set up Firestore either in + the production mode or in the test mode. + + +3. Enable Realtime Database: + 1. Go to the Firebase Console, and select **Realtime Database** from the **Build** menu. + 2. Click on the **Create Database** button. You can choose to set up the Realtime Database + either in the locked mode or in the test mode. + + > **Note:** Integration tests are not run against the default Realtime Database reference and are + instead run against a database created at `https://{PROJECT_ID}.firebaseio.com`. + This second Realtime Database reference is created in the following steps. + + 3. In the **Data** tab click on the kebab menu (3 dots) and select **Create Database**. + 4. Enter your Project ID (Found in the **General** tab in **Account Settings**) as the + **Realtime Database reference**. Again, you can choose to set up the Realtime Database + either in the locked mode or in the test mode. + + +4. Enable Storage: + 1. Go to the Firebase Console, and select **Storage** from the **Build** menu. + 2. Click on the **Get started** button. You can choose to set up Cloud Storage + either in the production mode or in the test mode. + + +5. Enable the Firebase ML API: + 1. Go to the + [Google Cloud console | Firebase ML API](https://console.cloud.google.com/apis/api/firebaseml.googleapis.com/overview) + and make sure your project is selected. + 2. If the API is not already enabled, click **Enable**. + + +6. Enable the IAM API: + 1. Go to the [Google Cloud console](https://console.cloud.google.com) + and make sure your Firebase project is selected. + 2. Select **APIs & Services** from the main menu, and click the + **ENABLE APIS AND SERVICES** button. + 3. Search for and enable **Identity and Access Management (IAM) API** by Google Enterprise API. + + +7. Enable Tenant Management: + 1. Go to + [Google Cloud console | Identity Platform](https://console.cloud.google.com/customer-identity/) + and if it is not already enabled, click **Enable**. + 2. Then + [enable multi-tenancy](https://cloud.google.com/identity-platform/docs/multi-tenancy-quickstart#enabling_multi-tenancy) + for your project. + + +8. Ensure your service account has the **Firebase Authentication Admin** role. This is required +to ensure that exported user records contain the password hashes of the user accounts: + 1. Go to [Google Cloud console | IAM & admin](https://console.cloud.google.com/iam-admin). + 2. Find your service account in the list. If not added click the pencil icon to edit its + permissions. + 3. Click **ADD ANOTHER ROLE** and choose **Firebase Authentication Admin**. + 4. Click **SAVE**. Now you can invoke the integration test suite as follows: From 3c391867ebddef0d06f57799aac40d8603ea114f Mon Sep 17 00:00:00 2001 From: Edwin Liu Date: Thu, 21 Dec 2023 02:40:34 +1100 Subject: [PATCH 157/226] Add missing return type for firebase_admin.firestore.client() (#747) firestore.client() should have return type google.cloud.firestore.Client Co-authored-by: Lahiru Maramba --- firebase_admin/firestore.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/firebase_admin/firestore.py b/firebase_admin/firestore.py index 32c9897d5..224ba3aeb 100644 --- a/firebase_admin/firestore.py +++ b/firebase_admin/firestore.py @@ -34,7 +34,7 @@ _FIRESTORE_ATTRIBUTE = '_firestore' -def client(app=None): +def client(app=None) -> firestore.Client: """Returns a client that can be used to interact with Google Cloud Firestore. Args: From b2173da73631b1c635a45eebf4238864819b5f6e Mon Sep 17 00:00:00 2001 From: Marco Tomas Rodriguez Date: Wed, 3 Jan 2024 14:51:08 -0300 Subject: [PATCH 158/226] fix(fcm): Export `send_each` and `send_each_for_multicast` (#749) --- firebase_admin/messaging.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/firebase_admin/messaging.py b/firebase_admin/messaging.py index 7e63933e1..d2ad04a04 100644 --- a/firebase_admin/messaging.py +++ b/firebase_admin/messaging.py @@ -65,6 +65,8 @@ 'send', 'send_all', 'send_multicast', + 'send_each', + 'send_each_for_multicast', 'subscribe_to_topic', 'unsubscribe_from_topic', ] From df94f8bf92a072369d49f98b177fc70aba528c22 Mon Sep 17 00:00:00 2001 From: Jonathan Edey <145066863+jonathanedey@users.noreply.github.com> Date: Wed, 3 Jan 2024 14:14:33 -0500 Subject: [PATCH 159/226] Update release and ci workflows to use `GITHUB_OUTPUT` and bump `actions/checkout` to v4. (#752) * Update release workflow to use GITHUB_OUTPUT. * Use actions/checkout@v4 for nightly builds --- .github/scripts/publish_preflight_check.sh | 12 ++++++------ .github/workflows/nightly.yml | 2 +- .github/workflows/release.yml | 4 ++-- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/.github/scripts/publish_preflight_check.sh b/.github/scripts/publish_preflight_check.sh index c962d8807..c787c8548 100755 --- a/.github/scripts/publish_preflight_check.sh +++ b/.github/scripts/publish_preflight_check.sh @@ -71,7 +71,7 @@ if [[ ! "${RELEASE_VERSION}" =~ ^([0-9]*)\.([0-9]*)\.([0-9]*)$ ]]; then fi echo_info "Extracted release version: ${RELEASE_VERSION}" -echo "::set-output name=version::v${RELEASE_VERSION}" +echo "version=v${RELEASE_VERSION}" >> $GITHUB_OUTPUT echo_info "" @@ -169,12 +169,12 @@ readonly CHANGELOG=`${CURRENT_DIR}/generate_changelog.sh` echo "$CHANGELOG" # Parse and preformat the text to handle multi-line output. -# See https://github.community/t5/GitHub-Actions/set-output-Truncates-Multiline-Strings/td-p/37870 +# See https://docs.github.com/en/actions/using-workflows/workflow-commands-for-github-actions#example-of-a-multiline-string +# and https://github.com/github/docs/issues/21529#issue-1418590935 FILTERED_CHANGELOG=`echo "$CHANGELOG" | grep -v "\\[INFO\\]"` -FILTERED_CHANGELOG="${FILTERED_CHANGELOG//'%'/'%25'}" -FILTERED_CHANGELOG="${FILTERED_CHANGELOG//$'\n'/'%0A'}" -FILTERED_CHANGELOG="${FILTERED_CHANGELOG//$'\r'/'%0D'}" -echo "::set-output name=changelog::${FILTERED_CHANGELOG}" +echo "changelog<> $GITHUB_OUTPUT +echo "$FILTERED_CHANGELOG" >> $GITHUB_OUTPUT +echo "CHANGELOGEOF" >> $GITHUB_OUTPUT echo "" diff --git a/.github/workflows/nightly.yml b/.github/workflows/nightly.yml index ac6c62abe..9dd0883ad 100644 --- a/.github/workflows/nightly.yml +++ b/.github/workflows/nightly.yml @@ -29,7 +29,7 @@ jobs: steps: - name: Checkout source for staging - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: ref: ${{ github.event.client_payload.ref || github.ref }} diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index e4c9b9fd0..f0835da6b 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -40,7 +40,7 @@ jobs: # via the 'ref' client parameter. steps: - name: Checkout source for staging - uses: actions/checkout@v2 + uses: actions/checkout@v4 with: ref: ${{ github.event.client_payload.ref || github.ref }} @@ -95,7 +95,7 @@ jobs: steps: - name: Checkout source for publish - uses: actions/checkout@v2 + uses: actions/checkout@v4 # Download the artifacts created by the stage_release job. - name: Download release candidates From 4f20371711db1d81cfc36cf032f289d36f78d946 Mon Sep 17 00:00:00 2001 From: Jonathan Edey <145066863+jonathanedey@users.noreply.github.com> Date: Tue, 9 Jan 2024 10:56:29 -0500 Subject: [PATCH 160/226] feat(fcm): Enabled the `direct_boot_ok` parameter for FCM Android Config. (#734) * feat(fcm): Enabled direct_boot_ok Android Config parameter. * Added tests. * fix: add to correct config. * fix: Validator label --- firebase_admin/_messaging_encoder.py | 11 +++++++++++ firebase_admin/_messaging_utils.py | 5 ++++- tests/test_messaging.py | 25 ++++++++++++++++++++++--- 3 files changed, 37 insertions(+), 4 deletions(-) diff --git a/firebase_admin/_messaging_encoder.py b/firebase_admin/_messaging_encoder.py index 48a3dd3cd..85072b597 100644 --- a/firebase_admin/_messaging_encoder.py +++ b/firebase_admin/_messaging_encoder.py @@ -160,6 +160,15 @@ def check_analytics_label(cls, label, value): raise ValueError('Malformed {}.'.format(label)) return value + @classmethod + def check_boolean(cls, label, value): + """Checks if the given value is boolean.""" + if value is None: + return None + if not isinstance(value, bool): + raise ValueError('{0} must be a boolean.'.format(label)) + return value + @classmethod def check_datetime(cls, label, value): """Checks if the given value is a datetime.""" @@ -196,6 +205,8 @@ def encode_android(cls, android): 'AndroidConfig.restricted_package_name', android.restricted_package_name), 'ttl': cls.encode_ttl(android.ttl), 'fcm_options': cls.encode_android_fcm_options(android.fcm_options), + 'direct_boot_ok': _Validators.check_boolean( + 'AndroidConfig.direct_boot_ok', android.direct_boot_ok), } result = cls.remove_null_values(result) priority = result.get('priority') diff --git a/firebase_admin/_messaging_utils.py b/firebase_admin/_messaging_utils.py index 64930f1b8..29b8276bc 100644 --- a/firebase_admin/_messaging_utils.py +++ b/firebase_admin/_messaging_utils.py @@ -49,10 +49,12 @@ class AndroidConfig: strings. When specified, overrides any data fields set via ``Message.data``. notification: A ``messaging.AndroidNotification`` to be included in the message (optional). fcm_options: A ``messaging.AndroidFCMOptions`` to be included in the message (optional). + direct_boot_ok: A boolean indicating whether messages will be allowed to be delivered to + the app while the device is in direct boot mode (optional). """ def __init__(self, collapse_key=None, priority=None, ttl=None, restricted_package_name=None, - data=None, notification=None, fcm_options=None): + data=None, notification=None, fcm_options=None, direct_boot_ok=None): self.collapse_key = collapse_key self.priority = priority self.ttl = ttl @@ -60,6 +62,7 @@ def __init__(self, collapse_key=None, priority=None, ttl=None, restricted_packag self.data = data self.notification = notification self.fcm_options = fcm_options + self.direct_boot_ok = direct_boot_ok class AndroidNotification: diff --git a/tests/test_messaging.py b/tests/test_messaging.py index 71bb13eed..5072df6ea 100644 --- a/tests/test_messaging.py +++ b/tests/test_messaging.py @@ -33,6 +33,7 @@ NON_OBJECT_ARGS = [list(), tuple(), dict(), 'foo', 0, 1, True, False] NON_LIST_ARGS = ['', tuple(), dict(), True, False, 1, 0, [1], ['foo', 1]] NON_UINT_ARGS = ['1.23s', list(), tuple(), dict(), -1.23] +NON_BOOL_ARGS = ['', list(), tuple(), dict(), 1, 0, [1], ['foo', 1], {1: 'foo'}, {'foo': 1}] HTTP_ERROR_CODES = { 400: exceptions.InvalidArgumentError, 403: exceptions.PermissionDeniedError, @@ -249,7 +250,8 @@ def test_fcm_options(self): topic='topic', fcm_options=messaging.FCMOptions('message-label'), android=messaging.AndroidConfig( - fcm_options=messaging.AndroidFCMOptions('android-label')), + fcm_options=messaging.AndroidFCMOptions('android-label'), + direct_boot_ok=False), apns=messaging.APNSConfig(fcm_options= messaging.APNSFCMOptions( analytics_label='apns-label', @@ -259,7 +261,8 @@ def test_fcm_options(self): { 'topic': 'topic', 'fcm_options': {'analytics_label': 'message-label'}, - 'android': {'fcm_options': {'analytics_label': 'android-label'}}, + 'android': {'fcm_options': {'analytics_label': 'android-label'}, + 'direct_boot_ok': False}, 'apns': {'fcm_options': {'analytics_label': 'apns-label', 'image': 'https://images.unsplash.com/photo-14944386399' '46-1ebd1d20bf85?fit=crop&w=900&q=60'}}, @@ -317,6 +320,20 @@ def test_invalid_data(self, data): check_encoding(messaging.Message( topic='topic', android=messaging.AndroidConfig(data=data))) + @pytest.mark.parametrize('data', NON_STRING_ARGS) + def test_invalid_analytics_label(self, data): + with pytest.raises(ValueError): + check_encoding(messaging.Message( + topic='topic', android=messaging.AndroidConfig( + fcm_options=messaging.AndroidFCMOptions(analytics_label=data)))) + + @pytest.mark.parametrize('data', NON_BOOL_ARGS) + def test_invalid_direct_boot_ok(self, data): + with pytest.raises(ValueError): + check_encoding(messaging.Message( + topic='topic', android=messaging.AndroidConfig(direct_boot_ok=data))) + + def test_android_config(self): msg = messaging.Message( topic='topic', @@ -326,7 +343,8 @@ def test_android_config(self): priority='high', ttl=123, data={'k1': 'v1', 'k2': 'v2'}, - fcm_options=messaging.AndroidFCMOptions('analytics_label_v1') + fcm_options=messaging.AndroidFCMOptions('analytics_label_v1'), + direct_boot_ok=True, ) ) expected = { @@ -343,6 +361,7 @@ def test_android_config(self): 'fcm_options': { 'analytics_label': 'analytics_label_v1', }, + 'direct_boot_ok': True, }, } check_encoding(msg, expected) From 2252f17539067333bebff6f3ba1368e19fcc3ca4 Mon Sep 17 00:00:00 2001 From: Jonathan Edey <145066863+jonathanedey@users.noreply.github.com> Date: Tue, 16 Jan 2024 14:33:45 -0500 Subject: [PATCH 161/226] chore: Add python 3.11 and 3.12 to CI tests (#754) * Add python 3.11 and 3.12 to CI tests * fix: typo * added new versions to classifier list * Trigger CI tests --- .github/workflows/ci.yml | 2 +- setup.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 4829256eb..2ff59ec77 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -8,7 +8,7 @@ jobs: strategy: fail-fast: false matrix: - python: ['3.7', '3.8', '3.9', '3.10', 'pypy3.8'] + python: ['3.7', '3.8', '3.9', '3.10', '3.11', '3.12', 'pypy3.8'] steps: - uses: actions/checkout@v4 diff --git a/setup.py b/setup.py index a82bc47f3..ef30e6be6 100644 --- a/setup.py +++ b/setup.py @@ -70,6 +70,8 @@ 'Programming Language :: Python :: 3.8', 'Programming Language :: Python :: 3.9', 'Programming Language :: Python :: 3.10', + 'Programming Language :: Python :: 3.11', + 'Programming Language :: Python :: 3.12', 'License :: OSI Approved :: Apache Software License', ], ) From 2d64228bb6204460d05fc67f2ea9d9ed562fc264 Mon Sep 17 00:00:00 2001 From: Jonathan Edey <145066863+jonathanedey@users.noreply.github.com> Date: Tue, 16 Jan 2024 15:00:38 -0500 Subject: [PATCH 162/226] Use GitHub CLI in publish workflow. (#753) --- .github/workflows/release.yml | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index f0835da6b..a6ef19c9e 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -107,19 +107,13 @@ jobs: id: preflight run: ./.github/scripts/publish_preflight_check.sh - # We pull this action from a custom fork of a contributor until - # https://github.com/actions/create-release/pull/32 is merged. Also note that v1 of - # this action does not support the "body" parameter. + # See: https://cli.github.com/manual/gh_release_create - name: Create release tag - uses: fleskesvor/create-release@1a72e235c178bf2ae6c51a8ae36febc24568c5fe env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - with: - tag_name: ${{ steps.preflight.outputs.version }} - release_name: Firebase Admin Python SDK ${{ steps.preflight.outputs.version }} - body: ${{ steps.preflight.outputs.changelog }} - draft: false - prerelease: false + run: gh release create ${{ steps.preflight.outputs.version }} + --title "Firebase Admin Python SDK ${{ steps.preflight.outputs.version }}" + --notes "${{ steps.preflight.outputs.changelog }}" - name: Publish to Pypi uses: pypa/gh-action-pypi-publish@v1.0.0a0 From c988d2fa395aeb43174eaf7956a698836b2e1abe Mon Sep 17 00:00:00 2001 From: Jonathan Edey <145066863+jonathanedey@users.noreply.github.com> Date: Thu, 18 Jan 2024 12:10:30 -0500 Subject: [PATCH 163/226] [chore] Release 6.4.0 (#756) * [chore] Release 6.4.0 * Trigger integration tests * Trigger integration tests again --- firebase_admin/__about__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/firebase_admin/__about__.py b/firebase_admin/__about__.py index 091fbd205..7ce5b6f79 100644 --- a/firebase_admin/__about__.py +++ b/firebase_admin/__about__.py @@ -14,7 +14,7 @@ """About information (version, etc) for Firebase Admin SDK.""" -__version__ = '6.3.0' +__version__ = '6.4.0' __title__ = 'firebase_admin' __author__ = 'Firebase' __license__ = 'Apache License 2.0' From 23765843a124065dcb8fd5a86a3d8f3b784df34c Mon Sep 17 00:00:00 2001 From: Jonathan Edey <145066863+jonathanedey@users.noreply.github.com> Date: Thu, 18 Jan 2024 13:12:44 -0500 Subject: [PATCH 164/226] Revert "[chore] Release 6.4.0" (#757) --- firebase_admin/__about__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/firebase_admin/__about__.py b/firebase_admin/__about__.py index 7ce5b6f79..091fbd205 100644 --- a/firebase_admin/__about__.py +++ b/firebase_admin/__about__.py @@ -14,7 +14,7 @@ """About information (version, etc) for Firebase Admin SDK.""" -__version__ = '6.4.0' +__version__ = '6.3.0' __title__ = 'firebase_admin' __author__ = 'Firebase' __license__ = 'Apache License 2.0' From 9c1f5b564e4c5d5fe276d62d3ddeb460f7e573e6 Mon Sep 17 00:00:00 2001 From: Jonathan Edey <145066863+jonathanedey@users.noreply.github.com> Date: Mon, 22 Jan 2024 16:15:43 -0500 Subject: [PATCH 165/226] chore: Update PyPi to use trusted publisher for authentication and correctly escape change log body. (#759) * fix: Escape release tag body and change PyPi to use trusted publisher for authentication. * fix typo --- .github/scripts/publish_preflight_check.sh | 2 +- .github/workflows/release.yml | 10 ++++++---- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/.github/scripts/publish_preflight_check.sh b/.github/scripts/publish_preflight_check.sh index c787c8548..c5e231690 100755 --- a/.github/scripts/publish_preflight_check.sh +++ b/.github/scripts/publish_preflight_check.sh @@ -173,7 +173,7 @@ echo "$CHANGELOG" # and https://github.com/github/docs/issues/21529#issue-1418590935 FILTERED_CHANGELOG=`echo "$CHANGELOG" | grep -v "\\[INFO\\]"` echo "changelog<> $GITHUB_OUTPUT -echo "$FILTERED_CHANGELOG" >> $GITHUB_OUTPUT +echo -e "$FILTERED_CHANGELOG" >> $GITHUB_OUTPUT echo "CHANGELOGEOF" >> $GITHUB_OUTPUT diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index a6ef19c9e..60cd9f457 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -92,6 +92,11 @@ jobs: startsWith(github.event.pull_request.title, '[chore] Release ') runs-on: ubuntu-latest + permissions: + # Used to create a short-lived OIDC token which is given to PyPi to identify this workflow job + # See: https://docs.github.com/en/actions/deployment/security-hardening-your-deployments/about-security-hardening-with-openid-connect#adding-permissions-settings + # and https://docs.pypi.org/trusted-publishers/using-a-publisher/ + id-token: write steps: - name: Checkout source for publish @@ -116,10 +121,7 @@ jobs: --notes "${{ steps.preflight.outputs.changelog }}" - name: Publish to Pypi - uses: pypa/gh-action-pypi-publish@v1.0.0a0 - with: - user: firebase - password: ${{ secrets.PYPI_PASSWORD }} + uses: pypa/gh-action-pypi-publish@release/v1 # Post to Twitter if explicitly opted-in by adding the label 'release:tweet'. - name: Post to Twitter From ee5cb242e378592d23d9c94f18d47c3efb63b606 Mon Sep 17 00:00:00 2001 From: Jonathan Edey <145066863+jonathanedey@users.noreply.github.com> Date: Tue, 23 Jan 2024 11:04:55 -0500 Subject: [PATCH 166/226] [chore] Release 6.4.0 (#760) --- firebase_admin/__about__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/firebase_admin/__about__.py b/firebase_admin/__about__.py index 091fbd205..7ce5b6f79 100644 --- a/firebase_admin/__about__.py +++ b/firebase_admin/__about__.py @@ -14,7 +14,7 @@ """About information (version, etc) for Firebase Admin SDK.""" -__version__ = '6.3.0' +__version__ = '6.4.0' __title__ = 'firebase_admin' __author__ = 'Firebase' __license__ = 'Apache License 2.0' From 451880f8d95f8a10a521520218f279d553f1e518 Mon Sep 17 00:00:00 2001 From: Jonathan Edey <145066863+jonathanedey@users.noreply.github.com> Date: Tue, 23 Jan 2024 14:22:39 -0500 Subject: [PATCH 167/226] [chore] Release 6.4.0 Take #2 (#762) --- .github/scripts/publish_preflight_check.sh | 1 + .github/workflows/release.yml | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/scripts/publish_preflight_check.sh b/.github/scripts/publish_preflight_check.sh index c5e231690..1d001c3b9 100755 --- a/.github/scripts/publish_preflight_check.sh +++ b/.github/scripts/publish_preflight_check.sh @@ -172,6 +172,7 @@ echo "$CHANGELOG" # See https://docs.github.com/en/actions/using-workflows/workflow-commands-for-github-actions#example-of-a-multiline-string # and https://github.com/github/docs/issues/21529#issue-1418590935 FILTERED_CHANGELOG=`echo "$CHANGELOG" | grep -v "\\[INFO\\]"` +FILTERED_CHANGELOG="${FILTERED_CHANGELOG//$'\''/'"'}" echo "changelog<> $GITHUB_OUTPUT echo -e "$FILTERED_CHANGELOG" >> $GITHUB_OUTPUT echo "CHANGELOGEOF" >> $GITHUB_OUTPUT diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 60cd9f457..adfde4886 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -118,7 +118,7 @@ jobs: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} run: gh release create ${{ steps.preflight.outputs.version }} --title "Firebase Admin Python SDK ${{ steps.preflight.outputs.version }}" - --notes "${{ steps.preflight.outputs.changelog }}" + --notes '${{ steps.preflight.outputs.changelog }}' - name: Publish to Pypi uses: pypa/gh-action-pypi-publish@release/v1 From b992604fb831f19a91430bc699946323e86fb344 Mon Sep 17 00:00:00 2001 From: Jonathan Edey <145066863+jonathanedey@users.noreply.github.com> Date: Tue, 23 Jan 2024 17:12:24 -0500 Subject: [PATCH 168/226] [chore] Release 6.4.0 Take 3 (#763) --- .github/workflows/release.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index adfde4886..7b57582d3 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -97,6 +97,7 @@ jobs: # See: https://docs.github.com/en/actions/deployment/security-hardening-your-deployments/about-security-hardening-with-openid-connect#adding-permissions-settings # and https://docs.pypi.org/trusted-publishers/using-a-publisher/ id-token: write + contents: write steps: - name: Checkout source for publish From a7ac17a33c86bd8dd0a80e2d9a978b96ad5249f6 Mon Sep 17 00:00:00 2001 From: Jonathan Edey <145066863+jonathanedey@users.noreply.github.com> Date: Thu, 1 Feb 2024 15:14:19 -0500 Subject: [PATCH 169/226] [chore] Upgrade `actions/setup-python` to v5 (#765) --- .github/workflows/ci.yml | 4 ++-- .github/workflows/nightly.yml | 2 +- .github/workflows/release.yml | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 2ff59ec77..00a01a908 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -13,7 +13,7 @@ jobs: steps: - uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python }} - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python }} - name: Install dependencies @@ -36,7 +36,7 @@ jobs: steps: - uses: actions/checkout@v4 - name: Set up Python 3.7 - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: 3.7 - name: Install dependencies diff --git a/.github/workflows/nightly.yml b/.github/workflows/nightly.yml index 9dd0883ad..0fe418cf7 100644 --- a/.github/workflows/nightly.yml +++ b/.github/workflows/nightly.yml @@ -34,7 +34,7 @@ jobs: ref: ${{ github.event.client_payload.ref || github.ref }} - name: Set up Python - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: 3.7 diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 7b57582d3..00e1267c8 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -45,7 +45,7 @@ jobs: ref: ${{ github.event.client_payload.ref || github.ref }} - name: Set up Python - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: 3.7 From 0752992e3f3a186a47a1bb935b14bff392894bd6 Mon Sep 17 00:00:00 2001 From: Jonathan Edey <145066863+jonathanedey@users.noreply.github.com> Date: Tue, 13 Feb 2024 15:53:23 -0500 Subject: [PATCH 170/226] [chore] Rename pytest teardown methods from deprecated `teardown` to `teardown_method` (#768) --- tests/test_messaging.py | 2 +- tests/test_token_gen.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_messaging.py b/tests/test_messaging.py index 5072df6ea..d482438f5 100644 --- a/tests/test_messaging.py +++ b/tests/test_messaging.py @@ -1586,7 +1586,7 @@ def test_aps_alert_custom_data_override(self): class TestTimeout: - def teardown(self): + def teardown_method(self): testutils.cleanup_apps() def _instrument_service(self, url, response): diff --git a/tests/test_token_gen.py b/tests/test_token_gen.py index 64540f26f..536a5ec91 100644 --- a/tests/test_token_gen.py +++ b/tests/test_token_gen.py @@ -853,5 +853,5 @@ def _instrument_session(self, app): request.session.mount('https://', testutils.MockAdapter(MOCK_PUBLIC_CERTS, 200, recorder)) return recorder - def teardown(self): + def teardown_method(self): testutils.cleanup_apps() From 8bcc751ad52b43fbdd1c098aedb2b0029addf6aa Mon Sep 17 00:00:00 2001 From: Jonathan Edey <145066863+jonathanedey@users.noreply.github.com> Date: Tue, 27 Feb 2024 15:39:23 -0500 Subject: [PATCH 171/226] feat(functions): Add task queue API support (#770) * feat(functions): Add task queue API support (#751) * Draft implementation of task queue * fix lint * Error handling, code review fixes and typos * feat(functions): Add unit and integration tests for task queue api support (#764) * Unit and Integration tests for task queues. * fix: copyright year * fix: remove commented code * feat(functions): Added `uri` task option and additional task queue test coverage (#767) * feat(functions): Add task queue API support (#751) * Draft implementation of task queue * fix lint * Error handling, code review fixes and typos * feat(functions): Add unit and integration tests for task queue api support (#764) * Unit and Integration tests for task queues. * fix: copyright year * fix: remove commented code * feat(functions): Added `uri` task option and additional task queue test coverage * Removed uri and add doc strings * fix removed typo * re-add missing uri changes * fix missing check * fix: TW requested changes * fix: Added extra note for full list of replaced headers and undo Content-Type change --- firebase_admin/functions.py | 437 ++++++++++++++++++++++++++++++++++ integration/test_functions.py | 56 +++++ tests/test_functions.py | 301 +++++++++++++++++++++++ tests/testutils.py | 19 +- 4 files changed, 812 insertions(+), 1 deletion(-) create mode 100644 firebase_admin/functions.py create mode 100644 integration/test_functions.py create mode 100644 tests/test_functions.py diff --git a/firebase_admin/functions.py b/firebase_admin/functions.py new file mode 100644 index 000000000..b39ee0a66 --- /dev/null +++ b/firebase_admin/functions.py @@ -0,0 +1,437 @@ +# Copyright 2024 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Firebase Functions module.""" + +from __future__ import annotations +from datetime import datetime, timedelta +from urllib import parse +import re +import json +from base64 import b64encode +from typing import Any, Optional, Dict +from dataclasses import dataclass +from google.auth.compute_engine import Credentials as ComputeEngineCredentials + +import requests +import firebase_admin +from firebase_admin import App +from firebase_admin import _http_client +from firebase_admin import _utils + +_FUNCTIONS_ATTRIBUTE = '_functions' + +__all__ = [ + 'TaskOptions', + + 'task_queue', +] + + +_CLOUD_TASKS_API_RESOURCE_PATH = \ + 'projects/{project_id}/locations/{location_id}/queues/{resource_id}/tasks' +_CLOUD_TASKS_API_URL_FORMAT = \ + 'https://cloudtasks.googleapis.com/v2/' + _CLOUD_TASKS_API_RESOURCE_PATH +_FIREBASE_FUNCTION_URL_FORMAT = \ + 'https://{location_id}-{project_id}.cloudfunctions.net/{resource_id}' + +_FUNCTIONS_HEADERS = { + 'X-GOOG-API-FORMAT-VERSION': '2', + 'X-FIREBASE-CLIENT': 'fire-admin-python/{0}'.format(firebase_admin.__version__), +} + +# Default canonical location ID of the task queue. +_DEFAULT_LOCATION = 'us-central1' + +def _get_functions_service(app) -> _FunctionsService: + return _utils.get_app_service(app, _FUNCTIONS_ATTRIBUTE, _FunctionsService) + +def task_queue( + function_name: str, + extension_id: Optional[str] = None, + app: Optional[App] = None + ) -> TaskQueue: + """Creates a reference to a TaskQueue for a given function name. + + The function name can be either: + 1. A fully qualified function resource name: + `projects/{project-id}/locations/{location-id}/functions/{function-name}` + + 2. A partial resource name with location and function name, in which case + the runtime project ID is used: + `locations/{location-id}/functions/{function-name}` + + 3. A partial function name, in which case the runtime project ID and the + default location, `us-central1`, is used: + `{function-name}` + + Args: + function_name: Name of the function. + extension_id: Firebase extension ID (optional). + app: An App instance (optional). + + Returns: + TaskQueue: A TaskQueue instance. + + Raises: + ValueError: If the input arguments are invalid. + """ + return _get_functions_service(app).task_queue(function_name, extension_id) + +class _FunctionsService: + """Service class that implements Firebase Functions functionality.""" + def __init__(self, app: App): + self._project_id = app.project_id + if not self._project_id: + raise ValueError( + 'Project ID is required to access the Cloud Functions service. Either set the ' + 'projectId option, or use service account credentials. Alternatively, set the ' + 'GOOGLE_CLOUD_PROJECT environment variable.') + + self._credential = app.credential.get_credential() + self._http_client = _http_client.JsonHttpClient(credential=self._credential) + + def task_queue(self, function_name: str, extension_id: Optional[str] = None) -> TaskQueue: + """Creates a TaskQueue instance.""" + return TaskQueue( + function_name, extension_id, self._project_id, self._credential, self._http_client) + + @classmethod + def handle_functions_error(cls, error: Any): + """Handles errors received from the Cloud Functions API.""" + + return _utils.handle_platform_error_from_requests(error) + +class TaskQueue: + """TaskQueue class that implements Firebase Cloud Tasks Queues functionality.""" + def __init__( + self, + function_name: str, + extension_id: Optional[str], + project_id, + credential, + http_client + ) -> None: + + # Validate function_name + _Validators.check_non_empty_string('function_name', function_name) + + self._project_id = project_id + self._credential = credential + self._http_client = http_client + self._function_name = function_name + self._extension_id = extension_id + # Parse resources from function_name + self._resource = self._parse_resource_name(self._function_name, 'functions') + + # Apply defaults and validate resource_id + self._resource.project_id = self._resource.project_id or self._project_id + self._resource.location_id = self._resource.location_id or _DEFAULT_LOCATION + _Validators.check_non_empty_string('resource.resource_id', self._resource.resource_id) + # Validate extension_id if provided and edit resources depending + if self._extension_id is not None: + _Validators.check_non_empty_string('extension_id', self._extension_id) + self._resource.resource_id = f'ext-{self._extension_id}-{self._resource.resource_id}' + + + def enqueue(self, task_data: Any, opts: Optional[TaskOptions] = None) -> str: + """Creates a task and adds it to the queue. Tasks cannot be updated after creation. + + This action requires `cloudtasks.tasks.create` IAM permission on the service account. + + Args: + task_data: The data payload of the task. + opts: Options when enqueuing a new task (optional). + + Raises: + FirebaseError: If an error occurs while requesting the task to be queued by + the Cloud Functions service. + ValueError: If the input arguments are invalid. + + Returns: + str: The ID of the task relative to this queue. + """ + task = self._validate_task_options(task_data, self._resource, opts) + service_url = self._get_url(self._resource, _CLOUD_TASKS_API_URL_FORMAT) + task_payload = self._update_task_payload(task, self._resource, self._extension_id) + try: + resp = self._http_client.body( + 'post', + url=service_url, + headers=_FUNCTIONS_HEADERS, + json={'task': task_payload.__dict__} + ) + task_name = resp.get('name', None) + task_resource = \ + self._parse_resource_name(task_name, f'queues/{self._resource.resource_id}/tasks') + return task_resource.resource_id + except requests.exceptions.RequestException as error: + raise _FunctionsService.handle_functions_error(error) + + def delete(self, task_id: str) -> None: + """Deletes an enqueued task if it has not yet started. + + This action requires `cloudtasks.tasks.delete` IAM permission on the service account. + + Args: + task_id: The ID of the task relative to this queue. + + Raises: + FirebaseError: If an error occurs while requesting the task to be deleted by + the Cloud Functions service. + ValueError: If the input arguments are invalid. + """ + _Validators.check_non_empty_string('task_id', task_id) + service_url = self._get_url(self._resource, _CLOUD_TASKS_API_URL_FORMAT + f'/{task_id}') + try: + self._http_client.body( + 'delete', + url=service_url, + headers=_FUNCTIONS_HEADERS, + ) + except requests.exceptions.RequestException as error: + raise _FunctionsService.handle_functions_error(error) + + + def _parse_resource_name(self, resource_name: str, resource_id_key: str) -> Resource: + """Parses a full or partial resource path into a ``Resource``.""" + if '/' not in resource_name: + return Resource(resource_id=resource_name) + + reg = f'^(projects/([^/]+)/)?locations/([^/]+)/{resource_id_key}/([^/]+)$' + match = re.search(reg, resource_name) + if match is None: + raise ValueError('Invalid resource name format.') + return Resource(project_id=match[2], location_id=match[3], resource_id=match[4]) + + def _get_url(self, resource: Resource, url_format: str) -> str: + """Generates url path from a ``Resource`` and url format string.""" + return url_format.format( + project_id=resource.project_id, + location_id=resource.location_id, + resource_id=resource.resource_id) + + def _validate_task_options( + self, + data: Any, + resource: Resource, + opts: Optional[TaskOptions] = None + ) -> Task: + """Validate and create a Task from optional ``TaskOptions``.""" + task_http_request = { + 'url': '', + 'oidc_token': { + 'service_account_email': '' + }, + 'body': b64encode(json.dumps(data).encode()).decode(), + 'headers': { + 'Content-Type': 'application/json', + } + } + task = Task(http_request=task_http_request) + + if opts is not None: + if opts.headers is not None: + task.http_request['headers'] = {**task.http_request['headers'], **opts.headers} + if opts.schedule_time is not None and opts.schedule_delay_seconds is not None: + raise ValueError( + 'Both sechdule_delay_seconds and schedule_time cannot be set at the same time.') + if opts.schedule_time is not None and opts.schedule_delay_seconds is None: + if not isinstance(opts.schedule_time, datetime): + raise ValueError('schedule_time should be UTC datetime.') + task.schedule_time = opts.schedule_time.strftime('%Y-%m-%dT%H:%M:%S.%fZ') + if opts.schedule_delay_seconds is not None and opts.schedule_time is None: + if not isinstance(opts.schedule_delay_seconds, int) \ + or opts.schedule_delay_seconds < 0: + raise ValueError('schedule_delay_seconds should be positive int.') + schedule_time = datetime.utcnow() + timedelta(seconds=opts.schedule_delay_seconds) + task.schedule_time = schedule_time.strftime('%Y-%m-%dT%H:%M:%S.%fZ') + if opts.dispatch_deadline_seconds is not None: + if not isinstance(opts.dispatch_deadline_seconds, int) \ + or opts.dispatch_deadline_seconds < 15 \ + or opts.dispatch_deadline_seconds > 1800: + raise ValueError( + 'dispatch_deadline_seconds should be int in the range of 15s to ' + '1800s (30 mins).') + task.dispatch_deadline = f'{opts.dispatch_deadline_seconds}s' + if opts.task_id is not None: + if not _Validators.is_task_id(opts.task_id): + raise ValueError( + 'task_id can contain only letters ([A-Za-z]), numbers ([0-9]), hyphens (-)' + ', or underscores (_). The maximum length is 500 characters.') + task.name = self._get_url( + resource, _CLOUD_TASKS_API_RESOURCE_PATH + f'/{opts.task_id}') + if opts.uri is not None: + if not _Validators.is_url(opts.uri): + raise ValueError( + 'uri must be a valid RFC3986 URI string using the https or http schema.') + task.http_request['url'] = opts.uri + return task + + def _update_task_payload(self, task: Task, resource: Resource, extension_id: str) -> Task: + """Prepares task to be sent with credentials.""" + # Get function url from task or generate from resources + if not _Validators.is_non_empty_string(task.http_request['url']): + task.http_request['url'] = self._get_url(resource, _FIREBASE_FUNCTION_URL_FORMAT) + # If extension id is provided, it emplies that it is being run from a deployed extension. + # Meaning that it's credential should be a Compute Engine Credential. + if _Validators.is_non_empty_string(extension_id) and \ + isinstance(self._credential, ComputeEngineCredentials): + + id_token = self._credential.token + task.http_request['headers'] = \ + {**task.http_request['headers'], 'Authorization': f'Bearer ${id_token}'} + # Delete oidc token + del task.http_request['oidc_token'] + else: + task.http_request['oidc_token'] = \ + {'service_account_email': self._credential.service_account_email} + return task + + +class _Validators: + """A collection of data validation utilities.""" + @classmethod + def check_non_empty_string(cls, label: str, value: Any): + """Checks if given value is a non-empty string and throws error if not.""" + if not isinstance(value, str): + raise ValueError('{0} "{1}" must be a string.'.format(label, value)) + if value == '': + raise ValueError('{0} "{1}" must be a non-empty string.'.format(label, value)) + + @classmethod + def is_non_empty_string(cls, value: Any): + """Checks if given value is a non-empty string and returns bool.""" + if not isinstance(value, str) or value == '': + return False + return True + + @classmethod + def is_task_id(cls, task_id: Any): + """Checks if given value is a valid task id.""" + reg = '^[A-Za-z0-9_-]+$' + if re.match(reg, task_id) is not None and len(task_id) <= 500: + return True + return False + + @classmethod + def is_url(cls, url: Any): + """Checks if given value is a valid url.""" + if not isinstance(url, str): + return False + try: + parsed = parse.urlparse(url) + if not parsed.netloc or parsed.scheme not in ['http', 'https']: + return False + return True + except Exception: # pylint: disable=broad-except + return False + + +@dataclass +class TaskOptions: + """Task Options that can be applied to a Task. + Args: + schedule_delay_seconds: The number of seconds after the current time at which to attempt or + retry the task. Should only be set if ``schedule_time`` is not set. + + schedule_time: The time when the task is scheduled to be attempted or retried. Should only + be set if ``schedule_delay_seconds`` is not set. + + dispatch_deadline_seconds: The deadline for requests sent to the worker. If the worker does + not respond by this deadline then the request is cancelled and the attempt is marked as + a ``DEADLINE_EXCEEDED`` failure. Cloud Tasks will retry the task according to the + ``RetryConfig``. The default is 10 minutes. The deadline must be in the range of 15 + seconds and 30 minutes (1800 seconds). + + task_id: The ID to use for the enqueued task. If not provided, one will be automatically + generated. + + If provided, an explicitly specified task ID enables task de-duplication. + Task IDs should be strings that contain only letters ([A-Za-z]), numbers ([0-9]), + hyphens (-), and underscores (_) with a maximum length of 500 characters. If a task's + ID is identical to that of an existing task or a task that was deleted or executed + recently then the call will throw an error with code "functions/task-already-exists". + Another task with the same ID can't be created for ~1hour after the original task was + deleted or executed. + + Because there is an extra lookup cost to identify duplicate task IDs, setting ID + significantly increases latency. + + Also, note that the infrastructure relies on an approximately uniform distribution + of task IDs to store and serve tasks efficiently. For this reason, using hashed strings + for the task ID or for the prefix of the task ID is recommended. Choosing task IDs that + are sequential or have sequential prefixes, for example using a timestamp, causes an + increase in latency and error rates in all task commands. + + "Push IDs" from the Firebase Realtime Database make poor IDs because they are based on + timestamps and will cause contention (slowdowns) in your task queue. Reversed push IDs + however form a perfect distribution and are an ideal key. To reverse a string in Python + use ``reversedString = someString[::-1]`` + + headers: HTTP request headers to include in the request to the task queue function. These + headers represent a subset of the headers that will accompany the task's HTTP request. + Some HTTP request headers will be ignored or replaced: `Authorization`, `Host`, + `Content-Length`, `User-Agent` and others cannot be overridden. + + A complete list of these ignored or replaced headers can be found in the following + definition of the HttpRequest.headers property: + https://cloud.google.com/tasks/docs/reference/rest/v2/projects.locations.queues.tasks#httprequest + + By default, Content-Type is set to 'application/json'. + + The size of the headers must be less than 80KB. + + uri: The full URL that the request will be sent to. Must be a valid RFC3986 https or + http URL. + """ + schedule_delay_seconds: Optional[int] = None + schedule_time: Optional[datetime] = None + dispatch_deadline_seconds: Optional[int] = None + task_id: Optional[str] = None + headers: Optional[Dict[str, str]] = None + uri: Optional[str] = None + +@dataclass +class Task: + """Contains the relevant fields for enqueueing tasks that trigger Cloud Functions. + + This is a limited subset of the Cloud Functions `Task` resource. See the following + page for definitions of this class's properties: + https://cloud.google.com/tasks/docs/reference/rest/v2/projects.locations.queues.tasks#resource:-task + + Args: + httpRequest: The request to be made by the task worker. + name: The name of the function. See the Cloud docs for the format of this property. + schedule_time: The time when the task is scheduled to be attempted or retried. + dispatch_deadline: The deadline for requests sent to the worker. + """ + http_request: Dict[str, Optional[str | dict]] + name: Optional[str] = None + schedule_time: Optional[str] = None + dispatch_deadline: Optional[str] = None + + +@dataclass +class Resource: + """Contains the parsed address of a resource. + + Args: + resource_id: The ID of the resource. + project_id: The project ID of the resource. + location_id: The location ID of the resource. + """ + resource_id: str + project_id: Optional[str] = None + location_id: Optional[str] = None diff --git a/integration/test_functions.py b/integration/test_functions.py new file mode 100644 index 000000000..606798436 --- /dev/null +++ b/integration/test_functions.py @@ -0,0 +1,56 @@ +# Copyright 2024 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Integration tests for firebase_admin.functions module.""" + +import pytest + +import firebase_admin +from firebase_admin import functions +from integration import conftest + + +@pytest.fixture(scope='module') +def app(request): + cred, _ = conftest.integration_conf(request) + return firebase_admin.initialize_app(cred, name='integration-functions') + + +class TestFunctions: + + _TEST_FUNCTIONS_PARAMS = [ + {'function_name': 'function-name'}, + {'function_name': 'projects/test-project/locations/test-location/functions/function-name'}, + {'function_name': 'function-name', 'extension_id': 'extension-id'}, + { + 'function_name': \ + 'projects/test-project/locations/test-location/functions/function-name', + 'extension_id': 'extension-id' + } + ] + + @pytest.mark.parametrize('task_queue_params', _TEST_FUNCTIONS_PARAMS) + def test_task_queue(self, task_queue_params): + queue = functions.task_queue(**task_queue_params) + assert queue is not None + assert callable(queue.enqueue) + assert callable(queue.delete) + + @pytest.mark.parametrize('task_queue_params', _TEST_FUNCTIONS_PARAMS) + def test_task_queue_app(self, task_queue_params, app): + assert app.name == 'integration-functions' + queue = functions.task_queue(**task_queue_params, app=app) + assert queue is not None + assert callable(queue.enqueue) + assert callable(queue.delete) diff --git a/tests/test_functions.py b/tests/test_functions.py new file mode 100644 index 000000000..75809c1ad --- /dev/null +++ b/tests/test_functions.py @@ -0,0 +1,301 @@ +# Copyright 2024 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test cases for the firebase_admin.functions module.""" + +from datetime import datetime, timedelta +import json +import time +import pytest + +import firebase_admin +from firebase_admin import functions +from tests import testutils + + +_DEFAULT_DATA = {'city': 'Seattle'} +_CLOUD_TASKS_URL = 'https://cloudtasks.googleapis.com/v2/' +_DEFAULT_TASK_PATH = \ + 'projects/test-project/locations/us-central1/queues/test-function-name/tasks/test-task-id' +_DEFAULT_REQUEST_URL = \ + _CLOUD_TASKS_URL + 'projects/test-project/locations/us-central1/queues/test-function-name/tasks' +_DEFAULT_TASK_URL = _CLOUD_TASKS_URL + _DEFAULT_TASK_PATH +_DEFAULT_RESPONSE = json.dumps({'name': _DEFAULT_TASK_PATH}) +_ENQUEUE_TIME = datetime.utcnow() +_SCHEDULE_TIME = _ENQUEUE_TIME + timedelta(seconds=100) + +class TestTaskQueue: + @classmethod + def setup_class(cls): + cred = testutils.MockCredential() + firebase_admin.initialize_app(cred, {'projectId': 'test-project'}) + + @classmethod + def teardown_class(cls): + testutils.cleanup_apps() + + def _instrument_functions_service(self, app=None, status=200, payload=_DEFAULT_RESPONSE): + if not app: + app = firebase_admin.get_app() + functions_service = functions._get_functions_service(app) + recorder = [] + functions_service._http_client.session.mount( + _CLOUD_TASKS_URL, + testutils.MockAdapter(payload, status, recorder)) + return functions_service, recorder + + def test_task_queue_no_project_id(self): + def evaluate(): + app = firebase_admin.initialize_app(testutils.MockCredential(), name='no-project-id') + with pytest.raises(ValueError): + functions.task_queue('test-function-name', app=app) + testutils.run_without_project_id(evaluate) + + @pytest.mark.parametrize('function_name', [ + 'projects/test-project/locations/us-central1/functions/test-function-name', + 'locations/us-central1/functions/test-function-name', + 'test-function-name', + ]) + def test_task_queue_function_name(self, function_name): + queue = functions.task_queue(function_name) + assert queue._resource.resource_id == 'test-function-name' + assert queue._resource.project_id == 'test-project' + assert queue._resource.location_id == 'us-central1' + + def test_task_queue_empty_function_name_error(self): + with pytest.raises(ValueError) as excinfo: + functions.task_queue('') + assert str(excinfo.value) == 'function_name "" must be a non-empty string.' + + def test_task_queue_non_string_function_name_error(self): + with pytest.raises(ValueError) as excinfo: + functions.task_queue(1234) + assert str(excinfo.value) == 'function_name "1234" must be a string.' + + @pytest.mark.parametrize('function_name', [ + '/test', + 'test/', + 'test-project/us-central1/test-function-name', + 'projects/test-project/functions/test-function-name', + 'functions/test-function-name', + ]) + def test_task_queue_invalid_function_name_error(self, function_name): + with pytest.raises(ValueError) as excinfo: + functions.task_queue(function_name) + assert str(excinfo.value) == 'Invalid resource name format.' + + def test_task_queue_extension_id(self): + queue = functions.task_queue("test-function-name", "test-extension-id") + assert queue._resource.resource_id == 'ext-test-extension-id-test-function-name' + assert queue._resource.project_id == 'test-project' + assert queue._resource.location_id == 'us-central1' + + def test_task_queue_empty_extension_id_error(self): + with pytest.raises(ValueError) as excinfo: + functions.task_queue('test-function-name', '') + assert str(excinfo.value) == 'extension_id "" must be a non-empty string.' + + def test_task_queue_non_string_extension_id_error(self): + with pytest.raises(ValueError) as excinfo: + functions.task_queue('test-function-name', 1234) + assert str(excinfo.value) == 'extension_id "1234" must be a string.' + + + def test_task_enqueue(self): + _, recorder = self._instrument_functions_service() + queue = functions.task_queue('test-function-name') + task_id = queue.enqueue(_DEFAULT_DATA) + assert len(recorder) == 1 + assert recorder[0].method == 'POST' + assert recorder[0].url == _DEFAULT_REQUEST_URL + assert recorder[0].headers['Content-Type'] == 'application/json' + assert recorder[0].headers['Authorization'] == 'Bearer mock-token' + assert task_id == 'test-task-id' + + def test_task_enqueue_with_extension(self): + resource_name = ( + 'projects/test-project/locations/us-central1/queues/' + 'ext-test-extension-id-test-function-name/tasks' + ) + extension_response = json.dumps({'name': resource_name + '/test-task-id'}) + _, recorder = self._instrument_functions_service(payload=extension_response) + queue = functions.task_queue('test-function-name', 'test-extension-id') + task_id = queue.enqueue(_DEFAULT_DATA) + assert len(recorder) == 1 + assert recorder[0].method == 'POST' + assert recorder[0].url == _CLOUD_TASKS_URL + resource_name + assert recorder[0].headers['Content-Type'] == 'application/json' + assert recorder[0].headers['Authorization'] == 'Bearer mock-token' + assert task_id == 'test-task-id' + + def test_task_delete(self): + _, recorder = self._instrument_functions_service() + queue = functions.task_queue('test-function-name') + queue.delete('test-task-id') + assert len(recorder) == 1 + assert recorder[0].method == 'DELETE' + assert recorder[0].url == _DEFAULT_TASK_URL + + +class TestTaskQueueOptions: + + _DEFAULT_TASK_OPTS = {'schedule_delay_seconds': None, 'schedule_time': None, \ + 'dispatch_deadline_seconds': None, 'task_id': None, 'headers': None} + + non_alphanumeric_chars = [ + ',', '.', '?', '!', ':', ';', "'", '"', '(', ')', '[', ']', '{', '}', + '@', '&', '*', '+', '=', '$', '%', '#', '~', '\\', '/', '|', '^', + '\t', '\n', '\r', '\f', '\v', '\0', '\a', '\b', + 'é', 'ç', 'ö', '❤️', '€', '¥', '£', '←', '→', '↑', '↓', 'π', 'Ω', 'ß' + ] + + @classmethod + def setup_class(cls): + cred = testutils.MockCredential() + firebase_admin.initialize_app(cred, {'projectId': 'test-project'}) + + @classmethod + def teardown_class(cls): + testutils.cleanup_apps() + + def _instrument_functions_service(self, app=None, status=200, payload=_DEFAULT_RESPONSE): + if not app: + app = firebase_admin.get_app() + functions_service = functions._get_functions_service(app) + recorder = [] + functions_service._http_client.session.mount( + _CLOUD_TASKS_URL, + testutils.MockAdapter(payload, status, recorder)) + return functions_service, recorder + + + @pytest.mark.parametrize('task_opts_params', [ + { + 'schedule_delay_seconds': 100, + 'schedule_time': None, + 'dispatch_deadline_seconds': 200, + 'task_id': 'test-task-id', + 'headers': {'x-test-header': 'test-header-value'}, + 'uri': 'https://google.com' + }, + { + 'schedule_delay_seconds': None, + 'schedule_time': _SCHEDULE_TIME, + 'dispatch_deadline_seconds': 200, + 'task_id': 'test-task-id', + 'headers': {'x-test-header': 'test-header-value'}, + 'uri': 'http://google.com' + }, + ]) + def test_task_options(self, task_opts_params): + _, recorder = self._instrument_functions_service() + queue = functions.task_queue('test-function-name') + task_opts = functions.TaskOptions(**task_opts_params) + queue.enqueue(_DEFAULT_DATA, task_opts) + + assert len(recorder) == 1 + task = json.loads(recorder[0].body.decode())['task'] + + schedule_time = datetime.fromisoformat(task['schedule_time'][:-1]) + delta = abs(schedule_time - _SCHEDULE_TIME) + assert delta <= timedelta(seconds=15) + + assert task['dispatch_deadline'] == '200s' + assert task['http_request']['headers']['x-test-header'] == 'test-header-value' + assert task['http_request']['url'] in ['http://google.com', 'https://google.com'] + assert task['name'] == _DEFAULT_TASK_PATH + + + def test_schedule_set_twice_error(self): + _, recorder = self._instrument_functions_service() + opts = functions.TaskOptions(schedule_delay_seconds=100, schedule_time=datetime.utcnow()) + queue = functions.task_queue('test-function-name') + with pytest.raises(ValueError) as excinfo: + queue.enqueue(_DEFAULT_DATA, opts) + assert len(recorder) == 0 + assert str(excinfo.value) == \ + 'Both sechdule_delay_seconds and schedule_time cannot be set at the same time.' + + + @pytest.mark.parametrize('schedule_time', [ + time.time(), + str(datetime.utcnow()), + datetime.utcnow().isoformat(), + datetime.utcnow().isoformat() + 'Z', + '', ' ' + ]) + def test_invalid_schedule_time_error(self, schedule_time): + _, recorder = self._instrument_functions_service() + opts = functions.TaskOptions(schedule_time=schedule_time) + queue = functions.task_queue('test-function-name') + with pytest.raises(ValueError) as excinfo: + queue.enqueue(_DEFAULT_DATA, opts) + assert len(recorder) == 0 + assert str(excinfo.value) == 'schedule_time should be UTC datetime.' + + + @pytest.mark.parametrize('schedule_delay_seconds', [ + -1, '100', '-1', '', ' ', -1.23, 1.23 + ]) + def test_invalid_schedule_delay_seconds_error(self, schedule_delay_seconds): + _, recorder = self._instrument_functions_service() + opts = functions.TaskOptions(schedule_delay_seconds=schedule_delay_seconds) + queue = functions.task_queue('test-function-name') + with pytest.raises(ValueError) as excinfo: + queue.enqueue(_DEFAULT_DATA, opts) + assert len(recorder) == 0 + assert str(excinfo.value) == 'schedule_delay_seconds should be positive int.' + + + @pytest.mark.parametrize('dispatch_deadline_seconds', [ + 14, 1801, -15, -1800, 0, '100', '-1', '', ' ', -1.23, 1.23, + ]) + def test_invalid_dispatch_deadline_seconds_error(self, dispatch_deadline_seconds): + _, recorder = self._instrument_functions_service() + opts = functions.TaskOptions(dispatch_deadline_seconds=dispatch_deadline_seconds) + queue = functions.task_queue('test-function-name') + with pytest.raises(ValueError) as excinfo: + queue.enqueue(_DEFAULT_DATA, opts) + assert len(recorder) == 0 + assert str(excinfo.value) == \ + 'dispatch_deadline_seconds should be int in the range of 15s to 1800s (30 mins).' + + + @pytest.mark.parametrize('task_id', [ + '', ' ', 'task/1', 'task.1', 'a'*501, *non_alphanumeric_chars + ]) + def test_invalid_task_id_error(self, task_id): + _, recorder = self._instrument_functions_service() + opts = functions.TaskOptions(task_id=task_id) + queue = functions.task_queue('test-function-name') + with pytest.raises(ValueError) as excinfo: + queue.enqueue(_DEFAULT_DATA, opts) + assert len(recorder) == 0 + assert str(excinfo.value) == ( + 'task_id can contain only letters ([A-Za-z]), numbers ([0-9]), ' + 'hyphens (-), or underscores (_). The maximum length is 500 characters.' + ) + + @pytest.mark.parametrize('uri', [ + '', ' ', 'a', 'foo', 'image.jpg', [], {}, True, 'google.com', 'www.google.com' + ]) + def test_invalid_uri_error(self, uri): + _, recorder = self._instrument_functions_service() + opts = functions.TaskOptions(uri=uri) + queue = functions.task_queue('test-function-name') + with pytest.raises(ValueError) as excinfo: + queue.enqueue(_DEFAULT_DATA, opts) + assert len(recorder) == 0 + assert str(excinfo.value) == \ + 'uri must be a valid RFC3986 URI string using the https or http schema.' diff --git a/tests/testutils.py b/tests/testutils.py index e52b90d1a..ab4fb40cb 100644 --- a/tests/testutils.py +++ b/tests/testutils.py @@ -18,7 +18,7 @@ import pytest -from google.auth import credentials +from google.auth import credentials, compute_engine from google.auth import transport from requests import adapters from requests import models @@ -119,6 +119,10 @@ class MockGoogleCredential(credentials.Credentials): def refresh(self, request): self.token = 'mock-token' + @property + def service_account_email(self): + return 'mock-email' + class MockCredential(firebase_admin.credentials.Base): """A mock Firebase credential implementation.""" @@ -129,6 +133,19 @@ def __init__(self): def get_credential(self): return self._g_credential +class MockGoogleComputeEngineCredential(compute_engine.Credentials): + """A mock Compute Engine credential""" + def refresh(self, request): + self.token = 'mock-compute-engine-token' + +class MockComputeEngineCredential(firebase_admin.credentials.Base): + """A mock Firebase credential implementation.""" + + def __init__(self): + self._g_credential = MockGoogleComputeEngineCredential() + + def get_credential(self): + return self._g_credential class MockMultiRequestAdapter(adapters.HTTPAdapter): """A mock HTTP adapter that supports multiple responses for the Python requests module.""" From 7dbe27820fd0f38a0d5d0340b0583ad97aec3cb6 Mon Sep 17 00:00:00 2001 From: Jonathan Edey <145066863+jonathanedey@users.noreply.github.com> Date: Tue, 5 Mar 2024 15:51:06 -0500 Subject: [PATCH 172/226] fix: doc spacing (#775) --- firebase_admin/functions.py | 1 + 1 file changed, 1 insertion(+) diff --git a/firebase_admin/functions.py b/firebase_admin/functions.py index b39ee0a66..7df9bc607 100644 --- a/firebase_admin/functions.py +++ b/firebase_admin/functions.py @@ -342,6 +342,7 @@ def is_url(cls, url: Any): @dataclass class TaskOptions: """Task Options that can be applied to a Task. + Args: schedule_delay_seconds: The number of seconds after the current time at which to attempt or retry the task. Should only be set if ``schedule_time`` is not set. From a14ca32ad2cb06bbd30990d15d28eee01c82814a Mon Sep 17 00:00:00 2001 From: Pieter Ennes Date: Wed, 6 Mar 2024 20:42:47 +0100 Subject: [PATCH 173/226] Add rate limiting exceptions. (#695) --- firebase_admin/_auth_utils.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/firebase_admin/_auth_utils.py b/firebase_admin/_auth_utils.py index 7aece495e..ac7b322ff 100644 --- a/firebase_admin/_auth_utils.py +++ b/firebase_admin/_auth_utils.py @@ -405,6 +405,20 @@ def __init__(self, message, cause=None, http_response=None): exceptions.InvalidArgumentError.__init__(self, message, cause, http_response) +class TooManyAttemptsTryLaterError(exceptions.ResourceExhaustedError): + """Rate limited because of too many attempts.""" + + def __init__(self, message, cause=None, http_response=None): + exceptions.ResourceExhaustedError.__init__(self, message, cause, http_response) + + +class ResetPasswordExceedLimitError(exceptions.ResourceExhaustedError): + """Reset password emails exceeded their limits.""" + + def __init__(self, message, cause=None, http_response=None): + exceptions.ResourceExhaustedError.__init__(self, message, cause, http_response) + + _CODE_TO_EXC_TYPE = { 'CONFIGURATION_NOT_FOUND': ConfigurationNotFoundError, 'DUPLICATE_EMAIL': EmailAlreadyExistsError, @@ -417,6 +431,8 @@ def __init__(self, message, cause=None, http_response=None): 'PHONE_NUMBER_EXISTS': PhoneNumberAlreadyExistsError, 'TENANT_NOT_FOUND': TenantNotFoundError, 'USER_NOT_FOUND': UserNotFoundError, + 'TOO_MANY_ATTEMPTS_TRY_LATER': TooManyAttemptsTryLaterError, + 'RESET_PASSWORD_EXCEED_LIMIT': ResetPasswordExceedLimitError, } From f25394eaf53c9d6065a6aa752fec70ee021cc968 Mon Sep 17 00:00:00 2001 From: Jonathan Edey <145066863+jonathanedey@users.noreply.github.com> Date: Wed, 6 Mar 2024 16:36:53 -0500 Subject: [PATCH 174/226] Export rate limiting error types `TooManyAttemptsTryLaterError` and `ResetPasswordExceedLimitError` (#777) --- firebase_admin/auth.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/firebase_admin/auth.py b/firebase_admin/auth.py index 84873c3da..ced143112 100644 --- a/firebase_admin/auth.py +++ b/firebase_admin/auth.py @@ -56,10 +56,12 @@ 'OIDCProviderConfig', 'PhoneNumberAlreadyExistsError', 'ProviderConfig', + 'ResetPasswordExceedLimitError', 'RevokedIdTokenError', 'RevokedSessionCookieError', 'SAMLProviderConfig', 'TokenSignError', + 'TooManyAttemptsTryLaterError', 'UidAlreadyExistsError', 'UnexpectedResponseError', 'UserDisabledError', @@ -130,10 +132,12 @@ OIDCProviderConfig = _auth_providers.OIDCProviderConfig PhoneNumberAlreadyExistsError = _auth_utils.PhoneNumberAlreadyExistsError ProviderConfig = _auth_providers.ProviderConfig +ResetPasswordExceedLimitError = _auth_utils.ResetPasswordExceedLimitError RevokedIdTokenError = _token_gen.RevokedIdTokenError RevokedSessionCookieError = _token_gen.RevokedSessionCookieError SAMLProviderConfig = _auth_providers.SAMLProviderConfig TokenSignError = _token_gen.TokenSignError +TooManyAttemptsTryLaterError = _auth_utils.TooManyAttemptsTryLaterError UidAlreadyExistsError = _auth_utils.UidAlreadyExistsError UnexpectedResponseError = _auth_utils.UnexpectedResponseError UserDisabledError = _auth_utils.UserDisabledError From 244f32b747d729e2d0728ce295b1e56605b00879 Mon Sep 17 00:00:00 2001 From: Jonathan Edey <145066863+jonathanedey@users.noreply.github.com> Date: Wed, 6 Mar 2024 17:06:01 -0500 Subject: [PATCH 175/226] fix: doc quotes (#778) --- firebase_admin/functions.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/firebase_admin/functions.py b/firebase_admin/functions.py index 7df9bc607..fa17dfc0c 100644 --- a/firebase_admin/functions.py +++ b/firebase_admin/functions.py @@ -363,7 +363,7 @@ class TaskOptions: Task IDs should be strings that contain only letters ([A-Za-z]), numbers ([0-9]), hyphens (-), and underscores (_) with a maximum length of 500 characters. If a task's ID is identical to that of an existing task or a task that was deleted or executed - recently then the call will throw an error with code "functions/task-already-exists". + recently then the call will throw an error with code `functions/task-already-exists`. Another task with the same ID can't be created for ~1hour after the original task was deleted or executed. @@ -376,7 +376,7 @@ class TaskOptions: are sequential or have sequential prefixes, for example using a timestamp, causes an increase in latency and error rates in all task commands. - "Push IDs" from the Firebase Realtime Database make poor IDs because they are based on + Push IDs from the Firebase Realtime Database make poor IDs because they are based on timestamps and will cause contention (slowdowns) in your task queue. Reversed push IDs however form a perfect distribution and are an ideal key. To reverse a string in Python use ``reversedString = someString[::-1]`` From 32b900b0b7f9d8f453f821bf684bf7023c799808 Mon Sep 17 00:00:00 2001 From: Jonathan Edey <145066863+jonathanedey@users.noreply.github.com> Date: Mon, 11 Mar 2024 10:49:56 -0400 Subject: [PATCH 176/226] [chore] Release 6.5.0 (#779) --- firebase_admin/__about__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/firebase_admin/__about__.py b/firebase_admin/__about__.py index 7ce5b6f79..75f3f4b41 100644 --- a/firebase_admin/__about__.py +++ b/firebase_admin/__about__.py @@ -14,7 +14,7 @@ """About information (version, etc) for Firebase Admin SDK.""" -__version__ = '6.4.0' +__version__ = '6.5.0' __title__ = 'firebase_admin' __author__ = 'Firebase' __license__ = 'Apache License 2.0' From 0e87c44b07e6edec811acdafd6c0c99574e88ef1 Mon Sep 17 00:00:00 2001 From: Lahiru Maramba Date: Wed, 3 Jul 2024 18:30:58 +0000 Subject: [PATCH 177/226] chore: Update integration test resources (#796) --- .../resources/integ-service-account.json.gpg | Bin 1733 -> 1762 bytes .gitignore | 1 + 2 files changed, 1 insertion(+) diff --git a/.github/resources/integ-service-account.json.gpg b/.github/resources/integ-service-account.json.gpg index e8cc3e2a2a970b6760faceb0ff7dfc3e4b02c60a..7740dccd8bdada2eecc181f75c552c00e912e5c2 100644 GIT binary patch literal 1762 zcmV<81|9i~4Fm}T0)n*iur~xoxnP%@af)-$RhZ&7)3w(?>+$-b1cw>C)uo%^A>P*=<*0iel03 za}Fsg1`ruYzno_!>aY;)OsTkc%h&&pe{$nZw5?@UjcH^%VjHu)VgwEnKQUK4HJk0= zyZ5gIrQ>xo>l)_!-wlcQad@p8gshGgv1xcg_Yjj!nW#U|n*LMBm3Wlbk?i_CZ5Ume zal+1184AaJ?*V#JRCa=MTz1K6B-nEb*&W)}8_{YN|7UTh+U!ds%WR1v?9Hbxk@ z+EogP!glp>+*yB=n_Y5O!Q;p67Vmj%K0Z&IQHoSQT$5HK`B~z1yiX&eY#WwU=~;7S z*R}F3&c9V&rDM=B?(>!Z9}?HjEfYzeD!pR<0+x@F>TU9zk7}`hv#*;%+wZyrU0HEofh5VjbOiejUW3pY z12D!+Wf#4>(VKPmD}rZ71T zja(Co^VJ`de?Q=DM7vMvigtGHs?54vGS7;%Hh%Kgpcm|%^7?KEi%S69Cr3HH;P{#2 zQKW(H3{(*{l=~Bf7Rl-tT)anXQbTFgdpURPL%zy_*wvVIb#5uC(O*IiG}XiMGUA%y z{hDOjK*5$4CjH6|s5_9ya=cleH7`i?@E_&%aD!Azwfl->_ z*G%8=hHUxvc;-w-8SZhmd?W7x5Q>Gdn&TWCqCer@&YB71xG=dZ z)$Q#4CnyByk9+}lQNhr0X%Yl^Z@x^E{?CS9l0qFDy5|%MRx3RMp&FRUbXi#z!svt^ z$1+LBRDKak{r8lR!uG7M$3y`i=mM%I6C;mOE*azp=yC&EjEO@hZAV~%vb<>vw*Tx}<@1-F%E1MJ z&Kx%>OC=ss;A2c#Xs*e!O^%Y!Dj=VPAc>{~z^9u8{@H%2A-4|8w>siL6sFHXXGAGp z6u#;-ci&~<6YWll-R0BZ(9=FJ*d=Z^1zpP&loXySLaK2LILSvW!Dc?^#B&6$r*#Z0VQL=xlo$_b_5MO*>kW6!}M4b=rO4;kC? z7>1#k9x6oenqHJ=>|?M)^;{lP0reAV^DsBc!yt{hJ=)44ze!{NMcolK=)E*z3xPIFQ}^9xIoZ8E2fWR@#G|3UOCwegz@y^#$8?WWnJ zs2>$dD-U!>e@QG8RfQO0o(=0E;V|(Wz7kVeam0oZT7r|6xiFJ zNd8hqdzqJ0-`fL-a@okskU*`zWMi=%n1PPxq!&m!cL|J#l~QZ}A3JebK?T zy$sIEE!|IkVGk9;6HzHA)9Y-IXTCA)$QGOVqdT1OhM%`yYT53qvYtc=4KjLTfqCH9 zgC*H0Jk^%P@UH!391-ApSu$Z{4D#3;W{vq*kO zsL+qpghA`Jm|4{gJrTOb%(>>3vR5{Am*n3SBI9H-+h-WW$}kgTg9MyYUUi*K+#nml Ero6&wRR910 literal 1733 zcmV;$20HnS4Fm}T0&*P1GU?%eegD$x0rvxBAXZGDSEYyg&@r!tuJO!^@&2Mg0RD-w z6$6e;;`CFWj*fn+wYp+93`zD7be^H9^(f2dArxBMIsVqnH^uk`NBpM}j)R?=LOWsA z{+*+Is!_mH0&L^@gva$WU4tSy{!^+>46sf3l~2(yGHLCQJ|AD}OSS4|w)W)bIhUCB ztLZxR_T>Ox?nQ%n2SuXok1L6Q?jrQ##qw?ASK7)@vt_ihr{>UJqi+)ws0`vx6lVaz zYoMV^widvCLEMCOzE9r4f`M#Cz#4ukZW(RbYSG4zyaS7o#kYgzHU2SEPW}TDgOeO$ zl8#OV2K4iNZ(_%sDt3-8G?e{}9)Xa63cE<5Tw~PUBhHb%f%Wp+7@Dii4KH~>ThXnc zJU+i2l19mPJ*o^DiA1mUbL4eB> zbfqZwFA9iW9)@X5^Jg;FmKH1LyrWqA*+tn$T-+S{Q!Ba4sb3{w26;5ycQyh=djxuG zE6Ef&YBF(Bdd_s{DEg>abVPH z6sfGhfLXxrTgj3Towx+c#p!H^NDd+(me+3C^t?B9kvUgNYxtM|xwc~5A3rV?MQ&Eb zRJ)mm($0#cd@i+T7wA;@d~q4WBv2zt(~SwCC&t<(Qzk>wvIas=r{SY@FlV!7%Y4X|e=JDOGfN84Xw7mlf9b{Z_{P(a;fJd&Ex zPE;unRrc%Lt6Yl#y!I` zW#bH&1CfNm3=Yxg&2XrnR$`$>mqEN6koziF5@98L^u}ja`*?#L9DqbkUsSVDO%mCg zyP&~Gf!nBhwHZTdZJZ=?(OfE%qT3Og6Pj0p(`{TM5&%c%Sp%Ly^p!vvKI|2`DN^t$ z+dont3_<4G_ds{)3CY0N32M6o8Nw=bodIN2N}2j6Kx>s0I%&=ci z8<<`m=?L9(VxxCHhes%UrPSTNQG&?qWdm{1;3a=FlA-^@>_1LqL4PsLKUIJ|lW}}ll^Rmr+`u^00-Zu3=0MK{X7=wW1xG9v zb>+uzSLWHC`%$|q#}_OwBP?rUGXCC>US#9h1-dE_KdPv%`8PBHTXX2${J{nG-7ZR7 z)j;WYkbTKBD-I^zOeBUw_AQ~!0`Ql&jp&}z(t&1V;13U#S4w>IC{WpQqf+-0w%_krosfryd7nMa+kp*`jl7_Q}9BY)T)sl`80~A!l zY=Kfl2&-TmHbwDIP%GrpRKMWoYsK z6~5b^5HMXvSl{rRG8_Innc{l2^8=++T>ul8K#~j(O_2?|&4mMfRkS+zMy2vVS}Cb;C8Qbj zh_XDiovWpqGvdj(vC6e_XKoj<`4>%Sm#lPtm?g(B1X|F~Ecq?+-dLoI<(r(K|0;(F z-GjL5-xIv&u}KCiyG&3X*caBWCC>sfo`m7udmy%Fq=do+s!GWP52a-~I()x-5i@;$ zKiD+X_ujNViv(}lsFTk0pW{G7J$y)Z-62(VZ^ diff --git a/.gitignore b/.gitignore index e5c1902d5..d9d47dc51 100644 --- a/.gitignore +++ b/.gitignore @@ -13,3 +13,4 @@ htmlcov/ .pytest_cache/ .vscode/ .venv/ +.DS_Store From 0d498a65dc7d055361b8f8e30053f4a117f266f1 Mon Sep 17 00:00:00 2001 From: Terence Nip Date: Thu, 1 Aug 2024 14:18:22 -0400 Subject: [PATCH 178/226] Fix link to `requests.Response` in FirebaseError. (#800) --- firebase_admin/exceptions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/firebase_admin/exceptions.py b/firebase_admin/exceptions.py index 06504225f..947f36806 100644 --- a/firebase_admin/exceptions.py +++ b/firebase_admin/exceptions.py @@ -91,7 +91,7 @@ class FirebaseError(Exception): cause: The exception that caused this error (optional). http_response: If this error was caused by an HTTP error response, this property is set to the ``requests.Response`` object that represents the HTTP response (optional). - See https://2.python-requests.org/en/master/api/#requests.Response for details of + See https://docs.python-requests.org/en/master/api/#requests.Response for details of this object. """ From b4700da4e9e362378113777653fa0e6ea216653e Mon Sep 17 00:00:00 2001 From: Lahiru Maramba Date: Wed, 28 Aug 2024 14:48:11 +0000 Subject: [PATCH 179/226] chore: Remove Python 3.7 from workflows (#806) * chore: Remove Python 3.7 from workflows * Skip batch send tests --- .github/workflows/ci.yml | 6 +++--- .github/workflows/nightly.yml | 2 +- .github/workflows/release.yml | 2 +- CONTRIBUTING.md | 2 +- integration/test_messaging.py | 2 ++ 5 files changed, 8 insertions(+), 6 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 00a01a908..127aa2e7e 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -8,7 +8,7 @@ jobs: strategy: fail-fast: false matrix: - python: ['3.7', '3.8', '3.9', '3.10', '3.11', '3.12', 'pypy3.8'] + python: ['3.8', '3.9', '3.10', '3.11', '3.12', 'pypy3.8'] steps: - uses: actions/checkout@v4 @@ -35,10 +35,10 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - - name: Set up Python 3.7 + - name: Set up Python 3.8 uses: actions/setup-python@v5 with: - python-version: 3.7 + python-version: 3.8 - name: Install dependencies run: | python -m pip install --upgrade pip diff --git a/.github/workflows/nightly.yml b/.github/workflows/nightly.yml index 0fe418cf7..935ee56ce 100644 --- a/.github/workflows/nightly.yml +++ b/.github/workflows/nightly.yml @@ -36,7 +36,7 @@ jobs: - name: Set up Python uses: actions/setup-python@v5 with: - python-version: 3.7 + python-version: 3.8 - name: Install dependencies run: | diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 00e1267c8..8271e9e67 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -47,7 +47,7 @@ jobs: - name: Set up Python uses: actions/setup-python@v5 with: - python-version: 3.7 + python-version: 3.8 - name: Install dependencies run: | diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index c06d7de2c..de5934866 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -85,7 +85,7 @@ information on using pull requests. ### Initial Setup -You need Python 3.7+ to build and test the code in this repo. +You need Python 3.8+ to build and test the code in this repo. We recommend using [pip](https://pypi.python.org/pypi/pip) for installing the necessary tools and project dependencies. Most recent versions of Python ship with pip. If your development environment diff --git a/integration/test_messaging.py b/integration/test_messaging.py index ab5d09b9e..522e87e85 100644 --- a/integration/test_messaging.py +++ b/integration/test_messaging.py @@ -148,6 +148,7 @@ def test_send_each_for_multicast(): assert response.exception is not None assert response.message_id is None +@pytest.mark.skip(reason="Replaced with test_send_each") def test_send_all(): messages = [ messaging.Message( @@ -179,6 +180,7 @@ def test_send_all(): assert isinstance(response.exception, exceptions.InvalidArgumentError) assert response.message_id is None +@pytest.mark.skip(reason="Replaced with test_send_each_500") def test_send_all_500(): messages = [] for msg_number in range(500): From 18e60438dae57afa4e197bca790b35dbf53f5877 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 10 Sep 2024 16:36:00 -0400 Subject: [PATCH 180/226] chore(deps): bump actions/download-artifact in /.github/workflows (#810) Bumps [actions/download-artifact](https://github.com/actions/download-artifact) from 1 to 4.1.7. - [Release notes](https://github.com/actions/download-artifact/releases) - [Commits](https://github.com/actions/download-artifact/compare/v1...v4.1.7) --- updated-dependencies: - dependency-name: actions/download-artifact dependency-type: direct:production ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/release.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 8271e9e67..9111ef547 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -105,7 +105,7 @@ jobs: # Download the artifacts created by the stage_release job. - name: Download release candidates - uses: actions/download-artifact@v1 + uses: actions/download-artifact@v4.1.7 with: name: dist From c0447297d5f1020d7f82e3d0f50e6220babd9b07 Mon Sep 17 00:00:00 2001 From: Lahiru Maramba Date: Thu, 12 Sep 2024 12:07:40 -0400 Subject: [PATCH 181/226] chore: Update `actions/upload-artifact` (#812) * chore: Update `actions/upload-artifact` * Trigger CI --- .github/workflows/nightly.yml | 2 +- .github/workflows/release.yml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/nightly.yml b/.github/workflows/nightly.yml index 935ee56ce..282cb1b91 100644 --- a/.github/workflows/nightly.yml +++ b/.github/workflows/nightly.yml @@ -62,7 +62,7 @@ jobs: # Attach the packaged artifacts to the workflow output. These can be manually # downloaded for later inspection if necessary. - name: Archive artifacts - uses: actions/upload-artifact@v1 + uses: actions/upload-artifact@v4 with: name: dist path: dist diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 9111ef547..7aab71b23 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -73,7 +73,7 @@ jobs: # Attach the packaged artifacts to the workflow output. These can be manually # downloaded for later inspection if necessary. - name: Archive artifacts - uses: actions/upload-artifact@v1 + uses: actions/upload-artifact@v4 with: name: dist path: dist From 8727e91739f2d04f2417705e098331362b8ac5c8 Mon Sep 17 00:00:00 2001 From: Jonathan Edey <145066863+jonathanedey@users.noreply.github.com> Date: Thu, 24 Oct 2024 10:18:05 -0400 Subject: [PATCH 182/226] feat(firestore): Add Firestore Multi Database Support (#818) * Added multi db support for firestore and firestore_async * Added unit and integration tests * fix docs strings --- firebase_admin/firestore.py | 88 ++++++++++++++++----------- firebase_admin/firestore_async.py | 94 ++++++++++++++++------------- integration/test_firestore.py | 55 +++++++++++++++++ integration/test_firestore_async.py | 69 +++++++++++++++++++-- tests/test_firestore.py | 86 ++++++++++++++++++++++++++ tests/test_firestore_async.py | 86 ++++++++++++++++++++++++++ 6 files changed, 396 insertions(+), 82 deletions(-) diff --git a/firebase_admin/firestore.py b/firebase_admin/firestore.py index 224ba3aeb..52ea90671 100644 --- a/firebase_admin/firestore.py +++ b/firebase_admin/firestore.py @@ -18,59 +18,75 @@ Firebase apps. This requires the ``google-cloud-firestore`` Python module. """ +from __future__ import annotations +from typing import Optional, Dict +from firebase_admin import App +from firebase_admin import _utils + try: - from google.cloud import firestore # pylint: disable=import-error,no-name-in-module + from google.cloud import firestore + from google.cloud.firestore_v1.base_client import DEFAULT_DATABASE existing = globals().keys() for key, value in firestore.__dict__.items(): if not key.startswith('_') and key not in existing: globals()[key] = value -except ImportError: +except ImportError as error: raise ImportError('Failed to import the Cloud Firestore library for Python. Make sure ' - 'to install the "google-cloud-firestore" module.') - -from firebase_admin import _utils + 'to install the "google-cloud-firestore" module.') from error _FIRESTORE_ATTRIBUTE = '_firestore' -def client(app=None) -> firestore.Client: +def client(app: Optional[App] = None, database_id: Optional[str] = None) -> firestore.Client: """Returns a client that can be used to interact with Google Cloud Firestore. Args: - app: An App instance (optional). + app: An App instance (optional). + database_id: The database ID of the Google Cloud Firestore database to be used. + Defaults to the default Firestore database ID if not specified or an empty string + (optional). Returns: - google.cloud.firestore.Firestore: A `Firestore Client`_. + google.cloud.firestore.Firestore: A `Firestore Client`_. Raises: - ValueError: If a project ID is not specified either via options, credentials or - environment variables, or if the specified project ID is not a valid string. + ValueError: If the specified database ID is not a valid string, or if a project ID is not + specified either via options, credentials or environment variables, or if the specified + project ID is not a valid string. - .. _Firestore Client: https://googlecloudplatform.github.io/google-cloud-python/latest\ - /firestore/client.html + .. _Firestore Client: https://cloud.google.com/python/docs/reference/firestore/latest/\ + google.cloud.firestore_v1.client.Client """ - fs_client = _utils.get_app_service(app, _FIRESTORE_ATTRIBUTE, _FirestoreClient.from_app) - return fs_client.get() - - -class _FirestoreClient: - """Holds a Google Cloud Firestore client instance.""" - - def __init__(self, credentials, project): - self._client = firestore.Client(credentials=credentials, project=project) - - def get(self): - return self._client - - @classmethod - def from_app(cls, app): - """Creates a new _FirestoreClient for the specified app.""" - credentials = app.credential.get_credential() - project = app.project_id - if not project: - raise ValueError( - 'Project ID is required to access Firestore. Either set the projectId option, ' - 'or use service account credentials. Alternatively, set the GOOGLE_CLOUD_PROJECT ' - 'environment variable.') - return _FirestoreClient(credentials, project) + # Validate database_id + if database_id is not None and not isinstance(database_id, str): + raise ValueError(f'database_id "{database_id}" must be a string or None.') + fs_service = _utils.get_app_service(app, _FIRESTORE_ATTRIBUTE, _FirestoreService) + return fs_service.get_client(database_id) + + +class _FirestoreService: + """Service that maintains a collection of firestore clients.""" + + def __init__(self, app: App) -> None: + self._app: App = app + self._clients: Dict[str, firestore.Client] = {} + + def get_client(self, database_id: Optional[str]) -> firestore.Client: + """Creates a client based on the database_id. These clients are cached.""" + database_id = database_id or DEFAULT_DATABASE + if database_id not in self._clients: + # Create a new client and cache it in _clients + credentials = self._app.credential.get_credential() + project = self._app.project_id + if not project: + raise ValueError( + 'Project ID is required to access Firestore. Either set the projectId option, ' + 'or use service account credentials. Alternatively, set the ' + 'GOOGLE_CLOUD_PROJECT environment variable.') + + fs_client = firestore.Client( + credentials=credentials, project=project, database=database_id) + self._clients[database_id] = fs_client + + return self._clients[database_id] diff --git a/firebase_admin/firestore_async.py b/firebase_admin/firestore_async.py index a63d5a761..4a197e9df 100644 --- a/firebase_admin/firestore_async.py +++ b/firebase_admin/firestore_async.py @@ -18,65 +18,75 @@ associated with Firebase apps. This requires the ``google-cloud-firestore`` Python module. """ -from typing import Type - -from firebase_admin import ( - App, - _utils, -) -from firebase_admin.credentials import Base +from __future__ import annotations +from typing import Optional, Dict +from firebase_admin import App +from firebase_admin import _utils try: - from google.cloud import firestore # type: ignore # pylint: disable=import-error,no-name-in-module + from google.cloud import firestore + from google.cloud.firestore_v1.base_client import DEFAULT_DATABASE existing = globals().keys() for key, value in firestore.__dict__.items(): if not key.startswith('_') and key not in existing: globals()[key] = value -except ImportError: +except ImportError as error: raise ImportError('Failed to import the Cloud Firestore library for Python. Make sure ' - 'to install the "google-cloud-firestore" module.') + 'to install the "google-cloud-firestore" module.') from error + _FIRESTORE_ASYNC_ATTRIBUTE: str = '_firestore_async' -def client(app: App = None) -> firestore.AsyncClient: +def client(app: Optional[App] = None, database_id: Optional[str] = None) -> firestore.AsyncClient: """Returns an async client that can be used to interact with Google Cloud Firestore. Args: - app: An App instance (optional). + app: An App instance (optional). + database_id: The database ID of the Google Cloud Firestore database to be used. + Defaults to the default Firestore database ID if not specified or an empty string + (optional). Returns: - google.cloud.firestore.Firestore_Async: A `Firestore Async Client`_. + google.cloud.firestore.Firestore_Async: A `Firestore Async Client`_. Raises: - ValueError: If a project ID is not specified either via options, credentials or - environment variables, or if the specified project ID is not a valid string. + ValueError: If the specified database ID is not a valid string, or if a project ID is not + specified either via options, credentials or environment variables, or if the specified + project ID is not a valid string. - .. _Firestore Async Client: https://googleapis.dev/python/firestore/latest/client.html + .. _Firestore Async Client: https://cloud.google.com/python/docs/reference/firestore/latest/\ + google.cloud.firestore_v1.async_client.AsyncClient """ - fs_client = _utils.get_app_service( - app, _FIRESTORE_ASYNC_ATTRIBUTE, _FirestoreAsyncClient.from_app) - return fs_client.get() - - -class _FirestoreAsyncClient: - """Holds a Google Cloud Firestore Async Client instance.""" - - def __init__(self, credentials: Type[Base], project: str) -> None: - self._client = firestore.AsyncClient(credentials=credentials, project=project) - - def get(self) -> firestore.AsyncClient: - return self._client - - @classmethod - def from_app(cls, app: App) -> "_FirestoreAsyncClient": - # Replace remove future reference quotes by importing annotations in Python 3.7+ b/238779406 - """Creates a new _FirestoreAsyncClient for the specified app.""" - credentials = app.credential.get_credential() - project = app.project_id - if not project: - raise ValueError( - 'Project ID is required to access Firestore. Either set the projectId option, ' - 'or use service account credentials. Alternatively, set the GOOGLE_CLOUD_PROJECT ' - 'environment variable.') - return _FirestoreAsyncClient(credentials, project) + # Validate database_id + if database_id is not None and not isinstance(database_id, str): + raise ValueError(f'database_id "{database_id}" must be a string or None.') + + fs_service = _utils.get_app_service(app, _FIRESTORE_ASYNC_ATTRIBUTE, _FirestoreAsyncService) + return fs_service.get_client(database_id) + +class _FirestoreAsyncService: + """Service that maintains a collection of firestore async clients.""" + + def __init__(self, app: App) -> None: + self._app: App = app + self._clients: Dict[str, firestore.AsyncClient] = {} + + def get_client(self, database_id: Optional[str]) -> firestore.AsyncClient: + """Creates an async client based on the database_id. These clients are cached.""" + database_id = database_id or DEFAULT_DATABASE + if database_id not in self._clients: + # Create a new client and cache it in _clients + credentials = self._app.credential.get_credential() + project = self._app.project_id + if not project: + raise ValueError( + 'Project ID is required to access Firestore. Either set the projectId option, ' + 'or use service account credentials. Alternatively, set the ' + 'GOOGLE_CLOUD_PROJECT environment variable.') + + fs_client = firestore.AsyncClient( + credentials=credentials, project=project, database=database_id) + self._clients[database_id] = fs_client + + return self._clients[database_id] diff --git a/integration/test_firestore.py b/integration/test_firestore.py index 2bc3d1931..fd39d9b8a 100644 --- a/integration/test_firestore.py +++ b/integration/test_firestore.py @@ -17,6 +17,20 @@ from firebase_admin import firestore +_CITY = { + 'name': u'Mountain View', + 'country': u'USA', + 'population': 77846, + 'capital': False + } + +_MOVIE = { + 'Name': u'Interstellar', + 'Year': 2014, + 'Runtime': u'2h 49m', + 'Academy Award Winner': True + } + def test_firestore(): client = firestore.client() @@ -35,6 +49,47 @@ def test_firestore(): doc.delete() assert doc.get().exists is False +def test_firestore_explicit_database_id(): + client = firestore.client(database_id='testing-database') + expected = _CITY + doc = client.collection('cities').document() + doc.set(expected) + + data = doc.get() + assert data.to_dict() == expected + + doc.delete() + data = doc.get() + assert data.exists is False + +def test_firestore_multi_db(): + city_client = firestore.client() + movie_client = firestore.client(database_id='testing-database') + + expected_city = _CITY + expected_movie = _MOVIE + + city_doc = city_client.collection('cities').document() + movie_doc = movie_client.collection('movies').document() + + city_doc.set(expected_city) + movie_doc.set(expected_movie) + + city_data = city_doc.get() + movie_data = movie_doc.get() + + assert city_data.to_dict() == expected_city + assert movie_data.to_dict() == expected_movie + + city_doc.delete() + movie_doc.delete() + + city_data = city_doc.get() + movie_data = movie_doc.get() + + assert city_data.exists is False + assert movie_data.exists is False + def test_server_timestamp(): client = firestore.client() expected = { diff --git a/integration/test_firestore_async.py b/integration/test_firestore_async.py index 2a5b93217..8b73dda0f 100644 --- a/integration/test_firestore_async.py +++ b/integration/test_firestore_async.py @@ -13,20 +13,31 @@ # limitations under the License. """Integration tests for firebase_admin.firestore_async module.""" +import asyncio import datetime import pytest from firebase_admin import firestore_async -@pytest.mark.asyncio -async def test_firestore_async(): - client = firestore_async.client() - expected = { +_CITY = { 'name': u'Mountain View', 'country': u'USA', 'population': 77846, 'capital': False } + +_MOVIE = { + 'Name': u'Interstellar', + 'Year': 2014, + 'Runtime': u'2h 49m', + 'Academy Award Winner': True + } + + +@pytest.mark.asyncio +async def test_firestore_async(): + client = firestore_async.client() + expected = _CITY doc = client.collection('cities').document() await doc.set(expected) @@ -37,6 +48,56 @@ async def test_firestore_async(): data = await doc.get() assert data.exists is False +@pytest.mark.asyncio +async def test_firestore_async_explicit_database_id(): + client = firestore_async.client(database_id='testing-database') + expected = _CITY + doc = client.collection('cities').document() + await doc.set(expected) + + data = await doc.get() + assert data.to_dict() == expected + + await doc.delete() + data = await doc.get() + assert data.exists is False + +@pytest.mark.asyncio +async def test_firestore_async_multi_db(): + city_client = firestore_async.client() + movie_client = firestore_async.client(database_id='testing-database') + + expected_city = _CITY + expected_movie = _MOVIE + + city_doc = city_client.collection('cities').document() + movie_doc = movie_client.collection('movies').document() + + await asyncio.gather( + city_doc.set(expected_city), + movie_doc.set(expected_movie) + ) + + data = await asyncio.gather( + city_doc.get(), + movie_doc.get() + ) + + assert data[0].to_dict() == expected_city + assert data[1].to_dict() == expected_movie + + await asyncio.gather( + city_doc.delete(), + movie_doc.delete() + ) + + data = await asyncio.gather( + city_doc.get(), + movie_doc.get() + ) + assert data[0].exists is False + assert data[1].exists is False + @pytest.mark.asyncio async def test_server_timestamp(): client = firestore_async.client() diff --git a/tests/test_firestore.py b/tests/test_firestore.py index 768eb637e..47debd54b 100644 --- a/tests/test_firestore.py +++ b/tests/test_firestore.py @@ -50,6 +50,7 @@ def test_project_id(self): client = firestore.client() assert client is not None assert client.project == 'explicit-project-id' + assert client._database == '(default)' def test_project_id_with_explicit_app(self): cred = credentials.Certificate(testutils.resource_filename('service_account.json')) @@ -57,6 +58,7 @@ def test_project_id_with_explicit_app(self): client = firestore.client(app=app) assert client is not None assert client.project == 'explicit-project-id' + assert client._database == '(default)' def test_service_account(self): cred = credentials.Certificate(testutils.resource_filename('service_account.json')) @@ -64,6 +66,7 @@ def test_service_account(self): client = firestore.client() assert client is not None assert client.project == 'mock-project-id' + assert client._database == '(default)' def test_service_account_with_explicit_app(self): cred = credentials.Certificate(testutils.resource_filename('service_account.json')) @@ -71,6 +74,89 @@ def test_service_account_with_explicit_app(self): client = firestore.client(app=app) assert client is not None assert client.project == 'mock-project-id' + assert client._database == '(default)' + + @pytest.mark.parametrize('database_id', [123, False, True, {}, []]) + def test_invalid_database_id(self, database_id): + cred = credentials.Certificate(testutils.resource_filename('service_account.json')) + firebase_admin.initialize_app(cred) + with pytest.raises(ValueError) as excinfo: + firestore.client(database_id=database_id) + assert str(excinfo.value) == f'database_id "{database_id}" must be a string or None.' + + def test_database_id(self): + cred = credentials.Certificate(testutils.resource_filename('service_account.json')) + firebase_admin.initialize_app(cred) + database_id = 'mock-database-id' + client = firestore.client(database_id=database_id) + assert client is not None + assert client.project == 'mock-project-id' + assert client._database == 'mock-database-id' + + @pytest.mark.parametrize('database_id', ['', '(default)', None]) + def test_database_id_with_default_id(self, database_id): + cred = credentials.Certificate(testutils.resource_filename('service_account.json')) + firebase_admin.initialize_app(cred) + client = firestore.client(database_id=database_id) + assert client is not None + assert client.project == 'mock-project-id' + assert client._database == '(default)' + + def test_database_id_with_explicit_app(self): + cred = credentials.Certificate(testutils.resource_filename('service_account.json')) + app = firebase_admin.initialize_app(cred) + database_id = 'mock-database-id' + client = firestore.client(app, database_id) + assert client is not None + assert client.project == 'mock-project-id' + assert client._database == 'mock-database-id' + + def test_database_id_with_multi_db(self): + cred = credentials.Certificate(testutils.resource_filename('service_account.json')) + firebase_admin.initialize_app(cred) + database_id_1 = 'mock-database-id-1' + database_id_2 = 'mock-database-id-2' + client_1 = firestore.client(database_id=database_id_1) + client_2 = firestore.client(database_id=database_id_2) + assert (client_1 is not None) and (client_2 is not None) + assert client_1 is not client_2 + assert client_1.project == 'mock-project-id' + assert client_2.project == 'mock-project-id' + assert client_1._database == 'mock-database-id-1' + assert client_2._database == 'mock-database-id-2' + + def test_database_id_with_multi_db_uses_cache(self): + cred = credentials.Certificate(testutils.resource_filename('service_account.json')) + firebase_admin.initialize_app(cred) + database_id = 'mock-database-id' + client_1 = firestore.client(database_id=database_id) + client_2 = firestore.client(database_id=database_id) + assert (client_1 is not None) and (client_2 is not None) + assert client_1 is client_2 + assert client_1.project == 'mock-project-id' + assert client_2.project == 'mock-project-id' + assert client_1._database == 'mock-database-id' + assert client_2._database == 'mock-database-id' + + def test_database_id_with_multi_db_uses_cache_default(self): + cred = credentials.Certificate(testutils.resource_filename('service_account.json')) + firebase_admin.initialize_app(cred) + database_id_1 = '' + database_id_2 = '(default)' + client_1 = firestore.client(database_id=database_id_1) + client_2 = firestore.client(database_id=database_id_2) + client_3 = firestore.client() + assert (client_1 is not None) and (client_2 is not None) and (client_3 is not None) + assert client_1 is client_2 + assert client_1 is client_3 + assert client_2 is client_3 + assert client_1.project == 'mock-project-id' + assert client_2.project == 'mock-project-id' + assert client_3.project == 'mock-project-id' + assert client_1._database == '(default)' + assert client_2._database == '(default)' + assert client_3._database == '(default)' + def test_geo_point(self): geo_point = firestore.GeoPoint(10, 20) # pylint: disable=no-member diff --git a/tests/test_firestore_async.py b/tests/test_firestore_async.py index 0fb17c813..3d17cbfc5 100644 --- a/tests/test_firestore_async.py +++ b/tests/test_firestore_async.py @@ -50,6 +50,7 @@ def test_project_id(self): client = firestore_async.client() assert client is not None assert client.project == 'explicit-project-id' + assert client._database == '(default)' def test_project_id_with_explicit_app(self): cred = credentials.Certificate(testutils.resource_filename('service_account.json')) @@ -57,6 +58,7 @@ def test_project_id_with_explicit_app(self): client = firestore_async.client(app=app) assert client is not None assert client.project == 'explicit-project-id' + assert client._database == '(default)' def test_service_account(self): cred = credentials.Certificate(testutils.resource_filename('service_account.json')) @@ -64,6 +66,7 @@ def test_service_account(self): client = firestore_async.client() assert client is not None assert client.project == 'mock-project-id' + assert client._database == '(default)' def test_service_account_with_explicit_app(self): cred = credentials.Certificate(testutils.resource_filename('service_account.json')) @@ -71,6 +74,89 @@ def test_service_account_with_explicit_app(self): client = firestore_async.client(app=app) assert client is not None assert client.project == 'mock-project-id' + assert client._database == '(default)' + + @pytest.mark.parametrize('database_id', [123, False, True, {}, []]) + def test_invalid_database_id(self, database_id): + cred = credentials.Certificate(testutils.resource_filename('service_account.json')) + firebase_admin.initialize_app(cred) + with pytest.raises(ValueError) as excinfo: + firestore_async.client(database_id=database_id) + assert str(excinfo.value) == f'database_id "{database_id}" must be a string or None.' + + def test_database_id(self): + cred = credentials.Certificate(testutils.resource_filename('service_account.json')) + firebase_admin.initialize_app(cred) + database_id = 'mock-database-id' + client = firestore_async.client(database_id=database_id) + assert client is not None + assert client.project == 'mock-project-id' + assert client._database == 'mock-database-id' + + @pytest.mark.parametrize('database_id', ['', '(default)', None]) + def test_database_id_with_default_id(self, database_id): + cred = credentials.Certificate(testutils.resource_filename('service_account.json')) + firebase_admin.initialize_app(cred) + client = firestore_async.client(database_id=database_id) + assert client is not None + assert client.project == 'mock-project-id' + assert client._database == '(default)' + + def test_database_id_with_explicit_app(self): + cred = credentials.Certificate(testutils.resource_filename('service_account.json')) + app = firebase_admin.initialize_app(cred) + database_id = 'mock-database-id' + client = firestore_async.client(app, database_id) + assert client is not None + assert client.project == 'mock-project-id' + assert client._database == 'mock-database-id' + + def test_database_id_with_multi_db(self): + cred = credentials.Certificate(testutils.resource_filename('service_account.json')) + firebase_admin.initialize_app(cred) + database_id_1 = 'mock-database-id-1' + database_id_2 = 'mock-database-id-2' + client_1 = firestore_async.client(database_id=database_id_1) + client_2 = firestore_async.client(database_id=database_id_2) + assert (client_1 is not None) and (client_2 is not None) + assert client_1 is not client_2 + assert client_1.project == 'mock-project-id' + assert client_2.project == 'mock-project-id' + assert client_1._database == 'mock-database-id-1' + assert client_2._database == 'mock-database-id-2' + + def test_database_id_with_multi_db_uses_cache(self): + cred = credentials.Certificate(testutils.resource_filename('service_account.json')) + firebase_admin.initialize_app(cred) + database_id = 'mock-database-id' + client_1 = firestore_async.client(database_id=database_id) + client_2 = firestore_async.client(database_id=database_id) + assert (client_1 is not None) and (client_2 is not None) + assert client_1 is client_2 + assert client_1.project == 'mock-project-id' + assert client_2.project == 'mock-project-id' + assert client_1._database == 'mock-database-id' + assert client_2._database == 'mock-database-id' + + def test_database_id_with_multi_db_uses_cache_default(self): + cred = credentials.Certificate(testutils.resource_filename('service_account.json')) + firebase_admin.initialize_app(cred) + database_id_1 = '' + database_id_2 = '(default)' + client_1 = firestore_async.client(database_id=database_id_1) + client_2 = firestore_async.client(database_id=database_id_2) + client_3 = firestore_async.client() + assert (client_1 is not None) and (client_2 is not None) and (client_3 is not None) + assert client_1 is client_2 + assert client_1 is client_3 + assert client_2 is client_3 + assert client_1.project == 'mock-project-id' + assert client_2.project == 'mock-project-id' + assert client_3.project == 'mock-project-id' + assert client_1._database == '(default)' + assert client_2._database == '(default)' + assert client_3._database == '(default)' + def test_geo_point(self): geo_point = firestore_async.GeoPoint(10, 20) # pylint: disable=no-member From 2a8170f85871c8c84579ba9a3e96f97dc512d34b Mon Sep 17 00:00:00 2001 From: Jonathan Edey <145066863+jonathanedey@users.noreply.github.com> Date: Thu, 24 Oct 2024 13:34:44 -0400 Subject: [PATCH 183/226] [chore] Bump cachecontrol (#819) --- requirements.txt | 2 +- setup.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/requirements.txt b/requirements.txt index acf09438b..fa48f7f57 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,7 +6,7 @@ pytest-localserver >= 0.4.1 pytest-asyncio >= 0.16.0 pytest-mock >= 3.6.1 -cachecontrol >= 0.12.6 +cachecontrol >= 0.12.14 google-api-core[grpc] >= 1.22.1, < 3.0.0dev; platform.python_implementation != 'PyPy' google-api-python-client >= 1.7.8 google-cloud-firestore >= 2.9.1; platform.python_implementation != 'PyPy' diff --git a/setup.py b/setup.py index ef30e6be6..e479e39e6 100644 --- a/setup.py +++ b/setup.py @@ -37,7 +37,7 @@ long_description = ('The Firebase Admin Python SDK enables server-side (backend) Python developers ' 'to integrate Firebase into their services and applications.') install_requires = [ - 'cachecontrol>=0.12.6', + 'cachecontrol>=0.12.14', 'google-api-core[grpc] >= 1.22.1, < 3.0.0dev; platform.python_implementation != "PyPy"', 'google-api-python-client >= 1.7.8', 'google-cloud-firestore>=2.9.1; platform.python_implementation != "PyPy"', From d8d6aea496a20fc5c7cc86f5335c2351f39cefd1 Mon Sep 17 00:00:00 2001 From: Lahiru Maramba Date: Thu, 24 Oct 2024 19:43:19 -0400 Subject: [PATCH 184/226] chore: Create dependabot.yml (#820) Add dependabot support to the repo --- .github/dependabot.yml | 6 ++++++ 1 file changed, 6 insertions(+) create mode 100644 .github/dependabot.yml diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 000000000..6a7695c06 --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,6 @@ +version: 2 +updates: + - package-ecosystem: "pip" + directory: "/" + schedule: + interval: "weekly" From 32e8dd284429f5dfbdaa21f6ac54f9d0698151cc Mon Sep 17 00:00:00 2001 From: Jonathan Edey <145066863+jonathanedey@users.noreply.github.com> Date: Mon, 4 Nov 2024 15:22:53 -0500 Subject: [PATCH 185/226] feat: Support passing `google.auth` typed credentials in `initialize_app()` (#821) * feat: Support passing `google.auth` typed credentials in `initialize_app()` * Refactor and add unit test --- firebase_admin/__init__.py | 8 ++++++-- firebase_admin/credentials.py | 14 ++++++++++++++ tests/test_app.py | 10 ++++++++++ 3 files changed, 30 insertions(+), 2 deletions(-) diff --git a/firebase_admin/__init__.py b/firebase_admin/__init__.py index 0ca82ec5e..7bb9c59c2 100644 --- a/firebase_admin/__init__.py +++ b/firebase_admin/__init__.py @@ -18,6 +18,7 @@ import os import threading +from google.auth.credentials import Credentials as GoogleAuthCredentials from google.auth.exceptions import DefaultCredentialsError from firebase_admin import credentials from firebase_admin.__about__ import __version__ @@ -208,10 +209,13 @@ def __init__(self, name, credential, options): 'non-empty string.'.format(name)) self._name = name - if not isinstance(credential, credentials.Base): + if isinstance(credential, GoogleAuthCredentials): + self._credential = credentials._ExternalCredentials(credential) # pylint: disable=protected-access + elif isinstance(credential, credentials.Base): + self._credential = credential + else: raise ValueError('Illegal Firebase credential provided. App must be initialized ' 'with a valid credential instance.') - self._credential = credential self._options = _AppOptions(options) self._lock = threading.RLock() self._services = {} diff --git a/firebase_admin/credentials.py b/firebase_admin/credentials.py index 5477e1cf7..750600280 100644 --- a/firebase_admin/credentials.py +++ b/firebase_admin/credentials.py @@ -18,6 +18,7 @@ import pathlib import google.auth +from google.auth.credentials import Credentials as GoogleAuthCredentials from google.auth.transport import requests from google.oauth2 import credentials from google.oauth2 import service_account @@ -58,6 +59,19 @@ def get_credential(self): """Returns the Google credential instance used for authentication.""" raise NotImplementedError +class _ExternalCredentials(Base): + """A wrapper for google.auth.credentials.Credentials typed credential instances""" + + def __init__(self, credential: GoogleAuthCredentials): + super(_ExternalCredentials, self).__init__() + self._g_credential = credential + + def get_credential(self): + """Returns the underlying Google Credential + + Returns: + google.auth.credentials.Credentials: A Google Auth credential instance.""" + return self._g_credential class Certificate(Base): """A credential initialized from a JSON certificate keyfile.""" diff --git a/tests/test_app.py b/tests/test_app.py index 4233d5849..5b203661f 100644 --- a/tests/test_app.py +++ b/tests/test_app.py @@ -246,6 +246,16 @@ def test_non_default_app_init(self, app_credential): with pytest.raises(ValueError): firebase_admin.initialize_app(app_credential, name='myApp') + def test_app_init_with_google_auth_cred(self): + cred = testutils.MockGoogleCredential() + assert isinstance(cred, credentials.GoogleAuthCredentials) + app = firebase_admin.initialize_app(cred) + assert cred is app.credential.get_credential() + assert isinstance(app.credential, credentials.Base) + assert isinstance(app.credential, credentials._ExternalCredentials) + with pytest.raises(ValueError): + firebase_admin.initialize_app(app_credential) + @pytest.mark.parametrize('cred', invalid_credentials) def test_app_init_with_invalid_credential(self, cred): with pytest.raises(ValueError): From be56a0ffe6c5fa87be64f49d77450c84c2db4d46 Mon Sep 17 00:00:00 2001 From: Jonathan Edey <145066863+jonathanedey@users.noreply.github.com> Date: Mon, 4 Nov 2024 15:47:08 -0500 Subject: [PATCH 186/226] chore: Add `X-Goog-Api-Client` metric header to requests (#826) --- firebase_admin/_http_client.py | 5 + firebase_admin/_utils.py | 3 + firebase_admin/app_check.py | 7 +- firebase_admin/storage.py | 7 +- tests/test_auth_providers.py | 190 +++++++++++++------------------ tests/test_db.py | 118 ++++++++----------- tests/test_functions.py | 4 + tests/test_http_client.py | 14 ++- tests/test_instance_id.py | 21 ++-- tests/test_messaging.py | 71 ++++++------ tests/test_ml.py | 74 +++++------- tests/test_project_management.py | 5 +- tests/test_tenant_mgt.py | 13 +++ tests/test_user_mgt.py | 2 + 14 files changed, 255 insertions(+), 279 deletions(-) diff --git a/firebase_admin/_http_client.py b/firebase_admin/_http_client.py index d259faddf..f1eccbcf2 100644 --- a/firebase_admin/_http_client.py +++ b/firebase_admin/_http_client.py @@ -21,6 +21,7 @@ import requests from requests.packages.urllib3.util import retry # pylint: disable=import-error +from firebase_admin import _utils if hasattr(retry.Retry.DEFAULT, 'allowed_methods'): _ANY_METHOD = {'allowed_methods': None} @@ -36,6 +37,9 @@ DEFAULT_TIMEOUT_SECONDS = 120 +METRICS_HEADERS = { + 'X-GOOG-API-CLIENT': _utils.get_metrics_header(), +} class HttpClient: """Base HTTP client used to make HTTP calls. @@ -72,6 +76,7 @@ def __init__( if headers: self._session.headers.update(headers) + self._session.headers.update(METRICS_HEADERS) if retries: self._session.mount('http://', requests.adapters.HTTPAdapter(max_retries=retries)) self._session.mount('https://', requests.adapters.HTTPAdapter(max_retries=retries)) diff --git a/firebase_admin/_utils.py b/firebase_admin/_utils.py index dcfb520d2..b6e292546 100644 --- a/firebase_admin/_utils.py +++ b/firebase_admin/_utils.py @@ -15,6 +15,7 @@ """Internal utilities common to all modules.""" import json +from platform import python_version import google.auth import requests @@ -75,6 +76,8 @@ 16: exceptions.UNAUTHENTICATED, } +def get_metrics_header(): + return f'gl-python/{python_version()} fire-admin/{firebase_admin.__version__}' def _get_initialized_app(app): """Returns a reference to an initialized App instance.""" diff --git a/firebase_admin/app_check.py b/firebase_admin/app_check.py index 6bc10b2f4..e6b66efc1 100644 --- a/firebase_admin/app_check.py +++ b/firebase_admin/app_check.py @@ -51,6 +51,10 @@ class _AppCheckService: _scoped_project_id = None _jwks_client = None + _APP_CHECK_HEADERS = { + 'X-GOOG-API-CLIENT': _utils.get_metrics_header(), + } + def __init__(self, app): # Validate and store the project_id to validate the JWT claims self._project_id = app.project_id @@ -62,7 +66,8 @@ def __init__(self, app): 'GOOGLE_CLOUD_PROJECT environment variable.') self._scoped_project_id = 'projects/' + app.project_id # Default lifespan is 300 seconds (5 minutes) so we change it to 21600 seconds (6 hours). - self._jwks_client = PyJWKClient(self._JWKS_URL, lifespan=21600) + self._jwks_client = PyJWKClient( + self._JWKS_URL, lifespan=21600, headers=self._APP_CHECK_HEADERS) def verify_token(self, token: str) -> Dict[str, Any]: diff --git a/firebase_admin/storage.py b/firebase_admin/storage.py index f3948371c..46f5f6043 100644 --- a/firebase_admin/storage.py +++ b/firebase_admin/storage.py @@ -55,8 +55,13 @@ def bucket(name=None, app=None) -> storage.Bucket: class _StorageClient: """Holds a Google Cloud Storage client instance.""" + STORAGE_HEADERS = { + 'X-GOOG-API-CLIENT': _utils.get_metrics_header(), + } + def __init__(self, credentials, project, default_bucket): - self._client = storage.Client(credentials=credentials, project=project) + self._client = storage.Client( + credentials=credentials, project=project, extra_headers=self.STORAGE_HEADERS) self._default_bucket = default_bucket @classmethod diff --git a/tests/test_auth_providers.py b/tests/test_auth_providers.py index a5716266c..48f38a011 100644 --- a/tests/test_auth_providers.py +++ b/tests/test_auth_providers.py @@ -21,6 +21,7 @@ import firebase_admin from firebase_admin import auth from firebase_admin import exceptions +from firebase_admin import _utils from tests import testutils ID_TOOLKIT_URL = 'https://identitytoolkit.googleapis.com/v2' @@ -70,6 +71,11 @@ def _instrument_provider_mgt(app, status, payload): testutils.MockAdapter(payload, status, recorder)) return recorder +def _assert_request(request, expected_method, expected_url): + assert request.method == expected_method + assert request.url == expected_url + assert request.headers['X-Client-Version'] == f'Python/Admin/{firebase_admin.__version__}' + assert request.headers['X-GOOG-API-CLIENT'] == _utils.get_metrics_header() class TestOIDCProviderConfig: @@ -110,9 +116,8 @@ def test_get(self, user_mgt_app): self._assert_provider_config(provider_config) assert len(recorder) == 1 - req = recorder[0] - assert req.method == 'GET' - assert req.url == '{0}{1}'.format(USER_MGT_URLS['PREFIX'], '/oauthIdpConfigs/oidc.provider') + _assert_request( + recorder[0], 'GET', f'{USER_MGT_URLS["PREFIX"]}/oauthIdpConfigs/oidc.provider') @pytest.mark.parametrize('invalid_opts', [ {'provider_id': None}, {'provider_id': ''}, {'provider_id': 'saml.provider'}, @@ -140,11 +145,9 @@ def test_create(self, user_mgt_app): self._assert_provider_config(provider_config) assert len(recorder) == 1 - req = recorder[0] - assert req.method == 'POST' - assert req.url == '{0}/oauthIdpConfigs?oauthIdpConfigId=oidc.provider'.format( - USER_MGT_URLS['PREFIX']) - got = json.loads(req.body.decode()) + _assert_request(recorder[0], 'POST', + f'{USER_MGT_URLS["PREFIX"]}/oauthIdpConfigs?oauthIdpConfigId=oidc.provider') + got = json.loads(recorder[0].body.decode()) assert got == self.OIDC_CONFIG_REQUEST def test_create_minimal(self, user_mgt_app): @@ -165,11 +168,9 @@ def test_create_minimal(self, user_mgt_app): self._assert_provider_config(provider_config) assert len(recorder) == 1 - req = recorder[0] - assert req.method == 'POST' - assert req.url == '{0}/oauthIdpConfigs?oauthIdpConfigId=oidc.provider'.format( - USER_MGT_URLS['PREFIX']) - got = json.loads(req.body.decode()) + _assert_request(recorder[0], 'POST', + f'{USER_MGT_URLS["PREFIX"]}/oauthIdpConfigs?oauthIdpConfigId=oidc.provider') + got = json.loads(recorder[0].body.decode()) assert got == want def test_create_empty_values(self, user_mgt_app): @@ -191,11 +192,9 @@ def test_create_empty_values(self, user_mgt_app): self._assert_provider_config(provider_config) assert len(recorder) == 1 - req = recorder[0] - assert req.method == 'POST' - assert req.url == '{0}/oauthIdpConfigs?oauthIdpConfigId=oidc.provider'.format( - USER_MGT_URLS['PREFIX']) - got = json.loads(req.body.decode()) + _assert_request(recorder[0], 'POST', + f'{USER_MGT_URLS["PREFIX"]}/oauthIdpConfigs?oauthIdpConfigId=oidc.provider') + got = json.loads(recorder[0].body.decode()) assert got == want @pytest.mark.parametrize('invalid_opts', [ @@ -225,13 +224,12 @@ def test_update(self, user_mgt_app): self._assert_provider_config(provider_config) assert len(recorder) == 1 - req = recorder[0] - assert req.method == 'PATCH' mask = ['clientId', 'clientSecret', 'displayName', 'enabled', 'issuer', 'responseType.code', 'responseType.idToken'] - assert req.url == '{0}/oauthIdpConfigs/oidc.provider?updateMask={1}'.format( - USER_MGT_URLS['PREFIX'], ','.join(mask)) - got = json.loads(req.body.decode()) + _assert_request(recorder[0], 'PATCH', + f'{USER_MGT_URLS["PREFIX"]}/oauthIdpConfigs/oidc.provider?' + f'updateMask={",".join(mask)}') + got = json.loads(recorder[0].body.decode()) assert got == self.OIDC_CONFIG_REQUEST def test_update_minimal(self, user_mgt_app): @@ -242,11 +240,10 @@ def test_update_minimal(self, user_mgt_app): self._assert_provider_config(provider_config) assert len(recorder) == 1 - req = recorder[0] - assert req.method == 'PATCH' - assert req.url == '{0}/oauthIdpConfigs/oidc.provider?updateMask=displayName'.format( - USER_MGT_URLS['PREFIX']) - got = json.loads(req.body.decode()) + _assert_request(recorder[0], 'PATCH', + f'{USER_MGT_URLS["PREFIX"]}/oauthIdpConfigs/oidc.provider?' + f'updateMask=displayName') + got = json.loads(recorder[0].body.decode()) assert got == {'displayName': 'oidcProviderName'} def test_update_empty_values(self, user_mgt_app): @@ -258,12 +255,11 @@ def test_update_empty_values(self, user_mgt_app): self._assert_provider_config(provider_config) assert len(recorder) == 1 - req = recorder[0] - assert req.method == 'PATCH' mask = ['displayName', 'enabled', 'responseType.idToken'] - assert req.url == '{0}/oauthIdpConfigs/oidc.provider?updateMask={1}'.format( - USER_MGT_URLS['PREFIX'], ','.join(mask)) - got = json.loads(req.body.decode()) + _assert_request(recorder[0], 'PATCH', + f'{USER_MGT_URLS["PREFIX"]}/oauthIdpConfigs/oidc.provider?' + f'updateMask={",".join(mask)}') + got = json.loads(recorder[0].body.decode()) assert got == {'displayName': None, 'enabled': False, 'responseType': {'idToken': False}} @pytest.mark.parametrize('provider_id', INVALID_PROVIDER_IDS + ['saml.provider']) @@ -279,9 +275,8 @@ def test_delete(self, user_mgt_app): auth.delete_oidc_provider_config('oidc.provider', app=user_mgt_app) assert len(recorder) == 1 - req = recorder[0] - assert req.method == 'DELETE' - assert req.url == '{0}{1}'.format(USER_MGT_URLS['PREFIX'], '/oauthIdpConfigs/oidc.provider') + _assert_request(recorder[0], 'DELETE', + f'{USER_MGT_URLS["PREFIX"]}/oauthIdpConfigs/oidc.provider') @pytest.mark.parametrize('arg', [None, 'foo', list(), dict(), 0, -1, 101, False]) def test_invalid_max_results(self, user_mgt_app, arg): @@ -302,9 +297,8 @@ def test_list_single_page(self, user_mgt_app): assert len(provider_configs) == 2 assert len(recorder) == 1 - req = recorder[0] - assert req.method == 'GET' - assert req.url == '{0}{1}'.format(USER_MGT_URLS['PREFIX'], '/oauthIdpConfigs?pageSize=100') + _assert_request(recorder[0], 'GET', + f'{USER_MGT_URLS["PREFIX"]}/oauthIdpConfigs?pageSize=100') def test_list_multiple_pages(self, user_mgt_app): sample_response = json.loads(OIDC_PROVIDER_CONFIG_RESPONSE) @@ -320,9 +314,8 @@ def test_list_multiple_pages(self, user_mgt_app): self._assert_page(page, next_page_token='token') assert len(recorder) == 1 - req = recorder[0] - assert req.method == 'GET' - assert req.url == '{0}/oauthIdpConfigs?pageSize=10'.format(USER_MGT_URLS['PREFIX']) + _assert_request(recorder[0], 'GET', + f'{USER_MGT_URLS["PREFIX"]}/oauthIdpConfigs?pageSize=10') # Page 2 (also the last page) response = {'oauthIdpConfigs': configs[2:]} @@ -331,10 +324,8 @@ def test_list_multiple_pages(self, user_mgt_app): self._assert_page(page, count=1, start=2) assert len(recorder) == 1 - req = recorder[0] - assert req.method == 'GET' - assert req.url == '{0}/oauthIdpConfigs?pageSize=10&pageToken=token'.format( - USER_MGT_URLS['PREFIX']) + _assert_request(recorder[0], 'GET', + f'{USER_MGT_URLS["PREFIX"]}/oauthIdpConfigs?pageSize=10&pageToken=token') def test_paged_iteration(self, user_mgt_app): sample_response = json.loads(OIDC_PROVIDER_CONFIG_RESPONSE) @@ -353,9 +344,8 @@ def test_paged_iteration(self, user_mgt_app): provider_config = next(iterator) assert provider_config.provider_id == 'oidc.provider{0}'.format(index) assert len(recorder) == 1 - req = recorder[0] - assert req.method == 'GET' - assert req.url == '{0}/oauthIdpConfigs?pageSize=100'.format(USER_MGT_URLS['PREFIX']) + _assert_request(recorder[0], 'GET', + f'{USER_MGT_URLS["PREFIX"]}/oauthIdpConfigs?pageSize=100') # Page 2 (also the last page) response = {'oauthIdpConfigs': configs[2:]} @@ -364,10 +354,8 @@ def test_paged_iteration(self, user_mgt_app): provider_config = next(iterator) assert provider_config.provider_id == 'oidc.provider2' assert len(recorder) == 1 - req = recorder[0] - assert req.method == 'GET' - assert req.url == '{0}/oauthIdpConfigs?pageSize=100&pageToken=token'.format( - USER_MGT_URLS['PREFIX']) + _assert_request(recorder[0], 'GET', + f'{USER_MGT_URLS["PREFIX"]}/oauthIdpConfigs?pageSize=100&pageToken=token') with pytest.raises(StopIteration): next(iterator) @@ -464,10 +452,8 @@ def test_get(self, user_mgt_app): self._assert_provider_config(provider_config) assert len(recorder) == 1 - req = recorder[0] - assert req.method == 'GET' - assert req.url == '{0}{1}'.format(USER_MGT_URLS['PREFIX'], - '/inboundSamlConfigs/saml.provider') + _assert_request(recorder[0], 'GET', + f'{USER_MGT_URLS["PREFIX"]}/inboundSamlConfigs/saml.provider') @pytest.mark.parametrize('invalid_opts', [ {'provider_id': None}, {'provider_id': ''}, {'provider_id': 'oidc.provider'}, @@ -494,11 +480,10 @@ def test_create(self, user_mgt_app): self._assert_provider_config(provider_config) assert len(recorder) == 1 - req = recorder[0] - assert req.method == 'POST' - assert req.url == '{0}/inboundSamlConfigs?inboundSamlConfigId=saml.provider'.format( - USER_MGT_URLS['PREFIX']) - got = json.loads(req.body.decode()) + _assert_request(recorder[0], 'POST', + f'{USER_MGT_URLS["PREFIX"]}/inboundSamlConfigs?' + f'inboundSamlConfigId=saml.provider') + got = json.loads(recorder[0].body.decode()) assert got == self.SAML_CONFIG_REQUEST def test_create_minimal(self, user_mgt_app): @@ -514,11 +499,10 @@ def test_create_minimal(self, user_mgt_app): self._assert_provider_config(provider_config) assert len(recorder) == 1 - req = recorder[0] - assert req.method == 'POST' - assert req.url == '{0}/inboundSamlConfigs?inboundSamlConfigId=saml.provider'.format( - USER_MGT_URLS['PREFIX']) - got = json.loads(req.body.decode()) + _assert_request(recorder[0], 'POST', + f'{USER_MGT_URLS["PREFIX"]}/inboundSamlConfigs?' + f'inboundSamlConfigId=saml.provider') + got = json.loads(recorder[0].body.decode()) assert got == want def test_create_empty_values(self, user_mgt_app): @@ -534,11 +518,10 @@ def test_create_empty_values(self, user_mgt_app): self._assert_provider_config(provider_config) assert len(recorder) == 1 - req = recorder[0] - assert req.method == 'POST' - assert req.url == '{0}/inboundSamlConfigs?inboundSamlConfigId=saml.provider'.format( - USER_MGT_URLS['PREFIX']) - got = json.loads(req.body.decode()) + _assert_request(recorder[0], 'POST', + f'{USER_MGT_URLS["PREFIX"]}/inboundSamlConfigs?' + f'inboundSamlConfigId=saml.provider') + got = json.loads(recorder[0].body.decode()) assert got == want @pytest.mark.parametrize('invalid_opts', [ @@ -567,15 +550,14 @@ def test_update(self, user_mgt_app): self._assert_provider_config(provider_config) assert len(recorder) == 1 - req = recorder[0] - assert req.method == 'PATCH' mask = [ 'displayName', 'enabled', 'idpConfig.idpCertificates', 'idpConfig.idpEntityId', 'idpConfig.ssoUrl', 'spConfig.callbackUri', 'spConfig.spEntityId', ] - assert req.url == '{0}/inboundSamlConfigs/saml.provider?updateMask={1}'.format( - USER_MGT_URLS['PREFIX'], ','.join(mask)) - got = json.loads(req.body.decode()) + _assert_request(recorder[0], 'PATCH', + f'{USER_MGT_URLS["PREFIX"]}/inboundSamlConfigs/saml.provider?' + f'updateMask={",".join(mask)}') + got = json.loads(recorder[0].body.decode()) assert got == self.SAML_CONFIG_REQUEST def test_update_minimal(self, user_mgt_app): @@ -586,11 +568,10 @@ def test_update_minimal(self, user_mgt_app): self._assert_provider_config(provider_config) assert len(recorder) == 1 - req = recorder[0] - assert req.method == 'PATCH' - assert req.url == '{0}/inboundSamlConfigs/saml.provider?updateMask=displayName'.format( - USER_MGT_URLS['PREFIX']) - got = json.loads(req.body.decode()) + _assert_request(recorder[0], 'PATCH', + f'{USER_MGT_URLS["PREFIX"]}/inboundSamlConfigs/saml.provider?' + f'updateMask=displayName') + got = json.loads(recorder[0].body.decode()) assert got == {'displayName': 'samlProviderName'} def test_update_empty_values(self, user_mgt_app): @@ -601,12 +582,11 @@ def test_update_empty_values(self, user_mgt_app): self._assert_provider_config(provider_config) assert len(recorder) == 1 - req = recorder[0] - assert req.method == 'PATCH' mask = ['displayName', 'enabled'] - assert req.url == '{0}/inboundSamlConfigs/saml.provider?updateMask={1}'.format( - USER_MGT_URLS['PREFIX'], ','.join(mask)) - got = json.loads(req.body.decode()) + _assert_request(recorder[0], 'PATCH', + f'{USER_MGT_URLS["PREFIX"]}/inboundSamlConfigs/saml.provider?' + f'updateMask={",".join(mask)}') + got = json.loads(recorder[0].body.decode()) assert got == {'displayName': None, 'enabled': False} @pytest.mark.parametrize('provider_id', INVALID_PROVIDER_IDS + ['oidc.provider']) @@ -622,10 +602,8 @@ def test_delete(self, user_mgt_app): auth.delete_saml_provider_config('saml.provider', app=user_mgt_app) assert len(recorder) == 1 - req = recorder[0] - assert req.method == 'DELETE' - assert req.url == '{0}{1}'.format(USER_MGT_URLS['PREFIX'], - '/inboundSamlConfigs/saml.provider') + _assert_request( + recorder[0], 'DELETE', f'{USER_MGT_URLS["PREFIX"]}/inboundSamlConfigs/saml.provider') def test_config_not_found(self, user_mgt_app): _instrument_provider_mgt(user_mgt_app, 500, CONFIG_NOT_FOUND_RESPONSE) @@ -658,10 +636,8 @@ def test_list_single_page(self, user_mgt_app): assert len(provider_configs) == 2 assert len(recorder) == 1 - req = recorder[0] - assert req.method == 'GET' - assert req.url == '{0}{1}'.format(USER_MGT_URLS['PREFIX'], - '/inboundSamlConfigs?pageSize=100') + _assert_request( + recorder[0], 'GET', f'{USER_MGT_URLS["PREFIX"]}/inboundSamlConfigs?pageSize=100') def test_list_multiple_pages(self, user_mgt_app): sample_response = json.loads(SAML_PROVIDER_CONFIG_RESPONSE) @@ -677,9 +653,8 @@ def test_list_multiple_pages(self, user_mgt_app): self._assert_page(page, next_page_token='token') assert len(recorder) == 1 - req = recorder[0] - assert req.method == 'GET' - assert req.url == '{0}/inboundSamlConfigs?pageSize=10'.format(USER_MGT_URLS['PREFIX']) + _assert_request( + recorder[0], 'GET', f'{USER_MGT_URLS["PREFIX"]}/inboundSamlConfigs?pageSize=10') # Page 2 (also the last page) response = {'inboundSamlConfigs': configs[2:]} @@ -688,10 +663,9 @@ def test_list_multiple_pages(self, user_mgt_app): self._assert_page(page, count=1, start=2) assert len(recorder) == 1 - req = recorder[0] - assert req.method == 'GET' - assert req.url == '{0}/inboundSamlConfigs?pageSize=10&pageToken=token'.format( - USER_MGT_URLS['PREFIX']) + _assert_request( + recorder[0], 'GET', + f'{USER_MGT_URLS["PREFIX"]}/inboundSamlConfigs?pageSize=10&pageToken=token') def test_paged_iteration(self, user_mgt_app): sample_response = json.loads(SAML_PROVIDER_CONFIG_RESPONSE) @@ -710,9 +684,8 @@ def test_paged_iteration(self, user_mgt_app): provider_config = next(iterator) assert provider_config.provider_id == 'saml.provider{0}'.format(index) assert len(recorder) == 1 - req = recorder[0] - assert req.method == 'GET' - assert req.url == '{0}/inboundSamlConfigs?pageSize=100'.format(USER_MGT_URLS['PREFIX']) + _assert_request( + recorder[0], 'GET', f'{USER_MGT_URLS["PREFIX"]}/inboundSamlConfigs?pageSize=100') # Page 2 (also the last page) response = {'inboundSamlConfigs': configs[2:]} @@ -721,10 +694,9 @@ def test_paged_iteration(self, user_mgt_app): provider_config = next(iterator) assert provider_config.provider_id == 'saml.provider2' assert len(recorder) == 1 - req = recorder[0] - assert req.method == 'GET' - assert req.url == '{0}/inboundSamlConfigs?pageSize=100&pageToken=token'.format( - USER_MGT_URLS['PREFIX']) + _assert_request( + recorder[0], 'GET', + f'{USER_MGT_URLS["PREFIX"]}/inboundSamlConfigs?pageSize=100&pageToken=token') with pytest.raises(StopIteration): next(iterator) diff --git a/tests/test_db.py b/tests/test_db.py index aa2c83bd9..4245f65fb 100644 --- a/tests/test_db.py +++ b/tests/test_db.py @@ -193,16 +193,20 @@ def instrument(self, ref, payload, status=200, etag=MockAdapter.ETAG): ref._client.session.mount(self.test_url, adapter) return recorder + def _assert_request(self, request, expected_method, expected_url): + assert request.method == expected_method + assert request.url == expected_url + assert request.headers['Authorization'] == 'Bearer mock-token' + assert request.headers['User-Agent'] == db._USER_AGENT + assert request.headers['X-GOOG-API-CLIENT'] == _utils.get_metrics_header() + @pytest.mark.parametrize('data', valid_values) def test_get_value(self, data): ref = db.reference('/test') recorder = self.instrument(ref, json.dumps(data)) assert ref.get() == data assert len(recorder) == 1 - assert recorder[0].method == 'GET' - assert recorder[0].url == 'https://test.firebaseio.com/test.json' - assert recorder[0].headers['Authorization'] == 'Bearer mock-token' - assert recorder[0].headers['User-Agent'] == db._USER_AGENT + self._assert_request(recorder[0], 'GET', 'https://test.firebaseio.com/test.json') assert 'X-Firebase-ETag' not in recorder[0].headers @pytest.mark.parametrize('data', valid_values) @@ -211,10 +215,7 @@ def test_get_with_etag(self, data): recorder = self.instrument(ref, json.dumps(data)) assert ref.get(etag=True) == (data, MockAdapter.ETAG) assert len(recorder) == 1 - assert recorder[0].method == 'GET' - assert recorder[0].url == 'https://test.firebaseio.com/test.json' - assert recorder[0].headers['Authorization'] == 'Bearer mock-token' - assert recorder[0].headers['User-Agent'] == db._USER_AGENT + self._assert_request(recorder[0], 'GET', 'https://test.firebaseio.com/test.json') assert recorder[0].headers['X-Firebase-ETag'] == 'true' @pytest.mark.parametrize('data', valid_values) @@ -223,10 +224,8 @@ def test_get_shallow(self, data): recorder = self.instrument(ref, json.dumps(data)) assert ref.get(shallow=True) == data assert len(recorder) == 1 - assert recorder[0].method == 'GET' - assert recorder[0].url == 'https://test.firebaseio.com/test.json?shallow=true' - assert recorder[0].headers['Authorization'] == 'Bearer mock-token' - assert recorder[0].headers['User-Agent'] == db._USER_AGENT + self._assert_request( + recorder[0], 'GET', 'https://test.firebaseio.com/test.json?shallow=true') def test_get_with_etag_and_shallow(self): ref = db.reference('/test') @@ -240,14 +239,12 @@ def test_get_if_changed(self, data): assert ref.get_if_changed('invalid-etag') == (True, data, MockAdapter.ETAG) assert len(recorder) == 1 - assert recorder[0].method == 'GET' - assert recorder[0].url == 'https://test.firebaseio.com/test.json' + self._assert_request(recorder[0], 'GET', 'https://test.firebaseio.com/test.json') assert recorder[0].headers['if-none-match'] == 'invalid-etag' assert ref.get_if_changed(MockAdapter.ETAG) == (False, None, None) assert len(recorder) == 2 - assert recorder[1].method == 'GET' - assert recorder[1].url == 'https://test.firebaseio.com/test.json' + self._assert_request(recorder[1], 'GET', 'https://test.firebaseio.com/test.json') assert recorder[1].headers['if-none-match'] == MockAdapter.ETAG @pytest.mark.parametrize('etag', [0, 1, True, False, dict(), list(), tuple()]) @@ -264,9 +261,8 @@ def test_order_by_query(self, data): query_str = 'orderBy=%22foo%22' assert query.get() == data assert len(recorder) == 1 - assert recorder[0].method == 'GET' - assert recorder[0].url == 'https://test.firebaseio.com/test.json?' + query_str - assert recorder[0].headers['Authorization'] == 'Bearer mock-token' + self._assert_request( + recorder[0], 'GET', 'https://test.firebaseio.com/test.json?' + query_str) @pytest.mark.parametrize('data', valid_values) def test_limit_query(self, data): @@ -277,9 +273,8 @@ def test_limit_query(self, data): query_str = 'limitToFirst=100&orderBy=%22foo%22' assert query.get() == data assert len(recorder) == 1 - assert recorder[0].method == 'GET' - assert recorder[0].url == 'https://test.firebaseio.com/test.json?' + query_str - assert recorder[0].headers['Authorization'] == 'Bearer mock-token' + self._assert_request( + recorder[0], 'GET', 'https://test.firebaseio.com/test.json?' + query_str) @pytest.mark.parametrize('data', valid_values) def test_range_query(self, data): @@ -291,9 +286,8 @@ def test_range_query(self, data): query_str = 'endAt=200&orderBy=%22foo%22&startAt=100' assert query.get() == data assert len(recorder) == 1 - assert recorder[0].method == 'GET' - assert recorder[0].url == 'https://test.firebaseio.com/test.json?' + query_str - assert recorder[0].headers['Authorization'] == 'Bearer mock-token' + self._assert_request( + recorder[0], 'GET', 'https://test.firebaseio.com/test.json?' + query_str) @pytest.mark.parametrize('data', valid_values) def test_set_value(self, data): @@ -301,10 +295,9 @@ def test_set_value(self, data): recorder = self.instrument(ref, '') ref.set(data) assert len(recorder) == 1 - assert recorder[0].method == 'PUT' - assert recorder[0].url == 'https://test.firebaseio.com/test.json?print=silent' + self._assert_request( + recorder[0], 'PUT', 'https://test.firebaseio.com/test.json?print=silent') assert json.loads(recorder[0].body.decode()) == data - assert recorder[0].headers['Authorization'] == 'Bearer mock-token' def test_set_none_value(self): ref = db.reference('/test') @@ -327,10 +320,9 @@ def test_update_children(self, data): recorder = self.instrument(ref, json.dumps(data)) ref.update(data) assert len(recorder) == 1 - assert recorder[0].method == 'PATCH' - assert recorder[0].url == 'https://test.firebaseio.com/test.json?print=silent' + self._assert_request( + recorder[0], 'PATCH', 'https://test.firebaseio.com/test.json?print=silent') assert json.loads(recorder[0].body.decode()) == data - assert recorder[0].headers['Authorization'] == 'Bearer mock-token' @pytest.mark.parametrize('data', valid_values) def test_set_if_unchanged_success(self, data): @@ -339,10 +331,8 @@ def test_set_if_unchanged_success(self, data): vals = ref.set_if_unchanged(MockAdapter.ETAG, data) assert vals == (True, data, MockAdapter.ETAG) assert len(recorder) == 1 - assert recorder[0].method == 'PUT' - assert recorder[0].url == 'https://test.firebaseio.com/test.json' + self._assert_request(recorder[0], 'PUT', 'https://test.firebaseio.com/test.json') assert json.loads(recorder[0].body.decode()) == data - assert recorder[0].headers['Authorization'] == 'Bearer mock-token' assert recorder[0].headers['if-match'] == MockAdapter.ETAG @pytest.mark.parametrize('data', valid_values) @@ -352,10 +342,8 @@ def test_set_if_unchanged_failure(self, data): vals = ref.set_if_unchanged('invalid-etag', data) assert vals == (False, {'foo':'bar'}, MockAdapter.ETAG) assert len(recorder) == 1 - assert recorder[0].method == 'PUT' - assert recorder[0].url == 'https://test.firebaseio.com/test.json' + self._assert_request(recorder[0], 'PUT', 'https://test.firebaseio.com/test.json') assert json.loads(recorder[0].body.decode()) == data - assert recorder[0].headers['Authorization'] == 'Bearer mock-token' assert recorder[0].headers['if-match'] == 'invalid-etag' @pytest.mark.parametrize('etag', [0, 1, True, False, dict(), list(), tuple()]) @@ -397,22 +385,16 @@ def test_push(self, data): assert isinstance(child, db.Reference) assert child.key == 'testkey' assert len(recorder) == 1 - assert recorder[0].method == 'POST' - assert recorder[0].url == 'https://test.firebaseio.com/test.json' + self._assert_request(recorder[0], 'POST', 'https://test.firebaseio.com/test.json') assert json.loads(recorder[0].body.decode()) == data - assert recorder[0].headers['Authorization'] == 'Bearer mock-token' - assert recorder[0].headers['User-Agent'] == db._USER_AGENT def test_push_default(self): ref = db.reference('/test') recorder = self.instrument(ref, json.dumps({'name' : 'testkey'})) assert ref.push().key == 'testkey' assert len(recorder) == 1 - assert recorder[0].method == 'POST' - assert recorder[0].url == 'https://test.firebaseio.com/test.json' + self._assert_request(recorder[0], 'POST', 'https://test.firebaseio.com/test.json') assert json.loads(recorder[0].body.decode()) == '' - assert recorder[0].headers['Authorization'] == 'Bearer mock-token' - assert recorder[0].headers['User-Agent'] == db._USER_AGENT def test_push_none_value(self): ref = db.reference('/test') @@ -425,10 +407,7 @@ def test_delete(self): recorder = self.instrument(ref, '') ref.delete() assert len(recorder) == 1 - assert recorder[0].method == 'DELETE' - assert recorder[0].url == 'https://test.firebaseio.com/test.json' - assert recorder[0].headers['Authorization'] == 'Bearer mock-token' - assert recorder[0].headers['User-Agent'] == db._USER_AGENT + self._assert_request(recorder[0], 'DELETE', 'https://test.firebaseio.com/test.json') def test_transaction(self): ref = db.reference('/test') @@ -442,8 +421,8 @@ def transaction_update(data): new_value = ref.transaction(transaction_update) assert new_value == {'foo1' : 'bar1', 'foo2' : 'bar2'} assert len(recorder) == 2 - assert recorder[0].method == 'GET' - assert recorder[1].method == 'PUT' + self._assert_request(recorder[0], 'GET', 'https://test.firebaseio.com/test.json') + self._assert_request(recorder[1], 'PUT', 'https://test.firebaseio.com/test.json') assert json.loads(recorder[1].body.decode()) == {'foo1': 'bar1', 'foo2': 'bar2'} def test_transaction_scalar(self): @@ -454,8 +433,8 @@ def test_transaction_scalar(self): new_value = ref.transaction(lambda x: x + 1 if x else 1) assert new_value == 43 assert len(recorder) == 2 - assert recorder[0].method == 'GET' - assert recorder[1].method == 'PUT' + self._assert_request(recorder[0], 'GET', 'https://test.firebaseio.com/test/count.json') + self._assert_request(recorder[1], 'PUT', 'https://test.firebaseio.com/test/count.json') assert json.loads(recorder[1].body.decode()) == 43 def test_transaction_error(self): @@ -471,7 +450,7 @@ def transaction_update(data): ref.transaction(transaction_update) assert str(excinfo.value) == 'test error' assert len(recorder) == 1 - assert recorder[0].method == 'GET' + self._assert_request(recorder[0], 'GET', 'https://test.firebaseio.com/test.json') def test_transaction_abort(self): ref = db.reference('/test/count') @@ -638,16 +617,21 @@ def instrument(self, ref, payload, status=200): ref._client.session.mount(self.test_url, adapter) return recorder + def _assert_request(self, request, expected_method, expected_url): + assert request.method == expected_method + assert request.url == expected_url + assert request.headers['Authorization'] == 'Bearer mock-token' + assert request.headers['User-Agent'] == db._USER_AGENT + assert request.headers['X-GOOG-API-CLIENT'] == _utils.get_metrics_header() + def test_get_value(self): ref = db.reference('/test') recorder = self.instrument(ref, json.dumps('data')) query_str = 'auth_variable_override={0}'.format(self.encoded_override) assert ref.get() == 'data' assert len(recorder) == 1 - assert recorder[0].method == 'GET' - assert recorder[0].url == 'https://test.firebaseio.com/test.json?' + query_str - assert recorder[0].headers['Authorization'] == 'Bearer mock-token' - assert recorder[0].headers['User-Agent'] == db._USER_AGENT + self._assert_request( + recorder[0], 'GET', 'https://test.firebaseio.com/test.json?' + query_str) def test_set_value(self): ref = db.reference('/test') @@ -656,11 +640,9 @@ def test_set_value(self): ref.set(data) query_str = 'print=silent&auth_variable_override={0}'.format(self.encoded_override) assert len(recorder) == 1 - assert recorder[0].method == 'PUT' - assert recorder[0].url == 'https://test.firebaseio.com/test.json?' + query_str + self._assert_request( + recorder[0], 'PUT', 'https://test.firebaseio.com/test.json?' + query_str) assert json.loads(recorder[0].body.decode()) == data - assert recorder[0].headers['Authorization'] == 'Bearer mock-token' - assert recorder[0].headers['User-Agent'] == db._USER_AGENT def test_order_by_query(self): ref = db.reference('/test') @@ -669,10 +651,8 @@ def test_order_by_query(self): query_str = 'orderBy=%22foo%22&auth_variable_override={0}'.format(self.encoded_override) assert query.get() == 'data' assert len(recorder) == 1 - assert recorder[0].method == 'GET' - assert recorder[0].url == 'https://test.firebaseio.com/test.json?' + query_str - assert recorder[0].headers['Authorization'] == 'Bearer mock-token' - assert recorder[0].headers['User-Agent'] == db._USER_AGENT + self._assert_request( + recorder[0], 'GET', 'https://test.firebaseio.com/test.json?' + query_str) def test_range_query(self): ref = db.reference('/test') @@ -682,10 +662,8 @@ def test_range_query(self): 'auth_variable_override={0}'.format(self.encoded_override)) assert query.get() == 'data' assert len(recorder) == 1 - assert recorder[0].method == 'GET' - assert recorder[0].url == 'https://test.firebaseio.com/test.json?' + query_str - assert recorder[0].headers['Authorization'] == 'Bearer mock-token' - assert recorder[0].headers['User-Agent'] == db._USER_AGENT + self._assert_request( + recorder[0], 'GET', 'https://test.firebaseio.com/test.json?' + query_str) class TestDatabaseInitialization: diff --git a/tests/test_functions.py b/tests/test_functions.py index 75809c1ad..f8f675890 100644 --- a/tests/test_functions.py +++ b/tests/test_functions.py @@ -21,6 +21,7 @@ import firebase_admin from firebase_admin import functions +from firebase_admin import _utils from tests import testutils @@ -121,6 +122,7 @@ def test_task_enqueue(self): assert recorder[0].url == _DEFAULT_REQUEST_URL assert recorder[0].headers['Content-Type'] == 'application/json' assert recorder[0].headers['Authorization'] == 'Bearer mock-token' + assert recorder[0].headers['X-GOOG-API-CLIENT'] == _utils.get_metrics_header() assert task_id == 'test-task-id' def test_task_enqueue_with_extension(self): @@ -137,6 +139,7 @@ def test_task_enqueue_with_extension(self): assert recorder[0].url == _CLOUD_TASKS_URL + resource_name assert recorder[0].headers['Content-Type'] == 'application/json' assert recorder[0].headers['Authorization'] == 'Bearer mock-token' + assert recorder[0].headers['X-GOOG-API-CLIENT'] == _utils.get_metrics_header() assert task_id == 'test-task-id' def test_task_delete(self): @@ -146,6 +149,7 @@ def test_task_delete(self): assert len(recorder) == 1 assert recorder[0].method == 'DELETE' assert recorder[0].url == _DEFAULT_TASK_URL + assert recorder[0].headers['X-GOOG-API-CLIENT'] == _utils.get_metrics_header() class TestTaskQueueOptions: diff --git a/tests/test_http_client.py b/tests/test_http_client.py index 12ba03b48..cc948b393 100644 --- a/tests/test_http_client.py +++ b/tests/test_http_client.py @@ -17,7 +17,7 @@ from pytest_localserver import http import requests -from firebase_admin import _http_client +from firebase_admin import _http_client, _utils from tests import testutils @@ -61,6 +61,18 @@ def test_base_url(): assert recorder[0].method == 'GET' assert recorder[0].url == _TEST_URL + 'foo' +def test_metrics_headers(): + client = _http_client.HttpClient() + assert client.session is not None + recorder = _instrument(client, 'body') + resp = client.request('get', _TEST_URL) + assert resp.status_code == 200 + assert resp.text == 'body' + assert len(recorder) == 1 + assert recorder[0].method == 'GET' + assert recorder[0].url == _TEST_URL + assert recorder[0].headers['X-GOOG-API-CLIENT'] == _utils.get_metrics_header() + def test_credential(): client = _http_client.HttpClient( credential=testutils.MockGoogleCredential()) diff --git a/tests/test_instance_id.py b/tests/test_instance_id.py index 08b0fe6db..720171cd9 100644 --- a/tests/test_instance_id.py +++ b/tests/test_instance_id.py @@ -20,6 +20,7 @@ from firebase_admin import exceptions from firebase_admin import instance_id from firebase_admin import _http_client +from firebase_admin import _utils from tests import testutils @@ -64,6 +65,11 @@ def _instrument_iid_service(self, app, status=200, payload='True'): testutils.MockAdapter(payload, status, recorder)) return iid_service, recorder + def _assert_request(self, request, expected_method, expected_url): + assert request.method == expected_method + assert request.url == expected_url + assert request.headers['X-GOOG-API-CLIENT'] == _utils.get_metrics_header() + def _get_url(self, project_id, iid): return instance_id._IID_SERVICE_URL + 'project/{0}/instanceId/{1}'.format(project_id, iid) @@ -86,8 +92,8 @@ def test_delete_instance_id(self): _, recorder = self._instrument_iid_service(app) instance_id.delete_instance_id('test_iid') assert len(recorder) == 1 - assert recorder[0].method == 'DELETE' - assert recorder[0].url == self._get_url('explicit-project-id', 'test_iid') + self._assert_request( + recorder[0], 'DELETE', self._get_url('explicit-project-id', 'test_iid')) def test_delete_instance_id_with_explicit_app(self): cred = testutils.MockCredential() @@ -95,8 +101,8 @@ def test_delete_instance_id_with_explicit_app(self): _, recorder = self._instrument_iid_service(app) instance_id.delete_instance_id('test_iid', app) assert len(recorder) == 1 - assert recorder[0].method == 'DELETE' - assert recorder[0].url == self._get_url('explicit-project-id', 'test_iid') + self._assert_request( + recorder[0], 'DELETE', self._get_url('explicit-project-id', 'test_iid')) @pytest.mark.parametrize('status', http_errors.keys()) def test_delete_instance_id_error(self, status): @@ -114,8 +120,8 @@ def test_delete_instance_id_error(self, status): else: # 401 responses are automatically retried by google-auth assert len(recorder) == 3 - assert recorder[0].method == 'DELETE' - assert recorder[0].url == self._get_url('explicit-project-id', 'test_iid') + self._assert_request( + recorder[0], 'DELETE', self._get_url('explicit-project-id', 'test_iid')) def test_delete_instance_id_unexpected_error(self): cred = testutils.MockCredential() @@ -129,8 +135,7 @@ def test_delete_instance_id_unexpected_error(self): assert excinfo.value.cause is not None assert excinfo.value.http_response is not None assert len(recorder) == 1 - assert recorder[0].method == 'DELETE' - assert recorder[0].url == url + self._assert_request(recorder[0], 'DELETE', url) @pytest.mark.parametrize('iid', [None, '', 0, 1, True, False, list(), dict(), tuple()]) def test_invalid_instance_id(self, iid): diff --git a/tests/test_messaging.py b/tests/test_messaging.py index d482438f5..edb36f53a 100644 --- a/tests/test_messaging.py +++ b/tests/test_messaging.py @@ -25,6 +25,7 @@ from firebase_admin import exceptions from firebase_admin import messaging from firebase_admin import _http_client +from firebase_admin import _utils from tests import testutils @@ -1660,6 +1661,18 @@ def _instrument_messaging_service(self, app=None, status=200, payload=_DEFAULT_R testutils.MockAdapter(payload, status, recorder)) return fcm_service, recorder + + def _assert_request(self, request, expected_method, expected_url, expected_body=None): + assert request.method == expected_method + assert request.url == expected_url + assert request.headers['X-GOOG-API-FORMAT-VERSION'] == '2' + assert request.headers['X-FIREBASE-CLIENT'] == self._CLIENT_VERSION + assert request.headers['X-GOOG-API-CLIENT'] == _utils.get_metrics_header() + if expected_body is None: + assert request.body is None + else: + assert json.loads(request.body.decode()) == expected_body + def _get_url(self, project_id): return messaging._MessagingService.FCM_URL.format(project_id) @@ -1682,15 +1695,11 @@ def test_send_dry_run(self): msg_id = messaging.send(msg, dry_run=True) assert msg_id == 'message-id' assert len(recorder) == 1 - assert recorder[0].method == 'POST' - assert recorder[0].url == self._get_url('explicit-project-id') - assert recorder[0].headers['X-GOOG-API-FORMAT-VERSION'] == '2' - assert recorder[0].headers['X-FIREBASE-CLIENT'] == self._CLIENT_VERSION body = { 'message': messaging._MessagingService.encode_message(msg), 'validate_only': True, } - assert json.loads(recorder[0].body.decode()) == body + self._assert_request(recorder[0], 'POST', self._get_url('explicit-project-id'), body) def test_send(self): _, recorder = self._instrument_messaging_service() @@ -1698,12 +1707,8 @@ def test_send(self): msg_id = messaging.send(msg) assert msg_id == 'message-id' assert len(recorder) == 1 - assert recorder[0].method == 'POST' - assert recorder[0].url == self._get_url('explicit-project-id') - assert recorder[0].headers['X-GOOG-API-FORMAT-VERSION'] == '2' - assert recorder[0].headers['X-FIREBASE-CLIENT'] == self._CLIENT_VERSION body = {'message': messaging._MessagingService.encode_message(msg)} - assert json.loads(recorder[0].body.decode()) == body + self._assert_request(recorder[0], 'POST', self._get_url('explicit-project-id'), body) @pytest.mark.parametrize('status,exc_type', HTTP_ERROR_CODES.items()) def test_send_error(self, status, exc_type): @@ -1714,12 +1719,8 @@ def test_send_error(self, status, exc_type): expected = 'Unexpected HTTP response with status: {0}; body: {{}}'.format(status) check_exception(excinfo.value, expected, status) assert len(recorder) == 1 - assert recorder[0].method == 'POST' - assert recorder[0].url == self._get_url('explicit-project-id') - assert recorder[0].headers['X-GOOG-API-FORMAT-VERSION'] == '2' - assert recorder[0].headers['X-FIREBASE-CLIENT'] == self._CLIENT_VERSION body = {'message': messaging._MessagingService.JSON_ENCODER.default(msg)} - assert json.loads(recorder[0].body.decode()) == body + self._assert_request(recorder[0], 'POST', self._get_url('explicit-project-id'), body) @pytest.mark.parametrize('status', HTTP_ERROR_CODES) def test_send_detailed_error(self, status): @@ -1735,10 +1736,8 @@ def test_send_detailed_error(self, status): messaging.send(msg) check_exception(excinfo.value, 'test error', status) assert len(recorder) == 1 - assert recorder[0].method == 'POST' - assert recorder[0].url == self._get_url('explicit-project-id') body = {'message': messaging._MessagingService.JSON_ENCODER.default(msg)} - assert json.loads(recorder[0].body.decode()) == body + self._assert_request(recorder[0], 'POST', self._get_url('explicit-project-id'), body) @pytest.mark.parametrize('status', HTTP_ERROR_CODES) def test_send_canonical_error_code(self, status): @@ -1754,10 +1753,8 @@ def test_send_canonical_error_code(self, status): messaging.send(msg) check_exception(excinfo.value, 'test error', status) assert len(recorder) == 1 - assert recorder[0].method == 'POST' - assert recorder[0].url == self._get_url('explicit-project-id') body = {'message': messaging._MessagingService.JSON_ENCODER.default(msg)} - assert json.loads(recorder[0].body.decode()) == body + self._assert_request(recorder[0], 'POST', self._get_url('explicit-project-id'), body) @pytest.mark.parametrize('status', HTTP_ERROR_CODES) @pytest.mark.parametrize('fcm_error_code, exc_type', FCM_ERROR_CODES.items()) @@ -1780,10 +1777,8 @@ def test_send_fcm_error_code(self, status, fcm_error_code, exc_type): messaging.send(msg) check_exception(excinfo.value, 'test error', status) assert len(recorder) == 1 - assert recorder[0].method == 'POST' - assert recorder[0].url == self._get_url('explicit-project-id') body = {'message': messaging._MessagingService.JSON_ENCODER.default(msg)} - assert json.loads(recorder[0].body.decode()) == body + self._assert_request(recorder[0], 'POST', self._get_url('explicit-project-id'), body) @pytest.mark.parametrize('status', HTTP_ERROR_CODES) def test_send_unknown_fcm_error_code(self, status): @@ -1805,10 +1800,8 @@ def test_send_unknown_fcm_error_code(self, status): messaging.send(msg) check_exception(excinfo.value, 'test error', status) assert len(recorder) == 1 - assert recorder[0].method == 'POST' - assert recorder[0].url == self._get_url('explicit-project-id') body = {'message': messaging._MessagingService.JSON_ENCODER.default(msg)} - assert json.loads(recorder[0].body.decode()) == body + self._assert_request(recorder[0], 'POST', self._get_url('explicit-project-id'), body) class _HttpMockException: @@ -2591,6 +2584,12 @@ def _instrument_iid_service(self, app=None, status=200, payload=_DEFAULT_RESPONS testutils.MockAdapter(payload, status, recorder)) return fcm_service, recorder + def _assert_request(self, request, expected_method, expected_url): + assert request.method == expected_method + assert request.url == expected_url + assert request.headers['access_token_auth'] == 'true' + assert request.headers['X-GOOG-API-CLIENT'] == _utils.get_metrics_header() + def _get_url(self, path): return '{0}/{1}'.format(messaging._MessagingService.IID_URL, path) @@ -2625,8 +2624,7 @@ def test_subscribe_to_topic(self, args): resp = messaging.subscribe_to_topic(args[0], args[1]) self._check_response(resp) assert len(recorder) == 1 - assert recorder[0].method == 'POST' - assert recorder[0].url == self._get_url('iid/v1:batchAdd') + self._assert_request(recorder[0], 'POST', self._get_url('iid/v1:batchAdd')) assert json.loads(recorder[0].body.decode()) == args[2] @pytest.mark.parametrize('status, exc_type', HTTP_ERROR_CODES.items()) @@ -2637,8 +2635,7 @@ def test_subscribe_to_topic_error(self, status, exc_type): messaging.subscribe_to_topic('foo', 'test-topic') assert str(excinfo.value) == 'Error while calling the IID service: error_reason' assert len(recorder) == 1 - assert recorder[0].method == 'POST' - assert recorder[0].url == self._get_url('iid/v1:batchAdd') + self._assert_request(recorder[0], 'POST', self._get_url('iid/v1:batchAdd')) @pytest.mark.parametrize('status, exc_type', HTTP_ERROR_CODES.items()) def test_subscribe_to_topic_non_json_error(self, status, exc_type): @@ -2648,8 +2645,7 @@ def test_subscribe_to_topic_non_json_error(self, status, exc_type): reason = 'Unexpected HTTP response with status: {0}; body: not json'.format(status) assert str(excinfo.value) == reason assert len(recorder) == 1 - assert recorder[0].method == 'POST' - assert recorder[0].url == self._get_url('iid/v1:batchAdd') + self._assert_request(recorder[0], 'POST', self._get_url('iid/v1:batchAdd')) @pytest.mark.parametrize('args', _VALID_ARGS) def test_unsubscribe_from_topic(self, args): @@ -2657,8 +2653,7 @@ def test_unsubscribe_from_topic(self, args): resp = messaging.unsubscribe_from_topic(args[0], args[1]) self._check_response(resp) assert len(recorder) == 1 - assert recorder[0].method == 'POST' - assert recorder[0].url == self._get_url('iid/v1:batchRemove') + self._assert_request(recorder[0], 'POST', self._get_url('iid/v1:batchRemove')) assert json.loads(recorder[0].body.decode()) == args[2] @pytest.mark.parametrize('status, exc_type', HTTP_ERROR_CODES.items()) @@ -2669,8 +2664,7 @@ def test_unsubscribe_from_topic_error(self, status, exc_type): messaging.unsubscribe_from_topic('foo', 'test-topic') assert str(excinfo.value) == 'Error while calling the IID service: error_reason' assert len(recorder) == 1 - assert recorder[0].method == 'POST' - assert recorder[0].url == self._get_url('iid/v1:batchRemove') + self._assert_request(recorder[0], 'POST', self._get_url('iid/v1:batchRemove')) @pytest.mark.parametrize('status, exc_type', HTTP_ERROR_CODES.items()) def test_unsubscribe_from_topic_non_json_error(self, status, exc_type): @@ -2680,8 +2674,7 @@ def test_unsubscribe_from_topic_non_json_error(self, status, exc_type): reason = 'Unexpected HTTP response with status: {0}; body: not json'.format(status) assert str(excinfo.value) == reason assert len(recorder) == 1 - assert recorder[0].method == 'POST' - assert recorder[0].url == self._get_url('iid/v1:batchRemove') + self._assert_request(recorder[0], 'POST', self._get_url('iid/v1:batchRemove')) def _check_response(self, resp): assert resp.success_count == 1 diff --git a/tests/test_ml.py b/tests/test_ml.py index abd6d06f9..137fe4cf6 100644 --- a/tests/test_ml.py +++ b/tests/test_ml.py @@ -21,12 +21,11 @@ import firebase_admin from firebase_admin import exceptions from firebase_admin import ml +from firebase_admin import _utils from tests import testutils BASE_URL = 'https://firebaseml.googleapis.com/v1beta2/' -HEADER_CLIENT_KEY = 'X-FIREBASE-CLIENT' -HEADER_CLIENT_VALUE = 'fire-admin-python/{0}'.format(firebase_admin.__version__) PROJECT_ID = 'my-project-1' PAGE_TOKEN = 'pageToken' @@ -336,6 +335,12 @@ def instrument_ml_service(status=200, payload=None, operations=False, app=None): session_url, adapter(payload, status, recorder)) return recorder +def _assert_request(request, expected_method, expected_url): + assert request.method == expected_method + assert request.url == expected_url + assert request.headers['X-FIREBASE-CLIENT'] == f'fire-admin-python/{firebase_admin.__version__}' + assert request.headers['X-GOOG-API-CLIENT'] == _utils.get_metrics_header() + class _TestStorageClient: @staticmethod def upload(bucket_name, model_file_name, app): @@ -599,9 +604,7 @@ def test_wait_for_unlocked(self): model.wait_for_unlocked() assert model == FULL_MODEL_PUBLISHED assert len(recorder) == 1 - assert recorder[0].method == 'GET' - assert recorder[0].url == TestModel._op_url(PROJECT_ID) - assert recorder[0].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE + _assert_request(recorder[0], 'GET', TestModel._op_url(PROJECT_ID)) def test_wait_for_unlocked_timeout(self): recorder = instrument_ml_service( @@ -653,12 +656,8 @@ def test_returns_locked(self): assert model == expected_model assert len(recorder) == 2 - assert recorder[0].method == 'POST' - assert recorder[0].url == TestCreateModel._url(PROJECT_ID) - assert recorder[0].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE - assert recorder[1].method == 'GET' - assert recorder[1].url == TestCreateModel._get_url(PROJECT_ID, MODEL_ID_1) - assert recorder[1].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE + _assert_request(recorder[0], 'POST', TestCreateModel._url(PROJECT_ID)) + _assert_request(recorder[1], 'GET', TestCreateModel._get_url(PROJECT_ID, MODEL_ID_1)) def test_operation_error(self): instrument_ml_service(status=200, payload=OPERATION_ERROR_RESPONSE) @@ -747,12 +746,8 @@ def test_returns_locked(self): assert model == expected_model assert len(recorder) == 2 - assert recorder[0].method == 'PATCH' - assert recorder[0].url == TestUpdateModel._url(PROJECT_ID, MODEL_ID_1) - assert recorder[0].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE - assert recorder[1].method == 'GET' - assert recorder[1].url == TestUpdateModel._url(PROJECT_ID, MODEL_ID_1) - assert recorder[1].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE + _assert_request(recorder[0], 'PATCH', TestUpdateModel._url(PROJECT_ID, MODEL_ID_1)) + _assert_request(recorder[1], 'GET', TestUpdateModel._url(PROJECT_ID, MODEL_ID_1)) def test_operation_error(self): instrument_ml_service(status=200, payload=OPERATION_ERROR_RESPONSE) @@ -846,9 +841,8 @@ def test_immediate_done(self, publish_function, published): model = publish_function(MODEL_ID_1) assert model == CREATED_UPDATED_MODEL_1 assert len(recorder) == 1 - assert recorder[0].method == 'PATCH' - assert recorder[0].url == TestPublishUnpublish._update_url(PROJECT_ID, MODEL_ID_1) - assert recorder[0].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE + _assert_request( + recorder[0], 'PATCH', TestPublishUnpublish._update_url(PROJECT_ID, MODEL_ID_1)) body = json.loads(recorder[0].body.decode()) assert body.get('state', {}).get('published', None) is published @@ -862,12 +856,10 @@ def test_returns_locked(self, publish_function): assert model == expected_model assert len(recorder) == 2 - assert recorder[0].method == 'PATCH' - assert recorder[0].url == TestPublishUnpublish._update_url(PROJECT_ID, MODEL_ID_1) - assert recorder[0].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE - assert recorder[1].method == 'GET' - assert recorder[1].url == TestPublishUnpublish._get_url(PROJECT_ID, MODEL_ID_1) - assert recorder[1].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE + _assert_request( + recorder[0], 'PATCH', TestPublishUnpublish._update_url(PROJECT_ID, MODEL_ID_1)) + _assert_request( + recorder[1], 'GET', TestPublishUnpublish._get_url(PROJECT_ID, MODEL_ID_1)) @pytest.mark.parametrize('publish_function', PUBLISH_UNPUBLISH_FUNCS) def test_operation_error(self, publish_function): @@ -918,9 +910,7 @@ def test_get_model(self): recorder = instrument_ml_service(status=200, payload=DEFAULT_GET_RESPONSE) model = ml.get_model(MODEL_ID_1) assert len(recorder) == 1 - assert recorder[0].method == 'GET' - assert recorder[0].url == TestGetModel._url(PROJECT_ID, MODEL_ID_1) - assert recorder[0].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE + _assert_request(recorder[0], 'GET', TestGetModel._url(PROJECT_ID, MODEL_ID_1)) assert model == MODEL_1 assert model.model_id == MODEL_ID_1 assert model.display_name == DISPLAY_NAME_1 @@ -942,9 +932,7 @@ def test_get_model_error(self): ERROR_MSG_NOT_FOUND ) assert len(recorder) == 1 - assert recorder[0].method == 'GET' - assert recorder[0].url == TestGetModel._url(PROJECT_ID, MODEL_ID_1) - assert recorder[0].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE + _assert_request(recorder[0], 'GET', TestGetModel._url(PROJECT_ID, MODEL_ID_1)) def test_no_project_id(self): def evaluate(): @@ -973,9 +961,7 @@ def test_delete_model(self): recorder = instrument_ml_service(status=200, payload=EMPTY_RESPONSE) ml.delete_model(MODEL_ID_1) # no response for delete assert len(recorder) == 1 - assert recorder[0].method == 'DELETE' - assert recorder[0].url == TestDeleteModel._url(PROJECT_ID, MODEL_ID_1) - assert recorder[0].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE + _assert_request(recorder[0], 'DELETE', TestDeleteModel._url(PROJECT_ID, MODEL_ID_1)) @pytest.mark.parametrize('model_id, exc_type', INVALID_MODEL_ID_ARGS) def test_delete_model_validation_errors(self, model_id, exc_type): @@ -994,9 +980,7 @@ def test_delete_model_error(self): ERROR_MSG_NOT_FOUND ) assert len(recorder) == 1 - assert recorder[0].method == 'DELETE' - assert recorder[0].url == self._url(PROJECT_ID, MODEL_ID_1) - assert recorder[0].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE + _assert_request(recorder[0], 'DELETE', self._url(PROJECT_ID, MODEL_ID_1)) def test_no_project_id(self): def evaluate(): @@ -1032,9 +1016,7 @@ def test_list_models_no_args(self): recorder = instrument_ml_service(status=200, payload=DEFAULT_LIST_RESPONSE) models_page = ml.list_models() assert len(recorder) == 1 - assert recorder[0].method == 'GET' - assert recorder[0].url == TestListModels._url(PROJECT_ID) - assert recorder[0].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE + _assert_request(recorder[0], 'GET', TestListModels._url(PROJECT_ID)) TestListModels._check_page(models_page, 2) assert models_page.has_next_page assert models_page.next_page_token == NEXT_PAGE_TOKEN @@ -1048,12 +1030,10 @@ def test_list_models_with_all_args(self): page_size=10, page_token=PAGE_TOKEN) assert len(recorder) == 1 - assert recorder[0].method == 'GET' - assert recorder[0].url == ( + _assert_request(recorder[0], 'GET', ( TestListModels._url(PROJECT_ID) + '?filter=display_name%3DdisplayName3&page_size=10&page_token={0}' - .format(PAGE_TOKEN)) - assert recorder[0].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE + .format(PAGE_TOKEN))) assert isinstance(models_page, ml.ListModelsPage) assert len(models_page.models) == 1 assert models_page.models[0] == MODEL_3 @@ -1097,9 +1077,7 @@ def test_list_models_error(self): ERROR_MSG_BAD_REQUEST ) assert len(recorder) == 1 - assert recorder[0].method == 'GET' - assert recorder[0].url == TestListModels._url(PROJECT_ID) - assert recorder[0].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE + _assert_request(recorder[0], 'GET', TestListModels._url(PROJECT_ID)) def test_no_project_id(self): def evaluate(): diff --git a/tests/test_project_management.py b/tests/test_project_management.py index 183195510..0a1bf97e5 100644 --- a/tests/test_project_management.py +++ b/tests/test_project_management.py @@ -23,6 +23,7 @@ from firebase_admin import exceptions from firebase_admin import project_management from firebase_admin import _http_client +from firebase_admin import _utils from tests import testutils OPERATION_IN_PROGRESS_RESPONSE = json.dumps({ @@ -521,8 +522,8 @@ def _assert_request_is_correct( self, request, expected_method, expected_url, expected_body=None): assert request.method == expected_method assert request.url == expected_url - client_version = 'Python/Admin/{0}'.format(firebase_admin.__version__) - assert request.headers['X-Client-Version'] == client_version + assert request.headers['X-Client-Version'] == f'Python/Admin/{firebase_admin.__version__}' + assert request.headers['X-GOOG-API-CLIENT'] == _utils.get_metrics_header() if expected_body is None: assert request.body is None else: diff --git a/tests/test_tenant_mgt.py b/tests/test_tenant_mgt.py index 53b766239..1da6d938a 100644 --- a/tests/test_tenant_mgt.py +++ b/tests/test_tenant_mgt.py @@ -26,6 +26,7 @@ from firebase_admin import tenant_mgt from firebase_admin import _auth_providers from firebase_admin import _user_mgt +from firebase_admin import _utils from tests import testutils from tests import test_token_gen @@ -195,6 +196,8 @@ def test_get_tenant(self, tenant_mgt_app): req = recorder[0] assert req.method == 'GET' assert req.url == '{0}/tenants/tenant-id'.format(TENANT_MGT_URL_PREFIX) + assert req.headers['X-Client-Version'] == f'Python/Admin/{firebase_admin.__version__}' + assert req.headers['X-GOOG-API-CLIENT'] == _utils.get_metrics_header() def test_tenant_not_found(self, tenant_mgt_app): _instrument_tenant_mgt(tenant_mgt_app, 500, TENANT_NOT_FOUND_RESPONSE) @@ -285,6 +288,8 @@ def _assert_request(self, recorder, body): req = recorder[0] assert req.method == 'POST' assert req.url == '{0}/tenants'.format(TENANT_MGT_URL_PREFIX) + assert req.headers['X-Client-Version'] == f'Python/Admin/{firebase_admin.__version__}' + assert req.headers['X-GOOG-API-CLIENT'] == _utils.get_metrics_header() got = json.loads(req.body.decode()) assert got == body @@ -383,6 +388,8 @@ def _assert_request(self, recorder, body, mask): assert req.method == 'PATCH' assert req.url == '{0}/tenants/tenant-id?updateMask={1}'.format( TENANT_MGT_URL_PREFIX, ','.join(mask)) + assert req.headers['X-Client-Version'] == f'Python/Admin/{firebase_admin.__version__}' + assert req.headers['X-GOOG-API-CLIENT'] == _utils.get_metrics_header() got = json.loads(req.body.decode()) assert got == body @@ -403,6 +410,8 @@ def test_delete_tenant(self, tenant_mgt_app): req = recorder[0] assert req.method == 'DELETE' assert req.url == '{0}/tenants/tenant-id'.format(TENANT_MGT_URL_PREFIX) + assert req.headers['X-Client-Version'] == f'Python/Admin/{firebase_admin.__version__}' + assert req.headers['X-GOOG-API-CLIENT'] == _utils.get_metrics_header() def test_tenant_not_found(self, tenant_mgt_app): _instrument_tenant_mgt(tenant_mgt_app, 500, TENANT_NOT_FOUND_RESPONSE) @@ -545,6 +554,8 @@ def _assert_request(self, recorder, expected=None): assert len(recorder) == 1 req = recorder[0] assert req.method == 'GET' + assert req.headers['X-Client-Version'] == f'Python/Admin/{firebase_admin.__version__}' + assert req.headers['X-GOOG-API-CLIENT'] == _utils.get_metrics_header() request = dict(parse.parse_qsl(parse.urlsplit(req.url).query)) assert request == expected @@ -920,6 +931,8 @@ def _assert_request( req = recorder[0] assert req.method == method assert req.url == '{0}/tenants/tenant-id{1}'.format(prefix, want_url) + assert req.headers['X-Client-Version'] == f'Python/Admin/{firebase_admin.__version__}' + assert req.headers['X-GOOG-API-CLIENT'] == _utils.get_metrics_header() body = json.loads(req.body.decode()) assert body == want_body diff --git a/tests/test_user_mgt.py b/tests/test_user_mgt.py index ea9c87e6f..604ec9959 100644 --- a/tests/test_user_mgt.py +++ b/tests/test_user_mgt.py @@ -28,6 +28,7 @@ from firebase_admin import _http_client from firebase_admin import _user_import from firebase_admin import _user_mgt +from firebase_admin import _utils from tests import testutils @@ -135,6 +136,7 @@ def _check_request(recorder, want_url, want_body=None, want_timeout=None): req = recorder[0] assert req.method == 'POST' assert req.url == '{0}{1}'.format(USER_MGT_URLS['PREFIX'], want_url) + assert req.headers['X-GOOG-API-CLIENT'] == _utils.get_metrics_header() if want_body: body = json.loads(req.body.decode()) assert body == want_body From 50ace23cbcdf0d66f38a0b8ec0e0ded02f91fe46 Mon Sep 17 00:00:00 2001 From: Jonathan Edey <145066863+jonathanedey@users.noreply.github.com> Date: Thu, 7 Nov 2024 14:35:16 -0500 Subject: [PATCH 187/226] feat(firestore): Upgrade `google-cloud-firestore` to support Firestore Multi Database (#827) * feat(firestore): Upgrade `google-cloud-firestore` to support Firestore Multi Database * Bump to v2.19.0 --- requirements.txt | 2 +- setup.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/requirements.txt b/requirements.txt index fa48f7f57..fd5b0b39c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,6 +9,6 @@ pytest-mock >= 3.6.1 cachecontrol >= 0.12.14 google-api-core[grpc] >= 1.22.1, < 3.0.0dev; platform.python_implementation != 'PyPy' google-api-python-client >= 1.7.8 -google-cloud-firestore >= 2.9.1; platform.python_implementation != 'PyPy' +google-cloud-firestore >= 2.19.0; platform.python_implementation != 'PyPy' google-cloud-storage >= 1.37.1 pyjwt[crypto] >= 2.5.0 \ No newline at end of file diff --git a/setup.py b/setup.py index e479e39e6..23be6d481 100644 --- a/setup.py +++ b/setup.py @@ -40,7 +40,7 @@ 'cachecontrol>=0.12.14', 'google-api-core[grpc] >= 1.22.1, < 3.0.0dev; platform.python_implementation != "PyPy"', 'google-api-python-client >= 1.7.8', - 'google-cloud-firestore>=2.9.1; platform.python_implementation != "PyPy"', + 'google-cloud-firestore>=2.19.0; platform.python_implementation != "PyPy"', 'google-cloud-storage>=1.37.1', 'pyjwt[crypto] >= 2.5.0', ] From d3e2a6306be13d1492ff5a33a46065a075e22a2d Mon Sep 17 00:00:00 2001 From: Jonathan Edey <145066863+jonathanedey@users.noreply.github.com> Date: Thu, 7 Nov 2024 17:23:35 -0500 Subject: [PATCH 188/226] [chore] Release 6.6.0 (#829) --- firebase_admin/__about__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/firebase_admin/__about__.py b/firebase_admin/__about__.py index 75f3f4b41..4ee475c8a 100644 --- a/firebase_admin/__about__.py +++ b/firebase_admin/__about__.py @@ -14,7 +14,7 @@ """About information (version, etc) for Firebase Admin SDK.""" -__version__ = '6.5.0' +__version__ = '6.6.0' __title__ = 'firebase_admin' __author__ = 'Firebase' __license__ = 'Apache License 2.0' From 1b131f01fa84a4c4bd521431e6e6bb342ebddfe6 Mon Sep 17 00:00:00 2001 From: Jonathan Edey <145066863+jonathanedey@users.noreply.github.com> Date: Thu, 7 Nov 2024 17:45:43 -0500 Subject: [PATCH 189/226] [chore] Release 6.6.0 Take 2 (#830) --- .github/workflows/release.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 7aab71b23..7a7986a5a 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -108,6 +108,7 @@ jobs: uses: actions/download-artifact@v4.1.7 with: name: dist + path: dist - name: Publish preflight check id: preflight From 43ab91e22d02e5d07dfcc77e6943f4d1251c291e Mon Sep 17 00:00:00 2001 From: Jonathan Edey <145066863+jonathanedey@users.noreply.github.com> Date: Thu, 19 Dec 2024 12:35:16 -0500 Subject: [PATCH 190/226] chore: Skip integration test for deprecated FCM API and bump pypy CI to 3.9 (#840) * chore: Skip integration test for deprecated FCM API * chore: Bump pypy test version to 3.9 --- .github/workflows/ci.yml | 2 +- integration/test_messaging.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 127aa2e7e..4cc8ec481 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -8,7 +8,7 @@ jobs: strategy: fail-fast: false matrix: - python: ['3.8', '3.9', '3.10', '3.11', '3.12', 'pypy3.8'] + python: ['3.8', '3.9', '3.10', '3.11', '3.12', 'pypy3.9'] steps: - uses: actions/checkout@v4 diff --git a/integration/test_messaging.py b/integration/test_messaging.py index 522e87e85..50b4ae3a4 100644 --- a/integration/test_messaging.py +++ b/integration/test_messaging.py @@ -197,6 +197,7 @@ def test_send_all_500(): assert response.exception is None assert re.match('^projects/.*/messages/.*$', response.message_id) +@pytest.mark.skip(reason="Replaced with test_send_each_for_multicast") def test_send_multicast(): multicast = messaging.MulticastMessage( notification=messaging.Notification('Title', 'Body'), From 8ba819a4175e758576f1a7cccc131c1b66d6417a Mon Sep 17 00:00:00 2001 From: Jonathan Edey <145066863+jonathanedey@users.noreply.github.com> Date: Fri, 20 Dec 2024 16:42:54 -0500 Subject: [PATCH 191/226] chore: Adding delayed response message for holidays (#842) * Adding delayed response message for holidays * fix date --- .github/ISSUE_TEMPLATE/bug_report.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md index 2970d494f..ade9ad153 100644 --- a/.github/ISSUE_TEMPLATE/bug_report.md +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -7,6 +7,11 @@ assignees: '' --- +--- +**Thank you for submitting your issue. We are operating at reduced capacity from Dec 23 2024 to Jan 6 2025. Please expect delayed responses. For more urgent requests please reach us via our support channels https://firebase.google.com/support** + +--- + ### [READ] Step 1: Are you in the right place? * For issues related to __the code in this repository__ file a GitHub issue. From 0ce187ffe710e6c295656cb515d4bbb9ac31217f Mon Sep 17 00:00:00 2001 From: Jonathan Edey <145066863+jonathanedey@users.noreply.github.com> Date: Mon, 6 Jan 2025 14:49:16 -0500 Subject: [PATCH 192/226] Revert "chore: Adding delayed response message for holidays (#842)" (#848) This reverts commit 8ba819a4175e758576f1a7cccc131c1b66d6417a. --- .github/ISSUE_TEMPLATE/bug_report.md | 5 ----- 1 file changed, 5 deletions(-) diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md index ade9ad153..2970d494f 100644 --- a/.github/ISSUE_TEMPLATE/bug_report.md +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -7,11 +7,6 @@ assignees: '' --- ---- -**Thank you for submitting your issue. We are operating at reduced capacity from Dec 23 2024 to Jan 6 2025. Please expect delayed responses. For more urgent requests please reach us via our support channels https://firebase.google.com/support** - ---- - ### [READ] Step 1: Are you in the right place? * For issues related to __the code in this repository__ file a GitHub issue. From e5618c0bb9c2d186cebeffa25bff90b3ee223ffb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E3=81=82=E3=81=84=E3=81=86=E3=81=88=E3=81=8A?= <130837816+aiueo-1234@users.noreply.github.com> Date: Tue, 14 Jan 2025 05:15:01 +0900 Subject: [PATCH 193/226] pass clinet's params to SSEClient (#845) --- firebase_admin/db.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/firebase_admin/db.py b/firebase_admin/db.py index 890968796..1dec98653 100644 --- a/firebase_admin/db.py +++ b/firebase_admin/db.py @@ -467,7 +467,7 @@ def _listen_with_session(self, callback, session=None): session = self._client.create_listener_session() try: - sse = _sseclient.SSEClient(url, session) + sse = _sseclient.SSEClient(url, session, **{"params": self._client.params}) return ListenerRegistration(callback, sse) except requests.exceptions.RequestException as error: raise _Client.handle_rtdb_error(error) From e6c95e7ef6e8f4f77ff7e6540e6a070cb728aa4e Mon Sep 17 00:00:00 2001 From: Jonathan Edey <145066863+jonathanedey@users.noreply.github.com> Date: Wed, 22 Jan 2025 12:06:20 -0500 Subject: [PATCH 194/226] chore: Add tests for `Reference.listen()` (#851) * chore: Add unit tests for `Reference.listen()` * Integration test for rtdb listeners * fix lint --- integration/test_db.py | 32 +++++++++++++++++++++++++++++++ tests/test_db.py | 43 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 75 insertions(+) diff --git a/integration/test_db.py b/integration/test_db.py index c448436d6..0170743dd 100644 --- a/integration/test_db.py +++ b/integration/test_db.py @@ -16,6 +16,7 @@ import collections import json import os +import time import pytest @@ -245,6 +246,37 @@ def test_delete(self, testref): ref.delete() assert ref.get() is None +class TestListenOperations: + """Test cases for listening to changes to node values.""" + + def test_listen(self, testref): + self.events = [] + def callback(event): + self.events.append(event) + + python = testref.parent + registration = python.listen(callback) + try: + ref = python.child('users').push() + assert ref.path == '/_adminsdk/python/users/' + ref.key + assert ref.get() == '' + + self.wait_for(self.events, count=2) + assert len(self.events) == 2 + + assert self.events[1].event_type == 'put' + assert self.events[1].path == '/users/' + ref.key + assert self.events[1].data == '' + finally: + registration.close() + + @classmethod + def wait_for(cls, events, count=1, timeout_seconds=5): + must_end = time.time() + timeout_seconds + while time.time() < must_end: + if len(events) >= count: + return + raise pytest.fail('Timed out while waiting for events') class TestAdvancedQueries: """Test cases for advanced interactions via the db.Query interface.""" diff --git a/tests/test_db.py b/tests/test_db.py index 4245f65fb..f2ba08827 100644 --- a/tests/test_db.py +++ b/tests/test_db.py @@ -535,6 +535,49 @@ def callback(_): finally: testutils.cleanup_apps() + @pytest.mark.parametrize( + 'url,emulator_host,expected_base_url,expected_namespace', + [ + # Production URLs with no override: + ('https://test.firebaseio.com', None, 'https://test.firebaseio.com/.json', None), + ('https://test.firebaseio.com/', None, 'https://test.firebaseio.com/.json', None), + + # Production URLs with emulator_host override: + ('https://test.firebaseio.com', 'localhost:9000', 'http://localhost:9000/.json', + 'test'), + ('https://test.firebaseio.com/', 'localhost:9000', 'http://localhost:9000/.json', + 'test'), + + # Emulator URL with no override. + ('http://localhost:8000/?ns=test', None, 'http://localhost:8000/.json', 'test'), + + # emulator_host is ignored when the original URL is already emulator. + ('http://localhost:8000/?ns=test', 'localhost:9999', 'http://localhost:8000/.json', + 'test'), + ] + ) + def test_listen_sse_client(self, url, emulator_host, expected_base_url, expected_namespace, + mocker): + if emulator_host: + os.environ[_EMULATOR_HOST_ENV_VAR] = emulator_host + + try: + firebase_admin.initialize_app(testutils.MockCredential(), {'databaseURL' : url}) + ref = db.reference() + mock_sse_client = mocker.patch('firebase_admin._sseclient.SSEClient') + mock_callback = mocker.Mock() + ref.listen(mock_callback) + args, kwargs = mock_sse_client.call_args + assert args[0] == expected_base_url + if expected_namespace: + assert kwargs.get('params') == {'ns': expected_namespace} + else: + assert kwargs.get('params') == {} + finally: + if _EMULATOR_HOST_ENV_VAR in os.environ: + del os.environ[_EMULATOR_HOST_ENV_VAR] + testutils.cleanup_apps() + def test_listener_session(self): firebase_admin.initialize_app(testutils.MockCredential(), { 'databaseURL' : 'https://test.firebaseio.com', From cc9a069237ea2ff6dabcb495b2b346f419da8bab Mon Sep 17 00:00:00 2001 From: Pijush Chakraborty Date: Wed, 5 Mar 2025 23:35:56 +0530 Subject: [PATCH 195/226] feat(rc): Sever Side Remote Config Integration (#863) --- firebase_admin/remote_config.py | 764 +++++++++++++++++++++++++ tests/test_remote_config.py | 984 ++++++++++++++++++++++++++++++++ tests/testutils.py | 40 ++ 3 files changed, 1788 insertions(+) create mode 100644 firebase_admin/remote_config.py create mode 100644 tests/test_remote_config.py diff --git a/firebase_admin/remote_config.py b/firebase_admin/remote_config.py new file mode 100644 index 000000000..943141ccf --- /dev/null +++ b/firebase_admin/remote_config.py @@ -0,0 +1,764 @@ +# Copyright 2024 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Firebase Remote Config Module. +This module has required APIs for the clients to use Firebase Remote Config with python. +""" + +import asyncio +import json +import logging +import threading +from typing import Dict, Optional, Literal, Union, Any +from enum import Enum +import re +import hashlib +import requests +from firebase_admin import App, _http_client, _utils +import firebase_admin + +# Set up logging (you can customize the level and output) +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +_REMOTE_CONFIG_ATTRIBUTE = '_remoteconfig' +MAX_CONDITION_RECURSION_DEPTH = 10 +ValueSource = Literal['default', 'remote', 'static'] # Define the ValueSource type + +class PercentConditionOperator(Enum): + """Enum representing the available operators for percent conditions. + """ + LESS_OR_EQUAL = "LESS_OR_EQUAL" + GREATER_THAN = "GREATER_THAN" + BETWEEN = "BETWEEN" + UNKNOWN = "UNKNOWN" + +class CustomSignalOperator(Enum): + """Enum representing the available operators for custom signal conditions. + """ + STRING_CONTAINS = "STRING_CONTAINS" + STRING_DOES_NOT_CONTAIN = "STRING_DOES_NOT_CONTAIN" + STRING_EXACTLY_MATCHES = "STRING_EXACTLY_MATCHES" + STRING_CONTAINS_REGEX = "STRING_CONTAINS_REGEX" + NUMERIC_LESS_THAN = "NUMERIC_LESS_THAN" + NUMERIC_LESS_EQUAL = "NUMERIC_LESS_EQUAL" + NUMERIC_EQUAL = "NUMERIC_EQUAL" + NUMERIC_NOT_EQUAL = "NUMERIC_NOT_EQUAL" + NUMERIC_GREATER_THAN = "NUMERIC_GREATER_THAN" + NUMERIC_GREATER_EQUAL = "NUMERIC_GREATER_EQUAL" + SEMANTIC_VERSION_LESS_THAN = "SEMANTIC_VERSION_LESS_THAN" + SEMANTIC_VERSION_LESS_EQUAL = "SEMANTIC_VERSION_LESS_EQUAL" + SEMANTIC_VERSION_EQUAL = "SEMANTIC_VERSION_EQUAL" + SEMANTIC_VERSION_NOT_EQUAL = "SEMANTIC_VERSION_NOT_EQUAL" + SEMANTIC_VERSION_GREATER_THAN = "SEMANTIC_VERSION_GREATER_THAN" + SEMANTIC_VERSION_GREATER_EQUAL = "SEMANTIC_VERSION_GREATER_EQUAL" + UNKNOWN = "UNKNOWN" + +class _ServerTemplateData: + """Parses, validates and encapsulates template data and metadata.""" + def __init__(self, template_data): + """Initializes a new ServerTemplateData instance. + + Args: + template_data: The data to be parsed for getting the parameters and conditions. + + Raises: + ValueError: If the template data is not valid. + """ + if 'parameters' in template_data: + if template_data['parameters'] is not None: + self._parameters = template_data['parameters'] + else: + raise ValueError('Remote Config parameters must be a non-null object') + else: + self._parameters = {} + + if 'conditions' in template_data: + if template_data['conditions'] is not None: + self._conditions = template_data['conditions'] + else: + raise ValueError('Remote Config conditions must be a non-null object') + else: + self._conditions = [] + + self._version = '' + if 'version' in template_data: + self._version = template_data['version'] + + self._etag = '' + if 'etag' in template_data and isinstance(template_data['etag'], str): + self._etag = template_data['etag'] + + self._template_data_json = json.dumps(template_data) + + @property + def parameters(self): + return self._parameters + + @property + def etag(self): + return self._etag + + @property + def version(self): + return self._version + + @property + def conditions(self): + return self._conditions + + @property + def template_data_json(self): + return self._template_data_json + + +class ServerTemplate: + """Represents a Server Template with implementations for loading and evaluating the template.""" + def __init__(self, app: App = None, default_config: Optional[Dict[str, str]] = None): + """Initializes a ServerTemplate instance. + + Args: + app: App instance to be used. This is optional and the default app instance will + be used if not present. + default_config: The default config to be used in the evaluated config. + """ + self._rc_service = _utils.get_app_service(app, + _REMOTE_CONFIG_ATTRIBUTE, _RemoteConfigService) + # This gets set when the template is + # fetched from RC servers via the load API, or via the set API. + self._cache = None + self._stringified_default_config: Dict[str, str] = {} + self._lock = threading.RLock() + + # RC stores all remote values as string, but it's more intuitive + # to declare default values with specific types, so this converts + # the external declaration to an internal string representation. + if default_config is not None: + for key in default_config: + self._stringified_default_config[key] = str(default_config[key]) + + async def load(self): + """Fetches the server template and caches the data.""" + rc_server_template = await self._rc_service.get_server_template() + with self._lock: + self._cache = rc_server_template + + def evaluate(self, context: Optional[Dict[str, Union[str, int]]] = None) -> 'ServerConfig': + """Evaluates the cached server template to produce a ServerConfig. + + Args: + context: A dictionary of values to use for evaluating conditions. + + Returns: + A ServerConfig object. + Raises: + ValueError: If the input arguments are invalid. + """ + # Logic to process the cached template into a ServerConfig here. + if not self._cache: + raise ValueError("""No Remote Config Server template in cache. + Call load() before calling evaluate().""") + context = context or {} + config_values = {} + + with self._lock: + template_conditions = self._cache.conditions + template_parameters = self._cache.parameters + + # Initializes config Value objects with default values. + if self._stringified_default_config is not None: + for key, value in self._stringified_default_config.items(): + config_values[key] = _Value('default', value) + self._evaluator = _ConditionEvaluator(template_conditions, + template_parameters, context, + config_values) + return ServerConfig(config_values=self._evaluator.evaluate()) + + def set(self, template_data_json: str): + """Updates the cache to store the given template is of type ServerTemplateData. + + Args: + template_data_json: A json string representing ServerTemplateData to be cached. + """ + template_data_map = json.loads(template_data_json) + template_data = _ServerTemplateData(template_data_map) + + with self._lock: + self._cache = template_data + + def to_json(self): + """Provides the server template in a JSON format to be used for initialization later.""" + if not self._cache: + raise ValueError("""No Remote Config Server template in cache. + Call load() before calling toJSON().""") + with self._lock: + template_json = self._cache.template_data_json + return template_json + + +class ServerConfig: + """Represents a Remote Config Server Side Config.""" + def __init__(self, config_values): + self._config_values = config_values # dictionary of param key to values + + def get_boolean(self, key): + """Returns the value as a boolean.""" + return self._get_value(key).as_boolean() + + def get_string(self, key): + """Returns the value as a string.""" + return self._get_value(key).as_string() + + def get_int(self, key): + """Returns the value as an integer.""" + return self._get_value(key).as_int() + + def get_float(self, key): + """Returns the value as a float.""" + return self._get_value(key).as_float() + + def get_value_source(self, key): + """Returns the source of the value.""" + return self._get_value(key).get_source() + + def _get_value(self, key): + return self._config_values.get(key, _Value('static')) + + +class _RemoteConfigService: + """Internal class that facilitates sending requests to the Firebase Remote + Config backend API. + """ + def __init__(self, app): + """Initialize a JsonHttpClient with necessary inputs. + + Args: + app: App instance to be used for fetching app specific details required + for initializing the http client. + """ + remote_config_base_url = 'https://firebaseremoteconfig.googleapis.com' + self._project_id = app.project_id + app_credential = app.credential.get_credential() + rc_headers = { + 'X-FIREBASE-CLIENT': 'fire-admin-python/{0}'.format(firebase_admin.__version__), } + timeout = app.options.get('httpTimeout', _http_client.DEFAULT_TIMEOUT_SECONDS) + + self._client = _http_client.JsonHttpClient(credential=app_credential, + base_url=remote_config_base_url, + headers=rc_headers, timeout=timeout) + + async def get_server_template(self): + """Requests for a server template and converts the response to an instance of + ServerTemplateData for storing the template parameters and conditions.""" + try: + loop = asyncio.get_event_loop() + headers, template_data = await loop.run_in_executor(None, + self._client.headers_and_body, + 'get', self._get_url()) + except requests.exceptions.RequestException as error: + raise self._handle_remote_config_error(error) + else: + template_data['etag'] = headers.get('etag') + return _ServerTemplateData(template_data) + + def _get_url(self): + """Returns project prefix for url, in the format of /v1/projects/${projectId}""" + return "/v1/projects/{0}/namespaces/firebase-server/serverRemoteConfig".format( + self._project_id) + + @classmethod + def _handle_remote_config_error(cls, error: Any): + """Handles errors received from the Cloud Functions API.""" + return _utils.handle_platform_error_from_requests(error) + + +class _ConditionEvaluator: + """Internal class that facilitates sending requests to the Firebase Remote + Config backend API.""" + def __init__(self, conditions, parameters, context, config_values): + self._context = context + self._conditions = conditions + self._parameters = parameters + self._config_values = config_values + + def evaluate(self): + """Internal function that evaluates the cached server template to produce + a ServerConfig""" + evaluated_conditions = self.evaluate_conditions(self._conditions, self._context) + + # Overlays config Value objects derived by evaluating the template. + if self._parameters: + for key, parameter in self._parameters.items(): + conditional_values = parameter.get('conditionalValues', {}) + default_value = parameter.get('defaultValue', {}) + parameter_value_wrapper = None + # Iterates in order over condition list. If there is a value associated + # with a condition, this checks if the condition is true. + if evaluated_conditions: + for condition_name, condition_evaluation in evaluated_conditions.items(): + if condition_name in conditional_values and condition_evaluation: + parameter_value_wrapper = conditional_values[condition_name] + break + + if parameter_value_wrapper and parameter_value_wrapper.get('useInAppDefault'): + logger.info("Using in-app default value for key '%s'", key) + continue + + if parameter_value_wrapper: + parameter_value = parameter_value_wrapper.get('value') + self._config_values[key] = _Value('remote', parameter_value) + continue + + if not default_value: + logger.warning("No default value found for key '%s'", key) + continue + + if default_value.get('useInAppDefault'): + logger.info("Using in-app default value for key '%s'", key) + continue + self._config_values[key] = _Value('remote', default_value.get('value')) + return self._config_values + + def evaluate_conditions(self, conditions, context)-> Dict[str, bool]: + """Evaluates a list of conditions and returns a dictionary of results. + + Args: + conditions: A list of NamedCondition objects. + context: An EvaluationContext object. + + Returns: + A dictionary that maps condition names to boolean evaluation results. + """ + evaluated_conditions = {} + for condition in conditions: + evaluated_conditions[condition.get('name')] = self.evaluate_condition( + condition.get('condition'), context + ) + return evaluated_conditions + + def evaluate_condition(self, condition, context, + nesting_level: int = 0) -> bool: + """Recursively evaluates a condition. + + Args: + condition: The condition to evaluate. + context: An EvaluationContext object. + nesting_level: The current recursion depth. + + Returns: + The boolean result of the condition evaluation. + """ + if nesting_level >= MAX_CONDITION_RECURSION_DEPTH: + logger.warning("Maximum condition recursion depth exceeded.") + return False + if condition.get('orCondition') is not None: + return self.evaluate_or_condition(condition.get('orCondition'), + context, nesting_level + 1) + if condition.get('andCondition') is not None: + return self.evaluate_and_condition(condition.get('andCondition'), + context, nesting_level + 1) + if condition.get('true') is not None: + return True + if condition.get('false') is not None: + return False + if condition.get('percent') is not None: + return self.evaluate_percent_condition(condition.get('percent'), context) + if condition.get('customSignal') is not None: + return self.evaluate_custom_signal_condition(condition.get('customSignal'), context) + logger.warning("Unknown condition type encountered.") + return False + + def evaluate_or_condition(self, or_condition, + context, + nesting_level: int = 0) -> bool: + """Evaluates an OR condition. + + Args: + or_condition: The OR condition to evaluate. + context: An EvaluationContext object. + nesting_level: The current recursion depth. + + Returns: + True if any of the subconditions are true, False otherwise. + """ + sub_conditions = or_condition.get('conditions') or [] + for sub_condition in sub_conditions: + result = self.evaluate_condition(sub_condition, context, nesting_level + 1) + if result: + return True + return False + + def evaluate_and_condition(self, and_condition, + context, + nesting_level: int = 0) -> bool: + """Evaluates an AND condition. + + Args: + and_condition: The AND condition to evaluate. + context: An EvaluationContext object. + nesting_level: The current recursion depth. + + Returns: + True if all of the subconditions are met; False otherwise. + """ + sub_conditions = and_condition.get('conditions') or [] + for sub_condition in sub_conditions: + result = self.evaluate_condition(sub_condition, context, nesting_level + 1) + if not result: + return False + return True + + def evaluate_percent_condition(self, percent_condition, + context) -> bool: + """Evaluates a percent condition. + + Args: + percent_condition: The percent condition to evaluate. + context: An EvaluationContext object. + + Returns: + True if the condition is met, False otherwise. + """ + if not context.get('randomization_id'): + logger.warning("Missing randomization_id in context for evaluating percent condition.") + return False + + seed = percent_condition.get('seed') + percent_operator = percent_condition.get('percentOperator') + micro_percent = percent_condition.get('microPercent') + micro_percent_range = percent_condition.get('microPercentRange') + if not percent_operator: + logger.warning("Missing percent operator for percent condition.") + return False + if micro_percent_range: + norm_percent_upper_bound = micro_percent_range.get('microPercentUpperBound') or 0 + norm_percent_lower_bound = micro_percent_range.get('microPercentLowerBound') or 0 + else: + norm_percent_upper_bound = 0 + norm_percent_lower_bound = 0 + if micro_percent: + norm_micro_percent = micro_percent + else: + norm_micro_percent = 0 + seed_prefix = f"{seed}." if seed else "" + string_to_hash = f"{seed_prefix}{context.get('randomization_id')}" + + hash64 = self.hash_seeded_randomization_id(string_to_hash) + instance_micro_percentile = hash64 % (100 * 1000000) + if percent_operator == PercentConditionOperator.LESS_OR_EQUAL.value: + return instance_micro_percentile <= norm_micro_percent + if percent_operator == PercentConditionOperator.GREATER_THAN.value: + return instance_micro_percentile > norm_micro_percent + if percent_operator == PercentConditionOperator.BETWEEN.value: + return norm_percent_lower_bound < instance_micro_percentile <= norm_percent_upper_bound + logger.warning("Unknown percent operator: %s", percent_operator) + return False + def hash_seeded_randomization_id(self, seeded_randomization_id: str) -> int: + """Hashes a seeded randomization ID. + + Args: + seeded_randomization_id: The seeded randomization ID to hash. + + Returns: + The hashed value. + """ + hash_object = hashlib.sha256() + hash_object.update(seeded_randomization_id.encode('utf-8')) + hash64 = hash_object.hexdigest() + return abs(int(hash64, 16)) + + def evaluate_custom_signal_condition(self, custom_signal_condition, + context) -> bool: + """Evaluates a custom signal condition. + + Args: + custom_signal_condition: The custom signal condition to evaluate. + context: An EvaluationContext object. + + Returns: + True if the condition is met, False otherwise. + """ + custom_signal_operator = custom_signal_condition.get('customSignalOperator') or {} + custom_signal_key = custom_signal_condition.get('customSignalKey') or {} + target_custom_signal_values = ( + custom_signal_condition.get('targetCustomSignalValues') or {}) + + if not all([custom_signal_operator, custom_signal_key, target_custom_signal_values]): + logger.warning("Missing operator, key, or target values for custom signal condition.") + return False + + if not target_custom_signal_values: + return False + actual_custom_signal_value = context.get(custom_signal_key) or {} + + if not actual_custom_signal_value: + logger.debug("Custom signal value not found in context: %s", custom_signal_key) + return False + + if custom_signal_operator == CustomSignalOperator.STRING_CONTAINS.value: + return self._compare_strings(target_custom_signal_values, + actual_custom_signal_value, + lambda target, actual: target in actual) + if custom_signal_operator == CustomSignalOperator.STRING_DOES_NOT_CONTAIN.value: + return not self._compare_strings(target_custom_signal_values, + actual_custom_signal_value, + lambda target, actual: target in actual) + if custom_signal_operator == CustomSignalOperator.STRING_EXACTLY_MATCHES.value: + return self._compare_strings(target_custom_signal_values, + actual_custom_signal_value, + lambda target, actual: target.strip() == actual.strip()) + if custom_signal_operator == CustomSignalOperator.STRING_CONTAINS_REGEX.value: + return self._compare_strings(target_custom_signal_values, + actual_custom_signal_value, + re.search) + + # For numeric operators only one target value is allowed. + if custom_signal_operator == CustomSignalOperator.NUMERIC_LESS_THAN.value: + return self._compare_numbers(custom_signal_key, + target_custom_signal_values[0], + actual_custom_signal_value, + lambda r: r < 0) + if custom_signal_operator == CustomSignalOperator.NUMERIC_LESS_EQUAL.value: + return self._compare_numbers(custom_signal_key, + target_custom_signal_values[0], + actual_custom_signal_value, + lambda r: r <= 0) + if custom_signal_operator == CustomSignalOperator.NUMERIC_EQUAL.value: + return self._compare_numbers(custom_signal_key, + target_custom_signal_values[0], + actual_custom_signal_value, + lambda r: r == 0) + if custom_signal_operator == CustomSignalOperator.NUMERIC_NOT_EQUAL.value: + return self._compare_numbers(custom_signal_key, + target_custom_signal_values[0], + actual_custom_signal_value, + lambda r: r != 0) + if custom_signal_operator == CustomSignalOperator.NUMERIC_GREATER_THAN.value: + return self._compare_numbers(custom_signal_key, + target_custom_signal_values[0], + actual_custom_signal_value, + lambda r: r > 0) + if custom_signal_operator == CustomSignalOperator.NUMERIC_GREATER_EQUAL.value: + return self._compare_numbers(custom_signal_key, + target_custom_signal_values[0], + actual_custom_signal_value, + lambda r: r >= 0) + + # For semantic operators only one target value is allowed. + if custom_signal_operator == CustomSignalOperator.SEMANTIC_VERSION_LESS_THAN.value: + return self._compare_semantic_versions(custom_signal_key, + target_custom_signal_values[0], + actual_custom_signal_value, + lambda r: r < 0) + if custom_signal_operator == CustomSignalOperator.SEMANTIC_VERSION_LESS_EQUAL.value: + return self._compare_semantic_versions(custom_signal_key, + target_custom_signal_values[0], + actual_custom_signal_value, + lambda r: r <= 0) + if custom_signal_operator == CustomSignalOperator.SEMANTIC_VERSION_EQUAL.value: + return self._compare_semantic_versions(custom_signal_key, + target_custom_signal_values[0], + actual_custom_signal_value, + lambda r: r == 0) + if custom_signal_operator == CustomSignalOperator.SEMANTIC_VERSION_NOT_EQUAL.value: + return self._compare_semantic_versions(custom_signal_key, + target_custom_signal_values[0], + actual_custom_signal_value, + lambda r: r != 0) + if custom_signal_operator == CustomSignalOperator.SEMANTIC_VERSION_GREATER_THAN.value: + return self._compare_semantic_versions(custom_signal_key, + target_custom_signal_values[0], + actual_custom_signal_value, + lambda r: r > 0) + if custom_signal_operator == CustomSignalOperator.SEMANTIC_VERSION_GREATER_EQUAL.value: + return self._compare_semantic_versions(custom_signal_key, + target_custom_signal_values[0], + actual_custom_signal_value, + lambda r: r >= 0) + logger.warning("Unknown custom signal operator: %s", custom_signal_operator) + return False + + def _compare_strings(self, target_values, actual_value, predicate_fn) -> bool: + """Compares the actual string value of a signal against a list of target values. + + Args: + target_values: A list of target string values. + actual_value: The actual value to compare, which can be a string or number. + predicate_fn: A function that takes two string arguments (target and actual) + and returns a boolean indicating whether + the target matches the actual value. + + Returns: + bool: True if the predicate function returns True for any target value in the list, + False otherwise. + """ + + for target in target_values: + if predicate_fn(target, str(actual_value)): + return True + return False + + def _compare_numbers(self, custom_signal_key, target_value, actual_value, predicate_fn) -> bool: + try: + target = float(target_value) + actual = float(actual_value) + result = -1 if actual < target else 1 if actual > target else 0 + return predicate_fn(result) + except ValueError: + logger.warning("Invalid numeric value for comparison for custom signal key %s.", + custom_signal_key) + return False + + def _compare_semantic_versions(self, custom_signal_key, + target_value, actual_value, predicate_fn) -> bool: + """Compares the actual semantic version value of a signal against a target value. + Calls the predicate function with -1, 0, 1 if actual is less than, equal to, + or greater than target. + + Args: + custom_signal_key: The custom signal for which the evaluation is being performed. + target_values: A list of target string values. + actual_value: The actual value to compare, which can be a string or number. + predicate_fn: A function that takes an integer (-1, 0, or 1) and returns a boolean. + + Returns: + bool: True if the predicate function returns True for the result of the comparison, + False otherwise. + """ + return self._compare_versions(custom_signal_key, str(actual_value), + str(target_value), predicate_fn) + + def _compare_versions(self, custom_signal_key, + sem_version_1, sem_version_2, predicate_fn) -> bool: + """Compares two semantic version strings. + + Args: + custom_signal_key: The custom singal for which the evaluation is being performed. + sem_version_1: The first semantic version string. + sem_version_2: The second semantic version string. + predicate_fn: A function that takes an integer and returns a boolean. + + Returns: + bool: The result of the predicate function. + """ + try: + v1_parts = [int(part) for part in sem_version_1.split('.')] + v2_parts = [int(part) for part in sem_version_2.split('.')] + max_length = max(len(v1_parts), len(v2_parts)) + v1_parts.extend([0] * (max_length - len(v1_parts))) + v2_parts.extend([0] * (max_length - len(v2_parts))) + + for part1, part2 in zip(v1_parts, v2_parts): + if any((part1 < 0, part2 < 0)): + raise ValueError + if part1 < part2: + return predicate_fn(-1) + if part1 > part2: + return predicate_fn(1) + return predicate_fn(0) + except ValueError: + logger.warning( + "Invalid semantic version format for comparison for custom signal key %s.", + custom_signal_key) + return False + +async def get_server_template(app: App = None, default_config: Optional[Dict[str, str]] = None): + """Initializes a new ServerTemplate instance and fetches the server template. + + Args: + app: App instance to be used. This is optional and the default app instance will + be used if not present. + default_config: The default config to be used in the evaluated config. + + Returns: + ServerTemplate: An object having the cached server template to be used for evaluation. + """ + template = init_server_template(app=app, default_config=default_config) + await template.load() + return template + +def init_server_template(app: App = None, default_config: Optional[Dict[str, str]] = None, + template_data_json: Optional[str] = None): + """Initializes a new ServerTemplate instance. + + Args: + app: App instance to be used. This is optional and the default app instance will + be used if not present. + default_config: The default config to be used in the evaluated config. + template_data_json: An optional template data JSON to be set on initialization. + + Returns: + ServerTemplate: A new ServerTemplate instance initialized with an optional + template and config. + """ + template = ServerTemplate(app=app, default_config=default_config) + if template_data_json is not None: + template.set(template_data_json) + return template + +class _Value: + """Represents a value fetched from Remote Config. + """ + DEFAULT_VALUE_FOR_BOOLEAN = False + DEFAULT_VALUE_FOR_STRING = '' + DEFAULT_VALUE_FOR_INTEGER = 0 + DEFAULT_VALUE_FOR_FLOAT_NUMBER = 0.0 + BOOLEAN_TRUTHY_VALUES = ['1', 'true', 't', 'yes', 'y', 'on'] + + def __init__(self, source: ValueSource, value: str = DEFAULT_VALUE_FOR_STRING): + """Initializes a Value instance. + + Args: + source: The source of the value (e.g., 'default', 'remote', 'static'). + "static" indicates the value was defined by a static constant. + "default" indicates the value was defined by default config. + "remote" indicates the value was defined by config produced by evaluating a template. + value: The string value. + """ + self.source = source + self.value = value + + def as_string(self) -> str: + """Returns the value as a string.""" + if self.source == 'static': + return self.DEFAULT_VALUE_FOR_STRING + return str(self.value) + + def as_boolean(self) -> bool: + """Returns the value as a boolean.""" + if self.source == 'static': + return self.DEFAULT_VALUE_FOR_BOOLEAN + return str(self.value).lower() in self.BOOLEAN_TRUTHY_VALUES + + def as_int(self) -> float: + """Returns the value as a number.""" + if self.source == 'static': + return self.DEFAULT_VALUE_FOR_INTEGER + try: + return int(self.value) + except ValueError: + return self.DEFAULT_VALUE_FOR_INTEGER + + def as_float(self) -> float: + """Returns the value as a number.""" + if self.source == 'static': + return self.DEFAULT_VALUE_FOR_FLOAT_NUMBER + try: + return float(self.value) + except ValueError: + return self.DEFAULT_VALUE_FOR_FLOAT_NUMBER + + def get_source(self) -> ValueSource: + """Returns the source of the value.""" + return self.source diff --git a/tests/test_remote_config.py b/tests/test_remote_config.py new file mode 100644 index 000000000..8c6248e18 --- /dev/null +++ b/tests/test_remote_config.py @@ -0,0 +1,984 @@ +# Copyright 2024 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for firebase_admin.remote_config.""" +import json +import uuid +import pytest +import firebase_admin +from firebase_admin.remote_config import ( + CustomSignalOperator, + PercentConditionOperator, + _REMOTE_CONFIG_ATTRIBUTE, + _RemoteConfigService) +from firebase_admin import remote_config, _utils +from tests import testutils + +VERSION_INFO = { + 'versionNumber': '86', + 'updateOrigin': 'ADMIN_SDK_PYTHON', + 'updateType': 'INCREMENTAL_UPDATE', + 'updateUser': { + 'email': 'firebase-adminsdk@gserviceaccount.com' + }, + 'description': 'production version', + 'updateTime': '2024-11-05T16:45:03.541527Z' + } + +SERVER_REMOTE_CONFIG_RESPONSE = { + 'conditions': [ + { + 'name': 'ios', + 'condition': { + 'orCondition': { + 'conditions': [ + { + 'andCondition': { + 'conditions': [ + {'true': {}} + ] + } + } + ] + } + } + }, + ], + 'parameters': { + 'holiday_promo_enabled': { + 'defaultValue': {'value': 'true'}, + 'conditionalValues': {'ios': {'useInAppDefault': 'true'}} + }, + }, + 'parameterGroups': '', + 'etag': 'etag-123456789012-5', + 'version': VERSION_INFO, + } + +SEMENTIC_VERSION_LESS_THAN_TRUE = [ + CustomSignalOperator.SEMANTIC_VERSION_LESS_THAN.value, ['12.1.3.444'], '12.1.3.443', True] +SEMENTIC_VERSION_EQUAL_TRUE = [ + CustomSignalOperator.SEMANTIC_VERSION_EQUAL.value, ['12.1.3.444'], '12.1.3.444', True] +SEMANTIC_VERSION_GREATER_THAN_FALSE = [ + CustomSignalOperator.SEMANTIC_VERSION_LESS_THAN.value, ['12.1.3.4'], '12.1.3.4', False] +SEMANTIC_VERSION_INVALID_FORMAT_STRING = [ + CustomSignalOperator.SEMANTIC_VERSION_LESS_THAN.value, ['12.1.3.444'], '12.1.3.abc', False] +SEMANTIC_VERSION_INVALID_FORMAT_NEGATIVE_INTEGER = [ + CustomSignalOperator.SEMANTIC_VERSION_LESS_THAN.value, ['12.1.3.444'], '12.1.3.-2', False] + +class TestEvaluate: + @classmethod + def setup_class(cls): + cred = testutils.MockCredential() + firebase_admin.initialize_app(cred, {'projectId': 'project-id'}) + + @classmethod + def teardown_class(cls): + testutils.cleanup_apps() + + def test_evaluate_or_and_true_condition_true(self): + app = firebase_admin.get_app() + default_config = {'param1': 'in_app_default_param1', 'param3': 'in_app_default_param3'} + condition = { + 'name': 'is_true', + 'condition': { + 'orCondition': { + 'conditions': [ + { + 'andCondition': { + 'conditions': [ + { + 'name': '', + 'true': { + } + } + ] + } + } + ] + } + } + } + template_data = { + 'conditions': [condition], + 'parameters': { + 'is_enabled': { + 'defaultValue': {'value': 'false'}, + 'conditionalValues': {'is_true': {'value': 'true'}} + }, + }, + 'parameterGroups': '', + 'version': '', + 'etag': 'etag' + } + server_template = remote_config.init_server_template( + app=app, + default_config=default_config, + template_data_json=json.dumps(template_data) + ) + + server_config = server_template.evaluate() + assert server_config.get_boolean('is_enabled') + assert server_config.get_value_source('is_enabled') == 'remote' + + def test_evaluate_or_and_false_condition_false(self): + app = firebase_admin.get_app() + default_config = {'param1': 'in_app_default_param1', 'param3': 'in_app_default_param3'} + condition = { + 'name': 'is_true', + 'condition': { + 'orCondition': { + 'conditions': [ + { + 'andCondition': { + 'conditions': [ + { + 'name': '', + 'false': { + } + } + ] + } + } + ] + } + } + } + template_data = { + 'conditions': [condition], + 'parameters': { + 'is_enabled': { + 'defaultValue': {'value': 'false'}, + 'conditionalValues': {'is_true': {'value': 'true'}} + }, + }, + 'parameterGroups': '', + 'version': '', + 'etag': 'etag' + } + server_template = remote_config.init_server_template( + app=app, + default_config=default_config, + template_data_json=json.dumps(template_data) + ) + + server_config = server_template.evaluate() + assert not server_config.get_boolean('is_enabled') + + def test_evaluate_non_or_condition(self): + app = firebase_admin.get_app() + default_config = {'param1': 'in_app_default_param1', 'param3': 'in_app_default_param3'} + condition = { + 'name': 'is_true', + 'condition': { + 'true': { + } + } + } + template_data = { + 'conditions': [condition], + 'parameters': { + 'is_enabled': { + 'defaultValue': {'value': 'false'}, + 'conditionalValues': {'is_true': {'value': 'true'}} + }, + }, + 'parameterGroups': '', + 'version': '', + 'etag': 'etag' + } + server_template = remote_config.init_server_template( + app=app, + default_config=default_config, + template_data_json=json.dumps(template_data) + ) + + server_config = server_template.evaluate() + assert server_config.get_boolean('is_enabled') + + def test_evaluate_return_conditional_values_honor_order(self): + app = firebase_admin.get_app() + default_config = {'param1': 'in_app_default_param1', 'param3': 'in_app_default_param3'} + template_data = { + 'conditions': [ + { + 'name': 'is_true', + 'condition': { + 'orCondition': { + 'conditions': [ + { + 'andCondition': { + 'conditions': [ + { + 'true': { + } + } + ] + } + } + ] + } + } + }, + { + 'name': 'is_true_too', + 'condition': { + 'orCondition': { + 'conditions': [ + { + 'andCondition': { + 'conditions': [ + { + 'true': { + } + } + ] + } + } + ] + } + } + } + ], + 'parameters': { + 'dog_type': { + 'defaultValue': {'value': 'chihuahua'}, + 'conditionalValues': { + 'is_true_too': {'value': 'dachshund'}, + 'is_true': {'value': 'corgi'} + } + }, + }, + 'parameterGroups':'', + 'version':'', + 'etag': 'etag' + } + server_template = remote_config.init_server_template( + app=app, + default_config=default_config, + template_data_json=json.dumps(template_data) + ) + server_config = server_template.evaluate() + assert server_config.get_string('dog_type') == 'corgi' + + def test_evaluate_default_when_no_param(self): + app = firebase_admin.get_app() + default_config = {'promo_enabled': False, 'promo_discount': '20',} + template_data = SERVER_REMOTE_CONFIG_RESPONSE + template_data['parameters'] = {} + server_template = remote_config.init_server_template( + app=app, + default_config=default_config, + template_data_json=json.dumps(template_data) + ) + server_config = server_template.evaluate() + assert server_config.get_boolean('promo_enabled') == default_config.get('promo_enabled') + assert server_config.get_int('promo_discount') == int(default_config.get('promo_discount')) + + def test_evaluate_default_when_no_default_value(self): + app = firebase_admin.get_app() + default_config = {'default_value': 'local default'} + template_data = SERVER_REMOTE_CONFIG_RESPONSE + template_data['parameters'] = { + 'default_value': {} + } + server_template = remote_config.init_server_template( + app=app, + default_config=default_config, + template_data_json=json.dumps(template_data) + ) + server_config = server_template.evaluate() + assert server_config.get_string('default_value') == default_config.get('default_value') + + def test_evaluate_default_when_in_default(self): + app = firebase_admin.get_app() + template_data = SERVER_REMOTE_CONFIG_RESPONSE + template_data['parameters'] = { + 'remote_default_value': {} + } + default_config = { + 'inapp_default': '🐕' + } + server_template = remote_config.init_server_template( + app=app, + default_config=default_config, + template_data_json=json.dumps(template_data) + ) + server_config = server_template.evaluate() + assert server_config.get_string('inapp_default') == default_config.get('inapp_default') + + def test_evaluate_default_when_defined(self): + app = firebase_admin.get_app() + template_data = SERVER_REMOTE_CONFIG_RESPONSE + template_data['parameters'] = {} + default_config = { + 'dog_type': 'shiba' + } + server_template = remote_config.init_server_template( + app=app, + default_config=default_config, + template_data_json=json.dumps(template_data) + ) + server_config = server_template.evaluate() + assert server_config.get_string('dog_type') == 'shiba' + + def test_evaluate_return_numeric_value(self): + app = firebase_admin.get_app() + template_data = SERVER_REMOTE_CONFIG_RESPONSE + default_config = { + 'dog_age': '12' + } + server_template = remote_config.init_server_template( + app=app, + default_config=default_config, + template_data_json=json.dumps(template_data) + ) + server_config = server_template.evaluate() + assert server_config.get_int('dog_age') == int(default_config.get('dog_age')) + + def test_evaluate_return_boolean_value(self): + app = firebase_admin.get_app() + template_data = SERVER_REMOTE_CONFIG_RESPONSE + default_config = { + 'dog_is_cute': True + } + server_template = remote_config.init_server_template( + app=app, + default_config=default_config, + template_data_json=json.dumps(template_data) + ) + server_config = server_template.evaluate() + assert server_config.get_boolean('dog_is_cute') + + def test_evaluate_unknown_operator_to_false(self): + app = firebase_admin.get_app() + condition = { + 'name': 'is_true', + 'condition': { + 'orCondition': { + 'conditions': [{ + 'andCondition': { + 'conditions': [{ + 'percent': { + 'percentOperator': PercentConditionOperator.UNKNOWN.value + } + }], + } + }] + } + } + } + default_config = { + 'dog_is_cute': True + } + template_data = { + 'conditions': [condition], + 'parameters': { + 'is_enabled': { + 'defaultValue': {'value': 'false'}, + 'conditionalValues': {'is_true': {'value': 'true'}} + }, + }, + 'parameterGroups':'', + 'version':'', + 'etag': '123' + } + context = {'randomization_id': '123'} + server_template = remote_config.init_server_template( + app=app, + default_config=default_config, + template_data_json=json.dumps(template_data) + ) + server_config = server_template.evaluate(context) + assert not server_config.get_boolean('is_enabled') + + def test_evaluate_less_or_equal_to_max_to_true(self): + app = firebase_admin.get_app() + condition = { + 'name': 'is_true', + 'condition': { + 'orCondition': { + 'conditions': [{ + 'andCondition': { + 'conditions': [{ + 'percent': { + 'percentOperator': PercentConditionOperator.LESS_OR_EQUAL.value, + 'seed': 'abcdef', + 'microPercent': 100_000_000 + } + }], + } + }] + } + } + } + default_config = { + 'dog_is_cute': True + } + template_data = { + 'conditions': [condition], + 'parameters': { + 'is_enabled': { + 'defaultValue': {'value': 'false'}, + 'conditionalValues': {'is_true': {'value': 'true'}} + }, + }, + 'parameterGroups':'', + 'version':'', + 'etag': '123' + } + context = {'randomization_id': '123'} + server_template = remote_config.init_server_template( + app=app, + default_config=default_config, + template_data_json=json.dumps(template_data) + ) + server_config = server_template.evaluate(context) + assert server_config.get_boolean('is_enabled') + + def test_evaluate_undefined_micropercent_to_false(self): + app = firebase_admin.get_app() + condition = { + 'name': 'is_true', + 'condition': { + 'orCondition': { + 'conditions': [{ + 'andCondition': { + 'conditions': [{ + 'percent': { + 'percentOperator': PercentConditionOperator.LESS_OR_EQUAL.value, + # Leaves microPercent undefined + } + }], + } + }] + } + } + } + default_config = { + 'dog_is_cute': True + } + template_data = { + 'conditions': [condition], + 'parameters': { + 'is_enabled': { + 'defaultValue': {'value': 'false'}, + 'conditionalValues': {'is_true': {'value': 'true'}} + }, + }, + 'parameterGroups':'', + 'version':'', + 'etag': '123' + } + context = {'randomization_id': '123'} + server_template = remote_config.init_server_template( + app=app, + default_config=default_config, + template_data_json=json.dumps(template_data) + ) + server_config = server_template.evaluate(context) + assert not server_config.get_boolean('is_enabled') + + def test_evaluate_undefined_micropercentrange_to_false(self): + app = firebase_admin.get_app() + condition = { + 'name': 'is_true', + 'condition': { + 'orCondition': { + 'conditions': [{ + 'andCondition': { + 'conditions': [{ + 'percent': { + 'percentOperator': PercentConditionOperator.BETWEEN.value, + # Leaves microPercent undefined + } + }], + } + }] + } + } + } + default_config = { + 'dog_is_cute': True + } + template_data = { + 'conditions': [condition], + 'parameters': { + 'is_enabled': { + 'defaultValue': {'value': 'false'}, + 'conditionalValues': {'is_true': {'value': 'true'}} + }, + }, + 'parameterGroups':'', + 'version':'', + 'etag': '123' + } + context = {'randomization_id': '123'} + server_template = remote_config.init_server_template( + app=app, + default_config=default_config, + template_data_json=json.dumps(template_data) + ) + server_config = server_template.evaluate(context) + assert not server_config.get_boolean('is_enabled') + + def test_evaluate_between_min_max_to_true(self): + app = firebase_admin.get_app() + condition = { + 'name': 'is_true', + 'condition': { + 'orCondition': { + 'conditions': [{ + 'andCondition': { + 'conditions': [{ + 'percent': { + 'percentOperator': PercentConditionOperator.BETWEEN.value, + 'seed': 'abcdef', + 'microPercentRange': { + 'microPercentLowerBound': 0, + 'microPercentUpperBound': 100_000_000 + } + } + }], + } + }] + } + } + } + default_config = { + 'dog_is_cute': True + } + template_data = { + 'conditions': [condition], + 'parameters': { + 'is_enabled': { + 'defaultValue': {'value': 'false'}, + 'conditionalValues': {'is_true': {'value': 'true'}} + }, + }, + 'parameterGroups':'', + 'version':'', + 'etag': '123' + } + context = {'randomization_id': '123'} + server_template = remote_config.init_server_template( + app=app, + default_config=default_config, + template_data_json=json.dumps(template_data) + ) + server_config = server_template.evaluate(context) + assert server_config.get_boolean('is_enabled') + + def test_evaluate_between_equal_bounds_to_false(self): + app = firebase_admin.get_app() + condition = { + 'name': 'is_true', + 'condition': { + 'orCondition': { + 'conditions': [{ + 'andCondition': { + 'conditions': [{ + 'percent': { + 'percentOperator': PercentConditionOperator.BETWEEN.value, + 'seed': 'abcdef', + 'microPercentRange': { + 'microPercentLowerBound': 50000000, + 'microPercentUpperBound': 50000000 + } + } + }], + } + }] + } + } + } + default_config = { + 'dog_is_cute': True + } + template_data = { + 'conditions': [condition], + 'parameters': { + 'is_enabled': { + 'defaultValue': {'value': 'false'}, + 'conditionalValues': {'is_true': {'value': 'true'}} + }, + }, + 'parameterGroups':'', + 'version':'', + 'etag': '123' + } + context = {'randomization_id': '123'} + server_template = remote_config.init_server_template( + app=app, + default_config=default_config, + template_data_json=json.dumps(template_data) + ) + server_config = server_template.evaluate(context) + assert not server_config.get_boolean('is_enabled') + + def test_evaluate_less_or_equal_to_approx(self): + app = firebase_admin.get_app() + condition = { + 'name': 'is_true', + 'condition': { + 'orCondition': { + 'conditions': [{ + 'andCondition': { + 'conditions': [{ + 'percent': { + 'percentOperator': PercentConditionOperator.LESS_OR_EQUAL.value, + 'seed': 'abcdef', + 'microPercent': 10_000_000 # 10% + } + }], + } + }] + } + } + } + default_config = { + 'dog_is_cute': True + } + + truthy_assignments = self.evaluate_random_assignments(condition, 100000, + app, default_config) + tolerance = 284 + assert truthy_assignments >= 10000 - tolerance + assert truthy_assignments <= 10000 + tolerance + + def test_evaluate_between_approx(self): + app = firebase_admin.get_app() + condition = { + 'name': 'is_true', + 'condition': { + 'orCondition': { + 'conditions': [{ + 'andCondition': { + 'conditions': [{ + 'percent': { + 'percentOperator': PercentConditionOperator.BETWEEN.value, + 'seed': 'abcdef', + 'microPercentRange': { + 'microPercentLowerBound': 40_000_000, + 'microPercentUpperBound': 60_000_000 + } + } + }], + } + }] + } + } + } + default_config = { + 'dog_is_cute': True + } + + truthy_assignments = self.evaluate_random_assignments(condition, 100000, + app, default_config) + tolerance = 379 + assert truthy_assignments >= 20000 - tolerance + assert truthy_assignments <= 20000 + tolerance + + def test_evaluate_between_interquartile_range_accuracy(self): + app = firebase_admin.get_app() + condition = { + 'name': 'is_true', + 'condition': { + 'orCondition': { + 'conditions': [{ + 'andCondition': { + 'conditions': [{ + 'percent': { + 'percentOperator': PercentConditionOperator.BETWEEN.value, + 'seed': 'abcdef', + 'microPercentRange': { + 'microPercentLowerBound': 25_000_000, + 'microPercentUpperBound': 75_000_000 + } + } + }], + } + }] + } + } + } + default_config = { + 'dog_is_cute': True + } + + truthy_assignments = self.evaluate_random_assignments(condition, 100000, + app, default_config) + tolerance = 490 + assert truthy_assignments >= 50000 - tolerance + assert truthy_assignments <= 50000 + tolerance + + def evaluate_random_assignments(self, condition, num_of_assignments, mock_app, default_config): + """Evaluates random assignments based on a condition. + + Args: + condition: The condition to evaluate. + num_of_assignments: The number of assignments to generate. + condition_evaluator: An instance of the ConditionEvaluator class. + + Returns: + int: The number of assignments that evaluated to true. + """ + eval_true_count = 0 + template_data = { + 'conditions': [condition], + 'parameters': { + 'is_enabled': { + 'defaultValue': {'value': 'false'}, + 'conditionalValues': {'is_true': {'value': 'true'}} + }, + }, + 'parameterGroups':'', + 'version':'', + 'etag': '123' + } + server_template = remote_config.init_server_template( + app=mock_app, + default_config=default_config, + template_data_json=json.dumps(template_data) + ) + + for _ in range(num_of_assignments): + context = {'randomization_id': str(uuid.uuid4())} + result = server_template.evaluate(context) + if result.get_boolean('is_enabled') is True: + eval_true_count += 1 + + return eval_true_count + + @pytest.mark.parametrize( + 'custom_signal_opearator, \ + target_custom_signal_value, actual_custom_signal_value, parameter_value', + [ + SEMENTIC_VERSION_LESS_THAN_TRUE, + SEMANTIC_VERSION_GREATER_THAN_FALSE, + SEMENTIC_VERSION_EQUAL_TRUE, + SEMANTIC_VERSION_INVALID_FORMAT_NEGATIVE_INTEGER, + SEMANTIC_VERSION_INVALID_FORMAT_STRING + ]) + def test_evaluate_custom_signal_semantic_version(self, + custom_signal_opearator, + target_custom_signal_value, + actual_custom_signal_value, + parameter_value): + app = firebase_admin.get_app() + condition = { + 'name': 'is_true', + 'condition': { + 'orCondition': { + 'conditions': [{ + 'andCondition': { + 'conditions': [{ + 'customSignal': { + 'customSignalOperator': custom_signal_opearator, + 'customSignalKey': 'sementic_version_key', + 'targetCustomSignalValues': target_custom_signal_value + } + }], + } + }] + } + } + } + default_config = { + 'dog_is_cute': True + } + template_data = { + 'conditions': [condition], + 'parameters': { + 'is_enabled': { + 'defaultValue': {'value': 'false'}, + 'conditionalValues': {'is_true': {'value': 'true'}} + }, + }, + 'parameterGroups':'', + 'version':'', + 'etag': '123' + } + context = {'randomization_id': '123', 'sementic_version_key': actual_custom_signal_value} + server_template = remote_config.init_server_template( + app=app, + default_config=default_config, + template_data_json=json.dumps(template_data) + ) + server_config = server_template.evaluate(context) + assert server_config.get_boolean('is_enabled') == parameter_value + + +class MockAdapter(testutils.MockAdapter): + """A Mock HTTP Adapter that provides Firebase Remote Config responses with ETag in header.""" + + ETAG = 'etag' + + def __init__(self, data, status, recorder, etag=ETAG): + testutils.MockAdapter.__init__(self, data, status, recorder) + self._etag = etag + + def send(self, request, **kwargs): + resp = super(MockAdapter, self).send(request, **kwargs) + resp.headers = {'etag': self._etag} + return resp + + +class TestRemoteConfigService: + """Tests methods on _RemoteConfigService""" + @classmethod + def setup_class(cls): + cred = testutils.MockCredential() + firebase_admin.initialize_app(cred, {'projectId': 'project-id'}) + + @classmethod + def teardown_class(cls): + testutils.cleanup_apps() + + @pytest.mark.asyncio + async def test_rc_instance_get_server_template(self): + recorder = [] + response = json.dumps({ + 'parameters': { + 'test_key': 'test_value' + }, + 'conditions': [], + 'version': 'test' + }) + + rc_instance = _utils.get_app_service(firebase_admin.get_app(), + _REMOTE_CONFIG_ATTRIBUTE, _RemoteConfigService) + rc_instance._client.session.mount( + 'https://firebaseremoteconfig.googleapis.com', + MockAdapter(response, 200, recorder)) + + template = await rc_instance.get_server_template() + + assert template.parameters == dict(test_key="test_value") + assert str(template.version) == 'test' + assert str(template.etag) == 'etag' + + @pytest.mark.asyncio + async def test_rc_instance_get_server_template_empty_params(self): + recorder = [] + response = json.dumps({ + 'conditions': [], + 'version': 'test' + }) + + rc_instance = _utils.get_app_service(firebase_admin.get_app(), + _REMOTE_CONFIG_ATTRIBUTE, _RemoteConfigService) + rc_instance._client.session.mount( + 'https://firebaseremoteconfig.googleapis.com', + MockAdapter(response, 200, recorder)) + + template = await rc_instance.get_server_template() + + assert template.parameters == {} + assert str(template.version) == 'test' + assert str(template.etag) == 'etag' + + +class TestRemoteConfigModule: + """Tests methods on firebase_admin.remote_config""" + @classmethod + def setup_class(cls): + cred = testutils.MockCredential() + firebase_admin.initialize_app(cred, {'projectId': 'project-id'}) + + @classmethod + def teardown_class(cls): + testutils.cleanup_apps() + + def test_init_server_template(self): + app = firebase_admin.get_app() + template_data = { + 'conditions': [], + 'parameters': { + 'test_key': { + 'defaultValue': {'value': 'test_value'}, + 'conditionalValues': {} + } + }, + 'version': '', + } + + template = remote_config.init_server_template( + app=app, + default_config={'default_test': 'default_value'}, + template_data_json=json.dumps(template_data) + ) + + config = template.evaluate() + assert config.get_string('test_key') == 'test_value' + + @pytest.mark.asyncio + async def test_get_server_template(self): + app = firebase_admin.get_app() + rc_instance = _utils.get_app_service(app, + _REMOTE_CONFIG_ATTRIBUTE, _RemoteConfigService) + + recorder = [] + response = json.dumps({ + 'parameters': { + 'test_key': { + 'defaultValue': {'value': 'test_value'}, + 'conditionalValues': {} + } + }, + 'conditions': [], + 'version': 'test' + }) + + rc_instance._client.session.mount( + 'https://firebaseremoteconfig.googleapis.com', + MockAdapter(response, 200, recorder)) + + template = await remote_config.get_server_template(app=app) + + config = template.evaluate() + assert config.get_string('test_key') == 'test_value' + + @pytest.mark.asyncio + async def test_server_template_to_json(self): + app = firebase_admin.get_app() + rc_instance = _utils.get_app_service(app, + _REMOTE_CONFIG_ATTRIBUTE, _RemoteConfigService) + + recorder = [] + response = json.dumps({ + 'parameters': { + 'test_key': { + 'defaultValue': {'value': 'test_value'}, + 'conditionalValues': {} + } + }, + 'conditions': [], + 'version': 'test' + }) + + expected_template_json = '{"parameters": {' \ + '"test_key": {' \ + '"defaultValue": {' \ + '"value": "test_value"}, ' \ + '"conditionalValues": {}}}, "conditions": [], ' \ + '"version": "test", "etag": "etag"}' + + rc_instance._client.session.mount( + 'https://firebaseremoteconfig.googleapis.com', + MockAdapter(response, 200, recorder)) + template = await remote_config.get_server_template(app=app) + + template_json = template.to_json() + assert template_json == expected_template_json diff --git a/tests/testutils.py b/tests/testutils.py index ab4fb40cb..17013b469 100644 --- a/tests/testutils.py +++ b/tests/testutils.py @@ -218,3 +218,43 @@ def send(self, request, **kwargs): # pylint: disable=arguments-differ resp.raw = io.BytesIO(response.encode()) break return resp + +def build_mock_condition(name, condition): + return { + 'name': name, + 'condition': condition, + } + +def build_mock_parameter(name, description, value=None, + conditional_values=None, default_value=None, parameter_groups=None): + return { + 'name': name, + 'description': description, + 'value': value, + 'conditionalValues': conditional_values, + 'defaultValue': default_value, + 'parameterGroups': parameter_groups, + } + +def build_mock_conditional_value(condition_name, value): + return { + 'conditionName': condition_name, + 'value': value, + } + +def build_mock_default_value(value): + return { + 'value': value, + } + +def build_mock_parameter_group(name, description, parameters): + return { + 'name': name, + 'description': description, + 'parameters': parameters, + } + +def build_mock_version(version_number): + return { + 'versionNumber': version_number, + } From 3c862081dee305474350a26291dc6d6488da1222 Mon Sep 17 00:00:00 2001 From: Lahiru Maramba Date: Wed, 12 Mar 2025 12:55:04 -0400 Subject: [PATCH 196/226] [chore] Release 6.7.0 (#867) --- firebase_admin/__about__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/firebase_admin/__about__.py b/firebase_admin/__about__.py index 4ee475c8a..2c606611f 100644 --- a/firebase_admin/__about__.py +++ b/firebase_admin/__about__.py @@ -14,7 +14,7 @@ """About information (version, etc) for Firebase Admin SDK.""" -__version__ = '6.6.0' +__version__ = '6.7.0' __title__ = 'firebase_admin' __author__ = 'Firebase' __license__ = 'Apache License 2.0' From 387f11a5d61edd19e2fd2d1c3f8baf81d5d2aa9f Mon Sep 17 00:00:00 2001 From: Jonathan Edey <145066863+jonathanedey@users.noreply.github.com> Date: Wed, 19 Mar 2025 16:54:01 -0400 Subject: [PATCH 197/226] feat(fcm): Support `proxy` field in FCM `AndroidNotification` (#868) * feat(fcm): Support `proxy` field in FCM `AndroidNotification` * fix lint * fix: Update `proxy` and `visibility` doc string with TW suggestion --- firebase_admin/_messaging_encoder.py | 11 ++++++++++- firebase_admin/_messaging_utils.py | 10 ++++++++-- integration/test_messaging.py | 3 ++- tests/test_messaging.py | 16 ++++++++++++++++ 4 files changed, 36 insertions(+), 4 deletions(-) diff --git a/firebase_admin/_messaging_encoder.py b/firebase_admin/_messaging_encoder.py index 85072b597..d7f233289 100644 --- a/firebase_admin/_messaging_encoder.py +++ b/firebase_admin/_messaging_encoder.py @@ -319,7 +319,9 @@ def encode_android_notification(cls, notification): 'visibility': _Validators.check_string( 'AndroidNotification.visibility', notification.visibility, non_empty=True), 'notification_count': _Validators.check_number( - 'AndroidNotification.notification_count', notification.notification_count) + 'AndroidNotification.notification_count', notification.notification_count), + 'proxy': _Validators.check_string( + 'AndroidNotification.proxy', notification.proxy, non_empty=True) } result = cls.remove_null_values(result) color = result.get('color') @@ -363,6 +365,13 @@ def encode_android_notification(cls, notification): 'AndroidNotification.vibrate_timings_millis', msec) vibrate_timing_strings.append(formated_string) result['vibrate_timings'] = vibrate_timing_strings + + proxy = result.get('proxy') + if proxy: + if proxy not in ('allow', 'deny', 'if_priority_lowered'): + raise ValueError( + 'AndroidNotification.proxy must be "allow", "deny" or "if_priority_lowered".') + result['proxy'] = proxy.upper() return result @classmethod diff --git a/firebase_admin/_messaging_utils.py b/firebase_admin/_messaging_utils.py index 29b8276bc..ae1f5cc56 100644 --- a/firebase_admin/_messaging_utils.py +++ b/firebase_admin/_messaging_utils.py @@ -137,7 +137,8 @@ class AndroidNotification: If ``default_light_settings`` is set to ``True`` and ``light_settings`` is also set, the user-specified ``light_settings`` is used instead of the default value. visibility: Sets the visibility of the notification. Must be either ``private``, ``public``, - or ``secret``. If unspecified, default to ``private``. + or ``secret``. If unspecified, it remains undefined in the Admin SDK, and defers to + the FCM backend's default mapping. notification_count: Sets the number of items this notification represents. May be displayed as a badge count for Launchers that support badging. See ``NotificationBadge`` https://developer.android.com/training/notify-user/badges. For example, this might be @@ -145,6 +146,9 @@ class AndroidNotification: want the count here to represent the number of total new messages. If zero or unspecified, systems that support badging use the default, which is to increment a number displayed on the long-press menu each time a new notification arrives. + proxy: Sets if the notification may be proxied. Must be one of ``allow``, ``deny``, or + ``if_priority_lowered``. If unspecified, it remains undefined in the Admin SDK, and + defers to the FCM backend's default mapping. """ @@ -154,7 +158,8 @@ def __init__(self, title=None, body=None, icon=None, color=None, sound=None, tag title_loc_args=None, channel_id=None, image=None, ticker=None, sticky=None, event_timestamp=None, local_only=None, priority=None, vibrate_timings_millis=None, default_vibrate_timings=None, default_sound=None, light_settings=None, - default_light_settings=None, visibility=None, notification_count=None): + default_light_settings=None, visibility=None, notification_count=None, + proxy=None): self.title = title self.body = body self.icon = icon @@ -180,6 +185,7 @@ def __init__(self, title=None, body=None, icon=None, color=None, sound=None, tag self.default_light_settings = default_light_settings self.visibility = visibility self.notification_count = notification_count + self.proxy = proxy class LightSettings: diff --git a/integration/test_messaging.py b/integration/test_messaging.py index 50b4ae3a4..4c1d7d0dc 100644 --- a/integration/test_messaging.py +++ b/integration/test_messaging.py @@ -55,7 +55,8 @@ def test_send(): light_off_duration_millis=200, light_on_duration_millis=300 ), - notification_count=1 + notification_count=1, + proxy='if_priority_lowered', ) ), apns=messaging.APNSConfig(payload=messaging.APNSPayload( diff --git a/tests/test_messaging.py b/tests/test_messaging.py index edb36f53a..b7b5c69ba 100644 --- a/tests/test_messaging.py +++ b/tests/test_messaging.py @@ -535,6 +535,20 @@ def test_invalid_visibility(self, visibility): expected = 'AndroidNotification.visibility must be a non-empty string.' assert str(excinfo.value) == expected + @pytest.mark.parametrize('proxy', NON_STRING_ARGS + ['foo']) + def test_invalid_proxy(self, proxy): + notification = messaging.AndroidNotification(proxy=proxy) + excinfo = self._check_notification(notification) + if isinstance(proxy, str): + if not proxy: + expected = 'AndroidNotification.proxy must be a non-empty string.' + else: + expected = ('AndroidNotification.proxy must be "allow", "deny" or' + ' "if_priority_lowered".') + else: + expected = 'AndroidNotification.proxy must be a non-empty string.' + assert str(excinfo.value) == expected + @pytest.mark.parametrize('vibrate_timings', ['', 1, True, 'msec', ['500', 500], [0, 'abc']]) def test_invalid_vibrate_timings_millis(self, vibrate_timings): notification = messaging.AndroidNotification(vibrate_timings_millis=vibrate_timings) @@ -580,6 +594,7 @@ def test_android_notification(self): light_off_duration_millis=300, ), default_light_settings=False, visibility='public', notification_count=1, + proxy='if_priority_lowered', ) ) ) @@ -620,6 +635,7 @@ def test_android_notification(self): 'default_light_settings': False, 'visibility': 'PUBLIC', 'notification_count': 1, + 'proxy': 'IF_PRIORITY_LOWERED' }, }, } From ffeb939d55ada0aac4b18b91b26ef431da58495e Mon Sep 17 00:00:00 2001 From: Lahiru Maramba Date: Tue, 22 Apr 2025 16:09:26 -0400 Subject: [PATCH 198/226] Python 3.8 has EoL'ed. Update README to deprecate Python 3.8 support (#873) Updated the 'Supported Python Versions' section in README.md to indicate that Python 3.7 and Python 3.8 support is deprecated, advising users to use Python 3.9 or higher. Co-authored-by: google-labs-jules[bot] <161369871+google-labs-jules[bot]@users.noreply.github.com> --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index f7cae21ff..6e3ed6805 100644 --- a/README.md +++ b/README.md @@ -43,8 +43,8 @@ requests, code review feedback, and also pull requests. ## Supported Python Versions -We currently support Python 3.7+. However, Python 3.7 support is deprecated, -and developers are strongly advised to use Python 3.8 or higher. Firebase +We currently support Python 3.7+. However, Python 3.7 and Python 3.8 support is deprecated, +and developers are strongly advised to use Python 3.9 or higher. Firebase Admin Python SDK is also tested on PyPy and [Google App Engine](https://cloud.google.com/appengine/) environments. From bde3fb0134b2f84d789cff47b932c07981d6565b Mon Sep 17 00:00:00 2001 From: Jonathan Edey <145066863+jonathanedey@users.noreply.github.com> Date: Thu, 24 Apr 2025 14:49:17 -0400 Subject: [PATCH 199/226] [chore] Release 6.8.0 (#874) --- firebase_admin/__about__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/firebase_admin/__about__.py b/firebase_admin/__about__.py index 2c606611f..c822fb375 100644 --- a/firebase_admin/__about__.py +++ b/firebase_admin/__about__.py @@ -14,7 +14,7 @@ """About information (version, etc) for Firebase Admin SDK.""" -__version__ = '6.7.0' +__version__ = '6.8.0' __title__ = 'firebase_admin' __author__ = 'Firebase' __license__ = 'Apache License 2.0' From 70013c8ad55181befc1e4c29a2871ebcfe036e34 Mon Sep 17 00:00:00 2001 From: Jonathan Edey <145066863+jonathanedey@users.noreply.github.com> Date: Thu, 8 May 2025 14:15:39 -0400 Subject: [PATCH 200/226] chore: Correct x-goog-api-client header logic (#876) --- firebase_admin/_http_client.py | 4 ++-- firebase_admin/app_check.py | 2 +- firebase_admin/storage.py | 2 +- tests/test_auth_providers.py | 6 +++++- tests/test_db.py | 6 ++++-- tests/test_functions.py | 10 ++++++---- tests/test_http_client.py | 16 +++++++++++++++- tests/test_instance_id.py | 3 ++- tests/test_messaging.py | 6 ++++-- tests/test_ml.py | 3 ++- tests/test_project_management.py | 3 ++- tests/test_tenant_mgt.py | 18 ++++++++++++------ tests/test_user_mgt.py | 6 +++++- tests/testutils.py | 4 ++++ 14 files changed, 65 insertions(+), 24 deletions(-) diff --git a/firebase_admin/_http_client.py b/firebase_admin/_http_client.py index f1eccbcf2..57c09e2e4 100644 --- a/firebase_admin/_http_client.py +++ b/firebase_admin/_http_client.py @@ -38,7 +38,7 @@ DEFAULT_TIMEOUT_SECONDS = 120 METRICS_HEADERS = { - 'X-GOOG-API-CLIENT': _utils.get_metrics_header(), + 'x-goog-api-client': _utils.get_metrics_header(), } class HttpClient: @@ -76,7 +76,6 @@ def __init__( if headers: self._session.headers.update(headers) - self._session.headers.update(METRICS_HEADERS) if retries: self._session.mount('http://', requests.adapters.HTTPAdapter(max_retries=retries)) self._session.mount('https://', requests.adapters.HTTPAdapter(max_retries=retries)) @@ -120,6 +119,7 @@ class call this method to send HTTP requests out. Refer to """ if 'timeout' not in kwargs: kwargs['timeout'] = self.timeout + kwargs.setdefault('headers', {}).update(METRICS_HEADERS) resp = self._session.request(method, self.base_url + url, **kwargs) resp.raise_for_status() return resp diff --git a/firebase_admin/app_check.py b/firebase_admin/app_check.py index e6b66efc1..53686db3d 100644 --- a/firebase_admin/app_check.py +++ b/firebase_admin/app_check.py @@ -52,7 +52,7 @@ class _AppCheckService: _jwks_client = None _APP_CHECK_HEADERS = { - 'X-GOOG-API-CLIENT': _utils.get_metrics_header(), + 'x-goog-api-client': _utils.get_metrics_header(), } def __init__(self, app): diff --git a/firebase_admin/storage.py b/firebase_admin/storage.py index 46f5f6043..b6084842a 100644 --- a/firebase_admin/storage.py +++ b/firebase_admin/storage.py @@ -56,7 +56,7 @@ class _StorageClient: """Holds a Google Cloud Storage client instance.""" STORAGE_HEADERS = { - 'X-GOOG-API-CLIENT': _utils.get_metrics_header(), + 'x-goog-api-client': _utils.get_metrics_header(), } def __init__(self, credentials, project, default_bucket): diff --git a/tests/test_auth_providers.py b/tests/test_auth_providers.py index 48f38a011..304e0fd78 100644 --- a/tests/test_auth_providers.py +++ b/tests/test_auth_providers.py @@ -75,7 +75,11 @@ def _assert_request(request, expected_method, expected_url): assert request.method == expected_method assert request.url == expected_url assert request.headers['X-Client-Version'] == f'Python/Admin/{firebase_admin.__version__}' - assert request.headers['X-GOOG-API-CLIENT'] == _utils.get_metrics_header() + expected_metrics_header = [ + _utils.get_metrics_header(), + _utils.get_metrics_header() + ' mock-cred-metric-tag' + ] + assert request.headers['x-goog-api-client'] in expected_metrics_header class TestOIDCProviderConfig: diff --git a/tests/test_db.py b/tests/test_db.py index f2ba08827..00a0077cb 100644 --- a/tests/test_db.py +++ b/tests/test_db.py @@ -198,7 +198,8 @@ def _assert_request(self, request, expected_method, expected_url): assert request.url == expected_url assert request.headers['Authorization'] == 'Bearer mock-token' assert request.headers['User-Agent'] == db._USER_AGENT - assert request.headers['X-GOOG-API-CLIENT'] == _utils.get_metrics_header() + expected_metrics_header = _utils.get_metrics_header() + ' mock-cred-metric-tag' + assert request.headers['x-goog-api-client'] == expected_metrics_header @pytest.mark.parametrize('data', valid_values) def test_get_value(self, data): @@ -665,7 +666,8 @@ def _assert_request(self, request, expected_method, expected_url): assert request.url == expected_url assert request.headers['Authorization'] == 'Bearer mock-token' assert request.headers['User-Agent'] == db._USER_AGENT - assert request.headers['X-GOOG-API-CLIENT'] == _utils.get_metrics_header() + expected_metrics_header = _utils.get_metrics_header() + ' mock-cred-metric-tag' + assert request.headers['x-goog-api-client'] == expected_metrics_header def test_get_value(self): ref = db.reference('/test') diff --git a/tests/test_functions.py b/tests/test_functions.py index f8f675890..1856426d9 100644 --- a/tests/test_functions.py +++ b/tests/test_functions.py @@ -122,7 +122,8 @@ def test_task_enqueue(self): assert recorder[0].url == _DEFAULT_REQUEST_URL assert recorder[0].headers['Content-Type'] == 'application/json' assert recorder[0].headers['Authorization'] == 'Bearer mock-token' - assert recorder[0].headers['X-GOOG-API-CLIENT'] == _utils.get_metrics_header() + expected_metrics_header = _utils.get_metrics_header() + ' mock-cred-metric-tag' + assert recorder[0].headers['x-goog-api-client'] == expected_metrics_header assert task_id == 'test-task-id' def test_task_enqueue_with_extension(self): @@ -139,7 +140,8 @@ def test_task_enqueue_with_extension(self): assert recorder[0].url == _CLOUD_TASKS_URL + resource_name assert recorder[0].headers['Content-Type'] == 'application/json' assert recorder[0].headers['Authorization'] == 'Bearer mock-token' - assert recorder[0].headers['X-GOOG-API-CLIENT'] == _utils.get_metrics_header() + expected_metrics_header = _utils.get_metrics_header() + ' mock-cred-metric-tag' + assert recorder[0].headers['x-goog-api-client'] == expected_metrics_header assert task_id == 'test-task-id' def test_task_delete(self): @@ -149,8 +151,8 @@ def test_task_delete(self): assert len(recorder) == 1 assert recorder[0].method == 'DELETE' assert recorder[0].url == _DEFAULT_TASK_URL - assert recorder[0].headers['X-GOOG-API-CLIENT'] == _utils.get_metrics_header() - + expected_metrics_header = _utils.get_metrics_header() + ' mock-cred-metric-tag' + assert recorder[0].headers['x-goog-api-client'] == expected_metrics_header class TestTaskQueueOptions: diff --git a/tests/test_http_client.py b/tests/test_http_client.py index cc948b393..78036166c 100644 --- a/tests/test_http_client.py +++ b/tests/test_http_client.py @@ -71,7 +71,21 @@ def test_metrics_headers(): assert len(recorder) == 1 assert recorder[0].method == 'GET' assert recorder[0].url == _TEST_URL - assert recorder[0].headers['X-GOOG-API-CLIENT'] == _utils.get_metrics_header() + assert recorder[0].headers['x-goog-api-client'] == _utils.get_metrics_header() + +def test_metrics_headers_with_credentials(): + client = _http_client.HttpClient( + credential=testutils.MockGoogleCredential()) + assert client.session is not None + recorder = _instrument(client, 'body') + resp = client.request('get', _TEST_URL) + assert resp.status_code == 200 + assert resp.text == 'body' + assert len(recorder) == 1 + assert recorder[0].method == 'GET' + assert recorder[0].url == _TEST_URL + expected_metrics_header = _utils.get_metrics_header() + ' mock-cred-metric-tag' + assert recorder[0].headers['x-goog-api-client'] == expected_metrics_header def test_credential(): client = _http_client.HttpClient( diff --git a/tests/test_instance_id.py b/tests/test_instance_id.py index 720171cd9..387e067c9 100644 --- a/tests/test_instance_id.py +++ b/tests/test_instance_id.py @@ -68,7 +68,8 @@ def _instrument_iid_service(self, app, status=200, payload='True'): def _assert_request(self, request, expected_method, expected_url): assert request.method == expected_method assert request.url == expected_url - assert request.headers['X-GOOG-API-CLIENT'] == _utils.get_metrics_header() + expected_metrics_header = _utils.get_metrics_header() + ' mock-cred-metric-tag' + assert request.headers['x-goog-api-client'] == expected_metrics_header def _get_url(self, project_id, iid): return instance_id._IID_SERVICE_URL + 'project/{0}/instanceId/{1}'.format(project_id, iid) diff --git a/tests/test_messaging.py b/tests/test_messaging.py index b7b5c69ba..45a5bc6d5 100644 --- a/tests/test_messaging.py +++ b/tests/test_messaging.py @@ -1683,7 +1683,8 @@ def _assert_request(self, request, expected_method, expected_url, expected_body= assert request.url == expected_url assert request.headers['X-GOOG-API-FORMAT-VERSION'] == '2' assert request.headers['X-FIREBASE-CLIENT'] == self._CLIENT_VERSION - assert request.headers['X-GOOG-API-CLIENT'] == _utils.get_metrics_header() + expected_metrics_header = _utils.get_metrics_header() + ' mock-cred-metric-tag' + assert request.headers['x-goog-api-client'] == expected_metrics_header if expected_body is None: assert request.body is None else: @@ -2604,7 +2605,8 @@ def _assert_request(self, request, expected_method, expected_url): assert request.method == expected_method assert request.url == expected_url assert request.headers['access_token_auth'] == 'true' - assert request.headers['X-GOOG-API-CLIENT'] == _utils.get_metrics_header() + expected_metrics_header = _utils.get_metrics_header() + ' mock-cred-metric-tag' + assert request.headers['x-goog-api-client'] == expected_metrics_header def _get_url(self, path): return '{0}/{1}'.format(messaging._MessagingService.IID_URL, path) diff --git a/tests/test_ml.py b/tests/test_ml.py index 137fe4cf6..18a9e2754 100644 --- a/tests/test_ml.py +++ b/tests/test_ml.py @@ -339,7 +339,8 @@ def _assert_request(request, expected_method, expected_url): assert request.method == expected_method assert request.url == expected_url assert request.headers['X-FIREBASE-CLIENT'] == f'fire-admin-python/{firebase_admin.__version__}' - assert request.headers['X-GOOG-API-CLIENT'] == _utils.get_metrics_header() + expected_metrics_header = _utils.get_metrics_header() + ' mock-cred-metric-tag' + assert request.headers['x-goog-api-client'] == expected_metrics_header class _TestStorageClient: @staticmethod diff --git a/tests/test_project_management.py b/tests/test_project_management.py index 0a1bf97e5..a242f523f 100644 --- a/tests/test_project_management.py +++ b/tests/test_project_management.py @@ -523,7 +523,8 @@ def _assert_request_is_correct( assert request.method == expected_method assert request.url == expected_url assert request.headers['X-Client-Version'] == f'Python/Admin/{firebase_admin.__version__}' - assert request.headers['X-GOOG-API-CLIENT'] == _utils.get_metrics_header() + expected_metrics_header = _utils.get_metrics_header() + ' mock-cred-metric-tag' + assert request.headers['x-goog-api-client'] == expected_metrics_header if expected_body is None: assert request.body is None else: diff --git a/tests/test_tenant_mgt.py b/tests/test_tenant_mgt.py index 1da6d938a..224fdcc16 100644 --- a/tests/test_tenant_mgt.py +++ b/tests/test_tenant_mgt.py @@ -197,7 +197,8 @@ def test_get_tenant(self, tenant_mgt_app): assert req.method == 'GET' assert req.url == '{0}/tenants/tenant-id'.format(TENANT_MGT_URL_PREFIX) assert req.headers['X-Client-Version'] == f'Python/Admin/{firebase_admin.__version__}' - assert req.headers['X-GOOG-API-CLIENT'] == _utils.get_metrics_header() + expected_metrics_header = _utils.get_metrics_header() + ' mock-cred-metric-tag' + assert req.headers['x-goog-api-client'] == expected_metrics_header def test_tenant_not_found(self, tenant_mgt_app): _instrument_tenant_mgt(tenant_mgt_app, 500, TENANT_NOT_FOUND_RESPONSE) @@ -289,7 +290,8 @@ def _assert_request(self, recorder, body): assert req.method == 'POST' assert req.url == '{0}/tenants'.format(TENANT_MGT_URL_PREFIX) assert req.headers['X-Client-Version'] == f'Python/Admin/{firebase_admin.__version__}' - assert req.headers['X-GOOG-API-CLIENT'] == _utils.get_metrics_header() + expected_metrics_header = _utils.get_metrics_header() + ' mock-cred-metric-tag' + assert req.headers['x-goog-api-client'] == expected_metrics_header got = json.loads(req.body.decode()) assert got == body @@ -389,7 +391,8 @@ def _assert_request(self, recorder, body, mask): assert req.url == '{0}/tenants/tenant-id?updateMask={1}'.format( TENANT_MGT_URL_PREFIX, ','.join(mask)) assert req.headers['X-Client-Version'] == f'Python/Admin/{firebase_admin.__version__}' - assert req.headers['X-GOOG-API-CLIENT'] == _utils.get_metrics_header() + expected_metrics_header = _utils.get_metrics_header() + ' mock-cred-metric-tag' + assert req.headers['x-goog-api-client'] == expected_metrics_header got = json.loads(req.body.decode()) assert got == body @@ -411,7 +414,8 @@ def test_delete_tenant(self, tenant_mgt_app): assert req.method == 'DELETE' assert req.url == '{0}/tenants/tenant-id'.format(TENANT_MGT_URL_PREFIX) assert req.headers['X-Client-Version'] == f'Python/Admin/{firebase_admin.__version__}' - assert req.headers['X-GOOG-API-CLIENT'] == _utils.get_metrics_header() + expected_metrics_header = _utils.get_metrics_header() + ' mock-cred-metric-tag' + assert req.headers['x-goog-api-client'] == expected_metrics_header def test_tenant_not_found(self, tenant_mgt_app): _instrument_tenant_mgt(tenant_mgt_app, 500, TENANT_NOT_FOUND_RESPONSE) @@ -555,7 +559,8 @@ def _assert_request(self, recorder, expected=None): req = recorder[0] assert req.method == 'GET' assert req.headers['X-Client-Version'] == f'Python/Admin/{firebase_admin.__version__}' - assert req.headers['X-GOOG-API-CLIENT'] == _utils.get_metrics_header() + expected_metrics_header = _utils.get_metrics_header() + ' mock-cred-metric-tag' + assert req.headers['x-goog-api-client'] == expected_metrics_header request = dict(parse.parse_qsl(parse.urlsplit(req.url).query)) assert request == expected @@ -932,7 +937,8 @@ def _assert_request( assert req.method == method assert req.url == '{0}/tenants/tenant-id{1}'.format(prefix, want_url) assert req.headers['X-Client-Version'] == f'Python/Admin/{firebase_admin.__version__}' - assert req.headers['X-GOOG-API-CLIENT'] == _utils.get_metrics_header() + expected_metrics_header = _utils.get_metrics_header() + ' mock-cred-metric-tag' + assert req.headers['x-goog-api-client'] == expected_metrics_header body = json.loads(req.body.decode()) assert body == want_body diff --git a/tests/test_user_mgt.py b/tests/test_user_mgt.py index 604ec9959..34b698be4 100644 --- a/tests/test_user_mgt.py +++ b/tests/test_user_mgt.py @@ -136,7 +136,11 @@ def _check_request(recorder, want_url, want_body=None, want_timeout=None): req = recorder[0] assert req.method == 'POST' assert req.url == '{0}{1}'.format(USER_MGT_URLS['PREFIX'], want_url) - assert req.headers['X-GOOG-API-CLIENT'] == _utils.get_metrics_header() + expected_metrics_header = [ + _utils.get_metrics_header(), + _utils.get_metrics_header() + ' mock-cred-metric-tag' + ] + assert req.headers['x-goog-api-client'] in expected_metrics_header if want_body: body = json.loads(req.body.decode()) assert body == want_body diff --git a/tests/testutils.py b/tests/testutils.py index 17013b469..62f7bd9b5 100644 --- a/tests/testutils.py +++ b/tests/testutils.py @@ -123,6 +123,10 @@ def refresh(self, request): def service_account_email(self): return 'mock-email' + # Simulate x-goog-api-client modification in credential refresh + def _metric_header_for_usage(self): + return 'mock-cred-metric-tag' + class MockCredential(firebase_admin.credentials.Base): """A mock Firebase credential implementation.""" From 2d9b18c6009cdab53654c972b4f0e0fecf50eed3 Mon Sep 17 00:00:00 2001 From: Jonathan Edey <145066863+jonathanedey@users.noreply.github.com> Date: Tue, 27 May 2025 09:07:55 -0400 Subject: [PATCH 201/226] chore: Use mock time for consistent token generation and verification tests (#881) * Fix(tests): Use mock time for consistent token generation and verification tests Patches time.time and google.auth.jwt._helpers.utcnow to use a fixed timestamp (MOCK_CURRENT_TIME) throughout tests/test_token_gen.py. This addresses test flakiness and inconsistencies by ensuring that: 1. Tokens and cookies are generated with predictable `iat` and `exp` claims based on MOCK_CURRENT_TIME. 2. The verification logic within the Firebase Admin SDK and the underlying google-auth library also uses MOCK_CURRENT_TIME. Helper functions _get_id_token and _get_session_cookie were updated to default to using MOCK_CURRENT_TIME for their internal time calculations, simplifying test code. Relevant fixtures and token definitions were updated to rely on these new defaults and the fixed timestamp. The setup_method in TestVerifyIdToken, TestVerifySessionCookie, TestCertificateCaching, and TestCertificateFetchTimeout now mock time.time and google.auth.jwt._helpers.utcnow to ensure that all time-sensitive operations during testing use the MOCK_CURRENT_TIME. * Fix(tests): Apply time mocking to test_tenant_mgt.py Extends the time mocking strategy (using a fixed MOCK_CURRENT_TIME) to tests in `tests/test_tenant_mgt.py` to ensure consistency with changes previously made in `tests/test_token_gen.py`. Specifically: - Imported `MOCK_CURRENT_TIME` from `tests.test_token_gen`. - Added `setup_method` (and `teardown_method`) to the `TestVerifyIdToken` and `TestCreateCustomToken` classes. - These setup methods patch `time.time` and `google.auth.jwt._helpers.utcnow` to return `MOCK_CURRENT_TIME` (or its datetime equivalent). This ensures that token generation (for custom tokens) and token verification within `test_tenant_mgt.py` align with the mocked timeline, preventing potential flakiness or failures due to time inconsistencies. All tests in `test_tenant_mgt.py` pass with these changes. * fix lint and refactor --------- Co-authored-by: google-labs-jules[bot] <161369871+google-labs-jules[bot]@users.noreply.github.com> --- tests/test_tenant_mgt.py | 24 +++++++++++ tests/test_token_gen.py | 88 +++++++++++++++++++++++++++++++--------- 2 files changed, 92 insertions(+), 20 deletions(-) diff --git a/tests/test_tenant_mgt.py b/tests/test_tenant_mgt.py index 224fdcc16..018892e3a 100644 --- a/tests/test_tenant_mgt.py +++ b/tests/test_tenant_mgt.py @@ -15,6 +15,7 @@ """Test cases for the firebase_admin.tenant_mgt module.""" import json +import unittest.mock from urllib import parse import pytest @@ -29,6 +30,7 @@ from firebase_admin import _utils from tests import testutils from tests import test_token_gen +from tests.test_token_gen import MOCK_CURRENT_TIME, MOCK_CURRENT_TIME_UTC GET_TENANT_RESPONSE = """{ @@ -964,6 +966,17 @@ def _assert_saml_provider_config(self, provider_config, want_id='saml.provider') class TestVerifyIdToken: + def setup_method(self): + self.time_patch = unittest.mock.patch('time.time', return_value=MOCK_CURRENT_TIME) + self.mock_time = self.time_patch.start() + self.utcnow_patch = unittest.mock.patch( + 'google.auth.jwt._helpers.utcnow', return_value=MOCK_CURRENT_TIME_UTC) + self.mock_utcnow = self.utcnow_patch.start() + + def teardown_method(self): + self.time_patch.stop() + self.utcnow_patch.stop() + def test_valid_token(self, tenant_mgt_app): client = tenant_mgt.auth_for_tenant('test-tenant', app=tenant_mgt_app) client._token_verifier.request = test_token_gen.MOCK_REQUEST @@ -997,6 +1010,17 @@ def tenant_aware_custom_token_app(): class TestCreateCustomToken: + def setup_method(self): + self.time_patch = unittest.mock.patch('time.time', return_value=MOCK_CURRENT_TIME) + self.mock_time = self.time_patch.start() + self.utcnow_patch = unittest.mock.patch( + 'google.auth.jwt._helpers.utcnow', return_value=MOCK_CURRENT_TIME_UTC) + self.mock_utcnow = self.utcnow_patch.start() + + def teardown_method(self): + self.time_patch.stop() + self.utcnow_patch.stop() + def test_custom_token(self, tenant_aware_custom_token_app): client = tenant_mgt.auth_for_tenant('test-tenant', app=tenant_aware_custom_token_app) diff --git a/tests/test_token_gen.py b/tests/test_token_gen.py index 536a5ec91..fe0b28dbe 100644 --- a/tests/test_token_gen.py +++ b/tests/test_token_gen.py @@ -19,6 +19,7 @@ import json import os import time +import unittest.mock from google.auth import crypt from google.auth import jwt @@ -36,6 +37,9 @@ from tests import testutils +MOCK_CURRENT_TIME = 1500000000 +MOCK_CURRENT_TIME_UTC = datetime.datetime.fromtimestamp( + MOCK_CURRENT_TIME, tz=datetime.timezone.utc) MOCK_UID = 'user1' MOCK_CREDENTIAL = credentials.Certificate( testutils.resource_filename('service_account.json')) @@ -105,16 +109,17 @@ def verify_custom_token(custom_token, expected_claims, tenant_id=None): for key, value in expected_claims.items(): assert value == token['claims'][key] -def _get_id_token(payload_overrides=None, header_overrides=None): +def _get_id_token(payload_overrides=None, header_overrides=None, current_time=MOCK_CURRENT_TIME): signer = crypt.RSASigner.from_string(MOCK_PRIVATE_KEY) headers = { 'kid': 'mock-key-id-1' } + now = int(current_time if current_time is not None else time.time()) payload = { 'aud': MOCK_CREDENTIAL.project_id, 'iss': 'https://securetoken.google.com/' + MOCK_CREDENTIAL.project_id, - 'iat': int(time.time()) - 100, - 'exp': int(time.time()) + 3600, + 'iat': now - 100, + 'exp': now + 3600, 'sub': '1234567890', 'admin': True, 'firebase': { @@ -127,12 +132,13 @@ def _get_id_token(payload_overrides=None, header_overrides=None): payload = _merge_jwt_claims(payload, payload_overrides) return jwt.encode(signer, payload, header=headers) -def _get_session_cookie(payload_overrides=None, header_overrides=None): +def _get_session_cookie( + payload_overrides=None, header_overrides=None, current_time=MOCK_CURRENT_TIME): payload_overrides = payload_overrides or {} if 'iss' not in payload_overrides: payload_overrides['iss'] = 'https://session.firebase.google.com/{0}'.format( MOCK_CREDENTIAL.project_id) - return _get_id_token(payload_overrides, header_overrides) + return _get_id_token(payload_overrides, header_overrides, current_time=current_time) def _instrument_user_manager(app, status, payload): client = auth._get_client(app) @@ -205,7 +211,7 @@ def env_var_app(request): @pytest.fixture(scope='module') def revoked_tokens(): mock_user = json.loads(testutils.resource('get_user.json')) - mock_user['users'][0]['validSince'] = str(int(time.time())+100) + mock_user['users'][0]['validSince'] = str(MOCK_CURRENT_TIME + 100) return json.dumps(mock_user) @pytest.fixture(scope='module') @@ -218,7 +224,7 @@ def user_disabled(): def user_disabled_and_revoked(): mock_user = json.loads(testutils.resource('get_user.json')) mock_user['users'][0]['disabled'] = True - mock_user['users'][0]['validSince'] = str(int(time.time())+100) + mock_user['users'][0]['validSince'] = str(MOCK_CURRENT_TIME + 100) return json.dumps(mock_user) @@ -420,6 +426,17 @@ def test_unexpected_response(self, user_mgt_app): class TestVerifyIdToken: + def setup_method(self): + self.time_patch = unittest.mock.patch('time.time', return_value=MOCK_CURRENT_TIME) + self.time_patch.start() + self.utcnow_patch = unittest.mock.patch( + 'google.auth.jwt._helpers.utcnow', return_value=MOCK_CURRENT_TIME_UTC) + self.utcnow_patch.start() + + def teardown_method(self): + self.time_patch.stop() + self.utcnow_patch.stop() + valid_tokens = { 'BinaryToken': TEST_ID_TOKEN, 'TextToken': TEST_ID_TOKEN.decode('utf-8'), @@ -435,14 +452,14 @@ class TestVerifyIdToken: 'EmptySubject': _get_id_token({'sub': ''}), 'IntSubject': _get_id_token({'sub': 10}), 'LongStrSubject': _get_id_token({'sub': 'a' * 129}), - 'FutureToken': _get_id_token({'iat': int(time.time()) + 1000}), + 'FutureToken': _get_id_token({'iat': MOCK_CURRENT_TIME + 1000}), 'ExpiredToken': _get_id_token({ - 'iat': int(time.time()) - 10000, - 'exp': int(time.time()) - 3600 + 'iat': MOCK_CURRENT_TIME - 10000, + 'exp': MOCK_CURRENT_TIME - 3600 }), 'ExpiredTokenShort': _get_id_token({ - 'iat': int(time.time()) - 10000, - 'exp': int(time.time()) - 30 + 'iat': MOCK_CURRENT_TIME - 10000, + 'exp': MOCK_CURRENT_TIME - 30 }), 'BadFormatToken': 'foobar' } @@ -618,6 +635,17 @@ def test_certificate_request_failure(self, user_mgt_app): class TestVerifySessionCookie: + def setup_method(self): + self.time_patch = unittest.mock.patch('time.time', return_value=MOCK_CURRENT_TIME) + self.time_patch.start() + self.utcnow_patch = unittest.mock.patch( + 'google.auth.jwt._helpers.utcnow', return_value=MOCK_CURRENT_TIME_UTC) + self.utcnow_patch.start() + + def teardown_method(self): + self.time_patch.stop() + self.utcnow_patch.stop() + valid_cookies = { 'BinaryCookie': TEST_SESSION_COOKIE, 'TextCookie': TEST_SESSION_COOKIE.decode('utf-8'), @@ -633,14 +661,14 @@ class TestVerifySessionCookie: 'EmptySubject': _get_session_cookie({'sub': ''}), 'IntSubject': _get_session_cookie({'sub': 10}), 'LongStrSubject': _get_session_cookie({'sub': 'a' * 129}), - 'FutureCookie': _get_session_cookie({'iat': int(time.time()) + 1000}), + 'FutureCookie': _get_session_cookie({'iat': MOCK_CURRENT_TIME + 1000}), 'ExpiredCookie': _get_session_cookie({ - 'iat': int(time.time()) - 10000, - 'exp': int(time.time()) - 3600 + 'iat': MOCK_CURRENT_TIME - 10000, + 'exp': MOCK_CURRENT_TIME - 3600 }), 'ExpiredCookieShort': _get_session_cookie({ - 'iat': int(time.time()) - 10000, - 'exp': int(time.time()) - 30 + 'iat': MOCK_CURRENT_TIME - 10000, + 'exp': MOCK_CURRENT_TIME - 30 }), 'BadFormatCookie': 'foobar', 'IDToken': TEST_ID_TOKEN, @@ -792,6 +820,17 @@ def test_certificate_request_failure(self, user_mgt_app): class TestCertificateCaching: + def setup_method(self): + self.time_patch = unittest.mock.patch('time.time', return_value=MOCK_CURRENT_TIME) + self.time_patch.start() + self.utcnow_patch = unittest.mock.patch( + 'google.auth.jwt._helpers.utcnow', return_value=MOCK_CURRENT_TIME_UTC) + self.utcnow_patch.start() + + def teardown_method(self): + self.time_patch.stop() + self.utcnow_patch.stop() + def test_certificate_caching(self, user_mgt_app, httpserver): httpserver.serve_content(MOCK_PUBLIC_CERTS, 200, headers={'Cache-Control': 'max-age=3600'}) verifier = _token_gen.TokenVerifier(user_mgt_app) @@ -810,6 +849,18 @@ def test_certificate_caching(self, user_mgt_app, httpserver): class TestCertificateFetchTimeout: + def setup_method(self): + self.time_patch = unittest.mock.patch('time.time', return_value=MOCK_CURRENT_TIME) + self.time_patch.start() + self.utcnow_patch = unittest.mock.patch( + 'google.auth.jwt._helpers.utcnow', return_value=MOCK_CURRENT_TIME_UTC) + self.utcnow_patch.start() + + def teardown_method(self): + self.time_patch.stop() + self.utcnow_patch.stop() + testutils.cleanup_apps() + timeout_configs = [ ({'httpTimeout': 4}, 4), ({'httpTimeout': None}, None), @@ -852,6 +903,3 @@ def _instrument_session(self, app): recorder = [] request.session.mount('https://', testutils.MockAdapter(MOCK_PUBLIC_CERTS, 200, recorder)) return recorder - - def teardown_method(self): - testutils.cleanup_apps() From f7546f5271267154ebb3a70d368ce25b88c6a76a Mon Sep 17 00:00:00 2001 From: Jonathan Edey <145066863+jonathanedey@users.noreply.github.com> Date: Tue, 3 Jun 2025 14:20:41 -0400 Subject: [PATCH 202/226] feat(fcm): Add `live_activity_token` to `APNSConfig` (#880) * Add live_activity_token to `APNSConfig`, allowing you to specify this token for APNS messages. This change introduces: - Adding the `live_activity_token` field to the `APNSConfig` class - Updated unit test to verify that the `live_activity_token` is correctly included in the encoded message * Refactor and edit doc string --------- Co-authored-by: google-labs-jules[bot] <161369871+google-labs-jules[bot]@users.noreply.github.com> --- firebase_admin/_messaging_encoder.py | 2 ++ firebase_admin/_messaging_utils.py | 4 +++- tests/test_messaging.py | 4 +++- 3 files changed, 8 insertions(+), 2 deletions(-) diff --git a/firebase_admin/_messaging_encoder.py b/firebase_admin/_messaging_encoder.py index d7f233289..32f97875e 100644 --- a/firebase_admin/_messaging_encoder.py +++ b/firebase_admin/_messaging_encoder.py @@ -529,6 +529,8 @@ def encode_apns(cls, apns): 'APNSConfig.headers', apns.headers), 'payload': cls.encode_apns_payload(apns.payload), 'fcm_options': cls.encode_apns_fcm_options(apns.fcm_options), + 'live_activity_token': _Validators.check_string( + 'APNSConfig.live_activity_token', apns.live_activity_token), } return cls.remove_null_values(result) diff --git a/firebase_admin/_messaging_utils.py b/firebase_admin/_messaging_utils.py index ae1f5cc56..8fd720701 100644 --- a/firebase_admin/_messaging_utils.py +++ b/firebase_admin/_messaging_utils.py @@ -334,15 +334,17 @@ class APNSConfig: payload: A ``messaging.APNSPayload`` to be included in the message (optional). fcm_options: A ``messaging.APNSFCMOptions`` instance to be included in the message (optional). + live_activity_token: A live activity token string (optional). .. _APNS Documentation: https://developer.apple.com/library/content/documentation\ /NetworkingInternet/Conceptual/RemoteNotificationsPG/CommunicatingwithAPNs.html """ - def __init__(self, headers=None, payload=None, fcm_options=None): + def __init__(self, headers=None, payload=None, fcm_options=None, live_activity_token=None): self.headers = headers self.payload = payload self.fcm_options = fcm_options + self.live_activity_token = live_activity_token class APNSPayload: diff --git a/tests/test_messaging.py b/tests/test_messaging.py index 45a5bc6d5..54173ea97 100644 --- a/tests/test_messaging.py +++ b/tests/test_messaging.py @@ -1094,7 +1094,8 @@ def test_apns_config(self): topic='topic', apns=messaging.APNSConfig( headers={'h1': 'v1', 'h2': 'v2'}, - fcm_options=messaging.APNSFCMOptions('analytics_label_v1') + fcm_options=messaging.APNSFCMOptions('analytics_label_v1'), + live_activity_token='test_token_string' ), ) expected = { @@ -1107,6 +1108,7 @@ def test_apns_config(self): 'fcm_options': { 'analytics_label': 'analytics_label_v1', }, + 'live_activity_token': 'test_token_string', }, } check_encoding(msg, expected) From e0599f98e67b3d7a97db125fbe84fc9d3bc59571 Mon Sep 17 00:00:00 2001 From: Minki Kim <68267535+mingi3314@users.noreply.github.com> Date: Wed, 4 Jun 2025 03:29:08 +0900 Subject: [PATCH 203/226] refactor: Optimize success count calculation in BatchResponse (#837) Co-authored-by: Lahiru Maramba Co-authored-by: Jonathan Edey <145066863+jonathanedey@users.noreply.github.com> --- firebase_admin/messaging.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/firebase_admin/messaging.py b/firebase_admin/messaging.py index d2ad04a04..c2870eac7 100644 --- a/firebase_admin/messaging.py +++ b/firebase_admin/messaging.py @@ -323,7 +323,7 @@ class BatchResponse: def __init__(self, responses): self._responses = responses - self._success_count = len([resp for resp in responses if resp.success]) + self._success_count = sum(1 for resp in responses if resp.success) @property def responses(self): From 99b60207dbe1f359783d54b6a320de723748bc7b Mon Sep 17 00:00:00 2001 From: Jonathan Edey <145066863+jonathanedey@users.noreply.github.com> Date: Tue, 3 Jun 2025 14:45:44 -0400 Subject: [PATCH 204/226] feat(fcm) Add `send_each_async` and `send_each_for_multicast_async` for FCM async and HTTP/2 support (#882) * Added minimal support for sending FCM messages in async using HTTP/2 (#870) * httpx async_send_each prototype * Clean up code and lint * fix: Add extra dependancy for http2 * fix: reset message batch limit to 500 * fix: Add new import to `setup.py` * Refactored retry config into `_retry.py` and added support for exponential backoff and `Retry-After` header (#871) * Refactored retry config to `_retry.py` and added support for backoff and Retry-After * Added unit tests for `_retry.py` * Updated unit tests for HTTPX request errors * Address review comments * Added `HttpxAsyncClient` wrapper for `httpx.AsyncClient` and support for `send_each_for_multicast_async()` (#878) * Refactored retry config to `_retry.py` and added support for backoff and Retry-After * Added unit tests for `_retry.py` * Updated unit tests for HTTPX request errors * Add HttpxAsyncClient to wrap httpx.AsyncClient * Added forced refresh to google auth credential flow and fixed lint * Added unit tests for `GoogleAuthCredentialFlow` and `HttpxAsyncClient` * Removed duplicate export * Added support for `send_each_for_multicast_async()` and updated doc string and type hints * Remove duplicate auth class * Cover auth request error case when `requests` request fails in HTTPX auth flow * Update test for `send_each_for_multicast_async()` * Address review comments * fix lint and some types * Address review comments and removed unused code * Update metric header test logic for `TestHttpxAsyncClient` * Add `send_each_for_multicast_async` to `__all__` * Apply suggestions from TW review --- firebase_admin/_http_client.py | 214 ++++++++++- firebase_admin/_retry.py | 223 +++++++++++ firebase_admin/_utils.py | 86 +++++ firebase_admin/messaging.py | 188 +++++++-- integration/conftest.py | 15 +- integration/test_messaging.py | 65 ++++ requirements.txt | 4 +- setup.py | 1 + tests/test_http_client.py | 683 ++++++++++++++++++++++++++++----- tests/test_messaging.py | 198 ++++++++++ tests/test_retry.py | 454 ++++++++++++++++++++++ 11 files changed, 1990 insertions(+), 141 deletions(-) create mode 100644 firebase_admin/_retry.py create mode 100644 tests/test_retry.py diff --git a/firebase_admin/_http_client.py b/firebase_admin/_http_client.py index 57c09e2e4..6d2582291 100644 --- a/firebase_admin/_http_client.py +++ b/firebase_admin/_http_client.py @@ -14,14 +14,23 @@ """Internal HTTP client module. - This module provides utilities for making HTTP calls using the requests library. - """ - -from google.auth import transport -import requests +This module provides utilities for making HTTP calls using the requests library. +""" + +from __future__ import annotations +import logging +from typing import Any, Dict, Generator, Optional, Tuple, Union +import httpx +import requests.adapters from requests.packages.urllib3.util import retry # pylint: disable=import-error +from google.auth import credentials +from google.auth import transport +from google.auth.transport import requests as google_auth_requests from firebase_admin import _utils +from firebase_admin._retry import HttpxRetry, HttpxRetryTransport + +logger = logging.getLogger(__name__) if hasattr(retry.Retry.DEFAULT, 'allowed_methods'): _ANY_METHOD = {'allowed_methods': None} @@ -34,6 +43,9 @@ connect=1, read=1, status=4, status_forcelist=[500, 503], raise_on_status=False, backoff_factor=0.5, **_ANY_METHOD) +DEFAULT_HTTPX_RETRY_CONFIG = HttpxRetry( + max_retries=4, status_forcelist=[500, 503], backoff_factor=0.5) + DEFAULT_TIMEOUT_SECONDS = 120 @@ -144,7 +156,6 @@ def close(self): self._session.close() self._session = None - class JsonHttpClient(HttpClient): """An HTTP client that parses response messages as JSON.""" @@ -153,3 +164,194 @@ def __init__(self, **kwargs): def parse_body(self, resp): return resp.json() + +class GoogleAuthCredentialFlow(httpx.Auth): + """Google Auth Credential Auth Flow""" + def __init__(self, credential: credentials.Credentials): + self._credential = credential + self._max_refresh_attempts = 2 + self._refresh_status_codes = (401,) + + def apply_auth_headers( + self, + request: httpx.Request, + auth_request: google_auth_requests.Request + ) -> None: + """A helper function that refreshes credentials if needed and mutates the request headers + to contain access token and any other Google Auth headers.""" + + logger.debug( + 'Attempting to apply auth headers. Credential validity before: %s', + self._credential.valid + ) + self._credential.before_request( + auth_request, request.method, str(request.url), request.headers + ) + logger.debug('Auth headers applied. Credential validity after: %s', self._credential.valid) + + def auth_flow(self, request: httpx.Request) -> Generator[httpx.Request, httpx.Response, None]: + _original_headers = request.headers.copy() + _credential_refresh_attempt = 0 + + # Create a Google auth request object to be used for refreshing credentials + auth_request = google_auth_requests.Request() + + while True: + # Copy original headers for each attempt + request.headers = _original_headers.copy() + + # Apply auth headers (which might include an implicit refresh if token is expired) + self.apply_auth_headers(request, auth_request) + + logger.debug( + 'Dispatching request, attempt %d of %d', + _credential_refresh_attempt, self._max_refresh_attempts + ) + response: httpx.Response = yield request + + if response.status_code in self._refresh_status_codes: + if _credential_refresh_attempt < self._max_refresh_attempts: + logger.debug( + 'Received status %d. Attempting explicit credential refresh. \ + Attempt %d of %d.', + response.status_code, + _credential_refresh_attempt + 1, + self._max_refresh_attempts + ) + # Explicitly force a credentials refresh + self._credential.refresh(auth_request) + _credential_refresh_attempt += 1 + else: + logger.debug( + 'Received status %d, but max auth refresh attempts (%d) reached. \ + Returning last response.', + response.status_code, self._max_refresh_attempts + ) + break + else: + # Status code is not one that requires a refresh, so break and return response + logger.debug( + 'Status code %d does not require refresh. Returning response.', + response.status_code + ) + break + # The last yielded response is automatically returned by httpx's auth flow. + +class HttpxAsyncClient(): + """Async HTTP client used to make HTTP/2 calls using HTTPX. + + HttpxAsyncClient maintains an async HTTPX client, handles request authentication, and retries + if necessary. + """ + def __init__( + self, + credential: Optional[credentials.Credentials] = None, + base_url: str = '', + headers: Optional[Union[httpx.Headers, Dict[str, str]]] = None, + retry_config: HttpxRetry = DEFAULT_HTTPX_RETRY_CONFIG, + timeout: int = DEFAULT_TIMEOUT_SECONDS, + http2: bool = True + ) -> None: + """Creates a new HttpxAsyncClient instance from the provided arguments. + + If a credential is provided, initializes a new async HTTPX client authorized with it. + Otherwise, initializes a new unauthorized async HTTPX client. + + Args: + credential: A Google credential that can be used to authenticate requests (optional). + base_url: A URL prefix to be added to all outgoing requests (optional). + headers: A map of headers to be added to all outgoing requests (optional). + retry_config: A HttpxRetry configuration. Default settings would retry up to 4 times for + HTTP 500 and 503 errors (optional). + timeout: HTTP timeout in seconds. Defaults to 120 seconds when not specified (optional). + http2: A boolean indicating if HTTP/2 support should be enabled. Defaults to `True` when + not specified (optional). + """ + self._base_url = base_url + self._timeout = timeout + self._headers = {**headers, **METRICS_HEADERS} if headers else {**METRICS_HEADERS} + self._retry_config = retry_config + + # Only set up retries on urls starting with 'http://' and 'https://' + self._mounts = { + 'http://': HttpxRetryTransport(retry=self._retry_config, http2=http2), + 'https://': HttpxRetryTransport(retry=self._retry_config, http2=http2) + } + + if credential: + self._async_client = httpx.AsyncClient( + http2=http2, + timeout=self._timeout, + headers=self._headers, + auth=GoogleAuthCredentialFlow(credential), # Add auth flow for credentials. + mounts=self._mounts + ) + else: + self._async_client = httpx.AsyncClient( + http2=http2, + timeout=self._timeout, + headers=self._headers, + mounts=self._mounts + ) + + @property + def base_url(self): + return self._base_url + + @property + def timeout(self): + return self._timeout + + @property + def async_client(self): + return self._async_client + + async def request(self, method: str, url: str, **kwargs: Any) -> httpx.Response: + """Makes an HTTP call using the HTTPX library. + + This is the sole entry point to the HTTPX library. All other helper methods in this + class call this method to send HTTP requests out. Refer to + https://www.python-httpx.org/api/ for more information on supported options + and features. + + Args: + method: HTTP method name as a string (e.g. get, post). + url: URL of the remote endpoint. + **kwargs: An additional set of keyword arguments to be passed into the HTTPX API + (e.g. json, params, timeout). + + Returns: + Response: An HTTPX response object. + + Raises: + HTTPError: Any HTTPX exceptions encountered while making the HTTP call. + RequestException: Any requests exceptions encountered while making the HTTP call. + """ + if 'timeout' not in kwargs: + kwargs['timeout'] = self.timeout + resp = await self._async_client.request(method, self.base_url + url, **kwargs) + return resp.raise_for_status() + + async def headers(self, method: str, url: str, **kwargs: Any) -> httpx.Headers: + resp = await self.request(method, url, **kwargs) + return resp.headers + + async def body_and_response( + self, method: str, url: str, **kwargs: Any) -> Tuple[Any, httpx.Response]: + resp = await self.request(method, url, **kwargs) + return self.parse_body(resp), resp + + async def body(self, method: str, url: str, **kwargs: Any) -> Any: + resp = await self.request(method, url, **kwargs) + return self.parse_body(resp) + + async def headers_and_body( + self, method: str, url: str, **kwargs: Any) -> Tuple[httpx.Headers, Any]: + resp = await self.request(method, url, **kwargs) + return resp.headers, self.parse_body(resp) + + def parse_body(self, resp: httpx.Response) -> Any: + return resp.json() + + async def aclose(self) -> None: + await self._async_client.aclose() diff --git a/firebase_admin/_retry.py b/firebase_admin/_retry.py new file mode 100644 index 000000000..efd90a743 --- /dev/null +++ b/firebase_admin/_retry.py @@ -0,0 +1,223 @@ +# Copyright 2025 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Internal retry logic module + +This module provides utilities for adding retry logic to HTTPX requests +""" + +from __future__ import annotations +import copy +import email.utils +import random +import re +import time +from typing import Any, Callable, List, Optional, Tuple, Coroutine +import logging +import asyncio +import httpx + +logger = logging.getLogger(__name__) + + +class HttpxRetry: + """HTTPX based retry config""" + # Status codes to be used for respecting `Retry-After` header + RETRY_AFTER_STATUS_CODES = frozenset([413, 429, 503]) + + # Default maximum backoff time. + DEFAULT_BACKOFF_MAX = 120 + + def __init__( + self, + max_retries: int = 10, + status_forcelist: Optional[List[int]] = None, + backoff_factor: float = 0, + backoff_max: float = DEFAULT_BACKOFF_MAX, + backoff_jitter: float = 0, + history: Optional[List[Tuple[ + httpx.Request, + Optional[httpx.Response], + Optional[Exception] + ]]] = None, + respect_retry_after_header: bool = False, + ) -> None: + self.retries_left = max_retries + self.status_forcelist = status_forcelist + self.backoff_factor = backoff_factor + self.backoff_max = backoff_max + self.backoff_jitter = backoff_jitter + if history: + self.history = history + else: + self.history = [] + self.respect_retry_after_header = respect_retry_after_header + + def copy(self) -> HttpxRetry: + """Creates a deep copy of this instance.""" + return copy.deepcopy(self) + + def is_retryable_response(self, response: httpx.Response) -> bool: + """Determine if a response implies that the request should be retried if possible.""" + if self.status_forcelist and response.status_code in self.status_forcelist: + return True + + has_retry_after = bool(response.headers.get("Retry-After")) + if ( + self.respect_retry_after_header + and has_retry_after + and response.status_code in self.RETRY_AFTER_STATUS_CODES + ): + return True + + return False + + def is_exhausted(self) -> bool: + """Determine if there are anymore more retires.""" + # retries_left is negative + return self.retries_left < 0 + + # Identical implementation of `urllib3.Retry.parse_retry_after()` + def _parse_retry_after(self, retry_after_header: str) -> float | None: + """Parses Retry-After string into a float with unit seconds.""" + seconds: float + # Whitespace: https://tools.ietf.org/html/rfc7230#section-3.2.4 + if re.match(r"^\s*[0-9]+\s*$", retry_after_header): + seconds = int(retry_after_header) + else: + retry_date_tuple = email.utils.parsedate_tz(retry_after_header) + if retry_date_tuple is None: + raise httpx.RemoteProtocolError(f"Invalid Retry-After header: {retry_after_header}") + + retry_date = email.utils.mktime_tz(retry_date_tuple) + seconds = retry_date - time.time() + + seconds = max(seconds, 0) + + return seconds + + def get_retry_after(self, response: httpx.Response) -> float | None: + """Determine the Retry-After time needed before sending the next request.""" + retry_after_header = response.headers.get('Retry-After', None) + if retry_after_header: + # Convert retry header to a float in seconds + return self._parse_retry_after(retry_after_header) + return None + + def get_backoff_time(self): + """Determine the backoff time needed before sending the next request.""" + # attempt_count is the number of previous request attempts + attempt_count = len(self.history) + # Backoff should be set to 0 until after first retry. + if attempt_count <= 1: + return 0 + backoff = self.backoff_factor * (2 ** (attempt_count-1)) + if self.backoff_jitter: + backoff += random.random() * self.backoff_jitter + return float(max(0, min(self.backoff_max, backoff))) + + async def sleep_for_backoff(self) -> None: + """Determine and wait the backoff time needed before sending the next request.""" + backoff = self.get_backoff_time() + logger.debug('Sleeping for backoff of %f seconds following failed request', backoff) + await asyncio.sleep(backoff) + + async def sleep(self, response: httpx.Response) -> None: + """Determine and wait the time needed before sending the next request.""" + if self.respect_retry_after_header: + retry_after = self.get_retry_after(response) + if retry_after: + logger.debug( + 'Sleeping for Retry-After header of %f seconds following failed request', + retry_after + ) + await asyncio.sleep(retry_after) + return + await self.sleep_for_backoff() + + def increment( + self, + request: httpx.Request, + response: Optional[httpx.Response] = None, + error: Optional[Exception] = None + ) -> None: + """Update the retry state based on request attempt.""" + self.retries_left -= 1 + self.history.append((request, response, error)) + + +class HttpxRetryTransport(httpx.AsyncBaseTransport): + """HTTPX transport with retry logic.""" + + DEFAULT_RETRY = HttpxRetry(max_retries=4, status_forcelist=[500, 503], backoff_factor=0.5) + + def __init__(self, retry: HttpxRetry = DEFAULT_RETRY, **kwargs: Any) -> None: + self._retry = retry + + transport_kwargs = kwargs.copy() + transport_kwargs.update({'retries': 0, 'http2': True}) + # We use a full AsyncHTTPTransport under the hood that is already + # set up to handle requests. We also insure that that transport's internal + # retries are not allowed. + self._wrapped_transport = httpx.AsyncHTTPTransport(**transport_kwargs) + + async def handle_async_request(self, request: httpx.Request) -> httpx.Response: + return await self._dispatch_with_retry( + request, self._wrapped_transport.handle_async_request) + + async def _dispatch_with_retry( + self, + request: httpx.Request, + dispatch_method: Callable[[httpx.Request], Coroutine[Any, Any, httpx.Response]] + ) -> httpx.Response: + """Sends a request with retry logic using a provided dispatch method.""" + # This request config is used across all requests that use this transport and therefore + # needs to be copied to be used for just this request and it's retries. + retry = self._retry.copy() + # First request + response, error = None, None + + while not retry.is_exhausted(): + + # First retry + if response: + await retry.sleep(response) + + # Need to reset here so only last attempt's error or response is saved. + response, error = None, None + + try: + logger.debug('Sending request in _dispatch_with_retry(): %r', request) + response = await dispatch_method(request) + logger.debug('Received response: %r', response) + except httpx.HTTPError as err: + logger.debug('Received error: %r', err) + error = err + + if response and not retry.is_retryable_response(response): + return response + + if error: + raise error + + retry.increment(request, response, error) + + if response: + return response + if error: + raise error + raise AssertionError('_dispatch_with_retry() ended with no response or exception') + + async def aclose(self) -> None: + await self._wrapped_transport.aclose() diff --git a/firebase_admin/_utils.py b/firebase_admin/_utils.py index b6e292546..765d11587 100644 --- a/firebase_admin/_utils.py +++ b/firebase_admin/_utils.py @@ -16,9 +16,11 @@ import json from platform import python_version +from typing import Callable, Optional import google.auth import requests +import httpx import firebase_admin from firebase_admin import exceptions @@ -128,6 +130,36 @@ def handle_platform_error_from_requests(error, handle_func=None): return exc if exc else _handle_func_requests(error, message, error_dict) +def handle_platform_error_from_httpx( + error: httpx.HTTPError, + handle_func: Optional[Callable[..., Optional[exceptions.FirebaseError]]] = None +) -> exceptions.FirebaseError: + """Constructs a ``FirebaseError`` from the given httpx error. + + This can be used to handle errors returned by Google Cloud Platform (GCP) APIs. + + Args: + error: An error raised by the httpx module while making an HTTP call to a GCP API. + handle_func: A function that can be used to handle platform errors in a custom way. When + specified, this function will be called with three arguments. It has the same + signature as ```_handle_func_httpx``, but may return ``None``. + + Returns: + FirebaseError: A ``FirebaseError`` that can be raised to the user code. + """ + + if isinstance(error, httpx.HTTPStatusError): + response = error.response + content = response.content.decode() + status_code = response.status_code + error_dict, message = _parse_platform_error(content, status_code) + exc = None + if handle_func: + exc = handle_func(error, message, error_dict) + + return exc if exc else _handle_func_httpx(error, message, error_dict) + return handle_httpx_error(error) + def handle_operation_error(error): """Constructs a ``FirebaseError`` from the given operation error. @@ -204,6 +236,60 @@ def handle_requests_error(error, message=None, code=None): err_type = _error_code_to_exception_type(code) return err_type(message=message, cause=error, http_response=error.response) +def _handle_func_httpx(error: httpx.HTTPError, message, error_dict) -> exceptions.FirebaseError: + """Constructs a ``FirebaseError`` from the given GCP error. + + Args: + error: An error raised by the httpx module while making an HTTP call. + message: A message to be included in the resulting ``FirebaseError``. + error_dict: Parsed GCP error response. + + Returns: + FirebaseError: A ``FirebaseError`` that can be raised to the user code or None. + """ + code = error_dict.get('status') + return handle_httpx_error(error, message, code) + + +def handle_httpx_error(error: httpx.HTTPError, message=None, code=None) -> exceptions.FirebaseError: + """Constructs a ``FirebaseError`` from the given httpx error. + + This method is agnostic of the remote service that produced the error, whether it is a GCP + service or otherwise. Therefore, this method does not attempt to parse the error response in + any way. + + Args: + error: An error raised by the httpx module while making an HTTP call. + message: A message to be included in the resulting ``FirebaseError`` (optional). If not + specified the string representation of the ``error`` argument is used as the message. + code: A GCP error code that will be used to determine the resulting error type (optional). + If not specified the HTTP status code on the error response is used to determine a + suitable error code. + + Returns: + FirebaseError: A ``FirebaseError`` that can be raised to the user code. + """ + if isinstance(error, httpx.TimeoutException): + return exceptions.DeadlineExceededError( + message='Timed out while making an API call: {0}'.format(error), + cause=error) + if isinstance(error, httpx.ConnectError): + return exceptions.UnavailableError( + message='Failed to establish a connection: {0}'.format(error), + cause=error) + if isinstance(error, httpx.HTTPStatusError): + print("printing status error", error) + if not code: + code = _http_status_to_error_code(error.response.status_code) + if not message: + message = str(error) + + err_type = _error_code_to_exception_type(code) + return err_type(message=message, cause=error, http_response=error.response) + + return exceptions.UnknownError( + message='Unknown error while making a remote service call: {0}'.format(error), + cause=error) def _http_status_to_error_code(status): """Maps an HTTP status to a platform error code.""" diff --git a/firebase_admin/messaging.py b/firebase_admin/messaging.py index c2870eac7..99dc93a67 100644 --- a/firebase_admin/messaging.py +++ b/firebase_admin/messaging.py @@ -14,22 +14,31 @@ """Firebase Cloud Messaging module.""" +from __future__ import annotations +from typing import Any, Callable, Dict, List, Optional, cast import concurrent.futures import json import warnings +import asyncio +import logging import requests +import httpx from googleapiclient import http from googleapiclient import _auth import firebase_admin -from firebase_admin import _http_client -from firebase_admin import _messaging_encoder -from firebase_admin import _messaging_utils -from firebase_admin import _gapic_utils -from firebase_admin import _utils -from firebase_admin import exceptions - +from firebase_admin import ( + _http_client, + _messaging_encoder, + _messaging_utils, + _gapic_utils, + _utils, + exceptions, + App +) + +logger = logging.getLogger(__name__) _MESSAGING_ATTRIBUTE = '_messaging' @@ -66,7 +75,9 @@ 'send_all', 'send_multicast', 'send_each', + 'send_each_async', 'send_each_for_multicast', + 'send_each_for_multicast_async', 'subscribe_to_topic', 'unsubscribe_from_topic', ] @@ -97,14 +108,14 @@ UnregisteredError = _messaging_utils.UnregisteredError -def _get_messaging_service(app): +def _get_messaging_service(app: Optional[App]) -> _MessagingService: return _utils.get_app_service(app, _MESSAGING_ATTRIBUTE, _MessagingService) -def send(message, dry_run=False, app=None): +def send(message: Message, dry_run: bool = False, app: Optional[App] = None) -> str: """Sends the given message via Firebase Cloud Messaging (FCM). If the ``dry_run`` mode is enabled, the message will not be actually delivered to the - recipients. Instead FCM performs all the usual validations, and emulates the send operation. + recipients. Instead, FCM performs all the usual validations and emulates the send operation. Args: message: An instance of ``messaging.Message``. @@ -120,11 +131,15 @@ def send(message, dry_run=False, app=None): """ return _get_messaging_service(app).send(message, dry_run) -def send_each(messages, dry_run=False, app=None): +def send_each( + messages: List[Message], + dry_run: bool = False, + app: Optional[App] = None + ) -> BatchResponse: """Sends each message in the given list via Firebase Cloud Messaging. If the ``dry_run`` mode is enabled, the message will not be actually delivered to the - recipients. Instead FCM performs all the usual validations, and emulates the send operation. + recipients. Instead, FCM performs all the usual validations and emulates the send operation. Args: messages: A list of ``messaging.Message`` instances. @@ -140,11 +155,71 @@ def send_each(messages, dry_run=False, app=None): """ return _get_messaging_service(app).send_each(messages, dry_run) +async def send_each_async( + messages: List[Message], + dry_run: bool = False, + app: Optional[App] = None + ) -> BatchResponse: + """Sends each message in the given list asynchronously via Firebase Cloud Messaging. + + If the ``dry_run`` mode is enabled, the message will not be actually delivered to the + recipients. Instead, FCM performs all the usual validations and emulates the send operation. + + Args: + messages: A list of ``messaging.Message`` instances. + dry_run: A boolean indicating whether to run the operation in dry run mode (optional). + app: An App instance (optional). + + Returns: + BatchResponse: A ``messaging.BatchResponse`` instance. + + Raises: + FirebaseError: If an error occurs while sending the message to the FCM service. + ValueError: If the input arguments are invalid. + """ + return await _get_messaging_service(app).send_each_async(messages, dry_run) + +async def send_each_for_multicast_async( + multicast_message: MulticastMessage, + dry_run: bool = False, + app: Optional[App] = None + ) -> BatchResponse: + """Sends the given mutlicast message to each token asynchronously via Firebase Cloud Messaging + (FCM). + + If the ``dry_run`` mode is enabled, the message will not be actually delivered to the + recipients. Instead, FCM performs all the usual validations and emulates the send operation. + + Args: + multicast_message: An instance of ``messaging.MulticastMessage``. + dry_run: A boolean indicating whether to run the operation in dry run mode (optional). + app: An App instance (optional). + + Returns: + BatchResponse: A ``messaging.BatchResponse`` instance. + + Raises: + FirebaseError: If an error occurs while sending the message to the FCM service. + ValueError: If the input arguments are invalid. + """ + if not isinstance(multicast_message, MulticastMessage): + raise ValueError('Message must be an instance of messaging.MulticastMessage class.') + messages = [Message( + data=multicast_message.data, + notification=multicast_message.notification, + android=multicast_message.android, + webpush=multicast_message.webpush, + apns=multicast_message.apns, + fcm_options=multicast_message.fcm_options, + token=token + ) for token in multicast_message.tokens] + return await _get_messaging_service(app).send_each_async(messages, dry_run) + def send_each_for_multicast(multicast_message, dry_run=False, app=None): """Sends the given mutlicast message to each token via Firebase Cloud Messaging (FCM). If the ``dry_run`` mode is enabled, the message will not be actually delivered to the - recipients. Instead FCM performs all the usual validations, and emulates the send operation. + recipients. Instead, FCM performs all the usual validations and emulates the send operation. Args: multicast_message: An instance of ``messaging.MulticastMessage``. @@ -175,7 +250,7 @@ def send_all(messages, dry_run=False, app=None): """Sends the given list of messages via Firebase Cloud Messaging as a single batch. If the ``dry_run`` mode is enabled, the message will not be actually delivered to the - recipients. Instead FCM performs all the usual validations, and emulates the send operation. + recipients. Instead, FCM performs all the usual validations and emulates the send operation. Args: messages: A list of ``messaging.Message`` instances. @@ -198,7 +273,7 @@ def send_multicast(multicast_message, dry_run=False, app=None): """Sends the given mutlicast message to all tokens via Firebase Cloud Messaging (FCM). If the ``dry_run`` mode is enabled, the message will not be actually delivered to the - recipients. Instead FCM performs all the usual validations, and emulates the send operation. + recipients. Instead, FCM performs all the usual validations and emulates the send operation. Args: multicast_message: An instance of ``messaging.MulticastMessage``. @@ -321,21 +396,21 @@ def errors(self): class BatchResponse: """The response received from a batch request to the FCM API.""" - def __init__(self, responses): + def __init__(self, responses: List[SendResponse]) -> None: self._responses = responses self._success_count = sum(1 for resp in responses if resp.success) @property - def responses(self): + def responses(self) -> List[SendResponse]: """A list of ``messaging.SendResponse`` objects (possibly empty).""" return self._responses @property - def success_count(self): + def success_count(self) -> int: return self._success_count @property - def failure_count(self): + def failure_count(self) -> int: return len(self.responses) - self.success_count @@ -363,7 +438,6 @@ def exception(self): """A ``FirebaseError`` if an error occurs while sending the message to the FCM service.""" return self._exception - class _MessagingService: """Service class that implements Firebase Cloud Messaging (FCM) functionality.""" @@ -381,7 +455,7 @@ class _MessagingService: 'UNREGISTERED': UnregisteredError, } - def __init__(self, app): + def __init__(self, app: App) -> None: project_id = app.project_id if not project_id: raise ValueError( @@ -396,6 +470,8 @@ def __init__(self, app): timeout = app.options.get('httpTimeout', _http_client.DEFAULT_TIMEOUT_SECONDS) self._credential = app.credential.get_credential() self._client = _http_client.JsonHttpClient(credential=self._credential, timeout=timeout) + self._async_client = _http_client.HttpxAsyncClient( + credential=self._credential, timeout=timeout) self._build_transport = _auth.authorized_http @classmethod @@ -404,7 +480,7 @@ def encode_message(cls, message): raise ValueError('Message must be an instance of messaging.Message class.') return cls.JSON_ENCODER.default(message) - def send(self, message, dry_run=False): + def send(self, message: Message, dry_run: bool = False) -> str: """Sends the given message to FCM via the FCM v1 API.""" data = self._message_data(message, dry_run) try: @@ -417,9 +493,9 @@ def send(self, message, dry_run=False): except requests.exceptions.RequestException as error: raise self._handle_fcm_error(error) else: - return resp['name'] + return cast(str, resp['name']) - def send_each(self, messages, dry_run=False): + def send_each(self, messages: List[Message], dry_run: bool = False) -> BatchResponse: """Sends the given messages to FCM via the FCM v1 API.""" if not isinstance(messages, list): raise ValueError('messages must be a list of messaging.Message instances.') @@ -448,6 +524,38 @@ def send_data(data): message='Unknown error while making remote service calls: {0}'.format(error), cause=error) + async def send_each_async(self, messages: List[Message], dry_run: bool = True) -> BatchResponse: + """Sends the given messages to FCM via the FCM v1 API.""" + if not isinstance(messages, list): + raise ValueError('messages must be a list of messaging.Message instances.') + if len(messages) > 500: + raise ValueError('messages must not contain more than 500 elements.') + + async def send_data(data): + try: + resp = await self._async_client.request( + 'post', + url=self._fcm_url, + headers=self._fcm_headers, + json=data) + except httpx.HTTPError as exception: + return SendResponse(resp=None, exception=self._handle_fcm_httpx_error(exception)) + # Catch errors caused by the requests library during authorization + except requests.exceptions.RequestException as exception: + return SendResponse(resp=None, exception=self._handle_fcm_error(exception)) + else: + return SendResponse(resp.json(), exception=None) + + message_data = [self._message_data(message, dry_run) for message in messages] + try: + responses = await asyncio.gather(*[send_data(message) for message in message_data]) + return BatchResponse(responses) + except Exception as error: + raise exceptions.UnknownError( + message='Unknown error while making remote service calls: {0}'.format(error), + cause=error) + + def send_all(self, messages, dry_run=False): """Sends the given messages to FCM via the batch API.""" if not isinstance(messages, list): @@ -533,6 +641,11 @@ def _handle_fcm_error(self, error): return _utils.handle_platform_error_from_requests( error, _MessagingService._build_fcm_error_requests) + def _handle_fcm_httpx_error(self, error: httpx.HTTPError) -> exceptions.FirebaseError: + """Handles errors received from the FCM API.""" + return _utils.handle_platform_error_from_httpx( + error, _MessagingService._build_fcm_error_httpx) + def _handle_iid_error(self, error): """Handles errors received from the Instance ID API.""" if error.response is None: @@ -562,6 +675,9 @@ def _handle_batch_error(self, error): return _gapic_utils.handle_platform_error_from_googleapiclient( error, _MessagingService._build_fcm_error_googleapiclient) + def close(self) -> None: + asyncio.run(self._async_client.aclose()) + @classmethod def _build_fcm_error_requests(cls, error, message, error_dict): """Parses an error response from the FCM API and creates a FCM-specific exception if @@ -569,6 +685,22 @@ def _build_fcm_error_requests(cls, error, message, error_dict): exc_type = cls._build_fcm_error(error_dict) return exc_type(message, cause=error, http_response=error.response) if exc_type else None + @classmethod + def _build_fcm_error_httpx( + cls, + error: httpx.HTTPError, + message: str, + error_dict: Optional[Dict[str, Any]] + ) -> Optional[exceptions.FirebaseError]: + """Parses a httpx error response from the FCM API and creates a FCM-specific exception if + appropriate.""" + exc_type = cls._build_fcm_error(error_dict) + if isinstance(error, httpx.HTTPStatusError): + return exc_type( + message, cause=error, http_response=error.response) if exc_type else None + return exc_type(message, cause=error) if exc_type else None + + @classmethod def _build_fcm_error_googleapiclient(cls, error, message, error_dict, http_response): """Parses an error response from the FCM API and creates a FCM-specific exception if @@ -577,7 +709,11 @@ def _build_fcm_error_googleapiclient(cls, error, message, error_dict, http_respo return exc_type(message, cause=error, http_response=http_response) if exc_type else None @classmethod - def _build_fcm_error(cls, error_dict): + def _build_fcm_error( + cls, + error_dict: Optional[Dict[str, Any]] + ) -> Optional[Callable[..., exceptions.FirebaseError]]: + """Parses an error response to determine the appropriate FCM-specific error type.""" if not error_dict: return None fcm_code = None @@ -585,4 +721,4 @@ def _build_fcm_error(cls, error_dict): if detail.get('@type') == 'type.googleapis.com/google.firebase.fcm.v1.FcmError': fcm_code = detail.get('errorCode') break - return _MessagingService.FCM_ERROR_TYPES.get(fcm_code) + return _MessagingService.FCM_ERROR_TYPES.get(fcm_code) if fcm_code else None diff --git a/integration/conftest.py b/integration/conftest.py index 71f53f612..efa45932d 100644 --- a/integration/conftest.py +++ b/integration/conftest.py @@ -15,8 +15,8 @@ """pytest configuration and global fixtures for integration tests.""" import json -import asyncio import pytest +from pytest_asyncio import is_async_test import firebase_admin from firebase_admin import credentials @@ -72,11 +72,8 @@ def api_key(request): with open(path) as keyfile: return keyfile.read().strip() -@pytest.fixture(scope="session") -def event_loop(): - """Create an instance of the default event loop for test session. - This avoids early eventloop closure. - """ - loop = asyncio.get_event_loop_policy().new_event_loop() - yield loop - loop.close() +def pytest_collection_modifyitems(items): + pytest_asyncio_tests = (item for item in items if is_async_test(item)) + session_scope_marker = pytest.mark.asyncio(loop_scope="session") + for async_test in pytest_asyncio_tests: + async_test.add_marker(session_scope_marker, append=False) diff --git a/integration/test_messaging.py b/integration/test_messaging.py index 4c1d7d0dc..296a4d338 100644 --- a/integration/test_messaging.py +++ b/integration/test_messaging.py @@ -221,3 +221,68 @@ def test_subscribe(): def test_unsubscribe(): resp = messaging.unsubscribe_from_topic(_REGISTRATION_TOKEN, 'mock-topic') assert resp.success_count + resp.failure_count == 1 + +@pytest.mark.asyncio +async def test_send_each_async(): + messages = [ + messaging.Message( + topic='foo-bar', notification=messaging.Notification('Title', 'Body')), + messaging.Message( + topic='foo-bar', notification=messaging.Notification('Title', 'Body')), + messaging.Message( + token='not-a-token', notification=messaging.Notification('Title', 'Body')), + ] + + batch_response = await messaging.send_each_async(messages, dry_run=True) + + assert batch_response.success_count == 2 + assert batch_response.failure_count == 1 + assert len(batch_response.responses) == 3 + + response = batch_response.responses[0] + assert response.success is True + assert response.exception is None + assert re.match('^projects/.*/messages/.*$', response.message_id) + + response = batch_response.responses[1] + assert response.success is True + assert response.exception is None + assert re.match('^projects/.*/messages/.*$', response.message_id) + + response = batch_response.responses[2] + assert response.success is False + assert isinstance(response.exception, exceptions.InvalidArgumentError) + assert response.message_id is None + +@pytest.mark.asyncio +async def test_send_each_async_500(): + messages = [] + for msg_number in range(500): + topic = 'foo-bar-{0}'.format(msg_number % 10) + messages.append(messaging.Message(topic=topic)) + + batch_response = await messaging.send_each_async(messages, dry_run=True) + + assert batch_response.success_count == 500 + assert batch_response.failure_count == 0 + assert len(batch_response.responses) == 500 + for response in batch_response.responses: + assert response.success is True + assert response.exception is None + assert re.match('^projects/.*/messages/.*$', response.message_id) + +@pytest.mark.asyncio +async def test_send_each_for_multicast_async(): + multicast = messaging.MulticastMessage( + notification=messaging.Notification('Title', 'Body'), + tokens=['not-a-token', 'also-not-a-token']) + + batch_response = await messaging.send_each_for_multicast_async(multicast) + + assert batch_response.success_count == 0 + assert batch_response.failure_count == 2 + assert len(batch_response.responses) == 2 + for response in batch_response.responses: + assert response.success is False + assert response.exception is not None + assert response.message_id is None diff --git a/requirements.txt b/requirements.txt index fd5b0b39c..ba6f2f947 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,10 +5,12 @@ pytest-cov >= 2.4.0 pytest-localserver >= 0.4.1 pytest-asyncio >= 0.16.0 pytest-mock >= 3.6.1 +respx == 0.22.0 cachecontrol >= 0.12.14 google-api-core[grpc] >= 1.22.1, < 3.0.0dev; platform.python_implementation != 'PyPy' google-api-python-client >= 1.7.8 google-cloud-firestore >= 2.19.0; platform.python_implementation != 'PyPy' google-cloud-storage >= 1.37.1 -pyjwt[crypto] >= 2.5.0 \ No newline at end of file +pyjwt[crypto] >= 2.5.0 +httpx[http2] == 0.28.1 \ No newline at end of file diff --git a/setup.py b/setup.py index 23be6d481..e92d207aa 100644 --- a/setup.py +++ b/setup.py @@ -43,6 +43,7 @@ 'google-cloud-firestore>=2.19.0; platform.python_implementation != "PyPy"', 'google-cloud-storage>=1.37.1', 'pyjwt[crypto] >= 2.5.0', + 'httpx[http2] == 0.28.1', ] setup( diff --git a/tests/test_http_client.py b/tests/test_http_client.py index 78036166c..f1e7f6a64 100644 --- a/tests/test_http_client.py +++ b/tests/test_http_client.py @@ -13,116 +13,131 @@ # limitations under the License. """Tests for firebase_admin._http_client.""" +from typing import Dict, Optional, Union import pytest +import httpx +import respx from pytest_localserver import http +from pytest_mock import MockerFixture import requests from firebase_admin import _http_client, _utils +from firebase_admin._retry import HttpxRetry, HttpxRetryTransport +from firebase_admin._http_client import ( + HttpxAsyncClient, + GoogleAuthCredentialFlow, + DEFAULT_TIMEOUT_SECONDS +) from tests import testutils _TEST_URL = 'http://firebase.test.url/' +@pytest.fixture +def default_retry_config() -> HttpxRetry: + """Provides a fresh copy of the default retry config instance.""" + return _http_client.DEFAULT_HTTPX_RETRY_CONFIG -def test_http_client_default_session(): - client = _http_client.HttpClient() - assert client.session is not None - assert client.base_url == '' - recorder = _instrument(client, 'body') - resp = client.request('get', _TEST_URL) - assert resp.status_code == 200 - assert resp.text == 'body' - assert len(recorder) == 1 - assert recorder[0].method == 'GET' - assert recorder[0].url == _TEST_URL - -def test_http_client_custom_session(): - session = requests.Session() - client = _http_client.HttpClient(session=session) - assert client.session is session - assert client.base_url == '' - recorder = _instrument(client, 'body') - resp = client.request('get', _TEST_URL) - assert resp.status_code == 200 - assert resp.text == 'body' - assert len(recorder) == 1 - assert recorder[0].method == 'GET' - assert recorder[0].url == _TEST_URL - -def test_base_url(): - client = _http_client.HttpClient(base_url=_TEST_URL) - assert client.session is not None - assert client.base_url == _TEST_URL - recorder = _instrument(client, 'body') - resp = client.request('get', 'foo') - assert resp.status_code == 200 - assert resp.text == 'body' - assert len(recorder) == 1 - assert recorder[0].method == 'GET' - assert recorder[0].url == _TEST_URL + 'foo' - -def test_metrics_headers(): - client = _http_client.HttpClient() - assert client.session is not None - recorder = _instrument(client, 'body') - resp = client.request('get', _TEST_URL) - assert resp.status_code == 200 - assert resp.text == 'body' - assert len(recorder) == 1 - assert recorder[0].method == 'GET' - assert recorder[0].url == _TEST_URL - assert recorder[0].headers['x-goog-api-client'] == _utils.get_metrics_header() - -def test_metrics_headers_with_credentials(): - client = _http_client.HttpClient( - credential=testutils.MockGoogleCredential()) - assert client.session is not None - recorder = _instrument(client, 'body') - resp = client.request('get', _TEST_URL) - assert resp.status_code == 200 - assert resp.text == 'body' - assert len(recorder) == 1 - assert recorder[0].method == 'GET' - assert recorder[0].url == _TEST_URL - expected_metrics_header = _utils.get_metrics_header() + ' mock-cred-metric-tag' - assert recorder[0].headers['x-goog-api-client'] == expected_metrics_header - -def test_credential(): - client = _http_client.HttpClient( - credential=testutils.MockGoogleCredential()) - assert client.session is not None - recorder = _instrument(client, 'body') - resp = client.request('get', _TEST_URL) - assert resp.status_code == 200 - assert resp.text == 'body' - assert len(recorder) == 1 - assert recorder[0].method == 'GET' - assert recorder[0].url == _TEST_URL - assert recorder[0].headers['Authorization'] == 'Bearer mock-token' - -@pytest.mark.parametrize('options, timeout', [ - ({}, _http_client.DEFAULT_TIMEOUT_SECONDS), - ({'timeout': 7}, 7), - ({'timeout': 0}, 0), - ({'timeout': None}, None), -]) -def test_timeout(options, timeout): - client = _http_client.HttpClient(**options) - assert client.timeout == timeout - recorder = _instrument(client, 'body') - client.request('get', _TEST_URL) - assert len(recorder) == 1 - if timeout is None: - assert recorder[0]._extra_kwargs['timeout'] is None - else: - assert recorder[0]._extra_kwargs['timeout'] == pytest.approx(timeout, 0.001) - - -def _instrument(client, payload, status=200): - recorder = [] - adapter = testutils.MockAdapter(payload, status, recorder) - client.session.mount(_TEST_URL, adapter) - return recorder +class TestHttpClient: + def test_http_client_default_session(self): + client = _http_client.HttpClient() + assert client.session is not None + assert client.base_url == '' + recorder = self._instrument(client, 'body') + resp = client.request('get', _TEST_URL) + assert resp.status_code == 200 + assert resp.text == 'body' + assert len(recorder) == 1 + assert recorder[0].method == 'GET' + assert recorder[0].url == _TEST_URL + + def test_http_client_custom_session(self): + session = requests.Session() + client = _http_client.HttpClient(session=session) + assert client.session is session + assert client.base_url == '' + recorder = self._instrument(client, 'body') + resp = client.request('get', _TEST_URL) + assert resp.status_code == 200 + assert resp.text == 'body' + assert len(recorder) == 1 + assert recorder[0].method == 'GET' + assert recorder[0].url == _TEST_URL + + def test_base_url(self): + client = _http_client.HttpClient(base_url=_TEST_URL) + assert client.session is not None + assert client.base_url == _TEST_URL + recorder = self._instrument(client, 'body') + resp = client.request('get', 'foo') + assert resp.status_code == 200 + assert resp.text == 'body' + assert len(recorder) == 1 + assert recorder[0].method == 'GET' + assert recorder[0].url == _TEST_URL + 'foo' + + def test_metrics_headers(self): + client = _http_client.HttpClient() + assert client.session is not None + recorder = self._instrument(client, 'body') + resp = client.request('get', _TEST_URL) + assert resp.status_code == 200 + assert resp.text == 'body' + assert len(recorder) == 1 + assert recorder[0].method == 'GET' + assert recorder[0].url == _TEST_URL + assert recorder[0].headers['x-goog-api-client'] == _utils.get_metrics_header() + + def test_metrics_headers_with_credentials(self): + client = _http_client.HttpClient( + credential=testutils.MockGoogleCredential()) + assert client.session is not None + recorder = self._instrument(client, 'body') + resp = client.request('get', _TEST_URL) + assert resp.status_code == 200 + assert resp.text == 'body' + assert len(recorder) == 1 + assert recorder[0].method == 'GET' + assert recorder[0].url == _TEST_URL + expected_metrics_header = _utils.get_metrics_header() + ' mock-cred-metric-tag' + assert recorder[0].headers['x-goog-api-client'] == expected_metrics_header + + def test_credential(self): + client = _http_client.HttpClient( + credential=testutils.MockGoogleCredential()) + assert client.session is not None + recorder = self._instrument(client, 'body') + resp = client.request('get', _TEST_URL) + assert resp.status_code == 200 + assert resp.text == 'body' + assert len(recorder) == 1 + assert recorder[0].method == 'GET' + assert recorder[0].url == _TEST_URL + assert recorder[0].headers['Authorization'] == 'Bearer mock-token' + + @pytest.mark.parametrize('options, timeout', [ + ({}, _http_client.DEFAULT_TIMEOUT_SECONDS), + ({'timeout': 7}, 7), + ({'timeout': 0}, 0), + ({'timeout': None}, None), + ]) + def test_timeout(self, options, timeout): + client = _http_client.HttpClient(**options) + assert client.timeout == timeout + recorder = self._instrument(client, 'body') + client.request('get', _TEST_URL) + assert len(recorder) == 1 + if timeout is None: + assert recorder[0]._extra_kwargs['timeout'] is None + else: + assert recorder[0]._extra_kwargs['timeout'] == pytest.approx(timeout, 0.001) + + + def _instrument(self, client, payload, status=200): + recorder = [] + adapter = testutils.MockAdapter(payload, status, recorder) + client.session.mount(_TEST_URL, adapter) + return recorder class TestHttpRetry: @@ -183,3 +198,473 @@ def test_no_retry_on_404(self): client.request('get', '/') assert excinfo.value.response.status_code == 404 assert len(self.httpserver.requests) == 1 + +class TestHttpxAsyncClient: + def test_init_default(self, mocker: MockerFixture, default_retry_config: HttpxRetry): + """Test client initialization with default settings (no credentials).""" + + # Mock httpx.AsyncClient and HttpxRetryTransport init to check args passed to them + mock_async_client_init = mocker.patch('httpx.AsyncClient.__init__', return_value=None) + mock_transport_init = mocker.patch( + 'firebase_admin._retry.HttpxRetryTransport.__init__', return_value=None + ) + + client = HttpxAsyncClient() + + assert client.base_url == '' + assert client.timeout == DEFAULT_TIMEOUT_SECONDS + assert client._headers == _http_client.METRICS_HEADERS + assert client._retry_config == default_retry_config + + # Check httpx.AsyncClient call args + _, init_kwargs = mock_async_client_init.call_args + assert init_kwargs.get('http2') is True + assert init_kwargs.get('timeout') == DEFAULT_TIMEOUT_SECONDS + assert init_kwargs.get('headers') == _http_client.METRICS_HEADERS + assert init_kwargs.get('auth') is None + assert 'mounts' in init_kwargs + assert 'http://' in init_kwargs['mounts'] + assert 'https://' in init_kwargs['mounts'] + assert isinstance(init_kwargs['mounts']['http://'], HttpxRetryTransport) + assert isinstance(init_kwargs['mounts']['https://'], HttpxRetryTransport) + + # Check that HttpxRetryTransport was initialized with the default retry config + assert mock_transport_init.call_count >= 1 + _, transport_call_kwargs = mock_transport_init.call_args_list[0] + assert transport_call_kwargs.get('retry') == default_retry_config + assert transport_call_kwargs.get('http2') is True + + def test_init_with_credentials(self, mocker: MockerFixture, default_retry_config: HttpxRetry): + """Test client initialization with credentials.""" + + # Mock GoogleAuthCredentialFlow, httpx.AsyncClient and HttpxRetryTransport init to + # check args passed to them + mock_auth_flow_init = mocker.patch( + 'firebase_admin._http_client.GoogleAuthCredentialFlow.__init__', return_value=None + ) + mock_async_client_init = mocker.patch('httpx.AsyncClient.__init__', return_value=None) + mock_transport_init = mocker.patch( + 'firebase_admin._retry.HttpxRetryTransport.__init__', return_value=None + ) + + mock_credential = testutils.MockGoogleCredential() + client = HttpxAsyncClient(credential=mock_credential) + + assert client.base_url == '' + assert client.timeout == DEFAULT_TIMEOUT_SECONDS + assert client._headers == _http_client.METRICS_HEADERS + assert client._retry_config == default_retry_config + + # Verify GoogleAuthCredentialFlow was initialized with the credential + mock_auth_flow_init.assert_called_once_with(mock_credential) + + # Check httpx.AsyncClient call args + _, init_kwargs = mock_async_client_init.call_args + assert init_kwargs.get('http2') is True + assert init_kwargs.get('timeout') == DEFAULT_TIMEOUT_SECONDS + assert init_kwargs.get('headers') == _http_client.METRICS_HEADERS + assert isinstance(init_kwargs.get('auth'), GoogleAuthCredentialFlow) + assert 'mounts' in init_kwargs + assert 'http://' in init_kwargs['mounts'] + assert 'https://' in init_kwargs['mounts'] + assert isinstance(init_kwargs['mounts']['http://'], HttpxRetryTransport) + assert isinstance(init_kwargs['mounts']['https://'], HttpxRetryTransport) + + # Check that HttpxRetryTransport was initialized with the default retry config + assert mock_transport_init.call_count >= 1 + _, transport_call_kwargs = mock_transport_init.call_args_list[0] + assert transport_call_kwargs.get('retry') == default_retry_config + assert transport_call_kwargs.get('http2') is True + + def test_init_with_custom_settings(self, mocker: MockerFixture): + """Test client initialization with custom settings.""" + + # Mock httpx.AsyncClient and HttpxRetryTransport init to check args passed to them + mock_auth_flow_init = mocker.patch( + 'firebase_admin._http_client.GoogleAuthCredentialFlow.__init__', return_value=None + ) + mock_async_client_init = mocker.patch('httpx.AsyncClient.__init__', return_value=None) + mock_transport_init = mocker.patch( + 'firebase_admin._retry.HttpxRetryTransport.__init__', return_value=None + ) + + mock_credential = testutils.MockGoogleCredential() + headers = {'X-Custom': 'Test'} + custom_retry = HttpxRetry(max_retries=1, status_forcelist=[429], backoff_factor=0) + timeout = 60 + http2 = False + + expected_headers = {**headers, **_http_client.METRICS_HEADERS} + + client = HttpxAsyncClient( + credential=mock_credential, base_url=_TEST_URL, headers=headers, + retry_config=custom_retry, timeout=timeout, http2=http2) + + assert client.base_url == _TEST_URL + assert client._headers == expected_headers + assert client._retry_config == custom_retry + assert client.timeout == timeout + + # Verify GoogleAuthCredentialFlow was initialized with the credential + mock_auth_flow_init.assert_called_once_with(mock_credential) + # Verify original headers are not mutated + assert headers == {'X-Custom': 'Test'} + + # Check httpx.AsyncClient call args + _, init_kwargs = mock_async_client_init.call_args + assert init_kwargs.get('http2') is False + assert init_kwargs.get('timeout') == timeout + assert init_kwargs.get('headers') == expected_headers + assert isinstance(init_kwargs.get('auth'), GoogleAuthCredentialFlow) + assert 'mounts' in init_kwargs + assert 'http://' in init_kwargs['mounts'] + assert 'https://' in init_kwargs['mounts'] + assert isinstance(init_kwargs['mounts']['http://'], HttpxRetryTransport) + assert isinstance(init_kwargs['mounts']['https://'], HttpxRetryTransport) + + # Check that HttpxRetryTransport was initialized with the default retry config + assert mock_transport_init.call_count >= 1 + _, transport_call_kwargs = mock_transport_init.call_args_list[0] + assert transport_call_kwargs.get('retry') == custom_retry + assert transport_call_kwargs.get('http2') is False + + + @respx.mock + @pytest.mark.asyncio + async def test_request(self): + """Test client request.""" + + client = HttpxAsyncClient() + + responses = [ + respx.MockResponse(200, http_version='HTTP/2', content='body'), + ] + route = respx.request('POST', _TEST_URL).mock(side_effect=responses) + + resp = await client.request('post', _TEST_URL) + assert resp.status_code == 200 + assert resp.text == 'body' + assert route.call_count == 1 + + request = route.calls.last.request + assert request.method == 'POST' + assert request.url == _TEST_URL + self.check_headers(request.headers, has_auth=False) + + @respx.mock + @pytest.mark.asyncio + async def test_request_raise_for_status(self): + """Test client request raise for status error.""" + + client = HttpxAsyncClient() + + responses = [ + respx.MockResponse(404, http_version='HTTP/2', content='Status error'), + ] + route = respx.request('POST', _TEST_URL).mock(side_effect=responses) + + with pytest.raises(httpx.HTTPStatusError) as exc_info: + resp = await client.request('post', _TEST_URL) + resp = exc_info.value.response + assert resp.status_code == 404 + assert resp.text == 'Status error' + assert route.call_count == 1 + + request = route.calls.last.request + assert request.method == 'POST' + assert request.url == _TEST_URL + self.check_headers(request.headers, has_auth=False) + + + @respx.mock + @pytest.mark.asyncio + async def test_request_with_base_url(self): + """Test client request with base_url.""" + + client = HttpxAsyncClient(base_url=_TEST_URL) + + url_extension = 'post/123' + responses = [ + respx.MockResponse(200, http_version='HTTP/2', content='body'), + ] + route = respx.request('POST', _TEST_URL + url_extension).mock(side_effect=responses) + + resp = await client.request('POST', url_extension) + assert resp.status_code == 200 + assert resp.text == 'body' + assert route.call_count == 1 + + request = route.calls.last.request + assert request.method == 'POST' + assert request.url == _TEST_URL + url_extension + self.check_headers(request.headers, has_auth=False) + + @respx.mock + @pytest.mark.asyncio + async def test_request_with_timeout(self): + """Test client request with timeout.""" + + timeout = 60 + client = HttpxAsyncClient(timeout=timeout) + responses = [ + respx.MockResponse(200, http_version='HTTP/2', content='body'), + ] + route = respx.request('POST', _TEST_URL).mock(side_effect=responses) + + resp = await client.request('POST', _TEST_URL) + assert resp.status_code == 200 + assert resp.text == 'body' + assert route.call_count == 1 + + request = route.calls.last.request + assert request.method == 'POST' + assert request.url == _TEST_URL + self.check_headers(request.headers, has_auth=False) + + @respx.mock + @pytest.mark.asyncio + async def test_request_with_credential(self): + """Test client request with credentials.""" + + mock_credential = testutils.MockGoogleCredential() + client = HttpxAsyncClient(credential=mock_credential) + + responses = [ + respx.MockResponse(200, http_version='HTTP/2', content='test'), + ] + route = respx.request('POST', _TEST_URL).mock(side_effect=responses) + + resp = await client.request('post', _TEST_URL) + + assert resp.status_code == 200 + assert resp.text == 'test' + assert route.call_count == 1 + + request = route.calls.last.request + assert request.method == 'POST' + assert request.url == _TEST_URL + self.check_headers(request.headers) + + @respx.mock + @pytest.mark.asyncio + async def test_request_with_headers(self): + """Test client request with credentials.""" + + mock_credential = testutils.MockGoogleCredential() + headers = httpx.Headers({'X-Custom': 'Test'}) + client = HttpxAsyncClient(credential=mock_credential, headers=headers) + + responses = [ + respx.MockResponse(200, http_version='HTTP/2', content='body'), + ] + route = respx.request('POST', _TEST_URL).mock(side_effect=responses) + + resp = await client.request('post', _TEST_URL) + assert resp.status_code == 200 + assert resp.text == 'body' + assert route.call_count == 1 + + request = route.calls.last.request + assert request.method == 'POST' + assert request.url == _TEST_URL + self.check_headers(request.headers, expected_headers=headers) + + + @respx.mock + @pytest.mark.asyncio + async def test_response_get_headers(self): + """Test the headers() helper method.""" + + client = HttpxAsyncClient() + expected_headers = {'X-Custom': 'Test'} + + responses = [ + respx.MockResponse(200, http_version='HTTP/2', headers=expected_headers), + ] + route = respx.request('POST', _TEST_URL).mock(side_effect=responses) + + headers = await client.headers('post', _TEST_URL) + + self.check_headers( + headers, expected_headers=expected_headers, has_auth=False, has_metrics=False + ) + assert route.call_count == 1 + + request = route.calls.last.request + assert request.method == 'POST' + assert request.url == _TEST_URL + self.check_headers(request.headers, has_auth=False) + + @respx.mock + @pytest.mark.asyncio + async def test_response_get_body_and_response(self): + """Test the body_and_response() helper method.""" + + client = HttpxAsyncClient() + expected_body = {'key': 'value'} + + responses = [ + respx.MockResponse(200, http_version='HTTP/2', json=expected_body), + ] + route = respx.request('POST', _TEST_URL).mock(side_effect=responses) + + body, resp = await client.body_and_response('post', _TEST_URL) + + assert resp.status_code == 200 + assert body == expected_body + assert route.call_count == 1 + + request = route.calls.last.request + assert request.method == 'POST' + assert request.url == _TEST_URL + self.check_headers(request.headers, has_auth=False) + + + @respx.mock + @pytest.mark.asyncio + async def test_response_get_body(self): + """Test the body() helper method.""" + + client = HttpxAsyncClient() + expected_body = {'key': 'value'} + + responses = [ + respx.MockResponse(200, http_version='HTTP/2', json=expected_body), + ] + route = respx.request('POST', _TEST_URL).mock(side_effect=responses) + + body = await client.body('post', _TEST_URL) + + assert body == expected_body + assert route.call_count == 1 + + request = route.calls.last.request + assert request.method == 'POST' + assert request.url == _TEST_URL + self.check_headers(request.headers, has_auth=False) + + @respx.mock + @pytest.mark.asyncio + async def test_response_get_headers_and_body(self): + """Test the headers_and_body() helper method.""" + + client = HttpxAsyncClient() + expected_headers = {'X-Custom': 'Test'} + expected_body = {'key': 'value'} + + responses = [ + respx.MockResponse( + 200, http_version='HTTP/2', json=expected_body, headers=expected_headers), + ] + route = respx.request('POST', _TEST_URL).mock(side_effect=responses) + + headers, body = await client.headers_and_body('post', _TEST_URL) + + assert body == expected_body + self.check_headers( + headers, expected_headers=expected_headers, has_auth=False, has_metrics=False + ) + assert route.call_count == 1 + + request = route.calls.last.request + assert request.method == 'POST' + assert request.url == _TEST_URL + self.check_headers(request.headers, has_auth=False) + + @pytest.mark.asyncio + async def test_aclose(self): + """Test that aclose calls the underlying client's aclose.""" + + client = HttpxAsyncClient() + assert client._async_client.is_closed is False + await client.aclose() + assert client._async_client.is_closed is True + + + def check_headers( + self, + headers: Union[httpx.Headers, Dict[str, str]], + expected_headers: Optional[Union[httpx.Headers, Dict[str, str]]] = None, + has_auth: bool = True, + has_metrics: bool = True + ): + if expected_headers: + for header_key in expected_headers.keys(): + assert header_key in headers + assert headers.get(header_key) == expected_headers.get(header_key) + + if has_auth: + assert 'Authorization' in headers + assert headers.get('Authorization') == 'Bearer mock-token' + + if has_metrics: + for header_key in _http_client.METRICS_HEADERS: + assert header_key in headers + expected_metrics_header = _http_client.METRICS_HEADERS.get(header_key, '') + if has_auth: + expected_metrics_header += ' mock-cred-metric-tag' + assert headers.get(header_key) == expected_metrics_header + + +class TestGoogleAuthCredentialFlow: + + @respx.mock + @pytest.mark.asyncio + async def test_auth_headers_retry(self): + """Test invalid credential retry.""" + + mock_credential = testutils.MockGoogleCredential() + client = HttpxAsyncClient(credential=mock_credential) + + responses = [ + respx.MockResponse(401, http_version='HTTP/2', content='Auth error'), + respx.MockResponse(401, http_version='HTTP/2', content='Auth error'), + respx.MockResponse(200, http_version='HTTP/2', content='body'), + ] + route = respx.request('POST', _TEST_URL).mock(side_effect=responses) + + resp = await client.request('post', _TEST_URL) + assert resp.status_code == 200 + assert resp.text == 'body' + assert route.call_count == 3 + + request = route.calls.last.request + assert request.method == 'POST' + assert request.url == _TEST_URL + headers = request.headers + assert 'Authorization' in headers + assert headers.get('Authorization') == 'Bearer mock-token' + + @respx.mock + @pytest.mark.asyncio + async def test_auth_headers_retry_exhausted(self, mocker: MockerFixture): + """Test invalid credential retry exhausted.""" + + mock_credential = testutils.MockGoogleCredential() + mock_credential_patch = mocker.spy(mock_credential, 'refresh') + client = HttpxAsyncClient(credential=mock_credential) + + responses = [ + respx.MockResponse(401, http_version='HTTP/2', content='Auth error'), + respx.MockResponse(401, http_version='HTTP/2', content='Auth error'), + respx.MockResponse(401, http_version='HTTP/2', content='Auth error'), + # Should stop after previous response + respx.MockResponse(200, http_version='HTTP/2', content='body'), + ] + route = respx.request('POST', _TEST_URL).mock(side_effect=responses) + + with pytest.raises(httpx.HTTPStatusError) as exc_info: + resp = await client.request('post', _TEST_URL) + resp = exc_info.value.response + assert resp.status_code == 401 + assert resp.text == 'Auth error' + assert route.call_count == 3 + + assert mock_credential_patch.call_count == 3 + + request = route.calls.last.request + assert request.method == 'POST' + assert request.url == _TEST_URL + headers = request.headers + assert 'Authorization' in headers + assert headers.get('Authorization') == 'Bearer mock-token' diff --git a/tests/test_messaging.py b/tests/test_messaging.py index 54173ea97..76cee2a33 100644 --- a/tests/test_messaging.py +++ b/tests/test_messaging.py @@ -14,8 +14,11 @@ """Test cases for the firebase_admin.messaging module.""" import datetime +from itertools import chain, repeat import json import numbers +import httpx +import respx from googleapiclient import http from googleapiclient import _helpers @@ -1927,6 +1930,201 @@ def test_send_each(self): assert all([r.success for r in batch_response.responses]) assert not any([r.exception for r in batch_response.responses]) + @respx.mock + @pytest.mark.asyncio + async def test_send_each_async(self): + responses = [ + respx.MockResponse(200, http_version='HTTP/2', json={'name': 'message-id1'}), + respx.MockResponse(200, http_version='HTTP/2', json={'name': 'message-id2'}), + respx.MockResponse(200, http_version='HTTP/2', json={'name': 'message-id3'}), + ] + msg1 = messaging.Message(topic='foo1') + msg2 = messaging.Message(topic='foo2') + msg3 = messaging.Message(topic='foo3') + route = respx.request( + 'POST', + 'https://fcm.googleapis.com/v1/projects/explicit-project-id/messages:send' + ).mock(side_effect=responses) + + batch_response = await messaging.send_each_async([msg1, msg2, msg3], dry_run=True) + + # try: + # batch_response = await messaging.send_each_async([msg1, msg2], dry_run=True) + # except Exception as error: + # if isinstance(error.cause.__cause__, StopIteration): + # raise Exception('Received more requests than mocks') + + assert batch_response.success_count == 3 + assert batch_response.failure_count == 0 + assert len(batch_response.responses) == 3 + assert [r.message_id for r in batch_response.responses] \ + == ['message-id1', 'message-id2', 'message-id3'] + assert all([r.success for r in batch_response.responses]) + assert not any([r.exception for r in batch_response.responses]) + + assert route.call_count == 3 + + @respx.mock + @pytest.mark.asyncio + async def test_send_each_async_error_401_fail_auth_retry(self): + payload = json.dumps({ + 'error': { + 'status': 'UNAUTHENTICATED', + 'message': 'test unauthenticated error', + 'details': [ + { + '@type': 'type.googleapis.com/google.firebase.fcm.v1.FcmError', + 'errorCode': 'SOME_UNKNOWN_CODE', + }, + ], + } + }) + + responses = repeat(respx.MockResponse(401, http_version='HTTP/2', content=payload)) + + msg1 = messaging.Message(topic='foo1') + route = respx.request( + 'POST', + 'https://fcm.googleapis.com/v1/projects/explicit-project-id/messages:send' + ).mock(side_effect=responses) + batch_response = await messaging.send_each_async([msg1], dry_run=True) + + assert route.call_count == 3 + assert batch_response.success_count == 0 + assert batch_response.failure_count == 1 + assert len(batch_response.responses) == 1 + exception = batch_response.responses[0].exception + assert isinstance(exception, exceptions.UnauthenticatedError) + + @respx.mock + @pytest.mark.asyncio + async def test_send_each_async_error_401_pass_on_auth_retry(self): + payload = json.dumps({ + 'error': { + 'status': 'UNAUTHENTICATED', + 'message': 'test unauthenticated error', + 'details': [ + { + '@type': 'type.googleapis.com/google.firebase.fcm.v1.FcmError', + 'errorCode': 'SOME_UNKNOWN_CODE', + }, + ], + } + }) + responses = [ + respx.MockResponse(401, http_version='HTTP/2', content=payload), + respx.MockResponse(200, http_version='HTTP/2', json={'name': 'message-id1'}), + ] + + msg1 = messaging.Message(topic='foo1') + route = respx.request( + 'POST', + 'https://fcm.googleapis.com/v1/projects/explicit-project-id/messages:send' + ).mock(side_effect=responses) + batch_response = await messaging.send_each_async([msg1], dry_run=True) + + assert route.call_count == 2 + assert batch_response.success_count == 1 + assert batch_response.failure_count == 0 + assert len(batch_response.responses) == 1 + assert [r.message_id for r in batch_response.responses] == ['message-id1'] + assert all([r.success for r in batch_response.responses]) + assert not any([r.exception for r in batch_response.responses]) + + @respx.mock + @pytest.mark.asyncio + async def test_send_each_async_error_500_fail_retry_config(self): + payload = json.dumps({ + 'error': { + 'status': 'INTERNAL', + 'message': 'test INTERNAL error', + 'details': [ + { + '@type': 'type.googleapis.com/google.firebase.fcm.v1.FcmError', + 'errorCode': 'SOME_UNKNOWN_CODE', + }, + ], + } + }) + + responses = repeat(respx.MockResponse(500, http_version='HTTP/2', content=payload)) + + msg1 = messaging.Message(topic='foo1') + route = respx.request( + 'POST', + 'https://fcm.googleapis.com/v1/projects/explicit-project-id/messages:send' + ).mock(side_effect=responses) + batch_response = await messaging.send_each_async([msg1], dry_run=True) + + assert route.call_count == 5 + assert batch_response.success_count == 0 + assert batch_response.failure_count == 1 + assert len(batch_response.responses) == 1 + exception = batch_response.responses[0].exception + assert isinstance(exception, exceptions.InternalError) + + + @respx.mock + @pytest.mark.asyncio + async def test_send_each_async_error_500_pass_on_retry_config(self): + payload = json.dumps({ + 'error': { + 'status': 'INTERNAL', + 'message': 'test INTERNAL error', + 'details': [ + { + '@type': 'type.googleapis.com/google.firebase.fcm.v1.FcmError', + 'errorCode': 'SOME_UNKNOWN_CODE', + }, + ], + } + }) + responses = chain( + [ + respx.MockResponse(500, http_version='HTTP/2', content=payload), + respx.MockResponse(500, http_version='HTTP/2', content=payload), + respx.MockResponse(500, http_version='HTTP/2', content=payload), + respx.MockResponse(500, http_version='HTTP/2', content=payload), + respx.MockResponse(200, http_version='HTTP/2', json={'name': 'message-id1'}), + ], + ) + + msg1 = messaging.Message(topic='foo1') + route = respx.request( + 'POST', + 'https://fcm.googleapis.com/v1/projects/explicit-project-id/messages:send' + ).mock(side_effect=responses) + batch_response = await messaging.send_each_async([msg1], dry_run=True) + + assert route.call_count == 5 + assert batch_response.success_count == 1 + assert batch_response.failure_count == 0 + assert len(batch_response.responses) == 1 + assert [r.message_id for r in batch_response.responses] == ['message-id1'] + assert all([r.success for r in batch_response.responses]) + assert not any([r.exception for r in batch_response.responses]) + + @respx.mock + @pytest.mark.asyncio + async def test_send_each_async_request_error(self): + responses = httpx.ConnectError("Test request error", request=httpx.Request( + 'POST', + 'https://fcm.googleapis.com/v1/projects/explicit-project-id/messages:send')) + + msg1 = messaging.Message(topic='foo1') + route = respx.request( + 'POST', + 'https://fcm.googleapis.com/v1/projects/explicit-project-id/messages:send' + ).mock(side_effect=responses) + batch_response = await messaging.send_each_async([msg1], dry_run=True) + + assert route.call_count == 1 + assert batch_response.success_count == 0 + assert batch_response.failure_count == 1 + assert len(batch_response.responses) == 1 + exception = batch_response.responses[0].exception + assert isinstance(exception, exceptions.UnavailableError) + @pytest.mark.parametrize('status', HTTP_ERROR_CODES) def test_send_each_detailed_error(self, status): success_payload = json.dumps({'name': 'message-id'}) diff --git a/tests/test_retry.py b/tests/test_retry.py new file mode 100644 index 000000000..751fdea7b --- /dev/null +++ b/tests/test_retry.py @@ -0,0 +1,454 @@ +# Copyright 2025 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test cases for the firebase_admin._retry module.""" + +import time +import email.utils +from itertools import repeat +from unittest.mock import call +import pytest +import httpx +from pytest_mock import MockerFixture +import respx + +from firebase_admin._retry import HttpxRetry, HttpxRetryTransport + +_TEST_URL = 'http://firebase.test.url/' + +@pytest.fixture +def base_url() -> str: + """Provides a consistent base URL for tests.""" + return _TEST_URL + +class TestHttpxRetryTransport(): + @pytest.mark.asyncio + @respx.mock + async def test_no_retry_on_success(self, base_url: str, mocker: MockerFixture): + """Test that a successful response doesn't trigger retries.""" + retry_config = HttpxRetry(max_retries=3, status_forcelist=[500]) + transport = HttpxRetryTransport(retry=retry_config) + client = httpx.AsyncClient(transport=transport) + + route = respx.post(base_url).mock(return_value=httpx.Response(200, text="Success")) + + mock_sleep = mocker.patch('asyncio.sleep', return_value=None) + response = await client.post(base_url) + + assert response.status_code == 200 + assert response.text == "Success" + assert route.call_count == 1 + mock_sleep.assert_not_called() + + @pytest.mark.asyncio + @respx.mock + async def test_no_retry_on_non_retryable_status(self, base_url: str, mocker: MockerFixture): + """Test that a non-retryable error status doesn't trigger retries.""" + retry_config = HttpxRetry(max_retries=3, status_forcelist=[500, 503]) + transport = HttpxRetryTransport(retry=retry_config) + client = httpx.AsyncClient(transport=transport) + + route = respx.post(base_url).mock(return_value=httpx.Response(404, text="Not Found")) + + mock_sleep = mocker.patch('asyncio.sleep', return_value=None) + response = await client.post(base_url) + + assert response.status_code == 404 + assert response.text == "Not Found" + assert route.call_count == 1 + mock_sleep.assert_not_called() + + @pytest.mark.asyncio + @respx.mock + async def test_retry_on_status_code_success_on_last_retry( + self, base_url: str, mocker: MockerFixture + ): + """Test retry on status code from status_forcelist, succeeding on the last attempt.""" + retry_config = HttpxRetry(max_retries=2, status_forcelist=[503, 500], backoff_factor=0.5) + transport = HttpxRetryTransport(retry=retry_config) + client = httpx.AsyncClient(transport=transport) + + route = respx.post(base_url).mock(side_effect=[ + httpx.Response(503, text="Attempt 1 Failed"), + httpx.Response(500, text="Attempt 2 Failed"), + httpx.Response(200, text="Attempt 3 Success"), + ]) + + mock_sleep = mocker.patch('asyncio.sleep', return_value=None) + response = await client.post(base_url) + + assert response.status_code == 200 + assert response.text == "Attempt 3 Success" + assert route.call_count == 3 + assert mock_sleep.call_count == 2 + # Check sleep calls (backoff_factor is 0.5) + mock_sleep.assert_has_calls([call(0.0), call(1.0)]) + + @pytest.mark.asyncio + @respx.mock + async def test_retry_exhausted_returns_last_response( + self, base_url: str, mocker: MockerFixture + ): + """Test that the last response is returned when retries are exhausted.""" + retry_config = HttpxRetry(max_retries=1, status_forcelist=[500], backoff_factor=0) + transport = HttpxRetryTransport(retry=retry_config) + client = httpx.AsyncClient(transport=transport) + + route = respx.post(base_url).mock(side_effect=[ + httpx.Response(500, text="Attempt 1 Failed"), + httpx.Response(500, text="Attempt 2 Failed (Final)"), + # Should stop after previous response + httpx.Response(200, text="This should not be reached"), + ]) + + mock_sleep = mocker.patch('asyncio.sleep', return_value=None) + response = await client.post(base_url) + + assert response.status_code == 500 + assert response.text == "Attempt 2 Failed (Final)" + assert route.call_count == 2 # Initial call + 1 retry + assert mock_sleep.call_count == 1 # Slept before the single retry + + @pytest.mark.asyncio + @respx.mock + async def test_retry_after_header_seconds(self, base_url: str, mocker: MockerFixture): + """Test respecting Retry-After header with seconds value.""" + retry_config = HttpxRetry( + max_retries=1, respect_retry_after_header=True, backoff_factor=100) + transport = HttpxRetryTransport(retry=retry_config) + client = httpx.AsyncClient(transport=transport) + + route = respx.post(base_url).mock(side_effect=[ + httpx.Response(429, text="Too Many Requests", headers={'Retry-After': '10'}), + httpx.Response(200, text="OK"), + ]) + + mock_sleep = mocker.patch('asyncio.sleep', return_value=None) + response = await client.post(base_url) + + assert response.status_code == 200 + assert route.call_count == 2 + assert mock_sleep.call_count == 1 + # Assert sleep was called with the value from Retry-After header + mock_sleep.assert_called_once_with(10.0) + + @pytest.mark.asyncio + @respx.mock + async def test_retry_after_header_http_date(self, base_url: str, mocker: MockerFixture): + """Test respecting Retry-After header with an HTTP-date value.""" + retry_config = HttpxRetry( + max_retries=1, respect_retry_after_header=True, backoff_factor=100) + transport = HttpxRetryTransport(retry=retry_config) + client = httpx.AsyncClient(transport=transport) + + # Calculate a future time and format as HTTP-date + retry_delay_seconds = 60 + time_at_request = time.time() + retry_time = time_at_request + retry_delay_seconds + http_date = email.utils.formatdate(retry_time) + + route = respx.post(base_url).mock(side_effect=[ + httpx.Response(503, text="Maintenance", headers={'Retry-After': http_date}), + httpx.Response(200, text="OK"), + ]) + + mock_sleep = mocker.patch('asyncio.sleep', return_value=None) + # Patch time.time() within the test context to control the baseline for date calculation + # Set the mock time to be *just before* the Retry-After time + mocker.patch('time.time', return_value=time_at_request) + response = await client.post(base_url) + + assert response.status_code == 200 + assert route.call_count == 2 + assert mock_sleep.call_count == 1 + # Check that sleep was called with approximately the correct delay + # Allow for small floating point inaccuracies + mock_sleep.assert_called_once() + args, _ = mock_sleep.call_args + assert args[0] == pytest.approx(retry_delay_seconds, abs=2) + + @pytest.mark.asyncio + @respx.mock + async def test_retry_after_ignored_when_disabled(self, base_url: str, mocker: MockerFixture): + """Test Retry-After header is ignored if `respect_retry_after_header` is `False`.""" + retry_config = HttpxRetry( + max_retries=3, respect_retry_after_header=False, status_forcelist=[429], + backoff_factor=0.5, backoff_max=10) + transport = HttpxRetryTransport(retry=retry_config) + client = httpx.AsyncClient(transport=transport) + + route = respx.post(base_url).mock(side_effect=[ + httpx.Response(429, text="Too Many Requests", headers={'Retry-After': '60'}), + httpx.Response(429, text="Too Many Requests", headers={'Retry-After': '60'}), + httpx.Response(429, text="Too Many Requests", headers={'Retry-After': '60'}), + httpx.Response(200, text="OK"), + ]) + + mock_sleep = mocker.patch('asyncio.sleep', return_value=None) + response = await client.post(base_url) + + assert response.status_code == 200 + assert route.call_count == 4 + assert mock_sleep.call_count == 3 + + # Assert sleep was called with the calculated backoff times: + # After first attempt: delay = 0 + # After retry 1 (attempt 2): delay = 0.5 * (2**(2-1)) = 0.5 * 2 = 1.0 + # After retry 2 (attempt 3): delay = 0.5 * (2**(3-1)) = 0.5 * 4 = 2.0 + expected_sleeps = [call(0), call(1.0), call(2.0)] + mock_sleep.assert_has_calls(expected_sleeps) + + @pytest.mark.asyncio + @respx.mock + async def test_retry_after_header_missing_backoff_fallback( + self, base_url: str, mocker: MockerFixture + ): + """Test Retry-After header is ignored if `respect_retry_after_header`is `True` but header is + not set.""" + retry_config = HttpxRetry( + max_retries=3, respect_retry_after_header=True, status_forcelist=[429], + backoff_factor=0.5, backoff_max=10) + transport = HttpxRetryTransport(retry=retry_config) + client = httpx.AsyncClient(transport=transport) + + route = respx.post(base_url).mock(side_effect=[ + httpx.Response(429, text="Too Many Requests"), + httpx.Response(429, text="Too Many Requests"), + httpx.Response(429, text="Too Many Requests"), + httpx.Response(200, text="OK"), + ]) + + mock_sleep = mocker.patch('asyncio.sleep', return_value=None) + response = await client.post(base_url) + + assert response.status_code == 200 + assert route.call_count == 4 + assert mock_sleep.call_count == 3 + + # Assert sleep was called with the calculated backoff times: + # After first attempt: delay = 0 + # After retry 1 (attempt 2): delay = 0.5 * (2**(2-1)) = 0.5 * 2 = 1.0 + # After retry 2 (attempt 3): delay = 0.5 * (2**(3-1)) = 0.5 * 4 = 2.0 + expected_sleeps = [call(0), call(1.0), call(2.0)] + mock_sleep.assert_has_calls(expected_sleeps) + + @pytest.mark.asyncio + @respx.mock + async def test_exponential_backoff(self, base_url: str, mocker: MockerFixture): + """Test that sleep time increases exponentially with `backoff_factor`.""" + # status=3 allows 3 retries (attempts 2, 3, 4) + retry_config = HttpxRetry( + max_retries=3, status_forcelist=[500], backoff_factor=0.1, backoff_max=10.0) + transport = HttpxRetryTransport(retry=retry_config) + client = httpx.AsyncClient(transport=transport) + + route = respx.post(base_url).mock(side_effect=[ + httpx.Response(500, text="Fail 1"), + httpx.Response(500, text="Fail 2"), + httpx.Response(500, text="Fail 3"), + httpx.Response(200, text="Success"), + ]) + + mock_sleep = mocker.patch('asyncio.sleep', return_value=None) + response = await client.post(base_url) + + assert response.status_code == 200 + assert route.call_count == 4 + assert mock_sleep.call_count == 3 + + # Check expected backoff times: + # After first attempt: delay = 0 + # After retry 1 (attempt 2): delay = 0.1 * (2**(2-1)) = 0.1 * 2 = 0.2 + # After retry 2 (attempt 3): delay = 0.1 * (2**(3-1)) = 0.1 * 4 = 0.4 + expected_sleeps = [call(0), call(0.2), call(0.4)] + mock_sleep.assert_has_calls(expected_sleeps) + + @pytest.mark.asyncio + @respx.mock + async def test_backoff_max(self, base_url: str, mocker: MockerFixture): + """Test that backoff time respects `backoff_max`.""" + # status=4 allows 4 retries. backoff_factor=1 causes rapid increase. + retry_config = HttpxRetry( + max_retries=4, status_forcelist=[500], backoff_factor=1, backoff_max=3.0) + transport = HttpxRetryTransport(retry=retry_config) + client = httpx.AsyncClient(transport=transport) + + route = respx.post(base_url).mock(side_effect=[ + httpx.Response(500, text="Fail 1"), + httpx.Response(500, text="Fail 2"), + httpx.Response(500, text="Fail 2"), + httpx.Response(500, text="Fail 4"), + httpx.Response(200, text="Success"), + ]) + + mock_sleep = mocker.patch('asyncio.sleep', return_value=None) + response = await client.post(base_url) + + assert response.status_code == 200 + assert route.call_count == 5 + assert mock_sleep.call_count == 4 + + # Check expected backoff times: + # After first attempt: delay = 0 + # After retry 1 (attempt 2): delay = 1*(2**(2-1)) = 2. Clamped by max(0, min(3.0, 2)) = 2.0 + # After retry 2 (attempt 3): delay = 1*(2**(3-1)) = 4. Clamped by max(0, min(3.0, 4)) = 3.0 + # After retry 3 (attempt 4): delay = 1*(2**(4-1)) = 8. Clamped by max(0, min(3.0, 8)) = 3.0 + expected_sleeps = [call(0.0), call(2.0), call(3.0), call(3.0)] + mock_sleep.assert_has_calls(expected_sleeps) + + @pytest.mark.asyncio + @respx.mock + async def test_backoff_jitter(self, base_url: str, mocker: MockerFixture): + """Test that `backoff_jitter` adds randomness within bounds.""" + retry_config = HttpxRetry( + max_retries=3, status_forcelist=[500], backoff_factor=0.2, backoff_jitter=0.1) + transport = HttpxRetryTransport(retry=retry_config) + client = httpx.AsyncClient(transport=transport) + + route = respx.post(base_url).mock(side_effect=[ + httpx.Response(500, text="Fail 1"), + httpx.Response(500, text="Fail 2"), + httpx.Response(500, text="Fail 3"), + httpx.Response(200, text="Success"), + ]) + + mock_sleep = mocker.patch('asyncio.sleep', return_value=None) + response = await client.post(base_url) + + assert response.status_code == 200 + assert route.call_count == 4 + assert mock_sleep.call_count == 3 + + # Check expected backoff times are within the expected range [base - jitter, base + jitter] + # After first attempt: delay = 0 + # After retry 1 (attempt 2): delay = 0.2 * (2**(2-1)) = 0.2 * 2 = 0.4 +/- 0.1 + # After retry 2 (attempt 3): delay = 0.2 * (2**(3-1)) = 0.2 * 4 = 0.8 +/- 0.1 + expected_sleeps = [ + call(pytest.approx(0.0, abs=0.1)), + call(pytest.approx(0.4, abs=0.1)), + call(pytest.approx(0.8, abs=0.1)) + ] + mock_sleep.assert_has_calls(expected_sleeps) + + @pytest.mark.asyncio + @respx.mock + async def test_error_not_retryable(self, base_url): + """Test that non-HTTP errors are raised immediately if not retryable.""" + retry_config = HttpxRetry(max_retries=3) + transport = HttpxRetryTransport(retry=retry_config) + client = httpx.AsyncClient(transport=transport) + + # Mock a connection error + route = respx.post(base_url).mock( + side_effect=repeat(httpx.ConnectError("Connection failed"))) + + with pytest.raises(httpx.ConnectError, match="Connection failed"): + await client.post(base_url) + + assert route.call_count == 1 + + +class TestHttpxRetry(): + _TEST_REQUEST = httpx.Request('POST', _TEST_URL) + + def test_httpx_retry_copy(self, base_url): + """Test that `HttpxRetry.copy()` creates a deep copy.""" + original = HttpxRetry(max_retries=5, status_forcelist=[500, 503], backoff_factor=0.5) + original.history.append((base_url, None, None)) # Add something mutable + + copied = original.copy() + + # Assert they are different objects + assert original is not copied + assert original.history is not copied.history + + # Assert values are the same initially + assert copied.retries_left == original.retries_left + assert copied.status_forcelist == original.status_forcelist + assert copied.backoff_factor == original.backoff_factor + assert len(copied.history) == 1 + + # Modify the copy and check original is unchanged + copied.retries_left = 1 + copied.status_forcelist = [404] + copied.history.append((base_url, None, None)) + + assert original.retries_left == 5 + assert original.status_forcelist == [500, 503] + assert len(original.history) == 1 + + def test_parse_retry_after_seconds(self): + retry = HttpxRetry() + assert retry._parse_retry_after('10') == 10.0 + assert retry._parse_retry_after(' 30 ') == 30.0 + + + def test_parse_retry_after_http_date(self, mocker: MockerFixture): + mocker.patch('time.time', return_value=1000.0) + retry = HttpxRetry() + # Date string representing 1015 seconds since epoch + http_date = email.utils.formatdate(1015.0) + # time.time() is mocked to 1000.0, so delay should be 15s + assert retry._parse_retry_after(http_date) == pytest.approx(15.0) + + def test_parse_retry_after_past_http_date(self, mocker: MockerFixture): + """Test that a past date results in 0 seconds.""" + mocker.patch('time.time', return_value=1000.0) + retry = HttpxRetry() + http_date = email.utils.formatdate(990.0) # 10s in the past + assert retry._parse_retry_after(http_date) == 0.0 + + def test_parse_retry_after_invalid_date(self): + retry = HttpxRetry() + with pytest.raises(httpx.RemoteProtocolError, match='Invalid Retry-After header'): + retry._parse_retry_after('Invalid Date Format') + + def test_get_backoff_time_calculation(self): + retry = HttpxRetry( + max_retries=6, status_forcelist=[503], backoff_factor=0.5, backoff_max=10.0) + response = httpx.Response(503) + # No history -> attempt 1 -> no backoff before first request + # Note: get_backoff_time() is typically called *before* the *next* request, + # so history length reflects completed attempts. + assert retry.get_backoff_time() == 0.0 + + # Simulate attempt 1 completed + retry.increment(self._TEST_REQUEST, response) + # History len 1, attempt 2 -> base case 0 + assert retry.get_backoff_time() == pytest.approx(0) + + # Simulate attempt 2 completed + retry.increment(self._TEST_REQUEST, response) + # History len 2, attempt 3 -> 0.5*(2^1) = 1.0 + assert retry.get_backoff_time() == pytest.approx(1.0) + + # Simulate attempt 3 completed + retry.increment(self._TEST_REQUEST, response) + # History len 3, attempt 4 -> 0.5*(2^2) = 2.0 + assert retry.get_backoff_time() == pytest.approx(2.0) + + # Simulate attempt 4 completed + retry.increment(self._TEST_REQUEST, response) + # History len 4, attempt 5 -> 0.5*(2^3) = 4.0 + assert retry.get_backoff_time() == pytest.approx(4.0) + + # Simulate attempt 5 completed + retry.increment(self._TEST_REQUEST, response) + # History len 5, attempt 6 -> 0.5*(2^4) = 8.0 + assert retry.get_backoff_time() == pytest.approx(8.0) + + # Simulate attempt 6 completed + retry.increment(self._TEST_REQUEST, response) + # History len 6, attempt 7 -> 0.5*(2^5) = 16.0 Clamped to 10 + assert retry.get_backoff_time() == pytest.approx(10.0) From e4aff7efbc1c476421b8da521e93e2a6d523cf91 Mon Sep 17 00:00:00 2001 From: Jonathan Edey <145066863+jonathanedey@users.noreply.github.com> Date: Thu, 5 Jun 2025 14:44:47 -0400 Subject: [PATCH 205/226] [chore] Release 6.9.0 (#885) --- firebase_admin/__about__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/firebase_admin/__about__.py b/firebase_admin/__about__.py index c822fb375..2ee3bbd62 100644 --- a/firebase_admin/__about__.py +++ b/firebase_admin/__about__.py @@ -14,7 +14,7 @@ """About information (version, etc) for Firebase Admin SDK.""" -__version__ = '6.8.0' +__version__ = '6.9.0' __title__ = 'firebase_admin' __author__ = 'Firebase' __license__ = 'Apache License 2.0' From 363166bbfdbf72945ace00296c93b7c37773e092 Mon Sep 17 00:00:00 2001 From: joefspiro <97258781+joefspiro@users.noreply.github.com> Date: Fri, 13 Jun 2025 16:39:25 -0400 Subject: [PATCH 206/226] Adds send each snippets. (#891) --- snippets/messaging/cloud_messaging.py | 48 +++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) diff --git a/snippets/messaging/cloud_messaging.py b/snippets/messaging/cloud_messaging.py index bb63db065..6caf316d0 100644 --- a/snippets/messaging/cloud_messaging.py +++ b/snippets/messaging/cloud_messaging.py @@ -245,6 +245,29 @@ def send_all(): # [END send_all] +def send_each(): + registration_token = 'YOUR_REGISTRATION_TOKEN' + # [START send_each] + # Create a list containing up to 500 messages. + messages = [ + messaging.Message( + notification=messaging.Notification('Price drop', '5% off all electronics'), + token=registration_token, + ), + # ... + messaging.Message( + notification=messaging.Notification('Price drop', '2% off all books'), + topic='readers-club', + ), + ] + + response = messaging.send_each(messages) + # See the BatchResponse reference documentation + # for the contents of response. + print('{0} messages were sent successfully'.format(response.success_count)) + # [END send_each] + + def send_multicast(): # [START send_multicast] # Create a list containing up to 500 registration tokens. @@ -289,3 +312,28 @@ def send_multicast_and_handle_errors(): failed_tokens.append(registration_tokens[idx]) print('List of tokens that caused failures: {0}'.format(failed_tokens)) # [END send_multicast_error] + + +def send_each_for_multicast_and_handle_errors(): + # [START send_each_for_multicast_error] + # These registration tokens come from the client FCM SDKs. + registration_tokens = [ + 'YOUR_REGISTRATION_TOKEN_1', + # ... + 'YOUR_REGISTRATION_TOKEN_N', + ] + + message = messaging.MulticastMessage( + data={'score': '850', 'time': '2:45'}, + tokens=registration_tokens, + ) + response = messaging.send_each_for_multicast(message) + if response.failure_count > 0: + responses = response.responses + failed_tokens = [] + for idx, resp in enumerate(responses): + if not resp.success: + # The order of responses corresponds to the order of the registration tokens. + failed_tokens.append(registration_tokens[idx]) + print('List of tokens that caused failures: {0}'.format(failed_tokens)) + # [END send_each_for_multicast_error] From 339452e7cf02daffdd74637cd8c827bb4cfd5b49 Mon Sep 17 00:00:00 2001 From: Jonathan Edey <145066863+jonathanedey@users.noreply.github.com> Date: Wed, 2 Jul 2025 16:04:09 -0400 Subject: [PATCH 207/226] fix(functions): Remove usage of deprecated `datetime.utcnow() and fix flaky unit test` (#896) --- firebase_admin/functions.py | 5 ++-- tests/test_functions.py | 58 +++++++++++++++++++++++-------------- 2 files changed, 40 insertions(+), 23 deletions(-) diff --git a/firebase_admin/functions.py b/firebase_admin/functions.py index fa17dfc0c..48ce62a76 100644 --- a/firebase_admin/functions.py +++ b/firebase_admin/functions.py @@ -15,7 +15,7 @@ """Firebase Functions module.""" from __future__ import annotations -from datetime import datetime, timedelta +from datetime import datetime, timedelta, timezone from urllib import parse import re import json @@ -255,7 +255,8 @@ def _validate_task_options( if not isinstance(opts.schedule_delay_seconds, int) \ or opts.schedule_delay_seconds < 0: raise ValueError('schedule_delay_seconds should be positive int.') - schedule_time = datetime.utcnow() + timedelta(seconds=opts.schedule_delay_seconds) + schedule_time = ( + datetime.now(timezone.utc) + timedelta(seconds=opts.schedule_delay_seconds)) task.schedule_time = schedule_time.strftime('%Y-%m-%dT%H:%M:%S.%fZ') if opts.dispatch_deadline_seconds is not None: if not isinstance(opts.dispatch_deadline_seconds, int) \ diff --git a/tests/test_functions.py b/tests/test_functions.py index 1856426d9..52e92c1b2 100644 --- a/tests/test_functions.py +++ b/tests/test_functions.py @@ -14,7 +14,7 @@ """Test cases for the firebase_admin.functions module.""" -from datetime import datetime, timedelta +from datetime import datetime, timedelta, timezone import json import time import pytest @@ -33,8 +33,6 @@ _CLOUD_TASKS_URL + 'projects/test-project/locations/us-central1/queues/test-function-name/tasks' _DEFAULT_TASK_URL = _CLOUD_TASKS_URL + _DEFAULT_TASK_PATH _DEFAULT_RESPONSE = json.dumps({'name': _DEFAULT_TASK_PATH}) -_ENQUEUE_TIME = datetime.utcnow() -_SCHEDULE_TIME = _ENQUEUE_TIME + timedelta(seconds=100) class TestTaskQueue: @classmethod @@ -185,27 +183,46 @@ def _instrument_functions_service(self, app=None, status=200, payload=_DEFAULT_R testutils.MockAdapter(payload, status, recorder)) return functions_service, recorder - - @pytest.mark.parametrize('task_opts_params', [ - { + def test_task_options_delay_seconds(self): + _, recorder = self._instrument_functions_service() + enqueue_time = datetime.now(timezone.utc) + expected_schedule_time = enqueue_time + timedelta(seconds=100) + task_opts_params = { 'schedule_delay_seconds': 100, 'schedule_time': None, 'dispatch_deadline_seconds': 200, 'task_id': 'test-task-id', 'headers': {'x-test-header': 'test-header-value'}, 'uri': 'https://google.com' - }, - { + } + queue = functions.task_queue('test-function-name') + task_opts = functions.TaskOptions(**task_opts_params) + queue.enqueue(_DEFAULT_DATA, task_opts) + + assert len(recorder) == 1 + task = json.loads(recorder[0].body.decode())['task'] + + task_schedule_time = datetime.fromisoformat(task['schedule_time'].replace('Z', '+00:00')) + delta = abs(task_schedule_time - expected_schedule_time) + assert delta <= timedelta(seconds=1) + + assert task['dispatch_deadline'] == '200s' + assert task['http_request']['headers']['x-test-header'] == 'test-header-value' + assert task['http_request']['url'] in ['http://google.com', 'https://google.com'] + assert task['name'] == _DEFAULT_TASK_PATH + + def test_task_options_utc_time(self): + _, recorder = self._instrument_functions_service() + enqueue_time = datetime.now(timezone.utc) + expected_schedule_time = enqueue_time + timedelta(seconds=100) + task_opts_params = { 'schedule_delay_seconds': None, - 'schedule_time': _SCHEDULE_TIME, + 'schedule_time': expected_schedule_time, 'dispatch_deadline_seconds': 200, 'task_id': 'test-task-id', 'headers': {'x-test-header': 'test-header-value'}, 'uri': 'http://google.com' - }, - ]) - def test_task_options(self, task_opts_params): - _, recorder = self._instrument_functions_service() + } queue = functions.task_queue('test-function-name') task_opts = functions.TaskOptions(**task_opts_params) queue.enqueue(_DEFAULT_DATA, task_opts) @@ -213,19 +230,18 @@ def test_task_options(self, task_opts_params): assert len(recorder) == 1 task = json.loads(recorder[0].body.decode())['task'] - schedule_time = datetime.fromisoformat(task['schedule_time'][:-1]) - delta = abs(schedule_time - _SCHEDULE_TIME) - assert delta <= timedelta(seconds=15) + task_schedule_time = datetime.fromisoformat(task['schedule_time'].replace('Z', '+00:00')) + assert task_schedule_time == expected_schedule_time assert task['dispatch_deadline'] == '200s' assert task['http_request']['headers']['x-test-header'] == 'test-header-value' assert task['http_request']['url'] in ['http://google.com', 'https://google.com'] assert task['name'] == _DEFAULT_TASK_PATH - def test_schedule_set_twice_error(self): _, recorder = self._instrument_functions_service() - opts = functions.TaskOptions(schedule_delay_seconds=100, schedule_time=datetime.utcnow()) + opts = functions.TaskOptions( + schedule_delay_seconds=100, schedule_time=datetime.now(timezone.utc)) queue = functions.task_queue('test-function-name') with pytest.raises(ValueError) as excinfo: queue.enqueue(_DEFAULT_DATA, opts) @@ -236,9 +252,9 @@ def test_schedule_set_twice_error(self): @pytest.mark.parametrize('schedule_time', [ time.time(), - str(datetime.utcnow()), - datetime.utcnow().isoformat(), - datetime.utcnow().isoformat() + 'Z', + str(datetime.now(timezone.utc)), + datetime.now(timezone.utc).isoformat(), + datetime.now(timezone.utc).isoformat() + 'Z', '', ' ' ]) def test_invalid_schedule_time_error(self, schedule_time): From 8d8439f4883107ab228002c72027b2f09d19adb0 Mon Sep 17 00:00:00 2001 From: Jonathan Edey <145066863+jonathanedey@users.noreply.github.com> Date: Mon, 14 Jul 2025 16:45:21 -0400 Subject: [PATCH 208/226] Merge v7 Feature Branch (#900) * change(fcm): Remove deprecated FCM APIs (#890) * chore(deps): Bump minimum supported Python version to 3.9 and add 3.13 to CIs (#892) * chore(deps): Bump minimum supported Python version to 3.9 and add 3.13 to CIs * fix deprecation warnings * fix GHA build status svg * fix: Correctly scope async eventloop * fix: Bump pylint to v2.7.4 and astroid to v2.5.8 to fix lint issues * fix ml tests * fix lint * fix: remove commented code * change(ml): Drop AutoML model support (#894) * chore: Bump `pylint` to v3.3.7 and `astroid` to v3.3.10 (#895) * chore: Bump pylint to v3 * chore: fix src lint * chore: fix unit test lint * chore: fix integration test lint * chore: fix snippets lint * chore: 2nd pass for errors * fix: corrected use of the `bad-functions` config * fix: add EoF newline * chore: Upgraded Google API Core, Cloud Firestore, and Cloud Storage dependencies (#897) * chore: Bump dependencies * fix: Also update setup.py * fix(functions): Remove usage of deprecated `datetime.utcnow() and fix flaky unit test` (#896) --- .github/workflows/ci.yml | 6 +- .github/workflows/nightly.yml | 5 +- .github/workflows/release.yml | 5 +- .pylintrc | 105 ++-- CONTRIBUTING.md | 2 +- README.md | 6 +- firebase_admin/__init__.py | 48 +- firebase_admin/_auth_client.py | 15 +- firebase_admin/_auth_providers.py | 57 +- firebase_admin/_auth_utils.py | 108 ++-- firebase_admin/_gapic_utils.py | 122 ----- firebase_admin/_messaging_encoder.py | 52 +- firebase_admin/_rfc3339.py | 2 +- firebase_admin/_sseclient.py | 11 +- firebase_admin/_token_gen.py | 117 +++-- firebase_admin/_user_import.py | 6 +- firebase_admin/_user_mgt.py | 71 ++- firebase_admin/_utils.py | 21 +- firebase_admin/app_check.py | 24 +- firebase_admin/credentials.py | 34 +- firebase_admin/db.py | 61 ++- firebase_admin/functions.py | 6 +- firebase_admin/instance_id.py | 6 +- firebase_admin/messaging.py | 155 +----- firebase_admin/ml.py | 79 +-- firebase_admin/project_management.py | 51 +- firebase_admin/remote_config.py | 10 +- firebase_admin/storage.py | 8 +- firebase_admin/tenant_mgt.py | 36 +- integration/conftest.py | 15 +- integration/test_auth.py | 17 +- integration/test_db.py | 14 +- integration/test_firestore.py | 14 +- integration/test_firestore_async.py | 18 +- integration/test_messaging.py | 75 +-- integration/test_ml.py | 91 +--- integration/test_project_management.py | 11 +- integration/test_storage.py | 8 +- integration/test_tenant_mgt.py | 9 +- requirements.txt | 19 +- setup.cfg | 2 + setup.py | 20 +- snippets/auth/get_service_account_tokens.py | 2 +- snippets/auth/index.py | 29 +- snippets/database/index.py | 12 +- snippets/messaging/cloud_messaging.py | 74 +-- tests/test_app.py | 14 +- tests/test_app_check.py | 4 +- tests/test_auth_providers.py | 21 +- tests/test_credentials.py | 4 +- tests/test_db.py | 67 +-- tests/test_exceptions.py | 161 ------ tests/test_instance_id.py | 6 +- tests/test_messaging.py | 549 ++------------------ tests/test_ml.py | 133 ++--- tests/test_project_management.py | 6 +- tests/test_remote_config.py | 4 +- tests/test_sseclient.py | 4 +- tests/test_storage.py | 2 +- tests/test_tenant_mgt.py | 81 +-- tests/test_token_gen.py | 20 +- tests/test_user_mgt.py | 39 +- tests/testutils.py | 4 +- 63 files changed, 773 insertions(+), 2005 deletions(-) delete mode 100644 firebase_admin/_gapic_utils.py diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 4cc8ec481..bfd29e2cc 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -8,7 +8,7 @@ jobs: strategy: fail-fast: false matrix: - python: ['3.8', '3.9', '3.10', '3.11', '3.12', 'pypy3.9'] + python: ['3.9', '3.10', '3.11', '3.12', '3.13', 'pypy3.9'] steps: - uses: actions/checkout@v4 @@ -35,10 +35,10 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - - name: Set up Python 3.8 + - name: Set up Python 3.9 uses: actions/setup-python@v5 with: - python-version: 3.8 + python-version: 3.9 - name: Install dependencies run: | python -m pip install --upgrade pip diff --git a/.github/workflows/nightly.yml b/.github/workflows/nightly.yml index 282cb1b91..3d5420537 100644 --- a/.github/workflows/nightly.yml +++ b/.github/workflows/nightly.yml @@ -36,7 +36,7 @@ jobs: - name: Set up Python uses: actions/setup-python@v5 with: - python-version: 3.8 + python-version: 3.9 - name: Install dependencies run: | @@ -45,6 +45,7 @@ jobs: pip install setuptools wheel pip install tensorflow pip install keras + pip install build - name: Run unit tests run: pytest @@ -57,7 +58,7 @@ jobs: # Build the Python Wheel and the source distribution. - name: Package release artifacts - run: python setup.py bdist_wheel sdist + run: python -m build # Attach the packaged artifacts to the workflow output. These can be manually # downloaded for later inspection if necessary. diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 7a7986a5a..6cd1d3f07 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -47,7 +47,7 @@ jobs: - name: Set up Python uses: actions/setup-python@v5 with: - python-version: 3.8 + python-version: 3.9 - name: Install dependencies run: | @@ -56,6 +56,7 @@ jobs: pip install setuptools wheel pip install tensorflow pip install keras + pip install build - name: Run unit tests run: pytest @@ -68,7 +69,7 @@ jobs: # Build the Python Wheel and the source distribution. - name: Package release artifacts - run: python setup.py bdist_wheel sdist + run: python -m build # Attach the packaged artifacts to the workflow output. These can be manually # downloaded for later inspection if necessary. diff --git a/.pylintrc b/.pylintrc index 2155853c7..ea54e481c 100644 --- a/.pylintrc +++ b/.pylintrc @@ -1,4 +1,4 @@ -[MASTER] +[MAIN] # Specify a configuration file. #rcfile= @@ -20,7 +20,9 @@ persistent=no # List of plugins (as comma separated values of python modules names) to load, # usually to register additional checkers. -load-plugins=pylint.extensions.docparams,pylint.extensions.docstyle +load-plugins=pylint.extensions.docparams, + pylint.extensions.docstyle, + pylint.extensions.bad_builtin, # Use multiple processes to speed up Pylint. jobs=1 @@ -34,15 +36,6 @@ unsafe-load-any-extension=no # run arbitrary code extension-pkg-whitelist= -# Allow optimization of some AST trees. This will activate a peephole AST -# optimizer, which will apply various small optimizations. For instance, it can -# be used to obtain the result of joining multiple strings with the addition -# operator. Joining a lot of strings can lead to a maximum recursion error in -# Pylint and this flag can prevent that. It has one side effect, the resulting -# AST will be different than the one from reality. This option is deprecated -# and it will be removed in Pylint 2.0. -optimize-ast=no - [MESSAGES CONTROL] @@ -65,21 +58,31 @@ enable=indexing-exception,old-raise-syntax # --enable=similarities". If you want to run only the classes checker, but have # no Warning level messages displayed, use"--disable=all --enable=classes # --disable=W" -disable=design,similarities,no-self-use,attribute-defined-outside-init,locally-disabled,star-args,pointless-except,bad-option-value,global-statement,fixme,suppressed-message,useless-suppression,locally-enabled,file-ignored,missing-type-doc +disable=design, + similarities, + no-self-use, + attribute-defined-outside-init, + locally-disabled, + star-args, + pointless-except, + bad-option-value, + lobal-statement, + fixme, + suppressed-message, + useless-suppression, + locally-enabled, + file-ignored, + missing-type-doc, + c-extension-no-member, [REPORTS] -# Set the output format. Available formats are text, parseable, colorized, msvs -# (visual studio) and html. You can also give a reporter class, eg -# mypackage.mymodule.MyReporterClass. -output-format=text - -# Put messages in a separate file for each module / package specified on the -# command line instead of printing them on stdout. Reports (if any) will be -# written in a file name "pylint_global.[txt|html]". This option is deprecated -# and it will be removed in Pylint 2.0. -files-output=no +# Set the output format. Available formats are: 'text', 'parseable', +# 'colorized', 'json2' (improved json format), 'json' (old json format), msvs +# (visual studio) and 'github' (GitHub actions). You can also give a reporter +# class, e.g. mypackage.mymodule.MyReporterClass. +output-format=colorized # Tells whether to display a full report or only the messages reports=no @@ -176,9 +179,12 @@ logging-modules=logging good-names=main,_ # Bad variable names which should always be refused, separated by a comma -bad-names= - -bad-functions=input,apply,reduce +bad-names=foo, + bar, + baz, + toto, + tutu, + tata # Colon-delimited sets of names that determine each other's naming style when # the name regexes allow several styles. @@ -194,64 +200,33 @@ property-classes=abc.abstractproperty # Regular expression matching correct function names function-rgx=[a-z_][a-z0-9_]*$ -# Naming hint for function names -function-name-hint=[a-z_][a-z0-9_]*$ - # Regular expression matching correct variable names variable-rgx=[a-z_][a-z0-9_]{2,30}$ -# Naming hint for variable names -variable-name-hint=[a-z_][a-z0-9_]{2,30}$ - # Regular expression matching correct constant names const-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$ - -# Naming hint for constant names -const-name-hint=(([A-Z_][A-Z0-9_]*)|(__.*__))$ - # Regular expression matching correct attribute names attr-rgx=[a-z_][a-z0-9_]{2,30}$ -# Naming hint for attribute names -attr-name-hint=[a-z_][a-z0-9_]{2,30}$ - # Regular expression matching correct argument names argument-rgx=[a-z_][a-z0-9_]{2,30}$ -# Naming hint for argument names -argument-name-hint=[a-z_][a-z0-9_]{2,30}$ - # Regular expression matching correct class attribute names class-attribute-rgx=([A-Za-z_][A-Za-z0-9_]{2,30}|(__.*__))$ -# Naming hint for class attribute names -class-attribute-name-hint=([A-Za-z_][A-Za-z0-9_]{2,30}|(__.*__))$ - # Regular expression matching correct inline iteration names inlinevar-rgx=[A-Za-z_][A-Za-z0-9_]*$ -# Naming hint for inline iteration names -inlinevar-name-hint=[A-Za-z_][A-Za-z0-9_]*$ - # Regular expression matching correct class names class-rgx=[A-Z_][a-zA-Z0-9]+$ -# Naming hint for class names -class-name-hint=[A-Z_][a-zA-Z0-9]+$ - # Regular expression matching correct module names module-rgx=(([a-z_][a-z0-9_]*)|([A-Z][a-zA-Z0-9]+))$ -# Naming hint for module names -module-name-hint=(([a-z_][a-z0-9_]*)|([A-Z][a-zA-Z0-9]+))$ - # Regular expression matching correct method names method-rgx=[a-z_][a-z0-9_]*$ -# Naming hint for method names -method-name-hint=[a-z_][a-z0-9_]*$ - # Regular expression which should only match function or class names that do # not require a docstring. no-docstring-rgx=(__.*__|main) @@ -294,12 +269,6 @@ ignore-long-lines=^\s*(# )??$ # else. single-line-if-stmt=no -# List of optional constructs for which whitespace checking is disabled. `dict- -# separator` is used to allow tabulation in dicts, etc.: {1 : 1,\n222: 2}. -# `trailing-comma` allows a space between comma and closing bracket: (a, ). -# `empty-line` allows space-only lines. -no-space-check=trailing-comma,dict-separator - # Maximum number of lines in a module max-module-lines=1000 @@ -405,6 +374,12 @@ exclude-protected=_asdict,_fields,_replace,_source,_make [EXCEPTIONS] -# Exceptions that will emit a warning when being caught. Defaults to -# "Exception" -overgeneral-exceptions=Exception +# Exceptions that will emit a warning when caught. +overgeneral-exceptions=builtins.BaseException,builtins.Exception + +[DEPRECATED_BUILTINS] + +# List of builtins function names that should not be used, separated by a comma +bad-functions=input, + apply, + reduce diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index de5934866..72933a24f 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -85,7 +85,7 @@ information on using pull requests. ### Initial Setup -You need Python 3.8+ to build and test the code in this repo. +You need Python 3.9+ to build and test the code in this repo. We recommend using [pip](https://pypi.python.org/pypi/pip) for installing the necessary tools and project dependencies. Most recent versions of Python ship with pip. If your development environment diff --git a/README.md b/README.md index 6e3ed6805..29303fd4f 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -[![Build Status](https://travis-ci.org/firebase/firebase-admin-python.svg?branch=master)](https://travis-ci.org/firebase/firebase-admin-python) +[![Nightly Builds](https://github.com/firebase/firebase-admin-python/actions/workflows/nightly.yml/badge.svg)](https://github.com/firebase/firebase-admin-python/actions/workflows/nightly.yml) [![Python](https://img.shields.io/pypi/pyversions/firebase-admin.svg)](https://pypi.org/project/firebase-admin/) [![Version](https://img.shields.io/pypi/v/firebase-admin.svg)](https://pypi.org/project/firebase-admin/) @@ -43,8 +43,8 @@ requests, code review feedback, and also pull requests. ## Supported Python Versions -We currently support Python 3.7+. However, Python 3.7 and Python 3.8 support is deprecated, -and developers are strongly advised to use Python 3.9 or higher. Firebase +We currently support Python 3.9+. However, Python 3.9 support is deprecated, +and developers are strongly advised to use Python 3.10 or higher. Firebase Admin Python SDK is also tested on PyPy and [Google App Engine](https://cloud.google.com/appengine/) environments. diff --git a/firebase_admin/__init__.py b/firebase_admin/__init__.py index 7bb9c59c2..8c9f628e5 100644 --- a/firebase_admin/__init__.py +++ b/firebase_admin/__init__.py @@ -79,11 +79,11 @@ def initialize_app(credential=None, options=None, name=_DEFAULT_APP_NAME): 'apps, pass a second argument to initialize_app() to give each app ' 'a unique name.')) - raise ValueError(( - 'Firebase app named "{0}" already exists. This means you called ' + raise ValueError( + f'Firebase app named "{name}" already exists. This means you called ' 'initialize_app() more than once with the same app name as the ' 'second argument. Make sure you provide a unique name every time ' - 'you call initialize_app().').format(name)) + 'you call initialize_app().') def delete_app(app): @@ -96,8 +96,7 @@ def delete_app(app): ValueError: If the app is not initialized. """ if not isinstance(app, App): - raise ValueError('Illegal app argument type: "{}". Argument must be of ' - 'type App.'.format(type(app))) + raise ValueError(f'Illegal app argument type: "{type(app)}". Argument must be of type App.') with _apps_lock: if _apps.get(app.name) is app: del _apps[app.name] @@ -109,9 +108,9 @@ def delete_app(app): 'the default app by calling initialize_app().') raise ValueError( - ('Firebase app named "{0}" is not initialized. Make sure to initialize ' - 'the app by calling initialize_app() with your app name as the ' - 'second argument.').format(app.name)) + f'Firebase app named "{app.name}" is not initialized. Make sure to initialize ' + 'the app by calling initialize_app() with your app name as the ' + 'second argument.') def get_app(name=_DEFAULT_APP_NAME): @@ -128,8 +127,8 @@ def get_app(name=_DEFAULT_APP_NAME): app does not exist. """ if not isinstance(name, str): - raise ValueError('Illegal app name argument type: "{}". App name ' - 'must be a string.'.format(type(name))) + raise ValueError( + f'Illegal app name argument type: "{type(name)}". App name must be a string.') with _apps_lock: if name in _apps: return _apps[name] @@ -140,9 +139,9 @@ def get_app(name=_DEFAULT_APP_NAME): 'the SDK by calling initialize_app().') raise ValueError( - ('Firebase app named "{0}" does not exist. Make sure to initialize ' - 'the SDK by calling initialize_app() with your app name as the ' - 'second argument.').format(name)) + f'Firebase app named "{name}" does not exist. Make sure to initialize ' + 'the SDK by calling initialize_app() with your app name as the ' + 'second argument.') class _AppOptions: @@ -153,8 +152,9 @@ def __init__(self, options): options = self._load_from_environment() if not isinstance(options, dict): - raise ValueError('Illegal Firebase app options type: {0}. Options ' - 'must be a dictionary.'.format(type(options))) + raise ValueError( + f'Illegal Firebase app options type: {type(options)}. ' + 'Options must be a dictionary.') self._options = options def get(self, key, default=None): @@ -175,14 +175,15 @@ def _load_from_environment(self): json_str = config_file else: try: - with open(config_file, 'r') as json_file: + with open(config_file, 'r', encoding='utf-8') as json_file: json_str = json_file.read() except Exception as err: - raise ValueError('Unable to read file {}. {}'.format(config_file, err)) + raise ValueError(f'Unable to read file {config_file}. {err}') from err try: json_data = json.loads(json_str) except Exception as err: - raise ValueError('JSON string "{0}" is not valid json. {1}'.format(json_str, err)) + raise ValueError( + f'JSON string "{json_str}" is not valid json. {err}') from err return {k: v for k, v in json_data.items() if k in _CONFIG_VALID_KEYS} @@ -205,8 +206,9 @@ def __init__(self, name, credential, options): ValueError: If an argument is None or invalid. """ if not name or not isinstance(name, str): - raise ValueError('Illegal Firebase app name "{0}" provided. App name must be a ' - 'non-empty string.'.format(name)) + raise ValueError( + f'Illegal Firebase app name "{name}" provided. App name must be a ' + 'non-empty string.') self._name = name if isinstance(credential, GoogleAuthCredentials): @@ -227,7 +229,7 @@ def __init__(self, name, credential, options): def _validate_project_id(cls, project_id): if project_id is not None and not isinstance(project_id, str): raise ValueError( - 'Invalid project ID: "{0}". project ID must be a string.'.format(project_id)) + f'Invalid project ID: "{project_id}". project ID must be a string.') @property def name(self): @@ -292,11 +294,11 @@ def _get_service(self, name, initializer): """ if not name or not isinstance(name, str): raise ValueError( - 'Illegal name argument: "{0}". Name must be a non-empty string.'.format(name)) + f'Illegal name argument: "{name}". Name must be a non-empty string.') with self._lock: if self._services is None: raise ValueError( - 'Service requested from deleted Firebase App: "{0}".'.format(self._name)) + f'Service requested from deleted Firebase App: "{self._name}".') if name not in self._services: self._services[name] = initializer(self) return self._services[name] diff --git a/firebase_admin/_auth_client.py b/firebase_admin/_auth_client.py index 38b42993a..74261fa37 100644 --- a/firebase_admin/_auth_client.py +++ b/firebase_admin/_auth_client.py @@ -38,7 +38,7 @@ def __init__(self, app, tenant_id=None): 3. set the project ID via the GOOGLE_CLOUD_PROJECT environment variable.""") credential = None - version_header = 'Python/Admin/{0}'.format(firebase_admin.__version__) + version_header = f'Python/Admin/{firebase_admin.__version__}' timeout = app.options.get('httpTimeout', _http_client.DEFAULT_TIMEOUT_SECONDS) # Non-default endpoint URLs for emulator support are set in this dict later. endpoint_urls = {} @@ -48,7 +48,7 @@ def __init__(self, app, tenant_id=None): # endpoint URLs to use the emulator. Additionally, use a fake credential. emulator_host = _auth_utils.get_emulator_host() if emulator_host: - base_url = 'http://{0}/identitytoolkit.googleapis.com'.format(emulator_host) + base_url = f'http://{emulator_host}/identitytoolkit.googleapis.com' endpoint_urls['v1'] = base_url + '/v1' endpoint_urls['v2'] = base_url + '/v2' credential = _utils.EmulatorAdminCredentials() @@ -123,15 +123,16 @@ def verify_id_token(self, id_token, check_revoked=False, clock_skew_seconds=0): """ if not isinstance(check_revoked, bool): # guard against accidental wrong assignment. - raise ValueError('Illegal check_revoked argument. Argument must be of type ' - ' bool, but given "{0}".'.format(type(check_revoked))) + raise ValueError( + 'Illegal check_revoked argument. Argument must be of type bool, but given ' + f'"{type(check_revoked)}".') verified_claims = self._token_verifier.verify_id_token(id_token, clock_skew_seconds) if self.tenant_id: token_tenant_id = verified_claims.get('firebase', {}).get('tenant') if self.tenant_id != token_tenant_id: raise _auth_utils.TenantIdMismatchError( - 'Invalid tenant ID: {0}'.format(token_tenant_id)) + f'Invalid tenant ID: {token_tenant_id}') if check_revoked: self._check_jwt_revoked_or_disabled( @@ -249,7 +250,7 @@ def _matches(identifier, user_record): if identifier.provider_id == user_info.provider_id and identifier.provider_uid == user_info.uid ), False) - raise TypeError("Unexpected type: {}".format(type(identifier))) + raise TypeError(f"Unexpected type: {type(identifier)}") def _is_user_found(identifier, user_records): return any(_matches(identifier, user_record) for user_record in user_records) @@ -757,4 +758,4 @@ def _check_jwt_revoked_or_disabled(self, verified_claims, exc_type, label): if user.disabled: raise _auth_utils.UserDisabledError('The user record is disabled.') if verified_claims.get('iat') * 1000 < user.tokens_valid_after_timestamp: - raise exc_type('The Firebase {0} has been revoked.'.format(label)) + raise exc_type(f'The Firebase {label} has been revoked.') diff --git a/firebase_admin/_auth_providers.py b/firebase_admin/_auth_providers.py index 31894a4dc..cc7949526 100644 --- a/firebase_admin/_auth_providers.py +++ b/firebase_admin/_auth_providers.py @@ -181,13 +181,13 @@ class ProviderConfigClient: def __init__(self, http_client, project_id, tenant_id=None, url_override=None): self.http_client = http_client url_prefix = url_override or self.PROVIDER_CONFIG_URL - self.base_url = '{0}/projects/{1}'.format(url_prefix, project_id) + self.base_url = f'{url_prefix}/projects/{project_id}' if tenant_id: - self.base_url += '/tenants/{0}'.format(tenant_id) + self.base_url += f'/tenants/{tenant_id}' def get_oidc_provider_config(self, provider_id): _validate_oidc_provider_id(provider_id) - body = self._make_request('get', '/oauthIdpConfigs/{0}'.format(provider_id)) + body = self._make_request('get', f'/oauthIdpConfigs/{provider_id}') return OIDCProviderConfig(body) def create_oidc_provider_config( @@ -218,7 +218,7 @@ def create_oidc_provider_config( if response_type: req['responseType'] = response_type - params = 'oauthIdpConfigId={0}'.format(provider_id) + params = f'oauthIdpConfigId={provider_id}' body = self._make_request('post', '/oauthIdpConfigs', json=req, params=params) return OIDCProviderConfig(body) @@ -259,14 +259,14 @@ def update_oidc_provider_config( raise ValueError('At least one parameter must be specified for update.') update_mask = _auth_utils.build_update_mask(req) - params = 'updateMask={0}'.format(','.join(update_mask)) - url = '/oauthIdpConfigs/{0}'.format(provider_id) + params = f'updateMask={",".join(update_mask)}' + url = f'/oauthIdpConfigs/{provider_id}' body = self._make_request('patch', url, json=req, params=params) return OIDCProviderConfig(body) def delete_oidc_provider_config(self, provider_id): _validate_oidc_provider_id(provider_id) - self._make_request('delete', '/oauthIdpConfigs/{0}'.format(provider_id)) + self._make_request('delete', f'/oauthIdpConfigs/{provider_id}') def list_oidc_provider_configs(self, page_token=None, max_results=MAX_LIST_CONFIGS_RESULTS): return _ListOIDCProviderConfigsPage( @@ -277,7 +277,7 @@ def _fetch_oidc_provider_configs(self, page_token=None, max_results=MAX_LIST_CON def get_saml_provider_config(self, provider_id): _validate_saml_provider_id(provider_id) - body = self._make_request('get', '/inboundSamlConfigs/{0}'.format(provider_id)) + body = self._make_request('get', f'/inboundSamlConfigs/{provider_id}') return SAMLProviderConfig(body) def create_saml_provider_config( @@ -301,7 +301,7 @@ def create_saml_provider_config( if enabled is not None: req['enabled'] = _auth_utils.validate_boolean(enabled, 'enabled') - params = 'inboundSamlConfigId={0}'.format(provider_id) + params = f'inboundSamlConfigId={provider_id}' body = self._make_request('post', '/inboundSamlConfigs', json=req, params=params) return SAMLProviderConfig(body) @@ -341,14 +341,14 @@ def update_saml_provider_config( raise ValueError('At least one parameter must be specified for update.') update_mask = _auth_utils.build_update_mask(req) - params = 'updateMask={0}'.format(','.join(update_mask)) - url = '/inboundSamlConfigs/{0}'.format(provider_id) + params = f'updateMask={",".join(update_mask)}' + url = f'/inboundSamlConfigs/{provider_id}' body = self._make_request('patch', url, json=req, params=params) return SAMLProviderConfig(body) def delete_saml_provider_config(self, provider_id): _validate_saml_provider_id(provider_id) - self._make_request('delete', '/inboundSamlConfigs/{0}'.format(provider_id)) + self._make_request('delete', f'/inboundSamlConfigs/{provider_id}') def list_saml_provider_configs(self, page_token=None, max_results=MAX_LIST_CONFIGS_RESULTS): return _ListSAMLProviderConfigsPage( @@ -367,15 +367,15 @@ def _fetch_provider_configs(self, path, page_token=None, max_results=MAX_LIST_CO if max_results < 1 or max_results > MAX_LIST_CONFIGS_RESULTS: raise ValueError( 'Max results must be a positive integer less than or equal to ' - '{0}.'.format(MAX_LIST_CONFIGS_RESULTS)) + f'{MAX_LIST_CONFIGS_RESULTS}.') - params = 'pageSize={0}'.format(max_results) + params = f'pageSize={max_results}' if page_token: - params += '&pageToken={0}'.format(page_token) + params += f'&pageToken={page_token}' return self._make_request('get', path, params=params) def _make_request(self, method, path, **kwargs): - url = '{0}{1}'.format(self.base_url, path) + url = f'{self.base_url}{path}' try: return self.http_client.body(method, url, **kwargs) except requests.exceptions.RequestException as error: @@ -385,29 +385,27 @@ def _make_request(self, method, path, **kwargs): def _validate_oidc_provider_id(provider_id): if not isinstance(provider_id, str): raise ValueError( - 'Invalid OIDC provider ID: {0}. Provider ID must be a non-empty string.'.format( - provider_id)) + f'Invalid OIDC provider ID: {provider_id}. Provider ID must be a non-empty string.') if not provider_id.startswith('oidc.'): - raise ValueError('Invalid OIDC provider ID: {0}.'.format(provider_id)) + raise ValueError(f'Invalid OIDC provider ID: {provider_id}.') return provider_id def _validate_saml_provider_id(provider_id): if not isinstance(provider_id, str): raise ValueError( - 'Invalid SAML provider ID: {0}. Provider ID must be a non-empty string.'.format( - provider_id)) + f'Invalid SAML provider ID: {provider_id}. Provider ID must be a non-empty string.') if not provider_id.startswith('saml.'): - raise ValueError('Invalid SAML provider ID: {0}.'.format(provider_id)) + raise ValueError(f'Invalid SAML provider ID: {provider_id}.') return provider_id def _validate_non_empty_string(value, label): """Validates that the given value is a non-empty string.""" if not isinstance(value, str): - raise ValueError('Invalid type for {0}: {1}.'.format(label, value)) + raise ValueError(f'Invalid type for {label}: {value}.') if not value: - raise ValueError('{0} must not be empty.'.format(label)) + raise ValueError(f'{label} must not be empty.') return value @@ -415,20 +413,19 @@ def _validate_url(url, label): """Validates that the given value is a well-formed URL string.""" if not isinstance(url, str) or not url: raise ValueError( - 'Invalid photo URL: "{0}". {1} must be a non-empty ' - 'string.'.format(url, label)) + f'Invalid photo URL: "{url}". {label} must be a non-empty string.') try: parsed = parse.urlparse(url) if not parsed.netloc: - raise ValueError('Malformed {0}: "{1}".'.format(label, url)) + raise ValueError(f'Malformed {label}: "{url}".') return url - except Exception: - raise ValueError('Malformed {0}: "{1}".'.format(label, url)) + except Exception as exception: + raise ValueError(f'Malformed {label}: "{url}".') from exception def _validate_x509_certificates(x509_certificates): if not isinstance(x509_certificates, list) or not x509_certificates: raise ValueError('x509_certificates must be a non-empty list.') - if not all([isinstance(cert, str) and cert for cert in x509_certificates]): + if not all(isinstance(cert, str) and cert for cert in x509_certificates): raise ValueError('x509_certificates must only contain non-empty strings.') return [{'x509Certificate': cert} for cert in x509_certificates] diff --git a/firebase_admin/_auth_utils.py b/firebase_admin/_auth_utils.py index ac7b322ff..60d411822 100644 --- a/firebase_admin/_auth_utils.py +++ b/firebase_admin/_auth_utils.py @@ -74,8 +74,8 @@ def get_emulator_host(): emulator_host = os.getenv(EMULATOR_HOST_ENV_VAR, '') if emulator_host and '//' in emulator_host: raise ValueError( - 'Invalid {0}: "{1}". It must follow format "host:port".'.format( - EMULATOR_HOST_ENV_VAR, emulator_host)) + f'Invalid {EMULATOR_HOST_ENV_VAR}: "{emulator_host}". ' + 'It must follow format "host:port".') return emulator_host @@ -88,8 +88,8 @@ def validate_uid(uid, required=False): return None if not isinstance(uid, str) or not uid or len(uid) > 128: raise ValueError( - 'Invalid uid: "{0}". The uid must be a non-empty string with no more than 128 ' - 'characters.'.format(uid)) + f'Invalid uid: "{uid}". The uid must be a non-empty string with no more than 128 ' + 'characters.') return uid def validate_email(email, required=False): @@ -97,10 +97,10 @@ def validate_email(email, required=False): return None if not isinstance(email, str) or not email: raise ValueError( - 'Invalid email: "{0}". Email must be a non-empty string.'.format(email)) + f'Invalid email: "{email}". Email must be a non-empty string.') parts = email.split('@') if len(parts) != 2 or not parts[0] or not parts[1]: - raise ValueError('Malformed email address string: "{0}".'.format(email)) + raise ValueError(f'Malformed email address string: "{email}".') return email def validate_phone(phone, required=False): @@ -113,11 +113,12 @@ def validate_phone(phone, required=False): if phone is None and not required: return None if not isinstance(phone, str) or not phone: - raise ValueError('Invalid phone number: "{0}". Phone number must be a non-empty ' - 'string.'.format(phone)) + raise ValueError( + f'Invalid phone number: "{phone}". Phone number must be a non-empty string.') if not phone.startswith('+') or not re.search('[a-zA-Z0-9]', phone): - raise ValueError('Invalid phone number: "{0}". Phone number must be a valid, E.164 ' - 'compliant identifier.'.format(phone)) + raise ValueError( + f'Invalid phone number: "{phone}". Phone number must be a valid, E.164 ' + 'compliant identifier.') return phone def validate_password(password, required=False): @@ -132,7 +133,7 @@ def validate_bytes(value, label, required=False): if value is None and not required: return None if not isinstance(value, bytes) or not value: - raise ValueError('{0} must be a non-empty byte sequence.'.format(label)) + raise ValueError(f'{label} must be a non-empty byte sequence.') return value def validate_display_name(display_name, required=False): @@ -140,8 +141,8 @@ def validate_display_name(display_name, required=False): return None if not isinstance(display_name, str) or not display_name: raise ValueError( - 'Invalid display name: "{0}". Display name must be a non-empty ' - 'string.'.format(display_name)) + f'Invalid display name: "{display_name}". Display name must be a non-empty ' + 'string.') return display_name def validate_provider_id(provider_id, required=True): @@ -149,8 +150,7 @@ def validate_provider_id(provider_id, required=True): return None if not isinstance(provider_id, str) or not provider_id: raise ValueError( - 'Invalid provider ID: "{0}". Provider ID must be a non-empty ' - 'string.'.format(provider_id)) + f'Invalid provider ID: "{provider_id}". Provider ID must be a non-empty string.') return provider_id def validate_provider_uid(provider_uid, required=True): @@ -158,8 +158,7 @@ def validate_provider_uid(provider_uid, required=True): return None if not isinstance(provider_uid, str) or not provider_uid: raise ValueError( - 'Invalid provider UID: "{0}". Provider UID must be a non-empty ' - 'string.'.format(provider_uid)) + f'Invalid provider UID: "{provider_uid}". Provider UID must be a non-empty string.') return provider_uid def validate_photo_url(photo_url, required=False): @@ -168,15 +167,14 @@ def validate_photo_url(photo_url, required=False): return None if not isinstance(photo_url, str) or not photo_url: raise ValueError( - 'Invalid photo URL: "{0}". Photo URL must be a non-empty ' - 'string.'.format(photo_url)) + f'Invalid photo URL: "{photo_url}". Photo URL must be a non-empty string.') try: parsed = parse.urlparse(photo_url) if not parsed.netloc: - raise ValueError('Malformed photo URL: "{0}".'.format(photo_url)) + raise ValueError(f'Malformed photo URL: "{photo_url}".') return photo_url - except Exception: - raise ValueError('Malformed photo URL: "{0}".'.format(photo_url)) + except Exception as err: + raise ValueError(f'Malformed photo URL: "{photo_url}".') from err def validate_timestamp(timestamp, label, required=False): """Validates the given timestamp value. Timestamps must be positive integers.""" @@ -186,14 +184,13 @@ def validate_timestamp(timestamp, label, required=False): raise ValueError('Boolean value specified as timestamp.') try: timestamp_int = int(timestamp) - except TypeError: - raise ValueError('Invalid type for timestamp value: {0}.'.format(timestamp)) - else: - if timestamp_int != timestamp: - raise ValueError('{0} must be a numeric value and a whole number.'.format(label)) - if timestamp_int <= 0: - raise ValueError('{0} timestamp must be a positive interger.'.format(label)) - return timestamp_int + except TypeError as err: + raise ValueError(f'Invalid type for timestamp value: {timestamp}.') from err + if timestamp_int != timestamp: + raise ValueError(f'{label} must be a numeric value and a whole number.') + if timestamp_int <= 0: + raise ValueError(f'{label} timestamp must be a positive interger.') + return timestamp_int def validate_int(value, label, low=None, high=None): """Validates that the given value represents an integer. @@ -204,31 +201,30 @@ def validate_int(value, label, low=None, high=None): a developer error. """ if value is None or isinstance(value, bool): - raise ValueError('Invalid type for integer value: {0}.'.format(value)) + raise ValueError(f'Invalid type for integer value: {value}.') try: val_int = int(value) - except TypeError: - raise ValueError('Invalid type for integer value: {0}.'.format(value)) - else: - if val_int != value: - # This will be True for non-numeric values like '2' and non-whole numbers like 2.5. - raise ValueError('{0} must be a numeric value and a whole number.'.format(label)) - if low is not None and val_int < low: - raise ValueError('{0} must not be smaller than {1}.'.format(label, low)) - if high is not None and val_int > high: - raise ValueError('{0} must not be larger than {1}.'.format(label, high)) - return val_int + except TypeError as err: + raise ValueError(f'Invalid type for integer value: {value}.') from err + if val_int != value: + # This will be True for non-numeric values like '2' and non-whole numbers like 2.5. + raise ValueError(f'{label} must be a numeric value and a whole number.') + if low is not None and val_int < low: + raise ValueError(f'{label} must not be smaller than {low}.') + if high is not None and val_int > high: + raise ValueError(f'{label} must not be larger than {high}.') + return val_int def validate_string(value, label): """Validates that the given value is a string.""" if not isinstance(value, str): - raise ValueError('Invalid type for {0}: {1}.'.format(label, value)) + raise ValueError(f'Invalid type for {label}: {value}.') return value def validate_boolean(value, label): """Validates that the given value is a boolean.""" if not isinstance(value, bool): - raise ValueError('Invalid type for {0}: {1}.'.format(label, value)) + raise ValueError(f'Invalid type for {label}: {value}.') return value def validate_custom_claims(custom_claims, required=False): @@ -242,28 +238,28 @@ def validate_custom_claims(custom_claims, required=False): claims_str = str(custom_claims) if len(claims_str) > MAX_CLAIMS_PAYLOAD_SIZE: raise ValueError( - 'Custom claims payload must not exceed {0} characters.'.format( - MAX_CLAIMS_PAYLOAD_SIZE)) + f'Custom claims payload must not exceed {MAX_CLAIMS_PAYLOAD_SIZE} characters.') try: parsed = json.loads(claims_str) - except Exception: - raise ValueError('Failed to parse custom claims string as JSON.') + except Exception as err: + raise ValueError('Failed to parse custom claims string as JSON.') from err if not isinstance(parsed, dict): raise ValueError('Custom claims must be parseable as a JSON object.') invalid_claims = RESERVED_CLAIMS.intersection(set(parsed.keys())) if len(invalid_claims) > 1: joined = ', '.join(sorted(invalid_claims)) - raise ValueError('Claims "{0}" are reserved, and must not be set.'.format(joined)) + raise ValueError(f'Claims "{joined}" are reserved, and must not be set.') if len(invalid_claims) == 1: raise ValueError( - 'Claim "{0}" is reserved, and must not be set.'.format(invalid_claims.pop())) + f'Claim "{invalid_claims.pop()}" is reserved, and must not be set.') return claims_str def validate_action_type(action_type): if action_type not in VALID_EMAIL_ACTION_TYPES: - raise ValueError('Invalid action type provided action_type: {0}. \ - Valid values are {1}'.format(action_type, ', '.join(VALID_EMAIL_ACTION_TYPES))) + raise ValueError( + f'Invalid action type provided action_type: {action_type}. Valid values are ' + f'{", ".join(VALID_EMAIL_ACTION_TYPES)}') return action_type def validate_provider_ids(provider_ids, required=False): @@ -282,7 +278,7 @@ def build_update_mask(params): if isinstance(value, dict): child_mask = build_update_mask(value) for child in child_mask: - mask.append('{0}.{1}'.format(key, child)) + mask.append(f'{key}.{child}') else: mask.append(key) @@ -443,7 +439,7 @@ def handle_auth_backend_error(error): code, custom_message = _parse_error_body(error.response) if not code: - msg = 'Unexpected error response: {0}'.format(error.response.content.decode()) + msg = f'Unexpected error response: {error.response.content.decode()}' return _utils.handle_requests_error(error, message=msg) exc_type = _CODE_TO_EXC_TYPE.get(code) @@ -479,5 +475,5 @@ def _parse_error_body(response): def _build_error_message(code, exc_type, custom_message): default_message = exc_type.default_message if ( exc_type and hasattr(exc_type, 'default_message')) else 'Error while calling Auth service' - ext = ' {0}'.format(custom_message) if custom_message else '' - return '{0} ({1}).{2}'.format(default_message, code, ext) + ext = f' {custom_message}' if custom_message else '' + return f'{default_message} ({code}).{ext}' diff --git a/firebase_admin/_gapic_utils.py b/firebase_admin/_gapic_utils.py deleted file mode 100644 index 3c975808c..000000000 --- a/firebase_admin/_gapic_utils.py +++ /dev/null @@ -1,122 +0,0 @@ -# Copyright 2021 Google Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Internal utilities for interacting with Google API client.""" - -import io -import socket - -import googleapiclient -import httplib2 -import requests - -from firebase_admin import exceptions -from firebase_admin import _utils - - -def handle_platform_error_from_googleapiclient(error, handle_func=None): - """Constructs a ``FirebaseError`` from the given googleapiclient error. - - This can be used to handle errors returned by Google Cloud Platform (GCP) APIs. - - Args: - error: An error raised by the googleapiclient while making an HTTP call to a GCP API. - handle_func: A function that can be used to handle platform errors in a custom way. When - specified, this function will be called with three arguments. It has the same - signature as ```_handle_func_googleapiclient``, but may return ``None``. - - Returns: - FirebaseError: A ``FirebaseError`` that can be raised to the user code. - """ - if not isinstance(error, googleapiclient.errors.HttpError): - return handle_googleapiclient_error(error) - - content = error.content.decode() - status_code = error.resp.status - error_dict, message = _utils._parse_platform_error(content, status_code) # pylint: disable=protected-access - http_response = _http_response_from_googleapiclient_error(error) - exc = None - if handle_func: - exc = handle_func(error, message, error_dict, http_response) - - return exc if exc else _handle_func_googleapiclient(error, message, error_dict, http_response) - - -def _handle_func_googleapiclient(error, message, error_dict, http_response): - """Constructs a ``FirebaseError`` from the given GCP error. - - Args: - error: An error raised by the googleapiclient module while making an HTTP call. - message: A message to be included in the resulting ``FirebaseError``. - error_dict: Parsed GCP error response. - http_response: A requests HTTP response object to associate with the exception. - - Returns: - FirebaseError: A ``FirebaseError`` that can be raised to the user code or None. - """ - code = error_dict.get('status') - return handle_googleapiclient_error(error, message, code, http_response) - - -def handle_googleapiclient_error(error, message=None, code=None, http_response=None): - """Constructs a ``FirebaseError`` from the given googleapiclient error. - - This method is agnostic of the remote service that produced the error, whether it is a GCP - service or otherwise. Therefore, this method does not attempt to parse the error response in - any way. - - Args: - error: An error raised by the googleapiclient module while making an HTTP call. - message: A message to be included in the resulting ``FirebaseError`` (optional). If not - specified the string representation of the ``error`` argument is used as the message. - code: A GCP error code that will be used to determine the resulting error type (optional). - If not specified the HTTP status code on the error response is used to determine a - suitable error code. - http_response: A requests HTTP response object to associate with the exception (optional). - If not specified, one will be created from the ``error``. - - Returns: - FirebaseError: A ``FirebaseError`` that can be raised to the user code. - """ - if isinstance(error, socket.timeout) or ( - isinstance(error, socket.error) and 'timed out' in str(error)): - return exceptions.DeadlineExceededError( - message='Timed out while making an API call: {0}'.format(error), - cause=error) - if isinstance(error, httplib2.ServerNotFoundError): - return exceptions.UnavailableError( - message='Failed to establish a connection: {0}'.format(error), - cause=error) - if not isinstance(error, googleapiclient.errors.HttpError): - return exceptions.UnknownError( - message='Unknown error while making a remote service call: {0}'.format(error), - cause=error) - - if not code: - code = _utils._http_status_to_error_code(error.resp.status) # pylint: disable=protected-access - if not message: - message = str(error) - if not http_response: - http_response = _http_response_from_googleapiclient_error(error) - - err_type = _utils._error_code_to_exception_type(code) # pylint: disable=protected-access - return err_type(message=message, cause=error, http_response=http_response) - - -def _http_response_from_googleapiclient_error(error): - """Creates a requests HTTP Response object from the given googleapiclient error.""" - resp = requests.models.Response() - resp.raw = io.BytesIO(error.content) - resp.status_code = error.resp.status - return resp diff --git a/firebase_admin/_messaging_encoder.py b/firebase_admin/_messaging_encoder.py index 32f97875e..960a6d742 100644 --- a/firebase_admin/_messaging_encoder.py +++ b/firebase_admin/_messaging_encoder.py @@ -20,7 +20,7 @@ import numbers import re -import firebase_admin._messaging_utils as _messaging_utils +from firebase_admin import _messaging_utils class Message: @@ -99,10 +99,10 @@ def check_string(cls, label, value, non_empty=False): return None if not isinstance(value, str): if non_empty: - raise ValueError('{0} must be a non-empty string.'.format(label)) - raise ValueError('{0} must be a string.'.format(label)) + raise ValueError(f'{label} must be a non-empty string.') + raise ValueError(f'{label} must be a string.') if non_empty and not value: - raise ValueError('{0} must be a non-empty string.'.format(label)) + raise ValueError(f'{label} must be a non-empty string.') return value @classmethod @@ -110,7 +110,7 @@ def check_number(cls, label, value): if value is None: return None if not isinstance(value, numbers.Number): - raise ValueError('{0} must be a number.'.format(label)) + raise ValueError(f'{label} must be a number.') return value @classmethod @@ -119,13 +119,13 @@ def check_string_dict(cls, label, value): if value is None or value == {}: return None if not isinstance(value, dict): - raise ValueError('{0} must be a dictionary.'.format(label)) + raise ValueError(f'{label} must be a dictionary.') non_str = [k for k in value if not isinstance(k, str)] if non_str: - raise ValueError('{0} must not contain non-string keys.'.format(label)) + raise ValueError(f'{label} must not contain non-string keys.') non_str = [v for v in value.values() if not isinstance(v, str)] if non_str: - raise ValueError('{0} must not contain non-string values.'.format(label)) + raise ValueError(f'{label} must not contain non-string values.') return value @classmethod @@ -134,10 +134,10 @@ def check_string_list(cls, label, value): if value is None or value == []: return None if not isinstance(value, list): - raise ValueError('{0} must be a list of strings.'.format(label)) + raise ValueError(f'{label} must be a list of strings.') non_str = [k for k in value if not isinstance(k, str)] if non_str: - raise ValueError('{0} must not contain non-string values.'.format(label)) + raise ValueError(f'{label} must not contain non-string values.') return value @classmethod @@ -146,10 +146,10 @@ def check_number_list(cls, label, value): if value is None or value == []: return None if not isinstance(value, list): - raise ValueError('{0} must be a list of numbers.'.format(label)) + raise ValueError(f'{label} must be a list of numbers.') non_number = [k for k in value if not isinstance(k, numbers.Number)] if non_number: - raise ValueError('{0} must not contain non-number values.'.format(label)) + raise ValueError(f'{label} must not contain non-number values.') return value @classmethod @@ -157,7 +157,7 @@ def check_analytics_label(cls, label, value): """Checks if the given value is a valid analytics label.""" value = _Validators.check_string(label, value) if value is not None and not re.match(r'^[a-zA-Z0-9-_.~%]{1,50}$', value): - raise ValueError('Malformed {}.'.format(label)) + raise ValueError(f'Malformed {label}.') return value @classmethod @@ -166,7 +166,7 @@ def check_boolean(cls, label, value): if value is None: return None if not isinstance(value, bool): - raise ValueError('{0} must be a boolean.'.format(label)) + raise ValueError(f'{label} must be a boolean.') return value @classmethod @@ -175,7 +175,7 @@ def check_datetime(cls, label, value): if value is None: return None if not isinstance(value, datetime.datetime): - raise ValueError('{0} must be a datetime.'.format(label)) + raise ValueError(f'{label} must be a datetime.') return value @@ -245,8 +245,8 @@ def encode_ttl(cls, ttl): seconds = int(math.floor(total_seconds)) nanos = int((total_seconds - seconds) * 1e9) if nanos: - return '{0}.{1}s'.format(seconds, str(nanos).zfill(9)) - return '{0}s'.format(seconds) + return f'{seconds}.{str(nanos).zfill(9)}s' + return f'{seconds}s' @classmethod def encode_milliseconds(cls, label, msec): @@ -256,16 +256,16 @@ def encode_milliseconds(cls, label, msec): if isinstance(msec, numbers.Number): msec = datetime.timedelta(milliseconds=msec) if not isinstance(msec, datetime.timedelta): - raise ValueError('{0} must be a duration in milliseconds or an instance of ' - 'datetime.timedelta.'.format(label)) + raise ValueError( + f'{label} must be a duration in milliseconds or an instance of datetime.timedelta.') total_seconds = msec.total_seconds() if total_seconds < 0: - raise ValueError('{0} must not be negative.'.format(label)) + raise ValueError(f'{label} must not be negative.') seconds = int(math.floor(total_seconds)) nanos = int((total_seconds - seconds) * 1e9) if nanos: - return '{0}.{1}s'.format(seconds, str(nanos).zfill(9)) - return '{0}s'.format(seconds) + return f'{seconds}.{str(nanos).zfill(9)}s' + return f'{seconds}s' @classmethod def encode_android_notification(cls, notification): @@ -409,7 +409,7 @@ def encode_light_settings(cls, light_settings): raise ValueError( 'LightSettings.color must be in the form #RRGGBB or #RRGGBBAA.') if len(color) == 7: - color = (color+'FF') + color = color+'FF' rgba = [int(color[i:i + 2], 16) / 255.0 for i in (1, 3, 5, 7)] result['color'] = {'red': rgba[0], 'green': rgba[1], 'blue': rgba[2], 'alpha': rgba[3]} @@ -475,7 +475,7 @@ def encode_webpush_notification(cls, notification): for key, value in notification.custom_data.items(): if key in result: raise ValueError( - 'Multiple specifications for {0} in WebpushNotification.'.format(key)) + f'Multiple specifications for {key} in WebpushNotification.') result[key] = value return cls.remove_null_values(result) @@ -585,7 +585,7 @@ def encode_aps(cls, aps): for key, val in aps.custom_data.items(): _Validators.check_string('Aps.custom_data key', key) if key in result: - raise ValueError('Multiple specifications for {0} in Aps.'.format(key)) + raise ValueError(f'Multiple specifications for {key} in Aps.') result[key] = val return cls.remove_null_values(result) @@ -698,7 +698,7 @@ def default(self, o): # pylint: disable=method-hidden } result['topic'] = MessageEncoder.sanitize_topic_name(result.get('topic')) result = MessageEncoder.remove_null_values(result) - target_count = sum([t in result for t in ['token', 'topic', 'condition']]) + target_count = sum(t in result for t in ['token', 'topic', 'condition']) if target_count != 1: raise ValueError('Exactly one of token, topic or condition must be specified.') return result diff --git a/firebase_admin/_rfc3339.py b/firebase_admin/_rfc3339.py index 2c720bdd1..8489bdcb9 100644 --- a/firebase_admin/_rfc3339.py +++ b/firebase_admin/_rfc3339.py @@ -84,4 +84,4 @@ def _parse_to_datetime(datestr): except ValueError: pass - raise ValueError('time data {0} does not match RFC3339 format'.format(datestr)) + raise ValueError(f'time data {datestr} does not match RFC3339 format') diff --git a/firebase_admin/_sseclient.py b/firebase_admin/_sseclient.py index 6585dfc80..3372fe5f2 100644 --- a/firebase_admin/_sseclient.py +++ b/firebase_admin/_sseclient.py @@ -34,7 +34,7 @@ class KeepAuthSession(transport.requests.AuthorizedSession): """A session that does not drop authentication on redirects between domains.""" def __init__(self, credential): - super(KeepAuthSession, self).__init__(credential) + super().__init__(credential) def rebuild_auth(self, prepared_request, response): pass @@ -86,7 +86,7 @@ def __init__(self, url, session, retry=3000, **kwargs): self.requests_kwargs = kwargs self.should_connect = True self.last_id = None - self.buf = u'' # Keep data here as it streams in + self.buf = '' # Keep data here as it streams in headers = self.requests_kwargs.get('headers', {}) # The SSE spec requires making requests with Cache-Control: no-cache @@ -153,9 +153,6 @@ def __next__(self): self.last_id = event.event_id return event - def next(self): - return self.__next__() - class Event: """Event represents the events fired by SSE.""" @@ -184,7 +181,7 @@ def parse(cls, raw): match = cls.sse_line_pattern.match(line) if match is None: # Malformed line. Discard but warn. - warnings.warn('Invalid SSE line: "%s"' % line, SyntaxWarning) + warnings.warn(f'Invalid SSE line: "{line}"', SyntaxWarning) continue name = match.groupdict()['name'] @@ -196,7 +193,7 @@ def parse(cls, raw): # If we already have some data, then join to it with a newline. # Else this is it. if event.data: - event.data = '%s\n%s' % (event.data, value) + event.data = f'{event.data}\n{value}' else: event.data = value elif name == 'event': diff --git a/firebase_admin/_token_gen.py b/firebase_admin/_token_gen.py index a2fc725e8..1607ef0ba 100644 --- a/firebase_admin/_token_gen.py +++ b/firebase_admin/_token_gen.py @@ -114,7 +114,7 @@ def __init__(self, app, http_client, url_override=None): self.http_client = http_client self.request = transport.requests.Request() url_prefix = url_override or self.ID_TOOLKIT_URL - self.base_url = '{0}/projects/{1}'.format(url_prefix, app.project_id) + self.base_url = f'{url_prefix}/projects/{app.project_id}' self._signing_provider = None def _init_signing_provider(self): @@ -142,7 +142,7 @@ def _init_signing_provider(self): resp = self.request(url=METADATA_SERVICE_URL, headers={'Metadata-Flavor': 'Google'}) if resp.status != 200: raise ValueError( - 'Failed to contact the local metadata service: {0}.'.format(resp.data.decode())) + f'Failed to contact the local metadata service: {resp.data.decode()}.') service_account = resp.data.decode() return _SigningProvider.from_iam(self.request, google_cred, service_account) @@ -155,10 +155,10 @@ def signing_provider(self): except Exception as error: url = 'https://firebase.google.com/docs/auth/admin/create-custom-tokens' raise ValueError( - 'Failed to determine service account: {0}. Make sure to initialize the SDK ' - 'with service account credentials or specify a service account ID with ' - 'iam.serviceAccounts.signBlob permission. Please refer to {1} for more ' - 'details on creating custom tokens.'.format(error, url)) + f'Failed to determine service account: {error}. Make sure to initialize the ' + 'SDK with service account credentials or specify a service account ID with ' + f'iam.serviceAccounts.signBlob permission. Please refer to {url} for more ' + 'details on creating custom tokens.') from error return self._signing_provider def create_custom_token(self, uid, developer_claims=None, tenant_id=None): @@ -170,13 +170,13 @@ def create_custom_token(self, uid, developer_claims=None, tenant_id=None): disallowed_keys = set(developer_claims.keys()) & RESERVED_CLAIMS if disallowed_keys: if len(disallowed_keys) > 1: - error_message = ('Developer claims {0} are reserved and ' - 'cannot be specified.'.format( - ', '.join(disallowed_keys))) + error_message = ( + f'Developer claims {", ".join(disallowed_keys)} are reserved and cannot be ' + 'specified.') else: - error_message = ('Developer claim {0} is reserved and ' - 'cannot be specified.'.format( - ', '.join(disallowed_keys))) + error_message = ( + f'Developer claim {", ".join(disallowed_keys)} is reserved and cannot be ' + 'specified.') raise ValueError(error_message) if not uid or not isinstance(uid, str) or len(uid) > 128: @@ -202,8 +202,8 @@ def create_custom_token(self, uid, developer_claims=None, tenant_id=None): try: return jwt.encode(signing_provider.signer, payload, header=header) except google.auth.exceptions.TransportError as error: - msg = 'Failed to sign custom token. {0}'.format(error) - raise TokenSignError(msg, error) + msg = f'Failed to sign custom token. {error}' + raise TokenSignError(msg, error) from error def create_session_cookie(self, id_token, expires_in): @@ -211,21 +211,22 @@ def create_session_cookie(self, id_token, expires_in): id_token = id_token.decode('utf-8') if isinstance(id_token, bytes) else id_token if not isinstance(id_token, str) or not id_token: raise ValueError( - 'Illegal ID token provided: {0}. ID token must be a non-empty ' - 'string.'.format(id_token)) + f'Illegal ID token provided: {id_token}. ID token must be a non-empty string.') if isinstance(expires_in, datetime.timedelta): expires_in = int(expires_in.total_seconds()) if isinstance(expires_in, bool) or not isinstance(expires_in, int): - raise ValueError('Illegal expiry duration: {0}.'.format(expires_in)) + raise ValueError(f'Illegal expiry duration: {expires_in}.') if expires_in < MIN_SESSION_COOKIE_DURATION_SECONDS: - raise ValueError('Illegal expiry duration: {0}. Duration must be at least {1} ' - 'seconds.'.format(expires_in, MIN_SESSION_COOKIE_DURATION_SECONDS)) + raise ValueError( + f'Illegal expiry duration: {expires_in}. Duration must be at least ' + f'{MIN_SESSION_COOKIE_DURATION_SECONDS} seconds.') if expires_in > MAX_SESSION_COOKIE_DURATION_SECONDS: - raise ValueError('Illegal expiry duration: {0}. Duration must be at most {1} ' - 'seconds.'.format(expires_in, MAX_SESSION_COOKIE_DURATION_SECONDS)) + raise ValueError( + f'Illegal expiry duration: {expires_in}. Duration must be at most ' + f'{MAX_SESSION_COOKIE_DURATION_SECONDS} seconds.') - url = '{0}:createSessionCookie'.format(self.base_url) + url = f'{self.base_url}:createSessionCookie' payload = { 'idToken': id_token, 'validDuration': expires_in, @@ -234,11 +235,10 @@ def create_session_cookie(self, id_token, expires_in): body, http_resp = self.http_client.body_and_response('post', url, json=payload) except requests.exceptions.RequestException as error: raise _auth_utils.handle_auth_backend_error(error) - else: - if not body or not body.get('sessionCookie'): - raise _auth_utils.UnexpectedResponseError( - 'Failed to create session cookie.', http_response=http_resp) - return body.get('sessionCookie') + if not body or not body.get('sessionCookie'): + raise _auth_utils.UnexpectedResponseError( + 'Failed to create session cookie.', http_response=http_resp) + return body.get('sessionCookie') class CertificateFetchRequest(transport.Request): @@ -307,9 +307,9 @@ def __init__(self, **kwargs): self.cert_url = kwargs.pop('cert_url') self.issuer = kwargs.pop('issuer') if self.short_name[0].lower() in 'aeiou': - self.articled_short_name = 'an {0}'.format(self.short_name) + self.articled_short_name = f'an {self.short_name}' else: - self.articled_short_name = 'a {0}'.format(self.short_name) + self.articled_short_name = f'a {self.short_name}' self._invalid_token_error = kwargs.pop('invalid_token_error') self._expired_token_error = kwargs.pop('expired_token_error') @@ -318,20 +318,20 @@ def verify(self, token, request, clock_skew_seconds=0): token = token.encode('utf-8') if isinstance(token, str) else token if not isinstance(token, bytes) or not token: raise ValueError( - 'Illegal {0} provided: {1}. {0} must be a non-empty ' - 'string.'.format(self.short_name, token)) + f'Illegal {self.short_name} provided: {token}. {self.short_name} must be a ' + 'non-empty string.') if not self.project_id: raise ValueError( 'Failed to ascertain project ID from the credential or the environment. Project ' - 'ID is required to call {0}. Initialize the app with a credentials.Certificate ' - 'or set your Firebase project ID as an app option. Alternatively set the ' - 'GOOGLE_CLOUD_PROJECT environment variable.'.format(self.operation)) + f'ID is required to call {self.operation}. Initialize the app with a ' + 'credentials.Certificate or set your Firebase project ID as an app option. ' + 'Alternatively set the GOOGLE_CLOUD_PROJECT environment variable.') if clock_skew_seconds < 0 or clock_skew_seconds > 60: raise ValueError( - 'Illegal clock_skew_seconds value: {0}. Must be between 0 and 60, inclusive.' - .format(clock_skew_seconds)) + f'Illegal clock_skew_seconds value: {clock_skew_seconds}. Must be between 0 and 60' + ', inclusive.') header, payload = self._decode_unverified(token) issuer = payload.get('iss') @@ -340,52 +340,51 @@ def verify(self, token, request, clock_skew_seconds=0): expected_issuer = self.issuer + self.project_id project_id_match_msg = ( - 'Make sure the {0} comes from the same Firebase project as the service account used ' - 'to authenticate this SDK.'.format(self.short_name)) + f'Make sure the {self.short_name} comes from the same Firebase project as the service ' + 'account used to authenticate this SDK.') verify_id_token_msg = ( - 'See {0} for details on how to retrieve {1}.'.format(self.url, self.short_name)) + f'See {self.url} for details on how to retrieve {self.short_name}.') emulated = _auth_utils.is_emulated() error_message = None if audience == FIREBASE_AUDIENCE: error_message = ( - '{0} expects {1}, but was given a custom ' - 'token.'.format(self.operation, self.articled_short_name)) + f'{self.operation} expects {self.articled_short_name}, but was given a custom ' + 'token.') elif not emulated and not header.get('kid'): if header.get('alg') == 'HS256' and payload.get( 'v') == 0 and 'uid' in payload.get('d', {}): error_message = ( - '{0} expects {1}, but was given a legacy custom ' - 'token.'.format(self.operation, self.articled_short_name)) + f'{self.operation} expects {self.articled_short_name}, but was given a legacy ' + 'custom token.') else: - error_message = 'Firebase {0} has no "kid" claim.'.format(self.short_name) + error_message = f'Firebase {self.short_name} has no "kid" claim.' elif not emulated and header.get('alg') != 'RS256': error_message = ( - 'Firebase {0} has incorrect algorithm. Expected "RS256" but got ' - '"{1}". {2}'.format(self.short_name, header.get('alg'), verify_id_token_msg)) + f'Firebase {self.short_name} has incorrect algorithm. Expected "RS256" but got ' + f'"{header.get("alg")}". {verify_id_token_msg}') elif audience != self.project_id: error_message = ( - 'Firebase {0} has incorrect "aud" (audience) claim. Expected "{1}" but ' - 'got "{2}". {3} {4}'.format(self.short_name, self.project_id, audience, - project_id_match_msg, verify_id_token_msg)) + f'Firebase {self.short_name} has incorrect "aud" (audience) claim. Expected ' + f'"{self.project_id}" but got "{audience}". {project_id_match_msg} ' + f'{verify_id_token_msg}') elif issuer != expected_issuer: error_message = ( - 'Firebase {0} has incorrect "iss" (issuer) claim. Expected "{1}" but ' - 'got "{2}". {3} {4}'.format(self.short_name, expected_issuer, issuer, - project_id_match_msg, verify_id_token_msg)) + f'Firebase {self.short_name} has incorrect "iss" (issuer) claim. Expected ' + f'"{expected_issuer}" but got "{issuer}". {project_id_match_msg} ' + f'{verify_id_token_msg}') elif subject is None or not isinstance(subject, str): error_message = ( - 'Firebase {0} has no "sub" (subject) claim. ' - '{1}'.format(self.short_name, verify_id_token_msg)) + f'Firebase {self.short_name} has no "sub" (subject) claim. {verify_id_token_msg}') elif not subject: error_message = ( - 'Firebase {0} has an empty string "sub" (subject) claim. ' - '{1}'.format(self.short_name, verify_id_token_msg)) + f'Firebase {self.short_name} has an empty string "sub" (subject) claim. ' + f'{verify_id_token_msg}') elif len(subject) > 128: error_message = ( - 'Firebase {0} has a "sub" (subject) claim longer than 128 characters. ' - '{1}'.format(self.short_name, verify_id_token_msg)) + f'Firebase {self.short_name} has a "sub" (subject) claim longer than 128 ' + f'characters. {verify_id_token_msg}') if error_message: raise self._invalid_token_error(error_message) @@ -403,7 +402,7 @@ def verify(self, token, request, clock_skew_seconds=0): verified_claims['uid'] = verified_claims['sub'] return verified_claims except google.auth.exceptions.TransportError as error: - raise CertificateFetchError(str(error), cause=error) + raise CertificateFetchError(str(error), cause=error) from error except ValueError as error: if 'Token expired' in str(error): raise self._expired_token_error(str(error), cause=error) diff --git a/firebase_admin/_user_import.py b/firebase_admin/_user_import.py index 659a68701..7c7a9e70b 100644 --- a/firebase_admin/_user_import.py +++ b/firebase_admin/_user_import.py @@ -216,10 +216,10 @@ def provider_data(self): def provider_data(self, provider_data): if provider_data is not None: try: - if any([not isinstance(p, UserProvider) for p in provider_data]): + if any(not isinstance(p, UserProvider) for p in provider_data): raise ValueError('One or more provider data instances are invalid.') - except TypeError: - raise ValueError('provider_data must be iterable.') + except TypeError as err: + raise ValueError('provider_data must be iterable.') from err self._provider_data = provider_data @property diff --git a/firebase_admin/_user_mgt.py b/firebase_admin/_user_mgt.py index aa0dfb0a4..9a75b7a2e 100644 --- a/firebase_admin/_user_mgt.py +++ b/firebase_admin/_user_mgt.py @@ -128,9 +128,9 @@ class UserRecord(UserInfo): """Contains metadata associated with a Firebase user account.""" def __init__(self, data): - super(UserRecord, self).__init__() + super().__init__() if not isinstance(data, dict): - raise ValueError('Invalid data argument: {0}. Must be a dictionary.'.format(data)) + raise ValueError(f'Invalid data argument: {data}. Must be a dictionary.') if not data.get('localId'): raise ValueError('User ID must not be None or empty.') self._data = data @@ -452,9 +452,9 @@ class ProviderUserInfo(UserInfo): """Contains metadata regarding how a user is known by a particular identity provider.""" def __init__(self, data): - super(ProviderUserInfo, self).__init__() + super().__init__() if not isinstance(data, dict): - raise ValueError('Invalid data argument: {0}. Must be a dictionary.'.format(data)) + raise ValueError(f'Invalid data argument: {data}. Must be a dictionary.') if not data.get('rawId'): raise ValueError('User ID must not be None or empty.') self._data = data @@ -516,30 +516,30 @@ def encode_action_code_settings(settings): try: parsed = parse.urlparse(settings.url) if not parsed.netloc: - raise ValueError('Malformed dynamic action links url: "{0}".'.format(settings.url)) + raise ValueError(f'Malformed dynamic action links url: "{settings.url}".') parameters['continueUrl'] = settings.url - except Exception: - raise ValueError('Malformed dynamic action links url: "{0}".'.format(settings.url)) + except Exception as err: + raise ValueError(f'Malformed dynamic action links url: "{settings.url}".') from err # handle_code_in_app if settings.handle_code_in_app is not None: if not isinstance(settings.handle_code_in_app, bool): - raise ValueError('Invalid value provided for handle_code_in_app: {0}' - .format(settings.handle_code_in_app)) + raise ValueError( + f'Invalid value provided for handle_code_in_app: {settings.handle_code_in_app}') parameters['canHandleCodeInApp'] = settings.handle_code_in_app # dynamic_link_domain if settings.dynamic_link_domain is not None: if not isinstance(settings.dynamic_link_domain, str): - raise ValueError('Invalid value provided for dynamic_link_domain: {0}' - .format(settings.dynamic_link_domain)) + raise ValueError( + f'Invalid value provided for dynamic_link_domain: {settings.dynamic_link_domain}') parameters['dynamicLinkDomain'] = settings.dynamic_link_domain # ios_bundle_id if settings.ios_bundle_id is not None: if not isinstance(settings.ios_bundle_id, str): - raise ValueError('Invalid value provided for ios_bundle_id: {0}' - .format(settings.ios_bundle_id)) + raise ValueError( + f'Invalid value provided for ios_bundle_id: {settings.ios_bundle_id}') parameters['iOSBundleId'] = settings.ios_bundle_id # android_* attributes @@ -549,20 +549,21 @@ def encode_action_code_settings(settings): if settings.android_package_name is not None: if not isinstance(settings.android_package_name, str): - raise ValueError('Invalid value provided for android_package_name: {0}' - .format(settings.android_package_name)) + raise ValueError( + f'Invalid value provided for android_package_name: {settings.android_package_name}') parameters['androidPackageName'] = settings.android_package_name if settings.android_minimum_version is not None: if not isinstance(settings.android_minimum_version, str): - raise ValueError('Invalid value provided for android_minimum_version: {0}' - .format(settings.android_minimum_version)) + raise ValueError( + 'Invalid value provided for android_minimum_version: ' + f'{settings.android_minimum_version}') parameters['androidMinimumVersion'] = settings.android_minimum_version if settings.android_install_app is not None: if not isinstance(settings.android_install_app, bool): - raise ValueError('Invalid value provided for android_install_app: {0}' - .format(settings.android_install_app)) + raise ValueError( + f'Invalid value provided for android_install_app: {settings.android_install_app}') parameters['androidInstallApp'] = settings.android_install_app return parameters @@ -576,9 +577,9 @@ class UserManager: def __init__(self, http_client, project_id, tenant_id=None, url_override=None): self.http_client = http_client url_prefix = url_override or self.ID_TOOLKIT_URL - self.base_url = '{0}/projects/{1}'.format(url_prefix, project_id) + self.base_url = f'{url_prefix}/projects/{project_id}' if tenant_id: - self.base_url += '/tenants/{0}'.format(tenant_id) + self.base_url += f'/tenants/{tenant_id}' def get_user(self, **kwargs): """Gets the user data corresponding to the provided key.""" @@ -592,12 +593,12 @@ def get_user(self, **kwargs): key, key_type = kwargs.pop('phone_number'), 'phone number' payload = {'phoneNumber' : [_auth_utils.validate_phone(key, required=True)]} else: - raise TypeError('Unsupported keyword arguments: {0}.'.format(kwargs)) + raise TypeError(f'Unsupported keyword arguments: {kwargs}.') body, http_resp = self._make_request('post', '/accounts:lookup', json=payload) if not body or not body.get('users'): raise _auth_utils.UserNotFoundError( - 'No user record found for the provided {0}: {1}.'.format(key_type, key), + f'No user record found for the provided {key_type}: {key}.', http_response=http_resp) return body['users'][0] @@ -638,8 +639,7 @@ def get_users(self, identifiers): }) else: raise ValueError( - 'Invalid entry in "identifiers" list. Unsupported type: {}' - .format(type(identifier))) + f'Invalid entry in "identifiers" list. Unsupported type: {type(identifier)}') body, http_resp = self._make_request( 'post', '/accounts:lookup', json=payload) @@ -657,8 +657,7 @@ def list_users(self, page_token=None, max_results=MAX_LIST_USERS_RESULTS): raise ValueError('Max results must be an integer.') if max_results < 1 or max_results > MAX_LIST_USERS_RESULTS: raise ValueError( - 'Max results must be a positive integer less than ' - '{0}.'.format(MAX_LIST_USERS_RESULTS)) + f'Max results must be a positive integer less than {MAX_LIST_USERS_RESULTS}.') payload = {'maxResults': max_results} if page_token: @@ -734,7 +733,7 @@ def update_user(self, uid, display_name=None, email=None, phone_number=None, body, http_resp = self._make_request('post', '/accounts:update', json=payload) if not body or not body.get('localId'): raise _auth_utils.UnexpectedResponseError( - 'Failed to update user: {0}.'.format(uid), http_response=http_resp) + f'Failed to update user: {uid}.', http_response=http_resp) return body.get('localId') def delete_user(self, uid): @@ -743,7 +742,7 @@ def delete_user(self, uid): body, http_resp = self._make_request('post', '/accounts:delete', json={'localId' : uid}) if not body or not body.get('kind'): raise _auth_utils.UnexpectedResponseError( - 'Failed to delete user: {0}.'.format(uid), http_response=http_resp) + f'Failed to delete user: {uid}.', http_response=http_resp) def delete_users(self, uids, force_delete=False): """Deletes the users identified by the specified user ids. @@ -786,15 +785,15 @@ def import_users(self, users, hash_alg=None): try: if not users or len(users) > MAX_IMPORT_USERS_SIZE: raise ValueError( - 'Users must be a non-empty list with no more than {0} elements.'.format( - MAX_IMPORT_USERS_SIZE)) - if any([not isinstance(u, _user_import.ImportUserRecord) for u in users]): + 'Users must be a non-empty list with no more than ' + f'{MAX_IMPORT_USERS_SIZE} elements.') + if any(not isinstance(u, _user_import.ImportUserRecord) for u in users): raise ValueError('One or more user objects are invalid.') - except TypeError: - raise ValueError('users must be iterable') + except TypeError as err: + raise ValueError('users must be iterable') from err payload = {'users': [u.to_dict() for u in users]} - if any(['passwordHash' in u for u in payload['users']]): + if any('passwordHash' in u for u in payload['users']): if not isinstance(hash_alg, _user_import.UserImportHash): raise ValueError('A UserImportHash is required to import users with passwords.') payload.update(hash_alg.to_dict()) @@ -837,7 +836,7 @@ def generate_email_action_link(self, action_type, email, action_code_settings=No return body.get('oobLink') def _make_request(self, method, path, **kwargs): - url = '{0}{1}'.format(self.base_url, path) + url = f'{self.base_url}{path}' try: return self.http_client.body_and_response(method, url, **kwargs) except requests.exceptions.RequestException as error: diff --git a/firebase_admin/_utils.py b/firebase_admin/_utils.py index 765d11587..d0aca884b 100644 --- a/firebase_admin/_utils.py +++ b/firebase_admin/_utils.py @@ -93,8 +93,9 @@ def _get_initialized_app(app): 'initialized via the firebase module.') return app - raise ValueError('Illegal app argument. Argument must be of type ' - ' firebase_admin.App, but given "{0}".'.format(type(app))) + raise ValueError( + 'Illegal app argument. Argument must be of type firebase_admin.App, but given ' + f'"{type(app)}".') @@ -172,7 +173,7 @@ def handle_operation_error(error): """ if not isinstance(error, dict): return exceptions.UnknownError( - message='Unknown error while making a remote service call: {0}'.format(error), + message=f'Unknown error while making a remote service call: {error}', cause=error) rpc_code = error.get('code') @@ -217,15 +218,15 @@ def handle_requests_error(error, message=None, code=None): """ if isinstance(error, requests.exceptions.Timeout): return exceptions.DeadlineExceededError( - message='Timed out while making an API call: {0}'.format(error), + message=f'Timed out while making an API call: {error}', cause=error) if isinstance(error, requests.exceptions.ConnectionError): return exceptions.UnavailableError( - message='Failed to establish a connection: {0}'.format(error), + message=f'Failed to establish a connection: {error}', cause=error) if error.response is None: return exceptions.UnknownError( - message='Unknown error while making a remote service call: {0}'.format(error), + message=f'Unknown error while making a remote service call: {error}', cause=error) if not code: @@ -271,11 +272,11 @@ def handle_httpx_error(error: httpx.HTTPError, message=None, code=None) -> excep """ if isinstance(error, httpx.TimeoutException): return exceptions.DeadlineExceededError( - message='Timed out while making an API call: {0}'.format(error), + message=f'Timed out while making an API call: {error}', cause=error) if isinstance(error, httpx.ConnectError): return exceptions.UnavailableError( - message='Failed to establish a connection: {0}'.format(error), + message=f'Failed to establish a connection: {error}', cause=error) if isinstance(error, httpx.HTTPStatusError): print("printing status error", error) @@ -288,7 +289,7 @@ def handle_httpx_error(error: httpx.HTTPError, message=None, code=None) -> excep return err_type(message=message, cause=error, http_response=error.response) return exceptions.UnknownError( - message='Unknown error while making a remote service call: {0}'.format(error), + message=f'Unknown error while making a remote service call: {error}', cause=error) def _http_status_to_error_code(status): @@ -326,7 +327,7 @@ def _parse_platform_error(content, status_code): error_dict = data.get('error', {}) msg = error_dict.get('message') if not msg: - msg = 'Unexpected HTTP response with status: {0}; body: {1}'.format(status_code, content) + msg = f'Unexpected HTTP response with status: {status_code}; body: {content}' return error_dict, msg diff --git a/firebase_admin/app_check.py b/firebase_admin/app_check.py index 53686db3d..40d857f4e 100644 --- a/firebase_admin/app_check.py +++ b/firebase_admin/app_check.py @@ -84,7 +84,7 @@ def verify_token(self, token: str) -> Dict[str, Any]: except (InvalidTokenError, DecodeError) as exception: raise ValueError( f'Verifying App Check token failed. Error: {exception}' - ) + ) from exception verified_claims['app_id'] = verified_claims.get('sub') return verified_claims @@ -112,28 +112,28 @@ def _decode_and_verify(self, token: str, signing_key: str): algorithms=["RS256"], audience=self._scoped_project_id ) - except InvalidSignatureError: + except InvalidSignatureError as exception: raise ValueError( 'The provided App Check token has an invalid signature.' - ) - except InvalidAudienceError: + ) from exception + except InvalidAudienceError as exception: raise ValueError( 'The provided App Check token has an incorrect "aud" (audience) claim. ' f'Expected payload to include {self._scoped_project_id}.' - ) - except InvalidIssuerError: + ) from exception + except InvalidIssuerError as exception: raise ValueError( 'The provided App Check token has an incorrect "iss" (issuer) claim. ' f'Expected claim to include {self._APP_CHECK_ISSUER}' - ) - except ExpiredSignatureError: + ) from exception + except ExpiredSignatureError as exception: raise ValueError( 'The provided App Check token has expired.' - ) + ) from exception except InvalidTokenError as exception: raise ValueError( f'Decoding App Check token failed. Error: {exception}' - ) + ) from exception audience = payload.get('aud') if not isinstance(audience, list) or self._scoped_project_id not in audience: @@ -156,6 +156,6 @@ class _Validators: def check_string(cls, label: str, value: Any): """Checks if the given value is a string.""" if value is None: - raise ValueError('{0} "{1}" must be a non-empty string.'.format(label, value)) + raise ValueError(f'{label} "{value}" must be a non-empty string.') if not isinstance(value, str): - raise ValueError('{0} "{1}" must be a string.'.format(label, value)) + raise ValueError(f'{label} "{value}" must be a string.') diff --git a/firebase_admin/credentials.py b/firebase_admin/credentials.py index 750600280..7117b71a9 100644 --- a/firebase_admin/credentials.py +++ b/firebase_admin/credentials.py @@ -63,7 +63,7 @@ class _ExternalCredentials(Base): """A wrapper for google.auth.credentials.Credentials typed credential instances""" def __init__(self, credential: GoogleAuthCredentials): - super(_ExternalCredentials, self).__init__() + super().__init__() self._g_credential = credential def get_credential(self): @@ -92,26 +92,27 @@ def __init__(self, cert): IOError: If the specified certificate file doesn't exist or cannot be read. ValueError: If the specified certificate is invalid. """ - super(Certificate, self).__init__() + super().__init__() if _is_file_path(cert): - with open(cert) as json_file: + with open(cert, encoding='utf-8') as json_file: json_data = json.load(json_file) elif isinstance(cert, dict): json_data = cert else: raise ValueError( - 'Invalid certificate argument: "{0}". Certificate argument must be a file path, ' - 'or a dict containing the parsed file contents.'.format(cert)) + f'Invalid certificate argument: "{cert}". Certificate argument must be a file ' + 'path, or a dict containing the parsed file contents.') if json_data.get('type') != self._CREDENTIAL_TYPE: - raise ValueError('Invalid service account certificate. Certificate must contain a ' - '"type" field set to "{0}".'.format(self._CREDENTIAL_TYPE)) + raise ValueError( + 'Invalid service account certificate. Certificate must contain a ' + f'"type" field set to "{self._CREDENTIAL_TYPE}".') try: self._g_credential = service_account.Credentials.from_service_account_info( json_data, scopes=_scopes) except ValueError as error: - raise ValueError('Failed to initialize a certificate credential. ' - 'Caused by: "{0}"'.format(error)) + raise ValueError( + f'Failed to initialize a certificate credential. Caused by: "{error}"') from error @property def project_id(self): @@ -142,7 +143,7 @@ def __init__(self): The credentials will be lazily initialized when get_credential() or project_id() is called. See those methods for possible errors raised. """ - super(ApplicationDefault, self).__init__() + super().__init__() self._g_credential = None # Will be lazily-loaded via _load_credential(). def get_credential(self): @@ -193,20 +194,21 @@ def __init__(self, refresh_token): IOError: If the specified file doesn't exist or cannot be read. ValueError: If the refresh token configuration is invalid. """ - super(RefreshToken, self).__init__() + super().__init__() if _is_file_path(refresh_token): - with open(refresh_token) as json_file: + with open(refresh_token, encoding='utf-8') as json_file: json_data = json.load(json_file) elif isinstance(refresh_token, dict): json_data = refresh_token else: raise ValueError( - 'Invalid refresh token argument: "{0}". Refresh token argument must be a file ' - 'path, or a dict containing the parsed file contents.'.format(refresh_token)) + f'Invalid refresh token argument: "{refresh_token}". Refresh token argument must ' + 'be a file path, or a dict containing the parsed file contents.') if json_data.get('type') != self._CREDENTIAL_TYPE: - raise ValueError('Invalid refresh token configuration. JSON must contain a ' - '"type" field set to "{0}".'.format(self._CREDENTIAL_TYPE)) + raise ValueError( + 'Invalid refresh token configuration. JSON must contain a ' + f'"type" field set to "{self._CREDENTIAL_TYPE}".') self._g_credential = credentials.Credentials.from_authorized_user_info(json_data, _scopes) @property diff --git a/firebase_admin/db.py b/firebase_admin/db.py index 1dec98653..800cbf8e3 100644 --- a/firebase_admin/db.py +++ b/firebase_admin/db.py @@ -39,8 +39,10 @@ _DB_ATTRIBUTE = '_database' _INVALID_PATH_CHARACTERS = '[].?#$' _RESERVED_FILTERS = ('$key', '$value', '$priority') -_USER_AGENT = 'Firebase/HTTP/{0}/{1}.{2}/AdminPython'.format( - firebase_admin.__version__, sys.version_info.major, sys.version_info.minor) +_USER_AGENT = ( + f'Firebase/HTTP/{firebase_admin.__version__}/{sys.version_info.major}' + f'.{sys.version_info.minor}/AdminPython' +) _TRANSACTION_MAX_RETRIES = 25 _EMULATOR_HOST_ENV_VAR = 'FIREBASE_DATABASE_EMULATOR_HOST' @@ -72,10 +74,9 @@ def reference(path='/', app=None, url=None): def _parse_path(path): """Parses a path string into a set of segments.""" if not isinstance(path, str): - raise ValueError('Invalid path: "{0}". Path must be a string.'.format(path)) + raise ValueError(f'Invalid path: "{path}". Path must be a string.') if any(ch in path for ch in _INVALID_PATH_CHARACTERS): - raise ValueError( - 'Invalid path: "{0}". Path contains illegal characters.'.format(path)) + raise ValueError(f'Invalid path: "{path}". Path contains illegal characters.') return [seg for seg in path.split('/') if seg] @@ -184,11 +185,9 @@ def child(self, path): ValueError: If the child path is not a string, not well-formed or begins with '/'. """ if not path or not isinstance(path, str): - raise ValueError( - 'Invalid path argument: "{0}". Path must be a non-empty string.'.format(path)) + raise ValueError(f'Invalid path argument: "{path}". Path must be a non-empty string.') if path.startswith('/'): - raise ValueError( - 'Invalid path argument: "{0}". Child path must not start with "/"'.format(path)) + raise ValueError(f'Invalid path argument: "{path}". Child path must not start with "/"') full_path = self._pathurl + '/' + path return Reference(client=self._client, path=full_path) @@ -433,7 +432,7 @@ def order_by_child(self, path): ValueError: If the child path is not a string, not well-formed or None. """ if path in _RESERVED_FILTERS: - raise ValueError('Illegal child path: {0}'.format(path)) + raise ValueError(f'Illegal child path: {path}') return Query(order_by=path, client=self._client, pathurl=self._add_suffix()) def order_by_key(self): @@ -492,8 +491,8 @@ def __init__(self, **kwargs): raise ValueError('order_by field must be a non-empty string') if order_by not in _RESERVED_FILTERS: if order_by.startswith('/'): - raise ValueError('Invalid path argument: "{0}". Child path must not start ' - 'with "/"'.format(order_by)) + raise ValueError( + f'Invalid path argument: "{order_by}". Child path must not start with "/"') segments = _parse_path(order_by) order_by = '/'.join(segments) self._client = kwargs.pop('client') @@ -501,7 +500,7 @@ def __init__(self, **kwargs): self._order_by = order_by self._params = {'orderBy' : json.dumps(order_by)} if kwargs: - raise ValueError('Unexpected keyword arguments: {0}'.format(kwargs)) + raise ValueError(f'Unexpected keyword arguments: {kwargs}') def limit_to_first(self, limit): """Creates a query with limit, and anchors it to the start of the window. @@ -604,7 +603,7 @@ def equal_to(self, value): def _querystr(self): params = [] for key in sorted(self._params): - params.append('{0}={1}'.format(key, self._params[key])) + params.append(f'{key}={self._params[key]}') return '&'.join(params) def get(self): @@ -642,7 +641,7 @@ def __init__(self, results, order_by): self.dict_input = False entries = [_SortEntry(k, v, order_by) for k, v in enumerate(results)] else: - raise ValueError('Sorting not supported for "{0}" object.'.format(type(results))) + raise ValueError(f'Sorting not supported for "{type(results)}" object.') self.sort_entries = sorted(entries) def get(self): @@ -783,8 +782,8 @@ def __init__(self, app): if emulator_host: if '//' in emulator_host: raise ValueError( - 'Invalid {0}: "{1}". It must follow format "host:port".'.format( - _EMULATOR_HOST_ENV_VAR, emulator_host)) + f'Invalid {_EMULATOR_HOST_ENV_VAR}: "{emulator_host}". It must follow format ' + '"host:port".') self._emulator_host = emulator_host else: self._emulator_host = None @@ -796,14 +795,12 @@ def get_client(self, db_url=None): if not db_url or not isinstance(db_url, str): raise ValueError( - 'Invalid database URL: "{0}". Database URL must be a non-empty ' - 'URL string.'.format(db_url)) + f'Invalid database URL: "{db_url}". Database URL must be a non-empty URL string.') parsed_url = parse.urlparse(db_url) if not parsed_url.netloc: raise ValueError( - 'Invalid database URL: "{0}". Database URL must be a wellformed ' - 'URL string.'.format(db_url)) + f'Invalid database URL: "{db_url}". Database URL must be a wellformed URL string.') emulator_config = self._get_emulator_config(parsed_url) if emulator_config: @@ -813,7 +810,7 @@ def get_client(self, db_url=None): else: # Defer credential lookup until we are certain it's going to be prod connection. credential = self._credential.get_credential() - base_url = 'https://{0}'.format(parsed_url.netloc) + base_url = f'https://{parsed_url.netloc}' params = {} @@ -835,7 +832,7 @@ def _get_emulator_config(self, parsed_url): return EmulatorConfig(base_url, namespace) if self._emulator_host: # Emulator mode enabled via environment variable - base_url = 'http://{0}'.format(self._emulator_host) + base_url = f'http://{self._emulator_host}' namespace = parsed_url.netloc.split('.')[0] return EmulatorConfig(base_url, namespace) @@ -847,21 +844,23 @@ def _parse_emulator_url(cls, parsed_url): query_ns = parse.parse_qs(parsed_url.query).get('ns') if parsed_url.scheme != 'http' or (not query_ns or len(query_ns) != 1 or not query_ns[0]): raise ValueError( - 'Invalid database URL: "{0}". Database URL must be a valid URL to a ' - 'Firebase Realtime Database instance.'.format(parsed_url.geturl())) + f'Invalid database URL: "{parsed_url.geturl()}". Database URL must be a valid URL ' + 'to a Firebase Realtime Database instance.') namespace = query_ns[0] - base_url = '{0}://{1}'.format(parsed_url.scheme, parsed_url.netloc) + base_url = f'{parsed_url.scheme}://{parsed_url.netloc}' return base_url, namespace @classmethod def _get_auth_override(cls, app): + """Gets and validates the database auth override to be used.""" auth_override = app.options.get('databaseAuthVariableOverride', cls._DEFAULT_AUTH_OVERRIDE) if auth_override == cls._DEFAULT_AUTH_OVERRIDE or auth_override is None: return auth_override if not isinstance(auth_override, dict): - raise ValueError('Invalid databaseAuthVariableOverride option: "{0}". Override ' - 'value must be a dict or None.'.format(auth_override)) + raise ValueError( + f'Invalid databaseAuthVariableOverride option: "{auth_override}". Override ' + 'value must be a dict or None.') return auth_override @@ -916,7 +915,7 @@ def request(self, method, url, **kwargs): Raises: FirebaseError: If an error occurs while making the HTTP call. """ - query = '&'.join('{0}={1}'.format(key, self.params[key]) for key in self.params) + query = '&'.join(f'{key}={value}' for key, value in self.params.items()) extra_params = kwargs.get('params') if extra_params: if query: @@ -926,7 +925,7 @@ def request(self, method, url, **kwargs): kwargs['params'] = query try: - return super(_Client, self).request(method, url, **kwargs) + return super().request(method, url, **kwargs) except requests.exceptions.RequestException as error: raise _Client.handle_rtdb_error(error) @@ -961,6 +960,6 @@ def _extract_error_message(cls, response): pass if not message: - message = 'Unexpected response from database: {0}'.format(response.content.decode()) + message = f'Unexpected response from database: {response.content.decode()}' return message diff --git a/firebase_admin/functions.py b/firebase_admin/functions.py index 48ce62a76..6db0fbb42 100644 --- a/firebase_admin/functions.py +++ b/firebase_admin/functions.py @@ -48,7 +48,7 @@ _FUNCTIONS_HEADERS = { 'X-GOOG-API-FORMAT-VERSION': '2', - 'X-FIREBASE-CLIENT': 'fire-admin-python/{0}'.format(firebase_admin.__version__), + 'X-FIREBASE-CLIENT': f'fire-admin-python/{firebase_admin.__version__}', } # Default canonical location ID of the task queue. @@ -307,9 +307,9 @@ class _Validators: def check_non_empty_string(cls, label: str, value: Any): """Checks if given value is a non-empty string and throws error if not.""" if not isinstance(value, str): - raise ValueError('{0} "{1}" must be a string.'.format(label, value)) + raise ValueError(f'{label} "{value}" must be a string.') if value == '': - raise ValueError('{0} "{1}" must be a non-empty string.'.format(label, value)) + raise ValueError(f'{label} "{value}" must be a non-empty string.') @classmethod def is_non_empty_string(cls, value: Any): diff --git a/firebase_admin/instance_id.py b/firebase_admin/instance_id.py index 604158d9c..812daf40b 100644 --- a/firebase_admin/instance_id.py +++ b/firebase_admin/instance_id.py @@ -81,7 +81,7 @@ def __init__(self, app): def delete_instance_id(self, instance_id): if not isinstance(instance_id, str) or not instance_id: raise ValueError('Instance ID must be a non-empty string.') - path = 'project/{0}/instanceId/{1}'.format(self._project_id, instance_id) + path = f'project/{self._project_id}/instanceId/{instance_id}' try: self._client.request('delete', path) except requests.exceptions.RequestException as error: @@ -94,6 +94,6 @@ def _extract_message(self, instance_id, error): status = error.response.status_code msg = self.error_codes.get(status) if msg: - return 'Instance ID "{0}": {1}'.format(instance_id, msg) + return f'Instance ID "{instance_id}": {msg}' - return 'Instance ID "{0}": {1}'.format(instance_id, error) + return f'Instance ID "{instance_id}": {error}' diff --git a/firebase_admin/messaging.py b/firebase_admin/messaging.py index 99dc93a67..749044436 100644 --- a/firebase_admin/messaging.py +++ b/firebase_admin/messaging.py @@ -18,21 +18,16 @@ from typing import Any, Callable, Dict, List, Optional, cast import concurrent.futures import json -import warnings import asyncio import logging import requests import httpx -from googleapiclient import http -from googleapiclient import _auth - import firebase_admin from firebase_admin import ( _http_client, _messaging_encoder, _messaging_utils, - _gapic_utils, _utils, exceptions, App @@ -72,8 +67,6 @@ 'WebpushNotificationAction', 'send', - 'send_all', - 'send_multicast', 'send_each', 'send_each_async', 'send_each_for_multicast', @@ -246,64 +239,6 @@ def send_each_for_multicast(multicast_message, dry_run=False, app=None): ) for token in multicast_message.tokens] return _get_messaging_service(app).send_each(messages, dry_run) -def send_all(messages, dry_run=False, app=None): - """Sends the given list of messages via Firebase Cloud Messaging as a single batch. - - If the ``dry_run`` mode is enabled, the message will not be actually delivered to the - recipients. Instead, FCM performs all the usual validations and emulates the send operation. - - Args: - messages: A list of ``messaging.Message`` instances. - dry_run: A boolean indicating whether to run the operation in dry run mode (optional). - app: An App instance (optional). - - Returns: - BatchResponse: A ``messaging.BatchResponse`` instance. - - Raises: - FirebaseError: If an error occurs while sending the message to the FCM service. - ValueError: If the input arguments are invalid. - - send_all() is deprecated. Use send_each() instead. - """ - warnings.warn('send_all() is deprecated. Use send_each() instead.', DeprecationWarning) - return _get_messaging_service(app).send_all(messages, dry_run) - -def send_multicast(multicast_message, dry_run=False, app=None): - """Sends the given mutlicast message to all tokens via Firebase Cloud Messaging (FCM). - - If the ``dry_run`` mode is enabled, the message will not be actually delivered to the - recipients. Instead, FCM performs all the usual validations and emulates the send operation. - - Args: - multicast_message: An instance of ``messaging.MulticastMessage``. - dry_run: A boolean indicating whether to run the operation in dry run mode (optional). - app: An App instance (optional). - - Returns: - BatchResponse: A ``messaging.BatchResponse`` instance. - - Raises: - FirebaseError: If an error occurs while sending the message to the FCM service. - ValueError: If the input arguments are invalid. - - send_multicast() is deprecated. Use send_each_for_multicast() instead. - """ - warnings.warn('send_multicast() is deprecated. Use send_each_for_multicast() instead.', - DeprecationWarning) - if not isinstance(multicast_message, MulticastMessage): - raise ValueError('Message must be an instance of messaging.MulticastMessage class.') - messages = [Message( - data=multicast_message.data, - notification=multicast_message.notification, - android=multicast_message.android, - webpush=multicast_message.webpush, - apns=multicast_message.apns, - fcm_options=multicast_message.fcm_options, - token=token - ) for token in multicast_message.tokens] - return _get_messaging_service(app).send_all(messages, dry_run) - def subscribe_to_topic(tokens, topic, app=None): """Subscribes a list of registration tokens to an FCM topic. @@ -366,7 +301,7 @@ class TopicManagementResponse: def __init__(self, resp): if not isinstance(resp, dict) or 'results' not in resp: - raise ValueError('Unexpected topic management response: {0}.'.format(resp)) + raise ValueError(f'Unexpected topic management response: {resp}.') self._success_count = 0 self._failure_count = 0 self._errors = [] @@ -465,14 +400,13 @@ def __init__(self, app: App) -> None: self._fcm_url = _MessagingService.FCM_URL.format(project_id) self._fcm_headers = { 'X-GOOG-API-FORMAT-VERSION': '2', - 'X-FIREBASE-CLIENT': 'fire-admin-python/{0}'.format(firebase_admin.__version__), + 'X-FIREBASE-CLIENT': f'fire-admin-python/{firebase_admin.__version__}', } timeout = app.options.get('httpTimeout', _http_client.DEFAULT_TIMEOUT_SECONDS) self._credential = app.credential.get_credential() self._client = _http_client.JsonHttpClient(credential=self._credential, timeout=timeout) self._async_client = _http_client.HttpxAsyncClient( credential=self._credential, timeout=timeout) - self._build_transport = _auth.authorized_http @classmethod def encode_message(cls, message): @@ -492,8 +426,7 @@ def send(self, message: Message, dry_run: bool = False) -> str: ) except requests.exceptions.RequestException as error: raise self._handle_fcm_error(error) - else: - return cast(str, resp['name']) + return cast(str, resp['name']) def send_each(self, messages: List[Message], dry_run: bool = False) -> BatchResponse: """Sends the given messages to FCM via the FCM v1 API.""" @@ -511,17 +444,16 @@ def send_data(data): json=data) except requests.exceptions.RequestException as exception: return SendResponse(resp=None, exception=self._handle_fcm_error(exception)) - else: - return SendResponse(resp, exception=None) + return SendResponse(resp, exception=None) message_data = [self._message_data(message, dry_run) for message in messages] try: with concurrent.futures.ThreadPoolExecutor(max_workers=len(message_data)) as executor: - responses = [resp for resp in executor.map(send_data, message_data)] + responses = list(executor.map(send_data, message_data)) return BatchResponse(responses) except Exception as error: raise exceptions.UnknownError( - message='Unknown error while making remote service calls: {0}'.format(error), + message=f'Unknown error while making remote service calls: {error}', cause=error) async def send_each_async(self, messages: List[Message], dry_run: bool = True) -> BatchResponse: @@ -543,8 +475,7 @@ async def send_data(data): # Catch errors caused by the requests library during authorization except requests.exceptions.RequestException as exception: return SendResponse(resp=None, exception=self._handle_fcm_error(exception)) - else: - return SendResponse(resp.json(), exception=None) + return SendResponse(resp.json(), exception=None) message_data = [self._message_data(message, dry_run) for message in messages] try: @@ -552,48 +483,9 @@ async def send_data(data): return BatchResponse(responses) except Exception as error: raise exceptions.UnknownError( - message='Unknown error while making remote service calls: {0}'.format(error), + message=f'Unknown error while making remote service calls: {error}', cause=error) - - def send_all(self, messages, dry_run=False): - """Sends the given messages to FCM via the batch API.""" - if not isinstance(messages, list): - raise ValueError('messages must be a list of messaging.Message instances.') - if len(messages) > 500: - raise ValueError('messages must not contain more than 500 elements.') - - responses = [] - - def batch_callback(_, response, error): - exception = None - if error: - exception = self._handle_batch_error(error) - send_response = SendResponse(response, exception) - responses.append(send_response) - - batch = http.BatchHttpRequest( - callback=batch_callback, batch_uri=_MessagingService.FCM_BATCH_URL) - transport = self._build_transport(self._credential) - for message in messages: - body = json.dumps(self._message_data(message, dry_run)) - req = http.HttpRequest( - http=transport, - postproc=self._postproc, - uri=self._fcm_url, - method='POST', - body=body, - headers=self._fcm_headers - ) - batch.add(req) - - try: - batch.execute() - except Exception as error: - raise self._handle_batch_error(error) - else: - return BatchResponse(responses) - def make_topic_management_request(self, tokens, topic, operation): """Invokes the IID service for topic management functionality.""" if isinstance(tokens, str): @@ -607,12 +499,12 @@ def make_topic_management_request(self, tokens, topic, operation): if not isinstance(topic, str) or not topic: raise ValueError('Topic must be a non-empty string.') if not topic.startswith('/topics/'): - topic = '/topics/{0}'.format(topic) + topic = f'/topics/{topic}' data = { 'to': topic, 'registration_tokens': tokens, } - url = '{0}/{1}'.format(_MessagingService.IID_URL, operation) + url = f'{_MessagingService.IID_URL}/{operation}' try: resp = self._client.body( 'post', @@ -622,8 +514,7 @@ def make_topic_management_request(self, tokens, topic, operation): ) except requests.exceptions.RequestException as error: raise self._handle_iid_error(error) - else: - return TopicManagementResponse(resp) + return TopicManagementResponse(resp) def _message_data(self, message, dry_run): data = {'message': _MessagingService.encode_message(message)} @@ -663,18 +554,15 @@ def _handle_iid_error(self, error): code = data.get('error') msg = None if code: - msg = 'Error while calling the IID service: {0}'.format(code) + msg = f'Error while calling the IID service: {code}' else: - msg = 'Unexpected HTTP response with status: {0}; body: {1}'.format( - error.response.status_code, error.response.content.decode()) + msg = ( + f'Unexpected HTTP response with status: {error.response.status_code}; body: ' + f'{error.response.content.decode()}' + ) return _utils.handle_requests_error(error, msg) - def _handle_batch_error(self, error): - """Handles errors received from the googleapiclient while making batch requests.""" - return _gapic_utils.handle_platform_error_from_googleapiclient( - error, _MessagingService._build_fcm_error_googleapiclient) - def close(self) -> None: asyncio.run(self._async_client.aclose()) @@ -683,6 +571,7 @@ def _build_fcm_error_requests(cls, error, message, error_dict): """Parses an error response from the FCM API and creates a FCM-specific exception if appropriate.""" exc_type = cls._build_fcm_error(error_dict) + # pylint: disable=not-callable return exc_type(message, cause=error, http_response=error.response) if exc_type else None @classmethod @@ -696,18 +585,12 @@ def _build_fcm_error_httpx( appropriate.""" exc_type = cls._build_fcm_error(error_dict) if isinstance(error, httpx.HTTPStatusError): + # pylint: disable=not-callable return exc_type( message, cause=error, http_response=error.response) if exc_type else None + # pylint: disable=not-callable return exc_type(message, cause=error) if exc_type else None - - @classmethod - def _build_fcm_error_googleapiclient(cls, error, message, error_dict, http_response): - """Parses an error response from the FCM API and creates a FCM-specific exception if - appropriate.""" - exc_type = cls._build_fcm_error(error_dict) - return exc_type(message, cause=error, http_response=http_response) if exc_type else None - @classmethod def _build_fcm_error( cls, diff --git a/firebase_admin/ml.py b/firebase_admin/ml.py index 98bdbb56a..3a77dd05f 100644 --- a/firebase_admin/ml.py +++ b/firebase_admin/ml.py @@ -24,7 +24,6 @@ import time import os from urllib import parse -import warnings import requests @@ -33,14 +32,14 @@ from firebase_admin import _utils from firebase_admin import exceptions -# pylint: disable=import-error,no-name-in-module +# pylint: disable=import-error,no-member try: from firebase_admin import storage _GCS_ENABLED = True except ImportError: _GCS_ENABLED = False -# pylint: disable=import-error,no-name-in-module +# pylint: disable=import-error,no-member try: import tensorflow as tf _TF_ENABLED = True @@ -54,9 +53,6 @@ _TAG_PATTERN = re.compile(r'^[A-Za-z0-9_-]{1,32}$') _GCS_TFLITE_URI_PATTERN = re.compile( r'^gs://(?P[a-z0-9_.-]{3,63})/(?P.+)$') -_AUTO_ML_MODEL_PATTERN = re.compile( - r'^projects/(?P[a-z0-9-]{6,30})/locations/(?P[^/]+)/' + - r'models/(?P[A-Za-z0-9]+)$') _RESOURCE_NAME_PATTERN = re.compile( r'^projects/(?P[a-z0-9-]{6,30})/models/(?P[A-Za-z0-9_-]{1,60})$') _OPERATION_NAME_PATTERN = re.compile( @@ -388,11 +384,6 @@ def _init_model_source(data): gcs_tflite_uri = data.pop('gcsTfliteUri', None) if gcs_tflite_uri: return TFLiteGCSModelSource(gcs_tflite_uri=gcs_tflite_uri) - auto_ml_model = data.pop('automlModel', None) - if auto_ml_model: - warnings.warn('AutoML model support is deprecated and will be removed in the next ' - 'major version.', DeprecationWarning) - return TFLiteAutoMlSource(auto_ml_model=auto_ml_model) return None @property @@ -516,8 +507,8 @@ def _assert_tf_enabled(): raise ImportError('Failed to import the tensorflow library for Python. Make sure ' 'to install the tensorflow module.') if not tf.version.VERSION.startswith('1.') and not tf.version.VERSION.startswith('2.'): - raise ImportError('Expected tensorflow version 1.x or 2.x, but found {0}' - .format(tf.version.VERSION)) + raise ImportError( + f'Expected tensorflow version 1.x or 2.x, but found {tf.version.VERSION}') @staticmethod def _tf_convert_from_saved_model(saved_model_dir): @@ -606,42 +597,6 @@ def as_dict(self, for_upload=False): return {'gcsTfliteUri': self._gcs_tflite_uri} - -class TFLiteAutoMlSource(TFLiteModelSource): - """TFLite model source representing a tflite model created with AutoML. - - AutoML model support is deprecated and will be removed in the next major version. - """ - - def __init__(self, auto_ml_model, app=None): - warnings.warn('AutoML model support is deprecated and will be removed in the next ' - 'major version.', DeprecationWarning) - self._app = app - self.auto_ml_model = auto_ml_model - - def __eq__(self, other): - if isinstance(other, self.__class__): - return self.auto_ml_model == other.auto_ml_model - return False - - def __ne__(self, other): - return not self.__eq__(other) - - @property - def auto_ml_model(self): - """Resource name of the model, created by the AutoML API or Cloud console.""" - return self._auto_ml_model - - @auto_ml_model.setter - def auto_ml_model(self, auto_ml_model): - self._auto_ml_model = _validate_auto_ml_model(auto_ml_model) - - def as_dict(self, for_upload=False): - """Returns a serializable representation of the object.""" - # Upload is irrelevant for auto_ml models - return {'automlModel': self._auto_ml_model} - - class ListModelsPage: """Represents a page of models in a Firebase project. @@ -721,7 +676,7 @@ def __init__(self, current_page): self._current_page = current_page self._index = 0 - def next(self): + def __next__(self): if self._index == len(self._current_page.models): if self._current_page.has_next_page: self._current_page = self._current_page.get_next_page() @@ -732,9 +687,6 @@ def next(self): return result raise StopIteration - def __next__(self): - return self.next() - def __iter__(self): return self @@ -789,11 +741,6 @@ def _validate_gcs_tflite_uri(uri): raise ValueError('GCS TFLite URI format is invalid.') return uri -def _validate_auto_ml_model(model): - if not _AUTO_ML_MODEL_PATTERN.match(model): - raise ValueError('Model resource name format is invalid.') - return model - def _validate_model_format(model_format): if not isinstance(model_format, ModelFormat): @@ -813,8 +760,8 @@ def _validate_page_size(page_size): # Specifically type() to disallow boolean which is a subtype of int raise TypeError('Page size must be a number or None.') if page_size < 1 or page_size > _MAX_PAGE_SIZE: - raise ValueError('Page size must be a positive integer between ' - '1 and {0}'.format(_MAX_PAGE_SIZE)) + raise ValueError( + f'Page size must be a positive integer between 1 and {_MAX_PAGE_SIZE}') def _validate_page_token(page_token): @@ -839,7 +786,7 @@ def __init__(self, app): 'projectId option, or use service account credentials.') self._project_url = _MLService.PROJECT_URL.format(self._project_id) ml_headers = { - 'X-FIREBASE-CLIENT': 'fire-admin-python/{0}'.format(firebase_admin.__version__), + 'X-FIREBASE-CLIENT': f'fire-admin-python/{firebase_admin.__version__}', } self._client = _http_client.JsonHttpClient( credential=app.credential.get_credential(), @@ -936,9 +883,9 @@ def create_model(self, model): def update_model(self, model, update_mask=None): _validate_model(model, update_mask) - path = 'models/{0}'.format(model.model_id) + path = f'models/{model.model_id}' if update_mask is not None: - path = path + '?updateMask={0}'.format(update_mask) + path = path + f'?updateMask={update_mask}' try: return self.handle_operation( self._client.body('patch', url=path, json=model.as_dict(for_upload=True))) @@ -947,7 +894,7 @@ def update_model(self, model, update_mask=None): def set_published(self, model_id, publish): _validate_model_id(model_id) - model_name = 'projects/{0}/models/{1}'.format(self._project_id, model_id) + model_name = f'projects/{self._project_id}/models/{model_id}' model = Model.from_dict({ 'name': model_name, 'state': { @@ -959,7 +906,7 @@ def set_published(self, model_id, publish): def get_model(self, model_id): _validate_model_id(model_id) try: - return self._client.body('get', url='models/{0}'.format(model_id)) + return self._client.body('get', url=f'models/{model_id}') except requests.exceptions.RequestException as error: raise _utils.handle_platform_error_from_requests(error) @@ -987,6 +934,6 @@ def list_models(self, list_filter, page_size, page_token): def delete_model(self, model_id): _validate_model_id(model_id) try: - self._client.body('delete', url='models/{0}'.format(model_id)) + self._client.body('delete', url=f'models/{model_id}') except requests.exceptions.RequestException as error: raise _utils.handle_platform_error_from_requests(error) diff --git a/firebase_admin/project_management.py b/firebase_admin/project_management.py index ed292b80f..73c100d3a 100644 --- a/firebase_admin/project_management.py +++ b/firebase_admin/project_management.py @@ -118,13 +118,13 @@ def create_ios_app(bundle_id, display_name=None, app=None): def _check_is_string_or_none(obj, field_name): if obj is None or isinstance(obj, str): return obj - raise ValueError('{0} must be a string.'.format(field_name)) + raise ValueError(f'{field_name} must be a string.') def _check_is_nonempty_string(obj, field_name): if isinstance(obj, str) and obj: return obj - raise ValueError('{0} must be a non-empty string.'.format(field_name)) + raise ValueError(f'{field_name} must be a non-empty string.') def _check_is_nonempty_string_or_none(obj, field_name): @@ -135,7 +135,7 @@ def _check_is_nonempty_string_or_none(obj, field_name): def _check_not_none(obj, field_name): if obj is None: - raise ValueError('{0} cannot be None.'.format(field_name)) + raise ValueError(f'{field_name} cannot be None.') return obj @@ -338,7 +338,7 @@ class AndroidAppMetadata(_AppMetadata): def __init__(self, package_name, name, app_id, display_name, project_id): """Clients should not instantiate this class directly.""" - super(AndroidAppMetadata, self).__init__(name, app_id, display_name, project_id) + super().__init__(name, app_id, display_name, project_id) self._package_name = _check_is_nonempty_string(package_name, 'package_name') @property @@ -347,7 +347,7 @@ def package_name(self): return self._package_name def __eq__(self, other): - return (super(AndroidAppMetadata, self).__eq__(other) and + return (super().__eq__(other) and self.package_name == other.package_name) def __ne__(self, other): @@ -363,7 +363,7 @@ class IOSAppMetadata(_AppMetadata): def __init__(self, bundle_id, name, app_id, display_name, project_id): """Clients should not instantiate this class directly.""" - super(IOSAppMetadata, self).__init__(name, app_id, display_name, project_id) + super().__init__(name, app_id, display_name, project_id) self._bundle_id = _check_is_nonempty_string(bundle_id, 'bundle_id') @property @@ -372,7 +372,7 @@ def bundle_id(self): return self._bundle_id def __eq__(self, other): - return super(IOSAppMetadata, self).__eq__(other) and self.bundle_id == other.bundle_id + return super().__eq__(other) and self.bundle_id == other.bundle_id def __ne__(self, other): return not self.__eq__(other) @@ -477,7 +477,7 @@ def __init__(self, app): 'set the projectId option, or use service account credentials. Alternatively, set ' 'the GOOGLE_CLOUD_PROJECT environment variable.') self._project_id = project_id - version_header = 'Python/Admin/{0}'.format(firebase_admin.__version__) + version_header = f'Python/Admin/{firebase_admin.__version__}' timeout = app.options.get('httpTimeout', _http_client.DEFAULT_TIMEOUT_SECONDS) self._client = _http_client.JsonHttpClient( credential=app.credential.get_credential(), @@ -502,7 +502,7 @@ def get_ios_app_metadata(self, app_id): def _get_app_metadata(self, platform_resource_name, identifier_name, metadata_class, app_id): """Retrieves detailed information about an Android or iOS app.""" _check_is_nonempty_string(app_id, 'app_id') - path = '/v1beta1/projects/-/{0}/{1}'.format(platform_resource_name, app_id) + path = f'/v1beta1/projects/-/{platform_resource_name}/{app_id}' response = self._make_request('get', path) return metadata_class( response[identifier_name], @@ -525,8 +525,7 @@ def set_ios_app_display_name(self, app_id, new_display_name): def _set_display_name(self, app_id, new_display_name, platform_resource_name): """Sets the display name of an Android or iOS app.""" - path = '/v1beta1/projects/-/{0}/{1}?updateMask=displayName'.format( - platform_resource_name, app_id) + path = f'/v1beta1/projects/-/{platform_resource_name}/{app_id}?updateMask=displayName' request_body = {'displayName': new_display_name} self._make_request('patch', path, json=request_body) @@ -542,10 +541,10 @@ def list_ios_apps(self): def _list_apps(self, platform_resource_name, app_class): """Lists all the Android or iOS apps within the Firebase project.""" - path = '/v1beta1/projects/{0}/{1}?pageSize={2}'.format( - self._project_id, - platform_resource_name, - _ProjectManagementService.MAXIMUM_LIST_APPS_PAGE_SIZE) + path = ( + f'/v1beta1/projects/{self._project_id}/{platform_resource_name}?pageSize=' + f'{_ProjectManagementService.MAXIMUM_LIST_APPS_PAGE_SIZE}' + ) response = self._make_request('get', path) apps_list = [] while True: @@ -557,11 +556,11 @@ def _list_apps(self, platform_resource_name, app_class): if not next_page_token: break # Retrieve the next page of apps. - path = '/v1beta1/projects/{0}/{1}?pageToken={2}&pageSize={3}'.format( - self._project_id, - platform_resource_name, - next_page_token, - _ProjectManagementService.MAXIMUM_LIST_APPS_PAGE_SIZE) + path = ( + f'/v1beta1/projects/{self._project_id}/{platform_resource_name}' + f'?pageToken={next_page_token}' + f'&pageSize={_ProjectManagementService.MAXIMUM_LIST_APPS_PAGE_SIZE}' + ) response = self._make_request('get', path) return apps_list @@ -590,7 +589,7 @@ def _create_app( app_class): """Creates an Android or iOS app.""" _check_is_string_or_none(display_name, 'display_name') - path = '/v1beta1/projects/{0}/{1}'.format(self._project_id, platform_resource_name) + path = f'/v1beta1/projects/{self._project_id}/{platform_resource_name}' request_body = {identifier_name: identifier} if display_name: request_body['displayName'] = display_name @@ -606,7 +605,7 @@ def _poll_app_creation(self, operation_name): _ProjectManagementService.POLL_EXPONENTIAL_BACKOFF_FACTOR, current_attempt) wait_time_seconds = delay_factor * _ProjectManagementService.POLL_BASE_WAIT_TIME_SECONDS time.sleep(wait_time_seconds) - path = '/v1/{0}'.format(operation_name) + path = f'/v1/{operation_name}' poll_response, http_response = self._body_and_response('get', path) done = poll_response.get('done') if done: @@ -629,20 +628,20 @@ def get_ios_app_config(self, app_id): platform_resource_name=_ProjectManagementService.IOS_APPS_RESOURCE_NAME, app_id=app_id) def _get_app_config(self, platform_resource_name, app_id): - path = '/v1beta1/projects/-/{0}/{1}/config'.format(platform_resource_name, app_id) + path = f'/v1beta1/projects/-/{platform_resource_name}/{app_id}/config' response = self._make_request('get', path) # In Python 2.7, the base64 module works with strings, while in Python 3, it works with # bytes objects. This line works in both versions. return base64.standard_b64decode(response['configFileContents']).decode(encoding='utf-8') def get_sha_certificates(self, app_id): - path = '/v1beta1/projects/-/androidApps/{0}/sha'.format(app_id) + path = f'/v1beta1/projects/-/androidApps/{app_id}/sha' response = self._make_request('get', path) cert_list = response.get('certificates') or [] return [SHACertificate(sha_hash=cert['shaHash'], name=cert['name']) for cert in cert_list] def add_sha_certificate(self, app_id, certificate_to_add): - path = '/v1beta1/projects/-/androidApps/{0}/sha'.format(app_id) + path = f'/v1beta1/projects/-/androidApps/{app_id}/sha' sha_hash = _check_not_none(certificate_to_add, 'certificate_to_add').sha_hash cert_type = certificate_to_add.cert_type request_body = {'shaHash': sha_hash, 'certType': cert_type} @@ -650,7 +649,7 @@ def add_sha_certificate(self, app_id, certificate_to_add): def delete_sha_certificate(self, certificate_to_delete): name = _check_not_none(certificate_to_delete, 'certificate_to_delete').name - path = '/v1beta1/{0}'.format(name) + path = f'/v1beta1/{name}' self._make_request('delete', path) def _make_request(self, method, url, json=None): diff --git a/firebase_admin/remote_config.py b/firebase_admin/remote_config.py index 943141ccf..880804d3d 100644 --- a/firebase_admin/remote_config.py +++ b/firebase_admin/remote_config.py @@ -251,7 +251,7 @@ def __init__(self, app): self._project_id = app.project_id app_credential = app.credential.get_credential() rc_headers = { - 'X-FIREBASE-CLIENT': 'fire-admin-python/{0}'.format(firebase_admin.__version__), } + 'X-FIREBASE-CLIENT': f'fire-admin-python/{firebase_admin.__version__}', } timeout = app.options.get('httpTimeout', _http_client.DEFAULT_TIMEOUT_SECONDS) self._client = _http_client.JsonHttpClient(credential=app_credential, @@ -268,14 +268,12 @@ async def get_server_template(self): 'get', self._get_url()) except requests.exceptions.RequestException as error: raise self._handle_remote_config_error(error) - else: - template_data['etag'] = headers.get('etag') - return _ServerTemplateData(template_data) + template_data['etag'] = headers.get('etag') + return _ServerTemplateData(template_data) def _get_url(self): """Returns project prefix for url, in the format of /v1/projects/${projectId}""" - return "/v1/projects/{0}/namespaces/firebase-server/serverRemoteConfig".format( - self._project_id) + return f"/v1/projects/{self._project_id}/namespaces/firebase-server/serverRemoteConfig" @classmethod def _handle_remote_config_error(cls, error: Any): diff --git a/firebase_admin/storage.py b/firebase_admin/storage.py index b6084842a..d2f004be6 100644 --- a/firebase_admin/storage.py +++ b/firebase_admin/storage.py @@ -21,9 +21,9 @@ # pylint: disable=import-error,no-name-in-module try: from google.cloud import storage -except ImportError: +except ImportError as exception: raise ImportError('Failed to import the Cloud Storage library for Python. Make sure ' - 'to install the "google-cloud-storage" module.') + 'to install the "google-cloud-storage" module.') from exception from firebase_admin import _utils @@ -82,6 +82,6 @@ def bucket(self, name=None): 'name explicitly when calling the storage.bucket() function.') if not bucket_name or not isinstance(bucket_name, str): raise ValueError( - 'Invalid storage bucket name: "{0}". Bucket name must be a non-empty ' - 'string.'.format(bucket_name)) + f'Invalid storage bucket name: "{bucket_name}". Bucket name must be a non-empty ' + 'string.') return self._client.bucket(bucket_name) diff --git a/firebase_admin/tenant_mgt.py b/firebase_admin/tenant_mgt.py index 8c53e30a1..9e713d988 100644 --- a/firebase_admin/tenant_mgt.py +++ b/firebase_admin/tenant_mgt.py @@ -205,7 +205,7 @@ class Tenant: def __init__(self, data): if not isinstance(data, dict): - raise ValueError('Invalid data argument in Tenant constructor: {0}'.format(data)) + raise ValueError(f'Invalid data argument in Tenant constructor: {data}') if not 'name' in data: raise ValueError('Tenant response missing required keys.') @@ -236,8 +236,8 @@ class _TenantManagementService: def __init__(self, app): credential = app.credential.get_credential() - version_header = 'Python/Admin/{0}'.format(firebase_admin.__version__) - base_url = '{0}/projects/{1}'.format(self.TENANT_MGT_URL, app.project_id) + version_header = f'Python/Admin/{firebase_admin.__version__}' + base_url = f'{self.TENANT_MGT_URL}/projects/{app.project_id}' self.app = app self.client = _http_client.JsonHttpClient( credential=credential, base_url=base_url, headers={'X-Client-Version': version_header}) @@ -248,7 +248,7 @@ def auth_for_tenant(self, tenant_id): """Gets an Auth Client instance scoped to the given tenant ID.""" if not isinstance(tenant_id, str) or not tenant_id: raise ValueError( - 'Invalid tenant ID: {0}. Tenant ID must be a non-empty string.'.format(tenant_id)) + f'Invalid tenant ID: {tenant_id}. Tenant ID must be a non-empty string.') with self.lock: if tenant_id in self.tenant_clients: @@ -262,14 +262,13 @@ def get_tenant(self, tenant_id): """Gets the tenant corresponding to the given ``tenant_id``.""" if not isinstance(tenant_id, str) or not tenant_id: raise ValueError( - 'Invalid tenant ID: {0}. Tenant ID must be a non-empty string.'.format(tenant_id)) + f'Invalid tenant ID: {tenant_id}. Tenant ID must be a non-empty string.') try: - body = self.client.body('get', '/tenants/{0}'.format(tenant_id)) + body = self.client.body('get', f'/tenants/{tenant_id}') except requests.exceptions.RequestException as error: raise _auth_utils.handle_auth_backend_error(error) - else: - return Tenant(body) + return Tenant(body) def create_tenant( self, display_name, allow_password_sign_up=None, enable_email_link_sign_in=None): @@ -287,8 +286,7 @@ def create_tenant( body = self.client.body('post', '/tenants', json=payload) except requests.exceptions.RequestException as error: raise _auth_utils.handle_auth_backend_error(error) - else: - return Tenant(body) + return Tenant(body) def update_tenant( self, tenant_id, display_name=None, allow_password_sign_up=None, @@ -310,24 +308,23 @@ def update_tenant( if not payload: raise ValueError('At least one parameter must be specified for update.') - url = '/tenants/{0}'.format(tenant_id) + url = f'/tenants/{tenant_id}' update_mask = ','.join(_auth_utils.build_update_mask(payload)) - params = 'updateMask={0}'.format(update_mask) + params = f'updateMask={update_mask}' try: body = self.client.body('patch', url, json=payload, params=params) except requests.exceptions.RequestException as error: raise _auth_utils.handle_auth_backend_error(error) - else: - return Tenant(body) + return Tenant(body) def delete_tenant(self, tenant_id): """Deletes the tenant corresponding to the given ``tenant_id``.""" if not isinstance(tenant_id, str) or not tenant_id: raise ValueError( - 'Invalid tenant ID: {0}. Tenant ID must be a non-empty string.'.format(tenant_id)) + f'Invalid tenant ID: {tenant_id}. Tenant ID must be a non-empty string.') try: - self.client.request('delete', '/tenants/{0}'.format(tenant_id)) + self.client.request('delete', f'/tenants/{tenant_id}') except requests.exceptions.RequestException as error: raise _auth_utils.handle_auth_backend_error(error) @@ -341,7 +338,7 @@ def list_tenants(self, page_token=None, max_results=_MAX_LIST_TENANTS_RESULTS): if max_results < 1 or max_results > _MAX_LIST_TENANTS_RESULTS: raise ValueError( 'Max results must be a positive integer less than or equal to ' - '{0}.'.format(_MAX_LIST_TENANTS_RESULTS)) + f'{_MAX_LIST_TENANTS_RESULTS}.') payload = {'pageSize': max_results} if page_token: @@ -417,7 +414,7 @@ def __init__(self, current_page): self._current_page = current_page self._index = 0 - def next(self): + def __next__(self): if self._index == len(self._current_page.tenants): if self._current_page.has_next_page: self._current_page = self._current_page.get_next_page() @@ -428,9 +425,6 @@ def next(self): return result raise StopIteration - def __next__(self): - return self.next() - def __iter__(self): return self diff --git a/integration/conftest.py b/integration/conftest.py index efa45932d..ebaf9297a 100644 --- a/integration/conftest.py +++ b/integration/conftest.py @@ -16,7 +16,6 @@ import json import pytest -from pytest_asyncio import is_async_test import firebase_admin from firebase_admin import credentials @@ -37,7 +36,7 @@ def _get_cert_path(request): def integration_conf(request): cert_path = _get_cert_path(request) - with open(cert_path) as cert: + with open(cert_path, encoding='utf-8') as cert: project_id = json.load(cert).get('project_id') if not project_id: raise ValueError('Failed to determine project ID from service account certificate.') @@ -58,8 +57,8 @@ def default_app(request): """ cred, project_id = integration_conf(request) ops = { - 'databaseURL' : 'https://{0}.firebaseio.com'.format(project_id), - 'storageBucket' : '{0}.appspot.com'.format(project_id) + 'databaseURL' : f'https://{project_id}.firebaseio.com', + 'storageBucket' : f'{project_id}.appspot.com' } return firebase_admin.initialize_app(cred, ops) @@ -69,11 +68,5 @@ def api_key(request): if not path: raise ValueError('API key file not specified. Make sure to specify the "--apikey" ' 'command-line option.') - with open(path) as keyfile: + with open(path, encoding='utf-8') as keyfile: return keyfile.read().strip() - -def pytest_collection_modifyitems(items): - pytest_asyncio_tests = (item for item in items if is_async_test(item)) - session_scope_marker = pytest.mark.asyncio(loop_scope="session") - for async_test in pytest_asyncio_tests: - async_test.add_marker(session_scope_marker, append=False) diff --git a/integration/test_auth.py b/integration/test_auth.py index e1d01a254..7f4725dfe 100644 --- a/integration/test_auth.py +++ b/integration/test_auth.py @@ -30,6 +30,7 @@ import firebase_admin from firebase_admin import auth from firebase_admin import credentials +from firebase_admin._http_client import DEFAULT_TIMEOUT_SECONDS as timeout _verify_token_url = 'https://www.googleapis.com/identitytoolkit/v3/relyingparty/verifyCustomToken' @@ -67,14 +68,14 @@ def _sign_in(custom_token, api_key): body = {'token' : custom_token.decode(), 'returnSecureToken' : True} params = {'key' : api_key} - resp = requests.request('post', _verify_token_url, params=params, json=body) + resp = requests.request('post', _verify_token_url, params=params, json=body, timeout=timeout) resp.raise_for_status() return resp.json().get('idToken') def _sign_in_with_password(email, password, api_key): body = {'email': email, 'password': password, 'returnSecureToken': True} params = {'key' : api_key} - resp = requests.request('post', _verify_password_url, params=params, json=body) + resp = requests.request('post', _verify_password_url, params=params, json=body, timeout=timeout) resp.raise_for_status() return resp.json().get('idToken') @@ -84,7 +85,7 @@ def _random_string(length=10): def _random_id(): random_id = str(uuid.uuid4()).lower().replace('-', '') - email = 'test{0}@example.{1}.com'.format(random_id[:12], random_id[12:]) + email = f'test{random_id[:12]}@example.{random_id[12:]}.com' return random_id, email def _random_phone(): @@ -93,21 +94,21 @@ def _random_phone(): def _reset_password(oob_code, new_password, api_key): body = {'oobCode': oob_code, 'newPassword': new_password} params = {'key' : api_key} - resp = requests.request('post', _password_reset_url, params=params, json=body) + resp = requests.request('post', _password_reset_url, params=params, json=body, timeout=timeout) resp.raise_for_status() return resp.json().get('email') def _verify_email(oob_code, api_key): body = {'oobCode': oob_code} params = {'key' : api_key} - resp = requests.request('post', _verify_email_url, params=params, json=body) + resp = requests.request('post', _verify_email_url, params=params, json=body, timeout=timeout) resp.raise_for_status() return resp.json().get('email') def _sign_in_with_email_link(email, oob_code, api_key): body = {'oobCode': oob_code, 'email': email} params = {'key' : api_key} - resp = requests.request('post', _email_sign_in_url, params=params, json=body) + resp = requests.request('post', _email_sign_in_url, params=params, json=body, timeout=timeout) resp.raise_for_status() return resp.json().get('idToken') @@ -870,7 +871,7 @@ def test_delete_saml_provider_config(): def _create_oidc_provider_config(): - provider_id = 'oidc.{0}'.format(_random_string()) + provider_id = f'oidc.{_random_string()}' return auth.create_oidc_provider_config( provider_id=provider_id, client_id='OIDC_CLIENT_ID', @@ -882,7 +883,7 @@ def _create_oidc_provider_config(): def _create_saml_provider_config(): - provider_id = 'saml.{0}'.format(_random_string()) + provider_id = f'saml.{_random_string()}' return auth.create_saml_provider_config( provider_id=provider_id, idp_entity_id='IDP_ENTITY_ID', diff --git a/integration/test_db.py b/integration/test_db.py index 0170743dd..1ceb0b992 100644 --- a/integration/test_db.py +++ b/integration/test_db.py @@ -39,7 +39,7 @@ def integration_conf(request): def app(request): cred, project_id = integration_conf(request) ops = { - 'databaseURL' : 'https://{0}.firebaseio.com'.format(project_id), + 'databaseURL' : f'https://{project_id}.firebaseio.com', } return firebase_admin.initialize_app(cred, ops, name='integration-db') @@ -53,7 +53,7 @@ def default_app(): @pytest.fixture(scope='module') def update_rules(app): - with open(testutils.resource_filename('dinosaurs_index.json')) as rules_file: + with open(testutils.resource_filename('dinosaurs_index.json'), encoding='utf-8') as rules_file: new_rules = json.load(rules_file) client = db.reference('', app)._client rules = client.body('get', '/.settings/rules.json', params='format=strict') @@ -64,7 +64,7 @@ def update_rules(app): @pytest.fixture(scope='module') def testdata(): - with open(testutils.resource_filename('dinosaurs.json')) as dino_file: + with open(testutils.resource_filename('dinosaurs.json'), encoding='utf-8') as dino_file: return json.load(dino_file) @pytest.fixture(scope='module') @@ -195,8 +195,8 @@ def test_update_nested_children(self, testref): edward = python.child('users').push({'name' : 'Edward Cope', 'since' : 1800}) jack = python.child('users').push({'name' : 'Jack Horner', 'since' : 1940}) delta = { - '{0}/since'.format(edward.key) : 1840, - '{0}/since'.format(jack.key) : 1946 + f'{edward.key}/since' : 1840, + f'{jack.key}/since' : 1946 } python.child('users').update(delta) assert edward.get() == {'name' : 'Edward Cope', 'since' : 1840} @@ -363,7 +363,7 @@ def override_app(request, update_rules): del update_rules cred, project_id = integration_conf(request) ops = { - 'databaseURL' : 'https://{0}.firebaseio.com'.format(project_id), + 'databaseURL' : f'https://{project_id}.firebaseio.com', 'databaseAuthVariableOverride' : {'uid' : 'user1'} } app = firebase_admin.initialize_app(cred, ops, 'db-override') @@ -375,7 +375,7 @@ def none_override_app(request, update_rules): del update_rules cred, project_id = integration_conf(request) ops = { - 'databaseURL' : 'https://{0}.firebaseio.com'.format(project_id), + 'databaseURL' : f'https://{project_id}.firebaseio.com', 'databaseAuthVariableOverride' : None } app = firebase_admin.initialize_app(cred, ops, 'db-none-override') diff --git a/integration/test_firestore.py b/integration/test_firestore.py index fd39d9b8a..96cdd3fb1 100644 --- a/integration/test_firestore.py +++ b/integration/test_firestore.py @@ -18,16 +18,16 @@ from firebase_admin import firestore _CITY = { - 'name': u'Mountain View', - 'country': u'USA', + 'name': 'Mountain View', + 'country': 'USA', 'population': 77846, 'capital': False } _MOVIE = { - 'Name': u'Interstellar', + 'Name': 'Interstellar', 'Year': 2014, - 'Runtime': u'2h 49m', + 'Runtime': '2h 49m', 'Academy Award Winner': True } @@ -35,8 +35,8 @@ def test_firestore(): client = firestore.client() expected = { - 'name': u'Mountain View', - 'country': u'USA', + 'name': 'Mountain View', + 'country': 'USA', 'population': 77846, 'capital': False } @@ -93,7 +93,7 @@ def test_firestore_multi_db(): def test_server_timestamp(): client = firestore.client() expected = { - 'name': u'Mountain View', + 'name': 'Mountain View', 'timestamp': firestore.SERVER_TIMESTAMP # pylint: disable=no-member } doc = client.collection('cities').document() diff --git a/integration/test_firestore_async.py b/integration/test_firestore_async.py index 8b73dda0f..e899f25b2 100644 --- a/integration/test_firestore_async.py +++ b/integration/test_firestore_async.py @@ -20,21 +20,21 @@ from firebase_admin import firestore_async _CITY = { - 'name': u'Mountain View', - 'country': u'USA', + 'name': 'Mountain View', + 'country': 'USA', 'population': 77846, 'capital': False } _MOVIE = { - 'Name': u'Interstellar', + 'Name': 'Interstellar', 'Year': 2014, - 'Runtime': u'2h 49m', + 'Runtime': '2h 49m', 'Academy Award Winner': True } -@pytest.mark.asyncio +@pytest.mark.asyncio(loop_scope="session") async def test_firestore_async(): client = firestore_async.client() expected = _CITY @@ -48,7 +48,7 @@ async def test_firestore_async(): data = await doc.get() assert data.exists is False -@pytest.mark.asyncio +@pytest.mark.asyncio(loop_scope="session") async def test_firestore_async_explicit_database_id(): client = firestore_async.client(database_id='testing-database') expected = _CITY @@ -62,7 +62,7 @@ async def test_firestore_async_explicit_database_id(): data = await doc.get() assert data.exists is False -@pytest.mark.asyncio +@pytest.mark.asyncio(loop_scope="session") async def test_firestore_async_multi_db(): city_client = firestore_async.client() movie_client = firestore_async.client(database_id='testing-database') @@ -98,11 +98,11 @@ async def test_firestore_async_multi_db(): assert data[0].exists is False assert data[1].exists is False -@pytest.mark.asyncio +@pytest.mark.asyncio(loop_scope="session") async def test_server_timestamp(): client = firestore_async.client() expected = { - 'name': u'Mountain View', + 'name': 'Mountain View', 'timestamp': firestore_async.SERVER_TIMESTAMP # pylint: disable=no-member } doc = client.collection('cities').document() diff --git a/integration/test_messaging.py b/integration/test_messaging.py index 296a4d338..e72086741 100644 --- a/integration/test_messaging.py +++ b/integration/test_messaging.py @@ -121,7 +121,7 @@ def test_send_each(): def test_send_each_500(): messages = [] for msg_number in range(500): - topic = 'foo-bar-{0}'.format(msg_number % 10) + topic = f'foo-bar-{msg_number % 10}' messages.append(messaging.Message(topic=topic)) batch_response = messaging.send_each(messages, dry_run=True) @@ -149,71 +149,6 @@ def test_send_each_for_multicast(): assert response.exception is not None assert response.message_id is None -@pytest.mark.skip(reason="Replaced with test_send_each") -def test_send_all(): - messages = [ - messaging.Message( - topic='foo-bar', notification=messaging.Notification('Title', 'Body')), - messaging.Message( - topic='foo-bar', notification=messaging.Notification('Title', 'Body')), - messaging.Message( - token='not-a-token', notification=messaging.Notification('Title', 'Body')), - ] - - batch_response = messaging.send_all(messages, dry_run=True) - - assert batch_response.success_count == 2 - assert batch_response.failure_count == 1 - assert len(batch_response.responses) == 3 - - response = batch_response.responses[0] - assert response.success is True - assert response.exception is None - assert re.match('^projects/.*/messages/.*$', response.message_id) - - response = batch_response.responses[1] - assert response.success is True - assert response.exception is None - assert re.match('^projects/.*/messages/.*$', response.message_id) - - response = batch_response.responses[2] - assert response.success is False - assert isinstance(response.exception, exceptions.InvalidArgumentError) - assert response.message_id is None - -@pytest.mark.skip(reason="Replaced with test_send_each_500") -def test_send_all_500(): - messages = [] - for msg_number in range(500): - topic = 'foo-bar-{0}'.format(msg_number % 10) - messages.append(messaging.Message(topic=topic)) - - batch_response = messaging.send_all(messages, dry_run=True) - - assert batch_response.success_count == 500 - assert batch_response.failure_count == 0 - assert len(batch_response.responses) == 500 - for response in batch_response.responses: - assert response.success is True - assert response.exception is None - assert re.match('^projects/.*/messages/.*$', response.message_id) - -@pytest.mark.skip(reason="Replaced with test_send_each_for_multicast") -def test_send_multicast(): - multicast = messaging.MulticastMessage( - notification=messaging.Notification('Title', 'Body'), - tokens=['not-a-token', 'also-not-a-token']) - - batch_response = messaging.send_multicast(multicast) - - assert batch_response.success_count == 0 - assert batch_response.failure_count == 2 - assert len(batch_response.responses) == 2 - for response in batch_response.responses: - assert response.success is False - assert response.exception is not None - assert response.message_id is None - def test_subscribe(): resp = messaging.subscribe_to_topic(_REGISTRATION_TOKEN, 'mock-topic') assert resp.success_count + resp.failure_count == 1 @@ -222,7 +157,7 @@ def test_unsubscribe(): resp = messaging.unsubscribe_from_topic(_REGISTRATION_TOKEN, 'mock-topic') assert resp.success_count + resp.failure_count == 1 -@pytest.mark.asyncio +@pytest.mark.asyncio(loop_scope="session") async def test_send_each_async(): messages = [ messaging.Message( @@ -254,11 +189,11 @@ async def test_send_each_async(): assert isinstance(response.exception, exceptions.InvalidArgumentError) assert response.message_id is None -@pytest.mark.asyncio +@pytest.mark.asyncio(loop_scope="session") async def test_send_each_async_500(): messages = [] for msg_number in range(500): - topic = 'foo-bar-{0}'.format(msg_number % 10) + topic = f'foo-bar-{msg_number % 10}' messages.append(messaging.Message(topic=topic)) batch_response = await messaging.send_each_async(messages, dry_run=True) @@ -271,7 +206,7 @@ async def test_send_each_async_500(): assert response.exception is None assert re.match('^projects/.*/messages/.*$', response.message_id) -@pytest.mark.asyncio +@pytest.mark.asyncio(loop_scope="session") async def test_send_each_for_multicast_async(): multicast = messaging.MulticastMessage( notification=messaging.Notification('Title', 'Body'), diff --git a/integration/test_ml.py b/integration/test_ml.py index 52cb1bb7e..ea5b10be9 100644 --- a/integration/test_ml.py +++ b/integration/test_ml.py @@ -22,29 +22,22 @@ import pytest -import firebase_admin from firebase_admin import exceptions from firebase_admin import ml from tests import testutils -# pylint: disable=import-error,no-name-in-module +# pylint: disable=import-error, no-member try: import tensorflow as tf _TF_ENABLED = True except ImportError: _TF_ENABLED = False -try: - from google.cloud import automl_v1 - _AUTOML_ENABLED = True -except ImportError: - _AUTOML_ENABLED = False - def _random_identifier(prefix): #pylint: disable=unused-variable suffix = ''.join([random.choice(string.ascii_letters + string.digits) for n in range(8)]) - return '{0}_{1}'.format(prefix, suffix) + return f'{prefix}_{suffix}' NAME_ONLY_ARGS = { @@ -159,14 +152,6 @@ def check_tflite_gcs_format(model, validation_error=None): assert model.model_hash is not None -def check_tflite_automl_format(model): - assert model.validation_error is None - assert model.published is False - assert model.model_format.model_source.auto_ml_model.startswith('projects/') - # Automl models don't have validation errors since they are references - # to valid automl models. - - @pytest.mark.parametrize('firebase_model', [NAME_AND_TAGS_ARGS], indirect=True) def test_create_simple_model(firebase_model): check_model(firebase_model, NAME_AND_TAGS_ARGS) @@ -185,7 +170,7 @@ def test_create_already_existing_fails(firebase_model): ml.create_model(model=firebase_model) check_operation_error( excinfo, - 'Model \'{0}\' already exists'.format(firebase_model.display_name)) + f'Model \'{firebase_model.display_name}\' already exists') @pytest.mark.parametrize('firebase_model', [INVALID_FULL_MODEL_ARGS], indirect=True) @@ -234,7 +219,7 @@ def test_update_non_existing_model(firebase_model): ml.update_model(firebase_model) check_operation_error( excinfo, - 'Model \'{0}\' was not found'.format(firebase_model.as_dict().get('name'))) + f'Model \'{firebase_model.as_dict().get("name")}\' was not found') @pytest.mark.parametrize('firebase_model', [FULL_MODEL_ARGS], indirect=True) @@ -267,18 +252,17 @@ def test_publish_unpublish_non_existing_model(firebase_model): ml.publish_model(firebase_model.model_id) check_operation_error( excinfo, - 'Model \'{0}\' was not found'.format(firebase_model.as_dict().get('name'))) + f'Model \'{firebase_model.as_dict().get("name")}\' was not found') with pytest.raises(exceptions.NotFoundError) as excinfo: ml.unpublish_model(firebase_model.model_id) check_operation_error( excinfo, - 'Model \'{0}\' was not found'.format(firebase_model.as_dict().get('name'))) + f'Model \'{firebase_model.as_dict().get("name")}\' was not found') def test_list_models(model_list): - filter_str = 'displayName={0} OR tags:{1}'.format( - model_list[0].display_name, model_list[1].tags[0]) + filter_str = f'displayName={model_list[0].display_name} OR tags:{model_list[1].tags[0]}' all_models = ml.list_models(list_filter=filter_str) all_model_ids = [mdl.model_id for mdl in all_models.iterate_all()] @@ -317,12 +301,16 @@ def _clean_up_directory(save_dir): @pytest.fixture def keras_model(): assert _TF_ENABLED - x_array = [-1, 0, 1, 2, 3, 4] - y_array = [-3, -1, 1, 3, 5, 7] - model = tf.keras.models.Sequential( - [tf.keras.layers.Dense(units=1, input_shape=[1])]) + x_list = [-1, 0, 1, 2, 3, 4] + y_list = [-3, -1, 1, 3, 5, 7] + x_tensor = tf.convert_to_tensor(x_list, dtype=tf.float32) + y_tensor = tf.convert_to_tensor(y_list, dtype=tf.float32) + model = tf.keras.models.Sequential([ + tf.keras.Input(shape=(1,)), + tf.keras.layers.Dense(units=1) + ]) model.compile(optimizer='sgd', loss='mean_squared_error') - model.fit(x_array, y_array, epochs=3) + model.fit(x_tensor, y_tensor, epochs=3) return model @@ -388,50 +376,3 @@ def test_from_saved_model(saved_model_dir): assert created_model.validation_error is None finally: _clean_up_model(created_model) - - -# Test AutoML functionality if AutoML is enabled. -#'pip install google-cloud-automl' in the environment if you want _AUTOML_ENABLED = True -# You will also need a predefined AutoML model named 'admin_sdk_integ_test1' to run the -# successful test. (Test is skipped otherwise) - -@pytest.fixture -def automl_model(): - assert _AUTOML_ENABLED - - # It takes > 20 minutes to train a model, so we expect a predefined AutoMl - # model named 'admin_sdk_integ_test1' to exist in the project, or we skip - # the test. - automl_client = automl_v1.AutoMlClient() - project_id = firebase_admin.get_app().project_id - parent = automl_client.location_path(project_id, 'us-central1') - models = automl_client.list_models(parent, filter_="display_name=admin_sdk_integ_test1") - # Expecting exactly one. (Ok to use last one if somehow more than 1) - automl_ref = None - for model in models: - automl_ref = model.name - - # Skip if no pre-defined model. (It takes min > 20 minutes to train a model) - if automl_ref is None: - pytest.skip("No pre-existing AutoML model found. Skipping test") - - source = ml.TFLiteAutoMlSource(automl_ref) - tflite_format = ml.TFLiteFormat(model_source=source) - ml_model = ml.Model( - display_name=_random_identifier('TestModel_automl_'), - tags=['test_automl'], - model_format=tflite_format) - model = ml.create_model(model=ml_model) - yield model - _clean_up_model(model) - -@pytest.mark.skipif(not _AUTOML_ENABLED, reason='AutoML is required for this test.') -def test_automl_model(automl_model): - # This test looks for a predefined automl model with display_name = 'admin_sdk_integ_test1' - automl_model.wait_for_unlocked() - - check_model(automl_model, { - 'display_name': automl_model.display_name, - 'tags': ['test_automl'], - }) - check_tflite_automl_format(automl_model) diff --git a/integration/test_project_management.py b/integration/test_project_management.py index b0b7fa52a..ba2c5ec16 100644 --- a/integration/test_project_management.py +++ b/integration/test_project_management.py @@ -74,14 +74,13 @@ def test_create_android_app_already_exists(android_app): def test_android_set_display_name_and_get_metadata(android_app, project_id): app_id = android_app.app_id android_app = project_management.android_app(app_id) - new_display_name = '{0} helloworld {1}'.format( - TEST_APP_DISPLAY_NAME_PREFIX, random.randint(0, 10000)) + new_display_name = f'{TEST_APP_DISPLAY_NAME_PREFIX} helloworld {random.randint(0, 10000)}' android_app.set_display_name(new_display_name) metadata = project_management.android_app(app_id).get_metadata() android_app.set_display_name(TEST_APP_DISPLAY_NAME_PREFIX) # Revert the display name. - assert metadata._name == 'projects/{0}/androidApps/{1}'.format(project_id, app_id) + assert metadata._name == f'projects/{project_id}/androidApps/{app_id}' assert metadata.app_id == app_id assert metadata.project_id == project_id assert metadata.display_name == new_display_name @@ -149,15 +148,13 @@ def test_create_ios_app_already_exists(ios_app): def test_ios_set_display_name_and_get_metadata(ios_app, project_id): app_id = ios_app.app_id ios_app = project_management.ios_app(app_id) - new_display_name = '{0} helloworld {1}'.format( - TEST_APP_DISPLAY_NAME_PREFIX, random.randint(0, 10000)) + new_display_name = f'{TEST_APP_DISPLAY_NAME_PREFIX} helloworld {random.randint(0, 10000)}' ios_app.set_display_name(new_display_name) metadata = project_management.ios_app(app_id).get_metadata() ios_app.set_display_name(TEST_APP_DISPLAY_NAME_PREFIX) # Revert the display name. - assert metadata._name == 'projects/{0}/iosApps/{1}'.format(project_id, app_id) - assert metadata.app_id == app_id + assert metadata._name == f'projects/{project_id}/iosApps/{app_id}' assert metadata.project_id == project_id assert metadata.display_name == new_display_name assert metadata.bundle_id == TEST_APP_BUNDLE_ID diff --git a/integration/test_storage.py b/integration/test_storage.py index 729190950..32e4d86a3 100644 --- a/integration/test_storage.py +++ b/integration/test_storage.py @@ -20,10 +20,10 @@ def test_default_bucket(project_id): bucket = storage.bucket() - _verify_bucket(bucket, '{0}.appspot.com'.format(project_id)) + _verify_bucket(bucket, f'{project_id}.appspot.com') def test_custom_bucket(project_id): - bucket_name = '{0}.appspot.com'.format(project_id) + bucket_name = f'{project_id}.appspot.com' bucket = storage.bucket(bucket_name) _verify_bucket(bucket, bucket_name) @@ -33,12 +33,12 @@ def test_non_existing_bucket(): def _verify_bucket(bucket, expected_name): assert bucket.name == expected_name - file_name = 'data_{0}.txt'.format(int(time.time())) + file_name = f'data_{int(time.time())}.txt' blob = bucket.blob(file_name) blob.upload_from_string('Hello World') blob = bucket.get_blob(file_name) - assert blob.download_as_string().decode() == 'Hello World' + assert blob.download_as_bytes().decode() == 'Hello World' bucket.delete_blob(file_name) assert not bucket.get_blob(file_name) diff --git a/integration/test_tenant_mgt.py b/integration/test_tenant_mgt.py index c9eefd96e..f0bad58b2 100644 --- a/integration/test_tenant_mgt.py +++ b/integration/test_tenant_mgt.py @@ -25,6 +25,7 @@ from firebase_admin import auth from firebase_admin import tenant_mgt +from firebase_admin._http_client import DEFAULT_TIMEOUT_SECONDS as timeout from integration import test_auth @@ -359,7 +360,7 @@ def test_delete_saml_provider_config(sample_tenant): def _create_oidc_provider_config(client): - provider_id = 'oidc.{0}'.format(_random_string()) + provider_id = f'oidc.{_random_string()}' return client.create_oidc_provider_config( provider_id=provider_id, client_id='OIDC_CLIENT_ID', @@ -369,7 +370,7 @@ def _create_oidc_provider_config(client): def _create_saml_provider_config(client): - provider_id = 'saml.{0}'.format(_random_string()) + provider_id = f'saml.{_random_string()}' return client.create_saml_provider_config( provider_id=provider_id, idp_entity_id='IDP_ENTITY_ID', @@ -387,7 +388,7 @@ def _random_uid(): def _random_email(): random_id = str(uuid.uuid4()).lower().replace('-', '') - return 'test{0}@example.{1}.com'.format(random_id[:12], random_id[12:]) + return f'test{random_id[:12]}@example.{random_id[12:]}.com' def _random_phone(): @@ -412,6 +413,6 @@ def _sign_in(custom_token, tenant_id, api_key): 'tenantId': tenant_id, } params = {'key' : api_key} - resp = requests.request('post', VERIFY_TOKEN_URL, params=params, json=body) + resp = requests.request('post', VERIFY_TOKEN_URL, params=params, json=body, timeout=timeout) resp.raise_for_status() return resp.json().get('idToken') diff --git a/requirements.txt b/requirements.txt index ba6f2f947..ff15072a6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,16 +1,15 @@ -astroid == 2.3.3 -pylint == 2.3.1 -pytest >= 6.2.0 +astroid == 3.3.10 +pylint == 3.3.7 +pytest >= 8.2.2 pytest-cov >= 2.4.0 pytest-localserver >= 0.4.1 -pytest-asyncio >= 0.16.0 +pytest-asyncio >= 0.26.0 pytest-mock >= 3.6.1 respx == 0.22.0 -cachecontrol >= 0.12.14 -google-api-core[grpc] >= 1.22.1, < 3.0.0dev; platform.python_implementation != 'PyPy' -google-api-python-client >= 1.7.8 -google-cloud-firestore >= 2.19.0; platform.python_implementation != 'PyPy' -google-cloud-storage >= 1.37.1 -pyjwt[crypto] >= 2.5.0 +cachecontrol >= 0.14.3 +google-api-core[grpc] >= 2.25.1, < 3.0.0dev; platform.python_implementation != 'PyPy' +google-cloud-firestore >= 2.21.0; platform.python_implementation != 'PyPy' +google-cloud-storage >= 3.1.1 +pyjwt[crypto] >= 2.10.1 httpx[http2] == 0.28.1 \ No newline at end of file diff --git a/setup.cfg b/setup.cfg index 25c649748..32e00676b 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,2 +1,4 @@ [tool:pytest] testpaths = tests +asyncio_default_test_loop_scope = class +asyncio_default_fixture_loop_scope = None diff --git a/setup.py b/setup.py index e92d207aa..21e29332e 100644 --- a/setup.py +++ b/setup.py @@ -22,8 +22,8 @@ (major, minor) = (sys.version_info.major, sys.version_info.minor) -if major != 3 or minor < 7: - print('firebase_admin requires python >= 3.7', file=sys.stderr) +if major != 3 or minor < 9: + print('firebase_admin requires python >= 3.9', file=sys.stderr) sys.exit(1) # Read in the package metadata per recommendations from: @@ -37,12 +37,11 @@ long_description = ('The Firebase Admin Python SDK enables server-side (backend) Python developers ' 'to integrate Firebase into their services and applications.') install_requires = [ - 'cachecontrol>=0.12.14', - 'google-api-core[grpc] >= 1.22.1, < 3.0.0dev; platform.python_implementation != "PyPy"', - 'google-api-python-client >= 1.7.8', - 'google-cloud-firestore>=2.19.0; platform.python_implementation != "PyPy"', - 'google-cloud-storage>=1.37.1', - 'pyjwt[crypto] >= 2.5.0', + 'cachecontrol>=0.14.3', + 'google-api-core[grpc] >= 2.25.1, < 3.0.0dev; platform.python_implementation != "PyPy"', + 'google-cloud-firestore>=2.21.0; platform.python_implementation != "PyPy"', + 'google-cloud-storage>=3.1.1', + 'pyjwt[crypto] >= 2.10.1', 'httpx[http2] == 0.28.1', ] @@ -61,18 +60,17 @@ keywords='firebase cloud development', install_requires=install_requires, packages=['firebase_admin'], - python_requires='>=3.7', + python_requires='>=3.9', classifiers=[ 'Development Status :: 5 - Production/Stable', 'Intended Audience :: Developers', 'Topic :: Software Development :: Build Tools', 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.7', - 'Programming Language :: Python :: 3.8', 'Programming Language :: Python :: 3.9', 'Programming Language :: Python :: 3.10', 'Programming Language :: Python :: 3.11', 'Programming Language :: Python :: 3.12', + 'Programming Language :: Python :: 3.13', 'License :: OSI Approved :: Apache Software License', ], ) diff --git a/snippets/auth/get_service_account_tokens.py b/snippets/auth/get_service_account_tokens.py index 9f60590fe..7ad67a093 100644 --- a/snippets/auth/get_service_account_tokens.py +++ b/snippets/auth/get_service_account_tokens.py @@ -26,4 +26,4 @@ # After expiration_time, you must generate a new access token # [END get_service_account_tokens] -print('The access token {} expires at {}'.format(access_token, expiration_time)) +print(f'The access token {access_token} expires at {expiration_time}') diff --git a/snippets/auth/index.py b/snippets/auth/index.py index ed324e486..6a509b8f5 100644 --- a/snippets/auth/index.py +++ b/snippets/auth/index.py @@ -169,7 +169,7 @@ def revoke_refresh_token_uid(): user = auth.get_user(uid) # Convert to seconds as the auth_time in the token claims is in seconds. revocation_second = user.tokens_valid_after_timestamp / 1000 - print('Tokens revoked at: {0}'.format(revocation_second)) + print(f'Tokens revoked at: {revocation_second}') # [END revoke_tokens] # [START save_revocation_in_db] metadata_ref = firebase_admin.db.reference("metadata/" + uid) @@ -183,7 +183,7 @@ def get_user(uid): from firebase_admin import auth user = auth.get_user(uid) - print('Successfully fetched user data: {0}'.format(user.uid)) + print(f'Successfully fetched user data: {user.uid}') # [END get_user] def get_user_by_email(): @@ -192,7 +192,7 @@ def get_user_by_email(): from firebase_admin import auth user = auth.get_user_by_email(email) - print('Successfully fetched user data: {0}'.format(user.uid)) + print(f'Successfully fetched user data: {user.uid}') # [END get_user_by_email] def bulk_get_users(): @@ -221,7 +221,7 @@ def get_user_by_phone_number(): from firebase_admin import auth user = auth.get_user_by_phone_number(phone) - print('Successfully fetched user data: {0}'.format(user.uid)) + print(f'Successfully fetched user data: {user.uid}') # [END get_user_by_phone] def create_user(): @@ -234,7 +234,7 @@ def create_user(): display_name='John Doe', photo_url='http://www.example.com/12345678/photo.png', disabled=False) - print('Sucessfully created new user: {0}'.format(user.uid)) + print(f'Sucessfully created new user: {user.uid}') # [END create_user] return user.uid @@ -242,7 +242,7 @@ def create_user_with_id(): # [START create_user_with_id] user = auth.create_user( uid='some-uid', email='user@example.com', phone_number='+15555550100') - print('Sucessfully created new user: {0}'.format(user.uid)) + print(f'Sucessfully created new user: {user.uid}') # [END create_user_with_id] def update_user(uid): @@ -256,7 +256,7 @@ def update_user(uid): display_name='John Doe', photo_url='http://www.example.com/12345678/photo.png', disabled=True) - print('Sucessfully updated user: {0}'.format(user.uid)) + print(f'Sucessfully updated user: {user.uid}') # [END update_user] def delete_user(uid): @@ -271,10 +271,10 @@ def bulk_delete_users(): result = auth.delete_users(["uid1", "uid2", "uid3"]) - print('Successfully deleted {0} users'.format(result.success_count)) - print('Failed to delete {0} users'.format(result.failure_count)) + print(f'Successfully deleted {result.success_count} users') + print(f'Failed to delete {result.failure_count} users') for err in result.errors: - print('error #{0}, reason: {1}'.format(result.index, result.reason)) + print(f'error #{result.index}, reason: {result.reason}') # [END bulk_delete_users] def set_custom_user_claims(uid): @@ -475,10 +475,11 @@ def import_users(): hash_alg = auth.UserImportHash.hmac_sha256(key=b'secret_key') try: result = auth.import_users(users, hash_alg=hash_alg) - print('Successfully imported {0} users. Failed to import {1} users.'.format( - result.success_count, result.failure_count)) + print( + f'Successfully imported {result.success_count} users. Failed to import ' + f'{result.failure_count} users.') for err in result.errors: - print('Failed to import {0} due to {1}'.format(users[err.index].uid, err.reason)) + print(f'Failed to import {users[err.index].uid} due to {err.reason}') except exceptions.FirebaseError: # Some unrecoverable error occurred that prevented the operation from running. pass @@ -1012,7 +1013,7 @@ def revoke_refresh_tokens_tenant(tenant_client, uid): user = tenant_client.get_user(uid) # Convert to seconds as the auth_time in the token claims is in seconds. revocation_second = user.tokens_valid_after_timestamp / 1000 - print('Tokens revoked at: {0}'.format(revocation_second)) + print(f'Tokens revoked at: {revocation_second}') # [END revoke_tokens_tenant] def verify_id_token_and_check_revoked_tenant(tenant_client, id_token): diff --git a/snippets/database/index.py b/snippets/database/index.py index adfa13476..99bb4981e 100644 --- a/snippets/database/index.py +++ b/snippets/database/index.py @@ -235,7 +235,7 @@ def order_by_child(): ref = db.reference('dinosaurs') snapshot = ref.order_by_child('height').get() for key, val in snapshot.items(): - print('{0} was {1} meters tall'.format(key, val)) + print(f'{key} was {val} meters tall') # [END order_by_child] def order_by_nested_child(): @@ -243,7 +243,7 @@ def order_by_nested_child(): ref = db.reference('dinosaurs') snapshot = ref.order_by_child('dimensions/height').get() for key, val in snapshot.items(): - print('{0} was {1} meters tall'.format(key, val)) + print(f'{key} was {val} meters tall') # [END order_by_nested_child] def order_by_key(): @@ -258,7 +258,7 @@ def order_by_value(): ref = db.reference('scores') snapshot = ref.order_by_value().get() for key, val in snapshot.items(): - print('The {0} dinosaur\'s score is {1}'.format(key, val)) + print(f'The {key} dinosaur\'s score is {val}') # [END order_by_value] def limit_query(): @@ -280,7 +280,7 @@ def limit_query(): scores_ref = db.reference('scores') snapshot = scores_ref.order_by_value().limit_to_last(3).get() for key, val in snapshot.items(): - print('The {0} dinosaur\'s score is {1}'.format(key, val)) + print(f'The {key} dinosaur\'s score is {val}') # [END limit_query_3] def range_query(): @@ -300,7 +300,7 @@ def range_query(): # [START range_query_3] ref = db.reference('dinosaurs') - snapshot = ref.order_by_key().start_at('b').end_at(u'b\uf8ff').get() + snapshot = ref.order_by_key().start_at('b').end_at('b\uf8ff').get() for key in snapshot: print(key) # [END range_query_3] @@ -322,7 +322,7 @@ def complex_query(): # Data is ordered by increasing height, so we want the first entry. # Second entry is stegosarus. for key in snapshot: - print('The dinosaur just shorter than the stegosaurus is {0}'.format(key)) + print(f'The dinosaur just shorter than the stegosaurus is {key}') return else: print('The stegosaurus is the shortest dino') diff --git a/snippets/messaging/cloud_messaging.py b/snippets/messaging/cloud_messaging.py index 6caf316d0..6fb525231 100644 --- a/snippets/messaging/cloud_messaging.py +++ b/snippets/messaging/cloud_messaging.py @@ -222,29 +222,6 @@ def unsubscribe_from_topic(): # [END unsubscribe] -def send_all(): - registration_token = 'YOUR_REGISTRATION_TOKEN' - # [START send_all] - # Create a list containing up to 500 messages. - messages = [ - messaging.Message( - notification=messaging.Notification('Price drop', '5% off all electronics'), - token=registration_token, - ), - # ... - messaging.Message( - notification=messaging.Notification('Price drop', '2% off all books'), - topic='readers-club', - ), - ] - - response = messaging.send_all(messages) - # See the BatchResponse reference documentation - # for the contents of response. - print('{0} messages were sent successfully'.format(response.success_count)) - # [END send_all] - - def send_each(): registration_token = 'YOUR_REGISTRATION_TOKEN' # [START send_each] @@ -264,56 +241,9 @@ def send_each(): response = messaging.send_each(messages) # See the BatchResponse reference documentation # for the contents of response. - print('{0} messages were sent successfully'.format(response.success_count)) + print(f'{response.success_count} messages were sent successfully') # [END send_each] - -def send_multicast(): - # [START send_multicast] - # Create a list containing up to 500 registration tokens. - # These registration tokens come from the client FCM SDKs. - registration_tokens = [ - 'YOUR_REGISTRATION_TOKEN_1', - # ... - 'YOUR_REGISTRATION_TOKEN_N', - ] - - message = messaging.MulticastMessage( - data={'score': '850', 'time': '2:45'}, - tokens=registration_tokens, - ) - response = messaging.send_multicast(message) - # See the BatchResponse reference documentation - # for the contents of response. - print('{0} messages were sent successfully'.format(response.success_count)) - # [END send_multicast] - - -def send_multicast_and_handle_errors(): - # [START send_multicast_error] - # These registration tokens come from the client FCM SDKs. - registration_tokens = [ - 'YOUR_REGISTRATION_TOKEN_1', - # ... - 'YOUR_REGISTRATION_TOKEN_N', - ] - - message = messaging.MulticastMessage( - data={'score': '850', 'time': '2:45'}, - tokens=registration_tokens, - ) - response = messaging.send_multicast(message) - if response.failure_count > 0: - responses = response.responses - failed_tokens = [] - for idx, resp in enumerate(responses): - if not resp.success: - # The order of responses corresponds to the order of the registration tokens. - failed_tokens.append(registration_tokens[idx]) - print('List of tokens that caused failures: {0}'.format(failed_tokens)) - # [END send_multicast_error] - - def send_each_for_multicast_and_handle_errors(): # [START send_each_for_multicast_error] # These registration tokens come from the client FCM SDKs. @@ -335,5 +265,5 @@ def send_each_for_multicast_and_handle_errors(): if not resp.success: # The order of responses corresponds to the order of the registration tokens. failed_tokens.append(registration_tokens[idx]) - print('List of tokens that caused failures: {0}'.format(failed_tokens)) + print(f'List of tokens that caused failures: {failed_tokens}') # [END send_each_for_multicast_error] diff --git a/tests/test_app.py b/tests/test_app.py index 5b203661f..0ff0854b4 100644 --- a/tests/test_app.py +++ b/tests/test_app.py @@ -215,11 +215,11 @@ def revert_config_env(config_old): class TestFirebaseApp: """Test cases for App initialization and life cycle.""" - invalid_credentials = ['', 'foo', 0, 1, dict(), list(), tuple(), True, False] - invalid_options = ['', 0, 1, list(), tuple(), True, False] - invalid_names = [None, '', 0, 1, dict(), list(), tuple(), True, False] + invalid_credentials = ['', 'foo', 0, 1, {}, [], tuple(), True, False] + invalid_options = ['', 0, 1, [], tuple(), True, False] + invalid_names = [None, '', 0, 1, {}, [], tuple(), True, False] invalid_apps = [ - None, '', 0, 1, dict(), list(), tuple(), True, False, + None, '', 0, 1, {}, [], tuple(), True, False, firebase_admin.App('uninitialized', CREDENTIAL, {}) ] @@ -308,11 +308,11 @@ def test_project_id_from_environment(self): variables = ['GOOGLE_CLOUD_PROJECT', 'GCLOUD_PROJECT'] for idx, var in enumerate(variables): old_project_id = os.environ.get(var) - new_project_id = 'env-project-{0}'.format(idx) + new_project_id = f'env-project-{idx}' os.environ[var] = new_project_id try: app = firebase_admin.initialize_app( - testutils.MockCredential(), name='myApp{0}'.format(var)) + testutils.MockCredential(), name=f'myApp{var}') assert app.project_id == new_project_id finally: if old_project_id: @@ -388,7 +388,7 @@ def test_app_services(self, init_app): with pytest.raises(ValueError): _utils.get_app_service(init_app, 'test.service', AppService) - @pytest.mark.parametrize('arg', [0, 1, True, False, 'str', list(), dict(), tuple()]) + @pytest.mark.parametrize('arg', [0, 1, True, False, 'str', [], {}, tuple()]) def test_app_services_invalid_arg(self, arg): with pytest.raises(ValueError): _utils.get_app_service(arg, 'test.service', AppService) diff --git a/tests/test_app_check.py b/tests/test_app_check.py index 168d0a972..e55ae39de 100644 --- a/tests/test_app_check.py +++ b/tests/test_app_check.py @@ -22,7 +22,7 @@ from firebase_admin import app_check from tests import testutils -NON_STRING_ARGS = [list(), tuple(), dict(), True, False, 1, 0] +NON_STRING_ARGS = [[], tuple(), {}, True, False, 1, 0] APP_ID = "1234567890" PROJECT_ID = "1334" @@ -71,7 +71,7 @@ def evaluate(): def test_verify_token_with_non_string_raises_error(self, token): with pytest.raises(ValueError) as excinfo: app_check.verify_token(token) - expected = 'app check token "{0}" must be a string.'.format(token) + expected = f'app check token "{token}" must be a string.' assert str(excinfo.value) == expected def test_has_valid_token_headers(self): diff --git a/tests/test_auth_providers.py b/tests/test_auth_providers.py index 304e0fd78..106e1cae3 100644 --- a/tests/test_auth_providers.py +++ b/tests/test_auth_providers.py @@ -27,8 +27,7 @@ ID_TOOLKIT_URL = 'https://identitytoolkit.googleapis.com/v2' EMULATOR_HOST_ENV_VAR = 'FIREBASE_AUTH_EMULATOR_HOST' AUTH_EMULATOR_HOST = 'localhost:9099' -EMULATED_ID_TOOLKIT_URL = 'http://{}/identitytoolkit.googleapis.com/v2'.format( - AUTH_EMULATOR_HOST) +EMULATED_ID_TOOLKIT_URL = f'http://{AUTH_EMULATOR_HOST}/identitytoolkit.googleapis.com/v2' URL_PROJECT_SUFFIX = '/projects/mock-project-id' USER_MGT_URLS = { 'ID_TOOLKIT': ID_TOOLKIT_URL, @@ -45,7 +44,7 @@ } }""" -INVALID_PROVIDER_IDS = [None, True, False, 1, 0, list(), tuple(), dict(), ''] +INVALID_PROVIDER_IDS = [None, True, False, 1, 0, [], tuple(), {}, ''] @pytest.fixture(scope='module', params=[{'emulated': False}, {'emulated': True}]) @@ -282,12 +281,12 @@ def test_delete(self, user_mgt_app): _assert_request(recorder[0], 'DELETE', f'{USER_MGT_URLS["PREFIX"]}/oauthIdpConfigs/oidc.provider') - @pytest.mark.parametrize('arg', [None, 'foo', list(), dict(), 0, -1, 101, False]) + @pytest.mark.parametrize('arg', [None, 'foo', [], {}, 0, -1, 101, False]) def test_invalid_max_results(self, user_mgt_app, arg): with pytest.raises(ValueError): auth.list_oidc_provider_configs(max_results=arg, app=user_mgt_app) - @pytest.mark.parametrize('arg', ['', list(), dict(), 0, -1, 101, False]) + @pytest.mark.parametrize('arg', ['', [], {}, 0, -1, 101, False]) def test_invalid_page_token(self, user_mgt_app, arg): with pytest.raises(ValueError): auth.list_oidc_provider_configs(page_token=arg, app=user_mgt_app) @@ -346,7 +345,7 @@ def test_paged_iteration(self, user_mgt_app): for index in range(2): provider_config = next(iterator) - assert provider_config.provider_id == 'oidc.provider{0}'.format(index) + assert provider_config.provider_id == f'oidc.provider{index}' assert len(recorder) == 1 _assert_request(recorder[0], 'GET', f'{USER_MGT_URLS["PREFIX"]}/oauthIdpConfigs?pageSize=100') @@ -403,7 +402,7 @@ def _assert_page(self, page, count=2, start=0, next_page_token=''): index = start assert len(page.provider_configs) == count for provider_config in page.provider_configs: - self._assert_provider_config(provider_config, want_id='oidc.provider{0}'.format(index)) + self._assert_provider_config(provider_config, want_id=f'oidc.provider{index}') index += 1 if next_page_token: @@ -621,12 +620,12 @@ def test_config_not_found(self, user_mgt_app): assert excinfo.value.http_response is not None assert excinfo.value.cause is not None - @pytest.mark.parametrize('arg', [None, 'foo', list(), dict(), 0, -1, 101, False]) + @pytest.mark.parametrize('arg', [None, 'foo', [], {}, 0, -1, 101, False]) def test_invalid_max_results(self, user_mgt_app, arg): with pytest.raises(ValueError): auth.list_saml_provider_configs(max_results=arg, app=user_mgt_app) - @pytest.mark.parametrize('arg', ['', list(), dict(), 0, -1, 101, False]) + @pytest.mark.parametrize('arg', ['', [], {}, 0, -1, 101, False]) def test_invalid_page_token(self, user_mgt_app, arg): with pytest.raises(ValueError): auth.list_saml_provider_configs(page_token=arg, app=user_mgt_app) @@ -686,7 +685,7 @@ def test_paged_iteration(self, user_mgt_app): for index in range(2): provider_config = next(iterator) - assert provider_config.provider_id == 'saml.provider{0}'.format(index) + assert provider_config.provider_id == f'saml.provider{index}' assert len(recorder) == 1 _assert_request( recorder[0], 'GET', f'{USER_MGT_URLS["PREFIX"]}/inboundSamlConfigs?pageSize=100') @@ -735,7 +734,7 @@ def _assert_page(self, page, count=2, start=0, next_page_token=''): index = start assert len(page.provider_configs) == count for provider_config in page.provider_configs: - self._assert_provider_config(provider_config, want_id='saml.provider{0}'.format(index)) + self._assert_provider_config(provider_config, want_id=f'saml.provider{index}') index += 1 if next_page_token: diff --git a/tests/test_credentials.py b/tests/test_credentials.py index cceb6b6f9..1e1db6460 100644 --- a/tests/test_credentials.py +++ b/tests/test_credentials.py @@ -64,7 +64,7 @@ def test_init_from_invalid_certificate(self, file_name, error): with pytest.raises(error): credentials.Certificate(testutils.resource_filename(file_name)) - @pytest.mark.parametrize('arg', [None, 0, 1, True, False, list(), tuple(), dict()]) + @pytest.mark.parametrize('arg', [None, 0, 1, True, False, [], tuple(), {}]) def test_invalid_args(self, arg): with pytest.raises(ValueError): credentials.Certificate(arg) @@ -156,7 +156,7 @@ def test_init_from_invalid_file(self): credentials.RefreshToken( testutils.resource_filename('service_account.json')) - @pytest.mark.parametrize('arg', [None, 0, 1, True, False, list(), tuple(), dict()]) + @pytest.mark.parametrize('arg', [None, 0, 1, True, False, [], tuple(), {}]) def test_invalid_args(self, arg): with pytest.raises(ValueError): credentials.RefreshToken(arg) diff --git a/tests/test_db.py b/tests/test_db.py index 00a0077cb..abba3baa8 100644 --- a/tests/test_db.py +++ b/tests/test_db.py @@ -45,7 +45,7 @@ def __init__(self, data, status, recorder, etag=ETAG): def send(self, request, **kwargs): if_match = request.headers.get('if-match') if_none_match = request.headers.get('if-none-match') - resp = super(MockAdapter, self).send(request, **kwargs) + resp = super().send(request, **kwargs) resp.headers = {'ETag': self._etag} if if_match and if_match != MockAdapter.ETAG: resp.status_code = 412 @@ -87,7 +87,7 @@ class TestReferencePath: } invalid_paths = [ - None, True, False, 0, 1, dict(), list(), tuple(), _Object(), + None, True, False, 0, 1, {}, [], tuple(), _Object(), 'foo#', 'foo.', 'foo$', 'foo[', 'foo]', ] @@ -98,7 +98,7 @@ class TestReferencePath: } invalid_children = [ - None, '', '/foo', '/foo/bar', True, False, 0, 1, dict(), list(), tuple(), + None, '', '/foo', '/foo/bar', True, False, 0, 1, {}, [], tuple(), 'foo#', 'foo.', 'foo$', 'foo[', 'foo]', _Object() ] @@ -248,7 +248,7 @@ def test_get_if_changed(self, data): self._assert_request(recorder[1], 'GET', 'https://test.firebaseio.com/test.json') assert recorder[1].headers['if-none-match'] == MockAdapter.ETAG - @pytest.mark.parametrize('etag', [0, 1, True, False, dict(), list(), tuple()]) + @pytest.mark.parametrize('etag', [0, 1, True, False, {}, [], tuple()]) def test_get_if_changed_invalid_etag(self, etag): ref = db.reference('/test') with pytest.raises(ValueError): @@ -347,7 +347,7 @@ def test_set_if_unchanged_failure(self, data): assert json.loads(recorder[0].body.decode()) == data assert recorder[0].headers['if-match'] == 'invalid-etag' - @pytest.mark.parametrize('etag', [0, 1, True, False, dict(), list(), tuple()]) + @pytest.mark.parametrize('etag', [0, 1, True, False, {}, [], tuple()]) def test_set_if_unchanged_invalid_etag(self, etag): ref = db.reference('/test') with pytest.raises(ValueError): @@ -369,7 +369,7 @@ def test_set_if_unchanged_non_json_value(self, value): ref.set_if_unchanged(MockAdapter.ETAG, value) @pytest.mark.parametrize('update', [ - None, {}, {None:'foo'}, '', 'foo', 0, 1, list(), tuple(), _Object() + None, {}, {None:'foo'}, '', 'foo', 0, 1, [], tuple(), _Object() ]) def test_set_invalid_update(self, update): ref = db.reference('/test') @@ -466,7 +466,7 @@ def test_transaction_abort(self): assert excinfo.value.http_response is None assert len(recorder) == 1 + 25 - @pytest.mark.parametrize('func', [None, 0, 1, True, False, 'foo', dict(), list(), tuple()]) + @pytest.mark.parametrize('func', [None, 0, 1, True, False, 'foo', {}, [], tuple()]) def test_transaction_invalid_function(self, func): ref = db.reference('/test') with pytest.raises(ValueError): @@ -672,7 +672,7 @@ def _assert_request(self, request, expected_method, expected_url): def test_get_value(self): ref = db.reference('/test') recorder = self.instrument(ref, json.dumps('data')) - query_str = 'auth_variable_override={0}'.format(self.encoded_override) + query_str = f'auth_variable_override={self.encoded_override}' assert ref.get() == 'data' assert len(recorder) == 1 self._assert_request( @@ -683,7 +683,7 @@ def test_set_value(self): recorder = self.instrument(ref, '') data = {'foo' : 'bar'} ref.set(data) - query_str = 'print=silent&auth_variable_override={0}'.format(self.encoded_override) + query_str = f'print=silent&auth_variable_override={self.encoded_override}' assert len(recorder) == 1 self._assert_request( recorder[0], 'PUT', 'https://test.firebaseio.com/test.json?' + query_str) @@ -693,7 +693,7 @@ def test_order_by_query(self): ref = db.reference('/test') recorder = self.instrument(ref, json.dumps('data')) query = ref.order_by_child('foo') - query_str = 'orderBy=%22foo%22&auth_variable_override={0}'.format(self.encoded_override) + query_str = f'orderBy=%22foo%22&auth_variable_override={self.encoded_override}' assert query.get() == 'data' assert len(recorder) == 1 self._assert_request( @@ -703,8 +703,9 @@ def test_range_query(self): ref = db.reference('/test') recorder = self.instrument(ref, json.dumps('data')) query = ref.order_by_child('foo').start_at(1).end_at(10) - query_str = ('endAt=10&orderBy=%22foo%22&startAt=1&' - 'auth_variable_override={0}'.format(self.encoded_override)) + query_str = ( + f'endAt=10&orderBy=%22foo%22&startAt=1&auth_variable_override={self.encoded_override}' + ) assert query.get() == 'data' assert len(recorder) == 1 self._assert_request( @@ -794,7 +795,7 @@ def test_valid_db_url(self, url): @pytest.mark.parametrize('url', [ None, '', 'foo', 'http://test.firebaseio.com', 'http://test.firebasedatabase.app', - True, False, 1, 0, dict(), list(), tuple(), _Object() + True, False, 1, 0, {}, [], tuple(), _Object() ]) def test_invalid_db_url(self, url): firebase_admin.initialize_app(testutils.MockCredential(), {'databaseURL' : url}) @@ -838,7 +839,7 @@ def test_valid_auth_override(self, override): assert ref._client.params['auth_variable_override'] == encoded @pytest.mark.parametrize('override', [ - '', 'foo', 0, 1, True, False, list(), tuple(), _Object()]) + '', 'foo', 0, 1, True, False, [], tuple(), _Object()]) def test_invalid_auth_override(self, override): firebase_admin.initialize_app(testutils.MockCredential(), { 'databaseURL' : 'https://test.firebaseio.com', @@ -885,8 +886,10 @@ def test_app_delete(self): assert other_ref._client.session is None def test_user_agent_format(self): - expected = 'Firebase/HTTP/{0}/{1}.{2}/AdminPython'.format( - firebase_admin.__version__, sys.version_info.major, sys.version_info.minor) + expected = ( + f'Firebase/HTTP/{firebase_admin.__version__}/{sys.version_info.major}.' + f'{sys.version_info.minor}/AdminPython' + ) assert db._USER_AGENT == expected def _check_timeout(self, ref, timeout): @@ -925,7 +928,7 @@ class TestQuery: ref = db.Reference(path='foo') @pytest.mark.parametrize('path', [ - '', None, '/', '/foo', 0, 1, True, False, dict(), list(), tuple(), _Object(), + '', None, '/', '/foo', 0, 1, True, False, {}, [], tuple(), _Object(), '$foo', '.foo', '#foo', '[foo', 'foo]', '$key', '$value', '$priority' ]) def test_invalid_path(self, path): @@ -935,13 +938,13 @@ def test_invalid_path(self, path): @pytest.mark.parametrize('path, expected', valid_paths.items()) def test_order_by_valid_path(self, path, expected): query = self.ref.order_by_child(path) - assert query._querystr == 'orderBy="{0}"'.format(expected) + assert query._querystr == f'orderBy="{expected}"' @pytest.mark.parametrize('path, expected', valid_paths.items()) def test_filter_by_valid_path(self, path, expected): query = self.ref.order_by_child(path) query.equal_to(10) - assert query._querystr == 'equalTo=10&orderBy="{0}"'.format(expected) + assert query._querystr == f'equalTo=10&orderBy="{expected}"' def test_order_by_key(self): query = self.ref.order_by_key() @@ -972,7 +975,7 @@ def test_multiple_limits(self): with pytest.raises(ValueError): query.limit_to_first(1) - @pytest.mark.parametrize('limit', [None, -1, 'foo', 1.2, list(), dict(), tuple(), _Object()]) + @pytest.mark.parametrize('limit', [None, -1, 'foo', 1.2, [], {}, tuple(), _Object()]) def test_invalid_limit(self, limit): query = self.ref.order_by_child('foo') with pytest.raises(ValueError): @@ -985,47 +988,47 @@ def test_start_at_none(self): with pytest.raises(ValueError): query.start_at(None) - @pytest.mark.parametrize('arg', ['', 'foo', True, False, 0, 1, dict()]) + @pytest.mark.parametrize('arg', ['', 'foo', True, False, 0, 1, {}]) def test_valid_start_at(self, arg): query = self.ref.order_by_child('foo').start_at(arg) - assert query._querystr == 'orderBy="foo"&startAt={0}'.format(json.dumps(arg)) + assert query._querystr == f'orderBy="foo"&startAt={json.dumps(arg)}' def test_end_at_none(self): query = self.ref.order_by_child('foo') with pytest.raises(ValueError): query.end_at(None) - @pytest.mark.parametrize('arg', ['', 'foo', True, False, 0, 1, dict()]) + @pytest.mark.parametrize('arg', ['', 'foo', True, False, 0, 1, {}]) def test_valid_end_at(self, arg): query = self.ref.order_by_child('foo').end_at(arg) - assert query._querystr == 'endAt={0}&orderBy="foo"'.format(json.dumps(arg)) + assert query._querystr == f'endAt={json.dumps(arg)}&orderBy="foo"' def test_equal_to_none(self): query = self.ref.order_by_child('foo') with pytest.raises(ValueError): query.equal_to(None) - @pytest.mark.parametrize('arg', ['', 'foo', True, False, 0, 1, dict()]) + @pytest.mark.parametrize('arg', ['', 'foo', True, False, 0, 1, {}]) def test_valid_equal_to(self, arg): query = self.ref.order_by_child('foo').equal_to(arg) - assert query._querystr == 'equalTo={0}&orderBy="foo"'.format(json.dumps(arg)) + assert query._querystr == f'equalTo={json.dumps(arg)}&orderBy="foo"' def test_range_query(self, initquery): query, order_by = initquery query.start_at(1) query.equal_to(2) query.end_at(3) - assert query._querystr == 'endAt=3&equalTo=2&orderBy="{0}"&startAt=1'.format(order_by) + assert query._querystr == f'endAt=3&equalTo=2&orderBy="{order_by}"&startAt=1' def test_limit_first_query(self, initquery): query, order_by = initquery query.limit_to_first(1) - assert query._querystr == 'limitToFirst=1&orderBy="{0}"'.format(order_by) + assert query._querystr == f'limitToFirst=1&orderBy="{order_by}"' def test_limit_last_query(self, initquery): query, order_by = initquery query.limit_to_last(1) - assert query._querystr == 'limitToLast=1&orderBy="{0}"'.format(order_by) + assert query._querystr == f'limitToLast=1&orderBy="{order_by}"' def test_all_in(self, initquery): query, order_by = initquery @@ -1033,7 +1036,7 @@ def test_all_in(self, initquery): query.equal_to(2) query.end_at(3) query.limit_to_first(10) - expected = 'endAt=3&equalTo=2&limitToFirst=10&orderBy="{0}"&startAt=1'.format(order_by) + expected = f'endAt=3&equalTo=2&limitToFirst=10&orderBy="{order_by}"&startAt=1' assert query._querystr == expected def test_invalid_query_args(self): @@ -1059,9 +1062,9 @@ class TestSorter: ({'k1' : False, 'k2' : 'bar', 'k3' : None}, ['k3', 'k1', 'k2']), ({'k1' : False, 'k2' : 1, 'k3' : None}, ['k3', 'k1', 'k2']), ({'k1' : True, 'k2' : 0, 'k3' : None, 'k4' : 'foo'}, ['k3', 'k1', 'k2', 'k4']), - ({'k1' : True, 'k2' : 0, 'k3' : None, 'k4' : 'foo', 'k5' : False, 'k6' : dict()}, + ({'k1' : True, 'k2' : 0, 'k3' : None, 'k4' : 'foo', 'k5' : False, 'k6' : {}}, ['k3', 'k5', 'k1', 'k2', 'k4', 'k6']), - ({'k1' : True, 'k2' : 0, 'k3' : 'foo', 'k4' : 'foo', 'k5' : False, 'k6' : dict()}, + ({'k1' : True, 'k2' : 0, 'k3' : 'foo', 'k4' : 'foo', 'k5' : False, 'k6' : {}}, ['k5', 'k1', 'k2', 'k3', 'k4', 'k6']), ] diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index 4347c838a..fa1276feb 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -14,17 +14,12 @@ import io import json -import socket -import httplib2 -import pytest import requests from requests import models -from googleapiclient import errors from firebase_admin import exceptions from firebase_admin import _utils -from firebase_admin import _gapic_utils _NOT_FOUND_ERROR_DICT = { @@ -178,159 +173,3 @@ def _create_response(self, status=500, payload=None): resp.raw = io.BytesIO(payload.encode()) exc = requests.exceptions.RequestException('Test error', response=resp) return resp, exc - - -class TestGoogleApiClient: - - @pytest.mark.parametrize('error', [ - socket.timeout('Test error'), - socket.error('Read timed out') - ]) - def test_googleapicleint_timeout_error(self, error): - firebase_error = _gapic_utils.handle_googleapiclient_error(error) - assert isinstance(firebase_error, exceptions.DeadlineExceededError) - assert str(firebase_error) == 'Timed out while making an API call: {0}'.format(error) - assert firebase_error.cause is error - assert firebase_error.http_response is None - - def test_googleapiclient_connection_error(self): - error = httplib2.ServerNotFoundError('Test error') - firebase_error = _gapic_utils.handle_googleapiclient_error(error) - assert isinstance(firebase_error, exceptions.UnavailableError) - assert str(firebase_error) == 'Failed to establish a connection: Test error' - assert firebase_error.cause is error - assert firebase_error.http_response is None - - def test_unknown_transport_error(self): - error = socket.error('Test error') - firebase_error = _gapic_utils.handle_googleapiclient_error(error) - assert isinstance(firebase_error, exceptions.UnknownError) - assert str(firebase_error) == 'Unknown error while making a remote service call: Test error' - assert firebase_error.cause is error - assert firebase_error.http_response is None - - def test_http_response(self): - error = self._create_http_error() - firebase_error = _gapic_utils.handle_googleapiclient_error(error) - assert isinstance(firebase_error, exceptions.InternalError) - assert str(firebase_error) == str(error) - assert firebase_error.cause is error - assert firebase_error.http_response.status_code == 500 - assert firebase_error.http_response.content.decode() == 'Body' - - def test_http_response_with_unknown_status(self): - error = self._create_http_error(status=501) - firebase_error = _gapic_utils.handle_googleapiclient_error(error) - assert isinstance(firebase_error, exceptions.UnknownError) - assert str(firebase_error) == str(error) - assert firebase_error.cause is error - assert firebase_error.http_response.status_code == 501 - assert firebase_error.http_response.content.decode() == 'Body' - - def test_http_response_with_message(self): - error = self._create_http_error() - firebase_error = _gapic_utils.handle_googleapiclient_error( - error, message='Explicit error message') - assert isinstance(firebase_error, exceptions.InternalError) - assert str(firebase_error) == 'Explicit error message' - assert firebase_error.cause is error - assert firebase_error.http_response.status_code == 500 - assert firebase_error.http_response.content.decode() == 'Body' - - def test_http_response_with_code(self): - error = self._create_http_error() - firebase_error = _gapic_utils.handle_googleapiclient_error( - error, code=exceptions.UNAVAILABLE) - assert isinstance(firebase_error, exceptions.UnavailableError) - assert str(firebase_error) == str(error) - assert firebase_error.cause is error - assert firebase_error.http_response.status_code == 500 - assert firebase_error.http_response.content.decode() == 'Body' - - def test_http_response_with_message_and_code(self): - error = self._create_http_error() - firebase_error = _gapic_utils.handle_googleapiclient_error( - error, message='Explicit error message', code=exceptions.UNAVAILABLE) - assert isinstance(firebase_error, exceptions.UnavailableError) - assert str(firebase_error) == 'Explicit error message' - assert firebase_error.cause is error - assert firebase_error.http_response.status_code == 500 - assert firebase_error.http_response.content.decode() == 'Body' - - def test_handle_platform_error(self): - error = self._create_http_error(payload=_NOT_FOUND_PAYLOAD) - firebase_error = _gapic_utils.handle_platform_error_from_googleapiclient(error) - assert isinstance(firebase_error, exceptions.NotFoundError) - assert str(firebase_error) == 'test error' - assert firebase_error.cause is error - assert firebase_error.http_response.status_code == 500 - assert firebase_error.http_response.content.decode() == _NOT_FOUND_PAYLOAD - - def test_handle_platform_error_with_no_response(self): - error = socket.error('Test error') - firebase_error = _gapic_utils.handle_platform_error_from_googleapiclient(error) - assert isinstance(firebase_error, exceptions.UnknownError) - assert str(firebase_error) == 'Unknown error while making a remote service call: Test error' - assert firebase_error.cause is error - assert firebase_error.http_response is None - - def test_handle_platform_error_with_no_error_code(self): - error = self._create_http_error(payload='no error code') - firebase_error = _gapic_utils.handle_platform_error_from_googleapiclient(error) - assert isinstance(firebase_error, exceptions.InternalError) - message = 'Unexpected HTTP response with status: 500; body: no error code' - assert str(firebase_error) == message - assert firebase_error.cause is error - assert firebase_error.http_response.status_code == 500 - assert firebase_error.http_response.content.decode() == 'no error code' - - def test_handle_platform_error_with_custom_handler(self): - error = self._create_http_error(payload=_NOT_FOUND_PAYLOAD) - invocations = [] - - def _custom_handler(cause, message, error_dict, http_response): - invocations.append((cause, message, error_dict, http_response)) - return exceptions.InvalidArgumentError('Custom message', cause, http_response) - - firebase_error = _gapic_utils.handle_platform_error_from_googleapiclient( - error, _custom_handler) - - assert isinstance(firebase_error, exceptions.InvalidArgumentError) - assert str(firebase_error) == 'Custom message' - assert firebase_error.cause is error - assert firebase_error.http_response.status_code == 500 - assert firebase_error.http_response.content.decode() == _NOT_FOUND_PAYLOAD - assert len(invocations) == 1 - args = invocations[0] - assert len(args) == 4 - assert args[0] is error - assert args[1] == 'test error' - assert args[2] == _NOT_FOUND_ERROR_DICT - assert args[3] is not None - - def test_handle_platform_error_with_custom_handler_ignore(self): - error = self._create_http_error(payload=_NOT_FOUND_PAYLOAD) - invocations = [] - - def _custom_handler(cause, message, error_dict, http_response): - invocations.append((cause, message, error_dict, http_response)) - - firebase_error = _gapic_utils.handle_platform_error_from_googleapiclient( - error, _custom_handler) - - assert isinstance(firebase_error, exceptions.NotFoundError) - assert str(firebase_error) == 'test error' - assert firebase_error.cause is error - assert firebase_error.http_response.status_code == 500 - assert firebase_error.http_response.content.decode() == _NOT_FOUND_PAYLOAD - assert len(invocations) == 1 - args = invocations[0] - assert len(args) == 4 - assert args[0] is error - assert args[1] == 'test error' - assert args[2] == _NOT_FOUND_ERROR_DICT - assert args[3] is not None - - def _create_http_error(self, status=500, payload='Body'): - resp = httplib2.Response({'status': status}) - return errors.HttpError(resp, payload.encode()) diff --git a/tests/test_instance_id.py b/tests/test_instance_id.py index 387e067c9..2b0e21079 100644 --- a/tests/test_instance_id.py +++ b/tests/test_instance_id.py @@ -72,7 +72,7 @@ def _assert_request(self, request, expected_method, expected_url): assert request.headers['x-goog-api-client'] == expected_metrics_header def _get_url(self, project_id, iid): - return instance_id._IID_SERVICE_URL + 'project/{0}/instanceId/{1}'.format(project_id, iid) + return instance_id._IID_SERVICE_URL + f'project/{project_id}/instanceId/{iid}' def test_no_project_id(self): def evaluate(): @@ -131,14 +131,14 @@ def test_delete_instance_id_unexpected_error(self): with pytest.raises(exceptions.UnknownError) as excinfo: instance_id.delete_instance_id('test_iid') url = self._get_url('explicit-project-id', 'test_iid') - message = 'Instance ID "test_iid": 501 Server Error: None for url: {0}'.format(url) + message = f'Instance ID "test_iid": 501 Server Error: None for url: {url}' assert str(excinfo.value) == message assert excinfo.value.cause is not None assert excinfo.value.http_response is not None assert len(recorder) == 1 self._assert_request(recorder[0], 'DELETE', url) - @pytest.mark.parametrize('iid', [None, '', 0, 1, True, False, list(), dict(), tuple()]) + @pytest.mark.parametrize('iid', [None, '', 0, 1, True, False, [], {}, tuple()]) def test_invalid_instance_id(self, iid): cred = testutils.MockCredential() app = firebase_admin.initialize_app(cred, {'projectId': 'explicit-project-id'}) diff --git a/tests/test_messaging.py b/tests/test_messaging.py index 76cee2a33..9fa30fef9 100644 --- a/tests/test_messaging.py +++ b/tests/test_messaging.py @@ -20,8 +20,6 @@ import httpx import respx -from googleapiclient import http -from googleapiclient import _helpers import pytest import firebase_admin @@ -32,12 +30,12 @@ from tests import testutils -NON_STRING_ARGS = [list(), tuple(), dict(), True, False, 1, 0] -NON_DICT_ARGS = ['', list(), tuple(), True, False, 1, 0, {1: 'foo'}, {'foo': 1}] -NON_OBJECT_ARGS = [list(), tuple(), dict(), 'foo', 0, 1, True, False] -NON_LIST_ARGS = ['', tuple(), dict(), True, False, 1, 0, [1], ['foo', 1]] -NON_UINT_ARGS = ['1.23s', list(), tuple(), dict(), -1.23] -NON_BOOL_ARGS = ['', list(), tuple(), dict(), 1, 0, [1], ['foo', 1], {1: 'foo'}, {'foo': 1}] +NON_STRING_ARGS = [[], tuple(), {}, True, False, 1, 0] +NON_DICT_ARGS = ['', [], tuple(), True, False, 1, 0, {1: 'foo'}, {'foo': 1}] +NON_OBJECT_ARGS = [[], tuple(), {}, 'foo', 0, 1, True, False] +NON_LIST_ARGS = ['', tuple(), {}, True, False, 1, 0, [1], ['foo', 1]] +NON_UINT_ARGS = ['1.23s', [], tuple(), {}, -1.23] +NON_BOOL_ARGS = ['', [], tuple(), {}, 1, 0, [1], ['foo', 1], {1: 'foo'}, {'foo': 1}] HTTP_ERROR_CODES = { 400: exceptions.InvalidArgumentError, 403: exceptions.PermissionDeniedError, @@ -503,7 +501,7 @@ def test_invalid_channel_id(self, data): excinfo = self._check_notification(notification) assert str(excinfo.value) == 'AndroidNotification.channel_id must be a string.' - @pytest.mark.parametrize('timestamp', [100, '', 'foo', {}, [], list(), dict()]) + @pytest.mark.parametrize('timestamp', [100, '', 'foo', {}, []]) def test_invalid_event_timestamp(self, timestamp): notification = messaging.AndroidNotification(event_timestamp=timestamp) excinfo = self._check_notification(notification) @@ -570,7 +568,7 @@ def test_negative_vibrate_timings_millis(self): expected = 'AndroidNotification.vibrate_timings_millis must not be negative.' assert str(excinfo.value) == expected - @pytest.mark.parametrize('notification_count', ['', 'foo', list(), tuple(), dict()]) + @pytest.mark.parametrize('notification_count', ['', 'foo', [], tuple(), {}]) def test_invalid_notification_count(self, notification_count): notification = messaging.AndroidNotification(notification_count=notification_count) excinfo = self._check_notification(notification) @@ -941,19 +939,19 @@ def test_invalid_tag(self, data): excinfo = self._check_notification(notification) assert str(excinfo.value) == 'WebpushNotification.tag must be a string.' - @pytest.mark.parametrize('data', ['', 'foo', list(), tuple(), dict()]) + @pytest.mark.parametrize('data', ['', 'foo', [], tuple(), {}]) def test_invalid_timestamp(self, data): notification = messaging.WebpushNotification(timestamp_millis=data) excinfo = self._check_notification(notification) assert str(excinfo.value) == 'WebpushNotification.timestamp_millis must be a number.' - @pytest.mark.parametrize('data', ['', list(), tuple(), True, False, 1, 0]) + @pytest.mark.parametrize('data', ['', [], tuple(), True, False, 1, 0]) def test_invalid_custom_data(self, data): notification = messaging.WebpushNotification(custom_data=data) excinfo = self._check_notification(notification) assert str(excinfo.value) == 'WebpushNotification.custom_data must be a dict.' - @pytest.mark.parametrize('data', ['', dict(), tuple(), True, False, 1, 0, [1, 2]]) + @pytest.mark.parametrize('data', ['', {}, tuple(), True, False, 1, 0, [1, 2]]) def test_invalid_actions(self, data): notification = messaging.WebpushNotification(actions=data) excinfo = self._check_notification(notification) @@ -1174,7 +1172,7 @@ def test_invalid_alert(self, data): expected = 'Aps.alert must be a string or an instance of ApsAlert class.' assert str(excinfo.value) == expected - @pytest.mark.parametrize('data', [list(), tuple(), dict(), 'foo']) + @pytest.mark.parametrize('data', [[], tuple(), {}, 'foo']) def test_invalid_badge(self, data): aps = messaging.Aps(badge=data) with pytest.raises(ValueError) as excinfo: @@ -1206,7 +1204,7 @@ def test_invalid_thread_id(self, data): expected = 'Aps.thread_id must be a string.' assert str(excinfo.value) == expected - @pytest.mark.parametrize('data', ['', list(), tuple(), True, False, 1, 0, ]) + @pytest.mark.parametrize('data', ['', [], tuple(), True, False, 1, 0, ]) def test_invalid_custom_data_dict(self, data): if isinstance(data, dict): return @@ -1311,7 +1309,7 @@ def test_invalid_name(self, data): expected = 'CriticalSound.name must be a non-empty string.' assert str(excinfo.value) == expected - @pytest.mark.parametrize('data', [list(), tuple(), dict(), 'foo']) + @pytest.mark.parametrize('data', [[], tuple(), {}, 'foo']) def test_invalid_volume(self, data): sound = messaging.CriticalSound(name='default', volume=data) excinfo = self._check_sound(sound) @@ -1661,7 +1659,7 @@ def test_topic_management_custom_timeout(self, options, timeout): class TestSend: _DEFAULT_RESPONSE = json.dumps({'name': 'message-id'}) - _CLIENT_VERSION = 'fire-admin-python/{0}'.format(firebase_admin.__version__) + _CLIENT_VERSION = f'fire-admin-python/{firebase_admin.__version__}' @classmethod def setup_class(cls): @@ -1738,7 +1736,7 @@ def test_send_error(self, status, exc_type): msg = messaging.Message(topic='foo') with pytest.raises(exc_type) as excinfo: messaging.send(msg) - expected = 'Unexpected HTTP response with status: {0}; body: {{}}'.format(status) + expected = f'Unexpected HTTP response with status: {status}; body: {{}}' check_exception(excinfo.value, expected, status) assert len(recorder) == 1 body = {'message': messaging._MessagingService.JSON_ENCODER.default(msg)} @@ -1826,17 +1824,7 @@ def test_send_unknown_fcm_error_code(self, status): self._assert_request(recorder[0], 'POST', self._get_url('explicit-project-id'), body) -class _HttpMockException: - - def __init__(self, exc): - self._exc = exc - - def request(self, url, **kwargs): - raise self._exc - - -class TestBatch: - +class TestSendEach(): @classmethod def setup_class(cls): cred = testutils.MockCredential() @@ -1856,40 +1844,6 @@ def _instrument_messaging_service(self, response_dict, app=None): testutils.MockRequestBasedMultiRequestAdapter(response_dict, recorder)) return fcm_service, recorder - def _instrument_batch_messaging_service(self, app=None, status=200, payload='', exc=None): - def build_mock_transport(_): - if exc: - return _HttpMockException(exc) - - if status == 200: - content_type = 'multipart/mixed; boundary=boundary' - else: - content_type = 'application/json' - return http.HttpMockSequence([ - ({'status': str(status), 'content-type': content_type}, payload), - ]) - - if not app: - app = firebase_admin.get_app() - - fcm_service = messaging._get_messaging_service(app) - fcm_service._build_transport = build_mock_transport - return fcm_service - - def _batch_payload(self, payloads): - # payloads should be a list of (status_code, content) tuples - payload = '' - _playload_format = """--boundary\r\nContent-Type: application/http\r\n\ -Content-ID: \r\n\r\nHTTP/1.1 {} Success\r\n\ -Content-Type: application/json; charset=UTF-8\r\n\r\n{}\r\n\r\n""" - for (index, (status_code, content)) in enumerate(payloads): - payload += _playload_format.format(str(index + 1), str(status_code), content) - payload += '--boundary--' - return payload - - -class TestSendEach(TestBatch): - def test_no_project_id(self): def evaluate(): app = firebase_admin.initialize_app(testutils.MockCredential(), name='no_project_id') @@ -1927,8 +1881,8 @@ def test_send_each(self): assert batch_response.failure_count == 0 assert len(batch_response.responses) == 2 assert [r.message_id for r in batch_response.responses] == ['message-id1', 'message-id2'] - assert all([r.success for r in batch_response.responses]) - assert not any([r.exception for r in batch_response.responses]) + assert all(r.success for r in batch_response.responses) + assert not any(r.exception for r in batch_response.responses) @respx.mock @pytest.mark.asyncio @@ -1948,19 +1902,13 @@ async def test_send_each_async(self): batch_response = await messaging.send_each_async([msg1, msg2, msg3], dry_run=True) - # try: - # batch_response = await messaging.send_each_async([msg1, msg2], dry_run=True) - # except Exception as error: - # if isinstance(error.cause.__cause__, StopIteration): - # raise Exception('Received more requests than mocks') - assert batch_response.success_count == 3 assert batch_response.failure_count == 0 assert len(batch_response.responses) == 3 assert [r.message_id for r in batch_response.responses] \ == ['message-id1', 'message-id2', 'message-id3'] - assert all([r.success for r in batch_response.responses]) - assert not any([r.exception for r in batch_response.responses]) + assert all(r.success for r in batch_response.responses) + assert not any(r.exception for r in batch_response.responses) assert route.call_count == 3 @@ -2028,8 +1976,8 @@ async def test_send_each_async_error_401_pass_on_auth_retry(self): assert batch_response.failure_count == 0 assert len(batch_response.responses) == 1 assert [r.message_id for r in batch_response.responses] == ['message-id1'] - assert all([r.success for r in batch_response.responses]) - assert not any([r.exception for r in batch_response.responses]) + assert all(r.success for r in batch_response.responses) + assert not any(r.exception for r in batch_response.responses) @respx.mock @pytest.mark.asyncio @@ -2101,11 +2049,12 @@ async def test_send_each_async_error_500_pass_on_retry_config(self): assert batch_response.failure_count == 0 assert len(batch_response.responses) == 1 assert [r.message_id for r in batch_response.responses] == ['message-id1'] - assert all([r.success for r in batch_response.responses]) - assert not any([r.exception for r in batch_response.responses]) + assert all(r.success for r in batch_response.responses) + assert not any(r.exception for r in batch_response.responses) + - @respx.mock @pytest.mark.asyncio + @respx.mock async def test_send_each_async_request_error(self): responses = httpx.ConnectError("Test request error", request=httpx.Request( 'POST', @@ -2217,19 +2166,19 @@ def test_send_each_fcm_error_code(self, status, fcm_error_code, exc_type): check_exception(exception, 'test error', status) -class TestSendEachForMulticast(TestBatch): +class TestSendEachForMulticast(TestSendEach): def test_no_project_id(self): def evaluate(): app = firebase_admin.initialize_app(testutils.MockCredential(), name='no_project_id') with pytest.raises(ValueError): - messaging.send_all([messaging.Message(topic='foo')], app=app) + messaging.send_each([messaging.Message(topic='foo')], app=app) testutils.run_without_project_id(evaluate) @pytest.mark.parametrize('msg', NON_LIST_ARGS) def test_invalid_send_each_for_multicast(self, msg): with pytest.raises(ValueError) as excinfo: - messaging.send_multicast(msg) + messaging.send_each_for_multicast(msg) expected = 'Message must be an instance of messaging.MulticastMessage class.' assert str(excinfo.value) == expected @@ -2244,8 +2193,8 @@ def test_send_each_for_multicast(self): assert batch_response.failure_count == 0 assert len(batch_response.responses) == 2 assert [r.message_id for r in batch_response.responses] == ['message-id1', 'message-id2'] - assert all([r.success for r in batch_response.responses]) - assert not any([r.exception for r in batch_response.responses]) + assert all(r.success for r in batch_response.responses) + assert not any(r.exception for r in batch_response.responses) @pytest.mark.parametrize('status', HTTP_ERROR_CODES) def test_send_each_for_multicast_detailed_error(self, status): @@ -2338,432 +2287,6 @@ def test_send_each_for_multicast_fcm_error_code(self, status): check_exception(exception, 'test error', status) -class TestSendAll(TestBatch): - - def test_no_project_id(self): - def evaluate(): - app = firebase_admin.initialize_app(testutils.MockCredential(), name='no_project_id') - with pytest.raises(ValueError): - messaging.send_all([messaging.Message(topic='foo')], app=app) - testutils.run_without_project_id(evaluate) - - @pytest.mark.parametrize('msg', NON_LIST_ARGS) - def test_invalid_send_all(self, msg): - with pytest.raises(ValueError) as excinfo: - messaging.send_all(msg) - if isinstance(msg, list): - expected = 'Message must be an instance of messaging.Message class.' - assert str(excinfo.value) == expected - else: - expected = 'messages must be a list of messaging.Message instances.' - assert str(excinfo.value) == expected - - def test_invalid_over_500(self): - msg = messaging.Message(topic='foo') - with pytest.raises(ValueError) as excinfo: - messaging.send_all([msg for _ in range(0, 501)]) - expected = 'messages must not contain more than 500 elements.' - assert str(excinfo.value) == expected - - def test_send_all(self): - payload = json.dumps({'name': 'message-id'}) - _ = self._instrument_batch_messaging_service( - payload=self._batch_payload([(200, payload), (200, payload)])) - msg = messaging.Message(topic='foo') - batch_response = messaging.send_all([msg, msg], dry_run=True) - assert batch_response.success_count == 2 - assert batch_response.failure_count == 0 - assert len(batch_response.responses) == 2 - assert [r.message_id for r in batch_response.responses] == ['message-id', 'message-id'] - assert all([r.success for r in batch_response.responses]) - assert not any([r.exception for r in batch_response.responses]) - - def test_send_all_with_positional_param_enforcement(self): - payload = json.dumps({'name': 'message-id'}) - _ = self._instrument_batch_messaging_service( - payload=self._batch_payload([(200, payload), (200, payload)])) - msg = messaging.Message(topic='foo') - - enforcement = _helpers.positional_parameters_enforcement - _helpers.positional_parameters_enforcement = _helpers.POSITIONAL_EXCEPTION - try: - batch_response = messaging.send_all([msg, msg], dry_run=True) - assert batch_response.success_count == 2 - finally: - _helpers.positional_parameters_enforcement = enforcement - - @pytest.mark.parametrize('status', HTTP_ERROR_CODES) - def test_send_all_detailed_error(self, status): - success_payload = json.dumps({'name': 'message-id'}) - error_payload = json.dumps({ - 'error': { - 'status': 'INVALID_ARGUMENT', - 'message': 'test error' - } - }) - _ = self._instrument_batch_messaging_service( - payload=self._batch_payload([(200, success_payload), (status, error_payload)])) - msg = messaging.Message(topic='foo') - batch_response = messaging.send_all([msg, msg]) - assert batch_response.success_count == 1 - assert batch_response.failure_count == 1 - assert len(batch_response.responses) == 2 - success_response = batch_response.responses[0] - assert success_response.message_id == 'message-id' - assert success_response.success is True - assert success_response.exception is None - error_response = batch_response.responses[1] - assert error_response.message_id is None - assert error_response.success is False - exception = error_response.exception - assert isinstance(exception, exceptions.InvalidArgumentError) - check_exception(exception, 'test error', status) - - @pytest.mark.parametrize('status', HTTP_ERROR_CODES) - def test_send_all_canonical_error_code(self, status): - success_payload = json.dumps({'name': 'message-id'}) - error_payload = json.dumps({ - 'error': { - 'status': 'NOT_FOUND', - 'message': 'test error' - } - }) - _ = self._instrument_batch_messaging_service( - payload=self._batch_payload([(200, success_payload), (status, error_payload)])) - msg = messaging.Message(topic='foo') - batch_response = messaging.send_all([msg, msg]) - assert batch_response.success_count == 1 - assert batch_response.failure_count == 1 - assert len(batch_response.responses) == 2 - success_response = batch_response.responses[0] - assert success_response.message_id == 'message-id' - assert success_response.success is True - assert success_response.exception is None - error_response = batch_response.responses[1] - assert error_response.message_id is None - assert error_response.success is False - exception = error_response.exception - assert isinstance(exception, exceptions.NotFoundError) - check_exception(exception, 'test error', status) - - @pytest.mark.parametrize('status', HTTP_ERROR_CODES) - @pytest.mark.parametrize('fcm_error_code, exc_type', FCM_ERROR_CODES.items()) - def test_send_all_fcm_error_code(self, status, fcm_error_code, exc_type): - success_payload = json.dumps({'name': 'message-id'}) - error_payload = json.dumps({ - 'error': { - 'status': 'INVALID_ARGUMENT', - 'message': 'test error', - 'details': [ - { - '@type': 'type.googleapis.com/google.firebase.fcm.v1.FcmError', - 'errorCode': fcm_error_code, - }, - ], - } - }) - _ = self._instrument_batch_messaging_service( - payload=self._batch_payload([(200, success_payload), (status, error_payload)])) - msg = messaging.Message(topic='foo') - batch_response = messaging.send_all([msg, msg]) - assert batch_response.success_count == 1 - assert batch_response.failure_count == 1 - assert len(batch_response.responses) == 2 - success_response = batch_response.responses[0] - assert success_response.message_id == 'message-id' - assert success_response.success is True - assert success_response.exception is None - error_response = batch_response.responses[1] - assert error_response.message_id is None - assert error_response.success is False - exception = error_response.exception - assert isinstance(exception, exc_type) - check_exception(exception, 'test error', status) - - @pytest.mark.parametrize('status, exc_type', HTTP_ERROR_CODES.items()) - def test_send_all_batch_error(self, status, exc_type): - _ = self._instrument_batch_messaging_service(status=status, payload='{}') - msg = messaging.Message(topic='foo') - with pytest.raises(exc_type) as excinfo: - messaging.send_all([msg]) - expected = 'Unexpected HTTP response with status: {0}; body: {{}}'.format(status) - check_exception(excinfo.value, expected, status) - - @pytest.mark.parametrize('status', HTTP_ERROR_CODES) - def test_send_all_batch_detailed_error(self, status): - payload = json.dumps({ - 'error': { - 'status': 'INVALID_ARGUMENT', - 'message': 'test error' - } - }) - _ = self._instrument_batch_messaging_service(status=status, payload=payload) - msg = messaging.Message(topic='foo') - with pytest.raises(exceptions.InvalidArgumentError) as excinfo: - messaging.send_all([msg]) - check_exception(excinfo.value, 'test error', status) - - @pytest.mark.parametrize('status', HTTP_ERROR_CODES) - def test_send_all_batch_canonical_error_code(self, status): - payload = json.dumps({ - 'error': { - 'status': 'NOT_FOUND', - 'message': 'test error' - } - }) - _ = self._instrument_batch_messaging_service(status=status, payload=payload) - msg = messaging.Message(topic='foo') - with pytest.raises(exceptions.NotFoundError) as excinfo: - messaging.send_all([msg]) - check_exception(excinfo.value, 'test error', status) - - @pytest.mark.parametrize('status', HTTP_ERROR_CODES) - def test_send_all_batch_fcm_error_code(self, status): - payload = json.dumps({ - 'error': { - 'status': 'INVALID_ARGUMENT', - 'message': 'test error', - 'details': [ - { - '@type': 'type.googleapis.com/google.firebase.fcm.v1.FcmError', - 'errorCode': 'UNREGISTERED', - }, - ], - } - }) - _ = self._instrument_batch_messaging_service(status=status, payload=payload) - msg = messaging.Message(topic='foo') - with pytest.raises(messaging.UnregisteredError) as excinfo: - messaging.send_all([msg]) - check_exception(excinfo.value, 'test error', status) - - def test_send_all_runtime_exception(self): - exc = BrokenPipeError('Test error') - _ = self._instrument_batch_messaging_service(exc=exc) - msg = messaging.Message(topic='foo') - - with pytest.raises(exceptions.UnknownError) as excinfo: - messaging.send_all([msg]) - - expected = 'Unknown error while making a remote service call: Test error' - assert str(excinfo.value) == expected - assert excinfo.value.cause is exc - assert excinfo.value.http_response is None - - def test_send_transport_init(self): - def track_call_count(build_transport): - def wrapper(credential): - wrapper.calls += 1 - return build_transport(credential) - wrapper.calls = 0 - return wrapper - - payload = json.dumps({'name': 'message-id'}) - fcm_service = self._instrument_batch_messaging_service( - payload=self._batch_payload([(200, payload), (200, payload)])) - build_mock_transport = fcm_service._build_transport - fcm_service._build_transport = track_call_count(build_mock_transport) - msg = messaging.Message(topic='foo') - - batch_response = messaging.send_all([msg, msg], dry_run=True) - assert batch_response.success_count == 2 - assert fcm_service._build_transport.calls == 1 - - batch_response = messaging.send_all([msg, msg], dry_run=True) - assert batch_response.success_count == 2 - assert fcm_service._build_transport.calls == 2 - - -class TestSendMulticast(TestBatch): - - def test_no_project_id(self): - def evaluate(): - app = firebase_admin.initialize_app(testutils.MockCredential(), name='no_project_id') - with pytest.raises(ValueError): - messaging.send_all([messaging.Message(topic='foo')], app=app) - testutils.run_without_project_id(evaluate) - - @pytest.mark.parametrize('msg', NON_LIST_ARGS) - def test_invalid_send_multicast(self, msg): - with pytest.raises(ValueError) as excinfo: - messaging.send_multicast(msg) - expected = 'Message must be an instance of messaging.MulticastMessage class.' - assert str(excinfo.value) == expected - - def test_send_multicast(self): - payload = json.dumps({'name': 'message-id'}) - _ = self._instrument_batch_messaging_service( - payload=self._batch_payload([(200, payload), (200, payload)])) - msg = messaging.MulticastMessage(tokens=['foo', 'foo']) - batch_response = messaging.send_multicast(msg, dry_run=True) - assert batch_response.success_count == 2 - assert batch_response.failure_count == 0 - assert len(batch_response.responses) == 2 - assert [r.message_id for r in batch_response.responses] == ['message-id', 'message-id'] - assert all([r.success for r in batch_response.responses]) - assert not any([r.exception for r in batch_response.responses]) - - @pytest.mark.parametrize('status', HTTP_ERROR_CODES) - def test_send_multicast_detailed_error(self, status): - success_payload = json.dumps({'name': 'message-id'}) - error_payload = json.dumps({ - 'error': { - 'status': 'INVALID_ARGUMENT', - 'message': 'test error' - } - }) - _ = self._instrument_batch_messaging_service( - payload=self._batch_payload([(200, success_payload), (status, error_payload)])) - msg = messaging.MulticastMessage(tokens=['foo', 'foo']) - batch_response = messaging.send_multicast(msg) - assert batch_response.success_count == 1 - assert batch_response.failure_count == 1 - assert len(batch_response.responses) == 2 - success_response = batch_response.responses[0] - assert success_response.message_id == 'message-id' - assert success_response.success is True - assert success_response.exception is None - error_response = batch_response.responses[1] - assert error_response.message_id is None - assert error_response.success is False - assert error_response.exception is not None - exception = error_response.exception - assert isinstance(exception, exceptions.InvalidArgumentError) - check_exception(exception, 'test error', status) - - @pytest.mark.parametrize('status', HTTP_ERROR_CODES) - def test_send_multicast_canonical_error_code(self, status): - success_payload = json.dumps({'name': 'message-id'}) - error_payload = json.dumps({ - 'error': { - 'status': 'NOT_FOUND', - 'message': 'test error' - } - }) - _ = self._instrument_batch_messaging_service( - payload=self._batch_payload([(200, success_payload), (status, error_payload)])) - msg = messaging.MulticastMessage(tokens=['foo', 'foo']) - batch_response = messaging.send_multicast(msg) - assert batch_response.success_count == 1 - assert batch_response.failure_count == 1 - assert len(batch_response.responses) == 2 - success_response = batch_response.responses[0] - assert success_response.message_id == 'message-id' - assert success_response.success is True - assert success_response.exception is None - error_response = batch_response.responses[1] - assert error_response.message_id is None - assert error_response.success is False - assert error_response.exception is not None - exception = error_response.exception - assert isinstance(exception, exceptions.NotFoundError) - check_exception(exception, 'test error', status) - - @pytest.mark.parametrize('status', HTTP_ERROR_CODES) - def test_send_multicast_fcm_error_code(self, status): - success_payload = json.dumps({'name': 'message-id'}) - error_payload = json.dumps({ - 'error': { - 'status': 'INVALID_ARGUMENT', - 'message': 'test error', - 'details': [ - { - '@type': 'type.googleapis.com/google.firebase.fcm.v1.FcmError', - 'errorCode': 'UNREGISTERED', - }, - ], - } - }) - _ = self._instrument_batch_messaging_service( - payload=self._batch_payload([(200, success_payload), (status, error_payload)])) - msg = messaging.MulticastMessage(tokens=['foo', 'foo']) - batch_response = messaging.send_multicast(msg) - assert batch_response.success_count == 1 - assert batch_response.failure_count == 1 - assert len(batch_response.responses) == 2 - success_response = batch_response.responses[0] - assert success_response.message_id == 'message-id' - assert success_response.success is True - assert success_response.exception is None - error_response = batch_response.responses[1] - assert error_response.message_id is None - assert error_response.success is False - assert error_response.exception is not None - exception = error_response.exception - assert isinstance(exception, messaging.UnregisteredError) - check_exception(exception, 'test error', status) - - @pytest.mark.parametrize('status, exc_type', HTTP_ERROR_CODES.items()) - def test_send_multicast_batch_error(self, status, exc_type): - _ = self._instrument_batch_messaging_service(status=status, payload='{}') - msg = messaging.MulticastMessage(tokens=['foo']) - with pytest.raises(exc_type) as excinfo: - messaging.send_multicast(msg) - expected = 'Unexpected HTTP response with status: {0}; body: {{}}'.format(status) - check_exception(excinfo.value, expected, status) - - @pytest.mark.parametrize('status', HTTP_ERROR_CODES) - def test_send_multicast_batch_detailed_error(self, status): - payload = json.dumps({ - 'error': { - 'status': 'INVALID_ARGUMENT', - 'message': 'test error' - } - }) - _ = self._instrument_batch_messaging_service(status=status, payload=payload) - msg = messaging.MulticastMessage(tokens=['foo']) - with pytest.raises(exceptions.InvalidArgumentError) as excinfo: - messaging.send_multicast(msg) - check_exception(excinfo.value, 'test error', status) - - @pytest.mark.parametrize('status', HTTP_ERROR_CODES) - def test_send_multicast_batch_canonical_error_code(self, status): - payload = json.dumps({ - 'error': { - 'status': 'NOT_FOUND', - 'message': 'test error' - } - }) - _ = self._instrument_batch_messaging_service(status=status, payload=payload) - msg = messaging.MulticastMessage(tokens=['foo']) - with pytest.raises(exceptions.NotFoundError) as excinfo: - messaging.send_multicast(msg) - check_exception(excinfo.value, 'test error', status) - - @pytest.mark.parametrize('status', HTTP_ERROR_CODES) - def test_send_multicast_batch_fcm_error_code(self, status): - payload = json.dumps({ - 'error': { - 'status': 'INVALID_ARGUMENT', - 'message': 'test error', - 'details': [ - { - '@type': 'type.googleapis.com/google.firebase.fcm.v1.FcmError', - 'errorCode': 'UNREGISTERED', - }, - ], - } - }) - _ = self._instrument_batch_messaging_service(status=status, payload=payload) - msg = messaging.MulticastMessage(tokens=['foo']) - with pytest.raises(messaging.UnregisteredError) as excinfo: - messaging.send_multicast(msg) - check_exception(excinfo.value, 'test error', status) - - def test_send_multicast_runtime_exception(self): - exc = BrokenPipeError('Test error') - _ = self._instrument_batch_messaging_service(exc=exc) - msg = messaging.MulticastMessage(tokens=['foo']) - - with pytest.raises(exceptions.UnknownError) as excinfo: - messaging.send_multicast(msg) - - expected = 'Unknown error while making a remote service call: Test error' - assert str(excinfo.value) == expected - assert excinfo.value.cause is exc - assert excinfo.value.http_response is None - - class TestTopicManagement: _DEFAULT_RESPONSE = json.dumps({'results': [{}, {'error': 'error_reason'}]}) @@ -2809,9 +2332,9 @@ def _assert_request(self, request, expected_method, expected_url): assert request.headers['x-goog-api-client'] == expected_metrics_header def _get_url(self, path): - return '{0}/{1}'.format(messaging._MessagingService.IID_URL, path) + return f'{messaging._MessagingService.IID_URL}/{path}' - @pytest.mark.parametrize('tokens', [None, '', list(), dict(), tuple()]) + @pytest.mark.parametrize('tokens', [None, '', [], {}, tuple()]) def test_invalid_tokens(self, tokens): expected = 'Tokens must be a string or a non-empty list of strings.' if isinstance(tokens, str): @@ -2860,7 +2383,7 @@ def test_subscribe_to_topic_non_json_error(self, status, exc_type): _, recorder = self._instrument_iid_service(status=status, payload='not json') with pytest.raises(exc_type) as excinfo: messaging.subscribe_to_topic('foo', 'test-topic') - reason = 'Unexpected HTTP response with status: {0}; body: not json'.format(status) + reason = f'Unexpected HTTP response with status: {status}; body: not json' assert str(excinfo.value) == reason assert len(recorder) == 1 self._assert_request(recorder[0], 'POST', self._get_url('iid/v1:batchAdd')) @@ -2889,7 +2412,7 @@ def test_unsubscribe_from_topic_non_json_error(self, status, exc_type): _, recorder = self._instrument_iid_service(status=status, payload='not json') with pytest.raises(exc_type) as excinfo: messaging.unsubscribe_from_topic('foo', 'test-topic') - reason = 'Unexpected HTTP response with status: {0}; body: not json'.format(status) + reason = f'Unexpected HTTP response with status: {status}; body: not json' assert str(excinfo.value) == reason assert len(recorder) == 1 self._assert_request(recorder[0], 'POST', self._get_url('iid/v1:batchRemove')) diff --git a/tests/test_ml.py b/tests/test_ml.py index 18a9e2754..bcc93fd05 100644 --- a/tests/test_ml.py +++ b/tests/test_ml.py @@ -49,7 +49,7 @@ TAGS_2 = [TAG_1, TAG_3] MODEL_ID_1 = 'modelId1' -MODEL_NAME_1 = 'projects/{0}/models/{1}'.format(PROJECT_ID, MODEL_ID_1) +MODEL_NAME_1 = f'projects/{PROJECT_ID}/models/{MODEL_ID_1}' DISPLAY_NAME_1 = 'displayName1' MODEL_JSON_1 = { 'name': MODEL_NAME_1, @@ -58,7 +58,7 @@ MODEL_1 = ml.Model.from_dict(MODEL_JSON_1) MODEL_ID_2 = 'modelId2' -MODEL_NAME_2 = 'projects/{0}/models/{1}'.format(PROJECT_ID, MODEL_ID_2) +MODEL_NAME_2 = f'projects/{PROJECT_ID}/models/{MODEL_ID_2}' DISPLAY_NAME_2 = 'displayName2' MODEL_JSON_2 = { 'name': MODEL_NAME_2, @@ -67,7 +67,7 @@ MODEL_2 = ml.Model.from_dict(MODEL_JSON_2) MODEL_ID_3 = 'modelId3' -MODEL_NAME_3 = 'projects/{0}/models/{1}'.format(PROJECT_ID, MODEL_ID_3) +MODEL_NAME_3 = f'projects/{PROJECT_ID}/models/{MODEL_ID_3}' DISPLAY_NAME_3 = 'displayName3' MODEL_JSON_3 = { 'name': MODEL_NAME_3, @@ -79,7 +79,7 @@ 'published': True } VALIDATION_ERROR_CODE = 400 -VALIDATION_ERROR_MSG = 'No model format found for {0}.'.format(MODEL_ID_1) +VALIDATION_ERROR_MSG = f'No model format found for {MODEL_ID_1}.' MODEL_STATE_ERROR_JSON = { 'validationError': { 'code': VALIDATION_ERROR_CODE, @@ -87,19 +87,19 @@ } } -OPERATION_NAME_1 = 'projects/{0}/operations/123'.format(PROJECT_ID) +OPERATION_NAME_1 = f'projects/{PROJECT_ID}/operations/123' OPERATION_NOT_DONE_JSON_1 = { 'name': OPERATION_NAME_1, 'metadata': { '@type': 'type.googleapis.com/google.firebase.ml.v1beta2.ModelOperationMetadata', - 'name': 'projects/{0}/models/{1}'.format(PROJECT_ID, MODEL_ID_1), + 'name': f'projects/{PROJECT_ID}/models/{MODEL_ID_1}', 'basic_operation_status': 'BASIC_OPERATION_STATUS_UPLOADING' } } GCS_BUCKET_NAME = 'my_bucket' GCS_BLOB_NAME = 'mymodel.tflite' -GCS_TFLITE_URI = 'gs://{0}/{1}'.format(GCS_BUCKET_NAME, GCS_BLOB_NAME) +GCS_TFLITE_URI = f'gs://{GCS_BUCKET_NAME}/{GCS_BLOB_NAME}' GCS_TFLITE_URI_JSON = {'gcsTfliteUri': GCS_TFLITE_URI} GCS_TFLITE_MODEL_SOURCE = ml.TFLiteGCSModelSource(GCS_TFLITE_URI) TFLITE_FORMAT_JSON = { @@ -121,18 +121,6 @@ } TFLITE_FORMAT_2 = ml.TFLiteFormat.from_dict(TFLITE_FORMAT_JSON_2) -AUTOML_MODEL_NAME = 'projects/111111111111/locations/us-central1/models/ICN7683346839371803263' -AUTOML_MODEL_SOURCE = ml.TFLiteAutoMlSource(AUTOML_MODEL_NAME) -TFLITE_FORMAT_JSON_3 = { - 'automlModel': AUTOML_MODEL_NAME, - 'sizeBytes': '3456789' -} -TFLITE_FORMAT_3 = ml.TFLiteFormat.from_dict(TFLITE_FORMAT_JSON_3) - -AUTOML_MODEL_NAME_2 = 'projects/2222222222/locations/us-central1/models/ICN2222222222222222222' -AUTOML_MODEL_NAME_JSON_2 = {'automlModel': AUTOML_MODEL_NAME_2} -AUTOML_MODEL_SOURCE_2 = ml.TFLiteAutoMlSource(AUTOML_MODEL_NAME_2) - CREATED_UPDATED_MODEL_JSON_1 = { 'name': MODEL_NAME_1, 'displayName': DISPLAY_NAME_1, @@ -269,8 +257,8 @@ INVALID_MODEL_ARGS = [ 'abc', 4.2, - list(), - dict(), + [], + {}, True, -1, 0, @@ -284,9 +272,10 @@ 'projects/$#@/operations/123', 'projects/1234/operations/123/extrathing', ] -PAGE_SIZE_VALUE_ERROR_MSG = 'Page size must be a positive integer between ' \ - '1 and {0}'.format(ml._MAX_PAGE_SIZE) -INVALID_STRING_OR_NONE_ARGS = [0, -1, 4.2, 0x10, False, list(), dict()] +PAGE_SIZE_VALUE_ERROR_MSG = ( + f'Page size must be a positive integer between 1 and {ml._MAX_PAGE_SIZE}' +) +INVALID_STRING_OR_NONE_ARGS = [0, -1, 4.2, 0x10, False, [], {}] # For validation type errors @@ -370,8 +359,7 @@ def teardown_class(cls): @staticmethod def _op_url(project_id): - return BASE_URL + \ - 'projects/{0}/operations/123'.format(project_id) + return BASE_URL + f'projects/{project_id}/operations/123' def test_model_success_err_state_lro(self): model = ml.Model.from_dict(FULL_MODEL_ERR_STATE_LRO_JSON) @@ -423,14 +411,6 @@ def test_model_keyword_based_creation_and_setters(self): 'tfliteModel': TFLITE_FORMAT_JSON_2 } - model.model_format = TFLITE_FORMAT_3 - assert model.as_dict() == { - 'displayName': DISPLAY_NAME_2, - 'tags': TAGS_2, - 'tfliteModel': TFLITE_FORMAT_JSON_3 - } - - def test_gcs_tflite_model_format_source_creation(self): model_source = ml.TFLiteGCSModelSource(gcs_tflite_uri=GCS_TFLITE_URI) model_format = ml.TFLiteFormat(model_source=model_source) @@ -442,17 +422,6 @@ def test_gcs_tflite_model_format_source_creation(self): } } - def test_auto_ml_tflite_model_format_source_creation(self): - model_source = ml.TFLiteAutoMlSource(auto_ml_model=AUTOML_MODEL_NAME) - model_format = ml.TFLiteFormat(model_source=model_source) - model = ml.Model(display_name=DISPLAY_NAME_1, model_format=model_format) - assert model.as_dict() == { - 'displayName': DISPLAY_NAME_1, - 'tfliteModel': { - 'automlModel': AUTOML_MODEL_NAME - } - } - def test_source_creation_from_tflite_file(self): model_source = ml.TFLiteGCSModelSource.from_tflite_model_file( "my_model.tflite", "my_bucket") @@ -466,13 +435,6 @@ def test_gcs_tflite_model_source_setters(self): assert model_source.gcs_tflite_uri == GCS_TFLITE_URI_2 assert model_source.as_dict() == GCS_TFLITE_URI_JSON_2 - def test_auto_ml_tflite_model_source_setters(self): - model_source = ml.TFLiteAutoMlSource(AUTOML_MODEL_NAME) - model_source.auto_ml_model = AUTOML_MODEL_NAME_2 - assert model_source.auto_ml_model == AUTOML_MODEL_NAME_2 - assert model_source.as_dict() == AUTOML_MODEL_NAME_JSON_2 - - def test_model_format_setters(self): model_format = ml.TFLiteFormat(model_source=GCS_TFLITE_MODEL_SOURCE) model_format.model_source = GCS_TFLITE_MODEL_SOURCE_2 @@ -483,14 +445,6 @@ def test_model_format_setters(self): } } - model_format.model_source = AUTOML_MODEL_SOURCE - assert model_format.model_source == AUTOML_MODEL_SOURCE - assert model_format.as_dict() == { - 'tfliteModel': { - 'automlModel': AUTOML_MODEL_NAME - } - } - def test_model_as_dict_for_upload(self): model_source = ml.TFLiteGCSModelSource(gcs_tflite_uri=GCS_TFLITE_URI) model_format = ml.TFLiteFormat(model_source=model_source) @@ -576,23 +530,6 @@ def test_gcs_tflite_source_validation_errors(self, uri, exc_type): ml.TFLiteGCSModelSource(gcs_tflite_uri=uri) check_error(excinfo, exc_type) - @pytest.mark.parametrize('auto_ml_model, exc_type', [ - (123, TypeError), - ('abc', ValueError), - ('/projects/123456/locations/us-central1/models/noLeadingSlash', ValueError), - ('projects/123546/models/ICN123456', ValueError), - ('projects//locations/us-central1/models/ICN123456', ValueError), - ('projects/123456/locations//models/ICN123456', ValueError), - ('projects/123456/locations/us-central1/models/', ValueError), - ('projects/ABC/locations/us-central1/models/ICN123456', ValueError), - ('projects/123456/locations/us-central1/models/@#$%^&', ValueError), - ('projects/123456/locations/us-cent/ral1/models/ICN123456', ValueError), - ]) - def test_auto_ml_tflite_source_validation_errors(self, auto_ml_model, exc_type): - with pytest.raises(exc_type) as excinfo: - ml.TFLiteAutoMlSource(auto_ml_model=auto_ml_model) - check_error(excinfo, exc_type) - def test_wait_for_unlocked_not_locked(self): model = ml.Model(display_name="not_locked") model.wait_for_unlocked() @@ -632,16 +569,15 @@ def teardown_class(cls): @staticmethod def _url(project_id): - return BASE_URL + 'projects/{0}/models'.format(project_id) + return BASE_URL + f'projects/{project_id}/models' @staticmethod def _op_url(project_id): - return BASE_URL + \ - 'projects/{0}/operations/123'.format(project_id) + return BASE_URL + f'projects/{project_id}/operations/123' @staticmethod def _get_url(project_id, model_id): - return BASE_URL + 'projects/{0}/models/{1}'.format(project_id, model_id) + return BASE_URL + f'projects/{project_id}/models/{model_id}' def test_immediate_done(self): instrument_ml_service(status=200, payload=OPERATION_DONE_RESPONSE) @@ -726,12 +662,11 @@ def teardown_class(cls): @staticmethod def _url(project_id, model_id): - return BASE_URL + 'projects/{0}/models/{1}'.format(project_id, model_id) + return BASE_URL + f'projects/{project_id}/models/{model_id}' @staticmethod def _op_url(project_id): - return BASE_URL + \ - 'projects/{0}/operations/123'.format(project_id) + return BASE_URL + f'projects/{project_id}/operations/123' def test_immediate_done(self): instrument_ml_service(status=200, payload=OPERATION_DONE_RESPONSE) @@ -823,18 +758,16 @@ def teardown_class(cls): @staticmethod def _update_url(project_id, model_id): - update_url = 'projects/{0}/models/{1}?updateMask=state.published'.format( - project_id, model_id) + update_url = f'projects/{project_id}/models/{model_id}?updateMask=state.published' return BASE_URL + update_url @staticmethod def _get_url(project_id, model_id): - return BASE_URL + 'projects/{0}/models/{1}'.format(project_id, model_id) + return BASE_URL + f'projects/{project_id}/models/{model_id}' @staticmethod def _op_url(project_id): - return BASE_URL + \ - 'projects/{0}/operations/123'.format(project_id) + return BASE_URL + f'projects/{project_id}/operations/123' @pytest.mark.parametrize('publish_function, published', PUBLISH_UNPUBLISH_WITH_ARGS) def test_immediate_done(self, publish_function, published): @@ -905,7 +838,7 @@ def teardown_class(cls): @staticmethod def _url(project_id, model_id): - return BASE_URL + 'projects/{0}/models/{1}'.format(project_id, model_id) + return BASE_URL + f'projects/{project_id}/models/{model_id}' def test_get_model(self): recorder = instrument_ml_service(status=200, payload=DEFAULT_GET_RESPONSE) @@ -956,7 +889,7 @@ def teardown_class(cls): @staticmethod def _url(project_id, model_id): - return BASE_URL + 'projects/{0}/models/{1}'.format(project_id, model_id) + return BASE_URL + f'projects/{project_id}/models/{model_id}' def test_delete_model(self): recorder = instrument_ml_service(status=200, payload=EMPTY_RESPONSE) @@ -1004,7 +937,7 @@ def teardown_class(cls): @staticmethod def _url(project_id): - return BASE_URL + 'projects/{0}/models'.format(project_id) + return BASE_URL + f'projects/{project_id}/models' @staticmethod def _check_page(page, model_count): @@ -1033,8 +966,8 @@ def test_list_models_with_all_args(self): assert len(recorder) == 1 _assert_request(recorder[0], 'GET', ( TestListModels._url(PROJECT_ID) + - '?filter=display_name%3DdisplayName3&page_size=10&page_token={0}' - .format(PAGE_TOKEN))) + f'?filter=display_name%3DdisplayName3&page_size=10&page_token={PAGE_TOKEN}' + )) assert isinstance(models_page, ml.ListModelsPage) assert len(models_page.models) == 1 assert models_page.models[0] == MODEL_3 @@ -1049,8 +982,8 @@ def test_list_models_list_filter_validation(self, list_filter): @pytest.mark.parametrize('page_size, exc_type, error_message', [ ('abc', TypeError, 'Page size must be a number or None.'), (4.2, TypeError, 'Page size must be a number or None.'), - (list(), TypeError, 'Page size must be a number or None.'), - (dict(), TypeError, 'Page size must be a number or None.'), + ([], TypeError, 'Page size must be a number or None.'), + ({}, TypeError, 'Page size must be a number or None.'), (True, TypeError, 'Page size must be a number or None.'), (-1, ValueError, PAGE_SIZE_VALUE_ERROR_MSG), (0, ValueError, PAGE_SIZE_VALUE_ERROR_MSG), @@ -1094,7 +1027,7 @@ def test_list_single_page(self): assert models_page.next_page_token == '' assert models_page.has_next_page is False assert models_page.get_next_page() is None - models = [model for model in models_page.iterate_all()] + models = list(models_page.iterate_all()) assert len(models) == 1 def test_list_multiple_pages(self): @@ -1124,7 +1057,7 @@ def test_list_models_paged_iteration(self): iterator = page.iterate_all() for index in range(2): model = next(iterator) - assert model.display_name == 'displayName{0}'.format(index+1) + assert model.display_name == f'displayName{index+1}' assert len(recorder) == 1 # Page 2 @@ -1140,7 +1073,7 @@ def test_list_models_stop_iteration(self): assert len(recorder) == 1 assert len(page.models) == 3 iterator = page.iterate_all() - models = [model for model in iterator] + models = list(iterator) assert len(page.models) == 3 with pytest.raises(StopIteration): next(iterator) @@ -1151,5 +1084,5 @@ def test_list_models_no_models(self): page = ml.list_models() assert len(recorder) == 1 assert len(page.models) == 0 - models = [model for model in page.iterate_all()] + models = list(page.iterate_all()) assert len(models) == 0 diff --git a/tests/test_project_management.py b/tests/test_project_management.py index a242f523f..89e48c2e5 100644 --- a/tests/test_project_management.py +++ b/tests/test_project_management.py @@ -545,7 +545,7 @@ def test_custom_timeout(self, timeout): 'projectId': 'test-project-id' } app = firebase_admin.initialize_app( - testutils.MockCredential(), options, 'timeout-{0}'.format(timeout)) + testutils.MockCredential(), options, f'timeout-{timeout}') project_management_service = project_management._get_project_management_service(app) assert project_management_service._client.timeout == timeout @@ -820,7 +820,7 @@ def test_list_android_apps_rpc_error(self): assert len(recorder) == 1 def test_list_android_apps_empty_list(self): - recorder = self._instrument_service(statuses=[200], responses=[json.dumps(dict())]) + recorder = self._instrument_service(statuses=[200], responses=[json.dumps({})]) android_apps = project_management.list_android_apps() @@ -883,7 +883,7 @@ def test_list_ios_apps_rpc_error(self): assert len(recorder) == 1 def test_list_ios_apps_empty_list(self): - recorder = self._instrument_service(statuses=[200], responses=[json.dumps(dict())]) + recorder = self._instrument_service(statuses=[200], responses=[json.dumps({})]) ios_apps = project_management.list_ios_apps() diff --git a/tests/test_remote_config.py b/tests/test_remote_config.py index 8c6248e18..7bbf9721d 100644 --- a/tests/test_remote_config.py +++ b/tests/test_remote_config.py @@ -830,7 +830,7 @@ def __init__(self, data, status, recorder, etag=ETAG): self._etag = etag def send(self, request, **kwargs): - resp = super(MockAdapter, self).send(request, **kwargs) + resp = super().send(request, **kwargs) resp.headers = {'etag': self._etag} return resp @@ -865,7 +865,7 @@ async def test_rc_instance_get_server_template(self): template = await rc_instance.get_server_template() - assert template.parameters == dict(test_key="test_value") + assert template.parameters == {"test_key": 'test_value'} assert str(template.version) == 'test' assert str(template.etag) == 'etag' diff --git a/tests/test_sseclient.py b/tests/test_sseclient.py index 70edcf0d0..2c523e36f 100644 --- a/tests/test_sseclient.py +++ b/tests/test_sseclient.py @@ -25,10 +25,10 @@ class MockSSEClientAdapter(testutils.MockAdapter): def __init__(self, payload, recorder): - super(MockSSEClientAdapter, self).__init__(payload, 200, recorder) + super().__init__(payload, 200, recorder) def send(self, request, **kwargs): - resp = super(MockSSEClientAdapter, self).send(request, **kwargs) + resp = super().send(request, **kwargs) resp.url = request.url resp.status_code = self.status resp.raw = io.BytesIO(self.data.encode()) diff --git a/tests/test_storage.py b/tests/test_storage.py index e15c4e2ab..c874ef640 100644 --- a/tests/test_storage.py +++ b/tests/test_storage.py @@ -33,7 +33,7 @@ def test_invalid_config(): with pytest.raises(ValueError): storage.bucket() -@pytest.mark.parametrize('name', [None, '', 0, 1, True, False, list(), tuple(), dict()]) +@pytest.mark.parametrize('name', [None, '', 0, 1, True, False, [], tuple(), {}]) def test_invalid_name(name): with pytest.raises(ValueError): storage.bucket(name) diff --git a/tests/test_tenant_mgt.py b/tests/test_tenant_mgt.py index 018892e3a..900faa376 100644 --- a/tests/test_tenant_mgt.py +++ b/tests/test_tenant_mgt.py @@ -107,8 +107,8 @@ LIST_OIDC_PROVIDER_CONFIGS_RESPONSE = testutils.resource('list_oidc_provider_configs.json') LIST_SAML_PROVIDER_CONFIGS_RESPONSE = testutils.resource('list_saml_provider_configs.json') -INVALID_TENANT_IDS = [None, '', 0, 1, True, False, list(), tuple(), dict()] -INVALID_BOOLEANS = ['', 1, 0, list(), tuple(), dict()] +INVALID_TENANT_IDS = [None, '', 0, 1, True, False, [], tuple(), {}] +INVALID_BOOLEANS = ['', 1, 0, [], tuple(), {}] USER_MGT_URL_PREFIX = 'https://identitytoolkit.googleapis.com/v1/projects/mock-project-id' PROVIDER_MGT_URL_PREFIX = 'https://identitytoolkit.googleapis.com/v2/projects/mock-project-id' @@ -152,7 +152,7 @@ def _instrument_provider_mgt(client, status, payload): class TestTenant: - @pytest.mark.parametrize('data', [None, 'foo', 0, 1, True, False, list(), tuple(), dict()]) + @pytest.mark.parametrize('data', [None, 'foo', 0, 1, True, False, [], tuple(), {}]) def test_invalid_data(self, data): with pytest.raises(ValueError): tenant_mgt.Tenant(data) @@ -197,7 +197,7 @@ def test_get_tenant(self, tenant_mgt_app): assert len(recorder) == 1 req = recorder[0] assert req.method == 'GET' - assert req.url == '{0}/tenants/tenant-id'.format(TENANT_MGT_URL_PREFIX) + assert req.url == f'{TENANT_MGT_URL_PREFIX}/tenants/tenant-id' assert req.headers['X-Client-Version'] == f'Python/Admin/{firebase_admin.__version__}' expected_metrics_header = _utils.get_metrics_header() + ' mock-cred-metric-tag' assert req.headers['x-goog-api-client'] == expected_metrics_header @@ -216,7 +216,7 @@ def test_tenant_not_found(self, tenant_mgt_app): class TestCreateTenant: - @pytest.mark.parametrize('display_name', [True, False, 1, 0, list(), tuple(), dict()]) + @pytest.mark.parametrize('display_name', [True, False, 1, 0, [], tuple(), {}]) def test_invalid_display_name_type(self, display_name, tenant_mgt_app): with pytest.raises(ValueError) as excinfo: tenant_mgt.create_tenant(display_name=display_name, app=tenant_mgt_app) @@ -290,7 +290,7 @@ def _assert_request(self, recorder, body): assert len(recorder) == 1 req = recorder[0] assert req.method == 'POST' - assert req.url == '{0}/tenants'.format(TENANT_MGT_URL_PREFIX) + assert req.url == f'{TENANT_MGT_URL_PREFIX}/tenants' assert req.headers['X-Client-Version'] == f'Python/Admin/{firebase_admin.__version__}' expected_metrics_header = _utils.get_metrics_header() + ' mock-cred-metric-tag' assert req.headers['x-goog-api-client'] == expected_metrics_header @@ -306,7 +306,7 @@ def test_invalid_tenant_id(self, tenant_id, tenant_mgt_app): tenant_mgt.update_tenant(tenant_id, display_name='My Tenant', app=tenant_mgt_app) assert str(excinfo.value).startswith('Tenant ID must be a non-empty string') - @pytest.mark.parametrize('display_name', [True, False, 1, 0, list(), tuple(), dict()]) + @pytest.mark.parametrize('display_name', [True, False, 1, 0, [], tuple(), {}]) def test_invalid_display_name_type(self, display_name, tenant_mgt_app): with pytest.raises(ValueError) as excinfo: tenant_mgt.update_tenant('tenant-id', display_name=display_name, app=tenant_mgt_app) @@ -390,8 +390,7 @@ def _assert_request(self, recorder, body, mask): assert len(recorder) == 1 req = recorder[0] assert req.method == 'PATCH' - assert req.url == '{0}/tenants/tenant-id?updateMask={1}'.format( - TENANT_MGT_URL_PREFIX, ','.join(mask)) + assert req.url == f'{TENANT_MGT_URL_PREFIX}/tenants/tenant-id?updateMask={",".join(mask)}' assert req.headers['X-Client-Version'] == f'Python/Admin/{firebase_admin.__version__}' expected_metrics_header = _utils.get_metrics_header() + ' mock-cred-metric-tag' assert req.headers['x-goog-api-client'] == expected_metrics_header @@ -414,7 +413,7 @@ def test_delete_tenant(self, tenant_mgt_app): assert len(recorder) == 1 req = recorder[0] assert req.method == 'DELETE' - assert req.url == '{0}/tenants/tenant-id'.format(TENANT_MGT_URL_PREFIX) + assert req.url == f'{TENANT_MGT_URL_PREFIX}/tenants/tenant-id' assert req.headers['X-Client-Version'] == f'Python/Admin/{firebase_admin.__version__}' expected_metrics_header = _utils.get_metrics_header() + ' mock-cred-metric-tag' assert req.headers['x-goog-api-client'] == expected_metrics_header @@ -433,12 +432,12 @@ def test_tenant_not_found(self, tenant_mgt_app): class TestListTenants: - @pytest.mark.parametrize('arg', [None, 'foo', list(), dict(), 0, -1, 101, False]) + @pytest.mark.parametrize('arg', [None, 'foo', [], {}, 0, -1, 101, False]) def test_invalid_max_results(self, tenant_mgt_app, arg): with pytest.raises(ValueError): tenant_mgt.list_tenants(max_results=arg, app=tenant_mgt_app) - @pytest.mark.parametrize('arg', ['', list(), dict(), 0, -1, True, False]) + @pytest.mark.parametrize('arg', ['', [], {}, 0, -1, True, False]) def test_invalid_page_token(self, tenant_mgt_app, arg): with pytest.raises(ValueError): tenant_mgt.list_tenants(page_token=arg, app=tenant_mgt_app) @@ -450,7 +449,7 @@ def test_list_single_page(self, tenant_mgt_app): assert page.next_page_token == '' assert page.has_next_page is False assert page.get_next_page() is None - tenants = [tenant for tenant in page.iterate_all()] + tenants = list(page.iterate_all()) assert len(tenants) == 2 self._assert_request(recorder) @@ -480,7 +479,7 @@ def test_list_tenants_paged_iteration(self, tenant_mgt_app): iterator = page.iterate_all() for index in range(3): tenant = next(iterator) - assert tenant.tenant_id == 'tenant{0}'.format(index) + assert tenant.tenant_id == f'tenant{index}' self._assert_request(recorder) # Page 2 (also the last page) @@ -514,7 +513,7 @@ def test_list_tenants_stop_iteration(self, tenant_mgt_app): _, recorder = _instrument_tenant_mgt(tenant_mgt_app, 200, LIST_TENANTS_RESPONSE) page = tenant_mgt.list_tenants(app=tenant_mgt_app) iterator = page.iterate_all() - tenants = [tenant for tenant in iterator] + tenants = list(iterator) assert len(tenants) == 2 with pytest.raises(StopIteration): @@ -526,7 +525,7 @@ def test_list_tenants_no_tenants_response(self, tenant_mgt_app): _instrument_tenant_mgt(tenant_mgt_app, 200, json.dumps(response)) page = tenant_mgt.list_tenants(app=tenant_mgt_app) assert len(page.tenants) == 0 - tenants = [tenant for tenant in page.iterate_all()] + tenants = list(page.iterate_all()) assert len(tenants) == 0 def test_list_tenants_with_max_results(self, tenant_mgt_app): @@ -551,7 +550,7 @@ def _assert_tenants_page(self, page): assert isinstance(page, tenant_mgt.ListTenantsPage) assert len(page.tenants) == 2 for idx, tenant in enumerate(page.tenants): - _assert_tenant(tenant, 'tenant{0}'.format(idx)) + _assert_tenant(tenant, f'tenant{idx}') def _assert_request(self, recorder, expected=None): if expected is None: @@ -671,8 +670,7 @@ def test_revoke_refresh_tokens(self, tenant_mgt_app): assert len(recorder) == 1 req = recorder[0] assert req.method == 'POST' - assert req.url == '{0}/tenants/tenant-id/accounts:update'.format( - USER_MGT_URL_PREFIX) + assert req.url == f'{USER_MGT_URL_PREFIX}/tenants/tenant-id/accounts:update' body = json.loads(req.body.decode()) assert body['localId'] == 'testuser' assert 'validSince' in body @@ -693,8 +691,9 @@ def test_list_users(self, tenant_mgt_app): assert len(recorder) == 1 req = recorder[0] assert req.method == 'GET' - assert req.url == '{0}/tenants/tenant-id/accounts:batchGet?maxResults=1000'.format( - USER_MGT_URL_PREFIX) + assert req.url == ( + f'{USER_MGT_URL_PREFIX}/tenants/tenant-id/accounts:batchGet?maxResults=1000' + ) def test_import_users(self, tenant_mgt_app): client = tenant_mgt.auth_for_tenant('tenant-id', app=tenant_mgt_app) @@ -765,8 +764,9 @@ def test_get_oidc_provider_config(self, tenant_mgt_app): assert len(recorder) == 1 req = recorder[0] assert req.method == 'GET' - assert req.url == '{0}/tenants/tenant-id/oauthIdpConfigs/oidc.provider'.format( - PROVIDER_MGT_URL_PREFIX) + assert req.url == ( + f'{PROVIDER_MGT_URL_PREFIX}/tenants/tenant-id/oauthIdpConfigs/oidc.provider' + ) def test_create_oidc_provider_config(self, tenant_mgt_app): client = tenant_mgt.auth_for_tenant('tenant-id', app=tenant_mgt_app) @@ -791,7 +791,7 @@ def test_update_oidc_provider_config(self, tenant_mgt_app): self._assert_oidc_provider_config(provider_config) mask = ['clientId', 'displayName', 'enabled', 'issuer'] - url = '/oauthIdpConfigs/oidc.provider?updateMask={0}'.format(','.join(mask)) + url = f'/oauthIdpConfigs/oidc.provider?updateMask={",".join(mask)}' self._assert_request( recorder, url, OIDC_PROVIDER_CONFIG_REQUEST, method='PATCH', prefix=PROVIDER_MGT_URL_PREFIX) @@ -805,8 +805,9 @@ def test_delete_oidc_provider_config(self, tenant_mgt_app): assert len(recorder) == 1 req = recorder[0] assert req.method == 'DELETE' - assert req.url == '{0}/tenants/tenant-id/oauthIdpConfigs/oidc.provider'.format( - PROVIDER_MGT_URL_PREFIX) + assert req.url == ( + f'{PROVIDER_MGT_URL_PREFIX}/tenants/tenant-id/oauthIdpConfigs/oidc.provider' + ) def test_list_oidc_provider_configs(self, tenant_mgt_app): client = tenant_mgt.auth_for_tenant('tenant-id', app=tenant_mgt_app) @@ -819,7 +820,7 @@ def test_list_oidc_provider_configs(self, tenant_mgt_app): assert len(page.provider_configs) == 2 for provider_config in page.provider_configs: self._assert_oidc_provider_config( - provider_config, want_id='oidc.provider{0}'.format(index)) + provider_config, want_id=f'oidc.provider{index}') index += 1 assert page.next_page_token == '' @@ -831,8 +832,9 @@ def test_list_oidc_provider_configs(self, tenant_mgt_app): assert len(recorder) == 1 req = recorder[0] assert req.method == 'GET' - assert req.url == '{0}{1}'.format( - PROVIDER_MGT_URL_PREFIX, '/tenants/tenant-id/oauthIdpConfigs?pageSize=100') + assert req.url == ( + f'{PROVIDER_MGT_URL_PREFIX}/tenants/tenant-id/oauthIdpConfigs?pageSize=100' + ) def test_get_saml_provider_config(self, tenant_mgt_app): client = tenant_mgt.auth_for_tenant('tenant-id', app=tenant_mgt_app) @@ -844,8 +846,9 @@ def test_get_saml_provider_config(self, tenant_mgt_app): assert len(recorder) == 1 req = recorder[0] assert req.method == 'GET' - assert req.url == '{0}/tenants/tenant-id/inboundSamlConfigs/saml.provider'.format( - PROVIDER_MGT_URL_PREFIX) + assert req.url == ( + f'{PROVIDER_MGT_URL_PREFIX}/tenants/tenant-id/inboundSamlConfigs/saml.provider' + ) def test_create_saml_provider_config(self, tenant_mgt_app): client = tenant_mgt.auth_for_tenant('tenant-id', app=tenant_mgt_app) @@ -877,7 +880,7 @@ def test_update_saml_provider_config(self, tenant_mgt_app): 'displayName', 'enabled', 'idpConfig.idpCertificates', 'idpConfig.idpEntityId', 'idpConfig.ssoUrl', 'spConfig.callbackUri', 'spConfig.spEntityId', ] - url = '/inboundSamlConfigs/saml.provider?updateMask={0}'.format(','.join(mask)) + url = f'/inboundSamlConfigs/saml.provider?updateMask={",".join(mask)}' self._assert_request( recorder, url, SAML_PROVIDER_CONFIG_REQUEST, method='PATCH', prefix=PROVIDER_MGT_URL_PREFIX) @@ -891,8 +894,9 @@ def test_delete_saml_provider_config(self, tenant_mgt_app): assert len(recorder) == 1 req = recorder[0] assert req.method == 'DELETE' - assert req.url == '{0}/tenants/tenant-id/inboundSamlConfigs/saml.provider'.format( - PROVIDER_MGT_URL_PREFIX) + assert req.url == ( + f'{PROVIDER_MGT_URL_PREFIX}/tenants/tenant-id/inboundSamlConfigs/saml.provider' + ) def test_list_saml_provider_configs(self, tenant_mgt_app): client = tenant_mgt.auth_for_tenant('tenant-id', app=tenant_mgt_app) @@ -905,7 +909,7 @@ def test_list_saml_provider_configs(self, tenant_mgt_app): assert len(page.provider_configs) == 2 for provider_config in page.provider_configs: self._assert_saml_provider_config( - provider_config, want_id='saml.provider{0}'.format(index)) + provider_config, want_id=f'saml.provider{index}') index += 1 assert page.next_page_token == '' @@ -917,8 +921,9 @@ def test_list_saml_provider_configs(self, tenant_mgt_app): assert len(recorder) == 1 req = recorder[0] assert req.method == 'GET' - assert req.url == '{0}{1}'.format( - PROVIDER_MGT_URL_PREFIX, '/tenants/tenant-id/inboundSamlConfigs?pageSize=100') + assert req.url == ( + f'{PROVIDER_MGT_URL_PREFIX}/tenants/tenant-id/inboundSamlConfigs?pageSize=100' + ) def test_tenant_not_found(self, tenant_mgt_app): client = tenant_mgt.auth_for_tenant('tenant-id', app=tenant_mgt_app) @@ -937,7 +942,7 @@ def _assert_request( assert len(recorder) == 1 req = recorder[0] assert req.method == method - assert req.url == '{0}/tenants/tenant-id{1}'.format(prefix, want_url) + assert req.url == f'{prefix}/tenants/tenant-id{want_url}' assert req.headers['X-Client-Version'] == f'Python/Admin/{firebase_admin.__version__}' expected_metrics_header = _utils.get_metrics_header() + ' mock-cred-metric-tag' assert req.headers['x-goog-api-client'] == expected_metrics_header diff --git a/tests/test_token_gen.py b/tests/test_token_gen.py index fe0b28dbe..384bc22c3 100644 --- a/tests/test_token_gen.py +++ b/tests/test_token_gen.py @@ -48,8 +48,8 @@ MOCK_SERVICE_ACCOUNT_EMAIL = MOCK_CREDENTIAL.service_account_email MOCK_REQUEST = testutils.MockRequest(200, MOCK_PUBLIC_CERTS) -INVALID_STRINGS = [None, '', 0, 1, True, False, list(), tuple(), dict()] -INVALID_BOOLS = [None, '', 'foo', 0, 1, list(), tuple(), dict()] +INVALID_STRINGS = [None, '', 0, 1, True, False, [], tuple(), {}] +INVALID_BOOLS = [None, '', 'foo', 0, 1, [], tuple(), {}] INVALID_JWT_ARGS = { 'NoneToken': None, 'EmptyToken': '', @@ -63,7 +63,7 @@ ID_TOOLKIT_URL = 'https://identitytoolkit.googleapis.com/v1' EMULATOR_HOST_ENV_VAR = 'FIREBASE_AUTH_EMULATOR_HOST' AUTH_EMULATOR_HOST = 'localhost:9099' -EMULATED_ID_TOOLKIT_URL = 'http://{}/identitytoolkit.googleapis.com/v1'.format(AUTH_EMULATOR_HOST) +EMULATED_ID_TOOLKIT_URL = f'http://{AUTH_EMULATOR_HOST}/identitytoolkit.googleapis.com/v1' TOKEN_MGT_URLS = { 'ID_TOOLKIT': ID_TOOLKIT_URL, } @@ -136,8 +136,9 @@ def _get_session_cookie( payload_overrides=None, header_overrides=None, current_time=MOCK_CURRENT_TIME): payload_overrides = payload_overrides or {} if 'iss' not in payload_overrides: - payload_overrides['iss'] = 'https://session.firebase.google.com/{0}'.format( - MOCK_CREDENTIAL.project_id) + payload_overrides['iss'] = ( + f'https://session.firebase.google.com/{MOCK_CREDENTIAL.project_id}' + ) return _get_id_token(payload_overrides, header_overrides, current_time=current_time) def _instrument_user_manager(app, status, payload): @@ -282,7 +283,7 @@ def test_sign_with_iam(self): testutils.MockCredential(), name='iam-signer-app', options=options) try: signature = base64.b64encode(b'test').decode() - iam_resp = '{{"signedBlob": "{0}"}}'.format(signature) + iam_resp = json.dumps({'signedBlob': signature}) _overwrite_iam_request(app, testutils.MockRequest(200, iam_resp)) custom_token = auth.create_custom_token(MOCK_UID, app=app).decode() assert custom_token.endswith('.' + signature.rstrip('=')) @@ -319,8 +320,7 @@ def test_sign_with_discovered_service_account(self): # Now invoke the IAM signer. signature = base64.b64encode(b'test').decode() - request.response = testutils.MockResponse( - 200, '{{"signedBlob": "{0}"}}'.format(signature)) + request.response = testutils.MockResponse(200, json.dumps({'signedBlob': signature})) custom_token = auth.create_custom_token(MOCK_UID, app=app).decode() assert custom_token.endswith('.' + signature.rstrip('=')) self._verify_signer(custom_token, 'discovered-service-account') @@ -354,13 +354,13 @@ def _verify_signer(self, token, signer): class TestCreateSessionCookie: - @pytest.mark.parametrize('id_token', [None, '', 0, 1, True, False, list(), dict(), tuple()]) + @pytest.mark.parametrize('id_token', [None, '', 0, 1, True, False, [], {}, tuple()]) def test_invalid_id_token(self, user_mgt_app, id_token): with pytest.raises(ValueError): auth.create_session_cookie(id_token, expires_in=3600, app=user_mgt_app) @pytest.mark.parametrize('expires_in', [ - None, '', True, False, list(), dict(), tuple(), + None, '', True, False, [], {}, tuple(), _token_gen.MIN_SESSION_COOKIE_DURATION_SECONDS - 1, _token_gen.MAX_SESSION_COOKIE_DURATION_SECONDS + 1, ]) diff --git a/tests/test_user_mgt.py b/tests/test_user_mgt.py index 34b698be4..2c747ee5e 100644 --- a/tests/test_user_mgt.py +++ b/tests/test_user_mgt.py @@ -32,10 +32,10 @@ from tests import testutils -INVALID_STRINGS = [None, '', 0, 1, True, False, list(), tuple(), dict()] -INVALID_DICTS = [None, 'foo', 0, 1, True, False, list(), tuple()] -INVALID_INTS = [None, 'foo', '1', -1, 1.1, True, False, list(), tuple(), dict()] -INVALID_TIMESTAMPS = ['foo', '1', 0, -1, 1.1, True, False, list(), tuple(), dict()] +INVALID_STRINGS = [None, '', 0, 1, True, False, [], tuple(), {}] +INVALID_DICTS = [None, 'foo', 0, 1, True, False, [], tuple()] +INVALID_INTS = [None, 'foo', '1', -1, 1.1, True, False, [], tuple(), {}] +INVALID_TIMESTAMPS = ['foo', '1', 0, -1, 1.1, True, False, [], tuple(), {}] MOCK_GET_USER_RESPONSE = testutils.resource('get_user.json') MOCK_LIST_USERS_RESPONSE = testutils.resource('list_users.json') @@ -56,7 +56,7 @@ ID_TOOLKIT_URL = 'https://identitytoolkit.googleapis.com/v1' EMULATOR_HOST_ENV_VAR = 'FIREBASE_AUTH_EMULATOR_HOST' AUTH_EMULATOR_HOST = 'localhost:9099' -EMULATED_ID_TOOLKIT_URL = 'http://{}/identitytoolkit.googleapis.com/v1'.format(AUTH_EMULATOR_HOST) +EMULATED_ID_TOOLKIT_URL = f'http://{AUTH_EMULATOR_HOST}/identitytoolkit.googleapis.com/v1' URL_PROJECT_SUFFIX = '/projects/mock-project-id' USER_MGT_URLS = { 'ID_TOOLKIT': ID_TOOLKIT_URL, @@ -135,7 +135,7 @@ def _check_request(recorder, want_url, want_body=None, want_timeout=None): assert len(recorder) == 1 req = recorder[0] assert req.method == 'POST' - assert req.url == '{0}{1}'.format(USER_MGT_URLS['PREFIX'], want_url) + assert req.url == f'{USER_MGT_URLS["PREFIX"]}{want_url}' expected_metrics_header = [ _utils.get_metrics_header(), _utils.get_metrics_header() + ' mock-cred-metric-tag' @@ -538,7 +538,7 @@ def test_user_already_exists(self, user_mgt_app, error_code): with pytest.raises(exc_type) as excinfo: auth.create_user(app=user_mgt_app) assert isinstance(excinfo.value, exceptions.AlreadyExistsError) - assert str(excinfo.value) == '{0} ({1}).'.format(exc_type.default_message, error_code) + assert str(excinfo.value) == f'{exc_type.default_message} ({error_code}).' assert excinfo.value.http_response is not None assert excinfo.value.cause is not None @@ -704,15 +704,14 @@ def test_single_reserved_claim(self, user_mgt_app, key): claims = {key : 'value'} with pytest.raises(ValueError) as excinfo: auth.set_custom_user_claims('user', claims, app=user_mgt_app) - assert str(excinfo.value) == 'Claim "{0}" is reserved, and must not be set.'.format(key) + assert str(excinfo.value) == f'Claim "{key}" is reserved, and must not be set.' def test_multiple_reserved_claims(self, user_mgt_app): claims = {key : 'value' for key in _auth_utils.RESERVED_CLAIMS} with pytest.raises(ValueError) as excinfo: auth.set_custom_user_claims('user', claims, app=user_mgt_app) joined = ', '.join(sorted(claims.keys())) - assert str(excinfo.value) == ('Claims "{0}" are reserved, and must not be ' - 'set.'.format(joined)) + assert str(excinfo.value) == f'Claims "{joined}" are reserved, and must not be set.' def test_large_claims_payload(self, user_mgt_app): claims = {'key' : 'A'*1000} @@ -830,12 +829,12 @@ def test_success(self, user_mgt_app): class TestListUsers: - @pytest.mark.parametrize('arg', [None, 'foo', list(), dict(), 0, -1, 1001, False]) + @pytest.mark.parametrize('arg', [None, 'foo', [], {}, 0, -1, 1001, False]) def test_invalid_max_results(self, user_mgt_app, arg): with pytest.raises(ValueError): auth.list_users(max_results=arg, app=user_mgt_app) - @pytest.mark.parametrize('arg', ['', list(), dict(), 0, -1, 1001, False]) + @pytest.mark.parametrize('arg', ['', [], {}, 0, -1, 1001, False]) def test_invalid_page_token(self, user_mgt_app, arg): with pytest.raises(ValueError): auth.list_users(page_token=arg, app=user_mgt_app) @@ -887,7 +886,7 @@ def test_list_users_paged_iteration(self, user_mgt_app): iterator = page.iterate_all() for index in range(3): user = next(iterator) - assert user.uid == 'user{0}'.format(index+1) + assert user.uid == f'user{index+1}' assert len(recorder) == 1 self._check_rpc_calls(recorder) @@ -912,7 +911,7 @@ def test_list_users_iterator_state(self, user_mgt_app): iterator = page.iterate_all() for user in iterator: index += 1 - assert user.uid == 'user{0}'.format(index) + assert user.uid == f'user{index}' if index == 2: break @@ -986,7 +985,7 @@ def _check_page(self, page): assert len(page.users) == 2 for user in page.users: assert isinstance(user, auth.ExportedUserRecord) - _check_user_record(user, 'testuser{0}'.format(index)) + _check_user_record(user, f'testuser{index}') assert user.password_hash == 'passwordHash' assert user.password_salt == 'passwordSalt' index += 1 @@ -1061,8 +1060,8 @@ class TestImportUserRecord: [{'email': arg} for arg in INVALID_STRINGS[1:] + ['not-an-email']] + [{'photo_url': arg} for arg in INVALID_STRINGS[1:] + ['not-a-url']] + [{'phone_number': arg} for arg in INVALID_STRINGS[1:] + ['not-a-phone']] + - [{'password_hash': arg} for arg in INVALID_STRINGS[1:] + [u'test']] + - [{'password_salt': arg} for arg in INVALID_STRINGS[1:] + [u'test']] + + [{'password_hash': arg} for arg in INVALID_STRINGS[1:] + ['test']] + + [{'password_salt': arg} for arg in INVALID_STRINGS[1:] + ['test']] + [{'custom_claims': arg} for arg in INVALID_DICTS[1:] + ['"json"', {'key': 'a'*1000}]] + [{'provider_data': arg} for arg in ['foo', 1, True]] ) @@ -1245,13 +1244,13 @@ def test_invalid_standard_scrypt(self, arg): class TestImportUsers: - @pytest.mark.parametrize('arg', [None, list(), tuple(), dict(), 0, 1, 'foo']) + @pytest.mark.parametrize('arg', [None, [], tuple(), {}, 0, 1, 'foo']) def test_invalid_users(self, user_mgt_app, arg): with pytest.raises(Exception): auth.import_users(arg, app=user_mgt_app) def test_too_many_users(self, user_mgt_app): - users = [auth.ImportUserRecord(uid='test{0}'.format(i)) for i in range(1001)] + users = [auth.ImportUserRecord(uid=f'test{i}') for i in range(1001)] with pytest.raises(ValueError): auth.import_users(users, app=user_mgt_app) @@ -1384,7 +1383,7 @@ def test_valid_data(self): {'android_install_app':'nonboolean'}, {'dynamic_link_domain': False}, {'ios_bundle_id':11}, - {'android_package_name':dict()}, + {'android_package_name':{}}, {'android_minimum_version':tuple()}, {'android_minimum_version':'7'}, {'android_install_app': True}]) diff --git a/tests/testutils.py b/tests/testutils.py index 62f7bd9b5..598a929b4 100644 --- a/tests/testutils.py +++ b/tests/testutils.py @@ -33,7 +33,7 @@ def resource_filename(filename): def resource(filename): """Returns the contents of a test resource.""" - with open(resource_filename(filename), 'r') as file_obj: + with open(resource_filename(filename), 'r', encoding='utf-8') as file_obj: return file_obj.read() @@ -183,7 +183,7 @@ def send(self, request, **kwargs): # pylint: disable=arguments-differ class MockAdapter(MockMultiRequestAdapter): """A mock HTTP adapter for the Python requests module.""" def __init__(self, data, status, recorder): - super(MockAdapter, self).__init__([data], [status], recorder) + super().__init__([data], [status], recorder) @property def status(self): From dfaceecf27a2a7402c6589799e595e1f92803966 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 15 Jul 2025 16:22:36 +0000 Subject: [PATCH 209/226] chore(deps): bump astroid from 3.3.10 to 3.3.11 (#901) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index ff15072a6..c68d71a0f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -astroid == 3.3.10 +astroid == 3.3.11 pylint == 3.3.7 pytest >= 8.2.2 pytest-cov >= 2.4.0 From 2c8a34a7362246bb0f129f253816a35a2c6f9af4 Mon Sep 17 00:00:00 2001 From: Jonathan Edey <145066863+jonathanedey@users.noreply.github.com> Date: Thu, 17 Jul 2025 10:23:24 -0400 Subject: [PATCH 210/226] [chore] Release 7.0.0 (#902) --- firebase_admin/__about__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/firebase_admin/__about__.py b/firebase_admin/__about__.py index 2ee3bbd62..6a05c663f 100644 --- a/firebase_admin/__about__.py +++ b/firebase_admin/__about__.py @@ -14,7 +14,7 @@ """About information (version, etc) for Firebase Admin SDK.""" -__version__ = '6.9.0' +__version__ = '7.0.0' __title__ = 'firebase_admin' __author__ = 'Firebase' __license__ = 'Apache License 2.0' From 4fa29189d7124c7912826d33136413a556e7d6fd Mon Sep 17 00:00:00 2001 From: Huw Martin Date: Wed, 23 Jul 2025 18:02:51 +0100 Subject: [PATCH 211/226] feat(auth) Update `ActionCodeSettings` to support `link_domain` and deprecate `dynamic_link_domain` (#884) * Add link_domain to ActionCodeSettings; update encode_action_code_settings to handle link_domain * Add handling for InvalidHostingLinkDomainError * Add deprecation warning for dynamic_link_domain * Update error message for InvalidHostingLinkDomainError * Fix lint * Add type hints to ActionCodeSettings * Fix f-string lint --- firebase_admin/_auth_utils.py | 12 ++++++++++++ firebase_admin/_user_mgt.py | 28 ++++++++++++++++++++++++++-- firebase_admin/auth.py | 2 ++ tests/test_user_mgt.py | 25 +++++++++++++++++++++++-- 4 files changed, 63 insertions(+), 4 deletions(-) diff --git a/firebase_admin/_auth_utils.py b/firebase_admin/_auth_utils.py index 60d411822..a514442c4 100644 --- a/firebase_admin/_auth_utils.py +++ b/firebase_admin/_auth_utils.py @@ -324,6 +324,17 @@ def __init__(self, message, cause, http_response): exceptions.InvalidArgumentError.__init__(self, message, cause, http_response) +class InvalidHostingLinkDomainError(exceptions.InvalidArgumentError): + """The provided hosting link domain is not configured in Firebase Hosting + or is not owned by the current project.""" + + default_message = ('The provided hosting link domain is not configured in Firebase ' + 'Hosting or is not owned by the current project') + + def __init__(self, message, cause, http_response): + exceptions.InvalidArgumentError.__init__(self, message, cause, http_response) + + class InvalidIdTokenError(exceptions.InvalidArgumentError): """The provided ID token is not a valid Firebase ID token.""" @@ -423,6 +434,7 @@ def __init__(self, message, cause=None, http_response=None): 'EMAIL_NOT_FOUND': EmailNotFoundError, 'INSUFFICIENT_PERMISSION': InsufficientPermissionError, 'INVALID_DYNAMIC_LINK_DOMAIN': InvalidDynamicLinkDomainError, + 'INVALID_HOSTING_LINK_DOMAIN': InvalidHostingLinkDomainError, 'INVALID_ID_TOKEN': InvalidIdTokenError, 'PHONE_NUMBER_EXISTS': PhoneNumberAlreadyExistsError, 'TENANT_NOT_FOUND': TenantNotFoundError, diff --git a/firebase_admin/_user_mgt.py b/firebase_admin/_user_mgt.py index 9a75b7a2e..e7825499c 100644 --- a/firebase_admin/_user_mgt.py +++ b/firebase_admin/_user_mgt.py @@ -17,7 +17,9 @@ import base64 from collections import defaultdict import json +from typing import Optional from urllib import parse +import warnings import requests @@ -489,8 +491,22 @@ class ActionCodeSettings: Used when invoking the email action link generation APIs. """ - def __init__(self, url, handle_code_in_app=None, dynamic_link_domain=None, ios_bundle_id=None, - android_package_name=None, android_install_app=None, android_minimum_version=None): + def __init__( + self, + url: str, + handle_code_in_app: Optional[bool] = None, + dynamic_link_domain: Optional[str] = None, + ios_bundle_id: Optional[str] = None, + android_package_name: Optional[str] = None, + android_install_app: Optional[str] = None, + android_minimum_version: Optional[str] = None, + link_domain: Optional[str] = None, + ): + if dynamic_link_domain is not None: + warnings.warn( + 'dynamic_link_domain is deprecated, use link_domain instead', + DeprecationWarning + ) self.url = url self.handle_code_in_app = handle_code_in_app self.dynamic_link_domain = dynamic_link_domain @@ -498,6 +514,7 @@ def __init__(self, url, handle_code_in_app=None, dynamic_link_domain=None, ios_b self.android_package_name = android_package_name self.android_install_app = android_install_app self.android_minimum_version = android_minimum_version + self.link_domain = link_domain def encode_action_code_settings(settings): @@ -535,6 +552,13 @@ def encode_action_code_settings(settings): f'Invalid value provided for dynamic_link_domain: {settings.dynamic_link_domain}') parameters['dynamicLinkDomain'] = settings.dynamic_link_domain + # link_domain + if settings.link_domain is not None: + if not isinstance(settings.link_domain, str): + raise ValueError( + f'Invalid value provided for link_domain: {settings.link_domain}') + parameters['linkDomain'] = settings.link_domain + # ios_bundle_id if settings.ios_bundle_id is not None: if not isinstance(settings.ios_bundle_id, str): diff --git a/firebase_admin/auth.py b/firebase_admin/auth.py index ced143112..cb63ab7f0 100644 --- a/firebase_admin/auth.py +++ b/firebase_admin/auth.py @@ -49,6 +49,7 @@ 'ImportUserRecord', 'InsufficientPermissionError', 'InvalidDynamicLinkDomainError', + 'InvalidHostingLinkDomainError', 'InvalidIdTokenError', 'InvalidSessionCookieError', 'ListProviderConfigsPage', @@ -125,6 +126,7 @@ ImportUserRecord = _user_import.ImportUserRecord InsufficientPermissionError = _auth_utils.InsufficientPermissionError InvalidDynamicLinkDomainError = _auth_utils.InvalidDynamicLinkDomainError +InvalidHostingLinkDomainError = _auth_utils.InvalidHostingLinkDomainError InvalidIdTokenError = _auth_utils.InvalidIdTokenError InvalidSessionCookieError = _token_gen.InvalidSessionCookieError ListProviderConfigsPage = _auth_providers.ListProviderConfigsPage diff --git a/tests/test_user_mgt.py b/tests/test_user_mgt.py index 2c747ee5e..4623f5e54 100644 --- a/tests/test_user_mgt.py +++ b/tests/test_user_mgt.py @@ -43,7 +43,8 @@ MOCK_ACTION_CODE_DATA = { 'url': 'http://localhost', 'handle_code_in_app': True, - 'dynamic_link_domain': 'http://testly', + 'dynamic_link_domain': 'http://dynamic-link-domain', + 'link_domain': 'http://link-domain', 'ios_bundle_id': 'test.bundle', 'android_package_name': 'test.bundle', 'android_minimum_version': '7', @@ -1363,7 +1364,8 @@ def test_valid_data(self): data = { 'url': 'http://localhost', 'handle_code_in_app': True, - 'dynamic_link_domain': 'http://testly', + 'dynamic_link_domain': 'http://dynamic-link-domain', + 'link_domain': 'http://link-domain', 'ios_bundle_id': 'test.bundle', 'android_package_name': 'test.bundle', 'android_minimum_version': '7', @@ -1374,6 +1376,7 @@ def test_valid_data(self): assert parameters['continueUrl'] == data['url'] assert parameters['canHandleCodeInApp'] == data['handle_code_in_app'] assert parameters['dynamicLinkDomain'] == data['dynamic_link_domain'] + assert parameters['linkDomain'] == data['link_domain'] assert parameters['iOSBundleId'] == data['ios_bundle_id'] assert parameters['androidPackageName'] == data['android_package_name'] assert parameters['androidMinimumVersion'] == data['android_minimum_version'] @@ -1496,6 +1499,23 @@ def test_invalid_dynamic_link(self, user_mgt_app, func): assert excinfo.value.http_response is not None assert excinfo.value.cause is not None + @pytest.mark.parametrize('func', [ + auth.generate_sign_in_with_email_link, + auth.generate_email_verification_link, + auth.generate_password_reset_link, + ]) + def test_invalid_hosting_link(self, user_mgt_app, func): + resp = '{"error":{"message": "INVALID_HOSTING_LINK_DOMAIN: Because of this reason."}}' + _instrument_user_manager(user_mgt_app, 500, resp) + with pytest.raises(auth.InvalidHostingLinkDomainError) as excinfo: + func('test@test.com', MOCK_ACTION_CODE_SETTINGS, app=user_mgt_app) + assert isinstance(excinfo.value, exceptions.InvalidArgumentError) + assert str(excinfo.value) == ('The provided hosting link domain is not configured in ' + 'Firebase Hosting or is not owned by the current project ' + '(INVALID_HOSTING_LINK_DOMAIN). Because of this reason.') + assert excinfo.value.http_response is not None + assert excinfo.value.cause is not None + @pytest.mark.parametrize('func', [ auth.generate_sign_in_with_email_link, auth.generate_email_verification_link, @@ -1534,6 +1554,7 @@ def _validate_request(self, request, settings=None): assert request['continueUrl'] == settings.url assert request['canHandleCodeInApp'] == settings.handle_code_in_app assert request['dynamicLinkDomain'] == settings.dynamic_link_domain + assert request['linkDomain'] == settings.link_domain assert request['iOSBundleId'] == settings.ios_bundle_id assert request['androidPackageName'] == settings.android_package_name assert request['androidMinimumVersion'] == settings.android_minimum_version From 6555a84baaa27a4cfb0d76d2ce6c09c899573b3c Mon Sep 17 00:00:00 2001 From: Jonathan Edey <145066863+jonathanedey@users.noreply.github.com> Date: Thu, 31 Jul 2025 16:31:52 -0400 Subject: [PATCH 212/226] [chore] Release 7.1.0 (#903) --- firebase_admin/__about__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/firebase_admin/__about__.py b/firebase_admin/__about__.py index 6a05c663f..9fb40b11c 100644 --- a/firebase_admin/__about__.py +++ b/firebase_admin/__about__.py @@ -14,7 +14,7 @@ """About information (version, etc) for Firebase Admin SDK.""" -__version__ = '7.0.0' +__version__ = '7.1.0' __title__ = 'firebase_admin' __author__ = 'Firebase' __license__ = 'Apache License 2.0' From 5e752502fdaede3246e4224684dba6ea089a7726 Mon Sep 17 00:00:00 2001 From: Jonathan Edey <145066863+jonathanedey@users.noreply.github.com> Date: Wed, 20 Aug 2025 13:23:32 -0400 Subject: [PATCH 213/226] chore: Added an `AGENTS.md` file to instruct AI agents how to interact with this repository (#906) --- AGENTS.md | 170 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 170 insertions(+) create mode 100644 AGENTS.md diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 000000000..28bba4b55 --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,170 @@ +# Firebase Admin Python SDK - Agent Guide + +This document provides AI agents with a comprehensive guide to the conventions, design patterns, and architectural nuances of the Firebase Admin Python SDK. Adhering to this guide ensures that all contributions are idiomatic and align with the existing codebase. + +## 1. High-Level Overview + +The Firebase Admin Python SDK provides a Pythonic interface to Firebase services. Its design emphasizes thread-safety, a consistent and predictable API, and seamless integration with Google Cloud Platform services. + +## 2. Directory Structure + +- `firebase_admin/`: The main package directory. + - `__init__.py`: The primary entry point. It exposes the `initialize_app()` function and manages the lifecycle of `App` instances. + - `exceptions.py`: Defines the custom exception hierarchy for the SDK. + - `_http_client.py`: Contains the centralized `JsonHttpClient` and `HttpxAsyncClient` for all outgoing HTTP requests. + - Service modules (e.g., `auth.py`, `db.py`, `messaging.py`): Each module contains the logic for a specific Firebase service. +- `tests/`: Contains all unit tests. + - `tests/resources/`: Contains mock data, keys, and other test assets. +- `integration/`: Contains all integration tests.* + - These integration tests require a real Firebase project to run against. + - `integration/conftest.py`: Contains provides configurations for these integration tests including how credentials are provided through pytest. +- `snippets/`: Contains code snippets used in documentation. +- `setup.py`: Package definition, including the required environment dependencies. +- `requirements.txt`: A list of all development dependencies. +- `.pylintrc`: Configuration file for the `pylint` linter. +- `CONTRIBUTING.md`: General guidelines for human contributors. Your instructions here supersede this file. + +## 3. Core Design Patterns + +### Initialization + +The SDK is initialized by calling the `initialize_app(credential, options)` function. This creates a default `App` instance that SDK modules use implicitly. For multi-project use cases, named apps can be created by providing a `name` argument: `initialize_app(credential, options, name='my_app')`. + +### Service Clients + +Service clients are accessed via module-level factory functions. These functions automatically use the default app unless a specific `App` object is provided via the `app` parameter. The clients are created lazily and cached for the lifetime of the application. + +- **Direct Action Modules (auth, db)**: Some modules provide functions that perform actions directly. +- **Client Factory Modules (firestore, storage)**: Other modules have a function (e.g., client() or bucket()) that returns a client object, which you then use for operations. + + +### Error Handling + +- All SDK-specific exceptions inherit from `firebase_admin.exceptions.FirebaseError`. +- Specific error conditions are represented by subclasses, such as `firebase_admin.exceptions.InvalidArgumentError` and `firebase_admin.exceptions.UnauthenticatedError`. +- Each service may additionaly define exceptions under these subclasses and apply them by passing a handle function to `_utils.handle_platform_error_from_requests()` or `_utils.handle_platform_error_from_httpx()`. Each services error handling patterns should be considered before making changes. + +### HTTP Communication + +- All synchronous HTTP requests are made through the `JsonHttpClient` class in `firebase_admin._http_client`. +- All asynchronous HTTP requests are made through the `HttpxAsyncClient` class in `firebase_admin._http_client`. +- These clients handle authentication and retries for all API calls. + +### Asynchronous Operations + +Asynchronous operations are supported using Python's `asyncio` library. Asynchronous methods are typically named with an `_async` suffix (e.g., `messaging.send_each_async()`). + +## 4. Coding Style and Naming Conventions + +- **Formatting:** This project uses **pylint** to enforce code style and detect potential errors. Before submitting code, you **must** run the linter and ensure your changes do not introduce any new errors. Run the linter from the repository's root directory with the following command: + ```bash + ./lint.sh all # Lint all source files + ``` + or + ```bash + ./lint.sh # Lint locally modified source files + ``` +- **Naming:** + - Classes: `PascalCase` (e.g., `FirebaseError`). + - Methods and Functions: `snake_case` (e.g., `initialize_app`). + - Private Members: An underscore prefix (e.g., `_http_client`). + - Constants: `UPPER_SNAKE_CASE` (e.g., `INVALID_ARGUMENT`). + +## 5. Testing Philosophy + +- **Unit Tests:** + - Located in the `tests/` directory. + - Test files follow the `test_*.py` naming convention. + - Unit tests can be run using the following command: + ```bash + pytest + ``` +- **Integration Tests:** + - Located in the `integration/` directory. + - These tests make real API calls to Firebase services and require a configured project. Running these tests be should be ignored without a project and instead rely on the repository's GitHub Actions. + +## 6. Dependency Management + +- **Manager:** `pip` +- **Manifest:** `requirements.txt` +- **Command:** `pip install -r requirements.txt` + +## 7. Critical Developer Journeys + +### Journey 1: How to Add a New API Method + +1. **Define Public Method:** Add the new method or change to the appropriate service client files (e.g., `firebase_admin/auth.py`). +2. **Expose the public API method** by updating the `__all__` constant with the name of the new method. +3. **Internal Logic:** Implement the core logic within the service package. +4. **HTTP Client:** Use the HTTP client (`JsonHttpClient` or `HttpxAsyncClient`) to make the API call. +5. **Error Handling:** Catching exceptions from the HTTP client and raise the appropriate `FirebaseError` subclass using the services error handling logic +6. **Testing:** + - Add unit tests in the corresponding `test_*.py` file (e.g., `tests/test_user_mgt.py`). + - Add integration tests in the `integration/` directory if applicable. +7. **Snippets:** (Optional) Add or update code snippets in the `snippets/` directory. + +### Journey 2: How to Deprecate a Field/Method in an Existing API + +1. **Add Deprecation Note:** Locate where the deprecated object is defined and add a deprecation note to its docstring (e.g. `X is deprecated. Use Y instead.`). +2. **Add Deprecation Warning:** In the same location where the deprecated object is defined, add a deprecation warning to the code. (e.g. `warnings.warn('X is deprecated. Use Y instead.', DeprecationWarning)`) + +## 8. Critical Do's and Don'ts + +- **DO:** Use the centralized `JsonHttpClient` or `HttpxAsyncClient` for all HTTP requests. +- **DO:** Follow the established error handling patterns by using `FirebaseError` and its subclasses. +- **DON'T:** Expose implementation details from private (underscored) modules or functions in the public API. +- **DON'T:** Introduce new third-party dependencies without updating `requirements.txt` and `setup.py`. + +## 9. Branch Creation +- When creating a new barnch use the format `agentName-short-description`. + * Example: `jules-auth-token-parsing` + * Example: `gemini-add-storage-file-signer` + +## 10. Commit and Pull Request Generation + +After implementing and testing a change, you may create a commit and pull request which must follow the following these rules: + +### Commit and Pull Request Title Format: +Use the [Conventional Commits](https://www.conventionalcommits.org/en/v1.0.0/) specification: `type(scope): subject` +- `type` should be one of `feat`, `fix` or `chore`. +- `scope` should be the service package changed (e.g., `auth`, `rtdb`, `deps`). + - **Note**: Some services use specific abbreviations. Use the abbreviation if one exists. Common abbreviations include: + - `messaging` -> `fcm` + - `dataconnect` -> `fdc` + - `database` -> `rtdb` + - `appcheck` -> `fac` +- `subject` should be a brief summary of the change depending on the action: + - For pull requests this should focus on the larger goal the included commits achieve. + - Example: `fix(auth): Resolved issue with custom token verification` + - For commits this should focus on the specific changes made in that commit. + - Example: `fix(auth): Added a new token verification check` + +### Commit Body: +This should be a brief explanation of code changes. + +Example: +``` +feat(fcm): Added `send_each_for_multicast` support for multicast messages + +Added a new `send_each_for_multicast` method to the messaging client. This method wraps the `send_each` method and sends the same message to each token. +``` + +### Pull Request Body: +- A brief explanation of the problem and the solution. +- A summary of the testing strategy (e.g., "Added a new unit test to verify the fix."). +- A **Context Sources** section that lists the `id` and repository path of every `AGENTS.md` file you used. + +Example: +``` +feat(fcm): Added support for multicast messages + +This change introduces a new `send_each_for_multicast` method to the messaging client, allowing developers to send a single message to multiple tokens efficiently. + +Testing: Added unit tests in `tests/test_messaging.py` with mock requests and an integration test in `integration/test_messaging.py`. + +Context Sources Used: +- id: firebase-admin-python +``` + +## 11. Metadata +- id: firebase-admin-python \ No newline at end of file From 3d3ef0c9b5f99044c92072b241f470d2dc9fe9d2 Mon Sep 17 00:00:00 2001 From: Jonathan Edey <145066863+jonathanedey@users.noreply.github.com> Date: Mon, 8 Sep 2025 15:05:59 -0400 Subject: [PATCH 214/226] fix(auth): Fixed auth error code parsing (#908) --- firebase_admin/_auth_utils.py | 2 +- integration/test_auth.py | 13 +++++++++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/firebase_admin/_auth_utils.py b/firebase_admin/_auth_utils.py index a514442c4..8f3c419a7 100644 --- a/firebase_admin/_auth_utils.py +++ b/firebase_admin/_auth_utils.py @@ -479,7 +479,7 @@ def _parse_error_body(response): separator = code.find(':') if separator != -1: custom_message = code[separator + 1:].strip() - code = code[:separator] + code = code[:separator].strip() return code, custom_message diff --git a/integration/test_auth.py b/integration/test_auth.py index 7f4725dfe..b36063d19 100644 --- a/integration/test_auth.py +++ b/integration/test_auth.py @@ -724,6 +724,19 @@ def test_email_sign_in_with_settings(new_user_email_unverified, api_key): assert id_token is not None and len(id_token) > 0 assert auth.get_user(new_user_email_unverified.uid).email_verified +def test_auth_error_parse(new_user_email_unverified): + action_code_settings = auth.ActionCodeSettings( + ACTION_LINK_CONTINUE_URL, handle_code_in_app=True, link_domain="cool.link") + with pytest.raises(auth.InvalidHostingLinkDomainError) as excinfo: + auth.generate_sign_in_with_email_link(new_user_email_unverified.email, + action_code_settings=action_code_settings) + assert str(excinfo.value) == ('The provided hosting link domain is not configured in Firebase ' + 'Hosting or is not owned by the current project ' + '(INVALID_HOSTING_LINK_DOMAIN). The provided hosting link ' + 'domain is not configured in Firebase Hosting or is not owned ' + 'by the current project. This cannot be a default hosting domain ' + '(web.app or firebaseapp.com).') + @pytest.fixture(scope='module') def oidc_provider(): From de713d21da83b1f50c24c5a23132ffc442700448 Mon Sep 17 00:00:00 2001 From: Jonathan Edey <145066863+jonathanedey@users.noreply.github.com> Date: Tue, 16 Sep 2025 09:33:19 -0400 Subject: [PATCH 215/226] chore: Removed invalid `asyncio_default_fixture_loop_scope` config (#912) --- setup.cfg | 1 - 1 file changed, 1 deletion(-) diff --git a/setup.cfg b/setup.cfg index 32e00676b..4c6cf8d8f 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,4 +1,3 @@ [tool:pytest] testpaths = tests asyncio_default_test_loop_scope = class -asyncio_default_fixture_loop_scope = None From ee8fd701def6ae4252af5a846ea70c85be8d4cfe Mon Sep 17 00:00:00 2001 From: Jonathan Edey <145066863+jonathanedey@users.noreply.github.com> Date: Wed, 17 Sep 2025 17:22:29 -0400 Subject: [PATCH 216/226] fix(functions): Refresh credentials before enqueueing first task (#907) * fix(functions): Refresh credentials before enqueueing task This change addresses an issue where enqueueing a task from a Cloud Function would fail with a InvalidArgumentError error. This was caused by uninitialized credentials being used to in the task payload. The fix explicitly refreshes the credential before accessing the credential, ensuring a valid token or service account email is used in the in the task payload. This also includes a correction for an f-string typo in the Authorization header construction. * fix(functions): Move credential refresh to functions service init * fix(functions): Moved credential refresh to run on task payload update with freshness guard --------- Co-authored-by: google-labs-jules[bot] <161369871+google-labs-jules[bot]@users.noreply.github.com> --- firebase_admin/functions.py | 16 +++++++++-- tests/test_functions.py | 57 +++++++++++++++++++++++++++++++++++++ tests/testutils.py | 31 +++++++++++++++++++- 3 files changed, 101 insertions(+), 3 deletions(-) diff --git a/firebase_admin/functions.py b/firebase_admin/functions.py index 6db0fbb42..8e77d8560 100644 --- a/firebase_admin/functions.py +++ b/firebase_admin/functions.py @@ -22,7 +22,11 @@ from base64 import b64encode from typing import Any, Optional, Dict from dataclasses import dataclass + from google.auth.compute_engine import Credentials as ComputeEngineCredentials +from google.auth.credentials import TokenState +from google.auth.exceptions import RefreshError +from google.auth.transport import requests as google_auth_requests import requests import firebase_admin @@ -285,14 +289,22 @@ def _update_task_payload(self, task: Task, resource: Resource, extension_id: str # Get function url from task or generate from resources if not _Validators.is_non_empty_string(task.http_request['url']): task.http_request['url'] = self._get_url(resource, _FIREBASE_FUNCTION_URL_FORMAT) + + # Refresh the credential to ensure all attributes (e.g. service_account_email, id_token) + # are populated, preventing cold start errors. + if self._credential.token_state != TokenState.FRESH: + try: + self._credential.refresh(google_auth_requests.Request()) + except RefreshError as err: + raise ValueError(f'Initial task payload credential refresh failed: {err}') from err + # If extension id is provided, it emplies that it is being run from a deployed extension. # Meaning that it's credential should be a Compute Engine Credential. if _Validators.is_non_empty_string(extension_id) and \ isinstance(self._credential, ComputeEngineCredentials): - id_token = self._credential.token task.http_request['headers'] = \ - {**task.http_request['headers'], 'Authorization': f'Bearer ${id_token}'} + {**task.http_request['headers'], 'Authorization': f'Bearer {id_token}'} # Delete oidc token del task.http_request['oidc_token'] else: diff --git a/tests/test_functions.py b/tests/test_functions.py index 52e92c1b2..953563449 100644 --- a/tests/test_functions.py +++ b/tests/test_functions.py @@ -124,6 +124,10 @@ def test_task_enqueue(self): assert recorder[0].headers['x-goog-api-client'] == expected_metrics_header assert task_id == 'test-task-id' + task = json.loads(recorder[0].body.decode())['task'] + assert task['http_request']['oidc_token'] == {'service_account_email': 'mock-email'} + assert task['http_request']['headers'] == {'Content-Type': 'application/json'} + def test_task_enqueue_with_extension(self): resource_name = ( 'projects/test-project/locations/us-central1/queues/' @@ -142,6 +146,59 @@ def test_task_enqueue_with_extension(self): assert recorder[0].headers['x-goog-api-client'] == expected_metrics_header assert task_id == 'test-task-id' + task = json.loads(recorder[0].body.decode())['task'] + assert task['http_request']['oidc_token'] == {'service_account_email': 'mock-email'} + assert task['http_request']['headers'] == {'Content-Type': 'application/json'} + + def test_task_enqueue_compute_engine(self): + app = firebase_admin.initialize_app( + testutils.MockComputeEngineCredential(), + options={'projectId': 'test-project'}, + name='test-project-gce') + _, recorder = self._instrument_functions_service(app) + queue = functions.task_queue('test-function-name', app=app) + task_id = queue.enqueue(_DEFAULT_DATA) + assert len(recorder) == 1 + assert recorder[0].method == 'POST' + assert recorder[0].url == _DEFAULT_REQUEST_URL + assert recorder[0].headers['Content-Type'] == 'application/json' + assert recorder[0].headers['Authorization'] == 'Bearer mock-compute-engine-token' + expected_metrics_header = _utils.get_metrics_header() + ' mock-gce-cred-metric-tag' + assert recorder[0].headers['x-goog-api-client'] == expected_metrics_header + assert task_id == 'test-task-id' + + task = json.loads(recorder[0].body.decode())['task'] + assert task['http_request']['oidc_token'] == {'service_account_email': 'mock-gce-email'} + assert task['http_request']['headers'] == {'Content-Type': 'application/json'} + + def test_task_enqueue_with_extension_compute_engine(self): + resource_name = ( + 'projects/test-project/locations/us-central1/queues/' + 'ext-test-extension-id-test-function-name/tasks' + ) + extension_response = json.dumps({'name': resource_name + '/test-task-id'}) + app = firebase_admin.initialize_app( + testutils.MockComputeEngineCredential(), + options={'projectId': 'test-project'}, + name='test-project-gce-extensions') + _, recorder = self._instrument_functions_service(app, payload=extension_response) + queue = functions.task_queue('test-function-name', 'test-extension-id', app) + task_id = queue.enqueue(_DEFAULT_DATA) + assert len(recorder) == 1 + assert recorder[0].method == 'POST' + assert recorder[0].url == _CLOUD_TASKS_URL + resource_name + assert recorder[0].headers['Content-Type'] == 'application/json' + assert recorder[0].headers['Authorization'] == 'Bearer mock-compute-engine-token' + expected_metrics_header = _utils.get_metrics_header() + ' mock-gce-cred-metric-tag' + assert recorder[0].headers['x-goog-api-client'] == expected_metrics_header + assert task_id == 'test-task-id' + + task = json.loads(recorder[0].body.decode())['task'] + assert 'oidc_token' not in task['http_request'] + assert task['http_request']['headers'] == { + 'Content-Type': 'application/json', + 'Authorization': 'Bearer mock-compute-engine-token'} + def test_task_delete(self): _, recorder = self._instrument_functions_service() queue = functions.task_queue('test-function-name') diff --git a/tests/testutils.py b/tests/testutils.py index 598a929b4..7546595af 100644 --- a/tests/testutils.py +++ b/tests/testutils.py @@ -116,12 +116,25 @@ def __call__(self, *args, **kwargs): # pylint: disable=arguments-differ # pylint: disable=abstract-method class MockGoogleCredential(credentials.Credentials): """A mock Google authentication credential.""" + + def __init__(self): + super().__init__() + self.token = None + self._service_account_email = None + self._token_state = credentials.TokenState.INVALID + def refresh(self, request): self.token = 'mock-token' + self._service_account_email = 'mock-email' + self._token_state = credentials.TokenState.FRESH + + @property + def token_state(self): + return self._token_state @property def service_account_email(self): - return 'mock-email' + return self._service_account_email # Simulate x-goog-api-client modification in credential refresh def _metric_header_for_usage(self): @@ -139,8 +152,24 @@ def get_credential(self): class MockGoogleComputeEngineCredential(compute_engine.Credentials): """A mock Compute Engine credential""" + + def __init__(self): + super().__init__() + self.token = None + self._service_account_email = None + self._token_state = credentials.TokenState.INVALID + def refresh(self, request): self.token = 'mock-compute-engine-token' + self._service_account_email = 'mock-gce-email' + self._token_state = credentials.TokenState.FRESH + + @property + def token_state(self): + return self._token_state + + def _metric_header_for_usage(self): + return 'mock-gce-cred-metric-tag' class MockComputeEngineCredential(firebase_admin.credentials.Base): """A mock Firebase credential implementation.""" From f85a8de1b5d9a252971827d2a6c075d59d564004 Mon Sep 17 00:00:00 2001 From: Jonathan Edey <145066863+jonathanedey@users.noreply.github.com> Date: Mon, 29 Sep 2025 15:27:56 -0400 Subject: [PATCH 217/226] chore: Fix typo (#913) --- firebase_admin/credentials.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/firebase_admin/credentials.py b/firebase_admin/credentials.py index 7117b71a9..0edbecaae 100644 --- a/firebase_admin/credentials.py +++ b/firebase_admin/credentials.py @@ -37,7 +37,7 @@ AccessTokenInfo = collections.namedtuple('AccessTokenInfo', ['access_token', 'expiry']) """Data included in an OAuth2 access token. -Contains the access token string and the expiry time. The expirty time is exposed as a +Contains the access token string and the expiry time. The expiry time is exposed as a ``datetime`` value. """ From fc6c8ee67ea29fe498e7cfca907a5c9fa41a3fed Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 25 Nov 2025 15:43:48 +0000 Subject: [PATCH 218/226] chore(deps): bump pylint from 3.3.7 to 3.3.9 (#917) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index c68d71a0f..3b96eea00 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ astroid == 3.3.11 -pylint == 3.3.7 +pylint == 3.3.9 pytest >= 8.2.2 pytest-cov >= 2.4.0 pytest-localserver >= 0.4.1 From 2305519d058afb5aaaa326e790cc52690ec596f6 Mon Sep 17 00:00:00 2001 From: Jonathan Edey <145066863+jonathanedey@users.noreply.github.com> Date: Wed, 3 Dec 2025 12:46:03 -0500 Subject: [PATCH 219/226] feat(functions): Enable Cloud Task Queue Emulator support (#920) * feat(functions): Enable Cloud Task Queue Emulator support * fix: lint * fix: Resolved issues from gemini review * chore: Added basic integration tests for task enqueue and delete * chore: Setup emulator testing for Functions integration tests * fix: Re-added accidentally removed lint * fix: integration test default apps * fix: lint --- .github/workflows/ci.yml | 22 ++++- CONTRIBUTING.md | 11 +++ firebase_admin/functions.py | 97 ++++++++++++++++--- integration/emulators/.gitignore | 69 +++++++++++++ integration/emulators/firebase.json | 29 ++++++ integration/emulators/functions/.gitignore | 6 ++ integration/emulators/functions/main.py | 7 ++ .../emulators/functions/requirements.txt | 1 + integration/test_functions.py | 52 ++++++++-- tests/test_functions.py | 91 +++++++++++++---- 10 files changed, 336 insertions(+), 49 deletions(-) create mode 100644 integration/emulators/.gitignore create mode 100644 integration/emulators/firebase.json create mode 100644 integration/emulators/functions/.gitignore create mode 100644 integration/emulators/functions/main.py create mode 100644 integration/emulators/functions/requirements.txt diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index bfd29e2cc..2ba09880b 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -12,6 +12,17 @@ jobs: steps: - uses: actions/checkout@v4 + + - name: Set up Python 3.13 for emulator + uses: actions/setup-python@v5 + with: + python-version: '3.13' + - name: Setup functions emulator environment + run: | + python -m venv integration/emulators/functions/venv + source integration/emulators/functions/venv/bin/activate + pip install -r integration/emulators/functions/requirements.txt + deactivate - name: Set up Python ${{ matrix.python }} uses: actions/setup-python@v5 with: @@ -26,11 +37,12 @@ jobs: uses: actions/setup-node@v4 with: node-version: 20 - - name: Run integration tests against emulator - run: | - npm install -g firebase-tools - firebase emulators:exec --only database --project fake-project-id 'pytest integration/test_db.py' - + - name: Install firebase-tools + run: npm install -g firebase-tools + - name: Run Database emulator tests + run: firebase emulators:exec --only database --project fake-project-id 'pytest integration/test_db.py' + - name: Run Functions emulator tests + run: firebase emulators:exec --config integration/emulators/firebase.json --only tasks,functions --project fake-project-id 'CLOUD_TASKS_EMULATOR_HOST=localhost:9499 pytest integration/test_functions.py' lint: runs-on: ubuntu-latest steps: diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 72933a24f..71da12dc6 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -252,6 +252,17 @@ to ensure that exported user records contain the password hashes of the user acc 3. Click **ADD ANOTHER ROLE** and choose **Firebase Authentication Admin**. 4. Click **SAVE**. +9. Enable Cloud Tasks: + 1. Search for and enable **Cloud Run**. + 2. Search for and enable **Cloud Tasks**. + 3. Go to [Google Cloud console | IAM & admin](https://console.cloud.google.com/iam-admin) + and make sure your Firebase project is selected. + 4. Ensure your service account has the following required roles: + * **Cloud Tasks Enqueuer** - `cloudtasks.taskEnqueuer` + * **Cloud Tasks Task Deleter** - `cloudtasks.taskDeleter` + * **Cloud Run Invoker** - `run.invoker` + * **Service Account User** - `iam.serviceAccountUser` + Now you can invoke the integration test suite as follows: diff --git a/firebase_admin/functions.py b/firebase_admin/functions.py index 8e77d8560..66ba700b3 100644 --- a/firebase_admin/functions.py +++ b/firebase_admin/functions.py @@ -18,6 +18,7 @@ from datetime import datetime, timedelta, timezone from urllib import parse import re +import os import json from base64 import b64encode from typing import Any, Optional, Dict @@ -49,6 +50,8 @@ 'https://cloudtasks.googleapis.com/v2/' + _CLOUD_TASKS_API_RESOURCE_PATH _FIREBASE_FUNCTION_URL_FORMAT = \ 'https://{location_id}-{project_id}.cloudfunctions.net/{resource_id}' +_EMULATOR_HOST_ENV_VAR = 'CLOUD_TASKS_EMULATOR_HOST' +_EMULATED_SERVICE_ACCOUNT_DEFAULT = 'emulated-service-acct@email.com' _FUNCTIONS_HEADERS = { 'X-GOOG-API-FORMAT-VERSION': '2', @@ -58,6 +61,17 @@ # Default canonical location ID of the task queue. _DEFAULT_LOCATION = 'us-central1' +def _get_emulator_host() -> Optional[str]: + emulator_host = os.environ.get(_EMULATOR_HOST_ENV_VAR) + if emulator_host: + if '//' in emulator_host: + raise ValueError( + f'Invalid {_EMULATOR_HOST_ENV_VAR}: "{emulator_host}". It must follow format ' + '"host:port".') + return emulator_host + return None + + def _get_functions_service(app) -> _FunctionsService: return _utils.get_app_service(app, _FUNCTIONS_ATTRIBUTE, _FunctionsService) @@ -103,13 +117,19 @@ def __init__(self, app: App): 'projectId option, or use service account credentials. Alternatively, set the ' 'GOOGLE_CLOUD_PROJECT environment variable.') - self._credential = app.credential.get_credential() + self._emulator_host = _get_emulator_host() + if self._emulator_host: + self._credential = _utils.EmulatorAdminCredentials() + else: + self._credential = app.credential.get_credential() + self._http_client = _http_client.JsonHttpClient(credential=self._credential) def task_queue(self, function_name: str, extension_id: Optional[str] = None) -> TaskQueue: """Creates a TaskQueue instance.""" return TaskQueue( - function_name, extension_id, self._project_id, self._credential, self._http_client) + function_name, extension_id, self._project_id, self._credential, self._http_client, + self._emulator_host) @classmethod def handle_functions_error(cls, error: Any): @@ -125,7 +145,8 @@ def __init__( extension_id: Optional[str], project_id, credential, - http_client + http_client, + emulator_host: Optional[str] = None ) -> None: # Validate function_name @@ -134,6 +155,7 @@ def __init__( self._project_id = project_id self._credential = credential self._http_client = http_client + self._emulator_host = emulator_host self._function_name = function_name self._extension_id = extension_id # Parse resources from function_name @@ -167,16 +189,26 @@ def enqueue(self, task_data: Any, opts: Optional[TaskOptions] = None) -> str: str: The ID of the task relative to this queue. """ task = self._validate_task_options(task_data, self._resource, opts) - service_url = self._get_url(self._resource, _CLOUD_TASKS_API_URL_FORMAT) + emulator_url = self._get_emulator_url(self._resource) + service_url = emulator_url or self._get_url(self._resource, _CLOUD_TASKS_API_URL_FORMAT) task_payload = self._update_task_payload(task, self._resource, self._extension_id) try: resp = self._http_client.body( 'post', url=service_url, headers=_FUNCTIONS_HEADERS, - json={'task': task_payload.__dict__} + json={'task': task_payload.to_api_dict()} ) - task_name = resp.get('name', None) + if self._is_emulated(): + # Emulator returns a response with format {task: {name: }} + # The task name also has an extra '/' at the start compared to prod + task_info = resp.get('task') or {} + task_name = task_info.get('name') + if task_name: + task_name = task_name[1:] + else: + # Production returns a response with format {name: } + task_name = resp.get('name') task_resource = \ self._parse_resource_name(task_name, f'queues/{self._resource.resource_id}/tasks') return task_resource.resource_id @@ -197,7 +229,11 @@ def delete(self, task_id: str) -> None: ValueError: If the input arguments are invalid. """ _Validators.check_non_empty_string('task_id', task_id) - service_url = self._get_url(self._resource, _CLOUD_TASKS_API_URL_FORMAT + f'/{task_id}') + emulator_url = self._get_emulator_url(self._resource) + if emulator_url: + service_url = emulator_url + f'/{task_id}' + else: + service_url = self._get_url(self._resource, _CLOUD_TASKS_API_URL_FORMAT + f'/{task_id}') try: self._http_client.body( 'delete', @@ -235,8 +271,8 @@ def _validate_task_options( """Validate and create a Task from optional ``TaskOptions``.""" task_http_request = { 'url': '', - 'oidc_token': { - 'service_account_email': '' + 'oidcToken': { + 'serviceAccountEmail': '' }, 'body': b64encode(json.dumps(data).encode()).decode(), 'headers': { @@ -250,7 +286,7 @@ def _validate_task_options( task.http_request['headers'] = {**task.http_request['headers'], **opts.headers} if opts.schedule_time is not None and opts.schedule_delay_seconds is not None: raise ValueError( - 'Both sechdule_delay_seconds and schedule_time cannot be set at the same time.') + 'Both schedule_delay_seconds and schedule_time cannot be set at the same time.') if opts.schedule_time is not None and opts.schedule_delay_seconds is None: if not isinstance(opts.schedule_time, datetime): raise ValueError('schedule_time should be UTC datetime.') @@ -288,7 +324,10 @@ def _update_task_payload(self, task: Task, resource: Resource, extension_id: str """Prepares task to be sent with credentials.""" # Get function url from task or generate from resources if not _Validators.is_non_empty_string(task.http_request['url']): - task.http_request['url'] = self._get_url(resource, _FIREBASE_FUNCTION_URL_FORMAT) + if self._is_emulated(): + task.http_request['url'] = '' + else: + task.http_request['url'] = self._get_url(resource, _FIREBASE_FUNCTION_URL_FORMAT) # Refresh the credential to ensure all attributes (e.g. service_account_email, id_token) # are populated, preventing cold start errors. @@ -298,7 +337,7 @@ def _update_task_payload(self, task: Task, resource: Resource, extension_id: str except RefreshError as err: raise ValueError(f'Initial task payload credential refresh failed: {err}') from err - # If extension id is provided, it emplies that it is being run from a deployed extension. + # If extension id is provided, it implies that it is being run from a deployed extension. # Meaning that it's credential should be a Compute Engine Credential. if _Validators.is_non_empty_string(extension_id) and \ isinstance(self._credential, ComputeEngineCredentials): @@ -306,12 +345,32 @@ def _update_task_payload(self, task: Task, resource: Resource, extension_id: str task.http_request['headers'] = \ {**task.http_request['headers'], 'Authorization': f'Bearer {id_token}'} # Delete oidc token - del task.http_request['oidc_token'] + del task.http_request['oidcToken'] else: - task.http_request['oidc_token'] = \ - {'service_account_email': self._credential.service_account_email} + try: + task.http_request['oidcToken'] = \ + {'serviceAccountEmail': self._credential.service_account_email} + except AttributeError as error: + if self._is_emulated(): + task.http_request['oidcToken'] = \ + {'serviceAccountEmail': _EMULATED_SERVICE_ACCOUNT_DEFAULT} + else: + raise ValueError( + 'Failed to determine service account. Initialize the SDK with service ' + 'account credentials or set service account ID as an app option.' + ) from error return task + def _get_emulator_url(self, resource: Resource): + if self._emulator_host: + emulator_url_format = f'http://{self._emulator_host}/' + _CLOUD_TASKS_API_RESOURCE_PATH + url = self._get_url(resource, emulator_url_format) + return url + return None + + def _is_emulated(self): + return self._emulator_host is not None + class _Validators: """A collection of data validation utilities.""" @@ -436,6 +495,14 @@ class Task: schedule_time: Optional[str] = None dispatch_deadline: Optional[str] = None + def to_api_dict(self) -> dict: + """Converts the Task object to a dictionary suitable for the Cloud Tasks API.""" + return { + 'httpRequest': self.http_request, + 'name': self.name, + 'scheduleTime': self.schedule_time, + 'dispatchDeadline': self.dispatch_deadline, + } @dataclass class Resource: diff --git a/integration/emulators/.gitignore b/integration/emulators/.gitignore new file mode 100644 index 000000000..b17f63107 --- /dev/null +++ b/integration/emulators/.gitignore @@ -0,0 +1,69 @@ +# Logs +logs +*.log +npm-debug.log* +yarn-debug.log* +yarn-error.log* +firebase-debug.log* +firebase-debug.*.log* + +# Firebase cache +.firebase/ + +# Firebase config + +# Uncomment this if you'd like others to create their own Firebase project. +# For a team working on the same Firebase project(s), it is recommended to leave +# it commented so all members can deploy to the same project(s) in .firebaserc. +# .firebaserc + +# Runtime data +pids +*.pid +*.seed +*.pid.lock + +# Directory for instrumented libs generated by jscoverage/JSCover +lib-cov + +# Coverage directory used by tools like istanbul +coverage + +# nyc test coverage +.nyc_output + +# Grunt intermediate storage (http://gruntjs.com/creating-plugins#storing-task-files) +.grunt + +# Bower dependency directory (https://bower.io/) +bower_components + +# node-waf configuration +.lock-wscript + +# Compiled binary addons (http://nodejs.org/api/addons.html) +build/Release + +# Dependency directories +node_modules/ + +# Optional npm cache directory +.npm + +# Optional eslint cache +.eslintcache + +# Optional REPL history +.node_repl_history + +# Output of 'npm pack' +*.tgz + +# Yarn Integrity file +.yarn-integrity + +# dotenv environment variables file +.env + +# dataconnect generated files +.dataconnect diff --git a/integration/emulators/firebase.json b/integration/emulators/firebase.json new file mode 100644 index 000000000..a7b727c4d --- /dev/null +++ b/integration/emulators/firebase.json @@ -0,0 +1,29 @@ +{ + "emulators": { + "tasks": { + "port": 9499 + }, + "ui": { + "enabled": false + }, + "singleProjectMode": true, + "functions": { + "port": 5001 + } + }, + "functions": [ + { + "source": "functions", + "codebase": "default", + "disallowLegacyRuntimeConfig": true, + "ignore": [ + "venv", + ".git", + "firebase-debug.log", + "firebase-debug.*.log", + "*.local" + ], + "runtime": "python313" + } + ] +} diff --git a/integration/emulators/functions/.gitignore b/integration/emulators/functions/.gitignore new file mode 100644 index 000000000..1609bab70 --- /dev/null +++ b/integration/emulators/functions/.gitignore @@ -0,0 +1,6 @@ +# Python bytecode +__pycache__/ + +# Python virtual environment +venv/ +*.local diff --git a/integration/emulators/functions/main.py b/integration/emulators/functions/main.py new file mode 100644 index 000000000..6cd2c5766 --- /dev/null +++ b/integration/emulators/functions/main.py @@ -0,0 +1,7 @@ +from firebase_functions import tasks_fn + +@tasks_fn.on_task_dispatched() +def testTaskQueue(req: tasks_fn.CallableRequest) -> None: + """Handles tasks from the task queue.""" + print(f"Received task with data: {req.data}") + return diff --git a/integration/emulators/functions/requirements.txt b/integration/emulators/functions/requirements.txt new file mode 100644 index 000000000..6bbab42f8 --- /dev/null +++ b/integration/emulators/functions/requirements.txt @@ -0,0 +1 @@ +firebase_functions~=0.4.1 diff --git a/integration/test_functions.py b/integration/test_functions.py index 606798436..fc972f9e5 100644 --- a/integration/test_functions.py +++ b/integration/test_functions.py @@ -14,17 +14,34 @@ """Integration tests for firebase_admin.functions module.""" +import os import pytest import firebase_admin from firebase_admin import functions +from firebase_admin import _utils from integration import conftest +_DEFAULT_DATA = {'data': {'city': 'Seattle'}} +def integration_conf(request): + host_override = os.environ.get('CLOUD_TASKS_EMULATOR_HOST') + if host_override: + return _utils.EmulatorAdminCredentials(), 'fake-project-id' + + return conftest.integration_conf(request) + @pytest.fixture(scope='module') def app(request): - cred, _ = conftest.integration_conf(request) - return firebase_admin.initialize_app(cred, name='integration-functions') + cred, project_id = integration_conf(request) + return firebase_admin.initialize_app( + cred, options={'projectId': project_id}, name='integration-functions') + +@pytest.fixture(scope='module', autouse=True) +def default_app(): + # Overwrites the default_app fixture in conftest.py. + # This test suite should not use the default app. Use the app fixture instead. + pass class TestFunctions: @@ -41,16 +58,31 @@ class TestFunctions: ] @pytest.mark.parametrize('task_queue_params', _TEST_FUNCTIONS_PARAMS) - def test_task_queue(self, task_queue_params): - queue = functions.task_queue(**task_queue_params) - assert queue is not None - assert callable(queue.enqueue) - assert callable(queue.delete) - - @pytest.mark.parametrize('task_queue_params', _TEST_FUNCTIONS_PARAMS) - def test_task_queue_app(self, task_queue_params, app): + def test_task_queue(self, task_queue_params, app): assert app.name == 'integration-functions' queue = functions.task_queue(**task_queue_params, app=app) assert queue is not None assert callable(queue.enqueue) assert callable(queue.delete) + + def test_task_enqueue(self, app): + queue = functions.task_queue('testTaskQueue', app=app) + task_id = queue.enqueue(_DEFAULT_DATA) + assert task_id is not None + + @pytest.mark.skipif( + os.environ.get('CLOUD_TASKS_EMULATOR_HOST') is not None, + reason="Skipping test_task_delete against emulator due to bug in firebase-tools" + ) + def test_task_delete(self, app): + # Skip this test against the emulator since tasks can't be delayed there to verify deletion + # See: https://github.com/firebase/firebase-tools/issues/8254 + task_options = functions.TaskOptions(schedule_delay_seconds=60) + queue = functions.task_queue('testTaskQueue', app=app) + task_id = queue.enqueue(_DEFAULT_DATA, task_options) + assert task_id is not None + queue.delete(task_id) + # We don't have a way to check the contents of the queue so we check that the deleted + # task is not found using the delete method again. + with pytest.raises(firebase_admin.exceptions.NotFoundError): + queue.delete(task_id) diff --git a/tests/test_functions.py b/tests/test_functions.py index 953563449..0f766767a 100644 --- a/tests/test_functions.py +++ b/tests/test_functions.py @@ -44,13 +44,14 @@ def setup_class(cls): def teardown_class(cls): testutils.cleanup_apps() - def _instrument_functions_service(self, app=None, status=200, payload=_DEFAULT_RESPONSE): + def _instrument_functions_service( + self, app=None, status=200, payload=_DEFAULT_RESPONSE, mounted_url=_CLOUD_TASKS_URL): if not app: app = firebase_admin.get_app() functions_service = functions._get_functions_service(app) recorder = [] functions_service._http_client.session.mount( - _CLOUD_TASKS_URL, + mounted_url, testutils.MockAdapter(payload, status, recorder)) return functions_service, recorder @@ -125,8 +126,8 @@ def test_task_enqueue(self): assert task_id == 'test-task-id' task = json.loads(recorder[0].body.decode())['task'] - assert task['http_request']['oidc_token'] == {'service_account_email': 'mock-email'} - assert task['http_request']['headers'] == {'Content-Type': 'application/json'} + assert task['httpRequest']['oidcToken'] == {'serviceAccountEmail': 'mock-email'} + assert task['httpRequest']['headers'] == {'Content-Type': 'application/json'} def test_task_enqueue_with_extension(self): resource_name = ( @@ -147,8 +148,8 @@ def test_task_enqueue_with_extension(self): assert task_id == 'test-task-id' task = json.loads(recorder[0].body.decode())['task'] - assert task['http_request']['oidc_token'] == {'service_account_email': 'mock-email'} - assert task['http_request']['headers'] == {'Content-Type': 'application/json'} + assert task['httpRequest']['oidcToken'] == {'serviceAccountEmail': 'mock-email'} + assert task['httpRequest']['headers'] == {'Content-Type': 'application/json'} def test_task_enqueue_compute_engine(self): app = firebase_admin.initialize_app( @@ -168,8 +169,8 @@ def test_task_enqueue_compute_engine(self): assert task_id == 'test-task-id' task = json.loads(recorder[0].body.decode())['task'] - assert task['http_request']['oidc_token'] == {'service_account_email': 'mock-gce-email'} - assert task['http_request']['headers'] == {'Content-Type': 'application/json'} + assert task['httpRequest']['oidcToken'] == {'serviceAccountEmail': 'mock-gce-email'} + assert task['httpRequest']['headers'] == {'Content-Type': 'application/json'} def test_task_enqueue_with_extension_compute_engine(self): resource_name = ( @@ -194,8 +195,8 @@ def test_task_enqueue_with_extension_compute_engine(self): assert task_id == 'test-task-id' task = json.loads(recorder[0].body.decode())['task'] - assert 'oidc_token' not in task['http_request'] - assert task['http_request']['headers'] == { + assert 'oidcToken' not in task['httpRequest'] + assert task['httpRequest']['headers'] == { 'Content-Type': 'application/json', 'Authorization': 'Bearer mock-compute-engine-token'} @@ -209,6 +210,58 @@ def test_task_delete(self): expected_metrics_header = _utils.get_metrics_header() + ' mock-cred-metric-tag' assert recorder[0].headers['x-goog-api-client'] == expected_metrics_header + def test_task_enqueue_with_emulator_host(self, monkeypatch): + emulator_host = 'localhost:8124' + emulator_url = f'http://{emulator_host}/' + request_url = emulator_url + _DEFAULT_TASK_PATH.replace('/tasks/test-task-id', '/tasks') + + monkeypatch.setenv('CLOUD_TASKS_EMULATOR_HOST', emulator_host) + app = firebase_admin.initialize_app( + _utils.EmulatorAdminCredentials(), {'projectId': 'test-project'}, name='emulator-app') + + expected_task_name = ( + '/projects/test-project/locations/us-central1' + '/queues/test-function-name/tasks/test-task-id' + ) + expected_response = json.dumps({'task': {'name': expected_task_name}}) + _, recorder = self._instrument_functions_service( + app, payload=expected_response, mounted_url=emulator_url) + + queue = functions.task_queue('test-function-name', app=app) + task_id = queue.enqueue(_DEFAULT_DATA) + + assert len(recorder) == 1 + assert recorder[0].method == 'POST' + assert recorder[0].url == request_url + assert recorder[0].headers['Content-Type'] == 'application/json' + + task = json.loads(recorder[0].body.decode())['task'] + assert task['httpRequest']['oidcToken'] == { + 'serviceAccountEmail': 'emulated-service-acct@email.com' + } + assert task_id == 'test-task-id' + + def test_task_enqueue_without_emulator_host_error(self, monkeypatch): + app = firebase_admin.initialize_app( + _utils.EmulatorAdminCredentials(), + {'projectId': 'test-project'}, name='no-emulator-app') + + _, recorder = self._instrument_functions_service(app) + monkeypatch.delenv('CLOUD_TASKS_EMULATOR_HOST', raising=False) + queue = functions.task_queue('test-function-name', app=app) + with pytest.raises(ValueError) as excinfo: + queue.enqueue(_DEFAULT_DATA) + assert "Failed to determine service account" in str(excinfo.value) + assert len(recorder) == 0 + + def test_get_emulator_url_invalid_format(self, monkeypatch): + monkeypatch.setenv('CLOUD_TASKS_EMULATOR_HOST', 'http://localhost:8124') + app = firebase_admin.initialize_app( + testutils.MockCredential(), {'projectId': 'test-project'}, name='invalid-host-app') + with pytest.raises(ValueError) as excinfo: + functions.task_queue('test-function-name', app=app) + assert 'Invalid CLOUD_TASKS_EMULATOR_HOST' in str(excinfo.value) + class TestTaskQueueOptions: _DEFAULT_TASK_OPTS = {'schedule_delay_seconds': None, 'schedule_time': None, \ @@ -259,13 +312,13 @@ def test_task_options_delay_seconds(self): assert len(recorder) == 1 task = json.loads(recorder[0].body.decode())['task'] - task_schedule_time = datetime.fromisoformat(task['schedule_time'].replace('Z', '+00:00')) + task_schedule_time = datetime.fromisoformat(task['scheduleTime'].replace('Z', '+00:00')) delta = abs(task_schedule_time - expected_schedule_time) assert delta <= timedelta(seconds=1) - assert task['dispatch_deadline'] == '200s' - assert task['http_request']['headers']['x-test-header'] == 'test-header-value' - assert task['http_request']['url'] in ['http://google.com', 'https://google.com'] + assert task['dispatchDeadline'] == '200s' + assert task['httpRequest']['headers']['x-test-header'] == 'test-header-value' + assert task['httpRequest']['url'] in ['http://google.com', 'https://google.com'] assert task['name'] == _DEFAULT_TASK_PATH def test_task_options_utc_time(self): @@ -287,12 +340,12 @@ def test_task_options_utc_time(self): assert len(recorder) == 1 task = json.loads(recorder[0].body.decode())['task'] - task_schedule_time = datetime.fromisoformat(task['schedule_time'].replace('Z', '+00:00')) + task_schedule_time = datetime.fromisoformat(task['scheduleTime'].replace('Z', '+00:00')) assert task_schedule_time == expected_schedule_time - assert task['dispatch_deadline'] == '200s' - assert task['http_request']['headers']['x-test-header'] == 'test-header-value' - assert task['http_request']['url'] in ['http://google.com', 'https://google.com'] + assert task['dispatchDeadline'] == '200s' + assert task['httpRequest']['headers']['x-test-header'] == 'test-header-value' + assert task['httpRequest']['url'] in ['http://google.com', 'https://google.com'] assert task['name'] == _DEFAULT_TASK_PATH def test_schedule_set_twice_error(self): @@ -304,7 +357,7 @@ def test_schedule_set_twice_error(self): queue.enqueue(_DEFAULT_DATA, opts) assert len(recorder) == 0 assert str(excinfo.value) == \ - 'Both sechdule_delay_seconds and schedule_time cannot be set at the same time.' + 'Both schedule_delay_seconds and schedule_time cannot be set at the same time.' @pytest.mark.parametrize('schedule_time', [ From 807e7e1d8abdf37a8413f42864c12bebce89fd21 Mon Sep 17 00:00:00 2001 From: Jonathan Edey <145066863+jonathanedey@users.noreply.github.com> Date: Tue, 9 Dec 2025 15:32:31 -0500 Subject: [PATCH 220/226] chore: Fix auth snippet typo (#924) --- snippets/auth/index.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/snippets/auth/index.py b/snippets/auth/index.py index 6a509b8f5..656137dba 100644 --- a/snippets/auth/index.py +++ b/snippets/auth/index.py @@ -770,7 +770,7 @@ def get_tenant(tenant_id): # [START get_tenant] tenant = tenant_mgt.get_tenant(tenant_id) - print('Retreieved tenant:', tenant.tenant_id) + print('Retrieved tenant:', tenant.tenant_id) # [END get_tenant] def create_tenant(): From d5aba8443196e0212d724bd7b81f73689b5c8a08 Mon Sep 17 00:00:00 2001 From: Lahiru Maramba Date: Fri, 12 Dec 2025 17:16:11 -0500 Subject: [PATCH 221/226] chore: Update default branch to `main` (#926) * chore: Update default branch to main * set java version to fix emulator tools --- .github/scripts/publish_preflight_check.sh | 4 ++-- .github/workflows/ci.yml | 6 ++++++ .github/workflows/nightly.yml | 4 ++-- .github/workflows/release.yml | 6 +++--- CONTRIBUTING.md | 2 +- 5 files changed, 14 insertions(+), 8 deletions(-) diff --git a/.github/scripts/publish_preflight_check.sh b/.github/scripts/publish_preflight_check.sh index 1d001c3b9..38fe49a88 100755 --- a/.github/scripts/publish_preflight_check.sh +++ b/.github/scripts/publish_preflight_check.sh @@ -159,8 +159,8 @@ echo_info "Generating changelog" echo_info "--------------------------------------------" echo_info "" -echo_info "---< git fetch origin master --prune --unshallow >---" -git fetch origin master --prune --unshallow +echo_info "---< git fetch origin main --prune --unshallow >---" +git fetch origin main --prune --unshallow echo "" echo_info "Generating changelog from history..." diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 2ba09880b..fa980083e 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -37,6 +37,12 @@ jobs: uses: actions/setup-node@v4 with: node-version: 20 + - name: Set up Java 21 + uses: actions/setup-java@v5 + with: + distribution: 'temurin' + java-version: '21' + check-latest: true - name: Install firebase-tools run: npm install -g firebase-tools - name: Run Database emulator tests diff --git a/.github/workflows/nightly.yml b/.github/workflows/nightly.yml index 3d5420537..61644e806 100644 --- a/.github/workflows/nightly.yml +++ b/.github/workflows/nightly.yml @@ -70,7 +70,7 @@ jobs: - name: Send email on failure if: failure() - uses: firebase/firebase-admin-node/.github/actions/send-email@master + uses: firebase/firebase-admin-node/.github/actions/send-email@main with: api-key: ${{ secrets.OSS_BOT_MAILGUN_KEY }} domain: ${{ secrets.OSS_BOT_MAILGUN_DOMAIN }} @@ -85,7 +85,7 @@ jobs: - name: Send email on cancelled if: cancelled() - uses: firebase/firebase-admin-node/.github/actions/send-email@master + uses: firebase/firebase-admin-node/.github/actions/send-email@main with: api-key: ${{ secrets.OSS_BOT_MAILGUN_KEY }} domain: ${{ secrets.OSS_BOT_MAILGUN_DOMAIN }} diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 6cd1d3f07..738dfca55 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -84,11 +84,11 @@ jobs: # Check whether the release should be published. We publish only when the trigger PR is # 1. merged - # 2. to the master branch + # 2. to the main branch # 3. with the label 'release:publish', and # 4. the title prefix '[chore] Release '. if: github.event.pull_request.merged && - github.ref == 'refs/heads/master' && + github.ref == 'refs/heads/main' && contains(github.event.pull_request.labels.*.name, 'release:publish') && startsWith(github.event.pull_request.title, '[chore] Release ') @@ -130,7 +130,7 @@ jobs: - name: Post to Twitter if: success() && contains(github.event.pull_request.labels.*.name, 'release:tweet') - uses: firebase/firebase-admin-node/.github/actions/send-tweet@master + uses: firebase/firebase-admin-node/.github/actions/send-tweet@main with: status: > ${{ steps.preflight.outputs.version }} of @Firebase Admin Python SDK is available. diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 71da12dc6..139e7f96c 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -47,7 +47,7 @@ Great, we love hearing how we can improve our products! Share you idea through o ## Want to submit a pull request? Sweet, we'd love to accept your contribution! -[Open a new pull request](https://github.com/firebase/firebase-admin-python/pull/new/master) and fill +[Open a new pull request](https://github.com/firebase/firebase-admin-python/pull/new) and fill out the provided template. **If you want to implement a new feature, please open an issue with a proposal first so that we can From e8276552c377d72452f6cd182ad9f4fc62982112 Mon Sep 17 00:00:00 2001 From: Jonathan Edey <145066863+jonathanedey@users.noreply.github.com> Date: Tue, 13 Jan 2026 15:54:59 -0500 Subject: [PATCH 222/226] chore: Update github actions workflows and integration test resources (#932) * chore: Pinned github actions to full-length comit SHAs * chore: Update integration test resource * chore: Added environment label to release action * Trigger integration tests --- .../resources/integ-service-account.json.gpg | Bin 1762 -> 1756 bytes .github/workflows/ci.yml | 14 +++++++------- .github/workflows/nightly.yml | 10 +++++----- .github/workflows/release.yml | 15 ++++++++------- 4 files changed, 20 insertions(+), 19 deletions(-) diff --git a/.github/resources/integ-service-account.json.gpg b/.github/resources/integ-service-account.json.gpg index 7740dccd8bdada2eecc181f75c552c00e912e5c2..5a52805c9a854fdea91101dfa1822cdd5e79c874 100644 GIT binary patch literal 1756 zcmV<21|#{54Fm}T3b4{&ennUfga6X%0rw}><=tSBuPhv7FQy8=I?3Lmw%%fwb*z*z z=6KA27D5iuBgx~x$Nyfl&GpjP(Qv$$D`=U2l?uxu35Sn`=|kyz<8?VA+jsyF+7=2@ z<@|}U!==V573LAFo)Fk!&)W@2^oQ_BI8xVK1@Fz@nY52OgWsOxEWs4?ws$bNy2VpekV}20eGu8>0Z((IN{nq@79Jv z>qJM3@B0X3ROSPK`6JTc9rmfZp<%iCo=L%D$)96DgP`ezf<5a55_+{><$|JWn2{6- zu8|~zd^P6mDHo4^^s;zYo=gn$v3_|-6h#Q*v>YB4FwzzYxJ$OFI2qsh-<}WE`UP0o zvNc^F5c1`|u5^Lo6gM(449xUT0J_lPfT`o-Rv3YXMm45zU)xYv%V8E`t!^%3W2T@d zME7SXe{`62EFI}_gg}|*iwSRjWZ)dX-wOf2SENE#%c{=78oej87A$)iafd`bW8SH6 zJ-kB6gL(_ov>A`sM?z99l_M;#(lNwfS+_jn-)(p+mR%!(X&i-aQ;uoOep`8uJh{NU z`;X((7+wEarUlyb_5_u^GXdq zaHU$=fZ8&|fJOjb;?BfCUET*97w=(u`g8iTDN%`zGy@e26Rze|?C=eo-{bX~^~{(OslofH)*vmX`q?uh;mD5YR6@E1U1-|M>X0?* znFf7wQOrblD2(asM#I1dWRAPl@zh(QD z3XsUb(1C36YvZSlmD6!_%N!na+}IMOU|^m>_N`mzbT zZz&Ee7(3&M^=M8Xx_;F26O2}mG$#jR3sA&;ot)ktfCj624V{@9&f|P1`hLMTpnf^v zyhk%W;T7PT7K8gpz*`iIN{Id&dQVMOU|K%KwbY*8!(h}Mh1@VUgaWc{)H1_HNj{KH zq!~m9Ks6xb#27@Jch`+}Rsh!qa(SViGg&w78YzoJc~jys?pKii|FV@TzhUOuk}*;X z4UM8Nu9^|{7&XbCsEJX@rHSaw)%Fuy4-8YCUH<)teY{M^oANVK%EBx4h~)y7RP!u+ zG4q4$%cl%;>0k?v&6UdW=S7~lH-VQ%rOiDIg!tzMdLHwerG6__b|UZ{+x*``6FM_x z9Q%*J0L1Bea||NK+{ND-NlRa5UL%ezMBxFUxC1VlSB|DLWU6iI+7OPXmC?gb}g23VDc?X&{)zPJ~y802E;k33pZuLifi1$Fy+IgceC8QUX(c8vuYUjQKL zPLpo*M3{Hm$;qSSe&l)rR9` zt9n!#Tg#AJ^E292w0{@6@OsPYQ6>R0#1H-qumS64GunHce~@%Y9h397pn53owR5DC zt*{lGzs20XTi%=eqw1~J<2TYgq7Y{bA?-e8DSjCROK>ezwNru{_CUNLu@-pZpX7y~ z;J^s(GiVC9k!jG$m?zZX8w$6!$XoIBdUL|re!W6X_8{PtSQ38mbE@D8(%y$_zcNI6 zQOb{GR=s*=H?m`D`{_yr@g0Jr`3B&P>JtuTtoLNZgc9|vz>at=Wdh*Zk?DLam2X-i zvOz~E+b``WR>({bK{cqmo9hPH%#g(%I1-u!fh6lSWv!AmZ6+;`&jUAf%zUT!Y_d3L ye!cL($Ue&aX9i7=EYc2_I+(Y=Ru;Ur6L%VZx+|5b-zMEMj_Xxk_VrmJgD2Lh0c>#q literal 1762 zcmV<81|9i~4Fm}T0)n*iur~xoxnP%@af)-$RhZ&7)3w(?>+$-b1cw>C)uo%^A>P*=<*0iel03 za}Fsg1`ruYzno_!>aY;)OsTkc%h&&pe{$nZw5?@UjcH^%VjHu)VgwEnKQUK4HJk0= zyZ5gIrQ>xo>l)_!-wlcQad@p8gshGgv1xcg_Yjj!nW#U|n*LMBm3Wlbk?i_CZ5Ume zal+1184AaJ?*V#JRCa=MTz1K6B-nEb*&W)}8_{YN|7UTh+U!ds%WR1v?9Hbxk@ z+EogP!glp>+*yB=n_Y5O!Q;p67Vmj%K0Z&IQHoSQT$5HK`B~z1yiX&eY#WwU=~;7S z*R}F3&c9V&rDM=B?(>!Z9}?HjEfYzeD!pR<0+x@F>TU9zk7}`hv#*;%+wZyrU0HEofh5VjbOiejUW3pY z12D!+Wf#4>(VKPmD}rZ71T zja(Co^VJ`de?Q=DM7vMvigtGHs?54vGS7;%Hh%Kgpcm|%^7?KEi%S69Cr3HH;P{#2 zQKW(H3{(*{l=~Bf7Rl-tT)anXQbTFgdpURPL%zy_*wvVIb#5uC(O*IiG}XiMGUA%y z{hDOjK*5$4CjH6|s5_9ya=cleH7`i?@E_&%aD!Azwfl->_ z*G%8=hHUxvc;-w-8SZhmd?W7x5Q>Gdn&TWCqCer@&YB71xG=dZ z)$Q#4CnyByk9+}lQNhr0X%Yl^Z@x^E{?CS9l0qFDy5|%MRx3RMp&FRUbXi#z!svt^ z$1+LBRDKak{r8lR!uG7M$3y`i=mM%I6C;mOE*azp=yC&EjEO@hZAV~%vb<>vw*Tx}<@1-F%E1MJ z&Kx%>OC=ss;A2c#Xs*e!O^%Y!Dj=VPAc>{~z^9u8{@H%2A-4|8w>siL6sFHXXGAGp z6u#;-ci&~<6YWll-R0BZ(9=FJ*d=Z^1zpP&loXySLaK2LILSvW!Dc?^#B&6$r*#Z0VQL=xlo$_b_5MO*>kW6!}M4b=rO4;kC? z7>1#k9x6oenqHJ=>|?M)^;{lP0reAV^DsBc!yt{hJ=)44ze!{NMcolK=)E*z3xPIFQ}^9xIoZ8E2fWR@#G|3UOCwegz@y^#$8?WWnJ zs2>$dD-U!>e@QG8RfQO0o(=0E;V|(Wz7kVeam0oZT7r|6xiFJ zNd8hqdzqJ0-`fL-a@okskU*`zWMi=%n1PPxq!&m!cL|J#l~QZ}A3JebK?T zy$sIEE!|IkVGk9;6HzHA)9Y-IXTCA)$QGOVqdT1OhM%`yYT53qvYtc=4KjLTfqCH9 zgC*H0Jk^%P@UH!391-ApSu$Z{4D#3;W{vq*kO zsL+qpghA`Jm|4{gJrTOb%(>>3vR5{Am*n3SBI9H-+h-WW$}kgTg9MyYUUi*K+#nml Ero6&wRR910 diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index fa980083e..5bf78a56b 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -11,10 +11,10 @@ jobs: python: ['3.9', '3.10', '3.11', '3.12', '3.13', 'pypy3.9'] steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # 4.3.1 - name: Set up Python 3.13 for emulator - uses: actions/setup-python@v5 + uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # 5.6.0 with: python-version: '3.13' - name: Setup functions emulator environment @@ -24,7 +24,7 @@ jobs: pip install -r integration/emulators/functions/requirements.txt deactivate - name: Set up Python ${{ matrix.python }} - uses: actions/setup-python@v5 + uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # 5.6.0 with: python-version: ${{ matrix.python }} - name: Install dependencies @@ -34,11 +34,11 @@ jobs: - name: Test with pytest run: pytest - name: Set up Node.js 20 - uses: actions/setup-node@v4 + uses: actions/setup-node@49933ea5288caeca8642d1e84afbd3f7d6820020 # 4.4.0 with: node-version: 20 - name: Set up Java 21 - uses: actions/setup-java@v5 + uses: actions/setup-java@f2beeb24e141e01a676f977032f5a29d81c9e27e # 5.1.0 with: distribution: 'temurin' java-version: '21' @@ -52,9 +52,9 @@ jobs: lint: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # 4.3.1 - name: Set up Python 3.9 - uses: actions/setup-python@v5 + uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # 5.6.0 with: python-version: 3.9 - name: Install dependencies diff --git a/.github/workflows/nightly.yml b/.github/workflows/nightly.yml index 61644e806..d60b3cd0b 100644 --- a/.github/workflows/nightly.yml +++ b/.github/workflows/nightly.yml @@ -29,12 +29,12 @@ jobs: steps: - name: Checkout source for staging - uses: actions/checkout@v4 + uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # 4.3.1 with: ref: ${{ github.event.client_payload.ref || github.ref }} - name: Set up Python - uses: actions/setup-python@v5 + uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # 5.6.0 with: python-version: 3.9 @@ -63,14 +63,14 @@ jobs: # Attach the packaged artifacts to the workflow output. These can be manually # downloaded for later inspection if necessary. - name: Archive artifacts - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2 with: name: dist path: dist - name: Send email on failure if: failure() - uses: firebase/firebase-admin-node/.github/actions/send-email@main + uses: firebase/firebase-admin-node/.github/actions/send-email@2e2b36a84ba28679bcb7aecdacabfec0bded2d48 # Admin Node SDK v13.6.0 with: api-key: ${{ secrets.OSS_BOT_MAILGUN_KEY }} domain: ${{ secrets.OSS_BOT_MAILGUN_DOMAIN }} @@ -85,7 +85,7 @@ jobs: - name: Send email on cancelled if: cancelled() - uses: firebase/firebase-admin-node/.github/actions/send-email@main + uses: firebase/firebase-admin-node/.github/actions/send-email@2e2b36a84ba28679bcb7aecdacabfec0bded2d48 # Admin Node SDK v13.6.0 with: api-key: ${{ secrets.OSS_BOT_MAILGUN_KEY }} domain: ${{ secrets.OSS_BOT_MAILGUN_DOMAIN }} diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 738dfca55..53ebe825c 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -40,12 +40,12 @@ jobs: # via the 'ref' client parameter. steps: - name: Checkout source for staging - uses: actions/checkout@v4 + uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # 4.3.1 with: ref: ${{ github.event.client_payload.ref || github.ref }} - name: Set up Python - uses: actions/setup-python@v5 + uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # 5.6.0 with: python-version: 3.9 @@ -74,7 +74,7 @@ jobs: # Attach the packaged artifacts to the workflow output. These can be manually # downloaded for later inspection if necessary. - name: Archive artifacts - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2 with: name: dist path: dist @@ -93,6 +93,7 @@ jobs: startsWith(github.event.pull_request.title, '[chore] Release ') runs-on: ubuntu-latest + environment: Release permissions: # Used to create a short-lived OIDC token which is given to PyPi to identify this workflow job # See: https://docs.github.com/en/actions/deployment/security-hardening-your-deployments/about-security-hardening-with-openid-connect#adding-permissions-settings @@ -102,11 +103,11 @@ jobs: steps: - name: Checkout source for publish - uses: actions/checkout@v4 + uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # 4.3.1 # Download the artifacts created by the stage_release job. - name: Download release candidates - uses: actions/download-artifact@v4.1.7 + uses: actions/download-artifact@d3f86a106a0bac45b974a628896c90dbdf5c8093 # v4.3.0 with: name: dist path: dist @@ -124,13 +125,13 @@ jobs: --notes '${{ steps.preflight.outputs.changelog }}' - name: Publish to Pypi - uses: pypa/gh-action-pypi-publish@release/v1 + uses: pypa/gh-action-pypi-publish@ed0c53931b1dc9bd32cbe73a98c7f6766f8a527e # v1.13.0 # Post to Twitter if explicitly opted-in by adding the label 'release:tweet'. - name: Post to Twitter if: success() && contains(github.event.pull_request.labels.*.name, 'release:tweet') - uses: firebase/firebase-admin-node/.github/actions/send-tweet@main + uses: firebase/firebase-admin-node/.github/actions/send-tweet@2e2b36a84ba28679bcb7aecdacabfec0bded2d48 # Admin Node SDK v13.6.0 with: status: > ${{ steps.preflight.outputs.version }} of @Firebase Admin Python SDK is available. From d11b211739f69ed384516e4ed63de7f7ff6a895f Mon Sep 17 00:00:00 2001 From: Jonathan Edey <145066863+jonathanedey@users.noreply.github.com> Date: Tue, 10 Feb 2026 11:48:34 -0500 Subject: [PATCH 223/226] chore: Update release workflows for push triggers (#935) * chore: Update release workflows for push triggers * chore: Update release instructions and fix tag creation --- .github/workflows/release.yml | 60 +++++++++++++++-------------------- 1 file changed, 25 insertions(+), 35 deletions(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 53ebe825c..6bbf19aab 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -15,10 +15,18 @@ name: Release Candidate on: - # Only run the workflow when a PR is updated or when a developer explicitly requests - # a build by sending a 'firebase_build' event. + # Run the workflow when: + # 1. A PR is created or updated (staging checks). + # 2. A commit is pushed to main (release publication). + # 3. A developer explicitly requests a build via 'firebase_build' event. pull_request: - types: [opened, synchronize, closed] + types: [opened, synchronize] + + push: + branches: + - main + paths: + - 'firebase_admin/__about__.py' repository_dispatch: types: @@ -26,23 +34,19 @@ on: jobs: stage_release: - # To publish a release, merge the release PR with the label 'release:publish'. + # To publish a release, merge a PR with the title prefix '[chore] Release ' to main + # and ensure the squashed commit message also has the prefix. # To stage a release without publishing it, send a 'firebase_build' event or apply # the 'release:stage' label to a PR. if: github.event.action == 'firebase_build' || contains(github.event.pull_request.labels.*.name, 'release:stage') || - (github.event.pull_request.merged && - contains(github.event.pull_request.labels.*.name, 'release:publish')) + (github.event_name == 'push' && startsWith(github.event.head_commit.message, '[chore] Release ')) runs-on: ubuntu-latest - # When manually triggering the build, the requester can specify a target branch or a tag - # via the 'ref' client parameter. steps: - name: Checkout source for staging uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # 4.3.1 - with: - ref: ${{ github.event.client_payload.ref || github.ref }} - name: Set up Python uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # 5.6.0 @@ -82,15 +86,13 @@ jobs: publish_release: needs: stage_release - # Check whether the release should be published. We publish only when the trigger PR is - # 1. merged + # Check whether the release should be published. We publish only when the trigger is + # 1. a push (merge) # 2. to the main branch - # 3. with the label 'release:publish', and - # 4. the title prefix '[chore] Release '. - if: github.event.pull_request.merged && + # 3. and the commit message has the title prefix '[chore] Release '. + if: github.event_name == 'push' && github.ref == 'refs/heads/main' && - contains(github.event.pull_request.labels.*.name, 'release:publish') && - startsWith(github.event.pull_request.title, '[chore] Release ') + startsWith(github.event.head_commit.message, '[chore] Release ') runs-on: ubuntu-latest environment: Release @@ -120,24 +122,12 @@ jobs: - name: Create release tag env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - run: gh release create ${{ steps.preflight.outputs.version }} - --title "Firebase Admin Python SDK ${{ steps.preflight.outputs.version }}" - --notes '${{ steps.preflight.outputs.changelog }}' + RELEASE_VER: ${{ steps.preflight.outputs.version }} + RELEASE_BODY: ${{ steps.preflight.outputs.changelog }} + run: | + gh release create "$RELEASE_VER" \ + --title "Firebase Admin Python SDK $RELEASE_VER" \ + --notes "$RELEASE_BODY" - name: Publish to Pypi uses: pypa/gh-action-pypi-publish@ed0c53931b1dc9bd32cbe73a98c7f6766f8a527e # v1.13.0 - - # Post to Twitter if explicitly opted-in by adding the label 'release:tweet'. - - name: Post to Twitter - if: success() && - contains(github.event.pull_request.labels.*.name, 'release:tweet') - uses: firebase/firebase-admin-node/.github/actions/send-tweet@2e2b36a84ba28679bcb7aecdacabfec0bded2d48 # Admin Node SDK v13.6.0 - with: - status: > - ${{ steps.preflight.outputs.version }} of @Firebase Admin Python SDK is available. - https://github.com/firebase/firebase-admin-python/releases/tag/${{ steps.preflight.outputs.version }} - consumer-key: ${{ secrets.TWITTER_CONSUMER_KEY }} - consumer-secret: ${{ secrets.TWITTER_CONSUMER_SECRET }} - access-token: ${{ secrets.TWITTER_ACCESS_TOKEN }} - access-token-secret: ${{ secrets.TWITTER_ACCESS_TOKEN_SECRET }} - continue-on-error: true From 005b44dc3f5600f0ecebc8c24a5e91392c7fcc75 Mon Sep 17 00:00:00 2001 From: Jonathan Edey <145066863+jonathanedey@users.noreply.github.com> Date: Wed, 25 Feb 2026 12:25:49 -0500 Subject: [PATCH 224/226] [chore] Release 7.2.0 (#937) --- firebase_admin/__about__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/firebase_admin/__about__.py b/firebase_admin/__about__.py index 9fb40b11c..d219f5ed7 100644 --- a/firebase_admin/__about__.py +++ b/firebase_admin/__about__.py @@ -14,7 +14,7 @@ """About information (version, etc) for Firebase Admin SDK.""" -__version__ = '7.1.0' +__version__ = '7.2.0' __title__ = 'firebase_admin' __author__ = 'Firebase' __license__ = 'Apache License 2.0' From 581ef26c3ea0964d44bbd77dfbae1940985c1300 Mon Sep 17 00:00:00 2001 From: Huon Imberger Date: Fri, 27 Feb 2026 03:47:30 +1100 Subject: [PATCH 225/226] Remove debug print for HTTP status error (#939) Co-authored-by: Jonathan Edey <145066863+jonathanedey@users.noreply.github.com> --- firebase_admin/_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/firebase_admin/_utils.py b/firebase_admin/_utils.py index d0aca884b..0277b9e5f 100644 --- a/firebase_admin/_utils.py +++ b/firebase_admin/_utils.py @@ -279,7 +279,6 @@ def handle_httpx_error(error: httpx.HTTPError, message=None, code=None) -> excep message=f'Failed to establish a connection: {error}', cause=error) if isinstance(error, httpx.HTTPStatusError): - print("printing status error", error) if not code: code = _http_status_to_error_code(error.response.status_code) if not message: From d62a15e7da70e2d1829aadf034ad34546d7433cb Mon Sep 17 00:00:00 2001 From: Jonathan Edey <145066863+jonathanedey@users.noreply.github.com> Date: Thu, 12 Mar 2026 13:07:34 -0400 Subject: [PATCH 226/226] feat(fcm): Add support for bandwidth constrained and restricted satellite APIs (#940) * feat(fcm): Add support for bandwidth constrained and restricted satellite APIs * chore: Add type hints to `AndroidConfig` --- firebase_admin/_messaging_encoder.py | 4 ++++ firebase_admin/_messaging_utils.py | 24 ++++++++++++++++++++++-- tests/test_messaging.py | 16 ++++++++++++++++ 3 files changed, 42 insertions(+), 2 deletions(-) diff --git a/firebase_admin/_messaging_encoder.py b/firebase_admin/_messaging_encoder.py index 960a6d742..4c0c6daa4 100644 --- a/firebase_admin/_messaging_encoder.py +++ b/firebase_admin/_messaging_encoder.py @@ -207,6 +207,10 @@ def encode_android(cls, android): 'fcm_options': cls.encode_android_fcm_options(android.fcm_options), 'direct_boot_ok': _Validators.check_boolean( 'AndroidConfig.direct_boot_ok', android.direct_boot_ok), + 'bandwidth_constrained_ok': _Validators.check_boolean( + 'AndroidConfig.bandwidth_constrained_ok', android.bandwidth_constrained_ok), + 'restricted_satellite_ok': _Validators.check_boolean( + 'AndroidConfig.restricted_satellite_ok', android.restricted_satellite_ok), } result = cls.remove_null_values(result) priority = result.get('priority') diff --git a/firebase_admin/_messaging_utils.py b/firebase_admin/_messaging_utils.py index 8fd720701..773ed6057 100644 --- a/firebase_admin/_messaging_utils.py +++ b/firebase_admin/_messaging_utils.py @@ -13,6 +13,9 @@ # limitations under the License. """Types and utilities used by the messaging (FCM) module.""" +from __future__ import annotations +import datetime +from typing import Dict, Optional, Union from firebase_admin import exceptions @@ -51,10 +54,25 @@ class AndroidConfig: fcm_options: A ``messaging.AndroidFCMOptions`` to be included in the message (optional). direct_boot_ok: A boolean indicating whether messages will be allowed to be delivered to the app while the device is in direct boot mode (optional). + bandwidth_constrained_ok: A boolean indicating whether messages will be allowed to be + delivered to the app while the device is on a bandwidth constrained network (optional). + restricted_satellite_ok: A boolean indicating whether messages will be allowed to be + delivered to the app while the device is on a restricted satellite network (optional). """ - def __init__(self, collapse_key=None, priority=None, ttl=None, restricted_package_name=None, - data=None, notification=None, fcm_options=None, direct_boot_ok=None): + def __init__( + self, + collapse_key: Optional[str] = None, + priority: Optional[str] = None, + ttl: Optional[Union[int, float, datetime.timedelta]] = None, + restricted_package_name: Optional[str] = None, + data: Optional[Dict[str, str]] = None, + notification: Optional[AndroidNotification] = None, + fcm_options: Optional[AndroidFCMOptions] = None, + direct_boot_ok: Optional[bool] = None, + bandwidth_constrained_ok: Optional[bool] = None, + restricted_satellite_ok: Optional[bool] = None + ): self.collapse_key = collapse_key self.priority = priority self.ttl = ttl @@ -63,6 +81,8 @@ def __init__(self, collapse_key=None, priority=None, ttl=None, restricted_packag self.notification = notification self.fcm_options = fcm_options self.direct_boot_ok = direct_boot_ok + self.bandwidth_constrained_ok = bandwidth_constrained_ok + self.restricted_satellite_ok = restricted_satellite_ok class AndroidNotification: diff --git a/tests/test_messaging.py b/tests/test_messaging.py index 9fa30fef9..b30790f14 100644 --- a/tests/test_messaging.py +++ b/tests/test_messaging.py @@ -335,6 +335,18 @@ def test_invalid_direct_boot_ok(self, data): check_encoding(messaging.Message( topic='topic', android=messaging.AndroidConfig(direct_boot_ok=data))) + @pytest.mark.parametrize('data', NON_BOOL_ARGS) + def test_invalid_bandwidth_constrained_ok(self, data): + with pytest.raises(ValueError): + check_encoding(messaging.Message( + topic='topic', android=messaging.AndroidConfig(bandwidth_constrained_ok=data))) + + @pytest.mark.parametrize('data', NON_BOOL_ARGS) + def test_invalid_restricted_satellite_ok(self, data): + with pytest.raises(ValueError): + check_encoding(messaging.Message( + topic='topic', android=messaging.AndroidConfig(restricted_satellite_ok=data))) + def test_android_config(self): msg = messaging.Message( @@ -347,6 +359,8 @@ def test_android_config(self): data={'k1': 'v1', 'k2': 'v2'}, fcm_options=messaging.AndroidFCMOptions('analytics_label_v1'), direct_boot_ok=True, + bandwidth_constrained_ok=True, + restricted_satellite_ok=True, ) ) expected = { @@ -364,6 +378,8 @@ def test_android_config(self): 'analytics_label': 'analytics_label_v1', }, 'direct_boot_ok': True, + 'bandwidth_constrained_ok': True, + 'restricted_satellite_ok': True, }, } check_encoding(msg, expected)