diff --git a/.coveragerc b/.coveragerc index b2713c796..e2ccc2cae 100644 --- a/.coveragerc +++ b/.coveragerc @@ -1,3 +1,2 @@ [run] -parallel=True source=pgcli diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index 35e8486bf..52c903d80 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -8,5 +8,5 @@ - [ ] I've added this contribution to the `changelog.rst`. - [ ] I've added my name to the `AUTHORS` file (or it's already there). -- [ ] I installed pre-commit hooks (`pip install pre-commit && pre-commit install`), and ran `black` on my code. +- [ ] I installed pre-commit hooks (`pip install pre-commit && pre-commit install`). - [x] Please squash merge this pull request (uncheck if you'd like us to merge as multiple commits) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index ce54d6f59..ac5b3dae1 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -1,6 +1,9 @@ name: pgcli on: + push: + branches: + - main pull_request: paths-ignore: - '**.rst' @@ -11,11 +14,11 @@ jobs: strategy: matrix: - python-version: [3.6, 3.7, 3.8, 3.9] + python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"] services: postgres: - image: postgres:9.6 + image: postgres:10 env: POSTGRES_USER: postgres POSTGRES_PASSWORD: postgres @@ -28,39 +31,62 @@ jobs: --health-retries 5 steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + - uses: astral-sh/setup-uv@c7f87aa956e4c323abf06d5dec078e358f6b4d04 # v6.0.0 + with: + version: "latest" - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 + uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0 with: python-version: ${{ matrix.python-version }} - - name: Install requirements + - name: Install pgbouncer run: | - pip install -U pip setuptools - pip install --no-cache-dir . - pip install -r requirements-dev.txt - pip install keyrings.alt>=3.1 + sudo apt install pgbouncer -y + + sudo chmod 666 /etc/pgbouncer/*.* + + cat < /etc/pgbouncer/userlist.txt + "postgres" "postgres" + EOF + + cat < /etc/pgbouncer/pgbouncer.ini + [databases] + * = host=localhost port=5432 + [pgbouncer] + listen_port = 6432 + listen_addr = localhost + auth_type = trust + auth_file = /etc/pgbouncer/userlist.txt + logfile = pgbouncer.log + pidfile = pgbouncer.pid + admin_users = postgres + EOF + + sudo systemctl stop pgbouncer + + pgbouncer -d /etc/pgbouncer/pgbouncer.ini + + psql -h localhost -U postgres -p 6432 pgbouncer -c 'show help' + + - name: Install requirements + run: uv sync --all-extras -p ${{ matrix.python-version }} - name: Run unit tests - run: coverage run --source pgcli -m py.test + run: uv run tox -e py${{ matrix.python-version }} - name: Run integration tests env: PGUSER: postgres PGPASSWORD: postgres + TERM: xterm - run: behave tests/features --no-capture + run: uv run tox -e integration - name: Check changelog for ReST compliance - run: rst2html.py --halt=warning changelog.rst >/dev/null - - - name: Run Black - run: pip install black && black --check . - if: matrix.python-version == '3.6' + run: uv run tox -e rest - - name: Coverage - run: | - coverage combine - coverage report - codecov + - name: Run style checks + run: uv run tox -e style diff --git a/.github/workflows/codeql.yml b/.github/workflows/codeql.yml new file mode 100644 index 000000000..c9232c711 --- /dev/null +++ b/.github/workflows/codeql.yml @@ -0,0 +1,41 @@ +name: "CodeQL" + +on: + push: + branches: [ "main" ] + pull_request: + branches: [ "main" ] + schedule: + - cron: "29 13 * * 1" + +jobs: + analyze: + name: Analyze + runs-on: ubuntu-latest + permissions: + actions: read + contents: read + security-events: write + + strategy: + fail-fast: false + matrix: + language: [ python ] + + steps: + - name: Checkout + uses: actions/checkout@v3 + + - name: Initialize CodeQL + uses: github/codeql-action/init@v2 + with: + languages: ${{ matrix.language }} + queries: +security-and-quality + + - name: Autobuild + uses: github/codeql-action/autobuild@v2 + + - name: Perform CodeQL Analysis + uses: github/codeql-action/analyze@v2 + with: + category: "/language:${{ matrix.language }}" diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml new file mode 100644 index 000000000..8b9d5728e --- /dev/null +++ b/.github/workflows/publish.yml @@ -0,0 +1,97 @@ +name: Publish Python Package + +on: + release: + types: [created] + +permissions: + contents: read + +jobs: + test: + runs-on: ubuntu-latest + + strategy: + matrix: + python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"] + + services: + postgres: + image: postgres:10 + env: + POSTGRES_USER: postgres + POSTGRES_PASSWORD: postgres + ports: + - 5432:5432 + options: >- + --health-cmd pg_isready + --health-interval 10s + --health-timeout 5s + --health-retries 5 + + steps: + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + - uses: astral-sh/setup-uv@c7f87aa956e4c323abf06d5dec078e358f6b4d04 # v6.0.0 + with: + version: "latest" + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0 + with: + python-version: ${{ matrix.python-version }} + + - name: Install dependencies + run: uv sync --all-extras -p ${{ matrix.python-version }} + + - name: Run unit tests + env: + LANG: en_US.UTF-8 + run: uv run tox -e py${{ matrix.python-version }} + + - name: Run Style Checks + run: uv run tox -e style + + build: + runs-on: ubuntu-latest + needs: [test] + + steps: + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + - uses: astral-sh/setup-uv@c7f87aa956e4c323abf06d5dec078e358f6b4d04 # v6.0.0 + with: + version: "latest" + + - name: Set up Python + uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0 + with: + python-version: '3.13' + + - name: Install dependencies + run: uv sync --all-extras -p 3.13 + + - name: Build + run: uv build + + - name: Store the distribution packages + uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2 + with: + name: python-packages + path: dist/ + + publish: + name: Publish to PyPI + runs-on: ubuntu-latest + if: startsWith(github.ref, 'refs/tags/') + needs: [build] + environment: release + permissions: + id-token: write + steps: + - name: Download distribution packages + uses: actions/download-artifact@d3f86a106a0bac45b974a628896c90dbdf5c8093 # v4.3.0 + with: + name: python-packages + path: dist/ + - name: Publish to PyPI + uses: pypa/gh-action-pypi-publish@76f52bc884231f62b9a034ebfe128415bbaabdfc # v1.12.4 \ No newline at end of file diff --git a/.gitignore b/.gitignore index 170585df2..1437096ab 100644 --- a/.gitignore +++ b/.gitignore @@ -42,6 +42,7 @@ htmlcov/ nosetests.xml coverage.xml .pytest_cache +tests/behave.ini # Translations *.mo @@ -69,3 +70,7 @@ target/ .vscode/ venv/ + +.ropeproject/ +uv.lock + diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 9e27ab8ec..9284586b1 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,7 +1,10 @@ repos: -- repo: https://github.com/psf/black - rev: 21.5b0 - hooks: - - id: black - language_version: python3.7 - +- repo: https://github.com/astral-sh/ruff-pre-commit + # Ruff version. + rev: v0.14.10 + hooks: + # Run the linter. + - id: ruff + args: [ --fix ] + # Run the formatter. + - id: ruff-format diff --git a/AUTHORS b/AUTHORS index bcfba6a33..79c650c5d 100644 --- a/AUTHORS +++ b/AUTHORS @@ -116,8 +116,36 @@ Contributors: * Kevin Marsh (kevinmarsh) * Eero Ruohola (ruohola) * Miroslav Šedivý (eumiro) - * Eric R Young (ERYoung11) + * Eric R Young (ERYoung11) * Paweł Sacawa (psacawa) + * Bruno Inec (sweenu) + * Daniele Varrazzo + * Daniel Kukula (dkuku) + * Kian-Meng Ang (kianmeng) + * Liu Zhao (astroshot) + * Rigo Neri (rigoneri) + * Anna Glasgall (annathyst) + * Andy Schoenberger (andyscho) + * Damien Baty (dbaty) + * blag + * Rob Berry (rob-b) + * Sharon Yogev (sharonyogev) + * Hollis Wu (holi0317) + * Antonio Aguilar (crazybolillo) + * Andrew M. MacFie (amacfie) + * saucoide + * Chris Rose (offbyone/offby1) + * Mathieu Dupuy (deronnax) + * Chris Novakovic + * Max Smolin (maximsmol) + * Josh Lynch (josh-lynch) + * Fabio (3ximus) + * Doug Harris (dougharris) + * Jay Knight (jay-knight) + * fbdb + * Charbel Jacquin (charbeljc) + * Devadathan M B (devadathanmb) + * Charalampos Stratakis Creator: -------- diff --git a/DEVELOP.rst b/CONTRIBUTING.rst similarity index 70% rename from DEVELOP.rst rename to CONTRIBUTING.rst index e262823d9..ad7eb5bdc 100644 --- a/DEVELOP.rst +++ b/CONTRIBUTING.rst @@ -23,8 +23,8 @@ repo. $ git remote add upstream git@github.com:dbcli/pgcli.git Once the 'upstream' end point is added you can then periodically do a ``git -pull upstream master`` to update your local copy and then do a ``git push -origin master`` to keep your own fork up to date. +pull upstream main`` to update your local copy and then do a ``git push +origin main`` to keep your own fork up to date. Check Github's `Understanding the GitHub flow guide `_ for a more detailed @@ -38,26 +38,23 @@ pgcli. If you're developing pgcli, you'll need to install it in a slightly different way so you can see the effects of your changes right away without having to go through the install cycle every time you change the code. -It is highly recommended to use virtualenv for development. If you don't know -what a virtualenv is, `this guide `_ -will help you get started. - -Create a virtualenv (let's call it pgcli-dev). Activate it: +Set up [uv](https://docs.astral.sh/uv/getting-started/installation/) for development: :: + cd pgcli + uv venv source ./pgcli-dev/bin/activate -Once the virtualenv is activated, `cd` into the local clone of pgcli folder -and install pgcli using pip as follows: +Once the virtualenv is activated, install pgcli using pip as follows: :: - $ pip install --editable . + $ uv pip install --editable . or - $ pip install -e . + $ uv pip install -e . This will install the necessary dependencies as well as install pgcli from the working folder into the virtualenv. By installing it using `pip install -e` @@ -73,6 +70,37 @@ If you want to work on adding new meta-commands (such as `\dp`, `\ds`, `dy`), you need to contribute to `pgspecial `_ project. +Visual Studio Code Debugging +----------------------------- +To set up Visual Studio Code to debug pgcli requires a launch.json file. + +Within the project, create a file: .vscode\\launch.json like below. + +:: + + { + // Use IntelliSense to learn about possible attributes. + // Hover to view descriptions of existing attributes. + // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 + "version": "0.2.0", + "configurations": [ + { + "name": "Python: Module", + "type": "python", + "request": "launch", + "module": "pgcli.main", + "justMyCode": false, + "console": "externalTerminal", + "env": { + "PGUSER": "postgres", + "PGPASS": "password", + "PGHOST": "localhost", + "PGPORT": "5432" + } + } + ] + } + Building RPM and DEB packages ----------------------------- @@ -130,8 +158,7 @@ in the ``tests`` directory. An example:: First, install the requirements for testing: :: - - $ pip install -r requirements-dev.txt + $ uv pip install ".[dev]" Ensure that the database user has permissions to create and drop test databases by checking your ``pg_hba.conf`` file. The default user should be ``postgres`` @@ -144,19 +171,14 @@ service for the changes to take effect. # ONLY IF YOU MADE CHANGES TO YOUR pg_hba.conf FILE $ sudo service postgresql restart -After that, tests in the ``/pgcli/tests`` directory can be run with: +After that: :: - # on directory /pgcli/tests + $ cd pgcli/tests $ behave -And on the ``/pgcli`` directory: - -:: - - # on directory /pgcli - $ py.test +Note that these ``behave`` tests do not currently work when developing on Windows due to pexpect incompatibility. To see stdout/stderr, use the following command: @@ -172,7 +194,23 @@ Troubleshooting the integration tests - Check `this issue `_ for relevant information. - `File an issue `_. +Running the unit tests +---------------------- + +The unit tests can be run with pytest: + +:: + + $ cd pgcli + $ pytest + + Coding Style ------------ -``pgcli`` uses `black `_ to format the source code. Make sure to install black. +``pgcli`` uses `ruff `_ to format the source code. + +Releases +-------- + +If you're the person responsible for releasing `pgcli`, `this guide `_ is for you. diff --git a/README.rst b/README.rst index 95137f7c5..6dc4f9e8a 100644 --- a/README.rst +++ b/README.rst @@ -1,7 +1,30 @@ +We stand with Ukraine +--------------------- + +Ukrainian people are fighting for their country. A lot of civilians, women and children, are suffering. Hundreds were killed and injured, and thousands were displaced. + +This is an image from my home town, Kharkiv. This place is right in the old city center. + +.. image:: screenshots/kharkiv-destroyed.jpg + +Picture by @fomenko_ph (Telegram). + +Please consider donating or volunteering. + +* https://bank.gov.ua/en/ +* https://savelife.in.ua/en/donate/ +* https://www.comebackalive.in.ua/donate +* https://www.globalgiving.org/projects/ukraine-crisis-relief-fund/ +* https://www.savethechildren.org/us/where-we-work/ukraine +* https://www.facebook.com/donate/1137971146948461/ +* https://donate.wck.org/give/393234#!/donation/checkout +* https://atlantaforukraine.com/ + + A REPL for Postgres ------------------- -|Build Status| |CodeCov| |PyPI| |Landscape| +|Build Status| |CodeCov| |PyPI| |netlify| This is a postgres client that does auto-completion and syntax highlighting. @@ -29,10 +52,7 @@ If you already know how to install python packages, then you can simply do: If you don't know how to install python packages, please check the `detailed instructions`_. -If you are restricted to using psycopg2 2.7.x then pip will try to install it from a binary. There are some known issues with the psycopg2 2.7 binary - see the `psycopg docs`_ for more information about this and how to force installation from source. psycopg2 2.8 has fixed these problems, and will build from source. - .. _`detailed instructions`: https://github.com/dbcli/pgcli#detailed-installation-instructions -.. _`psycopg docs`: http://initd.org/psycopg/docs/install.html#change-in-binary-packages-between-psycopg-2-7-and-2-8 Usage ----- @@ -135,10 +155,12 @@ If you're interested in contributing to this project, first of all I would like to extend my heartfelt gratitude. I've written a small doc to describe how to get this running in a development setup. -https://github.com/dbcli/pgcli/blob/master/DEVELOP.rst +https://github.com/dbcli/pgcli/blob/main/CONTRIBUTING.rst -Please feel free to reach out to me if you need help. -My email: amjith.r@gmail.com, Twitter: `@amjithr `_ +Please feel free to reach out to us if you need help. + +* Amjith, pgcli author: amjith.r@gmail.com, Twitter: `@amjithr `_ +* Irina, pgcli maintainer: i.chernyavska@gmail.com, Twitter: `@irinatruong `_ Detailed Installation Instructions: ----------------------------------- @@ -158,7 +180,7 @@ Alternatively, you can install ``pgcli`` as a python package using a package manager called called ``pip``. You will need postgres installed on your system for this to work. -In depth getting started guide for ``pip`` - https://pip.pypa.io/en/latest/installing.html. +In depth getting started guide for ``pip`` - https://pip.pypa.io/en/latest/installation/ :: @@ -188,43 +210,27 @@ If pip is not installed check if easy_install is available on the system. Linux: ====== -In depth getting started guide for ``pip`` - https://pip.pypa.io/en/latest/installing.html. +Many distributions have ``pgcli`` packages. +Refer to https://repology.org/project/pgcli/versions or your distribution to check the available versions. -Check if pip is already available in your system. +Alternatively, you can use tools such as `pipx`_ or `uvx`_ to install the latest published package to an isolated virtual environment. -:: +.. _pipx: https://pipx.pypa.io/ +.. _uvx: https://docs.astral.sh/uv/guides/tools/ - $ which pip - -If it doesn't exist, use your linux package manager to install `pip`. This -might look something like: +Run: :: - $ sudo apt-get install python-pip # Debian, Ubuntu, Mint etc - - or - - $ sudo yum install python-pip # RHEL, Centos, Fedora etc - -``pgcli`` requires python-dev, libpq-dev and libevent-dev packages. You can -install these via your operating system package manager. + $ pipx install pgcli +to install ``pgcli`` with ``pipx``, or run: :: - $ sudo apt-get install python-dev libpq-dev libevent-dev - - or - - $ sudo yum install python-devel postgresql-devel - -Then you can install pgcli: - -:: - - $ sudo pip install pgcli + $ uvx pgcli +to run ``pgcli`` by installing on the fly with ``uvx``. Docker ====== @@ -331,8 +337,10 @@ choice: In [3]: my_result = _ -Pgcli only runs on Python3.6+ since 2.2.0, if you use an old version of Python, -you should use install ``pgcli <= 2.2.0``. +Pgcli dropped support for: + +* Python<3.8 as of 4.0.0. +* Python<3.9 as of 4.2.0. Thanks: ------- @@ -346,23 +354,27 @@ of this app. `Click `_ is used for command line option parsing and printing error messages. -Thanks to `psycopg `_ for providing a rock solid +Thanks to `psycopg `_ for providing a rock solid interface to Postgres database. Thanks to all the beta testers and contributors for your time and patience. :) -.. |Build Status| image:: https://github.com/dbcli/pgcli/workflows/pgcli/badge.svg - :target: https://github.com/dbcli/pgcli/actions?query=workflow%3Apgcli +.. |Build Status| image:: https://github.com/dbcli/pgcli/actions/workflows/ci.yml/badge.svg?branch=main + :target: https://github.com/dbcli/pgcli/actions/workflows/ci.yml -.. |CodeCov| image:: https://codecov.io/gh/dbcli/pgcli/branch/master/graph/badge.svg +.. |CodeCov| image:: https://codecov.io/gh/dbcli/pgcli/branch/main/graph/badge.svg :target: https://codecov.io/gh/dbcli/pgcli :alt: Code coverage report -.. |Landscape| image:: https://landscape.io/github/dbcli/pgcli/master/landscape.svg?style=flat - :target: https://landscape.io/github/dbcli/pgcli/master +.. |Landscape| image:: https://landscape.io/github/dbcli/pgcli/main/landscape.svg?style=flat + :target: https://landscape.io/github/dbcli/pgcli/main :alt: Code Health .. |PyPI| image:: https://img.shields.io/pypi/v/pgcli.svg :target: https://pypi.python.org/pypi/pgcli/ :alt: Latest Version + +.. |netlify| image:: https://api.netlify.com/api/v1/badges/3a0a14dd-776d-445d-804c-3dd74fe31c4e/deploy-status + :target: https://app.netlify.com/sites/pgcli/deploys + :alt: Netlify diff --git a/RELEASES.md b/RELEASES.md new file mode 100644 index 000000000..d5bc64035 --- /dev/null +++ b/RELEASES.md @@ -0,0 +1,6 @@ +Releasing pgcli +--------------- + +You have been made the maintainer of `pgcli`? Congratulations! + +To release a new version of the package, [create a new release](https://github.com/dbcli/pgcli/releases) in Github. This will trigger a Github action which will run all the tests, build the wheel and upload it to PyPI. \ No newline at end of file diff --git a/changelog.rst b/changelog.rst index 54951023f..fdfcde538 100644 --- a/changelog.rst +++ b/changelog.rst @@ -1,5 +1,200 @@ -TBD -=== +Upcoming (TBD) +============== + +Features: +--------- +* Add support for `\\T` prompt escape sequence to display transaction status (similar to psql's `%x`). + +4.4.0 (2025-12-24) +================== + +Features: +--------- +* Add support for `init-command` to run when the connection is established. + * Command line option `--init-command` + * Provide `init-command` in the config file + * Support dsn specific init-command in the config file +* Add suggestion when setting the search_path +* Allow per dsn_alias ssh tunnel selection + +Internal: +--------- + +* Moderize the repository + * Use uv instead of pip + * Use github trusted publisher for pypi release + * Update dev requirements and replace requirements-dev.txt with pyproject.toml + * Use ruff instead of black + +Bug fixes: +---------- + +* Improve display of larger durations when passed as floats + +4.3.0 (2025-03-22) +================== + +Features +-------- +* The session time zone setting is set to the system time zone by default + +4.2.0 (2025-03-06) +================== + +Features +-------- +* Add a `--ping` command line option; allows pgcli to replace `pg_isready` +* Changed the packaging metadata from setup.py to pyproject.toml +* Add bash completion for services defined in the service file `~/.pg_service.conf` +* Added support for per-column date/time formatting using `column_date_formats` in config + +Bug fixes: +---------- +* Avoid raising `NameError` when exiting unsuccessfully in some cases +* Use configured `alias_map_file` to generate table aliases if available. + +Internal: +--------- + +* Drop support for Python 3.8 and add 3.13. + +4.1.0 (2024-03-09) +================== + +Features: +--------- +* Support `PGAPPNAME` as an environment variable and `--application-name` as a command line argument. +* Add `verbose_errors` config and `\v` special command which enable the + displaying of all Postgres error fields received. +* Show Postgres notifications. +* Support sqlparse 0.5.x +* Add `--log-file [filename]` cli argument and `\log-file [filename]` special commands to + log to an external file in addition to the normal output + +Bug fixes: +---------- + +* Fix display of "short host" in prompt (with `\h`) for IPv4 addresses ([issue 964](https://github.com/dbcli/pgcli/issues/964)). +* Fix backwards display of NOTICEs from a Function ([issue 1443](https://github.com/dbcli/pgcli/issues/1443)) +* Fix psycopg errors when installing on Windows. ([issue 1413](https://https://github.com/dbcli/pgcli/issues/1413)) +* Use a home-made function to display query duration instead of relying on a third-party library (the general behaviour does not change), which fixes the installation of `pgcli` on 32-bit architectures ([issue 1451](https://github.com/dbcli/pgcli/issues/1451)) + +================== +4.0.1 (2023-10-30) +================== + +Internal: +--------- +* Allow stable version of pendulum. + +================== +4.0.0 (2023-10-27) +================== + +Features: +--------- + +* Ask for confirmation when quitting cli while a transaction is ongoing. +* New `destructive_statements_require_transaction` config option to refuse to execute a + destructive SQL statement if outside a transaction. This option is off by default. +* Changed the `destructive_warning` config to be a list of commands that are considered + destructive. This would allow you to be warned on `create`, `grant`, or `insert` queries. +* Destructive warnings will now include the alias dsn connection string name if provided (-D option). +* pgcli.magic will now work with connection URLs that use TLS client certificates for authentication +* Have config option to retry queries on operational errors like connections being lost. + Also prevents getting stuck in a retry loop. +* Config option to not restart connection when cancelling a `destructive_warning` query. By default, + it will now not restart. +* Config option to always run with a single connection. +* Add comment explaining default LESS environment variable behavior and change example pager setting. +* Added `\echo` & `\qecho` special commands. ([issue 1335](https://github.com/dbcli/pgcli/issues/1335)). + +Bug fixes: +---------- + +* Fix `\ev` not producing a correctly quoted "schema"."view" +* Fix 'invalid connection option "dsn"' ([issue 1373](https://github.com/dbcli/pgcli/issues/1373)). +* Fix explain mode when used with `expand`, `auto_expand`, or `--explain-vertical-output` ([issue 1393](https://github.com/dbcli/pgcli/issues/1393)). +* Fix sql-insert format emits NULL as 'None' ([issue 1408](https://github.com/dbcli/pgcli/issues/1408)). +* Improve check for prompt-toolkit 3.0.6 ([issue 1416](https://github.com/dbcli/pgcli/issues/1416)). +* Allow specifying an `alias_map_file` in the config that will use + predetermined table aliases instead of generating aliases programmatically on + the fly +* Fixed SQL error when there is a comment on the first line: ([issue 1403](https://github.com/dbcli/pgcli/issues/1403)) +* Fix wrong usage of prompt instead of confirm when confirm execution of destructive query + +Internal: +--------- + +* Drop support for Python 3.7 and add 3.12. + +3.5.0 (2022/09/15): +=================== + +Features: +--------- + +* New formatter is added to export query result to sql format (such as sql-insert, sql-update) like mycli. + +Bug fixes: +---------- + +* Fix exception when retrieving password from keyring ([issue 1338](https://github.com/dbcli/pgcli/issues/1338)). +* Fix using comments with special commands ([issue 1362](https://github.com/dbcli/pgcli/issues/1362)). +* Small improvements to the Windows developer experience +* Fix submitting queries in safe multiline mode ([1360](https://github.com/dbcli/pgcli/issues/1360)). + +Internal: +--------- + +* Port to psycopg3 (https://github.com/psycopg/psycopg). +* Fix typos + +3.4.1 (2022/03/19) +================== + +Bug fixes: +---------- + +* Fix the bug with Redshift not displaying word count in status ([related issue](https://github.com/dbcli/pgcli/issues/1320)). +* Show the error status for CSV output format. + + +3.4.0 (2022/02/21) +================== + +Features: +--------- + +* Add optional support for automatically creating an SSH tunnel to a machine with access to the remote database ([related issue](https://github.com/dbcli/pgcli/issues/459)). + +3.3.1 (2022/01/18) +================== + +Bug fixes: +---------- + +* Prompt for password when -W is provided even if there is a password in keychain. Fixes #1307. +* Upgrade cli_helpers to 2.2.1 + +3.3.0 (2022/01/11) +================== + +Features: +--------- + +* Add `max_field_width` setting to config, to enable more control over field truncation ([related issue](https://github.com/dbcli/pgcli/issues/1250)). +* Re-run last query via bare `\watch`. (Thanks: `Saif Hakim`_) + +Bug fixes: +---------- + +* Pin the version of pygments to prevent breaking change + +3.2.0 +===== + +Release date: 2021/08/23 Features: --------- @@ -8,6 +203,7 @@ Features: `destructive_warning` setting to `all|moderate|off`, vs `true|false`. (#1239) * Skip initial comment in .pg_session even if it doesn't start with '#' * Include functions from schemas in search_path. (`Amjith Ramanujam`_) +* Easy way to show explain output under F5 Bug fixes: ---------- @@ -721,7 +917,7 @@ Internal Changes: * Added code coverage to the tests. (Thanks: `Irina Truong`_) * Run behaviorial tests as part of TravisCI (Thanks: `Irina Truong`_) * Upgraded prompt_toolkit version to 0.45 (Thanks: `Jonathan Slenders`_) -* Update the minumum required version of click to 4.1. +* Update the minimum required version of click to 4.1. 0.18.0 ====== @@ -959,7 +1155,7 @@ Features: * IPython integration through `ipython-sql`_ (Thanks: `Darik Gamble`_) * Add an ipython magic extension to embed pgcli inside ipython. * Results from a pgcli query are sent back to ipython. -* Multiple sql statments in the same line separated by semi-colon. (Thanks: https://github.com/macobo) +* Multiple sql statements in the same line separated by semi-colon. (Thanks: https://github.com/macobo) .. _`ipython-sql`: https://github.com/catherinedevlin/ipython-sql @@ -1085,3 +1281,4 @@ Improvements: .. _`thegeorgeous`: https://github.com/thegeorgeous .. _`laixintao`: https://github.com/laixintao .. _`anthonydb`: https://github.com/anthonydb +.. _`Daniel Kukula`: https://github.com/dkuku diff --git a/pgcli-completion.bash b/pgcli-completion.bash index 3549b5614..620563d27 100644 --- a/pgcli-completion.bash +++ b/pgcli-completion.bash @@ -3,9 +3,9 @@ _pg_databases() # -w was introduced in 8.4, https://launchpad.net/bugs/164772 # "Access privileges" in output may contain linefeeds, hence the NF > 1 COMPREPLY=( $( compgen -W "$( psql -AtqwlF $'\t' 2>/dev/null | \ - awk 'NF > 1 { print $1 }' )" -- "$cur" ) ) + awk 'NF > 1 { print $1 }' )" -- "$cur" ) ) } - + _pg_users() { # -w was introduced in 8.4, https://launchpad.net/bugs/164772 @@ -13,12 +13,23 @@ _pg_users() template1 2>/dev/null )" -- "$cur" ) ) [[ ${#COMPREPLY[@]} -eq 0 ]] && COMPREPLY=( $( compgen -u -- "$cur" ) ) } - + +_pg_services() +{ + # return list of available services + local services + if [[ -f "$HOME/.pg_service.conf" ]]; then + services=$(grep -oP '(?<=^\[).*?(?=\])' "$HOME/.pg_service.conf") + fi + local suffix="${cur#*=}" + COMPREPLY=( $(compgen -W "$services" -- "$suffix") ) +} + _pgcli() { local cur prev words cword _init_completion -s || return - + case $prev in -h|--host) _known_hosts_real "$cur" @@ -39,23 +50,27 @@ _pgcli() esac case "$cur" in - --*) - # return list of available options - COMPREPLY=( $( compgen -W '--host --port --user --password --no-password - --single-connection --version --dbname --pgclirc --dsn - --row-limit --help' -- "$cur" ) ) - [[ $COMPREPLY == *= ]] && compopt -o nospace - return 0 - ;; - -) - # only complete long options - compopt -o nospace - COMPREPLY=( -- ) - return 0 - ;; - *) + service=*) + _pg_services + return 0 + ;; + --*) + # return list of available options + COMPREPLY=( $( compgen -W '--host --port --user --password --no-password + --single-connection --version --dbname --pgclirc --dsn + --row-limit --help' -- "$cur" ) ) + [[ $COMPREPLY == *= ]] && compopt -o nospace + return 0 + ;; + -) + # only complete long options + compopt -o nospace + COMPREPLY=( -- ) + return 0 + ;; + *) # return list of available databases - _pg_databases + _pg_databases esac -} && +} && complete -F _pgcli pgcli diff --git a/pgcli/__init__.py b/pgcli/__init__.py index f5f41e567..ecdb1cef9 100644 --- a/pgcli/__init__.py +++ b/pgcli/__init__.py @@ -1 +1 @@ -__version__ = "3.1.0" +__version__ = "4.4.0" diff --git a/pgcli/auth.py b/pgcli/auth.py new file mode 100644 index 000000000..513097a31 --- /dev/null +++ b/pgcli/auth.py @@ -0,0 +1,56 @@ +import click +from textwrap import dedent + + +keyring = None # keyring will be loaded later + + +keyring_error_message = dedent( + """\ + {} + {} + To remove this message do one of the following: + - prepare keyring as described at: https://keyring.readthedocs.io/en/stable/ + - uninstall keyring: pip uninstall keyring + - disable keyring in our configuration: add keyring = False to [main]""" +) + + +def keyring_initialize(keyring_enabled, *, logger): + """Initialize keyring only if explicitly enabled""" + global keyring + + if keyring_enabled: + # Try best to load keyring (issue #1041). + import importlib + + try: + keyring = importlib.import_module("keyring") + except ModuleNotFoundError as e: # ImportError for Python 2, ModuleNotFoundError for Python 3 + logger.warning("import keyring failed: %r.", e) + + +def keyring_get_password(key): + """Attempt to get password from keyring""" + # Find password from store + passwd = "" + try: + passwd = keyring.get_password("pgcli", key) or "" + except Exception as e: + click.secho( + keyring_error_message.format("Load your password from keyring returned:", str(e)), + err=True, + fg="red", + ) + return passwd + + +def keyring_set_password(key, passwd): + try: + keyring.set_password("pgcli", key, passwd) + except Exception as e: + click.secho( + keyring_error_message.format("Set password in keyring returned:", str(e)), + err=True, + fg="red", + ) diff --git a/pgcli/completion_refresher.py b/pgcli/completion_refresher.py index 1039d5159..34771b8a2 100644 --- a/pgcli/completion_refresher.py +++ b/pgcli/completion_refresher.py @@ -6,7 +6,6 @@ class CompletionRefresher: - refreshers = OrderedDict() def __init__(self): @@ -39,20 +38,16 @@ def refresh(self, executor, special, callbacks, history=None, settings=None): args=(executor, special, callbacks, history, settings), name="completion_refresh", ) - self._completer_thread.setDaemon(True) + self._completer_thread.daemon = True self._completer_thread.start() - return [ - (None, None, None, "Auto-completion refresh started in the background.") - ] + return [(None, None, None, "Auto-completion refresh started in the background.")] def is_refreshing(self): return self._completer_thread and self._completer_thread.is_alive() def _bg_refresh(self, pgexecute, special, callbacks, history=None, settings=None): settings = settings or {} - completer = PGCompleter( - smart_completion=True, pgspecial=special, settings=settings - ) + completer = PGCompleter(smart_completion=True, pgspecial=special, settings=settings) if settings.get("single_connection"): executor = pgexecute diff --git a/pgcli/config.py b/pgcli/config.py index 22f08dc07..2b44a7bb7 100644 --- a/pgcli/config.py +++ b/pgcli/config.py @@ -1,4 +1,3 @@ -import errno import shutil import os import platform diff --git a/pgcli/explain_output_formatter.py b/pgcli/explain_output_formatter.py new file mode 100644 index 000000000..ce45b4f8d --- /dev/null +++ b/pgcli/explain_output_formatter.py @@ -0,0 +1,19 @@ +from pgcli.pyev import Visualizer +import json + + +"""Explain response output adapter""" + + +class ExplainOutputFormatter: + def __init__(self, max_width): + self.max_width = max_width + + def format_output(self, cur, headers, **output_kwargs): + # explain query results should always contain 1 row each + [(data,)] = list(cur) + explain_list = json.loads(data) + visualizer = Visualizer(self.max_width) + for explain in explain_list: + visualizer.load(explain) + yield visualizer.get_list() diff --git a/pgcli/key_bindings.py b/pgcli/key_bindings.py index 23174b6bc..11855df74 100644 --- a/pgcli/key_bindings.py +++ b/pgcli/key_bindings.py @@ -9,7 +9,7 @@ vi_mode, ) -from .pgbuffer import buffer_should_be_handled +from .pgbuffer import buffer_should_be_handled, safe_multi_line_mode _logger = logging.getLogger(__name__) @@ -39,6 +39,12 @@ def _(event): pgcli.vi_mode = not pgcli.vi_mode event.app.editing_mode = EditingMode.VI if pgcli.vi_mode else EditingMode.EMACS + @kb.add("f5") + def _(event): + """Toggle between Vi and Emacs mode.""" + _logger.debug("Detected F5 key.") + pgcli.explain_mode = not pgcli.explain_mode + @kb.add("tab") def _(event): """Force autocompletion at cursor on non-empty lines.""" @@ -101,14 +107,13 @@ def _(event): # history search, and one of several conditions are True @kb.add( "enter", - filter=~(completion_is_selected | is_searching) - & buffer_should_be_handled(pgcli), + filter=~(completion_is_selected | is_searching) & buffer_should_be_handled(pgcli), ) def _(event): _logger.debug("Detected enter key.") event.current_buffer.validate_and_handle() - @kb.add("escape", "enter", filter=~vi_mode) + @kb.add("escape", "enter", filter=~vi_mode & ~safe_multi_line_mode(pgcli)) def _(event): """Introduces a line break regardless of multi-line mode or not.""" _logger.debug("Detected alt-enter key.") diff --git a/pgcli/magic.py b/pgcli/magic.py index 6e58f28b7..09902a295 100644 --- a/pgcli/magic.py +++ b/pgcli/magic.py @@ -43,7 +43,7 @@ def pgcli_line_magic(line): u = conn.session.engine.url _logger.debug("New pgcli: %r", str(u)) - pgcli.connect(u.database, u.host, u.username, u.port, u.password) + pgcli.connect_uri(str(u._replace(drivername="postgres"))) conn._pgcli = pgcli # For convenience, print the connection alias diff --git a/pgcli/main.py b/pgcli/main.py index 5135f6fda..913228b33 100644 --- a/pgcli/main.py +++ b/pgcli/main.py @@ -1,13 +1,9 @@ -import platform -import warnings -from os.path import expanduser - +from zoneinfo import ZoneInfoNotFoundError from configobj import ConfigObj, ParseError from pgspecial.namedqueries import NamedQueries from .config import skip_initial_comment -warnings.filterwarnings("ignore", category=UserWarning, module="psycopg2") - +import atexit import os import re import sys @@ -16,18 +12,23 @@ import threading import shutil import functools -import pendulum import datetime as dt import itertools +import pathlib import platform from time import time, sleep - -keyring = None # keyring will be loaded later +from typing import Optional from cli_helpers.tabular_output import TabularOutputFormatter -from cli_helpers.tabular_output.preprocessors import align_decimals, format_numbers +from cli_helpers.tabular_output.preprocessors import ( + align_decimals, + format_numbers, + format_timestamps, +) from cli_helpers.utils import strip_ansi +from .explain_output_formatter import ExplainOutputFormatter import click +import tzlocal try: import setproctitle @@ -52,6 +53,7 @@ from pgspecial.main import PGSpecial, NO_QUERY, PAGER_OFF, PAGER_LONG_OUTPUT import pgspecial as special +from . import auth from .pgcompleter import PGCompleter from .pgtoolbar import create_toolbar_tokens_func from .pgstyle import style_factory, style_factory_output @@ -66,26 +68,35 @@ get_config_filename, ) from .key_bindings import pgcli_bindings -from .packages.prompt_utils import confirm_destructive_query +from .packages.formatter.sqlformatter import register_new_formatter +from .packages.prompt_utils import confirm, confirm_destructive_query +from .packages.parseutils import is_destructive +from .packages.parseutils import parse_destructive_warning from .__init__ import __version__ click.disable_unicode_literals_warning = True -try: - from urlparse import urlparse, unquote, parse_qs -except ImportError: - from urllib.parse import urlparse, unquote, parse_qs +from urllib.parse import urlparse from getpass import getuser -from psycopg2 import OperationalError, InterfaceError -import psycopg2 + +from psycopg import OperationalError, InterfaceError, Notify +from psycopg.conninfo import make_conninfo, conninfo_to_dict +from psycopg.errors import Diagnostic from collections import namedtuple -from textwrap import dedent +try: + import sshtunnel + + SSH_TUNNEL_SUPPORT = True +except ImportError: + SSH_TUNNEL_SUPPORT = False + # Ref: https://stackoverflow.com/questions/30425105/filter-special-chars-such-as-color-codes-from-shell-output COLOR_CODE_REGEX = re.compile(r"\x1b(\[.*?[@-~]|\].*?(\x07|\x1b\\))") +DEFAULT_MAX_FIELD_WIDTH = 500 # Query tuples are used for maintaining history MetaQuery = namedtuple( @@ -106,17 +117,19 @@ OutputSettings = namedtuple( "OutputSettings", - "table_format dcmlfmt floatfmt missingval expanded max_width case_function style_output", + "table_format dcmlfmt floatfmt column_date_formats missingval expanded max_width case_function style_output max_field_width", ) OutputSettings.__new__.__defaults__ = ( None, None, None, + None, "", False, None, lambda x: x, None, + DEFAULT_MAX_FIELD_WIDTH, ) @@ -124,6 +137,13 @@ class PgCliQuitError(Exception): pass +def notify_callback(notify: Notify): + click.secho( + 'Notification received on channel "{}" (PID {}):\n{}'.format(notify.channel, notify.pid, notify.payload), + fg="green", + ) + + class PGCli: default_prompt = "\\u@\\h:\\d> " max_len_prompt = 30 @@ -133,9 +153,7 @@ def set_default_pager(self, config): os_environ_pager = os.environ.get("PAGER") if configured_pager: - self.logger.info( - 'Default pager found in config file: "%s"', configured_pager - ) + self.logger.info('Default pager found in config file: "%s"', configured_pager) os.environ["PAGER"] = configured_pager elif os_environ_pager: self.logger.info( @@ -144,9 +162,7 @@ def set_default_pager(self, config): ) os.environ["PAGER"] = os_environ_pager else: - self.logger.info( - "No default pager found in environment. Using os default pager" - ) + self.logger.info("No default pager found in environment. Using os default pager") # Set default set of less recommended options, if they are not already set. # They are ignored if pager is different than less. @@ -160,14 +176,16 @@ def __init__( pgexecute=None, pgclirc_file=None, row_limit=None, + application_name="pgcli", single_connection=False, less_chatty=None, prompt=None, prompt_dsn=None, auto_vertical_output=False, warn=None, + ssh_tunnel_url: Optional[str] = None, + log_file: Optional[str] = None, ): - self.force_passwd_prompt = force_passwd_prompt self.never_passwd_prompt = never_passwd_prompt self.pgexecute = pgexecute @@ -190,10 +208,13 @@ def __init__( self.output_file = None self.pgspecial = PGSpecial() + self.hide_named_query_text = "hide_named_query_text" in c["main"] and c["main"].as_bool("hide_named_query_text") + self.explain_mode = False self.multi_line = c["main"].as_bool("multi_line") self.multiline_mode = c["main"].get("multi_line_mode", "psql") self.vi_mode = c["main"].as_bool("vi") self.auto_expand = auto_vertical_output or c["main"].as_bool("auto_expand") + self.auto_retry_closed_connection = c["main"].as_bool("auto_retry_closed_connection") self.expanded_output = c["main"].as_bool("expand") self.pgspecial.timing_enabled = c["main"].as_bool("timing") if row_limit is not None: @@ -201,34 +222,41 @@ def __init__( else: self.row_limit = c["main"].as_int("row_limit") + self.application_name = application_name + + # if not specified, set to DEFAULT_MAX_FIELD_WIDTH + # if specified but empty, set to None to disable truncation + # ellipsis will take at least 3 symbols, so this can't be less than 3 if specified and > 0 + max_field_width = c["main"].get("max_field_width", DEFAULT_MAX_FIELD_WIDTH) + if max_field_width and max_field_width.lower() != "none": + max_field_width = max(3, abs(int(max_field_width))) + else: + max_field_width = None + self.max_field_width = max_field_width + self.min_num_menu_lines = c["main"].as_int("min_num_menu_lines") self.multiline_continuation_char = c["main"]["multiline_continuation_char"] self.table_format = c["main"]["table_format"] self.syntax_style = c["main"]["syntax_style"] self.cli_style = c["colors"] self.wider_completion_menu = c["main"].as_bool("wider_completion_menu") - self.destructive_warning = warn or c["main"]["destructive_warning"] - # also handle boolean format of destructive warning - self.destructive_warning = {"true": "all", "false": "off"}.get( - self.destructive_warning.lower(), self.destructive_warning - ) + self.destructive_warning = parse_destructive_warning(warn or c["main"].as_list("destructive_warning")) + self.destructive_warning_restarts_connection = c["main"].as_bool("destructive_warning_restarts_connection") + self.destructive_statements_require_transaction = c["main"].as_bool("destructive_statements_require_transaction") + self.less_chatty = bool(less_chatty) or c["main"].as_bool("less_chatty") + self.verbose_errors = "verbose_errors" in c["main"] and c["main"].as_bool("verbose_errors") self.null_string = c["main"].get("null_string", "") - self.prompt_format = ( - prompt - if prompt is not None - else c["main"].get("prompt", self.default_prompt) - ) + self.prompt_format = prompt if prompt is not None else c["main"].get("prompt", self.default_prompt) self.prompt_dsn_format = prompt_dsn self.on_error = c["main"]["on_error"].upper() self.decimal_format = c["data_formats"]["decimal"] self.float_format = c["data_formats"]["float"] - self.initialize_keyring() + self.column_date_formats = c["column_date_formats"] + auth.keyring_initialize(c["main"].as_bool("keyring"), logger=self.logger) self.show_bottom_toolbar = c["main"].as_bool("show_bottom_toolbar") - self.pgspecial.pset_pager( - self.config["main"].as_bool("enable_pager") and "on" or "off" - ) + self.pgspecial.pset_pager(self.config["main"].as_bool("enable_pager") and "on" or "off") self.style_output = style_factory_output(self.syntax_style, c["colors"]) @@ -241,6 +269,7 @@ def __init__( # Initialize completer smart_completion = c["main"].as_bool("smart_completion") keyword_casing = c["main"]["keyword_casing"] + single_connection = single_connection or c["main"].as_bool("always_use_single_connection") self.settings = { "casing_file": get_casing_file(c), "generate_casing_file": c["main"].as_bool("generate_casing_file"), @@ -252,21 +281,54 @@ def __init__( "single_connection": single_connection, "less_chatty": less_chatty, "keyword_casing": keyword_casing, + "alias_map_file": c["main"]["alias_map_file"] or None, } - completer = PGCompleter( - smart_completion, pgspecial=self.pgspecial, settings=self.settings - ) + completer = PGCompleter(smart_completion, pgspecial=self.pgspecial, settings=self.settings) self.completer = completer self._completer_lock = threading.Lock() self.register_special_commands() self.prompt_app = None + self.dsn_ssh_tunnel_config = c.get("dsn ssh tunnels") + self.ssh_tunnel_config = c.get("ssh tunnels") + self.ssh_tunnel_url = ssh_tunnel_url + self.ssh_tunnel = None + + if log_file: + with open(log_file, "a+"): + pass # ensure writeable + self.log_file = log_file + + # formatter setup + self.formatter = TabularOutputFormatter(format_name=c["main"]["table_format"]) + register_new_formatter(self.formatter) + def quit(self): raise PgCliQuitError + def toggle_named_query_quiet(self): + """Toggle hiding of named query text""" + self.hide_named_query_text = not self.hide_named_query_text + status = "ON" if self.hide_named_query_text else "OFF" + message = f"Named query quiet mode: {status}" + return [(None, None, None, message)] + + def _is_named_query_execution(self, text): + """Check if the command is a named query execution (\n ).""" + text = text.strip() + return text.startswith("\\n ") and not text.startswith("\\ns ") and not text.startswith("\\nd ") + def register_special_commands(self): + self.pgspecial.register( + self.toggle_named_query_quiet, + "\\nq", + "\\nq", + "Toggle named query quiet mode (hide query text)", + arg_type=NO_QUERY, + case_sensitive=True, + ) self.pgspecial.register( self.change_db, @@ -276,7 +338,8 @@ def register_special_commands(self): aliases=("use", "\\connect", "USE"), ) - refresh_callback = lambda: self.refresh_completions(persist_priorities="all") + def refresh_callback(): + return self.refresh_completions(persist_priorities="all") self.pgspecial.register( self.quit, @@ -310,9 +373,7 @@ def register_special_commands(self): "Refresh auto-completions.", arg_type=NO_QUERY, ) - self.pgspecial.register( - self.execute_from_file, "\\i", "\\i filename", "Execute commands from file." - ) + self.pgspecial.register(self.execute_from_file, "\\i", "\\i filename", "Execute commands from file.") self.pgspecial.register( self.write_to_file, "\\o", @@ -320,8 +381,12 @@ def register_special_commands(self): "Send all query results to file.", ) self.pgspecial.register( - self.info_connection, "\\conninfo", "\\conninfo", "Get connection details" + self.write_to_logfile, + "\\log-file", + "\\log-file [filename]", + "Log all query results to a logfile, in addition to the normal output destination.", ) + self.pgspecial.register(self.info_connection, "\\conninfo", "\\conninfo", "Get connection details") self.pgspecial.register( self.change_table_format, "\\T", @@ -329,6 +394,43 @@ def register_special_commands(self): "Change the table format used to output results", ) + self.pgspecial.register( + self.echo, + "\\echo", + "\\echo [string]", + "Echo a string to stdout", + ) + + self.pgspecial.register( + self.echo, + "\\qecho", + "\\qecho [string]", + "Echo a string to the query output channel.", + ) + + self.pgspecial.register( + self.toggle_verbose_errors, + "\\v", + "\\v [on|off]", + "Toggle verbose errors.", + ) + + def toggle_verbose_errors(self, pattern, **_): + flag = pattern.strip() + + if flag == "on": + self.verbose_errors = True + elif flag == "off": + self.verbose_errors = False + else: + self.verbose_errors = not self.verbose_errors + + message = "Verbose errors " + "on." if self.verbose_errors else "off." + return [(None, None, None, message)] + + def echo(self, pattern, **_): + return [(None, None, None, pattern)] + def change_table_format(self, pattern, **_): try: if pattern not in TabularOutputFormatter().supported_formats: @@ -353,8 +455,7 @@ def info_connection(self, **_): None, None, 'You are connected to database "%s" as user ' - '"%s" on %s at port "%s".' - % (self.pgexecute.dbname, self.pgexecute.user, host, self.pgexecute.port), + '"%s" on %s at port "%s".' % (self.pgexecute.dbname, self.pgexecute.user, host, self.pgexecute.port), ) def change_db(self, pattern, **_): @@ -362,7 +463,7 @@ def change_db(self, pattern, **_): # Get all the parameters in pattern, handling double quotes if any. infos = re.findall(r'"[^"]*"|[^"\'\s]+', pattern) # Now removing quotes. - list(map(lambda s: s.strip('"'), infos)) + [s.strip('"') for s in infos] infos.extend([None] * (4 - len(infos))) db, user, host, port = infos @@ -384,8 +485,7 @@ def change_db(self, pattern, **_): None, None, None, - 'You are now connected to database "%s" as ' - 'user "%s"' % (self.pgexecute.dbname, self.pgexecute.user), + 'You are now connected to database "%s" as user "%s"' % (self.pgexecute.dbname, self.pgexecute.user), ) def execute_from_file(self, pattern, **_): @@ -398,18 +498,47 @@ def execute_from_file(self, pattern, **_): except OSError as e: return [(None, None, None, str(e), "", False, True)] - if ( - self.destructive_warning != "off" - and confirm_destructive_query(query, self.destructive_warning) is False - ): - message = "Wise choice. Command execution stopped." - return [(None, None, None, message)] + if self.destructive_warning: + if ( + self.destructive_statements_require_transaction + and not self.pgexecute.valid_transaction() + and is_destructive(query, self.destructive_warning) + ): + message = "Destructive statements must be run within a transaction. Command execution stopped." + return [(None, None, None, message)] + destroy = confirm_destructive_query(query, self.destructive_warning, self.dsn_alias) + if destroy is False: + message = "Wise choice. Command execution stopped." + return [(None, None, None, message)] on_error_resume = self.on_error == "RESUME" return self.pgexecute.run( - query, self.pgspecial, on_error_resume=on_error_resume + query, + self.pgspecial, + on_error_resume=on_error_resume, + explain_mode=self.explain_mode, ) + def write_to_logfile(self, pattern, **_): + if not pattern: + self.log_file = None + message = "Logfile capture disabled" + return [(None, None, None, message, "", True, True)] + + log_file = pathlib.Path(pattern).expanduser().absolute() + + try: + with open(log_file, "a+"): + pass # ensure writeable + except OSError as e: + self.log_file = None + message = str(e) + "\nLogfile capture disabled" + return [(None, None, None, message, "", False, True)] + + self.log_file = str(log_file) + message = 'Writing to file "%s"' % self.log_file + return [(None, None, None, message, "", True, True)] + def write_to_file(self, pattern, **_): if not pattern: self.output_file = None @@ -428,7 +557,6 @@ def write_to_file(self, pattern, **_): return [(None, None, None, message, "", True, True)] def initialize_logging(self): - log_file = self.config["main"]["log_file"] if log_file == "default": log_file = config_location() + "log" @@ -453,10 +581,7 @@ def initialize_logging(self): log_level = level_map[log_level.upper()] - formatter = logging.Formatter( - "%(asctime)s (%(process)d/%(threadName)s) " - "%(name)s %(levelname)s - %(message)s" - ) + formatter = logging.Formatter("%(asctime)s (%(process)d/%(threadName)s) %(name)s %(levelname)s - %(message)s") handler.setFormatter(formatter) @@ -471,29 +596,14 @@ def initialize_logging(self): pgspecial_logger.addHandler(handler) pgspecial_logger.setLevel(log_level) - def initialize_keyring(self): - global keyring - - keyring_enabled = self.config["main"].as_bool("keyring") - if keyring_enabled: - # Try best to load keyring (issue #1041). - import importlib - - try: - keyring = importlib.import_module("keyring") - except Exception as e: # ImportError for Python 2, ModuleNotFoundError for Python 3 - self.logger.warning("import keyring failed: %r.", e) - def connect_dsn(self, dsn, **kwargs): self.connect(dsn=dsn, **kwargs) def connect_service(self, service, user): service_config, file = parse_service_info(service) if service_config is None: - click.secho( - f"service '{service}' was not found in {file}", err=True, fg="red" - ) - exit(1) + click.secho(f"service '{service}' was not found in {file}", err=True, fg="red") + sys.exit(1) self.connect( database=service_config.get("dbname"), host=service_config.get("host"), @@ -503,14 +613,12 @@ def connect_service(self, service, user): ) def connect_uri(self, uri): - kwargs = psycopg2.extensions.parse_dsn(uri) + kwargs = conninfo_to_dict(uri) remap = {"dbname": "database", "password": "passwd"} kwargs = {remap.get(k, k): v for k, v in kwargs.items()} self.connect(**kwargs) - def connect( - self, database="", host="", user="", port="", passwd="", dsn="", **kwargs - ): + def connect(self, database="", host="", user="", port="", passwd="", dsn="", **kwargs): # Connect to the database. if not user: @@ -519,46 +627,25 @@ def connect( if not database: database = user - kwargs.setdefault("application_name", "pgcli") + kwargs.setdefault("application_name", self.application_name) # If password prompt is not forced but no password is provided, try # getting it from environment variable. if not self.force_passwd_prompt and not passwd: passwd = os.environ.get("PGPASSWORD", "") - # Find password from store - key = f"{user}@{host}" - keyring_error_message = dedent( - """\ - {} - {} - To remove this message do one of the following: - - prepare keyring as described at: https://keyring.readthedocs.io/en/stable/ - - uninstall keyring: pip uninstall keyring - - disable keyring in our configuration: add keyring = False to [main]""" - ) - if not passwd and keyring: - - try: - passwd = keyring.get_password("pgcli", key) - except (RuntimeError, keyring.errors.InitError) as e: - click.secho( - keyring_error_message.format( - "Load your password from keyring returned:", str(e) - ), - err=True, - fg="red", - ) - # Prompt for a password immediately if requested via the -W flag. This # avoids wasting time trying to connect to the database and catching a # no-password exception. # If we successfully parsed a password from a URI, there's no need to # prompt for it, even with the -W flag if self.force_passwd_prompt and not passwd: - passwd = click.prompt( - "Password for %s" % user, hide_input=True, show_default=False, type=str - ) + passwd = click.prompt("Password for %s" % user, hide_input=True, show_default=False, type=str) + + key = f"{user}@{host}" + + if not passwd and auth.keyring: + passwd = auth.keyring_get_password(key) def should_ask_for_password(exc): # Prompt for a password after 1st attempt to connect @@ -572,13 +659,78 @@ def should_ask_for_password(exc): return True return False + if dsn: + parsed_dsn = conninfo_to_dict(dsn) + if "host" in parsed_dsn: + host = parsed_dsn["host"] + if "port" in parsed_dsn: + port = parsed_dsn["port"] + + if self.dsn_alias and self.dsn_ssh_tunnel_config and not self.ssh_tunnel_url: + for dsn_regex, tunnel_url in self.dsn_ssh_tunnel_config.items(): + if re.search(dsn_regex, self.dsn_alias): + self.ssh_tunnel_url = tunnel_url + break + + if self.ssh_tunnel_config and not self.ssh_tunnel_url: + for db_host_regex, tunnel_url in self.ssh_tunnel_config.items(): + if re.search(db_host_regex, host): + self.ssh_tunnel_url = tunnel_url + break + + if self.ssh_tunnel_url: + # We add the protocol as urlparse doesn't find it by itself + if "://" not in self.ssh_tunnel_url: + self.ssh_tunnel_url = f"ssh://{self.ssh_tunnel_url}" + + tunnel_info = urlparse(self.ssh_tunnel_url) + params = { + "local_bind_address": ("127.0.0.1",), + "remote_bind_address": (host, int(port or 5432)), + "ssh_address_or_host": (tunnel_info.hostname, tunnel_info.port or 22), + "logger": self.logger, + } + if tunnel_info.username: + params["ssh_username"] = tunnel_info.username + if tunnel_info.password: + params["ssh_password"] = tunnel_info.password + + # Hack: sshtunnel adds a console handler to the logger, so we revert handlers. + # We can remove this when https://github.com/pahaz/sshtunnel/pull/250 is merged. + logger_handlers = self.logger.handlers.copy() + try: + self.ssh_tunnel = sshtunnel.SSHTunnelForwarder(**params) + self.ssh_tunnel.start() + except Exception as e: + self.logger.handlers = logger_handlers + self.logger.error("traceback: %r", traceback.format_exc()) + click.secho(str(e), err=True, fg="red") + sys.exit(1) + self.logger.handlers = logger_handlers + + atexit.register(self.ssh_tunnel.stop) + host = "127.0.0.1" + port = self.ssh_tunnel.local_bind_ports[0] + + if dsn: + dsn = make_conninfo(dsn, host=host, port=port) + # Attempt to connect to the database. # Note that passwd may be empty on the first attempt. If connection # fails because of a missing or incorrect password, but we're allowed to # prompt for a password (no -w flag), prompt for a passwd and try again. try: try: - pgexecute = PGExecute(database, user, passwd, host, port, dsn, **kwargs) + pgexecute = PGExecute( + database, + user, + passwd, + host, + port, + dsn, + notify_callback, + **kwargs, + ) except (OperationalError, InterfaceError) as e: if should_ask_for_password(e): passwd = click.prompt( @@ -588,27 +740,25 @@ def should_ask_for_password(exc): type=str, ) pgexecute = PGExecute( - database, user, passwd, host, port, dsn, **kwargs + database, + user, + passwd, + host, + port, + dsn, + notify_callback, + **kwargs, ) else: raise e - if passwd and keyring: - try: - keyring.set_password("pgcli", key, passwd) - except (RuntimeError, keyring.errors.KeyringError) as e: - click.secho( - keyring_error_message.format( - "Set password in keyring returned:", str(e) - ), - err=True, - fg="red", - ) + if passwd and auth.keyring: + auth.keyring_set_password(key, passwd) except Exception as e: # Connecting to a database could fail. self.logger.debug("Database connection failed: %r.", e) self.logger.error("traceback: %r", traceback.format_exc()) click.secho(str(e), err=True, fg="red") - exit(1) + sys.exit(1) self.pgexecute = pgexecute @@ -650,34 +800,46 @@ def handle_editor_command(self, text): editor_command = special.editor_command(text) return text - def execute_command(self, text): + def execute_command(self, text, handle_closed_connection=True): logger = self.logger query = MetaQuery(query=text, successful=False) try: - if self.destructive_warning != "off": - destroy = confirm = confirm_destructive_query( - text, self.destructive_warning - ) + if self.destructive_warning: + if ( + self.destructive_statements_require_transaction + and not self.pgexecute.valid_transaction() + and is_destructive(text, self.destructive_warning) + ): + click.secho("Destructive statements must be run within a transaction.") + raise KeyboardInterrupt + destroy = confirm_destructive_query(text, self.destructive_warning, self.dsn_alias) if destroy is False: click.secho("Wise choice!") raise KeyboardInterrupt elif destroy: click.secho("Your call!") + output, query = self._evaluate_command(text) except KeyboardInterrupt: - # Restart connection to the database - self.pgexecute.connect() - logger.debug("cancelled query, sql: %r", text) - click.secho("cancelled query", err=True, fg="red") + if self.destructive_warning_restarts_connection: + # Restart connection to the database + self.pgexecute.connect() + logger.debug("cancelled query and restarted connection, sql: %r", text) + click.secho("cancelled query and restarted connection", err=True, fg="red") + else: + logger.debug("cancelled query, sql: %r", text) + click.secho("cancelled query", err=True, fg="red") except NotImplementedError: click.secho("Not Yet Implemented.", fg="yellow") except OperationalError as e: logger.error("sql: %r, error: %r", text, e) logger.error("traceback: %r", traceback.format_exc()) - self._handle_server_closed_connection(text) - except (PgCliQuitError, EOFError) as e: + click.secho(str(e), err=True, fg="red") + if handle_closed_connection: + self._handle_server_closed_connection(text) + except (PgCliQuitError, EOFError): raise except Exception as e: logger.error("sql: %r, error: %r", text, e) @@ -685,10 +847,17 @@ def execute_command(self, text): click.secho(str(e), err=True, fg="red") else: try: - if self.output_file and not text.startswith(("\\o ", "\\? ")): + if self.output_file and not text.startswith(("\\o ", "\\log-file", "\\? ", "\\echo ")): try: with open(self.output_file, "a", encoding="utf-8") as f: - click.echo(text, file=f) + should_hide = ( + self.hide_named_query_text + and query.is_special + and query.successful + and self._is_named_query_execution(text) + ) + if not should_hide: + click.echo(text, file=f) click.echo("\n".join(output), file=f) click.echo("", file=f) # extra newline except OSError as e: @@ -696,6 +865,24 @@ def execute_command(self, text): else: if output: self.echo_via_pager("\n".join(output)) + + # Log to file in addition to normal output + if self.log_file and not text.startswith(("\\o ", "\\log-file", "\\? ", "\\echo ")) and not text.strip() == "": + try: + with open(self.log_file, "a", encoding="utf-8") as f: + click.echo(dt.datetime.now().isoformat(), file=f) # timestamp log + should_hide = ( + self.hide_named_query_text + and query.is_special + and query.successful + and self._is_named_query_execution(text) + ) + if not should_hide: + click.echo(text, file=f) + click.echo("\n".join(output), file=f) + click.echo("", file=f) # extra newline + except OSError as e: + click.secho(str(e), err=True, fg="red") except KeyboardInterrupt: pass @@ -706,9 +893,9 @@ def execute_command(self, text): "Time: %0.03fs (%s), executed in: %0.03fs (%s)" % ( query.total_time, - pendulum.Duration(seconds=query.total_time).in_words(), + duration_in_words(query.total_time), query.execution_time, - pendulum.Duration(seconds=query.execution_time).in_words(), + duration_in_words(query.execution_time), ) ) else: @@ -729,6 +916,34 @@ def execute_command(self, text): logger.debug("Search path: %r", self.completer.search_path) return query + def _check_ongoing_transaction_and_allow_quitting(self): + """Return whether we can really quit, possibly by asking the + user to confirm so if there is an ongoing transaction. + """ + if not self.pgexecute.valid_transaction(): + return True + while 1: + try: + choice = click.prompt( + "A transaction is ongoing. Choose `c` to COMMIT, `r` to ROLLBACK, `a` to abort exit.", + default="a", + ) + except click.Abort: + # Print newline if user aborts with `^C`, otherwise + # pgcli's prompt will be printed on the same line + # (just after the confirmation prompt). + click.echo(None, err=False) + choice = "a" + choice = choice.lower() + if choice == "a": + return False # do not quit + if choice == "c": + query = self.execute_command("commit") + return query.successful # quit only if query is successful + if choice == "r": + query = self.execute_command("rollback") + return query.successful # quit only if query is successful + def run_cli(self): logger = self.logger @@ -751,6 +966,10 @@ def run_cli(self): text = self.prompt_app.prompt() except KeyboardInterrupt: continue + except EOFError: + if not self._check_ongoing_transaction_and_allow_quitting(): + continue + raise try: text = self.handle_editor_command(text) @@ -760,18 +979,12 @@ def run_cli(self): click.secho(str(e), err=True, fg="red") continue - # Initialize default metaquery in case execution fails - self.watch_command, timing = special.get_watch_command(text) - if self.watch_command: - while self.watch_command: - try: - query = self.execute_command(self.watch_command) - click.echo(f"Waiting for {timing} seconds before repeating") - sleep(timing) - except KeyboardInterrupt: - self.watch_command = None - else: - query = self.execute_command(text) + try: + self.handle_watch_command(text) + except PgCliQuitError: + if not self._check_ongoing_transaction_and_allow_quitting(): + continue + raise self.now = dt.datetime.today() @@ -779,12 +992,38 @@ def run_cli(self): with self._completer_lock: self.completer.extend_query_history(text) - self.query_history.append(query) - except (PgCliQuitError, EOFError): if not self.less_chatty: print("Goodbye!") + def handle_watch_command(self, text): + # Initialize default metaquery in case execution fails + self.watch_command, timing = special.get_watch_command(text) + + # If we run \watch without a command, apply it to the last query run. + if self.watch_command is not None and not self.watch_command.strip(): + try: + self.watch_command = self.query_history[-1].query + except IndexError: + click.secho("\\watch cannot be used with an empty query", err=True, fg="red") + self.watch_command = None + + # If there's a command to \watch, run it in a loop. + if self.watch_command: + while self.watch_command: + try: + query = self.execute_command(self.watch_command) + click.echo(f"Waiting for {timing} seconds before repeating") + sleep(timing) + except KeyboardInterrupt: + self.watch_command = None + + # Otherwise, execute it as a regular command. + else: + query = self.execute_command(text) + + self.query_history.append(query) + def _build_cli(self, history): key_bindings = pgcli_bindings(self) @@ -796,10 +1035,7 @@ def get_message(): prompt = self.get_prompt(prompt_format) - if ( - prompt_format == self.default_prompt - and len(prompt) > self.max_len_prompt - ): + if prompt_format == self.default_prompt and len(prompt) > self.max_len_prompt: prompt = self.get_prompt("\\d> ") prompt = prompt.replace("\\x1b", "\x1b") @@ -857,15 +1093,12 @@ def get_continuation(width, line_number, is_soft_wrap): def _should_limit_output(self, sql, cur): """returns True if the output should be truncated, False otherwise.""" + if self.explain_mode: + return False if not is_select(sql): return False - return ( - not self._has_limit(sql) - and self.row_limit != 0 - and cur - and cur.rowcount > self.row_limit - ) + return not self._has_limit(sql) and self.row_limit != 0 and cur and cur.rowcount > self.row_limit def _has_limit(self, sql): if not sql: @@ -889,6 +1122,8 @@ def _evaluate_command(self, text): logger = self.logger logger.debug("sql: %r", text) + # set query to formatter in order to parse table name + self.formatter.query = text all_success = True meta_changed = False # CREATE, ALTER, DROP, etc mutated = False # INSERT, DELETE, etc @@ -902,7 +1137,11 @@ def _evaluate_command(self, text): start = time() on_error_resume = self.on_error == "RESUME" res = self.pgexecute.run( - text, self.pgspecial, exception_formatter, on_error_resume + text, + self.pgspecial, + lambda x: exception_formatter(x, self.verbose_errors), + on_error_resume, + explain_mode=self.explain_mode, ) is_special = None @@ -925,18 +1164,28 @@ def _evaluate_command(self, text): table_format=self.table_format, dcmlfmt=self.decimal_format, floatfmt=self.float_format, + column_date_formats=self.column_date_formats, missingval=self.null_string, expanded=expanded, max_width=max_width, - case_function=( - self.completer.case - if self.settings["case_column_headers"] - else lambda x: x - ), + case_function=(self.completer.case if self.settings["case_column_headers"] else lambda x: x), style_output=self.style_output, + max_field_width=self.max_field_width, ) + + # Hide query text for named queries in quiet mode + if ( + self.hide_named_query_text + and is_special + and success + and self._is_named_query_execution(text) + and title + and title.startswith("> ") + ): + title = None + execution = time() - start - formatted = format_output(title, cur, headers, status, settings) + formatted = format_output(title, cur, headers, status, settings, self.explain_mode) output.extend(formatted) total = time() - start @@ -971,10 +1220,15 @@ def _handle_server_closed_connection(self, text): click.secho("Reconnecting...", fg="green") self.pgexecute.connect() click.secho("Reconnected!", fg="green") - self.execute_command(text) except OperationalError as e: click.secho("Reconnect Failed", fg="red") click.secho(str(e), err=True, fg="red") + else: + retry = self.auto_retry_closed_connection or confirm("Run the query from before reconnecting?") + if retry: + click.secho("Running query...", fg="green") + # Don't get stuck in a retry loop + self.execute_command(text, handle_closed_connection=False) def refresh_completions(self, history=None, persist_priorities="all"): """Refresh outdated completions @@ -985,9 +1239,7 @@ def refresh_completions(self, history=None, persist_priorities="all"): :param persist_priorities: 'all' or 'keywords' """ - callback = functools.partial( - self._on_completions_refreshed, persist_priorities=persist_priorities - ) + callback = functools.partial(self._on_completions_refreshed, persist_priorities=persist_priorities) return self.completion_refresher.refresh( self.pgexecute, self.pgspecial, @@ -1038,9 +1290,7 @@ def _swap_completer_objects(self, new_completer, persist_priorities): def get_completions(self, text, cursor_positition): with self._completer_lock: - return self.completer.get_completions( - Document(text=text, cursor_position=cursor_positition), None - ) + return self.completer.get_completions(Document(text=text, cursor_position=cursor_positition), None) def get_prompt(self, string): # should be before replacing \\d @@ -1057,6 +1307,7 @@ def get_prompt(self, string): string = string.replace("\\i", str(self.pgexecute.pid) or "(none)") string = string.replace("\\#", "#" if self.pgexecute.superuser else ">") string = string.replace("\\n", "\n") + string = string.replace("\\T", self.pgexecute.transaction_indicator) return string def get_last_query(self): @@ -1067,10 +1318,7 @@ def is_too_wide(self, line): """Will this line be too wide to fit into terminal?""" if not self.prompt_app: return False - return ( - len(COLOR_CODE_REGEX.sub("", line)) - > self.prompt_app.output.get_size().columns - ) + return len(COLOR_CODE_REGEX.sub("", line)) > self.prompt_app.output.get_size().columns def is_too_tall(self, lines): """Are there too many lines to fit into terminal?""" @@ -1081,10 +1329,7 @@ def is_too_tall(self, lines): def echo_via_pager(self, text, color=None): if self.pgspecial.pager_config == PAGER_OFF or self.watch_command: click.echo(text, color=color) - elif ( - self.pgspecial.pager_config == PAGER_LONG_OUTPUT - and self.table_format != "csv" - ): + elif self.pgspecial.pager_config == PAGER_LONG_OUTPUT and self.table_format != "csv": lines = text.split("\n") # The last 4 lines are reserved for the pgcli menu and padding @@ -1097,7 +1342,7 @@ def echo_via_pager(self, text, color=None): @click.command() -# Default host is '' so psycopg2 can default to either localhost or unix socket +# Default host is '' so psycopg can default to either localhost or unix socket @click.option( "-h", "--host", @@ -1109,7 +1354,7 @@ def echo_via_pager(self, text, color=None): "-p", "--port", default=5432, - help="Port number at which the " "postgres instance is listening.", + help="Port number at which the postgres instance is listening.", envvar="PGPORT", type=click.INT, ) @@ -1119,9 +1364,7 @@ def echo_via_pager(self, text, color=None): "username_opt", help="Username to connect to the postgres database.", ) -@click.option( - "-u", "--user", "username_opt", help="Username to connect to the postgres database." -) +@click.option("-u", "--user", "username_opt", help="Username to connect to the postgres database.") @click.option( "-W", "--password", @@ -1174,6 +1417,12 @@ def echo_via_pager(self, text, color=None): type=click.INT, help="Set threshold for row limit prompt. Use 0 to disable prompt.", ) +@click.option( + "--application-name", + default="pgcli", + envvar="PGAPPNAME", + help="Application name for the connection.", +) @click.option( "--less-chatty", "less_chatty", @@ -1191,7 +1440,14 @@ def echo_via_pager(self, text, color=None): "--list", "list_databases", is_flag=True, - help="list " "available databases, then exit.", + help="list available databases, then exit.", +) +@click.option( + "--ping", + "ping_database", + is_flag=True, + default=False, + help="Check database connectivity, then exit.", ) @click.option( "--auto-vertical-output", @@ -1201,9 +1457,24 @@ def echo_via_pager(self, text, color=None): @click.option( "--warn", default=None, - type=click.Choice(["all", "moderate", "off"]), help="Warn before running a destructive query.", ) +@click.option( + "--ssh-tunnel", + default=None, + help="Open an SSH tunnel to the given address and connect to the database from it.", +) +@click.option( + "--log-file", + default=None, + help="Write all queries & output into a file, in addition to the normal output destination.", +) +@click.option( + "--init-command", + "init_command", + type=str, + help="SQL statement to execute after connecting.", +) @click.argument("dbname", default=lambda: None, envvar="PGDATABASE", nargs=1) @click.argument("username", default=lambda: None, envvar="PGUSER", nargs=1) def cli( @@ -1220,13 +1491,18 @@ def cli( pgclirc, dsn, row_limit, + application_name, less_chatty, prompt, prompt_dsn, list_databases, + ping_database, auto_vertical_output, list_dsn, warn, + ssh_tunnel: str, + init_command: str, + log_file: str, ): if version: print("Version:", __version__) @@ -1254,26 +1530,37 @@ def cli( for alias in cfg["alias_dsn"]: click.secho(alias + " : " + cfg["alias_dsn"][alias]) sys.exit(0) - except Exception as err: + except Exception: click.secho( - "Invalid DSNs found in the config file. " - 'Please check the "[alias_dsn]" section in pgclirc.', + "Invalid DSNs found in the config file. Please check the \"[alias_dsn]\" section in pgclirc.", err=True, fg="red", ) - exit(1) + sys.exit(1) + + if ssh_tunnel and not SSH_TUNNEL_SUPPORT: + click.secho( + 'Cannot open SSH tunnel, "sshtunnel" package was not found. ' + "Please install pgcli with `pip install pgcli[sshtunnel]` if you want SSH tunnel support.", + err=True, + fg="red", + ) + sys.exit(1) pgcli = PGCli( prompt_passwd, never_prompt, pgclirc_file=pgclirc, row_limit=row_limit, + application_name=application_name, single_connection=single_connection, less_chatty=less_chatty, prompt=prompt, prompt_dsn=prompt_dsn, auto_vertical_output=auto_vertical_output, warn=warn, + ssh_tunnel_url=ssh_tunnel, + log_file=log_file, ) # Choose which ever one has a valid value. @@ -1287,32 +1574,30 @@ def cli( service = database[8:] elif os.getenv("PGSERVICE") is not None: service = os.getenv("PGSERVICE") - # because option --list or -l are not supposed to have a db name - if list_databases: + # because option --ping, --list or -l are not supposed to have a db name + if list_databases or ping_database: database = "postgres" + cfg = load_config(pgclirc, config_full_path) if dsn != "": try: - cfg = load_config(pgclirc, config_full_path) dsn_config = cfg["alias_dsn"][dsn] except KeyError: click.secho( - f"Could not find a DSN with alias {dsn}. " - 'Please check the "[alias_dsn]" section in pgclirc.', + f"Could not find a DSN with alias {dsn}. Please check the \"[alias_dsn]\" section in pgclirc.", err=True, fg="red", ) - exit(1) + sys.exit(1) except Exception: click.secho( - "Invalid DSNs found in the config file. " - 'Please check the "[alias_dsn]" section in pgclirc.', + "Invalid DSNs found in the config file. Please check the \"[alias_dsn]\" section in pgclirc.", err=True, fg="red", ) - exit(1) - pgcli.connect_uri(dsn_config) + sys.exit(1) pgcli.dsn_alias = dsn + pgcli.connect_uri(dsn_config) elif "://" in database: pgcli.connect_uri(database) elif "=" in database and service is None: @@ -1322,6 +1607,80 @@ def cli( else: pgcli.connect(database, host, user, port) + if "use_local_timezone" not in cfg["main"] or cfg["main"].as_bool("use_local_timezone"): + server_tz = pgcli.pgexecute.get_timezone() + + def echo_error(msg: str): + click.secho( + "Failed to determine the local time zone", + err=True, + fg="yellow", + ) + click.secho( + msg, + err=True, + fg="yellow", + ) + click.secho( + f"Continuing with the default time zone as preset by the server ({server_tz})", + err=True, + fg="yellow", + ) + click.secho( + "Set `use_local_timezone = False` in the config to avoid trying to override the server time zone\n", + err=True, + dim=True, + ) + + local_tz = None + try: + local_tz = tzlocal.get_localzone_name() + + if local_tz is None: + echo_error("No local time zone configuration found\n") + else: + click.secho( + f"Using local time zone {local_tz} (server uses {server_tz})", + fg="green", + ) + click.secho( + "Use `set time zone ` to override, or set `use_local_timezone = False` in the config", + dim=True, + ) + + pgcli.pgexecute.set_timezone(local_tz) + except ZoneInfoNotFoundError as e: + # e.args[0] is the pre-formatted message which includes a list + # of conflicting sources + echo_error(e.args[0]) + + # Merge init-commands: global, DSN-specific, then CLI-provided + init_cmds = [] + # 1) Global init-commands + global_section = pgcli.config.get("init-commands", {}) + for _, val in global_section.items(): + if isinstance(val, (list, tuple)): + init_cmds.extend(val) + elif val: + init_cmds.append(val) + # 2) DSN-specific init-commands + if dsn: + alias_section = pgcli.config.get("alias_dsn.init-commands", {}) + if dsn in alias_section: + val = alias_section.get(dsn) + if isinstance(val, (list, tuple)): + init_cmds.extend(val) + elif val: + init_cmds.append(val) + # 3) CLI-provided init-command + if init_command: + init_cmds.append(init_command) + if init_cmds: + click.echo("Running init commands: %s" % "; ".join(init_cmds)) + for cmd in init_cmds: + # Execute each init command + list(pgcli.pgexecute.run(cmd)) + if list_databases: cur, headers, status = pgcli.pgexecute.full_databases() @@ -1332,8 +1691,22 @@ def cli( sys.exit(0) + if ping_database: + try: + list(pgcli.pgexecute.run("SELECT 1")) + except Exception: + click.secho( + "Could not connect to the database. Please check that the database is running.", + err=True, + fg="red", + ) + sys.exit(1) + else: + click.echo("PONG") + sys.exit(0) + pgcli.logger.debug( - "Launch Params: \n" "\tdatabase: %r" "\tuser: %r" "\thost: %r" "\tport: %r", + "Launch Params: \n\tdatabase: %r\tuser: %r\thost: %r\tport: %r", database, user, host, @@ -1351,9 +1724,7 @@ def obfuscate_process_password(): if "://" in process_title: process_title = re.sub(r":(.*):(.*)@", r":\1:xxxx@", process_title) elif "=" in process_title: - process_title = re.sub( - r"password=(.+?)((\s[a-zA-Z]+=)|$)", r"password=xxxx\2", process_title - ) + process_title = re.sub(r"password=(.+?)((\s[a-zA-Z]+=)|$)", r"password=xxxx\2", process_title) setproctitle.setproctitle(process_title) @@ -1405,17 +1776,83 @@ def is_select(status): return status.split(None, 1)[0].lower() == "select" -def exception_formatter(e): - return click.style(str(e), fg="red") +def diagnostic_output(diagnostic: Diagnostic) -> str: + fields = [] + + if diagnostic.severity is not None: + fields.append("Severity: " + diagnostic.severity) + + if diagnostic.severity_nonlocalized is not None: + fields.append("Severity (non-localized): " + diagnostic.severity_nonlocalized) + if diagnostic.sqlstate is not None: + fields.append("SQLSTATE code: " + diagnostic.sqlstate) -def format_output(title, cur, headers, status, settings): + if diagnostic.message_primary is not None: + fields.append("Message: " + diagnostic.message_primary) + + if diagnostic.message_detail is not None: + fields.append("Detail: " + diagnostic.message_detail) + + if diagnostic.message_hint is not None: + fields.append("Hint: " + diagnostic.message_hint) + + if diagnostic.statement_position is not None: + fields.append("Position: " + diagnostic.statement_position) + + if diagnostic.internal_position is not None: + fields.append("Internal position: " + diagnostic.internal_position) + + if diagnostic.internal_query is not None: + fields.append("Internal query: " + diagnostic.internal_query) + + if diagnostic.context is not None: + fields.append("Where: " + diagnostic.context) + + if diagnostic.schema_name is not None: + fields.append("Schema name: " + diagnostic.schema_name) + + if diagnostic.table_name is not None: + fields.append("Table name: " + diagnostic.table_name) + + if diagnostic.column_name is not None: + fields.append("Column name: " + diagnostic.column_name) + + if diagnostic.datatype_name is not None: + fields.append("Data type name: " + diagnostic.datatype_name) + + if diagnostic.constraint_name is not None: + fields.append("Constraint name: " + diagnostic.constraint_name) + + if diagnostic.source_file is not None: + fields.append("File: " + diagnostic.source_file) + + if diagnostic.source_line is not None: + fields.append("Line: " + diagnostic.source_line) + + if diagnostic.source_function is not None: + fields.append("Routine: " + diagnostic.source_function) + + return "\n".join(fields) + + +def exception_formatter(e, verbose_errors: bool = False): + s = str(e) + if verbose_errors: + s += "\n" + diagnostic_output(e.diag) + return click.style(s, fg="red") + + +def format_output(title, cur, headers, status, settings, explain_mode=False): output = [] expanded = settings.expanded or settings.table_format == "vertical" table_format = "vertical" if settings.expanded else settings.table_format max_width = settings.max_width case_function = settings.case_function - formatter = TabularOutputFormatter(format_name=table_format) + if explain_mode: + formatter = ExplainOutputFormatter(max_width or 100) + else: + formatter = TabularOutputFormatter(format_name=table_format) def format_array(val): if val is None: @@ -1427,12 +1864,18 @@ def format_array(val): def format_arrays(data, headers, **_): data = list(data) for row in data: - row[:] = [ - format_array(val) if isinstance(val, list) else val for val in row - ] + row[:] = [format_array(val) if isinstance(val, list) else val for val in row] return data, headers + def format_status(cur, status): + # redshift does not return rowcount as part of status. + # See https://github.com/dbcli/pgcli/issues/1320 + if cur and hasattr(cur, "rowcount") and cur.rowcount is not None: + if status and not status.endswith(str(cur.rowcount)): + status += " %s" % cur.rowcount + return status + output_kwargs = { "sep_title": "RECORD {n}", "sep_character": "-", @@ -1440,14 +1883,19 @@ def format_arrays(data, headers, **_): "missing_value": settings.missingval, "integer_format": settings.dcmlfmt, "float_format": settings.floatfmt, + "column_date_formats": settings.column_date_formats, "preprocessors": (format_numbers, format_arrays), "disable_numparse": True, "preserve_whitespace": True, "style": settings.style_output, + "max_field_width": settings.max_field_width, } if not settings.floatfmt: output_kwargs["preprocessors"] = (align_decimals,) + if settings.column_date_formats: + output_kwargs["preprocessors"] += (format_timestamps,) + if table_format == "csv": # The default CSV dialect is "excel" which is not handling newline values correctly # Nevertheless, we want to keep on using "excel" on Windows since it uses '\r\n' @@ -1467,15 +1915,11 @@ def format_arrays(data, headers, **_): if hasattr(cur, "description"): column_types = [] for d in cur.description: - if ( - d[1] in psycopg2.extensions.DECIMAL.values - or d[1] in psycopg2.extensions.FLOAT.values - ): + col_type = cur.adapters.types.get(d.type_code) + type_name = col_type.name if col_type else None + if type_name in ("numeric", "float4", "float8"): column_types.append(float) - if ( - d[1] == psycopg2.extensions.INTEGER.values - or d[1] in psycopg2.extensions.LONGINTEGER.values - ): + if type_name in ("int2", "int4", "int8"): column_types.append(int) else: column_types.append(str) @@ -1485,23 +1929,22 @@ def format_arrays(data, headers, **_): formatted = iter(formatted.splitlines()) first_line = next(formatted) formatted = itertools.chain([first_line], formatted) - if ( - not expanded - and max_width - and len(strip_ansi(first_line)) > max_width - and headers - ): + if not explain_mode and not expanded and max_width and len(strip_ansi(first_line)) > max_width and headers: formatted = formatter.format_output( - cur, headers, format_name="vertical", column_types=None, **output_kwargs + cur, + headers, + format_name="vertical", + column_types=column_types, + **output_kwargs, ) if isinstance(formatted, str): formatted = iter(formatted.splitlines()) output = itertools.chain(output, formatted) - # Only print the status if it's not None and we are not producing CSV - if status and table_format != "csv": - output = itertools.chain(output, [status]) + # Only print the status if it's not None + if status: + output = itertools.chain(output, [format_status(cur, status)]) return output @@ -1516,7 +1959,7 @@ def parse_service_info(service): elif os.getenv("PGSYSCONFDIR"): service_file = os.path.join(os.getenv("PGSYSCONFDIR"), ".pg_service.conf") else: - service_file = expanduser("~/.pg_service.conf") + service_file = os.path.expanduser("~/.pg_service.conf") if not service or not os.path.exists(service_file): # nothing to do return None, service_file @@ -1533,5 +1976,28 @@ def parse_service_info(service): return service_conf, service_file +def duration_in_words(duration_in_seconds: float) -> str: + if not duration_in_seconds: + return "0 seconds" + components = [] + hours, remainder = divmod(duration_in_seconds, 3600) + if hours > 1: + components.append(f"{int(hours)} hours") + elif hours == 1: + components.append("1 hour") + minutes, seconds = divmod(remainder, 60) + if minutes > 1: + components.append(f"{int(minutes)} minutes") + elif minutes == 1: + components.append("1 minute") + if seconds >= 2: + components.append(f"{int(seconds)} seconds") + elif seconds >= 1: + components.append("1 second") + elif seconds: + components.append(f"{round(seconds, 3)} second") + return " ".join(components) + + if __name__ == "__main__": cli() diff --git a/pgcli/packages/formatter/__init__.py b/pgcli/packages/formatter/__init__.py new file mode 100644 index 000000000..9bad5790a --- /dev/null +++ b/pgcli/packages/formatter/__init__.py @@ -0,0 +1 @@ +# coding=utf-8 diff --git a/pgcli/packages/formatter/sqlformatter.py b/pgcli/packages/formatter/sqlformatter.py new file mode 100644 index 000000000..6c4973db7 --- /dev/null +++ b/pgcli/packages/formatter/sqlformatter.py @@ -0,0 +1,67 @@ +# coding=utf-8 + +from pgcli.packages.parseutils.tables import extract_tables + + +supported_formats = ( + "sql-insert", + "sql-update", + "sql-update-1", + "sql-update-2", +) + +preprocessors = () + + +def escape_for_sql_statement(value): + if value is None: + return "NULL" + + if isinstance(value, bytes): + return f"X'{value.hex()}'" + + return "'{}'".format(value) + + +def adapter(data, headers, table_format=None, **kwargs): + tables = extract_tables(formatter.query) + if len(tables) > 0: + table = tables[0] + if table[0]: + table_name = "{}.{}".format(*table[:2]) + else: + table_name = table[1] + else: + table_name = "DUAL" + if table_format == "sql-insert": + h = '", "'.join(headers) + yield 'INSERT INTO "{}" ("{}") VALUES'.format(table_name, h) + prefix = " " + for d in data: + values = ", ".join(escape_for_sql_statement(v) for i, v in enumerate(d)) + yield "{}({})".format(prefix, values) + if prefix == " ": + prefix = ", " + yield ";" + if table_format.startswith("sql-update"): + s = table_format.split("-") + keys = 1 + if len(s) > 2: + keys = int(s[-1]) + for d in data: + yield 'UPDATE "{}" SET'.format(table_name) + prefix = " " + for i, v in enumerate(d[keys:], keys): + yield '{}"{}" = {}'.format(prefix, headers[i], escape_for_sql_statement(v)) + if prefix == " ": + prefix = ", " + f = '"{}" = {}' + where = (f.format(headers[i], escape_for_sql_statement(d[i])) for i in range(keys)) + yield "WHERE {};".format(" AND ".join(where)) + + +def register_new_formatter(TabularOutputFormatter): + global formatter + formatter = TabularOutputFormatter + for sql_format in supported_formats: + TabularOutputFormatter.register_new_formatter(sql_format, adapter, preprocessors, {"table_format": sql_format}) diff --git a/pgcli/packages/parseutils/__init__.py b/pgcli/packages/parseutils/__init__.py index 1acc008e0..76a930ce0 100644 --- a/pgcli/packages/parseutils/__init__.py +++ b/pgcli/packages/parseutils/__init__.py @@ -1,6 +1,17 @@ import sqlparse +BASE_KEYWORDS = [ + "drop", + "shutdown", + "delete", + "truncate", + "alter", + "unconditional_update", +] +ALL_KEYWORDS = BASE_KEYWORDS + ["update"] + + def query_starts_with(formatted_sql, prefixes): """Check if the query starts with any item from *prefixes*.""" prefixes = [prefix.lower() for prefix in prefixes] @@ -13,22 +24,33 @@ def query_is_unconditional_update(formatted_sql): return bool(tokens) and tokens[0] == "update" and "where" not in tokens -def query_is_simple_update(formatted_sql): - """Check if the query starts with UPDATE.""" - tokens = formatted_sql.split() - return bool(tokens) and tokens[0] == "update" - - -def is_destructive(queries, warning_level="all"): +def is_destructive(queries, keywords): """Returns if any of the queries in *queries* is destructive.""" - keywords = ("drop", "shutdown", "delete", "truncate", "alter") for query in sqlparse.split(queries): if query: formatted_sql = sqlparse.format(query.lower(), strip_comments=True).strip() - if query_starts_with(formatted_sql, keywords): - return True - if query_is_unconditional_update(formatted_sql): + if "unconditional_update" in keywords and query_is_unconditional_update(formatted_sql): return True - if warning_level == "all" and query_is_simple_update(formatted_sql): + if query_starts_with(formatted_sql, keywords): return True return False + + +def parse_destructive_warning(warning_level): + """Converts a deprecated destructive warning option to a list of command keywords.""" + if not warning_level: + return [] + + if not isinstance(warning_level, list): + if "," in warning_level: + return warning_level.split(",") + warning_level = [warning_level] + + return { + "true": ALL_KEYWORDS, + "false": [], + "all": ALL_KEYWORDS, + "moderate": BASE_KEYWORDS, + "off": [], + "": [], + }.get(warning_level[0], warning_level) diff --git a/pgcli/packages/parseutils/ctes.py b/pgcli/packages/parseutils/ctes.py index e1f908850..a6a364a02 100644 --- a/pgcli/packages/parseutils/ctes.py +++ b/pgcli/packages/parseutils/ctes.py @@ -17,7 +17,7 @@ def isolate_query_ctes(full_text, text_before_cursor): """Simplify a query by converting CTEs into table metadata objects""" if not full_text or not full_text.strip(): - return full_text, text_before_cursor, tuple() + return full_text, text_before_cursor, () ctes, remainder = extract_ctes(full_text) if not ctes: diff --git a/pgcli/packages/parseutils/meta.py b/pgcli/packages/parseutils/meta.py index 333cab559..df41cf4ee 100644 --- a/pgcli/packages/parseutils/meta.py +++ b/pgcli/packages/parseutils/meta.py @@ -1,8 +1,6 @@ from collections import namedtuple -_ColumnMetadata = namedtuple( - "ColumnMetadata", ["name", "datatype", "foreignkeys", "default", "has_default"] -) +_ColumnMetadata = namedtuple("ColumnMetadata", ["name", "datatype", "foreignkeys", "default", "has_default"]) def ColumnMetadata(name, datatype, foreignkeys=None, default=None, has_default=False): @@ -143,11 +141,7 @@ def arg(name, typ, num): num_args = len(args) num_defaults = len(self.arg_defaults) has_default = num + num_defaults >= num_args - default = ( - self.arg_defaults[num - num_args + num_defaults] - if has_default - else None - ) + default = self.arg_defaults[num - num_args + num_defaults] if has_default else None return ColumnMetadata(name, typ, [], default, has_default) return [arg(name, typ, num) for num, (name, typ) in enumerate(args)] diff --git a/pgcli/packages/parseutils/tables.py b/pgcli/packages/parseutils/tables.py index aaa676ccd..bf67df09c 100644 --- a/pgcli/packages/parseutils/tables.py +++ b/pgcli/packages/parseutils/tables.py @@ -3,16 +3,9 @@ from sqlparse.sql import IdentifierList, Identifier, Function from sqlparse.tokens import Keyword, DML, Punctuation -TableReference = namedtuple( - "TableReference", ["schema", "name", "alias", "is_function"] -) +TableReference = namedtuple("TableReference", ["schema", "name", "alias", "is_function"]) TableReference.ref = property( - lambda self: self.alias - or ( - self.name - if self.name.islower() or self.name[0] == '"' - else '"' + self.name + '"' - ) + lambda self: self.alias or (self.name if self.name.islower() or self.name[0] == '"' else '"' + self.name + '"') ) @@ -53,27 +46,19 @@ def extract_from_part(parsed, stop_at_punctuation=True): # Also 'SELECT * FROM abc JOIN def' will trigger this elif # condition. So we need to ignore the keyword JOIN and its variants # INNER JOIN, FULL OUTER JOIN, etc. - elif ( - item.ttype is Keyword - and (not item.value.upper() == "FROM") - and (not item.value.upper().endswith("JOIN")) - ): + elif item.ttype is Keyword and (not item.value.upper() == "FROM") and (not item.value.upper().endswith("JOIN")): tbl_prefix_seen = False else: yield item elif item.ttype is Keyword or item.ttype is Keyword.DML: item_val = item.value.upper() - if ( - item_val - in ( - "COPY", - "FROM", - "INTO", - "UPDATE", - "TABLE", - ) - or item_val.endswith("JOIN") - ): + if item_val in ( + "COPY", + "FROM", + "INTO", + "UPDATE", + "TABLE", + ) or item_val.endswith("JOIN"): tbl_prefix_seen = True # 'SELECT a, FROM abc' will detect FROM as part of the column list. # So this check here is necessary. @@ -120,15 +105,11 @@ def parse_identifier(item): try: schema_name = identifier.get_parent_name() real_name = identifier.get_real_name() - is_function = allow_functions and _identifier_is_function( - identifier - ) + is_function = allow_functions and _identifier_is_function(identifier) except AttributeError: continue if real_name: - yield TableReference( - schema_name, real_name, identifier.get_alias(), is_function - ) + yield TableReference(schema_name, real_name, identifier.get_alias(), is_function) elif isinstance(item, Identifier): schema_name, real_name, alias = parse_identifier(item) is_function = allow_functions and _identifier_is_function(item) @@ -143,7 +124,7 @@ def parse_identifier(item): # extract_tables is inspired from examples in the sqlparse lib. def extract_tables(sql): - """Extract the table names from an SQL statment. + """Extract the table names from an SQL statement. Returns a list of TableReference namedtuples diff --git a/pgcli/packages/parseutils/utils.py b/pgcli/packages/parseutils/utils.py index 034c96e92..6d577430a 100644 --- a/pgcli/packages/parseutils/utils.py +++ b/pgcli/packages/parseutils/utils.py @@ -79,9 +79,7 @@ def find_prev_keyword(sql, n_skip=0): logical_operators = ("AND", "OR", "NOT", "BETWEEN") for t in reversed(flattened): - if t.value == "(" or ( - t.is_keyword and (t.value.upper() not in logical_operators) - ): + if t.value == "(" or (t.is_keyword and (t.value.upper() not in logical_operators)): # Find the location of token t in the original parsed statement # We can't use parsed.token_index(t) because t may be a child token # inside a TokenList, in which case token_index throws an error diff --git a/pgcli/packages/pgliterals/pgliterals.json b/pgcli/packages/pgliterals/pgliterals.json index df00817a5..6828a4632 100644 --- a/pgcli/packages/pgliterals/pgliterals.json +++ b/pgcli/packages/pgliterals/pgliterals.json @@ -227,7 +227,7 @@ "ROWS": [], "SELECT": [], "SESSION": [], - "SET": [], + "SET": ["SEARCH_PATH TO"], "SHARE": [], "SHOW": [], "SIZE": [], diff --git a/pgcli/packages/prompt_utils.py b/pgcli/packages/prompt_utils.py index e8589def3..997b86e7f 100644 --- a/pgcli/packages/prompt_utils.py +++ b/pgcli/packages/prompt_utils.py @@ -3,7 +3,7 @@ from .parseutils import is_destructive -def confirm_destructive_query(queries, warning_level): +def confirm_destructive_query(queries, keywords, alias): """Check if the query is destructive and prompts the user to confirm. Returns: @@ -12,11 +12,13 @@ def confirm_destructive_query(queries, warning_level): * False if the query is destructive and the user doesn't want to proceed. """ - prompt_text = ( - "You're about to run a destructive command.\n" "Do you want to proceed? (y/n)" - ) - if is_destructive(queries, warning_level) and sys.stdin.isatty(): - return prompt(prompt_text, type=bool) + info = "You're about to run a destructive command" + if alias: + info += f" in {click.style(alias, fg='red')}" + + prompt_text = f"{info}.\nDo you want to proceed?" + if is_destructive(queries, keywords) and sys.stdin.isatty(): + return confirm(prompt_text) def confirm(*args, **kwargs): diff --git a/pgcli/packages/sqlcompletion.py b/pgcli/packages/sqlcompletion.py index 630530189..20194ab14 100644 --- a/pgcli/packages/sqlcompletion.py +++ b/pgcli/packages/sqlcompletion.py @@ -1,4 +1,3 @@ -import sys import re import sqlparse from collections import namedtuple @@ -27,16 +26,16 @@ Function = namedtuple("Function", ["schema", "table_refs", "usage"]) # For convenience, don't require the `usage` argument in Function constructor -Function.__new__.__defaults__ = (None, tuple(), None) -Table.__new__.__defaults__ = (None, tuple(), tuple()) -View.__new__.__defaults__ = (None, tuple()) -FromClauseItem.__new__.__defaults__ = (None, tuple(), tuple()) +Function.__new__.__defaults__ = (None, (), None) +Table.__new__.__defaults__ = (None, (), ()) +View.__new__.__defaults__ = (None, ()) +FromClauseItem.__new__.__defaults__ = (None, (), ()) Column = namedtuple( "Column", ["table_refs", "require_last_table", "local_tables", "qualifiable", "context"], ) -Column.__new__.__defaults__ = (None, None, tuple(), False, None) +Column.__new__.__defaults__ = (None, None, (), False, None) Keyword = namedtuple("Keyword", ["last_token"]) Keyword.__new__.__defaults__ = (None,) @@ -50,15 +49,11 @@ class SqlStatement: def __init__(self, full_text, text_before_cursor): self.identifier = None - self.word_before_cursor = word_before_cursor = last_word( - text_before_cursor, include="many_punctuations" - ) + self.word_before_cursor = word_before_cursor = last_word(text_before_cursor, include="many_punctuations") full_text = _strip_named_query(full_text) text_before_cursor = _strip_named_query(text_before_cursor) - full_text, text_before_cursor, self.local_tables = isolate_query_ctes( - full_text, text_before_cursor - ) + full_text, text_before_cursor, self.local_tables = isolate_query_ctes(full_text, text_before_cursor) self.text_before_cursor_including_last_word = text_before_cursor @@ -78,9 +73,7 @@ def __init__(self, full_text, text_before_cursor): else: parsed = sqlparse.parse(text_before_cursor) - full_text, text_before_cursor, parsed = _split_multiple_statements( - full_text, text_before_cursor, parsed - ) + full_text, text_before_cursor, parsed = _split_multiple_statements(full_text, text_before_cursor, parsed) self.full_text = full_text self.text_before_cursor = text_before_cursor @@ -98,9 +91,7 @@ def get_tables(self, scope="full"): If 'before', only tables before the cursor are returned. If not 'insert' and the stmt is an insert, the first table is skipped. """ - tables = extract_tables( - self.full_text if scope == "full" else self.text_before_cursor - ) + tables = extract_tables(self.full_text if scope == "full" else self.text_before_cursor) if scope == "insert": tables = tables[:1] elif self.is_insert(): @@ -119,9 +110,7 @@ def get_identifier_schema(self): return schema def reduce_to_prev_keyword(self, n_skip=0): - prev_keyword, self.text_before_cursor = find_prev_keyword( - self.text_before_cursor, n_skip=n_skip - ) + prev_keyword, self.text_before_cursor = find_prev_keyword(self.text_before_cursor, n_skip=n_skip) return prev_keyword @@ -222,9 +211,7 @@ def _split_multiple_statements(full_text, text_before_cursor, parsed): token1_idx = statement.token_index(token1) token2 = statement.token_next(token1_idx)[1] if token2 and token2.value.upper() == "FUNCTION": - full_text, text_before_cursor, statement = _statement_from_function( - full_text, text_before_cursor, statement - ) + full_text, text_before_cursor, statement = _statement_from_function(full_text, text_before_cursor, statement) return full_text, text_before_cursor, statement @@ -283,14 +270,13 @@ def suggest_special(text): return (Schema(), Function(schema=None, usage="special")) return (Schema(), rel_type(schema=None)) - if cmd in ["\\n", "\\ns", "\\nd"]: + if cmd in ["\\n", "\\ns", "\\nd", "\\nq"]: return (NamedQuery(),) return (Keyword(), Special()) def suggest_based_on_last_token(token, stmt): - if isinstance(token, str): token_v = token.lower() elif isinstance(token, Comparison): @@ -362,11 +348,7 @@ def suggest_based_on_last_token(token, stmt): # Get the token before the parens prev_tok = p.token_prev(len(p.tokens) - 1)[1] - if ( - prev_tok - and prev_tok.value - and prev_tok.value.lower().split(" ")[-1] == "using" - ): + if prev_tok and prev_tok.value and prev_tok.value.lower().split(" ")[-1] == "using": # tbl1 INNER JOIN tbl2 USING (col1, col2) tables = stmt.get_tables("before") @@ -380,7 +362,7 @@ def suggest_based_on_last_token(token, stmt): ) elif p.token_first().value.lower() == "select": - # If the lparen is preceeded by a space chances are we're about to + # If the lparen is preceded by a space chances are we're about to # do a sub-select. if last_word(stmt.text_before_cursor, "all_punctuations").startswith("("): return (Keyword(),) @@ -390,16 +372,21 @@ def suggest_based_on_last_token(token, stmt): # We're probably in a function argument list return _suggest_expression(token_v, stmt) elif token_v == "set": + # "set" for changing a run-time parameter + p = sqlparse.parse(stmt.text_before_cursor)[0] + is_first_token = p.token_first().value.upper() == token_v.upper() + if is_first_token: + return (Keyword(token_v.upper()),) + + # E.g. 'UPDATE foo SET' return (Column(table_refs=stmt.get_tables(), local_tables=stmt.local_tables),) + elif token_v in ("select", "where", "having", "order by", "distinct"): return _suggest_expression(token_v, stmt) elif token_v == "as": # Don't suggest anything for aliases return () - elif (token_v.endswith("join") and token.is_keyword) or ( - token_v in ("copy", "from", "update", "into", "describe", "truncate") - ): - + elif (token_v.endswith("join") and token.is_keyword) or (token_v in ("copy", "from", "update", "into", "describe", "truncate")): schema = stmt.get_identifier_schema() tables = extract_tables(stmt.text_before_cursor) is_join = token_v.endswith("join") and token.is_keyword @@ -413,11 +400,7 @@ def suggest_based_on_last_token(token, stmt): suggest.insert(0, Schema()) if token_v == "from" or is_join: - suggest.append( - FromClauseItem( - schema=schema, table_refs=tables, local_tables=stmt.local_tables - ) - ) + suggest.append(FromClauseItem(schema=schema, table_refs=tables, local_tables=stmt.local_tables)) elif token_v == "truncate": suggest.append(Table(schema)) else: @@ -436,7 +419,6 @@ def suggest_based_on_last_token(token, stmt): try: prev = stmt.get_previous_token(token).value.lower() if prev in ("drop", "alter", "create", "create or replace"): - # Suggest functions from either the currently-selected schema or the # public schema if no schema has been specified suggest = [] @@ -450,7 +432,7 @@ def suggest_based_on_last_token(token, stmt): except ValueError: pass - return tuple() + return () elif token_v in ("table", "view"): # E.g. 'ALTER TABLE ' @@ -520,6 +502,9 @@ def suggest_based_on_last_token(token, stmt): return tuple(suggestions) elif token_v in {"alter", "create", "drop"}: return (Keyword(token_v.upper()),) + elif token_v == "to": + # E.g. 'SET search_path TO' + return (Schema(),) elif token.is_keyword: # token is a keyword we haven't implemented any special handling for # go backwards in the query until we find one we do recognize @@ -556,14 +541,10 @@ def _suggest_expression(token_v, stmt): ) -def identifies(id, ref): +def identifies(table_id, ref): """Returns true if string `id` matches TableReference `ref`""" - return ( - id == ref.alias - or id == ref.name - or (ref.schema and (id == ref.schema + "." + ref.name)) - ) + return table_id == ref.alias or table_id == ref.name or (ref.schema and (table_id == ref.schema + "." + ref.name)) def _allow_join_condition(statement): diff --git a/pgcli/pgbuffer.py b/pgcli/pgbuffer.py index 706ed25fe..aba180c8f 100644 --- a/pgcli/pgbuffer.py +++ b/pgcli/pgbuffer.py @@ -22,6 +22,15 @@ def _is_complete(sql): """ +def safe_multi_line_mode(pgcli): + @Condition + def cond(): + _logger.debug('Multi-line mode state: "%s" / "%s"', pgcli.multi_line, pgcli.multiline_mode) + return pgcli.multi_line and (pgcli.multiline_mode == "safe") + + return cond + + def buffer_should_be_handled(pgcli): @Condition def cond(): @@ -37,14 +46,13 @@ def cond(): text = doc.text.strip() return ( - text.startswith("\\") # Special Command - or text.endswith(r"\e") # Special Command - or text.endswith(r"\G") # Ended with \e which should launch the editor - or _is_complete(text) # A complete SQL command - or (text == "exit") # Exit doesn't need semi-colon - or (text == "quit") # Quit doesn't need semi-colon - or (text == ":q") # To all the vim fans out there - or (text == "") # Just a plain enter without any text + text.startswith("\\") + or text.endswith((r"\e", r"\G")) + or _is_complete(text) + or text == "exit" + or text == "quit" + or text == ":q" + or text == "" # Just a plain enter without any text ) return cond diff --git a/pgcli/pgclirc b/pgcli/pgclirc index 15c10f5e0..35ff41c5a 100644 --- a/pgcli/pgclirc +++ b/pgcli/pgclirc @@ -9,6 +9,10 @@ smart_completion = True # visible.) wider_completion_menu = False +# Do not create new connections for refreshing completions; Equivalent to +# always running with the --single-connection flag. +always_use_single_connection = False + # Multi-line mode allows breaking up the sql statements into multiple lines. If # this is set to True, then the end of the statements must have a semi-colon. # If this is set to False then sql statements can't be split into multiple @@ -22,14 +26,23 @@ multi_line = False # a command. multi_line_mode = psql -# Destructive warning mode will alert you before executing a sql statement +# Destructive warning will alert you before executing a sql statement # that may cause harm to the database such as "drop table", "drop database", # "shutdown", "delete", or "update". -# Possible values: -# "all" - warn on data definition statements, server actions such as SHUTDOWN, DELETE or UPDATE -# "moderate" - skip warning on UPDATE statements, except for unconditional updates -# "off" - skip all warnings -destructive_warning = all +# You can pass a list of destructive commands or leave it empty if you want to skip all warnings. +# "unconditional_update" will warn you of update statements that don't have a where clause +destructive_warning = drop, shutdown, delete, truncate, alter, update, unconditional_update + +# When `destructive_warning` is on and the user declines to proceed with a +# destructive statement, the current transaction (if any) is left untouched, +# by default. When setting `destructive_warning_restarts_connection` to +# "True", the connection to the server is restarted. In that case, the +# transaction (if any) is rolled back. +destructive_warning_restarts_connection = False + +# When this option is on (and if `destructive_warning` is not empty), +# destructive statements are not executed when outside of a transaction. +destructive_statements_require_transaction = False # Enables expand mode, which is similar to `\x` in psql. expand = False @@ -37,9 +50,21 @@ expand = False # Enables auto expand mode, which is similar to `\x auto` in psql. auto_expand = False +# Auto-retry queries on connection failures and other operational errors. If +# False, will prompt to rerun the failed query instead of auto-retrying. +auto_retry_closed_connection = True + # If set to True, table suggestions will include a table alias generate_aliases = False +# Path to a json file that specifies specific table aliases to use when generate_aliases is set to True +# the format for this file should be: +# { +# "some_table_name": "desired_alias", +# "some_other_table_name": "another_alias" +# } +alias_map_file = + # log_file location. # In Unix/Linux: ~/.config/pgcli/log # In Windows: %USERPROFILE%\AppData\Local\dbcli\pgcli\log @@ -83,19 +108,27 @@ qualify_columns = if_more_than_one_table # When no schema is entered, only suggest objects in search_path search_path_filter = False -# Default pager. -# By default 'PAGER' environment variable is used -# pager = less -SRXF +# Default pager. See https://www.pgcli.com/pager for more information on settings. +# By default 'PAGER' environment variable is used. If the pager is less, and the 'LESS' +# environment variable is not set, then LESS='-SRXF' will be automatically set. +# pager = less # Timing of sql statements and table rendering. timing = True +# Hide the query text when executing named queries (\n ). +# Only the query results will be displayed. +# Can be toggled at runtime with \nq command. +hide_named_query_text = False + # Show/hide the informational toolbar with function keymap at the footer. show_bottom_toolbar = True # Table format. Possible values: psql, plain, simple, grid, fancy_grid, pipe, # ascii, double, github, orgtbl, rst, mediawiki, html, latex, latex_booktabs, -# textile, moinmoin, jira, vertical, tsv, csv. +# textile, moinmoin, jira, vertical, tsv, csv, sql-insert, sql-update, +# sql-update-1, sql-update-2 (formatter with sql-* prefix can format query +# output to executable insertion or updating sql). # Recommended: psql, fancy_grid and grid. table_format = psql @@ -119,9 +152,20 @@ on_error = STOP # Set threshold for row limit. Use 0 to disable limiting. row_limit = 1000 +# Truncate long text fields to this value for tabular display (does not apply to csv). +# Leave unset to disable truncation. Example: "max_field_width = " +# Be aware that formatting might get slow with values larger than 500 and tables with +# lots of records. +max_field_width = 500 + # Skip intro on startup and goodbye on exit less_chatty = False +# Show all Postgres error fields (as listed in +# https://www.postgresql.org/docs/current/protocol-error-fields.html). +# Can be toggled with \v. +verbose_errors = False + # Postgres prompt # \t - Current date and time # \u - Username @@ -132,7 +176,8 @@ less_chatty = False # \i - Postgres PID # \# - "@" sign if logged in as superuser, '>' in other case # \n - Newline -# \dsn_alias - name of dsn alias if -D option is used (empty otherwise) +# \T - Transaction status: '*' if in a valid transaction, '!' if in a failed transaction, '?' if disconnected, empty otherwise +# \dsn_alias - name of dsn connection string alias if -D option is used (empty otherwise) # \x1b[...m - insert ANSI escape sequence # eg: prompt = '\x1b[35m\u@\x1b[32m\h:\x1b[36m\d>' prompt = '\u@\h:\d> ' @@ -152,6 +197,10 @@ enable_pager = True # Use keyring to automatically save and load password in a secure manner keyring = True +# Automatically set the session time zone to the local time zone +# If unset, uses the server's time zone, which is the Postgres default +use_local_timezone = True + # Custom colors for the completion menu, toolbar, etc. [colors] completion-menu.completion.current = 'bg:#ffffff #000000' @@ -189,14 +238,38 @@ output.null = "#808080" # Named queries are queries you can execute by name. [named queries] +# ver = "SELECT version()" -# DSN to call by -D option +# Here's where you can provide a list of connection string aliases. +# You can use it by passing the -D option. `pgcli -D example_dsn` [alias_dsn] # example_dsn = postgresql://[user[:password]@][netloc][:port][/dbname] +# Initial commands to execute when connecting to any database. +[init-commands] +# example = "SET search_path TO myschema" + +# Initial commands to execute when connecting to a DSN alias. +[alias_dsn.init-commands] +# example_dsn = "SET search_path TO otherschema; SET timezone TO 'UTC'" + # Format for number representation # for decimal "d" - 12345678, ",d" - 12,345,678 # for float "g" - 123456.78, ",g" - 123,456.78 [data_formats] decimal = "" float = "" + +# Per column formats for date/timestamp columns +[column_date_formats] +# use strftime format, e.g. +# created = "%Y-%m-%d" + +# Per host ssh tunnel configuration +[ssh tunnels] +# ^example.*\.host$ = myuser:mypasswd@my.tunnel.com:4000 +# .*\.net = another.tunnel.com + +# Per dsn_alias ssh tunnel configuration +[dsn ssh tunnels] +# ^example_dsn$ = myuser:mypasswd@my.tunnel.com:4000 diff --git a/pgcli/pgcompleter.py b/pgcli/pgcompleter.py index 227e25c64..ced0f1687 100644 --- a/pgcli/pgcompleter.py +++ b/pgcli/pgcompleter.py @@ -1,6 +1,7 @@ +import json import logging import re -from itertools import count, repeat, chain +from itertools import count, chain import operator from collections import namedtuple, defaultdict, OrderedDict from cli_helpers.tabular_output import TabularOutputFormatter @@ -31,7 +32,6 @@ from .packages.parseutils.tables import TableReference from .packages.pgliterals.main import get_literals from .packages.prioritization import PrevalenceCounter -from .config import load_config, config_location _logger = logging.getLogger(__name__) @@ -47,30 +47,50 @@ def SchemaObject(name, schema=None, meta=None): _Candidate = namedtuple("Candidate", "completion prio meta synonyms prio2 display") -def Candidate( - completion, prio=None, meta=None, synonyms=None, prio2=None, display=None -): - return _Candidate( - completion, prio, meta, synonyms or [completion], prio2, display or completion - ) +def Candidate(completion, prio=None, meta=None, synonyms=None, prio2=None, display=None): + return _Candidate(completion, prio, meta, synonyms or [completion], prio2, display or completion) # Used to strip trailing '::some_type' from default-value expressions arg_default_type_strip_regex = re.compile(r"::[\w\.]+(\[\])?$") -normalize_ref = lambda ref: ref if ref[0] == '"' else '"' + ref.lower() + '"' +def normalize_ref(ref): + return ref if ref[0] == '"' else '"' + ref.lower() + '"' -def generate_alias(tbl): - """Generate a table alias, consisting of all upper-case letters in - the table name, or, if there are no upper-case letters, the first letter + - all letters preceded by _ - param tbl - unescaped name of the table to alias + +def generate_alias(tbl, alias_map=None): + """Generate a table alias. + + Given a table name will return an alias for that table using the first of + the following options there's a match for. + + 1. The predefined alias for table defined in the alias_map. + 2. All upper-case letters in the table name. + 3. The first letter of the table name and all letters preceded by _ + + :param tbl: unescaped name of the table to alias + :param alias_map: optional mapping of predefined table aliases """ - return "".join( - [l for l in tbl if l.isupper()] - or [l for l, prev in zip(tbl, "_" + tbl) if prev == "_" and l != "_"] - ) + if alias_map and tbl in alias_map: + return alias_map[tbl] + return "".join([l for l in tbl if l.isupper()] or [l for l, prev in zip(tbl, "_" + tbl) if prev == "_" and l != "_"]) + + +class InvalidMapFile(ValueError): + pass + + +def load_alias_map_file(path): + try: + with open(path) as fo: + alias_map = json.load(fo) + except FileNotFoundError as err: + raise InvalidMapFile(f"Cannot read alias_map_file - {err.filename} does not exist") + except json.JSONDecodeError: + raise InvalidMapFile(f"Cannot read alias_map_file - {path} is not valid json") + else: + return alias_map class PGCompleter(Completer): @@ -88,30 +108,24 @@ def __init__(self, smart_completion=True, pgspecial=None, settings=None): self.pgspecial = pgspecial self.prioritizer = PrevalenceCounter() settings = settings or {} - self.signature_arg_style = settings.get( - "signature_arg_style", "{arg_name} {arg_type}" - ) - self.call_arg_style = settings.get( - "call_arg_style", "{arg_name: <{max_arg_len}} := {arg_default}" - ) - self.call_arg_display_style = settings.get( - "call_arg_display_style", "{arg_name}" - ) + self.signature_arg_style = settings.get("signature_arg_style", "{arg_name} {arg_type}") + self.call_arg_style = settings.get("call_arg_style", "{arg_name: <{max_arg_len}} := {arg_default}") + self.call_arg_display_style = settings.get("call_arg_display_style", "{arg_name}") self.call_arg_oneliner_max = settings.get("call_arg_oneliner_max", 2) self.search_path_filter = settings.get("search_path_filter") self.generate_aliases = settings.get("generate_aliases") + alias_map_file = settings.get("alias_map_file") + if alias_map_file is not None: + self.alias_map = load_alias_map_file(alias_map_file) + else: + self.alias_map = None self.casing_file = settings.get("casing_file") self.insert_col_skip_patterns = [ - re.compile(pattern) - for pattern in settings.get( - "insert_col_skip_patterns", [r"^now\(\)$", r"^nextval\("] - ) + re.compile(pattern) for pattern in settings.get("insert_col_skip_patterns", [r"^now\(\)$", r"^nextval\("]) ] self.generate_casing_file = settings.get("generate_casing_file") self.qualify_columns = settings.get("qualify_columns", "if_more_than_one_table") - self.asterisk_column_order = settings.get( - "asterisk_column_order", "table_order" - ) + self.asterisk_column_order = settings.get("asterisk_column_order", "table_order") keyword_casing = settings.get("keyword_casing", "upper").lower() if keyword_casing not in ("upper", "lower", "auto"): @@ -127,11 +141,7 @@ def __init__(self, smart_completion=True, pgspecial=None, settings=None): self.all_completions = set(self.keywords + self.functions) def escape_name(self, name): - if name and ( - (not self.name_pattern.match(name)) - or (name.upper() in self.reserved_words) - or (name.upper() in self.functions) - ): + if name and ((not self.name_pattern.match(name)) or (name.upper() in self.reserved_words) or (name.upper() in self.functions)): name = '"%s"' % name return name @@ -157,7 +167,6 @@ def extend_keywords(self, additional_keywords): self.all_completions.update(additional_keywords) def extend_schemata(self, schemata): - # schemata is a list of schema names schemata = self.escaped_names(schemata) metadata = self.dbmetadata["tables"] @@ -198,9 +207,7 @@ def extend_relations(self, data, kind): try: metadata[schema][relname] = OrderedDict() except KeyError: - _logger.error( - "%r %r listed in unrecognized schema %r", kind, relname, schema - ) + _logger.error("%r %r listed in unrecognized schema %r", kind, relname, schema) self.all_completions.add(relname) def extend_columns(self, column_data, kind): @@ -226,7 +233,6 @@ def extend_columns(self, column_data, kind): self.all_completions.add(colname) def extend_functions(self, func_data): - # func_data is a list of function metadata namedtuples # dbmetadata['schema_name']['functions']['function_name'] should return @@ -260,7 +266,6 @@ def _refresh_arg_list_cache(self): } def extend_foreignkeys(self, fk_data): - # fk_data is a list of ForeignKey namedtuples, with fields # parentschema, childschema, parenttable, childtable, # parentcolumns, childcolumns @@ -276,14 +281,11 @@ def extend_foreignkeys(self, fk_data): childcol, parcol = e([fk.childcolumn, fk.parentcolumn]) childcolmeta = meta[childschema][childtable][childcol] parcolmeta = meta[parentschema][parenttable][parcol] - fk = ForeignKey( - parentschema, parenttable, parcol, childschema, childtable, childcol - ) + fk = ForeignKey(parentschema, parenttable, parcol, childschema, childtable, childcol) childcolmeta.foreignkeys.append(fk) parcolmeta.foreignkeys.append(fk) def extend_datatypes(self, type_data): - # dbmetadata['datatypes'][schema_name][type_name] should store type # metadata, such as composite type field names. Currently, we're not # storing any metadata beyond typename, so just store None @@ -423,12 +425,7 @@ def _match(item): # We also use the unescape_name to make sure quoted names have # the same priority as unquoted names. lexical_priority = ( - tuple( - 0 if c in " _" else -ord(c) - for c in self.unescape_name(item.lower()) - ) - + (1,) - + tuple(c for c in item) + tuple(0 if c in " _" else -ord(c) for c in self.unescape_name(item.lower())) + (1,) + tuple(c for c in item) ) item = self.case(item) @@ -466,9 +463,7 @@ def get_completions(self, document, complete_event, smart_completion=None): # If smart_completion is off then match any word that starts with # 'word_before_cursor'. if not smart_completion: - matches = self.find_matches( - word_before_cursor, self.all_completions, mode="strict" - ) + matches = self.find_matches(word_before_cursor, self.all_completions, mode="strict") completions = [m.completion for m in matches] return sorted(completions, key=operator.attrgetter("text")) @@ -491,77 +486,58 @@ def get_completions(self, document, complete_event, smart_completion=None): def get_column_matches(self, suggestion, word_before_cursor): tables = suggestion.table_refs - do_qualify = suggestion.qualifiable and { - "always": True, - "never": False, - "if_more_than_one_table": len(tables) > 1, - }[self.qualify_columns] - qualify = lambda col, tbl: ( - (tbl + "." + self.case(col)) if do_qualify else self.case(col) + do_qualify = ( + suggestion.qualifiable + and { + "always": True, + "never": False, + "if_more_than_one_table": len(tables) > 1, + }[self.qualify_columns] ) + + def qualify(col, tbl): + return (tbl + "." + self.case(col)) if do_qualify else self.case(col) + _logger.debug("Completion column scope: %r", tables) scoped_cols = self.populate_scoped_cols(tables, suggestion.local_tables) def make_cand(name, ref): - synonyms = (name, generate_alias(self.case(name))) + synonyms = (name, generate_alias(self.case(name), alias_map=self.alias_map)) return Candidate(qualify(name, ref), 0, "column", synonyms) def flat_cols(): - return [ - make_cand(c.name, t.ref) - for t, cols in scoped_cols.items() - for c in cols - ] + return [make_cand(c.name, t.ref) for t, cols in scoped_cols.items() for c in cols] if suggestion.require_last_table: # require_last_table is used for 'tb11 JOIN tbl2 USING (...' which should # suggest only columns that appear in the last table and one more ltbl = tables[-1].ref - other_tbl_cols = { - c.name for t, cs in scoped_cols.items() if t.ref != ltbl for c in cs - } - scoped_cols = { - t: [col for col in cols if col.name in other_tbl_cols] - for t, cols in scoped_cols.items() - if t.ref == ltbl - } + other_tbl_cols = {c.name for t, cs in scoped_cols.items() if t.ref != ltbl for c in cs} + scoped_cols = {t: [col for col in cols if col.name in other_tbl_cols] for t, cols in scoped_cols.items() if t.ref == ltbl} lastword = last_word(word_before_cursor, include="most_punctuations") if lastword == "*": if suggestion.context == "insert": - def filter(col): + def _filter(col): if not col.has_default: return True - return not any( - p.match(col.default) for p in self.insert_col_skip_patterns - ) + return not any(p.match(col.default) for p in self.insert_col_skip_patterns) - scoped_cols = { - t: [col for col in cols if filter(col)] - for t, cols in scoped_cols.items() - } + scoped_cols = {t: [col for col in cols if _filter(col)] for t, cols in scoped_cols.items()} if self.asterisk_column_order == "alphabetic": for cols in scoped_cols.values(): cols.sort(key=operator.attrgetter("name")) - if ( - lastword != word_before_cursor - and len(tables) == 1 - and word_before_cursor[-len(lastword) - 1] == "." - ): + if lastword != word_before_cursor and len(tables) == 1 and word_before_cursor[-len(lastword) - 1] == ".": # User typed x.*; replicate "x." for all columns except the # first, which gets the original (as we only replace the "*"") sep = ", " + word_before_cursor[:-1] collist = sep.join(self.case(c.completion) for c in flat_cols()) else: - collist = ", ".join( - qualify(c.name, t.ref) for t, cs in scoped_cols.items() for c in cs - ) + collist = ", ".join(qualify(c.name, t.ref) for t, cs in scoped_cols.items() for c in cs) return [ Match( - completion=Completion( - collist, -1, display_meta="columns", display="*" - ), + completion=Completion(collist, -1, display_meta="columns", display="*"), priority=(1, 1, 1), ) ] @@ -576,7 +552,7 @@ def alias(self, tbl, tbls): tbl = self.case(tbl) tbls = {normalize_ref(t.ref) for t in tbls} if self.generate_aliases: - tbl = generate_alias(self.unescape_name(tbl)) + tbl = generate_alias(self.unescape_name(tbl), alias_map=self.alias_map) if normalize_ref(tbl) not in tbls: return tbl elif tbl[0] == '"': @@ -595,12 +571,7 @@ def get_join_matches(self, suggestion, word_before_cursor): other_tbls = {(t.schema, t.name) for t in list(cols)[:-1]} joins = [] # Iterate over FKs in existing tables to find potential joins - fks = ( - (fk, rtbl, rcol) - for rtbl, rcols in cols.items() - for rcol in rcols - for fk in rcol.foreignkeys - ) + fks = ((fk, rtbl, rcol) for rtbl, rcols in cols.items() for rcol in rcols for fk in rcol.foreignkeys) col = namedtuple("col", "schema tbl col") for fk, rtbl, rcol in fks: right = col(rtbl.schema, rtbl.name, rcol.name) @@ -612,31 +583,21 @@ def get_join_matches(self, suggestion, word_before_cursor): c = self.case if self.generate_aliases or normalize_ref(left.tbl) in refs: lref = self.alias(left.tbl, suggestion.table_refs) - join = "{0} {4} ON {4}.{1} = {2}.{3}".format( - c(left.tbl), c(left.col), rtbl.ref, c(right.col), lref - ) + join = "{0} {4} ON {4}.{1} = {2}.{3}".format(c(left.tbl), c(left.col), rtbl.ref, c(right.col), lref) else: - join = "{0} ON {0}.{1} = {2}.{3}".format( - c(left.tbl), c(left.col), rtbl.ref, c(right.col) - ) - alias = generate_alias(self.case(left.tbl)) + join = "{0} ON {0}.{1} = {2}.{3}".format(c(left.tbl), c(left.col), rtbl.ref, c(right.col)) + alias = generate_alias(self.case(left.tbl), alias_map=self.alias_map) synonyms = [ join, - "{0} ON {0}.{1} = {2}.{3}".format( - alias, c(left.col), rtbl.ref, c(right.col) - ), + "{0} ON {0}.{1} = {2}.{3}".format(alias, c(left.col), rtbl.ref, c(right.col)), ] # Schema-qualify if (1) new table in same schema as old, and old # is schema-qualified, or (2) new in other schema, except public if not suggestion.schema and ( - qualified[normalize_ref(rtbl.ref)] - and left.schema == right.schema - or left.schema not in (right.schema, "public") + qualified[normalize_ref(rtbl.ref)] and left.schema == right.schema or left.schema not in (right.schema, "public") ): join = left.schema + "." + join - prio = ref_prio[normalize_ref(rtbl.ref)] * 2 + ( - 0 if (left.schema, left.tbl) in other_tbls else 1 - ) + prio = ref_prio[normalize_ref(rtbl.ref)] * 2 + (0 if (left.schema, left.tbl) in other_tbls else 1) joins.append(Candidate(join, prio, "join", synonyms=synonyms)) return self.find_matches(word_before_cursor, joins, meta="join") @@ -669,9 +630,7 @@ def list_dict(pairs): # Turns [(a, b), (a, c)] into {a: [b, c]} # Tables that are closer to the cursor get higher prio ref_prio = {tbl.ref: num for num, tbl in enumerate(suggestion.table_refs)} # Map (schema, table, col) to tables - coldict = list_dict( - ((t.schema, t.name, c.name), t) for t, c in cols if t.ref != lref - ) + coldict = list_dict(((t.schema, t.name, c.name), t) for t, c in cols if t.ref != lref) # For each fk from the left table, generate a join condition if # the other table is also in the scope fks = ((fk, lcol.name) for lcol in lcols for fk in lcol.foreignkeys) @@ -694,7 +653,6 @@ def list_dict(pairs): # Turns [(a, b), (a, c)] into {a: [b, c]} return self.find_matches(word_before_cursor, conds, meta="join") def get_function_matches(self, suggestion, word_before_cursor, alias=False): - if suggestion.usage == "from": # Only suggest functions allowed in FROM clause @@ -703,24 +661,16 @@ def filt(f): not f.is_aggregate and not f.is_window and not f.is_extension - and ( - f.is_public - or f.schema_name in self.search_path - or f.schema_name == suggestion.schema - ) + and (f.is_public or f.schema_name in self.search_path or f.schema_name == suggestion.schema) ) else: alias = False def filt(f): - return not f.is_extension and ( - f.is_public or f.schema_name == suggestion.schema - ) + return not f.is_extension and (f.is_public or f.schema_name == suggestion.schema) - arg_mode = {"signature": "signature", "special": None}.get( - suggestion.usage, "call" - ) + arg_mode = {"signature": "signature", "special": None}.get(suggestion.usage, "call") # Function overloading means we way have multiple functions of the same # name at this point, so keep unique names only @@ -731,9 +681,7 @@ def filt(f): if not suggestion.schema and not suggestion.usage: # also suggest hardcoded functions using startswith matching - predefined_funcs = self.find_matches( - word_before_cursor, self.functions, mode="strict", meta="function" - ) + predefined_funcs = self.find_matches(word_before_cursor, self.functions, mode="strict", meta="function") matches.extend(predefined_funcs) return matches @@ -784,10 +732,7 @@ def _arg_list(self, func, usage): return "()" multiline = usage == "call" and len(args) > self.call_arg_oneliner_max max_arg_len = max(len(a.name) for a in args) if multiline else 0 - args = ( - self._format_arg(template, arg, arg_num + 1, max_arg_len) - for arg_num, arg in enumerate(args) - ) + args = (self._format_arg(template, arg, arg_num + 1, max_arg_len) for arg_num, arg in enumerate(args)) if multiline: return "(" + ",".join("\n " + a for a in args if a) + "\n)" else: @@ -821,7 +766,7 @@ def _make_cand(self, tbl, do_alias, suggestion, arg_mode=None): cased_tbl = self.case(tbl.name) if do_alias: alias = self.alias(cased_tbl, suggestion.table_refs) - synonyms = (cased_tbl, generate_alias(cased_tbl)) + synonyms = (cased_tbl, generate_alias(cased_tbl, alias_map=self.alias_map)) maybe_alias = (" " + alias) if do_alias else "" maybe_schema = (self.case(tbl.schema) + ".") if tbl.schema else "" suffix = self._arg_list_cache[arg_mode][tbl.meta] if arg_mode else "" @@ -886,15 +831,11 @@ def get_keyword_matches(self, suggestion, word_before_cursor): else: keywords = [k.lower() for k in keywords] - return self.find_matches( - word_before_cursor, keywords, mode="strict", meta="keyword" - ) + return self.find_matches(word_before_cursor, keywords, mode="strict", meta="keyword") def get_path_matches(self, _, word_before_cursor): completer = PathCompleter(expanduser=True) - document = Document( - text=word_before_cursor, cursor_position=len(word_before_cursor) - ) + document = Document(text=word_before_cursor, cursor_position=len(word_before_cursor)) for c in completer.get_completions(document, None): yield Match(completion=c, priority=(0,)) @@ -915,18 +856,12 @@ def get_datatype_matches(self, suggestion, word_before_cursor): if not suggestion.schema: # Also suggest hardcoded types - matches.extend( - self.find_matches( - word_before_cursor, self.datatypes, mode="strict", meta="datatype" - ) - ) + matches.extend(self.find_matches(word_before_cursor, self.datatypes, mode="strict", meta="datatype")) return matches def get_namedquery_matches(self, _, word_before_cursor): - return self.find_matches( - word_before_cursor, NamedQueries.instance.list(), meta="named query" - ) + return self.find_matches(word_before_cursor, NamedQueries.instance.list(), meta="named query") suggestion_matchers = { FromClauseItem: get_from_clause_item_matches, @@ -1016,9 +951,7 @@ def populate_schema_objects(self, schema, obj_type): """ return [ - SchemaObject( - name=obj, schema=(self._maybe_schema(schema=sch, parent=schema)) - ) + SchemaObject(name=obj, schema=(self._maybe_schema(schema=sch, parent=schema))) for sch in self._get_schemas(obj_type, schema) for obj in self.dbmetadata[obj_type][sch].keys() ] diff --git a/pgcli/pgexecute.py b/pgcli/pgexecute.py index a013b558c..2864c8645 100644 --- a/pgcli/pgexecute.py +++ b/pgcli/pgexecute.py @@ -1,148 +1,63 @@ +import ipaddress import logging -import select import traceback - +from collections import namedtuple +import re import pgspecial as special -import psycopg2 -import psycopg2.errorcodes -import psycopg2.extensions as ext -import psycopg2.extras +import psycopg +import psycopg.sql +from psycopg.conninfo import make_conninfo import sqlparse -from psycopg2.extensions import POLL_OK, POLL_READ, POLL_WRITE, make_dsn from .packages.parseutils.meta import FunctionMetadata, ForeignKey _logger = logging.getLogger(__name__) -# Cast all database input to unicode automatically. -# See http://initd.org/psycopg/docs/usage.html#unicode-handling for more info. -ext.register_type(ext.UNICODE) -ext.register_type(ext.UNICODEARRAY) -ext.register_type(ext.new_type((705,), "UNKNOWN", ext.UNICODE)) -# See https://github.com/dbcli/pgcli/issues/426 for more details. -# This registers a unicode type caster for datatype 'RECORD'. -ext.register_type(ext.new_type((2249,), "RECORD", ext.UNICODE)) +ViewDef = namedtuple("ViewDef", "nspname relname relkind viewdef reloptions checkoption") -# Cast bytea fields to text. By default, this will render as hex strings with -# Postgres 9+ and as escaped binary in earlier versions. -ext.register_type(ext.new_type((17,), "BYTEA_TEXT", psycopg2.STRING)) -# TODO: Get default timeout from pgclirc? -_WAIT_SELECT_TIMEOUT = 1 -_wait_callback_is_set = False +# we added this funcion to strip beginning comments +# because sqlparse didn't handle tem well. It won't be needed if sqlparse +# does parsing of this situation better -def _wait_select(conn): - """ - copy-pasted from psycopg2.extras.wait_select - the default implementation doesn't define a timeout in the select calls - """ - try: - while 1: - try: - state = conn.poll() - if state == POLL_OK: - break - elif state == POLL_READ: - select.select([conn.fileno()], [], [], _WAIT_SELECT_TIMEOUT) - elif state == POLL_WRITE: - select.select([], [conn.fileno()], [], _WAIT_SELECT_TIMEOUT) - else: - raise conn.OperationalError("bad state from poll: %s" % state) - except KeyboardInterrupt: - conn.cancel() - # the loop will be broken by a server error - continue - except OSError as e: - errno = e.args[0] - if errno != 4: - raise - except psycopg2.OperationalError: - pass +def remove_beginning_comments(command): + # Regular expression pattern to match comments + pattern = r"^(/\*.*?\*/|--.*?)(?:\n|$)" + # Find and remove all comments from the beginning + cleaned_command = command + comments = [] + match = re.match(pattern, cleaned_command, re.DOTALL) + while match: + comments.append(match.group()) + cleaned_command = cleaned_command[len(match.group()) :].lstrip() + match = re.match(pattern, cleaned_command, re.DOTALL) -def _set_wait_callback(is_virtual_database): - global _wait_callback_is_set - if _wait_callback_is_set: - return - _wait_callback_is_set = True - if is_virtual_database: - return - # When running a query, make pressing CTRL+C raise a KeyboardInterrupt - # See http://initd.org/psycopg/articles/2014/07/20/cancelling-postgresql-statements-python/ - # See also https://github.com/psycopg/psycopg2/issues/468 - ext.set_wait_callback(_wait_select) + return [cleaned_command, comments] -def register_date_typecasters(connection): - """ - Casts date and timestamp values to string, resolves issues with out of - range dates (e.g. BC) which psycopg2 can't handle - """ +def register_typecasters(connection): + """Casts date and timestamp values to string, resolves issues with out-of-range + dates (e.g. BC) which psycopg can't handle""" + for forced_text_type in [ + "date", + "time", + "timestamp", + "timestamptz", + "bytea", + "json", + "jsonb", + ]: + connection.adapters.register_loader(forced_text_type, psycopg.types.string.TextLoader) - def cast_date(value, cursor): - return value - - cursor = connection.cursor() - cursor.execute("SELECT NULL::date") - if cursor.description is None: - return - date_oid = cursor.description[0][1] - cursor.execute("SELECT NULL::timestamp") - timestamp_oid = cursor.description[0][1] - cursor.execute("SELECT NULL::timestamp with time zone") - timestamptz_oid = cursor.description[0][1] - oids = (date_oid, timestamp_oid, timestamptz_oid) - new_type = psycopg2.extensions.new_type(oids, "DATE", cast_date) - psycopg2.extensions.register_type(new_type) - - -def register_json_typecasters(conn, loads_fn): - """Set the function for converting JSON data for a connection. - - Use the supplied function to decode JSON data returned from the database - via the given connection. The function should accept a single argument of - the data as a string encoded in the database's character encoding. - psycopg2's default handler for JSON data is json.loads. - http://initd.org/psycopg/docs/extras.html#json-adaptation - - This function attempts to register the typecaster for both JSON and JSONB - types. - - Returns a set that is a subset of {'json', 'jsonb'} indicating which types - (if any) were successfully registered. - """ - available = set() - - for name in ["json", "jsonb"]: - try: - psycopg2.extras.register_json(conn, loads=loads_fn, name=name) - available.add(name) - except (psycopg2.ProgrammingError, psycopg2.errors.ProtocolViolation): - pass - return available +# pg3: I don't know what is this +class ProtocolSafeCursor(psycopg.Cursor): + """This class wraps and suppresses Protocol Errors with pgbouncer database. + See https://github.com/dbcli/pgcli/pull/1097. + Pgbouncer database is a virtual database with its own set of commands.""" - -def register_hstore_typecaster(conn): - """ - Instead of using register_hstore() which converts hstore into a python - dict, we query the 'oid' of hstore which will be different for each - database and register a type caster that converts it to unicode. - http://initd.org/psycopg/docs/extras.html#psycopg2.extras.register_hstore - """ - with conn.cursor() as cur: - try: - cur.execute( - "select t.oid FROM pg_type t WHERE t.typname = 'hstore' and t.typisdefined" - ) - oid = cur.fetchone()[0] - ext.register_type(ext.new_type((oid,), "HSTORE", ext.UNICODE)) - except Exception: - pass - - -class ProtocolSafeCursor(psycopg2.extensions.cursor): def __init__(self, *args, **kwargs): self.protocol_error = False self.protocol_message = "" @@ -163,19 +78,22 @@ def fetchone(self): return (self.protocol_message,) return super().fetchone() - def execute(self, sql, args=None): + # def mogrify(self, query, params): + # args = [Literal(v).as_string(self.connection) for v in params] + # return query % tuple(args) + # + def execute(self, *args, **kwargs): try: - psycopg2.extensions.cursor.execute(self, sql, args) + super().execute(*args, **kwargs) self.protocol_error = False self.protocol_message = "" - except psycopg2.errors.ProtocolViolation as ex: + except psycopg.errors.ProtocolViolation as ex: self.protocol_error = True - self.protocol_message = ex.pgerror + self.protocol_message = str(ex) _logger.debug("%s: %s" % (ex.__class__.__name__, ex)) class PGExecute: - # The boolean argument to the current_schemas function indicates whether # implicit schemas, e.g. pg_catalog search_path_query = """ @@ -245,6 +163,7 @@ def __init__( host=None, port=None, dsn=None, + notify_callback=None, **kwargs, ): self._conn_params = {} @@ -257,6 +176,7 @@ def __init__( self.port = None self.server_version = None self.extra_args = None + self.notify_callback = notify_callback self.connect(database, user, password, host, port, dsn, **kwargs) self.reset_expanded = None @@ -279,11 +199,10 @@ def connect( dsn=None, **kwargs, ): - conn_params = self._conn_params.copy() new_params = { - "database": database, + "dbname": database, "user": user, "password": password, "host": host, @@ -296,15 +215,17 @@ def connect( new_params = {"dsn": new_params["dsn"], "password": new_params["password"]} if new_params["password"]: - new_params["dsn"] = make_dsn( - new_params["dsn"], password=new_params.pop("password") - ) + new_params["dsn"] = make_conninfo(new_params["dsn"], password=new_params.pop("password")) conn_params.update({k: v for k, v in new_params.items() if v}) - conn_params["cursor_factory"] = ProtocolSafeCursor - conn = psycopg2.connect(**conn_params) - conn.set_client_encoding("utf8") + if "dsn" in conn_params: + other_params = {k: v for k, v in conn_params.items() if k != "dsn"} + conn_info = make_conninfo(conn_params["dsn"], **other_params) + else: + conn_info = make_conninfo(**conn_params) + conn = psycopg.connect(conn_info) + conn.cursor_factory = ProtocolSafeCursor self._conn_params = conn_params if self.conn: @@ -312,22 +233,13 @@ def connect( self.conn = conn self.conn.autocommit = True + if self.notify_callback is not None: + self.conn.add_notify_handler(self.notify_callback) + # When we connect using a DSN, we don't really know what db, # user, etc. we connected to. Let's read it. # Note: moved this after setting autocommit because of #664. - libpq_version = psycopg2.__libpq_version__ - dsn_parameters = {} - if libpq_version >= 93000: - # use actual connection info from psycopg2.extensions.Connection.info - # as libpq_version > 9.3 is available and required dependency - dsn_parameters = conn.info.dsn_parameters - else: - try: - dsn_parameters = conn.get_dsn_parameters() - except Exception as x: - # https://github.com/dbcli/pgcli/issues/1110 - # PQconninfo not available in libpq < 9.3 - _logger.info("Exception in get_dsn_parameters: %r", x) + dsn_parameters = conn.info.get_parameters() if dsn_parameters: self.dbname = dsn_parameters.get("dbname") @@ -344,25 +256,24 @@ def connect( self.extra_args = kwargs if not self.host: - self.host = ( - "pgbouncer" - if self.is_virtual_database() - else self.get_socket_directory() - ) + self.host = "pgbouncer" if self.is_virtual_database() else self.get_socket_directory() - self.pid = conn.get_backend_pid() - self.superuser = conn.get_parameter_status("is_superuser") in ("on", "1") - self.server_version = conn.get_parameter_status("server_version") or "" + self.pid = conn.info.backend_pid + self.superuser = conn.info.parameter_status("is_superuser") in ("on", "1") + self.server_version = conn.info.parameter_status("server_version") or "" - _set_wait_callback(self.is_virtual_database()) + # _set_wait_callback(self.is_virtual_database()) if not self.is_virtual_database(): - register_date_typecasters(conn) - register_json_typecasters(self.conn, self._json_typecaster) - register_hstore_typecaster(self.conn) + register_typecasters(conn) @property def short_host(self): + try: + ipaddress.ip_address(self.host) + return self.host + except ValueError: + pass if "," in self.host: host, _, _ = self.host.partition(",") else: @@ -380,30 +291,33 @@ def _select_one(self, cur, sql): cur.execute(sql) return cur.fetchone() - def _json_typecaster(self, json_data): - """Interpret incoming JSON data as a string. - - The raw data is decoded using the connection's encoding, which defaults - to the database's encoding. - - See http://initd.org/psycopg/docs/connection.html#connection.encoding - """ - - return json_data - def failed_transaction(self): - status = self.conn.get_transaction_status() - return status == ext.TRANSACTION_STATUS_INERROR + return self.conn.info.transaction_status == psycopg.pq.TransactionStatus.INERROR def valid_transaction(self): - status = self.conn.get_transaction_status() - return ( - status == ext.TRANSACTION_STATUS_ACTIVE - or status == ext.TRANSACTION_STATUS_INTRANS - ) + status = self.conn.info.transaction_status + return status == psycopg.pq.TransactionStatus.ACTIVE or status == psycopg.pq.TransactionStatus.INTRANS + + def is_connection_closed(self): + return self.conn.info.transaction_status == psycopg.pq.TransactionStatus.UNKNOWN + + @property + def transaction_indicator(self): + if self.is_connection_closed(): + return "?" + if self.failed_transaction(): + return "!" + if self.valid_transaction(): + return "*" + return "" def run( - self, statement, pgspecial=None, exception_formatter=None, on_error_resume=False + self, + statement, + pgspecial=None, + exception_formatter=None, + on_error_resume=False, + explain_mode=False, ): """Execute the sql in the database and return the results. @@ -424,17 +338,37 @@ def run( # Remove spaces and EOL statement = statement.strip() if not statement: # Empty string - yield (None, None, None, None, statement, False, False) + yield None, None, None, None, statement, False, False + + # sql parse doesn't split on a comment first + special + # so we're going to do it + + removed_comments = [] + sqlarr = [] + cleaned_command = "" - # Split the sql into separate queries and run each one. - for sql in sqlparse.split(statement): + # could skip if statement doesn't match ^-- or ^/* + cleaned_command, removed_comments = remove_beginning_comments(statement) + + sqlarr = sqlparse.split(cleaned_command) + + # now re-add the beginning comments if there are any, so that they show up in + # log files etc when running these commands + + if len(removed_comments) > 0: + sqlarr = removed_comments + sqlarr + + # run each sql query + for sql in sqlarr: # Remove spaces, eol and semi-colons. sql = sql.rstrip(";") - sql = sqlparse.format(sql, strip_comments=True).strip() + sql = sqlparse.format(sql, strip_comments=False).strip() if not sql: continue try: - if pgspecial: + if explain_mode: + sql = self.explain_prefix() + sql + elif pgspecial: # \G is treated specially since we have to set the expanded output. if sql.endswith("\\G"): if not pgspecial.expanded_output: @@ -446,7 +380,7 @@ def run( _logger.debug("Trying a pgspecial command. sql: %r", sql) try: cur = self.conn.cursor() - except psycopg2.InterfaceError: + except psycopg.InterfaceError: # edge case when connection is already closed, but we # don't need cursor for special_cmd.arg_type == NO_QUERY. # See https://github.com/dbcli/pgcli/issues/1014. @@ -470,7 +404,7 @@ def run( # Not a special command, so execute as normal sql yield self.execute_normal_sql(sql) + (sql, True, False) - except psycopg2.DatabaseError as e: + except psycopg.DatabaseError as e: _logger.error("sql: %r, error: %r", sql, e) _logger.error("traceback: %r", traceback.format_exc()) @@ -490,7 +424,7 @@ def _must_raise(self, e): """Return true if e is an error that should not be caught in ``run``. An uncaught error will prompt the user to reconnect; as long as we - detect that the connection is stil open, we catch the error, as + detect that the connection is still open, we catch the error, as reconnecting won't solve that problem. :param e: DatabaseError. An exception raised while executing a query. @@ -503,13 +437,27 @@ def _must_raise(self, e): def execute_normal_sql(self, split_sql): """Returns tuple (title, rows, headers, status)""" _logger.debug("Regular sql statement. sql: %r", split_sql) - cur = self.conn.cursor() - cur.execute(split_sql) - # conn.notices persist between queies, we use pop to clear out the list title = "" - while len(self.conn.notices) > 0: - title = self.conn.notices.pop() + title + + def handle_notices(n): + nonlocal title + title = f"{title}" + if n.message_primary is not None: + title = f"{title}\n{n.message_primary}" + if n.message_detail is not None: + title = f"{title}\n{n.message_detail}" + + self.conn.add_notice_handler(handle_notices) + + if self.is_virtual_database() and "show help" in split_sql.lower(): + # see https://github.com/psycopg/psycopg/issues/303 + # special case "show help" in pgbouncer + res = self.conn.pgconn.exec_(split_sql.encode()) + return title, None, None, res.command_status.decode() + + cur = self.conn.cursor() + cur.execute(split_sql) # cur.description will be None for operations that do not return # rows. @@ -531,7 +479,7 @@ def search_path(self): _logger.debug("Search path query. sql: %r", self.search_path_query) cur.execute(self.search_path_query) return [x[0] for x in cur.fetchall()] - except psycopg2.ProgrammingError: + except psycopg.ProgrammingError: fallback = "SELECT * FROM current_schemas(true)" with self.conn.cursor() as cur: _logger.debug("Search path query. sql: %r", fallback) @@ -541,7 +489,6 @@ def search_path(self): def view_definition(self, spec): """Returns the SQL defining views described by `spec`""" - template = "CREATE OR REPLACE {6} VIEW {0}.{1} AS \n{3}" # 2: relkind, v or m (materialized) # 4: reloptions, null # 5: checkoption: local or cascaded @@ -550,11 +497,22 @@ def view_definition(self, spec): _logger.debug("View Definition Query. sql: %r\nspec: %r", sql, spec) try: cur.execute(sql, (spec,)) - except psycopg2.ProgrammingError: + except psycopg.ProgrammingError: raise RuntimeError(f"View {spec} does not exist.") - result = cur.fetchone() - view_type = "MATERIALIZED" if result[2] == "m" else "" - return template.format(*result + (view_type,)) + result = ViewDef(*cur.fetchone()) + if result.relkind == "m": + template = "CREATE OR REPLACE MATERIALIZED VIEW {name} AS \n{stmt}" + else: + template = "CREATE OR REPLACE VIEW {name} AS \n{stmt}" + return ( + psycopg.sql + .SQL(template) + .format( + name=psycopg.sql.Identifier(result.nspname, result.relname), + stmt=psycopg.sql.SQL(result.viewdef), + ) + .as_string(self.conn) + ) def function_definition(self, spec): """Returns the SQL defining functions described by `spec`""" @@ -566,7 +524,7 @@ def function_definition(self, spec): cur.execute(sql, (spec,)) result = cur.fetchone() return result[0] - except psycopg2.ProgrammingError: + except psycopg.ProgrammingError: raise RuntimeError(f"Function {spec} does not exist.") def schemata(self): @@ -590,9 +548,9 @@ def _relations(self, kinds=("r", "p", "f", "v", "m")): """ with self.conn.cursor() as cur: - sql = cur.mogrify(self.tables_query, [kinds]) - _logger.debug("Tables Query. sql: %r", sql) - cur.execute(sql) + # sql = cur.mogrify(self.tables_query, kinds) + # _logger.debug("Tables Query. sql: %r", sql) + cur.execute(self.tables_query, [kinds]) yield from cur def tables(self): @@ -618,7 +576,7 @@ def _columns(self, kinds=("r", "p", "f", "v", "m")): :return: list of (schema_name, relation_name, column_name, column_type) tuples """ - if self.conn.server_version >= 80400: + if self.conn.info.server_version >= 80400: columns_query = """ SELECT nsp.nspname schema_name, cls.relname table_name, @@ -659,9 +617,9 @@ def _columns(self, kinds=("r", "p", "f", "v", "m")): ORDER BY 1, 2, att.attnum""" with self.conn.cursor() as cur: - sql = cur.mogrify(columns_query, [kinds]) - _logger.debug("Columns Query. sql: %r", sql) - cur.execute(sql) + # sql = cur.mogrify(columns_query, kinds) + # _logger.debug("Columns Query. sql: %r", sql) + cur.execute(columns_query, [kinds]) yield from cur def table_columns(self): @@ -692,9 +650,7 @@ def is_protocol_error(self): def get_socket_directory(self): with self.conn.cursor() as cur: - _logger.debug( - "Socket directory Query. sql: %r", self.socket_directory_query - ) + _logger.debug("Socket directory Query. sql: %r", self.socket_directory_query) cur.execute(self.socket_directory_query) result = cur.fetchone() return result[0] if result else "" @@ -702,7 +658,7 @@ def get_socket_directory(self): def foreignkeys(self): """Yields ForeignKey named tuples""" - if self.conn.server_version < 90000: + if self.conn.info.server_version < 90000: return with self.conn.cursor() as cur: @@ -742,7 +698,7 @@ def foreignkeys(self): def functions(self): """Yields FunctionMetadata named tuples""" - if self.conn.server_version >= 110000: + if self.conn.info.server_version >= 110000: query = """ SELECT n.nspname schema_name, p.proname func_name, @@ -762,7 +718,7 @@ def functions(self): WHERE p.prorettype::regtype != 'trigger'::regtype ORDER BY 1, 2 """ - elif self.conn.server_version > 90000: + elif self.conn.info.server_version > 90000: query = """ SELECT n.nspname schema_name, p.proname func_name, @@ -782,7 +738,7 @@ def functions(self): WHERE p.prorettype::regtype != 'trigger'::regtype ORDER BY 1, 2 """ - elif self.conn.server_version >= 80400: + elif self.conn.info.server_version >= 80400: query = """ SELECT n.nspname schema_name, p.proname func_name, @@ -833,7 +789,7 @@ def datatypes(self): """Yields tuples of (schema_name, type_name)""" with self.conn.cursor() as cur: - if self.conn.server_version > 90000: + if self.conn.info.server_version > 90000: query = """ SELECT n.nspname schema_name, t.typname type_name @@ -921,3 +877,17 @@ def casing(self): cur.execute(query) for row in cur: yield row[0] + + def explain_prefix(self): + return "EXPLAIN (ANALYZE, COSTS, VERBOSE, BUFFERS, FORMAT JSON) " + + def get_timezone(self) -> str: + query = psycopg.sql.SQL("show time zone") + with self.conn.cursor() as cur: + cur.execute(query) + return cur.fetchone()[0] + + def set_timezone(self, timezone: str): + query = psycopg.sql.SQL("set time zone {}").format(psycopg.sql.Identifier(timezone)) + with self.conn.cursor() as cur: + cur.execute(query) diff --git a/pgcli/pgstyle.py b/pgcli/pgstyle.py index 822903705..450f2c8b9 100644 --- a/pgcli/pgstyle.py +++ b/pgcli/pgstyle.py @@ -83,13 +83,11 @@ def style_factory(name, cli_style): logger.error("Unhandled style / class name: %s", token) else: # treat as prompt style name (2.0). See default style names here: - # https://github.com/jonathanslenders/python-prompt-toolkit/blob/master/prompt_toolkit/styles/defaults.py + # https://github.com/prompt-toolkit/python-prompt-toolkit/blob/master/src/prompt_toolkit/styles/defaults.py prompt_styles.append((token, cli_style[token])) override_style = Style([("bottom-toolbar", "noreverse")]) - return merge_styles( - [style_from_pygments_cls(style), override_style, Style(prompt_styles)] - ) + return merge_styles([style_from_pygments_cls(style), override_style, Style(prompt_styles)]) def style_factory_output(name, cli_style): diff --git a/pgcli/pgtoolbar.py b/pgcli/pgtoolbar.py index 41f903dab..ae8bc67c9 100644 --- a/pgcli/pgtoolbar.py +++ b/pgcli/pgtoolbar.py @@ -1,18 +1,14 @@ -from pkg_resources import packaging - -import prompt_toolkit from prompt_toolkit.key_binding.vi_state import InputMode from prompt_toolkit.application import get_app -parse_version = packaging.version.parse - vi_modes = { InputMode.INSERT: "I", InputMode.NAVIGATION: "N", InputMode.REPLACE: "R", InputMode.INSERT_MULTIPLE: "M", } -if parse_version(prompt_toolkit.__version__) >= parse_version("3.0.6"): +# REPLACE_SINGLE is available in prompt_toolkit >= 3.0.6 +if "REPLACE_SINGLE" in {e.name for e in InputMode}: vi_modes[InputMode.REPLACE_SINGLE] = "R" @@ -41,26 +37,23 @@ def get_toolbar_tokens(): if pgcli.multiline_mode == "safe": result.append(("class:bottom-toolbar", " ([Esc] [Enter] to execute]) ")) else: - result.append( - ("class:bottom-toolbar", " (Semi-colon [;] will end the line) ") - ) + result.append(("class:bottom-toolbar", " (Semi-colon [;] will end the line) ")) if pgcli.vi_mode: - result.append( - ("class:bottom-toolbar", "[F4] Vi-mode (" + _get_vi_mode() + ")") - ) + result.append(("class:bottom-toolbar", "[F4] Vi-mode (" + _get_vi_mode() + ") ")) + else: + result.append(("class:bottom-toolbar", "[F4] Emacs-mode ")) + + if pgcli.explain_mode: + result.append(("class:bottom-toolbar", "[F5] Explain: ON ")) else: - result.append(("class:bottom-toolbar", "[F4] Emacs-mode")) + result.append(("class:bottom-toolbar", "[F5] Explain: OFF ")) if pgcli.pgexecute.failed_transaction(): - result.append( - ("class:bottom-toolbar.transaction.failed", " Failed transaction") - ) + result.append(("class:bottom-toolbar.transaction.failed", " Failed transaction")) if pgcli.pgexecute.valid_transaction(): - result.append( - ("class:bottom-toolbar.transaction.valid", " Transaction") - ) + result.append(("class:bottom-toolbar.transaction.valid", " Transaction")) if pgcli.completion_refresher.is_refreshing(): result.append(("class:bottom-toolbar", " Refreshing completions...")) diff --git a/pgcli/pyev.py b/pgcli/pyev.py new file mode 100644 index 000000000..322060fa6 --- /dev/null +++ b/pgcli/pyev.py @@ -0,0 +1,417 @@ +import textwrap +import re +from click import style as color + +DESCRIPTIONS = { + "Append": "Used in a UNION to merge multiple record sets by appending them together.", + "Limit": "Returns a specified number of rows from a record set.", + "Sort": "Sorts a record set based on the specified sort key.", + "Nested Loop": "Merges two record sets by looping through every record in the first set and trying to find a match in the second set. All matching records are returned.", + "Merge Join": "Merges two record sets by first sorting them on a join key.", + "Hash": "Generates a hash table from the records in the input recordset. Hash is used by Hash Join.", + "Hash Join": "Joins to record sets by hashing one of them (using a Hash Scan).", + "Aggregate": "Groups records together based on a GROUP BY or aggregate function (e.g. sum()).", + "Hashaggregate": "Groups records together based on a GROUP BY or aggregate function (e.g. sum()). Hash Aggregate uses a hash to first organize the records by a key.", + "Sequence Scan": "Finds relevant records by sequentially scanning the input record set. When reading from a table, Seq Scans (unlike Index Scans) perform a single read operation (only the table is read).", + "Seq Scan": "Finds relevant records by sequentially scanning the input record set. When reading from a table, Seq Scans (unlike Index Scans) perform a single read operation (only the table is read).", + "Index Scan": "Finds relevant records based on an Index. Index Scans perform 2 read operations: one to read the index and another to read the actual value from the table.", + "Index Only Scan": "Finds relevant records based on an Index. Index Only Scans perform a single read operation from the index and do not read from the corresponding table.", + "Bitmap Heap Scan": "Searches through the pages returned by the Bitmap Index Scan for relevant rows.", + "Bitmap Index Scan": "Uses a Bitmap Index (index which uses 1 bit per page) to find all relevant pages. Results of this node are fed to the Bitmap Heap Scan.", + "CTEScan": "Performs a sequential scan of Common Table Expression (CTE) query results. Note that results of a CTE are materialized (calculated and temporarily stored).", + "ProjectSet": "ProjectSet appears when the SELECT or ORDER BY clause of the query. They basically just execute the set-returning function(s) for each tuple until none of the functions return any more records.", + "Result": "Returns result", +} + + +class Visualizer: + def __init__(self, terminal_width=100, color=True): + self.color = color + self.terminal_width = terminal_width + self.string_lines = [] + + def load(self, explain_dict): + self.plan = explain_dict.pop("Plan") + self.explain = explain_dict + self.process_all() + self.generate_lines() + + def process_all(self): + self.plan = self.process_plan(self.plan) + self.plan = self.calculate_outlier_nodes(self.plan) + + # + def process_plan(self, plan): + plan = self.calculate_planner_estimate(plan) + plan = self.calculate_actuals(plan) + self.calculate_maximums(plan) + # + for index in range(len(plan.get("Plans", []))): + _plan = plan["Plans"][index] + plan["Plans"][index] = self.process_plan(_plan) + return plan + + def prefix_format(self, v): + if self.color: + return color(v, fg="bright_black") + return v + + def tag_format(self, v): + if self.color: + return color(v, fg="white", bg="red") + return v + + def muted_format(self, v): + if self.color: + return color(v, fg="bright_black") + return v + + def bold_format(self, v): + if self.color: + return color(v, fg="white") + return v + + def good_format(self, v): + if self.color: + return color(v, fg="green") + return v + + def warning_format(self, v): + if self.color: + return color(v, fg="yellow") + return v + + def critical_format(self, v): + if self.color: + return color(v, fg="red") + return v + + def output_format(self, v): + if self.color: + return color(v, fg="cyan") + return v + + def calculate_planner_estimate(self, plan): + plan["Planner Row Estimate Factor"] = 0 + plan["Planner Row Estimate Direction"] = "Under" + + if plan["Plan Rows"] == plan["Actual Rows"]: + return plan + + if plan["Plan Rows"] != 0: + plan["Planner Row Estimate Factor"] = plan["Actual Rows"] / plan["Plan Rows"] + + if plan["Planner Row Estimate Factor"] < 10: + plan["Planner Row Estimate Factor"] = 0 + plan["Planner Row Estimate Direction"] = "Over" + if plan["Actual Rows"] != 0: + plan["Planner Row Estimate Factor"] = plan["Plan Rows"] / plan["Actual Rows"] + return plan + + # + def calculate_actuals(self, plan): + plan["Actual Duration"] = plan["Actual Total Time"] + plan["Actual Cost"] = plan["Total Cost"] + + for child in plan.get("Plans", []): + if child["Node Type"] != "CTEScan": + plan["Actual Duration"] = plan["Actual Duration"] - child["Actual Total Time"] + plan["Actual Cost"] = plan["Actual Cost"] - child["Total Cost"] + + if plan["Actual Cost"] < 0: + plan["Actual Cost"] = 0 + + plan["Actual Duration"] = plan["Actual Duration"] * plan["Actual Loops"] + return plan + + def calculate_outlier_nodes(self, plan): + plan["Costliest"] = plan["Actual Cost"] == self.explain["Max Cost"] + plan["Largest"] = plan["Actual Rows"] == self.explain["Max Rows"] + plan["Slowest"] = plan["Actual Duration"] == self.explain["Max Duration"] + + for index in range(len(plan.get("Plans", []))): + _plan = plan["Plans"][index] + plan["Plans"][index] = self.calculate_outlier_nodes(_plan) + return plan + + def calculate_maximums(self, plan): + if not self.explain.get("Max Rows"): + self.explain["Max Rows"] = plan["Actual Rows"] + elif self.explain.get("Max Rows") < plan["Actual Rows"]: + self.explain["Max Rows"] = plan["Actual Rows"] + + if not self.explain.get("Max Cost"): + self.explain["Max Cost"] = plan["Actual Cost"] + elif self.explain.get("Max Cost") < plan["Actual Cost"]: + self.explain["Max Cost"] = plan["Actual Cost"] + + if not self.explain.get("Max Duration"): + self.explain["Max Duration"] = plan["Actual Duration"] + elif self.explain.get("Max Duration") < plan["Actual Duration"]: + self.explain["Max Duration"] = plan["Actual Duration"] + + if not self.explain.get("Total Cost"): + self.explain["Total Cost"] = plan["Actual Cost"] + elif self.explain.get("Total Cost") < plan["Actual Cost"]: + self.explain["Total Cost"] = plan["Actual Cost"] + + # + def duration_to_string(self, value): + if value < 1: + return self.good_format("<1 ms") + elif value < 100: + return self.good_format("%.2f ms" % value) + elif value < 1000: + return self.warning_format("%.2f ms" % value) + elif value < 60000: + return self.critical_format( + "%.2f s" % (value / 1000.0), + ) + else: + return self.critical_format( + "%.2f m" % (value / 60000.0), + ) + + # } + # + def format_details(self, plan): + details = [] + + if plan.get("Scan Direction"): + details.append(plan["Scan Direction"]) + + if plan.get("Strategy"): + details.append(plan["Strategy"]) + + if len(details) > 0: + return self.muted_format(" [%s]" % ", ".join(details)) + + return "" + + def format_tags(self, plan): + tags = [] + + if plan["Slowest"]: + tags.append(self.tag_format("slowest")) + if plan["Costliest"]: + tags.append(self.tag_format("costliest")) + if plan["Largest"]: + tags.append(self.tag_format("largest")) + if plan.get("Planner Row Estimate Factor", 0) >= 100: + tags.append(self.tag_format("bad estimate")) + + return " ".join(tags) + + def get_terminator(self, index, plan): + if index == 0: + if len(plan.get("Plans", [])) == 0: + return "⌡► " + else: + return "├► " + else: + if len(plan.get("Plans", [])) == 0: + return " " + else: + return "│ " + + def wrap_string(self, line, width): + if width == 0: + return [line] + return textwrap.wrap(line, width) + + def intcomma(self, value): + sep = "," + if not isinstance(value, str): + value = int(value) + + orig = str(value) + + new = re.sub(r"^(-?\d+)(\d{3})", rf"\g<1>{sep}\g<2>", orig) + if orig == new: + return new + else: + return self.intcomma(new) + + def output_fn(self, current_prefix, string): + return "%s%s" % (self.prefix_format(current_prefix), string) + + def create_lines(self, plan, prefix, depth, width, last_child): + current_prefix = prefix + self.string_lines.append(self.output_fn(current_prefix, self.prefix_format("│"))) + + joint = "├" + if last_child: + joint = "└" + # + self.string_lines.append( + self.output_fn( + current_prefix, + "%s %s%s %s" + % ( + self.prefix_format(joint + "─⌠"), + self.bold_format(plan["Node Type"]), + self.format_details(plan), + self.format_tags(plan), + ), + ) + ) + # + if last_child: + prefix += " " + else: + prefix += "│ " + + current_prefix = prefix + "│ " + + cols = width - len(current_prefix) + + for line in self.wrap_string( + DESCRIPTIONS.get(plan["Node Type"], "Not found : %s" % plan["Node Type"]), + cols, + ): + self.string_lines.append(self.output_fn(current_prefix, "%s" % self.muted_format(line))) + # + if plan.get("Actual Duration"): + self.string_lines.append( + self.output_fn( + current_prefix, + "○ %s %s (%.0f%%)" + % ( + "Duration:", + self.duration_to_string(plan["Actual Duration"]), + (plan["Actual Duration"] / self.explain["Execution Time"]) * 100, + ), + ) + ) + + self.string_lines.append( + self.output_fn( + current_prefix, + "○ %s %s (%.0f%%)" + % ( + "Cost:", + self.intcomma(plan["Actual Cost"]), + (plan["Actual Cost"] / self.explain["Total Cost"]) * 100, + ), + ) + ) + + self.string_lines.append( + self.output_fn( + current_prefix, + "○ %s %s" % ("Rows:", self.intcomma(plan["Actual Rows"])), + ) + ) + + current_prefix = current_prefix + " " + + if plan.get("Join Type"): + self.string_lines.append( + self.output_fn( + current_prefix, + "%s %s" % (plan["Join Type"], self.muted_format("join")), + ) + ) + + if plan.get("Relation Name"): + self.string_lines.append( + self.output_fn( + current_prefix, + "%s %s.%s" + % ( + self.muted_format("on"), + plan.get("Schema", "unknown"), + plan["Relation Name"], + ), + ) + ) + + if plan.get("Index Name"): + self.string_lines.append( + self.output_fn( + current_prefix, + "%s %s" % (self.muted_format("using"), plan["Index Name"]), + ) + ) + + if plan.get("Index Condition"): + self.string_lines.append( + self.output_fn( + current_prefix, + "%s %s" % (self.muted_format("condition"), plan["Index Condition"]), + ) + ) + + if plan.get("Filter"): + self.string_lines.append( + self.output_fn( + current_prefix, + "%s %s %s" + % ( + self.muted_format("filter"), + plan["Filter"], + self.muted_format("[-%s rows]" % self.intcomma(plan["Rows Removed by Filter"])), + ), + ) + ) + + if plan.get("Hash Condition"): + self.string_lines.append( + self.output_fn( + current_prefix, + "%s %s" % (self.muted_format("on"), plan["Hash Condition"]), + ) + ) + + if plan.get("CTE Name"): + self.string_lines.append(self.output_fn(current_prefix, "CTE %s" % plan["CTE Name"])) + + if plan.get("Planner Row Estimate Factor") != 0: + self.string_lines.append( + self.output_fn( + current_prefix, + "%s %sestimated %s %.2fx" + % ( + self.muted_format("rows"), + plan["Planner Row Estimate Direction"], + self.muted_format("by"), + plan["Planner Row Estimate Factor"], + ), + ) + ) + + current_prefix = prefix + + if len(plan.get("Output", [])) > 0: + for index, line in enumerate(self.wrap_string(" + ".join(plan["Output"]), cols)): + self.string_lines.append( + self.output_fn( + current_prefix, + self.prefix_format(self.get_terminator(index, plan)) + self.output_format(line), + ) + ) + + for index, nested_plan in enumerate(plan.get("Plans", [])): + self.create_lines(nested_plan, prefix, depth + 1, width, index == len(plan["Plans"]) - 1) + + def generate_lines(self): + self.string_lines = [ + "○ Total Cost: %s" % self.intcomma(self.explain["Total Cost"]), + "○ Planning Time: %s" % self.duration_to_string(self.explain["Planning Time"]), + "○ Execution Time: %s" % self.duration_to_string(self.explain["Execution Time"]), + self.prefix_format("┬"), + ] + self.create_lines( + self.plan, + "", + 0, + self.terminal_width, + len(self.plan.get("Plans", [])) == 1, + ) + + def get_list(self): + return "\n".join(self.string_lines) + + def print(self): + for lin in self.string_lines: + print(lin) diff --git a/pyproject.toml b/pyproject.toml index c9bf518c7..6f3d28e4c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,22 +1,136 @@ -[tool.black] -line-length = 88 -target-version = ['py36'] -include = '\.pyi?$' -exclude = ''' -/( - \.eggs - | \.git - | \.hg - | \.mypy_cache - | \.tox - | \.venv - | \.cache - | \.pytest_cache - | _build - | buck-out - | build - | dist - | tests/data -)/ -''' +[project] +name = "pgcli" +authors = [{ name = "Pgcli Core Team", email = "pgcli-dev@googlegroups.com" }] +license = { text = "BSD" } +description = "CLI for Postgres Database. With auto-completion and syntax highlighting." +readme = "README.rst" +classifiers = [ + "Intended Audience :: Developers", + "Operating System :: Unix", + "Programming Language :: Python", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", + "Programming Language :: SQL", + "Topic :: Database", + "Topic :: Database :: Front-Ends", + "Topic :: Software Development", + "Topic :: Software Development :: Libraries :: Python Modules", +] +urls = { Homepage = "https://pgcli.com" } +requires-python = ">=3.9" +dependencies = [ + "pgspecial>=2.0.0", + # Click 8.1.8 through 8.3.0 have broken pager invocation for multi-argument PAGER values, which causes behave test failures. + "click >= 4.1, != 8.1.8, != 8.2.*, != 8.3.0, < 9", + "Pygments>=2.0", # Pygments has to be Capitalcased. + # We still need to use pt-2 unless pt-3 released on Fedora32 + # see: https://github.com/dbcli/pgcli/pull/1197 + "prompt_toolkit>=2.0.6,<4.0.0", + "psycopg >= 3.0.14; sys_platform != 'win32'", + "psycopg-binary >= 3.0.14; sys_platform == 'win32'", + "sqlparse >=0.3.0,<0.6", + "configobj >= 5.0.6", + "cli_helpers[styles] >= 2.4.0", + # setproctitle is used to mask the password when running `ps` in command line. + # But this is not necessary in Windows since the password is never shown in the + # task manager. Also setproctitle is a hard dependency to install in Windows, + # so we'll only install it if we're not in Windows. + "setproctitle >= 1.1.9; sys_platform != 'win32' and 'CYGWIN' not in sys_platform", + "tzlocal >= 5.2", +] +dynamic = ["version"] + +[project.scripts] +pgcli = "pgcli.main:cli" + +[project.optional-dependencies] +keyring = ["keyring >= 12.2.0"] +sshtunnel = ["sshtunnel >= 0.4.0"] +dev = [ + "behave>=1.2.4", + "coverage>=7.2.7", + "docutils>=0.13.1", + "keyrings.alt>=3.1", + "pexpect>=4.9.0; platform_system != 'Windows'", + "pytest>=7.4.4", + "pytest-cov>=4.1.0", + "ruff>=0.11.7", + "sshtunnel>=0.4.0", + "tox>=1.9.2", +] + +[build-system] +requires = ["setuptools>=64.0", "setuptools-scm>=8"] +build-backend = "setuptools.build_meta" + +[tool.setuptools_scm] + +[tool.setuptools] +include-package-data = false + +[tool.setuptools.dynamic] +version = { attr = "pgcli.__version__" } + +[tool.setuptools.packages] +find = { namespaces = false } + +[tool.setuptools.package-data] +pgcli = ["pgclirc", "packages/pgliterals/pgliterals.json"] + +[tool.ruff] +target-version = 'py39' +line-length = 140 +show-fixes = true + +[tool.ruff.lint] +select = [ + 'A', +# 'I', # todo enableme imports + 'E', + 'W', + 'F', + 'C4', + 'PIE', + 'TID', +] +ignore = [ + 'E401', # Multiple imports on one line + 'E402', # Module level import not at top of file + 'PIE808', # range() starting with 0 + # https://docs.astral.sh/ruff/formatter/#conflicting-lint-rules + 'E111', # indentation-with-invalid-multiple + 'E114', # indentation-with-invalid-multiple-comment + 'E117', # over-indented + 'W191', # tab-indentation + 'E741', # ambiguous-variable-name + # TODO + 'PIE796', # todo enableme Enum contains duplicate value +] +exclude = [ + 'pgcli/magic.py', + 'pgcli/pyev.py', +] + +[tool.ruff.lint.isort] +force-sort-within-sections = true +known-first-party = [ + 'pgcli', + 'tests', +] + +[tool.ruff.format] +preview = true +quote-style = 'preserve' +exclude = [ + 'build', +] + +[tool.pytest.ini_options] +minversion = "6.0" +addopts = "--capture=sys --showlocals -rxs" +testpaths = ["tests"] \ No newline at end of file diff --git a/release.py b/release.py index e83d23921..929393796 100644 --- a/release.py +++ b/release.py @@ -2,10 +2,10 @@ """A script to publish a release of pgcli to PyPI.""" import io -from optparse import OptionParser import re import subprocess import sys +from optparse import OptionParser import click @@ -45,9 +45,7 @@ def run_step(*args): def version(version_file): - _version_re = re.compile( - r'__version__\s+=\s+(?P[\'"])(?P.*)(?P=quote)' - ) + _version_re = re.compile(r'__version__\s+=\s+(?P[\'"])(?P.*)(?P=quote)') with io.open(version_file, encoding="utf-8") as f: ver = _version_re.search(f.read()).group("version") @@ -57,7 +55,7 @@ def version(version_file): def commit_for_release(version_file, ver): run_step("git", "reset") - run_step("git", "add", version_file) + run_step("git", "add", "-u") run_step("git", "commit", "--message", "Releasing version {}".format(ver)) @@ -66,7 +64,8 @@ def create_git_tag(tag_name): def create_distribution_files(): - run_step("python", "setup.py", "clean", "--all", "sdist", "bdist_wheel") + run_step("rm", "-rf", "dist/") + run_step("python", "-m", "build") def upload_distribution_files(): @@ -74,7 +73,7 @@ def upload_distribution_files(): def push_to_github(): - run_step("git", "push", "origin", "master") + run_step("git", "push", "origin", "main") def push_tags_to_github(): @@ -91,11 +90,11 @@ def checklist(questions): if DEBUG: subprocess.check_output = lambda x: x - checks = [ - "Have you updated the AUTHORS file?", - "Have you updated the `Usage` section of the README?", - ] - checklist(checks) + # checks = [ + # "Have you updated the AUTHORS file?", + # "Have you updated the `Usage` section of the README?", + # ] + # checklist(checks) ver = version("pgcli/__init__.py") print("Releasing Version:", ver) @@ -107,9 +106,7 @@ def checklist(questions): action="store_true", dest="confirm_steps", default=False, - help=( - "Confirm every step. If the step is not " "confirmed, it will be skipped." - ), + help=("Confirm every step. If the step is not confirmed, it will be skipped."), ) parser.add_option( "-d", diff --git a/release_procedure.txt b/release_procedure.txt deleted file mode 100644 index 9f3bff0ec..000000000 --- a/release_procedure.txt +++ /dev/null @@ -1,13 +0,0 @@ -# vi: ft=vimwiki - -* Bump the version number in pgcli/__init__.py -* Commit with message: 'Releasing version X.X.X.' -* Create a tag: git tag vX.X.X -* Fix the image url in PyPI to point to github raw content. https://raw.githubusercontent.com/dbcli/pgcli/master/screenshots/image01.png -* Create source dist tar ball: python setup.py sdist -* Test this by installing it in a fresh new virtualenv. Run SanityChecks [./sanity_checks.txt]. -* Upload the source dist to PyPI: https://pypi.python.org/pypi/pgcli -* pip install pgcli -* Run SanityChecks. -* Push the version back to github: git push --tags origin master -* Done! diff --git a/requirements-dev.txt b/requirements-dev.txt deleted file mode 100644 index 84fa6bf73..000000000 --- a/requirements-dev.txt +++ /dev/null @@ -1,13 +0,0 @@ -pytest>=2.7.0 -tox>=1.9.2 -behave>=1.2.4 -pexpect==3.3 -pre-commit>=1.16.0 -coverage==5.0.4 -codecov>=1.5.1 -docutils>=0.13.1 -autopep8==1.3.3 -click==6.7 -twine==1.11.0 -wheel==0.33.6 -prompt_toolkit==3.0.5 diff --git a/screenshots/kharkiv-destroyed.jpg b/screenshots/kharkiv-destroyed.jpg new file mode 100644 index 000000000..4f9578390 Binary files /dev/null and b/screenshots/kharkiv-destroyed.jpg differ diff --git a/setup.py b/setup.py deleted file mode 100644 index fc2103237..000000000 --- a/setup.py +++ /dev/null @@ -1,64 +0,0 @@ -import platform -from setuptools import setup, find_packages - -from pgcli import __version__ - -description = "CLI for Postgres Database. With auto-completion and syntax highlighting." - -install_requirements = [ - "pgspecial>=1.11.8", - "click >= 4.1", - "Pygments >= 2.0", # Pygments has to be Capitalcased. WTF? - # We still need to use pt-2 unless pt-3 released on Fedora32 - # see: https://github.com/dbcli/pgcli/pull/1197 - "prompt_toolkit>=2.0.6,<4.0.0", - "psycopg2 >= 2.8", - "sqlparse >=0.3.0,<0.5", - "configobj >= 5.0.6", - "pendulum>=2.1.0", - "cli_helpers[styles] >= 2.0.0", -] - - -# setproctitle is used to mask the password when running `ps` in command line. -# But this is not necessary in Windows since the password is never shown in the -# task manager. Also setproctitle is a hard dependency to install in Windows, -# so we'll only install it if we're not in Windows. -if platform.system() != "Windows" and not platform.system().startswith("CYGWIN"): - install_requirements.append("setproctitle >= 1.1.9") - -setup( - name="pgcli", - author="Pgcli Core Team", - author_email="pgcli-dev@googlegroups.com", - version=__version__, - license="BSD", - url="http://pgcli.com", - packages=find_packages(), - package_data={"pgcli": ["pgclirc", "packages/pgliterals/pgliterals.json"]}, - description=description, - long_description=open("README.rst").read(), - install_requires=install_requirements, - extras_require={"keyring": ["keyring >= 12.2.0"]}, - python_requires=">=3.6", - entry_points=""" - [console_scripts] - pgcli=pgcli.main:cli - """, - classifiers=[ - "Intended Audience :: Developers", - "License :: OSI Approved :: BSD License", - "Operating System :: Unix", - "Programming Language :: Python", - "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.6", - "Programming Language :: Python :: 3.7", - "Programming Language :: Python :: 3.8", - "Programming Language :: Python :: 3.9", - "Programming Language :: SQL", - "Topic :: Database", - "Topic :: Database :: Front-Ends", - "Topic :: Software Development", - "Topic :: Software Development :: Libraries :: Python Modules", - ], -) diff --git a/tests/conftest.py b/tests/conftest.py index 33cddf247..e50f1fe07 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -9,6 +9,7 @@ db_connection, drop_tables, ) +import pgcli.main import pgcli.pgexecute @@ -37,6 +38,7 @@ def executor(connection): password=POSTGRES_PASSWORD, port=POSTGRES_PORT, dsn=None, + notify_callback=pgcli.main.notify_callback, ) diff --git a/tests/features/basic_commands.feature b/tests/features/basic_commands.feature index 99f893e2f..7e975dcc8 100644 --- a/tests/features/basic_commands.feature +++ b/tests/features/basic_commands.feature @@ -23,10 +23,38 @@ Feature: run the cli, When we send "ctrl + d" then dbcli exits + Scenario: confirm exit when a transaction is ongoing + When we begin transaction + and we try to send "ctrl + d" + then we see ongoing transaction message + when we send "c" + then dbcli exits + + Scenario: cancel exit when a transaction is ongoing + When we begin transaction + and we try to send "ctrl + d" + then we see ongoing transaction message + when we send "a" + then we see dbcli prompt + when we rollback transaction + when we send "ctrl + d" + then dbcli exits + + Scenario: interrupt current query via "ctrl + c" + When we send sleep query + and we send "ctrl + c" + then we see cancelled query warning + when we check for any non-idle sleep queries + then we don't see any non-idle sleep queries + Scenario: list databases When we list databases then we see list of databases + Scenario: ping databases + When we ping the database + then we get a pong response + Scenario: run the cli with --username When we launch dbcli using --username and we send "\?" command @@ -49,7 +77,6 @@ Feature: run the cli, when we send "\?" command then we see help output - @wip Scenario: run the cli with dsn and password When we launch dbcli using dsn_password then we send password diff --git a/tests/features/crud_database.feature b/tests/features/crud_database.feature index ed13bbe02..87da4e392 100644 --- a/tests/features/crud_database.feature +++ b/tests/features/crud_database.feature @@ -5,7 +5,7 @@ Feature: manipulate databases: When we create database then we see database created when we drop database - then we confirm the destructive warning + then we respond to the destructive warning: y then we see database dropped when we connect to dbserver then we see database connected diff --git a/tests/features/crud_table.feature b/tests/features/crud_table.feature index 1f9db4a03..8a43c5c0a 100644 --- a/tests/features/crud_table.feature +++ b/tests/features/crud_table.feature @@ -8,15 +8,38 @@ Feature: manipulate tables: then we see table created when we insert into table then we see record inserted + when we select from table + then we see data selected: initial when we update table then we see record updated when we select from table - then we see data selected + then we see data selected: updated when we delete from table - then we confirm the destructive warning + then we respond to the destructive warning: y then we see record deleted when we drop table - then we confirm the destructive warning + then we respond to the destructive warning: y then we see table dropped when we connect to dbserver then we see database connected + + Scenario: transaction handling, with cancelling on a destructive warning. + When we connect to test database + then we see database connected + when we create table + then we see table created + when we begin transaction + then we see transaction began + when we insert into table + then we see record inserted + when we delete from table + then we respond to the destructive warning: n + when we select from table + then we see data selected: initial + when we rollback transaction + then we see transaction rolled back + when we select from table + then we see select output without data + when we drop table + then we respond to the destructive warning: y + then we see table dropped diff --git a/tests/features/db_utils.py b/tests/features/db_utils.py index 6898394e3..db7f017f1 100644 --- a/tests/features/db_utils.py +++ b/tests/features/db_utils.py @@ -1,10 +1,7 @@ -from psycopg2 import connect -from psycopg2.extensions import AsIs +from psycopg import connect -def create_db( - hostname="localhost", username=None, password=None, dbname=None, port=None -): +def create_db(hostname="localhost", username=None, password=None, dbname=None, port=None): """Create test database. :param hostname: string @@ -17,13 +14,10 @@ def create_db( """ cn = create_cn(hostname, password, username, "postgres", port) - # ISOLATION_LEVEL_AUTOCOMMIT = 0 - # Needed for DB creation. - cn.set_isolation_level(0) - + cn.autocommit = True with cn.cursor() as cr: - cr.execute("drop database if exists %s", (AsIs(dbname),)) - cr.execute("create database %s", (AsIs(dbname),)) + cr.execute(f"drop database if exists {dbname}") + cr.execute(f"create database {dbname}") cn.close() @@ -40,14 +34,25 @@ def create_cn(hostname, password, username, dbname, port): :param dbname: string :return: psycopg2.connection """ - cn = connect( - host=hostname, user=username, database=dbname, password=password, port=port - ) + cn = connect(host=hostname, user=username, dbname=dbname, password=password, port=port) - print(f"Created connection: {cn.dsn}.") + print(f"Created connection: {cn.info.get_parameters()}.") return cn +def pgbouncer_available(hostname="localhost", password=None, username="postgres"): + cn = None + try: + cn = create_cn(hostname, password, username, "pgbouncer", 6432) + return True + except Exception: + print("Pgbouncer is not available.") + finally: + if cn: + cn.close() + return False + + def drop_db(hostname="localhost", username=None, password=None, dbname=None, port=None): """ Drop database. @@ -58,12 +63,11 @@ def drop_db(hostname="localhost", username=None, password=None, dbname=None, por """ cn = create_cn(hostname, password, username, "postgres", port) - # ISOLATION_LEVEL_AUTOCOMMIT = 0 # Needed for DB drop. - cn.set_isolation_level(0) + cn.autocommit = True with cn.cursor() as cr: - cr.execute("drop database if exists %s", (AsIs(dbname),)) + cr.execute(f"drop database if exists {dbname}") close_cn(cn) @@ -74,5 +78,6 @@ def close_cn(cn=None): :param connection: psycopg2.connection """ if cn: + cn_params = cn.info.get_parameters() cn.close() - print(f"Closed connection: {cn.dsn}.") + print(f"Closed connection: {cn_params}.") diff --git a/tests/features/environment.py b/tests/features/environment.py index 215c85cd5..a6cde7021 100644 --- a/tests/features/environment.py +++ b/tests/features/environment.py @@ -1,13 +1,13 @@ import copy import os +import shutil +import signal import sys +import tempfile + import db_utils as dbutils import fixture_utils as fixutils import pexpect -import tempfile -import shutil -import signal - from steps import wrappers @@ -22,17 +22,13 @@ def before_all(context): os.environ["VISUAL"] = "ex" os.environ["PROMPT_TOOLKIT_NO_CPR"] = "1" - context.package_root = os.path.abspath( - os.path.dirname(os.path.dirname(os.path.dirname(__file__))) - ) + context.package_root = os.path.abspath(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) fixture_dir = os.path.join(context.package_root, "tests/features/fixture_data") print("package root:", context.package_root) print("fixture dir:", fixture_dir) - os.environ["COVERAGE_PROCESS_START"] = os.path.join( - context.package_root, ".coveragerc" - ) + os.environ["COVERAGE_PROCESS_START"] = os.path.join(context.package_root, ".coveragerc") context.exit_sent = False @@ -42,30 +38,20 @@ def before_all(context): # Store get params from config. context.conf = { - "host": context.config.userdata.get( - "pg_test_host", os.getenv("PGHOST", "localhost") - ), - "user": context.config.userdata.get( - "pg_test_user", os.getenv("PGUSER", "postgres") - ), - "pass": context.config.userdata.get( - "pg_test_pass", os.getenv("PGPASSWORD", None) - ), - "port": context.config.userdata.get( - "pg_test_port", os.getenv("PGPORT", "5432") - ), + "host": context.config.userdata.get("pg_test_host", os.getenv("PGHOST", "localhost")), + "user": context.config.userdata.get("pg_test_user", os.getenv("PGUSER", "postgres")), + "pass": context.config.userdata.get("pg_test_pass", os.getenv("PGPASSWORD", None)), + "port": context.config.userdata.get("pg_test_port", os.getenv("PGPORT", "5432")), "cli_command": ( context.config.userdata.get("pg_cli_command", None) or '{python} -c "{startup}"'.format( python=sys.executable, - startup="; ".join( - [ - "import coverage", - "coverage.process_startup()", - "import pgcli.main", - "pgcli.main.cli(auto_envvar_prefix='BEHAVE')", - ] - ), + startup="; ".join([ + "import coverage", + "coverage.process_startup()", + "import pgcli.main", + "pgcli.main.cli(auto_envvar_prefix='BEHAVE')", + ]), ) ), "dbname": db_name_full, @@ -111,7 +97,11 @@ def before_all(context): context.conf["dbname"], context.conf["port"], ) - + context.pgbouncer_available = dbutils.pgbouncer_available( + hostname=context.conf["host"], + password=context.conf["pass"], + username=context.conf["user"], + ) context.fixture_data = fixutils.read_fixture_files() # use temporary directory as config home @@ -145,7 +135,7 @@ def after_all(context): context.conf["port"], ) - # Remove temp config direcotry + # Remove temp config directory shutil.rmtree(context.env_config_home) # Restore env vars. @@ -160,11 +150,38 @@ def before_step(context, _): context.atprompt = False +def is_known_problem(scenario): + """TODO: can we fix this?""" + return scenario.name in ( + 'interrupt current query via "ctrl + c"', + "run the cli with --username", + "run the cli with --user", + "run the cli with --port", + "confirm exit when a transaction is ongoing", + "cancel exit when a transaction is ongoing", + "run the cli and exit", + ) + + def before_scenario(context, scenario): if scenario.name == "list databases": # not using the cli for that return - wrappers.run_cli(context) + if is_known_problem(scenario): + scenario.skip() + currentdb = None + if "pgbouncer" in scenario.feature.tags: + if context.pgbouncer_available: + os.environ["PGDATABASE"] = "pgbouncer" + os.environ["PGPORT"] = "6432" + currentdb = "pgbouncer" + else: + scenario.skip() + else: + # set env vars back to normal test database + os.environ["PGDATABASE"] = context.conf["dbname"] + os.environ["PGPORT"] = context.conf["port"] + wrappers.run_cli(context, currentdb=currentdb) wrappers.wait_prompt(context) @@ -172,13 +189,17 @@ def after_scenario(context, scenario): """Cleans up after each scenario completes.""" if hasattr(context, "cli") and context.cli and not context.exit_sent: # Quit nicely. - if not context.atprompt: + if not getattr(context, "atprompt", False): dbname = context.currentdb - context.cli.expect_exact(f"{dbname}> ", timeout=15) - context.cli.sendcontrol("c") - context.cli.sendcontrol("d") + context.cli.expect_exact(f"{dbname}>", timeout=5) + try: + context.cli.sendcontrol("c") + context.cli.sendcontrol("d") + except Exception as x: + print("Failed cleanup after scenario:") + print(x) try: - context.cli.expect_exact(pexpect.EOF, timeout=15) + context.cli.expect_exact(pexpect.EOF, timeout=5) except pexpect.TIMEOUT: print(f"--- after_scenario {scenario.name}: kill cli") context.cli.kill(signal.SIGKILL) diff --git a/tests/features/expanded.feature b/tests/features/expanded.feature index 4f381f81d..e4860486a 100644 --- a/tests/features/expanded.feature +++ b/tests/features/expanded.feature @@ -7,7 +7,7 @@ Feature: expanded mode: and we select from table then we see expanded data selected when we drop table - then we confirm the destructive warning + then we respond to the destructive warning: y then we see table dropped Scenario: expanded off @@ -16,7 +16,7 @@ Feature: expanded mode: and we select from table then we see nonexpanded data selected when we drop table - then we confirm the destructive warning + then we respond to the destructive warning: y then we see table dropped Scenario: expanded auto @@ -25,5 +25,5 @@ Feature: expanded mode: and we select from table then we see auto data selected when we drop table - then we confirm the destructive warning + then we respond to the destructive warning: y then we see table dropped diff --git a/tests/features/pgbouncer.feature b/tests/features/pgbouncer.feature new file mode 100644 index 000000000..14cc5ad8a --- /dev/null +++ b/tests/features/pgbouncer.feature @@ -0,0 +1,12 @@ +@pgbouncer +Feature: run pgbouncer, + call the help command, + exit the cli + + Scenario: run "show help" command + When we send "show help" command + then we see the pgbouncer help output + + Scenario: run the cli and exit + When we send "ctrl + d" + then dbcli exits diff --git a/tests/features/steps/auto_vertical.py b/tests/features/steps/auto_vertical.py index 1643ea5e2..d7cdccd43 100644 --- a/tests/features/steps/auto_vertical.py +++ b/tests/features/steps/auto_vertical.py @@ -24,11 +24,11 @@ def step_see_small_results(context): context, dedent( """\ - +------------+\r - | ?column? |\r - |------------|\r - | 1 |\r - +------------+\r + +----------+\r + | ?column? |\r + |----------|\r + | 1 |\r + +----------+\r SELECT 1\r """ ), diff --git a/tests/features/steps/basic_commands.py b/tests/features/steps/basic_commands.py index 07e9ec174..24c7e1e2f 100644 --- a/tests/features/steps/basic_commands.py +++ b/tests/features/steps/basic_commands.py @@ -26,6 +26,19 @@ def step_see_list_databases(context): context.cmd_output = None +@when("we ping the database") +def step_ping_database(context): + cmd = ["pgcli", "--ping"] + context.cmd_output = subprocess.check_output(cmd, cwd=context.package_root) + + +@then("we get a pong response") +def step_get_pong_response(context): + # exit code 0 is implied by the presence of cmd_output here, which + # is only set on a successful run. + assert b"PONG" in context.cmd_output.strip(), f"Output was {context.cmd_output}" + + @when("we run dbcli") def step_run_cli(context): wrappers.run_cli(context) @@ -49,9 +62,7 @@ def step_run_cli_using_arg(context, arg): arg = "service=mock_postgres --password" prompt_check = False currentdb = "postgres" - wrappers.run_cli( - context, run_args=[arg], prompt_check=prompt_check, currentdb=currentdb - ) + wrappers.run_cli(context, run_args=[arg], prompt_check=prompt_check, currentdb=currentdb) @when("we wait for prompt") @@ -64,13 +75,83 @@ def step_ctrl_d(context): """ Send Ctrl + D to hopefully exit. """ + step_try_to_ctrl_d(context) + context.cli.expect(pexpect.EOF, timeout=5) + context.exit_sent = True + + +@when('we try to send "ctrl + d"') +def step_try_to_ctrl_d(context): + """ + Send Ctrl + D, perhaps exiting, perhaps not (if a transaction is + ongoing). + """ # turn off pager before exiting context.cli.sendcontrol("c") context.cli.sendline(r"\pset pager off") wrappers.wait_prompt(context) context.cli.sendcontrol("d") - context.cli.expect(pexpect.EOF, timeout=15) - context.exit_sent = True + + +@when('we send "ctrl + c"') +def step_ctrl_c(context): + """Send Ctrl + c to hopefully interrupt.""" + context.cli.sendcontrol("c") + + +@then("we see cancelled query warning") +def step_see_cancelled_query_warning(context): + """ + Make sure we receive the warning that the current query was cancelled. + """ + wrappers.expect_exact(context, "cancelled query", timeout=2) + + +@then("we see ongoing transaction message") +def step_see_ongoing_transaction_error(context): + """ + Make sure we receive the warning that a transaction is ongoing. + """ + context.cli.expect("A transaction is ongoing.", timeout=2) + + +@when("we send sleep query") +def step_send_sleep_15_seconds(context): + """ + Send query to sleep for 15 seconds. + """ + context.cli.sendline("select pg_sleep(15)") + + +@when("we check for any non-idle sleep queries") +def step_check_for_active_sleep_queries(context): + """ + Send query to check for any non-idle pg_sleep queries. + """ + context.cli.sendline( + "select state from pg_stat_activity where query not like '%pg_stat_activity%' and query like '%pg_sleep%' and state != 'idle';" + ) + + +@then("we don't see any non-idle sleep queries") +def step_no_active_sleep_queries(context): + """Confirm that any pg_sleep queries are either idle or not active.""" + wrappers.expect_exact( + context, + context.conf["pager_boundary"] + + "\r" + + dedent( + """ + +-------+\r + | state |\r + |-------|\r + +-------+\r + SELECT 0\r + """ + ) + + context.conf["pager_boundary"], + timeout=5, + ) @when(r'we send "\?" command') @@ -97,17 +178,15 @@ def step_see_error_message(context): @when("we send source command") def step_send_source_command(context): context.tmpfile_sql_help = tempfile.NamedTemporaryFile(prefix="pgcli_") - context.tmpfile_sql_help.write(br"\?") + context.tmpfile_sql_help.write(rb"\?") context.tmpfile_sql_help.flush() - context.cli.sendline(fr"\i {context.tmpfile_sql_help.name}") + context.cli.sendline(rf"\i {context.tmpfile_sql_help.name}") wrappers.expect_exact(context, context.conf["pager_boundary"] + "\r\n", timeout=5) @when("we run query to check application_name") def step_check_application_name(context): - context.cli.sendline( - "SELECT 'found' FROM pg_stat_activity WHERE application_name = 'pgcli' HAVING COUNT(*) > 0;" - ) + context.cli.sendline("SELECT 'found' FROM pg_stat_activity WHERE application_name = 'pgcli' HAVING COUNT(*) > 0;") @then("we see found") @@ -118,11 +197,11 @@ def step_see_found(context): + "\r" + dedent( """ - +------------+\r - | ?column? |\r - |------------|\r - | found |\r - +------------+\r + +----------+\r + | ?column? |\r + |----------|\r + | found |\r + +----------+\r SELECT 1\r """ ) @@ -131,18 +210,31 @@ def step_see_found(context): ) -@then("we confirm the destructive warning") -def step_confirm_destructive_command(context): - """Confirm destructive command.""" +@then("we respond to the destructive warning: {response}") +def step_resppond_to_destructive_command(context, response): + """Respond to destructive command.""" wrappers.expect_exact( context, - "You're about to run a destructive command.\r\nDo you want to proceed? (y/n):", + "You're about to run a destructive command.\r\nDo you want to proceed? [y/N]:", timeout=2, ) - context.cli.sendline("y") + context.cli.sendline(response.strip()) @then("we send password") def step_send_password(context): wrappers.expect_exact(context, "Password for", timeout=5) context.cli.sendline(context.conf["pass"] or "DOES NOT MATTER") + + +@when('we send "{text}"') +def step_send_text(context, text): + context.cli.sendline(text) + # Try to detect whether we are exiting. If so, set `exit_sent` + # so that `after_scenario` correctly cleans up. + try: + context.cli.expect(pexpect.EOF, timeout=0.2) + except pexpect.TIMEOUT: + pass + else: + context.exit_sent = True diff --git a/tests/features/steps/crud_database.py b/tests/features/steps/crud_database.py index 3f5d0e718..9507d461a 100644 --- a/tests/features/steps/crud_database.py +++ b/tests/features/steps/crud_database.py @@ -3,6 +3,7 @@ Each step is defined by the string decorating it. This string is used to call the step in "*.feature" file. """ + import pexpect from behave import when, then @@ -59,7 +60,7 @@ def step_see_prompt(context): Wait to see the prompt. """ db_name = getattr(context, "currentdb", context.conf["dbname"]) - wrappers.expect_exact(context, f"{db_name}> ", timeout=5) + wrappers.expect_exact(context, f"{db_name}>", timeout=5) context.atprompt = True diff --git a/tests/features/steps/crud_table.py b/tests/features/steps/crud_table.py index 0375883a4..114e3edf5 100644 --- a/tests/features/steps/crud_table.py +++ b/tests/features/steps/crud_table.py @@ -9,6 +9,10 @@ import wrappers +INITIAL_DATA = "xxx" +UPDATED_DATA = "yyy" + + @when("we create table") def step_create_table(context): """ @@ -22,7 +26,7 @@ def step_insert_into_table(context): """ Send insert into table. """ - context.cli.sendline("""insert into a(x) values('xxx');""") + context.cli.sendline(f"""insert into a(x) values('{INITIAL_DATA}');""") @when("we update table") @@ -30,7 +34,7 @@ def step_update_table(context): """ Send insert into table. """ - context.cli.sendline("""update a set x = 'yyy' where x = 'xxx';""") + context.cli.sendline(f"""update a set x = '{UPDATED_DATA}' where x = '{INITIAL_DATA}';""") @when("we select from table") @@ -46,7 +50,7 @@ def step_delete_from_table(context): """ Send deete from table. """ - context.cli.sendline("""delete from a where x = 'yyy';""") + context.cli.sendline(f"""delete from a where x = '{UPDATED_DATA}';""") @when("we drop table") @@ -57,6 +61,30 @@ def step_drop_table(context): context.cli.sendline("drop table a;") +@when("we alter the table") +def step_alter_table(context): + """ + Alter the table by adding a column. + """ + context.cli.sendline("""alter table a add column y varchar;""") + + +@when("we begin transaction") +def step_begin_transaction(context): + """ + Begin transaction + """ + context.cli.sendline("begin;") + + +@when("we rollback transaction") +def step_rollback_transaction(context): + """ + Rollback transaction + """ + context.cli.sendline("rollback;") + + @then("we see table created") def step_see_table_created(context): """ @@ -81,19 +109,20 @@ def step_see_record_updated(context): wrappers.expect_pager(context, "UPDATE 1\r\n", timeout=2) -@then("we see data selected") -def step_see_data_selected(context): +@then("we see data selected: {data}") +def step_see_data_selected(context, data): """ - Wait to see select output. + Wait to see select output with initial or updated data. """ + x = UPDATED_DATA if data == "updated" else INITIAL_DATA wrappers.expect_pager( context, dedent( - """\ + f"""\ +-----+\r | x |\r |-----|\r - | yyy |\r + | {x} |\r +-----+\r SELECT 1\r """ @@ -102,6 +131,26 @@ def step_see_data_selected(context): ) +@then("we see select output without data") +def step_see_no_data_selected(context): + """ + Wait to see select output without data. + """ + wrappers.expect_pager( + context, + dedent( + """\ + +---+\r + | x |\r + |---|\r + +---+\r + SELECT 0\r + """ + ), + timeout=1, + ) + + @then("we see record deleted") def step_see_data_deleted(context): """ @@ -116,3 +165,19 @@ def step_see_table_dropped(context): Wait to see drop output. """ wrappers.expect_pager(context, "DROP TABLE\r\n", timeout=2) + + +@then("we see transaction began") +def step_see_transaction_began(context): + """ + Wait to see transaction began. + """ + wrappers.expect_pager(context, "BEGIN\r\n", timeout=2) + + +@then("we see transaction rolled back") +def step_see_transaction_rolled_back(context): + """ + Wait to see transaction rollback. + """ + wrappers.expect_pager(context, "ROLLBACK\r\n", timeout=2) diff --git a/tests/features/steps/expanded.py b/tests/features/steps/expanded.py index 265ea39b2..302cab949 100644 --- a/tests/features/steps/expanded.py +++ b/tests/features/steps/expanded.py @@ -16,7 +16,7 @@ def step_prepare_data(context): context.cli.sendline("drop table if exists a;") wrappers.expect_exact( context, - "You're about to run a destructive command.\r\nDo you want to proceed? (y/n):", + "You're about to run a destructive command.\r\nDo you want to proceed? [y/N]:", timeout=2, ) context.cli.sendline("y") @@ -58,11 +58,11 @@ def step_see_data(context, which): context, dedent( """\ - +-----+-----+--------+\r - | x | y | z |\r - |-----+-----+--------|\r - | 1 | 1.0 | 1.0000 |\r - +-----+-----+--------+\r + +---+-----+--------+\r + | x | y | z |\r + |---+-----+--------|\r + | 1 | 1.0 | 1.0000 |\r + +---+-----+--------+\r SELECT 1\r """ ), diff --git a/tests/features/steps/iocommands.py b/tests/features/steps/iocommands.py index a614490a7..615eecbb1 100644 --- a/tests/features/steps/iocommands.py +++ b/tests/features/steps/iocommands.py @@ -8,15 +8,11 @@ @when("we start external editor providing a file name") def step_edit_file(context): """Edit file with external editor.""" - context.editor_file_name = os.path.join( - context.package_root, "test_file_{0}.sql".format(context.conf["vi"]) - ) + context.editor_file_name = os.path.join(context.package_root, "test_file_{0}.sql".format(context.conf["vi"])) if os.path.exists(context.editor_file_name): os.remove(context.editor_file_name) context.cli.sendline(r"\e {}".format(os.path.basename(context.editor_file_name))) - wrappers.expect_exact( - context, 'Entering Ex mode. Type "visual" to go to Normal mode.', timeout=2 - ) + wrappers.expect_exact(context, 'Entering Ex mode. Type "visual" to go to Normal mode.', timeout=2) wrappers.expect_exact(context, ":", timeout=2) @@ -48,9 +44,7 @@ def step_edit_done_sql(context): @when("we tee output") def step_tee_ouptut(context): - context.tee_file_name = os.path.join( - context.package_root, "tee_file_{0}.sql".format(context.conf["vi"]) - ) + context.tee_file_name = os.path.join(context.package_root, "tee_file_{0}.sql".format(context.conf["vi"])) if os.path.exists(context.tee_file_name): os.remove(context.tee_file_name) context.cli.sendline(r"\o {}".format(os.path.basename(context.tee_file_name))) diff --git a/tests/features/steps/pgbouncer.py b/tests/features/steps/pgbouncer.py new file mode 100644 index 000000000..f15698231 --- /dev/null +++ b/tests/features/steps/pgbouncer.py @@ -0,0 +1,22 @@ +""" +Steps for behavioral style tests are defined in this module. +Each step is defined by the string decorating it. +This string is used to call the step in "*.feature" file. +""" + +from behave import when, then +import wrappers + + +@when('we send "show help" command') +def step_send_help_command(context): + context.cli.sendline("show help") + + +@then("we see the pgbouncer help output") +def see_pgbouncer_help(context): + wrappers.expect_exact( + context, + "SHOW HELP|CONFIG|DATABASES|POOLS|CLIENTS|SERVERS|USERS|VERSION", + timeout=3, + ) diff --git a/tests/features/steps/wrappers.py b/tests/features/steps/wrappers.py index 0ca83669c..2162b8b96 100644 --- a/tests/features/steps/wrappers.py +++ b/tests/features/steps/wrappers.py @@ -1,12 +1,8 @@ import re import pexpect -from pgcli.main import COLOR_CODE_REGEX import textwrap -try: - from StringIO import StringIO -except ImportError: - from io import StringIO +from io import StringIO def expect_exact(context, expected, timeout): @@ -40,10 +36,7 @@ def expect_exact(context, expected, timeout): def expect_pager(context, expected, timeout): formatted = expected if isinstance(expected, list) else [expected] - formatted = [ - f"{context.conf['pager_boundary']}\r\n{t}{context.conf['pager_boundary']}\r\n" - for t in formatted - ] + formatted = [f"{context.conf['pager_boundary']}\r\n{t}{context.conf['pager_boundary']}\r\n" for t in formatted] expect_exact( context, @@ -70,4 +63,5 @@ def run_cli(context, run_args=None, prompt_check=True, currentdb=None): def wait_prompt(context): """Make sure prompt is displayed.""" - expect_exact(context, "{0}> ".format(context.conf["dbname"]), timeout=5) + prompt_str = "{0}>".format(context.currentdb) + expect_exact(context, [prompt_str + " ", prompt_str, pexpect.EOF], timeout=3) diff --git a/tests/formatter/__init__.py b/tests/formatter/__init__.py new file mode 100644 index 000000000..9bad5790a --- /dev/null +++ b/tests/formatter/__init__.py @@ -0,0 +1 @@ +# coding=utf-8 diff --git a/tests/formatter/test_sqlformatter.py b/tests/formatter/test_sqlformatter.py new file mode 100644 index 000000000..78ff5dc95 --- /dev/null +++ b/tests/formatter/test_sqlformatter.py @@ -0,0 +1,111 @@ +# coding=utf-8 + +from pgcli.packages.formatter.sqlformatter import escape_for_sql_statement + +from cli_helpers.tabular_output import TabularOutputFormatter +from pgcli.packages.formatter.sqlformatter import adapter, register_new_formatter + + +def test_escape_for_sql_statement_bytes(): + bts = b"837124ab3e8dc0f" + escaped_bytes = escape_for_sql_statement(bts) + assert escaped_bytes == "X'383337313234616233653864633066'" + + +def test_escape_for_sql_statement_number(): + num = 2981 + escaped_bytes = escape_for_sql_statement(num) + assert escaped_bytes == "'2981'" + + +def test_escape_for_sql_statement_str(): + example_str = "example str" + escaped_bytes = escape_for_sql_statement(example_str) + assert escaped_bytes == "'example str'" + + +def test_output_sql_insert(): + global formatter + formatter = TabularOutputFormatter + register_new_formatter(formatter) + data = [ + [ + 1, + "Jackson", + "jackson_test@gmail.com", + "132454789", + None, + "2022-09-09 19:44:32.712343+08", + "2022-09-09 19:44:32.712343+08", + ] + ] + header = ["id", "name", "email", "phone", "description", "created_at", "updated_at"] + table_format = "sql-insert" + kwargs = { + "column_types": [int, str, str, str, str, str, str], + "sep_title": "RECORD {n}", + "sep_character": "-", + "sep_length": (1, 25), + "missing_value": "", + "integer_format": "", + "float_format": "", + "disable_numparse": True, + "preserve_whitespace": True, + "max_field_width": 500, + } + formatter.query = 'SELECT * FROM "user";' + output = adapter(data, header, table_format=table_format, **kwargs) + output_list = list(output) + expected = [ + 'INSERT INTO "user" ("id", "name", "email", "phone", "description", "created_at", "updated_at") VALUES', + " ('1', 'Jackson', 'jackson_test@gmail.com', '132454789', NULL, " + + "'2022-09-09 19:44:32.712343+08', '2022-09-09 19:44:32.712343+08')", + ";", + ] + assert expected == output_list + + +def test_output_sql_update(): + global formatter + formatter = TabularOutputFormatter + register_new_formatter(formatter) + data = [ + [ + 1, + "Jackson", + "jackson_test@gmail.com", + "132454789", + "", + "2022-09-09 19:44:32.712343+08", + "2022-09-09 19:44:32.712343+08", + ] + ] + header = ["id", "name", "email", "phone", "description", "created_at", "updated_at"] + table_format = "sql-update" + kwargs = { + "column_types": [int, str, str, str, str, str, str], + "sep_title": "RECORD {n}", + "sep_character": "-", + "sep_length": (1, 25), + "missing_value": "", + "integer_format": "", + "float_format": "", + "disable_numparse": True, + "preserve_whitespace": True, + "max_field_width": 500, + } + formatter.query = 'SELECT * FROM "user";' + output = adapter(data, header, table_format=table_format, **kwargs) + output_list = list(output) + print(output_list) + expected = [ + 'UPDATE "user" SET', + " \"name\" = 'Jackson'", + ", \"email\" = 'jackson_test@gmail.com'", + ", \"phone\" = '132454789'", + ", \"description\" = ''", + ", \"created_at\" = '2022-09-09 19:44:32.712343+08'", + ", \"updated_at\" = '2022-09-09 19:44:32.712343+08'", + "WHERE \"id\" = '1';", + ] + assert expected == output_list diff --git a/tests/metadata.py b/tests/metadata.py index 4ebcccd07..3ab88a719 100644 --- a/tests/metadata.py +++ b/tests/metadata.py @@ -23,16 +23,12 @@ def completion(display_meta, text, pos=0): def function(text, pos=0, display=None): - return Completion( - text, display=display or text, start_position=pos, display_meta="function" - ) + return Completion(text, display=display or text, start_position=pos, display_meta="function") def get_result(completer, text, position=None): position = len(text) if position is None else position - return completer.get_completions( - Document(text=text, cursor_position=position), Mock() - ) + return completer.get_completions(Document(text=text, cursor_position=position), Mock()) def result_set(completer, text, position=None): @@ -73,10 +69,7 @@ def keywords(self, pos=0): return [keyword(kw, pos) for kw in self.completer.keywords_tree.keys()] def specials(self, pos=0): - return [ - Completion(text=k, start_position=pos, display_meta=v.description) - for k, v in self.completer.pgspecial.commands.items() - ] + return [Completion(text=k, start_position=pos, display_meta=v.description) for k, v in self.completer.pgspecial.commands.items()] def columns(self, tbl, parent="public", typ="tables", pos=0): if typ == "functions": @@ -87,42 +80,23 @@ def columns(self, tbl, parent="public", typ="tables", pos=0): return [column(escape(col), pos) for col in cols] def datatypes(self, parent="public", pos=0): - return [ - datatype(escape(x), pos) - for x in self.metadata.get("datatypes", {}).get(parent, []) - ] + return [datatype(escape(x), pos) for x in self.metadata.get("datatypes", {}).get(parent, [])] def tables(self, parent="public", pos=0): - return [ - table(escape(x), pos) - for x in self.metadata.get("tables", {}).get(parent, []) - ] + return [table(escape(x), pos) for x in self.metadata.get("tables", {}).get(parent, [])] def views(self, parent="public", pos=0): - return [ - view(escape(x), pos) for x in self.metadata.get("views", {}).get(parent, []) - ] + return [view(escape(x), pos) for x in self.metadata.get("views", {}).get(parent, [])] def functions(self, parent="public", pos=0): return [ function( escape(x[0]) + "(" - + ", ".join( - arg_name + " := " - for (arg_name, arg_mode) in zip(x[1], x[3]) - if arg_mode in ("b", "i") - ) + + ", ".join(arg_name + " := " for (arg_name, arg_mode) in zip(x[1], x[3]) if arg_mode in ("b", "i")) + ")", pos, - escape(x[0]) - + "(" - + ", ".join( - arg_name - for (arg_name, arg_mode) in zip(x[1], x[3]) - if arg_mode in ("b", "i") - ) - + ")", + escape(x[0]) + "(" + ", ".join(arg_name for (arg_name, arg_mode) in zip(x[1], x[3]) if arg_mode in ("b", "i")) + ")", ) for x in self.metadata.get("functions", {}).get(parent, []) ] @@ -132,24 +106,14 @@ def schemas(self, pos=0): return [schema(escape(s), pos=pos) for s in schemas] def functions_and_keywords(self, parent="public", pos=0): - return ( - self.functions(parent, pos) - + self.builtin_functions(pos) - + self.keywords(pos) - ) + return self.functions(parent, pos) + self.builtin_functions(pos) + self.keywords(pos) # Note that the filtering parameters here only apply to the columns def columns_functions_and_keywords(self, tbl, parent="public", typ="tables", pos=0): - return self.functions_and_keywords(pos=pos) + self.columns( - tbl, parent, typ, pos - ) + return self.functions_and_keywords(pos=pos) + self.columns(tbl, parent, typ, pos) def from_clause_items(self, parent="public", pos=0): - return ( - self.functions(parent, pos) - + self.views(parent, pos) - + self.tables(parent, pos) - ) + return self.functions(parent, pos) + self.views(parent, pos) + self.tables(parent, pos) def schemas_and_from_clause_items(self, parent="public", pos=0): return self.from_clause_items(parent, pos) + self.schemas(pos) @@ -205,9 +169,7 @@ def get_completer(self, settings=None, casing=None): from pgcli.pgcompleter import PGCompleter from pgspecial import PGSpecial - comp = PGCompleter( - smart_completion=True, settings=settings, pgspecial=PGSpecial() - ) + comp = PGCompleter(smart_completion=True, settings=settings, pgspecial=PGSpecial()) schemata, tables, tbl_cols, views, view_cols = [], [], [], [], [] @@ -226,20 +188,12 @@ def get_completer(self, settings=None, casing=None): view_cols.extend([self._make_col(sch, tbl, col) for col in cols]) functions = [ - FunctionMetadata(sch, *func_meta, arg_defaults=None) - for sch, funcs in metadata["functions"].items() - for func_meta in funcs + FunctionMetadata(sch, *func_meta, arg_defaults=None) for sch, funcs in metadata["functions"].items() for func_meta in funcs ] - datatypes = [ - (sch, typ) - for sch, datatypes in metadata["datatypes"].items() - for typ in datatypes - ] + datatypes = [(sch, typ) for sch, datatypes in metadata["datatypes"].items() for typ in datatypes] - foreignkeys = [ - ForeignKey(*fk) for fks in metadata["foreignkeys"].values() for fk in fks - ] + foreignkeys = [ForeignKey(*fk) for fks in metadata["foreignkeys"].values() for fk in fks] comp.extend_schemata(schemata) comp.extend_relations(tables, kind="tables") diff --git a/tests/parseutils/test_function_metadata.py b/tests/parseutils/test_function_metadata.py index 0350e2a17..c4000ab1c 100644 --- a/tests/parseutils/test_function_metadata.py +++ b/tests/parseutils/test_function_metadata.py @@ -2,15 +2,9 @@ def test_function_metadata_eq(): - f1 = FunctionMetadata( - "s", "f", ["x"], ["integer"], [], "int", False, False, False, False, None - ) - f2 = FunctionMetadata( - "s", "f", ["x"], ["integer"], [], "int", False, False, False, False, None - ) - f3 = FunctionMetadata( - "s", "g", ["x"], ["integer"], [], "int", False, False, False, False, None - ) + f1 = FunctionMetadata("s", "f", ["x"], ["integer"], [], "int", False, False, False, False, None) + f2 = FunctionMetadata("s", "f", ["x"], ["integer"], [], "int", False, False, False, False, None) + f3 = FunctionMetadata("s", "g", ["x"], ["integer"], [], "int", False, False, False, False, None) assert f1 == f2 assert f1 != f3 assert not (f1 != f2) diff --git a/tests/parseutils/test_parseutils.py b/tests/parseutils/test_parseutils.py index 5a375d70f..90749ebfc 100644 --- a/tests/parseutils/test_parseutils.py +++ b/tests/parseutils/test_parseutils.py @@ -1,5 +1,10 @@ import pytest -from pgcli.packages.parseutils import is_destructive +from pgcli.packages.parseutils import ( + is_destructive, + parse_destructive_warning, + BASE_KEYWORDS, + ALL_KEYWORDS, +) from pgcli.packages.parseutils.tables import extract_tables from pgcli.packages.parseutils.utils import find_prev_keyword, is_open_quote @@ -14,9 +19,7 @@ def test_simple_select_single_table(): assert tables == ((None, "abc", None, False),) -@pytest.mark.parametrize( - "sql", ['select * from "abc"."def"', 'select * from abc."def"'] -) +@pytest.mark.parametrize("sql", ['select * from "abc"."def"', 'select * from abc."def"']) def test_simple_select_single_table_schema_qualified_quoted_table(sql): tables = extract_tables(sql) assert tables == (("abc", "def", '"def"', False),) @@ -167,7 +170,7 @@ def test_subselect_tables(): @pytest.mark.parametrize("text", ["SELECT * FROM foo.", "SELECT 123 AS foo"]) def test_extract_no_tables(text): tables = extract_tables(text) - assert tables == tuple() + assert tables == () @pytest.mark.parametrize("arg_list", ["", "arg1", "arg1, arg2, arg3"]) @@ -220,9 +223,7 @@ def test_find_prev_keyword_where(sql): assert kw.value == "where" and stripped == "select * from foo where" -@pytest.mark.parametrize( - "sql", ["create table foo (bar int, baz ", "select * from foo() as bar (baz "] -) +@pytest.mark.parametrize("sql", ["create table foo (bar int, baz ", "select * from foo() as bar (baz "]) def test_find_prev_keyword_open_parens(sql): kw, _ = find_prev_keyword(sql) assert kw.value == "(" @@ -263,18 +264,43 @@ def test_is_open_quote__open(sql): @pytest.mark.parametrize( - ("sql", "warning_level", "expected"), + ("sql", "keywords", "expected"), + [ + ("update abc set x = 1", ALL_KEYWORDS, True), + ("update abc set x = 1 where y = 2", ALL_KEYWORDS, True), + ("update abc set x = 1", BASE_KEYWORDS, True), + ("update abc set x = 1 where y = 2", BASE_KEYWORDS, False), + ("select x, y, z from abc", ALL_KEYWORDS, False), + ("drop abc", ALL_KEYWORDS, True), + ("alter abc", ALL_KEYWORDS, True), + ("delete abc", ALL_KEYWORDS, True), + ("truncate abc", ALL_KEYWORDS, True), + ("insert into abc values (1, 2, 3)", ALL_KEYWORDS, False), + ("insert into abc values (1, 2, 3)", BASE_KEYWORDS, False), + ("insert into abc values (1, 2, 3)", ["insert"], True), + ("insert into abc values (1, 2, 3)", ["insert"], True), + ], +) +def test_is_destructive(sql, keywords, expected): + assert is_destructive(sql, keywords) == expected + + +@pytest.mark.parametrize( + ("warning_level", "expected"), [ - ("update abc set x = 1", "all", True), - ("update abc set x = 1 where y = 2", "all", True), - ("update abc set x = 1", "moderate", True), - ("update abc set x = 1 where y = 2", "moderate", False), - ("select x, y, z from abc", "all", False), - ("drop abc", "all", True), - ("alter abc", "all", True), - ("delete abc", "all", True), - ("truncate abc", "all", True), + ("true", ALL_KEYWORDS), + ("false", []), + ("all", ALL_KEYWORDS), + ("moderate", BASE_KEYWORDS), + ("off", []), + ("", []), + (None, []), + (ALL_KEYWORDS, ALL_KEYWORDS), + (BASE_KEYWORDS, BASE_KEYWORDS), + ("insert", ["insert"]), + ("drop,alter,delete", ["drop", "alter", "delete"]), + (["drop", "alter", "delete"], ["drop", "alter", "delete"]), ], ) -def test_is_destructive(sql, warning_level, expected): - assert is_destructive(sql, warning_level=warning_level) == expected +def test_parse_destructive_warning(warning_level, expected): + assert parse_destructive_warning(warning_level) == expected diff --git a/tests/test_application_name.py b/tests/test_application_name.py new file mode 100644 index 000000000..e50e397a5 --- /dev/null +++ b/tests/test_application_name.py @@ -0,0 +1,15 @@ +from unittest.mock import patch + +from click.testing import CliRunner + +from pgcli.main import cli +from pgcli.pgexecute import PGExecute + + +def test_application_name_in_env(): + runner = CliRunner() + app_name = "wonderful_app" + with patch.object(PGExecute, "__init__") as mock_pgxecute: + runner.invoke(cli, ["127.0.0.1:5432/hello", "user"], env={"PGAPPNAME": app_name}) + kwargs = mock_pgxecute.call_args.kwargs + assert kwargs.get("application_name") == app_name diff --git a/tests/test_auth.py b/tests/test_auth.py new file mode 100644 index 000000000..13eed58db --- /dev/null +++ b/tests/test_auth.py @@ -0,0 +1,36 @@ +import pytest +from unittest import mock +from pgcli import auth + + +@pytest.mark.parametrize("enabled,call_count", [(True, 1), (False, 0)]) +def test_keyring_initialize(enabled, call_count): + logger = mock.MagicMock() + + with mock.patch("importlib.import_module", return_value=True) as import_method: + auth.keyring_initialize(enabled, logger=logger) + assert import_method.call_count == call_count + + +def test_keyring_get_password_ok(): + with mock.patch("pgcli.auth.keyring", return_value=mock.MagicMock()): + with mock.patch("pgcli.auth.keyring.get_password", return_value="abc123"): + assert auth.keyring_get_password("test") == "abc123" + + +def test_keyring_get_password_exception(): + with mock.patch("pgcli.auth.keyring", return_value=mock.MagicMock()): + with mock.patch("pgcli.auth.keyring.get_password", side_effect=Exception("Boom!")): + assert auth.keyring_get_password("test") == "" + + +def test_keyring_set_password_ok(): + with mock.patch("pgcli.auth.keyring", return_value=mock.MagicMock()): + with mock.patch("pgcli.auth.keyring.set_password"): + auth.keyring_set_password("test", "abc123") + + +def test_keyring_set_password_exception(): + with mock.patch("pgcli.auth.keyring", return_value=mock.MagicMock()): + with mock.patch("pgcli.auth.keyring.set_password", side_effect=Exception("Boom!")): + auth.keyring_set_password("test", "abc123") diff --git a/tests/test_exceptionals.py b/tests/test_exceptionals.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/test_init_commands_simple.py b/tests/test_init_commands_simple.py new file mode 100644 index 000000000..72b11f3bd --- /dev/null +++ b/tests/test_init_commands_simple.py @@ -0,0 +1,94 @@ +import pytest +from click.testing import CliRunner + +from pgcli.main import cli, PGCli + + +@pytest.fixture +def dummy_exec(monkeypatch, tmp_path): + # Capture executed commands + # Isolate config directory for tests + monkeypatch.setenv("XDG_CONFIG_HOME", str(tmp_path)) + dummy_cmds = [] + + class DummyExec: + def run(self, cmd): + # Ignore ping SELECT 1 commands used for exiting CLI + if cmd.strip().upper() == "SELECT 1": + return [] + # Record init commands + dummy_cmds.append(cmd) + return [] + + def get_timezone(self): + return "UTC" + + def set_timezone(self, *args, **kwargs): + pass + + def fake_connect(self, *args, **kwargs): + self.pgexecute = DummyExec() + + monkeypatch.setattr(PGCli, "connect", fake_connect) + return dummy_cmds + + +def test_init_command_option(dummy_exec): + "Test that --init-command triggers execution of the command." + runner = CliRunner() + # Use a custom init command and --ping to exit the CLI after init commands + result = runner.invoke(cli, ["--init-command", "SELECT foo", "--ping", "db", "user"]) + assert result.exit_code == 0 + # Should print the init command + assert "Running init commands: SELECT foo" in result.output + # Should exit via ping + assert "PONG" in result.output + # DummyExec should have recorded only the init command + assert dummy_exec == ["SELECT foo"] + + +def test_init_commands_from_config(dummy_exec, tmp_path): + """ + Test that init commands defined in the config file are executed on startup. + """ + # Create a temporary config file with init-commands + config_file = tmp_path / "pgclirc_test" + config_file.write_text("[main]\n[init-commands]\nfirst = SELECT foo;\nsecond = SELECT bar;\n") + + runner = CliRunner() + # Use --ping to exit the CLI after init commands + result = runner.invoke(cli, ["--pgclirc", str(config_file.absolute()), "--ping", "testdb", "user"]) + assert result.exit_code == 0 + # Should print both init commands in order (note trailing semicolons cause double ';;') + assert "Running init commands: SELECT foo;; SELECT bar;" in result.output + # DummyExec should have recorded both commands + assert dummy_exec == ["SELECT foo;", "SELECT bar;"] + + +def test_init_commands_option_and_config(dummy_exec, tmp_path): + """ + Test that CLI-provided init command is appended after config-defined commands. + """ + # Create a temporary config file with init-commands + config_file = tmp_path / "pgclirc_test" + config_file.write_text("[main]\n [init-commands]\nfirst = SELECT foo;\n") + + runner = CliRunner() + # Use --ping to exit the CLI after init commands + result = runner.invoke( + cli, + [ + "--pgclirc", + str(config_file), + "--init-command", + "SELECT baz;", + "--ping", + "testdb", + "user", + ], + ) + assert result.exit_code == 0 + # Should print config command followed by CLI option (double ';' between commands) + assert "Running init commands: SELECT foo;; SELECT baz;" in result.output + # DummyExec should record both commands in order + assert dummy_exec == ["SELECT foo;", "SELECT baz;"] diff --git a/tests/test_main.py b/tests/test_main.py index c48accbeb..52415a008 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -1,5 +1,8 @@ import os import platform +import re +import tempfile +import datetime from unittest import mock import pytest @@ -11,7 +14,9 @@ from pgcli.main import ( obfuscate_process_password, + duration_in_words, format_output, + notify_callback, PGCli, OutputSettings, COLOR_CODE_REGEX, @@ -56,21 +61,97 @@ def test_obfuscate_process_password(): def test_format_output(): settings = OutputSettings(table_format="psql", dcmlfmt="d", floatfmt="g") - results = format_output( - "Title", [("abc", "def")], ["head1", "head2"], "test status", settings + results = format_output("Title", [("abc", "def")], ["head1", "head2"], "test status", settings) + expected = [ + "Title", + "+-------+-------+", + "| head1 | head2 |", + "|-------+-------|", + "| abc | def |", + "+-------+-------+", + "test status", + ] + assert list(results) == expected + + +def test_column_date_formats(): + settings = OutputSettings( + table_format="psql", + column_date_formats={ + "date_col": "%Y-%m-%d", + "datetime_col": "%I:%M:%S %m/%d/%y", + }, ) + data = [ + ("name1", "2024-12-13T18:32:22", "2024-12-13T19:32:22", "2024-12-13T20:32:22"), + ("name2", "2025-02-13T02:32:22", "2025-02-13T02:32:22", "2025-02-13T02:32:22"), + ] + headers = ["name", "date_col", "datetime_col", "unchanged_col"] + + results = format_output("Title", data, headers, "test status", settings) + expected = [ + "Title", + "+-------+------------+-------------------+---------------------+", + "| name | date_col | datetime_col | unchanged_col |", + "|-------+------------+-------------------+---------------------|", + "| name1 | 2024-12-13 | 07:32:22 12/13/24 | 2024-12-13T20:32:22 |", + "| name2 | 2025-02-13 | 02:32:22 02/13/25 | 2025-02-13T02:32:22 |", + "+-------+------------+-------------------+---------------------+", + "test status", + ] + assert list(results) == expected + + +def test_no_column_date_formats(): + """Test that not setting any column date formats returns unaltered datetime columns""" + settings = OutputSettings(table_format="psql") + data = [ + ("name1", "2024-12-13T18:32:22", "2024-12-13T19:32:22", "2024-12-13T20:32:22"), + ("name2", "2025-02-13T02:32:22", "2025-02-13T02:32:22", "2025-02-13T02:32:22"), + ] + headers = ["name", "date_col", "datetime_col", "unchanged_col"] + + results = format_output("Title", data, headers, "test status", settings) expected = [ "Title", - "+---------+---------+", - "| head1 | head2 |", - "|---------+---------|", - "| abc | def |", - "+---------+---------+", + "+-------+---------------------+---------------------+---------------------+", + "| name | date_col | datetime_col | unchanged_col |", + "|-------+---------------------+---------------------+---------------------|", + "| name1 | 2024-12-13T18:32:22 | 2024-12-13T19:32:22 | 2024-12-13T20:32:22 |", + "| name2 | 2025-02-13T02:32:22 | 2025-02-13T02:32:22 | 2025-02-13T02:32:22 |", + "+-------+---------------------+---------------------+---------------------+", "test status", ] assert list(results) == expected +def test_format_output_truncate_on(): + settings = OutputSettings(table_format="psql", dcmlfmt="d", floatfmt="g", max_field_width=10) + results = format_output( + None, + [("first field value", "second field value")], + ["head1", "head2"], + None, + settings, + ) + expected = [ + "+------------+------------+", + "| head1 | head2 |", + "|------------+------------|", + "| first f... | second ... |", + "+------------+------------+", + ] + assert list(results) == expected + + +def test_format_output_truncate_off(): + settings = OutputSettings(table_format="psql", dcmlfmt="d", floatfmt="g", max_field_width=None) + long_field_value = ("first field " * 100).strip() + results = format_output(None, [(long_field_value,)], ["head1"], None, settings) + lines = list(results) + assert lines[3] == f"| {long_field_value} |" + + @dbtest def test_format_array_output(executor): statement = """ @@ -83,12 +164,12 @@ def test_format_array_output(executor): """ results = run(executor, statement) expected = [ - "+----------------+------------------------+--------------+", - "| bigint_array | nested_numeric_array | 配列 |", - "|----------------+------------------------+--------------|", - "| {1,2,3} | {{1,2},{3,4}} | {å,魚,текст} |", - "| {} | | {} |", - "+----------------+------------------------+--------------+", + "+--------------+----------------------+--------------+", + "| bigint_array | nested_numeric_array | 配列 |", + "|--------------+----------------------+--------------|", + "| {1,2,3} | {{1,2},{3,4}} | {å,魚,текст} |", + "| {} | | {} |", + "+--------------+----------------------+--------------+", "SELECT 2", ] assert list(results) == expected @@ -120,19 +201,15 @@ def test_format_array_output_expanded(executor): def test_format_output_auto_expand(): - settings = OutputSettings( - table_format="psql", dcmlfmt="d", floatfmt="g", max_width=100 - ) - table_results = format_output( - "Title", [("abc", "def")], ["head1", "head2"], "test status", settings - ) + settings = OutputSettings(table_format="psql", dcmlfmt="d", floatfmt="g", max_width=100) + table_results = format_output("Title", [("abc", "def")], ["head1", "head2"], "test status", settings) table = [ "Title", - "+---------+---------+", - "| head1 | head2 |", - "|---------+---------|", - "| abc | def |", - "+---------+---------+", + "+-------+-------+", + "| head1 | head2 |", + "|-------+-------|", + "| abc | def |", + "+-------+-------+", "test status", ] assert list(table_results) == table @@ -182,19 +259,18 @@ def test_format_output_auto_expand(): def pset_pager_mocks(): cli = PGCli() cli.watch_command = None - with mock.patch("pgcli.main.click.echo") as mock_echo, mock.patch( - "pgcli.main.click.echo_via_pager" - ) as mock_echo_via_pager, mock.patch.object(cli, "prompt_app") as mock_app: - + with ( + mock.patch("pgcli.main.click.echo") as mock_echo, + mock.patch("pgcli.main.click.echo_via_pager") as mock_echo_via_pager, + mock.patch.object(cli, "prompt_app") as mock_app, + ): yield cli, mock_echo, mock_echo_via_pager, mock_app @pytest.mark.parametrize("term_height,term_width,text", test_data, ids=test_ids) def test_pset_pager_off(term_height, term_width, text, pset_pager_mocks): cli, mock_echo, mock_echo_via_pager, mock_cli = pset_pager_mocks - mock_cli.output.get_size.return_value = termsize( - rows=term_height, columns=term_width - ) + mock_cli.output.get_size.return_value = termsize(rows=term_height, columns=term_width) with mock.patch.object(cli.pgspecial, "pager_config", PAGER_OFF): cli.echo_via_pager(text) @@ -206,9 +282,7 @@ def test_pset_pager_off(term_height, term_width, text, pset_pager_mocks): @pytest.mark.parametrize("term_height,term_width,text", test_data, ids=test_ids) def test_pset_pager_always(term_height, term_width, text, pset_pager_mocks): cli, mock_echo, mock_echo_via_pager, mock_cli = pset_pager_mocks - mock_cli.output.get_size.return_value = termsize( - rows=term_height, columns=term_width - ) + mock_cli.output.get_size.return_value = termsize(rows=term_height, columns=term_width) with mock.patch.object(cli.pgspecial, "pager_config", PAGER_ALWAYS): cli.echo_via_pager(text) @@ -220,14 +294,10 @@ def test_pset_pager_always(term_height, term_width, text, pset_pager_mocks): pager_on_test_data = [l + (r,) for l, r in zip(test_data, use_pager_when_on)] -@pytest.mark.parametrize( - "term_height,term_width,text,use_pager", pager_on_test_data, ids=test_ids -) +@pytest.mark.parametrize("term_height,term_width,text,use_pager", pager_on_test_data, ids=test_ids) def test_pset_pager_on(term_height, term_width, text, use_pager, pset_pager_mocks): cli, mock_echo, mock_echo_via_pager, mock_cli = pset_pager_mocks - mock_cli.output.get_size.return_value = termsize( - rows=term_height, columns=term_width - ) + mock_cli.output.get_size.return_value = termsize(rows=term_height, columns=term_width) with mock.patch.object(cli.pgspecial, "pager_config", PAGER_LONG_OUTPUT): cli.echo_via_pager(text) @@ -244,15 +314,14 @@ def test_pset_pager_on(term_height, term_width, text, use_pager, pset_pager_mock "text,expected_length", [ ( - "22200K .......\u001b[0m\u001b[91m... .......... ...\u001b[0m\u001b[91m.\u001b[0m\u001b[91m...... .........\u001b[0m\u001b[91m.\u001b[0m\u001b[91m \u001b[0m\u001b[91m.\u001b[0m\u001b[91m.\u001b[0m\u001b[91m.\u001b[0m\u001b[91m.\u001b[0m\u001b[91m...... 50% 28.6K 12m55s", + "22200K .......\u001b[0m\u001b[91m... .......... ...\u001b[0m\u001b[91m.\u001b[0m\u001b[91m...... .........\u001b[0m\u001b[91m.\u001b[0m\u001b[91m \u001b[0m\u001b[91m.\u001b[0m\u001b[91m.\u001b[0m\u001b[91m.\u001b[0m\u001b[91m.\u001b[0m\u001b[91m...... 50% 28.6K 12m55s", # noqa: E501 78, ), ("=\u001b[m=", 2), ("-\u001b]23\u0007-", 2), ], ) -def test_color_pattern(text, expected_length, pset_pager_mocks): - cli = pset_pager_mocks[0] +def test_color_pattern(text, expected_length): assert len(COLOR_CODE_REGEX.sub("", text)) == expected_length @@ -266,6 +335,108 @@ def test_i_works(tmpdir, executor): run(executor, statement, pgspecial=cli.pgspecial) +@dbtest +def test_toggle_verbose_errors(executor): + cli = PGCli(pgexecute=executor) + + cli._evaluate_command("\\v on") + assert cli.verbose_errors + output, _ = cli._evaluate_command("SELECT 1/0") + assert "SQLSTATE" in output[0] + + cli._evaluate_command("\\v off") + assert not cli.verbose_errors + output, _ = cli._evaluate_command("SELECT 1/0") + assert "SQLSTATE" not in output[0] + + cli._evaluate_command("\\v") + assert cli.verbose_errors + + +@dbtest +def test_echo_works(executor): + cli = PGCli(pgexecute=executor) + statement = r"\echo asdf" + result = run(executor, statement, pgspecial=cli.pgspecial) + assert result == ["asdf"] + + +@dbtest +def test_qecho_works(executor): + cli = PGCli(pgexecute=executor) + statement = r"\qecho asdf" + result = run(executor, statement, pgspecial=cli.pgspecial) + assert result == ["asdf"] + + +@dbtest +def test_logfile_works(executor): + with tempfile.TemporaryDirectory() as tmpdir: + log_file = f"{tmpdir}/tempfile.log" + cli = PGCli(pgexecute=executor, log_file=log_file) + statement = r"\qecho hello!" + cli.execute_command(statement) + with open(log_file, "r") as f: + log_contents = f.readlines() + assert datetime.datetime.fromisoformat(log_contents[0].strip()) + assert log_contents[1].strip() == r"\qecho hello!" + assert log_contents[2].strip() == "hello!" + + +@dbtest +def test_logfile_unwriteable_file(executor): + cli = PGCli(pgexecute=executor) + statement = r"\log-file forbidden.log" + with mock.patch("builtins.open") as mock_open: + mock_open.side_effect = PermissionError("[Errno 13] Permission denied: 'forbidden.log'") + result = run(executor, statement, pgspecial=cli.pgspecial) + assert result == ["[Errno 13] Permission denied: 'forbidden.log'\nLogfile capture disabled"] + + +@dbtest +def test_watch_works(executor): + cli = PGCli(pgexecute=executor) + + def run_with_watch(query, target_call_count=1, expected_output="", expected_timing=None): + """ + :param query: Input to the CLI + :param target_call_count: Number of times the user lets the command run before Ctrl-C + :param expected_output: Substring expected to be found for each executed query + :param expected_timing: value `time.sleep` expected to be called with on every invocation + """ + with mock.patch.object(cli, "echo_via_pager") as mock_echo, mock.patch("pgcli.main.sleep") as mock_sleep: + mock_sleep.side_effect = [None] * (target_call_count - 1) + [KeyboardInterrupt] + cli.handle_watch_command(query) + # Validate that sleep was called with the right timing + for i in range(target_call_count - 1): + assert mock_sleep.call_args_list[i][0][0] == expected_timing + # Validate that the output of the query was expected + assert mock_echo.call_count == target_call_count + for i in range(target_call_count): + assert expected_output in mock_echo.call_args_list[i][0][0] + + # With no history, it errors. + with mock.patch("pgcli.main.click.secho") as mock_secho: + cli.handle_watch_command(r"\watch 2") + mock_secho.assert_called() + assert r"\watch cannot be used with an empty query" in mock_secho.call_args_list[0][0][0] + + # Usage 1: Run a query and then re-run it with \watch across two prompts. + run_with_watch("SELECT 111", expected_output="111") + run_with_watch("\\watch 10", target_call_count=2, expected_output="111", expected_timing=10) + + # Usage 2: Run a query and \watch via the same prompt. + run_with_watch( + "SELECT 222; \\watch 4", + target_call_count=3, + expected_output="222", + expected_timing=4, + ) + + # Usage 3: Re-run the last watched command with a new timing + run_with_watch("\\watch 5", target_call_count=4, expected_output="222", expected_timing=5) + + def test_missing_rc_dir(tmpdir): rcfile = str(tmpdir.join("subdir").join("rcfile")) @@ -277,13 +448,10 @@ def test_quoted_db_uri(tmpdir): with mock.patch.object(PGCli, "connect") as mock_connect: cli = PGCli(pgclirc_file=str(tmpdir.join("rcfile"))) cli.connect_uri("postgres://bar%5E:%5Dfoo@baz.com/testdb%5B") - mock_connect.assert_called_with( - database="testdb[", host="baz.com", user="bar^", passwd="]foo" - ) + mock_connect.assert_called_with(database="testdb[", host="baz.com", user="bar^", passwd="]foo") def test_pg_service_file(tmpdir): - with mock.patch.object(PGCli, "connect") as mock_connect: cli = PGCli(pgclirc_file=str(tmpdir.join("rcfile"))) with open(tmpdir.join(".pg_service.conf").strpath, "w") as service_conf: @@ -329,6 +497,7 @@ def test_pg_service_file(tmpdir): "b_host", "5435", "", + notify_callback, application_name="pgcli", ) del os.environ["PGPASSWORD"] @@ -339,8 +508,7 @@ def test_ssl_db_uri(tmpdir): with mock.patch.object(PGCli, "connect") as mock_connect: cli = PGCli(pgclirc_file=str(tmpdir.join("rcfile"))) cli.connect_uri( - "postgres://bar%5E:%5Dfoo@baz.com/testdb%5B?" - "sslmode=verify-full&sslcert=m%79.pem&sslkey=my-key.pem&sslrootcert=c%61.pem" + "postgres://bar%5E:%5Dfoo@baz.com/testdb%5B?sslmode=verify-full&sslcert=m%79.pem&sslkey=my-key.pem&sslrootcert=c%61.pem" ) mock_connect.assert_called_with( database="testdb[", @@ -358,17 +526,13 @@ def test_port_db_uri(tmpdir): with mock.patch.object(PGCli, "connect") as mock_connect: cli = PGCli(pgclirc_file=str(tmpdir.join("rcfile"))) cli.connect_uri("postgres://bar:foo@baz.com:2543/testdb") - mock_connect.assert_called_with( - database="testdb", host="baz.com", user="bar", passwd="foo", port="2543" - ) + mock_connect.assert_called_with(database="testdb", host="baz.com", user="bar", passwd="foo", port="2543") def test_multihost_db_uri(tmpdir): with mock.patch.object(PGCli, "connect") as mock_connect: cli = PGCli(pgclirc_file=str(tmpdir.join("rcfile"))) - cli.connect_uri( - "postgres://bar:foo@baz1.com:2543,baz2.com:2543,baz3.com:2543/testdb" - ) + cli.connect_uri("postgres://bar:foo@baz1.com:2543,baz2.com:2543,baz3.com:2543/testdb") mock_connect.assert_called_with( database="testdb", host="baz1.com,baz2.com,baz3.com", @@ -383,6 +547,92 @@ def test_application_name_db_uri(tmpdir): mock_pgexecute.return_value = None cli = PGCli(pgclirc_file=str(tmpdir.join("rcfile"))) cli.connect_uri("postgres://bar@baz.com/?application_name=cow") - mock_pgexecute.assert_called_with( - "bar", "bar", "", "baz.com", "", "", application_name="cow" + mock_pgexecute.assert_called_with("bar", "bar", "", "baz.com", "", "", notify_callback, application_name="cow") + + +@pytest.mark.parametrize( + "duration_in_seconds,words", + [ + (0, "0 seconds"), + (0.0009, "0.001 second"), + (0.0005, "0.001 second"), + (0.0004, "0.0 second"), # not perfect, but will do + (0.2, "0.2 second"), + (1, "1 second"), + (1.4, "1 second"), + (2, "2 seconds"), + (3.4, "3 seconds"), + (60, "1 minute"), + (61, "1 minute 1 second"), + (123, "2 minutes 3 seconds"), + (124.4, "2 minutes 4 seconds"), + (3600, "1 hour"), + (7235, "2 hours 35 seconds"), + (9005, "2 hours 30 minutes 5 seconds"), + (9006.7, "2 hours 30 minutes 6 seconds"), + (86401, "24 hours 1 second"), + ], +) +def test_duration_in_words(duration_in_seconds, words): + assert duration_in_words(duration_in_seconds) == words + + +@pytest.mark.parametrize( + "transaction_indicator,expected", + [ + ("*", "*testuser"), # valid transaction + ("!", "!testuser"), # failed transaction + ("?", "?testuser"), # connection closed + ("", "testuser"), # idle + ], +) +def test_get_prompt_with_transaction_status(transaction_indicator, expected): + cli = PGCli() + cli.pgexecute = mock.MagicMock() + cli.pgexecute.user = "testuser" + cli.pgexecute.dbname = "testdb" + cli.pgexecute.host = "localhost" + cli.pgexecute.short_host = "localhost" + cli.pgexecute.port = 5432 + cli.pgexecute.pid = 12345 + cli.pgexecute.superuser = False + cli.pgexecute.transaction_indicator = transaction_indicator + + result = cli.get_prompt("\\T\\u") + assert result == expected + + +def test_get_prompt_transaction_status_in_full_prompt(): + cli = PGCli() + cli.pgexecute = mock.MagicMock() + cli.pgexecute.user = "user" + cli.pgexecute.dbname = "mydb" + cli.pgexecute.host = "db.example.com" + cli.pgexecute.short_host = "db.example.com" + cli.pgexecute.port = 5432 + cli.pgexecute.pid = 12345 + cli.pgexecute.superuser = False + cli.pgexecute.transaction_indicator = "*" + + result = cli.get_prompt("\\T\\u@\\h:\\d> ") + assert result == "*user@db.example.com:mydb> " + + +@dbtest +def test_notifications(executor): + run(executor, "listen chan1") + + with mock.patch("pgcli.main.click.secho") as mock_secho: + run(executor, "notify chan1, 'testing1'") + mock_secho.assert_called() + arg = mock_secho.call_args_list[0].args[0] + assert re.match( + r'Notification received on channel "chan1" \(PID \d+\):\ntesting1', + arg, ) + + run(executor, "unlisten chan1") + + with mock.patch("pgcli.main.click.secho") as mock_secho: + run(executor, "notify chan1, 'testing2'") + mock_secho.assert_not_called() diff --git a/tests/test_naive_completion.py b/tests/test_naive_completion.py index 5b936619b..ec5034e04 100644 --- a/tests/test_naive_completion.py +++ b/tests/test_naive_completion.py @@ -21,56 +21,38 @@ def complete_event(): def test_empty_string_completion(completer, complete_event): text = "" position = 0 - result = completions_to_set( - completer.get_completions( - Document(text=text, cursor_position=position), complete_event - ) - ) + result = completions_to_set(completer.get_completions(Document(text=text, cursor_position=position), complete_event)) assert result == completions_to_set(map(Completion, completer.all_completions)) def test_select_keyword_completion(completer, complete_event): text = "SEL" position = len("SEL") - result = completions_to_set( - completer.get_completions( - Document(text=text, cursor_position=position), complete_event - ) - ) + result = completions_to_set(completer.get_completions(Document(text=text, cursor_position=position), complete_event)) assert result == completions_to_set([Completion(text="SELECT", start_position=-3)]) def test_function_name_completion(completer, complete_event): text = "SELECT MA" position = len("SELECT MA") - result = completions_to_set( - completer.get_completions( - Document(text=text, cursor_position=position), complete_event - ) - ) - assert result == completions_to_set( - [ - Completion(text="MATERIALIZED VIEW", start_position=-2), - Completion(text="MAX", start_position=-2), - Completion(text="MAXEXTENTS", start_position=-2), - Completion(text="MAKE_DATE", start_position=-2), - Completion(text="MAKE_TIME", start_position=-2), - Completion(text="MAKE_TIMESTAMPTZ", start_position=-2), - Completion(text="MAKE_INTERVAL", start_position=-2), - Completion(text="MASKLEN", start_position=-2), - Completion(text="MAKE_TIMESTAMP", start_position=-2), - ] - ) + result = completions_to_set(completer.get_completions(Document(text=text, cursor_position=position), complete_event)) + assert result == completions_to_set([ + Completion(text="MATERIALIZED VIEW", start_position=-2), + Completion(text="MAX", start_position=-2), + Completion(text="MAXEXTENTS", start_position=-2), + Completion(text="MAKE_DATE", start_position=-2), + Completion(text="MAKE_TIME", start_position=-2), + Completion(text="MAKE_TIMESTAMPTZ", start_position=-2), + Completion(text="MAKE_INTERVAL", start_position=-2), + Completion(text="MASKLEN", start_position=-2), + Completion(text="MAKE_TIMESTAMP", start_position=-2), + ]) def test_column_name_completion(completer, complete_event): text = "SELECT FROM users" position = len("SELECT ") - result = completions_to_set( - completer.get_completions( - Document(text=text, cursor_position=position), complete_event - ) - ) + result = completions_to_set(completer.get_completions(Document(text=text, cursor_position=position), complete_event)) assert result == completions_to_set(map(Completion, completer.all_completions)) @@ -84,27 +66,18 @@ def test_alter_well_known_keywords_completion(completer, complete_event): smart_completion=True, ) ) - assert result > completions_to_set( - [ - Completion(text="DATABASE", display_meta="keyword"), - Completion(text="TABLE", display_meta="keyword"), - Completion(text="SYSTEM", display_meta="keyword"), - ] - ) - assert ( - completions_to_set([Completion(text="CREATE", display_meta="keyword")]) - not in result - ) + assert result > completions_to_set([ + Completion(text="DATABASE", display_meta="keyword"), + Completion(text="TABLE", display_meta="keyword"), + Completion(text="SYSTEM", display_meta="keyword"), + ]) + assert completions_to_set([Completion(text="CREATE", display_meta="keyword")]) not in result def test_special_name_completion(completer, complete_event): text = "\\" position = len("\\") - result = completions_to_set( - completer.get_completions( - Document(text=text, cursor_position=position), complete_event - ) - ) + result = completions_to_set(completer.get_completions(Document(text=text, cursor_position=position), complete_event)) # Special commands will NOT be suggested during naive completion mode. assert result == completions_to_set([]) @@ -119,15 +92,13 @@ def test_datatype_name_completion(completer, complete_event): smart_completion=True, ) ) - assert result == completions_to_set( - [ - Completion(text="INET", display_meta="datatype"), - Completion(text="INT", display_meta="datatype"), - Completion(text="INT2", display_meta="datatype"), - Completion(text="INT4", display_meta="datatype"), - Completion(text="INT8", display_meta="datatype"), - Completion(text="INTEGER", display_meta="datatype"), - Completion(text="INTERNAL", display_meta="datatype"), - Completion(text="INTERVAL", display_meta="datatype"), - ] - ) + assert result == completions_to_set([ + Completion(text="INET", display_meta="datatype"), + Completion(text="INT", display_meta="datatype"), + Completion(text="INT2", display_meta="datatype"), + Completion(text="INT4", display_meta="datatype"), + Completion(text="INT8", display_meta="datatype"), + Completion(text="INTEGER", display_meta="datatype"), + Completion(text="INTERNAL", display_meta="datatype"), + Completion(text="INTERVAL", display_meta="datatype"), + ]) diff --git a/tests/test_pgcompleter.py b/tests/test_pgcompleter.py new file mode 100644 index 000000000..a9f390d76 --- /dev/null +++ b/tests/test_pgcompleter.py @@ -0,0 +1,94 @@ +import json +import pytest +from pgcli import pgcompleter +import tempfile + + +def test_load_alias_map_file_missing_file(): + with pytest.raises( + pgcompleter.InvalidMapFile, + match=r"Cannot read alias_map_file - /path/to/non-existent/file.json does not exist$", + ): + pgcompleter.load_alias_map_file("/path/to/non-existent/file.json") + + +def test_load_alias_map_file_invalid_json(tmp_path): + fpath = tmp_path / "foo.json" + fpath.write_text("this is not valid json") + with pytest.raises(pgcompleter.InvalidMapFile, match=r".*is not valid json$"): + pgcompleter.load_alias_map_file(str(fpath)) + + +@pytest.mark.parametrize( + "table_name, alias", + [ + ("SomE_Table", "SET"), + ("SOmeTabLe", "SOTL"), + ("someTable", "T"), + ], +) +def test_generate_alias_uses_upper_case_letters_from_name(table_name, alias): + assert pgcompleter.generate_alias(table_name) == alias + + +@pytest.mark.parametrize( + "table_name, alias", + [ + ("some_tab_le", "stl"), + ("s_ome_table", "sot"), + ("sometable", "s"), + ], +) +def test_generate_alias_uses_first_char_and_every_preceded_by_underscore(table_name, alias): + assert pgcompleter.generate_alias(table_name) == alias + + +@pytest.mark.parametrize( + "table_name, alias_map, alias", + [ + ("some_table", {"some_table": "my_alias"}, "my_alias"), + pytest.param("some_other_table", {"some_table": "my_alias"}, "sot", id="no_match_in_map"), + ], +) +def test_generate_alias_can_use_alias_map(table_name, alias_map, alias): + assert pgcompleter.generate_alias(table_name, alias_map) == alias + + +@pytest.mark.parametrize( + "table_name, alias_map, alias", + [ + ("some_table", {"some_table": "my_alias"}, "my_alias"), + ], +) +def test_pgcompleter_alias_uses_configured_alias_map(table_name, alias_map, alias): + with tempfile.NamedTemporaryFile(mode="w", suffix=".json") as alias_map_file: + alias_map_file.write(json.dumps(alias_map)) + alias_map_file.seek(0) + completer = pgcompleter.PGCompleter( + settings={ + "generate_aliases": True, + "alias_map_file": alias_map_file.name, + } + ) + assert completer.alias(table_name, []) == alias + + +@pytest.mark.parametrize( + "table_name, alias_map, alias", + [ + ("SomeTable", {"SomeTable": "my_alias"}, "my_alias"), + ], +) +def test_generate_alias_prefers_alias_over_upper_case_name(table_name, alias_map, alias): + assert pgcompleter.generate_alias(table_name, alias_map) == alias + + +@pytest.mark.parametrize( + "table_name, alias", + [ + ("Some_tablE", "SE"), + ("SomeTab_le", "ST"), + ], +) +def test_generate_alias_prefers_upper_case_name_over_underscore_name(table_name, alias): + assert pgcompleter.generate_alias(table_name) == alias diff --git a/tests/test_pgexecute.py b/tests/test_pgexecute.py index 109674cb7..2b8e87cc0 100644 --- a/tests/test_pgexecute.py +++ b/tests/test_pgexecute.py @@ -1,12 +1,13 @@ +import re from textwrap import dedent -import psycopg2 +import psycopg import pytest from unittest.mock import patch, MagicMock from pgspecial.main import PGSpecial, NO_QUERY from utils import run, dbtest, requires_json, requires_jsonb -from pgcli.main import PGCli +from pgcli.main import PGCli, exception_formatter as main_exception_formatter from pgcli.packages.parseutils.meta import FunctionMetadata @@ -89,8 +90,8 @@ def test_expanded_slash_G(executor, pgspecial): # Tests whether we reset the expanded output after a \G. run(executor, """create table test(a boolean)""") run(executor, """insert into test values(True)""") - results = run(executor, r"""select * from test \G""", pgspecial=pgspecial) - assert pgspecial.expanded_output == False + run(executor, r"""select * from test \G""", pgspecial=pgspecial) + assert pgspecial.expanded_output is False @dbtest @@ -131,9 +132,7 @@ def test_schemata_table_views_and_columns_query(executor): # views assert set(executor.views()) >= {("public", "d")} - assert set(executor.view_columns()) >= { - ("public", "d", "e", "integer", False, None) - } + assert set(executor.view_columns()) >= {("public", "d", "e", "integer", False, None)} @dbtest @@ -146,9 +145,7 @@ def test_foreign_key_query(executor): "create table schema2.child(childid int PRIMARY KEY, motherid int REFERENCES schema1.parent)", ) - assert set(executor.foreignkeys()) >= { - ("schema1", "parent", "parentid", "schema2", "child", "motherid") - } + assert set(executor.foreignkeys()) >= {("schema1", "parent", "parentid", "schema2", "child", "motherid")} @dbtest @@ -197,9 +194,7 @@ def test_functions_query(executor): return_type="integer", is_set_returning=True, ), - function_meta_data( - schema_name="schema1", func_name="func2", return_type="integer" - ), + function_meta_data(schema_name="schema1", func_name="func2", return_type="integer"), } @@ -219,15 +214,38 @@ def test_database_list(executor): @dbtest def test_invalid_syntax(executor, exception_formatter): - result = run(executor, "invalid syntax!", exception_formatter=exception_formatter) + result = run( + executor, + "invalid syntax!", + exception_formatter=lambda x: main_exception_formatter(x, verbose_errors=False), + ) assert 'syntax error at or near "invalid"' in result[0] + assert "SQLSTATE" not in result[0] @dbtest -def test_invalid_column_name(executor, exception_formatter): +def test_invalid_syntax_verbose(executor): result = run( - executor, "select invalid command", exception_formatter=exception_formatter + executor, + "invalid syntax!", + exception_formatter=lambda x: main_exception_formatter(x, verbose_errors=True), ) + fields = r""" +Severity: ERROR +Severity \(non-localized\): ERROR +SQLSTATE code: 42601 +Message: syntax error at or near "invalid" +Position: 1 +File: scan\.l +Line: \d+ +Routine: scanner_yyerror + """.strip() + assert re.search(fields, result[0]) + + +@dbtest +def test_invalid_column_name(executor, exception_formatter): + result = run(executor, "select invalid command", exception_formatter=exception_formatter) assert 'column "invalid" does not exist' in result[0] @@ -242,9 +260,7 @@ def test_unicode_support_in_output(executor, expanded): run(executor, "insert into unicodechars (t) values ('é')") # See issue #24, this raises an exception without proper handling - assert "é" in run( - executor, "select * from unicodechars", join=True, expanded=expanded - ) + assert "é" in run(executor, "select * from unicodechars", join=True, expanded=expanded) @dbtest @@ -253,8 +269,8 @@ def test_not_is_special(executor, pgspecial): query = "select 1" result = list(executor.run(query, pgspecial=pgspecial)) success, is_special = result[0][5:] - assert success == True - assert is_special == False + assert success is True + assert is_special is False @dbtest @@ -263,8 +279,8 @@ def test_execute_from_file_no_arg(executor, pgspecial): result = list(executor.run(r"\i", pgspecial=pgspecial)) status, sql, success, is_special = result[0][3:] assert "missing required argument" in status - assert success == False - assert is_special == True + assert success is False + assert is_special is True @dbtest @@ -278,8 +294,215 @@ def test_execute_from_file_io_error(os, executor, pgspecial): result = list(executor.run(r"\i test", pgspecial=pgspecial)) status, sql, success, is_special = result[0][3:] assert status == "test" - assert success == False - assert is_special == True + assert success is False + assert is_special is True + + +@dbtest +def test_execute_from_commented_file_that_executes_another_file(executor, pgspecial, tmpdir): + # https://github.com/dbcli/pgcli/issues/1336 + sqlfile1 = tmpdir.join("test01.sql") + sqlfile1.write("-- asdf \n\\h") + sqlfile2 = tmpdir.join("test00.sql") + sqlfile2.write("--An useless comment;\nselect now();\n-- another useless comment") + + rcfile = str(tmpdir.join("rcfile")) + print(rcfile) + cli = PGCli(pgexecute=executor, pgclirc_file=rcfile) + assert cli is not None + statement = "--comment\n\\h" + result = run(executor, statement, pgspecial=cli.pgspecial) + assert result is not None + assert result[0].find("ALTER TABLE") + + +@dbtest +def test_execute_commented_first_line_and_special(executor, pgspecial, tmpdir): + # just some base cases that should work also + statement = "--comment\nselect now();" + result = run(executor, statement, pgspecial=pgspecial) + assert result is not None + assert result[1].find("now") >= 0 + + statement = "/*comment*/\nselect now();" + result = run(executor, statement, pgspecial=pgspecial) + assert result is not None + assert result[1].find("now") >= 0 + + # https://github.com/dbcli/pgcli/issues/1362 + statement = "--comment\n\\h" + result = run(executor, statement, pgspecial=pgspecial) + assert result is not None + assert result[1].find("ALTER") >= 0 + assert result[1].find("ABORT") >= 0 + + statement = "--comment1\n--comment2\n\\h" + result = run(executor, statement, pgspecial=pgspecial) + assert result is not None + assert result[1].find("ALTER") >= 0 + assert result[1].find("ABORT") >= 0 + + statement = "/*comment*/\n\\h;" + result = run(executor, statement, pgspecial=pgspecial) + assert result is not None + assert result[1].find("ALTER") >= 0 + assert result[1].find("ABORT") >= 0 + + statement = r"""/*comment1 + comment2*/ + \h""" + result = run(executor, statement, pgspecial=pgspecial) + assert result is not None + assert result[1].find("ALTER") >= 0 + assert result[1].find("ABORT") >= 0 + + statement = """/*comment1 + comment2*/ + /*comment 3 + comment4*/ + \\h""" + result = run(executor, statement, pgspecial=pgspecial) + assert result is not None + assert result[1].find("ALTER") >= 0 + assert result[1].find("ABORT") >= 0 + + statement = " /*comment*/\n\\h;" + result = run(executor, statement, pgspecial=pgspecial) + assert result is not None + assert result[1].find("ALTER") >= 0 + assert result[1].find("ABORT") >= 0 + + statement = "/*comment\ncomment line2*/\n\\h;" + result = run(executor, statement, pgspecial=pgspecial) + assert result is not None + assert result[1].find("ALTER") >= 0 + assert result[1].find("ABORT") >= 0 + + statement = " /*comment\ncomment line2*/\n\\h;" + result = run(executor, statement, pgspecial=pgspecial) + assert result is not None + assert result[1].find("ALTER") >= 0 + assert result[1].find("ABORT") >= 0 + + statement = """\\h /*comment4 */""" + result = run(executor, statement, pgspecial=pgspecial) + print(result) + assert result is not None + assert result[0].find("No help") >= 0 + + # TODO: we probably don't want to do this but sqlparse is not parsing things well + # we relly want it to find help but right now, sqlparse isn't dropping the /*comment*/ + # style comments after command + + statement = r"""/*comment1*/ + \h + /*comment4 */""" + result = run(executor, statement, pgspecial=pgspecial) + assert result is not None + assert result[0].find("No help") >= 0 + + # TODO: same for this one + statement = """/*comment1 + comment3 + comment2*/ + \\h + /*comment4 + comment5 + comment6*/""" + result = run(executor, statement, pgspecial=pgspecial) + assert result is not None + assert result[0].find("No help") >= 0 + + +@dbtest +def test_execute_commented_first_line_and_normal(executor, pgspecial, tmpdir): + # https://github.com/dbcli/pgcli/issues/1403 + + # just some base cases that should work also + statement = "--comment\nselect now();" + result = run(executor, statement, pgspecial=pgspecial) + assert result is not None + assert result[1].find("now") >= 0 + + statement = "/*comment*/\nselect now();" + result = run(executor, statement, pgspecial=pgspecial) + assert result is not None + assert result[1].find("now") >= 0 + + # this simulates the original error (1403) without having to add/drop tables + # since it was just an error on reading input files and not the actual + # command itself + + # test that the statement works + statement = """VALUES (1, 'one'), (2, 'two'), (3, 'three');""" + result = run(executor, statement, pgspecial=pgspecial) + assert result is not None + assert result[5].find("three") >= 0 + + # test the statement with a \n in the middle + statement = """VALUES (1, 'one'),\n (2, 'two'), (3, 'three');""" + result = run(executor, statement, pgspecial=pgspecial) + assert result is not None + assert result[5].find("three") >= 0 + + # test the statement with a newline in the middle + statement = """VALUES (1, 'one'), + (2, 'two'), (3, 'three');""" + result = run(executor, statement, pgspecial=pgspecial) + assert result is not None + assert result[5].find("three") >= 0 + + # now add a single comment line + statement = """--comment\nVALUES (1, 'one'), (2, 'two'), (3, 'three');""" + result = run(executor, statement, pgspecial=pgspecial) + assert result is not None + assert result[5].find("three") >= 0 + + # doing without special char \n + statement = """--comment + VALUES (1,'one'), + (2, 'two'), (3, 'three');""" + result = run(executor, statement, pgspecial=pgspecial) + assert result is not None + assert result[5].find("three") >= 0 + + # two comment lines + statement = """--comment\n--comment2\nVALUES (1,'one'), (2, 'two'), (3, 'three');""" + result = run(executor, statement, pgspecial=pgspecial) + assert result is not None + assert result[5].find("three") >= 0 + + # doing without special char \n + statement = """--comment + --comment2 + VALUES (1,'one'), (2, 'two'), (3, 'three'); + """ + result = run(executor, statement, pgspecial=pgspecial) + assert result is not None + assert result[5].find("three") >= 0 + + # multiline comment + newline in middle of the statement + statement = """/*comment +comment2 +comment3*/ +VALUES (1,'one'), +(2, 'two'), (3, 'three');""" + result = run(executor, statement, pgspecial=pgspecial) + assert result is not None + assert result[5].find("three") >= 0 + + # multiline comment + newline in middle of the statement + # + comments after the statement + statement = """/*comment +comment2 +comment3*/ +VALUES (1,'one'), +(2, 'two'), (3, 'three'); +--comment4 +--comment5""" + result = run(executor, statement, pgspecial=pgspecial) + assert result is not None + assert result[5].find("three") >= 0 @dbtest @@ -347,9 +570,7 @@ def test_unicode_support_in_enum_type(executor): def test_json_renders_without_u_prefix(executor, expanded): run(executor, "create table jsontest(d json)") run(executor, """insert into jsontest (d) values ('{"name": "Éowyn"}')""") - result = run( - executor, "SELECT d FROM jsontest LIMIT 1", join=True, expanded=expanded - ) + result = run(executor, "SELECT d FROM jsontest LIMIT 1", join=True, expanded=expanded) assert '{"name": "Éowyn"}' in result @@ -358,9 +579,7 @@ def test_json_renders_without_u_prefix(executor, expanded): def test_jsonb_renders_without_u_prefix(executor, expanded): run(executor, "create table jsonbtest(d jsonb)") run(executor, """insert into jsonbtest (d) values ('{"name": "Éowyn"}')""") - result = run( - executor, "SELECT d FROM jsonbtest LIMIT 1", join=True, expanded=expanded - ) + result = run(executor, "SELECT d FROM jsonbtest LIMIT 1", join=True, expanded=expanded) assert '{"name": "Éowyn"}' in result @@ -368,28 +587,10 @@ def test_jsonb_renders_without_u_prefix(executor, expanded): @dbtest def test_date_time_types(executor): run(executor, "SET TIME ZONE UTC") - assert ( - run(executor, "SELECT (CAST('00:00:00' AS time))", join=True).split("\n")[3] - == "| 00:00:00 |" - ) - assert ( - run(executor, "SELECT (CAST('00:00:00+14:59' AS timetz))", join=True).split( - "\n" - )[3] - == "| 00:00:00+14:59 |" - ) - assert ( - run(executor, "SELECT (CAST('4713-01-01 BC' AS date))", join=True).split("\n")[ - 3 - ] - == "| 4713-01-01 BC |" - ) - assert ( - run( - executor, "SELECT (CAST('4713-01-01 00:00:00 BC' AS timestamp))", join=True - ).split("\n")[3] - == "| 4713-01-01 00:00:00 BC |" - ) + assert run(executor, "SELECT (CAST('00:00:00' AS time))", join=True).split("\n")[3] == "| 00:00:00 |" + assert run(executor, "SELECT (CAST('00:00:00+14:59' AS timetz))", join=True).split("\n")[3] == "| 00:00:00+14:59 |" + assert run(executor, "SELECT (CAST('4713-01-01 BC' AS date))", join=True).split("\n")[3] == "| 4713-01-01 BC |" + assert run(executor, "SELECT (CAST('4713-01-01 00:00:00 BC' AS timestamp))", join=True).split("\n")[3] == "| 4713-01-01 00:00:00 BC |" assert ( run( executor, @@ -399,10 +600,7 @@ def test_date_time_types(executor): == "| 4713-01-01 00:00:00+00 BC |" ) assert ( - run( - executor, "SELECT (CAST('-123456789 days 12:23:56' AS interval))", join=True - ).split("\n")[3] - == "| -123456789 days, 12:23:56 |" + run(executor, "SELECT (CAST('-123456789 days 12:23:56' AS interval))", join=True).split("\n")[3] == "| -123456789 days, 12:23:56 |" ) @@ -428,27 +626,21 @@ def test_describe_special(executor, command, verbose, pattern, pgspecial): @dbtest @pytest.mark.parametrize("sql", ["invalid sql", "SELECT 1; select error;"]) def test_raises_with_no_formatter(executor, sql): - with pytest.raises(psycopg2.ProgrammingError): + with pytest.raises(psycopg.ProgrammingError): list(executor.run(sql)) @dbtest def test_on_error_resume(executor, exception_formatter): sql = "select 1; error; select 1;" - result = list( - executor.run(sql, on_error_resume=True, exception_formatter=exception_formatter) - ) + result = list(executor.run(sql, on_error_resume=True, exception_formatter=exception_formatter)) assert len(result) == 3 @dbtest def test_on_error_stop(executor, exception_formatter): sql = "select 1; error; select 1;" - result = list( - executor.run( - sql, on_error_resume=False, exception_formatter=exception_formatter - ) - ) + result = list(executor.run(sql, on_error_resume=False, exception_formatter=exception_formatter)) assert len(result) == 2 @@ -462,7 +654,7 @@ def test_on_error_stop(executor, exception_formatter): @dbtest def test_nonexistent_function_definition(executor): with pytest.raises(RuntimeError): - result = executor.view_definition("there_is_no_such_function") + executor.view_definition("there_is_no_such_function") @dbtest @@ -478,7 +670,39 @@ def test_function_definition(executor): $function$ """, ) - result = executor.function_definition("the_number_three") + executor.function_definition("the_number_three") + + +@dbtest +def test_function_notice_order(executor): + run( + executor, + """ + CREATE OR REPLACE FUNCTION demo_order() RETURNS VOID AS + $$ + BEGIN + RAISE NOTICE 'first'; + RAISE NOTICE 'second'; + RAISE NOTICE 'third'; + RAISE NOTICE 'fourth'; + RAISE NOTICE 'fifth'; + RAISE NOTICE 'sixth'; + END; + $$ + LANGUAGE plpgsql; + """, + ) + + executor.function_definition("demo_order") + + result = run(executor, "select demo_order()") + assert "first\nsecond\nthird\nfourth\nfifth\nsixth" in result[0] + assert "+------------+" in result[1] + assert "| demo_order |" in result[2] + assert "|------------|" in result[3] + assert "| |" in result[4] + assert "+------------+" in result[5] + assert "SELECT 1" in result[6] @dbtest @@ -487,6 +711,7 @@ def test_view_definition(executor): run(executor, "create view vw1 AS SELECT * FROM tbl1") run(executor, "create materialized view mvw1 AS SELECT * FROM tbl1") result = executor.view_definition("vw1") + assert 'VIEW "public"."vw1" AS' in result assert "FROM tbl1" in result # import pytest; pytest.set_trace() result = executor.view_definition("mvw1") @@ -496,9 +721,9 @@ def test_view_definition(executor): @dbtest def test_nonexistent_view_definition(executor): with pytest.raises(RuntimeError): - result = executor.view_definition("there_is_no_such_view") + executor.view_definition("there_is_no_such_view") with pytest.raises(RuntimeError): - result = executor.view_definition("mvw1") + executor.view_definition("mvw1") @dbtest @@ -507,17 +732,12 @@ def test_short_host(executor): assert executor.short_host == "localhost" with patch.object(executor, "host", "localhost.example.org"): assert executor.short_host == "localhost" - with patch.object( - executor, "host", "localhost1.example.org,localhost2.example.org" - ): + with patch.object(executor, "host", "localhost1.example.org,localhost2.example.org"): assert executor.short_host == "localhost1" - - -class BrokenConnection: - """Mock a connection that failed.""" - - def cursor(self): - raise psycopg2.InterfaceError("I'm broken!") + with patch.object(executor, "host", "ec2-11-222-333-444.compute-1.amazonaws.com"): + assert executor.short_host == "ec2-11-222-333-444" + with patch.object(executor, "host", "1.2.3.4"): + assert executor.short_host == "1.2.3.4" class VirtualCursor: @@ -549,13 +769,13 @@ def test_exit_without_active_connection(executor): aliases=(":q",), ) - with patch.object(executor, "conn", BrokenConnection()): + with patch.object(executor.conn, "cursor", side_effect=psycopg.InterfaceError("I'm broken!")): # we should be able to quit the app, even without active connection run(executor, "\\q", pgspecial=pgspecial) quit_handler.assert_called_once() # an exception should be raised when running a query without active connection - with pytest.raises(psycopg2.InterfaceError): + with pytest.raises(psycopg.InterfaceError): run(executor, "select 1", pgspecial=pgspecial) diff --git a/tests/test_plan.wiki b/tests/test_plan.wiki deleted file mode 100644 index 6812f18aa..000000000 --- a/tests/test_plan.wiki +++ /dev/null @@ -1,38 +0,0 @@ -= Gross Checks = - * [ ] Check connecting to a local database. - * [ ] Check connecting to a remote database. - * [ ] Check connecting to a database with a user/password. - * [ ] Check connecting to a non-existent database. - * [ ] Test changing the database. - - == PGExecute == - * [ ] Test successful execution given a cursor. - * [ ] Test unsuccessful execution with a syntax error. - * [ ] Test a series of executions with the same cursor without failure. - * [ ] Test a series of executions with the same cursor with failure. - * [ ] Test passing in a special command. - - == Naive Autocompletion == - * [ ] Input empty string, ask for completions - Everything. - * [ ] Input partial prefix, ask for completions - Stars with prefix. - * [ ] Input fully autocompleted string, ask for completions - Only full match - * [ ] Input non-existent prefix, ask for completions - nothing - * [ ] Input lowercase prefix - case insensitive completions - - == Smart Autocompletion == - * [ ] Input empty string and check if only keywords are returned. - * [ ] Input SELECT prefix and check if only columns are returned. - * [ ] Input SELECT blah - only keywords are returned. - * [ ] Input SELECT * FROM - Table names only - - == PGSpecial == - * [ ] Test \d - * [ ] Test \d tablename - * [ ] Test \d tablena* - * [ ] Test \d non-existent-tablename - * [ ] Test \d index - * [ ] Test \d sequence - * [ ] Test \d view - - == Exceptionals == - * [ ] Test the 'use' command to change db. diff --git a/tests/test_prompt_utils.py b/tests/test_prompt_utils.py index a8a3a1e08..91abe374a 100644 --- a/tests/test_prompt_utils.py +++ b/tests/test_prompt_utils.py @@ -7,4 +7,11 @@ def test_confirm_destructive_query_notty(): stdin = click.get_text_stream("stdin") if not stdin.isatty(): sql = "drop database foo;" - assert confirm_destructive_query(sql, "all") is None + assert confirm_destructive_query(sql, [], None) is None + + +def test_confirm_destructive_query_with_alias(): + stdin = click.get_text_stream("stdin") + if not stdin.isatty(): + sql = "drop database foo;" + assert confirm_destructive_query(sql, ["drop"], "test") is None diff --git a/tests/test_rowlimit.py b/tests/test_rowlimit.py index 947fc80d5..da916b4da 100644 --- a/tests/test_rowlimit.py +++ b/tests/test_rowlimit.py @@ -4,7 +4,7 @@ from pgcli.main import PGCli -# We need this fixtures beacause we need PGCli object to be created +# We need this fixtures because we need PGCli object to be created # after test collection so it has config loaded from temp directory diff --git a/tests/test_smart_completion_multiple_schemata.py b/tests/test_smart_completion_multiple_schemata.py index 5c9c9af48..98feb02db 100644 --- a/tests/test_smart_completion_multiple_schemata.py +++ b/tests/test_smart_completion_multiple_schemata.py @@ -11,7 +11,6 @@ wildcard_expansion, column, get_result, - result_set, qual, no_qual, parametrize, @@ -125,9 +124,7 @@ @parametrize("table", ["users", '"users"']) def test_suggested_column_names_from_shadowed_visible_table(completer, table): result = get_result(completer, "SELECT FROM " + table, len("SELECT ")) - assert completions_to_set(result) == completions_to_set( - testdata.columns_functions_and_keywords("users") - ) + assert completions_to_set(result) == completions_to_set(testdata.columns_functions_and_keywords("users")) @parametrize("completer", completers(filtr=True, casing=False, qualify=no_qual)) @@ -140,18 +137,14 @@ def test_suggested_column_names_from_shadowed_visible_table(completer, table): ) def test_suggested_column_names_from_qualified_shadowed_table(completer, text): result = get_result(completer, text, position=text.find(" ") + 1) - assert completions_to_set(result) == completions_to_set( - testdata.columns_functions_and_keywords("users", "custom") - ) + assert completions_to_set(result) == completions_to_set(testdata.columns_functions_and_keywords("users", "custom")) @parametrize("completer", completers(filtr=True, casing=False, qualify=no_qual)) @parametrize("text", ["WITH users as (SELECT 1 AS foo) SELECT from users"]) def test_suggested_column_names_from_cte(completer, text): result = completions_to_set(get_result(completer, text, text.find(" ") + 1)) - assert result == completions_to_set( - [column("foo")] + testdata.functions_and_keywords() - ) + assert result == completions_to_set([column("foo")] + testdata.functions_and_keywords()) @parametrize("completer", completers(casing=False)) @@ -166,14 +159,12 @@ def test_suggested_column_names_from_cte(completer, text): ) def test_suggested_join_conditions(completer, text): result = get_result(completer, text) - assert completions_to_set(result) == completions_to_set( - [ - alias("users"), - alias("shipments"), - name_join("shipments.id = users.id"), - fk_join("shipments.user_id = users.id"), - ] - ) + assert completions_to_set(result) == completions_to_set([ + alias("users"), + alias("shipments"), + name_join("shipments.id = users.id"), + fk_join("shipments.user_id = users.id"), + ]) @parametrize("completer", completers(filtr=True, casing=False, aliasing=False)) @@ -192,17 +183,14 @@ def test_suggested_join_conditions(completer, text): def test_suggested_joins(completer, query, tbl): result = get_result(completer, query.format(tbl)) assert completions_to_set(result) == completions_to_set( - testdata.schemas_and_from_clause_items() - + [join(f"custom.shipments ON shipments.user_id = {tbl}.id")] + testdata.schemas_and_from_clause_items() + [join(f"custom.shipments ON shipments.user_id = {tbl}.id")] ) @parametrize("completer", completers(filtr=True, casing=False, qualify=no_qual)) def test_suggested_column_names_from_schema_qualifed_table(completer): result = get_result(completer, "SELECT from custom.products", len("SELECT ")) - assert completions_to_set(result) == completions_to_set( - testdata.columns_functions_and_keywords("products", "custom") - ) + assert completions_to_set(result) == completions_to_set(testdata.columns_functions_and_keywords("products", "custom")) @parametrize( @@ -216,19 +204,13 @@ def test_suggested_column_names_from_schema_qualifed_table(completer): ) @parametrize("completer", completers(filtr=True, casing=False)) def test_suggested_columns_with_insert(completer, text): - assert completions_to_set(get_result(completer, text)) == completions_to_set( - testdata.columns("orders") - ) + assert completions_to_set(get_result(completer, text)) == completions_to_set(testdata.columns("orders")) @parametrize("completer", completers(filtr=True, casing=False, qualify=no_qual)) def test_suggested_column_names_in_function(completer): - result = get_result( - completer, "SELECT MAX( from custom.products", len("SELECT MAX(") - ) - assert completions_to_set(result) == completions_to_set( - testdata.columns_functions_and_keywords("products", "custom") - ) + result = get_result(completer, "SELECT MAX( from custom.products", len("SELECT MAX(")) + assert completions_to_set(result) == completions_to_set(testdata.columns_functions_and_keywords("products", "custom")) @parametrize("completer", completers(casing=False, aliasing=False)) @@ -237,9 +219,7 @@ def test_suggested_column_names_in_function(completer): ["SELECT * FROM Custom.", "SELECT * FROM custom.", 'SELECT * FROM "custom".'], ) @parametrize("use_leading_double_quote", [False, True]) -def test_suggested_table_names_with_schema_dot( - completer, text, use_leading_double_quote -): +def test_suggested_table_names_with_schema_dot(completer, text, use_leading_double_quote): if use_leading_double_quote: text += '"' start_position = -1 @@ -247,17 +227,13 @@ def test_suggested_table_names_with_schema_dot( start_position = 0 result = get_result(completer, text) - assert completions_to_set(result) == completions_to_set( - testdata.from_clause_items("custom", start_position) - ) + assert completions_to_set(result) == completions_to_set(testdata.from_clause_items("custom", start_position)) @parametrize("completer", completers(casing=False, aliasing=False)) @parametrize("text", ['SELECT * FROM "Custom".']) @parametrize("use_leading_double_quote", [False, True]) -def test_suggested_table_names_with_schema_dot2( - completer, text, use_leading_double_quote -): +def test_suggested_table_names_with_schema_dot2(completer, text, use_leading_double_quote): if use_leading_double_quote: text += '"' start_position = -1 @@ -265,37 +241,25 @@ def test_suggested_table_names_with_schema_dot2( start_position = 0 result = get_result(completer, text) - assert completions_to_set(result) == completions_to_set( - testdata.from_clause_items("Custom", start_position) - ) + assert completions_to_set(result) == completions_to_set(testdata.from_clause_items("Custom", start_position)) @parametrize("completer", completers(filtr=True, casing=False)) def test_suggested_column_names_with_qualified_alias(completer): result = get_result(completer, "SELECT p. from custom.products p", len("SELECT p.")) - assert completions_to_set(result) == completions_to_set( - testdata.columns("products", "custom") - ) + assert completions_to_set(result) == completions_to_set(testdata.columns("products", "custom")) @parametrize("completer", completers(filtr=True, casing=False, qualify=no_qual)) def test_suggested_multiple_column_names(completer): - result = get_result( - completer, "SELECT id, from custom.products", len("SELECT id, ") - ) - assert completions_to_set(result) == completions_to_set( - testdata.columns_functions_and_keywords("products", "custom") - ) + result = get_result(completer, "SELECT id, from custom.products", len("SELECT id, ")) + assert completions_to_set(result) == completions_to_set(testdata.columns_functions_and_keywords("products", "custom")) @parametrize("completer", completers(filtr=True, casing=False)) def test_suggested_multiple_column_names_with_alias(completer): - result = get_result( - completer, "SELECT p.id, p. from custom.products p", len("SELECT u.id, u.") - ) - assert completions_to_set(result) == completions_to_set( - testdata.columns("products", "custom") - ) + result = get_result(completer, "SELECT p.id, p. from custom.products p", len("SELECT u.id, u.")) + assert completions_to_set(result) == completions_to_set(testdata.columns("products", "custom")) @parametrize("completer", completers(filtr=True, casing=False)) @@ -307,19 +271,15 @@ def test_suggested_multiple_column_names_with_alias(completer): ], ) def test_suggestions_after_on(completer, text): - position = len( - "SELECT x.id, y.product_name FROM custom.products x JOIN custom.products y ON " - ) + position = len("SELECT x.id, y.product_name FROM custom.products x JOIN custom.products y ON ") result = get_result(completer, text, position) - assert completions_to_set(result) == completions_to_set( - [ - alias("x"), - alias("y"), - name_join("y.price = x.price"), - name_join("y.product_name = x.product_name"), - name_join("y.id = x.id"), - ] - ) + assert completions_to_set(result) == completions_to_set([ + alias("x"), + alias("y"), + name_join("y.price = x.price"), + name_join("y.product_name = x.product_name"), + name_join("y.id = x.id"), + ]) @parametrize("completer", completers()) @@ -333,32 +293,26 @@ def test_suggested_aliases_after_on_right_side(completer): def test_table_names_after_from(completer): text = "SELECT * FROM " result = get_result(completer, text) - assert completions_to_set(result) == completions_to_set( - testdata.schemas_and_from_clause_items() - ) + assert completions_to_set(result) == completions_to_set(testdata.schemas_and_from_clause_items()) @parametrize("completer", completers(filtr=True, casing=False)) def test_schema_qualified_function_name(completer): text = "SELECT custom.func" result = get_result(completer, text) - assert completions_to_set(result) == completions_to_set( - [ - function("func3()", -len("func")), - function("set_returning_func()", -len("func")), - ] - ) + assert completions_to_set(result) == completions_to_set([ + function("func3()", -len("func")), + function("set_returning_func()", -len("func")), + ]) @parametrize("completer", completers(filtr=True, casing=False, aliasing=False)) def test_schema_qualified_function_name_after_from(completer): text = "SELECT * FROM custom.set_r" result = get_result(completer, text) - assert completions_to_set(result) == completions_to_set( - [ - function("set_returning_func()", -len("func")), - ] - ) + assert completions_to_set(result) == completions_to_set([ + function("set_returning_func()", -len("func")), + ]) @parametrize("completer", completers(filtr=True, casing=False, aliasing=False)) @@ -373,11 +327,9 @@ def test_unqualified_function_name_in_search_path(completer): completer.search_path = ["public", "custom"] text = "SELECT * FROM set_r" result = get_result(completer, text) - assert completions_to_set(result) == completions_to_set( - [ - function("set_returning_func()", -len("func")), - ] - ) + assert completions_to_set(result) == completions_to_set([ + function("set_returning_func()", -len("func")), + ]) @parametrize("completer", completers(filtr=True, casing=False)) @@ -397,12 +349,8 @@ def test_schema_qualified_type_name(completer, text): @parametrize("completer", completers(filtr=True, casing=False)) def test_suggest_columns_from_aliased_set_returning_function(completer): - result = get_result( - completer, "select f. from custom.set_returning_func() f", len("select f.") - ) - assert completions_to_set(result) == completions_to_set( - testdata.columns("set_returning_func", "custom", "functions") - ) + result = get_result(completer, "select f. from custom.set_returning_func() f", len("select f.")) + assert completions_to_set(result) == completions_to_set(testdata.columns("set_returning_func", "custom", "functions")) @parametrize("completer", completers(filtr=True, casing=False, qualify=no_qual)) @@ -499,10 +447,7 @@ def test_wildcard_column_expansion_with_two_tables(completer): completions = get_result(completer, text, position) - cols = ( - '"select".id, "select"."localtime", "select"."ABC", ' - "users.id, users.phone_number" - ) + cols = '"select".id, "select"."localtime", "select"."ABC", users.id, users.phone_number' expected = [wildcard_expansion(cols)] assert completions == expected @@ -535,21 +480,15 @@ def test_wildcard_column_expansion_with_two_tables_and_parent(completer): def test_suggest_columns_from_unquoted_table(completer, text): position = len("SELECT U.") result = get_result(completer, text, position) - assert completions_to_set(result) == completions_to_set( - testdata.columns("users", "custom") - ) + assert completions_to_set(result) == completions_to_set(testdata.columns("users", "custom")) @parametrize("completer", completers(filtr=True, casing=False)) -@parametrize( - "text", ['SELECT U. FROM custom."Users" U', 'SELECT U. FROM "custom"."Users" U'] -) +@parametrize("text", ['SELECT U. FROM custom."Users" U', 'SELECT U. FROM "custom"."Users" U']) def test_suggest_columns_from_quoted_table(completer, text): position = len("SELECT U.") result = get_result(completer, text, position) - assert completions_to_set(result) == completions_to_set( - testdata.columns("Users", "custom") - ) + assert completions_to_set(result) == completions_to_set(testdata.columns("Users", "custom")) texts = ["SELECT * FROM ", "SELECT * FROM public.Orders O CROSS JOIN "] @@ -559,9 +498,7 @@ def test_suggest_columns_from_quoted_table(completer, text): @parametrize("text", texts) def test_schema_or_visible_table_completion(completer, text): result = get_result(completer, text) - assert completions_to_set(result) == completions_to_set( - testdata.schemas_and_from_clause_items() - ) + assert completions_to_set(result) == completions_to_set(testdata.schemas_and_from_clause_items()) @parametrize("completer", completers(aliasing=True, casing=False, filtr=True)) @@ -703,9 +640,7 @@ def test_column_alias_search(completer): @parametrize("completer", completers(casing=True)) def test_column_alias_search_qualified(completer): - result = get_result( - completer, "SELECT E.ei FROM blog.Entries E", len("SELECT E.ei") - ) + result = get_result(completer, "SELECT E.ei FROM blog.Entries E", len("SELECT E.ei")) cols = ("EntryID", "EntryTitle") assert result[:3] == [column(c, -2) for c in cols] @@ -713,9 +648,7 @@ def test_column_alias_search_qualified(completer): @parametrize("completer", completers(casing=False, filtr=False, aliasing=False)) def test_schema_object_order(completer): result = get_result(completer, "SELECT * FROM u") - assert result[:3] == [ - table(t, pos=-1) for t in ("users", 'custom."Users"', "custom.users") - ] + assert result[:3] == [table(t, pos=-1) for t in ("users", 'custom."Users"', "custom.users")] @parametrize("completer", completers(casing=False, filtr=False, aliasing=False)) @@ -723,8 +656,7 @@ def test_all_schema_objects(completer): text = "SELECT * FROM " result = get_result(completer, text) assert completions_to_set(result) >= completions_to_set( - [table(x) for x in ("orders", '"select"', "custom.shipments")] - + [function(x + "()") for x in ("func2",)] + [table(x) for x in ("orders", '"select"', "custom.shipments")] + [function(x + "()") for x in ("func2",)] ) @@ -733,8 +665,7 @@ def test_all_schema_objects_with_casing(completer): text = "SELECT * FROM " result = get_result(completer, text) assert completions_to_set(result) >= completions_to_set( - [table(x) for x in ("Orders", '"select"', "CUSTOM.shipments")] - + [function(x + "()") for x in ("func2",)] + [table(x) for x in ("Orders", '"select"', "CUSTOM.shipments")] + [function(x + "()") for x in ("func2",)] ) @@ -743,8 +674,7 @@ def test_all_schema_objects_with_aliases(completer): text = "SELECT * FROM " result = get_result(completer, text) assert completions_to_set(result) >= completions_to_set( - [table(x) for x in ("orders o", '"select" s', "custom.shipments s")] - + [function(x) for x in ("func2() f",)] + [table(x) for x in ("orders o", '"select" s', "custom.shipments s")] + [function(x) for x in ("func2() f",)] ) @@ -752,6 +682,4 @@ def test_all_schema_objects_with_aliases(completer): def test_set_schema(completer): text = "SET SCHEMA " result = get_result(completer, text) - assert completions_to_set(result) == completions_to_set( - [schema("'blog'"), schema("'Custom'"), schema("'custom'"), schema("'public'")] - ) + assert completions_to_set(result) == completions_to_set([schema("'blog'"), schema("'Custom'"), schema("'custom'"), schema("'public'")]) diff --git a/tests/test_smart_completion_public_schema_only.py b/tests/test_smart_completion_public_schema_only.py index db1fe0a39..92bfff765 100644 --- a/tests/test_smart_completion_public_schema_only.py +++ b/tests/test_smart_completion_public_schema_only.py @@ -12,7 +12,6 @@ column, wildcard_expansion, get_result, - result_set, qual, no_qual, parametrize, @@ -68,19 +67,11 @@ ] cased_tbls = ["Users", "Orders"] cased_views = ["User_Emails", "Functions"] -casing = ( - ["SELECT", "PUBLIC"] - + cased_func_names - + cased_tbls - + cased_views - + cased_users_col_names - + cased_users2_col_names -) +casing = ["SELECT", "PUBLIC"] + cased_func_names + cased_tbls + cased_views + cased_users_col_names + cased_users2_col_names # Lists for use in assertions -cased_funcs = [ - function(f) - for f in ("Custom_Fun()", "_custom_fun()", "Custom_Func1()", "custom_func2()") -] + [function("set_returning_func(x := , y := )", display="set_returning_func(x, y)")] +cased_funcs = [function(f) for f in ("Custom_Fun()", "_custom_fun()", "Custom_Func1()", "custom_func2()")] + [ + function("set_returning_func(x := , y := )", display="set_returning_func(x, y)") +] cased_tbls = [table(t) for t in (cased_tbls + ['"Users"', '"select"'])] cased_rels = [view(t) for t in cased_views] + cased_funcs + cased_tbls cased_users_cols = [column(c) for c in cased_users_col_names] @@ -132,25 +123,19 @@ def test_function_column_name(completer): len("SELECT * FROM Functions WHERE function:"), len("SELECT * FROM Functions WHERE function:text") + 1, ): - assert [] == get_result( - completer, "SELECT * FROM Functions WHERE function:text"[:l] - ) + assert [] == get_result(completer, "SELECT * FROM Functions WHERE function:text"[:l]) @parametrize("action", ["ALTER", "DROP", "CREATE", "CREATE OR REPLACE"]) @parametrize("completer", completers()) def test_drop_alter_function(completer, action): - assert get_result(completer, action + " FUNCTION set_ret") == [ - function("set_returning_func(x integer, y integer)", -len("set_ret")) - ] + assert get_result(completer, action + " FUNCTION set_ret") == [function("set_returning_func(x integer, y integer)", -len("set_ret"))] @parametrize("completer", completers()) def test_empty_string_completion(completer): result = get_result(completer, "") - assert completions_to_set( - testdata.keywords() + testdata.specials() - ) == completions_to_set(result) + assert completions_to_set(testdata.keywords() + testdata.specials()) == completions_to_set(result) @parametrize("completer", completers()) @@ -162,19 +147,17 @@ def test_select_keyword_completion(completer): @parametrize("completer", completers()) def test_builtin_function_name_completion(completer): result = get_result(completer, "SELECT MA") - assert completions_to_set(result) == completions_to_set( - [ - function("MAKE_DATE", -2), - function("MAKE_INTERVAL", -2), - function("MAKE_TIME", -2), - function("MAKE_TIMESTAMP", -2), - function("MAKE_TIMESTAMPTZ", -2), - function("MASKLEN", -2), - function("MAX", -2), - keyword("MAXEXTENTS", -2), - keyword("MATERIALIZED VIEW", -2), - ] - ) + assert completions_to_set(result) == completions_to_set([ + function("MAKE_DATE", -2), + function("MAKE_INTERVAL", -2), + function("MAKE_TIME", -2), + function("MAKE_TIMESTAMP", -2), + function("MAKE_TIMESTAMPTZ", -2), + function("MASKLEN", -2), + function("MAX", -2), + keyword("MAXEXTENTS", -2), + keyword("MATERIALIZED VIEW", -2), + ]) @parametrize("completer", completers()) @@ -189,58 +172,47 @@ def test_builtin_function_matches_only_at_start(completer): @parametrize("completer", completers(casing=False, aliasing=False)) def test_user_function_name_completion(completer): result = get_result(completer, "SELECT cu") - assert completions_to_set(result) == completions_to_set( - [ - function("custom_fun()", -2), - function("_custom_fun()", -2), - function("custom_func1()", -2), - function("custom_func2()", -2), - function("CURRENT_DATE", -2), - function("CURRENT_TIMESTAMP", -2), - function("CUME_DIST", -2), - function("CURRENT_TIME", -2), - keyword("CURRENT", -2), - ] - ) + assert completions_to_set(result) == completions_to_set([ + function("custom_fun()", -2), + function("_custom_fun()", -2), + function("custom_func1()", -2), + function("custom_func2()", -2), + function("CURRENT_DATE", -2), + function("CURRENT_TIMESTAMP", -2), + function("CUME_DIST", -2), + function("CURRENT_TIME", -2), + keyword("CURRENT", -2), + ]) @parametrize("completer", completers(casing=False, aliasing=False)) def test_user_function_name_completion_matches_anywhere(completer): result = get_result(completer, "SELECT om") - assert completions_to_set(result) == completions_to_set( - [ - function("custom_fun()", -2), - function("_custom_fun()", -2), - function("custom_func1()", -2), - function("custom_func2()", -2), - ] - ) + assert completions_to_set(result) == completions_to_set([ + function("custom_fun()", -2), + function("_custom_fun()", -2), + function("custom_func1()", -2), + function("custom_func2()", -2), + ]) @parametrize("completer", completers(casing=True)) def test_list_functions_for_special(completer): result = get_result(completer, r"\df ") - assert completions_to_set(result) == completions_to_set( - [schema("PUBLIC")] + [function(f) for f in cased_func_names] - ) + assert completions_to_set(result) == completions_to_set([schema("PUBLIC")] + [function(f) for f in cased_func_names]) @parametrize("completer", completers(casing=False, qualify=no_qual)) def test_suggested_column_names_from_visible_table(completer): result = get_result(completer, "SELECT from users", len("SELECT ")) - assert completions_to_set(result) == completions_to_set( - testdata.columns_functions_and_keywords("users") - ) + assert completions_to_set(result) == completions_to_set(testdata.columns_functions_and_keywords("users")) @parametrize("completer", completers(casing=True, qualify=no_qual)) def test_suggested_cased_column_names(completer): result = get_result(completer, "SELECT from users", len("SELECT ")) assert completions_to_set(result) == completions_to_set( - cased_funcs - + cased_users_cols - + testdata.builtin_functions() - + testdata.keywords() + cased_funcs + cased_users_cols + testdata.builtin_functions() + testdata.keywords() ) @@ -250,9 +222,7 @@ def test_suggested_auto_qualified_column_names(text, completer): position = text.index(" ") + 1 cols = [column(c.lower()) for c in cased_users_col_names] result = get_result(completer, text, position) - assert completions_to_set(result) == completions_to_set( - cols + testdata.functions_and_keywords() - ) + assert completions_to_set(result) == completions_to_set(cols + testdata.functions_and_keywords()) @parametrize("completer", completers(casing=False, qualify=qual)) @@ -268,9 +238,7 @@ def test_suggested_auto_qualified_column_names_two_tables(text, completer): cols = [column("U." + c.lower()) for c in cased_users_col_names] cols += [column('"Users".' + c.lower()) for c in cased_users2_col_names] result = get_result(completer, text, position) - assert completions_to_set(result) == completions_to_set( - cols + testdata.functions_and_keywords() - ) + assert completions_to_set(result) == completions_to_set(cols + testdata.functions_and_keywords()) @parametrize("completer", completers(casing=True, qualify=["always"])) @@ -287,17 +255,13 @@ def test_suggested_cased_always_qualified_column_names(completer): position = len("SELECT ") cols = [column("users." + c) for c in cased_users_col_names] result = get_result(completer, text, position) - assert completions_to_set(result) == completions_to_set( - cased_funcs + cols + testdata.builtin_functions() + testdata.keywords() - ) + assert completions_to_set(result) == completions_to_set(cased_funcs + cols + testdata.builtin_functions() + testdata.keywords()) @parametrize("completer", completers(casing=False, qualify=no_qual)) def test_suggested_column_names_in_function(completer): result = get_result(completer, "SELECT MAX( from users", len("SELECT MAX(")) - assert completions_to_set(result) == completions_to_set( - testdata.columns_functions_and_keywords("users") - ) + assert completions_to_set(result) == completions_to_set(testdata.columns_functions_and_keywords("users")) @parametrize("completer", completers(casing=False)) @@ -315,24 +279,18 @@ def test_suggested_column_names_with_alias(completer): @parametrize("completer", completers(casing=False, qualify=no_qual)) def test_suggested_multiple_column_names(completer): result = get_result(completer, "SELECT id, from users u", len("SELECT id, ")) - assert completions_to_set(result) == completions_to_set( - testdata.columns_functions_and_keywords("users") - ) + assert completions_to_set(result) == completions_to_set(testdata.columns_functions_and_keywords("users")) @parametrize("completer", completers(casing=False)) def test_suggested_multiple_column_names_with_alias(completer): - result = get_result( - completer, "SELECT u.id, u. from users u", len("SELECT u.id, u.") - ) + result = get_result(completer, "SELECT u.id, u. from users u", len("SELECT u.id, u.")) assert completions_to_set(result) == completions_to_set(testdata.columns("users")) @parametrize("completer", completers(casing=True)) def test_suggested_cased_column_names_with_alias(completer): - result = get_result( - completer, "SELECT u.id, u. from users u", len("SELECT u.id, u.") - ) + result = get_result(completer, "SELECT u.id, u. from users u", len("SELECT u.id, u.")) assert completions_to_set(result) == completions_to_set(cased_users_cols) @@ -378,18 +336,14 @@ def test_suggest_columns_after_three_way_join(completer): @parametrize("text", join_condition_texts) def test_suggested_join_conditions(completer, text): result = get_result(completer, text) - assert completions_to_set(result) == completions_to_set( - [alias("U"), alias("U2"), fk_join("U2.userid = U.id")] - ) + assert completions_to_set(result) == completions_to_set([alias("U"), alias("U2"), fk_join("U2.userid = U.id")]) @parametrize("completer", completers(casing=True)) @parametrize("text", join_condition_texts) def test_cased_join_conditions(completer, text): result = get_result(completer, text) - assert completions_to_set(result) == completions_to_set( - [alias("U"), alias("U2"), fk_join("U2.UserID = U.ID")] - ) + assert completions_to_set(result) == completions_to_set([alias("U"), alias("U2"), fk_join("U2.UserID = U.ID")]) @parametrize("completer", completers(casing=False)) @@ -435,9 +389,7 @@ def test_suggested_join_conditions_with_invalid_qualifier(completer, text): ) def test_suggested_join_conditions_with_invalid_table(completer, text, ref): result = get_result(completer, text) - assert completions_to_set(result) == completions_to_set( - [alias("users"), alias(ref)] - ) + assert completions_to_set(result) == completions_to_set([alias("users"), alias(ref)]) @parametrize("completer", completers(casing=False, aliasing=False)) @@ -531,8 +483,7 @@ def test_aliased_joins(completer, text): def test_suggested_joins_quoted_schema_qualified_table(completer, text): result = get_result(completer, text) assert completions_to_set(result) == completions_to_set( - testdata.schemas_and_from_clause_items() - + [join('public.users ON users.id = "Users".userid')] + testdata.schemas_and_from_clause_items() + [join('public.users ON users.id = "Users".userid')] ) @@ -547,14 +498,12 @@ def test_suggested_joins_quoted_schema_qualified_table(completer, text): def test_suggested_aliases_after_on(completer, text): position = len("SELECT u.name, o.id FROM users u JOIN orders o ON ") result = get_result(completer, text, position) - assert completions_to_set(result) == completions_to_set( - [ - alias("u"), - name_join("o.id = u.id"), - name_join("o.email = u.email"), - alias("o"), - ] - ) + assert completions_to_set(result) == completions_to_set([ + alias("u"), + name_join("o.id = u.id"), + name_join("o.email = u.email"), + alias("o"), + ]) @parametrize("completer", completers()) @@ -582,14 +531,12 @@ def test_suggested_aliases_after_on_right_side(completer, text): def test_suggested_tables_after_on(completer, text): position = len("SELECT users.name, orders.id FROM users JOIN orders ON ") result = get_result(completer, text, position) - assert completions_to_set(result) == completions_to_set( - [ - name_join("orders.id = users.id"), - name_join("orders.email = users.email"), - alias("users"), - alias("orders"), - ] - ) + assert completions_to_set(result) == completions_to_set([ + name_join("orders.id = users.id"), + name_join("orders.email = users.email"), + alias("users"), + alias("orders"), + ]) @parametrize("completer", completers(casing=False)) @@ -601,13 +548,9 @@ def test_suggested_tables_after_on(completer, text): ], ) def test_suggested_tables_after_on_right_side(completer, text): - position = len( - "SELECT users.name, orders.id FROM users JOIN orders ON orders.user_id = " - ) + position = len("SELECT users.name, orders.id FROM users JOIN orders ON orders.user_id = ") result = get_result(completer, text, position) - assert completions_to_set(result) == completions_to_set( - [alias("users"), alias("orders")] - ) + assert completions_to_set(result) == completions_to_set([alias("users"), alias("orders")]) @parametrize("completer", completers(casing=False)) @@ -620,9 +563,7 @@ def test_suggested_tables_after_on_right_side(completer, text): ) def test_join_using_suggests_common_columns(completer, text): result = get_result(completer, text) - assert completions_to_set(result) == completions_to_set( - [column("id"), column("email")] - ) + assert completions_to_set(result) == completions_to_set([column("id"), column("email")]) @parametrize("completer", completers(casing=False)) @@ -638,9 +579,7 @@ def test_join_using_suggests_common_columns(completer, text): def test_join_using_suggests_from_last_table(completer, text): position = text.index("()") + 1 result = get_result(completer, text, position) - assert completions_to_set(result) == completions_to_set( - [column("id"), column("email")] - ) + assert completions_to_set(result) == completions_to_set([column("id"), column("email")]) @parametrize("completer", completers(casing=False)) @@ -653,9 +592,7 @@ def test_join_using_suggests_from_last_table(completer, text): ) def test_join_using_suggests_columns_after_first_column(completer, text): result = get_result(completer, text) - assert completions_to_set(result) == completions_to_set( - [column("id"), column("email")] - ) + assert completions_to_set(result) == completions_to_set([column("id"), column("email")]) @parametrize("completer", completers(casing=False, aliasing=False)) @@ -669,9 +606,7 @@ def test_join_using_suggests_columns_after_first_column(completer, text): ) def test_table_names_after_from(completer, text): result = get_result(completer, text) - assert completions_to_set(result) == completions_to_set( - testdata.schemas_and_from_clause_items() - ) + assert completions_to_set(result) == completions_to_set(testdata.schemas_and_from_clause_items()) assert [c.text for c in result] == [ "public", "orders", @@ -691,9 +626,7 @@ def test_table_names_after_from(completer, text): @parametrize("completer", completers(casing=False, qualify=no_qual)) def test_auto_escaped_col_names(completer): result = get_result(completer, 'SELECT from "select"', len("SELECT ")) - assert completions_to_set(result) == completions_to_set( - testdata.columns_functions_and_keywords("select") - ) + assert completions_to_set(result) == completions_to_set(testdata.columns_functions_and_keywords("select")) @parametrize("completer", completers(aliasing=False)) @@ -717,9 +650,7 @@ def test_allow_leading_double_quote_in_last_word(completer): ) def test_suggest_datatype(text, completer): result = get_result(completer, text) - assert completions_to_set(result) == completions_to_set( - testdata.schemas() + testdata.types() + testdata.builtin_datatypes() - ) + assert completions_to_set(result) == completions_to_set(testdata.schemas() + testdata.types() + testdata.builtin_datatypes()) @parametrize("completer", completers(casing=False)) @@ -731,19 +662,13 @@ def test_suggest_columns_from_escaped_table_alias(completer): @parametrize("completer", completers(casing=False, qualify=no_qual)) def test_suggest_columns_from_set_returning_function(completer): result = get_result(completer, "select from set_returning_func()", len("select ")) - assert completions_to_set(result) == completions_to_set( - testdata.columns_functions_and_keywords("set_returning_func", typ="functions") - ) + assert completions_to_set(result) == completions_to_set(testdata.columns_functions_and_keywords("set_returning_func", typ="functions")) @parametrize("completer", completers(casing=False)) def test_suggest_columns_from_aliased_set_returning_function(completer): - result = get_result( - completer, "select f. from set_returning_func() f", len("select f.") - ) - assert completions_to_set(result) == completions_to_set( - testdata.columns("set_returning_func", typ="functions") - ) + result = get_result(completer, "select f. from set_returning_func() f", len("select f.")) + assert completions_to_set(result) == completions_to_set(testdata.columns("set_returning_func", typ="functions")) @parametrize("completer", completers(casing=False)) @@ -751,9 +676,7 @@ def test_join_functions_using_suggests_common_columns(completer): text = """SELECT * FROM set_returning_func() f1 INNER JOIN set_returning_func() f2 USING (""" result = get_result(completer, text) - assert completions_to_set(result) == completions_to_set( - testdata.columns("set_returning_func", typ="functions") - ) + assert completions_to_set(result) == completions_to_set(testdata.columns("set_returning_func", typ="functions")) @parametrize("completer", completers(casing=False)) @@ -762,8 +685,7 @@ def test_join_functions_on_suggests_columns_and_join_conditions(completer): INNER JOIN set_returning_func() f2 ON f1.""" result = get_result(completer, text) assert completions_to_set(result) == completions_to_set( - [name_join("y = f2.y"), name_join("x = f2.x")] - + testdata.columns("set_returning_func", typ="functions") + [name_join("y = f2.y"), name_join("x = f2.x")] + testdata.columns("set_returning_func", typ="functions") ) @@ -880,10 +802,7 @@ def test_wildcard_column_expansion_with_two_tables(completer): completions = get_result(completer, text, position) - cols = ( - '"select".id, "select".insert, "select"."ABC", ' - "u.id, u.parentid, u.email, u.first_name, u.last_name" - ) + cols = '"select".id, "select".insert, "select"."ABC", u.id, u.parentid, u.email, u.first_name, u.last_name' expected = [wildcard_expansion(cols)] assert completions == expected @@ -922,18 +841,14 @@ def test_suggest_columns_from_quoted_table(completer): @parametrize("text", ["SELECT * FROM ", "SELECT * FROM Orders o CROSS JOIN "]) def test_schema_or_visible_table_completion(completer, text): result = get_result(completer, text) - assert completions_to_set(result) == completions_to_set( - testdata.schemas_and_from_clause_items() - ) + assert completions_to_set(result) == completions_to_set(testdata.schemas_and_from_clause_items()) @parametrize("completer", completers(casing=False, aliasing=True)) @parametrize("text", ["SELECT * FROM "]) def test_table_aliases(completer, text): result = get_result(completer, text) - assert completions_to_set(result) == completions_to_set( - testdata.schemas() + aliased_rels - ) + assert completions_to_set(result) == completions_to_set(testdata.schemas() + aliased_rels) @parametrize("completer", completers(casing=False, aliasing=True)) @@ -965,43 +880,37 @@ def test_duplicate_table_aliases(completer, text): @parametrize("text", ["SELECT * FROM Orders o CROSS JOIN "]) def test_duplicate_aliases_with_casing(completer, text): result = get_result(completer, text) - assert completions_to_set(result) == completions_to_set( - [ - schema("PUBLIC"), - table("Orders O2"), - table("Users U"), - table('"Users" U'), - table('"select" s'), - view("User_Emails UE"), - view("Functions F"), - function("_custom_fun() cf"), - function("Custom_Fun() CF"), - function("Custom_Func1() CF"), - function("custom_func2() cf"), - function( - "set_returning_func(x := , y := ) srf", - display="set_returning_func(x, y) srf", - ), - ] - ) + assert completions_to_set(result) == completions_to_set([ + schema("PUBLIC"), + table("Orders O2"), + table("Users U"), + table('"Users" U'), + table('"select" s'), + view("User_Emails UE"), + view("Functions F"), + function("_custom_fun() cf"), + function("Custom_Fun() CF"), + function("Custom_Func1() CF"), + function("custom_func2() cf"), + function( + "set_returning_func(x := , y := ) srf", + display="set_returning_func(x, y) srf", + ), + ]) @parametrize("completer", completers(casing=True, aliasing=True)) @parametrize("text", ["SELECT * FROM "]) def test_aliases_with_casing(completer, text): result = get_result(completer, text) - assert completions_to_set(result) == completions_to_set( - [schema("PUBLIC")] + cased_aliased_rels - ) + assert completions_to_set(result) == completions_to_set([schema("PUBLIC")] + cased_aliased_rels) @parametrize("completer", completers(casing=True, aliasing=False)) @parametrize("text", ["SELECT * FROM "]) def test_table_casing(completer, text): result = get_result(completer, text) - assert completions_to_set(result) == completions_to_set( - [schema("PUBLIC")] + cased_rels - ) + assert completions_to_set(result) == completions_to_set([schema("PUBLIC")] + cased_rels) @parametrize("completer", completers(casing=False)) @@ -1028,12 +937,10 @@ def test_suggest_cte_names(completer): SELECT * FROM """ result = get_result(completer, text) - expected = completions_to_set( - [ - Completion("cte1", 0, display_meta="table"), - Completion("cte2", 0, display_meta="table"), - ] - ) + expected = completions_to_set([ + Completion("cte1", 0, display_meta="table"), + Completion("cte2", 0, display_meta="table"), + ]) assert expected <= completions_to_set(result) @@ -1101,12 +1008,10 @@ def test_set_schema(completer): @parametrize("completer", completers()) def test_special_name_completion(completer): result = get_result(completer, "\\t") - assert completions_to_set(result) == completions_to_set( - [ - Completion( - text="\\timing", - start_position=-2, - display_meta="Toggle timing of commands.", - ) - ] - ) + assert completions_to_set(result) == completions_to_set([ + Completion( + text="\\timing", + start_position=-2, + display_meta="Toggle timing of commands.", + ) + ]) diff --git a/tests/test_sqlcompletion.py b/tests/test_sqlcompletion.py index 744fadb02..028170d58 100644 --- a/tests/test_sqlcompletion.py +++ b/tests/test_sqlcompletion.py @@ -18,9 +18,7 @@ import pytest -def cols_etc( - table, schema=None, alias=None, is_function=False, parent=None, last_keyword=None -): +def cols_etc(table, schema=None, alias=None, is_function=False, parent=None, last_keyword=None): """Returns the expected select-clause suggestions for a single-table select.""" return { @@ -46,7 +44,7 @@ def test_select_suggests_cols_with_qualified_table_scope(): def test_cte_does_not_crash(): sql = "WITH CTE AS (SELECT F.* FROM Foo F WHERE F.Bar > 23) SELECT C.* FROM CTE C WHERE C.FooID BETWEEN 123 AND 234;" for i in range(len(sql)): - suggestions = suggest_type(sql[: i + 1], sql[: i + 1]) + suggest_type(sql[: i + 1], sql[: i + 1]) @pytest.mark.parametrize("expression", ['SELECT * FROM "tabl" WHERE ']) @@ -117,9 +115,7 @@ def test_select_suggests_cols_and_funcs(): } -@pytest.mark.parametrize( - "expression", ["INSERT INTO ", "COPY ", "UPDATE ", "DESCRIBE "] -) +@pytest.mark.parametrize("expression", ["INSERT INTO ", "COPY ", "UPDATE ", "DESCRIBE "]) def test_suggests_tables_views_and_schemas(expression): suggestions = suggest_type(expression, expression) assert set(suggestions) == {Table(schema=None), View(schema=None), Schema()} @@ -140,7 +136,7 @@ def test_suggest_tables_views_schemas_and_functions(expression): ) def test_suggest_after_join_with_two_tables(expression): suggestions = suggest_type(expression, expression) - tables = tuple([(None, "foo", None, False), (None, "bar", None, False)]) + tables = ((None, "foo", None, False), (None, "bar", None, False)) assert set(suggestions) == { FromClauseItem(schema=None, table_refs=tables), Join(tables, None), @@ -148,9 +144,7 @@ def test_suggest_after_join_with_two_tables(expression): } -@pytest.mark.parametrize( - "expression", ["SELECT * FROM foo JOIN ", "SELECT * FROM foo JOIN bar"] -) +@pytest.mark.parametrize("expression", ["SELECT * FROM foo JOIN ", "SELECT * FROM foo JOIN bar"]) def test_suggest_after_join_with_one_table(expression): suggestions = suggest_type(expression, expression) tables = ((None, "foo", None, False),) @@ -161,9 +155,7 @@ def test_suggest_after_join_with_one_table(expression): } -@pytest.mark.parametrize( - "expression", ["INSERT INTO sch.", "COPY sch.", "DESCRIBE sch."] -) +@pytest.mark.parametrize("expression", ["INSERT INTO sch.", "COPY sch.", "DESCRIBE sch."]) def test_suggest_qualified_tables_and_views(expression): suggestions = suggest_type(expression, expression) assert set(suggestions) == {Table(schema="sch"), View(schema="sch")} @@ -193,7 +185,7 @@ def test_suggest_qualified_tables_views_and_functions(expression): @pytest.mark.parametrize("expression", ["SELECT * FROM foo JOIN sch."]) def test_suggest_qualified_tables_views_functions_and_joins(expression): suggestions = suggest_type(expression, expression) - tbls = tuple([(None, "foo", None, False)]) + tbls = ((None, "foo", None, False),) assert set(suggestions) == { FromClauseItem(schema="sch", table_refs=tbls), Join(tbls, "sch"), @@ -210,9 +202,7 @@ def test_truncate_suggests_qualified_tables(): assert set(suggestions) == {Table(schema="sch")} -@pytest.mark.parametrize( - "text", ["SELECT DISTINCT ", "INSERT INTO foo SELECT DISTINCT "] -) +@pytest.mark.parametrize("text", ["SELECT DISTINCT ", "INSERT INTO foo SELECT DISTINCT "]) def test_distinct_suggests_cols(text): suggestions = suggest_type(text, text) assert set(suggestions) == { @@ -233,9 +223,7 @@ def test_distinct_suggests_cols(text): ), ], ) -def test_distinct_and_order_by_suggestions_with_aliases( - text, text_before, last_keyword -): +def test_distinct_and_order_by_suggestions_with_aliases(text, text_before, last_keyword): suggestions = suggest_type(text, text_before) assert set(suggestions) == { Column( @@ -309,34 +297,24 @@ def test_into_suggests_tables_and_schemas(): assert set(suggestion) == {Table(schema=None), View(schema=None), Schema()} -@pytest.mark.parametrize( - "text", ["INSERT INTO abc (", "INSERT INTO abc () SELECT * FROM hij;"] -) +@pytest.mark.parametrize("text", ["INSERT INTO abc (", "INSERT INTO abc () SELECT * FROM hij;"]) def test_insert_into_lparen_suggests_cols(text): suggestions = suggest_type(text, "INSERT INTO abc (") - assert suggestions == ( - Column(table_refs=((None, "abc", None, False),), context="insert"), - ) + assert suggestions == (Column(table_refs=((None, "abc", None, False),), context="insert"),) def test_insert_into_lparen_partial_text_suggests_cols(): suggestions = suggest_type("INSERT INTO abc (i", "INSERT INTO abc (i") - assert suggestions == ( - Column(table_refs=((None, "abc", None, False),), context="insert"), - ) + assert suggestions == (Column(table_refs=((None, "abc", None, False),), context="insert"),) def test_insert_into_lparen_comma_suggests_cols(): suggestions = suggest_type("INSERT INTO abc (id,", "INSERT INTO abc (id,") - assert suggestions == ( - Column(table_refs=((None, "abc", None, False),), context="insert"), - ) + assert suggestions == (Column(table_refs=((None, "abc", None, False),), context="insert"),) def test_partially_typed_col_name_suggests_col_names(): - suggestions = suggest_type( - "SELECT * FROM tabl WHERE col_n", "SELECT * FROM tabl WHERE col_n" - ) + suggestions = suggest_type("SELECT * FROM tabl WHERE col_n", "SELECT * FROM tabl WHERE col_n") assert set(suggestions) == cols_etc("tabl", last_keyword="WHERE") @@ -389,9 +367,7 @@ def test_dot_suggests_cols_of_an_alias_where(sql): def test_dot_col_comma_suggests_cols_or_schema_qualified_table(): - suggestions = suggest_type( - "SELECT t1.a, t2. FROM tabl1 t1, tabl2 t2", "SELECT t1.a, t2." - ) + suggestions = suggest_type("SELECT t1.a, t2. FROM tabl1 t1, tabl2 t2", "SELECT t1.a, t2.") assert set(suggestions) == { Column(table_refs=((None, "tabl2", "t2", False),)), Table(schema="t2"), @@ -452,14 +428,12 @@ def test_sub_select_table_name_completion(expression): ) def test_sub_select_table_name_completion_with_outer_table(expression): suggestion = suggest_type(expression, expression) - tbls = tuple([(None, "foo", None, False)]) + tbls = ((None, "foo", None, False),) assert set(suggestion) == {FromClauseItem(schema=None, table_refs=tbls), Schema()} def test_sub_select_col_name_completion(): - suggestions = suggest_type( - "SELECT * FROM (SELECT FROM abc", "SELECT * FROM (SELECT " - ) + suggestions = suggest_type("SELECT * FROM (SELECT FROM abc", "SELECT * FROM (SELECT ") assert set(suggestions) == { Column(table_refs=((None, "abc", None, False),), qualifiable=True), Function(schema=None), @@ -469,16 +443,12 @@ def test_sub_select_col_name_completion(): @pytest.mark.xfail def test_sub_select_multiple_col_name_completion(): - suggestions = suggest_type( - "SELECT * FROM (SELECT a, FROM abc", "SELECT * FROM (SELECT a, " - ) + suggestions = suggest_type("SELECT * FROM (SELECT a, FROM abc", "SELECT * FROM (SELECT a, ") assert set(suggestions) == cols_etc("abc") def test_sub_select_dot_col_name_completion(): - suggestions = suggest_type( - "SELECT * FROM (SELECT t. FROM tabl t", "SELECT * FROM (SELECT t." - ) + suggestions = suggest_type("SELECT * FROM (SELECT t. FROM tabl t", "SELECT * FROM (SELECT t.") assert set(suggestions) == { Column(table_refs=((None, "tabl", "t", False),)), Table(schema="t"), @@ -492,7 +462,7 @@ def test_sub_select_dot_col_name_completion(): def test_join_suggests_tables_and_schemas(tbl_alias, join_type): text = f"SELECT * FROM abc {tbl_alias} {join_type} JOIN " suggestion = suggest_type(text, text) - tbls = tuple([(None, "abc", tbl_alias or None, False)]) + tbls = ((None, "abc", tbl_alias or None, False),) assert set(suggestion) == { FromClauseItem(schema=None, table_refs=tbls), Schema(), @@ -505,7 +475,7 @@ def test_left_join_with_comma(): suggestions = suggest_type(text, text) # tbls should also include (None, 'bar', 'b', False) # but there's a bug with commas - tbls = tuple([(None, "foo", "f", False)]) + tbls = ((None, "foo", "f", False),) assert set(suggestions) == {FromClauseItem(schema=None, table_refs=tbls), Schema()} @@ -627,9 +597,7 @@ def test_on_suggests_tables_and_join_conditions_right_side(sql): ) def test_join_using_suggests_common_columns(text): tables = ((None, "abc", None, False), (None, "def", None, False)) - assert set(suggest_type(text, text)) == { - Column(table_refs=tables, require_last_table=True) - } + assert set(suggest_type(text, text)) == {Column(table_refs=tables, require_last_table=True)} def test_suggest_columns_after_multiple_joins(): @@ -643,14 +611,10 @@ def test_suggest_columns_after_multiple_joins(): def test_2_statements_2nd_current(): - suggestions = suggest_type( - "select * from a; select * from ", "select * from a; select * from " - ) + suggestions = suggest_type("select * from a; select * from ", "select * from a; select * from ") assert set(suggestions) == {FromClauseItem(schema=None), Schema()} - suggestions = suggest_type( - "select * from a; select from b", "select * from a; select " - ) + suggestions = suggest_type("select * from a; select from b", "select * from a; select ") assert set(suggestions) == { Column(table_refs=((None, "b", None, False),), qualifiable=True), Function(schema=None), @@ -658,9 +622,7 @@ def test_2_statements_2nd_current(): } # Should work even if first statement is invalid - suggestions = suggest_type( - "select * from; select * from ", "select * from; select * from " - ) + suggestions = suggest_type("select * from; select * from ", "select * from; select * from ") assert set(suggestions) == {FromClauseItem(schema=None), Schema()} @@ -679,9 +641,7 @@ def test_3_statements_2nd_current(): ) assert set(suggestions) == {FromClauseItem(schema=None), Schema()} - suggestions = suggest_type( - "select * from a; select from b; select * from c", "select * from a; select " - ) + suggestions = suggest_type("select * from a; select from b; select * from c", "select * from a; select ") assert set(suggestions) == cols_etc("b", last_keyword="SELECT") @@ -773,9 +733,7 @@ def test_statements_with_cursor_before_function_body(text): def test_create_db_with_template(): - suggestions = suggest_type( - "create database foo with template ", "create database foo with template " - ) + suggestions = suggest_type("create database foo with template ", "create database foo with template ") assert set(suggestions) == {Database()} @@ -814,9 +772,7 @@ def test_cast_operator_suggests_types(text): } -@pytest.mark.parametrize( - "text", ["SELECT foo::bar.", "SELECT foo::bar.baz", "SELECT (x + y)::bar."] -) +@pytest.mark.parametrize("text", ["SELECT foo::bar.", "SELECT foo::bar.baz", "SELECT (x + y)::bar."]) def test_cast_operator_suggests_schema_qualified_types(text): assert set(suggest_type(text, text)) == { Datatype(schema="bar"), @@ -844,7 +800,7 @@ def test_alter_column_type_suggests_types(): "CREATE FUNCTION foo (bar INT, baz ", "SELECT * FROM foo() AS bar (baz ", "SELECT * FROM foo() AS bar (baz INT, qux ", - # make sure this doesnt trigger special completion + # make sure this doesn't trigger special completion "CREATE TABLE foo (dt d", ], ) @@ -962,3 +918,13 @@ def test_handle_unrecognized_kw_generously(): @pytest.mark.parametrize("sql", ["ALTER ", "ALTER TABLE foo ALTER "]) def test_keyword_after_alter(sql): assert Keyword("ALTER") in set(suggest_type(sql, sql)) + + +def test_suggestion_when_setting_search_path(): + sql_set = "SET " + suggestion_set = suggest_type(sql_set, sql_set) + assert set(suggestion_set) == {Keyword("SET")} + + sql_set_search_path_to = "SET search_path TO " + suggestion_set_search_path_to = suggest_type(sql_set_search_path_to, sql_set_search_path_to) + assert set(suggestion_set_search_path_to) == {Schema()} diff --git a/tests/test_ssh_tunnel.py b/tests/test_ssh_tunnel.py new file mode 100644 index 000000000..c8670141b --- /dev/null +++ b/tests/test_ssh_tunnel.py @@ -0,0 +1,176 @@ +import os +from unittest.mock import patch, MagicMock, ANY + +import pytest +from configobj import ConfigObj +from click.testing import CliRunner +from sshtunnel import SSHTunnelForwarder + +from pgcli.main import cli, notify_callback, PGCli +from pgcli.pgexecute import PGExecute + + +@pytest.fixture +def mock_ssh_tunnel_forwarder() -> MagicMock: + mock_ssh_tunnel_forwarder = MagicMock(SSHTunnelForwarder, local_bind_ports=[1111], autospec=True) + with patch( + "pgcli.main.sshtunnel.SSHTunnelForwarder", + return_value=mock_ssh_tunnel_forwarder, + ) as mock: + yield mock + + +@pytest.fixture +def mock_pgexecute() -> MagicMock: + with patch.object(PGExecute, "__init__", return_value=None) as mock_pgexecute: + yield mock_pgexecute + + +def test_ssh_tunnel(mock_ssh_tunnel_forwarder: MagicMock, mock_pgexecute: MagicMock) -> None: + # Test with just a host + tunnel_url = "some.host" + db_params = { + "database": "dbname", + "host": "db.host", + "user": "db_user", + "passwd": "db_passwd", + } + expected_tunnel_params = { + "local_bind_address": ("127.0.0.1",), + "remote_bind_address": (db_params["host"], 5432), + "ssh_address_or_host": (tunnel_url, 22), + "logger": ANY, + } + + pgcli = PGCli(ssh_tunnel_url=tunnel_url) + pgcli.connect(**db_params) + + mock_ssh_tunnel_forwarder.assert_called_once_with(**expected_tunnel_params) + mock_ssh_tunnel_forwarder.return_value.start.assert_called_once() + mock_pgexecute.assert_called_once() + + call_args, call_kwargs = mock_pgexecute.call_args + assert call_args == ( + db_params["database"], + db_params["user"], + db_params["passwd"], + "127.0.0.1", + pgcli.ssh_tunnel.local_bind_ports[0], + "", + notify_callback, + ) + mock_ssh_tunnel_forwarder.reset_mock() + mock_pgexecute.reset_mock() + + # Test with a full url and with a specific db port + tunnel_user = "tunnel_user" + tunnel_passwd = "tunnel_pass" + tunnel_host = "some.other.host" + tunnel_port = 1022 + tunnel_url = f"ssh://{tunnel_user}:{tunnel_passwd}@{tunnel_host}:{tunnel_port}" + db_params["port"] = 1234 + + expected_tunnel_params["remote_bind_address"] = ( + db_params["host"], + db_params["port"], + ) + expected_tunnel_params["ssh_address_or_host"] = (tunnel_host, tunnel_port) + expected_tunnel_params["ssh_username"] = tunnel_user + expected_tunnel_params["ssh_password"] = tunnel_passwd + + pgcli = PGCli(ssh_tunnel_url=tunnel_url) + pgcli.connect(**db_params) + + mock_ssh_tunnel_forwarder.assert_called_once_with(**expected_tunnel_params) + mock_ssh_tunnel_forwarder.return_value.start.assert_called_once() + mock_pgexecute.assert_called_once() + + call_args, call_kwargs = mock_pgexecute.call_args + assert call_args == ( + db_params["database"], + db_params["user"], + db_params["passwd"], + "127.0.0.1", + pgcli.ssh_tunnel.local_bind_ports[0], + "", + notify_callback, + ) + mock_ssh_tunnel_forwarder.reset_mock() + mock_pgexecute.reset_mock() + + # Test with DSN + dsn = f"user={db_params['user']} password={db_params['passwd']} host={db_params['host']} port={db_params['port']}" + + pgcli = PGCli(ssh_tunnel_url=tunnel_url) + pgcli.connect(dsn=dsn) + + expected_dsn = f"user={db_params['user']} password={db_params['passwd']} host=127.0.0.1 port={pgcli.ssh_tunnel.local_bind_ports[0]}" + + mock_ssh_tunnel_forwarder.assert_called_once_with(**expected_tunnel_params) + mock_pgexecute.assert_called_once() + + call_args, call_kwargs = mock_pgexecute.call_args + assert expected_dsn in call_args + + +def test_cli_with_tunnel() -> None: + runner = CliRunner() + tunnel_url = "mytunnel" + with patch.object(PGCli, "__init__", autospec=True, return_value=None) as mock_pgcli: + runner.invoke(cli, ["--ssh-tunnel", tunnel_url]) + mock_pgcli.assert_called_once() + call_args, call_kwargs = mock_pgcli.call_args + assert call_kwargs["ssh_tunnel_url"] == tunnel_url + + +def test_config(tmpdir: os.PathLike, mock_ssh_tunnel_forwarder: MagicMock, mock_pgexecute: MagicMock) -> None: + pgclirc = str(tmpdir.join("rcfile")) + + tunnel_user = "tunnel_user" + tunnel_passwd = "tunnel_pass" + tunnel_host = "tunnel.host" + tunnel_port = 1022 + tunnel_url = f"{tunnel_user}:{tunnel_passwd}@{tunnel_host}:{tunnel_port}" + + tunnel2_url = "tunnel2.host" + + config = ConfigObj() + config.filename = pgclirc + config["ssh tunnels"] = {} + config["ssh tunnels"][r"\.com$"] = tunnel_url + config["ssh tunnels"][r"^hello-"] = tunnel2_url + config.write() + + # Unmatched host + pgcli = PGCli(pgclirc_file=pgclirc) + pgcli.connect(host="unmatched.host") + mock_ssh_tunnel_forwarder.assert_not_called() + + # Host matching first tunnel + pgcli = PGCli(pgclirc_file=pgclirc) + pgcli.connect(host="matched.host.com") + mock_ssh_tunnel_forwarder.assert_called_once() + call_args, call_kwargs = mock_ssh_tunnel_forwarder.call_args + assert call_kwargs["ssh_address_or_host"] == (tunnel_host, tunnel_port) + assert call_kwargs["ssh_username"] == tunnel_user + assert call_kwargs["ssh_password"] == tunnel_passwd + mock_ssh_tunnel_forwarder.reset_mock() + + # Host matching second tunnel + pgcli = PGCli(pgclirc_file=pgclirc) + pgcli.connect(host="hello-i-am-matched") + mock_ssh_tunnel_forwarder.assert_called_once() + + call_args, call_kwargs = mock_ssh_tunnel_forwarder.call_args + assert call_kwargs["ssh_address_or_host"] == (tunnel2_url, 22) + mock_ssh_tunnel_forwarder.reset_mock() + + # Host matching both tunnels (will use the first one matched) + pgcli = PGCli(pgclirc_file=pgclirc) + pgcli.connect(host="hello-i-am-matched.com") + mock_ssh_tunnel_forwarder.assert_called_once() + + call_args, call_kwargs = mock_ssh_tunnel_forwarder.call_args + assert call_kwargs["ssh_address_or_host"] == (tunnel_host, tunnel_port) + assert call_kwargs["ssh_username"] == tunnel_user + assert call_kwargs["ssh_password"] == tunnel_passwd diff --git a/tests/utils.py b/tests/utils.py index 460ea4694..e6dad62a7 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,8 +1,6 @@ import pytest -import psycopg2 -import psycopg2.extras +import psycopg from pgcli.main import format_output, OutputSettings -from pgcli.pgexecute import register_json_typecasters from os import getenv POSTGRES_USER = getenv("PGUSER", "postgres") @@ -12,12 +10,12 @@ def db_connection(dbname=None): - conn = psycopg2.connect( + conn = psycopg.connect( user=POSTGRES_USER, host=POSTGRES_HOST, password=POSTGRES_PASSWORD, port=POSTGRES_PORT, - database=dbname, + dbname=dbname, ) conn.autocommit = True return conn @@ -26,11 +24,10 @@ def db_connection(dbname=None): try: conn = db_connection() CAN_CONNECT_TO_DB = True - SERVER_VERSION = conn.server_version - json_types = register_json_typecasters(conn, lambda x: x) - JSON_AVAILABLE = "json" in json_types - JSONB_AVAILABLE = "jsonb" in json_types -except: + SERVER_VERSION = conn.info.parameter_status("server_version") + JSON_AVAILABLE = True + JSONB_AVAILABLE = True +except Exception: CAN_CONNECT_TO_DB = JSON_AVAILABLE = JSONB_AVAILABLE = False SERVER_VERSION = 0 @@ -41,21 +38,17 @@ def db_connection(dbname=None): ) -requires_json = pytest.mark.skipif( - not JSON_AVAILABLE, reason="Postgres server unavailable or json type not defined" -) +requires_json = pytest.mark.skipif(not JSON_AVAILABLE, reason="Postgres server unavailable or json type not defined") -requires_jsonb = pytest.mark.skipif( - not JSONB_AVAILABLE, reason="Postgres server unavailable or jsonb type not defined" -) +requires_jsonb = pytest.mark.skipif(not JSONB_AVAILABLE, reason="Postgres server unavailable or jsonb type not defined") def create_db(dbname): with db_connection().cursor() as cur: try: cur.execute("""CREATE DATABASE _test_db""") - except: + except Exception: pass @@ -70,16 +63,12 @@ def drop_tables(conn): ) -def run( - executor, sql, join=False, expanded=False, pgspecial=None, exception_formatter=None -): +def run(executor, sql, join=False, expanded=False, pgspecial=None, exception_formatter=None): "Return string output for the sql to be run" results = executor.run(sql, pgspecial, exception_formatter) formatted = [] - settings = OutputSettings( - table_format="psql", dcmlfmt="d", floatfmt="g", expanded=expanded - ) + settings = OutputSettings(table_format="psql", dcmlfmt="d", floatfmt="g", expanded=expanded) for title, rows, headers, status, sql, success, is_special in results: formatted.extend(format_output(title, rows, headers, status, settings)) if join: @@ -89,7 +78,4 @@ def run( def completions_to_set(completions): - return { - (completion.display_text, completion.display_meta_text) - for completion in completions - } + return {(completion.display_text, completion.display_meta_text) for completion in completions} diff --git a/tox.ini b/tox.ini index c2d4239b5..1b8ea025a 100644 --- a/tox.ini +++ b/tox.ini @@ -1,13 +1,31 @@ [tox] -envlist = py36, py37, py38, py39 +envlist = py + [testenv] -deps = pytest>=2.7.0,<=3.0.7 - mock>=1.0.1 - behave>=1.2.4 - pexpect==3.3 -commands = py.test - behave tests/features +skip_install = true +deps = uv +commands = uv pip install -e .[dev] + coverage run -m pytest -v tests + coverage report -m passenv = PGHOST PGPORT PGUSER PGPASSWORD + +[testenv:style] +skip_install = true +deps = ruff +commands = ruff check + ruff format --diff + +[testenv:integration] +skip_install = true +deps = uv +commands = uv pip install -e .[dev] + behave tests/features --no-capture + +[testenv:rest] +skip_install = true +deps = uv +commands = uv pip install -e .[dev] + docutils --halt=warning changelog.rst