Skip to content

SEA: normalise column names in metadata queries #661

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 25 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
9de6c8b
init col norm
varun-edachali-dbx Jul 31, 2025
f5982f0
partial working dump
varun-edachali-dbx Aug 3, 2025
d97d875
refactor
varun-edachali-dbx Aug 3, 2025
c5c9859
remove callback methods for transformations
varun-edachali-dbx Aug 4, 2025
311e217
merge with main
varun-edachali-dbx Aug 4, 2025
2b1442f
remove redundant COLUMN_DATA_MAPPING
varun-edachali-dbx Aug 4, 2025
1d515e3
rename transformation functions to normalise for metadata cols
varun-edachali-dbx Aug 4, 2025
2be0c86
make mock result set be of type SeaResultSet
varun-edachali-dbx Aug 4, 2025
92c2da4
remove redundant comments
varun-edachali-dbx Aug 4, 2025
99481e9
use SqlType for type conv
varun-edachali-dbx Aug 4, 2025
f90a75e
verified: get catalogs
varun-edachali-dbx Aug 4, 2025
946e513
verif: get schemas
varun-edachali-dbx Aug 4, 2025
070b931
verif: TABLE_COLUMNS from jdbc
varun-edachali-dbx Aug 4, 2025
939c542
verif: COLUMN_COLUMNS from JDBC
varun-edachali-dbx Aug 4, 2025
7d3174f
make stuff missing from SEA None in SEA mapping
varun-edachali-dbx Aug 4, 2025
82e9c4f
remove hardcoding in SqlType
varun-edachali-dbx Aug 4, 2025
b2ae83c
move helper type name extractor out of class
varun-edachali-dbx Aug 4, 2025
9fb0444
Merge branch 'main' into col-normalisation
varun-edachali-dbx Aug 4, 2025
4be6808
ensure SeaResultSet resp
varun-edachali-dbx Aug 4, 2025
0dad966
clean up conversion code
varun-edachali-dbx Aug 4, 2025
a28596b
fix type codes
varun-edachali-dbx Aug 4, 2025
639bafa
simplify docstring
varun-edachali-dbx Aug 4, 2025
b0b58fb
nit: reduce repetition
varun-edachali-dbx Aug 4, 2025
55d8c75
test metadata mappings
varun-edachali-dbx Aug 4, 2025
ce591ce
Merge branch 'main' into col-normalisation
varun-edachali-dbx Aug 6, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 18 additions & 4 deletions src/databricks/sql/backend/sea/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
WaitTimeout,
MetadataCommands,
)
from databricks.sql.backend.sea.utils.metadata_mappings import MetadataColumnMappings
from databricks.sql.backend.sea.utils.normalize import normalize_sea_type_to_thrift
from databricks.sql.thrift_api.TCLIService import ttypes

Expand Down Expand Up @@ -700,7 +701,10 @@ def get_catalogs(
async_op=False,
enforce_embedded_schema_correctness=False,
)
assert result is not None, "execute_command returned None in synchronous mode"
assert isinstance(
result, SeaResultSet
), "Expected SeaResultSet from SEA backend"
result.prepare_metadata_columns(MetadataColumnMappings.CATALOG_COLUMNS)
return result

def get_schemas(
Expand Down Expand Up @@ -733,7 +737,10 @@ def get_schemas(
async_op=False,
enforce_embedded_schema_correctness=False,
)
assert result is not None, "execute_command returned None in synchronous mode"
assert isinstance(
result, SeaResultSet
), "Expected SeaResultSet from SEA backend"
result.prepare_metadata_columns(MetadataColumnMappings.SCHEMA_COLUMNS)
return result

def get_tables(
Expand Down Expand Up @@ -774,13 +781,17 @@ def get_tables(
async_op=False,
enforce_embedded_schema_correctness=False,
)
assert result is not None, "execute_command returned None in synchronous mode"
assert isinstance(
result, SeaResultSet
), "Expected SeaResultSet from SEA backend"

# Apply client-side filtering by table_types
from databricks.sql.backend.sea.utils.filters import ResultSetFilter

result = ResultSetFilter.filter_tables_by_type(result, table_types)

result.prepare_metadata_columns(MetadataColumnMappings.TABLE_COLUMNS)

return result

def get_columns(
Expand Down Expand Up @@ -821,5 +832,8 @@ def get_columns(
async_op=False,
enforce_embedded_schema_correctness=False,
)
assert result is not None, "execute_command returned None in synchronous mode"
assert isinstance(
result, SeaResultSet
), "Expected SeaResultSet from SEA backend"
result.prepare_metadata_columns(MetadataColumnMappings.COLUMN_COLUMNS)
return result
126 changes: 125 additions & 1 deletion src/databricks/sql/backend/sea/result_set.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from __future__ import annotations

from typing import Any, List, Optional, TYPE_CHECKING
from typing import Any, List, Optional, TYPE_CHECKING, Dict, Union

import logging

from databricks.sql.backend.sea.models.base import ResultData, ResultManifest
from databricks.sql.backend.sea.utils.conversion import SqlTypeConverter
from databricks.sql.backend.sea.utils.result_column import ResultColumn

try:
import pyarrow
Expand Down Expand Up @@ -82,6 +83,10 @@ def __init__(
arrow_schema_bytes=execute_response.arrow_schema_bytes,
)

self._metadata_columns: Optional[List[ResultColumn]] = None
# new index -> old index
self._column_index_mapping: Optional[Dict[int, Union[int, None]]] = None

def _convert_json_types(self, row: List[str]) -> List[Any]:
"""
Convert string values in the row to appropriate Python types based on column metadata.
Expand Down Expand Up @@ -160,6 +165,7 @@ def fetchmany_json(self, size: int) -> List[List[str]]:
raise ValueError(f"size argument for fetchmany is {size} but must be >= 0")

results = self.results.next_n_rows(size)
results = self._normalise_json_metadata_cols(results)
self._next_row_index += len(results)

return results
Expand All @@ -173,6 +179,7 @@ def fetchall_json(self) -> List[List[str]]:
"""

results = self.results.remaining_rows()
results = self._normalise_json_metadata_cols(results)
self._next_row_index += len(results)

return results
Expand All @@ -198,6 +205,7 @@ def fetchmany_arrow(self, size: int) -> "pyarrow.Table":
results = self.results.next_n_rows(size)
if isinstance(self.results, JsonQueue):
results = self._convert_json_to_arrow_table(results)
results = self._normalise_arrow_metadata_cols(results)

self._next_row_index += results.num_rows

Expand All @@ -211,6 +219,7 @@ def fetchall_arrow(self) -> "pyarrow.Table":
results = self.results.remaining_rows()
if isinstance(self.results, JsonQueue):
results = self._convert_json_to_arrow_table(results)
results = self._normalise_arrow_metadata_cols(results)

self._next_row_index += results.num_rows

Expand Down Expand Up @@ -263,3 +272,118 @@ def fetchall(self) -> List[Row]:
return self._create_json_table(self.fetchall_json())
else:
return self._convert_arrow_table(self.fetchall_arrow())

def prepare_metadata_columns(self, metadata_columns: List[ResultColumn]) -> None:
"""
Prepare result set for metadata column normalization.

Args:
metadata_columns: List of ResultColumn objects defining the expected columns
and their mappings from SEA column names
"""
self._metadata_columns = metadata_columns
self._prepare_column_mapping()

def _prepare_column_mapping(self) -> None:
"""
Prepare column index mapping for metadata queries.
Updates description to use Thrift column names.
"""
# Ensure description is available
if not self.description:
raise ValueError("Cannot prepare column mapping without result description")

# Build mapping from SEA column names to their indices
sea_column_indices = {}
for idx, col in enumerate(self.description):
sea_column_indices[col[0]] = idx

# Create new description and index mapping
new_description = []
self._column_index_mapping = {} # Maps new index -> old index

for new_idx, result_column in enumerate(self._metadata_columns or []):
# Determine the old index and get column metadata
if (
result_column.sea_col_name
and result_column.sea_col_name in sea_column_indices
):
old_idx = sea_column_indices[result_column.sea_col_name]
old_col = self.description[old_idx]
# Use original column metadata
display_size, internal_size, precision, scale, null_ok = old_col[2:7]
else:
old_idx = None
# Use None values for missing columns
display_size, internal_size, precision, scale, null_ok = (
None,
None,
None,
None,
True,
)

# Set the mapping
self._column_index_mapping[new_idx] = old_idx

# Create the new description entry
new_description.append(
(
result_column.thrift_col_name, # Thrift (normalised) name
result_column.thrift_col_type, # Expected type
display_size,
internal_size,
precision,
scale,
null_ok,
)
)

self.description = new_description

def _normalise_arrow_metadata_cols(self, table: "pyarrow.Table") -> "pyarrow.Table":
"""Transform arrow table columns for metadata normalization."""
if not self._metadata_columns or len(table.schema) == 0:
return table

# Reorder columns and add missing ones
new_columns = []
column_names = []

for new_idx, result_column in enumerate(self._metadata_columns or []):
old_idx = (
self._column_index_mapping.get(new_idx, None)
if self._column_index_mapping
else None
)

column = (
pyarrow.nulls(table.num_rows)
if old_idx is None
else table.column(old_idx)
)
new_columns.append(column)

column_names.append(result_column.thrift_col_name)

return pyarrow.Table.from_arrays(new_columns, names=column_names)

def _normalise_json_metadata_cols(self, rows: List[List[str]]) -> List[List[Any]]:
"""Transform JSON rows for metadata normalization."""
if not self._metadata_columns or len(rows) == 0:
return rows

transformed_rows = []
for row in rows:
new_row = []
for new_idx, result_column in enumerate(self._metadata_columns or []):
old_idx = (
self._column_index_mapping.get(new_idx, None)
if self._column_index_mapping
else None
)

value = None if old_idx is None else row[old_idx]
new_row.append(value)
transformed_rows.append(new_row)
return transformed_rows
56 changes: 32 additions & 24 deletions src/databricks/sql/backend/sea/utils/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from dateutil import parser
from typing import Callable, Dict, Optional

from databricks.sql.thrift_api.TCLIService import ttypes

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -48,6 +50,14 @@ def _convert_decimal(
return result


def _get_type_name(thrift_type_id: int) -> str:
type_name = ttypes.TTypeId._VALUES_TO_NAMES[thrift_type_id]
type_name = type_name.lower()
if type_name.endswith("_type"):
type_name = type_name[:-5]
return type_name


class SqlType:
"""
SQL type constants based on Thrift TTypeId values.
Expand All @@ -57,42 +67,40 @@ class SqlType:
"""

# Numeric types
TINYINT = "tinyint" # Maps to TTypeId.TINYINT_TYPE
SMALLINT = "smallint" # Maps to TTypeId.SMALLINT_TYPE
INT = "int" # Maps to TTypeId.INT_TYPE
BIGINT = "bigint" # Maps to TTypeId.BIGINT_TYPE
FLOAT = "float" # Maps to TTypeId.FLOAT_TYPE
DOUBLE = "double" # Maps to TTypeId.DOUBLE_TYPE
DECIMAL = "decimal" # Maps to TTypeId.DECIMAL_TYPE
TINYINT = _get_type_name(ttypes.TTypeId.TINYINT_TYPE)
SMALLINT = _get_type_name(ttypes.TTypeId.SMALLINT_TYPE)
INT = _get_type_name(ttypes.TTypeId.INT_TYPE)
BIGINT = _get_type_name(ttypes.TTypeId.BIGINT_TYPE)
FLOAT = _get_type_name(ttypes.TTypeId.FLOAT_TYPE)
DOUBLE = _get_type_name(ttypes.TTypeId.DOUBLE_TYPE)
DECIMAL = _get_type_name(ttypes.TTypeId.DECIMAL_TYPE)

# Boolean type
BOOLEAN = "boolean" # Maps to TTypeId.BOOLEAN_TYPE
BOOLEAN = _get_type_name(ttypes.TTypeId.BOOLEAN_TYPE)

# Date/Time types
DATE = "date" # Maps to TTypeId.DATE_TYPE
TIMESTAMP = "timestamp" # Maps to TTypeId.TIMESTAMP_TYPE
INTERVAL_YEAR_MONTH = (
"interval_year_month" # Maps to TTypeId.INTERVAL_YEAR_MONTH_TYPE
)
INTERVAL_DAY_TIME = "interval_day_time" # Maps to TTypeId.INTERVAL_DAY_TIME_TYPE
DATE = _get_type_name(ttypes.TTypeId.DATE_TYPE)
TIMESTAMP = _get_type_name(ttypes.TTypeId.TIMESTAMP_TYPE)
INTERVAL_YEAR_MONTH = _get_type_name(ttypes.TTypeId.INTERVAL_YEAR_MONTH_TYPE)
INTERVAL_DAY_TIME = _get_type_name(ttypes.TTypeId.INTERVAL_DAY_TIME_TYPE)

# String types
CHAR = "char" # Maps to TTypeId.CHAR_TYPE
VARCHAR = "varchar" # Maps to TTypeId.VARCHAR_TYPE
STRING = "string" # Maps to TTypeId.STRING_TYPE
CHAR = _get_type_name(ttypes.TTypeId.CHAR_TYPE)
VARCHAR = _get_type_name(ttypes.TTypeId.VARCHAR_TYPE)
STRING = _get_type_name(ttypes.TTypeId.STRING_TYPE)

# Binary type
BINARY = "binary" # Maps to TTypeId.BINARY_TYPE
BINARY = _get_type_name(ttypes.TTypeId.BINARY_TYPE)

# Complex types
ARRAY = "array" # Maps to TTypeId.ARRAY_TYPE
MAP = "map" # Maps to TTypeId.MAP_TYPE
STRUCT = "struct" # Maps to TTypeId.STRUCT_TYPE
ARRAY = _get_type_name(ttypes.TTypeId.ARRAY_TYPE)
MAP = _get_type_name(ttypes.TTypeId.MAP_TYPE)
STRUCT = _get_type_name(ttypes.TTypeId.STRUCT_TYPE)

# Other types
NULL = "null" # Maps to TTypeId.NULL_TYPE
UNION = "union" # Maps to TTypeId.UNION_TYPE
USER_DEFINED = "user_defined" # Maps to TTypeId.USER_DEFINED_TYPE
NULL = _get_type_name(ttypes.TTypeId.NULL_TYPE)
UNION = _get_type_name(ttypes.TTypeId.UNION_TYPE)
USER_DEFINED = _get_type_name(ttypes.TTypeId.USER_DEFINED_TYPE)


class SqlTypeConverter:
Expand Down
Loading
Loading