Skip to content

Commit

Permalink
Map INTERVAL types to Python types
Browse files Browse the repository at this point in the history
Co-authored-by: Damian Owsianny <[email protected]>
  • Loading branch information
hovaesco and damian3031 authored Oct 24, 2024
1 parent bac6ae7 commit 4c57774
Show file tree
Hide file tree
Showing 2 changed files with 115 additions and 8 deletions.
83 changes: 75 additions & 8 deletions tests/integration/test_types_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from zoneinfo import ZoneInfo

import pytest
from dateutil.relativedelta import relativedelta

import trino
from tests.integration.conftest import trino_version
Expand Down Expand Up @@ -733,14 +734,80 @@ def create_timezone(timezone_str: str) -> tzinfo:
return ZoneInfo(timezone_str)


def test_interval(trino_connection):
SqlTest(trino_connection) \
.add_field(sql="CAST(null AS INTERVAL YEAR TO MONTH)", python=None) \
.add_field(sql="CAST(null AS INTERVAL DAY TO SECOND)", python=None) \
.add_field(sql="INTERVAL '3' MONTH", python='0-3') \
.add_field(sql="INTERVAL '2' DAY", python='2 00:00:00.000') \
.add_field(sql="INTERVAL '-2' DAY", python='-2 00:00:00.000') \
.execute()
def test_interval_year_to_month(trino_connection):
(
SqlTest(trino_connection)
.add_field(
sql="CAST(null AS INTERVAL YEAR TO MONTH)",
python=None)
.add_field(
sql="INTERVAL '10' YEAR",
python=relativedelta(years=10))
.add_field(
sql="INTERVAL '-5' YEAR",
python=relativedelta(years=-5))
.add_field(
sql="INTERVAL '3' MONTH",
python=relativedelta(months=3))
.add_field(
sql="INTERVAL '-18' MONTH",
python=relativedelta(years=-1, months=-6))
.add_field(
sql="INTERVAL '30' MONTH",
python=relativedelta(years=2, months=6))
# max supported INTERVAL in Trino
.add_field(
sql="INTERVAL '178956970-7' YEAR TO MONTH",
python=relativedelta(years=178956970, months=7))
# min supported INTERVAL in Trino
.add_field(
sql="INTERVAL '-178956970-8' YEAR TO MONTH",
python=relativedelta(years=-178956970, months=-8))
).execute()


def test_interval_day_to_second(trino_connection):
(
SqlTest(trino_connection)
.add_field(
sql="CAST(null AS INTERVAL DAY TO SECOND)",
python=None)
.add_field(
sql="INTERVAL '2' DAY",
python=timedelta(days=2))
.add_field(
sql="INTERVAL '-2' DAY",
python=timedelta(days=-2))
.add_field(
sql="INTERVAL '-2' SECOND",
python=timedelta(seconds=-2))
.add_field(
sql="INTERVAL '1 11:11:11.116555' DAY TO SECOND",
python=timedelta(days=1, seconds=40271, microseconds=116000))
.add_field(
sql="INTERVAL '-5 23:59:57.000' DAY TO SECOND",
python=timedelta(days=-6, seconds=3))
.add_field(
sql="INTERVAL '12 10:45' DAY TO MINUTE",
python=timedelta(days=12, seconds=38700))
.add_field(
sql="INTERVAL '45:32.123' MINUTE TO SECOND",
python=timedelta(seconds=2732, microseconds=123000))
.add_field(
sql="INTERVAL '32.123' SECOND",
python=timedelta(seconds=32, microseconds=123000))
# max supported timedelta in Python
.add_field(
sql="INTERVAL '999999999 23:59:59.999' DAY TO SECOND",
python=timedelta(days=999999999, hours=23, minutes=59, seconds=59, milliseconds=999))
# min supported timedelta in Python
.add_field(
sql="INTERVAL '-999999999' DAY",
python=timedelta(days=-999999999))
).execute()

SqlExpectFailureTest(trino_connection).execute("INTERVAL '1000000000' DAY")
SqlExpectFailureTest(trino_connection).execute("INTERVAL '-999999999 00:00:00.001' DAY TO SECOND")


def test_array(trino_connection):
Expand Down
40 changes: 40 additions & 0 deletions trino/mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from typing import Any, Dict, Generic, List, Optional, Tuple, TypeVar
from zoneinfo import ZoneInfo

from dateutil.relativedelta import relativedelta

import trino.exceptions
from trino.types import (
POWERS_OF_TEN,
Expand Down Expand Up @@ -167,6 +169,40 @@ def _fraction_to_decimal(fractional_str: str) -> Decimal:
return Decimal(fractional_str or 0) / POWERS_OF_TEN[len(fractional_str)]


class IntervalYearToMonthMapper(ValueMapper[relativedelta]):
def map(self, value: Any) -> Optional[relativedelta]:
if value is None:
return None
is_negative = value[0] == "-"
years, months = (value[1:] if is_negative else value).split('-')
years, months = int(years), int(months)
if is_negative:
years, months = -years, -months
return relativedelta(years=years, months=months)


class IntervalDayToSecondMapper(ValueMapper[timedelta]):
def map(self, value: Any) -> Optional[timedelta]:
if value is None:
return None
is_negative = value[0] == "-"
days, time = (value[1:] if is_negative else value).split(' ')
hours, minutes, seconds_milliseconds = time.split(':')
seconds, milliseconds = seconds_milliseconds.split('.')
days, hours, minutes, seconds, milliseconds = (int(days), int(hours), int(minutes), int(seconds),
int(milliseconds))
if is_negative:
days, hours, minutes, seconds, milliseconds = -days, -hours, -minutes, -seconds, -milliseconds
try:
return timedelta(days=days, hours=hours, minutes=minutes, seconds=seconds, milliseconds=milliseconds)
except OverflowError as e:
error_str = (
f"Could not convert '{value}' into the associated python type, as the value "
"exceeds the maximum or minimum limit."
)
raise trino.exceptions.TrinoDataError(error_str) from e


class ArrayValueMapper(ValueMapper[List[Optional[Any]]]):
def __init__(self, mapper: ValueMapper[Any]):
self.mapper = mapper
Expand Down Expand Up @@ -271,6 +307,10 @@ def _create_value_mapper(self, column: Dict[str, Any]) -> ValueMapper[Any]:
return TimestampValueMapper(self._get_precision(column))
if col_type == 'timestamp with time zone':
return TimestampWithTimeZoneValueMapper(self._get_precision(column))
if col_type == 'interval year to month':
return IntervalYearToMonthMapper()
if col_type == 'interval day to second':
return IntervalDayToSecondMapper()

# structural types
if col_type == 'array':
Expand Down

0 comments on commit 4c57774

Please sign in to comment.