diff options
author | Feng Xiao <xfxyjwf@gmail.com> | 2018-08-08 17:00:41 -0700 |
---|---|---|
committer | Feng Xiao <xfxyjwf@gmail.com> | 2018-08-08 17:00:41 -0700 |
commit | 6bbe197e9c1b6fc38cbdc45e3bf83fa7ced792a3 (patch) | |
tree | e575738adf52d24b883cca5e8928a5ded31caba1 /python/google/protobuf/text_format.py | |
parent | e7746f487cb9cca685ffb1b3d7dccc5554b618a4 (diff) | |
download | protobuf-6bbe197e9c1b6fc38cbdc45e3bf83fa7ced792a3.tar.gz protobuf-6bbe197e9c1b6fc38cbdc45e3bf83fa7ced792a3.tar.bz2 protobuf-6bbe197e9c1b6fc38cbdc45e3bf83fa7ced792a3.zip |
Down-integrate from google3.
Diffstat (limited to 'python/google/protobuf/text_format.py')
-rwxr-xr-x | python/google/protobuf/text_format.py | 134 |
1 files changed, 88 insertions, 46 deletions
diff --git a/python/google/protobuf/text_format.py b/python/google/protobuf/text_format.py index 2cbd21bc..5dd41830 100755 --- a/python/google/protobuf/text_format.py +++ b/python/google/protobuf/text_format.py @@ -55,15 +55,15 @@ from google.protobuf.internal import type_checkers from google.protobuf import descriptor from google.protobuf import text_encoding -__all__ = ['MessageToString', 'PrintMessage', 'PrintField', 'PrintFieldValue', - 'Merge'] +__all__ = ['MessageToString', 'Parse', 'PrintMessage', 'PrintField', + 'PrintFieldValue', 'Merge', 'MessageToBytes'] _INTEGER_CHECKERS = (type_checkers.Uint32ValueChecker(), type_checkers.Int32ValueChecker(), type_checkers.Uint64ValueChecker(), type_checkers.Int64ValueChecker()) -_FLOAT_INFINITY = re.compile('-?inf(?:inity)?f?', re.IGNORECASE) -_FLOAT_NAN = re.compile('nanf?', re.IGNORECASE) +_FLOAT_INFINITY = re.compile('-?inf(?:inity)?f?$', re.IGNORECASE) +_FLOAT_NAN = re.compile('nanf?$', re.IGNORECASE) _FLOAT_TYPES = frozenset([descriptor.FieldDescriptor.CPPTYPE_FLOAT, descriptor.FieldDescriptor.CPPTYPE_DOUBLE]) _QUOTES = frozenset(("'", '"')) @@ -121,6 +121,7 @@ class TextWriter(object): def MessageToString(message, as_utf8=False, as_one_line=False, + use_short_repeated_primitives=False, pointy_brackets=False, use_index_order=False, float_format=None, @@ -128,6 +129,7 @@ def MessageToString(message, descriptor_pool=None, indent=0, message_formatter=None): + # type: (...) -> str """Convert protobuf message to text format. Floating point values can be formatted compactly with 15 digits of @@ -137,8 +139,11 @@ def MessageToString(message, Args: message: The protocol buffers message. - as_utf8: Produce text output in UTF8 format. + as_utf8: Return unescaped Unicode for non-ASCII characters. + In Python 3 actual Unicode characters may appear as is in strings. + In Python 2 the return value will be valid UTF-8 rather than only ASCII. as_one_line: Don't introduce newlines between fields. + use_short_repeated_primitives: Use short repeated format for primitives. pointy_brackets: If True, use angle brackets instead of curly braces for nesting. use_index_order: If True, fields of a proto message will be printed using @@ -159,7 +164,8 @@ def MessageToString(message, A string of the text formatted protocol buffer message. """ out = TextWriter(as_utf8) - printer = _Printer(out, indent, as_utf8, as_one_line, pointy_brackets, + printer = _Printer(out, indent, as_utf8, as_one_line, + use_short_repeated_primitives, pointy_brackets, use_index_order, float_format, use_field_number, descriptor_pool, message_formatter) printer.PrintMessage(message) @@ -170,6 +176,16 @@ def MessageToString(message, return result +def MessageToBytes(message, **kwargs): + # type: (...) -> bytes + """Convert protobuf message to encoded text format. See MessageToString.""" + text = MessageToString(message, **kwargs) + if isinstance(text, bytes): + return text + codec = 'utf-8' if kwargs.get('as_utf8') else 'ascii' + return text.encode(codec) + + def _IsMapEntry(field): return (field.type == descriptor.FieldDescriptor.TYPE_MESSAGE and field.message_type.has_options and @@ -181,13 +197,15 @@ def PrintMessage(message, indent=0, as_utf8=False, as_one_line=False, + use_short_repeated_primitives=False, pointy_brackets=False, use_index_order=False, float_format=None, use_field_number=False, descriptor_pool=None, message_formatter=None): - printer = _Printer(out, indent, as_utf8, as_one_line, pointy_brackets, + printer = _Printer(out, indent, as_utf8, as_one_line, + use_short_repeated_primitives, pointy_brackets, use_index_order, float_format, use_field_number, descriptor_pool, message_formatter) printer.PrintMessage(message) @@ -199,12 +217,14 @@ def PrintField(field, indent=0, as_utf8=False, as_one_line=False, + use_short_repeated_primitives=False, pointy_brackets=False, use_index_order=False, float_format=None, message_formatter=None): """Print a single field name/value pair.""" - printer = _Printer(out, indent, as_utf8, as_one_line, pointy_brackets, + printer = _Printer(out, indent, as_utf8, as_one_line, + use_short_repeated_primitives, pointy_brackets, use_index_order, float_format, message_formatter) printer.PrintField(field, value) @@ -215,12 +235,14 @@ def PrintFieldValue(field, indent=0, as_utf8=False, as_one_line=False, + use_short_repeated_primitives=False, pointy_brackets=False, use_index_order=False, float_format=None, message_formatter=None): """Print a single field value (not including name).""" - printer = _Printer(out, indent, as_utf8, as_one_line, pointy_brackets, + printer = _Printer(out, indent, as_utf8, as_one_line, + use_short_repeated_primitives, pointy_brackets, use_index_order, float_format, message_formatter) printer.PrintFieldValue(field, value) @@ -258,6 +280,7 @@ class _Printer(object): indent=0, as_utf8=False, as_one_line=False, + use_short_repeated_primitives=False, pointy_brackets=False, use_index_order=False, float_format=None, @@ -274,8 +297,11 @@ class _Printer(object): Args: out: To record the text format result. indent: The indent level for pretty print. - as_utf8: Produce text output in UTF8 format. + as_utf8: Return unescaped Unicode for non-ASCII characters. + In Python 3 actual Unicode characters may appear as is in strings. + In Python 2 the return value will be valid UTF-8 rather than ASCII. as_one_line: Don't introduce newlines between fields. + use_short_repeated_primitives: Use short repeated format for primitives. pointy_brackets: If True, use angle brackets instead of curly braces for nesting. use_index_order: If True, print fields of a proto message using the order @@ -294,6 +320,7 @@ class _Printer(object): self.indent = indent self.as_utf8 = as_utf8 self.as_one_line = as_one_line + self.use_short_repeated_primitives = use_short_repeated_primitives self.pointy_brackets = pointy_brackets self.use_index_order = use_index_order self.float_format = float_format @@ -351,13 +378,18 @@ class _Printer(object): entry_submsg = value.GetEntryClass()(key=key, value=value[key]) self.PrintField(field, entry_submsg) elif field.label == descriptor.FieldDescriptor.LABEL_REPEATED: - for element in value: - self.PrintField(field, element) + if (self.use_short_repeated_primitives + and field.cpp_type != descriptor.FieldDescriptor.CPPTYPE_MESSAGE + and field.cpp_type != descriptor.FieldDescriptor.CPPTYPE_STRING): + self._PrintShortRepeatedPrimitivesValue(field, value) + else: + for element in value: + self.PrintField(field, element) else: self.PrintField(field, value) - def PrintField(self, field, value): - """Print a single field name/value pair.""" + def _PrintFieldName(self, field): + """Print field name.""" out = self.out out.write(' ' * self.indent) if self.use_field_number: @@ -383,11 +415,22 @@ class _Printer(object): # don't include it. out.write(': ') + def PrintField(self, field, value): + """Print a single field name/value pair.""" + self._PrintFieldName(field) self.PrintFieldValue(field, value) - if self.as_one_line: - out.write(' ') - else: - out.write('\n') + self.out.write(' ' if self.as_one_line else '\n') + + def _PrintShortRepeatedPrimitivesValue(self, field, value): + # Note: this is called only when value has at least one element. + self._PrintFieldName(field) + self.out.write('[') + for i in xrange(len(value) - 1): + self.PrintFieldValue(field, value[i]) + self.out.write(', ') + self.PrintFieldValue(field, value[-1]) + self.out.write(']') + self.out.write(' ' if self.as_one_line else '\n') def _PrintMessageFieldValue(self, value): if self.pointy_brackets: @@ -428,12 +471,12 @@ class _Printer(object): out.write(str(value)) elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_STRING: out.write('\"') - if isinstance(value, six.text_type): + if isinstance(value, six.text_type) and (six.PY2 or not self.as_utf8): out_value = value.encode('utf-8') else: out_value = value if field.type == descriptor.FieldDescriptor.TYPE_BYTES: - # We need to escape non-UTF8 chars in TYPE_BYTES field. + # We always need to escape all binary data in TYPE_BYTES fields. out_as_utf8 = False else: out_as_utf8 = self.as_utf8 @@ -487,12 +530,7 @@ def Parse(text, Raises: ParseError: On text parsing problems. """ - if not isinstance(text, str): - if six.PY3: - text = text.decode('utf-8') - else: - text = text.encode('utf-8') - return ParseLines(text.split('\n'), + return ParseLines(text.split(b'\n' if isinstance(text, bytes) else u'\n'), message, allow_unknown_extension, allow_field_number, @@ -523,13 +561,8 @@ def Merge(text, Raises: ParseError: On text parsing problems. """ - if not isinstance(text, str): - if six.PY3: - text = text.decode('utf-8') - else: - text = text.encode('utf-8') return MergeLines( - text.split('\n'), + text.split(b'\n' if isinstance(text, bytes) else u'\n'), message, allow_unknown_extension, allow_field_number, @@ -570,6 +603,9 @@ def MergeLines(lines, descriptor_pool=None): """Parses a text representation of a protocol message into a message. + Like ParseLines(), but allows repeated values for a non-repeated field, and + uses the last one. + Args: lines: An iterable of lines of a message's text representation. message: A protocol buffer message to merge into. @@ -601,22 +637,12 @@ class _Parser(object): self.allow_field_number = allow_field_number self.descriptor_pool = descriptor_pool - def ParseFromString(self, text, message): - """Parses a text representation of a protocol message into a message.""" - if not isinstance(text, str): - text = text.decode('utf-8') - return self.ParseLines(text.split('\n'), message) - def ParseLines(self, lines, message): """Parses a text representation of a protocol message into a message.""" self._allow_multiple_scalars = False self._ParseOrMerge(lines, message) return message - def MergeFromString(self, text, message): - """Merges a text representation of a protocol message into a message.""" - return self._MergeLines(text.split('\n'), message) - def MergeLines(self, lines, message): """Merges a text representation of a protocol message into a message.""" self._allow_multiple_scalars = True @@ -633,7 +659,14 @@ class _Parser(object): Raises: ParseError: On text parsing problems. """ - tokenizer = Tokenizer(lines) + # Tokenize expects native str lines. + if six.PY2: + str_lines = (line if isinstance(line, str) else line.encode('utf-8') + for line in lines) + else: + str_lines = (line if isinstance(line, str) else line.decode('utf-8') + for line in lines) + tokenizer = Tokenizer(str_lines) while not tokenizer.AtEnd(): self._MergeField(tokenizer, message) @@ -1019,7 +1052,9 @@ class Tokenizer(object): r'[a-zA-Z_][0-9a-zA-Z_+-]*', # an identifier r'([0-9+-]|(\.[0-9]))[0-9a-zA-Z_.+-]*', # a number ] + [ # quoted str for each quote mark - r'{qt}([^{qt}\n\\]|\\.)*({qt}|\\?$)'.format(qt=mark) for mark in _QUOTES + # Avoid backtracking! https://stackoverflow.com/a/844267 + r'{qt}[^{qt}\n\\]*((\\.)+[^{qt}\n\\]*)*({qt}|\\?$)'.format(qt=mark) + for mark in _QUOTES ])) _IDENTIFIER = re.compile(r'[^\d\W]\w*') @@ -1316,7 +1351,8 @@ class Tokenizer(object): def ParseError(self, message): """Creates and *returns* a ParseError for the current token.""" - return ParseError(message, self._line + 1, self._column + 1) + return ParseError('\'' + self._current_line + '\': ' + message, + self._line + 1, self._column + 1) def _StringParseError(self, e): return self.ParseError('Couldn\'t parse string: ' + str(e)) @@ -1490,6 +1526,12 @@ def _ParseAbstractInteger(text, is_long=False): ValueError: Thrown Iff the text is not a valid integer. """ # Do the actual parsing. Exception handling is propagated to caller. + orig_text = text + c_octal_match = re.match(r'(-?)0(\d+)$', text) + if c_octal_match: + # Python 3 no longer supports 0755 octal syntax without the 'o', so + # we always use the '0o' prefix for multi-digit numbers starting with 0. + text = c_octal_match.group(1) + '0o' + c_octal_match.group(2) try: # We force 32-bit values to int and 64-bit values to long to make # alternate implementations where the distinction is more significant @@ -1499,7 +1541,7 @@ def _ParseAbstractInteger(text, is_long=False): else: return int(text, 0) except ValueError: - raise ValueError('Couldn\'t parse integer: %s' % text) + raise ValueError('Couldn\'t parse integer: %s' % orig_text) def ParseFloat(text): |