diff options
Diffstat (limited to 'python/google/protobuf/text_format.py')
-rwxr-xr-x | python/google/protobuf/text_format.py | 149 |
1 files changed, 97 insertions, 52 deletions
diff --git a/python/google/protobuf/text_format.py b/python/google/protobuf/text_format.py index 2cbd21bc..998cd681 100755 --- a/python/google/protobuf/text_format.py +++ b/python/google/protobuf/text_format.py @@ -55,15 +55,15 @@ from google.protobuf.internal import type_checkers from google.protobuf import descriptor from google.protobuf import text_encoding -__all__ = ['MessageToString', 'PrintMessage', 'PrintField', 'PrintFieldValue', - 'Merge'] +__all__ = ['MessageToString', 'Parse', 'PrintMessage', 'PrintField', + 'PrintFieldValue', 'Merge', 'MessageToBytes'] _INTEGER_CHECKERS = (type_checkers.Uint32ValueChecker(), type_checkers.Int32ValueChecker(), type_checkers.Uint64ValueChecker(), type_checkers.Int64ValueChecker()) -_FLOAT_INFINITY = re.compile('-?inf(?:inity)?f?', re.IGNORECASE) -_FLOAT_NAN = re.compile('nanf?', re.IGNORECASE) +_FLOAT_INFINITY = re.compile('-?inf(?:inity)?f?$', re.IGNORECASE) +_FLOAT_NAN = re.compile('nanf?$', re.IGNORECASE) _FLOAT_TYPES = frozenset([descriptor.FieldDescriptor.CPPTYPE_FLOAT, descriptor.FieldDescriptor.CPPTYPE_DOUBLE]) _QUOTES = frozenset(("'", '"')) @@ -121,6 +121,7 @@ class TextWriter(object): def MessageToString(message, as_utf8=False, as_one_line=False, + use_short_repeated_primitives=False, pointy_brackets=False, use_index_order=False, float_format=None, @@ -128,6 +129,7 @@ def MessageToString(message, descriptor_pool=None, indent=0, message_formatter=None): + # type: (...) -> str """Convert protobuf message to text format. Floating point values can be formatted compactly with 15 digits of @@ -137,8 +139,11 @@ def MessageToString(message, Args: message: The protocol buffers message. - as_utf8: Produce text output in UTF8 format. + as_utf8: Return unescaped Unicode for non-ASCII characters. + In Python 3 actual Unicode characters may appear as is in strings. + In Python 2 the return value will be valid UTF-8 rather than only ASCII. as_one_line: Don't introduce newlines between fields. + use_short_repeated_primitives: Use short repeated format for primitives. pointy_brackets: If True, use angle brackets instead of curly braces for nesting. use_index_order: If True, fields of a proto message will be printed using @@ -150,7 +155,7 @@ def MessageToString(message, (per the "Format Specification Mini-Language"); otherwise, str() is used. use_field_number: If True, print field numbers instead of names. descriptor_pool: A DescriptorPool used to resolve Any types. - indent: The indent level, in terms of spaces, for pretty print. + indent: The initial indent level, in terms of spaces, for pretty print. message_formatter: A function(message, indent, as_one_line): unicode|None to custom format selected sub-messages (usually based on message type). Use to pretty print parts of the protobuf for easier diffing. @@ -159,7 +164,8 @@ def MessageToString(message, A string of the text formatted protocol buffer message. """ out = TextWriter(as_utf8) - printer = _Printer(out, indent, as_utf8, as_one_line, pointy_brackets, + printer = _Printer(out, indent, as_utf8, as_one_line, + use_short_repeated_primitives, pointy_brackets, use_index_order, float_format, use_field_number, descriptor_pool, message_formatter) printer.PrintMessage(message) @@ -170,6 +176,16 @@ def MessageToString(message, return result +def MessageToBytes(message, **kwargs): + # type: (...) -> bytes + """Convert protobuf message to encoded text format. See MessageToString.""" + text = MessageToString(message, **kwargs) + if isinstance(text, bytes): + return text + codec = 'utf-8' if kwargs.get('as_utf8') else 'ascii' + return text.encode(codec) + + def _IsMapEntry(field): return (field.type == descriptor.FieldDescriptor.TYPE_MESSAGE and field.message_type.has_options and @@ -181,13 +197,15 @@ def PrintMessage(message, indent=0, as_utf8=False, as_one_line=False, + use_short_repeated_primitives=False, pointy_brackets=False, use_index_order=False, float_format=None, use_field_number=False, descriptor_pool=None, message_formatter=None): - printer = _Printer(out, indent, as_utf8, as_one_line, pointy_brackets, + printer = _Printer(out, indent, as_utf8, as_one_line, + use_short_repeated_primitives, pointy_brackets, use_index_order, float_format, use_field_number, descriptor_pool, message_formatter) printer.PrintMessage(message) @@ -199,12 +217,14 @@ def PrintField(field, indent=0, as_utf8=False, as_one_line=False, + use_short_repeated_primitives=False, pointy_brackets=False, use_index_order=False, float_format=None, message_formatter=None): """Print a single field name/value pair.""" - printer = _Printer(out, indent, as_utf8, as_one_line, pointy_brackets, + printer = _Printer(out, indent, as_utf8, as_one_line, + use_short_repeated_primitives, pointy_brackets, use_index_order, float_format, message_formatter) printer.PrintField(field, value) @@ -215,12 +235,14 @@ def PrintFieldValue(field, indent=0, as_utf8=False, as_one_line=False, + use_short_repeated_primitives=False, pointy_brackets=False, use_index_order=False, float_format=None, message_formatter=None): """Print a single field value (not including name).""" - printer = _Printer(out, indent, as_utf8, as_one_line, pointy_brackets, + printer = _Printer(out, indent, as_utf8, as_one_line, + use_short_repeated_primitives, pointy_brackets, use_index_order, float_format, message_formatter) printer.PrintFieldValue(field, value) @@ -258,6 +280,7 @@ class _Printer(object): indent=0, as_utf8=False, as_one_line=False, + use_short_repeated_primitives=False, pointy_brackets=False, use_index_order=False, float_format=None, @@ -273,9 +296,12 @@ class _Printer(object): Args: out: To record the text format result. - indent: The indent level for pretty print. - as_utf8: Produce text output in UTF8 format. + indent: The initial indent level for pretty print. + as_utf8: Return unescaped Unicode for non-ASCII characters. + In Python 3 actual Unicode characters may appear as is in strings. + In Python 2 the return value will be valid UTF-8 rather than ASCII. as_one_line: Don't introduce newlines between fields. + use_short_repeated_primitives: Use short repeated format for primitives. pointy_brackets: If True, use angle brackets instead of curly braces for nesting. use_index_order: If True, print fields of a proto message using the order @@ -294,6 +320,7 @@ class _Printer(object): self.indent = indent self.as_utf8 = as_utf8 self.as_one_line = as_one_line + self.use_short_repeated_primitives = use_short_repeated_primitives self.pointy_brackets = pointy_brackets self.use_index_order = use_index_order self.float_format = float_format @@ -303,11 +330,13 @@ class _Printer(object): def _TryPrintAsAnyMessage(self, message): """Serializes if message is a google.protobuf.Any field.""" + if '/' not in message.type_url: + return False packed_message = _BuildMessageFromTypeName(message.TypeName(), self.descriptor_pool) if packed_message: packed_message.MergeFromString(message.value) - self.out.write('%s[%s]' % (self.indent * ' ', message.type_url)) + self.out.write('%s[%s] ' % (self.indent * ' ', message.type_url)) self._PrintMessageFieldValue(packed_message) self.out.write(' ' if self.as_one_line else '\n') return True @@ -351,13 +380,18 @@ class _Printer(object): entry_submsg = value.GetEntryClass()(key=key, value=value[key]) self.PrintField(field, entry_submsg) elif field.label == descriptor.FieldDescriptor.LABEL_REPEATED: - for element in value: - self.PrintField(field, element) + if (self.use_short_repeated_primitives + and field.cpp_type != descriptor.FieldDescriptor.CPPTYPE_MESSAGE + and field.cpp_type != descriptor.FieldDescriptor.CPPTYPE_STRING): + self._PrintShortRepeatedPrimitivesValue(field, value) + else: + for element in value: + self.PrintField(field, element) else: self.PrintField(field, value) - def PrintField(self, field, value): - """Print a single field name/value pair.""" + def _PrintFieldName(self, field): + """Print field name.""" out = self.out out.write(' ' * self.indent) if self.use_field_number: @@ -381,13 +415,25 @@ class _Printer(object): if field.cpp_type != descriptor.FieldDescriptor.CPPTYPE_MESSAGE: # The colon is optional in this case, but our cross-language golden files # don't include it. - out.write(': ') + out.write(':') + def PrintField(self, field, value): + """Print a single field name/value pair.""" + self._PrintFieldName(field) + self.out.write(' ') self.PrintFieldValue(field, value) - if self.as_one_line: - out.write(' ') - else: - out.write('\n') + self.out.write(' ' if self.as_one_line else '\n') + + def _PrintShortRepeatedPrimitivesValue(self, field, value): + # Note: this is called only when value has at least one element. + self._PrintFieldName(field) + self.out.write('[') + for i in xrange(len(value) - 1): + self.PrintFieldValue(field, value[i]) + self.out.write(', ') + self.PrintFieldValue(field, value[-1]) + self.out.write(']') + self.out.write(' ' if self.as_one_line else '\n') def _PrintMessageFieldValue(self, value): if self.pointy_brackets: @@ -398,11 +444,11 @@ class _Printer(object): closeb = '}' if self.as_one_line: - self.out.write(' %s ' % openb) + self.out.write('%s ' % openb) self.PrintMessage(value) self.out.write(closeb) else: - self.out.write(' %s\n' % openb) + self.out.write('%s\n' % openb) self.indent += 2 self.PrintMessage(value) self.indent -= 2 @@ -428,12 +474,12 @@ class _Printer(object): out.write(str(value)) elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_STRING: out.write('\"') - if isinstance(value, six.text_type): + if isinstance(value, six.text_type) and (six.PY2 or not self.as_utf8): out_value = value.encode('utf-8') else: out_value = value if field.type == descriptor.FieldDescriptor.TYPE_BYTES: - # We need to escape non-UTF8 chars in TYPE_BYTES field. + # We always need to escape all binary data in TYPE_BYTES fields. out_as_utf8 = False else: out_as_utf8 = self.as_utf8 @@ -487,12 +533,7 @@ def Parse(text, Raises: ParseError: On text parsing problems. """ - if not isinstance(text, str): - if six.PY3: - text = text.decode('utf-8') - else: - text = text.encode('utf-8') - return ParseLines(text.split('\n'), + return ParseLines(text.split(b'\n' if isinstance(text, bytes) else u'\n'), message, allow_unknown_extension, allow_field_number, @@ -523,13 +564,8 @@ def Merge(text, Raises: ParseError: On text parsing problems. """ - if not isinstance(text, str): - if six.PY3: - text = text.decode('utf-8') - else: - text = text.encode('utf-8') return MergeLines( - text.split('\n'), + text.split(b'\n' if isinstance(text, bytes) else u'\n'), message, allow_unknown_extension, allow_field_number, @@ -570,6 +606,9 @@ def MergeLines(lines, descriptor_pool=None): """Parses a text representation of a protocol message into a message. + Like ParseLines(), but allows repeated values for a non-repeated field, and + uses the last one. + Args: lines: An iterable of lines of a message's text representation. message: A protocol buffer message to merge into. @@ -601,22 +640,12 @@ class _Parser(object): self.allow_field_number = allow_field_number self.descriptor_pool = descriptor_pool - def ParseFromString(self, text, message): - """Parses a text representation of a protocol message into a message.""" - if not isinstance(text, str): - text = text.decode('utf-8') - return self.ParseLines(text.split('\n'), message) - def ParseLines(self, lines, message): """Parses a text representation of a protocol message into a message.""" self._allow_multiple_scalars = False self._ParseOrMerge(lines, message) return message - def MergeFromString(self, text, message): - """Merges a text representation of a protocol message into a message.""" - return self._MergeLines(text.split('\n'), message) - def MergeLines(self, lines, message): """Merges a text representation of a protocol message into a message.""" self._allow_multiple_scalars = True @@ -633,7 +662,14 @@ class _Parser(object): Raises: ParseError: On text parsing problems. """ - tokenizer = Tokenizer(lines) + # Tokenize expects native str lines. + if six.PY2: + str_lines = (line if isinstance(line, str) else line.encode('utf-8') + for line in lines) + else: + str_lines = (line if isinstance(line, str) else line.decode('utf-8') + for line in lines) + tokenizer = Tokenizer(str_lines) while not tokenizer.AtEnd(): self._MergeField(tokenizer, message) @@ -1019,7 +1055,9 @@ class Tokenizer(object): r'[a-zA-Z_][0-9a-zA-Z_+-]*', # an identifier r'([0-9+-]|(\.[0-9]))[0-9a-zA-Z_.+-]*', # a number ] + [ # quoted str for each quote mark - r'{qt}([^{qt}\n\\]|\\.)*({qt}|\\?$)'.format(qt=mark) for mark in _QUOTES + # Avoid backtracking! https://stackoverflow.com/a/844267 + r'{qt}[^{qt}\n\\]*((\\.)+[^{qt}\n\\]*)*({qt}|\\?$)'.format(qt=mark) + for mark in _QUOTES ])) _IDENTIFIER = re.compile(r'[^\d\W]\w*') @@ -1316,7 +1354,8 @@ class Tokenizer(object): def ParseError(self, message): """Creates and *returns* a ParseError for the current token.""" - return ParseError(message, self._line + 1, self._column + 1) + return ParseError('\'' + self._current_line + '\': ' + message, + self._line + 1, self._column + 1) def _StringParseError(self, e): return self.ParseError('Couldn\'t parse string: ' + str(e)) @@ -1490,6 +1529,12 @@ def _ParseAbstractInteger(text, is_long=False): ValueError: Thrown Iff the text is not a valid integer. """ # Do the actual parsing. Exception handling is propagated to caller. + orig_text = text + c_octal_match = re.match(r'(-?)0(\d+)$', text) + if c_octal_match: + # Python 3 no longer supports 0755 octal syntax without the 'o', so + # we always use the '0o' prefix for multi-digit numbers starting with 0. + text = c_octal_match.group(1) + '0o' + c_octal_match.group(2) try: # We force 32-bit values to int and 64-bit values to long to make # alternate implementations where the distinction is more significant @@ -1499,7 +1544,7 @@ def _ParseAbstractInteger(text, is_long=False): else: return int(text, 0) except ValueError: - raise ValueError('Couldn\'t parse integer: %s' % text) + raise ValueError('Couldn\'t parse integer: %s' % orig_text) def ParseFloat(text): |