Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Refactor dialect #271

Merged
merged 4 commits into from
Nov 2, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion data_diff/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@ def _apply_config(config: Dict[str, Any], run_name: str, kw: Dict[str, Any]):
try:
args = run_args.pop(index)
except KeyError:
raise ConfigParseError(f"Could not find source #{index}: Expecting a key of '{index}' containing '.database' and '.table'.")
raise ConfigParseError(
f"Could not find source #{index}: Expecting a key of '{index}' containing '.database' and '.table'."
)
for attr in ("database", "table"):
if attr not in args:
raise ConfigParseError(f"Running 'run.{run_name}': Connection #{index} is missing attribute '{attr}'.")
Expand Down
228 changes: 119 additions & 109 deletions data_diff/databases/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
from data_diff.queries import Expr, Compiler, table, Select, SKIP, Explain
from .database_types import (
AbstractDatabase,
AbstractDialect,
AbstractMixin_MD5,
AbstractMixin_NormalizeValue,
ColType,
Integer,
Decimal,
Expand Down Expand Up @@ -99,6 +102,116 @@ def apply_query(callback: Callable[[str], Any], sql_code: Union[str, ThreadLocal
return callback(sql_code)


class BaseDialect(AbstractDialect, AbstractMixin_MD5, AbstractMixin_NormalizeValue):
SUPPORTS_PRIMARY_KEY = False
TYPE_CLASSES: Dict[str, type] = {}

def offset_limit(self, offset: Optional[int] = None, limit: Optional[int] = None):
if offset:
raise NotImplementedError("No support for OFFSET in query")

return f"LIMIT {limit}"

def concat(self, items: List[str]) -> str:
assert len(items) > 1
joined_exprs = ", ".join(items)
return f"concat({joined_exprs})"

def is_distinct_from(self, a: str, b: str) -> str:
return f"{a} is distinct from {b}"

def timestamp_value(self, t: DbTime) -> str:
return f"'{t.isoformat()}'"

def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str:
if isinstance(coltype, String_UUID):
return f"TRIM({value})"
return self.to_string(value)

def random(self) -> str:
return "RANDOM()"

def explain_as_text(self, query: str) -> str:
return f"EXPLAIN {query}"

def _constant_value(self, v):
if v is None:
return "NULL"
elif isinstance(v, str):
return f"'{v}'"
elif isinstance(v, datetime):
# TODO use self.timestamp_value
return f"timestamp '{v}'"
elif isinstance(v, UUID):
return f"'{v}'"
return repr(v)

def constant_values(self, rows) -> str:
values = ", ".join("(%s)" % ", ".join(self._constant_value(v) for v in row) for row in rows)
return f"VALUES {values}"

def type_repr(self, t) -> str:
if isinstance(t, str):
return t
return {
int: "INT",
str: "VARCHAR",
bool: "BOOLEAN",
float: "FLOAT",
datetime: "TIMESTAMP",
}[t]

def _parse_type_repr(self, type_repr: str) -> Optional[Type[ColType]]:
return self.TYPE_CLASSES.get(type_repr)

def parse_type(
self,
table_path: DbPath,
col_name: str,
type_repr: str,
datetime_precision: int = None,
numeric_precision: int = None,
numeric_scale: int = None,
) -> ColType:
""" """

cls = self._parse_type_repr(type_repr)
if not cls:
return UnknownColType(type_repr)

if issubclass(cls, TemporalType):
return cls(
precision=datetime_precision if datetime_precision is not None else DEFAULT_DATETIME_PRECISION,
rounds=self.ROUNDS_ON_PREC_LOSS,
)

elif issubclass(cls, Integer):
return cls()

elif issubclass(cls, Decimal):
if numeric_scale is None:
numeric_scale = 0 # Needed for Oracle.
return cls(precision=numeric_scale)

elif issubclass(cls, Float):
# assert numeric_scale is None
return cls(
precision=self._convert_db_precision_to_digits(
numeric_precision if numeric_precision is not None else DEFAULT_NUMERIC_PRECISION
)
)

elif issubclass(cls, (Text, Native_UUID)):
return cls()

raise TypeError(f"Parsing {type_repr} returned an unknown type '{cls}'.")

def _convert_db_precision_to_digits(self, p: int) -> int:
"""Convert from binary precision, used by floats, to decimal precision."""
# See: https://en.wikipedia.org/wiki/Single-precision_floating-point_format
return math.floor(math.log(2**p, 10))


class Database(AbstractDatabase):
"""Base abstract class for databases.

Expand All @@ -107,10 +220,10 @@ class Database(AbstractDatabase):
Instanciated using :meth:`~data_diff.connect`
"""

TYPE_CLASSES: Dict[str, type] = {}
default_schema: str = None
dialect: AbstractDialect = None

SUPPORTS_ALPHANUMS = True
SUPPORTS_PRIMARY_KEY = False
SUPPORTS_UNIQUE_CONSTAINT = False

_interactive = False
Expand Down Expand Up @@ -169,56 +282,6 @@ def query(self, sql_ast: Union[Expr, Generator], res_type: type = list):
def enable_interactive(self):
self._interactive = True

def _convert_db_precision_to_digits(self, p: int) -> int:
"""Convert from binary precision, used by floats, to decimal precision."""
# See: https://en.wikipedia.org/wiki/Single-precision_floating-point_format
return math.floor(math.log(2**p, 10))

def _parse_type_repr(self, type_repr: str) -> Optional[Type[ColType]]:
return self.TYPE_CLASSES.get(type_repr)

def _parse_type(
self,
table_path: DbPath,
col_name: str,
type_repr: str,
datetime_precision: int = None,
numeric_precision: int = None,
numeric_scale: int = None,
) -> ColType:
""" """

cls = self._parse_type_repr(type_repr)
if not cls:
return UnknownColType(type_repr)

if issubclass(cls, TemporalType):
return cls(
precision=datetime_precision if datetime_precision is not None else DEFAULT_DATETIME_PRECISION,
rounds=self.ROUNDS_ON_PREC_LOSS,
)

elif issubclass(cls, Integer):
return cls()

elif issubclass(cls, Decimal):
if numeric_scale is None:
numeric_scale = 0 # Needed for Oracle.
return cls(precision=numeric_scale)

elif issubclass(cls, Float):
# assert numeric_scale is None
return cls(
precision=self._convert_db_precision_to_digits(
numeric_precision if numeric_precision is not None else DEFAULT_NUMERIC_PRECISION
)
)

elif issubclass(cls, (Text, Native_UUID)):
return cls()

raise TypeError(f"Parsing {type_repr} returned an unknown type '{cls}'.")

def select_table_schema(self, path: DbPath) -> str:
schema, table = self._normalize_table_path(path)

Expand Down Expand Up @@ -257,7 +320,9 @@ def _process_table_schema(
):
accept = {i.lower() for i in filter_columns}

col_dict = {row[0]: self._parse_type(path, *row) for name, row in raw_schema.items() if name.lower() in accept}
col_dict = {
row[0]: self.dialect.parse_type(path, *row) for name, row in raw_schema.items() if name.lower() in accept
}

self._refine_coltypes(path, col_dict, where)

Expand All @@ -274,7 +339,7 @@ def _refine_coltypes(self, table_path: DbPath, col_dict: Dict[str, ColType], whe
if not text_columns:
return

fields = [self.normalize_uuid(self.quote(c), String_UUID()) for c in text_columns]
fields = [self.dialect.normalize_uuid(self.dialect.quote(c), String_UUID()) for c in text_columns]
samples_by_row = self.query(table(*table_path).select(*fields).where(where or SKIP).limit(sample_size), list)
if not samples_by_row:
raise ValueError(f"Table {table_path} is empty.")
Expand Down Expand Up @@ -321,58 +386,6 @@ def _normalize_table_path(self, path: DbPath) -> DbPath:
def parse_table_name(self, name: str) -> DbPath:
return parse_table_name(name)

def offset_limit(self, offset: Optional[int] = None, limit: Optional[int] = None):
if offset:
raise NotImplementedError("No support for OFFSET in query")

return f"LIMIT {limit}"

def concat(self, items: List[str]) -> str:
assert len(items) > 1
joined_exprs = ", ".join(items)
return f"concat({joined_exprs})"

def is_distinct_from(self, a: str, b: str) -> str:
return f"{a} is distinct from {b}"

def timestamp_value(self, t: DbTime) -> str:
return f"'{t.isoformat()}'"

def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str:
if isinstance(coltype, String_UUID):
return f"TRIM({value})"
return self.to_string(value)

def random(self) -> str:
return "RANDOM()"

def _constant_value(self, v):
if v is None:
return "NULL"
elif isinstance(v, str):
return f"'{v}'"
elif isinstance(v, datetime):
# TODO use self.timestamp_value
return f"timestamp '{v}'"
elif isinstance(v, UUID):
return f"'{v}'"
return repr(v)

def constant_values(self, rows) -> str:
values = ", ".join("(%s)" % ", ".join(self._constant_value(v) for v in row) for row in rows)
return f"VALUES {values}"

def type_repr(self, t) -> str:
if isinstance(t, str):
return t
return {
int: "INT",
str: "VARCHAR",
bool: "BOOLEAN",
float: "FLOAT",
datetime: "TIMESTAMP",
}[t]

def _query_cursor(self, c, sql_code: str):
assert isinstance(sql_code, str), sql_code
try:
Expand All @@ -389,9 +402,6 @@ def _query_conn(self, conn, sql_code: Union[str, ThreadLocalInterpreter]) -> lis
callback = partial(self._query_cursor, c)
return apply_query(callback, sql_code)

def explain_as_text(self, query: str) -> str:
return f"EXPLAIN {query}"


class ThreadedDatabase(Database):
"""Access the database through singleton threads.
Expand Down
Loading