Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use other fallback coders for protobuf message base class #33432

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions sdks/python/apache_beam/coders/coders.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@

import google.protobuf.wrappers_pb2
import proto
from google.protobuf import message

from apache_beam.coders import coder_impl
from apache_beam.coders.avro_record import AvroRecord
Expand All @@ -65,7 +66,6 @@
from apache_beam.utils import proto_utils

if TYPE_CHECKING:
from google.protobuf import message # pylint: disable=ungrouped-imports
from apache_beam.coders.typecoders import CoderRegistry
from apache_beam.runners.pipeline_context import PipelineContext

Expand Down Expand Up @@ -1039,11 +1039,18 @@ def __hash__(self):

@classmethod
def from_type_hint(cls, typehint, unused_registry):
if issubclass(typehint, proto_utils.message_types):
# The typehint must be a strict subclass of google.protobuf.message.Message.
# ProtoCoder cannot work with message.Message itself, as deserialization of
# a serialized proto requires knowledge of the desired concrete proto
# subclass which is not stored in the encoded bytes themselves. If this
# occurs, an error is raised and the system defaults to other fallback
# coders.
if (issubclass(typehint, proto_utils.message_types) and
typehint != message.Message):
return cls(typehint)
else:
raise ValueError((
'Expected a subclass of google.protobuf.message.Message'
'Expected a strict subclass of google.protobuf.message.Message'
', but got a %s' % typehint))

def to_type_hint(self):
Expand Down
18 changes: 18 additions & 0 deletions sdks/python/apache_beam/coders/coders_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import proto
import pytest
from google.protobuf import message

import apache_beam as beam
from apache_beam import typehints
Expand Down Expand Up @@ -86,6 +87,23 @@ def test_proto_coder(self):
self.assertEqual(ma, real_coder.decode(real_coder.encode(ma)))
self.assertEqual(ma.__class__, real_coder.to_type_hint())

def test_proto_coder_on_protobuf_message_subclasses(self):
# This replicates a scenario where users provide message.Message as the
# output typehint for a Map function, even though the actual output messages
# are subclasses of message.Message.
ma = test_message.MessageA()
mb = ma.field2.add()
mb.field1 = True
ma.field1 = 'hello world'

coder = coders_registry.get_coder(message.Message)
# For messages of google.protobuf.message.Message, the fallback coder will
# be FastPrimitivesCoder rather than ProtoCoder.
# See the comment on ProtoCoder.from_type_hint() for further details.
self.assertEqual(coder, coders.FastPrimitivesCoder())

self.assertEqual(ma, coder.decode(coder.encode(ma)))


class DeterministicProtoCoderTest(unittest.TestCase):
def test_deterministic_proto_coder(self):
Expand Down
Loading