aboutsummaryrefslogtreecommitdiff
path: root/python/google/protobuf/text_format.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/google/protobuf/text_format.py')
-rwxr-xr-xpython/google/protobuf/text_format.py257
1 files changed, 180 insertions, 77 deletions
diff --git a/python/google/protobuf/text_format.py b/python/google/protobuf/text_format.py
index 06b79d77..2cbd21bc 100755
--- a/python/google/protobuf/text_format.py
+++ b/python/google/protobuf/text_format.py
@@ -126,7 +126,8 @@ def MessageToString(message,
float_format=None,
use_field_number=False,
descriptor_pool=None,
- indent=0):
+ indent=0,
+ message_formatter=None):
"""Convert protobuf message to text format.
Floating point values can be formatted compactly with 15 digits of
@@ -140,14 +141,19 @@ def MessageToString(message,
as_one_line: Don't introduce newlines between fields.
pointy_brackets: If True, use angle brackets instead of curly braces for
nesting.
- use_index_order: If True, print fields of a proto message using the order
- defined in source code instead of the field number. By default, use the
- field number order.
+ use_index_order: If True, fields of a proto message will be printed using
+ the order defined in source code instead of the field number, extensions
+ will be printed at the end of the message and their relative order is
+ determined by the extension number. By default, use the field number
+ order.
float_format: If set, use this to specify floating point number formatting
(per the "Format Specification Mini-Language"); otherwise, str() is used.
use_field_number: If True, print field numbers instead of names.
descriptor_pool: A DescriptorPool used to resolve Any types.
indent: The indent level, in terms of spaces, for pretty print.
+ message_formatter: A function(message, indent, as_one_line): unicode|None
+ to custom format selected sub-messages (usually based on message type).
+ Use to pretty print parts of the protobuf for easier diffing.
Returns:
A string of the text formatted protocol buffer message.
@@ -155,7 +161,7 @@ def MessageToString(message,
out = TextWriter(as_utf8)
printer = _Printer(out, indent, as_utf8, as_one_line, pointy_brackets,
use_index_order, float_format, use_field_number,
- descriptor_pool)
+ descriptor_pool, message_formatter)
printer.PrintMessage(message)
result = out.getvalue()
out.close()
@@ -179,10 +185,11 @@ def PrintMessage(message,
use_index_order=False,
float_format=None,
use_field_number=False,
- descriptor_pool=None):
+ descriptor_pool=None,
+ message_formatter=None):
printer = _Printer(out, indent, as_utf8, as_one_line, pointy_brackets,
use_index_order, float_format, use_field_number,
- descriptor_pool)
+ descriptor_pool, message_formatter)
printer.PrintMessage(message)
@@ -194,10 +201,11 @@ def PrintField(field,
as_one_line=False,
pointy_brackets=False,
use_index_order=False,
- float_format=None):
+ float_format=None,
+ message_formatter=None):
"""Print a single field name/value pair."""
printer = _Printer(out, indent, as_utf8, as_one_line, pointy_brackets,
- use_index_order, float_format)
+ use_index_order, float_format, message_formatter)
printer.PrintField(field, value)
@@ -209,10 +217,11 @@ def PrintFieldValue(field,
as_one_line=False,
pointy_brackets=False,
use_index_order=False,
- float_format=None):
+ float_format=None,
+ message_formatter=None):
"""Print a single field value (not including name)."""
printer = _Printer(out, indent, as_utf8, as_one_line, pointy_brackets,
- use_index_order, float_format)
+ use_index_order, float_format, message_formatter)
printer.PrintFieldValue(field, value)
@@ -228,13 +237,16 @@ def _BuildMessageFromTypeName(type_name, descriptor_pool):
wasn't found matching type_name.
"""
# pylint: disable=g-import-not-at-top
- from google.protobuf import message_factory
- factory = message_factory.MessageFactory(descriptor_pool)
+ if descriptor_pool is None:
+ from google.protobuf import descriptor_pool as pool_mod
+ descriptor_pool = pool_mod.Default()
+ from google.protobuf import symbol_database
+ database = symbol_database.Default()
try:
message_descriptor = descriptor_pool.FindMessageTypeByName(type_name)
except KeyError:
return None
- message_type = factory.GetPrototype(message_descriptor)
+ message_type = database.GetPrototype(message_descriptor)
return message_type()
@@ -250,7 +262,8 @@ class _Printer(object):
use_index_order=False,
float_format=None,
use_field_number=False,
- descriptor_pool=None):
+ descriptor_pool=None,
+ message_formatter=None):
"""Initialize the Printer.
Floating point values can be formatted compactly with 15 digits of
@@ -273,6 +286,9 @@ class _Printer(object):
used.
use_field_number: If True, print field numbers instead of names.
descriptor_pool: A DescriptorPool used to resolve Any types.
+ message_formatter: A function(message, indent, as_one_line): unicode|None
+ to custom format selected sub-messages (usually based on message type).
+ Use to pretty print parts of the protobuf for easier diffing.
"""
self.out = out
self.indent = indent
@@ -283,6 +299,7 @@ class _Printer(object):
self.float_format = float_format
self.use_field_number = use_field_number
self.descriptor_pool = descriptor_pool
+ self.message_formatter = message_formatter
def _TryPrintAsAnyMessage(self, message):
"""Serializes if message is a google.protobuf.Any field."""
@@ -297,28 +314,41 @@ class _Printer(object):
else:
return False
+ def _TryCustomFormatMessage(self, message):
+ formatted = self.message_formatter(message, self.indent, self.as_one_line)
+ if formatted is None:
+ return False
+
+ out = self.out
+ out.write(' ' * self.indent)
+ out.write(formatted)
+ out.write(' ' if self.as_one_line else '\n')
+ return True
+
def PrintMessage(self, message):
"""Convert protobuf message to text format.
Args:
message: The protocol buffers message.
"""
+ if self.message_formatter and self._TryCustomFormatMessage(message):
+ return
if (message.DESCRIPTOR.full_name == _ANY_FULL_TYPE_NAME and
- self.descriptor_pool and self._TryPrintAsAnyMessage(message)):
+ self._TryPrintAsAnyMessage(message)):
return
fields = message.ListFields()
if self.use_index_order:
- fields.sort(key=lambda x: x[0].index)
+ fields.sort(
+ key=lambda x: x[0].number if x[0].is_extension else x[0].index)
for field, value in fields:
if _IsMapEntry(field):
for key in sorted(value):
- # This is slow for maps with submessage entires because it copies the
+ # This is slow for maps with submessage entries because it copies the
# entire tree. Unfortunately this would take significant refactoring
# of this file to work around.
#
# TODO(haberman): refactor and optimize if this becomes an issue.
- entry_submsg = field.message_type._concrete_class(key=key,
- value=value[key])
+ entry_submsg = value.GetEntryClass()(key=key, value=value[key])
self.PrintField(field, entry_submsg)
elif field.label == descriptor.FieldDescriptor.LABEL_REPEATED:
for element in value:
@@ -423,15 +453,33 @@ class _Printer(object):
def Parse(text,
message,
allow_unknown_extension=False,
- allow_field_number=False):
+ allow_field_number=False,
+ descriptor_pool=None):
"""Parses a text representation of a protocol message into a message.
+ NOTE: for historical reasons this function does not clear the input
+ message. This is different from what the binary msg.ParseFrom(...) does.
+
+ Example
+ a = MyProto()
+ a.repeated_field.append('test')
+ b = MyProto()
+
+ text_format.Parse(repr(a), b)
+ text_format.Parse(repr(a), b) # repeated_field contains ["test", "test"]
+
+ # Binary version:
+ b.ParseFromString(a.SerializeToString()) # repeated_field is now "test"
+
+ Caller is responsible for clearing the message as needed.
+
Args:
text: Message text representation.
message: A protocol buffer message to merge into.
allow_unknown_extension: if True, skip over missing extensions and keep
parsing
allow_field_number: if True, both field number and field name are allowed.
+ descriptor_pool: A DescriptorPool used to resolve Any types.
Returns:
The same message passed as argument.
@@ -440,9 +488,15 @@ def Parse(text,
ParseError: On text parsing problems.
"""
if not isinstance(text, str):
- text = text.decode('utf-8')
- return ParseLines(
- text.split('\n'), message, allow_unknown_extension, allow_field_number)
+ if six.PY3:
+ text = text.decode('utf-8')
+ else:
+ text = text.encode('utf-8')
+ return ParseLines(text.split('\n'),
+ message,
+ allow_unknown_extension,
+ allow_field_number,
+ descriptor_pool=descriptor_pool)
def Merge(text,
@@ -469,6 +523,11 @@ def Merge(text,
Raises:
ParseError: On text parsing problems.
"""
+ if not isinstance(text, str):
+ if six.PY3:
+ text = text.decode('utf-8')
+ else:
+ text = text.encode('utf-8')
return MergeLines(
text.split('\n'),
message,
@@ -480,7 +539,8 @@ def Merge(text,
def ParseLines(lines,
message,
allow_unknown_extension=False,
- allow_field_number=False):
+ allow_field_number=False,
+ descriptor_pool=None):
"""Parses a text representation of a protocol message into a message.
Args:
@@ -497,7 +557,9 @@ def ParseLines(lines,
Raises:
ParseError: On text parsing problems.
"""
- parser = _Parser(allow_unknown_extension, allow_field_number)
+ parser = _Parser(allow_unknown_extension,
+ allow_field_number,
+ descriptor_pool=descriptor_pool)
return parser.ParseLines(lines, message)
@@ -514,6 +576,7 @@ def MergeLines(lines,
allow_unknown_extension: if True, skip over missing extensions and keep
parsing
allow_field_number: if True, both field number and field name are allowed.
+ descriptor_pool: A DescriptorPool used to resolve Any types.
Returns:
The same message passed as argument.
@@ -585,11 +648,30 @@ class _Parser(object):
ParseError: In case of text parsing problems.
"""
message_descriptor = message.DESCRIPTOR
- if (hasattr(message_descriptor, 'syntax') and
- message_descriptor.syntax == 'proto3'):
- # Proto3 doesn't represent presence so we can't test if multiple
- # scalars have occurred. We have to allow them.
- self._allow_multiple_scalars = True
+ if (message_descriptor.full_name == _ANY_FULL_TYPE_NAME and
+ tokenizer.TryConsume('[')):
+ type_url_prefix, packed_type_name = self._ConsumeAnyTypeUrl(tokenizer)
+ tokenizer.Consume(']')
+ tokenizer.TryConsume(':')
+ if tokenizer.TryConsume('<'):
+ expanded_any_end_token = '>'
+ else:
+ tokenizer.Consume('{')
+ expanded_any_end_token = '}'
+ expanded_any_sub_message = _BuildMessageFromTypeName(packed_type_name,
+ self.descriptor_pool)
+ if not expanded_any_sub_message:
+ raise ParseError('Type %s not found in descriptor pool' %
+ packed_type_name)
+ while not tokenizer.TryConsume(expanded_any_end_token):
+ if tokenizer.AtEnd():
+ raise tokenizer.ParseErrorPreviousToken('Expected "%s".' %
+ (expanded_any_end_token,))
+ self._MergeField(tokenizer, expanded_any_sub_message)
+ message.Pack(expanded_any_sub_message,
+ type_url_prefix=type_url_prefix)
+ return
+
if tokenizer.TryConsume('['):
name = [tokenizer.ConsumeIdentifier()]
while tokenizer.TryConsume('.'):
@@ -608,7 +690,11 @@ class _Parser(object):
field = None
else:
raise tokenizer.ParseErrorPreviousToken(
- 'Extension "%s" not registered.' % name)
+ 'Extension "%s" not registered. '
+ 'Did you import the _pb2 module which defines it? '
+ 'If you are trying to place the extension in the MessageSet '
+ 'field of another message that is in an Any or MessageSet field, '
+ 'that message\'s _pb2 module must be imported as well' % name)
elif message_descriptor != field.containing_type:
raise tokenizer.ParseErrorPreviousToken(
'Extension "%s" does not extend message type "%s".' %
@@ -666,11 +752,12 @@ class _Parser(object):
if (field.label == descriptor.FieldDescriptor.LABEL_REPEATED and
tokenizer.TryConsume('[')):
# Short repeated format, e.g. "foo: [1, 2, 3]"
- while True:
- merger(tokenizer, message, field)
- if tokenizer.TryConsume(']'):
- break
- tokenizer.Consume(',')
+ if not tokenizer.TryConsume(']'):
+ while True:
+ merger(tokenizer, message, field)
+ if tokenizer.TryConsume(']'):
+ break
+ tokenizer.Consume(',')
else:
merger(tokenizer, message, field)
@@ -687,17 +774,17 @@ class _Parser(object):
def _ConsumeAnyTypeUrl(self, tokenizer):
"""Consumes a google.protobuf.Any type URL and returns the type name."""
# Consume "type.googleapis.com/".
- tokenizer.ConsumeIdentifier()
+ prefix = [tokenizer.ConsumeIdentifier()]
tokenizer.Consume('.')
- tokenizer.ConsumeIdentifier()
+ prefix.append(tokenizer.ConsumeIdentifier())
tokenizer.Consume('.')
- tokenizer.ConsumeIdentifier()
+ prefix.append(tokenizer.ConsumeIdentifier())
tokenizer.Consume('/')
# Consume the fully-qualified type name.
name = [tokenizer.ConsumeIdentifier()]
while tokenizer.TryConsume('.'):
name.append(tokenizer.ConsumeIdentifier())
- return '.'.join(name)
+ return '.'.join(prefix), '.'.join(name)
def _MergeMessageField(self, tokenizer, message, field):
"""Merges a single scalar field into a message.
@@ -718,45 +805,29 @@ class _Parser(object):
tokenizer.Consume('{')
end_token = '}'
- if (field.message_type.full_name == _ANY_FULL_TYPE_NAME and
- tokenizer.TryConsume('[')):
- packed_type_name = self._ConsumeAnyTypeUrl(tokenizer)
- tokenizer.Consume(']')
- tokenizer.TryConsume(':')
- if tokenizer.TryConsume('<'):
- expanded_any_end_token = '>'
- else:
- tokenizer.Consume('{')
- expanded_any_end_token = '}'
- if not self.descriptor_pool:
- raise ParseError('Descriptor pool required to parse expanded Any field')
- expanded_any_sub_message = _BuildMessageFromTypeName(packed_type_name,
- self.descriptor_pool)
- if not expanded_any_sub_message:
- raise ParseError('Type %s not found in descriptor pool' %
- packed_type_name)
- while not tokenizer.TryConsume(expanded_any_end_token):
- if tokenizer.AtEnd():
- raise tokenizer.ParseErrorPreviousToken('Expected "%s".' %
- (expanded_any_end_token,))
- self._MergeField(tokenizer, expanded_any_sub_message)
- if field.label == descriptor.FieldDescriptor.LABEL_REPEATED:
- any_message = getattr(message, field.name).add()
- else:
- any_message = getattr(message, field.name)
- any_message.Pack(expanded_any_sub_message)
- elif field.label == descriptor.FieldDescriptor.LABEL_REPEATED:
+ if field.label == descriptor.FieldDescriptor.LABEL_REPEATED:
if field.is_extension:
sub_message = message.Extensions[field].add()
elif is_map_entry:
- # pylint: disable=protected-access
- sub_message = field.message_type._concrete_class()
+ sub_message = getattr(message, field.name).GetEntryClass()()
else:
sub_message = getattr(message, field.name).add()
else:
if field.is_extension:
+ if (not self._allow_multiple_scalars and
+ message.HasExtension(field)):
+ raise tokenizer.ParseErrorPreviousToken(
+ 'Message type "%s" should not have multiple "%s" extensions.' %
+ (message.DESCRIPTOR.full_name, field.full_name))
sub_message = message.Extensions[field]
else:
+ # Also apply _allow_multiple_scalars to message field.
+ # TODO(jieluo): Change to _allow_singular_overwrites.
+ if (not self._allow_multiple_scalars and
+ message.HasField(field.name)):
+ raise tokenizer.ParseErrorPreviousToken(
+ 'Message type "%s" should not have multiple "%s" fields.' %
+ (message.DESCRIPTOR.full_name, field.name))
sub_message = getattr(message, field.name)
sub_message.SetInParent()
@@ -773,6 +844,12 @@ class _Parser(object):
else:
getattr(message, field.name)[sub_message.key] = sub_message.value
+ @staticmethod
+ def _IsProto3Syntax(message):
+ message_descriptor = message.DESCRIPTOR
+ return (hasattr(message_descriptor, 'syntax') and
+ message_descriptor.syntax == 'proto3')
+
def _MergeScalarField(self, tokenizer, message, field):
"""Merges a single scalar field into a message.
@@ -822,15 +899,20 @@ class _Parser(object):
else:
getattr(message, field.name).append(value)
else:
+ # Proto3 doesn't represent presence so we can't test if multiple scalars
+ # have occurred. We have to allow them.
+ can_check_presence = not self._IsProto3Syntax(message)
if field.is_extension:
- if not self._allow_multiple_scalars and message.HasExtension(field):
+ if (not self._allow_multiple_scalars and can_check_presence and
+ message.HasExtension(field)):
raise tokenizer.ParseErrorPreviousToken(
'Message type "%s" should not have multiple "%s" extensions.' %
(message.DESCRIPTOR.full_name, field.full_name))
else:
message.Extensions[field] = value
else:
- if not self._allow_multiple_scalars and message.HasField(field.name):
+ if (not self._allow_multiple_scalars and can_check_presence and
+ message.HasField(field.name)):
raise tokenizer.ParseErrorPreviousToken(
'Message type "%s" should not have multiple "%s" fields.' %
(message.DESCRIPTOR.full_name, field.name))
@@ -870,7 +952,7 @@ def _SkipField(tokenizer):
tokenizer.ConsumeIdentifier()
tokenizer.Consume(']')
else:
- tokenizer.ConsumeIdentifier()
+ tokenizer.ConsumeIdentifierOrNumber()
_SkipFieldContents(tokenizer)
@@ -1025,6 +1107,22 @@ class Tokenizer(object):
self.NextToken()
return result
+ def ConsumeCommentOrTrailingComment(self):
+ """Consumes a comment, returns a 2-tuple (trailing bool, comment str)."""
+
+ # Tokenizer initializes _previous_line and _previous_column to 0. As the
+ # tokenizer starts, it looks like there is a previous token on the line.
+ just_started = self._line == 0 and self._column == 0
+
+ before_parsing = self._previous_line
+ comment = self.ConsumeComment()
+
+ # A trailing comment is a comment on the same line than the previous token.
+ trailing = (self._previous_line == before_parsing
+ and not just_started)
+
+ return trailing, comment
+
def TryConsumeIdentifier(self):
try:
self.ConsumeIdentifier()
@@ -1065,7 +1163,7 @@ class Tokenizer(object):
"""
result = self.token
if not self._IDENTIFIER_OR_NUMBER.match(result):
- raise self.ParseError('Expected identifier or number.')
+ raise self.ParseError('Expected identifier or number, got %s.' % result)
self.NextToken()
return result
@@ -1448,9 +1546,9 @@ def ParseBool(text):
Raises:
ValueError: If text is not a valid boolean.
"""
- if text in ('true', 't', '1'):
+ if text in ('true', 't', '1', 'True'):
return True
- elif text in ('false', 'f', '0'):
+ elif text in ('false', 'f', '0', 'False'):
return False
else:
raise ValueError('Expected "true" or "false".')
@@ -1483,6 +1581,11 @@ def ParseEnum(field, value):
(enum_descriptor.full_name, value))
else:
# Numeric value.
+ if hasattr(field.file, 'syntax'):
+ # Attribute is checked for compatibility.
+ if field.file.syntax == 'proto3':
+ # Proto3 accept numeric unknown enums.
+ return number
enum_value = enum_descriptor.values_by_number.get(number, None)
if enum_value is None:
raise ValueError('Enum type "%s" has no value with number %d.' %