Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Refactor tests oct2022 #253

Merged
merged 7 commits into from
Oct 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ jobs:
env:
DATADIFF_SNOWFLAKE_URI: '${{ secrets.DATADIFF_SNOWFLAKE_URI }}'
DATADIFF_PRESTO_URI: '${{ secrets.DATADIFF_PRESTO_URI }}'
DATADIFF_TRINO_URI: '${{ secrets.DATADIFF_TRINO_URI }}'
DATADIFF_CLICKHOUSE_URI: 'clickhouse://clickhouse:Password1@localhost:9000/clickhouse'
DATADIFF_VERTICA_URI: 'vertica://vertica:Password1@localhost:5433/vertica'
run: |
Expand Down
8 changes: 7 additions & 1 deletion data_diff/databases/base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from datetime import datetime
import math
import sys
import logging
Expand Down Expand Up @@ -120,6 +121,10 @@ def query(self, sql_ast: Union[Expr, Generator], res_type: type = list):
compiler = Compiler(self)
if isinstance(sql_ast, Generator):
sql_code = ThreadLocalInterpreter(compiler, sql_ast)
elif isinstance(sql_ast, list):
for i in sql_ast[:-1]:
self.query(i)
return self.query(sql_ast[-1], res_type)
else:
sql_code = compiler.compile(sql_ast)
if sql_code is SKIP:
Expand Down Expand Up @@ -249,7 +254,7 @@ def _refine_coltypes(self, table_path: DbPath, col_dict: Dict[str, ColType], whe
if not text_columns:
return

fields = [self.normalize_uuid(c, String_UUID()) for c in text_columns]
fields = [self.normalize_uuid(self.quote(c), String_UUID()) for c in text_columns]
samples_by_row = self.query(table(*table_path).select(*fields).where(where or SKIP).limit(sample_size), list)
if not samples_by_row:
raise ValueError(f"Table {table_path} is empty.")
Expand Down Expand Up @@ -329,6 +334,7 @@ def type_repr(self, t) -> str:
str: "VARCHAR",
bool: "BOOLEAN",
float: "FLOAT",
datetime: "TIMESTAMP",
}[t]

def _query_cursor(self, c, sql_code: str):
Expand Down
4 changes: 1 addition & 3 deletions data_diff/databases/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,6 @@ def is_autocommit(self) -> bool:

def type_repr(self, t) -> str:
try:
return {
str: "STRING",
}[t]
return {str: "STRING", float: "FLOAT64"}[t]
except KeyError:
return super().type_repr(t)
4 changes: 4 additions & 0 deletions data_diff/databases/clickhouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,3 +150,7 @@ def normalize_number(self, value: str, coltype: FractionalType) -> str:
)
"""
return value

@property
def is_autocommit(self) -> bool:
return True
4 changes: 4 additions & 0 deletions data_diff/databases/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,3 +151,7 @@ def parse_table_name(self, name: str) -> DbPath:

def close(self):
self._conn.close()

@property
def is_autocommit(self) -> bool:
return True
6 changes: 6 additions & 0 deletions data_diff/databases/presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,3 +153,9 @@ def is_autocommit(self) -> bool:

def explain_as_text(self, query: str) -> str:
return f"EXPLAIN (FORMAT TEXT) {query}"

def type_repr(self, t) -> str:
try:
return {float: "REAL"}[t]
except KeyError:
return super().type_repr(t)
50 changes: 5 additions & 45 deletions data_diff/joindiff_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@

from runtype import dataclass

from data_diff.databases.database_types import DbPath, NumericType
from data_diff.databases.base import QueryError
from .databases.database_types import DbPath, NumericType
from .query_utils import append_to_table, drop_table


from .utils import safezip
Expand Down Expand Up @@ -48,7 +48,7 @@ def sample(table_expr):
return table_expr.order_by(Random()).limit(10)


def create_temp_table(c: Compiler, path: TablePath, expr: Expr):
def create_temp_table(c: Compiler, path: TablePath, expr: Expr) -> str:
db = c.database
if isinstance(db, BigQuery):
return f"create table {c.compile(path)} OPTIONS(expiration_timestamp=TIMESTAMP_ADD(CURRENT_TIMESTAMP(), INTERVAL 1 DAY)) as {c.compile(expr)}"
Expand All @@ -60,42 +60,6 @@ def create_temp_table(c: Compiler, path: TablePath, expr: Expr):
return f"create temporary table {c.compile(path)} as {c.compile(expr)}"


def drop_table_oracle(name: DbPath):
t = table(name)
# Experience shows double drop is necessary
with suppress(QueryError):
yield t.drop()
yield t.drop()
yield commit


def drop_table(name: DbPath):
t = table(name)
yield t.drop(if_exists=True)
yield commit


def append_to_table_oracle(path: DbPath, expr: Expr):
"""See append_to_table"""
assert expr.schema, expr
t = table(path, schema=expr.schema)
with suppress(QueryError):
yield t.create() # uses expr.schema
yield commit
yield t.insert_expr(expr)
yield commit


def append_to_table(path: DbPath, expr: Expr):
"""Append to table"""
assert expr.schema, expr
t = table(path, schema=expr.schema)
yield t.create(if_not_exists=True) # uses expr.schema
yield commit
yield t.insert_expr(expr)
yield commit


def bool_to_int(x):
return if_(x, 1, 0)

Expand Down Expand Up @@ -170,10 +134,7 @@ def _diff_tables(self, table1: TableSegment, table2: TableSegment) -> DiffResult

bg_funcs = [partial(self._test_duplicate_keys, table1, table2)] if self.validate_unique_key else []
if self.materialize_to_table:
if isinstance(db, Oracle):
db.query(drop_table_oracle(self.materialize_to_table))
else:
db.query(drop_table(self.materialize_to_table))
drop_table(db, self.materialize_to_table)

with self._run_in_background(*bg_funcs):

Expand Down Expand Up @@ -348,6 +309,5 @@ def exclusive_rows(expr):
def _materialize_diff(self, db, diff_rows, segment_index=None):
assert self.materialize_to_table

f = append_to_table_oracle if isinstance(db, Oracle) else append_to_table
db.query(f(self.materialize_to_table, diff_rows.limit(self.write_limit)))
append_to_table(db, self.materialize_to_table, diff_rows.limit(self.write_limit))
logger.info("Materialized diff to table '%s'.", ".".join(self.materialize_to_table))
7 changes: 6 additions & 1 deletion data_diff/queries/api.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from typing import Optional

from data_diff.utils import CaseAwareMapping, CaseSensitiveDict
from .ast_classes import *
from .base import args_as_tuple

Expand Down Expand Up @@ -30,11 +32,14 @@ def cte(expr: Expr, *, name: Optional[str] = None, params: Sequence[str] = None)
return Cte(expr, name, params)


def table(*path: str, schema: Schema = None) -> TablePath:
def table(*path: str, schema: Union[dict, CaseAwareMapping] = None) -> TablePath:
if len(path) == 1 and isinstance(path[0], tuple):
(path,) = path
if not all(isinstance(i, str) for i in path):
raise TypeError(f"All elements of table path must be of type 'str'. Got: {path}")
if schema and not isinstance(schema, CaseAwareMapping):
assert isinstance(schema, dict)
schema = CaseSensitiveDict(schema)
return TablePath(path, schema)


Expand Down
73 changes: 64 additions & 9 deletions data_diff/queries/ast_classes.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from dataclasses import field
from datetime import datetime
from typing import Any, Generator, Optional, Sequence, Tuple, Union
from typing import Any, Generator, List, Optional, Sequence, Tuple, Union
from uuid import UUID

from runtype import dataclass

Expand Down Expand Up @@ -298,18 +299,29 @@ class TablePath(ExprNode, ITable):
path: DbPath
schema: Optional[Schema] = field(default=None, repr=False)

def create(self, if_not_exists=False):
if not self.schema:
raise ValueError("Schema must have a value to create table")
return CreateTable(self, if_not_exists=if_not_exists)
def create(self, source_table: ITable = None, *, if_not_exists=False):
if source_table is None and not self.schema:
raise ValueError("Either schema or source table needed to create table")
if isinstance(source_table, TablePath):
source_table = source_table.select()
return CreateTable(self, source_table, if_not_exists=if_not_exists)

def drop(self, if_exists=False):
return DropTable(self, if_exists=if_exists)

def insert_values(self, rows):
raise NotImplementedError()
def truncate(self):
return TruncateTable(self)

def insert_rows(self, rows, *, columns=None):
rows = list(rows)
return InsertToTable(self, ConstantTable(rows), columns=columns)

def insert_row(self, *values, columns=None):
return InsertToTable(self, ConstantTable([values]), columns=columns)

def insert_expr(self, expr: Expr):
if isinstance(expr, TablePath):
expr = expr.select()
return InsertToTable(self, expr)

@property
Expand Down Expand Up @@ -592,6 +604,29 @@ def compile(self, c: Compiler) -> str:
return c.database.random()


@dataclass
class ConstantTable(ExprNode):
rows: Sequence[Sequence]

def compile(self, c: Compiler) -> str:
raise NotImplementedError()

def _value(self, v):
if v is None:
return "NULL"
elif isinstance(v, str):
return f"'{v}'"
elif isinstance(v, datetime):
return f"timestamp '{v}'"
elif isinstance(v, UUID):
return f"'{v}'"
return repr(v)

def compile_for_insert(self, c: Compiler):
values = ", ".join("(%s)" % ", ".join(self._value(v) for v in row) for row in self.rows)
return f"VALUES {values}"


@dataclass
class Explain(ExprNode):
select: Select
Expand All @@ -610,11 +645,15 @@ class Statement(Compilable):
@dataclass
class CreateTable(Statement):
path: TablePath
source_table: Expr = None
if_not_exists: bool = False

def compile(self, c: Compiler) -> str:
schema = ", ".join(f"{k} {c.database.type_repr(v)}" for k, v in self.path.schema.items())
ne = "IF NOT EXISTS " if self.if_not_exists else ""
if self.source_table:
return f"CREATE TABLE {ne}{c.compile(self.path)} AS {c.compile(self.source_table)}"

schema = ", ".join(f"{c.database.quote(k)} {c.database.type_repr(v)}" for k, v in self.path.schema.items())
return f"CREATE TABLE {ne}{c.compile(self.path)}({schema})"


Expand All @@ -628,14 +667,30 @@ def compile(self, c: Compiler) -> str:
return f"DROP TABLE {ie}{c.compile(self.path)}"


@dataclass
class TruncateTable(Statement):
path: TablePath

def compile(self, c: Compiler) -> str:
return f"TRUNCATE TABLE {c.compile(self.path)}"


@dataclass
class InsertToTable(Statement):
# TODO Support insert for only some columns
path: TablePath
expr: Expr
columns: List[str] = None

def compile(self, c: Compiler) -> str:
return f"INSERT INTO {c.compile(self.path)} {c.compile(self.expr)}"
if isinstance(self.expr, ConstantTable):
expr = self.expr.compile_for_insert(c)
else:
expr = c.compile(self.expr)

columns = f"(%s)" % ", ".join(map(c.quote, self.columns)) if self.columns is not None else ""

return f"INSERT INTO {c.compile(self.path)}{columns} {expr}"


@dataclass
Expand Down
55 changes: 55 additions & 0 deletions data_diff/query_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
"Module for query utilities that didn't make it into the query-builder (yet)"

from contextlib import suppress

from data_diff.databases.database_types import DbPath
from data_diff.databases.base import QueryError

from .databases import Oracle
from .queries import table, commit, Expr

def _drop_table_oracle(name: DbPath):
t = table(name)
# Experience shows double drop is necessary
with suppress(QueryError):
yield t.drop()
yield t.drop()
yield commit


def _drop_table(name: DbPath):
t = table(name)
yield t.drop(if_exists=True)
yield commit


def drop_table(db, tbl):
if isinstance(db, Oracle):
db.query(_drop_table_oracle(tbl))
else:
db.query(_drop_table(tbl))


def _append_to_table_oracle(path: DbPath, expr: Expr):
"""See append_to_table"""
assert expr.schema, expr
t = table(path, schema=expr.schema)
with suppress(QueryError):
yield t.create() # uses expr.schema
yield commit
yield t.insert_expr(expr)
yield commit


def _append_to_table(path: DbPath, expr: Expr):
"""Append to table"""
assert expr.schema, expr
t = table(path, schema=expr.schema)
yield t.create(if_not_exists=True) # uses expr.schema
yield commit
yield t.insert_expr(expr)
yield commit

def append_to_table(db, path, expr):
f = _append_to_table_oracle if isinstance(db, Oracle) else _append_to_table
db.query(f(path, expr))
Loading