Initial commit

This commit is contained in:
2025-12-10 22:47:38 +01:00
parent e98d8c5c1b
commit f78c4d389d
2870 changed files with 641720 additions and 0 deletions

View File

@@ -0,0 +1,29 @@
from .editor import open_in_editor as open_in_editor
from .exc import AutogenerateDiffsDetected as AutogenerateDiffsDetected
from .exc import CommandError as CommandError
from .langhelpers import _with_legacy_names as _with_legacy_names
from .langhelpers import asbool as asbool
from .langhelpers import dedupe_tuple as dedupe_tuple
from .langhelpers import Dispatcher as Dispatcher
from .langhelpers import EMPTY_DICT as EMPTY_DICT
from .langhelpers import immutabledict as immutabledict
from .langhelpers import memoized_property as memoized_property
from .langhelpers import ModuleClsProxy as ModuleClsProxy
from .langhelpers import not_none as not_none
from .langhelpers import rev_id as rev_id
from .langhelpers import to_list as to_list
from .langhelpers import to_tuple as to_tuple
from .langhelpers import unique_list as unique_list
from .messaging import err as err
from .messaging import format_as_comma as format_as_comma
from .messaging import msg as msg
from .messaging import obfuscate_url_pw as obfuscate_url_pw
from .messaging import status as status
from .messaging import warn as warn
from .messaging import warn_deprecated as warn_deprecated
from .messaging import write_outstream as write_outstream
from .pyfiles import coerce_resource_to_filename as coerce_resource_to_filename
from .pyfiles import load_python_file as load_python_file
from .pyfiles import pyc_file_from_path as pyc_file_from_path
from .pyfiles import template_to_file as template_to_file
from .sqla_compat import sqla_2 as sqla_2

View File

Binary file not shown.

View File

Binary file not shown.

View File

Binary file not shown.

View File

Binary file not shown.

View File

Binary file not shown.

View File

Binary file not shown.

View File

Binary file not shown.

View File

Binary file not shown.

View File

@@ -0,0 +1,146 @@
# mypy: no-warn-unused-ignores
from __future__ import annotations
from configparser import ConfigParser
import io
import os
from pathlib import Path
import sys
import typing
from typing import Any
from typing import Iterator
from typing import List
from typing import Optional
from typing import Sequence
from typing import Union
if True:
# zimports hack for too-long names
from sqlalchemy.util import ( # noqa: F401
inspect_getfullargspec as inspect_getfullargspec,
)
from sqlalchemy.util.compat import ( # noqa: F401
inspect_formatargspec as inspect_formatargspec,
)
is_posix = os.name == "posix"
py314 = sys.version_info >= (3, 14)
py313 = sys.version_info >= (3, 13)
py312 = sys.version_info >= (3, 12)
py311 = sys.version_info >= (3, 11)
py310 = sys.version_info >= (3, 10)
py39 = sys.version_info >= (3, 9)
# produce a wrapper that allows encoded text to stream
# into a given buffer, but doesn't close it.
# not sure of a more idiomatic approach to this.
class EncodedIO(io.TextIOWrapper):
def close(self) -> None:
pass
if py39:
from importlib import resources as _resources
importlib_resources = _resources
from importlib import metadata as _metadata
importlib_metadata = _metadata
from importlib.metadata import EntryPoint as EntryPoint
else:
import importlib_resources # type:ignore # noqa
import importlib_metadata # type:ignore # noqa
from importlib_metadata import EntryPoint # type:ignore # noqa
if py311:
import tomllib as tomllib
else:
import tomli as tomllib # type: ignore # noqa
if py312:
def path_walk(
path: Path, *, top_down: bool = True
) -> Iterator[tuple[Path, list[str], list[str]]]:
return Path.walk(path)
def path_relative_to(
path: Path, other: Path, *, walk_up: bool = False
) -> Path:
return path.relative_to(other, walk_up=walk_up)
else:
def path_walk(
path: Path, *, top_down: bool = True
) -> Iterator[tuple[Path, list[str], list[str]]]:
for root, dirs, files in os.walk(path, topdown=top_down):
yield Path(root), dirs, files
def path_relative_to(
path: Path, other: Path, *, walk_up: bool = False
) -> Path:
"""
Calculate the relative path of 'path' with respect to 'other',
optionally allowing 'path' to be outside the subtree of 'other'.
OK I used AI for this, sorry
"""
try:
return path.relative_to(other)
except ValueError:
if walk_up:
other_ancestors = list(other.parents) + [other]
for ancestor in other_ancestors:
try:
return path.relative_to(ancestor)
except ValueError:
continue
raise ValueError(
f"{path} is not in the same subtree as {other}"
)
else:
raise
def importlib_metadata_get(group: str) -> Sequence[EntryPoint]:
ep = importlib_metadata.entry_points()
if hasattr(ep, "select"):
return ep.select(group=group)
else:
return ep.get(group, ()) # type: ignore
def formatannotation_fwdref(
annotation: Any, base_module: Optional[Any] = None
) -> str:
"""vendored from python 3.7"""
# copied over _formatannotation from sqlalchemy 2.0
if isinstance(annotation, str):
return annotation
if getattr(annotation, "__module__", None) == "typing":
return repr(annotation).replace("typing.", "").replace("~", "")
if isinstance(annotation, type):
if annotation.__module__ in ("builtins", base_module):
return repr(annotation.__qualname__)
return annotation.__module__ + "." + annotation.__qualname__
elif isinstance(annotation, typing.TypeVar):
return repr(annotation).replace("~", "")
return repr(annotation).replace("~", "")
def read_config_parser(
file_config: ConfigParser,
file_argument: Sequence[Union[str, os.PathLike[str]]],
) -> List[str]:
if py310:
return file_config.read(file_argument, encoding="locale")
else:
return file_config.read(file_argument)

View File

@@ -0,0 +1,81 @@
from __future__ import annotations
import os
from os.path import exists
from os.path import join
from os.path import splitext
from subprocess import check_call
from typing import Dict
from typing import List
from typing import Mapping
from typing import Optional
from .compat import is_posix
from .exc import CommandError
def open_in_editor(
filename: str, environ: Optional[Dict[str, str]] = None
) -> None:
"""
Opens the given file in a text editor. If the environment variable
``EDITOR`` is set, this is taken as preference.
Otherwise, a list of commonly installed editors is tried.
If no editor matches, an :py:exc:`OSError` is raised.
:param filename: The filename to open. Will be passed verbatim to the
editor command.
:param environ: An optional drop-in replacement for ``os.environ``. Used
mainly for testing.
"""
env = os.environ if environ is None else environ
try:
editor = _find_editor(env)
check_call([editor, filename])
except Exception as exc:
raise CommandError("Error executing editor (%s)" % (exc,)) from exc
def _find_editor(environ: Mapping[str, str]) -> str:
candidates = _default_editors()
for i, var in enumerate(("EDITOR", "VISUAL")):
if var in environ:
user_choice = environ[var]
if exists(user_choice):
return user_choice
if os.sep not in user_choice:
candidates.insert(i, user_choice)
for candidate in candidates:
path = _find_executable(candidate, environ)
if path is not None:
return path
raise OSError(
"No suitable editor found. Please set the "
'"EDITOR" or "VISUAL" environment variables'
)
def _find_executable(
candidate: str, environ: Mapping[str, str]
) -> Optional[str]:
# Assuming this is on the PATH, we need to determine it's absolute
# location. Otherwise, ``check_call`` will fail
if not is_posix and splitext(candidate)[1] != ".exe":
candidate += ".exe"
for path in environ.get("PATH", "").split(os.pathsep):
value = join(path, candidate)
if exists(value):
return value
return None
def _default_editors() -> List[str]:
# Look for an editor. Prefer the user's choice by env-var, fall back to
# most commonly installed editor (nano/vim)
if is_posix:
return ["sensible-editor", "editor", "nano", "vim", "code"]
else:
return ["code.exe", "notepad++.exe", "notepad.exe"]

View File

@@ -0,0 +1,25 @@
from __future__ import annotations
from typing import Any
from typing import List
from typing import Tuple
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from alembic.autogenerate import RevisionContext
class CommandError(Exception):
pass
class AutogenerateDiffsDetected(CommandError):
def __init__(
self,
message: str,
revision_context: RevisionContext,
diffs: List[Tuple[Any, ...]],
) -> None:
super().__init__(message)
self.revision_context = revision_context
self.diffs = diffs

View File

@@ -0,0 +1,332 @@
from __future__ import annotations
import collections
from collections.abc import Iterable
import textwrap
from typing import Any
from typing import Callable
from typing import cast
from typing import Dict
from typing import List
from typing import Mapping
from typing import MutableMapping
from typing import NoReturn
from typing import Optional
from typing import overload
from typing import Sequence
from typing import Set
from typing import Tuple
from typing import Type
from typing import TYPE_CHECKING
from typing import TypeVar
from typing import Union
import uuid
import warnings
from sqlalchemy.util import asbool as asbool # noqa: F401
from sqlalchemy.util import immutabledict as immutabledict # noqa: F401
from sqlalchemy.util import to_list as to_list # noqa: F401
from sqlalchemy.util import unique_list as unique_list
from .compat import inspect_getfullargspec
if True:
# zimports workaround :(
from sqlalchemy.util import ( # noqa: F401
memoized_property as memoized_property,
)
EMPTY_DICT: Mapping[Any, Any] = immutabledict()
_T = TypeVar("_T", bound=Any)
_C = TypeVar("_C", bound=Callable[..., Any])
class _ModuleClsMeta(type):
def __setattr__(cls, key: str, value: Callable[..., Any]) -> None:
super().__setattr__(key, value)
cls._update_module_proxies(key) # type: ignore
class ModuleClsProxy(metaclass=_ModuleClsMeta):
"""Create module level proxy functions for the
methods on a given class.
The functions will have a compatible signature
as the methods.
"""
_setups: Dict[
Type[Any],
Tuple[
Set[str],
List[Tuple[MutableMapping[str, Any], MutableMapping[str, Any]]],
],
] = collections.defaultdict(lambda: (set(), []))
@classmethod
def _update_module_proxies(cls, name: str) -> None:
attr_names, modules = cls._setups[cls]
for globals_, locals_ in modules:
cls._add_proxied_attribute(name, globals_, locals_, attr_names)
def _install_proxy(self) -> None:
attr_names, modules = self._setups[self.__class__]
for globals_, locals_ in modules:
globals_["_proxy"] = self
for attr_name in attr_names:
globals_[attr_name] = getattr(self, attr_name)
def _remove_proxy(self) -> None:
attr_names, modules = self._setups[self.__class__]
for globals_, locals_ in modules:
globals_["_proxy"] = None
for attr_name in attr_names:
del globals_[attr_name]
@classmethod
def create_module_class_proxy(
cls,
globals_: MutableMapping[str, Any],
locals_: MutableMapping[str, Any],
) -> None:
attr_names, modules = cls._setups[cls]
modules.append((globals_, locals_))
cls._setup_proxy(globals_, locals_, attr_names)
@classmethod
def _setup_proxy(
cls,
globals_: MutableMapping[str, Any],
locals_: MutableMapping[str, Any],
attr_names: Set[str],
) -> None:
for methname in dir(cls):
cls._add_proxied_attribute(methname, globals_, locals_, attr_names)
@classmethod
def _add_proxied_attribute(
cls,
methname: str,
globals_: MutableMapping[str, Any],
locals_: MutableMapping[str, Any],
attr_names: Set[str],
) -> None:
if not methname.startswith("_"):
meth = getattr(cls, methname)
if callable(meth):
locals_[methname] = cls._create_method_proxy(
methname, globals_, locals_
)
else:
attr_names.add(methname)
@classmethod
def _create_method_proxy(
cls,
name: str,
globals_: MutableMapping[str, Any],
locals_: MutableMapping[str, Any],
) -> Callable[..., Any]:
fn = getattr(cls, name)
def _name_error(name: str, from_: Exception) -> NoReturn:
raise NameError(
"Can't invoke function '%s', as the proxy object has "
"not yet been "
"established for the Alembic '%s' class. "
"Try placing this code inside a callable."
% (name, cls.__name__)
) from from_
globals_["_name_error"] = _name_error
translations = getattr(fn, "_legacy_translations", [])
if translations:
spec = inspect_getfullargspec(fn)
if spec[0] and spec[0][0] == "self":
spec[0].pop(0)
outer_args = inner_args = "*args, **kw"
translate_str = "args, kw = _translate(%r, %r, %r, args, kw)" % (
fn.__name__,
tuple(spec),
translations,
)
def translate(
fn_name: str, spec: Any, translations: Any, args: Any, kw: Any
) -> Any:
return_kw = {}
return_args = []
for oldname, newname in translations:
if oldname in kw:
warnings.warn(
"Argument %r is now named %r "
"for method %s()." % (oldname, newname, fn_name)
)
return_kw[newname] = kw.pop(oldname)
return_kw.update(kw)
args = list(args)
if spec[3]:
pos_only = spec[0][: -len(spec[3])]
else:
pos_only = spec[0]
for arg in pos_only:
if arg not in return_kw:
try:
return_args.append(args.pop(0))
except IndexError:
raise TypeError(
"missing required positional argument: %s"
% arg
)
return_args.extend(args)
return return_args, return_kw
globals_["_translate"] = translate
else:
outer_args = "*args, **kw"
inner_args = "*args, **kw"
translate_str = ""
func_text = textwrap.dedent(
"""\
def %(name)s(%(args)s):
%(doc)r
%(translate)s
try:
p = _proxy
except NameError as ne:
_name_error('%(name)s', ne)
return _proxy.%(name)s(%(apply_kw)s)
e
"""
% {
"name": name,
"translate": translate_str,
"args": outer_args,
"apply_kw": inner_args,
"doc": fn.__doc__,
}
)
lcl: MutableMapping[str, Any] = {}
exec(func_text, cast("Dict[str, Any]", globals_), lcl)
return cast("Callable[..., Any]", lcl[name])
def _with_legacy_names(translations: Any) -> Any:
def decorate(fn: _C) -> _C:
fn._legacy_translations = translations # type: ignore[attr-defined]
return fn
return decorate
def rev_id() -> str:
return uuid.uuid4().hex[-12:]
@overload
def to_tuple(x: Any, default: Tuple[Any, ...]) -> Tuple[Any, ...]: ...
@overload
def to_tuple(x: None, default: Optional[_T] = ...) -> _T: ...
@overload
def to_tuple(
x: Any, default: Optional[Tuple[Any, ...]] = None
) -> Tuple[Any, ...]: ...
def to_tuple(
x: Any, default: Optional[Tuple[Any, ...]] = None
) -> Optional[Tuple[Any, ...]]:
if x is None:
return default
elif isinstance(x, str):
return (x,)
elif isinstance(x, Iterable):
return tuple(x)
else:
return (x,)
def dedupe_tuple(tup: Tuple[str, ...]) -> Tuple[str, ...]:
return tuple(unique_list(tup))
class Dispatcher:
def __init__(self, uselist: bool = False) -> None:
self._registry: Dict[Tuple[Any, ...], Any] = {}
self.uselist = uselist
def dispatch_for(
self, target: Any, qualifier: str = "default"
) -> Callable[[_C], _C]:
def decorate(fn: _C) -> _C:
if self.uselist:
self._registry.setdefault((target, qualifier), []).append(fn)
else:
assert (target, qualifier) not in self._registry
self._registry[(target, qualifier)] = fn
return fn
return decorate
def dispatch(self, obj: Any, qualifier: str = "default") -> Any:
if isinstance(obj, str):
targets: Sequence[Any] = [obj]
elif isinstance(obj, type):
targets = obj.__mro__
else:
targets = type(obj).__mro__
for spcls in targets:
if qualifier != "default" and (spcls, qualifier) in self._registry:
return self._fn_or_list(self._registry[(spcls, qualifier)])
elif (spcls, "default") in self._registry:
return self._fn_or_list(self._registry[(spcls, "default")])
else:
raise ValueError("no dispatch function for object: %s" % obj)
def _fn_or_list(
self, fn_or_list: Union[List[Callable[..., Any]], Callable[..., Any]]
) -> Callable[..., Any]:
if self.uselist:
def go(*arg: Any, **kw: Any) -> None:
if TYPE_CHECKING:
assert isinstance(fn_or_list, Sequence)
for fn in fn_or_list:
fn(*arg, **kw)
return go
else:
return fn_or_list # type: ignore
def branch(self) -> Dispatcher:
"""Return a copy of this dispatcher that is independently
writable."""
d = Dispatcher()
if self.uselist:
d._registry.update(
(k, [fn for fn in self._registry[k]]) for k in self._registry
)
else:
d._registry.update(self._registry)
return d
def not_none(value: Optional[_T]) -> _T:
assert value is not None
return value

View File

@@ -0,0 +1,122 @@
from __future__ import annotations
from collections.abc import Iterable
from contextlib import contextmanager
import logging
import sys
import textwrap
from typing import Iterator
from typing import Optional
from typing import TextIO
from typing import Union
import warnings
from sqlalchemy.engine import url
log = logging.getLogger(__name__)
# disable "no handler found" errors
logging.getLogger("alembic").addHandler(logging.NullHandler())
try:
import fcntl
import termios
import struct
ioctl = fcntl.ioctl(0, termios.TIOCGWINSZ, struct.pack("HHHH", 0, 0, 0, 0))
_h, TERMWIDTH, _hp, _wp = struct.unpack("HHHH", ioctl)
if TERMWIDTH <= 0: # can occur if running in emacs pseudo-tty
TERMWIDTH = None
except (ImportError, OSError):
TERMWIDTH = None
def write_outstream(
stream: TextIO, *text: Union[str, bytes], quiet: bool = False
) -> None:
if quiet:
return
encoding = getattr(stream, "encoding", "ascii") or "ascii"
for t in text:
if not isinstance(t, bytes):
t = t.encode(encoding, "replace")
t = t.decode(encoding)
try:
stream.write(t)
except OSError:
# suppress "broken pipe" errors.
# no known way to handle this on Python 3 however
# as the exception is "ignored" (noisily) in TextIOWrapper.
break
@contextmanager
def status(
status_msg: str, newline: bool = False, quiet: bool = False
) -> Iterator[None]:
msg(status_msg + " ...", newline, flush=True, quiet=quiet)
try:
yield
except:
if not quiet:
write_outstream(sys.stdout, " FAILED\n")
raise
else:
if not quiet:
write_outstream(sys.stdout, " done\n")
def err(message: str, quiet: bool = False) -> None:
log.error(message)
msg(f"FAILED: {message}", quiet=quiet)
sys.exit(-1)
def obfuscate_url_pw(input_url: str) -> str:
return url.make_url(input_url).render_as_string(hide_password=True)
def warn(msg: str, stacklevel: int = 2) -> None:
warnings.warn(msg, UserWarning, stacklevel=stacklevel)
def warn_deprecated(msg: str, stacklevel: int = 2) -> None:
warnings.warn(msg, DeprecationWarning, stacklevel=stacklevel)
def msg(
msg: str, newline: bool = True, flush: bool = False, quiet: bool = False
) -> None:
if quiet:
return
if TERMWIDTH is None:
write_outstream(sys.stdout, msg)
if newline:
write_outstream(sys.stdout, "\n")
else:
# left indent output lines
indent = " "
lines = textwrap.wrap(
msg,
TERMWIDTH,
initial_indent=indent,
subsequent_indent=indent,
)
if len(lines) > 1:
for line in lines[0:-1]:
write_outstream(sys.stdout, line, "\n")
write_outstream(sys.stdout, lines[-1], ("\n" if newline else ""))
if flush:
sys.stdout.flush()
def format_as_comma(value: Optional[Union[str, Iterable[str]]]) -> str:
if value is None:
return ""
elif isinstance(value, str):
return value
elif isinstance(value, Iterable):
return ", ".join(value)
else:
raise ValueError("Don't know how to comma-format %r" % value)

View File

@@ -0,0 +1,153 @@
from __future__ import annotations
import atexit
from contextlib import ExitStack
import importlib
import importlib.machinery
import importlib.util
import os
import pathlib
import re
import tempfile
from types import ModuleType
from typing import Any
from typing import Optional
from typing import Union
from mako import exceptions
from mako.template import Template
from . import compat
from .exc import CommandError
def template_to_file(
template_file: Union[str, os.PathLike[str]],
dest: Union[str, os.PathLike[str]],
output_encoding: str,
*,
append_with_newlines: bool = False,
**kw: Any,
) -> None:
template = Template(filename=_preserving_path_as_str(template_file))
try:
output = template.render_unicode(**kw).encode(output_encoding)
except:
with tempfile.NamedTemporaryFile(suffix=".txt", delete=False) as ntf:
ntf.write(
exceptions.text_error_template()
.render_unicode()
.encode(output_encoding)
)
fname = ntf.name
raise CommandError(
"Template rendering failed; see %s for a "
"template-oriented traceback." % fname
)
else:
with open(dest, "ab" if append_with_newlines else "wb") as f:
if append_with_newlines:
f.write("\n\n".encode(output_encoding))
f.write(output)
def coerce_resource_to_filename(fname_or_resource: str) -> pathlib.Path:
"""Interpret a filename as either a filesystem location or as a package
resource.
Names that are non absolute paths and contain a colon
are interpreted as resources and coerced to a file location.
"""
# TODO: there seem to be zero tests for the package resource codepath
if not os.path.isabs(fname_or_resource) and ":" in fname_or_resource:
tokens = fname_or_resource.split(":")
# from https://importlib-resources.readthedocs.io/en/latest/migration.html#pkg-resources-resource-filename # noqa E501
file_manager = ExitStack()
atexit.register(file_manager.close)
ref = compat.importlib_resources.files(tokens[0])
for tok in tokens[1:]:
ref = ref / tok
fname_or_resource = file_manager.enter_context( # type: ignore[assignment] # noqa: E501
compat.importlib_resources.as_file(ref)
)
return pathlib.Path(fname_or_resource)
def pyc_file_from_path(
path: Union[str, os.PathLike[str]],
) -> Optional[pathlib.Path]:
"""Given a python source path, locate the .pyc."""
pathpath = pathlib.Path(path)
candidate = pathlib.Path(
importlib.util.cache_from_source(pathpath.as_posix())
)
if candidate.exists():
return candidate
# even for pep3147, fall back to the old way of finding .pyc files,
# to support sourceless operation
ext = pathpath.suffix
for ext in importlib.machinery.BYTECODE_SUFFIXES:
if pathpath.with_suffix(ext).exists():
return pathpath.with_suffix(ext)
else:
return None
def load_python_file(
dir_: Union[str, os.PathLike[str]], filename: Union[str, os.PathLike[str]]
) -> ModuleType:
"""Load a file from the given path as a Python module."""
dir_ = pathlib.Path(dir_)
filename_as_path = pathlib.Path(filename)
filename = filename_as_path.name
module_id = re.sub(r"\W", "_", filename)
path = dir_ / filename
ext = path.suffix
if ext == ".py":
if path.exists():
module = load_module_py(module_id, path)
else:
pyc_path = pyc_file_from_path(path)
if pyc_path is None:
raise ImportError("Can't find Python file %s" % path)
else:
module = load_module_py(module_id, pyc_path)
elif ext in (".pyc", ".pyo"):
module = load_module_py(module_id, path)
else:
assert False
return module
def load_module_py(
module_id: str, path: Union[str, os.PathLike[str]]
) -> ModuleType:
spec = importlib.util.spec_from_file_location(module_id, path)
assert spec
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module) # type: ignore
return module
def _preserving_path_as_str(path: Union[str, os.PathLike[str]]) -> str:
"""receive str/pathlike and return a string.
Does not convert an incoming string path to a Path first, to help with
unit tests that are doing string path round trips without OS-specific
processing if not necessary.
"""
if isinstance(path, str):
return path
elif isinstance(path, pathlib.PurePath):
return str(path)
else:
return str(pathlib.Path(path))

View File

@@ -0,0 +1,495 @@
# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls
# mypy: no-warn-return-any, allow-any-generics
from __future__ import annotations
import contextlib
import re
from typing import Any
from typing import Callable
from typing import Dict
from typing import Iterable
from typing import Iterator
from typing import Optional
from typing import Protocol
from typing import Set
from typing import Type
from typing import TYPE_CHECKING
from typing import TypeVar
from typing import Union
from sqlalchemy import __version__
from sqlalchemy import schema
from sqlalchemy import sql
from sqlalchemy import types as sqltypes
from sqlalchemy.schema import CheckConstraint
from sqlalchemy.schema import Column
from sqlalchemy.schema import ForeignKeyConstraint
from sqlalchemy.sql import visitors
from sqlalchemy.sql.base import DialectKWArgs
from sqlalchemy.sql.elements import BindParameter
from sqlalchemy.sql.elements import ColumnClause
from sqlalchemy.sql.elements import TextClause
from sqlalchemy.sql.elements import UnaryExpression
from sqlalchemy.sql.naming import _NONE_NAME as _NONE_NAME # type: ignore[attr-defined] # noqa: E501
from sqlalchemy.sql.visitors import traverse
from typing_extensions import TypeGuard
if TYPE_CHECKING:
from sqlalchemy import ClauseElement
from sqlalchemy import Identity
from sqlalchemy import Index
from sqlalchemy import Table
from sqlalchemy.engine import Connection
from sqlalchemy.engine import Dialect
from sqlalchemy.engine import Transaction
from sqlalchemy.sql.base import ColumnCollection
from sqlalchemy.sql.compiler import SQLCompiler
from sqlalchemy.sql.elements import ColumnElement
from sqlalchemy.sql.schema import Constraint
from sqlalchemy.sql.schema import SchemaItem
_CE = TypeVar("_CE", bound=Union["ColumnElement[Any]", "SchemaItem"])
class _CompilerProtocol(Protocol):
def __call__(self, element: Any, compiler: Any, **kw: Any) -> str: ...
def _safe_int(value: str) -> Union[int, str]:
try:
return int(value)
except:
return value
_vers = tuple(
[_safe_int(x) for x in re.findall(r"(\d+|[abc]\d)", __version__)]
)
# https://docs.sqlalchemy.org/en/latest/changelog/changelog_14.html#change-0c6e0cc67dfe6fac5164720e57ef307d
sqla_14_18 = _vers >= (1, 4, 18)
sqla_14_26 = _vers >= (1, 4, 26)
sqla_2 = _vers >= (2,)
sqlalchemy_version = __version__
if TYPE_CHECKING:
def compiles(
element: Type[ClauseElement], *dialects: str
) -> Callable[[_CompilerProtocol], _CompilerProtocol]: ...
else:
from sqlalchemy.ext.compiler import compiles # noqa: I100,I202
identity_has_dialect_kwargs = issubclass(schema.Identity, DialectKWArgs)
def _get_identity_options_dict(
identity: Union[Identity, schema.Sequence, None],
dialect_kwargs: bool = False,
) -> Dict[str, Any]:
if identity is None:
return {}
elif identity_has_dialect_kwargs:
assert hasattr(identity, "_as_dict")
as_dict = identity._as_dict()
if dialect_kwargs:
assert isinstance(identity, DialectKWArgs)
as_dict.update(identity.dialect_kwargs)
else:
as_dict = {}
if isinstance(identity, schema.Identity):
# always=None means something different than always=False
as_dict["always"] = identity.always
if identity.on_null is not None:
as_dict["on_null"] = identity.on_null
# attributes common to Identity and Sequence
attrs = (
"start",
"increment",
"minvalue",
"maxvalue",
"nominvalue",
"nomaxvalue",
"cycle",
"cache",
"order",
)
as_dict.update(
{
key: getattr(identity, key, None)
for key in attrs
if getattr(identity, key, None) is not None
}
)
return as_dict
if sqla_2:
from sqlalchemy.sql.base import _NoneName
else:
from sqlalchemy.util import symbol as _NoneName # type: ignore[assignment]
_ConstraintName = Union[None, str, _NoneName]
_ConstraintNameDefined = Union[str, _NoneName]
def constraint_name_defined(
name: _ConstraintName,
) -> TypeGuard[_ConstraintNameDefined]:
return name is _NONE_NAME or isinstance(name, (str, _NoneName))
def constraint_name_string(name: _ConstraintName) -> TypeGuard[str]:
return isinstance(name, str)
def constraint_name_or_none(name: _ConstraintName) -> Optional[str]:
return name if constraint_name_string(name) else None
AUTOINCREMENT_DEFAULT = "auto"
@contextlib.contextmanager
def _ensure_scope_for_ddl(
connection: Optional[Connection],
) -> Iterator[None]:
try:
in_transaction = connection.in_transaction # type: ignore[union-attr]
except AttributeError:
# catch for MockConnection, None
in_transaction = None
pass
# yield outside the catch
if in_transaction is None:
yield
else:
if not in_transaction():
assert connection is not None
with connection.begin():
yield
else:
yield
def _safe_begin_connection_transaction(
connection: Connection,
) -> Transaction:
transaction = connection.get_transaction()
if transaction:
return transaction
else:
return connection.begin()
def _safe_commit_connection_transaction(
connection: Connection,
) -> None:
transaction = connection.get_transaction()
if transaction:
transaction.commit()
def _safe_rollback_connection_transaction(
connection: Connection,
) -> None:
transaction = connection.get_transaction()
if transaction:
transaction.rollback()
def _get_connection_in_transaction(connection: Optional[Connection]) -> bool:
try:
in_transaction = connection.in_transaction # type: ignore
except AttributeError:
# catch for MockConnection
return False
else:
return in_transaction()
def _idx_table_bound_expressions(idx: Index) -> Iterable[ColumnElement[Any]]:
return idx.expressions # type: ignore
def _copy(schema_item: _CE, **kw) -> _CE:
if hasattr(schema_item, "_copy"):
return schema_item._copy(**kw)
else:
return schema_item.copy(**kw) # type: ignore[union-attr]
def _connectable_has_table(
connectable: Connection, tablename: str, schemaname: Union[str, None]
) -> bool:
return connectable.dialect.has_table(connectable, tablename, schemaname)
def _exec_on_inspector(inspector, statement, **params):
with inspector._operation_context() as conn:
return conn.execute(statement, params)
def _nullability_might_be_unset(metadata_column):
from sqlalchemy.sql import schema
return metadata_column._user_defined_nullable is schema.NULL_UNSPECIFIED
def _server_default_is_computed(*server_default) -> bool:
return any(isinstance(sd, schema.Computed) for sd in server_default)
def _server_default_is_identity(*server_default) -> bool:
return any(isinstance(sd, schema.Identity) for sd in server_default)
def _table_for_constraint(constraint: Constraint) -> Table:
if isinstance(constraint, ForeignKeyConstraint):
table = constraint.parent
assert table is not None
return table # type: ignore[return-value]
else:
return constraint.table
def _columns_for_constraint(constraint):
if isinstance(constraint, ForeignKeyConstraint):
return [fk.parent for fk in constraint.elements]
elif isinstance(constraint, CheckConstraint):
return _find_columns(constraint.sqltext)
else:
return list(constraint.columns)
def _resolve_for_variant(type_, dialect):
if _type_has_variants(type_):
base_type, mapping = _get_variant_mapping(type_)
return mapping.get(dialect.name, base_type)
else:
return type_
if hasattr(sqltypes.TypeEngine, "_variant_mapping"): # 2.0
def _type_has_variants(type_):
return bool(type_._variant_mapping)
def _get_variant_mapping(type_):
return type_, type_._variant_mapping
else:
def _type_has_variants(type_):
return type(type_) is sqltypes.Variant
def _get_variant_mapping(type_):
return type_.impl, type_.mapping
def _fk_spec(constraint: ForeignKeyConstraint) -> Any:
if TYPE_CHECKING:
assert constraint.columns is not None
assert constraint.elements is not None
assert isinstance(constraint.parent, Table)
source_columns = [
constraint.columns[key].name for key in constraint.column_keys
]
source_table = constraint.parent.name
source_schema = constraint.parent.schema
target_schema = constraint.elements[0].column.table.schema
target_table = constraint.elements[0].column.table.name
target_columns = [element.column.name for element in constraint.elements]
ondelete = constraint.ondelete
onupdate = constraint.onupdate
deferrable = constraint.deferrable
initially = constraint.initially
return (
source_schema,
source_table,
source_columns,
target_schema,
target_table,
target_columns,
onupdate,
ondelete,
deferrable,
initially,
)
def _fk_is_self_referential(constraint: ForeignKeyConstraint) -> bool:
spec = constraint.elements[0]._get_colspec()
tokens = spec.split(".")
tokens.pop(-1) # colname
tablekey = ".".join(tokens)
assert constraint.parent is not None
return tablekey == constraint.parent.key
def _is_type_bound(constraint: Constraint) -> bool:
# this deals with SQLAlchemy #3260, don't copy CHECK constraints
# that will be generated by the type.
# new feature added for #3260
return constraint._type_bound
def _find_columns(clause):
"""locate Column objects within the given expression."""
cols: Set[ColumnElement[Any]] = set()
traverse(clause, {}, {"column": cols.add})
return cols
def _remove_column_from_collection(
collection: ColumnCollection, column: Union[Column[Any], ColumnClause[Any]]
) -> None:
"""remove a column from a ColumnCollection."""
# workaround for older SQLAlchemy, remove the
# same object that's present
assert column.key is not None
to_remove = collection[column.key]
# SQLAlchemy 2.0 will use more ReadOnlyColumnCollection
# (renamed from ImmutableColumnCollection)
if hasattr(collection, "_immutable") or hasattr(collection, "_readonly"):
collection._parent.remove(to_remove)
else:
collection.remove(to_remove)
def _textual_index_column(
table: Table, text_: Union[str, TextClause, ColumnElement[Any]]
) -> Union[ColumnElement[Any], Column[Any]]:
"""a workaround for the Index construct's severe lack of flexibility"""
if isinstance(text_, str):
c = Column(text_, sqltypes.NULLTYPE)
table.append_column(c)
return c
elif isinstance(text_, TextClause):
return _textual_index_element(table, text_)
elif isinstance(text_, _textual_index_element):
return _textual_index_column(table, text_.text)
elif isinstance(text_, sql.ColumnElement):
return _copy_expression(text_, table)
else:
raise ValueError("String or text() construct expected")
def _copy_expression(expression: _CE, target_table: Table) -> _CE:
def replace(col):
if (
isinstance(col, Column)
and col.table is not None
and col.table is not target_table
):
if col.name in target_table.c:
return target_table.c[col.name]
else:
c = _copy(col)
target_table.append_column(c)
return c
else:
return None
return visitors.replacement_traverse( # type: ignore[call-overload]
expression, {}, replace
)
class _textual_index_element(sql.ColumnElement):
"""Wrap around a sqlalchemy text() construct in such a way that
we appear like a column-oriented SQL expression to an Index
construct.
The issue here is that currently the Postgresql dialect, the biggest
recipient of functional indexes, keys all the index expressions to
the corresponding column expressions when rendering CREATE INDEX,
so the Index we create here needs to have a .columns collection that
is the same length as the .expressions collection. Ultimately
SQLAlchemy should support text() expressions in indexes.
See SQLAlchemy issue 3174.
"""
__visit_name__ = "_textual_idx_element"
def __init__(self, table: Table, text: TextClause) -> None:
self.table = table
self.text = text
self.key = text.text
self.fake_column = schema.Column(self.text.text, sqltypes.NULLTYPE)
table.append_column(self.fake_column)
def get_children(self, **kw):
return [self.fake_column]
@compiles(_textual_index_element)
def _render_textual_index_column(
element: _textual_index_element, compiler: SQLCompiler, **kw
) -> str:
return compiler.process(element.text, **kw)
class _literal_bindparam(BindParameter):
pass
@compiles(_literal_bindparam)
def _render_literal_bindparam(
element: _literal_bindparam, compiler: SQLCompiler, **kw
) -> str:
return compiler.render_literal_bindparam(element, **kw)
def _get_constraint_final_name(
constraint: Union[Index, Constraint], dialect: Optional[Dialect]
) -> Optional[str]:
if constraint.name is None:
return None
assert dialect is not None
# for SQLAlchemy 1.4 we would like to have the option to expand
# the use of "deferred" names for constraints as well as to have
# some flexibility with "None" name and similar; make use of new
# SQLAlchemy API to return what would be the final compiled form of
# the name for this dialect.
return dialect.identifier_preparer.format_constraint(
constraint, _alembic_quote=False
)
def _constraint_is_named(
constraint: Union[Constraint, Index], dialect: Optional[Dialect]
) -> bool:
if constraint.name is None:
return False
assert dialect is not None
name = dialect.identifier_preparer.format_constraint(
constraint, _alembic_quote=False
)
return name is not None
def is_expression_index(index: Index) -> bool:
for expr in index.expressions:
if is_expression(expr):
return True
return False
def is_expression(expr: Any) -> bool:
while isinstance(expr, UnaryExpression):
expr = expr.element
if not isinstance(expr, ColumnClause) or expr.is_literal:
return True
return False