diff options
Diffstat (limited to 'python/google/protobuf/text_format.py')
-rwxr-xr-x | python/google/protobuf/text_format.py | 257 |
1 files changed, 180 insertions, 77 deletions
diff --git a/python/google/protobuf/text_format.py b/python/google/protobuf/text_format.py index 06b79d77..2cbd21bc 100755 --- a/python/google/protobuf/text_format.py +++ b/python/google/protobuf/text_format.py @@ -126,7 +126,8 @@ def MessageToString(message, float_format=None, use_field_number=False, descriptor_pool=None, - indent=0): + indent=0, + message_formatter=None): """Convert protobuf message to text format. Floating point values can be formatted compactly with 15 digits of @@ -140,14 +141,19 @@ def MessageToString(message, as_one_line: Don't introduce newlines between fields. pointy_brackets: If True, use angle brackets instead of curly braces for nesting. - use_index_order: If True, print fields of a proto message using the order - defined in source code instead of the field number. By default, use the - field number order. + use_index_order: If True, fields of a proto message will be printed using + the order defined in source code instead of the field number, extensions + will be printed at the end of the message and their relative order is + determined by the extension number. By default, use the field number + order. float_format: If set, use this to specify floating point number formatting (per the "Format Specification Mini-Language"); otherwise, str() is used. use_field_number: If True, print field numbers instead of names. descriptor_pool: A DescriptorPool used to resolve Any types. indent: The indent level, in terms of spaces, for pretty print. + message_formatter: A function(message, indent, as_one_line): unicode|None + to custom format selected sub-messages (usually based on message type). + Use to pretty print parts of the protobuf for easier diffing. Returns: A string of the text formatted protocol buffer message. @@ -155,7 +161,7 @@ def MessageToString(message, out = TextWriter(as_utf8) printer = _Printer(out, indent, as_utf8, as_one_line, pointy_brackets, use_index_order, float_format, use_field_number, - descriptor_pool) + descriptor_pool, message_formatter) printer.PrintMessage(message) result = out.getvalue() out.close() @@ -179,10 +185,11 @@ def PrintMessage(message, use_index_order=False, float_format=None, use_field_number=False, - descriptor_pool=None): + descriptor_pool=None, + message_formatter=None): printer = _Printer(out, indent, as_utf8, as_one_line, pointy_brackets, use_index_order, float_format, use_field_number, - descriptor_pool) + descriptor_pool, message_formatter) printer.PrintMessage(message) @@ -194,10 +201,11 @@ def PrintField(field, as_one_line=False, pointy_brackets=False, use_index_order=False, - float_format=None): + float_format=None, + message_formatter=None): """Print a single field name/value pair.""" printer = _Printer(out, indent, as_utf8, as_one_line, pointy_brackets, - use_index_order, float_format) + use_index_order, float_format, message_formatter) printer.PrintField(field, value) @@ -209,10 +217,11 @@ def PrintFieldValue(field, as_one_line=False, pointy_brackets=False, use_index_order=False, - float_format=None): + float_format=None, + message_formatter=None): """Print a single field value (not including name).""" printer = _Printer(out, indent, as_utf8, as_one_line, pointy_brackets, - use_index_order, float_format) + use_index_order, float_format, message_formatter) printer.PrintFieldValue(field, value) @@ -228,13 +237,16 @@ def _BuildMessageFromTypeName(type_name, descriptor_pool): wasn't found matching type_name. """ # pylint: disable=g-import-not-at-top - from google.protobuf import message_factory - factory = message_factory.MessageFactory(descriptor_pool) + if descriptor_pool is None: + from google.protobuf import descriptor_pool as pool_mod + descriptor_pool = pool_mod.Default() + from google.protobuf import symbol_database + database = symbol_database.Default() try: message_descriptor = descriptor_pool.FindMessageTypeByName(type_name) except KeyError: return None - message_type = factory.GetPrototype(message_descriptor) + message_type = database.GetPrototype(message_descriptor) return message_type() @@ -250,7 +262,8 @@ class _Printer(object): use_index_order=False, float_format=None, use_field_number=False, - descriptor_pool=None): + descriptor_pool=None, + message_formatter=None): """Initialize the Printer. Floating point values can be formatted compactly with 15 digits of @@ -273,6 +286,9 @@ class _Printer(object): used. use_field_number: If True, print field numbers instead of names. descriptor_pool: A DescriptorPool used to resolve Any types. + message_formatter: A function(message, indent, as_one_line): unicode|None + to custom format selected sub-messages (usually based on message type). + Use to pretty print parts of the protobuf for easier diffing. """ self.out = out self.indent = indent @@ -283,6 +299,7 @@ class _Printer(object): self.float_format = float_format self.use_field_number = use_field_number self.descriptor_pool = descriptor_pool + self.message_formatter = message_formatter def _TryPrintAsAnyMessage(self, message): """Serializes if message is a google.protobuf.Any field.""" @@ -297,28 +314,41 @@ class _Printer(object): else: return False + def _TryCustomFormatMessage(self, message): + formatted = self.message_formatter(message, self.indent, self.as_one_line) + if formatted is None: + return False + + out = self.out + out.write(' ' * self.indent) + out.write(formatted) + out.write(' ' if self.as_one_line else '\n') + return True + def PrintMessage(self, message): """Convert protobuf message to text format. Args: message: The protocol buffers message. """ + if self.message_formatter and self._TryCustomFormatMessage(message): + return if (message.DESCRIPTOR.full_name == _ANY_FULL_TYPE_NAME and - self.descriptor_pool and self._TryPrintAsAnyMessage(message)): + self._TryPrintAsAnyMessage(message)): return fields = message.ListFields() if self.use_index_order: - fields.sort(key=lambda x: x[0].index) + fields.sort( + key=lambda x: x[0].number if x[0].is_extension else x[0].index) for field, value in fields: if _IsMapEntry(field): for key in sorted(value): - # This is slow for maps with submessage entires because it copies the + # This is slow for maps with submessage entries because it copies the # entire tree. Unfortunately this would take significant refactoring # of this file to work around. # # TODO(haberman): refactor and optimize if this becomes an issue. - entry_submsg = field.message_type._concrete_class(key=key, - value=value[key]) + entry_submsg = value.GetEntryClass()(key=key, value=value[key]) self.PrintField(field, entry_submsg) elif field.label == descriptor.FieldDescriptor.LABEL_REPEATED: for element in value: @@ -423,15 +453,33 @@ class _Printer(object): def Parse(text, message, allow_unknown_extension=False, - allow_field_number=False): + allow_field_number=False, + descriptor_pool=None): """Parses a text representation of a protocol message into a message. + NOTE: for historical reasons this function does not clear the input + message. This is different from what the binary msg.ParseFrom(...) does. + + Example + a = MyProto() + a.repeated_field.append('test') + b = MyProto() + + text_format.Parse(repr(a), b) + text_format.Parse(repr(a), b) # repeated_field contains ["test", "test"] + + # Binary version: + b.ParseFromString(a.SerializeToString()) # repeated_field is now "test" + + Caller is responsible for clearing the message as needed. + Args: text: Message text representation. message: A protocol buffer message to merge into. allow_unknown_extension: if True, skip over missing extensions and keep parsing allow_field_number: if True, both field number and field name are allowed. + descriptor_pool: A DescriptorPool used to resolve Any types. Returns: The same message passed as argument. @@ -440,9 +488,15 @@ def Parse(text, ParseError: On text parsing problems. """ if not isinstance(text, str): - text = text.decode('utf-8') - return ParseLines( - text.split('\n'), message, allow_unknown_extension, allow_field_number) + if six.PY3: + text = text.decode('utf-8') + else: + text = text.encode('utf-8') + return ParseLines(text.split('\n'), + message, + allow_unknown_extension, + allow_field_number, + descriptor_pool=descriptor_pool) def Merge(text, @@ -469,6 +523,11 @@ def Merge(text, Raises: ParseError: On text parsing problems. """ + if not isinstance(text, str): + if six.PY3: + text = text.decode('utf-8') + else: + text = text.encode('utf-8') return MergeLines( text.split('\n'), message, @@ -480,7 +539,8 @@ def Merge(text, def ParseLines(lines, message, allow_unknown_extension=False, - allow_field_number=False): + allow_field_number=False, + descriptor_pool=None): """Parses a text representation of a protocol message into a message. Args: @@ -497,7 +557,9 @@ def ParseLines(lines, Raises: ParseError: On text parsing problems. """ - parser = _Parser(allow_unknown_extension, allow_field_number) + parser = _Parser(allow_unknown_extension, + allow_field_number, + descriptor_pool=descriptor_pool) return parser.ParseLines(lines, message) @@ -514,6 +576,7 @@ def MergeLines(lines, allow_unknown_extension: if True, skip over missing extensions and keep parsing allow_field_number: if True, both field number and field name are allowed. + descriptor_pool: A DescriptorPool used to resolve Any types. Returns: The same message passed as argument. @@ -585,11 +648,30 @@ class _Parser(object): ParseError: In case of text parsing problems. """ message_descriptor = message.DESCRIPTOR - if (hasattr(message_descriptor, 'syntax') and - message_descriptor.syntax == 'proto3'): - # Proto3 doesn't represent presence so we can't test if multiple - # scalars have occurred. We have to allow them. - self._allow_multiple_scalars = True + if (message_descriptor.full_name == _ANY_FULL_TYPE_NAME and + tokenizer.TryConsume('[')): + type_url_prefix, packed_type_name = self._ConsumeAnyTypeUrl(tokenizer) + tokenizer.Consume(']') + tokenizer.TryConsume(':') + if tokenizer.TryConsume('<'): + expanded_any_end_token = '>' + else: + tokenizer.Consume('{') + expanded_any_end_token = '}' + expanded_any_sub_message = _BuildMessageFromTypeName(packed_type_name, + self.descriptor_pool) + if not expanded_any_sub_message: + raise ParseError('Type %s not found in descriptor pool' % + packed_type_name) + while not tokenizer.TryConsume(expanded_any_end_token): + if tokenizer.AtEnd(): + raise tokenizer.ParseErrorPreviousToken('Expected "%s".' % + (expanded_any_end_token,)) + self._MergeField(tokenizer, expanded_any_sub_message) + message.Pack(expanded_any_sub_message, + type_url_prefix=type_url_prefix) + return + if tokenizer.TryConsume('['): name = [tokenizer.ConsumeIdentifier()] while tokenizer.TryConsume('.'): @@ -608,7 +690,11 @@ class _Parser(object): field = None else: raise tokenizer.ParseErrorPreviousToken( - 'Extension "%s" not registered.' % name) + 'Extension "%s" not registered. ' + 'Did you import the _pb2 module which defines it? ' + 'If you are trying to place the extension in the MessageSet ' + 'field of another message that is in an Any or MessageSet field, ' + 'that message\'s _pb2 module must be imported as well' % name) elif message_descriptor != field.containing_type: raise tokenizer.ParseErrorPreviousToken( 'Extension "%s" does not extend message type "%s".' % @@ -666,11 +752,12 @@ class _Parser(object): if (field.label == descriptor.FieldDescriptor.LABEL_REPEATED and tokenizer.TryConsume('[')): # Short repeated format, e.g. "foo: [1, 2, 3]" - while True: - merger(tokenizer, message, field) - if tokenizer.TryConsume(']'): - break - tokenizer.Consume(',') + if not tokenizer.TryConsume(']'): + while True: + merger(tokenizer, message, field) + if tokenizer.TryConsume(']'): + break + tokenizer.Consume(',') else: merger(tokenizer, message, field) @@ -687,17 +774,17 @@ class _Parser(object): def _ConsumeAnyTypeUrl(self, tokenizer): """Consumes a google.protobuf.Any type URL and returns the type name.""" # Consume "type.googleapis.com/". - tokenizer.ConsumeIdentifier() + prefix = [tokenizer.ConsumeIdentifier()] tokenizer.Consume('.') - tokenizer.ConsumeIdentifier() + prefix.append(tokenizer.ConsumeIdentifier()) tokenizer.Consume('.') - tokenizer.ConsumeIdentifier() + prefix.append(tokenizer.ConsumeIdentifier()) tokenizer.Consume('/') # Consume the fully-qualified type name. name = [tokenizer.ConsumeIdentifier()] while tokenizer.TryConsume('.'): name.append(tokenizer.ConsumeIdentifier()) - return '.'.join(name) + return '.'.join(prefix), '.'.join(name) def _MergeMessageField(self, tokenizer, message, field): """Merges a single scalar field into a message. @@ -718,45 +805,29 @@ class _Parser(object): tokenizer.Consume('{') end_token = '}' - if (field.message_type.full_name == _ANY_FULL_TYPE_NAME and - tokenizer.TryConsume('[')): - packed_type_name = self._ConsumeAnyTypeUrl(tokenizer) - tokenizer.Consume(']') - tokenizer.TryConsume(':') - if tokenizer.TryConsume('<'): - expanded_any_end_token = '>' - else: - tokenizer.Consume('{') - expanded_any_end_token = '}' - if not self.descriptor_pool: - raise ParseError('Descriptor pool required to parse expanded Any field') - expanded_any_sub_message = _BuildMessageFromTypeName(packed_type_name, - self.descriptor_pool) - if not expanded_any_sub_message: - raise ParseError('Type %s not found in descriptor pool' % - packed_type_name) - while not tokenizer.TryConsume(expanded_any_end_token): - if tokenizer.AtEnd(): - raise tokenizer.ParseErrorPreviousToken('Expected "%s".' % - (expanded_any_end_token,)) - self._MergeField(tokenizer, expanded_any_sub_message) - if field.label == descriptor.FieldDescriptor.LABEL_REPEATED: - any_message = getattr(message, field.name).add() - else: - any_message = getattr(message, field.name) - any_message.Pack(expanded_any_sub_message) - elif field.label == descriptor.FieldDescriptor.LABEL_REPEATED: + if field.label == descriptor.FieldDescriptor.LABEL_REPEATED: if field.is_extension: sub_message = message.Extensions[field].add() elif is_map_entry: - # pylint: disable=protected-access - sub_message = field.message_type._concrete_class() + sub_message = getattr(message, field.name).GetEntryClass()() else: sub_message = getattr(message, field.name).add() else: if field.is_extension: + if (not self._allow_multiple_scalars and + message.HasExtension(field)): + raise tokenizer.ParseErrorPreviousToken( + 'Message type "%s" should not have multiple "%s" extensions.' % + (message.DESCRIPTOR.full_name, field.full_name)) sub_message = message.Extensions[field] else: + # Also apply _allow_multiple_scalars to message field. + # TODO(jieluo): Change to _allow_singular_overwrites. + if (not self._allow_multiple_scalars and + message.HasField(field.name)): + raise tokenizer.ParseErrorPreviousToken( + 'Message type "%s" should not have multiple "%s" fields.' % + (message.DESCRIPTOR.full_name, field.name)) sub_message = getattr(message, field.name) sub_message.SetInParent() @@ -773,6 +844,12 @@ class _Parser(object): else: getattr(message, field.name)[sub_message.key] = sub_message.value + @staticmethod + def _IsProto3Syntax(message): + message_descriptor = message.DESCRIPTOR + return (hasattr(message_descriptor, 'syntax') and + message_descriptor.syntax == 'proto3') + def _MergeScalarField(self, tokenizer, message, field): """Merges a single scalar field into a message. @@ -822,15 +899,20 @@ class _Parser(object): else: getattr(message, field.name).append(value) else: + # Proto3 doesn't represent presence so we can't test if multiple scalars + # have occurred. We have to allow them. + can_check_presence = not self._IsProto3Syntax(message) if field.is_extension: - if not self._allow_multiple_scalars and message.HasExtension(field): + if (not self._allow_multiple_scalars and can_check_presence and + message.HasExtension(field)): raise tokenizer.ParseErrorPreviousToken( 'Message type "%s" should not have multiple "%s" extensions.' % (message.DESCRIPTOR.full_name, field.full_name)) else: message.Extensions[field] = value else: - if not self._allow_multiple_scalars and message.HasField(field.name): + if (not self._allow_multiple_scalars and can_check_presence and + message.HasField(field.name)): raise tokenizer.ParseErrorPreviousToken( 'Message type "%s" should not have multiple "%s" fields.' % (message.DESCRIPTOR.full_name, field.name)) @@ -870,7 +952,7 @@ def _SkipField(tokenizer): tokenizer.ConsumeIdentifier() tokenizer.Consume(']') else: - tokenizer.ConsumeIdentifier() + tokenizer.ConsumeIdentifierOrNumber() _SkipFieldContents(tokenizer) @@ -1025,6 +1107,22 @@ class Tokenizer(object): self.NextToken() return result + def ConsumeCommentOrTrailingComment(self): + """Consumes a comment, returns a 2-tuple (trailing bool, comment str).""" + + # Tokenizer initializes _previous_line and _previous_column to 0. As the + # tokenizer starts, it looks like there is a previous token on the line. + just_started = self._line == 0 and self._column == 0 + + before_parsing = self._previous_line + comment = self.ConsumeComment() + + # A trailing comment is a comment on the same line than the previous token. + trailing = (self._previous_line == before_parsing + and not just_started) + + return trailing, comment + def TryConsumeIdentifier(self): try: self.ConsumeIdentifier() @@ -1065,7 +1163,7 @@ class Tokenizer(object): """ result = self.token if not self._IDENTIFIER_OR_NUMBER.match(result): - raise self.ParseError('Expected identifier or number.') + raise self.ParseError('Expected identifier or number, got %s.' % result) self.NextToken() return result @@ -1448,9 +1546,9 @@ def ParseBool(text): Raises: ValueError: If text is not a valid boolean. """ - if text in ('true', 't', '1'): + if text in ('true', 't', '1', 'True'): return True - elif text in ('false', 'f', '0'): + elif text in ('false', 'f', '0', 'False'): return False else: raise ValueError('Expected "true" or "false".') @@ -1483,6 +1581,11 @@ def ParseEnum(field, value): (enum_descriptor.full_name, value)) else: # Numeric value. + if hasattr(field.file, 'syntax'): + # Attribute is checked for compatibility. + if field.file.syntax == 'proto3': + # Proto3 accept numeric unknown enums. + return number enum_value = enum_descriptor.values_by_number.get(number, None) if enum_value is None: raise ValueError('Enum type "%s" has no value with number %d.' % |