Skip to content
This repository was archived by the owner on Apr 1, 2026. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 39 additions & 51 deletions bigframes/functions/_function_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

from __future__ import annotations

import inspect
import logging
import os
import random
Expand All @@ -25,7 +24,7 @@
import tempfile
import textwrap
import types
from typing import Any, cast, Optional, Sequence, TYPE_CHECKING
from typing import Any, cast, Optional, TYPE_CHECKING
import warnings

import requests
Expand Down Expand Up @@ -87,7 +86,6 @@ def __init__(
bq_location,
bq_dataset,
bq_client,
bq_connection_id,
bq_connection_manager,
cloud_function_region=None,
cloud_functions_client=None,
Expand All @@ -102,7 +100,6 @@ def __init__(
self._bq_location = bq_location
self._bq_dataset = bq_dataset
self._bq_client = bq_client
self._bq_connection_id = bq_connection_id
self._bq_connection_manager = bq_connection_manager
self._session = session

Expand All @@ -114,12 +111,12 @@ def __init__(
self._cloud_function_docker_repository = cloud_function_docker_repository
self._cloud_build_service_account = cloud_build_service_account

def _create_bq_connection(self) -> None:
def _create_bq_connection(self, connection_id: str) -> None:
if self._bq_connection_manager:
self._bq_connection_manager.create_bq_connection(
self._gcp_project_id,
self._bq_location,
self._bq_connection_id,
connection_id,
"run.invoker",
)

Expand Down Expand Up @@ -174,7 +171,7 @@ def create_bq_remote_function(
):
"""Create a BigQuery remote function given the artifacts of a user defined
function and the http endpoint of a corresponding cloud function."""
self._create_bq_connection()
self._create_bq_connection(udf_def.connection_id)

# Create BQ function
# https://cloud.google.com/bigquery/docs/reference/standard-sql/remote-functions#create_a_remote_function_2
Expand Down Expand Up @@ -202,7 +199,7 @@ def create_bq_remote_function(
create_function_ddl = f"""
CREATE OR REPLACE FUNCTION `{self._gcp_project_id}.{self._bq_dataset}`.{bq_function_name_escaped}({udf_def.signature.to_sql_input_signature()})
RETURNS {udf_def.signature.with_devirtualize().output.sql_type}
REMOTE WITH CONNECTION `{self._gcp_project_id}.{self._bq_location}.{self._bq_connection_id}`
REMOTE WITH CONNECTION `{self._gcp_project_id}.{self._bq_location}.{udf_def.connection_id}`
OPTIONS ({remote_function_options_str})"""

logger.info(f"Creating BQ remote function: {create_function_ddl}")
Expand All @@ -212,26 +209,15 @@ def create_bq_remote_function(

def provision_bq_managed_function(
self,
func,
input_types: Sequence[str],
output_type: str,
name: Optional[str],
packages: Optional[Sequence[str]],
max_batching_rows: Optional[int],
container_cpu: Optional[float],
container_memory: Optional[str],
is_row_processor: bool,
bq_connection_id,
*,
capture_references: bool = False,
config: udf_def.ManagedFunctionConfig,
):
"""Create a BigQuery managed function."""

# TODO(b/406283812): Expose the capability to pass down
# capture_references=True in the public udf API.
# TODO(b/495508827): Include all config in the value hash.
if (
capture_references
config.capture_references
and (python_version := _utils.get_python_version())
!= _MANAGED_FUNC_PYTHON_VERSION
):
Expand All @@ -241,31 +227,26 @@ def provision_bq_managed_function(
)

# Create BQ managed function.
bq_function_args = []
bq_function_return_type = output_type

input_args = inspect.getargs(func.__code__).args
# We expect the input type annotations to be 1:1 with the input args.
for name_, type_ in zip(input_args, input_types):
bq_function_args.append(f"{name_} {type_}")
bq_function_args = config.signature.to_sql_input_signature()
bq_function_return_type = config.signature.with_devirtualize().output.sql_type

managed_function_options: dict[str, Any] = {
"runtime_version": _MANAGED_FUNC_PYTHON_VERSION,
"entry_point": "bigframes_handler",
}
if max_batching_rows:
managed_function_options["max_batching_rows"] = max_batching_rows
if container_cpu:
managed_function_options["container_cpu"] = container_cpu
if container_memory:
managed_function_options["container_memory"] = container_memory
if config.max_batching_rows:
managed_function_options["max_batching_rows"] = config.max_batching_rows
if config.container_cpu:
managed_function_options["container_cpu"] = config.container_cpu
if config.container_memory:
managed_function_options["container_memory"] = config.container_memory

# Augment user package requirements with any internal package
# requirements.
packages = _utils.get_updated_package_requirements(
packages or [],
is_row_processor,
capture_references,
config.code.package_requirements or [],
config.signature.is_row_processor,
config.capture_references,
ignore_package_version=True,
)
if packages:
Expand All @@ -276,40 +257,34 @@ def provision_bq_managed_function(

bq_function_name = name
if not bq_function_name:
# Compute a unique hash representing the user code.
function_hash = _utils.get_hash(func, packages)
bq_function_name = _utils.get_managed_function_name(
function_hash,
# session-scope in absensce of name from user
# name indicates permanent allocation
None if name else self._session.session_id,
# Compute a unique hash representing the artifact definition.
bq_function_name = get_managed_function_name(
config, self._session.session_id
)

persistent_func_id = (
f"`{self._gcp_project_id}.{self._bq_dataset}`.{bq_function_name}"
)

udf_name = func.__name__

with_connection_clause = (
(
f"WITH CONNECTION `{self._gcp_project_id}.{self._bq_location}.{self._bq_connection_id}`"
f"WITH CONNECTION `{self._gcp_project_id}.{self._bq_location}.{config.bq_connection_id}`"
)
if bq_connection_id
if config.bq_connection_id
else ""
)

# Generate the complete Python code block for the managed Python UDF,
# including the user's function, necessary imports, and the BigQuery
# handler wrapper.
python_code_block = bff_template.generate_managed_function_code(
func, udf_name, is_row_processor, capture_references
config.code, config.signature, config.capture_references
)

create_function_ddl = (
textwrap.dedent(
f"""
CREATE OR REPLACE FUNCTION {persistent_func_id}({','.join(bq_function_args)})
CREATE OR REPLACE FUNCTION {persistent_func_id}({bq_function_args})
RETURNS {bq_function_return_type}
LANGUAGE python
{with_connection_clause}
Expand Down Expand Up @@ -590,6 +565,7 @@ def provision_bq_remote_function(
cloud_function_memory_mib: int | None,
cloud_function_cpus: float | None,
cloud_function_ingress_settings: str,
bq_connection_id: str,
):
"""Provision a BigQuery remote function."""
# Augment user package requirements with any internal package
Expand Down Expand Up @@ -657,7 +633,7 @@ def provision_bq_remote_function(

intended_rf_spec = udf_def.RemoteFunctionConfig(
endpoint=cf_endpoint,
connection_id=self._bq_connection_id,
connection_id=bq_connection_id,
max_batching_rows=max_batching_rows or 1000,
signature=func_signature,
bq_metadata=func_signature.protocol_metadata,
Expand Down Expand Up @@ -731,6 +707,18 @@ def get_bigframes_function_name(
return _BQ_FUNCTION_NAME_SEPERATOR.join(parts)


def get_managed_function_name(
function_def: udf_def.ManagedFunctionConfig,
session_id: str | None = None,
):
"""Get a name for the bigframes managed function for the given user defined function."""
parts = [_BIGFRAMES_FUNCTION_PREFIX]
if session_id:
parts.append(session_id)
parts.append(function_def.stable_hash().hex())
return _BQ_FUNCTION_NAME_SEPERATOR.join(parts)


def _validate_routine_name(name: str) -> None:
"""Validate that the given name is a valid BigQuery routine name."""
# Routine IDs can contain only letters (a-z, A-Z), numbers (0-9), or underscores (_)
Expand Down
96 changes: 44 additions & 52 deletions bigframes/functions/_function_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,34 +556,13 @@ def wrapper(func):
func,
**signature_kwargs,
)
if input_types is not None:
if not isinstance(input_types, collections.abc.Sequence):
input_types = [input_types]
if _utils.has_conflict_input_type(py_sig, input_types):
msg = bfe.format_message(
"Conflicting input types detected, using the one from the decorator."
)
warnings.warn(msg, category=bfe.FunctionConflictTypeHintWarning)
py_sig = py_sig.replace(
parameters=[
par.replace(annotation=itype)
for par, itype in zip(py_sig.parameters.values(), input_types)
]
)
if output_type:
if _utils.has_conflict_output_type(py_sig, output_type):
msg = bfe.format_message(
"Conflicting return type detected, using the one from the decorator."
)
warnings.warn(msg, category=bfe.FunctionConflictTypeHintWarning)
py_sig = py_sig.replace(return_annotation=output_type)
py_sig = _resolve_signature(py_sig, input_types, output_type)

remote_function_client = _function_client.FunctionClient(
dataset_ref.project,
bq_location,
dataset_ref.dataset_id,
bigquery_client,
bq_connection_id,
bq_connection_manager,
cloud_function_region,
cloud_functions_client,
Expand Down Expand Up @@ -618,6 +597,7 @@ def wrapper(func):
cloud_function_memory_mib=cloud_function_memory_mib,
cloud_function_cpus=cloud_function_cpus,
cloud_function_ingress_settings=cloud_function_ingress_settings,
bq_connection_id=bq_connection_id,
)

bigframes_cloud_function = (
Expand Down Expand Up @@ -840,27 +820,7 @@ def wrapper(func):
func,
**signature_kwargs,
)
if input_types is not None:
if not isinstance(input_types, collections.abc.Sequence):
input_types = [input_types]
if _utils.has_conflict_input_type(py_sig, input_types):
msg = bfe.format_message(
"Conflicting input types detected, using the one from the decorator."
)
warnings.warn(msg, category=bfe.FunctionConflictTypeHintWarning)
py_sig = py_sig.replace(
parameters=[
par.replace(annotation=itype)
for par, itype in zip(py_sig.parameters.values(), input_types)
]
)
if output_type:
if _utils.has_conflict_output_type(py_sig, output_type):
msg = bfe.format_message(
"Conflicting return type detected, using the one from the decorator."
)
warnings.warn(msg, category=bfe.FunctionConflictTypeHintWarning)
py_sig = py_sig.replace(return_annotation=output_type)
py_sig = _resolve_signature(py_sig, input_types, output_type)

# The function will actually be receiving a pandas Series, but allow
# both BigQuery DataFrames and pandas object types for compatibility.
Expand All @@ -872,22 +832,22 @@ def wrapper(func):
bq_location,
dataset_ref.dataset_id,
bigquery_client,
bq_connection_id,
bq_connection_manager,
session=session, # type: ignore
)

bq_function_name = managed_function_client.provision_bq_managed_function(
func=func,
input_types=tuple(arg.sql_type for arg in udf_sig.inputs),
output_type=udf_sig.output.sql_type,
name=name,
packages=packages,
config = udf_def.ManagedFunctionConfig(
code=udf_def.CodeDef.from_func(func),
signature=udf_sig,
max_batching_rows=max_batching_rows,
container_cpu=container_cpu,
container_memory=container_memory,
is_row_processor=udf_sig.is_row_processor,
bq_connection_id=bq_connection_id,
capture_references=False,
)

bq_function_name = managed_function_client.provision_bq_managed_function(
name=name,
config=config,
)
full_rf_name = (
managed_function_client.get_remote_function_fully_qualilfied_name(
Expand All @@ -907,12 +867,14 @@ def wrapper(func):
if udf_sig.is_row_processor:
msg = bfe.format_message("input_types=Series is in preview.")
warnings.warn(msg, stacklevel=1, category=bfe.PreviewWarning)
assert session is not None # appease mypy
return decorator(
bq_functions.BigqueryCallableRowRoutine(
udf_definition, session, local_func=func, is_managed=True
)
)
else:
assert session is not None # appease mypy
return decorator(
bq_functions.BigqueryCallableRoutine(
udf_definition,
Expand Down Expand Up @@ -949,3 +911,33 @@ def deploy_udf(
# TODO(tswast): If we update udf to defer deployment, update this method
# to deploy immediately.
return self.udf(**kwargs)(func)


def _resolve_signature(
py_sig: inspect.Signature,
input_types: Union[None, type, Sequence[type]] = None,
output_type: Optional[type] = None,
) -> inspect.Signature:
if input_types is not None:
if not isinstance(input_types, collections.abc.Sequence):
input_types = [input_types]
if _utils.has_conflict_input_type(py_sig, input_types):
msg = bfe.format_message(
"Conflicting input types detected, using the one from the decorator."
)
warnings.warn(msg, category=bfe.FunctionConflictTypeHintWarning)
py_sig = py_sig.replace(
parameters=[
par.replace(annotation=itype)
for par, itype in zip(py_sig.parameters.values(), input_types)
]
)
if output_type:
if _utils.has_conflict_output_type(py_sig, output_type):
msg = bfe.format_message(
"Conflicting return type detected, using the one from the decorator."
)
warnings.warn(msg, category=bfe.FunctionConflictTypeHintWarning)
py_sig = py_sig.replace(return_annotation=output_type)

return py_sig
12 changes: 0 additions & 12 deletions bigframes/functions/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,18 +186,6 @@ def routine_ref_to_string_for_query(routine_ref: bigquery.RoutineReference) -> s
return f"`{routine_ref.project}.{routine_ref.dataset_id}`.{routine_ref.routine_id}"


def get_managed_function_name(
function_hash: str,
session_id: str | None = None,
):
"""Get a name for the bigframes managed function for the given user defined function."""
parts = [_BIGFRAMES_FUNCTION_PREFIX]
if session_id:
parts.append(session_id)
parts.append(function_hash)
return _BQ_FUNCTION_NAME_SEPERATOR.join(parts)


# Deprecated: Use CodeDef.stable_hash() instead.
def get_hash(def_, package_requirements=None):
"Get hash (32 digits alphanumeric) of a function."
Expand Down
Loading
Loading