Skip to content

SEA: Normalise Column Values from Metadata queries #662

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 9 commits into
base: col-normalisation
Choose a base branch
from
23 changes: 22 additions & 1 deletion src/databricks/sql/backend/sea/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,12 @@
MetadataCommands,
)
from databricks.sql.backend.sea.utils.metadata_mappings import MetadataColumnMappings
from databricks.sql.backend.sea.utils.metadata_transforms import (
create_table_catalog_transform,
)
from databricks.sql.backend.sea.utils.normalize import normalize_sea_type_to_thrift
from databricks.sql.backend.sea.utils.result_column import ResultColumn
from databricks.sql.backend.sea.utils.conversion import SqlType
from databricks.sql.thrift_api.TCLIService import ttypes

if TYPE_CHECKING:
Expand Down Expand Up @@ -740,7 +745,23 @@ def get_schemas(
assert isinstance(
result, SeaResultSet
), "Expected SeaResultSet from SEA backend"
result.prepare_metadata_columns(MetadataColumnMappings.SCHEMA_COLUMNS)

# Create dynamic schema columns with catalog name bound to TABLE_CATALOG
schema_columns = []
for col in MetadataColumnMappings.SCHEMA_COLUMNS:
if col.thrift_col_name == "TABLE_CATALOG":
# Create a new column with the catalog transform bound
dynamic_col = ResultColumn(
col.thrift_col_name,
col.sea_col_name,
col.thrift_col_type,
create_table_catalog_transform(catalog_name),
)
schema_columns.append(dynamic_col)
else:
schema_columns.append(col)

result.prepare_metadata_columns(schema_columns)
return result

def get_tables(
Expand Down
17 changes: 13 additions & 4 deletions src/databricks/sql/backend/sea/result_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ def _prepare_column_mapping(self) -> None:
None,
None,
None,
True,
None,
)

# Set the mapping
Expand Down Expand Up @@ -356,14 +356,20 @@ def _normalise_arrow_metadata_cols(self, table: "pyarrow.Table") -> "pyarrow.Tab
if self._column_index_mapping
else None
)

column = (
pyarrow.nulls(table.num_rows)
if old_idx is None
else table.column(old_idx)
)
new_columns.append(column)

# Apply transform if available
if result_column.transform_value:
# Convert to list, apply transform, and convert back
values = column.to_pylist()
transformed_values = [result_column.transform_value(v) for v in values]
column = pyarrow.array(transformed_values)

new_columns.append(column)
column_names.append(result_column.thrift_col_name)

return pyarrow.Table.from_arrays(new_columns, names=column_names)
Expand All @@ -382,8 +388,11 @@ def _normalise_json_metadata_cols(self, rows: List[List[str]]) -> List[List[Any]
if self._column_index_mapping
else None
)

value = None if old_idx is None else row[old_idx]

# Apply transform if available
if result_column.transform_value:
value = result_column.transform_value(value)
new_row.append(value)
transformed_rows.append(new_row)
return transformed_rows
30 changes: 25 additions & 5 deletions src/databricks/sql/backend/sea/utils/metadata_mappings.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
from databricks.sql.backend.sea.utils.result_column import ResultColumn
from databricks.sql.backend.sea.utils.conversion import SqlType
from databricks.sql.backend.sea.utils.metadata_transforms import (
transform_remarks,
transform_is_autoincrement,
transform_is_nullable,
transform_nullable,
transform_data_type,
transform_ordinal_position,
)


class MetadataColumnMappings:
Expand All @@ -18,7 +26,9 @@ class MetadataColumnMappings:
SCHEMA_COLUMN = ResultColumn("TABLE_SCHEM", "namespace", SqlType.STRING)
TABLE_NAME_COLUMN = ResultColumn("TABLE_NAME", "tableName", SqlType.STRING)
TABLE_TYPE_COLUMN = ResultColumn("TABLE_TYPE", "tableType", SqlType.STRING)
REMARKS_COLUMN = ResultColumn("REMARKS", "remarks", SqlType.STRING)
REMARKS_COLUMN = ResultColumn(
"REMARKS", "remarks", SqlType.STRING, transform_remarks
)
TYPE_CATALOG_COLUMN = ResultColumn("TYPE_CAT", None, SqlType.STRING)
TYPE_SCHEM_COLUMN = ResultColumn("TYPE_SCHEM", None, SqlType.STRING)
TYPE_NAME_COLUMN = ResultColumn("TYPE_NAME", None, SqlType.STRING)
Expand All @@ -28,7 +38,9 @@ class MetadataColumnMappings:
REF_GENERATION_COLUMN = ResultColumn("REF_GENERATION", None, SqlType.STRING)

COL_NAME_COLUMN = ResultColumn("COLUMN_NAME", "col_name", SqlType.STRING)
DATA_TYPE_COLUMN = ResultColumn("DATA_TYPE", None, SqlType.INT)
DATA_TYPE_COLUMN = ResultColumn(
"DATA_TYPE", "columnType", SqlType.INT, transform_data_type
)
COLUMN_TYPE_COLUMN = ResultColumn("TYPE_NAME", "columnType", SqlType.STRING)
COLUMN_SIZE_COLUMN = ResultColumn("COLUMN_SIZE", "columnSize", SqlType.INT)
BUFFER_LENGTH_COLUMN = ResultColumn("BUFFER_LENGTH", None, SqlType.TINYINT)
Expand All @@ -43,22 +55,30 @@ class MetadataColumnMappings:
"ORDINAL_POSITION",
"ordinalPosition",
SqlType.INT,
transform_ordinal_position,
)

NULLABLE_COLUMN = ResultColumn("NULLABLE", None, SqlType.INT)
NULLABLE_COLUMN = ResultColumn(
"NULLABLE", "isNullable", SqlType.INT, transform_nullable
)
COLUMN_DEF_COLUMN = ResultColumn("COLUMN_DEF", None, SqlType.STRING)
SQL_DATA_TYPE_COLUMN = ResultColumn("SQL_DATA_TYPE", None, SqlType.INT)
SQL_DATETIME_SUB_COLUMN = ResultColumn("SQL_DATETIME_SUB", None, SqlType.INT)
CHAR_OCTET_LENGTH_COLUMN = ResultColumn("CHAR_OCTET_LENGTH", None, SqlType.INT)
IS_NULLABLE_COLUMN = ResultColumn("IS_NULLABLE", "isNullable", SqlType.STRING)
IS_NULLABLE_COLUMN = ResultColumn(
"IS_NULLABLE", "isNullable", SqlType.STRING, transform_is_nullable
)

SCOPE_CATALOG_COLUMN = ResultColumn("SCOPE_CATALOG", None, SqlType.STRING)
SCOPE_SCHEMA_COLUMN = ResultColumn("SCOPE_SCHEMA", None, SqlType.STRING)
SCOPE_TABLE_COLUMN = ResultColumn("SCOPE_TABLE", None, SqlType.STRING)
SOURCE_DATA_TYPE_COLUMN = ResultColumn("SOURCE_DATA_TYPE", None, SqlType.SMALLINT)

IS_AUTO_INCREMENT_COLUMN = ResultColumn(
"IS_AUTOINCREMENT", "isAutoIncrement", SqlType.STRING
"IS_AUTO_INCREMENT",
"isAutoIncrement",
SqlType.STRING,
transform_is_autoincrement,
)
IS_GENERATED_COLUMN = ResultColumn(
"IS_GENERATEDCOLUMN", "isGenerated", SqlType.STRING
Expand Down
83 changes: 83 additions & 0 deletions src/databricks/sql/backend/sea/utils/metadata_transforms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
"""Simple transformation functions for metadata value normalization."""


def transform_is_autoincrement(value):
"""Transform IS_AUTOINCREMENT: boolean to YES/NO string."""
if isinstance(value, bool) or value is None:
return "YES" if value else "NO"
return value


def transform_is_nullable(value):
"""Transform IS_NULLABLE: true/false to YES/NO string."""
if value is True or value == "true":
return "YES"
elif value is False or value == "false":
return "NO"
return value


def transform_remarks(value):
if value is None:
return ""
return value


def transform_nullable(value):
"""Transform NULLABLE column: boolean/string to integer."""
if value is True or value == "true" or value == "YES":
return 1
elif value is False or value == "false" or value == "NO":
return 0
return value


# Type code mapping based on JDBC specification
TYPE_CODE_MAP = {
"STRING": 12, # VARCHAR
"VARCHAR": 12, # VARCHAR
"CHAR": 1, # CHAR
"INT": 4, # INTEGER
"INTEGER": 4, # INTEGER
"BIGINT": -5, # BIGINT
"SMALLINT": 5, # SMALLINT
"TINYINT": -6, # TINYINT
"DOUBLE": 8, # DOUBLE
"FLOAT": 6, # FLOAT
"REAL": 7, # REAL
"DECIMAL": 3, # DECIMAL
"NUMERIC": 2, # NUMERIC
"BOOLEAN": 16, # BOOLEAN
"DATE": 91, # DATE
"TIMESTAMP": 93, # TIMESTAMP
"BINARY": -2, # BINARY
"ARRAY": 2003, # ARRAY
"MAP": 2002, # JAVA_OBJECT
"STRUCT": 2002, # JAVA_OBJECT
}


def transform_data_type(value):
"""Transform DATA_TYPE: type name to JDBC type code."""
if isinstance(value, str):
# Handle parameterized types like DECIMAL(10,2)
base_type = value.split("(")[0].upper()
return TYPE_CODE_MAP.get(base_type, value)
return value


def transform_ordinal_position(value):
"""Transform ORDINAL_POSITION: decrement by 1 (1-based to 0-based)."""
if isinstance(value, int):
return value - 1
return value


def create_table_catalog_transform(catalog_name):
"""Factory function to create TABLE_CATALOG transform with bound catalog name."""

def transform_table_catalog(value):
"""Transform TABLE_CATALOG: return the catalog name for all rows."""
return catalog_name

return transform_table_catalog
4 changes: 3 additions & 1 deletion src/databricks/sql/backend/sea/utils/result_column.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import Optional
from typing import Optional, Callable, Any


@dataclass(frozen=True)
Expand All @@ -11,8 +11,10 @@ class ResultColumn:
thrift_col_name: Column name as returned by Thrift (e.g., "TABLE_CAT")
sea_col_name: Server result column name from SEA (e.g., "catalog")
thrift_col_type: SQL type name
transform_value: Optional callback to transform values for this column
"""

thrift_col_name: str
sea_col_name: Optional[str] # None if SEA doesn't return this column
thrift_col_type: str
transform_value: Optional[Callable[[Any], Any]] = None
13 changes: 11 additions & 2 deletions tests/e2e/test_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,8 +562,17 @@ def test_get_schemas(self):
finally:
cursor.execute("DROP DATABASE IF EXISTS {}".format(database_name))

def test_get_catalogs(self):
with self.cursor({}) as cursor:
@pytest.mark.parametrize(
"backend_params",
[
{},
{
"use_sea": True,
},
],
)
def test_get_catalogs(self, backend_params):
with self.cursor(backend_params) as cursor:
cursor.catalogs()
cursor.fetchall()
catalogs_desc = cursor.description
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_metadata_mappings.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def test_column_columns_mapping(self):
"TABLE_SCHEM": ("namespace", SqlType.STRING),
"TABLE_NAME": ("tableName", SqlType.STRING),
"COLUMN_NAME": ("col_name", SqlType.STRING),
"DATA_TYPE": (None, SqlType.INT),
"DATA_TYPE": ("columnType", SqlType.INT),
"TYPE_NAME": ("columnType", SqlType.STRING),
"COLUMN_SIZE": ("columnSize", SqlType.INT),
"DECIMAL_DIGITS": ("decimalDigits", SqlType.INT),
Expand Down
3 changes: 0 additions & 3 deletions tests/unit/test_sea_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -826,9 +826,6 @@ def test_get_schemas(self, sea_client, sea_session_id, mock_cursor):

# Verify prepare_metadata_columns was called for successful cases
assert mock_result_set.prepare_metadata_columns.call_count == 2
mock_result_set.prepare_metadata_columns.assert_called_with(
MetadataColumnMappings.SCHEMA_COLUMNS
)

def test_get_tables(self, sea_client, sea_session_id, mock_cursor):
"""Test the get_tables method with various parameter combinations."""
Expand Down
Loading