diff --git a/.coveragerc b/.coveragerc deleted file mode 100644 index aea4994..0000000 --- a/.coveragerc +++ /dev/null @@ -1,3 +0,0 @@ -[run] -parallel = True -source = litecli diff --git a/.git-blame-ignore-revs b/.git-blame-ignore-revs new file mode 100644 index 0000000..c433202 --- /dev/null +++ b/.git-blame-ignore-revs @@ -0,0 +1,2 @@ +# Black +f767afc80bd5bcc8f1b1cc1a134babc2dec4d239 \ No newline at end of file diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index e995453..3e14cc7 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -5,4 +5,4 @@ ## Checklist -- [ ] I've added this contribution to the `changelog.md` file. +- [ ] I've added this contribution to the `CHANGELOG.md` file. diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..9c4973d --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,62 @@ +name: litecli + +on: + pull_request: + paths-ignore: + - '**.md' + - 'AUTHORS' + +jobs: + tests: + name: Tests + runs-on: ubuntu-latest + + strategy: + fail-fast: false + matrix: + python-version: ["3.10", "3.11", "3.12", "3.13"] + + steps: + - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 + + - uses: astral-sh/setup-uv@d9e0f98d3fc6adb07d1e3d37f3043649ddad06a1 # v6.5.0 + with: + version: "latest" + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Install dependencies + run: uv sync --all-extras -p ${{ matrix.python-version }} + + - name: Run unit tests + run: uv run tox -e py${{ matrix.python-version }} + + tests-no-extras: + name: Tests Without Extras + runs-on: ubuntu-latest + + strategy: + fail-fast: false + matrix: + python-version: ["3.13"] # Just the latest version + + steps: + - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 + + - uses: astral-sh/setup-uv@d9e0f98d3fc6adb07d1e3d37f3043649ddad06a1 # v6.5.0 + with: + version: "latest" + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Install dependencies + run: uv sync --extra dev -p ${{ matrix.python-version }} + + - name: Run unit tests + run: uv run tox -e py${{ matrix.python-version }} diff --git a/.github/workflows/codex-review.yml b/.github/workflows/codex-review.yml new file mode 100644 index 0000000..525c810 --- /dev/null +++ b/.github/workflows/codex-review.yml @@ -0,0 +1,72 @@ +name: Codex Review + +on: + pull_request_target: + types: [opened, reopened, synchronize, ready_for_review] + +jobs: + codex-review: + if: github.event.pull_request.draft == false + runs-on: ubuntu-latest + permissions: + contents: read + outputs: + final_message: ${{ steps.run_codex.outputs.final-message }} + + steps: + - name: Check out PR merge commit + uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 + with: + ref: refs/pull/${{ github.event.pull_request.number }}/merge + + - name: Fetch base and head refs + run: | + git fetch --no-tags origin \ + ${{ github.event.pull_request.base.ref }} \ + +refs/pull/${{ github.event.pull_request.number }}/head + + - name: Run Codex review + id: run_codex + uses: openai/codex-action@v1 + with: + openai-api-key: ${{ secrets.OPENAI_API_KEY }} + prompt: | + You are reviewing PR #${{ github.event.pull_request.number }} for ${{ github.repository }}. + + Only review changes introduced by this PR: + git log --oneline ${{ github.event.pull_request.base.sha }}...${{ github.event.pull_request.head.sha }} + + Focus on: + - correctness bugs and regressions + - security concerns + - missing tests or edge cases + + Keep feedback concise and actionable. + + Pull request title and body: + ---- + ${{ github.event.pull_request.title }} + ${{ github.event.pull_request.body }} + + post-feedback: + runs-on: ubuntu-latest + needs: codex-review + if: needs.codex-review.outputs.final_message != '' + permissions: + issues: write + pull-requests: write + + steps: + - name: Post Codex review as PR comment + uses: actions/github-script@v7 + env: + CODEX_FINAL_MESSAGE: ${{ needs.codex-review.outputs.final_message }} + with: + github-token: ${{ github.token }} + script: | + await github.rest.issues.createComment({ + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: context.payload.pull_request.number, + body: process.env.CODEX_FINAL_MESSAGE, + }); diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml new file mode 100644 index 0000000..e0b5d8a --- /dev/null +++ b/.github/workflows/lint.yml @@ -0,0 +1,29 @@ +name: Lint + +on: + pull_request: + paths-ignore: + - '**.md' + - 'AUTHORS' + +jobs: + linters: + name: Linters + runs-on: ubuntu-latest + + steps: + - name: Check out Git repository + uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 + + # remember to sync the ruff-check version number with pyproject.toml + - name: Run ruff check + uses: astral-sh/ruff-action@57714a7c8a2e59f32539362ba31877a1957dded1 # v3.5.1 + with: + version: 0.11.5 + + # remember to sync the ruff-check version number with pyproject.toml + - name: Run ruff format + uses: astral-sh/ruff-action@57714a7c8a2e59f32539362ba31877a1957dded1 # v3.5.1 + with: + version: 0.11.5 + args: 'format --check' diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml new file mode 100644 index 0000000..fbdf8c1 --- /dev/null +++ b/.github/workflows/publish.yml @@ -0,0 +1,84 @@ +name: Publish Python Package + +on: + release: + types: [created] + +permissions: + contents: read + +jobs: + test: + runs-on: ubuntu-latest + + strategy: + matrix: + python-version: ["3.10", "3.11", "3.12", "3.13"] + + steps: + - uses: actions/checkout@v4 + - uses: astral-sh/setup-uv@v1 + with: + version: "latest" + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Install dependencies + run: uv sync --all-extras -p ${{ matrix.python-version }} + + - name: Run unit tests + run: uv run tox -e py${{ matrix.python-version }} + + - name: Run Style Checks + run: uv run tox -e style + + build: + runs-on: ubuntu-latest + needs: [test] + + strategy: + matrix: + python-version: ["3.13"] + + steps: + - uses: actions/checkout@v4 + - uses: astral-sh/setup-uv@v1 + with: + version: "latest" + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Install dependencies + run: uv sync --all-extras -p ${{ matrix.python-version }} + + - name: Build + run: uv build + + - name: Store the distribution packages + uses: actions/upload-artifact@v4 + with: + name: python-packages + path: dist/ + + publish: + name: Publish to PyPI + runs-on: ubuntu-latest + if: startsWith(github.ref, 'refs/tags/') + needs: [build] + environment: release + permissions: + id-token: write + steps: + - name: Download distribution packages + uses: actions/download-artifact@v4 + with: + name: python-packages + path: dist/ + - name: Publish to PyPI + uses: pypa/gh-action-pypi-publish@release/v1 diff --git a/.github/workflows/typecheck.yml b/.github/workflows/typecheck.yml new file mode 100644 index 0000000..9d723be --- /dev/null +++ b/.github/workflows/typecheck.yml @@ -0,0 +1,37 @@ +name: Typecheck + +on: + pull_request: + paths-ignore: + - '**/*.md' + - 'AUTHORS' + +jobs: + typecheck: + name: Typecheck + runs-on: ubuntu-latest + + strategy: + matrix: + python-version: ["3.13"] + + steps: + - name: Check out Git repository + uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 + + - name: Set up Python + uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0 + with: + python-version: ${{ matrix.python-version }} + + - uses: astral-sh/setup-uv@d9e0f98d3fc6adb07d1e3d37f3043649ddad06a1 # v6.5.0 + with: + version: 'latest' + + - name: Install dependencies + run: uv sync --all-extras + + - name: Run ty + run: | + cd litecli + uv run ty check -v diff --git a/.gitignore b/.gitignore index ab7d02e..63c3eb6 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,7 @@ /src /test/behave.ini /litecli_env +/.venv /.eggs .vagrant diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b268b62..159c2b6 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,10 @@ repos: -- repo: https://github.com/ambv/black - rev: stable - hooks: - - id: black - language_version: python3.7 +- repo: https://github.com/astral-sh/ruff-pre-commit + # Ruff version. + rev: v0.6.4 + hooks: + # Run the linter. + - id: ruff + args: [ --fix ] + # Run the formatter. + - id: ruff-format diff --git a/.tox/.pkg/file.lock b/.tox/.pkg/file.lock new file mode 100644 index 0000000..e69de29 diff --git a/.travis.yml b/.travis.yml deleted file mode 100644 index 01a5df0..0000000 --- a/.travis.yml +++ /dev/null @@ -1,32 +0,0 @@ -language: python -python: - - "2.7" - - "3.4" - - "3.5" - - "3.6" - -matrix: - include: - - python: 3.7 - dist: xenial - sudo: true - -install: - - pip install -r requirements-dev.txt - - if [[ $TRAVIS_PYTHON_VERSION == '3.6' ]]; then pip install black; fi - - pip install -e . - -script: - - ./setup.py test --pytest-args="--cov-report= --cov=litecli" - - coverage report - - if [[ $TRAVIS_PYTHON_VERSION == '3.6' ]]; then ./setup.py lint; fi - -after_success: - - codecov - -notifications: - webhooks: - urls: - - YOUR_WEBHOOK_URL - on_success: change # options: [always|never|change] default: always - on_failure: always # options: [always|never|change] default: always diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 0000000..8009d19 --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,56 @@ +# Repository Guidelines + +## Project Structure & Module Organization +- `litecli/`: Core package. Entry point `main.py`, SQL execution in `sqlexecute.py`, completion in `packages/completion_engine.py`, special commands under `packages/special/`. +- `tests/`: Pytest suite (files like `test_*.py`). Test data under `tests/data/`. +- `screenshots/`: Images used in README. +- Config template: `litecli/liteclirc` (user config is created under `~/.config/litecli/config` or `%LOCALAPPDATA%/dbcli/litecli/config`). + +## Build, Test, and Development Commands +- Create env: `python -m venv .venv && source .venv/bin/activate`. +- Install dev deps: `pip install -e .[dev]`. +- Run all tests + coverage: `tox`. +- Extra tests with SQLean: `tox -e sqlean` (installs `[sqlean]` extras). +- Run tests directly: `pytest -q` or focused: `pytest -k keyword`. +- Launch CLI locally: `litecli path/to.db`. + +### Ruff (lint/format) +- Full style pass: `tox -e style` (runs `ruff check --fix` and `ruff format`). +- Direct commands: + - Lint: `ruff check` (add `--fix` to auto-fix) + - Format: `ruff format` + +## ty (type checking) +- Repo-wide `ty check -v` +- Per-package: `ty check litecli -v` +- Notes: + - Config is in `pyproject.toml` (target Python 3.9, stricter settings). + +## Coding Style & Naming Conventions +- Formatter/linter: Ruff (configured via `.pre-commit-config.yaml` and `tox`). +- Indentation: 4 spaces. Line length: 140 (see `pyproject.toml`). +- Naming: modules/functions/variables `snake_case`; classes `CamelCase`; tests `test_*.py`. +- Keep imports sorted and unused code removed (ruff enforces). +- Use lowercase type hints for dict, list, tuples etc. +- Use | for Unions and | None for Optional. + +## Testing Guidelines +- Framework: Pytest with coverage (`coverage run -m pytest` via tox). +- Location: place new tests in `tests/` alongside related module area. +- Conventions: name files `test_.py`; use fixtures from `tests/conftest.py`. +- Quick check: `pytest -q`; coverage report via `tox` or `coverage report -m`. + +## Commit & Pull Request Guidelines +- Commits: imperative mood, concise scope (e.g., `fix: handle NULL types`). Reference issues (`#123`) when relevant. +- Update `CHANGELOG.md` for user-visible changes. +- PRs: include clear description, steps to reproduce/verify, and screenshots or snippets for CLI output when helpful. Use the PR template. +- Ensure CI passes (tests + ruff). Re-run `tox -e style` before requesting review. + +## Changelog Discipline +- Always add an "Unreleased" section at the top of `CHANGELOG.md` when making changes. +- Keep entries succinct; avoid overly detailed technical notes. +- Group under "Features", "Bug Fixes", and "Internal" when applicable. + +## Security & Configuration Tips +- Do not commit local databases or secrets. Use files under `tests/data/` for fixtures. +- User settings live outside the repo; document defaults by editing `litecli/liteclirc`. diff --git "a/C:\\Users\\litecli\\litecli_test.db" "b/C:\\Users\\litecli\\litecli_test.db" new file mode 100644 index 0000000..e69de29 diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000..e429af4 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,321 @@ +## Unreleased + +### Bug Fixes + +- Expand `~` in configured log file paths before opening the log. + +### Internal + +- Add a GitHub Actions workflow to run Codex review on pull requests. +- Drop Python 3.9 from test matrices and tooling targets. + +## 1.19.0 - 2026-01-30 + +### Features + +- Make LLM support optional and installable via `litecli[ai]`. + +### Bug Fixes + +- Avoid completion refresh crashes when no database is connected. + +### Internal + +- Clean up ty type-checking for optional sqlean/llm imports. + +## 1.18.0 + +### Internal + +- Switch mypy to ty for type checking. [(#242)](https://github.com/dbcli/litecli/pull/242/files) +- Add sqlean-stubs for type checking. [(#243)(https://github.com/dbcli/litecli/pull/243/files)] + +## 1.17.0 - 2025-09-28 + +### Features + +* Add support for opening 'file:' URIs with parameters. [(#234)](https://github.com/dbcli/litecli/pull/234) + +### Bug Fixes + +* Avoid Click 8.1.* to prevent messing up the pager when the PAGER env var has a string with spaces. + +### Internal + +- Add type checking using mypy. + +## 1.16.0 - 2025-08-16 + +### Features + +* Use [sqlean](https://antonz.org/sqlean/) when available. It's a drop-in replacement for sqlite3. +* Add support for `.output` to write the results to a file. +* The 'llm' library is now a default dependency not installed on demand. +* The `\llm` command now has three modes. Succinct, Regular and Verbose. + + Succinct = `\llm-` - This will return just the sql query. No explanation. + Regular = `\llm` - This will return just the sql query and the explanation. + Verbose = `\llm+` - This will print the prompt sent to the LLM and the sql query and the explanation. + +### Bug Fixes + +* Fix missing sqlite extensions using sqlean. Note. support only limited set of extensions. [(#119)](https://github.com/dbcli/litecli/issues/119) + + +## 1.15.0 - 2025-03-15 + +### Features +* Add logs while invoking `\llm`and `\\m+` command. [(#215)](https://github.com/dbcli/litecli/pull/215) +* Support `--help` in the `\llm`and `\llm+` command. ([#214](https://github.com/dbcli/litecli/pull/214)) +* Make the history file location configurable. ([#206](https://github.com/dbcli/litecli/issues/206)) +* Add dot command to list views. + +### Bug Fixes + +* Fix a bug where the `\llm` command on alternate invocations weren't detected correctly. (#211) +* Do not escape upper table or column name. [(#185)](https://github.com/dbcli/litecli/issues/185) +* Return indices when `.schema` command is run. Also update the output to contain the `sql` for the `indexes` command. [(#149)](https://github.com/dbcli/litecli/issues/149) + +### Internal + +* Fix typo `pormpt`to `prompt` in `special/llm.py`. +* Update pip install to work in both bash and zsh. + + +## 1.14.4 - 2025-01-31 + +### Bug Fixes + +* Fix the usage instructions in the `\llm` command. + +## 1.14.3 - 2025-01-29 + +### Bug Fixes + +* Fix [misleading "0 rows affected" status for CTEs](https://github.com/dbcli/litecli/issues/203) + by never displaying rows affected when the connector tells us -1 +* Show an error message when `\llm "question"` is invoked without a database connection. + +## 1.14.2 - 2025-01-26 + +### Bug Fixes + +* Catch errors surfaced by `llm` cli and surface them as runtime errors. + +## 1.14.1 - 2025-01-25 + +### Bug Fixes + +* Capture stderr in addition to stdout when capturing output from `llm` cli. + +## 1.14.0 - 2025-01-22 + +### Features + +* Add LLM feature to ask an LLM to create a SQL query. + - This adds a new `\llm` special command + - eg: `\llm "Who is the largest customer based on revenue?"` + +### Bug Fixes + +* Fix the [windows path](https://github.com/dbcli/litecli/issues/187) shown in prompt to remove escaping. +* Fix a bug where if column name was same as table name it was [crashing](https://github.com/dbcli/litecli/issues/155) the autocompletion. + +### Internal + +* Change min required python version to 3.9+ + +## 1.13.2 - 2024-11-24 + +### Internal + +* Read the version from the git tag using setuptools-scm + +## 1.13.0 - 2024-11-23 + +### Features + +* Add `\pipe_once` / `\|` commands for sending output to a command + +## 1.12.4 - 2024-11-11 + +### Bug Fixes + +* Fix the syntax error when `\d tbl` is used. + +## 1.12.3 - 2024-09-10 + +### Bug Fixes + +* Specify build system in `pyproject.toml` +* Don't install tests + +## 1.12.2 - 2024-09-07 + +### Bug Fixes + +* Fix the missing packages due to invalid pyproject.toml config + +## 1.12.1 - 2024-09-07 (Yanked) + +### Internal Changes + +* Modernize the project with following changes: + * pyproject.toml instead of setup.py + * Use ruff for linting and formatting + * Update GH actions to use uv and tox + * Use GH actions to release a new version + +## 1.11.1 - 2024-07-04 + +### Bug Fixes + +* Fix the escape sequence warning. + +## 1.11.0 - 2024-05-03 + +### Improvements + +* When an empty `\d` is invoked the list of tables are returned instead of an error. +* Show SQLite version at startup. + +### Bug Fixes + +* Support a single item in the startup commands in the config. (bug #176) + +## 1.10.1 - 2024-3-23 + +### Bug Fixes + +* Do not crash at start up if ~/.config/litecli is not writeable. [#172](https://github.com/dbcli/litecli/issues/172) + +## 1.10.0 - 2022-11-19 + +### Features + +* Adding support for startup commands being set in liteclirc and executed on startup. Limited to commands already implemented in litecli. ([[#56](https://github.com/dbcli/litecli/issues/56)]) + +### Bug Fixes + +* Fix [[#146](https://github.com/dbcli/litecli/issues/146)], making sure `.once` + can be used more than once in a session. +* Fixed setting `successful = True` only when query is executed without exceptions so + failing queries get `successful = False` in `query_history`. +* Changed `master` to `main` in CONTRIBUTING.md to reflect GitHubs new default branch + naming. +* Fixed `.once -o ` by opening the output file once per statement instead + of for every line of output ([#148](https://github.com/dbcli/litecli/issues/148)). +* Use the sqlite3 API to cancel a running query on interrupt + ([#164](https://github.com/dbcli/litecli/issues/164)). +* Skip internal indexes in the .schema output + ([#170](https://github.com/dbcli/litecli/issues/170)). + +## 1.9.0 - 2022-06-06 + +### Features + +* Add support for ANSI escape sequences for coloring the prompt. +* Add support for `.indexes` command. +* Add an option to turn off the auto-completion menu. Completion menu can be + triggered by pressed the `` key when this option is set to False. Fixes + [#105](https://github.com/dbcli/litecli/issues/105). + +### Bug Fixes + +* Fix [#120](https://github.com/dbcli/litecli/issues/120). Make the `.read` command actually read and execute the commands from a file. +* Fix [#96](https://github.com/dbcli/litecli/issues/96) the crash in VI mode when pressing `r`. + +## 1.8.0 - 2022-03-29 + +### Features + +* Update compatible Python versions. (Thanks: [blazewicz]) +* Add support for Python 3.10. (Thanks: [blazewicz]) +* Drop support for Python 3.6. (Thanks: [blazewicz]) + +### Bug Fixes + +* Upgrade cli_helpers to workaround Pygments regression. +* Use get_terminal_size from shutil instead of click. + +## 1.7.0 - 2022-01-11 + +### Features + +* Add config option show_bottom_toolbar. + +### Bug Fixes + +* Pin pygments version to prevent breaking change. + +## 1.6.0 - 2021-03-15 + +### Features + +* Add verbose feature to `favorite_query` command. (Thanks: [Zhaolong Zhu]) + * `\f query` does not show the full SQL. + * `\f+ query` shows the full SQL. +* Add prompt format of file's basename. (Thanks: [elig0n]) + +### Bug Fixes + +* Fix compatibility with sqlparse >= 0.4.0. (Thanks: [chocolateboy]) +* Fix invalid utf-8 exception. (Thanks: [Amjith]) + +## 1.4.1 - 2020-07-27 + +### Bug Fixes + +* Fix setup.py to set `long_description_content_type` as markdown. + +## 1.4.0 - 2020-07-27 + +### Features + +* Add NULLS FIRST and NULLS LAST to keywords. (Thanks: [Amjith]) + +## 1.3.2 - 2020-03-11 + +* Fix the completion engine to work with newer sqlparse. + +## 1.3.1 - 2020-03-11 + +* Remove the version pinning of sqlparse package. + +## 1.3.0 - 2020-02-11 + +### Features + +* Added `.import` command for importing data from file into table. (Thanks: [Zhaolong Zhu]) +* Upgraded to prompt-toolkit 3.x. + +## 1.2.0 - 2019-10-26 + +### Features + +* Enhance the `describe` command. (Thanks: [Amjith]) +* Autocomplete table names for special commands. (Thanks: [Amjith]) + +## 1.1.0 - 2019-07-14 + +### Features + +* Added `.read` command for reading scripts. +* Added `.load` command for loading extension libraries. (Thanks: [Zhiming Wang]) +* Add support for using `?` as a placeholder in the favorite queries. (Thanks: [Amjith]) +* Added shift-tab to select the previous entry in the completion menu. [Amjith] +* Added `describe` and `desc` keywords. + +### Bug Fixes + +* Clear error message when directory does not exist. (Thanks: [Irina Truong]) + +## 1.0.0 - 2019-01-04 + +* To new beginnings. :tada: + +[Amjith]: https://blog.amjith.com +[chocolateboy]: https://github.com/chocolateboy +[Irina Truong]: https://github.com/j-bennet +[Zhaolong Zhu]: https://github.com/zzl0 +[Zhiming Wang]: https://github.com/zmwangx diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index ad22cdf..a4d5f2d 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -1,6 +1,6 @@ # Development Guide -This is a guide for developers who would like to contribute to this project. It is recommended to use Python 3.6 and above for development. +This is a guide for developers who would like to contribute to this project. It is recommended to use Python 3.10 and above for development. If you're interested in contributing to litecli, thank you. We'd love your help! You'll always get credit for your work. @@ -24,8 +24,7 @@ You'll always get credit for your work. ```bash $ cd litecli - $ pip install virtualenv - $ virtualenv litecli_dev + $ python -m venv .venv ``` We've just created a virtual environment that we'll use to install all the dependencies @@ -33,7 +32,7 @@ You'll always get credit for your work. need to activate the virtual environment: ```bash - $ source litecli_dev/bin/activate + $ source .venv/bin/activate ``` When you're done working, you can deactivate the virtual environment: @@ -45,11 +44,10 @@ You'll always get credit for your work. 5. Install the dependencies and development tools: ```bash - $ pip install -r requirements-dev.txt - $ pip install --editable . + $ pip install --editable ".[dev]" ``` -6. Create a branch for your bugfix or feature based off the `master` branch: +6. Create a branch for your bugfix or feature based off the `main` branch: ```bash $ git checkout -b @@ -58,7 +56,7 @@ You'll always get credit for your work. 7. While you work on your bugfix or feature, be sure to pull the latest changes from `upstream`. This ensures that your local codebase is up-to-date: ```bash - $ git pull upstream master + $ git pull upstream main ``` 8. When your work is ready for the litecli team to review it, push your branch to your fork: @@ -75,18 +73,10 @@ You'll always get credit for your work. While you work on litecli, it's important to run the tests to make sure your code hasn't broken any existing functionality. To run the tests, just type in: -```bash -$ ./setup.py test -``` - -litecli supports Python 2.7 and 3.4+. You can test against multiple versions of -Python by running tox: - ```bash $ tox ``` - ### CLI Tests Some CLI tests expect the program `ex` to be a symbolic link to `vim`. @@ -102,18 +92,12 @@ $ readlink -f $(which ex) ## Coding Style -litecli uses [black](https://github.com/ambv/black) to format the source code. Make sure to install black. - -It's easy to check the style of your code, just run: - -```bash -$ ./setup.py lint -``` +Litecli uses [ruff](https://docs.astral.sh/ruff/) to format the source code. -If you see any style issues, you can automatically fix them by running: +To check the style and fix any violations, run: ```bash -$ ./setup.py lint --fix +$ tox -e style ``` Be sure to commit and push any stylistic fixes. diff --git a/MANIFEST.in b/MANIFEST.in index c7e08e7..f1ff0f6 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,6 +1,8 @@ include *.txt *.py -include LICENSE changelog.md +include LICENSE CHANGELOG.md include tox.ini recursive-include tests *.py recursive-include tests *.txt +recursive-include tests *.csv +recursive-include tests liteclirc recursive-include litecli AUTHORS diff --git a/README.md b/README.md index 468615f..d4142eb 100644 --- a/README.md +++ b/README.md @@ -1,13 +1,13 @@ # litecli -[![Build Status](https://travis-ci.org/dbcli/litecli.svg?branch=master)](https://travis-ci.org/dbcli/litecli) +[![GitHub Actions](https://github.com/dbcli/litecli/actions/workflows/ci.yml/badge.svg)](https://github.com/dbcli/litecli/actions/workflows/ci.yml "GitHub Actions") [Docs](https://litecli.com) A command-line client for SQLite databases that has auto-completion and syntax highlighting. -![Completion](screenshots/litecli.png) -![CompletionGif](screenshots/litecli.gif) +![Completion](https://raw.githubusercontent.com/dbcli/litecli/refs/heads/main/screenshots/litecli.png) +![CompletionGif](https://raw.githubusercontent.com/dbcli/litecli/refs/heads/main/screenshots/litecli.gif) ## Installation @@ -16,37 +16,27 @@ If you already know how to install python packages, then you can install it via You might need sudo on linux. ``` -$ pip install -U litecli -``` - -The package is also available on Arch Linux through AUR in two versions: [litecli](https://aur.archlinux.org/packages/litecli/) is based the latest release (git tag) and [litecli-git](https://aur.archlinux.org/packages/litecli-git/) is based on the master branch of the git repo. You can install them manually or with an AUR helper such as `yay`: - -``` -$ yay -S litecli -``` -or - -``` -$ yay -S litecli-git +$ pip install -U litecli[sqlean] ``` For MacOS users, you can also use Homebrew to install it: ``` -$ brew tap dbcli/tap $ brew install litecli ``` ## Usage - $ litecli --help - - Usage: litecli [OPTIONS] [DATABASE] +``` +$ litecli --help - Examples: - - litecli sqlite_db_name +Usage: litecli [OPTIONS] [DATABASE] + +Examples: + - litecli sqlite_db_name +``` -A config file is automatically created at `~/.config/litecli/config` at first launch. See the file itself for a description of all available options. +A config file is automatically created at `~/.config/litecli/config` at first launch. For Windows machines a config file is created at `~\AppData\Local\dbcli\litecli\config` at first launch. See the file itself for a description of all available options. ## Docs diff --git a/TODO b/TODO deleted file mode 100644 index 7c854dc..0000000 --- a/TODO +++ /dev/null @@ -1,3 +0,0 @@ -* [] Sort by frecency. -* [] Add completions when an attach database command is run. -* [] Add behave tests. diff --git a/TODO.md b/TODO.md new file mode 100644 index 0000000..58e2ebc --- /dev/null +++ b/TODO.md @@ -0,0 +1,6 @@ +* [ ] Change to use ruff +* [ ] Automate the release process via GH actions. [Article](https://simonwillison.net/2024/Jan/16/python-lib-pypi/) + +* [] Sort by frecency. +* [] Add completions when an attach database command is run. +* [] Add behave tests. diff --git a/changelog.md b/changelog.md deleted file mode 100644 index c3f54fe..0000000 --- a/changelog.md +++ /dev/null @@ -1,41 +0,0 @@ -TBD -=== - -Bug Fixes: ----------- - -* - -Features: ---------- - -* - - -1.1.0 -===== - -Bug Fixes: ----------- - -* Clear error message when directory does not exist. (Thanks: [Irina Truong]) - -Features: ---------- - -* Added `.read` command for reading scripts. -* Added `.load` command for loading extension libraries. (Thanks: [Zhiming Wang]) -* Add support for using `?` as a placeholder in the favorite queries. (Thanks: [Amjith]) -* Added shift-tab to select the previous entry in the completion menu. [Amjith] -* Added `describe` and `desc` keywords. - -1.0.0 -===== - -* To new beginnings. :tada: - - - -[Amjith]: https://blog.amjith.com -[Zhiming Wang]: https://github.com/zmwangx -[Irina Truong]: https://github.com/j-bennet diff --git a/litecli/AUTHORS b/litecli/AUTHORS index d5265de..194cdc7 100644 --- a/litecli/AUTHORS +++ b/litecli/AUTHORS @@ -18,3 +18,4 @@ Contributors: * Zhaolong Zhu * Zhiming Wang * Shawn M. Chapla + * Paweł Sacawa diff --git a/litecli/__init__.py b/litecli/__init__.py index 6849410..3c05333 100644 --- a/litecli/__init__.py +++ b/litecli/__init__.py @@ -1 +1,5 @@ -__version__ = "1.1.0" +from __future__ import annotations + +import importlib.metadata + +__version__ = importlib.metadata.version("litecli") diff --git a/litecli/clibuffer.py b/litecli/clibuffer.py index a57192a..cd67aa8 100644 --- a/litecli/clibuffer.py +++ b/litecli/clibuffer.py @@ -1,14 +1,18 @@ -from __future__ import unicode_literals +from __future__ import annotations + +from typing import Any from prompt_toolkit.enums import DEFAULT_BUFFER -from prompt_toolkit.filters import Condition +from prompt_toolkit.filters import Condition, Filter from prompt_toolkit.application import get_app -def cli_is_multiline(cli): +def cli_is_multiline(cli: Any) -> Filter: @Condition - def cond(): - doc = get_app().layout.get_buffer_by_name(DEFAULT_BUFFER).document + def cond() -> bool: + buf = get_app().layout.get_buffer_by_name(DEFAULT_BUFFER) + assert buf is not None + doc = buf.document if not cli.multi_line: return False @@ -18,7 +22,7 @@ def cond(): return cond -def _multiline_exception(text): +def _multiline_exception(text: str) -> bool: orig = text text = text.strip() diff --git a/litecli/clistyle.py b/litecli/clistyle.py index 7527315..b364872 100644 --- a/litecli/clistyle.py +++ b/litecli/clistyle.py @@ -1,18 +1,20 @@ -from __future__ import unicode_literals +from __future__ import annotations import logging +from typing import cast import pygments.styles -from pygments.token import string_to_tokentype, Token +from prompt_toolkit.styles import Style, merge_styles +from prompt_toolkit.styles.pygments import style_from_pygments_cls +from prompt_toolkit.styles.style import _MergedStyle from pygments.style import Style as PygmentsStyle +from pygments.token import Token, _TokenType, string_to_tokentype from pygments.util import ClassNotFound -from prompt_toolkit.styles.pygments import style_from_pygments_cls -from prompt_toolkit.styles import merge_styles, Style logger = logging.getLogger(__name__) # map Pygments tokens (ptk 1.0) to class names (ptk 2.0). -TOKEN_TO_PROMPT_STYLE = { +TOKEN_TO_PROMPT_STYLE: dict[_TokenType, str] = { Token.Menu.Completions.Completion.Current: "completion-menu.completion.current", Token.Menu.Completions.Completion: "completion-menu.completion", Token.Menu.Completions.Meta.Current: "completion-menu.meta.completion.current", @@ -41,10 +43,10 @@ } # reverse dict for cli_helpers, because they still expect Pygments tokens. -PROMPT_STYLE_TO_TOKEN = {v: k for k, v in TOKEN_TO_PROMPT_STYLE.items()} +PROMPT_STYLE_TO_TOKEN: dict[str, _TokenType] = {v: k for k, v in TOKEN_TO_PROMPT_STYLE.items()} -def parse_pygments_style(token_name, style_object, style_dict): +def parse_pygments_style(token_name: str, style_object: PygmentsStyle | dict, style_dict: dict[str, str]) -> tuple[_TokenType, str]: """Parse token type and style string. :param token_name: str name of Pygments token. Example: "Token.String" @@ -53,20 +55,20 @@ def parse_pygments_style(token_name, style_object, style_dict): """ token_type = string_to_tokentype(token_name) - try: + if isinstance(style_object, PygmentsStyle): other_token_type = string_to_tokentype(style_dict[token_name]) return token_type, style_object.styles[other_token_type] - except AttributeError as err: + else: return token_type, style_dict[token_name] -def style_factory(name, cli_style): +def style_factory(name: str, cli_style: dict[str, str]) -> _MergedStyle: try: style = pygments.styles.get_style_by_name(name) except ClassNotFound: style = pygments.styles.get_style_by_name("native") - prompt_styles = [] + prompt_styles: list[tuple[str, str]] = [] # prompt-toolkit used pygments tokens for styling before, switched to style # names in 2.0. Convert old token types to new style names, for backwards compatibility. for token in cli_style: @@ -84,13 +86,11 @@ def style_factory(name, cli_style): # https://github.com/jonathanslenders/python-prompt-toolkit/blob/master/prompt_toolkit/styles/defaults.py prompt_styles.append((token, cli_style[token])) - override_style = Style([("bottom-toolbar", "noreverse")]) - return merge_styles( - [style_from_pygments_cls(style), override_style, Style(prompt_styles)] - ) + override_style: Style = Style([("bottom-toolbar", "noreverse")]) + return merge_styles([style_from_pygments_cls(style), override_style, Style(prompt_styles)]) -def style_factory_output(name, cli_style): +def style_factory_output(name: str, cli_style: dict[str, str]) -> PygmentsStyle: try: style = pygments.styles.get_style_by_name(name).styles except ClassNotFound: @@ -111,4 +111,5 @@ class OutputStyle(PygmentsStyle): default_style = "" styles = style - return OutputStyle + # mypy does not complain but ty complains: error[invalid-return-type]: Return type does not match returned value. Hence added cast. + return cast(OutputStyle, PygmentsStyle) diff --git a/litecli/clitoolbar.py b/litecli/clitoolbar.py index 05d0bfd..e3bc6ee 100644 --- a/litecli/clitoolbar.py +++ b/litecli/clitoolbar.py @@ -1,37 +1,31 @@ -from __future__ import unicode_literals +from __future__ import annotations + +from typing import Callable, Any from prompt_toolkit.key_binding.vi_state import InputMode from prompt_toolkit.enums import EditingMode from prompt_toolkit.application import get_app -def create_toolbar_tokens_func(cli, show_fish_help): - """ - Return a function that generates the toolbar tokens. - """ +def create_toolbar_tokens_func(cli: Any, show_fish_help: Callable[[], bool]) -> Callable[[], list[tuple[str, str]]]: + """Return a function that generates the toolbar tokens.""" - def get_toolbar_tokens(): - result = [] + def get_toolbar_tokens() -> list[tuple[str, str]]: + result: list[tuple[str, str]] = [] result.append(("class:bottom-toolbar", " ")) if cli.multi_line: - result.append( - ("class:bottom-toolbar", " (Semi-colon [;] will end the line) ") - ) + result.append(("class:bottom-toolbar", " (Semi-colon [;] will end the line) ")) if cli.multi_line: result.append(("class:bottom-toolbar.on", "[F3] Multiline: ON ")) else: result.append(("class:bottom-toolbar.off", "[F3] Multiline: OFF ")) if cli.prompt_app.editing_mode == EditingMode.VI: - result.append( - ("class:botton-toolbar.on", "Vi-mode ({})".format(_get_vi_mode())) - ) + result.append(("class:botton-toolbar.on", "Vi-mode ({})".format(_get_vi_mode()))) if show_fish_help(): - result.append( - ("class:bottom-toolbar", " Right-arrow to complete suggestion") - ) + result.append(("class:bottom-toolbar", " Right-arrow to complete suggestion")) if cli.completion_refresher.is_refreshing(): result.append(("class:bottom-toolbar", " Refreshing completions...")) @@ -41,11 +35,12 @@ def get_toolbar_tokens(): return get_toolbar_tokens -def _get_vi_mode(): +def _get_vi_mode() -> str: """Get the current vi mode for display.""" return { InputMode.INSERT: "I", InputMode.NAVIGATION: "N", InputMode.REPLACE: "R", InputMode.INSERT_MULTIPLE: "M", + InputMode.REPLACE_SINGLE: "R", }[get_app().vi_state.input_mode] diff --git a/litecli/compat.py b/litecli/compat.py deleted file mode 100644 index 7316261..0000000 --- a/litecli/compat.py +++ /dev/null @@ -1,9 +0,0 @@ -# -*- coding: utf-8 -*- -"""Platform and Python version compatibility support.""" - -import sys - - -PY2 = sys.version_info[0] == 2 -PY3 = sys.version_info[0] == 3 -WIN = sys.platform in ("win32", "cygwin") diff --git a/litecli/completion_refresher.py b/litecli/completion_refresher.py index 9602070..4e76faa 100644 --- a/litecli/completion_refresher.py +++ b/litecli/completion_refresher.py @@ -1,20 +1,27 @@ +from __future__ import annotations + import threading -from .packages.special.main import COMMANDS from collections import OrderedDict +from typing import Callable, cast +from .packages.special.main import COMMANDS from .sqlcompleter import SQLCompleter from .sqlexecute import SQLExecute class CompletionRefresher(object): + refreshers: dict[str, Callable] = OrderedDict() - refreshers = OrderedDict() - - def __init__(self): - self._completer_thread = None + def __init__(self) -> None: + self._completer_thread: threading.Thread | None = None self._restart_refresh = threading.Event() - def refresh(self, executor, callbacks, completer_options=None): + def refresh( + self, + executor: SQLExecute, + callbacks: Callable | list[Callable], + completer_options: dict | None = None, + ) -> list[tuple]: """Creates a SQLCompleter object and populates it with the relevant completion suggestions in a background thread. @@ -37,27 +44,26 @@ def refresh(self, executor, callbacks, completer_options=None): # if DB is memory, needed to use same connection # So can't use same connection with different thread self._bg_refresh(executor, callbacks, completer_options) + return [(None, None, None, "Auto-completion refresh started in the background.")] else: self._completer_thread = threading.Thread( target=self._bg_refresh, args=(executor, callbacks, completer_options), name="completion_refresh", ) - self._completer_thread.setDaemon(True) + self._completer_thread.daemon = True self._completer_thread.start() - return [ - ( - None, - None, - None, - "Auto-completion refresh started in the background.", - ) - ] - - def is_refreshing(self): - return self._completer_thread and self._completer_thread.is_alive() - - def _bg_refresh(self, sqlexecute, callbacks, completer_options): + return [(None, None, None, "Auto-completion refresh started in the background.")] + + def is_refreshing(self) -> bool: + return bool(self._completer_thread and self._completer_thread.is_alive()) + + def _bg_refresh( + self, + sqlexecute: SQLExecute, + callbacks: Callable | list[Callable], + completer_options: dict, + ) -> None: completer = SQLCompleter(**completer_options) e = sqlexecute @@ -65,12 +71,14 @@ def _bg_refresh(self, sqlexecute, callbacks, completer_options): # if DB is memory, needed to use same connection executor = sqlexecute else: - # Create a new sqlexecute method to popoulate the completions. + # Create a new sqlexecute method to populate the completions. executor = SQLExecute(e.dbname) # If callbacks is a single function then push it into a list. if callable(callbacks): - callbacks = [callbacks] + callbacks_list: list[Callable] = [callbacks] + else: + callbacks_list = list(cast(list[Callable], callbacks)) while 1: for refresher in self.refreshers.values(): @@ -79,7 +87,7 @@ def _bg_refresh(self, sqlexecute, callbacks, completer_options): self._restart_refresh.clear() break else: - # Break out of while loop if the for loop finishes natually + # Break out of while loop if the for loop finishes naturally # without hitting the break statement. break @@ -87,16 +95,16 @@ def _bg_refresh(self, sqlexecute, callbacks, completer_options): # break statement. continue - for callback in callbacks: + for callback in callbacks_list: callback(completer) -def refresher(name, refreshers=CompletionRefresher.refreshers): +def refresher(name: str, refreshers: dict[str, Callable] = CompletionRefresher.refreshers) -> Callable: """Decorator to add the decorated function to the dictionary of refreshers. Any function decorated with a @refresher will be executed as part of the completion refresh routine.""" - def wrapper(wrapped): + def wrapper(wrapped: Callable) -> Callable: refreshers[name] = wrapped return wrapped @@ -104,28 +112,29 @@ def wrapper(wrapped): @refresher("databases") -def refresh_databases(completer, executor): +def refresh_databases(completer: SQLCompleter, executor: SQLExecute) -> None: completer.extend_database_names(executor.databases()) @refresher("schemata") -def refresh_schemata(completer, executor): +def refresh_schemata(completer: SQLCompleter, executor: SQLExecute) -> None: # name of the current database. completer.extend_schemata(executor.dbname) completer.set_dbname(executor.dbname) @refresher("tables") -def refresh_tables(completer, executor): - completer.extend_relations(executor.tables(), kind="tables") - completer.extend_columns(executor.table_columns(), kind="tables") +def refresh_tables(completer: SQLCompleter, executor: SQLExecute) -> None: + table_cols = list(executor.table_columns()) + completer.extend_relations(table_cols, kind="tables") + completer.extend_columns(table_cols, kind="tables") @refresher("functions") -def refresh_functions(completer, executor): +def refresh_functions(completer: SQLCompleter, executor: SQLExecute) -> None: completer.extend_functions(executor.functions()) @refresher("special_commands") -def refresh_special(completer, executor): - completer.extend_special_commands(COMMANDS.keys()) +def refresh_special(completer: SQLCompleter, executor: SQLExecute) -> None: + completer.extend_special_commands(list(COMMANDS.keys())) diff --git a/litecli/config.py b/litecli/config.py index 1c7fb25..953bc16 100644 --- a/litecli/config.py +++ b/litecli/config.py @@ -1,30 +1,34 @@ +from __future__ import annotations + import errno -import shutil import os import platform -from os.path import expanduser, exists, dirname +import shutil +from os.path import dirname, exists, expanduser + from configobj import ConfigObj -def config_location(): +def config_location() -> str: if "XDG_CONFIG_HOME" in os.environ: return "%s/litecli/" % expanduser(os.environ["XDG_CONFIG_HOME"]) elif platform.system() == "Windows": - return os.getenv("USERPROFILE") + "\\AppData\\Local\\dbcli\\litecli\\" + userprofile = os.getenv("USERPROFILE", "") + return userprofile + "\\AppData\\Local\\dbcli\\litecli\\" else: return expanduser("~/.config/litecli/") -def load_config(usr_cfg, def_cfg=None): +def load_config(usr_cfg: str, def_cfg: str | None = None) -> ConfigObj: cfg = ConfigObj() - cfg.merge(ConfigObj(def_cfg, interpolation=False)) + if def_cfg: + cfg.merge(ConfigObj(def_cfg, interpolation=False)) cfg.merge(ConfigObj(expanduser(usr_cfg), interpolation=False, encoding="utf-8")) cfg.filename = expanduser(usr_cfg) - return cfg -def ensure_dir_exists(path): +def ensure_dir_exists(path: str) -> None: parent_dir = expanduser(dirname(path)) try: os.makedirs(parent_dir) @@ -34,29 +38,31 @@ def ensure_dir_exists(path): raise -def write_default_config(source, destination, overwrite=False): +def write_default_config(source: str, destination: str, overwrite: bool = False) -> None: destination = expanduser(destination) if not overwrite and exists(destination): return - ensure_dir_exists(destination) - shutil.copyfile(source, destination) -def upgrade_config(config, def_config): +def upgrade_config(config: str, def_config: str) -> None: cfg = load_config(config, def_config) cfg.write() -def get_config(liteclirc_file=None): +def get_config(liteclirc_file: str | None = None) -> ConfigObj: from litecli import __file__ as package_root - package_root = os.path.dirname(package_root) + package_root = os.path.dirname(str(package_root)) - liteclirc_file = liteclirc_file or "%sconfig" % config_location() + liteclirc_file = liteclirc_file or f"{config_location()}config" default_config = os.path.join(package_root, "liteclirc") - write_default_config(default_config, liteclirc_file) + try: + write_default_config(default_config, liteclirc_file) + except OSError: + # If we can't write to the config file, just use the default config + return load_config(default_config) return load_config(liteclirc_file, default_config) diff --git a/litecli/encodingutils.py b/litecli/encodingutils.py deleted file mode 100644 index bd23820..0000000 --- a/litecli/encodingutils.py +++ /dev/null @@ -1,36 +0,0 @@ -# -*- coding: utf-8 -*- -from __future__ import unicode_literals - -from litecli.compat import PY2 - - -if PY2: - text_type = unicode - binary_type = str -else: - text_type = str - binary_type = bytes - - -def unicode2utf8(arg): - """Convert strings to UTF8-encoded bytes. - - Only in Python 2. In Python 3 the args are expected as unicode. - - """ - - if PY2 and isinstance(arg, text_type): - return arg.encode("utf-8") - return arg - - -def utf8tounicode(arg): - """Convert UTF8-encoded bytes to strings. - - Only in Python 2. In Python 3 the errors are returned as strings. - - """ - - if PY2 and isinstance(arg, binary_type): - return arg.decode("utf-8") - return arg diff --git a/litecli/key_bindings.py b/litecli/key_bindings.py index 44d59d2..e2e097b 100644 --- a/litecli/key_bindings.py +++ b/litecli/key_bindings.py @@ -1,24 +1,27 @@ -from __future__ import unicode_literals +from __future__ import annotations import logging +from typing import Any + from prompt_toolkit.enums import EditingMode from prompt_toolkit.filters import completion_is_selected from prompt_toolkit.key_binding import KeyBindings +from prompt_toolkit.key_binding.key_processor import KeyPressEvent _logger = logging.getLogger(__name__) -def cli_bindings(cli): +def cli_bindings(cli: Any) -> KeyBindings: """Custom key bindings for cli.""" kb = KeyBindings() @kb.add("f3") - def _(event): + def _(_event: KeyPressEvent) -> None: """Enable/Disable Multiline Mode.""" _logger.debug("Detected F3 key.") cli.multi_line = not cli.multi_line @kb.add("f4") - def _(event): + def _(event: KeyPressEvent) -> None: """Toggle between Vi and Emacs mode.""" _logger.debug("Detected F4 key.") if cli.key_bindings == "vi": @@ -29,7 +32,7 @@ def _(event): cli.key_bindings = "vi" @kb.add("tab") - def _(event): + def _(event: KeyPressEvent) -> None: """Force autocompletion at cursor.""" _logger.debug("Detected key.") b = event.app.current_buffer @@ -39,7 +42,7 @@ def _(event): b.start_completion(select_first=True) @kb.add("s-tab") - def _(event): + def _(event: KeyPressEvent) -> None: """Force autocompletion at cursor.""" _logger.debug("Detected key.") b = event.app.current_buffer @@ -49,7 +52,7 @@ def _(event): b.start_completion(select_last=True) @kb.add("c-space") - def _(event): + def _(event: KeyPressEvent) -> None: """ Initialize autocompletion at cursor. @@ -67,7 +70,7 @@ def _(event): b.start_completion(select_first=False) @kb.add("enter", filter=completion_is_selected) - def _(event): + def _(event: KeyPressEvent) -> None: """Makes the enter key work as the tab key only when showing the menu. In other words, don't execute query when enter is pressed in @@ -81,4 +84,12 @@ def _(event): b = event.app.current_buffer b.complete_state = None + @kb.add("right", filter=completion_is_selected) + def _(event: KeyPressEvent) -> None: + """Accept the completion that is selected in the dropdown menu.""" + _logger.debug("Detected right-arrow key.") + + b = event.app.current_buffer + b.complete_state = None + return kb diff --git a/litecli/lexer.py b/litecli/lexer.py index 678eb3f..9260a17 100644 --- a/litecli/lexer.py +++ b/litecli/lexer.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from pygments.lexer import inherit from pygments.lexers.sql import MySqlLexer from pygments.token import Keyword @@ -6,4 +8,10 @@ class LiteCliLexer(MySqlLexer): """Extends SQLite lexer to add keywords.""" - tokens = {"root": [(r"\brepair\b", Keyword), (r"\boffset\b", Keyword), inherit]} + tokens = { + "root": [ + (r"\brepair\b", Keyword), + (r"\boffset\b", Keyword), + inherit, + ] + } diff --git a/litecli/liteclirc b/litecli/liteclirc index e3331d1..ec162d3 100644 --- a/litecli/liteclirc +++ b/litecli/liteclirc @@ -18,6 +18,12 @@ destructive_warning = True # %USERPROFILE% is typically C:\Users\{username} log_file = default +# history_file location. +# In Unix/Linux: ~/.config/litecli/history +# In Windows: %USERPROFILE%\AppData\Local\dbcli\litecli\history +# %USERPROFILE% is typically C:\Users\{username} +history_file = default + # Default log level. Possible values: "CRITICAL", "ERROR", "WARNING", "INFO" # and "DEBUG". "NONE" disables logging. log_level = INFO @@ -41,7 +47,7 @@ table_format = ascii # manni, igor, xcode, vim, autumn, vs, rrt, native, perldoc, borland, tango, emacs, # friendly, monokai, paraiso, colorful, murphy, bw, pastie, paraiso, trac, default, # fruity. -# Screenshots at http://mycli.net/syntax +# See the LiteCLI README for syntax examples syntax_style = default # Keybindings: Possible values: emacs, vi. @@ -52,18 +58,27 @@ key_bindings = emacs # Enabling this option will show the suggestions in a wider menu. Thus more items are suggested. wider_completion_menu = False +# Autocompletion is on by default. This can be truned off by setting this +# option to False. Pressing tab will still trigger completion. +autocompletion = True + # litecli prompt # \D - The full current date # \d - Database name +# \f - File basename of the "main" database # \m - Minutes of the current time # \n - Newline # \P - AM/PM # \R - The current time, in 24-hour military time (0-23) # \r - The current time, standard 12-hour time (1-12) # \s - Seconds of the current time +# \x1b[...m - insert ANSI escape sequence prompt = '\d> ' prompt_continuation = '-> ' +# Show/hide the informational toolbar with function keymap at the footer. +show_bottom_toolbar = True + # Skip intro info on startup and outro info on exit less_chatty = False @@ -108,6 +123,12 @@ output.header = "#00ff5f bold" output.odd-row = "" output.even-row = "" - # Favorite queries. [favorite_queries] + +# Startup commands +# litecli commands or sqlite commands to be executed on startup. +# some of them will require you to have a database attached. +# they will be executed in the same order as they appear in the list. +[startup_commands] +#commands = ".tables", "pragma foreign_keys = ON;" diff --git a/litecli/main.py b/litecli/main.py index 23f5f79..fa732c3 100644 --- a/litecli/main.py +++ b/litecli/main.py @@ -1,54 +1,65 @@ -from __future__ import unicode_literals -from __future__ import print_function +from __future__ import annotations +import itertools +import logging import os +import re +import shutil import sys -import traceback -import logging import threading -from time import time +import traceback +from collections import namedtuple from datetime import datetime from io import open -from collections import namedtuple -from sqlite3 import OperationalError +from time import time +from typing import Any, Generator, Iterable, Literal, TextIO, cast -from cli_helpers.tabular_output import TabularOutputFormatter -from cli_helpers.tabular_output import preprocessors import click import sqlparse -from prompt_toolkit.completion import DynamicCompleter -from prompt_toolkit.enums import DEFAULT_BUFFER, EditingMode -from prompt_toolkit.shortcuts import PromptSession, CompleteStyle -from prompt_toolkit.styles.pygments import style_from_pygments_cls +from cli_helpers.tabular_output import TabularOutputFormatter, preprocessors +from prompt_toolkit.auto_suggest import AutoSuggestFromHistory +from prompt_toolkit.completion import Completion, DynamicCompleter from prompt_toolkit.document import Document +from prompt_toolkit.enums import DEFAULT_BUFFER, EditingMode from prompt_toolkit.filters import HasFocus, IsDone +from prompt_toolkit.formatted_text import ANSI +from prompt_toolkit.history import FileHistory from prompt_toolkit.layout.processors import ( - HighlightMatchingBracketProcessor, ConditionalProcessor, + HighlightMatchingBracketProcessor, ) from prompt_toolkit.lexers import PygmentsLexer -from prompt_toolkit.history import FileHistory -from prompt_toolkit.auto_suggest import AutoSuggestFromHistory +from prompt_toolkit.shortcuts import CompleteStyle, PromptSession -from .packages.special.main import NO_QUERY -from .packages.prompt_utils import confirm, confirm_destructive_query -from .packages import special -from .sqlcompleter import SQLCompleter -from .clitoolbar import create_toolbar_tokens_func -from .clistyle import style_factory, style_factory_output -from .sqlexecute import SQLExecute +from .__init__ import __version__ from .clibuffer import cli_is_multiline +from .clistyle import style_factory, style_factory_output +from .clitoolbar import create_toolbar_tokens_func from .completion_refresher import CompletionRefresher from .config import config_location, ensure_dir_exists, get_config from .key_bindings import cli_bindings -from .encodingutils import utf8tounicode, text_type from .lexer import LiteCliLexer -from .__init__ import __version__ +from .packages import special from .packages.filepaths import dir_path_exists +from .packages.prompt_utils import confirm, confirm_destructive_query +from .packages.special.main import NO_QUERY +from .sqlcompleter import SQLCompleter +from .sqlexecute import SQLExecute -import itertools -click.disable_unicode_literals_warning = True +def _load_sqlite3() -> Any: + try: + import sqlean + except ImportError: + import sqlite3 + + return sqlite3 + return sqlean + + +_sqlite3 = _load_sqlite3() +OperationalError = _sqlite3.OperationalError +sqlite_version = _sqlite3.sqlite_version # Query tuples are used for maintaining history Query = namedtuple("Query", ["query", "successful", "mutating"]) @@ -57,21 +68,20 @@ class LiteCli(object): - default_prompt = "\\d> " max_len_prompt = 45 def __init__( self, - sqlexecute=None, - prompt=None, - logfile=None, - auto_vertical_output=False, - warn=None, - liteclirc=None, - ): + sqlexecute: SQLExecute | None = None, + prompt: str | None = None, + logfile: TextIO | None = None, + auto_vertical_output: bool = False, + warn: bool | None = None, + liteclirc: str | None = None, + ) -> None: self.sqlexecute = sqlexecute - self.logfile = logfile + self.logfile: TextIO | Literal[False] | None = logfile # Load config. c = self.config = get_config(liteclirc) @@ -80,20 +90,21 @@ def __init__( self.key_bindings = c["main"]["key_bindings"] special.set_favorite_queries(self.config) self.formatter = TabularOutputFormatter(format_name=c["main"]["table_format"]) - self.formatter.litecli = self + # self.formatter.litecli = self, ty raises unresolved-attribute, hence use dynamic assignment + setattr(self.formatter, "litecli", self) self.syntax_style = c["main"]["syntax_style"] self.less_chatty = c["main"].as_bool("less_chatty") + self.show_bottom_toolbar = c["main"].as_bool("show_bottom_toolbar") self.cli_style = c["colors"] self.output_style = style_factory_output(self.syntax_style, self.cli_style) self.wider_completion_menu = c["main"].as_bool("wider_completion_menu") + self.autocompletion = c["main"].as_bool("autocompletion") c_dest_warning = c["main"].as_bool("destructive_warning") self.destructive_warning = c_dest_warning if warn is None else warn self.login_path_as_host = c["main"].as_bool("login_path_as_host") # read from cli argument or user config file - self.auto_vertical_output = auto_vertical_output or c["main"].as_bool( - "auto_vertical_output" - ) + self.auto_vertical_output = auto_vertical_output or c["main"].as_bool("auto_vertical_output") # audit log if self.logfile is None and "audit_log" in c["main"]: @@ -106,6 +117,11 @@ def __init__( fg="red", ) self.logfile = False + # Load startup commands. + try: + self.startup_commands = c["startup_commands"] + except KeyError: # Redundant given the load_config() function that merges in the standard config, but put here to avoid fail if user do not have updated config file. + self.startup_commands = None self.completion_refresher = CompletionRefresher() @@ -113,13 +129,11 @@ def __init__( self.initialize_logging() prompt_cnf = self.read_my_cnf_files(["prompt"])["prompt"] - self.prompt_format = ( - prompt or prompt_cnf or c["main"]["prompt"] or self.default_prompt - ) + self.prompt_format = prompt or prompt_cnf or c["main"]["prompt"] or self.default_prompt self.prompt_continuation_format = c["main"]["prompt_continuation"] keyword_casing = c["main"].get("keyword_casing", "auto") - self.query_history = [] + self.query_history: list[Query] = [] # Initialize completer. self.completer = SQLCompleter( @@ -131,9 +145,9 @@ def __init__( # Register custom special commands. self.register_special_commands() - self.prompt_app = None + self.prompt_app: PromptSession | None = None - def register_special_commands(self): + def register_special_commands(self) -> None: special.register_special_command( self.change_db, ".open", @@ -159,10 +173,11 @@ def register_special_commands(self): ) special.register_special_command( self.execute_from_file, - "source", + ".read", "\\. filename", "Execute commands from file.", - aliases=("\\.",), + case_sensitive=True, + aliases=("\\.", "source"), ) special.register_special_command( self.change_prompt_format, @@ -173,7 +188,7 @@ def register_special_commands(self): case_sensitive=True, ) - def change_table_format(self, arg, **_): + def change_table_format(self, arg: str, **_: Any) -> Generator[tuple[None, None, None, str], None, None]: try: self.formatter.format_name = arg yield (None, None, None, "Changed table format to {}".format(arg)) @@ -183,21 +198,26 @@ def change_table_format(self, arg, **_): msg += "\n\t{}".format(table_type) yield (None, None, None, msg) - def change_db(self, arg, **_): + def change_db(self, arg: str | None, **_: Any) -> Iterable[tuple]: if arg is None: + assert self.sqlexecute is not None self.sqlexecute.connect() else: + assert self.sqlexecute is not None self.sqlexecute.connect(database=arg) self.refresh_completions() + # guard so that ty doesn't complain + dbname = self.sqlexecute.dbname if self.sqlexecute is not None else "" + yield ( None, None, None, - 'You are now connected to database "%s"' % (self.sqlexecute.dbname), + 'You are now connected to database "%s"' % (dbname), ) - def execute_from_file(self, arg, **_): + def execute_from_file(self, arg: str | None, **_: Any) -> Iterable[tuple[Any, ...]]: if not arg: message = "Missing required argument, filename." return [(None, None, None, message)] @@ -211,9 +231,10 @@ def execute_from_file(self, arg, **_): message = "Wise choice. Command execution stopped." return [(None, None, None, message)] - return self.sqlexecute.run(query) + assert self.sqlexecute is not None + return cast(Iterable[tuple[Any, ...]], self.sqlexecute.run(query)) - def change_prompt_format(self, arg, **_): + def change_prompt_format(self, arg: str | None, **_: Any) -> Iterable[tuple]: """ Change the prompt format. """ @@ -224,12 +245,16 @@ def change_prompt_format(self, arg, **_): self.prompt_format = self.get_prompt(arg) return [(None, None, None, "Changed prompt format to %s" % arg)] - def initialize_logging(self): - + def initialize_logging(self) -> None: log_file = self.config["main"]["log_file"] if log_file == "default": log_file = config_location() + "log" - ensure_dir_exists(log_file) + log_file = os.path.expanduser(log_file) + try: + ensure_dir_exists(log_file) + except OSError: + # Unable to create log file, log to temp directory instead. + log_file = "/tmp/litecli.log" log_level = self.config["main"]["log_level"] @@ -244,7 +269,7 @@ def initialize_logging(self): # Disable logging if value is NONE by switching to a no-op handler # Set log level to a high value so it doesn't even waste cycles getting called. if log_level.upper() == "NONE": - handler = logging.NullHandler() + handler: logging.Handler = logging.NullHandler() log_level = "CRITICAL" elif dir_path_exists(log_file): handler = logging.FileHandler(log_file) @@ -256,10 +281,7 @@ def initialize_logging(self): ) return - formatter = logging.Formatter( - "%(asctime)s (%(process)d/%(threadName)s) " - "%(name)s %(levelname)s - %(message)s" - ) + formatter = logging.Formatter("%(asctime)s (%(process)d/%(threadName)s) %(name)s %(levelname)s - %(message)s") handler.setFormatter(formatter) @@ -272,7 +294,7 @@ def initialize_logging(self): root_logger.debug("Initializing litecli logging.") root_logger.debug("Log file %r.", log_file) - def read_my_cnf_files(self, keys): + def read_my_cnf_files(self, keys: Iterable[str]) -> dict[str, str | None]: """ Reads a list of config files and merges them. The last one will win. :param files: list of files to read @@ -283,7 +305,7 @@ def read_my_cnf_files(self, keys): sections = ["main"] - def get(key): + def get(key: str) -> str | None: result = None for sect in cnf: if sect in sections and key in cnf[sect]: @@ -292,20 +314,19 @@ def get(key): return {x: get(x) for x in keys} - def connect(self, database=""): - - cnf = {"database": None} + def connect(self, database: str | None = "") -> None: + cnf: dict[str, str | None] = {"database": None} cnf = self.read_my_cnf_files(cnf.keys()) # Fall back to config values only if user did not specify a value. - database = database or cnf["database"] + db_value: str | None = database or cnf["database"] # Connect to the database. - def _connect(): - self.sqlexecute = SQLExecute(database) + def _connect() -> None: + self.sqlexecute = SQLExecute(db_value) try: _connect() @@ -315,8 +336,8 @@ def _connect(): self.echo(str(e), err=True, fg="red") exit(1) - def handle_editor_command(self, text): - """Editor command is any query that is prefixed or suffixed by a '\e'. + def handle_editor_command(self, text: str) -> str: + R"""Editor command is any query that is prefixed or suffixed by a '\e'. The reason for a while loop is because a user might edit a query multiple times. For eg: @@ -337,6 +358,7 @@ def handle_editor_command(self, text): raise RuntimeError(message) while True: try: + assert self.prompt_app is not None text = self.prompt_app.prompt(default=sql) break except KeyboardInterrupt: @@ -345,21 +367,24 @@ def handle_editor_command(self, text): continue return text - def run_cli(self): + def run_cli(self) -> None: iterations = 0 sqlexecute = self.sqlexecute + assert sqlexecute is not None logger = self.logger self.configure_pager() self.refresh_completions() - history_file = config_location() + "history" + history_file = self.config["main"]["history_file"] + if history_file == "default": + history_file = config_location() + "history" + history_file = os.path.expanduser(history_file) if dir_path_exists(history_file): history = FileHistory(history_file) else: history = None self.echo( - 'Error: Unable to open the history file "{}". ' - "Your query history will not be saved.".format(history_file), + 'Error: Unable to open the history file "{}". Your query history will not be saved.'.format(history_file), err=True, fg="red", ) @@ -367,30 +392,69 @@ def run_cli(self): key_bindings = cli_bindings(self) if not self.less_chatty: - print("Version:", __version__) - print("Mail: https://groups.google.com/forum/#!forum/litecli-users") + print(f"LiteCli: {__version__} (SQLite: {sqlite_version})") print("GitHub: https://github.com/dbcli/litecli") - # print("Home: https://litecli.com") - def get_message(): + def get_message() -> ANSI: prompt = self.get_prompt(self.prompt_format) - if ( - self.prompt_format == self.default_prompt - and len(prompt) > self.max_len_prompt - ): + if self.prompt_format == self.default_prompt and len(prompt) > self.max_len_prompt: prompt = self.get_prompt("\\d> ") - return [("class:prompt", prompt)] + prompt = prompt.replace("\\x1b", "\x1b") + return ANSI(prompt) - def get_continuation(width, line_number, is_soft_wrap): + def get_continuation(width: int, line_number: int, is_soft_wrap: int) -> list[tuple[str, str]]: continuation = " " * (width - 1) + " " return [("class:continuation", continuation)] - def show_suggestion_tip(): + def show_suggestion_tip() -> bool: return iterations < 2 - def one_iteration(text=None): + def output_res(res: Iterable[tuple[Any, Any, Any, str | None]], start: float) -> bool: + result_count = 0 + mutating = False + for title, cur, headers, status in res: + logger.debug("headers: %r", headers) + logger.debug("rows: %r", cur) + logger.debug("status: %r", status) + threshold = 1000 + if is_select(status) and cur and cur.rowcount > threshold: + self.echo( + "The result set has more than {} rows.".format(threshold), + fg="red", + ) + if not confirm("Do you want to continue?"): + self.echo("Aborted!", err=True, fg="red") + break + + if self.auto_vertical_output: + assert self.prompt_app is not None + max_width = self.prompt_app.output.get_size().columns + else: + max_width = None + + formatted = self.format_output(title, cur, headers, special.is_expanded_output(), max_width) + + t = time() - start + try: + if result_count > 0: + self.echo("") + try: + self.output(formatted, status) + except KeyboardInterrupt: + pass + self.echo("Time: %0.03fs" % t) + except KeyboardInterrupt: + pass + + start = time() + result_count += 1 + mutating = mutating or is_mutating(status) + return mutating + + def one_iteration(text: str | None = None) -> None: if text is None: try: + assert self.prompt_app is not None text = self.prompt_app.prompt() except KeyboardInterrupt: return @@ -405,6 +469,33 @@ def one_iteration(text=None): self.echo(str(e), err=True, fg="red") return + while special.is_llm_command(text): + try: + start = time() + assert self.sqlexecute is not None + conn = self.sqlexecute.conn + assert conn is not None + cur = conn.cursor() + context, sql, duration = special.handle_llm(text, cur) + if context: + click.echo("LLM Reponse:") + click.echo(context) + click.echo("---") + click.echo(f"Time: {duration:.2f} seconds") + assert self.prompt_app is not None + text = self.prompt_app.prompt(default=sql) + except KeyboardInterrupt: + return + except special.FinishIteration as e: + if e.results: + output_res(e.results, start) + return + except RuntimeError as e: + logger.error("sql: %r, error: %r", text, e) + logger.error("traceback: %r", traceback.format_exc()) + self.echo(str(e), err=True, fg="red") + return + if not text.strip(): return @@ -418,9 +509,6 @@ def one_iteration(text=None): self.echo("Wise choice!") return - # Keep track of whether or not the query is mutating. In case - # of a multi-statement query, the overall query is considered - # mutating if any one of the component statements is mutating mutating = False try: @@ -435,74 +523,32 @@ def one_iteration(text=None): successful = False start = time() res = sqlexecute.run(text) - self.formatter.query = text + # Set query attribute dynamically on formatter + setattr(self.formatter, "query", text) successful = True - result_count = 0 - for title, cur, headers, status in res: - logger.debug("headers: %r", headers) - logger.debug("rows: %r", cur) - logger.debug("status: %r", status) - threshold = 1000 - if is_select(status) and cur and cur.rowcount > threshold: - self.echo( - "The result set has more than {} rows.".format(threshold), - fg="red", - ) - if not confirm("Do you want to continue?"): - self.echo("Aborted!", err=True, fg="red") - break - - if self.auto_vertical_output: - max_width = self.prompt_app.output.get_size().columns - else: - max_width = None - - formatted = self.format_output( - title, cur, headers, special.is_expanded_output(), max_width - ) - - t = time() - start - try: - if result_count > 0: - self.echo("") - try: - self.output(formatted, status) - except KeyboardInterrupt: - pass - self.echo("Time: %0.03fs" % t) - except KeyboardInterrupt: - pass - - start = time() - result_count += 1 - mutating = mutating or is_mutating(status) special.unset_once_if_written() + # Keep track of whether or not the query is mutating. In case + # of a multi-statement query, the overall query is considered + # mutating if any one of the component statements is mutating + mutating = output_res(res, start) + special.unset_pipe_once_if_written() except EOFError as e: raise e except KeyboardInterrupt: - # get last connection id - connection_id_to_kill = sqlexecute.connection_id - logger.debug("connection id to kill: %r", connection_id_to_kill) - # Restart connection to the database - sqlexecute.connect() try: - for title, cur, headers, status in sqlexecute.run( - "kill %s" % connection_id_to_kill - ): - status_str = str(status).lower() - if status_str.find("ok") > -1: - logger.debug( - "cancelled query, connection id: %r, sql: %r", - connection_id_to_kill, - text, - ) - self.echo("cancelled query", err=True, fg="red") + # since connection can be sqlite3 or sqlean, it's hard to annotate the type for interrupt. so ignore the type hint warning. + conn = sqlexecute.conn + if conn is not None: + conn.interrupt() # type: ignore[attr-defined] except Exception as e: self.echo( "Encountered error while cancelling query: {}".format(e), err=True, fg="red", ) + else: + logger.debug("cancelled query") + self.echo("cancelled query", err=True, fg="red") except NotImplementedError: self.echo("Not Yet Implemented.", fg="yellow") except OperationalError as e: @@ -515,9 +561,9 @@ def one_iteration(text=None): logger.debug("Reconnected successfully.") one_iteration(text) return # OK to just return, cuz the recursion call runs to the end. - except OperationalError as e: - logger.debug("Reconnect failed. e: %r", e) - self.echo(str(e), err=True, fg="red") + except OperationalError as ex: + logger.debug("Reconnect failed. e: %r", ex) + self.echo(str(ex), err=True, fg="red") # If reconnection failed, don't proceed further. return else: @@ -529,10 +575,6 @@ def one_iteration(text=None): logger.error("traceback: %r", traceback.format_exc()) self.echo(str(e), err=True, fg="red") else: - if is_dropping_database(text, self.sqlexecute.dbname): - self.sqlexecute.dbname = None - self.sqlexecute.connect() - # Refresh the table names and column names if necessary. if need_completion_refresh(text): self.refresh_completions(reset=need_completion_reset(text)) @@ -549,8 +591,10 @@ def one_iteration(text=None): else: complete_style = CompleteStyle.COLUMN - with self._completer_lock: + if not self.autocompletion: + complete_style = CompleteStyle.READLINE_LIKE + with self._completer_lock: if self.key_bindings == "vi": editing_mode = EditingMode.VI else: @@ -560,8 +604,8 @@ def one_iteration(text=None): lexer=PygmentsLexer(LiteCliLexer), reserve_space_for_menu=self.get_reserved_space(), message=get_message, - prompt_continuation=get_continuation, - bottom_toolbar=get_toolbar_tokens, + prompt_continuation=cast(Any, get_continuation), + bottom_toolbar=get_toolbar_tokens if self.show_bottom_toolbar else None, complete_style=complete_style, input_processors=[ ConditionalProcessor( @@ -585,6 +629,42 @@ def one_iteration(text=None): search_ignore_case=True, ) + def startup_commands() -> None: + if self.startup_commands: + if "commands" in self.startup_commands: + if isinstance(self.startup_commands["commands"], str): + commands = [self.startup_commands["commands"]] + else: + commands = self.startup_commands["commands"] + for command in commands: + try: + res = sqlexecute.run(command) + except Exception as e: + click.echo(command) + self.echo(str(e), err=True, fg="red") + else: + click.echo(command) + for title, cur, headers, status in res: + if title == "dot command not implemented": + self.echo( + "The SQLite dot command '" + command.split(" ", 1)[0] + "' is not yet implemented.", + fg="yellow", + ) + else: + output = self.format_output(title, cur, headers) + for line in output: + self.echo(line) + else: + self.echo( + "Could not read commands. The startup commands needs to be formatted as: \n commands = 'command1', 'command2', ...", + fg="yellow", + ) + + try: + startup_commands() + except Exception as e: + self.echo("Could not execute all startup commands: \n" + str(e), fg="yellow") + try: while True: one_iteration() @@ -594,12 +674,12 @@ def one_iteration(text=None): if not self.less_chatty: self.echo("Goodbye!") - def log_output(self, output): + def log_output(self, output: str) -> None: """Log the output in the audit log, if it's enabled.""" if self.logfile: - click.echo(utf8tounicode(output), file=self.logfile) + click.echo(output, file=self.logfile) - def echo(self, s, **kwargs): + def echo(self, s: str, **kwargs: Any) -> None: """Print a message to stdout. The message will be logged in the audit log, if enabled. @@ -610,20 +690,16 @@ def echo(self, s, **kwargs): self.log_output(s) click.secho(s, **kwargs) - def get_output_margin(self, status=None): + def get_output_margin(self, status: str | None = None) -> int: """Get the output margin (number of rows for the prompt, footer and timing message.""" - margin = ( - self.get_reserved_space() - + self.get_prompt(self.prompt_format).count("\n") - + 2 - ) + margin = self.get_reserved_space() + self.get_prompt(self.prompt_format).count("\n") + 2 if status: margin += 1 + status.count("\n") return margin - def output(self, output, status=None): + def output(self, output: Iterable[str], status: str | None = None) -> None: """Output text to stdout or a pager command. The status text is not outputted to pager or files. @@ -634,6 +710,7 @@ def output(self, output, status=None): """ if output: + assert self.prompt_app is not None size = self.prompt_app.output.get_size() margin = self.get_output_margin(status) @@ -645,6 +722,7 @@ def output(self, output, status=None): self.log_output(line) special.write_tee(line) special.write_once(line) + special.write_pipe_once(line) if fits or output_via_pager: # buffering @@ -675,7 +753,7 @@ def output(self, output, status=None): self.log_output(status) click.secho(status) - def configure_pager(self): + def configure_pager(self) -> None: # Provide sane defaults for less if they are empty. if not os.environ.get("LESS"): os.environ["LESS"] = "-RXF" @@ -690,10 +768,11 @@ def configure_pager(self): if cnf["skip-pager"] or not self.config["main"].as_bool("enable_pager"): special.disable_pager() - def refresh_completions(self, reset=False): + def refresh_completions(self, reset: bool = False) -> list[tuple]: if reset: with self._completer_lock: self.completer.reset_completions() + assert self.sqlexecute is not None self.completion_refresher.refresh( self.sqlexecute, self._on_completions_refreshed, @@ -703,13 +782,10 @@ def refresh_completions(self, reset=False): }, ) - return [ - (None, None, None, "Auto-completion refresh started in the background.") - ] + return [(None, None, None, "Auto-completion refresh started in the background.")] - def _on_completions_refreshed(self, new_completer): - """Swap the completer object in cli with the newly created completer. - """ + def _on_completions_refreshed(self, new_completer: SQLCompleter) -> None: + """Swap the completer object in cli with the newly created completer.""" with self._completer_lock: self.completer = new_completer @@ -718,40 +794,53 @@ def _on_completions_refreshed(self, new_completer): # "Refreshing completions..." indicator self.prompt_app.app.invalidate() - def get_completions(self, text, cursor_positition): + def get_completions(self, text: str, cursor_positition: int) -> Iterable[Completion]: with self._completer_lock: - return self.completer.get_completions( - Document(text=text, cursor_position=cursor_positition), None - ) + return self.completer.get_completions(Document(text=text, cursor_position=cursor_positition), None) - def get_prompt(self, string): - self.logger.debug("Getting prompt") + def get_prompt(self, string: str) -> str: + self.logger.debug("Getting prompt %r", string) sqlexecute = self.sqlexecute + assert sqlexecute is not None now = datetime.now() - string = string.replace("\\d", sqlexecute.dbname or "(none)") - string = string.replace("\\n", "\n") - string = string.replace("\\D", now.strftime("%a %b %d %H:%M:%S %Y")) - string = string.replace("\\m", now.strftime("%M")) - string = string.replace("\\P", now.strftime("%p")) - string = string.replace("\\R", now.strftime("%H")) - string = string.replace("\\r", now.strftime("%I")) - string = string.replace("\\s", now.strftime("%S")) - string = string.replace("\\_", " ") - return string - - def run_query(self, query, new_line=True): + + # Prepare the replacements dictionary + replacements = { + r"\d": sqlexecute.dbname or "(none)", + r"\f": os.path.basename(sqlexecute.dbname or "(none)"), + r"\n": "\n", + r"\D": now.strftime("%a %b %d %H:%M:%S %Y"), + r"\m": now.strftime("%M"), + r"\P": now.strftime("%p"), + r"\R": now.strftime("%H"), + r"\r": now.strftime("%I"), + r"\s": now.strftime("%S"), + r"\_": " ", + } + # Compile a regex pattern that matches any of the keys in replacements + pattern = re.compile("|".join(re.escape(key) for key in replacements.keys())) + + # Define the replacement function + def replacer(match: re.Match[str]) -> str: + return replacements[match.group(0)] + + # Perform the substitution + return pattern.sub(replacer, string) + + def run_query(self, query: str, new_line: bool = True) -> None: """Runs *query*.""" + assert self.sqlexecute is not None results = self.sqlexecute.run(query) for result in results: title, cur, headers, status = result - self.formatter.query = query + setattr(self.formatter, "query", query) output = self.format_output(title, cur, headers) for line in output: click.echo(line, nl=new_line) - def format_output(self, title, cur, headers, expanded=False, max_width=None): + def format_output(self, title: Any, cur: Any, headers: Any, expanded: bool = False, max_width: int | None = None) -> Iterable[str]: expanded = expanded or self.formatter.format_name == "vertical" - output = [] + output_iter: Iterable[str] = [] output_kwargs = { "dialect": "unix", @@ -762,18 +851,12 @@ def format_output(self, title, cur, headers, expanded=False, max_width=None): } if title: # Only print the title if it's not None. - output = itertools.chain(output, [title]) + output_iter = itertools.chain(output_iter, [title]) if cur: column_types = None if hasattr(cur, "description"): - - def get_col_type(col): - # col_type = FIELD_TYPES.get(col[1], text_type) - # return col_type if type(col_type) is type else text_type - return text_type - - column_types = [get_col_type(col) for col in cur.description] + column_types = [str(col) for col in cur.description] if max_width is not None: cur = list(cur) @@ -783,51 +866,45 @@ def get_col_type(col): headers, format_name="vertical" if expanded else None, column_types=column_types, - **output_kwargs + **output_kwargs, ) - if isinstance(formatted, (text_type)): + if isinstance(formatted, str): formatted = formatted.splitlines() formatted = iter(formatted) first_line = next(formatted) formatted = itertools.chain([first_line], formatted) - if ( - not expanded - and max_width - and headers - and cur - and len(first_line) > max_width - ): + if not expanded and max_width and headers and cur and len(first_line) > max_width: formatted = self.formatter.format_output( cur, headers, format_name="vertical", column_types=column_types, - **output_kwargs + **output_kwargs, ) - if isinstance(formatted, (text_type)): + if isinstance(formatted, str): formatted = iter(formatted.splitlines()) - output = itertools.chain(output, formatted) + output_iter = itertools.chain(output_iter, formatted) - return output + return output_iter - def get_reserved_space(self): + def get_reserved_space(self) -> int: """Get the number of lines to reserve for the completion menu.""" reserved_space_ratio = 0.45 max_reserved_space = 8 - _, height = click.get_terminal_size() + _, height = shutil.get_terminal_size() return min(int(round(height * reserved_space_ratio)), max_reserved_space) - def get_last_query(self): + def get_last_query(self) -> str | None: """Get the last query executed or None.""" return self.query_history[-1][0] if self.query_history else None @click.command() -@click.option("-V", "--version", is_flag=True, help="Output litecli's version.") +@click.version_option(__version__, "-V", "--version") @click.option("-D", "--database", "dbname", help="Database to use.") @click.option( "-R", @@ -852,28 +929,23 @@ def get_last_query(self): is_flag=True, help="Automatically switch to vertical output mode if the result is wider than the terminal width.", ) -@click.option( - "-t", "--table", is_flag=True, help="Display batch output in table format." -) +@click.option("-t", "--table", is_flag=True, help="Display batch output in table format.") @click.option("--csv", is_flag=True, help="Display batch output in CSV format.") -@click.option( - "--warn/--no-warn", default=None, help="Warn before running a destructive query." -) +@click.option("--warn/--no-warn", default=None, help="Warn before running a destructive query.") @click.option("-e", "--execute", type=str, help="Execute command and quit.") @click.argument("database", default="", nargs=1) def cli( - database, - dbname, - version, - prompt, - logfile, - auto_vertical_output, - table, - csv, - warn, - execute, - liteclirc, -): + database: str, + dbname: str, + prompt: str | None, + logfile: TextIO | None, + auto_vertical_output: bool, + table: bool, + csv: bool, + warn: bool | None, + execute: str | None, + liteclirc: str, +) -> None: """A SQLite terminal client with auto-completion and syntax highlighting. \b @@ -881,11 +953,6 @@ def cli( - litecli lite_database """ - - if version: - print("Version:", __version__) - sys.exit(0) - litecli = LiteCli( prompt=prompt, logfile=logfile, @@ -899,7 +966,7 @@ def cli( litecli.connect(database) - litecli.logger.debug("Launch Params: \n" "\tdatabase: %r", database) + litecli.logger.debug("Launch Params: \n\tdatabase: %r", database) # --execute argument if execute: @@ -926,10 +993,7 @@ def cli( except (FileNotFoundError, OSError): litecli.logger.warning("Unable to open TTY as stdin.") - if ( - litecli.destructive_warning - and confirm_destructive_query(stdin_text) is False - ): + if litecli.destructive_warning and confirm_destructive_query(stdin_text) is False: exit(0) try: new_line = True @@ -946,7 +1010,7 @@ def cli( exit(1) -def need_completion_refresh(queries): +def need_completion_refresh(queries: str) -> bool: """Determines if the completion needs a refresh by checking if the sql statement is an alter, create, drop or change db.""" for query in sqlparse.split(queries): @@ -964,34 +1028,10 @@ def need_completion_refresh(queries): return True except Exception: return False + return False -def is_dropping_database(queries, dbname): - """Determine if the query is dropping a specific database.""" - if dbname is None: - return False - - def normalize_db_name(db): - return db.lower().strip('`"') - - dbname = normalize_db_name(dbname) - - for query in sqlparse.parse(queries): - if query.get_name() is None: - continue - - first_token = query.token_first(skip_cm=True) - _, second_token = query.token_next(0, skip_cm=True) - database_name = normalize_db_name(query.get_name()) - if ( - first_token.value.lower() == "drop" - and second_token.value.lower() in ("database", "schema") - and database_name == dbname - ): - return True - - -def need_completion_reset(queries): +def need_completion_reset(queries: str) -> bool: """Determines if the statement is a database switch such as 'use' or '\\u'. When a database is changed the existing completions must be reset before we start the completion refresh for the new database. @@ -1003,9 +1043,10 @@ def need_completion_reset(queries): return True except Exception: return False + return False -def is_mutating(status): +def is_mutating(status: str | None) -> bool: """Determines if the statement is mutating based on the status.""" if not status: return False @@ -1026,7 +1067,7 @@ def is_mutating(status): return status.split(None, 1)[0].lower() in mutating -def is_select(status): +def is_select(status: str | None) -> bool: """Returns true if the first word in status is 'select'.""" if not status: return False diff --git a/litecli/packages/completion_engine.py b/litecli/packages/completion_engine.py index 21ca81d..4083b4d 100644 --- a/litecli/packages/completion_engine.py +++ b/litecli/packages/completion_engine.py @@ -1,22 +1,14 @@ -from __future__ import print_function -import os -import sys -import sqlparse -from sqlparse.sql import Comparison, Identifier, Where -from sqlparse.compat import text_type -from .parseutils import last_word, extract_tables, find_prev_keyword -from .special import parse_special_command +from __future__ import annotations -PY2 = sys.version_info[0] == 2 -PY3 = sys.version_info[0] == 3 +from typing import Any, cast -if PY3: - string_types = str -else: - string_types = basestring +import sqlparse +from sqlparse.sql import Comparison, Identifier, Where, Token +from .parseutils import last_word, extract_tables, find_prev_keyword +from .special.main import parse_special_command -def suggest_type(full_text, text_before_cursor): +def suggest_type(full_text: str, text_before_cursor: str) -> list[dict[str, Any]]: """Takes the full_text that is typed so far and also the text before the cursor to suggest completion type and scope. @@ -26,7 +18,7 @@ def suggest_type(full_text, text_before_cursor): word_before_cursor = last_word(text_before_cursor, include="many_punctuations") - identifier = None + identifier: Identifier | None = None # here should be removed once sqlparse has been fixed try: @@ -61,7 +53,7 @@ def suggest_type(full_text, text_before_cursor): stmt_start, stmt_end = 0, 0 for statement in parsed: - stmt_len = len(text_type(statement)) + stmt_len = len(str(statement)) stmt_start, stmt_end = stmt_end, stmt_end + stmt_len if stmt_end >= current_pos: @@ -81,19 +73,21 @@ def suggest_type(full_text, text_before_cursor): # Be careful here because trivial whitespace is parsed as a statement, # but the statement won't have a first token tok1 = statement.token_first() - if tok1 and tok1.value in [".", "\\", "source"]: + if tok1 and tok1.value.startswith("."): + return suggest_special(text_before_cursor) + elif tok1 and tok1.value.startswith("\\"): + return suggest_special(text_before_cursor) + elif tok1 and tok1.value.startswith("source"): return suggest_special(text_before_cursor) elif text_before_cursor and text_before_cursor.startswith(".open "): return suggest_special(text_before_cursor) last_token = statement and statement.token_prev(len(statement.tokens))[1] or "" - return suggest_based_on_last_token( - last_token, text_before_cursor, full_text, identifier - ) + return suggest_based_on_last_token(last_token, text_before_cursor, full_text, identifier) -def suggest_special(text): +def suggest_special(text: str) -> list[dict[str, Any]]: text = text.lstrip() cmd, _, arg = parse_special_command(text) @@ -110,20 +104,52 @@ def suggest_special(text): if cmd in ["\\f", "\\fs", "\\fd"]: return [{"type": "favoritequery"}] - if cmd in ["\\d", "\\dt", "\\dt+", ".schema"]: + if cmd in ["\\d", "\\dt", "\\dt+", ".schema", ".indexes"]: return [ {"type": "table", "schema": []}, {"type": "view", "schema": []}, {"type": "schema"}, ] - elif cmd in ["\\.", "source", ".open"]: + + if cmd in ["\\.", "source", ".open", ".read"]: return [{"type": "file_name"}] + if cmd in [".import"]: + # Usage: .import filename table + if _expecting_arg_idx(arg, text) == 1: + return [{"type": "file_name"}] + else: + return [{"type": "table", "schema": []}] + + if cmd in [".llm", ".ai", "\\llm", "\\ai"]: + return [{"type": "llm"}] + return [{"type": "keyword"}, {"type": "special"}] -def suggest_based_on_last_token(token, text_before_cursor, full_text, identifier): - if isinstance(token, string_types): +def _expecting_arg_idx(arg: str, text: str) -> int: + """Return the index of expecting argument. + + >>> _expecting_arg_idx("./da", ".import ./da") + 1 + >>> _expecting_arg_idx("./data.csv", ".import ./data.csv") + 1 + >>> _expecting_arg_idx("./data.csv", ".import ./data.csv ") + 2 + >>> _expecting_arg_idx("./data.csv t", ".import ./data.csv t") + 2 + """ + args = arg.split() + return len(args) + int(text[-1].isspace()) + + +def suggest_based_on_last_token( + token: str | Token | None, + text_before_cursor: str, + full_text: str, + identifier: Identifier | None, +) -> list[dict[str, Any]]: + if isinstance(token, str): token_v = token.lower() elif isinstance(token, Comparison): # If 'token' is a Comparison type such as @@ -139,13 +165,15 @@ def suggest_based_on_last_token(token, text_before_cursor, full_text, identifier # 'where foo > 5 and '. We need to look "inside" token.tokens to handle # suggestions in complicated where clauses correctly prev_keyword, text_before_cursor = find_prev_keyword(text_before_cursor) - return suggest_based_on_last_token( - prev_keyword, text_before_cursor, full_text, identifier - ) + return suggest_based_on_last_token(prev_keyword, text_before_cursor, full_text, identifier) else: + assert token is not None token_v = token.value.lower() - is_operand = lambda x: x and any([x.endswith(op) for op in ["+", "-", "*", "/"]]) + def is_operand(x: str | None) -> bool: + if not x: + return False + return any([x.endswith(op) for op in ["+", "-", "*", "/"]]) if not token: return [{"type": "keyword"}, {"type": "special"}] @@ -164,9 +192,7 @@ def suggest_based_on_last_token(token, text_before_cursor, full_text, identifier # Suggest columns/functions AND keywords. (If we wanted to be # really fancy, we could suggest only array-typed columns) - column_suggestions = suggest_based_on_last_token( - "where", text_before_cursor, full_text, identifier - ) + column_suggestions = suggest_based_on_last_token("where", text_before_cursor, full_text, identifier) # Check for a subquery expression (cases 3 & 4) where = p.tokens[-1] @@ -191,7 +217,7 @@ def suggest_based_on_last_token(token, text_before_cursor, full_text, identifier # suggest columns that are present in more than one table return [{"type": "column", "tables": tables, "drop_unique": True}] elif p.token_first().value.lower() == "select": - # If the lparen is preceeded by a space chances are we're about to + # If the lparen is preceded by a space chances are we're about to # do a sub-select. if last_word(text_before_cursor, "all_punctuations").startswith("("): return [{"type": "keyword"}] @@ -200,7 +226,7 @@ def suggest_based_on_last_token(token, text_before_cursor, full_text, identifier # We're probably in a function argument list return [{"type": "column", "tables": extract_tables(full_text)}] - elif token_v in ("set", "by", "distinct"): + elif token_v in ("set", "order by", "distinct"): return [{"type": "column", "tables": extract_tables(full_text)}] elif token_v == "as": # Don't suggest anything for an alias @@ -217,7 +243,7 @@ def suggest_based_on_last_token(token, text_before_cursor, full_text, identifier return [{"type": "user"}] elif token_v in ("select", "where", "having"): # Check for a table alias or schema qualification - parent = (identifier and identifier.get_parent_name()) or [] + parent = _get_parent_name(identifier) tables = extract_tables(full_text) if parent: @@ -236,11 +262,10 @@ def suggest_based_on_last_token(token, text_before_cursor, full_text, identifier {"type": "alias", "aliases": aliases}, {"type": "keyword"}, ] - elif (token_v.endswith("join") and token.is_keyword) or ( - token_v - in ("copy", "from", "update", "into", "describe", "truncate", "desc", "explain") + elif (token_v.endswith("join") and isinstance(token, Token) and token.is_keyword) or ( + token_v in ("copy", "from", "update", "into", "describe", "truncate", "desc", "explain") ): - schema = (identifier and identifier.get_parent_name()) or [] + schema = _get_parent_name(identifier) # Suggest tables from either the currently-selected schema or the # public schema if no schema has been specified @@ -259,14 +284,14 @@ def suggest_based_on_last_token(token, text_before_cursor, full_text, identifier elif token_v in ("table", "view", "function"): # E.g. 'DROP FUNCTION ', 'ALTER TABLE ' rel_type = token_v - schema = (identifier and identifier.get_parent_name()) or [] + schema = _get_parent_name(identifier) if schema: return [{"type": rel_type, "schema": schema}] else: return [{"type": "schema"}, {"type": rel_type, "schema": []}] elif token_v == "on": tables = extract_tables(full_text) # [(schema, table, alias), ...] - parent = (identifier and identifier.get_parent_name()) or [] + parent = _get_parent_name(identifier) if parent: # "ON parent." # parent can be either a schema name or table alias @@ -299,14 +324,20 @@ def suggest_based_on_last_token(token, text_before_cursor, full_text, identifier elif token_v.endswith(",") or is_operand(token_v) or token_v in ["=", "and", "or"]: prev_keyword, text_before_cursor = find_prev_keyword(text_before_cursor) if prev_keyword: - return suggest_based_on_last_token( - prev_keyword, text_before_cursor, full_text, identifier - ) + return suggest_based_on_last_token(prev_keyword, text_before_cursor, full_text, identifier) else: return [] else: return [{"type": "keyword"}] -def identifies(id, schema, table, alias): - return id == alias or id == table or (schema and (id == schema + "." + table)) +def identifies(id: Any, schema: str | None, table: str, alias: str | None) -> bool: + return (id == alias) or (id == table) or (schema is not None and (id == schema + "." + table)) + + +def _get_parent_name(identifier: Identifier | None) -> str | list[str]: + if identifier is None: + return [] + + parent = identifier.get_parent_name() + return cast(str, parent) if parent else [] diff --git a/litecli/packages/filepaths.py b/litecli/packages/filepaths.py index 2f01046..fe71192 100644 --- a/litecli/packages/filepaths.py +++ b/litecli/packages/filepaths.py @@ -1,26 +1,25 @@ # -*- coding: utf-8 +from __future__ import annotations -from __future__ import unicode_literals -from litecli.encodingutils import text_type import os -def list_path(root_dir): +def list_path(root_dir: str) -> list[str]: """List directory if exists. :param dir: str :return: list """ - res = [] + res: list[str] = [] if os.path.isdir(root_dir): for name in os.listdir(root_dir): res.append(name) return res -def complete_path(curr_dir, last_dir): +def complete_path(curr_dir: str, last_dir: str) -> str | None: """Return the path to complete that matches the last entered component. If the last entered component is ~, expanded path would not @@ -35,9 +34,10 @@ def complete_path(curr_dir, last_dir): return curr_dir elif last_dir == "~": return os.path.join(last_dir, curr_dir) + return None -def parse_path(root_dir): +def parse_path(root_dir: str) -> tuple[str, str, int]: """Split path into head and last component for the completer. Also return position where last component starts. @@ -53,7 +53,7 @@ def parse_path(root_dir): return base_dir, last_dir, position -def suggest_path(root_dir): +def suggest_path(root_dir: str) -> list[str]: """List all files and subdirectories in a directory. If the directory is not specified, suggest root directory, @@ -64,10 +64,10 @@ def suggest_path(root_dir): """ if not root_dir: - return map(text_type, [os.path.abspath(os.sep), "~", os.curdir, os.pardir]) + return [str(x) for x in [os.path.abspath(os.sep), "~", os.curdir, os.pardir]] if "~" in root_dir: - root_dir = text_type(os.path.expanduser(root_dir)) + root_dir = str(os.path.expanduser(root_dir)) if not os.path.exists(root_dir): root_dir, _ = os.path.split(root_dir) @@ -75,7 +75,7 @@ def suggest_path(root_dir): return list_path(root_dir) -def dir_path_exists(path): +def dir_path_exists(path: str) -> bool: """Check if the directory path exists for a given file. For example, for a file /home/user/.cache/litecli/log, check if diff --git a/litecli/packages/parseutils.py b/litecli/packages/parseutils.py index 92fe365..1a5cd6d 100644 --- a/litecli/packages/parseutils.py +++ b/litecli/packages/parseutils.py @@ -1,10 +1,13 @@ -from __future__ import print_function +from __future__ import annotations + import re +from typing import Generator, Iterable, Literal + import sqlparse -from sqlparse.sql import IdentifierList, Identifier, Function -from sqlparse.tokens import Keyword, DML, Punctuation +from sqlparse.sql import Function, Identifier, IdentifierList, Token, TokenList +from sqlparse.tokens import DML, Keyword, Punctuation -cleanup_regex = { +cleanup_regex: dict[str, re.Pattern[str]] = { # This matches only alphanumerics and underscores. "alphanum_underscore": re.compile(r"(\w+)$"), # This matches everything except spaces, parens, colon, and comma @@ -12,12 +15,14 @@ # This matches everything except spaces, parens, colon, comma, and period "most_punctuations": re.compile(r"([^\.():,\s]+)$"), # This matches everything except a space. - "all_punctuations": re.compile("([^\s]+)$"), + "all_punctuations": re.compile(r"([^\s]+)$"), } +LAST_WORD_INCLUDE_TYPE = Literal["alphanum_underscore", "many_punctuations", "most_punctuations", "all_punctuations"] -def last_word(text, include="alphanum_underscore"): - """ + +def last_word(text: str, include: LAST_WORD_INCLUDE_TYPE = "alphanum_underscore") -> str: + R""" Find the last word in a sentence. >>> last_word('abc') @@ -41,9 +46,9 @@ def last_word(text, include="alphanum_underscore"): >>> last_word('bac $def', include='most_punctuations') '$def' >>> last_word('bac \def', include='most_punctuations') - '\\\\def' + '\\def' >>> last_word('bac \def;', include='most_punctuations') - '\\\\def;' + '\\def;' >>> last_word('bac::def', include='most_punctuations') 'def' """ @@ -63,8 +68,7 @@ def last_word(text, include="alphanum_underscore"): # This code is borrowed from sqlparse example script. -# -def is_subselect(parsed): +def is_subselect(parsed: TokenList) -> bool: if not parsed.is_group: return False for item in parsed.tokens: @@ -79,7 +83,7 @@ def is_subselect(parsed): return False -def extract_from_part(parsed, stop_at_punctuation=True): +def extract_from_part(parsed: TokenList, stop_at_punctuation: bool = True) -> Generator[Token, None, None]: tbl_prefix_seen = False for item in parsed.tokens: if tbl_prefix_seen: @@ -96,17 +100,18 @@ def extract_from_part(parsed, stop_at_punctuation=True): # Also 'SELECT * FROM abc JOIN def' will trigger this elif # condition. So we need to ignore the keyword JOIN and its variants # INNER JOIN, FULL OUTER JOIN, etc. - elif ( - item.ttype is Keyword - and (not item.value.upper() == "FROM") - and (not item.value.upper().endswith("JOIN")) - ): + elif item.ttype is Keyword and (not item.value.upper() == "FROM") and (not item.value.upper().endswith("JOIN")): return else: yield item - elif ( - item.ttype is Keyword or item.ttype is Keyword.DML - ) and item.value.upper() in ("COPY", "FROM", "INTO", "UPDATE", "TABLE", "JOIN"): + elif (item.ttype is Keyword or item.ttype is Keyword.DML) and item.value.upper() in ( + "COPY", + "FROM", + "INTO", + "UPDATE", + "TABLE", + "JOIN", + ): tbl_prefix_seen = True # 'SELECT a, FROM abc' will detect FROM as part of the column list. # So this check here is necessary. @@ -117,8 +122,8 @@ def extract_from_part(parsed, stop_at_punctuation=True): break -def extract_table_identifiers(token_stream): - """yields tuples of (schema_name, table_name, table_alias)""" +def extract_table_identifiers(token_stream: Iterable[Token]) -> Generator[tuple[str | None, str, str | None], None, None]: + """Yield tuples of (schema_name, table_name, table_alias).""" for item in token_stream: if isinstance(item, IdentifierList): @@ -146,8 +151,8 @@ def extract_table_identifiers(token_stream): # extract_tables is inspired from examples in the sqlparse lib. -def extract_tables(sql): - """Extract the table names from an SQL statment. +def extract_tables(sql: str) -> list[tuple[str | None, str, str | None]]: + """Extract the table names from an SQL statement. Returns a list of (schema, table, alias) tuples @@ -165,8 +170,8 @@ def extract_tables(sql): return list(extract_table_identifiers(stream)) -def find_prev_keyword(sql): - """ Find the last sql keyword in an SQL statement +def find_prev_keyword(sql: str) -> tuple[Token | None, str]: + """Find the last sql keyword in an SQL statement Returns the value of the last keyword, and the text of the query with everything after the last keyword stripped @@ -180,9 +185,7 @@ def find_prev_keyword(sql): logical_operators = ("AND", "OR", "NOT", "BETWEEN") for t in reversed(flattened): - if t.value == "(" or ( - t.is_keyword and (t.value.upper() not in logical_operators) - ): + if t.value == "(" or (t.is_keyword and (t.value.upper() not in logical_operators)): # Find the location of token t in the original parsed statement # We can't use parsed.token_index(t) because t may be a child token # inside a TokenList, in which case token_index thows an error @@ -201,14 +204,14 @@ def find_prev_keyword(sql): return None, "" -def query_starts_with(query, prefixes): +def query_starts_with(query: str, prefixes: Iterable[str]) -> bool: """Check if the query starts with any item from *prefixes*.""" prefixes = [prefix.lower() for prefix in prefixes] formatted_sql = sqlparse.format(query.lower(), strip_comments=True) return bool(formatted_sql) and formatted_sql.split()[0] in prefixes -def queries_start_with(queries, prefixes): +def queries_start_with(queries: str, prefixes: Iterable[str]) -> bool: """Check if any queries start with any item from *prefixes*.""" for query in sqlparse.split(queries): if query and query_starts_with(query, prefixes) is True: @@ -216,7 +219,7 @@ def queries_start_with(queries, prefixes): return False -def is_destructive(queries): +def is_destructive(queries: str) -> bool: """Returns if any of the queries in *queries* is destructive.""" keywords = ("drop", "shutdown", "delete", "truncate", "alter") return queries_start_with(queries, keywords) diff --git a/litecli/packages/prompt_utils.py b/litecli/packages/prompt_utils.py index d9ad2b6..22d2318 100644 --- a/litecli/packages/prompt_utils.py +++ b/litecli/packages/prompt_utils.py @@ -1,38 +1,57 @@ -# -*- coding: utf-8 -*- -from __future__ import unicode_literals - +from __future__ import annotations import sys + import click + from .parseutils import is_destructive +from typing import Any -def confirm_destructive_query(queries): - """Check if the query is destructive and prompts the user to confirm. +class ConfirmBoolParamType(click.ParamType): + name = "confirmation" - Returns: - * None if the query is non-destructive or we can't prompt the user. - * True if the query is destructive and the user wants to proceed. - * False if the query is destructive and the user doesn't want to proceed. + def convert(self, value: bool | str, param: click.Parameter | None, ctx: click.Context | None) -> bool: + if isinstance(value, bool): + return value + value = value.lower() + if value in ("yes", "y"): + return True + if value in ("no", "n"): + return False + self.fail(f"{value} is not a valid boolean", param, ctx) + + def __repr__(self) -> str: + return "BOOL" + +BOOLEAN_TYPE = ConfirmBoolParamType() + + +def confirm_destructive_query(queries: str) -> bool | None: + """Check if the query is destructive and prompt to confirm. + + Returns: + - None: non-destructive or cannot prompt (non-tty). + - True: destructive and user consents. + - False: destructive and user declines. """ - prompt_text = ( - "You're about to run a destructive command.\n" "Do you want to proceed? (y/n)" - ) + prompt_text = "You're about to run a destructive command.\nDo you want to proceed? (y/n)" if is_destructive(queries) and sys.stdin.isatty(): - return prompt(prompt_text, type=bool) + return bool(prompt(prompt_text, type=BOOLEAN_TYPE)) + return None -def confirm(*args, **kwargs): - """Prompt for confirmation (yes/no) and handle any abort exceptions.""" +def confirm(*args: Any, **kwargs: Any) -> bool: + """Prompt for confirmation (yes/no) and handle aborts.""" try: return click.confirm(*args, **kwargs) except click.Abort: return False -def prompt(*args, **kwargs): - """Prompt the user for input and handle any abort exceptions.""" +def prompt(*args: Any, **kwargs: Any) -> Any: + """Prompt the user for input and handle aborts. Returns the value from click.prompt.""" try: return click.prompt(*args, **kwargs) except click.Abort: diff --git a/litecli/packages/special/__init__.py b/litecli/packages/special/__init__.py index fd2b18c..5eddb63 100644 --- a/litecli/packages/special/__init__.py +++ b/litecli/packages/special/__init__.py @@ -1,12 +1,45 @@ -__all__ = [] +# ruff: noqa +from __future__ import annotations +from types import FunctionType -def export(defn): +from typing import TypeVar + +__all__: list[str] = [] + +_Exported = TypeVar("_Exported") + + +def export(defn: _Exported) -> _Exported: """Decorator to explicitly mark functions that are exposed in a lib.""" - globals()[defn.__name__] = defn - __all__.append(defn.__name__) + # ty requires an explicit callable/type check to access __name__. + if isinstance(defn, (type, FunctionType)): + globals()[defn.__name__] = defn + __all__.append(defn.__name__) return defn from . import dbcommands from . import iocommands +from . import llm +from . import utils +from .main import CommandNotFound, register_special_command, execute +from .iocommands import ( + set_favorite_queries, + editor_command, + get_filename, + get_editor_query, + open_external_editor, + is_expanded_output, + set_expanded_output, + write_tee, + unset_once_if_written, + unset_pipe_once_if_written, + disable_pager, + set_pager, + is_pager_enabled, + write_once, + write_pipe_once, + close_tee, +) +from .llm import is_llm_command, handle_llm, FinishIteration diff --git a/litecli/packages/special/dbcommands.py b/litecli/packages/special/dbcommands.py index 9d5d84e..5391e8a 100644 --- a/litecli/packages/special/dbcommands.py +++ b/litecli/packages/special/dbcommands.py @@ -1,14 +1,18 @@ -from __future__ import unicode_literals +from __future__ import annotations + +import csv import logging import os +import sys import platform import shlex -from sqlite3 import ProgrammingError +from typing import Any, cast + from litecli import __version__ from litecli.packages.special import iocommands -from litecli.packages.special.utils import format_uptime -from .main import special_command, RAW_QUERY, PARSED_QUERY, ArgumentMissing +from .main import special_command, RAW_QUERY, PARSED_QUERY +from .types import DBCursor log = logging.getLogger(__name__) @@ -21,19 +25,24 @@ case_sensitive=True, aliases=("\\dt",), ) -def list_tables(cur, arg=None, arg_type=PARSED_QUERY, verbose=False): +def list_tables( + cur: DBCursor, + arg: str | None = None, + arg_type: int = PARSED_QUERY, + verbose: bool = False, +) -> list[tuple]: if arg: - args = ("{0}%".format(arg),) + args: tuple[str, ...] = ("{0}%".format(arg),) query = """ SELECT name FROM sqlite_master - WHERE type IN ('table','view') AND name LIKE ? AND name NOT LIKE 'sqlite_%' + WHERE type IN ('table','view') AND name LIKE ? AND name NOT LIKE 'sqlite_%' AND name NOT LIKE 'sqlean_%' ORDER BY 1 """ else: args = tuple() query = """ SELECT name FROM sqlite_master - WHERE type IN ('table','view') AND name NOT LIKE 'sqlite_%' + WHERE type IN ('table','view') AND name NOT LIKE 'sqlite_%' AND name NOT LIKE 'sqlean_%' ORDER BY 1 """ @@ -55,6 +64,45 @@ def list_tables(cur, arg=None, arg_type=PARSED_QUERY, verbose=False): return [(None, tables, headers, status)] +@special_command( + ".views", + "\\dv", + "List views.", + arg_type=PARSED_QUERY, + case_sensitive=True, + aliases=("\\dv",), +) +def list_views( + cur: DBCursor, + arg: str | None = None, + arg_type: int = PARSED_QUERY, + verbose: bool = False, +) -> list[tuple]: + if arg: + args: tuple[str, ...] = ("{0}%".format(arg),) + query = """ + SELECT name FROM sqlite_master + WHERE type = 'view' AND name LIKE ? AND name NOT LIKE 'sqlite_%' AND name NOT LIKE 'sqlean_%' + ORDER BY 1 + """ + else: + args = tuple() + query = """ + SELECT name FROM sqlite_master + WHERE type = 'view' AND name NOT LIKE 'sqlite_%' AND name NOT LIKE 'sqlean_%' + ORDER BY 1 + """ + log.debug(query) + cur.execute(query, args) + views = cur.fetchall() + status = "" + if cur.description: + headers = [x[0] for x in cur.description] + else: + return [(None, None, None, "")] + return [(None, views, headers, status)] + + @special_command( ".schema", ".schema[+] [table]", @@ -62,18 +110,19 @@ def list_tables(cur, arg=None, arg_type=PARSED_QUERY, verbose=False): arg_type=PARSED_QUERY, case_sensitive=True, ) -def show_schema(cur, arg=None, **_): +def show_schema(cur: DBCursor, arg: str | None = None, **_: Any) -> list[tuple]: if arg: - args = (arg,) + args: tuple[str, ...] = (arg,) query = """ SELECT sql FROM sqlite_master - WHERE name==? + WHERE tbl_name==? AND sql IS NOT NULL ORDER BY tbl_name, type DESC, name """ else: args = tuple() query = """ SELECT sql FROM sqlite_master + WHERE sql IS NOT NULL ORDER BY tbl_name, type DESC, name """ @@ -97,7 +146,7 @@ def show_schema(cur, arg=None, **_): case_sensitive=True, aliases=("\\l",), ) -def list_databases(cur, **_): +def list_databases(cur: DBCursor, **_: Any) -> list[tuple]: query = "PRAGMA database_list" log.debug(query) cur.execute(query) @@ -108,6 +157,46 @@ def list_databases(cur, **_): return [(None, None, None, "")] +@special_command( + ".indexes", + ".indexes [tablename]", + "List indexes.", + arg_type=PARSED_QUERY, + case_sensitive=True, + aliases=("\\di",), +) +def list_indexes( + cur: DBCursor, + arg: str | None = None, + arg_type: int = PARSED_QUERY, + verbose: bool = False, +) -> list[tuple]: + if arg: + args: tuple[str, ...] = ("{0}%".format(arg),) + query = """ + SELECT name, sql FROM sqlite_master + WHERE type = 'index' AND tbl_name LIKE ? AND name NOT LIKE 'sqlite_%' + ORDER BY 1 + """ + else: + args = tuple() + query = """ + SELECT name, sql FROM sqlite_master + WHERE type = 'index' AND name NOT LIKE 'sqlite_%' + ORDER BY 1 + """ + + log.debug(query) + cur.execute(query, args) + indexes = cur.fetchall() + status = "" + if cur.description: + headers = [x[0] for x in cur.description] + else: + return [(None, None, None, "")] + return [(None, indexes, headers, status)] + + @special_command( ".status", "\\s", @@ -116,7 +205,7 @@ def list_databases(cur, **_): aliases=("\\s",), case_sensitive=True, ) -def status(cur, **_): +def status(cur: DBCursor, **_: Any) -> list[tuple]: # Create output buffers. footer = [] footer.append("--------------") @@ -133,9 +222,8 @@ def status(cur, **_): query = "SELECT file from pragma_database_list() where name = 'main';" log.debug(query) cur.execute(query) - db = cur.fetchone()[0] - if db is None: - db = "" + row = cur.fetchone() + db = row[0] if row else "" footer.append("Current database: " + db) if iocommands.is_pager_enabled(): @@ -158,7 +246,7 @@ def status(cur, **_): arg_type=PARSED_QUERY, case_sensitive=True, ) -def load_extension(cur, arg, **_): +def load_extension(cur: DBCursor, arg: str, **_: Any) -> list[tuple]: args = shlex.split(arg) if len(args) != 1: raise TypeError(".load accepts exactly one path") @@ -175,18 +263,15 @@ def load_extension(cur, arg, **_): "Description of a table", arg_type=PARSED_QUERY, case_sensitive=True, - aliases=("\\d", "describe", "desc"), + aliases=("\\d", "desc"), ) -def describe(cur, arg, **_): +def describe(cur: DBCursor, arg: str | None, **_: Any) -> list[tuple]: if arg: - args = (arg,) query = """ PRAGMA table_info({}) - """.format( - arg - ) + """.format(arg) else: - raise ArgumentMissing("Table name required.") + return cast(list[tuple[Any, ...]], list_tables(cur)) log.debug(query) cur.execute(query) @@ -201,18 +286,54 @@ def describe(cur, arg, **_): @special_command( - ".read", - ".read path", - "Read input from path", + ".import", + ".import filename table", + "Import data from filename into an existing table", arg_type=PARSED_QUERY, case_sensitive=True, ) -def read_script(cur, arg, **_): - args = shlex.split(arg) - if len(args) != 1: - raise TypeError(".read accepts exactly one path") - path = args[0] - with open(path, "r") as f: - script = f.read() - cur.executescript(script) - return [(None, None, None, "")] +def import_file(cur: DBCursor, arg: str | None = None, **_: Any) -> list[tuple]: + def split(s: str) -> list[str]: + # this is a modification of shlex.split function, just to make it support '`', + # because table name might contain '`' character. + lex = shlex.shlex(s, posix=True) + lex.whitespace_split = True + lex.commenters = "" + lex.quotes += "`" + return list(lex) + + if arg is None: + raise TypeError("Usage: .import filename table") + args = split(arg) + log.debug("[arg = %r], [args = %r]", arg, args) + if len(args) != 2: + raise TypeError("Usage: .import filename table") + + filename, table = args + cur.execute('PRAGMA table_info("%s")' % table) + ncols = len(cur.fetchall()) + insert_tmpl = 'INSERT INTO "%s" VALUES (?%s)' % (table, ",?" * (ncols - 1)) + + with open(filename, "r") as csvfile: + dialect = csv.Sniffer().sniff(csvfile.read(1024)) + csvfile.seek(0) + reader = csv.reader(csvfile, dialect) + + cur.execute("BEGIN") + ninserted, nignored = 0, 0 + for i, row in enumerate(reader): + if len(row) != ncols: + print( + "%s:%d expected %d columns but found %d - ignored" % (filename, i, ncols, len(row)), + file=sys.stderr, + ) + nignored += 1 + continue + cur.execute(insert_tmpl, row) + ninserted += 1 + cur.execute("COMMIT") + + status = "Inserted %d rows into %s" % (ninserted, table) + if nignored > 0: + status += " (%d rows are ignored)" % nignored + return [(None, None, None, status)] diff --git a/litecli/packages/special/favoritequeries.py b/litecli/packages/special/favoritequeries.py index 7da6fbf..3dd2a89 100644 --- a/litecli/packages/special/favoritequeries.py +++ b/litecli/packages/special/favoritequeries.py @@ -1,10 +1,12 @@ # -*- coding: utf-8 -*- -from __future__ import unicode_literals +from __future__ import annotations +import builtins +from typing import Any, cast -class FavoriteQueries(object): - section_name = "favorite_queries" +class FavoriteQueries(object): + section_name: str = "favorite_queries" usage = """ Favorite Queries are a way to save frequently used queries @@ -35,22 +37,24 @@ class FavoriteQueries(object): simple: Deleted """ - def __init__(self, config): + def __init__(self, config: Any) -> None: self.config = config - def list(self): - return self.config.get(self.section_name, []) + def list(self) -> builtins.list[str]: + section = cast(dict[str, str], self.config.get(self.section_name, {})) + return list(section.keys()) - def get(self, name): - return self.config.get(self.section_name, {}).get(name, None) + def get(self, name: str) -> str | None: + section = cast(dict[str, str], self.config.get(self.section_name, {})) + return section.get(name) - def save(self, name, query): + def save(self, name: str, query: str) -> None: if self.section_name not in self.config: self.config[self.section_name] = {} self.config[self.section_name][name] = query self.config.write() - def delete(self, name): + def delete(self, name: str) -> str: try: del self.config[self.section_name][name] except KeyError: diff --git a/litecli/packages/special/iocommands.py b/litecli/packages/special/iocommands.py index a9036aa..434a96f 100644 --- a/litecli/packages/special/iocommands.py +++ b/litecli/packages/special/iocommands.py @@ -1,44 +1,51 @@ -from __future__ import unicode_literals -import os -import re +from __future__ import annotations + import locale import logging -import subprocess +import os +import re import shlex +import subprocess from io import open from time import sleep +from typing import Any, Generator, TextIO import click import sqlparse from configobj import ConfigObj +from ..prompt_utils import confirm_destructive_query from . import export -from .main import special_command, NO_QUERY, PARSED_QUERY from .favoritequeries import FavoriteQueries +from .main import NO_QUERY, PARSED_QUERY, special_command from .utils import handle_cd_command -from litecli.packages.prompt_utils import confirm_destructive_query -use_expanded_output = False -PAGER_ENABLED = True -tee_file = None -once_file = written_to_once_file = None -favoritequeries = FavoriteQueries(ConfigObj()) +use_expanded_output: bool = False +PAGER_ENABLED: bool = True +tee_file: TextIO | None = None +once_file: TextIO | None = None +written_to_once_file: bool = False +pipe_once_process: subprocess.Popen[str] | None = None +written_to_pipe_once_process: bool = False +favoritequeries: FavoriteQueries = FavoriteQueries(ConfigObj()) + +log = logging.getLogger(__name__) @export -def set_favorite_queries(config): +def set_favorite_queries(config: Any) -> None: global favoritequeries favoritequeries = FavoriteQueries(config) @export -def set_pager_enabled(val): +def set_pager_enabled(val: bool) -> None: global PAGER_ENABLED PAGER_ENABLED = val @export -def is_pager_enabled(): +def is_pager_enabled() -> bool: return PAGER_ENABLED @@ -51,7 +58,7 @@ def is_pager_enabled(): aliases=("\\P",), case_sensitive=True, ) -def set_pager(arg, **_): +def set_pager(arg: str, **_: Any) -> list[tuple]: if arg: os.environ["PAGER"] = arg msg = "PAGER set to %s." % arg @@ -76,27 +83,24 @@ def set_pager(arg, **_): aliases=("\\n",), case_sensitive=True, ) -def disable_pager(): +def disable_pager() -> list[tuple]: set_pager_enabled(False) return [(None, None, None, "Pager disabled.")] @export -def set_expanded_output(val): +def set_expanded_output(val: bool) -> None: global use_expanded_output use_expanded_output = val @export -def is_expanded_output(): +def is_expanded_output() -> bool: return use_expanded_output -_logger = logging.getLogger(__name__) - - @export -def editor_command(command): +def editor_command(command: str) -> bool: """ Is this an external editor command? :param command: string @@ -107,21 +111,22 @@ def editor_command(command): @export -def get_filename(sql): +def get_filename(sql: str) -> str | None: if sql.strip().startswith("\\e"): - command, _, filename = sql.partition(" ") + _cmd, _sep, filename = sql.partition(" ") return filename.strip() or None + return None @export -def get_editor_query(sql): +def get_editor_query(sql: str) -> str: """Get the query part of an editor command.""" sql = sql.strip() # The reason we can't simply do .strip('\e') is that it strips characters, # not a substring. So it'll strip "e" in the end of the sql also! # Ex: "select * from style\e" -> "select * from styl". - pattern = re.compile("(^\\\e|\\\e$)") + pattern = re.compile(r"(^\\e|\\e$)") while pattern.search(sql): sql = pattern.sub("", sql) @@ -129,43 +134,29 @@ def get_editor_query(sql): @export -def open_external_editor(filename=None, sql=None): - """Open external editor, wait for the user to type in their query, return - the query. - - :return: list with one tuple, query as first element. - - """ - - message = None - filename = filename.strip().split(" ", 1)[0] if filename else None - +def open_external_editor(filename: str | None = None, sql: str | None = None) -> tuple[str, str | None]: + """Open external editor, wait for the user to type in their query, return the query.""" + message: str | None = None sql = sql or "" MARKER = "# Type your query above this line.\n" - # Populate the editor buffer with the partial sql (if available) and a - # placeholder comment. - query = click.edit( - "{sql}\n\n{marker}".format(sql=sql, marker=MARKER), - filename=filename, - extension=".sql", - ) - if filename: + filename = filename.strip().split(" ", 1)[0] + click.edit(filename=filename) try: with open(filename, encoding="utf-8") as f: - query = f.read() + text = f.read() except IOError: - message = "Error reading file: %s." % filename + message = f"Error reading file: {filename}." + text = sql + return (text, message) - if query is not None: - query = query.split(MARKER, 1)[0].rstrip("\n") + edited = click.edit(f"{sql}\n\n{MARKER}", extension=".sql") + if edited: + edited = edited.split(MARKER, 1)[0].rstrip("\n") else: - # Don't return None for the caller to deal with. - # Empty string is ok. - query = sql - - return (query, message) + edited = sql + return (edited, None) @special_command( @@ -175,14 +166,14 @@ def open_external_editor(filename=None, sql=None): arg_type=PARSED_QUERY, case_sensitive=True, ) -def execute_favorite_query(cur, arg, **_): +def execute_favorite_query(cur: Any, arg: str, verbose: bool = False, **_: Any) -> Generator[tuple, None, None]: """Returns (title, rows, headers, status)""" if arg == "": for result in list_favorite_queries(): yield result """Parse out favorite name and optional substitution parameters""" - name, _, arg_str = arg.partition(" ") + name, _sep, arg_str = arg.partition(" ") args = shlex.split(arg_str) query = favoritequeries.get(name) @@ -192,7 +183,7 @@ def execute_favorite_query(cur, arg, **_): elif "?" in query: for sql in sqlparse.split(query): sql = sql.rstrip(";") - title = "> %s" % (sql) + title = "> %s" % (sql) if verbose else None cur.execute(sql, args) if cur.description: headers = [x[0] for x in cur.description] @@ -204,9 +195,10 @@ def execute_favorite_query(cur, arg, **_): if arg_error: yield (None, None, None, arg_error) else: + assert query, "query should be non-empty" for sql in sqlparse.split(query): sql = sql.rstrip(";") - title = "> %s" % (sql) + title = "> %s" % (sql) if verbose else None cur.execute(sql) if cur.description: headers = [x[0] for x in cur.description] @@ -215,7 +207,7 @@ def execute_favorite_query(cur, arg, **_): yield (title, None, None, None) -def list_favorite_queries(): +def list_favorite_queries() -> list[tuple]: """List of all favorite queries. Returns (title, rows, headers, status)""" @@ -229,8 +221,8 @@ def list_favorite_queries(): return [("", rows, headers, status)] -def subst_favorite_query_args(query, args): - """replace positional parameters ($1...$N) in query.""" +def subst_favorite_query_args(query: str, args: list[str]) -> list[str | None]: + """Replace positional parameters ($1...$N or ?) in query.""" for idx, val in enumerate(args): shell_subst_var = "$" + str(idx + 1) question_subst_var = "?" @@ -241,11 +233,10 @@ def subst_favorite_query_args(query, args): else: return [ None, - "Too many arguments.\nQuery does not have enough place holders to substitute.\n" - + query, + "Too many arguments.\nQuery does not have enough place holders to substitute.\n" + query, ] - match = re.search("\\?|\\$\d+", query) + match = re.search(r"\?|\$\d+", query) if match: return [ None, @@ -256,7 +247,7 @@ def subst_favorite_query_args(query, args): @special_command("\\fs", "\\fs name query", "Save a favorite query.") -def save_favorite_query(arg, **_): +def save_favorite_query(arg: str, **_: Any) -> list[tuple]: """Save a new favorite query. Returns (title, rows, headers, status)""" @@ -264,7 +255,7 @@ def save_favorite_query(arg, **_): if not arg: return [(None, None, None, usage)] - name, _, query = arg.partition(" ") + name, _sep, query = arg.partition(" ") # If either name or query is missing then print the usage and complain. if (not name) or (not query): @@ -275,9 +266,8 @@ def save_favorite_query(arg, **_): @special_command("\\fd", "\\fd [name]", "Delete a favorite query.") -def delete_favorite_query(arg, **_): - """Delete an existing favorite query. - """ +def delete_favorite_query(arg: str, **_: Any) -> list[tuple]: + """Delete an existing favorite query.""" usage = "Syntax: \\fd name.\n\n" + favoritequeries.usage if not arg: return [(None, None, None, usage)] @@ -287,8 +277,8 @@ def delete_favorite_query(arg, **_): return [(None, None, None, status)] -@special_command("system", "system [command]", "Execute a system shell commmand.") -def execute_system_command(arg, **_): +@special_command("system", "system [command]", "Execute a system shell command.") +def execute_system_command(arg: str, **_: Any) -> list[tuple]: """Execute a system shell command.""" usage = "Syntax: system [command].\n" @@ -306,19 +296,17 @@ def execute_system_command(arg, **_): args = arg.split(" ") process = subprocess.Popen(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE) output, error = process.communicate() - response = output if not error else error - + raw = output if not error else error # Python 3 returns bytes. This needs to be decoded to a string. - if isinstance(response, bytes): - encoding = locale.getpreferredencoding(False) - response = response.decode(encoding) + encoding = locale.getpreferredencoding(False) + response: str = raw.decode(encoding) if isinstance(raw, bytes) else str(raw) return [(None, None, None, response)] except OSError as e: return [(None, None, None, "OSError: %s" % e.strerror)] -def parseargfile(arg): +def parseargfile(arg: str) -> tuple[str, str]: if arg.startswith("-o "): mode = "w" filename = arg[3:] @@ -329,19 +317,23 @@ def parseargfile(arg): if not filename: raise TypeError("You must provide a filename.") - return {"file": os.path.expanduser(filename), "mode": mode} + return (os.path.expanduser(filename), mode) @special_command( - "tee", - "tee [-o] filename", + ".output", + ".output [-o] filename", "Append all results to an output file (overwrite using -o).", + aliases=("tee",), ) -def set_tee(arg, **_): +def set_tee(arg: str, **_: Any) -> list[tuple]: global tee_file try: - tee_file = open(**parseargfile(arg)) + file, mode = parseargfile(arg) + from typing import cast + + tee_file = cast(TextIO, open(file, mode)) except (IOError, OSError) as e: raise OSError("Cannot write to file '{}': {}".format(e.filename, e.strerror)) @@ -349,7 +341,7 @@ def set_tee(arg, **_): @export -def close_tee(): +def close_tee() -> None: global tee_file if tee_file: tee_file.close() @@ -357,13 +349,13 @@ def close_tee(): @special_command("notee", "notee", "Stop writing results to an output file.") -def no_tee(arg, **_): +def no_tee(arg: str, **_: Any) -> list[tuple]: close_tee() return [(None, None, None, "")] @export -def write_tee(output): +def write_tee(output: str) -> None: global tee_file if tee_file: click.echo(output, file=tee_file, nl=False) @@ -377,38 +369,89 @@ def write_tee(output): "Append next result to an output file (overwrite using -o).", aliases=("\\o", "\\once"), ) -def set_once(arg, **_): - global once_file +def set_once(arg: str, **_: Any) -> list[tuple]: + global once_file, written_to_once_file + try: + file, mode = parseargfile(arg) + from typing import cast - once_file = parseargfile(arg) + once_file = cast(TextIO, open(file, mode)) + except (IOError, OSError) as e: + raise OSError("Cannot write to file '{}': {}".format(e.filename, e.strerror)) + written_to_once_file = False return [(None, None, None, "")] @export -def write_once(output): +def write_once(output: str) -> None: global once_file, written_to_once_file if output and once_file: - try: - f = open(**once_file) - except (IOError, OSError) as e: - once_file = None - raise OSError( - "Cannot write to file '{}': {}".format(e.filename, e.strerror) - ) - - with f: - click.echo(output, file=f, nl=False) - click.echo("\n", file=f, nl=False) + click.echo(output, file=once_file, nl=False) + click.echo("\n", file=once_file, nl=False) + once_file.flush() written_to_once_file = True @export -def unset_once_if_written(): +def unset_once_if_written() -> None: """Unset the once file, if it has been written to.""" - global once_file - if written_to_once_file: + global once_file, written_to_once_file + if once_file and written_to_once_file: + once_file.close() once_file = None + written_to_once_file = False + + +@special_command( + "\\pipe_once", + "\\| command", + "Send next result to a subprocess.", + aliases=("\\|",), +) +def set_pipe_once(arg: str, **_: Any) -> list[tuple]: + global pipe_once_process, written_to_pipe_once_process + pipe_once_cmd = shlex.split(arg) + if len(pipe_once_cmd) == 0: + raise OSError("pipe_once requires a command") + written_to_pipe_once_process = False + pipe_once_process = subprocess.Popen( + pipe_once_cmd, + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + bufsize=1, + encoding="UTF-8", + universal_newlines=True, + ) + return [(None, None, None, "")] + + +@export +def write_pipe_once(output: str) -> None: + global pipe_once_process, written_to_pipe_once_process + if output and pipe_once_process: + try: + click.echo(output, file=pipe_once_process.stdin, nl=False) + click.echo("\n", file=pipe_once_process.stdin, nl=False) + except (IOError, OSError) as e: + pipe_once_process.terminate() + raise OSError("Failed writing to pipe_once subprocess: {}".format(e.strerror)) + written_to_pipe_once_process = True + + +@export +def unset_pipe_once_if_written() -> None: + """Unset the pipe_once cmd, if it has been written to.""" + global pipe_once_process, written_to_pipe_once_process + if written_to_pipe_once_process and pipe_once_process: + (stdout_data, stderr_data) = pipe_once_process.communicate() + if len(stdout_data) > 0: + print(stdout_data.rstrip("\n")) + if len(stderr_data) > 0: + print(stderr_data.rstrip("\n")) + pipe_once_process = None + written_to_pipe_once_process = False @special_command( @@ -416,7 +459,7 @@ def unset_once_if_written(): "watch [seconds] [-c] query", "Executes the query every [seconds] seconds (by default 5).", ) -def watch_query(arg, **kwargs): +def watch_query(arg: str, **kwargs: Any) -> Generator[tuple, None, None]: usage = """Syntax: watch [seconds] [-c] query. * seconds: The interval at the query will be repeated, in seconds. By default 5. @@ -425,7 +468,7 @@ def watch_query(arg, **kwargs): if not arg: yield (None, None, None, usage) raise StopIteration - seconds = 5 + seconds: float = 5.0 clear_screen = False statement = None while statement is None: @@ -451,9 +494,7 @@ def watch_query(arg, **kwargs): elif destructive_prompt is True: click.secho("Your call!") cur = kwargs["cur"] - sql_list = [ - (sql.rstrip(";"), "> {0!s}".format(sql)) for sql in sqlparse.split(statement) - ] + sql_list = [(sql.rstrip(";"), "> {0!s}".format(sql)) for sql in sqlparse.split(statement)] old_pager_enabled = is_pager_enabled() while True: if clear_screen: @@ -462,7 +503,7 @@ def watch_query(arg, **kwargs): # Somewhere in the code the pager its activated after every yield, # so we disable it in every iteration set_pager_enabled(False) - for (sql, title) in sql_list: + for sql, title in sql_list: cur.execute(sql) if cur.description: headers = [x[0] for x in cur.description] diff --git a/litecli/packages/special/llm.py b/litecli/packages/special/llm.py new file mode 100644 index 0000000..e4391d7 --- /dev/null +++ b/litecli/packages/special/llm.py @@ -0,0 +1,443 @@ +from __future__ import annotations + +import contextlib +import importlib +import io +import logging +import os +import pprint +import re +import shlex +import sys +from runpy import run_module +from time import time +from typing import Any + +import click + +from . import export +from .main import Verbosity, parse_special_command +from .types import DBCursor + + +def _load_llm_module() -> Any | None: + try: + return importlib.import_module("llm") + except ImportError: + return None + + +def _load_llm_cli_module() -> Any | None: + try: + return importlib.import_module("llm.cli") + except ImportError: + return None + + +llm_module = _load_llm_module() +llm_cli_module = _load_llm_cli_module() + +# Alias for tests and patching. +llm = llm_module + +LLM_IMPORTED = llm_module is not None + +cli: click.Command | None +if llm_cli_module is not None: + llm_cli = getattr(llm_cli_module, "cli", None) + cli = llm_cli if isinstance(llm_cli, click.Command) else None +else: + cli = None + +LLM_CLI_IMPORTED = cli is not None + +log = logging.getLogger(__name__) + +LLM_TEMPLATE_NAME = "litecli-llm-template" +LLM_CLI_COMMANDS: list[str] = list(cli.commands.keys()) if isinstance(cli, click.Group) else [] +# Mapping of model_id to None used for completion tree leaves. +if llm_module is not None: + get_models = getattr(llm_module, "get_models", None) + MODELS: dict[str, None] = {x.model_id: None for x in get_models()} if callable(get_models) else {} +else: + MODELS = {} + + +def run_external_cmd( + cmd: str, + *args: str, + capture_output: bool = False, + restart_cli: bool = False, + raise_exception: bool = True, +) -> tuple[int, str]: + original_exe = sys.executable + original_args = sys.argv + + try: + sys.argv = [cmd] + list(args) + code: int = 0 + + if capture_output: + buffer = io.StringIO() + stack = contextlib.ExitStack() + stack.enter_context(contextlib.redirect_stdout(buffer)) + stack.enter_context(contextlib.redirect_stderr(buffer)) + redirect: contextlib.AbstractContextManager[Any] = stack + else: + redirect = contextlib.nullcontext() + + with redirect: + try: + run_module(cmd, run_name="__main__") + except SystemExit as e: + exit_code = e.code + if isinstance(exit_code, int): + code = exit_code + else: + code = 1 + if code != 0 and raise_exception: + if capture_output: + raise RuntimeError(buffer.getvalue()) + else: + raise RuntimeError(f"Command {cmd} failed with exit code {code}.") + except Exception as e: + code = 1 + if raise_exception: + if capture_output: + raise RuntimeError(buffer.getvalue()) + else: + raise RuntimeError(f"Command {cmd} failed: {e}") + + if restart_cli and code == 0: + os.execv(original_exe, [original_exe] + original_args) + + if capture_output: + return code, buffer.getvalue() + else: + return code, "" + finally: + sys.argv = original_args + + +def build_command_tree(cmd: click.Command) -> dict[str, Any] | None: + """Recursively build a command tree for a Click app. + + Args: + cmd (click.Command or click.Group): The Click command/group to inspect. + + Returns: + dict | None: A nested dictionary representing the command structure, + or None for leaf commands. + """ + tree: dict[str, Any] = {} + if isinstance(cmd, click.Group): + for name, subcmd in cmd.commands.items(): + if cmd.name == "models" and name == "default": + tree[name] = MODELS + else: + # Recursively build the tree for subcommands + tree[name] = build_command_tree(subcmd) + else: + # Leaf command with no subcommands + return None + return tree + + +# Generate the tree +COMMAND_TREE: dict[str, Any] | None = build_command_tree(cli) if cli is not None else {} + + +def get_completions(tokens: list[str], tree: dict[str, Any] | None = COMMAND_TREE) -> list[str]: + """Get autocompletions for the current command tokens. + + Args: + tree (dict | None): The command tree. + tokens (list[str]): List of tokens (command arguments). + + Returns: + list[str]: List of possible completions. + """ + if not LLM_CLI_IMPORTED: + return [] + for token in tokens: + if token.startswith("-"): + # Skip options (flags) + continue + if tree and token in tree: + tree = tree[token] + else: + # No completions available + return [] + + # Return possible completions (keys of the current tree level) + return list(tree.keys()) if tree else [] + + +@export +class FinishIteration(Exception): + def __init__(self, results: Any | None = None) -> None: + self.results: Any | None = results + + +USAGE = """ +Use an LLM to create SQL queries to answer questions from your database. +Examples: + +# Ask a question. +> \\llm 'Most visited urls?' + +# List available models +> \\llm models +gpt-4o +gpt-3.5-turbo +qwq + +# Change default model +> \\llm models default llama3 + +# Set api key (not required for local models) +> \\llm keys set openai + + +# Install a model plugin +> \\llm install llm-ollama +llm-ollama installed. + +# Plugins directory +# https://llm.datasette.io/en/stable/plugins/directory.html +""" + +NEED_DEPENDENCIES = """ +To enable LLM features you need to install litecli with AI support: + + pip install 'litecli[ai]' + +or install LLM libraries separately + + pip install llm + +This is required to use the \\llm command. +""" + +_SQL_CODE_FENCE = r"```sql\n(.*?)\n```" +PROMPT = """ +You are a helpful assistant who is a SQLite expert. You are embedded in a SQLite +cli tool called litecli. + +Answer this question: + +$question + +Use the following context if it is relevant to answering the question. If the +question is not about the current database then ignore the context. + +You are connected to a SQLite database with the following schema: + +$db_schema + +Here is a sample row of data from each table: + +$sample_data + +If the answer can be found using a SQL query, include a sql query in a code +fence such as this one: + +```sql +SELECT count(*) FROM table_name; +``` +Keep your explanation concise and focused on the question asked. +""" + + +def ensure_litecli_template(replace: bool = False) -> None: + """ + Create a template called litecli with the default prompt. + """ + if not replace: + # Check if it already exists. + code, _ = run_external_cmd("llm", "templates", "show", LLM_TEMPLATE_NAME, capture_output=True, raise_exception=False) + if code == 0: # Template already exists. No need to create it. + return + + run_external_cmd("llm", PROMPT, "--save", LLM_TEMPLATE_NAME) + return + + +@export +def handle_llm(text: str, cur: DBCursor) -> tuple[str, str, float]: + """This function handles the special command `\\llm`. + + If it deals with a question that results in a SQL query then it will return + the query. + If it deals with a subcommand like `models` or `keys` then it will raise + FinishIteration() which will be caught by the main loop AND print any + output that was supplied (or None). + """ + # Determine invocation mode: regular, verbose (+), or succinct (-) + _, mode, arg = parse_special_command(text) + is_verbose = mode is Verbosity.VERBOSE + is_succinct = mode is Verbosity.SUCCINCT + + if not LLM_IMPORTED: + output = [(None, None, None, NEED_DEPENDENCIES)] + raise FinishIteration(output) + + if not arg.strip(): # No question provided. Print usage and bail. + output = [(None, None, None, USAGE)] + raise FinishIteration(output) + + parts = shlex.split(arg) + + restart = False + # If the parts has `-c` then capture the output and check for fenced SQL. + # User is continuing a previous question. + # eg: \llm -m ollama -c "Show only the top 5 results" + if "-c" in parts: + capture_output = True + use_context = False + # If the parts has `prompt` command without `-c` then use context to the prompt. + # \llm -m ollama prompt "Most visited urls?" + elif "prompt" in parts: # User might invoke prompt with an option flag in the first argument. + capture_output = True + use_context = True + elif "install" in parts or "uninstall" in parts: + capture_output = False + use_context = False + restart = True + # If the parts starts with any of the known LLM_CLI_COMMANDS then invoke + # the llm and don't capture output. This is to handle commands like `models` or `keys`. + elif parts[0] in LLM_CLI_COMMANDS: + capture_output = False + use_context = False + # If the user wants to use --help option to see each command and it's description + elif "--help" == parts[0]: + capture_output = False + use_context = False + # If the parts doesn't have any known LLM_CLI_COMMANDS then the user is + # invoking a question. eg: \llm -m ollama "Most visited urls?" + elif not set(parts).intersection(LLM_CLI_COMMANDS): + capture_output = True + use_context = True + # User invoked llm with a question without `prompt` subcommand. Capture the + # output and check for fenced SQL. eg: \llm "Most visited urls?" + else: + capture_output = True + use_context = True + + if not use_context: + args = parts + if capture_output: + click.echo("Calling llm command") + start = time() + _, result = run_external_cmd("llm", *args, capture_output=capture_output) + end = time() + match = re.search(_SQL_CODE_FENCE, result, re.DOTALL) + if match: + sql = match.group(1).strip() + else: + output = [(None, None, None, result)] + raise FinishIteration(output) + + context = "" if is_succinct else result + return context, sql, end - start + else: + run_external_cmd("llm", *args, restart_cli=restart) + raise FinishIteration(None) + + try: + ensure_litecli_template() + # Measure end-to-end LLM command invocation (schema gathering and LLM call) + start = time() + result, sql, prompt_text = sql_using_llm(cur=cur, question=arg, verbose=is_verbose) + end = time() + context = "" if is_succinct else result + if is_verbose and prompt_text is not None: + click.echo("LLM Prompt:") + click.echo(prompt_text) + click.echo("---") + return context, sql, end - start + except Exception as e: + # Something went wrong. Raise an exception and bail. + raise RuntimeError(e) + + +@export +def is_llm_command(command: str) -> bool: + """ + Is this an llm/ai command? + """ + cmd, _, _ = parse_special_command(command) + return cmd in ("\\llm", "\\ai", ".llm", ".ai") + + +@export +def sql_using_llm( + cur: DBCursor, + question: str | None = None, + verbose: bool = False, +) -> tuple[str, str, str | None]: + if cur is None: + raise RuntimeError("Connect to a datbase and try again.") + schema_query = """ + SELECT sql FROM sqlite_master + WHERE sql IS NOT NULL + ORDER BY tbl_name, type DESC, name + """ + tables_query = """ + SELECT name FROM sqlite_master + WHERE type IN ('table','view') AND name NOT LIKE 'sqlite_%' AND name NOT LIKE 'sqlean_%' + ORDER BY 1 + """ + click.echo("Preparing schema information to feed the llm") + sample_row_query = "SELECT * FROM {table} LIMIT 1" + log.debug(schema_query) + cur.execute(schema_query) + db_schema = "\n".join([x for (x,) in cur.fetchall()]) + + log.debug(tables_query) + cur.execute(tables_query) + sample_data = {} + for (table,) in cur.fetchall(): + sample_row = sample_row_query.format(table=table) + cur.execute(sample_row) + if cur.description is None: + continue + cols = [x[0] for x in cur.description] + row = cur.fetchone() + if row is None: # Skip empty tables + continue + sample_data[table] = list(zip(cols, row)) + + args = [ + "--template", + LLM_TEMPLATE_NAME, + "--param", + "db_schema", + db_schema, + "--param", + "sample_data", + sample_data, + "--param", + "question", + question, + " ", # Dummy argument to prevent llm from waiting on stdin + ] + click.echo("Invoking llm command with schema information") + # Ensure all args are strings for sys.argv safety inside run_module + str_args = [str(a) for a in args] + _, result = run_external_cmd("llm", *str_args, capture_output=True) + click.echo("Received response from the llm command") + match = re.search(_SQL_CODE_FENCE, result, re.DOTALL) + sql = match.group(1).strip() if match else "" + + # When verbose, build and return the rendered prompt text + prompt_text = None + if verbose: + # Render the prompt by substituting schema, sample_data, and question + prompt_text = PROMPT + prompt_text = prompt_text.replace("$db_schema", db_schema) + prompt_text = prompt_text.replace("$sample_data", pprint.pformat(sample_data)) + prompt_text = prompt_text.replace("$question", question or "") + if verbose: + return result, sql, prompt_text + return result, sql, None diff --git a/litecli/packages/special/main.py b/litecli/packages/special/main.py index 3dd0e77..08f7bd5 100644 --- a/litecli/packages/special/main.py +++ b/litecli/packages/special/main.py @@ -1,11 +1,20 @@ -from __future__ import unicode_literals +from __future__ import annotations import logging from collections import namedtuple +from enum import Enum +from typing import Any, Callable, cast from . import export log = logging.getLogger(__name__) +try: + import llm # noqa: F401 + + LLM_IMPORTED = True +except ImportError: + LLM_IMPORTED = False + NO_QUERY = 0 PARSED_QUERY = 1 RAW_QUERY = 2 @@ -36,25 +45,42 @@ class CommandNotFound(Exception): pass +class Verbosity(Enum): + """Invocation verbosity: succinct (-), normal, or verbose (+).""" + + SUCCINCT = "succinct" + NORMAL = "normal" + VERBOSE = "verbose" + + @export -def parse_special_command(sql): +def parse_special_command(sql: str) -> tuple[str, "Verbosity", str]: + """ + Parse a special command, extracting the base command name, verbosity + (normal, verbose (+), or succinct (-)), and the remaining argument. + Mirrors the behavior used in similar CLI tools. + """ command, _, arg = sql.partition(" ") - verbose = "+" in command - command = command.strip().replace("+", "") - return (command, verbose, arg.strip()) + verbosity = Verbosity.NORMAL + if "+" in command: + verbosity = Verbosity.VERBOSE + elif "-" in command: + verbosity = Verbosity.SUCCINCT + command = command.strip().strip("+-") + return (command, verbosity, arg.strip()) @export def special_command( - command, - shortcut, - description, - arg_type=PARSED_QUERY, - hidden=False, - case_sensitive=False, - aliases=(), -): - def wrapper(wrapped): + command: str, + shortcut: str, + description: str, + arg_type: int = PARSED_QUERY, + hidden: bool = False, + case_sensitive: bool = False, + aliases: tuple[str, ...] = (), +) -> Callable: + def wrapper(wrapped: Callable) -> Callable: register_special_command( wrapped, command, @@ -72,19 +98,17 @@ def wrapper(wrapped): @export def register_special_command( - handler, - command, - shortcut, - description, - arg_type=PARSED_QUERY, - hidden=False, - case_sensitive=False, - aliases=(), -): + handler: Callable, + command: str, + shortcut: str, + description: str, + arg_type: int = PARSED_QUERY, + hidden: bool = False, + case_sensitive: bool = False, + aliases: tuple[str, ...] = (), +) -> None: cmd = command.lower() if not case_sensitive else command - COMMANDS[cmd] = SpecialCommand( - handler, command, shortcut, description, arg_type, hidden, case_sensitive - ) + COMMANDS[cmd] = SpecialCommand(handler, command, shortcut, description, arg_type, hidden, case_sensitive) for alias in aliases: cmd = alias.lower() if not case_sensitive else alias COMMANDS[cmd] = SpecialCommand( @@ -99,11 +123,11 @@ def register_special_command( @export -def execute(cur, sql): +def execute(cur: Any, sql: str) -> list[tuple[Any, ...]]: """Execute a special command and return the results. If the special command is not supported a KeyError will be raised. """ - command, verbose, arg = parse_special_command(sql) + command, verbosity, arg = parse_special_command(sql) if (command not in COMMANDS) and (command.lower() not in COMMANDS): raise CommandNotFound @@ -116,17 +140,20 @@ def execute(cur, sql): raise CommandNotFound("Command not found: %s" % command) if special_cmd.arg_type == NO_QUERY: - return special_cmd.handler() + return cast(list[tuple[Any, ...]], special_cmd.handler()) elif special_cmd.arg_type == PARSED_QUERY: - return special_cmd.handler(cur=cur, arg=arg, verbose=verbose) + return cast( + list[tuple[Any, ...]], + special_cmd.handler(cur=cur, arg=arg, verbose=(verbosity == Verbosity.VERBOSE)), + ) elif special_cmd.arg_type == RAW_QUERY: - return special_cmd.handler(cur=cur, query=sql) + return cast(list[tuple[Any, ...]], special_cmd.handler(cur=cur, query=sql)) + raise CommandNotFound(f"Command type not found: {command}") -@special_command( - "help", "\\?", "Show this help.", arg_type=NO_QUERY, aliases=("\\?", "?") -) -def show_help(): # All the parameters are ignored. + +@special_command("help", "\\?", "Show this help.", arg_type=NO_QUERY, aliases=("\\?", "?")) +def show_help() -> list[tuple]: # All the parameters are ignored. headers = ["Command", "Shortcut", "Description"] result = [] @@ -138,7 +165,7 @@ def show_help(): # All the parameters are ignored. @special_command(".exit", "\\q", "Exit.", arg_type=NO_QUERY, aliases=("\\q", "exit")) @special_command("quit", "\\q", "Quit.", arg_type=NO_QUERY) -def quit(*_args): +def quit(*_args: Any) -> None: raise EOFError @@ -156,5 +183,19 @@ def quit(*_args): arg_type=NO_QUERY, case_sensitive=True, ) -def stub(): +def stub() -> None: raise NotImplementedError + + +if LLM_IMPORTED: + + @special_command( + "\\llm", + "\\ai", + "Use LLM to construct a SQL query.", + arg_type=NO_QUERY, + case_sensitive=False, + aliases=(".ai", ".llm"), + ) + def llm_stub() -> None: + raise NotImplementedError diff --git a/litecli/packages/special/types.py b/litecli/packages/special/types.py new file mode 100644 index 0000000..71f1be5 --- /dev/null +++ b/litecli/packages/special/types.py @@ -0,0 +1,18 @@ +from __future__ import annotations + +from typing import Any, Protocol, Sequence + + +class DBCursor(Protocol): + """Minimal DB-API cursor protocol used by special modules.""" + + description: Sequence[Sequence[Any]] | None + + # Optional attribute on many DB-API cursors + connection: Any + + def execute(self, sql: str, params: Any = ...) -> Any: ... + + def fetchall(self) -> list[tuple[Any, ...]]: ... + + def fetchone(self) -> tuple[Any, ...] | None: ... diff --git a/litecli/packages/special/utils.py b/litecli/packages/special/utils.py index eed9306..6470ee1 100644 --- a/litecli/packages/special/utils.py +++ b/litecli/packages/special/utils.py @@ -1,8 +1,11 @@ +from __future__ import annotations + + import os import subprocess -def handle_cd_command(arg): +def handle_cd_command(arg: str) -> tuple[bool, str | None]: """Handles a `cd` shell command by calling python's os.chdir.""" CD_CMD = "cd" tokens = arg.split(CD_CMD + " ") @@ -17,7 +20,7 @@ def handle_cd_command(arg): return False, e.strerror -def format_uptime(uptime_in_seconds): +def format_uptime(uptime_in_seconds: str) -> str: """Format number of seconds into human-readable string. :param uptime_in_seconds: The server uptime in seconds. @@ -32,7 +35,7 @@ def format_uptime(uptime_in_seconds): h, m = divmod(m, 60) d, h = divmod(h, 24) - uptime_values = [] + uptime_values: list[str] = [] for value, unit in ((d, "days"), (h, "hours"), (m, "min"), (s, "sec")): if value == 0 and not uptime_values: @@ -46,3 +49,85 @@ def format_uptime(uptime_in_seconds): uptime = " ".join(uptime_values) return uptime + + +def check_if_sqlitedotcommand(command: object) -> bool: + """Does a check if the command supplied is in the list of SQLite dot commands. + + :param command: A command (str) supplied from the user + :returns: True/False + """ + + sqlite3dotcommands = [ + ".archive", + ".auth", + ".backup", + ".bail", + ".binary", + ".cd", + ".changes", + ".check", + ".clone", + ".connection", + ".databases", + ".dbconfig", + ".dbinfo", + ".dump", + ".echo", + ".eqp", + ".excel", + ".exit", + ".expert", + ".explain", + ".filectrl", + ".fullschema", + ".headers", + ".help", + ".import", + ".imposter", + ".indexes", + ".limit", + ".lint", + ".load", + ".log", + ".mode", + ".nonce", + ".nullvalue", + ".once", + ".open", + ".output", + ".parameter", + ".print", + ".progress", + ".prompt", + ".quit", + ".read", + ".recover", + ".restore", + ".save", + ".scanstats", + ".schema", + ".selftest", + ".separator", + ".session", + ".sha3sum", + ".shell", + ".show", + ".stats", + ".system", + ".tables", + ".testcase", + ".testctrl", + ".timeout", + ".timer", + ".trace", + ".vfsinfo", + ".vfslist", + ".vfsname", + ".width", + ] + + if isinstance(command, str): + head = command.split(" ", 1)[0].lower() + return head in sqlite3dotcommands + return False diff --git a/litecli/sqlcompleter.py b/litecli/sqlcompleter.py index 6950a3c..0263447 100644 --- a/litecli/sqlcompleter.py +++ b/litecli/sqlcompleter.py @@ -1,21 +1,24 @@ -from __future__ import print_function -from __future__ import unicode_literals +from __future__ import annotations + import logging -from re import compile, escape from collections import Counter +from re import compile, escape +from typing import Any, Collection, Generator, Iterable, Literal, Sequence -from prompt_toolkit.completion import Completer, Completion +from prompt_toolkit.completion import CompleteEvent, Completer, Completion +from prompt_toolkit.completion.base import Document from .packages.completion_engine import suggest_type -from .packages.parseutils import last_word +from .packages.filepaths import complete_path, parse_path, suggest_path +from .packages.parseutils import LAST_WORD_INCLUDE_TYPE, last_word +from .packages.special import llm from .packages.special.iocommands import favoritequeries -from .packages.filepaths import parse_path, complete_path, suggest_path _logger = logging.getLogger(__name__) class SQLCompleter(Completer): - keywords = [ + keywords: list[str] = [ "ABORT", "ACTION", "ADD", @@ -117,13 +120,15 @@ class SQLCompleter(Completer): "NOT", "NOTHING", "NULL", + "NULLS FIRST", + "NULLS LAST", "NUMERIC", "NVARCHAR", "OF", "OFFSET", "ON", "OR", - "ORDER", + "ORDER BY", "OUTER", "OVER", "PARTITION", @@ -179,7 +184,7 @@ class SQLCompleter(Completer): "WITHOUT", ] - functions = [ + functions: list[str] = [ "ABS", "AVG", "CHANGES", @@ -250,53 +255,49 @@ class SQLCompleter(Completer): "TRIM", ] - def __init__(self, supported_formats=(), keyword_casing="auto"): + def __init__(self, supported_formats: Iterable[str] = (), keyword_casing: Literal["upper", "lower", "auto"] = "auto"): super(self.__class__, self).__init__() - self.reserved_words = set() + self.reserved_words: set[str] = set() for x in self.keywords: self.reserved_words.update(x.split()) - self.name_pattern = compile("^[_a-z][_a-z0-9\$]*$") + self.name_pattern = compile(r"^[_a-zA-Z][_a-zA-Z0-9\$]*$") - self.special_commands = [] - self.table_formats = supported_formats + self.special_commands: list[str] = [] + self.table_formats: list[str] = list(supported_formats) if keyword_casing not in ("upper", "lower", "auto"): keyword_casing = "auto" - self.keyword_casing = keyword_casing + self.keyword_casing: Literal["upper", "lower", "auto"] = keyword_casing self.reset_completions() - def escape_name(self, name): - if name and ( - (not self.name_pattern.match(name)) - or (name.upper() in self.reserved_words) - or (name.upper() in self.functions) - ): + def escape_name(self, name: str) -> str: + if name and ((not self.name_pattern.match(name)) or (name.upper() in self.reserved_words) or (name.upper() in self.functions)): name = "`%s`" % name return name - def unescape_name(self, name): + def unescape_name(self, name: str) -> str: """Unquote a string.""" if name and name[0] == '"' and name[-1] == '"': name = name[1:-1] return name - def escaped_names(self, names): + def escaped_names(self, names: Iterable[str]) -> list[str]: return [self.escape_name(name) for name in names] - def extend_special_commands(self, special_commands): + def extend_special_commands(self, special_commands: Iterable[str]) -> None: # Special commands are not part of all_completions since they can only # be at the beginning of a line. self.special_commands.extend(special_commands) - def extend_database_names(self, databases): + def extend_database_names(self, databases: Iterable[str]) -> None: self.databases.extend(databases) - def extend_keywords(self, additional_keywords): + def extend_keywords(self, additional_keywords: Iterable[str]) -> None: self.keywords.extend(additional_keywords) self.all_completions.update(additional_keywords) - def extend_schemata(self, schema): + def extend_schemata(self, schema: str | None) -> None: if schema is None: return metadata = self.dbmetadata["tables"] @@ -307,7 +308,7 @@ def extend_schemata(self, schema): metadata[schema] = {} self.all_completions.update(schema) - def extend_relations(self, data, kind): + def extend_relations(self, data: Iterable[Sequence[str]], kind: str) -> None: """Extend metadata for tables or views :param data: list of (rel_name, ) tuples @@ -321,6 +322,7 @@ def extend_relations(self, data, kind): try: data = [self.escaped_names(d) for d in data] except Exception: + _logger.exception("Failed to get relation names.") data = [] # dbmetadata['tables'][$schema_name][$table_name] should be a list of @@ -338,7 +340,7 @@ def extend_relations(self, data, kind): ) self.all_completions.add(relname[0]) - def extend_columns(self, column_data, kind): + def extend_columns(self, column_data: Iterable[Sequence[str]], kind: str) -> None: """Extend column metadata :param column_data: list of (rel_name, column_name) tuples @@ -352,6 +354,7 @@ def extend_columns(self, column_data, kind): try: column_data = [self.escaped_names(d) for d in column_data] except Exception: + _logger.exception("Failed to get column names.") column_data = [] metadata = self.dbmetadata[kind] @@ -359,7 +362,7 @@ def extend_columns(self, column_data, kind): metadata[self.dbname][relname].append(column) self.all_completions.add(column) - def extend_functions(self, func_data): + def extend_functions(self, func_data: Iterable[Sequence[str]]) -> None: # 'func_data' is a generator object. It can throw an exception while # being consumed. This could happen if the user has launched the app # without specifying a database name. This exception must be handled to @@ -367,6 +370,7 @@ def extend_functions(self, func_data): try: func_data = [self.escaped_names(d) for d in func_data] except Exception: + _logger.exception("Failed to get function names.") func_data = [] # dbmetadata['functions'][$schema_name][$function_name] should return @@ -377,24 +381,24 @@ def extend_functions(self, func_data): metadata[self.dbname][func[0]] = None self.all_completions.add(func[0]) - def set_dbname(self, dbname): + def set_dbname(self, dbname: str | None) -> None: self.dbname = dbname - def reset_completions(self): - self.databases = [] + def reset_completions(self) -> None: + self.databases: list[str] = [] self.dbname = "" - self.dbmetadata = {"tables": {}, "views": {}, "functions": {}} - self.all_completions = set(self.keywords + self.functions) + self.dbmetadata: dict[str, Any] = {"tables": {}, "views": {}, "functions": {}} + self.all_completions: set[str] = set(self.keywords + self.functions) @staticmethod def find_matches( - text, - collection, - start_only=False, - fuzzy=True, - casing=None, - punctuations="most_punctuations", - ): + text: str, + collection: Collection[str], + start_only: bool = False, + fuzzy: bool = True, + casing: str | None = None, + punctuations: LAST_WORD_INCLUDE_TYPE = "most_punctuations", + ) -> Generator[Completion, None, None]: """Find completion matches for the given text. Given the user's input text and a collection of available @@ -430,23 +434,23 @@ def find_matches( if casing == "auto": casing = "lower" if last and last[-1].islower() else "upper" - def apply_case(kw): + def apply_case(kw: str) -> str: if casing == "upper": return kw.upper() return kw.lower() - return ( - Completion(z if casing is None else apply_case(z), -len(text)) - for x, y, z in sorted(completions) - ) + return (Completion(z if casing is None else apply_case(z), -len(text)) for x, y, z in sorted(completions)) - def get_completions(self, document, complete_event): + def get_completions( + self, + document: Document, + complete_event: CompleteEvent | None, + ) -> Iterable[Completion]: word_before_cursor = document.get_word_before_cursor(WORD=True) - completions = [] + completions: list[Completion] = [] suggestions = suggest_type(document.text, document.text_before_cursor) for suggestion in suggestions: - _logger.debug("Suggestion type: %r", suggestion["type"]) if suggestion["type"] == "column": @@ -457,11 +461,7 @@ def get_completions(self, document, complete_event): # drop_unique is used for 'tb11 JOIN tbl2 USING (...' # which should suggest only columns that appear in more than # one table - scoped_cols = [ - col - for (col, count) in Counter(scoped_cols).items() - if count > 1 and col != "*" - ] + scoped_cols = [col for (col, count) in Counter(scoped_cols).items() if count > 1 and col != "*"] cols = self.find_matches(word_before_cursor, scoped_cols) completions.extend(cols) @@ -487,19 +487,19 @@ def get_completions(self, document, complete_event): completions.extend(predefined_funcs) elif suggestion["type"] == "table": - tables = self.populate_schema_objects(suggestion["schema"], "tables") - tables = self.find_matches(word_before_cursor, tables) - completions.extend(tables) + table_names = self.populate_schema_objects(suggestion["schema"], "tables") + table_matches = self.find_matches(word_before_cursor, table_names) + completions.extend(table_matches) elif suggestion["type"] == "view": - views = self.populate_schema_objects(suggestion["schema"], "views") - views = self.find_matches(word_before_cursor, views) - completions.extend(views) + view_names = self.populate_schema_objects(suggestion["schema"], "views") + view_matches = self.find_matches(word_before_cursor, view_names) + completions.extend(view_matches) elif suggestion["type"] == "alias": aliases = suggestion["aliases"] - aliases = self.find_matches(word_before_cursor, aliases) - completions.extend(aliases) + alias_matches = self.find_matches(word_before_cursor, aliases) + completions.extend(alias_matches) elif suggestion["type"] == "database": dbs = self.find_matches(word_before_cursor, self.databases) @@ -534,17 +534,28 @@ def get_completions(self, document, complete_event): ) completions.extend(queries) elif suggestion["type"] == "table_format": - formats = self.find_matches( - word_before_cursor, self.table_formats, start_only=True, fuzzy=False - ) + formats = self.find_matches(word_before_cursor, self.table_formats, start_only=True, fuzzy=False) completions.extend(formats) elif suggestion["type"] == "file_name": file_names = self.find_files(word_before_cursor) completions.extend(file_names) + elif suggestion["type"] == "llm": + if not word_before_cursor: + tokens = document.text.split()[1:] + else: + tokens = document.text.split()[1:-1] + possible_entries = llm.get_completions(tokens) + subcommands = self.find_matches( + word_before_cursor, + possible_entries, + start_only=False, + fuzzy=True, + ) + completions.extend(subcommands) return completions - def find_files(self, word): + def find_files(self, word: str) -> Generator[Completion, None, None]: """Yield matching directory or file names. :param word: @@ -558,7 +569,7 @@ def find_files(self, word): if suggestion: yield Completion(suggestion, position) - def populate_scoped_cols(self, scoped_tbls): + def populate_scoped_cols(self, scoped_tbls: list[tuple[str | None, str, str | None]]) -> list[str]: """Find all columns in a set of scoped_tables :param scoped_tbls: list of (schema, table, alias) tuples :return: list of column names @@ -596,15 +607,15 @@ def populate_scoped_cols(self, scoped_tbls): return columns - def populate_schema_objects(self, schema, obj_type): + def populate_schema_objects(self, schema: str | None, obj_type: str) -> list[str]: """Returns list of tables or functions for a (optional) schema""" metadata = self.dbmetadata[obj_type] schema = schema or self.dbname try: - objects = metadata[schema].keys() + keys = list(metadata[schema].keys()) except KeyError: # schema doesn't exist - objects = [] + keys = [] - return objects + return keys diff --git a/litecli/sqlexecute.py b/litecli/sqlexecute.py index 7ef103c..5e52b73 100644 --- a/litecli/sqlexecute.py +++ b/litecli/sqlexecute.py @@ -1,13 +1,25 @@ +from __future__ import annotations + import logging -import sqlite3 -import uuid +import os.path from contextlib import closing -from sqlite3 import OperationalError +from typing import Any, Generator, Iterable, cast +from urllib.parse import urlparse import sqlparse -import os.path -from .packages import special +try: + import sqlean as _sqlite3 + + _sqlite3.extensions.enable_all() +except ImportError: + import sqlite3 as _sqlite3 + +from litecli.packages import special +from litecli.packages.special.utils import check_if_sqlitedotcommand + +sqlite3 = cast(Any, _sqlite3) +OperationalError = sqlite3.OperationalError _logger = logging.getLogger(__name__) @@ -18,7 +30,6 @@ class SQLExecute(object): - databases_query = """ PRAGMA database_list """ @@ -26,52 +37,68 @@ class SQLExecute(object): tables_query = """ SELECT name FROM sqlite_master - WHERE type IN ('table','view') AND name NOT LIKE 'sqlite_%' + WHERE type IN ('table','view') AND name NOT LIKE 'sqlite_%' AND name NOT LIKE 'sqlean_%' ORDER BY 1 """ table_columns_query = """ SELECT m.name as tableName, p.name as columnName FROM sqlite_master m - LEFT OUTER JOIN pragma_table_info((m.name)) p ON m.name <> p.name - WHERE m.type IN ('table','view') AND m.name NOT LIKE 'sqlite_%' + JOIN pragma_table_info((m.name)) p + WHERE m.type IN ('table','view') AND m.name NOT LIKE 'sqlite_%' AND m.name NOT LIKE 'sqlean_%' ORDER BY tableName, columnName """ + indexes_query = """ + SELECT name, sql + FROM sqlite_master + WHERE type = 'index' AND name NOT LIKE 'sqlite_%' + ORDER BY 1 + """ + functions_query = '''SELECT ROUTINE_NAME FROM INFORMATION_SCHEMA.ROUTINES WHERE ROUTINE_TYPE="FUNCTION" AND ROUTINE_SCHEMA = "%s"''' - def __init__(self, database): - self.dbname = database - self._server_type = None - self.connection_id = None - self.conn = None + def __init__(self, database: str | None): + self.dbname: str | None = database + self._server_type: tuple[str, str] | None = None + # Connection can be sqlite3.Connection or sqlean.sqlite3 connection. + self.conn: Any | None = None if not database: _logger.debug("Database is not specified. Skip connection.") return self.connect() - def connect(self, database=None): + def connect(self, database: str | None = None) -> None: db = database or self.dbname - _logger.debug("Connection DB Params: \n" "\tdatabase: %r", database) - - db_name = os.path.expanduser(db) - db_dir_name = os.path.dirname(os.path.abspath(db_name)) - if not os.path.exists(db_dir_name): - raise Exception("Path does not exist: {}".format(db_dir_name)) + _logger.debug("Connection DB Params: \n\tdatabase: %r", db) + if db is None: + # Nothing to connect to. + return - conn = sqlite3.connect(database=db_name, isolation_level=None) + location = urlparse(db) + if location.scheme and location.scheme == "file": + uri = True + db_name = db + db_filename = location.path + else: + uri = False + db_filename = db_name = os.path.expanduser(db) + db_dir_name = os.path.dirname(os.path.abspath(db_filename)) + if not os.path.exists(db_dir_name): + raise Exception("Path does not exist: {}".format(db_dir_name)) + + conn = sqlite3.connect(database=db_name, isolation_level=None, uri=uri) + conn.text_factory = lambda x: x.decode("utf-8", "backslashreplace") if self.conn: self.conn.close() self.conn = conn # Update them after the connection is made to ensure that it was a # successful connection. - self.dbname = db - # retrieve connection id - self.reset_connection_id() + self.dbname = db_filename - def run(self, statement): + def run(self, statement: str) -> Iterable[tuple]: """Execute the sql in the database and return the results. The results are a list of tuples. Each tuple has 4 values (title, rows, headers, status). @@ -108,9 +135,7 @@ def run(self, statement): or sql.startswith("exit") or sql.startswith("quit") ): - _logger.debug( - "Not connected to database. Will not run statement: %s.", sql - ) + _logger.debug("Not connected to database. Will not run statement: %s.", sql) raise OperationalError("Not connected to database.") # yield ('Not connected to database', None, None, None) # return @@ -121,11 +146,15 @@ def run(self, statement): for result in special.execute(cur, sql): yield result except special.CommandNotFound: # Regular SQL - _logger.debug("Regular sql statement. sql: %r", sql) - cur.execute(sql) - yield self.get_result(cur) - - def get_result(self, cursor): + if check_if_sqlitedotcommand(sql): + yield ("dot command not implemented", None, None, None) + else: + _logger.debug("Regular sql statement. sql: %r", sql) + assert cur is not None + cur.execute(sql) + yield self.get_result(cur) + + def get_result(self, cursor: Any) -> tuple[str | None, list | None, list | None, str]: """Get the current result's data from the cursor.""" title = headers = None @@ -133,37 +162,43 @@ def get_result(self, cursor): # e.g. SELECT. if cursor.description is not None: headers = [x[0] for x in cursor.description] - status = "{0} row{1} in set" + status = "{count} row{s} in set" cursor = list(cursor) rowcount = len(cursor) else: _logger.debug("No rows in result.") - status = "Query OK, {0} row{1} affected" - rowcount = 0 if cursor.rowcount == -1 else cursor.rowcount + if cursor.rowcount == -1: + status = "Query OK" + else: + status = "Query OK, {count} row{s} affected" + rowcount = cursor.rowcount cursor = None - status = status.format(rowcount, "" if rowcount == 1 else "s") + status = status.format(count=rowcount, s="" if rowcount == 1 else "s") return (title, cursor, headers, status) - def tables(self): + def tables(self) -> Generator[tuple[str], None, None]: """Yields table names""" - + if not self.conn: + return with closing(self.conn.cursor()) as cur: _logger.debug("Tables Query. sql: %r", self.tables_query) cur.execute(self.tables_query) for row in cur: yield row - def table_columns(self): + def table_columns(self) -> Generator[tuple[str, str], None, None]: """Yields column names""" + if not self.conn: + return with closing(self.conn.cursor()) as cur: _logger.debug("Columns Query. sql: %r", self.table_columns_query) cur.execute(self.table_columns_query) for row in cur: yield row - def databases(self): + def databases(self) -> Generator[str, None, None]: if not self.conn: return @@ -172,41 +207,16 @@ def databases(self): for row in cur.execute(self.databases_query): yield row[1] - def functions(self): + def functions(self) -> Iterable[tuple]: """Yields tuples of (schema_name, function_name)""" - + if not self.conn: + return with closing(self.conn.cursor()) as cur: _logger.debug("Functions Query. sql: %r", self.functions_query) cur.execute(self.functions_query % self.dbname) for row in cur: yield row - def show_candidates(self): - with closing(self.conn.cursor()) as cur: - _logger.debug("Show Query. sql: %r", self.show_candidates_query) - try: - cur.execute(self.show_candidates_query) - except sqlite3.DatabaseError as e: - _logger.error("No show completions due to %r", e) - yield "" - else: - for row in cur: - yield (row[0].split(None, 1)[-1],) - - def server_type(self): + def server_type(self) -> tuple[str, str]: self._server_type = ("sqlite3", "3") return self._server_type - - def get_connection_id(self): - if not self.connection_id: - self.reset_connection_id() - return self.connection_id - - def reset_connection_id(self): - # Remember current connection id - _logger.debug("Get current connection id") - # res = self.run('select connection_id()') - self.connection_id = uuid.uuid4() - # for title, cur, headers, status in res: - # self.connection_id = cur.fetchone()[0] - _logger.debug("Current connection id: %s", self.connection_id) diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..ecd8a1e --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,78 @@ +[project] +name = "litecli" +dynamic = ["version"] +description = "CLI for SQLite Databases with auto-completion and syntax highlighting." +readme = "README.md" +requires-python = ">=3.10" +license = { text = "BSD" } +authors = [{ name = "dbcli", email = "litecli-users@googlegroups.com" }] +urls = { "homepage" = "https://github.com/dbcli/litecli" } +dependencies = [ + "cli-helpers[styles]>=2.2.1", + "click>=4.1,!=8.1.*", + "configobj>=5.0.5", + "prompt-toolkit>=3.0.3,<4.0.0", + "pygments>=1.6", + "sqlparse>=0.4.4" +] + +[build-system] +requires = [ + "setuptools>=64.0", + "setuptools-scm>=8;python_version>='3.8'", + "setuptools-scm<8;python_version<'3.8'", +] +build-backend = "setuptools.build_meta" + +[tool.setuptools_scm] + +[project.scripts] +litecli = "litecli.main:cli" + +[project.optional-dependencies] +ai = [ + "llm>=0.25.0", + "setuptools", # Required by llm commands to install models + "pip", +] +sqlean = ["sqlean-py>=3.47.0", + "sqlean-stubs>=0.0.3"] + +dev = [ + "behave>=1.2.6", + "coverage>=7.2.7", + "pexpect>=4.9.0", + "pytest>=7.4.4", + "pytest-cov>=4.1.0", + "tox>=4.8.0", + "pdbpp>=0.10.3", + "llm>=0.25.0", + "setuptools", + "pip", + "ty>=0.0.4" +] + +[tool.setuptools.packages.find] +exclude = ["screenshots", "tests*"] + +[tool.setuptools.package-data] +litecli = ["liteclirc", "AUTHORS"] + +[tool.ruff] +line-length = 140 + +[tool.ty.environment] +python-version = "3.10" +root = [".", "litecli", "litecli/packages", "litecli/packages/special"] + + +[tool.ty.src] +exclude = [ + '**/build/', + '**/dist/', + '**/.tox/', + '**/.venv/', + '**/.mypy_cache/', + '**/.pytest_cache/', + '**/.ruff_cache/' +] diff --git a/release.py b/release.py deleted file mode 100644 index 264a4c3..0000000 --- a/release.py +++ /dev/null @@ -1,130 +0,0 @@ -#!/usr/bin/env python -"""A script to publish a release of litecli to PyPI.""" - -from __future__ import print_function -import io -from optparse import OptionParser -import re -import subprocess -import sys - -import click - -DEBUG = False -CONFIRM_STEPS = False -DRY_RUN = False - - -def skip_step(): - """ - Asks for user's response whether to run a step. Default is yes. - :return: boolean - """ - global CONFIRM_STEPS - - if CONFIRM_STEPS: - return not click.confirm("--- Run this step?", default=True) - return False - - -def run_step(*args): - """ - Prints out the command and asks if it should be run. - If yes (default), runs it. - :param args: list of strings (command and args) - """ - global DRY_RUN - - cmd = args - print(" ".join(cmd)) - if skip_step(): - print("--- Skipping...") - elif DRY_RUN: - print("--- Pretending to run...") - else: - subprocess.check_output(cmd) - - -def version(version_file): - _version_re = re.compile( - r'__version__\s+=\s+(?P[\'"])(?P.*)(?P=quote)' - ) - - with io.open(version_file, encoding="utf-8") as f: - ver = _version_re.search(f.read()).group("version") - - return ver - - -def commit_for_release(version_file, ver): - run_step("git", "reset") - run_step("git", "add", version_file) - run_step("git", "commit", "--message", "Releasing version {}".format(ver)) - - -def create_git_tag(tag_name): - run_step("git", "tag", tag_name) - - -def create_distribution_files(): - run_step("python", "setup.py", "sdist", "bdist_wheel") - - -def upload_distribution_files(): - run_step("twine", "upload", "dist/*") - - -def push_to_github(): - run_step("git", "push", "origin", "master") - - -def push_tags_to_github(): - run_step("git", "push", "--tags", "origin") - - -def checklist(questions): - for question in questions: - if not click.confirm("--- {}".format(question), default=False): - sys.exit(1) - - -if __name__ == "__main__": - if DEBUG: - subprocess.check_output = lambda x: x - - ver = version("litecli/__init__.py") - print("Releasing Version:", ver) - - parser = OptionParser() - parser.add_option( - "-c", - "--confirm-steps", - action="store_true", - dest="confirm_steps", - default=False, - help=( - "Confirm every step. If the step is not " "confirmed, it will be skipped." - ), - ) - parser.add_option( - "-d", - "--dry-run", - action="store_true", - dest="dry_run", - default=False, - help="Print out, but not actually run any steps.", - ) - - popts, pargs = parser.parse_args() - CONFIRM_STEPS = popts.confirm_steps - DRY_RUN = popts.dry_run - - if not click.confirm("Are you sure?", default=False): - sys.exit(1) - - commit_for_release("litecli/__init__.py", ver) - create_git_tag("v{}".format(ver)) - create_distribution_files() - push_to_github() - push_tags_to_github() - upload_distribution_files() diff --git a/requirements-dev.txt b/requirements-dev.txt deleted file mode 100644 index b95211a..0000000 --- a/requirements-dev.txt +++ /dev/null @@ -1,9 +0,0 @@ -mock -pytest>=3.6 -pytest-cov -tox -behave -pexpect -coverage -codecov -click diff --git a/setup.cfg b/setup.cfg deleted file mode 100644 index 40eab0a..0000000 --- a/setup.cfg +++ /dev/null @@ -1,18 +0,0 @@ -[bdist_wheel] -universal = 1 - -[tool:pytest] -addopts = --capture=sys - --showlocals - --doctest-modules - --doctest-ignore-import-errors - --ignore=setup.py - --ignore=litecli/magic.py - --ignore=litecli/packages/parseutils.py - --ignore=test/features - -[pep8] -rev = master -docformatter = True -diff = True -error-status = True diff --git a/setup.py b/setup.py deleted file mode 100755 index 056b99a..0000000 --- a/setup.py +++ /dev/null @@ -1,73 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- - -import ast -from io import open -import re -import sys -import subprocess -from setuptools import Command, setup, find_packages - -_version_re = re.compile(r"__version__\s+=\s+(.*)") - -with open("litecli/__init__.py", "rb") as f: - version = str( - ast.literal_eval(_version_re.search(f.read().decode("utf-8")).group(1)) - ) - - -def open_file(filename): - """Open and read the file *filename*.""" - with open(filename) as f: - return f.read() - - -readme = open_file("README.md") - -install_requirements = [ - "click >= 4.1", - "Pygments >= 1.6", - "prompt_toolkit>=2.0.0,<2.1.0", - "sqlparse>=0.2.2,<0.3.0", - "configobj >= 5.0.5", - "cli_helpers[styles] >= 1.0.1", -] - - -setup( - name="litecli", - author="dbcli", - author_email="litecli-users@googlegroups.com", - license="BSD", - version=version, - url="https://github.com/dbcli/litecli", - packages=find_packages(), - package_data={"litecli": ["liteclirc", "AUTHORS"]}, - description="CLI for SQLite Databases with auto-completion and syntax " - "highlighting.", - long_description=readme, - install_requires=install_requirements, - # cmdclass={"test": test, "lint": lint}, - entry_points={ - "console_scripts": ["litecli = litecli.main:cli"], - "distutils.commands": ["lint = tasks:lint", "test = tasks:test"], - }, - classifiers=[ - "Intended Audience :: Developers", - "License :: OSI Approved :: BSD License", - "Operating System :: Unix", - "Programming Language :: Python", - "Programming Language :: Python :: 2", - "Programming Language :: Python :: 2.7", - "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.4", - "Programming Language :: Python :: 3.5", - "Programming Language :: Python :: 3.6", - "Programming Language :: Python :: 3.7", - "Programming Language :: SQL", - "Topic :: Database", - "Topic :: Database :: Front-Ends", - "Topic :: Software Development", - "Topic :: Software Development :: Libraries :: Python Modules", - ], -) diff --git a/tasks.py b/tasks.py deleted file mode 100644 index 5e68107..0000000 --- a/tasks.py +++ /dev/null @@ -1,128 +0,0 @@ -# -*- coding: utf-8 -*- -"""Common development tasks for setup.py to use.""" - -import re -import subprocess -import sys - -from setuptools import Command -from setuptools.command.test import test as TestCommand - - -class BaseCommand(Command, object): - """The base command for project tasks.""" - - user_options = [] - - default_cmd_options = ("verbose", "quiet", "dry_run") - - def __init__(self, *args, **kwargs): - super(BaseCommand, self).__init__(*args, **kwargs) - self.verbose = False - - def initialize_options(self): - """Override the distutils abstract method.""" - pass - - def finalize_options(self): - """Override the distutils abstract method.""" - # Distutils uses incrementing integers for verbosity. - self.verbose = bool(self.verbose) - - def call_and_exit(self, cmd, shell=True): - """Run the *cmd* and exit with the proper exit code.""" - sys.exit(subprocess.call(cmd, shell=shell)) - - def call_in_sequence(self, cmds, shell=True): - """Run multiple commmands in a row, exiting if one fails.""" - for cmd in cmds: - if subprocess.call(cmd, shell=shell) == 1: - sys.exit(1) - - def apply_options(self, cmd, options=()): - """Apply command-line options.""" - for option in self.default_cmd_options + options: - cmd = self.apply_option(cmd, option, active=getattr(self, option, False)) - return cmd - - def apply_option(self, cmd, option, active=True): - """Apply a command-line option.""" - return re.sub( - r"{{{}\:(?P