From e4b90122ae5807e7e015b561b1a5a5ed6a2b5f96 Mon Sep 17 00:00:00 2001 From: Trevor Bergeron Date: Fri, 27 Mar 2026 00:16:21 +0000 Subject: [PATCH] refactor: Simplify @udf code --- bigframes/functions/_function_client.py | 90 ++++++++--------- bigframes/functions/_function_session.py | 96 +++++++++---------- bigframes/functions/_utils.py | 12 --- bigframes/functions/function_template.py | 34 +++---- bigframes/functions/udf_def.py | 40 +++++++- bigframes/testing/utils.py | 14 --- .../small/functions/test_remote_function.py | 20 +++- .../functions/test_remote_function_utils.py | 72 -------------- 8 files changed, 151 insertions(+), 227 deletions(-) diff --git a/bigframes/functions/_function_client.py b/bigframes/functions/_function_client.py index 4b368f48cc..fc06465327 100644 --- a/bigframes/functions/_function_client.py +++ b/bigframes/functions/_function_client.py @@ -15,7 +15,6 @@ from __future__ import annotations -import inspect import logging import os import random @@ -25,7 +24,7 @@ import tempfile import textwrap import types -from typing import Any, cast, Optional, Sequence, TYPE_CHECKING +from typing import Any, cast, Optional, TYPE_CHECKING import warnings import requests @@ -87,7 +86,6 @@ def __init__( bq_location, bq_dataset, bq_client, - bq_connection_id, bq_connection_manager, cloud_function_region=None, cloud_functions_client=None, @@ -102,7 +100,6 @@ def __init__( self._bq_location = bq_location self._bq_dataset = bq_dataset self._bq_client = bq_client - self._bq_connection_id = bq_connection_id self._bq_connection_manager = bq_connection_manager self._session = session @@ -114,12 +111,12 @@ def __init__( self._cloud_function_docker_repository = cloud_function_docker_repository self._cloud_build_service_account = cloud_build_service_account - def _create_bq_connection(self) -> None: + def _create_bq_connection(self, connection_id: str) -> None: if self._bq_connection_manager: self._bq_connection_manager.create_bq_connection( self._gcp_project_id, self._bq_location, - self._bq_connection_id, + connection_id, "run.invoker", ) @@ -174,7 +171,7 @@ def create_bq_remote_function( ): """Create a BigQuery remote function given the artifacts of a user defined function and the http endpoint of a corresponding cloud function.""" - self._create_bq_connection() + self._create_bq_connection(udf_def.connection_id) # Create BQ function # https://cloud.google.com/bigquery/docs/reference/standard-sql/remote-functions#create_a_remote_function_2 @@ -202,7 +199,7 @@ def create_bq_remote_function( create_function_ddl = f""" CREATE OR REPLACE FUNCTION `{self._gcp_project_id}.{self._bq_dataset}`.{bq_function_name_escaped}({udf_def.signature.to_sql_input_signature()}) RETURNS {udf_def.signature.with_devirtualize().output.sql_type} - REMOTE WITH CONNECTION `{self._gcp_project_id}.{self._bq_location}.{self._bq_connection_id}` + REMOTE WITH CONNECTION `{self._gcp_project_id}.{self._bq_location}.{udf_def.connection_id}` OPTIONS ({remote_function_options_str})""" logger.info(f"Creating BQ remote function: {create_function_ddl}") @@ -212,26 +209,15 @@ def create_bq_remote_function( def provision_bq_managed_function( self, - func, - input_types: Sequence[str], - output_type: str, name: Optional[str], - packages: Optional[Sequence[str]], - max_batching_rows: Optional[int], - container_cpu: Optional[float], - container_memory: Optional[str], - is_row_processor: bool, - bq_connection_id, - *, - capture_references: bool = False, + config: udf_def.ManagedFunctionConfig, ): """Create a BigQuery managed function.""" # TODO(b/406283812): Expose the capability to pass down # capture_references=True in the public udf API. - # TODO(b/495508827): Include all config in the value hash. if ( - capture_references + config.capture_references and (python_version := _utils.get_python_version()) != _MANAGED_FUNC_PYTHON_VERSION ): @@ -241,31 +227,26 @@ def provision_bq_managed_function( ) # Create BQ managed function. - bq_function_args = [] - bq_function_return_type = output_type - - input_args = inspect.getargs(func.__code__).args - # We expect the input type annotations to be 1:1 with the input args. - for name_, type_ in zip(input_args, input_types): - bq_function_args.append(f"{name_} {type_}") + bq_function_args = config.signature.to_sql_input_signature() + bq_function_return_type = config.signature.with_devirtualize().output.sql_type managed_function_options: dict[str, Any] = { "runtime_version": _MANAGED_FUNC_PYTHON_VERSION, "entry_point": "bigframes_handler", } - if max_batching_rows: - managed_function_options["max_batching_rows"] = max_batching_rows - if container_cpu: - managed_function_options["container_cpu"] = container_cpu - if container_memory: - managed_function_options["container_memory"] = container_memory + if config.max_batching_rows: + managed_function_options["max_batching_rows"] = config.max_batching_rows + if config.container_cpu: + managed_function_options["container_cpu"] = config.container_cpu + if config.container_memory: + managed_function_options["container_memory"] = config.container_memory # Augment user package requirements with any internal package # requirements. packages = _utils.get_updated_package_requirements( - packages or [], - is_row_processor, - capture_references, + config.code.package_requirements or [], + config.signature.is_row_processor, + config.capture_references, ignore_package_version=True, ) if packages: @@ -276,26 +257,20 @@ def provision_bq_managed_function( bq_function_name = name if not bq_function_name: - # Compute a unique hash representing the user code. - function_hash = _utils.get_hash(func, packages) - bq_function_name = _utils.get_managed_function_name( - function_hash, - # session-scope in absensce of name from user - # name indicates permanent allocation - None if name else self._session.session_id, + # Compute a unique hash representing the artifact definition. + bq_function_name = get_managed_function_name( + config, self._session.session_id ) persistent_func_id = ( f"`{self._gcp_project_id}.{self._bq_dataset}`.{bq_function_name}" ) - udf_name = func.__name__ - with_connection_clause = ( ( - f"WITH CONNECTION `{self._gcp_project_id}.{self._bq_location}.{self._bq_connection_id}`" + f"WITH CONNECTION `{self._gcp_project_id}.{self._bq_location}.{config.bq_connection_id}`" ) - if bq_connection_id + if config.bq_connection_id else "" ) @@ -303,13 +278,13 @@ def provision_bq_managed_function( # including the user's function, necessary imports, and the BigQuery # handler wrapper. python_code_block = bff_template.generate_managed_function_code( - func, udf_name, is_row_processor, capture_references + config.code, config.signature, config.capture_references ) create_function_ddl = ( textwrap.dedent( f""" - CREATE OR REPLACE FUNCTION {persistent_func_id}({','.join(bq_function_args)}) + CREATE OR REPLACE FUNCTION {persistent_func_id}({bq_function_args}) RETURNS {bq_function_return_type} LANGUAGE python {with_connection_clause} @@ -590,6 +565,7 @@ def provision_bq_remote_function( cloud_function_memory_mib: int | None, cloud_function_cpus: float | None, cloud_function_ingress_settings: str, + bq_connection_id: str, ): """Provision a BigQuery remote function.""" # Augment user package requirements with any internal package @@ -657,7 +633,7 @@ def provision_bq_remote_function( intended_rf_spec = udf_def.RemoteFunctionConfig( endpoint=cf_endpoint, - connection_id=self._bq_connection_id, + connection_id=bq_connection_id, max_batching_rows=max_batching_rows or 1000, signature=func_signature, bq_metadata=func_signature.protocol_metadata, @@ -731,6 +707,18 @@ def get_bigframes_function_name( return _BQ_FUNCTION_NAME_SEPERATOR.join(parts) +def get_managed_function_name( + function_def: udf_def.ManagedFunctionConfig, + session_id: str | None = None, +): + """Get a name for the bigframes managed function for the given user defined function.""" + parts = [_BIGFRAMES_FUNCTION_PREFIX] + if session_id: + parts.append(session_id) + parts.append(function_def.stable_hash().hex()) + return _BQ_FUNCTION_NAME_SEPERATOR.join(parts) + + def _validate_routine_name(name: str) -> None: """Validate that the given name is a valid BigQuery routine name.""" # Routine IDs can contain only letters (a-z, A-Z), numbers (0-9), or underscores (_) diff --git a/bigframes/functions/_function_session.py b/bigframes/functions/_function_session.py index 85753a71ce..fe7889e955 100644 --- a/bigframes/functions/_function_session.py +++ b/bigframes/functions/_function_session.py @@ -556,34 +556,13 @@ def wrapper(func): func, **signature_kwargs, ) - if input_types is not None: - if not isinstance(input_types, collections.abc.Sequence): - input_types = [input_types] - if _utils.has_conflict_input_type(py_sig, input_types): - msg = bfe.format_message( - "Conflicting input types detected, using the one from the decorator." - ) - warnings.warn(msg, category=bfe.FunctionConflictTypeHintWarning) - py_sig = py_sig.replace( - parameters=[ - par.replace(annotation=itype) - for par, itype in zip(py_sig.parameters.values(), input_types) - ] - ) - if output_type: - if _utils.has_conflict_output_type(py_sig, output_type): - msg = bfe.format_message( - "Conflicting return type detected, using the one from the decorator." - ) - warnings.warn(msg, category=bfe.FunctionConflictTypeHintWarning) - py_sig = py_sig.replace(return_annotation=output_type) + py_sig = _resolve_signature(py_sig, input_types, output_type) remote_function_client = _function_client.FunctionClient( dataset_ref.project, bq_location, dataset_ref.dataset_id, bigquery_client, - bq_connection_id, bq_connection_manager, cloud_function_region, cloud_functions_client, @@ -618,6 +597,7 @@ def wrapper(func): cloud_function_memory_mib=cloud_function_memory_mib, cloud_function_cpus=cloud_function_cpus, cloud_function_ingress_settings=cloud_function_ingress_settings, + bq_connection_id=bq_connection_id, ) bigframes_cloud_function = ( @@ -840,27 +820,7 @@ def wrapper(func): func, **signature_kwargs, ) - if input_types is not None: - if not isinstance(input_types, collections.abc.Sequence): - input_types = [input_types] - if _utils.has_conflict_input_type(py_sig, input_types): - msg = bfe.format_message( - "Conflicting input types detected, using the one from the decorator." - ) - warnings.warn(msg, category=bfe.FunctionConflictTypeHintWarning) - py_sig = py_sig.replace( - parameters=[ - par.replace(annotation=itype) - for par, itype in zip(py_sig.parameters.values(), input_types) - ] - ) - if output_type: - if _utils.has_conflict_output_type(py_sig, output_type): - msg = bfe.format_message( - "Conflicting return type detected, using the one from the decorator." - ) - warnings.warn(msg, category=bfe.FunctionConflictTypeHintWarning) - py_sig = py_sig.replace(return_annotation=output_type) + py_sig = _resolve_signature(py_sig, input_types, output_type) # The function will actually be receiving a pandas Series, but allow # both BigQuery DataFrames and pandas object types for compatibility. @@ -872,22 +832,22 @@ def wrapper(func): bq_location, dataset_ref.dataset_id, bigquery_client, - bq_connection_id, bq_connection_manager, session=session, # type: ignore ) - - bq_function_name = managed_function_client.provision_bq_managed_function( - func=func, - input_types=tuple(arg.sql_type for arg in udf_sig.inputs), - output_type=udf_sig.output.sql_type, - name=name, - packages=packages, + config = udf_def.ManagedFunctionConfig( + code=udf_def.CodeDef.from_func(func), + signature=udf_sig, max_batching_rows=max_batching_rows, container_cpu=container_cpu, container_memory=container_memory, - is_row_processor=udf_sig.is_row_processor, bq_connection_id=bq_connection_id, + capture_references=False, + ) + + bq_function_name = managed_function_client.provision_bq_managed_function( + name=name, + config=config, ) full_rf_name = ( managed_function_client.get_remote_function_fully_qualilfied_name( @@ -907,12 +867,14 @@ def wrapper(func): if udf_sig.is_row_processor: msg = bfe.format_message("input_types=Series is in preview.") warnings.warn(msg, stacklevel=1, category=bfe.PreviewWarning) + assert session is not None # appease mypy return decorator( bq_functions.BigqueryCallableRowRoutine( udf_definition, session, local_func=func, is_managed=True ) ) else: + assert session is not None # appease mypy return decorator( bq_functions.BigqueryCallableRoutine( udf_definition, @@ -949,3 +911,33 @@ def deploy_udf( # TODO(tswast): If we update udf to defer deployment, update this method # to deploy immediately. return self.udf(**kwargs)(func) + + +def _resolve_signature( + py_sig: inspect.Signature, + input_types: Union[None, type, Sequence[type]] = None, + output_type: Optional[type] = None, +) -> inspect.Signature: + if input_types is not None: + if not isinstance(input_types, collections.abc.Sequence): + input_types = [input_types] + if _utils.has_conflict_input_type(py_sig, input_types): + msg = bfe.format_message( + "Conflicting input types detected, using the one from the decorator." + ) + warnings.warn(msg, category=bfe.FunctionConflictTypeHintWarning) + py_sig = py_sig.replace( + parameters=[ + par.replace(annotation=itype) + for par, itype in zip(py_sig.parameters.values(), input_types) + ] + ) + if output_type: + if _utils.has_conflict_output_type(py_sig, output_type): + msg = bfe.format_message( + "Conflicting return type detected, using the one from the decorator." + ) + warnings.warn(msg, category=bfe.FunctionConflictTypeHintWarning) + py_sig = py_sig.replace(return_annotation=output_type) + + return py_sig diff --git a/bigframes/functions/_utils.py b/bigframes/functions/_utils.py index c197ed14fc..e02cd94fb1 100644 --- a/bigframes/functions/_utils.py +++ b/bigframes/functions/_utils.py @@ -186,18 +186,6 @@ def routine_ref_to_string_for_query(routine_ref: bigquery.RoutineReference) -> s return f"`{routine_ref.project}.{routine_ref.dataset_id}`.{routine_ref.routine_id}" -def get_managed_function_name( - function_hash: str, - session_id: str | None = None, -): - """Get a name for the bigframes managed function for the given user defined function.""" - parts = [_BIGFRAMES_FUNCTION_PREFIX] - if session_id: - parts.append(session_id) - parts.append(function_hash) - return _BQ_FUNCTION_NAME_SEPERATOR.join(parts) - - # Deprecated: Use CodeDef.stable_hash() instead. def get_hash(def_, package_requirements=None): "Get hash (32 digits alphanumeric) of a function." diff --git a/bigframes/functions/function_template.py b/bigframes/functions/function_template.py index 31b5b20520..33a3688cf1 100644 --- a/bigframes/functions/function_template.py +++ b/bigframes/functions/function_template.py @@ -20,8 +20,6 @@ import re import textwrap -import cloudpickle - from bigframes.functions import udf_def logger = logging.getLogger(__name__) @@ -230,9 +228,10 @@ def generate_udf_code(code_def: udf_def.CodeDef, directory: str): udf_pickle_file_name = "udf.cloudpickle" # original code, only for debugging purpose - udf_code_file_path = os.path.join(directory, udf_code_file_name) - with open(udf_code_file_path, "w") as f: - f.write(code_def.function_source) + if code_def.function_source: + udf_code_file_path = os.path.join(directory, udf_code_file_name) + with open(udf_code_file_path, "w") as f: + f.write(code_def.function_source) # serialized udf udf_pickle_file_path = os.path.join(directory, udf_pickle_file_name) @@ -293,35 +292,37 @@ def generate_cloud_function_main_code( def generate_managed_function_code( - def_, - udf_name: str, - is_row_processor: bool, + code_def: udf_def.CodeDef, + signature: udf_def.UdfSignature, capture_references: bool, ) -> str: """Generates the Python code block for managed Python UDF.""" + udf_name = "unpickled_udf" if capture_references: # This code path ensures that if the udf body contains any # references to variables and/or imports outside the body, they are # captured as well. - pickled = cloudpickle.dumps(def_) func_code = textwrap.dedent( f""" import cloudpickle - {udf_name} = cloudpickle.loads({pickled}) + {udf_name} = cloudpickle.loads({code_def.pickled_code!r}) """ ) else: # This code path ensures that if the udf body is self contained, # i.e. there are no references to variables or imports outside the # body. - func_code = textwrap.dedent(inspect.getsource(def_)) + assert code_def.function_source is not None + assert code_def.entry_point is not None + func_code = code_def.function_source + udf_name = code_def.entry_point match = re.search(r"^def ", func_code, flags=re.MULTILINE) if match is None: raise ValueError("The UDF is not defined correctly.") func_code = func_code[match.start() :] - if is_row_processor: + if signature.is_row_processor: udf_code = textwrap.dedent(inspect.getsource(get_pd_series)) udf_code = udf_code[udf_code.index("def") :] bigframes_handler_code = textwrap.dedent( @@ -331,20 +332,19 @@ def bigframes_handler(str_arg): """ ) - sig = inspect.signature(def_) - params = list(sig.parameters.values()) + params = list(arg.name for arg in signature.inputs) additional_params = params[1:] # Build the parameter list for the new handler function definition. # e.g., "str_arg, y: bool, z" handler_def_parts = ["str_arg"] - handler_def_parts.extend(str(p) for p in additional_params) + handler_def_parts.extend(additional_params) handler_def_str = ", ".join(handler_def_parts) # Build the argument list for the call to the original UDF. # e.g., "get_pd_series(str_arg), y, z" udf_call_parts = [f"{get_pd_series.__name__}(str_arg)"] - udf_call_parts.extend(p.name for p in additional_params) + udf_call_parts.extend(additional_params) udf_call_str = ", ".join(udf_call_parts) bigframes_handler_code = textwrap.dedent( @@ -364,7 +364,7 @@ def bigframes_handler(*args): ) udf_code_block = [] - if not capture_references and is_row_processor: + if not capture_references and signature.is_row_processor: # Enable postponed evaluation of type annotations. This converts all # type hints to strings at runtime, which is necessary for correctly # handling the type annotation of pandas.Series after the UDF code is diff --git a/bigframes/functions/udf_def.py b/bigframes/functions/udf_def.py index f02f289ef6..3ebf2eeb47 100644 --- a/bigframes/functions/udf_def.py +++ b/bigframes/functions/udf_def.py @@ -19,7 +19,7 @@ import io import os import textwrap -from typing import Any, cast, get_args, get_origin, Sequence, Type +from typing import Any, cast, get_args, get_origin, Optional, Sequence, Type import warnings import cloudpickle @@ -401,18 +401,26 @@ class CodeDef: # Produced by cloudpickle, not compatible across python versions pickled_code: bytes # This is just the function itself, and does not include referenced objects/functions/modules - function_source: str + function_source: Optional[str] + entry_point: Optional[str] package_requirements: tuple[str, ...] @classmethod def from_func(cls, func, package_requirements: Sequence[str] | None = None): bytes_io = io.BytesIO() cloudpickle.dump(func, bytes_io, protocol=_pickle_protocol_version) - # this is hacky, but works for some nested functions - source = textwrap.dedent(inspect.getsource(func)) + source = None + entry_point = None + try: + # dedent is hacky, but works for some nested functions + source = textwrap.dedent(inspect.getsource(func)) + entry_point = func.__name__ + except OSError: + pass return cls( pickled_code=bytes_io.getvalue(), function_source=source, + entry_point=entry_point, package_requirements=tuple(package_requirements or []), ) @@ -448,6 +456,30 @@ def stable_hash(self) -> bytes: return hash_val.digest() +@dataclasses.dataclass(frozen=True) +class ManagedFunctionConfig: + code: CodeDef + signature: UdfSignature + max_batching_rows: Optional[int] + container_cpu: Optional[float] + container_memory: Optional[str] + bq_connection_id: Optional[str] + # capture_refernces=True -> deploy as cloudpickle + # capture_references=False -> deploy as source + capture_references: bool = False + + def stable_hash(self) -> bytes: + hash_val = google_crc32c.Checksum() + hash_val.update(self.code.stable_hash()) + hash_val.update(self.signature.stable_hash()) + hash_val.update(str(self.max_batching_rows).encode()) + hash_val.update(str(self.container_cpu).encode()) + hash_val.update(str(self.container_memory).encode()) + hash_val.update(str(self.bq_connection_id).encode()) + hash_val.update(str(self.capture_references).encode()) + return hash_val.digest() + + @dataclasses.dataclass(frozen=True) class CloudRunFunctionConfig: code: CodeDef diff --git a/bigframes/testing/utils.py b/bigframes/testing/utils.py index 5f4a8d2627..bd2fa41c5e 100644 --- a/bigframes/testing/utils.py +++ b/bigframes/testing/utils.py @@ -508,20 +508,6 @@ def cleanup_function_assets( pass -def get_function_name(func, package_requirements=None, is_row_processor=False): - """Get a bigframes function name for testing given a udf.""" - # Augment user package requirements with any internal package - # requirements. - package_requirements = bff_utils.get_updated_package_requirements( - package_requirements or [], is_row_processor - ) - - # Compute a unique hash representing the user code. - function_hash = bff_utils.get_hash(func, package_requirements) - - return f"bigframes_{function_hash}" - - def _apply_ops_to_sql( obj: bpd.DataFrame, ops_list: Sequence[ex.Expression], diff --git a/tests/system/small/functions/test_remote_function.py b/tests/system/small/functions/test_remote_function.py index 643f503c05..0a9875a989 100644 --- a/tests/system/small/functions/test_remote_function.py +++ b/tests/system/small/functions/test_remote_function.py @@ -34,15 +34,25 @@ from bigframes.functions import _utils as bff_utils from bigframes.functions import function as bff import bigframes.session._io.bigquery -from bigframes.testing.utils import ( - assert_frame_equal, - assert_series_equal, - get_function_name, -) +from bigframes.testing.utils import assert_frame_equal, assert_series_equal _prefixer = test_utils.prefixer.Prefixer("bigframes", "") +def get_function_name(func, package_requirements=None, is_row_processor=False): + """Get a bigframes function name for testing given a udf.""" + # Augment user package requirements with any internal package + # requirements. + package_requirements = bff_utils.get_updated_package_requirements( + package_requirements or [], is_row_processor + ) + + # Compute a unique hash representing the user code. + function_hash = bff_utils.get_hash(func, package_requirements) + + return f"bigframes_{function_hash}" + + @pytest.fixture(scope="module") def bq_cf_connection() -> str: """Pre-created BQ connection in the test project in US location, used to diff --git a/tests/unit/functions/test_remote_function_utils.py b/tests/unit/functions/test_remote_function_utils.py index dcf6058767..5ca26fe96f 100644 --- a/tests/unit/functions/test_remote_function_utils.py +++ b/tests/unit/functions/test_remote_function_utils.py @@ -188,78 +188,6 @@ def test_package_existed_helper(): assert not _utils._package_existed([], "pandas") -def _function_add_one(x): - return x + 1 - - -def _function_add_two(x): - return x + 2 - - -@pytest.mark.parametrize( - "func1, func2, should_be_equal, description", - [ - ( - _function_add_one, - _function_add_one, - True, - "Identical functions should have the same hash.", - ), - ( - _function_add_one, - _function_add_two, - False, - "Different functions should have different hashes.", - ), - ], -) -def test_get_hash_without_package_requirements( - func1, func2, should_be_equal, description -): - """Tests function hashes without any requirements.""" - hash1 = _utils.get_hash(func1) - hash2 = _utils.get_hash(func2) - - if should_be_equal: - assert hash1 == hash2, f"FAILED: {description}" - else: - assert hash1 != hash2, f"FAILED: {description}" - - -@pytest.mark.parametrize( - "reqs1, reqs2, should_be_equal, description", - [ - ( - None, - ["pandas>=1.0"], - False, - "Hash with or without requirements should differ from hash.", - ), - ( - ["pandas", "numpy", "scikit-learn"], - ["numpy", "scikit-learn", "pandas"], - True, - "Same requirements should produce the same hash.", - ), - ( - ["pandas==1.0"], - ["pandas==2.0"], - False, - "Different requirement versions should produce different hashes.", - ), - ], -) -def test_get_hash_with_package_requirements(reqs1, reqs2, should_be_equal, description): - """Tests how package requirements affect the final hash.""" - hash1 = _utils.get_hash(_function_add_one, package_requirements=reqs1) - hash2 = _utils.get_hash(_function_add_one, package_requirements=reqs2) - - if should_be_equal: - assert hash1 == hash2, f"FAILED: {description}" - else: - assert hash1 != hash2, f"FAILED: {description}" - - # Helper functions for signature inspection tests def _func_one_arg_annotated(x: int) -> int: """A function with one annotated arg and an annotated return type."""