diff --git a/playwright/_impl/_browser_type.py b/playwright/_impl/_browser_type.py index f78ba82c0..e4f32c138 100644 --- a/playwright/_impl/_browser_type.py +++ b/playwright/_impl/_browser_type.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import asyncio from pathlib import Path from typing import Dict, List, Optional, Union, cast @@ -179,11 +180,24 @@ async def connect( self._connection._object_factory, transport, ) + await connection._transport.start() connection._is_sync = self._connection._is_sync connection._loop = self._connection._loop connection._loop.create_task(connection.run()) + obj = asyncio.create_task( + connection.wait_for_object_with_known_name("Playwright") + ) + done, pending = await asyncio.wait( + { + obj, + connection._transport.on_error_future, # type: ignore + }, + return_when=asyncio.FIRST_COMPLETED, + ) + if not obj.done(): + obj.cancel() + playwright = next(iter(done)).result() self._connection._child_ws_connections.append(connection) - playwright = await connection.wait_for_object_with_known_name("Playwright") pre_launched_browser = playwright._initializer.get("preLaunchedBrowser") assert pre_launched_browser browser = cast(Browser, from_channel(pre_launched_browser)) diff --git a/playwright/_impl/_transport.py b/playwright/_impl/_transport.py index b315318b3..f67805ebd 100644 --- a/playwright/_impl/_transport.py +++ b/playwright/_impl/_transport.py @@ -42,6 +42,7 @@ def _get_stderr_fileno() -> Optional[int]: class Transport(ABC): def __init__(self) -> None: + self.on_error_future: asyncio.Future self.on_message = lambda _: None @abstractmethod @@ -55,9 +56,14 @@ def dispose(self) -> None: async def wait_until_stopped(self) -> None: pass - async def run(self) -> None: + async def start(self) -> None: + if not hasattr(self, "on_error_future"): + self.on_error_future = asyncio.Future() self._loop = asyncio.get_running_loop() - self.on_error_future: asyncio.Future = asyncio.Future() + + @abstractmethod + async def run(self) -> None: + pass @abstractmethod def send(self, message: Dict) -> None: @@ -93,17 +99,28 @@ async def wait_until_stopped(self) -> None: await self._proc.wait() async def run(self) -> None: - await super().run() + await self.start() self._stopped_future: asyncio.Future = asyncio.Future() - self._proc = proc = await asyncio.create_subprocess_exec( - str(self._driver_executable), - "run-driver", - stdin=asyncio.subprocess.PIPE, - stdout=asyncio.subprocess.PIPE, - stderr=_get_stderr_fileno(), - limit=32768, - ) + try: + self._proc = proc = await asyncio.create_subprocess_exec( + str(self._driver_executable), + "run-driver", + stdin=asyncio.subprocess.PIPE, + stdout=asyncio.subprocess.PIPE, + stderr=_get_stderr_fileno(), + limit=32768, + ) + except FileNotFoundError: + self.on_error_future.set_exception( + Error( + "playwright's driver is not found, You can read the contributing guide " + "for some guidance on how to get everything setup for working on the code " + "https://github.com/microsoft/playwright-python/blob/master/CONTRIBUTING.md" + ) + ) + return + assert proc.stdout assert proc.stdin self._output = proc.stdin @@ -160,15 +177,22 @@ async def wait_until_stopped(self) -> None: await self._connection.wait_closed() async def run(self) -> None: - await super().run() + await self.start() options: Dict[str, Any] = {} if self.timeout is not None: options["close_timeout"] = self.timeout / 1000 options["ping_timeout"] = self.timeout / 1000 + if self.headers is not None: options["extra_headers"] = self.headers - self._connection = await websockets.connect(self.ws_endpoint, **options) + try: + self._connection = await websockets.connect(self.ws_endpoint, **options) + except Exception as err: + self.on_error_future.set_exception( + Error(f"playwright's websocket endpoint connection error: {err}") + ) + return while not self._stopped: try: diff --git a/playwright/async_api/_context_manager.py b/playwright/async_api/_context_manager.py index 51800dd48..f8c7cd50b 100644 --- a/playwright/async_api/_context_manager.py +++ b/playwright/async_api/_context_manager.py @@ -32,10 +32,22 @@ async def __aenter__(self) -> AsyncPlaywright: ) loop = asyncio.get_running_loop() self._connection._loop = loop + obj = asyncio.create_task( + self._connection.wait_for_object_with_known_name("Playwright") + ) + await self._connection._transport.start() loop.create_task(self._connection.run()) - playwright = AsyncPlaywright( - await self._connection.wait_for_object_with_known_name("Playwright") + done, pending = await asyncio.wait( + { + obj, + self._connection._transport.on_error_future, # type: ignore + }, + return_when=asyncio.FIRST_COMPLETED, ) + if not obj.done(): + obj.cancel() + obj = next(iter(done)).result() + playwright = AsyncPlaywright(obj) # type: ignore playwright.stop = self.__aexit__ # type: ignore return playwright diff --git a/tests/async/test_browsertype_connect.py b/tests/async/test_browsertype_connect.py index 842b6c1ee..55266035f 100644 --- a/tests/async/test_browsertype_connect.py +++ b/tests/async/test_browsertype_connect.py @@ -182,3 +182,13 @@ async def test_prevent_getting_video_path( == "Path is not available when using browserType.connect(). Use save_as() to save a local copy." ) remote_server.kill() + + +async def test_connect_to_closed_server_without_hangs( + browser_type: BrowserType, launch_server +): + remote_server = launch_server() + remote_server.kill() + with pytest.raises(Error) as exc: + await browser_type.connect(remote_server.ws_endpoint) + assert "playwright's websocket endpoint connection error" in exc.value.message diff --git a/tests/sync/test_browsertype_connect.py b/tests/sync/test_browsertype_connect.py index d698d8da1..3fa95a879 100644 --- a/tests/sync/test_browsertype_connect.py +++ b/tests/sync/test_browsertype_connect.py @@ -145,3 +145,13 @@ def test_browser_type_connect_should_forward_close_events_to_pages( assert events == ["page::close", "context::close", "browser::disconnected"] remote.kill() assert events == ["page::close", "context::close", "browser::disconnected"] + + +def test_connect_to_closed_server_without_hangs( + browser_type: BrowserType, launch_server +): + remote_server = launch_server() + remote_server.kill() + with pytest.raises(Error) as exc: + browser_type.connect(remote_server.ws_endpoint) + assert "playwright's websocket endpoint connection error" in exc.value.message