Skip to content

Commit

Permalink
fixup! Support spooled protocol
Browse files Browse the repository at this point in the history
  • Loading branch information
mdesmet committed Nov 25, 2024
1 parent e339a49 commit 2dc1909
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 77 deletions.
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,13 +83,14 @@
],
python_requires=">=3.9",
install_requires=[
"lz4",
"python-dateutil",
"pytz",
# requests CVE https://github.com/advisories/GHSA-j8r2-6x86-q33q
"requests>=2.31.0",
"typing_extensions",
"tzlocal",
"zstandard",
"lz4",
],
extras_require={
"all": all_require,
Expand Down
98 changes: 96 additions & 2 deletions trino/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,10 @@
"""
from __future__ import annotations

import base64
import copy
import functools
import json
import os
import random
import re
Expand All @@ -46,10 +48,13 @@
from datetime import datetime
from email.utils import parsedate_to_datetime
from time import sleep
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
from zoneinfo import ZoneInfo

import lz4.block
import requests
import zstandard
from typing_extensions import NotRequired, TypedDict
from tzlocal import get_localzone_name # type: ignore

import trino.logging
Expand Down Expand Up @@ -858,7 +863,14 @@ def fetch(self) -> List[List[Any]]:
if not self._row_mapper:
return []

return self._row_mapper.map(status.rows)
rows = status.rows
if isinstance(rows, dict):
# spooled protocol
encoding = rows["encoding"]
segments = rows["segments"]
return list(SegmentIterator(segments, encoding, self._row_mapper, self._request))
else:
return self._row_mapper.map(rows)

def cancel(self) -> None:
"""Cancel the current query"""
Expand Down Expand Up @@ -934,3 +946,85 @@ def _parse_retry_after_header(retry_after):
retry_date = parsedate_to_datetime(retry_after)
now = datetime.utcnow()
return (retry_date - now).total_seconds()


# Trino Spooled protocol transfer objects
SpooledSegmentMetadata = TypedDict('SpooledSegmentMetadata', {'uncompressedSize': str})
SpooledSegment = TypedDict(
'SpooledSegment',
{
'encoding': str,
'uri': str,
'ackUri': NotRequired[str],
'data': List[List[Any]],
'metadata': SpooledSegmentMetadata
}
)


class SegmentIterator:
def __init__(self, segments: List[SpooledSegment], encoding: str, row_mapper: RowMapper, request: TrinoRequest):
self._segments = iter(segments)
self._encoding = encoding
self._row_mapper = row_mapper
self._request = request
self._rows: Iterator[List[List[Any]]] = iter([])
self._finished = False
self._current_segment: Optional[SpooledSegment] = None

def __iter__(self) -> Iterator[List[Any]]:
return self

def __next__(self) -> List[Any]:
# If rows are exhausted, fetch the next segment
while True:
try:
return next(self._rows)
except StopIteration:
if self._current_segment and "ackUri" in self._current_segment:
ack_uri = self._current_segment["ackUri"]
http_response = self._request._get(ack_uri)
if not http_response.ok:
self._request.raise_response_error(http_response)
if self._finished:
raise StopIteration
self._load_next_row_set()

def _load_next_row_set(self):
try:
self._current_segment = next(self._segments)
segment_type = self._current_segment["type"]

if segment_type == "inline":
data = self._current_segment["data"]
decoded_string = base64.b64decode(data)
rows = self._row_mapper.map(json.loads(decoded_string))
self._rows = iter(rows)

elif segment_type == "spooled":
uri = self._current_segment["uri"]
decoded_string = self._load_spooled_segment(uri, self._encoding)
rows = self._row_mapper.map(json.loads(decoded_string))
self._rows = iter(rows)
else:
raise ValueError(f"Unsupported segment type: {segment_type}")

except StopIteration:
self._finished = True

def _load_spooled_segment(self, uri, encoding):
http_response = self._request._get(uri, stream=True)
if not http_response.ok:
self._request.raise_response_error(http_response)

content = http_response.content
if encoding == "json+zstd":
zstd_decompressor = zstandard.ZstdDecompressor()
return zstd_decompressor.decompress(content).decode('utf-8')
elif encoding == "json+lz4":
expected_size = self._current_segment["metadata"]["uncompressedSize"]
return lz4.block.decompress(content, uncompressed_size=int(expected_size)).decode('utf-8')
elif encoding == "json":
return content.decode('utf-8')
else:
raise ValueError(f"Unsupported encoding: {encoding}")
78 changes: 4 additions & 74 deletions trino/mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,12 @@

import abc
import base64
import json
import uuid
from datetime import date, datetime, time, timedelta, timezone, tzinfo
from decimal import Decimal
from typing import Any, Dict, Generic, List, Optional, Tuple, TypeVar, Union
from typing import Any, Dict, Generic, List, Optional, Tuple, TypeVar
from zoneinfo import ZoneInfo

import lz4.block
import requests
import zstandard
from dateutil.relativedelta import relativedelta

import trino.exceptions
Expand Down Expand Up @@ -353,76 +349,10 @@ class RowMapper:
def __init__(self, columns: List[ValueMapper[Any]]):
self.columns = columns

def map(self, rows: Union[List[Any], Dict[str, Any]]) -> List[List[Any]]:
if isinstance(rows, list) and len(self.columns) == 0:
def map(self, rows: List[List[Any]]) -> List[List[Any]]:
if len(self.columns) == 0:
return rows
if isinstance(rows, dict):
# spooled protocol
# TODO: implement ackUrl acknowledging
# TODO: refactor to cleaner code
# TODO: should probably stream rows
encoding = rows["encoding"]
segments = rows["segments"]
rows_to_return = []
for segment in segments:
segment_type = segment["type"]
if segment_type == "inline":
data = segment["data"]
decoded_string = base64.b64decode(data)
rows_to_return += [self._map_row(row) for row in json.loads(decoded_string)]
elif segment_type == "spooled":
if encoding == "json+zstd":
response = requests.get(segment["uri"], stream=True)
if response.status_code == 200:
dctx = zstandard.ZstdDecompressor()
# TODO: Investigate why streaming didn't work
# with dctx.stream_reader(response.raw) as decompressed_stream:
# decompressed_data = io.TextIOWrapper(decompressed_stream, encoding="utf-8")
# rows_to_return += [self._map_row(row) for row in json.load(decompressed_data)]
decompressed_data = dctx.decompress(response.content)
decompressed_text = decompressed_data.decode('utf-8')
rows_to_return += [self._map_row(row) for row in json.loads(decompressed_text)]
else:
raise Exception("TODO: implement retrying")
elif encoding == "json+lz4":
response = requests.get(segment["uri"], stream=True)
if response.status_code == 200:
# Read all compressed data
# Decompress the data
content = response.content

expected_decompressed_size = segment["metadata"]["uncompressedSize"]
decompressed_data = lz4.block.decompress(
content,
uncompressed_size=expected_decompressed_size
)

# Check if decompressed data size matches the expected size
decompressed_size = len(decompressed_data)
if decompressed_size != expected_decompressed_size:
raise Exception(
f"Decompressed size does not match expected size, expected "
f"{expected_decompressed_size}, got {decompressed_size}")

# Decode decompressed data and process JSON
decompressed_text = decompressed_data.decode('utf-8')
rows_to_return += [self._map_row(row) for row in json.loads(decompressed_text)]
else:
raise Exception("TODO: implement retrying")
elif encoding == "json":
response = requests.get(segment["uri"], stream=True)
if response.status_code == 200:
rows_to_return += [self._map_row(row) for row in response.json()]
else:
raise Exception("TODO: implement retrying")
else:
raise ValueError(f"Unsupported encoding: {encoding}")
else:
raise ValueError(f"Unsupported segment type: {segment_type}")
return rows_to_return
else:
# legacy driver
return [self._map_row(row) for row in rows]
return [self._map_row(row) for row in rows]

def _map_row(self, row: List[Any]) -> List[Any]:
return [self._map_value(value, self.columns[index]) for index, value in enumerate(row)]
Expand Down

0 comments on commit 2dc1909

Please sign in to comment.