Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

tests for unique key constraints (if possible) instead of always actively validating (+ tests) #257

Merged
merged 5 commits into from
Nov 1, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions data_diff/databases/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,8 @@ class Database(AbstractDatabase):
TYPE_CLASSES: Dict[str, type] = {}
default_schema: str = None
SUPPORTS_ALPHANUMS = True
SUPPORTS_PRIMARY_KEY = False
SUPPORTS_UNIQUE_CONSTAINT = False

_interactive = False

Expand Down Expand Up @@ -235,6 +237,21 @@ def query_table_schema(self, path: DbPath) -> Dict[str, tuple]:
assert len(d) == len(rows)
return d

def select_table_unique_columns(self, path: DbPath) -> str:
schema, table = self._normalize_table_path(path)

return (
"SELECT column_name "
"FROM information_schema.key_column_usage "
f"WHERE table_name = '{table}' AND table_schema = '{schema}'"
)

def query_table_unique_columns(self, path: DbPath) -> List[str]:
if not self.SUPPORTS_UNIQUE_CONSTAINT:
raise NotImplementedError("This database doesn't support 'unique' constraints")
res = self.query(self.select_table_unique_columns(path), List[str])
return list(res)

def _process_table_schema(
self, path: DbPath, raw_schema: Dict[str, tuple], filter_columns: Sequence[str], where: str = None
):
Expand Down
5 changes: 4 additions & 1 deletion data_diff/databases/bigquery.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Union
from typing import List, Union
from .database_types import Timestamp, Datetime, Integer, Decimal, Float, Text, DbPath, FractionalType, TemporalType
from .base import Database, import_helper, parse_table_name, ConnectError, apply_query
from .base import TIMESTAMP_PRECISION_POS, ThreadLocalInterpreter
Expand Down Expand Up @@ -78,6 +78,9 @@ def select_table_schema(self, path: DbPath) -> str:
f"WHERE table_name = '{table}' AND table_schema = '{schema}'"
)

def query_table_unique_columns(self, path: DbPath) -> List[str]:
return []

def normalize_timestamp(self, value: str, coltype: TemporalType) -> str:
if coltype.rounds:
timestamp = f"timestamp_micros(cast(round(unix_micros(cast({value} as timestamp))/1000000, {coltype.precision})*1000000 as int))"
Expand Down
10 changes: 10 additions & 0 deletions data_diff/databases/database_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,16 @@ def query_table_schema(self, path: DbPath) -> Dict[str, tuple]:
"""
...

@abstractmethod
def select_table_unique_columns(self, path: DbPath) -> str:
"Provide SQL for selecting the names of unique columns in the table"
...

@abstractmethod
def query_table_unique_columns(self, path: DbPath) -> List[str]:
"""Query the table for its unique columns for table in 'path', and return {column}"""
...

@abstractmethod
def _process_table_schema(
self, path: DbPath, raw_schema: Dict[str, tuple], filter_columns: Sequence[str], where: str = None
Expand Down
2 changes: 2 additions & 0 deletions data_diff/databases/mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ class MySQL(ThreadedDatabase):
}
ROUNDS_ON_PREC_LOSS = True
SUPPORTS_ALPHANUMS = False
SUPPORTS_PRIMARY_KEY = True
SUPPORTS_UNIQUE_CONSTAINT = True

def __init__(self, *, thread_count, **kw):
self._args = kw
Expand Down
1 change: 1 addition & 0 deletions data_diff/databases/oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class Oracle(ThreadedDatabase):
"VARCHAR2": Text,
}
ROUNDS_ON_PREC_LOSS = True
SUPPORTS_PRIMARY_KEY = True

def __init__(self, *, host, database, thread_count, **kw):
self.kwargs = dict(dsn=f"{host}/{database}" if database else host, **kw)
Expand Down
2 changes: 2 additions & 0 deletions data_diff/databases/postgresql.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ class PostgreSQL(ThreadedDatabase):
"uuid": Native_UUID,
}
ROUNDS_ON_PREC_LOSS = True
SUPPORTS_PRIMARY_KEY = True
SUPPORTS_UNIQUE_CONSTAINT = True

default_schema = "public"

Expand Down
5 changes: 4 additions & 1 deletion data_diff/databases/snowflake.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Union
from typing import Union, List
import logging

from .database_types import Timestamp, TimestampTZ, Decimal, Float, Text, FractionalType, TemporalType, DbPath
Expand Down Expand Up @@ -95,3 +95,6 @@ def is_autocommit(self) -> bool:

def explain_as_text(self, query: str) -> str:
return f"EXPLAIN USING TEXT {query}"

def query_table_unique_columns(self, path: DbPath) -> List[str]:
return []
16 changes: 11 additions & 5 deletions data_diff/joindiff_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,18 +195,24 @@ def _diff_segments(
if not is_xa:
yield "+", tuple(b_row)

def _test_duplicate_keys(self, table1, table2):
def _test_duplicate_keys(self, table1: TableSegment, table2: TableSegment):
logger.debug("Testing for duplicate keys")

# Test duplicate keys
for ts in [table1, table2]:
unique = ts.database.query_table_unique_columns(ts.table_path) if ts.database.SUPPORTS_UNIQUE_CONSTAINT else []

t = ts.make_select()
key_columns = ts.key_columns

q = t.select(total=Count(), total_distinct=Count(Concat(this[key_columns]), distinct=True))
total, total_distinct = ts.database.query(q, tuple)
if total != total_distinct:
raise ValueError("Duplicate primary keys")
unvalidated = list(set(key_columns) - set(unique))
if unvalidated:
# Validate that there are no duplicate keys
self.stats["validated_unique_keys"] = self.stats.get("validated_unique_keys", []) + [unvalidated]
q = t.select(total=Count(), total_distinct=Count(Concat(this[unvalidated]), distinct=True))
total, total_distinct = ts.database.query(q, tuple)
if total != total_distinct:
raise ValueError("Duplicate primary keys")

def _test_null_keys(self, table1, table2):
logger.debug("Testing for null keys")
Expand Down
14 changes: 10 additions & 4 deletions data_diff/queries/ast_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,13 +306,13 @@ def compile(self, c: Compiler) -> str:
return ".".join(map(c.quote, path))

# Statement shorthands
def create(self, source_table: ITable = None, *, if_not_exists=False, primary_keys=None):

def create(self, source_table: ITable = None, *, if_not_exists=False):
if source_table is None and not self.schema:
raise ValueError("Either schema or source table needed to create table")
if isinstance(source_table, TablePath):
source_table = source_table.select()
return CreateTable(self, source_table, if_not_exists=if_not_exists)
return CreateTable(self, source_table, if_not_exists=if_not_exists, primary_keys=primary_keys)

def drop(self, if_exists=False):
return DropTable(self, if_exists=if_exists)
Expand Down Expand Up @@ -641,14 +641,20 @@ class CreateTable(Statement):
path: TablePath
source_table: Expr = None
if_not_exists: bool = False
primary_keys: List[str] = None

def compile(self, c: Compiler) -> str:
ne = "IF NOT EXISTS " if self.if_not_exists else ""
if self.source_table:
return f"CREATE TABLE {ne}{c.compile(self.path)} AS {c.compile(self.source_table)}"

schema = ", ".join(f"{c.quote(k)} {c.database.type_repr(v)}" for k, v in self.path.schema.items())
return f"CREATE TABLE {ne}{c.compile(self.path)}({schema})"
schema = ", ".join(f"{c.database.quote(k)} {c.database.type_repr(v)}" for k, v in self.path.schema.items())
pks = (
", PRIMARY KEY (%s)" % ", ".join(self.primary_keys)
if self.primary_keys and c.database.SUPPORTS_PRIMARY_KEY
else ""
)
return f"CREATE TABLE {ne}{c.compile(self.path)}({schema}{pks})"


@dataclass
Expand Down
1 change: 1 addition & 0 deletions tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ def tearDown(self):


def _parameterized_class_per_conn(test_databases):
test_databases = set(test_databases)
names = [(cls.__name__, cls) for cls in CONN_STRINGS if cls in test_databases]
return parameterized_class(("name", "db_cls"), names)

Expand Down
46 changes: 46 additions & 0 deletions tests/test_joindiff.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,3 +273,49 @@ def test_null_pks(self):

x = self.differ.diff_tables(self.table, self.table2)
self.assertRaises(ValueError, list, x)


@test_each_database_in_list(d for d in TEST_DATABASES if d.SUPPORTS_PRIMARY_KEY and d.SUPPORTS_UNIQUE_CONSTAINT)
class TestUniqueConstraint(TestPerDatabase):
def setUp(self):
super().setUp()

self.src_table = table(
self.table_src_path,
schema={"id": int, "userid": int, "movieid": int, "rating": float},
)
self.dst_table = table(
self.table_dst_path,
schema={"id": int, "userid": int, "movieid": int, "rating": float},
)

self.connection.query(
[self.src_table.create(primary_keys=["id"]), self.dst_table.create(primary_keys=["id", "userid"]), commit]
)

self.differ = JoinDiffer()

def test_unique_constraint(self):
self.connection.query(
[
self.src_table.insert_rows([[1, 1, 1, 9], [2, 2, 2, 9]]),
self.dst_table.insert_rows([[1, 1, 1, 9], [2, 2, 2, 9]]),
commit,
]
)

# Test no active validation
table = TableSegment(self.connection, self.table_src_path, ("id",), case_sensitive=False)
table2 = TableSegment(self.connection, self.table_dst_path, ("id",), case_sensitive=False)

res = list(self.differ.diff_tables(table, table2))
assert not res
assert "validated_unique_keys" not in self.differ.stats

# Test active validation
table = TableSegment(self.connection, self.table_src_path, ("userid",), case_sensitive=False)
table2 = TableSegment(self.connection, self.table_dst_path, ("userid",), case_sensitive=False)

res = list(self.differ.diff_tables(table, table2))
assert not res
self.assertEqual(self.differ.stats["validated_unique_keys"], [["userid"]])